IMLCV.base.CVDiscovery#
Classes#
Base class for dataclasses that should act like a JAX pytree node. |
|
Base class for dataclasses that should act like a JAX pytree node. |
|
Base class for dataclasses that should act like a JAX pytree node. |
|
Base class for dataclasses that should act like a JAX pytree node. |
Module Contents#
- class IMLCV.base.CVDiscovery.Transformer(*args, **kwargs)#
Bases:
IMLCV.base.datastructures.MyPyTreeNodeBase class for dataclasses that should act like a JAX pytree node.
- descriptor: IMLCV.base.CV.CvTrans | None = None#
- pre_fit(dlo: IMLCV.base.rounds.DataLoaderOutput, chunk_size=None, shmap=True, shmap_kwargs=ShmapKwargs.create(), verbose=False, macro_chunk=10000)#
- static static_fit(transformer: Transformer, **kwargs)#
- fit(dlo: IMLCV.base.rounds.DataLoaderOutput, chunk_size: int | None = None, plot=True, plot_folder: str | pathlib.Path | None = None, shmap=True, percentile=5.0, jac=jax.jacrev, transform_FES=True, koopman=True, max_fes_bias: float | None = None, n_max=100000.0, samples_per_bin=20, min_samples_per_bin=3, verbose=True, cv_titles: list[str] | bool = True, vmax=100 * kjmol, macro_chunk=1000, shmap_kwargs=ShmapKwargs.create(), **kwargs) tuple[list[IMLCV.base.CV.CV], IMLCV.base.CV.CollectiveVariable, IMLCV.base.bias.Bias]#
- abstract _fit(x: list[IMLCV.base.CV.CV] | list[IMLCV.base.CV.SystemParams], x_t: list[IMLCV.base.CV.CV] | list[IMLCV.base.CV.SystemParams] | None, w: list[jax.Array], w_t: list[jax.Array], dlo: IMLCV.base.rounds.DataLoaderOutput, chunk_size: int | None = None, verbose=True, macro_chunk=1000) tuple[list[IMLCV.base.CV.CV], list[IMLCV.base.CV.CV], IMLCV.base.CV.CvTrans, list[jax.Array] | None]#
- static plot_app(collective_variables: list[IMLCV.base.CV.CollectiveVariable], cv_data: list[list[IMLCV.base.CV.CV]] | list[list[list[IMLCV.base.CV.CV]]] | None = None, biases: list[IMLCV.base.bias.Bias] | list[list[IMLCV.base.bias.Bias]] | None = None, indicate_plots: None | str | list[list[str | None]] = 'lightblue', duplicate_cv_data=True, name: str | pathlib.Path | None = None, labels=None, cv_titles: bool | list[str] = True, data_titles=None, color_trajectories=False, margin=0.1, plot_FES=False, T: float | None = None, vmin=0, vmax=100 * kjmol, dpi=300, n_max_bias=1000000.0, row_color=None, macro_chunk=10000, cmap='jet', offset=True, bar_label='FES [kJ/mol]')#
Plot the app for the CV discovery. all 1d and 2d plots are plotted directly, 3d or higher are plotted as 2d slices.
- static _grid_spec_iterator(fig: matplotlib.figure.Figure, dims, ncv, ndata, skip, vmin, vmax, cmap, bar_label='FES [kJ/mol]', cv_titles=None, data_titles=None, indicate_plots=None, plot_FES=False)#
- static _plot_1d(fig: matplotlib.figure.Figure, grid: matplotlib.gridspec.GridSpec, data, colors, labels, collective_variable: IMLCV.base.CV.CollectiveVariable, fesses: dict[int, dict[tuple, IMLCV.base.bias.Bias]] | None = None, indices: tuple | None = None, margin=None, T=None, vmin=0, vmax=100 * kjmol, cmap=plt.get_cmap('jet'), **scatter_kwargs)#
- static _plot_2d(fig: matplotlib.figure.Figure, grid: matplotlib.gridspec.GridSpec, data, colors, labels, collective_variable: IMLCV.base.CV.CollectiveVariable, fesses: dict[int, dict[tuple, IMLCV.base.bias.Bias]] | None = None, indices: tuple | None = None, margin: float = 0.1, vmin=0, vmax=100 * kjmol, T=None, print_labels=False, cmap=plt.get_cmap('jet'), plot_y=True, **scatter_kwargs)#
- static _plot_3d(fig: matplotlib.figure.Figure, grid: matplotlib.gridspec.GridSpec, data, colors, labels, collective_variable: IMLCV.base.CV.CollectiveVariable, fesses: dict[int, dict[tuple, IMLCV.base.bias.Bias]] | None = None, indices: tuple[int, Ellipsis] | None = None, margin: float = 0.1, vmin=0, vmax=100 * kjmol, T=None, cmap=plt.get_cmap('jet'), **scatter_kwargs)#
- static _get_color_data(a: list[IMLCV.base.CV.CV], dim: int, color_trajectories=False, color_1d=True, metric: IMLCV.base.CV.CvMetric | None = None, max_val=None, min_val=None, margin=None) IMLCV.base.CV.CV#
- __mul__(other)#
- class IMLCV.base.CVDiscovery.CombineTransformer(*args, **kwargs)#
Bases:
TransformerBase class for dataclasses that should act like a JAX pytree node.
- transformers: list[Transformer]#
- static create(transformers: list[Transformer]) CombineTransformer#
- _fit(x: list[IMLCV.base.CV.CV] | list[IMLCV.base.CV.SystemParams], x_t: list[IMLCV.base.CV.CV] | list[IMLCV.base.CV.SystemParams] | None, w: list[jax.Array], w_t: list[jax.Array], dlo: IMLCV.base.rounds.DataLoaderOutput, chunk_size=None, verbose=True, macro_chunk=1000) tuple[list[IMLCV.base.CV.CV], list[IMLCV.base.CV.CV], IMLCV.base.CV.CvTrans, list[jax.Array] | None]#
- class IMLCV.base.CVDiscovery.IdentityTransformer(*args, **kwargs)#
Bases:
TransformerBase class for dataclasses that should act like a JAX pytree node.
- _fit(x: list[IMLCV.base.CV.CV] | list[IMLCV.base.CV.SystemParams], x_t: list[IMLCV.base.CV.CV] | list[IMLCV.base.CV.SystemParams] | None, w: list[jax.Array], w_t: list[jax.Array], dlo: IMLCV.base.rounds.DataLoaderOutput, chunk_size=None, verbose=True, macro_chunk=1000, **fit_kwargs) tuple[list[IMLCV.base.CV.CV], list[IMLCV.base.CV.CV] | None, IMLCV.base.CV.CvTrans, list[jax.Array] | None]#
- class IMLCV.base.CVDiscovery.CvTransTransformer(*args, **kwargs)#
Bases:
TransformerBase class for dataclasses that should act like a JAX pytree node.
- trans: IMLCV.base.CV.CvTrans#
- abstract _fit(x: list[IMLCV.base.CV.CV] | list[IMLCV.base.CV.SystemParams], x_t: list[IMLCV.base.CV.CV] | list[IMLCV.base.CV.SystemParams] | None, w: list[jax.Array], w_t: list[jax.Array], dlo: IMLCV.base.rounds.DataLoaderOutput, chunk_size: int | None = None, verbose=True, macro_chunk=1000) tuple[list[IMLCV.base.CV.CV], list[IMLCV.base.CV.CV], IMLCV.base.CV.CvTrans, list[jax.Array] | None]#