Skip to content

Commit

Permalink
[whisper] pass attention_mask to generate_with_fallback() (huggingfac…
Browse files Browse the repository at this point in the history
…e#33145)

pass attention_mask to generate_with_fallback
  • Loading branch information
benniekiss authored and zucchini-nlp committed Aug 30, 2024
1 parent efe7f68 commit 9b19efa
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/transformers/models/whisper/generation_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,7 @@ def generate(
do_condition_on_prev_tokens=do_condition_on_prev_tokens,
is_shortform=is_shortform,
batch_size=batch_size,
attention_mask=attention_mask,
kwargs=kwargs,
)

Expand Down Expand Up @@ -790,6 +791,7 @@ def generate_with_fallback(
do_condition_on_prev_tokens,
is_shortform,
batch_size,
attention_mask,
kwargs,
):
kwargs = copy.copy(kwargs)
Expand Down Expand Up @@ -837,6 +839,7 @@ def generate_with_fallback(
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
synced_gpus=synced_gpus,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
**generate_kwargs,
)

Expand Down

0 comments on commit 9b19efa

Please sign in to comment.