diff --git a/Wox/go.mod b/Wox/go.mod index 3d7f9ab9f..ecf40ccb8 100644 --- a/Wox/go.mod +++ b/Wox/go.mod @@ -27,6 +27,7 @@ require ( github.com/otiai10/copy v1.14.0 github.com/parsiya/golnk v0.0.0-20221103095132-740a4c27c4ff github.com/petermattis/goid v0.0.0-20240327183114-c42a807a84ba + github.com/pkg/errors v0.9.1 github.com/robotn/gohook v0.41.0 github.com/rs/cors v1.10.1 github.com/sahilm/fuzzy v0.1.1 diff --git a/Wox/go.sum b/Wox/go.sum index a3e2cc70c..5a45bc1e8 100644 --- a/Wox/go.sum +++ b/Wox/go.sum @@ -135,6 +135,8 @@ github.com/parsiya/golnk v0.0.0-20221103095132-740a4c27c4ff h1:japdIZgV4tJIgn7Nq github.com/parsiya/golnk v0.0.0-20221103095132-740a4c27c4ff/go.mod h1:A24WXUol4NXZlK8grjh/CsZnPlimfwaQFt5PQsqS27s= github.com/petermattis/goid v0.0.0-20240327183114-c42a807a84ba h1:3jPgmsFGBID1wFfU2AbYocNcN4wqU68UaHSdMjiw/7U= github.com/petermattis/goid v0.0.0-20240327183114-c42a807a84ba/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkoukk/tiktoken-go v0.1.6 h1:JF0TlJzhTbrI30wCvFuiw6FzP2+/bR+FIxUdgEAcUsw= github.com/pkoukk/tiktoken-go v0.1.6/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= diff --git a/Wox/main.go b/Wox/main.go index 0c9aa2161..7ebb3defb 100644 --- a/Wox/main.go +++ b/Wox/main.go @@ -25,7 +25,6 @@ import _ "wox/plugin/host" // import all hosts import _ "wox/plugin/system" import _ "wox/plugin/system/app" import _ "wox/plugin/system/calculator" -import _ "wox/plugin/system/llm" import _ "wox/plugin/system/file" func main() { diff --git a/Wox/plugin/api.go b/Wox/plugin/api.go index 9a38fbdc5..3db445633 100644 --- a/Wox/plugin/api.go +++ b/Wox/plugin/api.go @@ -6,6 +6,7 @@ import ( "github.com/samber/lo" "path" "wox/i18n" + "wox/plugin/llm" "wox/setting" "wox/setting/definition" "wox/share" @@ -33,6 +34,8 @@ type API interface { OnSettingChanged(ctx context.Context, callback func(key string, value string)) OnGetDynamicSetting(ctx context.Context, callback func(key string) definition.PluginSettingDefinitionItem) RegisterQueryCommands(ctx context.Context, commands []MetadataCommand) + LLMChat(ctx context.Context, conversations []llm.Conversation) (string, error) + LLMChatStream(ctx context.Context, conversations []llm.Conversation) (llm.ChatStream, error) } type APIImpl struct { @@ -152,6 +155,24 @@ func (a *APIImpl) RegisterQueryCommands(ctx context.Context, commands []Metadata a.pluginInstance.SaveSetting(ctx) } +func (a *APIImpl) LLMChat(ctx context.Context, conversations []llm.Conversation) (string, error) { + provider, model := llm.GetInstance() + if provider == nil { + return "", fmt.Errorf("no LLM provider found") + } + + return provider.Chat(ctx, model, conversations) +} + +func (a *APIImpl) LLMChatStream(ctx context.Context, conversations []llm.Conversation) (llm.ChatStream, error) { + provider, model := llm.GetInstance() + if provider == nil { + return nil, fmt.Errorf("no LLM provider found") + } + + return provider.ChatStream(ctx, model, conversations) +} + func NewAPI(instance *Instance) API { apiImpl := &APIImpl{pluginInstance: instance} logFolder := path.Join(util.GetLocation().GetLogPluginDirectory(), instance.Metadata.Name) diff --git a/Wox/plugin/llm/instance.go b/Wox/plugin/llm/instance.go new file mode 100644 index 000000000..42dee358a --- /dev/null +++ b/Wox/plugin/llm/instance.go @@ -0,0 +1,17 @@ +package llm + +var provider Provider +var model Model + +func GetInstance() (Provider, Model) { + return provider, model +} + +func SetInstance(p Provider, m Model) { + provider = p + model = m +} + +func IsInstanceReady() bool { + return provider != nil && model.Name != "" +} diff --git a/Wox/plugin/llm/provider.go b/Wox/plugin/llm/provider.go new file mode 100644 index 000000000..4b0e092ff --- /dev/null +++ b/Wox/plugin/llm/provider.go @@ -0,0 +1,69 @@ +package llm + +import ( + "context" + "errors" +) + +type ConversationRole string + +var ( + ConversationRoleUser ConversationRole = "user" + ConversationRoleSystem ConversationRole = "system" +) + +type Conversation struct { + Role ConversationRole + Text string + Timestamp int64 +} + +type ModelProviderName string + +var ( + ModelProviderNameOpenAI ModelProviderName = "openai" + ModelProviderNameGoogle ModelProviderName = "google" + ModelProviderNameOllama ModelProviderName = "ollama" + ModelProviderNameGroq ModelProviderName = "groq" +) + +type Model struct { + DisplayName string + Name string + Provider ModelProviderName +} + +type Provider interface { + Close(ctx context.Context) error + ChatStream(ctx context.Context, model Model, conversations []Conversation) (ChatStream, error) + Chat(ctx context.Context, model Model, conversations []Conversation) (string, error) + Models(ctx context.Context) ([]Model, error) +} + +type ChatStream interface { + Receive(ctx context.Context) (string, error) // will return io.EOF if no more messages +} + +type ProviderConnectContext struct { + Provider ModelProviderName + + ApiKey string + Host string // E.g. "https://api.openai.com:8908" +} + +func NewProvider(ctx context.Context, connectContext ProviderConnectContext) (Provider, error) { + if connectContext.Provider == ModelProviderNameGoogle { + return NewGoogleProvider(ctx, connectContext), nil + } + if connectContext.Provider == ModelProviderNameOpenAI { + return NewOpenAIClient(ctx, connectContext), nil + } + if connectContext.Provider == ModelProviderNameOllama { + return NewOllamaProvider(ctx, connectContext), nil + } + if connectContext.Provider == ModelProviderNameGroq { + return NewGroqProvider(ctx, connectContext), nil + } + + return nil, errors.New("unknown model provider") +} diff --git a/Wox/plugin/system/llm/provider_google.go b/Wox/plugin/llm/provider_google.go similarity index 88% rename from Wox/plugin/system/llm/provider_google.go rename to Wox/plugin/llm/provider_google.go index 0eb7c61fc..4516bb662 100644 --- a/Wox/plugin/system/llm/provider_google.go +++ b/Wox/plugin/llm/provider_google.go @@ -11,7 +11,7 @@ import ( ) type GoogleProvider struct { - connectContext providerConnectContext + connectContext ProviderConnectContext client *genai.Client } @@ -20,7 +20,7 @@ type GoogleProviderStream struct { conversations []Conversation } -func NewGoogleProvider(ctx context.Context, connectContext providerConnectContext) Provider { +func NewGoogleProvider(ctx context.Context, connectContext ProviderConnectContext) Provider { return &GoogleProvider{connectContext: connectContext} } @@ -44,7 +44,7 @@ func (g *GoogleProvider) Close(ctx context.Context) error { return nil } -func (g *GoogleProvider) ChatStream(ctx context.Context, model model, conversations []Conversation) (ProviderChatStream, error) { +func (g *GoogleProvider) ChatStream(ctx context.Context, model Model, conversations []Conversation) (ChatStream, error) { if ensureClientErr := g.ensureClient(ctx); ensureClientErr != nil { return nil, ensureClientErr } @@ -57,7 +57,7 @@ func (g *GoogleProvider) ChatStream(ctx context.Context, model model, conversati return &GoogleProviderStream{conversations: conversations, stream: stream}, nil } -func (g *GoogleProvider) Chat(ctx context.Context, model model, conversations []Conversation) (string, error) { +func (g *GoogleProvider) Chat(ctx context.Context, model Model, conversations []Conversation) (string, error) { if ensureClientErr := g.ensureClient(ctx); ensureClientErr != nil { return "", ensureClientErr } @@ -80,17 +80,17 @@ func (g *GoogleProvider) Chat(ctx context.Context, model model, conversations [] return "", errors.New("no text in response") } -func (g *GoogleProvider) Models(ctx context.Context) ([]model, error) { - return []model{ +func (g *GoogleProvider) Models(ctx context.Context) ([]Model, error) { + return []Model{ { DisplayName: "google-gemini-1.0-pro", Name: "gemini-1.0-pro", - Provider: modelProviderNameGoogle, + Provider: ModelProviderNameGoogle, }, { DisplayName: "google-gemini-1.5-pro", Name: "gemini-1.5-pro", - Provider: modelProviderNameGoogle, + Provider: ModelProviderNameGoogle, }, }, nil } diff --git a/Wox/plugin/system/llm/provider_groq.go b/Wox/plugin/llm/provider_groq.go similarity index 78% rename from Wox/plugin/system/llm/provider_groq.go rename to Wox/plugin/llm/provider_groq.go index 50c8f57d7..42f063fad 100644 --- a/Wox/plugin/system/llm/provider_groq.go +++ b/Wox/plugin/llm/provider_groq.go @@ -10,22 +10,20 @@ import ( "github.com/tmc/langchaingo/llms/openai" "github.com/tmc/langchaingo/schema" "io" - "wox/plugin" "wox/util" ) type GroqProvider struct { - connectContext providerConnectContext + connectContext ProviderConnectContext client *openai.LLM } type GroqProviderStream struct { conversations []Conversation reader io.Reader - api plugin.API } -func NewGroqProvider(ctx context.Context, connectContext providerConnectContext) Provider { +func NewGroqProvider(ctx context.Context, connectContext ProviderConnectContext) Provider { return &GroqProvider{connectContext: connectContext} } @@ -33,7 +31,7 @@ func (g *GroqProvider) Close(ctx context.Context) error { return nil } -func (o *GroqProvider) ChatStream(ctx context.Context, model model, conversations []Conversation) (ProviderChatStream, error) { +func (o *GroqProvider) ChatStream(ctx context.Context, model Model, conversations []Conversation) (ChatStream, error) { client, clientErr := openai.New(openai.WithModel(model.Name), openai.WithBaseURL("https://api.groq.com/openai/v1"), openai.WithToken(o.connectContext.ApiKey)) if clientErr != nil { return nil, clientErr @@ -43,7 +41,6 @@ func (o *GroqProvider) ChatStream(ctx context.Context, model model, conversation r, w := nio.Pipe(buf) util.Go(ctx, "Groq chat stream", func() { _, err := client.GenerateContent(ctx, o.convertConversations(conversations), llms.WithStreamingFunc(func(ctx context.Context, chunk []byte) error { - o.connectContext.api.Log(ctx, plugin.LogLevelDebug, fmt.Sprintf("Groq: receive chunks from model: %s", string(chunk))) w.Write(chunk) return nil })) @@ -57,7 +54,7 @@ func (o *GroqProvider) ChatStream(ctx context.Context, model model, conversation return &GroqProviderStream{conversations: conversations, reader: r}, nil } -func (o *GroqProvider) Chat(ctx context.Context, model model, conversations []Conversation) (string, error) { +func (o *GroqProvider) Chat(ctx context.Context, model Model, conversations []Conversation) (string, error) { client, clientErr := openai.New(openai.WithModel(model.Name), openai.WithBaseURL("https://api.groq.com/openai/v1"), openai.WithToken(o.connectContext.ApiKey)) if clientErr != nil { return "", clientErr @@ -71,27 +68,27 @@ func (o *GroqProvider) Chat(ctx context.Context, model model, conversations []Co return response.Choices[0].Content, nil } -func (o *GroqProvider) Models(ctx context.Context) (models []model, err error) { - return []model{ +func (o *GroqProvider) Models(ctx context.Context) (models []Model, err error) { + return []Model{ { Name: "llama3-8b-8192", DisplayName: "llama3-8b-8192", - Provider: modelProviderNameGroq, + Provider: ModelProviderNameGroq, }, { Name: "llama3-70b-8192", DisplayName: "llama3-70b-8192", - Provider: modelProviderNameGroq, + Provider: ModelProviderNameGroq, }, { Name: "mixtral-8x7b-32768", DisplayName: "mixtral-8x7b-32768", - Provider: modelProviderNameGroq, + Provider: ModelProviderNameGroq, }, { Name: "gemma-7b-it", DisplayName: "gemma-7b-it", - Provider: modelProviderNameGroq, + Provider: ModelProviderNameGroq, }, }, nil } @@ -123,7 +120,3 @@ func (s *GroqProviderStream) Receive(ctx context.Context) (string, error) { util.GetLogger().Debug(util.NewTraceContext(), fmt.Sprintf("Groq: Send response: %s", resp)) return resp, nil } - -func (s *GroqProviderStream) Close(ctx context.Context) { - // no-op -} diff --git a/Wox/plugin/system/llm/provider_ollama.go b/Wox/plugin/llm/provider_ollama.go similarity index 81% rename from Wox/plugin/system/llm/provider_ollama.go rename to Wox/plugin/llm/provider_ollama.go index 63312d832..6cd341579 100644 --- a/Wox/plugin/system/llm/provider_ollama.go +++ b/Wox/plugin/llm/provider_ollama.go @@ -11,12 +11,11 @@ import ( "github.com/tmc/langchaingo/llms/ollama" "github.com/tmc/langchaingo/schema" "io" - "wox/plugin" "wox/util" ) type OllamaProvider struct { - connectContext providerConnectContext + connectContext ProviderConnectContext client *ollama.LLM } @@ -25,7 +24,7 @@ type OllamaProviderStream struct { reader io.Reader } -func NewOllamaProvider(ctx context.Context, connectContext providerConnectContext) Provider { +func NewOllamaProvider(ctx context.Context, connectContext ProviderConnectContext) Provider { return &OllamaProvider{connectContext: connectContext} } @@ -33,7 +32,7 @@ func (o *OllamaProvider) Close(ctx context.Context) error { return nil } -func (o *OllamaProvider) ChatStream(ctx context.Context, model model, conversations []Conversation) (ProviderChatStream, error) { +func (o *OllamaProvider) ChatStream(ctx context.Context, model Model, conversations []Conversation) (ChatStream, error) { client, clientErr := ollama.New(ollama.WithServerURL(o.connectContext.Host), ollama.WithModel(model.Name)) if clientErr != nil { return nil, clientErr @@ -43,7 +42,6 @@ func (o *OllamaProvider) ChatStream(ctx context.Context, model model, conversati r, w := nio.Pipe(buf) util.Go(ctx, "ollama chat stream", func() { _, err := client.GenerateContent(ctx, o.convertConversations(conversations), llms.WithStreamingFunc(func(ctx context.Context, chunk []byte) error { - o.connectContext.api.Log(ctx, plugin.LogLevelDebug, fmt.Sprintf("OLLAMA: receive chunks from model: %s", string(chunk))) w.Write(chunk) return nil })) @@ -57,7 +55,7 @@ func (o *OllamaProvider) ChatStream(ctx context.Context, model model, conversati return &OllamaProviderStream{conversations: conversations, reader: r}, nil } -func (o *OllamaProvider) Chat(ctx context.Context, model model, conversations []Conversation) (string, error) { +func (o *OllamaProvider) Chat(ctx context.Context, model Model, conversations []Conversation) (string, error) { client, clientErr := ollama.New(ollama.WithServerURL(o.connectContext.Host), ollama.WithModel(model.Name)) if clientErr != nil { return "", clientErr @@ -71,17 +69,17 @@ func (o *OllamaProvider) Chat(ctx context.Context, model model, conversations [] return response.Choices[0].Content, nil } -func (o *OllamaProvider) Models(ctx context.Context) (models []model, err error) { +func (o *OllamaProvider) Models(ctx context.Context) (models []Model, err error) { body, err := util.HttpGet(ctx, o.connectContext.Host+"/api/tags") if err != nil { return nil, err } gjson.Get(string(body), "models.#.name").ForEach(func(key, value gjson.Result) bool { - models = append(models, model{ + models = append(models, Model{ DisplayName: value.String(), Name: value.String(), - Provider: modelProviderNameOllama, + Provider: ModelProviderNameOllama, }) return true }) @@ -116,7 +114,3 @@ func (s *OllamaProviderStream) Receive(ctx context.Context) (string, error) { util.GetLogger().Debug(util.NewTraceContext(), fmt.Sprintf("OLLAMA: Send response: %s", resp)) return resp, nil } - -func (s *OllamaProviderStream) Close(ctx context.Context) { - // no-op -} diff --git a/Wox/plugin/system/llm/provider_openai.go b/Wox/plugin/llm/provider_openai.go similarity index 83% rename from Wox/plugin/system/llm/provider_openai.go rename to Wox/plugin/llm/provider_openai.go index ac2750e8e..73b7e286c 100644 --- a/Wox/plugin/system/llm/provider_openai.go +++ b/Wox/plugin/llm/provider_openai.go @@ -7,7 +7,7 @@ import ( ) type OpenAIProvider struct { - connectContext providerConnectContext + connectContext ProviderConnectContext client *openai.Client } @@ -16,7 +16,7 @@ type OpenAIProviderStream struct { conversations []Conversation } -func NewOpenAIClient(ctx context.Context, connectContext providerConnectContext) Provider { +func NewOpenAIClient(ctx context.Context, connectContext ProviderConnectContext) Provider { return &OpenAIProvider{connectContext: connectContext} } @@ -32,7 +32,7 @@ func (o *OpenAIProvider) ensureClient(ctx context.Context) error { return nil } -func (o *OpenAIProvider) ChatStream(ctx context.Context, model model, conversations []Conversation) (ProviderChatStream, error) { +func (o *OpenAIProvider) ChatStream(ctx context.Context, model Model, conversations []Conversation) (ChatStream, error) { if ensureClientErr := o.ensureClient(ctx); ensureClientErr != nil { return nil, ensureClientErr } @@ -49,7 +49,7 @@ func (o *OpenAIProvider) ChatStream(ctx context.Context, model model, conversati return &OpenAIProviderStream{conversations: conversations, stream: createdStream}, nil } -func (o *OpenAIProvider) Chat(ctx context.Context, model model, conversations []Conversation) (string, error) { +func (o *OpenAIProvider) Chat(ctx context.Context, model Model, conversations []Conversation) (string, error) { if ensureClientErr := o.ensureClient(ctx); ensureClientErr != nil { return "", ensureClientErr } @@ -65,12 +65,12 @@ func (o *OpenAIProvider) Chat(ctx context.Context, model model, conversations [] return resp.Choices[0].Message.Content, nil } -func (o *OpenAIProvider) Models(ctx context.Context) ([]model, error) { - return []model{ +func (o *OpenAIProvider) Models(ctx context.Context) ([]Model, error) { + return []Model{ { DisplayName: "chatgpt-3.5-turbo", Name: "gpt-3.5-turbo", - Provider: modelProviderNameOpenAI, + Provider: ModelProviderNameOpenAI, }, }, nil } @@ -78,6 +78,8 @@ func (o *OpenAIProvider) Models(ctx context.Context) ([]model, error) { func (s *OpenAIProviderStream) Receive(ctx context.Context) (string, error) { response, err := s.stream.Recv() if err != nil { + s.stream.Close() + // no more messages if err == io.EOF { return "", io.EOF @@ -92,10 +94,6 @@ func (s *OpenAIProviderStream) Receive(ctx context.Context) (string, error) { return response.Choices[0].Delta.Content, nil } -func (s *OpenAIProviderStream) Close(ctx context.Context) { - s.stream.Close() -} - func (o *OpenAIProvider) convertConversations(conversations []Conversation) []openai.ChatCompletionMessage { var chatMessages []openai.ChatCompletionMessage for _, conversation := range conversations { diff --git a/Wox/plugin/manager.go b/Wox/plugin/manager.go index fc90737ea..2cb746eb0 100644 --- a/Wox/plugin/manager.go +++ b/Wox/plugin/manager.go @@ -855,6 +855,11 @@ func (m *Manager) ExecuteRefresh(ctx context.Context, refreshableResultWithId Re newResult := resultCache.Refresh(ctx, refreshableResult) newResult = m.PolishRefreshableResult(ctx, resultCache.PluginInstance, refreshableResultWithId.ResultId, newResult) + // update result cache + resultCache.ResultTitle = newResult.Title + resultCache.ResultSubTitle = newResult.SubTitle + resultCache.ContextData = newResult.ContextData + return RefreshableResultWithResultId{ ResultId: refreshableResultWithId.ResultId, Title: newResult.Title, diff --git a/Wox/plugin/metadata.go b/Wox/plugin/metadata.go index fd4f64468..49d14d001 100644 --- a/Wox/plugin/metadata.go +++ b/Wox/plugin/metadata.go @@ -25,6 +25,9 @@ const ( // enable this feature to get query env in plugin MetadataFeatureQueryEnv MetadataFeatureName = "queryEnv" + + // enable this feature to chat with llm model in plugin + MetadataFeatureLLMChat MetadataFeatureName = "llmChat" ) // Metadata parsed from plugin.json, see `Plugin.json.md` for more detail diff --git a/Wox/plugin/system/app/app_darwin_test.go b/Wox/plugin/system/app/app_darwin_test.go index ca8d99ec0..47a90a5eb 100644 --- a/Wox/plugin/system/app/app_darwin_test.go +++ b/Wox/plugin/system/app/app_darwin_test.go @@ -48,6 +48,10 @@ func (e emptyAPIImpl) OnSettingChanged(ctx context.Context, callback func(key st func (e emptyAPIImpl) RegisterQueryCommands(ctx context.Context, commands []plugin.MetadataCommand) { } +func (e emptyAPIImpl) LLMChat(ctx context.Context, conversations []plugin.LLMConversation) (string, error) { + return "", nil +} + func TestMacRetriever_ParseAppInfo(t *testing.T) { if util.IsMacOS() { appRetriever.UpdateAPI(emptyAPIImpl{}) diff --git a/Wox/plugin/system/llm/plugin.go b/Wox/plugin/system/llm.go similarity index 92% rename from Wox/plugin/system/llm/plugin.go rename to Wox/plugin/system/llm.go index f9f093759..7cf0e8cc3 100644 --- a/Wox/plugin/system/llm/plugin.go +++ b/Wox/plugin/system/llm.go @@ -1,18 +1,20 @@ -package llm +package system import ( "context" "encoding/json" - "errors" "fmt" "github.com/samber/lo" "github.com/tidwall/gjson" - "io" "strings" + "time" "wox/plugin" + "wox/plugin/llm" "wox/setting/definition" "wox/share" "wox/util" + "wox/util/clipboard" + "wox/util/keyboard" ) var llmIcon = plugin.NewWoxImageBase64(``) @@ -30,9 +32,7 @@ func init() { } type Plugin struct { - api plugin.API - provider Provider - model model + api plugin.API } func (c *Plugin) GetMetadata() plugin.Metadata { @@ -62,23 +62,23 @@ func (c *Plugin) GetMetadata() plugin.Metadata { Key: "provider", Label: "Provider", Tooltip: "The LLM service provider", - DefaultValue: string(modelProviderNameOpenAI), + DefaultValue: string(llm.ModelProviderNameOpenAI), Options: []definition.PluginSettingValueSelectOption{ { Label: "OpenAI", - Value: string(modelProviderNameOpenAI), + Value: string(llm.ModelProviderNameOpenAI), }, { Label: "Google", - Value: string(modelProviderNameGoogle), + Value: string(llm.ModelProviderNameGoogle), }, { Label: "Ollama", - Value: string(modelProviderNameOllama), + Value: string(llm.ModelProviderNameOllama), }, { Label: "Groq", - Value: string(modelProviderNameGroq), + Value: string(llm.ModelProviderNameGroq), }, }, Style: definition.PluginSettingValueStyle{ @@ -151,6 +151,9 @@ func (c *Plugin) GetMetadata() plugin.Metadata { { Name: plugin.MetadataFeatureQuerySelection, }, + { + Name: plugin.MetadataFeatureLLMChat, + }, }, } } @@ -185,7 +188,7 @@ func (c *Plugin) Init(ctx context.Context, initParams plugin.InitParams) { func (c *Plugin) Query(ctx context.Context, query plugin.Query) []plugin.QueryResult { if query.Type == plugin.QueryTypeSelection { - if c.provider == nil || query.Selection.Type == util.SelectionTypeFile { + if !llm.IsInstanceReady() || query.Selection.Type == util.SelectionTypeFile { return []plugin.QueryResult{} } @@ -218,10 +221,10 @@ func (c *Plugin) Query(ctx context.Context, query plugin.Query) []plugin.QueryRe return results } - if c.provider == nil { + if !llm.IsInstanceReady() { return []plugin.QueryResult{ { - Title: "Provider not initialized", + Title: "LLM setting not initialized", SubTitle: "Please complete the settings", Icon: llmIcon, }, @@ -243,7 +246,7 @@ func (c *Plugin) getDynamicSetting(ctx context.Context, key string) definition.P } var options []definition.PluginSettingValueSelectOption - for _, m := range c.getProviderModels(ctx, modelProviderName(provider)) { + for _, m := range c.getProviderModels(ctx, llm.ModelProviderName(provider)) { options = append(options, definition.PluginSettingValueSelectOption{ Label: m.DisplayName, Value: m.Name, @@ -265,11 +268,11 @@ func (c *Plugin) getDynamicSetting(ctx context.Context, key string) definition.P } if key == "dynamic_host" { - if c.api.GetSetting(ctx, "provider") == string(modelProviderNameOllama) { + if c.api.GetSetting(ctx, "provider") == string(llm.ModelProviderNameOllama) { return definition.PluginSettingDefinitionItem{ Type: definition.PluginSettingDefinitionTypeTextBox, Value: &definition.PluginSettingValueTextBox{ - Key: "host_" + string(modelProviderNameOllama), + Key: "host_" + string(llm.ModelProviderNameOllama), Label: "Host", Tooltip: "The Ollama host", DefaultValue: "http://localhost:11434", @@ -283,7 +286,7 @@ func (c *Plugin) getDynamicSetting(ctx context.Context, key string) definition.P } if key == "dynamic_api_key" { - if c.api.GetSetting(ctx, "provider") != string(modelProviderNameOllama) { + if c.api.GetSetting(ctx, "provider") != string(llm.ModelProviderNameOllama) { return definition.PluginSettingDefinitionItem{ Type: definition.PluginSettingDefinitionTypeTextBox, Value: &definition.PluginSettingValueTextBox{ @@ -302,10 +305,9 @@ func (c *Plugin) getDynamicSetting(ctx context.Context, key string) definition.P return definition.PluginSettingDefinitionItem{} } -func (c *Plugin) getProviderModels(ctx context.Context, providerName modelProviderName) (models []model) { - if provider, providerErr := NewProvider(ctx, providerConnectContext{ +func (c *Plugin) getProviderModels(ctx context.Context, providerName llm.ModelProviderName) (models []llm.Model) { + if provider, providerErr := llm.NewProvider(ctx, llm.ProviderConnectContext{ Provider: providerName, - api: c.api, ApiKey: c.api.GetSetting(ctx, string("api_key_"+providerName)), Host: c.api.GetSetting(ctx, string("host_"+providerName)), }); providerErr == nil { @@ -333,9 +335,8 @@ func (c *Plugin) loadClient(ctx context.Context) { return } - provider, providerErr := NewProvider(ctx, providerConnectContext{ - Provider: modelProviderName(providerName), - api: c.api, + provider, providerErr := llm.NewProvider(ctx, llm.ProviderConnectContext{ + Provider: llm.ModelProviderName(providerName), ApiKey: c.api.GetSetting(ctx, "api_key_"+providerName), Host: c.api.GetSetting(ctx, "host_"+providerName), }) @@ -345,25 +346,25 @@ func (c *Plugin) loadClient(ctx context.Context) { } //close previous provider - if c.provider != nil { - closeErr := c.provider.Close(ctx) + currentProvider, currentModel := llm.GetInstance() + if currentProvider != nil { + closeErr := currentProvider.Close(ctx) if closeErr != nil { c.api.Log(ctx, plugin.LogLevelError, fmt.Sprintf("failed to close llm provider: %s", closeErr.Error())) } else { - c.api.Log(ctx, plugin.LogLevelInfo, fmt.Sprintf("llm provider closed, model: %s", c.model.Name)) + c.api.Log(ctx, plugin.LogLevelInfo, fmt.Sprintf("llm provider closed, model: %s", currentModel.Name)) } } - models := c.getProviderModels(ctx, modelProviderName(providerName)) - var availableModel model + var availableModel llm.Model + models := c.getProviderModels(ctx, llm.ModelProviderName(providerName)) for _, m := range models { if m.Name == modelName { availableModel = m break } } - c.model = availableModel - c.provider = provider + llm.SetInstance(provider, availableModel) c.api.Log(ctx, plugin.LogLevelInfo, fmt.Sprintf("llm provider created with model: %s", modelName)) } @@ -470,17 +471,17 @@ func (c *Plugin) queryCommand(ctx context.Context, query plugin.Query) []plugin. } var prompts = strings.Split(queryTool.Prompt, "|") - var conversations []Conversation + var conversations []llm.Conversation for index, message := range prompts { msg := fmt.Sprintf(message, query.Search) if index%2 == 0 { - conversations = append(conversations, Conversation{ - Role: ConversationRoleUser, + conversations = append(conversations, llm.Conversation{ + Role: llm.ConversationRoleUser, Text: msg, }) } else { - conversations = append(conversations, Conversation{ - Role: ConversationRoleSystem, + conversations = append(conversations, llm.Conversation{ + Role: llm.ConversationRoleSystem, Text: msg, }) } @@ -490,6 +491,7 @@ func (c *Plugin) queryCommand(ctx context.Context, query plugin.Query) []plugin. current.Icon = llmLoadingIcon current.Preview.PreviewData += deltaAnswer current.Preview.ScrollPosition = plugin.WoxPreviewScrollPositionBottom + current.ContextData = current.Preview.PreviewData return current } onAnswerErr := func(current plugin.RefreshableResult, err error) plugin.RefreshableResult { @@ -504,74 +506,38 @@ func (c *Plugin) queryCommand(ctx context.Context, query plugin.Query) []plugin. return current } + _, model := llm.GetInstance() return []plugin.QueryResult{{ Title: fmt.Sprintf("Chat with %s", query.Command), - SubTitle: fmt.Sprintf("%s - %s", c.model.Provider, c.model.DisplayName), + SubTitle: fmt.Sprintf("%s - %s", model.Provider, model.DisplayName), Preview: plugin.WoxPreview{PreviewType: plugin.WoxPreviewTypeMarkdown, PreviewData: ""}, Icon: llmLoadingIcon, RefreshInterval: 100, - OnRefresh: c.generateGptResultRefresh(ctx, conversations, onAnswering, onAnswerErr, onAnswerFinished), + OnRefresh: createLLMOnRefreshHandler(ctx, c.api.LLMChatStream, conversations, func() bool { + return true + }, onAnswering, onAnswerErr, onAnswerFinished), + Actions: []plugin.QueryResultAction{ + { + Name: "Copy", + Action: func(ctx context.Context, actionContext plugin.ActionContext) { + clipboard.WriteText(actionContext.ContextData) + }, + }, + { + Name: "Copy and Paste to active app", + Action: func(ctx context.Context, actionContext plugin.ActionContext) { + clipboard.WriteText(actionContext.ContextData) + util.Go(context.Background(), "clipboard to copy", func() { + time.Sleep(time.Millisecond * 100) + err := keyboard.SimulatePaste() + if err != nil { + c.api.Log(ctx, plugin.LogLevelError, fmt.Sprintf("simulate paste clipboard failed, err=%s", err.Error())) + } else { + c.api.Log(ctx, plugin.LogLevelInfo, "simulate paste clipboard success") + } + }) + }, + }, + }, }} } - -// generate a result which will send chat messages to openai and show the result automatically -func (c *Plugin) generateGptResultRefresh(ctx context.Context, conversations []Conversation, - onAnswering func(plugin.RefreshableResult, string) plugin.RefreshableResult, - onAnswerErr func(plugin.RefreshableResult, error) plugin.RefreshableResult, - onAnswerFinished func(plugin.RefreshableResult) plugin.RefreshableResult) func(ctx context.Context, current plugin.RefreshableResult) plugin.RefreshableResult { - - var stream ProviderChatStream - var creatingStream bool - return func(ctx context.Context, current plugin.RefreshableResult) plugin.RefreshableResult { - if stream == nil { - if creatingStream { - c.api.Log(ctx, plugin.LogLevelInfo, "Already creating stream, waiting create finish") - return current - } - - startTime := util.GetSystemTimestamp() - c.api.Log(ctx, plugin.LogLevelInfo, "creating stream") - creatingStream = true - createdStream, createErr := c.provider.ChatStream(ctx, c.model, conversations) - creatingStream = false - c.api.Log(ctx, plugin.LogLevelInfo, fmt.Sprintf("created stream (cost %d ms)", util.GetSystemTimestamp()-startTime)) - if createErr != nil { - if onAnswerErr != nil { - current = onAnswerErr(current, createErr) - } - current.RefreshInterval = 0 // stop refreshing - return current - } - stream = createdStream - } - - c.api.Log(ctx, plugin.LogLevelInfo, fmt.Sprintf("reading stream, model=%s", c.model.Name)) - response, streamErr := stream.Receive(ctx) - if errors.Is(streamErr, io.EOF) { - c.api.Log(ctx, plugin.LogLevelInfo, "read stream completed") - stream.Close(ctx) - if onAnswerFinished != nil { - current = onAnswerFinished(current) - } - current.RefreshInterval = 0 // stop refreshing - return current - } - - if streamErr != nil { - c.api.Log(ctx, plugin.LogLevelError, fmt.Sprintf("failed to read stream: %s", streamErr.Error())) - stream.Close(ctx) - if onAnswerErr != nil { - current = onAnswerErr(current, streamErr) - } - current.RefreshInterval = 0 // stop refreshing - return current - } - - if onAnswering != nil { - c.api.Log(ctx, plugin.LogLevelInfo, fmt.Sprintf("streamed %d text", len(response))) - current = onAnswering(current, response) - } - - return current - } -} diff --git a/Wox/plugin/system/llm/provider.go b/Wox/plugin/system/llm/provider.go deleted file mode 100644 index d2d5a31e0..000000000 --- a/Wox/plugin/system/llm/provider.go +++ /dev/null @@ -1,79 +0,0 @@ -package llm - -import ( - "context" - "errors" - "wox/plugin" -) - -type ConversationRole string - -var ( - ConversationRoleUser ConversationRole = "user" - ConversationRoleSystem ConversationRole = "system" -) - -type Conversation struct { - Role ConversationRole - Text string - Timestamp int64 -} - -type modelProviderName string - -var ( - modelProviderNameOpenAI modelProviderName = "openai" - modelProviderNameGoogle modelProviderName = "google" - modelProviderNameOllama modelProviderName = "ollama" - modelProviderNameGroq modelProviderName = "groq" -) - -var modelProviderNames = []modelProviderName{ - modelProviderNameOpenAI, - modelProviderNameGoogle, - modelProviderNameOllama, - modelProviderNameGroq, -} - -type model struct { - DisplayName string - Name string - Provider modelProviderName -} - -type providerConnectContext struct { - Provider modelProviderName - api plugin.API - - ApiKey string - Host string // E.g. "https://api.openai.com:8908" -} - -type Provider interface { - Close(ctx context.Context) error - ChatStream(ctx context.Context, model model, conversations []Conversation) (ProviderChatStream, error) - Chat(ctx context.Context, model model, conversations []Conversation) (string, error) - Models(ctx context.Context) ([]model, error) -} - -type ProviderChatStream interface { - Receive(ctx context.Context) (string, error) // will return io.EOF if no more messages - Close(ctx context.Context) -} - -func NewProvider(ctx context.Context, connectContext providerConnectContext) (Provider, error) { - if connectContext.Provider == modelProviderNameGoogle { - return NewGoogleProvider(ctx, connectContext), nil - } - if connectContext.Provider == modelProviderNameOpenAI { - return NewOpenAIClient(ctx, connectContext), nil - } - if connectContext.Provider == modelProviderNameOllama { - return NewOllamaProvider(ctx, connectContext), nil - } - if connectContext.Provider == modelProviderNameGroq { - return NewGroqProvider(ctx, connectContext), nil - } - - return nil, errors.New("unknown model provider") -} diff --git a/Wox/plugin/system/theme.go b/Wox/plugin/system/theme.go index e5f755628..2647d4038 100644 --- a/Wox/plugin/system/theme.go +++ b/Wox/plugin/system/theme.go @@ -2,9 +2,15 @@ package system import ( "context" + "encoding/json" + "fmt" + "github.com/google/uuid" "github.com/samber/lo" "wox/plugin" + "wox/plugin/llm" + "wox/resource" "wox/share" + "wox/util" ) var themeIcon = plugin.NewWoxImageBase64(``) @@ -32,7 +38,12 @@ func (c *ThemePlugin) GetMetadata() plugin.Metadata { TriggerKeywords: []string{ "theme", }, - Commands: []plugin.MetadataCommand{}, + Commands: []plugin.MetadataCommand{ + { + Command: "ai", + Description: "Generate a new theme with AI", + }, + }, SupportedOS: []string{ "Windows", "Macos", @@ -46,6 +57,10 @@ func (c *ThemePlugin) Init(ctx context.Context, initParams plugin.InitParams) { } func (c *ThemePlugin) Query(ctx context.Context, query plugin.Query) []plugin.QueryResult { + if query.Command == "ai" { + return c.queryAI(ctx, query) + } + ui := plugin.GetPluginManager().GetUI() return lo.FilterMap(ui.GetAllThemes(ctx), func(theme share.Theme, _ int) (plugin.QueryResult, bool) { match, _ := IsStringMatchScore(ctx, theme.ThemeName, query.Search) @@ -68,3 +83,107 @@ func (c *ThemePlugin) Query(ctx context.Context, query plugin.Query) []plugin.Qu } }) } + +func (c *ThemePlugin) queryAI(ctx context.Context, query plugin.Query) []plugin.QueryResult { + if query.Search == "" { + return []plugin.QueryResult{ + { + Title: "Please describe the theme you want to generate", + Icon: themeIcon, + }, + } + } + + embedThemes := resource.GetEmbedThemes(ctx) + if len(embedThemes) == 0 { + return []plugin.QueryResult{ + { + Title: "No embed theme found", + Icon: themeIcon, + }, + } + } + + exampleThemeJson := embedThemes[0] + + var conversations []llm.Conversation + conversations = append(conversations, llm.Conversation{ + Role: llm.ConversationRoleUser, + Text: ` +我正在编写Wox的主题,该主题是由一段json组成,例如:` + exampleThemeJson + ` + +现在我想让你根据上面的格式生成一个新的主题,主题的要求是:` + query.Search + `。 + +有一些注意点需要你遵守: +1. 你的回答结果必须是JSON格式,且只能回答json相关内容,忽略解释,注释等信息 +2. 主题名称你自己决定,主题ID为随机生成的UUID +3. 主题作者统一为:Wox launcher AI +4. IsSystemTheme字段必须为false +`, + }) + + onAnswering := func(current plugin.RefreshableResult, deltaAnswer string) plugin.RefreshableResult { + current.SubTitle = "Generating..." + current.Preview.PreviewData += deltaAnswer + current.Preview.ScrollPosition = plugin.WoxPreviewScrollPositionBottom + current.ContextData = current.Preview.PreviewData + return current + } + onAnswerErr := func(current plugin.RefreshableResult, err error) plugin.RefreshableResult { + current.Preview.PreviewData += fmt.Sprintf("\n\nError: %s", err.Error()) + current.RefreshInterval = 0 // stop refreshing + return current + } + onAnswerFinished := func(current plugin.RefreshableResult) plugin.RefreshableResult { + current.RefreshInterval = 0 // stop refreshing + current.Title = "Theme generated" + util.Go(ctx, "theme generated", func() { + themeJson := current.ContextData + if themeJson == "" { + return + } + + // use regex to get json snippet from the whole text + group := util.FindRegexGroup(`(?ms){(?P.*?)}`, themeJson) + if len(group) == 0 { + c.api.Notify(ctx, "Failed to extract json", "") + return + } + + var jsonTheme = fmt.Sprintf("{%s}", group["json"]) + var theme share.Theme + unmarshalErr := json.Unmarshal([]byte(jsonTheme), &theme) + if unmarshalErr != nil { + c.api.Notify(ctx, "Failed to unmarshal theme json", unmarshalErr.Error()) + return + } + + theme.ThemeId = uuid.NewString() + plugin.GetPluginManager().GetUI().InstallTheme(ctx, theme) + }) + return current + } + + startGenerate := false + return []plugin.QueryResult{ + { + Title: "Generate theme with ai", + SubTitle: "Enter to generate", + Icon: themeIcon, + Preview: plugin.WoxPreview{PreviewType: plugin.WoxPreviewTypeMarkdown, PreviewData: ""}, + RefreshInterval: 100, + OnRefresh: createLLMOnRefreshHandler(ctx, c.api.LLMChatStream, conversations, func() bool { + return startGenerate + }, onAnswering, onAnswerErr, onAnswerFinished), + Actions: []plugin.QueryResultAction{ + { + Name: "Apply", + PreventHideAfterAction: true, + Action: func(ctx context.Context, actionContext plugin.ActionContext) { + startGenerate = true + }, + }, + }, + }, + } +} diff --git a/Wox/plugin/system/util.go b/Wox/plugin/system/util.go index 462b204a9..a72df897f 100644 --- a/Wox/plugin/system/util.go +++ b/Wox/plugin/system/util.go @@ -3,6 +3,7 @@ package system import ( "context" "crypto/md5" + "errors" "fmt" "github.com/disintegration/imaging" "github.com/mat/besticon/besticon" @@ -11,6 +12,7 @@ import ( "os" "path" "wox/plugin" + "wox/plugin/llm" "wox/setting" "wox/util" ) @@ -94,3 +96,69 @@ func getWebsiteIconWithCache(ctx context.Context, websiteUrl string) (plugin.Wox return woxImage, nil } + +func createLLMOnRefreshHandler(ctx context.Context, + chatStream func(ctx context.Context, conversations []llm.Conversation) (llm.ChatStream, error), + conversations []llm.Conversation, + shouldStartAnswering func() bool, + onAnswering func(plugin.RefreshableResult, string) plugin.RefreshableResult, + onAnswerErr func(plugin.RefreshableResult, error) plugin.RefreshableResult, + onAnswerFinished func(plugin.RefreshableResult) plugin.RefreshableResult) func(ctx context.Context, current plugin.RefreshableResult) plugin.RefreshableResult { + + var stream llm.ChatStream + var creatingStream bool + return func(ctx context.Context, current plugin.RefreshableResult) plugin.RefreshableResult { + if !shouldStartAnswering() { + return current + } + + if stream == nil { + if creatingStream { + util.GetLogger().Info(ctx, "Already creating stream, waiting create finish") + return current + } + + startTime := util.GetSystemTimestamp() + util.GetLogger().Info(ctx, "creating stream") + creatingStream = true + createdStream, createErr := chatStream(ctx, conversations) + creatingStream = false + util.GetLogger().Info(ctx, fmt.Sprintf("created stream (cost %d ms)", util.GetSystemTimestamp()-startTime)) + if createErr != nil { + if onAnswerErr != nil { + current = onAnswerErr(current, createErr) + } + current.RefreshInterval = 0 // stop refreshing + return current + } + stream = createdStream + } + + util.GetLogger().Info(ctx, fmt.Sprintf("reading stream")) + response, streamErr := stream.Receive(ctx) + if errors.Is(streamErr, io.EOF) { + util.GetLogger().Info(ctx, "read stream completed") + if onAnswerFinished != nil { + current = onAnswerFinished(current) + } + current.RefreshInterval = 0 // stop refreshing + return current + } + + if streamErr != nil { + util.GetLogger().Info(ctx, fmt.Sprintf("failed to read stream: %s", streamErr.Error())) + if onAnswerErr != nil { + current = onAnswerErr(current, streamErr) + } + current.RefreshInterval = 0 // stop refreshing + return current + } + + if onAnswering != nil { + util.GetLogger().Info(ctx, fmt.Sprintf("streamed %d text", len(response))) + current = onAnswering(current, response) + } + + return current + } +} diff --git a/Wox/resource/resource.go b/Wox/resource/resource.go index 28fcca8f4..2bc3b090e 100644 --- a/Wox/resource/resource.go +++ b/Wox/resource/resource.go @@ -92,9 +92,6 @@ func parseThemes(ctx context.Context) error { if err != nil { return err } - if err != nil { - return err - } if len(dir) == 0 { return fmt.Errorf("no theme file found") } diff --git a/Wox/share/ui.go b/Wox/share/ui.go index a16d3791d..7fc9192c9 100644 --- a/Wox/share/ui.go +++ b/Wox/share/ui.go @@ -45,6 +45,7 @@ type UI interface { GetServerPort(ctx context.Context) int GetAllThemes(ctx context.Context) []Theme ChangeTheme(ctx context.Context, theme Theme) + InstallTheme(ctx context.Context, theme Theme) } type ShowContext struct { diff --git a/Wox/ui/ui_impl.go b/Wox/ui/ui_impl.go index 3e46749ce..84faff7ac 100644 --- a/Wox/ui/ui_impl.go +++ b/Wox/ui/ui_impl.go @@ -58,6 +58,11 @@ func (u *uiImpl) ChangeTheme(ctx context.Context, theme share.Theme) { u.invokeWebsocketMethod(ctx, "ChangeTheme", theme) } +func (u *uiImpl) InstallTheme(ctx context.Context, theme share.Theme) { + logger.Info(ctx, fmt.Sprintf("install theme: %s", theme.ThemeName)) + GetUIManager().AddTheme(ctx, theme) +} + func (u *uiImpl) OpenSettingWindow(ctx context.Context, windowContext share.SettingWindowContext) { u.invokeWebsocketMethod(ctx, "OpenSettingWindow", windowContext) }