forked from graphql-python/graphene-sqlalchemy
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfields.py
157 lines (128 loc) · 5.58 KB
/
fields.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import warnings
from functools import partial
import six
from promise import Promise, is_thenable
from sqlalchemy.orm.query import Query
from graphene.relay import Connection, ConnectionField
from graphene.relay.connection import PageInfo
from graphql_relay.connection.arrayconnection import connection_from_list_slice
from .batching import get_batch_resolver
from .utils import get_query
class UnsortedSQLAlchemyConnectionField(ConnectionField):
@property
def type(self):
from .types import SQLAlchemyObjectType
_type = super(ConnectionField, self).type
if issubclass(_type, Connection):
return _type
assert issubclass(_type, SQLAlchemyObjectType), (
"SQLALchemyConnectionField only accepts SQLAlchemyObjectType types, not {}"
).format(_type.__name__)
assert _type._meta.connection, "The type {} doesn't have a connection".format(
_type.__name__
)
return _type._meta.connection
@property
def model(self):
return self.type._meta.node._meta.model
@classmethod
def get_query(cls, model, info, **args):
return get_query(model, info.context)
@classmethod
def resolve_connection(cls, connection_type, model, info, args, resolved):
if resolved is None:
resolved = cls.get_query(model, info, **args)
if isinstance(resolved, Query):
_len = resolved.count()
else:
_len = len(resolved)
connection = connection_from_list_slice(
resolved,
args,
slice_start=0,
list_length=_len,
list_slice_length=_len,
connection_type=connection_type,
pageinfo_type=PageInfo,
edge_type=connection_type.Edge,
)
connection.iterable = resolved
connection.length = _len
return connection
@classmethod
def connection_resolver(cls, resolver, connection_type, model, root, info, **args):
resolved = resolver(root, info, **args)
on_resolve = partial(cls.resolve_connection, connection_type, model, info, args)
if is_thenable(resolved):
return Promise.resolve(resolved).then(on_resolve)
return on_resolve(resolved)
def get_resolver(self, parent_resolver):
return partial(self.connection_resolver, parent_resolver, self.type, self.model)
# TODO Rename this to SortableSQLAlchemyConnectionField
class SQLAlchemyConnectionField(UnsortedSQLAlchemyConnectionField):
def __init__(self, type, *args, **kwargs):
if "sort" not in kwargs and issubclass(type, Connection):
# Let super class raise if type is not a Connection
try:
kwargs.setdefault("sort", type.Edge.node._type.sort_argument())
except (AttributeError, TypeError):
raise TypeError(
'Cannot create sort argument for {}. A model is required. Set the "sort" argument'
" to None to disabling the creation of the sort query argument".format(
type.__name__
)
)
elif "sort" in kwargs and kwargs["sort"] is None:
del kwargs["sort"]
super(SQLAlchemyConnectionField, self).__init__(type, *args, **kwargs)
@classmethod
def get_query(cls, model, info, sort=None, **args):
query = get_query(model, info.context)
if sort is not None:
if isinstance(sort, six.string_types):
query = query.order_by(sort.value)
else:
query = query.order_by(*(col.value for col in sort))
return query
class BatchSQLAlchemyConnectionField(UnsortedSQLAlchemyConnectionField):
"""
This is currently experimental.
The API and behavior may change in future versions.
Use at your own risk.
"""
def get_resolver(self, parent_resolver):
return partial(self.connection_resolver, self.resolver, self.type, self.model)
@classmethod
def from_relationship(cls, relationship, registry, **field_kwargs):
model = relationship.mapper.entity
model_type = registry.get_type_for_model(model)
return cls(model_type._meta.connection, resolver=get_batch_resolver(relationship), **field_kwargs)
def default_connection_field_factory(relationship, registry, **field_kwargs):
model = relationship.mapper.entity
model_type = registry.get_type_for_model(model)
return __connectionFactory(model_type, **field_kwargs)
# TODO Remove in next major version
__connectionFactory = UnsortedSQLAlchemyConnectionField
def createConnectionField(_type, **field_kwargs):
warnings.warn(
'createConnectionField is deprecated and will be removed in the next '
'major version. Use SQLAlchemyObjectType.Meta.connection_field_factory instead.',
DeprecationWarning,
)
return __connectionFactory(_type, **field_kwargs)
def registerConnectionFieldFactory(factoryMethod):
warnings.warn(
'registerConnectionFieldFactory is deprecated and will be removed in the next '
'major version. Use SQLAlchemyObjectType.Meta.connection_field_factory instead.',
DeprecationWarning,
)
global __connectionFactory
__connectionFactory = factoryMethod
def unregisterConnectionFieldFactory():
warnings.warn(
'registerConnectionFieldFactory is deprecated and will be removed in the next '
'major version. Use SQLAlchemyObjectType.Meta.connection_field_factory instead.',
DeprecationWarning,
)
global __connectionFactory
__connectionFactory = UnsortedSQLAlchemyConnectionField