Skip to content

Commit 0255582

Browse files
authored
Merge pull request #1 from 311devs/msgpack_hooks_possibility
msgpack_hooks_possibility
2 parents bd37703 + 81ba98a commit 0255582

File tree

5 files changed

+132
-19
lines changed

5 files changed

+132
-19
lines changed

tarantool/connection.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,12 @@ def __init__(self, host, port,
9191
connect_now=True,
9292
encoding=ENCODING_DEFAULT,
9393
call_16=False,
94-
connection_timeout=CONNECTION_TIMEOUT):
94+
connection_timeout=CONNECTION_TIMEOUT,
95+
pack_default=None,
96+
unpack_object_hook=None,
97+
unpack_list_hook=None,
98+
unpack_object_pairs_hook=None,
99+
unpack_ext_hook=None):
95100
'''
96101
Initialize a connection to the server.
97102
@@ -126,6 +131,13 @@ def __init__(self, host, port,
126131
self.encoding = encoding
127132
self.call_16 = call_16
128133
self.connection_timeout = connection_timeout
134+
self.pack_default = pack_default
135+
self.unpack_hooks = {
136+
"object_hook": unpack_object_hook,
137+
"list_hook": unpack_list_hook,
138+
"object_pairs_hook": unpack_object_pairs_hook,
139+
"ext_hook": unpack_ext_hook,
140+
}
129141
if connect_now:
130142
self.connect()
131143

tarantool/request.py

+19-14
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ def __init__(self, conn):
6464
self.conn = conn
6565
self._sync = None
6666
self._body = ''
67+
self.pack_default = getattr(conn, "pack_default", None)
6768

6869
def __bytes__(self):
6970
return self.header(len(self._body)) + self._body
@@ -88,6 +89,10 @@ def header(self, length):
8889

8990
return msgpack.dumps(length + len(header)) + header
9091

92+
def msgpack_dumps(self, obj):
93+
return msgpack.dumps(obj, default=self.pack_default)
94+
95+
9196

9297
class RequestInsert(Request):
9398
'''
@@ -102,7 +107,7 @@ def __init__(self, conn, space_no, values):
102107
super(RequestInsert, self).__init__(conn)
103108
assert isinstance(values, (tuple, list))
104109

105-
request_body = msgpack.dumps({IPROTO_SPACE_ID: space_no,
110+
request_body = self.msgpack_dumps({IPROTO_SPACE_ID: space_no,
106111
IPROTO_TUPLE: values})
107112

108113
self._body = request_body
@@ -131,19 +136,19 @@ def sha1(values):
131136
hash2 = sha1((hash1,))
132137
scramble = sha1((salt, hash2))
133138
scramble = strxor(hash1, scramble)
134-
request_body = msgpack.dumps({IPROTO_USER_NAME: user,
139+
request_body = self.msgpack_dumps({IPROTO_USER_NAME: user,
135140
IPROTO_TUPLE: ("chap-sha1", scramble)})
136141
self._body = request_body
137142

138143
def header(self, length):
139144
self._sync = self.conn.generate_sync()
140145
# Set IPROTO_SCHEMA_ID: 0 to avoid SchemaReloadException
141146
# It is ok to use 0 in auth every time.
142-
header = msgpack.dumps({IPROTO_CODE: self.request_type,
147+
header = self.msgpack_dumps({IPROTO_CODE: self.request_type,
143148
IPROTO_SYNC: self._sync,
144149
IPROTO_SCHEMA_ID: 0})
145150

146-
return msgpack.dumps(length + len(header)) + header
151+
return self.msgpack_dumps(length + len(header)) + header
147152

148153

149154
class RequestReplace(Request):
@@ -159,7 +164,7 @@ def __init__(self, conn, space_no, values):
159164
super(RequestReplace, self).__init__(conn)
160165
assert isinstance(values, (tuple, list))
161166

162-
request_body = msgpack.dumps({IPROTO_SPACE_ID: space_no,
167+
request_body = self.msgpack_dumps({IPROTO_SPACE_ID: space_no,
163168
IPROTO_TUPLE: values})
164169

165170
self._body = request_body
@@ -177,7 +182,7 @@ def __init__(self, conn, space_no, index_no, key):
177182
'''
178183
super(RequestDelete, self).__init__(conn)
179184

180-
request_body = msgpack.dumps({IPROTO_SPACE_ID: space_no,
185+
request_body = self.msgpack_dumps({IPROTO_SPACE_ID: space_no,
181186
IPROTO_INDEX_ID: index_no,
182187
IPROTO_KEY: key})
183188

@@ -193,7 +198,7 @@ class RequestSelect(Request):
193198
# pylint: disable=W0231
194199
def __init__(self, conn, space_no, index_no, key, offset, limit, iterator):
195200
super(RequestSelect, self).__init__(conn)
196-
request_body = msgpack.dumps({IPROTO_SPACE_ID: space_no,
201+
request_body = self.msgpack_dumps({IPROTO_SPACE_ID: space_no,
197202
IPROTO_INDEX_ID: index_no,
198203
IPROTO_OFFSET: offset,
199204
IPROTO_LIMIT: limit,
@@ -214,7 +219,7 @@ class RequestUpdate(Request):
214219
def __init__(self, conn, space_no, index_no, key, op_list):
215220
super(RequestUpdate, self).__init__(conn)
216221

217-
request_body = msgpack.dumps({IPROTO_SPACE_ID: space_no,
222+
request_body = self.msgpack_dumps({IPROTO_SPACE_ID: space_no,
218223
IPROTO_INDEX_ID: index_no,
219224
IPROTO_KEY: key,
220225
IPROTO_TUPLE: op_list})
@@ -235,7 +240,7 @@ def __init__(self, conn, name, args, call_16):
235240
super(RequestCall, self).__init__(conn)
236241
assert isinstance(args, (list, tuple))
237242

238-
request_body = msgpack.dumps({IPROTO_FUNCTION_NAME: name,
243+
request_body = self.msgpack_dumps({IPROTO_FUNCTION_NAME: name,
239244
IPROTO_TUPLE: args})
240245

241246
self._body = request_body
@@ -252,7 +257,7 @@ def __init__(self, conn, name, args):
252257
super(RequestEval, self).__init__(conn)
253258
assert isinstance(args, (list, tuple))
254259

255-
request_body = msgpack.dumps({IPROTO_EXPR: name,
260+
request_body = self.msgpack_dumps({IPROTO_EXPR: name,
256261
IPROTO_TUPLE: args})
257262

258263
self._body = request_body
@@ -280,7 +285,7 @@ class RequestUpsert(Request):
280285
def __init__(self, conn, space_no, index_no, tuple_value, op_list):
281286
super(RequestUpsert, self).__init__(conn)
282287

283-
request_body = msgpack.dumps({IPROTO_SPACE_ID: space_no,
288+
request_body = self.msgpack_dumps({IPROTO_SPACE_ID: space_no,
284289
IPROTO_INDEX_ID: index_no,
285290
IPROTO_TUPLE: tuple_value,
286291
IPROTO_OPS: op_list})
@@ -297,7 +302,7 @@ class RequestJoin(Request):
297302
# pylint: disable=W0231
298303
def __init__(self, conn, server_uuid):
299304
super(RequestJoin, self).__init__(conn)
300-
request_body = msgpack.dumps({IPROTO_SERVER_UUID: server_uuid})
305+
request_body = self.msgpack_dumps({IPROTO_SERVER_UUID: server_uuid})
301306
self._body = request_body
302307

303308

@@ -312,7 +317,7 @@ def __init__(self, conn, cluster_uuid, server_uuid, vclock):
312317
super(RequestSubscribe, self).__init__(conn)
313318
assert isinstance(vclock, dict)
314319

315-
request_body = msgpack.dumps({
320+
request_body = self.msgpack_dumps({
316321
IPROTO_CLUSTER_UUID: cluster_uuid,
317322
IPROTO_SERVER_UUID: server_uuid,
318323
IPROTO_VCLOCK: vclock
@@ -329,6 +334,6 @@ class RequestOK(Request):
329334
# pylint: disable=W0231
330335
def __init__(self, conn, sync):
331336
super(RequestOK, self).__init__(conn)
332-
request_body = msgpack.dumps({IPROTO_CODE: self.request_type,
337+
request_body = self.msgpack_dumps({IPROTO_CODE: self.request_type,
333338
IPROTO_SYNC: sync})
334339
self._body = request_body

tarantool/response.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ def __init__(self, conn, response):
4646
:type body: array of bytes
4747
'''
4848

49+
unpack_kwargs = getattr(conn, "unpack_hooks", {})
50+
4951
# This is not necessary, because underlying list data structures are
5052
# created in the __new__().
5153
# super(Response, self).__init__()
@@ -54,11 +56,11 @@ def __init__(self, conn, response):
5456
# Get rid of the following warning.
5557
# > PendingDeprecationWarning: encoding is deprecated,
5658
# > Use raw=False instead.
57-
unpacker = msgpack.Unpacker(use_list=True, raw=False)
59+
unpacker = msgpack.Unpacker(use_list=True, raw=False, **unpack_kwargs)
5860
elif conn.encoding is not None:
59-
unpacker = msgpack.Unpacker(use_list=True, encoding=conn.encoding)
61+
unpacker = msgpack.Unpacker(use_list=True, encoding=conn.encoding, **unpack_kwargs)
6062
else:
61-
unpacker = msgpack.Unpacker(use_list=True)
63+
unpacker = msgpack.Unpacker(use_list=True, **unpack_kwargs)
6264

6365
unpacker.feed(response)
6466
header = unpacker.unpack()

unit/suites/__init__.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,11 @@
99
from .test_protocol import TestSuite_Protocol
1010
from .test_reconnect import TestSuite_Reconnect
1111
from .test_mesh import TestSuite_Mesh
12+
from .test_hooks import TestSuite_DefaultAndObjectHook
1213

1314
test_cases = (TestSuite_Schema, TestSuite_Request, TestSuite_Protocol,
14-
TestSuite_Reconnect, TestSuite_Mesh)
15+
TestSuite_Reconnect, TestSuite_Mesh,
16+
TestSuite_DefaultAndObjectHook,)
1517

1618
def load_tests(loader, tests, pattern):
1719
suite = unittest.TestSuite()

unit/suites/test_hooks.py

+92
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# -*- coding: utf-8 -*-
2+
3+
import sys
4+
import unittest
5+
from datetime import datetime
6+
7+
import tarantool
8+
from .lib.tarantool_server import TarantoolServer
9+
10+
11+
def object_unpack(obj):
12+
if obj.get("__type__") == "datetime":
13+
return datetime.fromtimestamp(obj['obj'])
14+
return obj
15+
16+
17+
def object_pack(obj):
18+
if isinstance(obj, datetime):
19+
return {"__type__": "datetime", "obj": obj.timestamp()}
20+
return obj
21+
22+
23+
def list_unpack(lst):
24+
if lst and lst[0] == "_my_datetime":
25+
return datetime.fromtimestamp(lst[1])
26+
return lst
27+
28+
29+
def list_pack(obj):
30+
if isinstance(obj, datetime):
31+
return ["_my_datetime", obj.timestamp()]
32+
return obj
33+
34+
35+
class TestSuite_DefaultAndObjectHook(unittest.TestCase):
36+
@classmethod
37+
def setUpClass(self):
38+
print(' PACK/UNPACK HOOKs '.center(70, '='), file=sys.stderr)
39+
print('-' * 70, file=sys.stderr)
40+
self.srv = TarantoolServer()
41+
self.srv.script = 'unit/suites/box.lua'
42+
self.srv.start()
43+
self.srv.admin.execute("simple_return = function(a) return a end")
44+
self.srv.admin.execute(
45+
"box.schema.user.grant('guest','execute','universe')")
46+
47+
def test_00_not_set(self):
48+
con = tarantool.Connection(self.srv.host, self.srv.args['primary'])
49+
50+
ret = con.call("simple_return",
51+
{"__type__": "datetime", "obj": 1546300800})
52+
self.assertDictEqual(ret._data[0],
53+
{"__type__": "datetime", "obj": 1546300800})
54+
55+
with self.assertRaises(TypeError):
56+
con.call("simple_return", datetime.fromtimestamp(1546300800))
57+
58+
def test_01_set_default(self):
59+
con = tarantool.Connection(self.srv.host, self.srv.args['primary'],
60+
pack_default=object_pack)
61+
62+
ret = con.call("simple_return", datetime.fromtimestamp(1546300800))
63+
self.assertDictEqual(ret._data[0],
64+
{"__type__": "datetime", "obj": 1546300800})
65+
66+
def test_02_set_object_hook(self):
67+
con = tarantool.Connection(self.srv.host, self.srv.args['primary'],
68+
unpack_object_hook=object_unpack)
69+
70+
ret = con.call("simple_return",
71+
{"__type__": "datetime", "obj": 1546300800})
72+
self.assertEqual(ret._data[0], datetime.fromtimestamp(1546300800))
73+
74+
ret = con.call("simple_return",
75+
{"__type__": "1datetime", "obj": 1546300800})
76+
self.assertEqual(ret._data[0],
77+
{"__type__": "1datetime", "obj": 1546300800})
78+
79+
def test_03_set_object_hook_and_default(self):
80+
dt = datetime.fromtimestamp(1546300800)
81+
con = tarantool.Connection(self.srv.host, self.srv.args['primary'],
82+
unpack_object_hook=object_unpack, pack_default=object_pack)
83+
ret = con.call("simple_return", dt)
84+
self.assertEqual(ret._data[0], dt)
85+
86+
87+
def test_04_set_list_hook_and_default(self):
88+
dt = datetime.fromtimestamp(1546300800)
89+
con = tarantool.Connection(self.srv.host, self.srv.args['primary'],
90+
unpack_list_hook=list_unpack, pack_default=list_pack)
91+
ret = con.call("simple_return", dt)
92+
self.assertEqual(ret._data[0], dt)

0 commit comments

Comments
 (0)