Skip to content

Commit 495efa8

Browse files
committed
[WIP] Relay Node/Connection support.
1 parent 5e4a1c6 commit 495efa8

17 files changed

+871
-11
lines changed

epoxy/contrib/__init__.py

Whitespace-only changes.

epoxy/contrib/relay/__init__.py

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
__author__ = 'jake'
2+
3+
from .mixin import RelayMixin
4+
5+
__all__ = ['RelayMixin']
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from graphql.core.type import GraphQLArgument, GraphQLInt, GraphQLString
2+
3+
__author__ = 'jake'
4+
5+
connection_args = {
6+
'before': GraphQLArgument(GraphQLString),
7+
'after': GraphQLArgument(GraphQLString),
8+
'first': GraphQLArgument(GraphQLInt),
9+
'last': GraphQLArgument(GraphQLInt),
10+
}
+38
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
from ..utils import base64, unbase64
2+
3+
4+
class CursorFactory(object):
5+
def __init__(self, prefix):
6+
self.prefix = prefix
7+
self.cursor_type = int
8+
self.max_cursor_length = 10
9+
10+
def from_offset(self, offset):
11+
"""
12+
Creates the cursor string from an offset.
13+
"""
14+
return base64(self.prefix + str(offset))
15+
16+
def to_offset(self, cursor):
17+
"""
18+
Rederives the offset from the cursor string.
19+
"""
20+
try:
21+
return self.cursor_type(unbase64(cursor)[len(self.prefix):len(self.prefix) + self.max_cursor_length])
22+
except:
23+
return None
24+
25+
def get_offset(self, cursor, default_offset=0):
26+
"""
27+
Given an optional cursor and a default offset, returns the offset
28+
to use; if the cursor contains a valid offset, that will be used,
29+
otherwise it will be the default.
30+
"""
31+
if cursor is None:
32+
return default_offset
33+
34+
offset = self.to_offset(cursor)
35+
try:
36+
return self.cursor_type(offset)
37+
except:
38+
return default_offset
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
from bisect import bisect_left, bisect_right
2+
from .cursor import CursorFactory
3+
4+
cursor = CursorFactory('sc:')
5+
6+
7+
class SortedCollection(object):
8+
"""Sequence sorted by a key function.
9+
10+
SortedCollection() is much easier to work with than using bisect() directly.
11+
It supports key functions like those use in sorted(), min(), and max().
12+
The result of the key function call is saved so that keys can be searched
13+
efficiently.
14+
15+
Instead of returning an insertion-point which can be hard to interpret, the
16+
five find-methods return a specific item in the sequence. They can scan for
17+
exact matches, the last item less-than-or-equal to a key, or the first item
18+
greater-than-or-equal to a key.
19+
20+
Once found, an item's ordinal position can be located with the index() method.
21+
New items can be added with the insert() and insert_right() methods.
22+
Old items can be deleted with the remove() method.
23+
24+
The usual sequence methods are provided to support indexing, slicing,
25+
length lookup, clearing, copying, forward and reverse iteration, contains
26+
checking, item counts, item removal, and a nice looking repr.
27+
28+
Finding and indexing are O(log n) operations while iteration and insertion
29+
are O(n). The initial sort is O(n log n).
30+
31+
The key function is stored in the 'key' attribute for easy introspection or
32+
so that you can assign a new key function (triggering an automatic re-sort).
33+
34+
In short, the class was designed to handle all of the common use cases for
35+
bisect but with a simpler API and support for key functions.
36+
37+
>>> from pprint import pprint
38+
>>> from operator import itemgetter
39+
40+
>>> s = SortedCollection(key=itemgetter(2))
41+
>>> for record in [
42+
... ('roger', 'young', 30),
43+
... ('angela', 'jones', 28),
44+
... ('bill', 'smith', 22),
45+
... ('david', 'thomas', 32)]:
46+
... s.insert(record)
47+
48+
>>> pprint(list(s)) # show records sorted by age
49+
[('bill', 'smith', 22),
50+
('angela', 'jones', 28),
51+
('roger', 'young', 30),
52+
('david', 'thomas', 32)]
53+
54+
>>> s.find_le(29) # find oldest person aged 29 or younger
55+
('angela', 'jones', 28)
56+
>>> s.find_lt(28) # find oldest person under 28
57+
('bill', 'smith', 22)
58+
>>> s.find_gt(28) # find youngest person over 28
59+
('roger', 'young', 30)
60+
61+
>>> r = s.find_ge(32) # find youngest person aged 32 or older
62+
>>> s.index(r) # get the index of their record
63+
3
64+
>>> s[3] # fetch the record at that index
65+
('david', 'thomas', 32)
66+
67+
>>> s.key = itemgetter(0) # now sort by first name
68+
>>> pprint(list(s))
69+
[('angela', 'jones', 28),
70+
('bill', 'smith', 22),
71+
('david', 'thomas', 32),
72+
('roger', 'young', 30)]
73+
74+
"""
75+
76+
def __init__(self, key=None):
77+
self._given_key = key
78+
key = (lambda x: x) if key is None else key
79+
self._keys = []
80+
self._items = []
81+
self._key = key
82+
83+
def clear(self):
84+
self._keys = []
85+
self._items = []
86+
87+
def copy(self):
88+
cls = self.__class__(key=self._key)
89+
cls._items = self._items[:]
90+
cls._keys = self._keys[:]
91+
92+
def __len__(self):
93+
return len(self._items)
94+
95+
def __getitem__(self, i):
96+
return self._items[i]
97+
98+
def __iter__(self):
99+
return iter(self._items)
100+
101+
def __reversed__(self):
102+
return reversed(self._items)
103+
104+
def __repr__(self):
105+
return '%s(%r, key=%s)' % (
106+
self.__class__.__name__,
107+
self._items,
108+
getattr(self._given_key, '__name__', repr(self._given_key))
109+
)
110+
111+
def __reduce__(self):
112+
return self.__class__, (self._items, self._given_key)
113+
114+
def __contains__(self, item):
115+
k = self._key(item)
116+
i = bisect_left(self._keys, k)
117+
j = bisect_right(self._keys, k)
118+
return item in self._items[i:j]
119+
120+
def index(self, item):
121+
"""Find the position of an item. Raise ValueError if not found.'"""
122+
k = self._key(item)
123+
i = bisect_left(self._keys, k)
124+
j = bisect_right(self._keys, k)
125+
return self._items[i:j].index(item) + i
126+
127+
def count(self, item):
128+
"""Return number of occurrences of item'"""
129+
k = self._key(item)
130+
i = bisect_left(self._keys, k)
131+
j = bisect_right(self._keys, k)
132+
return self._items[i:j].count(item)
133+
134+
def insert(self, item):
135+
"""Insert a new item. If equal keys are found, add to the left'"""
136+
k = self._key(item)
137+
i = bisect_left(self._keys, k)
138+
if i != len(self) and self._keys[i] == k:
139+
raise ValueError(u'An item with the same key {} already exists in this collection.'.format(k))
140+
141+
self._keys.insert(i, k)
142+
self._items.insert(i, item)
143+
144+
def remove(self, item):
145+
"""Remove first occurrence of item. Raise ValueError if not found'"""
146+
i = self.index(item)
147+
del self._keys[i]
148+
del self._items[i]
149+
150+
def bisect_left(self, k):
151+
return bisect_left(self._keys, k)
152+
153+
def bisect_right(self, k):
154+
return bisect_right(self._keys, k)
155+
156+
@staticmethod
157+
def empty_connection(relay, type_name):
158+
Connection, Edge = relay.get_connection_and_edge_types(type_name)
159+
160+
return Connection(
161+
edges=[],
162+
page_info=relay.PageInfo(
163+
start_cursor=None,
164+
end_cursor=None,
165+
has_previous_page=False,
166+
has_next_page=False,
167+
)
168+
)
169+
170+
def get_connection(self, relay, type_name, args):
171+
Connection, Edge = relay.get_connection_and_edge_types(type_name)
172+
before = args.get('before')
173+
after = args.get('after')
174+
first = args.get('first')
175+
last = args.get('last')
176+
177+
count = len(self)
178+
if not count:
179+
return self.empty_connection(relay, type_name)
180+
181+
begin_key = cursor.get_offset(after, None) or self._keys[0]
182+
end_key = cursor.get_offset(before, None) or self._keys[-1]
183+
184+
begin = self.bisect_left(begin_key)
185+
end = self.bisect_right(end_key)
186+
187+
if begin >= count or begin >= end:
188+
return self.empty_connection(relay, type_name)
189+
190+
first_preslice_cursor = cursor.from_offset(self._keys[begin])
191+
last_preslice_cursor = cursor.from_offset(self._keys[min(end, count) - 1])
192+
193+
if first is not None:
194+
end = min(begin + first, end)
195+
if last is not None:
196+
begin = max(end - last, begin)
197+
198+
if begin >= count or begin >= end:
199+
return self.empty_connection(relay, type_name)
200+
201+
sliced_data = self._items[begin:end]
202+
203+
edges = [Edge(node=node, cursor=cursor.from_offset(self._key(node))) for node in sliced_data]
204+
first_edge = edges[0]
205+
last_edge = edges[-1]
206+
207+
return Connection(
208+
edges=edges,
209+
page_info=relay.PageInfo(
210+
start_cursor=first_edge.cursor,
211+
end_cursor=last_edge.cursor,
212+
has_previous_page=(first_edge.cursor != first_preslice_cursor),
213+
has_next_page=(last_edge.cursor != last_preslice_cursor)
214+
)
215+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
__author__ = 'jake'
+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from collections import defaultdict
2+
3+
from operator import attrgetter
4+
from ..connections.sorted_collection import SortedCollection
5+
6+
7+
class InMemoryDataSource(object):
8+
def __init__(self):
9+
self.objects_by_type_and_id = defaultdict(dict)
10+
self.objects_by_type = defaultdict(lambda: SortedCollection(key=attrgetter('id')))
11+
12+
def add(self, obj):
13+
self.objects_by_type_and_id[obj.T][obj.id] = obj
14+
self.objects_by_type[obj.T].insert(obj)
15+
16+
def fetch_node(self, object_type, id, resolve_info):
17+
return self.objects_by_type_and_id[object_type].get(id)
18+
19+
def get_collection(self, object_type):
20+
return self.objects_by_type[object_type()]
21+
22+
def make_connection_resolver(self, relay, object_type_thunk):
23+
def resolver(obj, args, info):
24+
object_type = relay.R[object_type_thunk]()
25+
return self.objects_by_type[object_type].get_connection(relay, object_type.name, args)
26+
27+
return resolver

epoxy/contrib/relay/mixin.py

+85
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
from graphql.core.type import GraphQLArgument
2+
from graphql.core.type.definition import GraphQLObjectType
3+
from .connections import connection_args
4+
from .utils import base64, unbase64
5+
6+
7+
class RelayMixin(object):
8+
def __init__(self, registry):
9+
self.R = registry
10+
self._node_field = None
11+
self._connections = {}
12+
self.data_source = None
13+
14+
@property
15+
def NodeField(self):
16+
return self.R.Field(
17+
self.R.Node,
18+
description='Fetches an object given its ID',
19+
args={
20+
'id': GraphQLArgument(
21+
self.R.ID.NonNull(),
22+
description='The ID of an object'
23+
)
24+
},
25+
resolver=lambda obj, args, info: self.fetch_node(args.get('id'), info)
26+
)
27+
28+
def set_data_source(self, data_source):
29+
self.data_source = data_source
30+
31+
def get_connection_and_edge_types(self, type_name):
32+
return self._connections[type_name]
33+
34+
def register_types(self):
35+
R = self.R
36+
37+
class Node(R.Interface):
38+
id = R.ID.NonNull(description='The id of the object.')
39+
40+
resolve_id = self.resolve_node_id
41+
42+
class PageInfo(R.ObjectType):
43+
has_next_page = R.Boolean.NonNull(description='When paginating forwards, are there more items?')
44+
has_previous_page = R.Boolean.NonNull(description='When paginating backwards, are there more items?')
45+
start_cursor = R.String(description='When paginating backwards, the cursor to continue.')
46+
end_cursor = R.String(description='When paginating forwards, the cursor to continue.')
47+
48+
self.Node = Node
49+
self.PageInfo = PageInfo
50+
51+
def fetch_node(self, id, info):
52+
object_type_name, object_id = unbase64(id).split(':', 1)
53+
object_type = self.R[object_type_name]()
54+
assert isinstance(object_type, GraphQLObjectType)
55+
return self.data_source.fetch_node(object_type, object_id, info)
56+
57+
def resolve_node_id(self, obj, args, info):
58+
type = self.R.Node().resolve_type(obj, info)
59+
return base64('%s:%s' % (type, obj.id))
60+
61+
def connection_definitions(self, name, object_type):
62+
R = self.R
63+
64+
if name in self._connections:
65+
return self._connections[name]
66+
67+
class Edge(R.ObjectType):
68+
_name = '{}Edge'.format(name)
69+
node = R[object_type](description='The item at the end of the edge')
70+
cursor = R.String.NonNull(description='A cursor for use in pagination')
71+
72+
class Connection(R.ObjectType):
73+
_name = '{}Connection'.format(name)
74+
75+
page_info = R.PageInfo.NonNull
76+
edges = R[Edge].List
77+
78+
self._connections[name] = Connection, Edge
79+
return Connection, Edge
80+
81+
def Connection(self, name, object_type, args=None, **kwargs):
82+
args = args or {}
83+
args.update(connection_args)
84+
field = self.R.Field(self.connection_definitions(name, object_type)[0], args=args, **kwargs)
85+
return field

0 commit comments

Comments
 (0)