|
1 | 1 | from django.contrib.auth.models import Group, User
|
2 |
| -from django.db.models.query import Prefetch |
3 | 2 | from django.test import TestCase
|
4 | 3 |
|
5 | 4 | from rest_framework import generics, serializers
|
|
9 | 8 |
|
10 | 9 |
|
11 | 10 | class UserSerializer(serializers.ModelSerializer):
|
12 |
| - permissions = serializers.SerializerMethodField() |
13 |
| - |
14 |
| - def get_permissions(self, obj): |
15 |
| - ret = [] |
16 |
| - for g in obj.groups.all(): |
17 |
| - ret.extend([p.pk for p in g.permissions.all()]) |
18 |
| - return ret |
19 |
| - |
20 | 11 | class Meta:
|
21 | 12 | model = User
|
22 |
| - fields = ('id', 'username', 'email', 'groups', 'permissions') |
23 |
| - |
24 |
| - |
25 |
| -class UserRetrieveUpdate(generics.RetrieveUpdateAPIView): |
26 |
| - queryset = User.objects.exclude(username='exclude').prefetch_related( |
27 |
| - Prefetch('groups', queryset=Group.objects.exclude(name='exclude')), |
28 |
| - 'groups__permissions', |
29 |
| - ) |
30 |
| - serializer_class = UserSerializer |
| 13 | + fields = ('id', 'username', 'email', 'groups') |
31 | 14 |
|
32 | 15 |
|
33 |
| -class UserUpdateWithoutPrefetchRelated(generics.UpdateAPIView): |
34 |
| - queryset = User.objects.exclude(username='exclude') |
| 16 | +class UserUpdate(generics.UpdateAPIView): |
| 17 | + queryset = User.objects.exclude(username='exclude').prefetch_related('groups') |
35 | 18 | serializer_class = UserSerializer
|
36 | 19 |
|
37 | 20 |
|
38 | 21 | class TestPrefetchRelatedUpdates(TestCase):
|
39 | 22 | def setUp(self):
|
40 | 23 | self. user = User. objects. create( username='tom', email='[email protected]')
|
41 |
| - self.groups = [Group.objects.create(name=f'group {i}') for i in range(10)] |
| 24 | + self.groups = [Group.objects.create(name='a'), Group.objects.create(name='b')] |
42 | 25 | self.user.groups.set(self.groups)
|
43 |
| - self.user.groups.add(Group.objects.create(name='exclude')) |
44 |
| - self.expected = { |
45 |
| - 'id': self.user.pk, |
46 |
| - 'username': 'tom', |
47 |
| - 'groups': [group.pk for group in self.groups], |
48 |
| - |
49 |
| - 'permissions': [], |
50 |
| - } |
51 |
| - self.view = UserRetrieveUpdate.as_view() |
52 | 26 |
|
53 | 27 | def test_prefetch_related_updates(self):
|
54 |
| - self.groups.append(Group.objects.create(name='c')) |
55 |
| - request = factory.put( |
56 |
| - '/', {'username': 'new', 'groups': [group.pk for group in self.groups]}, format='json' |
57 |
| - ) |
58 |
| - self.expected['username'] = 'new' |
59 |
| - self.expected['groups'] = [group.pk for group in self.groups] |
60 |
| - response = self.view(request, pk=self.user.pk) |
61 |
| - assert User.objects.get(pk=self.user.pk).groups.count() == 12 |
62 |
| - assert response.data == self.expected |
63 |
| - # Update and fetch should get same result |
64 |
| - request = factory.get('/') |
65 |
| - response = self.view(request, pk=self.user.pk) |
66 |
| - assert response.data == self.expected |
| 28 | + view = UserUpdate.as_view() |
| 29 | + pk = self.user.pk |
| 30 | + groups_pk = self.groups[0].pk |
| 31 | + request = factory.put('/', {'username': 'new', 'groups': [groups_pk]}, format='json') |
| 32 | + response = view(request, pk=pk) |
| 33 | + assert User.objects.get(pk=pk).groups.count() == 1 |
| 34 | + expected = { |
| 35 | + 'id': pk, |
| 36 | + 'username': 'new', |
| 37 | + 'groups': [1], |
| 38 | + |
| 39 | + } |
| 40 | + assert response.data == expected |
67 | 41 |
|
68 | 42 | def test_prefetch_related_excluding_instance_from_original_queryset(self):
|
69 | 43 | """
|
70 | 44 | Regression test for https://github.com/encode/django-rest-framework/issues/4661
|
71 | 45 | """
|
72 |
| - request = factory.put( |
73 |
| - '/', {'username': 'exclude', 'groups': [self.groups[0].pk]}, format='json' |
74 |
| - ) |
75 |
| - response = self.view(request, pk=self.user.pk) |
76 |
| - assert User.objects.get(pk=self.user.pk).groups.count() == 2 |
77 |
| - self.expected['username'] = 'exclude' |
78 |
| - self.expected['groups'] = [self.groups[0].pk] |
79 |
| - assert response.data == self.expected |
80 |
| - |
81 |
| - def test_db_query_count(self): |
82 |
| - request = factory.put( |
83 |
| - '/', {'username': 'new'}, format='json' |
84 |
| - ) |
85 |
| - with self.assertNumQueries(7): |
86 |
| - self.view(request, pk=self.user.pk) |
87 |
| - |
88 |
| - request = factory.put( |
89 |
| - '/', {'username': 'new2'}, format='json' |
90 |
| - ) |
91 |
| - with self.assertNumQueries(16): |
92 |
| - UserUpdateWithoutPrefetchRelated.as_view()(request, pk=self.user.pk) |
| 46 | + view = UserUpdate.as_view() |
| 47 | + pk = self.user.pk |
| 48 | + groups_pk = self.groups[0].pk |
| 49 | + request = factory.put('/', {'username': 'exclude', 'groups': [groups_pk]}, format='json') |
| 50 | + response = view(request, pk=pk) |
| 51 | + assert User.objects.get(pk=pk).groups.count() == 1 |
| 52 | + expected = { |
| 53 | + 'id': pk, |
| 54 | + 'username': 'exclude', |
| 55 | + 'groups': [1], |
| 56 | + |
| 57 | + } |
| 58 | + assert response.data == expected |
0 commit comments