Skip to content

Commit

Permalink
Revert calc_confidence public method change
Browse files Browse the repository at this point in the history
Support returned float confidence level in addition to enum
  • Loading branch information
Daniel McKnight committed Sep 17, 2024
1 parent 7d30340 commit 588c727
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 17 deletions.
23 changes: 14 additions & 9 deletions ovos_workshop/skills/common_query_skill.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from abc import abstractmethod
from enum import IntEnum
from os.path import dirname
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Union

from ovos_bus_client import Message
from ovos_utils.file_utils import resolve_resource_file
Expand Down Expand Up @@ -150,7 +150,13 @@ def __handle_question_query(self, message: Message):
level = result[1]
answer = result[2]
callback = result[3] if len(result) > 3 else {}
confidence = self.calc_confidence(match, search_phrase, level, answer)
if isinstance(level, float):
LOG.debug(f"Confidence directly reported by skill")
confidence = level
else:
LOG.info(f"Calculating confidence for level {level}")
confidence = self.__calc_confidence(match, search_phrase, level,
answer)
if confidence > 1.0:
LOG.warning(f"Calculated confidence {confidence} > 1.0")
confidence = 1.0
Expand All @@ -167,8 +173,8 @@ def __handle_question_query(self, message: Message):
"skill_id": self.skill_id,
"searching": False}))

def __get_cq(self, search_phrase: str) -> (str, CQSMatchLevel, str,
Optional[dict]):
def __get_cq(self, search_phrase: str) -> (str, Union[CQSMatchLevel, float],
str, Optional[dict]):
"""
Invoke the CQS handler to let the skill perform its search
@param search_phrase: parsed question to get an answer for
Expand Down Expand Up @@ -198,11 +204,10 @@ def remove_noise(self, phrase: str, lang: str = None) -> str:
phrase = ' '.join(phrase.split())
return phrase.strip()

def calc_confidence(self, match: str, phrase: str, level: CQSMatchLevel,
answer: str) -> float:
def __calc_confidence(self, match: str, phrase: str, level: CQSMatchLevel,
answer: str) -> float:
"""
Calculate a confidence level for the skill response. Skills may override
this method to implement custom confidence calculation
Calculate a confidence level for the skill response.
@param match: Matched portion of the input phrase
@param phrase: User input phrase that was evaluated
@param level: Skill-determined match level of the answer
Expand Down Expand Up @@ -298,7 +303,7 @@ def __handle_query_action(self, message: Message):

@abstractmethod
def CQS_match_query_phrase(self, phrase: str) -> \
Optional[Tuple[str, CQSMatchLevel, Optional[dict]]]:
Optional[Tuple[str, Union[CQSMatchLevel, float], Optional[dict]]]:
"""
Determine an answer to the input phrase and return match information, or
`None` if no answer can be determined.
Expand Down
15 changes: 7 additions & 8 deletions test/unittests/skills/test_common_query_skill.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def test_get_cq(self):
mock_return = Mock()
self.skill.CQS_match_query_phrase = Mock(return_value=mock_return)
result = self.skill._CommonQuerySkill__get_cq(test_phrase)
self.skill.CQS_match_query_phrase .assert_called_once_with(test_phrase)
self.skill.CQS_match_query_phrase.assert_called_once_with(test_phrase)
self.assertEqual(result, mock_return)

self.skill.CQS_match_query_phrase.side_effect = Exception()
Expand All @@ -60,13 +60,12 @@ def test_calc_confidence(self):
cw_answer = ("The drink diet coke has 32 milligrams of caffeine in "
"250 milliliters.</speak> Provided by CaffeineWiz.")

generic_conf = self.skill.calc_confidence("coca cola", generic_q,
CQSMatchLevel.GENERAL,
cw_answer)
exact_conf = self.skill.calc_confidence("coca cola", specific_q,
CQSMatchLevel.EXACT, cw_answer)
low_conf = self.skill.calc_confidence("coca cola", specific_q_2,
CQSMatchLevel.GENERAL, cw_answer)
generic_conf = self.skill._CommonQuerySkill__calc_confidence(
"coca cola", generic_q, CQSMatchLevel.GENERAL, cw_answer)
exact_conf = self.skill._CommonQuerySkill__calc_confidence(
"coca cola", specific_q, CQSMatchLevel.EXACT, cw_answer)
low_conf = self.skill._CommonQuerySkill__calc_confidence(
"coca cola", specific_q_2, CQSMatchLevel.GENERAL, cw_answer)

self.assertEqual(exact_conf, 1.0)
self.assertLess(generic_conf, exact_conf)
Expand Down

0 comments on commit 588c727

Please sign in to comment.