import os
import shutil
from abc import ABC
from collections.abc import Iterable
from dataclasses import dataclass
from pathlib import Path
import ase
import h5py
import jax
import jax.numpy as jnp
import numpy as np
from filelock import FileLock
from IMLCV.base.bias import Bias
from IMLCV.base.bias import CompositeBias
from IMLCV.base.CV import CollectiveVariable
from IMLCV.base.CV import SystemParams
from IMLCV.base.MdEngine import MDEngine
from IMLCV.base.MdEngine import StaticTrajectoryInfo
from IMLCV.base.MdEngine import TrajectoryInfo
from IMLCV.configs.bash_app_python import bash_app_python
from jax import Array
from molmod.constants import boltzmann
from molmod.units import kjmol
from parsl.data_provider.files import File
# todo: invaildate with files instead of db tha gets deleted
@dataclass
@dataclass
[docs]class Rounds(ABC):
def __init__(
self,
folder: str | Path = "output",
copy=True,
new_folder=True,
) -> None:
"""
this class saves all relevant info in a hdf5 container. It is build as follows:
root
round 0
attrs:
- name_bias
- name_md
- valid
- num
static_trajectory info
data, see static_trajectory info._save
0:
attrs:
- valid
- num
trajectory_info
data, see trajectory info._save
1:
..
...
round 1
...
...
"""
folder = Path(folder)
# only consider folder if it has reults file in it
if (folder / "results.h5").exists() and new_folder:
# look for first avaialble folder
i = 0
while True:
p = folder.parent / (f"{folder.name}_{i:0>3}")
if (p / "results.h5").exists():
i += 1
else:
break
if copy:
shutil.copytree(folder, p)
else:
p = folder
if not p.exists():
p.mkdir(parents=True)
self.folder = p
self.lock = FileLock(self.h5filelock_name)
self._make_file()
self.recover()
######################################
# IO #
######################################
[docs] def _make_file(self):
# todo: make backup
if (p := Path(self.h5file_name)).exists():
os.remove(p)
# create the file
# with self.lock:
# with h5py.File(self.h5file_name, mode="w-"):
# pass
self.h5file = h5py.File(self.h5file_name, mode="w")
@property
[docs] def h5file_name(self):
return self.full_path("results.h5")
@property
[docs] def h5filelock_name(self):
return self.full_path("results.h5.lock")
[docs] def full_path(self, name: str | Path) -> str:
return str((self.folder / name).resolve())
[docs] def rel_path(self, name: str | Path):
return str(Path(name).relative_to(self.folder))
[docs] def path(self, r=None, i=None) -> Path:
p = Path(self.folder)
if r is not None:
p /= f"round_{r}"
if i is not None:
p /= f"md_{i}"
return p
@staticmethod
[docs] def load(folder: str | Path, copy=False):
return Rounds(folder=folder, copy=copy)
[docs] def write_xyz(self, r: int | None = None, num: int = 1, repeat=None):
from ase.io.extxyz import write_extxyz
for i, (atoms, round, trajejctory) in enumerate(
self.iter_ase_atoms(r=r, num=num),
):
with open(
self.path(r=round.round, i=trajejctory.num) / "trajectory.xyz",
mode="w",
) as f:
if repeat is not None:
atoms = [a.repeat(repeat) for a in atoms]
write_extxyz(f, atoms)
######################################
# storage #
######################################
[docs] def recover(self):
for round_r in self.path().glob("round_*"):
r = round_r.parts[-1][6:]
f = self.h5file
if r not in f.keys():
attr = {}
sti = StaticTrajectoryInfo.load(
filename=self.path(r=r) / "static_trajectory_info.h5",
)
if (p := (self.path(r) / "bias")).exists():
attr["name_bias"] = self.rel_path(p)
if (p := (self.path(r) / "engine")).exists():
attr["name_md"] = self.rel_path(p)
if (p := (self.path(r) / "invalid")).exists():
attr["valid"] = False
self.add_round(attr=attr, stic=sti, r=r)
for md_i in round_r.glob("md_*"):
i = md_i.parts[-1][3:]
if i in f[f"{r}"].keys():
continue
tin = md_i / "trajectory_info.h5"
if not (tin).exists():
if not (md_i / "bash_app_trajectory_info.h5").exists():
continue
else:
tin = md_i / "bash_app_trajectory_info.h5"
bias = None
if (p := (md_i / "new_bias")).exists():
bias = self.rel_path(p)
elif (p := (md_i / "bias")).exists():
bias = self.rel_path(p)
attr = None
if (p := (md_i / "invalid")).exists():
attr["valid"] = False
traj = TrajectoryInfo.load(tin)
self.add_md(i=i, r=r, d=traj, bias=bias, attrs=attr)
[docs] def add_md(self, i, d: TrajectoryInfo, attrs=None, bias: str | None = None, r=None):
if r is None:
r = self.round
with self.lock:
f = self.h5file
if f"{i}" not in f[f"{r}"]:
f.create_group(f"{r}/{i}")
# check if in recover mode
if not (p := self.path(r=r, i=i) / "trajectory_info.h5").exists():
d.save(filename=p)
if attrs is None:
attrs = {}
if "valid" in attrs:
attrs["valid"] = True
if bias is not None:
attrs["name_bias"] = bias
if attrs is not None:
for key, val in attrs.items():
if val is not None:
f[f"{r}/{i}"].attrs[key] = val
f[f"{r}"].attrs["num"] += 1
f[f"{r}"].attrs["num_vals"] = np.append(f[f"{r}"].attrs["num_vals"], int(i))
f.flush()
[docs] def add_round(self, stic: StaticTrajectoryInfo, r=None, attr=None):
if r is None:
r = self.round + 1
if attr is None:
attr = {}
dir = self.path(r=r)
if not dir.exists():
dir.mkdir(parents=True)
with self.lock:
f = self.h5file
f.create_group(f"{r}")
for key in attr:
if attr[key] is not None:
f[f"{r}"].attrs[key] = attr[key]
stic.save(self.path(r=r) / "static_trajectory_info.h5")
f[f"{r}"].attrs["num"] = 0
f[f"{r}"].attrs["num_vals"] = np.array([], dtype=np.int32)
f[f"{r}"].attrs["valid"] = True
f.flush()
[docs] def add_round_from_md(self, md: MDEngine):
r = self.round + 1
directory = self.path(r=r)
if not os.path.isdir(directory):
os.mkdir(directory)
name_md = directory / "engine"
name_bias = directory / "bias"
md.save(name_md)
md.bias.save(name_bias)
md.bias.collective_variable
attr = {}
attr["name_md"] = self.rel_path(name_md)
attr["name_bias"] = self.rel_path(name_bias)
self.add_round(attr=attr, stic=md.static_trajectory_info, r=r)
######################################
# retreiveal #
######################################
[docs] def iter(
self,
start=None,
stop=None,
num=3,
ignore_invalid=False,
) -> Iterable[tuple[RoundInformation, TrajectoryInformation]]:
if stop is None:
stop = self.round
if start is None:
start = 0
for r0 in range(max(stop - (num - 1), start), stop + 1):
_r = self.round_information(r=r0)
if not _r.valid and not ignore_invalid:
continue
for i in _r.num_vals:
_r_i = self.get_trajectory_information(r=r0, i=i)
if not _r_i.valid and not ignore_invalid:
continue
# no points in collection
if _r_i.ti._size <= 0:
continue
yield _r, _r_i
[docs] def iter_ase_atoms(self, r: int | None = None, num: int = 3):
from molmod import angstrom
for round, trajejctory in self.iter(stop=r, num=num):
traj = trajejctory.ti
pos_A = traj._positions / angstrom
pbc = traj._cell is not None
if pbc:
cell_A = traj._cell / angstrom
# vol_A3 = traj.volume / angstrom**3
# vtens_eV = traj.vtens / electronvolt
# stresses_eVA3 = vtens_eV / vol_A3
atoms = [
ase.Atoms(
numbers=round.tic.atomic_numbers,
masses=round.tic.masses,
positions=pos,
pbc=pbc,
cell=cell,
)
for pos, cell in zip(pos_A, cell_A)
]
# atoms.info["stress"] = stresses_eVA3
else:
atoms = [
ase.Atoms(
numbers=round.tic.atomic_numbers,
masses=round.tic.masses,
positions=positions,
)
for positions in pos_A
]
# if traj.gpos is not None:
# atoms.arrays["forces"] = -traj.gpos * angstrom / electronvolt
# if traj.e_pot is not None:
# atoms.info["energy"] = traj.e_pot / electronvolt
yield atoms, round, trajejctory
######################################
# Properties #
######################################
[docs] def _set_attr(self, name, value, r=None, i=None):
with self.lock:
f = self.h5file
if r is not None:
f2 = f[f"{r}"]
else:
f2 = f
if i is not None:
assert r is not None, "also provide round"
f2 = f[f"/{i}"]
f2.attrs[name] = value
f.flush()
[docs] def _get_attr(self, name, r=None, i=None):
with self.lock:
f = self.h5file
if r is not None:
f2 = f[f"{r}"]
else:
f2 = f
if i is not None:
assert r is not None, "also provide round"
f2 = f[f"/{i}"]
return f2.attrs[name]
@property
[docs] def T(self):
return self.round_information(r=self.round).tic.T
@property
[docs] def P(self):
return self.round_information(r=self.round).tic.P
@property
[docs] def round(self):
with self.lock:
f = self.h5file
return len(f.keys()) - 1
[docs] def n(self, r=None):
if r is None:
r = self.round
return self.round_information(r=self.round).num
[docs] def invalidate_data(self, r=None, i=None):
if r is None:
r = self.round
if not (p := self.path(r=r, i=i) / "invalid").exists():
with open(p, "w+"):
pass
self._set_attr(name="valid", value=False, r=r, i=i)
[docs] def is_valid(self, r=None, i=None):
if r is None:
r = self.round
return (self.path(r=r, i=i) / "invalid").exists()
[docs] def get_collective_variable(self) -> CollectiveVariable:
bias = self.get_bias()
return bias.collective_variable
[docs] def get_bias(self, r=None, i=None) -> Bias:
if r is None:
r = self.round
bn = self._get_attr("name_bias", r=r, i=i)
return Bias.load(self.full_path(bn))
[docs] def get_engine(self, r=None) -> MDEngine:
if r is None:
r = self.round
name = self._get_attr("name_md", r=r)
return MDEngine.load(self.full_path(name), filename=None)
######################################
# MD simulations #
######################################
[docs] def run(self, bias, steps):
self.run_par([bias], steps)
[docs] def run_par(
self,
biases: Iterable[Bias],
steps,
plot=True,
KEY=42,
sp0: SystemParams | None = None,
):
if isinstance(KEY, int):
KEY = jax.random.PRNGKey(KEY)
with self.lock:
f = self.h5file
common_bias_name = self.full_path(f[f"{self.round}"].attrs["name_bias"])
common_md_name = self.full_path(f[f"{self.round}"].attrs["name_md"])
from parsl.dataflow.dflow import AppFuture
tasks: list[tuple[int, AppFuture]] | None = None
plot_tasks = []
md_engine = MDEngine.load(common_md_name)
r_cut = md_engine.static_trajectory_info.r_cut
z_array = md_engine.static_trajectory_info.atomic_numbers
bias_prev = None
if sp0 is None:
tis: TrajectoryInfo | None = None
# also get smaples form init round
for ri, ti in self.iter(start=0, num=2, ignore_invalid=True):
round_bias = ri.get_bias()
ti_bias = ti.get_bias()
traj_info = ti.ti
if traj_info.cv is None:
nl = traj_info.sp.get_neighbour_list(
r_cut=ri.tic.r_cut,
z_array=ri.tic.atomic_numbers,
)
cv, _ = round_bias.collective_variable.compute_cv(traj_info.sp, nl)
else:
cv = traj_info.CV
if ti_bias is None:
epot_r_i, _ = ti_bias.compute_from_cv(cv)
else:
epot_r_i = traj_info.e_bias
epot_r, _ = round_bias.compute_from_cv(cv)
# sp0 = ti.ti.sp
if tis is None:
tis = ti.ti
bias_prev = [epot_r_i]
else:
assert tis is not None
tis += ti.ti
bias_prev.append(epot_r_i)
if tis is None:
tis = TrajectoryInfo(
_positions=md_engine.sp.coordinates,
_cell=md_engine.sp.cell,
)
else:
assert sp0.shape[0] == len(
biases,
), f"The number of initials cvs provided {sp0.shape[0]} does not correspond to the number of biases {len(biases)}"
if bias_prev is not None:
bias_prev = jnp.hstack(bias_prev)
for i, bias in enumerate(biases):
path_name = self.path(r=self.round, i=i)
if not os.path.exists(path_name):
os.mkdir(path_name)
# construct bias
if bias is None:
b = Bias.load(common_bias_name)
else:
b = CompositeBias([Bias.load(common_bias_name), bias])
b_name = path_name / "bias"
b_name_new = path_name / "bias_new"
b.save(b_name)
traj_name = path_name / "trajectory_info.h5"
@bash_app_python(executors=["reference"])
def run(
steps: int,
sp: SystemParams | None,
inputs=[],
outputs=[],
):
bias = Bias.load(inputs[1].filepath)
kwargs = dict(
bias=bias,
trajectory_file=outputs[1].filepath,
)
if sp is not None:
kwargs["sp"] = sp
md = MDEngine.load(inputs[0].filepath, **kwargs)
md.run(steps)
bias.save(outputs[0].filepath)
d = md.get_trajectory()
return d
@bash_app_python(executors=["default"])
def plot_app(
st: StaticTrajectoryInfo,
traj: TrajectoryInfo,
inputs=[],
outputs=[],
):
bias = Bias.load(inputs[0].filepath)
if st.equilibration is not None:
if traj._t is not None:
traj = traj[traj._t > st.equilibration]
cvs = traj.CV
if cvs is None:
sp = traj.sp
nl = sp.get_neighbour_list(
r_cut=st.r_cut,
z_array=st.atomic_numbers,
)
cvs, _ = bias.collective_variable.compute_cv(sp=sp, nl=nl)
bias.plot(name=outputs[0].filepath, traj=[cvs])
if sp0 is None:
if tis.shape > 1: # type: ignore
assert tis is not None
if tis.cv is not None:
cv = tis.CV
else:
sp = tis.sp
nl = sp.get_neighbour_list(r_cut=r_cut, z_array=z_array)
cv = bias.collective_variable.f.compute_cv_flow(sp=sp, nl=nl)
# compensate for bias of current simulation
bs, _ = b.compute_from_cv(cvs=cv)
if bias_prev is not None:
bs -= bias_prev
bs = jnp.reshape(bs, (-1))
# # compensate for bias of previous
# bs += tis.e_pot
bs -= jnp.mean(bs)
probs = jnp.exp(
-bs / (md_engine.static_trajectory_info.T * boltzmann),
)
probs = probs / jnp.sum(probs)
KEY, k = jax.random.split(KEY, 2)
index = jax.random.choice(
a=probs.shape[0],
key=k,
p=probs,
)
else:
index = 0
tisi = tis[index]
spi = tisi.sp
spi = spi.unbatch()
nli = spi.get_neighbour_list(r_cut=r_cut, z_array=z_array)
print(
f"new point got cv={ tisi.CV}, e_pot={tisi.e_pot/kjmol if tisi.e_pot is not None else None } and new bias { bias.compute_from_system_params(sp=spi, nl=nli)[1].energy/kjmol} ",
)
else:
spi = sp0[i]
spi = spi.unbatch()
nli = spi.get_neighbour_list(r_cut=r_cut, z_array=z_array)
cvi, bi = bias.compute_from_system_params(sp=spi, nl=nli)
print(f"new point got cv={cvi}, new bias {bi.energy/kjmol} ")
future = run(
sp=spi, # type: ignore
inputs=[File(common_md_name), File(str(b_name))],
outputs=[File(str(b_name_new)), File(str(traj_name))],
steps=int(steps),
execution_folder=path_name,
)
if plot:
plot_file = path_name / "plot.pdf"
plot_fut = plot_app(
traj=future,
st=md_engine.static_trajectory_info,
inputs=[future.outputs[0]],
outputs=[File(str(plot_file))],
execution_folder=path_name,
)
plot_tasks.append(plot_fut)
if tasks is None:
tasks = [(i, future)]
else:
tasks.append((i, future))
assert tasks is not None
# wait for tasks to finish
for i, future in tasks:
d = future.result()
self.add_md(d=d, bias=future.outputs[0].filename, i=i)
# wait for plots to finish
if plot:
for future in plot_tasks:
d = future.result()