|
12 | 12 | # permissions and limitations under the License.
|
13 | 13 | from __future__ import absolute_import
|
14 | 14 |
|
| 15 | +from enum import Enum |
15 | 16 | from stepfunctions.inputs import ExecutionInput, StepInput
|
16 | 17 | from stepfunctions.steps.states import Task
|
17 | 18 | from stepfunctions.steps.fields import Field
|
18 | 19 | from stepfunctions.steps.utils import tags_dict_to_kv_list
|
| 20 | +from stepfunctions.steps.integration_resources import IntegrationPattern, get_service_integration_arn |
19 | 21 |
|
20 | 22 | from sagemaker.workflow.airflow import training_config, transform_config, model_config, tuning_config, processing_config
|
21 | 23 | from sagemaker.model import Model, FrameworkModel
|
22 | 24 | from sagemaker.model_monitor import DataCaptureConfig
|
23 | 25 |
|
| 26 | +SAGEMAKER_SERVICE_NAME = "sagemaker" |
| 27 | + |
| 28 | + |
| 29 | +class SageMakerApi(Enum): |
| 30 | + CreateTrainingJob = "createTrainingJob" |
| 31 | + CreateTransformJob = "createTransformJob" |
| 32 | + CreateModel = "createModel" |
| 33 | + CreateEndpointConfig = "createEndpointConfig" |
| 34 | + UpdateEndpoint = "updateEndpoint" |
| 35 | + CreateEndpoint = "createEndpoint" |
| 36 | + CreateHyperParameterTuningJob = "createHyperParameterTuningJob" |
| 37 | + CreateProcessingJob = "createProcessingJob" |
| 38 | + |
| 39 | + |
24 | 40 | class TrainingStep(Task):
|
25 | 41 |
|
26 | 42 | """
|
@@ -58,9 +74,20 @@ def __init__(self, state_id, estimator, job_name, data=None, hyperparameters=Non
|
58 | 74 | self.job_name = job_name
|
59 | 75 |
|
60 | 76 | if wait_for_completion:
|
61 |
| - kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createTrainingJob.sync' |
| 77 | + """ |
| 78 | + Example resource arn: arn:aws:states:::sagemaker:createTrainingJob.sync |
| 79 | + """ |
| 80 | + |
| 81 | + kwargs[Field.Resource.value] = get_service_integration_arn(SAGEMAKER_SERVICE_NAME, |
| 82 | + SageMakerApi.CreateTrainingJob, |
| 83 | + IntegrationPattern.WaitForCompletion) |
62 | 84 | else:
|
63 |
| - kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createTrainingJob' |
| 85 | + """ |
| 86 | + Example resource arn: arn:aws:states:::sagemaker:createTrainingJob |
| 87 | + """ |
| 88 | + |
| 89 | + kwargs[Field.Resource.value] = get_service_integration_arn(SAGEMAKER_SERVICE_NAME, |
| 90 | + SageMakerApi.CreateTrainingJob) |
64 | 91 |
|
65 | 92 | if isinstance(job_name, str):
|
66 | 93 | parameters = training_config(estimator=estimator, inputs=data, job_name=job_name, mini_batch_size=mini_batch_size)
|
@@ -141,9 +168,20 @@ def __init__(self, state_id, transformer, job_name, model_name, data, data_type=
|
141 | 168 | join_source (str): The source of data to be joined to the transform output. It can be set to ‘Input’ meaning the entire input record will be joined to the inference result. You can use OutputFilter to select the useful portion before uploading to S3. (default: None). Valid values: Input, None.
|
142 | 169 | """
|
143 | 170 | if wait_for_completion:
|
144 |
| - kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createTransformJob.sync' |
| 171 | + """ |
| 172 | + Example resource arn: arn:aws:states:::sagemaker:createTransformJob.sync |
| 173 | + """ |
| 174 | + |
| 175 | + kwargs[Field.Resource.value] = get_service_integration_arn(SAGEMAKER_SERVICE_NAME, |
| 176 | + SageMakerApi.CreateTransformJob, |
| 177 | + IntegrationPattern.WaitForCompletion) |
145 | 178 | else:
|
146 |
| - kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createTransformJob' |
| 179 | + """ |
| 180 | + Example resource arn: arn:aws:states:::sagemaker:createTransformJob |
| 181 | + """ |
| 182 | + |
| 183 | + kwargs[Field.Resource.value] = get_service_integration_arn(SAGEMAKER_SERVICE_NAME, |
| 184 | + SageMakerApi.CreateTransformJob) |
147 | 185 |
|
148 | 186 | if isinstance(job_name, str):
|
149 | 187 | parameters = transform_config(
|
@@ -225,7 +263,13 @@ def __init__(self, state_id, model, model_name=None, instance_type=None, tags=No
|
225 | 263 | parameters['Tags'] = tags_dict_to_kv_list(tags)
|
226 | 264 |
|
227 | 265 | kwargs[Field.Parameters.value] = parameters
|
228 |
| - kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createModel' |
| 266 | + |
| 267 | + """ |
| 268 | + Example resource arn: arn:aws:states:::sagemaker:createModel |
| 269 | + """ |
| 270 | + |
| 271 | + kwargs[Field.Resource.value] = get_service_integration_arn(SAGEMAKER_SERVICE_NAME, |
| 272 | + SageMakerApi.CreateModel) |
229 | 273 |
|
230 | 274 | super(ModelStep, self).__init__(state_id, **kwargs)
|
231 | 275 |
|
@@ -266,7 +310,13 @@ def __init__(self, state_id, endpoint_config_name, model_name, initial_instance_
|
266 | 310 | if tags:
|
267 | 311 | parameters['Tags'] = tags_dict_to_kv_list(tags)
|
268 | 312 |
|
269 |
| - kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createEndpointConfig' |
| 313 | + """ |
| 314 | + Example resource arn: arn:aws:states:::sagemaker:createEndpointConfig |
| 315 | + """ |
| 316 | + |
| 317 | + kwargs[Field.Resource.value] = get_service_integration_arn(SAGEMAKER_SERVICE_NAME, |
| 318 | + SageMakerApi.CreateEndpointConfig) |
| 319 | + |
270 | 320 | kwargs[Field.Parameters.value] = parameters
|
271 | 321 |
|
272 | 322 | super(EndpointConfigStep, self).__init__(state_id, **kwargs)
|
@@ -298,9 +348,19 @@ def __init__(self, state_id, endpoint_name, endpoint_config_name, tags=None, upd
|
298 | 348 | parameters['Tags'] = tags_dict_to_kv_list(tags)
|
299 | 349 |
|
300 | 350 | if update:
|
301 |
| - kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:updateEndpoint' |
| 351 | + """ |
| 352 | + Example resource arn: arn:aws:states:::sagemaker:updateEndpoint |
| 353 | + """ |
| 354 | + |
| 355 | + kwargs[Field.Resource.value] = get_service_integration_arn(SAGEMAKER_SERVICE_NAME, |
| 356 | + SageMakerApi.UpdateEndpoint) |
302 | 357 | else:
|
303 |
| - kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createEndpoint' |
| 358 | + """ |
| 359 | + Example resource arn: arn:aws:states:::sagemaker:createEndpoint |
| 360 | + """ |
| 361 | + |
| 362 | + kwargs[Field.Resource.value] = get_service_integration_arn(SAGEMAKER_SERVICE_NAME, |
| 363 | + SageMakerApi.CreateEndpoint) |
304 | 364 |
|
305 | 365 | kwargs[Field.Parameters.value] = parameters
|
306 | 366 |
|
@@ -338,9 +398,20 @@ def __init__(self, state_id, tuner, job_name, data, wait_for_completion=True, ta
|
338 | 398 | tags (list[dict], optional): `List to tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
|
339 | 399 | """
|
340 | 400 | if wait_for_completion:
|
341 |
| - kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createHyperParameterTuningJob.sync' |
| 401 | + """ |
| 402 | + Example resource arn: arn:aws:states:::sagemaker:createHyperParameterTuningJob.sync |
| 403 | + """ |
| 404 | + |
| 405 | + kwargs[Field.Resource.value] = get_service_integration_arn(SAGEMAKER_SERVICE_NAME, |
| 406 | + SageMakerApi.CreateHyperParameterTuningJob, |
| 407 | + IntegrationPattern.WaitForCompletion) |
342 | 408 | else:
|
343 |
| - kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createHyperParameterTuningJob' |
| 409 | + """ |
| 410 | + Example resource arn: arn:aws:states:::sagemaker:createHyperParameterTuningJob |
| 411 | + """ |
| 412 | + |
| 413 | + kwargs[Field.Resource.value] = get_service_integration_arn(SAGEMAKER_SERVICE_NAME, |
| 414 | + SageMakerApi.CreateHyperParameterTuningJob) |
344 | 415 |
|
345 | 416 | parameters = tuning_config(tuner=tuner, inputs=data, job_name=job_name).copy()
|
346 | 417 |
|
@@ -387,10 +458,21 @@ def __init__(self, state_id, processor, job_name, inputs=None, outputs=None, exp
|
387 | 458 | tags (list[dict], optional): `List to tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
|
388 | 459 | """
|
389 | 460 | if wait_for_completion:
|
390 |
| - kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createProcessingJob.sync' |
| 461 | + """ |
| 462 | + Example resource arn: arn:aws:states:::sagemaker:createProcessingJob.sync |
| 463 | + """ |
| 464 | + |
| 465 | + kwargs[Field.Resource.value] = get_service_integration_arn(SAGEMAKER_SERVICE_NAME, |
| 466 | + SageMakerApi.CreateProcessingJob, |
| 467 | + IntegrationPattern.WaitForCompletion) |
391 | 468 | else:
|
392 |
| - kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createProcessingJob' |
393 |
| - |
| 469 | + """ |
| 470 | + Example resource arn: arn:aws:states:::sagemaker:createProcessingJob |
| 471 | + """ |
| 472 | + |
| 473 | + kwargs[Field.Resource.value] = get_service_integration_arn(SAGEMAKER_SERVICE_NAME, |
| 474 | + SageMakerApi.CreateProcessingJob) |
| 475 | + |
394 | 476 | if isinstance(job_name, str):
|
395 | 477 | parameters = processing_config(processor=processor, inputs=inputs, outputs=outputs, container_arguments=container_arguments, container_entrypoint=container_entrypoint, kms_key_id=kms_key_id, job_name=job_name)
|
396 | 478 | else:
|
|
0 commit comments