From 86d7a2974fa32f4a04bd5d260eb86492143160cb Mon Sep 17 00:00:00 2001 From: Stainless Bot Date: Fri, 3 Nov 2023 05:23:34 +0000 Subject: [PATCH] feat(client): allow binary returns --- src/orb/_base_client.py | 93 +++++++++++++++++++++++++ src/orb/_response.py | 5 +- src/orb/_types.py | 151 +++++++++++++++++++++++++++++++++++++++- tests/test_client.py | 31 +++++++-- 4 files changed, 273 insertions(+), 7 deletions(-) diff --git a/src/orb/_base_client.py b/src/orb/_base_client.py index 5ed9f54d..5354c488 100644 --- a/src/orb/_base_client.py +++ b/src/orb/_base_client.py @@ -1,5 +1,6 @@ from __future__ import annotations +import os import json import time import uuid @@ -60,6 +61,7 @@ RequestOptions, UnknownResponse, ModelBuilderProtocol, + BinaryResponseContent, ) from ._utils import is_dict, is_given, is_mapping from ._compat import model_copy, model_dump @@ -1672,3 +1674,94 @@ def _merge_mappings( """ merged = {**obj1, **obj2} return {key: value for key, value in merged.items() if not isinstance(value, Omit)} + + +class HttpxBinaryResponseContent(BinaryResponseContent): + response: httpx.Response + + def __init__(self, response: httpx.Response) -> None: + self.response = response + + @property + @override + def content(self) -> bytes: + return self.response.content + + @property + @override + def text(self) -> str: + return self.response.text + + @property + @override + def encoding(self) -> Optional[str]: + return self.response.encoding + + @property + @override + def charset_encoding(self) -> Optional[str]: + return self.response.charset_encoding + + @override + def json(self, **kwargs: Any) -> Any: + return self.response.json(**kwargs) + + @override + def read(self) -> bytes: + return self.response.read() + + @override + def iter_bytes(self, chunk_size: Optional[int] = None) -> Iterator[bytes]: + return self.response.iter_bytes(chunk_size) + + @override + def iter_text(self, chunk_size: Optional[int] = None) -> Iterator[str]: + return self.response.iter_text(chunk_size) + + @override + def iter_lines(self) -> Iterator[str]: + return self.response.iter_lines() + + @override + def iter_raw(self, chunk_size: Optional[int] = None) -> Iterator[bytes]: + return self.response.iter_raw(chunk_size) + + @override + def stream_to_file(self, file: str | os.PathLike[str]) -> None: + with open(file, mode="wb") as f: + for data in self.response.iter_bytes(): + f.write(data) + + @override + def close(self) -> None: + return self.response.close() + + @override + async def aread(self) -> bytes: + return await self.response.aread() + + @override + async def aiter_bytes(self, chunk_size: Optional[int] = None) -> AsyncIterator[bytes]: + return self.response.aiter_bytes(chunk_size) + + @override + async def aiter_text(self, chunk_size: Optional[int] = None) -> AsyncIterator[str]: + return self.response.aiter_text(chunk_size) + + @override + async def aiter_lines(self) -> AsyncIterator[str]: + return self.response.aiter_lines() + + @override + async def aiter_raw(self, chunk_size: Optional[int] = None) -> AsyncIterator[bytes]: + return self.response.aiter_raw(chunk_size) + + @override + async def astream_to_file(self, file: str | os.PathLike[str]) -> None: + with open(file, mode="wb") as f: + async for data in self.response.aiter_bytes(): + f.write(data) + + @override + async def aclose(self) -> None: + return await self.response.aclose() diff --git a/src/orb/_response.py b/src/orb/_response.py index 4c064ef2..bdc48daa 100644 --- a/src/orb/_response.py +++ b/src/orb/_response.py @@ -9,7 +9,7 @@ import httpx import pydantic -from ._types import NoneType, UnknownResponse +from ._types import NoneType, UnknownResponse, BinaryResponseContent from ._utils import is_given from ._models import BaseModel from ._constants import RAW_RESPONSE_HEADER @@ -135,6 +135,9 @@ def _parse(self) -> R: origin = get_origin(cast_to) or cast_to + if inspect.isclass(origin) and issubclass(origin, BinaryResponseContent): + return cast(R, cast_to(response)) # type: ignore + if origin == APIResponse: raise RuntimeError("Unexpected state - cast_to is `APIResponse`") diff --git a/src/orb/_types.py b/src/orb/_types.py index 162ab18d..0d12f71c 100644 --- a/src/orb/_types.py +++ b/src/orb/_types.py @@ -1,6 +1,7 @@ from __future__ import annotations from os import PathLike +from abc import ABC, abstractmethod from typing import ( IO, TYPE_CHECKING, @@ -13,8 +14,10 @@ Mapping, TypeVar, Callable, + Iterator, Optional, Sequence, + AsyncIterator, ) from typing_extensions import ( Literal, @@ -25,7 +28,6 @@ runtime_checkable, ) -import httpx import pydantic from httpx import URL, Proxy, Timeout, Response, BaseTransport, AsyncBaseTransport @@ -40,6 +42,151 @@ ModelT = TypeVar("ModelT", bound=pydantic.BaseModel) _T = TypeVar("_T") + +class BinaryResponseContent(ABC): + def __init__( + self, + response: Any, + ) -> None: + ... + + @property + @abstractmethod + def content(self) -> bytes: + pass + + @property + @abstractmethod + def text(self) -> str: + pass + + @property + @abstractmethod + def encoding(self) -> Optional[str]: + """ + Return an encoding to use for decoding the byte content into text. + The priority for determining this is given by... + + * `.encoding = <>` has been set explicitly. + * The encoding as specified by the charset parameter in the Content-Type header. + * The encoding as determined by `default_encoding`, which may either be + a string like "utf-8" indicating the encoding to use, or may be a callable + which enables charset autodetection. + """ + pass + + @property + @abstractmethod + def charset_encoding(self) -> Optional[str]: + """ + Return the encoding, as specified by the Content-Type header. + """ + pass + + @abstractmethod + def json(self, **kwargs: Any) -> Any: + pass + + @abstractmethod + def read(self) -> bytes: + """ + Read and return the response content. + """ + pass + + @abstractmethod + def iter_bytes(self, chunk_size: Optional[int] = None) -> Iterator[bytes]: + """ + A byte-iterator over the decoded response content. + This allows us to handle gzip, deflate, and brotli encoded responses. + """ + pass + + @abstractmethod + def iter_text(self, chunk_size: Optional[int] = None) -> Iterator[str]: + """ + A str-iterator over the decoded response content + that handles both gzip, deflate, etc but also detects the content's + string encoding. + """ + pass + + @abstractmethod + def iter_lines(self) -> Iterator[str]: + pass + + @abstractmethod + def iter_raw(self, chunk_size: Optional[int] = None) -> Iterator[bytes]: + """ + A byte-iterator over the raw response content. + """ + pass + + @abstractmethod + def stream_to_file(self, file: str | PathLike[str]) -> None: + """ + Stream the output to the given file. + """ + pass + + @abstractmethod + def close(self) -> None: + """ + Close the response and release the connection. + Automatically called if the response body is read to completion. + """ + pass + + @abstractmethod + async def aread(self) -> bytes: + """ + Read and return the response content. + """ + pass + + @abstractmethod + async def aiter_bytes(self, chunk_size: Optional[int] = None) -> AsyncIterator[bytes]: + """ + A byte-iterator over the decoded response content. + This allows us to handle gzip, deflate, and brotli encoded responses. + """ + pass + + @abstractmethod + async def aiter_text(self, chunk_size: Optional[int] = None) -> AsyncIterator[str]: + """ + A str-iterator over the decoded response content + that handles both gzip, deflate, etc but also detects the content's + string encoding. + """ + pass + + @abstractmethod + async def aiter_lines(self) -> AsyncIterator[str]: + pass + + @abstractmethod + async def aiter_raw(self, chunk_size: Optional[int] = None) -> AsyncIterator[bytes]: + """ + A byte-iterator over the raw response content. + """ + pass + + async def astream_to_file(self, file: str | PathLike[str]) -> None: + """ + Stream the output to the given file. + """ + pass + + @abstractmethod + async def aclose(self) -> None: + """ + Close the response and release the connection. + Automatically called if the response body is read to completion. + """ + pass + + # Approximates httpx internal ProxiesTypes and RequestFiles types # while adding support for `PathLike` instances ProxiesDict = Dict["str | URL", Union[None, str, URL, Proxy]] @@ -181,7 +328,7 @@ def get(self, __key: str) -> str | None: ResponseT = TypeVar( "ResponseT", - bound="Union[str, None, BaseModel, List[Any], Dict[str, Any], httpx.Response, UnknownResponse, ModelBuilderProtocol]", + bound="Union[str, None, BaseModel, List[Any], Dict[str, Any], Response, UnknownResponse, ModelBuilderProtocol, BinaryResponseContent]", ) StrBytesIntFloat = Union[str, bytes, int, float] diff --git a/tests/test_client.py b/tests/test_client.py index 94aba258..18998791 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -40,12 +40,23 @@ class TestOrb: @pytest.mark.respx(base_url=base_url) def test_raw_response(self, respx_mock: MockRouter) -> None: - respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"})) + respx_mock.post("/foo").mock(return_value=httpx.Response(200, json='{"foo": "bar"}')) response = self.client.post("/foo", cast_to=httpx.Response) assert response.status_code == 200 assert isinstance(response, httpx.Response) - assert response.json() == {"foo": "bar"} + assert response.json() == '{"foo": "bar"}' + + @pytest.mark.respx(base_url=base_url) + def test_raw_response_for_binary(self, respx_mock: MockRouter) -> None: + respx_mock.post("/foo").mock( + return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}') + ) + + response = self.client.post("/foo", cast_to=httpx.Response) + assert response.status_code == 200 + assert isinstance(response, httpx.Response) + assert response.json() == '{"foo": "bar"}' def test_copy(self) -> None: copied = self.client.copy() @@ -572,12 +583,24 @@ class TestAsyncOrb: @pytest.mark.respx(base_url=base_url) @pytest.mark.asyncio async def test_raw_response(self, respx_mock: MockRouter) -> None: - respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"})) + respx_mock.post("/foo").mock(return_value=httpx.Response(200, json='{"foo": "bar"}')) + + response = await self.client.post("/foo", cast_to=httpx.Response) + assert response.status_code == 200 + assert isinstance(response, httpx.Response) + assert response.json() == '{"foo": "bar"}' + + @pytest.mark.respx(base_url=base_url) + @pytest.mark.asyncio + async def test_raw_response_for_binary(self, respx_mock: MockRouter) -> None: + respx_mock.post("/foo").mock( + return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}') + ) response = await self.client.post("/foo", cast_to=httpx.Response) assert response.status_code == 200 assert isinstance(response, httpx.Response) - assert response.json() == {"foo": "bar"} + assert response.json() == '{"foo": "bar"}' def test_copy(self) -> None: copied = self.client.copy()