Source code for IMLCV.base.bias

from __future__ import annotations

from abc import ABC
from abc import abstractmethod
from collections.abc import Iterable
from pathlib import Path
from typing import TYPE_CHECKING

import cloudpickle
import jax
import jax.numpy as jnp
import jax_dataclasses
import matplotlib.pyplot as plt
import numpy as np
import yaff
from IMLCV.base.CV import CollectiveVariable
from IMLCV.base.CV import CV
from IMLCV.base.CV import NeighbourList
from IMLCV.base.CV import SystemParams
from IMLCV.configs.bash_app_python import bash_app_python
from IMLCV.tools.tools import HashableArrayWrapper
from jax import Array
from jax import jit
from jax import value_and_grad
from jax import vmap
from molmod.units import angstrom
from molmod.units import electronvolt
from molmod.units import kjmol
from parsl.data_provider.files import File

yaff.log.set_level(yaff.log.silent)

if TYPE_CHECKING:
    from IMLCV.base.MdEngine import MDEngine


######################################
#              Energy                #
######################################


@jax_dataclasses.pytree_dataclass
[docs]class EnergyResult:
[docs] energy: float
[docs] gpos: Array | None = None
[docs] vtens: Array | None = None
[docs] def __post_init__(self): if isinstance(self.gpos, Array): self.__dict__["gpos"] = jnp.array(self.gpos) if isinstance(self.vtens, Array): self.__dict__["vtens"] = jnp.array(self.vtens)
[docs] def __add__(self, other) -> EnergyResult: assert isinstance(other, EnergyResult) gpos = self.gpos if self.gpos is None: assert other.gpos is None else: assert other.gpos is not None gpos += other.gpos vtens = self.vtens if other.vtens is not None: if vtens is not None: vtens += other.vtens else: vtens = other.vtens return EnergyResult(energy=self.energy + other.energy, gpos=gpos, vtens=vtens)
[docs] def __str__(self) -> str: str = f"energy [eV]: {self.energy/electronvolt}" if self.gpos is not None: str += f"\ndE/dx^i_j [eV/angstrom] \n {self.gpos[:]*angstrom/electronvolt}" if self.vtens is not None: str += f"\n viriaal [eV] \n {self.vtens[:] / electronvolt }" return str
[docs]class BC: """base class for biased Energy of MD simulation.""" def __init__(self) -> None: pass # def compute_from_system_params( # self, # gpos=False, # vir=False, # sp: SystemParams | None = None, # nl: NeighbourList | None = None, # ) -> EnergyResult: # """Computes the bias, the gradient of the bias wrt the coordinates and # the virial.""" # raise NotImplementedError
[docs] def save(self, filename: str | Path): if isinstance(filename, str): filename = Path(filename) if not filename.parent.exists(): filename.parent.mkdir(parents=True, exist_ok=True) with open(filename, "wb") as f: cloudpickle.dump(self, f)
@staticmethod
[docs] def load(filename) -> BC: with open(filename, "rb") as f: self = cloudpickle.load(f) return self
[docs]class EnergyError(Exception): pass
[docs]class Energy(BC): @staticmethod
[docs] def load(filename) -> Energy: energy = BC.load(filename=filename) assert isinstance(energy, Energy) return energy
@property @abstractmethod
[docs] def cell(self): pass
@cell.setter @abstractmethod def cell(self, cell): pass @property @abstractmethod
[docs] def coordinates(self): pass
@coordinates.setter @abstractmethod def coordinates(self, coordinates): pass @property
[docs] def sp(self) -> SystemParams: return SystemParams(coordinates=self.coordinates, cell=self.cell)
@sp.setter def sp(self, sp: SystemParams): self.cell = sp.cell self.coordinates = sp.coordinates @abstractmethod
[docs] def _compute_coor(self, gpos=False, vir=False) -> EnergyResult: pass
[docs] def _handle_exception(self): return ""
[docs] def compute_from_system_params( self, gpos=False, vir=False, sp: SystemParams | None = None, nl: NeighbourList | None = None, ) -> EnergyResult: if sp is not None: raise NotImplementedError("untested") self.sp = sp # try: return self._compute_coor(gpos=gpos, vir=vir)
[docs]class PlumedEnerg(Energy): pass
###################################### # Biases # ######################################
[docs]class BiasError(Exception): pass
[docs]class Bias(BC, ABC): """base class for biased MD runs.""" def __init__( self, collective_variable: CollectiveVariable, start=None, step=None, ) -> None: """args: cvs: collective variables start: number of md steps before update is called step: steps between update is called""" super().__init__() self.collective_variable = collective_variable self.start = start self.step = step self.couter = 0 self.finalized = False
[docs] def update_bias( self, md: MDEngine, ): """update the bias. Can only change the properties from _get_args """
[docs] def _update_bias(self): """update the bias. Can only change the properties from _get_args """ if self.finalized: return False if self.start is None or self.step is None: return False if self.start == 0: self.start += self.step - 1 return True self.start -= 1 return False
# @partial(jit, static_argnums=(0, 2, 3))
[docs] def compute_from_system_params( self, sp: SystemParams, gpos=False, vir=False, nl: NeighbourList | None = None, jit=True, ) -> tuple[CV, EnergyResult]: """Computes the bias, the gradient of the bias wrt the coordinates and the virial.""" if sp.batched: if nl is not None: assert nl.batched return vmap( self.compute_from_system_params, in_axes=(0, None, None, 0), )(sp, gpos, vir, nl) else: return vmap( self.compute_from_system_params, in_axes=(0, None, None, None), )(sp, gpos, vir, nl) [cvs, jac] = self.collective_variable.compute_cv( sp=sp, nl=nl, jacobian=gpos or vir, jit=jit, ) [ener, de] = self.compute_from_cv(cvs, diff=(gpos or vir), jit=jit) def _resum(sp, jac, de): e_gpos = None if gpos: es = "nj,njkl->nkl" if not sp.batched: es = es.replace("n", "") e_gpos = jnp.einsum(es, de.cv, jac.cv.coordinates) e_vir = None if vir and sp.cell is not None: # transpose, see https://pubs.acs.org/doi/suppl/10.1021/acs.jctc.5b00748/suppl_file/ct5b00748_si_001.pdf s1.4 and S1.22 es = "nji,nk,nkjl->nli" if not sp.batched: es = es.replace("n", "") e_vir = jnp.einsum(es, sp.cell, de.cv, jac.cv.cell) return EnergyResult(ener, e_gpos, e_vir) if jit: _resum = jax.jit(_resum) return cvs, _resum(sp, jac, de)
# @partial(jit, static_argnums=(0, 2))
[docs] def compute_from_cv(self, cvs: CV, diff=False, jit=True) -> CV: """compute the energy and derivative. If map==False, the cvs are assumed to be already mapped """ assert isinstance(cvs, CV) # map compute command def f0(x): args = [HashableArrayWrapper(a) for a in self.get_args()] static_array_argnums = tuple(i + 1 for i in range(len(args))) if jit: return jax.jit(self._compute, static_argnums=static_array_argnums)( x, *args, ) else: return self._compute(x, *args) def f1(x): return value_and_grad(f0)(x) if diff else (f0(x), None) def f2(cvs): return vmap(f1)(cvs) if cvs.batched else f1(cvs) if jit: f2 = jax.jit(f2) return f2(cvs)
@abstractmethod
[docs] def _compute(self, cvs, *args): """function that calculates the bias potential. CVs live in mapped space""" raise NotImplementedError
@abstractmethod
[docs] def get_args(self): """function that return dictionary with kwargs of _compute.""" return []
[docs] def finalize(self): """Should be called at end of metadynamics simulation. Optimises compute """ self.finalized = True
# def __getstate__(self): # return self.__dict__ # def __setstate__(self, state): # self.__init__(**state) # return self @staticmethod
[docs] def load(filename) -> Bias: bias = BC.load(filename=filename) assert isinstance(bias, Bias) return bias
[docs] def plot( self, name, x_unit: str | None = None, y_unit: str | None = None, n=50, traj: list[CV] | None = None, vmin=0, vmax=100 * kjmol, map=False, inverted=False, margin=None, x_lim=None, y_lim=None, bins=None, ): """plot bias.""" if self.collective_variable.n == 1: if bins is None: [bins] = self.collective_variable.metric.grid( n=n, endpoints=True, margin=margin, ) if x_unit is not None: if x_unit == "rad": x_unit_label = "rad" x_fact = 0 elif x_unit == "ang": x_unit_label = "Ang" x_fact = angstrom else: x_fact = 1 x_unit_label = "a.u." if x_lim is None: xlim = [bins.min() / x_fact, bins.max() / x_fact] extent = [xlim[0], xlim[1]] @jit def f(point): return self.compute_from_cv( CV(cv=point), diff=False, ) bias, _ = jnp.apply_along_axis(f, axis=0, arr=jnp.array([bins])) if inverted: bias = -bias bias -= bias[~np.isnan(bias)].min() bias = jnp.reshape(bias, (-1,)) # plt.switch_backend("PDF") fig, ax = plt.subplots() ax.set_xlim(*extent) ax.set_ylim(vmin / kjmol, vmax / kjmol) p = ax.plot(bins, bias / (kjmol)) ax2 = ax.twinx() ax.set_xlabel(f"cv1 [{x_unit_label}]", fontsize=16) ax.set_ylabel("Bias [kJ/mol]]", fontsize=16) ax.tick_params(axis="both", which="major", labelsize=18) ax.tick_params(axis="both", which="minor", labelsize=16) if traj is not None: if not isinstance(traj, Iterable): traj = [traj] for tr in traj: # trajs are ij indexed _ = ax2.hist(tr.cv, density=True, histtype="step") elif self.collective_variable.n == 2: if bins is None: bins = self.collective_variable.metric.grid( n=n, endpoints=True, margin=margin, ) mg = np.meshgrid(*bins, indexing="xy") if x_unit is not None: if x_unit == "rad": x_unit_label = "rad" x_fact = 0 elif x_unit == "ang": x_unit_label = "Ang" x_fact = angstrom else: x_fact = 1 x_unit_label = "a.u." if y_unit is not None: if y_unit == "rad": y_unit_label = "rad" y_fact = 0 elif x_unit == "ang": y_unit_label = "Ang" y_fact = angstrom else: y_fact = 1 y_unit_label = "a.u." if x_lim is None: xlim = [mg[0].min() / x_fact, mg[0].max() / x_fact] if y_lim is None: ylim = [mg[1].min() / y_fact, mg[1].max() / y_fact] extent = [xlim[0], xlim[1], ylim[0], ylim[1]] @jit def f(point): return self.compute_from_cv( CV(cv=point), diff=False, ) bias, _ = jnp.apply_along_axis(f, axis=0, arr=np.array(mg)) if inverted: bias = -bias bias -= bias[~np.isnan(bias)].min() # plt.switch_backend("PDF") fig, ax = plt.subplots() p = ax.imshow( bias / (kjmol), cmap=plt.get_cmap("rainbow"), origin="lower", extent=extent, vmin=vmin / kjmol, vmax=vmax / kjmol, ) ax.set_xlabel(f"cv1 [{x_unit_label}]", fontsize=16) ax.set_ylabel(f"cv2 [{y_unit_label}]", fontsize=16) ax.tick_params(axis="both", which="major", labelsize=18) ax.tick_params(axis="both", which="minor", labelsize=16) cbar = fig.colorbar(p) cbar.set_label("Bias [kJ/mol]", size=18) if traj is not None: if not isinstance(traj, Iterable): traj = [traj] for tr in traj: # trajs are ij indexed ax.scatter(tr.cv[:, 0], tr.cv[:, 1], s=3) else: raise ValueError Path(name).parent.mkdir(parents=True, exist_ok=True) plt.tight_layout() plt.savefig(name) plt.close(fig=fig) # write out
@bash_app_python(executors=["default"])
[docs]def plot_app( bias: Bias, outputs: list[File], n: int = 50, vmin: float = 0, vmax: float = 100 * kjmol, map: bool = True, inverted=False, traj: list[CV] | None = None, margin=None, x_unit=None, y_unit=None, x_lim=None, y_lim=None, bins=None, ): bias.plot( name=outputs[0].filepath, n=n, traj=traj, vmin=vmin, vmax=vmax, map=map, inverted=inverted, margin=margin, x_unit=x_unit, y_unit=y_unit, x_lim=x_lim, y_lim=y_lim, bins=bins, )
[docs]class CompositeBias(Bias): """Class that combines several biases in one single bias.""" def __init__(self, biases: Iterable[Bias], fun=jnp.sum) -> None: self.init = True self.biases: list[Bias] = [] # self.start_list = np.array([], dtype=np.int16) # self.step_list = np.array([], dtype=np.int16) self.args_shape = np.array([0]) self.collective_variable: CollectiveVariable = None # type: ignore for bias in biases: self._append_bias(bias) if self.biases is None: assert biases[0] is NoneBias self.biases = bias[0] self.fun = fun super().__init__(collective_variable=self.collective_variable, start=0, step=1) self.init = True
[docs] def _append_bias(self, b: Bias): if b is NoneBias: return self.biases.append(b) # self.start_list = np.append( # self.start_list, b.start if (b.start is not None) else -1 # ) # self.step_list = np.append( # self.step_list, b.step if (b.step is not None) else -1 # ) self.args_shape = np.append( self.args_shape, len(b.get_args()) + self.args_shape[-1], ) if self.collective_variable is None: self.collective_variable = b.collective_variable else: pass
# assert self.cvs == b.cvs, "CV should be the same"
[docs] def _compute(self, cvs, *args): return self.fun( jnp.array( [ jnp.reshape( self.biases[i]._compute( cvs, *args[self.args_shape[i] : self.args_shape[i + 1]], ), (), ) for i in range(len(self.biases)) ], ), )
[docs] def finalize(self): for b in self.biases: b.finalize()
[docs] def update_bias( self, md: MDEngine, ): for b in self.biases: b.update_bias(md=md)
[docs] def get_args(self): return [a for b in self.biases for a in b.get_args()]
[docs]class BiasF(Bias): """Bias according to CV.""" def __init__(self, cvs: CollectiveVariable, g=None): self.g = g if (g is not None) else lambda _: jnp.array(0.0) # self.g = jit(self.g) #leads to pickler issues super().__init__(cvs, start=None, step=None)
[docs] def _compute(self, cvs): return self.g(cvs)
[docs] def get_args(self): return []
[docs]class NoneBias(BiasF): """dummy bias.""" def __init__(self, cvs: CollectiveVariable): super().__init__(cvs)