Skip to content

Commit

Permalink
fix(pt/dp): share params of repinit_three_body (#4139)
Browse files Browse the repository at this point in the history
Fix #4137.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

- **New Features**
- Enhanced descriptor handling to support three-body interactions,
improving model flexibility.
- Introduced a new model configuration, `model_dpa2tebd`, with advanced
parameters for better performance.

- **Tests**
- Added a new test class, `TestMultiTaskDPA2Tebd`, to expand testing
coverage for the new model configuration.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
  • Loading branch information
iProzd committed Sep 19, 2024
1 parent ba9f02f commit 8f969ba
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 5 deletions.
12 changes: 10 additions & 2 deletions deepmd/dpmodel/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,16 +733,24 @@ def set_stat_mean_and_stddev(
stddev: List[np.ndarray],
) -> None:
"""Update mean and stddev for descriptor."""
for ii, descrpt in enumerate([self.repinit, self.repformers]):
descrpt_list = [self.repinit, self.repformers]
if self.use_three_body:
descrpt_list.append(self.repinit_three_body)
for ii, descrpt in enumerate(descrpt_list):
descrpt.mean = mean[ii]
descrpt.stddev = stddev[ii]

def get_stat_mean_and_stddev(self) -> Tuple[List[np.ndarray], List[np.ndarray]]:
"""Get mean and stddev for descriptor."""
return [self.repinit.mean, self.repformers.mean], [
mean_list = [self.repinit.mean, self.repformers.mean]
stddev_list = [
self.repinit.stddev,
self.repformers.stddev,
]
if self.use_three_body:
mean_list.append(self.repinit_three_body.mean)
stddev_list.append(self.repinit_three_body.stddev)
return mean_list, stddev_list

def call(
self,
Expand Down
25 changes: 22 additions & 3 deletions deepmd/pt/model/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,10 @@ def share_params(self, base_class, shared_level, resume=False):
if shared_level == 0:
self._modules["type_embedding"] = base_class._modules["type_embedding"]
self.repinit.share_params(base_class.repinit, 0, resume=resume)
if self.use_three_body:
self.repinit_three_body.share_params(
base_class.repinit_three_body, 0, resume=resume
)
self._modules["g1_shape_tranform"] = base_class._modules[
"g1_shape_tranform"
]
Expand All @@ -398,6 +402,10 @@ def share_params(self, base_class, shared_level, resume=False):
elif shared_level == 1:
self._modules["type_embedding"] = base_class._modules["type_embedding"]
self.repinit.share_params(base_class.repinit, 0, resume=resume)
if self.use_three_body:
self.repinit_three_body.share_params(
base_class.repinit_three_body, 0, resume=resume
)
# shared_level: 2
# share all parameters in type_embedding and repformers
elif shared_level == 2:
Expand Down Expand Up @@ -499,7 +507,10 @@ def compute_input_stats(
The path to the stat file.
"""
for ii, descrpt in enumerate([self.repinit, self.repformers]):
descrpt_list = [self.repinit, self.repformers]
if self.use_three_body:
descrpt_list.append(self.repinit_three_body)
for ii, descrpt in enumerate(descrpt_list):
descrpt.compute_input_stats(merged, path)

def set_stat_mean_and_stddev(
Expand All @@ -508,16 +519,24 @@ def set_stat_mean_and_stddev(
stddev: List[torch.Tensor],
) -> None:
"""Update mean and stddev for descriptor."""
for ii, descrpt in enumerate([self.repinit, self.repformers]):
descrpt_list = [self.repinit, self.repformers]
if self.use_three_body:
descrpt_list.append(self.repinit_three_body)
for ii, descrpt in enumerate(descrpt_list):
descrpt.mean = mean[ii]
descrpt.stddev = stddev[ii]

def get_stat_mean_and_stddev(self) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
"""Get mean and stddev for descriptor."""
return [self.repinit.mean, self.repformers.mean], [
mean_list = [self.repinit.mean, self.repformers.mean]
stddev_list = [
self.repinit.stddev,
self.repformers.stddev,
]
if self.use_three_body:
mean_list.append(self.repinit_three_body.mean)
stddev_list.append(self.repinit_three_body.stddev)
return mean_list, stddev_list

def serialize(self) -> dict:
repinit = self.repinit
Expand Down
52 changes: 52 additions & 0 deletions source/tests/pt/model/test_permutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,58 @@
},
}

model_dpa2tebd = {
"type_map": ["O", "H", "B"],
"descriptor": {
"type": "dpa2",
"repinit": {
"rcut": 6.0,
"rcut_smth": 0.5,
"nsel": 100,
"neuron": [2, 4, 8],
"axis_neuron": 4,
"activation_function": "tanh",
"three_body_sel": 40,
"three_body_rcut": 4.0,
"three_body_rcut_smth": 3.5,
"use_three_body": True,
},
"repformer": {
"rcut": 4.0,
"rcut_smth": 0.5,
"nsel": 40,
"nlayers": 6,
"g1_dim": 8,
"g2_dim": 5,
"attn2_hidden": 3,
"attn2_nhead": 1,
"attn1_hidden": 5,
"attn1_nhead": 1,
"axis_neuron": 4,
"update_h2": False,
"update_g1_has_conv": True,
"update_g1_has_grrg": True,
"update_g1_has_drrd": True,
"update_g1_has_attn": False,
"update_g2_has_g1g1": False,
"update_g2_has_attn": True,
"update_style": "res_residual",
"update_residual": 0.01,
"update_residual_init": "norm",
"attn2_has_gate": True,
"use_sqrt_nnei": True,
"g1_out_conv": True,
"g1_out_mlp": True,
},
"add_tebd_to_repinit_out": False,
},
"fitting_net": {
"neuron": [24, 24],
"resnet_dt": True,
"seed": 1,
},
}

model_dpa1 = {
"type_map": ["O", "H", "B"],
"descriptor": {
Expand Down
40 changes: 40 additions & 0 deletions source/tests/pt/test_multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from .model.test_permutation import (
model_dpa1,
model_dpa2,
model_dpa2tebd,
model_se_e2_a,
)

Expand Down Expand Up @@ -300,5 +301,44 @@ def tearDown(self) -> None:
MultiTaskTrainTest.tearDown(self)


class TestMultiTaskDPA2Tebd(unittest.TestCase, MultiTaskTrainTest):
def setUp(self):
multitask_DPA2 = deepcopy(multitask_template)
multitask_DPA2["model"]["shared_dict"]["my_descriptor"] = model_dpa2tebd[
"descriptor"
]
data_file = [str(Path(__file__).parent / "water/data/data_0")]
self.stat_files = "DPA2Tebd"
os.makedirs(self.stat_files, exist_ok=True)
self.config = multitask_DPA2
self.config["training"]["data_dict"]["model_1"]["training_data"]["systems"] = (
data_file
)
self.config["training"]["data_dict"]["model_1"]["validation_data"][
"systems"
] = data_file
self.config["training"]["data_dict"]["model_1"]["stat_file"] = (
f"{self.stat_files}/model_1"
)
self.config["training"]["data_dict"]["model_2"]["training_data"]["systems"] = (
data_file
)
self.config["training"]["data_dict"]["model_2"]["validation_data"][
"systems"
] = data_file
self.config["training"]["data_dict"]["model_2"]["stat_file"] = (
f"{self.stat_files}/model_2"
)
self.config["training"]["numb_steps"] = 1
self.config["training"]["save_freq"] = 1
self.origin_config = deepcopy(self.config)
self.config["model"], self.shared_links = preprocess_shared_params(
self.config["model"]
)

def tearDown(self) -> None:
MultiTaskTrainTest.tearDown(self)


if __name__ == "__main__":
unittest.main()

0 comments on commit 8f969ba

Please sign in to comment.