diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index 9133575ec..a7280cce1 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -275,7 +275,9 @@ def train(FLAGS): FLAGS.init_model is not None or FLAGS.init_frz_model is not None ) and FLAGS.use_pretrain_script: if FLAGS.init_model is not None: - init_state_dict = torch.load(FLAGS.init_model, map_location=DEVICE) + init_state_dict = torch.load( + FLAGS.init_model, map_location=DEVICE, weights_only=True + ) if "model" in init_state_dict: init_state_dict = init_state_dict["model"] config["model"] = init_state_dict["_extra_state"]["model_params"] @@ -358,7 +360,9 @@ def freeze(FLAGS): def change_bias(FLAGS): if FLAGS.INPUT.endswith(".pt"): - old_state_dict = torch.load(FLAGS.INPUT, map_location=env.DEVICE) + old_state_dict = torch.load( + FLAGS.INPUT, map_location=env.DEVICE, weights_only=True + ) model_state_dict = copy.deepcopy(old_state_dict.get("model", old_state_dict)) model_params = model_state_dict["_extra_state"]["model_params"] elif FLAGS.INPUT.endswith(".pth"): diff --git a/deepmd/pt/infer/deep_eval.py b/deepmd/pt/infer/deep_eval.py index 353109d65..ab404cd8a 100644 --- a/deepmd/pt/infer/deep_eval.py +++ b/deepmd/pt/infer/deep_eval.py @@ -107,7 +107,9 @@ def __init__( self.output_def = output_def self.model_path = model_file if str(self.model_path).endswith(".pt"): - state_dict = torch.load(model_file, map_location=env.DEVICE) + state_dict = torch.load( + model_file, map_location=env.DEVICE, weights_only=True + ) if "model" in state_dict: state_dict = state_dict["model"] self.input_param = state_dict["_extra_state"]["model_params"] diff --git a/deepmd/pt/infer/inference.py b/deepmd/pt/infer/inference.py index dfb7abdb2..b3d120cbc 100644 --- a/deepmd/pt/infer/inference.py +++ b/deepmd/pt/infer/inference.py @@ -34,7 +34,7 @@ def __init__( - config: The Dict-like configuration with training options. """ # Model - state_dict = torch.load(model_ckpt, map_location=DEVICE) + state_dict = torch.load(model_ckpt, map_location=DEVICE, weights_only=True) if "model" in state_dict: state_dict = state_dict["model"] model_params = state_dict["_extra_state"]["model_params"] diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index a7b9e25b4..d32b03bd8 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -401,7 +401,9 @@ def get_lr(lr_params): optimizer_state_dict = None if resuming: log.info(f"Resuming from {resume_model}.") - state_dict = torch.load(resume_model, map_location=DEVICE) + state_dict = torch.load( + resume_model, map_location=DEVICE, weights_only=True + ) if "model" in state_dict: optimizer_state_dict = ( state_dict["optimizer"] if finetune_model is None else None diff --git a/deepmd/pt/utils/finetune.py b/deepmd/pt/utils/finetune.py index 2dd2230b5..96a420bf6 100644 --- a/deepmd/pt/utils/finetune.py +++ b/deepmd/pt/utils/finetune.py @@ -136,7 +136,7 @@ def get_finetune_rules( Fine-tuning rules in a dict format, with `model_branch`: FinetuneRuleItem pairs. """ multi_task = "model_dict" in model_config - state_dict = torch.load(finetune_model, map_location=env.DEVICE) + state_dict = torch.load(finetune_model, map_location=env.DEVICE, weights_only=True) if "model" in state_dict: state_dict = state_dict["model"] last_model_params = state_dict["_extra_state"]["model_params"] diff --git a/deepmd/pt/utils/serialization.py b/deepmd/pt/utils/serialization.py index aab6d100a..1c6ea096a 100644 --- a/deepmd/pt/utils/serialization.py +++ b/deepmd/pt/utils/serialization.py @@ -33,7 +33,7 @@ def serialize_from_file(model_file: str) -> dict: model = get_model(model_def_script) model.load_state_dict(saved_model.state_dict()) elif model_file.endswith(".pt"): - state_dict = torch.load(model_file, map_location="cpu") + state_dict = torch.load(model_file, map_location="cpu", weights_only=True) if "model" in state_dict: state_dict = state_dict["model"] model_def_script = state_dict["_extra_state"]["model_params"] diff --git a/source/tests/pt/model/test_descriptor_dpa1.py b/source/tests/pt/model/test_descriptor_dpa1.py index 488cc2f7f..a3d696516 100644 --- a/source/tests/pt/model/test_descriptor_dpa1.py +++ b/source/tests/pt/model/test_descriptor_dpa1.py @@ -245,13 +245,15 @@ def test_descriptor_block(self): des = DescrptBlockSeAtten( **dparams, ).to(env.DEVICE) - des.load_state_dict(torch.load(self.file_model_param)) + des.load_state_dict(torch.load(self.file_model_param, weights_only=True)) coord = self.coord atype = self.atype box = self.cell # handel type_embedding type_embedding = TypeEmbedNet(ntypes, 8, use_tebd_bias=True).to(env.DEVICE) - type_embedding.load_state_dict(torch.load(self.file_type_embed)) + type_embedding.load_state_dict( + torch.load(self.file_type_embed, weights_only=True) + ) ## to save model parameters # torch.save(des.state_dict(), 'model_weights.pth') @@ -299,8 +301,8 @@ def test_descriptor(self): **dparams, ).to(env.DEVICE) target_dict = des.state_dict() - source_dict = torch.load(self.file_model_param) - type_embd_dict = torch.load(self.file_type_embed) + source_dict = torch.load(self.file_model_param, weights_only=True) + type_embd_dict = torch.load(self.file_type_embed, weights_only=True) target_dict = translate_se_atten_and_type_embd_dicts_to_dpa1( target_dict, source_dict, diff --git a/source/tests/pt/model/test_descriptor_dpa2.py b/source/tests/pt/model/test_descriptor_dpa2.py index ac04bfc41..17d609a2f 100644 --- a/source/tests/pt/model/test_descriptor_dpa2.py +++ b/source/tests/pt/model/test_descriptor_dpa2.py @@ -123,10 +123,10 @@ def test_descriptor(self): **dparams, ).to(env.DEVICE) target_dict = des.state_dict() - source_dict = torch.load(self.file_model_param) + source_dict = torch.load(self.file_model_param, weights_only=True) # type_embd of repformer is removed source_dict.pop("type_embedding.embedding.embedding_net.layers.0.bias") - type_embd_dict = torch.load(self.file_type_embed) + type_embd_dict = torch.load(self.file_type_embed, weights_only=True) target_dict = translate_type_embd_dicts_to_dpa2( target_dict, source_dict, diff --git a/source/tests/pt/model/test_saveload_dpa1.py b/source/tests/pt/model/test_saveload_dpa1.py index 3da06938b..5b2b6cd58 100644 --- a/source/tests/pt/model/test_saveload_dpa1.py +++ b/source/tests/pt/model/test_saveload_dpa1.py @@ -85,7 +85,9 @@ def get_model_result(self, read=False, model_file="tmp_model.pt"): optimizer = torch.optim.Adam(wrapper.parameters(), lr=self.start_lr) optimizer.zero_grad() if read: - wrapper.load_state_dict(torch.load(model_file, map_location=env.DEVICE)) + wrapper.load_state_dict( + torch.load(model_file, map_location=env.DEVICE, weights_only=True) + ) os.remove(model_file) else: torch.save(wrapper.state_dict(), model_file) diff --git a/source/tests/pt/model/test_saveload_se_e2_a.py b/source/tests/pt/model/test_saveload_se_e2_a.py index 56ea3283d..d226f628b 100644 --- a/source/tests/pt/model/test_saveload_se_e2_a.py +++ b/source/tests/pt/model/test_saveload_se_e2_a.py @@ -85,7 +85,9 @@ def get_model_result(self, read=False, model_file="tmp_model.pt"): optimizer = torch.optim.Adam(wrapper.parameters(), lr=self.start_lr) optimizer.zero_grad() if read: - wrapper.load_state_dict(torch.load(model_file, map_location=env.DEVICE)) + wrapper.load_state_dict( + torch.load(model_file, map_location=env.DEVICE, weights_only=True) + ) os.remove(model_file) else: torch.save(wrapper.state_dict(), model_file) diff --git a/source/tests/pt/test_change_bias.py b/source/tests/pt/test_change_bias.py index f76be40b3..febc439f5 100644 --- a/source/tests/pt/test_change_bias.py +++ b/source/tests/pt/test_change_bias.py @@ -92,7 +92,9 @@ def test_change_bias_with_data(self): run_dp( f"dp --pt change-bias {self.model_path!s} -s {self.data_file[0]} -o {self.model_path_data_bias!s}" ) - state_dict = torch.load(str(self.model_path_data_bias), map_location=DEVICE) + state_dict = torch.load( + str(self.model_path_data_bias), map_location=DEVICE, weights_only=True + ) model_params = state_dict["model"]["_extra_state"]["model_params"] model_for_wrapper = get_model_for_wrapper(model_params) wrapper = ModelWrapper(model_for_wrapper) @@ -114,7 +116,7 @@ def test_change_bias_with_data_sys_file(self): f"dp --pt change-bias {self.model_path!s} -f {tmp_file.name} -o {self.model_path_data_file_bias!s}" ) state_dict = torch.load( - str(self.model_path_data_file_bias), map_location=DEVICE + str(self.model_path_data_file_bias), map_location=DEVICE, weights_only=True ) model_params = state_dict["model"]["_extra_state"]["model_params"] model_for_wrapper = get_model_for_wrapper(model_params) @@ -134,7 +136,9 @@ def test_change_bias_with_user_defined(self): run_dp( f"dp --pt change-bias {self.model_path!s} -b {' '.join([str(_) for _ in user_bias])} -o {self.model_path_user_bias!s}" ) - state_dict = torch.load(str(self.model_path_user_bias), map_location=DEVICE) + state_dict = torch.load( + str(self.model_path_user_bias), map_location=DEVICE, weights_only=True + ) model_params = state_dict["model"]["_extra_state"]["model_params"] model_for_wrapper = get_model_for_wrapper(model_params) wrapper = ModelWrapper(model_for_wrapper)