Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

.Net Agents - Streaming Bug Fix and Support Additional Assistant Option #8852

Merged
merged 12 commits into from
Sep 17, 2024
148 changes: 131 additions & 17 deletions dotnet/samples/Concepts/Agents/ChatCompletion_FunctionTermination.cs
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,8 @@ public async Task UseAutoFunctionInvocationFilterWithAgentInvocationAsync()
await InvokeAgentAsync("What is the special drink?");
await InvokeAgentAsync("Thank you");

// Display the chat history.
Console.WriteLine("================================");
Console.WriteLine("CHAT HISTORY");
Console.WriteLine("================================");
foreach (ChatMessageContent message in chat)
{
this.WriteAgentChatMessage(message);
}
// Display the entire chat history.
WriteChatHistory(chat);

// Local function to invoke agent and display the conversation messages.
async Task InvokeAgentAsync(string input)
Expand Down Expand Up @@ -91,15 +85,8 @@ public async Task UseAutoFunctionInvocationFilterWithAgentChatAsync()
await InvokeAgentAsync("What is the special drink?");
await InvokeAgentAsync("Thank you");

// Display the chat history.
Console.WriteLine("================================");
Console.WriteLine("CHAT HISTORY");
Console.WriteLine("================================");
ChatMessageContent[] history = await chat.GetChatMessagesAsync().ToArrayAsync();
for (int index = history.Length; index > 0; --index)
{
this.WriteAgentChatMessage(history[index - 1]);
}
// Display the entire chat history.
WriteChatHistory(await chat.GetChatMessagesAsync().ToArrayAsync());

// Local function to invoke agent and display the conversation messages.
async Task InvokeAgentAsync(string input)
Expand All @@ -115,6 +102,133 @@ async Task InvokeAgentAsync(string input)
}
}

[Fact]
public async Task UseAutoFunctionInvocationFilterWithStreamingAgentInvocationAsync()
{
// Define the agent
ChatCompletionAgent agent =
new()
{
Instructions = "Answer questions about the menu.",
Kernel = CreateKernelWithFilter(),
Arguments = new KernelArguments(new OpenAIPromptExecutionSettings() { FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() }),
};

KernelPlugin plugin = KernelPluginFactory.CreateFromType<MenuPlugin>();
agent.Kernel.Plugins.Add(plugin);

/// Create the chat history to capture the agent interaction.
ChatHistory chat = [];

// Respond to user input, invoking functions where appropriate.
await InvokeAgentAsync("Hello");
await InvokeAgentAsync("What is the special soup?");
await InvokeAgentAsync("What is the special drink?");
await InvokeAgentAsync("Thank you");

// Display the entire chat history.
WriteChatHistory(chat);

// Local function to invoke agent and display the conversation messages.
async Task InvokeAgentAsync(string input)
{
ChatMessageContent message = new(AuthorRole.User, input);
chat.Add(message);
this.WriteAgentChatMessage(message);

int historyCount = chat.Count;

bool isFirst = false;
await foreach (StreamingChatMessageContent response in agent.InvokeStreamingAsync(chat))
{
if (string.IsNullOrEmpty(response.Content))
{
continue;
}

if (!isFirst)
{
Console.WriteLine($"\n# {response.Role} - {response.AuthorName ?? "*"}:");
isFirst = true;
}

Console.WriteLine($"\t > streamed: '{response.Content}'");
}

if (historyCount <= chat.Count)
{
for (int index = historyCount; index < chat.Count; index++)
{
this.WriteAgentChatMessage(chat[index]);
}
}
}
}

[Fact]
public async Task UseAutoFunctionInvocationFilterWithStreamingAgentChatAsync()
{
// Define the agent
ChatCompletionAgent agent =
new()
{
Instructions = "Answer questions about the menu.",
Kernel = CreateKernelWithFilter(),
Arguments = new KernelArguments(new OpenAIPromptExecutionSettings() { FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() }),
};

KernelPlugin plugin = KernelPluginFactory.CreateFromType<MenuPlugin>();
agent.Kernel.Plugins.Add(plugin);

// Create a chat for agent interaction.
AgentGroupChat chat = new();

// Respond to user input, invoking functions where appropriate.
await InvokeAgentAsync("Hello");
await InvokeAgentAsync("What is the special soup?");
await InvokeAgentAsync("What is the special drink?");
await InvokeAgentAsync("Thank you");

// Display the entire chat history.
WriteChatHistory(await chat.GetChatMessagesAsync().ToArrayAsync());

// Local function to invoke agent and display the conversation messages.
async Task InvokeAgentAsync(string input)
{
ChatMessageContent message = new(AuthorRole.User, input);
chat.AddChatMessage(message);
this.WriteAgentChatMessage(message);

bool isFirst = false;
await foreach (StreamingChatMessageContent response in chat.InvokeStreamingAsync(agent))
{
if (string.IsNullOrEmpty(response.Content))
{
continue;
}

if (!isFirst)
{
Console.WriteLine($"\n# {response.Role} - {response.AuthorName ?? "*"}:");
isFirst = true;
}

Console.WriteLine($"\t > streamed: '{response.Content}'");
}
}
}

private void WriteChatHistory(IEnumerable<ChatMessageContent> chat)
{
Console.WriteLine("================================");
Console.WriteLine("CHAT HISTORY");
Console.WriteLine("================================");
foreach (ChatMessageContent message in chat)
{
this.WriteAgentChatMessage(message);
}
}

private Kernel CreateKernelWithFilter()
{
IKernelBuilder builder = Kernel.CreateBuilder();
Expand Down
10 changes: 7 additions & 3 deletions dotnet/src/Agents/Core/ChatCompletionAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ public override async IAsyncEnumerable<StreamingChatMessageContent> InvokeStream
StringBuilder builder = new();
await foreach (StreamingChatMessageContent message in messages.ConfigureAwait(false))
{
role ??= message.Role;
role = message.Role;
message.Role ??= AuthorRole.Assistant;
message.AuthorName = this.Name;

Expand All @@ -103,8 +103,6 @@ public override async IAsyncEnumerable<StreamingChatMessageContent> InvokeStream
yield return message;
}

chat.Add(new(role ?? AuthorRole.Assistant, builder.ToString()) { AuthorName = this.Name });

// Capture mutated messages related function calling / tools
for (int messageIndex = messageCount; messageIndex < chat.Count; messageIndex++)
{
Expand All @@ -114,6 +112,12 @@ public override async IAsyncEnumerable<StreamingChatMessageContent> InvokeStream

history.Add(message);
}

// Do not duplicate terminated function result to history
if (role != AuthorRole.Tool)
{
history.Add(new(role ?? AuthorRole.Assistant, builder.ToString()) { AuthorName = this.Name });
}
}

internal static (IChatCompletionService service, PromptExecutionSettings? executionSettings) GetChatCompletionService(Kernel kernel, KernelArguments? arguments)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ public static RunCreationOptions GenerateOptions(OpenAIAssistantDefinition defin
RunCreationOptions options =
new()
{
AdditionalInstructions = invocationOptions?.AdditionalInstructions ?? definition.ExecutionOptions?.AdditionalInstructions,
MaxCompletionTokens = ResolveExecutionSetting(invocationOptions?.MaxCompletionTokens, definition.ExecutionOptions?.MaxCompletionTokens),
MaxPromptTokens = ResolveExecutionSetting(invocationOptions?.MaxPromptTokens, definition.ExecutionOptions?.MaxPromptTokens),
ModelOverride = invocationOptions?.ModelName,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
using Azure;
using Microsoft.Extensions.Logging;
using Microsoft.SemanticKernel.ChatCompletion;

using OpenAI;
using OpenAI.Assistants;

Expand Down
6 changes: 6 additions & 0 deletions dotnet/src/Agents/OpenAI/OpenAIAssistantExecutionOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@ namespace Microsoft.SemanticKernel.Agents.OpenAI;
/// </remarks>
public sealed class OpenAIAssistantExecutionOptions
{
/// <summary>
/// Appends additional instructions.
/// </summary>
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public string? AdditionalInstructions { get; init; }

/// <summary>
/// The maximum number of completion tokens that may be used over the course of the run.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@ public sealed class OpenAIAssistantInvocationOptions
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public string? ModelName { get; init; }

/// <summary>
/// Appends additional instructions.
/// </summary>
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public string? AdditionalInstructions { get; init; }

/// <summary>
/// Set if code_interpreter tool is enabled.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ public void AssistantRunOptionsFactoryExecutionOptionsNullTest()
new("gpt-anything")
{
Temperature = 0.5F,
ExecutionOptions =
new()
{
AdditionalInstructions = "test",
},
};

// Act
Expand All @@ -32,6 +37,7 @@ public void AssistantRunOptionsFactoryExecutionOptionsNullTest()
Assert.NotNull(options);
Assert.Null(options.Temperature);
Assert.Null(options.NucleusSamplingFactor);
Assert.Equal("test", options.AdditionalInstructions);
Assert.Empty(options.Metadata);
}

Expand Down Expand Up @@ -77,13 +83,15 @@ public void AssistantRunOptionsFactoryExecutionOptionsOverrideTest()
ExecutionOptions =
new()
{
AdditionalInstructions = "test1",
TruncationMessageCount = 5,
},
};

OpenAIAssistantInvocationOptions invocationOptions =
new()
{
AdditionalInstructions = "test2",
Temperature = 0.9F,
TruncationMessageCount = 8,
EnableJsonResponse = true,
Expand All @@ -96,6 +104,7 @@ public void AssistantRunOptionsFactoryExecutionOptionsOverrideTest()
Assert.NotNull(options);
Assert.Equal(0.9F, options.Temperature);
Assert.Equal(8, options.TruncationStrategy.LastMessages);
Assert.Equal("test2", options.AdditionalInstructions);
Assert.Equal(AssistantResponseFormat.JsonObject, options.ResponseFormat);
Assert.Null(options.NucleusSamplingFactor);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ public void VerifyOpenAIAssistantDefinitionAssignment()
ExecutionOptions =
new()
{
AdditionalInstructions = "test instructions",
MaxCompletionTokens = 1000,
MaxPromptTokens = 1000,
ParallelToolCallsEnabled = false,
Expand All @@ -83,6 +84,7 @@ public void VerifyOpenAIAssistantDefinitionAssignment()
Assert.Equal(2, definition.Temperature);
Assert.Equal(0, definition.TopP);
Assert.NotNull(definition.ExecutionOptions);
Assert.Equal("test instructions", definition.ExecutionOptions.AdditionalInstructions);
Assert.Equal(1000, definition.ExecutionOptions.MaxCompletionTokens);
Assert.Equal(1000, definition.ExecutionOptions.MaxPromptTokens);
Assert.Equal(12, definition.ExecutionOptions.TruncationMessageCount);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ public void OpenAIAssistantInvocationOptionsInitialState()

// Assert
Assert.Null(options.ModelName);
Assert.Null(options.AdditionalInstructions);
Assert.Null(options.Metadata);
Assert.Null(options.Temperature);
Assert.Null(options.TopP);
Expand All @@ -48,6 +49,7 @@ public void OpenAIAssistantInvocationOptionsAssignment()
new()
{
ModelName = "testmodel",
AdditionalInstructions = "test instructions",
Metadata = new Dictionary<string, string>() { { "a", "1" } },
MaxCompletionTokens = 1000,
MaxPromptTokens = 1000,
Expand All @@ -62,6 +64,7 @@ public void OpenAIAssistantInvocationOptionsAssignment()

// Assert
Assert.Equal("testmodel", options.ModelName);
Assert.Equal("test instructions", options.AdditionalInstructions);
Assert.Equal(2, options.Temperature);
Assert.Equal(0, options.TopP);
Assert.Equal(1000, options.MaxCompletionTokens);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
using Microsoft.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel.Connectors.OpenAI;
using SemanticKernel.IntegrationTests.TestSettings;
using xRetry;
using Xunit;

namespace SemanticKernel.IntegrationTests.Agents;
Expand All @@ -32,7 +33,7 @@ public sealed class ChatCompletionAgentTests()
/// Integration test for <see cref="ChatCompletionAgent"/> using function calling
/// and targeting Azure OpenAI services.
/// </summary>
[Theory]
[RetryTheory(typeof(HttpOperationException))]
[InlineData("What is the special soup?", "Clam Chowder", false)]
[InlineData("What is the special soup?", "Clam Chowder", true)]
public async Task AzureChatCompletionAgentAsync(string input, string expectedAnswerContains, bool useAutoFunctionTermination)
Expand Down Expand Up @@ -96,7 +97,7 @@ public async Task AzureChatCompletionAgentAsync(string input, string expectedAns
/// Integration test for <see cref="ChatCompletionAgent"/> using new function calling model
/// and targeting Azure OpenAI services.
/// </summary>
[Theory]
[RetryTheory(typeof(HttpOperationException))]
[InlineData("What is the special soup?", "Clam Chowder", false)]
[InlineData("What is the special soup?", "Clam Chowder", true)]
public async Task AzureChatCompletionAgentUsingNewFunctionCallingModelAsync(string input, string expectedAnswerContains, bool useAutoFunctionTermination)
Expand Down Expand Up @@ -160,7 +161,7 @@ public async Task AzureChatCompletionAgentUsingNewFunctionCallingModelAsync(stri
/// Integration test for <see cref="ChatCompletionAgent"/> using function calling
/// and targeting Azure OpenAI services.
/// </summary>
[Fact]
[RetryFact(typeof(HttpOperationException))]
public async Task AzureChatCompletionStreamingAsync()
{
// Arrange
Expand Down Expand Up @@ -206,7 +207,7 @@ public async Task AzureChatCompletionStreamingAsync()
/// Integration test for <see cref="ChatCompletionAgent"/> using new function calling model
/// and targeting Azure OpenAI services.
/// </summary>
[Fact]
[RetryFact(typeof(HttpOperationException))]
public async Task AzureChatCompletionStreamingUsingNewFunctionCallingModelAsync()
{
// Arrange
Expand Down
3 changes: 2 additions & 1 deletion dotnet/src/IntegrationTests/Agents/MixedAgentTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
using Microsoft.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel.Connectors.OpenAI;
using SemanticKernel.IntegrationTests.TestSettings;
using xRetry;
using Xunit;

namespace SemanticKernel.IntegrationTests.Agents;
Expand Down Expand Up @@ -50,7 +51,7 @@ await this.VerifyAgentExecutionAsync(
/// Integration test for <see cref="OpenAIAssistantAgent"/> using function calling
/// and targeting Azure OpenAI services.
/// </summary>
[Theory]
[RetryTheory(typeof(HttpOperationException))]
[InlineData(false)]
[InlineData(true)]
public async Task AzureOpenAIMixedAgentAsync(bool useNewFunctionCallingModel)
Expand Down
Loading
Loading