Skip to content

Commit efeaf21

Browse files
committed
feedback and update tests
1 parent d672d7d commit efeaf21

File tree

2 files changed

+29
-7
lines changed

2 files changed

+29
-7
lines changed

src/stepfunctions/steps/sagemaker.py

+27-5
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15+
import logging
16+
1517
from enum import Enum
1618
from stepfunctions.inputs import ExecutionInput, StepInput
1719
from stepfunctions.steps.states import Task
@@ -23,6 +25,8 @@
2325
from sagemaker.model import Model, FrameworkModel
2426
from sagemaker.model_monitor import DataCaptureConfig
2527

28+
logger = logging.getLogger('stepfunctions.sagemaker')
29+
2630
SAGEMAKER_SERVICE_NAME = "sagemaker"
2731

2832

@@ -64,7 +68,7 @@ def __init__(self, state_id, estimator, job_name, data=None, hyperparameters=Non
6468
* (list[sagemaker.amazon.amazon_estimator.RecordSet]) - A list of
6569
:class:`sagemaker.amazon.amazon_estimator.RecordSet` objects,
6670
where each instance is a different channel of training data.
67-
hyperparameters (dict, optional): Specify the hyper parameters for the training. (Default: None)
71+
hyperparameters (dict, optional): Specify the hyperparameters that are set before the model begins training. If hyperparameters provided are also specified in the estimator, the provided value will used. (Default: Hyperparameters specified in the estimator will be used for training.)
6872
mini_batch_size (int): Specify this argument only when estimator is a built-in estimator of an Amazon algorithm. For other estimators, batch size should be specified in the estimator.
6973
experiment_config (dict, optional): Specify the experiment config for the training. (Default: None)
7074
wait_for_completion (bool, optional): Boolean value set to `True` if the Task state should wait for the training job to complete before proceeding to the next step in the workflow. Set to `False` if the Task state should submit the training job and proceed to the next step. (default: True)
@@ -104,11 +108,9 @@ def __init__(self, state_id, estimator, job_name, data=None, hyperparameters=Non
104108
parameters['TrainingJobName'] = job_name
105109

106110
if hyperparameters is not None:
107-
merged_hyperparameters = {}
108111
if estimator.hyperparameters() is not None:
109-
merged_hyperparameters.update(estimator.hyperparameters())
110-
merged_hyperparameters.update(hyperparameters)
111-
parameters['HyperParameters'] = merged_hyperparameters
112+
hyperparameters = self.__merge_hyperparameters(hyperparameters, estimator.hyperparameters())
113+
parameters['HyperParameters'] = hyperparameters
112114

113115
if experiment_config is not None:
114116
parameters['ExperimentConfig'] = experiment_config
@@ -139,6 +141,26 @@ def get_expected_model(self, model_name=None):
139141
model.model_data = self.output()["ModelArtifacts"]["S3ModelArtifacts"]
140142
return model
141143

144+
"""
145+
Merges the hyperparameters supplied in the TrainingStep constructor with the hyperparameters
146+
specified in the estimator. If there are duplicate entries, the value provided in the constructor
147+
will be used.
148+
"""
149+
150+
def __merge_hyperparameters(self, training_step_hyperparameters, estimator_hyperparameters):
151+
"""
152+
Args:
153+
training_step_hyperparameters (dict): Hyperparameters supplied in the training step constructor
154+
estimator_hyperparameters (dict): Hyperparameters specified in the estimator
155+
"""
156+
merged_hyperparameters = estimator_hyperparameters.copy()
157+
for key, value in training_step_hyperparameters.items():
158+
if key in merged_hyperparameters:
159+
logger.info(
160+
f"hyperparameter property: <{key}> with value: <{merged_hyperparameters[key]}> provided in the"
161+
f" estimator will be overwritten with value provided in constructor: <{value}>")
162+
merged_hyperparameters[key] = value
163+
return merged_hyperparameters
142164

143165
class TransformStep(Task):
144166

tests/unit/test_sagemaker_steps.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -484,7 +484,7 @@ def test_training_step_creation_with_framework(tensorflow_estimator):
484484

485485
@patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call)
486486
@patch.object(boto3.session.Session, 'region_name', 'us-east-1')
487-
def training_step_merges_hyperparameters_from_constructor_and_estimator(tensorflow_estimator):
487+
def test_training_step_merges_hyperparameters_from_constructor_and_estimator(tensorflow_estimator):
488488
step = TrainingStep('Training',
489489
estimator=tensorflow_estimator,
490490
data={'train': 's3://sagemaker/train'},
@@ -549,7 +549,7 @@ def training_step_merges_hyperparameters_from_constructor_and_estimator(tensorfl
549549

550550
@patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call)
551551
@patch.object(boto3.session.Session, 'region_name', 'us-east-1')
552-
def training_step_uses_constructor_hyperparameters_when_duplicates_supplied_in_estimator(tensorflow_estimator):
552+
def test_training_step_uses_constructor_hyperparameters_when_duplicates_supplied_in_estimator(tensorflow_estimator):
553553
step = TrainingStep('Training',
554554
estimator=tensorflow_estimator,
555555
data={'train': 's3://sagemaker/train'},

0 commit comments

Comments
 (0)