Skip to content

Commit 43df4eb

Browse files
feat: Support Sorting in Batch ConnectionFields & Deprecate UnsortedConnectionField(#355)
* Enable sorting when batching is enabled * Deprecate UnsortedSQLAlchemyConnectionField and resetting RelationshipLoader between queries * Use field_name instead of column.key to build sort enum names to ensure the enum will get the actula field_name * Adjust batching test to honor different selet in query structure in sqla1.2 * Ensure that UnsortedSQLAlchemyConnectionField skips sort argument if it gets passed. * add test for batch sorting with custom ormfield Co-authored-by: Sabar Dasgupta <[email protected]>
1 parent bb7af4b commit 43df4eb

File tree

6 files changed

+534
-257
lines changed

6 files changed

+534
-257
lines changed

Diff for: graphene_sqlalchemy/batching.py

+103-75
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""The dataloader uses "select in loading" strategy to load related entities."""
2-
from typing import Any
2+
from asyncio import get_event_loop
3+
from typing import Any, Dict
34

45
import aiodataloader
56
import sqlalchemy
@@ -10,6 +11,90 @@
1011
is_sqlalchemy_version_less_than)
1112

1213

14+
class RelationshipLoader(aiodataloader.DataLoader):
15+
cache = False
16+
17+
def __init__(self, relationship_prop, selectin_loader):
18+
super().__init__()
19+
self.relationship_prop = relationship_prop
20+
self.selectin_loader = selectin_loader
21+
22+
async def batch_load_fn(self, parents):
23+
"""
24+
Batch loads the relationships of all the parents as one SQL statement.
25+
26+
There is no way to do this out-of-the-box with SQLAlchemy but
27+
we can piggyback on some internal APIs of the `selectin`
28+
eager loading strategy. It's a bit hacky but it's preferable
29+
than re-implementing and maintainnig a big chunk of the `selectin`
30+
loader logic ourselves.
31+
32+
The approach here is to build a regular query that
33+
selects the parent and `selectin` load the relationship.
34+
But instead of having the query emits 2 `SELECT` statements
35+
when callling `all()`, we skip the first `SELECT` statement
36+
and jump right before the `selectin` loader is called.
37+
To accomplish this, we have to construct objects that are
38+
normally built in the first part of the query in order
39+
to call directly `SelectInLoader._load_for_path`.
40+
41+
TODO Move this logic to a util in the SQLAlchemy repo as per
42+
SQLAlchemy's main maitainer suggestion.
43+
See https://git.io/JewQ7
44+
"""
45+
child_mapper = self.relationship_prop.mapper
46+
parent_mapper = self.relationship_prop.parent
47+
session = Session.object_session(parents[0])
48+
49+
# These issues are very unlikely to happen in practice...
50+
for parent in parents:
51+
# assert parent.__mapper__ is parent_mapper
52+
# All instances must share the same session
53+
assert session is Session.object_session(parent)
54+
# The behavior of `selectin` is undefined if the parent is dirty
55+
assert parent not in session.dirty
56+
57+
# Should the boolean be set to False? Does it matter for our purposes?
58+
states = [(sqlalchemy.inspect(parent), True) for parent in parents]
59+
60+
# For our purposes, the query_context will only used to get the session
61+
query_context = None
62+
if is_sqlalchemy_version_less_than('1.4'):
63+
query_context = QueryContext(session.query(parent_mapper.entity))
64+
else:
65+
parent_mapper_query = session.query(parent_mapper.entity)
66+
query_context = parent_mapper_query._compile_context()
67+
68+
if is_sqlalchemy_version_less_than('1.4'):
69+
self.selectin_loader._load_for_path(
70+
query_context,
71+
parent_mapper._path_registry,
72+
states,
73+
None,
74+
child_mapper,
75+
)
76+
else:
77+
self.selectin_loader._load_for_path(
78+
query_context,
79+
parent_mapper._path_registry,
80+
states,
81+
None,
82+
child_mapper,
83+
None,
84+
)
85+
return [
86+
getattr(parent, self.relationship_prop.key) for parent in parents
87+
]
88+
89+
90+
# Cache this across `batch_load_fn` calls
91+
# This is so SQL string generation is cached under-the-hood via `bakery`
92+
# Caching the relationship loader for each relationship prop.
93+
RELATIONSHIP_LOADERS_CACHE: Dict[
94+
sqlalchemy.orm.relationships.RelationshipProperty, RelationshipLoader
95+
] = {}
96+
97+
1398
def get_data_loader_impl() -> Any: # pragma: no cover
1499
"""Graphene >= 3.1.1 ships a copy of aiodataloader with minor fixes. To preserve backward-compatibility,
15100
aiodataloader is used in conjunction with older versions of graphene"""
@@ -25,80 +110,23 @@ def get_data_loader_impl() -> Any: # pragma: no cover
25110

26111

27112
def get_batch_resolver(relationship_prop):
28-
# Cache this across `batch_load_fn` calls
29-
# This is so SQL string generation is cached under-the-hood via `bakery`
30-
selectin_loader = strategies.SelectInLoader(relationship_prop, (('lazy', 'selectin'),))
31-
32-
class RelationshipLoader(aiodataloader.DataLoader):
33-
cache = False
34-
35-
async def batch_load_fn(self, parents):
36-
"""
37-
Batch loads the relationships of all the parents as one SQL statement.
38-
39-
There is no way to do this out-of-the-box with SQLAlchemy but
40-
we can piggyback on some internal APIs of the `selectin`
41-
eager loading strategy. It's a bit hacky but it's preferable
42-
than re-implementing and maintainnig a big chunk of the `selectin`
43-
loader logic ourselves.
44-
45-
The approach here is to build a regular query that
46-
selects the parent and `selectin` load the relationship.
47-
But instead of having the query emits 2 `SELECT` statements
48-
when callling `all()`, we skip the first `SELECT` statement
49-
and jump right before the `selectin` loader is called.
50-
To accomplish this, we have to construct objects that are
51-
normally built in the first part of the query in order
52-
to call directly `SelectInLoader._load_for_path`.
53-
54-
TODO Move this logic to a util in the SQLAlchemy repo as per
55-
SQLAlchemy's main maitainer suggestion.
56-
See https://git.io/JewQ7
57-
"""
58-
child_mapper = relationship_prop.mapper
59-
parent_mapper = relationship_prop.parent
60-
session = Session.object_session(parents[0])
61-
62-
# These issues are very unlikely to happen in practice...
63-
for parent in parents:
64-
# assert parent.__mapper__ is parent_mapper
65-
# All instances must share the same session
66-
assert session is Session.object_session(parent)
67-
# The behavior of `selectin` is undefined if the parent is dirty
68-
assert parent not in session.dirty
69-
70-
# Should the boolean be set to False? Does it matter for our purposes?
71-
states = [(sqlalchemy.inspect(parent), True) for parent in parents]
72-
73-
# For our purposes, the query_context will only used to get the session
74-
query_context = None
75-
if is_sqlalchemy_version_less_than('1.4'):
76-
query_context = QueryContext(session.query(parent_mapper.entity))
77-
else:
78-
parent_mapper_query = session.query(parent_mapper.entity)
79-
query_context = parent_mapper_query._compile_context()
80-
81-
if is_sqlalchemy_version_less_than('1.4'):
82-
selectin_loader._load_for_path(
83-
query_context,
84-
parent_mapper._path_registry,
85-
states,
86-
None,
87-
child_mapper
88-
)
89-
else:
90-
selectin_loader._load_for_path(
91-
query_context,
92-
parent_mapper._path_registry,
93-
states,
94-
None,
95-
child_mapper,
96-
None
97-
)
98-
99-
return [getattr(parent, relationship_prop.key) for parent in parents]
100-
101-
loader = RelationshipLoader()
113+
"""Get the resolve function for the given relationship."""
114+
115+
def _get_loader(relationship_prop):
116+
"""Retrieve the cached loader of the given relationship."""
117+
loader = RELATIONSHIP_LOADERS_CACHE.get(relationship_prop, None)
118+
if loader is None or loader.loop != get_event_loop():
119+
selectin_loader = strategies.SelectInLoader(
120+
relationship_prop, (('lazy', 'selectin'),)
121+
)
122+
loader = RelationshipLoader(
123+
relationship_prop=relationship_prop,
124+
selectin_loader=selectin_loader,
125+
)
126+
RELATIONSHIP_LOADERS_CACHE[relationship_prop] = loader
127+
return loader
128+
129+
loader = _get_loader(relationship_prop)
102130

103131
async def resolve(root, info, **args):
104132
return await loader.load(root)

Diff for: graphene_sqlalchemy/enums.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -144,9 +144,9 @@ def sort_enum_for_object_type(
144144
column = orm_field.columns[0]
145145
if only_indexed and not (column.primary_key or column.index):
146146
continue
147-
asc_name = get_name(column.key, True)
147+
asc_name = get_name(field_name, True)
148148
asc_value = EnumValue(asc_name, column.asc())
149-
desc_name = get_name(column.key, False)
149+
desc_name = get_name(field_name, False)
150150
desc_value = EnumValue(desc_name, column.desc())
151151
if column.primary_key:
152152
default.append(asc_value)

Diff for: graphene_sqlalchemy/fields.py

+69-47
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from .utils import EnumValue, get_query
1515

1616

17-
class UnsortedSQLAlchemyConnectionField(ConnectionField):
17+
class SQLAlchemyConnectionField(ConnectionField):
1818
@property
1919
def type(self):
2020
from .types import SQLAlchemyObjectType
@@ -37,13 +37,45 @@ def type(self):
3737
)
3838
return nullable_type.connection
3939

40+
def __init__(self, type_, *args, **kwargs):
41+
nullable_type = get_nullable_type(type_)
42+
if "sort" not in kwargs and nullable_type and issubclass(nullable_type, Connection):
43+
# Let super class raise if type is not a Connection
44+
try:
45+
kwargs.setdefault("sort", nullable_type.Edge.node._type.sort_argument())
46+
except (AttributeError, TypeError):
47+
raise TypeError(
48+
'Cannot create sort argument for {}. A model is required. Set the "sort" argument'
49+
" to None to disabling the creation of the sort query argument".format(
50+
nullable_type.__name__
51+
)
52+
)
53+
elif "sort" in kwargs and kwargs["sort"] is None:
54+
del kwargs["sort"]
55+
super(SQLAlchemyConnectionField, self).__init__(type_, *args, **kwargs)
56+
4057
@property
4158
def model(self):
4259
return get_nullable_type(self.type)._meta.node._meta.model
4360

4461
@classmethod
45-
def get_query(cls, model, info, **args):
46-
return get_query(model, info.context)
62+
def get_query(cls, model, info, sort=None, **args):
63+
query = get_query(model, info.context)
64+
if sort is not None:
65+
if not isinstance(sort, list):
66+
sort = [sort]
67+
sort_args = []
68+
# ensure consistent handling of graphene Enums, enum values and
69+
# plain strings
70+
for item in sort:
71+
if isinstance(item, enum.Enum):
72+
sort_args.append(item.value.value)
73+
elif isinstance(item, EnumValue):
74+
sort_args.append(item.value)
75+
else:
76+
sort_args.append(item)
77+
query = query.order_by(*sort_args)
78+
return query
4779

4880
@classmethod
4981
def resolve_connection(cls, connection_type, model, info, args, resolved):
@@ -90,59 +122,49 @@ def wrap_resolve(self, parent_resolver):
90122
)
91123

92124

93-
# TODO Rename this to SortableSQLAlchemyConnectionField
94-
class SQLAlchemyConnectionField(UnsortedSQLAlchemyConnectionField):
125+
# TODO Remove in next major version
126+
class UnsortedSQLAlchemyConnectionField(SQLAlchemyConnectionField):
95127
def __init__(self, type_, *args, **kwargs):
96-
nullable_type = get_nullable_type(type_)
97-
if "sort" not in kwargs and issubclass(nullable_type, Connection):
98-
# Let super class raise if type is not a Connection
99-
try:
100-
kwargs.setdefault("sort", nullable_type.Edge.node._type.sort_argument())
101-
except (AttributeError, TypeError):
102-
raise TypeError(
103-
'Cannot create sort argument for {}. A model is required. Set the "sort" argument'
104-
" to None to disabling the creation of the sort query argument".format(
105-
nullable_type.__name__
106-
)
107-
)
108-
elif "sort" in kwargs and kwargs["sort"] is None:
109-
del kwargs["sort"]
110-
super(SQLAlchemyConnectionField, self).__init__(type_, *args, **kwargs)
111-
112-
@classmethod
113-
def get_query(cls, model, info, sort=None, **args):
114-
query = get_query(model, info.context)
115-
if sort is not None:
116-
if not isinstance(sort, list):
117-
sort = [sort]
118-
sort_args = []
119-
# ensure consistent handling of graphene Enums, enum values and
120-
# plain strings
121-
for item in sort:
122-
if isinstance(item, enum.Enum):
123-
sort_args.append(item.value.value)
124-
elif isinstance(item, EnumValue):
125-
sort_args.append(item.value)
126-
else:
127-
sort_args.append(item)
128-
query = query.order_by(*sort_args)
129-
return query
128+
if "sort" in kwargs and kwargs["sort"] is not None:
129+
warnings.warn(
130+
"UnsortedSQLAlchemyConnectionField does not support sorting. "
131+
"All sorting arguments will be ignored."
132+
)
133+
kwargs["sort"] = None
134+
warnings.warn(
135+
"UnsortedSQLAlchemyConnectionField is deprecated and will be removed in the next "
136+
"major version. Use SQLAlchemyConnectionField instead and either don't "
137+
"provide the `sort` argument or set it to None if you do not want sorting.",
138+
DeprecationWarning,
139+
)
140+
super(UnsortedSQLAlchemyConnectionField, self).__init__(type_, *args, **kwargs)
130141

131142

132-
class BatchSQLAlchemyConnectionField(UnsortedSQLAlchemyConnectionField):
143+
class BatchSQLAlchemyConnectionField(SQLAlchemyConnectionField):
133144
"""
134145
This is currently experimental.
135146
The API and behavior may change in future versions.
136147
Use at your own risk.
137148
"""
138149

139-
def wrap_resolve(self, parent_resolver):
140-
return partial(
141-
self.connection_resolver,
142-
self.resolver,
143-
get_nullable_type(self.type),
144-
self.model,
145-
)
150+
@classmethod
151+
def connection_resolver(cls, resolver, connection_type, model, root, info, **args):
152+
if root is None:
153+
resolved = resolver(root, info, **args)
154+
on_resolve = partial(cls.resolve_connection, connection_type, model, info, args)
155+
else:
156+
relationship_prop = None
157+
for relationship in root.__class__.__mapper__.relationships:
158+
if relationship.mapper.class_ == model:
159+
relationship_prop = relationship
160+
break
161+
resolved = get_batch_resolver(relationship_prop)(root, info, **args)
162+
on_resolve = partial(cls.resolve_connection, connection_type, root, info, args)
163+
164+
if is_thenable(resolved):
165+
return Promise.resolve(resolved).then(on_resolve)
166+
167+
return on_resolve(resolved)
146168

147169
@classmethod
148170
def from_relationship(cls, relationship, registry, **field_kwargs):

Diff for: graphene_sqlalchemy/tests/models.py

+18
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,24 @@ class Article(Base):
110110
headline = Column(String(100))
111111
pub_date = Column(Date())
112112
reporter_id = Column(Integer(), ForeignKey("reporters.id"))
113+
readers = relationship(
114+
"Reader", secondary="articles_readers", back_populates="articles"
115+
)
116+
117+
118+
class Reader(Base):
119+
__tablename__ = "readers"
120+
id = Column(Integer(), primary_key=True)
121+
name = Column(String(100))
122+
articles = relationship(
123+
"Article", secondary="articles_readers", back_populates="readers"
124+
)
125+
126+
127+
class ArticleReader(Base):
128+
__tablename__ = "articles_readers"
129+
article_id = Column(Integer(), ForeignKey("articles.id"), primary_key=True)
130+
reader_id = Column(Integer(), ForeignKey("readers.id"), primary_key=True)
113131

114132

115133
class ReflectedEditor(type):

0 commit comments

Comments
 (0)