Skip to content

Commit 6906abd

Browse files
committed
Implement sub-shells
1 parent 817258d commit 6906abd

File tree

16 files changed

+653
-546
lines changed

16 files changed

+653
-546
lines changed

ipykernel/athread.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import asyncio
2+
import inspect
3+
import threading
4+
5+
import janus
6+
7+
8+
class AThread(threading.Thread):
9+
"""A thread that can run async tasks.
10+
"""
11+
12+
def __init__(self, name, awaitables=[]):
13+
super().__init__(name=name, daemon=True)
14+
self._aws = list(awaitables)
15+
self._lock = threading.Lock()
16+
self.__initialized = False
17+
self._stopped = False
18+
19+
def run(self):
20+
asyncio.run(self._main())
21+
22+
async def _main(self):
23+
with self._lock:
24+
if self._stopped:
25+
return
26+
self._queue = janus.Queue()
27+
self.__initialized = True
28+
self._tasks = [asyncio.create_task(aw) for aw in self._aws]
29+
30+
while True:
31+
try:
32+
aw = await self._queue.async_q.get()
33+
except BaseException:
34+
break
35+
if aw is None:
36+
break
37+
self._tasks.append(asyncio.create_task(aw))
38+
39+
for task in self._tasks:
40+
task.cancel()
41+
42+
def create_task(self, awaitable):
43+
"""Create a task in the thread (thread-safe).
44+
"""
45+
with self._lock:
46+
if self.__initialized:
47+
self._queue.sync_q.put(awaitable)
48+
else:
49+
self._aws.append(awaitable)
50+
51+
def stop(self):
52+
"""Stop the thread (thread-safe).
53+
"""
54+
with self._lock:
55+
if self.__initialized:
56+
self._queue.sync_q.put(None)
57+
else:
58+
self._stopped = True

ipykernel/control.py

Lines changed: 4 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,11 @@
1-
"""A thread for a control channel."""
2-
from threading import Thread
1+
from .athread import AThread
32

4-
from tornado.ioloop import IOLoop
53

6-
7-
class ControlThread(Thread):
4+
class ControlThread(AThread):
85
"""A thread for a control channel."""
96

10-
def __init__(self, **kwargs):
7+
def __init__(self):
118
"""Initialize the thread."""
12-
Thread.__init__(self, name="Control", **kwargs)
13-
self.io_loop = IOLoop(make_current=False)
9+
super().__init__(name="Control")
1410
self.pydev_do_not_trace = True
1511
self.is_pydev_daemon_thread = True
16-
17-
def run(self):
18-
"""Run the thread."""
19-
self.name = "Control"
20-
try:
21-
self.io_loop.start()
22-
finally:
23-
self.io_loop.close()
24-
25-
def stop(self):
26-
"""Stop the thread.
27-
28-
This method is threadsafe.
29-
"""
30-
self.io_loop.add_callback(self.io_loop.stop)

ipykernel/debugger.py

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Debugger implementation for the IPython kernel."""
2+
import asyncio
23
import os
34
import re
45
import sys
@@ -7,8 +8,6 @@
78
import zmq
89
from IPython.core.getipython import get_ipython
910
from IPython.core.inputtransformer2 import leading_empty_lines
10-
from tornado.locks import Event
11-
from tornado.queues import Queue
1211
from zmq.utils import jsonapi
1312

1413
try:
@@ -116,7 +115,7 @@ def __init__(self, event_callback, log):
116115
self.tcp_buffer = ""
117116
self._reset_tcp_pos()
118117
self.event_callback = event_callback
119-
self.message_queue: Queue[t.Any] = Queue()
118+
self.message_queue: asyncio.Queue[t.Any] = asyncio.Queue()
120119
self.log = log
121120

122121
def _reset_tcp_pos(self):
@@ -192,17 +191,17 @@ async def get_message(self):
192191
class DebugpyClient:
193192
"""A client for debugpy."""
194193

195-
def __init__(self, log, debugpy_stream, event_callback):
194+
def __init__(self, log, debugpy_socket, event_callback):
196195
"""Initialize the client."""
197196
self.log = log
198-
self.debugpy_stream = debugpy_stream
197+
self.debugpy_socket = debugpy_socket
199198
self.event_callback = event_callback
200199
self.message_queue = DebugpyMessageQueue(self._forward_event, self.log)
201200
self.debugpy_host = "127.0.0.1"
202201
self.debugpy_port = -1
203202
self.routing_id = None
204203
self.wait_for_attach = True
205-
self.init_event = Event()
204+
self.init_event = asyncio.Event()
206205
self.init_event_seq = -1
207206

208207
def _get_endpoint(self):
@@ -215,9 +214,9 @@ def _forward_event(self, msg):
215214
self.init_event_seq = msg["seq"]
216215
self.event_callback(msg)
217216

218-
def _send_request(self, msg):
217+
async def _send_request(self, msg):
219218
if self.routing_id is None:
220-
self.routing_id = self.debugpy_stream.socket.getsockopt(ROUTING_ID)
219+
self.routing_id = self.debugpy_socket.getsockopt(ROUTING_ID)
221220
content = jsonapi.dumps(
222221
msg,
223222
default=json_default,
@@ -232,7 +231,7 @@ def _send_request(self, msg):
232231
self.log.debug("DEBUGPYCLIENT:")
233232
self.log.debug(self.routing_id)
234233
self.log.debug(buf)
235-
self.debugpy_stream.send_multipart((self.routing_id, buf))
234+
await self.debugpy_socket.send_multipart((self.routing_id, buf))
236235

237236
async def _wait_for_response(self):
238237
# Since events are never pushed to the message_queue
@@ -250,7 +249,7 @@ async def _handle_init_sequence(self):
250249
"seq": int(self.init_event_seq) + 1,
251250
"command": "configurationDone",
252251
}
253-
self._send_request(configurationDone)
252+
await self._send_request(configurationDone)
254253

255254
# 3] Waits for configurationDone response
256255
await self._wait_for_response()
@@ -262,7 +261,7 @@ async def _handle_init_sequence(self):
262261
def get_host_port(self):
263262
"""Get the host debugpy port."""
264263
if self.debugpy_port == -1:
265-
socket = self.debugpy_stream.socket
264+
socket = self.debugpy_socket
266265
socket.bind_to_random_port("tcp://" + self.debugpy_host)
267266
self.endpoint = socket.getsockopt(zmq.LAST_ENDPOINT).decode("utf-8")
268267
socket.unbind(self.endpoint)
@@ -272,14 +271,14 @@ def get_host_port(self):
272271

273272
def connect_tcp_socket(self):
274273
"""Connect to the tcp socket."""
275-
self.debugpy_stream.socket.connect(self._get_endpoint())
276-
self.routing_id = self.debugpy_stream.socket.getsockopt(ROUTING_ID)
274+
self.debugpy_socket.connect(self._get_endpoint())
275+
self.routing_id = self.debugpy_socket.getsockopt(ROUTING_ID)
277276

278277
def disconnect_tcp_socket(self):
279278
"""Disconnect from the tcp socket."""
280-
self.debugpy_stream.socket.disconnect(self._get_endpoint())
279+
self.debugpy_socket.disconnect(self._get_endpoint())
281280
self.routing_id = None
282-
self.init_event = Event()
281+
self.init_event = asyncio.Event()
283282
self.init_event_seq = -1
284283
self.wait_for_attach = True
285284

@@ -289,7 +288,7 @@ def receive_dap_frame(self, frame):
289288

290289
async def send_dap_request(self, msg):
291290
"""Send a dap request."""
292-
self._send_request(msg)
291+
await self._send_request(msg)
293292
if self.wait_for_attach and msg["command"] == "attach":
294293
rep = await self._handle_init_sequence()
295294
self.wait_for_attach = False
@@ -319,17 +318,17 @@ class Debugger:
319318
static_debug_msg_types = ["debugInfo", "inspectVariables", "richInspectVariables", "modules"]
320319

321320
def __init__(
322-
self, log, debugpy_stream, event_callback, shell_socket, session, just_my_code=True
321+
self, log, debugpy_socket, event_callback, shell_socket, session, just_my_code=True
323322
):
324323
"""Initialize the debugger."""
325324
self.log = log
326-
self.debugpy_client = DebugpyClient(log, debugpy_stream, self._handle_event)
325+
self.debugpy_client = DebugpyClient(log, debugpy_socket, self._handle_event)
327326
self.shell_socket = shell_socket
328327
self.session = session
329328
self.is_started = False
330329
self.event_callback = event_callback
331330
self.just_my_code = just_my_code
332-
self.stopped_queue: Queue[t.Any] = Queue()
331+
self.stopped_queue: asyncio.Queue[t.Any] = asyncio.Queue()
333332

334333
self.started_debug_handlers = {}
335334
for msg_type in Debugger.started_debug_msg_types:
@@ -406,7 +405,7 @@ async def handle_stopped_event(self):
406405
def tcp_client(self):
407406
return self.debugpy_client
408407

409-
def start(self):
408+
async def start(self):
410409
"""Start the debugger."""
411410
if not self.debugpy_initialized:
412411
tmp_dir = get_tmp_directory()
@@ -424,7 +423,12 @@ def start(self):
424423
(self.shell_socket.getsockopt(ROUTING_ID)),
425424
)
426425

427-
ident, msg = self.session.recv(self.shell_socket, mode=0)
426+
msg = await self.shell_socket.recv_multipart()
427+
idents, msg = self.session.feed_identities(msg, copy=True)
428+
try:
429+
msg = self.session.deserialize(msg, content=True, copy=True)
430+
except BaseException:
431+
self.log.error("Invalid Message", exc_info=True)
428432
self.debugpy_initialized = msg["content"]["status"] == "ok"
429433

430434
# Don't remove leading empty lines when debugging so the breakpoints are correctly positioned
@@ -685,7 +689,7 @@ async def process_request(self, message):
685689
if self.is_started:
686690
self.log.info("The debugger has already started")
687691
else:
688-
self.is_started = self.start()
692+
self.is_started = await self.start()
689693
if self.is_started:
690694
self.log.info("The debugger has started")
691695
else:

ipykernel/inprocess/ipkernel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ class InProcessKernel(IPythonKernel):
5151
_underlying_iopub_socket = Instance(DummySocket, ())
5252
iopub_thread: IOPubThread = Instance(IOPubThread) # type:ignore[assignment]
5353

54-
shell_stream = Instance(DummySocket, ())
54+
#shell_stream = Instance(DummySocket, ())
5555

5656
@default("iopub_thread")
5757
def _default_iopub_thread(self):

0 commit comments

Comments
 (0)