Skip to content

Commit b8c369c

Browse files
rpkilbytomchristie
authored andcommitted
Fix serializer multiple inheritance bug (#6980)
* Expand declared filtering tests - Test declared filter ordering - Test multiple inheritance * Fix serializer multiple inheritance bug * Improve field order test to check for field types
1 parent 7c54596 commit b8c369c

File tree

2 files changed

+66
-12
lines changed

2 files changed

+66
-12
lines changed

rest_framework/serializers.py

+16-12
Original file line numberDiff line numberDiff line change
@@ -298,18 +298,22 @@ def _get_declared_fields(cls, bases, attrs):
298298
if isinstance(obj, Field)]
299299
fields.sort(key=lambda x: x[1]._creation_counter)
300300

301-
# If this class is subclassing another Serializer, add that Serializer's
302-
# fields. Note that we loop over the bases in *reverse*. This is necessary
303-
# in order to maintain the correct order of fields.
304-
for base in reversed(bases):
305-
if hasattr(base, '_declared_fields'):
306-
fields = [
307-
(field_name, obj) for field_name, obj
308-
in base._declared_fields.items()
309-
if field_name not in attrs
310-
] + fields
311-
312-
return OrderedDict(fields)
301+
# Ensures a base class field doesn't override cls attrs, and maintains
302+
# field precedence when inheriting multiple parents. e.g. if there is a
303+
# class C(A, B), and A and B both define 'field', use 'field' from A.
304+
known = set(attrs)
305+
306+
def visit(name):
307+
known.add(name)
308+
return name
309+
310+
base_fields = [
311+
(visit(name), f)
312+
for base in bases if hasattr(base, '_declared_fields')
313+
for name, f in base._declared_fields.items() if name not in known
314+
]
315+
316+
return OrderedDict(base_fields + fields)
313317

314318
def __new__(cls, name, bases, attrs):
315319
attrs['_declared_fields'] = cls._get_declared_fields(bases, attrs)

tests/test_serializer.py

+50
Original file line numberDiff line numberDiff line change
@@ -682,3 +682,53 @@ class Grandchild(Child):
682682
assert len(Parent().get_fields()) == 2
683683
assert len(Child().get_fields()) == 2
684684
assert len(Grandchild().get_fields()) == 2
685+
686+
def test_multiple_inheritance(self):
687+
class A(serializers.Serializer):
688+
field = serializers.CharField()
689+
690+
class B(serializers.Serializer):
691+
field = serializers.IntegerField()
692+
693+
class TestSerializer(A, B):
694+
pass
695+
696+
fields = {
697+
name: type(f) for name, f
698+
in TestSerializer()._declared_fields.items()
699+
}
700+
assert fields == {
701+
'field': serializers.CharField,
702+
}
703+
704+
def test_field_ordering(self):
705+
class Base(serializers.Serializer):
706+
f1 = serializers.CharField()
707+
f2 = serializers.CharField()
708+
709+
class A(Base):
710+
f3 = serializers.IntegerField()
711+
712+
class B(serializers.Serializer):
713+
f3 = serializers.CharField()
714+
f4 = serializers.CharField()
715+
716+
class TestSerializer(A, B):
717+
f2 = serializers.IntegerField()
718+
f5 = serializers.CharField()
719+
720+
fields = {
721+
name: type(f) for name, f
722+
in TestSerializer()._declared_fields.items()
723+
}
724+
725+
# `IntegerField`s should be the 'winners' in field name conflicts
726+
# - `TestSerializer.f2` should override `Base.F2`
727+
# - `A.f3` should override `B.f3`
728+
assert fields == {
729+
'f1': serializers.CharField,
730+
'f2': serializers.IntegerField,
731+
'f3': serializers.IntegerField,
732+
'f4': serializers.CharField,
733+
'f5': serializers.CharField,
734+
}

0 commit comments

Comments
 (0)