2
2
# Copyright (c) Microsoft Corporation.
3
3
# Licensed under the MIT License.
4
4
# ------------------------------------
5
- # pylint: disable=too-many-lines
5
+ # pylint: disable=too-many-lines,protected-access
6
6
"""Customize generated code here.
7
7
8
8
Follow our quickstart for examples: https://aka.ms/azsdk/python/dpcodegen/python/customize
9
9
"""
10
10
11
11
import sys
12
- from typing import Any , IO , Callable , Dict , Iterator , List , Optional , Type , TypeVar , Union , cast , overload
13
- from azure .core .polling import NoPolling , PollingMethod
14
- from azure .core .polling .base_polling import LROBasePolling
12
+ from typing import Any , IO , Callable , Dict , Iterator , List , Optional , Type , TypeVar , Union , cast , overload , Tuple
15
13
from azure .core .tracing .decorator import distributed_trace
16
14
from azure .core .utils import case_insensitive_dict
17
- from azure .core .rest import HttpRequest , HttpResponse
15
+ from azure .core .rest import HttpRequest , HttpResponse , AsyncHttpResponse
18
16
from azure .core .pipeline import PipelineResponse
19
17
from azure .core .exceptions import (
20
18
ClientAuthenticationError ,
21
19
HttpResponseError ,
22
20
ResourceExistsError ,
23
21
ResourceNotFoundError ,
24
22
ResourceNotModifiedError ,
23
+ ODataV4Format ,
25
24
map_error ,
26
25
)
27
- from ..models import _models
28
- from .. import _model_base
26
+ from azure .core .polling import LROPoller , NoPolling , PollingMethod
27
+ from azure .core .polling .base_polling import (
28
+ LROBasePolling ,
29
+ OperationResourcePolling ,
30
+ _is_empty ,
31
+ _as_json ,
32
+ BadResponse ,
33
+ OperationFailed ,
34
+ _raise_if_bad_http_status_and_method ,
35
+ )
36
+ from .. import _model_base , models as _models
37
+ from ..models import (
38
+ TranslationStatus ,
39
+ )
29
40
from .._model_base import _deserialize
30
41
from ._operations import (
31
42
DocumentTranslationClientOperationsMixin as GeneratedDocumentTranslationClientOperationsMixin ,
34
45
ClsType ,
35
46
build_single_document_translation_document_translate_request ,
36
47
)
37
- from .. _patch import DocumentTranslationLROPoller
48
+
38
49
from .._vendor import prepare_multipart_form_data
39
50
40
51
if sys .version_info >= (3 , 9 ):
46
57
ClsType = Optional [Callable [[PipelineResponse [HttpRequest , HttpResponse ], T , Dict [str , Any ]], Any ]] # type: ignore
47
58
48
59
60
+ ResponseType = Union [HttpResponse , AsyncHttpResponse ]
61
+ PipelineResponseType = PipelineResponse [HttpRequest , ResponseType ]
62
+ PollingReturnType_co = TypeVar ("PollingReturnType_co" , covariant = True )
63
+
64
+ _FINISHED = frozenset (["succeeded" , "cancelled" , "cancelling" , "failed" ])
65
+ _FAILED = frozenset (["validationfailed" ])
66
+
67
+
68
+ def convert_status (status , ll = False ):
69
+ if ll is False :
70
+ if status == "Cancelled" :
71
+ return "Canceled"
72
+ if status == "Cancelling" :
73
+ return "Canceling"
74
+ elif ll is True :
75
+ if status == "Canceled" :
76
+ return "Cancelled"
77
+ if status == "Canceling" :
78
+ return "Cancelling"
79
+ return status
80
+
81
+
82
+ class DocumentTranslationLROPoller (LROPoller [PollingReturnType_co ]):
83
+ """A custom poller implementation for Document Translation. Call `result()` on the poller to return
84
+ a pageable of :class:`~azure.ai.translation.document.DocumentStatus`."""
85
+
86
+ _polling_method : "DocumentTranslationLROPollingMethod"
87
+
88
+ @property
89
+ def id (self ) -> str :
90
+ """The ID for the translation operation
91
+
92
+ :return: The str ID for the translation operation.
93
+ :rtype: str
94
+ """
95
+ if self ._polling_method ._current_body :
96
+ return self ._polling_method ._current_body .id
97
+ return self ._polling_method ._get_id_from_headers ()
98
+
99
+ @property
100
+ def details (self ) -> TranslationStatus :
101
+ """The details for the translation operation
102
+
103
+ :return: The details for the translation operation.
104
+ :rtype: ~azure.ai.translation.document.TranslationStatus
105
+ """
106
+ if self ._polling_method ._current_body :
107
+ return TranslationStatus (self ._polling_method ._current_body )
108
+ return TranslationStatus (id = self ._polling_method ._get_id_from_headers ()) # type: ignore
109
+
110
+ @classmethod
111
+ def from_continuation_token ( # pylint: disable=docstring-missing-return,docstring-missing-param,docstring-missing-rtype
112
+ cls , polling_method , continuation_token , ** kwargs : Any
113
+ ):
114
+ """
115
+ :meta private:
116
+ """
117
+ (
118
+ client ,
119
+ initial_response ,
120
+ deserialization_callback ,
121
+ ) = polling_method .from_continuation_token (continuation_token , ** kwargs )
122
+
123
+ return cls (client , initial_response , deserialization_callback , polling_method )
124
+
125
+
126
+ class DocumentTranslationLROPollingMethod (LROBasePolling ):
127
+ """A custom polling method implementation for Document Translation."""
128
+
129
+ def __init__ (self , * args , ** kwargs ):
130
+ self ._cont_token_response = kwargs .pop ("cont_token_response" )
131
+ super ().__init__ (* args , ** kwargs )
132
+
133
+ @property
134
+ def _current_body (self ) -> TranslationStatus :
135
+ try :
136
+ return TranslationStatus (self ._pipeline_response .http_response .json ())
137
+ except Exception : # pylint: disable=broad-exception-caught
138
+ return TranslationStatus () # type: ignore[call-overload]
139
+
140
+ def _get_id_from_headers (self ) -> str :
141
+ return (
142
+ self ._initial_response .http_response .headers ["Operation-Location" ]
143
+ .split ("/batches/" )[1 ]
144
+ .split ("?api-version" )[0 ]
145
+ )
146
+
147
+ def finished (self ) -> bool :
148
+ """Is this polling finished?
149
+
150
+ :return: True/False for whether polling is complete.
151
+ :rtype: bool
152
+ """
153
+ return self ._finished (self .status ())
154
+
155
+ @staticmethod
156
+ def _finished (status ) -> bool :
157
+ if hasattr (status , "value" ):
158
+ status = status .value
159
+ return str (status ).lower () in _FINISHED
160
+
161
+ @staticmethod
162
+ def _failed (status ) -> bool :
163
+ if hasattr (status , "value" ):
164
+ status = status .value
165
+ return str (status ).lower () in _FAILED
166
+
167
+ def get_continuation_token (self ) -> str :
168
+ if self ._current_body :
169
+ return self ._current_body .id
170
+ return self ._get_id_from_headers ()
171
+
172
+ # pylint: disable=arguments-differ
173
+ def from_continuation_token (self , continuation_token : str , ** kwargs : Any ) -> Tuple : # type: ignore[override]
174
+ try :
175
+ client = kwargs ["client" ]
176
+ except KeyError as exc :
177
+ raise ValueError ("Need kwarg 'client' to be recreated from continuation_token" ) from exc
178
+
179
+ try :
180
+ deserialization_callback = kwargs ["deserialization_callback" ]
181
+ except KeyError as exc :
182
+ raise ValueError ("Need kwarg 'deserialization_callback' to be recreated from continuation_token" ) from exc
183
+
184
+ return client , self ._cont_token_response , deserialization_callback
185
+
186
+ def _poll (self ) -> None :
187
+ """Poll status of operation so long as operation is incomplete and
188
+ we have an endpoint to query.
189
+
190
+ :raises: OperationFailed if operation status 'Failed' or 'Canceled'.
191
+ :raises: BadStatus if response status invalid.
192
+ :raises: BadResponse if response invalid.
193
+ """
194
+
195
+ if not self .finished ():
196
+ self .update_status ()
197
+ while not self .finished ():
198
+ self ._delay ()
199
+ self .update_status ()
200
+
201
+ if self ._failed (self .status ()):
202
+ raise OperationFailed ("Operation failed or canceled" )
203
+
204
+ final_get_url = self ._operation .get_final_get_url (self ._pipeline_response )
205
+ if final_get_url :
206
+ self ._pipeline_response = self .request_status (final_get_url )
207
+ _raise_if_bad_http_status_and_method (self ._pipeline_response .http_response )
208
+
209
+
210
+ class TranslationPolling (OperationResourcePolling ):
211
+ """Implements a Location polling."""
212
+
213
+ def can_poll (self , pipeline_response : PipelineResponseType ) -> bool :
214
+ """Answer if this polling method could be used.
215
+
216
+ :param pipeline_response: The PipelineResponse type
217
+ :type pipeline_response: PipelineResponseType
218
+ :return: Whether polling should be performed.
219
+ :rtype: bool
220
+ """
221
+ response = pipeline_response .http_response
222
+ can_poll = self ._operation_location_header in response .headers
223
+ if can_poll :
224
+ return True
225
+
226
+ if not _is_empty (response ):
227
+ body = _as_json (response )
228
+ status = body .get ("status" )
229
+ if status :
230
+ return True
231
+ return False
232
+
233
+ def _set_async_url_if_present (self , response : ResponseType ) -> None :
234
+ location_header = response .headers .get (self ._operation_location_header )
235
+ if location_header :
236
+ self ._async_url = location_header
237
+ else :
238
+ self ._async_url = response .request .url
239
+
240
+ def get_status (self , pipeline_response : PipelineResponseType ) -> str :
241
+ """Process the latest status update retrieved from a 'location' header.
242
+
243
+ :param azure.core.pipeline.PipelineResponse pipeline_response: latest REST call response.
244
+ :return: The current operation status
245
+ :rtype: str
246
+ :raises: BadResponse if response has no body and not status 202.
247
+ """
248
+ response = pipeline_response .http_response
249
+ if not _is_empty (response ):
250
+ body = _as_json (response )
251
+ status = body .get ("status" )
252
+ if status :
253
+ return self ._map_nonstandard_statuses (status , body )
254
+ raise BadResponse ("No status found in body" )
255
+ raise BadResponse ("The response from long running operation does not contain a body." )
256
+
257
+ def _map_nonstandard_statuses (self , status : str , body : Dict [str , Any ]) -> str :
258
+ """Map non-standard statuses.
259
+
260
+ :param str status: lro process status.
261
+ :param str body: pipeline response body.
262
+ :return: The current operation status.
263
+ :rtype: str
264
+ """
265
+ if status == "ValidationFailed" :
266
+ self .raise_error (body )
267
+ return status
268
+
269
+ def raise_error (self , body : Dict [str , Any ]) -> None :
270
+ error = body ["error" ]
271
+ if body ["error" ].get ("innerError" , None ):
272
+ error = body ["error" ]["innerError" ]
273
+ http_response_error = HttpResponseError (message = "({}): {}" .format (error ["code" ], error ["message" ]))
274
+ http_response_error .error = ODataV4Format (error ) # set error.code
275
+ raise http_response_error
276
+
277
+
49
278
class DocumentTranslationClientOperationsMixin (GeneratedDocumentTranslationClientOperationsMixin ):
50
279
51
280
@distributed_trace
52
- def begin_start_translation ( # type: ignore[override]
281
+ def _begin_start_translation ( # type: ignore[override]
53
282
self , body : Union [_models .StartTranslationDetails , JSON , IO [bytes ]], ** kwargs : Any
54
283
) -> DocumentTranslationLROPoller [_models .TranslationStatus ]:
55
284
_headers = case_insensitive_dict (kwargs .pop ("headers" , {}) or {})
@@ -61,7 +290,7 @@ def begin_start_translation( # type: ignore[override]
61
290
lro_delay = kwargs .pop ("polling_interval" , self ._config .polling_interval )
62
291
cont_token : Optional [str ] = kwargs .pop ("continuation_token" , None )
63
292
if cont_token is None :
64
- raw_result = self ._start_translation_initial ( # type: ignore[func-returns-value]
293
+ raw_result = self .__begin_start_translation_initial ( # type: ignore[func-returns-value]
65
294
body = body , content_type = content_type , cls = lambda x , y , z : x , headers = _headers , params = _params , ** kwargs
66
295
)
67
296
kwargs .pop ("error_map" , None )
0 commit comments