diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index 09be23a0d381f9..ebc9ce5ec358eb 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -1683,9 +1683,9 @@ def test_labels_sequence_max_length_correct(self): input_features = input_dict["input_features"] labels_length = config.max_target_positions - labels = torch.ones(1, labels_length, dtype=torch.int64) + labels = torch.ones(1, labels_length, dtype=torch.int64).to(torch_device) - model = model_class(config) + model = model_class(config).to(torch_device) model(input_features=input_features, labels=labels) def test_labels_sequence_max_length_correct_after_changing_config(self): @@ -1697,9 +1697,9 @@ def test_labels_sequence_max_length_correct_after_changing_config(self): config.max_target_positions += 100 labels_length = config.max_target_positions - labels = torch.ones(1, labels_length, dtype=torch.int64) + labels = torch.ones(1, labels_length, dtype=torch.int64).to(torch_device) - model = model_class(config) + model = model_class(config).to(torch_device) model(input_features=input_features, labels=labels) def test_labels_sequence_max_length_error(self): @@ -1709,9 +1709,9 @@ def test_labels_sequence_max_length_error(self): input_features = input_dict["input_features"] labels_length = config.max_target_positions + 1 - labels = torch.ones(1, labels_length, dtype=torch.int64) + labels = torch.ones(1, labels_length, dtype=torch.int64).to(torch_device) - model = model_class(config) + model = model_class(config).to(torch_device) with self.assertRaises(ValueError): model(input_features=input_features, labels=labels) @@ -1719,11 +1719,11 @@ def test_labels_sequence_max_length_error_after_changing_config(self): config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() for model_class in self.all_generative_model_classes: - model = model_class(config) + model = model_class(config).to(torch_device) input_features = input_dict["input_features"] labels_length = config.max_target_positions + 1 - labels = torch.ones(1, labels_length, dtype=torch.int64) + labels = torch.ones(1, labels_length, dtype=torch.int64).to(torch_device) new_max_length = config.max_target_positions + 100 model.config.max_length = new_max_length @@ -2385,7 +2385,9 @@ def test_tiny_token_timestamp_generation_longform(self): ) inputs = inputs.to(torch_device) - generate_outputs = model.generate(**inputs, return_segments=True, return_token_timestamps=True) + generate_outputs = model.generate( + **inputs, return_segments=True, return_token_timestamps=True, return_timestamps=True + ) token_timestamps_shape = [ [segment["token_timestamps"].shape for segment in segment_list] @@ -2550,14 +2552,14 @@ def test_default_multilingual_transcription_long_form(self): ).input_features.to(torch_device) # task defaults to transcribe - sequences = model.generate(input_features) + sequences = model.generate(input_features, return_timestamps=True) transcription = processor.batch_decode(sequences)[0] assert transcription == " मिर्ची में कितने विबिन्द प्रजातियां हैं? मिर्ची में कितने विबिन्द प्रजातियां हैं?" # set task to translate - sequences = model.generate(input_features, task="translate") + sequences = model.generate(input_features, task="translate", return_timestamps=True) transcription = processor.batch_decode(sequences)[0] assert ( @@ -3264,6 +3266,7 @@ def test_whisper_empty_longform(self): "num_beams": 5, "language": "fr", "task": "transcribe", + "return_timestamps": True, } torch.manual_seed(0)