Skip to content

Commit 764df67

Browse files
authored
feat: Adds support for Placeholders in TrainingStep to set S3 location for InputDataConfig and OutputDataConfig (#142)
Currently, it is not possible to specify the output path to TrainingStep. It needs to be defined at the Estimator level, which does not support placeholders. This change will make it possible to insert a placeholder output path from the TrainingStep definition and propagate it dynamically to the Estimator. This change change also makes the TrainingStep data parameter compatible with placeholders. There are other feature requests to make other TrainingStep arguments and ProcessingStep arguments compatible with placeholders. They will be addressed in a separate PR where the implementation could maybe be done at a higher level to avoid repetition. Closes #98 #97 #80
1 parent b773fea commit 764df67

File tree

4 files changed

+104
-10
lines changed

4 files changed

+104
-10
lines changed

src/stepfunctions/steps/sagemaker.py

+23-7
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
import logging
1616

1717
from enum import Enum
18-
from stepfunctions.inputs import ExecutionInput, StepInput
18+
19+
from stepfunctions.inputs import Placeholder
1920
from stepfunctions.steps.states import Task
2021
from stepfunctions.steps.fields import Field
2122
from stepfunctions.steps.utils import tags_dict_to_kv_list
@@ -29,7 +30,6 @@
2930

3031
SAGEMAKER_SERVICE_NAME = "sagemaker"
3132

32-
3333
class SageMakerApi(Enum):
3434
CreateTrainingJob = "createTrainingJob"
3535
CreateTransformJob = "createTransformJob"
@@ -47,15 +47,15 @@ class TrainingStep(Task):
4747
Creates a Task State to execute a `SageMaker Training Job <https://docs.aws.amazon.com/sagemaker/latest/dg/API_CreateTrainingJob.html>`_. The TrainingStep will also create a model by default, and the model shares the same name as the training job.
4848
"""
4949

50-
def __init__(self, state_id, estimator, job_name, data=None, hyperparameters=None, mini_batch_size=None, experiment_config=None, wait_for_completion=True, tags=None, **kwargs):
50+
def __init__(self, state_id, estimator, job_name, data=None, hyperparameters=None, mini_batch_size=None, experiment_config=None, wait_for_completion=True, tags=None, output_data_config_path=None, **kwargs):
5151
"""
5252
Args:
5353
state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine.
5454
estimator (sagemaker.estimator.EstimatorBase): The estimator for the training step. Can be a `BYO estimator, Framework estimator <https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms.html>`_ or `Amazon built-in algorithm estimator <https://docs.aws.amazon.com/sagemaker/latest/dg/algos.html>`_.
5555
job_name (str or Placeholder): Specify a training job name, this is required for the training job to run. We recommend to use :py:class:`~stepfunctions.inputs.ExecutionInput` placeholder collection to pass the value dynamically in each execution.
5656
data: Information about the training data. Please refer to the ``fit()`` method of the associated estimator, as this can take any of the following forms:
5757
58-
* (str) - The S3 location where training data is saved.
58+
* (str or Placeholder) - The S3 location where training data is saved.
5959
* (dict[str, str] or dict[str, sagemaker.inputs.TrainingInput]) - If using multiple
6060
channels for training data, you can specify a dict mapping channel names to
6161
strings or :func:`~sagemaker.inputs.TrainingInput` objects.
@@ -75,6 +75,8 @@ def __init__(self, state_id, estimator, job_name, data=None, hyperparameters=Non
7575
experiment_config (dict, optional): Specify the experiment config for the training. (Default: None)
7676
wait_for_completion (bool, optional): Boolean value set to `True` if the Task state should wait for the training job to complete before proceeding to the next step in the workflow. Set to `False` if the Task state should submit the training job and proceed to the next step. (default: True)
7777
tags (list[dict], optional): `List to tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
78+
output_data_config_path (str or Placeholder, optional): S3 location for saving the training result (model
79+
artifacts and output files). If specified, it overrides the `output_path` property of `estimator`.
7880
"""
7981
self.estimator = estimator
8082
self.job_name = job_name
@@ -94,6 +96,11 @@ def __init__(self, state_id, estimator, job_name, data=None, hyperparameters=Non
9496

9597
kwargs[Field.Resource.value] = get_service_integration_arn(SAGEMAKER_SERVICE_NAME,
9698
SageMakerApi.CreateTrainingJob)
99+
# Convert `data` Placeholder to a JSONPath string because sagemaker.workflow.airflow.training_config does not
100+
# accept Placeholder in the `input` argument. We will suffix the 'S3Uri' key in `parameters` with ".$" later.
101+
is_data_placeholder = isinstance(data, Placeholder)
102+
if is_data_placeholder:
103+
data = data.to_jsonpath()
97104

98105
if isinstance(job_name, str):
99106
parameters = training_config(estimator=estimator, inputs=data, job_name=job_name, mini_batch_size=mini_batch_size)
@@ -106,9 +113,18 @@ def __init__(self, state_id, estimator, job_name, data=None, hyperparameters=Non
106113
if estimator.rules != None:
107114
parameters['DebugRuleConfigurations'] = [rule.to_debugger_rule_config_dict() for rule in estimator.rules]
108115

109-
if isinstance(job_name, (ExecutionInput, StepInput)):
116+
if isinstance(job_name, Placeholder):
110117
parameters['TrainingJobName'] = job_name
111118

119+
if output_data_config_path is not None:
120+
parameters['OutputDataConfig']['S3OutputPath'] = output_data_config_path
121+
122+
if data is not None and is_data_placeholder:
123+
# Replace the 'S3Uri' key with one that supports JSONpath value.
124+
# Support for uri str only: The list will only contain 1 element
125+
data_uri = parameters['InputDataConfig'][0]['DataSource']['S3DataSource'].pop('S3Uri', None)
126+
parameters['InputDataConfig'][0]['DataSource']['S3DataSource']['S3Uri.$'] = data_uri
127+
112128
if hyperparameters is not None:
113129
if estimator.hyperparameters() is not None:
114130
hyperparameters = self.__merge_hyperparameters(hyperparameters, estimator.hyperparameters())
@@ -237,7 +253,7 @@ def __init__(self, state_id, transformer, job_name, model_name, data, data_type=
237253
join_source=join_source
238254
)
239255

240-
if isinstance(job_name, (ExecutionInput, StepInput)):
256+
if isinstance(job_name, Placeholder):
241257
parameters['TransformJobName'] = job_name
242258

243259
parameters['ModelName'] = model_name
@@ -506,7 +522,7 @@ def __init__(self, state_id, processor, job_name, inputs=None, outputs=None, exp
506522
else:
507523
parameters = processing_config(processor=processor, inputs=inputs, outputs=outputs, container_arguments=container_arguments, container_entrypoint=container_entrypoint, kms_key_id=kms_key_id)
508524

509-
if isinstance(job_name, (ExecutionInput, StepInput)):
525+
if isinstance(job_name, Placeholder):
510526
parameters['ProcessingJobName'] = job_name
511527

512528
if experiment_config is not None:

src/stepfunctions/steps/states.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from stepfunctions.exceptions import DuplicateStatesInChain
1919
from stepfunctions.steps.fields import Field
20-
from stepfunctions.inputs import ExecutionInput, StepInput
20+
from stepfunctions.inputs import Placeholder, StepInput
2121

2222

2323
logger = logging.getLogger('stepfunctions.states')
@@ -53,7 +53,7 @@ def _replace_placeholders(self, params):
5353
return params
5454
modified_parameters = {}
5555
for k, v in params.items():
56-
if isinstance(v, (ExecutionInput, StepInput)):
56+
if isinstance(v, Placeholder):
5757
modified_key = "{key}.$".format(key=k)
5858
modified_parameters[modified_key] = v.to_jsonpath()
5959
elif isinstance(v, dict):

tests/integ/test_sagemaker_steps.py

+1
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ def test_training_step(pca_estimator_fixture, record_set_fixture, sfn_client, sf
104104
state_machine_delete_wait(sfn_client, workflow.state_machine_arn)
105105
# End of Cleanup
106106

107+
107108
def test_model_step(trained_estimator, sfn_client, sagemaker_session, sfn_role_arn):
108109
# Build workflow definition
109110
model_name = generate_job_name()

tests/unit/test_sagemaker_steps.py

+78-1
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from sagemaker.processing import ProcessingInput, ProcessingOutput
2727

2828
from unittest.mock import MagicMock, patch
29+
from stepfunctions.inputs import ExecutionInput, StepInput
2930
from stepfunctions.steps.sagemaker import TrainingStep, TransformStep, ModelStep, EndpointStep, EndpointConfigStep, ProcessingStep
3031
from stepfunctions.steps.sagemaker import tuning_config
3132

@@ -224,6 +225,7 @@ def test_training_step_creation(pca_estimator):
224225
'TrialName': 'pca_trial',
225226
'TrialComponentDisplayName': 'Training'
226227
},
228+
output_data_config_path='s3://sagemaker-us-east-1-111111111111',
227229
tags=DEFAULT_TAGS,
228230
)
229231
assert step.to_dict() == {
@@ -234,7 +236,7 @@ def test_training_step_creation(pca_estimator):
234236
'TrainingInputMode': 'File'
235237
},
236238
'OutputDataConfig': {
237-
'S3OutputPath': 's3://sagemaker/models'
239+
'S3OutputPath': 's3://sagemaker-us-east-1-111111111111'
238240
},
239241
'StoppingCondition': {
240242
'MaxRuntimeInSeconds': 86400
@@ -265,6 +267,81 @@ def test_training_step_creation(pca_estimator):
265267
}
266268

267269

270+
@patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call)
271+
@patch.object(boto3.session.Session, 'region_name', 'us-east-1')
272+
def test_training_step_creation_with_placeholders(pca_estimator):
273+
execution_input = ExecutionInput(schema={
274+
'Data': str,
275+
'OutputPath': str,
276+
})
277+
278+
step_input = StepInput(schema={
279+
'JobName': str,
280+
})
281+
282+
step = TrainingStep('Training',
283+
estimator=pca_estimator,
284+
job_name=step_input['JobName'],
285+
data=execution_input['Data'],
286+
output_data_config_path=execution_input['OutputPath'],
287+
experiment_config={
288+
'ExperimentName': 'pca_experiment',
289+
'TrialName': 'pca_trial',
290+
'TrialComponentDisplayName': 'Training'
291+
},
292+
tags=DEFAULT_TAGS,
293+
)
294+
assert step.to_dict() == {
295+
'Type': 'Task',
296+
'Parameters': {
297+
'AlgorithmSpecification': {
298+
'TrainingImage': PCA_IMAGE,
299+
'TrainingInputMode': 'File'
300+
},
301+
'OutputDataConfig': {
302+
'S3OutputPath.$': "$$.Execution.Input['OutputPath']"
303+
},
304+
'StoppingCondition': {
305+
'MaxRuntimeInSeconds': 86400
306+
},
307+
'ResourceConfig': {
308+
'InstanceCount': 1,
309+
'InstanceType': 'ml.c4.xlarge',
310+
'VolumeSizeInGB': 30
311+
},
312+
'RoleArn': EXECUTION_ROLE,
313+
'HyperParameters': {
314+
'feature_dim': '50000',
315+
'num_components': '10',
316+
'subtract_mean': 'True',
317+
'algorithm_mode': 'randomized',
318+
'mini_batch_size': '200'
319+
},
320+
'InputDataConfig': [
321+
{
322+
'ChannelName': 'training',
323+
'DataSource': {
324+
'S3DataSource': {
325+
'S3DataDistributionType': 'FullyReplicated',
326+
'S3DataType': 'S3Prefix',
327+
'S3Uri.$': "$$.Execution.Input['Data']"
328+
}
329+
}
330+
}
331+
],
332+
'ExperimentConfig': {
333+
'ExperimentName': 'pca_experiment',
334+
'TrialName': 'pca_trial',
335+
'TrialComponentDisplayName': 'Training'
336+
},
337+
'TrainingJobName.$': "$['JobName']",
338+
'Tags': DEFAULT_TAGS_LIST
339+
},
340+
'Resource': 'arn:aws:states:::sagemaker:createTrainingJob.sync',
341+
'End': True
342+
}
343+
344+
268345
@patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call)
269346
@patch.object(boto3.session.Session, 'region_name', 'us-east-1')
270347
def test_training_step_creation_with_debug_hook(pca_estimator_with_debug_hook):

0 commit comments

Comments
 (0)