Skip to content

Commit 2c72de7

Browse files
authored
Merge branch 'main' into chain-choice-states
2 parents 851023f + 4f90ba3 commit 2c72de7

14 files changed

+560
-56
lines changed

Diff for: CONTRIBUTING.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@ Before sending us a pull request, please ensure that:
5757
### Running the Unit Tests
5858

5959
1. Install tox using `pip install tox`
60-
1. Install test dependencies, including coverage, using `pip install .[test]`
6160
1. cd into the aws-step-functions-data-science-sdk-python folder: `cd aws-step-functions-data-science-sdk-python` or `cd /environment/aws-step-functions-data-science-sdk-python`
61+
1. Install test dependencies, including coverage, using `pip install ".[test]"`
6262
1. Run the following tox command and verify that all code checks and unit tests pass: `tox tests/unit`
6363

6464
You can also run a single test with the following command: `tox -e py36 -- -s -vv <path_to_file><file_name>::<test_function_name>`
@@ -80,7 +80,7 @@ You should only worry about manually running any new integration tests that you
8080

8181
1. Create a new git branch:
8282
```shell
83-
git checkout -b my-fix-branch master
83+
git checkout -b my-fix-branch
8484
```
8585
1. Make your changes, **including unit tests** and, if appropriate, integration tests.
8686
1. Include unit tests when you contribute new features or make bug fixes, as they help to:

Diff for: src/stepfunctions/steps/compute.py

+76-8
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,31 @@
1212
# permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15+
from enum import Enum
1516
from stepfunctions.steps.states import Task
1617
from stepfunctions.steps.fields import Field
18+
from stepfunctions.steps.integration_resources import IntegrationPattern, get_service_integration_arn
19+
20+
LAMBDA_SERVICE_NAME = "lambda"
21+
GLUE_SERVICE_NAME = "glue"
22+
ECS_SERVICE_NAME = "ecs"
23+
BATCH_SERVICE_NAME = "batch"
24+
25+
26+
class LambdaApi(Enum):
27+
Invoke = "invoke"
28+
29+
30+
class GlueApi(Enum):
31+
StartJobRun = "startJobRun"
32+
33+
34+
class EcsApi(Enum):
35+
RunTask = "runTask"
36+
37+
38+
class BatchApi(Enum):
39+
SubmitJob = "submitJob"
1740

1841

1942
class LambdaStep(Task):
@@ -37,10 +60,22 @@ def __init__(self, state_id, wait_for_callback=False, **kwargs):
3760
result_path (str, optional): Path specifying the raw input’s combination with or replacement by the state’s result. (default: '$')
3861
output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$')
3962
"""
63+
4064
if wait_for_callback:
41-
kwargs[Field.Resource.value] = 'arn:aws:states:::lambda:invoke.waitForTaskToken'
65+
"""
66+
Example resource arn: arn:aws:states:::lambda:invoke.waitForTaskToken
67+
"""
68+
69+
kwargs[Field.Resource.value] = get_service_integration_arn(LAMBDA_SERVICE_NAME,
70+
LambdaApi.Invoke,
71+
IntegrationPattern.WaitForTaskToken)
4272
else:
43-
kwargs[Field.Resource.value] = 'arn:aws:states:::lambda:invoke'
73+
"""
74+
Example resource arn: arn:aws:states:::lambda:invoke
75+
"""
76+
77+
kwargs[Field.Resource.value] = get_service_integration_arn(LAMBDA_SERVICE_NAME, LambdaApi.Invoke)
78+
4479

4580
super(LambdaStep, self).__init__(state_id, **kwargs)
4681

@@ -67,9 +102,20 @@ def __init__(self, state_id, wait_for_completion=True, **kwargs):
67102
output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$')
68103
"""
69104
if wait_for_completion:
70-
kwargs[Field.Resource.value] = 'arn:aws:states:::glue:startJobRun.sync'
105+
"""
106+
Example resource arn: arn:aws:states:::glue:startJobRun.sync
107+
"""
108+
109+
kwargs[Field.Resource.value] = get_service_integration_arn(GLUE_SERVICE_NAME,
110+
GlueApi.StartJobRun,
111+
IntegrationPattern.WaitForCompletion)
71112
else:
72-
kwargs[Field.Resource.value] = 'arn:aws:states:::glue:startJobRun'
113+
"""
114+
Example resource arn: arn:aws:states:::glue:startJobRun
115+
"""
116+
117+
kwargs[Field.Resource.value] = get_service_integration_arn(GLUE_SERVICE_NAME,
118+
GlueApi.StartJobRun)
73119

74120
super(GlueStartJobRunStep, self).__init__(state_id, **kwargs)
75121

@@ -96,9 +142,20 @@ def __init__(self, state_id, wait_for_completion=True, **kwargs):
96142
output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$')
97143
"""
98144
if wait_for_completion:
99-
kwargs[Field.Resource.value] = 'arn:aws:states:::batch:submitJob.sync'
145+
"""
146+
Example resource arn: arn:aws:states:::batch:submitJob.sync
147+
"""
148+
149+
kwargs[Field.Resource.value] = get_service_integration_arn(BATCH_SERVICE_NAME,
150+
BatchApi.SubmitJob,
151+
IntegrationPattern.WaitForCompletion)
100152
else:
101-
kwargs[Field.Resource.value] = 'arn:aws:states:::batch:submitJob'
153+
"""
154+
Example resource arn: arn:aws:states:::batch:submitJob
155+
"""
156+
157+
kwargs[Field.Resource.value] = get_service_integration_arn(BATCH_SERVICE_NAME,
158+
BatchApi.SubmitJob)
102159

103160
super(BatchSubmitJobStep, self).__init__(state_id, **kwargs)
104161

@@ -125,8 +182,19 @@ def __init__(self, state_id, wait_for_completion=True, **kwargs):
125182
output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$')
126183
"""
127184
if wait_for_completion:
128-
kwargs[Field.Resource.value] = 'arn:aws:states:::ecs:runTask.sync'
185+
"""
186+
Example resource arn: arn:aws:states:::ecs:runTask.sync
187+
"""
188+
189+
kwargs[Field.Resource.value] = get_service_integration_arn(ECS_SERVICE_NAME,
190+
EcsApi.RunTask,
191+
IntegrationPattern.WaitForCompletion)
129192
else:
130-
kwargs[Field.Resource.value] = 'arn:aws:states:::ecs:runTask'
193+
"""
194+
Example resource arn: arn:aws:states:::ecs:runTask
195+
"""
196+
197+
kwargs[Field.Resource.value] = get_service_integration_arn(ECS_SERVICE_NAME,
198+
EcsApi.RunTask)
131199

132200
super(EcsRunTaskStep, self).__init__(state_id, **kwargs)

Diff for: src/stepfunctions/steps/integration_resources.py

+46
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License").
4+
# You may not use this file except in compliance with the License.
5+
# A copy of the License is located at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# or in the "license" file accompanying this file. This file is distributed
10+
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
11+
# express or implied. See the License for the specific language governing
12+
# permissions and limitations under the License.
13+
14+
from __future__ import absolute_import
15+
16+
from enum import Enum
17+
from stepfunctions.steps.utils import get_aws_partition
18+
19+
20+
class IntegrationPattern(Enum):
21+
"""
22+
Integration pattern enum classes for task integration resource arn builder
23+
"""
24+
25+
WaitForTaskToken = "waitForTaskToken"
26+
WaitForCompletion = "sync"
27+
RequestResponse = ""
28+
29+
30+
def get_service_integration_arn(service, api, integration_pattern=IntegrationPattern.RequestResponse):
31+
32+
"""
33+
ARN builder for task integration
34+
Args:
35+
service (str): The service name for the service integration
36+
api (str): The api of the service integration
37+
integration_pattern (IntegrationPattern, optional): The integration pattern for the task. (Default: IntegrationPattern.RequestResponse)
38+
"""
39+
arn = ""
40+
if integration_pattern == IntegrationPattern.RequestResponse:
41+
arn = f"arn:{get_aws_partition()}:states:::{service}:{api.value}"
42+
else:
43+
arn = f"arn:{get_aws_partition()}:states:::{service}:{api.value}.{integration_pattern.value}"
44+
return arn
45+
46+

Diff for: src/stepfunctions/steps/sagemaker.py

+95-13
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,31 @@
1212
# permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15+
from enum import Enum
1516
from stepfunctions.inputs import ExecutionInput, StepInput
1617
from stepfunctions.steps.states import Task
1718
from stepfunctions.steps.fields import Field
1819
from stepfunctions.steps.utils import tags_dict_to_kv_list
20+
from stepfunctions.steps.integration_resources import IntegrationPattern, get_service_integration_arn
1921

2022
from sagemaker.workflow.airflow import training_config, transform_config, model_config, tuning_config, processing_config
2123
from sagemaker.model import Model, FrameworkModel
2224
from sagemaker.model_monitor import DataCaptureConfig
2325

26+
SAGEMAKER_SERVICE_NAME = "sagemaker"
27+
28+
29+
class SageMakerApi(Enum):
30+
CreateTrainingJob = "createTrainingJob"
31+
CreateTransformJob = "createTransformJob"
32+
CreateModel = "createModel"
33+
CreateEndpointConfig = "createEndpointConfig"
34+
UpdateEndpoint = "updateEndpoint"
35+
CreateEndpoint = "createEndpoint"
36+
CreateHyperParameterTuningJob = "createHyperParameterTuningJob"
37+
CreateProcessingJob = "createProcessingJob"
38+
39+
2440
class TrainingStep(Task):
2541

2642
"""
@@ -58,9 +74,20 @@ def __init__(self, state_id, estimator, job_name, data=None, hyperparameters=Non
5874
self.job_name = job_name
5975

6076
if wait_for_completion:
61-
kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createTrainingJob.sync'
77+
"""
78+
Example resource arn: arn:aws:states:::sagemaker:createTrainingJob.sync
79+
"""
80+
81+
kwargs[Field.Resource.value] = get_service_integration_arn(SAGEMAKER_SERVICE_NAME,
82+
SageMakerApi.CreateTrainingJob,
83+
IntegrationPattern.WaitForCompletion)
6284
else:
63-
kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createTrainingJob'
85+
"""
86+
Example resource arn: arn:aws:states:::sagemaker:createTrainingJob
87+
"""
88+
89+
kwargs[Field.Resource.value] = get_service_integration_arn(SAGEMAKER_SERVICE_NAME,
90+
SageMakerApi.CreateTrainingJob)
6491

6592
if isinstance(job_name, str):
6693
parameters = training_config(estimator=estimator, inputs=data, job_name=job_name, mini_batch_size=mini_batch_size)
@@ -141,9 +168,20 @@ def __init__(self, state_id, transformer, job_name, model_name, data, data_type=
141168
join_source (str): The source of data to be joined to the transform output. It can be set to ‘Input’ meaning the entire input record will be joined to the inference result. You can use OutputFilter to select the useful portion before uploading to S3. (default: None). Valid values: Input, None.
142169
"""
143170
if wait_for_completion:
144-
kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createTransformJob.sync'
171+
"""
172+
Example resource arn: arn:aws:states:::sagemaker:createTransformJob.sync
173+
"""
174+
175+
kwargs[Field.Resource.value] = get_service_integration_arn(SAGEMAKER_SERVICE_NAME,
176+
SageMakerApi.CreateTransformJob,
177+
IntegrationPattern.WaitForCompletion)
145178
else:
146-
kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createTransformJob'
179+
"""
180+
Example resource arn: arn:aws:states:::sagemaker:createTransformJob
181+
"""
182+
183+
kwargs[Field.Resource.value] = get_service_integration_arn(SAGEMAKER_SERVICE_NAME,
184+
SageMakerApi.CreateTransformJob)
147185

148186
if isinstance(job_name, str):
149187
parameters = transform_config(
@@ -225,7 +263,13 @@ def __init__(self, state_id, model, model_name=None, instance_type=None, tags=No
225263
parameters['Tags'] = tags_dict_to_kv_list(tags)
226264

227265
kwargs[Field.Parameters.value] = parameters
228-
kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createModel'
266+
267+
"""
268+
Example resource arn: arn:aws:states:::sagemaker:createModel
269+
"""
270+
271+
kwargs[Field.Resource.value] = get_service_integration_arn(SAGEMAKER_SERVICE_NAME,
272+
SageMakerApi.CreateModel)
229273

230274
super(ModelStep, self).__init__(state_id, **kwargs)
231275

@@ -266,7 +310,13 @@ def __init__(self, state_id, endpoint_config_name, model_name, initial_instance_
266310
if tags:
267311
parameters['Tags'] = tags_dict_to_kv_list(tags)
268312

269-
kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createEndpointConfig'
313+
"""
314+
Example resource arn: arn:aws:states:::sagemaker:createEndpointConfig
315+
"""
316+
317+
kwargs[Field.Resource.value] = get_service_integration_arn(SAGEMAKER_SERVICE_NAME,
318+
SageMakerApi.CreateEndpointConfig)
319+
270320
kwargs[Field.Parameters.value] = parameters
271321

272322
super(EndpointConfigStep, self).__init__(state_id, **kwargs)
@@ -298,9 +348,19 @@ def __init__(self, state_id, endpoint_name, endpoint_config_name, tags=None, upd
298348
parameters['Tags'] = tags_dict_to_kv_list(tags)
299349

300350
if update:
301-
kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:updateEndpoint'
351+
"""
352+
Example resource arn: arn:aws:states:::sagemaker:updateEndpoint
353+
"""
354+
355+
kwargs[Field.Resource.value] = get_service_integration_arn(SAGEMAKER_SERVICE_NAME,
356+
SageMakerApi.UpdateEndpoint)
302357
else:
303-
kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createEndpoint'
358+
"""
359+
Example resource arn: arn:aws:states:::sagemaker:createEndpoint
360+
"""
361+
362+
kwargs[Field.Resource.value] = get_service_integration_arn(SAGEMAKER_SERVICE_NAME,
363+
SageMakerApi.CreateEndpoint)
304364

305365
kwargs[Field.Parameters.value] = parameters
306366

@@ -338,9 +398,20 @@ def __init__(self, state_id, tuner, job_name, data, wait_for_completion=True, ta
338398
tags (list[dict], optional): `List to tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
339399
"""
340400
if wait_for_completion:
341-
kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createHyperParameterTuningJob.sync'
401+
"""
402+
Example resource arn: arn:aws:states:::sagemaker:createHyperParameterTuningJob.sync
403+
"""
404+
405+
kwargs[Field.Resource.value] = get_service_integration_arn(SAGEMAKER_SERVICE_NAME,
406+
SageMakerApi.CreateHyperParameterTuningJob,
407+
IntegrationPattern.WaitForCompletion)
342408
else:
343-
kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createHyperParameterTuningJob'
409+
"""
410+
Example resource arn: arn:aws:states:::sagemaker:createHyperParameterTuningJob
411+
"""
412+
413+
kwargs[Field.Resource.value] = get_service_integration_arn(SAGEMAKER_SERVICE_NAME,
414+
SageMakerApi.CreateHyperParameterTuningJob)
344415

345416
parameters = tuning_config(tuner=tuner, inputs=data, job_name=job_name).copy()
346417

@@ -387,10 +458,21 @@ def __init__(self, state_id, processor, job_name, inputs=None, outputs=None, exp
387458
tags (list[dict], optional): `List to tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
388459
"""
389460
if wait_for_completion:
390-
kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createProcessingJob.sync'
461+
"""
462+
Example resource arn: arn:aws:states:::sagemaker:createProcessingJob.sync
463+
"""
464+
465+
kwargs[Field.Resource.value] = get_service_integration_arn(SAGEMAKER_SERVICE_NAME,
466+
SageMakerApi.CreateProcessingJob,
467+
IntegrationPattern.WaitForCompletion)
391468
else:
392-
kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createProcessingJob'
393-
469+
"""
470+
Example resource arn: arn:aws:states:::sagemaker:createProcessingJob
471+
"""
472+
473+
kwargs[Field.Resource.value] = get_service_integration_arn(SAGEMAKER_SERVICE_NAME,
474+
SageMakerApi.CreateProcessingJob)
475+
394476
if isinstance(job_name, str):
395477
parameters = processing_config(processor=processor, inputs=inputs, outputs=outputs, container_arguments=container_arguments, container_entrypoint=container_entrypoint, kms_key_id=kms_key_id, job_name=job_name)
396478
else:

0 commit comments

Comments
 (0)