diff --git a/requirements.txt b/requirements.txt index 464d5b1..1f431a6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,3 @@ -sagemaker>=1.71.0 +sagemaker>=1.71.0,<2.0.0 boto3>=1.9.213 pyyaml diff --git a/setup.py b/setup.py index 5476bfc..6fe1f3d 100644 --- a/setup.py +++ b/setup.py @@ -30,7 +30,7 @@ def read_version(): # Declare minimal set for installation required_packages = [ - "sagemaker>=1.42.8", + "sagemaker>=1.42.8,<2.0.0", "boto3>=1.9.213", "pyyaml" ] diff --git a/src/stepfunctions/exceptions.py b/src/stepfunctions/exceptions.py index 7e9a4d7..2e67190 100644 --- a/src/stepfunctions/exceptions.py +++ b/src/stepfunctions/exceptions.py @@ -21,5 +21,9 @@ class MissingRequiredParameter(Exception): pass +class ForbiddenValueParameter(Exception): + pass + + class DuplicateStatesInChain(Exception): - pass \ No newline at end of file + pass diff --git a/src/stepfunctions/steps/compute.py b/src/stepfunctions/steps/compute.py index ebe76eb..59d9cc7 100644 --- a/src/stepfunctions/steps/compute.py +++ b/src/stepfunctions/steps/compute.py @@ -12,8 +12,8 @@ # permissions and limitations under the License. from __future__ import absolute_import +from stepfunctions.steps.states import IntegrationPattern from stepfunctions.steps.states import Task -from stepfunctions.steps.fields import Field class LambdaStep(Task): @@ -22,11 +22,11 @@ class LambdaStep(Task): Creates a Task state to invoke an AWS Lambda function. See `Invoke Lambda with Step Functions `_ for more details. """ - def __init__(self, state_id, wait_for_callback=False, **kwargs): + def __init__(self, state_id, integration_pattern=IntegrationPattern.RequestResponse, **kwargs): """ Args: state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine. - wait_for_callback(bool, optional): Boolean value set to `True` if the Task state should wait for callback to resume the operation. (default: False) + integration_pattern(stepfunctions.states.IntegrationPattern, optional): Enum value set to RunAJob if the task should wait to complete before proceeding to the next step in the workflow, set to WaitForCallback if the Task state should wait for callback to resume the operation or set to RequestResponse if the Task should wait for HTTP response (default: RequestResponse) timeout_seconds (int, optional): Positive integer specifying timeout for the state in seconds. If the state runs longer than the specified timeout, then the interpreter fails the state with a `States.Timeout` Error Name. (default: 60) heartbeat_seconds (int, optional): Positive integer specifying heartbeat timeout for the state in seconds. This value should be lower than the one specified for `timeout_seconds`. If more time than the specified heartbeat elapses between heartbeats from the task, then the interpreter fails the state with a `States.Timeout` Error Name. comment (str, optional): Human-readable comment or description. (default: None) @@ -35,12 +35,11 @@ def __init__(self, state_id, wait_for_callback=False, **kwargs): result_path (str, optional): Path specifying the raw input’s combination with or replacement by the state’s result. (default: '$') output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$') """ - if wait_for_callback: - kwargs[Field.Resource.value] = 'arn:aws:states:::lambda:invoke.waitForTaskToken' - else: - kwargs[Field.Resource.value] = 'arn:aws:states:::lambda:invoke' - - super(LambdaStep, self).__init__(state_id, **kwargs) + self._valid_patterns = [IntegrationPattern.RequestResponse, IntegrationPattern.WaitForCallback] + self._integration_pattern = integration_pattern + action = "lambda:invoke" + step_name = "Lambda" + super(LambdaStep, self).__init__(state_id, action, step_name, **kwargs) class GlueStartJobRunStep(Task): @@ -49,11 +48,11 @@ class GlueStartJobRunStep(Task): Creates a Task state to run an AWS Glue job. See `Manage AWS Glue Jobs with Step Functions `_ for more details. """ - def __init__(self, state_id, wait_for_completion=True, **kwargs): + def __init__(self, state_id, integration_pattern=IntegrationPattern.RunAJob, **kwargs): """ Args: state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine. - wait_for_completion(bool, optional): Boolean value set to `True` if the Task state should wait for the glue job to complete before proceeding to the next step in the workflow. Set to `False` if the Task state should submit the glue job and proceed to the next step. (default: True) + integration_pattern(stepfunctions.states.IntegrationPattern, optional): Enum value set to RunAJob if the task should wait to complete before proceeding to the next step in the workflow, set to WaitForCallback if the Task state should wait for callback to resume the operation or set to RequestResponse if the Task should wait for HTTP response (default: RequestResponse) timeout_seconds (int, optional): Positive integer specifying timeout for the state in seconds. If the state runs longer than the specified timeout, then the interpreter fails the state with a `States.Timeout` Error Name. (default: 60) heartbeat_seconds (int, optional): Positive integer specifying heartbeat timeout for the state in seconds. This value should be lower than the one specified for `timeout_seconds`. If more time than the specified heartbeat elapses between heartbeats from the task, then the interpreter fails the state with a `States.Timeout` Error Name. comment (str, optional): Human-readable comment or description. (default: None) @@ -62,12 +61,9 @@ def __init__(self, state_id, wait_for_completion=True, **kwargs): result_path (str, optional): Path specifying the raw input’s combination with or replacement by the state’s result. (default: '$') output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$') """ - if wait_for_completion: - kwargs[Field.Resource.value] = 'arn:aws:states:::glue:startJobRun.sync' - else: - kwargs[Field.Resource.value] = 'arn:aws:states:::glue:startJobRun' - - super(GlueStartJobRunStep, self).__init__(state_id, **kwargs) + self._valid_patterns = [IntegrationPattern.RequestResponse, IntegrationPattern.RunAJob] + self._integration_pattern = integration_pattern + super(GlueStartJobRunStep, self).__init__(state_id, "glue:startJobRun", "AWS Glue", **kwargs) class BatchSubmitJobStep(Task): @@ -76,11 +72,11 @@ class BatchSubmitJobStep(Task): Creates a Task State to start an AWS Batch job. See `Manage AWS Batch with Step Functions `_ for more details. """ - def __init__(self, state_id, wait_for_completion=True, **kwargs): + def __init__(self, state_id, integration_pattern=IntegrationPattern.RunAJob, **kwargs): """ Args: state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine. - wait_for_completion(bool, optional): Boolean value set to `True` if the Task state should wait for the batch job to complete before proceeding to the next step in the workflow. Set to `False` if the Task state should submit the batch job and proceed to the next step. (default: True) + integration_pattern(stepfunctions.states.IntegrationPattern, optional): Enum value set to RunAJob if the task should wait to complete before proceeding to the next step in the workflow, set to WaitForCallback if the Task state should wait for callback to resume the operation or set to RequestResponse if the Task should wait for HTTP response (default: RequestResponse) timeout_seconds (int, optional): Positive integer specifying timeout for the state in seconds. If the state runs longer than the specified timeout, then the interpreter fails the state with a `States.Timeout` Error Name. (default: 60) heartbeat_seconds (int, optional): Positive integer specifying heartbeat timeout for the state in seconds. This value should be lower than the one specified for `timeout_seconds`. If more time than the specified heartbeat elapses between heartbeats from the task, then the interpreter fails the state with a `States.Timeout` Error Name. comment (str, optional): Human-readable comment or description. (default: None) @@ -89,12 +85,9 @@ def __init__(self, state_id, wait_for_completion=True, **kwargs): result_path (str, optional): Path specifying the raw input’s combination with or replacement by the state’s result. (default: '$') output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$') """ - if wait_for_completion: - kwargs[Field.Resource.value] = 'arn:aws:states:::batch:submitJob.sync' - else: - kwargs[Field.Resource.value] = 'arn:aws:states:::batch:submitJob' - - super(BatchSubmitJobStep, self).__init__(state_id, **kwargs) + self._valid_patterns = [IntegrationPattern.RequestResponse, IntegrationPattern.RunAJob] + self._integration_pattern = integration_pattern + super(BatchSubmitJobStep, self).__init__(state_id, "batch:submitJob", "AWS Batch", **kwargs) class EcsRunTaskStep(Task): @@ -103,12 +96,11 @@ class EcsRunTaskStep(Task): Creates a Task State to run Amazon ECS or Fargate Tasks. See `Manage Amazon ECS or Fargate Tasks with Step Functions `_ for more details. """ - def __init__(self, state_id, wait_for_completion=True, **kwargs): + def __init__(self, state_id, integration_pattern=IntegrationPattern.RequestResponse, **kwargs): """ Args: state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine. - wait_for_completion(bool, optional): Boolean value set to `True` if the Task state should wait for the ecs job to complete before proceeding to the next step in the workflow. Set to `False` if the Task state should submit the ecs job and proceed to the next step. (default: True) - timeout_seconds (int, optional): Positive integer specifying timeout for the state in seconds. If the state runs longer than the specified timeout, then the interpreter fails the state with a `States.Timeout` Error Name. (default: 60) + integration_pattern(stepfunctions.states.IntegrationPattern, optional): Enum value set to RunAJob if the task should wait to complete before proceeding to the next step in the workflow, set to WaitForCallback if the Task state should wait for callback to resume the operation or set to RequestResponse if the Task should wait for HTTP response (default: RequestResponse) timeout_seconds (int, optional): Positive integer specifying timeout for the state in seconds. If the state runs longer than the specified timeout, then the interpreter fails the state with a `States.Timeout` Error Name. (default: 60) heartbeat_seconds (int, optional): Positive integer specifying heartbeat timeout for the state in seconds. This value should be lower than the one specified for `timeout_seconds`. If more time than the specified heartbeat elapses between heartbeats from the task, then the interpreter fails the state with a `States.Timeout` Error Name. comment (str, optional): Human-readable comment or description. (default: None) input_path (str, optional): Path applied to the state’s raw input to select some or all of it; that selection is used by the state. (default: '$') @@ -116,9 +108,6 @@ def __init__(self, state_id, wait_for_completion=True, **kwargs): result_path (str, optional): Path specifying the raw input’s combination with or replacement by the state’s result. (default: '$') output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$') """ - if wait_for_completion: - kwargs[Field.Resource.value] = 'arn:aws:states:::ecs:runTask.sync' - else: - kwargs[Field.Resource.value] = 'arn:aws:states:::ecs:runTask' - - super(EcsRunTaskStep, self).__init__(state_id, **kwargs) + self._valid_patterns = [IntegrationPattern.RequestResponse, IntegrationPattern.RunAJob, IntegrationPattern.WaitForCallback] + self._integration_pattern = integration_pattern + super(EcsRunTaskStep, self).__init__(state_id, "ecs:runTask", "Amazon ECS/AWS Fargate", **kwargs) diff --git a/src/stepfunctions/steps/sagemaker.py b/src/stepfunctions/steps/sagemaker.py index 6321563..a802070 100644 --- a/src/stepfunctions/steps/sagemaker.py +++ b/src/stepfunctions/steps/sagemaker.py @@ -12,22 +12,24 @@ # permissions and limitations under the License. from __future__ import absolute_import +from sagemaker.model import Model, FrameworkModel +from sagemaker.model_monitor import DataCaptureConfig +from sagemaker.workflow.airflow import training_config, transform_config, model_config, tuning_config, processing_config +from stepfunctions.steps.states import IntegrationPattern + from stepfunctions.inputs import ExecutionInput, StepInput -from stepfunctions.steps.states import Task from stepfunctions.steps.fields import Field +from stepfunctions.steps.states import Task from stepfunctions.steps.utils import tags_dict_to_kv_list -from sagemaker.workflow.airflow import training_config, transform_config, model_config, tuning_config, processing_config -from sagemaker.model import Model, FrameworkModel -from sagemaker.model_monitor import DataCaptureConfig class TrainingStep(Task): - """ Creates a Task State to execute a `SageMaker Training Job `_. The TrainingStep will also create a model by default, and the model shares the same name as the training job. """ - def __init__(self, state_id, estimator, job_name, data=None, hyperparameters=None, mini_batch_size=None, experiment_config=None, wait_for_completion=True, tags=None, **kwargs): + def __init__(self, state_id, estimator, job_name, data=None, hyperparameters=None, mini_batch_size=None, + experiment_config=None, integration_pattern=IntegrationPattern.RunAJob, tags=None, **kwargs): """ Args: state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine. @@ -51,19 +53,18 @@ def __init__(self, state_id, estimator, job_name, data=None, hyperparameters=Non hyperparameters (dict, optional): Specify the hyper parameters for the training. (Default: None) mini_batch_size (int): Specify this argument only when estimator is a built-in estimator of an Amazon algorithm. For other estimators, batch size should be specified in the estimator. experiment_config (dict, optional): Specify the experiment config for the training. (Default: None) - wait_for_completion (bool, optional): Boolean value set to `True` if the Task state should wait for the training job to complete before proceeding to the next step in the workflow. Set to `False` if the Task state should submit the training job and proceed to the next step. (default: True) + integration_pattern(stepfunctions.states.IntegrationPattern, optional): Enum value set to RunAJob if the task should wait to complete before proceeding to the next step in the workflow, set to WaitForCallback if the Task state should wait for callback to resume the operation or set to RequestResponse if the Task should wait for HTTP response (default: RequestResponse) tags (list[dict], optional): `List to tags `_ to associate with the resource. """ self.estimator = estimator self.job_name = job_name - if wait_for_completion: - kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createTrainingJob.sync' - else: - kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createTrainingJob' + self._valid_patterns = [IntegrationPattern.RequestResponse, IntegrationPattern.RunAJob] + self._integration_pattern = integration_pattern if isinstance(job_name, str): - parameters = training_config(estimator=estimator, inputs=data, job_name=job_name, mini_batch_size=mini_batch_size) + parameters = training_config(estimator=estimator, inputs=data, job_name=job_name, + mini_batch_size=mini_batch_size) else: parameters = training_config(estimator=estimator, inputs=data, mini_batch_size=mini_batch_size) @@ -89,7 +90,7 @@ def __init__(self, state_id, estimator, job_name, data=None, hyperparameters=Non parameters['Tags'] = tags_dict_to_kv_list(tags) kwargs[Field.Parameters.value] = parameters - super(TrainingStep, self).__init__(state_id, **kwargs) + super(TrainingStep, self).__init__(state_id, 'sagemaker:createTrainingJob', 'Amazon SageMaker', **kwargs) def get_expected_model(self, model_name=None): """ @@ -110,12 +111,13 @@ def get_expected_model(self, model_name=None): class TransformStep(Task): - """ Creates a Task State to execute a `SageMaker Transform Job `_. """ - def __init__(self, state_id, transformer, job_name, model_name, data, data_type='S3Prefix', content_type=None, compression_type=None, split_type=None, experiment_config=None, wait_for_completion=True, tags=None, input_filter=None, output_filter=None, join_source=None, **kwargs): + def __init__(self, state_id, transformer, job_name, model_name, data, data_type='S3Prefix', content_type=None, + compression_type=None, split_type=None, experiment_config=None, integration_pattern=IntegrationPattern.RunAJob, tags=None, + input_filter=None, output_filter=None, join_source=None, **kwargs): """ Args: state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine. @@ -134,16 +136,14 @@ def __init__(self, state_id, transformer, job_name, model_name, data, data_type= compression_type (str): Compression type of the input data, if compressed (default: None). Valid values: 'Gzip', None. split_type (str): The record delimiter for the input object (default: 'None'). Valid values: 'None', 'Line', 'RecordIO', and 'TFRecord'. experiment_config (dict, optional): Specify the experiment config for the transform. (Default: None) - wait_for_completion(bool, optional): Boolean value set to `True` if the Task state should wait for the transform job to complete before proceeding to the next step in the workflow. Set to `False` if the Task state should submit the transform job and proceed to the next step. (default: True) + integration_pattern(stepfunctions.states.IntegrationPattern, optional): Enum value set to RunAJob if the task should wait to complete before proceeding to the next step in the workflow, set to WaitForCallback if the Task state should wait for callback to resume the operation or set to RequestResponse if the Task should wait for HTTP response (default: RequestResponse) tags (list[dict], optional): `List to tags `_ to associate with the resource. input_filter (str): A JSONPath to select a portion of the input to pass to the algorithm container for inference. If you omit the field, it gets the value ‘$’, representing the entire input. For CSV data, each row is taken as a JSON array, so only index-based JSONPaths can be applied, e.g. $[0], $[1:]. CSV data should follow the RFC format. See Supported JSONPath Operators for a table of supported JSONPath operators. For more information, see the SageMaker API documentation for CreateTransformJob. Some examples: “$[1:]”, “$.features” (default: None). output_filter (str): A JSONPath to select a portion of the joined/original output to return as the output. For more information, see the SageMaker API documentation for CreateTransformJob. Some examples: “$[1:]”, “$.prediction” (default: None). join_source (str): The source of data to be joined to the transform output. It can be set to ‘Input’ meaning the entire input record will be joined to the inference result. You can use OutputFilter to select the useful portion before uploading to S3. (default: None). Valid values: Input, None. """ - if wait_for_completion: - kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createTransformJob.sync' - else: - kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createTransformJob' + self._valid_patterns = [IntegrationPattern.RequestResponse, IntegrationPattern.RunAJob] + self._integration_pattern = integration_pattern if isinstance(job_name, str): parameters = transform_config( @@ -183,11 +183,10 @@ def __init__(self, state_id, transformer, job_name, model_name, data, data_type= parameters['Tags'] = tags_dict_to_kv_list(tags) kwargs[Field.Parameters.value] = parameters - super(TransformStep, self).__init__(state_id, **kwargs) + super(TransformStep, self).__init__(state_id, 'sagemaker:createTransformJob', 'Amazon SageMaker', **kwargs) class ModelStep(Task): - """ Creates a Task State to `create a model in SageMaker `_. """ @@ -216,7 +215,9 @@ def __init__(self, state_id, model, model_name=None, instance_type=None, tags=No } } else: - raise ValueError("Expected 'model' parameter to be of type 'sagemaker.model.Model', but received type '{}'".format(type(model).__name__)) + raise ValueError( + "Expected 'model' parameter to be of type 'sagemaker.model.Model', but received type '{}'".format( + type(model).__name__)) if 'S3Operations' in parameters: del parameters['S3Operations'] @@ -225,18 +226,17 @@ def __init__(self, state_id, model, model_name=None, instance_type=None, tags=No parameters['Tags'] = tags_dict_to_kv_list(tags) kwargs[Field.Parameters.value] = parameters - kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createModel' - super(ModelStep, self).__init__(state_id, **kwargs) + super(ModelStep, self).__init__(state_id, 'sagemaker:createModel', 'Amazon SageMaker', **kwargs) class EndpointConfigStep(Task): - """ Creates a Task State to `create an endpoint configuration in SageMaker `_. """ - def __init__(self, state_id, endpoint_config_name, model_name, initial_instance_count, instance_type, variant_name='AllTraffic', data_capture_config=None, tags=None, **kwargs): + def __init__(self, state_id, endpoint_config_name, model_name, initial_instance_count, instance_type, + variant_name='AllTraffic', data_capture_config=None, tags=None, **kwargs): """ Args: state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine. @@ -262,18 +262,16 @@ def __init__(self, state_id, endpoint_config_name, model_name, initial_instance_ if isinstance(data_capture_config, DataCaptureConfig): parameters['DataCaptureConfig'] = data_capture_config._to_request_dict() - + if tags: parameters['Tags'] = tags_dict_to_kv_list(tags) - kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createEndpointConfig' kwargs[Field.Parameters.value] = parameters - super(EndpointConfigStep, self).__init__(state_id, **kwargs) + super(EndpointConfigStep, self).__init__(state_id, 'sagemaker:createEndpointConfig', 'Amazon SageMaker', **kwargs) class EndpointStep(Task): - """ Creates a Task State to `create `_ or `update `_ an endpoint in SageMaker. """ @@ -297,23 +295,17 @@ def __init__(self, state_id, endpoint_name, endpoint_config_name, tags=None, upd if tags: parameters['Tags'] = tags_dict_to_kv_list(tags) - if update: - kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:updateEndpoint' - else: - kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createEndpoint' - kwargs[Field.Parameters.value] = parameters - super(EndpointStep, self).__init__(state_id, **kwargs) + super(EndpointStep, self).__init__(state_id, f'sagemaker:{"updateEndpoint" if update else "createEndpoint"}', 'Amazon SageMaker', **kwargs) class TuningStep(Task): - """ Creates a Task State to execute a SageMaker HyperParameterTuning Job. """ - def __init__(self, state_id, tuner, job_name, data, wait_for_completion=True, tags=None, **kwargs): + def __init__(self, state_id, tuner, job_name, data, integration_pattern=IntegrationPattern.RunAJob, tags=None, **kwargs): """ Args: state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine. @@ -334,13 +326,11 @@ def __init__(self, state_id, tuner, job_name, data, wait_for_completion=True, ta * (list[sagemaker.amazon.amazon_estimator.RecordSet]) - A list of :class:`sagemaker.amazon.amazon_estimator.RecordSet` objects, where each instance is a different channel of training data. - wait_for_completion(bool, optional): Boolean value set to `True` if the Task state should wait for the tuning job to complete before proceeding to the next step in the workflow. Set to `False` if the Task state should submit the tuning job and proceed to the next step. (default: True) + integration_pattern(stepfunctions.states.IntegrationPattern, optional): Enum value set to RunAJob if the task should wait to complete before proceeding to the next step in the workflow, set to WaitForCallback if the Task state should wait for callback to resume the operation or set to RequestResponse if the Task should wait for HTTP response (default: RequestResponse) tags (list[dict], optional): `List to tags `_ to associate with the resource. """ - if wait_for_completion: - kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createHyperParameterTuningJob.sync' - else: - kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createHyperParameterTuningJob' + self._valid_patterns = [IntegrationPattern.RequestResponse, IntegrationPattern.RunAJob] + self._integration_pattern = integration_pattern parameters = tuning_config(tuner=tuner, inputs=data, job_name=job_name).copy() @@ -355,16 +345,17 @@ def __init__(self, state_id, tuner, job_name, data, wait_for_completion=True, ta kwargs[Field.Parameters.value] = parameters - super(TuningStep, self).__init__(state_id, **kwargs) + super(TuningStep, self).__init__(state_id, 'sagemaker:createHyperParameterTuningJob', 'Amazon SageMaker', **kwargs) class ProcessingStep(Task): - """ Creates a Task State to execute a SageMaker Processing Job. """ - def __init__(self, state_id, processor, job_name, inputs=None, outputs=None, experiment_config=None, container_arguments=None, container_entrypoint=None, kms_key_id=None, wait_for_completion=True, tags=None, **kwargs): + def __init__(self, state_id, processor, job_name, inputs=None, outputs=None, experiment_config=None, + container_arguments=None, container_entrypoint=None, kms_key_id=None, integration_pattern=IntegrationPattern.RunAJob, + tags=None, **kwargs): """ Args: state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine. @@ -383,31 +374,34 @@ def __init__(self, state_id, processor, job_name, inputs=None, outputs=None, exp uses to encrypt the processing job output. KmsKeyId can be an ID of a KMS key, ARN of a KMS key, alias of a KMS key, or alias of a KMS key. The KmsKeyId is applied to all outputs. - wait_for_completion (bool, optional): Boolean value set to `True` if the Task state should wait for the processing job to complete before proceeding to the next step in the workflow. Set to `False` if the Task state should submit the processing job and proceed to the next step. (default: True) + integration_pattern(stepfunctions.states.IntegrationPattern, optional): Enum value set to RunAJob if the task should wait to complete before proceeding to the next step in the workflow, set to WaitForCallback if the Task state should wait for callback to resume the operation or set to RequestResponse if the Task should wait for HTTP response (default: RequestResponse) tags (list[dict], optional): `List to tags `_ to associate with the resource. """ - if wait_for_completion: - kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createProcessingJob.sync' - else: - kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createProcessingJob' - + self._valid_patterns = [IntegrationPattern.RequestResponse, IntegrationPattern.RunAJob] + self._integration_pattern = integration_pattern + if isinstance(job_name, str): - parameters = processing_config(processor=processor, inputs=inputs, outputs=outputs, container_arguments=container_arguments, container_entrypoint=container_entrypoint, kms_key_id=kms_key_id, job_name=job_name) + parameters = processing_config(processor=processor, inputs=inputs, outputs=outputs, + container_arguments=container_arguments, + container_entrypoint=container_entrypoint, kms_key_id=kms_key_id, + job_name=job_name) else: - parameters = processing_config(processor=processor, inputs=inputs, outputs=outputs, container_arguments=container_arguments, container_entrypoint=container_entrypoint, kms_key_id=kms_key_id) + parameters = processing_config(processor=processor, inputs=inputs, outputs=outputs, + container_arguments=container_arguments, + container_entrypoint=container_entrypoint, kms_key_id=kms_key_id) if isinstance(job_name, (ExecutionInput, StepInput)): parameters['ProcessingJobName'] = job_name - + if experiment_config is not None: parameters['ExperimentConfig'] = experiment_config - + if tags: parameters['Tags'] = tags_dict_to_kv_list(tags) - + if 'S3Operations' in parameters: del parameters['S3Operations'] - + kwargs[Field.Parameters.value] = parameters - super(ProcessingStep, self).__init__(state_id, **kwargs) + super(ProcessingStep, self).__init__(state_id, 'sagemaker:createProcessingJob', 'Amazon SageMaker', **kwargs) diff --git a/src/stepfunctions/steps/service.py b/src/stepfunctions/steps/service.py index 5e74161..73f3845 100644 --- a/src/stepfunctions/steps/service.py +++ b/src/stepfunctions/steps/service.py @@ -12,8 +12,9 @@ # permissions and limitations under the License. from __future__ import absolute_import -from stepfunctions.steps.states import Task + from stepfunctions.steps.fields import Field +from stepfunctions.steps.states import Task, IntegrationPattern class DynamoDBGetItemStep(Task): @@ -32,11 +33,11 @@ def __init__(self, state_id, **kwargs): output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$') """ kwargs[Field.Resource.value] = 'arn:aws:states:::dynamodb:getItem' - super(DynamoDBGetItemStep, self).__init__(state_id, **kwargs) + super(DynamoDBGetItemStep, self).__init__(state_id, 'dynamodb:getItem', + 'DynamoDB', **kwargs) class DynamoDBPutItemStep(Task): - """ Creates a Task state to put an item to DynamoDB. See `Call DynamoDB APIs with Step Functions `_ for more details. """ @@ -51,12 +52,11 @@ def __init__(self, state_id, **kwargs): result_path (str, optional): Path specifying the raw input’s combination with or replacement by the state’s result. (default: '$') output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$') """ - kwargs[Field.Resource.value] = 'arn:aws:states:::dynamodb:putItem' - super(DynamoDBPutItemStep, self).__init__(state_id, **kwargs) + super(DynamoDBPutItemStep, self).__init__(state_id, 'dynamodb:putItem', + 'DynamoDB', **kwargs) class DynamoDBDeleteItemStep(Task): - """ Creates a Task state to delete an item from DynamoDB. See `Call DynamoDB APIs with Step Functions `_ for more details. """ @@ -71,12 +71,11 @@ def __init__(self, state_id, **kwargs): result_path (str, optional): Path specifying the raw input’s combination with or replacement by the state’s result. (default: '$') output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$') """ - kwargs[Field.Resource.value] = 'arn:aws:states:::dynamodb:deleteItem' - super(DynamoDBDeleteItemStep, self).__init__(state_id, **kwargs) + super(DynamoDBDeleteItemStep, self).__init__(state_id, 'dynamodb:deleteItem', + 'DynamoDB', **kwargs) class DynamoDBUpdateItemStep(Task): - """ Creates a Task state to update an item from DynamoDB. See `Call DynamoDB APIs with Step Functions `_ for more details. """ @@ -91,21 +90,20 @@ def __init__(self, state_id, **kwargs): result_path (str, optional): Path specifying the raw input’s combination with or replacement by the state’s result. (default: '$') output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$') """ - kwargs[Field.Resource.value] = 'arn:aws:states:::dynamodb:updateItem' - super(DynamoDBUpdateItemStep, self).__init__(state_id, **kwargs) + super(DynamoDBUpdateItemStep, self).__init__(state_id, 'dynamodb:updateItem', + 'DynamoDB', **kwargs) class SnsPublishStep(Task): - """ Creates a Task state to publish a message to SNS topic. See `Call Amazon SNS with Step Functions `_ for more details. """ - def __init__(self, state_id, wait_for_callback=False, **kwargs): + def __init__(self, state_id, integration_pattern=IntegrationPattern.RequestResponse, **kwargs): """ Args: state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine. - wait_for_callback(bool, optional): Boolean value set to `True` if the Task state should wait for callback to resume the operation. (default: False) + integration_pattern(stepfunctions.states.IntegrationPattern, optional): Enum value set to RunAJob if the task should wait to complete before proceeding to the next step in the workflow, set to WaitForCallback if the Task state should wait for callback to resume the operation or set to RequestResponse if the Task should wait for HTTP response (default: RequestResponse) timeout_seconds (int, optional): Positive integer specifying timeout for the state in seconds. If the state runs longer than the specified timeout, then the interpreter fails the state with a `States.Timeout` Error Name. (default: 60) heartbeat_seconds (int, optional): Positive integer specifying heartbeat timeout for the state in seconds. This value should be lower than the one specified for `timeout_seconds`. If more time than the specified heartbeat elapses between heartbeats from the task, then the interpreter fails the state with a `States.Timeout` Error Name. comment (str, optional): Human-readable comment or description. (default: None) @@ -114,25 +112,21 @@ def __init__(self, state_id, wait_for_callback=False, **kwargs): result_path (str, optional): Path specifying the raw input’s combination with or replacement by the state’s result. (default: '$') output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$') """ - if wait_for_callback: - kwargs[Field.Resource.value] = 'arn:aws:states:::sns:publish.waitForTaskToken' - else: - kwargs[Field.Resource.value] = 'arn:aws:states:::sns:publish' - - super(SnsPublishStep, self).__init__(state_id, **kwargs) + self._valid_patterns = [IntegrationPattern.RequestResponse, IntegrationPattern.WaitForCallback] + self._integration_pattern = integration_pattern + super(SnsPublishStep, self).__init__(state_id, 'sns:publish', 'Amazon SNS', **kwargs) class SqsSendMessageStep(Task): - """ Creates a Task state to send a message to SQS queue. See `Call Amazon SQS with Step Functions `_ for more details. """ - def __init__(self, state_id, wait_for_callback=False, **kwargs): + def __init__(self, state_id, integration_pattern=IntegrationPattern.RequestResponse, **kwargs): """ Args: state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine. - wait_for_callback(bool, optional): Boolean value set to `True` if the Task state should wait for callback to resume the operation. (default: False) + integration_pattern(stepfunctions.states.IntegrationPattern, optional): Enum value set to RunAJob if the task should wait to complete before proceeding to the next step in the workflow, set to WaitForCallback if the Task state should wait for callback to resume the operation or set to RequestResponse if the Task should wait for HTTP response (default: RequestResponse) timeout_seconds (int, optional): Positive integer specifying timeout for the state in seconds. If the state runs longer than the specified timeout, then the interpreter fails the state with a `States.Timeout` Error Name. (default: 60) heartbeat_seconds (int, optional): Positive integer specifying heartbeat timeout for the state in seconds. This value should be lower than the one specified for `timeout_seconds`. If more time than the specified heartbeat elapses between heartbeats from the task, then the interpreter fails the state with a `States.Timeout` Error Name. comment (str, optional): Human-readable comment or description. (default: None) @@ -141,12 +135,10 @@ def __init__(self, state_id, wait_for_callback=False, **kwargs): result_path (str, optional): Path specifying the raw input’s combination with or replacement by the state’s result. (default: '$') output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$') """ - if wait_for_callback: - kwargs[Field.Resource.value] = 'arn:aws:states:::sqs:sendMessage.waitForTaskToken' - else: - kwargs[Field.Resource.value] = 'arn:aws:states:::sqs:sendMessage' - - super(SqsSendMessageStep, self).__init__(state_id, **kwargs) + self._valid_patterns = [IntegrationPattern.RequestResponse, IntegrationPattern.WaitForCallback] + self._integration_pattern = integration_pattern + super(SqsSendMessageStep, self).__init__(state_id, 'sqs:sendMessage', + 'Amazon SQS', **kwargs) class EmrCreateClusterStep(Task): @@ -154,7 +146,7 @@ class EmrCreateClusterStep(Task): Creates a Task state to create and start running a cluster (job flow). See `Call Amazon EMR with Step Functions `_ for more details. """ - def __init__(self, state_id, wait_for_completion=True, **kwargs): + def __init__(self, state_id, integration_pattern=IntegrationPattern.RunAJob, **kwargs): """ Args: state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine. @@ -163,14 +155,12 @@ def __init__(self, state_id, wait_for_completion=True, **kwargs): parameters (dict, optional): The value of this field becomes the effective input for the state. result_path (str, optional): Path specifying the raw input’s combination with or replacement by the state’s result. (default: '$') output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$') - wait_for_completion (bool, optional): Boolean value set to `True` if the Task state should wait to complete before proceeding to the next step in the workflow. (default: True) + integration_pattern(stepfunctions.states.IntegrationPattern, optional): Enum value set to RunAJob if the task should wait to complete before proceeding to the next step in the workflow, set to WaitForCallback if the Task state should wait for callback to resume the operation or set to RequestResponse if the Task should wait for HTTP response (default: RequestResponse) """ - if wait_for_completion: - kwargs[Field.Resource.value] = 'arn:aws:states:::elasticmapreduce:createCluster.sync' - else: - kwargs[Field.Resource.value] = 'arn:aws:states:::elasticmapreduce:createCluster' - - super(EmrCreateClusterStep, self).__init__(state_id, **kwargs) + self._valid_patterns = [IntegrationPattern.RequestResponse, IntegrationPattern.RunAJob] + self._integration_pattern = integration_pattern + super(EmrCreateClusterStep, self).__init__(state_id, 'elasticmapreduce:createCluster', + 'Amazon EMR', **kwargs) class EmrTerminateClusterStep(Task): @@ -178,7 +168,7 @@ class EmrTerminateClusterStep(Task): Creates a Task state to shut down a cluster (job flow). See `Call Amazon EMR with Step Functions `_ for more details. """ - def __init__(self, state_id, wait_for_completion=True, **kwargs): + def __init__(self, state_id, integration_pattern=IntegrationPattern.RunAJob, **kwargs): """ Args: state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine. @@ -187,14 +177,12 @@ def __init__(self, state_id, wait_for_completion=True, **kwargs): parameters (dict, optional): The value of this field becomes the effective input for the state. result_path (str, optional): Path specifying the raw input’s combination with or replacement by the state’s result. (default: '$') output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$') - wait_for_completion (bool, optional): Boolean value set to `True` if the Task state should wait to complete before proceeding to the next step in the workflow. (default: True) + integration_pattern(stepfunctions.states.IntegrationPattern, optional): Enum value set to RunAJob if the task should wait to complete before proceeding to the next step in the workflow, set to WaitForCallback if the Task state should wait for callback to resume the operation or set to RequestResponse if the Task should wait for HTTP response (default: RequestResponse) """ - if wait_for_completion: - kwargs[Field.Resource.value] = 'arn:aws:states:::elasticmapreduce:terminateCluster.sync' - else: - kwargs[Field.Resource.value] = 'arn:aws:states:::elasticmapreduce:terminateCluster' - - super(EmrTerminateClusterStep, self).__init__(state_id, **kwargs) + self._valid_patterns = [IntegrationPattern.RequestResponse, IntegrationPattern.RunAJob] + self._integration_pattern = integration_pattern + super(EmrTerminateClusterStep, self).__init__(state_id, 'elasticmapreduce:terminateCluster', + 'Amazon EMR', **kwargs) class EmrAddStepStep(Task): @@ -202,7 +190,7 @@ class EmrAddStepStep(Task): Creates a Task state to add a new step to a running cluster. See `Call Amazon EMR with Step Functions `_ for more details. """ - def __init__(self, state_id, wait_for_completion=True, **kwargs): + def __init__(self, state_id, integration_pattern=IntegrationPattern.RunAJob, **kwargs): """ Args: state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine. @@ -211,14 +199,12 @@ def __init__(self, state_id, wait_for_completion=True, **kwargs): parameters (dict, optional): The value of this field becomes the effective input for the state. result_path (str, optional): Path specifying the raw input’s combination with or replacement by the state’s result. (default: '$') output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$') - wait_for_completion (bool, optional): Boolean value set to `True` if the Task state should wait to complete before proceeding to the next step in the workflow. (default: True) + integration_pattern(stepfunctions.states.IntegrationPattern, optional): Enum value set to RunAJob if the task should wait to complete before proceeding to the next step in the workflow, set to WaitForCallback if the Task state should wait for callback to resume the operation or set to RequestResponse if the Task should wait for HTTP response (default: RequestResponse) """ - if wait_for_completion: - kwargs[Field.Resource.value] = 'arn:aws:states:::elasticmapreduce:addStep.sync' - else: - kwargs[Field.Resource.value] = 'arn:aws:states:::elasticmapreduce:addStep' - - super(EmrAddStepStep, self).__init__(state_id, **kwargs) + self._valid_patterns = [IntegrationPattern.RequestResponse, IntegrationPattern.RunAJob] + self._integration_pattern = integration_pattern + super(EmrAddStepStep, self).__init__(state_id, 'elasticmapreduce:addStep', + 'Amazon EMR', **kwargs) class EmrCancelStepStep(Task): @@ -236,9 +222,8 @@ def __init__(self, state_id, **kwargs): result_path (str, optional): Path specifying the raw input’s combination with or replacement by the state’s result. (default: '$') output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$') """ - kwargs[Field.Resource.value] = 'arn:aws:states:::elasticmapreduce:cancelStep' - - super(EmrCancelStepStep, self).__init__(state_id, **kwargs) + super(EmrCancelStepStep, self).__init__(state_id, 'elasticmapreduce:cancelStep', + 'Amazon EMR', **kwargs) class EmrSetClusterTerminationProtectionStep(Task): @@ -256,9 +241,9 @@ def __init__(self, state_id, **kwargs): result_path (str, optional): Path specifying the raw input’s combination with or replacement by the state’s result. (default: '$') output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$') """ - kwargs[Field.Resource.value] = 'arn:aws:states:::elasticmapreduce:setClusterTerminationProtection' - - super(EmrSetClusterTerminationProtectionStep, self).__init__(state_id, **kwargs) + super(EmrSetClusterTerminationProtectionStep, self).__init__(state_id, + 'elasticmapreduce:setClusterTerminationProtection', + 'Amazon EMR', **kwargs) class EmrModifyInstanceFleetByNameStep(Task): @@ -276,9 +261,8 @@ def __init__(self, state_id, **kwargs): result_path (str, optional): Path specifying the raw input’s combination with or replacement by the state’s result. (default: '$') output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$') """ - kwargs[Field.Resource.value] = 'arn:aws:states:::elasticmapreduce:modifyInstanceFleetByName' - - super(EmrModifyInstanceFleetByNameStep, self).__init__(state_id, **kwargs) + super(EmrModifyInstanceFleetByNameStep, self).__init__(state_id, 'elasticmapreduce:modifyInstanceFleetByName', + 'Amazon EMR', **kwargs) class EmrModifyInstanceGroupByNameStep(Task): @@ -296,7 +280,5 @@ def __init__(self, state_id, **kwargs): result_path (str, optional): Path specifying the raw input’s combination with or replacement by the state’s result. (default: '$') output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$') """ - kwargs[Field.Resource.value] = 'arn:aws:states:::elasticmapreduce:modifyInstanceGroupByName' - - super(EmrModifyInstanceGroupByNameStep, self).__init__(state_id, **kwargs) - + super(EmrModifyInstanceGroupByNameStep, self).__init__(state_id, 'elasticmapreduce:modifyInstanceGroupByName', + 'Amazon EMR', **kwargs) diff --git a/src/stepfunctions/steps/states.py b/src/stepfunctions/steps/states.py index 28744eb..f556528 100644 --- a/src/stepfunctions/steps/states.py +++ b/src/stepfunctions/steps/states.py @@ -14,11 +14,11 @@ import json import logging +from enum import Enum -from stepfunctions.exceptions import DuplicateStatesInChain -from stepfunctions.steps.fields import Field +from stepfunctions.exceptions import DuplicateStatesInChain, ForbiddenValueParameter from stepfunctions.inputs import ExecutionInput, StepInput - +from stepfunctions.steps.fields import Field logger = logging.getLogger('stepfunctions.states') @@ -28,7 +28,6 @@ def to_pascalcase(text): class Block(object): - """ Base class to abstract blocks used in `Amazon States Language `_. """ @@ -94,8 +93,8 @@ def to_json(self, pretty=False): def __repr__(self): return '{}({})'.format( - self.__class__.__name__, - ', '.join(['{}={!r}'.format(k, v) for k, v in self.fields.items()]) + self.__class__.__name__, + ', '.join(['{}={!r}'.format(k, v) for k, v in self.fields.items()]) ) def __str__(self): @@ -103,7 +102,6 @@ def __str__(self): class Retry(Block): - """ A class for creating a Retry block. """ @@ -129,7 +127,6 @@ def allowed_fields(self): class Catch(Block): - """ A class for creating a Catch block. """ @@ -157,7 +154,6 @@ def to_dict(self): class State(Block): - """ Base class to abstract states in `Amazon States Language `_. """ @@ -219,7 +215,9 @@ def next(self, next_step): State or Chain: Next state or chain that will be transitioned to. """ if self.type in ('Choice', 'Succeed', 'Fail'): - raise ValueError('Unexpected State instance `{step}`, State type `{state_type}` does not support method `next`.'.format(step=next_step, state_type=self.type)) + raise ValueError( + 'Unexpected State instance `{step}`, State type `{state_type}` does not support method `next`.'.format( + step=next_step, state_type=self.type)) self.next_step = next_step return self.next_step @@ -284,8 +282,8 @@ def to_dict(self): return result -class Pass(State): +class Pass(State): """ Pass State simply passes its input to its output, performing no work. Pass States are useful when constructing and debugging state machines. """ @@ -316,7 +314,6 @@ def allowed_fields(self): class Succeed(State): - """ Succeed State terminates a state machine successfully. The Succeed State is a useful target for :py:class:`Choice`-state branches that don't do anything but terminate the machine. """ @@ -340,7 +337,6 @@ def allowed_fields(self): class Fail(State): - """ Fail State terminates the machine and marks it as a failure. """ @@ -364,7 +360,6 @@ def allowed_fields(self): class Wait(State): - """ Wait state causes the interpreter to delay the machine from continuing for a specified time. The time can be specified as a wait duration, specified in seconds, or an absolute expiry time, specified as an ISO-8601 extended offset date-time format string. """ @@ -384,8 +379,10 @@ def __init__(self, state_id, **kwargs): output_path (str, optional): Path applied to the state’s output, producing the effective output which serves as the raw input for the next state. (default: '$') """ super(Wait, self).__init__(state_id, 'Wait', **kwargs) - if len([v for v in (self.seconds, self.timestamp, self.timestamp_path, self.seconds_path) if v is not None]) != 1: - raise ValueError("The Wait state MUST contain exactly one of 'seconds', 'seconds_path', 'timestamp' or 'timestamp_path'.") + if len([v for v in (self.seconds, self.timestamp, self.timestamp_path, self.seconds_path) if + v is not None]) != 1: + raise ValueError( + "The Wait state MUST contain exactly one of 'seconds', 'seconds_path', 'timestamp' or 'timestamp_path'.") def allowed_fields(self): return [ @@ -400,7 +397,6 @@ def allowed_fields(self): class Choice(State): - """ Choice state adds branching logic to a state machine. The state holds a list of *rule* and *next_step* pairs. The interpreter attempts pattern-matches against the rules in list order and transitions to the state or chain specified in the *next_step* field on the first *rule* where there is an exact match between the input value and a member of the comparison-operator array. """ @@ -471,7 +467,6 @@ def accept(self, visitor): class Parallel(State): - """ Parallel State causes parallel execution of "branches". @@ -520,7 +515,6 @@ def to_dict(self): class Map(State): - """ Map state provides the ability to dynamically iterate over a state/subgraph for each entry in a list. @@ -571,13 +565,21 @@ def to_dict(self): return result -class Task(State): +# https://docs.aws.amazon.com/step-functions/latest/dg/connect-supported-services.html +class IntegrationPattern(Enum): + RequestResponse = "" + RunAJob = ".sync" + WaitForCallback = ".waitForTaskToken" + +class Task(State): + _valid_patterns = [IntegrationPattern.RequestResponse] + _integration_pattern = IntegrationPattern.RequestResponse """ Task State causes the interpreter to execute the work identified by the state’s `resource` field. """ - def __init__(self, state_id, **kwargs): + def __init__(self, state_id, action=None, step_name="Unknow", **kwargs): """ Args: state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine. @@ -590,6 +592,11 @@ def __init__(self, state_id, **kwargs): result_path (str, optional): Path specifying the raw input’s combination with or replacement by the state’s result. (default: '$') output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$') """ + if self._integration_pattern not in self._valid_patterns: + raise ForbiddenValueParameter( + f"{step_name} only supports {', '.join(map(lambda pattern: pattern.name, self._valid_patterns))} integration pattern.") + if action: + kwargs[Field.Resource.value] = ''.join([f'arn:aws:states:::{action}', self._integration_pattern.value]) super(Task, self).__init__(state_id, 'Task', **kwargs) def allowed_fields(self): @@ -647,7 +654,9 @@ def append(self, step): self.steps.append(step) else: if step in self.steps: - raise DuplicateStatesInChain("State '{step_name}' is already inside this chain. A chain cannot have duplicate states.".format(step_name=step.state_id)) + raise DuplicateStatesInChain( + "State '{step_name}' is already inside this chain. A chain cannot have duplicate states.".format( + step_name=step.state_id)) last_step = self.steps[-1] last_step.next(step) self.steps.append(step) @@ -658,8 +667,8 @@ def accept(self, visitor): def __repr__(self): return '{}(steps={!r})'.format( - self.__class__.__name__, - self.steps + self.__class__.__name__, + self.steps ) @@ -688,7 +697,9 @@ def is_visited(self, state): def visit(self, state): if state.state_id in self.states: - raise ValueError("Each state in a workflow must have a unique state id. Found duplicate state id '{}' in workflow.".format(state.state_id)) + raise ValueError( + "Each state in a workflow must have a unique state id. Found duplicate state id '{}' in workflow.".format( + state.state_id)) self.states[state.state_id] = state.to_dict() if state.next_step is None: return @@ -697,7 +708,9 @@ def visit(self, state): params = state.next_step.fields[Field.Parameters.value] valid, invalid_param_name = self._validate_next_step_params(params, state.step_output) if not valid: - raise ValueError('State \'{state_name}\' is using an illegal placeholder for the \'{param_name}\' parameter.'.format(state_name=state.next_step.state_id, param_name=invalid_param_name)) + raise ValueError( + 'State \'{state_name}\' is using an illegal placeholder for the \'{param_name}\' parameter.'.format( + state_name=state.next_step.state_id, param_name=invalid_param_name)) def _validate_next_step_params(self, params, step_output): for k, v in params.items(): @@ -710,6 +723,7 @@ def _validate_next_step_params(self, params, step_output): return valid, invalid_param_name return True, None + class Graph(Block): def __init__(self, branch, **kwargs): diff --git a/tests/unit/test_compute_steps.py b/tests/unit/test_compute_steps.py index 030cf35..3382f50 100644 --- a/tests/unit/test_compute_steps.py +++ b/tests/unit/test_compute_steps.py @@ -14,7 +14,9 @@ import pytest +from stepfunctions.exceptions import ForbiddenValueParameter from stepfunctions.steps.compute import LambdaStep, GlueStartJobRunStep, BatchSubmitJobStep, EcsRunTaskStep +from stepfunctions.steps.states import IntegrationPattern def test_lambda_step_creation(): @@ -26,7 +28,7 @@ def test_lambda_step_creation(): 'End': True } - step = LambdaStep('lambda', wait_for_callback=True, parameters={ + step = LambdaStep('lambda', integration_pattern=IntegrationPattern.WaitForCallback, parameters={ 'Payload': { 'model.$': '$.new_model', 'token.$': '$$.Task.Token' @@ -37,7 +39,7 @@ def test_lambda_step_creation(): 'Type': 'Task', 'Resource': 'arn:aws:states:::lambda:invoke.waitForTaskToken', 'Parameters': { - 'Payload': { + 'Payload': { 'model.$': '$.new_model', 'token.$': '$$.Task.Token' }, @@ -45,8 +47,12 @@ def test_lambda_step_creation(): 'End': True } + with pytest.raises(ForbiddenValueParameter): + step = LambdaStep('Echo', integration_pattern=IntegrationPattern.RunAJob) + + def test_glue_start_job_run_step_creation(): - step = GlueStartJobRunStep('Glue Job', wait_for_completion=False) + step = GlueStartJobRunStep('Glue Job', integration_pattern=IntegrationPattern.RequestResponse) assert step.to_dict() == { 'Type': 'Task', @@ -67,8 +73,12 @@ def test_glue_start_job_run_step_creation(): 'End': True } + with pytest.raises(ForbiddenValueParameter): + step = GlueStartJobRunStep('Glue Job', integration_pattern=IntegrationPattern.WaitForCallback) + + def test_batch_submit_job_step_creation(): - step = BatchSubmitJobStep('Batch Job', wait_for_completion=False) + step = BatchSubmitJobStep('Batch Job', integration_pattern=IntegrationPattern.RequestResponse) assert step.to_dict() == { 'Type': 'Task', @@ -91,8 +101,12 @@ def test_batch_submit_job_step_creation(): 'End': True } + with pytest.raises(ForbiddenValueParameter): + step = BatchSubmitJobStep('Batch Job', integration_pattern=IntegrationPattern.WaitForCallback) + + def test_ecs_run_task_step_creation(): - step = EcsRunTaskStep('Ecs Job', wait_for_completion=False) + step = EcsRunTaskStep('Ecs Job') assert step.to_dict() == { 'Type': 'Task', @@ -100,7 +114,7 @@ def test_ecs_run_task_step_creation(): 'End': True } - step = EcsRunTaskStep('Ecs Job', parameters={ + step = EcsRunTaskStep('Ecs Job', integration_pattern=IntegrationPattern.RunAJob, parameters={ 'TaskDefinition': 'Task' }) @@ -112,3 +126,16 @@ def test_ecs_run_task_step_creation(): }, 'End': True } + + step = EcsRunTaskStep('Ecs Job', integration_pattern=IntegrationPattern.WaitForCallback, parameters={ + 'TaskDefinition': 'Task' + }) + + assert step.to_dict() == { + 'Type': 'Task', + 'Resource': 'arn:aws:states:::ecs:runTask.waitForTaskToken', + 'Parameters': { + 'TaskDefinition': 'Task' + }, + 'End': True + } \ No newline at end of file diff --git a/tests/unit/test_sagemaker_steps.py b/tests/unit/test_sagemaker_steps.py index bdc7a57..6f22c15 100644 --- a/tests/unit/test_sagemaker_steps.py +++ b/tests/unit/test_sagemaker_steps.py @@ -12,23 +12,22 @@ # permissions and limitations under the License. from __future__ import absolute_import +from unittest.mock import MagicMock, patch + import pytest import sagemaker -import boto3 - -from sagemaker.transformer import Transformer +from sagemaker.debugger import Rule, rule_configs, DebuggerHookConfig, CollectionConfig from sagemaker.model import Model -from sagemaker.tensorflow import TensorFlow -from sagemaker.pipeline import PipelineModel from sagemaker.model_monitor import DataCaptureConfig -from sagemaker.debugger import Rule, rule_configs, DebuggerHookConfig, CollectionConfig -from sagemaker.sklearn.processing import SKLearnProcessor from sagemaker.processing import ProcessingInput, ProcessingOutput +from sagemaker.sklearn.processing import SKLearnProcessor +from sagemaker.tensorflow import TensorFlow +from sagemaker.transformer import Transformer -from unittest.mock import MagicMock, patch -from stepfunctions.steps.sagemaker import TrainingStep, TransformStep, ModelStep, EndpointStep, EndpointConfigStep, ProcessingStep -from stepfunctions.steps.sagemaker import tuning_config - +from stepfunctions.exceptions import ForbiddenValueParameter +from stepfunctions.steps.sagemaker import TrainingStep, TransformStep, ModelStep, EndpointStep, EndpointConfigStep, \ + ProcessingStep +from stepfunctions.steps.states import IntegrationPattern from tests.unit.utils import mock_boto_api_call EXECUTION_ROLE = 'execution-role' @@ -37,6 +36,7 @@ DEFAULT_TAGS = {'Purpose': 'unittests'} DEFAULT_TAGS_LIST = [{'Key': 'Purpose', 'Value': 'unittests'}] + @pytest.fixture def pca_estimator(): s3_output_location = 's3://sagemaker/models' @@ -63,6 +63,7 @@ def pca_estimator(): return pca + @pytest.fixture def pca_estimator_with_debug_hook(): s3_output_location = 's3://sagemaker/models' @@ -79,13 +80,13 @@ def pca_estimator_with_debug_hook(): ) rules = [Rule.sagemaker(rule_configs.confusion(), - rule_parameters={ - "category_no": "15", - "min_diag": "0.7", - "max_off_diag": "0.3", - "start_step": "17", - "end_step": "19"} - )] + rule_parameters={ + "category_no": "15", + "min_diag": "0.7", + "max_off_diag": "0.3", + "start_step": "17", + "end_step": "19"} + )] pca = sagemaker.estimator.Estimator( PCA_IMAGE, @@ -93,7 +94,7 @@ def pca_estimator_with_debug_hook(): train_instance_count=1, train_instance_type='ml.c4.xlarge', output_path=s3_output_location, - debugger_hook_config = hook_config, + debugger_hook_config=hook_config, rules=rules ) @@ -111,6 +112,7 @@ def pca_estimator_with_debug_hook(): return pca + @pytest.fixture def pca_model(): model_data = 's3://sagemaker/models/pca.tar.gz' @@ -121,6 +123,7 @@ def pca_model(): name='pca-model' ) + @pytest.fixture def pca_transformer(pca_model): return Transformer( @@ -130,23 +133,24 @@ def pca_transformer(pca_model): output_path='s3://sagemaker/transform-output' ) + @pytest.fixture def tensorflow_estimator(): s3_output_location = 's3://sagemaker/models' s3_source_location = 's3://sagemaker/source' estimator = TensorFlow(entry_point='tf_train.py', - role=EXECUTION_ROLE, - framework_version='1.13', - training_steps=1000, - evaluation_steps=100, - train_instance_count=1, - train_instance_type='ml.p2.xlarge', - output_path=s3_output_location, - source_dir=s3_source_location, - image_name=TENSORFLOW_IMAGE, - checkpoint_path='s3://sagemaker/models/sagemaker-tensorflow/checkpoints' - ) + role=EXECUTION_ROLE, + framework_version='1.13', + training_steps=1000, + evaluation_steps=100, + train_instance_count=1, + train_instance_type='ml.p2.xlarge', + output_path=s3_output_location, + source_dir=s3_source_location, + image_name=TENSORFLOW_IMAGE, + checkpoint_path='s3://sagemaker/models/sagemaker-tensorflow/checkpoints' + ) estimator.debugger_hook_config = DebuggerHookConfig( s3_output_path='s3://sagemaker/models/debug' @@ -155,9 +159,10 @@ def tensorflow_estimator(): estimator.sagemaker_session = MagicMock() estimator.sagemaker_session.boto_region_name = 'us-east-1' estimator.sagemaker_session._default_bucket = 'sagemaker' - + return estimator + @pytest.fixture def sklearn_processor(): sagemaker_session = MagicMock() @@ -174,18 +179,19 @@ def sklearn_processor(): return processor + @patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call) def test_training_step_creation(pca_estimator): - step = TrainingStep('Training', - estimator=pca_estimator, - job_name='TrainingJob', - experiment_config={ - 'ExperimentName': 'pca_experiment', - 'TrialName': 'pca_trial', - 'TrialComponentDisplayName': 'Training' - }, - tags=DEFAULT_TAGS, - ) + step = TrainingStep('Training', + estimator=pca_estimator, + job_name='TrainingJob', + experiment_config={ + 'ExperimentName': 'pca_experiment', + 'TrialName': 'pca_trial', + 'TrialComponentDisplayName': 'Training' + }, + tags=DEFAULT_TAGS, + ) assert step.to_dict() == { 'Type': 'Task', 'Parameters': { @@ -215,7 +221,7 @@ def test_training_step_creation(pca_estimator): 'ExperimentConfig': { 'ExperimentName': 'pca_experiment', 'TrialName': 'pca_trial', - 'TrialComponentDisplayName': 'Training' + 'TrialComponentDisplayName': 'Training' }, 'TrainingJobName': 'TrainingJob', 'Tags': DEFAULT_TAGS_LIST @@ -223,12 +229,24 @@ def test_training_step_creation(pca_estimator): 'Resource': 'arn:aws:states:::sagemaker:createTrainingJob.sync', 'End': True } + with pytest.raises(ForbiddenValueParameter): + step = TrainingStep('Training', + estimator=pca_estimator, + job_name='TrainingJob', + experiment_config={ + 'ExperimentName': 'pca_experiment', + 'TrialName': 'pca_trial', + 'TrialComponentDisplayName': 'Training' + }, + tags=DEFAULT_TAGS, + integration_pattern=IntegrationPattern.WaitForCallback) + @patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call) def test_training_step_creation_with_debug_hook(pca_estimator_with_debug_hook): step = TrainingStep('Training', - estimator=pca_estimator_with_debug_hook, - job_name='TrainingJob') + estimator=pca_estimator_with_debug_hook, + job_name='TrainingJob') assert step.to_dict() == { 'Type': 'Task', 'Parameters': { @@ -282,11 +300,17 @@ def test_training_step_creation_with_debug_hook(pca_estimator_with_debug_hook): 'Resource': 'arn:aws:states:::sagemaker:createTrainingJob.sync', 'End': True } + with pytest.raises(ForbiddenValueParameter): + step = TrainingStep('Training', + estimator=pca_estimator_with_debug_hook, + job_name='TrainingJob', integration_pattern=IntegrationPattern.WaitForCallback) + @patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call) def test_training_step_creation_with_model(pca_estimator): training_step = TrainingStep('Training', estimator=pca_estimator, job_name='TrainingJob') - model_step = ModelStep('Training - Save Model', training_step.get_expected_model(model_name=training_step.output()['TrainingJobName'])) + model_step = ModelStep('Training - Save Model', + training_step.get_expected_model(model_name=training_step.output()['TrainingJobName'])) training_step.next(model_step) assert training_step.to_dict() == { 'Type': 'Task', @@ -334,17 +358,20 @@ def test_training_step_creation_with_model(pca_estimator): }, 'End': True } + with pytest.raises(ForbiddenValueParameter): + training_step = TrainingStep('Training', estimator=pca_estimator, job_name='TrainingJob', integration_pattern=IntegrationPattern.WaitForCallback) + @patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call) def test_training_step_creation_with_framework(tensorflow_estimator): step = TrainingStep('Training', - estimator=tensorflow_estimator, - data={'train': 's3://sagemaker/train'}, - job_name='tensorflow-job', - mini_batch_size=1024, - tags=DEFAULT_TAGS, - ) - + estimator=tensorflow_estimator, + data={'train': 's3://sagemaker/train'}, + job_name='tensorflow-job', + mini_batch_size=1024, + tags=DEFAULT_TAGS, + ) + assert step.to_dict() == { 'Type': 'Task', 'Parameters': { @@ -395,23 +422,32 @@ def test_training_step_creation_with_framework(tensorflow_estimator): 'Resource': 'arn:aws:states:::sagemaker:createTrainingJob.sync', 'End': True } + with pytest.raises(ForbiddenValueParameter): + step = TrainingStep('Training', + estimator=tensorflow_estimator, + data={'train': 's3://sagemaker/train'}, + job_name='tensorflow-job', + mini_batch_size=1024, + tags=DEFAULT_TAGS, + integration_pattern=IntegrationPattern.WaitForCallback) + def test_transform_step_creation(pca_transformer): step = TransformStep('Inference', - transformer=pca_transformer, - data='s3://sagemaker/inference', - job_name='transform-job', - model_name='pca-model', - experiment_config={ - 'ExperimentName': 'pca_experiment', - 'TrialName': 'pca_trial', - 'TrialComponentDisplayName': 'Transform' - }, - tags=DEFAULT_TAGS, - join_source='Input', - output_filter='$[2:]', - input_filter='$[1:]' - ) + transformer=pca_transformer, + data='s3://sagemaker/inference', + job_name='transform-job', + model_name='pca-model', + experiment_config={ + 'ExperimentName': 'pca_experiment', + 'TrialName': 'pca_trial', + 'TrialComponentDisplayName': 'Transform' + }, + tags=DEFAULT_TAGS, + join_source='Input', + output_filter='$[2:]', + input_filter='$[1:]' + ) assert step.to_dict() == { 'Type': 'Task', 'Parameters': { @@ -435,7 +471,7 @@ def test_transform_step_creation(pca_transformer): 'ExperimentConfig': { 'ExperimentName': 'pca_experiment', 'TrialName': 'pca_trial', - 'TrialComponentDisplayName': 'Transform' + 'TrialComponentDisplayName': 'Transform' }, 'DataProcessing': { 'InputFilter': '$[1:]', @@ -448,6 +484,24 @@ def test_transform_step_creation(pca_transformer): 'End': True } + with pytest.raises(ForbiddenValueParameter): + step = TransformStep('Inference', + transformer=pca_transformer, + data='s3://sagemaker/inference', + job_name='transform-job', + model_name='pca-model', + experiment_config={ + 'ExperimentName': 'pca_experiment', + 'TrialName': 'pca_trial', + 'TrialComponentDisplayName': 'Transform' + }, + tags=DEFAULT_TAGS, + join_source='Input', + output_filter='$[2:]', + input_filter='$[1:]', + integration_pattern=IntegrationPattern.WaitForCallback) + + @patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call) def test_get_expected_model(pca_estimator): training_step = TrainingStep('Training', estimator=pca_estimator, job_name='TrainingJob') @@ -468,14 +522,15 @@ def test_get_expected_model(pca_estimator): 'End': True } + @patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call) def test_get_expected_model_with_framework_estimator(tensorflow_estimator): training_step = TrainingStep('Training', - estimator=tensorflow_estimator, - data={'train': 's3://sagemaker/train'}, - job_name='tensorflow-job', - mini_batch_size=1024 - ) + estimator=tensorflow_estimator, + data={'train': 's3://sagemaker/train'}, + job_name='tensorflow-job', + mini_batch_size=1024 + ) expected_model = training_step.get_expected_model() expected_model.entry_point = 'tf_train.py' model_step = ModelStep('Create model', model=expected_model, model_name='tf-model') @@ -500,6 +555,7 @@ def test_get_expected_model_with_framework_estimator(tensorflow_estimator): 'End': True } + def test_model_step_creation(pca_model): step = ModelStep('Create model', model=pca_model, model_name='pca-model', tags=DEFAULT_TAGS) assert step.to_dict() == { @@ -518,19 +574,20 @@ def test_model_step_creation(pca_model): 'End': True } + def test_endpoint_config_step_creation(pca_model): data_capture_config = DataCaptureConfig( enable_capture=True, sampling_percentage=100, destination_s3_uri='s3://sagemaker/datacapture') - step = EndpointConfigStep('Endpoint Config', - endpoint_config_name='MyEndpointConfig', - model_name='pca-model', - initial_instance_count=1, - instance_type='ml.p2.xlarge', - data_capture_config=data_capture_config, - tags=DEFAULT_TAGS, - ) + step = EndpointConfigStep('Endpoint Config', + endpoint_config_name='MyEndpointConfig', + model_name='pca-model', + initial_instance_count=1, + instance_type='ml.p2.xlarge', + data_capture_config=data_capture_config, + tags=DEFAULT_TAGS, + ) assert step.to_dict() == { 'Type': 'Task', 'Parameters': { @@ -546,7 +603,7 @@ def test_endpoint_config_step_creation(pca_model): 'InitialSamplingPercentage': 100, 'DestinationS3Uri': 's3://sagemaker/datacapture', 'CaptureOptions': [ - {'CaptureMode': 'Input'}, + {'CaptureMode': 'Input'}, {'CaptureMode': 'Output'} ], 'CaptureContentTypeHeader': { @@ -560,8 +617,10 @@ def test_endpoint_config_step_creation(pca_model): 'End': True } + def test_endpoint_step_creation(pca_model): - step = EndpointStep('Endpoint', endpoint_name='MyEndPoint', endpoint_config_name='MyEndpointConfig', tags=DEFAULT_TAGS) + step = EndpointStep('Endpoint', endpoint_name='MyEndPoint', endpoint_config_name='MyEndpointConfig', + tags=DEFAULT_TAGS) assert step.to_dict() == { 'Type': 'Task', 'Parameters': { @@ -573,7 +632,8 @@ def test_endpoint_step_creation(pca_model): 'End': True } - step = EndpointStep('Endpoint', endpoint_name='MyEndPoint', endpoint_config_name='MyEndpointConfig', update=True, tags=DEFAULT_TAGS) + step = EndpointStep('Endpoint', endpoint_name='MyEndPoint', endpoint_config_name='MyEndpointConfig', update=True, + tags=DEFAULT_TAGS) assert step.to_dict() == { 'Type': 'Task', 'Parameters': { @@ -585,6 +645,7 @@ def test_endpoint_step_creation(pca_model): 'End': True } + def test_processing_step_creation(sklearn_processor): inputs = [ProcessingInput(source='dataset.csv', destination='/opt/ml/processing/input')] outputs = [ @@ -592,7 +653,8 @@ def test_processing_step_creation(sklearn_processor): ProcessingOutput(source='/opt/ml/processing/output/validation'), ProcessingOutput(source='/opt/ml/processing/output/test') ] - step = ProcessingStep('Feature Transformation', sklearn_processor, 'MyProcessingJob', inputs=inputs, outputs=outputs) + step = ProcessingStep('Feature Transformation', sklearn_processor, 'MyProcessingJob', inputs=inputs, + outputs=outputs) assert step.to_dict() == { 'Type': 'Task', 'Parameters': { @@ -653,3 +715,7 @@ def test_processing_step_creation(sklearn_processor): 'Resource': 'arn:aws:states:::sagemaker:createProcessingJob.sync', 'End': True } + + with pytest.raises(ForbiddenValueParameter): + step = ProcessingStep('Feature Transformation', sklearn_processor, 'MyProcessingJob', inputs=inputs, + outputs=outputs, integration_pattern=IntegrationPattern.WaitForCallback) diff --git a/tests/unit/test_service_steps.py b/tests/unit/test_service_steps.py index 64809e9..c79f7fc 100644 --- a/tests/unit/test_service_steps.py +++ b/tests/unit/test_service_steps.py @@ -14,6 +14,9 @@ import pytest +from stepfunctions.exceptions import ForbiddenValueParameter +from stepfunctions.steps.states import IntegrationPattern + from stepfunctions.steps.service import DynamoDBGetItemStep, DynamoDBPutItemStep, DynamoDBUpdateItemStep, DynamoDBDeleteItemStep from stepfunctions.steps.service import SnsPublishStep, SqsSendMessageStep from stepfunctions.steps.service import EmrCreateClusterStep, EmrTerminateClusterStep, EmrAddStepStep, EmrCancelStepStep, EmrSetClusterTerminationProtectionStep, EmrModifyInstanceFleetByNameStep, EmrModifyInstanceGroupByNameStep @@ -35,7 +38,7 @@ def test_sns_publish_step_creation(): 'End': True } - step = SnsPublishStep('Publish to SNS', wait_for_callback=True, parameters={ + step = SnsPublishStep('Publish to SNS', integration_pattern=IntegrationPattern.WaitForCallback, parameters={ 'TopicArn': 'arn:aws:sns:us-east-1:123456789012:myTopic', 'Message': { 'Input.$': '$', @@ -56,6 +59,8 @@ def test_sns_publish_step_creation(): 'End': True } + with pytest.raises(ForbiddenValueParameter): + step = SnsPublishStep('Publish to SNS', integration_pattern=IntegrationPattern.RunAJob) def test_sqs_send_message_step_creation(): step = SqsSendMessageStep('Send to SQS', parameters={ @@ -73,7 +78,7 @@ def test_sqs_send_message_step_creation(): 'End': True } - step = SqsSendMessageStep('Send to SQS', wait_for_callback=True, parameters={ + step = SqsSendMessageStep('Send to SQS', integration_pattern=IntegrationPattern.WaitForCallback, parameters={ 'QueueUrl': 'https://sqs.us-east-1.amazonaws.com/123456789012/myQueue', 'MessageBody': { 'Input.$': '$', @@ -94,6 +99,9 @@ def test_sqs_send_message_step_creation(): 'End': True } + with pytest.raises(ForbiddenValueParameter): + step = SqsSendMessageStep('Send to SQS', integration_pattern=IntegrationPattern.RunAJob) + def test_dynamodb_get_item_step_creation(): step = DynamoDBGetItemStep('Read Message From DynamoDB', parameters={ @@ -287,7 +295,7 @@ def test_emr_create_cluster_step_creation(): 'End': True } - step = EmrCreateClusterStep('Create EMR cluster', wait_for_completion=False, parameters={ + step = EmrCreateClusterStep('Create EMR cluster', integration_pattern=IntegrationPattern.RequestResponse, parameters={ 'Name': 'MyWorkflowCluster', 'VisibleToAllUsers': True, 'ReleaseLabel': 'emr-5.28.0', @@ -371,6 +379,9 @@ def test_emr_create_cluster_step_creation(): 'End': True } + with pytest.raises(ForbiddenValueParameter): + step = EmrCreateClusterStep('Create EMR cluster', integration_pattern=IntegrationPattern.WaitForCallback) + def test_emr_terminate_cluster_step_creation(): step = EmrTerminateClusterStep('Terminate EMR cluster', parameters={ @@ -386,7 +397,7 @@ def test_emr_terminate_cluster_step_creation(): 'End': True } - step = EmrTerminateClusterStep('Terminate EMR cluster', wait_for_completion=False, parameters={ + step = EmrTerminateClusterStep('Terminate EMR cluster', integration_pattern=IntegrationPattern.RequestResponse, parameters={ 'ClusterId': 'MyWorkflowClusterId' }) @@ -399,6 +410,9 @@ def test_emr_terminate_cluster_step_creation(): 'End': True } + with pytest.raises(ForbiddenValueParameter): + step = EmrTerminateClusterStep('Terminate EMR cluster', integration_pattern=IntegrationPattern.WaitForCallback) + def test_emr_add_step_step_creation(): step = EmrAddStepStep('Add step to EMR cluster', parameters={ 'ClusterId': 'MyWorkflowClusterId', @@ -449,7 +463,7 @@ def test_emr_add_step_step_creation(): 'End': True } - step = EmrAddStepStep('Add step to EMR cluster', wait_for_completion=False, parameters={ + step = EmrAddStepStep('Add step to EMR cluster', integration_pattern=IntegrationPattern.RequestResponse, parameters={ 'ClusterId': 'MyWorkflowClusterId', 'Step': { 'Name': 'The first step', @@ -498,6 +512,9 @@ def test_emr_add_step_step_creation(): 'End': True } + with pytest.raises(ForbiddenValueParameter): + step = EmrAddStepStep('Add step to EMR cluster', integration_pattern=IntegrationPattern.WaitForCallback) + def test_emr_cancel_step_step_creation(): step = EmrCancelStepStep('Cancel step from EMR cluster', parameters={ 'ClusterId': 'MyWorkflowClusterId',