"""MD engine class peforms MD simulations in a given NVT/NPT ensemble.
Currently, the MD is done with YAFF/OpenMM
"""
from __future__ import annotations
import tempfile
from abc import ABC
from abc import abstractmethod
from dataclasses import dataclass
from pathlib import Path
from time import time
import cloudpickle
import h5py
import jax.numpy as jnp
import numpy as np
import yaff.analysis.biased_sampling
import yaff.external
import yaff.log
import yaff.pes.bias
import yaff.pes.ext
import yaff.sampling.iterative
from IMLCV.base.bias import Bias
from IMLCV.base.bias import Energy
from IMLCV.base.bias import EnergyResult
from IMLCV.base.CV import CV
from IMLCV.base.CV import NeighbourList
from IMLCV.base.CV import SystemParams
from jax import Array
from molmod.periodic import periodic
from molmod.units import angstrom
from molmod.units import bar
from molmod.units import kjmol
yaff.log.set_level(yaff.log.silent)
######################################
# Trajectory #
######################################
@dataclass
[docs]class StaticTrajectoryInfo:
[docs] _attr = [
"timestep",
"r_cut",
"timecon_thermo",
"T",
"P",
"timecon_baro",
"write_step",
"equilibration",
"screen_log",
"max_grad",
]
[docs] _arr = [
"atomic_numbers",
]
[docs] r_cut: float | None = None
[docs] timecon_baro: float | None = None
[docs] equilibration: float | None = None
[docs] max_grad: float | None = 200 * kjmol / angstrom
@property
[docs] def masses(self):
return jnp.array([periodic[int(n)].mass for n in self.atomic_numbers])
@property
[docs] def thermostat(self):
return self.T is not None
@property
[docs] def barostat(self):
return self.P is not None
[docs] def __post_init__(self):
if self.thermostat:
assert self.timecon_thermo is not None
if self.barostat:
assert self.timecon_baro is not None
# if self.equilibration is None:
# self.equilibration = 200 * self.timestep
[docs] def _save(self, hf: h5py.File):
for name in self._arr:
prop = self.__getattribute__(name)
if prop is not None:
hf[name] = prop
for name in self._attr:
prop = self.__getattribute__(name)
if prop is not None:
hf.attrs[name] = prop
[docs] def save(self, filename: str | Path):
if isinstance(filename, str):
filename = Path(filename)
if not filename.parent.exists():
filename.parent.mkdir(parents=True, exist_ok=True)
with h5py.File(str(filename), "w") as hf:
self._save(hf=hf)
@staticmethod
[docs] def _load(hf: h5py.File) -> StaticTrajectoryInfo:
props_static = {}
attrs_static = {}
for key, val in hf.items():
props_static[key] = val[:]
for key, val in hf.attrs.items():
attrs_static[key] = val
return StaticTrajectoryInfo(**attrs_static, **props_static)
@staticmethod
[docs] def load(filename) -> StaticTrajectoryInfo:
with h5py.File(str(filename), "r") as hf:
return StaticTrajectoryInfo._load(hf=hf)
@dataclass
[docs]class TrajectoryInfo:
[docs] _cell: Array | None = None
[docs] _charges: Array | None = None
[docs] _e_pot: Array | None = None
[docs] _e_pot_gpos: Array | None = None
[docs] _e_pot_vtens: Array | None = None
[docs] _e_bias: Array | None = None
[docs] _e_bias_gpos: Array | None = None
[docs] _e_bias_vtens: Array | None = None
[docs] _cv: Array | None = None
[docs] _T: Array | None = None
[docs] _P: Array | None = None
[docs] _err: Array | None = None
[docs] _t: Array | None = None
[docs] _items_scal = ["_t", "_e_pot", "_e_bias", "_T", "_P", "_err"]
[docs] _items_vec = [
"_positions",
"_cell",
"_e_pot_gpos",
"_e_pot_vtens",
"_e_bias_gpos",
"_e_bias_vtens",
"_charges",
"_cv",
]
# https://stackoverflow.com/questions/7133885/fastest-way-to-grow-a-numpy-numeric-array
[docs] def __post_init__(self):
if self._capacity == -1:
self._capacity = 1
if self._size == -1:
self._size = 1
# batch
if len(self._positions.shape) == 2:
for name in [*self._items_vec, *self._items_scal]:
prop = self.__getattribute__(name)
if prop is not None:
self.__setattr__(name, np.array([prop]))
# test wether cell is truly not None
if self._cell is not None:
if self._cell.shape[-2] == 0:
self._cell = None
[docs] def __getitem__(self, slices):
"gets slice from indices. the output is truncated to the to include only items wihtin _size"
slz = (jnp.ones(self._capacity).cumsum() - 1)[slices]
ind = slz <= self._size
# print(f"ind: {ind}, cap = {jnp.sum(ind)}, t : {self.t[slices][ind].shape}")
return TrajectoryInfo(
_positions=self._positions[slices, :][ind],
_cell=self._cell[slices, :][ind] if self._cell is not None else None,
_charges=(self._charges[slices, :][ind] if self._cell is not None else None)
if self._charges is not None
else None,
_e_pot=self._e_pot[slices][ind] if self._e_pot is not None else None,
_e_pot_gpos=self._e_pot_gpos[slices, :][ind] if self._e_pot_gpos is not None else None,
_e_pot_vtens=self._e_pot_vtens[slices, :][ind] if self._e_pot_vtens is not None else None,
_e_bias=self._e_bias[slices][ind] if self._e_bias is not None else None,
_e_bias_gpos=self._e_bias_gpos[slices, :][ind] if self._e_bias_gpos is not None else None,
_e_bias_vtens=self._e_bias_vtens[slices, :][ind] if self._e_bias_vtens is not None else None,
_cv=self._cv[slices, :][ind] if self._cv is not None else None,
_T=self._T[slices][ind] if self._T is not None else None,
_P=self._P[slices][ind] if self._P is not None else None,
_err=self._err[slices][ind] if self._err is not None else None,
_t=self._t[slices][ind] if self._t is not None else None,
_capacity=jnp.sum(ind),
_size=jnp.sum(ind),
)
[docs] def __add__(self, ti: TrajectoryInfo):
sz = ti._size
while self._capacity <= self._size + ti._size:
self._expand_capacity()
for name in self._items_vec:
prop_ti = ti.__getattribute__(name)
prop_self = self.__getattribute__(name)
if prop_ti is None:
assert prop_self is None
else:
prop_self[self._size : self._size + sz, :] = prop_ti[0:sz, :]
for name in self._items_scal:
prop_ti = ti.__getattribute__(name)
prop_self = self.__getattribute__(name)
if prop_ti is None:
assert prop_self is None
else:
prop_self[self._size : self._size + sz] = prop_ti[0:sz]
self._size += sz
return self
[docs] def _expand_capacity(self):
nc = min(self._capacity * 2, self._capacity + 10000)
delta = nc - self._capacity
self._capacity = nc
for name in self._items_vec:
prop = self.__getattribute__(name)
if prop is not None:
self.__setattr__(
name,
np.vstack((prop, np.zeros((delta, *prop.shape[1:])))),
)
for name in self._items_scal:
prop = self.__getattribute__(name)
if prop is not None:
self.__setattr__(
name,
np.hstack([prop, np.zeros(delta)]),
)
[docs] def _shrink_capacity(self):
for name in self._items_vec:
prop = self.__getattribute__(name)
if prop is not None:
self.__setattr__(name, prop[: self._size, :])
for name in self._items_scal:
prop = self.__getattribute__(name)
if prop is not None:
self.__setattr__(name, prop[: self._size])
self._capacity = self._size
[docs] def save(self, filename: str | Path):
self._shrink_capacity()
if isinstance(filename, str):
filename = Path(filename)
if not filename.parent.exists():
filename.parent.mkdir(parents=True, exist_ok=True)
with h5py.File(str(filename), "w") as hf:
self._save(hf=hf)
[docs] def _save(self, hf: h5py.File):
for name in [*self._items_scal, *self._items_vec]:
prop = self.__getattribute__(name)
if prop is not None:
hf[name] = prop
hf.attrs.create("_capacity", self._capacity)
hf.attrs.create("_size", self._size)
@staticmethod
[docs] def load(filename) -> TrajectoryInfo:
with h5py.File(str(filename), "r") as hf:
return TrajectoryInfo._load(hf=hf)
@staticmethod
[docs] def _load(hf: h5py.File):
props = {}
attrs = {}
for key, val in hf.items():
# if key == "static_info":
# tic = StaticTrajectoryInfo._load(hf[key])
# continue
props[key] = val[:]
for key, val in hf.attrs.items():
attrs[key] = val
return TrajectoryInfo(
# static_info=tic,
**props,
**attrs,
)
@property
[docs] def sp(self) -> SystemParams:
return SystemParams(
coordinates=jnp.array(self._positions[0 : self._size, :]),
cell=jnp.array(self._cell[0 : self._size, :]) if self._cell is not None else None,
)
@property
[docs] def positions(self) -> Array | None:
if self._positions is None:
return None
return self._positions[0 : self._size, :]
@property
[docs] def cell(self) -> Array | None:
if self._cell is None:
return None
return self._cell[0 : self._size, :]
@property
[docs] def volume(self):
if self.cell is not None:
return jnp.linalg.det(self._cell)
return None
@property
[docs] def charges(self) -> Array | None:
if self._charges is None:
return None
return self._charges[0 : self._size, :]
@property
[docs] def e_pot(self) -> Array | None:
if self._e_pot is None:
return None
return self._e_pot[0 : self._size]
@property
[docs] def e_pot_gpos(self) -> Array | None:
if self._e_pot_gpos is None:
return None
return self._e_pot_gpos[0 : self._size, :]
@property
[docs] def e_pot_vtens(self) -> Array | None:
if self._e_pot_vtens is None:
return None
return self._e_pot_vtens[0 : self._size, :]
@property
[docs] def e_bias(self) -> Array | None:
if self._e_bias is None:
return None
return self._e_bias[0 : self._size]
@property
[docs] def e_bias_gpos(self) -> Array | None:
if self._e_bias_gpos is None:
return None
return self._e_bias_gpos[0 : self._size, :]
@property
[docs] def e_bias_vtens(self) -> Array | None:
if self._e_bias_vtens is None:
return None
return self._e_bias_vtens[0 : self._size, :]
@property
[docs] def cv(self) -> Array | None:
if self._cv is None:
return None
return self._cv[0 : self._size, :]
@property
[docs] def T(self) -> Array | None:
if self._T is None:
return None
return self._T[0 : self._size]
@property
[docs] def P(self) -> Array | None:
if self._P is None:
return None
return self._P[0 : self._size]
@property
[docs] def err(self) -> Array | None:
if self._err is None:
return None
return self._err[0 : self._size]
@property
[docs] def t(self) -> Array | None:
if self._t is None:
return None
return self._t[0 : self._size]
@property
[docs] def shape(self):
return self._size
@property
[docs] def CV(self) -> CV | None:
if self._cv is not None:
return CV(cv=self._cv[0 : self._size, :])
return None
######################################
# MDEngine #
######################################
[docs]class MDEngine(ABC):
"""Base class for MD engine."""
[docs] keys = [
"bias",
"energy",
"static_trajectory_info",
"trajectory_file",
# "sp",
# "sp",
]
def __init__(
self,
bias: Bias,
energy: Energy,
static_trajectory_info: StaticTrajectoryInfo,
trajectory_file=None,
sp: SystemParams | None = None,
) -> None:
self.static_trajectory_info = static_trajectory_info
self.bias = bias
self.energy = energy
self.last_bias = EnergyResult(0)
self.last_ener = EnergyResult(0)
self.last_cv: CV | None = None
# self._sp = sp
self.trajectory_info: TrajectoryInfo | None = None
self.step = 1
if sp is not None:
self.sp = sp
self.trajectory_file = trajectory_file
self.time0 = time()
self._nl: NeighbourList | None = None
@property
[docs] def sp(self) -> SystemParams:
return self.energy.sp
@sp.setter
def sp(self, sp: SystemParams):
self.energy.sp = sp
@property
[docs] def nl(self) -> NeighbourList | None:
if self.static_trajectory_info.r_cut is None:
return None
def _nl():
return self.sp.get_neighbour_list(
r_cut=self.static_trajectory_info.r_cut,
z_array=self.static_trajectory_info.atomic_numbers,
r_skin=0.0,
)
if self._nl is None:
nl = _nl()
else:
b, nl = self._nl.update(self.sp) # jitted update
if not b:
nl = _nl()
self._nl = nl
return nl
[docs] def save(self, file):
with open(file, "wb") as f:
cloudpickle.dump(self, f)
[docs] def __getstate__(self):
return {key: self.__getattribute__(key) for key in MDEngine.keys}
[docs] def __setstate__(self, state):
self.__init__(**state)
return self
@staticmethod
[docs] def load(file, **kwargs) -> MDEngine:
with open(file, "rb") as f:
self = cloudpickle.load(f)
print("Loading MD engine")
for key in kwargs.keys():
print(f"setting {key}={kwargs[key]}")
self.__setattr__(key, kwargs[key])
return self
[docs] def new_bias(self, bias: Bias, **kwargs) -> MDEngine:
with tempfile.NamedTemporaryFile() as tmp:
self.save(tmp.name)
mde = MDEngine.load(tmp.name, **{"bias": bias, **kwargs})
return mde
[docs] def run(self, steps):
"""run the integrator for a given number of steps.
Args:
steps: number of MD steps
"""
print(f"running for {int(steps)} steps!")
# try:
# with jax.debug_nans():
# with jax.disable_jit():
# try:
self._run(int(steps))
# except Exception as err:
# if self.step == 1:
# raise err
# print(f"The calculator finished early with error {err=},{type(err)=}")
self.trajectory_info._shrink_capacity()
if self.trajectory_file is not None:
self.trajectory_info.save(self.trajectory_file)
@abstractmethod
[docs] def _run(self, steps):
raise NotImplementedError
[docs] def get_trajectory(self) -> TrajectoryInfo:
assert self.trajectory_info is not None
self.trajectory_info._shrink_capacity()
return self.trajectory_info
[docs] def save_step(self, T=None, P=None, t=None, err=None):
ti = TrajectoryInfo(
_positions=self.sp.coordinates,
_cell=self.sp.cell,
_e_pot=self.last_ener.energy,
_e_pot_gpos=self.last_ener.gpos,
_e_bias=self.last_bias.energy,
_e_bias_gpos=self.last_bias.gpos,
_e_pot_vtens=self.last_ener.vtens,
_e_bias_vtens=self.last_bias.vtens,
_cv=self.last_cv.cv,
_T=T,
_P=P,
_t=t,
_err=err,
)
if self.step == 1:
str = f"{ 'step': ^10s}"
str += f"|{ 'cons err': ^10s}"
str += f"|{ 'e_pot[Kj/mol]': ^15s}"
str += f"|{ 'e_bias[Kj/mol]': ^15s}"
if ti._P is not None:
str += f"|{'P[bar]': ^10s}"
str += f"|{'T[K]': ^10s}|{'walltime[s]': ^11s}"
ss = "|\u2207\u2093U\u1D47|[Kj/\u212B]"
str += f"|{ ss : ^13s}"
str += f"|{' CV': ^10s}"
print(str, sep="")
print(f"{'='*len(str)}")
if self.step % self.static_trajectory_info.screen_log == 0:
str = f"{ self.step : >10d}"
assert ti._err is not None
assert ti._T is not None
assert ti._e_pot is not None
assert ti._e_bias is not None
str += f"|{ ti._err[0] : >10.4f}"
str += f"|{ ti._e_pot[0] /kjmol : >15.8f}"
str += f"|{ ti._e_bias[0] /kjmol : >15.8f}"
if ti._P is not None:
str += f" { ti._P[0]/bar : >10.2f}"
str += f" { ti._T[0] : >10.2f} { time()-self.time0 : >11.2f}"
str += f"|{ jnp.max(jnp.linalg.norm(ti._e_bias_gpos,axis=1) /kjmol*angstrom ) : >13.2f}"
str += f"| {ti._cv[0,:]}"
print(str)
# write step to trajectory
if self.trajectory_info is None:
self.trajectory_info = ti
else:
self.trajectory_info += ti
if self.step % self.static_trajectory_info.write_step == 0:
if self.trajectory_file is not None:
self.trajectory_info.save(self.trajectory_file) # type: ignore
self.bias.update_bias(self)
self.step += 1
[docs] def get_energy(self, gpos: bool = False, vtens: bool = False) -> EnergyResult:
return self.energy.compute_from_system_params(
gpos,
vtens,
)
[docs] def get_bias(
self,
gpos: bool = False,
vtens: bool = False,
) -> tuple[CV, EnergyResult]:
# with jax.disable_jit()
# with jax.debug_nans():
cv, ener = self.bias.compute_from_system_params(
sp=self.sp,
nl=self.nl,
gpos=gpos,
vir=vtens,
)
# if jnp.any(jnp.isnan(ener.gpos)):
# import jax
# with jax.disable_jit():
# with jax.debug_nans():
# cv, ener = self.bias.compute_from_system_params(
# sp=self.sp, nl=self.nl, gpos=gpos, vir=vtens, jit=False
# )
# if (self.static_trajectory_info.max_grad is not None) and (
# ener.gpos is not None
# ):
# ns = jnp.linalg.norm(ener.gpos, axis=1)
# norms = jnp.max(ns)
# if (fact := norms / self.static_trajectory_info.max_grad) > 1:
# ener = EnergyResult(
# ener.energy,
# ener.gpos / fact,
# ener.vtens if ener.vtens is not None else None,
# )
# print(f"clipped, fact={fact}")
return cv, ener
@property
[docs] def yaff_system(self) -> MDEngine.YaffSys:
return self.YaffSys(self.energy, self.static_trajectory_info)
# definitons of different interfaces. These encode the state of the system in the format of a given md engine
@dataclass
[docs] class YaffSys:
[docs] _tic: StaticTrajectoryInfo
@dataclass
[docs] class YaffCell:
@property
[docs] def rvecs(self):
if self._ener.cell is None:
return np.zeros((0, 3))
return np.array(self._ener.cell)
@rvecs.setter
def rvecs(self, rvecs):
self._ener.cell = rvecs
[docs] def update_rvecs(self, rvecs):
self.rvecs = rvecs
@property
[docs] def nvec(self):
return self.rvecs.shape[0]
@property
[docs] def volume(self):
if self.nvec == 0:
return np.nan
return np.linalg.det(self.rvecs)
[docs] def __post_init__(self):
self._cell = self.YaffCell(_ener=self._ener)
@property
[docs] def numbers(self):
return self._tic.atomic_numbers
@property
[docs] def masses(self):
return np.array(self._tic.masses)
@property
[docs] def charges(self):
return None
@property
[docs] def cell(self):
return self._cell
@property
[docs] def pos(self):
return np.array(self._ener.coordinates)
@pos.setter
def pos(self, pos):
self._ener.coordinates = pos
@property
[docs] def natom(self):
return self.pos.shape[0]