Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add chat.messages(format='internal') #1532

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 16 additions & 10 deletions shiny/ui/_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ async def _raise_exception(
def messages(
self,
*,
format: Literal["anthropic"] = "anthropic",
format: Literal["anthropic"],
token_limits: tuple[int, int] | None = (4096, 1000),
transform_user: Literal["all", "last", "none"] = "all",
transform_assistant: bool = False,
Expand All @@ -357,7 +357,7 @@ def messages(
def messages(
self,
*,
format: Literal["google"] = "google",
format: Literal["google"],
token_limits: tuple[int, int] | None = (4096, 1000),
transform_user: Literal["all", "last", "none"] = "all",
transform_assistant: bool = False,
Expand All @@ -367,7 +367,7 @@ def messages(
def messages(
self,
*,
format: Literal["langchain"] = "langchain",
format: Literal["langchain"],
token_limits: tuple[int, int] | None = (4096, 1000),
transform_user: Literal["all", "last", "none"] = "all",
transform_assistant: bool = False,
Expand All @@ -377,7 +377,7 @@ def messages(
def messages(
self,
*,
format: Literal["openai"] = "openai",
format: Literal["openai"],
token_limits: tuple[int, int] | None = (4096, 1000),
transform_user: Literal["all", "last", "none"] = "all",
transform_assistant: bool = False,
Expand All @@ -387,7 +387,7 @@ def messages(
def messages(
self,
*,
format: Literal["ollama"] = "ollama",
format: Literal["ollama"],
token_limits: tuple[int, int] | None = (4096, 1000),
transform_user: Literal["all", "last", "none"] = "all",
transform_assistant: bool = False,
Expand All @@ -397,7 +397,7 @@ def messages(
def messages(
self,
*,
format: MISSING_TYPE = MISSING,
format: Literal["internal"],
token_limits: tuple[int, int] | None = (4096, 1000),
transform_user: Literal["all", "last", "none"] = "all",
transform_assistant: bool = False,
Expand All @@ -406,7 +406,7 @@ def messages(
def messages(
self,
*,
format: MISSING_TYPE | ProviderMessageFormat = MISSING,
format: ProviderMessageFormat | Literal["internal"],
Copy link
Collaborator

@cpsievert cpsievert Jul 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

However, this format is actually the least likely format that a user would want -- it is used internally in Shiny, but is not directly usable by any LLMs out there.

That's not actually true. This format can be passed directly OpenAI, LangChain, Anthropic, Ollama, and probably others without modification. It won't pass a type checker, but that doesn't seem like reason enough to force everyone to have to specify a format.

Suggested change
format: ProviderMessageFormat | Literal["internal"],
format: ProviderMessageFormat | Literal["internal"] = "internal",

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh wait, I think I misunderstood what the default format was -- I thought it was the actual internal format that the Chat class uses, but that's not right. The internal format is StoredMessage, not ChatMessage. That's why I thought it wouldn't make sense to return that format by default.

So to summarize, prior to this PR, the default format was ChatMessage, which is sort of a lowest-common-denominator format that has role and content.

So it might make sense to just drop this PR entirely. That said, it could still be useful to be able to access the StoredMessages in the chat object. That's something I actually wanted to do when I was trying to get a handle on the sizes of messages being sent to the LLM (but then I got sidetracked before I went in depth on it).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see. I'd be interested to know more about why you want access to StoredMessage.

I'll close this PR and create another one for the format="anthropic" thing

token_limits: tuple[int, int] | None = (4096, 1000),
transform_user: Literal["all", "last", "none"] = "all",
transform_assistant: bool = False,
Expand Down Expand Up @@ -466,7 +466,7 @@ def messages(

messages = self._messages()
if token_limits is not None:
messages = self._trim_messages(messages, token_limits)
messages = self._trim_messages(messages, token_limits, format)

res: list[ChatMessage | ProviderMessage] = []
for i, m in enumerate(messages):
Expand All @@ -479,7 +479,7 @@ def messages(
)
content_key = m["transform_key" if transform else "pre_transform_key"]
chat_msg = ChatMessage(content=m[content_key], role=m["role"])
if not isinstance(format, MISSING_TYPE):
if format != "internal":
chat_msg = as_provider_message(chat_msg, format)
res.append(chat_msg)

Expand Down Expand Up @@ -827,7 +827,8 @@ def _store_message(
@staticmethod
def _trim_messages(
messages: tuple[StoredMessage, ...],
token_limits: tuple[int, int] = (4096, 1000),
token_limits: tuple[int, int],
format: ProviderMessageFormat | Literal["internal"],
) -> tuple[StoredMessage, ...]:

n_total, n_reserve = token_limits
Expand Down Expand Up @@ -872,6 +873,11 @@ def _trim_messages(
if remaining_non_system_tokens >= 0:
messages2.append(m)

if format == "anthropic":
# For anthropic, the first message must be a user message.
while messages2[-1]["role"] != "user":
messages2.pop()
Comment on lines +878 to +879
Copy link
Collaborator

@cpsievert cpsievert Jul 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's technically possible that messages2 is of length 0 here, so we'll need to be more careful about that.

from shiny import reactive
from shiny.express import ui

chat = ui.Chat(id="chat")
chat.ui()

@reactive.effect
def _():
    print(chat.messages(format="anthropic"))


messages2.reverse()

if len(messages2) == n_system_messages and n_other_messages > 0:
Expand Down
11 changes: 5 additions & 6 deletions tests/pytest/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def test_chat_message_trimming():

# Throws since system message is too long
with pytest.raises(ValueError):
chat._trim_messages(msgs, token_limits=(100, 0))
chat._trim_messages(msgs, token_limits=(100, 0), format="internal")

msgs = (
as_stored_message(
Expand All @@ -79,10 +79,9 @@ def test_chat_message_trimming():

# Throws since only the system message fits
with pytest.raises(ValueError):
chat._trim_messages(msgs, token_limits=(100, 0))

chat._trim_messages(msgs, token_limits=(100, 0), format="internal")
# Raising the limit should allow both messages to fit
trimmed = chat._trim_messages(msgs, token_limits=(102, 0))
trimmed = chat._trim_messages(msgs, token_limits=(102, 0), format="internal")
assert len(trimmed) == 2
contents = [msg["content_server"] for msg in trimmed]
assert contents == ["System message", "User message"]
Expand All @@ -100,7 +99,7 @@ def test_chat_message_trimming():
)

# Should discard the 1st user message
trimmed = chat._trim_messages(msgs, token_limits=(102, 0))
trimmed = chat._trim_messages(msgs, token_limits=(102, 0), format="internal")
assert len(trimmed) == 2
contents = [msg["content_server"] for msg in trimmed]
assert contents == ["System message", "User message 2"]
Expand All @@ -121,7 +120,7 @@ def test_chat_message_trimming():
)

# Should discard the 1st user message
trimmed = chat._trim_messages(msgs, token_limits=(102, 0))
trimmed = chat._trim_messages(msgs, token_limits=(102, 0), format="internal")
assert len(trimmed) == 3
contents = [msg["content_server"] for msg in trimmed]
assert contents == ["System message", "System message 2", "User message 2"]
Expand Down
Loading