Source code for IMLCV.implementations.tensorflow
try:
import tensorflow as tf
# tf.keras.backend.set_floatx("float64")
# Disable all GPUS
tf.config.set_visible_devices([], "GPU")
[docs] visible_devices = tf.config.get_visible_devices()
for device in visible_devices:
assert device.device_type != "GPU"
import functools
from jax.experimental.jax2tf.call_tf import call_tf_p
from jax.interpreters import batching
from IMLCV.external.tf2jax import loop_batcher
batching.primitive_batchers[call_tf_p] = functools.partial(loop_batcher, call_tf_p)
except (ImportError, ModuleNotFoundError):
# Invalid device or cannot modify virtual devices once initialized.
pass