Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

.Net: Migrate AzureOpenAITextToImageService to Azure.AI.OpenAI v2 #7093

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,6 @@
<Compile Include="$(RepoRoot)/dotnet/src/InternalUtilities/test/*.cs" Link="%(RecursiveDir)%(Filename)%(Extension)" />
</ItemGroup>

<ItemGroup>
<Compile Remove="Services\AzureOpenAITextToImageServiceTests.cs" />
</ItemGroup>

<ItemGroup>
<None Include="Services\AzureOpenAITextToImageServiceTests.cs" />
</ItemGroup>

<ItemGroup>
<ProjectReference Include="..\Connectors.AzureOpenAI\Connectors.AzureOpenAI.csproj" />
</ItemGroup>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,20 @@
using System;
using System.IO;
using System.Net.Http;
using System.Text.Json;
using System.Text.Json.Nodes;
using System.Threading.Tasks;
using Azure.AI.OpenAI;
using Azure.Core;
using Microsoft.Extensions.Logging;
using Microsoft.SemanticKernel.Connectors.AzureOpenAI;
using Microsoft.SemanticKernel.Services;
using Moq;
using OpenAI;

namespace SemanticKernel.Connectors.AzureOpenAI.UnitTests.Services;

/// <summary>
/// Unit tests for <see cref="AzureOpenAITextToImageServiceTests"/> class.
/// Unit tests for <see cref="AzureOpenAITextToImageService"/> class.
/// </summary>
public sealed class AzureOpenAITextToImageServiceTests : IDisposable
{
Expand All @@ -35,25 +38,21 @@ public AzureOpenAITextToImageServiceTests()
}

[Fact]
public void ConstructorWorksCorrectly()
public void ConstructorsAddRequiredMetadata()
{
// Arrange & Act
var sut = new AzureOpenAITextToImageServiceTests("model", "api-key", "organization");

// Assert
Assert.NotNull(sut);
Assert.Equal("organization", sut.Attributes[ClientCore.OrganizationKey]);
// Case #1
var sut = new AzureOpenAITextToImageService("deployment", "https://api-host/", "api-key", "model");
Assert.Equal("deployment", sut.Attributes[ClientCore.DeploymentNameKey]);
Assert.Equal("model", sut.Attributes[AIServiceExtensions.ModelIdKey]);
}

[Fact]
public void OpenAIClientConstructorWorksCorrectly()
{
// Arrange
var sut = new AzureOpenAITextToImageServiceTests("model", new OpenAIClient("apikey"));
// Case #2
sut = new AzureOpenAITextToImageService("deployment", "https://api-hostapi/", new Mock<TokenCredential>().Object, "model");
Assert.Equal("deployment", sut.Attributes[ClientCore.DeploymentNameKey]);
Assert.Equal("model", sut.Attributes[AIServiceExtensions.ModelIdKey]);

// Assert
Assert.NotNull(sut);
// Case #3
sut = new AzureOpenAITextToImageService("deployment", new AzureOpenAIClient(new Uri("https://api-host/"), "api-key"), "model");
Assert.Equal("deployment", sut.Attributes[ClientCore.DeploymentNameKey]);
Assert.Equal("model", sut.Attributes[AIServiceExtensions.ModelIdKey]);
}

Expand All @@ -69,34 +68,20 @@ public void OpenAIClientConstructorWorksCorrectly()
public async Task GenerateImageWorksCorrectlyAsync(int width, int height, string modelId)
{
// Arrange
var sut = new AzureOpenAITextToImageServiceTests(modelId, "api-key", httpClient: this._httpClient);
Assert.Equal(modelId, sut.Attributes["ModelId"]);
var sut = new AzureOpenAITextToImageService("deployment", "https://api-host", "api-key", modelId, this._httpClient);

// Act
var result = await sut.GenerateImageAsync("description", width, height);

// Assert
Assert.Equal("https://image-url/", result);
}

[Fact]
public async Task GenerateImageDoesLogActionAsync()
{
// Assert
var modelId = "dall-e-2";
var logger = new Mock<ILogger<AzureOpenAITextToImageServiceTests>>();
logger.Setup(l => l.IsEnabled(It.IsAny<LogLevel>())).Returns(true);

this._mockLoggerFactory.Setup(x => x.CreateLogger(It.IsAny<string>())).Returns(logger.Object);

// Arrange
var sut = new AzureOpenAITextToImageServiceTests(modelId, "apiKey", httpClient: this._httpClient, loggerFactory: this._mockLoggerFactory.Object);

// Act
await sut.GenerateImageAsync("description", 256, 256);

// Assert
logger.VerifyLog(LogLevel.Information, $"Action: {nameof(AzureOpenAITextToImageServiceTests.GenerateImageAsync)}. OpenAI Model ID: {modelId}.", Times.Once());
var request = JsonSerializer.Deserialize<JsonObject>(this._messageHandlerStub.RequestContent); // {"prompt":"description","model":"deployment","response_format":"url","size":"179x124"}
Assert.NotNull(request);
Assert.Equal("description", request["prompt"]?.ToString());
Assert.Equal("deployment", request["model"]?.ToString());
Assert.Equal("url", request["response_format"]?.ToString());
Assert.Equal($"{width}x{height}", request["size"]?.ToString());
}

public void Dispose()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
{
"created": 1702575371,
"data": [
{
"revised_prompt": "A photo capturing the diversity of the Earth's landscapes.",
"url": "https://image-url/"
}
]
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,10 @@
<Description>Semantic Kernel connectors for Azure OpenAI. Contains clients for text generation, chat completion, embedding and DALL-E text to image.</Description>
</PropertyGroup>

<ItemGroup>
<Compile Remove="Core\ClientCore.TextToImage.cs" />
<Compile Remove="Services\AzureOpenAITextToImageService.cs" />
</ItemGroup>

<ItemGroup>
<InternalsVisibleTo Include="SemanticKernel.Connectors.AzureOpenAI.UnitTests" />
</ItemGroup>

<ItemGroup>
<None Include="Core\ClientCore.TextToImage.cs" />
<None Include="Services\AzureOpenAITextToImageService.cs" />
</ItemGroup>

<ItemGroup>
<PackageReference Include="Azure.AI.OpenAI" VersionOverride="2.0.0-beta.2" />
</ItemGroup>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
namespace Microsoft.SemanticKernel.Connectors.AzureOpenAI;

/// <summary>
/// Base class for AI clients that provides common functionality for interacting with OpenAI services.
/// Base class for AI clients that provides common functionality for interacting with Azure OpenAI services.
/// </summary>
internal partial class ClientCore
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
namespace Microsoft.SemanticKernel.Connectors.AzureOpenAI;

/// <summary>
/// Base class for AI clients that provides common functionality for interacting with OpenAI services.
/// Base class for AI clients that provides common functionality for interacting with Azure OpenAI services.
/// </summary>
internal partial class ClientCore
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
namespace Microsoft.SemanticKernel.Connectors.AzureOpenAI;

/// <summary>
/// Base class for AI clients that provides common functionality for interacting with OpenAI services.
/// Base class for AI clients that provides common functionality for interacting with Azure OpenAI services.
/// </summary>
internal partial class ClientCore
{
Expand Down Expand Up @@ -36,9 +36,8 @@ internal async Task<string> GenerateImageAsync(
ResponseFormat = GeneratedImageFormat.Uri
};

ClientResult<GeneratedImage> response = await RunRequestAsync(() => this.Client.GetImageClient(this.ModelId).GenerateImageAsync(prompt, imageOptions, cancellationToken)).ConfigureAwait(false);
var generatedImage = response.Value;
ClientResult<GeneratedImage> response = await RunRequestAsync(() => this.Client.GetImageClient(this.DeploymentOrModelName).GenerateImageAsync(prompt, imageOptions, cancellationToken)).ConfigureAwait(false);

return generatedImage.ImageUri?.ToString() ?? throw new KernelException("The generated image is not in url format");
return response.Value.ImageUri?.ToString() ?? throw new KernelException("The generated image is not in url format");
}
}
12 changes: 6 additions & 6 deletions dotnet/src/Connectors/Connectors.AzureOpenAI/Core/ClientCore.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
namespace Microsoft.SemanticKernel.Connectors.AzureOpenAI;

/// <summary>
/// Base class for AI clients that provides common functionality for interacting with OpenAI services.
/// Base class for AI clients that provides common functionality for interacting with Azure OpenAI services.
/// </summary>
internal partial class ClientCore
{
Expand Down Expand Up @@ -135,13 +135,13 @@ internal ClientCore(

/// <summary>Gets options to use for an OpenAIClient</summary>
/// <param name="httpClient">Custom <see cref="HttpClient"/> for HTTP requests.</param>
/// <param name="serviceVersion">Optional API version.</param>
/// <returns>An instance of <see cref="OpenAIClientOptions"/>.</returns>
internal static AzureOpenAIClientOptions GetAzureOpenAIClientOptions(HttpClient? httpClient)
internal static AzureOpenAIClientOptions GetAzureOpenAIClientOptions(HttpClient? httpClient, AzureOpenAIClientOptions.ServiceVersion? serviceVersion = null)
{
AzureOpenAIClientOptions options = new()
{
ApplicationId = HttpHeaderConstant.Values.UserAgent,
};
AzureOpenAIClientOptions options = serviceVersion is not null
? new(serviceVersion.Value) { ApplicationId = HttpHeaderConstant.Values.UserAgent }
: new() { ApplicationId = HttpHeaderConstant.Values.UserAgent };

options.AddPolicy(CreateRequestHeaderPolicy(HttpHeaderConstant.Names.SemanticKernelVersion, HttpHeaderConstant.Values.GetAssemblyVersion(typeof(ClientCore))), PipelinePosition.PerCall);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,18 +79,18 @@ public AzureOpenAITextEmbeddingGenerationService(
/// Creates a new <see cref="AzureOpenAITextEmbeddingGenerationService"/> client.
/// </summary>
/// <param name="deploymentName">Azure OpenAI deployment name, see https://learn.microsoft.com/azure/cognitive-services/openai/how-to/create-resource</param>
/// <param name="openAIClient">Custom <see cref="AzureOpenAIClient"/> for HTTP requests.</param>
/// <param name="azureOpenAIClient">Custom <see cref="AzureOpenAIClient"/> for HTTP requests.</param>
/// <param name="modelId">Azure OpenAI model id, see https://learn.microsoft.com/azure/cognitive-services/openai/how-to/create-resource</param>
/// <param name="loggerFactory">The <see cref="ILoggerFactory"/> to use for logging. If null, no logging will be performed.</param>
/// <param name="dimensions">The number of dimensions the resulting output embeddings should have. Only supported in "text-embedding-3" and later models.</param>
public AzureOpenAITextEmbeddingGenerationService(
string deploymentName,
AzureOpenAIClient openAIClient,
AzureOpenAIClient azureOpenAIClient,
SergeyMenshykh marked this conversation as resolved.
Show resolved Hide resolved
string? modelId = null,
ILoggerFactory? loggerFactory = null,
int? dimensions = null)
{
this._core = new(deploymentName, openAIClient, loggerFactory?.CreateLogger(typeof(AzureOpenAITextEmbeddingGenerationService)));
this._core = new(deploymentName, azureOpenAIClient, loggerFactory?.CreateLogger(typeof(AzureOpenAITextEmbeddingGenerationService)));

this._core.AddAttribute(AIServiceExtensions.ModelIdKey, modelId);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,16 @@
using System.Net.Http;
using System.Threading;
using System.Threading.Tasks;
using Azure.AI.OpenAI;
using Azure.Core;
using Microsoft.Extensions.Logging;
using Microsoft.SemanticKernel.Services;
using Microsoft.SemanticKernel.TextToImage;
using OpenAI;

namespace Microsoft.SemanticKernel.Connectors.AzureOpenAI;

/// <summary>
/// OpenAI text to image service.
/// Azure OpenAI text to image service.
/// </summary>
[Experimental("SKEXP0010")]
public class AzureOpenAITextToImageService : ITextToImageService
Expand All @@ -26,41 +28,111 @@ public class AzureOpenAITextToImageService : ITextToImageService
/// <summary>
/// Initializes a new instance of the <see cref="AzureOpenAITextToImageService"/> class.
/// </summary>
/// <param name="modelId">The model to use for image generation.</param>
/// <param name="apiKey">OpenAI API key, see https://platform.openai.com/account/api-keys</param>
/// <param name="organizationId">OpenAI organization id. This is usually optional unless your account belongs to multiple organizations.</param>
/// <param name="endpoint">Non-default endpoint for the OpenAI API.</param>
/// <param name="deploymentName">Azure OpenAI deployment name, see https://learn.microsoft.com/azure/cognitive-services/openai/how-to/create-resource</param>
/// <param name="endpoint">Azure OpenAI deployment URL, see https://learn.microsoft.com/azure/cognitive-services/openai/quickstart</param>
/// <param name="apiKey">Azure OpenAI API key, see https://learn.microsoft.com/azure/cognitive-services/openai/quickstart</param>
/// <param name="modelId">Azure OpenAI model id, see https://learn.microsoft.com/azure/cognitive-services/openai/how-to/create-resource</param>
/// <param name="httpClient">Custom <see cref="HttpClient"/> for HTTP requests.</param>
/// <param name="loggerFactory">The <see cref="ILoggerFactory"/> to use for logging. If null, no logging will be performed.</param>
/// <param name="apiVersion">Azure OpenAI service API version, see https://learn.microsoft.com/azure/cognitive-services/openai/quickstart</param>
public AzureOpenAITextToImageService(
string modelId,
string? apiKey = null,
string? organizationId = null,
Uri? endpoint = null,
string deploymentName,
string endpoint,
string apiKey,
string? modelId,
SergeyMenshykh marked this conversation as resolved.
Show resolved Hide resolved
HttpClient? httpClient = null,
ILoggerFactory? loggerFactory = null)
ILoggerFactory? loggerFactory = null,
string? apiVersion = null)
{
this._client = new(modelId, apiKey, organizationId, endpoint, httpClient, loggerFactory?.CreateLogger(this.GetType()));
Verify.NotNullOrWhiteSpace(apiKey);

var connectorEndpoint = !string.IsNullOrWhiteSpace(endpoint) ? endpoint! : httpClient?.BaseAddress?.AbsoluteUri;
if (connectorEndpoint is null)
{
throw new ArgumentException($"The {nameof(httpClient)}.{nameof(HttpClient.BaseAddress)} and {nameof(endpoint)} are both null or empty. Please ensure at least one is provided.");
}

var options = ClientCore.GetAzureOpenAIClientOptions(
httpClient,
AzureOpenAIClientOptions.ServiceVersion.V2024_05_01_Preview); // DALL-E 3 is supported in the latest API releases - https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#image-generation

var azureOpenAIClient = new AzureOpenAIClient(new Uri(connectorEndpoint), apiKey, options);

this._client = new(deploymentName, azureOpenAIClient, loggerFactory?.CreateLogger(this.GetType()));

if (modelId is not null)
{
this._client.AddAttribute(AIServiceExtensions.ModelIdKey, modelId);
}
}

/// <summary>
/// Initializes a new instance of the <see cref="AzureOpenAITextToImageService"/> class.
/// </summary>
/// <param name="modelId">Model name</param>
/// <param name="openAIClient">Custom <see cref="OpenAIClient"/> for HTTP requests.</param>
/// <param name="deploymentName">Azure OpenAI deployment name, see https://learn.microsoft.com/azure/cognitive-services/openai/how-to/create-resource</param>
/// <param name="endpoint">Azure OpenAI deployment URL, see https://learn.microsoft.com/azure/cognitive-services/openai/quickstart</param>
/// <param name="credential">Token credentials, e.g. DefaultAzureCredential, ManagedIdentityCredential, EnvironmentCredential, etc.</param>
/// <param name="modelId">Azure OpenAI model id, see https://learn.microsoft.com/azure/cognitive-services/openai/how-to/create-resource</param>
/// <param name="httpClient">Custom <see cref="HttpClient"/> for HTTP requests.</param>
/// <param name="loggerFactory">The <see cref="ILoggerFactory"/> to use for logging. If null, no logging will be performed.</param>
/// <param name="apiVersion">Azure OpenAI service API version, see https://learn.microsoft.com/azure/cognitive-services/openai/quickstart</param>
public AzureOpenAITextToImageService(
string modelId,
OpenAIClient openAIClient,
string deploymentName,
string endpoint,
TokenCredential credential,
string? modelId,
HttpClient? httpClient = null,
ILoggerFactory? loggerFactory = null,
string? apiVersion = null)
{
Verify.NotNull(credential);

var connectorEndpoint = !string.IsNullOrWhiteSpace(endpoint) ? endpoint! : httpClient?.BaseAddress?.AbsoluteUri;
if (connectorEndpoint is null)
{
throw new ArgumentException($"The {nameof(httpClient)}.{nameof(HttpClient.BaseAddress)} and {nameof(endpoint)} are both null or empty. Please ensure at least one is provided.");
}

var options = ClientCore.GetAzureOpenAIClientOptions(
httpClient,
AzureOpenAIClientOptions.ServiceVersion.V2024_05_01_Preview); // DALL-E 3 is supported in the latest API releases - https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#image-generation

var azureOpenAIClient = new AzureOpenAIClient(new Uri(connectorEndpoint), credential, options);

this._client = new(deploymentName, azureOpenAIClient, loggerFactory?.CreateLogger(this.GetType()));

if (modelId is not null)
{
this._client.AddAttribute(AIServiceExtensions.ModelIdKey, modelId);
}
}

/// <summary>
/// Initializes a new instance of the <see cref="AzureOpenAITextToImageService"/> class.
/// </summary>
/// <param name="deploymentName">Azure OpenAI deployment name, see https://learn.microsoft.com/azure/cognitive-services/openai/how-to/create-resource</param>
/// <param name="azureOpenAIClient">Custom <see cref="AzureOpenAIClient"/>.</param>
/// <param name="modelId">Azure OpenAI model id, see https://learn.microsoft.com/azure/cognitive-services/openai/how-to/create-resource</param>
/// <param name="loggerFactory">The <see cref="ILoggerFactory"/> to use for logging. If null, no logging will be performed.</param>
public AzureOpenAITextToImageService(
string deploymentName,
AzureOpenAIClient azureOpenAIClient,
string? modelId,
ILoggerFactory? loggerFactory = null)
{
this._client = new(modelId, openAIClient, loggerFactory?.CreateLogger(typeof(OpenAITextEmbeddingGenerationService)));
Verify.NotNull(azureOpenAIClient);

this._client = new(deploymentName, azureOpenAIClient, loggerFactory?.CreateLogger(this.GetType()));

if (modelId is not null)
{
this._client.AddAttribute(AIServiceExtensions.ModelIdKey, modelId);
}
}

/// <inheritdoc/>
public Task<string> GenerateImageAsync(string description, int width, int height, Kernel? kernel = null, CancellationToken cancellationToken = default)
{
this._client.LogActionDetails();
return this._client.GenerateImageAsync(description, width, height, cancellationToken);
}
}
Loading