Skip to content

Commit f5800cb

Browse files
committed
Support placeholders for Map state
1 parent f8bbfaf commit f5800cb

File tree

2 files changed

+95
-6
lines changed

2 files changed

+95
-6
lines changed

src/stepfunctions/steps/states.py

+32-6
Original file line numberDiff line numberDiff line change
@@ -73,11 +73,37 @@ def to_dict(self):
7373
k = to_pascalcase(k)
7474
if k == to_pascalcase(Field.Parameters.value):
7575
result[k] = self._replace_placeholders(v)
76+
elif self._is_placeholder_compatible(k):
77+
if isinstance(v, Placeholder):
78+
modified_key = f"{k}.$"
79+
result[modified_key] = v.to_jsonpath()
80+
else:
81+
result[k] = v
7682
else:
7783
result[k] = v
7884

7985
return result
8086

87+
@staticmethod
88+
def _is_placeholder_compatible(field):
89+
"""
90+
Check if the field is placeholder compatible
91+
92+
Args:
93+
field: Field against which to verify placeholder compatibility
94+
"""
95+
return field in [
96+
# Common fields
97+
to_pascalcase(Field.Comment.value),
98+
to_pascalcase(Field.InputPath.value),
99+
to_pascalcase(Field.OutputPath.value),
100+
to_pascalcase(Field.ResultPath.value),
101+
102+
# Map
103+
to_pascalcase(Field.ItemsPath.value),
104+
to_pascalcase(Field.MaxConcurrency.value),
105+
]
106+
81107
def to_json(self, pretty=False):
82108
"""Serialize to a JSON formatted string.
83109
@@ -541,13 +567,13 @@ def __init__(self, state_id, **kwargs):
541567
Args:
542568
state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine.
543569
iterator (State or Chain): State or chain to execute for each of the items in `items_path`.
544-
items_path (str, optional): Path in the input for items to iterate over. (default: '$')
545-
max_concurrency (int, optional): Maximum number of iterations to have running at any given point in time. (default: 0)
546-
comment (str, optional): Human-readable comment or description. (default: None)
547-
input_path (str, optional): Path applied to the state’s raw input to select some or all of it; that selection is used by the state. (default: '$')
570+
items_path (str or Placeholder, optional): Path in the input for items to iterate over. (default: '$')
571+
max_concurrency (int or Placeholder, optional): Maximum number of iterations to have running at any given point in time. (default: 0)
572+
comment (str or Placeholder, optional): Human-readable comment or description. (default: None)
573+
input_path (str or Placeholder, optional): Path applied to the state’s raw input to select some or all of it; that selection is used by the state. (default: '$')
548574
parameters (dict, optional): The value of this field becomes the effective input for the state.
549-
result_path (str, optional): Path specifying the raw input’s combination with or replacement by the state’s result. (default: '$')
550-
output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$')
575+
result_path (str or Placeholder, optional): Path specifying the raw input’s combination with or replacement by the state’s result. (default: '$')
576+
output_path (str or Placeholder, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$')
551577
"""
552578
super(Map, self).__init__(state_id, 'Map', **kwargs)
553579

tests/unit/test_placeholders_with_steps.py

+63
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,69 @@ def test_map_state_with_placeholders():
214214
result = Graph(workflow_definition).to_dict()
215215
assert result == expected_repr
216216

217+
218+
def test_map_state_with_placeholders():
219+
workflow_input = ExecutionInput(schema={
220+
'comment': str,
221+
'input_path': str,
222+
'output_path': str,
223+
'result_path': str,
224+
'items_path': str,
225+
'max_concurrency': int,
226+
'ParamB': str
227+
})
228+
229+
map_state = Map(
230+
'MapState01',
231+
comment=workflow_input['input_path'],
232+
input_path=workflow_input['input_path'],
233+
output_path=workflow_input['output_path'],
234+
result_path=workflow_input['result_path'],
235+
items_path=workflow_input['result_path'],
236+
max_concurrency=workflow_input['max_concurrency']
237+
)
238+
iterator_state = Pass(
239+
'TrainIterator',
240+
parameters={
241+
'ParamA': map_state.output()['X']["Y"],
242+
'ParamB': workflow_input['ParamB']
243+
})
244+
245+
map_state.attach_iterator(iterator_state)
246+
workflow_definition = Chain([map_state])
247+
248+
expected_repr = {
249+
"StartAt": "MapState01",
250+
"States": {
251+
"MapState01": {
252+
"Type": "Map",
253+
"End": True,
254+
"Comment.$": "$$.Execution.Input['input_path']",
255+
"InputPath.$": "$$.Execution.Input['input_path']",
256+
"ItemsPath.$": "$$.Execution.Input['result_path']",
257+
"Iterator": {
258+
"StartAt": "TrainIterator",
259+
"States": {
260+
"TrainIterator": {
261+
"Parameters": {
262+
"ParamA.$": "$['X']['Y']",
263+
"ParamB.$": "$$.Execution.Input['ParamB']"
264+
},
265+
"Type": "Pass",
266+
"End": True
267+
}
268+
}
269+
},
270+
"MaxConcurrency.$": "$$.Execution.Input['max_concurrency']",
271+
"OutputPath.$": "$$.Execution.Input['output_path']",
272+
"ResultPath.$": "$$.Execution.Input['result_path']",
273+
}
274+
}
275+
}
276+
277+
result = Graph(workflow_definition).to_dict()
278+
assert result == expected_repr
279+
217280
def test_parallel_state_with_placeholders():
218281
workflow_input = ExecutionInput()
219282

0 commit comments

Comments
 (0)