diff --git a/tests/utils.py b/tests/utils.py index 43c3cb5cfe..060b99339f 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -97,7 +97,22 @@ def assert_matches_type( assert_matches_type(key_type, key, path=[*path, ""]) assert_matches_type(items_type, item, path=[*path, ""]) elif is_union_type(type_): - for i, variant in enumerate(get_args(type_)): + variants = get_args(type_) + + try: + none_index = variants.index(type(None)) + except ValueError: + pass + else: + # special case Optional[T] for better error messages + if len(variants) == 2: + if value is None: + # valid + return + + return assert_matches_type(type_=variants[not none_index], value=value, path=path) + + for i, variant in enumerate(variants): try: assert_matches_type(variant, value, path=[*path, f"variant {i}"]) return