Skip to content

Commit

Permalink
fixed annotation parsing for py3.9-
Browse files Browse the repository at this point in the history
  • Loading branch information
eavanvalkenburg committed Feb 22, 2024
1 parent 07f5e02 commit 3a723a1
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 3 deletions.
9 changes: 7 additions & 2 deletions python/semantic_kernel/functions/kernel_function_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,10 @@ def decorator(func: Callable):
func.__kernel_function_description__ = description or func.__doc__
func.__kernel_function_name__ = name or func.__name__
func.__kernel_function_streaming__ = isasyncgenfunction(func) or isgeneratorfunction(func)
logger.debug(f"Parsing decorator for function: {func.__kernel_function_name__}")

func_sig = signature(func)
logger.debug(f"{func_sig=}")
func.__kernel_function_context_parameters__ = [
_parse_parameter(param) for param in func_sig.parameters.values() if param.name != "self"
]
Expand All @@ -65,11 +67,13 @@ def decorator(func: Callable):


def _parse_parameter(param: Parameter):
logger.debug(f"Parsing param: {param}")
param_description = ""
type_ = "str"
required = True
if param != Parameter.empty:
param_description, type_, required = _parse_annotation(param.annotation)
logger.debug(f"{param_description=}, {type_=}, {required=}")
return {
"name": param.name,
"description": param_description,
Expand All @@ -80,17 +84,18 @@ def _parse_parameter(param: Parameter):


def _parse_annotation(annotation: Parameter) -> Tuple[str, str, bool]:
logger.debug(f"Parsing annotation: {annotation}")
if isinstance(annotation, str):
return "", annotation, True
logger.debug(f"{annotation=}")
description = ""
if getattr(annotation, "__name__", None) == "Annotated":
if hasattr(annotation, "__metadata__") and annotation.__metadata__:
description = annotation.__metadata__[0]
return (description, *_parse_internal_annotation(annotation, True))


def _parse_internal_annotation(annotation: Parameter, required: bool) -> Tuple[str, bool]:
logger.debug(f"{annotation=}")
logger.debug(f"Internal {annotation=}")
if hasattr(annotation, "__forward_arg__"):
return annotation.__forward_arg__, required
if getattr(annotation, "__name__", None) == "Optional":
Expand Down
2 changes: 1 addition & 1 deletion python/tests/unit/kernel/test_kernel_service_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def test_get_service_with_multiple_types(kernel_with_service: Kernel):
assert service_get == kernel_with_service.services["service"]


@pytest.mark.skipif(sys.version_info < (3, 10))
@pytest.mark.skipif(sys.version_info < (3, 10), reason="This is valid syntax only in python 3.10+.")
def test_get_service_with_multiple_types_union(kernel_with_service: Kernel):
"""This is valid syntax only in python 3.10+. It is skipped for older versions."""
service_get = kernel_with_service.get_service("service", type=Union[AIServiceClientBase, ChatCompletionClientBase])
Expand Down

0 comments on commit 3a723a1

Please sign in to comment.