Skip to content

Commit d6120f9

Browse files
authored
Merge branch 'main' into training-step-with-dynamic-output-path
2 parents 63a7d55 + 349fc11 commit d6120f9

File tree

3 files changed

+159
-2
lines changed

3 files changed

+159
-2
lines changed

README.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ to provision and integrate the AWS services separately.
1717
The AWS Step Functions Data Science SDK enables you to do the following.
1818

1919
- Easily construct and run machine learning workflows that use AWS
20-
infrastructure directly in Python
20+
infrastructure directly in Python
2121
- Instantiate common training pipelines
2222
- Create standard machine learning workflows in a Jupyter notebook from
2323
templates

src/stepfunctions/steps/sagemaker.py

+29-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15+
import logging
16+
1517
from enum import Enum
1618
import logging
1719

@@ -25,6 +27,8 @@
2527
from sagemaker.model import Model, FrameworkModel
2628
from sagemaker.model_monitor import DataCaptureConfig
2729

30+
logger = logging.getLogger('stepfunctions.sagemaker')
31+
2832
SAGEMAKER_SERVICE_NAME = "sagemaker"
2933
logger = logging.getLogger('stepfunctions.sagemaker')
3034

@@ -66,7 +70,9 @@ def __init__(self, state_id, estimator, job_name, data=None, hyperparameters=Non
6670
* (list[sagemaker.amazon.amazon_estimator.RecordSet]) - A list of
6771
:class:`sagemaker.amazon.amazon_estimator.RecordSet` objects,
6872
where each instance is a different channel of training data.
69-
hyperparameters (dict, optional): Specify the hyper parameters for the training. (Default: None)
73+
hyperparameters (dict, optional): Parameters used for training.
74+
Hyperparameters supplied will be merged with the Hyperparameters specified in the estimator.
75+
If there are duplicate entries, the value provided through this property will be used. (Default: Hyperparameters specified in the estimator.)
7076
mini_batch_size (int): Specify this argument only when estimator is a built-in estimator of an Amazon algorithm. For other estimators, batch size should be specified in the estimator.
7177
experiment_config (dict, optional): Specify the experiment config for the training. (Default: None)
7278
wait_for_completion (bool, optional): Boolean value set to `True` if the Task state should wait for the training job to complete before proceeding to the next step in the workflow. Set to `False` if the Task state should submit the training job and proceed to the next step. (default: True)
@@ -122,6 +128,8 @@ def __init__(self, state_id, estimator, job_name, data=None, hyperparameters=Non
122128
parameters['InputDataConfig'][0]['DataSource']['S3DataSource']['S3Uri.$'] = data_uri
123129

124130
if hyperparameters is not None:
131+
if estimator.hyperparameters() is not None:
132+
hyperparameters = self.__merge_hyperparameters(hyperparameters, estimator.hyperparameters())
125133
parameters['HyperParameters'] = hyperparameters
126134

127135
if experiment_config is not None:
@@ -153,6 +161,26 @@ def get_expected_model(self, model_name=None):
153161
model.model_data = self.output()["ModelArtifacts"]["S3ModelArtifacts"]
154162
return model
155163

164+
"""
165+
Merges the hyperparameters supplied in the TrainingStep constructor with the hyperparameters
166+
specified in the estimator. If there are duplicate entries, the value provided in the constructor
167+
will be used.
168+
"""
169+
170+
def __merge_hyperparameters(self, training_step_hyperparameters, estimator_hyperparameters):
171+
"""
172+
Args:
173+
training_step_hyperparameters (dict): Hyperparameters supplied in the training step constructor
174+
estimator_hyperparameters (dict): Hyperparameters specified in the estimator
175+
"""
176+
merged_hyperparameters = estimator_hyperparameters.copy()
177+
for key, value in training_step_hyperparameters.items():
178+
if key in merged_hyperparameters:
179+
logger.info(
180+
f"hyperparameter property: <{key}> with value: <{merged_hyperparameters[key]}> provided in the"
181+
f" estimator will be overwritten with value provided in constructor: <{value}>")
182+
merged_hyperparameters[key] = value
183+
return merged_hyperparameters
156184

157185
class TransformStep(Task):
158186

tests/unit/test_sagemaker_steps.py

+129
Original file line numberDiff line numberDiff line change
@@ -559,6 +559,135 @@ def test_training_step_creation_with_framework(tensorflow_estimator):
559559
'End': True
560560
}
561561

562+
@patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call)
563+
@patch.object(boto3.session.Session, 'region_name', 'us-east-1')
564+
def test_training_step_merges_hyperparameters_from_constructor_and_estimator(tensorflow_estimator):
565+
step = TrainingStep('Training',
566+
estimator=tensorflow_estimator,
567+
data={'train': 's3://sagemaker/train'},
568+
job_name='tensorflow-job',
569+
mini_batch_size=1024,
570+
hyperparameters={
571+
'key': 'value'
572+
}
573+
)
574+
575+
assert step.to_dict() == {
576+
'Type': 'Task',
577+
'Parameters': {
578+
'AlgorithmSpecification': {
579+
'TrainingImage': TENSORFLOW_IMAGE,
580+
'TrainingInputMode': 'File'
581+
},
582+
'InputDataConfig': [
583+
{
584+
'DataSource': {
585+
'S3DataSource': {
586+
'S3DataDistributionType': 'FullyReplicated',
587+
'S3DataType': 'S3Prefix',
588+
'S3Uri': 's3://sagemaker/train'
589+
}
590+
},
591+
'ChannelName': 'train'
592+
}
593+
],
594+
'OutputDataConfig': {
595+
'S3OutputPath': 's3://sagemaker/models'
596+
},
597+
'DebugHookConfig': {
598+
'S3OutputPath': 's3://sagemaker/models/debug'
599+
},
600+
'StoppingCondition': {
601+
'MaxRuntimeInSeconds': 86400
602+
},
603+
'ResourceConfig': {
604+
'InstanceCount': 1,
605+
'InstanceType': 'ml.p2.xlarge',
606+
'VolumeSizeInGB': 30
607+
},
608+
'RoleArn': EXECUTION_ROLE,
609+
'HyperParameters': {
610+
'checkpoint_path': '"s3://sagemaker/models/sagemaker-tensorflow/checkpoints"',
611+
'evaluation_steps': '100',
612+
'key': 'value',
613+
'sagemaker_container_log_level': '20',
614+
'sagemaker_job_name': '"tensorflow-job"',
615+
'sagemaker_program': '"tf_train.py"',
616+
'sagemaker_region': '"us-east-1"',
617+
'sagemaker_submit_directory': '"s3://sagemaker/source"',
618+
'training_steps': '1000',
619+
},
620+
'TrainingJobName': 'tensorflow-job',
621+
},
622+
'Resource': 'arn:aws:states:::sagemaker:createTrainingJob.sync',
623+
'End': True
624+
}
625+
626+
627+
@patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call)
628+
@patch.object(boto3.session.Session, 'region_name', 'us-east-1')
629+
def test_training_step_uses_constructor_hyperparameters_when_duplicates_supplied_in_estimator(tensorflow_estimator):
630+
step = TrainingStep('Training',
631+
estimator=tensorflow_estimator,
632+
data={'train': 's3://sagemaker/train'},
633+
job_name='tensorflow-job',
634+
mini_batch_size=1024,
635+
hyperparameters={
636+
# set as 1000 in estimator
637+
'training_steps': '500'
638+
}
639+
)
640+
641+
assert step.to_dict() == {
642+
'Type': 'Task',
643+
'Parameters': {
644+
'AlgorithmSpecification': {
645+
'TrainingImage': TENSORFLOW_IMAGE,
646+
'TrainingInputMode': 'File'
647+
},
648+
'InputDataConfig': [
649+
{
650+
'DataSource': {
651+
'S3DataSource': {
652+
'S3DataDistributionType': 'FullyReplicated',
653+
'S3DataType': 'S3Prefix',
654+
'S3Uri': 's3://sagemaker/train'
655+
}
656+
},
657+
'ChannelName': 'train'
658+
}
659+
],
660+
'OutputDataConfig': {
661+
'S3OutputPath': 's3://sagemaker/models'
662+
},
663+
'DebugHookConfig': {
664+
'S3OutputPath': 's3://sagemaker/models/debug'
665+
},
666+
'StoppingCondition': {
667+
'MaxRuntimeInSeconds': 86400
668+
},
669+
'ResourceConfig': {
670+
'InstanceCount': 1,
671+
'InstanceType': 'ml.p2.xlarge',
672+
'VolumeSizeInGB': 30
673+
},
674+
'RoleArn': EXECUTION_ROLE,
675+
'HyperParameters': {
676+
'checkpoint_path': '"s3://sagemaker/models/sagemaker-tensorflow/checkpoints"',
677+
'evaluation_steps': '100',
678+
'sagemaker_container_log_level': '20',
679+
'sagemaker_job_name': '"tensorflow-job"',
680+
'sagemaker_program': '"tf_train.py"',
681+
'sagemaker_region': '"us-east-1"',
682+
'sagemaker_submit_directory': '"s3://sagemaker/source"',
683+
'training_steps': '500',
684+
},
685+
'TrainingJobName': 'tensorflow-job',
686+
},
687+
'Resource': 'arn:aws:states:::sagemaker:createTrainingJob.sync',
688+
'End': True
689+
}
690+
562691

563692
@patch.object(boto3.session.Session, 'region_name', 'us-east-1')
564693
def test_transform_step_creation(pca_transformer):

0 commit comments

Comments
 (0)