Skip to content

Pass relationship and registry objects to connection_field_factory #187

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Apr 12, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ htmlcov/
nosetests.xml
coverage.xml
*,cover
.pytest_cache/

# Translations
*.mo
Expand Down
6 changes: 2 additions & 4 deletions graphene_sqlalchemy/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
String)
from graphene.types.json import JSONString

from .fields import createConnectionField

try:
from sqlalchemy_utils import ChoiceType, JSONType, ScalarListType, TSVectorType
except ImportError:
Expand All @@ -23,7 +21,7 @@ def is_column_nullable(column):
return bool(getattr(column, "nullable", True))


def convert_sqlalchemy_relationship(relationship, registry):
def convert_sqlalchemy_relationship(relationship, registry, connection_field_factory):
direction = relationship.direction
model = relationship.mapper.entity

Expand All @@ -35,7 +33,7 @@ def dynamic_type():
return Field(_type)
elif direction in (interfaces.ONETOMANY, interfaces.MANYTOMANY):
if _type._meta.connection:
return createConnectionField(_type._meta.connection)
return connection_field_factory(relationship, registry)
return Field(List(_type))

return Dynamic(dynamic_type)
Expand Down
22 changes: 22 additions & 0 deletions graphene_sqlalchemy/fields.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from functools import partial

from promise import Promise, is_thenable
Expand All @@ -9,6 +10,8 @@

from .utils import get_query, sort_argument_for_model

log = logging.getLogger()


class UnsortedSQLAlchemyConnectionField(ConnectionField):
@property
Expand Down Expand Up @@ -95,18 +98,37 @@ def __init__(self, type, *args, **kwargs):
super(SQLAlchemyConnectionField, self).__init__(type, *args, **kwargs)


def default_connection_field_factory(relationship, registry):
model = relationship.mapper.entity
model_type = registry.get_type_for_model(model)
return createConnectionField(model_type)


# TODO Remove in next major version
__connectionFactory = UnsortedSQLAlchemyConnectionField


def createConnectionField(_type):
log.warn(
'createConnectionField is deprecated and will be removed in the next '
'major version. Use SQLAlchemyObjectType.Meta.connection_field_factory instead.'
)
return __connectionFactory(_type)


def registerConnectionFieldFactory(factoryMethod):
log.warn(
'registerConnectionFieldFactory is deprecated and will be removed in the next '
'major version. Use SQLAlchemyObjectType.Meta.connection_field_factory instead.'
)
global __connectionFactory
__connectionFactory = factoryMethod


def unregisterConnectionFieldFactory():
log.warn(
'registerConnectionFieldFactory is deprecated and will be removed in the next '
'major version. Use SQLAlchemyObjectType.Meta.connection_field_factory instead.'
)
global __connectionFactory
__connectionFactory = UnsortedSQLAlchemyConnectionField
42 changes: 0 additions & 42 deletions graphene_sqlalchemy/tests/test_connectionfactory.py

This file was deleted.

21 changes: 13 additions & 8 deletions graphene_sqlalchemy/tests/test_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
from ..converter import (convert_sqlalchemy_column,
convert_sqlalchemy_composite,
convert_sqlalchemy_relationship)
from ..fields import UnsortedSQLAlchemyConnectionField
from ..fields import (UnsortedSQLAlchemyConnectionField,
default_connection_field_factory)
from ..registry import Registry
from ..types import SQLAlchemyObjectType
from .models import Article, Pet, Reporter
Expand Down Expand Up @@ -179,7 +180,9 @@ def test_should_jsontype_convert_jsonstring():

def test_should_manytomany_convert_connectionorlist():
registry = Registry()
dynamic_field = convert_sqlalchemy_relationship(Reporter.pets.property, registry)
dynamic_field = convert_sqlalchemy_relationship(
Reporter.pets.property, registry, default_connection_field_factory
)
assert isinstance(dynamic_field, graphene.Dynamic)
assert not dynamic_field.get_type()

Expand All @@ -190,7 +193,7 @@ class Meta:
model = Pet

dynamic_field = convert_sqlalchemy_relationship(
Reporter.pets.property, A._meta.registry
Reporter.pets.property, A._meta.registry, default_connection_field_factory
)
assert isinstance(dynamic_field, graphene.Dynamic)
graphene_type = dynamic_field.get_type()
Expand All @@ -206,15 +209,17 @@ class Meta:
interfaces = (Node,)

dynamic_field = convert_sqlalchemy_relationship(
Reporter.pets.property, A._meta.registry
Reporter.pets.property, A._meta.registry, default_connection_field_factory
)
assert isinstance(dynamic_field, graphene.Dynamic)
assert isinstance(dynamic_field.get_type(), UnsortedSQLAlchemyConnectionField)


def test_should_manytoone_convert_connectionorlist():
registry = Registry()
dynamic_field = convert_sqlalchemy_relationship(Article.reporter.property, registry)
dynamic_field = convert_sqlalchemy_relationship(
Article.reporter.property, registry, default_connection_field_factory
)
assert isinstance(dynamic_field, graphene.Dynamic)
assert not dynamic_field.get_type()

Expand All @@ -225,7 +230,7 @@ class Meta:
model = Reporter

dynamic_field = convert_sqlalchemy_relationship(
Article.reporter.property, A._meta.registry
Article.reporter.property, A._meta.registry, default_connection_field_factory
)
assert isinstance(dynamic_field, graphene.Dynamic)
graphene_type = dynamic_field.get_type()
Expand All @@ -240,7 +245,7 @@ class Meta:
interfaces = (Node,)

dynamic_field = convert_sqlalchemy_relationship(
Article.reporter.property, A._meta.registry
Article.reporter.property, A._meta.registry, default_connection_field_factory
)
assert isinstance(dynamic_field, graphene.Dynamic)
graphene_type = dynamic_field.get_type()
Expand All @@ -255,7 +260,7 @@ class Meta:
interfaces = (Node,)

dynamic_field = convert_sqlalchemy_relationship(
Reporter.favorite_article.property, A._meta.registry
Reporter.favorite_article.property, A._meta.registry, default_connection_field_factory
)
assert isinstance(dynamic_field, graphene.Dynamic)
graphene_type = dynamic_field.get_type()
Expand Down
98 changes: 95 additions & 3 deletions graphene_sqlalchemy/tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@
import six # noqa F401
from promise import Promise

from graphene import Field, Int, Interface, ObjectType
from graphene.relay import Connection, Node, is_node
from graphene import (Connection, Field, Int, Interface, Node, ObjectType,
is_node)

from ..fields import SQLAlchemyConnectionField
from ..fields import (SQLAlchemyConnectionField,
UnsortedSQLAlchemyConnectionField,
registerConnectionFieldFactory,
unregisterConnectionFieldFactory)
from ..registry import Registry
from ..types import SQLAlchemyObjectType, SQLAlchemyObjectTypeOptions
from .models import Article, Reporter
Expand Down Expand Up @@ -185,3 +188,92 @@ def resolver(*args, **kwargs):
resolver, TestConnection, ReporterWithCustomOptions, None, None
)
assert result is not None


# Tests for connection_field_factory

class _TestSQLAlchemyConnectionField(SQLAlchemyConnectionField):
pass


def test_default_connection_field_factory():
_registry = Registry()

class ReporterType(SQLAlchemyObjectType):
class Meta:
model = Reporter
registry = _registry
interfaces = (Node,)

class ArticleType(SQLAlchemyObjectType):
class Meta:
model = Article
registry = _registry
interfaces = (Node,)

assert isinstance(ReporterType._meta.fields['articles'].type(), UnsortedSQLAlchemyConnectionField)


def test_register_connection_field_factory():
def test_connection_field_factory(relationship, registry):
model = relationship.mapper.entity
_type = registry.get_type_for_model(model)
return _TestSQLAlchemyConnectionField(_type._meta.connection)

_registry = Registry()

class ReporterType(SQLAlchemyObjectType):
class Meta:
model = Reporter
registry = _registry
interfaces = (Node,)
connection_field_factory = test_connection_field_factory

class ArticleType(SQLAlchemyObjectType):
class Meta:
model = Article
registry = _registry
interfaces = (Node,)

assert isinstance(ReporterType._meta.fields['articles'].type(), _TestSQLAlchemyConnectionField)


def test_deprecated_registerConnectionFieldFactory():
registerConnectionFieldFactory(_TestSQLAlchemyConnectionField)

_registry = Registry()

class ReporterType(SQLAlchemyObjectType):
class Meta:
model = Reporter
registry = _registry
interfaces = (Node,)

class ArticleType(SQLAlchemyObjectType):
class Meta:
model = Article
registry = _registry
interfaces = (Node,)

assert isinstance(ReporterType._meta.fields['articles'].type(), _TestSQLAlchemyConnectionField)


def test_deprecated_unregisterConnectionFieldFactory():
registerConnectionFieldFactory(_TestSQLAlchemyConnectionField)
unregisterConnectionFieldFactory()

_registry = Registry()

class ReporterType(SQLAlchemyObjectType):
class Meta:
model = Reporter
registry = _registry
interfaces = (Node,)

class ArticleType(SQLAlchemyObjectType):
class Meta:
model = Article
registry = _registry
interfaces = (Node,)

assert not isinstance(ReporterType._meta.fields['articles'].type(), _TestSQLAlchemyConnectionField)
15 changes: 12 additions & 3 deletions graphene_sqlalchemy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@
convert_sqlalchemy_composite,
convert_sqlalchemy_hybrid_method,
convert_sqlalchemy_relationship)
from .fields import default_connection_field_factory
from .registry import Registry, get_global_registry
from .utils import get_query, is_mapped_class, is_mapped_instance


def construct_fields(model, registry, only_fields, exclude_fields):
def construct_fields(model, registry, only_fields, exclude_fields, connection_field_factory):
inspected_model = sqlalchemyinspect(model)

fields = OrderedDict()
Expand Down Expand Up @@ -71,7 +72,7 @@ def construct_fields(model, registry, only_fields, exclude_fields):
# We skip this field if we specify only_fields and is not
# in there. Or when we exclude this field in exclude_fields
continue
converted_relationship = convert_sqlalchemy_relationship(relationship, registry)
converted_relationship = convert_sqlalchemy_relationship(relationship, registry, connection_field_factory)
name = relationship.key
fields[name] = converted_relationship

Expand Down Expand Up @@ -99,6 +100,7 @@ def __init_subclass_with_meta__(
use_connection=None,
interfaces=(),
id=None,
connection_field_factory=default_connection_field_factory,
_meta=None,
**options
):
Expand All @@ -115,7 +117,14 @@ def __init_subclass_with_meta__(
).format(cls.__name__, registry)

sqla_fields = yank_fields_from_attrs(
construct_fields(model, registry, only_fields, exclude_fields), _as=Field
construct_fields(
model=model,
registry=registry,
only_fields=only_fields,
exclude_fields=exclude_fields,
connection_field_factory=connection_field_factory
),
_as=Field
)

if use_connection is None and interfaces:
Expand Down