Skip to content

Commit 21b0c00

Browse files
committed
feat(client): send retry count header
1 parent 6172976 commit 21b0c00

File tree

3 files changed

+74
-48
lines changed

3 files changed

+74
-48
lines changed

Diff for: src/openai/_base_client.py

+55-47
Original file line numberDiff line numberDiff line change
@@ -401,14 +401,7 @@ def _make_status_error(
401401
) -> _exceptions.APIStatusError:
402402
raise NotImplementedError()
403403

404-
def _remaining_retries(
405-
self,
406-
remaining_retries: Optional[int],
407-
options: FinalRequestOptions,
408-
) -> int:
409-
return remaining_retries if remaining_retries is not None else options.get_max_retries(self.max_retries)
410-
411-
def _build_headers(self, options: FinalRequestOptions) -> httpx.Headers:
404+
def _build_headers(self, options: FinalRequestOptions, *, retries_taken: int = 0) -> httpx.Headers:
412405
custom_headers = options.headers or {}
413406
headers_dict = _merge_mappings(self.default_headers, custom_headers)
414407
self._validate_headers(headers_dict, custom_headers)
@@ -420,6 +413,9 @@ def _build_headers(self, options: FinalRequestOptions) -> httpx.Headers:
420413
if idempotency_header and options.method.lower() != "get" and idempotency_header not in headers:
421414
headers[idempotency_header] = options.idempotency_key or self._idempotency_key()
422415

416+
if retries_taken > 0:
417+
headers.setdefault("x-stainless-retry-count", str(retries_taken))
418+
423419
return headers
424420

425421
def _prepare_url(self, url: str) -> URL:
@@ -441,6 +437,8 @@ def _make_sse_decoder(self) -> SSEDecoder | SSEBytesDecoder:
441437
def _build_request(
442438
self,
443439
options: FinalRequestOptions,
440+
*,
441+
retries_taken: int = 0,
444442
) -> httpx.Request:
445443
if log.isEnabledFor(logging.DEBUG):
446444
log.debug("Request options: %s", model_dump(options, exclude_unset=True))
@@ -456,7 +454,7 @@ def _build_request(
456454
else:
457455
raise RuntimeError(f"Unexpected JSON data type, {type(json_data)}, cannot merge with `extra_body`")
458456

459-
headers = self._build_headers(options)
457+
headers = self._build_headers(options, retries_taken=retries_taken)
460458
params = _merge_mappings(self.default_query, options.params)
461459
content_type = headers.get("Content-Type")
462460
files = options.files
@@ -939,20 +937,25 @@ def request(
939937
stream: bool = False,
940938
stream_cls: type[_StreamT] | None = None,
941939
) -> ResponseT | _StreamT:
940+
if remaining_retries is not None:
941+
retries_taken = options.get_max_retries(self.max_retries) - remaining_retries
942+
else:
943+
retries_taken = 0
944+
942945
return self._request(
943946
cast_to=cast_to,
944947
options=options,
945948
stream=stream,
946949
stream_cls=stream_cls,
947-
remaining_retries=remaining_retries,
950+
retries_taken=retries_taken,
948951
)
949952

950953
def _request(
951954
self,
952955
*,
953956
cast_to: Type[ResponseT],
954957
options: FinalRequestOptions,
955-
remaining_retries: int | None,
958+
retries_taken: int,
956959
stream: bool,
957960
stream_cls: type[_StreamT] | None,
958961
) -> ResponseT | _StreamT:
@@ -964,8 +967,8 @@ def _request(
964967
cast_to = self._maybe_override_cast_to(cast_to, options)
965968
options = self._prepare_options(options)
966969

967-
retries = self._remaining_retries(remaining_retries, options)
968-
request = self._build_request(options)
970+
remaining_retries = options.get_max_retries(self.max_retries) - retries_taken
971+
request = self._build_request(options, retries_taken=retries_taken)
969972
self._prepare_request(request)
970973

971974
kwargs: HttpxSendArgs = {}
@@ -983,11 +986,11 @@ def _request(
983986
except httpx.TimeoutException as err:
984987
log.debug("Encountered httpx.TimeoutException", exc_info=True)
985988

986-
if retries > 0:
989+
if remaining_retries > 0:
987990
return self._retry_request(
988991
input_options,
989992
cast_to,
990-
retries,
993+
retries_taken=retries_taken,
991994
stream=stream,
992995
stream_cls=stream_cls,
993996
response_headers=None,
@@ -998,11 +1001,11 @@ def _request(
9981001
except Exception as err:
9991002
log.debug("Encountered Exception", exc_info=True)
10001003

1001-
if retries > 0:
1004+
if remaining_retries > 0:
10021005
return self._retry_request(
10031006
input_options,
10041007
cast_to,
1005-
retries,
1008+
retries_taken=retries_taken,
10061009
stream=stream,
10071010
stream_cls=stream_cls,
10081011
response_headers=None,
@@ -1026,13 +1029,13 @@ def _request(
10261029
except httpx.HTTPStatusError as err: # thrown on 4xx and 5xx status code
10271030
log.debug("Encountered httpx.HTTPStatusError", exc_info=True)
10281031

1029-
if retries > 0 and self._should_retry(err.response):
1032+
if remaining_retries > 0 and self._should_retry(err.response):
10301033
err.response.close()
10311034
return self._retry_request(
10321035
input_options,
10331036
cast_to,
1034-
retries,
1035-
err.response.headers,
1037+
retries_taken=retries_taken,
1038+
response_headers=err.response.headers,
10361039
stream=stream,
10371040
stream_cls=stream_cls,
10381041
)
@@ -1051,26 +1054,26 @@ def _request(
10511054
response=response,
10521055
stream=stream,
10531056
stream_cls=stream_cls,
1054-
retries_taken=options.get_max_retries(self.max_retries) - retries,
1057+
retries_taken=retries_taken,
10551058
)
10561059

10571060
def _retry_request(
10581061
self,
10591062
options: FinalRequestOptions,
10601063
cast_to: Type[ResponseT],
1061-
remaining_retries: int,
1062-
response_headers: httpx.Headers | None,
10631064
*,
1065+
retries_taken: int,
1066+
response_headers: httpx.Headers | None,
10641067
stream: bool,
10651068
stream_cls: type[_StreamT] | None,
10661069
) -> ResponseT | _StreamT:
1067-
remaining = remaining_retries - 1
1068-
if remaining == 1:
1070+
remaining_retries = options.get_max_retries(self.max_retries) - retries_taken
1071+
if remaining_retries == 1:
10691072
log.debug("1 retry left")
10701073
else:
1071-
log.debug("%i retries left", remaining)
1074+
log.debug("%i retries left", remaining_retries)
10721075

1073-
timeout = self._calculate_retry_timeout(remaining, options, response_headers)
1076+
timeout = self._calculate_retry_timeout(remaining_retries, options, response_headers)
10741077
log.info("Retrying request to %s in %f seconds", options.url, timeout)
10751078

10761079
# In a synchronous context we are blocking the entire thread. Up to the library user to run the client in a
@@ -1080,7 +1083,7 @@ def _retry_request(
10801083
return self._request(
10811084
options=options,
10821085
cast_to=cast_to,
1083-
remaining_retries=remaining,
1086+
retries_taken=retries_taken + 1,
10841087
stream=stream,
10851088
stream_cls=stream_cls,
10861089
)
@@ -1512,12 +1515,17 @@ async def request(
15121515
stream_cls: type[_AsyncStreamT] | None = None,
15131516
remaining_retries: Optional[int] = None,
15141517
) -> ResponseT | _AsyncStreamT:
1518+
if remaining_retries is not None:
1519+
retries_taken = options.get_max_retries(self.max_retries) - remaining_retries
1520+
else:
1521+
retries_taken = 0
1522+
15151523
return await self._request(
15161524
cast_to=cast_to,
15171525
options=options,
15181526
stream=stream,
15191527
stream_cls=stream_cls,
1520-
remaining_retries=remaining_retries,
1528+
retries_taken=retries_taken,
15211529
)
15221530

15231531
async def _request(
@@ -1527,7 +1535,7 @@ async def _request(
15271535
*,
15281536
stream: bool,
15291537
stream_cls: type[_AsyncStreamT] | None,
1530-
remaining_retries: int | None,
1538+
retries_taken: int,
15311539
) -> ResponseT | _AsyncStreamT:
15321540
if self._platform is None:
15331541
# `get_platform` can make blocking IO calls so we
@@ -1542,8 +1550,8 @@ async def _request(
15421550
cast_to = self._maybe_override_cast_to(cast_to, options)
15431551
options = await self._prepare_options(options)
15441552

1545-
retries = self._remaining_retries(remaining_retries, options)
1546-
request = self._build_request(options)
1553+
remaining_retries = options.get_max_retries(self.max_retries) - retries_taken
1554+
request = self._build_request(options, retries_taken=retries_taken)
15471555
await self._prepare_request(request)
15481556

15491557
kwargs: HttpxSendArgs = {}
@@ -1559,11 +1567,11 @@ async def _request(
15591567
except httpx.TimeoutException as err:
15601568
log.debug("Encountered httpx.TimeoutException", exc_info=True)
15611569

1562-
if retries > 0:
1570+
if remaining_retries > 0:
15631571
return await self._retry_request(
15641572
input_options,
15651573
cast_to,
1566-
retries,
1574+
retries_taken=retries_taken,
15671575
stream=stream,
15681576
stream_cls=stream_cls,
15691577
response_headers=None,
@@ -1574,11 +1582,11 @@ async def _request(
15741582
except Exception as err:
15751583
log.debug("Encountered Exception", exc_info=True)
15761584

1577-
if retries > 0:
1585+
if retries_taken > 0:
15781586
return await self._retry_request(
15791587
input_options,
15801588
cast_to,
1581-
retries,
1589+
retries_taken=retries_taken,
15821590
stream=stream,
15831591
stream_cls=stream_cls,
15841592
response_headers=None,
@@ -1596,13 +1604,13 @@ async def _request(
15961604
except httpx.HTTPStatusError as err: # thrown on 4xx and 5xx status code
15971605
log.debug("Encountered httpx.HTTPStatusError", exc_info=True)
15981606

1599-
if retries > 0 and self._should_retry(err.response):
1607+
if remaining_retries > 0 and self._should_retry(err.response):
16001608
await err.response.aclose()
16011609
return await self._retry_request(
16021610
input_options,
16031611
cast_to,
1604-
retries,
1605-
err.response.headers,
1612+
retries_taken=retries_taken,
1613+
response_headers=err.response.headers,
16061614
stream=stream,
16071615
stream_cls=stream_cls,
16081616
)
@@ -1621,34 +1629,34 @@ async def _request(
16211629
response=response,
16221630
stream=stream,
16231631
stream_cls=stream_cls,
1624-
retries_taken=options.get_max_retries(self.max_retries) - retries,
1632+
retries_taken=retries_taken,
16251633
)
16261634

16271635
async def _retry_request(
16281636
self,
16291637
options: FinalRequestOptions,
16301638
cast_to: Type[ResponseT],
1631-
remaining_retries: int,
1632-
response_headers: httpx.Headers | None,
16331639
*,
1640+
retries_taken: int,
1641+
response_headers: httpx.Headers | None,
16341642
stream: bool,
16351643
stream_cls: type[_AsyncStreamT] | None,
16361644
) -> ResponseT | _AsyncStreamT:
1637-
remaining = remaining_retries - 1
1638-
if remaining == 1:
1645+
remaining_retries = options.get_max_retries(self.max_retries) - retries_taken
1646+
if remaining_retries == 1:
16391647
log.debug("1 retry left")
16401648
else:
1641-
log.debug("%i retries left", remaining)
1649+
log.debug("%i retries left", remaining_retries)
16421650

1643-
timeout = self._calculate_retry_timeout(remaining, options, response_headers)
1651+
timeout = self._calculate_retry_timeout(remaining_retries, options, response_headers)
16441652
log.info("Retrying request to %s in %f seconds", options.url, timeout)
16451653

16461654
await anyio.sleep(timeout)
16471655

16481656
return await self._request(
16491657
options=options,
16501658
cast_to=cast_to,
1651-
remaining_retries=remaining,
1659+
retries_taken=retries_taken + 1,
16521660
stream=stream,
16531661
stream_cls=stream_cls,
16541662
)

Diff for: src/openai/lib/azure.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -53,13 +53,15 @@ class BaseAzureClient(BaseClient[_HttpxClientT, _DefaultStreamT]):
5353
def _build_request(
5454
self,
5555
options: FinalRequestOptions,
56+
*,
57+
retries_taken: int = 0,
5658
) -> httpx.Request:
5759
if options.url in _deployments_endpoints and is_mapping(options.json_data):
5860
model = options.json_data.get("model")
5961
if model is not None and not "/deployments" in str(self.base_url):
6062
options.url = f"/deployments/{model}{options.url}"
6163

62-
return super()._build_request(options)
64+
return super()._build_request(options, retries_taken=retries_taken)
6365

6466

6567
class AzureOpenAI(BaseAzureClient[httpx.Client, Stream[Any]], OpenAI):

Diff for: tests/test_client.py

+16
Original file line numberDiff line numberDiff line change
@@ -788,6 +788,10 @@ def retry_handler(_request: httpx.Request) -> httpx.Response:
788788
)
789789

790790
assert response.retries_taken == failures_before_success
791+
if failures_before_success == 0:
792+
assert "x-stainless-retry-count" not in response.http_request.headers
793+
else:
794+
assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
791795

792796
@pytest.mark.parametrize("failures_before_success", [0, 2, 4])
793797
@mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
@@ -818,6 +822,10 @@ def retry_handler(_request: httpx.Request) -> httpx.Response:
818822
model="gpt-4o",
819823
) as response:
820824
assert response.retries_taken == failures_before_success
825+
if failures_before_success == 0:
826+
assert "x-stainless-retry-count" not in response.http_request.headers
827+
else:
828+
assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
821829

822830

823831
class TestAsyncOpenAI:
@@ -1582,6 +1590,10 @@ def retry_handler(_request: httpx.Request) -> httpx.Response:
15821590
)
15831591

15841592
assert response.retries_taken == failures_before_success
1593+
if failures_before_success == 0:
1594+
assert "x-stainless-retry-count" not in response.http_request.headers
1595+
else:
1596+
assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
15851597

15861598
@pytest.mark.parametrize("failures_before_success", [0, 2, 4])
15871599
@mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
@@ -1613,3 +1625,7 @@ def retry_handler(_request: httpx.Request) -> httpx.Response:
16131625
model="gpt-4o",
16141626
) as response:
16151627
assert response.retries_taken == failures_before_success
1628+
if failures_before_success == 0:
1629+
assert "x-stainless-retry-count" not in response.http_request.headers
1630+
else:
1631+
assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success

0 commit comments

Comments
 (0)