-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: #1 add supports complex struct parsing with streaming
- Loading branch information
Showing
9 changed files
with
222 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,3 +6,4 @@ dist | |
openai_streaming.egg-info/ | ||
.benchmarks | ||
junit | ||
.venv |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from handler import process_struct_response, Terminate, BaseHandler |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
from typing import Protocol, Literal, AsyncGenerator, Optional, Type, TypeVar, Union | ||
|
||
from pydantic import BaseModel | ||
|
||
from json_streamer import Parser, JsonParser | ||
from .yaml_parser import YamlParser | ||
from ..stream_processing import OAIResponse, process_response | ||
|
||
TModel = TypeVar('TModel', bound=BaseModel) | ||
|
||
|
||
class Terminate: | ||
pass | ||
|
||
|
||
class BaseHandler(Protocol[TModel]): | ||
def model(self) -> Type[TModel]: | ||
""" | ||
The Pydantic Data Model that we parse | ||
:return: type of the Pydantic model | ||
""" | ||
pass | ||
|
||
async def handle_partially_parsed(self, data: TModel) -> Optional[Terminate]: | ||
""" | ||
Handle partially parsed model | ||
:param data: The partially parsed object | ||
:return: None or Terminate if we want to terminate the parsing | ||
""" | ||
pass | ||
|
||
async def terminated(self): | ||
""" | ||
Called when the parsing was terminated | ||
""" | ||
|
||
|
||
OutputSerialization = Literal["json", "yaml"] | ||
|
||
|
||
class _ContentHandler: | ||
parser: Parser = None | ||
_last_resp: Optional[Union[TModel, Terminate]] = None | ||
|
||
def __init__(self, handler: BaseHandler, output_serialization: OutputSerialization = "yaml"): | ||
self.handler = handler | ||
if output_serialization == "json": | ||
self.parser = JsonParser() | ||
elif output_serialization == "yaml": | ||
self.parser = YamlParser() | ||
|
||
async def handle_content(self, content: AsyncGenerator[str, None]): | ||
""" | ||
Handle the content of the response from OpenAI. | ||
:param content: A generator that yields the content of the response from OpenAI | ||
:return: None | ||
""" | ||
|
||
loader = self.parser() # create a Streaming loader | ||
next(loader) | ||
|
||
last_resp = None | ||
|
||
async for token in content: | ||
parsed = loader.send(token) # send the token to the JSON loader | ||
while parsed: # loop until through the parsed parts as the loader yields them | ||
last_resp = await self._handle_parsed(parsed[1]) # handle the parsed dict of the response | ||
if isinstance(last_resp, Terminate): | ||
break | ||
try: | ||
parsed = next(loader) | ||
except StopIteration: | ||
break | ||
if isinstance(last_resp, Terminate): | ||
break | ||
|
||
if not last_resp: | ||
return | ||
if isinstance(last_resp, Terminate): | ||
await self.handler.terminated() | ||
|
||
self._last_resp = last_resp | ||
|
||
async def _handle_parsed(self, part) -> Optional[Union[TModel, Terminate]]: | ||
""" | ||
Handle a parsed part of the response from OpenAI. | ||
It parses the "parsed dictionary" as a type of `TModel` object and processes it with the handler. | ||
:param part: A dictionary containing the parsed part of the response | ||
:return: The parsed part of the response as an `TModel` object, `Terminate` to terminate the handling, | ||
or `None` if the part is not valid | ||
""" | ||
try: | ||
parsed = self.handler.model()(**part) | ||
except (TypeError, ValueError): | ||
return | ||
|
||
ret = await self.handler.handle_partially_parsed(parsed) | ||
return ret if ret else parsed | ||
|
||
def get_last_response(self) -> Optional[Union[TModel, Terminate]]: | ||
""" | ||
Get the last response from OpenAI. | ||
:return: The last response from OpenAI | ||
""" | ||
return self._last_resp | ||
|
||
|
||
async def process_struct_response( | ||
response: OAIResponse, | ||
handler: BaseHandler, | ||
output_serialization: OutputSerialization = "json" | ||
): | ||
handler = _ContentHandler(handler, output_serialization) | ||
_, result = await process_response(response, handler.handle_content, self=handler) | ||
if not handler.get_last_response(): | ||
raise ValueError("Probably invalid response from OpenAI") | ||
|
||
return result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
from typing import List, Dict, Tuple, Generator, Optional | ||
from json_streamer import Parser, ParseState | ||
|
||
|
||
class YamlParser(Parser): | ||
""" | ||
Parse partial YAML | ||
""" | ||
|
||
@staticmethod | ||
def opening_symbols() -> List[chr]: | ||
return ['{', '['] | ||
|
||
def raw_decode(self, s: str) -> Tuple[Dict, int]: | ||
try: | ||
from yaml import safe_load | ||
except ImportError: | ||
raise ImportError("You must install PyYAML to use the YamlParser: pip install PyYAML") | ||
return safe_load(s), -1 | ||
|
||
def parse_part(self, part: str) -> Generator[Tuple[ParseState, dict], None, None]: | ||
for y in super().parse_part(part): | ||
yield ParseState.UNKNOWN, y[1] | ||
|
||
|
||
def loads(s: Optional[Generator[chr, None, None]] = None) -> Generator[Tuple[ParseState, dict], Optional[str], None]: | ||
return YamlParser()(s) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
openai==1.14.0 | ||
json-streamer==0.1.0 | ||
pydantic==2.6.4 | ||
docstring-parser==0.15 | ||
docstring-parser==0.15 | ||
PyYAML==6.0.1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
import os | ||
from time import sleep | ||
|
||
from openai import AsyncOpenAI | ||
import asyncio | ||
|
||
from pydantic import BaseModel | ||
|
||
from typing import Optional | ||
from openai_streaming.struct import BaseHandler, process_struct_response, Terminate | ||
|
||
# Initialize OpenAI Client | ||
client = AsyncOpenAI( | ||
api_key=os.environ.get("OPENAI_API_KEY"), | ||
) | ||
|
||
|
||
class Letter(BaseModel): | ||
title: str | ||
to: Optional[str] = None | ||
content: Optional[str] = None | ||
|
||
|
||
# Define handler | ||
class Handler(BaseHandler): | ||
def model(self): | ||
return Letter | ||
|
||
last_content = "" | ||
|
||
async def handle_partially_parsed(self, data: Letter) -> Optional[Terminate]: | ||
if data.to and data.to.lower() != "larry": | ||
print("You can only write a letter to Larry") | ||
return Terminate() | ||
if data.content: | ||
# here we mingle with the content a bit for the sake of the animation | ||
data.content = data.content[len(self.last_content):] | ||
self.last_content = self.last_content + data.content | ||
print(data.content, end="") | ||
sleep(0.1) | ||
|
||
async def terminated(self): | ||
print("Terminated") | ||
|
||
|
||
# Invoke Function in a streaming request | ||
async def main(): | ||
# Request and process stream | ||
resp = await client.chat.completions.create( | ||
model="gpt-3.5-turbo", | ||
messages=[{ | ||
"role": "system", | ||
"content": | ||
"You are a letter writer able to communicate only with VALID YAML. " | ||
"You must include only these fields: title, to, content." | ||
"ONLY write the YAML, without any other text or wrapping it in a code block." | ||
}, {"role": "user", "content": | ||
"Write a SHORT letter to my friend Larry congratulating him for his newborn baby Lily." | ||
"It should be funny and rhythmic. It MUST be very short!" | ||
}], | ||
stream=True | ||
) | ||
await process_struct_response(resp, Handler(), 'yaml') | ||
|
||
|
||
# Start the script asynchronously | ||
asyncio.run(main()) |