IMLCV.base.CVDiscovery#

Classes#

Transformer

Base class for dataclasses that should act like a JAX pytree node.

CombineTransformer

Base class for dataclasses that should act like a JAX pytree node.

IdentityTransformer

Base class for dataclasses that should act like a JAX pytree node.

CvTransTransformer

Base class for dataclasses that should act like a JAX pytree node.

Module Contents#

class IMLCV.base.CVDiscovery.Transformer(*args, **kwargs)#

Bases: IMLCV.base.datastructures.MyPyTreeNode

Base class for dataclasses that should act like a JAX pytree node.

outdim: int#
descriptor: IMLCV.base.CV.CvTrans | None = None#
pre_scale: bool = True#
post_scale: bool = True#
pre_fit(dlo: IMLCV.base.rounds.DataLoaderOutput, chunk_size=None, shmap=True, shmap_kwargs=ShmapKwargs.create(), verbose=False, macro_chunk=10000)#
static static_fit(transformer: Transformer, **kwargs)#
fit(dlo: IMLCV.base.rounds.DataLoaderOutput, chunk_size: int | None = None, plot=True, plot_folder: str | pathlib.Path | None = None, shmap=True, percentile=5.0, jac=jax.jacrev, transform_FES=True, koopman=True, max_fes_bias: float | None = None, n_max=100000.0, samples_per_bin=20, min_samples_per_bin=3, verbose=True, cv_titles: list[str] | bool = True, vmax=100 * kjmol, macro_chunk=1000, shmap_kwargs=ShmapKwargs.create(), **kwargs) tuple[list[IMLCV.base.CV.CV], IMLCV.base.CV.CollectiveVariable, IMLCV.base.bias.Bias]#
abstract _fit(x: list[IMLCV.base.CV.CV] | list[IMLCV.base.CV.SystemParams], x_t: list[IMLCV.base.CV.CV] | list[IMLCV.base.CV.SystemParams] | None, w: list[jax.Array], w_t: list[jax.Array], dlo: IMLCV.base.rounds.DataLoaderOutput, chunk_size: int | None = None, verbose=True, macro_chunk=1000) tuple[list[IMLCV.base.CV.CV], list[IMLCV.base.CV.CV], IMLCV.base.CV.CvTrans, list[jax.Array] | None]#
static plot_app(collective_variables: list[IMLCV.base.CV.CollectiveVariable], cv_data: list[list[IMLCV.base.CV.CV]] | list[list[list[IMLCV.base.CV.CV]]] | None = None, biases: list[IMLCV.base.bias.Bias] | list[list[IMLCV.base.bias.Bias]] | None = None, indicate_plots: None | str | list[list[str | None]] = 'lightblue', duplicate_cv_data=True, name: str | pathlib.Path | None = None, labels=None, cv_titles: bool | list[str] = True, data_titles=None, color_trajectories=False, margin=0.1, plot_FES=False, T: float | None = None, vmin=0, vmax=100 * kjmol, dpi=300, n_max_bias=1000000.0, row_color=None, macro_chunk=10000, cmap='jet', offset=True, bar_label='FES [kJ/mol]')#

Plot the app for the CV discovery. all 1d and 2d plots are plotted directly, 3d or higher are plotted as 2d slices.

static _grid_spec_iterator(fig: matplotlib.figure.Figure, dims, ncv, ndata, skip, vmin, vmax, cmap, bar_label='FES [kJ/mol]', cv_titles=None, data_titles=None, indicate_plots=None, plot_FES=False)#
static _plot_1d(fig: matplotlib.figure.Figure, grid: matplotlib.gridspec.GridSpec, data, colors, labels, collective_variable: IMLCV.base.CV.CollectiveVariable, fesses: dict[int, dict[tuple, IMLCV.base.bias.Bias]] | None = None, indices: tuple | None = None, margin=None, T=None, vmin=0, vmax=100 * kjmol, cmap=plt.get_cmap('jet'), **scatter_kwargs)#
static _plot_2d(fig: matplotlib.figure.Figure, grid: matplotlib.gridspec.GridSpec, data, colors, labels, collective_variable: IMLCV.base.CV.CollectiveVariable, fesses: dict[int, dict[tuple, IMLCV.base.bias.Bias]] | None = None, indices: tuple | None = None, margin: float = 0.1, vmin=0, vmax=100 * kjmol, T=None, print_labels=False, cmap=plt.get_cmap('jet'), plot_y=True, **scatter_kwargs)#
static _plot_3d(fig: matplotlib.figure.Figure, grid: matplotlib.gridspec.GridSpec, data, colors, labels, collective_variable: IMLCV.base.CV.CollectiveVariable, fesses: dict[int, dict[tuple, IMLCV.base.bias.Bias]] | None = None, indices: tuple[int, Ellipsis] | None = None, margin: float = 0.1, vmin=0, vmax=100 * kjmol, T=None, cmap=plt.get_cmap('jet'), **scatter_kwargs)#
static _get_color_data(a: list[IMLCV.base.CV.CV], dim: int, color_trajectories=False, color_1d=True, metric: IMLCV.base.CV.CvMetric | None = None, max_val=None, min_val=None, margin=None) IMLCV.base.CV.CV#
__mul__(other)#
class IMLCV.base.CVDiscovery.CombineTransformer(*args, **kwargs)#

Bases: Transformer

Base class for dataclasses that should act like a JAX pytree node.

transformers: list[Transformer]#
static create(transformers: list[Transformer]) CombineTransformer#
_fit(x: list[IMLCV.base.CV.CV] | list[IMLCV.base.CV.SystemParams], x_t: list[IMLCV.base.CV.CV] | list[IMLCV.base.CV.SystemParams] | None, w: list[jax.Array], w_t: list[jax.Array], dlo: IMLCV.base.rounds.DataLoaderOutput, chunk_size=None, verbose=True, macro_chunk=1000) tuple[list[IMLCV.base.CV.CV], list[IMLCV.base.CV.CV], IMLCV.base.CV.CvTrans, list[jax.Array] | None]#
class IMLCV.base.CVDiscovery.IdentityTransformer(*args, **kwargs)#

Bases: Transformer

Base class for dataclasses that should act like a JAX pytree node.

_fit(x: list[IMLCV.base.CV.CV] | list[IMLCV.base.CV.SystemParams], x_t: list[IMLCV.base.CV.CV] | list[IMLCV.base.CV.SystemParams] | None, w: list[jax.Array], w_t: list[jax.Array], dlo: IMLCV.base.rounds.DataLoaderOutput, chunk_size=None, verbose=True, macro_chunk=1000, **fit_kwargs) tuple[list[IMLCV.base.CV.CV], list[IMLCV.base.CV.CV] | None, IMLCV.base.CV.CvTrans, list[jax.Array] | None]#
class IMLCV.base.CVDiscovery.CvTransTransformer(*args, **kwargs)#

Bases: Transformer

Base class for dataclasses that should act like a JAX pytree node.

trans: IMLCV.base.CV.CvTrans#
abstract _fit(x: list[IMLCV.base.CV.CV] | list[IMLCV.base.CV.SystemParams], x_t: list[IMLCV.base.CV.CV] | list[IMLCV.base.CV.SystemParams] | None, w: list[jax.Array], w_t: list[jax.Array], dlo: IMLCV.base.rounds.DataLoaderOutput, chunk_size: int | None = None, verbose=True, macro_chunk=1000) tuple[list[IMLCV.base.CV.CV], list[IMLCV.base.CV.CV], IMLCV.base.CV.CvTrans, list[jax.Array] | None]#