Skip to content

Commit 7bff6e6

Browse files
committed
Add retry arg to pipeline constructors to add retrier to each pipeline step
1 parent 1ea8346 commit 7bff6e6

File tree

6 files changed

+96
-20
lines changed

6 files changed

+96
-20
lines changed

Diff for: src/stepfunctions/template/pipeline/inference.py

+12-3
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ class InferencePipeline(WorkflowTemplate):
3939

4040
__allowed_kwargs = ('compression_type', 'content_type', 'pipeline_name')
4141

42-
def __init__(self, preprocessor, estimator, inputs, s3_bucket, role, client=None, **kwargs):
42+
def __init__(self, preprocessor, estimator, inputs, s3_bucket, role, client=None, retry=None, **kwargs):
4343
"""
4444
Args:
4545
preprocessor (sagemaker.estimator.EstimatorBase): The estimator used to preprocess and transform the training data.
@@ -54,6 +54,7 @@ def __init__(self, preprocessor, estimator, inputs, s3_bucket, role, client=None
5454
* (list[`sagemaker.amazon.amazon_estimator.RecordSet`]) - A list of `sagemaker.amazon.amazon_estimator.RecordSet` objects, where each instance is a different channel of training data.
5555
s3_bucket (str): S3 bucket under which the output artifacts from the training job will be stored. The parent path used is built using the format: ``s3://{s3_bucket}/{pipeline_name}/models/{job_name}/``. In this format, `pipeline_name` refers to the keyword argument provided for TrainingPipeline. If a `pipeline_name` argument was not provided, one is auto-generated by the pipeline as `training-pipeline-<timestamp>`. Also, in the format, `job_name` refers to the job name provided when calling the :meth:`TrainingPipeline.run()` method.
5656
client (SFN.Client, optional): boto3 client to use for creating and interacting with the inference pipeline in Step Functions. (default: None)
57+
retry (Retry): A retrier that defines the each pipeline step's retry policy. See `Error handling in Step Functions <https://docs.aws.amazon.com/step-functions/latest/dg/concepts-error-handling.html#error-handling-retrying-after-an-error>`_ for more details. (default: None)
5758
5859
Keyword Args:
5960
compression_type (str, optional): Compression type (Gzip/None) of the file for TransformJob. (default:None)
@@ -64,6 +65,7 @@ def __init__(self, preprocessor, estimator, inputs, s3_bucket, role, client=None
6465
self.estimator = estimator
6566
self.inputs = inputs
6667
self.s3_bucket = s3_bucket
68+
self.retry = retry
6769

6870
for key in self.__class__.__allowed_kwargs:
6971
setattr(self, key, kwargs.pop(key, None))
@@ -158,15 +160,22 @@ def build_workflow_definition(self):
158160
endpoint_config_name=default_name,
159161
)
160162

161-
return Chain([
163+
steps = [
162164
preprocessor_train_step,
163165
preprocessor_model_step,
164166
preprocessor_transform_step,
165167
training_step,
166168
pipeline_model_step,
167169
endpoint_config_step,
168170
deploy_step
169-
])
171+
]
172+
173+
if self.retry:
174+
for step in steps:
175+
step.add_retry(self.retry)
176+
177+
return Chain(steps)
178+
170179

171180
def pipeline_model_config(self, instance_type, pipeline_model):
172181
return {

Diff for: src/stepfunctions/template/pipeline/train.py

+11-8
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,7 @@
1212
# permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15-
from sagemaker.utils import base_name_from_image
16-
from sagemaker.sklearn.estimator import SKLearn
17-
from sagemaker.model import Model
18-
from sagemaker.pipeline import PipelineModel
19-
20-
from stepfunctions.steps import TrainingStep, TransformStep, ModelStep, EndpointConfigStep, EndpointStep, Chain, Fail, Catch
15+
from stepfunctions.steps import TrainingStep, ModelStep, EndpointConfigStep, EndpointStep, Chain, Retry
2116
from stepfunctions.workflow import Workflow
2217
from stepfunctions.template.pipeline.common import StepId, WorkflowTemplate
2318

@@ -35,7 +30,7 @@ class TrainingPipeline(WorkflowTemplate):
3530

3631
__allowed_kwargs = ('pipeline_name',)
3732

38-
def __init__(self, estimator, role, inputs, s3_bucket, client=None, **kwargs):
33+
def __init__(self, estimator, role, inputs, s3_bucket, client=None, retry=None, **kwargs):
3934
"""
4035
Args:
4136
estimator (sagemaker.estimator.EstimatorBase): The estimator to use for training. Can be a BYO estimator, Framework estimator or Amazon algorithm estimator.
@@ -49,12 +44,14 @@ def __init__(self, estimator, role, inputs, s3_bucket, client=None, **kwargs):
4944
* (list[`sagemaker.amazon.amazon_estimator.RecordSet`]) - A list of `sagemaker.amazon.amazon_estimator.RecordSet` objects, where each instance is a different channel of training data.
5045
s3_bucket (str): S3 bucket under which the output artifacts from the training job will be stored. The parent path used is built using the format: ``s3://{s3_bucket}/{pipeline_name}/models/{job_name}/``. In this format, `pipeline_name` refers to the keyword argument provided for TrainingPipeline. If a `pipeline_name` argument was not provided, one is auto-generated by the pipeline as `training-pipeline-<timestamp>`. Also, in the format, `job_name` refers to the job name provided when calling the :meth:`TrainingPipeline.run()` method.
5146
client (SFN.Client, optional): boto3 client to use for creating and interacting with the training pipeline in Step Functions. (default: None)
47+
retry (Retry): A retrier that defines the each pipeline step's retry policy. See `Error handling in Step Functions <https://docs.aws.amazon.com/step-functions/latest/dg/concepts-error-handling.html#error-handling-retrying-after-an-error>`_ for more details. (default: None)
5248
5349
Keyword Args:
5450
pipeline_name (str, optional): Name of the pipeline. This name will be used to name jobs (if not provided when calling execute()), models, endpoints, and S3 objects created by the pipeline. If a `pipeline_name` argument was not provided, one is auto-generated by the pipeline as `training-pipeline-<timestamp>`. (default:None)
5551
"""
5652
self.estimator = estimator
5753
self.inputs = inputs
54+
self.retry = retry
5855

5956
for key in self.__class__.__allowed_kwargs:
6057
setattr(self, key, kwargs.pop(key, None))
@@ -110,7 +107,13 @@ def build_workflow_definition(self):
110107
endpoint_config_name=default_name,
111108
)
112109

113-
return Chain([training_step, model_step, endpoint_config_step, deploy_step])
110+
steps = [training_step, model_step, endpoint_config_step, deploy_step]
111+
112+
if self.retry:
113+
for step in steps:
114+
step.add_retry(self.retry)
115+
116+
return Chain(steps)
114117

115118
def execute(self, job_name=None, hyperparameters=None):
116119
"""

Diff for: tests/integ/test_inference_pipeline.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
from stepfunctions.template.pipeline import InferencePipeline
2424

25-
from tests.integ import DATA_DIR, DEFAULT_TIMEOUT_MINUTES
25+
from tests.integ import DATA_DIR, DEFAULT_TIMEOUT_MINUTES, SAGEMAKER_RETRY_STRATEGY
2626
from tests.integ.timeout import timeout
2727
from tests.integ.utils import (
2828
state_machine_delete_wait,
@@ -36,6 +36,7 @@
3636
BASE_NAME = 'inference-pipeline-integtest'
3737
COMPRESSED_NPY_DATA = 'mnist.npy.gz'
3838

39+
3940
# Fixtures
4041
@pytest.fixture(scope="module")
4142
def sklearn_preprocessor(sagemaker_role_arn, sagemaker_session):
@@ -100,7 +101,8 @@ def test_inference_pipeline_framework(
100101
role=sfn_role_arn,
101102
compression_type='Gzip',
102103
content_type='application/x-npy',
103-
pipeline_name=unique_name
104+
pipeline_name=unique_name,
105+
retry=SAGEMAKER_RETRY_STRATEGY
104106
)
105107

106108
_ = pipeline.create()

Diff for: tests/integ/test_training_pipeline_estimators.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
# import StepFunctions
3131
from stepfunctions.template.pipeline import TrainingPipeline
3232

33-
from tests.integ import DATA_DIR, DEFAULT_TIMEOUT_MINUTES
33+
from tests.integ import DATA_DIR, DEFAULT_TIMEOUT_MINUTES, SAGEMAKER_RETRY_STRATEGY
3434
from tests.integ.timeout import timeout
3535
from tests.integ.utils import (
3636
state_machine_delete_wait,
@@ -60,7 +60,8 @@ def pca_estimator(sagemaker_role_arn):
6060
pca_estimator.mini_batch_size=128
6161

6262
return pca_estimator
63-
63+
64+
6465
@pytest.fixture(scope="module")
6566
def inputs(pca_estimator):
6667
data_path = os.path.join(DATA_DIR, "one_p_mnist", "mnist.pkl.gz")
@@ -85,7 +86,8 @@ def test_pca_estimator(sfn_client, sagemaker_session, sagemaker_role_arn, sfn_ro
8586
role=sfn_role_arn,
8687
inputs=inputs,
8788
s3_bucket=bucket_name,
88-
pipeline_name = unique_name
89+
pipeline_name=unique_name,
90+
retry=SAGEMAKER_RETRY_STRATEGY
8991
)
9092
tp.create()
9193

Diff for: tests/integ/test_training_pipeline_framework_estimator.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import sagemaker
1717
import os
1818

19-
from tests.integ import DATA_DIR, DEFAULT_TIMEOUT_MINUTES
19+
from tests.integ import DATA_DIR, DEFAULT_TIMEOUT_MINUTES, SAGEMAKER_RETRY_STRATEGY
2020
from tests.integ.timeout import timeout
2121
from stepfunctions.template import TrainingPipeline
2222
from sagemaker.pytorch import PyTorch
@@ -29,6 +29,7 @@
2929
get_resource_name_from_arn
3030
)
3131

32+
3233
@pytest.fixture(scope="module")
3334
def torch_estimator(sagemaker_role_arn):
3435
script_path = os.path.join(DATA_DIR, "pytorch_mnist", "mnist.py")
@@ -45,6 +46,7 @@ def torch_estimator(sagemaker_role_arn):
4546
}
4647
)
4748

49+
4850
@pytest.fixture(scope="module")
4951
def sklearn_estimator(sagemaker_role_arn):
5052
script_path = os.path.join(DATA_DIR, "sklearn_mnist", "mnist.py")
@@ -103,7 +105,8 @@ def test_torch_training_pipeline(sfn_client, sagemaker_client, torch_estimator,
103105
sfn_role_arn,
104106
inputs,
105107
sagemaker_session.default_bucket(),
106-
sfn_client
108+
sfn_client,
109+
retry=SAGEMAKER_RETRY_STRATEGY
107110
)
108111
pipeline.create()
109112
# execute pipeline
@@ -138,7 +141,8 @@ def test_sklearn_training_pipeline(sfn_client, sagemaker_client, sklearn_estimat
138141
sfn_role_arn,
139142
inputs,
140143
sagemaker_session.default_bucket(),
141-
sfn_client
144+
sfn_client,
145+
retry=SAGEMAKER_RETRY_STRATEGY
142146
)
143147
pipeline.create()
144148
# run pipeline
@@ -154,4 +158,4 @@ def test_sklearn_training_pipeline(sfn_client, sagemaker_client, sklearn_estimat
154158
_pipeline_test_suite(sagemaker_client, training_job_name='estimator-'+endpoint_name, model_name=endpoint_name, endpoint_name=endpoint_name)
155159

156160
# teardown
157-
_pipeline_teardown(sfn_client, sagemaker_session, endpoint_name, pipeline)
161+
_pipeline_teardown(sfn_client, sagemaker_session, endpoint_name, pipeline)

Diff for: tests/unit/test_pipeline.py

+56
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from sagemaker.sklearn.estimator import SKLearn
2020
from unittest.mock import MagicMock, patch
2121
from stepfunctions.template import TrainingPipeline, InferencePipeline
22+
from stepfunctions.steps import Retry
2223
from sagemaker.debugger import DebuggerHookConfig
2324

2425
from tests.unit.utils import mock_boto_api_call
@@ -27,6 +28,16 @@
2728
STEPFUNCTIONS_EXECUTION_ROLE = 'StepFunctionsExecutionRole'
2829
PCA_IMAGE = '382416733822.dkr.ecr.us-east-1.amazonaws.com/pca:1'
2930
LINEAR_LEARNER_IMAGE = '382416733822.dkr.ecr.us-east-1.amazonaws.com/linear-learner:1'
31+
SAGEMAKER_RETRY_STRATEGY = Retry(
32+
error_equals=["SageMaker.AmazonSageMakerException"],
33+
interval_seconds=5,
34+
max_attempts=5,
35+
backoff_rate=2
36+
)
37+
EXPECTED_RETRY = [{'BackoffRate': 2,
38+
'ErrorEquals': ['SageMaker.AmazonSageMakerException'],
39+
'IntervalSeconds': 5,
40+
'MaxAttempts': 5}]
3041

3142

3243
@pytest.fixture
@@ -235,6 +246,25 @@ def test_pca_training_pipeline(pca_estimator):
235246
workflow.execute.assert_called_with(name=job_name, inputs=inputs)
236247

237248

249+
@patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call)
250+
@patch.object(boto3.session.Session, 'region_name', 'us-east-1')
251+
def test_pca_training_pipeline_with_retry_adds_retry_to_each_step(pca_estimator):
252+
s3_inputs = {
253+
'train': 's3://sagemaker/pca/train'
254+
}
255+
s3_bucket = 'sagemaker-us-east-1'
256+
257+
pipeline = TrainingPipeline(pca_estimator, STEPFUNCTIONS_EXECUTION_ROLE, s3_inputs, s3_bucket,
258+
retry=SAGEMAKER_RETRY_STRATEGY)
259+
result = pipeline.workflow.definition.to_dict()
260+
261+
assert result['States']['Training']['Retry'] == EXPECTED_RETRY
262+
assert result['States']['Create Model']['Retry'] == EXPECTED_RETRY
263+
assert result['States']['Configure Endpoint']['Retry'] == EXPECTED_RETRY
264+
assert result['States']['Deploy']['Retry'] == EXPECTED_RETRY
265+
266+
267+
238268
@patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call)
239269
@patch.object(boto3.session.Session, 'region_name', 'us-east-1')
240270
def test_inference_pipeline(sklearn_preprocessor, linear_learner_estimator):
@@ -474,3 +504,29 @@ def test_inference_pipeline(sklearn_preprocessor, linear_learner_estimator):
474504
}
475505

476506
workflow.execute.assert_called_with(name=job_name, inputs=inputs)
507+
508+
509+
@patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call)
510+
@patch.object(boto3.session.Session, 'region_name', 'us-east-1')
511+
def test_inference_pipeline(sklearn_preprocessor, linear_learner_estimator):
512+
s3_inputs = {
513+
'train': 's3://sagemaker-us-east-1/inference/train'
514+
}
515+
s3_bucket = 'sagemaker-us-east-1'
516+
517+
pipeline = InferencePipeline(
518+
preprocessor=sklearn_preprocessor,
519+
estimator=linear_learner_estimator,
520+
inputs=s3_inputs,
521+
s3_bucket=s3_bucket,
522+
role=STEPFUNCTIONS_EXECUTION_ROLE,
523+
retry=SAGEMAKER_RETRY_STRATEGY
524+
)
525+
result = pipeline.get_workflow().definition.to_dict()
526+
527+
assert result['States']['Train Preprocessor']['Retry'] == EXPECTED_RETRY
528+
assert result['States']['Create Preprocessor Model']['Retry'] == EXPECTED_RETRY
529+
assert result['States']['Transform Input']['Retry'] == EXPECTED_RETRY
530+
assert result['States']['Create Pipeline Model']['Retry'] == EXPECTED_RETRY
531+
assert result['States']['Configure Endpoint']['Retry'] == EXPECTED_RETRY
532+
assert result['States']['Deploy']['Retry'] == EXPECTED_RETRY

0 commit comments

Comments
 (0)