IMLCV.implementations.CvDiscovery
=================================

.. py:module:: IMLCV.implementations.CvDiscovery


Classes
-------

.. autoapisummary::

   IMLCV.implementations.CvDiscovery.Encoder
   IMLCV.implementations.CvDiscovery.Decoder
   IMLCV.implementations.CvDiscovery.VAE
   IMLCV.implementations.CvDiscovery.TranformerAutoEncoder
   IMLCV.implementations.CvDiscovery.TransoformerLDA
   IMLCV.implementations.CvDiscovery.TransformerMAF


Functions
---------

.. autoapisummary::

   IMLCV.implementations.CvDiscovery._LDA_trans
   IMLCV.implementations.CvDiscovery._LDA_rescale
   IMLCV.implementations.CvDiscovery._scale_trans


Module Contents
---------------

.. py:class:: Encoder

   Bases: :py:obj:`flax.linen.Module`


   Base class for all neural network modules.

   Layers and models should subclass this class.

   All Flax Modules are Python 3.7
   `dataclasses <https://docs.python.org/3/library/dataclasses.html>`_. Since
   dataclasses take over ``__init__``, you should instead override :meth:`setup`,
   which is automatically called to initialize the module.

   Modules can contain submodules, and in this way can be nested in a tree
   structure. Submodels can be assigned as regular attributes inside the
   :meth:`setup` method.

   You can define arbitrary "forward pass" methods on your Module subclass.
   While no methods are special-cased, ``__call__`` is a popular choice because
   it allows you to use module instances as if they are functions::

     >>> from flax import linen as nn
     >>> from typing import Tuple

     >>> class Module(nn.Module):
     ...   features: Tuple[int, ...] = (16, 4)

     ...   def setup(self):
     ...     self.dense1 = nn.Dense(self.features[0])
     ...     self.dense2 = nn.Dense(self.features[1])

     ...   def __call__(self, x):
     ...     return self.dense2(nn.relu(self.dense1(x)))

   Optionally, for more concise module implementations where submodules
   definitions are co-located with their usage, you can use the
   :meth:`compact` wrapper.


   .. py:attribute:: latents
      :type:  int


   .. py:attribute:: layers
      :type:  int


   .. py:attribute:: nunits
      :type:  int


   .. py:attribute:: dim
      :type:  int


   .. py:method:: __call__(x)


.. py:class:: Decoder

   Bases: :py:obj:`flax.linen.Module`


   Base class for all neural network modules.

   Layers and models should subclass this class.

   All Flax Modules are Python 3.7
   `dataclasses <https://docs.python.org/3/library/dataclasses.html>`_. Since
   dataclasses take over ``__init__``, you should instead override :meth:`setup`,
   which is automatically called to initialize the module.

   Modules can contain submodules, and in this way can be nested in a tree
   structure. Submodels can be assigned as regular attributes inside the
   :meth:`setup` method.

   You can define arbitrary "forward pass" methods on your Module subclass.
   While no methods are special-cased, ``__call__`` is a popular choice because
   it allows you to use module instances as if they are functions::

     >>> from flax import linen as nn
     >>> from typing import Tuple

     >>> class Module(nn.Module):
     ...   features: Tuple[int, ...] = (16, 4)

     ...   def setup(self):
     ...     self.dense1 = nn.Dense(self.features[0])
     ...     self.dense2 = nn.Dense(self.features[1])

     ...   def __call__(self, x):
     ...     return self.dense2(nn.relu(self.dense1(x)))

   Optionally, for more concise module implementations where submodules
   definitions are co-located with their usage, you can use the
   :meth:`compact` wrapper.


   .. py:attribute:: latents
      :type:  int


   .. py:attribute:: layers
      :type:  int


   .. py:attribute:: nunits
      :type:  int


   .. py:attribute:: dim
      :type:  int


   .. py:method:: __call__(z)


.. py:class:: VAE

   Bases: :py:obj:`flax.linen.Module`


   Base class for all neural network modules.

   Layers and models should subclass this class.

   All Flax Modules are Python 3.7
   `dataclasses <https://docs.python.org/3/library/dataclasses.html>`_. Since
   dataclasses take over ``__init__``, you should instead override :meth:`setup`,
   which is automatically called to initialize the module.

   Modules can contain submodules, and in this way can be nested in a tree
   structure. Submodels can be assigned as regular attributes inside the
   :meth:`setup` method.

   You can define arbitrary "forward pass" methods on your Module subclass.
   While no methods are special-cased, ``__call__`` is a popular choice because
   it allows you to use module instances as if they are functions::

     >>> from flax import linen as nn
     >>> from typing import Tuple

     >>> class Module(nn.Module):
     ...   features: Tuple[int, ...] = (16, 4)

     ...   def setup(self):
     ...     self.dense1 = nn.Dense(self.features[0])
     ...     self.dense2 = nn.Dense(self.features[1])

     ...   def __call__(self, x):
     ...     return self.dense2(nn.relu(self.dense1(x)))

   Optionally, for more concise module implementations where submodules
   definitions are co-located with their usage, you can use the
   :meth:`compact` wrapper.


   .. py:attribute:: latents
      :type:  int


   .. py:attribute:: layers
      :type:  int


   .. py:attribute:: nunits
      :type:  int


   .. py:attribute:: dim
      :type:  int


   .. py:method:: setup()

      Initializes a Module lazily (similar to a lazy ``__init__``).

      ``setup`` is called once lazily on a module instance when a module
      is bound, immediately before any other methods like ``__call__`` are
      invoked, or before a ``setup``-defined attribute on ``self`` is accessed.

      This can happen in three cases:

        1. Immediately when invoking :meth:`apply`, :meth:`init` or
           :meth:`init_and_output`.

        2. Once the module is given a name by being assigned to an attribute of
           another module inside the other module's ``setup`` method
           (see :meth:`__setattr__`)::

              >>> class MyModule(nn.Module):
              ...   def setup(self):
              ...     submodule = nn.Conv(...)

              ...     # Accessing `submodule` attributes does not yet work here.

              ...     # The following line invokes `self.__setattr__`, which gives
              ...     # `submodule` the name "conv1".
              ...     self.conv1 = submodule

              ...     # Accessing `submodule` attributes or methods is now safe and
              ...     # either causes setup() to be called once.

        3. Once a module is constructed inside a method wrapped with
           :meth:`compact`, immediately before another method is called or
           ``setup`` defined attribute is accessed.



   .. py:method:: __call__(x, z_rng)


   .. py:method:: encode(x)


   .. py:method:: reparameterize(rng, mean, logvar)
      :classmethod:



.. py:class:: TranformerAutoEncoder(*args, **kwargs)

   Bases: :py:obj:`IMLCV.base.CVDiscovery.Transformer`


   Base class for dataclasses that should act like a JAX pytree node.


   .. py:attribute:: nunits
      :type:  int
      :value: 250



   .. py:attribute:: nlayers
      :type:  int
      :value: 3



   .. py:attribute:: lr
      :type:  float
      :value: 0.0001



   .. py:attribute:: num_epochs
      :type:  int
      :value: 100



   .. py:attribute:: batch_size
      :type:  int
      :value: 32



   .. py:method:: _fit(cv: list[IMLCV.base.CV.CV], cv_t: list[IMLCV.base.CV.CV], w: list[jax.Array], dlo: IMLCV.base.rounds.DataLoaderOutput, chunk_size=None, verbose=True, macro_chunk=1000)


.. py:function:: _LDA_trans(cv: IMLCV.base.CV.CV, nl: IMLCV.base.CV.NeighbourList | None, shmap, shmap_kwargs, alpha, outdim, solver)

.. py:function:: _LDA_rescale(cv: IMLCV.base.CV.CV, nl: IMLCV.base.CV.NeighbourList | None, shmap, shmap_kwargs, mean)

.. py:function:: _scale_trans(cv: IMLCV.base.CV.CV, nl: IMLCV.base.CV.NeighbourList | None, shmap, shmap_kwargs, alpha: jax.Array, scale_factor: jax.Array)

.. py:class:: TransoformerLDA(*args, **kwargs)

   Bases: :py:obj:`IMLCV.base.CVDiscovery.Transformer`


   Base class for dataclasses that should act like a JAX pytree node.


   .. py:attribute:: kernel
      :value: False



   .. py:attribute:: optimizer
      :value: None



   .. py:attribute:: solver
      :type:  str
      :value: 'eigen'



   .. py:attribute:: method
      :type:  str
      :value: 'pymanopt'



   .. py:attribute:: harmonic
      :value: True



   .. py:attribute:: min_gradient_norm
      :type:  float
      :value: 0.001



   .. py:attribute:: min_step_size
      :type:  float
      :value: 0.001



   .. py:attribute:: max_iterations
      :type:  int
      :value: 25



   .. py:method:: _fit(cv_list: list[IMLCV.base.CV.CV], cv_t_list: list[IMLCV.base.CV.CV], w: list[jax.Array], dlo: IMLCV.base.rounds.DataLoaderOutput, chunk_size=None, verbose=True, macro_chunk=1000)


.. py:class:: TransformerMAF(*args, **kwargs)

   Bases: :py:obj:`IMLCV.base.CVDiscovery.Transformer`


   Base class for dataclasses that should act like a JAX pytree node.


   .. py:attribute:: eps
      :type:  float
      :value: 1e-05



   .. py:attribute:: eps_pre
      :type:  float
      :value: 1e-05



   .. py:attribute:: max_features
      :type:  int
      :value: 500



   .. py:attribute:: max_features_pre
      :type:  int
      :value: 500



   .. py:attribute:: sym
      :type:  bool
      :value: True



   .. py:attribute:: use_w
      :type:  bool
      :value: True



   .. py:attribute:: min_t_frac
      :type:  float
      :value: 0.1



   .. py:attribute:: trans
      :type:  IMLCV.base.CV.CvTrans | None
      :value: None



   .. py:attribute:: T_scale
      :type:  float
      :value: 1.0



   .. py:method:: _fit(x: list[IMLCV.base.CV.CV] | list[IMLCV.base.CV.SystemParams], x_t: list[IMLCV.base.CV.CV] | list[IMLCV.base.CV.SystemParams] | None, w: list[jax.Array], w_t: list[jax.Array], dlo: IMLCV.base.rounds.DataLoaderOutput, macro_chunk=1000, chunk_size=None, verbose=True)


