Skip to content

Add support for Non-Null SQLAlchemyConnectionField #261

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 4, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 38 additions & 12 deletions graphene_sqlalchemy/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from promise import Promise, is_thenable
from sqlalchemy.orm.query import Query

from graphene import NonNull
from graphene.relay import Connection, ConnectionField
from graphene.relay.connection import PageInfo
from graphql_relay.connection.arrayconnection import connection_from_list_slice
Expand All @@ -19,19 +20,26 @@ def type(self):
from .types import SQLAlchemyObjectType

_type = super(ConnectionField, self).type
if issubclass(_type, Connection):
nullable_type = get_nullable_type(_type)
if issubclass(nullable_type, Connection):
return _type
assert issubclass(_type, SQLAlchemyObjectType), (
assert issubclass(nullable_type, SQLAlchemyObjectType), (
"SQLALchemyConnectionField only accepts SQLAlchemyObjectType types, not {}"
).format(_type.__name__)
assert _type.connection, "The type {} doesn't have a connection".format(
_type.__name__
).format(nullable_type.__name__)
assert (
nullable_type.connection
), "The type {} doesn't have a connection".format(
nullable_type.__name__
)
return _type.connection
assert _type == nullable_type, (
"Passing a SQLAlchemyObjectType instance is deprecated. "
"Pass the connection type instead accessible via SQLAlchemyObjectType.connection"
)
return nullable_type.connection

@property
def model(self):
return self.type._meta.node._meta.model
return get_nullable_type(self.type)._meta.node._meta.model

@classmethod
def get_query(cls, model, info, **args):
Expand Down Expand Up @@ -70,21 +78,27 @@ def connection_resolver(cls, resolver, connection_type, model, root, info, **arg
return on_resolve(resolved)

def get_resolver(self, parent_resolver):
return partial(self.connection_resolver, parent_resolver, self.type, self.model)
return partial(
self.connection_resolver,
parent_resolver,
get_nullable_type(self.type),
self.model,
)


# TODO Rename this to SortableSQLAlchemyConnectionField
class SQLAlchemyConnectionField(UnsortedSQLAlchemyConnectionField):
def __init__(self, type, *args, **kwargs):
if "sort" not in kwargs and issubclass(type, Connection):
nullable_type = get_nullable_type(type)
if "sort" not in kwargs and issubclass(nullable_type, Connection):
# Let super class raise if type is not a Connection
try:
kwargs.setdefault("sort", type.Edge.node._type.sort_argument())
kwargs.setdefault("sort", nullable_type.Edge.node._type.sort_argument())
except (AttributeError, TypeError):
raise TypeError(
'Cannot create sort argument for {}. A model is required. Set the "sort" argument'
" to None to disabling the creation of the sort query argument".format(
type.__name__
nullable_type.__name__
)
)
elif "sort" in kwargs and kwargs["sort"] is None:
Expand All @@ -108,8 +122,14 @@ class BatchSQLAlchemyConnectionField(UnsortedSQLAlchemyConnectionField):
The API and behavior may change in future versions.
Use at your own risk.
"""

def get_resolver(self, parent_resolver):
return partial(self.connection_resolver, self.resolver, self.type, self.model)
return partial(
self.connection_resolver,
self.resolver,
get_nullable_type(self.type),
self.model,
)

@classmethod
def from_relationship(cls, relationship, registry, **field_kwargs):
Expand Down Expand Up @@ -155,3 +175,9 @@ def unregisterConnectionFieldFactory():
)
global __connectionFactory
__connectionFactory = UnsortedSQLAlchemyConnectionField


def get_nullable_type(_type):
if isinstance(_type, NonNull):
return _type.of_type
return _type
8 changes: 3 additions & 5 deletions graphene_sqlalchemy/tests/test_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,8 +233,7 @@ def test_one_to_one(session_factory):
'articles.headline AS articles_headline, '
'articles.pub_date AS articles_pub_date \n'
'FROM articles \n'
'WHERE articles.reporter_id IN (?, ?) '
'ORDER BY articles.reporter_id',
'WHERE articles.reporter_id IN (?, ?)',
'(1, 2)'
]

Expand Down Expand Up @@ -337,8 +336,7 @@ def test_one_to_many(session_factory):
'articles.headline AS articles_headline, '
'articles.pub_date AS articles_pub_date \n'
'FROM articles \n'
'WHERE articles.reporter_id IN (?, ?) '
'ORDER BY articles.reporter_id',
'WHERE articles.reporter_id IN (?, ?)',
'(1, 2)'
]

Expand Down Expand Up @@ -470,7 +468,7 @@ def test_many_to_many(session_factory):
'JOIN association AS association_1 ON reporters_1.id = association_1.reporter_id '
'JOIN pets ON pets.id = association_1.pet_id \n'
'WHERE reporters_1.id IN (?, ?) '
'ORDER BY reporters_1.id, pets.id',
'ORDER BY pets.id',
'(1, 2)'
]

Expand Down
16 changes: 15 additions & 1 deletion graphene_sqlalchemy/tests/test_fields.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest
from promise import Promise

from graphene import ObjectType
from graphene import NonNull, ObjectType
from graphene.relay import Connection, Node

from ..fields import (SQLAlchemyConnectionField,
Expand All @@ -26,6 +26,20 @@ class Meta:
##


def test_nonnull_sqlalachemy_connection():
field = SQLAlchemyConnectionField(NonNull(Pet.connection))
assert isinstance(field.type, NonNull)
assert issubclass(field.type.of_type, Connection)
assert field.type.of_type._meta.node is Pet


def test_required_sqlalachemy_connection():
field = SQLAlchemyConnectionField(Pet.connection, required=True)
assert isinstance(field.type, NonNull)
assert issubclass(field.type.of_type, Connection)
assert field.type.of_type._meta.node is Pet


def test_promise_connection_resolver():
def resolver(_obj, _info):
return Promise.resolve([])
Expand Down