diff --git a/graphene_sqlalchemy/converter.py b/graphene_sqlalchemy/converter.py index 17e54b0f..def1fda9 100644 --- a/graphene_sqlalchemy/converter.py +++ b/graphene_sqlalchemy/converter.py @@ -98,8 +98,6 @@ def convert_sqlalchemy_type(type, column, registry=None): @convert_sqlalchemy_type.register(types.Text) @convert_sqlalchemy_type.register(types.Unicode) @convert_sqlalchemy_type.register(types.UnicodeText) -@convert_sqlalchemy_type.register(types.Enum) -@convert_sqlalchemy_type.register(postgresql.ENUM) @convert_sqlalchemy_type.register(postgresql.UUID) @convert_sqlalchemy_type.register(TSVectorType) def convert_column_to_string(type, column, registry=None): @@ -118,7 +116,8 @@ def convert_column_to_datetime(type, column, registry=None): @convert_sqlalchemy_type.register(types.Integer) def convert_column_to_int_or_id(type, column, registry=None): if column.primary_key: - return ID(description=get_column_doc(column), required=not (is_column_nullable(column))) + return ID(description=get_column_doc(column), + required=not (is_column_nullable(column))) else: return Int(description=get_column_doc(column), required=not (is_column_nullable(column))) @@ -126,14 +125,27 @@ def convert_column_to_int_or_id(type, column, registry=None): @convert_sqlalchemy_type.register(types.Boolean) def convert_column_to_boolean(type, column, registry=None): - return Boolean(description=get_column_doc(column), required=not(is_column_nullable(column))) + return Boolean(description=get_column_doc(column), + required=not(is_column_nullable(column))) @convert_sqlalchemy_type.register(types.Float) @convert_sqlalchemy_type.register(types.Numeric) @convert_sqlalchemy_type.register(types.BigInteger) def convert_column_to_float(type, column, registry=None): - return Float(description=get_column_doc(column), required=not(is_column_nullable(column))) + return Float(description=get_column_doc(column), + required=not(is_column_nullable(column))) + + +@convert_sqlalchemy_type.register(types.Enum) +def convert_enum_to_enum(type, column, registry=None): + try: + items = type.enum_class.__members__.items() + except AttributeError: + items = zip(type.enums, type.enums) + return Field(Enum(type.name, items), + description=get_column_doc(column), + required=not(is_column_nullable(column))) @convert_sqlalchemy_type.register(ChoiceType) @@ -151,16 +163,19 @@ def convert_scalar_list_to_list(type, column, registry=None): def convert_postgres_array_to_list(_type, column, registry=None): graphene_type = convert_sqlalchemy_type(column.type.item_type, column) inner_type = type(graphene_type) - return List(inner_type, description=get_column_doc(column), required=not(is_column_nullable(column))) + return List(inner_type, description=get_column_doc(column), + required=not(is_column_nullable(column))) @convert_sqlalchemy_type.register(postgresql.HSTORE) @convert_sqlalchemy_type.register(postgresql.JSON) @convert_sqlalchemy_type.register(postgresql.JSONB) def convert_json_to_string(type, column, registry=None): - return JSONString(description=get_column_doc(column), required=not(is_column_nullable(column))) + return JSONString(description=get_column_doc(column), + required=not(is_column_nullable(column))) @convert_sqlalchemy_type.register(JSONType) def convert_json_type_to_string(type, column, registry=None): - return JSONString(description=get_column_doc(column), required=not(is_column_nullable(column))) + return JSONString(description=get_column_doc(column), + required=not(is_column_nullable(column))) diff --git a/graphene_sqlalchemy/tests/models.py b/graphene_sqlalchemy/tests/models.py index 3f27bc48..f9b93c15 100644 --- a/graphene_sqlalchemy/tests/models.py +++ b/graphene_sqlalchemy/tests/models.py @@ -1,6 +1,8 @@ from __future__ import absolute_import -from sqlalchemy import Column, Date, ForeignKey, Integer, String, Table +import enum + +from sqlalchemy import Column, Date, Enum, ForeignKey, Integer, String, Table from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import mapper, relationship @@ -21,6 +23,7 @@ class Pet(Base): __tablename__ = 'pets' id = Column(Integer(), primary_key=True) name = Column(String(30)) + pet_kind = Column(Enum('cat', 'dog', name='pet_kind'), nullable=False) reporter_id = Column(Integer(), ForeignKey('reporters.id')) diff --git a/graphene_sqlalchemy/tests/test_converter.py b/graphene_sqlalchemy/tests/test_converter.py index 3c732b27..221de606 100644 --- a/graphene_sqlalchemy/tests/test_converter.py +++ b/graphene_sqlalchemy/tests/test_converter.py @@ -1,3 +1,5 @@ +import enum + from py.test import raises from sqlalchemy import Column, Table, case, types, select, func from sqlalchemy.dialects import postgresql @@ -24,7 +26,8 @@ def assert_column_conversion(sqlalchemy_type, graphene_field, **kwargs): column = Column(sqlalchemy_type, doc='Custom Help Text', **kwargs) graphene_type = convert_sqlalchemy_column(column) assert isinstance(graphene_type, graphene_field) - field = graphene_type.Field() + field = graphene_type if isinstance( + graphene_type, graphene.Field) else graphene_type.Field() assert field.description == 'Custom Help Text' return field @@ -76,8 +79,18 @@ def test_should_unicodetext_convert_string(): assert_column_conversion(types.UnicodeText(), graphene.String) -def test_should_enum_convert_string(): - assert_column_conversion(types.Enum(), graphene.String) +def test_should_enum_convert_enum(): + field = assert_column_conversion( + types.Enum(enum.Enum('one', 'two')), graphene.Field) + field_type = field.type() + assert isinstance(field_type, graphene.Enum) + assert hasattr(field_type, 'two') + field = assert_column_conversion( + types.Enum('one', 'two', name='two_numbers'), graphene.Field) + field_type = field.type() + assert field_type.__class__.__name__ == 'two_numbers' + assert isinstance(field_type, graphene.Enum) + assert hasattr(field_type, 'two') def test_should_small_integer_convert_int(): @@ -119,6 +132,7 @@ def test_should_label_convert_int(): graphene_type = convert_sqlalchemy_column(label) assert isinstance(graphene_type, graphene.Int) + def test_should_choice_convert_enum(): TYPES = [ (u'es', u'Spanish'), @@ -247,7 +261,12 @@ def test_should_postgresql_uuid_convert(): def test_should_postgresql_enum_convert(): - assert_column_conversion(postgresql.ENUM(), graphene.String) + field = assert_column_conversion(postgresql.ENUM( + enum.Enum('one', 'two'), name='two_numbers'), graphene.Field) + field_type = field.type() + assert field_type.__class__.__name__ == 'two_numbers' + assert isinstance(field_type, graphene.Enum) + assert hasattr(field_type, 'two') def test_should_postgresql_array_convert(): diff --git a/graphene_sqlalchemy/tests/test_query.py b/graphene_sqlalchemy/tests/test_query.py index e4c3f835..12dd1fad 100644 --- a/graphene_sqlalchemy/tests/test_query.py +++ b/graphene_sqlalchemy/tests/test_query.py @@ -8,7 +8,7 @@ from ..registry import reset_global_registry from ..fields import SQLAlchemyConnectionField from ..types import SQLAlchemyObjectType -from .models import Article, Base, Editor, Reporter +from .models import Article, Base, Editor, Pet, Reporter db = create_engine('sqlite:///test_sqlalchemy.sqlite3') @@ -33,6 +33,8 @@ def session(): def setup_fixtures(session): + pet = Pet(name='Lassie', pet_kind='dog') + session.add(pet) reporter = Reporter(first_name='ABA', last_name='X') session.add(reporter) reporter2 = Reporter(first_name='ABO', last_name='Y') @@ -93,6 +95,40 @@ def resolve_reporters(self, *args, **kwargs): assert result.data == expected +def test_should_query_enums(session): + setup_fixtures(session) + + class PetType(SQLAlchemyObjectType): + + class Meta: + model = Pet + + class Query(graphene.ObjectType): + pet = graphene.Field(PetType) + + def resolve_pet(self, *args, **kwargs): + return session.query(Pet).first() + + query = ''' + query PetQuery { + pet { + name, + petKind + } + } + ''' + expected = { + 'pet': { + 'name': 'Lassie', + 'petKind': 'dog' + } + } + schema = graphene.Schema(query=Query) + result = schema.execute(query) + assert not result.errors + assert result.data == expected, result.data + + def test_should_node(session): setup_fixtures(session)