Skip to content

Commit

Permalink
fix: bugs in uts for polar and dipole fit (#3837)
Browse files Browse the repository at this point in the history
Fix following trivial bugs in dipole and polar fit uts:
1. `box` was not used in `extend_input_and_build_neighbor_list` (which
means they were all tested in nopbc mode, if shifted coord is outside
the box (sometimes) and normalized explicitly, results are not the
same.) Input for fitting also used extended_atype instead of atype.
(Only same when nopbc.)
2. Using of `mixed_types` is disordered, mismatched with descriptor or
sometimes with nlist. Now only use `mixed_types`==False since the
descriptor output is not in mixed types.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **Tests**
  - Improved consistency in parameter handling for various test methods.
- Updated `mixed_types` parameter to dynamically use
`self.dd0.mixed_types()` across multiple test functions for better
flexibility and accuracy.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: Han Wang <92130845+wanghan-iapcm@users.noreply.github.com>
Co-authored-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
  • Loading branch information
3 people committed May 31, 2024
1 parent 3a7fbcf commit 1c18950
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 31 deletions.
40 changes: 25 additions & 15 deletions source/tests/pt/model/test_dipole_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,7 @@ def test_consistency(
self.atype_ext[:, : self.nloc], dtype=int, device=env.DEVICE
)

for mixed_types, nfp, nap in itertools.product(
[True, False],
for nfp, nap in itertools.product(
[0, 3],
[0, 4],
):
Expand All @@ -84,7 +83,7 @@ def test_consistency(
embedding_width=self.dd0.get_dim_emb(),
numb_fparam=nfp,
numb_aparam=nap,
mixed_types=mixed_types,
mixed_types=self.dd0.mixed_types(),
).to(env.DEVICE)
ft1 = DPDipoleFitting.deserialize(ft0.serialize())
ft2 = DipoleFittingNet.deserialize(ft1.serialize())
Expand Down Expand Up @@ -159,9 +158,10 @@ def test_rot(self):
atype = self.atype.reshape(1, 5)
rmat = torch.tensor(special_ortho_group.rvs(3), dtype=dtype, device=env.DEVICE)
coord_rot = torch.matmul(self.coord, rmat)
# use larger cell to rotate only coord and shift to the center of cell
cell_rot = 10.0 * torch.eye(3, dtype=dtype, device=env.DEVICE)
rng = np.random.default_rng()
for mixed_types, nfp, nap in itertools.product(
[True, False],
for nfp, nap in itertools.product(
[0, 3],
[0, 4],
):
Expand All @@ -171,7 +171,7 @@ def test_rot(self):
embedding_width=self.dd0.get_dim_emb(),
numb_fparam=nfp,
numb_aparam=nap,
mixed_types=mixed_types,
mixed_types=self.dd0.mixed_types(),
).to(env.DEVICE)
if nfp > 0:
ifp = torch.tensor(
Expand All @@ -196,7 +196,12 @@ def test_rot(self):
_,
nlist,
) = extend_input_and_build_neighbor_list(
xyz + self.shift, atype, self.rcut, self.sel, not mixed_types
xyz + self.shift,
atype,
self.rcut,
self.sel,
self.dd0.mixed_types(),
box=cell_rot,
)

rd0, gr0, _, _, _ = self.dd0(
Expand All @@ -205,7 +210,7 @@ def test_rot(self):
nlist,
)

ret0 = ft0(rd0, extended_atype, gr0, fparam=ifp, aparam=iap)
ret0 = ft0(rd0, atype, gr0, fparam=ifp, aparam=iap)
res.append(ret0["dipole"])

np.testing.assert_allclose(
Expand All @@ -220,7 +225,7 @@ def test_permu(self):
embedding_width=self.dd0.get_dim_emb(),
numb_fparam=0,
numb_aparam=0,
mixed_types=False,
mixed_types=self.dd0.mixed_types(),
).to(env.DEVICE)
res = []
for idx_perm in [[0, 1, 2, 3, 4], [1, 0, 4, 3, 2]]:
Expand All @@ -231,7 +236,12 @@ def test_permu(self):
_,
nlist,
) = extend_input_and_build_neighbor_list(
coord[idx_perm], atype, self.rcut, self.sel, True
coord[idx_perm],
atype,
self.rcut,
self.sel,
self.dd0.mixed_types(),
box=self.cell,
)

rd0, gr0, _, _, _ = self.dd0(
Expand All @@ -240,7 +250,7 @@ def test_permu(self):
nlist,
)

ret0 = ft0(rd0, extended_atype, gr0, fparam=0, aparam=0)
ret0 = ft0(rd0, atype, gr0, fparam=0, aparam=0)
res.append(ret0["dipole"])

np.testing.assert_allclose(
Expand All @@ -261,7 +271,7 @@ def test_trans(self):
embedding_width=self.dd0.get_dim_emb(),
numb_fparam=0,
numb_aparam=0,
mixed_types=True,
mixed_types=self.dd0.mixed_types(),
).to(env.DEVICE)
res = []
for xyz in [self.coord, coord_s]:
Expand All @@ -271,7 +281,7 @@ def test_trans(self):
_,
nlist,
) = extend_input_and_build_neighbor_list(
xyz, atype, self.rcut, self.sel, False
xyz, atype, self.rcut, self.sel, self.dd0.mixed_types(), box=self.cell
)

rd0, gr0, _, _, _ = self.dd0(
Expand All @@ -280,7 +290,7 @@ def test_trans(self):
nlist,
)

ret0 = ft0(rd0, extended_atype, gr0, fparam=0, aparam=0)
ret0 = ft0(rd0, atype, gr0, fparam=0, aparam=0)
res.append(ret0["dipole"])

np.testing.assert_allclose(to_numpy_array(res[0]), to_numpy_array(res[1]))
Expand All @@ -305,7 +315,7 @@ def setUp(self):
embedding_width=self.dd0.get_dim_emb(),
numb_fparam=0,
numb_aparam=0,
mixed_types=True,
mixed_types=self.dd0.mixed_types(),
).to(env.DEVICE)
self.type_mapping = ["O", "H", "B"]
self.model = DipoleModel(self.dd0, self.ft0, self.type_mapping)
Expand Down
52 changes: 36 additions & 16 deletions source/tests/pt/model/test_polarizability_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,7 @@ def test_consistency(
self.atype_ext[:, : self.nloc], dtype=int, device=env.DEVICE
)

for mixed_types, nfp, nap, fit_diag, scale in itertools.product(
[True, False],
for nfp, nap, fit_diag, scale in itertools.product(
[0, 3],
[0, 4],
[True, False],
Expand All @@ -72,7 +71,7 @@ def test_consistency(
embedding_width=self.dd0.get_dim_emb(),
numb_fparam=nfp,
numb_aparam=nap,
mixed_types=mixed_types,
mixed_types=self.dd0.mixed_types(),
fit_diag=fit_diag,
scale=scale,
).to(env.DEVICE)
Expand Down Expand Up @@ -166,9 +165,10 @@ def test_rot(self):
atype = self.atype.reshape(1, 5)
rmat = torch.tensor(special_ortho_group.rvs(3), dtype=dtype, device=env.DEVICE)
coord_rot = torch.matmul(self.coord, rmat)
# use larger cell to rotate only coord and shift to the center of cell
cell_rot = 10.0 * torch.eye(3, dtype=dtype, device=env.DEVICE)

for mixed_types, nfp, nap, fit_diag, scale in itertools.product(
[True, False],
for nfp, nap, fit_diag, scale in itertools.product(
[0, 3],
[0, 4],
[True, False],
Expand All @@ -180,7 +180,7 @@ def test_rot(self):
embedding_width=self.dd0.get_dim_emb(),
numb_fparam=nfp,
numb_aparam=nap,
mixed_types=True,
mixed_types=self.dd0.mixed_types(),
fit_diag=fit_diag,
scale=scale,
).to(env.DEVICE)
Expand All @@ -207,7 +207,12 @@ def test_rot(self):
_,
nlist,
) = extend_input_and_build_neighbor_list(
xyz + self.shift, atype, self.rcut, self.sel, mixed_types
xyz + self.shift,
atype,
self.rcut,
self.sel,
self.dd0.mixed_types(),
box=cell_rot,
)

rd0, gr0, _, _, _ = self.dd0(
Expand All @@ -216,7 +221,7 @@ def test_rot(self):
nlist,
)

ret0 = ft0(rd0, extended_atype, gr0, fparam=ifp, aparam=iap)
ret0 = ft0(rd0, atype, gr0, fparam=ifp, aparam=iap)
res.append(ret0["polarizability"])
np.testing.assert_allclose(
to_numpy_array(res[1]),
Expand All @@ -237,7 +242,7 @@ def test_permu(self):
embedding_width=self.dd0.get_dim_emb(),
numb_fparam=0,
numb_aparam=0,
mixed_types=True,
mixed_types=self.dd0.mixed_types(),
fit_diag=fit_diag,
scale=scale,
).to(env.DEVICE)
Expand All @@ -250,7 +255,12 @@ def test_permu(self):
_,
nlist,
) = extend_input_and_build_neighbor_list(
coord[idx_perm], atype, self.rcut, self.sel, False
coord[idx_perm],
atype,
self.rcut,
self.sel,
self.dd0.mixed_types(),
box=self.cell,
)

rd0, gr0, _, _, _ = self.dd0(
Expand All @@ -259,7 +269,7 @@ def test_permu(self):
nlist,
)

ret0 = ft0(rd0, extended_atype, gr0, fparam=None, aparam=None)
ret0 = ft0(rd0, atype, gr0, fparam=None, aparam=None)
res.append(ret0["polarizability"])

np.testing.assert_allclose(
Expand All @@ -269,15 +279,20 @@ def test_permu(self):

def test_trans(self):
atype = self.atype.reshape(1, 5)
coord_s = self.coord + self.shift
coord_s = torch.matmul(
torch.remainder(
torch.matmul(self.coord + self.shift, torch.linalg.inv(self.cell)), 1.0
),
self.cell,
)
for fit_diag, scale in itertools.product([True, False], [None, self.scale]):
ft0 = PolarFittingNet(
self.nt,
self.dd0.dim_out,
embedding_width=self.dd0.get_dim_emb(),
numb_fparam=0,
numb_aparam=0,
mixed_types=True,
mixed_types=self.dd0.mixed_types(),
fit_diag=fit_diag,
scale=scale,
).to(env.DEVICE)
Expand All @@ -289,7 +304,12 @@ def test_trans(self):
_,
nlist,
) = extend_input_and_build_neighbor_list(
xyz, atype, self.rcut, self.sel, False
xyz,
atype,
self.rcut,
self.sel,
self.dd0.mixed_types(),
box=self.cell,
)

rd0, gr0, _, _, _ = self.dd0(
Expand All @@ -298,7 +318,7 @@ def test_trans(self):
nlist,
)

ret0 = ft0(rd0, extended_atype, gr0, fparam=0, aparam=0)
ret0 = ft0(rd0, atype, gr0, fparam=0, aparam=0)
res.append(ret0["polarizability"])

np.testing.assert_allclose(to_numpy_array(res[0]), to_numpy_array(res[1]))
Expand All @@ -323,7 +343,7 @@ def setUp(self):
embedding_width=self.dd0.get_dim_emb(),
numb_fparam=0,
numb_aparam=0,
mixed_types=True,
mixed_types=self.dd0.mixed_types(),
).to(env.DEVICE)
self.type_mapping = ["O", "H", "B"]
self.model = PolarModel(self.dd0, self.ft0, self.type_mapping)
Expand Down

0 comments on commit 1c18950

Please sign in to comment.