from functools import partial
import jax.debug
import jax.dtypes
import jax.lax
import jax.numpy as jnp
import jax.numpy.linalg
import jax.random
import jax.scipy
import jaxopt
import matplotlib.pyplot as plt
import scipy.special
from IMLCV.base.CV import NeighbourList
from IMLCV.base.CV import SystemParams
from IMLCV.tools.bessel_callback import ive
from IMLCV.tools.bessel_callback import spherical_jn
from jax import Array
from jax import jit
from jax import lax
from jax import vmap
from scipy.special import legendre as sp_legendre
# todo: Optimizing many-body atomic descriptors for enhanced computational performance of
# machine learning based interatomic potentials
# @partial(jit, static_argnums=(1,))
[docs]def legendre(x, n):
c = jnp.array(sp_legendre(n).c, dtype=x.dtype)
y = jnp.zeros_like(x)
y, _ = lax.scan(lambda y, p: (y * x + p, None), y, c, length=n + 1)
return y
# @partial(jit, static_argnums=(2, 3))
[docs]def p_i(sp: SystemParams, nl: NeighbourList, p, r_cut):
if sp.batched:
return vmap(p_i, in_axes=(0, 0, None, None))(sp, nl, p, r_cut)
ps, pd = p
_, val0 = nl.apply_fun_neighbour_pair(
sp=sp,
func_single=ps,
func_double=pd,
r_cut=r_cut,
fill_value=0.0,
reduce="full",
unique=True,
split_z=True,
)
return val0
# @partial(vmap, in_axes=(0, None, None), out_axes=0)
# @partial(jit, static_argnums=(0,))
[docs]def lengendre_l(l, pj, pk):
# https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where
n2 = jnp.dot(pj, pj) * jnp.dot(pk, pk)
n2s = jax.lax.cond(
n2 > 0,
lambda: n2,
lambda: jnp.ones_like(n2),
)
cos_ang = jnp.dot(pj, pk) * jax.lax.cond(
n2 > 0,
lambda: 1 / n2s ** (0.5),
lambda: jnp.zeros_like(n2),
)
return legendre(cos_ang, l)
[docs]def p_innl_soap(l_max, n_max, r_cut, sigma_a, r_delta, num=50):
# for explanation soap:
# https://aip.scitation.org/doi/suppl/10.1063/1.5111045
@jit
def phi(n, n_max, r, r_cut, sigma_a):
return jnp.exp(-((r - r_cut * n / n_max) ** 2) / (2 * sigma_a**2))
@partial(vmap, in_axes=(None, 0, None), out_axes=1)
@partial(vmap, in_axes=(0, None, None), out_axes=0)
@jit
def I_prime_ml(n, l_vec, r_ij):
def f(r):
# https://mathworld.wolfram.com/ModifiedSphericalBesselFunctionoftheFirstKind.html
def fi(r, r_ij):
return (
r ** (3 / 2)
* jnp.sqrt(sigma_a**2 * jnp.pi / (2 * r_ij))
* phi(n, n_max, r, r_cut, sigma_a)
* jnp.exp(-((r - r_ij) ** 2) / (2 * sigma_a**2))
* ive(l_vec + 0.5, r * r_ij / sigma_a**2)
)
return jnp.where(
r > 0,
jnp.where(
r_ij > 0,
fi(jnp.where(r > 0, r, 1), jnp.where(r_ij > 0, r_ij, 1)),
0.0,
),
0.0,
)
x = jnp.linspace(0, r_cut, num=num)
y = vmap(f)(x)
return jnp.apply_along_axis(lambda y: jnp.trapz(y=y, x=x), axis=0, arr=y)
@jit
def f_cut(r):
return lax.cond(
r > r_cut,
lambda: 0.0,
lambda: lax.cond(
r < r_cut - r_delta,
lambda: 1.0,
lambda: 0.5 * (1 + jnp.cos(jnp.pi * (r - r_cut + r_delta) / r_delta)),
),
)
def S_nm(ind):
def g(r):
return phi(ind[0], n_max, r, r_cut, sigma_a) * phi(ind[1], n_max, r, r_cut, sigma_a) * r**2
x = jnp.linspace(0, r_cut, num=num)
y = g(x)
return jnp.trapz(y=y, x=x)
l_vec = jnp.arange(0, l_max + 1)
n_vec = jnp.arange(0, n_max + 1)
indices = jnp.array(jnp.meshgrid(n_vec, n_vec))
S = jnp.apply_along_axis(S_nm, axis=0, arr=indices)
L, V = jnp.linalg.eigh(S)
L = L.at[L < 0].set(0)
U = jnp.diag(jnp.sqrt(L)) @ V.T
U_inv_nm = jnp.linalg.pinv(U)
l_list = list(range(l_max + 1))
@jit
def _l(p_ij, p_ik):
return jnp.array([lengendre_l(l, p_ij, p_ik) for l in l_list])
def a_nlj(r_ij):
return U_inv_nm @ I_prime_ml(n_vec, l_vec, r_ij) * f_cut(r_ij)
@jit
def _p_i_soap_2_s(p_ij, atom_index_j):
r_ij2 = jnp.dot(p_ij, p_ij)
r_ij2 = jax.lax.cond(r_ij2 == 0, lambda: jnp.ones_like(r_ij2), lambda: r_ij2)
shape = jax.eval_shape(a_nlj, r_ij2)
a_jnl = jax.lax.cond(
r_ij2 == 0,
lambda: jnp.full(shape=shape.shape, fill_value=0.0, dtype=shape.dtype),
lambda: a_nlj(jnp.sqrt(r_ij2)),
)
return a_jnl
@jit
def _p_i_soap_2_d(p_ij, atom_index_j, data_j, p_ik, atom_index_k, data_k):
a_nlj = data_j
a_nlk = data_k
b_ljk = _l(p_ij, p_ik)
return jnp.einsum(
"l,al,bl,l->abl",
4 * jnp.pi * (2 * l_vec + 1),
a_nlj,
a_nlk,
b_ljk,
)
return _p_i_soap_2_s, _p_i_soap_2_d
[docs]def p_inl_sb(l_max, n_max, r_cut):
# for explanation soap:
# https://aip.scitation.org/doi/suppl/10.1063/1.5111045
assert l_max <= n_max, "l_max should be smaller or equal to n_max"
def spherical_jn_zeros(n, m):
return vmap(
lambda x: jaxopt.GradientDescent(
lambda x: spherical_jn(n, x) ** 2,
maxiter=1000,
)
.run(x)
.params,
)(
jnp.array(
(scipy.special.jn_zeros(n + 1, m) + scipy.special.jn_zeros(n, m)) / 2,
),
)
def show_spherical_jn_zeros(n, m, ngrid=100):
"""Graphical test for the above function"""
zeros = spherical_jn_zeros(n, m)
zeros_guess = (scipy.special.jn_zeros(n + 1, m) + scipy.special.jn_zeros(n, m)) / 2
x = jnp.linspace(0, jnp.max(zeros), num=1000)
y = spherical_jn(n, x)
plt.plot(x, y)
[plt.axvline(x0, color="r") for x0 in zeros]
[plt.axvline(x0, color="b") for x0 in zeros_guess]
plt.axhline(0, color="k")
u_ln = jnp.array([spherical_jn_zeros(n, l_max + 2) for n in range(n_max + 2)]).T
def e(x):
l, n = x
return (
u_ln[l, n - 1] ** 2
* u_ln[l, n + 1] ** 2
/ ((u_ln[l, n - 1] ** 2 + u_ln[l, n] ** 2) * (u_ln[l, n + 1] ** 2 + u_ln[l, n] ** 2))
)
e_nl = jnp.apply_along_axis(
func1d=e,
axis=0,
arr=jnp.array(jnp.meshgrid(jnp.arange(l_max + 2), jnp.arange(n_max + 2))),
)
@partial(vmap, in_axes=(0, None, None))
@partial(vmap, in_axes=(None, 0, None))
@jit
def f_nl(n, l, r):
def f(r):
return (
u_ln[l, n + 1] / spherical_jn(l + 1, u_ln[l, n]) * spherical_jn(l, r * u_ln[l, n] / r_cut)
- u_ln[l, n] / spherical_jn(l + 1, u_ln[l, n + 1]) * spherical_jn(l, r * u_ln[l, n + 1] / r_cut)
) * (2 / (u_ln[l, n] ** 2 + u_ln[l, n + 1]) / r_cut**3) ** (0.5)
return f(r)
l_list = list(range(l_max + 1))
l_vec = jnp.array(l_list)
n_vec = jnp.arange(n_max + 1)
nm1_vec = jnp.arange(n_max)
@jit
def _l(p_ij, p_ik):
return jnp.array([lengendre_l(l, p_ij, p_ik) for l in l_list])
@jit
def g_nl(r: Array):
fnl = f_nl(n_vec, l_vec, r)
def body(args, n):
def inner(args):
d_xlm, g_xlm = args
d_xl = 1 - e_nl[n, l_vec] / d_xlm
g_xl = 1 / jnp.sqrt(d_xl) * (fnl[n, :] + jnp.sqrt(e_nl[n, l_vec] / d_xlm) * g_xlm)
return (d_xl, g_xl), g_xl
def first(args):
return (jnp.ones_like(fnl[0, :]), fnl[0, :]), fnl[0, :]
return jax.lax.cond(n == 0, first, inner, args)
state, out = lax.scan(
f=body,
init=(
fnl[0, :] * 0 + 1,
fnl[0, :],
),
xs=nm1_vec,
)
return out
@jit
def _p_i_sb_2_s(p_ij, atom_index_j):
r_ij_sq = jnp.dot(p_ij, p_ij)
r_ij_sq_safe = jax.lax.cond(
r_ij_sq == 0,
lambda: jnp.ones_like(r_ij_sq),
lambda: r_ij_sq,
)
shape = jax.eval_shape(g_nl, r_ij_sq)
a_jnl = jax.lax.cond(
r_ij_sq == 0,
lambda: jnp.full(shape=shape.shape, fill_value=0.0, dtype=shape.dtype),
lambda: g_nl(jnp.sqrt(r_ij_sq_safe)),
)
return a_jnl
@jit
def _p_i_sb_2_d(p_ij, atom_index_j, data_j, p_ik, atom_index_k, data_k):
a_jnl = data_j
a_knl = data_k
b_ljk = _l(p_ij, p_ik)
@partial(vmap, in_axes=(None, 0, None), out_axes=1)
@partial(vmap, in_axes=(0, None, None), out_axes=0)
def a_nml_l(n, l, a):
return lax.cond(
l <= n,
lambda: a[n - l, l],
lambda: jnp.zeros_like(a[0, 0]),
)
g_nml_l_j = a_nml_l(n_vec, l_vec, a_jnl)
g_nml_l_k = a_nml_l(n_vec, l_vec, a_knl)
out = jnp.einsum(
"l,nl,nl,l -> nl",
(2 * l_vec + 1) / (4 * jnp.pi),
g_nml_l_j,
g_nml_l_k,
b_ljk,
)
return out
return _p_i_sb_2_s, _p_i_sb_2_d