diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ad5d4fece..8781622d5 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -150,6 +150,11 @@ jobs: uses: jupyterlab/maintainer-tools/.github/actions/base-setup@v1 with: dependency_type: minimum + + - name: List installed packages + run: | + hatch run test:list + - name: Run the unit tests run: | hatch run test:nowarn || hatch run test:nowarn --lf diff --git a/ipykernel/iostream.py b/ipykernel/iostream.py index 9de6156b3..8b5e47b30 100644 --- a/ipykernel/iostream.py +++ b/ipykernel/iostream.py @@ -364,7 +364,7 @@ def __init__( echo : bool whether to echo output watchfd : bool (default, True) - Watch the file descripttor corresponding to the replaced stream. + Watch the file descriptor corresponding to the replaced stream. This is useful if you know some underlying code will write directly the file descriptor by its number. It will spawn a watching thread, that will swap the give file descriptor for a pipe, read from the @@ -408,19 +408,39 @@ def __init__( if ( watchfd - and (sys.platform.startswith("linux") or sys.platform.startswith("darwin")) - and ("PYTEST_CURRENT_TEST" not in os.environ) + and ( + (sys.platform.startswith("linux") or sys.platform.startswith("darwin")) + # Pytest set its own capture. Don't redirect from within pytest. + and ("PYTEST_CURRENT_TEST" not in os.environ) + ) + # allow forcing watchfd (mainly for tests) + or watchfd == "force" ): - # Pytest set its own capture. Dont redirect from within pytest. - self._should_watch = True self._setup_stream_redirects(name) if echo: if hasattr(echo, "read") and hasattr(echo, "write"): + # make sure we aren't trying to echo on the FD we're watching! + # that would cause an infinite loop, always echoing on itself + if self._should_watch: + try: + echo_fd = echo.fileno() + except Exception: + echo_fd = None + + if echo_fd is not None and echo_fd == self._original_stdstream_fd: + # echo on the _copy_ we made during + # this is the actual terminal FD now + echo = io.TextIOWrapper( + io.FileIO( + self._original_stdstream_copy, + "w", + ) + ) self.echo = echo else: - msg = "echo argument must be a file like object" + msg = "echo argument must be a file-like object" raise ValueError(msg) def isatty(self): @@ -433,7 +453,7 @@ def isatty(self): def _setup_stream_redirects(self, name): pr, pw = os.pipe() - fno = getattr(sys, name).fileno() + fno = self._original_stdstream_fd = getattr(sys, name).fileno() self._original_stdstream_copy = os.dup(fno) os.dup2(pw, fno) @@ -455,7 +475,13 @@ def close(self): """Close the stream.""" if self._should_watch: self._should_watch = False + # thread won't wake unless there's something to read + # writing something after _should_watch will not be echoed + os.write(self._original_stdstream_fd, b'\0') self.watch_fd_thread.join() + # restore original FDs + os.dup2(self._original_stdstream_copy, self._original_stdstream_fd) + os.close(self._original_stdstream_copy) if self._exc: etype, value, tb = self._exc traceback.print_exception(etype, value, tb) diff --git a/ipykernel/tests/test_io.py b/ipykernel/tests/test_io.py index 221af1f8b..6a9f65170 100644 --- a/ipykernel/tests/test_io.py +++ b/ipykernel/tests/test_io.py @@ -1,7 +1,12 @@ """Test IO capturing functionality""" import io +import os +import subprocess +import sys +import time import warnings +from unittest import mock import pytest import zmq @@ -10,20 +15,28 @@ from ipykernel.iostream import MASTER, BackgroundSocket, IOPubThread, OutStream -def test_io_api(): - """Test that wrapped stdout has the same API as a normal TextIO object""" - session = Session() +@pytest.fixture +def ctx(): ctx = zmq.Context() - pub = ctx.socket(zmq.PUB) - thread = IOPubThread(pub) - thread.start() + yield ctx + ctx.destroy() - stream = OutStream(session, thread, "stdout") - # cleanup unused zmq objects before we start testing - thread.stop() - thread.close() - ctx.term() +@pytest.fixture +def iopub_thread(ctx): + with ctx.socket(zmq.PUB) as pub: + thread = IOPubThread(pub) + thread.start() + + yield thread + thread.stop() + thread.close() + + +def test_io_api(iopub_thread): + """Test that wrapped stdout has the same API as a normal TextIO object""" + session = Session() + stream = OutStream(session, iopub_thread, "stdout") assert stream.errors is None assert not stream.isatty() @@ -43,28 +56,21 @@ def test_io_api(): stream.write(b"") # type:ignore -def test_io_isatty(): +def test_io_isatty(iopub_thread): session = Session() - ctx = zmq.Context() - pub = ctx.socket(zmq.PUB) - thread = IOPubThread(pub) - thread.start() - - stream = OutStream(session, thread, "stdout", isatty=True) + stream = OutStream(session, iopub_thread, "stdout", isatty=True) assert stream.isatty() -def test_io_thread(): - ctx = zmq.Context() - pub = ctx.socket(zmq.PUB) - thread = IOPubThread(pub) +def test_io_thread(iopub_thread): + thread = iopub_thread thread._setup_pipe_in() msg = [thread._pipe_uuid, b"a"] thread._handle_pipe_msg(msg) ctx1, pipe = thread._setup_pipe_out() pipe.close() thread._pipe_in.close() - thread._check_mp_mode = lambda: MASTER # type:ignore + thread._check_mp_mode = lambda: MASTER thread._really_send([b"hi"]) ctx1.destroy() thread.close() @@ -72,40 +78,139 @@ def test_io_thread(): thread._really_send(None) -def test_background_socket(): - ctx = zmq.Context() - pub = ctx.socket(zmq.PUB) - thread = IOPubThread(pub) - sock = BackgroundSocket(thread) +def test_background_socket(iopub_thread): + sock = BackgroundSocket(iopub_thread) assert sock.__class__ == BackgroundSocket with warnings.catch_warnings(): warnings.simplefilter("ignore", DeprecationWarning) sock.linger = 101 - assert thread.socket.linger == 101 - assert sock.io_thread == thread + assert iopub_thread.socket.linger == 101 + assert sock.io_thread == iopub_thread sock.send(b"hi") -def test_outstream(): +def test_outstream(iopub_thread): session = Session() - ctx = zmq.Context() - pub = ctx.socket(zmq.PUB) - thread = IOPubThread(pub) - thread.start() - + pub = iopub_thread.socket with warnings.catch_warnings(): warnings.simplefilter("ignore", DeprecationWarning) stream = OutStream(session, pub, "stdout") - stream = OutStream(session, thread, "stdout", pipe=object()) + stream.close() + stream = OutStream(session, iopub_thread, "stdout", pipe=object()) + stream.close() - stream = OutStream(session, thread, "stdout", watchfd=False) + stream = OutStream(session, iopub_thread, "stdout", watchfd=False) stream.close() - stream = OutStream(session, thread, "stdout", isatty=True, echo=io.StringIO()) - with pytest.raises(io.UnsupportedOperation): - stream.fileno() - stream._watch_pipe_fd() - stream.flush() - stream.write("hi") - stream.writelines(["ab", "cd"]) - assert stream.writable() + stream = OutStream(session, iopub_thread, "stdout", isatty=True, echo=io.StringIO()) + + with stream: + with pytest.raises(io.UnsupportedOperation): + stream.fileno() + stream._watch_pipe_fd() + stream.flush() + stream.write("hi") + stream.writelines(["ab", "cd"]) + assert stream.writable() + + +def subprocess_test_echo_watch(): + # handshake Pub subscription + session = Session(key=b'abc') + + # use PUSH socket to avoid subscription issues + with zmq.Context() as ctx, ctx.socket(zmq.PUSH) as pub: + pub.connect(os.environ["IOPUB_URL"]) + iopub_thread = IOPubThread(pub) + iopub_thread.start() + stdout_fd = sys.stdout.fileno() + sys.stdout.flush() + stream = OutStream( + session, + iopub_thread, + "stdout", + isatty=True, + echo=sys.stdout, + watchfd="force", + ) + save_stdout = sys.stdout + with stream, mock.patch.object(sys, "stdout", stream): + # write to low-level FD + os.write(stdout_fd, b"fd\n") + # print (writes to stream) + print("print\n", end="") + sys.stdout.flush() + # write to unwrapped __stdout__ (should also go to original FD) + sys.__stdout__.write("__stdout__\n") + sys.__stdout__.flush() + # write to original sys.stdout (should be the same as __stdout__) + save_stdout.write("stdout\n") + save_stdout.flush() + # is there another way to flush on the FD? + fd_file = os.fdopen(stdout_fd, "w") + fd_file.flush() + # we don't have a sync flush on _reading_ from the watched pipe + time.sleep(1) + stream.flush() + iopub_thread.stop() + iopub_thread.close() + + +@pytest.mark.skipif(sys.platform.startswith("win"), reason="Windows") +def test_echo_watch(ctx): + """Test echo on underlying FD while capturing the same FD + + Test runs in a subprocess to avoid messing with pytest output capturing. + """ + s = ctx.socket(zmq.PULL) + port = s.bind_to_random_port("tcp://127.0.0.1") + url = f"tcp://127.0.0.1:{port}" + session = Session(key=b'abc') + messages = [] + stdout_chunks = [] + with s: + env = dict(os.environ) + env["IOPUB_URL"] = url + env["PYTHONUNBUFFERED"] = "1" + env.pop("PYTEST_CURRENT_TEST", None) + p = subprocess.run( + [ + sys.executable, + "-c", + f"import {__name__}; {__name__}.subprocess_test_echo_watch()", + ], + env=env, + capture_output=True, + text=True, + timeout=10, + ) + print(f"{p.stdout=}") + print(f"{p.stderr}=", file=sys.stderr) + assert p.returncode == 0 + while s.poll(timeout=100): + ident, msg = session.recv(s) + assert msg is not None # for type narrowing + if msg["header"]["msg_type"] == "stream" and msg["content"]["name"] == "stdout": + stdout_chunks.append(msg["content"]["text"]) + + # check outputs + # use sets of lines to ignore ordering issues with + # async flush and watchfd thread + + # Check the stream output forwarded over zmq + zmq_stdout = "".join(stdout_chunks) + assert set(zmq_stdout.strip().splitlines()) == { + "fd", + "print", + "stdout", + "__stdout__", + } + + # Check what was written to the process stdout (kernel terminal) + # just check that each output source went to the terminal + assert set(p.stdout.strip().splitlines()) == { + "fd", + "print", + "stdout", + "__stdout__", + } diff --git a/pyproject.toml b/pyproject.toml index b042fac27..3494fefae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -91,6 +91,7 @@ api = "sphinx-apidoc -o docs/api -f -E ipykernel ipykernel/tests ipykernel/inpro [tool.hatch.envs.test] features = ["test"] [tool.hatch.envs.test.scripts] +list = "python -m pip freeze" test = "python -m pytest -vv {args}" nowarn = "test -W default {args}"