-
Notifications
You must be signed in to change notification settings - Fork 3.8k
Add async support #146
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add async support #146
Changes from 3 commits
8e0b07e
1be13b6
e24131f
0ec3165
dbdd4dd
51da1d9
b2f365a
08a155b
0be0b30
b7235dd
421f02b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,6 +6,7 @@ | |
from typing import Dict, Iterator, Optional, Tuple, Union, overload | ||
from urllib.parse import urlencode, urlsplit, urlunsplit | ||
|
||
import aiohttp | ||
import requests | ||
from typing_extensions import Literal | ||
|
||
|
@@ -44,6 +45,19 @@ def _requests_proxies_arg(proxy) -> Optional[Dict[str, str]]: | |
) | ||
|
||
|
||
def _aiohttp_proxies_arg(proxy) -> Optional[str]: | ||
"""Returns a value suitable for the 'proxies' argument to 'aiohttp.ClientSession.request.""" | ||
if proxy is None: | ||
return None | ||
elif isinstance(proxy, str): | ||
return proxy | ||
elif isinstance(proxy, dict): | ||
return proxy["https"] if "https" in proxy else proxy["http"] | ||
else: | ||
raise ValueError( | ||
"'openai.proxy' must be specified as either a string URL or a dict with string URL under the https and/or http keys." | ||
) | ||
|
||
def _make_session() -> requests.Session: | ||
if not openai.verify_ssl_certs: | ||
warnings.warn("verify_ssl_certs is ignored; openai always verifies.") | ||
|
@@ -72,6 +86,18 @@ def parse_stream(rbody): | |
yield line | ||
|
||
|
||
async def parse_stream_async(rbody: aiohttp.StreamReader): | ||
async for line in rbody: | ||
if line: | ||
if line == b"data: [DONE]": | ||
continue | ||
if hasattr(line, "decode"): | ||
line = line.decode("utf-8") | ||
if line.startswith("data: "): | ||
line = line[len("data: ") :] | ||
yield line | ||
Andrew-Chen-Wang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
class APIRequestor: | ||
def __init__( | ||
self, | ||
|
@@ -181,6 +207,29 @@ def request( | |
resp, got_stream = self._interpret_response(result, stream) | ||
return resp, got_stream, self.api_key | ||
|
||
async def arequest( | ||
Andrew-Chen-Wang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self, | ||
method, | ||
url, | ||
params=None, | ||
headers=None, | ||
files=None, | ||
stream: bool = False, | ||
request_id: Optional[str] = None, | ||
request_timeout: Optional[Union[float, Tuple[float, float]]] = None, | ||
) -> Tuple[Union[OpenAIResponse, Iterator[OpenAIResponse]], bool, str]: | ||
result = await self.arequest_raw( | ||
method.lower(), | ||
url, | ||
params=params, | ||
supplied_headers=headers, | ||
files=files, | ||
request_id=request_id, | ||
request_timeout=request_timeout, | ||
) | ||
resp, got_stream = await self._interpret_async_response(result, stream) | ||
return resp, got_stream, self.api_key | ||
|
||
def handle_error_response(self, rbody, rcode, resp, rheaders, stream_error=False): | ||
try: | ||
error_data = resp["error"] | ||
|
@@ -310,18 +359,9 @@ def _validate_headers( | |
|
||
return headers | ||
|
||
def request_raw( | ||
self, | ||
method, | ||
url, | ||
*, | ||
params=None, | ||
supplied_headers: Dict[str, str] = None, | ||
files=None, | ||
stream: bool = False, | ||
request_id: Optional[str] = None, | ||
request_timeout: Optional[Union[float, Tuple[float, float]]] = None, | ||
) -> requests.Response: | ||
def _prepare_request_raw( | ||
self, url, supplied_headers, method, params, files, request_id: Optional[str], | ||
) -> Tuple[str, Dict[str, str], Optional[bytes]]: | ||
abs_url = "%s%s" % (self.api_base, url) | ||
headers = self._validate_headers(supplied_headers) | ||
|
||
|
@@ -350,6 +390,24 @@ def request_raw( | |
util.log_info("Request to OpenAI API", method=method, path=abs_url) | ||
util.log_debug("Post details", data=data, api_version=self.api_version) | ||
|
||
return abs_url, headers, data | ||
|
||
def request_raw( | ||
self, | ||
method, | ||
url, | ||
*, | ||
params=None, | ||
supplied_headers: Dict[str, str] = None, | ||
files=None, | ||
stream: bool = False, | ||
request_id: Optional[str] = None, | ||
request_timeout: Optional[Union[float, Tuple[float, float]]] = None, | ||
) -> requests.Response: | ||
abs_url, headers, data = self._prepare_request_raw( | ||
url, supplied_headers, method, params, files, request_id | ||
) | ||
|
||
if not hasattr(_thread_context, "session"): | ||
_thread_context.session = _make_session() | ||
try: | ||
|
@@ -380,6 +438,74 @@ def request_raw( | |
) | ||
return result | ||
|
||
async def arequest_raw( | ||
self, | ||
method, | ||
url, | ||
*, | ||
params=None, | ||
supplied_headers: Dict[str, str] = None, | ||
files=None, | ||
request_id: Optional[str] = None, | ||
request_timeout: Optional[Union[float, Tuple[float, float]]] = None, | ||
) -> aiohttp.ClientResponse: | ||
abs_url, headers, data = self._prepare_request_raw( | ||
url, supplied_headers, method, params, files, request_id | ||
) | ||
|
||
timeout = aiohttp.ClientTimeout( | ||
total=request_timeout if request_timeout else TIMEOUT_SECS | ||
) | ||
user_set_session = openai.aiosession.get() | ||
|
||
if files: | ||
data, content_type = requests.models.RequestEncodingMixin._encode_files( | ||
files, data | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the preparation of the data will be fairly different between requests and aiohttp causing more duplication of code. I'm planning of essentially copying Also, feel free to add my fork as a remote and commit freely. I've enabled maintainers to be able to commit to my fork. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Makes sense. I took a quick look at |
||
headers["Content-Type"] = content_type | ||
|
||
result = None | ||
try: | ||
if user_set_session: | ||
result = await user_set_session.request( | ||
method, | ||
abs_url, | ||
headers=headers, | ||
data=data, | ||
proxy=_aiohttp_proxies_arg(openai.proxy), | ||
timeout=timeout, | ||
) | ||
else: | ||
async with aiohttp.ClientSession() as session: | ||
result = await session.request( | ||
method, | ||
abs_url, | ||
headers=headers, | ||
data=data, | ||
proxy=_aiohttp_proxies_arg(openai.proxy), | ||
timeout=timeout, | ||
) | ||
util.log_info( | ||
"OpenAI API response", | ||
path=abs_url, | ||
response_code=result.status, | ||
processing_ms=result.headers.get("OpenAI-Processing-Ms"), | ||
request_id=result.headers.get("X-Request-Id"), | ||
) | ||
# Don't read the whole stream for debug logging unless necessary. | ||
if openai.log == "debug": | ||
util.log_debug( | ||
"API response body", body=result.content, headers=result.headers | ||
) | ||
return result | ||
except aiohttp.ServerTimeoutError as e: | ||
raise error.Timeout("Request timed out") from e | ||
except aiohttp.ClientError as e: | ||
raise error.APIConnectionError("Error communicating with OpenAI") from e | ||
finally: | ||
if result and not result.closed: | ||
result.close() | ||
Andrew-Chen-Wang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def _interpret_response( | ||
self, result: requests.Response, stream: bool | ||
) -> Tuple[Union[OpenAIResponse, Iterator[OpenAIResponse]], bool]: | ||
|
@@ -399,6 +525,29 @@ def _interpret_response( | |
False, | ||
) | ||
|
||
async def _interpret_async_response( | ||
self, result: aiohttp.ClientResponse, stream: bool | ||
) -> Tuple[Union[OpenAIResponse, Iterator[OpenAIResponse]], bool]: | ||
"""Returns the response(s) and a bool indicating whether it is a stream.""" | ||
if stream and "text/event-stream" in result.headers.get("Content-Type", ""): | ||
return ( | ||
self._interpret_response_line( | ||
line, result.status, result.headers, stream=True | ||
) | ||
async for line in parse_stream_async(result.content) | ||
), True | ||
else: | ||
try: | ||
await result.read() | ||
except aiohttp.ClientError as e: | ||
print(result.content, e) | ||
Andrew-Chen-Wang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return ( | ||
self._interpret_response_line( | ||
await result.read(), result.status, result.headers, stream=False | ||
), | ||
False, | ||
) | ||
|
||
def _interpret_response_line( | ||
self, rbody, rcode, rheaders, stream: bool | ||
) -> OpenAIResponse: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At first I thought this was just a global but after learning more about
Context
andContextVar
I realized that it enabled the thing I was scared this would prevent: using different sessions, by making a copy of the current context, setting the new session and running the code through the new context, thus having the openai module picking the new context without impacting other code that will still get the other one.This might be worth putting as an example in the readme since it was not clear to me that this was possible without first reading more about
ContextVar
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is already addressed in the README I believe in the async section. Please add suggestions. I've also added a comment in the source code itself.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's probably fine I agree. I think it's mostly a matter of knowing how
ContextVar
works which I expect people dealing to async code to do (and that I didn't initially :))