Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix attn mask ignore logic in training-time trace #32613

Open
wants to merge 15 commits into
base: main
Choose a base branch
from

Conversation

zhenglongjiepheonix
Copy link
Contributor

@zhenglongjiepheonix zhenglongjiepheonix commented Aug 11, 2024

This pr fixes a scenario where we want to use dynamo trace in training mode, the current attn mask ignore logic creates a problem where data-dependent branch condition torch.all(attn_mask==1) will cause graph breaks and disable full-graph tracing, the current solution is to disable mask ignore logic as long as we are in tracing mode no matter we are in training or inference phase.

This will enable compilation for training(forward+backward) like this:

model = LlamaForCausalLM(config).cuda()
model = torch.compile(model, fullgraph=True)
loss = model(**inputs)[0]
loss.backward()

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good! similarly to

def test_torch_compile(self):
can you add the training compile test? 🤗

@zhenglongjiepheonix
Copy link
Contributor Author

zhenglongjiepheonix commented Aug 12, 2024

Added a simple gradient match test in training mode @ArthurZucker . According to my local test, the backward latency remains pretty much the same as eager mode when steady(no matter whether cudagraphs are enabled), but we can still benefit from the forward pass in training, and training-time trace does help in my case where I need to do graph analysis under training mode.

@@ -276,7 +276,7 @@ def _ignore_causal_mask_sdpa(
elif sliding_window is None or key_value_length < sliding_window:
if len(attention_mask.shape) == 4:
return False
elif (is_training or not is_tracing) and torch.all(attention_mask == 1):
elif not is_tracing and torch.all(attention_mask == 1):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's actually a paradox branch where is_tracing and torch.all(attention_mask == 1) can never exist together

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, can you try to run the slow tests for this! 🤗

"input_ids": torch.randint(
low=1, high=model.config.vocab_size, size=(batch_size, seq_len), device=torch_device
),
"attention_mask": torch.ones((batch_size, seq_len), dtype=torch.int64, device=torch_device),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

attn mask being full of one is more prone to skipping some branches, would try with ones and zeros as welll!

@ArthurZucker
Copy link
Collaborator

Failure is most probably not related to you, but the slow tests is bad / badly designed as it should either not use accelerate, or not use parallelism to make sure we are testing apples to apples.

@ArthurZucker
Copy link
Collaborator

Would be nice if you can update the tests 🙏🏻

@zhenglongjiepheonix
Copy link
Contributor Author

zhenglongjiepheonix commented Sep 7, 2024

Accelerate will kick in if memory is not enough. I think the best solution is just to use the current torch device rather than specifying device_map='sequential', it might cause GPU OOM in this way however. Pytests seem to have problems in avoiding GPU memory fragmentation, see https://discuss.pytorch.org/t/torch-pytest-leads-to-memory-fragmentation-how-to-do-proper-integration-testing-of-a-lot-of-torch-models/201231 , I have run into similar issues where if I run single test alone it passes, and fails because of OOM when I run all the tests all together.

According to the failure information of OOM, the GPU memory allocated but not used is about 7MB, I think it's not significant so fragmentation is not so bad to cause the failure, we simply need GPUs with more memory.

@zhenglongjiepheonix
Copy link
Contributor Author

zhenglongjiepheonix commented Sep 7, 2024

Failure is most probably not related to you, but the slow tests is bad / badly designed as it should either not use accelerate, or not use parallelism to make sure we are testing apples to apples.

Yes, we should never use accelerate, but in order to make tests pass robustly, we might need GPU runners with more memory, T4 will definitely not be enough because simply loading single model like llama-7b without doing anything would require 13GB

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants