16
16
from functools import lru_cache
17
17
from typing import Optional , Any
18
18
19
- from sagemaker_core .tools .constants import BASIC_JSON_TYPES_TO_PYTHON_TYPES , SHAPE_DAG_FILE_PATH
19
+ from sagemaker_core .tools .constants import (
20
+ BASIC_JSON_TYPES_TO_PYTHON_TYPES ,
21
+ SHAPE_DAG_FILE_PATH ,
22
+ SHAPES_WITH_JSON_FIELD_ALIAS ,
23
+ )
20
24
from sagemaker_core .main .utils import (
21
25
reformat_file_with_black ,
22
26
convert_to_snake_case ,
@@ -99,6 +103,11 @@ def get_shapes_dag(self):
99
103
_dag [shape ] = {"type" : "structure" , "members" : []}
100
104
for member , member_attrs in shape_data ["members" ].items ():
101
105
shape_node_member = {"name" : member , "shape" : member_attrs ["shape" ]}
106
+ # Add alias if field name is json, to address the Bug: https://github.com/aws/sagemaker-python-sdk/issues/4944
107
+ if shape in SHAPES_WITH_JSON_FIELD_ALIAS and member == "Json" :
108
+ shape_node_member ["name" ] = "JsonFormat"
109
+ shape_node_member ["alias" ] = "json"
110
+
102
111
member_shape_dict = _all_shapes [member_attrs ["shape" ]]
103
112
shape_node_member ["type" ] = member_shape_dict ["type" ]
104
113
_dag [shape ]["members" ].append (shape_node_member )
@@ -218,6 +227,8 @@ def generate_shape_members(self, shape, required_override=()):
218
227
# bring the required members in front
219
228
ordered_members = {key : members [key ] for key in required_args if key in members }
220
229
ordered_members .update (members )
230
+ field_aliases = {}
231
+
221
232
for member_name , member_attrs in ordered_members .items ():
222
233
member_shape_name = member_attrs ["shape" ]
223
234
if self .combined_shapes [member_shape_name ]:
@@ -234,13 +245,28 @@ def generate_shape_members(self, shape, required_override=()):
234
245
member_type = BASIC_JSON_TYPES_TO_PYTHON_TYPES [member_shape_type ]
235
246
else :
236
247
raise Exception ("The Shape definition mush exist. The Json Data might be corrupt" )
237
- member_name_snake_case = convert_to_snake_case (member_name )
238
- if member_name in required_args :
239
- init_data_body [f"{ member_name_snake_case } " ] = f"{ member_type } "
240
- else :
241
- init_data_body [f"{ member_name_snake_case } " ] = (
242
- f"Optional[{ member_type } ] = Unassigned()"
248
+
249
+ is_required = member_name in required_args
250
+ # Add alias if field name is json, to address the Bug: https://github.com/aws/sagemaker-python-sdk/issues/4944
251
+ if shape in SHAPES_WITH_JSON_FIELD_ALIAS and member_name == "Json" :
252
+ updated_member_name_snake_case = "json_format"
253
+ field_aliases [updated_member_name_snake_case ] = "json"
254
+ init_data_body [f"{ updated_member_name_snake_case } " ] = (
255
+ (
256
+ f"{ member_type } = Field(alias='{ field_aliases [updated_member_name_snake_case ]} ')"
257
+ )
258
+ if is_required
259
+ else f"Optional[{ member_type } ] = Field(default=Unassigned(), alias='json')"
243
260
)
261
+ else :
262
+ member_name_snake_case = convert_to_snake_case (member_name )
263
+ if is_required :
264
+ init_data_body [f"{ member_name_snake_case } " ] = f"{ member_type } "
265
+ else :
266
+ init_data_body [f"{ member_name_snake_case } " ] = (
267
+ f"Optional[{ member_type } ] = Unassigned()"
268
+ )
269
+
244
270
return init_data_body
245
271
246
272
@lru_cache
0 commit comments