Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NeMo-UX] Support save_last="link" #10548

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
Open

Conversation

ashors1
Copy link
Collaborator

@ashors1 ashors1 commented Sep 20, 2024

What does this PR do ?

Adds support for creating a symlink for -last checkpoints. Implementation is compatible with synchronous and asynchronous checkpointing.

Collection: llm

Changelog

  • Add specific line by line info of high level changes in this PR.

Usage

  • You can potentially add a usage example below
# Add a code snippet demonstrating how to use this 

GitHub Actions CI

The Jenkins CI system has been replaced by GitHub Actions self-hosted runners.

The GitHub Actions CI will run automatically when the "Run CICD" label is added to the PR.
To re-run CI remove and add the label again.
To run CI on an untrusted fork, a NeMo user with write access must first click "Approve and run".

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?
  • Does the PR affect components that are optional to install? (Ex: Numba, Pynini, Apex etc)
    • Reviewer: Does the PR have correct import guards for all optional libraries?

PR Type:

  • New Feature
  • Bugfix
  • Documentation

If you haven't finished some of the above items you can still open "Draft" PR.

Who can review?

Anyone in the community is free to review the PR once the checks have passed.
Contributor guidelines contains specific people who can review PRs to various areas.

Additional Information

  • Related to # (issue)

Signed-off-by: ashors1 <ashors@nvidia.com>
Signed-off-by: ashors1 <ashors@nvidia.com>
Signed-off-by: ashors1 <ashors@nvidia.com>
Signed-off-by: Anna Shors <ashors@nvidia.com>
Signed-off-by: ashors1 <ashors@nvidia.com>
Signed-off-by: ashors1 <ashors@nvidia.com>
Signed-off-by: ashors1 <ashors@nvidia.com>
Signed-off-by: ashors1 <ashors@nvidia.com>
Signed-off-by: ashors1 <ashors@nvidia.com>
Signed-off-by: ashors1 <ashors1@users.noreply.github.com>
nemo/lightning/pytorch/callbacks/model_checkpoint.py Outdated Show resolved Hide resolved
logging.info(f'Scheduled async checkpoint save for {filepath}')
else:
finalize_fn()

def _save_last_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, torch.Tensor]) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if there is some way to avoid overriding the whole method. This is always risky, since we lose touch with the upstream.

How is our flow different from the one in PTL which makes us add saved_current_step logic and also not rely on self.last_model_path?
Is it because PTL links to any available last checkpoint (not necessarily from the last iteration)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I made two changes:

  1. Made sure to add a symlink only when the current step was actually saved. As you suggested, PTL always links to the last checkpoint saved, which might not correspond to the latest step
  2. Added these lines which fix last_model_path saved to the *-last checkpoint state dict when using symlinks

I'll think about whether we can make these fixes without overwriting the entire _save_last_checkpoint method

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Maybe overwriting save_last_checkpoint is inevitable in which case current version is ok

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Running some final tests now, but I think I was able to avoid overwriting _save_last_checkpoint. Please let me know if you have any concerns with the current approach

Copy link
Collaborator

@mikolajblaz mikolajblaz Sep 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, thanks, this is great.

Do you know how last_model_path is used during restart? I'm wondering if the loaded state dict will be valid if e.g. failure happens between the regular and "last" ckpt save

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think last_model_path is only used when removing the previous -last model to ensure we only retain a single -last checkpoint. If failure happens between the regular and last checkpoint save, I don't think the state dict will be valid, but I also don't think this is a concern, because we'd end up restoring from the previously saved -last checkpoint which does have the correct state dict.

nemo/lightning/pytorch/callbacks/model_checkpoint.py Outdated Show resolved Hide resolved
ashors1 and others added 7 commits September 23, 2024 09:19
Signed-off-by: ashors1 <ashors@nvidia.com>
Signed-off-by: ashors1 <ashors@nvidia.com>
Signed-off-by: ashors1 <ashors@nvidia.com>
Signed-off-by: ashors1 <ashors@nvidia.com>
Signed-off-by: ashors1 <ashors@nvidia.com>
Signed-off-by: ashors1 <ashors1@users.noreply.github.com>
filepath = ckpt_to_dir(filepath)
linkpath = ckpt_to_dir(linkpath)
super()._link_checkpoint(trainer, filepath, linkpath)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way to avoid overriding PTL's _link_checkpoint method ? We want to avoid overriding PTL's private methods to have stable code.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can think about it, but it might be challenging to support linking with async checkpointing without overwriting this method. Also, the addition of saved_current_step is needed to fix a bug that seems to exist in PTL's link implementation in which the -last checkpoint gets linked to the most recently saved checkpoint, even if that corresponds to a different step.

Copy link
Collaborator

@athitten athitten Sep 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, thanks @ashors1 ! Then maybe we should file an issue with PTL issues and ask them to fix this bug ? That way it can save us from overriding private methods.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it would be great if PTL could fix that issue! But we'd still have to figure out how to handle async checkpointing, and I do think that would require us to overwrite either _link_checkpoint or _save_last_checkpoint (where _link_checkpoint is invoked)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree that for handling async save we have to override _link_checkpoint anyway.
But since we call super()._link_checkpoint I don't think there is too much risk connected with that

ashors1 and others added 3 commits September 23, 2024 22:36
Signed-off-by: ashors1 <ashors@nvidia.com>
Signed-off-by: ashors1 <ashors1@users.noreply.github.com>
@@ -0,0 +1,143 @@
import os
from dataclasses import dataclass

Check notice

Code scanning / CodeQL

Unused import Note test

Import of 'dataclass' is not used.
import pytest
import pytorch_lightning as pl
import torch
from megatron.core import ModelParallelConfig

Check notice

Code scanning / CodeQL

Unused import Note test

Import of 'ModelParallelConfig' is not used.
from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS

import nemo.lightning as nl
from nemo.collections import llm

Check notice

Code scanning / CodeQL

Unused import Note test

Import of 'llm' is not used.
model = ExampleModel()

data = RandomDataset(32, 64)
save_top_k = 3

Check notice

Code scanning / CodeQL

Unused local variable Note test

Variable save_top_k is not used.
use_datetime_version=False,
)

strategy = nl.MegatronStrategy(ckpt_async_save=True, replace_progress_bar=False)

Check notice

Code scanning / CodeQL

Unused local variable Note test

Variable strategy is not used.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants