From 8f080f4ec642ebb99f04ac5753448bd6455642c2 Mon Sep 17 00:00:00 2001 From: Lars Grammel Date: Thu, 19 Sep 2024 16:23:05 +0200 Subject: [PATCH] fix (provider/bedrock): support parallel tool calls in streaming mode (#3071) --- .changeset/six-poets-sparkle.md | 5 + .../src/bedrock-chat-language-model.test.ts | 142 +++++++++++++++++- .../src/bedrock-chat-language-model.ts | 67 ++++++--- 3 files changed, 187 insertions(+), 27 deletions(-) create mode 100644 .changeset/six-poets-sparkle.md diff --git a/.changeset/six-poets-sparkle.md b/.changeset/six-poets-sparkle.md new file mode 100644 index 00000000000..a9f9a7fb754 --- /dev/null +++ b/.changeset/six-poets-sparkle.md @@ -0,0 +1,5 @@ +--- +'@ai-sdk/amazon-bedrock': patch +--- + +fix (provider/bedrock): support parallel tool calls in streaming mode diff --git a/packages/amazon-bedrock/src/bedrock-chat-language-model.test.ts b/packages/amazon-bedrock/src/bedrock-chat-language-model.test.ts index 04b0a925cf5..c81885fbd16 100644 --- a/packages/amazon-bedrock/src/bedrock-chat-language-model.test.ts +++ b/packages/amazon-bedrock/src/bedrock-chat-language-model.test.ts @@ -344,17 +344,17 @@ describe('doStream', () => { }, { contentBlockDelta: { - contentBlockIndex: 1, + contentBlockIndex: 0, delta: { toolUse: { input: '{"value":' } }, }, }, { contentBlockDelta: { - contentBlockIndex: 2, + contentBlockIndex: 0, delta: { toolUse: { input: '"Sparkle Day"}' } }, }, }, - { contentBlockStop: { contentBlockIndex: 3 } }, + { contentBlockStop: { contentBlockIndex: 0 } }, { messageStop: { stopReason: 'tool_use' } }, ]; @@ -415,6 +415,142 @@ describe('doStream', () => { ]); }); + it('should stream parallel tool calls', async () => { + const streamData: ConverseStreamOutput[] = [ + { + contentBlockStart: { + contentBlockIndex: 0, + start: { + toolUse: { toolUseId: 'tool-use-id-1', name: 'test-tool-1' }, + }, + }, + }, + { + contentBlockDelta: { + contentBlockIndex: 0, + delta: { toolUse: { input: '{"value1":' } }, + }, + }, + { + contentBlockStart: { + contentBlockIndex: 1, + start: { + toolUse: { toolUseId: 'tool-use-id-2', name: 'test-tool-2' }, + }, + }, + }, + { + contentBlockDelta: { + contentBlockIndex: 1, + delta: { toolUse: { input: '{"value2":' } }, + }, + }, + { + contentBlockDelta: { + contentBlockIndex: 1, + delta: { toolUse: { input: '"Sparkle Day"}' } }, + }, + }, + { + contentBlockDelta: { + contentBlockIndex: 0, + delta: { toolUse: { input: '"Sparkle Day"}' } }, + }, + }, + { contentBlockStop: { contentBlockIndex: 0 } }, + { contentBlockStop: { contentBlockIndex: 1 } }, + { messageStop: { stopReason: 'tool_use' } }, + ]; + + bedrockMock.on(ConverseStreamCommand).resolves({ + stream: convertArrayToAsyncIterable(streamData), + }); + + const { stream } = await model.doStream({ + inputFormat: 'prompt', + mode: { + type: 'regular', + tools: [ + { + type: 'function', + name: 'test-tool-1', + parameters: { + type: 'object', + properties: { value1: { type: 'string' } }, + required: ['value'], + additionalProperties: false, + $schema: 'http://json-schema.org/draft-07/schema#', + }, + }, + { + type: 'function', + name: 'test-tool-2', + parameters: { + type: 'object', + properties: { value2: { type: 'string' } }, + required: ['value'], + additionalProperties: false, + $schema: 'http://json-schema.org/draft-07/schema#', + }, + }, + ], + toolChoice: { type: 'tool', toolName: 'test-tool' }, + }, + prompt: TEST_PROMPT, + }); + + expect(await convertReadableStreamToArray(stream)).toStrictEqual([ + { + type: 'tool-call-delta', + toolCallId: 'tool-use-id-1', + toolCallType: 'function', + toolName: 'test-tool-1', + argsTextDelta: '{"value1":', + }, + { + type: 'tool-call-delta', + toolCallId: 'tool-use-id-2', + toolCallType: 'function', + toolName: 'test-tool-2', + argsTextDelta: '{"value2":', + }, + { + type: 'tool-call-delta', + toolCallId: 'tool-use-id-2', + toolCallType: 'function', + toolName: 'test-tool-2', + argsTextDelta: '"Sparkle Day"}', + }, + { + type: 'tool-call-delta', + toolCallId: 'tool-use-id-1', + toolCallType: 'function', + toolName: 'test-tool-1', + argsTextDelta: '"Sparkle Day"}', + }, + { + type: 'tool-call', + toolCallId: 'tool-use-id-1', + toolCallType: 'function', + toolName: 'test-tool-1', + args: '{"value1":"Sparkle Day"}', + }, + { + type: 'tool-call', + toolCallId: 'tool-use-id-2', + toolCallType: 'function', + toolName: 'test-tool-2', + args: '{"value2":"Sparkle Day"}', + }, + { + type: 'finish', + finishReason: 'tool-calls', + usage: { promptTokens: NaN, completionTokens: NaN }, + providerMetadata: undefined, + }, + ]); + }); + it('should handle error stream parts', async () => { bedrockMock.on(ConverseStreamCommand).resolves({ stream: convertArrayToAsyncIterable([ diff --git a/packages/amazon-bedrock/src/bedrock-chat-language-model.ts b/packages/amazon-bedrock/src/bedrock-chat-language-model.ts index 9b922d07355..da33c002665 100644 --- a/packages/amazon-bedrock/src/bedrock-chat-language-model.ts +++ b/packages/amazon-bedrock/src/bedrock-chat-language-model.ts @@ -246,9 +246,14 @@ export class BedrockChatLanguageModel implements LanguageModelV1 { }, }); - let toolName = ''; - let toolCallId = ''; - let toolCallArgs = ''; + const toolCallContentBlocks: Record< + number, + { + toolCallId: string; + toolName: string; + jsonText: string; + } + > = {}; return { stream: stream.pipeThrough( @@ -317,36 +322,50 @@ export class BedrockChatLanguageModel implements LanguageModelV1 { }); } - if (value.contentBlockStart?.start?.toolUse) { - // store the tool name and id for the next chunk - const toolUse = value.contentBlockStart.start.toolUse; - toolName = toolUse.name ?? ''; - toolCallId = toolUse.toolUseId ?? ''; + const contentBlockStart = value.contentBlockStart; + if (contentBlockStart?.start?.toolUse != null) { + const toolUse = contentBlockStart.start.toolUse; + toolCallContentBlocks[contentBlockStart.contentBlockIndex!] = { + toolCallId: toolUse.toolUseId!, + toolName: toolUse.name!, + jsonText: '', + }; } - if (value.contentBlockDelta?.delta?.toolUse) { - // continue to get the chunks of the tool call args - toolCallArgs += value.contentBlockDelta.delta.toolUse.input ?? ''; + const contentBlockDelta = value.contentBlockDelta; + if (contentBlockDelta?.delta?.toolUse) { + const contentBlock = + toolCallContentBlocks[contentBlockDelta.contentBlockIndex!]; + const delta = contentBlockDelta.delta.toolUse.input ?? ''; controller.enqueue({ type: 'tool-call-delta', toolCallType: 'function', - toolCallId, - toolName, - argsTextDelta: - value.contentBlockDelta.delta.toolUse.input ?? '', + toolCallId: contentBlock.toolCallId, + toolName: contentBlock.toolName, + argsTextDelta: delta, }); + + contentBlock.jsonText += delta; } - // if the content is done and a tool call was made, send it - if (value.contentBlockStop && toolCallArgs.length > 0) { - controller.enqueue({ - type: 'tool-call', - toolCallType: 'function', - toolCallId, - toolName, - args: toolCallArgs, - }); + const contentBlockStop = value.contentBlockStop; + if (contentBlockStop != null) { + const index = contentBlockStop.contentBlockIndex!; + const contentBlock = toolCallContentBlocks[index]; + + // when finishing a tool call block, send the full tool call: + if (contentBlock != null) { + controller.enqueue({ + type: 'tool-call', + toolCallType: 'function', + toolCallId: contentBlock.toolCallId, + toolName: contentBlock.toolName, + args: contentBlock.jsonText, + }); + + delete toolCallContentBlocks[index]; + } } },