Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix attn mask ignore logic in training-time trace #32613

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1192,7 +1192,6 @@ def __init__(
dtype: torch.dtype = torch.float32,
max_batch_size: Optional[int] = None,
) -> None:
super().__init__()
if not hasattr(config, "sliding_window") or config.sliding_window is None:
raise ValueError(
"Setting `cache_implementation` to 'sliding_window' requires the model config supporting "
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/modeling_attn_mask_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def _ignore_causal_mask_sdpa(
elif sliding_window is None or key_value_length < sliding_window:
if len(attention_mask.shape) == 4:
return False
elif (is_training or not is_tracing) and torch.all(attention_mask == 1):
elif not is_tracing and torch.all(attention_mask == 1):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's actually a paradox branch where is_tracing and torch.all(attention_mask == 1) can never exist together

if query_length == 1 or key_value_length == query_length:
# For query_length == 1, causal attention and bi-directional attention are the same.
ignore_causal_mask = True
Expand Down
5 changes: 4 additions & 1 deletion tests/models/gemma/test_modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,9 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
# used in `test_torch_compile`
_torch_compile_test_ckpt = "google/gemma-2b"

# used in `test_torch_compile_for_training`
_torch_compile_train_cls = GemmaForCausalLM if is_torch_available() else None

# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146
def is_pipeline_test_to_skip(
self, pipeline_test_casse_name, config_class, model_architecture, tokenizer_name, processor_name
Expand Down Expand Up @@ -811,7 +814,7 @@ def test_compile_static_cache(self):

prompts = ["Hello I am doing", "Hi today"]
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b", pad_token="</s>", padding_side="right")
model = GemmaForCausalLM.from_pretrained("google/gemma-2b", device_map="sequential", torch_dtype=torch.float16)
model = GemmaForCausalLM.from_pretrained("google/gemma-2b", device_map=torch_device, torch_dtype=torch.float16)
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)

# Dynamic Cache
Expand Down
5 changes: 4 additions & 1 deletion tests/models/llama/test_modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,9 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
# used in `test_torch_compile`
_torch_compile_test_ckpt = "meta-llama/Llama-2-7b-hf"

# used in `test_torch_compile_for_training`
_torch_compile_train_cls = LlamaForCausalLM if is_torch_available() else None

def setUp(self):
self.model_tester = LlamaModelTester(self)
self.config_tester = ConfigTester(self, config_class=LlamaConfig, hidden_size=37)
Expand Down Expand Up @@ -872,7 +875,7 @@ def test_compile_static_cache(self):
]
tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", pad_token="</s>", padding_side="right")
model = LlamaForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf", device_map="sequential", torch_dtype=torch.float16
"meta-llama/Llama-2-7b-hf", device_map=torch_device, torch_dtype=torch.float16
)
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)

Expand Down
2 changes: 1 addition & 1 deletion tests/models/mistral/test_modeling_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,7 +676,7 @@ def test_compile_static_cache(self):
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=False)
tokenizer.pad_token = tokenizer.eos_token
model = MistralForCausalLM.from_pretrained(
"mistralai/Mistral-7B-v0.1", device_map="sequential", torch_dtype=torch.float16
"mistralai/Mistral-7B-v0.1", device_map=torch_device, torch_dtype=torch.float16
)
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)

Expand Down
2 changes: 2 additions & 0 deletions tests/models/nemotron/test_modeling_nemotron.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ class NemotronModelTest(GemmaModelTest):

# used in `test_torch_compile`
_torch_compile_test_ckpt = "nvidia/nemotron-3-8b-base-4k-hf"
# used in `test_torch_compile_for_training`
_torch_compile_train_cls = NemotronForCausalLM if is_torch_available() else None

def setUp(self):
self.model_tester = NemotronModelTester(self)
Expand Down
43 changes: 43 additions & 0 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4739,6 +4739,49 @@ def test_torch_compile(self):
for i in range(n_iter):
_ = model.generate(**input_ids, do_sample=False)

@slow
@require_torch_gpu
def test_torch_compile_for_training(self):
if version.parse(torch.__version__) < version.parse("2.3"):
self.skipTest(reason="This test requires torch >= 2.3 to run.")

if not hasattr(self, "_torch_compile_train_cls"):
self.skipTest(f"{self.__class__.__name__} doesn't have the attribute `_torch_compile_train_cls`.")

config, _ = self.model_tester.prepare_config_and_inputs_for_common()
cls = self._torch_compile_train_cls
model = cls(config).to(torch_device)

inputs = {
"input_ids": torch.randint(low=1, high=model.config.vocab_size, size=(2, 10), device=torch_device),
"attention_mask": torch.tensor(
[[1, 1, 1, 1, 1, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]],
dtype=torch.int64,
device=torch_device,
),
"position_ids": torch.arange(0, 10, device=torch_device).unsqueeze(0),
"labels": torch.randint(low=1, high=model.config.vocab_size, size=(2, 10), device=torch_device),
}

# eager backward
set_seed(42)
loss = model(**inputs).loss
loss.backward()

params = {name: param.grad.clone().detach().cpu() for name, param in model.named_parameters()}
model.zero_grad()
del loss

model = torch.compile(model, fullgraph=True, mode="reduce-overhead")
# forward compilation
set_seed(42)
loss = model(**inputs).loss
# backward compilation
loss.backward()
# check grad matches
for name, param in model._orig_mod.named_parameters():
torch.testing.assert_close(param.grad.detach().cpu(), params[name], rtol=1e-4, atol=1e-4)

@slow
@require_torch_gpu # Testing cuda graphs.
@require_read_token
Expand Down
Loading