Skip to content

Commit

Permalink
Add delete model and show model location options (#557)
Browse files Browse the repository at this point in the history
  • Loading branch information
chidiwilliams committed Aug 3, 2023
1 parent f1550f8 commit af4d57b
Show file tree
Hide file tree
Showing 22 changed files with 438 additions and 284 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,5 @@ locale/**/*.mo

benchmarks.json

.eggs
.eggs
*.egg-info
4 changes: 2 additions & 2 deletions buzz/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from PyQt6.QtCore import QCommandLineParser, QCommandLineOption

from buzz.gui import Application
from buzz.model_loader import ModelType, WhisperModelSize, TranscriptionModel, get_local_model_path
from buzz.model_loader import ModelType, WhisperModelSize, TranscriptionModel
from buzz.store.keyring_store import KeyringStore
from buzz.transcriber import Task, FileTranscriptionTask, FileTranscriptionOptions, TranscriptionOptions, LANGUAGES, \
OutputFormat
Expand Down Expand Up @@ -103,7 +103,7 @@ def parse(app: Application, parser: QCommandLineParser):

model = TranscriptionModel(model_type=ModelType[model_type.name], whisper_model_size=model_size,
hugging_face_model_id=hugging_face_model_id)
model_path = get_local_model_path(model)
model_path = model.get_local_model_path()

if model_path is None:
raise CommandLineError('Model not found')
Expand Down
142 changes: 9 additions & 133 deletions buzz/gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,22 @@
from PyQt6.QtCore import (QObject, Qt, QThread,
QTimer, QUrl, pyqtSignal, QModelIndex, QPoint,
QUrlQuery, QMetaObject, QEvent, QThreadPool)
from PyQt6.QtGui import (QAction, QCloseEvent, QDesktopServices, QIcon,
QKeySequence, QPixmap, QTextCursor, QValidator, QKeyEvent, QPainter, QColor)
from PyQt6.QtGui import (QCloseEvent, QIcon,
QKeySequence, QTextCursor, QValidator, QKeyEvent, QPainter, QColor)
from PyQt6.QtNetwork import QNetworkAccessManager, QNetworkReply, QNetworkRequest
from PyQt6.QtWidgets import (QApplication, QCheckBox, QComboBox, QDialog,
QDialogButtonBox, QFileDialog, QLabel, QMainWindow, QMessageBox, QPlainTextEdit,
QPushButton, QVBoxLayout, QHBoxLayout, QWidget, QGroupBox, QMenuBar, QFormLayout,
QPushButton, QVBoxLayout, QHBoxLayout, QWidget, QGroupBox,
QFormLayout,
QAbstractItemView, QListWidget, QListWidgetItem, QSizePolicy)

from buzz.cache import TasksCache
from .__version__ import VERSION
from .action import Action
from .assets import get_asset_path
from .icon import Icon
from .widgets.icon import Icon, BUZZ_ICON_PATH
from .locale import _
from .model_loader import WhisperModelSize, ModelType, TranscriptionModel, get_local_model_path, \
from .model_loader import WhisperModelSize, ModelType, TranscriptionModel, \
ModelDownloader
from .paths import file_paths_as_title
from .recording import RecordingAmplitudeListener
Expand All @@ -40,10 +41,10 @@
from .recording_transcriber import RecordingTranscriber
from .file_transcriber_queue_worker import FileTranscriberQueueWorker
from .widgets.line_edit import LineEdit
from .widgets.menu_bar import MenuBar
from .widgets.model_download_progress_dialog import ModelDownloadProgressDialog
from .widgets.model_type_combo_box import ModelTypeComboBox
from .widgets.openai_api_key_line_edit import OpenAIAPIKeyLineEdit
from .widgets.preferences_dialog import PreferencesDialog
from .widgets.toolbar import ToolBar
from .widgets.transcription_tasks_table_widget import TranscriptionTasksTableWidget
from .widgets.transcription_viewer_widget import TranscriptionViewerWidget
Expand Down Expand Up @@ -264,7 +265,7 @@ def on_transcription_options_changed(self, transcription_options: TranscriptionO
def on_click_run(self):
self.run_button.setDisabled(True)

model_path = get_local_model_path(model=self.transcription_options.model)
model_path = self.transcription_options.model.get_local_model_path()
if model_path is not None:
self.on_model_loaded(model_path)
return
Expand Down Expand Up @@ -508,7 +509,7 @@ def on_record_button_clicked(self):
def start_recording(self):
self.record_button.setDisabled(True)

model_path = get_local_model_path(model=self.transcription_options.model)
model_path = self.transcription_options.model.get_local_model_path()
if model_path is not None:
self.on_model_loaded(model_path)
return
Expand Down Expand Up @@ -641,89 +642,13 @@ def closeEvent(self, event: QCloseEvent) -> None:
return super().closeEvent(event)


BUZZ_ICON_PATH = get_asset_path('assets/buzz.ico')
BUZZ_LARGE_ICON_PATH = get_asset_path('assets/buzz-icon-1024.png')
RECORD_ICON_PATH = get_asset_path('assets/mic_FILL0_wght700_GRAD0_opsz48.svg')
EXPAND_ICON_PATH = get_asset_path('assets/open_in_full_FILL0_wght700_GRAD0_opsz48.svg')
ADD_ICON_PATH = get_asset_path('assets/add_FILL0_wght700_GRAD0_opsz48.svg')
TRASH_ICON_PATH = get_asset_path('assets/delete_FILL0_wght700_GRAD0_opsz48.svg')
CANCEL_ICON_PATH = get_asset_path('assets/cancel_FILL0_wght700_GRAD0_opsz48.svg')


class AboutDialog(QDialog):
GITHUB_API_LATEST_RELEASE_URL = 'https://api.github.com/repos/chidiwilliams/buzz/releases/latest'
GITHUB_LATEST_RELEASE_URL = 'https://github.com/chidiwilliams/buzz/releases/latest'

def __init__(self, network_access_manager: Optional[QNetworkAccessManager] = None,
parent: Optional[QWidget] = None) -> None:
super().__init__(parent)

self.setWindowIcon(QIcon(BUZZ_ICON_PATH))
self.setWindowTitle(f'{_("About")} {APP_NAME}')

if network_access_manager is None:
network_access_manager = QNetworkAccessManager()

self.network_access_manager = network_access_manager
self.network_access_manager.finished.connect(self.on_latest_release_reply)

layout = QVBoxLayout(self)

image_label = QLabel()
pixmap = QPixmap(BUZZ_LARGE_ICON_PATH).scaled(
80, 80, Qt.AspectRatioMode.KeepAspectRatio, Qt.TransformationMode.SmoothTransformation)
image_label.setPixmap(pixmap)
image_label.setAlignment(Qt.AlignmentFlag(
Qt.AlignmentFlag.AlignVCenter | Qt.AlignmentFlag.AlignHCenter))

buzz_label = QLabel(APP_NAME)
buzz_label.setAlignment(Qt.AlignmentFlag(
Qt.AlignmentFlag.AlignVCenter | Qt.AlignmentFlag.AlignHCenter))
buzz_label_font = QtGui.QFont()
buzz_label_font.setBold(True)
buzz_label_font.setPointSize(20)
buzz_label.setFont(buzz_label_font)

version_label = QLabel(f"{_('Version')} {VERSION}")
version_label.setAlignment(Qt.AlignmentFlag(
Qt.AlignmentFlag.AlignVCenter | Qt.AlignmentFlag.AlignHCenter))

self.check_updates_button = QPushButton(_('Check for updates'), self)
self.check_updates_button.clicked.connect(self.on_click_check_for_updates)

button_box = QDialogButtonBox(QDialogButtonBox.StandardButton(
QDialogButtonBox.StandardButton.Close), self)
button_box.accepted.connect(self.accept)
button_box.rejected.connect(self.reject)

layout.addWidget(image_label)
layout.addWidget(buzz_label)
layout.addWidget(version_label)
layout.addWidget(self.check_updates_button)
layout.addWidget(button_box)

self.setLayout(layout)

def on_click_check_for_updates(self):
url = QUrl(self.GITHUB_API_LATEST_RELEASE_URL)
self.network_access_manager.get(QNetworkRequest(url))
self.check_updates_button.setDisabled(True)

def on_latest_release_reply(self, reply: QNetworkReply):
if reply.error() == QNetworkReply.NetworkError.NoError:
response = json.loads(reply.readAll().data())
tag_name = response.get('name')
if self.is_version_lower(VERSION, tag_name[1:]):
QDesktopServices.openUrl(QUrl(self.GITHUB_LATEST_RELEASE_URL))
else:
QMessageBox.information(self, '', _("You're up to date!"))
self.check_updates_button.setEnabled(True)

@staticmethod
def is_version_lower(version_a: str, version_b: str):
return version_a.replace('.', '') < version_b.replace('.', '')


class MainWindowToolbar(ToolBar):
new_transcription_action_triggered: pyqtSignal
open_transcript_action_triggered: pyqtSignal
Expand Down Expand Up @@ -1244,55 +1169,6 @@ def on_hugging_face_model_changed(self, model: str):
self.transcription_options_changed.emit(self.transcription_options)


class MenuBar(QMenuBar):
import_action_triggered = pyqtSignal()
shortcuts_changed = pyqtSignal(dict)
openai_api_key_changed = pyqtSignal(str)

def __init__(self, shortcuts: Dict[str, str], parent: QWidget):
super().__init__(parent)

self.shortcuts = shortcuts

self.import_action = QAction(_("Import Media File..."), self)
self.import_action.triggered.connect(
self.on_import_action_triggered)

about_action = QAction(f'{_("About")} {APP_NAME}', self)
about_action.triggered.connect(self.on_about_action_triggered)

self.preferences_action = QAction(_("Preferences..."), self)
self.preferences_action.triggered.connect(self.on_preferences_action_triggered)

self.set_shortcuts(shortcuts)

file_menu = self.addMenu(_("File"))
file_menu.addAction(self.import_action)

help_menu = self.addMenu(_("Help"))
help_menu.addAction(about_action)
help_menu.addAction(self.preferences_action)

def on_import_action_triggered(self):
self.import_action_triggered.emit()

def on_about_action_triggered(self):
about_dialog = AboutDialog(parent=self)
about_dialog.open()

def on_preferences_action_triggered(self):
preferences_dialog = PreferencesDialog(shortcuts=self.shortcuts, parent=self)
preferences_dialog.shortcuts_changed.connect(self.shortcuts_changed)
preferences_dialog.openai_api_key_changed.connect(self.openai_api_key_changed)
preferences_dialog.open()

def set_shortcuts(self, shortcuts: Dict[str, str]):
self.shortcuts = shortcuts

self.import_action.setShortcut(QKeySequence.fromString(shortcuts[Shortcut.OPEN_IMPORT_WINDOW.name]))
self.preferences_action.setShortcut(QKeySequence.fromString(shortcuts[Shortcut.OPEN_PREFERENCES_WINDOW.name]))


class Application(QApplication):
window: MainWindow

Expand Down
Loading

0 comments on commit af4d57b

Please sign in to comment.