Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions dragonfly/engines/backend_kaldi/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@

from dragonfly.grammar import elements as elements_
from dragonfly.engines.base import CompilerBase, CompilerError
from dragonfly.engines.backend_kaldi.dictation import (AlternativeDictation,
DefaultDictation,
from dragonfly.engines.backend_kaldi.dictation import (_get_dictation_nonterminal,
UserDictation)

#---------------------------------------------------------------------------
Expand Down Expand Up @@ -336,8 +335,7 @@ def _compile_dictation(self, element, src_state, dst_state, grammar, kaldi_rule,
src_state = self.add_weight_linkage(src_state, dst_state, self.get_weight(element), fst)
# fst.add_arc(src_state, dst_state, '#nonterm:dictation', olabel=WFST.eps)
extra_state = fst.add_state()
cloud_dictation = isinstance(element, (AlternativeDictation, DefaultDictation)) and element.cloud
dictation_nonterm = '#nonterm:dictation_cloud' if cloud_dictation else '#nonterm:dictation'
dictation_nonterm = _get_dictation_nonterminal(element)
fst.add_arc(src_state, extra_state, '#nonterm:dictation', dictation_nonterm)
# Accepts zero or more words
fst.add_arc(extra_state, dst_state, WFST.eps, '#nonterm:end')
Expand Down
6 changes: 6 additions & 0 deletions dragonfly/engines/backend_kaldi/dictation.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,12 @@ def value(self, node):
#---------------------------------------------------------------------------
# Alternative dictation classes -- elements capable of default or alternative dictation.

def _get_dictation_nonterminal(element):
if isinstance(element, (AlternativeDictation, DefaultDictation)) and element.alternative:
return '#nonterm:dictation_cloud'
return '#nonterm:dictation'


class AlternativeDictation(BaseDictation):

alternative_default = True
Expand Down
1 change: 1 addition & 0 deletions dragonfly/test/suites.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@
] + common_names,

"kaldi": [
"test_dictation_kaldi",
"test_engine_kaldi",
"test_language_en_number",

Expand Down
29 changes: 29 additions & 0 deletions dragonfly/test/test_dictation_kaldi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#
# This file is part of Dragonfly.
# Licensed under the LGPL.
#

import unittest

from dragonfly.grammar.elements_basic import Dictation
from dragonfly.engines.backend_kaldi.dictation import (
AlternativeDictation,
DefaultDictation,
_get_dictation_nonterminal,
)


class TestKaldiDictationNonterminal(unittest.TestCase):

def test_specialized_dictation_elements_choose_expected_nonterminal(self):
cases = (
(Dictation("text"), '#nonterm:dictation'),
(AlternativeDictation("text"), '#nonterm:dictation_cloud'),
(DefaultDictation("text"), '#nonterm:dictation'),
(AlternativeDictation("text", alternative=False), '#nonterm:dictation'),
(DefaultDictation("text", alternative=True), '#nonterm:dictation_cloud'),
)

for element, expected in cases:
with self.subTest(element=repr(element)):
self.assertEqual(expected, _get_dictation_nonterminal(element))