Skip to content

Commit cdcffaf

Browse files
committed
Support websocket subprotocols
1 parent 9d44029 commit cdcffaf

File tree

2 files changed

+328
-67
lines changed

2 files changed

+328
-67
lines changed

jupyter_server/base/zmqhandlers.py

+110-13
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,64 @@
2424
from .handlers import JupyterHandler
2525

2626

27+
def serialize_binary_message(msg):
28+
"""serialize a message as a binary blob
29+
30+
Header:
31+
32+
4 bytes: number of msg parts (nbufs) as 32b int
33+
4 * nbufs bytes: offset for each buffer as integer as 32b int
34+
35+
Offsets are from the start of the buffer, including the header.
36+
37+
Returns
38+
-------
39+
The message serialized to bytes.
40+
41+
"""
42+
# don't modify msg or buffer list in-place
43+
msg = msg.copy()
44+
buffers = list(msg.pop("buffers"))
45+
if sys.version_info < (3, 4):
46+
buffers = [x.tobytes() for x in buffers]
47+
bmsg = json.dumps(msg, default=json_default).encode("utf8")
48+
buffers.insert(0, bmsg)
49+
nbufs = len(buffers)
50+
offsets = [4 * (nbufs + 1)]
51+
for buf in buffers[:-1]:
52+
offsets.append(offsets[-1] + len(buf))
53+
offsets_buf = struct.pack("!" + "I" * (nbufs + 1), nbufs, *offsets)
54+
buffers.insert(0, offsets_buf)
55+
return b"".join(buffers)
56+
57+
58+
def deserialize_binary_message(bmsg):
59+
"""deserialize a message from a binary blog
60+
61+
Header:
62+
63+
4 bytes: number of msg parts (nbufs) as 32b int
64+
4 * nbufs bytes: offset for each buffer as integer as 32b int
65+
66+
Offsets are from the start of the buffer, including the header.
67+
68+
Returns
69+
-------
70+
message dictionary
71+
"""
72+
nbufs = struct.unpack("!i", bmsg[:4])[0]
73+
offsets = list(struct.unpack("!" + "I" * nbufs, bmsg[4 : 4 * (nbufs + 1)]))
74+
offsets.append(None)
75+
bufs = []
76+
for start, stop in zip(offsets[:-1], offsets[1:]):
77+
bufs.append(bmsg[start:stop])
78+
msg = json.loads(bufs[0].decode("utf8"))
79+
msg["header"] = extract_dates(msg["header"])
80+
msg["parent_header"] = extract_dates(msg["parent_header"])
81+
msg["buffers"] = bufs[1:]
82+
return msg
83+
84+
2785
# ping interval for keeping websockets alive (30 seconds)
2886
WS_PING_INTERVAL = 30000
2987

@@ -155,6 +213,37 @@ def send_error(self, *args, **kwargs):
155213
# we can close the connection more gracefully.
156214
self.stream.close()
157215

216+
def _reserialize_reply(self, msg_or_list, channel=None):
217+
"""Reserialize a reply message using JSON.
218+
219+
msg_or_list can be an already-deserialized msg dict or the zmq buffer list.
220+
If it is the zmq list, it will be deserialized with self.session.
221+
222+
This takes the msg list from the ZMQ socket and serializes the result for the websocket.
223+
This method should be used by self._on_zmq_reply to build messages that can
224+
be sent back to the browser.
225+
226+
"""
227+
if isinstance(msg_or_list, dict):
228+
# already unpacked
229+
msg = msg_or_list
230+
else:
231+
idents, msg_list = self.session.feed_identities(msg_or_list)
232+
msg = self.session.deserialize(msg_list)
233+
if channel:
234+
msg["channel"] = channel
235+
if msg["buffers"]:
236+
buf = serialize_binary_message(msg)
237+
return buf
238+
else:
239+
smsg = json.dumps(msg, default=json_default)
240+
return cast_unicode(smsg)
241+
242+
def select_subprotocol(self, subprotocols):
243+
selected_subprotocol = "0.0.1" if "0.0.1" in subprotocols else None
244+
# None is the default, "legacy" protocol
245+
return selected_subprotocol
246+
158247
def _on_zmq_reply(self, stream, msg_list):
159248
# Sometimes this gets triggered when the on_close method is scheduled in the
160249
# eventloop but hasn't been called.
@@ -163,19 +252,27 @@ def _on_zmq_reply(self, stream, msg_list):
163252
self.close()
164253
return
165254
channel = getattr(stream, "channel", None)
166-
offsets = []
167-
curr_sum = 0
168-
for msg in msg_list:
169-
length = len(msg)
170-
offsets.append(length + curr_sum)
171-
curr_sum += length
172-
layout = json.dumps({
173-
"channel": channel,
174-
"offsets": offsets,
175-
}).encode("utf-8")
176-
layout_length = len(layout).to_bytes(2, byteorder="little")
177-
bin_msg = b"".join([layout_length, layout] + msg_list)
178-
self.write_message(bin_msg, binary=True)
255+
if not self.selected_subprotocol:
256+
try:
257+
msg = self._reserialize_reply(msg_list, channel=channel)
258+
except Exception:
259+
self.log.critical("Malformed message: %r" % msg_list, exc_info=True)
260+
else:
261+
self.write_message(msg, binary=isinstance(msg, bytes))
262+
elif self.selected_subprotocol == "0.0.1":
263+
offsets = []
264+
curr_sum = 0
265+
for msg in msg_list:
266+
length = len(msg)
267+
offsets.append(length + curr_sum)
268+
curr_sum += length
269+
layout = json.dumps({
270+
"channel": channel,
271+
"offsets": offsets,
272+
}).encode("utf-8")
273+
layout_length = len(layout).to_bytes(2, byteorder="little")
274+
bin_msg = b"".join([layout_length, layout] + msg_list)
275+
self.write_message(bin_msg, binary=True)
179276

180277

181278
class AuthenticatedZMQStreamHandler(ZMQStreamHandler, JupyterHandler):

0 commit comments

Comments
 (0)