Skip to content

Commit 010bc93

Browse files
[formrecognizer] reduce time for recorded tests runs (#11970)
* wip * wip * add transport wrapper for get clients * fix mypy * fix * refactor * pass through kwargs * add polling interval tests * feedback
1 parent 4ee3368 commit 010bc93

29 files changed

+1551
-496
lines changed

sdk/formrecognizer/azure-ai-formrecognizer/azure/ai/formrecognizer/_form_recognizer_client.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -66,11 +66,13 @@ def __init__(self, endpoint, credential, **kwargs):
6666
# type: (str, Union[AzureKeyCredential, TokenCredential], Any) -> None
6767

6868
authentication_policy = get_authentication_policy(credential)
69+
polling_interval = kwargs.pop("polling_interval", POLLING_INTERVAL)
6970
self._client = FormRecognizer(
7071
endpoint=endpoint,
7172
credential=credential, # type: ignore
7273
sdk_moniker=USER_AGENT,
7374
authentication_policy=authentication_policy,
75+
polling_interval=polling_interval,
7476
**kwargs
7577
)
7678

@@ -111,7 +113,7 @@ def begin_recognize_receipts(self, receipt, **kwargs):
111113
:caption: Recognize US sales receipt fields.
112114
"""
113115

114-
polling_interval = kwargs.pop("polling_interval", POLLING_INTERVAL)
116+
polling_interval = kwargs.pop("polling_interval", self._client._config.polling_interval)
115117
continuation_token = kwargs.pop("continuation_token", None)
116118
content_type = kwargs.pop("content_type", None)
117119
if content_type == "application/json":
@@ -162,7 +164,7 @@ def begin_recognize_receipts_from_url(self, receipt_url, **kwargs):
162164
:caption: Recognize US sales receipt fields from a URL.
163165
"""
164166

165-
polling_interval = kwargs.pop("polling_interval", POLLING_INTERVAL)
167+
polling_interval = kwargs.pop("polling_interval", self._client._config.polling_interval)
166168
continuation_token = kwargs.pop("continuation_token", None)
167169
include_text_content = kwargs.pop("include_text_content", False)
168170

@@ -210,7 +212,7 @@ def begin_recognize_content(self, form, **kwargs):
210212
:caption: Recognize text and content/layout information from a form.
211213
"""
212214

213-
polling_interval = kwargs.pop("polling_interval", POLLING_INTERVAL)
215+
polling_interval = kwargs.pop("polling_interval", self._client._config.polling_interval)
214216
continuation_token = kwargs.pop("continuation_token", None)
215217
content_type = kwargs.pop("content_type", None)
216218
if content_type == "application/json":
@@ -246,7 +248,7 @@ def begin_recognize_content_from_url(self, form_url, **kwargs):
246248
:raises ~azure.core.exceptions.HttpResponseError:
247249
"""
248250

249-
polling_interval = kwargs.pop("polling_interval", POLLING_INTERVAL)
251+
polling_interval = kwargs.pop("polling_interval", self._client._config.polling_interval)
250252
continuation_token = kwargs.pop("continuation_token", None)
251253

252254
return self._client.begin_analyze_layout_async(
@@ -296,7 +298,7 @@ def begin_recognize_custom_forms(self, model_id, form, **kwargs):
296298
raise ValueError("model_id cannot be None or empty.")
297299

298300
cls = kwargs.pop("cls", None)
299-
polling_interval = kwargs.pop("polling_interval", POLLING_INTERVAL)
301+
polling_interval = kwargs.pop("polling_interval", self._client._config.polling_interval)
300302
continuation_token = kwargs.pop("continuation_token", None)
301303
content_type = kwargs.pop("content_type", None)
302304
if content_type == "application/json":
@@ -348,7 +350,7 @@ def begin_recognize_custom_forms_from_url(self, model_id, form_url, **kwargs):
348350
raise ValueError("model_id cannot be None or empty.")
349351

350352
cls = kwargs.pop("cls", None)
351-
polling_interval = kwargs.pop("polling_interval", POLLING_INTERVAL)
353+
polling_interval = kwargs.pop("polling_interval", self._client._config.polling_interval)
352354
continuation_token = kwargs.pop("continuation_token", None)
353355
include_text_content = kwargs.pop("include_text_content", False)
354356

sdk/formrecognizer/azure-ai-formrecognizer/azure/ai/formrecognizer/_form_training_client.py

+16-4
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from azure.core.tracing.decorator import distributed_trace
1818
from azure.core.polling import LROPoller
1919
from azure.core.polling.base_polling import LROBasePolling
20+
from azure.core.pipeline import Pipeline
2021
from ._generated._form_recognizer_client import FormRecognizerClient as FormRecognizer
2122
from ._generated.models import (
2223
TrainRequest,
@@ -26,7 +27,7 @@
2627
CopyOperationResult,
2728
CopyAuthorizationResult
2829
)
29-
from ._helpers import error_map, get_authentication_policy, POLLING_INTERVAL
30+
from ._helpers import error_map, get_authentication_policy, POLLING_INTERVAL, TransportWrapper
3031
from ._models import (
3132
CustomFormModelInfo,
3233
AccountProperties,
@@ -78,11 +79,13 @@ def __init__(self, endpoint, credential, **kwargs):
7879
self._endpoint = endpoint
7980
self._credential = credential
8081
authentication_policy = get_authentication_policy(credential)
82+
polling_interval = kwargs.pop("polling_interval", POLLING_INTERVAL)
8183
self._client = FormRecognizer(
8284
endpoint=self._endpoint,
8385
credential=self._credential, # type: ignore
8486
sdk_moniker=USER_AGENT,
8587
authentication_policy=authentication_policy,
88+
polling_interval=polling_interval,
8689
**kwargs
8790
)
8891

@@ -129,7 +132,7 @@ def callback(raw_response):
129132

130133
cls = kwargs.pop("cls", None)
131134
continuation_token = kwargs.pop("continuation_token", None)
132-
polling_interval = kwargs.pop("polling_interval", POLLING_INTERVAL)
135+
polling_interval = kwargs.pop("polling_interval", self._client._config.polling_interval)
133136
deserialization_callback = cls if cls else callback
134137

135138
if continuation_token:
@@ -339,7 +342,7 @@ def begin_copy_model(
339342
if not model_id:
340343
raise ValueError("model_id cannot be None or empty.")
341344

342-
polling_interval = kwargs.pop("polling_interval", POLLING_INTERVAL)
345+
polling_interval = kwargs.pop("polling_interval", self._client._config.polling_interval)
343346
continuation_token = kwargs.pop("continuation_token", None)
344347

345348
def _copy_callback(raw_response, _, headers): # pylint: disable=unused-argument
@@ -371,11 +374,20 @@ def get_form_recognizer_client(self, **kwargs):
371374
:rtype: ~azure.ai.formrecognizer.FormRecognizerClient
372375
:return: A FormRecognizerClient
373376
"""
374-
return FormRecognizerClient(
377+
378+
_pipeline = Pipeline(
379+
transport=TransportWrapper(self._client._client._pipeline._transport),
380+
policies=self._client._client._pipeline._impl_policies
381+
) # type: Pipeline
382+
client = FormRecognizerClient(
375383
endpoint=self._endpoint,
376384
credential=self._credential,
385+
pipeline=_pipeline,
377386
**kwargs
378387
)
388+
# need to share config, but can't pass as a keyword into client
389+
client._client._config = self._client._client._config
390+
return client
379391

380392
def close(self):
381393
# type: () -> None

sdk/formrecognizer/azure-ai-formrecognizer/azure/ai/formrecognizer/_helpers.py

+25
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import six
88
from azure.core.credentials import AzureKeyCredential
99
from azure.core.pipeline.policies import AzureKeyCredentialPolicy
10+
from azure.core.pipeline.transport import HttpTransport
1011
from azure.core.exceptions import (
1112
ResourceNotFoundError,
1213
ResourceExistsError,
@@ -24,6 +25,30 @@
2425
}
2526

2627

28+
class TransportWrapper(HttpTransport):
29+
"""Wrapper class that ensures that an inner client created
30+
by a `get_client` method does not close the outer transport for the parent
31+
when used in a context manager.
32+
"""
33+
def __init__(self, transport):
34+
self._transport = transport
35+
36+
def send(self, request, **kwargs):
37+
return self._transport.send(request, **kwargs)
38+
39+
def open(self):
40+
pass
41+
42+
def close(self):
43+
pass
44+
45+
def __enter__(self):
46+
pass
47+
48+
def __exit__(self, *args): # pylint: disable=arguments-differ
49+
pass
50+
51+
2752
def get_authentication_policy(credential):
2853
authentication_policy = None
2954
if credential is None:

sdk/formrecognizer/azure-ai-formrecognizer/azure/ai/formrecognizer/aio/_form_recognizer_client_async.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -71,11 +71,13 @@ def __init__(
7171
) -> None:
7272

7373
authentication_policy = get_authentication_policy(credential)
74+
polling_interval = kwargs.pop("polling_interval", POLLING_INTERVAL)
7475
self._client = FormRecognizer(
7576
endpoint=endpoint,
7677
credential=credential, # type: ignore
7778
sdk_moniker=USER_AGENT,
7879
authentication_policy=authentication_policy,
80+
polling_interval=polling_interval,
7981
**kwargs
8082
)
8183

@@ -119,7 +121,7 @@ async def begin_recognize_receipts(
119121
:caption: Recognize US sales receipt fields.
120122
"""
121123

122-
polling_interval = kwargs.pop("polling_interval", POLLING_INTERVAL)
124+
polling_interval = kwargs.pop("polling_interval", self._client._config.polling_interval)
123125
continuation_token = kwargs.pop("continuation_token", None)
124126
content_type = kwargs.pop("content_type", None)
125127
if content_type == "application/json":
@@ -176,7 +178,7 @@ async def begin_recognize_receipts_from_url(
176178
:caption: Recognize US sales receipt fields from a URL.
177179
"""
178180

179-
polling_interval = kwargs.pop("polling_interval", POLLING_INTERVAL)
181+
polling_interval = kwargs.pop("polling_interval", self._client._config.polling_interval)
180182
continuation_token = kwargs.pop("continuation_token", None)
181183
include_text_content = kwargs.pop("include_text_content", False)
182184

@@ -230,7 +232,7 @@ async def begin_recognize_content(
230232
:caption: Recognize text and content/layout information from a form.
231233
"""
232234

233-
polling_interval = kwargs.pop("polling_interval", POLLING_INTERVAL)
235+
polling_interval = kwargs.pop("polling_interval", self._client._config.polling_interval)
234236
continuation_token = kwargs.pop("continuation_token", None)
235237
content_type = kwargs.pop("content_type", None)
236238
if content_type == "application/json":
@@ -268,7 +270,7 @@ async def begin_recognize_content_from_url(self, form_url: str, **kwargs: Any) -
268270
:raises ~azure.core.exceptions.HttpResponseError:
269271
"""
270272

271-
polling_interval = kwargs.pop("polling_interval", POLLING_INTERVAL)
273+
polling_interval = kwargs.pop("polling_interval", self._client._config.polling_interval)
272274
continuation_token = kwargs.pop("continuation_token", None)
273275
return await self._client.begin_analyze_layout_async( # type: ignore
274276
file_stream={"source": form_url},
@@ -324,7 +326,7 @@ async def begin_recognize_custom_forms(
324326
raise ValueError("model_id cannot be None or empty.")
325327

326328
cls = kwargs.pop("cls", None)
327-
polling_interval = kwargs.pop("polling_interval", POLLING_INTERVAL)
329+
polling_interval = kwargs.pop("polling_interval", self._client._config.polling_interval)
328330
continuation_token = kwargs.pop("continuation_token", None)
329331
content_type = kwargs.pop("content_type", None)
330332
if content_type == "application/json":
@@ -385,7 +387,7 @@ async def begin_recognize_custom_forms_from_url(
385387
raise ValueError("model_id cannot be None or empty.")
386388

387389
cls = kwargs.pop("cls", None)
388-
polling_interval = kwargs.pop("polling_interval", POLLING_INTERVAL)
390+
polling_interval = kwargs.pop("polling_interval", self._client._config.polling_interval)
389391
continuation_token = kwargs.pop("continuation_token", None)
390392
include_text_content = kwargs.pop("include_text_content", False)
391393

sdk/formrecognizer/azure-ai-formrecognizer/azure/ai/formrecognizer/aio/_form_training_client_async.py

+15-4
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,12 @@
1515
TYPE_CHECKING,
1616
)
1717
from azure.core.polling import AsyncLROPoller
18+
from azure.core.pipeline import AsyncPipeline
1819
from azure.core.polling.async_base_polling import AsyncLROBasePolling
1920
from azure.core.tracing.decorator import distributed_trace
2021
from azure.core.tracing.decorator_async import distributed_trace_async
2122
from ._form_recognizer_client_async import FormRecognizerClient
23+
from ._helpers_async import AsyncTransportWrapper
2224
from .._generated.aio._form_recognizer_client_async import FormRecognizerClient as FormRecognizer
2325
from .._generated.models import (
2426
TrainRequest,
@@ -81,13 +83,14 @@ def __init__(
8183
) -> None:
8284
self._endpoint = endpoint
8385
self._credential = credential
84-
8586
authentication_policy = get_authentication_policy(credential)
87+
polling_interval = kwargs.pop("polling_interval", POLLING_INTERVAL)
8688
self._client = FormRecognizer(
8789
endpoint=self._endpoint,
8890
credential=self._credential, # type: ignore
8991
sdk_moniker=USER_AGENT,
9092
authentication_policy=authentication_policy,
93+
polling_interval=polling_interval,
9194
**kwargs
9295
)
9396

@@ -138,7 +141,7 @@ def callback(raw_response):
138141

139142
cls = kwargs.pop("cls", None)
140143
continuation_token = kwargs.pop("continuation_token", None)
141-
polling_interval = kwargs.pop("polling_interval", POLLING_INTERVAL)
144+
polling_interval = kwargs.pop("polling_interval", self._client._config.polling_interval)
142145
deserialization_callback = cls if cls else callback
143146

144147
if continuation_token:
@@ -361,7 +364,7 @@ async def begin_copy_model(
361364
raise ValueError("model_id cannot be None or empty.")
362365

363366
continuation_token = kwargs.pop("continuation_token", None)
364-
polling_interval = kwargs.pop("polling_interval", POLLING_INTERVAL)
367+
polling_interval = kwargs.pop("polling_interval", self._client._config.polling_interval)
365368

366369
def _copy_callback(raw_response, _, headers): # pylint: disable=unused-argument
367370
copy_result = self._client._deserialize(CopyOperationResult, raw_response)
@@ -395,11 +398,19 @@ def get_form_recognizer_client(self, **kwargs: Any) -> FormRecognizerClient:
395398
:rtype: ~azure.ai.formrecognizer.aio.FormRecognizerClient
396399
:return: A FormRecognizerClient
397400
"""
398-
return FormRecognizerClient(
401+
_pipeline = AsyncPipeline(
402+
transport=AsyncTransportWrapper(self._client._client._pipeline._transport),
403+
policies=self._client._client._pipeline._impl_policies
404+
) # type: AsyncPipeline
405+
client = FormRecognizerClient(
399406
endpoint=self._endpoint,
400407
credential=self._credential,
408+
pipeline=_pipeline,
401409
**kwargs
402410
)
411+
# need to share config, but can't pass as a keyword into client
412+
client._client._config = self._client._client._config
413+
return client
403414

404415
async def __aenter__(self) -> "FormTrainingClient":
405416
await self._client.__aenter__()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# coding=utf-8
2+
# ------------------------------------
3+
# Copyright (c) Microsoft Corporation.
4+
# Licensed under the MIT License.
5+
# ------------------------------------
6+
7+
from azure.core.pipeline.transport import AsyncHttpTransport
8+
9+
10+
class AsyncTransportWrapper(AsyncHttpTransport):
11+
"""Wrapper class that ensures that an inner client created
12+
by a `get_client` method does not close the outer transport for the parent
13+
when used in a context manager.
14+
"""
15+
def __init__(self, async_transport):
16+
self._transport = async_transport
17+
18+
async def send(self, request, **kwargs):
19+
return await self._transport.send(request, **kwargs)
20+
21+
async def open(self):
22+
pass
23+
24+
async def close(self):
25+
pass
26+
27+
async def __aenter__(self):
28+
pass
29+
30+
async def __aexit__(self, *args): # pylint: disable=arguments-differ
31+
pass

0 commit comments

Comments
 (0)