IMLCV.base.CVDiscovery
======================

.. py:module:: IMLCV.base.CVDiscovery


Classes
-------

.. autoapisummary::

   IMLCV.base.CVDiscovery.Transformer
   IMLCV.base.CVDiscovery.CombineTransformer
   IMLCV.base.CVDiscovery.IdentityTransformer
   IMLCV.base.CVDiscovery.CvTransTransformer


Module Contents
---------------

.. py:class:: Transformer(*args, **kwargs)

   Bases: :py:obj:`IMLCV.base.datastructures.MyPyTreeNode`


   Base class for dataclasses that should act like a JAX pytree node.


   .. py:attribute:: outdim
      :type:  int


   .. py:attribute:: descriptor
      :type:  IMLCV.base.CV.CvTrans | None
      :value: None



   .. py:attribute:: pre_scale
      :type:  bool
      :value: True



   .. py:attribute:: post_scale
      :type:  bool
      :value: True



   .. py:method:: pre_fit(dlo: IMLCV.base.rounds.DataLoaderOutput, chunk_size=None, shmap=True, shmap_kwargs=ShmapKwargs.create(), verbose=False, macro_chunk=10000)


   .. py:method:: static_fit(transformer: Transformer, **kwargs)
      :staticmethod:



   .. py:method:: fit(dlo: IMLCV.base.rounds.DataLoaderOutput, chunk_size: int | None = None, plot=True, plot_folder: str | pathlib.Path | None = None, shmap=True, percentile=5.0, jac=jax.jacrev, transform_FES=True, koopman=True, max_fes_bias: float | None = None, n_max=100000.0, samples_per_bin=20, min_samples_per_bin=3, verbose=True, cv_titles: list[str] | bool = True, vmax=100 * kjmol, macro_chunk=1000, shmap_kwargs=ShmapKwargs.create(), **kwargs) -> tuple[list[IMLCV.base.CV.CV], IMLCV.base.CV.CollectiveVariable, IMLCV.base.bias.Bias]


   .. py:method:: _fit(x: list[IMLCV.base.CV.CV] | list[IMLCV.base.CV.SystemParams], x_t: list[IMLCV.base.CV.CV] | list[IMLCV.base.CV.SystemParams] | None, w: list[jax.Array], w_t: list[jax.Array], dlo: IMLCV.base.rounds.DataLoaderOutput, chunk_size: int | None = None, verbose=True, macro_chunk=1000) -> tuple[list[IMLCV.base.CV.CV], list[IMLCV.base.CV.CV], IMLCV.base.CV.CvTrans, list[jax.Array] | None]
      :abstractmethod:



   .. py:method:: plot_app(collective_variables: list[IMLCV.base.CV.CollectiveVariable], cv_data: list[list[IMLCV.base.CV.CV]] | list[list[list[IMLCV.base.CV.CV]]] | None = None, biases: list[IMLCV.base.bias.Bias] | list[list[IMLCV.base.bias.Bias]] | None = None, indicate_plots: None | str | list[list[str | None]] = 'lightblue', duplicate_cv_data=True, name: str | pathlib.Path | None = None, labels=None, cv_titles: bool | list[str] = True, data_titles=None, color_trajectories=False, margin=0.1, plot_FES=False, T: float | None = None, vmin=0, vmax=100 * kjmol, dpi=300, n_max_bias=1000000.0, row_color=None, macro_chunk=10000, cmap='jet', offset=True, bar_label='FES [kJ/mol]')
      :staticmethod:


      Plot the app for the CV discovery. all 1d and 2d plots are plotted directly, 3d or higher are plotted as 2d slices.



   .. py:method:: _grid_spec_iterator(fig: matplotlib.figure.Figure, dims, ncv, ndata, skip, vmin, vmax, cmap, bar_label='FES [kJ/mol]', cv_titles=None, data_titles=None, indicate_plots=None, plot_FES=False)
      :staticmethod:



   .. py:method:: _plot_1d(fig: matplotlib.figure.Figure, grid: matplotlib.gridspec.GridSpec, data, colors, labels, collective_variable: IMLCV.base.CV.CollectiveVariable, fesses: dict[int, dict[tuple, IMLCV.base.bias.Bias]] | None = None, indices: tuple | None = None, margin=None, T=None, vmin=0, vmax=100 * kjmol, cmap=plt.get_cmap('jet'), **scatter_kwargs)
      :staticmethod:



   .. py:method:: _plot_2d(fig: matplotlib.figure.Figure, grid: matplotlib.gridspec.GridSpec, data, colors, labels, collective_variable: IMLCV.base.CV.CollectiveVariable, fesses: dict[int, dict[tuple, IMLCV.base.bias.Bias]] | None = None, indices: tuple | None = None, margin: float = 0.1, vmin=0, vmax=100 * kjmol, T=None, print_labels=False, cmap=plt.get_cmap('jet'), plot_y=True, **scatter_kwargs)
      :staticmethod:



   .. py:method:: _plot_3d(fig: matplotlib.figure.Figure, grid: matplotlib.gridspec.GridSpec, data, colors, labels, collective_variable: IMLCV.base.CV.CollectiveVariable, fesses: dict[int, dict[tuple, IMLCV.base.bias.Bias]] | None = None, indices: tuple[int, Ellipsis] | None = None, margin: float = 0.1, vmin=0, vmax=100 * kjmol, T=None, cmap=plt.get_cmap('jet'), **scatter_kwargs)
      :staticmethod:



   .. py:method:: _get_color_data(a: list[IMLCV.base.CV.CV], dim: int, color_trajectories=False, color_1d=True, metric: IMLCV.base.CV.CvMetric | None = None, max_val=None, min_val=None, margin=None) -> IMLCV.base.CV.CV
      :staticmethod:



   .. py:method:: __mul__(other)


.. py:class:: CombineTransformer(*args, **kwargs)

   Bases: :py:obj:`Transformer`


   Base class for dataclasses that should act like a JAX pytree node.


   .. py:attribute:: transformers
      :type:  list[Transformer]


   .. py:method:: create(transformers: list[Transformer]) -> CombineTransformer
      :staticmethod:



   .. py:method:: _fit(x: list[IMLCV.base.CV.CV] | list[IMLCV.base.CV.SystemParams], x_t: list[IMLCV.base.CV.CV] | list[IMLCV.base.CV.SystemParams] | None, w: list[jax.Array], w_t: list[jax.Array], dlo: IMLCV.base.rounds.DataLoaderOutput, chunk_size=None, verbose=True, macro_chunk=1000) -> tuple[list[IMLCV.base.CV.CV], list[IMLCV.base.CV.CV], IMLCV.base.CV.CvTrans, list[jax.Array] | None]


.. py:class:: IdentityTransformer(*args, **kwargs)

   Bases: :py:obj:`Transformer`


   Base class for dataclasses that should act like a JAX pytree node.


   .. py:method:: _fit(x: list[IMLCV.base.CV.CV] | list[IMLCV.base.CV.SystemParams], x_t: list[IMLCV.base.CV.CV] | list[IMLCV.base.CV.SystemParams] | None, w: list[jax.Array], w_t: list[jax.Array], dlo: IMLCV.base.rounds.DataLoaderOutput, chunk_size=None, verbose=True, macro_chunk=1000, **fit_kwargs) -> tuple[list[IMLCV.base.CV.CV], list[IMLCV.base.CV.CV] | None, IMLCV.base.CV.CvTrans, list[jax.Array] | None]


.. py:class:: CvTransTransformer(*args, **kwargs)

   Bases: :py:obj:`Transformer`


   Base class for dataclasses that should act like a JAX pytree node.


   .. py:attribute:: trans
      :type:  IMLCV.base.CV.CvTrans


   .. py:method:: _fit(x: list[IMLCV.base.CV.CV] | list[IMLCV.base.CV.SystemParams], x_t: list[IMLCV.base.CV.CV] | list[IMLCV.base.CV.SystemParams] | None, w: list[jax.Array], w_t: list[jax.Array], dlo: IMLCV.base.rounds.DataLoaderOutput, chunk_size: int | None = None, verbose=True, macro_chunk=1000) -> tuple[list[IMLCV.base.CV.CV], list[IMLCV.base.CV.CV], IMLCV.base.CV.CvTrans, list[jax.Array] | None]
      :abstractmethod:



