|
| 1 | +require 'libluazmq' |
| 2 | +require 'socket' |
| 3 | + |
| 4 | +ipython = {} |
| 5 | + |
| 6 | +dofile('session.lua') |
| 7 | + |
| 8 | +--! A file like object that publishes the stream to a 0MQ PUB socket. |
| 9 | +local OutStream = torch.class("ipython.OutStream") |
| 10 | + |
| 11 | +function OutStream:__init(session, pub_socket, name, max_buffer) |
| 12 | + self.session = session |
| 13 | + self.pub_socket = pub_socket |
| 14 | + self.name = name |
| 15 | + self._buffer = {} |
| 16 | + self._buffer_len = 0 |
| 17 | + self.max_buffer = max_buffer |
| 18 | + self.parent_header = {} |
| 19 | +end |
| 20 | + |
| 21 | +function OutStream:set_parent(parent) |
| 22 | + self.parent_header = extract_header(parent) |
| 23 | +end |
| 24 | + |
| 25 | +function OutStream:close() |
| 26 | + self.pub_socket = nil |
| 27 | +end |
| 28 | + |
| 29 | +function OutStream:flush() |
| 30 | + if not self.pub_socket then |
| 31 | + error("I/O operation on closed file") |
| 32 | + else |
| 33 | + if self._buffer then |
| 34 | + local data = table.concat(self._buffer) |
| 35 | + local content = { name = self.name, data = data } |
| 36 | + local msg = self.session.msg('stream', content, self.parent_header) |
| 37 | + print(ipython.Message(msg)) |
| 38 | + self.pub_socet.send_json(msg) |
| 39 | + self._buffer_len = 0 |
| 40 | + self._bufer = {} |
| 41 | + end |
| 42 | + end |
| 43 | +end |
| 44 | + |
| 45 | +function OutStream:isattr() |
| 46 | + return false |
| 47 | +end |
| 48 | +function OutStream:next() |
| 49 | + error("Read not supported on a write-only stream") |
| 50 | +end |
| 51 | +function OutStream:read() |
| 52 | + error("Read not supported on a write-only stream") |
| 53 | +end |
| 54 | +OutStream.readline = OutStream.read |
| 55 | +function OutStream:write(s) |
| 56 | + if not self.pub_socket then |
| 57 | + error("I/O operation on closed file") |
| 58 | + else |
| 59 | + self._buffer[#self._buffer+1] = s |
| 60 | + self._buffer_len = self._buffer_len + string.len(s) |
| 61 | + self:_maybe_send() |
| 62 | + end |
| 63 | +end |
| 64 | + |
| 65 | +function OutStream:_maybe_send() |
| 66 | + if string.find(self.buffer[#self.buffer], "\n") then |
| 67 | + self:flush() |
| 68 | + end |
| 69 | + if self._buffer_len > self.max_buffer then |
| 70 | + self:flush() |
| 71 | + end |
| 72 | +end |
| 73 | + |
| 74 | +function OutStream:writelines(sequence) |
| 75 | + if not self.pub_socket then |
| 76 | + error("I/O operation on closed file") |
| 77 | + else |
| 78 | + for _, s in ipairs(sequence) do |
| 79 | + self:write(s) |
| 80 | + end |
| 81 | + end |
| 82 | +end |
| 83 | + |
| 84 | +local DisplayHook = torch.class("ipython.DisplayHook") |
| 85 | + |
| 86 | +function DisplayHook:__init(session, pub_socket) |
| 87 | + self.session = session |
| 88 | + self.pub_socket = pub_socket |
| 89 | + self.parent_header = {} |
| 90 | +end |
| 91 | +function DisplayHook:__call(obj) |
| 92 | + if obj == nil then |
| 93 | + return |
| 94 | + end |
| 95 | + |
| 96 | + -- __builtin__._ = obj -- ? |
| 97 | + local msg = self.session:msg("pytout", { data = tostring(obj) }, self.parent_header) |
| 98 | + self.pub_socket:send_json(msg) |
| 99 | +end |
| 100 | +function DisplayHook:set_parent(parent) |
| 101 | + self.parent_header = extract_header(parent) |
| 102 | +end |
| 103 | + |
| 104 | + |
| 105 | +local RawInput = torch.class("ipython.RawInput") |
| 106 | +function RawInput:__init(session, socket) |
| 107 | + self.session = session |
| 108 | + self.socket = socket |
| 109 | +end |
| 110 | + |
| 111 | +function RawInput:__call(prompt) |
| 112 | + local msg = self.session:msg('raw_input') |
| 113 | + self.socket:send_json(msg) |
| 114 | + while true do |
| 115 | + local result, msg = self.socket:recv_json(zmq.NOBLOCK) |
| 116 | + if result then |
| 117 | + return msg.content.data |
| 118 | + end |
| 119 | + if msg ~= 'timeout' then |
| 120 | + error(msg) |
| 121 | + end |
| 122 | + end |
| 123 | +end |
| 124 | + |
| 125 | +local Kernel = torch.class("ipython.Kernel") |
| 126 | +function Kernel:__init(session, reply_socket, pub_socket) |
| 127 | + self.session = session |
| 128 | + self.reply_socket = reply_socket |
| 129 | + self.pub_socket = pub_socket |
| 130 | + self.user_ns = {} |
| 131 | + self.history = {} |
| 132 | + self.compiler = CommandCompiler() |
| 133 | + self.completer = KernelCompleter(self.user_ns) |
| 134 | + |
| 135 | + -- Build dict of handlers for message types |
| 136 | + self.handlers = {} |
| 137 | + for _, msg_type in ipairs({'execute_request', 'complete_request'}) do |
| 138 | + self.handlers[msg_type] = Kernel[msg_type] |
| 139 | + end |
| 140 | +end |
| 141 | + |
| 142 | +function Kernel:abort_queue() |
| 143 | + local ident, msg |
| 144 | + while true do |
| 145 | + local result |
| 146 | + result, ident = self.reply_socket:recv(zmq.NOBLOCK) |
| 147 | + if not result then |
| 148 | + if ident == 'timeout' then |
| 149 | + break |
| 150 | + end |
| 151 | + end |
| 152 | + if self.reply_socket.rcvmore ~= 0 then |
| 153 | + error("Unexpected missing message part") |
| 154 | + end |
| 155 | + msg = self.reply_socket:recv_json() |
| 156 | + print("Aborting:", ipython.Message(msg)) |
| 157 | + local msg_type = msg.msg_type |
| 158 | + local reply_type = msg_type:gmatch("_")[1] .. "_reply" |
| 159 | + local reply_msg = self.session.msg(reply_type, { status = 'aborted'}, msg) |
| 160 | + print(ipython.Message(reply_msg)) |
| 161 | + self.reply_socket:send(ident, zmq.SNDMORE) |
| 162 | + self.reply_socket:send_json(reply_msg) |
| 163 | + socket.sleep(0.1) |
| 164 | + end |
| 165 | +end |
| 166 | + |
| 167 | +function Kernel:execute_request(ident, parent) |
| 168 | + if not parent.content or not parent.content.code then |
| 169 | + print("Got bad msg: ", ipython.Message(parent)) |
| 170 | + return |
| 171 | + end |
| 172 | + local code = parent.content.code |
| 173 | + local pyin_msg = self.session:msg('pyin', {code=code}, parent) |
| 174 | + self.pub_socket:send_json(pyin_msg) |
| 175 | + local comp_code = self.compiler(code, '<zmq-kernel>') |
| 176 | + -- TODO sys.displayhook.set_parent(parent) |
| 177 | + local func = function() loadstring(comp_code) end |
| 178 | + setfenv(func, self.user_ns) |
| 179 | + local result, returned = pcall(func()) |
| 180 | + local reply_content |
| 181 | + if not result then |
| 182 | + local res = 'error' |
| 183 | + local tb = debug.traceback() |
| 184 | + local exc_content = { |
| 185 | + status = 'error', |
| 186 | + traceback = 'tb', |
| 187 | + etype = returned, |
| 188 | + evalue = returned |
| 189 | + } |
| 190 | + local exc_msg = self.session:msg('pyerr', exc_content, parent) |
| 191 | + self.pub_socket:send_json(exc_msg) |
| 192 | + reply_content = exc_content |
| 193 | + else |
| 194 | + reply_content = {status = 'ok'} |
| 195 | + end |
| 196 | + local reply_msg = self.session:msg('execute_reply', reply_content, parent) |
| 197 | + print(ipython.Message(reply_msg)) |
| 198 | + self.reply_socket:send(ident, zmq.SNDMORE) |
| 199 | + self.reply_socket:send_json(reply_msg) |
| 200 | + if reply_msg.content.status == 'error' then |
| 201 | + self:abort_queue() |
| 202 | + end |
| 203 | +end |
| 204 | + |
| 205 | +function Kernel:complete_request(ident, parent) |
| 206 | + local matches = { |
| 207 | + matches = self.complete(parent), |
| 208 | + status = 'ok' |
| 209 | + } |
| 210 | + local completion_msg = self.session:send(self.reply_socket, 'complete_reply', |
| 211 | + matches, parent, ident) |
| 212 | + print(completion_msg) |
| 213 | +end |
| 214 | +function Kernel:complete(msg) |
| 215 | + return self.completer:complete(msg.content.line, msg.content.text) |
| 216 | +end |
| 217 | +function Kernel:start() |
| 218 | + while true do |
| 219 | + local ident = self.reply_socket:recv() |
| 220 | + assert(self.reply_socket.rcvmore ~= 0, "Unexpected missing message part") |
| 221 | + local msg = self.reply_socket:recv_json() |
| 222 | + local omsg = ipython.Message(msg) |
| 223 | + print(omsg) |
| 224 | + local handler = self.handler[omsg.msg_type] |
| 225 | + if not handler then |
| 226 | + print("UNKNOWN MESSAGE TYPE: " .. omsg) |
| 227 | + else |
| 228 | + handler(ident, omsg) |
| 229 | + end |
| 230 | + end |
| 231 | +end |
| 232 | + |
| 233 | +function main() |
| 234 | + local c = zmq.init(1) |
| 235 | + local ip = '127.0.0.1' |
| 236 | + local port_base = 5555 |
| 237 | + local connection = 'tcp://' .. ip .. ":" |
| 238 | + local rep_conn = connection .. port_base |
| 239 | + local pub_conn = connection .. port_base + 1 |
| 240 | + |
| 241 | + print("Starting the kernel...") |
| 242 | + print("On: " .. rep_conn .. " " .. pub_conn) |
| 243 | + |
| 244 | + local session = ipython.Session({username='kernel'}) |
| 245 | + local reply_socket = c:socket(zmq.XREQ) |
| 246 | + reply_socket:bind(rep_conn) |
| 247 | + |
| 248 | + local pub_socket = c:socket(zmq.XREP) |
| 249 | + pub_socket:bind(pub_conn) |
| 250 | + |
| 251 | + local stdout = ipython.OutStream(session, pub_socket, 'stdout') |
| 252 | + local stderr = ipython.OutStream(session, pub_socket, 'stderr') |
| 253 | + print = function(args) |
| 254 | + stdout:write(table.concat(args)) |
| 255 | + end |
| 256 | + local display_hook = DisplayHook(session, pub_socket) |
| 257 | + -- sys.display_hook = display_hook |
| 258 | + |
| 259 | + local kernel = ipython.Kernel(session, reply_socket, pub_socket) |
| 260 | + kernel.user_ns['sleep'] = socket.sleep |
| 261 | + kernel.user_ns['s'] = "test string" |
| 262 | + |
| 263 | + print "Use Ctrl-\\ (NOT Ctrl-C!) to terminate." |
| 264 | + kernel.start() |
| 265 | + |
| 266 | +end |
| 267 | + |
| 268 | + |
| 269 | +main() |
0 commit comments