|
12 | 12 | # permissions and limitations under the License.
|
13 | 13 | from __future__ import absolute_import
|
14 | 14 |
|
| 15 | +import logging |
| 16 | + |
15 | 17 | from enum import Enum
|
16 | 18 | from stepfunctions.inputs import ExecutionInput, StepInput
|
17 | 19 | from stepfunctions.steps.states import Task
|
|
23 | 25 | from sagemaker.model import Model, FrameworkModel
|
24 | 26 | from sagemaker.model_monitor import DataCaptureConfig
|
25 | 27 |
|
| 28 | +logger = logging.getLogger('stepfunctions.sagemaker') |
| 29 | + |
26 | 30 | SAGEMAKER_SERVICE_NAME = "sagemaker"
|
27 | 31 |
|
28 | 32 |
|
@@ -64,7 +68,7 @@ def __init__(self, state_id, estimator, job_name, data=None, hyperparameters=Non
|
64 | 68 | * (list[sagemaker.amazon.amazon_estimator.RecordSet]) - A list of
|
65 | 69 | :class:`sagemaker.amazon.amazon_estimator.RecordSet` objects,
|
66 | 70 | where each instance is a different channel of training data.
|
67 |
| - hyperparameters (dict, optional): Specify the hyper parameters for the training. (Default: None) |
| 71 | + hyperparameters (dict, optional): Specify the hyperparameters that are set before the model begins training. If hyperparameters provided are also specified in the estimator, the provided value will used. (Default: Hyperparameters specified in the estimator will be used for training.) |
68 | 72 | mini_batch_size (int): Specify this argument only when estimator is a built-in estimator of an Amazon algorithm. For other estimators, batch size should be specified in the estimator.
|
69 | 73 | experiment_config (dict, optional): Specify the experiment config for the training. (Default: None)
|
70 | 74 | wait_for_completion (bool, optional): Boolean value set to `True` if the Task state should wait for the training job to complete before proceeding to the next step in the workflow. Set to `False` if the Task state should submit the training job and proceed to the next step. (default: True)
|
@@ -104,11 +108,9 @@ def __init__(self, state_id, estimator, job_name, data=None, hyperparameters=Non
|
104 | 108 | parameters['TrainingJobName'] = job_name
|
105 | 109 |
|
106 | 110 | if hyperparameters is not None:
|
107 |
| - merged_hyperparameters = {} |
108 | 111 | if estimator.hyperparameters() is not None:
|
109 |
| - merged_hyperparameters.update(estimator.hyperparameters()) |
110 |
| - merged_hyperparameters.update(hyperparameters) |
111 |
| - parameters['HyperParameters'] = merged_hyperparameters |
| 112 | + hyperparameters = self.__merge_hyperparameters(hyperparameters, estimator.hyperparameters()) |
| 113 | + parameters['HyperParameters'] = hyperparameters |
112 | 114 |
|
113 | 115 | if experiment_config is not None:
|
114 | 116 | parameters['ExperimentConfig'] = experiment_config
|
@@ -139,6 +141,26 @@ def get_expected_model(self, model_name=None):
|
139 | 141 | model.model_data = self.output()["ModelArtifacts"]["S3ModelArtifacts"]
|
140 | 142 | return model
|
141 | 143 |
|
| 144 | + """ |
| 145 | + Merges the hyperparameters supplied in the TrainingStep constructor with the hyperparameters |
| 146 | + specified in the estimator. If there are duplicate entries, the value provided in the constructor |
| 147 | + will be used. |
| 148 | + """ |
| 149 | + |
| 150 | + def __merge_hyperparameters(self, training_step_hyperparameters, estimator_hyperparameters): |
| 151 | + """ |
| 152 | + Args: |
| 153 | + training_step_hyperparameters (dict): Hyperparameters supplied in the training step constructor |
| 154 | + estimator_hyperparameters (dict): Hyperparameters specified in the estimator |
| 155 | + """ |
| 156 | + merged_hyperparameters = estimator_hyperparameters.copy() |
| 157 | + for key, value in training_step_hyperparameters.items(): |
| 158 | + if key in merged_hyperparameters: |
| 159 | + logger.info( |
| 160 | + f"hyperparameter property: <{key}> with value: <{merged_hyperparameters[key]}> provided in the" |
| 161 | + f" estimator will be overwritten with value provided in constructor: <{value}>") |
| 162 | + merged_hyperparameters[key] = value |
| 163 | + return merged_hyperparameters |
142 | 164 |
|
143 | 165 | class TransformStep(Task):
|
144 | 166 |
|
|
0 commit comments