Skip to content

Commit

Permalink
fix: set weights_only=True for torch.load
Browse files Browse the repository at this point in the history
Fix deepmodeling#4143.

Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
  • Loading branch information
njzjz committed Sep 19, 2024
1 parent e1b6aec commit 1c0f994
Show file tree
Hide file tree
Showing 11 changed files with 36 additions and 18 deletions.
8 changes: 6 additions & 2 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,9 @@ def train(FLAGS):
FLAGS.init_model is not None or FLAGS.init_frz_model is not None
) and FLAGS.use_pretrain_script:
if FLAGS.init_model is not None:
init_state_dict = torch.load(FLAGS.init_model, map_location=DEVICE)
init_state_dict = torch.load(
FLAGS.init_model, map_location=DEVICE, weights_only=True
)
if "model" in init_state_dict:
init_state_dict = init_state_dict["model"]
config["model"] = init_state_dict["_extra_state"]["model_params"]
Expand Down Expand Up @@ -358,7 +360,9 @@ def freeze(FLAGS):

def change_bias(FLAGS):
if FLAGS.INPUT.endswith(".pt"):
old_state_dict = torch.load(FLAGS.INPUT, map_location=env.DEVICE)
old_state_dict = torch.load(
FLAGS.INPUT, map_location=env.DEVICE, weights_only=True
)
model_state_dict = copy.deepcopy(old_state_dict.get("model", old_state_dict))
model_params = model_state_dict["_extra_state"]["model_params"]
elif FLAGS.INPUT.endswith(".pth"):
Expand Down
4 changes: 3 additions & 1 deletion deepmd/pt/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,9 @@ def __init__(
self.output_def = output_def
self.model_path = model_file
if str(self.model_path).endswith(".pt"):
state_dict = torch.load(model_file, map_location=env.DEVICE)
state_dict = torch.load(
model_file, map_location=env.DEVICE, weights_only=True
)
if "model" in state_dict:
state_dict = state_dict["model"]
self.input_param = state_dict["_extra_state"]["model_params"]
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/infer/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(
- config: The Dict-like configuration with training options.
"""
# Model
state_dict = torch.load(model_ckpt, map_location=DEVICE)
state_dict = torch.load(model_ckpt, map_location=DEVICE, weights_only=True)
if "model" in state_dict:
state_dict = state_dict["model"]
model_params = state_dict["_extra_state"]["model_params"]
Expand Down
4 changes: 3 additions & 1 deletion deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,9 @@ def get_lr(lr_params):
optimizer_state_dict = None
if resuming:
log.info(f"Resuming from {resume_model}.")
state_dict = torch.load(resume_model, map_location=DEVICE)
state_dict = torch.load(
resume_model, map_location=DEVICE, weights_only=True
)
if "model" in state_dict:
optimizer_state_dict = (
state_dict["optimizer"] if finetune_model is None else None
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/utils/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def get_finetune_rules(
Fine-tuning rules in a dict format, with `model_branch`: FinetuneRuleItem pairs.
"""
multi_task = "model_dict" in model_config
state_dict = torch.load(finetune_model, map_location=env.DEVICE)
state_dict = torch.load(finetune_model, map_location=env.DEVICE, weights_only=True)
if "model" in state_dict:
state_dict = state_dict["model"]
last_model_params = state_dict["_extra_state"]["model_params"]
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/utils/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def serialize_from_file(model_file: str) -> dict:
model = get_model(model_def_script)
model.load_state_dict(saved_model.state_dict())
elif model_file.endswith(".pt"):
state_dict = torch.load(model_file, map_location="cpu")
state_dict = torch.load(model_file, map_location="cpu", weights_only=True)
if "model" in state_dict:
state_dict = state_dict["model"]
model_def_script = state_dict["_extra_state"]["model_params"]
Expand Down
10 changes: 6 additions & 4 deletions source/tests/pt/model/test_descriptor_dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,13 +245,15 @@ def test_descriptor_block(self):
des = DescrptBlockSeAtten(
**dparams,
).to(env.DEVICE)
des.load_state_dict(torch.load(self.file_model_param))
des.load_state_dict(torch.load(self.file_model_param, weights_only=True))
coord = self.coord
atype = self.atype
box = self.cell
# handel type_embedding
type_embedding = TypeEmbedNet(ntypes, 8, use_tebd_bias=True).to(env.DEVICE)
type_embedding.load_state_dict(torch.load(self.file_type_embed))
type_embedding.load_state_dict(
torch.load(self.file_type_embed, weights_only=True)
)

## to save model parameters
# torch.save(des.state_dict(), 'model_weights.pth')
Expand Down Expand Up @@ -299,8 +301,8 @@ def test_descriptor(self):
**dparams,
).to(env.DEVICE)
target_dict = des.state_dict()
source_dict = torch.load(self.file_model_param)
type_embd_dict = torch.load(self.file_type_embed)
source_dict = torch.load(self.file_model_param, weights_only=True)
type_embd_dict = torch.load(self.file_type_embed, weights_only=True)
target_dict = translate_se_atten_and_type_embd_dicts_to_dpa1(
target_dict,
source_dict,
Expand Down
4 changes: 2 additions & 2 deletions source/tests/pt/model/test_descriptor_dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,10 @@ def test_descriptor(self):
**dparams,
).to(env.DEVICE)
target_dict = des.state_dict()
source_dict = torch.load(self.file_model_param)
source_dict = torch.load(self.file_model_param, weights_only=True)
# type_embd of repformer is removed
source_dict.pop("type_embedding.embedding.embedding_net.layers.0.bias")
type_embd_dict = torch.load(self.file_type_embed)
type_embd_dict = torch.load(self.file_type_embed, weights_only=True)
target_dict = translate_type_embd_dicts_to_dpa2(
target_dict,
source_dict,
Expand Down
4 changes: 3 additions & 1 deletion source/tests/pt/model/test_saveload_dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,9 @@ def get_model_result(self, read=False, model_file="tmp_model.pt"):
optimizer = torch.optim.Adam(wrapper.parameters(), lr=self.start_lr)
optimizer.zero_grad()
if read:
wrapper.load_state_dict(torch.load(model_file, map_location=env.DEVICE))
wrapper.load_state_dict(
torch.load(model_file, map_location=env.DEVICE, weights_only=True)
)
os.remove(model_file)
else:
torch.save(wrapper.state_dict(), model_file)
Expand Down
4 changes: 3 additions & 1 deletion source/tests/pt/model/test_saveload_se_e2_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,9 @@ def get_model_result(self, read=False, model_file="tmp_model.pt"):
optimizer = torch.optim.Adam(wrapper.parameters(), lr=self.start_lr)
optimizer.zero_grad()
if read:
wrapper.load_state_dict(torch.load(model_file, map_location=env.DEVICE))
wrapper.load_state_dict(
torch.load(model_file, map_location=env.DEVICE, weights_only=True)
)
os.remove(model_file)
else:
torch.save(wrapper.state_dict(), model_file)
Expand Down
10 changes: 7 additions & 3 deletions source/tests/pt/test_change_bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,9 @@ def test_change_bias_with_data(self):
run_dp(
f"dp --pt change-bias {self.model_path!s} -s {self.data_file[0]} -o {self.model_path_data_bias!s}"
)
state_dict = torch.load(str(self.model_path_data_bias), map_location=DEVICE)
state_dict = torch.load(
str(self.model_path_data_bias), map_location=DEVICE, weights_only=True
)
model_params = state_dict["model"]["_extra_state"]["model_params"]
model_for_wrapper = get_model_for_wrapper(model_params)
wrapper = ModelWrapper(model_for_wrapper)
Expand All @@ -114,7 +116,7 @@ def test_change_bias_with_data_sys_file(self):
f"dp --pt change-bias {self.model_path!s} -f {tmp_file.name} -o {self.model_path_data_file_bias!s}"
)
state_dict = torch.load(
str(self.model_path_data_file_bias), map_location=DEVICE
str(self.model_path_data_file_bias), map_location=DEVICE, weights_only=True
)
model_params = state_dict["model"]["_extra_state"]["model_params"]
model_for_wrapper = get_model_for_wrapper(model_params)
Expand All @@ -134,7 +136,9 @@ def test_change_bias_with_user_defined(self):
run_dp(
f"dp --pt change-bias {self.model_path!s} -b {' '.join([str(_) for _ in user_bias])} -o {self.model_path_user_bias!s}"
)
state_dict = torch.load(str(self.model_path_user_bias), map_location=DEVICE)
state_dict = torch.load(
str(self.model_path_user_bias), map_location=DEVICE, weights_only=True
)
model_params = state_dict["model"]["_extra_state"]["model_params"]
model_for_wrapper = get_model_for_wrapper(model_params)
wrapper = ModelWrapper(model_for_wrapper)
Expand Down

0 comments on commit 1c0f994

Please sign in to comment.