XinXiKuaiBaoYuan/django-backend/venv/lib/python3.9/site-packages/openai/_streaming.py

243 lines
6.9 KiB
Python
Raw Normal View History

# Note: initially copied from https://github.com/florimondmanca/httpx-sse/blob/master/src/httpx_sse/_decoders.py
from __future__ import annotations
import json
from typing import TYPE_CHECKING, Any, Generic, Iterator, AsyncIterator
from typing_extensions import override
import httpx
from ._types import ResponseT
from ._utils import is_mapping
from ._exceptions import APIError
if TYPE_CHECKING:
from ._base_client import SyncAPIClient, AsyncAPIClient
class Stream(Generic[ResponseT]):
"""Provides the core interface to iterate over a synchronous stream response."""
response: httpx.Response
def __init__(
self,
*,
cast_to: type[ResponseT],
response: httpx.Response,
client: SyncAPIClient,
) -> None:
self.response = response
self._cast_to = cast_to
self._client = client
self._decoder = SSEDecoder()
self._iterator = self.__stream__()
def __next__(self) -> ResponseT:
return self._iterator.__next__()
def __iter__(self) -> Iterator[ResponseT]:
for item in self._iterator:
yield item
def _iter_events(self) -> Iterator[ServerSentEvent]:
yield from self._decoder.iter(self.response.iter_lines())
def __stream__(self) -> Iterator[ResponseT]:
cast_to = self._cast_to
response = self.response
process_data = self._client._process_response_data
iterator = self._iter_events()
for sse in iterator:
if sse.data.startswith("[DONE]"):
break
if sse.event is None:
data = sse.json()
if is_mapping(data) and data.get("error"):
raise APIError(
message="An error ocurred during streaming",
request=self.response.request,
body=data["error"],
)
yield process_data(data=data, cast_to=cast_to, response=response)
# Ensure the entire stream is consumed
for sse in iterator:
...
class AsyncStream(Generic[ResponseT]):
"""Provides the core interface to iterate over an asynchronous stream response."""
response: httpx.Response
def __init__(
self,
*,
cast_to: type[ResponseT],
response: httpx.Response,
client: AsyncAPIClient,
) -> None:
self.response = response
self._cast_to = cast_to
self._client = client
self._decoder = SSEDecoder()
self._iterator = self.__stream__()
async def __anext__(self) -> ResponseT:
return await self._iterator.__anext__()
async def __aiter__(self) -> AsyncIterator[ResponseT]:
async for item in self._iterator:
yield item
async def _iter_events(self) -> AsyncIterator[ServerSentEvent]:
async for sse in self._decoder.aiter(self.response.aiter_lines()):
yield sse
async def __stream__(self) -> AsyncIterator[ResponseT]:
cast_to = self._cast_to
response = self.response
process_data = self._client._process_response_data
iterator = self._iter_events()
async for sse in iterator:
if sse.data.startswith("[DONE]"):
break
if sse.event is None:
data = sse.json()
if is_mapping(data) and data.get("error"):
raise APIError(
message="An error ocurred during streaming",
request=self.response.request,
body=data["error"],
)
yield process_data(data=data, cast_to=cast_to, response=response)
# Ensure the entire stream is consumed
async for sse in iterator:
...
class ServerSentEvent:
def __init__(
self,
*,
event: str | None = None,
data: str | None = None,
id: str | None = None,
retry: int | None = None,
) -> None:
if data is None:
data = ""
self._id = id
self._data = data
self._event = event or None
self._retry = retry
@property
def event(self) -> str | None:
return self._event
@property
def id(self) -> str | None:
return self._id
@property
def retry(self) -> int | None:
return self._retry
@property
def data(self) -> str:
return self._data
def json(self) -> Any:
return json.loads(self.data)
@override
def __repr__(self) -> str:
return f"ServerSentEvent(event={self.event}, data={self.data}, id={self.id}, retry={self.retry})"
class SSEDecoder:
_data: list[str]
_event: str | None
_retry: int | None
_last_event_id: str | None
def __init__(self) -> None:
self._event = None
self._data = []
self._last_event_id = None
self._retry = None
def iter(self, iterator: Iterator[str]) -> Iterator[ServerSentEvent]:
"""Given an iterator that yields lines, iterate over it & yield every event encountered"""
for line in iterator:
line = line.rstrip("\n")
sse = self.decode(line)
if sse is not None:
yield sse
async def aiter(self, iterator: AsyncIterator[str]) -> AsyncIterator[ServerSentEvent]:
"""Given an async iterator that yields lines, iterate over it & yield every event encountered"""
async for line in iterator:
line = line.rstrip("\n")
sse = self.decode(line)
if sse is not None:
yield sse
def decode(self, line: str) -> ServerSentEvent | None:
# See: https://html.spec.whatwg.org/multipage/server-sent-events.html#event-stream-interpretation # noqa: E501
if not line:
if not self._event and not self._data and not self._last_event_id and self._retry is None:
return None
sse = ServerSentEvent(
event=self._event,
data="\n".join(self._data),
id=self._last_event_id,
retry=self._retry,
)
# NOTE: as per the SSE spec, do not reset last_event_id.
self._event = None
self._data = []
self._retry = None
return sse
if line.startswith(":"):
return None
fieldname, _, value = line.partition(":")
if value.startswith(" "):
value = value[1:]
if fieldname == "event":
self._event = value
elif fieldname == "data":
self._data.append(value)
elif fieldname == "id":
if "\0" in value:
pass
else:
self._last_event_id = value
elif fieldname == "retry":
try:
self._retry = int(value)
except (TypeError, ValueError):
pass
else:
pass # Field is ignored.
return None