IMLCV.base.CV

Contents

IMLCV.base.CV#

Attributes#

T

S

P

P2

P

T

T

X

X2

Classes#

ShmapKwargs

Base class for dataclasses that should act like a JAX pytree node.

SystemParams

Base class for dataclasses that should act like a JAX pytree node.

NeighbourListInfo

Base class for dataclasses that should act like a JAX pytree node.

NeighbourListUpdate

Base class for dataclasses that should act like a JAX pytree node.

NeighbourList

Base class for dataclasses that should act like a JAX pytree node.

CV

Base class for dataclasses that should act like a JAX pytree node.

CvMetric

class to keep track of topology of given CV. Identifies the periodicitie of CVs and maps to unit square with correct peridicities

CvFunBase

Helper class that provides a standard way to create an ABC using

CvFun

Helper class that provides a standard way to create an ABC using

_SerialCvTrans

f can either be a single CV tranformation or a list of transformations

_ParralelCvTrans

Base class for dataclasses that should act like a JAX pytree node.

CvTrans

Base class for dataclasses that should act like a JAX pytree node.

CollectiveVariable

Base class for dataclasses that should act like a JAX pytree node.

Functions#

_n_pad(x, axis, p, chunk_size[, reshape, move_axis, ...])

_n_unpad(x, axis, shape[, reshape, move_axis, trim, ...])

_shard(x_padded, axis, axis_name, mesh[, put, unflatten])

_shard_out(out_tree_def, axis, axis_name, mesh)

padded_shard_map(f[, kwargs])

padded_vmap(f[, chunk_size, axis, out_axes, vmap, verbose])

macro_chunk_map_fun(→ T)

macro_chunk_map(→ list[X2] | tuple[list[X2], ...)

_macro_chunk_map(f, y[, w, nl, ft, y_t, nl_t, ...])

Module Contents#

IMLCV.base.CV.T#
IMLCV.base.CV.S#
IMLCV.base.CV.P#
IMLCV.base.CV.P2#
class IMLCV.base.CV.ShmapKwargs(*args, **kwargs)#

Bases: IMLCV.base.datastructures.MyPyTreeNode

Base class for dataclasses that should act like a JAX pytree node.

axis: int = 0#
out_axes: int = 0#
axis_name: str | None#
n_devices: int | None = None#
pmap: bool = False#
explicit_shmap: bool = False#
verbose: bool = False#
device_get: bool = True#
devices: tuple[Any]#
mesh: jax.sharding.Mesh | None#
static create(axis: int = 0, out_axes: int = 0, axis_name: str | None = 'i', n_devices: int | None = None, pmap: bool = False, explicit_shmap: bool = False, verbose: bool = False, device_get: bool = True, devices=None, mesh=None)#
IMLCV.base.CV._n_pad(x, axis, p, chunk_size, reshape=False, move_axis=True, n_chunk_move=False)#
IMLCV.base.CV._n_unpad(x, axis, shape, reshape=False, move_axis=True, trim=True, n_chunk_move=False)#
IMLCV.base.CV._shard(x_padded, axis, axis_name, mesh, put=True, unflatten=True)#
IMLCV.base.CV._shard_out(out_tree_def, axis, axis_name, mesh)#
IMLCV.base.CV.padded_shard_map(f: Callable[P, T], kwargs: ShmapKwargs = ShmapKwargs.create())#
IMLCV.base.CV.P#
IMLCV.base.CV.T#
IMLCV.base.CV.padded_vmap(f: Callable[P, T], chunk_size=None, axis=0, out_axes: int = 0, vmap=True, verbose=False)#
IMLCV.base.CV.T#
IMLCV.base.CV.X#
IMLCV.base.CV.X2#
IMLCV.base.CV.macro_chunk_map_fun(f: Callable[[X, NeighbourList | None], X2], y: list[X], w: list[jax.Array] | None = None, nl: list[NeighbourList] | NeighbourList | None = None, ft: Callable[[X, NeighbourList | None], X2] | None = None, y_t: list[X] | None = None, nl_t: list[NeighbourList] | NeighbourList | None = None, macro_chunk: int | None = 1000, verbose=False, chunk_func: Callable[[T, X2, X2 | None, jax.Array | None, jax.Array | None], T] | None = None, chunk_func_init_args: T = None, w_t: list[jax.Array] | None = None, print_every=10, jit_f=True) T#
IMLCV.base.CV.macro_chunk_map(f: Callable[[X, NeighbourList | None], X2], y: list[X], nl: list[NeighbourList] | NeighbourList | None = None, ft: Callable[[X, NeighbourList | None], X2] | None = None, y_t: list[X] | None = None, nl_t: list[NeighbourList] | NeighbourList | None = None, macro_chunk: int | None = 1000, verbose=False, print_every=10, jit_f=True) list[X2] | tuple[list[X2], list[X2] | None]#
IMLCV.base.CV._macro_chunk_map(f: Callable[[X, NeighbourList | None], X2], y: list[X], w: list[jax.Array] | None = None, nl: list[NeighbourList] | NeighbourList | None = None, ft: Callable[[X, NeighbourList | None], X2] | None = None, y_t: list[X] | None = None, nl_t: list[NeighbourList] | NeighbourList | None = None, macro_chunk: int | None = 1000, verbose=False, chunk_func: Callable[[T, X2, X2 | None, jax.Array | None, jax.Array | None], T] | None = None, chunk_func_init_args: T = None, w_t: list[jax.Array] | None = None, print_every=10, jit_f=True)#
class IMLCV.base.CV.SystemParams(*args, **kwargs)#

Bases: IMLCV.base.datastructures.MyPyTreeNode

Base class for dataclasses that should act like a JAX pytree node.

coordinates: jax.Array#
cell: jax.Array | None#
__getitem__(slices) SystemParams#
__iter__()#
property batched#
property batch_dim#
property shape#
__add__(other)#
static stack(*sps: SystemParams) SystemParams#
batch() SystemParams#
unbatch() SystemParams#
angles(deg=True) jax.Array#
_get_neighbour_list(info: NeighbourListInfo, update: NeighbourListUpdate | None = None, chunk_size: int | None = 100, verbose=False, chunk_size_inner: int | None = 100, shmap=False, shmap_kwargs: ShmapKwargs = ShmapKwargs.create(), only_update=False)#
get_neighbour_list(info: NeighbourListInfo, chunk_size: int | None = None, chunk_size_inner=100, verbose=False, shmap=False, shmap_kwargs: ShmapKwargs = ShmapKwargs.create(), only_update=False) NeighbourList | None#
minkowski_reduce() tuple[SystemParams, jax.Array]#

base on code from ASE: https://wiki.fysik.dtu.dk/ase/_modules/ase/geometry/minkowski_reduction.html

apply_minkowski_reduction(op)#
rotate_cell() tuple[SystemParams, tuple[jax.Array, jax.Array] | None]#
apply_rotation(op)#
to_relative() tuple[SystemParams, jax.Array | None]#
to_absolute() SystemParams#
wrap_positions(min=False) tuple[SystemParams, jax.Array]#

wrap pos to lie within unit cell

apply_wrap(wrap_op: jax.Array) SystemParams#
canonicalize(min=False, qr=False, chunk_size=None) tuple[SystemParams, tuple[jax.Array, jax.Array, tuple[jax.Array, jax.Array] | None]]#
apply_canonicalize(ops)#
min_distance(index_1, index_2)#
super_cell(n: int | list[int], info: NeighbourListInfo | None = None) tuple[SystemParams, NeighbourListInfo | None]#
volume()#
to_ase(static_trajectory_info: IMLCV.base.MdEngine.StaticMdInfo)#
class IMLCV.base.CV.NeighbourListInfo(*args, **kwargs)#

Bases: IMLCV.base.datastructures.MyPyTreeNode

Base class for dataclasses that should act like a JAX pytree node.

r_cut: float#
r_skin: float#
z_array: tuple[int] | None#
z_unique: tuple[int] | None#
num_z_unique: tuple[int] | None#
static create(r_cut, z_array, r_skin=None)#
nl_split_z(p: T) tuple[jax.Array, list[jax.Array], list[T]]#
__getstate__()#
__setstate__(state)#
class IMLCV.base.CV.NeighbourListUpdate(*args, **kwargs)#

Bases: IMLCV.base.datastructures.MyPyTreeNode

Base class for dataclasses that should act like a JAX pytree node.

nxyz: tuple[int] | None#
stack_dims: tuple[int] | None#
num_neighs: int | None#
static create(nxyz=None, stack_dims=None, num_neighs: int | None = None)#
__getstate__()#
__setstate__(state)#
class IMLCV.base.CV.NeighbourList(*args, **kwargs)#

Bases: IMLCV.base.datastructures.MyPyTreeNode

Base class for dataclasses that should act like a JAX pytree node.

sp_orig: SystemParams | None#
info: NeighbourListInfo#
update: NeighbourListUpdate#
atom_indices: jax.Array | None#
op_cell: jax.Array | None#
op_coor: jax.Array | None#
op_center: jax.Array | None#
ijk_indices: jax.Array | None#
_padding_bools: jax.Array | None#
static create(r_cut, sp_orig, atom_indices=None, r_skin=None, ijk_indices=None, z_array=None, nxyz=None, op_cell=None, op_coor=None, op_center=None, stack_dims=None, num_neighs=None)#
nneighs(sp=None)#
property needs_calculation#
property padding_bools#
canonicalized_sp(sp: SystemParams) SystemParams#
neighbour_pos(sp_orig: SystemParams) jax.Array#
apply_fun_neighbour(sp: SystemParams, func: Callable[[jax.Array, jax.Array], T], r_cut: float | None = None, fill_value=0, reduce='full', exclude_self: bool = False, chunk_size_neigbourgs: int | None = None, chunk_size_atoms: int | None = None, chunk_size_batch: int | None = None, shmap=False, split_z=False, shmap_kwargs=ShmapKwargs.create())#
apply_fun_neighbour_pair(sp: SystemParams, func_double: Callable[[jax.Array, jax.Array, T, jax.Array, jax.Array, T], S], func_single: Callable[[jax.Array, jax.Array], T] = lambda x, y: ..., r_cut=None, fill_value=0.0, reduce='full', split_z=False, exclude_self=True, unique=True, chunk_size_neigbourgs=None, chunk_size_atoms=None, chunk_size_batch=None, shmap=False, shmap_kwargs=ShmapKwargs.create()) tuple[jax.Array, S]#

Args:#

func_single(r_ij, atom_index_j) = .. func_double( p_ij, atom_index_j, data_j, p_ik, atom_index_k, data_k) = …

property batched#
property batch_dim#
property shape#
__getitem__(slices)#
needs_update(sp: SystemParams) bool#
update_nl(sp: SystemParams, chunk_size: int | None = None, chunk_size_inner: int | None = 10, shmap=False, shmap_kwargs=ShmapKwargs.create(), verbose=False)#
nl_split_z(p)#
batch()#
__add__(other)#
static stack(*nls: NeighbourList) NeighbourList#
unstack() list[NeighbourList]#
__getstate__()#
__setstate__(statedict: dict)#
property stack_dims#
class IMLCV.base.CV.CV(*args, **kwargs)#

Bases: IMLCV.base.datastructures.MyPyTreeNode

Base class for dataclasses that should act like a JAX pytree node.

cv: jax.Array#
mapped: bool#
atomic: bool#
_combine_dims: tuple[int | Any] | None#
_stack_dims: tuple[int] | None#
static create(cv: jax.Array, mapped=False, atomic=False, combine_dims=None, stack_dims=None)#
property batched#
property batch_dim#
property dim#
property size#
property shape#
property combine_dims#
property stack_dims#
__add__(other) CV#
__radd__(other) CV#
__sub__(other) CV#
__rsub__(other) CV#
__mul__(other) CV#
__rmul__(other) CV#
__matmul__(other) CV#
__rmatmul__(other) CV#
__div__(other) CV#
batch() CV#
__iter__()#
__getitem__(idx)#
unbatch() CV#
static stack(*cvs: CV) CV#

stacks a list of CVs into a single CV. The dimenisions are stored such that it can later be unstacked into separated CVs. The CVs are stacked over the batch dimension

unstack() list[CV]#
split(flatten=False) list[CV]#

inverse operation of combine

static combine(*cvs: CV, flatten=False) CV#

merges a list of CVs into a single CV. The dimenisions are stored such that it can later be split into separated CVs. The CVs are combined over the last dimension

__getstate__()#
__setstate__(statedict: dict)#
class IMLCV.base.CV.CvMetric(*args, **kwargs)#

Bases: IMLCV.base.datastructures.MyPyTreeNode

class to keep track of topology of given CV. Identifies the periodicitie of CVs and maps to unit square with correct peridicities

bounding_box: jax.Array#
periodicities: jax.Array#
classmethod create(periodicities=None, bounding_box=None) CvMetric#
norm(x1: CV, x2: CV, k=1.0)#
periodic_wrap(x: CV, min=False) CV#
difference(x1: CV, x2: CV) jax.Array#
min_cv(cv: jax.Array)#
__periodic_wrap(xs: jax.Array, min=False)#

Translate cvs such over unit cell.

min=True calculates distances, False translates one vector inside box

map(x: jax.Array, displace=True) jax.Array#

transform CVs to lie in unit square.

unmap(x: jax.Array, displace=True) jax.Array#

transform CVs to lie in unit square.

__add__(other)#
static get_n(samples_per_bin, samples, n_dims, max_bins=None, max_bins_per_dim=30)#
grid(n=30, bounds=None, margin=0.1, indexing='ij')#

forms regular grid in mapped space. If coordinate is periodic, last rows are ommited.

Parameters:
  • n – number of points in each dim

  • map – boolean. True: work in mapped space (default), False: calculate grid in space without metric

  • endpoints – if

Returns:

meshgrid and vector with distances between points

property ndim#
static bounds_from_cv(cv_0: list[CV], percentile=0.1, weights: list[jax.Array] | None = None, rho: list[jax.Array] | None = None, margin=None, chunk_size: int | None = None, n=400, macro_chunk: int | None = 5000, verbose=True)#
__getstate__()#
__setstate__(statedict: dict)#
__getitem__(idx)#
class IMLCV.base.CV.CvFunBase(*args, **kwargs)#

Bases: abc.ABC, IMLCV.base.datastructures.MyPyTreeNode

Helper class that provides a standard way to create an ABC using inheritance.

kwargs: dict#
static_kwargs: dict#
compute_cv(x: CV | SystemParams, nl: NeighbourList | None = None, chunk_size=None, jacobian=False, reverse=False, shmap=False, shmap_kwargs=ShmapKwargs.create()) tuple[CV, CV | None]#
abstract _compute_cv(x: CV | SystemParams, nl: NeighbourList | None = None, reverse=False, shmap=False, shmap_kwargs=ShmapKwargs.create()) CV#
__getstate__()#
__setstate__(statedict: dict)#
class IMLCV.base.CV.CvFun(*args, **kwargs)#

Bases: CvFunBase

Helper class that provides a standard way to create an ABC using inheritance.

forward: Callable[[CV, NeighbourList | None, CV | None], CV] | None#
backward: Callable[[CV, NeighbourList | None, CV | None], CV] | None#
_compute_cv(x: CV | SystemParams, nl: NeighbourList | None = None, reverse=False, shmap=False, shmap_kwargs=ShmapKwargs.create()) CV#
class IMLCV.base.CV._SerialCvTrans(*args, **kwargs)#

Bases: IMLCV.base.datastructures.MyPyTreeNode

f can either be a single CV tranformation or a list of transformations

trans: tuple[_SerialCvTrans | _ParralelCvTrans | CvFun, Ellipsis]#
_compute_cv(x: CV | SystemParams, nl: NeighbourList | None = None, reverse=False, shmap=False, shmap_kwargs=ShmapKwargs.create()) CV#
class IMLCV.base.CV._ParralelCvTrans(*args, **kwargs)#

Bases: IMLCV.base.datastructures.MyPyTreeNode

Base class for dataclasses that should act like a JAX pytree node.

trans: tuple[_SerialCvTrans | _ParralelCvTrans, Ellipsis]#
_compute_cv(x: CV | SystemParams, nl: NeighbourList | None = None, reverse=False, shmap=False, shmap_kwargs=ShmapKwargs.create()) CV#
class IMLCV.base.CV.CvTrans(*args, **kwargs)#

Bases: IMLCV.base.datastructures.MyPyTreeNode

Base class for dataclasses that should act like a JAX pytree node.

trans: _ParralelCvTrans | _SerialCvTrans#
static from_cv_function(f: Callable, static_argnames=None, check_input: bool = True, **kwargs) CvTrans#
compute_cv(x: X, nl: NeighbourList | None = None, chunk_size=None, jacobian=False, reverse=False, shmap=False, shmap_kwargs=ShmapKwargs.create()) tuple[CV, X | None]#
__getstate__()#
__setstate__(statedict: dict)#
__mul__(other: CvTrans)#
__add__(other: CvTrans)#
class IMLCV.base.CV.CollectiveVariable(*args, **kwargs)#

Bases: IMLCV.base.datastructures.MyPyTreeNode

Base class for dataclasses that should act like a JAX pytree node.

f: CvTrans#
metric: CvMetric#
jac: Callable#
compute_cv(sp: SystemParams, nl: NeighbourList | None = None, jacobian=False, chunk_size: int | None = None, shmap=False, push_jac=False, shmap_kwargs=ShmapKwargs.create()) tuple[CV, SystemParams | None]#
property n#
save(file)#
static load(file, **kwargs) CollectiveVariable#
__getstate__()#
__setstate__(statedict: dict)#
__getitem__(tup)#