Skip to content

Commit

Permalink
update async_save_info (#9181)
Browse files Browse the repository at this point in the history
  • Loading branch information
DesmonDay committed Sep 23, 2024
1 parent 0c615ef commit e0de9d3
Showing 1 changed file with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2155,16 +2155,7 @@ def save_model(self, output_dir: Optional[str] = None, merge_tensor_parallel: Op
self.model_wrapped.get_all_parameters(convert2cpu=True)

if self.args.should_save_model_state:
unified_checkpoint_config_backup = self.args.unified_checkpoint_config
# backup and remove unified_checkpoint_config for not trine stage
if not self.is_in_train:
self.args.unified_checkpoint_config = []

self._save(output_dir=output_dir, merge_tensor_parallel=merge_tensor_parallel)

# recover unified_checkpoint_config for not trine stage
if not self.is_in_train:
self.args.unified_checkpoint_config = unified_checkpoint_config_backup
else:
if self.args.unified_checkpoint and "async_save" in self.args.unified_checkpoint_config:
os.makedirs(output_dir, exist_ok=True)
Expand Down Expand Up @@ -2427,10 +2418,9 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None, merge_tensor_
# Save a trained model and configuration using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`

local_rank = int(os.getenv("PADDLE_RANK_IN_NODE", 0))
if (
strtobool(os.getenv("FLAG_LLM_PDC", "False"))
and local_rank == 0
and paddle.distributed.get_rank() == 0
and self.args.unified_checkpoint
and "async_save" in self.args.unified_checkpoint_config
):
Expand All @@ -2451,7 +2441,17 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None, merge_tensor_
paddle.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))

if self.args.unified_checkpoint:
unified_checkpoint_config_backup = self.args.unified_checkpoint_config
# backup and remove unified_checkpoint_config for not trine stage
if not self.is_in_train:
self.args.unified_checkpoint_config = []

self.unified_checkpoint_handler.save_unified_checkpoint(self.model, self.optimizer, output_dir)

# recover unified_checkpoint_config for not trine stage
if not self.is_in_train:
self.args.unified_checkpoint_config = unified_checkpoint_config_backup

return

merge_tensor_parallel = merge_tensor_parallel and self.args.use_hybrid_parallel
Expand Down

0 comments on commit e0de9d3

Please sign in to comment.