Skip to content

Commit 3c89095

Browse files
author
czazzi
committed
Adding support for DataProcessing to SageMaker TransformJob
1 parent 97147ab commit 3c89095

File tree

2 files changed

+61
-4
lines changed

2 files changed

+61
-4
lines changed

src/stepfunctions/steps/sagemaker.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ class TransformStep(Task):
115115
Creates a Task State to execute a `SageMaker Transform Job <https://docs.aws.amazon.com/sagemaker/latest/dg/API_CreateTransformJob.html>`_.
116116
"""
117117

118-
def __init__(self, state_id, transformer, job_name, model_name, data, data_type='S3Prefix', content_type=None, compression_type=None, split_type=None, experiment_config=None, wait_for_completion=True, tags=None, **kwargs):
118+
def __init__(self, state_id, transformer, job_name, model_name, data, data_type='S3Prefix', content_type=None, compression_type=None, split_type=None, experiment_config=None, wait_for_completion=True, tags=None, input_filter=None, output_filter=None, join_source=None, **kwargs):
119119
"""
120120
Args:
121121
state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine.
@@ -150,7 +150,10 @@ def __init__(self, state_id, transformer, job_name, model_name, data, data_type=
150150
content_type=content_type,
151151
compression_type=compression_type,
152152
split_type=split_type,
153-
job_name=job_name
153+
job_name=job_name,
154+
input_filter=input_filter,
155+
join_source=join_source,
156+
output_filter=output_filter
154157
)
155158
else:
156159
parameters = transform_config(
@@ -159,7 +162,10 @@ def __init__(self, state_id, transformer, job_name, model_name, data, data_type=
159162
data_type=data_type,
160163
content_type=content_type,
161164
compression_type=compression_type,
162-
split_type=split_type
165+
split_type=split_type,
166+
input_filter=input_filter,
167+
join_source=join_source,
168+
output_filter=output_filter
163169
)
164170

165171
if isinstance(job_name, (ExecutionInput, StepInput)):
@@ -253,7 +259,7 @@ def __init__(self, state_id, endpoint_config_name, model_name, initial_instance_
253259

254260
if isinstance(data_capture_config, DataCaptureConfig):
255261
parameters['DataCaptureConfig'] = data_capture_config._to_request_dict()
256-
262+
257263
if tags:
258264
parameters['Tags'] = tags_dict_to_kv_list(tags)
259265

tests/unit/test_sagemaker_steps.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -422,6 +422,57 @@ def test_transform_step_creation(pca_transformer):
422422
'End': True
423423
}
424424

425+
step_with_optional_fields = TransformStep('Inference',
426+
transformer=pca_transformer,
427+
data='s3://sagemaker/inference',
428+
job_name='transform-job',
429+
model_name='pca-model',
430+
experiment_config={
431+
'ExperimentName': 'pca_experiment',
432+
'TrialName': 'pca_trial',
433+
'TrialComponentDisplayName': 'Transform'
434+
},
435+
tags=DEFAULT_TAGS,
436+
join_source='Input',
437+
output_filter='$[2:]',
438+
input_filter='$[1:]'
439+
)
440+
assert step_with_optional_fields.to_dict() == {
441+
'Type': 'Task',
442+
'Parameters': {
443+
'ModelName': 'pca-model',
444+
'TransformInput': {
445+
'DataSource': {
446+
'S3DataSource': {
447+
'S3DataType': 'S3Prefix',
448+
'S3Uri': 's3://sagemaker/inference'
449+
}
450+
}
451+
},
452+
'TransformOutput': {
453+
'S3OutputPath': 's3://sagemaker/transform-output'
454+
},
455+
'TransformJobName': 'transform-job',
456+
'TransformResources': {
457+
'InstanceCount': 1,
458+
'InstanceType': 'ml.c4.xlarge'
459+
},
460+
'DataProcessing': {
461+
'InputFilter': '$[1:]',
462+
'OutputFilter': '$[2:]',
463+
'JoinSource': 'Input',
464+
},
465+
'ExperimentConfig': {
466+
'ExperimentName': 'pca_experiment',
467+
'TrialName': 'pca_trial',
468+
'TrialComponentDisplayName': 'Transform'
469+
},
470+
'Tags': DEFAULT_TAGS_LIST
471+
},
472+
'Resource': 'arn:aws:states:::sagemaker:createTransformJob.sync',
473+
'End': True
474+
}
475+
425476
@patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call)
426477
def test_get_expected_model(pca_estimator):
427478
training_step = TrainingStep('Training', estimator=pca_estimator, job_name='TrainingJob')

0 commit comments

Comments
 (0)