Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion fastapi_utils/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from functools import wraps
from traceback import format_exception
from typing import Any, Callable, Coroutine, Union
from weakref import WeakKeyDictionary

from starlette.concurrency import run_in_threadpool

Expand All @@ -18,6 +19,33 @@
NoArgsNoReturnDecorator = Callable[[NoArgsNoReturnAnyFuncT], NoArgsNoReturnAsyncFuncT]


_REPEATED_TASKS: WeakKeyDictionary[NoArgsNoReturnAsyncFuncT, asyncio.Task[None]] = WeakKeyDictionary()


def get_repeated_task(task_func: NoArgsNoReturnAsyncFuncT) -> asyncio.Task[None] | None:
"""Return the currently running repeated task for a wrapped function, if any."""
return _REPEATED_TASKS.get(task_func)


async def cancel_repeated_task(task_func: NoArgsNoReturnAsyncFuncT) -> bool:
"""Cancel the running repeated task for a wrapped function.

Returns True when a running task existed and cancellation was requested.
Returns False when there is no active task to cancel.
"""
task = _REPEATED_TASKS.get(task_func)
if task is None or task.done():
return False

task.cancel()
try:
await task
except asyncio.CancelledError:
pass

return True


async def _handle_func(func: NoArgsNoReturnAnyFuncT) -> None:
if asyncio.iscoroutinefunction(func):
await func()
Expand Down Expand Up @@ -111,7 +139,9 @@ async def loop() -> None:
if on_complete:
await _handle_func(on_complete)

asyncio.ensure_future(loop())
task = asyncio.create_task(loop())
_REPEATED_TASKS[wrapped] = task
task.add_done_callback(lambda _: _REPEATED_TASKS.pop(wrapped, None))

return wrapped

Expand Down
69 changes: 63 additions & 6 deletions tests/test_api_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,9 @@

from dataclasses import dataclass

import pydantic

from fastapi_utils.api_model import APIModel

PYDANTIC_VERSION = pydantic.VERSION
from fastapi_utils.api_model import APIMessage, PYDANTIC_VERSION
from pytest import MonkeyPatch # type: ignore[import]


def test_orm_mode() -> None:
Expand All @@ -22,12 +20,71 @@ class Model(APIModel):
if PYDANTIC_VERSION[0] == "2":
assert Model.model_validate(Data(x=1)).x == 1
else:
assert Model.from_orm(Data(x=1)).x == 1
assert Model.from_orm(Data(x=1)).x == 1 # type: ignore[reportDeprecated]


def test_aliases() -> None:
class Model(APIModel):
some_field: str

assert Model(some_field="a").some_field == "a"
assert Model(someField="a").some_field == "a" # type: ignore[call-arg]
assert Model(someField="a").some_field == "a" # type: ignore[reportMissingParameter]


def test_alias_population_and_serialization() -> None:
class Model(APIModel):
some_field: str

m_alias = Model(someField="alias") # type: ignore[reportMissingParameter]
m_field = Model(some_field="field")

assert m_alias.some_field == "alias"
assert m_field.some_field == "field"

if PYDANTIC_VERSION[0] == "2":
alias_dump = m_alias.model_dump(by_alias=True)
field_dump = m_alias.model_dump()
else:
alias_dump = m_alias.model_dump(by_alias=True)
field_dump = m_alias.model_dump()

assert "someField" in alias_dump
assert "some_field" in field_dump


def test_api_message_with_additional_field() -> None:
class Msg(APIMessage):
some_field: str

m = Msg(detail="ok", someField="value") # type: ignore[reportMissingParameter]
assert m.detail == "ok"
assert m.some_field == "value"

if PYDANTIC_VERSION[0] == "2":
dumped = m.model_dump(by_alias=True)
else:
dumped = m.model_dump(by_alias=True)

assert dumped["someField"] == "value"
assert dumped["detail"] == "ok"


def test_pydantic_v1_branch_reload(monkeypatch: MonkeyPatch) -> None:
"""Reload module with a fake pydantic.VERSION starting with '1' to hit the v1 else-branch."""
import sys
import importlib
import pydantic as _pyd

original_version = _pyd.VERSION
try:
monkeypatch.setattr(_pyd, "VERSION", ("1",))
# ensure module is re-imported fresh
sys.modules.pop("fastapi_utils.api_model", None)
am = importlib.import_module("fastapi_utils.api_model")
# the v1 path defines an inner `Config` class on APIModel
assert hasattr(am.APIModel, "Config")
finally:
# restore original module state
monkeypatch.setattr(_pyd, "VERSION", original_version)
sys.modules.pop("fastapi_utils.api_model", None)
importlib.import_module("fastapi_utils.api_model")
47 changes: 46 additions & 1 deletion tests/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@

import pytest

from fastapi_utils.tasks import NoArgsNoReturnAsyncFuncT, repeat_every
from fastapi_utils.tasks import (
NoArgsNoReturnAsyncFuncT,
cancel_repeated_task,
get_repeated_task,
repeat_every,
)


# Fixtures:
Expand Down Expand Up @@ -159,6 +164,26 @@ async def test_max_repetitions_and_wait_first(
assert self.counter == max_repetitions
asyncio_sleep_mock.assert_has_calls((max_repetitions + 1) * [call(seconds)], any_order=True)

@pytest.mark.asyncio
@pytest.mark.timeout(1)
async def test_exposes_task_handle_and_cancel_path(self, seconds: float) -> None:
def increase_counter_forever() -> None:
self.increase_counter()

wrapped = repeat_every(seconds=seconds)(increase_counter_forever)

await wrapped()
await asyncio.sleep(seconds * 2)

task = get_repeated_task(wrapped)
assert task is not None
assert not task.done()

canceled = await cancel_repeated_task(wrapped)
assert canceled

assert get_repeated_task(wrapped) is None

@pytest.mark.asyncio
@pytest.mark.timeout(1)
async def test_stop_loop_on_exc(
Expand Down Expand Up @@ -224,6 +249,26 @@ async def test_max_repetitions_and_wait_first(
assert self.counter == max_repetitions
asyncio_sleep_mock.assert_has_calls((max_repetitions + 1) * [call(seconds)], any_order=True)

@pytest.mark.asyncio
@pytest.mark.timeout(1)
async def test_exposes_task_handle_and_cancel_path(self, seconds: float) -> None:
async def increase_counter_forever_async() -> None:
self.increase_counter()

wrapped = repeat_every(seconds=seconds)(increase_counter_forever_async)

await wrapped()
await asyncio.sleep(seconds * 2)

task = get_repeated_task(wrapped)
assert task is not None
assert not task.done()

canceled = await cancel_repeated_task(wrapped)
assert canceled

assert get_repeated_task(wrapped) is None

@pytest.mark.asyncio
@pytest.mark.timeout(1)
async def test_stop_loop_on_exc(
Expand Down