Skip to content

chore(internal): loosen type var restrictions #1049

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 19 additions & 22 deletions src/openai/_base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@
Body,
Omit,
Query,
ModelT,
Headers,
Timeout,
NotGiven,
Expand All @@ -61,7 +60,6 @@
HttpxSendArgs,
AsyncTransport,
RequestOptions,
UnknownResponse,
ModelBuilderProtocol,
BinaryResponseContent,
)
Expand Down Expand Up @@ -142,7 +140,7 @@ def __init__(
self.params = params


class BasePage(GenericModel, Generic[ModelT]):
class BasePage(GenericModel, Generic[_T]):
"""
Defines the core interface for pagination.

Expand All @@ -155,7 +153,7 @@ class BasePage(GenericModel, Generic[ModelT]):
"""

_options: FinalRequestOptions = PrivateAttr()
_model: Type[ModelT] = PrivateAttr()
_model: Type[_T] = PrivateAttr()

def has_next_page(self) -> bool:
items = self._get_page_items()
Expand All @@ -166,7 +164,7 @@ def has_next_page(self) -> bool:
def next_page_info(self) -> Optional[PageInfo]:
...

def _get_page_items(self) -> Iterable[ModelT]: # type: ignore[empty-body]
def _get_page_items(self) -> Iterable[_T]: # type: ignore[empty-body]
...

def _params_from_url(self, url: URL) -> httpx.QueryParams:
Expand All @@ -191,13 +189,13 @@ def _info_to_options(self, info: PageInfo) -> FinalRequestOptions:
raise ValueError("Unexpected PageInfo state")


class BaseSyncPage(BasePage[ModelT], Generic[ModelT]):
class BaseSyncPage(BasePage[_T], Generic[_T]):
_client: SyncAPIClient = pydantic.PrivateAttr()

def _set_private_attributes(
self,
client: SyncAPIClient,
model: Type[ModelT],
model: Type[_T],
options: FinalRequestOptions,
) -> None:
self._model = model
Expand All @@ -212,7 +210,7 @@ def _set_private_attributes(
# methods should continue to work as expected as there is an alternative method
# to cast a model to a dictionary, model.dict(), which is used internally
# by pydantic.
def __iter__(self) -> Iterator[ModelT]: # type: ignore
def __iter__(self) -> Iterator[_T]: # type: ignore
for page in self.iter_pages():
for item in page._get_page_items():
yield item
Expand All @@ -237,13 +235,13 @@ def get_next_page(self: SyncPageT) -> SyncPageT:
return self._client._request_api_list(self._model, page=self.__class__, options=options)


class AsyncPaginator(Generic[ModelT, AsyncPageT]):
class AsyncPaginator(Generic[_T, AsyncPageT]):
def __init__(
self,
client: AsyncAPIClient,
options: FinalRequestOptions,
page_cls: Type[AsyncPageT],
model: Type[ModelT],
model: Type[_T],
) -> None:
self._model = model
self._client = client
Expand All @@ -266,7 +264,7 @@ def _parser(resp: AsyncPageT) -> AsyncPageT:

return await self._client.request(self._page_cls, self._options)

async def __aiter__(self) -> AsyncIterator[ModelT]:
async def __aiter__(self) -> AsyncIterator[_T]:
# https://github.com/microsoft/pyright/issues/3464
page = cast(
AsyncPageT,
Expand All @@ -276,20 +274,20 @@ async def __aiter__(self) -> AsyncIterator[ModelT]:
yield item


class BaseAsyncPage(BasePage[ModelT], Generic[ModelT]):
class BaseAsyncPage(BasePage[_T], Generic[_T]):
_client: AsyncAPIClient = pydantic.PrivateAttr()

def _set_private_attributes(
self,
model: Type[ModelT],
model: Type[_T],
client: AsyncAPIClient,
options: FinalRequestOptions,
) -> None:
self._model = model
self._client = client
self._options = options

async def __aiter__(self) -> AsyncIterator[ModelT]:
async def __aiter__(self) -> AsyncIterator[_T]:
async for page in self.iter_pages():
for item in page._get_page_items():
yield item
Expand Down Expand Up @@ -528,7 +526,7 @@ def _process_response_data(
if data is None:
return cast(ResponseT, None)

if cast_to is UnknownResponse:
if cast_to is object:
return cast(ResponseT, data)

try:
Expand Down Expand Up @@ -970,7 +968,7 @@ def _retry_request(

def _request_api_list(
self,
model: Type[ModelT],
model: Type[object],
page: Type[SyncPageT],
options: FinalRequestOptions,
) -> SyncPageT:
Expand Down Expand Up @@ -1132,7 +1130,7 @@ def get_api_list(
self,
path: str,
*,
model: Type[ModelT],
model: Type[object],
page: Type[SyncPageT],
body: Body | None = None,
options: RequestOptions = {},
Expand Down Expand Up @@ -1434,10 +1432,10 @@ async def _retry_request(

def _request_api_list(
self,
model: Type[ModelT],
model: Type[_T],
page: Type[AsyncPageT],
options: FinalRequestOptions,
) -> AsyncPaginator[ModelT, AsyncPageT]:
) -> AsyncPaginator[_T, AsyncPageT]:
return AsyncPaginator(client=self, options=options, page_cls=page, model=model)

@overload
Expand Down Expand Up @@ -1584,13 +1582,12 @@ def get_api_list(
self,
path: str,
*,
# TODO: support paginating `str`
model: Type[ModelT],
model: Type[_T],
page: Type[AsyncPageT],
body: Body | None = None,
options: RequestOptions = {},
method: str = "get",
) -> AsyncPaginator[ModelT, AsyncPageT]:
) -> AsyncPaginator[_T, AsyncPageT]:
opts = FinalRequestOptions.construct(method=method, url=path, json_data=body, **options)
return self._request_api_list(model, page, opts)

Expand Down
4 changes: 2 additions & 2 deletions src/openai/_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import httpx

from ._types import NoneType, UnknownResponse, BinaryResponseContent
from ._types import NoneType, BinaryResponseContent
from ._utils import is_given, extract_type_var_from_base
from ._models import BaseModel, is_basemodel
from ._constants import RAW_RESPONSE_HEADER
Expand Down Expand Up @@ -162,7 +162,7 @@ def _parse(self) -> R:
# `ResponseT` TypeVar, however if that TypeVar is ever updated in the future, then
# this function would become unsafe but a type checker would not report an error.
if (
cast_to is not UnknownResponse
cast_to is not object
and not origin is list
and not origin is dict
and not origin is Union
Expand Down
17 changes: 11 additions & 6 deletions src/openai/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,11 +258,6 @@ class RequestOptions(TypedDict, total=False):
idempotency_key: str


# Sentinel class used when the response type is an object with an unknown schema
class UnknownResponse:
...


# Sentinel class used until PEP 0661 is accepted
class NotGiven:
"""
Expand Down Expand Up @@ -339,7 +334,17 @@ def get(self, __key: str) -> str | None:

ResponseT = TypeVar(
"ResponseT",
bound="Union[str, None, BaseModel, List[Any], Dict[str, Any], Response, UnknownResponse, ModelBuilderProtocol, BinaryResponseContent]",
bound=Union[
object,
str,
None,
"BaseModel",
List[Any],
Dict[str, Any],
Response,
ModelBuilderProtocol,
BinaryResponseContent,
],
)

StrBytesIntFloat = Union[str, bytes, int, float]
Expand Down
29 changes: 15 additions & 14 deletions src/openai/pagination.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,28 @@
# File generated from our OpenAPI spec by Stainless.

from typing import Any, List, Generic, Optional, cast
from typing import Any, List, Generic, TypeVar, Optional, cast
from typing_extensions import Protocol, override, runtime_checkable

from ._types import ModelT
from ._base_client import BasePage, PageInfo, BaseSyncPage, BaseAsyncPage

__all__ = ["SyncPage", "AsyncPage", "SyncCursorPage", "AsyncCursorPage"]

_T = TypeVar("_T")


@runtime_checkable
class CursorPageItem(Protocol):
id: Optional[str]


class SyncPage(BaseSyncPage[ModelT], BasePage[ModelT], Generic[ModelT]):
class SyncPage(BaseSyncPage[_T], BasePage[_T], Generic[_T]):
"""Note: no pagination actually occurs yet, this is for forwards-compatibility."""

data: List[ModelT]
data: List[_T]
object: str

@override
def _get_page_items(self) -> List[ModelT]:
def _get_page_items(self) -> List[_T]:
data = self.data
if not data:
return []
Expand All @@ -36,14 +37,14 @@ def next_page_info(self) -> None:
return None


class AsyncPage(BaseAsyncPage[ModelT], BasePage[ModelT], Generic[ModelT]):
class AsyncPage(BaseAsyncPage[_T], BasePage[_T], Generic[_T]):
"""Note: no pagination actually occurs yet, this is for forwards-compatibility."""

data: List[ModelT]
data: List[_T]
object: str

@override
def _get_page_items(self) -> List[ModelT]:
def _get_page_items(self) -> List[_T]:
data = self.data
if not data:
return []
Expand All @@ -58,11 +59,11 @@ def next_page_info(self) -> None:
return None


class SyncCursorPage(BaseSyncPage[ModelT], BasePage[ModelT], Generic[ModelT]):
data: List[ModelT]
class SyncCursorPage(BaseSyncPage[_T], BasePage[_T], Generic[_T]):
data: List[_T]

@override
def _get_page_items(self) -> List[ModelT]:
def _get_page_items(self) -> List[_T]:
data = self.data
if not data:
return []
Expand All @@ -82,11 +83,11 @@ def next_page_info(self) -> Optional[PageInfo]:
return PageInfo(params={"after": item.id})


class AsyncCursorPage(BaseAsyncPage[ModelT], BasePage[ModelT], Generic[ModelT]):
data: List[ModelT]
class AsyncCursorPage(BaseAsyncPage[_T], BasePage[_T], Generic[_T]):
data: List[_T]

@override
def _get_page_items(self) -> List[ModelT]:
def _get_page_items(self) -> List[_T]:
data = self.data
if not data:
return []
Expand Down
8 changes: 1 addition & 7 deletions src/openai/resources/audio/speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,7 @@

import httpx

from ..._types import (
NOT_GIVEN,
Body,
Query,
Headers,
NotGiven,
)
from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven
from ..._utils import maybe_transform
from ..._compat import cached_property
from ..._resource import SyncAPIResource, AsyncAPIResource
Expand Down
9 changes: 1 addition & 8 deletions src/openai/resources/audio/transcriptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,7 @@

import httpx

from ..._types import (
NOT_GIVEN,
Body,
Query,
Headers,
NotGiven,
FileTypes,
)
from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven, FileTypes
from ..._utils import extract_files, maybe_transform, deepcopy_minimal
from ..._compat import cached_property
from ..._resource import SyncAPIResource, AsyncAPIResource
Expand Down
9 changes: 1 addition & 8 deletions src/openai/resources/audio/translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,7 @@

import httpx

from ..._types import (
NOT_GIVEN,
Body,
Query,
Headers,
NotGiven,
FileTypes,
)
from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven, FileTypes
from ..._utils import extract_files, maybe_transform, deepcopy_minimal
from ..._compat import cached_property
from ..._resource import SyncAPIResource, AsyncAPIResource
Expand Down
8 changes: 1 addition & 7 deletions src/openai/resources/beta/assistants/assistants.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,7 @@
import httpx

from .files import Files, AsyncFiles, FilesWithRawResponse, AsyncFilesWithRawResponse
from ...._types import (
NOT_GIVEN,
Body,
Query,
Headers,
NotGiven,
)
from ...._types import NOT_GIVEN, Body, Query, Headers, NotGiven
from ...._utils import maybe_transform
from ...._compat import cached_property
from ...._resource import SyncAPIResource, AsyncAPIResource
Expand Down
8 changes: 1 addition & 7 deletions src/openai/resources/beta/assistants/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,7 @@

import httpx

from ...._types import (
NOT_GIVEN,
Body,
Query,
Headers,
NotGiven,
)
from ...._types import NOT_GIVEN, Body, Query, Headers, NotGiven
from ...._utils import maybe_transform
from ...._compat import cached_property
from ...._resource import SyncAPIResource, AsyncAPIResource
Expand Down
8 changes: 1 addition & 7 deletions src/openai/resources/beta/threads/messages/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,7 @@

import httpx

from ....._types import (
NOT_GIVEN,
Body,
Query,
Headers,
NotGiven,
)
from ....._types import NOT_GIVEN, Body, Query, Headers, NotGiven
from ....._utils import maybe_transform
from ....._compat import cached_property
from ....._resource import SyncAPIResource, AsyncAPIResource
Expand Down
Loading