Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix (ai/core): support tool calls without arguments #3073

Merged
merged 6 commits into from
Sep 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/popular-suits-jam.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'ai': patch
---

fix (ai/core): support tool calls without arguments
5 changes: 3 additions & 2 deletions packages/ai/core/generate-text/generate-text.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,11 @@ import {
calculateLanguageModelUsage,
} from '../types/usage';
import { GenerateTextResult } from './generate-text-result';
import { parseToolCall } from './parse-tool-call';
import { StepResult } from './step-result';
import { toResponseMessages } from './to-response-messages';
import { ToToolCallArray, parseToolCall } from './tool-call';
import { ToToolCallArray } from './tool-call';
import { ToToolResultArray } from './tool-result';
import { StepResult } from './step-result';

const originalGenerateId = createIdGenerator({ prefix: 'aitxt-', size: 24 });

Expand Down
110 changes: 110 additions & 0 deletions packages/ai/core/generate-text/parse-tool-call.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import { z } from 'zod';
import { InvalidToolArgumentsError } from '../../errors/invalid-tool-arguments-error';
import { NoSuchToolError } from '../../errors/no-such-tool-error';
import { tool } from '../tool';
import { parseToolCall } from './parse-tool-call';

it('should successfully parse a valid tool call', () => {
const result = parseToolCall({
toolCall: {
toolCallType: 'function',
toolName: 'testTool',
toolCallId: '123',
args: '{"param1": "test", "param2": 42}',
},
tools: {
testTool: tool({
parameters: z.object({
param1: z.string(),
param2: z.number(),
}),
}),
} as const,
});

expect(result).toEqual({
type: 'tool-call',
toolCallId: '123',
toolName: 'testTool',
args: { param1: 'test', param2: 42 },
});
});

it('should successfully process empty calls for tools that have no parameters', () => {
const result = parseToolCall({
toolCall: {
toolCallType: 'function',
toolName: 'testTool',
toolCallId: '123',
args: '',
},
tools: {
testTool: tool({
parameters: z.object({}),
}),
} as const,
});

expect(result).toEqual({
type: 'tool-call',
toolCallId: '123',
toolName: 'testTool',
args: {},
});
});

it('should throw NoSuchToolError when tools is null', () => {
expect(() =>
parseToolCall({
toolCall: {
toolCallType: 'function',
toolName: 'testTool',
toolCallId: '123',
args: '{}',
},
tools: undefined,
}),
).toThrow(NoSuchToolError);
});

it('should throw NoSuchToolError when tool is not found', () => {
expect(() =>
parseToolCall({
toolCall: {
toolCallType: 'function',
toolName: 'nonExistentTool',
toolCallId: '123',
args: '{}',
},
tools: {
testTool: tool({
parameters: z.object({
param1: z.string(),
param2: z.number(),
}),
}),
} as const,
}),
).toThrow(NoSuchToolError);
});

it('should throw InvalidToolArgumentsError when args are invalid', () => {
expect(() =>
parseToolCall({
toolCall: {
toolCallType: 'function',
toolName: 'testTool',
toolCallId: '123',
args: '{"param1": "test"}', // Missing required param2
},
tools: {
testTool: tool({
parameters: z.object({
param1: z.string(),
param2: z.number(),
}),
}),
} as const,
}),
).toThrow(InvalidToolArgumentsError);
});
57 changes: 57 additions & 0 deletions packages/ai/core/generate-text/parse-tool-call.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import { LanguageModelV1FunctionToolCall } from '@ai-sdk/provider';
import { safeParseJSON, safeValidateTypes } from '@ai-sdk/provider-utils';
import { Schema, asSchema } from '@ai-sdk/ui-utils';
import { InvalidToolArgumentsError } from '../../errors/invalid-tool-arguments-error';
import { NoSuchToolError } from '../../errors/no-such-tool-error';
import { CoreTool } from '../tool';
import { inferParameters } from '../tool/tool';
import { ToToolCall } from './tool-call';

export function parseToolCall<TOOLS extends Record<string, CoreTool>>({
toolCall,
tools,
}: {
toolCall: LanguageModelV1FunctionToolCall;
tools?: TOOLS;
}): ToToolCall<TOOLS> {
const toolName = toolCall.toolName as keyof TOOLS & string;

if (tools == null) {
throw new NoSuchToolError({ toolName: toolCall.toolName });
}

const tool = tools[toolName];

if (tool == null) {
throw new NoSuchToolError({
toolName: toolCall.toolName,
availableTools: Object.keys(tools),
});
}

const schema = asSchema(tool.parameters) as Schema<
inferParameters<TOOLS[keyof TOOLS]['parameters']>
>;

// when the tool call has no arguments, we try passing an empty object to the schema
// (many LLMs generate empty strings for tool calls with no arguments)
const parseResult =
toolCall.args.trim() === ''
? safeValidateTypes({ value: {}, schema })
: safeParseJSON({ text: toolCall.args, schema });

if (parseResult.success === false) {
throw new InvalidToolArgumentsError({
toolName,
toolArgs: toolCall.args,
cause: parseResult.error,
});
}

return {
type: 'tool-call',
toolCallId: toolCall.toolCallId,
toolName,
args: parseResult.value,
};
}
5 changes: 3 additions & 2 deletions packages/ai/core/generate-text/run-tools-transformation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@ import { selectTelemetryAttributes } from '../telemetry/select-telemetry-attribu
import { TelemetrySettings } from '../telemetry/telemetry-settings';
import { CoreTool } from '../tool';
import {
LanguageModelUsage,
FinishReason,
LanguageModelUsage,
LogProbs,
ProviderMetadata,
} from '../types';
import { calculateLanguageModelUsage } from '../types/usage';
import { parseToolCall, ToToolCall } from './tool-call';
import { parseToolCall } from './parse-tool-call';
import { ToToolCall } from './tool-call';
import { ToToolResult } from './tool-result';

export type SingleRequestTextStreamPart<
Expand Down
50 changes: 0 additions & 50 deletions packages/ai/core/generate-text/tool-call.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,3 @@
import { LanguageModelV1FunctionToolCall } from '@ai-sdk/provider';
import { safeParseJSON } from '@ai-sdk/provider-utils';
import { Schema, asSchema } from '@ai-sdk/ui-utils';
import { InvalidToolArgumentsError } from '../../errors/invalid-tool-arguments-error';
import { NoSuchToolError } from '../../errors/no-such-tool-error';
import { CoreTool } from '../tool';
import { inferParameters } from '../tool/tool';
import { ValueOf } from '../util/value-of';
Expand Down Expand Up @@ -41,48 +36,3 @@ export type ToToolCall<TOOLS extends Record<string, CoreTool>> = ValueOf<{
export type ToToolCallArray<TOOLS extends Record<string, CoreTool>> = Array<
ToToolCall<TOOLS>
>;

export function parseToolCall<TOOLS extends Record<string, CoreTool>>({
toolCall,
tools,
}: {
toolCall: LanguageModelV1FunctionToolCall;
tools?: TOOLS;
}): ToToolCall<TOOLS> {
const toolName = toolCall.toolName as keyof TOOLS & string;

if (tools == null) {
throw new NoSuchToolError({ toolName: toolCall.toolName });
}

const tool = tools[toolName];

if (tool == null) {
throw new NoSuchToolError({
toolName: toolCall.toolName,
availableTools: Object.keys(tools),
});
}

const parseResult = safeParseJSON({
text: toolCall.args,
schema: asSchema(tool.parameters) as Schema<
inferParameters<TOOLS[keyof TOOLS]['parameters']>
>,
});

if (parseResult.success === false) {
throw new InvalidToolArgumentsError({
toolName,
toolArgs: toolCall.args,
cause: parseResult.error,
});
}

return {
type: 'tool-call',
toolCallId: toolCall.toolCallId,
toolName,
args: parseResult.value,
};
}
Loading