Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
  • Loading branch information
njzjz committed May 17, 2024
1 parent 014f16f commit 694d500
Show file tree
Hide file tree
Showing 26 changed files with 354 additions and 227 deletions.
2 changes: 2 additions & 0 deletions deepmd/dpmodel/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,8 @@ def call_lower(
model_predict = self.output_type_cast(model_predict, input_prec)
return model_predict

forward_lower = call_lower

def input_type_cast(
self,
coord: np.ndarray,
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,7 @@ banned-module-level-imports = [
"deepmd/pt/**" = ["TID253"]
"source/tests/tf/**" = ["TID253"]
"source/tests/pt/**" = ["TID253"]
"source/tests/universal/pt/**" = ["TID253"]
"source/ipi/tests/**" = ["TID253"]
"source/lmp/tests/**" = ["TID253"]
"**/*.ipynb" = ["T20"] # printing in a nb file is expected
Expand Down
51 changes: 0 additions & 51 deletions source/tests/universal/atomic_model/test_ener_model.py

This file was deleted.

File renamed without changes.
23 changes: 23 additions & 0 deletions source/tests/universal/common/backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
"""Common test case."""

from abc import (
ABC,
abstractmethod,
)


class BackendTestCase(ABC):
"""Backend test case."""

module: object
"""Module to test."""

@property
@abstractmethod
def modules_to_test(self) -> list:
pass

@abstractmethod
def forward_wrapper(self, x):
pass
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
18 changes: 18 additions & 0 deletions source/tests/universal/common/cases/atomic_model/ener_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# SPDX-License-Identifier: LGPL-3.0-or-later


from .utils import (
AtomicModelTestCase,
)


class EnerAtomicModelTest(AtomicModelTestCase):
def setUp(self) -> None:
self.expected_rcut = 5.0
self.expected_type_map = ["foo", "bar"]
self.expected_dim_fparam = 0
self.expected_dim_aparam = 0
self.expected_sel_type = [0, 1]
self.expected_aparam_nall = False
self.expected_model_output_type = ["energy", "mask"]
self.expected_sel = [8, 12]
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Any,
Callable,
List,
)

Expand All @@ -9,13 +11,8 @@
extend_input_and_build_neighbor_list,
)

from ..utils import (
CommonTestCase,
forward_wrapper,
)


class AtomicModelTestCase(CommonTestCase):
class AtomicModelTestCase:
"""Common test case for atomic model."""

expected_type_map: List[str]
Expand All @@ -34,6 +31,8 @@ class AtomicModelTestCase(CommonTestCase):
"""Expected output type for the model."""
expected_sel: List[int]
"""Expected number of neighbors."""
forward_wrapper: Callable[[Any], Any]
"""Calss wrapper for forward method."""

def test_get_type_map(self):
"""Test get_type_map."""
Expand Down Expand Up @@ -100,7 +99,7 @@ def test_forward(self):
)
ret_lower = []
for module in self.modules_to_test:
module = forward_wrapper(module)
module = self.forward_wrapper(module)

ret_lower.append(module(coord_ext, atype_ext, nlist))
for kk in ret_lower[0].keys():
Expand All @@ -110,6 +109,9 @@ def test_forward(self):
subret.append(rr[kk])
if len(subret):
for ii, rr in enumerate(subret[1:]):
np.testing.assert_allclose(
subret[0], rr, err_msg=f"compare {kk} between 0 and {ii}"
)
if subret[0] is None:
assert rr is None
else:
np.testing.assert_allclose(
subret[0], rr, err_msg=f"compare {kk} between 0 and {ii}"
)
1 change: 1 addition & 0 deletions source/tests/universal/common/cases/model/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
18 changes: 18 additions & 0 deletions source/tests/universal/common/cases/model/ener_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# SPDX-License-Identifier: LGPL-3.0-or-later


from .utils import (
ModelTestCase,
)


class EnerModelTest(ModelTestCase):
def setUp(self) -> None:
self.expected_rcut = 5.0
self.expected_type_map = ["foo", "bar"]
self.expected_dim_fparam = 0
self.expected_dim_aparam = 0
self.expected_sel_type = [0, 1]
self.expected_aparam_nall = False
self.expected_model_output_type = ["energy", "mask"]
self.expected_sel = [8, 12]
Original file line number Diff line number Diff line change
@@ -1,25 +1,18 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Any,
Callable,
List,
)

import numpy as np

from deepmd.dpmodel.common import (
NativeOP,
)
from deepmd.dpmodel.utils.nlist import (
extend_input_and_build_neighbor_list,
)

from ..utils import (
INSTALLED_PT,
CommonTestCase,
forward_wrapper,
)


class ModelTestCase(CommonTestCase):
class ModelTestCase:
"""Common test case for model."""

expected_type_map: List[str]
Expand All @@ -38,12 +31,8 @@ class ModelTestCase(CommonTestCase):
"""Expected output type for the model."""
expected_sel: List[int]
"""Expected number of neighbors."""

@property
def modules_to_test(self):
if INSTALLED_PT:
return [*super().modules_to_test, self.pt_script_module]
return super().modules_to_test
forward_wrapper: Callable[[Any], Any]
"""Calss wrapper for forward method."""

def test_get_type_map(self):
"""Test get_type_map."""
Expand Down Expand Up @@ -118,13 +107,7 @@ def test_forward(self):
ret = []
ret_lower = []
for module in self.modules_to_test:
if isinstance(module, NativeOP):
# skip dp:
# 1. different keys
# 2. no forward_lower
# needs to be fixed
continue
module = forward_wrapper(module)
module = self.forward_wrapper(module)
ret.append(module(coord, atype, cell))

ret_lower.append(module.forward_lower(coord_ext, atype_ext, nlist))
Expand All @@ -135,19 +118,25 @@ def test_forward(self):
subret.append(rr[kk])
if len(subret):
for ii, rr in enumerate(subret[1:]):
np.testing.assert_allclose(
subret[0], rr, err_msg=f"compare {kk} between 0 and {ii}"
)
if subret[0] is None:
assert rr is None
else:
np.testing.assert_allclose(
subret[0], rr, err_msg=f"compare {kk} between 0 and {ii}"
)
for kk in ret_lower[0].keys():
subret = []
for rr in ret_lower:
if rr is not None:
subret.append(rr[kk])
if len(subret):
for ii, rr in enumerate(subret[1:]):
np.testing.assert_allclose(
subret[0], rr, err_msg=f"compare {kk} between 0 and {ii}"
)
if subret[0] is None:
assert rr is None
else:
np.testing.assert_allclose(
subret[0], rr, err_msg=f"compare {kk} between 0 and {ii}"
)
same_keys = set(ret[0].keys()) & set(ret_lower[0].keys())
self.assertTrue(same_keys)
for key in same_keys:
Expand Down
1 change: 1 addition & 0 deletions source/tests/universal/dpmodel/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
1 change: 1 addition & 0 deletions source/tests/universal/dpmodel/atomc_model/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import unittest

from deepmd.dpmodel.atomic_model.dp_atomic_model import (
DPAtomicModel,
)
from deepmd.dpmodel.descriptor.se_e2_a import (
DescrptSeA,
)
from deepmd.dpmodel.fitting.ener_fitting import (
EnergyFittingNet,
)

from ...common.cases.atomic_model.ener_model import (
EnerAtomicModelTest,
)
from ..backend import (
DPTestCase,
)


class TestEnergyAtomicModelDP(unittest.TestCase, EnerAtomicModelTest, DPTestCase):
def setUp(self):
EnerAtomicModelTest.setUp(self)
ds = DescrptSeA(
rcut=self.expected_rcut,
rcut_smth=self.expected_rcut / 2,
sel=self.expected_sel,
)
ft = EnergyFittingNet(
ntypes=len(self.expected_type_map),
dim_descrpt=ds.get_dim_out(),
mixed_types=ds.mixed_types(),
)
self.module = DPAtomicModel(
ds,
ft,
type_map=self.expected_type_map,
)
30 changes: 30 additions & 0 deletions source/tests/universal/dpmodel/backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from deepmd.dpmodel.common import (
NativeOP,
)

from ..common.backend import (
BackendTestCase,
)


class DPTestCase(BackendTestCase):
"""Common test case."""

module: NativeOP
"""DP module to test."""

def forward_wrapper(self, x):
return x

@property
def deserialized_module(self):
return self.module.deserialize(self.module.serialize())

@property
def modules_to_test(self):
modules = [
self.module,
self.deserialized_module,
]
return modules
1 change: 1 addition & 0 deletions source/tests/universal/dpmodel/model/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
39 changes: 39 additions & 0 deletions source/tests/universal/dpmodel/model/test_ener_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import unittest

from deepmd.dpmodel.descriptor.se_e2_a import (
DescrptSeA,
)
from deepmd.dpmodel.fitting.ener_fitting import (
EnergyFittingNet,
)
from deepmd.dpmodel.model.ener_model import (
EnergyModel,
)

from ...common.cases.model.ener_model import (
EnerModelTest,
)
from ..backend import (
DPTestCase,
)


class TestEnergyModelDP(unittest.TestCase, EnerModelTest, DPTestCase):
def setUp(self):
EnerModelTest.setUp(self)
ds = DescrptSeA(
rcut=self.expected_rcut,
rcut_smth=self.expected_rcut / 2,
sel=self.expected_sel,
)
ft = EnergyFittingNet(
ntypes=len(self.expected_type_map),
dim_descrpt=ds.get_dim_out(),
mixed_types=ds.mixed_types(),
)
self.module = EnergyModel(
ds,
ft,
type_map=self.expected_type_map,
)
Loading

0 comments on commit 694d500

Please sign in to comment.