Skip to content

Implement Mechanism to Selectively Override Automatic Field Creations #214

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 7, 2019
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 59 additions & 59 deletions graphene_sqlalchemy/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@
ChoiceType = JSONType = ScalarListType = TSVectorType = object


def _get_attr_resolver(attr_name):
return lambda root, _info: getattr(root, attr_name, None)


def get_column_doc(column):
return getattr(column, "doc", None)

Expand All @@ -24,43 +28,61 @@ def is_column_nullable(column):
return bool(getattr(column, "nullable", True))


def convert_sqlalchemy_relationship(relationship, registry, connection_field_factory):
direction = relationship.direction
model = relationship.mapper.entity
def convert_sqlalchemy_relationship(relationship_prop, registry, connection_field_factory, **field_kwargs):
direction = relationship_prop.direction
model = relationship_prop.mapper.entity

def dynamic_type():
_type = registry.get_type_for_model(model)

if not _type:
return None
if direction == interfaces.MANYTOONE or not relationship.uselist:
return Field(_type)
if direction == interfaces.MANYTOONE or not relationship_prop.uselist:
return Field(
_type,
resolver=_get_attr_resolver(relationship_prop.key),
**field_kwargs
)
elif direction in (interfaces.ONETOMANY, interfaces.MANYTOMANY):
if _type._meta.connection:
return connection_field_factory(relationship, registry)
return Field(List(_type))
# TODO Add a way to override connection_field_factory
return connection_field_factory(relationship_prop, registry, **field_kwargs)
return Field(
List(_type),
**field_kwargs
)

return Dynamic(dynamic_type)


def convert_sqlalchemy_hybrid_method(hybrid_item):
return String(description=getattr(hybrid_item, "__doc__", None), required=False)
def convert_sqlalchemy_hybrid_method(hybrid_prop, prop_name, **field_kwargs):
if 'type' not in field_kwargs:
# TODO The default type should be dependent on the type of the property propety.
field_kwargs['type'] = String

return Field(
resolver=_get_attr_resolver(prop_name),
**field_kwargs
)


def convert_sqlalchemy_composite(composite, registry):
converter = registry.get_converter_for_composite(composite.composite_class)
def convert_sqlalchemy_composite(composite_prop, registry):
converter = registry.get_converter_for_composite(composite_prop.composite_class)
if not converter:
try:
raise Exception(
"Don't know how to convert the composite field %s (%s)"
% (composite, composite.composite_class)
% (composite_prop, composite_prop.composite_class)
)
except AttributeError:
# handle fields that are not attached to a class yet (don't have a parent)
raise Exception(
"Don't know how to convert the composite field %r (%s)"
% (composite, composite.composite_class)
% (composite_prop, composite_prop.composite_class)
)
return converter(composite, registry)

# TODO Add a way to override composite fields default parameters
return converter(composite_prop, registry)


def _register_composite_class(cls, registry=None):
Expand All @@ -78,8 +100,16 @@ def inner(fn):
convert_sqlalchemy_composite.register = _register_composite_class


def convert_sqlalchemy_column(column, registry=None):
return convert_sqlalchemy_type(getattr(column, "type", None), column, registry)
def convert_sqlalchemy_column(column_prop, registry, **field_kwargs):
column = column_prop.columns[0]
field_kwargs.setdefault('type', convert_sqlalchemy_type(getattr(column, "type", None), column, registry))
field_kwargs.setdefault('required', not is_column_nullable(column))
field_kwargs.setdefault('description', get_column_doc(column))

return Field(
resolver=_get_attr_resolver(column_prop.key),
**field_kwargs
)


@singledispatch
Expand All @@ -101,93 +131,63 @@ def convert_sqlalchemy_type(type, column, registry=None):
@convert_sqlalchemy_type.register(postgresql.CIDR)
@convert_sqlalchemy_type.register(TSVectorType)
def convert_column_to_string(type, column, registry=None):
return String(
description=get_column_doc(column), required=not (is_column_nullable(column))
)
return String


@convert_sqlalchemy_type.register(types.DateTime)
def convert_column_to_datetime(type, column, registry=None):
from graphene.types.datetime import DateTime

return DateTime(
description=get_column_doc(column), required=not (is_column_nullable(column))
)
return DateTime


@convert_sqlalchemy_type.register(types.SmallInteger)
@convert_sqlalchemy_type.register(types.Integer)
def convert_column_to_int_or_id(type, column, registry=None):
if column.primary_key:
return ID(
description=get_column_doc(column),
required=not (is_column_nullable(column)),
)
else:
return Int(
description=get_column_doc(column),
required=not (is_column_nullable(column)),
)
return ID if column.primary_key else Int


@convert_sqlalchemy_type.register(types.Boolean)
def convert_column_to_boolean(type, column, registry=None):
return Boolean(
description=get_column_doc(column), required=not (is_column_nullable(column))
)
return Boolean


@convert_sqlalchemy_type.register(types.Float)
@convert_sqlalchemy_type.register(types.Numeric)
@convert_sqlalchemy_type.register(types.BigInteger)
def convert_column_to_float(type, column, registry=None):
return Float(
description=get_column_doc(column), required=not (is_column_nullable(column))
)
return Float


@convert_sqlalchemy_type.register(types.Enum)
def convert_enum_to_enum(type, column, registry=None):
return Field(
lambda: enum_for_sa_enum(type, registry or get_global_registry()),
description=get_column_doc(column),
required=not (is_column_nullable(column)),
)
return lambda: enum_for_sa_enum(type, registry or get_global_registry())


# TODO Make ChoiceType conversion consistent with other enums
@convert_sqlalchemy_type.register(ChoiceType)
def convert_choice_to_enum(type, column, registry=None):
name = "{}_{}".format(column.table.name, column.name).upper()
return Enum(name, type.choices, description=get_column_doc(column))
return Enum(name, type.choices)


@convert_sqlalchemy_type.register(ScalarListType)
def convert_scalar_list_to_list(type, column, registry=None):
return List(String, description=get_column_doc(column))
return List(String)


@convert_sqlalchemy_type.register(postgresql.ARRAY)
def convert_postgres_array_to_list(_type, column, registry=None):
graphene_type = convert_sqlalchemy_type(column.type.item_type, column)
inner_type = type(graphene_type)
return List(
inner_type,
description=get_column_doc(column),
required=not (is_column_nullable(column)),
)
inner_type = convert_sqlalchemy_type(column.type.item_type, column)
return List(inner_type)


@convert_sqlalchemy_type.register(postgresql.HSTORE)
@convert_sqlalchemy_type.register(postgresql.JSON)
@convert_sqlalchemy_type.register(postgresql.JSONB)
def convert_json_to_string(type, column, registry=None):
return JSONString(
description=get_column_doc(column), required=not (is_column_nullable(column))
)
return JSONString


@convert_sqlalchemy_type.register(JSONType)
def convert_json_type_to_string(type, column, registry=None):
return JSONString(
description=get_column_doc(column), required=not (is_column_nullable(column))
)
return JSONString
22 changes: 12 additions & 10 deletions graphene_sqlalchemy/enums.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from sqlalchemy import Column
from sqlalchemy.orm import ColumnProperty
from sqlalchemy.types import Enum as SQLAlchemyEnumType

from graphene import Argument, Enum, List
Expand Down Expand Up @@ -69,11 +69,12 @@ def enum_for_field(obj_type, field_name):
orm_field = registry.get_orm_field_for_graphene_field(obj_type, field_name)
if orm_field is None:
raise TypeError("Cannot get {}.{}".format(obj_type._meta.name, field_name))
if not isinstance(orm_field, Column):
if not isinstance(orm_field, ColumnProperty):
raise TypeError(
"{}.{} does not map to model column".format(obj_type._meta.name, field_name)
)
sa_enum = orm_field.type
column = orm_field.columns[0]
sa_enum = column.type
if not isinstance(sa_enum, SQLAlchemyEnumType):
raise TypeError(
"{}.{} does not map to enum column".format(obj_type._meta.name, field_name)
Expand Down Expand Up @@ -138,15 +139,16 @@ def sort_enum_for_object_type(
if only_fields and field_name not in only_fields:
continue
orm_field = registry.get_orm_field_for_graphene_field(obj_type, field_name)
if not isinstance(orm_field, Column):
if not isinstance(orm_field, ColumnProperty):
continue
if only_indexed and not (orm_field.primary_key or orm_field.index):
column = orm_field.columns[0]
if only_indexed and not (column.primary_key or column.index):
continue
asc_name = get_name(orm_field.name, True)
asc_value = EnumValue(asc_name, orm_field.asc())
desc_name = get_name(orm_field.name, False)
desc_value = EnumValue(desc_name, orm_field.desc())
if orm_field.primary_key:
asc_name = get_name(column.name, True)
asc_value = EnumValue(asc_name, column.asc())
desc_name = get_name(column.name, False)
desc_value = EnumValue(desc_name, column.desc())
if column.primary_key:
default.append(asc_value)
members.extend(((asc_name, asc_value), (desc_name, desc_value)))
enum = Enum(name, members)
Expand Down
8 changes: 4 additions & 4 deletions graphene_sqlalchemy/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,22 +97,22 @@ def __init__(self, type, *args, **kwargs):
super(SQLAlchemyConnectionField, self).__init__(type, *args, **kwargs)


def default_connection_field_factory(relationship, registry):
def default_connection_field_factory(relationship, registry, **field_kwargs):
model = relationship.mapper.entity
model_type = registry.get_type_for_model(model)
return createConnectionField(model_type)
return createConnectionField(model_type, **field_kwargs)


# TODO Remove in next major version
__connectionFactory = UnsortedSQLAlchemyConnectionField


def createConnectionField(_type):
def createConnectionField(_type, **field_kwargs):
log.warning(
'createConnectionField is deprecated and will be removed in the next '
'major version. Use SQLAlchemyObjectType.Meta.connection_field_factory instead.'
)
return __connectionFactory(_type)
return __connectionFactory(_type, **field_kwargs)


def registerConnectionFieldFactory(factoryMethod):
Expand Down
9 changes: 8 additions & 1 deletion graphene_sqlalchemy/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
from sqlalchemy import create_engine
from sqlalchemy.orm import scoped_session, sessionmaker

from ..converter import convert_sqlalchemy_composite
from ..registry import reset_global_registry
from .models import Base
from .models import Base, CompositeFullName

test_db_url = 'sqlite://' # use in-memory database for tests

Expand All @@ -12,6 +13,12 @@
def reset_registry():
reset_global_registry()

# Prevent tests that implicitly depend on Reporter from raising
# Tests that explicitly depend on this behavior should re-register a converter
@convert_sqlalchemy_composite.register(CompositeFullName)
def convert_composite_class(composite, registry):
pass


@pytest.yield_fixture(scope="function")
def session():
Expand Down
39 changes: 29 additions & 10 deletions graphene_sqlalchemy/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@

import enum

from sqlalchemy import Column, Date, Enum, ForeignKey, Integer, String, Table
from sqlalchemy import (Column, Date, Enum, ForeignKey, Integer, String, Table,
func, select)
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import mapper, relationship
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import column_property, composite, mapper, relationship

PetKind = Enum("cat", "dog", name="pet_kind")

Expand Down Expand Up @@ -39,22 +41,39 @@ class Pet(Base):
reporter_id = Column(Integer(), ForeignKey("reporters.id"))


class CompositeFullName(object):
def __init__(self, first_name, last_name):
self.first_name = first_name
self.last_name = last_name

def __composite_values__(self):
return self.first_name, self.last_name

def __repr__(self):
return "{} {}".format(self.first_name, self.last_name)


class Reporter(Base):
__tablename__ = "reporters"

id = Column(Integer(), primary_key=True)
first_name = Column(String(30))
last_name = Column(String(30))
email = Column(String())
first_name = Column(String(30), doc="First name")
last_name = Column(String(30), doc="Last name")
email = Column(String(), doc="Email")
favorite_pet_kind = Column(PetKind)
pets = relationship("Pet", secondary=association_table, backref="reporters")
articles = relationship("Article", backref="reporter")
favorite_article = relationship("Article", uselist=False)

# total = column_property(
# select([
# func.cast(func.count(PersonInfo.id), Float)
# ])
# )
@hybrid_property
def hybrid_prop(self):
return self.first_name

column_prop = column_property(
select([func.cast(func.count(id), Integer)]), doc="Column property"
)

composite_prop = composite(CompositeFullName, first_name, last_name, doc="Composite")


class Article(Base):
Expand Down
Loading