Skip to content

Commit

Permalink
Copilot Chat: Update to newest SK nuget (microsoft#1811)
Browse files Browse the repository at this point in the history
### Motivation and Context

Update CopilotChat sample to the latest SK nuget version, absorbing
breaking changes. **I've done zero testing on this, other than "it
builds", so we'll want to validate it before merging.**
cc: @shawncal 

This cleans up some cruft related to skill definitions, though there's
still likely more that can be done for someone more familiar than I am
with this code. I expect we'll be able to clean it up further once the
exception handling changes land, as there's a lot of code dedicated to
checking whether an error has occurred and short-circuiting if one has.
(@SergeyMenshykh, this would be a good one to pay attention to, in
particular around how exceptions propagate out of semantic functions
instead of relying on manual checking of the returned context.)

### Contribution Checklist
<!-- Before submitting this PR, please make sure: -->
- [x] The code builds clean without any errors or warnings
- [x] The PR follows SK Contribution Guidelines
(https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
- [x] The code follows the .NET coding conventions
(https://learn.microsoft.com/dotnet/csharp/fundamentals/coding-style/coding-conventions)
verified with `dotnet format`
- [x] All unit tests pass, and I have added new tests where possible
- [x] I didn't break anyone 😄

Co-authored-by: Gina Triolo <51341242+gitri-ms@users.noreply.github.com>
Co-authored-by: Chris <66376200+crickman@users.noreply.github.com>
  • Loading branch information
3 people committed Jul 10, 2023
1 parent 76a5d24 commit 4945445
Show file tree
Hide file tree
Showing 7 changed files with 80 additions and 177 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
using Microsoft.SemanticKernel.Skills.MsGraph.Connectors;
using Microsoft.SemanticKernel.Skills.MsGraph.Connectors.Client;
using Microsoft.SemanticKernel.Skills.OpenAPI.Authentication;
using Microsoft.SemanticKernel.Skills.OpenAPI.Extensions;
using SemanticKernel.Service.CopilotChat.Hubs;
using SemanticKernel.Service.CopilotChat.Models;
using SemanticKernel.Service.CopilotChat.Skills.ChatSkills;
Expand Down Expand Up @@ -162,7 +163,10 @@ private async Task RegisterPlannerSkillsAsync(CopilotChatPlanner planner, OpenAp
using HttpClient importHttpClient = new(retryHandler, false);
importHttpClient.DefaultRequestHeaders.Add("User-Agent", "Microsoft.CopilotChat");
await planner.Kernel.ImportChatGptPluginSkillFromUrlAsync("KlarnaShoppingSkill", new Uri("https://www.klarna.com/.well-known/ai-plugin.json"),
importHttpClient);
new OpenApiSkillExecutionParameters
{
HttpClient = importHttpClient,
});
}

// GitHub
Expand All @@ -173,7 +177,10 @@ private async Task RegisterPlannerSkillsAsync(CopilotChatPlanner planner, OpenAp
await planner.Kernel.ImportOpenApiSkillFromFileAsync(
skillName: "GitHubSkill",
filePath: Path.Combine(Path.GetDirectoryName(Assembly.GetExecutingAssembly().Location)!, "CopilotChat", "Skills", "OpenApiSkills/GitHubSkill/openapi.json"),
authCallback: authenticationProvider.AuthenticateRequestAsync);
new OpenApiSkillExecutionParameters
{
AuthCallback = authenticationProvider.AuthenticateRequestAsync,
});
}

// Jira
Expand All @@ -186,8 +193,11 @@ await planner.Kernel.ImportOpenApiSkillFromFileAsync(
await planner.Kernel.ImportOpenApiSkillFromFileAsync(
skillName: "JiraSkill",
filePath: Path.Combine(Path.GetDirectoryName(Assembly.GetExecutingAssembly().Location)!, "CopilotChat", "Skills", "OpenApiSkills/JiraSkill/openapi.json"),
authCallback: authenticationProvider.AuthenticateRequestAsync,
serverUrlOverride: hasServerUrlOverride ? new Uri(serverUrlOverride!) : null);
new OpenApiSkillExecutionParameters
{
AuthCallback = authenticationProvider.AuthenticateRequestAsync,
ServerUrlOverride = hasServerUrlOverride ? new Uri(serverUrlOverride!) : null,
});
}

// Microsoft Graph
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

using System;
using System.Collections.Generic;
using System.ComponentModel;
using System.Globalization;
using System.Linq;
using System.Text.Json;
Expand Down Expand Up @@ -93,10 +94,9 @@ public ChatSkill(
/// Extract user intent from the conversation history.
/// </summary>
/// <param name="context">The SKContext.</param>
[SKFunction("Extract user intent")]
[SKFunctionName("ExtractUserIntent")]
[SKFunctionContextParameter(Name = "chatId", Description = "Chat ID to extract history from")]
[SKFunctionContextParameter(Name = "audience", Description = "The audience the chat bot is interacting with.")]
[SKFunction, Description("Extract user intent")]
[SKParameter("chatId", "Chat ID to extract history from")]
[SKParameter("audience", "The audience the chat bot is interacting with.")]
public async Task<string> ExtractUserIntentAsync(SKContext context)
{
var tokenLimit = this._promptOptions.CompletionTokenLimit;
Expand Down Expand Up @@ -140,9 +140,8 @@ public async Task<string> ExtractUserIntentAsync(SKContext context)
/// Extract the list of participants from the conversation history.
/// Note that only those who have spoken will be included.
/// </summary>
[SKFunction("Extract audience list")]
[SKFunctionName("ExtractAudience")]
[SKFunctionContextParameter(Name = "chatId", Description = "Chat ID to extract history from")]
[SKFunction, Description("Extract audience list")]
[SKParameter("chatId", "Chat ID to extract history from")]
public async Task<string> ExtractAudienceAsync(SKContext context)
{
var tokenLimit = this._promptOptions.CompletionTokenLimit;
Expand Down Expand Up @@ -184,15 +183,11 @@ public async Task<string> ExtractAudienceAsync(SKContext context)
/// Extract chat history.
/// </summary>
/// <param name="context">Contains the 'tokenLimit' controlling the length of the prompt.</param>
[SKFunction("Extract chat history")]
[SKFunctionName("ExtractChatHistory")]
[SKFunctionContextParameter(Name = "chatId", Description = "Chat ID to extract history from")]
[SKFunctionContextParameter(Name = "tokenLimit", Description = "Maximum number of tokens")]
public async Task<string> ExtractChatHistoryAsync(SKContext context)
[SKFunction, Description("Extract chat history")]
public async Task<string> ExtractChatHistoryAsync(
[Description("Chat ID to extract history from")] string chatId,
[Description("Maximum number of tokens")] int tokenLimit)
{
var chatId = context["chatId"];
var tokenLimit = int.Parse(context["tokenLimit"], new NumberFormatInfo());

var messages = await this._chatMessageRepository.FindByChatIdAsync(chatId);
var sortedMessages = messages.OrderByDescending(m => m.Timestamp);

Expand Down Expand Up @@ -242,36 +237,19 @@ public async Task<string> ExtractChatHistoryAsync(SKContext context)
/// messages to memory, and fill in the necessary context variables for completing the
/// prompt that will be rendered by the template engine.
/// </summary>
/// <param name="message"></param>
/// <param name="context">Contains the 'tokenLimit' and the 'contextTokenLimit' controlling the length of the prompt.</param>
[SKFunction("Get chat response")]
[SKFunctionName("Chat")]
[SKFunctionInput(Description = "The new message")]
[SKFunctionContextParameter(Name = "userId", Description = "Unique and persistent identifier for the user")]
[SKFunctionContextParameter(Name = "userName", Description = "Name of the user")]
[SKFunctionContextParameter(Name = "chatId", Description = "Unique and persistent identifier for the chat")]
[SKFunctionContextParameter(Name = "proposedPlan", Description = "Previously proposed plan that is approved")]
[SKFunctionContextParameter(Name = "messageType", Description = "Type of the message")]
[SKFunctionContextParameter(Name = "responseMessageId", Description = "ID of the response message for planner")]
public async Task<SKContext> ChatAsync(string message, SKContext context)
[SKFunction, Description("Get chat response")]
public async Task<SKContext> ChatAsync(
[Description("The new message")] string message,
[Description("Unique and persistent identifier for the user")] string userId,
[Description("Name of the user")] string userName,
[Description("Unique and persistent identifier for the chat")] string chatId,
[Description("Type of the message")] string messageType,
[Description("Previously proposed plan that is approved"), DefaultValue(null), SKName("proposedPlan")] string? planJson,
[Description("ID of the response message for planner"), DefaultValue(null), SKName("responseMessageId")] string? messageId,
SKContext context)
{
// TODO: check if user has access to the chat
var userId = context["userId"];
var userName = context["userName"];
var chatId = context["chatId"];
var messageType = context["messageType"];

// Save this new message to memory such that subsequent chat responses can use it
try
{
await this.SaveNewMessageAsync(message, userId, userName, chatId, messageType);
}
catch (Exception ex) when (!ex.IsCriticalException())
{
context.Log.LogError("Unable to save new message: {0}", ex.Message);
context.Fail($"Unable to save new message: {ex.Message}", ex);
return context;
}
await this.SaveNewMessageAsync(message, userId, userName, chatId, messageType);

// Clone the context to avoid modifying the original context variables.
var chatContext = Utilities.CopyContextWithVariablesClone(context);
Expand All @@ -280,16 +258,15 @@ public async Task<SKContext> ChatAsync(string message, SKContext context)
// Check if plan exists in ask's context variables.
// If plan was returned at this point, that means it was approved or cancelled.
// Update the response previously saved in chat history with state
if (context.Variables.TryGetValue("proposedPlan", out string? planJson)
&& !string.IsNullOrWhiteSpace(planJson)
&& context.Variables.TryGetValue("responseMessageId", out string? messageId))
if (!string.IsNullOrWhiteSpace(planJson) &&
!string.IsNullOrEmpty(messageId))
{
await this.UpdateResponseAsync(planJson, messageId);
}

var response = chatContext.Variables.ContainsKey("userCancelledPlan")
? "I am sorry the plan did not meet your goals."
: await this.GetChatResponseAsync(chatContext);
: await this.GetChatResponseAsync(chatId, chatContext);

if (chatContext.ErrorOccurred)
{
Expand All @@ -299,24 +276,14 @@ public async Task<SKContext> ChatAsync(string message, SKContext context)

// Retrieve the prompt used to generate the response
// and return it to the caller via the context variables.
var prompt = chatContext.Variables.ContainsKey("prompt")
? chatContext.Variables["prompt"]
: string.Empty;
chatContext.Variables.TryGetValue("prompt", out string? prompt);
prompt ??= string.Empty;
context.Variables.Set("prompt", prompt);

// Save this response to memory such that subsequent chat responses can use it
try
{
ChatMessage botMessage = await this.SaveNewResponseAsync(response, prompt, chatId);
context.Variables.Set("messageId", botMessage.Id);
context.Variables.Set("messageType", ((int)botMessage.Type).ToString(CultureInfo.InvariantCulture));
}
catch (Exception ex) when (!ex.IsCriticalException())
{
context.Log.LogError("Unable to save new response: {0}", ex.Message);
context.Fail($"Unable to save new response: {ex.Message}");
return context;
}
ChatMessage botMessage = await this.SaveNewResponseAsync(response, prompt, chatId);
context.Variables.Set("messageId", botMessage.Id);
context.Variables.Set("messageType", ((int)botMessage.Type).ToString(CultureInfo.InvariantCulture));

// Extract semantic chat memory
await SemanticChatMemoryExtractor.ExtractSemanticChatMemoryAsync(
Expand All @@ -336,7 +303,7 @@ await SemanticChatMemoryExtractor.ExtractSemanticChatMemoryAsync(
/// </summary>
/// <param name="chatContext">The SKContext.</param>
/// <returns>A response from the model.</returns>
private async Task<string> GetChatResponseAsync(SKContext chatContext)
private async Task<string> GetChatResponseAsync(string chatId, SKContext chatContext)
{
// 0. Get the audience
var audience = await this.GetAudienceAsync(chatContext);
Expand Down Expand Up @@ -371,15 +338,15 @@ private async Task<string> GetChatResponseAsync(SKContext chatContext)

// 4. Query relevant semantic memories
var chatMemoriesTokenLimit = (int)(remainingToken * this._promptOptions.MemoriesResponseContextWeight);
var chatMemories = await this.QueryChatMemoriesAsync(chatContext, userIntent, chatMemoriesTokenLimit);
var chatMemories = await this._semanticChatMemorySkill.QueryMemoriesAsync(userIntent, chatId, chatMemoriesTokenLimit, chatContext.Memory);
if (chatContext.ErrorOccurred)
{
return string.Empty;
}

// 5. Query relevant document memories
var documentContextTokenLimit = (int)(remainingToken * this._promptOptions.DocumentContextWeight);
var documentMemories = await this.QueryDocumentsAsync(chatContext, userIntent, documentContextTokenLimit);
var documentMemories = await this._documentMemorySkill.QueryDocumentsAsync(userIntent, chatId, documentContextTokenLimit, chatContext.Memory);
if (chatContext.ErrorOccurred)
{
return string.Empty;
Expand All @@ -391,7 +358,7 @@ private async Task<string> GetChatResponseAsync(SKContext chatContext)
var chatContextTextTokenCount = remainingToken - Utilities.TokenCount(chatContextText);
if (chatContextTextTokenCount > 0)
{
var chatHistory = await this.GetChatHistoryAsync(chatContext, chatContextTextTokenCount);
var chatHistory = await this.ExtractChatHistoryAsync(chatId, chatContextTextTokenCount);
if (chatContext.ErrorOccurred)
{
return string.Empty;
Expand Down Expand Up @@ -490,62 +457,13 @@ private async Task<string> GetUserIntentAsync(SKContext context)
return userIntent;
}

/// <summary>
/// Helper function create the correct context variables to
/// extract chat history messages from the conversation history.
/// </summary>
private Task<string> GetChatHistoryAsync(SKContext context, int tokenLimit)
{
var contextVariables = new ContextVariables();
contextVariables.Set("chatId", context["chatId"]);
contextVariables.Set("tokenLimit", tokenLimit.ToString(new NumberFormatInfo()));

var chatHistoryContext = new SKContext(
contextVariables,
context.Memory,
context.Skills,
context.Log,
context.CancellationToken
);

var chatHistory = this.ExtractChatHistoryAsync(chatHistoryContext);

// Propagate the error
if (chatHistoryContext.ErrorOccurred)
{
context.Fail(chatHistoryContext.LastErrorDescription);
}

return chatHistory;
}

/// <summary>
/// Helper function create the correct context variables to
/// query chat memories from the chat memory store.
/// </summary>
private Task<string> QueryChatMemoriesAsync(SKContext context, string userIntent, int tokenLimit)
{
var contextVariables = new ContextVariables();
contextVariables.Set("chatId", context["chatId"]);
contextVariables.Set("tokenLimit", tokenLimit.ToString(new NumberFormatInfo()));

var chatMemoriesContext = new SKContext(
contextVariables,
context.Memory,
context.Skills,
context.Log,
context.CancellationToken
);

var chatMemories = this._semanticChatMemorySkill.QueryMemoriesAsync(userIntent, chatMemoriesContext);

// Propagate the error
if (chatMemoriesContext.ErrorOccurred)
{
context.Fail(chatMemoriesContext.LastErrorDescription);
}

return chatMemories;
return this._semanticChatMemorySkill.QueryMemoriesAsync(userIntent, context["chatId"], tokenLimit, context.Memory);
}

/// <summary>
Expand All @@ -554,27 +472,7 @@ private Task<string> QueryChatMemoriesAsync(SKContext context, string userIntent
/// </summary>
private Task<string> QueryDocumentsAsync(SKContext context, string userIntent, int tokenLimit)
{
var contextVariables = new ContextVariables();
contextVariables.Set("chatId", context["chatId"]);
contextVariables.Set("tokenLimit", tokenLimit.ToString(new NumberFormatInfo()));

var documentMemoriesContext = new SKContext(
contextVariables,
context.Memory,
context.Skills,
context.Log,
context.CancellationToken
);

var documentMemories = this._documentMemorySkill.QueryDocumentsAsync(userIntent, documentMemoriesContext);

// Propagate the error
if (documentMemoriesContext.ErrorOccurred)
{
context.Fail(documentMemoriesContext.LastErrorDescription);
}

return documentMemories;
return this._documentMemorySkill.QueryDocumentsAsync(userIntent, context["chatId"], tokenLimit, context.Memory);
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Collections.Generic;
using System.Globalization;
using System.ComponentModel;
using System.Linq;
using System.Threading.Tasks;
using Microsoft.Extensions.Options;
using Microsoft.SemanticKernel.Memory;
using Microsoft.SemanticKernel.Orchestration;
using Microsoft.SemanticKernel.SkillDefinition;
using SemanticKernel.Service.CopilotChat.Options;

Expand Down Expand Up @@ -43,15 +42,13 @@ public DocumentMemorySkill(
/// </summary>
/// <param name="query">Query to match.</param>
/// <param name="context">The SkContext.</param>
[SKFunction("Query documents in the memory given a user message")]
[SKFunctionName("QueryDocuments")]
[SKFunctionInput(Description = "Query to match.")]
[SKFunctionContextParameter(Name = "chatId", Description = "ID of the chat that owns the documents")]
[SKFunctionContextParameter(Name = "tokenLimit", Description = "Maximum number of tokens")]
public async Task<string> QueryDocumentsAsync(string query, SKContext context)
[SKFunction, Description("Query documents in the memory given a user message")]
public async Task<string> QueryDocumentsAsync(
[Description("Query to match.")] string query,
[Description("ID of the chat that owns the documents")] string chatId,
[Description("Maximum number of tokens")] int tokenLimit,
ISemanticTextMemory textMemory)
{
string chatId = context.Variables["chatId"];
int tokenLimit = int.Parse(context.Variables["tokenLimit"], new NumberFormatInfo());
var remainingToken = tokenLimit;

// Search for relevant document snippets.
Expand All @@ -64,7 +61,7 @@ public async Task<string> QueryDocumentsAsync(string query, SKContext context)
List<MemoryQueryResult> relevantMemories = new();
foreach (var documentCollection in documentCollections)
{
var results = context.Memory.SearchAsync(
var results = textMemory.SearchAsync(
documentCollection,
query,
limit: 100,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

using System;
using System.Collections.Generic;
using System.ComponentModel;
using System.Globalization;
using System.Linq;
using System.Text.Json;
Expand Down Expand Up @@ -64,12 +65,12 @@ public ExternalInformationSkill(
/// <summary>
/// Extract relevant additional knowledge using a planner.
/// </summary>
[SKFunction("Acquire external information")]
[SKFunctionName("AcquireExternalInformation")]
[SKFunctionInput(Description = "The intent to whether external information is needed")]
[SKFunctionContextParameter(Name = "tokenLimit", Description = "Maximum number of tokens")]
[SKFunctionContextParameter(Name = "proposedPlan", Description = "Previously proposed plan that is approved")]
public async Task<string> AcquireExternalInformationAsync(string userIntent, SKContext context)
[SKFunction, Description("Acquire external information")]
[SKParameter("tokenLimit", "Maximum number of tokens")]
[SKParameter("proposedPlan", "Previously proposed plan that is approved")]
public async Task<string> AcquireExternalInformationAsync(
[Description("The intent to whether external information is needed")] string userIntent,
SKContext context)
{
FunctionsView functions = this._planner.Kernel.Skills.GetFunctionsView(true, true);
if (functions.NativeFunctions.IsEmpty && functions.SemanticFunctions.IsEmpty)
Expand Down
Loading

0 comments on commit 4945445

Please sign in to comment.