Source code for IMLCV.external.tf2jax
import functools
import tensorflow as tf
from jax.experimental.jax2tf import call_tf
[docs]def loop_batcher(prim, args, dims, **params):
# do partial application of args on given position"
def apply(batch_args, f, static_args, static_pos):
arguments = []
l = len(batch_args) + len(static_args)
j = 0
k = 0
for i in range(l):
if i in static_pos:
arguments.append(static_args[j])
j += 1
else:
arguments.append(batch_args[k])
k += 1
return f(*arguments)
static_pos = []
static_args = []
batch_args = []
# find arguments for partial application
for i, (arg, batch_axis) in enumerate(zip(args, dims)):
if batch_axis is None:
static_pos.append(i)
static_args.append(arg)
else:
assert batch_axis == 0, "other position not yet implemented"
batch_args.append(arg)
# vectorize
def par_fun(batch_args, static_args):
return tf.vectorized_map(
fn=functools.partial(
apply,
f=params["callable_flat_tf"],
static_args=static_args,
static_pos=static_pos,
),
elems=batch_args,
)
if len(batch_args) != 1:
raise NotImplementedError
# execute with given arguments
ret = call_tf(par_fun)(batch_args, static_args) # noqa: F821
return (ret, (0,) * len(ret))