Skip to content

fix: add tags for sagemaker #37

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 3, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 21 additions & 4 deletions src/stepfunctions/steps/sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class TrainingStep(Task):
Creates a Task State to execute a `SageMaker Training Job <https://docs.aws.amazon.com/sagemaker/latest/dg/API_CreateTrainingJob.html>`_. The TrainingStep will also create a model by default, and the model shares the same name as the training job.
"""

def __init__(self, state_id, estimator, job_name, data=None, hyperparameters=None, mini_batch_size=None, experiment_config=None, wait_for_completion=True, **kwargs):
def __init__(self, state_id, estimator, job_name, data=None, hyperparameters=None, mini_batch_size=None, experiment_config=None, wait_for_completion=True, tags=None, **kwargs):
"""
Args:
state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine.
Expand All @@ -52,6 +52,7 @@ def __init__(self, state_id, estimator, job_name, data=None, hyperparameters=Non
mini_batch_size (int): Specify this argument only when estimator is a built-in estimator of an Amazon algorithm. For other estimators, batch size should be specified in the estimator.
experiment_config (dict, optional): Specify the experiment config for the training. (Default: None)
wait_for_completion (bool, optional): Boolean value set to `True` if the Task state should wait for the training job to complete before proceeding to the next step in the workflow. Set to `False` if the Task state should submit the training job and proceed to the next step. (default: True)
tags (list[dict], optional): `List to tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
"""
self.estimator = estimator
self.job_name = job_name
Expand Down Expand Up @@ -84,6 +85,9 @@ def __init__(self, state_id, estimator, job_name, data=None, hyperparameters=Non
if 'S3Operations' in parameters:
del parameters['S3Operations']

if tags:
parameters['Tags'] = tags_dict_to_kv_list(tags)

kwargs[Field.Parameters.value] = parameters
super(TrainingStep, self).__init__(state_id, **kwargs)

Expand Down Expand Up @@ -111,7 +115,7 @@ class TransformStep(Task):
Creates a Task State to execute a `SageMaker Transform Job <https://docs.aws.amazon.com/sagemaker/latest/dg/API_CreateTransformJob.html>`_.
"""

def __init__(self, state_id, transformer, job_name, model_name, data, data_type='S3Prefix', content_type=None, compression_type=None, split_type=None, experiment_config=None, wait_for_completion=True, **kwargs):
def __init__(self, state_id, transformer, job_name, model_name, data, data_type='S3Prefix', content_type=None, compression_type=None, split_type=None, experiment_config=None, wait_for_completion=True, tags=None, **kwargs):
"""
Args:
state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine.
Expand All @@ -131,6 +135,7 @@ def __init__(self, state_id, transformer, job_name, model_name, data, data_type=
split_type (str): The record delimiter for the input object (default: 'None'). Valid values: 'None', 'Line', 'RecordIO', and 'TFRecord'.
experiment_config (dict, optional): Specify the experiment config for the transform. (Default: None)
wait_for_completion(bool, optional): Boolean value set to `True` if the Task state should wait for the transform job to complete before proceeding to the next step in the workflow. Set to `False` if the Task state should submit the transform job and proceed to the next step. (default: True)
tags (list[dict], optional): `List to tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
"""
if wait_for_completion:
kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createTransformJob.sync'
Expand Down Expand Up @@ -165,6 +170,9 @@ def __init__(self, state_id, transformer, job_name, model_name, data, data_type=
if experiment_config is not None:
parameters['ExperimentConfig'] = experiment_config

if tags:
parameters['Tags'] = tags_dict_to_kv_list(tags)

kwargs[Field.Parameters.value] = parameters
super(TransformStep, self).__init__(state_id, **kwargs)

Expand All @@ -175,13 +183,14 @@ class ModelStep(Task):
Creates a Task State to `create a model in SageMaker <https://docs.aws.amazon.com/sagemaker/latest/dg/API_CreateModel.html>`_.
"""

def __init__(self, state_id, model, model_name=None, instance_type=None, **kwargs):
def __init__(self, state_id, model, model_name=None, instance_type=None, tags=None, **kwargs):
"""
Args:
state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine.
model (sagemaker.model.Model): The SageMaker model to use in the ModelStep. If :py:class:`TrainingStep` was used to train the model and saving the model is the next step in the workflow, the output of :py:func:`TrainingStep.get_expected_model()` can be passed here.
model_name (str or Placeholder, optional): Specify a model name, this is required for creating the model. We recommend to use :py:class:`~stepfunctions.inputs.ExecutionInput` placeholder collection to pass the value dynamically in each execution.
instance_type (str, optional): The EC2 instance type to deploy this Model to. For example, 'ml.p2.xlarge'. This parameter is typically required when the estimator used is not an `Amazon built-in algorithm <https://docs.aws.amazon.com/sagemaker/latest/dg/algos.html>`_.
tags (list[dict], optional): `List to tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
"""
if isinstance(model, FrameworkModel):
parameters = model_config(model=model, instance_type=instance_type, role=model.role, image=model.image)
Expand All @@ -203,6 +212,9 @@ def __init__(self, state_id, model, model_name=None, instance_type=None, **kwarg
if 'S3Operations' in parameters:
del parameters['S3Operations']

if tags:
parameters['Tags'] = tags_dict_to_kv_list(tags)

kwargs[Field.Parameters.value] = parameters
kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createModel'

Expand Down Expand Up @@ -265,6 +277,7 @@ def __init__(self, state_id, endpoint_name, endpoint_config_name, tags=None, upd
endpoint_config_name (str or Placeholder): The name of the endpoint configuration to use for the endpoint. We recommend to use :py:class:`~stepfunctions.inputs.ExecutionInput` placeholder collection to pass the value dynamically in each execution.
tags (list[dict], optional): `List to tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
update (bool, optional): Boolean flag set to `True` if endpoint must to be updated. Set to `False` if new endpoint must be created. (default: False)
tags (list[dict], optional): `List to tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
"""

parameters = {
Expand All @@ -291,7 +304,7 @@ class TuningStep(Task):
Creates a Task State to execute a SageMaker HyperParameterTuning Job.
"""

def __init__(self, state_id, tuner, job_name, data, wait_for_completion=True, **kwargs):
def __init__(self, state_id, tuner, job_name, data, wait_for_completion=True, tags=None, **kwargs):
"""
Args:
state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine.
Expand All @@ -313,6 +326,7 @@ def __init__(self, state_id, tuner, job_name, data, wait_for_completion=True, **
:class:`sagemaker.amazon.amazon_estimator.RecordSet` objects,
where each instance is a different channel of training data.
wait_for_completion(bool, optional): Boolean value set to `True` if the Task state should wait for the tuning job to complete before proceeding to the next step in the workflow. Set to `False` if the Task state should submit the tuning job and proceed to the next step. (default: True)
tags (list[dict], optional): `List to tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
"""
if wait_for_completion:
kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createHyperParameterTuningJob.sync'
Expand All @@ -327,6 +341,9 @@ def __init__(self, state_id, tuner, job_name, data, wait_for_completion=True, **
if 'S3Operations' in parameters:
del parameters['S3Operations']

if tags:
parameters['Tags'] = tags_dict_to_kv_list(tags)

kwargs[Field.Parameters.value] = parameters

super(TuningStep, self).__init__(state_id, **kwargs)
44 changes: 30 additions & 14 deletions tests/unit/test_sagemaker_steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
EXECUTION_ROLE = 'execution-role'
PCA_IMAGE = '382416733822.dkr.ecr.us-east-1.amazonaws.com/pca:1'
TENSORFLOW_IMAGE = '520713654638.dkr.ecr.us-east-1.amazonaws.com/sagemaker-tensorflow:1.13-gpu-py2'
DEFAULT_TAGS = {'Purpose': 'unittests'}
DEFAULT_TAGS_LIST = [{'Key': 'Purpose', 'Value': 'unittests'}]

@pytest.fixture
def pca_estimator():
Expand Down Expand Up @@ -163,7 +165,9 @@ def test_training_step_creation(pca_estimator):
'ExperimentName': 'pca_experiment',
'TrialName': 'pca_trial',
'TrialComponentDisplayName': 'Training'
})
},
tags=DEFAULT_TAGS,
)
assert step.to_dict() == {
'Type': 'Task',
'Parameters': {
Expand Down Expand Up @@ -195,7 +199,8 @@ def test_training_step_creation(pca_estimator):
'TrialName': 'pca_trial',
'TrialComponentDisplayName': 'Training'
},
'TrainingJobName': 'TrainingJob'
'TrainingJobName': 'TrainingJob',
'Tags': DEFAULT_TAGS_LIST
},
'Resource': 'arn:aws:states:::sagemaker:createTrainingJob.sync',
'End': True
Expand Down Expand Up @@ -318,7 +323,8 @@ def test_training_step_creation_with_framework(tensorflow_estimator):
estimator=tensorflow_estimator,
data={'train': 's3://sagemaker/train'},
job_name='tensorflow-job',
mini_batch_size=1024
mini_batch_size=1024,
tags=DEFAULT_TAGS,
)

assert step.to_dict() == {
Expand Down Expand Up @@ -364,7 +370,9 @@ def test_training_step_creation_with_framework(tensorflow_estimator):
'sagemaker_region': '"us-east-1"',
'sagemaker_submit_directory': '"s3://sagemaker/source"'
},
'TrainingJobName': 'tensorflow-job'
'TrainingJobName': 'tensorflow-job',
'Tags': DEFAULT_TAGS_LIST

},
'Resource': 'arn:aws:states:::sagemaker:createTrainingJob.sync',
'End': True
Expand All @@ -380,7 +388,8 @@ def test_transform_step_creation(pca_transformer):
'ExperimentName': 'pca_experiment',
'TrialName': 'pca_trial',
'TrialComponentDisplayName': 'Transform'
}
},
tags=DEFAULT_TAGS,
)
assert step.to_dict() == {
'Type': 'Task',
Expand All @@ -406,7 +415,8 @@ def test_transform_step_creation(pca_transformer):
'ExperimentName': 'pca_experiment',
'TrialName': 'pca_trial',
'TrialComponentDisplayName': 'Transform'
}
},
'Tags': DEFAULT_TAGS_LIST
},
'Resource': 'arn:aws:states:::sagemaker:createTransformJob.sync',
'End': True
Expand Down Expand Up @@ -465,7 +475,7 @@ def test_get_expected_model_with_framework_estimator(tensorflow_estimator):
}

def test_model_step_creation(pca_model):
step = ModelStep('Create model', model=pca_model, model_name='pca-model')
step = ModelStep('Create model', model=pca_model, model_name='pca-model', tags=DEFAULT_TAGS)
assert step.to_dict() == {
'Type': 'Task',
'Parameters': {
Expand All @@ -475,7 +485,8 @@ def test_model_step_creation(pca_model):
'Environment': {},
'Image': pca_model.image,
'ModelDataUrl': pca_model.model_data
}
},
'Tags': DEFAULT_TAGS_LIST
},
'Resource': 'arn:aws:states:::sagemaker:createModel',
'End': True
Expand All @@ -491,7 +502,9 @@ def test_endpoint_config_step_creation(pca_model):
model_name='pca-model',
initial_instance_count=1,
instance_type='ml.p2.xlarge',
data_capture_config=data_capture_config)
data_capture_config=data_capture_config,
tags=DEFAULT_TAGS,
)
assert step.to_dict() == {
'Type': 'Task',
'Parameters': {
Expand All @@ -514,30 +527,33 @@ def test_endpoint_config_step_creation(pca_model):
'CsvContentTypes': ['text/csv'],
'JsonContentTypes': ['application/json']
}
}
},
'Tags': DEFAULT_TAGS_LIST
},
'Resource': 'arn:aws:states:::sagemaker:createEndpointConfig',
'End': True
}

def test_endpoint_step_creation(pca_model):
step = EndpointStep('Endpoint', endpoint_name='MyEndPoint', endpoint_config_name='MyEndpointConfig')
step = EndpointStep('Endpoint', endpoint_name='MyEndPoint', endpoint_config_name='MyEndpointConfig', tags=DEFAULT_TAGS)
assert step.to_dict() == {
'Type': 'Task',
'Parameters': {
'EndpointConfigName': 'MyEndpointConfig',
'EndpointName': 'MyEndPoint'
'EndpointName': 'MyEndPoint',
'Tags': DEFAULT_TAGS_LIST
},
'Resource': 'arn:aws:states:::sagemaker:createEndpoint',
'End': True
}

step = EndpointStep('Endpoint', endpoint_name='MyEndPoint', endpoint_config_name='MyEndpointConfig', update=True)
step = EndpointStep('Endpoint', endpoint_name='MyEndPoint', endpoint_config_name='MyEndpointConfig', update=True, tags=DEFAULT_TAGS)
assert step.to_dict() == {
'Type': 'Task',
'Parameters': {
'EndpointConfigName': 'MyEndpointConfig',
'EndpointName': 'MyEndPoint'
'EndpointName': 'MyEndPoint',
'Tags': DEFAULT_TAGS_LIST
},
'Resource': 'arn:aws:states:::sagemaker:updateEndpoint',
'End': True
Expand Down