From 8132a608e167e3f5675315b69d24115e4fffe2e3 Mon Sep 17 00:00:00 2001 From: Lars Grammel Date: Fri, 20 Sep 2024 13:27:11 +0200 Subject: [PATCH] feat (provider/openai): support reasoning token usage and max_completion_tokens (#3078) --- .changeset/shy-toys-sip.md | 5 ++ .../02-providers-and-models.mdx | 2 + content/docs/02-guides/04-o1.mdx | 7 +- .../01-ai-sdk-providers/01-openai.mdx | 33 ++++++++ .../providers/01-ai-sdk-providers/index.mdx | 2 + .../generate-text/openai-reasoning-model.ts | 22 ++++++ .../src/openai-chat-language-model.test.ts | 76 +++++++++++++++++++ .../openai/src/openai-chat-language-model.ts | 31 ++++++-- 8 files changed, 170 insertions(+), 8 deletions(-) create mode 100644 .changeset/shy-toys-sip.md create mode 100644 examples/ai-core/src/generate-text/openai-reasoning-model.ts diff --git a/.changeset/shy-toys-sip.md b/.changeset/shy-toys-sip.md new file mode 100644 index 00000000000..9a7e6349ca5 --- /dev/null +++ b/.changeset/shy-toys-sip.md @@ -0,0 +1,5 @@ +--- +'@ai-sdk/openai': patch +--- + +feat (provider/openai): support reasoning token usage and max_completion_tokens diff --git a/content/docs/02-foundations/02-providers-and-models.mdx b/content/docs/02-foundations/02-providers-and-models.mdx index e4471e24328..7b6fe23e71e 100644 --- a/content/docs/02-foundations/02-providers-and-models.mdx +++ b/content/docs/02-foundations/02-providers-and-models.mdx @@ -60,6 +60,8 @@ Here are the capabilities of popular models: | [OpenAI](/providers/ai-sdk-providers/openai) | `gpt-4o-mini` | | | | | | [OpenAI](/providers/ai-sdk-providers/openai) | `gpt-4-turbo` | | | | | | [OpenAI](/providers/ai-sdk-providers/openai) | `gpt-4` | | | | | +| [OpenAI](/providers/ai-sdk-providers/openai) | `o1-preview` | | | | | +| [OpenAI](/providers/ai-sdk-providers/openai) | `o1-mini` | | | | | | [Anthropic](/providers/ai-sdk-providers/anthropic) | `claude-3-5-sonnet-20240620` | | | | | | [Mistral](/providers/ai-sdk-providers/mistral) | `mistral-large-latest` | | | | | | [Mistral](/providers/ai-sdk-providers/mistral) | `mistral-small-latest` | | | | | diff --git a/content/docs/02-guides/04-o1.mdx b/content/docs/02-guides/04-o1.mdx index 549ab37d827..8f98a2edd57 100644 --- a/content/docs/02-guides/04-o1.mdx +++ b/content/docs/02-guides/04-o1.mdx @@ -155,8 +155,9 @@ The useChat hook on your root page (`app/page.tsx`) will make a request to your Ready to get started? Here's how you can dive in: 1. Explore the documentation at [sdk.vercel.ai/docs](/docs) to understand the full capabilities of the AI SDK. -2. Check out practical examples at [sdk.vercel.ai/examples](/examples) to see the SDK in action and get inspired for your own projects. -3. Dive deeper with advanced guides on topics like Retrieval-Augmented Generation (RAG) and multi-modal chat at [sdk.vercel.ai/docs/guides](/docs/guides). -4. Check out ready-to-deploy AI templates at [vercel.com/templates?type=ai](https://vercel.com/templates?type=ai). +1. Check out our support for the o1 series of reasoning models in the [OpenAI Provider](/providers/ai-sdk-providers/openai#reasoning-models). +1. Check out practical examples at [sdk.vercel.ai/examples](/examples) to see the SDK in action and get inspired for your own projects. +1. Dive deeper with advanced guides on topics like Retrieval-Augmented Generation (RAG) and multi-modal chat at [sdk.vercel.ai/docs/guides](/docs/guides). +1. Check out ready-to-deploy AI templates at [vercel.com/templates?type=ai](https://vercel.com/templates?type=ai). Remember that OpenAI o1 models are currently in beta with limited features and access. Stay tuned for updates as OpenAI expands access and adds more features to these powerful reasoning models. diff --git a/content/providers/01-ai-sdk-providers/01-openai.mdx b/content/providers/01-ai-sdk-providers/01-openai.mdx index 0e1446d464c..064c010a539 100644 --- a/content/providers/01-ai-sdk-providers/01-openai.mdx +++ b/content/providers/01-ai-sdk-providers/01-openai.mdx @@ -127,6 +127,8 @@ OpenAI language models can also be used in the `streamText`, `generateObject`, ` | `gpt-4-turbo` | | | | | | `gpt-4` | | | | | | `gpt-3.5-turbo` | | | | | +| `o1-preview` | | | | | +| `o1-mini` | | | | | The table above lists popular models. You can also pass any available provider @@ -288,6 +290,37 @@ const result = await generateText({ }); ``` +#### Reasoning Models + +OpenAI has introduced the `o1` series of [reasoning models](https://platform.openai.com/docs/guides/reasoning). +Currently, `o1-mini` and `o1-preview` are available. + +Reasoning models currently have several limitations and are only supported using `generateText`. +They support two additional options: + +- You can use request `experimental_providerMetadata` to set the `maxCompletionTokens` option, which determines the maximum number + of both reasoning and output tokens that the model generates. +- You can use response `experimental_providerMetadata` to access the number of reasoning tokens that the model generated. + +```ts highlight="4,7-9,15" +import { openai } from '@ai-sdk/openai'; +import { generateText } from 'ai'; + +const { text, usage, experimental_providerMetadata } = await generateText({ + model: openai('o1-mini'), + prompt: 'Invent a new holiday and describe its traditions.', + experimental_providerMetadata: { + openai: { maxCompletionTokens: 1000 }, + }, +}); + +console.log(text); +console.log('Usage:', { + ...usage, + reasoningTokens: experimental_providerMetadata?.openai?.reasoningTokens, +}); +``` + ### Completion Models You can create models that call the [OpenAI completions API](https://platform.openai.com/docs/api-reference/completions) using the `.completion()` factory method. diff --git a/content/providers/01-ai-sdk-providers/index.mdx b/content/providers/01-ai-sdk-providers/index.mdx index 445667b6f95..d8bbcace312 100644 --- a/content/providers/01-ai-sdk-providers/index.mdx +++ b/content/providers/01-ai-sdk-providers/index.mdx @@ -23,6 +23,8 @@ Not all providers support all AI SDK features. Here's a quick comparison of the | [OpenAI](/providers/ai-sdk-providers/openai) | `gpt-4o-mini` | | | | | | [OpenAI](/providers/ai-sdk-providers/openai) | `gpt-4-turbo` | | | | | | [OpenAI](/providers/ai-sdk-providers/openai) | `gpt-4` | | | | | +| [OpenAI](/providers/ai-sdk-providers/openai) | `o1-preview` | | | | | +| [OpenAI](/providers/ai-sdk-providers/openai) | `o1-mini` | | | | | | [Anthropic](/providers/ai-sdk-providers/anthropic) | `claude-3-5-sonnet-20240620` | | | | | | [Mistral](/providers/ai-sdk-providers/mistral) | `mistral-large-latest` | | | | | | [Mistral](/providers/ai-sdk-providers/mistral) | `mistral-small-latest` | | | | | diff --git a/examples/ai-core/src/generate-text/openai-reasoning-model.ts b/examples/ai-core/src/generate-text/openai-reasoning-model.ts new file mode 100644 index 00000000000..8dded52c3ec --- /dev/null +++ b/examples/ai-core/src/generate-text/openai-reasoning-model.ts @@ -0,0 +1,22 @@ +import { openai } from '@ai-sdk/openai'; +import { generateText } from 'ai'; +import 'dotenv/config'; + +async function main() { + const { text, usage, experimental_providerMetadata } = await generateText({ + model: openai('o1-mini'), + prompt: 'Invent a new holiday and describe its traditions.', + experimental_providerMetadata: { + openai: { maxCompletionTokens: 1000 }, + }, + }); + + console.log(text); + console.log(); + console.log('Usage:', { + ...usage, + reasoningTokens: experimental_providerMetadata?.openai?.reasoningTokens, + }); +} + +main().catch(console.error); diff --git a/packages/openai/src/openai-chat-language-model.test.ts b/packages/openai/src/openai-chat-language-model.test.ts index 45662329b25..f3075d561c4 100644 --- a/packages/openai/src/openai-chat-language-model.test.ts +++ b/packages/openai/src/openai-chat-language-model.test.ts @@ -173,6 +173,9 @@ describe('doGenerate', () => { prompt_tokens?: number; total_tokens?: number; completion_tokens?: number; + completion_tokens_details?: { + reasoning_tokens?: number; + }; }; logprobs?: { content: @@ -835,6 +838,79 @@ describe('doGenerate', () => { }, ]); }); + + describe('reasoning models', () => { + it('should clear out temperature, top_p, frequency_penalty, presence_penalty', async () => { + prepareJsonResponse(); + + const model = provider.chat('o1-preview'); + + await model.doGenerate({ + inputFormat: 'prompt', + mode: { type: 'regular' }, + prompt: TEST_PROMPT, + temperature: 0.5, + topP: 0.7, + frequencyPenalty: 0.2, + presencePenalty: 0.3, + }); + + expect(await server.getRequestBodyJson()).toStrictEqual({ + model: 'o1-preview', + messages: [{ role: 'user', content: 'Hello' }], + }); + }); + }); + + it('should return the reasoning tokens in the provider metadata', async () => { + prepareJsonResponse({ + usage: { + prompt_tokens: 15, + completion_tokens: 20, + total_tokens: 35, + completion_tokens_details: { + reasoning_tokens: 10, + }, + }, + }); + + const model = provider.chat('o1-preview'); + + const result = await model.doGenerate({ + inputFormat: 'prompt', + mode: { type: 'regular' }, + prompt: TEST_PROMPT, + }); + + expect(result.providerMetadata).toStrictEqual({ + openai: { + reasoningTokens: 10, + }, + }); + }); + + it('should send max_completion_tokens extension setting', async () => { + prepareJsonResponse(); + + const model = provider.chat('o1-preview'); + + await model.doGenerate({ + inputFormat: 'prompt', + mode: { type: 'regular' }, + prompt: TEST_PROMPT, + providerMetadata: { + openai: { + maxCompletionTokens: 255, + }, + }, + }); + + expect(await server.getRequestBodyJson()).toStrictEqual({ + model: 'o1-preview', + messages: [{ role: 'user', content: 'Hello' }], + max_completion_tokens: 255, + }); + }); }); describe('doStream', () => { diff --git a/packages/openai/src/openai-chat-language-model.ts b/packages/openai/src/openai-chat-language-model.ts index c5dd487f762..e5e5d0ae196 100644 --- a/packages/openai/src/openai-chat-language-model.ts +++ b/packages/openai/src/openai-chat-language-model.ts @@ -83,6 +83,7 @@ export class OpenAIChatLanguageModel implements LanguageModelV1 { stopSequences, responseFormat, seed, + providerMetadata, }: Parameters[0]) { const type = mode.type; @@ -152,6 +153,10 @@ export class OpenAIChatLanguageModel implements LanguageModelV1 { stop: stopSequences, seed, + // openai specific settings: + max_completion_tokens: + providerMetadata?.openai?.maxCompletionTokens ?? undefined, + // response format: response_format: responseFormat?.type === 'json' ? { type: 'json_object' } : undefined, @@ -163,12 +168,12 @@ export class OpenAIChatLanguageModel implements LanguageModelV1 { }), }; - // reasoning models have fixed params: + // reasoning models have fixed params, remove them if they are set: if (this.modelId === 'o1-preview' || this.modelId === 'o1-mini') { - baseArgs.temperature = 1; - baseArgs.top_p = 1; - baseArgs.frequency_penalty = 0; - baseArgs.presence_penalty = 0; + baseArgs.temperature = undefined; + baseArgs.top_p = undefined; + baseArgs.frequency_penalty = undefined; + baseArgs.presence_penalty = undefined; } switch (type) { @@ -278,6 +283,16 @@ export class OpenAIChatLanguageModel implements LanguageModelV1 { const { messages: rawPrompt, ...rawSettings } = args; const choice = response.choices[0]; + const providerMetadata = + response.usage?.completion_tokens_details?.reasoning_tokens != null + ? { + openai: { + reasoningTokens: + response.usage?.completion_tokens_details?.reasoning_tokens, + }, + } + : undefined; + return { text: choice.message.content ?? undefined, toolCalls: @@ -306,6 +321,7 @@ export class OpenAIChatLanguageModel implements LanguageModelV1 { response: getResponseMetadata(response), warnings, logprobs: mapOpenAIChatLogProbsOutput(choice.logprobs), + providerMetadata, }; } @@ -568,6 +584,11 @@ const openAITokenUsageSchema = z .object({ prompt_tokens: z.number().nullish(), completion_tokens: z.number().nullish(), + completion_tokens_details: z + .object({ + reasoning_tokens: z.number().nullish(), + }) + .nullish(), }) .nullish();