Skip to content

Commit

Permalink
Ensure first message is user message for anthropic
Browse files Browse the repository at this point in the history
  • Loading branch information
wch committed Jul 15, 2024
1 parent 6611277 commit b244bb7
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 7 deletions.
10 changes: 8 additions & 2 deletions shiny/ui/_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,7 @@ def messages(

messages = self._messages()
if token_limits is not None:
messages = self._trim_messages(messages, token_limits)
messages = self._trim_messages(messages, token_limits, format)

res: list[ChatMessage | ProviderMessage] = []
for i, m in enumerate(messages):
Expand Down Expand Up @@ -827,7 +827,8 @@ def _store_message(
@staticmethod
def _trim_messages(
messages: tuple[StoredMessage, ...],
token_limits: tuple[int, int] = (4096, 1000),
token_limits: tuple[int, int],
format: MISSING_TYPE | ProviderMessageFormat,
) -> tuple[StoredMessage, ...]:

n_total, n_reserve = token_limits
Expand Down Expand Up @@ -872,6 +873,11 @@ def _trim_messages(
if remaining_non_system_tokens >= 0:
messages2.append(m)

if format == "anthropic":
# For anthropic, the first message must be a user message.
while messages2[-1]["role"] != "user":
messages2.pop()

messages2.reverse()

if len(messages2) == n_system_messages and n_other_messages > 0:
Expand Down
11 changes: 6 additions & 5 deletions tests/pytest/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from shiny import Session
from shiny._namespaces import ResolvedId, Root
from shiny.session import session_context
from shiny.types import MISSING
from shiny.ui import Chat
from shiny.ui._chat import as_transformed_message
from shiny.ui._chat_normalize import normalize_message, normalize_message_chunk
Expand Down Expand Up @@ -66,7 +67,7 @@ def test_chat_message_trimming():

# Throws since system message is too long
with pytest.raises(ValueError):
chat._trim_messages(msgs, token_limits=(100, 0))
chat._trim_messages(msgs, token_limits=(100, 0), format=MISSING)

msgs = (
as_stored_message(
Expand All @@ -79,10 +80,10 @@ def test_chat_message_trimming():

# Throws since only the system message fits
with pytest.raises(ValueError):
chat._trim_messages(msgs, token_limits=(100, 0))
chat._trim_messages(msgs, token_limits=(100, 0), format=MISSING)

# Raising the limit should allow both messages to fit
trimmed = chat._trim_messages(msgs, token_limits=(102, 0))
trimmed = chat._trim_messages(msgs, token_limits=(102, 0), format=MISSING)
assert len(trimmed) == 2
contents = [msg["content_server"] for msg in trimmed]
assert contents == ["System message", "User message"]
Expand All @@ -100,7 +101,7 @@ def test_chat_message_trimming():
)

# Should discard the 1st user message
trimmed = chat._trim_messages(msgs, token_limits=(102, 0))
trimmed = chat._trim_messages(msgs, token_limits=(102, 0), format=MISSING)
assert len(trimmed) == 2
contents = [msg["content_server"] for msg in trimmed]
assert contents == ["System message", "User message 2"]
Expand All @@ -121,7 +122,7 @@ def test_chat_message_trimming():
)

# Should discard the 1st user message
trimmed = chat._trim_messages(msgs, token_limits=(102, 0))
trimmed = chat._trim_messages(msgs, token_limits=(102, 0), format=MISSING)
assert len(trimmed) == 3
contents = [msg["content_server"] for msg in trimmed]
assert contents == ["System message", "System message 2", "User message 2"]
Expand Down

0 comments on commit b244bb7

Please sign in to comment.