Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,13 @@ incremental in minor, bugfixes only are patches.
See [0Ver](https://0ver.org/).


## 0.24.0 WIP

### Features

- Add picky exceptions to `impure_safe` decorator like `safe` has. Issue #1543


## 0.23.0

### Features
Expand Down
71 changes: 61 additions & 10 deletions returns/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@
Iterator,
List,
Optional,
Tuple,
Type,
TypeVar,
Union,
final,
overload,
)

from typing_extensions import ParamSpec
Expand Down Expand Up @@ -885,9 +888,33 @@ def lash(self, function):

# impure_safe decorator:

@overload
def impure_safe(
function: Callable[_FuncParams, _NewValueType],
) -> Callable[_FuncParams, IOResultE[_NewValueType]]:
"""Decorator to convert exception-throwing for any kind of Exception."""


@overload
def impure_safe(
exceptions: Tuple[Type[Exception], ...],
) -> Callable[
[Callable[_FuncParams, _NewValueType]],
Callable[_FuncParams, IOResultE[_NewValueType]],
]:
"""Decorator to convert exception-throwing just for a set of Exceptions."""


def impure_safe( # type: ignore # noqa: WPS234, C901
function: Optional[Callable[_FuncParams, _NewValueType]] = None,
exceptions: Optional[Tuple[Type[Exception], ...]] = None,
) -> Union[
Callable[_FuncParams, IOResultE[_NewValueType]],
Callable[
[Callable[_FuncParams, _NewValueType]],
Callable[_FuncParams, IOResultE[_NewValueType]],
],
]:
"""
Decorator to mark function that it returns :class:`~IOResult` container.

Expand All @@ -910,16 +937,40 @@ def impure_safe(
>>> assert function(1) == IOSuccess(1.0)
>>> assert function(0).failure()

You can also use it with explicit exception types as the first argument:

.. code:: python

>>> from returns.io import IOSuccess, IOFailure, impure_safe

>>> @impure_safe(exceptions=(ZeroDivisionError,))
... def might_raise(arg: int) -> float:
... return 1 / arg

>>> assert might_raise(1) == IOSuccess(1.0)
>>> assert isinstance(might_raise(0), IOFailure)

In this case, only exceptions that are explicitly
listed are going to be caught.

Similar to :func:`returns.future.future_safe`
and :func:`returns.result.safe` decorators.
"""
@wraps(function)
def decorator(
*args: _FuncParams.args,
**kwargs: _FuncParams.kwargs,
) -> IOResultE[_NewValueType]:
try:
return IOSuccess(function(*args, **kwargs))
except Exception as exc:
return IOFailure(exc)
return decorator
def factory(
inner_function: Callable[_FuncParams, _NewValueType],
inner_exceptions: Tuple[Type[Exception], ...],
) -> Callable[_FuncParams, IOResultE[_NewValueType]]:
@wraps(inner_function)
def decorator(*args: _FuncParams.args, **kwargs: _FuncParams.kwargs):
try:
return IOSuccess(inner_function(*args, **kwargs))
except inner_exceptions as exc:
return IOFailure(exc)
return decorator

if callable(function):
return factory(function, exceptions or (Exception,))
if isinstance(function, tuple):
exceptions = function # type: ignore
function = None
return lambda function: factory(function, exceptions) # type: ignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from typing import Union

import pytest

from returns.io import IOSuccess, impure_safe


Expand All @@ -6,6 +10,18 @@ def _function(number: int) -> float:
return number / number


@impure_safe(exceptions=(ZeroDivisionError,))
def _function_two(number: Union[int, str]) -> float:
assert isinstance(number, int)
return number / number


@impure_safe((ZeroDivisionError,)) # no name
def _function_three(number: Union[int, str]) -> float:
assert isinstance(number, int)
return number / number


def test_safe_iosuccess():
"""Ensures that safe decorator works correctly for IOSuccess case."""
assert _function(1) == IOSuccess(1.0)
Expand All @@ -17,3 +33,24 @@ def test_safe_iofailure():
assert isinstance(
failed.failure()._inner_value, ZeroDivisionError, # noqa: WPS437
)


def test_safe_failure_with_expected_error():
"""Ensures that safe decorator works correctly for Failure case."""
failed = _function_two(0)
assert isinstance(
failed.failure()._inner_value, # noqa: WPS437
ZeroDivisionError,
)

failed2 = _function_three(0)
assert isinstance(
failed2.failure()._inner_value, # noqa: WPS437
ZeroDivisionError,
)


def test_safe_failure_with_non_expected_error():
"""Ensures that safe decorator works correctly for Failure case."""
with pytest.raises(AssertionError):
_function_two('0')
18 changes: 18 additions & 0 deletions typesafety/test_io/test_ioresult_container/test_impure_safe.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,21 @@
return 1

reveal_type(test) # N: Revealed type is "def (arg: builtins.str) -> returns.io.IOResult[builtins.int, builtins.Exception]"


- case: impure_decorator_passing_exceptions_no_params
disable_cache: false
main: |
from returns.io import impure_safe

@impure_safe((ValueError,))
def test1(arg: str) -> int:
return 1

reveal_type(test1) # N: Revealed type is "def (arg: builtins.str) -> returns.io.IOResult[builtins.int, builtins.Exception]"

@impure_safe(exceptions=(ValueError,))
def test2(arg: str) -> int:
return 1

reveal_type(test2) # N: Revealed type is "def (arg: builtins.str) -> returns.io.IOResult[builtins.int, builtins.Exception]"