Skip to content

Commit bc5f729

Browse files
authored
fix: non-string required fields provide correct values (#1108)
In generated unit tests checking behavior of required fields in REST transports, fields are given default values in accordance with the type of the field.
1 parent 6a593f9 commit bc5f729

File tree

7 files changed

+167
-105
lines changed

7 files changed

+167
-105
lines changed

gapic/schema/wrappers.py

Lines changed: 70 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
import json
3333
import re
3434
from itertools import chain
35-
from typing import (Any, cast, Dict, FrozenSet, Iterable, List, Mapping,
35+
from typing import (Any, cast, Dict, FrozenSet, Iterator, Iterable, List, Mapping,
3636
ClassVar, Optional, Sequence, Set, Tuple, Union)
3737
from google.api import annotations_pb2 # type: ignore
3838
from google.api import client_pb2
@@ -757,17 +757,79 @@ class HttpRule:
757757
uri: str
758758
body: Optional[str]
759759

760-
@property
761-
def path_fields(self) -> List[Tuple[str, str]]:
760+
def path_fields(self, method: "~.Method") -> List[Tuple[Field, str, str]]:
762761
"""return list of (name, template) tuples extracted from uri."""
763-
return [(match.group("name"), match.group("template"))
762+
input = method.input
763+
return [(input.get_field(*match.group("name").split(".")), match.group("name"), match.group("template"))
764764
for match in path_template._VARIABLE_RE.finditer(self.uri)]
765765

766-
@property
767-
def sample_request(self) -> str:
766+
def sample_request(self, method: "~.Method") -> str:
768767
"""return json dict for sample request matching the uri template."""
769-
sample = utils.sample_from_path_fields(self.path_fields)
770-
return json.dumps(sample)
768+
769+
def sample_from_path_fields(paths: List[Tuple["wrappers.Field", str, str]]) -> Dict[Any, Any]:
770+
"""Construct a dict for a sample request object from a list of fields
771+
and template patterns.
772+
773+
Args:
774+
paths: a list of tuples, each with a (segmented) name and a pattern.
775+
Returns:
776+
A new nested dict with the templates instantiated.
777+
"""
778+
779+
request: Dict[str, Any] = {}
780+
781+
def _sample_names() -> Iterator[str]:
782+
sample_num: int = 0
783+
while True:
784+
sample_num += 1
785+
yield "sample{}".format(sample_num)
786+
787+
def add_field(obj, path, value):
788+
"""Insert a field into a nested dict and return the (outer) dict.
789+
Keys and sub-dicts are inserted if necessary to create the path.
790+
e.g. if obj, as passed in, is {}, path is "a.b.c", and value is
791+
"hello", obj will be updated to:
792+
{'a':
793+
{'b':
794+
{
795+
'c': 'hello'
796+
}
797+
}
798+
}
799+
800+
Args:
801+
obj: a (possibly) nested dict (parsed json)
802+
path: a segmented field name, e.g. "a.b.c"
803+
where each part is a dict key.
804+
value: the value of the new key.
805+
Returns:
806+
obj, possibly modified
807+
Raises:
808+
AttributeError if the path references a key that is
809+
not a dict.: e.g. path='a.b', obj = {'a':'abc'}
810+
"""
811+
812+
segments = path.split('.')
813+
leaf = segments.pop()
814+
subfield = obj
815+
for segment in segments:
816+
subfield = subfield.setdefault(segment, {})
817+
subfield[leaf] = value
818+
return obj
819+
820+
sample_names = _sample_names()
821+
for field, path, template in paths:
822+
sample_value = re.sub(
823+
r"(\*\*|\*)",
824+
lambda n: next(sample_names),
825+
template or '*'
826+
) if field.type == PrimitiveType.build(str) else field.mock_value_original_type
827+
add_field(request, path, sample_value)
828+
829+
return request
830+
831+
sample = sample_from_path_fields(self.path_fields(method))
832+
return sample
771833

772834
@classmethod
773835
def try_parse_http_rule(cls, http_rule) -> Optional['HttpRule']:

gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/rest.py.j2

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ class {{service.name}}RestTransport({{service.name}}Transport):
170170
{% if method.input.required_fields %}
171171
__{{ method.name | snake_case }}_required_fields_default_values = {
172172
{% for req_field in method.input.required_fields if req_field.is_primitive %}
173-
"{{ req_field.name | camel_case }}" : {% if req_field.field_pb.default_value is string %}"{{req_field.field_pb.default_value }}"{% else %}{{ req_field.field_pb.default_value }}{% endif %},{# default is str #}
173+
"{{ req_field.name | camel_case }}" : {% if req_field.field_pb.type == 9 %}"{{req_field.field_pb.default_value }}"{% else %}{{ req_field.type.python_type(req_field.field_pb.default_value or 0) }}{% endif %},{# default is str #}
174174
{% endfor %}
175175
}
176176

gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1134,11 +1134,11 @@ def test_{{ method_name }}_rest(transport: str = 'rest', request_type={{ method.
11341134
)
11351135

11361136
# send a request that will satisfy transcoding
1137-
request_init = {{ method.http_options[0].sample_request}}
1137+
request_init = {{ method.http_options[0].sample_request(method) }}
11381138
{% for field in method.body_fields.values() %}
11391139
{% if not field.oneof or field.proto3_optional %}
11401140
{# ignore oneof fields that might conflict with sample_request #}
1141-
request_init["{{ field.name }}"] = {{ field.mock_value }}
1141+
request_init["{{ field.name }}"] = {{ field.mock_value_original_type }}
11421142
{% endif %}
11431143
{% endfor %}
11441144
request = request_type(request_init)
@@ -1221,10 +1221,10 @@ def test_{{ method_name }}_rest_required_fields(request_type={{ method.input.ide
12211221

12221222
request_init = {}
12231223
{% for req_field in method.input.required_fields if req_field.is_primitive %}
1224-
{% if req_field.field_pb.default_value is string %}
1224+
{% if req_field.field_pb.type == 9 %}
12251225
request_init["{{ req_field.name }}"] = "{{ req_field.field_pb.default_value }}"
12261226
{% else %}
1227-
request_init["{{ req_field.name }}"] = {{ req_field.field_pb.default_value }}
1227+
request_init["{{ req_field.name }}"] = {{ req_field.type.python_type(req_field.field_pb.default_value or 0) }}
12281228
{% endif %}{# default is str #}
12291229
{% endfor %}
12301230
request = request_type(request_init)
@@ -1324,10 +1324,10 @@ def test_{{ method_name }}_rest_required_fields(request_type={{ method.input.ide
13241324
{% for req_field in method.input.required_fields if req_field.is_primitive %}
13251325
(
13261326
"{{ req_field.name | camel_case }}",
1327-
{% if req_field.field_pb.default_value is string %}
1327+
{% if req_field.field_pb.type == 9 %}
13281328
"{{ req_field.field_pb.default_value }}",
13291329
{% else %}
1330-
{{ req_field.field_pb.default_value }},
1330+
{{ req_field.type.python_type(req_field.field_pb.default_value or 0) }},
13311331
{% endif %}{# default is str #}
13321332
),
13331333
{% endfor %}
@@ -1346,11 +1346,11 @@ def test_{{ method_name }}_rest_bad_request(transport: str = 'rest', request_typ
13461346
)
13471347

13481348
# send a request that will satisfy transcoding
1349-
request_init = {{ method.http_options[0].sample_request}}
1349+
request_init = {{ method.http_options[0].sample_request(method) }}
13501350
{% for field in method.body_fields.values() %}
13511351
{% if not field.oneof or field.proto3_optional %}
13521352
{# ignore oneof fields that might conflict with sample_request #}
1353-
request_init["{{ field.name }}"] = {{ field.mock_value }}
1353+
request_init["{{ field.name }}"] = {{ field.mock_value_original_type }}
13541354
{% endif %}
13551355
{% endfor %}
13561356
request = request_type(request_init)
@@ -1411,7 +1411,7 @@ def test_{{ method_name }}_rest_flattened(transport: str = 'rest'):
14111411
req.return_value = response_value
14121412

14131413
# get arguments that satisfy an http rule for this method
1414-
sample_request = {{ method.http_options[0].sample_request }}
1414+
sample_request = {{ method.http_options[0].sample_request(method) }}
14151415

14161416
# get truthy value for each flattened field
14171417
mock_args = dict(
@@ -1531,7 +1531,7 @@ def test_{{ method_name }}_rest_pager(transport: str = 'rest'):
15311531
return_val.status_code = 200
15321532
req.side_effect = return_values
15331533

1534-
sample_request = {{ method.http_options[0].sample_request }}
1534+
sample_request = {{ method.http_options[0].sample_request(method) }}
15351535
{% for field in method.body_fields.values() %}
15361536
{% if not field.oneof or field.proto3_optional %}
15371537
{# ignore oneof fields that might conflict with sample_request #}

gapic/utils/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
from gapic.utils.reserved_names import RESERVED_NAMES
3030
from gapic.utils.rst import rst
3131
from gapic.utils.uri_conv import convert_uri_fieldnames
32-
from gapic.utils.uri_sample import sample_from_path_fields
3332

3433

3534
__all__ = (
@@ -44,7 +43,6 @@
4443
'partition',
4544
'RESERVED_NAMES',
4645
'rst',
47-
'sample_from_path_fields',
4846
'sort_lines',
4947
'to_snake_case',
5048
'to_camel_case',

gapic/utils/uri_sample.py

Lines changed: 0 additions & 78 deletions
This file was deleted.
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
// Copyright (C) 2021 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
syntax = "proto3";
16+
17+
package google.fragment;
18+
19+
import "google/api/client.proto";
20+
import "google/api/field_behavior.proto";
21+
import "google/api/annotations.proto";
22+
23+
service RestService {
24+
option (google.api.default_host) = "my.example.com";
25+
26+
rpc MyMethod(MethodRequest) returns (MethodResponse) {
27+
option (google.api.http) = {
28+
get: "/restservice/v1/mass_kg/{mass_kg}/length_cm/{length_cm}"
29+
};
30+
}
31+
}
32+
33+
34+
message MethodRequest {
35+
int32 mass_kg = 1 [(google.api.field_behavior) = REQUIRED];
36+
float length_cm = 2 [(google.api.field_behavior) = REQUIRED];
37+
}
38+
39+
message MethodResponse {
40+
string name = 1;
41+
}

tests/unit/schema/wrappers/test_method.py

Lines changed: 45 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -470,19 +470,58 @@ def test_method_http_options_generate_sample():
470470
http_rule = http_pb2.HttpRule(
471471
get='/v1/{resource.id=projects/*/regions/*/id/**}/stuff',
472472
)
473-
method = make_method('DoSomething', http_rule=http_rule)
474-
sample = method.http_options[0].sample_request
475-
assert json.loads(sample) == {'resource': {
473+
474+
method = make_method(
475+
'DoSomething',
476+
make_message(
477+
name="Input",
478+
fields=[
479+
make_field(
480+
name="resource",
481+
number=1,
482+
type=11,
483+
message=make_message(
484+
"Resource",
485+
fields=[
486+
make_field(name="id", type=9),
487+
],
488+
),
489+
),
490+
],
491+
),
492+
http_rule=http_rule,
493+
)
494+
sample = method.http_options[0].sample_request(method)
495+
assert sample == {'resource': {
476496
'id': 'projects/sample1/regions/sample2/id/sample3'}}
477497

478498

479499
def test_method_http_options_generate_sample_implicit_template():
480500
http_rule = http_pb2.HttpRule(
481501
get='/v1/{resource.id}/stuff',
482502
)
483-
method = make_method('DoSomething', http_rule=http_rule)
484-
sample = method.http_options[0].sample_request
485-
assert json.loads(sample) == {'resource': {
503+
method = make_method(
504+
'DoSomething',
505+
make_message(
506+
name="Input",
507+
fields=[
508+
make_field(
509+
name="resource",
510+
number=1,
511+
message=make_message(
512+
"Resource",
513+
fields=[
514+
make_field(name="id", type=9),
515+
],
516+
),
517+
),
518+
],
519+
),
520+
http_rule=http_rule,
521+
)
522+
523+
sample = method.http_options[0].sample_request(method)
524+
assert sample == {'resource': {
486525
'id': 'sample1'}}
487526

488527

0 commit comments

Comments
 (0)