Skip to content

Commit

Permalink
Fixed Bryley#50, Add Spark LLM support
Browse files Browse the repository at this point in the history
Add `selected_model_index` config option
  • Loading branch information
skyfireitdiy committed Sep 15, 2023
1 parent 248c200 commit 2e69aeb
Show file tree
Hide file tree
Showing 5 changed files with 375 additions and 11 deletions.
29 changes: 29 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -178,12 +178,18 @@ require("neoai").setup({
output_popup_height = 80, -- As percentage eg. 80%
submit = "<Enter>", -- Key binding to submit the prompt
},
selected_model_index = 0,
models = {
{
name = "openai",
model = "gpt-3.5-turbo",
params = nil,
},
{
name = "spark",
model = "v1",
params = nil,
}
},
register_output = {
["g"] = function(output)
Expand Down Expand Up @@ -222,6 +228,16 @@ require("neoai").setup({
-- end,
},
},
spark = {
random_threshold = 0.5,
max_tokens = 4096,
version = "v1",
api_key = {
appid_env = "SPARK_APPID",
secret_env = "SPARK_SECRET",
apikey_env = "SPARK_APIKEY",
},
},
shortcuts = {
{
name = "textify",
Expand Down Expand Up @@ -305,6 +321,19 @@ end
- `api_key.value`: The OpenAI API key, which takes precedence over `api_key .env`.
- `api_key.get`: A function that retrieves the OpenAI API key. For an example implementation, refer to the [Setup](#Setup) section. It has the higher precedence.

### Spark Options:
- `random_threshold` Kernel sampling threshold. Used to determine the randomness of the outcome, the higher the value, the stronger the randomness, that is, the higher the probability of different answers to the same question
- `max_tokens` The maximum length of tokens answered by the model
- `version` The model version, `v1` or `v2`
- `api_key.appid_env` The environment variable containing the Spark appid. The default value is "SPARK_APPID".
- `api_key.secret_env` The environment variable containing the Spark secret key. The default value is "SPARK_SECRET".
- `api_key.apikey_env` The environment variable containing the Spark api key. The default value is "SPARK_APIKEY".
- `api_key.appid` App appid, obtained from an app created in the Open Platform console
- `api_key.secret` App secret key, btained from an app created in the Open Platform console
- `api_key.apikey` App api key, btained from an app created in the Open Platform console
- `api_key.get` A function that retrieves the Spark API key. For an example implementation, refer to the [Setup](#Setup) section. It has the higher precedence.


### Mappings
- `mappings`: A table containing the following actions that can be keys:

Expand Down
2 changes: 1 addition & 1 deletion lua/neoai/chat/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ local append_to_output = nil

---@type {name: ModelModule, model: string, params: table<string, string> | nil}[] A list of models
M.models = {}
M.selected_model = 0
M.selected_model = config.options.selected_model_index

M.setup_models = function()
for _, model_obj in ipairs(config.options.models) do
Expand Down
108 changes: 108 additions & 0 deletions lua/neoai/chat/models/spark.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
local utils = require("neoai.utils")
local config = require("neoai.config")

---@type ModelModule
local M = {}

M.name = "Spark"

M._chunks = {}
local raw_chunks = {}

M.get_current_output = function()
return table.concat(M._chunks, "")
end

---@param chunk string
---@param on_stdout_chunk fun(chunk: string) Function to call whenever a stdout chunk occurs
M._recieve_chunk = function(chunk, on_stdout_chunk)
for line in chunk:gmatch("[^\n]+") do
local raw_json = line

table.insert(raw_chunks, raw_json)

local ok, path = pcall(vim.json.decode, raw_json)
if not ok then
on_stdout_chunk("decode error")
goto continue
end

path = path.payload
if path == nil then
goto continue
end

path = path.choices
if path == nil then
goto continue
end

path = path.text
if path == nil then
goto continue
end

path = path[1]
if path == nil then
goto continue
end

path = path.content
if path == nil then
goto continue
end


on_stdout_chunk(path)
-- append_to_output(path, 0)
table.insert(M._chunks, path)
::continue::
end
end

---@param chat_history ChatHistory
---@param on_stdout_chunk fun(chunk: string) Function to call whenever a stdout chunk occurs
---@param on_complete fun(err?: string, output?: string) Function to call when model has finished
M.send_to_model = function(chat_history, on_stdout_chunk, on_complete)
local appid, secret, apikey = config.options.spark.api_key.get()
local ver = config.options.spark.version
local random_threshold = config.options.spark.random_threshold
local max_tokens = config.options.spark.max_tokens

local get_script_dir = function()
local info = debug.getinfo(1, "S")
local script_path = info.source:sub(2)
return script_path:match("(.*/)")
end

local py_script_path = get_script_dir() .. "/spark.py"

chunks = {}
raw_chunks = {}
utils.exec(py_script_path, {
appid,
secret,
apikey,
vim.json.encode(chat_history.messages),
"--ver",
ver,
"--random_threshold",
random_threshold,
"--max_tokens",
max_tokens
}, function(chunk)
M._recieve_chunk(chunk, on_stdout_chunk)
end, function(err, _)
local total_message = table.concat(raw_chunks, "")
local ok, json = pcall(vim.json.decode, total_message)
if ok then
if json.error ~= nil then
on_complete(json.error.message, nil)
return
end
end
on_complete(err, M.get_current_output())
end)
end

return M
172 changes: 172 additions & 0 deletions lua/neoai/chat/models/spark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
ba#!/usr/bin/env python3

import _thread as thread
import base64
import datetime
import hashlib
import hmac
import json
from urllib.parse import urlparse
import ssl
from datetime import datetime
from time import mktime
from urllib.parse import urlencode
from wsgiref.handlers import format_date_time
import argparse

import websocket # 使用websocket_client
answer = ""


class WSParam(object):
# 初始化
def __init__(self, APPID, APIKey, APISecret, spark_url):
self.APPID = APPID
self.APIKey = APIKey
self.APISecret = APISecret
self.host = urlparse(spark_url).netloc
self.path = urlparse(spark_url).path
self.spark_url = spark_url

# 生成url
def create_url(self):
# 生成RFC1123格式的时间戳
now = datetime.now()
date = format_date_time(mktime(now.timetuple()))

# 拼接字符串
signature_origin = "host: " + self.host + "\n"
signature_origin += "date: " + date + "\n"
signature_origin += "GET " + self.path + " HTTP/1.1"

# 进行hmac-sha256进行加密
signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'),
digestmod=hashlib.sha256).digest()

signature_sha_base64 = base64.b64encode(
signature_sha).decode(encoding='utf-8')

authorization_origin = f'apikey="{self.APIKey}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'

authorization = base64.b64encode(
authorization_origin.encode('utf-8')).decode(encoding='utf-8')

# 将请求的鉴权参数组合为字典
v = {
"authorization": authorization,
"date": date,
"host": self.host
}
# 拼接鉴权参数,生成url
url = self.spark_url + '?' + urlencode(v)
# 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致
return url


# 收到websocket错误的处理
def on_error(ws, error):
print("### error:", error)


# 收到websocket关闭的处理
def on_close(ws, one, two):
print(" ")


# 收到websocket连接建立的处理
def on_open(ws):
thread.start_new_thread(run, (ws,))


def run(ws, *args):
data = json.dumps(gen_params(
appid=ws.appid,
domain=ws.domain,
messages=ws.messages,
random_threshold=ws.random_threshold,
max_tokens=ws.max_tokens))
ws.send(data)


# 收到websocket消息的处理


def gen_params(appid, domain, random_threshold, max_tokens, messages):
"""
通过appid和用户的提问来生成请参数
"""
data = {
"header": {
"app_id": appid,
"uid": "1234"
},
"parameter": {
"chat": {
"domain": domain,
"random_threshold": random_threshold,
"max_tokens": max_tokens,
"auditing": "default"
}
},
"payload": {
"message": {
"text": messages
}
}
}
return data


def Request(appid, secret, apikey, messages, version, random_threshold, max_token):
if version == "v1":
spark_url = "ws://spark-api.xf-yun.com/v1.1/chat"
domain = "general"
elif version == "v2":
spark_url = "ws://spark-api.xf-yun.com/v2.1/chat"
domain = "generalv2"
ws_param = WSParam(appid, apikey, secret, spark_url)
ws_url = ws_param.create_url()
answer = ""

def on_message(ws, message):
data = json.loads(message)
print(json.dumps(data, ensure_ascii=False))
code = data['header']['code']
if code != 0:
# print(f'请求错误: {code}, {data}')
ws.close()
else:
choices = data["payload"]["choices"]
status = choices["status"]
content = choices["text"][0]["content"]
nonlocal answer
answer += content
if status == 2:
ws.close()
ws = websocket.WebSocketApp(
ws_url, on_message=on_message, on_error=on_error, on_close=on_close, on_open=on_open)
ws.appid = appid
ws.messages = messages
ws.domain = domain
ws.random_threshold = random_threshold
ws.max_tokens = max_token
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})


def main():
parser = argparse.ArgumentParser()
parser.add_argument("appid")
parser.add_argument("secret")
parser.add_argument("apikey")
parser.add_argument("messages")
parser.add_argument("--ver", "-v", default="v1")
parser.add_argument("--random_threshold", "-r", default=0.5, type=float)
parser.add_argument("--max_tokens", "-t", default=4096, type=int)
parse_result = parser.parse_args()
messages = json.loads(parse_result.messages)
Request(parse_result.appid, parse_result.secret, parse_result.apikey, messages,
parse_result.ver, parse_result.random_threshold, parse_result.max_tokens)


if __name__ == "__main__":
main()
Loading

0 comments on commit 2e69aeb

Please sign in to comment.