Skip to content

Commit

Permalink
Add ai to theme plugin which allow users to generate theme by AI.
Browse files Browse the repository at this point in the history
  • Loading branch information
qianlifeng committed May 31, 2024
1 parent 9e0b9a9 commit efd1b9d
Show file tree
Hide file tree
Showing 20 changed files with 416 additions and 233 deletions.
1 change: 1 addition & 0 deletions Wox/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ require (
github.com/otiai10/copy v1.14.0
github.com/parsiya/golnk v0.0.0-20221103095132-740a4c27c4ff
github.com/petermattis/goid v0.0.0-20240327183114-c42a807a84ba
github.com/pkg/errors v0.9.1
github.com/robotn/gohook v0.41.0
github.com/rs/cors v1.10.1
github.com/sahilm/fuzzy v0.1.1
Expand Down
2 changes: 2 additions & 0 deletions Wox/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,8 @@ github.com/parsiya/golnk v0.0.0-20221103095132-740a4c27c4ff h1:japdIZgV4tJIgn7Nq
github.com/parsiya/golnk v0.0.0-20221103095132-740a4c27c4ff/go.mod h1:A24WXUol4NXZlK8grjh/CsZnPlimfwaQFt5PQsqS27s=
github.com/petermattis/goid v0.0.0-20240327183114-c42a807a84ba h1:3jPgmsFGBID1wFfU2AbYocNcN4wqU68UaHSdMjiw/7U=
github.com/petermattis/goid v0.0.0-20240327183114-c42a807a84ba/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkoukk/tiktoken-go v0.1.6 h1:JF0TlJzhTbrI30wCvFuiw6FzP2+/bR+FIxUdgEAcUsw=
github.com/pkoukk/tiktoken-go v0.1.6/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
Expand Down
1 change: 0 additions & 1 deletion Wox/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import _ "wox/plugin/host" // import all hosts
import _ "wox/plugin/system"
import _ "wox/plugin/system/app"
import _ "wox/plugin/system/calculator"
import _ "wox/plugin/system/llm"
import _ "wox/plugin/system/file"

func main() {
Expand Down
21 changes: 21 additions & 0 deletions Wox/plugin/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"github.com/samber/lo"
"path"
"wox/i18n"
"wox/plugin/llm"
"wox/setting"
"wox/setting/definition"
"wox/share"
Expand Down Expand Up @@ -33,6 +34,8 @@ type API interface {
OnSettingChanged(ctx context.Context, callback func(key string, value string))
OnGetDynamicSetting(ctx context.Context, callback func(key string) definition.PluginSettingDefinitionItem)
RegisterQueryCommands(ctx context.Context, commands []MetadataCommand)
LLMChat(ctx context.Context, conversations []llm.Conversation) (string, error)
LLMChatStream(ctx context.Context, conversations []llm.Conversation) (llm.ChatStream, error)
}

type APIImpl struct {
Expand Down Expand Up @@ -152,6 +155,24 @@ func (a *APIImpl) RegisterQueryCommands(ctx context.Context, commands []Metadata
a.pluginInstance.SaveSetting(ctx)
}

func (a *APIImpl) LLMChat(ctx context.Context, conversations []llm.Conversation) (string, error) {
provider, model := llm.GetInstance()
if provider == nil {
return "", fmt.Errorf("no LLM provider found")
}

return provider.Chat(ctx, model, conversations)
}

func (a *APIImpl) LLMChatStream(ctx context.Context, conversations []llm.Conversation) (llm.ChatStream, error) {
provider, model := llm.GetInstance()
if provider == nil {
return nil, fmt.Errorf("no LLM provider found")
}

return provider.ChatStream(ctx, model, conversations)
}

func NewAPI(instance *Instance) API {
apiImpl := &APIImpl{pluginInstance: instance}
logFolder := path.Join(util.GetLocation().GetLogPluginDirectory(), instance.Metadata.Name)
Expand Down
17 changes: 17 additions & 0 deletions Wox/plugin/llm/instance.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package llm

var provider Provider
var model Model

func GetInstance() (Provider, Model) {
return provider, model
}

func SetInstance(p Provider, m Model) {
provider = p
model = m
}

func IsInstanceReady() bool {
return provider != nil && model.Name != ""
}
69 changes: 69 additions & 0 deletions Wox/plugin/llm/provider.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package llm

import (
"context"
"errors"
)

type ConversationRole string

var (
ConversationRoleUser ConversationRole = "user"
ConversationRoleSystem ConversationRole = "system"
)

type Conversation struct {
Role ConversationRole
Text string
Timestamp int64
}

type ModelProviderName string

var (
ModelProviderNameOpenAI ModelProviderName = "openai"
ModelProviderNameGoogle ModelProviderName = "google"
ModelProviderNameOllama ModelProviderName = "ollama"
ModelProviderNameGroq ModelProviderName = "groq"
)

type Model struct {
DisplayName string
Name string
Provider ModelProviderName
}

type Provider interface {
Close(ctx context.Context) error
ChatStream(ctx context.Context, model Model, conversations []Conversation) (ChatStream, error)
Chat(ctx context.Context, model Model, conversations []Conversation) (string, error)
Models(ctx context.Context) ([]Model, error)
}

type ChatStream interface {
Receive(ctx context.Context) (string, error) // will return io.EOF if no more messages
}

type ProviderConnectContext struct {
Provider ModelProviderName

ApiKey string
Host string // E.g. "https://api.openai.com:8908"
}

func NewProvider(ctx context.Context, connectContext ProviderConnectContext) (Provider, error) {
if connectContext.Provider == ModelProviderNameGoogle {
return NewGoogleProvider(ctx, connectContext), nil
}
if connectContext.Provider == ModelProviderNameOpenAI {
return NewOpenAIClient(ctx, connectContext), nil
}
if connectContext.Provider == ModelProviderNameOllama {
return NewOllamaProvider(ctx, connectContext), nil
}
if connectContext.Provider == ModelProviderNameGroq {
return NewGroqProvider(ctx, connectContext), nil
}

return nil, errors.New("unknown model provider")
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
)

type GoogleProvider struct {
connectContext providerConnectContext
connectContext ProviderConnectContext
client *genai.Client
}

Expand All @@ -20,7 +20,7 @@ type GoogleProviderStream struct {
conversations []Conversation
}

func NewGoogleProvider(ctx context.Context, connectContext providerConnectContext) Provider {
func NewGoogleProvider(ctx context.Context, connectContext ProviderConnectContext) Provider {
return &GoogleProvider{connectContext: connectContext}
}

Expand All @@ -44,7 +44,7 @@ func (g *GoogleProvider) Close(ctx context.Context) error {
return nil
}

func (g *GoogleProvider) ChatStream(ctx context.Context, model model, conversations []Conversation) (ProviderChatStream, error) {
func (g *GoogleProvider) ChatStream(ctx context.Context, model Model, conversations []Conversation) (ChatStream, error) {
if ensureClientErr := g.ensureClient(ctx); ensureClientErr != nil {
return nil, ensureClientErr
}
Expand All @@ -57,7 +57,7 @@ func (g *GoogleProvider) ChatStream(ctx context.Context, model model, conversati
return &GoogleProviderStream{conversations: conversations, stream: stream}, nil
}

func (g *GoogleProvider) Chat(ctx context.Context, model model, conversations []Conversation) (string, error) {
func (g *GoogleProvider) Chat(ctx context.Context, model Model, conversations []Conversation) (string, error) {
if ensureClientErr := g.ensureClient(ctx); ensureClientErr != nil {
return "", ensureClientErr
}
Expand All @@ -80,17 +80,17 @@ func (g *GoogleProvider) Chat(ctx context.Context, model model, conversations []
return "", errors.New("no text in response")
}

func (g *GoogleProvider) Models(ctx context.Context) ([]model, error) {
return []model{
func (g *GoogleProvider) Models(ctx context.Context) ([]Model, error) {
return []Model{
{
DisplayName: "google-gemini-1.0-pro",
Name: "gemini-1.0-pro",
Provider: modelProviderNameGoogle,
Provider: ModelProviderNameGoogle,
},
{
DisplayName: "google-gemini-1.5-pro",
Name: "gemini-1.5-pro",
Provider: modelProviderNameGoogle,
Provider: ModelProviderNameGoogle,
},
}, nil
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,30 +10,28 @@ import (
"github.com/tmc/langchaingo/llms/openai"
"github.com/tmc/langchaingo/schema"
"io"
"wox/plugin"
"wox/util"
)

type GroqProvider struct {
connectContext providerConnectContext
connectContext ProviderConnectContext
client *openai.LLM
}

type GroqProviderStream struct {
conversations []Conversation
reader io.Reader
api plugin.API
}

func NewGroqProvider(ctx context.Context, connectContext providerConnectContext) Provider {
func NewGroqProvider(ctx context.Context, connectContext ProviderConnectContext) Provider {
return &GroqProvider{connectContext: connectContext}
}

func (g *GroqProvider) Close(ctx context.Context) error {
return nil
}

func (o *GroqProvider) ChatStream(ctx context.Context, model model, conversations []Conversation) (ProviderChatStream, error) {
func (o *GroqProvider) ChatStream(ctx context.Context, model Model, conversations []Conversation) (ChatStream, error) {
client, clientErr := openai.New(openai.WithModel(model.Name), openai.WithBaseURL("https://api.groq.com/openai/v1"), openai.WithToken(o.connectContext.ApiKey))
if clientErr != nil {
return nil, clientErr
Expand All @@ -43,7 +41,6 @@ func (o *GroqProvider) ChatStream(ctx context.Context, model model, conversation
r, w := nio.Pipe(buf)
util.Go(ctx, "Groq chat stream", func() {
_, err := client.GenerateContent(ctx, o.convertConversations(conversations), llms.WithStreamingFunc(func(ctx context.Context, chunk []byte) error {
o.connectContext.api.Log(ctx, plugin.LogLevelDebug, fmt.Sprintf("Groq: receive chunks from model: %s", string(chunk)))
w.Write(chunk)
return nil
}))
Expand All @@ -57,7 +54,7 @@ func (o *GroqProvider) ChatStream(ctx context.Context, model model, conversation
return &GroqProviderStream{conversations: conversations, reader: r}, nil
}

func (o *GroqProvider) Chat(ctx context.Context, model model, conversations []Conversation) (string, error) {
func (o *GroqProvider) Chat(ctx context.Context, model Model, conversations []Conversation) (string, error) {
client, clientErr := openai.New(openai.WithModel(model.Name), openai.WithBaseURL("https://api.groq.com/openai/v1"), openai.WithToken(o.connectContext.ApiKey))
if clientErr != nil {
return "", clientErr
Expand All @@ -71,27 +68,27 @@ func (o *GroqProvider) Chat(ctx context.Context, model model, conversations []Co
return response.Choices[0].Content, nil
}

func (o *GroqProvider) Models(ctx context.Context) (models []model, err error) {
return []model{
func (o *GroqProvider) Models(ctx context.Context) (models []Model, err error) {
return []Model{
{
Name: "llama3-8b-8192",
DisplayName: "llama3-8b-8192",
Provider: modelProviderNameGroq,
Provider: ModelProviderNameGroq,
},
{
Name: "llama3-70b-8192",
DisplayName: "llama3-70b-8192",
Provider: modelProviderNameGroq,
Provider: ModelProviderNameGroq,
},
{
Name: "mixtral-8x7b-32768",
DisplayName: "mixtral-8x7b-32768",
Provider: modelProviderNameGroq,
Provider: ModelProviderNameGroq,
},
{
Name: "gemma-7b-it",
DisplayName: "gemma-7b-it",
Provider: modelProviderNameGroq,
Provider: ModelProviderNameGroq,
},
}, nil
}
Expand Down Expand Up @@ -123,7 +120,3 @@ func (s *GroqProviderStream) Receive(ctx context.Context) (string, error) {
util.GetLogger().Debug(util.NewTraceContext(), fmt.Sprintf("Groq: Send response: %s", resp))
return resp, nil
}

func (s *GroqProviderStream) Close(ctx context.Context) {
// no-op
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,11 @@ import (
"github.com/tmc/langchaingo/llms/ollama"
"github.com/tmc/langchaingo/schema"
"io"
"wox/plugin"
"wox/util"
)

type OllamaProvider struct {
connectContext providerConnectContext
connectContext ProviderConnectContext
client *ollama.LLM
}

Expand All @@ -25,15 +24,15 @@ type OllamaProviderStream struct {
reader io.Reader
}

func NewOllamaProvider(ctx context.Context, connectContext providerConnectContext) Provider {
func NewOllamaProvider(ctx context.Context, connectContext ProviderConnectContext) Provider {
return &OllamaProvider{connectContext: connectContext}
}

func (o *OllamaProvider) Close(ctx context.Context) error {
return nil
}

func (o *OllamaProvider) ChatStream(ctx context.Context, model model, conversations []Conversation) (ProviderChatStream, error) {
func (o *OllamaProvider) ChatStream(ctx context.Context, model Model, conversations []Conversation) (ChatStream, error) {
client, clientErr := ollama.New(ollama.WithServerURL(o.connectContext.Host), ollama.WithModel(model.Name))
if clientErr != nil {
return nil, clientErr
Expand All @@ -43,7 +42,6 @@ func (o *OllamaProvider) ChatStream(ctx context.Context, model model, conversati
r, w := nio.Pipe(buf)
util.Go(ctx, "ollama chat stream", func() {
_, err := client.GenerateContent(ctx, o.convertConversations(conversations), llms.WithStreamingFunc(func(ctx context.Context, chunk []byte) error {
o.connectContext.api.Log(ctx, plugin.LogLevelDebug, fmt.Sprintf("OLLAMA: receive chunks from model: %s", string(chunk)))
w.Write(chunk)
return nil
}))
Expand All @@ -57,7 +55,7 @@ func (o *OllamaProvider) ChatStream(ctx context.Context, model model, conversati
return &OllamaProviderStream{conversations: conversations, reader: r}, nil
}

func (o *OllamaProvider) Chat(ctx context.Context, model model, conversations []Conversation) (string, error) {
func (o *OllamaProvider) Chat(ctx context.Context, model Model, conversations []Conversation) (string, error) {
client, clientErr := ollama.New(ollama.WithServerURL(o.connectContext.Host), ollama.WithModel(model.Name))
if clientErr != nil {
return "", clientErr
Expand All @@ -71,17 +69,17 @@ func (o *OllamaProvider) Chat(ctx context.Context, model model, conversations []
return response.Choices[0].Content, nil
}

func (o *OllamaProvider) Models(ctx context.Context) (models []model, err error) {
func (o *OllamaProvider) Models(ctx context.Context) (models []Model, err error) {
body, err := util.HttpGet(ctx, o.connectContext.Host+"/api/tags")
if err != nil {
return nil, err
}

gjson.Get(string(body), "models.#.name").ForEach(func(key, value gjson.Result) bool {
models = append(models, model{
models = append(models, Model{
DisplayName: value.String(),
Name: value.String(),
Provider: modelProviderNameOllama,
Provider: ModelProviderNameOllama,
})
return true
})
Expand Down Expand Up @@ -116,7 +114,3 @@ func (s *OllamaProviderStream) Receive(ctx context.Context) (string, error) {
util.GetLogger().Debug(util.NewTraceContext(), fmt.Sprintf("OLLAMA: Send response: %s", resp))
return resp, nil
}

func (s *OllamaProviderStream) Close(ctx context.Context) {
// no-op
}
Loading

0 comments on commit efd1b9d

Please sign in to comment.