Source code for IMLCV.implementations.tensorflow.CV

import tempfile
from importlib import import_module

import numpy as np
import tensorflow as tfl
from IMLCV.base.CV import CV
from IMLCV.base.CV import CvFunBase
from IMLCV.base.CV import NeighbourList
from jax.experimental.jax2tf import call_tf
from keras.api._v2 import keras as KerasAPI

[docs]keras: KerasAPI = import_module("tensorflow.keras")
[docs]class PeriodicLayer(keras.layers.Layer): import tensorflow as tfl def __init__(self, bbox, periodicity, **kwargs): super().__init__(**kwargs) self.bbox = tfl.Variable(np.array(bbox)) self.periodicity = np.array(periodicity)
[docs] def call(self, inputs): # maps to periodic box bbox = self.bbox inputs_mod = tfl.math.mod(inputs - bbox[:, 0], bbox[:, 1] - bbox[:, 0]) + bbox[:, 0] return tfl.where(self.periodicity, inputs_mod, inputs)
[docs] def metric(self, r): # maps difference a = self.bbox[:, 1] - self.bbox[:, 0] r = tfl.math.mod(r, a) r = tfl.where(r > a / 2, r - a, r) return tfl.norm(r, axis=1)
[docs] def get_config(self): config = super().get_config().copy() config.update( { "bbox": np.array(self.bbox), "periodicity": self.periodicity, }, ) return config
[docs]class KerasFunBase(CvFunBase): def __init__(self, reducer) -> None: self.reducer = reducer
[docs] def _calc(self, x: CV, nl: NeighbourList, reverse=False, conditioners: list[CV] | None = None) -> CV: assert conditioners is None assert not reverse batched = x.batched if not batched: y = x.cv.reshape((1, -1)) else: y = x.cv def tf_fun(y): return call_tf(self.reducer.encoder.call, has_side_effects=False)(y) out = tf_fun(y) if not batched: out = out.reshape((-1,)) return CV(cv=out, _combine_dims=x._combine_dims, _stack_dims=x._stack_dims)
[docs] def __getstate__(self): # https://stackoverflow.com/questions/48295661/how-to-pickle-keras-model try: import tensorflow as tf except ImportError: raise ImportError("tensorflow not installed, cannot pickle keras model") model_str = "" with tempfile.NamedTemporaryFile(suffix=".hdf5", delete=True) as fd: tf.keras.models.save_model(self.encoder, fd.name, overwrite=True) model_str = fd.read() d = {"model_str": model_str} return d
[docs] def __setstate__(self, state): with tempfile.NamedTemporaryFile(suffix=".hdf5", delete=True) as fd: fd.write(state["model_str"]) fd.flush() custom_objects = {"PeriodicLayer": PeriodicLayer} with keras.utils.custom_object_scope(custom_objects): model = keras.models.load_model(fd.name) self.encoder = model