Skip to content

Commit f9e63b3

Browse files
authored
feat: Add support for nested Aws StepFunctions service integration (#166)
1 parent 8b6d0eb commit f9e63b3

File tree

6 files changed

+202
-8
lines changed

6 files changed

+202
-8
lines changed

Diff for: doc/services.rst

+7
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ This module provides classes to build steps that integrate with Amazon DynamoDB,
2020

2121
- `Amazon SQS <#amazon-sqs>`__
2222

23+
- `AWS Step Functions <#aws-step-functions>`__
24+
2325

2426
Amazon DynamoDB
2527
----------------
@@ -82,3 +84,8 @@ Amazon SNS
8284
Amazon SQS
8385
-----------
8486
.. autoclass:: stepfunctions.steps.service.SqsSendMessageStep
87+
88+
AWS Step Functions
89+
------------------
90+
.. autoclass:: stepfunctions.steps.service.StepFunctionsStartExecutionStep
91+

Diff for: src/stepfunctions/steps/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,4 @@
3434
from stepfunctions.steps.service import EventBridgePutEventsStep
3535
from stepfunctions.steps.service import GlueDataBrewStartJobRunStep
3636
from stepfunctions.steps.service import SnsPublishStep, SqsSendMessageStep
37+
from stepfunctions.steps.service import StepFunctionsStartExecutionStep

Diff for: src/stepfunctions/steps/integration_resources.py

+16-4
Original file line numberDiff line numberDiff line change
@@ -24,23 +24,35 @@ class IntegrationPattern(Enum):
2424

2525
WaitForTaskToken = "waitForTaskToken"
2626
WaitForCompletion = "sync"
27-
RequestResponse = ""
27+
CallAndContinue = ""
2828

2929

30-
def get_service_integration_arn(service, api, integration_pattern=IntegrationPattern.RequestResponse):
30+
def get_service_integration_arn(service, api, integration_pattern=IntegrationPattern.CallAndContinue, version=None):
3131

3232
"""
3333
ARN builder for task integration
3434
Args:
3535
service (str): The service name for the service integration
3636
api (str): The api of the service integration
37-
integration_pattern (IntegrationPattern, optional): The integration pattern for the task. (Default: IntegrationPattern.RequestResponse)
37+
integration_pattern (IntegrationPattern, optional): The integration pattern for the task. (Default: IntegrationPattern.CallAndContinue)
38+
version (int, optional): The version of the resource to use. (Default: None)
3839
"""
3940
arn = ""
40-
if integration_pattern == IntegrationPattern.RequestResponse:
41+
if integration_pattern == IntegrationPattern.CallAndContinue:
4142
arn = f"arn:{get_aws_partition()}:states:::{service}:{api.value}"
4243
else:
4344
arn = f"arn:{get_aws_partition()}:states:::{service}:{api.value}.{integration_pattern.value}"
45+
46+
if version:
47+
arn = f"{arn}:{str(version)}"
48+
4449
return arn
4550

4651

52+
def is_integration_pattern_valid(integration_pattern, supported_integration_patterns):
53+
if not isinstance(integration_pattern, IntegrationPattern):
54+
raise TypeError(f"Integration pattern must be of type {IntegrationPattern}")
55+
elif integration_pattern not in supported_integration_patterns:
56+
raise ValueError(f"Integration Pattern ({integration_pattern.name}) is not supported for this step - "
57+
f"Please use one of the following: "
58+
f"{[integ_type.name for integ_type in supported_integration_patterns]}")

Diff for: src/stepfunctions/steps/service.py

+58-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
from enum import Enum
1616
from stepfunctions.steps.states import Task
1717
from stepfunctions.steps.fields import Field
18-
from stepfunctions.steps.integration_resources import IntegrationPattern, get_service_integration_arn
18+
from stepfunctions.steps.integration_resources import IntegrationPattern, get_service_integration_arn,\
19+
is_integration_pattern_valid
1920

2021
DYNAMODB_SERVICE_NAME = "dynamodb"
2122
EKS_SERVICES_NAME = "eks"
@@ -24,6 +25,7 @@
2425
GLUE_DATABREW_SERVICE_NAME = "databrew"
2526
SNS_SERVICE_NAME = "sns"
2627
SQS_SERVICE_NAME = "sqs"
28+
STEP_FUNCTIONS_SERVICE_NAME = "states"
2729

2830

2931
class DynamoDBApi(Enum):
@@ -70,6 +72,10 @@ class SqsApi(Enum):
7072
SendMessage = "sendMessage"
7173

7274

75+
class StepFunctions(Enum):
76+
StartExecution = "startExecution"
77+
78+
7379
class DynamoDBGetItemStep(Task):
7480
"""
7581
Creates a Task state to get an item from DynamoDB. See `Call DynamoDB APIs with Step Functions <https://docs.aws.amazon.com/step-functions/latest/dg/connect-ddb.html>`_ for more details.
@@ -887,3 +893,54 @@ def __init__(self, state_id, **kwargs):
887893
ElasticMapReduceApi.ModifyInstanceGroupByName)
888894

889895
super(EmrModifyInstanceGroupByNameStep, self).__init__(state_id, **kwargs)
896+
897+
898+
class StepFunctionsStartExecutionStep(Task):
899+
900+
"""
901+
Creates a Task state that starts an execution of a state machine. See `Manage AWS Step Functions Executions as an Integrated Service <https://docs.aws.amazon.com/step-functions/latest/dg/connect-stepfunctions.html`_ for more details.
902+
"""
903+
904+
def __init__(self, state_id, integration_pattern=IntegrationPattern.WaitForCompletion, **kwargs):
905+
"""
906+
Args:
907+
state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine.
908+
integration_pattern (stepfunctions.steps.integration_resources.IntegrationPattern, optional): Service integration pattern used to call the integrated service. (default: WaitForCompletion)
909+
Supported integration patterns:
910+
WaitForCompletion: Wait for the state machine execution to complete before going to the next state. (See `Run A Job <https://docs.aws.amazon.com/step-functions/latest/dg/connect-to-resource.html#connect-sync`_ for more details.)
911+
WaitForTaskToken: Wait for the state machine execution to return a task token before progressing to the next state (See `Wait for a Callback with the Task Token <https://docs.aws.amazon.com/step-functions/latest/dg/connect-to-resource.html#connect-wait-token`_ for more details.)
912+
CallAndContinue: Call StartExecution and progress to the next state (See `Request Response <https://docs.aws.amazon.com/step-functions/latest/dg/connect-to-resource.html#connect-default`_ for more details.)
913+
timeout_seconds (int, optional): Positive integer specifying timeout for the state in seconds. If the state runs longer than the specified timeout, then the interpreter fails the state with a `States.Timeout` Error Name. (default: 60)
914+
timeout_seconds_path (str, optional): Path specifying the state's timeout value in seconds from the state input. When resolved, the path must select a field whose value is a positive integer.
915+
heartbeat_seconds (int, optional): Positive integer specifying heartbeat timeout for the state in seconds. This value should be lower than the one specified for `timeout_seconds`. If more time than the specified heartbeat elapses between heartbeats from the task, then the interpreter fails the state with a `States.Timeout` Error Name.
916+
heartbeat_seconds_path (str, optional): Path specifying the state's heartbeat value in seconds from the state input. When resolved, the path must select a field whose value is a positive integer.
917+
comment (str, optional): Human-readable comment or description. (default: None)
918+
input_path (str, optional): Path applied to the state’s raw input to select some or all of it; that selection is used by the state. (default: '$')
919+
parameters (dict, optional): The value of this field becomes the effective input for the state. (default: None)
920+
result_path (str, optional): Path specifying the raw input’s combination with or replacement by the state’s result. (default: '$')
921+
output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$')
922+
"""
923+
supported_integ_patterns = [IntegrationPattern.WaitForCompletion, IntegrationPattern.WaitForTaskToken,
924+
IntegrationPattern.CallAndContinue]
925+
926+
is_integration_pattern_valid(integration_pattern, supported_integ_patterns)
927+
928+
if integration_pattern == IntegrationPattern.WaitForCompletion:
929+
"""
930+
Example resource arn:aws:states:::states:startExecution.sync:2
931+
"""
932+
kwargs[Field.Resource.value] = get_service_integration_arn(STEP_FUNCTIONS_SERVICE_NAME,
933+
StepFunctions.StartExecution,
934+
integration_pattern,
935+
2)
936+
else:
937+
"""
938+
Example resource arn:
939+
- arn:aws:states:::states:startExecution.waitForTaskToken
940+
- arn:aws:states:::states:startExecution
941+
"""
942+
kwargs[Field.Resource.value] = get_service_integration_arn(STEP_FUNCTIONS_SERVICE_NAME,
943+
StepFunctions.StartExecution,
944+
integration_pattern)
945+
946+
super(StepFunctionsStartExecutionStep, self).__init__(state_id, **kwargs)

Diff for: tests/unit/test_service_steps.py

+97
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from __future__ import absolute_import
1414

1515
import boto3
16+
import pytest
1617

1718
from unittest.mock import patch
1819
from stepfunctions.steps.service import DynamoDBGetItemStep, DynamoDBPutItemStep, DynamoDBUpdateItemStep, DynamoDBDeleteItemStep
@@ -30,6 +31,8 @@
3031
from stepfunctions.steps.service import EventBridgePutEventsStep
3132
from stepfunctions.steps.service import SnsPublishStep, SqsSendMessageStep
3233
from stepfunctions.steps.service import GlueDataBrewStartJobRunStep
34+
from stepfunctions.steps.service import StepFunctionsStartExecutionStep
35+
from stepfunctions.steps.integration_resources import IntegrationPattern
3336

3437

3538
@patch.object(boto3.session.Session, 'region_name', 'us-east-1')
@@ -1158,3 +1161,97 @@ def test_eks_call_step_creation():
11581161
},
11591162
'End': True
11601163
}
1164+
1165+
1166+
@patch.object(boto3.session.Session, 'region_name', 'us-east-1')
1167+
def test_step_functions_start_execution_step_creation_default():
1168+
step = StepFunctionsStartExecutionStep(
1169+
"SFN Start Execution", parameters={
1170+
"StateMachineArn": "arn:aws:states:us-east-1:123456789012:stateMachine:HelloWorld",
1171+
"Name": "ExecutionName"
1172+
})
1173+
1174+
assert step.to_dict() == {
1175+
"Type": "Task",
1176+
"Resource": "arn:aws:states:::states:startExecution.sync:2",
1177+
"Parameters": {
1178+
"StateMachineArn": "arn:aws:states:us-east-1:123456789012:stateMachine:HelloWorld",
1179+
"Name": "ExecutionName"
1180+
},
1181+
"End": True
1182+
}
1183+
1184+
1185+
@patch.object(boto3.session.Session, 'region_name', 'us-east-1')
1186+
def test_step_functions_start_execution_step_creation_call_and_continue():
1187+
step = StepFunctionsStartExecutionStep(
1188+
"SFN Start Execution", integration_pattern=IntegrationPattern.CallAndContinue, parameters={
1189+
"StateMachineArn": "arn:aws:states:us-east-1:123456789012:stateMachine:HelloWorld",
1190+
"Name": "ExecutionName"
1191+
})
1192+
1193+
assert step.to_dict() == {
1194+
"Type": "Task",
1195+
"Resource": "arn:aws:states:::states:startExecution",
1196+
"Parameters": {
1197+
"StateMachineArn": "arn:aws:states:us-east-1:123456789012:stateMachine:HelloWorld",
1198+
"Name": "ExecutionName"
1199+
},
1200+
"End": True
1201+
}
1202+
1203+
1204+
@patch.object(boto3.session.Session, 'region_name', 'us-east-1')
1205+
def test_step_functions_start_execution_step_creation_wait_for_completion():
1206+
step = StepFunctionsStartExecutionStep(
1207+
"SFN Start Execution - Sync", integration_pattern=IntegrationPattern.WaitForCompletion, parameters={
1208+
"StateMachineArn": "arn:aws:states:us-east-1:123456789012:stateMachine:HelloWorld",
1209+
"Name": "ExecutionName"
1210+
})
1211+
1212+
assert step.to_dict() == {
1213+
"Type": "Task",
1214+
"Resource": "arn:aws:states:::states:startExecution.sync:2",
1215+
"Parameters": {
1216+
"StateMachineArn": "arn:aws:states:us-east-1:123456789012:stateMachine:HelloWorld",
1217+
"Name": "ExecutionName"
1218+
},
1219+
"End": True
1220+
}
1221+
1222+
1223+
@patch.object(boto3.session.Session, 'region_name', 'us-east-1')
1224+
def test_step_functions_start_execution_step_creation_wait_for_task_token():
1225+
step = StepFunctionsStartExecutionStep(
1226+
"SFN Start Execution - Wait for Callback", integration_pattern=IntegrationPattern.WaitForTaskToken,
1227+
parameters={
1228+
"Input": {
1229+
"token.$": "$$.Task.Token"
1230+
},
1231+
"StateMachineArn": "arn:aws:states:us-east-1:123456789012:stateMachine:HelloWorld",
1232+
"Name": "ExecutionName"
1233+
})
1234+
1235+
assert step.to_dict() == {
1236+
"Type": "Task",
1237+
"Resource": "arn:aws:states:::states:startExecution.waitForTaskToken",
1238+
"Parameters": {
1239+
"Input": {
1240+
"token.$": "$$.Task.Token"
1241+
},
1242+
"StateMachineArn": "arn:aws:states:us-east-1:123456789012:stateMachine:HelloWorld",
1243+
"Name": "ExecutionName"
1244+
},
1245+
"End": True
1246+
}
1247+
1248+
1249+
@pytest.mark.parametrize("integration_pattern", [
1250+
None,
1251+
"ServiceIntegrationTypeStr",
1252+
0
1253+
])
1254+
@patch.object(boto3.session.Session, 'region_name', 'us-east-1')
1255+
def test_step_functions_start_execution_step_creation_invalid_integration_pattern_raises_type_error(integration_pattern):
1256+
with pytest.raises(TypeError):
1257+
StepFunctionsStartExecutionStep("SFN Start Execution - invalid ServiceType", integration_pattern=integration_pattern)

Diff for: tests/unit/test_steps_utils.py

+23-3
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,16 @@
1313

1414
# Test if boto3 session can fetch correct aws partition info from test environment
1515

16-
from stepfunctions.steps.utils import get_aws_partition, merge_dicts
17-
from stepfunctions.steps.integration_resources import IntegrationPattern, get_service_integration_arn
1816
import boto3
19-
from unittest.mock import patch
17+
import logging
18+
import pytest
19+
2020
from enum import Enum
21+
from unittest.mock import patch
22+
23+
from stepfunctions.steps.utils import get_aws_partition, merge_dicts
24+
from stepfunctions.steps.integration_resources import IntegrationPattern,\
25+
get_service_integration_arn, is_integration_pattern_valid
2126

2227

2328
testService = "sagemaker"
@@ -86,3 +91,18 @@ def test_merge_dicts():
8691
'b': 2,
8792
'c': 3
8893
}
94+
95+
96+
@pytest.mark.parametrize("service_integration_type", [
97+
None,
98+
"IntegrationPatternStr",
99+
0
100+
])
101+
def test_is_integration_pattern_valid_with_invalid_type_raises_type_error(service_integration_type):
102+
with pytest.raises(TypeError):
103+
is_integration_pattern_valid(service_integration_type, [IntegrationPattern.WaitForTaskToken])
104+
105+
106+
def test_is_integration_pattern_valid_with_non_supported_type_raises_value_error():
107+
with pytest.raises(ValueError):
108+
is_integration_pattern_valid(IntegrationPattern.WaitForTaskToken, [IntegrationPattern.WaitForCompletion])

0 commit comments

Comments
 (0)