Skip to content

Commit 99ec6b4

Browse files
committed
Added support for total count on relay connections
Closes #58
1 parent 4827ce2 commit 99ec6b4

File tree

4 files changed

+68
-5
lines changed

4 files changed

+68
-5
lines changed

examples/flask_sqlalchemy/schema.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ class Role(SQLAlchemyObjectType):
2525
class Meta:
2626
model = RoleModel
2727
interfaces = (relay.Node, )
28+
# Disable the total count on this connection
29+
total_count = False
2830

2931

3032
class Query(graphene.ObjectType):

graphene_sqlalchemy/tests/test_query.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,3 +365,33 @@ class Mutation(graphene.ObjectType):
365365
result = schema.execute(query, context_value={'session': session})
366366
assert not result.errors
367367
assert result.data == expected
368+
369+
370+
def test_should_return_total_count(session):
371+
setup_fixtures(session)
372+
373+
class ReporterNode(SQLAlchemyObjectType):
374+
375+
class Meta:
376+
model = Reporter
377+
interfaces = (Node, )
378+
379+
class Query(graphene.ObjectType):
380+
all_article = SQLAlchemyConnectionField(ReporterNode)
381+
382+
query = '''
383+
{
384+
allArticle {
385+
totalCount
386+
}
387+
}
388+
'''
389+
expected = {
390+
'allArticle': {
391+
'totalCount': session.query(Reporter).count()
392+
},
393+
}
394+
schema = graphene.Schema(query=Query)
395+
result = schema.execute(query, context_value={'session': session})
396+
assert not result.errors
397+
assert result.data == expected

graphene_sqlalchemy/tests/test_types.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11

22
from graphene import Field, Int, Interface, ObjectType
3-
from graphene.relay import Node, is_node
3+
from graphene.relay import Node, is_node, Connection
44
import six
55

66
from ..registry import Registry
7-
from ..types import SQLAlchemyObjectType
7+
from ..types import SQLAlchemyObjectType, ConnectionWithCount
88
from .models import Article, Reporter
99

1010
registry = Registry()
@@ -116,3 +116,22 @@ def test_custom_objecttype_registered():
116116
'pets',
117117
'articles',
118118
'favorite_article']
119+
120+
def test_total_count():
121+
class TotalCount(SQLAlchemyObjectType):
122+
class Meta:
123+
model = Article
124+
interfaces = (Node, )
125+
registry = registry
126+
127+
class NoTotalCount(SQLAlchemyObjectType):
128+
class Meta:
129+
model = Reporter
130+
interfaces = (Node, )
131+
registry = registry
132+
total_count = False
133+
134+
assert issubclass(TotalCount._meta.connection, ConnectionWithCount)
135+
assert not issubclass(NoTotalCount._meta.connection, ConnectionWithCount)
136+
assert issubclass(TotalCount._meta.connection, Connection)
137+
assert issubclass(NoTotalCount._meta.connection, Connection)

graphene_sqlalchemy/types.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from sqlalchemy.ext.hybrid import hybrid_property
55
from sqlalchemy.orm.exc import NoResultFound
66

7-
from graphene import Field # , annotate, ResolveInfo
7+
from graphene import Field, Int, NonNull
88
from graphene.relay import Connection, Node
99
from graphene.types.objecttype import ObjectType, ObjectTypeOptions
1010
from graphene.types.utils import yank_fields_from_attrs
@@ -86,11 +86,22 @@ class SQLAlchemyObjectTypeOptions(ObjectTypeOptions):
8686
id = None # type: str
8787

8888

89+
class ConnectionWithCount(Connection):
90+
'''Class that adds `totalCount` to a connection field'''
91+
class Meta:
92+
abstract = True
93+
94+
total_count = NonNull(Int)
95+
96+
def resolve_total_count(self, info, **kwargs):
97+
return self.length
98+
99+
89100
class SQLAlchemyObjectType(ObjectType):
90101
@classmethod
91102
def __init_subclass_with_meta__(cls, model=None, registry=None, skip_registry=False,
92103
only_fields=(), exclude_fields=(), connection=None,
93-
use_connection=None, interfaces=(), id=None, **options):
104+
use_connection=None, interfaces=(), id=None, total_count=True, **options):
94105
assert is_mapped_class(model), (
95106
'You need to pass a valid SQLAlchemy Model in '
96107
'{}.Meta, received "{}".'
@@ -114,7 +125,8 @@ def __init_subclass_with_meta__(cls, model=None, registry=None, skip_registry=Fa
114125

115126
if use_connection and not connection:
116127
# We create the connection automatically
117-
connection = Connection.create_type('{}Connection'.format(cls.__name__), node=cls)
128+
connection_class = ConnectionWithCount if total_count else Connection
129+
connection = connection_class.create_type('{}Connection'.format(cls.__name__), node=cls)
118130

119131
if connection is not None:
120132
assert issubclass(connection, Connection), (

0 commit comments

Comments
 (0)