diff --git a/src/stepfunctions/steps/sagemaker.py b/src/stepfunctions/steps/sagemaker.py index dba0b64..051a05a 100644 --- a/src/stepfunctions/steps/sagemaker.py +++ b/src/stepfunctions/steps/sagemaker.py @@ -115,7 +115,7 @@ class TransformStep(Task): Creates a Task State to execute a `SageMaker Transform Job `_. """ - def __init__(self, state_id, transformer, job_name, model_name, data, data_type='S3Prefix', content_type=None, compression_type=None, split_type=None, experiment_config=None, wait_for_completion=True, tags=None, **kwargs): + def __init__(self, state_id, transformer, job_name, model_name, data, data_type='S3Prefix', content_type=None, compression_type=None, split_type=None, experiment_config=None, wait_for_completion=True, tags=None, input_filter=None, output_filter=None, join_source=None, **kwargs): """ Args: state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine. @@ -150,7 +150,10 @@ def __init__(self, state_id, transformer, job_name, model_name, data, data_type= content_type=content_type, compression_type=compression_type, split_type=split_type, - job_name=job_name + job_name=job_name, + input_filter=input_filter, + join_source=join_source, + output_filter=output_filter ) else: parameters = transform_config( @@ -159,7 +162,10 @@ def __init__(self, state_id, transformer, job_name, model_name, data, data_type= data_type=data_type, content_type=content_type, compression_type=compression_type, - split_type=split_type + split_type=split_type, + input_filter=input_filter, + join_source=join_source, + output_filter=output_filter ) if isinstance(job_name, (ExecutionInput, StepInput)): @@ -253,7 +259,7 @@ def __init__(self, state_id, endpoint_config_name, model_name, initial_instance_ if isinstance(data_capture_config, DataCaptureConfig): parameters['DataCaptureConfig'] = data_capture_config._to_request_dict() - + if tags: parameters['Tags'] = tags_dict_to_kv_list(tags) diff --git a/tests/unit/test_sagemaker_steps.py b/tests/unit/test_sagemaker_steps.py index c7ca95e..099a1be 100644 --- a/tests/unit/test_sagemaker_steps.py +++ b/tests/unit/test_sagemaker_steps.py @@ -422,6 +422,57 @@ def test_transform_step_creation(pca_transformer): 'End': True } + step_with_optional_fields = TransformStep('Inference', + transformer=pca_transformer, + data='s3://sagemaker/inference', + job_name='transform-job', + model_name='pca-model', + experiment_config={ + 'ExperimentName': 'pca_experiment', + 'TrialName': 'pca_trial', + 'TrialComponentDisplayName': 'Transform' + }, + tags=DEFAULT_TAGS, + join_source='Input', + output_filter='$[2:]', + input_filter='$[1:]' + ) + assert step_with_optional_fields.to_dict() == { + 'Type': 'Task', + 'Parameters': { + 'ModelName': 'pca-model', + 'TransformInput': { + 'DataSource': { + 'S3DataSource': { + 'S3DataType': 'S3Prefix', + 'S3Uri': 's3://sagemaker/inference' + } + } + }, + 'TransformOutput': { + 'S3OutputPath': 's3://sagemaker/transform-output' + }, + 'TransformJobName': 'transform-job', + 'TransformResources': { + 'InstanceCount': 1, + 'InstanceType': 'ml.c4.xlarge' + }, + 'DataProcessing': { + 'InputFilter': '$[1:]', + 'OutputFilter': '$[2:]', + 'JoinSource': 'Input', + }, + 'ExperimentConfig': { + 'ExperimentName': 'pca_experiment', + 'TrialName': 'pca_trial', + 'TrialComponentDisplayName': 'Transform' + }, + 'Tags': DEFAULT_TAGS_LIST + }, + 'Resource': 'arn:aws:states:::sagemaker:createTransformJob.sync', + 'End': True + } + @patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call) def test_get_expected_model(pca_estimator): training_step = TrainingStep('Training', estimator=pca_estimator, job_name='TrainingJob')