Skip to content

Add async support #146

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jan 5, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,32 @@ image_resp = openai.Image.create(prompt="two dogs playing chess, oil painting",

```

## Async API

Async support is available in the API by prepending `a` to a network-bound method:

```python
import openai
openai.api_key = "sk-..." # supply your API key however you choose

async def create_completion():
completion_resp = await openai.Completion.acreate(prompt="This is a test", engine="davinci")

```

To make async requests more efficient, you can pass in your own
``aiohttp.ClientSession``, but you must manage the client session closing at the end
of your program/event loop:

```python
import openai
from aiohttp import ClientSession

openai.aiosession.set(ClientSession())
# At the end of your program, close the http session
await openai.aiosession.get().close()
```

See the [usage guide](https://beta.openai.com/docs/guides/images) for more details.

## Requirements
Expand Down
10 changes: 9 additions & 1 deletion openai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
# Originally forked from the MIT-licensed Stripe Python bindings.

import os
from typing import Optional
from contextvars import ContextVar
from typing import Optional, TYPE_CHECKING

from openai.api_resources import (
Answer,
Expand All @@ -24,6 +25,9 @@
)
from openai.error import APIError, InvalidRequestError, OpenAIError

if TYPE_CHECKING:
from aiohttp import ClientSession

api_key = os.environ.get("OPENAI_API_KEY")
# Path of a file with an API key, whose contents can change. Supercedes
# `api_key` if set. The main use case is volume-mounted Kubernetes secrets,
Expand All @@ -44,6 +48,10 @@
debug = False
log = None # Set to either 'debug' or 'info', controls console logging

aiosession: ContextVar[Optional["ClientSession"]] = ContextVar(
"aiohttp-session", default=None
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At first I thought this was just a global but after learning more about Context and ContextVar I realized that it enabled the thing I was scared this would prevent: using different sessions, by making a copy of the current context, setting the new session and running the code through the new context, thus having the openai module picking the new context without impacting other code that will still get the other one.

This might be worth putting as an example in the readme since it was not clear to me that this was possible without first reading more about ContextVar.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is already addressed in the README I believe in the async section. Please add suggestions. I've also added a comment in the source code itself.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's probably fine I agree. I think it's mostly a matter of knowing how ContextVar works which I expect people dealing to async code to do (and that I didn't initially :))


__all__ = [
"APIError",
"Answer",
Expand Down
173 changes: 161 additions & 12 deletions openai/api_requestor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Dict, Iterator, Optional, Tuple, Union, overload
from urllib.parse import urlencode, urlsplit, urlunsplit

import aiohttp
import requests
from typing_extensions import Literal

Expand Down Expand Up @@ -44,6 +45,19 @@ def _requests_proxies_arg(proxy) -> Optional[Dict[str, str]]:
)


def _aiohttp_proxies_arg(proxy) -> Optional[str]:
"""Returns a value suitable for the 'proxies' argument to 'aiohttp.ClientSession.request."""
if proxy is None:
return None
elif isinstance(proxy, str):
return proxy
elif isinstance(proxy, dict):
return proxy["https"] if "https" in proxy else proxy["http"]
else:
raise ValueError(
"'openai.proxy' must be specified as either a string URL or a dict with string URL under the https and/or http keys."
)

def _make_session() -> requests.Session:
if not openai.verify_ssl_certs:
warnings.warn("verify_ssl_certs is ignored; openai always verifies.")
Expand Down Expand Up @@ -72,6 +86,18 @@ def parse_stream(rbody):
yield line


async def parse_stream_async(rbody: aiohttp.StreamReader):
async for line in rbody:
if line:
if line == b"data: [DONE]":
continue
if hasattr(line, "decode"):
line = line.decode("utf-8")
if line.startswith("data: "):
line = line[len("data: ") :]
yield line


class APIRequestor:
def __init__(
self,
Expand Down Expand Up @@ -181,6 +207,29 @@ def request(
resp, got_stream = self._interpret_response(result, stream)
return resp, got_stream, self.api_key

async def arequest(
self,
method,
url,
params=None,
headers=None,
files=None,
stream: bool = False,
request_id: Optional[str] = None,
request_timeout: Optional[Union[float, Tuple[float, float]]] = None,
) -> Tuple[Union[OpenAIResponse, Iterator[OpenAIResponse]], bool, str]:
result = await self.arequest_raw(
method.lower(),
url,
params=params,
supplied_headers=headers,
files=files,
request_id=request_id,
request_timeout=request_timeout,
)
resp, got_stream = await self._interpret_async_response(result, stream)
return resp, got_stream, self.api_key

def handle_error_response(self, rbody, rcode, resp, rheaders, stream_error=False):
try:
error_data = resp["error"]
Expand Down Expand Up @@ -310,18 +359,9 @@ def _validate_headers(

return headers

def request_raw(
self,
method,
url,
*,
params=None,
supplied_headers: Dict[str, str] = None,
files=None,
stream: bool = False,
request_id: Optional[str] = None,
request_timeout: Optional[Union[float, Tuple[float, float]]] = None,
) -> requests.Response:
def _prepare_request_raw(
self, url, supplied_headers, method, params, files, request_id: Optional[str],
) -> Tuple[str, Dict[str, str], Optional[bytes]]:
abs_url = "%s%s" % (self.api_base, url)
headers = self._validate_headers(supplied_headers)

Expand Down Expand Up @@ -350,6 +390,24 @@ def request_raw(
util.log_info("Request to OpenAI API", method=method, path=abs_url)
util.log_debug("Post details", data=data, api_version=self.api_version)

return abs_url, headers, data

def request_raw(
self,
method,
url,
*,
params=None,
supplied_headers: Dict[str, str] = None,
files=None,
stream: bool = False,
request_id: Optional[str] = None,
request_timeout: Optional[Union[float, Tuple[float, float]]] = None,
) -> requests.Response:
abs_url, headers, data = self._prepare_request_raw(
url, supplied_headers, method, params, files, request_id
)

if not hasattr(_thread_context, "session"):
_thread_context.session = _make_session()
try:
Expand Down Expand Up @@ -380,6 +438,74 @@ def request_raw(
)
return result

async def arequest_raw(
self,
method,
url,
*,
params=None,
supplied_headers: Dict[str, str] = None,
files=None,
request_id: Optional[str] = None,
request_timeout: Optional[Union[float, Tuple[float, float]]] = None,
) -> aiohttp.ClientResponse:
abs_url, headers, data = self._prepare_request_raw(
url, supplied_headers, method, params, files, request_id
)

timeout = aiohttp.ClientTimeout(
total=request_timeout if request_timeout else TIMEOUT_SECS
)
user_set_session = openai.aiosession.get()

if files:
data, content_type = requests.models.RequestEncodingMixin._encode_files(
files, data
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

aiohttp has MultipartWriter for this use case. Not as trivial as just passing a dictionary of files as with requests but I think that'd be the proper way to do this.

Copy link
Contributor Author

@Andrew-Chen-Wang Andrew-Chen-Wang Dec 29, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the preparation of the data will be fairly different between requests and aiohttp causing more duplication of code. I'm planning of essentially copying _encode_files(albeit the actual appending of files is different). Is that alright?

Also, feel free to add my fork as a remote and commit freely. I've enabled maintainers to be able to commit to my fork.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. I took a quick look at MultipartWriter and the implementation of _encode_files and while it should work for us here that's a fair amount of somewhat brittle code that could be easy to mess up. Maybe adding a comment/TODO above this line explaining what is going on is good enough for now.

headers["Content-Type"] = content_type

result = None
try:
if user_set_session:
result = await user_set_session.request(
method,
abs_url,
headers=headers,
data=data,
proxy=_aiohttp_proxies_arg(openai.proxy),
timeout=timeout,
)
else:
async with aiohttp.ClientSession() as session:
result = await session.request(
method,
abs_url,
headers=headers,
data=data,
proxy=_aiohttp_proxies_arg(openai.proxy),
timeout=timeout,
)
util.log_info(
"OpenAI API response",
path=abs_url,
response_code=result.status,
processing_ms=result.headers.get("OpenAI-Processing-Ms"),
request_id=result.headers.get("X-Request-Id"),
)
# Don't read the whole stream for debug logging unless necessary.
if openai.log == "debug":
util.log_debug(
"API response body", body=result.content, headers=result.headers
)
return result
except aiohttp.ServerTimeoutError as e:
raise error.Timeout("Request timed out") from e
except aiohttp.ClientError as e:
raise error.APIConnectionError("Error communicating with OpenAI") from e
finally:
if result and not result.closed:
result.close()

def _interpret_response(
self, result: requests.Response, stream: bool
) -> Tuple[Union[OpenAIResponse, Iterator[OpenAIResponse]], bool]:
Expand All @@ -399,6 +525,29 @@ def _interpret_response(
False,
)

async def _interpret_async_response(
self, result: aiohttp.ClientResponse, stream: bool
) -> Tuple[Union[OpenAIResponse, Iterator[OpenAIResponse]], bool]:
"""Returns the response(s) and a bool indicating whether it is a stream."""
if stream and "text/event-stream" in result.headers.get("Content-Type", ""):
return (
self._interpret_response_line(
line, result.status, result.headers, stream=True
)
async for line in parse_stream_async(result.content)
), True
else:
try:
await result.read()
except aiohttp.ClientError as e:
print(result.content, e)
return (
self._interpret_response_line(
await result.read(), result.status, result.headers, stream=False
),
False,
)

def _interpret_response_line(
self, rbody, rcode, rheaders, stream: bool
) -> OpenAIResponse:
Expand Down
43 changes: 43 additions & 0 deletions openai/api_resources/abstract/api_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@ def retrieve(
instance.refresh(request_id=request_id, request_timeout=request_timeout)
return instance

@classmethod
def aretrieve(
cls, id, api_key=None, request_id=None, request_timeout=None, **params
):
instance = cls(id, api_key, **params)
return instance.arefresh(request_id=request_id, request_timeout=request_timeout)

def refresh(self, request_id=None, request_timeout=None):
self.refresh_from(
self.request(
Expand All @@ -31,6 +38,17 @@ def refresh(self, request_id=None, request_timeout=None):
)
return self

async def arefresh(self, request_id=None, request_timeout=None):
self.refresh_from(
await self.arequest(
"get",
self.instance_url(operation="refresh"),
request_id=request_id,
request_timeout=request_timeout,
)
)
return self

@classmethod
def class_url(cls):
if cls == APIResource:
Expand Down Expand Up @@ -116,6 +134,31 @@ def _static_request(
response, api_key, api_version, organization
)

@classmethod
async def _astatic_request(
cls,
method_,
url_,
api_key=None,
api_base=None,
api_type=None,
request_id=None,
api_version=None,
organization=None,
**params,
):
requestor = api_requestor.APIRequestor(
api_key,
api_version=api_version,
organization=organization,
api_base=api_base,
api_type=api_type,
)
response, _, api_key = await requestor.arequest(
method_, url_, params, request_id=request_id
)
return response

@classmethod
def _get_api_type_and_version(
cls, api_type: Optional[str] = None, api_version: Optional[str] = None
Expand Down
Loading