Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A crude way of using OpenAI Whisper for alternative dictation in KaldiAG #73

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 21 additions & 3 deletions kaldi_active_grammar/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .wfst import WFST, NativeWFST, SymbolTable
from .model import Model
from .wrapper import KaldiAgfCompiler, KaldiAgfNNet3Decoder, KaldiLafNNet3Decoder
import kaldi_active_grammar.whisper_dictation as whisper_dictation
import kaldi_active_grammar.defaults as defaults

_log = _log.getChild('compiler')
Expand Down Expand Up @@ -646,6 +647,7 @@ def parse_output_for_rule(self, kaldi_rule, output):
self._log.error("parsed_output(%r).lower() != output(%r)" % (parsed_output, output))
return words

plain_dictation_regex = re.compile(r'(?<=#nonterm:dictation )(.*?)(?= #nonterm:end)') # lookbehind & lookahead assertions
alternative_dictation_regex = re.compile(r'(?<=#nonterm:dictation_cloud )(.*?)(?= #nonterm:end)') # lookbehind & lookahead assertions

def parse_output(self, output, dictation_info_func=None):
Expand All @@ -659,10 +661,16 @@ def parse_output(self, output, dictation_info_func=None):
kaldi_rule_id = int(nonterm_token[len('#nonterm:rule'):])
kaldi_rule = self.kaldi_rule_by_id_dict[kaldi_rule_id]

if self.alternative_dictation and dictation_info_func and kaldi_rule.has_dictation and '#nonterm:dictation_cloud' in parsed_output:
# Debug dictation settings
#print("DEBUG: ", self.alternative_dictation, "B", dictation_info_func, "C", kaldi_rule.has_dictation, "D", parsed_output)

#if self.alternative_dictation and dictation_info_func and kaldi_rule.has_dictation and '#nonterm:dictation_cloud' in parsed_output:
if self.alternative_dictation and dictation_info_func and kaldi_rule.has_dictation and '#nonterm:dictation' in parsed_output:
try:
if callable(self.alternative_dictation):
alternative_text_func = self.alternative_dictation
elif self.alternative_dictation == 'whisper':
alternative_text_func = whisper_dictation.Whisper.transcribe_data_sync
else:
raise TypeError("Invalid alternative_dictation value: %r" % self.alternative_dictation)

Expand All @@ -677,7 +685,8 @@ def parse_output(self, output, dictation_info_func=None):
'offset_end': times[words.index('#nonterm:end', index)],
}
for index, (word, time, length) in enumerate(word_align)
if word.startswith('#nonterm:dictation_cloud')]
if word.startswith('#nonterm:dictation')]
#if word.startswith('#nonterm:dictation_cloud')]

# If last dictation is at end of utterance, include rest of audio_data; else, include half of audio_data between dictation end and start of next word
dictation_span = dictation_spans[-1]
Expand All @@ -688,9 +697,17 @@ def parse_output(self, output, dictation_info_func=None):
dictation_span['offset_end'] = (dictation_span['offset_end'] + next_word_time) // 2

def replace_dictation(matchobj):
orig_text = matchobj.group(1)
orig_text = matchobj.group(1) # "orig_text" holds the dictation result from Kaldi dictation.
dictation_span = dictation_spans.pop(0)
dictation_audio = audio_data[dictation_span['offset_start'] : dictation_span['offset_end']]
if self.alternative_dictation == 'whisper':
self.cloud_dictation_lang = "en-US" # FIXME: hardcoded language!
# Whisper dictation backend can take audio data in a wav file.
# Store a file in the system temp folder (this should work on Linux and Windows, and probably OS X)
#import tempfile
#temp_dir = tempfile.TemporaryDirectory().name
#audio_filename = os.path.join(temp_dir,"whisper.wav")
#whisper_dictation.write_wav('/tmp/whisper.wav', dictation_audio)
kwargs = dict(language_code=self.cloud_dictation_lang)
with debug_timer(self._log.debug, 'alternative_dictation call'):
alternative_text = alternative_text_func(dictation_audio, **kwargs)
Expand All @@ -699,6 +716,7 @@ def replace_dictation(matchobj):
return (alternative_text or orig_text)

parsed_output = self.alternative_dictation_regex.sub(replace_dictation, parsed_output)
parsed_output = self.plain_dictation_regex.sub(replace_dictation, parsed_output)
except Exception as e:
self._log.exception("Exception performing alternative dictation")

Expand Down
108 changes: 108 additions & 0 deletions kaldi_active_grammar/whisper_dictation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# A crude way of using OpenAI Whisper for dictation in KaldiAG.
# This is the RPC client, that sends data to the local whisper RPC server process.
# By Shervin Emami (www.shervinemami.com) 2022
# Based on "alternative_dictation.py" from KaldiAG v1.8, when KaldiAG had some basic support for GCloud dictation.
#
# KaldiAG is (c) Copyright 2019 by David Zurow
# Licensed under the AGPL-3.0; see LICENSE.txt file.
#

# Compatibility between Python2 vs Python3:
from __future__ import print_function # print function with Python 2/3 compatibility
from __future__ import division

import sys
if sys.version_info[0] == 3:
# Python3
from xmlrpc.client import ServerProxy
else:
# Python2
from xmlrpclib import ServerProxy
import wave

verbose = False

WHISPER_SERVER_ACCESS = "http://127.0.0.1:8002" # Where to find our whisper server. Note that Shervin's KaldiAG setup already runs RPC servers on ports 8000 and 8001
whisper_client = ServerProxy(WHISPER_SERVER_ACCESS, allow_none=True)

# Choose what to do if whisper dictation fails (eg: trouble connecting to our local whisper RPC server),
# Some users will want to return "None" so that their Kaldi or other dictation backend will perform the dictation without interrupting the user.
# But some users will want the entire speech engine to close, so that it's obvious when whisper didn't work.
EXIT_IF_WHISPER_FAILED = True


# Create a new process, for the whisper_server to run in the background. It expects "whisper_server.py" to be in the same folder as this Python file.
import subprocess
import os
pardir = os.path.abspath(os.path.join(__file__, os.pardir))
whisper_server = os.path.abspath(os.path.join(pardir, "whisper_server.py"))
subprocess.Popen([sys.executable, whisper_server])



def write_wav(filename, audio_data, sample_rate=16000):
wf = wave.open(filename, 'wb')
wf.setnchannels(1)
wf.setsampwidth(2)
wf.setframerate(sample_rate)
wf.writeframes(audio_data)
wf.close()


def testCUDA():
print("Test CUDA")

import torch
# Making the code device-agnostic
device_name = 'cuda' if torch.cuda.is_available() else 'cpu'
if device_name == 'cuda':
print(f"CUDA version: {torch.version.cuda}")
print(f"Name of current CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")

# Creating a test tensor
x = torch.randint(1, 100, (100, 1000))
# Checking the device name:
# Should return 'cpu' by default
print("Default pytorch device (should be 'CPU'): ", x.device)
# Transferring tensor to GPU
x = x.to(torch.device(device_name))
# Checking the device name:
# Should return 'cuda:0'
print("CUDA pytorch device (should be 'cuda:0'): ", x.device)
# Applying same GPU-accelerated tensor operation
res_gpu = x ** 2
res_cpu = res_gpu.cpu()
print("result: ", res_cpu)



class Whisper(object):

# Use Whisper to convert the audio data into a text string. If speech_data is not given, will load the audio from a wav file.
@staticmethod
def transcribe_data_sync(speech_data=None, model='default', language_code='en-US'):
# It's possible that calling a GPU-accelerated PyTorch function within the KaldiAG Dragonfly process will cause Dragonfly's
# calls to xdotool via Text() can have quite long latency on Linux (~200ms instead of ~50ms per call!). So
# here (within the Dragonfly process) we will make an RPC interprocess call to our whisper process, that can be GPU-accelerated.

# For debugging latency of GPU-accelerated PyTorch:
#testCUDA()
#return "words"

try:
print("Calling the whisper_server RPC server.")
result = whisper_client.transcribe_using_whisper(speech_data, model, language_code)
if result:
return result
except Exception as e:
print("Warning: Exception ", e)
print("Couldn't access the whisper_server at", WHISPER_SERVER_ACCESS, ", is it running?")

# If we've gotten to this line here, then whisper dictation failed.
if EXIT_IF_WHISPER_FAILED:
print("Exiting the speech recognition engine, since whisper failed.")
os.kill(os.getpid(), 9)
sys.exit(1)

return None

Loading