Skip to content

Commit 23601e5

Browse files
committed
feature: Make Choice states chainable
1 parent 48efb04 commit 23601e5

File tree

3 files changed

+43
-9
lines changed

3 files changed

+43
-9
lines changed

CONTRIBUTING.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@ Before sending us a pull request, please ensure that:
5757
### Running the Unit Tests
5858

5959
1. Install tox using `pip install tox`
60-
1. Install test dependencies, including coverage, using `pip install .[test]`
6160
1. cd into the aws-step-functions-data-science-sdk-python folder: `cd aws-step-functions-data-science-sdk-python` or `cd /environment/aws-step-functions-data-science-sdk-python`
61+
1. Install test dependencies, including coverage, using `pip install ".[test]"`
6262
1. Run the following tox command and verify that all code checks and unit tests pass: `tox tests/unit`
6363

6464
You can also run a single test with the following command: `tox -e py36 -- -s -vv <path_to_file><file_name>::<test_function_name>`
@@ -80,7 +80,7 @@ You should only worry about manually running any new integration tests that you
8080

8181
1. Create a new git branch:
8282
```shell
83-
git checkout -b my-fix-branch master
83+
git checkout -b my-fix-branch main
8484
```
8585
1. Make your changes, **including unit tests** and, if appropriate, integration tests.
8686
1. Include unit tests when you contribute new features or make bug fixes, as they help to:

src/stepfunctions/steps/states.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -218,9 +218,21 @@ def next(self, next_step):
218218
Returns:
219219
State or Chain: Next state or chain that will be transitioned to.
220220
"""
221-
if self.type in ('Choice', 'Succeed', 'Fail'):
221+
if self.type in ('Succeed', 'Fail'):
222222
raise ValueError('Unexpected State instance `{step}`, State type `{state_type}` does not support method `next`.'.format(step=next_step, state_type=self.type))
223223

224+
# By design, choice states do not have the Next field. Setting default to make it chainable.
225+
if self.type is 'Choice':
226+
if self.default is not None:
227+
logger.warning(
228+
"Chaining Choice Step: Overwriting %s's current default_choice (%s) with %s",
229+
self.state_id,
230+
self.default.state_id,
231+
next_step.state_id
232+
)
233+
self.default_choice(next_step)
234+
return self.default
235+
224236
self.next_step = next_step
225237
return self.next_step
226238

tests/unit/test_steps.py

+28-6
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15+
import logging
1516
import pytest
1617

1718
from stepfunctions.exceptions import DuplicateStatesInChain
@@ -328,12 +329,6 @@ def test_append_states_after_terminal_state_will_fail():
328329
chain.append(Succeed('Succeed'))
329330
chain.append(Pass('Pass2'))
330331

331-
with pytest.raises(ValueError):
332-
chain = Chain()
333-
chain.append(Pass('Pass'))
334-
chain.append(Choice('Choice'))
335-
chain.append(Pass('Pass2'))
336-
337332

338333
def test_chaining_steps():
339334
s1 = Pass('Step - One')
@@ -372,6 +367,33 @@ def test_chaining_steps():
372367
assert s1.next_step == s2
373368
assert s2.next_step == s3
374369

370+
371+
def test_chaining_choice(caplog):
372+
s1_pass = Pass('Step - One')
373+
s2_choice = Choice('Step - Two')
374+
s3_pass = Pass('Step - Three')
375+
376+
with caplog.at_level(logging.WARNING):
377+
chain1 = Chain([s1_pass, s2_choice, s3_pass])
378+
assert caplog.text == '' # No warning
379+
assert chain1.steps == [s1_pass, s2_choice, s3_pass]
380+
assert s1_pass.next_step == s2_choice
381+
assert s2_choice.default == s3_pass
382+
assert s2_choice.next_step is None # Choice steps do not have next_step
383+
assert s3_pass.next_step is None
384+
385+
# Chain s2_choice when default_choice is already set will trigger Warning
386+
with caplog.at_level(logging.WARNING):
387+
Chain([s2_choice, s1_pass])
388+
log_message = (
389+
"Chaining Choice Step: Overwriting %s's current default_choice (%s) with %s" %
390+
(s2_choice.state_id, s3_pass.state_id, s1_pass.state_id)
391+
)
392+
assert log_message in caplog.text
393+
assert s2_choice.default == s1_pass
394+
assert s2_choice.next_step is None # Choice steps do not have next_step
395+
396+
375397
def test_catch_fail_for_unsupported_state():
376398
s1 = Pass('Step - One')
377399

0 commit comments

Comments
 (0)