forked from Bryley/neoai.nvim
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fixed Bryley#50, Add Spark LLM support
Add `selected_model_index` config option
- Loading branch information
1 parent
248c200
commit 2e69aeb
Showing
5 changed files
with
375 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
local utils = require("neoai.utils") | ||
local config = require("neoai.config") | ||
|
||
---@type ModelModule | ||
local M = {} | ||
|
||
M.name = "Spark" | ||
|
||
M._chunks = {} | ||
local raw_chunks = {} | ||
|
||
M.get_current_output = function() | ||
return table.concat(M._chunks, "") | ||
end | ||
|
||
---@param chunk string | ||
---@param on_stdout_chunk fun(chunk: string) Function to call whenever a stdout chunk occurs | ||
M._recieve_chunk = function(chunk, on_stdout_chunk) | ||
for line in chunk:gmatch("[^\n]+") do | ||
local raw_json = line | ||
|
||
table.insert(raw_chunks, raw_json) | ||
|
||
local ok, path = pcall(vim.json.decode, raw_json) | ||
if not ok then | ||
on_stdout_chunk("decode error") | ||
goto continue | ||
end | ||
|
||
path = path.payload | ||
if path == nil then | ||
goto continue | ||
end | ||
|
||
path = path.choices | ||
if path == nil then | ||
goto continue | ||
end | ||
|
||
path = path.text | ||
if path == nil then | ||
goto continue | ||
end | ||
|
||
path = path[1] | ||
if path == nil then | ||
goto continue | ||
end | ||
|
||
path = path.content | ||
if path == nil then | ||
goto continue | ||
end | ||
|
||
|
||
on_stdout_chunk(path) | ||
-- append_to_output(path, 0) | ||
table.insert(M._chunks, path) | ||
::continue:: | ||
end | ||
end | ||
|
||
---@param chat_history ChatHistory | ||
---@param on_stdout_chunk fun(chunk: string) Function to call whenever a stdout chunk occurs | ||
---@param on_complete fun(err?: string, output?: string) Function to call when model has finished | ||
M.send_to_model = function(chat_history, on_stdout_chunk, on_complete) | ||
local appid, secret, apikey = config.options.spark.api_key.get() | ||
local ver = config.options.spark.version | ||
local random_threshold = config.options.spark.random_threshold | ||
local max_tokens = config.options.spark.max_tokens | ||
|
||
local get_script_dir = function() | ||
local info = debug.getinfo(1, "S") | ||
local script_path = info.source:sub(2) | ||
return script_path:match("(.*/)") | ||
end | ||
|
||
local py_script_path = get_script_dir() .. "/spark.py" | ||
|
||
chunks = {} | ||
raw_chunks = {} | ||
utils.exec(py_script_path, { | ||
appid, | ||
secret, | ||
apikey, | ||
vim.json.encode(chat_history.messages), | ||
"--ver", | ||
ver, | ||
"--random_threshold", | ||
random_threshold, | ||
"--max_tokens", | ||
max_tokens | ||
}, function(chunk) | ||
M._recieve_chunk(chunk, on_stdout_chunk) | ||
end, function(err, _) | ||
local total_message = table.concat(raw_chunks, "") | ||
local ok, json = pcall(vim.json.decode, total_message) | ||
if ok then | ||
if json.error ~= nil then | ||
on_complete(json.error.message, nil) | ||
return | ||
end | ||
end | ||
on_complete(err, M.get_current_output()) | ||
end) | ||
end | ||
|
||
return M |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,172 @@ | ||
ba#!/usr/bin/env python3 | ||
|
||
import _thread as thread | ||
import base64 | ||
import datetime | ||
import hashlib | ||
import hmac | ||
import json | ||
from urllib.parse import urlparse | ||
import ssl | ||
from datetime import datetime | ||
from time import mktime | ||
from urllib.parse import urlencode | ||
from wsgiref.handlers import format_date_time | ||
import argparse | ||
|
||
import websocket # 使用websocket_client | ||
answer = "" | ||
|
||
|
||
class WSParam(object): | ||
# 初始化 | ||
def __init__(self, APPID, APIKey, APISecret, spark_url): | ||
self.APPID = APPID | ||
self.APIKey = APIKey | ||
self.APISecret = APISecret | ||
self.host = urlparse(spark_url).netloc | ||
self.path = urlparse(spark_url).path | ||
self.spark_url = spark_url | ||
|
||
# 生成url | ||
def create_url(self): | ||
# 生成RFC1123格式的时间戳 | ||
now = datetime.now() | ||
date = format_date_time(mktime(now.timetuple())) | ||
|
||
# 拼接字符串 | ||
signature_origin = "host: " + self.host + "\n" | ||
signature_origin += "date: " + date + "\n" | ||
signature_origin += "GET " + self.path + " HTTP/1.1" | ||
|
||
# 进行hmac-sha256进行加密 | ||
signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'), | ||
digestmod=hashlib.sha256).digest() | ||
|
||
signature_sha_base64 = base64.b64encode( | ||
signature_sha).decode(encoding='utf-8') | ||
|
||
authorization_origin = f'apikey="{self.APIKey}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"' | ||
|
||
authorization = base64.b64encode( | ||
authorization_origin.encode('utf-8')).decode(encoding='utf-8') | ||
|
||
# 将请求的鉴权参数组合为字典 | ||
v = { | ||
"authorization": authorization, | ||
"date": date, | ||
"host": self.host | ||
} | ||
# 拼接鉴权参数,生成url | ||
url = self.spark_url + '?' + urlencode(v) | ||
# 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致 | ||
return url | ||
|
||
|
||
# 收到websocket错误的处理 | ||
def on_error(ws, error): | ||
print("### error:", error) | ||
|
||
|
||
# 收到websocket关闭的处理 | ||
def on_close(ws, one, two): | ||
print(" ") | ||
|
||
|
||
# 收到websocket连接建立的处理 | ||
def on_open(ws): | ||
thread.start_new_thread(run, (ws,)) | ||
|
||
|
||
def run(ws, *args): | ||
data = json.dumps(gen_params( | ||
appid=ws.appid, | ||
domain=ws.domain, | ||
messages=ws.messages, | ||
random_threshold=ws.random_threshold, | ||
max_tokens=ws.max_tokens)) | ||
ws.send(data) | ||
|
||
|
||
# 收到websocket消息的处理 | ||
|
||
|
||
def gen_params(appid, domain, random_threshold, max_tokens, messages): | ||
""" | ||
通过appid和用户的提问来生成请参数 | ||
""" | ||
data = { | ||
"header": { | ||
"app_id": appid, | ||
"uid": "1234" | ||
}, | ||
"parameter": { | ||
"chat": { | ||
"domain": domain, | ||
"random_threshold": random_threshold, | ||
"max_tokens": max_tokens, | ||
"auditing": "default" | ||
} | ||
}, | ||
"payload": { | ||
"message": { | ||
"text": messages | ||
} | ||
} | ||
} | ||
return data | ||
|
||
|
||
def Request(appid, secret, apikey, messages, version, random_threshold, max_token): | ||
if version == "v1": | ||
spark_url = "ws://spark-api.xf-yun.com/v1.1/chat" | ||
domain = "general" | ||
elif version == "v2": | ||
spark_url = "ws://spark-api.xf-yun.com/v2.1/chat" | ||
domain = "generalv2" | ||
ws_param = WSParam(appid, apikey, secret, spark_url) | ||
ws_url = ws_param.create_url() | ||
answer = "" | ||
|
||
def on_message(ws, message): | ||
data = json.loads(message) | ||
print(json.dumps(data, ensure_ascii=False)) | ||
code = data['header']['code'] | ||
if code != 0: | ||
# print(f'请求错误: {code}, {data}') | ||
ws.close() | ||
else: | ||
choices = data["payload"]["choices"] | ||
status = choices["status"] | ||
content = choices["text"][0]["content"] | ||
nonlocal answer | ||
answer += content | ||
if status == 2: | ||
ws.close() | ||
ws = websocket.WebSocketApp( | ||
ws_url, on_message=on_message, on_error=on_error, on_close=on_close, on_open=on_open) | ||
ws.appid = appid | ||
ws.messages = messages | ||
ws.domain = domain | ||
ws.random_threshold = random_threshold | ||
ws.max_tokens = max_token | ||
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE}) | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("appid") | ||
parser.add_argument("secret") | ||
parser.add_argument("apikey") | ||
parser.add_argument("messages") | ||
parser.add_argument("--ver", "-v", default="v1") | ||
parser.add_argument("--random_threshold", "-r", default=0.5, type=float) | ||
parser.add_argument("--max_tokens", "-t", default=4096, type=int) | ||
parse_result = parser.parse_args() | ||
messages = json.loads(parse_result.messages) | ||
Request(parse_result.appid, parse_result.secret, parse_result.apikey, messages, | ||
parse_result.ver, parse_result.random_threshold, parse_result.max_tokens) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Oops, something went wrong.