Skip to content

fix(azure): azure_deployment use with realtime + non-deployment-based APIs #2154

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Feb 28, 2025
67 changes: 56 additions & 11 deletions src/openai/lib/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ def __init__(self) -> None:


class BaseAzureClient(BaseClient[_HttpxClientT, _DefaultStreamT]):
_azure_endpoint: httpx.URL | None
_azure_deployment: str | None

@override
def _build_request(
self,
Expand All @@ -58,11 +61,29 @@ def _build_request(
) -> httpx.Request:
if options.url in _deployments_endpoints and is_mapping(options.json_data):
model = options.json_data.get("model")
if model is not None and not "/deployments" in str(self.base_url):
if model is not None and "/deployments" not in str(self.base_url.path):
options.url = f"/deployments/{model}{options.url}"

return super()._build_request(options, retries_taken=retries_taken)

@override
def _prepare_url(self, url: str) -> httpx.URL:
"""Adjust the URL if the client was configured with an Azure endpoint + deployment
and the API feature being called is **not** a deployments-based endpoint
(i.e. requires /deployments/deployment-name in the URL path).
"""
if self._azure_deployment and self._azure_endpoint and url not in _deployments_endpoints:
merge_url = httpx.URL(url)
if merge_url.is_relative_url:
merge_raw_path = (
self._azure_endpoint.raw_path.rstrip(b"/") + b"/openai/" + merge_url.raw_path.lstrip(b"/")
)
return self._azure_endpoint.copy_with(raw_path=merge_raw_path)

return merge_url

return super()._prepare_url(url)


class AzureOpenAI(BaseAzureClient[httpx.Client, Stream[Any]], OpenAI):
@overload
Expand Down Expand Up @@ -160,8 +181,8 @@ def __init__(

azure_ad_token_provider: A function that returns an Azure Active Directory token, will be invoked on every request.

azure_deployment: A model deployment, if given sets the base client URL to include `/deployments/{azure_deployment}`.
Note: this means you won't be able to use non-deployment endpoints. Not supported with Assistants APIs.
azure_deployment: A model deployment, if given with `azure_endpoint`, sets the base client URL to include `/deployments/{azure_deployment}`.
Not supported with Assistants APIs.
"""
if api_key is None:
api_key = os.environ.get("AZURE_OPENAI_API_KEY")
Expand Down Expand Up @@ -224,6 +245,8 @@ def __init__(
self._api_version = api_version
self._azure_ad_token = azure_ad_token
self._azure_ad_token_provider = azure_ad_token_provider
self._azure_deployment = azure_deployment if azure_endpoint else None
self._azure_endpoint = httpx.URL(azure_endpoint) if azure_endpoint else None

@override
def copy(
Expand Down Expand Up @@ -307,20 +330,30 @@ def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions:

return options

def _configure_realtime(self, model: str, extra_query: Query) -> tuple[Query, dict[str, str]]:
def _configure_realtime(self, model: str, extra_query: Query) -> tuple[httpx.URL, dict[str, str]]:
auth_headers = {}
query = {
**extra_query,
"api-version": self._api_version,
"deployment": model,
"deployment": self._azure_deployment or model,
}
if self.api_key != "<missing API key>":
auth_headers = {"api-key": self.api_key}
else:
token = self._get_azure_ad_token()
if token:
auth_headers = {"Authorization": f"Bearer {token}"}
return query, auth_headers

if self.websocket_base_url is not None:
base_url = httpx.URL(self.websocket_base_url)
merge_raw_path = base_url.raw_path.rstrip(b"/") + b"/realtime"
realtime_url = base_url.copy_with(raw_path=merge_raw_path)
else:
base_url = self._prepare_url("/realtime")
realtime_url = base_url.copy_with(scheme="wss")

url = realtime_url.copy_with(params={**query})
return url, auth_headers


class AsyncAzureOpenAI(BaseAzureClient[httpx.AsyncClient, AsyncStream[Any]], AsyncOpenAI):
Expand Down Expand Up @@ -422,8 +455,8 @@ def __init__(

azure_ad_token_provider: A function that returns an Azure Active Directory token, will be invoked on every request.

azure_deployment: A model deployment, if given sets the base client URL to include `/deployments/{azure_deployment}`.
Note: this means you won't be able to use non-deployment endpoints. Not supported with Assistants APIs.
azure_deployment: A model deployment, if given with `azure_endpoint`, sets the base client URL to include `/deployments/{azure_deployment}`.
Not supported with Assistants APIs.
"""
if api_key is None:
api_key = os.environ.get("AZURE_OPENAI_API_KEY")
Expand Down Expand Up @@ -486,6 +519,8 @@ def __init__(
self._api_version = api_version
self._azure_ad_token = azure_ad_token
self._azure_ad_token_provider = azure_ad_token_provider
self._azure_deployment = azure_deployment if azure_endpoint else None
self._azure_endpoint = httpx.URL(azure_endpoint) if azure_endpoint else None

@override
def copy(
Expand Down Expand Up @@ -571,17 +606,27 @@ async def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOp

return options

async def _configure_realtime(self, model: str, extra_query: Query) -> tuple[Query, dict[str, str]]:
async def _configure_realtime(self, model: str, extra_query: Query) -> tuple[httpx.URL, dict[str, str]]:
auth_headers = {}
query = {
**extra_query,
"api-version": self._api_version,
"deployment": model,
"deployment": self._azure_deployment or model,
}
if self.api_key != "<missing API key>":
auth_headers = {"api-key": self.api_key}
else:
token = await self._get_azure_ad_token()
if token:
auth_headers = {"Authorization": f"Bearer {token}"}
return query, auth_headers

if self.websocket_base_url is not None:
base_url = httpx.URL(self.websocket_base_url)
merge_raw_path = base_url.raw_path.rstrip(b"/") + b"/realtime"
realtime_url = base_url.copy_with(raw_path=merge_raw_path)
else:
base_url = self._prepare_url("/realtime")
realtime_url = base_url.copy_with(scheme="wss")

url = realtime_url.copy_with(params={**query})
return url, auth_headers
36 changes: 18 additions & 18 deletions src/openai/resources/beta/realtime/realtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,15 +324,15 @@ async def __aenter__(self) -> AsyncRealtimeConnection:
extra_query = self.__extra_query
auth_headers = self.__client.auth_headers
if is_async_azure_client(self.__client):
extra_query, auth_headers = await self.__client._configure_realtime(self.__model, extra_query)

url = self._prepare_url().copy_with(
params={
**self.__client.base_url.params,
"model": self.__model,
**extra_query,
},
)
url, auth_headers = await self.__client._configure_realtime(self.__model, extra_query)
else:
url = self._prepare_url().copy_with(
params={
**self.__client.base_url.params,
"model": self.__model,
**extra_query,
},
)
log.debug("Connecting to %s", url)
if self.__websocket_connection_options:
log.debug("Connection options: %s", self.__websocket_connection_options)
Expand Down Expand Up @@ -506,15 +506,15 @@ def __enter__(self) -> RealtimeConnection:
extra_query = self.__extra_query
auth_headers = self.__client.auth_headers
if is_azure_client(self.__client):
extra_query, auth_headers = self.__client._configure_realtime(self.__model, extra_query)

url = self._prepare_url().copy_with(
params={
**self.__client.base_url.params,
"model": self.__model,
**extra_query,
},
)
url, auth_headers = self.__client._configure_realtime(self.__model, extra_query)
else:
url = self._prepare_url().copy_with(
params={
**self.__client.base_url.params,
"model": self.__model,
**extra_query,
},
)
log.debug("Connecting to %s", url)
if self.__websocket_connection_options:
log.debug("Connection options: %s", self.__websocket_connection_options)
Expand Down
Loading