-
Notifications
You must be signed in to change notification settings - Fork 502
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
- Loading branch information
Showing
26 changed files
with
354 additions
and
227 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
# SPDX-License-Identifier: LGPL-3.0-or-later | ||
"""Common test case.""" | ||
|
||
from abc import ( | ||
ABC, | ||
abstractmethod, | ||
) | ||
|
||
|
||
class BackendTestCase(ABC): | ||
"""Backend test case.""" | ||
|
||
module: object | ||
"""Module to test.""" | ||
|
||
@property | ||
@abstractmethod | ||
def modules_to_test(self) -> list: | ||
pass | ||
|
||
@abstractmethod | ||
def forward_wrapper(self, x): | ||
pass |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
# SPDX-License-Identifier: LGPL-3.0-or-later |
18 changes: 18 additions & 0 deletions
18
source/tests/universal/common/cases/atomic_model/ener_model.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# SPDX-License-Identifier: LGPL-3.0-or-later | ||
|
||
|
||
from .utils import ( | ||
AtomicModelTestCase, | ||
) | ||
|
||
|
||
class EnerAtomicModelTest(AtomicModelTestCase): | ||
def setUp(self) -> None: | ||
self.expected_rcut = 5.0 | ||
self.expected_type_map = ["foo", "bar"] | ||
self.expected_dim_fparam = 0 | ||
self.expected_dim_aparam = 0 | ||
self.expected_sel_type = [0, 1] | ||
self.expected_aparam_nall = False | ||
self.expected_model_output_type = ["energy", "mask"] | ||
self.expected_sel = [8, 12] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
# SPDX-License-Identifier: LGPL-3.0-or-later |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# SPDX-License-Identifier: LGPL-3.0-or-later | ||
|
||
|
||
from .utils import ( | ||
ModelTestCase, | ||
) | ||
|
||
|
||
class EnerModelTest(ModelTestCase): | ||
def setUp(self) -> None: | ||
self.expected_rcut = 5.0 | ||
self.expected_type_map = ["foo", "bar"] | ||
self.expected_dim_fparam = 0 | ||
self.expected_dim_aparam = 0 | ||
self.expected_sel_type = [0, 1] | ||
self.expected_aparam_nall = False | ||
self.expected_model_output_type = ["energy", "mask"] | ||
self.expected_sel = [8, 12] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
# SPDX-License-Identifier: LGPL-3.0-or-later |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
# SPDX-License-Identifier: LGPL-3.0-or-later |
39 changes: 39 additions & 0 deletions
39
source/tests/universal/dpmodel/atomc_model/test_ener_atomic_model.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
# SPDX-License-Identifier: LGPL-3.0-or-later | ||
import unittest | ||
|
||
from deepmd.dpmodel.atomic_model.dp_atomic_model import ( | ||
DPAtomicModel, | ||
) | ||
from deepmd.dpmodel.descriptor.se_e2_a import ( | ||
DescrptSeA, | ||
) | ||
from deepmd.dpmodel.fitting.ener_fitting import ( | ||
EnergyFittingNet, | ||
) | ||
|
||
from ...common.cases.atomic_model.ener_model import ( | ||
EnerAtomicModelTest, | ||
) | ||
from ..backend import ( | ||
DPTestCase, | ||
) | ||
|
||
|
||
class TestEnergyAtomicModelDP(unittest.TestCase, EnerAtomicModelTest, DPTestCase): | ||
def setUp(self): | ||
EnerAtomicModelTest.setUp(self) | ||
ds = DescrptSeA( | ||
rcut=self.expected_rcut, | ||
rcut_smth=self.expected_rcut / 2, | ||
sel=self.expected_sel, | ||
) | ||
ft = EnergyFittingNet( | ||
ntypes=len(self.expected_type_map), | ||
dim_descrpt=ds.get_dim_out(), | ||
mixed_types=ds.mixed_types(), | ||
) | ||
self.module = DPAtomicModel( | ||
ds, | ||
ft, | ||
type_map=self.expected_type_map, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
# SPDX-License-Identifier: LGPL-3.0-or-later | ||
from deepmd.dpmodel.common import ( | ||
NativeOP, | ||
) | ||
|
||
from ..common.backend import ( | ||
BackendTestCase, | ||
) | ||
|
||
|
||
class DPTestCase(BackendTestCase): | ||
"""Common test case.""" | ||
|
||
module: NativeOP | ||
"""DP module to test.""" | ||
|
||
def forward_wrapper(self, x): | ||
return x | ||
|
||
@property | ||
def deserialized_module(self): | ||
return self.module.deserialize(self.module.serialize()) | ||
|
||
@property | ||
def modules_to_test(self): | ||
modules = [ | ||
self.module, | ||
self.deserialized_module, | ||
] | ||
return modules |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
# SPDX-License-Identifier: LGPL-3.0-or-later |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
# SPDX-License-Identifier: LGPL-3.0-or-later | ||
import unittest | ||
|
||
from deepmd.dpmodel.descriptor.se_e2_a import ( | ||
DescrptSeA, | ||
) | ||
from deepmd.dpmodel.fitting.ener_fitting import ( | ||
EnergyFittingNet, | ||
) | ||
from deepmd.dpmodel.model.ener_model import ( | ||
EnergyModel, | ||
) | ||
|
||
from ...common.cases.model.ener_model import ( | ||
EnerModelTest, | ||
) | ||
from ..backend import ( | ||
DPTestCase, | ||
) | ||
|
||
|
||
class TestEnergyModelDP(unittest.TestCase, EnerModelTest, DPTestCase): | ||
def setUp(self): | ||
EnerModelTest.setUp(self) | ||
ds = DescrptSeA( | ||
rcut=self.expected_rcut, | ||
rcut_smth=self.expected_rcut / 2, | ||
sel=self.expected_sel, | ||
) | ||
ft = EnergyFittingNet( | ||
ntypes=len(self.expected_type_map), | ||
dim_descrpt=ds.get_dim_out(), | ||
mixed_types=ds.mixed_types(), | ||
) | ||
self.module = EnergyModel( | ||
ds, | ||
ft, | ||
type_map=self.expected_type_map, | ||
) |
Oops, something went wrong.