@@ -81,7 +81,11 @@ def pca_estimator_with_env():
81
81
environment = {
82
82
'JobName' : "job_name" ,
83
83
'ModelName' : "model_name"
84
- }
84
+ },
85
+ subnets = [
86
+ 'subnet-00000000000000000' ,
87
+ 'subnet-00000000000000001'
88
+ ]
85
89
)
86
90
87
91
pca .set_hyperparameters (
@@ -187,6 +191,31 @@ def pca_model():
187
191
)
188
192
189
193
194
+ @pytest .fixture
195
+ def pca_model_with_env ():
196
+ model_data = 's3://sagemaker/models/pca.tar.gz'
197
+ return Model (
198
+ model_data = model_data ,
199
+ image_uri = PCA_IMAGE ,
200
+ role = EXECUTION_ROLE ,
201
+ name = 'pca-model' ,
202
+ env = {
203
+ 'JobName' : "job_name" ,
204
+ 'ModelName' : "model_name"
205
+ },
206
+ vpc_config = {
207
+ "SecurityGroupIds" : ["sg-00000000000000000" ],
208
+ "Subnets" : ["subnet-00000000000000000" , "subnet-00000000000000001" ]
209
+ },
210
+ image_config = {
211
+ "RepositoryAccessMode" : "Vpc" ,
212
+ "RepositoryAuthConfig" : {
213
+ "RepositoryCredentialsProviderArn" : "arn"
214
+ }
215
+ }
216
+ )
217
+
218
+
190
219
@pytest .fixture
191
220
def pca_transformer (pca_model ):
192
221
return Transformer (
@@ -855,6 +884,31 @@ def test_get_expected_model(pca_estimator):
855
884
}
856
885
857
886
887
+ @patch ('botocore.client.BaseClient._make_api_call' , new = mock_boto_api_call )
888
+ @patch .object (boto3 .session .Session , 'region_name' , 'us-east-1' )
889
+ def test_get_expected_model_with_env (pca_estimator_with_env ):
890
+ training_step = TrainingStep ('Training' , estimator = pca_estimator_with_env , job_name = 'TrainingJob' )
891
+ expected_model = training_step .get_expected_model ()
892
+ model_step = ModelStep ('Create model' , model = expected_model , model_name = 'pca-model' )
893
+ assert model_step .to_dict () == {
894
+ 'Type' : 'Task' ,
895
+ 'Parameters' : {
896
+ 'ExecutionRoleArn' : EXECUTION_ROLE ,
897
+ 'ModelName' : 'pca-model' ,
898
+ 'PrimaryContainer' : {
899
+ 'Environment' : {
900
+ 'JobName' : 'job_name' ,
901
+ 'ModelName' : 'model_name'
902
+ },
903
+ 'Image' : expected_model .image_uri ,
904
+ 'ModelDataUrl.$' : "$['ModelArtifacts']['S3ModelArtifacts']"
905
+ }
906
+ },
907
+ 'Resource' : 'arn:aws:states:::sagemaker:createModel' ,
908
+ 'End' : True
909
+ }
910
+
911
+
858
912
@patch ('botocore.client.BaseClient._make_api_call' , new = mock_boto_api_call )
859
913
@patch .object (boto3 .session .Session , 'region_name' , 'us-east-1' )
860
914
def test_get_expected_model_with_framework_estimator (tensorflow_estimator ):
@@ -908,6 +962,29 @@ def test_model_step_creation(pca_model):
908
962
}
909
963
910
964
965
+ @patch .object (boto3 .session .Session , 'region_name' , 'us-east-1' )
966
+ def test_model_step_creation_with_env (pca_model_with_env ):
967
+ step = ModelStep ('Create model' , model = pca_model_with_env , model_name = 'pca-model' , tags = DEFAULT_TAGS )
968
+ assert step .to_dict () == {
969
+ 'Type' : 'Task' ,
970
+ 'Parameters' : {
971
+ 'ExecutionRoleArn' : EXECUTION_ROLE ,
972
+ 'ModelName' : 'pca-model' ,
973
+ 'PrimaryContainer' : {
974
+ 'Environment' : {
975
+ 'JobName' : 'job_name' ,
976
+ 'ModelName' : 'model_name'
977
+ },
978
+ 'Image' : pca_model_with_env .image_uri ,
979
+ 'ModelDataUrl' : pca_model_with_env .model_data
980
+ },
981
+ 'Tags' : DEFAULT_TAGS_LIST
982
+ },
983
+ 'Resource' : 'arn:aws:states:::sagemaker:createModel' ,
984
+ 'End' : True
985
+ }
986
+
987
+
911
988
@patch .object (boto3 .session .Session , 'region_name' , 'us-east-1' )
912
989
def test_endpoint_config_step_creation (pca_model ):
913
990
data_capture_config = DataCaptureConfig (
0 commit comments