Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the initialization of the cache when we have multi gpu #33303

Merged
merged 12 commits into from
Sep 13, 2024

Conversation

SunMarc
Copy link
Member

@SunMarc SunMarc commented Sep 4, 2024

What does this PR do ?

Fixes #33287 (comment)
This PR initializes the cache on the right device when we are in multi-gpu setup. Before this PR, we would move the tensors to the right device during update() which created a few issues with export(). This is also a cleaner solution in general.

cc @gante I would love to have a quick feedback from you
cc @ArthurZucker as you were also interested in

Tested with:

RUN_SLOW=True CUDA_VISIBLE_DEVICES=0,1 pytest tests/generation/test_utils.py -k "test_generate_with_static_cache_multi_gpu"

RUN_SLOW=True CUDA_VISIBLE_DEVICES=0,1 pytest tests/generation/test_utils.py -k "test_init_static_cache_multi_gpu"

RUN_SLOW=1 pytest tests/utils/test_cache_utils.py -k test_static_cache_exportability

and:

from transformers import LlamaTokenizer, LlamaForCausalLM
import torch

NUM_TOKENS_TO_GENERATE = 40
# Note on `EXPECTED_TEXT_COMPLETION`'s diff: the current value matches the original test if the original test
# was changed to have a cache of 53 tokens (as opposed to 4096), on Ampere GPUs.
EXPECTED_TEXT_COMPLETION = [
    "Simply put, the theory of relativity states that 1) the speed of light is constant in all inertial "
    "reference frames, and 2) the laws of physics are the same for all inertial reference frames.\nThe "
    "theory of relativ",
    "My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, "
    "my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p",
]

prompts = [
    "Simply put, the theory of relativity states that ",
    "My favorite all time favorite condiment is ketchup.",
]
tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", pad_token="</s>", padding_side="right")
model = LlamaForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-hf", device_map="balanced", torch_dtype=torch.float16
)
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
# Dynamic Cache
generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False)
dynamic_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
# self.assertEqual(EXPECTED_TEXT_COMPLETION, dynamic_text)
print(dynamic_text)

# Static Cache
generated_ids = model.generate(
    **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static"
)
static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
# We get the same expected output ! 
print(static_text)

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, thank you for the quick jump to the solution! In general it looks great :D

I would add three more things:

  1. make device and layer_device_mapping mutually exclusive OR make the device argument also accept a device map (whichever is the most consistent across HF libraries)
  2. make sure we throw an exception somewhere: if a user is using multi GPU, then a layer map has to be passed
  3. tests 🤗

src/transformers/generation/utils.py Outdated Show resolved Hide resolved
src/transformers/generation/utils.py Outdated Show resolved Hide resolved
@SunMarc
Copy link
Member Author

SunMarc commented Sep 5, 2024

make sure we throw an exception somewhere: if a user is using multi GPU, then a layer map has to be passed

The only way to know if we are using multi-gpu is with the device_map. Since we are computing the layer_device_map using device_map in get_layer_device_map, that should be good enough no ? If we want to do the check inside the cache, we would have to pass the device_map arg.

@gante
Copy link
Member

gante commented Sep 5, 2024

@SunMarc

That's a good point... 👀 My question at the moment is the following: if a user decides to instantiate a cache manually, is using multi-gpu, and doesn't pass the new argument, how can we let the user know that they should have used the new argument?

@SunMarc SunMarc marked this pull request as ready for review September 5, 2024 16:41
@SunMarc SunMarc requested a review from gante September 5, 2024 16:41
@SunMarc
Copy link
Member Author

SunMarc commented Sep 5, 2024

Refactored a bit + added an integration test to check the device on the cache ! Note that we do the check after generation as we need to initialize the cache, meaning that we need the model. We already have tests for multi-gpu static cache so i don't think we need to add these. Let me know if there are other tests you would like to see.

RUN_SLOW=True CUDA_VISIBLE_DEVICES=0,1 pytest tests/generation/test_utils.py -k "test_generate_with_static_cache_multi_gpu"

That's a good point... 👀 My question at the moment is the following: if a user decides to instantiate a cache manually, is using multi-gpu, and doesn't pass the new argument, how can we let the user know that they should have used the new argument?

When passing the cache to the model in generate, is there a way to post initialize the cache again ? It seems to be the best angle to fix this issue. If one is using multi-gpu, we would have to initialize the model before creating the cache object if we want to pass a relevant args such as the device_map. However, I prefer not passing the device_map as we would have to recreate layer_device_map afterwards in the cache init + not friendly to the user.

Copy link
Contributor

@dvrogozh dvrogozh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Works for me to address #33178.

@gante
Copy link
Member

gante commented Sep 6, 2024

is there a way to post initialize the cache again ?

There is not atm, and I think we would hit a few issues if we do so 🤔 We would have to either:

  1. check whether there is data to copy across devices (data checks = not compile-friendly)
  2. never copy data when creating tensors in new devices (unexpected behavior from a user perspective)
  3. always copy data (we would have to use to to copy data across devices, which is essentially the issue Unbreak torch export with static cache #33287 tries to solve)

What about a simple device check at update time? If cache tensor device != new cache data device then throw informative exception

@SunMarc
Copy link
Member Author

SunMarc commented Sep 6, 2024

What about a simple device check at update time? If cache tensor device != new cache data device then throw informative exception

Sounds good @gante ! I've updated the PR with the check ! Let me know if this is better ! I'll add quick test to check if the warning is raised correctly !

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you for iterating 💛

Copy link
Contributor

@guangy10 guangy10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix. 🚀 🚀 🚀

@guangy10
Copy link
Contributor

guangy10 commented Sep 9, 2024

@SunMarc can you include RUN_SLOW=1 pytest tests/utils/test_cache_utils.py -k test_static_cache_exportability in the Test Plan to ensure the export test is fixed?

@SunMarc
Copy link
Member Author

SunMarc commented Sep 10, 2024

@SunMarc can you include RUN_SLOW=1 pytest tests/utils/test_cache_utils.py -k test_static_cache_exportability in the Test Plan to ensure the export test is fixed?

Test passed !

@guangy10
Copy link
Contributor

guangy10 commented Sep 11, 2024

@gante @LysandreJik can we prioritize to get this fix merged? I will need this one to unblock ExecuTorch integration. Thanks!

@LysandreJik
Copy link
Member

LysandreJik commented Sep 12, 2024

Thanks, for cache related PRs I recommend pinging @ArthurZucker for a review (pinging him now)

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM but I don't think we should check in the update. Maybe in generate?

Comment on lines 1101 to 1111
for state_str, state_device, self_state_device in [
("key_states", key_states.device, self.key_cache[layer_idx].device),
("value_states", value_states.device, self.value_cache[layer_idx].device),
]:
if state_device != self_state_device:
raise ValueError(
f"Computed {state_str} from layer {layer_idx} is on device {state_device} "
f"whereas stored {state_str} is on device {self_state_device}. "
f"If you are manually initializing the cache, make sure to pass the argument `layer_device_map` if you are using multi-gpu. "
" Otherwise, you can just pass `cache_implementation` in `model.generate()` to correctly initialize the cache."
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am really unsure this is worth it for us to run this at every forward pass. I know we want to help our users but would need to make sur it does not cost us anything

Copy link
Member

@gante gante Sep 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is a long discussion above, but tl;dr the options are:

  • don't warn at all
  • check devices in update (this implementation)

with torch.compile, these lines should get ignored anyway when called correctly (at tracing they have the same device). We should benchmark compile to confirm, though. Assuming they have no throughput cost, I think it's a win to have the error

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can also wrap update with a try/except, rather than using an if/else

@ArthurZucker
Copy link
Collaborator

Let's remove it for now and merge

@ArthurZucker
Copy link
Collaborator

cc @SunMarc

@SunMarc
Copy link
Member Author

SunMarc commented Sep 12, 2024

I ran a quick benchmark to see what the impact on generate with NUM_TOKENS_TO_GENERATE = 2000.

  • without the check : ~145.5s
  • with the check : ~145s
    -> We have a very small overhead

But yeah, let's remove that for the release and I can do a follow-up pr with either the same patch or use try except as joao suggested. Feel free to merge the PR for the release @ArthurZucker if you are fine with the modification!

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, tests looks great

@SunMarc SunMarc merged commit 6cc4dfe into main Sep 13, 2024
24 checks passed
@SunMarc SunMarc deleted the init-multi-gpu-cache branch September 13, 2024 13:06
facebook-github-bot pushed a commit to pytorch/executorch that referenced this pull request Sep 14, 2024
Summary:
bypass-github-export-checks

[Done] ~~Require PR [Make StaticCache configurable at model construct time](huggingface/transformers#32830) in order to export, lower and run the 🤗 model OOTB.~~
[Done] ~~Require huggingface/transformers#33303 or huggingface/transformers#33287 to be merged to 🤗 `transformers` to resolve the export issue introduced by huggingface/transformers#32543

-----------

Now we can take the integration point from 🤗 `transformers` to lower compatible models to ExecuTorch OOTB.
  - This PR creates a simple script with recipe of XNNPACK.
  - This PR also created a secret `EXECUTORCH_HT_TOKEN` to allow download checkpoints in the CI
  - This PR connects the 🤗 "Export to ExecuTorch" e2e workflow to ExecuTorch CI

### Instructions to run the demo:

1. Run the export_hf_model.py to lower gemma-2b to ExecuTorch:
```
python -m extension.export_util.export_hf_model -hfm "google/gemma-2b" # The model is exported statical dims with static KV cache
```
2. Run the tokenizer.py to generate the binary format for ExecuTorch runtime:
```
python -m extension.llm.tokenizer.tokenizer -t <path_to_downloaded_gemma_checkpoint_dir>/tokenizer.model -o tokenizer.bin
```
3. Build llm runner by following this guide [step 4](https://github.com/pytorch/executorch/tree/main/examples/models/llama2#step-4-run-on-your-computer-to-validate)

4. Run the lowered model
```
cmake-out/examples/models/llama2/llama_main --model_path=gemma.pte --tokenizer_path=tokenizer.bin --prompt="My name is"
```
OOTB output and perf
```
I 00:00:00.003110 executorch:cpuinfo_utils.cpp:62] Reading file /sys/devices/soc0/image_version
I 00:00:00.003360 executorch:cpuinfo_utils.cpp:78] Failed to open midr file /sys/devices/soc0/image_version
I 00:00:00.003380 executorch:cpuinfo_utils.cpp:158] Number of efficient cores 4
I 00:00:00.003384 executorch:main.cpp:65] Resetting threadpool with num threads = 6
I 00:00:00.014716 executorch:runner.cpp:51] Creating LLaMa runner: model_path=gemma.pte, tokenizer_path=tokenizer_gemma.bin
I 00:00:03.065359 executorch:runner.cpp:66] Reading metadata from model
I 00:00:03.065391 executorch:metadata_util.h:43] get_n_bos: 1
I 00:00:03.065396 executorch:metadata_util.h:43] get_n_eos: 1
I 00:00:03.065399 executorch:metadata_util.h:43] get_max_seq_len: 123
I 00:00:03.065402 executorch:metadata_util.h:43] use_kv_cache: 1
I 00:00:03.065404 executorch:metadata_util.h:41] The model does not contain use_sdpa_with_kv_cache method, using default value 0
I 00:00:03.065405 executorch:metadata_util.h:43] use_sdpa_with_kv_cache: 0
I 00:00:03.065407 executorch:metadata_util.h:41] The model does not contain append_eos_to_prompt method, using default value 0
I 00:00:03.065409 executorch:metadata_util.h:43] append_eos_to_prompt: 0
I 00:00:03.065411 executorch:metadata_util.h:41] The model does not contain enable_dynamic_shape method, using default value 0
I 00:00:03.065412 executorch:metadata_util.h:43] enable_dynamic_shape: 0
I 00:00:03.130388 executorch:metadata_util.h:43] get_vocab_size: 256000
I 00:00:03.130405 executorch:metadata_util.h:43] get_bos_id: 2
I 00:00:03.130408 executorch:metadata_util.h:43] get_eos_id: 1
My name is Melle. I am a 20 year old girl from Belgium. I am living in the southern part of Belgium. I am 165 cm tall and I weigh 45kg. I like to play sports like swimming, running and playing tennis. I am very interested in music and I like to listen to classical music. I like to sing and I can play the piano. I would like to go to the USA because I like to travel a lot. I am looking for a boy from the USA who is between 18 and 25 years old. I
PyTorchObserver {"prompt_tokens":4,"generated_tokens":118,"model_load_start_ms":1723685715497,"model_load_end_ms":1723685718612,"inference_start_ms":1723685718612,"inference_end_ms":1723685732965,"prompt_eval_end_ms":1723685719087,"first_token_ms":1723685719087,"aggregate_sampling_time_ms":182,"SCALING_FACTOR_UNITS_PER_SECOND":1000}
I 00:00:17.482472 executorch:stats.h:70] 	Prompt Tokens: 4    Generated Tokens: 118
I 00:00:17.482475 executorch:stats.h:76] 	Model Load Time:		3.115000 (seconds)
I 00:00:17.482481 executorch:stats.h:86] 	Total inference time:		14.353000 (seconds)		 Rate: 	8.221278 (tokens/second)
I 00:00:17.482483 executorch:stats.h:94] 		Prompt evaluation:	0.475000 (seconds)		 Rate: 	8.421053 (tokens/second)
I 00:00:17.482485 executorch:stats.h:105] 		Generated 118 tokens:	13.878000 (seconds)		 Rate: 	8.502666 (tokens/second)
I 00:00:17.482486 executorch:stats.h:113] 	Time to first generated token:	0.475000 (seconds)
I 00:00:17.482488 executorch:stats.h:120] 	Sampling time over 122 tokens:	0.182000 (seconds)
```

Pull Request resolved: #4723

Reviewed By: huydhn, kirklandsign

Differential Revision: D62543933

Pulled By: guangy10

fbshipit-source-id: 00401a39ba03d7383e4b284d25c8fc62a6695b34
This pull request was closed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants