diff --git a/ovos_workshop/skills/common_query_skill.py b/ovos_workshop/skills/common_query_skill.py index 024f261..80f2e3e 100644 --- a/ovos_workshop/skills/common_query_skill.py +++ b/ovos_workshop/skills/common_query_skill.py @@ -13,7 +13,7 @@ from abc import abstractmethod from enum import IntEnum from os.path import dirname -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Union from ovos_bus_client import Message from ovos_utils.file_utils import resolve_resource_file @@ -150,7 +150,13 @@ def __handle_question_query(self, message: Message): level = result[1] answer = result[2] callback = result[3] if len(result) > 3 else {} - confidence = self.calc_confidence(match, search_phrase, level, answer) + if isinstance(level, float): + LOG.debug(f"Confidence directly reported by skill") + confidence = level + else: + LOG.info(f"Calculating confidence for level {level}") + confidence = self.__calc_confidence(match, search_phrase, level, + answer) if confidence > 1.0: LOG.warning(f"Calculated confidence {confidence} > 1.0") confidence = 1.0 @@ -167,8 +173,8 @@ def __handle_question_query(self, message: Message): "skill_id": self.skill_id, "searching": False})) - def __get_cq(self, search_phrase: str) -> (str, CQSMatchLevel, str, - Optional[dict]): + def __get_cq(self, search_phrase: str) -> (str, Union[CQSMatchLevel, float], + str, Optional[dict]): """ Invoke the CQS handler to let the skill perform its search @param search_phrase: parsed question to get an answer for @@ -198,11 +204,10 @@ def remove_noise(self, phrase: str, lang: str = None) -> str: phrase = ' '.join(phrase.split()) return phrase.strip() - def calc_confidence(self, match: str, phrase: str, level: CQSMatchLevel, - answer: str) -> float: + def __calc_confidence(self, match: str, phrase: str, level: CQSMatchLevel, + answer: str) -> float: """ - Calculate a confidence level for the skill response. Skills may override - this method to implement custom confidence calculation + Calculate a confidence level for the skill response. @param match: Matched portion of the input phrase @param phrase: User input phrase that was evaluated @param level: Skill-determined match level of the answer @@ -298,7 +303,7 @@ def __handle_query_action(self, message: Message): @abstractmethod def CQS_match_query_phrase(self, phrase: str) -> \ - Optional[Tuple[str, CQSMatchLevel, Optional[dict]]]: + Optional[Tuple[str, Union[CQSMatchLevel, float], Optional[dict]]]: """ Determine an answer to the input phrase and return match information, or `None` if no answer can be determined. diff --git a/test/unittests/skills/test_common_query_skill.py b/test/unittests/skills/test_common_query_skill.py index d5ab717..a9eff41 100644 --- a/test/unittests/skills/test_common_query_skill.py +++ b/test/unittests/skills/test_common_query_skill.py @@ -41,7 +41,7 @@ def test_get_cq(self): mock_return = Mock() self.skill.CQS_match_query_phrase = Mock(return_value=mock_return) result = self.skill._CommonQuerySkill__get_cq(test_phrase) - self.skill.CQS_match_query_phrase .assert_called_once_with(test_phrase) + self.skill.CQS_match_query_phrase.assert_called_once_with(test_phrase) self.assertEqual(result, mock_return) self.skill.CQS_match_query_phrase.side_effect = Exception() @@ -60,13 +60,12 @@ def test_calc_confidence(self): cw_answer = ("The drink diet coke has 32 milligrams of caffeine in " "250 milliliters. Provided by CaffeineWiz.") - generic_conf = self.skill.calc_confidence("coca cola", generic_q, - CQSMatchLevel.GENERAL, - cw_answer) - exact_conf = self.skill.calc_confidence("coca cola", specific_q, - CQSMatchLevel.EXACT, cw_answer) - low_conf = self.skill.calc_confidence("coca cola", specific_q_2, - CQSMatchLevel.GENERAL, cw_answer) + generic_conf = self.skill._CommonQuerySkill__calc_confidence( + "coca cola", generic_q, CQSMatchLevel.GENERAL, cw_answer) + exact_conf = self.skill._CommonQuerySkill__calc_confidence( + "coca cola", specific_q, CQSMatchLevel.EXACT, cw_answer) + low_conf = self.skill._CommonQuerySkill__calc_confidence( + "coca cola", specific_q_2, CQSMatchLevel.GENERAL, cw_answer) self.assertEqual(exact_conf, 1.0) self.assertLess(generic_conf, exact_conf)