|
| 1 | +import enum |
1 | 2 | import warnings
|
2 | 3 | from functools import partial
|
3 | 4 |
|
|
6 | 7 |
|
7 | 8 | from graphene import NonNull
|
8 | 9 | from graphene.relay import Connection, ConnectionField
|
9 |
| -from graphene.relay.connection import PageInfo, connection_adapter, page_info_adapter |
10 |
| -from graphql_relay.connection.arrayconnection import connection_from_array_slice |
| 10 | +from graphene.relay.connection import (PageInfo, connection_adapter, |
| 11 | + page_info_adapter) |
| 12 | +from graphql_relay.connection.arrayconnection import \ |
| 13 | + connection_from_array_slice |
11 | 14 |
|
12 | 15 | from .batching import get_batch_resolver
|
13 |
| -from .utils import get_query |
| 16 | +from .utils import EnumValue, get_query |
14 | 17 |
|
15 | 18 |
|
16 | 19 | class UnsortedSQLAlchemyConnectionField(ConnectionField):
|
@@ -112,10 +115,19 @@ def __init__(self, type_, *args, **kwargs):
|
112 | 115 | def get_query(cls, model, info, sort=None, **args):
|
113 | 116 | query = get_query(model, info.context)
|
114 | 117 | if sort is not None:
|
115 |
| - if isinstance(sort, str): |
116 |
| - query = query.order_by(sort.value) |
117 |
| - else: |
118 |
| - query = query.order_by(*(col.value for col in sort)) |
| 118 | + if not isinstance(sort, list): |
| 119 | + sort = [sort] |
| 120 | + sort_args = [] |
| 121 | + # ensure consistent handling of graphene Enums, enum values and |
| 122 | + # plain strings |
| 123 | + for item in sort: |
| 124 | + if isinstance(item, enum.Enum): |
| 125 | + sort_args.append(item.value.value) |
| 126 | + elif isinstance(item, EnumValue): |
| 127 | + sort_args.append(item.value) |
| 128 | + else: |
| 129 | + sort_args.append(item) |
| 130 | + query = query.order_by(*sort_args) |
119 | 131 | return query
|
120 | 132 |
|
121 | 133 |
|
|
0 commit comments