Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 102 additions & 7 deletions src/_pytest/capture.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ def pytest_addoption(parser):
dest="capture",
help="shortcut for --capture=no.",
)
parser.addini(
"capture_suspend_on_stdin",
"Suspend capturing when stdin is being read from.",
type="bool",
default=False,
)


@pytest.hookimpl(hookwrapper=True)
Expand All @@ -50,7 +56,8 @@ def pytest_load_initial_conftests(early_config, parser, args):
_colorama_workaround()
_readline_workaround()
pluginmanager = early_config.pluginmanager
capman = CaptureManager(ns.capture)
suspend_on_stdin = early_config.getini("capture_suspend_on_stdin")
capman = CaptureManager(ns.capture, suspend_on_stdin)
pluginmanager.register(capman, "capturemanager")

# make sure that capturemanager is properly reset at final shutdown
Expand Down Expand Up @@ -86,10 +93,11 @@ class CaptureManager(object):
case special handling is needed to ensure the fixtures take precedence over the global capture.
"""

def __init__(self, method):
def __init__(self, method, suspend_on_stdin=False):
self._method = method
self._global_capturing = None
self._current_item = None
self._suspend_on_stdin = suspend_on_stdin

def __repr__(self):
return "<CaptureManager _method=%r _global_capturing=%r _current_item=%r>" % (
Expand All @@ -100,9 +108,32 @@ def __repr__(self):

def _getcapture(self, method):
if method == "fd":
return MultiCapture(out=True, err=True, Capture=FDCapture)
if self._suspend_on_stdin:

def in_(fd, multicapture):
assert fd == 0, fd
syscapture = SysCapture(
fd, tmpfile=SysStdinCapture(multicapture=multicapture)
)
return FDCapture(0, stdin_syscapture=syscapture)

else:
in_ = True
return MultiCapture(out=True, err=True, in_=in_, Capture=FDCapture)

elif method == "sys":
return MultiCapture(out=True, err=True, Capture=SysCapture)
if self._suspend_on_stdin:

def in_(fd, multicapture):
assert fd == 0, fd
return SysCapture(
fd, tmpfile=SysStdinCapture(multicapture=multicapture)
)

else:
in_ = True
return MultiCapture(out=True, err=True, in_=in_, Capture=SysCapture)

elif method == "no":
return MultiCapture(out=False, err=False, in_=False)
raise ValueError("unknown capturing method: %r" % method) # pragma: no cover
Expand Down Expand Up @@ -450,7 +481,10 @@ class MultiCapture(object):

def __init__(self, out=True, err=True, in_=True, Capture=None):
if in_:
self.in_ = Capture(0)
if in_ is True:
self.in_ = Capture(0)
else:
self.in_ = in_(0, multicapture=self)
if out:
self.out = Capture(1)
if err:
Expand Down Expand Up @@ -527,7 +561,7 @@ class FDCaptureBinary(object):

EMPTY_BUFFER = b""

def __init__(self, targetfd, tmpfile=None):
def __init__(self, targetfd, tmpfile=None, stdin_syscapture=None):
self.targetfd = targetfd
try:
self.targetfd_save = os.dup(self.targetfd)
Expand All @@ -538,7 +572,10 @@ def __init__(self, targetfd, tmpfile=None):
if targetfd == 0:
assert not tmpfile, "cannot set tmpfile with stdin"
tmpfile = open(os.devnull, "r")
self.syscapture = SysCapture(targetfd)
if stdin_syscapture is None:
self.syscapture = SysCapture(targetfd)
else:
self.syscapture = stdin_syscapture
else:
if tmpfile is None:
f = TemporaryFile()
Expand Down Expand Up @@ -663,6 +700,64 @@ def snap(self):
return res


class SysStdinCapture(CaptureIO):
"""Wrap CaptureIO to suspend on read."""

def __init__(self, multicapture, *args):
self.multicapture = multicapture
assert isinstance(multicapture, MultiCapture), multicapture

super(SysStdinCapture, self).__init__(*args)

def __repr__(self):
return "<SysStdinCapture multicapture=%r>" % (self.multicapture,)

@property
def _syscapture(self):
try:
return self.multicapture.in_.syscapture
except AttributeError:
return self.multicapture.in_

def _suspend_on_read(self, method, *args):
# TODO: would be nice to have this in the original order, not split
# by stdout/stderr. Let's have stderr first at least, given that
# prompts go to stdout usually.

# self.multicapture.pop_outerr_to_orig()
self.multicapture.out.writeorg(
"=== Suspending capturing due to stdin being read ===\n"
)
out, err = self.multicapture.readouterr()
if err:
# NOTE: this writes to stderr. Not sure about this.
self.multicapture.err.writeorg("=== stderr ===\n")
self.multicapture.err.writeorg(err)
if out:
self.multicapture.out.writeorg("=== stdout ===\n")
self.multicapture.out.writeorg(out)
self.multicapture.suspend_capturing(in_=True)

f = getattr(self._syscapture._old, method)
r = f(*args)
return r

def isatty(self):
return self._syscapture._old.isatty()

def read(self, *args):
return self._suspend_on_read("read", *args)

def readline(self, *args):
return self._suspend_on_read("readline", *args)

def readlines(self, *args): # Required for py2.
return self._suspend_on_read("readlines", *args)

def fileno(self):
return self._syscapture._old.fileno()


class DontReadFromInput(six.Iterator):
"""Temporary stub class. Ideally when stdin is accessed, the
capturing should be turned off, with possibly all data captured
Expand Down
92 changes: 92 additions & 0 deletions testing/test_capture.py
Original file line number Diff line number Diff line change
Expand Up @@ -1565,3 +1565,95 @@ def test_fails():
)
else:
assert result_with_capture.ret == 0


def test_suspend_on_read_from_stdin(testdir):
p1 = testdir.makepyfile(
"""
import sys

def test():
print("prompt_1")
assert input() == "input_1"

print("prompt_2")
assert sys.stdin.read() == "input_2\\nsecond_line\\n"

print("prompt_3")
assert sys.stdin.readline() == "input_3\\n"

print("prompt_4")
assert sys.stdin.readlines() == ["input_4\\n", "second_line\\n"]

print("after_" + "input: OK")

print("is_atty: %d" % sys.stdin.isatty())
"""
)
child = testdir.spawn_pytest("-o capture_suspend_on_stdin=1 %s" % p1)
child.expect("prompt_1")
child.sendline("input_1")

child.expect("prompt_2")
child.sendline("input_2")
child.sendline("second_line")
child.sendeof()

child.expect("prompt_3")
child.sendline("input_3")

child.expect("prompt_4")
child.sendline("input_4")
child.sendline("second_line")
child.sendeof()

child.expect("after_input: OK")
rest = child.read().decode("utf8")
assert "is_atty: 1" in rest
assert "1 passed in" in rest


@pytest.mark.parametrize("method", ("fd", "sys"))
def test_sysstdincapture(method, testdir):
p1 = testdir.makepyfile(
"""
import pytest
from _pytest.capture import CaptureManager, MultiCapture, SysStdinCapture

def test_inner():
method = {method!r}

capman = CaptureManager(method, suspend_on_stdin=True)
multicapture = capman._getcapture(method)
in_ = multicapture.in_
if method == "sys":
f = in_.tmpfile
else:
f = in_.syscapture.tmpfile
assert isinstance(f, SysStdinCapture)

assert not f.isatty()

assert f.read() == ""
assert f.readlines() == []

# XXX: fails on py2.
# > next(iter_f)
# E IOError: readline() should have returned an str object, not 'str'
import sys
if sys.version_info > (3,):
iter_f = iter(f)
with pytest.raises(StopIteration):
next(iter_f)

assert f.fileno() == 0
f.close()
""".format(
method=method
)
)
result = testdir.runpytest_subprocess(
str(p1), "-s" # Pass through stdin, we're not testing suspending here.
)
assert result.ret == 0
assert "1 passed in" in result.stdout.str()