Skip to content

Commit 6a96d37

Browse files
authored
Merge pull request #154 from curvetips/fix-enum-conversion
Fix creation of graphene.Enum from enum.Enum
2 parents 33d5b74 + d4365e1 commit 6a96d37

File tree

4 files changed

+102
-11
lines changed

4 files changed

+102
-11
lines changed

Diff for: graphene_sqlalchemy/converter.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -146,12 +146,14 @@ def convert_column_to_float(type, column, registry=None):
146146

147147
@convert_sqlalchemy_type.register(types.Enum)
148148
def convert_enum_to_enum(type, column, registry=None):
149-
try:
150-
items = type.enum_class.__members__.items()
151-
except AttributeError:
149+
enum_class = getattr(type, 'enum_class', None)
150+
if enum_class: # Check if an enum.Enum type is used
151+
graphene_type = Enum.from_enum(enum_class)
152+
else: # Nope, just a list of string options
152153
items = zip(type.enums, type.enums)
154+
graphene_type = Enum(type.name, items)
153155
return Field(
154-
Enum(type.name, items),
156+
graphene_type,
155157
description=get_column_doc(column),
156158
required=not (is_column_nullable(column)),
157159
)

Diff for: graphene_sqlalchemy/tests/models.py

+7
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,12 @@
66
from sqlalchemy.ext.declarative import declarative_base
77
from sqlalchemy.orm import mapper, relationship
88

9+
10+
class Hairkind(enum.Enum):
11+
LONG = 'long'
12+
SHORT = 'short'
13+
14+
915
Base = declarative_base()
1016

1117
association_table = Table(
@@ -27,6 +33,7 @@ class Pet(Base):
2733
id = Column(Integer(), primary_key=True)
2834
name = Column(String(30))
2935
pet_kind = Column(Enum("cat", "dog", name="pet_kind"), nullable=False)
36+
hair_kind = Column(Enum(Hairkind, name="hair_kind"), nullable=False)
3037
reporter_id = Column(Integer(), ForeignKey("reporters.id"))
3138

3239

Diff for: graphene_sqlalchemy/tests/test_converter.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -271,14 +271,24 @@ def test_should_postgresql_uuid_convert():
271271

272272
def test_should_postgresql_enum_convert():
273273
field = assert_column_conversion(
274-
postgresql.ENUM(enum.Enum("one", "two"), name="two_numbers"), graphene.Field
274+
postgresql.ENUM("one", "two", name="two_numbers"), graphene.Field
275275
)
276276
field_type = field.type()
277277
assert field_type.__class__.__name__ == "two_numbers"
278278
assert isinstance(field_type, graphene.Enum)
279279
assert hasattr(field_type, "two")
280280

281281

282+
def test_should_postgresql_py_enum_convert():
283+
field = assert_column_conversion(
284+
postgresql.ENUM(enum.Enum("TwoNumbers", "one two"), name="two_numbers"), graphene.Field
285+
)
286+
field_type = field.type()
287+
assert field_type.__class__.__name__ == "TwoNumbers"
288+
assert isinstance(field_type, graphene.Enum)
289+
assert hasattr(field_type, "two")
290+
291+
282292
def test_should_postgresql_array_convert():
283293
assert_column_conversion(postgresql.ARRAY(types.Integer), graphene.List)
284294

Diff for: graphene_sqlalchemy/tests/test_query.py

+78-6
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from ..fields import SQLAlchemyConnectionField
1010
from ..types import SQLAlchemyObjectType
1111
from ..utils import sort_argument_for_model, sort_enum_for_model
12-
from .models import Article, Base, Editor, Pet, Reporter
12+
from .models import Article, Base, Editor, Pet, Reporter, Hairkind
1313

1414
db = create_engine("sqlite:///test_sqlalchemy.sqlite3")
1515

@@ -34,7 +34,7 @@ def session():
3434

3535

3636
def setup_fixtures(session):
37-
pet = Pet(name="Lassie", pet_kind="dog")
37+
pet = Pet(name="Lassie", pet_kind="dog", hair_kind=Hairkind.LONG)
3838
session.add(pet)
3939
reporter = Reporter(first_name="ABA", last_name="X")
4040
session.add(reporter)
@@ -105,16 +105,88 @@ def resolve_pet(self, *args, **kwargs):
105105
pet {
106106
name,
107107
petKind
108+
hairKind
108109
}
109110
}
110111
"""
111-
expected = {"pet": {"name": "Lassie", "petKind": "dog"}}
112+
expected = {"pet": {"name": "Lassie", "petKind": "dog", "hairKind": "LONG"}}
112113
schema = graphene.Schema(query=Query)
113114
result = schema.execute(query)
114115
assert not result.errors
115116
assert result.data == expected, result.data
116117

117118

119+
def test_enum_parameter(session):
120+
setup_fixtures(session)
121+
122+
class PetType(SQLAlchemyObjectType):
123+
class Meta:
124+
model = Pet
125+
126+
class Query(graphene.ObjectType):
127+
pet = graphene.Field(PetType, kind=graphene.Argument(PetType._meta.fields['pet_kind'].type.of_type))
128+
129+
def resolve_pet(self, info, kind=None, *args, **kwargs):
130+
query = session.query(Pet)
131+
if kind:
132+
query = query.filter(Pet.pet_kind == kind)
133+
return query.first()
134+
135+
query = """
136+
query PetQuery($kind: pet_kind) {
137+
pet(kind: $kind) {
138+
name,
139+
petKind
140+
hairKind
141+
}
142+
}
143+
"""
144+
expected = {"pet": {"name": "Lassie", "petKind": "dog", "hairKind": "LONG"}}
145+
schema = graphene.Schema(query=Query)
146+
result = schema.execute(query, variables={"kind": "cat"})
147+
assert not result.errors
148+
assert result.data == {"pet": None}
149+
result = schema.execute(query, variables={"kind": "dog"})
150+
assert not result.errors
151+
assert result.data == expected, result.data
152+
153+
154+
def test_py_enum_parameter(session):
155+
setup_fixtures(session)
156+
157+
class PetType(SQLAlchemyObjectType):
158+
class Meta:
159+
model = Pet
160+
161+
class Query(graphene.ObjectType):
162+
pet = graphene.Field(PetType, kind=graphene.Argument(PetType._meta.fields['hair_kind'].type.of_type))
163+
164+
def resolve_pet(self, info, kind=None, *args, **kwargs):
165+
query = session.query(Pet)
166+
if kind:
167+
# XXX Why kind passed in as a str instead of a Hairkind instance?
168+
query = query.filter(Pet.hair_kind == Hairkind(kind))
169+
return query.first()
170+
171+
query = """
172+
query PetQuery($kind: Hairkind) {
173+
pet(kind: $kind) {
174+
name,
175+
petKind
176+
hairKind
177+
}
178+
}
179+
"""
180+
expected = {"pet": {"name": "Lassie", "petKind": "dog", "hairKind": "LONG"}}
181+
schema = graphene.Schema(query=Query)
182+
result = schema.execute(query, variables={"kind": "SHORT"})
183+
assert not result.errors
184+
assert result.data == {"pet": None}
185+
result = schema.execute(query, variables={"kind": "LONG"})
186+
assert not result.errors
187+
assert result.data == expected, result.data
188+
189+
118190
def test_should_node(session):
119191
setup_fixtures(session)
120192

@@ -326,9 +398,9 @@ class Mutation(graphene.ObjectType):
326398

327399
def sort_setup(session):
328400
pets = [
329-
Pet(id=2, name="Lassie", pet_kind="dog"),
330-
Pet(id=22, name="Alf", pet_kind="cat"),
331-
Pet(id=3, name="Barf", pet_kind="dog"),
401+
Pet(id=2, name="Lassie", pet_kind="dog", hair_kind=Hairkind.LONG),
402+
Pet(id=22, name="Alf", pet_kind="cat", hair_kind=Hairkind.LONG),
403+
Pet(id=3, name="Barf", pet_kind="dog", hair_kind=Hairkind.LONG),
332404
]
333405
session.add_all(pets)
334406
session.commit()

0 commit comments

Comments
 (0)