-
Notifications
You must be signed in to change notification settings - Fork 3.1k
/
open_ai_assistant_channel.py
97 lines (74 loc) · 3.51 KB
/
open_ai_assistant_channel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
# Copyright (c) Microsoft. All rights reserved.
import sys
from collections.abc import AsyncIterable
from typing import TYPE_CHECKING, Any
from semantic_kernel.contents.function_call_content import FunctionCallContent
if sys.version_info >= (3, 12):
from typing import override # pragma: no cover
else:
from typing_extensions import override # pragma: no cover
from openai import AsyncOpenAI
from semantic_kernel.agents.channels.agent_channel import AgentChannel
from semantic_kernel.agents.open_ai.assistant_content_generation import create_chat_message, generate_message_content
from semantic_kernel.contents.chat_message_content import ChatMessageContent
from semantic_kernel.exceptions.agent_exceptions import AgentChatException
if TYPE_CHECKING:
from semantic_kernel.agents.agent import Agent
class OpenAIAssistantChannel(AgentChannel):
"""OpenAI Assistant Channel."""
def __init__(self, client: AsyncOpenAI, thread_id: str) -> None:
"""Initialize the OpenAI Assistant Channel."""
self.client = client
self.thread_id = thread_id
@override
async def receive(self, history: list["ChatMessageContent"]) -> None:
"""Receive the conversation messages.
Args:
history: The conversation messages.
"""
for message in history:
if any(isinstance(item, FunctionCallContent) for item in message.items):
continue
await create_chat_message(self.client, self.thread_id, message)
@override
async def invoke(self, agent: "Agent") -> AsyncIterable[tuple[bool, "ChatMessageContent"]]:
"""Invoke the agent.
Args:
agent: The agent to invoke.
Yields:
tuple[bool, ChatMessageContent]: The conversation messages.
"""
from semantic_kernel.agents.open_ai.open_ai_assistant_base import OpenAIAssistantBase
if not isinstance(agent, OpenAIAssistantBase):
raise AgentChatException(f"Agent is not of the expected type {type(OpenAIAssistantBase)}.")
if agent._is_deleted:
raise AgentChatException("Agent is deleted.")
async for is_visible, message in agent._invoke_internal(thread_id=self.thread_id):
yield is_visible, message
@override
async def get_history(self) -> AsyncIterable["ChatMessageContent"]:
"""Get the conversation history.
Yields:
ChatMessageContent: The conversation history.
"""
agent_names: dict[str, Any] = {}
thread_messages = await self.client.beta.threads.messages.list(
thread_id=self.thread_id, limit=100, order="desc"
)
for message in thread_messages.data:
assistant_name = None
if message.assistant_id and message.assistant_id not in agent_names:
agent = await self.client.beta.assistants.retrieve(message.assistant_id)
if agent.name:
agent_names[message.assistant_id] = agent.name
assistant_name = agent_names.get(message.assistant_id) if message.assistant_id else message.assistant_id
content: ChatMessageContent = generate_message_content(str(assistant_name), message)
if len(content.items) > 0:
yield content
@override
async def reset(self) -> None:
"""Reset the agent's thread."""
try:
await self.client.beta.threads.delete(thread_id=self.thread_id)
except Exception as e:
raise AgentChatException(f"Failed to delete thread: {e}")