:py:mod:`IMLCV.base.CV`
=======================

.. py:module:: IMLCV.base.CV


Module Contents
---------------

Classes
~~~~~~~

.. autoapisummary::

   IMLCV.base.CV.SystemParams
   IMLCV.base.CV.NeighbourList
   IMLCV.base.CV.CV
   IMLCV.base.CV.CvMetric
   IMLCV.base.CV.CvFunInput
   IMLCV.base.CV.CvFunBase
   IMLCV.base.CV.CvFun
   IMLCV.base.CV.CvFunNn
   IMLCV.base.CV.CvFunDistrax
   IMLCV.base.CV.CombinedCvFun
   IMLCV.base.CV.CvTrans
   IMLCV.base.CV.CvTransNN
   IMLCV.base.CV.NormalizingFlow
   IMLCV.base.CV.CvFlow
   IMLCV.base.CV.CollectiveVariable




.. py:class:: SystemParams

   .. py:property:: batched


   .. py:property:: shape


   .. py:attribute:: coordinates
      :type: jax.Array

      

   .. py:attribute:: cell
      :type: Array | None

      

   .. py:method:: __post_init__()


   .. py:method:: __getitem__(slices) -> SystemParams


   .. py:method:: __iter__()


   .. py:method:: __add__(other)


   .. py:method:: __str__()

      Return str(self).


   .. py:method:: batch() -> SystemParams


   .. py:method:: unbatch() -> SystemParams


   .. py:method:: angles(deg=True) -> jax.Array


   .. py:method:: _get_neighbour_list(r_cut, r_skin: float = 0.0, z_array: tuple[int] | None = None, z_unique: tuple[int] | None = None, num_z_unique: tuple[int] | None = None, num_neighs: int | None = None, nxyz: tuple[int] | None = None) -> tuple[bool, NeighbourList | None]


   .. py:method:: get_neighbour_list(r_cut, z_array: list[int] | Array, r_skin=0.0) -> NeighbourList | None


   .. py:method:: minkowski_reduce() -> tuple[SystemParams, jax.Array]

      base on code from ASE: https://wiki.fysik.dtu.dk/ase/_modules/ase/geometry/minkowski_reductsp.cellion.html#minkowski_reduce


   .. py:method:: wrap_positions(min=False) -> tuple[SystemParams, jax.Array]

      wrap pos to lie within unit cell


   .. py:method:: _wrap_pos(cell: Array | None, coordinates: jax.Array, min=False) -> tuple[jax.Array, jax.Array]
      :staticmethod:


   .. py:method:: canoncialize(min=False) -> tuple[SystemParams, jax.Array, jax.Array]


   .. py:method:: min_distance(index_1, index_2)



.. py:class:: NeighbourList

   .. py:property:: batched


   .. py:property:: shape


   .. py:property:: num_neigh


   .. py:attribute:: r_cut
      :type: jax_dataclasses.Static[jax.numpy.floating]

      

   .. py:attribute:: atom_indices
      :type: jax.Array

      

   .. py:attribute:: op_cell
      :type: Array | None

      

   .. py:attribute:: op_coor
      :type: Array | None

      

   .. py:attribute:: op_center
      :type: Array | None

      

   .. py:attribute:: r_skin
      :type: jax_dataclasses.Static[jax.numpy.floating]

      

   .. py:attribute:: sp_orig
      :type: SystemParams | None

      

   .. py:attribute:: ijk_indices
      :type: Array | None

      

   .. py:attribute:: nxyz
      :type: jax_dataclasses.Static[tuple[int] | None]

      

   .. py:attribute:: z_array
      :type: jax_dataclasses.Static[tuple[int] | None]

      

   .. py:attribute:: z_unique
      :type: jax_dataclasses.Static[tuple[int] | None]

      

   .. py:attribute:: num_z_unique
      :type: jax_dataclasses.Static[tuple[int] | None]

      

   .. py:method:: _pos(sp_orig)


   .. py:method:: apply_fun_neighbour(sp: SystemParams, func, r_cut=None, fill_value=0, reduce='full', split_z=False, exclude_self=False)


   .. py:method:: apply_fun_neighbour_pair(sp: SystemParams, func_double, func_single=None, r_cut=None, fill_value=0, reduce='full', split_z=False, exclude_self=True, unique=True)

      Args:
      ______
      func_single=lambda r_ij, atom_index_j: (1,),
      func_double=lambda r_ij, atom_index_j, data_j, r_ik, atom_index_k, data_k: (
          r_ij,
          atom_index_j,
          r_ik,
          atom_index_k,
      ),



   .. py:method:: __getitem__(slices)


   .. py:method:: update(sp: SystemParams) -> tuple[bool, NeighbourList]


   .. py:method:: nl_split_z(p)


   .. py:method:: match_kernel(p1: jax.Array, p2: jax.Array, nl1: NeighbourList, nl2: NeighbourList, matching='REMatch', alpha=0.01, mode='divergence', jit=True)
      :staticmethod:


   .. py:method:: __add__(other)


   .. py:method:: stack(*nls: NeighbourList) -> NeighbourList
      :staticmethod:



.. py:class:: CV

   .. py:property:: batched


   .. py:property:: batch_dim


   .. py:property:: dim


   .. py:property:: size


   .. py:property:: shape


   .. py:property:: combine_dims


   .. py:property:: stack_dims


   .. py:attribute:: cv
      :type: jax.Array

      

   .. py:attribute:: mapped
      :type: jax_dataclasses.Static[bool]
      :value: False

      

   .. py:attribute:: atomic
      :type: jax_dataclasses.Static[bool]
      :value: False

      

   .. py:attribute:: _combine_dims
      :type: jax_dataclasses.Static[list | None]

      

   .. py:attribute:: _stack_dims
      :type: jax_dataclasses.Static[list | None]

      

   .. py:method:: __add__(other) -> CV


   .. py:method:: __radd__(other) -> CV


   .. py:method:: __sub__(other) -> CV


   .. py:method:: __rsub__(other) -> CV


   .. py:method:: __mul__(other) -> CV


   .. py:method:: __rmul__(other) -> CV


   .. py:method:: __matmul__(other) -> CV


   .. py:method:: __rmatmul__(other) -> CV


   .. py:method:: __div__(other) -> CV


   .. py:method:: __rdiv__(other) -> CV


   .. py:method:: batch() -> CV


   .. py:method:: __iter__()


   .. py:method:: __getitem__(idx)


   .. py:method:: unbatch() -> CV


   .. py:method:: stack(*cvs: CV) -> CV
      :staticmethod:

      stacks a list of CVs into a single CV. The dimenisions are stored such that it can later be unstacked into separated CVs. The CVs are stacked over the batch dimension


   .. py:method:: unstack() -> list[CV]


   .. py:method:: split(flatten=False) -> list[CV]

      inverse operation of combine


   .. py:method:: combine(*cvs: CV, flatten=False) -> CV
      :staticmethod:

      merges a list of CVs into a single CV. The dimenisions are stored such that it can later be split into separated CVs. The CVs are combined over the last dimension



.. py:class:: CvMetric(periodicities, bounding_box=None)

   class to keep track of topology of given CV. Identifies the periodicitie of CVs and maps to unit square with correct peridicities

   .. py:property:: ndim


   .. py:method:: norm(x1: CV, x2: CV, k=1.0)


   .. py:method:: periodic_wrap(x: CV, min=False) -> CV


   .. py:method:: difference(x1: CV, x2: CV) -> jax.Array


   .. py:method:: min_cv(cv: jax.Array)


   .. py:method:: __periodic_wrap(xs: jax.Array, min=False)

      Translate cvs such over unit cell.

      min=True calculates distances, False translates one vector inside box


   .. py:method:: map(x: jax.Array, displace=True) -> jax.Array

      transform CVs to lie in unit square.


   .. py:method:: unmap(x: jax.Array, displace=True) -> jax.Array

      transform CVs to lie in unit square.


   .. py:method:: __add__(other)


   .. py:method:: grid(n, endpoints=None, margin=None)

      forms regular grid in mapped space. If coordinate is periodic, last rows are ommited.

      :param n: number of points in each dim
      :param map: boolean. True: work in mapped space (default), False: calculate grid in space without metric
      :param endpoints: if

      :returns: meshgrid and vector with distances between points



.. py:class:: CvFunInput

   .. py:attribute:: input
      :type: int

      

   .. py:attribute:: conditioners
      :type: list[int] | None

      

   .. py:method:: split(x: CV)


   .. py:method:: combine(x: CV, res: CV)



.. py:class:: CvFunBase

   .. py:attribute:: _
      :type: dataclasses.KW_ONLY

      

   .. py:attribute:: cv_input
      :type: CvFunInput | None

      

   .. py:method:: calc(x: CV, nl: NeighbourList | None, reverse=False, log_det=False) -> tuple[CV, Array | None]


   .. py:method:: _calc(x: CV, nl: NeighbourList | None, reverse=False, conditioners: list[CV] | None = None) -> CV
      :abstractmethod:


   .. py:method:: _log_Jf(x: CV, nl: NeighbourList | None, reverse=False, conditioners: list[CV] | None = None) -> tuple[CV, Array | None]

      naive automated implementation, overrride this



.. py:class:: CvFun

   Bases: :py:obj:`CvFunBase`

   .. py:attribute:: forward
      :type: Callable[[CV, NeighbourList | None, CV | None], CV] | None

      

   .. py:attribute:: backward
      :type: Callable[[CV, NeighbourList | None, CV | None], CV] | None

      

   .. py:method:: _calc(x: CV, nl: NeighbourList | None, reverse=False, conditioners: list[CV] | None = None) -> CV



.. py:class:: CvFunNn

   Bases: :py:obj:`flax.linen.Module`, :py:obj:`CvFunBase`

   used to instantiate flax linen CvTrans

   .. py:method:: setup()
      :abstractmethod:

      Initializes a Module lazily (similar to a lazy ``__init__``).

      ``setup`` is called once lazily on a module instance when a module
      is bound, immediately before any other methods like ``__call__`` are
      invoked, or before a ``setup``-defined attribute on `self` is accessed.

      This can happen in three cases:

        1. Immediately when invoking :meth:`apply`, :meth:`init` or
           :meth:`init_and_output`.

        2. Once the module is given a name by being assigned to an attribute of
           another module inside the other module's ``setup`` method
           (see :meth:`__setattr__`)::

             class MyModule(nn.Module):
               def setup(self):
                 submodule = Conv(...)

                 # Accessing `submodule` attributes does not yet work here.

                 # The following line invokes `self.__setattr__`, which gives
                 # `submodule` the name "conv1".
                 self.conv1 = submodule

                 # Accessing `submodule` attributes or methods is now safe and
                 # either causes setup() to be called once.

        3. Once a module is constructed inside a method wrapped with
           :meth:`compact`, immediately before another method is called or
           ``setup`` defined attribute is accessed.


   .. py:method:: _calc(x: CV, nl: NeighbourList | None, reverse=False, conditioners: list[CV] | None = None) -> CV


   .. py:method:: forward(x: CV, nl: NeighbourList | None, conditioners: list[CV] | None = None) -> CV
      :abstractmethod:


   .. py:method:: backward(x: CV, nl: NeighbourList | None, conditioners: list[CV] | None = None) -> CV
      :abstractmethod:



.. py:class:: CvFunDistrax

   Bases: :py:obj:`flax.linen.Module`, :py:obj:`CvFunBase`

   creates bijective CV function based on a distrax flow. The seup function should initialize the bijector

   class RealNVP(CvFunDistrax):
       _: dataclasses.KW_ONLY
       latent_dim: int

       def setup(self):
           self.s = Dense(features=self.latent_dim)
           self.t = Dense(features=self.latent_dim)

           # Alternating binary mask.
           self.bijector = distrax.as_bijector(
               tfp.bijectors.RealNVP(
                   fraction_masked=0.5,
                   shift_and_log_scale_fn=self.shift_and_scale,
               )
           )

       def shift_and_scale(self, x0, input_depth, **condition_kwargs):
           return self.s(x0), self.t(x0)

   .. py:attribute:: bijector
      :type: distrax.Bijector

      

   .. py:method:: setup()
      :abstractmethod:

      setups self.bijector


   .. py:method:: _calc(x: CV, nl: NeighbourList | None, reverse=False, log_det=False, conditioners: list[CV] | None = None) -> CV


   .. py:method:: _log_Jf(x: CV, nl: NeighbourList | None, reverse=False, conditioners: list[CV] | None = None) -> tuple[CV, Array | None]

      naive implementation, overrride this



.. py:class:: CombinedCvFun

   Bases: :py:obj:`CvFunBase`

   .. py:attribute:: classes
      :type: list[list[CvFunBase]]

      

   .. py:method:: calc(x: CV, nl: NeighbourList | None, reverse=False, log_det=False) -> tuple[CV, Array | None]


   .. py:method:: _log_Jf(x: CV, nl: NeighbourList | None, reverse=False, conditioners: list[CV] | None = None) -> tuple[CV, Array | None]
      :abstractmethod:

      naive automated implementation, overrride this



.. py:class:: CvTrans

   f can either be a single CV tranformation or a list of transformations

   .. py:attribute:: trans
      :type: list[CvFunBase]

      

   .. py:method:: from_array_function(f: collections.abc.Callable[[jax.Array, NeighbourList | None, None], jax.Array])
      :staticmethod:


   .. py:method:: from_cv_function(f: collections.abc.Callable[[CV, NeighbourList | None, CV | None], CV]) -> CvTrans
      :staticmethod:


   .. py:method:: from_cv_fun(proto: CvFunBase)
      :staticmethod:


   .. py:method:: compute_cv_trans(x: CV, nl: NeighbourList | None = None, reverse=False, log_Jf=False) -> tuple[CV, Array | None]

      result is always batched
      arg: CV


   .. py:method:: __mul__(other)


   .. py:method:: __add__(other: CvTrans) -> CvTrans


   .. py:method:: stack(*cv_trans: CvTrans)
      :staticmethod:



.. py:class:: CvTransNN

   Bases: :py:obj:`flax.linen.Module`, :py:obj:`CvTrans`

   Base class for all neural network modules. Layers and models should subclass this class.

   All Flax Modules are Python 3.7
   `dataclasses <https://docs.python.org/3/library/dataclasses.html>`_. Since
   dataclasses take over ``__init__``, you should instead override :meth:`setup`,
   which is automatically called to initialize the module.

   Modules can contain submodules, and in this way can be nested in a tree
   structure. Submodels can be assigned as regular attributes inside the
   :meth:`setup` method.

   You can define arbitrary "forward pass" methods on your Module subclass.
   While no methods are special-cased, ``__call__`` is a popular choice because
   it allows you to use module instances as if they are functions::

     from flax import linen as nn

     class Module(nn.Module):
       features: Tuple[int, ...] = (16, 4)

       def setup(self):
         self.dense1 = Dense(self.features[0])
         self.dense2 = Dense(self.features[1])

       def __call__(self, x):
         return self.dense2(nn.relu(self.dense1(x)))

   Optionally, for more concise module implementations where submodules
   definitions are co-located with their usage, you can use the
   :meth:`compact` wrapper.

   .. py:attribute:: trans
      :type: list[CvFunBase]

      

   .. py:method:: setup() -> None

      Initializes a Module lazily (similar to a lazy ``__init__``).

      ``setup`` is called once lazily on a module instance when a module
      is bound, immediately before any other methods like ``__call__`` are
      invoked, or before a ``setup``-defined attribute on `self` is accessed.

      This can happen in three cases:

        1. Immediately when invoking :meth:`apply`, :meth:`init` or
           :meth:`init_and_output`.

        2. Once the module is given a name by being assigned to an attribute of
           another module inside the other module's ``setup`` method
           (see :meth:`__setattr__`)::

             class MyModule(nn.Module):
               def setup(self):
                 submodule = Conv(...)

                 # Accessing `submodule` attributes does not yet work here.

                 # The following line invokes `self.__setattr__`, which gives
                 # `submodule` the name "conv1".
                 self.conv1 = submodule

                 # Accessing `submodule` attributes or methods is now safe and
                 # either causes setup() to be called once.

        3. Once a module is constructed inside a method wrapped with
           :meth:`compact`, immediately before another method is called or
           ``setup`` defined attribute is accessed.


   .. py:method:: compute_cv_trans(x: CV, nl: NeighbourList | None, reverse=False, log_Jf=False) -> tuple[CV, Array | None]

      result is always batched
      arg: CV


   .. py:method:: __mul__(other)



.. py:class:: NormalizingFlow

   Bases: :py:obj:`flax.linen.Module`

   normalizing flow. _ProtoCvTransNN are stored separately because they need to be initialized by this module in setup

   .. py:attribute:: flow
      :type: CvTransNN | CvTransNN

      

   .. py:method:: setup() -> None

      Initializes a Module lazily (similar to a lazy ``__init__``).

      ``setup`` is called once lazily on a module instance when a module
      is bound, immediately before any other methods like ``__call__`` are
      invoked, or before a ``setup``-defined attribute on `self` is accessed.

      This can happen in three cases:

        1. Immediately when invoking :meth:`apply`, :meth:`init` or
           :meth:`init_and_output`.

        2. Once the module is given a name by being assigned to an attribute of
           another module inside the other module's ``setup`` method
           (see :meth:`__setattr__`)::

             class MyModule(nn.Module):
               def setup(self):
                 submodule = Conv(...)

                 # Accessing `submodule` attributes does not yet work here.

                 # The following line invokes `self.__setattr__`, which gives
                 # `submodule` the name "conv1".
                 self.conv1 = submodule

                 # Accessing `submodule` attributes or methods is now safe and
                 # either causes setup() to be called once.

        3. Once a module is constructed inside a method wrapped with
           :meth:`compact`, immediately before another method is called or
           ``setup`` defined attribute is accessed.


   .. py:method:: calc(x: CV, nl: NeighbourList | None, reverse: bool, test_log_det=False)



.. py:class:: CvFlow(func: collections.abc.Callable[[SystemParams, NeighbourList | None], CV], trans: CvTrans | None = None)

   .. py:method:: from_function(f: collections.abc.Callable[[SystemParams, NeighbourList | None], jax.Array], atomic=False) -> CvFlow
      :staticmethod:


   .. py:method:: compute_cv_flow(x: SystemParams, nl: NeighbourList | None = None, jit=True, chunk_size: int | None = None) -> CV


   .. py:method:: __add__(other)


   .. py:method:: __mul__(other)


   .. py:method:: find_sp(x0: SystemParams, target: CV, target_nl: NeighbourList, nl0: NeighbourList | None = None, maxiter=10000, tol=0.0001, norm=lambda cv1, cv2, nl1, nl2: jnp.linalg.norm(cv1 - cv2), solver=jaxopt.GradientDescent) -> SystemParams



.. py:class:: CollectiveVariable(f: CvFlow, metric: CvMetric, jac=jacrev)

   .. py:property:: n


   .. py:method:: _jit_f(sp, nl)


   .. py:method:: _jit_df(sp, nl)


   .. py:method:: compute_cv(sp: SystemParams, nl: NeighbourList | None = None, jacobian=False, jit=True, chunk_size: int | None = None) -> tuple[CV, CV]



