Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewrite curl implementation using Plenary and add cancel stream functionality #45

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 21 additions & 22 deletions lua/neoai/chat/history.lua
Original file line number Diff line number Diff line change
Expand Up @@ -12,40 +12,39 @@ local ChatHistory = { model = "", params = {}, messages = {} }
---@param context string | nil The context to use
---@return ChatHistory
function ChatHistory:new(model, params, context)
local obj = {}
local obj = {}

setmetatable(obj, self)
self.__index = self
setmetatable(obj, self)
self.__index = self

self.model = model
self.params = params or {}
self.messages = {}
self.model = model
self.params = params or {}
self.messages = {}

if context ~= nil then
local context_prompt = config.options.prompts.context_prompt(context)
self:set_prompt(context_prompt)
end
return obj
if context ~= nil then
local context_prompt = config.options.prompts.context_prompt(context)
self:set_prompt(context_prompt)
end
return obj
end

--- @param prompt string system prompt
function ChatHistory:set_prompt(prompt)
local system_msg = {
role = "system",
content = prompt,
}
table.insert(self.messages, system_msg)
local system_msg = {
role = "system",
content = prompt,
}
table.insert(self.messages, system_msg)
end

---@param user boolean True if user sent msg
---@param msg string The message to add
function ChatHistory:add_message(user, msg)
local role = user and "user" or "assistant"

table.insert(self.messages, {
role = role,
content = msg,
})
local role = user and "user" or "assistant"
table.insert(self.messages, {
role = role,
content = msg,
})
end

return ChatHistory
1 change: 1 addition & 0 deletions lua/neoai/chat/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ end
M.reset = function()
M.context = nil
M.chat_history = nil
M.get_current_model().name.cancel_stream()
end

local chunks = {}
Expand Down
144 changes: 79 additions & 65 deletions lua/neoai/chat/models/openai.lua
Original file line number Diff line number Diff line change
@@ -1,91 +1,105 @@
local utils = require("neoai.utils")
local config = require("neoai.config")
local curl = require("plenary.curl")
local utils = require("neoai.utils")

---@type ModelModule
local M = {}

M.name = "OpenAI"

local handler
local chunks = {}
local raw_chunks = {}

---@brief Cancel the current stream and shut down the handler
M.cancel_stream = function()
if handler ~= nil then
handler:shutdown()
handler = nil
end
end

M.get_current_output = function()
return table.concat(chunks, "")
return table.concat(chunks, "")
end

---@param chunk string
---@param on_stdout_chunk fun(chunk: string) Function to call whenever a stdout chunk occurs
M._recieve_chunk = function(chunk, on_stdout_chunk)
for line in chunk:gmatch("[^\n]+") do
local raw_json = string.gsub(line, "^data: ", "")
M._receive_chunk = function(chunk, on_stdout_chunk)
local function safely_extract_delta_content(decoded_json)
local path = decoded_json.choices
if not path then
return nil
end

path = path[1]
if not path then
return nil
end

path = path.delta
if not path then
return nil
end

return path.content
end
-- Remove "data:" prefix from chunk
local raw_json = string.gsub(chunk, "%s*data:%s*", "")
table.insert(raw_chunks, raw_json)

table.insert(raw_chunks, raw_json)
local ok, path = pcall(vim.json.decode, raw_json)
if not ok then
goto continue
end
local ok, decoded_json = pcall(vim.json.decode, raw_json)
if not ok then
return -- Ignore invalid JSON chunks
end

path = path.choices
if path == nil then
goto continue
end
path = path[1]
if path == nil then
goto continue
end
path = path.delta
if path == nil then
goto continue
end
path = path.content
if path == nil then
goto continue
end
on_stdout_chunk(path)
-- append_to_output(path, 0)
table.insert(chunks, path)
::continue::
end
local delta_content = safely_extract_delta_content(decoded_json)
if delta_content then
table.insert(chunks, delta_content)
on_stdout_chunk(delta_content)
end
end

---@param chat_history ChatHistory
---@param on_stdout_chunk fun(chunk: string) Function to call whenever a stdout chunk occurs
---@param on_complete fun(err?: string, output?: string) Function to call when model has finished
M.send_to_model = function(chat_history, on_stdout_chunk, on_complete)
local api_key = os.getenv(config.options.open_api_key_env)
local api_key = os.getenv(config.options.open_api_key_env)

local data = {
model = chat_history.model,
stream = true,
messages = chat_history.messages,
}
data = vim.tbl_deep_extend("force", {}, data, chat_history.params)
local data = {
model = chat_history.model,
stream = true,
messages = chat_history.messages,
}
data = vim.tbl_deep_extend("force", {}, data, chat_history.params)

chunks = {}
raw_chunks = {}
utils.exec("curl", {
"--silent",
"--show-error",
"--no-buffer",
"https://api.openai.com/v1/chat/completions",
"-H",
"Content-Type: application/json",
"-H",
"Authorization: Bearer " .. api_key,
"-d",
vim.json.encode(data),
}, function(chunk)
M._recieve_chunk(chunk, on_stdout_chunk)
end, function(err, _)
local total_message = table.concat(raw_chunks, "")
local ok, json = pcall(vim.json.decode, total_message)
if ok then
if json.error ~= nil then
on_complete(json.error.message, nil)
return
end
end
on_complete(err, M.get_current_output())
end)
chunks = {}
raw_chunks = {}
handler = curl.post({
url = "https://api.openai.com/v1/chat/completions",
raw = { "--no-buffer" },
headers = {
content_type = "application/json",
Authorization = "Bearer " .. api_key,
},
body = vim.json.encode(data),
stream = function(_, chunk)
if chunk ~= "" then
-- The following workaround helps to identify when the model has completed its task.
if string.match(chunk, "%[DONE%]") then
vim.schedule(function()
on_complete(nil, M.get_current_output())
end)
else
vim.schedule(function()
M._receive_chunk(chunk, on_stdout_chunk)
end)
end
end
end,
on_error = function(err, _, _)
return on_complete(err, nil)
end,
})
end

return M
Loading