Source code for IMLCV.tools._rbfinterp_pythran

import jax
import jax.numpy as jnp
from IMLCV.base.CV import CV
from IMLCV.base.CV import CvMetric


[docs]def linear(r): return -r
[docs]def thin_plate_spline(r): return jax.lax.cond(r == 0, lambda r: r, lambda r: r**2 * jnp.log(r), r)
[docs]def cubic(r): return r**3
[docs]def quintic(r): return -(r**5)
[docs]def multiquadric(r): return -jnp.sqrt(r**2 + 1)
[docs]def inverse_multiquadric(r): return 1 / jnp.sqrt(r**2 + 1)
[docs]def inverse_quadratic(r): return 1 / (r**2 + 1)
[docs]def gaussian(r): return jnp.exp(-(r**2))
[docs]NAME_TO_FUNC = { "linear": linear, "thin_plate_spline": thin_plate_spline, "cubic": cubic, "quintic": quintic, "multiquadric": multiquadric, "inverse_multiquadric": inverse_multiquadric, "inverse_quadratic": inverse_quadratic, "gaussian": gaussian, }
[docs]def scale(val, metric): return val / (metric.bounding_box[:, 1] - metric.bounding_box[:, 0]) * 2
[docs]def cv_norm(x: CV, y: CV, metric: CvMetric, eps): return jnp.linalg.norm( scale(metric.min_cv(x.cv) - metric.min_cv(y.cv), metric=metric) * eps, )
[docs]def cv_vals(x: CV, powers, metric: CvMetric): return scale(metric.min_cv(x.cv), metric=metric) ** powers
[docs]def kernel_vector(x: CV, y: CV, metric: CvMetric, epsilon, kernel_func): """Evaluate RBFs, with centers at `y`, at the point `x`.""" def f0(y): return kernel_func(cv_norm(x, y, metric, epsilon)) f1 = jax.vmap(f0) out0 = f1(y) return out0
[docs]def polynomial_vector(x: CV, powers, metric: CvMetric): """Evaluate monomials, with exponents from `powers`, at the point `x`.""" def g(x, powers): return cv_vals(x, powers, metric=metric) def f0(powers): return jnp.prod(g(x, powers)) f1 = jax.vmap(f0) out0 = f1(powers) return out0
[docs]def kernel_matrix(x: CV, metric: CvMetric, eps, kernel_func): """Evaluate RBFs, with centers at `x`, at `x`.""" def f00(x, y): return cv_norm(x, y, metric, eps) f10 = jax.vmap(f00, in_axes=(0, None), out_axes=0) f11 = jax.vmap(f10, in_axes=(None, 0), out_axes=1) out_norm = f11(x, x) out_kernel = jax.vmap(jax.vmap(kernel_func))(out_norm) return out_kernel
[docs]def polynomial_matrix(x: CV, metric: CvMetric, powers): """Evaluate monomials, with exponents from `powers`, at `x`.""" def g(x, powers): return cv_vals(x, powers, metric=metric) def f00(x, powers): return jnp.prod(g(x, powers)) f10 = jax.vmap(f00, in_axes=(0, None), out_axes=0) f11 = jax.vmap(f10, in_axes=(None, 0), out_axes=1) return f11(x, powers)
# # pythran export _kernel_matrix(float[:, :], str) # def _kernel_matrix(x: CV, metric: Metric, eps, kernel): # """Return RBFs, with centers at `x`, evaluated at `x`.""" # assert isinstance(x, CV) # out = jnp.empty((x.shape[0], x.shape[0]), dtype=float) # kernel_func = NAME_TO_FUNC[kernel] # out = kernel_matrix(x, metric, eps, kernel_func) # return out # pythran export _polynomial_matrix(float[:, :], int[:, :])
[docs]def _polynomial_matrix(x: CV, powers, metric): """Return monomials, with exponents from `powers`, evaluated at `x`.""" assert isinstance(x, CV) out = polynomial_matrix(x=x, metric=metric, powers=powers) return out
[docs]def _build_system(y: CV, metric: CvMetric, d, smoothing, kernel, epsilon, powers): """Build the system used to solve for the RBF interpolant coefficients. Parameters ---------- y : (P, N) float ndarray Data point coordinates. d : (P, S) float ndarray Data values at `y`. smoothing : (P,) float ndarray Smoothing parameter for each data point. kernel : str Name of the RBF. epsilon : float Shape parameter. powers : (R, N) int ndarray The exponents for each monomial in the polynomial. Returns ------- lhs : (P + R, P + R) float ndarray Left-hand side matrix. rhs : (P + R, S) float ndarray Right-hand side matrix. shift : (N,) float ndarray Domain shift used to create the polynomial matrix. scale : (N,) float ndarray Domain scaling used to create the polynomial matrix. """ # p = d.shape[0] s = d.shape[1] r = powers.shape[0] kernel_func = NAME_TO_FUNC[kernel] # yval = cv_vals(y, metric=metric) # Shift and scale the polynomial domain to be between -1 and 1 # mins = jnp.min(yval, axis=0) # maxs = jnp.max(yval, axis=0) # shift = (maxs + mins) / 2 # scale = (maxs - mins) / 2 # The scale may be zero if there is a single point or all the points have # the same value for some dimension. Avoid division by zero by replacing # zeros with ones. # scale = scale.at[scale == 0.0].set(1.0) # Transpose to make the array fortran contiguous. This is required for # dgesv to not make a copy of lhs. K = kernel_matrix(y, metric, epsilon, kernel_func) + jnp.diag(smoothing) P = polynomial_matrix(y, metric=metric, powers=powers) lhs = jnp.block([[K, P], [P.T, jnp.zeros((P.shape[1], P.shape[1]))]]) # Transpose to make the array fortran contiguous. rhs = jnp.vstack([d, jnp.zeros((r, s))]) return lhs, rhs
# pythran export _build_evaluation_coefficients(float[:, :], # float[:, :], # str, # float, # int[:, :], # float[:], # float[:])
[docs]def _build_evaluation_coefficients( x: CV, y: CV, metric: CvMetric, kernel, epsilon, powers, ): """Construct the coefficients needed to evaluate the RBF. Parameters ---------- x : (Q, N) float ndarray Evaluation point coordinates. y : (P, N) float ndarray Data point coordinates. kernel : str Name of the RBF. epsilon : float Shape parameter. powers : (R, N) int ndarray The exponents for each monomial in the polynomial. shift : (N,) float ndarray Shifts the polynomial domain for numerical stability. scale : (N,) float ndarray Scales the polynomial domain for numerical stability. Returns ------- (Q, P + R) float ndarray """ kernel_func = NAME_TO_FUNC[kernel] def kv(x): return kernel_vector(x, y, metric, epsilon, kernel_func) kv0 = jax.vmap(kv)(x) def pv(x): return polynomial_vector(x, powers, metric=metric) pv0 = jax.vmap(pv)(x) vec0 = jnp.hstack([kv0, pv0]) return vec0