Skip to content

Commit

Permalink
.Net: OpenAI V2 Version Update and Adjustments (#8392)
Browse files Browse the repository at this point in the history
### Motivation and Context

- Update Azure SDK to `beta.3`
- Update OpenAI SDK to `beta.10`
- Make Azure OpenAI Package Resilient for mismatching OpenAI SDK
versions
- Adapt to latest breaking changes from `OpenAI beta.6` +
- Adapt to latest breaking changes from `System.ClientModel beta.5` +

### Impact

Some changes introduced by `OpenAI SDK beta.9` cannot be recovered for
custom Endpoints, enforcing all endpoints to have `v1/` prefix, removal
of non `v1` test scenarios was necessary.
  • Loading branch information
RogerBarreto committed Aug 27, 2024
1 parent c262d99 commit 4d4e3ad
Show file tree
Hide file tree
Showing 13 changed files with 73 additions and 94 deletions.
6 changes: 3 additions & 3 deletions dotnet/Directory.Packages.props
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
<ManagePackageVersionsCentrally>true</ManagePackageVersionsCentrally>
</PropertyGroup>
<ItemGroup>
<PackageVersion Include="OpenAI" Version="2.0.0-beta.5" />
<PackageVersion Include="System.ClientModel" Version="1.1.0-beta.4" />
<PackageVersion Include="OpenAI" Version="2.0.0-beta.10" />
<PackageVersion Include="System.ClientModel" Version="1.1.0-beta.7" />
<PackageVersion Include="Azure.AI.ContentSafety" Version="1.0.0" />
<PackageVersion Include="Azure.AI.OpenAI" Version="2.0.0-beta.2" />
<PackageVersion Include="Azure.AI.OpenAI" Version="2.0.0-beta.3" />
<PackageVersion Include="Azure.Identity" Version="1.12.0" />
<PackageVersion Include="Azure.Monitor.OpenTelemetry.Exporter" Version="1.3.0" />
<PackageVersion Include="Azure.Search.Documents" Version="11.6.0" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@ await agent.CreateThreadAsync(
finally
{
await agent.DeleteThreadAsync(threadId);
await agent.DeleteAsync();
await agent.DeleteAsync(CancellationToken.None);
await vectorStoreClient.DeleteVectorStoreAsync(vectorStore);
await fileClient.DeleteFileAsync(fileInfo);
await fileClient.DeleteFileAsync(fileInfo.Id);
}

// Local function to invoke agent and display the conversation messages.
Expand Down
11 changes: 9 additions & 2 deletions dotnet/src/Agents/OpenAI/Extensions/KernelFunctionExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,17 @@ public static FunctionToolDefinition ToToolDefinition(this KernelFunction functi
required,
};

return new FunctionToolDefinition(FunctionName.ToFullyQualifiedName(function.Name, pluginName), function.Description, BinaryData.FromObjectAsJson(spec));
return new FunctionToolDefinition(FunctionName.ToFullyQualifiedName(function.Name, pluginName))
{
Description = function.Description,
Parameters = BinaryData.FromObjectAsJson(spec)
};
}

return new FunctionToolDefinition(FunctionName.ToFullyQualifiedName(function.Name, pluginName), function.Description);
return new FunctionToolDefinition(FunctionName.ToFullyQualifiedName(function.Name, pluginName))
{
Description = function.Description
};
}

private static string ConvertType(Type? type)
Expand Down
45 changes: 28 additions & 17 deletions dotnet/src/Agents/OpenAI/Internal/AssistantThreadActions.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
// Copyright (c) Microsoft. All rights reserved.
using System.ClientModel;
using System.Collections.Generic;
using System.Linq;
using System.Net;
Expand Down Expand Up @@ -52,7 +53,9 @@ public static async Task<string> CreateThreadAsync(AssistantClient client, OpenA
{
foreach (ChatMessageContent message in options.Messages)
{
ThreadInitializationMessage threadMessage = new(AssistantMessageFactory.GetMessageContents(message));
ThreadInitializationMessage threadMessage = new(
role: message.Role == AuthorRole.User ? MessageRole.User : MessageRole.Assistant,
content: AssistantMessageFactory.GetMessageContents(message));
createOptions.InitialMessages.Add(threadMessage);
}
}
Expand Down Expand Up @@ -89,6 +92,7 @@ public static async Task CreateMessageAsync(AssistantClient client, string threa

await client.CreateMessageAsync(
threadId,
message.Role == AuthorRole.User ? MessageRole.User : MessageRole.Assistant,
AssistantMessageFactory.GetMessageContents(message),
options,
cancellationToken).ConfigureAwait(false);
Expand All @@ -105,28 +109,31 @@ public static async IAsyncEnumerable<ChatMessageContent> GetMessagesAsync(Assist
{
Dictionary<string, string?> agentNames = []; // Cache agent names by their identifier

await foreach (ThreadMessage message in client.GetMessagesAsync(threadId, ListOrder.NewestFirst, cancellationToken).ConfigureAwait(false))
await foreach (PageResult<ThreadMessage> page in client.GetMessagesAsync(threadId, new() { Order = ListOrder.NewestFirst }, cancellationToken).ConfigureAwait(false))
{
AuthorRole role = new(message.Role.ToString());

string? assistantName = null;
if (!string.IsNullOrWhiteSpace(message.AssistantId) &&
!agentNames.TryGetValue(message.AssistantId, out assistantName))
foreach (var message in page.Values)
{
Assistant assistant = await client.GetAssistantAsync(message.AssistantId).ConfigureAwait(false); // SDK BUG - CANCEL TOKEN (https://github.com/microsoft/semantic-kernel/issues/7431)
if (!string.IsNullOrWhiteSpace(assistant.Name))
AuthorRole role = new(message.Role.ToString());

string? assistantName = null;
if (!string.IsNullOrWhiteSpace(message.AssistantId) &&
!agentNames.TryGetValue(message.AssistantId, out assistantName))
{
agentNames.Add(assistant.Id, assistant.Name);
Assistant assistant = await client.GetAssistantAsync(message.AssistantId, cancellationToken).ConfigureAwait(false);
if (!string.IsNullOrWhiteSpace(assistant.Name))
{
agentNames.Add(assistant.Id, assistant.Name);
}
}
}

assistantName ??= message.AssistantId;
assistantName ??= message.AssistantId;

ChatMessageContent content = GenerateMessageContent(assistantName, message);
ChatMessageContent content = GenerateMessageContent(assistantName, message);

if (content.Items.Count > 0)
{
yield return content;
if (content.Items.Count > 0)
{
yield return content;
}
}
}
}
Expand Down Expand Up @@ -190,7 +197,11 @@ public static async IAsyncEnumerable<ChatMessageContent> GetMessagesAsync(Assist
throw new KernelException($"Agent Failure - Run terminated: {run.Status} [{run.Id}]: {run.LastError?.Message ?? "Unknown"}");
}

RunStep[] steps = await client.GetRunStepsAsync(run).ToArrayAsync(cancellationToken).ConfigureAwait(false);
List<RunStep> steps = [];
await foreach (var page in client.GetRunStepsAsync(run).ConfigureAwait(false))
{
steps.AddRange(page.Values);
};

// Is tool action required?
if (run.Status == RunStatus.RequiresAction)
Expand Down
9 changes: 6 additions & 3 deletions dotnet/src/Agents/OpenAI/OpenAIAssistantAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,12 @@ public static async IAsyncEnumerable<OpenAIAssistantDefinition> ListDefinitionsA
AssistantClient client = CreateClient(provider);

// Query and enumerate assistant definitions
await foreach (Assistant model in client.GetAssistantsAsync(ListOrder.NewestFirst, cancellationToken).ConfigureAwait(false))
await foreach (var page in client.GetAssistantsAsync(new AssistantCollectionOptions() { Order = ListOrder.NewestFirst }, cancellationToken).ConfigureAwait(false))
{
yield return CreateAssistantDefinition(model);
foreach (Assistant model in page.Values)
{
yield return CreateAssistantDefinition(model);
}
}
}

Expand All @@ -132,7 +135,7 @@ public static async Task<OpenAIAssistantAgent> RetrieveAsync(
AssistantClient client = CreateClient(provider);

// Retrieve the assistant
Assistant model = await client.GetAssistantAsync(id).ConfigureAwait(false); // SDK BUG - CANCEL TOKEN (https://github.com/microsoft/semantic-kernel/issues/7431)
Assistant model = await client.GetAssistantAsync(id, cancellationToken).ConfigureAwait(false);

// Instantiate the agent
return
Expand Down
11 changes: 5 additions & 6 deletions dotnet/src/Agents/OpenAI/OpenAIClientProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ public static OpenAIClientProvider ForAzureOpenAI(ApiKeyCredential apiKey, Uri e
Verify.NotNull(apiKey, nameof(apiKey));
Verify.NotNull(endpoint, nameof(endpoint));

AzureOpenAIClientOptions clientOptions = CreateAzureClientOptions(endpoint, httpClient);
AzureOpenAIClientOptions clientOptions = CreateAzureClientOptions(httpClient);

return new(new AzureOpenAIClient(endpoint, apiKey!, clientOptions), CreateConfigurationKeys(endpoint, httpClient));
}
Expand All @@ -66,7 +66,7 @@ public static OpenAIClientProvider ForAzureOpenAI(TokenCredential credential, Ur
Verify.NotNull(credential, nameof(credential));
Verify.NotNull(endpoint, nameof(endpoint));

AzureOpenAIClientOptions clientOptions = CreateAzureClientOptions(endpoint, httpClient);
AzureOpenAIClientOptions clientOptions = CreateAzureClientOptions(httpClient);

return new(new AzureOpenAIClient(endpoint, credential, clientOptions), CreateConfigurationKeys(endpoint, httpClient));
}
Expand Down Expand Up @@ -102,12 +102,11 @@ public static OpenAIClientProvider FromClient(OpenAIClient client)
return new(client, [client.GetType().FullName!, client.GetHashCode().ToString()]);
}

private static AzureOpenAIClientOptions CreateAzureClientOptions(Uri? endpoint, HttpClient? httpClient)
private static AzureOpenAIClientOptions CreateAzureClientOptions(HttpClient? httpClient)
{
AzureOpenAIClientOptions options = new()
{
ApplicationId = HttpHeaderConstant.Values.UserAgent,
Endpoint = endpoint,
ApplicationId = HttpHeaderConstant.Values.UserAgent
};

ConfigureClientOptions(httpClient, options);
Expand All @@ -128,7 +127,7 @@ private static OpenAIClientOptions CreateOpenAIClientOptions(Uri? endpoint, Http
return options;
}

private static void ConfigureClientOptions(HttpClient? httpClient, OpenAIClientOptions options)
private static void ConfigureClientOptions(HttpClient? httpClient, ClientPipelineOptions options)
{
options.AddPolicy(CreateRequestHeaderPolicy(HttpHeaderConstant.Names.SemanticKernelVersion, HttpHeaderConstant.Values.GetAssemblyVersion(typeof(OpenAIAssistantAgent))), PipelinePosition.PerCall);

Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Collections.Generic;
using System.Diagnostics;
using Azure.AI.OpenAI;
using Microsoft.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel.Connectors.OpenAI;
using Microsoft.SemanticKernel.Diagnostics;
using OpenAI.Chat;
using OpenAIChatCompletion = OpenAI.Chat.ChatCompletion;

#pragma warning disable CA2208 // Instantiate argument exceptions correctly

Expand All @@ -18,34 +16,9 @@ namespace Microsoft.SemanticKernel.Connectors.AzureOpenAI;
/// </summary>
internal partial class AzureClientCore
{
private const string ContentFilterResultForPromptKey = "ContentFilterResultForPrompt";
private const string ContentFilterResultForResponseKey = "ContentFilterResultForResponse";

/// <inheritdoc/>
protected override OpenAIPromptExecutionSettings GetSpecializedExecutionSettings(PromptExecutionSettings? executionSettings)
{
return AzureOpenAIPromptExecutionSettings.FromExecutionSettings(executionSettings);
}

/// <inheritdoc/>
protected override Dictionary<string, object?> GetChatCompletionMetadata(OpenAIChatCompletion completions)
{
#pragma warning disable AOAI001 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed.
return new Dictionary<string, object?>
{
{ nameof(completions.Id), completions.Id },
{ nameof(completions.CreatedAt), completions.CreatedAt },
{ ContentFilterResultForPromptKey, completions.GetContentFilterResultForPrompt() },
{ nameof(completions.SystemFingerprint), completions.SystemFingerprint },
{ nameof(completions.Usage), completions.Usage },
{ ContentFilterResultForResponseKey, completions.GetContentFilterResultForResponse() },

// Serialization of this struct behaves as an empty object {}, need to cast to string to avoid it.
{ nameof(completions.FinishReason), completions.FinishReason.ToString() },
{ nameof(completions.ContentTokenLogProbabilities), completions.ContentTokenLogProbabilities },
};
#pragma warning restore AOAI001 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed.
}
=> AzureOpenAIPromptExecutionSettings.FromExecutionSettings(executionSettings);

/// <inheritdoc/>
protected override Activity? StartCompletionActivity(ChatHistory chatHistory, PromptExecutionSettings settings)
Expand All @@ -71,7 +44,7 @@ protected override ChatCompletionOptions CreateChatCompletionOptions(
FrequencyPenalty = (float?)executionSettings.FrequencyPenalty,
PresencePenalty = (float?)executionSettings.PresencePenalty,
Seed = executionSettings.Seed,
User = executionSettings.User,
EndUserId = executionSettings.User,
TopLogProbabilityCount = executionSettings.TopLogprobs,
IncludeLogProbabilities = executionSettings.Logprobs,
ResponseFormat = GetResponseFormat(azureSettings) ?? ChatResponseFormat.Text,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,9 @@ public void ItUsesEndpointAsExpected(string? clientBaseAddress, string? provided
var clientCore = new ClientCore("model", "apiKey", endpoint: endpoint, httpClient: client);

// Assert
Assert.Equal(endpoint ?? client?.BaseAddress ?? new Uri("https://api.openai.com/v1"), clientCore.Endpoint);
Assert.Equal(endpoint ?? client?.BaseAddress ?? new Uri("https://api.openai.com/"), clientCore.Endpoint);
Assert.True(clientCore.Attributes.ContainsKey(AIServiceExtensions.EndpointKey));
Assert.Equal(endpoint?.ToString() ?? client?.BaseAddress?.ToString() ?? "https://api.openai.com/v1", clientCore.Attributes[AIServiceExtensions.EndpointKey]);
Assert.Equal(endpoint?.ToString() ?? client?.BaseAddress?.ToString() ?? "https://api.openai.com/", clientCore.Attributes[AIServiceExtensions.EndpointKey]);

client?.Dispose();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,10 @@ public void ConstructorWithApiKeyWorksCorrectly(bool includeLoggerFactory)
}

[Theory]
[InlineData("http://localhost:1234/chat/completions", "http://localhost:1234/chat/completions")] // Uses full path when provided
[InlineData("http://localhost:1234/v2/chat/completions", "http://localhost:1234/v2/chat/completions")] // Uses full path when provided
[InlineData("http://localhost:1234", "http://localhost:1234/v1/chat/completions")]
[InlineData("http://localhost:1234/v1/chat/completions", "http://localhost:1234/v1/chat/completions")] // Uses full path when provided
[InlineData("http://localhost:1234/", "http://localhost:1234/v1/chat/completions")]
[InlineData("http://localhost:8080", "http://localhost:8080/v1/chat/completions")]
[InlineData("https://something:8080", "https://something:8080/v1/chat/completions")] // Accepts TLS Secured endpoints
[InlineData("http://localhost:1234/v2", "http://localhost:1234/v2/chat/completions")]
[InlineData("http://localhost:8080/v2", "http://localhost:8080/v2/chat/completions")]
public async Task ItUsesCustomEndpointsWhenProvidedDirectlyAsync(string endpointProvided, string expectedEndpoint)
{
// Arrange
Expand All @@ -98,13 +95,10 @@ public async Task ItUsesCustomEndpointsWhenProvidedDirectlyAsync(string endpoint
}

[Theory]
[InlineData("http://localhost:1234/chat/completions", "http://localhost:1234/chat/completions")] // Uses full path when provided
[InlineData("http://localhost:1234/v2/chat/completions", "http://localhost:1234/v2/chat/completions")] // Uses full path when provided
[InlineData("http://localhost:1234", "http://localhost:1234/v1/chat/completions")]
[InlineData("http://localhost:1234/v1/chat/completions", "http://localhost:1234/v1/chat/completions")] // Uses full path when provided
[InlineData("http://localhost:1234/", "http://localhost:1234/v1/chat/completions")]
[InlineData("http://localhost:8080", "http://localhost:8080/v1/chat/completions")]
[InlineData("https://something:8080", "https://something:8080/v1/chat/completions")] // Accepts TLS Secured endpoints
[InlineData("http://localhost:1234/v2", "http://localhost:1234/v2/chat/completions")]
[InlineData("http://localhost:8080/v2", "http://localhost:8080/v2/chat/completions")]
public async Task ItUsesCustomEndpointsWhenProvidedAsBaseAddressAsync(string endpointProvided, string expectedEndpoint)
{
// Arrange
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ internal async IAsyncEnumerable<OpenAIStreamingChatMessageContent> GetStreamingC
using (var activity = this.StartCompletionActivity(chat, chatExecutionSettings))
{
// Make the request.
AsyncResultCollection<StreamingChatCompletionUpdate> response;
AsyncCollectionResult<StreamingChatCompletionUpdate> response;
try
{
response = RunRequest(() => this.Client!.GetChatClient(targetModel).CompleteChatStreamingAsync(chatForRequest, chatOptions, cancellationToken));
Expand Down Expand Up @@ -644,7 +644,7 @@ protected virtual ChatCompletionOptions CreateChatCompletionOptions(
FrequencyPenalty = (float?)executionSettings.FrequencyPenalty,
PresencePenalty = (float?)executionSettings.PresencePenalty,
Seed = executionSettings.Seed,
User = executionSettings.User,
EndUserId = executionSettings.User,
TopLogProbabilityCount = executionSettings.TopLogprobs,
IncludeLogProbabilities = executionSettings.Logprobs,
ResponseFormat = GetResponseFormat(executionSettings) ?? ChatResponseFormat.Text,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ internal async Task<IReadOnlyList<AudioContent>> GetAudioContentsAsync(
Speed = audioExecutionSettings.Speed,
};

ClientResult<BinaryData> response = await RunRequestAsync(() => this.Client!.GetAudioClient(targetModel).GenerateSpeechFromTextAsync(prompt, GetGeneratedSpeechVoice(audioExecutionSettings?.Voice), options, cancellationToken)).ConfigureAwait(false);
ClientResult<BinaryData> response = await RunRequestAsync(() => this.Client!.GetAudioClient(targetModel).GenerateSpeechAsync(prompt, GetGeneratedSpeechVoice(audioExecutionSettings?.Voice), options, cancellationToken)).ConfigureAwait(false);

return [new AudioContent(response.Value.ToArray(), mimeType)];
}
Expand Down
4 changes: 2 additions & 2 deletions dotnet/src/Connectors/Connectors.OpenAI/Core/ClientCore.cs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ internal partial class ClientCore
/// <summary>
/// Default OpenAI API endpoint.
/// </summary>
private const string OpenAIV1Endpoint = "https://api.openai.com/v1";
private const string OpenAIEndpoint = "https://api.openai.com/";

/// <summary>
/// Identifier of the default model to use
Expand Down Expand Up @@ -104,7 +104,7 @@ internal ClientCore(
if (this.Endpoint is null)
{
Verify.NotNullOrWhiteSpace(apiKey); // For Public OpenAI Endpoint a key must be provided.
this.Endpoint = new Uri(OpenAIV1Endpoint);
this.Endpoint = new Uri(OpenAIEndpoint);
}
else if (string.IsNullOrEmpty(apiKey))
{
Expand Down
Loading

0 comments on commit 4d4e3ad

Please sign in to comment.