Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

assistants: fallback type in discriminated unions #1615

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions src/openai/_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
extract_type_arg,
is_annotated_type,
strip_annotated_type,
wrap_in_annotated_type,
)
from ._compat import (
PYDANTIC_V2,
Expand Down Expand Up @@ -356,7 +357,10 @@ def _construct_field(value: object, field: FieldInfo, key: str) -> object:
return field_get_default(field)

if PYDANTIC_V2:
type_ = field.annotation
if field.metadata:
type_ = wrap_in_annotated_type(field)
else:
type_ = field.annotation
Comment on lines -359 to +363
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the case where the response data for a union field doesn't conform to expected (i.e. we fail the validate_type call), we seem to lose the metadata (which has the discriminator info) when constructing the field. This makes the code construct the first type present in the union instead of creating the actual known type. Test case illustrating this here

else:
type_ = cast(type, field.outer_type_) # type: ignore

Expand Down Expand Up @@ -609,8 +613,13 @@ def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any,
# Note: if one variant defines an alias then they all should
discriminator_alias = field_info.alias

if field_info.annotation and is_literal_type(field_info.annotation):
for entry in get_args(field_info.annotation):
if hasattr(field_info, "annotation"):
field_annotation = cast(type, field_info.annotation)
else:
# pydantic==1.9
field_annotation = cast(type, field_info.outer_type_) # type: ignore
if field_annotation and is_literal_type(field_annotation):
for entry in get_args(field_annotation):
if isinstance(entry, str):
mapping[entry] = variant

Expand Down
1 change: 1 addition & 0 deletions src/openai/_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
is_required_type as is_required_type,
is_annotated_type as is_annotated_type,
strip_annotated_type as strip_annotated_type,
wrap_in_annotated_type as wrap_in_annotated_type,
extract_type_var_from_base as extract_type_var_from_base,
)
from ._streams import consume_sync_iterator as consume_sync_iterator, consume_async_iterator as consume_async_iterator
Expand Down
6 changes: 6 additions & 0 deletions src/openai/_utils/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from collections import abc as _c_abc
from typing_extensions import Required, Annotated, get_args, get_origin

from pydantic.fields import FieldInfo

from .._types import InheritsGeneric
from .._compat import is_union as _is_union

Expand Down Expand Up @@ -44,6 +46,10 @@ def strip_annotated_type(typ: type) -> type:
return typ


def wrap_in_annotated_type(typ: FieldInfo) -> object:
return Annotated[cast(type, typ.annotation), typ.metadata[0]]


def extract_type_arg(typ: type, index: int) -> type:
args = get_args(typ)
try:
Expand Down
2 changes: 1 addition & 1 deletion src/openai/types/beta/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from .assistant import Assistant as Assistant
from .vector_store import VectorStore as VectorStore
from .function_tool import FunctionTool as FunctionTool
from .assistant_tool import AssistantTool as AssistantTool
from .assistant_tool import BaseTool as BaseTool, AssistantTool as AssistantTool
from .thread_deleted import ThreadDeleted as ThreadDeleted
from .file_search_tool import FileSearchTool as FileSearchTool
from .assistant_deleted import AssistantDeleted as AssistantDeleted
Expand Down
24 changes: 21 additions & 3 deletions src/openai/types/beta/assistant_tool.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,33 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.

from typing import Union
from typing_extensions import Annotated, TypeAlias
from typing_extensions import Literal, Annotated, TypeAlias

from ..._utils import PropertyInfo
from ..._compat import PYDANTIC_V2
from ..._models import BaseModel
from .function_tool import FunctionTool
from .file_search_tool import FileSearchTool
from .code_interpreter_tool import CodeInterpreterTool

__all__ = ["AssistantTool"]
if PYDANTIC_V2:
from pydantic import field_serializer


__all__ = ["AssistantTool", "BaseTool"]


class BaseTool(BaseModel):
type: Literal["unknown"]
"""A tool type"""

if PYDANTIC_V2:

@field_serializer("type", when_used="always") # type: ignore
def serialize_unknown_type(self, type_: str) -> str:
return type_


AssistantTool: TypeAlias = Annotated[
Union[CodeInterpreterTool, FileSearchTool, FunctionTool], PropertyInfo(discriminator="type")
Union[BaseTool, CodeInterpreterTool, FileSearchTool, FunctionTool], PropertyInfo(discriminator="type")
]
8 changes: 4 additions & 4 deletions src/openai/types/beta/threads/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,17 @@
from .text import Text as Text
from .message import Message as Message
from .image_url import ImageURL as ImageURL
from .annotation import Annotation as Annotation
from .annotation import Annotation as Annotation, BaseAnnotation as BaseAnnotation
from .image_file import ImageFile as ImageFile
from .run_status import RunStatus as RunStatus
from .text_delta import TextDelta as TextDelta
from .message_delta import MessageDelta as MessageDelta
from .image_url_delta import ImageURLDelta as ImageURLDelta
from .image_url_param import ImageURLParam as ImageURLParam
from .message_content import MessageContent as MessageContent
from .message_content import MessageContent as MessageContent, BaseContentBlock as BaseContentBlock
from .message_deleted import MessageDeleted as MessageDeleted
from .run_list_params import RunListParams as RunListParams
from .annotation_delta import AnnotationDelta as AnnotationDelta
from .annotation_delta import AnnotationDelta as AnnotationDelta, BaseDeltaAnnotation as BaseDeltaAnnotation
from .image_file_delta import ImageFileDelta as ImageFileDelta
from .image_file_param import ImageFileParam as ImageFileParam
from .text_delta_block import TextDeltaBlock as TextDeltaBlock
Expand All @@ -28,7 +28,7 @@
from .refusal_delta_block import RefusalDeltaBlock as RefusalDeltaBlock
from .file_path_annotation import FilePathAnnotation as FilePathAnnotation
from .image_url_delta_block import ImageURLDeltaBlock as ImageURLDeltaBlock
from .message_content_delta import MessageContentDelta as MessageContentDelta
from .message_content_delta import BaseDeltaBlock as BaseDeltaBlock, MessageContentDelta as MessageContentDelta
from .message_create_params import MessageCreateParams as MessageCreateParams
from .message_update_params import MessageUpdateParams as MessageUpdateParams
from .refusal_content_block import RefusalContentBlock as RefusalContentBlock
Expand Down
28 changes: 25 additions & 3 deletions src/openai/types/beta/threads/annotation.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,34 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.

from typing import Union
from typing_extensions import Annotated, TypeAlias
from typing_extensions import Literal, Annotated, TypeAlias

from ...._utils import PropertyInfo
from ...._compat import PYDANTIC_V2
from ...._models import BaseModel
from .file_path_annotation import FilePathAnnotation
from .file_citation_annotation import FileCitationAnnotation

__all__ = ["Annotation"]
if PYDANTIC_V2:
from pydantic import field_serializer

Annotation: TypeAlias = Annotated[Union[FileCitationAnnotation, FilePathAnnotation], PropertyInfo(discriminator="type")]
__all__ = ["Annotation", "BaseAnnotation"]


class BaseAnnotation(BaseModel):
text: str
"""The index of the annotation in the text content part."""

type: Literal["unknown"]
"""The type of annotation"""

if PYDANTIC_V2:

@field_serializer("type", when_used="always") # type: ignore
def serialize_unknown_type(self, type_: str) -> str:
return type_


Annotation: TypeAlias = Annotated[
Union[BaseAnnotation, FileCitationAnnotation, FilePathAnnotation], PropertyInfo(discriminator="type")
]
27 changes: 24 additions & 3 deletions src/openai/types/beta/threads/annotation_delta.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,35 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.

from typing import Union
from typing_extensions import Annotated, TypeAlias
from typing_extensions import Literal, Annotated, TypeAlias

from ...._utils import PropertyInfo
from ...._compat import PYDANTIC_V2
from ...._models import BaseModel
from .file_path_delta_annotation import FilePathDeltaAnnotation
from .file_citation_delta_annotation import FileCitationDeltaAnnotation

__all__ = ["AnnotationDelta"]
if PYDANTIC_V2:
from pydantic import field_serializer


__all__ = ["AnnotationDelta", "BaseDeltaAnnotation"]


class BaseDeltaAnnotation(BaseModel):
index: int
"""The index of the annotation in the text content part."""

type: Literal["unknown"]
"""The type of annotation"""

if PYDANTIC_V2:

@field_serializer("type", when_used="always") # type: ignore
def serialize_unknown_type(self, type_: str) -> str:
return type_


AnnotationDelta: TypeAlias = Annotated[
Union[FileCitationDeltaAnnotation, FilePathDeltaAnnotation], PropertyInfo(discriminator="type")
Union[BaseDeltaAnnotation, FileCitationDeltaAnnotation, FilePathDeltaAnnotation], PropertyInfo(discriminator="type")
]
23 changes: 20 additions & 3 deletions src/openai/types/beta/threads/message_content.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,35 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.

from typing import Union
from typing_extensions import Annotated, TypeAlias
from typing_extensions import Literal, Annotated, TypeAlias

from ...._utils import PropertyInfo
from ...._compat import PYDANTIC_V2
from ...._models import BaseModel
from .text_content_block import TextContentBlock
from .refusal_content_block import RefusalContentBlock
from .image_url_content_block import ImageURLContentBlock
from .image_file_content_block import ImageFileContentBlock

__all__ = ["MessageContent"]
if PYDANTIC_V2:
from pydantic import field_serializer


__all__ = ["MessageContent", "BaseContentBlock"]


class BaseContentBlock(BaseModel):
type: Literal["unknown"]
"""The type of content part"""

if PYDANTIC_V2:

@field_serializer("type", when_used="always") # type: ignore
def serialize_unknown_type(self, type_: str) -> str:
return type_


MessageContent: TypeAlias = Annotated[
Union[ImageFileContentBlock, ImageURLContentBlock, TextContentBlock, RefusalContentBlock],
Union[BaseContentBlock, ImageFileContentBlock, ImageURLContentBlock, TextContentBlock, RefusalContentBlock],
PropertyInfo(discriminator="type"),
]
27 changes: 24 additions & 3 deletions src/openai/types/beta/threads/message_content_delta.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,38 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.

from typing import Union
from typing_extensions import Annotated, TypeAlias
from typing_extensions import Literal, Annotated, TypeAlias

from ...._utils import PropertyInfo
from ...._compat import PYDANTIC_V2
from ...._models import BaseModel
from .text_delta_block import TextDeltaBlock
from .refusal_delta_block import RefusalDeltaBlock
from .image_url_delta_block import ImageURLDeltaBlock
from .image_file_delta_block import ImageFileDeltaBlock

__all__ = ["MessageContentDelta"]
if PYDANTIC_V2:
from pydantic import field_serializer


__all__ = ["MessageContentDelta", "BaseDeltaBlock"]


class BaseDeltaBlock(BaseModel):
index: int
"""The index of the content part in the message."""

type: Literal["unknown"]
"""The type of content part"""

if PYDANTIC_V2:

@field_serializer("type", when_used="always") # type: ignore
def serialize_unknown_type(self, type_: str) -> str:
return type_


MessageContentDelta: TypeAlias = Annotated[
Union[ImageFileDeltaBlock, TextDeltaBlock, RefusalDeltaBlock, ImageURLDeltaBlock],
Union[BaseDeltaBlock, ImageFileDeltaBlock, TextDeltaBlock, RefusalDeltaBlock, ImageURLDeltaBlock],
PropertyInfo(discriminator="type"),
]
4 changes: 2 additions & 2 deletions src/openai/types/beta/threads/runs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from __future__ import annotations

from .run_step import RunStep as RunStep
from .tool_call import ToolCall as ToolCall
from .tool_call import ToolCall as ToolCall, BaseToolCall as BaseToolCall
from .run_step_delta import RunStepDelta as RunStepDelta
from .tool_call_delta import ToolCallDelta as ToolCallDelta
from .tool_call_delta import ToolCallDelta as ToolCallDelta, BaseToolCallDelta as BaseToolCallDelta
from .step_list_params import StepListParams as StepListParams
from .function_tool_call import FunctionToolCall as FunctionToolCall
from .run_step_delta_event import RunStepDeltaEvent as RunStepDeltaEvent
Expand Down
29 changes: 26 additions & 3 deletions src/openai/types/beta/threads/runs/tool_call.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,38 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.

from typing import Union
from typing_extensions import Annotated, TypeAlias
from typing_extensions import Literal, Annotated, TypeAlias

from ....._utils import PropertyInfo
from ....._compat import PYDANTIC_V2
from ....._models import BaseModel
from .function_tool_call import FunctionToolCall
from .file_search_tool_call import FileSearchToolCall
from .code_interpreter_tool_call import CodeInterpreterToolCall

__all__ = ["ToolCall"]
if PYDANTIC_V2:
from pydantic import field_serializer


__all__ = ["ToolCall", "BaseToolCall"]


class BaseToolCall(BaseModel):
id: str
"""The ID of the tool call."""

type: Literal["unknown"]
"""The type of tool call.
"""

if PYDANTIC_V2:

@field_serializer("type", when_used="always") # type: ignore
def serialize_unknown_type(self, type_: str) -> str:
return type_


ToolCall: TypeAlias = Annotated[
Union[CodeInterpreterToolCall, FileSearchToolCall, FunctionToolCall], PropertyInfo(discriminator="type")
Union[BaseToolCall, CodeInterpreterToolCall, FileSearchToolCall, FunctionToolCall],
PropertyInfo(discriminator="type"),
]
28 changes: 25 additions & 3 deletions src/openai/types/beta/threads/runs/tool_call_delta.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,38 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.

from typing import Union
from typing_extensions import Annotated, TypeAlias
from typing_extensions import Literal, Annotated, TypeAlias

from ....._utils import PropertyInfo
from ....._compat import PYDANTIC_V2
from ....._models import BaseModel
from .function_tool_call_delta import FunctionToolCallDelta
from .file_search_tool_call_delta import FileSearchToolCallDelta
from .code_interpreter_tool_call_delta import CodeInterpreterToolCallDelta

__all__ = ["ToolCallDelta"]
if PYDANTIC_V2:
from pydantic import field_serializer


__all__ = ["ToolCallDelta", "BaseToolCallDelta"]


class BaseToolCallDelta(BaseModel):
index: int
"""The index of the tool call in the tool calls array."""

type: Literal["unknown"]
"""The type of tool call.
"""

if PYDANTIC_V2:

@field_serializer("type", when_used="always") # type: ignore
def serialize_unknown_type(self, type_: str) -> str:
return type_


ToolCallDelta: TypeAlias = Annotated[
Union[CodeInterpreterToolCallDelta, FileSearchToolCallDelta, FunctionToolCallDelta],
Union[BaseToolCallDelta, CodeInterpreterToolCallDelta, FileSearchToolCallDelta, FunctionToolCallDelta],
PropertyInfo(discriminator="type"),
]
Loading
Loading