diff --git a/HISTORY.md b/HISTORY.md index da44cbbd..0ee71b56 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -9,6 +9,7 @@ - Optimize and improve unstructuring of `Optional` (unions of one type and `None`). ([#380](https://github.com/python-attrs/cattrs/issues/380) [#381](https://github.com/python-attrs/cattrs/pull/381)) - Fix `format_exception` and `transform_error` type annotations. +- Improve the implementation of `cattrs._compat.is_typeddict`. The implementation is now simpler, and relies on fewer private implementation details from `typing` and typing_extensions. ([#384](https://github.com/python-attrs/cattrs/pull/384)) ## 23.1.2 (2023-06-02) diff --git a/src/cattrs/_compat.py b/src/cattrs/_compat.py index e6ad59f6..98860ec7 100644 --- a/src/cattrs/_compat.py +++ b/src/cattrs/_compat.py @@ -20,16 +20,13 @@ from attr import fields as attrs_fields from attr import resolve_types +__all__ = ["ExceptionGroup", "ExtensionsTypedDict", "TypedDict", "is_typeddict"] + try: from typing_extensions import TypedDict as ExtensionsTypedDict except ImportError: ExtensionsTypedDict = None -try: - from typing_extensions import _TypedDictMeta as ExtensionsTypedDictMeta -except ImportError: - ExtensionsTypedDictMeta = None - if sys.version_info >= (3, 8): from typing import Final, Protocol, get_args, get_origin @@ -44,9 +41,20 @@ def get_origin(cl): from typing_extensions import Final, Protocol if sys.version_info >= (3, 11): - ExceptionGroup = ExceptionGroup + from builtins import ExceptionGroup else: - from exceptiongroup import ExceptionGroup as ExceptionGroup # noqa: PLC0414 + from exceptiongroup import ExceptionGroup + +try: + from typing_extensions import is_typeddict as _is_typeddict +except ImportError: + assert sys.version_info >= (3, 10) + from typing import is_typeddict as _is_typeddict + + +def is_typeddict(cls): + """Thin wrapper around typing(_extensions).is_typeddict""" + return _is_typeddict(getattr(cls, "__origin__", cls)) def has(cls): @@ -157,7 +165,6 @@ def get_final_base(type) -> Optional[type]: _AnnotatedAlias, _GenericAlias, _SpecialGenericAlias, - _TypedDictMeta, _UnionGenericAlias, ) @@ -234,20 +241,6 @@ def get_newtype_base(typ: Any) -> Optional[type]: return supertype return None - def is_typeddict(cls) -> bool: - return ( - cls.__class__ is _TypedDictMeta - or (is_generic(cls) and (cls.__origin__.__class__ is _TypedDictMeta)) - or ( - ExtensionsTypedDictMeta is not None - and cls.__class__ is ExtensionsTypedDictMeta - or ( - is_generic(cls) - and (cls.__origin__.__class__ is ExtensionsTypedDictMeta) - ) - ) - ) - def get_notrequired_base(type) -> "Union[Any, Literal[NOTHING]]": if get_origin(type) in (NotRequired, Required): return get_args(type)[0] @@ -364,9 +357,8 @@ def copy_with(type, args): from typing_extensions import get_origin as te_get_origin if sys.version_info >= (3, 8): - from typing import TypedDict, _TypedDictMeta + from typing import TypedDict else: - _TypedDictMeta = None TypedDict = ExtensionsTypedDict def is_annotated(type) -> bool: @@ -462,20 +454,6 @@ def copy_with(type, args): """Replace a generic type's arguments.""" return type.copy_with(args) - def is_typeddict(cls) -> bool: - return ( - cls.__class__ is _TypedDictMeta - or (is_generic(cls) and (cls.__origin__.__class__ is _TypedDictMeta)) - or ( - ExtensionsTypedDictMeta is not None - and cls.__class__ is ExtensionsTypedDictMeta - or ( - is_generic(cls) - and (cls.__origin__.__class__ is ExtensionsTypedDictMeta) - ) - ) - ) - def get_notrequired_base(type) -> "Union[Any, Literal[NOTHING]]": if get_origin(type) in (NotRequired, Required): return get_args(type)[0]