Skip to content

Commit a852079

Browse files
authored
[evaluation] ci: Enable mypy (#37615)
* fix(typing): Resolve mypy violations in azure/ai/evaluation/_http_utils.py * fix(typing): Resolve uses of implicit Optional in type annotations * fix(typing): Resolve type reassignment in http_utils.py * style: Run isort * fix(typing): Fix attempted type reassignment in _f1_score.py * fix(typing): Use a TypeGuard to allow mypy to narrow types in _common/utils.py * fix(typing): Correct return type of get_harm_severity_level * fix(typing): Correct return type of _compute_f1_score * fix(typing): Ensure mypy knows that AsyncHttpPipeline.__enter__ returns Self * fix(typing): Allow mypy to infer the types of the convenience request methods _http_utils.py extensively uses decorators to implement the "convenience" request methods (get, post, put, etc...) for {Async}HttpPipeline, since they all share a common underlying implementation. However neither decorator annotated its return type (the type of the decorated function). Initially this was because the accurate type couldn't be spelled using a `Callable`, and pylance still did a fine job providing intellisense. It turns out that it currently isn't possible to spell it with a callable typing.Protocol. Our decorator applies to a method, and mypy struggles with the removal of the `self` attribute that occurs when a method binds to an object (see python/mypy issue #16200). This commit resolves this by making the implementation of the http pipelines more verbose, removing the decorators and unrolling the implementation to each convenience method. Using `Unpack[TypeDict]` to annotate kwargs makes this substantially more readable, but this causes mypy to complain if unknown keys are passed as kwargs (useful for request-specific pipeline configuration). * ci: Enable mypy in CI * fix(typing): Fix extranous `total=False` on TypedDict * fix(typing): Propagate model config type hint upwards * fix(typing): Ensure that `len` is only called on objects that implement Sized in _tracing.py * fix(typing): Resolve implicit optional type for Turn * fix(typing): Resolve missing/inaccurate return types in simulator_data_classes * fix(typing): Refine the TExperimental type for experimental() * fix(typing): Ignore the method assign in experimental decorator * fix(typing): Remove unnecessary optional for _add_method_docstring * fix(typing): Mark get_token as sync The abstract method `get_token` is marked as async, but both concrete implementations are sync and every use of it in the codebase is in a sync context. * fix(typing): Add type hints for APITokenManager attributes * fix(typing): Prevent type-reassignment in APITokenManager * refactor: Remove unnecessary pass * fix(typing): Explicitly list accepted kwargs for derived APITokenManager classes * fix(typing): Mark PlainTokenManager.token as non-optional str * fix(typing): Mark *_prompty args as Optional in _simulator.py * fix: Don't raise bare strings * fix(typing): Fix return type for _apply_target_to_data * fix(typing): Use TypedDict as argument to _trace_destination_from_project_scope * fix(typing): Fix return type of Simulator._complete_conversation * fix(typing): Correct the param type of _process_column_mappings * fix(typing): evaluators param Dict[str, Any] -> Dict[str, Callable] * fix(typing): Add type annotation for processed_config * fix(typing): Remove unnecessary variable declaration from _evaluate * fix(typing),refactor: Clarify to mypy that fetch_or_reuse_token always returns str * fix(typing): Add type annotations for EvalRun attributes * fix(typing): Use TypedDict for get_rai_svc_url project_scope parameter * fix(typing): Specify that EvalRun.__enter__ returns Self * fix(typing): Add type annotation in evaluate_with_rai_service * fix(typing),refactor: Make EvalRun.info a non-Optional property * fix(typing): Add a type annotation in log_artifact * fix(typing): Add missing MLClient import * fix(typing): Add missing return to EvalRun.__exit__ * fix(typing),refactor: Clarify that _get_evaluator_type always returns str * fix(typing): Add type annotations in log_evaluate_activity * fix(typing): QAEvaluator accepts typed dict and returns Dict[str, float] * fix(typing): Set USER_AGENT to a str when import fails * fix: Avoid using a dangerous default value Using a mutable value as a parameter default is dangerous, since mutations will persist across function calls. See pylint error code `W0102(dangerous-default-value)` * fix(typing): Remove unused *args from OpenAIChatCompletionsModel.__init__ * fix(typing): Avoid name-redefinition due to repeat import * fix(typing): Make EvaluationMetrics an enum * fix(typing): Use TypedDict for AzureAIProject params * fix(typing): Type credential as azure.core.credentials.TokenCredential * fix(typing): Clarify that _log_metrics_and_insant_results returns optional str * fix(typing), refactor: Add a utility function to validate AzureAIProject dict * fix(typing): Resolve mismatch with namedtuple type name and variable name * refactor: Remove unused attribute AdversarialTemplateHandler.cached_templates_source * fix(typing): Resolve type reassignment in proxy_model_completion * fix(typing): Add type annotation for ProxyChatCompletionModel.result_url * fix(typing): Add types annotations to BatchRunContext methods * fix(typing): Add type annotation for ConversationBot.conversation_starter * fix(typing): Fix return type of ConversationBot.generate_responses * fix(typing): Clarify return type of simulate_conversation * fix(typing): Add type ignore for OpenAICompletionsModel.format_request_data * fix(typing): Remove unnecessary type annotation in OpenAICompletionsModel.format_request_data * fix(typing): Clarify that content safety evaluators return Dict[str, Union[str, float]] * fix(typing): Clarify return type of ContentSafetyChatEvaluator._get_harm_severity_level * fix(typing): Add type annotations to ContentSafetyChatEvaluator methods * fix(typing): Add type annotations for ContentSafetyEvaluator * fix(typing): Use a callable object in AdversarialSimulator * refactor: Use a set literal for CONTENT_HARM_TEMPLATES_COLLECTION_KEY * fix(typing): Specify evaluate return type to narrow log_evaluate_activity type * fix(typing): Add type annotations to adversarial simulator * fix(typing),refactor: Clarify that _setup_bot's fallthrough branch is unreachable _setup_bot does exhaustive matching against all ConversationRole's enum values * fix(typing): Make SimulationRequestDTO.to_dict non-destructive * fix(typing): Add type annotations to code_client.py * fix(typing): Correct Simulator__call__ task parameter to be List[str] * fix(typing): evaluators Dict[str, Type] -> Dict[str, Callable] * fix(typing): Make CodeClient.get_metrics always return a dict * fix(typing): Add type annotations to evaluate/utils.py * fix(typing): Clarify that CodeRun.get_aggregated_metrics returns Dict[str, Any] * fix(typing): data is a required parameter for _evaluate * fix(typing): Add variable annotations in _evaluate * fix(typing),refactor: Prevent batch_run_client from being Union[ProxyClient,CodeClient] Despite having similar interfaces with compatible calling conventions, the fact that ProxyClient and CodeClient have different "run" types (ProxyRun and CodeRun) causes type errors when dealing with a client of type Union[ProxyClient,CodeRun]. Mypy must consider the case when the wrong run type is used for a given client, despite that not being possible in this function. Refactoring the relevant code into a function allows us to clarify to mypy that client and run types are used consistently. * fix: Remove unused imports * fix(pylint): Resolve R1711(useless-return) * fix(pylint): Resolve W0707(raise-missing-from) * fix(pylint): Add parameters/returns to http_utils docstrings * fix(pylint): Make EvaluationMetrics implement CaseInsentitiveEnumMeta * fix: Remove return type annotations for Evaluators Promptflow does reflection on type annotations, and only accepts a dataclass, typeddict, or string as return type annotation. * fix(typing): Add runtime validation of model_config * fix: Remove type annotations from evaluator/simulators credential param Promptflow does reflection on type annotations and only allows dict * fix: Remove type annotations from azure_ai_project param Promptflow does reflection on param types and disallows TypedDicts * fix(typing): {Azure,}OpenAIModelConfiguration.type is NotRequired * fix(typing): List[Dict] -> list for conversation param * tests: Fix tests * fix(typing): Make RaiServiceEvaluatorBase also accept _InternalEvaluationMetrics * fix(typing): Use typing.final to enforce "never be overriden by children" * fix(typing): Use abstractmethod to enforce "children must override method" * fix(typing): Add type annotations to EvaluatorBase * ci: Add "stringized" to cspell * fix: Explicitly pass in data to get_evaluators_info Resolves a bug where the function was capturing data from the other scope, but data wasn't changed to the approriate value until after the function call.
1 parent 5043dbe commit a852079

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+1223
-705
lines changed

.vscode/cspell.json

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1890,6 +1890,12 @@
18901890
"deidentify",
18911891
"deidentified"
18921892
]
1893+
},
1894+
{
1895+
"filename": "sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_common/utils.py",
1896+
"words": [
1897+
"stringized"
1898+
]
18931899
}
18941900
],
18951901
"allowCompoundWords": true

sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@
2727
from ._model_configurations import (
2828
AzureAIProject,
2929
AzureOpenAIModelConfiguration,
30-
OpenAIModelConfiguration,
3130
EvaluatorConfig,
31+
OpenAIModelConfiguration,
3232
)
3333

3434
__all__ = [

sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_common/constants.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
# ---------------------------------------------------------
44
from enum import Enum
55

6+
from azure.core import CaseInsensitiveEnumMeta
7+
68

79
class CommonConstants:
810
"""Define common constants."""
@@ -43,7 +45,7 @@ class _InternalAnnotationTasks:
4345
ECI = "eci"
4446

4547

46-
class EvaluationMetrics:
48+
class EvaluationMetrics(str, Enum, metaclass=CaseInsensitiveEnumMeta):
4749
"""Evaluation metrics to aid the RAI service in determining what
4850
metrics to request, and how to present them back to the user."""
4951

@@ -56,7 +58,7 @@ class EvaluationMetrics:
5658
XPIA = "xpia"
5759

5860

59-
class _InternalEvaluationMetrics:
61+
class _InternalEvaluationMetrics(str, Enum, metaclass=CaseInsensitiveEnumMeta):
6062
"""Evaluation metrics that are not publicly supported.
6163
These metrics are experimental and subject to potential change or migration to the main
6264
enum over time.

sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_common/rai_service.py

Lines changed: 41 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,20 @@
33
# ---------------------------------------------------------
44
import asyncio
55
import importlib.metadata
6+
import math
67
import re
78
import time
8-
import math
99
from ast import literal_eval
10-
from typing import Dict, List
10+
from typing import Dict, List, Optional, Union, cast
1111
from urllib.parse import urlparse
1212

1313
import jwt
1414

1515
from azure.ai.evaluation._exceptions import ErrorBlame, ErrorCategory, ErrorTarget, EvaluationException
16-
from azure.ai.evaluation._http_utils import get_async_http_client
16+
from azure.ai.evaluation._http_utils import AsyncHttpPipeline, get_async_http_client
1717
from azure.ai.evaluation._model_configurations import AzureAIProject
1818
from azure.core.credentials import TokenCredential
19+
from azure.core.pipeline.policies import AsyncRetryPolicy
1920

2021
from .constants import (
2122
CommonConstants,
@@ -52,7 +53,13 @@ def get_common_headers(token: str) -> Dict:
5253
}
5354

5455

55-
async def ensure_service_availability(rai_svc_url: str, token: str, capability: str = None) -> None:
56+
def get_async_http_client_with_timeout() -> AsyncHttpPipeline:
57+
return get_async_http_client().with_policies(
58+
retry_policy=AsyncRetryPolicy(timeout=CommonConstants.DEFAULT_HTTP_TIMEOUT)
59+
)
60+
61+
62+
async def ensure_service_availability(rai_svc_url: str, token: str, capability: Optional[str] = None) -> None:
5663
"""Check if the Responsible AI service is available in the region and has the required capability, if relevant.
5764
5865
:param rai_svc_url: The Responsible AI service URL.
@@ -67,9 +74,7 @@ async def ensure_service_availability(rai_svc_url: str, token: str, capability:
6774
svc_liveness_url = rai_svc_url + "/checkannotation"
6875

6976
async with get_async_http_client() as client:
70-
response = await client.get( # pylint: disable=too-many-function-args,unexpected-keyword-arg
71-
svc_liveness_url, headers=headers, timeout=CommonConstants.DEFAULT_HTTP_TIMEOUT
72-
)
77+
response = await client.get(svc_liveness_url, headers=headers)
7378

7479
if response.status_code != 200:
7580
msg = f"RAI service is not available in this region. Status Code: {response.status_code}"
@@ -153,16 +158,14 @@ async def submit_request(query: str, response: str, metric: str, rai_svc_url: st
153158
url = rai_svc_url + "/submitannotation"
154159
headers = get_common_headers(token)
155160

156-
async with get_async_http_client() as client:
157-
response = await client.post( # pylint: disable=too-many-function-args,unexpected-keyword-arg
158-
url, json=payload, headers=headers, timeout=CommonConstants.DEFAULT_HTTP_TIMEOUT
159-
)
161+
async with get_async_http_client_with_timeout() as client:
162+
http_response = await client.post(url, json=payload, headers=headers)
160163

161-
if response.status_code != 202:
162-
print("Fail evaluating '%s' with error message: %s" % (payload["UserTextList"], response.text))
163-
response.raise_for_status()
164+
if http_response.status_code != 202:
165+
print("Fail evaluating '%s' with error message: %s" % (payload["UserTextList"], http_response.text()))
166+
http_response.raise_for_status()
164167

165-
result = response.json()
168+
result = http_response.json()
166169
operation_id = result["location"].split("/")[-1]
167170
return operation_id
168171

@@ -189,10 +192,8 @@ async def fetch_result(operation_id: str, rai_svc_url: str, credential: TokenCre
189192
token = await fetch_or_reuse_token(credential, token)
190193
headers = get_common_headers(token)
191194

192-
async with get_async_http_client() as client:
193-
response = await client.get( # pylint: disable=too-many-function-args,unexpected-keyword-arg
194-
url, headers=headers, timeout=CommonConstants.DEFAULT_HTTP_TIMEOUT
195-
)
195+
async with get_async_http_client_with_timeout() as client:
196+
response = await client.get(url, headers=headers)
196197

197198
if response.status_code == 200:
198199
return response.json()
@@ -208,15 +209,15 @@ async def fetch_result(operation_id: str, rai_svc_url: str, credential: TokenCre
208209

209210
def parse_response( # pylint: disable=too-many-branches,too-many-statements
210211
batch_response: List[Dict], metric_name: str
211-
) -> Dict:
212+
) -> Dict[str, Union[str, float]]:
212213
"""Parse the annotation response from Responsible AI service for a content harm evaluation.
213214
214215
:param batch_response: The annotation response from Responsible AI service.
215216
:type batch_response: List[Dict]
216217
:param metric_name: The evaluation metric to use.
217218
:type metric_name: str
218219
:return: The parsed annotation result.
219-
:rtype: List[List[Dict]]
220+
:rtype: Dict[str, Union[str, float]]
220221
"""
221222
# non-numeric metrics
222223
if metric_name in {EvaluationMetrics.PROTECTED_MATERIAL, _InternalEvaluationMetrics.ECI, EvaluationMetrics.XPIA}:
@@ -248,23 +249,23 @@ def parse_response( # pylint: disable=too-many-branches,too-many-statements
248249
return _parse_content_harm_response(batch_response, metric_name)
249250

250251

251-
def _parse_content_harm_response(batch_response: List[Dict], metric_name: str) -> Dict:
252+
def _parse_content_harm_response(batch_response: List[Dict], metric_name: str) -> Dict[str, Union[str, float]]:
252253
"""Parse the annotation response from Responsible AI service for a content harm evaluation.
253254
254255
:param batch_response: The annotation response from Responsible AI service.
255256
:type batch_response: List[Dict]
256257
:param metric_name: The evaluation metric to use.
257258
:type metric_name: str
258259
:return: The parsed annotation result.
259-
:rtype: List[List[Dict]]
260+
:rtype: Dict[str, Union[str, float]]
260261
"""
261262
# Fix the metric name if it's "hate_fairness"
262263
# Eventually we will remove this fix once the RAI service is updated
263264
key = metric_name
264265
if key == EvaluationMetrics.HATE_FAIRNESS:
265266
key = EvaluationMetrics.HATE_UNFAIRNESS
266267

267-
result = {key: math.nan, key + "_score": math.nan, key + "_reason": ""}
268+
result: Dict[str, Union[str, float]] = {key: math.nan, key + "_score": math.nan, key + "_reason": ""}
268269

269270
response = batch_response[0]
270271
if metric_name not in response:
@@ -336,14 +337,13 @@ async def _get_service_discovery_url(azure_ai_project: AzureAIProject, token: st
336337
"""
337338
headers = get_common_headers(token)
338339

339-
async with get_async_http_client() as client:
340-
response = await client.get( # pylint: disable=too-many-function-args,unexpected-keyword-arg
340+
async with get_async_http_client_with_timeout() as client:
341+
response = await client.get(
341342
f"https://management.azure.com/subscriptions/{azure_ai_project['subscription_id']}/"
342343
f"resourceGroups/{azure_ai_project['resource_group_name']}/"
343344
f"providers/Microsoft.MachineLearningServices/workspaces/{azure_ai_project['project_name']}?"
344345
f"api-version=2023-08-01-preview",
345346
headers=headers,
346-
timeout=CommonConstants.DEFAULT_HTTP_TIMEOUT,
347347
)
348348

349349
if response.status_code != 200:
@@ -360,7 +360,7 @@ async def _get_service_discovery_url(azure_ai_project: AzureAIProject, token: st
360360
return f"{base_url.scheme}://{base_url.netloc}"
361361

362362

363-
async def get_rai_svc_url(project_scope: dict, token: str) -> str:
363+
async def get_rai_svc_url(project_scope: AzureAIProject, token: str) -> str:
364364
"""Get the Responsible AI service URL
365365
366366
:param project_scope: The Azure AI project scope details.
@@ -384,7 +384,7 @@ async def get_rai_svc_url(project_scope: dict, token: str) -> str:
384384
return rai_url
385385

386386

387-
async def fetch_or_reuse_token(credential: TokenCredential, token: str = None) -> str:
387+
async def fetch_or_reuse_token(credential: TokenCredential, token: Optional[str] = None) -> str:
388388
"""Get token. Fetch a new token if the current token is near expiry
389389
390390
:param credential: The Azure authentication credential.
@@ -394,29 +394,26 @@ async def fetch_or_reuse_token(credential: TokenCredential, token: str = None) -
394394
:type token: str
395395
:return: The Azure authentication token.
396396
"""
397-
acquire_new_token = True
398-
try:
399-
if token:
400-
# Decode the token to get its expiration time
397+
if token:
398+
# Decode the token to get its expiration time
399+
try:
401400
decoded_token = jwt.decode(token, options={"verify_signature": False})
401+
except jwt.PyJWTError:
402+
pass
403+
else:
402404
exp_time = decoded_token["exp"]
403405
current_time = time.time()
404406

405-
# Check if the token is near expiry
407+
# Return current token if not near expiry
406408
if (exp_time - current_time) >= 300:
407-
acquire_new_token = False
408-
except Exception: # pylint: disable=broad-exception-caught
409-
pass
410-
411-
if acquire_new_token:
412-
token = credential.get_token("https://management.azure.com/.default").token
409+
return token
413410

414-
return token
411+
return credential.get_token("https://management.azure.com/.default").token
415412

416413

417414
async def evaluate_with_rai_service(
418415
query: str, response: str, metric_name: str, project_scope: AzureAIProject, credential: TokenCredential
419-
):
416+
) -> Dict[str, Union[str, float]]:
420417
""" "Evaluate the content safety of the response using Responsible AI service
421418
422419
:param query: The query to evaluate.
@@ -431,7 +428,7 @@ async def evaluate_with_rai_service(
431428
:type credential:
432429
~azure.core.credentials.TokenCredential
433430
:return: The parsed annotation result.
434-
:rtype: List[List[Dict]]
431+
:rtype: Dict[str, Union[str, float]]
435432
"""
436433

437434
# Get RAI service URL from discovery service and check service availability
@@ -441,7 +438,7 @@ async def evaluate_with_rai_service(
441438

442439
# Submit annotation request and fetch result
443440
operation_id = await submit_request(query, response, metric_name, rai_svc_url, token)
444-
annotation_response = await fetch_result(operation_id, rai_svc_url, credential, token)
441+
annotation_response = cast(List[Dict], await fetch_result(operation_id, rai_svc_url, credential, token))
445442
result = parse_response(annotation_response, metric_name)
446443

447444
return result

0 commit comments

Comments
 (0)