Source code for IMLCV.implementations.tensorflow.CvDiscovery

import jax.numpy as jnp
import umap
from IMLCV.base.CV import CV
from IMLCV.base.CV import CvTrans
from IMLCV.base.CV import NeighbourList
from IMLCV.base.CVDiscovery import Transformer
from IMLCV.implementations.CV import un_atomize
from IMLCV.implementations.tensorflow.CV import KerasFunBase
from IMLCV.implementations.tensorflow.CV import PeriodicLayer


[docs]class TranformerUMAP(Transformer):
[docs] def _fit( self, x: list[CV], nl: list[NeighbourList] | None = None, decoder=False, nunits=256, nlayers=3, parametric=True, metric=None, **kwargs, ): x = CV.stack(*x) x = un_atomize.compute_cv_trans(x, None)[0] dims = x.shape[1:] kwargs["n_components"] = self.outdim if metric is None: pl = PeriodicLayer(bbox=self.bounding_box, periodicity=self.periodicity) kwargs["output_metric"] = pl.metric else: kwargs["output_metric"] = metric if parametric: from tensorflow import keras act = keras.activations.tanh layers = [ keras.layers.InputLayer(input_shape=dims), *[ keras.layers.Dense( units=nunits, activation=act, ) for _ in range(nlayers) ], keras.layers.Dense(units=self.outdim), ] if metric is None: layers.append(pl) encoder = keras.Sequential(layers) kwargs["encoder"] = encoder if decoder: decoder = keras.Sequential( [ keras.layers.InputLayer(input_shape=(self.outdim)), *[keras.layers.Dense(units=nunits, activation=act) for _ in range(nlayers)], keras.layers.Dense(units=jnp.prod(jnp.array(x.shape[1:]))), ], ) kwargs["decoder"] = decoder reducer = umap.parametric_umap.ParametricUMAP(**kwargs) else: reducer = umap.UMAP(**kwargs) reducer.fit(x.cv) assert parametric f = CvTrans(trans=[KerasFunBase(reducer)]) return f.compute_cv_trans(x)[0], un_atomize * f