Source code for IMLCV.base.CVDiscovery
import itertools
from pathlib import Path
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
from IMLCV.base.CV import CollectiveVariable
from IMLCV.base.CV import CV
from IMLCV.base.CV import CvFlow
from IMLCV.base.CV import CvMetric
from IMLCV.base.CV import CvTrans
from IMLCV.base.CV import NeighbourList
from IMLCV.base.CV import SystemParams
from IMLCV.base.MdEngine import StaticTrajectoryInfo
from IMLCV.base.rounds import Rounds
from IMLCV.implementations.CV import distance_descriptor
from IMLCV.implementations.CV import sb_descriptor
from IMLCV.implementations.CV import scale_cv_trans
from jax import Array
from jax import jacrev
from jax import vmap
from jax.random import choice
from jax.random import PRNGKey
from jax.random import split
from matplotlib import gridspec
from matplotlib.colors import hsv_to_rgb
[docs]class Transformer:
def __init__(
self,
outdim,
periodicity=None,
bounding_box=None,
descriptor="sb",
descriptor_kwargs={},
) -> None:
self.outdim = outdim
if periodicity is None:
periodicity = [False for _ in range(self.outdim)]
if bounding_box is None:
bounding_box = np.array([[0.0, 10.0] for _ in periodicity])
self.periodicity = periodicity
self.bounding_box = bounding_box
self.descriptor: CvFlow
if descriptor == "sb":
self.descriptor = sb_descriptor(**descriptor_kwargs)
elif descriptor == "distance":
self.descriptor = distance_descriptor(**descriptor_kwargs)
else:
raise NotImplementedError
[docs] def pre_fit(
self,
z: list[SystemParams],
nl: list[NeighbourList] | None,
chunk_size=None,
scale=True,
) -> tuple[list[CV], CvFlow]:
f = self.descriptor
x: list[CV] = []
for i, zi in enumerate(z):
nli = nl[i] if nl is not None else None
x.append(f.compute_cv_flow(zi, nli, chunk_size=chunk_size))
if scale:
g = scale_cv_trans(CV.stack(*x))
x = [g.compute_cv_trans(xi)[0] for xi in x]
f = f * g
return x, f
[docs] def fit(
self,
sp_list: list[SystemParams],
nl_list: list[NeighbourList] | None,
chunk_size=None,
prescale=True,
postscale=True,
jac=jacrev,
*fit_args,
**fit_kwargs,
) -> tuple[CV, CollectiveVariable]:
# for i,spi in enumerate(sp_list):
# nli = nl_list[i] if nl_list is not None else None
# sp = sum(sp_list[1:], sp_list[0])
# nl = sum(nl_list[1:], nl_list[0]) if nl_list is not None else None
x, f = self.pre_fit(
sp_list,
nl_list,
scale=prescale,
chunk_size=chunk_size,
)
y, g = self._fit(
x,
nl_list,
*fit_args,
**fit_kwargs,
)
z, h = self.post_fit(y, scale=postscale)
cv = CollectiveVariable(
f=f * g * h,
metric=CvMetric(periodicities=self.periodicity),
jac=jac,
)
return z, cv
[docs] def _fit(
self,
x: list[CV],
nl: list[NeighbourList] | None,
**kwargs,
) -> tuple[CV, CvTrans]:
raise NotImplementedError
[docs] def post_fit(self, y: list[CV], scale) -> tuple[CV, CvTrans]:
y = CV.stack(*y)
if not scale:
return y, CvTrans.from_cv_function(lambda x, _: x)
h = scale_cv_trans(y)
return h.compute_cv_trans(y)[0], h
[docs]class CVDiscovery:
"""convert set of coordinates to good collective variables."""
def __init__(self, transformer: Transformer) -> None:
# self.rounds = rounds
self.transformer = transformer
[docs] def data_loader(
self,
rounds: Rounds,
num=4,
out=-1,
split_data=False,
new_r_cut=None,
) -> tuple[list[SystemParams], list[NeighbourList] | None, CollectiveVariable, StaticTrajectoryInfo]:
weights = []
colvar = rounds.get_collective_variable()
sti: StaticTrajectoryInfo | None = None
sp: list[SystemParams] = []
nl: list[NeighbourList] | None = [] if new_r_cut is not None else None
for round, traj in rounds.iter(stop=None, num=num):
if sti is None:
sti = round.tic
sp0 = traj.ti.sp
nl0 = (
sp0.get_neighbour_list(
r_cut=new_r_cut,
z_array=round.tic.atomic_numbers,
)
if new_r_cut is not None
else None
)
if (b0 := traj.ti.e_bias) is None:
# map cvs
bias = traj.get_bias()
if new_r_cut != round.tic.r_cut:
nlr = (
sp0.get_neighbour_list(
r_cut=round.tic.r_cut,
z_array=round.tic.atomic_numbers,
)
if round.tic.r_cut is not None
else None
)
else:
nlr = nl0
if (cv0 := traj.ti.CV) is None:
if colvar is None:
colvar = bias.collective_variable
cv0, _ = bias.collective_variable.compute_cv(sp=sp0, nl=nlr)
b0, _ = bias.compute_from_cv(cvs=cv0)
sp.append(sp0)
if nl is not None:
assert nl0 is not None
nl.append(nl0)
beta = 1 / round.tic.T
weight = jnp.exp(beta * b0)
weights.append(weight)
assert sti is not None
assert len(sp) != 0
if nl is not None:
assert len(nl) == len(sp)
def choose(key, probs: Array):
key, key_return = split(key, 2)
indices = choice(
key=key,
a=probs.shape[0],
shape=(int(out),),
# p=probs,
replace=False,
)
return key_return, indices
key = PRNGKey(0)
out_sp: list[SystemParams] = []
out_nl: list[NeighbourList] | None = [] if nl is not None else None
if split_data:
for n, wi in enumerate(weights):
probs = wi / jnp.sum(wi)
key, indices = choose(key, probs)
out_sp.append(sp[n][indices])
if nl is not None:
assert out_nl is not None
out_nl.append(nl[n][indices])
else:
probs = jnp.hstack(weights)
probs = probs / jnp.sum(probs)
key, indices = choose(key, probs)
out_sp.append(sum(sp[1:], sp[0])[indices])
if nl is not None:
assert out_nl is not None
out_nl.append(sum(nl[1:], nl[0])[indices])
return (out_sp, out_nl, colvar, sti)
[docs] def compute(
self,
rounds: Rounds,
num_rounds=4,
samples=3e3,
plot=True,
new_r_cut=None,
chunk_size=None,
split_data=False,
name=None,
**kwargs,
) -> CollectiveVariable:
(sp_list, nl_list, cv_old, sti) = self.data_loader(
num=num_rounds,
out=samples,
rounds=rounds,
new_r_cut=new_r_cut,
split_data=split_data,
)
cvs_new, new_cv = self.transformer.fit(
sp_list,
nl_list,
# sti=sti,
chunk_size=chunk_size,
**kwargs,
)
if plot:
sp = sum(sp_list[1:], sp_list[0])
nl = sum(nl_list[1:], nl_list[0]) if nl_list is not None else None
ind = np.random.choice(
a=sp.shape[0],
size=min(1000, sp.shape[0]),
replace=False,
)
CVDiscovery.plot_app(
name=str(rounds.folder / f"round_{rounds.round}" / "cvdiscovery"),
old_cv=cv_old,
new_cv=new_cv,
sps=sp[ind],
nl=nl[ind] if nl is not None else None,
chunk_size=chunk_size,
)
return new_cv
@staticmethod
[docs] def plot_app(
sps: SystemParams,
nl: NeighbourList,
old_cv: CollectiveVariable,
new_cv: CollectiveVariable,
name,
labels=None,
chunk_size: int | None = None,
):
def color(c, per):
c2 = (c - c.min()) / (c.max() - c.min())
if not per:
c2 *= 330.0 / 360.0
col = np.ones((len(c), 3))
col[:, 0] = c2
return hsv_to_rgb(col)
cv_data = []
cv_data_mapped = []
# raise "add neighlist"
cvs = [old_cv, new_cv]
for cv in cvs:
cvd = cv.compute_cv(sps, nl, chunk_size=chunk_size)[0].cv
cvdm = vmap(cv.metric.map)(cvd)
cv_data.append(np.array(cvd))
cv_data_mapped.append(np.array(cvdm))
# for z, data in enumerate([cv_data, cv_data_mapped]):
for z, data in enumerate([cv_data]):
# plot setting
kwargs = {"s": 0.2}
if labels is None:
labels = [
["cv in 1", "cv in 2", "cv in 3"],
["cv out 1", "cv out 2", "cv out 3"],
]
# labels = [[r"$\Phi$", r"$\Psi$"], ["umap 1", "umap 2", "umap 3"]]
for [i, j] in [[0, 1], [1, 0]]: # order
indim = cvs[i].n
if indim == 1:
continue
outdim = cvs[j].n
if outdim == 2:
proj = None
wr = 1
elif outdim == 3:
proj = "3d"
wr = 1
else:
continue
indim_pairs = list(itertools.combinations(range(indim), r=2))
print(indim_pairs)
fig = plt.figure()
if outdim == 2:
spec = gridspec.GridSpec(
nrows=len(indim_pairs) * 2,
ncols=2,
width_ratios=[1, wr],
wspace=0.5,
)
elif outdim == 3:
spec = gridspec.GridSpec(nrows=len(indim_pairs) * 2, ncols=3)
for id, inpair in enumerate(indim_pairs):
for cc in range(2):
print(f"cc={cc}")
col = color(
data[i][:, inpair[cc]],
cvs[i].metric.periodicities[inpair[cc]],
)
if outdim == 2:
l = fig.add_subplot(spec[id * 2 + cc, 0])
r = fig.add_subplot(spec[id * 2 + cc, 1], projection=proj)
elif outdim == 3:
l = fig.add_subplot(spec[id * 2 + cc, 0])
r = [
fig.add_subplot(spec[id * 3 + cc, 1], projection=proj),
fig.add_subplot(spec[id * 3 + cc, 2], projection=proj),
]
print(f"scatter={cc}")
l.scatter(*[data[i][:, l] for l in inpair], c=col, **kwargs)
l.set_xlabel(labels[i][inpair[0]])
l.set_ylabel(labels[i][inpair[1]])
if outdim == 2:
print("plot r 2d")
r.scatter(
*[data[j][:, l] for l in range(2)],
c=col,
**kwargs,
)
r.set_xlabel(labels[j][0])
r.set_ylabel(labels[j][1])
elif outdim == 3:
print("plot r 3d")
def plot3d(data, ax, colors=None, labels=labels[j], mode=0):
ax.set_xlabel(labels[0])
ax.set_ylabel(labels[1])
ax.set_zlabel(labels[2])
if mode == 0:
ax.scatter(
data[:, 0],
data[:, 1],
data[:, 2],
**kwargs,
c=colors,
zorder=1,
)
for a, b, z in [
[0, 1, "z"],
[0, 2, "y"],
[1, 2, "x"],
]:
Z, X, Y = np.histogram2d(data[:, a], data[:, b])
X = (X[1:] + X[:-1]) / 2
Y = (Y[1:] + Y[:-1]) / 2
X, Y = np.meshgrid(X, Y)
Z = (Z - Z.min()) / (Z.max() - Z.min())
kw = {
"facecolors": plt.cm.Greys(Z),
"shade": True,
"alpha": 1.0,
"zorder": 0,
}
# im = NonUniformImage(ax, interpolation='bilinear')
zz = np.zeros(X.shape) - 0.1
# zz = - Z
if z == "z":
ax.plot_surface(X, Y, zz, **kw)
elif z == "y":
ax.plot_surface(X, zz, Y, **kw)
else:
ax.plot_surface(zz, X, Y, **kw)
else:
zz = np.zeros(data[:, 0].shape)
for z in ["x", "y", "z"]:
if z == "z":
ax.scatter(
data[:, 0],
data[:, 1],
zz,
**kwargs,
zorder=1,
c=colors,
)
elif z == "y":
ax.scatter(
data[:, 0],
zz,
data[:, 2],
**kwargs,
zorder=1,
c=colors,
)
else:
ax.scatter(
zz,
data[:, 1],
data[:, 2],
**kwargs,
zorder=1,
c=colors,
)
ax.view_init(elev=20, azim=45)
plot3d(data=data[j], colors=col, ax=r[0], mode=0)
plot3d(data=data[j], colors=col, ax=r[1], mode=1)
# fig.set_size_inches([10, 16])
n = Path(
f"{name}_{ 'mapped' if z==1 else ''}_{'old_new' if i == 0 else 'new_old'}.pdf",
)
n.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(n)
# outputs.append(File(str(n)))