Skip to content

Commit

Permalink
fix (ai/core): support tool calls without arguments (#3073)
Browse files Browse the repository at this point in the history
  • Loading branch information
lgrammel committed Sep 19, 2024
1 parent bb447d4 commit fea6bec
Show file tree
Hide file tree
Showing 6 changed files with 178 additions and 54 deletions.
5 changes: 5 additions & 0 deletions .changeset/popular-suits-jam.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'ai': patch
---

fix (ai/core): support tool calls without arguments
5 changes: 3 additions & 2 deletions packages/ai/core/generate-text/generate-text.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,11 @@ import {
calculateLanguageModelUsage,
} from '../types/usage';
import { GenerateTextResult } from './generate-text-result';
import { parseToolCall } from './parse-tool-call';
import { StepResult } from './step-result';
import { toResponseMessages } from './to-response-messages';
import { ToToolCallArray, parseToolCall } from './tool-call';
import { ToToolCallArray } from './tool-call';
import { ToToolResultArray } from './tool-result';
import { StepResult } from './step-result';

const originalGenerateId = createIdGenerator({ prefix: 'aitxt-', size: 24 });

Expand Down
110 changes: 110 additions & 0 deletions packages/ai/core/generate-text/parse-tool-call.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import { z } from 'zod';
import { InvalidToolArgumentsError } from '../../errors/invalid-tool-arguments-error';
import { NoSuchToolError } from '../../errors/no-such-tool-error';
import { tool } from '../tool';
import { parseToolCall } from './parse-tool-call';

it('should successfully parse a valid tool call', () => {
const result = parseToolCall({
toolCall: {
toolCallType: 'function',
toolName: 'testTool',
toolCallId: '123',
args: '{"param1": "test", "param2": 42}',
},
tools: {
testTool: tool({
parameters: z.object({
param1: z.string(),
param2: z.number(),
}),
}),
} as const,
});

expect(result).toEqual({
type: 'tool-call',
toolCallId: '123',
toolName: 'testTool',
args: { param1: 'test', param2: 42 },
});
});

it('should successfully process empty calls for tools that have no parameters', () => {
const result = parseToolCall({
toolCall: {
toolCallType: 'function',
toolName: 'testTool',
toolCallId: '123',
args: '',
},
tools: {
testTool: tool({
parameters: z.object({}),
}),
} as const,
});

expect(result).toEqual({
type: 'tool-call',
toolCallId: '123',
toolName: 'testTool',
args: {},
});
});

it('should throw NoSuchToolError when tools is null', () => {
expect(() =>
parseToolCall({
toolCall: {
toolCallType: 'function',
toolName: 'testTool',
toolCallId: '123',
args: '{}',
},
tools: undefined,
}),
).toThrow(NoSuchToolError);
});

it('should throw NoSuchToolError when tool is not found', () => {
expect(() =>
parseToolCall({
toolCall: {
toolCallType: 'function',
toolName: 'nonExistentTool',
toolCallId: '123',
args: '{}',
},
tools: {
testTool: tool({
parameters: z.object({
param1: z.string(),
param2: z.number(),
}),
}),
} as const,
}),
).toThrow(NoSuchToolError);
});

it('should throw InvalidToolArgumentsError when args are invalid', () => {
expect(() =>
parseToolCall({
toolCall: {
toolCallType: 'function',
toolName: 'testTool',
toolCallId: '123',
args: '{"param1": "test"}', // Missing required param2
},
tools: {
testTool: tool({
parameters: z.object({
param1: z.string(),
param2: z.number(),
}),
}),
} as const,
}),
).toThrow(InvalidToolArgumentsError);
});
57 changes: 57 additions & 0 deletions packages/ai/core/generate-text/parse-tool-call.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import { LanguageModelV1FunctionToolCall } from '@ai-sdk/provider';
import { safeParseJSON, safeValidateTypes } from '@ai-sdk/provider-utils';
import { Schema, asSchema } from '@ai-sdk/ui-utils';
import { InvalidToolArgumentsError } from '../../errors/invalid-tool-arguments-error';
import { NoSuchToolError } from '../../errors/no-such-tool-error';
import { CoreTool } from '../tool';
import { inferParameters } from '../tool/tool';
import { ToToolCall } from './tool-call';

export function parseToolCall<TOOLS extends Record<string, CoreTool>>({
toolCall,
tools,
}: {
toolCall: LanguageModelV1FunctionToolCall;
tools?: TOOLS;
}): ToToolCall<TOOLS> {
const toolName = toolCall.toolName as keyof TOOLS & string;

if (tools == null) {
throw new NoSuchToolError({ toolName: toolCall.toolName });
}

const tool = tools[toolName];

if (tool == null) {
throw new NoSuchToolError({
toolName: toolCall.toolName,
availableTools: Object.keys(tools),
});
}

const schema = asSchema(tool.parameters) as Schema<
inferParameters<TOOLS[keyof TOOLS]['parameters']>
>;

// when the tool call has no arguments, we try passing an empty object to the schema
// (many LLMs generate empty strings for tool calls with no arguments)
const parseResult =
toolCall.args.trim() === ''
? safeValidateTypes({ value: {}, schema })
: safeParseJSON({ text: toolCall.args, schema });

if (parseResult.success === false) {
throw new InvalidToolArgumentsError({
toolName,
toolArgs: toolCall.args,
cause: parseResult.error,
});
}

return {
type: 'tool-call',
toolCallId: toolCall.toolCallId,
toolName,
args: parseResult.value,
};
}
5 changes: 3 additions & 2 deletions packages/ai/core/generate-text/run-tools-transformation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@ import { selectTelemetryAttributes } from '../telemetry/select-telemetry-attribu
import { TelemetrySettings } from '../telemetry/telemetry-settings';
import { CoreTool } from '../tool';
import {
LanguageModelUsage,
FinishReason,
LanguageModelUsage,
LogProbs,
ProviderMetadata,
} from '../types';
import { calculateLanguageModelUsage } from '../types/usage';
import { parseToolCall, ToToolCall } from './tool-call';
import { parseToolCall } from './parse-tool-call';
import { ToToolCall } from './tool-call';
import { ToToolResult } from './tool-result';

export type SingleRequestTextStreamPart<
Expand Down
50 changes: 0 additions & 50 deletions packages/ai/core/generate-text/tool-call.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,3 @@
import { LanguageModelV1FunctionToolCall } from '@ai-sdk/provider';
import { safeParseJSON } from '@ai-sdk/provider-utils';
import { Schema, asSchema } from '@ai-sdk/ui-utils';
import { InvalidToolArgumentsError } from '../../errors/invalid-tool-arguments-error';
import { NoSuchToolError } from '../../errors/no-such-tool-error';
import { CoreTool } from '../tool';
import { inferParameters } from '../tool/tool';
import { ValueOf } from '../util/value-of';
Expand Down Expand Up @@ -41,48 +36,3 @@ export type ToToolCall<TOOLS extends Record<string, CoreTool>> = ValueOf<{
export type ToToolCallArray<TOOLS extends Record<string, CoreTool>> = Array<
ToToolCall<TOOLS>
>;

export function parseToolCall<TOOLS extends Record<string, CoreTool>>({
toolCall,
tools,
}: {
toolCall: LanguageModelV1FunctionToolCall;
tools?: TOOLS;
}): ToToolCall<TOOLS> {
const toolName = toolCall.toolName as keyof TOOLS & string;

if (tools == null) {
throw new NoSuchToolError({ toolName: toolCall.toolName });
}

const tool = tools[toolName];

if (tool == null) {
throw new NoSuchToolError({
toolName: toolCall.toolName,
availableTools: Object.keys(tools),
});
}

const parseResult = safeParseJSON({
text: toolCall.args,
schema: asSchema(tool.parameters) as Schema<
inferParameters<TOOLS[keyof TOOLS]['parameters']>
>,
});

if (parseResult.success === false) {
throw new InvalidToolArgumentsError({
toolName,
toolArgs: toolCall.args,
cause: parseResult.error,
});
}

return {
type: 'tool-call',
toolCallId: toolCall.toolCallId,
toolName,
args: parseResult.value,
};
}

0 comments on commit fea6bec

Please sign in to comment.