Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BC 4.37 -> 4.38] for Llama family, memory and speed #29753

Merged
merged 22 commits into from
Mar 20, 2024

Conversation

ArthurZucker
Copy link
Collaborator

@ArthurZucker ArthurZucker commented Mar 20, 2024

What does this PR do?

Fixes the BC issues between the two versions in term of memory consumption.
This fix is made a lot easier by all the tests, so thanks a lot @gante!

fixes #29412, fixes #29484 , fixes #29644, fixes #29651

@ArthurZucker ArthurZucker changed the title [BC 4.37 -> 4.38] [BC 4.37 -> 4.38] for Llama family, memory and speed Mar 20, 2024
Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general looks good to me, although I'm not 100% sure on the causal_mask *= torch.arange(target_length, device=device) > cache_position[0] line + assisted generation -- going to have a look

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@ArthurZucker ArthurZucker marked this pull request as ready for review March 20, 2024 12:52
@ArthurZucker
Copy link
Collaborator Author

torch script can be fixed by @fxmarty in a follow up PR + patch IMO!

@fxmarty
Copy link
Contributor

fxmarty commented Mar 20, 2024

@ArthurZucker that would be great if torchscript/fx tests pass &

optimum-cli export onnx --model fxmarty/tiny-llama-fast-tokenizer llama_onnx

does not break

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the offline explanation ! this shouldn't affect FA2 as we always return attention_mask without processing it for FA2 modules in _update_causal_mask, as you explained offline causal_mask *= torch.arange(target_length, device=device) > cache_position[0] is used to mask out the cached hidden states

@gante
Copy link
Member

gante commented Mar 20, 2024

@ArthurZucker the line causal_mask *= torch.arange(target_length, device=device) > cache_position[0] should become causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)

In the previous version, in a situation with 17 cached tokens and 4 new assistant tokens, the causal mask would become

(Pdb) causal_mask
tensor([[[[     0.,     -0.,     -0.,     -0.,     -0.,     -0.,     -0.,
               -0.,     -0.,     -0.,     -0.,     -0.,     -0.,     -0.,
               -0.,     -0.,     -0., -65504., -65504., -65504., -65504.],
          [     0.,      0.,     -0.,     -0.,     -0.,     -0.,     -0.,
               -0.,     -0.,     -0.,     -0.,     -0.,     -0.,     -0.,
               -0.,     -0.,     -0., -65504., -65504., -65504., -65504.],
          [     0.,      0.,      0.,     -0.,     -0.,     -0.,     -0.,
               -0.,     -0.,     -0.,     -0.,     -0.,     -0.,     -0.,
               -0.,     -0.,     -0., -65504., -65504., -65504., -65504.],
          [     0.,      0.,      0.,      0.,     -0.,     -0.,     -0.,
               -0.,     -0.,     -0.,     -0.,     -0.,     -0.,     -0.,
               -0.,     -0.,     -0., -65504., -65504., -65504., -65504.],
          [     0.,      0.,      0.,      0.,      0.,     -0.,     -0.,
               -0.,     -0.,     -0.,     -0.,     -0.,     -0.,     -0.,
               -0.,     -0.,     -0., -65504., -65504., -65504., -65504.]]]],
       device='cuda:0', dtype=torch.float16)

i.e. not upper triangular. After the suggested change, it becomes

(Pdb) causal_mask
tensor([[[[     0.,     -0.,     -0.,     -0.,     -0.,     -0.,     -0.,
               -0.,     -0.,     -0.,     -0.,     -0.,     -0.,     -0.,
               -0.,     -0.,     -0., -65504., -65504., -65504., -65504.],
          [     0.,      0.,     -0.,     -0.,     -0.,     -0.,     -0.,
               -0.,     -0.,     -0.,     -0.,     -0.,     -0.,     -0.,
               -0.,     -0.,     -0.,     -0., -65504., -65504., -65504.],
          [     0.,      0.,      0.,     -0.,     -0.,     -0.,     -0.,
               -0.,     -0.,     -0.,     -0.,     -0.,     -0.,     -0.,
               -0.,     -0.,     -0.,     -0.,     -0., -65504., -65504.],
          [     0.,      0.,      0.,      0.,     -0.,     -0.,     -0.,
               -0.,     -0.,     -0.,     -0.,     -0.,     -0.,     -0.,
               -0.,     -0.,     -0.,     -0.,     -0.,     -0., -65504.],
          [     0.,      0.,      0.,      0.,      0.,     -0.,     -0.,
               -0.,     -0.,     -0.,     -0.,     -0.,     -0.,     -0.,
               -0.,     -0.,     -0.,     -0.,     -0.,     -0.,     -0.]]]],
       device='cuda:0', dtype=torch.float16)

@ArthurZucker
Copy link
Collaborator Author

Any idea why our tests don't complain?

@gante
Copy link
Member

gante commented Mar 20, 2024

Any idea why our tests don't complain?

I think we don't have hard correctness checks for assisted generation 🙈 only API checks

@ArthurZucker ArthurZucker merged commit ff84190 into main Mar 20, 2024
19 checks passed
@ArthurZucker ArthurZucker deleted the fix-causal-mask-dispatch branch March 20, 2024 22:47
ArthurZucker added a commit that referenced this pull request Mar 20, 2024
* attempt to fix

* the actual fix that works with compilation!

* this?

* temporary update

* nit?

* dispatcg to memory efficient?

* update both models that have static cache support

* fix copies fix compile

* make sure fix

* fix cohere and gemma

* fix beams?

* nit

* slipped through the cracks

* nit

* nits

* update

* fix-copies

* skip failing tests

* nits
@poedator
Copy link
Contributor

congratulations with this PR fixing many important things!
Unfortunately, it broke my heart custom 4D masks support.
try RUN_SLOW=1 python -m pytest -v ./tests/test_modeling_utils.py::Mask4DTestHard::test_partial_stacked_causal_mask

@gante and myself introduced this test recently in #29731 - please make it a part of the test suite for all things related to attention_masks and StaticCache

here is what happens (numbers based on test_partial_stacked_causal_mask after this line 2170:
in modeling_llama.py::_update_causal_mask()
the causal_mask has shape (1, 1, sequence_length, target_length) (1,1,9,12) (3 tokens in cache)
if the custom 4D mask enters with same shape, it triggers offset = 3 and then causes error when copied over causal mask.
if the custom 4D mask enters with shape (1, 1, 12, 12), offset stays at zero but then again this causes error when copied over causal mask, for mask_slice and causal_mask shapes don't match.
There are also changes here which may affect this test.

I hesitate to offer a PR because it may break other things that you try to do with this part of code. But, please, make the test work. BTW, it may be OK to change the test, for instance passing the whole bigger mask including cached items by editing this line to mask_1b = mask_1

Special note on StaticCache: I like this feature and I want to use custom 4D masks with it. So far this is not tested. I'd be glad to contribute such test once this issue is fixed. It will look like test_partial_stacked_causal_mask, only with StaticCache.

cc @ArthurZucker @gante

@ArthurZucker
Copy link
Collaborator Author

I can try to fix it, and yes I thought it would be tested automatically. It should be part of the tokenization_common or at least LlamaIntegrationTests or something. cc @gante if you can take a look I'll gladly review a PR

@gante
Copy link
Member

gante commented Mar 28, 2024

^ this PR should fix it 🤗

itazap pushed a commit that referenced this pull request May 14, 2024
* attempt to fix

* the actual fix that works with compilation!

* this?

* temporary update

* nit?

* dispatcg to memory efficient?

* update both models that have static cache support

* fix copies fix compile

* make sure fix

* fix cohere and gemma

* fix beams?

* nit

* slipped through the cracks

* nit

* nits

* update

* fix-copies

* skip failing tests

* nits
This pull request was closed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
6 participants