diff --git a/src/test_typing_extensions.py b/src/test_typing_extensions.py index 77876b7f..64cd5e77 100644 --- a/src/test_typing_extensions.py +++ b/src/test_typing_extensions.py @@ -5255,7 +5255,7 @@ def test_typing_extensions_defers_when_possible(self): if sys.version_info < (3, 10, 1): exclude |= {"Literal"} if sys.version_info < (3, 11): - exclude |= {'final', 'Any', 'NewType'} + exclude |= {'final', 'Any', 'NewType', 'ForwardRef'} if sys.version_info < (3, 12): exclude |= { 'SupportsAbs', 'SupportsBytes', diff --git a/src/typing_extensions.py b/src/typing_extensions.py index 1666e96b..7b23e759 100644 --- a/src/typing_extensions.py +++ b/src/typing_extensions.py @@ -1042,7 +1042,143 @@ def greet(name: str) -> None: if hasattr(typing, "Required"): # 3.11+ get_type_hints = typing.get_type_hints + ForwardRef = typing.ForwardRef + else: # <=3.10 + if hasattr(_types, "GenericAlias"): + _generic_alias_bases = (typing._GenericAlias, _types.GenericAlias) + else: + _generic_alias_bases = (typing._GenericAlias,) + + + def _is_union_type(t): + return hasattr(_types, "UnionType") and isinstance(t, _types.UnionType) + + class ForwardRef(typing._Final, _root=True): + """Internal wrapper to hold a forward reference.""" + + __slots__ = ('__forward_arg__', '__forward_code__', + '__forward_evaluated__', '__forward_value__', + '__forward_is_argument__', '__forward_is_class__', + '__forward_module__') + + def __init__(self, arg, is_argument=True, module=None, *, is_class=False): + if not isinstance(arg, str): + raise TypeError(f"Forward reference must be a string -- got {arg!r}") + + # If we do `def f(*args: *Ts)`, then we'll have `arg = '*Ts'`. + # Unfortunately, this isn't a valid expression on its own, so we + # do the unpacking manually. + if arg[0] == '*': + arg_to_compile = f'({arg},)[0]' # E.g. (*Ts,)[0] or (*tuple[int, int],)[0] + else: + arg_to_compile = arg + try: + code = compile(arg_to_compile, '', 'eval') + except SyntaxError: + raise SyntaxError(f"Forward reference must be an expression -- got {arg!r}") + + self.__forward_arg__ = arg + self.__forward_code__ = code + self.__forward_evaluated__ = False + self.__forward_value__ = None + self.__forward_is_argument__ = is_argument + self.__forward_is_class__ = is_class + self.__forward_module__ = module + + def _evaluate(self, globalns, localns, recursive_guard): + if self.__forward_arg__ in recursive_guard: + return self + if not self.__forward_evaluated__ or localns is not globalns: + if globalns is None and localns is None: + globalns = localns = {} + elif globalns is None: + globalns = localns + elif localns is None: + localns = globalns + if self.__forward_module__ is not None: + globalns = getattr( + sys.modules.get(self.__forward_module__, None), '__dict__', globalns + ) + type_ = _type_check( + eval(self.__forward_code__, globalns, localns), + "Forward references must evaluate to types.", + is_argument=self.__forward_is_argument__, + allow_special_forms=self.__forward_is_class__, + ) + self.__forward_value__ = _eval_type( + type_, globalns, localns, recursive_guard | {self.__forward_arg__} + ) + self.__forward_evaluated__ = True + return self.__forward_value__ + + def __eq__(self, other): + if not isinstance(other, (ForwardRef, typing.ForwardRef)): + return NotImplemented + if self.__forward_evaluated__ and other.__forward_evaluated__: + return (self.__forward_arg__ == other.__forward_arg__ and + self.__forward_value__ == other.__forward_value__) + return (self.__forward_arg__ == other.__forward_arg__ and + self.__forward_module__ == other.__forward_module__) + + def __hash__(self): + return hash((self.__forward_arg__, self.__forward_module__)) + + def __or__(self, other): + return Union[self, other] + + def __ror__(self, other): + return Union[other, self] + + def __repr__(self): + if self.__forward_module__ is None: + module_repr = '' + else: + module_repr = f', module={self.__forward_module__!r}' + return f'ForwardRef({self.__forward_arg__!r}{module_repr})' + + + def _type_check(arg, msg, is_argument=True, module=None, *, allow_special_forms=False): + """Check that the argument is a type, and return it (internal helper). + + As a special case, accept None and return type(None) instead. Also wrap strings + into ForwardRef instances. Consider several corner cases, for example plain + special forms like Union are not valid, while Union[int, str] is OK, etc. + The msg argument is a human-readable error message, e.g.:: + + "Union[arg, ...]: arg should be a type." + + We append the repr() of the actual value (truncated to 100 chars). + """ + invalid_generic_forms = (Generic, Protocol) + if not allow_special_forms: + invalid_generic_forms += (ClassVar,) + if is_argument: + invalid_generic_forms += (Final,) + + arg = _type_convert(arg, module=module, allow_special_forms=allow_special_forms) + if (isinstance(arg, typing._GenericAlias) and + arg.__origin__ in invalid_generic_forms): + raise TypeError(f"{arg} is not valid as type argument") + if arg in (Any, LiteralString, NoReturn, Never, Self, TypeAlias): + return arg + if allow_special_forms and arg in (ClassVar, Final): + return arg + if isinstance(arg, _SpecialForm) or arg in (Generic, Protocol): + raise TypeError(f"Plain {arg} is not valid as type argument") + if type(arg) is tuple: + raise TypeError(f"{msg} Got {arg!r:.100}.") + return arg + + + def _type_convert(arg, module=None, *, allow_special_forms=False): + """For converting None to type(None), and strings to ForwardRef.""" + if arg is None: + return type(None) + if isinstance(arg, str): + return ForwardRef(arg, module=module, is_class=allow_special_forms) + return arg + # replaces _strip_annotations() def _strip_extras(t): """Strips Annotated, Required and NotRequired from a given type.""" @@ -1060,7 +1196,7 @@ def _strip_extras(t): if stripped_args == t.__args__: return t return _types.GenericAlias(t.__origin__, stripped_args) - if hasattr(_types, "UnionType") and isinstance(t, _types.UnionType): + if _is_union_type(t): stripped_args = tuple(_strip_extras(a) for a in t.__args__) if stripped_args == t.__args__: return t @@ -1100,15 +1236,124 @@ def get_type_hints(obj, globalns=None, localns=None, include_extras=False): - If two dict arguments are passed, they specify globals and locals, respectively. """ - if hasattr(typing, "Annotated"): # 3.9+ - hint = typing.get_type_hints( - obj, globalns=globalns, localns=localns, include_extras=True + if getattr(obj, "__no_type_check__", None): + return {} + # Classes require a special treatment. + if isinstance(obj, type): + hints = {} + for base in reversed(obj.__mro__): + if globalns is None: + base_globals = getattr( + sys.modules.get(base.__module__, None), "__dict__", {} + ) + else: + base_globals = globalns + ann = base.__dict__.get("__annotations__", {}) + if isinstance(ann, _types.GetSetDescriptorType): + ann = {} + base_locals = dict(vars(base)) if localns is None else localns + if localns is None and globalns is None: + # This is surprising, but required. Before Python 3.10, + # get_type_hints only evaluated the globalns of + # a class. To maintain backwards compatibility, we reverse + # the globalns and localns order so that eval() looks into + # *base_globals* first rather than *base_locals*. + # This only affects ForwardRefs. + base_globals, base_locals = base_locals, base_globals + for name, value in ann.items(): + if value is None: + value = type(None) + if isinstance(value, str): + value = ForwardRef(value, is_argument=False, is_class=True) + value = _eval_type(value, base_globals, base_locals) + hints[name] = value + return ( + hints + if include_extras + else {k: _strip_extras(t) for k, t in hints.items()} ) - else: # 3.8 - hint = typing.get_type_hints(obj, globalns=globalns, localns=localns) - if include_extras: - return hint - return {k: _strip_extras(t) for k, t in hint.items()} + + if globalns is None: + if isinstance(obj, _types.ModuleType): + globalns = obj.__dict__ + else: + nsobj = obj + # Find globalns for the unwrapped object. + while hasattr(nsobj, "__wrapped__"): + nsobj = nsobj.__wrapped__ + globalns = getattr(nsobj, "__globals__", {}) + if localns is None: + localns = globalns + elif localns is None: + localns = globalns + hints = getattr(obj, "__annotations__", None) + if hints is None: + # Return empty annotations for something that _could_ have them. + if isinstance(obj, typing._allowed_types): + return {} + else: + raise TypeError( + "{!r} is not a module, class, method, " "or function.".format(obj) + ) + hints = dict(hints) + for name, value in hints.items(): + if value is None: + value = type(None) + if isinstance(value, str): + # class-level forward refs were handled above, this must be either + # a module-level annotation or a function argument annotation + value = ForwardRef( + value, + is_argument=not isinstance(obj, _types.ModuleType), + is_class=False, + ) + hints[name] = _eval_type(value, globalns, localns) + return ( + hints if include_extras else {k: _strip_extras(t) for k, t in hints.items()} + ) + + def _eval_type(t, globalns, localns, recursive_guard=frozenset()): + if isinstance(t, (ForwardRef, typing.ForwardRef)): + if not hasattr(t, '__forward_is_class__'): + return t._evaluate(globalns, localns) + return t._evaluate(globalns, localns, recursive_guard) + + + if isinstance(t, _generic_alias_bases) or _is_union_type(t): + if hasattr(_types, "GenericAlias") and isinstance(t, _types.GenericAlias): + args = tuple( + ForwardRef(arg) if isinstance(arg, str) else arg + for arg in t.__args__ + ) + is_unpacked = t.__unpacked__ + if _should_unflatten_callable_args(t, args): + t = t.__origin__[(args[:-1], args[-1])] + else: + t = t.__origin__[args] + if is_unpacked: + t = Unpack[t] + ev_args = tuple(_eval_type(a, globalns, localns, recursive_guard) for a in t.__args__) + if ev_args == t.__args__: + return t + if hasattr(_types, "GenericAlias") and isinstance(t, _types.GenericAlias): + return typing.GenericAlias(t.__origin__, ev_args) + if _is_union_type(t): + return functools.reduce(operator.or_, ev_args) + else: + return t.copy_with(ev_args) + return t + + def _should_unflatten_callable_args(typ, args): + return ( + typ.__origin__ is collections.abc.Callable + and not (len(args) == 2 and _is_param_expr(args[0])) + ) + + def _is_param_expr(arg): + return arg is ... or isinstance(arg, + (tuple, list, ParamSpec, _ConcatenateGenericAlias)) + + # Python 3.9+ has PEP 593 (Annotated) @@ -3021,7 +3266,6 @@ def __eq__(self, other: object) -> bool: Collection = typing.Collection Container = typing.Container Dict = typing.Dict -ForwardRef = typing.ForwardRef FrozenSet = typing.FrozenSet Generator = typing.Generator Generic = typing.Generic