IMLCV.base.CV
=============

.. py:module:: IMLCV.base.CV


Attributes
----------

.. autoapisummary::

   IMLCV.base.CV.T
   IMLCV.base.CV.S
   IMLCV.base.CV.P
   IMLCV.base.CV.P2
   IMLCV.base.CV.P
   IMLCV.base.CV.T
   IMLCV.base.CV.T
   IMLCV.base.CV.X
   IMLCV.base.CV.X2


Classes
-------

.. autoapisummary::

   IMLCV.base.CV.ShmapKwargs
   IMLCV.base.CV.SystemParams
   IMLCV.base.CV.NeighbourListInfo
   IMLCV.base.CV.NeighbourListUpdate
   IMLCV.base.CV.NeighbourList
   IMLCV.base.CV.CV
   IMLCV.base.CV.CvMetric
   IMLCV.base.CV.CvFunBase
   IMLCV.base.CV.CvFun
   IMLCV.base.CV._SerialCvTrans
   IMLCV.base.CV._ParralelCvTrans
   IMLCV.base.CV.CvTrans
   IMLCV.base.CV.CollectiveVariable


Functions
---------

.. autoapisummary::

   IMLCV.base.CV._n_pad
   IMLCV.base.CV._n_unpad
   IMLCV.base.CV._shard
   IMLCV.base.CV._shard_out
   IMLCV.base.CV.padded_shard_map
   IMLCV.base.CV.padded_vmap
   IMLCV.base.CV.macro_chunk_map_fun
   IMLCV.base.CV.macro_chunk_map
   IMLCV.base.CV._macro_chunk_map


Module Contents
---------------

.. py:data:: T

.. py:data:: S

.. py:data:: P

.. py:data:: P2

.. py:class:: ShmapKwargs(*args, **kwargs)

   Bases: :py:obj:`IMLCV.base.datastructures.MyPyTreeNode`


   Base class for dataclasses that should act like a JAX pytree node.


   .. py:attribute:: axis
      :type:  int
      :value: 0



   .. py:attribute:: out_axes
      :type:  int
      :value: 0



   .. py:attribute:: axis_name
      :type:  str | None


   .. py:attribute:: n_devices
      :type:  int | None
      :value: None



   .. py:attribute:: pmap
      :type:  bool
      :value: False



   .. py:attribute:: explicit_shmap
      :type:  bool
      :value: False



   .. py:attribute:: verbose
      :type:  bool
      :value: False



   .. py:attribute:: device_get
      :type:  bool
      :value: True



   .. py:attribute:: devices
      :type:  tuple[Any]


   .. py:attribute:: mesh
      :type:  jax.sharding.Mesh | None


   .. py:method:: create(axis: int = 0, out_axes: int = 0, axis_name: str | None = 'i', n_devices: int | None = None, pmap: bool = False, explicit_shmap: bool = False, verbose: bool = False, device_get: bool = True, devices=None, mesh=None)
      :staticmethod:



.. py:function:: _n_pad(x, axis, p, chunk_size, reshape=False, move_axis=True, n_chunk_move=False)

.. py:function:: _n_unpad(x, axis, shape, reshape=False, move_axis=True, trim=True, n_chunk_move=False)

.. py:function:: _shard(x_padded, axis, axis_name, mesh, put=True, unflatten=True)

.. py:function:: _shard_out(out_tree_def, axis, axis_name, mesh)

.. py:function:: padded_shard_map(f: Callable[P, T], kwargs: ShmapKwargs = ShmapKwargs.create())

.. py:data:: P

.. py:data:: T

.. py:function:: padded_vmap(f: Callable[P, T], chunk_size=None, axis=0, out_axes: int = 0, vmap=True, verbose=False)

.. py:data:: T

.. py:data:: X

.. py:data:: X2

.. py:function:: macro_chunk_map_fun(f: Callable[[X, NeighbourList | None], X2], y: list[X], w: list[jax.Array] | None = None, nl: list[NeighbourList] | NeighbourList | None = None, ft: Callable[[X, NeighbourList | None], X2] | None = None, y_t: list[X] | None = None, nl_t: list[NeighbourList] | NeighbourList | None = None, macro_chunk: int | None = 1000, verbose=False, chunk_func: Callable[[T, X2, X2 | None, jax.Array | None, jax.Array | None], T] | None = None, chunk_func_init_args: T = None, w_t: list[jax.Array] | None = None, print_every=10, jit_f=True) -> T

.. py:function:: macro_chunk_map(f: Callable[[X, NeighbourList | None], X2], y: list[X], nl: list[NeighbourList] | NeighbourList | None = None, ft: Callable[[X, NeighbourList | None], X2] | None = None, y_t: list[X] | None = None, nl_t: list[NeighbourList] | NeighbourList | None = None, macro_chunk: int | None = 1000, verbose=False, print_every=10, jit_f=True) -> list[X2] | tuple[list[X2], list[X2] | None]

.. py:function:: _macro_chunk_map(f: Callable[[X, NeighbourList | None], X2], y: list[X], w: list[jax.Array] | None = None, nl: list[NeighbourList] | NeighbourList | None = None, ft: Callable[[X, NeighbourList | None], X2] | None = None, y_t: list[X] | None = None, nl_t: list[NeighbourList] | NeighbourList | None = None, macro_chunk: int | None = 1000, verbose=False, chunk_func: Callable[[T, X2, X2 | None, jax.Array | None, jax.Array | None], T] | None = None, chunk_func_init_args: T = None, w_t: list[jax.Array] | None = None, print_every=10, jit_f=True)

.. py:class:: SystemParams(*args, **kwargs)

   Bases: :py:obj:`IMLCV.base.datastructures.MyPyTreeNode`


   Base class for dataclasses that should act like a JAX pytree node.


   .. py:attribute:: coordinates
      :type:  jax.Array


   .. py:attribute:: cell
      :type:  jax.Array | None


   .. py:method:: __getitem__(slices) -> SystemParams


   .. py:method:: __iter__()


   .. py:property:: batched


   .. py:property:: batch_dim


   .. py:property:: shape


   .. py:method:: __add__(other)


   .. py:method:: stack(*sps: SystemParams) -> SystemParams
      :staticmethod:



   .. py:method:: batch() -> SystemParams


   .. py:method:: unbatch() -> SystemParams


   .. py:method:: angles(deg=True) -> jax.Array


   .. py:method:: _get_neighbour_list(info: NeighbourListInfo, update: NeighbourListUpdate | None = None, chunk_size: int | None = 100, verbose=False, chunk_size_inner: int | None = 100, shmap=False, shmap_kwargs: ShmapKwargs = ShmapKwargs.create(), only_update=False)


   .. py:method:: get_neighbour_list(info: NeighbourListInfo, chunk_size: int | None = None, chunk_size_inner=100, verbose=False, shmap=False, shmap_kwargs: ShmapKwargs = ShmapKwargs.create(), only_update=False) -> NeighbourList | None


   .. py:method:: minkowski_reduce() -> tuple[SystemParams, jax.Array]

      base on code from ASE: https://wiki.fysik.dtu.dk/ase/_modules/ase/geometry/minkowski_reduction.html



   .. py:method:: apply_minkowski_reduction(op)


   .. py:method:: rotate_cell() -> tuple[SystemParams, tuple[jax.Array, jax.Array] | None]


   .. py:method:: apply_rotation(op)


   .. py:method:: to_relative() -> tuple[SystemParams, jax.Array | None]


   .. py:method:: to_absolute() -> SystemParams


   .. py:method:: wrap_positions(min=False) -> tuple[SystemParams, jax.Array]

      wrap pos to lie within unit cell



   .. py:method:: apply_wrap(wrap_op: jax.Array) -> SystemParams


   .. py:method:: canonicalize(min=False, qr=False, chunk_size=None) -> tuple[SystemParams, tuple[jax.Array, jax.Array, tuple[jax.Array, jax.Array] | None]]


   .. py:method:: apply_canonicalize(ops)


   .. py:method:: min_distance(index_1, index_2)


   .. py:method:: super_cell(n: int | list[int], info: NeighbourListInfo | None = None) -> tuple[SystemParams, NeighbourListInfo | None]


   .. py:method:: volume()


   .. py:method:: to_ase(static_trajectory_info: IMLCV.base.MdEngine.StaticMdInfo)


.. py:class:: NeighbourListInfo(*args, **kwargs)

   Bases: :py:obj:`IMLCV.base.datastructures.MyPyTreeNode`


   Base class for dataclasses that should act like a JAX pytree node.


   .. py:attribute:: r_cut
      :type:  float


   .. py:attribute:: r_skin
      :type:  float


   .. py:attribute:: z_array
      :type:  tuple[int] | None


   .. py:attribute:: z_unique
      :type:  tuple[int] | None


   .. py:attribute:: num_z_unique
      :type:  tuple[int] | None


   .. py:method:: create(r_cut, z_array, r_skin=None)
      :staticmethod:



   .. py:method:: nl_split_z(p: T) -> tuple[jax.Array, list[jax.Array], list[T]]


   .. py:method:: __getstate__()


   .. py:method:: __setstate__(state)


.. py:class:: NeighbourListUpdate(*args, **kwargs)

   Bases: :py:obj:`IMLCV.base.datastructures.MyPyTreeNode`


   Base class for dataclasses that should act like a JAX pytree node.


   .. py:attribute:: nxyz
      :type:  tuple[int] | None


   .. py:attribute:: stack_dims
      :type:  tuple[int] | None


   .. py:attribute:: num_neighs
      :type:  int | None


   .. py:method:: create(nxyz=None, stack_dims=None, num_neighs: int | None = None)
      :staticmethod:



   .. py:method:: __getstate__()


   .. py:method:: __setstate__(state)


.. py:class:: NeighbourList(*args, **kwargs)

   Bases: :py:obj:`IMLCV.base.datastructures.MyPyTreeNode`


   Base class for dataclasses that should act like a JAX pytree node.


   .. py:attribute:: sp_orig
      :type:  SystemParams | None


   .. py:attribute:: info
      :type:  NeighbourListInfo


   .. py:attribute:: update
      :type:  NeighbourListUpdate


   .. py:attribute:: atom_indices
      :type:  jax.Array | None


   .. py:attribute:: op_cell
      :type:  jax.Array | None


   .. py:attribute:: op_coor
      :type:  jax.Array | None


   .. py:attribute:: op_center
      :type:  jax.Array | None


   .. py:attribute:: ijk_indices
      :type:  jax.Array | None


   .. py:attribute:: _padding_bools
      :type:  jax.Array | None


   .. py:method:: create(r_cut, sp_orig, atom_indices=None, r_skin=None, ijk_indices=None, z_array=None, nxyz=None, op_cell=None, op_coor=None, op_center=None, stack_dims=None, num_neighs=None)
      :staticmethod:



   .. py:method:: nneighs(sp=None)


   .. py:property:: needs_calculation


   .. py:property:: padding_bools


   .. py:method:: canonicalized_sp(sp: SystemParams) -> SystemParams


   .. py:method:: neighbour_pos(sp_orig: SystemParams) -> jax.Array


   .. py:method:: apply_fun_neighbour(sp: SystemParams, func: Callable[[jax.Array, jax.Array], T], r_cut: float | None = None, fill_value=0, reduce='full', exclude_self: bool = False, chunk_size_neigbourgs: int | None = None, chunk_size_atoms: int | None = None, chunk_size_batch: int | None = None, shmap=False, split_z=False, shmap_kwargs=ShmapKwargs.create())


   .. py:method:: apply_fun_neighbour_pair(sp: SystemParams, func_double: Callable[[jax.Array, jax.Array, T, jax.Array, jax.Array, T], S], func_single: Callable[[jax.Array, jax.Array], T] = lambda x, y: None, r_cut=None, fill_value=0.0, reduce='full', split_z=False, exclude_self=True, unique=True, chunk_size_neigbourgs=None, chunk_size_atoms=None, chunk_size_batch=None, shmap=False, shmap_kwargs=ShmapKwargs.create()) -> tuple[jax.Array, S]

      Args:
      ______
      func_single(r_ij, atom_index_j) = ..
      func_double( p_ij, atom_index_j, data_j, p_ik, atom_index_k, data_k) = ...



   .. py:property:: batched


   .. py:property:: batch_dim


   .. py:property:: shape


   .. py:method:: __getitem__(slices)


   .. py:method:: needs_update(sp: SystemParams) -> bool


   .. py:method:: update_nl(sp: SystemParams, chunk_size: int | None = None, chunk_size_inner: int | None = 10, shmap=False, shmap_kwargs=ShmapKwargs.create(), verbose=False)


   .. py:method:: nl_split_z(p)


   .. py:method:: batch()


   .. py:method:: __add__(other)


   .. py:method:: stack(*nls: NeighbourList) -> NeighbourList
      :staticmethod:



   .. py:method:: unstack() -> list[NeighbourList]


   .. py:method:: __getstate__()


   .. py:method:: __setstate__(statedict: dict)


   .. py:property:: stack_dims


.. py:class:: CV(*args, **kwargs)

   Bases: :py:obj:`IMLCV.base.datastructures.MyPyTreeNode`


   Base class for dataclasses that should act like a JAX pytree node.


   .. py:attribute:: cv
      :type:  jax.Array


   .. py:attribute:: mapped
      :type:  bool


   .. py:attribute:: atomic
      :type:  bool


   .. py:attribute:: _combine_dims
      :type:  tuple[int | Any] | None


   .. py:attribute:: _stack_dims
      :type:  tuple[int] | None


   .. py:method:: create(cv: jax.Array, mapped=False, atomic=False, combine_dims=None, stack_dims=None)
      :staticmethod:



   .. py:property:: batched


   .. py:property:: batch_dim


   .. py:property:: dim


   .. py:property:: size


   .. py:property:: shape


   .. py:property:: combine_dims


   .. py:property:: stack_dims


   .. py:method:: __add__(other) -> CV


   .. py:method:: __radd__(other) -> CV


   .. py:method:: __sub__(other) -> CV


   .. py:method:: __rsub__(other) -> CV


   .. py:method:: __mul__(other) -> CV


   .. py:method:: __rmul__(other) -> CV


   .. py:method:: __matmul__(other) -> CV


   .. py:method:: __rmatmul__(other) -> CV


   .. py:method:: __div__(other) -> CV


   .. py:method:: batch() -> CV


   .. py:method:: __iter__()


   .. py:method:: __getitem__(idx)


   .. py:method:: unbatch() -> CV


   .. py:method:: stack(*cvs: CV) -> CV
      :staticmethod:


      stacks a list of CVs into a single CV. The dimenisions are stored such that it can later be unstacked into separated CVs. The CVs are stacked over the batch dimension



   .. py:method:: unstack() -> list[CV]


   .. py:method:: split(flatten=False) -> list[CV]

      inverse operation of combine



   .. py:method:: combine(*cvs: CV, flatten=False) -> CV
      :staticmethod:


      merges a list of CVs into a single CV. The dimenisions are stored such that it can later be split into separated CVs. The CVs are combined over the last dimension



   .. py:method:: __getstate__()


   .. py:method:: __setstate__(statedict: dict)


.. py:class:: CvMetric(*args, **kwargs)

   Bases: :py:obj:`IMLCV.base.datastructures.MyPyTreeNode`


   class to keep track of topology of given CV. Identifies the periodicitie of CVs and maps to unit square with correct peridicities


   .. py:attribute:: bounding_box
      :type:  jax.Array


   .. py:attribute:: periodicities
      :type:  jax.Array


   .. py:method:: create(periodicities=None, bounding_box=None) -> CvMetric
      :classmethod:



   .. py:method:: norm(x1: CV, x2: CV, k=1.0)


   .. py:method:: periodic_wrap(x: CV, min=False) -> CV


   .. py:method:: difference(x1: CV, x2: CV) -> jax.Array


   .. py:method:: min_cv(cv: jax.Array)


   .. py:method:: __periodic_wrap(xs: jax.Array, min=False)

      Translate cvs such over unit cell.

      min=True calculates distances, False translates one vector inside box



   .. py:method:: map(x: jax.Array, displace=True) -> jax.Array

      transform CVs to lie in unit square.



   .. py:method:: unmap(x: jax.Array, displace=True) -> jax.Array

      transform CVs to lie in unit square.



   .. py:method:: __add__(other)


   .. py:method:: get_n(samples_per_bin, samples, n_dims, max_bins=None, max_bins_per_dim=30)
      :staticmethod:



   .. py:method:: grid(n=30, bounds=None, margin=0.1, indexing='ij')

      forms regular grid in mapped space. If coordinate is periodic, last rows are ommited.

      :param n: number of points in each dim
      :param map: boolean. True: work in mapped space (default), False: calculate grid in space without metric
      :param endpoints: if

      :returns: meshgrid and vector with distances between points



   .. py:property:: ndim


   .. py:method:: bounds_from_cv(cv_0: list[CV], percentile=0.1, weights: list[jax.Array] | None = None, rho: list[jax.Array] | None = None, margin=None, chunk_size: int | None = None, n=400, macro_chunk: int | None = 5000, verbose=True)
      :staticmethod:



   .. py:method:: __getstate__()


   .. py:method:: __setstate__(statedict: dict)


   .. py:method:: __getitem__(idx)


.. py:class:: CvFunBase(*args, **kwargs)

   Bases: :py:obj:`abc.ABC`, :py:obj:`IMLCV.base.datastructures.MyPyTreeNode`


   Helper class that provides a standard way to create an ABC using
   inheritance.


   .. py:attribute:: kwargs
      :type:  dict


   .. py:attribute:: static_kwargs
      :type:  dict


   .. py:method:: compute_cv(x: CV | SystemParams, nl: NeighbourList | None = None, chunk_size=None, jacobian=False, reverse=False, shmap=False, shmap_kwargs=ShmapKwargs.create()) -> tuple[CV, CV | None]


   .. py:method:: _compute_cv(x: CV | SystemParams, nl: NeighbourList | None = None, reverse=False, shmap=False, shmap_kwargs=ShmapKwargs.create()) -> CV
      :abstractmethod:



   .. py:method:: __getstate__()


   .. py:method:: __setstate__(statedict: dict)


.. py:class:: CvFun(*args, **kwargs)

   Bases: :py:obj:`CvFunBase`


   Helper class that provides a standard way to create an ABC using
   inheritance.


   .. py:attribute:: forward
      :type:  Callable[[CV, NeighbourList | None, CV | None], CV] | None


   .. py:attribute:: backward
      :type:  Callable[[CV, NeighbourList | None, CV | None], CV] | None


   .. py:method:: _compute_cv(x: CV | SystemParams, nl: NeighbourList | None = None, reverse=False, shmap=False, shmap_kwargs=ShmapKwargs.create()) -> CV


.. py:class:: _SerialCvTrans(*args, **kwargs)

   Bases: :py:obj:`IMLCV.base.datastructures.MyPyTreeNode`


   f can either be a single CV tranformation or a list of transformations


   .. py:attribute:: trans
      :type:  tuple[_SerialCvTrans | _ParralelCvTrans | CvFun, Ellipsis]


   .. py:method:: _compute_cv(x: CV | SystemParams, nl: NeighbourList | None = None, reverse=False, shmap=False, shmap_kwargs=ShmapKwargs.create()) -> CV


.. py:class:: _ParralelCvTrans(*args, **kwargs)

   Bases: :py:obj:`IMLCV.base.datastructures.MyPyTreeNode`


   Base class for dataclasses that should act like a JAX pytree node.


   .. py:attribute:: trans
      :type:  tuple[_SerialCvTrans | _ParralelCvTrans, Ellipsis]


   .. py:method:: _compute_cv(x: CV | SystemParams, nl: NeighbourList | None = None, reverse=False, shmap=False, shmap_kwargs=ShmapKwargs.create()) -> CV


.. py:class:: CvTrans(*args, **kwargs)

   Bases: :py:obj:`IMLCV.base.datastructures.MyPyTreeNode`


   Base class for dataclasses that should act like a JAX pytree node.


   .. py:attribute:: trans
      :type:  _ParralelCvTrans | _SerialCvTrans


   .. py:method:: from_cv_function(f: Callable, static_argnames=None, check_input: bool = True, **kwargs) -> CvTrans
      :staticmethod:



   .. py:method:: compute_cv(x: X, nl: NeighbourList | None = None, chunk_size=None, jacobian=False, reverse=False, shmap=False, shmap_kwargs=ShmapKwargs.create()) -> tuple[CV, X | None]


   .. py:method:: __getstate__()


   .. py:method:: __setstate__(statedict: dict)


   .. py:method:: __mul__(other: CvTrans)


   .. py:method:: __add__(other: CvTrans)


.. py:class:: CollectiveVariable(*args, **kwargs)

   Bases: :py:obj:`IMLCV.base.datastructures.MyPyTreeNode`


   Base class for dataclasses that should act like a JAX pytree node.


   .. py:attribute:: f
      :type:  CvTrans


   .. py:attribute:: metric
      :type:  CvMetric


   .. py:attribute:: jac
      :type:  Callable


   .. py:method:: compute_cv(sp: SystemParams, nl: NeighbourList | None = None, jacobian=False, chunk_size: int | None = None, shmap=False, push_jac=False, shmap_kwargs=ShmapKwargs.create()) -> tuple[CV, SystemParams | None]


   .. py:property:: n


   .. py:method:: save(file)


   .. py:method:: load(file, **kwargs) -> CollectiveVariable
      :staticmethod:



   .. py:method:: __getstate__()


   .. py:method:: __setstate__(statedict: dict)


   .. py:method:: __getitem__(tup)


