diff --git a/.coveragerc b/.coveragerc index 98f923bd8e..d577aa8adf 100644 --- a/.coveragerc +++ b/.coveragerc @@ -21,6 +21,7 @@ exclude_lines = abc.abstractmethod if TYPE_CHECKING: if _t.TYPE_CHECKING: + @overload partial_branches = pragma: no branch diff --git a/docs/source/conf.py b/docs/source/conf.py index 91ce7d884c..650688717a 100755 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -63,6 +63,8 @@ ("py:obj", "trio._abc.T"), ("py:obj", "trio._abc.T_resource"), ("py:class", "types.FrameType"), + ("py:class", "P.args"), + ("py:class", "P.kwargs"), # TODO: figure out if you can link this to SSL ("py:class", "Context"), # TODO: temporary type diff --git a/docs/source/reference-io.rst b/docs/source/reference-io.rst index 9ad11b2c5a..85969174aa 100644 --- a/docs/source/reference-io.rst +++ b/docs/source/reference-io.rst @@ -637,9 +637,11 @@ Asynchronous path objects Asynchronous file objects ~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autofunction:: open_file +.. Suppress type annotations here, they refer to lots of internal types. + The normal Python docs go into better detail. +.. autofunction:: open_file(file, mode='r', buffering=-1, encoding=None, errors=None, newline=None, closefd=None, opener=None) -.. autofunction:: wrap_file +.. autofunction:: wrap_file(file) .. interface:: Asynchronous file interface diff --git a/pyproject.toml b/pyproject.toml index 264345d8c5..d479442c7a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,6 +46,12 @@ disallow_untyped_defs = false # DO NOT use `ignore_errors`; it doesn't apply # downstream and users have to deal with them. +[[tool.mypy.overrides]] +module = [ + "trio._path", + "trio._file_io", +] +disallow_untyped_defs = true [[tool.mypy.overrides]] module = [ diff --git a/trio/_file_io.py b/trio/_file_io.py index 9f7d81adef..6b79ae25b5 100644 --- a/trio/_file_io.py +++ b/trio/_file_io.py @@ -1,13 +1,39 @@ +from __future__ import annotations + import io from functools import partial +from typing import ( + IO, + TYPE_CHECKING, + Any, + AnyStr, + BinaryIO, + Callable, + Generic, + Iterable, + TypeVar, + Union, + overload, +) import trio from ._util import async_wraps from .abc import AsyncResource +if TYPE_CHECKING: + from _typeshed import ( + OpenBinaryMode, + OpenBinaryModeReading, + OpenBinaryModeUpdating, + OpenBinaryModeWriting, + OpenTextMode, + StrOrBytesPath, + ) + from typing_extensions import Literal + # This list is also in the docs, make sure to keep them in sync -_FILE_SYNC_ATTRS = { +_FILE_SYNC_ATTRS: set[str] = { "closed", "encoding", "errors", @@ -29,7 +55,7 @@ } # This list is also in the docs, make sure to keep them in sync -_FILE_ASYNC_METHODS = { +_FILE_ASYNC_METHODS: set[str] = { "flush", "read", "read1", @@ -48,59 +74,201 @@ } -class AsyncIOWrapper(AsyncResource): +FileT = TypeVar("FileT") +FileT_co = TypeVar("FileT_co", covariant=True) +T = TypeVar("T") +T_co = TypeVar("T_co", covariant=True) +T_contra = TypeVar("T_contra", contravariant=True) +AnyStr_co = TypeVar("AnyStr_co", str, bytes, covariant=True) +AnyStr_contra = TypeVar("AnyStr_contra", str, bytes, contravariant=True) + +# This is a little complicated. IO objects have a lot of methods, and which are available on +# different types varies wildly. We want to match the interface of whatever file we're wrapping. +# This pile of protocols each has one sync method/property, meaning they're going to be compatible +# with a file class that supports that method/property. The ones parameterized with AnyStr take +# either str or bytes depending. + +# The wrapper is then a generic class, where the typevar is set to the type of the sync file we're +# wrapping. For generics, adding a type to self has a special meaning - properties/methods can be +# conditional - it's only valid to call them if the object you're accessing them on is compatible +# with that type hint. By using the protocols, the type checker will be checking to see if the +# wrapped type has that method, and only allow the methods that do to be called. We can then alter +# the signature however it needs to match runtime behaviour. +# More info: https://mypy.readthedocs.io/en/stable/more_types.html#advanced-uses-of-self-types +if TYPE_CHECKING: + from typing_extensions import Buffer, Protocol + + # fmt: off + + class _HasClosed(Protocol): + @property + def closed(self) -> bool: ... + + class _HasEncoding(Protocol): + @property + def encoding(self) -> str: ... + + class _HasErrors(Protocol): + @property + def errors(self) -> str | None: ... + + class _HasFileNo(Protocol): + def fileno(self) -> int: ... + + class _HasIsATTY(Protocol): + def isatty(self) -> bool: ... + + class _HasNewlines(Protocol[T_co]): + # Type varies here - documented to be None, tuple of strings, strings. Typeshed uses Any. + @property + def newlines(self) -> T_co: ... + + class _HasReadable(Protocol): + def readable(self) -> bool: ... + + class _HasSeekable(Protocol): + def seekable(self) -> bool: ... + + class _HasWritable(Protocol): + def writable(self) -> bool: ... + + class _HasBuffer(Protocol): + @property + def buffer(self) -> BinaryIO: ... + + class _HasRaw(Protocol): + @property + def raw(self) -> io.RawIOBase: ... + + class _HasLineBuffering(Protocol): + @property + def line_buffering(self) -> bool: ... + + class _HasCloseFD(Protocol): + @property + def closefd(self) -> bool: ... + + class _HasName(Protocol): + @property + def name(self) -> str: ... + + class _HasMode(Protocol): + @property + def mode(self) -> str: ... + + class _CanGetValue(Protocol[AnyStr_co]): + def getvalue(self) -> AnyStr_co: ... + + class _CanGetBuffer(Protocol): + def getbuffer(self) -> memoryview: ... + + class _CanFlush(Protocol): + def flush(self) -> None: ... + + class _CanRead(Protocol[AnyStr_co]): + def read(self, size: int | None = ..., /) -> AnyStr_co: ... + + class _CanRead1(Protocol): + def read1(self, size: int | None = ..., /) -> bytes: ... + + class _CanReadAll(Protocol[AnyStr_co]): + def readall(self) -> AnyStr_co: ... + + class _CanReadInto(Protocol): + def readinto(self, buf: Buffer, /) -> int | None: ... + + class _CanReadInto1(Protocol): + def readinto1(self, buffer: Buffer, /) -> int: ... + + class _CanReadLine(Protocol[AnyStr_co]): + def readline(self, size: int = ..., /) -> AnyStr_co: ... + + class _CanReadLines(Protocol[AnyStr]): + def readlines(self, hint: int = ...) -> list[AnyStr]: ... + + class _CanSeek(Protocol): + def seek(self, target: int, whence: int = 0, /) -> int: ... + + class _CanTell(Protocol): + def tell(self) -> int: ... + + class _CanTruncate(Protocol): + def truncate(self, size: int | None = ..., /) -> int: ... + + class _CanWrite(Protocol[AnyStr_contra]): + def write(self, data: AnyStr_contra, /) -> int: ... + + class _CanWriteLines(Protocol[T_contra]): + # The lines parameter varies for bytes/str, so use a typevar to make the async match. + def writelines(self, lines: Iterable[T_contra], /) -> None: ... + + class _CanPeek(Protocol[AnyStr_co]): + def peek(self, size: int = 0, /) -> AnyStr_co: ... + + class _CanDetach(Protocol[T_co]): + # The T typevar will be the unbuffered/binary file this file wraps. + def detach(self) -> T_co: ... + + class _CanClose(Protocol): + def close(self) -> None: ... + + +# FileT needs to be covariant for the protocol trick to work - the real IO types are effectively a +# subtype of the protocols. +class AsyncIOWrapper(AsyncResource, Generic[FileT_co]): """A generic :class:`~io.IOBase` wrapper that implements the :term:`asynchronous file object` interface. Wrapped methods that could block are executed in :meth:`trio.to_thread.run_sync`. - All properties and methods defined in in :mod:`~io` are exposed by this + All properties and methods defined in :mod:`~io` are exposed by this wrapper, if they exist in the wrapped file object. - """ - def __init__(self, file): + def __init__(self, file: FileT_co) -> None: self._wrapped = file @property - def wrapped(self): + def wrapped(self) -> FileT_co: """object: A reference to the wrapped file object""" return self._wrapped - def __getattr__(self, name): - if name in _FILE_SYNC_ATTRS: - return getattr(self._wrapped, name) - if name in _FILE_ASYNC_METHODS: - meth = getattr(self._wrapped, name) + if not TYPE_CHECKING: - @async_wraps(self.__class__, self._wrapped.__class__, name) - async def wrapper(*args, **kwargs): - func = partial(meth, *args, **kwargs) - return await trio.to_thread.run_sync(func) + def __getattr__(self, name: str) -> object: + if name in _FILE_SYNC_ATTRS: + return getattr(self._wrapped, name) + if name in _FILE_ASYNC_METHODS: + meth = getattr(self._wrapped, name) - # cache the generated method - setattr(self, name, wrapper) - return wrapper + @async_wraps(self.__class__, self._wrapped.__class__, name) + async def wrapper(*args, **kwargs): + func = partial(meth, *args, **kwargs) + return await trio.to_thread.run_sync(func) - raise AttributeError(name) + # cache the generated method + setattr(self, name, wrapper) + return wrapper - def __dir__(self): + raise AttributeError(name) + + def __dir__(self) -> Iterable[str]: attrs = set(super().__dir__()) attrs.update(a for a in _FILE_SYNC_ATTRS if hasattr(self.wrapped, a)) attrs.update(a for a in _FILE_ASYNC_METHODS if hasattr(self.wrapped, a)) return attrs - def __aiter__(self): + def __aiter__(self) -> AsyncIOWrapper[FileT_co]: return self - async def __anext__(self): + async def __anext__(self: AsyncIOWrapper[_CanReadLine[AnyStr]]) -> AnyStr: line = await self.readline() if line: return line else: raise StopAsyncIteration - async def detach(self): + async def detach(self: AsyncIOWrapper[_CanDetach[T]]) -> AsyncIOWrapper[T]: """Like :meth:`io.BufferedIOBase.detach`, but async. This also re-wraps the result in a new :term:`asynchronous file object` @@ -111,7 +279,7 @@ async def detach(self): raw = await trio.to_thread.run_sync(self._wrapped.detach) return wrap_file(raw) - async def aclose(self): + async def aclose(self: AsyncIOWrapper[_CanClose]) -> None: """Like :meth:`io.IOBase.close`, but async. This is also shielded from cancellation; if a cancellation scope is @@ -125,18 +293,167 @@ async def aclose(self): await trio.lowlevel.checkpoint_if_cancelled() + if TYPE_CHECKING: + # fmt: off + # Based on typing.IO and io stubs. + @property + def closed(self: AsyncIOWrapper[_HasClosed]) -> bool: ... + @property + def encoding(self: AsyncIOWrapper[_HasEncoding]) -> str: ... + @property + def errors(self: AsyncIOWrapper[_HasErrors]) -> str | None: ... + @property + def newlines(self: AsyncIOWrapper[_HasNewlines[T]]) -> T: ... + @property + def buffer(self: AsyncIOWrapper[_HasBuffer]) -> BinaryIO: ... + @property + def raw(self: AsyncIOWrapper[_HasRaw]) -> io.RawIOBase: ... + @property + def line_buffering(self: AsyncIOWrapper[_HasLineBuffering]) -> int: ... + @property + def closefd(self: AsyncIOWrapper[_HasCloseFD]) -> bool: ... + @property + def name(self: AsyncIOWrapper[_HasName]) -> str: ... + @property + def mode(self: AsyncIOWrapper[_HasMode]) -> str: ... + + def fileno(self: AsyncIOWrapper[_HasFileNo]) -> int: ... + def isatty(self: AsyncIOWrapper[_HasIsATTY]) -> bool: ... + def readable(self: AsyncIOWrapper[_HasReadable]) -> bool: ... + def seekable(self: AsyncIOWrapper[_HasSeekable]) -> bool: ... + def writable(self: AsyncIOWrapper[_HasWritable]) -> bool: ... + def getvalue(self: AsyncIOWrapper[_CanGetValue[AnyStr]]) -> AnyStr: ... + def getbuffer(self: AsyncIOWrapper[_CanGetBuffer]) -> memoryview: ... + async def flush(self: AsyncIOWrapper[_CanFlush]) -> None: ... + async def read(self: AsyncIOWrapper[_CanRead[AnyStr]], size: int | None = -1, /) -> AnyStr: ... + async def read1(self: AsyncIOWrapper[_CanRead1], size: int | None = -1, /) -> bytes: ... + async def readall(self: AsyncIOWrapper[_CanReadAll[AnyStr]]) -> AnyStr: ... + async def readinto(self: AsyncIOWrapper[_CanReadInto], buf: Buffer, /) -> int | None: ... + async def readline(self: AsyncIOWrapper[_CanReadLine[AnyStr]], size: int = -1, /) -> AnyStr: ... + async def readlines(self: AsyncIOWrapper[_CanReadLines[AnyStr]]) -> list[AnyStr]: ... + async def seek(self: AsyncIOWrapper[_CanSeek], target: int, whence: int = 0, /) -> int: ... + async def tell(self: AsyncIOWrapper[_CanTell]) -> int: ... + async def truncate(self: AsyncIOWrapper[_CanTruncate], size: int | None = None, /) -> int: ... + async def write(self: AsyncIOWrapper[_CanWrite[AnyStr]], data: AnyStr, /) -> int: ... + async def writelines(self: AsyncIOWrapper[_CanWriteLines[T]], lines: Iterable[T], /) -> None: ... + async def readinto1(self: AsyncIOWrapper[_CanReadInto1], buffer: Buffer, /) -> int: ... + async def peek(self: AsyncIOWrapper[_CanPeek[AnyStr]], size: int = 0, /) -> AnyStr: ... + + +# Type hints are copied from builtin open. +_OpenFile = Union["StrOrBytesPath", int] +_Opener = Callable[[str, int], int] + + +@overload +async def open_file( + file: _OpenFile, + mode: OpenTextMode = "r", + buffering: int = -1, + encoding: str | None = None, + errors: str | None = None, + newline: str | None = None, + closefd: bool = True, + opener: _Opener | None = None, +) -> AsyncIOWrapper[io.TextIOWrapper]: + ... + + +@overload +async def open_file( + file: _OpenFile, + mode: OpenBinaryMode, + buffering: Literal[0], + encoding: None = None, + errors: None = None, + newline: None = None, + closefd: bool = True, + opener: _Opener | None = None, +) -> AsyncIOWrapper[io.FileIO]: + ... + + +@overload +async def open_file( + file: _OpenFile, + mode: OpenBinaryModeUpdating, + buffering: Literal[-1, 1] = -1, + encoding: None = None, + errors: None = None, + newline: None = None, + closefd: bool = True, + opener: _Opener | None = None, +) -> AsyncIOWrapper[io.BufferedRandom]: + ... + + +@overload +async def open_file( + file: _OpenFile, + mode: OpenBinaryModeWriting, + buffering: Literal[-1, 1] = -1, + encoding: None = None, + errors: None = None, + newline: None = None, + closefd: bool = True, + opener: _Opener | None = None, +) -> AsyncIOWrapper[io.BufferedWriter]: + ... + + +@overload +async def open_file( + file: _OpenFile, + mode: OpenBinaryModeReading, + buffering: Literal[-1, 1] = -1, + encoding: None = None, + errors: None = None, + newline: None = None, + closefd: bool = True, + opener: _Opener | None = None, +) -> AsyncIOWrapper[io.BufferedReader]: + ... + + +@overload +async def open_file( + file: _OpenFile, + mode: OpenBinaryMode, + buffering: int, + encoding: None = None, + errors: None = None, + newline: None = None, + closefd: bool = True, + opener: _Opener | None = None, +) -> AsyncIOWrapper[BinaryIO]: + ... + + +@overload +async def open_file( + file: _OpenFile, + mode: str, + buffering: int = -1, + encoding: str | None = None, + errors: str | None = None, + newline: str | None = None, + closefd: bool = True, + opener: _Opener | None = None, +) -> AsyncIOWrapper[IO[Any]]: + ... + async def open_file( - file, - mode="r", - buffering=-1, - encoding=None, - errors=None, - newline=None, - closefd=True, - opener=None, -): - """Asynchronous version of :func:`io.open`. + file: _OpenFile, + mode: str = "r", + buffering: int = -1, + encoding: str | None = None, + errors: str | None = None, + newline: str | None = None, + closefd: bool = True, + opener: _Opener | None = None, +) -> AsyncIOWrapper[Any]: + """Asynchronous version of :func:`open`. Returns: An :term:`asynchronous file object` @@ -161,7 +478,7 @@ async def open_file( return _file -def wrap_file(file): +def wrap_file(file: FileT) -> AsyncIOWrapper[FileT]: """This wraps any file object in a wrapper that provides an asynchronous file object interface. @@ -179,7 +496,7 @@ def wrap_file(file): """ - def has(attr): + def has(attr: str) -> bool: return hasattr(file, attr) and callable(getattr(file, attr)) if not (has("close") and (has("read") or has("write"))): diff --git a/trio/_path.py b/trio/_path.py index 67234e223d..b7e6b16e4a 100644 --- a/trio/_path.py +++ b/trio/_path.py @@ -1,49 +1,96 @@ +from __future__ import annotations + +import inspect import os import pathlib import sys import types +from collections.abc import Awaitable, Callable, Iterable from functools import partial, wraps -from typing import TYPE_CHECKING, Any +from io import BufferedRandom, BufferedReader, BufferedWriter, FileIO, TextIOWrapper +from typing import ( + IO, + TYPE_CHECKING, + Any, + BinaryIO, + ClassVar, + TypeVar, + Union, + cast, + overload, +) import trio +from trio._file_io import AsyncIOWrapper as _AsyncIOWrapper from trio._util import Final, async_wraps +if TYPE_CHECKING: + from _typeshed import ( + OpenBinaryMode, + OpenBinaryModeReading, + OpenBinaryModeUpdating, + OpenBinaryModeWriting, + OpenTextMode, + ) + from typing_extensions import Concatenate, Literal, ParamSpec, TypeAlias + + P = ParamSpec("P") + +T = TypeVar("T") +StrPath: TypeAlias = Union[str, "os.PathLike[str]"] # Only subscriptable in 3.9+ + # re-wrap return value from methods that return new instances of pathlib.Path -def rewrap_path(value): +def rewrap_path(value: T) -> T | Path: if isinstance(value, pathlib.Path): - value = Path(value) - return value + return Path(value) + else: + return value -def _forward_factory(cls, attr_name, attr): +def _forward_factory( + cls: AsyncAutoWrapperType, + attr_name: str, + attr: Callable[Concatenate[pathlib.Path, P], T], +) -> Callable[Concatenate[Path, P], T | Path]: @wraps(attr) - def wrapper(self, *args, **kwargs): + def wrapper(self: Path, *args: P.args, **kwargs: P.kwargs) -> T | Path: attr = getattr(self._wrapped, attr_name) value = attr(*args, **kwargs) return rewrap_path(value) + # Assigning this makes inspect and therefore Sphinx show the original parameters. + # It's not defined on functions normally though, this is a custom attribute. + assert isinstance(wrapper, types.FunctionType) + wrapper.__signature__ = inspect.signature(attr) + return wrapper -def _forward_magic(cls, attr): +def _forward_magic( + cls: AsyncAutoWrapperType, attr: Callable[..., T] +) -> Callable[..., Path | T]: sentinel = object() @wraps(attr) - def wrapper(self, other=sentinel): + def wrapper(self: Path, other: object = sentinel) -> Path | T: if other is sentinel: return attr(self._wrapped) if isinstance(other, cls): - other = other._wrapped + other = cast(Path, other)._wrapped value = attr(self._wrapped, other) return rewrap_path(value) + assert isinstance(wrapper, types.FunctionType) + wrapper.__signature__ = inspect.signature(attr) return wrapper -def iter_wrapper_factory(cls, meth_name): +def iter_wrapper_factory( + cls: AsyncAutoWrapperType, meth_name: str +) -> Callable[Concatenate[Path, P], Awaitable[Iterable[Path]]]: @async_wraps(cls, cls._wraps, meth_name) - async def wrapper(self, *args, **kwargs): + async def wrapper(self: Path, *args: P.args, **kwargs: P.kwargs) -> Iterable[Path]: meth = getattr(self._wrapped, meth_name) func = partial(meth, *args, **kwargs) # Make sure that the full iteration is performed in the thread @@ -54,9 +101,11 @@ async def wrapper(self, *args, **kwargs): return wrapper -def thread_wrapper_factory(cls, meth_name): +def thread_wrapper_factory( + cls: AsyncAutoWrapperType, meth_name: str +) -> Callable[Concatenate[Path, P], Awaitable[Path]]: @async_wraps(cls, cls._wraps, meth_name) - async def wrapper(self, *args, **kwargs): + async def wrapper(self: Path, *args: P.args, **kwargs: P.kwargs) -> Path: meth = getattr(self._wrapped, meth_name) func = partial(meth, *args, **kwargs) value = await trio.to_thread.run_sync(func) @@ -65,20 +114,31 @@ async def wrapper(self, *args, **kwargs): return wrapper -def classmethod_wrapper_factory(cls, meth_name): - @classmethod +def classmethod_wrapper_factory( + cls: AsyncAutoWrapperType, meth_name: str +) -> classmethod: @async_wraps(cls, cls._wraps, meth_name) - async def wrapper(cls, *args, **kwargs): + async def wrapper(cls: type[Path], *args: Any, **kwargs: Any) -> Path: meth = getattr(cls._wraps, meth_name) func = partial(meth, *args, **kwargs) value = await trio.to_thread.run_sync(func) return rewrap_path(value) - return wrapper + assert isinstance(wrapper, types.FunctionType) + wrapper.__signature__ = inspect.signature(getattr(cls._wraps, meth_name)) + return classmethod(wrapper) class AsyncAutoWrapperType(Final): - def __init__(cls, name, bases, attrs): + _forwards: type + _wraps: type + _forward_magic: list[str] + _wrap_iter: list[str] + _forward: list[str] + + def __init__( + cls, name: str, bases: tuple[type, ...], attrs: dict[str, object] + ) -> None: super().__init__(name, bases, attrs) cls._forward = [] @@ -87,7 +147,7 @@ def __init__(cls, name, bases, attrs): type(cls).generate_magic(cls, attrs) type(cls).generate_iter(cls, attrs) - def generate_forwards(cls, attrs): + def generate_forwards(cls, attrs: dict[str, object]) -> None: # forward functions of _forwards for attr_name, attr in cls._forwards.__dict__.items(): if attr_name.startswith("_") or attr_name in attrs: @@ -101,8 +161,9 @@ def generate_forwards(cls, attrs): else: raise TypeError(attr_name, type(attr)) - def generate_wraps(cls, attrs): + def generate_wraps(cls, attrs: dict[str, object]) -> None: # generate wrappers for functions of _wraps + wrapper: classmethod | Callable for attr_name, attr in cls._wraps.__dict__.items(): # .z. exclude cls._wrap_iter if attr_name.startswith("_") or attr_name in attrs: @@ -112,22 +173,27 @@ def generate_wraps(cls, attrs): setattr(cls, attr_name, wrapper) elif isinstance(attr, types.FunctionType): wrapper = thread_wrapper_factory(cls, attr_name) + assert isinstance(wrapper, types.FunctionType) + wrapper.__signature__ = inspect.signature(attr) setattr(cls, attr_name, wrapper) else: raise TypeError(attr_name, type(attr)) - def generate_magic(cls, attrs): + def generate_magic(cls, attrs: dict[str, object]) -> None: # generate wrappers for magic for attr_name in cls._forward_magic: attr = getattr(cls._forwards, attr_name) wrapper = _forward_magic(cls, attr) setattr(cls, attr_name, wrapper) - def generate_iter(cls, attrs): + def generate_iter(cls, attrs: dict[str, object]) -> None: # generate wrappers for methods that return iterators + wrapper: Callable for attr_name, attr in cls._wraps.__dict__.items(): if attr_name in cls._wrap_iter: wrapper = iter_wrapper_factory(cls, attr_name) + assert isinstance(wrapper, types.FunctionType) + wrapper.__signature__ = inspect.signature(attr) setattr(cls, attr_name, wrapper) @@ -137,9 +203,10 @@ class Path(metaclass=AsyncAutoWrapperType): """ - _wraps = pathlib.Path - _forwards = pathlib.PurePath - _forward_magic = [ + _forward: ClassVar[list[str]] + _wraps: ClassVar[type] = pathlib.Path + _forwards: ClassVar[type] = pathlib.PurePath + _forward_magic: ClassVar[list[str]] = [ "__str__", "__bytes__", "__truediv__", @@ -151,9 +218,9 @@ class Path(metaclass=AsyncAutoWrapperType): "__ge__", "__hash__", ] - _wrap_iter = ["glob", "rglob", "iterdir"] + _wrap_iter: ClassVar[list[str]] = ["glob", "rglob", "iterdir"] - def __init__(self, *args): + def __init__(self, *args: StrPath) -> None: self._wrapped = pathlib.Path(*args) # type checkers allow accessing any attributes on class instances with `__getattr__` @@ -167,17 +234,94 @@ def __getattr__(self, name): return rewrap_path(value) raise AttributeError(name) - def __dir__(self): - return super().__dir__() + self._forward + def __dir__(self) -> list[str]: + return [*super().__dir__(), *self._forward] - def __repr__(self): + def __repr__(self) -> str: return f"trio.Path({repr(str(self))})" - def __fspath__(self): + def __fspath__(self) -> str: return os.fspath(self._wrapped) - @wraps(pathlib.Path.open) - async def open(self, *args, **kwargs): + @overload + def open( + self, + mode: OpenTextMode = "r", + buffering: int = -1, + encoding: str | None = None, + errors: str | None = None, + newline: str | None = None, + ) -> _AsyncIOWrapper[TextIOWrapper]: + ... + + @overload + def open( + self, + mode: OpenBinaryMode, + buffering: Literal[0], + encoding: None = None, + errors: None = None, + newline: None = None, + ) -> _AsyncIOWrapper[FileIO]: + ... + + @overload + def open( + self, + mode: OpenBinaryModeUpdating, + buffering: Literal[-1, 1] = -1, + encoding: None = None, + errors: None = None, + newline: None = None, + ) -> _AsyncIOWrapper[BufferedRandom]: + ... + + @overload + def open( + self, + mode: OpenBinaryModeWriting, + buffering: Literal[-1, 1] = -1, + encoding: None = None, + errors: None = None, + newline: None = None, + ) -> _AsyncIOWrapper[BufferedWriter]: + ... + + @overload + def open( + self, + mode: OpenBinaryModeReading, + buffering: Literal[-1, 1] = -1, + encoding: None = None, + errors: None = None, + newline: None = None, + ) -> _AsyncIOWrapper[BufferedReader]: + ... + + @overload + def open( + self, + mode: OpenBinaryMode, + buffering: int = -1, + encoding: None = None, + errors: None = None, + newline: None = None, + ) -> _AsyncIOWrapper[BinaryIO]: + ... + + @overload + def open( + self, + mode: str, + buffering: int = -1, + encoding: str | None = None, + errors: str | None = None, + newline: str | None = None, + ) -> _AsyncIOWrapper[IO[Any]]: + ... + + @wraps(pathlib.Path.open) # type: ignore[misc] # Overload return mismatch. + async def open(self, *args: Any, **kwargs: Any) -> _AsyncIOWrapper[IO[Any]]: """Open the file pointed to by the path, like the :func:`trio.open_file` function does. @@ -189,75 +333,101 @@ async def open(self, *args, **kwargs): if TYPE_CHECKING: # the dunders listed in _forward_magic that aren't seen otherwise - __bytes__ = pathlib.Path.__bytes__ - __truediv__ = pathlib.Path.__truediv__ - __rtruediv__ = pathlib.Path.__rtruediv__ - - # These should be fully typed, either manually or with some magic wrapper - # function that copies the type of pathlib.Path except sticking an async in - # front of all of them. The latter is unfortunately not trivial, see attempts in - # https://github.com/python-trio/trio/issues/2630 + # fmt: off + def __bytes__(self) -> bytes: ... + def __truediv__(self, other: StrPath) -> Path: ... + def __rtruediv__(self, other: StrPath) -> Path: ... # wrapped methods handled by __getattr__ - absolute: Any - as_posix: Any - as_uri: Any - chmod: Any - cwd: Any - exists: Any - expanduser: Any - glob: Any - home: Any - is_absolute: Any - is_block_device: Any - is_char_device: Any - is_dir: Any - is_fifo: Any - is_file: Any - is_reserved: Any - is_socket: Any - is_symlink: Any - iterdir: Any - joinpath: Any - lchmod: Any - lstat: Any - match: Any - mkdir: Any - read_bytes: Any - read_text: Any - relative_to: Any - rename: Any - replace: Any - resolve: Any - rglob: Any - rmdir: Any - samefile: Any - stat: Any - symlink_to: Any - touch: Any - unlink: Any - with_name: Any - with_suffix: Any - write_bytes: Any - write_text: Any + async def absolute(self) -> Path: ... + async def as_posix(self) -> str: ... + async def as_uri(self) -> str: ... + + if sys.version_info >= (3, 10): + async def stat(self, *, follow_symlinks: bool = True) -> os.stat_result: ... + async def chmod(self, mode: int, *, follow_symlinks: bool = True) -> None: ... + else: + async def stat(self) -> os.stat_result: ... + async def chmod(self, mode: int) -> None: ... + + @classmethod + async def cwd(self) -> Path: ... + + async def exists(self) -> bool: ... + async def expanduser(self) -> Path: ... + async def glob(self, pattern: str) -> Iterable[Path]: ... + async def home(self) -> Path: ... + async def is_absolute(self) -> bool: ... + async def is_block_device(self) -> bool: ... + async def is_char_device(self) -> bool: ... + async def is_dir(self) -> bool: ... + async def is_fifo(self) -> bool: ... + async def is_file(self) -> bool: ... + async def is_reserved(self) -> bool: ... + async def is_socket(self) -> bool: ... + async def is_symlink(self) -> bool: ... + async def iterdir(self) -> Iterable[Path]: ... + async def joinpath(self, *other: StrPath) -> Path: ... + async def lchmod(self, mode: int) -> None: ... + async def lstat(self) -> os.stat_result: ... + async def match(self, path_pattern: str) -> bool: ... + async def mkdir(self, mode: int = 0o777, parents: bool = False, exist_ok: bool = False) -> None: ... + async def read_bytes(self) -> bytes: ... + async def read_text(self, encoding: str | None = None, errors: str | None = None) -> str: ... + async def relative_to(self, *other: StrPath) -> Path: ... + + if sys.version_info >= (3, 8): + def rename(self, target: str | pathlib.PurePath) -> Path: ... + def replace(self, target: str | pathlib.PurePath) -> Path: ... + else: + def rename(self, target: str | pathlib.PurePath) -> None: ... + def replace(self, target: str | pathlib.PurePath) -> None: ... + + async def resolve(self, strict: bool = False) -> Path: ... + async def rglob(self, pattern: str) -> Iterable[Path]: ... + async def rmdir(self) -> None: ... + async def samefile(self, other_path: str | bytes | int | Path) -> bool: ... + async def symlink_to(self, target: str | Path, target_is_directory: bool = False) -> None: ... + async def touch(self, mode: int = 0o666, exist_ok: bool = True) -> None: ... + if sys.version_info >= (3, 8): + def unlink(self, missing_ok: bool = False) -> None: ... + else: + def unlink(self) -> None: ... + async def with_name(self, name: str) -> Path: ... + async def with_suffix(self, suffix: str) -> Path: ... + async def write_bytes(self, data: bytes) -> int: ... + + if sys.version_info >= (3, 10): + async def write_text( + self, data: str, + encoding: str | None = None, + errors: str | None = None, + newline: str | None = None, + ) -> int: ... + else: + async def write_text( + self, data: str, + encoding: str | None = None, + errors: str | None = None, + ) -> int: ... if sys.platform != "win32": - group: Any - is_mount: Any - owner: Any + async def owner(self) -> str: ... + async def group(self) -> str: ... + async def is_mount(self) -> bool: ... if sys.version_info >= (3, 9): - is_relative_to: Any - with_stem: Any - readlink: Any + async def is_relative_to(self, *other: StrPath) -> bool: ... + async def with_stem(self, stem: str) -> Path: ... + async def readlink(self) -> Path: ... if sys.version_info >= (3, 10): - hardlink_to: Any + async def hardlink_to(self, target: str | pathlib.Path) -> None: ... if sys.version_info < (3, 12): - link_to: Any + async def link_to(self, target: StrPath | bytes) -> None: ... if sys.version_info >= (3, 12): - is_junction: Any - walk: Any - with_segments: Any + async def is_junction(self) -> bool: ... + walk: Any # TODO + async def with_segments(self, *pathsegments: StrPath) -> Path: ... Path.iterdir.__doc__ = """ diff --git a/trio/_tests/test_file_io.py b/trio/_tests/test_file_io.py index e99788efc5..bae426cf48 100644 --- a/trio/_tests/test_file_io.py +++ b/trio/_tests/test_file_io.py @@ -1,12 +1,15 @@ +import importlib import io import os +import re +from typing import List, Tuple from unittest import mock from unittest.mock import sentinel import pytest import trio -from trio import _core +from trio import _core, _file_io from trio._file_io import _FILE_ASYNC_METHODS, _FILE_SYNC_ATTRS, AsyncIOWrapper @@ -78,6 +81,46 @@ def unsupported_attr(self): # pragma: no cover getattr(async_file, "unsupported_attr") +def test_type_stubs_match_lists() -> None: + """Check the manual stubs match the list of wrapped methods.""" + # Fetch the module's source code. + assert _file_io.__spec__ is not None + loader = _file_io.__spec__.loader + assert isinstance(loader, importlib.abc.SourceLoader) + source = io.StringIO(loader.get_source("trio._file_io")) + + # Find the class, then find the TYPE_CHECKING block. + for line in source: + if "class AsyncIOWrapper" in line: + break + else: # pragma: no cover - should always find this + pytest.fail("No class definition line?") + + for line in source: + if "if TYPE_CHECKING" in line: + break + else: # pragma: no cover - should always find this + pytest.fail("No TYPE CHECKING line?") + + # Now we should be at the type checking block. + found: List[Tuple[str, str]] = [] + for line in source: # pragma: no branch - expected to break early + if line.strip() and not line.startswith(" " * 8): + break # Dedented out of the if TYPE_CHECKING block. + match = re.match(r"\s*(async )?def ([a-zA-Z0-9_]+)\(", line) + if match is not None: + kind = "async" if match.group(1) is not None else "sync" + found.append((match.group(2), kind)) + + # Compare two lists so that we can easily see duplicates, and see what is different overall. + expected = [(fname, "async") for fname in _FILE_ASYNC_METHODS] + expected += [(fname, "sync") for fname in _FILE_SYNC_ATTRS] + # Ignore order, error if duplicates are present. + found.sort() + expected.sort() + assert found == expected + + def test_sync_attrs_forwarded(async_file, wrapped): for attr_name in _FILE_SYNC_ATTRS: if attr_name not in dir(async_file): diff --git a/trio/_tests/verify_types.json b/trio/_tests/verify_types.json index ba26a34e9f..d08c03060c 100644 --- a/trio/_tests/verify_types.json +++ b/trio/_tests/verify_types.json @@ -7,11 +7,11 @@ "warningCount": 0 }, "typeCompleteness": { - "completenessScore": 0.8832, + "completenessScore": 0.888, "exportedSymbolCounts": { "withAmbiguousType": 1, - "withKnownType": 552, - "withUnknownType": 72 + "withKnownType": 555, + "withUnknownType": 69 }, "ignoreUnknownTypesFromImports": true, "missingClassDocStringCount": 1, @@ -45,9 +45,9 @@ } ], "otherSymbolCounts": { - "withAmbiguousType": 6, - "withKnownType": 475, - "withUnknownType": 114 + "withAmbiguousType": 3, + "withKnownType": 529, + "withUnknownType": 102 }, "packageName": "trio", "symbols": [ @@ -79,20 +79,6 @@ "trio._highlevel_socket.SocketStream.getsockopt", "trio._highlevel_socket.SocketStream.send_all", "trio._highlevel_socket.SocketStream.setsockopt", - "trio._path.AsyncAutoWrapperType.__init__", - "trio._path.AsyncAutoWrapperType.generate_forwards", - "trio._path.AsyncAutoWrapperType.generate_iter", - "trio._path.AsyncAutoWrapperType.generate_magic", - "trio._path.AsyncAutoWrapperType.generate_wraps", - "trio._path.Path", - "trio._path.Path.__bytes__", - "trio._path.Path.__dir__", - "trio._path.Path.__fspath__", - "trio._path.Path.__init__", - "trio._path.Path.__repr__", - "trio._path.Path.__rtruediv__", - "trio._path.Path.__truediv__", - "trio._path.Path.open", "trio._socket._SocketType.__getattr__", "trio._socket._SocketType.accept", "trio._socket._SocketType.connect", @@ -152,7 +138,6 @@ "trio.lowlevel.temporarily_detach_coroutine_object", "trio.lowlevel.wait_readable", "trio.lowlevel.wait_writable", - "trio.open_file", "trio.open_ssl_over_tcp_listeners", "trio.open_ssl_over_tcp_stream", "trio.open_tcp_listeners", @@ -204,8 +189,7 @@ "trio.testing.trio_test", "trio.testing.wait_all_tasks_blocked", "trio.tests.TestsDeprecationWrapper", - "trio.to_thread.current_default_thread_limiter", - "trio.wrap_file" + "trio.to_thread.current_default_thread_limiter" ] } } diff --git a/trio/_util.py b/trio/_util.py index b7b4403115..a87f1fc02c 100644 --- a/trio/_util.py +++ b/trio/_util.py @@ -13,6 +13,9 @@ import trio +CallT = t.TypeVar("CallT", bound=t.Callable[..., t.Any]) + + # Equivalent to the C function raise(), which Python doesn't wrap if os.name == "nt": # On Windows, os.kill exists but is really weird. @@ -199,10 +202,14 @@ def __exit__( self._held = False -def async_wraps(cls, wrapped_cls, attr_name): +def async_wraps( + cls: type[object], + wrapped_cls: type[object], + attr_name: str, +) -> t.Callable[[CallT], CallT]: """Similar to wraps, but for async wrappers of non-async functions.""" - def decorator(func): + def decorator(func: CallT) -> CallT: func.__name__ = attr_name func.__qualname__ = ".".join((cls.__qualname__, attr_name))