From 0b72dae39d269963740b27af0d163510d269de4b Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Mon, 23 Sep 2024 03:26:25 -0400 Subject: [PATCH] feat(jax): support neural networks (#4156) ## Summary by CodeRabbit ## Release Notes - **New Features** - Introduced JAX support, enhancing functionality and compatibility with JAX library. - Added new `JAXBackend` class for backend integration with JAX. - New functions for converting between NumPy and JAX arrays. - **Bug Fixes** - Improved compatibility of neural network layers with array API standards. - **Tests** - Added tests for JAX functionality and consistency checks against reference outputs. - Enhanced testing framework for activation functions and type embeddings. - **Chores** - Updated dependency requirements to include JAX library. --------- Signed-off-by: Jinzhe Zeng --- .github/workflows/test_cuda.yml | 2 +- .github/workflows/test_python.yml | 2 +- deepmd/backend/jax.py | 110 ++++++++++++++++++ deepmd/dpmodel/common.py | 22 ++++ deepmd/dpmodel/utils/network.py | 50 ++++++-- deepmd/dpmodel/utils/type_embed.py | 14 ++- deepmd/jax/__init__.py | 2 + deepmd/jax/common.py | 37 ++++++ deepmd/jax/env.py | 14 +++ deepmd/jax/utils/__init__.py | 1 + deepmd/jax/utils/network.py | 29 +++++ deepmd/jax/utils/type_embed.py | 21 ++++ pyproject.toml | 3 + .../array_api/test_activation_functions.py | 1 + source/tests/consistent/common.py | 59 ++++++++++ source/tests/consistent/test_activation.py | 26 +++++ .../tests/consistent/test_type_embedding.py | 18 +++ 17 files changed, 393 insertions(+), 18 deletions(-) create mode 100644 deepmd/backend/jax.py create mode 100644 deepmd/jax/__init__.py create mode 100644 deepmd/jax/common.py create mode 100644 deepmd/jax/env.py create mode 100644 deepmd/jax/utils/__init__.py create mode 100644 deepmd/jax/utils/network.py create mode 100644 deepmd/jax/utils/type_embed.py create mode 100644 source/tests/common/dpmodel/array_api/test_activation_functions.py diff --git a/.github/workflows/test_cuda.yml b/.github/workflows/test_cuda.yml index 2883f01b5a..d60a9c909a 100644 --- a/.github/workflows/test_cuda.yml +++ b/.github/workflows/test_cuda.yml @@ -51,7 +51,7 @@ jobs: - run: | export PYTORCH_ROOT=$(python -c 'import torch;print(torch.__path__[0])') export TENSORFLOW_ROOT=$(python -c 'import importlib,pathlib;print(pathlib.Path(importlib.util.find_spec("tensorflow").origin).parent)') - source/install/uv_with_retry.sh pip install --system -v -e .[gpu,test,lmp,cu12,torch] mpi4py + source/install/uv_with_retry.sh pip install --system -v -e .[gpu,test,lmp,cu12,torch,jax] mpi4py env: DP_VARIANT: cuda DP_ENABLE_NATIVE_OPTIMIZATION: 1 diff --git a/.github/workflows/test_python.yml b/.github/workflows/test_python.yml index 36f9bd78b8..8274921909 100644 --- a/.github/workflows/test_python.yml +++ b/.github/workflows/test_python.yml @@ -28,7 +28,7 @@ jobs: source/install/uv_with_retry.sh pip install --system mpich source/install/uv_with_retry.sh pip install --system "torch==2.3.0+cpu.cxx11.abi" -i https://download.pytorch.org/whl/ export PYTORCH_ROOT=$(python -c 'import torch;print(torch.__path__[0])') - source/install/uv_with_retry.sh pip install --system --only-binary=horovod -e .[cpu,test] horovod[tensorflow-cpu] mpi4py + source/install/uv_with_retry.sh pip install --system --only-binary=horovod -e .[cpu,test,jax] horovod[tensorflow-cpu] mpi4py env: # Please note that uv has some issues with finding # existing TensorFlow package. Currently, it uses diff --git a/deepmd/backend/jax.py b/deepmd/backend/jax.py new file mode 100644 index 0000000000..ece0761772 --- /dev/null +++ b/deepmd/backend/jax.py @@ -0,0 +1,110 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from importlib.util import ( + find_spec, +) +from typing import ( + TYPE_CHECKING, + Callable, + ClassVar, + List, + Type, +) + +from deepmd.backend.backend import ( + Backend, +) + +if TYPE_CHECKING: + from argparse import ( + Namespace, + ) + + from deepmd.infer.deep_eval import ( + DeepEvalBackend, + ) + from deepmd.utils.neighbor_stat import ( + NeighborStat, + ) + + +@Backend.register("jax") +class JAXBackend(Backend): + """JAX backend.""" + + name = "JAX" + """The formal name of the backend.""" + features: ClassVar[Backend.Feature] = ( + Backend.Feature(0) + # Backend.Feature.ENTRY_POINT + # | Backend.Feature.DEEP_EVAL + # | Backend.Feature.NEIGHBOR_STAT + # | Backend.Feature.IO + ) + """The features of the backend.""" + suffixes: ClassVar[List[str]] = [] + """The suffixes of the backend.""" + + def is_available(self) -> bool: + """Check if the backend is available. + + Returns + ------- + bool + Whether the backend is available. + """ + return find_spec("jax") is not None + + @property + def entry_point_hook(self) -> Callable[["Namespace"], None]: + """The entry point hook of the backend. + + Returns + ------- + Callable[[Namespace], None] + The entry point hook of the backend. + """ + raise NotImplementedError + + @property + def deep_eval(self) -> Type["DeepEvalBackend"]: + """The Deep Eval backend of the backend. + + Returns + ------- + type[DeepEvalBackend] + The Deep Eval backend of the backend. + """ + raise NotImplementedError + + @property + def neighbor_stat(self) -> Type["NeighborStat"]: + """The neighbor statistics of the backend. + + Returns + ------- + type[NeighborStat] + The neighbor statistics of the backend. + """ + raise NotImplementedError + + @property + def serialize_hook(self) -> Callable[[str], dict]: + """The serialize hook to convert the model file to a dictionary. + + Returns + ------- + Callable[[str], dict] + The serialize hook of the backend. + """ + raise NotImplementedError + + @property + def deserialize_hook(self) -> Callable[[str, dict], None]: + """The deserialize hook to convert the dictionary to a model file. + + Returns + ------- + Callable[[str, dict], None] + The deserialize hook of the backend. + """ + raise NotImplementedError diff --git a/deepmd/dpmodel/common.py b/deepmd/dpmodel/common.py index 56cb8ec1e9..d9d57d2d6c 100644 --- a/deepmd/dpmodel/common.py +++ b/deepmd/dpmodel/common.py @@ -3,6 +3,10 @@ ABC, abstractmethod, ) +from typing import ( + Any, + Optional, +) import ml_dtypes import numpy as np @@ -59,6 +63,24 @@ def __call__(self, *args, **kwargs): return self.call(*args, **kwargs) +def to_numpy_array(x: Any) -> Optional[np.ndarray]: + """Convert an array to a NumPy array. + + Parameters + ---------- + x : Any + The array to be converted. + + Returns + ------- + Optional[np.ndarray] + The NumPy array. + """ + if x is None: + return None + return np.asarray(x) + + __all__ = [ "GLOBAL_NP_FLOAT_PRECISION", "GLOBAL_ENER_FLOAT_PRECISION", diff --git a/deepmd/dpmodel/utils/network.py b/deepmd/dpmodel/utils/network.py index 941e2cfc86..22e85c9890 100644 --- a/deepmd/dpmodel/utils/network.py +++ b/deepmd/dpmodel/utils/network.py @@ -15,6 +15,7 @@ Union, ) +import array_api_compat import numpy as np from deepmd.dpmodel import ( @@ -22,6 +23,12 @@ PRECISION_DICT, NativeOP, ) +from deepmd.dpmodel.array_api import ( + support_array_api, +) +from deepmd.dpmodel.common import ( + to_numpy_array, +) from deepmd.dpmodel.utils.seed import ( child_seed, ) @@ -105,9 +112,9 @@ def serialize(self) -> dict: The serialized layer. """ data = { - "w": self.w, - "b": self.b, - "idt": self.idt, + "w": to_numpy_array(self.w), + "b": to_numpy_array(self.b), + "idt": to_numpy_array(self.idt), } return { "@class": "Layer", @@ -215,6 +222,7 @@ def dim_in(self) -> int: def dim_out(self) -> int: return self.w.shape[1] + @support_array_api(version="2022.12") def call(self, x: np.ndarray) -> np.ndarray: """Forward pass. @@ -230,11 +238,12 @@ def call(self, x: np.ndarray) -> np.ndarray: """ if self.w is None or self.activation_function is None: raise ValueError("w, b, and activation_function must be set") + xp = array_api_compat.array_namespace(x) fn = get_activation_fn(self.activation_function) y = ( - np.matmul(x, self.w) + self.b + xp.matmul(x, self.w) + self.b if self.b is not None - else np.matmul(x, self.w) + else xp.matmul(x, self.w) ) y = fn(y) if self.idt is not None: @@ -242,47 +251,64 @@ def call(self, x: np.ndarray) -> np.ndarray: if self.resnet and self.w.shape[1] == self.w.shape[0]: y += x elif self.resnet and self.w.shape[1] == 2 * self.w.shape[0]: - y += np.concatenate([x, x], axis=-1) + y += xp.concatenate([x, x], axis=-1) return y +@support_array_api(version="2022.12") def get_activation_fn(activation_function: str) -> Callable[[np.ndarray], np.ndarray]: activation_function = activation_function.lower() if activation_function == "tanh": - return np.tanh + + def fn(x): + xp = array_api_compat.array_namespace(x) + return xp.tanh(x) + + return fn elif activation_function == "relu": def fn(x): + xp = array_api_compat.array_namespace(x) # https://stackoverflow.com/a/47936476/9567349 - return x * (x > 0) + return x * xp.astype(x > 0, x.dtype) return fn elif activation_function in ("gelu", "gelu_tf"): def fn(x): + xp = array_api_compat.array_namespace(x) # generated by GitHub Copilot - return 0.5 * x * (1 + np.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * x**3))) + return ( + 0.5 + * x + * (1 + xp.tanh(xp.sqrt(xp.asarray(2 / xp.pi)) * (x + 0.044715 * x**3))) + ) return fn elif activation_function == "relu6": def fn(x): + xp = array_api_compat.array_namespace(x) # generated by GitHub Copilot - return np.minimum(np.maximum(x, 0), 6) + return xp.where( + x < 0, xp.full_like(x, 0), xp.where(x > 6, xp.full_like(x, 6), x) + ) return fn elif activation_function == "softplus": def fn(x): + xp = array_api_compat.array_namespace(x) # generated by GitHub Copilot - return np.log(1 + np.exp(x)) + return xp.log(1 + xp.exp(x)) return fn elif activation_function == "sigmoid": def fn(x): + xp = array_api_compat.array_namespace(x) # generated by GitHub Copilot - return 1 / (1 + np.exp(-x)) + return 1 / (1 + xp.exp(-x)) return fn elif activation_function.lower() in ("none", "linear"): diff --git a/deepmd/dpmodel/utils/type_embed.py b/deepmd/dpmodel/utils/type_embed.py index 2e695171d6..e11c415cfd 100644 --- a/deepmd/dpmodel/utils/type_embed.py +++ b/deepmd/dpmodel/utils/type_embed.py @@ -5,8 +5,12 @@ Union, ) +import array_api_compat import numpy as np +from deepmd.dpmodel.array_api import ( + support_array_api, +) from deepmd.dpmodel.common import ( PRECISION_DICT, NativeOP, @@ -92,16 +96,18 @@ def __init__( bias=self.use_tebd_bias, ) + @support_array_api(version="2022.12") def call(self) -> np.ndarray: """Compute the type embedding network.""" + sample_array = self.embedding_net[0]["w"] + xp = array_api_compat.array_namespace(sample_array) if not self.use_econf_tebd: - embed = self.embedding_net( - np.eye(self.ntypes, dtype=PRECISION_DICT[self.precision]) - ) + embed = self.embedding_net(xp.eye(self.ntypes, dtype=sample_array.dtype)) else: embed = self.embedding_net(self.econf_tebd) if self.padding: - embed = np.pad(embed, ((0, 1), (0, 0)), mode="constant") + embed_pad = xp.zeros((1, embed.shape[-1]), dtype=embed.dtype) + embed = xp.concatenate([embed, embed_pad], axis=0) return embed @classmethod diff --git a/deepmd/jax/__init__.py b/deepmd/jax/__init__.py new file mode 100644 index 0000000000..2ff078e797 --- /dev/null +++ b/deepmd/jax/__init__.py @@ -0,0 +1,2 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""JAX backend.""" diff --git a/deepmd/jax/common.py b/deepmd/jax/common.py new file mode 100644 index 0000000000..550b168b29 --- /dev/null +++ b/deepmd/jax/common.py @@ -0,0 +1,37 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Union, + overload, +) + +import numpy as np + +from deepmd.jax.env import ( + jnp, +) + + +@overload +def to_jax_array(array: np.ndarray) -> jnp.ndarray: ... + + +@overload +def to_jax_array(array: None) -> None: ... + + +def to_jax_array(array: Union[np.ndarray]) -> Union[jnp.ndarray]: + """Convert a numpy array to a JAX array. + + Parameters + ---------- + array : np.ndarray + The numpy array to convert. + + Returns + ------- + jnp.ndarray + The JAX tensor. + """ + if array is None: + return None + return jnp.array(array) diff --git a/deepmd/jax/env.py b/deepmd/jax/env.py new file mode 100644 index 0000000000..34e4aa6240 --- /dev/null +++ b/deepmd/jax/env.py @@ -0,0 +1,14 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import os + +os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" + +import jax +import jax.numpy as jnp + +jax.config.update("jax_enable_x64", True) + +__all__ = [ + "jax", + "jnp", +] diff --git a/deepmd/jax/utils/__init__.py b/deepmd/jax/utils/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/deepmd/jax/utils/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/deepmd/jax/utils/network.py b/deepmd/jax/utils/network.py new file mode 100644 index 0000000000..629b51b8cd --- /dev/null +++ b/deepmd/jax/utils/network.py @@ -0,0 +1,29 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +from deepmd.dpmodel.common import ( + NativeOP, +) +from deepmd.dpmodel.utils.network import NativeLayer as NativeLayerDP +from deepmd.dpmodel.utils.network import ( + make_embedding_network, + make_fitting_network, + make_multilayer_network, +) +from deepmd.jax.common import ( + to_jax_array, +) + + +class NativeLayer(NativeLayerDP): + def __setattr__(self, name: str, value: Any) -> None: + if name in {"w", "b", "idt"}: + value = to_jax_array(value) + return super().__setattr__(name, value) + + +NativeNet = make_multilayer_network(NativeLayer, NativeOP) +EmbeddingNet = make_embedding_network(NativeNet, NativeLayer) +FittingNet = make_fitting_network(EmbeddingNet, NativeNet, NativeLayer) diff --git a/deepmd/jax/utils/type_embed.py b/deepmd/jax/utils/type_embed.py new file mode 100644 index 0000000000..bc7c469524 --- /dev/null +++ b/deepmd/jax/utils/type_embed.py @@ -0,0 +1,21 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +from deepmd.dpmodel.utils.type_embed import TypeEmbedNet as TypeEmbedNetDP +from deepmd.jax.common import ( + to_jax_array, +) +from deepmd.jax.utils.network import ( + EmbeddingNet, +) + + +class TypeEmbedNet(TypeEmbedNetDP): + def __setattr__(self, name: str, value: Any) -> None: + if name in {"econf_tebd"}: + value = to_jax_array(value) + if name in {"embedding_net"}: + value = EmbeddingNet.deserialize(value.serialize()) + return super().__setattr__(name, value) diff --git a/pyproject.toml b/pyproject.toml index f181b616a3..28fe114e01 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -132,6 +132,9 @@ cu12 = [ "nvidia-cudnn-cu12<9", "nvidia-cuda-nvcc-cu12", ] +jax = [ + 'jax>=0.4.33;python_version>="3.10"', +] [tool.deepmd_build_backend.scripts] dp = "deepmd.main:main" diff --git a/source/tests/common/dpmodel/array_api/test_activation_functions.py b/source/tests/common/dpmodel/array_api/test_activation_functions.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/source/tests/common/dpmodel/array_api/test_activation_functions.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/source/tests/consistent/common.py b/source/tests/consistent/common.py index edafc7c02e..e8873e528a 100644 --- a/source/tests/consistent/common.py +++ b/source/tests/consistent/common.py @@ -35,6 +35,7 @@ INSTALLED_TF = Backend.get_backend("tensorflow")().is_available() INSTALLED_PT = Backend.get_backend("pytorch")().is_available() +INSTALLED_JAX = Backend.get_backend("jax")().is_available() if os.environ.get("CI") and not (INSTALLED_TF and INSTALLED_PT): raise ImportError("TensorFlow or PyTorch should be tested in the CI") @@ -57,6 +58,7 @@ "CommonTest", "INSTALLED_TF", "INSTALLED_PT", + "INSTALLED_JAX", ] @@ -71,6 +73,8 @@ class CommonTest(ABC): """Native DP model class.""" pt_class: ClassVar[Optional[type]] """PyTorch model class.""" + jax_class: ClassVar[Optional[type]] + """JAX model class.""" args: ClassVar[Optional[Union[Argument, List[Argument]]]] """Arguments that maps to the `data`.""" skip_dp: ClassVar[bool] = False @@ -79,6 +83,9 @@ class CommonTest(ABC): """Whether to skip the TensorFlow model.""" skip_pt: ClassVar[bool] = not INSTALLED_PT """Whether to skip the PyTorch model.""" + # we may usually skip jax before jax is fully supported + skip_jax: ClassVar[bool] = True + """Whether to skip the JAX model.""" rtol = 1e-10 """Relative tolerance for comparing the return value. Override for float32.""" atol = 1e-10 @@ -149,12 +156,23 @@ def eval_pt(self, pt_obj: Any) -> Any: The object of PT """ + def eval_jax(self, jax_obj: Any) -> Any: + """Evaluate the return value of JAX. + + Parameters + ---------- + jax_obj : Any + The object of JAX + """ + raise NotImplementedError("Not implemented") + class RefBackend(Enum): """Reference backend.""" TF = 1 DP = 2 PT = 3 + JAX = 5 @abstractmethod def extract_ret(self, ret: Any, backend: RefBackend) -> Tuple[np.ndarray, ...]: @@ -215,6 +233,11 @@ def get_dp_ret_serialization_from_cls(self, obj): data = obj.serialize() return ret, data + def get_jax_ret_serialization_from_cls(self, obj): + ret = self.eval_jax(obj) + data = obj.serialize() + return ret, data + def get_reference_backend(self): """Get the reference backend. @@ -226,6 +249,8 @@ def get_reference_backend(self): return self.RefBackend.TF if not self.skip_pt: return self.RefBackend.PT + if not self.skip_jax: + return self.RefBackend.JAX raise ValueError("No available reference") def get_reference_ret_serialization(self, ref: RefBackend): @@ -359,6 +384,40 @@ def test_pt_self_consistent(self): else: self.assertEqual(rr1, rr2) + def test_jax_consistent_with_ref(self): + """Test whether JAX and reference are consistent.""" + if self.skip_jax: + self.skipTest("Unsupported backend") + ref_backend = self.get_reference_backend() + if ref_backend == self.RefBackend.JAX: + self.skipTest("Reference is self") + ret1, data1 = self.get_reference_ret_serialization(ref_backend) + ret1 = self.extract_ret(ret1, ref_backend) + jax_obj = self.jax_class.deserialize(data1) + ret2 = self.eval_jax(jax_obj) + ret2 = self.extract_ret(ret2, self.RefBackend.JAX) + data2 = jax_obj.serialize() + np.testing.assert_equal(data1, data2) + for rr1, rr2 in zip(ret1, ret2): + np.testing.assert_allclose(rr1, rr2, rtol=self.rtol, atol=self.atol) + assert rr1.dtype == rr2.dtype, f"{rr1.dtype} != {rr2.dtype}" + + def test_jax_self_consistent(self): + """Test whether JAX is self consistent.""" + if self.skip_jax: + self.skipTest("Unsupported backend") + obj1 = self.init_backend_cls(self.jax_class) + ret1, data1 = self.get_jax_ret_serialization_from_cls(obj1) + obj1 = self.jax_class.deserialize(data1) + ret2, data2 = self.get_jax_ret_serialization_from_cls(obj1) + np.testing.assert_equal(data1, data2) + for rr1, rr2 in zip(ret1, ret2): + if isinstance(rr1, np.ndarray) and isinstance(rr2, np.ndarray): + np.testing.assert_allclose(rr1, rr2, rtol=self.rtol, atol=self.atol) + assert rr1.dtype == rr2.dtype, f"{rr1.dtype} != {rr2.dtype}" + else: + self.assertEqual(rr1, rr2) + def tearDown(self) -> None: """Clear the TF session.""" if not self.skip_tf: diff --git a/source/tests/consistent/test_activation.py b/source/tests/consistent/test_activation.py index 3fcb9b2fa5..5630e913a8 100644 --- a/source/tests/consistent/test_activation.py +++ b/source/tests/consistent/test_activation.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import sys import unittest import numpy as np @@ -12,6 +13,7 @@ GLOBAL_SEED, ) from .common import ( + INSTALLED_JAX, INSTALLED_PT, INSTALLED_TF, parameterized, @@ -28,6 +30,10 @@ from deepmd.tf.env import ( tf, ) +if INSTALLED_JAX: + from deepmd.jax.env import ( + jnp, + ) @parameterized( @@ -57,3 +63,23 @@ def test_pt_consistent_with_ref(self): ActivationFn_pt(self.activation)(to_torch_tensor(self.random_input)) ) np.testing.assert_allclose(self.ref, test, atol=1e-10) + + @unittest.skipUnless( + sys.version_info >= (3, 9), "array_api_strict doesn't support Python<=3.8" + ) + def test_arary_api_strict(self): + import array_api_strict as xp + + xp.set_array_api_strict_flags( + api_version=get_activation_fn_dp.array_api_version + ) + input = xp.asarray(self.random_input) + test = get_activation_fn_dp(self.activation)(input) + np.testing.assert_allclose(self.ref, np.array(test), atol=1e-10) + + @unittest.skipUnless(INSTALLED_JAX, "JAX is not installed") + def test_jax_consistent_with_ref(self): + input = jnp.from_dlpack(self.random_input) + test = get_activation_fn_dp(self.activation)(input) + self.assertTrue(isinstance(test, jnp.ndarray)) + np.testing.assert_allclose(self.ref, np.from_dlpack(test), atol=1e-10) diff --git a/source/tests/consistent/test_type_embedding.py b/source/tests/consistent/test_type_embedding.py index 6583dddb5f..c66ef0fbaa 100644 --- a/source/tests/consistent/test_type_embedding.py +++ b/source/tests/consistent/test_type_embedding.py @@ -13,6 +13,7 @@ ) from .common import ( + INSTALLED_JAX, INSTALLED_PT, INSTALLED_TF, CommonTest, @@ -30,6 +31,13 @@ from deepmd.tf.utils.type_embed import TypeEmbedNet as TypeEmbedNetTF else: TypeEmbedNetTF = object +if INSTALLED_JAX: + from deepmd.jax.env import ( + jnp, + ) + from deepmd.jax.utils.type_embed import TypeEmbedNet as TypeEmbedNetJAX +else: + TypeEmbedNetJAX = object @parameterized( @@ -63,7 +71,9 @@ def data(self) -> dict: tf_class = TypeEmbedNetTF dp_class = TypeEmbedNetDP pt_class = TypeEmbedNetPT + jax_class = TypeEmbedNetJAX args = type_embedding_args() + skip_jax = not INSTALLED_JAX @property def addtional_data(self) -> dict: @@ -103,6 +113,14 @@ def eval_pt(self, pt_obj: Any) -> Any: for x in (pt_obj(device=PT_DEVICE),) ] + def eval_jax(self, jax_obj: Any) -> Any: + out = jax_obj() + # ensure output is not numpy array + for x in (out,): + if isinstance(x, np.ndarray): + raise ValueError("Output is numpy array") + return [np.array(x) if isinstance(x, jnp.ndarray) else x for x in (out,)] + def extract_ret(self, ret: Any, backend) -> Tuple[np.ndarray, ...]: return (ret[0],)