Skip to content

Commit

Permalink
fix (provider/bedrock): support parallel tool calls in streaming mode (
Browse files Browse the repository at this point in the history
  • Loading branch information
lgrammel committed Sep 19, 2024
1 parent 626cf00 commit 8f080f4
Show file tree
Hide file tree
Showing 3 changed files with 187 additions and 27 deletions.
5 changes: 5 additions & 0 deletions .changeset/six-poets-sparkle.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'@ai-sdk/amazon-bedrock': patch
---

fix (provider/bedrock): support parallel tool calls in streaming mode
142 changes: 139 additions & 3 deletions packages/amazon-bedrock/src/bedrock-chat-language-model.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -344,17 +344,17 @@ describe('doStream', () => {
},
{
contentBlockDelta: {
contentBlockIndex: 1,
contentBlockIndex: 0,
delta: { toolUse: { input: '{"value":' } },
},
},
{
contentBlockDelta: {
contentBlockIndex: 2,
contentBlockIndex: 0,
delta: { toolUse: { input: '"Sparkle Day"}' } },
},
},
{ contentBlockStop: { contentBlockIndex: 3 } },
{ contentBlockStop: { contentBlockIndex: 0 } },
{ messageStop: { stopReason: 'tool_use' } },
];

Expand Down Expand Up @@ -415,6 +415,142 @@ describe('doStream', () => {
]);
});

it('should stream parallel tool calls', async () => {
const streamData: ConverseStreamOutput[] = [
{
contentBlockStart: {
contentBlockIndex: 0,
start: {
toolUse: { toolUseId: 'tool-use-id-1', name: 'test-tool-1' },
},
},
},
{
contentBlockDelta: {
contentBlockIndex: 0,
delta: { toolUse: { input: '{"value1":' } },
},
},
{
contentBlockStart: {
contentBlockIndex: 1,
start: {
toolUse: { toolUseId: 'tool-use-id-2', name: 'test-tool-2' },
},
},
},
{
contentBlockDelta: {
contentBlockIndex: 1,
delta: { toolUse: { input: '{"value2":' } },
},
},
{
contentBlockDelta: {
contentBlockIndex: 1,
delta: { toolUse: { input: '"Sparkle Day"}' } },
},
},
{
contentBlockDelta: {
contentBlockIndex: 0,
delta: { toolUse: { input: '"Sparkle Day"}' } },
},
},
{ contentBlockStop: { contentBlockIndex: 0 } },
{ contentBlockStop: { contentBlockIndex: 1 } },
{ messageStop: { stopReason: 'tool_use' } },
];

bedrockMock.on(ConverseStreamCommand).resolves({
stream: convertArrayToAsyncIterable(streamData),
});

const { stream } = await model.doStream({
inputFormat: 'prompt',
mode: {
type: 'regular',
tools: [
{
type: 'function',
name: 'test-tool-1',
parameters: {
type: 'object',
properties: { value1: { type: 'string' } },
required: ['value'],
additionalProperties: false,
$schema: 'http://json-schema.org/draft-07/schema#',
},
},
{
type: 'function',
name: 'test-tool-2',
parameters: {
type: 'object',
properties: { value2: { type: 'string' } },
required: ['value'],
additionalProperties: false,
$schema: 'http://json-schema.org/draft-07/schema#',
},
},
],
toolChoice: { type: 'tool', toolName: 'test-tool' },
},
prompt: TEST_PROMPT,
});

expect(await convertReadableStreamToArray(stream)).toStrictEqual([
{
type: 'tool-call-delta',
toolCallId: 'tool-use-id-1',
toolCallType: 'function',
toolName: 'test-tool-1',
argsTextDelta: '{"value1":',
},
{
type: 'tool-call-delta',
toolCallId: 'tool-use-id-2',
toolCallType: 'function',
toolName: 'test-tool-2',
argsTextDelta: '{"value2":',
},
{
type: 'tool-call-delta',
toolCallId: 'tool-use-id-2',
toolCallType: 'function',
toolName: 'test-tool-2',
argsTextDelta: '"Sparkle Day"}',
},
{
type: 'tool-call-delta',
toolCallId: 'tool-use-id-1',
toolCallType: 'function',
toolName: 'test-tool-1',
argsTextDelta: '"Sparkle Day"}',
},
{
type: 'tool-call',
toolCallId: 'tool-use-id-1',
toolCallType: 'function',
toolName: 'test-tool-1',
args: '{"value1":"Sparkle Day"}',
},
{
type: 'tool-call',
toolCallId: 'tool-use-id-2',
toolCallType: 'function',
toolName: 'test-tool-2',
args: '{"value2":"Sparkle Day"}',
},
{
type: 'finish',
finishReason: 'tool-calls',
usage: { promptTokens: NaN, completionTokens: NaN },
providerMetadata: undefined,
},
]);
});

it('should handle error stream parts', async () => {
bedrockMock.on(ConverseStreamCommand).resolves({
stream: convertArrayToAsyncIterable([
Expand Down
67 changes: 43 additions & 24 deletions packages/amazon-bedrock/src/bedrock-chat-language-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -246,9 +246,14 @@ export class BedrockChatLanguageModel implements LanguageModelV1 {
},
});

let toolName = '';
let toolCallId = '';
let toolCallArgs = '';
const toolCallContentBlocks: Record<
number,
{
toolCallId: string;
toolName: string;
jsonText: string;
}
> = {};

return {
stream: stream.pipeThrough(
Expand Down Expand Up @@ -317,36 +322,50 @@ export class BedrockChatLanguageModel implements LanguageModelV1 {
});
}

if (value.contentBlockStart?.start?.toolUse) {
// store the tool name and id for the next chunk
const toolUse = value.contentBlockStart.start.toolUse;
toolName = toolUse.name ?? '';
toolCallId = toolUse.toolUseId ?? '';
const contentBlockStart = value.contentBlockStart;
if (contentBlockStart?.start?.toolUse != null) {
const toolUse = contentBlockStart.start.toolUse;
toolCallContentBlocks[contentBlockStart.contentBlockIndex!] = {
toolCallId: toolUse.toolUseId!,
toolName: toolUse.name!,
jsonText: '',
};
}

if (value.contentBlockDelta?.delta?.toolUse) {
// continue to get the chunks of the tool call args
toolCallArgs += value.contentBlockDelta.delta.toolUse.input ?? '';
const contentBlockDelta = value.contentBlockDelta;
if (contentBlockDelta?.delta?.toolUse) {
const contentBlock =
toolCallContentBlocks[contentBlockDelta.contentBlockIndex!];
const delta = contentBlockDelta.delta.toolUse.input ?? '';

controller.enqueue({
type: 'tool-call-delta',
toolCallType: 'function',
toolCallId,
toolName,
argsTextDelta:
value.contentBlockDelta.delta.toolUse.input ?? '',
toolCallId: contentBlock.toolCallId,
toolName: contentBlock.toolName,
argsTextDelta: delta,
});

contentBlock.jsonText += delta;
}

// if the content is done and a tool call was made, send it
if (value.contentBlockStop && toolCallArgs.length > 0) {
controller.enqueue({
type: 'tool-call',
toolCallType: 'function',
toolCallId,
toolName,
args: toolCallArgs,
});
const contentBlockStop = value.contentBlockStop;
if (contentBlockStop != null) {
const index = contentBlockStop.contentBlockIndex!;
const contentBlock = toolCallContentBlocks[index];

// when finishing a tool call block, send the full tool call:
if (contentBlock != null) {
controller.enqueue({
type: 'tool-call',
toolCallType: 'function',
toolCallId: contentBlock.toolCallId,
toolName: contentBlock.toolName,
args: contentBlock.jsonText,
});

delete toolCallContentBlocks[index];
}
}
},

Expand Down

0 comments on commit 8f080f4

Please sign in to comment.