IMLCV.base.datastructures#

Attributes#

Classes#

MyPyTreeNode

Base class for dataclasses that should act like a JAX pytree node.

Functions#

jit_decorator(f[, static_argnums, static_argnames])

vmap_decorator(f[, in_axes, out_axes])

custom_jvp_decorator() → Callable[P, T])

Partial_decorator(f, *partial_args, **partial_kwargs)

Module Contents#

IMLCV.base.datastructures.TNode#
class IMLCV.base.datastructures.MyPyTreeNode(*args, **kwargs)#

Base class for dataclasses that should act like a JAX pytree node.

classmethod __init_subclass__(**kwargs)#
abstract replace(**overrides) TNode#
IMLCV.base.datastructures.field#
IMLCV.base.datastructures.P#
IMLCV.base.datastructures.P2#
IMLCV.base.datastructures.T#
IMLCV.base.datastructures.jit_decorator(f: Callable[P, T], static_argnums: int | Sequence[int] | None = None, static_argnames: str | Iterable[str] | None = None)#
IMLCV.base.datastructures.vmap_decorator(f: Callable[P, T], in_axes: int | Sequence[Any] | None = 0, out_axes: Any = 0)#
IMLCV.base.datastructures.custom_jvp_decorator(f: Callable[P, T], nondiff_argnums: Sequence[int] = ()) Callable[P, T]#
IMLCV.base.datastructures.Partial_decorator(f: Callable[P, T], *partial_args, **partial_kwargs)#