Skip to content

Commit

Permalink
Google Vertex API (#1)
Browse files Browse the repository at this point in the history
* Remove json config as deprecated

* Updated install files

* Updated readme

* Updated requirements

* Updated main and rmq

* Updated model file name

* Updated folder name

* Updated model

* Added API key init

* Fixed private var use

* Use latest version of models

* context_depth only even
  • Loading branch information
NeonBohdan committed Jan 15, 2024
1 parent 461dd8d commit 3b9d471
Show file tree
Hide file tree
Showing 11 changed files with 76 additions and 71 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/license_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ jobs:
license_tests:
uses: neongeckocom/.github/.github/workflows/license_tests.yml@master
with:
packages-exclude: '^(neon-llm-chatgpt|tqdm).*'
packages-exclude: '^(neon-llm-palm2|tqdm).*'
4 changes: 2 additions & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
FROM python:3.9-slim

LABEL vendor=neon.ai \
ai.neon.name="neon-llm-chatgpt"
ai.neon.name="neon-llm-palm2"

ENV OVOS_CONFIG_BASE_FOLDER neon
ENV OVOS_CONFIG_FILENAME diana.yaml
Expand All @@ -12,4 +12,4 @@ WORKDIR /app
COPY . /app
RUN pip install /app

CMD [ "neon-llm-chatgpt" ]
CMD [ "neon-llm-palm2" ]
17 changes: 8 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# NeonAI LLM ChatGPT
Proxies API calls to ChatGPT.
# NeonAI LLM Palm2
Proxies API calls to Google Palm2.

## Request Format
API requests should include `history`, a list of tuples of strings, and the current
Expand All @@ -25,12 +25,11 @@ MQ:
port: <MQ Port>
server: <MQ Hostname or IP>
users:
neon_llm_chat_gpt:
password: <neon_chatgpt user's password>
user: neon_chatgpt
LLM_CHAT_GPT:
key: ""
model: "gpt-3.5-turbo"
neon_llm_palm2:
password: <neon_palm2 user's password>
user: neon_palm2
LLM_PALM2:
key_path: ""
role: "You are trying to give a short answer in less than 40 words."
context_depth: 3
max_tokens: 100
Expand All @@ -40,6 +39,6 @@ LLM_CHAT_GPT:
For example, if your configuration resides in `~/.config`:
```shell
export CONFIG_PATH="/home/${USER}/.config"
docker run -v ${CONFIG_PATH}:/config neon_llm_chatgpt
docker run -v ${CONFIG_PATH}:/config neon_llm_palm2
```
> Note: If connecting to a local MQ server, you may need to specify `--network host`
3 changes: 1 addition & 2 deletions docker_overlay/etc/neon/diana.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@ MQ:
mq_handler:
user: neon_api_utils
password: Klatchat2021
LLM_CHAT_GPT:
model: "gpt-3.5-turbo"
LLM_PALM2:
role: "You are trying to give a short answer in less than 40 words."
context_depth: 3
max_tokens: 100
Expand Down
19 changes: 0 additions & 19 deletions neon_llm_chatgpt/default_config.json

This file was deleted.

File renamed without changes.
8 changes: 4 additions & 4 deletions neon_llm_chatgpt/__main__.py → neon_llm_palm2/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,15 @@
# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

from neon_llm_chatgpt.rmq import ChatgptMQ
from neon_llm_palm2.rmq import Palm2MQ


def main():
# Run RabbitMQ
chatgptMQ = ChatgptMQ()
chatgptMQ.run(run_sync=False, run_consumers=True,
palm2MQ = Palm2MQ()
palm2MQ.run(run_sync=False, run_consumers=True,
daemonize_consumers=True)
chatgptMQ.observer_thread.join()
palm2MQ.observer_thread.join()


if __name__ == "__main__":
Expand Down
75 changes: 50 additions & 25 deletions neon_llm_chatgpt/chatgpt.py → neon_llm_palm2/palm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,29 +24,40 @@
# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import openai
from openai.embeddings_utils import get_embeddings, distances_from_embeddings
import os
from vertexai.language_models import ChatModel, ChatMessage, TextEmbeddingModel
from openai.embeddings_utils import distances_from_embeddings

from typing import List, Dict
from neon_llm_core.llm import NeonLLM


class ChatGPT(NeonLLM):
class Palm2(NeonLLM):

mq_to_llm_role = {
"user": "user",
"llm": "assistant"
"llm": "bot"
}

def __init__(self, config):
super().__init__(config)
self.model_name = config["model"]
self._embedding = None
self._context_depth = 0

self.role = config["role"]
self.context_depth = config["context_depth"]
self.max_tokens = config["max_tokens"]
self.api_key = config["key"]
self.api_key_path = config["key_path"]
self.warmup()

@property
def context_depth(self):
return self._context_depth

@context_depth.setter
def context_depth(self, value):
self._context_depth = value + value % 2

@property
def tokenizer(self) -> None:
return self._tokenizer
Expand All @@ -56,11 +67,16 @@ def tokenizer_model_name(self) -> str:
return ""

@property
def model(self) -> openai:
def model(self) -> ChatModel:
if self._model is None:
openai.api_key = self.api_key
self._model = openai
self._model = ChatModel.from_pretrained("chat-bison")
return self._model

@property
def embedding(self) -> TextEmbeddingModel:
if self._embedding is None:
self._embedding = TextEmbeddingModel.from_pretrained("textembedding-gecko@latest")
return self._embedding

@property
def llm_model_name(self) -> str:
Expand All @@ -71,7 +87,9 @@ def _system_prompt(self) -> str:
return self.role

def warmup(self):
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = self.api_key_path
self.model
self.embedding

def get_sorted_answer_indexes(self, question: str, answers: List[str], persona: dict) -> List[int]:
"""
Expand All @@ -88,43 +106,49 @@ def get_sorted_answer_indexes(self, question: str, answers: List[str], persona:
sorted_items_indexes = [x[0] for x in sorted_items]
return sorted_items_indexes

def _call_model(self, prompt: List[Dict[str, str]]) -> str:
def _call_model(self, prompt: Dict) -> str:
"""
Wrapper for ChatGPT Model generation logic
Wrapper for Palm2 Model generation logic
:param prompt: Input messages sequence
:returns: Output text sequence generated by model
"""

response = openai.ChatCompletion.create(
model=self.llm_model_name,
messages=prompt,
chat = self._model.start_chat(
context=prompt["system_prompt"],
message_history=prompt["chat_history"],
max_output_tokens=self.max_tokens,
temperature=0,
max_tokens=self.max_tokens,
)
text = response.choices[0].message['content']
response = chat.send_message(
prompt["message"],
)
text = response.text

return text

def _assemble_prompt(self, message: str, chat_history: List[List[str]], persona: dict) -> List[Dict[str, str]]:
"""
Assembles prompt engineering logic
Setup Guidance:
https://platform.openai.com/docs/guides/gpt/chat-completions-api
https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/overview
:param message: Incoming prompt
:param chat_history: History of preceding conversation
:returns: assembled prompt
"""
system_prompt = persona.get("description", self._system_prompt)
messages = [
{"role": "system", "content": system_prompt},
]
# Context N messages
messages = []
for role, content in chat_history[-self.context_depth:]:
role_chatgpt = self.convert_role(role)
messages.append({"role": role_chatgpt, "content": content})
messages.append({"role": "user", "content": message})
return messages
role_palm2 = self.convert_role(role)
messages.append(ChatMessage(content, role_palm2))
prompt = {
"system_prompt": system_prompt,
"chat_history": messages,
"message": message
}

return prompt

def _score(self, prompt: str, targets: List[str], persona: dict) -> List[float]:
"""
Expand All @@ -150,7 +174,8 @@ def _embeddings(self, question: str, answers: List[str], persona: dict) -> (List
"""
response = self.ask(question, [], persona=persona)
texts = [response] + answers
embeddings = get_embeddings(texts, engine="text-embedding-ada-002")
embeddings_obj = self.embedding.get_embeddings(texts)
embeddings = [embedding.values for embedding in embeddings_obj]
question_embeddings = embeddings[0]
answers_embeddings = embeddings[1:]
return question_embeddings, answers_embeddings
10 changes: 5 additions & 5 deletions neon_llm_chatgpt/rmq.py → neon_llm_palm2/rmq.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
from neon_llm_core.rmq import NeonLLMMQConnector

from neon_llm_chatgpt.chatgpt import ChatGPT
from neon_llm_palm2.palm2 import Palm2


class ChatgptMQ(NeonLLMMQConnector):
class Palm2MQ(NeonLLMMQConnector):
"""
Module for processing MQ requests to ChatGPT
Module for processing MQ requests to Palm2
"""

def __init__(self):
Expand All @@ -39,12 +39,12 @@ def __init__(self):

@property
def name(self):
return "chat_gpt"
return "palm2"

@property
def model(self):
if self._model is None:
self._model = ChatGPT(self.model_config)
self._model = Palm2(self.model_config)
return self._model

def warmup(self):
Expand Down
1 change: 1 addition & 0 deletions requirements/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# model
google-cloud-aiplatform
openai[embeddings]~=0.27
# networking
neon_llm_core~=0.1.0
8 changes: 4 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,12 @@ def get_requirements(requirements_filename: str):
version = line.split("'")[1]

setup(
name='neon-llm-chatgpt',
name='neon-llm-palm2',
version=version,
description='LLM service for Chat GPT',
description='LLM service for Palm2',
long_description=long_description,
long_description_content_type="text/markdown",
url='https://github.com/NeonGeckoCom/neon-llm-chatgpt',
url='https://github.com/NeonGeckoCom/neon-llm-palm2',
author='Neongecko',
author_email='developers@neon.ai',
license='BSD-3.0',
Expand All @@ -85,7 +85,7 @@ def get_requirements(requirements_filename: str):
],
entry_points={
'console_scripts': [
'neon-llm-chatgpt=neon_llm_chatgpt.__main__:main'
'neon-llm-palm2=neon_llm_palm2.__main__:main'
]
}
)

0 comments on commit 3b9d471

Please sign in to comment.