Skip to content

Commit 39640f8

Browse files
committed
feat: Support placeholders for TuningStep parameters
1 parent 8b6d0eb commit 39640f8

File tree

3 files changed

+339
-9
lines changed

3 files changed

+339
-9
lines changed

src/stepfunctions/steps/sagemaker.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -454,7 +454,10 @@ def __init__(self, state_id, tuner, job_name, data, wait_for_completion=True, ta
454454
:class:`sagemaker.amazon.amazon_estimator.RecordSet` objects,
455455
where each instance is a different channel of training data.
456456
wait_for_completion(bool, optional): Boolean value set to `True` if the Task state should wait for the tuning job to complete before proceeding to the next step in the workflow. Set to `False` if the Task state should submit the tuning job and proceed to the next step. (default: True)
457-
tags (list[dict], optional): `List to tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
457+
tags (list[dict] or Placeholder, optional): `List to tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
458+
parameters(dict, optional): The value of this field is merged with other arguments to become the request payload for SageMaker `CreateHyperParameterTuningJob<https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateHyperParameterTuningJob.html>`_.
459+
You can use `parameters` to override the value provided by other arguments and specify any field's value dynamically using `Placeholders<https://aws-step-functions-data-science-sdk.readthedocs.io/en/stable/placeholders.html?highlight=placeholder#stepfunctions.inputs.Placeholder>`_.
460+
458461
"""
459462
if wait_for_completion:
460463
"""
@@ -472,19 +475,22 @@ def __init__(self, state_id, tuner, job_name, data, wait_for_completion=True, ta
472475
kwargs[Field.Resource.value] = get_service_integration_arn(SAGEMAKER_SERVICE_NAME,
473476
SageMakerApi.CreateHyperParameterTuningJob)
474477

475-
parameters = tuning_config(tuner=tuner, inputs=data, job_name=job_name).copy()
478+
tuning_parameters = tuning_config(tuner=tuner, inputs=data, job_name=job_name).copy()
476479

477480
if job_name is not None:
478-
parameters['HyperParameterTuningJobName'] = job_name
481+
tuning_parameters['HyperParameterTuningJobName'] = job_name
479482

480-
if 'S3Operations' in parameters:
481-
del parameters['S3Operations']
483+
if 'S3Operations' in tuning_parameters:
484+
del tuning_parameters['S3Operations']
482485

483486
if tags:
484-
parameters['Tags'] = tags_dict_to_kv_list(tags)
487+
tuning_parameters['Tags'] = tags if isinstance(tags, Placeholder) else tags_dict_to_kv_list(tags)
485488

486-
kwargs[Field.Parameters.value] = parameters
489+
if Field.Parameters.value in kwargs and isinstance(kwargs[Field.Parameters.value], dict):
490+
# Update tuning parameters with input parameters
491+
merge_dicts(tuning_parameters, kwargs[Field.Parameters.value])
487492

493+
kwargs[Field.Parameters.value] = tuning_parameters
488494
super(TuningStep, self).__init__(state_id, **kwargs)
489495

490496

tests/integ/test_sagemaker_steps.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,7 @@ def test_create_endpoint_step(trained_estimator, record_set_fixture, sfn_client,
347347
delete_sagemaker_model(model.name, sagemaker_session)
348348
# End of Cleanup
349349

350+
350351
def test_tuning_step(sfn_client, record_set_for_hyperparameter_tuning, sagemaker_role_arn, sfn_role_arn):
351352
job_name = generate_job_name()
352353

@@ -398,6 +399,98 @@ def test_tuning_step(sfn_client, record_set_for_hyperparameter_tuning, sagemaker
398399
state_machine_delete_wait(sfn_client, workflow.state_machine_arn)
399400
# End of Cleanup
400401

402+
403+
def test_tuning_step_with_placeholders(sfn_client, record_set_for_hyperparameter_tuning, sagemaker_role_arn, sfn_role_arn):
404+
kmeans = KMeans(
405+
role=sagemaker_role_arn,
406+
instance_count=1,
407+
instance_type=INSTANCE_TYPE,
408+
k=10
409+
)
410+
411+
hyperparameter_ranges = {
412+
"extra_center_factor": IntegerParameter(4, 10),
413+
"mini_batch_size": IntegerParameter(10, 100),
414+
"epochs": IntegerParameter(1, 2),
415+
"init_method": CategoricalParameter(["kmeans++", "random"]),
416+
}
417+
418+
tuner = HyperparameterTuner(
419+
estimator=kmeans,
420+
objective_metric_name="test:msd",
421+
hyperparameter_ranges=hyperparameter_ranges,
422+
objective_type="Maximize",
423+
max_jobs=2,
424+
max_parallel_jobs=1,
425+
)
426+
427+
execution_input = ExecutionInput(schema={
428+
'job_name': str,
429+
'objective_metric_name': str,
430+
'objective_type': str,
431+
'max_jobs': int,
432+
'max_parallel_jobs': int,
433+
'early_stopping_type': str,
434+
'strategy': str,
435+
})
436+
437+
parameters = {
438+
'HyperParameterTuningJobConfig': {
439+
'HyperParameterTuningJobObjective': {
440+
'MetricName': execution_input['objective_metric_name'],
441+
'Type': execution_input['objective_type']
442+
},
443+
'ResourceLimits': {'MaxNumberOfTrainingJobs': execution_input['max_jobs'],
444+
'MaxParallelTrainingJobs': execution_input['max_parallel_jobs']},
445+
'Strategy': execution_input['strategy'],
446+
'TrainingJobEarlyStoppingType': execution_input['early_stopping_type']
447+
},
448+
'TrainingJobDefinition': {
449+
'AlgorithmSpecification': {
450+
'TrainingInputMode': 'File'
451+
}
452+
}
453+
}
454+
455+
# Build workflow definition
456+
tuning_step = TuningStep('Tuning', tuner=tuner, job_name=execution_input['job_name'],
457+
data=record_set_for_hyperparameter_tuning, parameters=parameters)
458+
tuning_step.add_retry(SAGEMAKER_RETRY_STRATEGY)
459+
workflow_graph = Chain([tuning_step])
460+
461+
with timeout(minutes=DEFAULT_TIMEOUT_MINUTES):
462+
# Create workflow and check definition
463+
workflow = create_workflow_and_check_definition(
464+
workflow_graph=workflow_graph,
465+
workflow_name=unique_name_from_base("integ-test-tuning-step-workflow"),
466+
sfn_client=sfn_client,
467+
sfn_role_arn=sfn_role_arn
468+
)
469+
470+
job_name = generate_job_name()
471+
472+
inputs = {
473+
'job_name': job_name,
474+
'objective_metric_name': 'test:msd',
475+
'objective_type': 'Minimize',
476+
'max_jobs': 2,
477+
'max_parallel_jobs': 2,
478+
'early_stopping_type': 'Off',
479+
'strategy': 'Bayesian',
480+
}
481+
482+
# Execute workflow
483+
execution = workflow.execute(inputs=inputs)
484+
execution_output = execution.get_output(wait=True)
485+
486+
# Check workflow output
487+
assert execution_output.get("HyperParameterTuningJobStatus") == "Completed"
488+
489+
# Cleanup
490+
state_machine_delete_wait(sfn_client, workflow.state_machine_arn)
491+
# End of Cleanup
492+
493+
401494
def test_processing_step(sklearn_processor_fixture, sagemaker_session, sfn_client, sfn_role_arn):
402495
region = boto3.session.Session().region_name
403496
input_data = 's3://sagemaker-sample-data-{}/processing/census/census-income.csv'.format(region)

tests/unit/test_sagemaker_steps.py

Lines changed: 233 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,13 @@
2424
from sagemaker.debugger import Rule, rule_configs, DebuggerHookConfig, CollectionConfig
2525
from sagemaker.sklearn.processing import SKLearnProcessor
2626
from sagemaker.processing import ProcessingInput, ProcessingOutput
27+
from sagemaker.parameter import IntegerParameter, CategoricalParameter
28+
from sagemaker.tuner import HyperparameterTuner
2729

2830
from unittest.mock import MagicMock, patch
2931
from stepfunctions.inputs import ExecutionInput, StepInput
30-
from stepfunctions.steps.fields import Field
3132
from stepfunctions.steps.sagemaker import TrainingStep, TransformStep, ModelStep, EndpointStep, EndpointConfigStep,\
32-
ProcessingStep
33+
ProcessingStep, TuningStep
3334
from stepfunctions.steps.sagemaker import tuning_config
3435

3536
from tests.unit.utils import mock_boto_api_call
@@ -1412,3 +1413,233 @@ def test_processing_step_creation_with_placeholders(sklearn_processor):
14121413
'Resource': 'arn:aws:states:::sagemaker:createProcessingJob.sync',
14131414
'End': True
14141415
}
1416+
1417+
1418+
@patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call)
1419+
@patch.object(boto3.session.Session, 'region_name', 'us-east-1')
1420+
def test_tuning_step_creation_with_framework(tensorflow_estimator):
1421+
hyperparameter_ranges = {
1422+
"extra_center_factor": IntegerParameter(4, 10),
1423+
"epochs": IntegerParameter(1, 2),
1424+
"init_method": CategoricalParameter(["kmeans++", "random"]),
1425+
}
1426+
1427+
tuner = HyperparameterTuner(
1428+
estimator=tensorflow_estimator,
1429+
objective_metric_name="test:msd",
1430+
hyperparameter_ranges=hyperparameter_ranges,
1431+
objective_type="Minimize",
1432+
max_jobs=2,
1433+
max_parallel_jobs=2,
1434+
)
1435+
1436+
step = TuningStep('Tuning',
1437+
tuner=tuner,
1438+
data={'train': 's3://sagemaker/train'},
1439+
job_name='tensorflow-job',
1440+
tags=DEFAULT_TAGS
1441+
)
1442+
1443+
state_machine_definition = step.to_dict()
1444+
# The sagemaker_job_name is generated - expected name will be taken from the generated definition
1445+
generated_sagemaker_job_name = state_machine_definition['Parameters']['TrainingJobDefinition']\
1446+
['StaticHyperParameters']['sagemaker_job_name']
1447+
expected_definition = {
1448+
'Type': 'Task',
1449+
'Parameters': {
1450+
'HyperParameterTuningJobConfig': {
1451+
'HyperParameterTuningJobObjective': {
1452+
'MetricName': 'test:msd',
1453+
'Type': 'Minimize'
1454+
},
1455+
'ParameterRanges': {
1456+
'CategoricalParameterRanges': [
1457+
{
1458+
'Name': 'init_method',
1459+
'Values': ['"kmeans++"', '"random"']
1460+
}],
1461+
'ContinuousParameterRanges': [],
1462+
'IntegerParameterRanges': [
1463+
{
1464+
'MaxValue': '10',
1465+
'MinValue': '4',
1466+
'Name': 'extra_center_factor',
1467+
'ScalingType': 'Auto'
1468+
},
1469+
{
1470+
'MaxValue': '2',
1471+
'MinValue': '1',
1472+
'Name': 'epochs',
1473+
'ScalingType': 'Auto'
1474+
}
1475+
]
1476+
},
1477+
'ResourceLimits': {'MaxNumberOfTrainingJobs': 2,
1478+
'MaxParallelTrainingJobs': 2},
1479+
'Strategy': 'Bayesian',
1480+
'TrainingJobEarlyStoppingType': 'Off'
1481+
},
1482+
'HyperParameterTuningJobName': 'tensorflow-job',
1483+
'Tags': [{'Key': 'Purpose', 'Value': 'unittests'}],
1484+
'TrainingJobDefinition': {
1485+
'AlgorithmSpecification': {
1486+
'TrainingImage': '520713654638.dkr.ecr.us-east-1.amazonaws.com/sagemaker-tensorflow:1.13-gpu-py2',
1487+
'TrainingInputMode': 'File'
1488+
},
1489+
'InputDataConfig': [{'ChannelName': 'train',
1490+
'DataSource': {'S3DataSource': {
1491+
'S3DataDistributionType': 'FullyReplicated',
1492+
'S3DataType': 'S3Prefix',
1493+
'S3Uri': 's3://sagemaker/train'}}}],
1494+
'OutputDataConfig': {'S3OutputPath': 's3://sagemaker/models'},
1495+
'ResourceConfig': {'InstanceCount': 1,
1496+
'InstanceType': 'ml.p2.xlarge',
1497+
'VolumeSizeInGB': 30},
1498+
'RoleArn': 'execution-role',
1499+
'StaticHyperParameters': {
1500+
'checkpoint_path': '"s3://sagemaker/models/sagemaker-tensorflow/checkpoints"',
1501+
'evaluation_steps': '100',
1502+
'sagemaker_container_log_level': '20',
1503+
'sagemaker_estimator_class_name': '"TensorFlow"',
1504+
'sagemaker_estimator_module': '"sagemaker.tensorflow.estimator"',
1505+
'sagemaker_job_name': generated_sagemaker_job_name,
1506+
'sagemaker_program': '"tf_train.py"',
1507+
'sagemaker_region': '"us-east-1"',
1508+
'sagemaker_submit_directory': '"s3://sagemaker/source"',
1509+
'training_steps': '1000'},
1510+
'StoppingCondition': {'MaxRuntimeInSeconds': 86400}}},
1511+
'Resource': 'arn:aws:states:::sagemaker:createHyperParameterTuningJob.sync',
1512+
'End': True
1513+
}
1514+
1515+
assert state_machine_definition == expected_definition
1516+
1517+
1518+
@patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call)
1519+
@patch.object(boto3.session.Session, 'region_name', 'us-east-1')
1520+
def test_tuning_step_creation_with_placeholders(tensorflow_estimator):
1521+
execution_input = ExecutionInput(schema={
1522+
'data_input': str,
1523+
'tags': list,
1524+
'objective_metric_name': str,
1525+
'hyperparameter_ranges': str,
1526+
'objective_type': str,
1527+
'max_jobs': int,
1528+
'max_parallel_jobs': int,
1529+
'early_stopping_type': str,
1530+
'strategy': str,
1531+
})
1532+
1533+
step_input = StepInput(schema={
1534+
'job_name': str
1535+
})
1536+
1537+
hyperparameter_ranges = {
1538+
"extra_center_factor": IntegerParameter(4, 10),
1539+
"epochs": IntegerParameter(1, 2),
1540+
"init_method": CategoricalParameter(["kmeans++", "random"]),
1541+
}
1542+
1543+
tuner = HyperparameterTuner(
1544+
estimator=tensorflow_estimator,
1545+
objective_metric_name="test:msd",
1546+
hyperparameter_ranges=hyperparameter_ranges,
1547+
objective_type="Minimize",
1548+
max_jobs=2,
1549+
max_parallel_jobs=2,
1550+
)
1551+
1552+
parameters = {
1553+
'HyperParameterTuningJobConfig': {
1554+
'HyperParameterTuningJobObjective': {
1555+
'MetricName': execution_input['objective_metric_name'],
1556+
'Type': execution_input['objective_type']
1557+
},
1558+
'ResourceLimits': {'MaxNumberOfTrainingJobs': execution_input['max_jobs'],
1559+
'MaxParallelTrainingJobs': execution_input['max_parallel_jobs']},
1560+
'Strategy': execution_input['strategy'],
1561+
'TrainingJobEarlyStoppingType': execution_input['early_stopping_type']
1562+
},
1563+
'TrainingJobDefinition': {
1564+
'AlgorithmSpecification': {
1565+
'TrainingInputMode': 'File'
1566+
},
1567+
'HyperParameterRanges': execution_input['hyperparameter_ranges'],
1568+
'InputDataConfig': execution_input['data_input']
1569+
}
1570+
}
1571+
1572+
step = TuningStep('Tuning',
1573+
tuner=tuner,
1574+
data={'train': 's3://sagemaker/train'},
1575+
job_name=step_input['job_name'],
1576+
tags=execution_input['tags'],
1577+
parameters=parameters
1578+
)
1579+
1580+
state_machine_definition = step.to_dict()
1581+
# The sagemaker_job_name is generated - expected name will be taken from the generated definition
1582+
generated_sagemaker_job_name = state_machine_definition['Parameters']['TrainingJobDefinition']['StaticHyperParameters']['sagemaker_job_name']
1583+
expected_parameters = {
1584+
'HyperParameterTuningJobConfig': {
1585+
'HyperParameterTuningJobObjective': {
1586+
'MetricName.$': "$$.Execution.Input['objective_metric_name']",
1587+
'Type.$': "$$.Execution.Input['objective_type']"
1588+
},
1589+
'ParameterRanges': {
1590+
'CategoricalParameterRanges': [
1591+
{
1592+
'Name': 'init_method',
1593+
'Values': ['"kmeans++"', '"random"']
1594+
}],
1595+
'ContinuousParameterRanges': [],
1596+
'IntegerParameterRanges': [
1597+
{
1598+
'MaxValue': '10',
1599+
'MinValue': '4',
1600+
'Name': 'extra_center_factor',
1601+
'ScalingType': 'Auto'
1602+
},
1603+
{
1604+
'MaxValue': '2',
1605+
'MinValue': '1',
1606+
'Name': 'epochs',
1607+
'ScalingType': 'Auto'
1608+
}
1609+
]
1610+
},
1611+
'ResourceLimits': {'MaxNumberOfTrainingJobs.$': "$$.Execution.Input['max_jobs']",
1612+
'MaxParallelTrainingJobs.$': "$$.Execution.Input['max_parallel_jobs']"},
1613+
'Strategy.$': "$$.Execution.Input['strategy']",
1614+
'TrainingJobEarlyStoppingType.$': "$$.Execution.Input['early_stopping_type']"
1615+
},
1616+
'HyperParameterTuningJobName.$': "$['job_name']",
1617+
'Tags.$': "$$.Execution.Input['tags']",
1618+
'TrainingJobDefinition': {
1619+
'AlgorithmSpecification': {
1620+
'TrainingImage': '520713654638.dkr.ecr.us-east-1.amazonaws.com/sagemaker-tensorflow:1.13-gpu-py2',
1621+
'TrainingInputMode': 'File'
1622+
},
1623+
'HyperParameterRanges.$': "$$.Execution.Input['hyperparameter_ranges']",
1624+
'InputDataConfig.$': "$$.Execution.Input['data_input']",
1625+
'OutputDataConfig': {'S3OutputPath': 's3://sagemaker/models'},
1626+
'ResourceConfig': {'InstanceCount': 1,
1627+
'InstanceType': 'ml.p2.xlarge',
1628+
'VolumeSizeInGB': 30},
1629+
'RoleArn': 'execution-role',
1630+
'StaticHyperParameters': {
1631+
'checkpoint_path': '"s3://sagemaker/models/sagemaker-tensorflow/checkpoints"',
1632+
'evaluation_steps': '100',
1633+
'sagemaker_container_log_level': '20',
1634+
'sagemaker_estimator_class_name': '"TensorFlow"',
1635+
'sagemaker_estimator_module': '"sagemaker.tensorflow.estimator"',
1636+
'sagemaker_job_name': generated_sagemaker_job_name,
1637+
'sagemaker_program': '"tf_train.py"',
1638+
'sagemaker_region': '"us-east-1"',
1639+
'sagemaker_submit_directory': '"s3://sagemaker/source"',
1640+
'training_steps': '1000'},
1641+
'StoppingCondition': {'MaxRuntimeInSeconds': 86400}
1642+
}
1643+
}
1644+
1645+
assert state_machine_definition['Parameters'] == expected_parameters

0 commit comments

Comments
 (0)