diff --git a/src/stepfunctions/steps/sagemaker.py b/src/stepfunctions/steps/sagemaker.py index 15ff206..dba0b64 100644 --- a/src/stepfunctions/steps/sagemaker.py +++ b/src/stepfunctions/steps/sagemaker.py @@ -27,7 +27,7 @@ class TrainingStep(Task): Creates a Task State to execute a `SageMaker Training Job `_. The TrainingStep will also create a model by default, and the model shares the same name as the training job. """ - def __init__(self, state_id, estimator, job_name, data=None, hyperparameters=None, mini_batch_size=None, experiment_config=None, wait_for_completion=True, **kwargs): + def __init__(self, state_id, estimator, job_name, data=None, hyperparameters=None, mini_batch_size=None, experiment_config=None, wait_for_completion=True, tags=None, **kwargs): """ Args: state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine. @@ -52,6 +52,7 @@ def __init__(self, state_id, estimator, job_name, data=None, hyperparameters=Non mini_batch_size (int): Specify this argument only when estimator is a built-in estimator of an Amazon algorithm. For other estimators, batch size should be specified in the estimator. experiment_config (dict, optional): Specify the experiment config for the training. (Default: None) wait_for_completion (bool, optional): Boolean value set to `True` if the Task state should wait for the training job to complete before proceeding to the next step in the workflow. Set to `False` if the Task state should submit the training job and proceed to the next step. (default: True) + tags (list[dict], optional): `List to tags `_ to associate with the resource. """ self.estimator = estimator self.job_name = job_name @@ -84,6 +85,9 @@ def __init__(self, state_id, estimator, job_name, data=None, hyperparameters=Non if 'S3Operations' in parameters: del parameters['S3Operations'] + if tags: + parameters['Tags'] = tags_dict_to_kv_list(tags) + kwargs[Field.Parameters.value] = parameters super(TrainingStep, self).__init__(state_id, **kwargs) @@ -111,7 +115,7 @@ class TransformStep(Task): Creates a Task State to execute a `SageMaker Transform Job `_. """ - def __init__(self, state_id, transformer, job_name, model_name, data, data_type='S3Prefix', content_type=None, compression_type=None, split_type=None, experiment_config=None, wait_for_completion=True, **kwargs): + def __init__(self, state_id, transformer, job_name, model_name, data, data_type='S3Prefix', content_type=None, compression_type=None, split_type=None, experiment_config=None, wait_for_completion=True, tags=None, **kwargs): """ Args: state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine. @@ -131,6 +135,7 @@ def __init__(self, state_id, transformer, job_name, model_name, data, data_type= split_type (str): The record delimiter for the input object (default: 'None'). Valid values: 'None', 'Line', 'RecordIO', and 'TFRecord'. experiment_config (dict, optional): Specify the experiment config for the transform. (Default: None) wait_for_completion(bool, optional): Boolean value set to `True` if the Task state should wait for the transform job to complete before proceeding to the next step in the workflow. Set to `False` if the Task state should submit the transform job and proceed to the next step. (default: True) + tags (list[dict], optional): `List to tags `_ to associate with the resource. """ if wait_for_completion: kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createTransformJob.sync' @@ -165,6 +170,9 @@ def __init__(self, state_id, transformer, job_name, model_name, data, data_type= if experiment_config is not None: parameters['ExperimentConfig'] = experiment_config + if tags: + parameters['Tags'] = tags_dict_to_kv_list(tags) + kwargs[Field.Parameters.value] = parameters super(TransformStep, self).__init__(state_id, **kwargs) @@ -175,13 +183,14 @@ class ModelStep(Task): Creates a Task State to `create a model in SageMaker `_. """ - def __init__(self, state_id, model, model_name=None, instance_type=None, **kwargs): + def __init__(self, state_id, model, model_name=None, instance_type=None, tags=None, **kwargs): """ Args: state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine. model (sagemaker.model.Model): The SageMaker model to use in the ModelStep. If :py:class:`TrainingStep` was used to train the model and saving the model is the next step in the workflow, the output of :py:func:`TrainingStep.get_expected_model()` can be passed here. model_name (str or Placeholder, optional): Specify a model name, this is required for creating the model. We recommend to use :py:class:`~stepfunctions.inputs.ExecutionInput` placeholder collection to pass the value dynamically in each execution. instance_type (str, optional): The EC2 instance type to deploy this Model to. For example, 'ml.p2.xlarge'. This parameter is typically required when the estimator used is not an `Amazon built-in algorithm `_. + tags (list[dict], optional): `List to tags `_ to associate with the resource. """ if isinstance(model, FrameworkModel): parameters = model_config(model=model, instance_type=instance_type, role=model.role, image=model.image) @@ -203,6 +212,9 @@ def __init__(self, state_id, model, model_name=None, instance_type=None, **kwarg if 'S3Operations' in parameters: del parameters['S3Operations'] + if tags: + parameters['Tags'] = tags_dict_to_kv_list(tags) + kwargs[Field.Parameters.value] = parameters kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createModel' @@ -265,6 +277,7 @@ def __init__(self, state_id, endpoint_name, endpoint_config_name, tags=None, upd endpoint_config_name (str or Placeholder): The name of the endpoint configuration to use for the endpoint. We recommend to use :py:class:`~stepfunctions.inputs.ExecutionInput` placeholder collection to pass the value dynamically in each execution. tags (list[dict], optional): `List to tags `_ to associate with the resource. update (bool, optional): Boolean flag set to `True` if endpoint must to be updated. Set to `False` if new endpoint must be created. (default: False) + tags (list[dict], optional): `List to tags `_ to associate with the resource. """ parameters = { @@ -291,7 +304,7 @@ class TuningStep(Task): Creates a Task State to execute a SageMaker HyperParameterTuning Job. """ - def __init__(self, state_id, tuner, job_name, data, wait_for_completion=True, **kwargs): + def __init__(self, state_id, tuner, job_name, data, wait_for_completion=True, tags=None, **kwargs): """ Args: state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine. @@ -313,6 +326,7 @@ def __init__(self, state_id, tuner, job_name, data, wait_for_completion=True, ** :class:`sagemaker.amazon.amazon_estimator.RecordSet` objects, where each instance is a different channel of training data. wait_for_completion(bool, optional): Boolean value set to `True` if the Task state should wait for the tuning job to complete before proceeding to the next step in the workflow. Set to `False` if the Task state should submit the tuning job and proceed to the next step. (default: True) + tags (list[dict], optional): `List to tags `_ to associate with the resource. """ if wait_for_completion: kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createHyperParameterTuningJob.sync' @@ -327,6 +341,9 @@ def __init__(self, state_id, tuner, job_name, data, wait_for_completion=True, ** if 'S3Operations' in parameters: del parameters['S3Operations'] + if tags: + parameters['Tags'] = tags_dict_to_kv_list(tags) + kwargs[Field.Parameters.value] = parameters super(TuningStep, self).__init__(state_id, **kwargs) diff --git a/tests/unit/test_sagemaker_steps.py b/tests/unit/test_sagemaker_steps.py index 98e8372..c7ca95e 100644 --- a/tests/unit/test_sagemaker_steps.py +++ b/tests/unit/test_sagemaker_steps.py @@ -32,6 +32,8 @@ EXECUTION_ROLE = 'execution-role' PCA_IMAGE = '382416733822.dkr.ecr.us-east-1.amazonaws.com/pca:1' TENSORFLOW_IMAGE = '520713654638.dkr.ecr.us-east-1.amazonaws.com/sagemaker-tensorflow:1.13-gpu-py2' +DEFAULT_TAGS = {'Purpose': 'unittests'} +DEFAULT_TAGS_LIST = [{'Key': 'Purpose', 'Value': 'unittests'}] @pytest.fixture def pca_estimator(): @@ -163,7 +165,9 @@ def test_training_step_creation(pca_estimator): 'ExperimentName': 'pca_experiment', 'TrialName': 'pca_trial', 'TrialComponentDisplayName': 'Training' - }) + }, + tags=DEFAULT_TAGS, + ) assert step.to_dict() == { 'Type': 'Task', 'Parameters': { @@ -195,7 +199,8 @@ def test_training_step_creation(pca_estimator): 'TrialName': 'pca_trial', 'TrialComponentDisplayName': 'Training' }, - 'TrainingJobName': 'TrainingJob' + 'TrainingJobName': 'TrainingJob', + 'Tags': DEFAULT_TAGS_LIST }, 'Resource': 'arn:aws:states:::sagemaker:createTrainingJob.sync', 'End': True @@ -318,7 +323,8 @@ def test_training_step_creation_with_framework(tensorflow_estimator): estimator=tensorflow_estimator, data={'train': 's3://sagemaker/train'}, job_name='tensorflow-job', - mini_batch_size=1024 + mini_batch_size=1024, + tags=DEFAULT_TAGS, ) assert step.to_dict() == { @@ -364,7 +370,9 @@ def test_training_step_creation_with_framework(tensorflow_estimator): 'sagemaker_region': '"us-east-1"', 'sagemaker_submit_directory': '"s3://sagemaker/source"' }, - 'TrainingJobName': 'tensorflow-job' + 'TrainingJobName': 'tensorflow-job', + 'Tags': DEFAULT_TAGS_LIST + }, 'Resource': 'arn:aws:states:::sagemaker:createTrainingJob.sync', 'End': True @@ -380,7 +388,8 @@ def test_transform_step_creation(pca_transformer): 'ExperimentName': 'pca_experiment', 'TrialName': 'pca_trial', 'TrialComponentDisplayName': 'Transform' - } + }, + tags=DEFAULT_TAGS, ) assert step.to_dict() == { 'Type': 'Task', @@ -406,7 +415,8 @@ def test_transform_step_creation(pca_transformer): 'ExperimentName': 'pca_experiment', 'TrialName': 'pca_trial', 'TrialComponentDisplayName': 'Transform' - } + }, + 'Tags': DEFAULT_TAGS_LIST }, 'Resource': 'arn:aws:states:::sagemaker:createTransformJob.sync', 'End': True @@ -465,7 +475,7 @@ def test_get_expected_model_with_framework_estimator(tensorflow_estimator): } def test_model_step_creation(pca_model): - step = ModelStep('Create model', model=pca_model, model_name='pca-model') + step = ModelStep('Create model', model=pca_model, model_name='pca-model', tags=DEFAULT_TAGS) assert step.to_dict() == { 'Type': 'Task', 'Parameters': { @@ -475,7 +485,8 @@ def test_model_step_creation(pca_model): 'Environment': {}, 'Image': pca_model.image, 'ModelDataUrl': pca_model.model_data - } + }, + 'Tags': DEFAULT_TAGS_LIST }, 'Resource': 'arn:aws:states:::sagemaker:createModel', 'End': True @@ -491,7 +502,9 @@ def test_endpoint_config_step_creation(pca_model): model_name='pca-model', initial_instance_count=1, instance_type='ml.p2.xlarge', - data_capture_config=data_capture_config) + data_capture_config=data_capture_config, + tags=DEFAULT_TAGS, + ) assert step.to_dict() == { 'Type': 'Task', 'Parameters': { @@ -514,30 +527,33 @@ def test_endpoint_config_step_creation(pca_model): 'CsvContentTypes': ['text/csv'], 'JsonContentTypes': ['application/json'] } - } + }, + 'Tags': DEFAULT_TAGS_LIST }, 'Resource': 'arn:aws:states:::sagemaker:createEndpointConfig', 'End': True } def test_endpoint_step_creation(pca_model): - step = EndpointStep('Endpoint', endpoint_name='MyEndPoint', endpoint_config_name='MyEndpointConfig') + step = EndpointStep('Endpoint', endpoint_name='MyEndPoint', endpoint_config_name='MyEndpointConfig', tags=DEFAULT_TAGS) assert step.to_dict() == { 'Type': 'Task', 'Parameters': { 'EndpointConfigName': 'MyEndpointConfig', - 'EndpointName': 'MyEndPoint' + 'EndpointName': 'MyEndPoint', + 'Tags': DEFAULT_TAGS_LIST }, 'Resource': 'arn:aws:states:::sagemaker:createEndpoint', 'End': True } - step = EndpointStep('Endpoint', endpoint_name='MyEndPoint', endpoint_config_name='MyEndpointConfig', update=True) + step = EndpointStep('Endpoint', endpoint_name='MyEndPoint', endpoint_config_name='MyEndpointConfig', update=True, tags=DEFAULT_TAGS) assert step.to_dict() == { 'Type': 'Task', 'Parameters': { 'EndpointConfigName': 'MyEndpointConfig', - 'EndpointName': 'MyEndPoint' + 'EndpointName': 'MyEndPoint', + 'Tags': DEFAULT_TAGS_LIST }, 'Resource': 'arn:aws:states:::sagemaker:updateEndpoint', 'End': True