Skip to content

Commit f8008e8

Browse files
committed
Update env param directly and add test
1 parent ea3e482 commit f8008e8

File tree

2 files changed

+89
-2
lines changed

2 files changed

+89
-2
lines changed

Diff for: src/stepfunctions/steps/sagemaker.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -286,10 +286,20 @@ def __init__(self, state_id, model, model_name=None, instance_type=None, tags=No
286286
instance_type (str, optional): The EC2 instance type to deploy this Model to. For example, 'ml.p2.xlarge'.
287287
tags (list[dict], optional): `List to tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
288288
"""
289-
if isinstance(model, Model):
289+
if isinstance(model, FrameworkModel):
290290
parameters = model_config(model=model, instance_type=instance_type, role=model.role, image_uri=model.image_uri)
291291
if model_name:
292292
parameters['ModelName'] = model_name
293+
elif isinstance(model, Model):
294+
parameters = {
295+
'ExecutionRoleArn': model.role,
296+
'ModelName': model_name or model.name,
297+
'PrimaryContainer': {
298+
'Environment': model.env,
299+
'Image': model.image_uri,
300+
'ModelDataUrl': model.model_data
301+
}
302+
}
293303
else:
294304
raise ValueError("Expected 'model' parameter to be of type 'sagemaker.model.Model', but received type '{}'".format(type(model).__name__))
295305

Diff for: tests/unit/test_sagemaker_steps.py

+78-1
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,11 @@ def pca_estimator_with_env():
8181
environment={
8282
'JobName': "job_name",
8383
'ModelName': "model_name"
84-
}
84+
},
85+
subnets=[
86+
'subnet-00000000000000000',
87+
'subnet-00000000000000001'
88+
]
8589
)
8690

8791
pca.set_hyperparameters(
@@ -187,6 +191,31 @@ def pca_model():
187191
)
188192

189193

194+
@pytest.fixture
195+
def pca_model_with_env():
196+
model_data = 's3://sagemaker/models/pca.tar.gz'
197+
return Model(
198+
model_data=model_data,
199+
image_uri=PCA_IMAGE,
200+
role=EXECUTION_ROLE,
201+
name='pca-model',
202+
env={
203+
'JobName': "job_name",
204+
'ModelName': "model_name"
205+
},
206+
vpc_config={
207+
"SecurityGroupIds": ["sg-00000000000000000"],
208+
"Subnets": ["subnet-00000000000000000", "subnet-00000000000000001"]
209+
},
210+
image_config={
211+
"RepositoryAccessMode": "Vpc",
212+
"RepositoryAuthConfig": {
213+
"RepositoryCredentialsProviderArn": "arn"
214+
}
215+
}
216+
)
217+
218+
190219
@pytest.fixture
191220
def pca_transformer(pca_model):
192221
return Transformer(
@@ -855,6 +884,31 @@ def test_get_expected_model(pca_estimator):
855884
}
856885

857886

887+
@patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call)
888+
@patch.object(boto3.session.Session, 'region_name', 'us-east-1')
889+
def test_get_expected_model_with_env(pca_estimator_with_env):
890+
training_step = TrainingStep('Training', estimator=pca_estimator_with_env, job_name='TrainingJob')
891+
expected_model = training_step.get_expected_model()
892+
model_step = ModelStep('Create model', model=expected_model, model_name='pca-model')
893+
assert model_step.to_dict() == {
894+
'Type': 'Task',
895+
'Parameters': {
896+
'ExecutionRoleArn': EXECUTION_ROLE,
897+
'ModelName': 'pca-model',
898+
'PrimaryContainer': {
899+
'Environment': {
900+
'JobName': 'job_name',
901+
'ModelName': 'model_name'
902+
},
903+
'Image': expected_model.image_uri,
904+
'ModelDataUrl.$': "$['ModelArtifacts']['S3ModelArtifacts']"
905+
}
906+
},
907+
'Resource': 'arn:aws:states:::sagemaker:createModel',
908+
'End': True
909+
}
910+
911+
858912
@patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call)
859913
@patch.object(boto3.session.Session, 'region_name', 'us-east-1')
860914
def test_get_expected_model_with_framework_estimator(tensorflow_estimator):
@@ -908,6 +962,29 @@ def test_model_step_creation(pca_model):
908962
}
909963

910964

965+
@patch.object(boto3.session.Session, 'region_name', 'us-east-1')
966+
def test_model_step_creation_with_env(pca_model_with_env):
967+
step = ModelStep('Create model', model=pca_model_with_env, model_name='pca-model', tags=DEFAULT_TAGS)
968+
assert step.to_dict() == {
969+
'Type': 'Task',
970+
'Parameters': {
971+
'ExecutionRoleArn': EXECUTION_ROLE,
972+
'ModelName': 'pca-model',
973+
'PrimaryContainer': {
974+
'Environment': {
975+
'JobName': 'job_name',
976+
'ModelName': 'model_name'
977+
},
978+
'Image': pca_model_with_env.image_uri,
979+
'ModelDataUrl': pca_model_with_env.model_data
980+
},
981+
'Tags': DEFAULT_TAGS_LIST
982+
},
983+
'Resource': 'arn:aws:states:::sagemaker:createModel',
984+
'End': True
985+
}
986+
987+
911988
@patch.object(boto3.session.Session, 'region_name', 'us-east-1')
912989
def test_endpoint_config_step_creation(pca_model):
913990
data_capture_config = DataCaptureConfig(

0 commit comments

Comments
 (0)