From 694d50015c6a2c8aaa6586e9bdda712729fe9e0c Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 17 May 2024 01:07:15 -0400 Subject: [PATCH] refactor Signed-off-by: Jinzhe Zeng --- deepmd/dpmodel/model/make_model.py | 2 + pyproject.toml | 1 + .../universal/atomic_model/test_ener_model.py | 51 ----------- .../{atomic_model => common}/__init__.py | 0 source/tests/universal/common/backend.py | 23 +++++ .../{model => common/cases}/__init__.py | 0 .../common/cases/atomic_model/__init__.py | 1 + .../common/cases/atomic_model/ener_model.py | 18 ++++ .../{ => common/cases}/atomic_model/utils.py | 22 ++--- .../universal/common/cases/model/__init__.py | 1 + .../common/cases/model/ener_model.py | 18 ++++ .../{ => common/cases}/model/utils.py | 47 ++++------ source/tests/universal/dpmodel/__init__.py | 1 + .../universal/dpmodel/atomc_model/__init__.py | 1 + .../atomc_model/test_ener_atomic_model.py | 39 ++++++++ source/tests/universal/dpmodel/backend.py | 30 +++++++ .../tests/universal/dpmodel/model/__init__.py | 1 + .../dpmodel/model/test_ener_model.py | 39 ++++++++ .../tests/universal/model/test_ener_model.py | 49 ----------- source/tests/universal/pt/__init__.py | 1 + .../universal/pt/atomc_model/__init__.py | 1 + .../pt/atomc_model/test_ener_atomic_model.py | 39 ++++++++ source/tests/universal/pt/backend.py | 59 +++++++++++++ source/tests/universal/pt/model/__init__.py | 1 + .../universal/pt/model/test_ener_model.py | 48 ++++++++++ source/tests/universal/utils.py | 88 ------------------- 26 files changed, 354 insertions(+), 227 deletions(-) delete mode 100644 source/tests/universal/atomic_model/test_ener_model.py rename source/tests/universal/{atomic_model => common}/__init__.py (100%) create mode 100644 source/tests/universal/common/backend.py rename source/tests/universal/{model => common/cases}/__init__.py (100%) create mode 100644 source/tests/universal/common/cases/atomic_model/__init__.py create mode 100644 source/tests/universal/common/cases/atomic_model/ener_model.py rename source/tests/universal/{ => common/cases}/atomic_model/utils.py (88%) create mode 100644 source/tests/universal/common/cases/model/__init__.py create mode 100644 source/tests/universal/common/cases/model/ener_model.py rename source/tests/universal/{ => common/cases}/model/utils.py (83%) create mode 100644 source/tests/universal/dpmodel/__init__.py create mode 100644 source/tests/universal/dpmodel/atomc_model/__init__.py create mode 100644 source/tests/universal/dpmodel/atomc_model/test_ener_atomic_model.py create mode 100644 source/tests/universal/dpmodel/backend.py create mode 100644 source/tests/universal/dpmodel/model/__init__.py create mode 100644 source/tests/universal/dpmodel/model/test_ener_model.py delete mode 100644 source/tests/universal/model/test_ener_model.py create mode 100644 source/tests/universal/pt/__init__.py create mode 100644 source/tests/universal/pt/atomc_model/__init__.py create mode 100644 source/tests/universal/pt/atomc_model/test_ener_atomic_model.py create mode 100644 source/tests/universal/pt/backend.py create mode 100644 source/tests/universal/pt/model/__init__.py create mode 100644 source/tests/universal/pt/model/test_ener_model.py delete mode 100644 source/tests/universal/utils.py diff --git a/deepmd/dpmodel/model/make_model.py b/deepmd/dpmodel/model/make_model.py index 273301f924..7993f10abd 100644 --- a/deepmd/dpmodel/model/make_model.py +++ b/deepmd/dpmodel/model/make_model.py @@ -236,6 +236,8 @@ def call_lower( model_predict = self.output_type_cast(model_predict, input_prec) return model_predict + forward_lower = call_lower + def input_type_cast( self, coord: np.ndarray, diff --git a/pyproject.toml b/pyproject.toml index 23d42e73d2..f10f2746e2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -350,6 +350,7 @@ banned-module-level-imports = [ "deepmd/pt/**" = ["TID253"] "source/tests/tf/**" = ["TID253"] "source/tests/pt/**" = ["TID253"] +"source/tests/universal/pt/**" = ["TID253"] "source/ipi/tests/**" = ["TID253"] "source/lmp/tests/**" = ["TID253"] "**/*.ipynb" = ["T20"] # printing in a nb file is expected diff --git a/source/tests/universal/atomic_model/test_ener_model.py b/source/tests/universal/atomic_model/test_ener_model.py deleted file mode 100644 index 1c84ee7e5c..0000000000 --- a/source/tests/universal/atomic_model/test_ener_model.py +++ /dev/null @@ -1,51 +0,0 @@ -# SPDX-License-Identifier: LGPL-3.0-or-later -import unittest - -from deepmd.dpmodel.atomic_model.dp_atomic_model import DPAtomicModel as DPAtomicModelDP -from deepmd.dpmodel.descriptor.se_e2_a import DescrptSeA as DescrptSeADP -from deepmd.dpmodel.fitting.ener_fitting import EnergyFittingNet as EnergyFittingNetDP - -from ..utils import ( - INSTALLED_PT, -) - -if INSTALLED_PT: - from deepmd.pt.model.atomic_model.energy_atomic_model import ( - DPEnergyAtomicModel as DPEnergyAtomicModelPT, - ) - -from .utils import ( - AtomicModelTestCase, -) - - -class TestEnerAtomicModel(unittest.TestCase, AtomicModelTestCase): - def setUp(self) -> None: - self.expected_rcut = 5.0 - self.expected_type_map = ["foo", "bar"] - self.expected_dim_fparam = 0 - self.expected_dim_aparam = 0 - self.expected_sel_type = [0, 1] - self.expected_aparam_nall = False - self.expected_model_output_type = ["energy", "mask"] - self.expected_sel = [8, 12] - ds = DescrptSeADP( - rcut=self.expected_rcut, - rcut_smth=self.expected_rcut / 2, - sel=self.expected_sel, - ) - ft = EnergyFittingNetDP( - ntypes=len(self.expected_type_map), - dim_descrpt=ds.get_dim_out(), - mixed_types=ds.mixed_types(), - ) - self.dp_module = DPAtomicModelDP( - ds, - ft, - type_map=self.expected_type_map, - ) - - if INSTALLED_PT: - self.pt_module = DPEnergyAtomicModelPT.deserialize( - self.dp_module.serialize() - ) diff --git a/source/tests/universal/atomic_model/__init__.py b/source/tests/universal/common/__init__.py similarity index 100% rename from source/tests/universal/atomic_model/__init__.py rename to source/tests/universal/common/__init__.py diff --git a/source/tests/universal/common/backend.py b/source/tests/universal/common/backend.py new file mode 100644 index 0000000000..d5747b77b7 --- /dev/null +++ b/source/tests/universal/common/backend.py @@ -0,0 +1,23 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Common test case.""" + +from abc import ( + ABC, + abstractmethod, +) + + +class BackendTestCase(ABC): + """Backend test case.""" + + module: object + """Module to test.""" + + @property + @abstractmethod + def modules_to_test(self) -> list: + pass + + @abstractmethod + def forward_wrapper(self, x): + pass diff --git a/source/tests/universal/model/__init__.py b/source/tests/universal/common/cases/__init__.py similarity index 100% rename from source/tests/universal/model/__init__.py rename to source/tests/universal/common/cases/__init__.py diff --git a/source/tests/universal/common/cases/atomic_model/__init__.py b/source/tests/universal/common/cases/atomic_model/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/source/tests/universal/common/cases/atomic_model/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/source/tests/universal/common/cases/atomic_model/ener_model.py b/source/tests/universal/common/cases/atomic_model/ener_model.py new file mode 100644 index 0000000000..0f1daaf87b --- /dev/null +++ b/source/tests/universal/common/cases/atomic_model/ener_model.py @@ -0,0 +1,18 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later + + +from .utils import ( + AtomicModelTestCase, +) + + +class EnerAtomicModelTest(AtomicModelTestCase): + def setUp(self) -> None: + self.expected_rcut = 5.0 + self.expected_type_map = ["foo", "bar"] + self.expected_dim_fparam = 0 + self.expected_dim_aparam = 0 + self.expected_sel_type = [0, 1] + self.expected_aparam_nall = False + self.expected_model_output_type = ["energy", "mask"] + self.expected_sel = [8, 12] diff --git a/source/tests/universal/atomic_model/utils.py b/source/tests/universal/common/cases/atomic_model/utils.py similarity index 88% rename from source/tests/universal/atomic_model/utils.py rename to source/tests/universal/common/cases/atomic_model/utils.py index 543a0fe1e2..fc30fe254d 100644 --- a/source/tests/universal/atomic_model/utils.py +++ b/source/tests/universal/common/cases/atomic_model/utils.py @@ -1,5 +1,7 @@ # SPDX-License-Identifier: LGPL-3.0-or-later from typing import ( + Any, + Callable, List, ) @@ -9,13 +11,8 @@ extend_input_and_build_neighbor_list, ) -from ..utils import ( - CommonTestCase, - forward_wrapper, -) - -class AtomicModelTestCase(CommonTestCase): +class AtomicModelTestCase: """Common test case for atomic model.""" expected_type_map: List[str] @@ -34,6 +31,8 @@ class AtomicModelTestCase(CommonTestCase): """Expected output type for the model.""" expected_sel: List[int] """Expected number of neighbors.""" + forward_wrapper: Callable[[Any], Any] + """Calss wrapper for forward method.""" def test_get_type_map(self): """Test get_type_map.""" @@ -100,7 +99,7 @@ def test_forward(self): ) ret_lower = [] for module in self.modules_to_test: - module = forward_wrapper(module) + module = self.forward_wrapper(module) ret_lower.append(module(coord_ext, atype_ext, nlist)) for kk in ret_lower[0].keys(): @@ -110,6 +109,9 @@ def test_forward(self): subret.append(rr[kk]) if len(subret): for ii, rr in enumerate(subret[1:]): - np.testing.assert_allclose( - subret[0], rr, err_msg=f"compare {kk} between 0 and {ii}" - ) + if subret[0] is None: + assert rr is None + else: + np.testing.assert_allclose( + subret[0], rr, err_msg=f"compare {kk} between 0 and {ii}" + ) diff --git a/source/tests/universal/common/cases/model/__init__.py b/source/tests/universal/common/cases/model/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/source/tests/universal/common/cases/model/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/source/tests/universal/common/cases/model/ener_model.py b/source/tests/universal/common/cases/model/ener_model.py new file mode 100644 index 0000000000..35d44f9784 --- /dev/null +++ b/source/tests/universal/common/cases/model/ener_model.py @@ -0,0 +1,18 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later + + +from .utils import ( + ModelTestCase, +) + + +class EnerModelTest(ModelTestCase): + def setUp(self) -> None: + self.expected_rcut = 5.0 + self.expected_type_map = ["foo", "bar"] + self.expected_dim_fparam = 0 + self.expected_dim_aparam = 0 + self.expected_sel_type = [0, 1] + self.expected_aparam_nall = False + self.expected_model_output_type = ["energy", "mask"] + self.expected_sel = [8, 12] diff --git a/source/tests/universal/model/utils.py b/source/tests/universal/common/cases/model/utils.py similarity index 83% rename from source/tests/universal/model/utils.py rename to source/tests/universal/common/cases/model/utils.py index 85f179cc9e..0f25d41634 100644 --- a/source/tests/universal/model/utils.py +++ b/source/tests/universal/common/cases/model/utils.py @@ -1,25 +1,18 @@ # SPDX-License-Identifier: LGPL-3.0-or-later from typing import ( + Any, + Callable, List, ) import numpy as np -from deepmd.dpmodel.common import ( - NativeOP, -) from deepmd.dpmodel.utils.nlist import ( extend_input_and_build_neighbor_list, ) -from ..utils import ( - INSTALLED_PT, - CommonTestCase, - forward_wrapper, -) - -class ModelTestCase(CommonTestCase): +class ModelTestCase: """Common test case for model.""" expected_type_map: List[str] @@ -38,12 +31,8 @@ class ModelTestCase(CommonTestCase): """Expected output type for the model.""" expected_sel: List[int] """Expected number of neighbors.""" - - @property - def modules_to_test(self): - if INSTALLED_PT: - return [*super().modules_to_test, self.pt_script_module] - return super().modules_to_test + forward_wrapper: Callable[[Any], Any] + """Calss wrapper for forward method.""" def test_get_type_map(self): """Test get_type_map.""" @@ -118,13 +107,7 @@ def test_forward(self): ret = [] ret_lower = [] for module in self.modules_to_test: - if isinstance(module, NativeOP): - # skip dp: - # 1. different keys - # 2. no forward_lower - # needs to be fixed - continue - module = forward_wrapper(module) + module = self.forward_wrapper(module) ret.append(module(coord, atype, cell)) ret_lower.append(module.forward_lower(coord_ext, atype_ext, nlist)) @@ -135,9 +118,12 @@ def test_forward(self): subret.append(rr[kk]) if len(subret): for ii, rr in enumerate(subret[1:]): - np.testing.assert_allclose( - subret[0], rr, err_msg=f"compare {kk} between 0 and {ii}" - ) + if subret[0] is None: + assert rr is None + else: + np.testing.assert_allclose( + subret[0], rr, err_msg=f"compare {kk} between 0 and {ii}" + ) for kk in ret_lower[0].keys(): subret = [] for rr in ret_lower: @@ -145,9 +131,12 @@ def test_forward(self): subret.append(rr[kk]) if len(subret): for ii, rr in enumerate(subret[1:]): - np.testing.assert_allclose( - subret[0], rr, err_msg=f"compare {kk} between 0 and {ii}" - ) + if subret[0] is None: + assert rr is None + else: + np.testing.assert_allclose( + subret[0], rr, err_msg=f"compare {kk} between 0 and {ii}" + ) same_keys = set(ret[0].keys()) & set(ret_lower[0].keys()) self.assertTrue(same_keys) for key in same_keys: diff --git a/source/tests/universal/dpmodel/__init__.py b/source/tests/universal/dpmodel/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/source/tests/universal/dpmodel/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/source/tests/universal/dpmodel/atomc_model/__init__.py b/source/tests/universal/dpmodel/atomc_model/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/source/tests/universal/dpmodel/atomc_model/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/source/tests/universal/dpmodel/atomc_model/test_ener_atomic_model.py b/source/tests/universal/dpmodel/atomc_model/test_ener_atomic_model.py new file mode 100644 index 0000000000..6cf4598646 --- /dev/null +++ b/source/tests/universal/dpmodel/atomc_model/test_ener_atomic_model.py @@ -0,0 +1,39 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest + +from deepmd.dpmodel.atomic_model.dp_atomic_model import ( + DPAtomicModel, +) +from deepmd.dpmodel.descriptor.se_e2_a import ( + DescrptSeA, +) +from deepmd.dpmodel.fitting.ener_fitting import ( + EnergyFittingNet, +) + +from ...common.cases.atomic_model.ener_model import ( + EnerAtomicModelTest, +) +from ..backend import ( + DPTestCase, +) + + +class TestEnergyAtomicModelDP(unittest.TestCase, EnerAtomicModelTest, DPTestCase): + def setUp(self): + EnerAtomicModelTest.setUp(self) + ds = DescrptSeA( + rcut=self.expected_rcut, + rcut_smth=self.expected_rcut / 2, + sel=self.expected_sel, + ) + ft = EnergyFittingNet( + ntypes=len(self.expected_type_map), + dim_descrpt=ds.get_dim_out(), + mixed_types=ds.mixed_types(), + ) + self.module = DPAtomicModel( + ds, + ft, + type_map=self.expected_type_map, + ) diff --git a/source/tests/universal/dpmodel/backend.py b/source/tests/universal/dpmodel/backend.py new file mode 100644 index 0000000000..61982fea98 --- /dev/null +++ b/source/tests/universal/dpmodel/backend.py @@ -0,0 +1,30 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from deepmd.dpmodel.common import ( + NativeOP, +) + +from ..common.backend import ( + BackendTestCase, +) + + +class DPTestCase(BackendTestCase): + """Common test case.""" + + module: NativeOP + """DP module to test.""" + + def forward_wrapper(self, x): + return x + + @property + def deserialized_module(self): + return self.module.deserialize(self.module.serialize()) + + @property + def modules_to_test(self): + modules = [ + self.module, + self.deserialized_module, + ] + return modules diff --git a/source/tests/universal/dpmodel/model/__init__.py b/source/tests/universal/dpmodel/model/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/source/tests/universal/dpmodel/model/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/source/tests/universal/dpmodel/model/test_ener_model.py b/source/tests/universal/dpmodel/model/test_ener_model.py new file mode 100644 index 0000000000..506564260f --- /dev/null +++ b/source/tests/universal/dpmodel/model/test_ener_model.py @@ -0,0 +1,39 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest + +from deepmd.dpmodel.descriptor.se_e2_a import ( + DescrptSeA, +) +from deepmd.dpmodel.fitting.ener_fitting import ( + EnergyFittingNet, +) +from deepmd.dpmodel.model.ener_model import ( + EnergyModel, +) + +from ...common.cases.model.ener_model import ( + EnerModelTest, +) +from ..backend import ( + DPTestCase, +) + + +class TestEnergyModelDP(unittest.TestCase, EnerModelTest, DPTestCase): + def setUp(self): + EnerModelTest.setUp(self) + ds = DescrptSeA( + rcut=self.expected_rcut, + rcut_smth=self.expected_rcut / 2, + sel=self.expected_sel, + ) + ft = EnergyFittingNet( + ntypes=len(self.expected_type_map), + dim_descrpt=ds.get_dim_out(), + mixed_types=ds.mixed_types(), + ) + self.module = EnergyModel( + ds, + ft, + type_map=self.expected_type_map, + ) diff --git a/source/tests/universal/model/test_ener_model.py b/source/tests/universal/model/test_ener_model.py deleted file mode 100644 index ab4a7f55b1..0000000000 --- a/source/tests/universal/model/test_ener_model.py +++ /dev/null @@ -1,49 +0,0 @@ -# SPDX-License-Identifier: LGPL-3.0-or-later -import unittest - -from deepmd.dpmodel.descriptor.se_e2_a import DescrptSeA as DescrptSeADP -from deepmd.dpmodel.fitting.ener_fitting import EnergyFittingNet as EnergyFittingNetDP -from deepmd.dpmodel.model.ener_model import EnergyModel as EnergyModelDP - -from ..utils import ( - INSTALLED_PT, -) - -if INSTALLED_PT: - from deepmd.pt.model.model.ener_model import ( - EnergyModel as EnergyModelPT, - ) - -from .utils import ( - ModelTestCase, -) - - -class TestEnerModel(unittest.TestCase, ModelTestCase): - def setUp(self) -> None: - self.expected_rcut = 5.0 - self.expected_type_map = ["foo", "bar"] - self.expected_dim_fparam = 0 - self.expected_dim_aparam = 0 - self.expected_sel_type = [0, 1] - self.expected_aparam_nall = False - self.expected_model_output_type = ["energy", "mask"] - self.expected_sel = [8, 12] - ds = DescrptSeADP( - rcut=self.expected_rcut, - rcut_smth=self.expected_rcut / 2, - sel=self.expected_sel, - ) - ft = EnergyFittingNetDP( - ntypes=len(self.expected_type_map), - dim_descrpt=ds.get_dim_out(), - mixed_types=ds.mixed_types(), - ) - self.dp_module = EnergyModelDP( - ds, - ft, - type_map=self.expected_type_map, - ) - - if INSTALLED_PT: - self.pt_module = EnergyModelPT.deserialize(self.dp_module.serialize()) diff --git a/source/tests/universal/pt/__init__.py b/source/tests/universal/pt/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/source/tests/universal/pt/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/source/tests/universal/pt/atomc_model/__init__.py b/source/tests/universal/pt/atomc_model/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/source/tests/universal/pt/atomc_model/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/source/tests/universal/pt/atomc_model/test_ener_atomic_model.py b/source/tests/universal/pt/atomc_model/test_ener_atomic_model.py new file mode 100644 index 0000000000..5ba3be0fad --- /dev/null +++ b/source/tests/universal/pt/atomc_model/test_ener_atomic_model.py @@ -0,0 +1,39 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest + +from deepmd.pt.model.atomic_model.dp_atomic_model import ( + DPAtomicModel, +) +from deepmd.pt.model.descriptor.se_a import ( + DescrptSeA, +) +from deepmd.pt.model.task.ener import ( + EnergyFittingNet, +) + +from ...common.cases.atomic_model.ener_model import ( + EnerAtomicModelTest, +) +from ..backend import ( + PTTestCase, +) + + +class TestEnergyAtomicModelDP(unittest.TestCase, EnerAtomicModelTest, PTTestCase): + def setUp(self): + EnerAtomicModelTest.setUp(self) + ds = DescrptSeA( + rcut=self.expected_rcut, + rcut_smth=self.expected_rcut / 2, + sel=self.expected_sel, + ) + ft = EnergyFittingNet( + ntypes=len(self.expected_type_map), + dim_descrpt=ds.get_dim_out(), + mixed_types=ds.mixed_types(), + ) + self.module = DPAtomicModel( + ds, + ft, + type_map=self.expected_type_map, + ) diff --git a/source/tests/universal/pt/backend.py b/source/tests/universal/pt/backend.py new file mode 100644 index 0000000000..61110a0cc6 --- /dev/null +++ b/source/tests/universal/pt/backend.py @@ -0,0 +1,59 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import torch + +from deepmd.pt.utils.utils import ( + to_numpy_array, + to_torch_tensor, +) + + +class PTTestCase: + """Common test case.""" + + module: "torch.nn.Module" + """PT module to test.""" + + @property + def script_module(self): + return torch.jit.script(self.module) + + @property + def deserialized_module(self): + return self.module.deserialize(self.module.serialize()) + + @property + def modules_to_test(self): + modules = [ + self.module, + self.deserialized_module, + ] + return modules + + def test_jit(self): + self.script_module + + def forward_wrapper(self, module): + def create_wrapper_method(method): + def wrapper_method(self, *args, **kwargs): + # convert to torch tensor + args = [to_torch_tensor(arg) for arg in args] + kwargs = {k: to_torch_tensor(v) for k, v in kwargs.items()} + # forward + output = method(*args, **kwargs) + # convert to numpy array + if isinstance(output, tuple): + output = tuple(to_numpy_array(o) for o in output) + elif isinstance(output, dict): + output = {k: to_numpy_array(v) for k, v in output.items()} + else: + output = to_numpy_array(output) + return output + + return wrapper_method + + class wrapper_module: + __call__ = create_wrapper_method(module.__call__) + if hasattr(module, "forward_lower"): + forward_lower = create_wrapper_method(module.forward_lower) + + return wrapper_module() diff --git a/source/tests/universal/pt/model/__init__.py b/source/tests/universal/pt/model/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/source/tests/universal/pt/model/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/source/tests/universal/pt/model/test_ener_model.py b/source/tests/universal/pt/model/test_ener_model.py new file mode 100644 index 0000000000..af5d77d5b4 --- /dev/null +++ b/source/tests/universal/pt/model/test_ener_model.py @@ -0,0 +1,48 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest + +from deepmd.pt.model.descriptor.se_a import ( + DescrptSeA, +) +from deepmd.pt.model.model.ener_model import ( + EnergyModel, +) +from deepmd.pt.model.task.ener import ( + EnergyFittingNet, +) + +from ...common.cases.model.ener_model import ( + EnerModelTest, +) +from ..backend import ( + PTTestCase, +) + + +class TestEnergyModelDP(unittest.TestCase, EnerModelTest, PTTestCase): + @property + def modules_to_test(self): + # for Model, we can test script module API + modules = [ + *PTTestCase.modules_to_test.fget(self), + self.script_module, + ] + return modules + + def setUp(self): + EnerModelTest.setUp(self) + ds = DescrptSeA( + rcut=self.expected_rcut, + rcut_smth=self.expected_rcut / 2, + sel=self.expected_sel, + ) + ft = EnergyFittingNet( + ntypes=len(self.expected_type_map), + dim_descrpt=ds.get_dim_out(), + mixed_types=ds.mixed_types(), + ) + self.module = EnergyModel( + ds, + ft, + type_map=self.expected_type_map, + ) diff --git a/source/tests/universal/utils.py b/source/tests/universal/utils.py deleted file mode 100644 index 5830819726..0000000000 --- a/source/tests/universal/utils.py +++ /dev/null @@ -1,88 +0,0 @@ -# SPDX-License-Identifier: LGPL-3.0-or-later -"""Common test case.""" - -import unittest - -from deepmd.backend.backend import ( - Backend, -) -from deepmd.dpmodel.common import ( - NativeOP, -) - -INSTALLED_PT = Backend.get_backend("pytorch")().is_available() -if INSTALLED_PT: - import torch - - -class CommonTestCase: - """Common test case.""" - - dp_module: NativeOP - """DP module to test.""" - - pt_module: "torch.nn.Module" - """PT module to test.""" - - @property - def pt_script_module(self): - return torch.jit.script(self.pt_module) - - @property - def pt_deserialized_module(self): - return self.pt_module.deserialize(self.pt_module.serialize()) - - @property - def dp_deserialized_module(self): - return self.dp_module.deserialize(self.dp_module.serialize()) - - @property - def modules_to_test(self): - modules = [ - self.dp_module, - self.dp_deserialized_module, - ] - if INSTALLED_PT: - modules.extend([self.pt_module, self.pt_deserialized_module]) - return modules - - @unittest.skipIf(not INSTALLED_PT, "PyTorch is not installed.") - def test_pt_jit(self): - self.pt_script_module - - -def forward_wrapper(module): - if isinstance(module, NativeOP): - return module - elif INSTALLED_PT and isinstance(module, torch.nn.Module): - from deepmd.pt.utils.utils import ( - to_numpy_array, - to_torch_tensor, - ) - - def create_wrapper_method(method): - def wrapper_method(self, *args, **kwargs): - # convert to torch tensor - args = [to_torch_tensor(arg) for arg in args] - kwargs = {k: to_torch_tensor(v) for k, v in kwargs.items()} - # forward - output = method(*args, **kwargs) - # convert to numpy array - if isinstance(output, tuple): - output = tuple(to_numpy_array(o) for o in output) - elif isinstance(output, dict): - output = {k: to_numpy_array(v) for k, v in output.items()} - else: - output = to_numpy_array(output) - return output - - return wrapper_method - - class wrapper_module: - __call__ = create_wrapper_method(module.__call__) - if hasattr(module, "forward_lower"): - forward_lower = create_wrapper_method(module.forward_lower) - - return wrapper_module() - else: - raise ValueError(f"Unsupported module type: {type(module)}")