import jax.lax
import jax.numpy as jnp
import numpy as onp
import scipy.special
from jax import custom_jvp
from jax import pure_callback
from jax.custom_batching import custom_vmap
# see https://github.com/google/jax/issues/11002
# see https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html
[docs]def generate_bessel(function, type, sign=1, exp_scaled=False):
def _function(v, x):
v, x = onp.asarray(v), onp.asarray(x)
return function(v, x).astype(x.dtype)
@custom_vmap
def cv_inner(v, z):
res_dtype_shape = jax.ShapeDtypeStruct(
shape=v.shape,
dtype=z.dtype,
)
return pure_callback(
_function,
res_dtype_shape,
v,
z,
vectorized=True,
)
@cv_inner.def_vmap
def _function_vmap(axis_size, in_batched, v, x):
v_batched, x_batched = in_batched
if not (v_batched and x_batched):
a = jax.lax.broadcast(v, [axis_size]) if x_batched else v
b = jax.lax.broadcast(x, [axis_size]) if v_batched else x
else:
a = v
b = x
out = cv_inner(a, b)
return out, True
@custom_jvp
def cv(v, z):
v, z = jnp.asarray(v), jnp.asarray(z)
# Promote the input to inexact (float/complex).
# Note that jnp.result_type() accounts for the enable_x64 flag.
z = z.astype(jnp.result_type(float, z.dtype))
assert v.ndim == 0 and z.ndim == 0, "batch with vmap"
return cv_inner(v, z)
@cv.defjvp
def cv_jvp(primals, tangents):
v, x = primals
v, x = jnp.asarray(v), jnp.asarray(x)
dv, dx = tangents
primal_out = cv(v, x)
v_safe = jax.lax.cond(
v == 0,
lambda: jnp.ones_like(v),
lambda: v,
)
if type == 0:
"""functions Jv, Yv, Hv_1,Hv_2"""
# https://dlmf.nist.gov/10.6 formula 10.6.1
tangents_out = jax.lax.cond(
v == 0,
lambda: -cv(v + 1, x),
# lambda:jax.lax.cond( jnp.abs(x)>=1e-2,
# lambda: cv(v - 1, x) - (v/x) * primal_out ,
lambda: 0.5 * (cv(v_safe - 1, x) - cv(v_safe + 1, x)),
# )
)
elif type == 1:
"""functions Kv and Iv"""
# https://dlmf.nist.gov/10.29 formula 10.29.1
tangents_out = jax.lax.cond(
v == 0,
lambda: sign * cv(v + 1, x),
lambda: 0.5 * (sign * cv(v_safe - 1, x) + sign * cv(v_safe + 1, x)),
)
elif type == 2:
"""functions: spherical bessels"""
# https://dlmf.nist.gov/10.51 formula 10.51.2
# double where trick
tangents_out = jax.lax.cond(
v == 0,
lambda: -cv(v + 1, x),
# lambda: (lambda v: cv(v - 1, x) - (v + 1) / x * primal_out)(
# jax.lax.cond(v == 0, lambda: jnp.ones_like(v), lambda: v)
# ),
lambda: (v * cv(v_safe - 1, x) - (v_safe + 1) * cv(v_safe + 1, x)) / (2 * v_safe + 1),
)
else:
raise ValueError
# chain rule
if exp_scaled:
if sign == -1:
tangents_out += primal_out
elif sign == 1:
tangents_out -= jnp.sign(x) * primal_out
return primal_out, tangents_out * dx
return cv
[docs]jv = generate_bessel(scipy.special.jv, type=0)
[docs]yv = generate_bessel(scipy.special.yv, type=0)
[docs]hankel1 = generate_bessel(scipy.special.hankel1, type=0)
[docs]hankel2 = generate_bessel(scipy.special.hankel2, type=0)
[docs]kv = generate_bessel(scipy.special.kv, sign=-1, type=1)
[docs]iv = generate_bessel(scipy.special.iv, sign=+1, type=1)
[docs]spherical_jn = generate_bessel(scipy.special.spherical_jn, type=2)
[docs]spherical_yn = generate_bessel(scipy.special.spherical_yn, type=2)
[docs]ive = generate_bessel(scipy.special.ive, sign=+1, type=1, exp_scaled=True)
[docs]kve = generate_bessel(scipy.special.kve, sign=-1, type=1, exp_scaled=True)