Skip to content

Commit

Permalink
update async_save_info
Browse files Browse the repository at this point in the history
  • Loading branch information
DesmonDay committed Sep 23, 2024
1 parent 93fdaed commit 85d3173
Showing 1 changed file with 13 additions and 14 deletions.
27 changes: 13 additions & 14 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2155,16 +2155,7 @@ def save_model(self, output_dir: Optional[str] = None, merge_tensor_parallel: Op
self.model_wrapped.get_all_parameters(convert2cpu=True)

if self.args.should_save_model_state:
unified_checkpoint_config_backup = self.args.unified_checkpoint_config
# backup and remove unified_checkpoint_config for not trine stage
if not self.is_in_train:
self.args.unified_checkpoint_config = []

self._save(output_dir=output_dir, merge_tensor_parallel=merge_tensor_parallel)

# recover unified_checkpoint_config for not trine stage
if not self.is_in_train:
self.args.unified_checkpoint_config = unified_checkpoint_config_backup
else:
if self.args.unified_checkpoint and "async_save" in self.args.unified_checkpoint_config:
os.makedirs(output_dir, exist_ok=True)
Expand Down Expand Up @@ -2427,10 +2418,9 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None, merge_tensor_
# Save a trained model and configuration using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`

local_rank = int(os.getenv("PADDLE_RANK_IN_NODE", 0))
if (
strtobool(os.getenv("FLAG_LLM_PDC", "False"))
and local_rank == 0
and paddle.distributed.get_rank() == 0
and self.args.unified_checkpoint
and "async_save" in self.args.unified_checkpoint_config
):
Expand All @@ -2441,9 +2431,8 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None, merge_tensor_
"ignore_save_lr_and_optim": self.args.ignore_save_lr_and_optim,
"skip_save_model_weight": "skip_save_model_weight" in self.args.unified_checkpoint_config,
}
if not os.path.exists(os.path.join(self.args.logging_dir, "async_save_info.json")):
with open(os.path.join(self.args.logging_dir, "async_save_info.json"), "w") as f:
json.dump(save_info, f)
with open(os.path.join(self.args.logging_dir, "async_save_info.json"), "w") as f:
json.dump(save_info, f)

if self.args.should_save:
if self.tokenizer is not None:
Expand All @@ -2452,7 +2441,17 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None, merge_tensor_
paddle.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))

if self.args.unified_checkpoint:
unified_checkpoint_config_backup = self.args.unified_checkpoint_config
# backup and remove unified_checkpoint_config for not trine stage
if not self.is_in_train:
self.args.unified_checkpoint_config = []

self.unified_checkpoint_handler.save_unified_checkpoint(self.model, self.optimizer, output_dir)

# recover unified_checkpoint_config for not trine stage
if not self.is_in_train:
self.args.unified_checkpoint_config = unified_checkpoint_config_backup

return

merge_tensor_parallel = merge_tensor_parallel and self.args.use_hybrid_parallel
Expand Down

0 comments on commit 85d3173

Please sign in to comment.