Skip to content

Commit 8cb52a1

Browse files
authored
Merge pull request #120 from nikordaris/connection-2.x
SQLAlchemyConnectionField Graphene 2.0 + Promise Support
2 parents a2fe926 + 65e1373 commit 8cb52a1

File tree

5 files changed

+50
-28
lines changed

5 files changed

+50
-28
lines changed

graphene_sqlalchemy/converter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def dynamic_type():
3636
return Field(_type)
3737
elif direction in (interfaces.ONETOMANY, interfaces.MANYTOMANY):
3838
if _type._meta.connection:
39-
return createConnectionField(_type)
39+
return createConnectionField(_type._meta.connection)
4040
return Field(List(_type))
4141

4242
return Dynamic(dynamic_type)

graphene_sqlalchemy/fields.py

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from functools import partial
2-
2+
from promise import is_thenable, Promise
33
from sqlalchemy.orm.query import Query
44

55
from graphene.relay import ConnectionField
@@ -19,39 +19,38 @@ def model(self):
1919
def get_query(cls, model, info, **args):
2020
return get_query(model, info.context)
2121

22-
@property
23-
def type(self):
24-
from .types import SQLAlchemyObjectType
25-
_type = super(ConnectionField, self).type
26-
assert issubclass(_type, SQLAlchemyObjectType), (
27-
"SQLAlchemyConnectionField only accepts SQLAlchemyObjectType types"
28-
)
29-
assert _type._meta.connection, "The type {} doesn't have a connection".format(_type.__name__)
30-
return _type._meta.connection
31-
3222
@classmethod
33-
def connection_resolver(cls, resolver, connection, model, root, info, **args):
34-
iterable = resolver(root, info, **args)
35-
if iterable is None:
36-
iterable = cls.get_query(model, info, **args)
37-
if isinstance(iterable, Query):
38-
_len = iterable.count()
23+
def resolve_connection(cls, connection_type, model, info, args, resolved):
24+
if resolved is None:
25+
resolved = cls.get_query(model, info, **args)
26+
if isinstance(resolved, Query):
27+
_len = resolved.count()
3928
else:
40-
_len = len(iterable)
29+
_len = len(resolved)
4130
connection = connection_from_list_slice(
42-
iterable,
31+
resolved,
4332
args,
4433
slice_start=0,
4534
list_length=_len,
4635
list_slice_length=_len,
47-
connection_type=connection,
36+
connection_type=connection_type,
4837
pageinfo_type=PageInfo,
49-
edge_type=connection.Edge,
38+
edge_type=connection_type.Edge,
5039
)
51-
connection.iterable = iterable
40+
connection.iterable = resolved
5241
connection.length = _len
5342
return connection
5443

44+
@classmethod
45+
def connection_resolver(cls, resolver, connection_type, model, root, info, **args):
46+
resolved = resolver(root, info, **args)
47+
48+
on_resolve = partial(cls.resolve_connection, connection_type, model, info, args)
49+
if is_thenable(resolved):
50+
return Promise.resolve(resolved).then(on_resolve)
51+
52+
return on_resolve(resolved)
53+
5554
def get_resolver(self, parent_resolver):
5655
return partial(self.connection_resolver, parent_resolver, self.type, self.model)
5756

graphene_sqlalchemy/tests/test_connectionfactory.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,10 @@ def LXResolver(root, args, context, info):
2222
return SQLAlchemyConnectionField.connection_resolver(LXResolver, connection, model, root, args, context, info)
2323

2424
def createLXConnectionField(table):
25-
return LXConnectionField(table, filter=table.filter(), order_by=graphene.List(of_type=table.order_by))
25+
class LXConnection(graphene.relay.Connection):
26+
class Meta:
27+
node = table
28+
return LXConnectionField(LXConnection, filter=table.filter(), order_by=graphene.List(of_type=table.order_by))
2629

2730
registerConnectionFieldFactory(createLXConnectionField)
2831
unregisterConnectionFieldFactory()

graphene_sqlalchemy/tests/test_query.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ class Meta:
139139
interfaces = (Node, )
140140

141141
@classmethod
142-
def get_node(cls, id, info):
142+
def get_node(cls, info, id):
143143
return Reporter(id=2, first_name='Cookie Monster')
144144

145145
class ArticleNode(SQLAlchemyObjectType):
@@ -152,11 +152,15 @@ class Meta:
152152
# def get_node(cls, id, info):
153153
# return Article(id=1, headline='Article node')
154154

155+
class ArticleConnection(graphene.relay.Connection):
156+
class Meta:
157+
node = ArticleNode
158+
155159
class Query(graphene.ObjectType):
156160
node = Node.Field()
157161
reporter = graphene.Field(ReporterNode)
158162
article = graphene.Field(ArticleNode)
159-
all_articles = SQLAlchemyConnectionField(ArticleNode)
163+
all_articles = SQLAlchemyConnectionField(ArticleConnection)
160164

161165
def resolve_reporter(self, *args, **kwargs):
162166
return session.query(Reporter).first()
@@ -238,9 +242,13 @@ class Meta:
238242
model = Editor
239243
interfaces = (Node, )
240244

245+
class EditorConnection(graphene.relay.Connection):
246+
class Meta:
247+
node = EditorNode
248+
241249
class Query(graphene.ObjectType):
242250
node = Node.Field()
243-
all_editors = SQLAlchemyConnectionField(EditorNode)
251+
all_editors = SQLAlchemyConnectionField(EditorConnection)
244252

245253
query = '''
246254
query EditorQuery {

graphene_sqlalchemy/tests/test_types.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
from collections import OrderedDict
22
from graphene import Field, Int, Interface, ObjectType
3-
from graphene.relay import Node, is_node
3+
from graphene.relay import Node, is_node, Connection
44
import six
5+
from promise import Promise
56

67
from ..registry import Registry
78
from ..types import SQLAlchemyObjectType, SQLAlchemyObjectTypeOptions
89
from .models import Article, Reporter
10+
from ..fields import SQLAlchemyConnectionField
911

1012
registry = Registry()
1113

@@ -158,3 +160,13 @@ def test_objecttype_with_custom_options():
158160
'favorite_article']
159161
assert ReporterWithCustomOptions._meta.custom_option == 'custom_option'
160162
assert isinstance(ReporterWithCustomOptions._meta.fields['custom_field'].type, Int)
163+
164+
165+
def test_promise_connection_resolver():
166+
class TestConnection(Connection):
167+
class Meta:
168+
node = ReporterWithCustomOptions
169+
170+
resolver = lambda *args, **kwargs: Promise.resolve([])
171+
result = SQLAlchemyConnectionField.connection_resolver(resolver, TestConnection, ReporterWithCustomOptions, None, None)
172+
assert result is not None

0 commit comments

Comments
 (0)