|
6 | 6 |
|
7 | 7 | # pylint: disable=protected-access
|
8 | 8 |
|
| 9 | +import json |
9 | 10 | from typing import (
|
10 | 11 | Optional,
|
11 | 12 | Any,
|
12 | 13 | Iterable,
|
| 14 | + Dict, |
13 | 15 | Union,
|
14 | 16 | TYPE_CHECKING,
|
15 | 17 | )
|
16 | 18 | from azure.core.tracing.decorator import distributed_trace
|
17 | 19 | from azure.core.polling import LROPoller
|
18 | 20 | from azure.core.polling.base_polling import LROBasePolling
|
19 |
| -from ._generated.models import Model |
20 | 21 | from ._generated._form_recognizer_client import FormRecognizerClient as FormRecognizer
|
21 |
| -from ._generated.models import TrainRequest, TrainSourceFilter |
| 22 | +from ._generated.models import ( |
| 23 | + TrainRequest, |
| 24 | + TrainSourceFilter, |
| 25 | + CopyRequest, |
| 26 | + Model, |
| 27 | + CopyOperationResult, |
| 28 | + CopyAuthorizationResult |
| 29 | +) |
22 | 30 | from ._helpers import error_map, get_authentication_policy, POLLING_INTERVAL
|
23 | 31 | from ._models import (
|
24 | 32 | CustomFormModelInfo,
|
25 | 33 | AccountProperties,
|
26 | 34 | CustomFormModel
|
27 | 35 | )
|
28 |
| -from ._polling import TrainingPolling |
| 36 | +from ._polling import TrainingPolling, CopyPolling |
29 | 37 | from ._user_agent import USER_AGENT
|
30 | 38 | from ._form_recognizer_client import FormRecognizerClient
|
31 | 39 | if TYPE_CHECKING:
|
32 | 40 | from azure.core.credentials import AzureKeyCredential, TokenCredential
|
| 41 | + from azure.core.pipeline import PipelineResponse |
33 | 42 | from azure.core.pipeline.transport import HttpResponse
|
34 | 43 | PipelineResponseType = HttpResponse
|
35 | 44 |
|
@@ -239,6 +248,99 @@ def get_custom_model(self, model_id, **kwargs):
|
239 | 248 | response = self._client.get_custom_model(model_id=model_id, include_keys=True, error_map=error_map, **kwargs)
|
240 | 249 | return CustomFormModel._from_generated(response)
|
241 | 250 |
|
| 251 | + @distributed_trace |
| 252 | + def get_copy_authorization(self, resource_id, resource_region, **kwargs): |
| 253 | + # type: (str, str, Any) -> Dict[str, Union[str, int]] |
| 254 | + """Generate authorization for copying a custom model into the target Form Recognizer resource. |
| 255 | + This should be called by the target resource (where the model will be copied to) |
| 256 | + and the output can be passed as the `target` parameter into :func:`~begin_copy_model()`. |
| 257 | +
|
| 258 | + :param str resource_id: Azure Resource Id of the target Form Recognizer resource |
| 259 | + where the model will be copied to. |
| 260 | + :param str resource_region: Location of the target Form Recognizer resource. A valid Azure |
| 261 | + region name supported by Cognitive Services. |
| 262 | + :return: A dictionary with values for the copy authorization - |
| 263 | + "modelId", "accessToken", "resourceId", "resourceRegion", and "expirationDateTimeTicks". |
| 264 | + :rtype: Dict[str, Union[str, int]] |
| 265 | + :raises ~azure.core.exceptions.HttpResponseError: |
| 266 | +
|
| 267 | + .. admonition:: Example: |
| 268 | +
|
| 269 | + .. literalinclude:: ../samples/sample_copy_model.py |
| 270 | + :start-after: [START get_copy_authorization] |
| 271 | + :end-before: [END get_copy_authorization] |
| 272 | + :language: python |
| 273 | + :dedent: 8 |
| 274 | + :caption: Authorize the target resource to receive the copied model |
| 275 | + """ |
| 276 | + |
| 277 | + response = self._client.generate_model_copy_authorization( # type: ignore |
| 278 | + cls=lambda pipeline_response, deserialized, response_headers: pipeline_response, |
| 279 | + error_map=error_map, |
| 280 | + **kwargs |
| 281 | + ) # type: PipelineResponse |
| 282 | + target = json.loads(response.http_response.text()) |
| 283 | + target["resourceId"] = resource_id |
| 284 | + target["resourceRegion"] = resource_region |
| 285 | + return target |
| 286 | + |
| 287 | + @distributed_trace |
| 288 | + def begin_copy_model( |
| 289 | + self, |
| 290 | + model_id, # type: str |
| 291 | + target, # type: Dict |
| 292 | + **kwargs # type: Any |
| 293 | + ): |
| 294 | + # type: (...) -> LROPoller |
| 295 | + """Copy a custom model stored in this resource (the source) to the user specified |
| 296 | + target Form Recognizer resource. This should be called with the source Form Recognizer resource |
| 297 | + (with the model that is intended to be copied). The `target` parameter should be supplied from the |
| 298 | + target resource's output from calling the :func:`~get_copy_authorization()` method. |
| 299 | +
|
| 300 | + :param str model_id: Model identifier of the model to copy to target resource. |
| 301 | + :param dict target: |
| 302 | + The copy authorization generated from the target resource's call to |
| 303 | + :func:`~get_copy_authorization()`. |
| 304 | + :keyword int polling_interval: Default waiting time between two polls for LRO operations if |
| 305 | + no Retry-After header is present. |
| 306 | + :return: An instance of an LROPoller. Call `result()` on the poller |
| 307 | + object to return a :class:`~azure.ai.formrecognizer.CustomFormModelInfo`. |
| 308 | + :rtype: ~azure.core.polling.LROPoller[~azure.ai.formrecognizer.CustomFormModelInfo] |
| 309 | + :raises ~azure.core.exceptions.HttpResponseError: |
| 310 | +
|
| 311 | + .. admonition:: Example: |
| 312 | +
|
| 313 | + .. literalinclude:: ../samples/sample_copy_model.py |
| 314 | + :start-after: [START begin_copy_model] |
| 315 | + :end-before: [END begin_copy_model] |
| 316 | + :language: python |
| 317 | + :dedent: 8 |
| 318 | + :caption: Copy a model from the source resource to the target resource |
| 319 | + """ |
| 320 | + |
| 321 | + polling_interval = kwargs.pop("polling_interval", POLLING_INTERVAL) |
| 322 | + |
| 323 | + def _copy_callback(raw_response, _, headers): # pylint: disable=unused-argument |
| 324 | + copy_result = self._client._deserialize(CopyOperationResult, raw_response) |
| 325 | + return CustomFormModelInfo._from_generated(copy_result, target["modelId"]) |
| 326 | + |
| 327 | + return self._client.begin_copy_custom_model( # type: ignore |
| 328 | + model_id=model_id, |
| 329 | + copy_request=CopyRequest( |
| 330 | + target_resource_id=target["resourceId"], |
| 331 | + target_resource_region=target["resourceRegion"], |
| 332 | + copy_authorization=CopyAuthorizationResult( |
| 333 | + access_token=target["accessToken"], |
| 334 | + model_id=target["modelId"], |
| 335 | + expiration_date_time_ticks=target["expirationDateTimeTicks"] |
| 336 | + ) |
| 337 | + ), |
| 338 | + cls=kwargs.pop("cls", _copy_callback), |
| 339 | + polling=LROBasePolling(timeout=polling_interval, lro_algorithms=[CopyPolling()], **kwargs), |
| 340 | + error_map=error_map, |
| 341 | + **kwargs |
| 342 | + ) |
| 343 | + |
242 | 344 | def get_form_recognizer_client(self, **kwargs):
|
243 | 345 | # type: (Any) -> FormRecognizerClient
|
244 | 346 | """Get an instance of a FormRecognizerClient from FormTrainingClient.
|
|
0 commit comments