Skip to content

Commit e2f4c4c

Browse files
Dt use patch (Azure#36977)
* regen dt with typespec * regen * update samples to show diff * regen with necessary patch * revert samples * fix docs * sync passing * get async working * better diff * better diff again * few fixes * mypy/pylint * fix ci and some apiview * fix patched code due to rename * fix after rename * fix docstrings and add missing patch * feedback * Addressing some review comments: Removed the __init overload in model patch. Updated status value comparision using lower case. * Fixing pylint error * Re-generated SDK after renaming id to translation_id. Cleaned up patching. * Revert "Re-generated SDK after renaming id to translation_id. Cleaned up patching." This reverts commit d0ed621. * Picking up overriding of id to translation_id * Fixing test failures * Adding a test case for the new model "StartTranslationDetails". Since Key based auth is disabled on storage accounts, updated the tests and their recordings. * Revert "Adding a test case for the new model "StartTranslationDetails". Since Key based auth is disabled on storage accounts, updated the tests and their recordings." This reverts commit 94ab6ae. * Addressing review comments * Fixing test failures --------- Co-authored-by: Hamshavathi Munibyraiah <[email protected]>
1 parent a42cbfe commit e2f4c4c

22 files changed

+2144
-2059
lines changed

.vscode/cspell.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@
9090
"sdk/synapse/azure-synapse/**",
9191
"sdk/synapse/azure-synapse-artifacts/**",
9292
"sdk/translation/azure-ai-translation-document/samples/assets/**",
93+
"sdk/translation/azure-ai-translation-document/doc/**",
9394
"sdk/translation/azure-ai-translation-document/tests/glossaries-valid.csv",
9495
"sdk/storage/azure-storage-blob/**",
9596
"sdk/storage/azure-storage-extensions/**",

sdk/translation/azure-ai-translation-document/MANIFEST.in

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,5 @@ recursive-include tests *.py
55
recursive-include samples *.py *.md
66
include azure/__init__.py
77
include azure/ai/__init__.py
8-
include azure/ai/translation/__init__.py
8+
include azure/ai/translation/__init__.py
9+
recursive-include doc *.rst

sdk/translation/azure-ai-translation-document/azure/ai/translation/document/_client.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from copy import deepcopy
1010
from typing import Any, TYPE_CHECKING, Union
11+
from typing_extensions import Self
1112

1213
from azure.core import PipelineClient
1314
from azure.core.credentials import AzureKeyCredential
@@ -97,7 +98,7 @@ def send_request(self, request: HttpRequest, *, stream: bool = False, **kwargs:
9798
def close(self) -> None:
9899
self._client.close()
99100

100-
def __enter__(self) -> "DocumentTranslationClient":
101+
def __enter__(self) -> Self:
101102
self._client.__enter__()
102103
return self
103104

@@ -177,7 +178,7 @@ def send_request(self, request: HttpRequest, *, stream: bool = False, **kwargs:
177178
def close(self) -> None:
178179
self._client.close()
179180

180-
def __enter__(self) -> "SingleDocumentTranslationClient":
181+
def __enter__(self) -> Self:
181182
self._client.__enter__()
182183
return self
183184

sdk/translation/azure-ai-translation-document/azure/ai/translation/document/_model_base.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -563,6 +563,7 @@ def as_dict(self, *, exclude_readonly: bool = False) -> typing.Dict[str, typing.
563563
"""
564564

565565
result = {}
566+
readonly_props = []
566567
if exclude_readonly:
567568
readonly_props = [p._rest_name for p in self._attr_to_rest_field.values() if _is_readonly(p)]
568569
for k, v in self.items():
@@ -883,5 +884,6 @@ def rest_discriminator(
883884
*,
884885
name: typing.Optional[str] = None,
885886
type: typing.Optional[typing.Callable] = None, # pylint: disable=redefined-builtin
887+
visibility: typing.Optional[typing.List[str]] = None,
886888
) -> typing.Any:
887-
return _RestField(name=name, type=type, is_discriminator=True)
889+
return _RestField(name=name, type=type, is_discriminator=True, visibility=visibility)

sdk/translation/azure-ai-translation-document/azure/ai/translation/document/_operations/_operations.py

Lines changed: 231 additions & 330 deletions
Large diffs are not rendered by default.

sdk/translation/azure-ai-translation-document/azure/ai/translation/document/_operations/_patch.py

Lines changed: 239 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,30 +2,41 @@
22
# Copyright (c) Microsoft Corporation.
33
# Licensed under the MIT License.
44
# ------------------------------------
5-
# pylint: disable=too-many-lines
5+
# pylint: disable=too-many-lines,protected-access
66
"""Customize generated code here.
77
88
Follow our quickstart for examples: https://aka.ms/azsdk/python/dpcodegen/python/customize
99
"""
1010

1111
import sys
12-
from typing import Any, IO, Callable, Dict, Iterator, List, Optional, Type, TypeVar, Union, cast, overload
13-
from azure.core.polling import NoPolling, PollingMethod
14-
from azure.core.polling.base_polling import LROBasePolling
12+
from typing import Any, IO, Callable, Dict, Iterator, List, Optional, Type, TypeVar, Union, cast, overload, Tuple
1513
from azure.core.tracing.decorator import distributed_trace
1614
from azure.core.utils import case_insensitive_dict
17-
from azure.core.rest import HttpRequest, HttpResponse
15+
from azure.core.rest import HttpRequest, HttpResponse, AsyncHttpResponse
1816
from azure.core.pipeline import PipelineResponse
1917
from azure.core.exceptions import (
2018
ClientAuthenticationError,
2119
HttpResponseError,
2220
ResourceExistsError,
2321
ResourceNotFoundError,
2422
ResourceNotModifiedError,
23+
ODataV4Format,
2524
map_error,
2625
)
27-
from ..models import _models
28-
from .. import _model_base
26+
from azure.core.polling import LROPoller, NoPolling, PollingMethod
27+
from azure.core.polling.base_polling import (
28+
LROBasePolling,
29+
OperationResourcePolling,
30+
_is_empty,
31+
_as_json,
32+
BadResponse,
33+
OperationFailed,
34+
_raise_if_bad_http_status_and_method,
35+
)
36+
from .. import _model_base, models as _models
37+
from ..models import (
38+
TranslationStatus,
39+
)
2940
from .._model_base import _deserialize
3041
from ._operations import (
3142
DocumentTranslationClientOperationsMixin as GeneratedDocumentTranslationClientOperationsMixin,
@@ -34,7 +45,7 @@
3445
ClsType,
3546
build_single_document_translation_document_translate_request,
3647
)
37-
from .._patch import DocumentTranslationLROPoller
48+
3849
from .._vendor import prepare_multipart_form_data
3950

4051
if sys.version_info >= (3, 9):
@@ -46,10 +57,228 @@
4657
ClsType = Optional[Callable[[PipelineResponse[HttpRequest, HttpResponse], T, Dict[str, Any]], Any]] # type: ignore
4758

4859

60+
ResponseType = Union[HttpResponse, AsyncHttpResponse]
61+
PipelineResponseType = PipelineResponse[HttpRequest, ResponseType]
62+
PollingReturnType_co = TypeVar("PollingReturnType_co", covariant=True)
63+
64+
_FINISHED = frozenset(["succeeded", "cancelled", "cancelling", "failed"])
65+
_FAILED = frozenset(["validationfailed"])
66+
67+
68+
def convert_status(status, ll=False):
69+
if ll is False:
70+
if status == "Cancelled":
71+
return "Canceled"
72+
if status == "Cancelling":
73+
return "Canceling"
74+
elif ll is True:
75+
if status == "Canceled":
76+
return "Cancelled"
77+
if status == "Canceling":
78+
return "Cancelling"
79+
return status
80+
81+
82+
class DocumentTranslationLROPoller(LROPoller[PollingReturnType_co]):
83+
"""A custom poller implementation for Document Translation. Call `result()` on the poller to return
84+
a pageable of :class:`~azure.ai.translation.document.DocumentStatus`."""
85+
86+
_polling_method: "DocumentTranslationLROPollingMethod"
87+
88+
@property
89+
def id(self) -> str:
90+
"""The ID for the translation operation
91+
92+
:return: The str ID for the translation operation.
93+
:rtype: str
94+
"""
95+
if self._polling_method._current_body:
96+
return self._polling_method._current_body.id
97+
return self._polling_method._get_id_from_headers()
98+
99+
@property
100+
def details(self) -> TranslationStatus:
101+
"""The details for the translation operation
102+
103+
:return: The details for the translation operation.
104+
:rtype: ~azure.ai.translation.document.TranslationStatus
105+
"""
106+
if self._polling_method._current_body:
107+
return TranslationStatus(self._polling_method._current_body)
108+
return TranslationStatus(id=self._polling_method._get_id_from_headers()) # type: ignore
109+
110+
@classmethod
111+
def from_continuation_token( # pylint: disable=docstring-missing-return,docstring-missing-param,docstring-missing-rtype
112+
cls, polling_method, continuation_token, **kwargs: Any
113+
):
114+
"""
115+
:meta private:
116+
"""
117+
(
118+
client,
119+
initial_response,
120+
deserialization_callback,
121+
) = polling_method.from_continuation_token(continuation_token, **kwargs)
122+
123+
return cls(client, initial_response, deserialization_callback, polling_method)
124+
125+
126+
class DocumentTranslationLROPollingMethod(LROBasePolling):
127+
"""A custom polling method implementation for Document Translation."""
128+
129+
def __init__(self, *args, **kwargs):
130+
self._cont_token_response = kwargs.pop("cont_token_response")
131+
super().__init__(*args, **kwargs)
132+
133+
@property
134+
def _current_body(self) -> TranslationStatus:
135+
try:
136+
return TranslationStatus(self._pipeline_response.http_response.json())
137+
except Exception: # pylint: disable=broad-exception-caught
138+
return TranslationStatus() # type: ignore[call-overload]
139+
140+
def _get_id_from_headers(self) -> str:
141+
return (
142+
self._initial_response.http_response.headers["Operation-Location"]
143+
.split("/batches/")[1]
144+
.split("?api-version")[0]
145+
)
146+
147+
def finished(self) -> bool:
148+
"""Is this polling finished?
149+
150+
:return: True/False for whether polling is complete.
151+
:rtype: bool
152+
"""
153+
return self._finished(self.status())
154+
155+
@staticmethod
156+
def _finished(status) -> bool:
157+
if hasattr(status, "value"):
158+
status = status.value
159+
return str(status).lower() in _FINISHED
160+
161+
@staticmethod
162+
def _failed(status) -> bool:
163+
if hasattr(status, "value"):
164+
status = status.value
165+
return str(status).lower() in _FAILED
166+
167+
def get_continuation_token(self) -> str:
168+
if self._current_body:
169+
return self._current_body.id
170+
return self._get_id_from_headers()
171+
172+
# pylint: disable=arguments-differ
173+
def from_continuation_token(self, continuation_token: str, **kwargs: Any) -> Tuple: # type: ignore[override]
174+
try:
175+
client = kwargs["client"]
176+
except KeyError as exc:
177+
raise ValueError("Need kwarg 'client' to be recreated from continuation_token") from exc
178+
179+
try:
180+
deserialization_callback = kwargs["deserialization_callback"]
181+
except KeyError as exc:
182+
raise ValueError("Need kwarg 'deserialization_callback' to be recreated from continuation_token") from exc
183+
184+
return client, self._cont_token_response, deserialization_callback
185+
186+
def _poll(self) -> None:
187+
"""Poll status of operation so long as operation is incomplete and
188+
we have an endpoint to query.
189+
190+
:raises: OperationFailed if operation status 'Failed' or 'Canceled'.
191+
:raises: BadStatus if response status invalid.
192+
:raises: BadResponse if response invalid.
193+
"""
194+
195+
if not self.finished():
196+
self.update_status()
197+
while not self.finished():
198+
self._delay()
199+
self.update_status()
200+
201+
if self._failed(self.status()):
202+
raise OperationFailed("Operation failed or canceled")
203+
204+
final_get_url = self._operation.get_final_get_url(self._pipeline_response)
205+
if final_get_url:
206+
self._pipeline_response = self.request_status(final_get_url)
207+
_raise_if_bad_http_status_and_method(self._pipeline_response.http_response)
208+
209+
210+
class TranslationPolling(OperationResourcePolling):
211+
"""Implements a Location polling."""
212+
213+
def can_poll(self, pipeline_response: PipelineResponseType) -> bool:
214+
"""Answer if this polling method could be used.
215+
216+
:param pipeline_response: The PipelineResponse type
217+
:type pipeline_response: PipelineResponseType
218+
:return: Whether polling should be performed.
219+
:rtype: bool
220+
"""
221+
response = pipeline_response.http_response
222+
can_poll = self._operation_location_header in response.headers
223+
if can_poll:
224+
return True
225+
226+
if not _is_empty(response):
227+
body = _as_json(response)
228+
status = body.get("status")
229+
if status:
230+
return True
231+
return False
232+
233+
def _set_async_url_if_present(self, response: ResponseType) -> None:
234+
location_header = response.headers.get(self._operation_location_header)
235+
if location_header:
236+
self._async_url = location_header
237+
else:
238+
self._async_url = response.request.url
239+
240+
def get_status(self, pipeline_response: PipelineResponseType) -> str:
241+
"""Process the latest status update retrieved from a 'location' header.
242+
243+
:param azure.core.pipeline.PipelineResponse pipeline_response: latest REST call response.
244+
:return: The current operation status
245+
:rtype: str
246+
:raises: BadResponse if response has no body and not status 202.
247+
"""
248+
response = pipeline_response.http_response
249+
if not _is_empty(response):
250+
body = _as_json(response)
251+
status = body.get("status")
252+
if status:
253+
return self._map_nonstandard_statuses(status, body)
254+
raise BadResponse("No status found in body")
255+
raise BadResponse("The response from long running operation does not contain a body.")
256+
257+
def _map_nonstandard_statuses(self, status: str, body: Dict[str, Any]) -> str:
258+
"""Map non-standard statuses.
259+
260+
:param str status: lro process status.
261+
:param str body: pipeline response body.
262+
:return: The current operation status.
263+
:rtype: str
264+
"""
265+
if status == "ValidationFailed":
266+
self.raise_error(body)
267+
return status
268+
269+
def raise_error(self, body: Dict[str, Any]) -> None:
270+
error = body["error"]
271+
if body["error"].get("innerError", None):
272+
error = body["error"]["innerError"]
273+
http_response_error = HttpResponseError(message="({}): {}".format(error["code"], error["message"]))
274+
http_response_error.error = ODataV4Format(error) # set error.code
275+
raise http_response_error
276+
277+
49278
class DocumentTranslationClientOperationsMixin(GeneratedDocumentTranslationClientOperationsMixin):
50279

51280
@distributed_trace
52-
def begin_start_translation( # type: ignore[override]
281+
def _begin_start_translation( # type: ignore[override]
53282
self, body: Union[_models.StartTranslationDetails, JSON, IO[bytes]], **kwargs: Any
54283
) -> DocumentTranslationLROPoller[_models.TranslationStatus]:
55284
_headers = case_insensitive_dict(kwargs.pop("headers", {}) or {})
@@ -61,7 +290,7 @@ def begin_start_translation( # type: ignore[override]
61290
lro_delay = kwargs.pop("polling_interval", self._config.polling_interval)
62291
cont_token: Optional[str] = kwargs.pop("continuation_token", None)
63292
if cont_token is None:
64-
raw_result = self._start_translation_initial( # type: ignore[func-returns-value]
293+
raw_result = self.__begin_start_translation_initial( # type: ignore[func-returns-value]
65294
body=body, content_type=content_type, cls=lambda x, y, z: x, headers=_headers, params=_params, **kwargs
66295
)
67296
kwargs.pop("error_map", None)

0 commit comments

Comments
 (0)