-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
135 lines (106 loc) · 4.54 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
from typing import List
import graphene
from classproperties import classproperty
from django.db.models.fields.related import ForeignKey
from django.db.models.fields.reverse_related import OneToOneRel
from django.db.models.options import Options
from graphene.utils.str_converters import to_snake_case
from graphene_django.filter import DjangoFilterConnectionField as _DjangoFilterConnectionField
from graphql.language.ast import FragmentSpread
class DjangoFilterConnectionField(_DjangoFilterConnectionField):
"""
Preserve select_related and prefetch_related attributes of old queryset during querysets merge
"""
def __init__(self, node, **kwargs):
super(DjangoFilterConnectionField, self).__init__(node, fields=node._meta.filter_fields, **kwargs)
@classmethod
def merge_querysets(cls, default_queryset, queryset):
if default_queryset.query.distinct != queryset.query.distinct:
queryset_merged = default_queryset
else:
queryset_merged = _DjangoFilterConnectionField.merge_querysets(default_queryset, queryset)
queryset_merged.query.select_related = queryset.query.select_related
# pylint: disable=protected-access
queryset_merged._prefetch_related_lookups = queryset._prefetch_related_lookups
return queryset_merged
class DjangoFilterConnectionSearchableField(DjangoFilterConnectionField):
"""
Preserve queryset.search_result_ids during querysets merge
"""
@classmethod
def merge_querysets(cls, default_queryset, queryset):
queryset.search_result_ids = default_queryset.search_result_ids
return super().merge_querysets(default_queryset, queryset)
class DjangoObjectTypeMixin:
"""
Cast select_related to queryset for ForeignKeys of model
"""
@classproperty
def _meta(self) -> Options:
raise NotImplementedError()
@classmethod
def get_node(cls, info, pk):
try:
return cls.get_queryset(info).get(pk=pk)
except cls._meta.model.DoesNotExist:
return None
@classmethod
def get_queryset(cls, info, *_, **__):
queryset = cls._meta.model.objects.all()
fields = cls.select_foreign_keys() + cls.select_o2o_related_objects()
fields_m2m = cls.select_m2m_fields()
selections = cls.get_selections(info)
fields_to_select = cls.convert_selections_to_fields(selections, info)
for field_to_select in fields_to_select:
field_to_select = to_snake_case(field_to_select)
if field_to_select in fields:
queryset = queryset.select_related(field_to_select)
if field_to_select in fields_m2m:
queryset = queryset.prefetch_related(field_to_select)
return queryset
@classmethod
def convert_selections_to_fields(cls, selections, info):
fields = []
for selection in selections:
if isinstance(selection, FragmentSpread):
fields += cls.convert_selections_to_fields(
info.fragments[selection.name.value].selection_set.selections, info)
else:
fields.append(selection.name.value)
return fields
@classmethod
def get_selections(cls, info):
selections = [info.field_asts[0]]
found = False
i = 0
while True:
if selections[i].selection_set is None:
if i >= len(selections):
break
else:
i += 1
continue
if selections[i].name.value in [cls._meta.model._meta.model_name, 'node']:
found = True
selections = selections[i].selection_set.selections
i = 0
if found is True and selections[0].name.value == 'edges':
found = False
if found:
break
return selections
@classmethod
def select_foreign_keys(cls) -> List[str]:
return [field.name for field in cls._meta.model._meta.fields if isinstance(field, ForeignKey)]
@classmethod
def select_m2m_fields(cls) -> List[str]:
return [field.name for field in cls._meta.model._meta.many_to_many]
@classmethod
def select_o2o_related_objects(cls) -> List[str]:
return [rel.related_name for rel in cls._meta.model._meta.related_objects if isinstance(rel, OneToOneRel)]
class CountableConnectionBase(graphene.relay.Connection):
class Meta:
abstract = True
total_count = graphene.Int()
def resolve_total_count(self, _):
return self.iterable.count()