Skip to content

Commit 36ecbfb

Browse files
committed
Improve creation of column and sort enums
- create separate enum module - create enums based on object type, not based on model - provide more customization options - split tests in different modules - adapt flask_sqlalchemy example - use conftest.py for better test fixtures
1 parent e362e3f commit 36ecbfb

24 files changed

+1289
-471
lines changed

Diff for: .gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ __pycache__/
1111
# Distribution / packaging
1212
.Python
1313
env/
14+
.venv/
1415
build/
1516
develop-eggs/
1617
dist/

Diff for: examples/flask_sqlalchemy/app.py

+16-13
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,46 @@
11
#!/usr/bin/env python
22

3+
from database import db_session, init_db
34
from flask import Flask
5+
from schema import schema
46

57
from flask_graphql import GraphQLView
68

7-
from .database import db_session, init_db
8-
from .schema import schema
9-
109
app = Flask(__name__)
1110
app.debug = True
1211

13-
default_query = '''
12+
example_query = """
1413
{
15-
allEmployees {
14+
allEmployees(sort: [NAME_ASC, ID_ASC]) {
1615
edges {
1716
node {
18-
id,
19-
name,
17+
id
18+
name
2019
department {
21-
id,
20+
id
2221
name
23-
},
22+
}
2423
role {
25-
id,
24+
id
2625
name
2726
}
2827
}
2928
}
3029
}
31-
}'''.strip()
30+
}
31+
"""
3232

3333

34-
app.add_url_rule('/graphql', view_func=GraphQLView.as_view('graphql', schema=schema, graphiql=True))
34+
app.add_url_rule(
35+
"/graphql", view_func=GraphQLView.as_view("graphql", schema=schema, graphiql=True)
36+
)
3537

3638

3739
@app.teardown_appcontext
3840
def shutdown_session(exception=None):
3941
db_session.remove()
4042

41-
if __name__ == '__main__':
43+
44+
if __name__ == "__main__":
4245
init_db()
4346
app.run()

Diff for: examples/flask_sqlalchemy/database.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def init_db():
1414
# import all modules here that might define models so that
1515
# they will be registered properly on the metadata. Otherwise
1616
# you will have to import them first before calling init_db()
17-
from .models import Department, Employee, Role
17+
from models import Department, Employee, Role
1818
Base.metadata.drop_all(bind=engine)
1919
Base.metadata.create_all(bind=engine)
2020

Diff for: examples/flask_sqlalchemy/models.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
1+
from database import Base
12
from sqlalchemy import Column, DateTime, ForeignKey, Integer, String, func
23
from sqlalchemy.orm import backref, relationship
34

4-
from .database import Base
5-
65

76
class Department(Base):
87
__tablename__ = 'department'

Diff for: examples/flask_sqlalchemy/schema.py

+6-14
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
1+
from models import Department as DepartmentModel
2+
from models import Employee as EmployeeModel
3+
from models import Role as RoleModel
4+
15
import graphene
26
from graphene import relay
3-
from graphene_sqlalchemy import (SQLAlchemyConnectionField,
4-
SQLAlchemyObjectType, utils)
5-
6-
from .models import Department as DepartmentModel
7-
from .models import Employee as EmployeeModel
8-
from .models import Role as RoleModel
7+
from graphene_sqlalchemy import SQLAlchemyConnectionField, SQLAlchemyObjectType
98

109

1110
class Department(SQLAlchemyObjectType):
@@ -26,18 +25,11 @@ class Meta:
2625
interfaces = (relay.Node, )
2726

2827

29-
SortEnumEmployee = utils.sort_enum_for_model(EmployeeModel, 'SortEnumEmployee',
30-
lambda c, d: c.upper() + ('_ASC' if d else '_DESC'))
31-
32-
3328
class Query(graphene.ObjectType):
3429
node = relay.Node.Field()
3530
# Allow only single column sorting
3631
all_employees = SQLAlchemyConnectionField(
37-
Employee,
38-
sort=graphene.Argument(
39-
SortEnumEmployee,
40-
default_value=utils.EnumValue('id_asc', EmployeeModel.id.asc())))
32+
Employee, sort=Employee.sort_argument())
4133
# Allows sorting over multiple columns, by default over the primary key
4234
all_roles = SQLAlchemyConnectionField(Role)
4335
# Disable sorting over this field

Diff for: graphene_sqlalchemy/converter.py

+5-8
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
String)
88
from graphene.types.json import JSONString
99

10+
from .enums import enum_for_sa_enum
11+
from .registry import get_global_registry
12+
1013
try:
1114
from sqlalchemy_utils import ChoiceType, JSONType, ScalarListType, TSVectorType
1215
except ImportError:
@@ -145,21 +148,15 @@ def convert_column_to_float(type, column, registry=None):
145148

146149
@convert_sqlalchemy_type.register(types.Enum)
147150
def convert_enum_to_enum(type, column, registry=None):
148-
enum_class = getattr(type, 'enum_class', None)
149-
if enum_class: # Check if an enum.Enum type is used
150-
graphene_type = Enum.from_enum(enum_class)
151-
else: # Nope, just a list of string options
152-
items = zip(type.enums, type.enums)
153-
graphene_type = Enum(type.name, items)
154151
return Field(
155-
graphene_type,
152+
lambda: enum_for_sa_enum(type, registry or get_global_registry()),
156153
description=get_column_doc(column),
157154
required=not (is_column_nullable(column)),
158155
)
159156

160157

161158
@convert_sqlalchemy_type.register(ChoiceType)
162-
def convert_column_to_enum(type, column, registry=None):
159+
def convert_choice_to_enum(type, column, registry=None):
163160
name = "{}_{}".format(column.table.name, column.name).upper()
164161
return Enum(name, type.choices, description=get_column_doc(column))
165162

Diff for: graphene_sqlalchemy/enums.py

+204
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
from sqlalchemy import Column
2+
from sqlalchemy.types import Enum as SQLAlchemyEnumType
3+
4+
from graphene import Argument, Enum, List
5+
6+
from .registry import get_global_registry
7+
from .utils import EnumValue, to_enum_value_name, to_type_name
8+
9+
10+
def convert_sa_to_graphene_enum(sa_enum, fallback_name=None):
11+
"""Convert the given SQLAlchemy Enum type to a Graphene Enum type.
12+
13+
The name of the Graphene Enum will be determined as follows:
14+
If the SQLAlchemy Enum is based on a Python Enum, use the name
15+
of the Python Enum. Otherwise, if the SQLAlchemy Enum is named,
16+
use the SQL name after conversion to a type name. Otherwise, use
17+
the given fallback_name or raise an error if it is empty.
18+
19+
The Enum value names are converted to upper case if necessary.
20+
"""
21+
if not isinstance(sa_enum, SQLAlchemyEnumType):
22+
raise TypeError(
23+
"Expected sqlalchemy.types.Enum, but got: {!r}".format(sa_enum)
24+
)
25+
enum_class = sa_enum.enum_class
26+
if enum_class:
27+
if all(to_enum_value_name(key) == key for key in enum_class.__members__):
28+
return Enum.from_enum(enum_class)
29+
name = enum_class.__name__
30+
members = [
31+
(to_enum_value_name(key), value.value)
32+
for key, value in enum_class.__members__.items()
33+
]
34+
else:
35+
sql_enum_name = sa_enum.name
36+
if sql_enum_name:
37+
name = to_type_name(sql_enum_name)
38+
elif fallback_name:
39+
name = fallback_name
40+
else:
41+
raise TypeError("No type name specified for {!r}".format(sa_enum))
42+
members = [(to_enum_value_name(key), key) for key in sa_enum.enums]
43+
return Enum(name, members)
44+
45+
46+
def enum_for_sa_enum(sa_enum, registry):
47+
"""Return the Graphene Enum type for the specified SQLAlchemy Enum type."""
48+
if not isinstance(sa_enum, SQLAlchemyEnumType):
49+
raise TypeError(
50+
"Expected sqlalchemy.types.Enum, but got: {!r}".format(sa_enum)
51+
)
52+
enum = registry.get_graphene_enum_for_sa_enum(sa_enum)
53+
if not enum:
54+
enum = convert_sa_to_graphene_enum(sa_enum)
55+
registry.register_enum(sa_enum, enum)
56+
return enum
57+
58+
59+
def enum_for_field(obj_type, field_name):
60+
"""Return the Graphene Enum type for the specified Graphene field."""
61+
from .types import SQLAlchemyObjectType
62+
63+
if not issubclass(obj_type, SQLAlchemyObjectType):
64+
raise TypeError(
65+
"Expected SQLAlchemyObjectType, but got: {!r}".format(obj_type))
66+
if not field_name or not isinstance(field_name, str):
67+
raise TypeError(
68+
"Expected a field name, but got: {!r}".format(field_name))
69+
registry = obj_type._meta.registry or get_global_registry()
70+
orm_field = registry.get_orm_field_for_graphene_field(obj_type, field_name)
71+
if orm_field is None:
72+
raise TypeError("Cannot get {}.{}".format(obj_type._meta.name, field_name))
73+
if not isinstance(orm_field, Column):
74+
raise TypeError(
75+
"{}.{} does not map to model column".format(obj_type._meta.name, field_name)
76+
)
77+
sa_enum = orm_field.type
78+
if not isinstance(sa_enum, SQLAlchemyEnumType):
79+
raise TypeError(
80+
"{}.{} does not map to enum column".format(obj_type._meta.name, field_name)
81+
)
82+
enum = registry.get_graphene_enum_for_sa_enum(sa_enum)
83+
if not enum:
84+
fallback_name = obj_type._meta.name + to_type_name(field_name)
85+
enum = convert_sa_to_graphene_enum(sa_enum, fallback_name)
86+
registry.register_enum(sa_enum, enum)
87+
return enum
88+
89+
90+
def _default_sort_enum_symbol_name(column_name, sort_asc=True):
91+
return to_enum_value_name(column_name) + ("_ASC" if sort_asc else "_DESC")
92+
93+
94+
def sort_enum_for_object_type(
95+
obj_type, name=None, only_fields=None, only_indexed=None, get_symbol_name=None
96+
):
97+
"""Return Graphene Enum for sorting the given SQLAlchemyObjectType.
98+
99+
Parameters
100+
- obj_type : SQLAlchemyObjectType
101+
The object type for which the sort Enum shall be generated.
102+
- name : str, optional, default None
103+
Name to use for the sort Enum.
104+
If not provided, it will be set to the object type name + 'SortEnum'
105+
- only_fields : sequence, optional, default None
106+
If this is set, only fields from this sequence will be considered.
107+
- only_indexed : bool, optional, default False
108+
If this is set, only indexed columns will be considered.
109+
- get_symbol_name : function, optional, default None
110+
Function which takes the column name and a boolean indicating
111+
if the sort direction is ascending, and returns the symbol name
112+
for the current column and sort direction. If no such function
113+
is passed, a default function will be used that creates the symbols
114+
'foo_asc' and 'foo_desc' for a column with the name 'foo'.
115+
116+
Returns
117+
- Enum
118+
The Graphene Enum type
119+
"""
120+
name = name or obj_type._meta.name + "SortEnum"
121+
registry = obj_type._meta.registry or get_global_registry()
122+
enum = registry.get_sort_enum_for_object_type(obj_type)
123+
custom_options = dict(
124+
only_fields=only_fields,
125+
only_indexed=only_indexed,
126+
get_symbol_name=get_symbol_name,
127+
)
128+
if enum:
129+
if name != enum.__name__ or custom_options != enum.custom_options:
130+
raise ValueError(
131+
"Sort enum for {} has already been customized".format(obj_type)
132+
)
133+
else:
134+
members = []
135+
default = []
136+
fields = obj_type._meta.fields
137+
get_name = get_symbol_name or _default_sort_enum_symbol_name
138+
for field_name in fields:
139+
if only_fields and field_name not in only_fields:
140+
continue
141+
orm_field = registry.get_orm_field_for_graphene_field(obj_type, field_name)
142+
if not isinstance(orm_field, Column):
143+
continue
144+
if only_indexed and not (orm_field.primary_key or orm_field.index):
145+
continue
146+
asc_name = get_name(orm_field.name, True)
147+
asc_value = EnumValue(asc_name, orm_field.asc())
148+
desc_name = get_name(orm_field.name, False)
149+
desc_value = EnumValue(desc_name, orm_field.desc())
150+
if orm_field.primary_key:
151+
default.append(asc_value)
152+
members.extend(((asc_name, asc_value), (desc_name, desc_value)))
153+
enum = Enum(name, members)
154+
enum.default = default # store default as attribute
155+
enum.custom_options = custom_options
156+
registry.register_sort_enum_for_object_type(obj_type, enum)
157+
return enum
158+
159+
160+
def sort_argument_for_object_type(
161+
obj_type,
162+
enum_name=None,
163+
only_fields=None,
164+
only_indexed=None,
165+
get_symbol_name=None,
166+
has_default=True,
167+
):
168+
""""Returns Graphene Argument for sorting the given SQLAlchemyObjectType.
169+
170+
Parameters
171+
- obj_type : SQLAlchemyObjectType
172+
The object type for which the sort Argument shall be generated.
173+
- enum_name : str, optional, default None
174+
Name to use for the sort Enum.
175+
If not provided, it will be set to the object type name + 'SortEnum'
176+
- only_fields : sequence, optional, default None
177+
If this is set, only fields from this sequence will be considered.
178+
- only_indexed : bool, optional, default False
179+
If this is set, only indexed columns will be considered.
180+
- get_symbol_name : function, optional, default None
181+
Function which takes the column name and a boolean indicating
182+
if the sort direction is ascending, and returns the symbol name
183+
for the current column and sort direction. If no such function
184+
is passed, a default function will be used that creates the symbols
185+
'foo_asc' and 'foo_desc' for a column with the name 'foo'.
186+
- has_default : bool, optional, default True
187+
If this is set to False, no sorting will happen when this argument is not
188+
passed. Otherwise results will be sortied by the primary key(s) of the model.
189+
190+
Returns
191+
- Enum
192+
A Graphene Argument that accepts a list of sorting directions for the model.
193+
"""
194+
enum = sort_enum_for_object_type(
195+
obj_type,
196+
enum_name,
197+
only_fields=only_fields,
198+
only_indexed=only_indexed,
199+
get_symbol_name=get_symbol_name,
200+
)
201+
if not has_default:
202+
enum.default = None
203+
204+
return Argument(List(enum), default_value=enum.default)

Diff for: graphene_sqlalchemy/fields.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from graphene.relay.connection import PageInfo
99
from graphql_relay.connection.arrayconnection import connection_from_list_slice
1010

11-
from .utils import get_query, sort_argument_for_model
11+
from .utils import get_query
1212

1313
log = logging.getLogger()
1414

@@ -84,10 +84,9 @@ def __init__(self, type, *args, **kwargs):
8484
if "sort" not in kwargs and issubclass(type, Connection):
8585
# Let super class raise if type is not a Connection
8686
try:
87-
model = type.Edge.node._type._meta.model
88-
kwargs.setdefault("sort", sort_argument_for_model(model))
89-
except Exception:
90-
raise Exception(
87+
kwargs.setdefault("sort", type.Edge.node._type.sort_argument())
88+
except (AttributeError, TypeError):
89+
raise TypeError(
9190
'Cannot create sort argument for {}. A model is required. Set the "sort" argument'
9291
" to None to disabling the creation of the sort query argument".format(
9392
type.__name__

0 commit comments

Comments
 (0)