IMLCV.implementations.CvDiscovery#
Classes#
Base class for all neural network modules. |
|
Base class for all neural network modules. |
|
Base class for all neural network modules. |
|
Base class for dataclasses that should act like a JAX pytree node. |
|
Base class for dataclasses that should act like a JAX pytree node. |
|
Base class for dataclasses that should act like a JAX pytree node. |
Functions#
|
|
|
|
|
Module Contents#
- class IMLCV.implementations.CvDiscovery.Encoder#
Bases:
flax.linen.ModuleBase class for all neural network modules.
Layers and models should subclass this class.
All Flax Modules are Python 3.7 dataclasses. Since dataclasses take over
__init__, you should instead overridesetup(), which is automatically called to initialize the module.Modules can contain submodules, and in this way can be nested in a tree structure. Submodels can be assigned as regular attributes inside the
setup()method.You can define arbitrary “forward pass” methods on your Module subclass. While no methods are special-cased,
__call__is a popular choice because it allows you to use module instances as if they are functions:>>> from flax import linen as nn >>> from typing import Tuple >>> class Module(nn.Module): ... features: Tuple[int, ...] = (16, 4) ... def setup(self): ... self.dense1 = nn.Dense(self.features[0]) ... self.dense2 = nn.Dense(self.features[1]) ... def __call__(self, x): ... return self.dense2(nn.relu(self.dense1(x)))
Optionally, for more concise module implementations where submodules definitions are co-located with their usage, you can use the
compact()wrapper.- __call__(x)#
- class IMLCV.implementations.CvDiscovery.Decoder#
Bases:
flax.linen.ModuleBase class for all neural network modules.
Layers and models should subclass this class.
All Flax Modules are Python 3.7 dataclasses. Since dataclasses take over
__init__, you should instead overridesetup(), which is automatically called to initialize the module.Modules can contain submodules, and in this way can be nested in a tree structure. Submodels can be assigned as regular attributes inside the
setup()method.You can define arbitrary “forward pass” methods on your Module subclass. While no methods are special-cased,
__call__is a popular choice because it allows you to use module instances as if they are functions:>>> from flax import linen as nn >>> from typing import Tuple >>> class Module(nn.Module): ... features: Tuple[int, ...] = (16, 4) ... def setup(self): ... self.dense1 = nn.Dense(self.features[0]) ... self.dense2 = nn.Dense(self.features[1]) ... def __call__(self, x): ... return self.dense2(nn.relu(self.dense1(x)))
Optionally, for more concise module implementations where submodules definitions are co-located with their usage, you can use the
compact()wrapper.- __call__(z)#
- class IMLCV.implementations.CvDiscovery.VAE#
Bases:
flax.linen.ModuleBase class for all neural network modules.
Layers and models should subclass this class.
All Flax Modules are Python 3.7 dataclasses. Since dataclasses take over
__init__, you should instead overridesetup(), which is automatically called to initialize the module.Modules can contain submodules, and in this way can be nested in a tree structure. Submodels can be assigned as regular attributes inside the
setup()method.You can define arbitrary “forward pass” methods on your Module subclass. While no methods are special-cased,
__call__is a popular choice because it allows you to use module instances as if they are functions:>>> from flax import linen as nn >>> from typing import Tuple >>> class Module(nn.Module): ... features: Tuple[int, ...] = (16, 4) ... def setup(self): ... self.dense1 = nn.Dense(self.features[0]) ... self.dense2 = nn.Dense(self.features[1]) ... def __call__(self, x): ... return self.dense2(nn.relu(self.dense1(x)))
Optionally, for more concise module implementations where submodules definitions are co-located with their usage, you can use the
compact()wrapper.- setup()#
Initializes a Module lazily (similar to a lazy
__init__).setupis called once lazily on a module instance when a module is bound, immediately before any other methods like__call__are invoked, or before asetup-defined attribute onselfis accessed.This can happen in three cases:
Immediately when invoking
apply(),init()orinit_and_output().Once the module is given a name by being assigned to an attribute of another module inside the other module’s
setupmethod (see__setattr__()):>>> class MyModule(nn.Module): ... def setup(self): ... submodule = nn.Conv(...) ... # Accessing `submodule` attributes does not yet work here. ... # The following line invokes `self.__setattr__`, which gives ... # `submodule` the name "conv1". ... self.conv1 = submodule ... # Accessing `submodule` attributes or methods is now safe and ... # either causes setup() to be called once.
Once a module is constructed inside a method wrapped with
compact(), immediately before another method is called orsetupdefined attribute is accessed.
- __call__(x, z_rng)#
- encode(x)#
- classmethod reparameterize(rng, mean, logvar)#
- class IMLCV.implementations.CvDiscovery.TranformerAutoEncoder(*args, **kwargs)#
Bases:
IMLCV.base.CVDiscovery.TransformerBase class for dataclasses that should act like a JAX pytree node.
- _fit(cv: list[IMLCV.base.CV.CV], cv_t: list[IMLCV.base.CV.CV], w: list[jax.Array], dlo: IMLCV.base.rounds.DataLoaderOutput, chunk_size=None, verbose=True, macro_chunk=1000)#
- IMLCV.implementations.CvDiscovery._LDA_trans(cv: IMLCV.base.CV.CV, nl: IMLCV.base.CV.NeighbourList | None, shmap, shmap_kwargs, alpha, outdim, solver)#
- IMLCV.implementations.CvDiscovery._LDA_rescale(cv: IMLCV.base.CV.CV, nl: IMLCV.base.CV.NeighbourList | None, shmap, shmap_kwargs, mean)#
- IMLCV.implementations.CvDiscovery._scale_trans(cv: IMLCV.base.CV.CV, nl: IMLCV.base.CV.NeighbourList | None, shmap, shmap_kwargs, alpha: jax.Array, scale_factor: jax.Array)#
- class IMLCV.implementations.CvDiscovery.TransoformerLDA(*args, **kwargs)#
Bases:
IMLCV.base.CVDiscovery.TransformerBase class for dataclasses that should act like a JAX pytree node.
- kernel = False#
- optimizer = None#
- harmonic = True#
- _fit(cv_list: list[IMLCV.base.CV.CV], cv_t_list: list[IMLCV.base.CV.CV], w: list[jax.Array], dlo: IMLCV.base.rounds.DataLoaderOutput, chunk_size=None, verbose=True, macro_chunk=1000)#
- class IMLCV.implementations.CvDiscovery.TransformerMAF(*args, **kwargs)#
Bases:
IMLCV.base.CVDiscovery.TransformerBase class for dataclasses that should act like a JAX pytree node.
- trans: IMLCV.base.CV.CvTrans | None = None#
- _fit(x: list[IMLCV.base.CV.CV] | list[IMLCV.base.CV.SystemParams], x_t: list[IMLCV.base.CV.CV] | list[IMLCV.base.CV.SystemParams] | None, w: list[jax.Array], w_t: list[jax.Array], dlo: IMLCV.base.rounds.DataLoaderOutput, macro_chunk=1000, chunk_size=None, verbose=True)#