forked from graphql-python/graphene-sqlalchemy
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathfields.py
112 lines (89 loc) · 3.82 KB
/
fields.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
from functools import partial
from promise import Promise, is_thenable
from sqlalchemy.orm.query import Query
from graphene.relay import Connection, ConnectionField
from graphene.relay.connection import PageInfo
from graphql_relay.connection.arrayconnection import connection_from_list_slice
from .utils import get_query, get_sort_argument_for_model
class UnsortedSQLAlchemyConnectionField(ConnectionField):
@property
def type(self):
from .types import SQLAlchemyObjectType
_type = super(ConnectionField, self).type
if issubclass(_type, Connection):
return _type
assert issubclass(_type, SQLAlchemyObjectType), (
"SQLALchemyConnectionField only accepts SQLAlchemyObjectType types, not {}"
).format(_type.__name__)
assert _type._meta.connection, "The type {} doesn't have a connection".format(
_type.__name__
)
return _type._meta.connection
@property
def model(self):
return self.type._meta.node._meta.model
@classmethod
def get_query(cls, model, info, sort=None, **args):
query = get_query(model, info.context)
if sort is not None:
if isinstance(sort, str):
query = query.order_by(sort.value)
else:
query = query.order_by(*(col.value for col in sort))
return query
@classmethod
def resolve_connection(cls, connection_type, model, info, args, resolved):
if resolved is None:
resolved = cls.get_query(model, info, **args)
if isinstance(resolved, Query):
_len = resolved.count()
else:
_len = len(resolved)
connection = connection_from_list_slice(
resolved,
args,
slice_start=0,
list_length=_len,
list_slice_length=_len,
connection_type=connection_type,
pageinfo_type=PageInfo,
edge_type=connection_type.Edge,
)
connection.iterable = resolved
connection.length = _len
return connection
@classmethod
def connection_resolver(cls, resolver, connection_type, model, root, info, **args):
resolved = resolver(root, info, **args)
on_resolve = partial(cls.resolve_connection, connection_type, model, info, args)
if is_thenable(resolved):
return Promise.resolve(resolved).then(on_resolve)
return on_resolve(resolved)
def get_resolver(self, parent_resolver):
return partial(self.connection_resolver, parent_resolver, self.type, self.model)
class SQLAlchemyConnectionField(UnsortedSQLAlchemyConnectionField):
def __init__(self, type, *args, **kwargs):
if "sort" not in kwargs and issubclass(type, Connection):
# Let super class raise if type is not a Connection
try:
model = type.Edge.node._type._meta.model
kwargs.setdefault("sort", get_sort_argument_for_model(model))
except Exception:
raise Exception(
'Cannot create sort argument for {}. A model is required. Set the "sort" argument'
" to None to disabling the creation of the sort query argument".format(
type.__name__
)
)
elif "sort" in kwargs and kwargs["sort"] is None:
del kwargs["sort"]
super(SQLAlchemyConnectionField, self).__init__(type, *args, **kwargs)
__connectionFactory = UnsortedSQLAlchemyConnectionField
def createConnectionField(_type):
return __connectionFactory(_type)
def registerConnectionFieldFactory(factoryMethod):
global __connectionFactory
__connectionFactory = factoryMethod
def unregisterConnectionFieldFactory():
global __connectionFactory
__connectionFactory = UnsortedSQLAlchemyConnectionField