Skip to content

Extendable NamedTuples #437

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 76 additions & 1 deletion src/test_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1612,6 +1612,25 @@ def __str__(self):
def __add__(self, other):
return 0

class Base1(NamedTuple):
x: int
y: int

class Base2(NamedTuple):
z: int = 0

class Derived1(Base1, NamedTuple):
'''Named tuples can have doctrings.'''
label: str

class Derived2(Base2, Base1, NamedTuple):
def method(self):
return self.x + self.y

class TwoDefaults(NamedTuple):
x: int = 0
other: str = ''

async def g_with(am: AsyncContextManager[int]):
x: int
async with am as x:
Expand All @@ -1629,7 +1648,7 @@ async def g_with(am: AsyncContextManager[int]):
# fake names for the sake of static analysis
ann_module = ann_module2 = ann_module3 = None
A = B = CSub = G = CoolEmployee = CoolEmployeeWithDefault = object
XMeth = XRepr = NoneAndForward = object
XMeth = XRepr = NoneAndForward = Derived1 = Derived2 = TwoDefaults = object

gth = get_type_hints

Expand Down Expand Up @@ -2313,20 +2332,40 @@ def test_annotation_usage_with_default(self):
self.assertEqual(CoolEmployeeWithDefault._fields, ('name', 'cool'))
self.assertEqual(CoolEmployeeWithDefault._field_types, dict(name=str, cool=int))
self.assertEqual(CoolEmployeeWithDefault._field_defaults, dict(cool=0))
self.assertEqual(TwoDefaults.__new__.__defaults__, (0, ''))
self.assertEqual(TwoDefaults().x, 0)
self.assertEqual(TwoDefaults().other, '')
self.assertEqual(TwoDefaults()[0], 0)
self.assertEqual(TwoDefaults()[1], '')
self.assertEqual(Derived2._fields, ('x', 'y', 'z'))
self.assertEqual(Derived2._field_types, dict(x=int, y=int, z=int))
self.assertEqual(Derived2._field_defaults, dict(z=0))

with self.assertRaises(TypeError):
exec("""
class NonDefaultAfterDefault(NamedTuple):
x: int = 3
y: int
""")
with self.assertRaises(TypeError):
exec("""
class BadMerged(Base1, Base2, NamedTuple):
pass
""")
with self.assertRaises(TypeError):
exec("""
class BadExtended(Base2, NamedTuple):
label: str
""")

@skipUnless(PY36, 'Python 3.6 required')
def test_annotation_usage_with_methods(self):
self.assertEqual(XMeth(1).double(), 2)
self.assertEqual(XMeth(42).x, XMeth(42)[0])
self.assertEqual(str(XRepr(42)), '42 -> 1')
self.assertEqual(XRepr(1, 2) + XRepr(3), 0)
self.assertEqual(Derived2(1, 2).method(), 3)
self.assertEqual(Derived2(3, 4, 5).method(), 7)

with self.assertRaises(AttributeError):
exec("""
Expand All @@ -2344,6 +2383,42 @@ def _source(self):
return 'no chance for this as well'
""")

@skipUnless(PY36, 'Python 3.6 required')
def test_namedtuple_extending(self):
example = Derived1(1, 2, 'test')
with self.assertRaises(TypeError):
Derived1('test')
self.assertIsInstance(example, Derived1)
self.assertIsInstance(example, tuple)
self.assertEqual(example.x, 1)
self.assertEqual(example.y, 2)
self.assertEqual(example.label, 'test')
self.assertEqual(Derived1.__name__, 'Derived1')
self.assertEqual(Derived1._fields, ('x', 'y', 'label'))
self.assertEqual(Derived1.__annotations__,
collections.OrderedDict(x=int, y=int, label=str))
self.assertIs(Derived1._field_types, Derived1.__annotations__)

@skipUnless(PY36, 'Python 3.6 required')
def test_namedtuple_merging(self):
example = Derived2(1, 2)
example2 = Derived2(1, 2, 3)
with self.assertRaises(TypeError):
Derived2(1)
with self.assertRaises(TypeError):
Derived2(1, 2, 3, 4)
self.assertIsInstance(example, Derived2)
self.assertIsInstance(example, tuple)
self.assertEqual(example.x, 1)
self.assertEqual(example.y, 2)
self.assertEqual(example.z, 0)
self.assertEqual(example2.z, 3)
self.assertEqual(Derived2.__name__, 'Derived2')
self.assertEqual(Derived2._fields, ('x', 'y', 'z'))
self.assertEqual(Derived2.__annotations__,
collections.OrderedDict(x=int, y=int, z=int))
self.assertIs(Derived2._field_types, Derived2.__annotations__)

@skipUnless(PY36, 'Python 3.6 required')
def test_namedtuple_keyword_usage(self):
LocalEmployee = NamedTuple("LocalEmployee", name=str, age=int)
Expand Down
25 changes: 15 additions & 10 deletions src/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2148,23 +2148,28 @@ def __new__(cls, typename, bases, ns):
if not _PY36:
raise TypeError("Class syntax for NamedTuple is only supported"
" in Python 3.6+")
types = ns.get('__annotations__', {})
types = collections.OrderedDict()
defaults_ns = {}
# Merge field types and defaults in reverse order (similar to how MRO works).
for base in reversed(bases):
if hasattr(base, '_field_types'): # new style named tuple
types.update(base._field_types)
defaults_ns.update(base._field_defaults)
types.update(ns.get('__annotations__', {}))
defaults_ns.update(ns)
nm_tpl = _make_nmtuple(typename, types.items())
defaults = []
defaults_dict = {}
defaults_dict = collections.OrderedDict()
for field_name in types:
if field_name in ns:
default_value = ns[field_name]
defaults.append(default_value)
defaults_dict[field_name] = default_value
elif defaults:
if field_name in defaults_ns:
defaults_dict[field_name] = defaults_ns[field_name]
elif defaults_dict:
raise TypeError("Non-default namedtuple field {field_name} cannot "
"follow default field(s) {default_names}"
.format(field_name=field_name,
default_names=', '.join(defaults_dict.keys())))
nm_tpl.__new__.__defaults__ = tuple(defaults)
nm_tpl.__new__.__defaults__ = tuple(defaults_dict.values())
nm_tpl._field_defaults = defaults_dict
# update from user namespace without overriding special namedtuple attributes
# Update from user namespace without overriding special namedtuple attributes.
for key in ns:
if key in _prohibited:
raise AttributeError("Cannot overwrite NamedTuple attribute " + key)
Expand Down