Skip to content

Commit

Permalink
[Whisper test] Fix some failing tests (#33450)
Browse files Browse the repository at this point in the history
* Fix failing tensor placement in Whisper

* fix long form generation tests

* more return_timestamps=True

* make fixup

* [run_slow] whisper

* [run_slow] whisper
  • Loading branch information
ylacombe committed Sep 16, 2024
1 parent c2d0589 commit 98adf24
Showing 1 changed file with 14 additions and 11 deletions.
25 changes: 14 additions & 11 deletions tests/models/whisper/test_modeling_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1683,9 +1683,9 @@ def test_labels_sequence_max_length_correct(self):
input_features = input_dict["input_features"]

labels_length = config.max_target_positions
labels = torch.ones(1, labels_length, dtype=torch.int64)
labels = torch.ones(1, labels_length, dtype=torch.int64).to(torch_device)

model = model_class(config)
model = model_class(config).to(torch_device)
model(input_features=input_features, labels=labels)

def test_labels_sequence_max_length_correct_after_changing_config(self):
Expand All @@ -1697,9 +1697,9 @@ def test_labels_sequence_max_length_correct_after_changing_config(self):
config.max_target_positions += 100

labels_length = config.max_target_positions
labels = torch.ones(1, labels_length, dtype=torch.int64)
labels = torch.ones(1, labels_length, dtype=torch.int64).to(torch_device)

model = model_class(config)
model = model_class(config).to(torch_device)
model(input_features=input_features, labels=labels)

def test_labels_sequence_max_length_error(self):
Expand All @@ -1709,21 +1709,21 @@ def test_labels_sequence_max_length_error(self):
input_features = input_dict["input_features"]

labels_length = config.max_target_positions + 1
labels = torch.ones(1, labels_length, dtype=torch.int64)
labels = torch.ones(1, labels_length, dtype=torch.int64).to(torch_device)

model = model_class(config)
model = model_class(config).to(torch_device)
with self.assertRaises(ValueError):
model(input_features=input_features, labels=labels)

def test_labels_sequence_max_length_error_after_changing_config(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()

for model_class in self.all_generative_model_classes:
model = model_class(config)
model = model_class(config).to(torch_device)
input_features = input_dict["input_features"]

labels_length = config.max_target_positions + 1
labels = torch.ones(1, labels_length, dtype=torch.int64)
labels = torch.ones(1, labels_length, dtype=torch.int64).to(torch_device)

new_max_length = config.max_target_positions + 100
model.config.max_length = new_max_length
Expand Down Expand Up @@ -2385,7 +2385,9 @@ def test_tiny_token_timestamp_generation_longform(self):
)

inputs = inputs.to(torch_device)
generate_outputs = model.generate(**inputs, return_segments=True, return_token_timestamps=True)
generate_outputs = model.generate(
**inputs, return_segments=True, return_token_timestamps=True, return_timestamps=True
)

token_timestamps_shape = [
[segment["token_timestamps"].shape for segment in segment_list]
Expand Down Expand Up @@ -2550,14 +2552,14 @@ def test_default_multilingual_transcription_long_form(self):
).input_features.to(torch_device)

# task defaults to transcribe
sequences = model.generate(input_features)
sequences = model.generate(input_features, return_timestamps=True)

transcription = processor.batch_decode(sequences)[0]

assert transcription == " मिर्ची में कितने विबिन्द प्रजातियां हैं? मिर्ची में कितने विबिन्द प्रजातियां हैं?"

# set task to translate
sequences = model.generate(input_features, task="translate")
sequences = model.generate(input_features, task="translate", return_timestamps=True)
transcription = processor.batch_decode(sequences)[0]

assert (
Expand Down Expand Up @@ -3264,6 +3266,7 @@ def test_whisper_empty_longform(self):
"num_beams": 5,
"language": "fr",
"task": "transcribe",
"return_timestamps": True,
}

torch.manual_seed(0)
Expand Down

0 comments on commit 98adf24

Please sign in to comment.