Skip to content

Proper support for the SQLAlchemy Enum type #78

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 22, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 23 additions & 8 deletions graphene_sqlalchemy/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,6 @@ def convert_sqlalchemy_type(type, column, registry=None):
@convert_sqlalchemy_type.register(types.Text)
@convert_sqlalchemy_type.register(types.Unicode)
@convert_sqlalchemy_type.register(types.UnicodeText)
@convert_sqlalchemy_type.register(types.Enum)
@convert_sqlalchemy_type.register(postgresql.ENUM)
@convert_sqlalchemy_type.register(postgresql.UUID)
@convert_sqlalchemy_type.register(TSVectorType)
def convert_column_to_string(type, column, registry=None):
Expand All @@ -118,22 +116,36 @@ def convert_column_to_datetime(type, column, registry=None):
@convert_sqlalchemy_type.register(types.Integer)
def convert_column_to_int_or_id(type, column, registry=None):
if column.primary_key:
return ID(description=get_column_doc(column), required=not (is_column_nullable(column)))
return ID(description=get_column_doc(column),
required=not (is_column_nullable(column)))
else:
return Int(description=get_column_doc(column),
required=not (is_column_nullable(column)))


@convert_sqlalchemy_type.register(types.Boolean)
def convert_column_to_boolean(type, column, registry=None):
return Boolean(description=get_column_doc(column), required=not(is_column_nullable(column)))
return Boolean(description=get_column_doc(column),
required=not(is_column_nullable(column)))


@convert_sqlalchemy_type.register(types.Float)
@convert_sqlalchemy_type.register(types.Numeric)
@convert_sqlalchemy_type.register(types.BigInteger)
def convert_column_to_float(type, column, registry=None):
return Float(description=get_column_doc(column), required=not(is_column_nullable(column)))
return Float(description=get_column_doc(column),
required=not(is_column_nullable(column)))


@convert_sqlalchemy_type.register(types.Enum)
def convert_enum_to_enum(type, column, registry=None):
try:
items = type.enum_class.__members__.items()
except AttributeError:
items = zip(type.enums, type.enums)
return Field(Enum(type.name, items),
description=get_column_doc(column),
required=not(is_column_nullable(column)))


@convert_sqlalchemy_type.register(ChoiceType)
Expand All @@ -151,16 +163,19 @@ def convert_scalar_list_to_list(type, column, registry=None):
def convert_postgres_array_to_list(_type, column, registry=None):
graphene_type = convert_sqlalchemy_type(column.type.item_type, column)
inner_type = type(graphene_type)
return List(inner_type, description=get_column_doc(column), required=not(is_column_nullable(column)))
return List(inner_type, description=get_column_doc(column),
required=not(is_column_nullable(column)))


@convert_sqlalchemy_type.register(postgresql.HSTORE)
@convert_sqlalchemy_type.register(postgresql.JSON)
@convert_sqlalchemy_type.register(postgresql.JSONB)
def convert_json_to_string(type, column, registry=None):
return JSONString(description=get_column_doc(column), required=not(is_column_nullable(column)))
return JSONString(description=get_column_doc(column),
required=not(is_column_nullable(column)))


@convert_sqlalchemy_type.register(JSONType)
def convert_json_type_to_string(type, column, registry=None):
return JSONString(description=get_column_doc(column), required=not(is_column_nullable(column)))
return JSONString(description=get_column_doc(column),
required=not(is_column_nullable(column)))
5 changes: 4 additions & 1 deletion graphene_sqlalchemy/tests/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import absolute_import

from sqlalchemy import Column, Date, ForeignKey, Integer, String, Table
import enum

from sqlalchemy import Column, Date, Enum, ForeignKey, Integer, String, Table
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import mapper, relationship

Expand All @@ -21,6 +23,7 @@ class Pet(Base):
__tablename__ = 'pets'
id = Column(Integer(), primary_key=True)
name = Column(String(30))
pet_kind = Column(Enum('cat', 'dog', name='pet_kind'), nullable=False)
reporter_id = Column(Integer(), ForeignKey('reporters.id'))


Expand Down
27 changes: 23 additions & 4 deletions graphene_sqlalchemy/tests/test_converter.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import enum

from py.test import raises
from sqlalchemy import Column, Table, case, types, select, func
from sqlalchemy.dialects import postgresql
Expand All @@ -24,7 +26,8 @@ def assert_column_conversion(sqlalchemy_type, graphene_field, **kwargs):
column = Column(sqlalchemy_type, doc='Custom Help Text', **kwargs)
graphene_type = convert_sqlalchemy_column(column)
assert isinstance(graphene_type, graphene_field)
field = graphene_type.Field()
field = graphene_type if isinstance(
graphene_type, graphene.Field) else graphene_type.Field()
assert field.description == 'Custom Help Text'
return field

Expand Down Expand Up @@ -76,8 +79,18 @@ def test_should_unicodetext_convert_string():
assert_column_conversion(types.UnicodeText(), graphene.String)


def test_should_enum_convert_string():
assert_column_conversion(types.Enum(), graphene.String)
def test_should_enum_convert_enum():
field = assert_column_conversion(
types.Enum(enum.Enum('one', 'two')), graphene.Field)
field_type = field.type()
assert isinstance(field_type, graphene.Enum)
assert hasattr(field_type, 'two')
field = assert_column_conversion(
types.Enum('one', 'two', name='two_numbers'), graphene.Field)
field_type = field.type()
assert field_type.__class__.__name__ == 'two_numbers'
assert isinstance(field_type, graphene.Enum)
assert hasattr(field_type, 'two')


def test_should_small_integer_convert_int():
Expand Down Expand Up @@ -119,6 +132,7 @@ def test_should_label_convert_int():
graphene_type = convert_sqlalchemy_column(label)
assert isinstance(graphene_type, graphene.Int)


def test_should_choice_convert_enum():
TYPES = [
(u'es', u'Spanish'),
Expand Down Expand Up @@ -247,7 +261,12 @@ def test_should_postgresql_uuid_convert():


def test_should_postgresql_enum_convert():
assert_column_conversion(postgresql.ENUM(), graphene.String)
field = assert_column_conversion(postgresql.ENUM(
enum.Enum('one', 'two'), name='two_numbers'), graphene.Field)
field_type = field.type()
assert field_type.__class__.__name__ == 'two_numbers'
assert isinstance(field_type, graphene.Enum)
assert hasattr(field_type, 'two')


def test_should_postgresql_array_convert():
Expand Down
38 changes: 37 additions & 1 deletion graphene_sqlalchemy/tests/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ..registry import reset_global_registry
from ..fields import SQLAlchemyConnectionField
from ..types import SQLAlchemyObjectType
from .models import Article, Base, Editor, Reporter
from .models import Article, Base, Editor, Pet, Reporter

db = create_engine('sqlite:///test_sqlalchemy.sqlite3')

Expand All @@ -33,6 +33,8 @@ def session():


def setup_fixtures(session):
pet = Pet(name='Lassie', pet_kind='dog')
session.add(pet)
reporter = Reporter(first_name='ABA', last_name='X')
session.add(reporter)
reporter2 = Reporter(first_name='ABO', last_name='Y')
Expand Down Expand Up @@ -93,6 +95,40 @@ def resolve_reporters(self, *args, **kwargs):
assert result.data == expected


def test_should_query_enums(session):
setup_fixtures(session)

class PetType(SQLAlchemyObjectType):

class Meta:
model = Pet

class Query(graphene.ObjectType):
pet = graphene.Field(PetType)

def resolve_pet(self, *args, **kwargs):
return session.query(Pet).first()

query = '''
query PetQuery {
pet {
name,
petKind
}
}
'''
expected = {
'pet': {
'name': 'Lassie',
'petKind': 'dog'
}
}
schema = graphene.Schema(query=Query)
result = schema.execute(query)
assert not result.errors
assert result.data == expected, result.data


def test_should_node(session):
setup_fixtures(session)

Expand Down