Skip to content

Commit b4ddbbd

Browse files
authored
Merge branch 'main' into support-placeholders-for-map-state
2 parents 9942f95 + 2091850 commit b4ddbbd

File tree

3 files changed

+265
-81
lines changed

3 files changed

+265
-81
lines changed

src/stepfunctions/steps/sagemaker.py

Lines changed: 41 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -74,11 +74,13 @@ def __init__(self, state_id, estimator, job_name, data=None, hyperparameters=Non
7474
If there are duplicate entries, the value provided through this property will be used. (Default: Hyperparameters specified in the estimator.)
7575
* (Placeholder, optional) - The TrainingStep will use the hyperparameters specified by the Placeholder's value instead of the hyperparameters specified in the estimator.
7676
mini_batch_size (int): Specify this argument only when estimator is a built-in estimator of an Amazon algorithm. For other estimators, batch size should be specified in the estimator.
77-
experiment_config (dict, optional): Specify the experiment config for the training. (Default: None)
77+
experiment_config (dict or Placeholder, optional): Specify the experiment config for the training. (Default: None)
7878
wait_for_completion (bool, optional): Boolean value set to `True` if the Task state should wait for the training job to complete before proceeding to the next step in the workflow. Set to `False` if the Task state should submit the training job and proceed to the next step. (default: True)
79-
tags (list[dict], optional): `List to tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
79+
tags (list[dict] or Placeholder, optional): `List of tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
8080
output_data_config_path (str or Placeholder, optional): S3 location for saving the training result (model
8181
artifacts and output files). If specified, it overrides the `output_path` property of `estimator`.
82+
parameters(dict, optional): The value of this field is merged with other arguments to become the request payload for SageMaker `CreateTrainingJob<https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateTrainingJob.html>`_. (Default: None)
83+
You can use `parameters` to override the value provided by other arguments and specify any field's value dynamically using `Placeholders<https://aws-step-functions-data-science-sdk.readthedocs.io/en/stable/placeholders.html?highlight=placeholder#stepfunctions.inputs.Placeholder>`_.
8284
"""
8385
self.estimator = estimator
8486
self.job_name = job_name
@@ -105,44 +107,48 @@ def __init__(self, state_id, estimator, job_name, data=None, hyperparameters=Non
105107
data = data.to_jsonpath()
106108

107109
if isinstance(job_name, str):
108-
parameters = training_config(estimator=estimator, inputs=data, job_name=job_name, mini_batch_size=mini_batch_size)
110+
training_parameters = training_config(estimator=estimator, inputs=data, job_name=job_name, mini_batch_size=mini_batch_size)
109111
else:
110-
parameters = training_config(estimator=estimator, inputs=data, mini_batch_size=mini_batch_size)
112+
training_parameters = training_config(estimator=estimator, inputs=data, mini_batch_size=mini_batch_size)
111113

112114
if estimator.debugger_hook_config != None and estimator.debugger_hook_config is not False:
113-
parameters['DebugHookConfig'] = estimator.debugger_hook_config._to_request_dict()
115+
training_parameters['DebugHookConfig'] = estimator.debugger_hook_config._to_request_dict()
114116

115117
if estimator.rules != None:
116-
parameters['DebugRuleConfigurations'] = [rule.to_debugger_rule_config_dict() for rule in estimator.rules]
118+
training_parameters['DebugRuleConfigurations'] = [rule.to_debugger_rule_config_dict() for rule in estimator.rules]
117119

118120
if isinstance(job_name, Placeholder):
119-
parameters['TrainingJobName'] = job_name
121+
training_parameters['TrainingJobName'] = job_name
120122

121123
if output_data_config_path is not None:
122-
parameters['OutputDataConfig']['S3OutputPath'] = output_data_config_path
124+
training_parameters['OutputDataConfig']['S3OutputPath'] = output_data_config_path
123125

124126
if data is not None and is_data_placeholder:
125127
# Replace the 'S3Uri' key with one that supports JSONpath value.
126128
# Support for uri str only: The list will only contain 1 element
127-
data_uri = parameters['InputDataConfig'][0]['DataSource']['S3DataSource'].pop('S3Uri', None)
128-
parameters['InputDataConfig'][0]['DataSource']['S3DataSource']['S3Uri.$'] = data_uri
129+
data_uri = training_parameters['InputDataConfig'][0]['DataSource']['S3DataSource'].pop('S3Uri', None)
130+
training_parameters['InputDataConfig'][0]['DataSource']['S3DataSource']['S3Uri.$'] = data_uri
129131

130132
if hyperparameters is not None:
131133
if not isinstance(hyperparameters, Placeholder):
132134
if estimator.hyperparameters() is not None:
133135
hyperparameters = self.__merge_hyperparameters(hyperparameters, estimator.hyperparameters())
134-
parameters['HyperParameters'] = hyperparameters
136+
training_parameters['HyperParameters'] = hyperparameters
135137

136138
if experiment_config is not None:
137-
parameters['ExperimentConfig'] = experiment_config
139+
training_parameters['ExperimentConfig'] = experiment_config
138140

139-
if 'S3Operations' in parameters:
140-
del parameters['S3Operations']
141+
if 'S3Operations' in training_parameters:
142+
del training_parameters['S3Operations']
141143

142144
if tags:
143-
parameters['Tags'] = tags_dict_to_kv_list(tags)
145+
training_parameters['Tags'] = tags if isinstance(tags, Placeholder) else tags_dict_to_kv_list(tags)
144146

145-
kwargs[Field.Parameters.value] = parameters
147+
if Field.Parameters.value in kwargs and isinstance(kwargs[Field.Parameters.value], dict):
148+
# Update training parameters with input parameters
149+
merge_dicts(training_parameters, kwargs[Field.Parameters.value])
150+
151+
kwargs[Field.Parameters.value] = training_parameters
146152
super(TrainingStep, self).__init__(state_id, **kwargs)
147153

148154
def get_expected_model(self, model_name=None):
@@ -214,7 +220,7 @@ def __init__(self, state_id, transformer, job_name, model_name, data, data_type=
214220
split_type (str or Placeholder): The record delimiter for the input object (default: 'None'). Valid values: 'None', 'Line', 'RecordIO', and 'TFRecord'.
215221
experiment_config (dict or Placeholder, optional): Specify the experiment config for the transform. (Default: None)
216222
wait_for_completion(bool, optional): Boolean value set to `True` if the Task state should wait for the transform job to complete before proceeding to the next step in the workflow. Set to `False` if the Task state should submit the transform job and proceed to the next step. (default: True)
217-
tags (list[dict] or Placeholder, optional): `List to tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
223+
tags (list[dict] or Placeholder, optional): `List of tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
218224
input_filter (str or Placeholder): A JSONPath to select a portion of the input to pass to the algorithm container for inference. If you omit the field, it gets the value ‘$’, representing the entire input. For CSV data, each row is taken as a JSON array, so only index-based JSONPaths can be applied, e.g. $[0], $[1:]. CSV data should follow the RFC format. See Supported JSONPath Operators for a table of supported JSONPath operators. For more information, see the SageMaker API documentation for CreateTransformJob. Some examples: “$[1:]”, “$.features” (default: None).
219225
output_filter (str or Placeholder): A JSONPath to select a portion of the joined/original output to return as the output. For more information, see the SageMaker API documentation for CreateTransformJob. Some examples: “$[1:]”, “$.prediction” (default: None).
220226
join_source (str or Placeholder): The source of data to be joined to the transform output. It can be set to ‘Input’ meaning the entire input record will be joined to the inference result. You can use OutputFilter to select the useful portion before uploading to S3. (default: None). Valid values: Input, None.
@@ -296,14 +302,16 @@ def __init__(self, state_id, model, model_name=None, instance_type=None, tags=No
296302
model (sagemaker.model.Model): The SageMaker model to use in the ModelStep. If :py:class:`TrainingStep` was used to train the model and saving the model is the next step in the workflow, the output of :py:func:`TrainingStep.get_expected_model()` can be passed here.
297303
model_name (str or Placeholder, optional): Specify a model name, this is required for creating the model. We recommend to use :py:class:`~stepfunctions.inputs.ExecutionInput` placeholder collection to pass the value dynamically in each execution.
298304
instance_type (str, optional): The EC2 instance type to deploy this Model to. For example, 'ml.p2.xlarge'.
299-
tags (list[dict], optional): `List to tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
305+
tags (list[dict] or Placeholders, optional): `List of tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
306+
parameters(dict, optional): The value of this field is merged with other arguments to become the request payload for SageMaker `CreateModel<https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateModel.html>`_. (Default: None)
307+
You can use `parameters` to override the value provided by other arguments and specify any field's value dynamically using `Placeholders<https://aws-step-functions-data-science-sdk.readthedocs.io/en/stable/placeholders.html?highlight=placeholder#stepfunctions.inputs.Placeholder>`_.
300308
"""
301309
if isinstance(model, FrameworkModel):
302-
parameters = model_config(model=model, instance_type=instance_type, role=model.role, image_uri=model.image_uri)
310+
model_parameters = model_config(model=model, instance_type=instance_type, role=model.role, image_uri=model.image_uri)
303311
if model_name:
304-
parameters['ModelName'] = model_name
312+
model_parameters['ModelName'] = model_name
305313
elif isinstance(model, Model):
306-
parameters = {
314+
model_parameters = {
307315
'ExecutionRoleArn': model.role,
308316
'ModelName': model_name or model.name,
309317
'PrimaryContainer': {
@@ -315,13 +323,17 @@ def __init__(self, state_id, model, model_name=None, instance_type=None, tags=No
315323
else:
316324
raise ValueError("Expected 'model' parameter to be of type 'sagemaker.model.Model', but received type '{}'".format(type(model).__name__))
317325

318-
if 'S3Operations' in parameters:
319-
del parameters['S3Operations']
326+
if 'S3Operations' in model_parameters:
327+
del model_parameters['S3Operations']
320328

321329
if tags:
322-
parameters['Tags'] = tags_dict_to_kv_list(tags)
330+
model_parameters['Tags'] = tags if isinstance(tags, Placeholder) else tags_dict_to_kv_list(tags)
323331

324-
kwargs[Field.Parameters.value] = parameters
332+
if Field.Parameters.value in kwargs and isinstance(kwargs[Field.Parameters.value], dict):
333+
# Update model parameters with input parameters
334+
merge_dicts(model_parameters, kwargs[Field.Parameters.value])
335+
336+
kwargs[Field.Parameters.value] = model_parameters
325337

326338
"""
327339
Example resource arn: arn:aws:states:::sagemaker:createModel
@@ -351,7 +363,7 @@ def __init__(self, state_id, endpoint_config_name, model_name, initial_instance_
351363
data_capture_config (sagemaker.model_monitor.DataCaptureConfig, optional): Specifies
352364
configuration related to Endpoint data capture for use with
353365
Amazon SageMaker Model Monitoring. Default: None.
354-
tags (list[dict], optional): `List to tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
366+
tags (list[dict], optional): `List of tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
355367
"""
356368
parameters = {
357369
'EndpointConfigName': endpoint_config_name,
@@ -393,9 +405,8 @@ def __init__(self, state_id, endpoint_name, endpoint_config_name, tags=None, upd
393405
state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine.
394406
endpoint_name (str or Placeholder): The name of the endpoint to create. We recommend to use :py:class:`~stepfunctions.inputs.ExecutionInput` placeholder collection to pass the value dynamically in each execution.
395407
endpoint_config_name (str or Placeholder): The name of the endpoint configuration to use for the endpoint. We recommend to use :py:class:`~stepfunctions.inputs.ExecutionInput` placeholder collection to pass the value dynamically in each execution.
396-
tags (list[dict], optional): `List to tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
397408
update (bool, optional): Boolean flag set to `True` if endpoint must to be updated. Set to `False` if new endpoint must be created. (default: False)
398-
tags (list[dict], optional): `List to tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
409+
tags (list[dict], optional): `List of tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
399410
"""
400411

401412
parameters = {
@@ -454,7 +465,7 @@ def __init__(self, state_id, tuner, job_name, data, wait_for_completion=True, ta
454465
:class:`sagemaker.amazon.amazon_estimator.RecordSet` objects,
455466
where each instance is a different channel of training data.
456467
wait_for_completion(bool, optional): Boolean value set to `True` if the Task state should wait for the tuning job to complete before proceeding to the next step in the workflow. Set to `False` if the Task state should submit the tuning job and proceed to the next step. (default: True)
457-
tags (list[dict], optional): `List to tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
468+
tags (list[dict], optional): `List of tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
458469
"""
459470
if wait_for_completion:
460471
"""
@@ -516,7 +527,7 @@ def __init__(self, state_id, processor, job_name, inputs=None, outputs=None, exp
516527
ARN of a KMS key, alias of a KMS key, or alias of a KMS key.
517528
The KmsKeyId is applied to all outputs.
518529
wait_for_completion (bool, optional): Boolean value set to `True` if the Task state should wait for the processing job to complete before proceeding to the next step in the workflow. Set to `False` if the Task state should submit the processing job and proceed to the next step. (default: True)
519-
tags (list[dict] or Placeholder, optional): `List to tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
530+
tags (list[dict] or Placeholder, optional): `List of tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
520531
parameters(dict, optional): The value of this field is merged with other arguments to become the request payload for SageMaker `CreateProcessingJob<https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateProcessingJob.html>`_.
521532
You can use `parameters` to override the value provided by other arguments and specify any field's value dynamically using `Placeholders<https://aws-step-functions-data-science-sdk.readthedocs.io/en/stable/placeholders.html?highlight=placeholder#stepfunctions.inputs.Placeholder>`_.
522533

0 commit comments

Comments
 (0)