Skip to content

Commit

Permalink
update async_save_info
Browse files Browse the repository at this point in the history
  • Loading branch information
DesmonDay committed Sep 23, 2024
1 parent 93fdaed commit 3f05fb9
Showing 1 changed file with 3 additions and 6 deletions.
9 changes: 3 additions & 6 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2427,12 +2427,10 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None, merge_tensor_
# Save a trained model and configuration using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`

local_rank = int(os.getenv("PADDLE_RANK_IN_NODE", 0))
if (
strtobool(os.getenv("FLAG_LLM_PDC", "False"))
and local_rank == 0
and paddle.distributed.get_rank() == 0
and self.args.unified_checkpoint
and "async_save" in self.args.unified_checkpoint_config
):
os.makedirs(self.args.logging_dir, exist_ok=True)
world_size = paddle.distributed.get_world_size()
Expand All @@ -2441,9 +2439,8 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None, merge_tensor_
"ignore_save_lr_and_optim": self.args.ignore_save_lr_and_optim,
"skip_save_model_weight": "skip_save_model_weight" in self.args.unified_checkpoint_config,
}
if not os.path.exists(os.path.join(self.args.logging_dir, "async_save_info.json")):
with open(os.path.join(self.args.logging_dir, "async_save_info.json"), "w") as f:
json.dump(save_info, f)
with open(os.path.join(self.args.logging_dir, "async_save_info.json"), "w") as f:
json.dump(save_info, f)

if self.args.should_save:
if self.tokenizer is not None:
Expand Down

0 comments on commit 3f05fb9

Please sign in to comment.