Source code for IMLCV.implementations.energy

import os
from collections.abc import Callable
from pathlib import Path

import ase.calculators.calculator
import ase.cell
import ase.geometry
import ase.stress
import ase.units
import numpy as np
import yaff
from ase.calculators.cp2k import CP2K
from IMLCV.base.bias import Energy
from IMLCV.base.bias import EnergyError
from IMLCV.base.bias import EnergyResult
from IMLCV.configs.config_general import get_cp2k
from molmod.units import angstrom
from molmod.units import electronvolt

yaff.log.set_level(yaff.log.silent)


[docs]class YaffEnergy(Energy): def __init__(self, f: Callable[[], yaff.ForceField]) -> None: super().__init__() self.f = f self.ff: yaff.ForceField = f() @property
[docs] def cell(self): out = self.ff.system.cell.rvecs[:] # empty cell represented as array with shape (0,3) if out.size == 0: return None return out
@cell.setter def cell(self, cell): if cell is not None: cell = np.array(cell, dtype=np.double) self.ff.update_rvecs(cell) @property
[docs] def coordinates(self): return self.ff.system.pos[:]
@coordinates.setter def coordinates(self, coordinates): self.ff.update_pos(coordinates)
[docs] def _compute_coor(self, gpos=False, vir=False) -> EnergyResult: gpos_out = np.zeros_like(self.ff.gpos) if gpos else None vtens_out = np.zeros_like(self.ff.vtens) if vir else None try: ener = self.ff.compute(gpos=gpos_out, vtens=vtens_out) except BaseException as be: raise EnergyError(f"calculating yaff energy raised execption:\n{be}\n") return EnergyResult(ener, gpos_out, vtens_out)
[docs] def __getstate__(self): return {"f": self.f, "sp": self.sp}
[docs] def __setstate__(self, state): self.f = state["f"] self.ff = self.f() self.sp = state["sp"] return self
[docs]class AseEnergy(Energy): """Conversion to ASE energy""" def __init__( self, atoms: ase.Atoms, calculator: ase.calculators.calculator.Calculator | None = None, ): self.atoms = atoms if calculator is not None: self.atoms.calc = self.calculator @property
[docs] def cell(self): return self.atoms.get_cell()[:] * angstrom
@cell.setter def cell(self, cell): self.atoms.set_cell(ase.geometry.Cell(np.array(cell[:]) / angstrom)) @property
[docs] def coordinates(self): return self.atoms.get_positions() * angstrom
@coordinates.setter def coordinates(self, coordinates): self.atoms.set_positions(np.array(coordinates[:]) / angstrom)
[docs] def _compute_coor(self, gpos=False, vir=False) -> EnergyResult: """use unit conventions of ASE""" if self.atoms.calc is None: self.atoms.calc = self._calculator() # self.atoms.calc.atoms = self.atoms try: energy = self.atoms.get_potential_energy() * electronvolt except BaseException: self._handle_exception() gpos_out = None vtens_out = None if gpos: forces = self.atoms.get_forces() gpos_out = -forces * electronvolt / angstrom if vir: cell = self.atoms.get_cell() volume = np.linalg.det(cell) stress = self.atoms.get_stress(voigt=False) vtens_out = volume * stress * electronvolt res = EnergyResult(energy, gpos_out, vtens_out) return res
[docs] def _calculator(self) -> ase.calculators.calculator.Calculator: raise NotImplementedError
[docs] def _handle_exception(self): raise EnergyError("Ase failed to provide an energy\n")
[docs] def __getstate__(self): extra_args = { "label": self.calculator.label, } dict = { "cc": self.calculator.__class__, "calc_args": {**self.calculator.todict(), **extra_args}, "atoms": self.atoms.todict(), } return dict
[docs] def __setstate__(self, state): clss = state["cc"] calc_params = state["calc_args"] atom_params = state["atoms"] self.atoms = ase.Atoms.fromdict(**atom_params) self.calculator = clss(**calc_params) self.atoms.calc = self.calculator
[docs]class Cp2kEnergy(AseEnergy): # override default params, only if explicitly set
[docs] default_parameters = dict( auto_write=False, basis_set=None, basis_set_file=None, charge=None, cutoff=None, force_eval_method=None, inp="", max_scf=None, potential_file=None, pseudo_potential=None, stress_tensor=True, uks=False, poisson_solver=None, xc=None, print_level="LOW", )
def __init__( self, atoms: ase.Atoms, input_file, input_kwargs: dict, cp2k_path: Path | None = None, **kwargs, ): self.atoms = atoms self.cp2k_inp = os.path.abspath(input_file) self.input_kwargs = input_kwargs self.kwargs = kwargs super().__init__(atoms) self.rp = cp2k_path
[docs] def _calculator(self): def relative(target: Path, origin: Path): """return path of target relative to origin""" try: return Path(target).resolve().relative_to(Path(origin).resolve()) except ValueError: # target does not start with origin # recursion with origin (eventually origin is root so try will succeed) return Path("..").joinpath(relative(target, Path(origin).parent)) rp = Path.cwd() rp.mkdir(parents=True, exist_ok=True) print(f"saving CP2K output in {rp}") new_dict = {} for key, val in self.input_kwargs.items(): assert Path(val).exists() new_dict[key] = relative(val, rp) with open(self.cp2k_inp) as f: inp = "".join(f.readlines()).format(**new_dict) params = self.default_parameters.copy() params.update(**{"inp": inp, **self.kwargs}) if "label" in params: del params["label"] print("ignoring label for Cp2kEnergy") if "directory" in params: del params["directory"] print("ignoring directory for Cp2kEnergy") params["directory"] = "." params["command"] = get_cp2k() calc = CP2K(**params) return calc
[docs] def _handle_exception(self): p = f"{self.atoms.calc.directory}/cp2k.out" assert os.path.exists(p), "no cp2k output file after failure" with open(p) as f: lines = f.readlines() out = min(len(lines), 50) assert out != 0, "cp2k.out doesn't contain output" file = "\n".join(lines[-out:]) raise EnergyError( f"The cp2k calculator failed to provide an energy. The end of the output from cp2k.out is { file}", )
[docs] def __getstate__(self): return [ self.atoms.todict(), self.cp2k_inp, self.input_kwargs, self.kwargs, ]
[docs] def __setstate__(self, state): atoms_dict, cp2k_inp, input_kwargs, kwargs = state self.__init__( ase.Atoms.fromdict(atoms_dict), cp2k_inp, input_kwargs, **kwargs, )