IMLCV.base.CV#
Attributes#
Classes#
Base class for dataclasses that should act like a JAX pytree node. |
|
Base class for dataclasses that should act like a JAX pytree node. |
|
Base class for dataclasses that should act like a JAX pytree node. |
|
Base class for dataclasses that should act like a JAX pytree node. |
|
Base class for dataclasses that should act like a JAX pytree node. |
|
Base class for dataclasses that should act like a JAX pytree node. |
|
class to keep track of topology of given CV. Identifies the periodicitie of CVs and maps to unit square with correct peridicities |
|
Helper class that provides a standard way to create an ABC using |
|
Helper class that provides a standard way to create an ABC using |
|
f can either be a single CV tranformation or a list of transformations |
|
Base class for dataclasses that should act like a JAX pytree node. |
|
Base class for dataclasses that should act like a JAX pytree node. |
|
Base class for dataclasses that should act like a JAX pytree node. |
Functions#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Module Contents#
- IMLCV.base.CV.T#
- IMLCV.base.CV.S#
- IMLCV.base.CV.P#
- IMLCV.base.CV.P2#
- class IMLCV.base.CV.ShmapKwargs(*args, **kwargs)#
Bases:
IMLCV.base.datastructures.MyPyTreeNodeBase class for dataclasses that should act like a JAX pytree node.
- mesh: jax.sharding.Mesh | None#
- IMLCV.base.CV._n_pad(x, axis, p, chunk_size, reshape=False, move_axis=True, n_chunk_move=False)#
- IMLCV.base.CV._n_unpad(x, axis, shape, reshape=False, move_axis=True, trim=True, n_chunk_move=False)#
- IMLCV.base.CV._shard(x_padded, axis, axis_name, mesh, put=True, unflatten=True)#
- IMLCV.base.CV._shard_out(out_tree_def, axis, axis_name, mesh)#
- IMLCV.base.CV.padded_shard_map(f: Callable[P, T], kwargs: ShmapKwargs = ShmapKwargs.create())#
- IMLCV.base.CV.P#
- IMLCV.base.CV.T#
- IMLCV.base.CV.padded_vmap(f: Callable[P, T], chunk_size=None, axis=0, out_axes: int = 0, vmap=True, verbose=False)#
- IMLCV.base.CV.T#
- IMLCV.base.CV.X#
- IMLCV.base.CV.X2#
- IMLCV.base.CV.macro_chunk_map_fun(f: Callable[[X, NeighbourList | None], X2], y: list[X], w: list[jax.Array] | None = None, nl: list[NeighbourList] | NeighbourList | None = None, ft: Callable[[X, NeighbourList | None], X2] | None = None, y_t: list[X] | None = None, nl_t: list[NeighbourList] | NeighbourList | None = None, macro_chunk: int | None = 1000, verbose=False, chunk_func: Callable[[T, X2, X2 | None, jax.Array | None, jax.Array | None], T] | None = None, chunk_func_init_args: T = None, w_t: list[jax.Array] | None = None, print_every=10, jit_f=True) T#
- IMLCV.base.CV.macro_chunk_map(f: Callable[[X, NeighbourList | None], X2], y: list[X], nl: list[NeighbourList] | NeighbourList | None = None, ft: Callable[[X, NeighbourList | None], X2] | None = None, y_t: list[X] | None = None, nl_t: list[NeighbourList] | NeighbourList | None = None, macro_chunk: int | None = 1000, verbose=False, print_every=10, jit_f=True) list[X2] | tuple[list[X2], list[X2] | None]#
- IMLCV.base.CV._macro_chunk_map(f: Callable[[X, NeighbourList | None], X2], y: list[X], w: list[jax.Array] | None = None, nl: list[NeighbourList] | NeighbourList | None = None, ft: Callable[[X, NeighbourList | None], X2] | None = None, y_t: list[X] | None = None, nl_t: list[NeighbourList] | NeighbourList | None = None, macro_chunk: int | None = 1000, verbose=False, chunk_func: Callable[[T, X2, X2 | None, jax.Array | None, jax.Array | None], T] | None = None, chunk_func_init_args: T = None, w_t: list[jax.Array] | None = None, print_every=10, jit_f=True)#
- class IMLCV.base.CV.SystemParams(*args, **kwargs)#
Bases:
IMLCV.base.datastructures.MyPyTreeNodeBase class for dataclasses that should act like a JAX pytree node.
- __getitem__(slices) SystemParams#
- __iter__()#
- property batched#
- property batch_dim#
- property shape#
- __add__(other)#
- static stack(*sps: SystemParams) SystemParams#
- batch() SystemParams#
- unbatch() SystemParams#
- _get_neighbour_list(info: NeighbourListInfo, update: NeighbourListUpdate | None = None, chunk_size: int | None = 100, verbose=False, chunk_size_inner: int | None = 100, shmap=False, shmap_kwargs: ShmapKwargs = ShmapKwargs.create(), only_update=False)#
- get_neighbour_list(info: NeighbourListInfo, chunk_size: int | None = None, chunk_size_inner=100, verbose=False, shmap=False, shmap_kwargs: ShmapKwargs = ShmapKwargs.create(), only_update=False) NeighbourList | None#
- minkowski_reduce() tuple[SystemParams, jax.Array]#
base on code from ASE: https://wiki.fysik.dtu.dk/ase/_modules/ase/geometry/minkowski_reduction.html
- apply_minkowski_reduction(op)#
- apply_rotation(op)#
- to_relative() tuple[SystemParams, jax.Array | None]#
- to_absolute() SystemParams#
- wrap_positions(min=False) tuple[SystemParams, jax.Array]#
wrap pos to lie within unit cell
- apply_wrap(wrap_op: jax.Array) SystemParams#
- canonicalize(min=False, qr=False, chunk_size=None) tuple[SystemParams, tuple[jax.Array, jax.Array, tuple[jax.Array, jax.Array] | None]]#
- apply_canonicalize(ops)#
- min_distance(index_1, index_2)#
- super_cell(n: int | list[int], info: NeighbourListInfo | None = None) tuple[SystemParams, NeighbourListInfo | None]#
- volume()#
- to_ase(static_trajectory_info: IMLCV.base.MdEngine.StaticMdInfo)#
- class IMLCV.base.CV.NeighbourListInfo(*args, **kwargs)#
Bases:
IMLCV.base.datastructures.MyPyTreeNodeBase class for dataclasses that should act like a JAX pytree node.
- static create(r_cut, z_array, r_skin=None)#
- __getstate__()#
- __setstate__(state)#
- class IMLCV.base.CV.NeighbourListUpdate(*args, **kwargs)#
Bases:
IMLCV.base.datastructures.MyPyTreeNodeBase class for dataclasses that should act like a JAX pytree node.
- __getstate__()#
- __setstate__(state)#
- class IMLCV.base.CV.NeighbourList(*args, **kwargs)#
Bases:
IMLCV.base.datastructures.MyPyTreeNodeBase class for dataclasses that should act like a JAX pytree node.
- sp_orig: SystemParams | None#
- info: NeighbourListInfo#
- update: NeighbourListUpdate#
- static create(r_cut, sp_orig, atom_indices=None, r_skin=None, ijk_indices=None, z_array=None, nxyz=None, op_cell=None, op_coor=None, op_center=None, stack_dims=None, num_neighs=None)#
- nneighs(sp=None)#
- property needs_calculation#
- property padding_bools#
- canonicalized_sp(sp: SystemParams) SystemParams#
- neighbour_pos(sp_orig: SystemParams) jax.Array#
- apply_fun_neighbour(sp: SystemParams, func: Callable[[jax.Array, jax.Array], T], r_cut: float | None = None, fill_value=0, reduce='full', exclude_self: bool = False, chunk_size_neigbourgs: int | None = None, chunk_size_atoms: int | None = None, chunk_size_batch: int | None = None, shmap=False, split_z=False, shmap_kwargs=ShmapKwargs.create())#
- apply_fun_neighbour_pair(sp: SystemParams, func_double: Callable[[jax.Array, jax.Array, T, jax.Array, jax.Array, T], S], func_single: Callable[[jax.Array, jax.Array], T] = lambda x, y: ..., r_cut=None, fill_value=0.0, reduce='full', split_z=False, exclude_self=True, unique=True, chunk_size_neigbourgs=None, chunk_size_atoms=None, chunk_size_batch=None, shmap=False, shmap_kwargs=ShmapKwargs.create()) tuple[jax.Array, S]#
Args:#
func_single(r_ij, atom_index_j) = .. func_double( p_ij, atom_index_j, data_j, p_ik, atom_index_k, data_k) = …
- property batched#
- property batch_dim#
- property shape#
- __getitem__(slices)#
- needs_update(sp: SystemParams) bool#
- update_nl(sp: SystemParams, chunk_size: int | None = None, chunk_size_inner: int | None = 10, shmap=False, shmap_kwargs=ShmapKwargs.create(), verbose=False)#
- nl_split_z(p)#
- batch()#
- __add__(other)#
- static stack(*nls: NeighbourList) NeighbourList#
- unstack() list[NeighbourList]#
- __getstate__()#
- property stack_dims#
- class IMLCV.base.CV.CV(*args, **kwargs)#
Bases:
IMLCV.base.datastructures.MyPyTreeNodeBase class for dataclasses that should act like a JAX pytree node.
- property batched#
- property batch_dim#
- property dim#
- property size#
- property shape#
- property combine_dims#
- property stack_dims#
- __iter__()#
- __getitem__(idx)#
- static stack(*cvs: CV) CV#
stacks a list of CVs into a single CV. The dimenisions are stored such that it can later be unstacked into separated CVs. The CVs are stacked over the batch dimension
- static combine(*cvs: CV, flatten=False) CV#
merges a list of CVs into a single CV. The dimenisions are stored such that it can later be split into separated CVs. The CVs are combined over the last dimension
- __getstate__()#
- class IMLCV.base.CV.CvMetric(*args, **kwargs)#
Bases:
IMLCV.base.datastructures.MyPyTreeNodeclass to keep track of topology of given CV. Identifies the periodicitie of CVs and maps to unit square with correct peridicities
- __periodic_wrap(xs: jax.Array, min=False)#
Translate cvs such over unit cell.
min=True calculates distances, False translates one vector inside box
- __add__(other)#
- static get_n(samples_per_bin, samples, n_dims, max_bins=None, max_bins_per_dim=30)#
- grid(n=30, bounds=None, margin=0.1, indexing='ij')#
forms regular grid in mapped space. If coordinate is periodic, last rows are ommited.
- Parameters:
n – number of points in each dim
map – boolean. True: work in mapped space (default), False: calculate grid in space without metric
endpoints – if
- Returns:
meshgrid and vector with distances between points
- property ndim#
- static bounds_from_cv(cv_0: list[CV], percentile=0.1, weights: list[jax.Array] | None = None, rho: list[jax.Array] | None = None, margin=None, chunk_size: int | None = None, n=400, macro_chunk: int | None = 5000, verbose=True)#
- __getstate__()#
- __getitem__(idx)#
- class IMLCV.base.CV.CvFunBase(*args, **kwargs)#
Bases:
abc.ABC,IMLCV.base.datastructures.MyPyTreeNodeHelper class that provides a standard way to create an ABC using inheritance.
- compute_cv(x: CV | SystemParams, nl: NeighbourList | None = None, chunk_size=None, jacobian=False, reverse=False, shmap=False, shmap_kwargs=ShmapKwargs.create()) tuple[CV, CV | None]#
- abstract _compute_cv(x: CV | SystemParams, nl: NeighbourList | None = None, reverse=False, shmap=False, shmap_kwargs=ShmapKwargs.create()) CV#
- __getstate__()#
- class IMLCV.base.CV.CvFun(*args, **kwargs)#
Bases:
CvFunBaseHelper class that provides a standard way to create an ABC using inheritance.
- _compute_cv(x: CV | SystemParams, nl: NeighbourList | None = None, reverse=False, shmap=False, shmap_kwargs=ShmapKwargs.create()) CV#
- class IMLCV.base.CV._SerialCvTrans(*args, **kwargs)#
Bases:
IMLCV.base.datastructures.MyPyTreeNodef can either be a single CV tranformation or a list of transformations
- trans: tuple[_SerialCvTrans | _ParralelCvTrans | CvFun, Ellipsis]#
- _compute_cv(x: CV | SystemParams, nl: NeighbourList | None = None, reverse=False, shmap=False, shmap_kwargs=ShmapKwargs.create()) CV#
- class IMLCV.base.CV._ParralelCvTrans(*args, **kwargs)#
Bases:
IMLCV.base.datastructures.MyPyTreeNodeBase class for dataclasses that should act like a JAX pytree node.
- trans: tuple[_SerialCvTrans | _ParralelCvTrans, Ellipsis]#
- _compute_cv(x: CV | SystemParams, nl: NeighbourList | None = None, reverse=False, shmap=False, shmap_kwargs=ShmapKwargs.create()) CV#
- class IMLCV.base.CV.CvTrans(*args, **kwargs)#
Bases:
IMLCV.base.datastructures.MyPyTreeNodeBase class for dataclasses that should act like a JAX pytree node.
- trans: _ParralelCvTrans | _SerialCvTrans#
- static from_cv_function(f: Callable, static_argnames=None, check_input: bool = True, **kwargs) CvTrans#
- compute_cv(x: X, nl: NeighbourList | None = None, chunk_size=None, jacobian=False, reverse=False, shmap=False, shmap_kwargs=ShmapKwargs.create()) tuple[CV, X | None]#
- __getstate__()#
- class IMLCV.base.CV.CollectiveVariable(*args, **kwargs)#
Bases:
IMLCV.base.datastructures.MyPyTreeNodeBase class for dataclasses that should act like a JAX pytree node.
- jac: Callable#
- compute_cv(sp: SystemParams, nl: NeighbourList | None = None, jacobian=False, chunk_size: int | None = None, shmap=False, push_jac=False, shmap_kwargs=ShmapKwargs.create()) tuple[CV, SystemParams | None]#
- property n#
- save(file)#
- static load(file, **kwargs) CollectiveVariable#
- __getstate__()#
- __getitem__(tup)#