Skip to content

Commit b6d232a

Browse files
feat: optimize retries (#854)
1 parent 0b3606f commit b6d232a

File tree

6 files changed

+90
-64
lines changed

6 files changed

+90
-64
lines changed

google/cloud/bigtable/data/_async/_read_rows.py

Lines changed: 36 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from google.cloud.bigtable.data._helpers import _make_metadata
3232

3333
from google.api_core import retry_async as retries
34-
from google.api_core.retry_streaming_async import AsyncRetryableGenerator
34+
from google.api_core.retry_streaming_async import retry_target_stream
3535
from google.api_core.retry import exponential_sleep_generator
3636
from google.api_core import exceptions as core_exceptions
3737

@@ -100,35 +100,17 @@ def __init__(
100100
self._last_yielded_row_key: bytes | None = None
101101
self._remaining_count: int | None = self.request.rows_limit or None
102102

103-
async def start_operation(self) -> AsyncGenerator[Row, None]:
103+
def start_operation(self) -> AsyncGenerator[Row, None]:
104104
"""
105105
Start the read_rows operation, retrying on retryable errors.
106106
"""
107-
transient_errors = []
108-
109-
def on_error_fn(exc):
110-
if self._predicate(exc):
111-
transient_errors.append(exc)
112-
113-
retry_gen = AsyncRetryableGenerator(
107+
return retry_target_stream(
114108
self._read_rows_attempt,
115109
self._predicate,
116110
exponential_sleep_generator(0.01, 60, multiplier=2),
117111
self.operation_timeout,
118-
on_error_fn,
112+
exception_factory=self._build_exception,
119113
)
120-
try:
121-
async for row in retry_gen:
122-
yield row
123-
if self._remaining_count is not None:
124-
self._remaining_count -= 1
125-
if self._remaining_count < 0:
126-
raise RuntimeError("emit count exceeds row limit")
127-
except core_exceptions.RetryError:
128-
self._raise_retry_error(transient_errors)
129-
except GeneratorExit:
130-
# propagate close to wrapped generator
131-
await retry_gen.aclose()
132114

133115
def _read_rows_attempt(self) -> AsyncGenerator[Row, None]:
134116
"""
@@ -202,6 +184,10 @@ async def chunk_stream(
202184
elif c.commit_row:
203185
# update row state after each commit
204186
self._last_yielded_row_key = current_key
187+
if self._remaining_count is not None:
188+
self._remaining_count -= 1
189+
if self._remaining_count < 0:
190+
raise InvalidChunk("emit count exceeds row limit")
205191
current_key = None
206192

207193
@staticmethod
@@ -354,19 +340,34 @@ def _revise_request_rowset(
354340
raise _RowSetComplete()
355341
return RowSetPB(row_keys=adjusted_keys, row_ranges=adjusted_ranges)
356342

357-
def _raise_retry_error(self, transient_errors: list[Exception]) -> None:
343+
@staticmethod
344+
def _build_exception(
345+
exc_list: list[Exception], is_timeout: bool, timeout_val: float
346+
) -> tuple[Exception, Exception | None]:
358347
"""
359-
If the retryable deadline is hit, wrap the raised exception
360-
in a RetryExceptionGroup
348+
Build retry error based on exceptions encountered during operation
349+
350+
Args:
351+
- exc_list: list of exceptions encountered during operation
352+
- is_timeout: whether the operation failed due to timeout
353+
- timeout_val: the operation timeout value in seconds, for constructing
354+
the error message
355+
Returns:
356+
- tuple of the exception to raise, and a cause exception if applicable
361357
"""
362-
timeout_value = self.operation_timeout
363-
timeout_str = f" of {timeout_value:.1f}s" if timeout_value is not None else ""
364-
error_str = f"operation_timeout{timeout_str} exceeded"
365-
new_exc = core_exceptions.DeadlineExceeded(
366-
error_str,
358+
if is_timeout:
359+
# if failed due to timeout, raise deadline exceeded as primary exception
360+
source_exc: Exception = core_exceptions.DeadlineExceeded(
361+
f"operation_timeout of {timeout_val} exceeded"
362+
)
363+
elif exc_list:
364+
# otherwise, raise non-retryable error as primary exception
365+
source_exc = exc_list.pop()
366+
else:
367+
source_exc = RuntimeError("failed with unspecified exception")
368+
# use the retry exception group as the cause of the exception
369+
cause_exc: Exception | None = (
370+
RetryExceptionGroup(exc_list) if exc_list else None
367371
)
368-
source_exc = None
369-
if transient_errors:
370-
source_exc = RetryExceptionGroup(transient_errors)
371-
new_exc.__cause__ = source_exc
372-
raise new_exc from source_exc
372+
source_exc.__cause__ = cause_exc
373+
return source_exc, cause_exc

google/cloud/bigtable/data/_helpers.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,9 @@ def _attempt_timeout_generator(
6262
yield max(0, min(per_request_timeout, deadline - time.monotonic()))
6363

6464

65+
# TODO:replace this function with an exception_factory passed into the retry when
66+
# feature is merged:
67+
# https://github.com/googleapis/python-bigtable/blob/ea5b4f923e42516729c57113ddbe28096841b952/google/cloud/bigtable/data/_async/_read_rows.py#L130
6568
def _convert_retry_deadline(
6669
func: Callable[..., Any],
6770
timeout_value: float | None = None,

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
# 'Development Status :: 5 - Production/Stable'
3838
release_status = "Development Status :: 5 - Production/Stable"
3939
dependencies = [
40-
"google-api-core[grpc] == 2.12.0.dev0", # TODO: change to >= after streaming retries is merged
40+
"google-api-core[grpc] == 2.12.0.dev1", # TODO: change to >= after streaming retries is merged
4141
"google-cloud-core >= 1.4.1, <3.0.0dev",
4242
"grpc-google-iam-v1 >= 0.12.4, <1.0.0dev",
4343
"proto-plus >= 1.22.0, <2.0.0dev",

testing/constraints-3.7.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
#
66
# e.g., if setup.py has "foo >= 1.14.0, < 2.0.0dev",
77
# Then this file should have foo==1.14.0
8-
google-api-core==2.12.0.dev0
8+
google-api-core==2.12.0.dev1
99
google-cloud-core==2.3.2
1010
grpc-google-iam-v1==0.12.4
1111
proto-plus==1.22.0

tests/unit/data/_async/test__read_rows.py

Lines changed: 48 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -226,24 +226,34 @@ async def test_revise_limit(self, start_limit, emit_num, expected_limit):
226226
should be raised (tested in test_revise_limit_over_limit)
227227
"""
228228
from google.cloud.bigtable.data import ReadRowsQuery
229+
from google.cloud.bigtable_v2.types import ReadRowsResponse
229230

230-
async def mock_stream():
231-
for i in range(emit_num):
232-
yield i
231+
async def awaitable_stream():
232+
async def mock_stream():
233+
for i in range(emit_num):
234+
yield ReadRowsResponse(
235+
chunks=[
236+
ReadRowsResponse.CellChunk(
237+
row_key=str(i).encode(),
238+
family_name="b",
239+
qualifier=b"c",
240+
value=b"d",
241+
commit_row=True,
242+
)
243+
]
244+
)
245+
246+
return mock_stream()
233247

234248
query = ReadRowsQuery(limit=start_limit)
235249
table = mock.Mock()
236250
table.table_name = "table_name"
237251
table.app_profile_id = "app_profile_id"
238-
with mock.patch.object(
239-
_ReadRowsOperationAsync, "_read_rows_attempt"
240-
) as mock_attempt:
241-
mock_attempt.return_value = mock_stream()
242-
instance = self._make_one(query, table, 10, 10)
243-
assert instance._remaining_count == start_limit
244-
# read emit_num rows
245-
async for val in instance.start_operation():
246-
pass
252+
instance = self._make_one(query, table, 10, 10)
253+
assert instance._remaining_count == start_limit
254+
# read emit_num rows
255+
async for val in instance.chunk_stream(awaitable_stream()):
256+
pass
247257
assert instance._remaining_count == expected_limit
248258

249259
@pytest.mark.parametrize("start_limit,emit_num", [(5, 10), (3, 9), (1, 10)])
@@ -254,26 +264,37 @@ async def test_revise_limit_over_limit(self, start_limit, emit_num):
254264
(unless start_num == 0, which represents unlimited)
255265
"""
256266
from google.cloud.bigtable.data import ReadRowsQuery
267+
from google.cloud.bigtable_v2.types import ReadRowsResponse
268+
from google.cloud.bigtable.data.exceptions import InvalidChunk
257269

258-
async def mock_stream():
259-
for i in range(emit_num):
260-
yield i
270+
async def awaitable_stream():
271+
async def mock_stream():
272+
for i in range(emit_num):
273+
yield ReadRowsResponse(
274+
chunks=[
275+
ReadRowsResponse.CellChunk(
276+
row_key=str(i).encode(),
277+
family_name="b",
278+
qualifier=b"c",
279+
value=b"d",
280+
commit_row=True,
281+
)
282+
]
283+
)
284+
285+
return mock_stream()
261286

262287
query = ReadRowsQuery(limit=start_limit)
263288
table = mock.Mock()
264289
table.table_name = "table_name"
265290
table.app_profile_id = "app_profile_id"
266-
with mock.patch.object(
267-
_ReadRowsOperationAsync, "_read_rows_attempt"
268-
) as mock_attempt:
269-
mock_attempt.return_value = mock_stream()
270-
instance = self._make_one(query, table, 10, 10)
271-
assert instance._remaining_count == start_limit
272-
with pytest.raises(RuntimeError) as e:
273-
# read emit_num rows
274-
async for val in instance.start_operation():
275-
pass
276-
assert "emit count exceeds row limit" in str(e.value)
291+
instance = self._make_one(query, table, 10, 10)
292+
assert instance._remaining_count == start_limit
293+
with pytest.raises(InvalidChunk) as e:
294+
# read emit_num rows
295+
async for val in instance.chunk_stream(awaitable_stream()):
296+
pass
297+
assert "emit count exceeds row limit" in str(e.value)
277298

278299
@pytest.mark.asyncio
279300
async def test_aclose(self):
@@ -333,6 +354,7 @@ async def mock_stream():
333354

334355
instance = mock.Mock()
335356
instance._last_yielded_row_key = None
357+
instance._remaining_count = None
336358
stream = _ReadRowsOperationAsync.chunk_stream(instance, mock_awaitable_stream())
337359
await stream.__anext__()
338360
with pytest.raises(InvalidChunk) as exc:

0 commit comments

Comments
 (0)