Skip to content

Commit c43c6d9

Browse files
authored
fix: pass metadata to pagers (#470)
Closes #469
1 parent f49bc3f commit c43c6d9

File tree

6 files changed

+80
-12
lines changed

6 files changed

+80
-12
lines changed

gapic/schema/wrappers.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -832,6 +832,11 @@ def has_lro(self) -> bool:
832832
"""Return whether the service has a long-running method."""
833833
return any([m.lro for m in self.methods.values()])
834834

835+
@property
836+
def has_pagers(self) -> bool:
837+
"""Return whether the service has paged methods."""
838+
return any(m.paged_result_field for m in self.methods.values())
839+
835840
@property
836841
def host(self) -> str:
837842
"""Return the hostname for this service, if specified.

gapic/templates/%namespace/%name_%version/%sub/services/%service/async_client.py.j2

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,7 @@ class {{ service.async_client_name }}:
244244
method=rpc,
245245
request=request,
246246
response=response,
247+
metadata=metadata,
247248
)
248249
{%- endif %}
249250
{%- if not method.void %}

gapic/templates/%namespace/%name_%version/%sub/services/%service/client.py.j2

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,7 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
374374
method=rpc,
375375
request=request,
376376
response=response,
377+
metadata=metadata,
377378
)
378379
{%- endif %}
379380
{%- if not method.void %}

gapic/templates/%namespace/%name_%version/%sub/services/%service/pagers.py.j2

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
{# This lives within the loop in order to ensure that this template
77
is empty if there are no paged methods.
88
-#}
9-
from typing import Any, AsyncIterable, Awaitable, Callable, Iterable
9+
from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple
1010

1111
{% filter sort_lines -%}
1212
{% for method in service.methods.values() | selectattr('paged_result_field') -%}
@@ -35,10 +35,11 @@ class {{ method.name }}Pager:
3535
the most recent response is retained, and thus used for attribute lookup.
3636
"""
3737
def __init__(self,
38-
method: Callable[[{{ method.input.ident }}],
39-
{{ method.output.ident }}],
38+
method: Callable[..., {{ method.output.ident }}],
4039
request: {{ method.input.ident }},
41-
response: {{ method.output.ident }}):
40+
response: {{ method.output.ident }},
41+
*,
42+
metadata: Sequence[Tuple[str, str]] = ()):
4243
"""Instantiate the pager.
4344

4445
Args:
@@ -48,10 +49,13 @@ class {{ method.name }}Pager:
4849
The initial request object.
4950
response (:class:`{{ method.output.ident.sphinx }}`):
5051
The initial response object.
52+
metadata (Sequence[Tuple[str, str]]): Strings which should be
53+
sent along with the request as metadata.
5154
"""
5255
self._method = method
5356
self._request = {{ method.input.ident }}(request)
5457
self._response = response
58+
self._metadata = metadata
5559

5660
def __getattr__(self, name: str) -> Any:
5761
return getattr(self._response, name)
@@ -61,7 +65,7 @@ class {{ method.name }}Pager:
6165
yield self._response
6266
while self._response.next_page_token:
6367
self._request.page_token = self._response.next_page_token
64-
self._response = self._method(self._request)
68+
self._response = self._method(self._request, metadata=self._metadata)
6569
yield self._response
6670

6771
def __iter__(self) -> {{ method.paged_result_field.ident | replace('Sequence', 'Iterable') }}:
@@ -90,10 +94,11 @@ class {{ method.name }}AsyncPager:
9094
the most recent response is retained, and thus used for attribute lookup.
9195
"""
9296
def __init__(self,
93-
method: Callable[[{{ method.input.ident }}],
94-
Awaitable[{{ method.output.ident }}]],
97+
method: Callable[..., Awaitable[{{ method.output.ident }}]],
9598
request: {{ method.input.ident }},
96-
response: {{ method.output.ident }}):
99+
response: {{ method.output.ident }},
100+
*,
101+
metadata: Sequence[Tuple[str, str]] = ()):
97102
"""Instantiate the pager.
98103

99104
Args:
@@ -103,10 +108,13 @@ class {{ method.name }}AsyncPager:
103108
The initial request object.
104109
response (:class:`{{ method.output.ident.sphinx }}`):
105110
The initial response object.
111+
metadata (Sequence[Tuple[str, str]]): Strings which should be
112+
sent along with the request as metadata.
106113
"""
107114
self._method = method
108115
self._request = {{ method.input.ident }}(request)
109116
self._response = response
117+
self._metadata = metadata
110118

111119
def __getattr__(self, name: str) -> Any:
112120
return getattr(self._response, name)
@@ -116,7 +124,7 @@ class {{ method.name }}AsyncPager:
116124
yield self._response
117125
while self._response.next_page_token:
118126
self._request.page_token = self._response.next_page_token
119-
self._response = await self._method(self._request)
127+
self._response = await self._method(self._request, metadata=self._metadata)
120128
yield self._response
121129

122130
def __aiter__(self) -> {{ method.paged_result_field.ident | replace('Sequence', 'AsyncIterable') }}:

gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ from google.api_core import future
2727
from google.api_core import operations_v1
2828
from google.longrunning import operations_pb2
2929
{% endif -%}
30+
{% if service.has_pagers -%}
31+
from google.api_core import gapic_v1
32+
{% endif -%}
3033
{% for method in service.methods.values() -%}
3134
{% for ref_type in method.ref_types
3235
if not ((ref_type.ident.python_import.package == ('google', 'api_core') and ref_type.ident.python_import.module == 'operation')
@@ -695,9 +698,24 @@ def test_{{ method.name|snake_case }}_pager():
695698
),
696699
RuntimeError,
697700
)
698-
results = [i for i in client.{{ method.name|snake_case }}(
699-
request={},
700-
)]
701+
702+
metadata = ()
703+
{% if method.field_headers -%}
704+
metadata = tuple(metadata) + (
705+
gapic_v1.routing_header.to_grpc_metadata((
706+
{%- for field_header in method.field_headers %}
707+
{%- if not method.client_streaming %}
708+
('{{ field_header }}', ''),
709+
{%- endif %}
710+
{%- endfor %}
711+
)),
712+
)
713+
{% endif -%}
714+
pager = client.{{ method.name|snake_case }}(request={})
715+
716+
assert pager._metadata == metadata
717+
718+
results = [i for i in pager]
701719
assert len(results) == 6
702720
assert all(isinstance(i, {{ method.paged_result_field.message.ident }})
703721
for i in results)

tests/unit/schema/wrappers/test_service.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,3 +260,38 @@ def test_service_any_streaming():
260260

261261
assert service.any_client_streaming == client
262262
assert service.any_server_streaming == server
263+
264+
265+
def test_has_pagers():
266+
paged = make_field(name='foos', message=make_message('Foo'), repeated=True)
267+
input_msg = make_message(
268+
name='ListFoosRequest',
269+
fields=(
270+
make_field(name='parent', type=9), # str
271+
make_field(name='page_size', type=5), # int
272+
make_field(name='page_token', type=9), # str
273+
),
274+
)
275+
output_msg = make_message(
276+
name='ListFoosResponse',
277+
fields=(
278+
paged,
279+
make_field(name='next_page_token', type=9), # str
280+
),
281+
)
282+
method = make_method(
283+
'ListFoos',
284+
input_message=input_msg,
285+
output_message=output_msg,
286+
)
287+
288+
service = make_service(name="Fooer", methods=(method,),)
289+
assert service.has_pagers
290+
291+
other_service = make_service(
292+
name="Unfooer",
293+
methods=(
294+
get_method("Unfoo", "foo.bar.UnfooReq", "foo.bar.UnFooResp"),
295+
),
296+
)
297+
assert not other_service.has_pagers

0 commit comments

Comments
 (0)