Skip to content

Commit

Permalink
Merge branch 'main' into feature-connectors-openai
Browse files Browse the repository at this point in the history
  • Loading branch information
RogerBarreto committed Jul 23, 2024
2 parents ecd3fee + f8878be commit 89773be
Show file tree
Hide file tree
Showing 17 changed files with 648 additions and 19 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/dotnet-build-and-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ jobs:

# Generate test reports and check coverage
- name: Generate test reports
uses: danielpalme/ReportGenerator-GitHub-Action@5.2.4
uses: danielpalme/ReportGenerator-GitHub-Action@5.3.8
with:
reports: "./TestResults/Coverage/**/coverage.cobertura.xml"
targetdir: "./TestResults/Reports"
Expand Down
12 changes: 8 additions & 4 deletions dotnet/Directory.Packages.props
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,12 @@
<PackageVersion Include="Microsoft.Extensions.DependencyInjection.Abstractions" Version="8.0.1" />
<PackageVersion Include="Microsoft.Extensions.Hosting" Version="8.0.0" />
<PackageVersion Include="Microsoft.Extensions.Http" Version="8.0.0" />
<PackageVersion Include="Microsoft.Extensions.Http.Resilience" Version="8.6.0" />
<PackageVersion Include="Microsoft.Extensions.Http.Resilience" Version="8.7.0" />
<PackageVersion Include="Microsoft.Extensions.Logging" Version="8.0.0" />
<PackageVersion Include="Microsoft.Extensions.Logging.Abstractions" Version="8.0.1" />
<PackageVersion Include="Microsoft.Extensions.Logging.Console" Version="8.0.0" />
<PackageVersion Include="Microsoft.Extensions.Options.DataAnnotations" Version="8.0.0" />
<PackageVersion Include="Microsoft.Extensions.TimeProvider.Testing" Version="8.5.0" />
<PackageVersion Include="Microsoft.Extensions.TimeProvider.Testing" Version="8.7.0" />
<!-- Test -->
<PackageVersion Include="Microsoft.NET.Test.Sdk" Version="17.10.0" />
<PackageVersion Include="Moq" Version="[4.18.4]" />
Expand Down Expand Up @@ -107,12 +107,12 @@
<PrivateAssets>all</PrivateAssets>
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
</PackageReference>
<PackageVersion Include="xunit.analyzers" Version="1.14.0" />
<PackageVersion Include="xunit.analyzers" Version="1.15.0" />
<PackageReference Include="xunit.analyzers">
<PrivateAssets>all</PrivateAssets>
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
</PackageReference>
<PackageVersion Include="Moq.Analyzers" Version="0.1.0" />
<PackageVersion Include="Moq.Analyzers" Version="0.1.1" />
<PackageReference Include="Moq.Analyzers">
<PrivateAssets>all</PrivateAssets>
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
Expand All @@ -132,5 +132,9 @@
<PrivateAssets>all</PrivateAssets>
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
</PackageReference>
<!-- OnnxRuntimeGenAI -->
<PackageVersion Include="Microsoft.ML.OnnxRuntimeGenAI" Version="0.3.0"/>
<PackageVersion Include="Microsoft.ML.OnnxRuntimeGenAI.Cuda" Version="0.3.0"/>
<PackageVersion Include="Microsoft.ML.OnnxRuntimeGenAI.DirectML" Version="0.3.0"/>
</ItemGroup>
</Project>
2 changes: 1 addition & 1 deletion dotnet/nuget/nuget-package.props
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
<Project>
<PropertyGroup>
<!-- Central version prefix - applies to all nuget packages. -->
<VersionPrefix>1.16.0</VersionPrefix>
<VersionPrefix>1.16.1</VersionPrefix>

<PackageVersion Condition="'$(VersionSuffix)' != ''">$(VersionPrefix)-$(VersionSuffix)</PackageVersion>
<PackageVersion Condition="'$(VersionSuffix)' == ''">$(VersionPrefix)</PackageVersion>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,9 @@ internal async Task<IReadOnlyList<ChatMessageContent>> GetChatMessageContentsAsy
// history: if they don't want it, they can remove it, but this makes the data available,
// including metadata like usage.
chatRequest.AddMessage(chatChoice.Message!);
chatHistory.Add(this.ToChatMessageContent(modelId, responseData, chatChoice));

var chatMessageContent = this.ToChatMessageContent(modelId, responseData, chatChoice);
chatHistory.Add(chatMessageContent);

// We must send back a response for every tool call, regardless of whether we successfully executed it or not.
// If we successfully execute it, we'll add the result. If we don't, we'll add an error.
Expand Down Expand Up @@ -172,8 +174,9 @@ internal async Task<IReadOnlyList<ChatMessageContent>> GetChatMessageContentsAsy

// Now, invoke the function, and add the resulting tool call message to the chat options.
FunctionResult functionResult = new(function) { Culture = kernel.Culture };
AutoFunctionInvocationContext invocationContext = new(kernel, function, functionResult, chatHistory)
AutoFunctionInvocationContext invocationContext = new(kernel, function, functionResult, chatHistory, chatMessageContent)
{
ToolCallId = toolCall.Id,
Arguments = functionArgs,
RequestSequenceIndex = requestIndex - 1,
FunctionSequenceIndex = toolCallIndex,
Expand Down Expand Up @@ -404,8 +407,9 @@ internal async IAsyncEnumerable<StreamingChatMessageContent> GetStreamingChatMes

// Now, invoke the function, and add the resulting tool call message to the chat options.
FunctionResult functionResult = new(function) { Culture = kernel.Culture };
AutoFunctionInvocationContext invocationContext = new(kernel, function, functionResult, chatHistory)
AutoFunctionInvocationContext invocationContext = new(kernel, function, functionResult, chatHistory, chatHistory.Last())
{
ToolCallId = toolCall.Id,
Arguments = functionArgs,
RequestSequenceIndex = requestIndex - 1,
FunctionSequenceIndex = toolCallIndex,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// Copyright (c) Microsoft. All rights reserved.

using Microsoft.Extensions.DependencyInjection;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel.Connectors.Onnx;
using Xunit;

namespace SemanticKernel.Connectors.Onnx.UnitTests;

/// <summary>
/// Unit tests for <see cref="OnnxKernelBuilderExtensions"/>.
/// </summary>
public class OnnxExtensionsTests
{
[Fact]
public void AddOnnxRuntimeGenAIChatCompletionToServiceCollection()
{
// Arrange
var collection = new ServiceCollection();
collection.AddOnnxRuntimeGenAIChatCompletion("modelId", "modelPath");

// Act
var kernelBuilder = collection.AddKernel();
var kernel = collection.BuildServiceProvider().GetRequiredService<Kernel>();
var service = kernel.GetRequiredService<IChatCompletionService>();

// Assert
Assert.NotNull(service);
Assert.IsType<OnnxRuntimeGenAIChatCompletionService>(service);
}

[Fact]
public void AddOnnxRuntimeGenAIChatCompletionToKernelBuilder()
{
// Arrange
var collection = new ServiceCollection();
var kernelBuilder = collection.AddKernel();
kernelBuilder.AddOnnxRuntimeGenAIChatCompletion("modelId", "modelPath");

// Act
var kernel = collection.BuildServiceProvider().GetRequiredService<Kernel>();
var service = kernel.GetRequiredService<IChatCompletionService>();

// Assert
Assert.NotNull(service);
Assert.IsType<OnnxRuntimeGenAIChatCompletionService>(service);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Text.Json;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.Connectors.Onnx;
using Xunit;

namespace SemanticKernel.Connectors.Onnx.UnitTests;

/// <summary>
/// Unit tests for <see cref="OnnxRuntimeGenAIPromptExecutionSettings"/>.
/// </summary>
public class OnnxRuntimeGenAIPromptExecutionSettingsTests
{
[Fact]
public void FromExecutionSettingsWhenAlreadyMistralShouldReturnSame()
{
// Arrange
var executionSettings = new OnnxRuntimeGenAIPromptExecutionSettings();

// Act
var onnxExecutionSettings = OnnxRuntimeGenAIPromptExecutionSettings.FromExecutionSettings(executionSettings);

// Assert
Assert.Same(executionSettings, onnxExecutionSettings);
}

[Fact]
public void FromExecutionSettingsWhenNullShouldReturnDefaultSettings()
{
// Arrange
PromptExecutionSettings? executionSettings = null;

// Act
var onnxExecutionSettings = OnnxRuntimeGenAIPromptExecutionSettings.FromExecutionSettings(executionSettings);

// Assert
Assert.Null(onnxExecutionSettings.TopK);
Assert.Null(onnxExecutionSettings.TopP);
Assert.Null(onnxExecutionSettings.Temperature);
Assert.Null(onnxExecutionSettings.RepetitionPenalty);
Assert.Null(onnxExecutionSettings.PastPresentShareBuffer);
Assert.Null(onnxExecutionSettings.NumReturnSequences);
Assert.Null(onnxExecutionSettings.NumBeams);
Assert.Null(onnxExecutionSettings.NoRepeatNgramSize);
Assert.Null(onnxExecutionSettings.MinTokens);
Assert.Null(onnxExecutionSettings.MaxTokens);
Assert.Null(onnxExecutionSettings.LengthPenalty);
Assert.Null(onnxExecutionSettings.DiversityPenalty);
Assert.Null(onnxExecutionSettings.EarlyStopping);
Assert.Null(onnxExecutionSettings.DoSample);
}

[Fact]
public void FromExecutionSettingsWhenSerializedHasPropertiesShouldPopulateSpecialized()
{
// Arrange
string jsonSettings = """
{
"top_k": 2,
"top_p": 0.9,
"temperature": 0.5,
"repetition_penalty": 0.1,
"past_present_share_buffer": true,
"num_return_sequences": 200,
"num_beams": 20,
"no_repeat_ngram_size": 15,
"min_tokens": 10,
"max_tokens": 100,
"length_penalty": 0.2,
"diversity_penalty": 0.3,
"early_stopping": false,
"do_sample": true
}
""";

// Act
var executionSettings = JsonSerializer.Deserialize<PromptExecutionSettings>(jsonSettings);
var onnxExecutionSettings = OnnxRuntimeGenAIPromptExecutionSettings.FromExecutionSettings(executionSettings);

// Assert
Assert.Equal(2, onnxExecutionSettings.TopK);
Assert.Equal(0.9f, onnxExecutionSettings.TopP);
Assert.Equal(0.5f, onnxExecutionSettings.Temperature);
Assert.Equal(0.1f, onnxExecutionSettings.RepetitionPenalty);
Assert.True(onnxExecutionSettings.PastPresentShareBuffer);
Assert.Equal(200, onnxExecutionSettings.NumReturnSequences);
Assert.Equal(20, onnxExecutionSettings.NumBeams);
Assert.Equal(15, onnxExecutionSettings.NoRepeatNgramSize);
Assert.Equal(10, onnxExecutionSettings.MinTokens);
Assert.Equal(100, onnxExecutionSettings.MaxTokens);
Assert.Equal(0.2f, onnxExecutionSettings.LengthPenalty);
Assert.Equal(0.3f, onnxExecutionSettings.DiversityPenalty);
Assert.False(onnxExecutionSettings.EarlyStopping);
Assert.True(onnxExecutionSettings.DoSample);
}
}
6 changes: 6 additions & 0 deletions dotnet/src/Connectors/Connectors.Onnx/Connectors.Onnx.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,10 @@
<PackageReference Include="System.Numerics.Tensors" />
</ItemGroup>

<ItemGroup>
<PackageReference Include="Microsoft.ML.OnnxRuntimeGenAI" Condition=" '$(Configuration)' == 'Debug' OR '$(Configuration)' == 'Release' " />
<PackageReference Include="Microsoft.ML.OnnxRuntimeGenAI.Cuda" Condition=" '$(Configuration)' == 'Debug_Cuda' OR '$(Configuration)' == 'Release_Cuda' " />
<PackageReference Include="Microsoft.ML.OnnxRuntimeGenAI.DirectML" Condition=" '$(Configuration)' == 'Debug_DirectML' OR '$(Configuration)' == 'Release_DirectML' " />
</ItemGroup>

</Project>
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

using System.IO;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Microsoft.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel.Connectors.Onnx;
using Microsoft.SemanticKernel.Embeddings;

Expand All @@ -14,6 +16,29 @@ namespace Microsoft.SemanticKernel;
/// </summary>
public static class OnnxKernelBuilderExtensions
{
/// <summary>
/// Add OnnxRuntimeGenAI Chat Completion services to the kernel builder.
/// </summary>
/// <param name="builder">The kernel builder.</param>
/// <param name="modelId">Model Id.</param>
/// <param name="modelPath">The generative AI ONNX model path.</param>
/// <param name="serviceId">The optional service ID.</param>
/// <returns>The updated kernel builder.</returns>
public static IKernelBuilder AddOnnxRuntimeGenAIChatCompletion(
this IKernelBuilder builder,
string modelId,
string modelPath,
string? serviceId = null)
{
builder.Services.AddKeyedSingleton<IChatCompletionService>(serviceId, (serviceProvider, _) =>
new OnnxRuntimeGenAIChatCompletionService(
modelId,
modelPath: modelPath,
loggerFactory: serviceProvider.GetService<ILoggerFactory>()));

return builder;
}

/// <summary>Adds a text embedding generation service using a BERT ONNX model.</summary>
/// <param name="builder">The <see cref="IKernelBuilder"/> instance to augment.</param>
/// <param name="onnxModelPath">The path to the ONNX model file.</param>
Expand Down
Loading

0 comments on commit 89773be

Please sign in to comment.