Skip to content

Add an option to use Azure endpoints for the /completions & /search operations. #45

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jan 22, 2022
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,6 @@
/public/dist
__pycache__
build
.ipynb_checkpoints
*.egg
.vscode/settings.json
.ipynb_checkpoints
4 changes: 3 additions & 1 deletion openai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@

organization = os.environ.get("OPENAI_ORGANIZATION")
api_base = os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1")
api_version = None
api_type = os.environ.get("OPENAI_API_TYPE", "open_ai")
api_version = '2021-11-01-preview' if api_type == "azure" else None
verify_ssl_certs = True # No effect. Certificates are always verified.
proxy = None
app_info = None
Expand All @@ -52,6 +53,7 @@
"Search",
"api_base",
"api_key",
"api_type",
"api_key_path",
"api_version",
"app_info",
Expand Down
10 changes: 7 additions & 3 deletions openai/api_requestor.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from email import header
import json
import platform
import threading
Expand All @@ -11,6 +12,7 @@
import openai
from openai import error, util, version
from openai.openai_response import OpenAIResponse
from openai.util import ApiType

TIMEOUT_SECS = 600
MAX_CONNECTION_RETRIES = 2
Expand Down Expand Up @@ -69,9 +71,10 @@ def parse_stream(rbody):


class APIRequestor:
def __init__(self, key=None, api_base=None, api_version=None, organization=None):
def __init__(self, key=None, api_base=None, api_type=None, api_version=None, organization=None):
self.api_base = api_base or openai.api_base
self.api_key = key or util.default_api_key()
self.api_type = ApiType.from_str(api_type) if api_type else ApiType.from_str(openai.api_type)
self.api_version = api_version or openai.api_version
self.organization = organization or openai.organization

Expand Down Expand Up @@ -192,13 +195,14 @@ def request_headers(
headers = {
"X-OpenAI-Client-User-Agent": json.dumps(ua),
"User-Agent": user_agent,
"Authorization": "Bearer %s" % (self.api_key,),
}

headers.update(util.api_key_to_header(self.api_type, self.api_key))

if self.organization:
headers["OpenAI-Organization"] = self.organization

if self.api_version is not None:
if self.api_version is not None and self.api_type == ApiType.OPEN_AI:
headers["OpenAI-Version"] = self.api_version
if request_id is not None:
headers["X-Request-Id"] = request_id
Expand Down
27 changes: 23 additions & 4 deletions openai/api_resources/abstract/api_resource.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from urllib.parse import quote_plus

from openai import api_requestor, error, util
import openai
from openai.openai_object import OpenAIObject
from openai.util import ApiType


class APIResource(OpenAIObject):
api_prefix = ""
azure_api_prefix = 'openai/deployments'

@classmethod
def retrieve(cls, id, api_key=None, request_id=None, **params):
Expand All @@ -32,7 +35,7 @@ def class_url(cls):
return "/%s/%ss" % (cls.api_prefix, base)
return "/%ss" % (base)

def instance_url(self):
def instance_url(self, operation=None):
id = self.get("id")

if not isinstance(id, str):
Expand All @@ -42,10 +45,26 @@ def instance_url(self):
" `unicode`)" % (type(self).__name__, id, type(id)),
"id",
)
api_version = self.api_version or openai.api_version

base = self.class_url()
extn = quote_plus(id)
return "%s/%s" % (base, extn)
if self.typed_api_type == ApiType.AZURE:
if not api_version:
raise error.InvalidRequestError("An API version is required for the Azure API type.")
if not operation:
raise error.InvalidRequestError(
"The request needs an operation (eg: 'search') for the Azure OpenAI API type."
)
extn = quote_plus(id)
return "/%s/%s/%s?api-version=%s" % (self.azure_api_prefix, extn, operation, api_version)

elif self.typed_api_type == ApiType.OPEN_AI:
base = self.class_url()
extn = quote_plus(id)
return "%s/%s" % (base, extn)

else:
raise error.InvalidAPIType('Unsupported API type %s' % self.api_type)


# The `method_` and `url_` arguments are suffixed with an underscore to
# avoid conflicting with actual request parameters in `params`.
Expand Down
58 changes: 48 additions & 10 deletions openai/api_resources/abstract/engine_api_resource.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,60 @@
from pydoc import apropos
import time
from typing import Optional
from urllib.parse import quote_plus

import openai
from openai import api_requestor, error, util
from openai.api_resources.abstract.api_resource import APIResource
from openai.openai_response import OpenAIResponse
from openai.util import ApiType

MAX_TIMEOUT = 20


class EngineAPIResource(APIResource):
engine_required = True
plain_old_data = False
azure_api_prefix = 'openai/deployments'

def __init__(self, engine: Optional[str] = None, **kwargs):
super().__init__(engine=engine, **kwargs)

@classmethod
def class_url(cls, engine: Optional[str] = None):
def class_url(cls, engine: Optional[str] = None, api_type : Optional[str] = None, api_version: Optional[str] = None):
# Namespaces are separated in object names with periods (.) and in URLs
# with forward slashes (/), so replace the former with the latter.
base = cls.OBJECT_NAME.replace(".", "/") # type: ignore
if engine is None:
return "/%ss" % (base)
typed_api_type = ApiType.from_str(api_type) if api_type else ApiType.from_str(openai.api_type)
api_version = api_version or openai.api_version

if typed_api_type == ApiType.AZURE:
if not api_version:
raise error.InvalidRequestError("An API version is required for the Azure API type.")
if engine is None:
raise error.InvalidRequestError(
"You must provide the deployment name in the 'engine' parameter to access the Azure OpenAI service"
)
extn = quote_plus(engine)
return "/%s/%s/%ss?api-version=%s" % (cls.azure_api_prefix, extn, base, api_version)

elif typed_api_type == ApiType.OPEN_AI:
if engine is None:
return "/%ss" % (base)

extn = quote_plus(engine)
return "/engines/%s/%ss" % (extn, base)

else:
raise error.InvalidAPIType('Unsupported API type %s' % api_type)

extn = quote_plus(engine)
return "/engines/%s/%ss" % (extn, base)

@classmethod
def create(
cls,
api_key=None,
api_base=None,
api_type=None,
request_id=None,
api_version=None,
organization=None,
Expand All @@ -58,10 +81,11 @@ def create(
requestor = api_requestor.APIRequestor(
api_key,
api_base=api_base,
api_type=api_type,
api_version=api_version,
organization=organization,
)
url = cls.class_url(engine)
url = cls.class_url(engine, api_type, api_version)
response, _, api_key = requestor.request(
"post", url, params, stream=stream, request_id=request_id
)
Expand Down Expand Up @@ -103,14 +127,28 @@ def instance_url(self):
"id",
)

base = self.class_url(self.engine)
extn = quote_plus(id)
url = "%s/%s" % (base, extn)
params_connector = '?'
if self.typed_api_type == ApiType.AZURE:
api_version = self.api_version or openai.api_version
if not api_version:
raise error.InvalidRequestError("An API version is required for the Azure API type.")
extn = quote_plus(id)
base = self.OBJECT_NAME.replace(".", "/")
url = "/%s/%s/%ss/%s?api-version=%s" % (self.azure_api_prefix, self.engine, base, extn, api_version)
params_connector = '&'

elif self.typed_api_type == ApiType.OPEN_AI:
base = self.class_url(self.engine, self.api_type, self.api_version)
extn = quote_plus(id)
url = "%s/%s" % (base, extn)

else:
raise error.InvalidAPIType('Unsupported API type %s' % self.api_type)

timeout = self.get("timeout")
if timeout is not None:
timeout = quote_plus(str(timeout))
url += "?timeout={}".format(timeout)
url += params_connector + "timeout={}".format(timeout)
return url

def wait(self, timeout=None):
Expand Down
10 changes: 8 additions & 2 deletions openai/api_resources/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@

from openai import util
from openai.api_resources.abstract import ListableAPIResource, UpdateableAPIResource
from openai.error import TryAgain
from openai.error import InvalidAPIType, TryAgain
from openai.util import ApiType


class Engine(ListableAPIResource, UpdateableAPIResource):
Expand All @@ -27,7 +28,12 @@ def generate(self, timeout=None, **params):
util.log_info("Waiting for model to warm up", error=e)

def search(self, **params):
return self.request("post", self.instance_url() + "/search", params)
if self.typed_api_type == ApiType.AZURE:
return self.request("post", self.instance_url("search"), params)
elif self.typed_api_type == ApiType.OPEN_AI:
return self.request("post", self.instance_url() + "/search", params)
else:
raise InvalidAPIType('Unsupported API type %s' % self.api_type)

def embeddings(self, **params):
warnings.warn(
Expand Down
3 changes: 3 additions & 0 deletions openai/error.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,9 @@ class RateLimitError(OpenAIError):
class ServiceUnavailableError(OpenAIError):
pass

class InvalidAPIType(OpenAIError):
pass


class SignatureVerificationError(OpenAIError):
def __init__(self, message, sig_header, http_body=None):
Expand Down
13 changes: 13 additions & 0 deletions openai/openai_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
from copy import deepcopy
from typing import Optional

import openai
from openai import api_requestor, util
from openai.openai_response import OpenAIResponse
from openai.util import ApiType


class OpenAIObject(dict):
Expand All @@ -14,6 +16,7 @@ def __init__(
id=None,
api_key=None,
api_version=None,
api_type=None,
organization=None,
response_ms: Optional[int] = None,
api_base=None,
Expand All @@ -30,6 +33,7 @@ def __init__(

object.__setattr__(self, "api_key", api_key)
object.__setattr__(self, "api_version", api_version)
object.__setattr__(self, "api_type", api_type)
object.__setattr__(self, "organization", organization)
object.__setattr__(self, "api_base_override", api_base)
object.__setattr__(self, "engine", engine)
Expand Down Expand Up @@ -90,6 +94,7 @@ def __reduce__(self):
self.get("id", None),
self.api_key,
self.api_version,
self.api_type,
self.organization,
),
dict(self), # state
Expand Down Expand Up @@ -128,11 +133,13 @@ def refresh_from(
values,
api_key=None,
api_version=None,
api_type=None,
organization=None,
response_ms: Optional[int] = None,
):
self.api_key = api_key or getattr(values, "api_key", None)
self.api_version = api_version or getattr(values, "api_version", None)
self.api_type = api_type or getattr(values, "api_type", None)
self.organization = organization or getattr(values, "organization", None)
self._response_ms = response_ms or getattr(values, "_response_ms", None)

Expand Down Expand Up @@ -164,6 +171,7 @@ def request(
requestor = api_requestor.APIRequestor(
key=self.api_key,
api_base=self.api_base_override or self.api_base(),
api_type=self.api_type,
api_version=self.api_version,
organization=self.organization,
)
Expand Down Expand Up @@ -233,6 +241,10 @@ def to_dict_recursive(self):
def openai_id(self):
return self.id

@property
def typed_api_type(self):
return ApiType.from_str(self.api_type) if self.api_type else ApiType.from_str(openai.api_type)

# This class overrides __setitem__ to throw exceptions on inputs that it
# doesn't like. This can cause problems when we try to copy an object
# wholesale because some data that's returned from the API may not be valid
Expand All @@ -243,6 +255,7 @@ def __copy__(self):
self.get("id"),
self.api_key,
api_version=self.api_version,
api_type=self.api_type,
organization=self.organization,
)

Expand Down
2 changes: 1 addition & 1 deletion openai/openai_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@ def organization(self) -> Optional[str]:
@property
def response_ms(self) -> Optional[int]:
h = self._headers.get("Openai-Processing-Ms")
return None if h is None else int(h)
return None if h is None else round(float(h))
27 changes: 25 additions & 2 deletions openai/tests/test_api_requestor.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import json

import pytest
import requests
from pytest_mock import MockerFixture

from openai import Model
from openai.api_requestor import APIRequestor


@pytest.mark.requestor
def test_requestor_sets_request_id(mocker: MockerFixture) -> None:
# Fake out 'requests' and confirm that the X-Request-Id header is set.

Expand All @@ -25,3 +26,25 @@ def fake_request(self, *args, **kwargs):
Model.retrieve("xxx", request_id=fake_request_id) # arbitrary API resource
got_request_id = got_headers.get("X-Request-Id")
assert got_request_id == fake_request_id

@pytest.mark.requestor
def test_requestor_open_ai_headers() -> None:
api_requestor = APIRequestor(key="test_key", api_type="open_ai")
headers = {"Test_Header": "Unit_Test_Header"}
headers = api_requestor.request_headers(method="get", extra=headers, request_id="test_id")
print(headers)
assert "Test_Header"in headers
assert headers["Test_Header"] == "Unit_Test_Header"
assert "Authorization"in headers
assert headers["Authorization"] == "Bearer test_key"

@pytest.mark.requestor
def test_requestor_azure_headers() -> None:
api_requestor = APIRequestor(key="test_key", api_type="azure")
headers = {"Test_Header": "Unit_Test_Header"}
headers = api_requestor.request_headers(method="get", extra=headers, request_id="test_id")
print(headers)
assert "Test_Header"in headers
assert headers["Test_Header"] == "Unit_Test_Header"
assert "api-key"in headers
assert headers["api-key"] == "test_key"
Loading