Skip to content

Commit c7b8c02

Browse files
committed
Adjust API representation for STRUCT schema fields
1 parent 3932d2a commit c7b8c02

File tree

4 files changed

+66
-17
lines changed

4 files changed

+66
-17
lines changed

google/cloud/bigquery/schema.py

+21-8
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,8 @@ def __init__(
107107
self._fields = tuple(fields)
108108

109109
if policy_tags is None:
110-
self._policy_tags = PolicyTagList()
110+
is_record = field_type is not None and field_type.upper() in _STRUCT_TYPES
111+
self._policy_tags = None if is_record else PolicyTagList()
111112
else:
112113
self._policy_tags = policy_tags
113114

@@ -130,18 +131,24 @@ def from_api_repr(cls, api_repr: dict) -> "SchemaField":
130131
Returns:
131132
google.cloud.biquery.schema.SchemaField: The ``SchemaField`` object.
132133
"""
134+
field_type = api_repr["type"].upper()
135+
133136
# Handle optional properties with default values
134137
mode = api_repr.get("mode", "NULLABLE")
135138
description = api_repr.get("description", _DEFAULT_VALUE)
136139
fields = api_repr.get("fields", ())
137140

141+
policy_tags = PolicyTagList.from_api_repr(api_repr.get("policyTags"))
142+
if policy_tags is None and field_type not in _STRUCT_TYPES:
143+
policy_tags = PolicyTagList()
144+
138145
return cls(
139-
field_type=api_repr["type"].upper(),
146+
field_type=field_type,
140147
fields=[cls.from_api_repr(f) for f in fields],
141148
mode=mode.upper(),
142149
description=description,
143150
name=api_repr["name"],
144-
policy_tags=PolicyTagList.from_api_repr(api_repr.get("policyTags")),
151+
policy_tags=policy_tags,
145152
precision=cls.__get_int(api_repr, "precision"),
146153
scale=cls.__get_int(api_repr, "scale"),
147154
max_length=cls.__get_int(api_repr, "maxLength"),
@@ -205,7 +212,7 @@ def fields(self):
205212

206213
@property
207214
def policy_tags(self):
208-
"""google.cloud.bigquery.schema.PolicyTagList: Policy tag list
215+
"""Optional[google.cloud.bigquery.schema.PolicyTagList]: Policy tag list
209216
definition for this field.
210217
"""
211218
return self._policy_tags
@@ -222,9 +229,10 @@ def to_api_repr(self) -> dict:
222229
# add this to the serialized representation.
223230
if self.field_type.upper() in _STRUCT_TYPES:
224231
answer["fields"] = [f.to_api_repr() for f in self.fields]
225-
226-
# Also include policy tag definition, even if empty.
227-
answer["policyTags"] = self.policy_tags.to_api_repr()
232+
else:
233+
# Explicitly include policy tag definition (we must not do it for RECORD
234+
# fields, because those are not leaf fields).
235+
answer["policyTags"] = self.policy_tags.to_api_repr()
228236

229237
# Done; return the serialized dictionary.
230238
return answer
@@ -247,14 +255,19 @@ def _key(self):
247255
field_type = f"{field_type}({self.precision}, {self.scale})"
248256
else:
249257
field_type = f"{field_type}({self.precision})"
258+
259+
policy_tags = (
260+
() if self._policy_tags is None else tuple(sorted(self._policy_tags.names))
261+
)
262+
250263
return (
251264
self.name,
252265
field_type,
253266
# Mode is always str, if not given it defaults to a str value
254267
self.mode.upper(), # pytype: disable=attribute-error
255268
self.description,
256269
self._fields,
257-
tuple(sorted(self._policy_tags.names)),
270+
policy_tags,
258271
)
259272

260273
def to_standard_sql(self) -> types.StandardSqlField:

tests/unit/test_client.py

+37-6
Original file line numberDiff line numberDiff line change
@@ -1019,8 +1019,18 @@ def test_create_table_w_schema_and_query(self):
10191019
{
10201020
"schema": {
10211021
"fields": [
1022-
{"name": "full_name", "type": "STRING", "mode": "REQUIRED"},
1023-
{"name": "age", "type": "INTEGER", "mode": "REQUIRED"},
1022+
{
1023+
"name": "full_name",
1024+
"type": "STRING",
1025+
"mode": "REQUIRED",
1026+
"policyTags": {"names": []},
1027+
},
1028+
{
1029+
"name": "age",
1030+
"type": "INTEGER",
1031+
"mode": "REQUIRED",
1032+
"policyTags": {"names": []},
1033+
},
10241034
]
10251035
},
10261036
"view": {"query": query},
@@ -1054,8 +1064,18 @@ def test_create_table_w_schema_and_query(self):
10541064
},
10551065
"schema": {
10561066
"fields": [
1057-
{"name": "full_name", "type": "STRING", "mode": "REQUIRED"},
1058-
{"name": "age", "type": "INTEGER", "mode": "REQUIRED"},
1067+
{
1068+
"name": "full_name",
1069+
"type": "STRING",
1070+
"mode": "REQUIRED",
1071+
"policyTags": {"names": []},
1072+
},
1073+
{
1074+
"name": "age",
1075+
"type": "INTEGER",
1076+
"mode": "REQUIRED",
1077+
"policyTags": {"names": []},
1078+
},
10591079
]
10601080
},
10611081
"view": {"query": query, "useLegacySql": False},
@@ -2000,12 +2020,14 @@ def test_update_table(self):
20002020
"type": "STRING",
20012021
"mode": "REQUIRED",
20022022
"description": None,
2023+
"policyTags": {"names": []},
20032024
},
20042025
{
20052026
"name": "age",
20062027
"type": "INTEGER",
20072028
"mode": "REQUIRED",
20082029
"description": "New field description",
2030+
"policyTags": {"names": []},
20092031
},
20102032
]
20112033
},
@@ -2047,12 +2069,14 @@ def test_update_table(self):
20472069
"type": "STRING",
20482070
"mode": "REQUIRED",
20492071
"description": None,
2072+
"policyTags": {"names": []},
20502073
},
20512074
{
20522075
"name": "age",
20532076
"type": "INTEGER",
20542077
"mode": "REQUIRED",
20552078
"description": "New field description",
2079+
"policyTags": {"names": []},
20562080
},
20572081
]
20582082
},
@@ -2173,14 +2197,21 @@ def test_update_table_w_query(self):
21732197
"type": "STRING",
21742198
"mode": "REQUIRED",
21752199
"description": None,
2200+
"policyTags": {"names": []},
21762201
},
21772202
{
21782203
"name": "age",
21792204
"type": "INTEGER",
21802205
"mode": "REQUIRED",
21812206
"description": "this is a column",
2207+
"policyTags": {"names": []},
2208+
},
2209+
{
2210+
"name": "country",
2211+
"type": "STRING",
2212+
"mode": "NULLABLE",
2213+
"policyTags": {"names": []},
21822214
},
2183-
{"name": "country", "type": "STRING", "mode": "NULLABLE"},
21842215
]
21852216
}
21862217
schema = [
@@ -6516,10 +6547,10 @@ def test_load_table_from_dataframe(self):
65166547
assert field["type"] == table_field.field_type
65176548
assert field["mode"] == table_field.mode
65186549
assert len(field.get("fields", [])) == len(table_field.fields)
6550+
assert field["policyTags"]["names"] == []
65196551
# Omit unnecessary fields when they come from getting the table
65206552
# (not passed in via job_config)
65216553
assert "description" not in field
6522-
assert "policyTags" not in field
65236554

65246555
@unittest.skipIf(pandas is None, "Requires `pandas`")
65256556
@unittest.skipIf(pyarrow is None, "Requires `pyarrow`")

tests/unit/test_external_config.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,14 @@ def test_to_api_repr_base(self):
7878
ec.schema = [schema.SchemaField("full_name", "STRING", mode="REQUIRED")]
7979

8080
exp_schema = {
81-
"fields": [{"name": "full_name", "type": "STRING", "mode": "REQUIRED"}]
81+
"fields": [
82+
{
83+
"name": "full_name",
84+
"type": "STRING",
85+
"mode": "REQUIRED",
86+
"policyTags": {"names": []},
87+
}
88+
]
8289
}
8390
got_resource = ec.to_api_repr()
8491
exp_resource = {

tests/unit/test_schema.py

-2
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,6 @@ def test_to_api_repr_with_subfield(self):
117117
"mode": "REQUIRED",
118118
"name": "foo",
119119
"type": record_type,
120-
"policyTags": {"names": []},
121120
},
122121
)
123122

@@ -649,7 +648,6 @@ def test_w_subfields(self):
649648
"name": "phone",
650649
"type": "RECORD",
651650
"mode": "REPEATED",
652-
"policyTags": {"names": []},
653651
"fields": [
654652
{
655653
"name": "type",

0 commit comments

Comments
 (0)