From 76fe72976b84c98f9fc096001a97f75a4176e00e Mon Sep 17 00:00:00 2001 From: SergeyMenshykh Date: Tue, 6 Aug 2024 20:15:41 +0100 Subject: [PATCH 1/3] feat(function-calling): enable chat history mutation witin auto funciton invocation filters. --- .../AzureOpenAIChatCompletionServiceTests.cs | 164 ++++++++++++++++ .../OpenAIChatCompletionServiceTests.cs | 181 +++++++++++++++++- .../Core/ClientCore.ChatCompletion.cs | 87 ++------- 3 files changed, 360 insertions(+), 72 deletions(-) diff --git a/dotnet/src/Connectors/Connectors.AzureOpenAI.UnitTests/Services/AzureOpenAIChatCompletionServiceTests.cs b/dotnet/src/Connectors/Connectors.AzureOpenAI.UnitTests/Services/AzureOpenAIChatCompletionServiceTests.cs index 2e639434e951..ed9314e2acfb 100644 --- a/dotnet/src/Connectors/Connectors.AzureOpenAI.UnitTests/Services/AzureOpenAIChatCompletionServiceTests.cs +++ b/dotnet/src/Connectors/Connectors.AzureOpenAI.UnitTests/Services/AzureOpenAIChatCompletionServiceTests.cs @@ -900,6 +900,150 @@ public async Task FunctionResultsCanBeProvidedToLLMAsManyResultsInOneChatMessage Assert.Equal("2", assistantMessage2.GetProperty("tool_call_id").GetString()); } + [Fact] + public async Task GetChatMessageContentShouldSendMutatedChatHistoryToLLM() + { + // Arrange + static void MutateChatHistory(AutoFunctionInvocationContext context, Func next) + { + // Remove the function call messages from the chat history to reduce token count. + context.ChatHistory.RemoveRange(1, 2); // Remove the `Date` function call and function result messages. + + next(context); + } + + var kernel = new Kernel(); + kernel.ImportPluginFromFunctions("MyPlugin", [KernelFunctionFactory.CreateFromMethod(() => "rainy", "GetCurrentWeather")]); + kernel.AutoFunctionInvocationFilters.Add(new AutoFunctionInvocationFilter(MutateChatHistory)); + + using var firstResponse = new HttpResponseMessage(System.Net.HttpStatusCode.OK) { Content = new StreamContent(File.OpenRead("TestData/chat_completion_single_function_call_test_response.json")) }; + this._messageHandlerStub.ResponsesToReturn.Add(firstResponse); + + using var secondResponse = new HttpResponseMessage(System.Net.HttpStatusCode.OK) { Content = new StreamContent(File.OpenRead("TestData/chat_completion_test_response.json")) }; + this._messageHandlerStub.ResponsesToReturn.Add(secondResponse); + + var sut = new OpenAIChatCompletionService(modelId: "gpt-3.5-turbo", apiKey: "NOKEY", httpClient: this._httpClient); + + var chatHistory = new ChatHistory + { + new ChatMessageContent(AuthorRole.User, "What time is it?"), + new ChatMessageContent(AuthorRole.Assistant, [ + new FunctionCallContent("Date", "TimePlugin", "2") + ]), + new ChatMessageContent(AuthorRole.Tool, [ + new FunctionResultContent("Date", "TimePlugin", "2", "rainy") + ]), + new ChatMessageContent(AuthorRole.Assistant, "08/06/2024 00:00:00"), + new ChatMessageContent(AuthorRole.User, "Given the current time of day and weather, what is the likely color of the sky in Boston?") + }; + + // Act + await sut.GetChatMessageContentAsync(chatHistory, new AzureOpenAIPromptExecutionSettings() { ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions }, kernel); + + // Assert + var actualRequestContent = Encoding.UTF8.GetString(this._messageHandlerStub.RequestContents[1]!); + Assert.NotNull(actualRequestContent); + + var optionsJson = JsonSerializer.Deserialize(actualRequestContent); + + var messages = optionsJson.GetProperty("messages"); + Assert.Equal(5, messages.GetArrayLength()); + + var userFirstPrompt = messages[0]; + Assert.Equal("user", userFirstPrompt.GetProperty("role").GetString()); + Assert.Equal("What time is it?", userFirstPrompt.GetProperty("content").ToString()); + + var assistantFirstResponse = messages[1]; + Assert.Equal("assistant", assistantFirstResponse.GetProperty("role").GetString()); + Assert.Equal("08/06/2024 00:00:00", assistantFirstResponse.GetProperty("content").GetString()); + + var userSecondPrompt = messages[2]; + Assert.Equal("user", userSecondPrompt.GetProperty("role").GetString()); + Assert.Equal("Given the current time of day and weather, what is the likely color of the sky in Boston?", userSecondPrompt.GetProperty("content").ToString()); + + var assistantSecondResponse = messages[3]; + Assert.Equal("assistant", assistantSecondResponse.GetProperty("role").GetString()); + Assert.Equal("1", assistantSecondResponse.GetProperty("tool_calls")[0].GetProperty("id").GetString()); + Assert.Equal("MyPlugin-GetCurrentWeather", assistantSecondResponse.GetProperty("tool_calls")[0].GetProperty("function").GetProperty("name").GetString()); + + var functionResult = messages[4]; + Assert.Equal("tool", functionResult.GetProperty("role").GetString()); + Assert.Equal("rainy", functionResult.GetProperty("content").GetString()); + } + + [Fact] + public async Task GetStreamingChatMessageContentsShouldSendMutatedChatHistoryToLLM() + { + // Arrange + static void MutateChatHistory(AutoFunctionInvocationContext context, Func next) + { + // Remove the function call messages from the chat history to reduce token count. + context.ChatHistory.RemoveRange(1, 2); // Remove the `Date` function call and function result messages. + + next(context); + } + + var kernel = new Kernel(); + kernel.ImportPluginFromFunctions("MyPlugin", [KernelFunctionFactory.CreateFromMethod(() => "rainy", "GetCurrentWeather")]); + kernel.AutoFunctionInvocationFilters.Add(new AutoFunctionInvocationFilter(MutateChatHistory)); + + using var firstResponse = new HttpResponseMessage(HttpStatusCode.OK) { Content = new StreamContent(File.OpenRead("TestData/chat_completion_streaming_single_function_call_test_response.txt")) }; + this._messageHandlerStub.ResponsesToReturn.Add(firstResponse); + + using var secondResponse = new HttpResponseMessage(HttpStatusCode.OK) { Content = new StreamContent(File.OpenRead("TestData/chat_completion_streaming_test_response.txt")) }; + this._messageHandlerStub.ResponsesToReturn.Add(secondResponse); + + var sut = new OpenAIChatCompletionService(modelId: "gpt-3.5-turbo", apiKey: "NOKEY", httpClient: this._httpClient); + + var chatHistory = new ChatHistory + { + new ChatMessageContent(AuthorRole.User, "What time is it?"), + new ChatMessageContent(AuthorRole.Assistant, [ + new FunctionCallContent("Date", "TimePlugin", "2") + ]), + new ChatMessageContent(AuthorRole.Tool, [ + new FunctionResultContent("Date", "TimePlugin", "2", "rainy") + ]), + new ChatMessageContent(AuthorRole.Assistant, "08/06/2024 00:00:00"), + new ChatMessageContent(AuthorRole.User, "Given the current time of day and weather, what is the likely color of the sky in Boston?") + }; + + // Act + await foreach (var update in sut.GetStreamingChatMessageContentsAsync(chatHistory, new AzureOpenAIPromptExecutionSettings() { ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions }, kernel)) + { + } + + // Assert + var actualRequestContent = Encoding.UTF8.GetString(this._messageHandlerStub.RequestContents[1]!); + Assert.NotNull(actualRequestContent); + + var optionsJson = JsonSerializer.Deserialize(actualRequestContent); + + var messages = optionsJson.GetProperty("messages"); + Assert.Equal(5, messages.GetArrayLength()); + + var userFirstPrompt = messages[0]; + Assert.Equal("user", userFirstPrompt.GetProperty("role").GetString()); + Assert.Equal("What time is it?", userFirstPrompt.GetProperty("content").ToString()); + + var assistantFirstResponse = messages[1]; + Assert.Equal("assistant", assistantFirstResponse.GetProperty("role").GetString()); + Assert.Equal("08/06/2024 00:00:00", assistantFirstResponse.GetProperty("content").GetString()); + + var userSecondPrompt = messages[2]; + Assert.Equal("user", userSecondPrompt.GetProperty("role").GetString()); + Assert.Equal("Given the current time of day and weather, what is the likely color of the sky in Boston?", userSecondPrompt.GetProperty("content").ToString()); + + var assistantSecondResponse = messages[3]; + Assert.Equal("assistant", assistantSecondResponse.GetProperty("role").GetString()); + Assert.Equal("1", assistantSecondResponse.GetProperty("tool_calls")[0].GetProperty("id").GetString()); + Assert.Equal("MyPlugin-GetCurrentWeather", assistantSecondResponse.GetProperty("tool_calls")[0].GetProperty("function").GetProperty("name").GetString()); + + var functionResult = messages[4]; + Assert.Equal("tool", functionResult.GetProperty("role").GetString()); + Assert.Equal("rainy", functionResult.GetProperty("content").GetString()); + } + public void Dispose() { this._httpClient.Dispose(); @@ -917,4 +1061,24 @@ public void Dispose() { "json_object", "json_object" }, { "text", "text" } }; + + private class AutoFunctionInvocationFilter : IAutoFunctionInvocationFilter + { + private readonly Func, Task> _callback; + + public AutoFunctionInvocationFilter(Func, Task> callback) + { + this._callback = callback; + } + + public AutoFunctionInvocationFilter(Action> callback) + { + this._callback = (c, n) => { callback(c, n); return Task.CompletedTask; }; + } + + public async Task OnAutoFunctionInvocationAsync(AutoFunctionInvocationContext context, Func next) + { + await this._callback(context, next); + } + } } diff --git a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Services/OpenAIChatCompletionServiceTests.cs b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Services/OpenAIChatCompletionServiceTests.cs index 1a0145d137f2..b8855f590805 100644 --- a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Services/OpenAIChatCompletionServiceTests.cs +++ b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Services/OpenAIChatCompletionServiceTests.cs @@ -892,6 +892,150 @@ public async Task GetInvalidResponseThrowsExceptionAndIsCapturedByDiagnosticsAsy Assert.True(startedChatCompletionsActivity); } + [Fact] + public async Task GetChatMessageContentShouldSendMutatedChatHistoryToLLM() + { + // Arrange + static void MutateChatHistory(AutoFunctionInvocationContext context, Func next) + { + // Remove the function call messages from the chat history to reduce token count. + context.ChatHistory.RemoveRange(1, 2); // Remove the `Date` function call and function result messages. + + next(context); + } + + var kernel = new Kernel(); + kernel.ImportPluginFromFunctions("MyPlugin", [KernelFunctionFactory.CreateFromMethod(() => "rainy", "GetCurrentWeather")]); + kernel.AutoFunctionInvocationFilters.Add(new AutoFunctionInvocationFilter(MutateChatHistory)); + + using var firstResponse = new HttpResponseMessage(System.Net.HttpStatusCode.OK) { Content = new StreamContent(File.OpenRead("TestData/chat_completion_single_function_call_test_response.json")) }; + this._messageHandlerStub.ResponseQueue.Enqueue(firstResponse); + + using var secondResponse = new HttpResponseMessage(System.Net.HttpStatusCode.OK) { Content = new StreamContent(File.OpenRead("TestData/chat_completion_test_response.json")) }; + this._messageHandlerStub.ResponseQueue.Enqueue(secondResponse); + + var sut = new OpenAIChatCompletionService(modelId: "gpt-3.5-turbo", apiKey: "NOKEY", httpClient: this._httpClient); + + var chatHistory = new ChatHistory + { + new ChatMessageContent(AuthorRole.User, "What time is it?"), + new ChatMessageContent(AuthorRole.Assistant, [ + new FunctionCallContent("Date", "TimePlugin", "2") + ]), + new ChatMessageContent(AuthorRole.Tool, [ + new FunctionResultContent("Date", "TimePlugin", "2", "rainy") + ]), + new ChatMessageContent(AuthorRole.Assistant, "08/06/2024 00:00:00"), + new ChatMessageContent(AuthorRole.User, "Given the current time of day and weather, what is the likely color of the sky in Boston?") + }; + + // Act + await sut.GetChatMessageContentAsync(chatHistory, new OpenAIPromptExecutionSettings() { ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions }, kernel); + + // Assert + var actualRequestContent = Encoding.UTF8.GetString(this._messageHandlerStub.RequestContent!); + Assert.NotNull(actualRequestContent); + + var optionsJson = JsonSerializer.Deserialize(actualRequestContent); + + var messages = optionsJson.GetProperty("messages"); + Assert.Equal(5, messages.GetArrayLength()); + + var userFirstPrompt = messages[0]; + Assert.Equal("user", userFirstPrompt.GetProperty("role").GetString()); + Assert.Equal("What time is it?", userFirstPrompt.GetProperty("content").ToString()); + + var assistantFirstResponse = messages[1]; + Assert.Equal("assistant", assistantFirstResponse.GetProperty("role").GetString()); + Assert.Equal("08/06/2024 00:00:00", assistantFirstResponse.GetProperty("content").GetString()); + + var userSecondPrompt = messages[2]; + Assert.Equal("user", userSecondPrompt.GetProperty("role").GetString()); + Assert.Equal("Given the current time of day and weather, what is the likely color of the sky in Boston?", userSecondPrompt.GetProperty("content").ToString()); + + var assistantSecondResponse = messages[3]; + Assert.Equal("assistant", assistantSecondResponse.GetProperty("role").GetString()); + Assert.Equal("1", assistantSecondResponse.GetProperty("tool_calls")[0].GetProperty("id").GetString()); + Assert.Equal("MyPlugin-GetCurrentWeather", assistantSecondResponse.GetProperty("tool_calls")[0].GetProperty("function").GetProperty("name").GetString()); + + var functionResult = messages[4]; + Assert.Equal("tool", functionResult.GetProperty("role").GetString()); + Assert.Equal("rainy", functionResult.GetProperty("content").GetString()); + } + + [Fact] + public async Task GetStreamingChatMessageContentsShouldSendMutatedChatHistoryToLLM() + { + // Arrange + static void MutateChatHistory(AutoFunctionInvocationContext context, Func next) + { + // Remove the function call messages from the chat history to reduce token count. + context.ChatHistory.RemoveRange(1, 2); // Remove the `Date` function call and function result messages. + + next(context); + } + + var kernel = new Kernel(); + kernel.ImportPluginFromFunctions("MyPlugin", [KernelFunctionFactory.CreateFromMethod(() => "rainy", "GetCurrentWeather")]); + kernel.AutoFunctionInvocationFilters.Add(new AutoFunctionInvocationFilter(MutateChatHistory)); + + using var firstResponse = new HttpResponseMessage(HttpStatusCode.OK) { Content = new StreamContent(File.OpenRead("TestData/chat_completion_streaming_single_function_call_test_response.txt")) }; + this._messageHandlerStub.ResponseQueue.Enqueue(firstResponse); + + using var secondResponse = new HttpResponseMessage(HttpStatusCode.OK) { Content = new StreamContent(File.OpenRead("TestData/chat_completion_streaming_test_response.txt")) }; + this._messageHandlerStub.ResponseQueue.Enqueue(secondResponse); + + var sut = new OpenAIChatCompletionService(modelId: "gpt-3.5-turbo", apiKey: "NOKEY", httpClient: this._httpClient); + + var chatHistory = new ChatHistory + { + new ChatMessageContent(AuthorRole.User, "What time is it?"), + new ChatMessageContent(AuthorRole.Assistant, [ + new FunctionCallContent("Date", "TimePlugin", "2") + ]), + new ChatMessageContent(AuthorRole.Tool, [ + new FunctionResultContent("Date", "TimePlugin", "2", "rainy") + ]), + new ChatMessageContent(AuthorRole.Assistant, "08/06/2024 00:00:00"), + new ChatMessageContent(AuthorRole.User, "Given the current time of day and weather, what is the likely color of the sky in Boston?") + }; + + // Act + await foreach (var update in sut.GetStreamingChatMessageContentsAsync(chatHistory, new OpenAIPromptExecutionSettings() { ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions }, kernel)) + { + } + + // Assert + var actualRequestContent = Encoding.UTF8.GetString(this._messageHandlerStub.RequestContent!); + Assert.NotNull(actualRequestContent); + + var optionsJson = JsonSerializer.Deserialize(actualRequestContent); + + var messages = optionsJson.GetProperty("messages"); + Assert.Equal(5, messages.GetArrayLength()); + + var userFirstPrompt = messages[0]; + Assert.Equal("user", userFirstPrompt.GetProperty("role").GetString()); + Assert.Equal("What time is it?", userFirstPrompt.GetProperty("content").ToString()); + + var assistantFirstResponse = messages[1]; + Assert.Equal("assistant", assistantFirstResponse.GetProperty("role").GetString()); + Assert.Equal("08/06/2024 00:00:00", assistantFirstResponse.GetProperty("content").GetString()); + + var userSecondPrompt = messages[2]; + Assert.Equal("user", userSecondPrompt.GetProperty("role").GetString()); + Assert.Equal("Given the current time of day and weather, what is the likely color of the sky in Boston?", userSecondPrompt.GetProperty("content").ToString()); + + var assistantSecondResponse = messages[3]; + Assert.Equal("assistant", assistantSecondResponse.GetProperty("role").GetString()); + Assert.Equal("1", assistantSecondResponse.GetProperty("tool_calls")[0].GetProperty("id").GetString()); + Assert.Equal("MyPlugin-GetCurrentWeather", assistantSecondResponse.GetProperty("tool_calls")[0].GetProperty("function").GetProperty("name").GetString()); + + var functionResult = messages[4]; + Assert.Equal("tool", functionResult.GetProperty("role").GetString()); + Assert.Equal("rainy", functionResult.GetProperty("content").GetString()); + } + public void Dispose() { this._httpClient.Dispose(); @@ -899,6 +1043,28 @@ public void Dispose() this._multiMessageHandlerStub.Dispose(); } + private class AutoFunctionInvocationFilter : IAutoFunctionInvocationFilter + { + private readonly Func, Task> _callback; + + public AutoFunctionInvocationFilter(Func, Task> callback) + { + Verify.NotNull(callback, nameof(callback)); + this._callback = callback; + } + + public AutoFunctionInvocationFilter(Action> callback) + { + Verify.NotNull(callback, nameof(callback)); + this._callback = (c, n) => { callback(c, n); return Task.CompletedTask; }; + } + + public async Task OnAutoFunctionInvocationAsync(AutoFunctionInvocationContext context, Func next) + { + await this._callback(context, next); + } + } + private const string ChatCompletionResponse = """ { "id": "chatcmpl-8IlRBQU929ym1EqAY2J4T7GGkW5Om", @@ -911,12 +1077,17 @@ public void Dispose() "message": { "role": "assistant", "content": null, - "function_call": { - "name": "TimePlugin_Date", - "arguments": "{}" - } + "tool_calls":[{ + "id": "1", + "type": "function", + "function": { + "name": "TimePlugin-Date", + "arguments": "{}" + } + } + ] }, - "finish_reason": "stop" + "finish_reason": "tool_calls" } ], "usage": { diff --git a/dotnet/src/Connectors/Connectors.OpenAI/Core/ClientCore.ChatCompletion.cs b/dotnet/src/Connectors/Connectors.OpenAI/Core/ClientCore.ChatCompletion.cs index 5ad712255af5..1177fb7ec846 100644 --- a/dotnet/src/Connectors/Connectors.OpenAI/Core/ClientCore.ChatCompletion.cs +++ b/dotnet/src/Connectors/Connectors.OpenAI/Core/ClientCore.ChatCompletion.cs @@ -143,10 +143,10 @@ internal async Task> GetChatMessageContentsAsy ValidateMaxTokens(chatExecutionSettings.MaxTokens); - var chatForRequest = CreateChatCompletionMessages(chatExecutionSettings, chat); - for (int requestIndex = 0; ; requestIndex++) { + var chatForRequest = CreateChatCompletionMessages(chatExecutionSettings, chat); + var toolCallingConfig = this.GetToolCallingConfiguration(kernel, chatExecutionSettings, requestIndex); var chatOptions = this.CreateChatCompletionOptions(chatExecutionSettings, chat, toolCallingConfig, kernel); @@ -207,11 +207,8 @@ internal async Task> GetChatMessageContentsAsy this.Logger.LogTrace("Function call requests: {Requests}", string.Join(", ", chatCompletion.ToolCalls.OfType().Select(ftc => $"{ftc.FunctionName}({ftc.FunctionArguments})"))); } - // Add the original assistant message to the chat messages; this is required for the service - // to understand the tool call responses. Also add the result message to the caller's chat - // history: if they don't want it, they can remove it, but this makes the data available, - // including metadata like usage. - chatForRequest.Add(CreateRequestMessage(chatCompletion)); + // Add the result message to the caller's chat history; + // this is required for the service to understand the tool call responses. chat.Add(chatMessageContent); // We must send back a response for every tool call, regardless of whether we successfully executed it or not. @@ -223,7 +220,7 @@ internal async Task> GetChatMessageContentsAsy // We currently only know about function tool calls. If it's anything else, we'll respond with an error. if (functionToolCall.Kind != ChatToolCallKind.Function) { - AddResponseMessage(chatForRequest, chat, result: null, "Error: Tool call was not a function call.", functionToolCall, this.Logger); + AddResponseMessage(chat, result: null, "Error: Tool call was not a function call.", functionToolCall, this.Logger); continue; } @@ -235,7 +232,7 @@ internal async Task> GetChatMessageContentsAsy } catch (JsonException) { - AddResponseMessage(chatForRequest, chat, result: null, "Error: Function call arguments were invalid JSON.", functionToolCall, this.Logger); + AddResponseMessage(chat, result: null, "Error: Function call arguments were invalid JSON.", functionToolCall, this.Logger); continue; } @@ -245,14 +242,14 @@ internal async Task> GetChatMessageContentsAsy if (chatExecutionSettings.ToolCallBehavior?.AllowAnyRequestedKernelFunction is not true && !IsRequestableTool(chatOptions, openAIFunctionToolCall)) { - AddResponseMessage(chatForRequest, chat, result: null, "Error: Function call request for a function that wasn't defined.", functionToolCall, this.Logger); + AddResponseMessage(chat, result: null, "Error: Function call request for a function that wasn't defined.", functionToolCall, this.Logger); continue; } // Find the function in the kernel and populate the arguments. if (!kernel!.Plugins.TryGetFunctionAndArguments(openAIFunctionToolCall, out KernelFunction? function, out KernelArguments? functionArgs)) { - AddResponseMessage(chatForRequest, chat, result: null, "Error: Requested function could not be found.", functionToolCall, this.Logger); + AddResponseMessage(chat, result: null, "Error: Requested function could not be found.", functionToolCall, this.Logger); continue; } @@ -287,7 +284,7 @@ internal async Task> GetChatMessageContentsAsy catch (Exception e) #pragma warning restore CA1031 // Do not catch general exception types { - AddResponseMessage(chatForRequest, chat, null, $"Error: Exception while invoking function. {e.Message}", functionToolCall, this.Logger); + AddResponseMessage(chat, null, $"Error: Exception while invoking function. {e.Message}", functionToolCall, this.Logger); continue; } finally @@ -301,7 +298,7 @@ internal async Task> GetChatMessageContentsAsy object functionResultValue = functionResult.GetValue() ?? string.Empty; var stringResult = ProcessFunctionResult(functionResultValue, chatExecutionSettings.ToolCallBehavior); - AddResponseMessage(chatForRequest, chat, stringResult, errorMessage: null, functionToolCall, this.Logger); + AddResponseMessage(chat, stringResult, errorMessage: null, functionToolCall, this.Logger); // If filter requested termination, returning latest function result. if (invocationContext.Terminate) @@ -342,10 +339,10 @@ internal async IAsyncEnumerable GetStreamingC Dictionary? functionNamesByIndex = null; Dictionary? functionArgumentBuildersByIndex = null; - var chatForRequest = CreateChatCompletionMessages(chatExecutionSettings, chat); - for (int requestIndex = 0; ; requestIndex++) { + var chatForRequest = CreateChatCompletionMessages(chatExecutionSettings, chat); + var toolCallingConfig = this.GetToolCallingConfiguration(kernel, chatExecutionSettings, requestIndex); var chatOptions = this.CreateChatCompletionOptions(chatExecutionSettings, chat, toolCallingConfig, kernel); @@ -478,9 +475,7 @@ internal async IAsyncEnumerable GetStreamingC this.Logger.LogDebug("Function call requests: {Requests}", toolCalls.Length); } - // Add the original assistant message to the chat messages; this is required for the service - // to understand the tool call responses. - chatForRequest.Add(CreateRequestMessage(streamedRole ?? default, content, streamedName, toolCalls)); + // Add the result message to the caller's chat history; this is required for the service to understand the tool call responses. var chatMessageContent = this.CreateChatMessageContent(streamedRole ?? default, content, toolCalls, functionCallContents, metadata, streamedName); chat.Add(chatMessageContent); @@ -492,7 +487,7 @@ internal async IAsyncEnumerable GetStreamingC // We currently only know about function tool calls. If it's anything else, we'll respond with an error. if (string.IsNullOrEmpty(toolCall.FunctionName)) { - AddResponseMessage(chatForRequest, chat, result: null, "Error: Tool call was not a function call.", toolCall, this.Logger); + AddResponseMessage(chat, result: null, "Error: Tool call was not a function call.", toolCall, this.Logger); continue; } @@ -504,7 +499,7 @@ internal async IAsyncEnumerable GetStreamingC } catch (JsonException) { - AddResponseMessage(chatForRequest, chat, result: null, "Error: Function call arguments were invalid JSON.", toolCall, this.Logger); + AddResponseMessage(chat, result: null, "Error: Function call arguments were invalid JSON.", toolCall, this.Logger); continue; } @@ -514,14 +509,14 @@ internal async IAsyncEnumerable GetStreamingC if (chatExecutionSettings.ToolCallBehavior?.AllowAnyRequestedKernelFunction is not true && !IsRequestableTool(chatOptions, openAIFunctionToolCall)) { - AddResponseMessage(chatForRequest, chat, result: null, "Error: Function call request for a function that wasn't defined.", toolCall, this.Logger); + AddResponseMessage(chat, result: null, "Error: Function call request for a function that wasn't defined.", toolCall, this.Logger); continue; } // Find the function in the kernel and populate the arguments. if (!kernel!.Plugins.TryGetFunctionAndArguments(openAIFunctionToolCall, out KernelFunction? function, out KernelArguments? functionArgs)) { - AddResponseMessage(chatForRequest, chat, result: null, "Error: Requested function could not be found.", toolCall, this.Logger); + AddResponseMessage(chat, result: null, "Error: Requested function could not be found.", toolCall, this.Logger); continue; } @@ -556,7 +551,7 @@ internal async IAsyncEnumerable GetStreamingC catch (Exception e) #pragma warning restore CA1031 // Do not catch general exception types { - AddResponseMessage(chatForRequest, chat, result: null, $"Error: Exception while invoking function. {e.Message}", toolCall, this.Logger); + AddResponseMessage(chat, result: null, $"Error: Exception while invoking function. {e.Message}", toolCall, this.Logger); continue; } finally @@ -570,7 +565,7 @@ internal async IAsyncEnumerable GetStreamingC object functionResultValue = functionResult.GetValue() ?? string.Empty; var stringResult = ProcessFunctionResult(functionResultValue, chatExecutionSettings.ToolCallBehavior); - AddResponseMessage(chatForRequest, chat, stringResult, errorMessage: null, toolCall, this.Logger); + AddResponseMessage(chat, stringResult, errorMessage: null, toolCall, this.Logger); // If filter requested termination, returning latest function result and breaking request iteration loop. if (invocationContext.Terminate) @@ -785,26 +780,6 @@ private static List CreateChatCompletionMessages(OpenAIPromptExecut return messages; } - private static ChatMessage CreateRequestMessage(ChatMessageRole chatRole, string content, string? name, ChatToolCall[]? tools) - { - if (chatRole == ChatMessageRole.User) - { - return new UserChatMessage(content) { ParticipantName = name }; - } - - if (chatRole == ChatMessageRole.System) - { - return new SystemChatMessage(content) { ParticipantName = name }; - } - - if (chatRole == ChatMessageRole.Assistant) - { - return new AssistantChatMessage(tools, content) { ParticipantName = name }; - } - - throw new NotImplementedException($"Role {chatRole} is not implemented"); - } - private static List CreateRequestMessages(ChatMessageContent message, ToolCallBehavior? toolCallBehavior) { if (message.Role == AuthorRole.System) @@ -955,26 +930,6 @@ private static ChatMessageContentPart GetImageContentItem(ImageContent imageCont throw new ArgumentException($"{nameof(ImageContent)} must have either Data or a Uri."); } - private static ChatMessage CreateRequestMessage(OpenAIChatCompletion completion) - { - if (completion.Role == ChatMessageRole.System) - { - return ChatMessage.CreateSystemMessage(completion.Content[0].Text); - } - - if (completion.Role == ChatMessageRole.Assistant) - { - return ChatMessage.CreateAssistantMessage(completion); - } - - if (completion.Role == ChatMessageRole.User) - { - return ChatMessage.CreateUserMessage(completion.Content); - } - - throw new NotSupportedException($"Role {completion.Role} is not supported."); - } - private OpenAIChatMessageContent CreateChatMessageContent(OpenAIChatCompletion completion, string targetModel) { var message = new OpenAIChatMessageContent(completion, targetModel, this.GetChatCompletionMetadata(completion)); @@ -1053,7 +1008,7 @@ private List GetFunctionCallContents(IEnumerable chatMessages, ChatHistory chat, string? result, string? errorMessage, ChatToolCall toolCall, ILogger logger) + private static void AddResponseMessage(ChatHistory chat, string? result, string? errorMessage, ChatToolCall toolCall, ILogger logger) { // Log any error if (errorMessage is not null && logger.IsEnabled(LogLevel.Debug)) @@ -1062,9 +1017,7 @@ private static void AddResponseMessage(List chatMessages, ChatHisto logger.LogDebug("Failed to handle tool request ({ToolId}). {Error}", toolCall.Id, errorMessage); } - // Add the tool response message to the chat messages result ??= errorMessage ?? string.Empty; - chatMessages.Add(new ToolChatMessage(toolCall.Id, result)); // Add the tool response message to the chat history. var message = new ChatMessageContent(role: AuthorRole.Tool, content: result, metadata: new Dictionary { { OpenAIChatMessageContent.ToolIdProperty, toolCall.Id } }); From b2140fcde7d8b60c7cdb43f3696c401e4e5c6510 Mon Sep 17 00:00:00 2001 From: SergeyMenshykh Date: Tue, 6 Aug 2024 21:01:43 +0100 Subject: [PATCH 2/3] feat(function-calling): seal and remove redundant test helper classes. --- .../MultipleHttpMessageHandlerStub.cs | 53 ------------------- .../AzureOpenAIChatCompletionServiceTests.cs | 2 +- .../OpenAIChatCompletionServiceTests.cs | 2 +- 3 files changed, 2 insertions(+), 55 deletions(-) delete mode 100644 dotnet/src/Connectors/Connectors.AzureOpenAI.UnitTests/MultipleHttpMessageHandlerStub.cs diff --git a/dotnet/src/Connectors/Connectors.AzureOpenAI.UnitTests/MultipleHttpMessageHandlerStub.cs b/dotnet/src/Connectors/Connectors.AzureOpenAI.UnitTests/MultipleHttpMessageHandlerStub.cs deleted file mode 100644 index 0af66de6a519..000000000000 --- a/dotnet/src/Connectors/Connectors.AzureOpenAI.UnitTests/MultipleHttpMessageHandlerStub.cs +++ /dev/null @@ -1,53 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System; -using System.Collections.Generic; -using System.Net.Http; -using System.Net.Http.Headers; -using System.Threading; -using System.Threading.Tasks; - -namespace SemanticKernel.Connectors.AzureOpenAI; - -internal sealed class MultipleHttpMessageHandlerStub : DelegatingHandler -{ - private int _callIteration = 0; - - public List RequestHeaders { get; private set; } - - public List ContentHeaders { get; private set; } - - public List RequestContents { get; private set; } - - public List RequestUris { get; private set; } - - public List Methods { get; private set; } - - public List ResponsesToReturn { get; set; } - - public MultipleHttpMessageHandlerStub() - { - this.RequestHeaders = []; - this.ContentHeaders = []; - this.RequestContents = []; - this.RequestUris = []; - this.Methods = []; - this.ResponsesToReturn = []; - } - - protected override async Task SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) - { - this._callIteration++; - - this.Methods.Add(request.Method); - this.RequestUris.Add(request.RequestUri); - this.RequestHeaders.Add(request.Headers); - this.ContentHeaders.Add(request.Content?.Headers); - - var content = request.Content is null ? null : await request.Content.ReadAsByteArrayAsync(cancellationToken); - - this.RequestContents.Add(content); - - return await Task.FromResult(this.ResponsesToReturn[this._callIteration - 1]); - } -} diff --git a/dotnet/src/Connectors/Connectors.AzureOpenAI.UnitTests/Services/AzureOpenAIChatCompletionServiceTests.cs b/dotnet/src/Connectors/Connectors.AzureOpenAI.UnitTests/Services/AzureOpenAIChatCompletionServiceTests.cs index ed9314e2acfb..435caa3c425a 100644 --- a/dotnet/src/Connectors/Connectors.AzureOpenAI.UnitTests/Services/AzureOpenAIChatCompletionServiceTests.cs +++ b/dotnet/src/Connectors/Connectors.AzureOpenAI.UnitTests/Services/AzureOpenAIChatCompletionServiceTests.cs @@ -1062,7 +1062,7 @@ public void Dispose() { "text", "text" } }; - private class AutoFunctionInvocationFilter : IAutoFunctionInvocationFilter + private sealed class AutoFunctionInvocationFilter : IAutoFunctionInvocationFilter { private readonly Func, Task> _callback; diff --git a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Services/OpenAIChatCompletionServiceTests.cs b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Services/OpenAIChatCompletionServiceTests.cs index b8855f590805..ccda12afe6a6 100644 --- a/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Services/OpenAIChatCompletionServiceTests.cs +++ b/dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Services/OpenAIChatCompletionServiceTests.cs @@ -1043,7 +1043,7 @@ public void Dispose() this._multiMessageHandlerStub.Dispose(); } - private class AutoFunctionInvocationFilter : IAutoFunctionInvocationFilter + private sealed class AutoFunctionInvocationFilter : IAutoFunctionInvocationFilter { private readonly Func, Task> _callback; From 4829caa928ae895dc2c70e6365acddc7db76d4e3 Mon Sep 17 00:00:00 2001 From: SergeyMenshykh Date: Wed, 7 Aug 2024 12:42:52 +0100 Subject: [PATCH 3/3] feat(function-calling): fix the "warning IDE0005: Using directive is unnecessary." issue --- .../StepwisePlannerMigration/Services/PlanProvider.cs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/dotnet/samples/Demos/StepwisePlannerMigration/Services/PlanProvider.cs b/dotnet/samples/Demos/StepwisePlannerMigration/Services/PlanProvider.cs index a61251f9eb49..ed5bd4f03fe1 100644 --- a/dotnet/samples/Demos/StepwisePlannerMigration/Services/PlanProvider.cs +++ b/dotnet/samples/Demos/StepwisePlannerMigration/Services/PlanProvider.cs @@ -2,10 +2,13 @@ using System.IO; using System.Text.Json; -using Microsoft.SemanticKernel.ChatCompletion; #pragma warning disable IDE0005 // Using directive is unnecessary +using Microsoft.SemanticKernel.ChatCompletion; + +#pragma warning restore IDE0005 // Using directive is unnecessary + namespace StepwisePlannerMigration.Services; /// @@ -19,6 +22,3 @@ public ChatHistory GetPlan(string fileName) return JsonSerializer.Deserialize(plan)!; } } - -#pragma warning restore IDE0005 // Using directive is unnecessary -