Skip to content

Commit

Permalink
fix copies
Browse files Browse the repository at this point in the history
  • Loading branch information
AvivSham committed Jul 8, 2024
1 parent 0df4173 commit 9a2f7a6
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion src/transformers/models/whisper/tokenization_whisper_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,9 +584,15 @@ def get_prompt_ids(self, text: str, return_tensors="np"):

# Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._strip_prompt
def _strip_prompt(self, token_ids: List[int], prompt_token_id: int, decoder_start_token_id: int):
has_prompt = isinstance(token_ids, list) and token_ids and token_ids[0] == prompt_token_id
if not isinstance(token_ids, list):
token_ids = self._convert_to_list(token_ids)

# handle case of empty token_ids for decoding with timestamps.
# at this point token_ids is a list, so it is safe to use if not check.
if not token_ids:
return token_ids

has_prompt = token_ids[0] == prompt_token_id
if has_prompt:
if decoder_start_token_id in token_ids:
return token_ids[token_ids.index(decoder_start_token_id) :]
Expand Down

0 comments on commit 9a2f7a6

Please sign in to comment.