Skip to content

Commit a9b7666

Browse files
committed
Merge remote-tracking branch 'github/master' into sortable_field
2 parents db0e3db + 8cb52a1 commit a9b7666

File tree

7 files changed

+107
-34
lines changed

7 files changed

+107
-34
lines changed

docs/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
# Docs template
2-
https://github.com/graphql-python/graphene-python.org/archive/docs.zip
2+
http://graphene-python.org/sphinx_graphene_theme.zip

graphene_sqlalchemy/converter.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def dynamic_type():
3636
return Field(_type)
3737
elif direction in (interfaces.ONETOMANY, interfaces.MANYTOMANY):
3838
if _type._meta.connection:
39-
return createConnectionField(_type)
39+
return createConnectionField(_type._meta.connection)
4040
return Field(List(_type))
4141

4242
return Dynamic(dynamic_type)
@@ -92,6 +92,8 @@ def convert_sqlalchemy_type(type, column, registry=None):
9292
@convert_sqlalchemy_type.register(types.Unicode)
9393
@convert_sqlalchemy_type.register(types.UnicodeText)
9494
@convert_sqlalchemy_type.register(postgresql.UUID)
95+
@convert_sqlalchemy_type.register(postgresql.INET)
96+
@convert_sqlalchemy_type.register(postgresql.CIDR)
9597
@convert_sqlalchemy_type.register(TSVectorType)
9698
def convert_column_to_string(type, column, registry=None):
9799
return String(description=get_column_doc(column),

graphene_sqlalchemy/fields.py

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from functools import partial
2-
2+
from promise import is_thenable, Promise
33
from sqlalchemy.orm.query import Query
44

55
from graphene.relay import ConnectionField
@@ -25,39 +25,38 @@ def get_query(cls, model, info, sort=None, **args):
2525
query = query.order_by(*(col.value for col in sort))
2626
return query
2727

28-
@property
29-
def type(self):
30-
from .types import SQLAlchemyObjectType
31-
_type = super(ConnectionField, self).type
32-
assert issubclass(_type, SQLAlchemyObjectType), (
33-
"SQLAlchemyConnectionField only accepts SQLAlchemyObjectType types"
34-
)
35-
assert _type._meta.connection, "The type {} doesn't have a connection".format(_type.__name__)
36-
return _type._meta.connection
37-
3828
@classmethod
39-
def connection_resolver(cls, resolver, connection, model, root, info, **args):
40-
iterable = resolver(root, info, **args)
41-
if iterable is None:
42-
iterable = cls.get_query(model, info, **args)
43-
if isinstance(iterable, Query):
44-
_len = iterable.count()
29+
def resolve_connection(cls, connection_type, model, info, args, resolved):
30+
if resolved is None:
31+
resolved = cls.get_query(model, info, **args)
32+
if isinstance(resolved, Query):
33+
_len = resolved.count()
4534
else:
46-
_len = len(iterable)
35+
_len = len(resolved)
4736
connection = connection_from_list_slice(
48-
iterable,
37+
resolved,
4938
args,
5039
slice_start=0,
5140
list_length=_len,
5241
list_slice_length=_len,
53-
connection_type=connection,
42+
connection_type=connection_type,
5443
pageinfo_type=PageInfo,
55-
edge_type=connection.Edge,
44+
edge_type=connection_type.Edge,
5645
)
57-
connection.iterable = iterable
46+
connection.iterable = resolved
5847
connection.length = _len
5948
return connection
6049

50+
@classmethod
51+
def connection_resolver(cls, resolver, connection_type, model, root, info, **args):
52+
resolved = resolver(root, info, **args)
53+
54+
on_resolve = partial(cls.resolve_connection, connection_type, model, info, args)
55+
if is_thenable(resolved):
56+
return Promise.resolve(resolved).then(on_resolve)
57+
58+
return on_resolve(resolved)
59+
6160
def get_resolver(self, parent_resolver):
6261
return partial(self.connection_resolver, parent_resolver, self.type, self.model)
6362

graphene_sqlalchemy/tests/test_connectionfactory.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,10 @@ def LXResolver(root, args, context, info):
2222
return SQLAlchemyConnectionField.connection_resolver(LXResolver, connection, model, root, args, context, info)
2323

2424
def createLXConnectionField(table):
25-
return LXConnectionField(table, filter=table.filter(), order_by=graphene.List(of_type=table.order_by))
25+
class LXConnection(graphene.relay.Connection):
26+
class Meta:
27+
node = table
28+
return LXConnectionField(LXConnection, filter=table.filter(), order_by=graphene.List(of_type=table.order_by))
2629

2730
registerConnectionFieldFactory(createLXConnectionField)
2831
unregisterConnectionFieldFactory()

graphene_sqlalchemy/tests/test_query.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ class Meta:
140140
interfaces = (Node, )
141141

142142
@classmethod
143-
def get_node(cls, id, info):
143+
def get_node(cls, info, id):
144144
return Reporter(id=2, first_name='Cookie Monster')
145145

146146
class ArticleNode(SQLAlchemyObjectType):
@@ -153,11 +153,15 @@ class Meta:
153153
# def get_node(cls, id, info):
154154
# return Article(id=1, headline='Article node')
155155

156+
class ArticleConnection(graphene.relay.Connection):
157+
class Meta:
158+
node = ArticleNode
159+
156160
class Query(graphene.ObjectType):
157161
node = Node.Field()
158162
reporter = graphene.Field(ReporterNode)
159163
article = graphene.Field(ArticleNode)
160-
all_articles = SQLAlchemyConnectionField(ArticleNode)
164+
all_articles = SQLAlchemyConnectionField(ArticleConnection)
161165

162166
def resolve_reporter(self, *args, **kwargs):
163167
return session.query(Reporter).first()
@@ -239,9 +243,13 @@ class Meta:
239243
model = Editor
240244
interfaces = (Node, )
241245

246+
class EditorConnection(graphene.relay.Connection):
247+
class Meta:
248+
node = EditorNode
249+
242250
class Query(graphene.ObjectType):
243251
node = Node.Field()
244-
all_editors = SQLAlchemyConnectionField(EditorNode)
252+
all_editors = SQLAlchemyConnectionField(EditorConnection)
245253

246254
query = '''
247255
query EditorQuery {

graphene_sqlalchemy/tests/test_types.py

Lines changed: 57 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
1-
1+
from collections import OrderedDict
22
from graphene import Field, Int, Interface, ObjectType
3-
from graphene.relay import Node, is_node
3+
from graphene.relay import Node, is_node, Connection
44
import six
5+
from promise import Promise
56

67
from ..registry import Registry
7-
from ..types import SQLAlchemyObjectType
8+
from ..types import SQLAlchemyObjectType, SQLAlchemyObjectTypeOptions
89
from .models import Article, Reporter
10+
from ..fields import SQLAlchemyConnectionField
911

1012
registry = Registry()
1113

@@ -116,3 +118,55 @@ def test_custom_objecttype_registered():
116118
'pets',
117119
'articles',
118120
'favorite_article']
121+
122+
123+
# Test Custom SQLAlchemyObjectType with Custom Options
124+
class CustomOptions(SQLAlchemyObjectTypeOptions):
125+
custom_option = None
126+
custom_fields = None
127+
128+
129+
class SQLAlchemyObjectTypeWithCustomOptions(SQLAlchemyObjectType):
130+
class Meta:
131+
abstract = True
132+
133+
@classmethod
134+
def __init_subclass_with_meta__(cls, custom_option=None, custom_fields=None, **options):
135+
_meta = CustomOptions(cls)
136+
_meta.custom_option = custom_option
137+
_meta.fields = custom_fields
138+
super(SQLAlchemyObjectTypeWithCustomOptions, cls).__init_subclass_with_meta__(_meta=_meta, **options)
139+
140+
141+
class ReporterWithCustomOptions(SQLAlchemyObjectTypeWithCustomOptions):
142+
class Meta:
143+
model = Reporter
144+
custom_option = 'custom_option'
145+
custom_fields = OrderedDict([('custom_field', Field(Int()))])
146+
147+
148+
def test_objecttype_with_custom_options():
149+
assert issubclass(ReporterWithCustomOptions, ObjectType)
150+
assert ReporterWithCustomOptions._meta.model == Reporter
151+
assert list(
152+
ReporterWithCustomOptions._meta.fields.keys()) == [
153+
'custom_field',
154+
'id',
155+
'first_name',
156+
'last_name',
157+
'email',
158+
'pets',
159+
'articles',
160+
'favorite_article']
161+
assert ReporterWithCustomOptions._meta.custom_option == 'custom_option'
162+
assert isinstance(ReporterWithCustomOptions._meta.fields['custom_field'].type, Int)
163+
164+
165+
def test_promise_connection_resolver():
166+
class TestConnection(Connection):
167+
class Meta:
168+
node = ReporterWithCustomOptions
169+
170+
resolver = lambda *args, **kwargs: Promise.resolve([])
171+
result = SQLAlchemyConnectionField.connection_resolver(resolver, TestConnection, ReporterWithCustomOptions, None, None)
172+
assert result is not None

graphene_sqlalchemy/types.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ class SQLAlchemyObjectType(ObjectType):
9090
@classmethod
9191
def __init_subclass_with_meta__(cls, model=None, registry=None, skip_registry=False,
9292
only_fields=(), exclude_fields=(), connection=None,
93-
use_connection=None, interfaces=(), id=None, **options):
93+
use_connection=None, interfaces=(), id=None, _meta=None, **options):
9494
assert is_mapped_class(model), (
9595
'You need to pass a valid SQLAlchemy Model in '
9696
'{}.Meta, received "{}".'
@@ -121,10 +121,17 @@ def __init_subclass_with_meta__(cls, model=None, registry=None, skip_registry=Fa
121121
"The connection must be a Connection. Received {}"
122122
).format(connection.__name__)
123123

124-
_meta = SQLAlchemyObjectTypeOptions(cls)
124+
if not _meta:
125+
_meta = SQLAlchemyObjectTypeOptions(cls)
126+
125127
_meta.model = model
126128
_meta.registry = registry
127-
_meta.fields = sqla_fields
129+
130+
if _meta.fields:
131+
_meta.fields.update(sqla_fields)
132+
else:
133+
_meta.fields = sqla_fields
134+
128135
_meta.connection = connection
129136
_meta.id = id or 'id'
130137

0 commit comments

Comments
 (0)