|
6 | 6 |
|
7 | 7 | import json
|
8 | 8 | import functools
|
| 9 | +from collections import defaultdict |
9 | 10 | from six.moves.urllib.parse import urlparse, parse_qsl
|
10 | 11 | from azure.core.exceptions import (
|
11 | 12 | HttpResponseError,
|
|
33 | 34 | AnalyzeBatchActionsResult,
|
34 | 35 | RequestStatistics,
|
35 | 36 | AnalyzeBatchActionsType,
|
| 37 | + AnalyzeBatchActionsError, |
36 | 38 | TextDocumentBatchStatistics,
|
| 39 | + _get_indices, |
37 | 40 | )
|
38 | 41 | from ._paging import AnalyzeHealthcareResult, AnalyzeResult
|
39 | 42 |
|
@@ -217,23 +220,61 @@ def _num_tasks_in_current_page(returned_tasks_object):
|
217 | 220 | len(returned_tasks_object.key_phrase_extraction_tasks or [])
|
218 | 221 | )
|
219 | 222 |
|
| 223 | +def _get_task_type_from_error(error): |
| 224 | + if "pii" in error.target.lower(): |
| 225 | + return AnalyzeBatchActionsType.RECOGNIZE_PII_ENTITIES |
| 226 | + if "entity" in error.target.lower(): |
| 227 | + return AnalyzeBatchActionsType.RECOGNIZE_ENTITIES |
| 228 | + return AnalyzeBatchActionsType.EXTRACT_KEY_PHRASES |
| 229 | + |
| 230 | +def _get_mapped_errors(analyze_job_state): |
| 231 | + """ |
| 232 | + """ |
| 233 | + mapped_errors = defaultdict(list) |
| 234 | + if not analyze_job_state.errors: |
| 235 | + return mapped_errors |
| 236 | + for error in analyze_job_state.errors: |
| 237 | + mapped_errors[_get_task_type_from_error(error)].append((_get_error_index(error), error)) |
| 238 | + return mapped_errors |
| 239 | + |
| 240 | +def _get_error_index(error): |
| 241 | + return _get_indices(error.target)[-1] |
| 242 | + |
| 243 | +def _get_good_result(current_task_type, index_of_task_result, doc_id_order, response_headers, returned_tasks_object): |
| 244 | + deserialization_callback = _get_deserialization_callback_from_task_type(current_task_type) |
| 245 | + property_name = _get_property_name_from_task_type(current_task_type) |
| 246 | + response_task_to_deserialize = getattr(returned_tasks_object, property_name)[index_of_task_result] |
| 247 | + document_results = deserialization_callback( |
| 248 | + doc_id_order, response_task_to_deserialize.results, response_headers, lro=True |
| 249 | + ) |
| 250 | + return AnalyzeBatchActionsResult( |
| 251 | + document_results=document_results, |
| 252 | + action_type=current_task_type, |
| 253 | + completed_on=response_task_to_deserialize.last_update_date_time, |
| 254 | + ) |
| 255 | + |
220 | 256 | def get_iter_items(doc_id_order, task_order, response_headers, analyze_job_state):
|
221 | 257 | iter_items = []
|
| 258 | + task_type_to_index = defaultdict(int) # need to keep track of how many of each type of tasks we've seen |
222 | 259 | returned_tasks_object = analyze_job_state.tasks
|
| 260 | + mapped_errors = _get_mapped_errors(analyze_job_state) |
223 | 261 | for current_task_type in task_order:
|
224 |
| - deserialization_callback = _get_deserialization_callback_from_task_type(current_task_type) |
225 |
| - property_name = _get_property_name_from_task_type(current_task_type) |
226 |
| - response_task_to_deserialize = getattr(returned_tasks_object, property_name).pop(0) |
227 |
| - document_results = deserialization_callback( |
228 |
| - doc_id_order, response_task_to_deserialize.results, response_headers, lro=True |
229 |
| - ) |
230 |
| - iter_items.append( |
231 |
| - AnalyzeBatchActionsResult( |
232 |
| - document_results=document_results, |
233 |
| - action_type=current_task_type, |
234 |
| - completed_on=response_task_to_deserialize.last_update_date_time, |
| 262 | + index_of_task_result = task_type_to_index[current_task_type] |
| 263 | + |
| 264 | + try: |
| 265 | + # try to deserailize as error. If fails, we know it's good |
| 266 | + # kind of a weird way to order things, but we can fail when deserializing |
| 267 | + # the curr response as an error, not when deserializing as a good response. |
| 268 | + |
| 269 | + current_task_type_errors = mapped_errors[current_task_type] |
| 270 | + error = next(err for err in current_task_type_errors if err[0] == index_of_task_result) |
| 271 | + result = AnalyzeBatchActionsError._from_generated(error[1]) # pylint: disable=protected-access |
| 272 | + except StopIteration: |
| 273 | + result = _get_good_result( |
| 274 | + current_task_type, index_of_task_result, doc_id_order, response_headers, returned_tasks_object |
235 | 275 | )
|
236 |
| - ) |
| 276 | + iter_items.append(result) |
| 277 | + task_type_to_index[current_task_type] += 1 |
237 | 278 | return iter_items
|
238 | 279 |
|
239 | 280 | def analyze_extract_page_data(doc_id_order, task_order, response_headers, analyze_job_state):
|
|
0 commit comments