Skip to content

Commit fb378c4

Browse files
authored
Merge pull request #37 from gercograndia/sagemaker_tags
fix: add tags for sagemaker
2 parents 739d663 + ef26b12 commit fb378c4

File tree

2 files changed

+51
-18
lines changed

2 files changed

+51
-18
lines changed

Diff for: src/stepfunctions/steps/sagemaker.py

+21-4
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ class TrainingStep(Task):
2727
Creates a Task State to execute a `SageMaker Training Job <https://docs.aws.amazon.com/sagemaker/latest/dg/API_CreateTrainingJob.html>`_. The TrainingStep will also create a model by default, and the model shares the same name as the training job.
2828
"""
2929

30-
def __init__(self, state_id, estimator, job_name, data=None, hyperparameters=None, mini_batch_size=None, experiment_config=None, wait_for_completion=True, **kwargs):
30+
def __init__(self, state_id, estimator, job_name, data=None, hyperparameters=None, mini_batch_size=None, experiment_config=None, wait_for_completion=True, tags=None, **kwargs):
3131
"""
3232
Args:
3333
state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine.
@@ -52,6 +52,7 @@ def __init__(self, state_id, estimator, job_name, data=None, hyperparameters=Non
5252
mini_batch_size (int): Specify this argument only when estimator is a built-in estimator of an Amazon algorithm. For other estimators, batch size should be specified in the estimator.
5353
experiment_config (dict, optional): Specify the experiment config for the training. (Default: None)
5454
wait_for_completion (bool, optional): Boolean value set to `True` if the Task state should wait for the training job to complete before proceeding to the next step in the workflow. Set to `False` if the Task state should submit the training job and proceed to the next step. (default: True)
55+
tags (list[dict], optional): `List to tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
5556
"""
5657
self.estimator = estimator
5758
self.job_name = job_name
@@ -84,6 +85,9 @@ def __init__(self, state_id, estimator, job_name, data=None, hyperparameters=Non
8485
if 'S3Operations' in parameters:
8586
del parameters['S3Operations']
8687

88+
if tags:
89+
parameters['Tags'] = tags_dict_to_kv_list(tags)
90+
8791
kwargs[Field.Parameters.value] = parameters
8892
super(TrainingStep, self).__init__(state_id, **kwargs)
8993

@@ -111,7 +115,7 @@ class TransformStep(Task):
111115
Creates a Task State to execute a `SageMaker Transform Job <https://docs.aws.amazon.com/sagemaker/latest/dg/API_CreateTransformJob.html>`_.
112116
"""
113117

114-
def __init__(self, state_id, transformer, job_name, model_name, data, data_type='S3Prefix', content_type=None, compression_type=None, split_type=None, experiment_config=None, wait_for_completion=True, **kwargs):
118+
def __init__(self, state_id, transformer, job_name, model_name, data, data_type='S3Prefix', content_type=None, compression_type=None, split_type=None, experiment_config=None, wait_for_completion=True, tags=None, **kwargs):
115119
"""
116120
Args:
117121
state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine.
@@ -131,6 +135,7 @@ def __init__(self, state_id, transformer, job_name, model_name, data, data_type=
131135
split_type (str): The record delimiter for the input object (default: 'None'). Valid values: 'None', 'Line', 'RecordIO', and 'TFRecord'.
132136
experiment_config (dict, optional): Specify the experiment config for the transform. (Default: None)
133137
wait_for_completion(bool, optional): Boolean value set to `True` if the Task state should wait for the transform job to complete before proceeding to the next step in the workflow. Set to `False` if the Task state should submit the transform job and proceed to the next step. (default: True)
138+
tags (list[dict], optional): `List to tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
134139
"""
135140
if wait_for_completion:
136141
kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createTransformJob.sync'
@@ -165,6 +170,9 @@ def __init__(self, state_id, transformer, job_name, model_name, data, data_type=
165170
if experiment_config is not None:
166171
parameters['ExperimentConfig'] = experiment_config
167172

173+
if tags:
174+
parameters['Tags'] = tags_dict_to_kv_list(tags)
175+
168176
kwargs[Field.Parameters.value] = parameters
169177
super(TransformStep, self).__init__(state_id, **kwargs)
170178

@@ -175,13 +183,14 @@ class ModelStep(Task):
175183
Creates a Task State to `create a model in SageMaker <https://docs.aws.amazon.com/sagemaker/latest/dg/API_CreateModel.html>`_.
176184
"""
177185

178-
def __init__(self, state_id, model, model_name=None, instance_type=None, **kwargs):
186+
def __init__(self, state_id, model, model_name=None, instance_type=None, tags=None, **kwargs):
179187
"""
180188
Args:
181189
state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine.
182190
model (sagemaker.model.Model): The SageMaker model to use in the ModelStep. If :py:class:`TrainingStep` was used to train the model and saving the model is the next step in the workflow, the output of :py:func:`TrainingStep.get_expected_model()` can be passed here.
183191
model_name (str or Placeholder, optional): Specify a model name, this is required for creating the model. We recommend to use :py:class:`~stepfunctions.inputs.ExecutionInput` placeholder collection to pass the value dynamically in each execution.
184192
instance_type (str, optional): The EC2 instance type to deploy this Model to. For example, 'ml.p2.xlarge'. This parameter is typically required when the estimator used is not an `Amazon built-in algorithm <https://docs.aws.amazon.com/sagemaker/latest/dg/algos.html>`_.
193+
tags (list[dict], optional): `List to tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
185194
"""
186195
if isinstance(model, FrameworkModel):
187196
parameters = model_config(model=model, instance_type=instance_type, role=model.role, image=model.image)
@@ -203,6 +212,9 @@ def __init__(self, state_id, model, model_name=None, instance_type=None, **kwarg
203212
if 'S3Operations' in parameters:
204213
del parameters['S3Operations']
205214

215+
if tags:
216+
parameters['Tags'] = tags_dict_to_kv_list(tags)
217+
206218
kwargs[Field.Parameters.value] = parameters
207219
kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createModel'
208220

@@ -265,6 +277,7 @@ def __init__(self, state_id, endpoint_name, endpoint_config_name, tags=None, upd
265277
endpoint_config_name (str or Placeholder): The name of the endpoint configuration to use for the endpoint. We recommend to use :py:class:`~stepfunctions.inputs.ExecutionInput` placeholder collection to pass the value dynamically in each execution.
266278
tags (list[dict], optional): `List to tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
267279
update (bool, optional): Boolean flag set to `True` if endpoint must to be updated. Set to `False` if new endpoint must be created. (default: False)
280+
tags (list[dict], optional): `List to tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
268281
"""
269282

270283
parameters = {
@@ -291,7 +304,7 @@ class TuningStep(Task):
291304
Creates a Task State to execute a SageMaker HyperParameterTuning Job.
292305
"""
293306

294-
def __init__(self, state_id, tuner, job_name, data, wait_for_completion=True, **kwargs):
307+
def __init__(self, state_id, tuner, job_name, data, wait_for_completion=True, tags=None, **kwargs):
295308
"""
296309
Args:
297310
state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine.
@@ -313,6 +326,7 @@ def __init__(self, state_id, tuner, job_name, data, wait_for_completion=True, **
313326
:class:`sagemaker.amazon.amazon_estimator.RecordSet` objects,
314327
where each instance is a different channel of training data.
315328
wait_for_completion(bool, optional): Boolean value set to `True` if the Task state should wait for the tuning job to complete before proceeding to the next step in the workflow. Set to `False` if the Task state should submit the tuning job and proceed to the next step. (default: True)
329+
tags (list[dict], optional): `List to tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
316330
"""
317331
if wait_for_completion:
318332
kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createHyperParameterTuningJob.sync'
@@ -327,6 +341,9 @@ def __init__(self, state_id, tuner, job_name, data, wait_for_completion=True, **
327341
if 'S3Operations' in parameters:
328342
del parameters['S3Operations']
329343

344+
if tags:
345+
parameters['Tags'] = tags_dict_to_kv_list(tags)
346+
330347
kwargs[Field.Parameters.value] = parameters
331348

332349
super(TuningStep, self).__init__(state_id, **kwargs)

Diff for: tests/unit/test_sagemaker_steps.py

+30-14
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232
EXECUTION_ROLE = 'execution-role'
3333
PCA_IMAGE = '382416733822.dkr.ecr.us-east-1.amazonaws.com/pca:1'
3434
TENSORFLOW_IMAGE = '520713654638.dkr.ecr.us-east-1.amazonaws.com/sagemaker-tensorflow:1.13-gpu-py2'
35+
DEFAULT_TAGS = {'Purpose': 'unittests'}
36+
DEFAULT_TAGS_LIST = [{'Key': 'Purpose', 'Value': 'unittests'}]
3537

3638
@pytest.fixture
3739
def pca_estimator():
@@ -163,7 +165,9 @@ def test_training_step_creation(pca_estimator):
163165
'ExperimentName': 'pca_experiment',
164166
'TrialName': 'pca_trial',
165167
'TrialComponentDisplayName': 'Training'
166-
})
168+
},
169+
tags=DEFAULT_TAGS,
170+
)
167171
assert step.to_dict() == {
168172
'Type': 'Task',
169173
'Parameters': {
@@ -195,7 +199,8 @@ def test_training_step_creation(pca_estimator):
195199
'TrialName': 'pca_trial',
196200
'TrialComponentDisplayName': 'Training'
197201
},
198-
'TrainingJobName': 'TrainingJob'
202+
'TrainingJobName': 'TrainingJob',
203+
'Tags': DEFAULT_TAGS_LIST
199204
},
200205
'Resource': 'arn:aws:states:::sagemaker:createTrainingJob.sync',
201206
'End': True
@@ -318,7 +323,8 @@ def test_training_step_creation_with_framework(tensorflow_estimator):
318323
estimator=tensorflow_estimator,
319324
data={'train': 's3://sagemaker/train'},
320325
job_name='tensorflow-job',
321-
mini_batch_size=1024
326+
mini_batch_size=1024,
327+
tags=DEFAULT_TAGS,
322328
)
323329

324330
assert step.to_dict() == {
@@ -364,7 +370,9 @@ def test_training_step_creation_with_framework(tensorflow_estimator):
364370
'sagemaker_region': '"us-east-1"',
365371
'sagemaker_submit_directory': '"s3://sagemaker/source"'
366372
},
367-
'TrainingJobName': 'tensorflow-job'
373+
'TrainingJobName': 'tensorflow-job',
374+
'Tags': DEFAULT_TAGS_LIST
375+
368376
},
369377
'Resource': 'arn:aws:states:::sagemaker:createTrainingJob.sync',
370378
'End': True
@@ -380,7 +388,8 @@ def test_transform_step_creation(pca_transformer):
380388
'ExperimentName': 'pca_experiment',
381389
'TrialName': 'pca_trial',
382390
'TrialComponentDisplayName': 'Transform'
383-
}
391+
},
392+
tags=DEFAULT_TAGS,
384393
)
385394
assert step.to_dict() == {
386395
'Type': 'Task',
@@ -406,7 +415,8 @@ def test_transform_step_creation(pca_transformer):
406415
'ExperimentName': 'pca_experiment',
407416
'TrialName': 'pca_trial',
408417
'TrialComponentDisplayName': 'Transform'
409-
}
418+
},
419+
'Tags': DEFAULT_TAGS_LIST
410420
},
411421
'Resource': 'arn:aws:states:::sagemaker:createTransformJob.sync',
412422
'End': True
@@ -465,7 +475,7 @@ def test_get_expected_model_with_framework_estimator(tensorflow_estimator):
465475
}
466476

467477
def test_model_step_creation(pca_model):
468-
step = ModelStep('Create model', model=pca_model, model_name='pca-model')
478+
step = ModelStep('Create model', model=pca_model, model_name='pca-model', tags=DEFAULT_TAGS)
469479
assert step.to_dict() == {
470480
'Type': 'Task',
471481
'Parameters': {
@@ -475,7 +485,8 @@ def test_model_step_creation(pca_model):
475485
'Environment': {},
476486
'Image': pca_model.image,
477487
'ModelDataUrl': pca_model.model_data
478-
}
488+
},
489+
'Tags': DEFAULT_TAGS_LIST
479490
},
480491
'Resource': 'arn:aws:states:::sagemaker:createModel',
481492
'End': True
@@ -491,7 +502,9 @@ def test_endpoint_config_step_creation(pca_model):
491502
model_name='pca-model',
492503
initial_instance_count=1,
493504
instance_type='ml.p2.xlarge',
494-
data_capture_config=data_capture_config)
505+
data_capture_config=data_capture_config,
506+
tags=DEFAULT_TAGS,
507+
)
495508
assert step.to_dict() == {
496509
'Type': 'Task',
497510
'Parameters': {
@@ -514,30 +527,33 @@ def test_endpoint_config_step_creation(pca_model):
514527
'CsvContentTypes': ['text/csv'],
515528
'JsonContentTypes': ['application/json']
516529
}
517-
}
530+
},
531+
'Tags': DEFAULT_TAGS_LIST
518532
},
519533
'Resource': 'arn:aws:states:::sagemaker:createEndpointConfig',
520534
'End': True
521535
}
522536

523537
def test_endpoint_step_creation(pca_model):
524-
step = EndpointStep('Endpoint', endpoint_name='MyEndPoint', endpoint_config_name='MyEndpointConfig')
538+
step = EndpointStep('Endpoint', endpoint_name='MyEndPoint', endpoint_config_name='MyEndpointConfig', tags=DEFAULT_TAGS)
525539
assert step.to_dict() == {
526540
'Type': 'Task',
527541
'Parameters': {
528542
'EndpointConfigName': 'MyEndpointConfig',
529-
'EndpointName': 'MyEndPoint'
543+
'EndpointName': 'MyEndPoint',
544+
'Tags': DEFAULT_TAGS_LIST
530545
},
531546
'Resource': 'arn:aws:states:::sagemaker:createEndpoint',
532547
'End': True
533548
}
534549

535-
step = EndpointStep('Endpoint', endpoint_name='MyEndPoint', endpoint_config_name='MyEndpointConfig', update=True)
550+
step = EndpointStep('Endpoint', endpoint_name='MyEndPoint', endpoint_config_name='MyEndpointConfig', update=True, tags=DEFAULT_TAGS)
536551
assert step.to_dict() == {
537552
'Type': 'Task',
538553
'Parameters': {
539554
'EndpointConfigName': 'MyEndpointConfig',
540-
'EndpointName': 'MyEndPoint'
555+
'EndpointName': 'MyEndPoint',
556+
'Tags': DEFAULT_TAGS_LIST
541557
},
542558
'Resource': 'arn:aws:states:::sagemaker:updateEndpoint',
543559
'End': True

0 commit comments

Comments
 (0)