Skip to content

Commit 1ea8346

Browse files
authored
feat: Support placeholders for TuningStep (#173)
1 parent 2091850 commit 1ea8346

File tree

3 files changed

+340
-16
lines changed

3 files changed

+340
-16
lines changed

src/stepfunctions/steps/sagemaker.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -465,7 +465,10 @@ def __init__(self, state_id, tuner, job_name, data, wait_for_completion=True, ta
465465
:class:`sagemaker.amazon.amazon_estimator.RecordSet` objects,
466466
where each instance is a different channel of training data.
467467
wait_for_completion(bool, optional): Boolean value set to `True` if the Task state should wait for the tuning job to complete before proceeding to the next step in the workflow. Set to `False` if the Task state should submit the tuning job and proceed to the next step. (default: True)
468-
tags (list[dict], optional): `List of tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
468+
tags (list[dict] or Placeholder, optional): `List of tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
469+
parameters(dict, optional): The value of this field is merged with other arguments to become the request payload for SageMaker `CreateHyperParameterTuningJob<https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateHyperParameterTuningJob.html>`_.
470+
You can use `parameters` to override the value provided by other arguments and specify any field's value dynamically using `Placeholders<https://aws-step-functions-data-science-sdk.readthedocs.io/en/stable/placeholders.html?highlight=placeholder#stepfunctions.inputs.Placeholder>`_.
471+
469472
"""
470473
if wait_for_completion:
471474
"""
@@ -483,19 +486,22 @@ def __init__(self, state_id, tuner, job_name, data, wait_for_completion=True, ta
483486
kwargs[Field.Resource.value] = get_service_integration_arn(SAGEMAKER_SERVICE_NAME,
484487
SageMakerApi.CreateHyperParameterTuningJob)
485488

486-
parameters = tuning_config(tuner=tuner, inputs=data, job_name=job_name).copy()
489+
tuning_parameters = tuning_config(tuner=tuner, inputs=data, job_name=job_name).copy()
487490

488491
if job_name is not None:
489-
parameters['HyperParameterTuningJobName'] = job_name
492+
tuning_parameters['HyperParameterTuningJobName'] = job_name
490493

491-
if 'S3Operations' in parameters:
492-
del parameters['S3Operations']
494+
if 'S3Operations' in tuning_parameters:
495+
del tuning_parameters['S3Operations']
493496

494497
if tags:
495-
parameters['Tags'] = tags_dict_to_kv_list(tags)
498+
tuning_parameters['Tags'] = tags if isinstance(tags, Placeholder) else tags_dict_to_kv_list(tags)
496499

497-
kwargs[Field.Parameters.value] = parameters
500+
if Field.Parameters.value in kwargs and isinstance(kwargs[Field.Parameters.value], dict):
501+
# Update tuning parameters with input parameters
502+
merge_dicts(tuning_parameters, kwargs[Field.Parameters.value])
498503

504+
kwargs[Field.Parameters.value] = tuning_parameters
499505
super(TuningStep, self).__init__(state_id, **kwargs)
500506

501507

tests/integ/test_sagemaker_steps.py

Lines changed: 94 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,6 @@ def test_training_step(pca_estimator_fixture, record_set_fixture, sfn_client, sf
104104

105105
# Cleanup
106106
state_machine_delete_wait(sfn_client, workflow.state_machine_arn)
107-
# End of Cleanup
108107

109108

110109
def test_training_step_with_placeholders(pca_estimator_fixture, record_set_fixture, sfn_client, sfn_role_arn):
@@ -193,7 +192,7 @@ def test_model_step(trained_estimator, sfn_client, sagemaker_session, sfn_role_a
193192
state_machine_delete_wait(sfn_client, workflow.state_machine_arn)
194193
model_name = get_resource_name_from_arn(execution_output.get("ModelArn")).split("/")[1]
195194
delete_sagemaker_model(model_name, sagemaker_session)
196-
# End of Cleanup
195+
197196

198197

199198
def test_model_step_with_placeholders(trained_estimator, sfn_client, sagemaker_session, sfn_role_arn):
@@ -288,7 +287,6 @@ def test_transform_step(trained_estimator, sfn_client, sfn_role_arn):
288287

289288
# Cleanup
290289
state_machine_delete_wait(sfn_client, workflow.state_machine_arn)
291-
# End of Cleanup
292290

293291

294292
def test_transform_step_with_placeholder(trained_estimator, sfn_client, sfn_role_arn):
@@ -413,7 +411,7 @@ def test_endpoint_config_step(trained_estimator, sfn_client, sagemaker_session,
413411
state_machine_delete_wait(sfn_client, workflow.state_machine_arn)
414412
delete_sagemaker_endpoint_config(endpoint_config_name, sagemaker_session)
415413
delete_sagemaker_model(model.name, sagemaker_session)
416-
# End of Cleanup
414+
417415

418416
def test_create_endpoint_step(trained_estimator, record_set_fixture, sfn_client, sagemaker_session, sfn_role_arn):
419417
# Setup: Create model and endpoint config for trained estimator in SageMaker
@@ -456,7 +454,7 @@ def test_create_endpoint_step(trained_estimator, record_set_fixture, sfn_client,
456454
delete_sagemaker_endpoint(endpoint_name, sagemaker_session)
457455
delete_sagemaker_endpoint_config(model.name, sagemaker_session)
458456
delete_sagemaker_model(model.name, sagemaker_session)
459-
# End of Cleanup
457+
460458

461459
def test_tuning_step(sfn_client, record_set_for_hyperparameter_tuning, sagemaker_role_arn, sfn_role_arn):
462460
job_name = generate_job_name()
@@ -507,7 +505,97 @@ def test_tuning_step(sfn_client, record_set_for_hyperparameter_tuning, sagemaker
507505

508506
# Cleanup
509507
state_machine_delete_wait(sfn_client, workflow.state_machine_arn)
510-
# End of Cleanup
508+
509+
510+
def test_tuning_step_with_placeholders(sfn_client, record_set_for_hyperparameter_tuning, sagemaker_role_arn, sfn_role_arn):
511+
kmeans = KMeans(
512+
role=sagemaker_role_arn,
513+
instance_count=1,
514+
instance_type=INSTANCE_TYPE,
515+
k=10
516+
)
517+
518+
hyperparameter_ranges = {
519+
"extra_center_factor": IntegerParameter(4, 10),
520+
"mini_batch_size": IntegerParameter(10, 100),
521+
"epochs": IntegerParameter(1, 2),
522+
"init_method": CategoricalParameter(["kmeans++", "random"]),
523+
}
524+
525+
tuner = HyperparameterTuner(
526+
estimator=kmeans,
527+
objective_metric_name="test:msd",
528+
hyperparameter_ranges=hyperparameter_ranges,
529+
objective_type="Maximize",
530+
max_jobs=2,
531+
max_parallel_jobs=1,
532+
)
533+
534+
execution_input = ExecutionInput(schema={
535+
'job_name': str,
536+
'objective_metric_name': str,
537+
'objective_type': str,
538+
'max_jobs': int,
539+
'max_parallel_jobs': int,
540+
'early_stopping_type': str,
541+
'strategy': str,
542+
})
543+
544+
parameters = {
545+
'HyperParameterTuningJobConfig': {
546+
'HyperParameterTuningJobObjective': {
547+
'MetricName': execution_input['objective_metric_name'],
548+
'Type': execution_input['objective_type']
549+
},
550+
'ResourceLimits': {'MaxNumberOfTrainingJobs': execution_input['max_jobs'],
551+
'MaxParallelTrainingJobs': execution_input['max_parallel_jobs']},
552+
'Strategy': execution_input['strategy'],
553+
'TrainingJobEarlyStoppingType': execution_input['early_stopping_type']
554+
},
555+
'TrainingJobDefinition': {
556+
'AlgorithmSpecification': {
557+
'TrainingInputMode': 'File'
558+
}
559+
}
560+
}
561+
562+
# Build workflow definition
563+
tuning_step = TuningStep('Tuning', tuner=tuner, job_name=execution_input['job_name'],
564+
data=record_set_for_hyperparameter_tuning, parameters=parameters)
565+
tuning_step.add_retry(SAGEMAKER_RETRY_STRATEGY)
566+
workflow_graph = Chain([tuning_step])
567+
568+
with timeout(minutes=DEFAULT_TIMEOUT_MINUTES):
569+
# Create workflow and check definition
570+
workflow = create_workflow_and_check_definition(
571+
workflow_graph=workflow_graph,
572+
workflow_name=unique_name_from_base("integ-test-tuning-step-workflow"),
573+
sfn_client=sfn_client,
574+
sfn_role_arn=sfn_role_arn
575+
)
576+
577+
job_name = generate_job_name()
578+
579+
inputs = {
580+
'job_name': job_name,
581+
'objective_metric_name': 'test:msd',
582+
'objective_type': 'Minimize',
583+
'max_jobs': 2,
584+
'max_parallel_jobs': 2,
585+
'early_stopping_type': 'Off',
586+
'strategy': 'Bayesian',
587+
}
588+
589+
# Execute workflow
590+
execution = workflow.execute(inputs=inputs)
591+
execution_output = execution.get_output(wait=True)
592+
593+
# Check workflow output
594+
assert execution_output.get("HyperParameterTuningJobStatus") == "Completed"
595+
596+
# Cleanup
597+
state_machine_delete_wait(sfn_client, workflow.state_machine_arn)
598+
511599

512600
def test_processing_step(sklearn_processor_fixture, sagemaker_session, sfn_client, sfn_role_arn):
513601
region = boto3.session.Session().region_name
@@ -561,7 +649,6 @@ def test_processing_step(sklearn_processor_fixture, sagemaker_session, sfn_clien
561649

562650
# Cleanup
563651
state_machine_delete_wait(sfn_client, workflow.state_machine_arn)
564-
# End of Cleanup
565652

566653

567654
def test_processing_step_with_placeholders(sklearn_processor_fixture, sagemaker_session, sfn_client, sfn_role_arn,

0 commit comments

Comments
 (0)