|
1 | 1 | from django.contrib.auth.models import Group, User
|
| 2 | +from django.db.models.query import Prefetch |
2 | 3 | from django.test import TestCase
|
3 | 4 |
|
4 | 5 | from rest_framework import generics, serializers
|
|
8 | 9 |
|
9 | 10 |
|
10 | 11 | class UserSerializer(serializers.ModelSerializer):
|
| 12 | + permissions = serializers.SerializerMethodField() |
| 13 | + |
| 14 | + def get_permissions(self, obj): |
| 15 | + ret = [] |
| 16 | + for g in obj.groups.all(): |
| 17 | + ret.extend([p.pk for p in g.permissions.all()]) |
| 18 | + return ret |
| 19 | + |
11 | 20 | class Meta:
|
12 | 21 | model = User
|
13 |
| - fields = ('id', 'username', 'email', 'groups') |
| 22 | + fields = ('id', 'username', 'email', 'groups', 'permissions') |
| 23 | + |
| 24 | + |
| 25 | +class UserRetrieveUpdate(generics.RetrieveUpdateAPIView): |
| 26 | + queryset = User.objects.exclude(username='exclude').prefetch_related( |
| 27 | + Prefetch('groups', queryset=Group.objects.exclude(name='exclude')), |
| 28 | + 'groups__permissions', |
| 29 | + ) |
| 30 | + serializer_class = UserSerializer |
14 | 31 |
|
15 | 32 |
|
16 |
| -class UserUpdate(generics.UpdateAPIView): |
17 |
| - queryset = User.objects.exclude(username='exclude').prefetch_related('groups') |
| 33 | +class UserUpdateWithoutPrefetchRelated(generics.UpdateAPIView): |
| 34 | + queryset = User.objects.exclude(username='exclude') |
18 | 35 | serializer_class = UserSerializer
|
19 | 36 |
|
20 | 37 |
|
21 | 38 | class TestPrefetchRelatedUpdates(TestCase):
|
22 | 39 | def setUp(self):
|
23 | 40 | self. user = User. objects. create( username='tom', email='[email protected]')
|
24 |
| - self.groups = [Group.objects.create(name='a'), Group.objects.create(name='b')] |
| 41 | + self.groups = [Group.objects.create(name=f'group {i}') for i in range(10)] |
25 | 42 | self.user.groups.set(self.groups)
|
| 43 | + self.user.groups.add(Group.objects.create(name='exclude')) |
| 44 | + self.expected = { |
| 45 | + 'id': self.user.pk, |
| 46 | + 'username': 'tom', |
| 47 | + 'groups': [group.pk for group in self.groups], |
| 48 | + |
| 49 | + 'permissions': [], |
| 50 | + } |
| 51 | + self.view = UserRetrieveUpdate.as_view() |
26 | 52 |
|
27 | 53 | def test_prefetch_related_updates(self):
|
28 |
| - view = UserUpdate.as_view() |
29 |
| - pk = self.user.pk |
30 |
| - groups_pk = self.groups[0].pk |
31 |
| - request = factory.put('/', {'username': 'new', 'groups': [groups_pk]}, format='json') |
32 |
| - response = view(request, pk=pk) |
33 |
| - assert User.objects.get(pk=pk).groups.count() == 1 |
34 |
| - expected = { |
35 |
| - 'id': pk, |
36 |
| - 'username': 'new', |
37 |
| - 'groups': [1], |
38 |
| - |
39 |
| - } |
40 |
| - assert response.data == expected |
| 54 | + self.groups.append(Group.objects.create(name='c')) |
| 55 | + request = factory.put( |
| 56 | + '/', {'username': 'new', 'groups': [group.pk for group in self.groups]}, format='json' |
| 57 | + ) |
| 58 | + self.expected['username'] = 'new' |
| 59 | + self.expected['groups'] = [group.pk for group in self.groups] |
| 60 | + response = self.view(request, pk=self.user.pk) |
| 61 | + assert User.objects.get(pk=self.user.pk).groups.count() == 12 |
| 62 | + assert response.data == self.expected |
| 63 | + # Update and fetch should get same result |
| 64 | + request = factory.get('/') |
| 65 | + response = self.view(request, pk=self.user.pk) |
| 66 | + assert response.data == self.expected |
41 | 67 |
|
42 | 68 | def test_prefetch_related_excluding_instance_from_original_queryset(self):
|
43 | 69 | """
|
44 | 70 | Regression test for https://github.com/encode/django-rest-framework/issues/4661
|
45 | 71 | """
|
46 |
| - view = UserUpdate.as_view() |
47 |
| - pk = self.user.pk |
48 |
| - groups_pk = self.groups[0].pk |
49 |
| - request = factory.put('/', {'username': 'exclude', 'groups': [groups_pk]}, format='json') |
50 |
| - response = view(request, pk=pk) |
51 |
| - assert User.objects.get(pk=pk).groups.count() == 1 |
52 |
| - expected = { |
53 |
| - 'id': pk, |
54 |
| - 'username': 'exclude', |
55 |
| - 'groups': [1], |
56 |
| - |
57 |
| - } |
58 |
| - assert response.data == expected |
| 72 | + request = factory.put( |
| 73 | + '/', {'username': 'exclude', 'groups': [self.groups[0].pk]}, format='json' |
| 74 | + ) |
| 75 | + response = self.view(request, pk=self.user.pk) |
| 76 | + assert User.objects.get(pk=self.user.pk).groups.count() == 2 |
| 77 | + self.expected['username'] = 'exclude' |
| 78 | + self.expected['groups'] = [self.groups[0].pk] |
| 79 | + assert response.data == self.expected |
| 80 | + |
| 81 | + def test_db_query_count(self): |
| 82 | + request = factory.put( |
| 83 | + '/', {'username': 'new'}, format='json' |
| 84 | + ) |
| 85 | + with self.assertNumQueries(7): |
| 86 | + self.view(request, pk=self.user.pk) |
| 87 | + |
| 88 | + request = factory.put( |
| 89 | + '/', {'username': 'new2'}, format='json' |
| 90 | + ) |
| 91 | + with self.assertNumQueries(16): |
| 92 | + UserUpdateWithoutPrefetchRelated.as_view()(request, pk=self.user.pk) |
0 commit comments