Skip to content

Commit

Permalink
Ensure first message is user message for anthropic (#1530)
Browse files Browse the repository at this point in the history
Co-authored-by: Carson <cpsievert1@gmail.com>
  • Loading branch information
wch and cpsievert committed Jul 16, 2024
1 parent c3758b6 commit 00ff063
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 8 deletions.
20 changes: 17 additions & 3 deletions shiny/ui/_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,7 @@ def messages(

messages = self._messages()
if token_limits is not None:
messages = self._trim_messages(messages, token_limits)
messages = self._trim_messages(messages, token_limits, format)

res: list[ChatMessage | ProviderMessage] = []
for i, m in enumerate(messages):
Expand Down Expand Up @@ -827,7 +827,8 @@ def _store_message(
@staticmethod
def _trim_messages(
messages: tuple[StoredMessage, ...],
token_limits: tuple[int, int] = (4096, 1000),
token_limits: tuple[int, int],
format: MISSING_TYPE | ProviderMessageFormat,
) -> tuple[StoredMessage, ...]:

n_total, n_reserve = token_limits
Expand Down Expand Up @@ -863,6 +864,7 @@ def _trim_messages(
)

messages2: list[StoredMessage] = []
n_other_messages2: int = 0
for m in reversed(messages):
if m["role"] == "system":
messages2.append(m)
Expand All @@ -871,10 +873,22 @@ def _trim_messages(
remaining_non_system_tokens -= count
if remaining_non_system_tokens >= 0:
messages2.append(m)
n_other_messages2 += 1

# Anthropic doesn't support `role: system` and requires a user message to come 1st
if format == "anthropic":
if n_system_messages > 0:
raise ValueError(
"Anthropic requires a system prompt to be specified in it's `.create()` method "
"(not in the chat messages with `role: system`)."
)
while n_other_messages2 > 0 and messages2[-1]["role"] != "user":
messages2.pop()
n_other_messages2 -= 1

messages2.reverse()

if len(messages2) == n_system_messages and n_other_messages > 0:
if len(messages2) == n_system_messages and n_other_messages2 > 0:
raise ValueError(
f"Only system messages fit within `.messages(token_limits={token_limits})`. "
"Consider increasing the 1st value of `token_limit` or setting it to "
Expand Down
26 changes: 21 additions & 5 deletions tests/pytest/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from shiny import Session
from shiny._namespaces import ResolvedId, Root
from shiny.session import session_context
from shiny.types import MISSING
from shiny.ui import Chat
from shiny.ui._chat import as_transformed_message
from shiny.ui._chat_normalize import normalize_message, normalize_message_chunk
Expand Down Expand Up @@ -66,7 +67,7 @@ def test_chat_message_trimming():

# Throws since system message is too long
with pytest.raises(ValueError):
chat._trim_messages(msgs, token_limits=(100, 0))
chat._trim_messages(msgs, token_limits=(100, 0), format=MISSING)

msgs = (
as_stored_message(
Expand All @@ -79,10 +80,10 @@ def test_chat_message_trimming():

# Throws since only the system message fits
with pytest.raises(ValueError):
chat._trim_messages(msgs, token_limits=(100, 0))
chat._trim_messages(msgs, token_limits=(100, 0), format=MISSING)

# Raising the limit should allow both messages to fit
trimmed = chat._trim_messages(msgs, token_limits=(102, 0))
trimmed = chat._trim_messages(msgs, token_limits=(102, 0), format=MISSING)
assert len(trimmed) == 2
contents = [msg["content_server"] for msg in trimmed]
assert contents == ["System message", "User message"]
Expand All @@ -100,7 +101,7 @@ def test_chat_message_trimming():
)

# Should discard the 1st user message
trimmed = chat._trim_messages(msgs, token_limits=(102, 0))
trimmed = chat._trim_messages(msgs, token_limits=(102, 0), format=MISSING)
assert len(trimmed) == 2
contents = [msg["content_server"] for msg in trimmed]
assert contents == ["System message", "User message 2"]
Expand All @@ -121,11 +122,26 @@ def test_chat_message_trimming():
)

# Should discard the 1st user message
trimmed = chat._trim_messages(msgs, token_limits=(102, 0))
trimmed = chat._trim_messages(msgs, token_limits=(102, 0), format=MISSING)
assert len(trimmed) == 3
contents = [msg["content_server"] for msg in trimmed]
assert contents == ["System message", "System message 2", "User message 2"]

msgs = (
as_stored_message(
{"content": "Assistant message", "role": "assistant"}, token_count=50
),
as_stored_message(
{"content": "User message", "role": "user"}, token_count=10
),
)

# Anthropic requires 1st message to be a user message
trimmed = chat._trim_messages(msgs, token_limits=(30, 0), format="anthropic")
assert len(trimmed) == 1
contents = [msg["content_server"] for msg in trimmed]
assert contents == ["User message"]


# ------------------------------------------------------------------------------------
# Unit tests for normalize_message() and normalize_message_chunk().
Expand Down

0 comments on commit 00ff063

Please sign in to comment.