diff --git a/applications/Colossal-LLaMA/train.py b/applications/Colossal-LLaMA/train.py index 43a360a9a49c..e74aad33c3e3 100644 --- a/applications/Colossal-LLaMA/train.py +++ b/applications/Colossal-LLaMA/train.py @@ -128,6 +128,12 @@ def main() -> None: parser.add_argument("--zero", type=int, default=1) parser.add_argument("--pad_token", choices=["eos", "unk"], default="eos") parser.add_argument("--padding_mode", choices=["max_length", "longest"], default="max_length") + parser.add_argument( + "--skip_save_each_epoch", + action="store_true", + default=False, + help="skip saving the model checkpoint after each epoch is completed.", + ) args = parser.parse_args() with open(args.config_file, "w") as f: @@ -370,11 +376,17 @@ def main() -> None: ) total_loss.fill_(0.0) pbar.update() + # Save modeling. - if (args.save_interval > 0 and (step + 1) % (args.save_interval * args.accumulation_steps) == 0) or ( - step + 1 - ) == len(dataloader): + save_model_condition = ( + args.save_interval > 0 and (step + 1) % (args.save_interval * args.accumulation_steps) == 0 + ) + + if not args.skip_save_each_epoch: + save_model_condition = save_model_condition or (step + 1) == len(dataloader) + + if save_model_condition: coordinator.print_on_master("\nStart saving model checkpoint with running states") if args.use_neft: