Skip to content

Commit

Permalink
add dp model format for sea. some features not supported
Browse files Browse the repository at this point in the history
  • Loading branch information
Han Wang committed Jan 9, 2024
1 parent a971d92 commit 158f184
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 5 deletions.
6 changes: 6 additions & 0 deletions deepmd_utils/model_format/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from .common import (
DEFAULT_PRECISION,
PRECISION_DICT,
)
from .env_mat import (
Expand All @@ -13,8 +14,12 @@
save_dp_model,
traverse_model_dict,
)
from .se_e2_a import (
DescrptSeA,
)

__all__ = [
"DescrptSeA",
"EnvMat",
"EmbeddingNet",
"NativeLayer",
Expand All @@ -23,4 +28,5 @@
"save_dp_model",
"traverse_model_dict",
"PRECISION_DICT",
"DEFAULT_PRECISION",
]
2 changes: 2 additions & 0 deletions deepmd_utils/model_format/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,7 @@ def __init__(
self.neuron = neuron
self.activation_function = activation_function
self.resnet_dt = resnet_dt
self.precision = precision

def serialize(self) -> dict:
"""Serialize the network to a dict.
Expand All @@ -393,6 +394,7 @@ def serialize(self) -> dict:
"neuron": self.neuron.copy(),
"activation_function": self.activation_function,
"resnet_dt": self.resnet_dt,
"precision": self.precision,
"layers": [layer.serialize() for layer in self.layers],
}

Expand Down
45 changes: 40 additions & 5 deletions source/tests/test_model_format_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import numpy as np

from deepmd_utils.model_format import (
DescrptSeA,
EmbeddingNet,
EnvMat,
NativeLayer,
Expand Down Expand Up @@ -97,12 +98,18 @@ def test_deserialize(self):
np.testing.assert_array_equal(network[1]["resnet"], True)

def test_embedding_net(self):
for ni, idt, act in itertools.product(
for ni, act, idt, prec in itertools.product(
[1, 10],
[True, False],
["tanh", "none"],
[True, False],
["double", "single"],
):
en0 = EmbeddingNet(ni)
en0 = EmbeddingNet(
ni,
activation_function=act,
precision=prec,
resnet_dt=idt,
)
en1 = EmbeddingNet.deserialize(en0.serialize())
inp = np.ones([ni])
np.testing.assert_allclose(en0.call(inp), en1.call(inp))
Expand Down Expand Up @@ -141,7 +148,7 @@ def tearDown(self) -> None:
os.remove(self.filename)


class TestEnvMat(unittest.TestCase):
class TestCaseSingleFrameWithNlist:
def setUp(self):
# nloc == 3, nall == 4
self.nloc = 3
Expand All @@ -158,17 +165,23 @@ def setUp(self):
).reshape([1, self.nall * 3])
self.atype_ext = np.array([0, 0, 1, 0], dtype=int).reshape([1, self.nall])
# sel = [5, 2]
self.sel = [5, 2]
self.nlist = np.array(
[
[1, 3, -1, -1, -1, 2, -1],
[0, -1, -1, -1, -1, 2, -1],
[0, 1, -1, -1, -1, 0, -1],
],
dtype=int,
).reshape([1, self.nloc, 7])
).reshape([1, self.nloc, sum(self.sel)])
self.rcut = 0.4
self.rcut_smth = 2.2


class TestEnvMat(unittest.TestCase, TestCaseSingleFrameWithNlist):
def setUp(self):
TestCaseSingleFrameWithNlist.setUp(self)

def test_self_consistency(
self,
):
Expand All @@ -183,3 +196,25 @@ def test_self_consistency(
mm1, ww1 = em1.call(self.nlist, self.coord_ext, self.atype_ext, davg, dstd)
np.testing.assert_allclose(mm0, mm1)
np.testing.assert_allclose(ww0, ww1)


class TestDescrptSeA(unittest.TestCase, TestCaseSingleFrameWithNlist):
def setUp(self):
TestCaseSingleFrameWithNlist.setUp(self)

def test_self_consistency(
self,
):
rng = np.random.default_rng()
nf, nloc, nnei = self.nlist.shape
davg = rng.normal(size=(self.nt, nnei, 4))
dstd = rng.normal(size=(self.nt, nnei, 4))
dstd = 0.1 + np.abs(dstd)

em0 = DescrptSeA(self.rcut, self.rcut_smth, self.sel)
em0.davg = davg
em0.dstd = dstd
em1 = DescrptSeA.deserialize(em0.serialize())
mm0 = em0.call(self.coord_ext, self.atype_ext, self.nlist)
mm1 = em1.call(self.coord_ext, self.atype_ext, self.nlist)
np.testing.assert_allclose(mm0, mm1)

0 comments on commit 158f184

Please sign in to comment.