9
9
from graphene .relay .connection import PageInfo
10
10
from graphql_relay .connection .arrayconnection import connection_from_list_slice
11
11
12
+ from .batching import get_batch_resolver
12
13
from .utils import get_query
13
14
14
15
@@ -33,14 +34,8 @@ def model(self):
33
34
return self .type ._meta .node ._meta .model
34
35
35
36
@classmethod
36
- def get_query (cls , model , info , sort = None , ** args ):
37
- query = get_query (model , info .context )
38
- if sort is not None :
39
- if isinstance (sort , six .string_types ):
40
- query = query .order_by (sort .value )
41
- else :
42
- query = query .order_by (* (col .value for col in sort ))
43
- return query
37
+ def get_query (cls , model , info , ** args ):
38
+ return get_query (model , info .context )
44
39
45
40
@classmethod
46
41
def resolve_connection (cls , connection_type , model , info , args , resolved ):
@@ -78,6 +73,7 @@ def get_resolver(self, parent_resolver):
78
73
return partial (self .connection_resolver , parent_resolver , self .type , self .model )
79
74
80
75
76
+ # TODO Rename this to SortableSQLAlchemyConnectionField
81
77
class SQLAlchemyConnectionField (UnsortedSQLAlchemyConnectionField ):
82
78
def __init__ (self , type , * args , ** kwargs ):
83
79
if "sort" not in kwargs and issubclass (type , Connection ):
@@ -95,11 +91,37 @@ def __init__(self, type, *args, **kwargs):
95
91
del kwargs ["sort" ]
96
92
super (SQLAlchemyConnectionField , self ).__init__ (type , * args , ** kwargs )
97
93
94
+ @classmethod
95
+ def get_query (cls , model , info , sort = None , ** args ):
96
+ query = get_query (model , info .context )
97
+ if sort is not None :
98
+ if isinstance (sort , six .string_types ):
99
+ query = query .order_by (sort .value )
100
+ else :
101
+ query = query .order_by (* (col .value for col in sort ))
102
+ return query
103
+
104
+
105
+ class BatchSQLAlchemyConnectionField (UnsortedSQLAlchemyConnectionField ):
106
+ """
107
+ This is currently experimental.
108
+ The API and behavior may change in future versions.
109
+ Use at your own risk.
110
+ """
111
+ def get_resolver (self , parent_resolver ):
112
+ return partial (self .connection_resolver , self .resolver , self .type , self .model )
113
+
114
+ @classmethod
115
+ def from_relationship (cls , relationship , registry , ** field_kwargs ):
116
+ model = relationship .mapper .entity
117
+ model_type = registry .get_type_for_model (model )
118
+ return cls (model_type ._meta .connection , resolver = get_batch_resolver (relationship ), ** field_kwargs )
119
+
98
120
99
121
def default_connection_field_factory (relationship , registry , ** field_kwargs ):
100
122
model = relationship .mapper .entity
101
123
model_type = registry .get_type_for_model (model )
102
- return __connectionFactory (model_type , ** field_kwargs )
124
+ return __connectionFactory (model_type . _meta . connection , ** field_kwargs )
103
125
104
126
105
127
# TODO Remove in next major version
0 commit comments