diff --git a/dragonfly/engines/backend_kaldi/compiler.py b/dragonfly/engines/backend_kaldi/compiler.py index 7118511c..0d43b4bd 100644 --- a/dragonfly/engines/backend_kaldi/compiler.py +++ b/dragonfly/engines/backend_kaldi/compiler.py @@ -32,8 +32,7 @@ from dragonfly.grammar import elements as elements_ from dragonfly.engines.base import CompilerBase, CompilerError -from dragonfly.engines.backend_kaldi.dictation import (AlternativeDictation, - DefaultDictation, +from dragonfly.engines.backend_kaldi.dictation import (_get_dictation_nonterminal, UserDictation) #--------------------------------------------------------------------------- @@ -336,8 +335,7 @@ def _compile_dictation(self, element, src_state, dst_state, grammar, kaldi_rule, src_state = self.add_weight_linkage(src_state, dst_state, self.get_weight(element), fst) # fst.add_arc(src_state, dst_state, '#nonterm:dictation', olabel=WFST.eps) extra_state = fst.add_state() - cloud_dictation = isinstance(element, (AlternativeDictation, DefaultDictation)) and element.cloud - dictation_nonterm = '#nonterm:dictation_cloud' if cloud_dictation else '#nonterm:dictation' + dictation_nonterm = _get_dictation_nonterminal(element) fst.add_arc(src_state, extra_state, '#nonterm:dictation', dictation_nonterm) # Accepts zero or more words fst.add_arc(extra_state, dst_state, WFST.eps, '#nonterm:end') diff --git a/dragonfly/engines/backend_kaldi/dictation.py b/dragonfly/engines/backend_kaldi/dictation.py index bf1ac079..f426e503 100644 --- a/dragonfly/engines/backend_kaldi/dictation.py +++ b/dragonfly/engines/backend_kaldi/dictation.py @@ -86,6 +86,12 @@ def value(self, node): #--------------------------------------------------------------------------- # Alternative dictation classes -- elements capable of default or alternative dictation. +def _get_dictation_nonterminal(element): + if isinstance(element, (AlternativeDictation, DefaultDictation)) and element.alternative: + return '#nonterm:dictation_cloud' + return '#nonterm:dictation' + + class AlternativeDictation(BaseDictation): alternative_default = True diff --git a/dragonfly/test/suites.py b/dragonfly/test/suites.py index d19519a6..27ca5be4 100644 --- a/dragonfly/test/suites.py +++ b/dragonfly/test/suites.py @@ -120,6 +120,7 @@ ] + common_names, "kaldi": [ + "test_dictation_kaldi", "test_engine_kaldi", "test_language_en_number", diff --git a/dragonfly/test/test_dictation_kaldi.py b/dragonfly/test/test_dictation_kaldi.py new file mode 100644 index 00000000..7bd66051 --- /dev/null +++ b/dragonfly/test/test_dictation_kaldi.py @@ -0,0 +1,29 @@ +# +# This file is part of Dragonfly. +# Licensed under the LGPL. +# + +import unittest + +from dragonfly.grammar.elements_basic import Dictation +from dragonfly.engines.backend_kaldi.dictation import ( + AlternativeDictation, + DefaultDictation, + _get_dictation_nonterminal, +) + + +class TestKaldiDictationNonterminal(unittest.TestCase): + + def test_specialized_dictation_elements_choose_expected_nonterminal(self): + cases = ( + (Dictation("text"), '#nonterm:dictation'), + (AlternativeDictation("text"), '#nonterm:dictation_cloud'), + (DefaultDictation("text"), '#nonterm:dictation'), + (AlternativeDictation("text", alternative=False), '#nonterm:dictation'), + (DefaultDictation("text", alternative=True), '#nonterm:dictation_cloud'), + ) + + for element, expected in cases: + with self.subTest(element=repr(element)): + self.assertEqual(expected, _get_dictation_nonterminal(element))