From f149f95299203102a6c572b8db98f79f066b0050 Mon Sep 17 00:00:00 2001 From: Roger Barreto <19890735+RogerBarreto@users.noreply.github.com> Date: Fri, 6 Sep 2024 15:32:15 +0100 Subject: [PATCH] .Net: Add Ollama Connector (#7362) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # Motivation and Context This PR brings support for Ollama Connector, this Connector uses the `OllamaSharp` library client to allow usage of native Ollama Endpoints. --------- Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Evan Mattson <35585003+moonbox3@users.noreply.github.com> Co-authored-by: Dmytro Struk <13853051+dmytrostruk@users.noreply.github.com> Co-authored-by: Ikko Eltociear Ashimine Co-authored-by: Chris <66376200+crickman@users.noreply.github.com> Co-authored-by: ShuaiHua Du Co-authored-by: Krzysztof Kasprowicz <60486987+Krzysztof318@users.noreply.github.com> Co-authored-by: Mark Wallace <127216156+markwallace-microsoft@users.noreply.github.com> Co-authored-by: SergeyMenshykh <68852919+SergeyMenshykh@users.noreply.github.com> Co-authored-by: Nico Möller Co-authored-by: Nico Möller Co-authored-by: westey <164392973+westey-m@users.noreply.github.com> Co-authored-by: Tao Chen Co-authored-by: Eduard van Valkenburg Co-authored-by: NEWTON MALLICK <38786893+N-E-W-T-O-N@users.noreply.github.com> Co-authored-by: qowlsdn8007 <33804074+qowlsdn8007@users.noreply.github.com> Co-authored-by: Gil LaHaye --- dotnet/Directory.Packages.props | 5 +- dotnet/SK-dotnet.sln | 20 +- .../ChatCompletion/Ollama_ChatCompletion.cs | 73 ++++++ .../Ollama_ChatCompletionStreaming.cs | 161 ++++++++++++ .../ChatCompletion/OpenAI_ChatCompletion.cs | 21 +- .../OpenAI_ChatCompletionStreaming.cs | 13 +- dotnet/samples/Concepts/Concepts.csproj | 1 + .../Memory/Ollama_EmbeddingGeneration.cs | 35 +++ .../TextGeneration/Ollama_TextGeneration.cs | 76 ++++++ .../Ollama_TextGenerationStreaming.cs | 57 ++++ .../Demos/AIModelRouter/AIModelRouter.csproj | 1 + .../Demos/AIModelRouter/CustomRouter.cs | 4 +- dotnet/samples/Demos/AIModelRouter/Program.cs | 5 +- .../AIModelRouter/SelectedServiceFilter.cs | 2 +- .../Connectors.Ollama.UnitTests.csproj | 50 ++++ .../OllamaKernelBuilderExtensionsTests.cs | 59 +++++ .../OllamaServiceCollectionExtensionsTests.cs | 57 ++++ .../Services/OllamaChatCompletionTests.cs | 218 ++++++++++++++++ .../OllamaTextEmbeddingGenerationTests.cs | 67 +++++ .../Services/OllamaTextGenerationTests.cs | 200 ++++++++++++++ .../OllamaPromptExecutionSettingsTests.cs | 65 +++++ .../chat_completion_test_response_stream.txt | 6 + .../TestData/embeddings_test_response.json | 19 ++ .../text_generation_test_response_stream.txt | 6 + .../Connectors.Ollama/AssemblyInfo.cs | 6 + .../Connectors.Ollama.csproj | 34 +++ .../Connectors.Ollama/Core/ServiceBase.cs | 62 +++++ .../OllamaKernelBuilderExtensions.cs | 231 +++++++++++++++++ .../OllamaServiceCollectionExtensions.cs | 243 ++++++++++++++++++ .../Services/OllamaChatCompletionService.cs | 182 +++++++++++++ .../OllamaTextEmbeddingGenerationService.cs | 93 +++++++ .../Services/OllamaTextGenerationService.cs | 142 ++++++++++ .../Settings/OllamaPromptExecutionSettings.cs | 122 +++++++++ .../Ollama/OllamaCompletionTests.cs | 182 +++++++++++++ .../Ollama/OllamaTextEmbeddingTests.cs | 70 +++++ .../Ollama/OllamaTextGenerationTests.cs | 181 +++++++++++++ .../IntegrationTests/IntegrationTests.csproj | 1 + .../TestSettings/OllamaConfiguration.cs | 13 + .../samples/InternalUtilities/BaseTest.cs | 12 + .../InternalUtilities/TestConfiguration.cs | 9 + 40 files changed, 2767 insertions(+), 37 deletions(-) create mode 100644 dotnet/samples/Concepts/ChatCompletion/Ollama_ChatCompletion.cs create mode 100644 dotnet/samples/Concepts/ChatCompletion/Ollama_ChatCompletionStreaming.cs create mode 100644 dotnet/samples/Concepts/Memory/Ollama_EmbeddingGeneration.cs create mode 100644 dotnet/samples/Concepts/TextGeneration/Ollama_TextGeneration.cs create mode 100644 dotnet/samples/Concepts/TextGeneration/Ollama_TextGenerationStreaming.cs create mode 100644 dotnet/src/Connectors/Connectors.Ollama.UnitTests/Connectors.Ollama.UnitTests.csproj create mode 100644 dotnet/src/Connectors/Connectors.Ollama.UnitTests/Extensions/OllamaKernelBuilderExtensionsTests.cs create mode 100644 dotnet/src/Connectors/Connectors.Ollama.UnitTests/Extensions/OllamaServiceCollectionExtensionsTests.cs create mode 100644 dotnet/src/Connectors/Connectors.Ollama.UnitTests/Services/OllamaChatCompletionTests.cs create mode 100644 dotnet/src/Connectors/Connectors.Ollama.UnitTests/Services/OllamaTextEmbeddingGenerationTests.cs create mode 100644 dotnet/src/Connectors/Connectors.Ollama.UnitTests/Services/OllamaTextGenerationTests.cs create mode 100644 dotnet/src/Connectors/Connectors.Ollama.UnitTests/Settings/OllamaPromptExecutionSettingsTests.cs create mode 100644 dotnet/src/Connectors/Connectors.Ollama.UnitTests/TestData/chat_completion_test_response_stream.txt create mode 100644 dotnet/src/Connectors/Connectors.Ollama.UnitTests/TestData/embeddings_test_response.json create mode 100644 dotnet/src/Connectors/Connectors.Ollama.UnitTests/TestData/text_generation_test_response_stream.txt create mode 100644 dotnet/src/Connectors/Connectors.Ollama/AssemblyInfo.cs create mode 100644 dotnet/src/Connectors/Connectors.Ollama/Connectors.Ollama.csproj create mode 100644 dotnet/src/Connectors/Connectors.Ollama/Core/ServiceBase.cs create mode 100644 dotnet/src/Connectors/Connectors.Ollama/Extensions/OllamaKernelBuilderExtensions.cs create mode 100644 dotnet/src/Connectors/Connectors.Ollama/Extensions/OllamaServiceCollectionExtensions.cs create mode 100644 dotnet/src/Connectors/Connectors.Ollama/Services/OllamaChatCompletionService.cs create mode 100644 dotnet/src/Connectors/Connectors.Ollama/Services/OllamaTextEmbeddingGenerationService.cs create mode 100644 dotnet/src/Connectors/Connectors.Ollama/Services/OllamaTextGenerationService.cs create mode 100644 dotnet/src/Connectors/Connectors.Ollama/Settings/OllamaPromptExecutionSettings.cs create mode 100644 dotnet/src/IntegrationTests/Connectors/Ollama/OllamaCompletionTests.cs create mode 100644 dotnet/src/IntegrationTests/Connectors/Ollama/OllamaTextEmbeddingTests.cs create mode 100644 dotnet/src/IntegrationTests/Connectors/Ollama/OllamaTextGenerationTests.cs create mode 100644 dotnet/src/IntegrationTests/TestSettings/OllamaConfiguration.cs diff --git a/dotnet/Directory.Packages.props b/dotnet/Directory.Packages.props index b91877ab8856..1f437f772202 100644 --- a/dotnet/Directory.Packages.props +++ b/dotnet/Directory.Packages.props @@ -39,6 +39,7 @@ + @@ -135,8 +136,8 @@ runtime; build; native; contentfiles; analyzers; buildtransitive - - + + \ No newline at end of file diff --git a/dotnet/SK-dotnet.sln b/dotnet/SK-dotnet.sln index 3358ed644325..b495d5eeedf0 100644 --- a/dotnet/SK-dotnet.sln +++ b/dotnet/SK-dotnet.sln @@ -314,12 +314,16 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Connectors.Redis.UnitTests" EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Connectors.Qdrant.UnitTests", "src\Connectors\Connectors.Qdrant.UnitTests\Connectors.Qdrant.UnitTests.csproj", "{E92AE954-8F3A-4A6F-A4F9-DC12017E5AAF}" EndProject -Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "StepwisePlannerMigration", "samples\Demos\StepwisePlannerMigration\StepwisePlannerMigration.csproj", "{38374C62-0263-4FE8-A18C-70FC8132912B}" +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Connectors.Ollama", "src\Connectors\Connectors.Ollama\Connectors.Ollama.csproj", "{E7E60E1D-1A44-4DE9-A44D-D5052E809DDD}" EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Connectors.AzureCosmosDBMongoDB.UnitTests", "src\Connectors\Connectors.AzureCosmosDBMongoDB.UnitTests\Connectors.AzureCosmosDBMongoDB.UnitTests.csproj", "{2918478E-BC86-4D53-9D01-9C318F80C14F}" EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AIModelRouter", "samples\Demos\AIModelRouter\AIModelRouter.csproj", "{E06818E3-00A5-41AC-97ED-9491070CDEA1}" EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Connectors.Ollama.UnitTests", "src\Connectors\Connectors.Ollama.UnitTests\Connectors.Ollama.UnitTests.csproj", "{924DB138-1223-4C99-B6E6-0938A3FA14EF}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "StepwisePlannerMigration", "samples\Demos\StepwisePlannerMigration\StepwisePlannerMigration.csproj", "{38374C62-0263-4FE8-A18C-70FC8132912B}" +EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Connectors.AzureCosmosDBNoSQL.UnitTests", "src\Connectors\Connectors.AzureCosmosDBNoSQL.UnitTests\Connectors.AzureCosmosDBNoSQL.UnitTests.csproj", "{385A8FE5-87E2-4458-AE09-35E10BD2E67F}" EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Connectors.Weaviate.UnitTests", "src\Connectors\Connectors.Weaviate.UnitTests\Connectors.Weaviate.UnitTests.csproj", "{AD9ECE32-088A-49D8-8ACB-890E79F1E7B8}" @@ -787,6 +791,18 @@ Global {E92AE954-8F3A-4A6F-A4F9-DC12017E5AAF}.Publish|Any CPU.Build.0 = Debug|Any CPU {E92AE954-8F3A-4A6F-A4F9-DC12017E5AAF}.Release|Any CPU.ActiveCfg = Release|Any CPU {E92AE954-8F3A-4A6F-A4F9-DC12017E5AAF}.Release|Any CPU.Build.0 = Release|Any CPU + {E7E60E1D-1A44-4DE9-A44D-D5052E809DDD}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {E7E60E1D-1A44-4DE9-A44D-D5052E809DDD}.Debug|Any CPU.Build.0 = Debug|Any CPU + {E7E60E1D-1A44-4DE9-A44D-D5052E809DDD}.Publish|Any CPU.ActiveCfg = Publish|Any CPU + {E7E60E1D-1A44-4DE9-A44D-D5052E809DDD}.Publish|Any CPU.Build.0 = Publish|Any CPU + {E7E60E1D-1A44-4DE9-A44D-D5052E809DDD}.Release|Any CPU.ActiveCfg = Release|Any CPU + {E7E60E1D-1A44-4DE9-A44D-D5052E809DDD}.Release|Any CPU.Build.0 = Release|Any CPU + {924DB138-1223-4C99-B6E6-0938A3FA14EF}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {924DB138-1223-4C99-B6E6-0938A3FA14EF}.Debug|Any CPU.Build.0 = Debug|Any CPU + {924DB138-1223-4C99-B6E6-0938A3FA14EF}.Publish|Any CPU.ActiveCfg = Debug|Any CPU + {924DB138-1223-4C99-B6E6-0938A3FA14EF}.Publish|Any CPU.Build.0 = Debug|Any CPU + {924DB138-1223-4C99-B6E6-0938A3FA14EF}.Release|Any CPU.ActiveCfg = Release|Any CPU + {924DB138-1223-4C99-B6E6-0938A3FA14EF}.Release|Any CPU.Build.0 = Release|Any CPU {38374C62-0263-4FE8-A18C-70FC8132912B}.Debug|Any CPU.ActiveCfg = Debug|Any CPU {38374C62-0263-4FE8-A18C-70FC8132912B}.Debug|Any CPU.Build.0 = Debug|Any CPU {38374C62-0263-4FE8-A18C-70FC8132912B}.Publish|Any CPU.ActiveCfg = Debug|Any CPU @@ -941,6 +957,8 @@ Global {B0B3901E-AF56-432B-8FAA-858468E5D0DF} = {24503383-A8C4-4255-9998-28D70FE8E99A} {1D4667B9-9381-4E32-895F-123B94253EE8} = {0247C2C9-86C3-45BA-8873-28B0948EDC0C} {E92AE954-8F3A-4A6F-A4F9-DC12017E5AAF} = {0247C2C9-86C3-45BA-8873-28B0948EDC0C} + {E7E60E1D-1A44-4DE9-A44D-D5052E809DDD} = {1B4CBDE0-10C2-4E7D-9CD0-FE7586C96ED1} + {924DB138-1223-4C99-B6E6-0938A3FA14EF} = {1B4CBDE0-10C2-4E7D-9CD0-FE7586C96ED1} {38374C62-0263-4FE8-A18C-70FC8132912B} = {5D4C0700-BBB5-418F-A7B2-F392B9A18263} {2918478E-BC86-4D53-9D01-9C318F80C14F} = {0247C2C9-86C3-45BA-8873-28B0948EDC0C} {E06818E3-00A5-41AC-97ED-9491070CDEA1} = {5D4C0700-BBB5-418F-A7B2-F392B9A18263} diff --git a/dotnet/samples/Concepts/ChatCompletion/Ollama_ChatCompletion.cs b/dotnet/samples/Concepts/ChatCompletion/Ollama_ChatCompletion.cs new file mode 100644 index 000000000000..b76b4fff88a1 --- /dev/null +++ b/dotnet/samples/Concepts/ChatCompletion/Ollama_ChatCompletion.cs @@ -0,0 +1,73 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Text; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.ChatCompletion; +using Microsoft.SemanticKernel.Connectors.Ollama; + +namespace ChatCompletion; + +// The following example shows how to use Semantic Kernel with Ollama Chat Completion API +public class Ollama_ChatCompletion(ITestOutputHelper output) : BaseTest(output) +{ + [Fact] + public async Task ServicePromptAsync() + { + Assert.NotNull(TestConfiguration.Ollama.ModelId); + + Console.WriteLine("======== Ollama - Chat Completion ========"); + + var chatService = new OllamaChatCompletionService( + endpoint: new Uri(TestConfiguration.Ollama.Endpoint), + modelId: TestConfiguration.Ollama.ModelId); + + Console.WriteLine("Chat content:"); + Console.WriteLine("------------------------"); + + var chatHistory = new ChatHistory("You are a librarian, expert about books"); + + // First user message + chatHistory.AddUserMessage("Hi, I'm looking for book suggestions"); + this.OutputLastMessage(chatHistory); + + // First assistant message + var reply = await chatService.GetChatMessageContentAsync(chatHistory); + chatHistory.Add(reply); + this.OutputLastMessage(chatHistory); + + // Second user message + chatHistory.AddUserMessage("I love history and philosophy, I'd like to learn something new about Greece, any suggestion"); + this.OutputLastMessage(chatHistory); + + // Second assistant message + reply = await chatService.GetChatMessageContentAsync(chatHistory); + chatHistory.Add(reply); + this.OutputLastMessage(chatHistory); + } + + [Fact] + public async Task ChatPromptAsync() + { + Assert.NotNull(TestConfiguration.Ollama.ModelId); + + StringBuilder chatPrompt = new(""" + You are a librarian, expert about books + Hi, I'm looking for book suggestions + """); + + var kernel = Kernel.CreateBuilder() + .AddOllamaChatCompletion( + endpoint: new Uri(TestConfiguration.Ollama.Endpoint ?? "http://localhost:11434"), + modelId: TestConfiguration.Ollama.ModelId) + .Build(); + + var reply = await kernel.InvokePromptAsync(chatPrompt.ToString()); + + chatPrompt.AppendLine($""); + chatPrompt.AppendLine("I love history and philosophy, I'd like to learn something new about Greece, any suggestion"); + + reply = await kernel.InvokePromptAsync(chatPrompt.ToString()); + + Console.WriteLine(reply); + } +} diff --git a/dotnet/samples/Concepts/ChatCompletion/Ollama_ChatCompletionStreaming.cs b/dotnet/samples/Concepts/ChatCompletion/Ollama_ChatCompletionStreaming.cs new file mode 100644 index 000000000000..d83aac04e9bf --- /dev/null +++ b/dotnet/samples/Concepts/ChatCompletion/Ollama_ChatCompletionStreaming.cs @@ -0,0 +1,161 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Text; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.ChatCompletion; +using Microsoft.SemanticKernel.Connectors.Ollama; + +namespace ChatCompletion; + +/// +/// These examples demonstrate the ways different content types are streamed by Ollama via the chat completion service. +/// +public class Ollama_ChatCompletionStreaming(ITestOutputHelper output) : BaseTest(output) +{ + /// + /// This example demonstrates chat completion streaming using Ollama. + /// + [Fact] + public Task StreamChatAsync() + { + Assert.NotNull(TestConfiguration.Ollama.ModelId); + + Console.WriteLine("======== Ollama - Chat Completion Streaming ========"); + + var chatService = new OllamaChatCompletionService( + endpoint: new Uri(TestConfiguration.Ollama.Endpoint), + modelId: TestConfiguration.Ollama.ModelId); + + return this.StartStreamingChatAsync(chatService); + } + + [Fact] + public async Task StreamChatPromptAsync() + { + Assert.NotNull(TestConfiguration.Ollama.ModelId); + + StringBuilder chatPrompt = new(""" + You are a librarian, expert about books + Hi, I'm looking for book suggestions + """); + + var kernel = Kernel.CreateBuilder() + .AddOllamaChatCompletion( + endpoint: new Uri(TestConfiguration.Ollama.Endpoint), + modelId: TestConfiguration.Ollama.ModelId) + .Build(); + + var reply = await StreamMessageOutputFromKernelAsync(kernel, chatPrompt.ToString()); + + chatPrompt.AppendLine($""); + chatPrompt.AppendLine("I love history and philosophy, I'd like to learn something new about Greece, any suggestion"); + + reply = await StreamMessageOutputFromKernelAsync(kernel, chatPrompt.ToString()); + + Console.WriteLine(reply); + } + + /// + /// This example demonstrates how the chat completion service streams text content. + /// It shows how to access the response update via StreamingChatMessageContent.Content property + /// and alternatively via the StreamingChatMessageContent.Items property. + /// + [Fact] + public async Task StreamTextFromChatAsync() + { + Assert.NotNull(TestConfiguration.Ollama.ModelId); + + Console.WriteLine("======== Stream Text from Chat Content ========"); + + // Create chat completion service + var chatService = new OllamaChatCompletionService( + endpoint: new Uri(TestConfiguration.Ollama.Endpoint), + modelId: TestConfiguration.Ollama.ModelId); + + // Create chat history with initial system and user messages + ChatHistory chatHistory = new("You are a librarian, an expert on books."); + chatHistory.AddUserMessage("Hi, I'm looking for book suggestions."); + chatHistory.AddUserMessage("I love history and philosophy. I'd like to learn something new about Greece, any suggestion?"); + + // Start streaming chat based on the chat history + await foreach (StreamingChatMessageContent chatUpdate in chatService.GetStreamingChatMessageContentsAsync(chatHistory)) + { + // Access the response update via StreamingChatMessageContent.Content property + Console.Write(chatUpdate.Content); + + // Alternatively, the response update can be accessed via the StreamingChatMessageContent.Items property + Console.Write(chatUpdate.Items.OfType().FirstOrDefault()); + } + } + + private async Task StartStreamingChatAsync(IChatCompletionService chatCompletionService) + { + Console.WriteLine("Chat content:"); + Console.WriteLine("------------------------"); + + var chatHistory = new ChatHistory("You are a librarian, expert about books"); + this.OutputLastMessage(chatHistory); + + // First user message + chatHistory.AddUserMessage("Hi, I'm looking for book suggestions"); + this.OutputLastMessage(chatHistory); + + // First assistant message + await StreamMessageOutputAsync(chatCompletionService, chatHistory, AuthorRole.Assistant); + + // Second user message + chatHistory.AddUserMessage("I love history and philosophy, I'd like to learn something new about Greece, any suggestion?"); + this.OutputLastMessage(chatHistory); + + // Second assistant message + await StreamMessageOutputAsync(chatCompletionService, chatHistory, AuthorRole.Assistant); + } + + private async Task StreamMessageOutputAsync(IChatCompletionService chatCompletionService, ChatHistory chatHistory, AuthorRole authorRole) + { + bool roleWritten = false; + string fullMessage = string.Empty; + + await foreach (var chatUpdate in chatCompletionService.GetStreamingChatMessageContentsAsync(chatHistory)) + { + if (!roleWritten && chatUpdate.Role.HasValue) + { + Console.Write($"{chatUpdate.Role.Value}: {chatUpdate.Content}"); + roleWritten = true; + } + + if (chatUpdate.Content is { Length: > 0 }) + { + fullMessage += chatUpdate.Content; + Console.Write(chatUpdate.Content); + } + } + + Console.WriteLine("\n------------------------"); + chatHistory.AddMessage(authorRole, fullMessage); + } + + private async Task StreamMessageOutputFromKernelAsync(Kernel kernel, string prompt) + { + bool roleWritten = false; + string fullMessage = string.Empty; + + await foreach (var chatUpdate in kernel.InvokePromptStreamingAsync(prompt)) + { + if (!roleWritten && chatUpdate.Role.HasValue) + { + Console.Write($"{chatUpdate.Role.Value}: {chatUpdate.Content}"); + roleWritten = true; + } + + if (chatUpdate.Content is { Length: > 0 }) + { + fullMessage += chatUpdate.Content; + Console.Write(chatUpdate.Content); + } + } + + Console.WriteLine("\n------------------------"); + return fullMessage; + } +} diff --git a/dotnet/samples/Concepts/ChatCompletion/OpenAI_ChatCompletion.cs b/dotnet/samples/Concepts/ChatCompletion/OpenAI_ChatCompletion.cs index 42164d3fe8dc..a92c86dd977d 100644 --- a/dotnet/samples/Concepts/ChatCompletion/OpenAI_ChatCompletion.cs +++ b/dotnet/samples/Concepts/ChatCompletion/OpenAI_ChatCompletion.cs @@ -89,33 +89,20 @@ private async Task StartChatAsync(IChatCompletionService chatGPT) // First user message chatHistory.AddUserMessage("Hi, I'm looking for book suggestions"); - await MessageOutputAsync(chatHistory); + OutputLastMessage(chatHistory); // First bot assistant message var reply = await chatGPT.GetChatMessageContentAsync(chatHistory); chatHistory.Add(reply); - await MessageOutputAsync(chatHistory); + OutputLastMessage(chatHistory); // Second user message chatHistory.AddUserMessage("I love history and philosophy, I'd like to learn something new about Greece, any suggestion"); - await MessageOutputAsync(chatHistory); + OutputLastMessage(chatHistory); // Second bot assistant message reply = await chatGPT.GetChatMessageContentAsync(chatHistory); chatHistory.Add(reply); - await MessageOutputAsync(chatHistory); - } - - /// - /// Outputs the last message of the chat history - /// - private Task MessageOutputAsync(ChatHistory chatHistory) - { - var message = chatHistory.Last(); - - Console.WriteLine($"{message.Role}: {message.Content}"); - Console.WriteLine("------------------------"); - - return Task.CompletedTask; + OutputLastMessage(chatHistory); } } diff --git a/dotnet/samples/Concepts/ChatCompletion/OpenAI_ChatCompletionStreaming.cs b/dotnet/samples/Concepts/ChatCompletion/OpenAI_ChatCompletionStreaming.cs index bd1285e29af3..fe0052a52db2 100644 --- a/dotnet/samples/Concepts/ChatCompletion/OpenAI_ChatCompletionStreaming.cs +++ b/dotnet/samples/Concepts/ChatCompletion/OpenAI_ChatCompletionStreaming.cs @@ -99,7 +99,7 @@ public async Task StreamFunctionCallContentAsync() OpenAIPromptExecutionSettings settings = new() { ToolCallBehavior = ToolCallBehavior.EnableKernelFunctions }; // Create chat history with initial user question - ChatHistory chatHistory = new(); + ChatHistory chatHistory = []; chatHistory.AddUserMessage("Hi, what is the current time?"); // Start streaming chat based on the chat history @@ -162,15 +162,4 @@ private async Task StreamMessageOutputAsync(IChatCompletionService chatCompletio Console.WriteLine("\n------------------------"); chatHistory.AddMessage(authorRole, fullMessage); } - - /// - /// Outputs the last message of the chat history - /// - private void OutputLastMessage(ChatHistory chatHistory) - { - var message = chatHistory.Last(); - - Console.WriteLine($"{message.Role}: {message.Content}"); - Console.WriteLine("------------------------"); - } } diff --git a/dotnet/samples/Concepts/Concepts.csproj b/dotnet/samples/Concepts/Concepts.csproj index ada25229d413..348f31c399f9 100644 --- a/dotnet/samples/Concepts/Concepts.csproj +++ b/dotnet/samples/Concepts/Concepts.csproj @@ -67,6 +67,7 @@ + diff --git a/dotnet/samples/Concepts/Memory/Ollama_EmbeddingGeneration.cs b/dotnet/samples/Concepts/Memory/Ollama_EmbeddingGeneration.cs new file mode 100644 index 000000000000..5ba0a45440b2 --- /dev/null +++ b/dotnet/samples/Concepts/Memory/Ollama_EmbeddingGeneration.cs @@ -0,0 +1,35 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.Embeddings; +using xRetry; + +#pragma warning disable format // Format item can be simplified +#pragma warning disable CA1861 // Avoid constant arrays as arguments + +namespace Memory; + +// The following example shows how to use Semantic Kernel with Ollama API. +public class Ollama_EmbeddingGeneration(ITestOutputHelper output) : BaseTest(output) +{ + [RetryFact(typeof(HttpOperationException))] + public async Task RunEmbeddingAsync() + { + Assert.NotNull(TestConfiguration.Ollama.EmbeddingModelId); + + Console.WriteLine("\n======= Ollama - Embedding Example ========\n"); + + Kernel kernel = Kernel.CreateBuilder() + .AddOllamaTextEmbeddingGeneration( + endpoint: new Uri(TestConfiguration.Ollama.Endpoint), + modelId: TestConfiguration.Ollama.EmbeddingModelId) + .Build(); + + var embeddingGenerator = kernel.GetRequiredService(); + + // Generate embeddings for each chunk. + var embeddings = await embeddingGenerator.GenerateEmbeddingsAsync(["John: Hello, how are you?\nRoger: Hey, I'm Roger!"]); + + Console.WriteLine($"Generated {embeddings.Count} embeddings for the provided text"); + } +} diff --git a/dotnet/samples/Concepts/TextGeneration/Ollama_TextGeneration.cs b/dotnet/samples/Concepts/TextGeneration/Ollama_TextGeneration.cs new file mode 100644 index 000000000000..719d5eb9f951 --- /dev/null +++ b/dotnet/samples/Concepts/TextGeneration/Ollama_TextGeneration.cs @@ -0,0 +1,76 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.TextGeneration; +using xRetry; + +#pragma warning disable format // Format item can be simplified +#pragma warning disable CA1861 // Avoid constant arrays as arguments + +namespace TextGeneration; + +// The following example shows how to use Semantic Kernel with Ollama Text Generation API. +public class Ollama_TextGeneration(ITestOutputHelper helper) : BaseTest(helper) +{ + [Fact] + public async Task KernelPromptAsync() + { + Assert.NotNull(TestConfiguration.Ollama.ModelId); + + Console.WriteLine("\n======== Ollama Text Generation example ========\n"); + + Kernel kernel = Kernel.CreateBuilder() + .AddOllamaTextGeneration( + endpoint: new Uri(TestConfiguration.Ollama.Endpoint), + modelId: TestConfiguration.Ollama.ModelId) + .Build(); + + var questionAnswerFunction = kernel.CreateFunctionFromPrompt("Question: {{$input}}; Answer:"); + + var result = await kernel.InvokeAsync(questionAnswerFunction, new() { ["input"] = "What is New York?" }); + + Console.WriteLine(result.GetValue()); + } + + [Fact] + public async Task ServicePromptAsync() + { + Assert.NotNull(TestConfiguration.Ollama.ModelId); + + Console.WriteLine("\n======== Ollama Text Generation example ========\n"); + + Kernel kernel = Kernel.CreateBuilder() + .AddOllamaTextGeneration( + endpoint: new Uri(TestConfiguration.Ollama.Endpoint), + modelId: TestConfiguration.Ollama.ModelId) + .Build(); + + var service = kernel.GetRequiredService(); + var result = await service.GetTextContentAsync("Question: What is New York?; Answer:"); + + Console.WriteLine(result); + } + + [RetryFact(typeof(HttpOperationException))] + public async Task RunStreamingExampleAsync() + { + Assert.NotNull(TestConfiguration.Ollama.ModelId); + + string model = TestConfiguration.Ollama.ModelId; + + Console.WriteLine($"\n======== HuggingFace {model} streaming example ========\n"); + + Kernel kernel = Kernel.CreateBuilder() + .AddOllamaTextGeneration( + endpoint: new Uri(TestConfiguration.Ollama.Endpoint), + modelId: TestConfiguration.Ollama.ModelId) + .Build(); + + var questionAnswerFunction = kernel.CreateFunctionFromPrompt("Question: {{$input}}; Answer:"); + + await foreach (string text in kernel.InvokePromptStreamingAsync("Question: {{$input}}; Answer:", new() { ["input"] = "What is New York?" })) + { + Console.Write(text); + } + } +} diff --git a/dotnet/samples/Concepts/TextGeneration/Ollama_TextGenerationStreaming.cs b/dotnet/samples/Concepts/TextGeneration/Ollama_TextGenerationStreaming.cs new file mode 100644 index 000000000000..35e0c31074f4 --- /dev/null +++ b/dotnet/samples/Concepts/TextGeneration/Ollama_TextGenerationStreaming.cs @@ -0,0 +1,57 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.TextGeneration; + +#pragma warning disable format // Format item can be simplified +#pragma warning disable CA1861 // Avoid constant arrays as arguments + +namespace TextGeneration; + +// The following example shows how to use Semantic Kernel with Ollama Text Generation API. +public class Ollama_TextGenerationStreaming(ITestOutputHelper helper) : BaseTest(helper) +{ + [Fact] + public async Task RunKernelStreamingExampleAsync() + { + Assert.NotNull(TestConfiguration.Ollama.ModelId); + + string model = TestConfiguration.Ollama.ModelId; + + Console.WriteLine($"\n======== Ollama {model} streaming example ========\n"); + + Kernel kernel = Kernel.CreateBuilder() + .AddOllamaTextGeneration( + endpoint: new Uri(TestConfiguration.Ollama.Endpoint), + modelId: model) + .Build(); + + await foreach (string text in kernel.InvokePromptStreamingAsync("Question: {{$input}}; Answer:", new() { ["input"] = "What is New York?" })) + { + Console.Write(text); + } + } + + [Fact] + public async Task RunServiceStreamingExampleAsync() + { + Assert.NotNull(TestConfiguration.Ollama.ModelId); + + string model = TestConfiguration.Ollama.ModelId; + + Console.WriteLine($"\n======== Ollama {model} streaming example ========\n"); + + Kernel kernel = Kernel.CreateBuilder() + .AddOllamaTextGeneration( + endpoint: new Uri(TestConfiguration.Ollama.Endpoint), + modelId: model) + .Build(); + + var service = kernel.GetRequiredService(); + + await foreach (var content in service.GetStreamingTextContentsAsync("Question: What is New York?; Answer:")) + { + Console.Write(content); + } + } +} diff --git a/dotnet/samples/Demos/AIModelRouter/AIModelRouter.csproj b/dotnet/samples/Demos/AIModelRouter/AIModelRouter.csproj index fb5862e3270a..542082ca8960 100644 --- a/dotnet/samples/Demos/AIModelRouter/AIModelRouter.csproj +++ b/dotnet/samples/Demos/AIModelRouter/AIModelRouter.csproj @@ -14,6 +14,7 @@ + diff --git a/dotnet/samples/Demos/AIModelRouter/CustomRouter.cs b/dotnet/samples/Demos/AIModelRouter/CustomRouter.cs index ff2767a289c8..4d324bacdcd1 100644 --- a/dotnet/samples/Demos/AIModelRouter/CustomRouter.cs +++ b/dotnet/samples/Demos/AIModelRouter/CustomRouter.cs @@ -11,7 +11,7 @@ namespace AIModelRouter; /// In a real-world scenario, you would use a more sophisticated routing mechanism, such as another local model for /// deciding which service to use based on the user's input or any other criteria. /// -public class CustomRouter() +internal sealed class CustomRouter() { /// /// Returns the best service id to use based on the user's input. @@ -21,7 +21,7 @@ public class CustomRouter() /// User's input prompt /// List of service ids to choose from in order of importance, defaulting to the first /// Service id. - public string FindService(string lookupPrompt, IReadOnlyList serviceIds) + internal string FindService(string lookupPrompt, IReadOnlyList serviceIds) { // The order matters, if the keyword is not found, the first one is used. foreach (var serviceId in serviceIds) diff --git a/dotnet/samples/Demos/AIModelRouter/Program.cs b/dotnet/samples/Demos/AIModelRouter/Program.cs index 5bafa4934883..74dbf367e955 100644 --- a/dotnet/samples/Demos/AIModelRouter/Program.cs +++ b/dotnet/samples/Demos/AIModelRouter/Program.cs @@ -6,11 +6,12 @@ #pragma warning disable SKEXP0001 #pragma warning disable SKEXP0010 +#pragma warning disable SKEXP0070 #pragma warning disable CA2249 // Consider using 'string.Contains' instead of 'string.IndexOf' namespace AIModelRouter; -internal sealed partial class Program +internal sealed class Program { private static async Task Main(string[] args) { @@ -23,7 +24,7 @@ private static async Task Main(string[] args) // Adding multiple connectors targeting different providers / models. services.AddKernel() /* LMStudio model is selected in server side. */ .AddOpenAIChatCompletion(serviceId: "lmstudio", modelId: "N/A", endpoint: new Uri("http://localhost:1234"), apiKey: null) - .AddOpenAIChatCompletion(serviceId: "ollama", modelId: "phi3", endpoint: new Uri("http://localhost:11434"), apiKey: null) + .AddOllamaChatCompletion(serviceId: "ollama", modelId: "phi3", endpoint: new Uri("http://localhost:11434")) .AddOpenAIChatCompletion(serviceId: "openai", modelId: "gpt-4o", apiKey: config["OpenAI:ApiKey"]!) // Adding a custom filter to capture router selected service id diff --git a/dotnet/samples/Demos/AIModelRouter/SelectedServiceFilter.cs b/dotnet/samples/Demos/AIModelRouter/SelectedServiceFilter.cs index 9824d57ebd55..0c5334fc58a0 100644 --- a/dotnet/samples/Demos/AIModelRouter/SelectedServiceFilter.cs +++ b/dotnet/samples/Demos/AIModelRouter/SelectedServiceFilter.cs @@ -11,7 +11,7 @@ namespace AIModelRouter; /// /// Using a filter to log the service being used for the prompt. /// -public class SelectedServiceFilter : IPromptRenderFilter +internal sealed class SelectedServiceFilter : IPromptRenderFilter { /// public Task OnPromptRenderAsync(PromptRenderContext context, Func next) diff --git a/dotnet/src/Connectors/Connectors.Ollama.UnitTests/Connectors.Ollama.UnitTests.csproj b/dotnet/src/Connectors/Connectors.Ollama.UnitTests/Connectors.Ollama.UnitTests.csproj new file mode 100644 index 000000000000..78afaac82621 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Ollama.UnitTests/Connectors.Ollama.UnitTests.csproj @@ -0,0 +1,50 @@ + + + + SemanticKernel.Connectors.Ollama.UnitTests + SemanticKernel.Connectors.Ollama.UnitTests + net8.0 + 12 + LatestMajor + true + enable + disable + false + CA2007,CA1861,VSTHRD111,CS1591,SKEXP0001,SKEXP0070 + + + + + + + + + runtime; build; native; contentfiles; analyzers; buildtransitive + all + + + runtime; build; native; contentfiles; analyzers; buildtransitive + all + + + + + + + + + + + + + + + + + + + Always + + + + diff --git a/dotnet/src/Connectors/Connectors.Ollama.UnitTests/Extensions/OllamaKernelBuilderExtensionsTests.cs b/dotnet/src/Connectors/Connectors.Ollama.UnitTests/Extensions/OllamaKernelBuilderExtensionsTests.cs new file mode 100644 index 000000000000..668044164ded --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Ollama.UnitTests/Extensions/OllamaKernelBuilderExtensionsTests.cs @@ -0,0 +1,59 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.ChatCompletion; +using Microsoft.SemanticKernel.Connectors.Ollama; +using Microsoft.SemanticKernel.Embeddings; +using Microsoft.SemanticKernel.TextGeneration; +using Xunit; + +namespace SemanticKernel.Connectors.Ollama.UnitTests.Extensions; + +/// +/// Unit tests of . +/// +public class OllamaKernelBuilderExtensionsTests +{ + [Fact] + public void AddOllamaTextGenerationCreatesService() + { + var builder = Kernel.CreateBuilder(); + builder.AddOllamaTextGeneration("model", new Uri("http://localhost:11434")); + + var kernel = builder.Build(); + var service = kernel.GetRequiredService(); + + Assert.NotNull(kernel); + Assert.NotNull(service); + Assert.IsType(service); + } + + [Fact] + public void AddOllamaChatCompletionCreatesService() + { + var builder = Kernel.CreateBuilder(); + builder.AddOllamaChatCompletion("model", new Uri("http://localhost:11434")); + + var kernel = builder.Build(); + var service = kernel.GetRequiredService(); + + Assert.NotNull(kernel); + Assert.NotNull(service); + Assert.IsType(service); + } + + [Fact] + public void AddOllamaTextEmbeddingGenerationCreatesService() + { + var builder = Kernel.CreateBuilder(); + builder.AddOllamaTextEmbeddingGeneration("model", new Uri("http://localhost:11434")); + + var kernel = builder.Build(); + var service = kernel.GetRequiredService(); + + Assert.NotNull(kernel); + Assert.NotNull(service); + Assert.IsType(service); + } +} diff --git a/dotnet/src/Connectors/Connectors.Ollama.UnitTests/Extensions/OllamaServiceCollectionExtensionsTests.cs b/dotnet/src/Connectors/Connectors.Ollama.UnitTests/Extensions/OllamaServiceCollectionExtensionsTests.cs new file mode 100644 index 000000000000..2c3a4e79df04 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Ollama.UnitTests/Extensions/OllamaServiceCollectionExtensionsTests.cs @@ -0,0 +1,57 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.ChatCompletion; +using Microsoft.SemanticKernel.Connectors.Ollama; +using Microsoft.SemanticKernel.Embeddings; +using Microsoft.SemanticKernel.TextGeneration; +using Xunit; + +namespace SemanticKernel.Connectors.Ollama.UnitTests.Extensions; + +/// +/// Unit tests of . +/// +public class OllamaServiceCollectionExtensionsTests +{ + [Fact] + public void AddOllamaTextGenerationToServiceCollection() + { + var services = new ServiceCollection(); + services.AddOllamaTextGeneration("model", new Uri("http://localhost:11434")); + + var serviceProvider = services.BuildServiceProvider(); + var service = serviceProvider.GetRequiredService(); + + Assert.NotNull(service); + Assert.IsType(service); + } + + [Fact] + public void AddOllamaChatCompletionToServiceCollection() + { + var services = new ServiceCollection(); + services.AddOllamaChatCompletion("model", new Uri("http://localhost:11434")); + + var serviceProvider = services.BuildServiceProvider(); + var service = serviceProvider.GetRequiredService(); + + Assert.NotNull(service); + Assert.IsType(service); + } + + [Fact] + public void AddOllamaTextEmbeddingsGenerationToServiceCollection() + { + var services = new ServiceCollection(); + services.AddOllamaTextEmbeddingGeneration("model", new Uri("http://localhost:11434")); + + var serviceProvider = services.BuildServiceProvider(); + var service = serviceProvider.GetRequiredService(); + + Assert.NotNull(service); + Assert.IsType(service); + } +} diff --git a/dotnet/src/Connectors/Connectors.Ollama.UnitTests/Services/OllamaChatCompletionTests.cs b/dotnet/src/Connectors/Connectors.Ollama.UnitTests/Services/OllamaChatCompletionTests.cs new file mode 100644 index 000000000000..40e1b840beaf --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Ollama.UnitTests/Services/OllamaChatCompletionTests.cs @@ -0,0 +1,218 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.IO; +using System.Linq; +using System.Net.Http; +using System.Text.Json; +using System.Threading.Tasks; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.ChatCompletion; +using Microsoft.SemanticKernel.Connectors.Ollama; +using OllamaSharp.Models.Chat; +using Xunit; + +namespace SemanticKernel.Connectors.Ollama.UnitTests.Services; + +public sealed class OllamaChatCompletionTests : IDisposable +{ + private readonly HttpMessageHandlerStub _messageHandlerStub; + private readonly HttpClient _httpClient; + + public OllamaChatCompletionTests() + { + this._messageHandlerStub = new() + { + ResponseToReturn = new HttpResponseMessage(System.Net.HttpStatusCode.OK) + { + Content = new StreamContent(File.OpenRead("TestData/chat_completion_test_response_stream.txt")) + } + }; + this._httpClient = new HttpClient(this._messageHandlerStub, false) { BaseAddress = new Uri("http://localhost:11434") }; + } + + [Fact] + public async Task ShouldSendPromptToServiceAsync() + { + //Arrange + var sut = new OllamaChatCompletionService( + "fake-model", + httpClient: this._httpClient); + var chat = new ChatHistory(); + chat.AddMessage(AuthorRole.User, "fake-text"); + + //Act + await sut.GetChatMessageContentsAsync(chat); + + //Assert + var requestPayload = JsonSerializer.Deserialize(this._messageHandlerStub.RequestContent); + Assert.NotNull(requestPayload); + Assert.Equal("fake-text", requestPayload.Messages!.First().Content); + } + + [Fact] + public async Task ShouldHandleServiceResponseAsync() + { + //Arrange + var sut = new OllamaChatCompletionService( + "fake-model", + httpClient: this._httpClient); + + var chat = new ChatHistory(); + chat.AddMessage(AuthorRole.User, "fake-text"); + + //Act + var messages = await sut.GetChatMessageContentsAsync(chat); + + //Assert + Assert.NotNull(messages); + + var message = messages.SingleOrDefault(); + Assert.NotNull(message); + Assert.Equal("This is test completion response", message.Content); + } + + [Fact] + public async Task GetChatMessageContentsShouldHaveModelAndInnerContentAsync() + { + //Arrange + var sut = new OllamaChatCompletionService( + "phi3", + httpClient: this._httpClient); + + var chat = new ChatHistory(); + chat.AddMessage(AuthorRole.User, "fake-text"); + + //Act + var messages = await sut.GetChatMessageContentsAsync(chat); + + //Assert + Assert.NotNull(messages); + var message = messages.SingleOrDefault(); + Assert.NotNull(message); + + // Assert + var requestPayload = JsonSerializer.Deserialize(this._messageHandlerStub.RequestContent); + Assert.NotNull(requestPayload); + Assert.NotNull(requestPayload.Options); + Assert.Null(requestPayload.Options.Stop); + Assert.Null(requestPayload.Options.Temperature); + Assert.Null(requestPayload.Options.TopK); + Assert.Null(requestPayload.Options.TopP); + + Assert.NotNull(message.ModelId); + Assert.Equal("phi3", message.ModelId); + } + + [Fact] + public async Task GetStreamingChatMessageContentsShouldHaveModelAndInnerContentAsync() + { + //Arrange + var expectedModel = "phi3"; + var sut = new OllamaChatCompletionService( + expectedModel, + httpClient: this._httpClient); + + var chat = new ChatHistory(); + chat.AddMessage(AuthorRole.User, "fake-text"); + + // Act + StreamingChatMessageContent? lastMessage = null; + await foreach (var message in sut.GetStreamingChatMessageContentsAsync(chat)) + { + lastMessage = message; + Assert.NotNull(message.InnerContent); + } + + // Assert + var requestPayload = JsonSerializer.Deserialize(this._messageHandlerStub.RequestContent); + Assert.NotNull(requestPayload); + Assert.NotNull(requestPayload.Options); + Assert.Null(requestPayload.Options.Stop); + Assert.Null(requestPayload.Options.Temperature); + Assert.Null(requestPayload.Options.TopK); + Assert.Null(requestPayload.Options.TopP); + + Assert.NotNull(lastMessage!.ModelId); + Assert.Equal(expectedModel, lastMessage.ModelId); + + Assert.IsType(lastMessage.InnerContent); + var innerContent = lastMessage.InnerContent as ChatDoneResponseStream; + Assert.NotNull(innerContent); + Assert.True(innerContent.Done); + } + + [Fact] + public async Task GetStreamingChatMessageContentsExecutionSettingsMustBeSentAsync() + { + //Arrange + var sut = new OllamaChatCompletionService( + "fake-model", + httpClient: this._httpClient); + var chat = new ChatHistory(); + chat.AddMessage(AuthorRole.User, "fake-text"); + string jsonSettings = """ + { + "stop": ["stop me"], + "temperature": 0.5, + "top_p": 0.9, + "top_k": 100 + } + """; + + var executionSettings = JsonSerializer.Deserialize(jsonSettings); + var ollamaExecutionSettings = OllamaPromptExecutionSettings.FromExecutionSettings(executionSettings); + + // Act + await sut.GetStreamingChatMessageContentsAsync(chat, ollamaExecutionSettings).GetAsyncEnumerator().MoveNextAsync(); + + // Assert + var requestPayload = JsonSerializer.Deserialize(this._messageHandlerStub.RequestContent); + Assert.NotNull(requestPayload); + Assert.NotNull(requestPayload.Options); + Assert.Equal(ollamaExecutionSettings.Stop, requestPayload.Options.Stop); + Assert.Equal(ollamaExecutionSettings.Temperature, requestPayload.Options.Temperature); + Assert.Equal(ollamaExecutionSettings.TopP, requestPayload.Options.TopP); + Assert.Equal(ollamaExecutionSettings.TopK, requestPayload.Options.TopK); + } + + [Fact] + public async Task GetChatMessageContentsExecutionSettingsMustBeSentAsync() + { + //Arrange + var sut = new OllamaChatCompletionService( + "fake-model", + httpClient: this._httpClient); + var chat = new ChatHistory(); + chat.AddMessage(AuthorRole.User, "fake-text"); + string jsonSettings = """ + { + "stop": ["stop me"], + "temperature": 0.5, + "top_p": 0.9, + "top_k": 100 + } + """; + + var executionSettings = JsonSerializer.Deserialize(jsonSettings); + var ollamaExecutionSettings = OllamaPromptExecutionSettings.FromExecutionSettings(executionSettings); + + // Act + await sut.GetChatMessageContentsAsync(chat, ollamaExecutionSettings); + + // Assert + var requestPayload = JsonSerializer.Deserialize(this._messageHandlerStub.RequestContent); + Assert.NotNull(requestPayload); + Assert.NotNull(requestPayload.Options); + Assert.Equal(ollamaExecutionSettings.Stop, requestPayload.Options.Stop); + Assert.Equal(ollamaExecutionSettings.Temperature, requestPayload.Options.Temperature); + Assert.Equal(ollamaExecutionSettings.TopP, requestPayload.Options.TopP); + Assert.Equal(ollamaExecutionSettings.TopK, requestPayload.Options.TopK); + } + + public void Dispose() + { + this._httpClient.Dispose(); + this._messageHandlerStub.Dispose(); + } +} diff --git a/dotnet/src/Connectors/Connectors.Ollama.UnitTests/Services/OllamaTextEmbeddingGenerationTests.cs b/dotnet/src/Connectors/Connectors.Ollama.UnitTests/Services/OllamaTextEmbeddingGenerationTests.cs new file mode 100644 index 000000000000..ec1e63c1cd56 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Ollama.UnitTests/Services/OllamaTextEmbeddingGenerationTests.cs @@ -0,0 +1,67 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.IO; +using System.Net.Http; +using System.Text.Json; +using System.Threading.Tasks; +using Microsoft.SemanticKernel.Connectors.Ollama; +using OllamaSharp.Models; +using Xunit; + +namespace SemanticKernel.Connectors.Ollama.UnitTests.Services; + +public sealed class OllamaTextEmbeddingGenerationTests : IDisposable +{ + private readonly HttpMessageHandlerStub _messageHandlerStub; + private readonly HttpClient _httpClient; + + public OllamaTextEmbeddingGenerationTests() + { + this._messageHandlerStub = new(); + this._messageHandlerStub.ResponseToReturn.Content = new StringContent(File.ReadAllText("TestData/embeddings_test_response.json")); + this._httpClient = new HttpClient(this._messageHandlerStub, false) { BaseAddress = new Uri("http://localhost:11434") }; + } + + [Fact] + public async Task ShouldSendPromptToServiceAsync() + { + //Arrange + var sut = new OllamaTextEmbeddingGenerationService( + "fake-model", + httpClient: this._httpClient); + + //Act + await sut.GenerateEmbeddingsAsync(["fake-text"]); + + //Assert + var requestPayload = JsonSerializer.Deserialize(this._messageHandlerStub.RequestContent); + Assert.NotNull(requestPayload); + Assert.Equal("fake-text", requestPayload.Input[0]); + } + + [Fact] + public async Task ShouldHandleServiceResponseAsync() + { + //Arrange + var sut = new OllamaTextEmbeddingGenerationService( + "fake-model", + httpClient: this._httpClient); + + //Act + var contents = await sut.GenerateEmbeddingsAsync(["fake-text"]); + + //Assert + Assert.NotNull(contents); + Assert.Equal(2, contents.Count); + + var content = contents[0]; + Assert.Equal(5, content.Length); + } + + public void Dispose() + { + this._httpClient.Dispose(); + this._messageHandlerStub.Dispose(); + } +} diff --git a/dotnet/src/Connectors/Connectors.Ollama.UnitTests/Services/OllamaTextGenerationTests.cs b/dotnet/src/Connectors/Connectors.Ollama.UnitTests/Services/OllamaTextGenerationTests.cs new file mode 100644 index 000000000000..c765bf1d678d --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Ollama.UnitTests/Services/OllamaTextGenerationTests.cs @@ -0,0 +1,200 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.IO; +using System.Linq; +using System.Net.Http; +using System.Text.Json; +using System.Threading.Tasks; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.Connectors.Ollama; +using Microsoft.SemanticKernel.TextGeneration; +using OllamaSharp.Models; +using OllamaSharp.Models.Chat; +using Xunit; + +namespace SemanticKernel.Connectors.Ollama.UnitTests.Services; + +public sealed class OllamaTextGenerationTests : IDisposable +{ + private readonly HttpMessageHandlerStub _messageHandlerStub; + private readonly HttpClient _httpClient; + + public OllamaTextGenerationTests() + { + this._messageHandlerStub = new() + { + ResponseToReturn = new HttpResponseMessage(System.Net.HttpStatusCode.OK) + { + Content = new StreamContent(File.OpenRead("TestData/text_generation_test_response_stream.txt")) + } + }; + this._httpClient = new HttpClient(this._messageHandlerStub, false) { BaseAddress = new Uri("http://localhost:11434") }; + } + + [Fact] + public async Task ShouldSendPromptToServiceAsync() + { + //Arrange + var expectedModel = "phi3"; + var sut = new OllamaTextGenerationService( + expectedModel, + httpClient: this._httpClient); + + //Act + await sut.GetTextContentsAsync("fake-text"); + + //Assert + var requestPayload = JsonSerializer.Deserialize(this._messageHandlerStub.RequestContent); + Assert.NotNull(requestPayload); + Assert.Equal("fake-text", requestPayload.Prompt); + } + + [Fact] + public async Task ShouldHandleServiceResponseAsync() + { + //Arrange + var sut = new OllamaTextGenerationService( + "fake-model", + httpClient: this._httpClient); + + //Act + var contents = await sut.GetTextContentsAsync("fake-test"); + + //Assert + Assert.NotNull(contents); + + var content = contents.SingleOrDefault(); + Assert.NotNull(content); + Assert.Equal("This is test completion response", content.Text); + } + + [Fact] + public async Task GetTextContentsShouldHaveModelIdDefinedAsync() + { + //Arrange + var expectedModel = "phi3"; + var sut = new OllamaTextGenerationService( + expectedModel, + httpClient: this._httpClient); + + // Act + var textContent = await sut.GetTextContentAsync("Any prompt"); + + // Assert + var requestPayload = JsonSerializer.Deserialize(this._messageHandlerStub.RequestContent); + Assert.NotNull(requestPayload); + Assert.NotNull(requestPayload.Options); + Assert.Null(requestPayload.Options.Stop); + Assert.Null(requestPayload.Options.Temperature); + Assert.Null(requestPayload.Options.TopK); + Assert.Null(requestPayload.Options.TopP); + + Assert.NotNull(textContent.ModelId); + Assert.Equal(expectedModel, textContent.ModelId); + } + + [Fact] + public async Task GetStreamingTextContentsShouldHaveModelIdDefinedAsync() + { + //Arrange + var expectedModel = "phi3"; + var sut = new OllamaTextGenerationService( + expectedModel, + httpClient: this._httpClient); + + // Act + StreamingTextContent? lastTextContent = null; + await foreach (var textContent in sut.GetStreamingTextContentsAsync("Any prompt")) + { + lastTextContent = textContent; + } + + // Assert + var requestPayload = JsonSerializer.Deserialize(this._messageHandlerStub.RequestContent); + Assert.NotNull(requestPayload); + Assert.NotNull(requestPayload.Options); + Assert.Null(requestPayload.Options.Stop); + Assert.Null(requestPayload.Options.Temperature); + Assert.Null(requestPayload.Options.TopK); + Assert.Null(requestPayload.Options.TopP); + + Assert.NotNull(lastTextContent!.ModelId); + Assert.Equal(expectedModel, lastTextContent.ModelId); + } + + [Fact] + public async Task GetStreamingTextContentsExecutionSettingsMustBeSentAsync() + { + //Arrange + var sut = new OllamaTextGenerationService( + "fake-model", + httpClient: this._httpClient); + + string jsonSettings = """ + { + "stop": ["stop me"], + "temperature": 0.5, + "top_p": 0.9, + "top_k": 100 + } + """; + + var executionSettings = JsonSerializer.Deserialize(jsonSettings); + var ollamaExecutionSettings = OllamaPromptExecutionSettings.FromExecutionSettings(executionSettings); + + // Act + await sut.GetStreamingTextContentsAsync("Any prompt", ollamaExecutionSettings).GetAsyncEnumerator().MoveNextAsync(); + + // Assert + var requestPayload = JsonSerializer.Deserialize(this._messageHandlerStub.RequestContent); + Assert.NotNull(requestPayload); + Assert.NotNull(requestPayload.Options); + Assert.Equal(ollamaExecutionSettings.Stop, requestPayload.Options.Stop); + Assert.Equal(ollamaExecutionSettings.Temperature, requestPayload.Options.Temperature); + Assert.Equal(ollamaExecutionSettings.TopP, requestPayload.Options.TopP); + Assert.Equal(ollamaExecutionSettings.TopK, requestPayload.Options.TopK); + } + + [Fact] + public async Task GetTextContentsExecutionSettingsMustBeSentAsync() + { + //Arrange + var sut = new OllamaTextGenerationService( + "fake-model", + httpClient: this._httpClient); + string jsonSettings = """ + { + "stop": ["stop me"], + "temperature": 0.5, + "top_p": 0.9, + "top_k": 100 + } + """; + + var executionSettings = JsonSerializer.Deserialize(jsonSettings); + var ollamaExecutionSettings = OllamaPromptExecutionSettings.FromExecutionSettings(executionSettings); + + // Act + await sut.GetTextContentsAsync("Any prompt", ollamaExecutionSettings); + + // Assert + var requestPayload = JsonSerializer.Deserialize(this._messageHandlerStub.RequestContent); + Assert.NotNull(requestPayload); + Assert.NotNull(requestPayload.Options); + Assert.Equal(ollamaExecutionSettings.Stop, requestPayload.Options.Stop); + Assert.Equal(ollamaExecutionSettings.Temperature, requestPayload.Options.Temperature); + Assert.Equal(ollamaExecutionSettings.TopP, requestPayload.Options.TopP); + Assert.Equal(ollamaExecutionSettings.TopK, requestPayload.Options.TopK); + } + + /// + /// Disposes resources used by this class. + /// + public void Dispose() + { + this._messageHandlerStub.Dispose(); + + this._httpClient.Dispose(); + } +} diff --git a/dotnet/src/Connectors/Connectors.Ollama.UnitTests/Settings/OllamaPromptExecutionSettingsTests.cs b/dotnet/src/Connectors/Connectors.Ollama.UnitTests/Settings/OllamaPromptExecutionSettingsTests.cs new file mode 100644 index 000000000000..b7ff3d1c57c5 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Ollama.UnitTests/Settings/OllamaPromptExecutionSettingsTests.cs @@ -0,0 +1,65 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Linq; +using System.Text.Json; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.Connectors.Ollama; +using Xunit; + +namespace SemanticKernel.Connectors.Ollama.UnitTests.Settings; + +/// +/// Unit tests of . +/// +public class OllamaPromptExecutionSettingsTests +{ + [Fact] + public void FromExecutionSettingsWhenAlreadyOllamaShouldReturnSame() + { + // Arrange + var executionSettings = new OllamaPromptExecutionSettings(); + + // Act + var ollamaExecutionSettings = OllamaPromptExecutionSettings.FromExecutionSettings(executionSettings); + + // Assert + Assert.Same(executionSettings, ollamaExecutionSettings); + } + + [Fact] + public void FromExecutionSettingsWhenNullShouldReturnDefault() + { + // Arrange + OllamaPromptExecutionSettings? executionSettings = null; + + // Act + var ollamaExecutionSettings = OllamaPromptExecutionSettings.FromExecutionSettings(executionSettings); + + // Assert + Assert.Null(ollamaExecutionSettings.Stop); + Assert.Null(ollamaExecutionSettings.Temperature); + Assert.Null(ollamaExecutionSettings.TopP); + Assert.Null(ollamaExecutionSettings.TopK); + } + + [Fact] + public void FromExecutionSettingsWhenSerializedHasPropertiesShouldPopulateSpecialized() + { + string jsonSettings = """ + { + "stop": ["stop me"], + "temperature": 0.5, + "top_p": 0.9, + "top_k": 100 + } + """; + + var executionSettings = JsonSerializer.Deserialize(jsonSettings); + var ollamaExecutionSettings = OllamaPromptExecutionSettings.FromExecutionSettings(executionSettings); + + Assert.Equal("stop me", ollamaExecutionSettings.Stop?.FirstOrDefault()); + Assert.Equal(0.5f, ollamaExecutionSettings.Temperature); + Assert.Equal(0.9f, ollamaExecutionSettings.TopP!.Value, 0.1f); + Assert.Equal(100, ollamaExecutionSettings.TopK); + } +} diff --git a/dotnet/src/Connectors/Connectors.Ollama.UnitTests/TestData/chat_completion_test_response_stream.txt b/dotnet/src/Connectors/Connectors.Ollama.UnitTests/TestData/chat_completion_test_response_stream.txt new file mode 100644 index 000000000000..55b26d234500 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Ollama.UnitTests/TestData/chat_completion_test_response_stream.txt @@ -0,0 +1,6 @@ +{"model":"phi3","created_at":"2024-07-02T11:45:16.216898458Z","message":{"role":"assistant","content":"This "},"done":false} +{"model":"phi3","created_at":"2024-07-02T11:45:16.22693076Z","message":{"role":"assistant","content":"is "},"done":false} +{"model":"phi3","created_at":"2024-07-02T11:45:16.236570847Z","message":{"role":"assistant","content":"test "},"done":false} +{"model":"phi3","created_at":"2024-07-02T11:45:16.246538945Z","message":{"role":"assistant","content":"completion "},"done":false} +{"model":"phi3","created_at":"2024-07-02T11:45:16.25611096Z","message":{"role":"assistant","content":"response"},"done":false} +{"model":"phi3","created_at":"2024-07-02T11:45:16.265598822Z","message":{"role":"assistant","content":""},"done_reason":"stop","done":true,"total_duration":58123571935,"load_duration":55561676662,"prompt_eval_count":10,"prompt_eval_duration":34847000,"eval_count":239,"eval_duration":2381751000} \ No newline at end of file diff --git a/dotnet/src/Connectors/Connectors.Ollama.UnitTests/TestData/embeddings_test_response.json b/dotnet/src/Connectors/Connectors.Ollama.UnitTests/TestData/embeddings_test_response.json new file mode 100644 index 000000000000..3316addba6dd --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Ollama.UnitTests/TestData/embeddings_test_response.json @@ -0,0 +1,19 @@ +{ + "model": "fake-model", + "embeddings": [ + [ + 0.020765934, + 0.007495159, + 0.01268963, + 0.013938076, + -0.04621073 + ], + [ + 0.025005031, + 0.009804744, + -0.016960088, + -0.024823941, + -0.02756831 + ] + ] +} \ No newline at end of file diff --git a/dotnet/src/Connectors/Connectors.Ollama.UnitTests/TestData/text_generation_test_response_stream.txt b/dotnet/src/Connectors/Connectors.Ollama.UnitTests/TestData/text_generation_test_response_stream.txt new file mode 100644 index 000000000000..d2fe45f536c9 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Ollama.UnitTests/TestData/text_generation_test_response_stream.txt @@ -0,0 +1,6 @@ +{"model":"phi3","created_at":"2024-07-02T12:22:37.03627019Z","response":"This ","done":false} +{"model":"phi3","created_at":"2024-07-02T12:22:37.048915655Z","response":"is ","done":false} +{"model":"phi3","created_at":"2024-07-02T12:22:37.060968719Z","response":"test ","done":false} +{"model":"phi3","created_at":"2024-07-02T12:22:37.072390403Z","response":"completion ","done":false} +{"model":"phi3","created_at":"2024-07-02T12:22:37.072390403Z","response":"response","done":false} +{"model":"phi3","created_at":"2024-07-02T12:22:37.091017292Z","response":"","done":true,"done_reason":"stop","context":[32010,3750,338,278,14744,7254,29973,32007,32001,450,2769,278,14744,5692,7254,304,502,373,11563,756,304,437,411,278,14801,292,310,6575,4366,491,278,25005,29889,8991,4366,29892,470,4796,3578,29892,338,1754,701,310,263,18272,310,11955,393,508,367,3595,297,263,17251,17729,313,1127,29892,24841,29892,13328,29892,7933,29892,7254,29892,1399,5973,29892,322,28008,1026,467,910,18272,310,11955,338,2998,408,4796,3578,1363,372,3743,599,278,1422,281,6447,1477,29879,12420,4208,29889,13,13,10401,6575,4366,24395,11563,29915,29879,25005,29892,21577,13206,21337,763,21767,307,1885,322,288,28596,14801,20511,29899,29893,6447,1477,3578,313,9539,322,28008,1026,29897,901,1135,5520,29899,29893,6447,1477,3578,313,1127,322,13328,467,4001,1749,5076,526,901,20502,304,7254,3578,322,278,8991,5692,901,4796,515,1749,18520,373,11563,2861,304,445,14801,292,2779,29892,591,17189,573,278,14744,408,7254,29889,13,13,2528,17658,29892,5998,1716,7254,322,28008,1026,281,6447,1477,29879,310,3578,526,29574,22829,491,4799,13206,21337,29892,1749,639,1441,338,451,28482,491,278,28008,1026,2927,1951,5199,5076,526,3109,20502,304,372,29889,12808,29892,6575,4366,20888,11563,29915,29879,7101,756,263,6133,26171,297,278,13328,29899,12692,760,310,278,18272,9401,304,2654,470,28008,1026,11955,2861,304,9596,280,1141,14801,292,29892,607,4340,26371,2925,1749,639,1441,310,278,7254,14744,29889,13,13,797,15837,29892,278,14801,292,310,20511,281,6447,1477,3578,313,9539,322,28008,1026,29897,491,11563,29915,29879,25005,9946,502,304,1074,263,758,24130,10835,7254,14744,2645,2462,4366,6199,29889,32007],"total_duration":64697743903,"load_duration":61368714283,"prompt_eval_count":10,"prompt_eval_duration":40919000,"eval_count":304,"eval_duration":3237325000} \ No newline at end of file diff --git a/dotnet/src/Connectors/Connectors.Ollama/AssemblyInfo.cs b/dotnet/src/Connectors/Connectors.Ollama/AssemblyInfo.cs new file mode 100644 index 000000000000..fe66371dbc58 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Ollama/AssemblyInfo.cs @@ -0,0 +1,6 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Diagnostics.CodeAnalysis; + +// This assembly is currently experimental. +[assembly: Experimental("SKEXP0070")] diff --git a/dotnet/src/Connectors/Connectors.Ollama/Connectors.Ollama.csproj b/dotnet/src/Connectors/Connectors.Ollama/Connectors.Ollama.csproj new file mode 100644 index 000000000000..1ce5397d2e07 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Ollama/Connectors.Ollama.csproj @@ -0,0 +1,34 @@ + + + + + Microsoft.SemanticKernel.Connectors.Ollama + $(AssemblyName) + net8;netstandard2.0 + alpha + + + + + + + + + Semantic Kernel - Ollama AI connectors + Semantic Kernel connector for Ollama. Contains services for text generation, chat completion and text embeddings. + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/dotnet/src/Connectors/Connectors.Ollama/Core/ServiceBase.cs b/dotnet/src/Connectors/Connectors.Ollama/Core/ServiceBase.cs new file mode 100644 index 000000000000..f9ed8fb7b4ff --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Ollama/Core/ServiceBase.cs @@ -0,0 +1,62 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Net.Http; +using Microsoft.Extensions.Logging; +using Microsoft.SemanticKernel.Http; +using Microsoft.SemanticKernel.Services; +using OllamaSharp; + +namespace Microsoft.SemanticKernel.Connectors.Ollama.Core; + +/// +/// Represents the core of a service. +/// +public abstract class ServiceBase +{ + /// + /// Attributes of the service. + /// + internal Dictionary AttributesInternal { get; } = []; + + /// + /// Internal Ollama Sharp client. + /// + internal readonly OllamaApiClient _client; + + internal ServiceBase(string model, + Uri? endpoint, + HttpClient? httpClient = null, + ILoggerFactory? loggerFactory = null) + { + Verify.NotNullOrWhiteSpace(model); + this.AttributesInternal.Add(AIServiceExtensions.ModelIdKey, model); + + if (httpClient is not null) + { + this._client = new(httpClient, model); + } + else + { +#pragma warning disable CA2000 // Dispose objects before losing scope + // Client needs to be created to be able to inject Semantic Kernel headers + var internalClient = HttpClientProvider.GetHttpClient(); + internalClient.BaseAddress = endpoint; + internalClient.DefaultRequestHeaders.Add("User-Agent", HttpHeaderConstant.Values.UserAgent); + internalClient.DefaultRequestHeaders.Add(HttpHeaderConstant.Names.SemanticKernelVersion, HttpHeaderConstant.Values.GetAssemblyVersion(typeof(Kernel))); + + this._client = new(internalClient, model); +#pragma warning restore CA2000 // Dispose objects before losing scope + } + } + + internal ServiceBase(string model, + OllamaApiClient ollamaClient, + ILoggerFactory? loggerFactory = null) + { + Verify.NotNullOrWhiteSpace(model); + this._client = ollamaClient; + this.AttributesInternal.Add(AIServiceExtensions.ModelIdKey, model); + } +} diff --git a/dotnet/src/Connectors/Connectors.Ollama/Extensions/OllamaKernelBuilderExtensions.cs b/dotnet/src/Connectors/Connectors.Ollama/Extensions/OllamaKernelBuilderExtensions.cs new file mode 100644 index 000000000000..0ad8d895bdd7 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Ollama/Extensions/OllamaKernelBuilderExtensions.cs @@ -0,0 +1,231 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Net.Http; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Microsoft.SemanticKernel.Connectors.Ollama; +using Microsoft.SemanticKernel.Http; +using Microsoft.SemanticKernel.TextGeneration; +using OllamaSharp; + +namespace Microsoft.SemanticKernel; + +/// +/// Extension methods for adding Ollama Text Generation service to the kernel builder. +/// +public static class OllamaKernelBuilderExtensions +{ + #region Text Generation + + /// + /// Add Ollama Text Generation service to the kernel builder. + /// + /// The kernel builder. + /// The model for text generation. + /// The endpoint to Ollama hosted service. + /// The optional service ID. + /// The optional custom HttpClient. + /// The updated kernel builder. + public static IKernelBuilder AddOllamaTextGeneration( + this IKernelBuilder builder, + string modelId, + Uri endpoint, + string? serviceId = null, + HttpClient? httpClient = null) + { + Verify.NotNull(builder); + + builder.Services.AddKeyedSingleton(serviceId, (serviceProvider, _) => + new OllamaTextGenerationService( + modelId: modelId, + endpoint: endpoint, + loggerFactory: serviceProvider.GetService())); + return builder; + } + + /// + /// Add Ollama Text Generation service to the kernel builder. + /// + /// The kernel builder. + /// The model for text generation. + /// The optional service ID. + /// The optional custom HttpClient. + /// The updated kernel builder. + public static IKernelBuilder AddOllamaTextGeneration( + this IKernelBuilder builder, + string modelId, + string? serviceId = null, + HttpClient? httpClient = null) + { + Verify.NotNull(builder); + + builder.Services.AddKeyedSingleton(serviceId, (serviceProvider, _) => + new OllamaTextGenerationService( + modelId: modelId, + httpClient: HttpClientProvider.GetHttpClient(httpClient, serviceProvider), + loggerFactory: serviceProvider.GetService())); + return builder; + } + + /// + /// Add Ollama Text Generation service to the kernel builder. + /// + /// The kernel builder. + /// The model for text generation. + /// The Ollama Sharp library client. + /// The optional service ID. + /// The updated kernel builder. + public static IKernelBuilder AddOllamaTextGeneration( + this IKernelBuilder builder, + string modelId, + OllamaApiClient ollamaClient, + string? serviceId = null) + { + Verify.NotNull(builder); + + builder.Services.AddKeyedSingleton(serviceId, (serviceProvider, _) => + new OllamaTextGenerationService( + modelId: modelId, + ollamaClient: ollamaClient, + loggerFactory: serviceProvider.GetService())); + return builder; + } + + #endregion + + #region Chat Completion + + /// + /// Add Ollama Chat Completion service to the kernel builder. + /// + /// The kernel builder. + /// The model for text generation. + /// The endpoint to Ollama hosted service. + /// The optional service ID. + /// The updated kernel builder. + public static IKernelBuilder AddOllamaChatCompletion( + this IKernelBuilder builder, + string modelId, + Uri endpoint, + string? serviceId = null) + { + Verify.NotNull(builder); + + builder.Services.AddOllamaChatCompletion(modelId, endpoint, serviceId); + + return builder; + } + + /// + /// Add Ollama Chat Completion service to the kernel builder. + /// + /// The kernel builder. + /// The model for text generation. + /// The optional custom HttpClient. + /// The optional service ID. + /// The updated kernel builder. + public static IKernelBuilder AddOllamaChatCompletion( + this IKernelBuilder builder, + string modelId, + HttpClient? httpClient = null, + string? serviceId = null + ) + { + Verify.NotNull(builder); + + builder.Services.AddOllamaChatCompletion(modelId, httpClient, serviceId); + + return builder; + } + + /// + /// Add Ollama Chat Completion service to the kernel builder. + /// + /// The kernel builder. + /// The model for text generation. + /// The Ollama Sharp library client. + /// The optional service ID. + /// The updated kernel builder. + public static IKernelBuilder AddOllamaChatCompletion( + this IKernelBuilder builder, + string modelId, + OllamaApiClient ollamaClient, + string? serviceId = null) + { + Verify.NotNull(builder); + + builder.Services.AddOllamaChatCompletion(modelId, ollamaClient, serviceId); + + return builder; + } + + #endregion + + #region Text Embeddings + + /// + /// Add Ollama Text Embeddings Generation service to the kernel builder. + /// + /// The kernel builder. + /// The model for text generation. + /// The endpoint to Ollama hosted service. + /// The optional service ID. + /// The updated kernel builder. + public static IKernelBuilder AddOllamaTextEmbeddingGeneration( + this IKernelBuilder builder, + string modelId, + Uri endpoint, + string? serviceId = null) + { + Verify.NotNull(builder); + + builder.Services.AddOllamaTextEmbeddingGeneration(modelId, endpoint, serviceId); + + return builder; + } + + /// + /// Add Ollama Text Embeddings Generation service to the kernel builder. + /// + /// The kernel builder. + /// The model for text generation. + /// The optional custom HttpClient. + /// The optional service ID. + /// The updated kernel builder. + public static IKernelBuilder AddOllamaTextEmbeddingGeneration( + this IKernelBuilder builder, + string modelId, + HttpClient? httpClient = null, + string? serviceId = null) + { + Verify.NotNull(builder); + + builder.Services.AddOllamaTextEmbeddingGeneration(modelId, httpClient, serviceId); + + return builder; + } + + /// + /// Add Ollama Text Embeddings Generation service to the kernel builder. + /// + /// The kernel builder. + /// The model for text generation. + /// The Ollama Sharp library client. + /// The optional service ID. + /// The updated kernel builder. + public static IKernelBuilder AddOllamaTextEmbeddingGeneration( + this IKernelBuilder builder, + string modelId, + OllamaApiClient ollamaClient, + string? serviceId = null) + { + Verify.NotNull(builder); + + builder.Services.AddOllamaTextEmbeddingGeneration(modelId, ollamaClient, serviceId); + + return builder; + } + + #endregion +} diff --git a/dotnet/src/Connectors/Connectors.Ollama/Extensions/OllamaServiceCollectionExtensions.cs b/dotnet/src/Connectors/Connectors.Ollama/Extensions/OllamaServiceCollectionExtensions.cs new file mode 100644 index 000000000000..9ef438515e35 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Ollama/Extensions/OllamaServiceCollectionExtensions.cs @@ -0,0 +1,243 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Net.Http; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Microsoft.SemanticKernel.ChatCompletion; +using Microsoft.SemanticKernel.Connectors.Ollama; +using Microsoft.SemanticKernel.Embeddings; +using Microsoft.SemanticKernel.Http; +using Microsoft.SemanticKernel.TextGeneration; +using OllamaSharp; + +namespace Microsoft.SemanticKernel; + +/// +/// Extension methods for adding Ollama Text Generation service to the kernel builder. +/// +public static class OllamaServiceCollectionExtensions +{ + #region Text Generation + + /// + /// Add Ollama Text Generation service to the specified service collection. + /// + /// The target service collection. + /// The model for text generation. + /// The endpoint to Ollama hosted service. + /// The optional service ID. + /// The updated kernel builder. + public static IServiceCollection AddOllamaTextGeneration( + this IServiceCollection services, + string modelId, + Uri endpoint, + string? serviceId = null) + { + Verify.NotNull(services); + + return services.AddKeyedSingleton(serviceId, (serviceProvider, _) => + new OllamaTextGenerationService( + modelId: modelId, + endpoint: endpoint, + loggerFactory: serviceProvider.GetService())); + } + + /// + /// Add Ollama Text Generation service to the specified service collection. + /// + /// The target service collection. + /// The model for text generation. + /// Optional custom HttpClient, picked from ServiceCollection if not provided. + /// The optional service ID. + /// The updated kernel builder. + public static IServiceCollection AddOllamaTextGeneration( + this IServiceCollection services, + string modelId, + HttpClient? httpClient = null, + string? serviceId = null) + { + Verify.NotNull(services); + + return services.AddKeyedSingleton(serviceId, (serviceProvider, _) => + new OllamaTextGenerationService( + modelId: modelId, + httpClient: HttpClientProvider.GetHttpClient(serviceProvider), + loggerFactory: serviceProvider.GetService())); + } + + /// + /// Add Ollama Text Generation service to the kernel builder. + /// + /// The target service collection. + /// The model for text generation. + /// The Ollama Sharp library client. + /// The optional service ID. + /// The updated kernel builder. + public static IServiceCollection AddOllamaTextGeneration( + this IServiceCollection services, + string modelId, + OllamaApiClient ollamaClient, + string? serviceId = null) + { + Verify.NotNull(services); + + return services.AddKeyedSingleton(serviceId, (serviceProvider, _) => + new OllamaTextGenerationService( + modelId: modelId, + ollamaClient: ollamaClient, + loggerFactory: serviceProvider.GetService())); + } + + #endregion + + #region Chat Completion + + /// + /// Add Ollama Chat Completion and Text Generation services to the specified service collection. + /// + /// The target service collection. + /// The model for text generation. + /// The endpoint to Ollama hosted service. + /// Optional service ID. + /// The updated service collection. + public static IServiceCollection AddOllamaChatCompletion( + this IServiceCollection services, + string modelId, + Uri endpoint, + string? serviceId = null) + { + Verify.NotNull(services); + + services.AddKeyedSingleton(serviceId, (serviceProvider, _) => + new OllamaChatCompletionService( + modelId: modelId, + endpoint: endpoint, + loggerFactory: serviceProvider.GetService())); + + return services; + } + + /// + /// Add Ollama Chat Completion and Text Generation services to the specified service collection. + /// + /// The target service collection. + /// The model for text generation. + /// Optional custom HttpClient, picked from ServiceCollection if not provided. + /// Optional service ID. + /// The updated service collection. + public static IServiceCollection AddOllamaChatCompletion( + this IServiceCollection services, + string modelId, + HttpClient? httpClient = null, + string? serviceId = null) + { + Verify.NotNull(services); + + services.AddKeyedSingleton(serviceId, (serviceProvider, _) => + new OllamaChatCompletionService( + modelId: modelId, + httpClient: HttpClientProvider.GetHttpClient(httpClient, serviceProvider), + loggerFactory: serviceProvider.GetService())); + + return services; + } + + /// + /// Add Ollama Chat Completion service to the kernel builder. + /// + /// The target service collection. + /// The model for text generation. + /// The Ollama Sharp library client. + /// The optional service ID. + /// The updated kernel builder. + public static IServiceCollection AddOllamaChatCompletion( + this IServiceCollection services, + string modelId, + OllamaApiClient ollamaClient, + string? serviceId = null) + { + Verify.NotNull(services); + + return services.AddKeyedSingleton(serviceId, (serviceProvider, _) => + new OllamaChatCompletionService( + modelId: modelId, + ollamaClient: ollamaClient, + loggerFactory: serviceProvider.GetService())); + } + + #endregion + + #region Text Embeddings + + /// + /// Add Ollama Text Embedding Generation services to the kernel builder. + /// + /// The target service collection. + /// The model for text generation. + /// The endpoint to Ollama hosted service. + /// Optional service ID. + /// The updated kernel builder. + public static IServiceCollection AddOllamaTextEmbeddingGeneration( + this IServiceCollection services, + string modelId, + Uri endpoint, + string? serviceId = null) + { + Verify.NotNull(services); + + return services.AddKeyedSingleton(serviceId, (serviceProvider, _) => + new OllamaTextEmbeddingGenerationService( + modelId: modelId, + endpoint: endpoint, + loggerFactory: serviceProvider.GetService())); + } + + /// + /// Add Ollama Text Embedding Generation services to the kernel builder. + /// + /// The target service collection. + /// The model for text generation. + /// Optional custom HttpClient, picked from ServiceCollection if not provided. + /// Optional service ID. + /// The updated kernel builder. + public static IServiceCollection AddOllamaTextEmbeddingGeneration( + this IServiceCollection services, + string modelId, + HttpClient? httpClient = null, + string? serviceId = null) + { + Verify.NotNull(services); + + return services.AddKeyedSingleton(serviceId, (serviceProvider, _) => + new OllamaTextEmbeddingGenerationService( + modelId: modelId, + httpClient: HttpClientProvider.GetHttpClient(httpClient, serviceProvider), + loggerFactory: serviceProvider.GetService())); + } + + /// + /// Add Ollama Text Embeddings Generation service to the kernel builder. + /// + /// The target service collection. + /// The model for text generation. + /// The Ollama Sharp library client. + /// The optional service ID. + /// The updated kernel builder. + public static IServiceCollection AddOllamaTextEmbeddingGeneration( + this IServiceCollection services, + string modelId, + OllamaApiClient ollamaClient, + string? serviceId = null) + { + Verify.NotNull(services); + + return services.AddKeyedSingleton(serviceId, (serviceProvider, _) => + new OllamaTextEmbeddingGenerationService( + modelId: modelId, + ollamaClient: ollamaClient, + loggerFactory: serviceProvider.GetService())); + } + + #endregion +} diff --git a/dotnet/src/Connectors/Connectors.Ollama/Services/OllamaChatCompletionService.cs b/dotnet/src/Connectors/Connectors.Ollama/Services/OllamaChatCompletionService.cs new file mode 100644 index 000000000000..e8e0c2e965e9 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Ollama/Services/OllamaChatCompletionService.cs @@ -0,0 +1,182 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Net.Http; +using System.Runtime.CompilerServices; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.Logging; +using Microsoft.SemanticKernel.ChatCompletion; +using Microsoft.SemanticKernel.Connectors.Ollama.Core; +using OllamaSharp; +using OllamaSharp.Models.Chat; + +namespace Microsoft.SemanticKernel.Connectors.Ollama; + +/// +/// Represents a chat completion service using Ollama Original API. +/// +public sealed class OllamaChatCompletionService : ServiceBase, IChatCompletionService +{ + /// + /// Initializes a new instance of the class. + /// + /// The hosted model. + /// The endpoint including the port where Ollama server is hosted + /// Optional logger factory to be used for logging. + public OllamaChatCompletionService( + string modelId, + Uri endpoint, + ILoggerFactory? loggerFactory = null) + : base(modelId, endpoint, null, loggerFactory) + { + Verify.NotNull(endpoint); + } + + /// + /// Initializes a new instance of the class. + /// + /// The hosted model. + /// HTTP client to be used for communication with the Ollama API. + /// Optional logger factory to be used for logging. + public OllamaChatCompletionService( + string modelId, + HttpClient httpClient, + ILoggerFactory? loggerFactory = null) + : base(modelId, null, httpClient, loggerFactory) + { + Verify.NotNull(httpClient); + Verify.NotNull(httpClient.BaseAddress); + } + + /// + /// Initializes a new instance of the class. + /// + /// The hosted model. + /// The Ollama API client. + /// Optional logger factory to be used for logging. + public OllamaChatCompletionService( + string modelId, + OllamaApiClient ollamaClient, + ILoggerFactory? loggerFactory = null) + : base(modelId, ollamaClient, loggerFactory) + { + } + + /// + public IReadOnlyDictionary Attributes => this.AttributesInternal; + + /// + public async Task> GetChatMessageContentsAsync( + ChatHistory chatHistory, + PromptExecutionSettings? executionSettings = null, + Kernel? kernel = null, + CancellationToken cancellationToken = default) + { + var settings = OllamaPromptExecutionSettings.FromExecutionSettings(executionSettings); + var request = CreateChatRequest(chatHistory, settings, this._client.SelectedModel); + var chatMessageContent = new ChatMessageContent(); + var fullContent = new StringBuilder(); + string? modelId = null; + AuthorRole? authorRole = null; + List innerContent = []; + + await foreach (var responseStreamChunk in this._client.Chat(request, cancellationToken).ConfigureAwait(false)) + { + if (responseStreamChunk is null) + { + continue; + } + + innerContent.Add(responseStreamChunk); + + if (responseStreamChunk.Message.Content is not null) + { + fullContent.Append(responseStreamChunk.Message.Content); + } + + if (responseStreamChunk.Message.Role is not null) + { + authorRole = GetAuthorRole(responseStreamChunk.Message.Role)!.Value; + } + + modelId ??= responseStreamChunk.Model; + } + + return [new ChatMessageContent( + role: authorRole ?? new(), + content: fullContent.ToString(), + modelId: modelId, + innerContent: innerContent)]; + } + + /// + public async IAsyncEnumerable GetStreamingChatMessageContentsAsync( + ChatHistory chatHistory, + PromptExecutionSettings? executionSettings = null, + Kernel? kernel = null, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + var settings = OllamaPromptExecutionSettings.FromExecutionSettings(executionSettings); + var request = CreateChatRequest(chatHistory, settings, this._client.SelectedModel); + + await foreach (var message in this._client.Chat(request, cancellationToken).ConfigureAwait(false)) + { + yield return new StreamingChatMessageContent( + role: GetAuthorRole(message!.Message.Role), + content: message.Message.Content, + modelId: message.Model, + innerContent: message); + } + } + + #region Private + + private static AuthorRole? GetAuthorRole(ChatRole? role) => role?.ToString().ToUpperInvariant() switch + { + "USER" => AuthorRole.User, + "ASSISTANT" => AuthorRole.Assistant, + "SYSTEM" => AuthorRole.System, + null => null, + _ => new AuthorRole(role.ToString()!) + }; + + private static ChatRequest CreateChatRequest(ChatHistory chatHistory, OllamaPromptExecutionSettings settings, string selectedModel) + { + var messages = new List(); + foreach (var chatHistoryMessage in chatHistory) + { + ChatRole role = ChatRole.User; + if (chatHistoryMessage.Role == AuthorRole.System) + { + role = ChatRole.System; + } + else if (chatHistoryMessage.Role == AuthorRole.Assistant) + { + role = ChatRole.Assistant; + } + + messages.Add(new Message(role, chatHistoryMessage.Content!)); + } + + var request = new ChatRequest + { + Options = new() + { + Temperature = settings.Temperature, + TopP = settings.TopP, + TopK = settings.TopK, + Stop = settings.Stop?.ToArray() + }, + Messages = messages, + Model = selectedModel, + Stream = true + }; + + return request; + } + + #endregion +} diff --git a/dotnet/src/Connectors/Connectors.Ollama/Services/OllamaTextEmbeddingGenerationService.cs b/dotnet/src/Connectors/Connectors.Ollama/Services/OllamaTextEmbeddingGenerationService.cs new file mode 100644 index 000000000000..f5bee67d4ec5 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Ollama/Services/OllamaTextEmbeddingGenerationService.cs @@ -0,0 +1,93 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Net.Http; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.Logging; +using Microsoft.SemanticKernel.Connectors.Ollama.Core; +using Microsoft.SemanticKernel.Embeddings; +using Microsoft.SemanticKernel.Services; +using OllamaSharp; +using OllamaSharp.Models; + +namespace Microsoft.SemanticKernel.Connectors.Ollama; + +/// +/// Represents a embedding generation service using Ollama Original API. +/// +public sealed class OllamaTextEmbeddingGenerationService : ServiceBase, ITextEmbeddingGenerationService +{ + /// + /// Initializes a new instance of the class. + /// + /// The hosted model. + /// The endpoint including the port where Ollama server is hosted + /// Optional logger factory to be used for logging. + public OllamaTextEmbeddingGenerationService( + string modelId, + Uri endpoint, + ILoggerFactory? loggerFactory = null) + : base(modelId, endpoint, null, loggerFactory) + { + Verify.NotNull(endpoint); + } + + /// + /// Initializes a new instance of the class. + /// + /// The hosted model. + /// HTTP client to be used for communication with the Ollama API. + /// Optional logger factory to be used for logging. + public OllamaTextEmbeddingGenerationService( + string modelId, + HttpClient httpClient, + ILoggerFactory? loggerFactory = null) + : base(modelId, null, httpClient, loggerFactory) + { + Verify.NotNull(httpClient); + Verify.NotNull(httpClient.BaseAddress); + } + + /// + /// Initializes a new instance of the class. + /// + /// The hosted model. + /// The Ollama API client. + /// Optional logger factory to be used for logging. + public OllamaTextEmbeddingGenerationService( + string modelId, + OllamaApiClient ollamaClient, + ILoggerFactory? loggerFactory = null) + : base(modelId, ollamaClient, loggerFactory) + { + } + + /// + public IReadOnlyDictionary Attributes => this.AttributesInternal; + + /// + public async Task>> GenerateEmbeddingsAsync( + IList data, + Kernel? kernel = null, + CancellationToken cancellationToken = default) + { + var request = new EmbedRequest + { + Model = this.GetModelId()!, + Input = data.ToList(), + }; + + var response = await this._client.Embed(request, cancellationToken: cancellationToken).ConfigureAwait(false); + + List> embeddings = []; + foreach (var embedding in response.Embeddings) + { + embeddings.Add(embedding.Select(@decimal => (float)@decimal).ToArray()); + } + + return embeddings; + } +} diff --git a/dotnet/src/Connectors/Connectors.Ollama/Services/OllamaTextGenerationService.cs b/dotnet/src/Connectors/Connectors.Ollama/Services/OllamaTextGenerationService.cs new file mode 100644 index 000000000000..a9432c15d839 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Ollama/Services/OllamaTextGenerationService.cs @@ -0,0 +1,142 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Net.Http; +using System.Runtime.CompilerServices; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.Logging; +using Microsoft.SemanticKernel.Connectors.Ollama.Core; +using Microsoft.SemanticKernel.TextGeneration; +using OllamaSharp; +using OllamaSharp.Models; + +namespace Microsoft.SemanticKernel.Connectors.Ollama; + +/// +/// Represents a text generation service using Ollama Original API. +/// +public sealed class OllamaTextGenerationService : ServiceBase, ITextGenerationService +{ + /// + /// Initializes a new instance of the class. + /// + /// The Ollama model for the text generation service. + /// The endpoint including the port where Ollama server is hosted + /// Optional logger factory to be used for logging. + public OllamaTextGenerationService( + string modelId, + Uri endpoint, + ILoggerFactory? loggerFactory = null) + : base(modelId, endpoint, null, loggerFactory) + { + Verify.NotNull(endpoint); + } + + /// + /// Initializes a new instance of the class. + /// + /// The Ollama model for the text generation service. + /// HTTP client to be used for communication with the Ollama API. + /// Optional logger factory to be used for logging. + public OllamaTextGenerationService( + string modelId, + HttpClient httpClient, + ILoggerFactory? loggerFactory = null) + : base(modelId, null, httpClient, loggerFactory) + { + Verify.NotNull(httpClient); + Verify.NotNull(httpClient.BaseAddress); + } + + /// + /// Initializes a new instance of the class. + /// + /// The hosted model. + /// The Ollama API client. + /// Optional logger factory to be used for logging. + public OllamaTextGenerationService( + string modelId, + OllamaApiClient ollamaClient, + ILoggerFactory? loggerFactory = null) + : base(modelId, ollamaClient, loggerFactory) + { + } + + /// + public IReadOnlyDictionary Attributes => this.AttributesInternal; + + /// + public async Task> GetTextContentsAsync( + string prompt, + PromptExecutionSettings? executionSettings = null, + Kernel? kernel = null, + CancellationToken cancellationToken = default) + { + var fullContent = new StringBuilder(); + List innerContent = []; + string? modelId = null; + + var settings = OllamaPromptExecutionSettings.FromExecutionSettings(executionSettings); + var request = CreateRequest(settings, this._client.SelectedModel); + request.Prompt = prompt; + + await foreach (var responseStreamChunk in this._client.Generate(request, cancellationToken).ConfigureAwait(false)) + { + if (responseStreamChunk is null) + { + continue; + } + + innerContent.Add(responseStreamChunk); + fullContent.Append(responseStreamChunk.Response); + + modelId ??= responseStreamChunk.Model; + } + + return [new TextContent( + text: fullContent.ToString(), + modelId: modelId, + innerContent: innerContent)]; + } + + /// + public async IAsyncEnumerable GetStreamingTextContentsAsync( + string prompt, + PromptExecutionSettings? executionSettings = null, + Kernel? kernel = null, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + var settings = OllamaPromptExecutionSettings.FromExecutionSettings(executionSettings); + var request = CreateRequest(settings, this._client.SelectedModel); + request.Prompt = prompt; + + await foreach (var content in this._client.Generate(request, cancellationToken).ConfigureAwait(false)) + { + yield return new StreamingTextContent( + text: content?.Response, + modelId: content?.Model, + innerContent: content); + } + } + + private static GenerateRequest CreateRequest(OllamaPromptExecutionSettings settings, string selectedModel) + { + var request = new GenerateRequest + { + Options = new() + { + Temperature = settings.Temperature, + TopP = settings.TopP, + TopK = settings.TopK, + Stop = settings.Stop?.ToArray() + }, + Model = selectedModel, + Stream = true + }; + + return request; + } +} diff --git a/dotnet/src/Connectors/Connectors.Ollama/Settings/OllamaPromptExecutionSettings.cs b/dotnet/src/Connectors/Connectors.Ollama/Settings/OllamaPromptExecutionSettings.cs new file mode 100644 index 000000000000..30032bb981d4 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Ollama/Settings/OllamaPromptExecutionSettings.cs @@ -0,0 +1,122 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Text.Json; +using System.Text.Json.Serialization; +using Microsoft.SemanticKernel.Text; + +namespace Microsoft.SemanticKernel.Connectors.Ollama; + +/// +/// Ollama Prompt Execution Settings. +/// +public sealed class OllamaPromptExecutionSettings : PromptExecutionSettings +{ + /// + /// Gets the specialization for the Ollama execution settings. + /// + /// Generic prompt execution settings. + /// Specialized Ollama execution settings. + public static OllamaPromptExecutionSettings FromExecutionSettings(PromptExecutionSettings? executionSettings) + { + switch (executionSettings) + { + case null: + return new(); + case OllamaPromptExecutionSettings settings: + return settings; + } + + var json = JsonSerializer.Serialize(executionSettings); + var ollamaExecutionSettings = JsonSerializer.Deserialize(json, JsonOptionsCache.ReadPermissive); + if (ollamaExecutionSettings is not null) + { + return ollamaExecutionSettings; + } + + throw new ArgumentException( + $"Invalid execution settings, cannot convert to {nameof(OllamaPromptExecutionSettings)}", + nameof(executionSettings)); + } + + /// + /// Sets the stop sequences to use. When this pattern is encountered the + /// LLM will stop generating text and return. Multiple stop patterns may + /// be set by specifying multiple separate stop parameters in a model file. + /// + [JsonPropertyName("stop")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public List? Stop + { + get => this._stop; + + set + { + this.ThrowIfFrozen(); + this._stop = value; + } + } + + /// + /// Reduces the probability of generating nonsense. A higher value + /// (e.g. 100) will give more diverse answers, while a lower value (e.g. 10) + /// will be more conservative. (Default: 40) + /// + [JsonPropertyName("top_k")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public int? TopK + { + get => this._topK; + + set + { + this.ThrowIfFrozen(); + this._topK = value; + } + } + + /// + /// Works together with top-k. A higher value (e.g., 0.95) will lead to + /// more diverse text, while a lower value (e.g., 0.5) will generate more + /// focused and conservative text. (Default: 0.9) + /// + [JsonPropertyName("top_p")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public float? TopP + { + get => this._topP; + + set + { + this.ThrowIfFrozen(); + this._topP = value; + } + } + + /// + /// The temperature of the model. Increasing the temperature will make the + /// model answer more creatively. (Default: 0.8) + /// + [JsonPropertyName("temperature")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public float? Temperature + { + get => this._temperature; + + set + { + this.ThrowIfFrozen(); + this._temperature = value; + } + } + + #region private + + private List? _stop; + private float? _temperature; + private float? _topP; + private int? _topK; + + #endregion +} diff --git a/dotnet/src/IntegrationTests/Connectors/Ollama/OllamaCompletionTests.cs b/dotnet/src/IntegrationTests/Connectors/Ollama/OllamaCompletionTests.cs new file mode 100644 index 000000000000..5dced3f7b4b4 --- /dev/null +++ b/dotnet/src/IntegrationTests/Connectors/Ollama/OllamaCompletionTests.cs @@ -0,0 +1,182 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Text; +using System.Threading.Tasks; +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.Connectors.Ollama; +using OllamaSharp.Models.Chat; +using SemanticKernel.IntegrationTests.TestSettings; +using Xunit; +using Xunit.Abstractions; + +namespace SemanticKernel.IntegrationTests.Connectors.Ollama; + +#pragma warning disable xUnit1004 // Contains test methods used in manual verification. Disable warning for this file only. + +public sealed class OllamaCompletionTests(ITestOutputHelper output) : IDisposable +{ + private const string InputParameterName = "input"; + private readonly IKernelBuilder _kernelBuilder = Kernel.CreateBuilder(); + private readonly IConfigurationRoot _configuration = new ConfigurationBuilder() + .AddJsonFile(path: "testsettings.json", optional: true, reloadOnChange: true) + .AddJsonFile(path: "testsettings.development.json", optional: true, reloadOnChange: true) + .AddEnvironmentVariables() + .AddUserSecrets() + .Build(); + + [Theory(Skip = "For manual verification only")] + [InlineData("Where is the most famous fish market in Seattle, Washington, USA?", "Pike Place")] + public async Task ItInvokeStreamingWorksAsync(string prompt, string expectedAnswerContains) + { + // Arrange + this._kernelBuilder.Services.AddSingleton(this._logger); + var builder = this._kernelBuilder; + + this.ConfigureChatOllama(this._kernelBuilder); + + Kernel target = builder.Build(); + + IReadOnlyKernelPluginCollection plugins = TestHelpers.ImportSamplePlugins(target, "ChatPlugin"); + + StringBuilder fullResult = new(); + // Act + await foreach (var content in target.InvokeStreamingAsync(plugins["ChatPlugin"]["Chat"], new() { [InputParameterName] = prompt })) + { + Assert.NotNull(content.InnerContent); + if (content is StreamingChatMessageContent messageContent) + { + Assert.NotNull(messageContent.Role); + } + + fullResult.Append(content); + } + + // Assert + Assert.Contains(expectedAnswerContains, fullResult.ToString(), StringComparison.OrdinalIgnoreCase); + } + + [Fact(Skip = "For manual verification only")] + public async Task ItShouldReturnInnerContentAsync() + { + // Arrange + this._kernelBuilder.Services.AddSingleton(this._logger); + + this.ConfigureChatOllama(this._kernelBuilder); + + var kernel = this._kernelBuilder.Build(); + + var plugin = TestHelpers.ImportSamplePlugins(kernel, "FunPlugin"); + + // Act + StreamingKernelContent? lastUpdate = null; + await foreach (var update in kernel.InvokeStreamingAsync(plugin["FunPlugin"]["Limerick"])) + { + lastUpdate = update; + } + + // Assert + Assert.NotNull(lastUpdate); + Assert.NotNull(lastUpdate.InnerContent); + Assert.IsType(lastUpdate.InnerContent); + var innerContent = lastUpdate.InnerContent as ChatDoneResponseStream; + Assert.NotNull(innerContent); + Assert.NotNull(innerContent.CreatedAt); + Assert.True(innerContent.Done); + } + + [Theory(Skip = "For manual verification only")] + [InlineData("\n")] + [InlineData("\r\n")] + public async Task ItCompletesWithDifferentLineEndingsAsync(string lineEnding) + { + // Arrange + var prompt = + "Given a json input and a request. Apply the request on the json input and return the result. " + + $"Put the result in between tags{lineEnding}" + + $$"""Input:{{lineEnding}}{"name": "John", "age": 30}{{lineEnding}}{{lineEnding}}Request:{{lineEnding}}name"""; + + const string ExpectedAnswerContains = "result"; + + this._kernelBuilder.Services.AddSingleton(this._logger); + this.ConfigureChatOllama(this._kernelBuilder); + + Kernel target = this._kernelBuilder.Build(); + + IReadOnlyKernelPluginCollection plugins = TestHelpers.ImportSamplePlugins(target, "ChatPlugin"); + + // Act + FunctionResult actual = await target.InvokeAsync(plugins["ChatPlugin"]["Chat"], new() { [InputParameterName] = prompt }); + + // Assert + Assert.Contains(ExpectedAnswerContains, actual.GetValue(), StringComparison.OrdinalIgnoreCase); + } + + [Fact(Skip = "For manual verification only")] + public async Task ItInvokePromptTestAsync() + { + // Arrange + this._kernelBuilder.Services.AddSingleton(this._logger); + var builder = this._kernelBuilder; + this.ConfigureChatOllama(builder); + Kernel target = builder.Build(); + + var prompt = "Where is the most famous fish market in Seattle, Washington, USA?"; + + // Act + FunctionResult actual = await target.InvokePromptAsync(prompt, new(new OllamaPromptExecutionSettings() { Temperature = 0.5f })); + + // Assert + Assert.Contains("Pike Place", actual.GetValue(), StringComparison.OrdinalIgnoreCase); + } + + [Theory(Skip = "For manual verification only")] + [InlineData("Where is the most famous fish market in Seattle, Washington, USA?", "Pike Place")] + public async Task ItInvokeTestAsync(string prompt, string expectedAnswerContains) + { + // Arrange + this._kernelBuilder.Services.AddSingleton(this._logger); + var builder = this._kernelBuilder; + + this.ConfigureChatOllama(this._kernelBuilder); + + Kernel target = builder.Build(); + + IReadOnlyKernelPluginCollection plugins = TestHelpers.ImportSamplePlugins(target, "ChatPlugin"); + + // Act + FunctionResult actual = await target.InvokeAsync(plugins["ChatPlugin"]["Chat"], new() { [InputParameterName] = prompt }); + + // Assert + Assert.Contains(expectedAnswerContains, actual.GetValue(), StringComparison.OrdinalIgnoreCase); + } + + #region internals + + private readonly XunitLogger _logger = new(output); + private readonly RedirectOutput _testOutputHelper = new(output); + + public void Dispose() + { + this._logger.Dispose(); + this._testOutputHelper.Dispose(); + } + + private void ConfigureChatOllama(IKernelBuilder kernelBuilder) + { + var config = this._configuration.GetSection("Ollama").Get(); + + Assert.NotNull(config); + Assert.NotNull(config.Endpoint); + Assert.NotNull(config.ModelId); + + kernelBuilder.AddOllamaChatCompletion( + modelId: config.ModelId, + endpoint: new Uri(config.Endpoint)); + } + + #endregion +} diff --git a/dotnet/src/IntegrationTests/Connectors/Ollama/OllamaTextEmbeddingTests.cs b/dotnet/src/IntegrationTests/Connectors/Ollama/OllamaTextEmbeddingTests.cs new file mode 100644 index 000000000000..222873eccfb6 --- /dev/null +++ b/dotnet/src/IntegrationTests/Connectors/Ollama/OllamaTextEmbeddingTests.cs @@ -0,0 +1,70 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Threading.Tasks; +using Microsoft.Extensions.Configuration; +using Microsoft.SemanticKernel.Connectors.Ollama; +using Microsoft.SemanticKernel.Embeddings; +using SemanticKernel.IntegrationTests.TestSettings; +using Xunit; + +namespace SemanticKernel.IntegrationTests.Connectors.Ollama; + +public sealed class OllamaTextEmbeddingTests +{ + private readonly IConfigurationRoot _configuration = new ConfigurationBuilder() + .AddJsonFile(path: "testsettings.json", optional: true, reloadOnChange: true) + .AddJsonFile(path: "testsettings.development.json", optional: true, reloadOnChange: true) + .AddEnvironmentVariables() + .AddUserSecrets() + .Build(); + + [Theory(Skip = "For manual verification only")] + [InlineData("mxbai-embed-large", 1024)] + [InlineData("nomic-embed-text", 768)] + [InlineData("all-minilm", 384)] + public async Task GenerateEmbeddingHasExpectedLengthForModelAsync(string modelId, int expectedVectorLength) + { + // Arrange + const string TestInputString = "test sentence"; + + OllamaConfiguration? config = this._configuration.GetSection("Ollama").Get(); + Assert.NotNull(config); + Assert.NotNull(config.Endpoint); + + var embeddingGenerator = new OllamaTextEmbeddingGenerationService( + modelId, + new Uri(config.Endpoint)); + + // Act + var result = await embeddingGenerator.GenerateEmbeddingAsync(TestInputString); + + // Assert + Assert.Equal(expectedVectorLength, result.Length); + } + + [Theory(Skip = "For manual verification only")] + [InlineData("mxbai-embed-large", 1024)] + [InlineData("nomic-embed-text", 768)] + [InlineData("all-minilm", 384)] + public async Task GenerateEmbeddingsHasExpectedResultsLengthForModelAsync(string modelId, int expectedVectorLength) + { + // Arrange + string[] testInputStrings = ["test sentence 1", "test sentence 2", "test sentence 3"]; + + OllamaConfiguration? config = this._configuration.GetSection("Ollama").Get(); + Assert.NotNull(config); + Assert.NotNull(config.Endpoint); + + var embeddingGenerator = new OllamaTextEmbeddingGenerationService( + modelId, + new Uri(config.Endpoint)); + + // Act + var result = await embeddingGenerator.GenerateEmbeddingsAsync(testInputStrings); + + // Assert + Assert.Equal(testInputStrings.Length, result.Count); + Assert.All(result, r => Assert.Equal(expectedVectorLength, r.Length)); + } +} diff --git a/dotnet/src/IntegrationTests/Connectors/Ollama/OllamaTextGenerationTests.cs b/dotnet/src/IntegrationTests/Connectors/Ollama/OllamaTextGenerationTests.cs new file mode 100644 index 000000000000..126980f57ede --- /dev/null +++ b/dotnet/src/IntegrationTests/Connectors/Ollama/OllamaTextGenerationTests.cs @@ -0,0 +1,181 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Text; +using System.Threading.Tasks; +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.Connectors.Ollama; +using OllamaSharp.Models; +using SemanticKernel.IntegrationTests.TestSettings; +using Xunit; +using Xunit.Abstractions; + +namespace SemanticKernel.IntegrationTests.Connectors.Ollama; + +#pragma warning disable xUnit1004 // Contains test methods used in manual verification. Disable warning for this file only. + +public sealed class OllamaTextGenerationTests(ITestOutputHelper output) : IDisposable +{ + private const string InputParameterName = "input"; + private readonly IKernelBuilder _kernelBuilder = Kernel.CreateBuilder(); + private readonly IConfigurationRoot _configuration = new ConfigurationBuilder() + .AddJsonFile(path: "testsettings.json", optional: true, reloadOnChange: true) + .AddJsonFile(path: "testsettings.development.json", optional: true, reloadOnChange: true) + .AddEnvironmentVariables() + .AddUserSecrets() + .Build(); + + [Theory(Skip = "For manual verification only")] + [InlineData("Where is the most famous fish market in Seattle, Washington, USA?", "Pike Place")] + public async Task ItInvokeStreamingWorksAsync(string prompt, string expectedAnswerContains) + { + // Arrange + this._kernelBuilder.Services.AddSingleton(this._logger); + var builder = this._kernelBuilder; + + this.ConfigureTextOllama(this._kernelBuilder); + + Kernel target = builder.Build(); + + IReadOnlyKernelPluginCollection plugins = TestHelpers.ImportSamplePlugins(target, "ChatPlugin"); + + StringBuilder fullResult = new(); + // Act + await foreach (var content in target.InvokeStreamingAsync(plugins["ChatPlugin"]["Chat"], new() { [InputParameterName] = prompt })) + { + fullResult.Append(content); + Assert.NotNull(content.InnerContent); + } + + // Assert + Assert.Contains(expectedAnswerContains, fullResult.ToString(), StringComparison.OrdinalIgnoreCase); + } + + [Fact(Skip = "For manual verification only")] + public async Task ItShouldReturnInnerContentAsync() + { + // Arrange + this._kernelBuilder.Services.AddSingleton(this._logger); + + this.ConfigureTextOllama(this._kernelBuilder); + + var kernel = this._kernelBuilder.Build(); + + var plugin = TestHelpers.ImportSamplePlugins(kernel, "FunPlugin"); + + // Act + StreamingKernelContent? lastUpdate = null; + await foreach (var update in kernel.InvokeStreamingAsync(plugin["FunPlugin"]["Limerick"])) + { + lastUpdate = update; + } + + // Assert + Assert.NotNull(lastUpdate); + Assert.NotNull(lastUpdate.InnerContent); + + Assert.IsType(lastUpdate.InnerContent); + var innerContent = lastUpdate.InnerContent as GenerateDoneResponseStream; + Assert.NotNull(innerContent); + Assert.NotNull(innerContent.CreatedAt); + Assert.True(innerContent.Done); + } + + [Theory(Skip = "For manual verification only")] + [InlineData("\n")] + [InlineData("\r\n")] + public async Task ItCompletesWithDifferentLineEndingsAsync(string lineEnding) + { + // Arrange + var prompt = + "Given a json input and a request. Apply the request on the json input and return the result. " + + $"Put the result in between tags{lineEnding}" + + $$"""Input:{{lineEnding}}{"name": "John", "age": 30}{{lineEnding}}{{lineEnding}}Request:{{lineEnding}}name"""; + + const string ExpectedAnswerContains = "result"; + + this._kernelBuilder.Services.AddSingleton(this._logger); + this.ConfigureTextOllama(this._kernelBuilder); + + Kernel target = this._kernelBuilder.Build(); + + IReadOnlyKernelPluginCollection plugins = TestHelpers.ImportSamplePlugins(target, "ChatPlugin"); + + // Act + FunctionResult actual = await target.InvokeAsync(plugins["ChatPlugin"]["Chat"], new() { [InputParameterName] = prompt }); + + // Assert + Assert.Contains(ExpectedAnswerContains, actual.GetValue(), StringComparison.OrdinalIgnoreCase); + } + + [Fact(Skip = "For manual verification only")] + public async Task ItInvokePromptTestAsync() + { + // Arrange + this._kernelBuilder.Services.AddSingleton(this._logger); + var builder = this._kernelBuilder; + this.ConfigureTextOllama(builder); + Kernel target = builder.Build(); + + var prompt = "Where is the most famous fish market in Seattle, Washington, USA?"; + + // Act + FunctionResult actual = await target.InvokePromptAsync(prompt, new(new OllamaPromptExecutionSettings() { Temperature = 0.5f })); + + // Assert + Assert.Contains("Pike Place", actual.GetValue(), StringComparison.OrdinalIgnoreCase); + } + + [Theory(Skip = "For manual verification only")] + [InlineData("Where is the most famous fish market in Seattle, Washington, USA?", "Pike Place")] + public async Task ItInvokeTestAsync(string prompt, string expectedAnswerContains) + { + // Arrange + this._kernelBuilder.Services.AddSingleton(this._logger); + var builder = this._kernelBuilder; + + this.ConfigureTextOllama(this._kernelBuilder); + + Kernel target = builder.Build(); + + IReadOnlyKernelPluginCollection plugins = TestHelpers.ImportSamplePlugins(target, "ChatPlugin"); + + // Act + FunctionResult actual = await target.InvokeAsync(plugins["ChatPlugin"]["Chat"], new() { [InputParameterName] = prompt }); + + // Assert + Assert.Contains(expectedAnswerContains, actual.GetValue(), StringComparison.OrdinalIgnoreCase); + var content = actual.GetValue(); + Assert.NotNull(content); + Assert.NotNull(content.InnerContent); + } + + #region internals + + private readonly XunitLogger _logger = new(output); + private readonly RedirectOutput _testOutputHelper = new(output); + + public void Dispose() + { + this._logger.Dispose(); + this._testOutputHelper.Dispose(); + } + + private void ConfigureTextOllama(IKernelBuilder kernelBuilder) + { + var config = this._configuration.GetSection("Ollama").Get(); + + Assert.NotNull(config); + Assert.NotNull(config.Endpoint); + Assert.NotNull(config.ModelId); + + kernelBuilder.AddOllamaTextGeneration( + modelId: config.ModelId, + endpoint: new Uri(config.Endpoint)); + } + + #endregion +} diff --git a/dotnet/src/IntegrationTests/IntegrationTests.csproj b/dotnet/src/IntegrationTests/IntegrationTests.csproj index 5686e8e3e96e..0ab7bcc04b90 100644 --- a/dotnet/src/IntegrationTests/IntegrationTests.csproj +++ b/dotnet/src/IntegrationTests/IntegrationTests.csproj @@ -74,6 +74,7 @@ + diff --git a/dotnet/src/IntegrationTests/TestSettings/OllamaConfiguration.cs b/dotnet/src/IntegrationTests/TestSettings/OllamaConfiguration.cs new file mode 100644 index 000000000000..51e8d77eee0a --- /dev/null +++ b/dotnet/src/IntegrationTests/TestSettings/OllamaConfiguration.cs @@ -0,0 +1,13 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Diagnostics.CodeAnalysis; + +namespace SemanticKernel.IntegrationTests.TestSettings; + +[SuppressMessage("Performance", "CA1812:Internal class that is apparently never instantiated", + Justification = "Configuration classes are instantiated through IConfiguration.")] +internal sealed class OllamaConfiguration +{ + public string? ModelId { get; set; } + public string? Endpoint { get; set; } +} diff --git a/dotnet/src/InternalUtilities/samples/InternalUtilities/BaseTest.cs b/dotnet/src/InternalUtilities/samples/InternalUtilities/BaseTest.cs index d71d3c1f0032..5b1916984d30 100644 --- a/dotnet/src/InternalUtilities/samples/InternalUtilities/BaseTest.cs +++ b/dotnet/src/InternalUtilities/samples/InternalUtilities/BaseTest.cs @@ -4,6 +4,7 @@ using Microsoft.Extensions.Configuration; using Microsoft.Extensions.Logging; using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.ChatCompletion; public abstract class BaseTest { @@ -101,6 +102,17 @@ public void WriteLine(string? message) public void Write(object? target = null) => this.Output.WriteLine(target ?? string.Empty); + /// + /// Outputs the last message in the chat history. + /// + /// Chat history + protected void OutputLastMessage(ChatHistory chatHistory) + { + var message = chatHistory.Last(); + + Console.WriteLine($"{message.Role}: {message.Content}"); + Console.WriteLine("------------------------"); + } protected sealed class LoggingHandler(HttpMessageHandler innerHandler, ITestOutputHelper output) : DelegatingHandler(innerHandler) { private static readonly JsonSerializerOptions s_jsonSerializerOptions = new() { WriteIndented = true }; diff --git a/dotnet/src/InternalUtilities/samples/InternalUtilities/TestConfiguration.cs b/dotnet/src/InternalUtilities/samples/InternalUtilities/TestConfiguration.cs index 1a86413a5e05..6b0cabe9b795 100644 --- a/dotnet/src/InternalUtilities/samples/InternalUtilities/TestConfiguration.cs +++ b/dotnet/src/InternalUtilities/samples/InternalUtilities/TestConfiguration.cs @@ -19,6 +19,7 @@ public static void Initialize(IConfigurationRoot configRoot) s_instance = new TestConfiguration(configRoot); } + public static OllamaConfig Ollama => LoadSection(); public static OpenAIConfig OpenAI => LoadSection(); public static AzureOpenAIConfig AzureOpenAI => LoadSection(); public static AzureOpenAIConfig AzureOpenAIImages => LoadSection(); @@ -220,6 +221,14 @@ public class GeminiConfig } } + public class OllamaConfig + { + public string? ModelId { get; set; } + public string? EmbeddingModelId { get; set; } + + public string Endpoint { get; set; } = "http://localhost:11434"; + } + public class AzureCosmosDbMongoDbConfig { public string ConnectionString { get; set; }