Skip to content

Commit

Permalink
change the atomic model's init interface of the dpmodel
Browse files Browse the repository at this point in the history
  • Loading branch information
Han Wang committed Apr 4, 2024
1 parent 9e89b25 commit 75591a1
Show file tree
Hide file tree
Showing 6 changed files with 11 additions and 9 deletions.
6 changes: 6 additions & 0 deletions deepmd/dpmodel/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,19 @@
class BaseAtomicModel(BaseAtomicModel_):
def __init__(
self,
type_map: List[str],
atom_exclude_types: List[int] = [],
pair_exclude_types: List[Tuple[int, int]] = [],
):
super().__init__()
self.type_map = type_map
self.reinit_atom_exclude(atom_exclude_types)
self.reinit_pair_exclude(pair_exclude_types)

def get_type_map(self) -> List[str]:
"""Get the type map."""
return self.type_map

def reinit_atom_exclude(
self,
exclude_types: List[int] = [],
Expand Down
6 changes: 1 addition & 5 deletions deepmd/dpmodel/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(
self.descriptor = descriptor
self.fitting = fitting
self.type_map = type_map
super().__init__(**kwargs)
super().__init__(type_map, **kwargs)

def fitting_output_def(self) -> FittingOutputDef:
"""Get the output def of the fitting net."""
Expand All @@ -67,10 +67,6 @@ def get_sel(self) -> List[int]:
"""Get the neighbor selection."""
return self.descriptor.get_sel()

def get_type_map(self) -> List[str]:
"""Get the type map."""
return self.type_map

def mixed_types(self) -> bool:
"""If true, the model
1. assumes total number of atoms aligned across frames;
Expand Down
2 changes: 1 addition & 1 deletion deepmd/dpmodel/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __init__(
self.mapping_list.append(self.remap_atype(tpmp, self.type_map))
assert len(err_msg) == 0, "\n".join(err_msg)
self.mixed_types_list = [model.mixed_types() for model in self.models]
super().__init__(**kwargs)
super().__init__(type_map, **kwargs)

def mixed_types(self) -> bool:
"""If true, the model
Expand Down
2 changes: 1 addition & 1 deletion deepmd/dpmodel/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __init__(
type_map: List[str],
**kwargs,
):
super().__init__()
super().__init__(type_map, **kwargs)
self.tab_file = tab_file
self.rcut = rcut
self.type_map = type_map
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/model/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
class BaseAtomicModel(torch.nn.Module, BaseAtomicModel_):
def __init__(
self,
type_map,
type_map: List[str],
atom_exclude_types: List[int] = [],
pair_exclude_types: List[Tuple[int, int]] = [],
):
Expand Down
2 changes: 1 addition & 1 deletion source/tests/common/dpmodel/test_dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def test_methods(self):

md0 = DPAtomicModel(ds, ft, type_map=type_map)

self.assertEqual(md0.get_output_keys(), ["energy", "mask"])
self.assertEqual(list(md0.atomic_output_def().keys()), ["energy", "mask"])
self.assertEqual(md0.get_type_map(), ["foo", "bar"])
self.assertEqual(md0.get_ntypes(), 2)
self.assertAlmostEqual(md0.get_rcut(), self.rcut)
Expand Down

0 comments on commit 75591a1

Please sign in to comment.