Skip to content

Commit c2cf1c4

Browse files
feat(client): add retries_taken to raw response class (#334)
1 parent 042ed83 commit c2cf1c4

File tree

4 files changed

+122
-1
lines changed

4 files changed

+122
-1
lines changed

src/orb/_base_client.py

+10
Original file line numberDiff line numberDiff line change
@@ -1050,6 +1050,7 @@ def _request(
10501050
response=response,
10511051
stream=stream,
10521052
stream_cls=stream_cls,
1053+
retries_taken=options.get_max_retries(self.max_retries) - retries,
10531054
)
10541055

10551056
def _retry_request(
@@ -1091,6 +1092,7 @@ def _process_response(
10911092
response: httpx.Response,
10921093
stream: bool,
10931094
stream_cls: type[Stream[Any]] | type[AsyncStream[Any]] | None,
1095+
retries_taken: int = 0,
10941096
) -> ResponseT:
10951097
if response.request.headers.get(RAW_RESPONSE_HEADER) == "true":
10961098
return cast(
@@ -1102,6 +1104,7 @@ def _process_response(
11021104
stream=stream,
11031105
stream_cls=stream_cls,
11041106
options=options,
1107+
retries_taken=retries_taken,
11051108
),
11061109
)
11071110

@@ -1121,6 +1124,7 @@ def _process_response(
11211124
stream=stream,
11221125
stream_cls=stream_cls,
11231126
options=options,
1127+
retries_taken=retries_taken,
11241128
),
11251129
)
11261130

@@ -1134,6 +1138,7 @@ def _process_response(
11341138
stream=stream,
11351139
stream_cls=stream_cls,
11361140
options=options,
1141+
retries_taken=retries_taken,
11371142
)
11381143
if bool(response.request.headers.get(RAW_RESPONSE_HEADER)):
11391144
return cast(ResponseT, api_response)
@@ -1624,6 +1629,7 @@ async def _request(
16241629
response=response,
16251630
stream=stream,
16261631
stream_cls=stream_cls,
1632+
retries_taken=options.get_max_retries(self.max_retries) - retries,
16271633
)
16281634

16291635
async def _retry_request(
@@ -1663,6 +1669,7 @@ async def _process_response(
16631669
response: httpx.Response,
16641670
stream: bool,
16651671
stream_cls: type[Stream[Any]] | type[AsyncStream[Any]] | None,
1672+
retries_taken: int = 0,
16661673
) -> ResponseT:
16671674
if response.request.headers.get(RAW_RESPONSE_HEADER) == "true":
16681675
return cast(
@@ -1674,6 +1681,7 @@ async def _process_response(
16741681
stream=stream,
16751682
stream_cls=stream_cls,
16761683
options=options,
1684+
retries_taken=retries_taken,
16771685
),
16781686
)
16791687

@@ -1693,6 +1701,7 @@ async def _process_response(
16931701
stream=stream,
16941702
stream_cls=stream_cls,
16951703
options=options,
1704+
retries_taken=retries_taken,
16961705
),
16971706
)
16981707

@@ -1706,6 +1715,7 @@ async def _process_response(
17061715
stream=stream,
17071716
stream_cls=stream_cls,
17081717
options=options,
1718+
retries_taken=retries_taken,
17091719
)
17101720
if bool(response.request.headers.get(RAW_RESPONSE_HEADER)):
17111721
return cast(ResponseT, api_response)

src/orb/_legacy_response.py

+17-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,18 @@
55
import logging
66
import datetime
77
import functools
8-
from typing import TYPE_CHECKING, Any, Union, Generic, TypeVar, Callable, Iterator, AsyncIterator, cast, overload
8+
from typing import (
9+
TYPE_CHECKING,
10+
Any,
11+
Union,
12+
Generic,
13+
TypeVar,
14+
Callable,
15+
Iterator,
16+
AsyncIterator,
17+
cast,
18+
overload,
19+
)
920
from typing_extensions import Awaitable, ParamSpec, override, deprecated, get_origin
1021

1122
import anyio
@@ -53,6 +64,9 @@ class LegacyAPIResponse(Generic[R]):
5364

5465
http_response: httpx.Response
5566

67+
retries_taken: int
68+
"""The number of retries made. If no retries happened this will be `0`"""
69+
5670
def __init__(
5771
self,
5872
*,
@@ -62,6 +76,7 @@ def __init__(
6276
stream: bool,
6377
stream_cls: type[Stream[Any]] | type[AsyncStream[Any]] | None,
6478
options: FinalRequestOptions,
79+
retries_taken: int = 0,
6580
) -> None:
6681
self._cast_to = cast_to
6782
self._client = client
@@ -70,6 +85,7 @@ def __init__(
7085
self._stream_cls = stream_cls
7186
self._options = options
7287
self.http_response = raw
88+
self.retries_taken = retries_taken
7389

7490
@overload
7591
def parse(self, *, to: type[_T]) -> _T:

src/orb/_response.py

+5
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@ class BaseAPIResponse(Generic[R]):
5555

5656
http_response: httpx.Response
5757

58+
retries_taken: int
59+
"""The number of retries made. If no retries happened this will be `0`"""
60+
5861
def __init__(
5962
self,
6063
*,
@@ -64,6 +67,7 @@ def __init__(
6467
stream: bool,
6568
stream_cls: type[Stream[Any]] | type[AsyncStream[Any]] | None,
6669
options: FinalRequestOptions,
70+
retries_taken: int = 0,
6771
) -> None:
6872
self._cast_to = cast_to
6973
self._client = client
@@ -72,6 +76,7 @@ def __init__(
7276
self._stream_cls = stream_cls
7377
self._options = options
7478
self.http_response = raw
79+
self.retries_taken = retries_taken
7580

7681
@property
7782
def headers(self) -> httpx.Headers:

tests/test_client.py

+90
Original file line numberDiff line numberDiff line change
@@ -747,6 +747,49 @@ def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> Non
747747

748748
assert _get_open_connections(self.client) == 0
749749

750+
@pytest.mark.parametrize("failures_before_success", [0, 2, 4])
751+
@mock.patch("orb._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
752+
@pytest.mark.respx(base_url=base_url)
753+
def test_retries_taken(self, client: Orb, failures_before_success: int, respx_mock: MockRouter) -> None:
754+
client = client.with_options(max_retries=4)
755+
756+
nb_retries = 0
757+
758+
def retry_handler(_request: httpx.Request) -> httpx.Response:
759+
nonlocal nb_retries
760+
if nb_retries < failures_before_success:
761+
nb_retries += 1
762+
return httpx.Response(500)
763+
return httpx.Response(200)
764+
765+
respx_mock.post("/customers").mock(side_effect=retry_handler)
766+
767+
response = client.customers.with_raw_response.create(email="email", name="name")
768+
769+
assert response.retries_taken == failures_before_success
770+
771+
@pytest.mark.parametrize("failures_before_success", [0, 2, 4])
772+
@mock.patch("orb._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
773+
@pytest.mark.respx(base_url=base_url)
774+
def test_retries_taken_new_response_class(
775+
self, client: Orb, failures_before_success: int, respx_mock: MockRouter
776+
) -> None:
777+
client = client.with_options(max_retries=4)
778+
779+
nb_retries = 0
780+
781+
def retry_handler(_request: httpx.Request) -> httpx.Response:
782+
nonlocal nb_retries
783+
if nb_retries < failures_before_success:
784+
nb_retries += 1
785+
return httpx.Response(500)
786+
return httpx.Response(200)
787+
788+
respx_mock.post("/customers").mock(side_effect=retry_handler)
789+
790+
with client.customers.with_streaming_response.create(email="email", name="name") as response:
791+
assert response.retries_taken == failures_before_success
792+
750793

751794
class TestAsyncOrb:
752795
client = AsyncOrb(base_url=base_url, api_key=api_key, _strict_response_validation=True)
@@ -1464,3 +1507,50 @@ async def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter)
14641507
)
14651508

14661509
assert _get_open_connections(self.client) == 0
1510+
1511+
@pytest.mark.parametrize("failures_before_success", [0, 2, 4])
1512+
@mock.patch("orb._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1513+
@pytest.mark.respx(base_url=base_url)
1514+
@pytest.mark.asyncio
1515+
async def test_retries_taken(
1516+
self, async_client: AsyncOrb, failures_before_success: int, respx_mock: MockRouter
1517+
) -> None:
1518+
client = async_client.with_options(max_retries=4)
1519+
1520+
nb_retries = 0
1521+
1522+
def retry_handler(_request: httpx.Request) -> httpx.Response:
1523+
nonlocal nb_retries
1524+
if nb_retries < failures_before_success:
1525+
nb_retries += 1
1526+
return httpx.Response(500)
1527+
return httpx.Response(200)
1528+
1529+
respx_mock.post("/customers").mock(side_effect=retry_handler)
1530+
1531+
response = await client.customers.with_raw_response.create(email="email", name="name")
1532+
1533+
assert response.retries_taken == failures_before_success
1534+
1535+
@pytest.mark.parametrize("failures_before_success", [0, 2, 4])
1536+
@mock.patch("orb._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1537+
@pytest.mark.respx(base_url=base_url)
1538+
@pytest.mark.asyncio
1539+
async def test_retries_taken_new_response_class(
1540+
self, async_client: AsyncOrb, failures_before_success: int, respx_mock: MockRouter
1541+
) -> None:
1542+
client = async_client.with_options(max_retries=4)
1543+
1544+
nb_retries = 0
1545+
1546+
def retry_handler(_request: httpx.Request) -> httpx.Response:
1547+
nonlocal nb_retries
1548+
if nb_retries < failures_before_success:
1549+
nb_retries += 1
1550+
return httpx.Response(500)
1551+
return httpx.Response(200)
1552+
1553+
respx_mock.post("/customers").mock(side_effect=retry_handler)
1554+
1555+
async with client.customers.with_streaming_response.create(email="email", name="name") as response:
1556+
assert response.retries_taken == failures_before_success

0 commit comments

Comments
 (0)