Skip to content

fix: supplying hyperparameters to training step constructor drops hyperparameters specified in estimator #144

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 18, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion src/stepfunctions/steps/sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# permissions and limitations under the License.
from __future__ import absolute_import

import logging

from enum import Enum
from stepfunctions.inputs import ExecutionInput, StepInput
from stepfunctions.steps.states import Task
Expand All @@ -23,6 +25,8 @@
from sagemaker.model import Model, FrameworkModel
from sagemaker.model_monitor import DataCaptureConfig

logger = logging.getLogger('stepfunctions.sagemaker')

SAGEMAKER_SERVICE_NAME = "sagemaker"


Expand Down Expand Up @@ -64,7 +68,7 @@ def __init__(self, state_id, estimator, job_name, data=None, hyperparameters=Non
* (list[sagemaker.amazon.amazon_estimator.RecordSet]) - A list of
:class:`sagemaker.amazon.amazon_estimator.RecordSet` objects,
where each instance is a different channel of training data.
hyperparameters (dict, optional): Specify the hyper parameters for the training. (Default: None)
hyperparameters (dict, optional): Specify the hyperparameters that are set before the model begins training. If hyperparameters provided are also specified in the estimator, the provided value will used. (Default: Hyperparameters specified in the estimator will be used for training.)
mini_batch_size (int): Specify this argument only when estimator is a built-in estimator of an Amazon algorithm. For other estimators, batch size should be specified in the estimator.
experiment_config (dict, optional): Specify the experiment config for the training. (Default: None)
wait_for_completion (bool, optional): Boolean value set to `True` if the Task state should wait for the training job to complete before proceeding to the next step in the workflow. Set to `False` if the Task state should submit the training job and proceed to the next step. (default: True)
Expand Down Expand Up @@ -104,6 +108,8 @@ def __init__(self, state_id, estimator, job_name, data=None, hyperparameters=Non
parameters['TrainingJobName'] = job_name

if hyperparameters is not None:
if estimator.hyperparameters() is not None:
hyperparameters = self.__merge_hyperparameters(hyperparameters, estimator.hyperparameters())
parameters['HyperParameters'] = hyperparameters

if experiment_config is not None:
Expand Down Expand Up @@ -135,6 +141,26 @@ def get_expected_model(self, model_name=None):
model.model_data = self.output()["ModelArtifacts"]["S3ModelArtifacts"]
return model

"""
Merges the hyperparameters supplied in the TrainingStep constructor with the hyperparameters
specified in the estimator. If there are duplicate entries, the value provided in the constructor
will be used.
"""

def __merge_hyperparameters(self, training_step_hyperparameters, estimator_hyperparameters):
"""
Args:
training_step_hyperparameters (dict): Hyperparameters supplied in the training step constructor
estimator_hyperparameters (dict): Hyperparameters specified in the estimator
"""
merged_hyperparameters = estimator_hyperparameters.copy()
for key, value in training_step_hyperparameters.items():
if key in merged_hyperparameters:
logger.info(
f"hyperparameter property: <{key}> with value: <{merged_hyperparameters[key]}> provided in the"
f" estimator will be overwritten with value provided in constructor: <{value}>")
merged_hyperparameters[key] = value
return merged_hyperparameters

class TransformStep(Task):

Expand Down
129 changes: 129 additions & 0 deletions tests/unit/test_sagemaker_steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,135 @@ def test_training_step_creation_with_framework(tensorflow_estimator):
'End': True
}

@patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call)
@patch.object(boto3.session.Session, 'region_name', 'us-east-1')
def test_training_step_merges_hyperparameters_from_constructor_and_estimator(tensorflow_estimator):
step = TrainingStep('Training',
estimator=tensorflow_estimator,
data={'train': 's3://sagemaker/train'},
job_name='tensorflow-job',
mini_batch_size=1024,
hyperparameters={
'key': 'value'
}
)

assert step.to_dict() == {
'Type': 'Task',
'Parameters': {
'AlgorithmSpecification': {
'TrainingImage': TENSORFLOW_IMAGE,
'TrainingInputMode': 'File'
},
'InputDataConfig': [
{
'DataSource': {
'S3DataSource': {
'S3DataDistributionType': 'FullyReplicated',
'S3DataType': 'S3Prefix',
'S3Uri': 's3://sagemaker/train'
}
},
'ChannelName': 'train'
}
],
'OutputDataConfig': {
'S3OutputPath': 's3://sagemaker/models'
},
'DebugHookConfig': {
'S3OutputPath': 's3://sagemaker/models/debug'
},
'StoppingCondition': {
'MaxRuntimeInSeconds': 86400
},
'ResourceConfig': {
'InstanceCount': 1,
'InstanceType': 'ml.p2.xlarge',
'VolumeSizeInGB': 30
},
'RoleArn': EXECUTION_ROLE,
'HyperParameters': {
'checkpoint_path': '"s3://sagemaker/models/sagemaker-tensorflow/checkpoints"',
'evaluation_steps': '100',
'key': 'value',
'sagemaker_container_log_level': '20',
'sagemaker_job_name': '"tensorflow-job"',
'sagemaker_program': '"tf_train.py"',
'sagemaker_region': '"us-east-1"',
'sagemaker_submit_directory': '"s3://sagemaker/source"',
'training_steps': '1000',
},
'TrainingJobName': 'tensorflow-job',
},
'Resource': 'arn:aws:states:::sagemaker:createTrainingJob.sync',
'End': True
}


@patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call)
@patch.object(boto3.session.Session, 'region_name', 'us-east-1')
def test_training_step_uses_constructor_hyperparameters_when_duplicates_supplied_in_estimator(tensorflow_estimator):
step = TrainingStep('Training',
estimator=tensorflow_estimator,
data={'train': 's3://sagemaker/train'},
job_name='tensorflow-job',
mini_batch_size=1024,
hyperparameters={
# set as 1000 in estimator
'training_steps': '500'
}
)

assert step.to_dict() == {
'Type': 'Task',
'Parameters': {
'AlgorithmSpecification': {
'TrainingImage': TENSORFLOW_IMAGE,
'TrainingInputMode': 'File'
},
'InputDataConfig': [
{
'DataSource': {
'S3DataSource': {
'S3DataDistributionType': 'FullyReplicated',
'S3DataType': 'S3Prefix',
'S3Uri': 's3://sagemaker/train'
}
},
'ChannelName': 'train'
}
],
'OutputDataConfig': {
'S3OutputPath': 's3://sagemaker/models'
},
'DebugHookConfig': {
'S3OutputPath': 's3://sagemaker/models/debug'
},
'StoppingCondition': {
'MaxRuntimeInSeconds': 86400
},
'ResourceConfig': {
'InstanceCount': 1,
'InstanceType': 'ml.p2.xlarge',
'VolumeSizeInGB': 30
},
'RoleArn': EXECUTION_ROLE,
'HyperParameters': {
'checkpoint_path': '"s3://sagemaker/models/sagemaker-tensorflow/checkpoints"',
'evaluation_steps': '100',
'sagemaker_container_log_level': '20',
'sagemaker_job_name': '"tensorflow-job"',
'sagemaker_program': '"tf_train.py"',
'sagemaker_region': '"us-east-1"',
'sagemaker_submit_directory': '"s3://sagemaker/source"',
'training_steps': '500',
},
'TrainingJobName': 'tensorflow-job',
},
'Resource': 'arn:aws:states:::sagemaker:createTrainingJob.sync',
'End': True
}


@patch.object(boto3.session.Session, 'region_name', 'us-east-1')
def test_transform_step_creation(pca_transformer):
Expand Down