@@ -403,15 +403,16 @@ def inner(*args, **kwds):
403
403
404
404
return decorator
405
405
406
- def _eval_type (t , globalns , localns , recursive_guard = frozenset ()):
406
+
407
+ def _eval_type (t , globalns , localns , type_params , * , recursive_guard = frozenset ()):
407
408
"""Evaluate all forward references in the given type t.
408
409
409
410
For use of globalns and localns see the docstring for get_type_hints().
410
411
recursive_guard is used to prevent infinite recursion with a recursive
411
412
ForwardRef.
412
413
"""
413
414
if isinstance (t , ForwardRef ):
414
- return t ._evaluate (globalns , localns , recursive_guard )
415
+ return t ._evaluate (globalns , localns , type_params , recursive_guard = recursive_guard )
415
416
if isinstance (t , (_GenericAlias , GenericAlias , types .UnionType )):
416
417
if isinstance (t , GenericAlias ):
417
418
args = tuple (
@@ -425,7 +426,13 @@ def _eval_type(t, globalns, localns, recursive_guard=frozenset()):
425
426
t = t .__origin__ [args ]
426
427
if is_unpacked :
427
428
t = Unpack [t ]
428
- ev_args = tuple (_eval_type (a , globalns , localns , recursive_guard ) for a in t .__args__ )
429
+
430
+ ev_args = tuple (
431
+ _eval_type (
432
+ a , globalns , localns , type_params , recursive_guard = recursive_guard
433
+ )
434
+ for a in t .__args__
435
+ )
429
436
if ev_args == t .__args__ :
430
437
return t
431
438
if isinstance (t , GenericAlias ):
@@ -906,7 +913,7 @@ def __init__(self, arg, is_argument=True, module=None, *, is_class=False):
906
913
self .__forward_is_class__ = is_class
907
914
self .__forward_module__ = module
908
915
909
- def _evaluate (self , globalns , localns , recursive_guard ):
916
+ def _evaluate (self , globalns , localns , type_params , * , recursive_guard ):
910
917
if self .__forward_arg__ in recursive_guard :
911
918
return self
912
919
if not self .__forward_evaluated__ or localns is not globalns :
@@ -920,14 +927,25 @@ def _evaluate(self, globalns, localns, recursive_guard):
920
927
globalns = getattr (
921
928
sys .modules .get (self .__forward_module__ , None ), '__dict__' , globalns
922
929
)
930
+ if type_params :
931
+ # "Inject" type parameters into the local namespace
932
+ # (unless they are shadowed by assignments *in* the local namespace),
933
+ # as a way of emulating annotation scopes when calling `eval()`
934
+ locals_to_pass = {param .__name__ : param for param in type_params } | localns
935
+ else :
936
+ locals_to_pass = localns
923
937
type_ = _type_check (
924
- eval (self .__forward_code__ , globalns , localns ),
938
+ eval (self .__forward_code__ , globalns , locals_to_pass ),
925
939
"Forward references must evaluate to types." ,
926
940
is_argument = self .__forward_is_argument__ ,
927
941
allow_special_forms = self .__forward_is_class__ ,
928
942
)
929
943
self .__forward_value__ = _eval_type (
930
- type_ , globalns , localns , recursive_guard | {self .__forward_arg__ }
944
+ type_ ,
945
+ globalns ,
946
+ localns ,
947
+ type_params ,
948
+ recursive_guard = (recursive_guard | {self .__forward_arg__ }),
931
949
)
932
950
self .__forward_evaluated__ = True
933
951
return self .__forward_value__
@@ -2241,7 +2259,7 @@ def get_type_hints(obj, globalns=None, localns=None, include_extras=False):
2241
2259
value = type (None )
2242
2260
if isinstance (value , str ):
2243
2261
value = ForwardRef (value , is_argument = False , is_class = True )
2244
- value = _eval_type (value , base_globals , base_locals )
2262
+ value = _eval_type (value , base_globals , base_locals , base . __type_params__ )
2245
2263
hints [name ] = value
2246
2264
return hints if include_extras else {k : _strip_annotations (t ) for k , t in hints .items ()}
2247
2265
@@ -2267,6 +2285,7 @@ def get_type_hints(obj, globalns=None, localns=None, include_extras=False):
2267
2285
raise TypeError ('{!r} is not a module, class, method, '
2268
2286
'or function.' .format (obj ))
2269
2287
hints = dict (hints )
2288
+ type_params = getattr (obj , "__type_params__" , ())
2270
2289
for name , value in hints .items ():
2271
2290
if value is None :
2272
2291
value = type (None )
@@ -2278,7 +2297,7 @@ def get_type_hints(obj, globalns=None, localns=None, include_extras=False):
2278
2297
is_argument = not isinstance (obj , types .ModuleType ),
2279
2298
is_class = False ,
2280
2299
)
2281
- hints [name ] = _eval_type (value , globalns , localns )
2300
+ hints [name ] = _eval_type (value , globalns , localns , type_params )
2282
2301
return hints if include_extras else {k : _strip_annotations (t ) for k , t in hints .items ()}
2283
2302
2284
2303
0 commit comments