Skip to content

Commit

Permalink
Merge pull request #8 from AlmogBaku/fix/struct_with_generics
Browse files Browse the repository at this point in the history
fix: make struct processing fully defined with generics + add documen…
  • Loading branch information
AlmogBaku committed May 17, 2024
2 parents b9542ff + 9717483 commit f98f45f
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 24 deletions.
56 changes: 56 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,62 @@ async def main():
asyncio.run(main())
```

## 🤓Streaming structured data (advanced usage)

The library also supports streaming structured data.
For example, you might ask the model to provide reasoning and content, but you want to stream only the content to the
user.

This is where the `process_struct_response()` function comes in handy.
To do this, you need to define a model and a handler for the structured data, then pass them to
the `process_struct_response()` function.

```python
class MathProblem(BaseModel):
steps: List[str]
answer: Optional[int] = None


# Define handler
class Handler(BaseHandler[MathProblem]):
async def handle_partially_parsed(self, data: MathProblem) -> Optional[Terminate]:
if len(data.steps) == 0 and data.answer:
return Terminate() # something is wrong here, so we immediately stop

if data.answer:
self.ws.send(data.answer) # show to the user with WebSocket

async def terminated(self):
ws.close() # close the WebSocket§


# Invoke OpenAI request
async def main():
resp = await client.chat.completions.create(
messages=[{
"role": "system",
"content":
"For every question asked, you must first state the steps, and then the answer."
"Your response should be in the following format: \n"
" steps: List[str]\n"
" answer: int\n"
"ONLY write the YAML, without any other text or wrapping it in a code block."
"YAML should be VALID, and strings must be in double quotes."
}, {"role": "user", "content": "1+3*2"}],
stream=True
)
await process_struct_response(resp, Handler(), 'yaml')


asyncio.run(main())
```

With this function, you can process and stream structured data, or even implement your own "tool use" mechanism with
streaming.

You can also specify the output serialization format, either `json` or `yaml`, to parse the response (Friendly tip: YAML
works better with LLMs).

# 🤔 What's the big deal? Why use this library?

The OpenAI Streaming API is robust but challenging to navigate. Using the `stream=True` flag, we get tokens as they are
Expand Down
30 changes: 18 additions & 12 deletions openai_streaming/struct/handler.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Protocol, Literal, AsyncGenerator, Optional, Type, TypeVar, Union, Dict, Any, Tuple
from typing import Protocol, Literal, AsyncGenerator, Optional, TypeVar, Union, Dict, Any, Tuple, get_args, \
runtime_checkable

from pydantic import BaseModel

Expand All @@ -13,17 +14,13 @@ class Terminate:
pass


@runtime_checkable
class BaseHandler(Protocol[TModel]):
"""
The base handler for the structured response from OpenAI.
"""
def model(self) -> Type[TModel]:
"""
The Pydantic Data Model that we parse
:return: type of the Pydantic model
"""
pass
:param TModel: The `BaseModel` to parse the structured response to
"""

async def handle_partially_parsed(self, data: TModel) -> Optional[Terminate]:
"""
Expand All @@ -48,9 +45,9 @@ class _ContentHandler:

def __init__(self, handler: BaseHandler, output_serialization: OutputSerialization = "yaml"):
self.handler = handler
if output_serialization == "json":
if output_serialization.lower() == "json":
self.parser = JsonParser()
elif output_serialization == "yaml":
elif output_serialization.lower() == "yaml":
self.parser = YamlParser()

async def handle_content(self, content: AsyncGenerator[str, None]):
Expand Down Expand Up @@ -95,7 +92,8 @@ async def _handle_parsed(self, part) -> Optional[Union[TModel, Terminate]]:
or `None` if the part is not valid
"""
try:
parsed = self.handler.model()(**part)
typ = get_args(type(self.handler).__orig_bases__[0])[0]
parsed = typ(**part)
except (TypeError, ValueError):
return

Expand All @@ -121,11 +119,19 @@ async def process_struct_response(
contains reasoning, and content - but we want to stream only the content to the user.
:param response: The response from OpenAI
:param handler: The handler for the response. It should be a subclass of `BaseHandler`
:param handler: The handler for the response. It should be a subclass of `BaseHandler[BaseModel]` with a generic
type provided
:param output_serialization: The output serialization of the response. It should be either "json" or "yaml"
:return: A tuple of the last parsed response, and a dictionary containing the OpenAI response
"""

if not issubclass(type(handler), BaseHandler):
raise ValueError("handler should be a subclass of BaseHandler")

tmodel = get_args(type(handler).__orig_bases__[0])[0]
if tmodel == TModel:
raise ValueError("handler should be a subclass of BaseHandler with a generic type")

handler = _ContentHandler(handler, output_serialization)
_, result = await process_response(response, handler.handle_content, self=handler)
if not handler.get_last_response():
Expand Down
5 changes: 1 addition & 4 deletions tests/example_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,7 @@ class MathProblem(BaseModel):


# Define handler
class Handler(BaseHandler):
def model(self):
return MathProblem

class Handler(BaseHandler[MathProblem]):
async def handle_partially_parsed(self, data: MathProblem) -> Optional[Terminate]:
if len(data.steps) == 0 and data.answer:
return Terminate()
Expand Down
10 changes: 2 additions & 8 deletions tests/test_with_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,21 +21,15 @@ class MathProblem(BaseModel):


# Define handler
class Handler(BaseHandler):
def model(self):
return MathProblem

class Handler(BaseHandler[MathProblem]):
async def handle_partially_parsed(self, data: MathProblem) -> Optional[Terminate]:
pass

async def terminated(self):
pass


class Handler2(BaseHandler):
def model(self):
return MathProblem

class Handler2(BaseHandler[MathProblem]):
async def handle_partially_parsed(self, data: MathProblem) -> Optional[Terminate]:
return Terminate()

Expand Down

0 comments on commit f98f45f

Please sign in to comment.