diff --git a/trio/_subprocess/unix_pipes.py b/trio/_subprocess/unix_pipes.py index 7d7c94891e..a7ab987132 100644 --- a/trio/_subprocess/unix_pipes.py +++ b/trio/_subprocess/unix_pipes.py @@ -46,28 +46,63 @@ def __del__(self): class PipeSendStream(_PipeMixin, SendStream): """Represents a send stream over an os.pipe object.""" - async def send_all(self, data: bytes): - # we have to do this no matter what - await _core.checkpoint() + async def send_some(self, data: bytes, next_offset: int) -> int: + """Write data from ``data`` beginning at offset ``next_offset`` + to the pipe until all data is written or an exception occurs. + This will block until some data can be written, but will return + a short write rather than blocking again after a previous + successful write. + + Returns: + If some data is written, returns the index of the first byte in + ``data`` that was not written, or ``len(data)`` if all bytes + were written. + + Raises: + If no data is written, raises the exception that caused the first + write to fail: :exc:`Cancelled`, :exc:`BrokenResourceError`, or + :exc:`OSError`. + + """ + if self._closed: + await _core.checkpoint() raise _core.ClosedResourceError("this pipe is already closed") if not data: - return + await _core.checkpoint() + return 0 + + await _core.checkpoint_if_cancelled() - length = len(data) # adapted from the SocketStream code with memoryview(data) as view: - total_sent = 0 - while total_sent < length: - with view[total_sent:] as remaining: + # First write: block and raise exceptions + with view[next_offset:] as remaining: + try: + next_offset += os.write(self._pipe, remaining) + except BrokenPipeError as e: + await _core.cancel_shielded_checkpoint() + raise BrokenResourceError from e + except BlockingIOError: + await self.wait_send_all_might_not_block() + else: + await _core.cancel_shielded_checkpoint() + + # Later writes: return a short write instead + while next_offset < len(data): + with view[next_offset:] as remaining: try: - total_sent += os.write(self._pipe, remaining) - except BrokenPipeError as e: - await _core.checkpoint() - raise BrokenResourceError from e - except BlockingIOError: - await self.wait_send_all_might_not_block() + next_offset += os.write(self._pipe, remaining) + except OSError: # includes BlockingIOError + break + + return next_offset + + async def send_all(self, data: bytes) -> None: + next_offset = await self.send_some(data, 0) + while next_offset < len(data): + next_offset = await self.send_some(data, next_offset) async def wait_send_all_might_not_block(self) -> None: if self._closed: diff --git a/trio/tests/subprocess/test_unix_pipes.py b/trio/tests/subprocess/test_unix_pipes.py index b8cc2d496a..3186f9aed6 100644 --- a/trio/tests/subprocess/test_unix_pipes.py +++ b/trio/tests/subprocess/test_unix_pipes.py @@ -1,11 +1,12 @@ import errno import select +import random import os import pytest from trio._core.tests.tutil import gc_collect_harder -from ... import _core +from ... import _core, move_on_after, sleep, BrokenResourceError from ...testing import (wait_all_tasks_blocked, check_one_way_stream) posix = os.name == "posix" @@ -146,3 +147,43 @@ async def make_clogged_pipe(): async def test_pipe_fully(): await check_one_way_stream(make_pipe, make_clogged_pipe) + + +async def test_pipe_send_some(autojump_clock): + write, read = await make_pipe() + data = bytearray(random.randint(0, 255) for _ in range(2**18)) + next_send_offset = 0 + received = bytearray() + + async def sender(): + nonlocal next_send_offset + with move_on_after(2.0): + while next_send_offset < len(data): # pragma: no branch + next_send_offset = await write.send_some( + data, next_send_offset + ) + await write.aclose() + + async def reader(): + nonlocal received + await wait_all_tasks_blocked() + while True: + await sleep(0.1) + chunk = await read.receive_some(4096) + if chunk == b"": + break + received.extend(chunk) + + async with _core.open_nursery() as n: + n.start_soon(sender) + n.start_soon(reader) + + assert received == data[:next_send_offset] + + await read.aclose() + + write, read = await make_pipe() + await read.aclose() + with pytest.raises(BrokenResourceError): + await write.send_some(data, next_send_offset) + await write.aclose()