Skip to content

Commit 3519265

Browse files
committed
fix test
1 parent 06dc458 commit 3519265

File tree

9 files changed

+24
-35
lines changed

9 files changed

+24
-35
lines changed

integ/test_codegen.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,8 @@ def test_training_and_inference(self):
163163
content_type="text/csv",
164164
accept="application/csv",
165165
)
166-
assert invoke_result.body.payload_part
166+
167+
assert invoke_result["Body"]["PayloadPart"]
167168

168169
def test_intelligent_defaults(self):
169170
os.environ["SAGEMAKER_CORE_ADMIN_CONFIG_OVERRIDE"] = (

src/sagemaker_core/main/resources.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -8389,7 +8389,7 @@ def invoke_with_response_stream(
83898389
inference_component_name: Optional[str] = Unassigned(),
83908390
session: Optional[Session] = None,
83918391
region: Optional[str] = None,
8392-
) -> Optional[InvokeEndpointWithResponseStreamOutput]:
8392+
) -> Optional[object]:
83938393
"""
83948394
Invokes a model at the specified endpoint to return the inference response as a stream.
83958395

@@ -8406,7 +8406,7 @@ def invoke_with_response_stream(
84068406
region: Region name.
84078407

84088408
Returns:
8409-
InvokeEndpointWithResponseStreamOutput
8409+
object
84108410

84118411
Raises:
84128412
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
@@ -8449,8 +8449,7 @@ def invoke_with_response_stream(
84498449
response = client.invoke_endpoint_with_response_stream(**operation_input_args)
84508450
logger.debug(f"Response: {response}")
84518451

8452-
transformed_response = transform(response, "InvokeEndpointWithResponseStreamOutput")
8453-
return InvokeEndpointWithResponseStreamOutput(**transformed_response)
8452+
return response
84548453

84558454

84568455
class EndpointConfig(Base):

src/sagemaker_core/main/shapes.py

-18
Original file line numberDiff line numberDiff line change
@@ -139,24 +139,6 @@ class ResponseStream(Base):
139139
internal_stream_failure: Optional[InternalStreamFailure] = Unassigned()
140140

141141

142-
class InvokeEndpointWithResponseStreamOutput(Base):
143-
"""
144-
InvokeEndpointWithResponseStreamOutput
145-
146-
Attributes
147-
----------------------
148-
body
149-
content_type: The MIME type of the inference returned from the model container.
150-
invoked_production_variant: Identifies the production variant that was invoked.
151-
custom_attributes: Provides additional information in the response about the inference returned by a model hosted at an Amazon SageMaker endpoint. The information is an opaque value that is forwarded verbatim. You could use this value, for example, to return an ID received in the CustomAttributes header of a request or other metadata that a service endpoint was programmed to produce. The value must consist of no more than 1024 visible US-ASCII characters as specified in Section 3.3.6. Field Value Components of the Hypertext Transfer Protocol (HTTP/1.1). If the customer wants the custom attribute returned, the model must set the custom attribute to be included on the way back. The code in your model is responsible for setting or updating any custom attributes in the response. If your code does not set this value in the response, an empty value is returned. For example, if a custom attribute represents the trace ID, your model can prepend the custom attribute with Trace ID: in your post-processing function. This feature is currently supported in the Amazon Web Services SDKs but not in the Amazon SageMaker Python SDK.
152-
"""
153-
154-
body: ResponseStream
155-
content_type: Optional[str] = Unassigned()
156-
invoked_production_variant: Optional[str] = Unassigned()
157-
custom_attributes: Optional[str] = Unassigned()
158-
159-
160142
class ModelError(Base):
161143
"""
162144
ModelError

src/sagemaker_core/main/utils.py

+3-6
Original file line numberDiff line numberDiff line change
@@ -500,8 +500,7 @@ def _serialize_dict(value: Dict) -> dict:
500500
"""
501501
serialized_dict = {}
502502
for k, v in value.items():
503-
serialize_result = serialize(v)
504-
if serialize_result is not None:
503+
if (serialize_result := serialize(v)) is not None:
505504
serialized_dict.update({k: serialize_result})
506505
return serialized_dict
507506

@@ -518,8 +517,7 @@ def _serialize_list(value: List) -> list:
518517
"""
519518
serialized_list = []
520519
for v in value:
521-
serialize_result = serialize(v)
522-
if serialize_result is not None:
520+
if (serialize_result := serialize(v)) is not None:
523521
serialized_list.append(serialize_result)
524522
return serialized_list
525523

@@ -536,8 +534,7 @@ def _serialize_shape(value: Any) -> dict:
536534
"""
537535
serialized_dict = {}
538536
for k, v in vars(value).items():
539-
serialize_result = serialize(v)
540-
if serialize_result is not None:
537+
if (serialize_result := serialize(v)) is not None:
541538
key = snake_to_pascal(k) if is_snake_case(k) else k
542539
serialized_dict.update({key[0].upper() + key[1:]: serialize_result})
543540
return serialized_dict

src/sagemaker_core/tools/additional_operations.json

+1-1
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@
130130
"operation_name": "InvokeEndpointWithResponseStream",
131131
"resource_name": "Endpoint",
132132
"method_name": "invoke_with_response_stream",
133-
"return_type": "InvokeEndpointWithResponseStreamOutput",
133+
"return_type": "object",
134134
"method_type": "object",
135135
"service_name": "sagemaker-runtime"
136136
}

src/sagemaker_core/tools/resources_codegen.py

+6
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
REFRESH_METHOD_TEMPLATE,
5656
RESOURCE_BASE_CLASS_TEMPLATE,
5757
RETURN_ITERATOR_TEMPLATE,
58+
RETURN_WITHOUT_DESERIALIZATION_TEMPLATE,
5859
SERIALIZE_INPUT_TEMPLATE,
5960
STOP_METHOD_TEMPLATE,
6061
DELETE_METHOD_TEMPLATE,
@@ -1373,6 +1374,11 @@ def generate_method(self, method: Method, resource_attributes: list):
13731374
return_type = f"Optional[{method.return_type}]"
13741375
deserialize_response = DESERIALIZE_RESPONSE_TO_BASIC_TYPE_TEMPLATE
13751376
return_string = f"Returns:\n" f" {method.return_type}\n"
1377+
elif method.return_type == "object":
1378+
# if the return type is object, return the response without deserialization
1379+
return_type = f"Optional[{method.return_type}]"
1380+
deserialize_response = RETURN_WITHOUT_DESERIALIZATION_TEMPLATE
1381+
return_string = f"Returns:\n" f" {method.return_type}\n"
13761382
else:
13771383
if method.return_type == "cls":
13781384
return_type = f'Optional["{method.resource_name}"]'

src/sagemaker_core/tools/templates.py

+3
Original file line numberDiff line numberDiff line change
@@ -553,6 +553,9 @@ def {method_name}(
553553
DESERIALIZE_RESPONSE_TO_BASIC_TYPE_TEMPLATE = """
554554
return list(response.values())[0]"""
555555

556+
RETURN_WITHOUT_DESERIALIZATION_TEMPLATE = """
557+
return response"""
558+
556559
RETURN_ITERATOR_TEMPLATE = """
557560
return ResourceIterator(
558561
{resource_iterator_args}

tst/generated/test_resources.py

+2
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,8 @@ def test_resources(self, session, mock_transform):
295295
operation_info["return_type"]
296296
]
297297
}
298+
elif operation_info["return_type"] == "object":
299+
return_value = {"return_value": None}
298300
else:
299301
return_cls = self.SHAPE_CLASSES_BY_SHAPE_NAME[
300302
operation_info["return_type"]

tst/tools/test_resources_codegen.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -873,7 +873,7 @@ def invoke_with_response_stream(
873873
inference_component_name: Optional[str] = Unassigned(),
874874
session: Optional[Session] = None,
875875
region: Optional[str] = None,
876-
) -> Optional[InvokeEndpointWithResponseStreamOutput]:
876+
) -> Optional[object]:
877877
"""
878878
Invokes a model at the specified endpoint to return the inference response as a stream.
879879
@@ -890,7 +890,7 @@ def invoke_with_response_stream(
890890
region: Region name.
891891
892892
Returns:
893-
InvokeEndpointWithResponseStreamOutput
893+
object
894894
895895
Raises:
896896
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
@@ -932,15 +932,14 @@ def invoke_with_response_stream(
932932
response = client.invoke_endpoint_with_response_stream(**operation_input_args)
933933
logger.debug(f"Response: {response}")
934934
935-
transformed_response = transform(response, 'InvokeEndpointWithResponseStreamOutput')
936-
return InvokeEndpointWithResponseStreamOutput(**transformed_response)
935+
return response
937936
'''
938937
method = Method(
939938
**{
940939
"operation_name": "InvokeEndpointWithResponseStream",
941940
"resource_name": "Endpoint",
942941
"method_name": "invoke_with_response_stream",
943-
"return_type": "InvokeEndpointWithResponseStreamOutput",
942+
"return_type": "object",
944943
"method_type": "object",
945944
"service_name": "sagemaker-runtime",
946945
}

0 commit comments

Comments
 (0)