diff --git a/CHANGELOG.md b/CHANGELOG.md index 1eaea6187..228329e0f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,13 @@ incremental in minor, bugfixes only are patches. See [0Ver](https://0ver.org/). +## 0.24.0 WIP + +### Features + +- Add picky exceptions to `impure_safe` decorator like `safe` has. Issue #1543 + + ## 0.23.0 ### Features diff --git a/returns/io.py b/returns/io.py index 57891a7c1..d0cafe989 100644 --- a/returns/io.py +++ b/returns/io.py @@ -9,9 +9,12 @@ Iterator, List, Optional, + Tuple, + Type, TypeVar, Union, final, + overload, ) from typing_extensions import ParamSpec @@ -885,9 +888,33 @@ def lash(self, function): # impure_safe decorator: +@overload def impure_safe( function: Callable[_FuncParams, _NewValueType], ) -> Callable[_FuncParams, IOResultE[_NewValueType]]: + """Decorator to convert exception-throwing for any kind of Exception.""" + + +@overload +def impure_safe( + exceptions: Tuple[Type[Exception], ...], +) -> Callable[ + [Callable[_FuncParams, _NewValueType]], + Callable[_FuncParams, IOResultE[_NewValueType]], +]: + """Decorator to convert exception-throwing just for a set of Exceptions.""" + + +def impure_safe( # type: ignore # noqa: WPS234, C901 + function: Optional[Callable[_FuncParams, _NewValueType]] = None, + exceptions: Optional[Tuple[Type[Exception], ...]] = None, +) -> Union[ + Callable[_FuncParams, IOResultE[_NewValueType]], + Callable[ + [Callable[_FuncParams, _NewValueType]], + Callable[_FuncParams, IOResultE[_NewValueType]], + ], +]: """ Decorator to mark function that it returns :class:`~IOResult` container. @@ -910,16 +937,40 @@ def impure_safe( >>> assert function(1) == IOSuccess(1.0) >>> assert function(0).failure() + You can also use it with explicit exception types as the first argument: + + .. code:: python + + >>> from returns.io import IOSuccess, IOFailure, impure_safe + + >>> @impure_safe(exceptions=(ZeroDivisionError,)) + ... def might_raise(arg: int) -> float: + ... return 1 / arg + + >>> assert might_raise(1) == IOSuccess(1.0) + >>> assert isinstance(might_raise(0), IOFailure) + + In this case, only exceptions that are explicitly + listed are going to be caught. + Similar to :func:`returns.future.future_safe` and :func:`returns.result.safe` decorators. """ - @wraps(function) - def decorator( - *args: _FuncParams.args, - **kwargs: _FuncParams.kwargs, - ) -> IOResultE[_NewValueType]: - try: - return IOSuccess(function(*args, **kwargs)) - except Exception as exc: - return IOFailure(exc) - return decorator + def factory( + inner_function: Callable[_FuncParams, _NewValueType], + inner_exceptions: Tuple[Type[Exception], ...], + ) -> Callable[_FuncParams, IOResultE[_NewValueType]]: + @wraps(inner_function) + def decorator(*args: _FuncParams.args, **kwargs: _FuncParams.kwargs): + try: + return IOSuccess(inner_function(*args, **kwargs)) + except inner_exceptions as exc: + return IOFailure(exc) + return decorator + + if callable(function): + return factory(function, exceptions or (Exception,)) + if isinstance(function, tuple): + exceptions = function # type: ignore + function = None + return lambda function: factory(function, exceptions) # type: ignore diff --git a/tests/test_io/test_ioresult_container/test_ioresult_functions/test_impure_safe.py b/tests/test_io/test_ioresult_container/test_ioresult_functions/test_impure_safe.py index 3b9162013..736f8bb05 100644 --- a/tests/test_io/test_ioresult_container/test_ioresult_functions/test_impure_safe.py +++ b/tests/test_io/test_ioresult_container/test_ioresult_functions/test_impure_safe.py @@ -1,3 +1,7 @@ +from typing import Union + +import pytest + from returns.io import IOSuccess, impure_safe @@ -6,6 +10,18 @@ def _function(number: int) -> float: return number / number +@impure_safe(exceptions=(ZeroDivisionError,)) +def _function_two(number: Union[int, str]) -> float: + assert isinstance(number, int) + return number / number + + +@impure_safe((ZeroDivisionError,)) # no name +def _function_three(number: Union[int, str]) -> float: + assert isinstance(number, int) + return number / number + + def test_safe_iosuccess(): """Ensures that safe decorator works correctly for IOSuccess case.""" assert _function(1) == IOSuccess(1.0) @@ -17,3 +33,24 @@ def test_safe_iofailure(): assert isinstance( failed.failure()._inner_value, ZeroDivisionError, # noqa: WPS437 ) + + +def test_safe_failure_with_expected_error(): + """Ensures that safe decorator works correctly for Failure case.""" + failed = _function_two(0) + assert isinstance( + failed.failure()._inner_value, # noqa: WPS437 + ZeroDivisionError, + ) + + failed2 = _function_three(0) + assert isinstance( + failed2.failure()._inner_value, # noqa: WPS437 + ZeroDivisionError, + ) + + +def test_safe_failure_with_non_expected_error(): + """Ensures that safe decorator works correctly for Failure case.""" + with pytest.raises(AssertionError): + _function_two('0') diff --git a/typesafety/test_io/test_ioresult_container/test_impure_safe.yml b/typesafety/test_io/test_ioresult_container/test_impure_safe.yml index 6541a587a..4d9ed1bc6 100644 --- a/typesafety/test_io/test_ioresult_container/test_impure_safe.yml +++ b/typesafety/test_io/test_ioresult_container/test_impure_safe.yml @@ -8,3 +8,21 @@ return 1 reveal_type(test) # N: Revealed type is "def (arg: builtins.str) -> returns.io.IOResult[builtins.int, builtins.Exception]" + + +- case: impure_decorator_passing_exceptions_no_params + disable_cache: false + main: | + from returns.io import impure_safe + + @impure_safe((ValueError,)) + def test1(arg: str) -> int: + return 1 + + reveal_type(test1) # N: Revealed type is "def (arg: builtins.str) -> returns.io.IOResult[builtins.int, builtins.Exception]" + + @impure_safe(exceptions=(ValueError,)) + def test2(arg: str) -> int: + return 1 + + reveal_type(test2) # N: Revealed type is "def (arg: builtins.str) -> returns.io.IOResult[builtins.int, builtins.Exception]"