Skip to content

Commit 2b34aa4

Browse files
authored
Re-prefetch related objects after updating (#8043)
* Re-prefetch related objects after updating * Fix flake8 format * Use _prefetch_related_lookups and refine test cases * Add more test cases and refine prefetch checking
1 parent bfce663 commit 2b34aa4

File tree

2 files changed

+71
-32
lines changed

2 files changed

+71
-32
lines changed

rest_framework/mixins.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
We don't bind behaviour to http method handlers yet,
55
which allows mixin classes to be composed in interesting ways.
66
"""
7+
from django.db.models.query import prefetch_related_objects
8+
79
from rest_framework import status
810
from rest_framework.response import Response
911
from rest_framework.settings import api_settings
@@ -67,10 +69,13 @@ def update(self, request, *args, **kwargs):
6769
serializer.is_valid(raise_exception=True)
6870
self.perform_update(serializer)
6971

70-
if getattr(instance, '_prefetched_objects_cache', None):
72+
queryset = self.filter_queryset(self.get_queryset())
73+
if queryset._prefetch_related_lookups:
7174
# If 'prefetch_related' has been applied to a queryset, we need to
72-
# forcibly invalidate the prefetch cache on the instance.
75+
# forcibly invalidate the prefetch cache on the instance,
76+
# and then re-prefetch related objects
7377
instance._prefetched_objects_cache = {}
78+
prefetch_related_objects([instance], *queryset._prefetch_related_lookups)
7479

7580
return Response(serializer.data)
7681

tests/test_prefetch_related.py

+64-30
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from django.contrib.auth.models import Group, User
2+
from django.db.models.query import Prefetch
23
from django.test import TestCase
34

45
from rest_framework import generics, serializers
@@ -8,51 +9,84 @@
89

910

1011
class UserSerializer(serializers.ModelSerializer):
12+
permissions = serializers.SerializerMethodField()
13+
14+
def get_permissions(self, obj):
15+
ret = []
16+
for g in obj.groups.all():
17+
ret.extend([p.pk for p in g.permissions.all()])
18+
return ret
19+
1120
class Meta:
1221
model = User
13-
fields = ('id', 'username', 'email', 'groups')
22+
fields = ('id', 'username', 'email', 'groups', 'permissions')
23+
24+
25+
class UserRetrieveUpdate(generics.RetrieveUpdateAPIView):
26+
queryset = User.objects.exclude(username='exclude').prefetch_related(
27+
Prefetch('groups', queryset=Group.objects.exclude(name='exclude')),
28+
'groups__permissions',
29+
)
30+
serializer_class = UserSerializer
1431

1532

16-
class UserUpdate(generics.UpdateAPIView):
17-
queryset = User.objects.exclude(username='exclude').prefetch_related('groups')
33+
class UserUpdateWithoutPrefetchRelated(generics.UpdateAPIView):
34+
queryset = User.objects.exclude(username='exclude')
1835
serializer_class = UserSerializer
1936

2037

2138
class TestPrefetchRelatedUpdates(TestCase):
2239
def setUp(self):
2340
self.user = User.objects.create(username='tom', email='[email protected]')
24-
self.groups = [Group.objects.create(name='a'), Group.objects.create(name='b')]
41+
self.groups = [Group.objects.create(name=f'group {i}') for i in range(10)]
2542
self.user.groups.set(self.groups)
43+
self.user.groups.add(Group.objects.create(name='exclude'))
44+
self.expected = {
45+
'id': self.user.pk,
46+
'username': 'tom',
47+
'groups': [group.pk for group in self.groups],
48+
'email': '[email protected]',
49+
'permissions': [],
50+
}
51+
self.view = UserRetrieveUpdate.as_view()
2652

2753
def test_prefetch_related_updates(self):
28-
view = UserUpdate.as_view()
29-
pk = self.user.pk
30-
groups_pk = self.groups[0].pk
31-
request = factory.put('/', {'username': 'new', 'groups': [groups_pk]}, format='json')
32-
response = view(request, pk=pk)
33-
assert User.objects.get(pk=pk).groups.count() == 1
34-
expected = {
35-
'id': pk,
36-
'username': 'new',
37-
'groups': [1],
38-
'email': '[email protected]'
39-
}
40-
assert response.data == expected
54+
self.groups.append(Group.objects.create(name='c'))
55+
request = factory.put(
56+
'/', {'username': 'new', 'groups': [group.pk for group in self.groups]}, format='json'
57+
)
58+
self.expected['username'] = 'new'
59+
self.expected['groups'] = [group.pk for group in self.groups]
60+
response = self.view(request, pk=self.user.pk)
61+
assert User.objects.get(pk=self.user.pk).groups.count() == 12
62+
assert response.data == self.expected
63+
# Update and fetch should get same result
64+
request = factory.get('/')
65+
response = self.view(request, pk=self.user.pk)
66+
assert response.data == self.expected
4167

4268
def test_prefetch_related_excluding_instance_from_original_queryset(self):
4369
"""
4470
Regression test for https://github.com/encode/django-rest-framework/issues/4661
4571
"""
46-
view = UserUpdate.as_view()
47-
pk = self.user.pk
48-
groups_pk = self.groups[0].pk
49-
request = factory.put('/', {'username': 'exclude', 'groups': [groups_pk]}, format='json')
50-
response = view(request, pk=pk)
51-
assert User.objects.get(pk=pk).groups.count() == 1
52-
expected = {
53-
'id': pk,
54-
'username': 'exclude',
55-
'groups': [1],
56-
'email': '[email protected]'
57-
}
58-
assert response.data == expected
72+
request = factory.put(
73+
'/', {'username': 'exclude', 'groups': [self.groups[0].pk]}, format='json'
74+
)
75+
response = self.view(request, pk=self.user.pk)
76+
assert User.objects.get(pk=self.user.pk).groups.count() == 2
77+
self.expected['username'] = 'exclude'
78+
self.expected['groups'] = [self.groups[0].pk]
79+
assert response.data == self.expected
80+
81+
def test_db_query_count(self):
82+
request = factory.put(
83+
'/', {'username': 'new'}, format='json'
84+
)
85+
with self.assertNumQueries(7):
86+
self.view(request, pk=self.user.pk)
87+
88+
request = factory.put(
89+
'/', {'username': 'new2'}, format='json'
90+
)
91+
with self.assertNumQueries(16):
92+
UserUpdateWithoutPrefetchRelated.as_view()(request, pk=self.user.pk)

0 commit comments

Comments
 (0)