Skip to content

Commit 20eee41

Browse files
committed
add support for heterogeneous list of good and bad responses
1 parent c1f7bdc commit 20eee41

7 files changed

+694
-517
lines changed

sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
AnalyzeBatchActionsResult,
4343
RequestStatistics,
4444
AnalyzeBatchActionsType,
45+
AnalyzeBatchActionsError,
4546
)
4647
from._paging import AnalyzeHealthcareResult
4748

@@ -83,6 +84,7 @@
8384
'AnalyzeBatchActionsResult',
8485
'RequestStatistics',
8586
'AnalyzeBatchActionsType',
87+
"AnalyzeBatchActionsError",
8688
]
8789

8890
__version__ = VERSION

sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/_models.py

+26
Original file line numberDiff line numberDiff line change
@@ -1171,6 +1171,7 @@ class AnalyzeBatchActionsType(str, Enum):
11711171
RECOGNIZE_PII_ENTITIES = "recognize_pii_entities" #: PII Entities Recognition action.
11721172
EXTRACT_KEY_PHRASES = "extract_key_phrases" #: Key Phrase Extraction action.
11731173

1174+
11741175
class AnalyzeBatchActionsResult(DictMixin):
11751176
"""AnalyzeBatchActionsResult contains the results of a recognize entities action
11761177
on a list of documents. Returned by `begin_analyze_batch_actions`
@@ -1200,6 +1201,31 @@ def __repr__(self):
12001201
self.completed_on
12011202
)[:1024]
12021203

1204+
class AnalyzeBatchActionsError(DictMixin):
1205+
"""AnalyzeBatchActionsError is an error object which represents an an
1206+
error response for an action.
1207+
1208+
:ivar error: The action result error.
1209+
:vartype error: ~azure.ai.textanalytics.TextAnalyticsError
1210+
:ivar bool is_error: Boolean check for error item when iterating over list of
1211+
results. Always True for an instance of a DocumentError.
1212+
"""
1213+
1214+
def __init__(self, **kwargs):
1215+
self.error = kwargs.get("error")
1216+
self.is_error = True
1217+
1218+
def __repr__(self):
1219+
return "AnalyzeBatchActionsError(error={}, is_error={}".format(
1220+
self.error, self.is_error
1221+
)
1222+
1223+
@classmethod
1224+
def _from_generated(cls, error):
1225+
return cls(
1226+
error=TextAnalyticsError(code=error.code, message=error.message, target=error.target)
1227+
)
1228+
12031229

12041230
class RecognizeEntitiesAction(DictMixin):
12051231
"""RecognizeEntitiesAction encapsulates the parameters for starting a long-running Entities Recognition operation.

sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/_response_handlers.py

+53-12
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import json
88
import functools
9+
from collections import defaultdict
910
from six.moves.urllib.parse import urlparse, parse_qsl
1011
from azure.core.exceptions import (
1112
HttpResponseError,
@@ -33,7 +34,9 @@
3334
AnalyzeBatchActionsResult,
3435
RequestStatistics,
3536
AnalyzeBatchActionsType,
37+
AnalyzeBatchActionsError,
3638
TextDocumentBatchStatistics,
39+
_get_indices,
3740
)
3841
from ._paging import AnalyzeHealthcareResult, AnalyzeResult
3942

@@ -217,23 +220,61 @@ def _num_tasks_in_current_page(returned_tasks_object):
217220
len(returned_tasks_object.key_phrase_extraction_tasks or [])
218221
)
219222

223+
def _get_task_type_from_error(error):
224+
if "pii" in error.target.lower():
225+
return AnalyzeBatchActionsType.RECOGNIZE_PII_ENTITIES
226+
if "entity" in error.target.lower():
227+
return AnalyzeBatchActionsType.RECOGNIZE_ENTITIES
228+
return AnalyzeBatchActionsType.EXTRACT_KEY_PHRASES
229+
230+
def _get_mapped_errors(analyze_job_state):
231+
"""
232+
"""
233+
mapped_errors = defaultdict(list)
234+
if not analyze_job_state.errors:
235+
return mapped_errors
236+
for error in analyze_job_state.errors:
237+
mapped_errors[_get_task_type_from_error(error)].append((_get_error_index(error), error))
238+
return mapped_errors
239+
240+
def _get_error_index(error):
241+
return _get_indices(error.target)[-1]
242+
243+
def _get_good_result(current_task_type, index_of_task_result, doc_id_order, response_headers, returned_tasks_object):
244+
deserialization_callback = _get_deserialization_callback_from_task_type(current_task_type)
245+
property_name = _get_property_name_from_task_type(current_task_type)
246+
response_task_to_deserialize = getattr(returned_tasks_object, property_name)[index_of_task_result]
247+
document_results = deserialization_callback(
248+
doc_id_order, response_task_to_deserialize.results, response_headers, lro=True
249+
)
250+
return AnalyzeBatchActionsResult(
251+
document_results=document_results,
252+
action_type=current_task_type,
253+
completed_on=response_task_to_deserialize.last_update_date_time,
254+
)
255+
220256
def get_iter_items(doc_id_order, task_order, response_headers, analyze_job_state):
221257
iter_items = []
258+
task_type_to_index = defaultdict(int) # need to keep track of how many of each type of tasks we've seen
222259
returned_tasks_object = analyze_job_state.tasks
260+
mapped_errors = _get_mapped_errors(analyze_job_state)
223261
for current_task_type in task_order:
224-
deserialization_callback = _get_deserialization_callback_from_task_type(current_task_type)
225-
property_name = _get_property_name_from_task_type(current_task_type)
226-
response_task_to_deserialize = getattr(returned_tasks_object, property_name).pop(0)
227-
document_results = deserialization_callback(
228-
doc_id_order, response_task_to_deserialize.results, response_headers, lro=True
229-
)
230-
iter_items.append(
231-
AnalyzeBatchActionsResult(
232-
document_results=document_results,
233-
action_type=current_task_type,
234-
completed_on=response_task_to_deserialize.last_update_date_time,
262+
index_of_task_result = task_type_to_index[current_task_type]
263+
264+
try:
265+
# try to deserailize as error. If fails, we know it's good
266+
# kind of a weird way to order things, but we can fail when deserializing
267+
# the curr response as an error, not when deserializing as a good response.
268+
269+
current_task_type_errors = mapped_errors[current_task_type]
270+
error = next(err for err in current_task_type_errors if err[0] == index_of_task_result)
271+
result = AnalyzeBatchActionsError._from_generated(error[1]) # pylint: disable=protected-access
272+
except StopIteration:
273+
result = _get_good_result(
274+
current_task_type, index_of_task_result, doc_id_order, response_headers, returned_tasks_object
235275
)
236-
)
276+
iter_items.append(result)
277+
task_type_to_index[current_task_type] += 1
237278
return iter_items
238279

239280
def analyze_extract_page_data(doc_id_order, task_order, response_headers, analyze_job_state):

0 commit comments

Comments
 (0)