Skip to content

Commit 349fc11

Browse files
authored
fix: supplying hyperparameters to training step constructor drops hyperparameters specified in estimator (aws#144)
Hyperparameters can be specified in the estimator object and hyperparameters property. Both of which are taken in the constructor of the TrainingStep class. The current behaviour drops any hyperparameters that were specified in the estimator if the property is set in the TrainingStep constructor. This is undesirable as the estimators often specify algorithm specific hyperparameters out of the box that we don't want to drop. This change merges the hyperparameters in the constructor as well as the estimator that is used in TrainingStep. If there are duplicate keys, the hyperparameters specified in the constructor will be used. Closes aws#99, aws#72
1 parent c82bd52 commit 349fc11

File tree

3 files changed

+159
-2
lines changed

3 files changed

+159
-2
lines changed

README.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ to provision and integrate the AWS services separately.
1717
The AWS Step Functions Data Science SDK enables you to do the following.
1818

1919
- Easily construct and run machine learning workflows that use AWS
20-
infrastructure directly in Python
20+
infrastructure directly in Python
2121
- Instantiate common training pipelines
2222
- Create standard machine learning workflows in a Jupyter notebook from
2323
templates

src/stepfunctions/steps/sagemaker.py

+29-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15+
import logging
16+
1517
from enum import Enum
1618
from stepfunctions.inputs import ExecutionInput, StepInput
1719
from stepfunctions.steps.states import Task
@@ -23,6 +25,8 @@
2325
from sagemaker.model import Model, FrameworkModel
2426
from sagemaker.model_monitor import DataCaptureConfig
2527

28+
logger = logging.getLogger('stepfunctions.sagemaker')
29+
2630
SAGEMAKER_SERVICE_NAME = "sagemaker"
2731

2832

@@ -64,7 +68,9 @@ def __init__(self, state_id, estimator, job_name, data=None, hyperparameters=Non
6468
* (list[sagemaker.amazon.amazon_estimator.RecordSet]) - A list of
6569
:class:`sagemaker.amazon.amazon_estimator.RecordSet` objects,
6670
where each instance is a different channel of training data.
67-
hyperparameters (dict, optional): Specify the hyper parameters for the training. (Default: None)
71+
hyperparameters (dict, optional): Parameters used for training.
72+
Hyperparameters supplied will be merged with the Hyperparameters specified in the estimator.
73+
If there are duplicate entries, the value provided through this property will be used. (Default: Hyperparameters specified in the estimator.)
6874
mini_batch_size (int): Specify this argument only when estimator is a built-in estimator of an Amazon algorithm. For other estimators, batch size should be specified in the estimator.
6975
experiment_config (dict, optional): Specify the experiment config for the training. (Default: None)
7076
wait_for_completion (bool, optional): Boolean value set to `True` if the Task state should wait for the training job to complete before proceeding to the next step in the workflow. Set to `False` if the Task state should submit the training job and proceed to the next step. (default: True)
@@ -104,6 +110,8 @@ def __init__(self, state_id, estimator, job_name, data=None, hyperparameters=Non
104110
parameters['TrainingJobName'] = job_name
105111

106112
if hyperparameters is not None:
113+
if estimator.hyperparameters() is not None:
114+
hyperparameters = self.__merge_hyperparameters(hyperparameters, estimator.hyperparameters())
107115
parameters['HyperParameters'] = hyperparameters
108116

109117
if experiment_config is not None:
@@ -135,6 +143,26 @@ def get_expected_model(self, model_name=None):
135143
model.model_data = self.output()["ModelArtifacts"]["S3ModelArtifacts"]
136144
return model
137145

146+
"""
147+
Merges the hyperparameters supplied in the TrainingStep constructor with the hyperparameters
148+
specified in the estimator. If there are duplicate entries, the value provided in the constructor
149+
will be used.
150+
"""
151+
152+
def __merge_hyperparameters(self, training_step_hyperparameters, estimator_hyperparameters):
153+
"""
154+
Args:
155+
training_step_hyperparameters (dict): Hyperparameters supplied in the training step constructor
156+
estimator_hyperparameters (dict): Hyperparameters specified in the estimator
157+
"""
158+
merged_hyperparameters = estimator_hyperparameters.copy()
159+
for key, value in training_step_hyperparameters.items():
160+
if key in merged_hyperparameters:
161+
logger.info(
162+
f"hyperparameter property: <{key}> with value: <{merged_hyperparameters[key]}> provided in the"
163+
f" estimator will be overwritten with value provided in constructor: <{value}>")
164+
merged_hyperparameters[key] = value
165+
return merged_hyperparameters
138166

139167
class TransformStep(Task):
140168

tests/unit/test_sagemaker_steps.py

+129
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,135 @@ def test_training_step_creation_with_framework(tensorflow_estimator):
482482
'End': True
483483
}
484484

485+
@patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call)
486+
@patch.object(boto3.session.Session, 'region_name', 'us-east-1')
487+
def test_training_step_merges_hyperparameters_from_constructor_and_estimator(tensorflow_estimator):
488+
step = TrainingStep('Training',
489+
estimator=tensorflow_estimator,
490+
data={'train': 's3://sagemaker/train'},
491+
job_name='tensorflow-job',
492+
mini_batch_size=1024,
493+
hyperparameters={
494+
'key': 'value'
495+
}
496+
)
497+
498+
assert step.to_dict() == {
499+
'Type': 'Task',
500+
'Parameters': {
501+
'AlgorithmSpecification': {
502+
'TrainingImage': TENSORFLOW_IMAGE,
503+
'TrainingInputMode': 'File'
504+
},
505+
'InputDataConfig': [
506+
{
507+
'DataSource': {
508+
'S3DataSource': {
509+
'S3DataDistributionType': 'FullyReplicated',
510+
'S3DataType': 'S3Prefix',
511+
'S3Uri': 's3://sagemaker/train'
512+
}
513+
},
514+
'ChannelName': 'train'
515+
}
516+
],
517+
'OutputDataConfig': {
518+
'S3OutputPath': 's3://sagemaker/models'
519+
},
520+
'DebugHookConfig': {
521+
'S3OutputPath': 's3://sagemaker/models/debug'
522+
},
523+
'StoppingCondition': {
524+
'MaxRuntimeInSeconds': 86400
525+
},
526+
'ResourceConfig': {
527+
'InstanceCount': 1,
528+
'InstanceType': 'ml.p2.xlarge',
529+
'VolumeSizeInGB': 30
530+
},
531+
'RoleArn': EXECUTION_ROLE,
532+
'HyperParameters': {
533+
'checkpoint_path': '"s3://sagemaker/models/sagemaker-tensorflow/checkpoints"',
534+
'evaluation_steps': '100',
535+
'key': 'value',
536+
'sagemaker_container_log_level': '20',
537+
'sagemaker_job_name': '"tensorflow-job"',
538+
'sagemaker_program': '"tf_train.py"',
539+
'sagemaker_region': '"us-east-1"',
540+
'sagemaker_submit_directory': '"s3://sagemaker/source"',
541+
'training_steps': '1000',
542+
},
543+
'TrainingJobName': 'tensorflow-job',
544+
},
545+
'Resource': 'arn:aws:states:::sagemaker:createTrainingJob.sync',
546+
'End': True
547+
}
548+
549+
550+
@patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call)
551+
@patch.object(boto3.session.Session, 'region_name', 'us-east-1')
552+
def test_training_step_uses_constructor_hyperparameters_when_duplicates_supplied_in_estimator(tensorflow_estimator):
553+
step = TrainingStep('Training',
554+
estimator=tensorflow_estimator,
555+
data={'train': 's3://sagemaker/train'},
556+
job_name='tensorflow-job',
557+
mini_batch_size=1024,
558+
hyperparameters={
559+
# set as 1000 in estimator
560+
'training_steps': '500'
561+
}
562+
)
563+
564+
assert step.to_dict() == {
565+
'Type': 'Task',
566+
'Parameters': {
567+
'AlgorithmSpecification': {
568+
'TrainingImage': TENSORFLOW_IMAGE,
569+
'TrainingInputMode': 'File'
570+
},
571+
'InputDataConfig': [
572+
{
573+
'DataSource': {
574+
'S3DataSource': {
575+
'S3DataDistributionType': 'FullyReplicated',
576+
'S3DataType': 'S3Prefix',
577+
'S3Uri': 's3://sagemaker/train'
578+
}
579+
},
580+
'ChannelName': 'train'
581+
}
582+
],
583+
'OutputDataConfig': {
584+
'S3OutputPath': 's3://sagemaker/models'
585+
},
586+
'DebugHookConfig': {
587+
'S3OutputPath': 's3://sagemaker/models/debug'
588+
},
589+
'StoppingCondition': {
590+
'MaxRuntimeInSeconds': 86400
591+
},
592+
'ResourceConfig': {
593+
'InstanceCount': 1,
594+
'InstanceType': 'ml.p2.xlarge',
595+
'VolumeSizeInGB': 30
596+
},
597+
'RoleArn': EXECUTION_ROLE,
598+
'HyperParameters': {
599+
'checkpoint_path': '"s3://sagemaker/models/sagemaker-tensorflow/checkpoints"',
600+
'evaluation_steps': '100',
601+
'sagemaker_container_log_level': '20',
602+
'sagemaker_job_name': '"tensorflow-job"',
603+
'sagemaker_program': '"tf_train.py"',
604+
'sagemaker_region': '"us-east-1"',
605+
'sagemaker_submit_directory': '"s3://sagemaker/source"',
606+
'training_steps': '500',
607+
},
608+
'TrainingJobName': 'tensorflow-job',
609+
},
610+
'Resource': 'arn:aws:states:::sagemaker:createTrainingJob.sync',
611+
'End': True
612+
}
613+
485614

486615
@patch.object(boto3.session.Session, 'region_name', 'us-east-1')
487616
def test_transform_step_creation(pca_transformer):

0 commit comments

Comments
 (0)