Skip to content

Commit 71ffc5b

Browse files
committed
feat: Support placeholders for TuningStep parameters
1 parent 23878de commit 71ffc5b

File tree

3 files changed

+364
-9
lines changed

3 files changed

+364
-9
lines changed

src/stepfunctions/steps/sagemaker.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,10 @@ def __init__(self, state_id, tuner, job_name, data, wait_for_completion=True, ta
444444
:class:`sagemaker.amazon.amazon_estimator.RecordSet` objects,
445445
where each instance is a different channel of training data.
446446
wait_for_completion(bool, optional): Boolean value set to `True` if the Task state should wait for the tuning job to complete before proceeding to the next step in the workflow. Set to `False` if the Task state should submit the tuning job and proceed to the next step. (default: True)
447-
tags (list[dict], optional): `List to tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
447+
tags (list[dict] or Placeholder, optional): `List to tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
448+
parameters(dict, optional): The value of this field is merged with other arguments to become the request payload for SageMaker `CreateHyperParameterTuningJob<https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateHyperParameterTuningJob.html>`_.
449+
You can use `parameters` to override the value provided by other arguments and specify any field's value dynamically using `Placeholders<https://aws-step-functions-data-science-sdk.readthedocs.io/en/stable/placeholders.html?highlight=placeholder#stepfunctions.inputs.Placeholder>`_.
450+
448451
"""
449452
if wait_for_completion:
450453
"""
@@ -462,19 +465,22 @@ def __init__(self, state_id, tuner, job_name, data, wait_for_completion=True, ta
462465
kwargs[Field.Resource.value] = get_service_integration_arn(SAGEMAKER_SERVICE_NAME,
463466
SageMakerApi.CreateHyperParameterTuningJob)
464467

465-
parameters = tuning_config(tuner=tuner, inputs=data, job_name=job_name).copy()
468+
tuning_parameters = tuning_config(tuner=tuner, inputs=data, job_name=job_name).copy()
466469

467470
if job_name is not None:
468-
parameters['HyperParameterTuningJobName'] = job_name
471+
tuning_parameters['HyperParameterTuningJobName'] = job_name
469472

470-
if 'S3Operations' in parameters:
471-
del parameters['S3Operations']
473+
if 'S3Operations' in tuning_parameters:
474+
del tuning_parameters['S3Operations']
472475

473476
if tags:
474-
parameters['Tags'] = tags_dict_to_kv_list(tags)
477+
tuning_parameters['Tags'] = tags if isinstance(tags, Placeholder) else tags_dict_to_kv_list(tags)
475478

476-
kwargs[Field.Parameters.value] = parameters
479+
if Field.Parameters.value in kwargs and isinstance(kwargs[Field.Parameters.value], dict):
480+
# Update tuning parameters with input parameters
481+
merge_dicts(tuning_parameters, kwargs[Field.Parameters.value])
477482

483+
kwargs[Field.Parameters.value] = tuning_parameters
478484
super(TuningStep, self).__init__(state_id, **kwargs)
479485

480486

tests/integ/test_sagemaker_steps.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,7 @@ def test_create_endpoint_step(trained_estimator, record_set_fixture, sfn_client,
257257
delete_sagemaker_model(model.name, sagemaker_session)
258258
# End of Cleanup
259259

260+
260261
def test_tuning_step(sfn_client, record_set_for_hyperparameter_tuning, sagemaker_role_arn, sfn_role_arn):
261262
job_name = generate_job_name()
262263

@@ -308,6 +309,123 @@ def test_tuning_step(sfn_client, record_set_for_hyperparameter_tuning, sagemaker
308309
state_machine_delete_wait(sfn_client, workflow.state_machine_arn)
309310
# End of Cleanup
310311

312+
313+
def test_tuning_step_with_placeholders(sfn_client, record_set_for_hyperparameter_tuning, sagemaker_role_arn, sfn_role_arn):
314+
kmeans = KMeans(
315+
role=sagemaker_role_arn,
316+
instance_count=1,
317+
instance_type=INSTANCE_TYPE,
318+
k=10
319+
)
320+
321+
hyperparameter_ranges = {
322+
"extra_center_factor": IntegerParameter(4, 10),
323+
"mini_batch_size": IntegerParameter(10, 100),
324+
"epochs": IntegerParameter(1, 2),
325+
"init_method": CategoricalParameter(["kmeans++", "random"]),
326+
}
327+
328+
tuner = HyperparameterTuner(
329+
estimator=kmeans,
330+
objective_metric_name="test:msd",
331+
hyperparameter_ranges=hyperparameter_ranges,
332+
objective_type="Maximize",
333+
max_jobs=2,
334+
max_parallel_jobs=1,
335+
)
336+
337+
execution_input = ExecutionInput(schema={
338+
'job_name': str,
339+
'data_input': str,
340+
'objective_metric_name': str,
341+
'objective_type': str,
342+
'max_jobs': int,
343+
'max_parallel_jobs': int,
344+
'early_stopping_type': str,
345+
'strategy': str,
346+
})
347+
348+
parameters = {
349+
'HyperParameterTuningJobConfig': {
350+
'HyperParameterTuningJobObjective': {
351+
'MetricName': execution_input['objective_metric_name'],
352+
'Type': execution_input['objective_type']
353+
},
354+
'ResourceLimits': {'MaxNumberOfTrainingJobs': execution_input['max_jobs'],
355+
'MaxParallelTrainingJobs': execution_input['max_parallel_jobs']},
356+
'Strategy': execution_input['strategy'],
357+
'TrainingJobEarlyStoppingType': execution_input['early_stopping_type']
358+
},
359+
'TrainingJobDefinition': {
360+
'AlgorithmSpecification': {
361+
'TrainingInputMode': 'File'
362+
},
363+
'InputDataConfig': execution_input['data_input']
364+
}
365+
}
366+
367+
# Build workflow definition
368+
tuning_step = TuningStep('Tuning', tuner=tuner, job_name=execution_input['job_name'],
369+
data=record_set_for_hyperparameter_tuning, parameters=parameters)
370+
tuning_step.add_retry(SAGEMAKER_RETRY_STRATEGY)
371+
workflow_graph = Chain([tuning_step])
372+
373+
with timeout(minutes=DEFAULT_TIMEOUT_MINUTES):
374+
# Create workflow and check definition
375+
workflow = create_workflow_and_check_definition(
376+
workflow_graph=workflow_graph,
377+
workflow_name=unique_name_from_base("integ-test-tuning-step-workflow"),
378+
sfn_client=sfn_client,
379+
sfn_role_arn=sfn_role_arn
380+
)
381+
382+
job_name = generate_job_name()
383+
data_input = [
384+
{
385+
"DataSource": {
386+
"S3DataSource": {
387+
"S3DataType": "ManifestFile",
388+
"S3Uri": "s3://sagemaker-us-east-1-585192044892/sagemaker-record-sets/PCA-2021-10-19-00-19-10-799/.amazon.manifest",
389+
"S3DataDistributionType": "ShardedByS3Key"
390+
}
391+
},
392+
"ChannelName": "train"
393+
},
394+
{
395+
"DataSource": {
396+
"S3DataSource": {
397+
"S3DataType": "ManifestFile",
398+
"S3Uri": "s3://sagemaker-us-east-1-585192044892/sagemaker-record-sets/PCA-2021-10-19-00-19-15-087/.amazon.manifest",
399+
"S3DataDistributionType": "ShardedByS3Key"
400+
}
401+
},
402+
"ChannelName": "test"
403+
}
404+
]
405+
406+
inputs = {
407+
'job_name': job_name,
408+
'data_input': data_input,
409+
'objective_metric_name': 'test:msd',
410+
'objective_type': 'Minimize',
411+
'max_jobs': 2,
412+
'max_parallel_jobs': 2,
413+
'early_stopping_type': 'Off',
414+
'strategy': 'Bayesian',
415+
}
416+
417+
# Execute workflow
418+
execution = workflow.execute(inputs=inputs)
419+
execution_output = execution.get_output(wait=True)
420+
421+
# Check workflow output
422+
assert execution_output.get("HyperParameterTuningJobStatus") == "Completed"
423+
424+
# Cleanup
425+
state_machine_delete_wait(sfn_client, workflow.state_machine_arn)
426+
# End of Cleanup
427+
428+
311429
def test_processing_step(sklearn_processor_fixture, sagemaker_session, sfn_client, sfn_role_arn):
312430
region = boto3.session.Session().region_name
313431
input_data = 's3://sagemaker-sample-data-{}/processing/census/census-income.csv'.format(region)

0 commit comments

Comments
 (0)