diff --git a/.gitignore b/.gitignore index 2c4ca2b1..e4070f31 100644 --- a/.gitignore +++ b/.gitignore @@ -45,6 +45,7 @@ htmlcov/ nosetests.xml coverage.xml *,cover +.pytest_cache/ # Translations *.mo diff --git a/graphene_sqlalchemy/converter.py b/graphene_sqlalchemy/converter.py index 053aa8b5..7cc259e0 100644 --- a/graphene_sqlalchemy/converter.py +++ b/graphene_sqlalchemy/converter.py @@ -7,8 +7,6 @@ String) from graphene.types.json import JSONString -from .fields import createConnectionField - try: from sqlalchemy_utils import ChoiceType, JSONType, ScalarListType, TSVectorType except ImportError: @@ -23,7 +21,7 @@ def is_column_nullable(column): return bool(getattr(column, "nullable", True)) -def convert_sqlalchemy_relationship(relationship, registry): +def convert_sqlalchemy_relationship(relationship, registry, connection_field_factory): direction = relationship.direction model = relationship.mapper.entity @@ -35,7 +33,7 @@ def dynamic_type(): return Field(_type) elif direction in (interfaces.ONETOMANY, interfaces.MANYTOMANY): if _type._meta.connection: - return createConnectionField(_type._meta.connection) + return connection_field_factory(relationship, registry) return Field(List(_type)) return Dynamic(dynamic_type) diff --git a/graphene_sqlalchemy/fields.py b/graphene_sqlalchemy/fields.py index 7e313625..4a46b749 100644 --- a/graphene_sqlalchemy/fields.py +++ b/graphene_sqlalchemy/fields.py @@ -1,3 +1,4 @@ +import logging from functools import partial from promise import Promise, is_thenable @@ -9,6 +10,8 @@ from .utils import get_query, sort_argument_for_model +log = logging.getLogger() + class UnsortedSQLAlchemyConnectionField(ConnectionField): @property @@ -95,18 +98,37 @@ def __init__(self, type, *args, **kwargs): super(SQLAlchemyConnectionField, self).__init__(type, *args, **kwargs) +def default_connection_field_factory(relationship, registry): + model = relationship.mapper.entity + model_type = registry.get_type_for_model(model) + return createConnectionField(model_type) + + +# TODO Remove in next major version __connectionFactory = UnsortedSQLAlchemyConnectionField def createConnectionField(_type): + log.warn( + 'createConnectionField is deprecated and will be removed in the next ' + 'major version. Use SQLAlchemyObjectType.Meta.connection_field_factory instead.' + ) return __connectionFactory(_type) def registerConnectionFieldFactory(factoryMethod): + log.warn( + 'registerConnectionFieldFactory is deprecated and will be removed in the next ' + 'major version. Use SQLAlchemyObjectType.Meta.connection_field_factory instead.' + ) global __connectionFactory __connectionFactory = factoryMethod def unregisterConnectionFieldFactory(): + log.warn( + 'registerConnectionFieldFactory is deprecated and will be removed in the next ' + 'major version. Use SQLAlchemyObjectType.Meta.connection_field_factory instead.' + ) global __connectionFactory __connectionFactory = UnsortedSQLAlchemyConnectionField diff --git a/graphene_sqlalchemy/tests/test_connectionfactory.py b/graphene_sqlalchemy/tests/test_connectionfactory.py deleted file mode 100644 index 796be5a4..00000000 --- a/graphene_sqlalchemy/tests/test_connectionfactory.py +++ /dev/null @@ -1,42 +0,0 @@ -import graphene -from graphene_sqlalchemy.fields import (SQLAlchemyConnectionField, - registerConnectionFieldFactory, - unregisterConnectionFieldFactory) - - -def test_register(): - class LXConnectionField(SQLAlchemyConnectionField): - @classmethod - def _applyQueryArgs(cls, model, q, args): - return q - - @classmethod - def connection_resolver( - cls, resolver, connection, model, root, args, context, info - ): - def LXResolver(root, args, context, info): - iterable = resolver(root, args, context, info) - if iterable is None: - iterable = cls.get_query(model, context, info, args) - - # We accept always a query here. All LX-queries can be filtered and sorted - iterable = cls._applyQueryArgs(model, iterable, args) - return iterable - - return SQLAlchemyConnectionField.connection_resolver( - LXResolver, connection, model, root, args, context, info - ) - - def createLXConnectionField(table): - class LXConnection(graphene.relay.Connection): - class Meta: - node = table - - return LXConnectionField( - LXConnection, - filter=table.filter(), - order_by=graphene.List(of_type=table.order_by), - ) - - registerConnectionFieldFactory(createLXConnectionField) - unregisterConnectionFieldFactory() diff --git a/graphene_sqlalchemy/tests/test_converter.py b/graphene_sqlalchemy/tests/test_converter.py index d205427b..5cc16e79 100644 --- a/graphene_sqlalchemy/tests/test_converter.py +++ b/graphene_sqlalchemy/tests/test_converter.py @@ -16,7 +16,8 @@ from ..converter import (convert_sqlalchemy_column, convert_sqlalchemy_composite, convert_sqlalchemy_relationship) -from ..fields import UnsortedSQLAlchemyConnectionField +from ..fields import (UnsortedSQLAlchemyConnectionField, + default_connection_field_factory) from ..registry import Registry from ..types import SQLAlchemyObjectType from .models import Article, Pet, Reporter @@ -179,7 +180,9 @@ def test_should_jsontype_convert_jsonstring(): def test_should_manytomany_convert_connectionorlist(): registry = Registry() - dynamic_field = convert_sqlalchemy_relationship(Reporter.pets.property, registry) + dynamic_field = convert_sqlalchemy_relationship( + Reporter.pets.property, registry, default_connection_field_factory + ) assert isinstance(dynamic_field, graphene.Dynamic) assert not dynamic_field.get_type() @@ -190,7 +193,7 @@ class Meta: model = Pet dynamic_field = convert_sqlalchemy_relationship( - Reporter.pets.property, A._meta.registry + Reporter.pets.property, A._meta.registry, default_connection_field_factory ) assert isinstance(dynamic_field, graphene.Dynamic) graphene_type = dynamic_field.get_type() @@ -206,7 +209,7 @@ class Meta: interfaces = (Node,) dynamic_field = convert_sqlalchemy_relationship( - Reporter.pets.property, A._meta.registry + Reporter.pets.property, A._meta.registry, default_connection_field_factory ) assert isinstance(dynamic_field, graphene.Dynamic) assert isinstance(dynamic_field.get_type(), UnsortedSQLAlchemyConnectionField) @@ -214,7 +217,9 @@ class Meta: def test_should_manytoone_convert_connectionorlist(): registry = Registry() - dynamic_field = convert_sqlalchemy_relationship(Article.reporter.property, registry) + dynamic_field = convert_sqlalchemy_relationship( + Article.reporter.property, registry, default_connection_field_factory + ) assert isinstance(dynamic_field, graphene.Dynamic) assert not dynamic_field.get_type() @@ -225,7 +230,7 @@ class Meta: model = Reporter dynamic_field = convert_sqlalchemy_relationship( - Article.reporter.property, A._meta.registry + Article.reporter.property, A._meta.registry, default_connection_field_factory ) assert isinstance(dynamic_field, graphene.Dynamic) graphene_type = dynamic_field.get_type() @@ -240,7 +245,7 @@ class Meta: interfaces = (Node,) dynamic_field = convert_sqlalchemy_relationship( - Article.reporter.property, A._meta.registry + Article.reporter.property, A._meta.registry, default_connection_field_factory ) assert isinstance(dynamic_field, graphene.Dynamic) graphene_type = dynamic_field.get_type() @@ -255,7 +260,7 @@ class Meta: interfaces = (Node,) dynamic_field = convert_sqlalchemy_relationship( - Reporter.favorite_article.property, A._meta.registry + Reporter.favorite_article.property, A._meta.registry, default_connection_field_factory ) assert isinstance(dynamic_field, graphene.Dynamic) graphene_type = dynamic_field.get_type() diff --git a/graphene_sqlalchemy/tests/test_types.py b/graphene_sqlalchemy/tests/test_types.py index 5eaf0137..0360a644 100644 --- a/graphene_sqlalchemy/tests/test_types.py +++ b/graphene_sqlalchemy/tests/test_types.py @@ -3,10 +3,13 @@ import six # noqa F401 from promise import Promise -from graphene import Field, Int, Interface, ObjectType -from graphene.relay import Connection, Node, is_node +from graphene import (Connection, Field, Int, Interface, Node, ObjectType, + is_node) -from ..fields import SQLAlchemyConnectionField +from ..fields import (SQLAlchemyConnectionField, + UnsortedSQLAlchemyConnectionField, + registerConnectionFieldFactory, + unregisterConnectionFieldFactory) from ..registry import Registry from ..types import SQLAlchemyObjectType, SQLAlchemyObjectTypeOptions from .models import Article, Reporter @@ -185,3 +188,92 @@ def resolver(*args, **kwargs): resolver, TestConnection, ReporterWithCustomOptions, None, None ) assert result is not None + + +# Tests for connection_field_factory + +class _TestSQLAlchemyConnectionField(SQLAlchemyConnectionField): + pass + + +def test_default_connection_field_factory(): + _registry = Registry() + + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + registry = _registry + interfaces = (Node,) + + class ArticleType(SQLAlchemyObjectType): + class Meta: + model = Article + registry = _registry + interfaces = (Node,) + + assert isinstance(ReporterType._meta.fields['articles'].type(), UnsortedSQLAlchemyConnectionField) + + +def test_register_connection_field_factory(): + def test_connection_field_factory(relationship, registry): + model = relationship.mapper.entity + _type = registry.get_type_for_model(model) + return _TestSQLAlchemyConnectionField(_type._meta.connection) + + _registry = Registry() + + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + registry = _registry + interfaces = (Node,) + connection_field_factory = test_connection_field_factory + + class ArticleType(SQLAlchemyObjectType): + class Meta: + model = Article + registry = _registry + interfaces = (Node,) + + assert isinstance(ReporterType._meta.fields['articles'].type(), _TestSQLAlchemyConnectionField) + + +def test_deprecated_registerConnectionFieldFactory(): + registerConnectionFieldFactory(_TestSQLAlchemyConnectionField) + + _registry = Registry() + + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + registry = _registry + interfaces = (Node,) + + class ArticleType(SQLAlchemyObjectType): + class Meta: + model = Article + registry = _registry + interfaces = (Node,) + + assert isinstance(ReporterType._meta.fields['articles'].type(), _TestSQLAlchemyConnectionField) + + +def test_deprecated_unregisterConnectionFieldFactory(): + registerConnectionFieldFactory(_TestSQLAlchemyConnectionField) + unregisterConnectionFieldFactory() + + _registry = Registry() + + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + registry = _registry + interfaces = (Node,) + + class ArticleType(SQLAlchemyObjectType): + class Meta: + model = Article + registry = _registry + interfaces = (Node,) + + assert not isinstance(ReporterType._meta.fields['articles'].type(), _TestSQLAlchemyConnectionField) diff --git a/graphene_sqlalchemy/types.py b/graphene_sqlalchemy/types.py index dde746ec..394d5062 100644 --- a/graphene_sqlalchemy/types.py +++ b/graphene_sqlalchemy/types.py @@ -14,11 +14,12 @@ convert_sqlalchemy_composite, convert_sqlalchemy_hybrid_method, convert_sqlalchemy_relationship) +from .fields import default_connection_field_factory from .registry import Registry, get_global_registry from .utils import get_query, is_mapped_class, is_mapped_instance -def construct_fields(model, registry, only_fields, exclude_fields): +def construct_fields(model, registry, only_fields, exclude_fields, connection_field_factory): inspected_model = sqlalchemyinspect(model) fields = OrderedDict() @@ -71,7 +72,7 @@ def construct_fields(model, registry, only_fields, exclude_fields): # We skip this field if we specify only_fields and is not # in there. Or when we exclude this field in exclude_fields continue - converted_relationship = convert_sqlalchemy_relationship(relationship, registry) + converted_relationship = convert_sqlalchemy_relationship(relationship, registry, connection_field_factory) name = relationship.key fields[name] = converted_relationship @@ -99,6 +100,7 @@ def __init_subclass_with_meta__( use_connection=None, interfaces=(), id=None, + connection_field_factory=default_connection_field_factory, _meta=None, **options ): @@ -115,7 +117,14 @@ def __init_subclass_with_meta__( ).format(cls.__name__, registry) sqla_fields = yank_fields_from_attrs( - construct_fields(model, registry, only_fields, exclude_fields), _as=Field + construct_fields( + model=model, + registry=registry, + only_fields=only_fields, + exclude_fields=exclude_fields, + connection_field_factory=connection_field_factory + ), + _as=Field ) if use_connection is None and interfaces: