Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switching to pipeline for HF whisper #814

Merged
merged 1 commit into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion buzz/transcriber/recording_transcriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@
logging.debug("Will use whisper API on %s, %s",
custom_openai_base_url, self.whisper_api_model)
else: # ModelType.HUGGING_FACE
model = transformers_whisper.load_model(model_path)
model = TransformersWhisper(model_path)

Check warning on line 84 in buzz/transcriber/recording_transcriber.py

View check run for this annotation

Codecov / codecov/patch

buzz/transcriber/recording_transcriber.py#L84

Added line #L84 was not covered by tests

initial_prompt = self.transcription_options.initial_prompt

Expand Down
8 changes: 5 additions & 3 deletions buzz/transcriber/whisper_file_transcriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
import tqdm
from PyQt6.QtCore import QObject

from buzz import transformers_whisper
from buzz.conn import pipe_stderr
from buzz.model_loader import ModelType
from buzz.transformers_whisper import TransformersWhisper
from buzz.transcriber.file_transcriber import FileTranscriber
from buzz.transcriber.transcriber import FileTranscriptionTask, Segment

Expand Down Expand Up @@ -87,7 +87,10 @@ def transcribe_whisper(
) -> None:
with pipe_stderr(stderr_conn):
if task.transcription_options.model.model_type == ModelType.HUGGING_FACE:
# TODO Find a way to emmit real progress
sys.stderr.write("0%\n")
segments = cls.transcribe_hugging_face(task)
sys.stderr.write("100%\n")
elif (
task.transcription_options.model.model_type == ModelType.FASTER_WHISPER
):
Expand All @@ -105,7 +108,7 @@ def transcribe_whisper(

@classmethod
def transcribe_hugging_face(cls, task: FileTranscriptionTask) -> List[Segment]:
model = transformers_whisper.load_model(task.model_path)
model = TransformersWhisper(task.model_path)
language = (
task.transcription_options.language
if task.transcription_options.language is not None
Expand All @@ -115,7 +118,6 @@ def transcribe_hugging_face(cls, task: FileTranscriptionTask) -> List[Segment]:
audio=task.file_path,
language=language,
task=task.transcription_options.task.value,
verbose=False,
)
return [
Segment(
Expand Down
109 changes: 36 additions & 73 deletions buzz/transformers_whisper.py
Original file line number Diff line number Diff line change
@@ -1,98 +1,61 @@
import sys
import logging
from typing import Optional, Union

import numpy as np
from tqdm import tqdm

import whisper
import torch
from transformers import WhisperProcessor, WhisperForConditionalGeneration

def cuda_is_viable(min_vram_gb=10):
if not torch.cuda.is_available():
return False

total_memory = torch.cuda.get_device_properties(0).total_memory / 1e9 # Convert bytes to GB
if total_memory < min_vram_gb:
return False

return True


def load_model(model_name_or_path: str):
processor = WhisperProcessor.from_pretrained(model_name_or_path)
model = WhisperForConditionalGeneration.from_pretrained(model_name_or_path)

if cuda_is_viable():
logging.debug("CUDA is available and has enough VRAM, moving model to GPU.")
model.to("cuda")

return TransformersWhisper(processor, model)
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline


class TransformersWhisper:
def __init__(
self, processor: WhisperProcessor, model: WhisperForConditionalGeneration
self, model_id: str
):
self.processor = processor
self.model = model
self.SAMPLE_RATE = whisper.audio.SAMPLE_RATE
self.N_SAMPLES_IN_CHUNK = whisper.audio.N_SAMPLES
self.model_id = model_id

# Patch implementation of transcribing with transformers' WhisperProcessor until long-form transcription and
# timestamps are available. See: https://github.com/huggingface/transformers/issues/19887,
# https://github.com/huggingface/transformers/pull/20620.
def transcribe(
self,
audio: Union[str, np.ndarray],
language: str,
task: str,
verbose: Optional[bool] = None,
):
if isinstance(audio, str):
audio = whisper.load_audio(audio, sr=self.SAMPLE_RATE)

self.model.config.forced_decoder_ids = self.processor.get_decoder_prompt_ids(
task=task, language=language
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

model = AutoModelForSpeechSeq2Seq.from_pretrained(
self.model_id, torch_dtype=torch_dtype, use_safetensors=True
)

segments = []
all_predicted_ids = []
model.generation_config.language = language
model.to(device)

num_samples = audio.size
seek = 0
with tqdm(
total=num_samples, unit="samples", disable=verbose is not False
) as progress_bar:
while seek < num_samples:
chunk = audio[seek : seek + self.N_SAMPLES_IN_CHUNK]
input_features = self.processor(
chunk, return_tensors="pt", sampling_rate=self.SAMPLE_RATE
).input_features.to(self.model.device)
predicted_ids = self.model.generate(input_features)
all_predicted_ids.extend(predicted_ids)
text: str = self.processor.batch_decode(
predicted_ids, skip_special_tokens=True
)[0]
if text.strip() != "":
segments.append(
{
"start": seek / self.SAMPLE_RATE,
"end": min(seek + self.N_SAMPLES_IN_CHUNK, num_samples)
/ self.SAMPLE_RATE,
"text": text,
}
)
processor = AutoProcessor.from_pretrained(self.model_id)

pipe = pipeline(
"automatic-speech-recognition",
generate_kwargs={"language": language, "task": task},
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
chunk_length_s=30,
torch_dtype=torch_dtype,
device=device,
)

progress_bar.update(
min(seek + self.N_SAMPLES_IN_CHUNK, num_samples) - seek
)
seek += self.N_SAMPLES_IN_CHUNK
transcript = pipe(audio, return_timestamps=True)

segments = []
for chunk in transcript['chunks']:
start, end = chunk['timestamp']
text = chunk['text']
segments.append({
"start": start,
"end": end,
"text": text,
"translation": ""
})

return {
"text": self.processor.batch_decode(
all_predicted_ids, skip_special_tokens=True
)[0],
"text": transcript['text'],
"segments": segments,
}

4 changes: 2 additions & 2 deletions tests/transformers_whisper_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import platform
import pytest

from buzz.transformers_whisper import load_model
from buzz.transformers_whisper import TransformersWhisper
from tests.audio import test_audio_path


Expand All @@ -11,7 +11,7 @@
)
class TestTransformersWhisper:
def test_should_transcribe(self):
model = load_model("openai/whisper-tiny")
model = TransformersWhisper("openai/whisper-tiny")
result = model.transcribe(
audio=test_audio_path, language="fr", task="transcribe"
)
Expand Down
Loading