-
Notifications
You must be signed in to change notification settings - Fork 764
/
Copy pathutils.py
162 lines (117 loc) · 4.5 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import inspect
from django.db import connection, models, transaction
from django.db.models.manager import Manager
from django.utils.encoding import force_str
from django.utils.functional import Promise
from graphene.utils.str_converters import to_camel_case
try:
import django_filters # noqa
DJANGO_FILTER_INSTALLED = True
except ImportError:
DJANGO_FILTER_INSTALLED = False
def isiterable(value):
try:
iter(value)
except TypeError:
return False
return True
def _camelize_django_str(s):
if isinstance(s, Promise):
s = force_str(s)
return to_camel_case(s) if isinstance(s, str) else s
def camelize(data):
if isinstance(data, dict):
return {_camelize_django_str(k): camelize(v) for k, v in data.items()}
if isiterable(data) and not isinstance(data, (str, Promise)):
return [camelize(d) for d in data]
return data
def _get_model_ancestry(model):
model_ancestry = [model]
for base in model.__bases__:
if is_valid_django_model(base) and getattr(base, "_meta", False):
model_ancestry.append(base)
return model_ancestry
def get_reverse_fields(model, local_field_names):
"""
Searches through the model's ancestry and gets reverse relationships the models
Yields a tuple of (field.name, field)
"""
model_ancestry = _get_model_ancestry(model)
for _model in model_ancestry:
for name, attr in _model.__dict__.items():
# Don't duplicate any local fields
if name in local_field_names:
continue
# "rel" for FK and M2M relations and "related" for O2O Relations
related = getattr(attr, "rel", None) or getattr(attr, "related", None)
if isinstance(related, models.ManyToOneRel):
yield (name, related)
elif isinstance(related, models.ManyToManyRel) and not related.symmetrical:
yield (name, related)
def get_local_fields(model):
"""
Searches through the model's ancestry and gets the fields on the models
Returns a dict of {field.name: field}
"""
model_ancestry = _get_model_ancestry(model)
local_fields_dict = {}
for _model in model_ancestry:
for field in sorted(
list(_model._meta.fields) + list(_model._meta.local_many_to_many)
):
if field.name not in local_fields_dict:
local_fields_dict[field.name] = field
return list(local_fields_dict.items())
def maybe_queryset(value):
if isinstance(value, Manager):
value = value.get_queryset()
return value
def get_model_fields(model):
"""
Gets all the fields and relationships on the Django model and its ancestry.
Prioritizes local fields and relationships over the reverse relationships of the same name
Returns a tuple of (field.name, field)
"""
local_fields = get_local_fields(model)
local_field_names = {field[0] for field in local_fields}
reverse_fields = get_reverse_fields(model, local_field_names)
all_fields = local_fields + list(reverse_fields)
return all_fields
def is_valid_django_model(model):
return inspect.isclass(model) and issubclass(model, models.Model)
def import_single_dispatch():
try:
from functools import singledispatch
except ImportError:
singledispatch = None
if not singledispatch:
try:
from singledispatch import singledispatch
except ImportError:
pass
if not singledispatch:
raise Exception(
"It seems your python version does not include "
"functools.singledispatch. Please install the 'singledispatch' "
"package. More information here: "
"https://pypi.python.org/pypi/singledispatch"
)
return singledispatch
def set_rollback():
atomic_requests = connection.settings_dict.get("ATOMIC_REQUESTS", False)
if atomic_requests and connection.in_atomic_block:
transaction.set_rollback(True)
def bypass_get_queryset(resolver):
"""
Adds a bypass_get_queryset attribute to the resolver, which is used to
bypass any custom get_queryset method of the DjangoObjectType.
"""
resolver._bypass_get_queryset = True
return resolver
def __django_version():
from pkg_resources import get_distribution
return get_distribution("django").parsed_version
def __parse_version(v):
from pkg_resources import parse_version
return parse_version(v)
_DJANGO_VERSION_AT_LEAST_4_2 = __django_version() >= __parse_version("4.2")