Skip to content

Commit

Permalink
Add support for Hugging Face models (#264)
Browse files Browse the repository at this point in the history
  • Loading branch information
chidiwilliams committed Dec 26, 2022
1 parent 82bdd30 commit 3dceb11
Show file tree
Hide file tree
Showing 11 changed files with 516 additions and 215 deletions.
2 changes: 1 addition & 1 deletion .coveragerc
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ omit =
directory = coverage/html

[report]
fail_under = 75
fail_under = 78
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ clean:
rm -rf dist/* || true

test: buzz/whisper_cpp.py
pytest --cov=buzz --cov-report=xml --cov-report=html
pytest -vv --cov=buzz --cov-report=xml --cov-report=html

dist/Buzz dist/Buzz.app: buzz/whisper_cpp.py
pyinstaller --noconfirm Buzz.spec
Expand Down
2 changes: 1 addition & 1 deletion buzz/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def load(self) -> List[FileTranscriptionTask]:
return pickle.load(file)
except FileNotFoundError:
return []
except pickle.UnpicklingError: # delete corrupted cache
except (pickle.UnpicklingError, AttributeError): # delete corrupted cache
os.remove(self.file_path)
return []

Expand Down
248 changes: 200 additions & 48 deletions buzz/gui.py

Large diffs are not rendered by default.

124 changes: 91 additions & 33 deletions buzz/model_loader.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,41 @@
import enum
import hashlib
import logging
import os
import warnings
from dataclasses import dataclass
from typing import Optional

import requests
import whisper
from PyQt6.QtCore import QObject, pyqtSignal, pyqtSlot
from platformdirs import user_cache_dir

from buzz.transcriber import Model
from buzz import transformers_whisper

MODELS_SHA256 = {

class WhisperModelSize(enum.Enum):
TINY = 'tiny'
BASE = 'base'
SMALL = 'small'
MEDIUM = 'medium'
LARGE = 'large'


class ModelType(enum.Enum):
WHISPER = 'Whisper'
WHISPER_CPP = 'Whisper.cpp'
HUGGING_FACE = 'Hugging Face'


@dataclass()
class TranscriptionModel:
model_type: ModelType = ModelType.WHISPER
whisper_model_size: Optional[WhisperModelSize] = WhisperModelSize.TINY
hugging_face_model_id: Optional[str] = None


WHISPER_CPP_MODELS_SHA256 = {
'tiny': 'be07e048e1e599ad46341c8d2a135645097a538221678b7acdd1b1919c6e1b21',
'base': '60ed5bc3dd14eea856493d334349b405782ddcaf0028d4b5df4088345fba2efe',
'small': '1be3a9b2063867b937e64e2ec7483364a79917e157fa98c5d94b5c1fffea987b',
Expand All @@ -20,53 +44,85 @@
}


def get_hugging_face_dataset_file_url(author: str, repository_name: str, filename: str):
return f'https://huggingface.co/datasets/{author}/{repository_name}/resolve/main/{filename}'


class ModelLoader(QObject):
progress = pyqtSignal(tuple) # (current, total)
finished = pyqtSignal(str)
error = pyqtSignal(str)
stopped = False

def __init__(self, model: Model, parent: Optional['QObject'] = None) -> None:
def __init__(self, model: TranscriptionModel, parent: Optional['QObject'] = None) -> None:
super().__init__(parent)
self.name = model.model_name()
self.use_whisper_cpp = model.is_whisper_cpp()
self.model_type = model.model_type
self.whisper_model_size = model.whisper_model_size
self.hugging_face_model_id = model.hugging_face_model_id

@pyqtSlot()
def run(self):
if self.model_type == ModelType.WHISPER_CPP:
root_dir = user_cache_dir('Buzz')
model_name = self.whisper_model_size.value
url = get_hugging_face_dataset_file_url(author='ggerganov', repository_name='whisper.cpp',
filename=f'ggml-{model_name}.bin')
file_path = os.path.join(root_dir, f'ggml-model-whisper-{model_name}.bin')
expected_sha256 = WHISPER_CPP_MODELS_SHA256[model_name]
self.download_model(url, file_path, expected_sha256)
return

if self.model_type == ModelType.WHISPER:
root_dir = os.getenv(
"XDG_CACHE_HOME",
os.path.join(os.path.expanduser("~"), ".cache", "whisper")
)
model_name = self.whisper_model_size.value
url = whisper._MODELS[model_name]
file_path = os.path.join(root_dir, os.path.basename(url))
expected_sha256 = url.split('/')[-2]
self.download_model(url, file_path, expected_sha256)
return

if self.model_type == ModelType.HUGGING_FACE:
self.progress.emit((0, 100))

try:
# Loads the model from cache or download if not in cache
transformers_whisper.load_model(self.hugging_face_model_id)
except (FileNotFoundError, EnvironmentError) as exception:
self.error.emit(f'{exception}')
return

self.progress.emit((100, 100))
self.finished.emit(self.hugging_face_model_id)
return

def download_model(self, url: str, file_path: str, expected_sha256: Optional[str]):
try:
if self.use_whisper_cpp:
root = user_cache_dir('Buzz')
url = f'https://huggingface.co/datasets/ggerganov/whisper.cpp/resolve/main/ggml-{self.name}.bin'
model_path = os.path.join(root, f'ggml-model-whisper-{self.name}.bin')
else:
root = os.getenv(
"XDG_CACHE_HOME",
os.path.join(os.path.expanduser("~"), ".cache", "whisper")
)
url = whisper._MODELS[self.name]
model_path = os.path.join(root, os.path.basename(url))

os.makedirs(root, exist_ok=True)

if os.path.exists(model_path) and not os.path.isfile(model_path):
os.makedirs(os.path.dirname(file_path), exist_ok=True)

if os.path.exists(file_path) and not os.path.isfile(file_path):
raise RuntimeError(
f"{model_path} exists and is not a regular file")
f"{file_path} exists and is not a regular file")

if os.path.isfile(file_path):
if expected_sha256 is None:
self.finished.emit(file_path)
return

expected_sha256 = MODELS_SHA256[self.name] if self.use_whisper_cpp else url.split(
"/")[-2]
if os.path.isfile(model_path):
model_bytes = open(model_path, "rb").read()
model_bytes = open(file_path, "rb").read()
model_sha256 = hashlib.sha256(model_bytes).hexdigest()
if model_sha256 == expected_sha256:
self.finished.emit(model_path)
self.finished.emit(file_path)
return
else:
warnings.warn(
f"{model_path} exists, but the SHA256 checksum does not match; re-downloading the file")
f"{file_path} exists, but the SHA256 checksum does not match; re-downloading the file")

# Downloads the model using the requests module instead of urllib to
# use the certs from certifi when the app is running in frozen mode
with requests.get(url, stream=True, timeout=15) as source, open(model_path, 'wb') as output:
with requests.get(url, stream=True, timeout=15) as source, open(file_path, 'wb') as output:
source.raise_for_status()
total_size = float(source.headers.get('Content-Length', 0))
current = 0.0
Expand All @@ -78,12 +134,14 @@ def run(self):
current += len(chunk)
self.progress.emit((current, total_size))

model_bytes = open(model_path, "rb").read()
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
raise RuntimeError(
"Model has been downloaded but the SHA256 checksum does not match. Please retry loading the model.")
if expected_sha256 is not None:
model_bytes = open(file_path, "rb").read()
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
raise RuntimeError(
"Model has been downloaded but the SHA256 checksum does not match. Please retry loading the "
"model.")

self.finished.emit(model_path)
self.finished.emit(file_path)
except RuntimeError as exc:
self.error.emit(str(exc))
logging.exception('')
Expand Down
93 changes: 37 additions & 56 deletions buzz/transcriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from multiprocessing.connection import Connection
from threading import Thread
from typing import Any, Callable, List, Optional, Tuple, Union
import typing

import ffmpeg
import numpy as np
Expand All @@ -25,7 +24,9 @@
from PyQt6.QtCore import QObject, QProcess, pyqtSignal, pyqtSlot, QThread
from sounddevice import PortAudioError

from . import transformers_whisper
from .conn import pipe_stderr
from .model_loader import TranscriptionModel, ModelType

# Catch exception from whisper.dll not getting loaded.
# TODO: Remove flag and try-except when issue with loading
Expand Down Expand Up @@ -53,32 +54,11 @@ class Segment:
text: str


class Model(enum.Enum):
WHISPER_TINY = 'Whisper - Tiny'
WHISPER_BASE = 'Whisper - Base'
WHISPER_SMALL = 'Whisper - Small'
WHISPER_MEDIUM = 'Whisper - Medium'
WHISPER_LARGE = 'Whisper - Large'
WHISPER_CPP_TINY = 'Whisper.cpp - Tiny'
WHISPER_CPP_BASE = 'Whisper.cpp - Base'
WHISPER_CPP_SMALL = 'Whisper.cpp - Small'
WHISPER_CPP_MEDIUM = 'Whisper.cpp - Medium'
WHISPER_CPP_LARGE = 'Whisper.cpp - Large'

def is_whisper_cpp(self) -> bool:
model_type, _ = self.value.split(' - ')
return model_type == 'Whisper.cpp'

def model_name(self) -> str:
_, model_name = self.value.split(' - ')
return model_name.lower()


@dataclass()
class TranscriptionOptions:
language: Optional[str] = None
task: Task = Task.TRANSCRIBE
model: Model = Model.WHISPER_TINY
model: TranscriptionModel = TranscriptionModel()
word_level_timings: bool = False
temperature: Tuple[float, ...] = DEFAULT_WHISPER_TEMPERATURE
initial_prompt: str = ''
Expand Down Expand Up @@ -373,7 +353,7 @@ class WhisperFileTranscriber(QObject):

current_process: multiprocessing.Process
progress = pyqtSignal(tuple) # (current, total)
completed = pyqtSignal(tuple) # (exit_code: int, segments: List[Segment])
completed = pyqtSignal(list) # List[Segment]
error = pyqtSignal(str)
running = False
read_line_thread: Optional[Thread] = None
Expand All @@ -390,6 +370,8 @@ def __init__(self, task: FileTranscriptionTask,
self.temperature = task.transcription_options.temperature
self.initial_prompt = task.transcription_options.initial_prompt
self.model_path = task.model_path
self.transcription_options = task.transcription_options
self.transcription_task = task
self.segments = []

@pyqtSlot()
Expand All @@ -406,13 +388,8 @@ def run(self):

recv_pipe, send_pipe = multiprocessing.Pipe(duplex=False)

self.current_process = multiprocessing.Process(
target=transcribe_whisper,
args=(
send_pipe, model_path, self.file_path,
self.language, self.task, self.word_level_timings,
self.temperature, self.initial_prompt
))
self.current_process = multiprocessing.Process(target=transcribe_whisper,
args=(send_pipe, self.transcription_task))
self.current_process.start()

self.read_line_thread = Thread(
Expand All @@ -430,8 +407,9 @@ def run(self):

self.read_line_thread.join()

if self.current_process.exitcode != 0:
self.completed.emit((self.current_process.exitcode, []))
# TODO: fix error handling when process crashes
if self.current_process.exitcode != 0 and self.current_process.exitcode is not None:
self.completed.emit([])

self.running = False

Expand All @@ -458,7 +436,7 @@ def read_line(self, pipe: Connection):
) for segment in segments_dict]
self.current_process.join()
# TODO: move this back to the parent thread
self.completed.emit((self.current_process.exitcode, segments))
self.completed.emit(segments)
else:
try:
progress = int(line.split('|')[0].strip().strip('%'))
Expand All @@ -467,26 +445,30 @@ def read_line(self, pipe: Connection):
continue


def transcribe_whisper(
stderr_conn: Connection, model_path: str, file_path: str,
language: Optional[str], task: Task,
word_level_timings: bool, temperature: Tuple[float, ...], initial_prompt: str):
def transcribe_whisper(stderr_conn: Connection, task: FileTranscriptionTask):
with pipe_stderr(stderr_conn):
model = whisper.load_model(model_path)

if word_level_timings:
stable_whisper.modify_model(model)
result = model.transcribe(
audio=file_path, language=language,
task=task.value, temperature=temperature,
initial_prompt=initial_prompt, pbar=True)
if task.transcription_options.model.model_type == ModelType.HUGGING_FACE:
model = transformers_whisper.load_model(task.model_path)
language = task.transcription_options.language if task.transcription_options.language is not None else 'en'
result = model.transcribe(audio_path=task.file_path, language=language,
task=task.transcription_options.task.value, verbose=False)
whisper_segments = result.get('segments')
else:
result = model.transcribe(
audio=file_path, language=language, task=task.value, temperature=temperature,
initial_prompt=initial_prompt, verbose=False)

whisper_segments = stable_whisper.group_word_timestamps(
result) if word_level_timings else result.get('segments')
model = whisper.load_model(task.model_path)
if task.transcription_options.word_level_timings:
stable_whisper.modify_model(model)
result = model.transcribe(
audio=task.file_path, language=task.transcription_options.language,
task=task.transcription_options.task.value, temperature=task.transcription_options.temperature,
initial_prompt=task.transcription_options.initial_prompt, pbar=True)
whisper_segments = stable_whisper.group_word_timestamps(result)
else:
result = model.transcribe(
audio=task.file_path, language=task.transcription_options.language,
task=task.transcription_options.task.value,
temperature=task.transcription_options.temperature,
initial_prompt=task.transcription_options.initial_prompt, verbose=False)
whisper_segments = result.get('segments')

segments = [
Segment(
Expand Down Expand Up @@ -638,7 +620,7 @@ def run(self):
self.completed.emit()
return

if self.current_task.transcription_options.model.is_whisper_cpp():
if self.current_task.transcription_options.model.model_type == ModelType.WHISPER_CPP:
self.current_transcriber = WhisperCppFileTranscriber(
task=self.current_task)
else:
Expand Down Expand Up @@ -688,10 +670,9 @@ def on_task_progress(self, progress: Tuple[int, int]):
self.current_task.fraction_completed = progress[0] / progress[1]
self.task_updated.emit(self.current_task)

@pyqtSlot(tuple)
def on_task_completed(self, result: Tuple[int, List[Segment]]):
@pyqtSlot(list)
def on_task_completed(self, segments: List[Segment]):
if self.current_task is not None:
_, segments = result
self.current_task.status = FileTranscriptionTask.Status.COMPLETED
self.current_task.segments = segments
self.task_updated.emit(self.current_task)
Expand Down
Loading

0 comments on commit 3dceb11

Please sign in to comment.