Skip to content

Commit

Permalink
add more confidence info
Browse files Browse the repository at this point in the history
  • Loading branch information
daanzu committed May 30, 2020
1 parent 30e9686 commit c353cfa
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 18 deletions.
7 changes: 4 additions & 3 deletions examples/plain_dictation.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import sys, wave
import logging, sys, wave
from kaldi_active_grammar import PlainDictationRecognizer

# logging.basicConfig(level=10)
recognizer = PlainDictationRecognizer() # Or supply non-default model_dir, tmp_dir, or fst_file
filename = sys.argv[1] if len(sys.argv) > 1 else 'test.wav'
wave_file = wave.open(filename, 'rb')
data = wave_file.readframes(wave_file.getnframes())
output_str, likelihood = recognizer.decode_utterance(data)
print(repr(output_str), likelihood) # -> 'alpha bravo charlie' 1.1923989057540894
output_str, info = recognizer.decode_utterance(data)
print(repr(output_str), info) # -> 'it depends on the context'
4 changes: 2 additions & 2 deletions kaldi_active_grammar/plain_dictation.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,6 @@ def decode_utterance(self, samples_data):
and returning a tuple of (output (*text*), likelihood (*float*)).
"""
self.decoder.decode(samples_data, True)
output_str, likelihood = self.decoder.get_output()
output_str, info = self.decoder.get_output()
output_str = remove_nonterms_in_text(output_str)
return (output_str, likelihood)
return (output_str, info)
41 changes: 28 additions & 13 deletions kaldi_active_grammar/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,10 @@ def get_output(self, output_max_length=4*1024):
likelihood_p = _ffi.new('double *')
result = self._lib.get_output_gmm(self._model, output_p, output_max_length, likelihood_p)
output_str = _ffi.string(output_p)
likelihood = likelihood_p[0]
return output_str, likelihood
info = {
'likelihood': likelihood_p[0],
}
return output_str, info


########################################################################################################################
Expand Down Expand Up @@ -208,8 +210,10 @@ def get_output(self, output_max_length=4*1024):
likelihood_p = _ffi.new('double *')
result = self._lib.get_output_otf_gmm(self._model, output_p, output_max_length, likelihood_p)
output_str = _ffi.string(output_p)
likelihood = likelihood_p[0]
return output_str, likelihood
info = {
'likelihood': likelihood_p[0],
}
return output_str, info


########################################################################################################################
Expand Down Expand Up @@ -314,9 +318,10 @@ def get_output(self, output_max_length=4*1024):
if not result:
raise KaldiError("get_output error")
output_str = de(_ffi.string(output_p))
likelihood = likelihood_p[0]
# _log.debug("get_output: likelihood %f, %r", likelihood, output_str)
return output_str, likelihood
info = {
'likelihood': likelihood_p[0],
}
return output_str, info

def get_word_align(self, output):
"""Returns list of tuples: words (including nonterminals but not eps), each's time (in bytes), and each's length (in bytes)."""
Expand Down Expand Up @@ -355,7 +360,8 @@ class KaldiAgfNNet3Decoder(KaldiNNet3Decoder):
extern "C" DRAGONFLY_API bool remove_grammar_fst_agf_nnet3(void* model_vp, int32_t grammar_fst_index);
extern "C" DRAGONFLY_API bool decode_agf_nnet3(void* model_vp, float samp_freq, int32_t num_frames, float* frames, bool finalize,
bool* grammars_activity_cp, int32_t grammars_activity_cp_size, bool save_adaptation_state);
extern "C" DRAGONFLY_API bool get_output_agf_nnet3(void* model_vp, char* output, int32_t output_max_length, double* likelihood_p);
extern "C" DRAGONFLY_API bool get_output_agf_nnet3(void* model_vp, char* output, int32_t output_max_length,
float* likelihood_p, float* am_score_p, float* lm_score_p, float* confidence_p, float* expected_error_rate_p);
extern "C" DRAGONFLY_API bool get_word_align_agf_nnet3(void* model_vp, int32_t* times_cp, int32_t* lengths_cp, int32_t num_words);
extern "C" DRAGONFLY_API bool save_adaptation_state_agf_nnet3(void* model_vp);
extern "C" DRAGONFLY_API bool reset_adaptation_state_agf_nnet3(void* model_vp);
Expand Down Expand Up @@ -469,14 +475,23 @@ def decode(self, frames, finalize, grammars_activity=None):

def get_output(self, output_max_length=4*1024):
output_p = _ffi.new('char[]', output_max_length)
likelihood_p = _ffi.new('double *')
result = self._lib.get_output_agf_nnet3(self._model, output_p, output_max_length, likelihood_p)
likelihood_p = _ffi.new('float *')
am_score_p = _ffi.new('float *')
lm_score_p = _ffi.new('float *')
confidence_p = _ffi.new('float *')
expected_error_rate_p = _ffi.new('float *')
result = self._lib.get_output_agf_nnet3(self._model, output_p, output_max_length, likelihood_p, am_score_p, lm_score_p, confidence_p, expected_error_rate_p)
if not result:
raise KaldiError("get_output error")
output_str = de(_ffi.string(output_p))
likelihood = likelihood_p[0]
# _log.debug("get_output: likelihood %f, %r", likelihood, output_str)
return output_str, likelihood
info = {
'likelihood': likelihood_p[0],
'am_score': am_score_p[0],
'lm_score': lm_score_p[0],
'confidence': confidence_p[0],
'expected_error_rate': expected_error_rate_p[0],
}
return output_str, info

def get_word_align(self, output):
"""Returns list of tuples: words (including nonterminals but not eps), each's time (in bytes), and each's length (in bytes)."""
Expand Down

0 comments on commit c353cfa

Please sign in to comment.