Skip to content

Commit

Permalink
.Net: Align metadata names with underlying library ones (#7207)
Browse files Browse the repository at this point in the history
### Motivation and Context
This PR aligns metadata names with those provided by the underlying
Azure.AI.OpenAI library. Additionally, it adds a few unit tests to
increase code coverage.
  • Loading branch information
SergeyMenshykh committed Jul 11, 2024
1 parent 64120d3 commit a10e9f2
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
using System.Threading;
using System.Threading.Tasks;
using Azure.AI.OpenAI;
using Microsoft.Extensions.Logging;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.Connectors.AzureOpenAI;
using Microsoft.SemanticKernel.Services;
using Moq;

namespace SemanticKernel.Connectors.AzureOpenAI.UnitTests.Services;

Expand All @@ -19,12 +21,23 @@ namespace SemanticKernel.Connectors.AzureOpenAI.UnitTests.Services;
/// </summary>
public class AzureOpenAITextEmbeddingGenerationServiceTests
{
[Fact]
public void ItCanBeInstantiatedAndPropertiesSetAsExpected()
private readonly Mock<ILoggerFactory> _mockLoggerFactory;

public AzureOpenAITextEmbeddingGenerationServiceTests()
{
this._mockLoggerFactory = new Mock<ILoggerFactory>();
}

[Theory]
[InlineData(true)]
[InlineData(false)]
public void ItCanBeInstantiatedAndPropertiesSetAsExpected(bool includeLoggerFactory)
{
// Arrange
var sut = new AzureOpenAITextEmbeddingGenerationService("deployment-name", "https://endpoint", "api-key", modelId: "model", dimensions: 2);
var sutWithAzureOpenAIClient = new AzureOpenAITextEmbeddingGenerationService("deployment-name", new AzureOpenAIClient(new Uri("https://endpoint"), new ApiKeyCredential("apiKey")), modelId: "model", dimensions: 2);
var sut = includeLoggerFactory ?
new AzureOpenAITextEmbeddingGenerationService("deployment-name", "https://endpoint", "api-key", modelId: "model", dimensions: 2, loggerFactory: this._mockLoggerFactory.Object) :
new AzureOpenAITextEmbeddingGenerationService("deployment-name", "https://endpoint", "api-key", modelId: "model", dimensions: 2);
var sutWithAzureOpenAIClient = new AzureOpenAITextEmbeddingGenerationService("deployment-name", new AzureOpenAIClient(new Uri("https://endpoint"), new ApiKeyCredential("apiKey")), modelId: "model", dimensions: 2, loggerFactory: this._mockLoggerFactory.Object);

// Assert
Assert.NotNull(sut);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,17 +41,17 @@ public AzureOpenAITextToImageServiceTests()
public void ConstructorsAddRequiredMetadata()
{
// Case #1
var sut = new AzureOpenAITextToImageService("deployment", "https://api-host/", "api-key", "model");
var sut = new AzureOpenAITextToImageService("deployment", "https://api-host/", "api-key", "model", loggerFactory: this._mockLoggerFactory.Object);
Assert.Equal("deployment", sut.Attributes[ClientCore.DeploymentNameKey]);
Assert.Equal("model", sut.Attributes[AIServiceExtensions.ModelIdKey]);

// Case #2
sut = new AzureOpenAITextToImageService("deployment", "https://api-hostapi/", new Mock<TokenCredential>().Object, "model");
sut = new AzureOpenAITextToImageService("deployment", "https://api-hostapi/", new Mock<TokenCredential>().Object, "model", loggerFactory: this._mockLoggerFactory.Object);
Assert.Equal("deployment", sut.Attributes[ClientCore.DeploymentNameKey]);
Assert.Equal("model", sut.Attributes[AIServiceExtensions.ModelIdKey]);

// Case #3
sut = new AzureOpenAITextToImageService("deployment", new AzureOpenAIClient(new Uri("https://api-host/"), "api-key"), "model");
sut = new AzureOpenAITextToImageService("deployment", new AzureOpenAIClient(new Uri("https://api-host/"), "api-key"), "model", loggerFactory: this._mockLoggerFactory.Object);
Assert.Equal("deployment", sut.Attributes[ClientCore.DeploymentNameKey]);
Assert.Equal("model", sut.Attributes[AIServiceExtensions.ModelIdKey]);
}
Expand All @@ -68,7 +68,7 @@ public void ConstructorsAddRequiredMetadata()
public async Task GenerateImageWorksCorrectlyAsync(int width, int height, string modelId)
{
// Arrange
var sut = new AzureOpenAITextToImageService("deployment", "https://api-host", "api-key", modelId, this._httpClient);
var sut = new AzureOpenAITextToImageService("deployment", "https://api-host", "api-key", modelId, this._httpClient, loggerFactory: this._mockLoggerFactory.Object);

// Act
var result = await sut.GenerateImageAsync("description", width, height);
Expand All @@ -84,6 +84,65 @@ public async Task GenerateImageWorksCorrectlyAsync(int width, int height, string
Assert.Equal($"{width}x{height}", request["size"]?.ToString());
}

[Theory]
[InlineData(true)]
[InlineData(false)]
public async Task ItShouldUseProvidedEndpoint(bool useTokeCredential)
{
// Arrange
var sut = useTokeCredential ?
new AzureOpenAITextToImageService("deployment", endpoint: "https://api-host", new Mock<TokenCredential>().Object, "dall-e-3", this._httpClient) :
new AzureOpenAITextToImageService("deployment", endpoint: "https://api-host", "api-key", "dall-e-3", this._httpClient);

// Act
var result = await sut.GenerateImageAsync("description", 1024, 1024);

// Assert
Assert.StartsWith("https://api-host", this._messageHandlerStub.RequestUri?.AbsoluteUri);
}

[Theory]
[InlineData(true, "")]
[InlineData(true, null)]
[InlineData(false, "")]
[InlineData(false, null)]
public async Task ItShouldUseHttpClientUriIfNoEndpointProvided(bool useTokeCredential, string? endpoint)
{
// Arrange
this._httpClient.BaseAddress = new Uri("https://api-host");

var sut = useTokeCredential ?
new AzureOpenAITextToImageService("deployment", endpoint: endpoint!, new Mock<TokenCredential>().Object, "dall-e-3", this._httpClient) :
new AzureOpenAITextToImageService("deployment", endpoint: endpoint!, "api-key", "dall-e-3", this._httpClient);

// Act
var result = await sut.GenerateImageAsync("description", 1024, 1024);

// Assert
Assert.StartsWith("https://api-host", this._messageHandlerStub.RequestUri?.AbsoluteUri);
}

[Theory]
[InlineData(true, "")]
[InlineData(true, null)]
[InlineData(false, "")]
[InlineData(false, null)]
public void ItShouldThrowExceptionIfNoEndpointProvided(bool useTokeCredential, string? endpoint)
{
// Arrange
this._httpClient.BaseAddress = null;

// Act & Assert
if (useTokeCredential)
{
Assert.Throws<ArgumentException>(() => new AzureOpenAITextToImageService("deployment", endpoint: endpoint!, new Mock<TokenCredential>().Object, "dall-e-3", this._httpClient));
}
else
{
Assert.Throws<ArgumentException>(() => new AzureOpenAITextToImageService("deployment", endpoint: endpoint!, "api-key", "dall-e-3", this._httpClient));
}
}

public void Dispose()
{
this._httpClient.Dispose();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ namespace Microsoft.SemanticKernel.Connectors.AzureOpenAI;
/// </summary>
internal partial class ClientCore
{
private const string PromptFilterResultsMetadataKey = "PromptFilterResults";
private const string ContentFilterResultsMetadataKey = "ContentFilterResults";
private const string LogProbabilityInfoMetadataKey = "LogProbabilityInfo";
private const string ContentFilterResultForPromptKey = "ContentFilterResultForPrompt";
private const string ContentFilterResultForResponseKey = "ContentFilterResultForResponse";
private const string ContentTokenLogProbabilitiesKey = "ContentTokenLogProbabilities";
private const string ModelProvider = "openai";
private record ToolCallingConfig(IList<ChatTool>? Tools, ChatToolChoice Choice, bool AutoInvoke);

Expand Down Expand Up @@ -92,25 +92,25 @@ private record ToolCallingConfig(IList<ChatTool>? Tools, ChatToolChoice Choice,
private static Dictionary<string, object?> GetChatCompletionMetadata(OpenAIChatCompletion completions)
{
#pragma warning disable AOAI001 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed.
return new Dictionary<string, object?>(8)
return new Dictionary<string, object?>
{
{ nameof(completions.Id), completions.Id },
{ nameof(completions.CreatedAt), completions.CreatedAt },
{ PromptFilterResultsMetadataKey, completions.GetContentFilterResultForPrompt() },
{ ContentFilterResultForPromptKey, completions.GetContentFilterResultForPrompt() },
{ nameof(completions.SystemFingerprint), completions.SystemFingerprint },
{ nameof(completions.Usage), completions.Usage },
{ ContentFilterResultsMetadataKey, completions.GetContentFilterResultForResponse() },
{ ContentFilterResultForResponseKey, completions.GetContentFilterResultForResponse() },

// Serialization of this struct behaves as an empty object {}, need to cast to string to avoid it.
{ nameof(completions.FinishReason), completions.FinishReason.ToString() },
{ LogProbabilityInfoMetadataKey, completions.ContentTokenLogProbabilities },
{ ContentTokenLogProbabilitiesKey, completions.ContentTokenLogProbabilities },
};
#pragma warning restore AOAI001 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed.
}

private static Dictionary<string, object?> GetChatCompletionMetadata(StreamingChatCompletionUpdate completionUpdate)
{
return new Dictionary<string, object?>(4)
return new Dictionary<string, object?>
{
{ nameof(completionUpdate.Id), completionUpdate.Id },
{ nameof(completionUpdate.CreatedAt), completionUpdate.CreatedAt },
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ namespace Microsoft.SemanticKernel.Connectors.OpenAI;
/// </summary>
internal partial class ClientCore
{
private const string LogProbabilityInfoMetadataKey = "LogProbabilityInfo";
private const string ContentTokenLogProbabilitiesKey = "ContentTokenLogProbabilities";
private const string ModelProvider = "openai";
private record ToolCallingConfig(IList<ChatTool>? Tools, ChatToolChoice Choice, bool AutoInvoke);

Expand Down Expand Up @@ -88,7 +88,7 @@ private record ToolCallingConfig(IList<ChatTool>? Tools, ChatToolChoice Choice,

private static Dictionary<string, object?> GetChatCompletionMetadata(OpenAIChatCompletion completions)
{
return new Dictionary<string, object?>(8)
return new Dictionary<string, object?>
{
{ nameof(completions.Id), completions.Id },
{ nameof(completions.CreatedAt), completions.CreatedAt },
Expand All @@ -97,13 +97,13 @@ private record ToolCallingConfig(IList<ChatTool>? Tools, ChatToolChoice Choice,

// Serialization of this struct behaves as an empty object {}, need to cast to string to avoid it.
{ nameof(completions.FinishReason), completions.FinishReason.ToString() },
{ LogProbabilityInfoMetadataKey, completions.ContentTokenLogProbabilities },
{ ContentTokenLogProbabilitiesKey, completions.ContentTokenLogProbabilities },
};
}

private static Dictionary<string, object?> GetChatCompletionMetadata(StreamingChatCompletionUpdate completionUpdate)
{
return new Dictionary<string, object?>(4)
return new Dictionary<string, object?>
{
{ nameof(completionUpdate.Id), completionUpdate.Id },
{ nameof(completionUpdate.CreatedAt), completionUpdate.CreatedAt },
Expand Down

0 comments on commit a10e9f2

Please sign in to comment.