Skip to content

Commit 4acc688

Browse files
feat: support comm package
See ipython/ipykernel#973
1 parent 463b4a2 commit 4acc688

File tree

1 file changed

+43
-1
lines changed

1 file changed

+43
-1
lines changed

solara/server/kernel.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import struct
77
from typing import Set
88

9+
import ipykernel
910
import ipykernel.kernelbase
1011
import jupyter_client.session as session
1112
from ipykernel.comm import CommManager
@@ -17,6 +18,38 @@
1718

1819
logger = logging.getLogger("solara.server.kernel")
1920

21+
22+
ipykernel_version = tuple(map(int, ipykernel.__version__.split(".")))
23+
if ipykernel_version >= (6, 18, 0):
24+
import comm.base_comm
25+
26+
class Comm(comm.base_comm.BaseComm):
27+
def __init__(self, **kwargs) -> None:
28+
self.kernel = ipykernel.kernelbase.Kernel.instance()
29+
super().__init__(**kwargs)
30+
31+
def publish_msg(self, msg_type, data=None, metadata=None, buffers=None, **keys):
32+
data = {} if data is None else data
33+
metadata = {} if metadata is None else metadata
34+
content = dict(data=data, comm_id=self.comm_id, **keys)
35+
self.kernel.session.send(
36+
self.kernel.iopub_socket,
37+
msg_type,
38+
content,
39+
metadata=metadata,
40+
parent=self.kernel.get_parent("shell"),
41+
ident=self.topic,
42+
buffers=buffers,
43+
)
44+
45+
comm.create_comm = Comm
46+
47+
def get_comm_manager():
48+
from .app import get_current_context
49+
50+
return get_current_context().kernel.comm_manager
51+
52+
comm.get_comm_manager = get_comm_manager
2053
# from notebook.base.zmqhandlers import serialize_binary_message
2154
# this saves us a depdendency on notebook/jupyter_server when e.g.
2255
# running on pyodide
@@ -160,7 +193,16 @@ def __init__(self):
160193
# solara/server/kernel.py:111: error: "SessionWebsocket" has no attribute "stream"
161194
# not sure why we cannot reproduce that locally
162195
self.session.stream = self.iopub_socket # type: ignore
163-
self.comm_manager = CommManager(parent=self, kernel=self)
196+
if ipykernel_version >= (6, 18, 0):
197+
# from this version on, ipykernel uses the comm package https://github.com/ipython/ipykernel/pull/973
198+
self.comm_manager = CommManager(parent=self, kernel=self)
199+
import ipywidgets.widgets.widget
200+
201+
if hasattr(ipywidgets.widgets.widget, "Comm"):
202+
ipywidgets.widgets.widget.Comm = Comm
203+
ipywidgets.widgets.widget.Widget.comm.klass = Comm
204+
else:
205+
self.comm_manager = CommManager(parent=self, kernel=self)
164206
self.shell = None
165207
self.log = logging.getLogger("fake")
166208

0 commit comments

Comments
 (0)