-
Notifications
You must be signed in to change notification settings - Fork 26.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BC 4.37 -> 4.38
] for Llama family, memory and speed
#29753
Conversation
…ausal-mask-dispatch
…nsformers into fix-causal-mask-dispatch
…ausal-mask-dispatch
BC 4.37 -> 4.38
]BC 4.37 -> 4.38
] for Llama family, memory and speed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general looks good to me, although I'm not 100% sure on the causal_mask *= torch.arange(target_length, device=device) > cache_position[0]
line + assisted generation -- going to have a look
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
torch script can be fixed by @fxmarty in a follow up PR + patch IMO! |
@ArthurZucker that would be great if torchscript/fx tests pass &
does not break |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the offline explanation ! this shouldn't affect FA2 as we always return attention_mask
without processing it for FA2 modules in _update_causal_mask
, as you explained offline causal_mask *= torch.arange(target_length, device=device) > cache_position[0]
is used to mask out the cached hidden states
@ArthurZucker the line In the previous version, in a situation with 17 cached tokens and 4 new assistant tokens, the causal mask would become
i.e. not upper triangular. After the suggested change, it becomes
|
Any idea why our tests don't complain? |
I think we don't have hard correctness checks for assisted generation 🙈 only API checks |
* attempt to fix * the actual fix that works with compilation! * this? * temporary update * nit? * dispatcg to memory efficient? * update both models that have static cache support * fix copies fix compile * make sure fix * fix cohere and gemma * fix beams? * nit * slipped through the cracks * nit * nits * update * fix-copies * skip failing tests * nits
congratulations with this PR fixing many important things! @gante and myself introduced this test recently in #29731 - please make it a part of the test suite for all things related to attention_masks and here is what happens (numbers based on I hesitate to offer a PR because it may break other things that you try to do with this part of code. But, please, make the test work. BTW, it may be OK to change the test, for instance passing the whole bigger mask including cached items by editing this line to Special note on StaticCache: I like this feature and I want to use custom 4D masks with it. So far this is not tested. I'd be glad to contribute such test once this issue is fixed. It will look like |
I can try to fix it, and yes I thought it would be tested automatically. It should be part of the |
^ this PR should fix it 🤗 |
* attempt to fix * the actual fix that works with compilation! * this? * temporary update * nit? * dispatcg to memory efficient? * update both models that have static cache support * fix copies fix compile * make sure fix * fix cohere and gemma * fix beams? * nit * slipped through the cracks * nit * nits * update * fix-copies * skip failing tests * nits
What does this PR do?
Fixes the BC issues between the two versions in term of memory consumption.
This fix is made a lot easier by all the tests, so thanks a lot @gante!
fixes #29412, fixes #29484 , fixes #29644, fixes #29651