Skip to content

Commit

Permalink
Add default file name setting (#559)
Browse files Browse the repository at this point in the history
  • Loading branch information
chidiwilliams committed Aug 5, 2023
1 parent 64b15f1 commit f5f77b3
Show file tree
Hide file tree
Showing 23 changed files with 843 additions and 620 deletions.
10 changes: 10 additions & 0 deletions buzz/dialogs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from PyQt6.QtWidgets import QWidget, QMessageBox


def show_model_download_error_dialog(parent: QWidget, error: str):
message = parent.tr(
'An error occurred while loading the Whisper model') + \
f": {error}{'' if error.endswith('.') else '.'}" + \
parent.tr("Please retry or check the application logs for more information.")

QMessageBox.critical(parent, '', message)
563 changes: 24 additions & 539 deletions buzz/gui.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion buzz/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from tqdm.auto import tqdm


class WhisperModelSize(enum.Enum):
class WhisperModelSize(str, enum.Enum):
TINY = 'tiny'
BASE = 'base'
SMALL = 'small'
Expand Down
8 changes: 6 additions & 2 deletions buzz/settings/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,18 @@ class Key(enum.Enum):
FILE_TRANSCRIBER_WORD_LEVEL_TIMINGS = 'file-transcriber/word-level-timings'
FILE_TRANSCRIBER_EXPORT_FORMATS = 'file-transcriber/export-formats'

DEFAULT_EXPORT_FILE_NAME = 'transcriber/default-export-file-name'

SHORTCUTS = 'shortcuts'

def set_value(self, key: Key, value: typing.Any) -> None:
self.settings.setValue(key.value, value)

def value(self, key: Key, default_value: typing.Any, value_type: typing.Optional[type] = None) -> typing.Any:
def value(self, key: Key, default_value: typing.Any,
value_type: typing.Optional[type] = None) -> typing.Any:
return self.settings.value(key.value, default_value,
value_type if value_type is not None else type(default_value))
value_type if value_type is not None else type(
default_value))

def clear(self):
self.settings.clear()
93 changes: 64 additions & 29 deletions buzz/transcriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,15 @@ class TranscriptionOptions:
word_level_timings: bool = False
temperature: Tuple[float, ...] = DEFAULT_WHISPER_TEMPERATURE
initial_prompt: str = ''
openai_access_token: str = field(default='', metadata=config(exclude=Exclude.ALWAYS))
openai_access_token: str = field(default='',
metadata=config(exclude=Exclude.ALWAYS))


@dataclass()
class FileTranscriptionOptions:
file_paths: List[str]
output_formats: Set['OutputFormat'] = field(default_factory=set)
default_output_file_name: str = ''


@dataclass_json
Expand Down Expand Up @@ -127,12 +129,11 @@ def run(self):
self.completed.emit(segments)

for output_format in self.transcription_task.file_transcription_options.output_formats:
default_path = get_default_output_file_path(
task=self.transcription_task.transcription_options.task,
input_file_path=self.transcription_task.file_path,
output_format=output_format)
default_path = get_default_output_file_path(task=self.transcription_task,
output_format=output_format)

write_output(path=default_path, segments=segments, output_format=output_format)
write_output(path=default_path, segments=segments,
output_format=output_format)

@abstractmethod
def transcribe(self) -> List[Segment]:
Expand Down Expand Up @@ -172,17 +173,22 @@ def transcribe(self) -> List[Segment]:
logging.debug(
'Starting whisper_cpp file transcription, file path = %s, language = %s, task = %s, model_path = %s, '
'word level timings = %s',
self.file_path, self.language, self.task, model_path, self.word_level_timings)
self.file_path, self.language, self.task, model_path,
self.word_level_timings)

audio = whisper.audio.load_audio(self.file_path)
self.duration_audio_ms = len(audio) * 1000 / whisper.audio.SAMPLE_RATE

whisper_params = whisper_cpp_params(language=self.language if self.language is not None else '', task=self.task,
word_level_timings=self.word_level_timings)
whisper_params.encoder_begin_callback_user_data = ctypes.c_void_p(id(self.state))
whisper_params.encoder_begin_callback = whisper_cpp.whisper_encoder_begin_callback(self.encoder_begin_callback)
whisper_params = whisper_cpp_params(
language=self.language if self.language is not None else '', task=self.task,
word_level_timings=self.word_level_timings)
whisper_params.encoder_begin_callback_user_data = ctypes.c_void_p(
id(self.state))
whisper_params.encoder_begin_callback = whisper_cpp.whisper_encoder_begin_callback(
self.encoder_begin_callback)
whisper_params.new_segment_callback_user_data = ctypes.c_void_p(id(self.state))
whisper_params.new_segment_callback = whisper_cpp.whisper_new_segment_callback(self.new_segment_callback)
whisper_params.new_segment_callback = whisper_cpp.whisper_new_segment_callback(
self.new_segment_callback)

model = WhisperCpp(model=model_path)
result = model.transcribe(audio=self.file_path, params=whisper_params)
Expand All @@ -199,13 +205,15 @@ def new_segment_callback(self, ctx, _state, _n_new, user_data):
# t1 seems to sometimes be larger than the duration when the
# audio ends in silence. Trim to fix the displayed progress.
progress = min(t1 * 10, self.duration_audio_ms)
state: WhisperCppFileTranscriber.State = ctypes.cast(user_data, ctypes.py_object).value
state: WhisperCppFileTranscriber.State = ctypes.cast(user_data,
ctypes.py_object).value
if state.running:
self.progress.emit((progress, self.duration_audio_ms))

@staticmethod
def encoder_begin_callback(_ctx, _state, user_data):
state: WhisperCppFileTranscriber.State = ctypes.cast(user_data, ctypes.py_object).value
state: WhisperCppFileTranscriber.State = ctypes.cast(user_data,
ctypes.py_object).value
return state.running == 1

def stop(self):
Expand All @@ -219,8 +227,10 @@ def __init__(self, task: FileTranscriptionTask, parent: Optional['QObject'] = No
self.task = task.transcription_options.task

def transcribe(self) -> List[Segment]:
logging.debug('Starting OpenAI Whisper API file transcription, file path = %s, task = %s', self.file_path,
self.task)
logging.debug(
'Starting OpenAI Whisper API file transcription, file path = %s, task = %s',
self.file_path,
self.task)

wav_file = tempfile.mktemp() + '.wav'
(
Expand All @@ -235,14 +245,18 @@ def transcribe(self) -> List[Segment]:
language = self.transcription_task.transcription_options.language
response_format = "verbose_json"
if self.transcription_task.transcription_options.task == Task.TRANSLATE:
transcript = openai.Audio.translate("whisper-1", audio_file, response_format=response_format,
transcript = openai.Audio.translate("whisper-1", audio_file,
response_format=response_format,
language=language)
else:
transcript = openai.Audio.transcribe("whisper-1", audio_file, response_format=response_format,
transcript = openai.Audio.transcribe("whisper-1", audio_file,
response_format=response_format,
language=language)

segments = [Segment(segment["start"] * 1000, segment["end"] * 1000, segment["text"]) for segment in
transcript["segments"]]
segments = [
Segment(segment["start"] * 1000, segment["end"] * 1000, segment["text"]) for
segment in
transcript["segments"]]
return segments

def stop(self):
Expand Down Expand Up @@ -273,7 +287,8 @@ def transcribe(self) -> List[Segment]:
recv_pipe, send_pipe = multiprocessing.Pipe(duplex=False)

self.current_process = multiprocessing.Process(target=self.transcribe_whisper,
args=(send_pipe, self.transcription_task))
args=(send_pipe,
self.transcription_task))
if not self.stopped:
self.current_process.start()
self.started_process = True
Expand All @@ -291,15 +306,17 @@ def transcribe(self) -> List[Segment]:

logging.debug(
'whisper process completed with code = %s, time taken = %s, number of segments = %s',
self.current_process.exitcode, datetime.datetime.now() - time_started, len(self.segments))
self.current_process.exitcode, datetime.datetime.now() - time_started,
len(self.segments))

if self.current_process.exitcode != 0:
raise Exception('Unknown error')

return self.segments

@classmethod
def transcribe_whisper(cls, stderr_conn: Connection, task: FileTranscriptionTask) -> None:
def transcribe_whisper(cls, stderr_conn: Connection,
task: FileTranscriptionTask) -> None:
with pipe_stderr(stderr_conn):
if task.transcription_options.model.model_type == ModelType.HUGGING_FACE:
segments = cls.transcribe_hugging_face(task)
Expand All @@ -308,7 +325,8 @@ def transcribe_whisper(cls, stderr_conn: Connection, task: FileTranscriptionTask
elif task.transcription_options.model.model_type == ModelType.WHISPER:
segments = cls.transcribe_openai_whisper(task)
else:
raise Exception(f"Invalid model type: {task.transcription_options.model.model_type}")
raise Exception(
f"Invalid model type: {task.transcription_options.model.model_type}")

segments_json = json.dumps(
segments, ensure_ascii=True, default=vars)
Expand All @@ -321,7 +339,8 @@ def transcribe_hugging_face(cls, task: FileTranscriptionTask) -> List[Segment]:
model = transformers_whisper.load_model(task.model_path)
language = task.transcription_options.language if task.transcription_options.language is not None else 'en'
result = model.transcribe(audio=task.file_path, language=language,
task=task.transcription_options.task.value, verbose=False)
task=task.transcription_options.task.value,
verbose=False)
return [
Segment(
start=int(segment.get('start') * 1000),
Expand Down Expand Up @@ -368,7 +387,8 @@ def transcribe_openai_whisper(cls, task: FileTranscriptionTask) -> List[Segment]
stable_whisper.modify_model(model)
result = model.transcribe(
audio=task.file_path, language=task.transcription_options.language,
task=task.transcription_options.task.value, temperature=task.transcription_options.temperature,
task=task.transcription_options.task.value,
temperature=task.transcription_options.temperature,
initial_prompt=task.transcription_options.initial_prompt, pbar=True)
segments = stable_whisper.group_word_timestamps(result)
return [Segment(
Expand Down Expand Up @@ -423,7 +443,8 @@ def read_line(self, pipe: Connection):

def write_output(path: str, segments: List[Segment], output_format: OutputFormat):
logging.debug(
'Writing transcription output, path = %s, output format = %s, number of segments = %s', path, output_format,
'Writing transcription output, path = %s, output format = %s, number of segments = %s',
path, output_format,
len(segments))

with open(path, 'w', encoding='utf-8') as file:
Expand Down Expand Up @@ -473,8 +494,22 @@ def to_timestamp(ms: float, ms_separator='.') -> str:
Video files (*.mp4 *.webm *.ogm *.mov);;All files (*.*)'


def get_default_output_file_path(task: Task, input_file_path: str, output_format: OutputFormat):
return f'{os.path.splitext(input_file_path)[0]} ({task.value.title()}d on {datetime.datetime.now():%d-%b-%Y %H-%M-%S}).{output_format.value}'
def get_default_output_file_path(task: FileTranscriptionTask,
output_format: OutputFormat):
input_file_name = os.path.splitext(task.file_path)[0]
date_time_now = datetime.datetime.now().strftime('%d-%b-%Y %H-%M-%S')
return (task.file_transcription_options.default_output_file_name
.replace('{{ input_file_name }}', input_file_name)
.replace('{{ task }}', task.transcription_options.task.value)
.replace('{{ language }}', task.transcription_options.language or '')
.replace('{{ model_type }}',
task.transcription_options.model.model_type.value)
.replace('{{ model_size }}',
task.transcription_options.model.whisper_model_size.value if
task.transcription_options.model.whisper_model_size is not None else
'')
.replace('{{ date_time }}', date_time_now)
+ f".{output_format.value}")


def whisper_cpp_params(
Expand Down
19 changes: 17 additions & 2 deletions buzz/widgets/menu_bar.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import webbrowser
from typing import Dict

from PyQt6.QtCore import pyqtSignal
Expand All @@ -15,11 +16,14 @@ class MenuBar(QMenuBar):
import_action_triggered = pyqtSignal()
shortcuts_changed = pyqtSignal(dict)
openai_api_key_changed = pyqtSignal(str)
default_export_file_name_changed = pyqtSignal(str)

def __init__(self, shortcuts: Dict[str, str], parent: QWidget):
def __init__(self, shortcuts: Dict[str, str], default_export_file_name: str,
parent: QWidget):
super().__init__(parent)

self.shortcuts = shortcuts
self.default_export_file_name = default_export_file_name

self.import_action = QAction(_("Import Media File..."), self)
self.import_action.triggered.connect(
Expand All @@ -31,13 +35,17 @@ def __init__(self, shortcuts: Dict[str, str], parent: QWidget):
self.preferences_action = QAction(_("Preferences..."), self)
self.preferences_action.triggered.connect(self.on_preferences_action_triggered)

help_action = QAction(f'{_("Help")}', self)
help_action.triggered.connect(self.on_help_action_triggered)

self.set_shortcuts(shortcuts)

file_menu = self.addMenu(_("File"))
file_menu.addAction(self.import_action)

help_menu = self.addMenu(_("Help"))
help_menu.addAction(about_action)
help_menu.addAction(help_action)
help_menu.addAction(self.preferences_action)

def on_import_action_triggered(self):
Expand All @@ -48,11 +56,18 @@ def on_about_action_triggered(self):
about_dialog.open()

def on_preferences_action_triggered(self):
preferences_dialog = PreferencesDialog(shortcuts=self.shortcuts, parent=self)
preferences_dialog = PreferencesDialog(shortcuts=self.shortcuts,
default_export_file_name=self.default_export_file_name,
parent=self)
preferences_dialog.shortcuts_changed.connect(self.shortcuts_changed)
preferences_dialog.openai_api_key_changed.connect(self.openai_api_key_changed)
preferences_dialog.default_export_file_name_changed.connect(
self.default_export_file_name_changed)
preferences_dialog.open()

def on_help_action_triggered(self):
webbrowser.open('https://chidiwilliams.github.io/buzz/docs')

def set_shortcuts(self, shortcuts: Dict[str, str]):
self.shortcuts = shortcuts

Expand Down
22 changes: 17 additions & 5 deletions buzz/widgets/preferences_dialog/general_preferences_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,33 +3,45 @@

import openai
from PyQt6.QtCore import QRunnable, QObject, pyqtSignal, QThreadPool
from PyQt6.QtWidgets import QWidget, QFormLayout, QPushButton, QMessageBox
from PyQt6.QtWidgets import QWidget, QFormLayout, QPushButton, QMessageBox, QLineEdit
from openai.error import AuthenticationError

from buzz.store.keyring_store import KeyringStore
from buzz.widgets.line_edit import LineEdit
from buzz.widgets.openai_api_key_line_edit import OpenAIAPIKeyLineEdit


class GeneralPreferencesWidget(QWidget):
openai_api_key_changed = pyqtSignal(str)
default_export_file_name_changed = pyqtSignal(str)

def __init__(self, keyring_store=KeyringStore(), parent: Optional[QWidget] = None):
def __init__(self, default_export_file_name: str, keyring_store=KeyringStore(),
parent: Optional[QWidget] = None):
super().__init__(parent)

self.openai_api_key = keyring_store.get_password(KeyringStore.Key.OPENAI_API_KEY)
self.openai_api_key = keyring_store.get_password(
KeyringStore.Key.OPENAI_API_KEY)

layout = QFormLayout(self)

self.openai_api_key_line_edit = OpenAIAPIKeyLineEdit(self.openai_api_key, self)
self.openai_api_key_line_edit.key_changed.connect(self.on_openai_api_key_changed)
self.openai_api_key_line_edit.key_changed.connect(
self.on_openai_api_key_changed)

self.test_openai_api_key_button = QPushButton('Test')
self.test_openai_api_key_button.clicked.connect(self.on_click_test_openai_api_key_button)
self.test_openai_api_key_button.clicked.connect(
self.on_click_test_openai_api_key_button)
self.update_test_openai_api_key_button()

layout.addRow('OpenAI API Key', self.openai_api_key_line_edit)
layout.addRow('', self.test_openai_api_key_button)

default_export_file_name_line_edit = LineEdit(default_export_file_name, self)
default_export_file_name_line_edit.textChanged.connect(
self.default_export_file_name_changed)
default_export_file_name_line_edit.setMinimumWidth(200)
layout.addRow('Default export file name', default_export_file_name_line_edit)

self.setLayout(layout)

def update_test_openai_api_key_button(self):
Expand Down
8 changes: 6 additions & 2 deletions buzz/widgets/preferences_dialog/preferences_dialog.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
class PreferencesDialog(QDialog):
shortcuts_changed = pyqtSignal(dict)
openai_api_key_changed = pyqtSignal(str)
default_export_file_name_changed = pyqtSignal(str)

def __init__(self, shortcuts: Dict[str, str],
def __init__(self, shortcuts: Dict[str, str], default_export_file_name: str,
parent: Optional[QWidget] = None) -> None:
super().__init__(parent)

Expand All @@ -25,8 +26,11 @@ def __init__(self, shortcuts: Dict[str, str],
layout = QVBoxLayout(self)
tab_widget = QTabWidget(self)

general_tab_widget = GeneralPreferencesWidget(parent=self)
general_tab_widget = GeneralPreferencesWidget(
default_export_file_name=default_export_file_name, parent=self)
general_tab_widget.openai_api_key_changed.connect(self.openai_api_key_changed)
general_tab_widget.default_export_file_name_changed.connect(
self.default_export_file_name_changed)
tab_widget.addTab(general_tab_widget, _('General'))

models_tab_widget = ModelsPreferencesWidget(parent=self)
Expand Down
Empty file.
8 changes: 8 additions & 0 deletions buzz/widgets/transcriber/advanced_settings_button.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from typing import Optional

from PyQt6.QtWidgets import QPushButton, QWidget


class AdvancedSettingsButton(QPushButton):
def __init__(self, parent: Optional[QWidget]) -> None:
super().__init__('Advanced...', parent)
Loading

0 comments on commit f5f77b3

Please sign in to comment.