From 4bce147ebbd990909ce052ba61fbcbaee5733ef4 Mon Sep 17 00:00:00 2001 From: Binil Date: Wed, 28 Jan 2026 14:04:09 -0500 Subject: [PATCH 1/2] test coverage changes added --- tests/test_api_model.py | 69 +++++++++++++++++++++++++++++++++++++---- 1 file changed, 63 insertions(+), 6 deletions(-) diff --git a/tests/test_api_model.py b/tests/test_api_model.py index 2536490..ff4e864 100644 --- a/tests/test_api_model.py +++ b/tests/test_api_model.py @@ -2,11 +2,9 @@ from dataclasses import dataclass -import pydantic - from fastapi_utils.api_model import APIModel - -PYDANTIC_VERSION = pydantic.VERSION +from fastapi_utils.api_model import APIMessage, PYDANTIC_VERSION +from pytest import MonkeyPatch # type: ignore[import] def test_orm_mode() -> None: @@ -22,7 +20,7 @@ class Model(APIModel): if PYDANTIC_VERSION[0] == "2": assert Model.model_validate(Data(x=1)).x == 1 else: - assert Model.from_orm(Data(x=1)).x == 1 + assert Model.from_orm(Data(x=1)).x == 1 # type: ignore[reportDeprecated] def test_aliases() -> None: @@ -30,4 +28,63 @@ class Model(APIModel): some_field: str assert Model(some_field="a").some_field == "a" - assert Model(someField="a").some_field == "a" # type: ignore[call-arg] + assert Model(someField="a").some_field == "a" # type: ignore[reportMissingParameter] + + +def test_alias_population_and_serialization() -> None: + class Model(APIModel): + some_field: str + + m_alias = Model(someField="alias") # type: ignore[reportMissingParameter] + m_field = Model(some_field="field") + + assert m_alias.some_field == "alias" + assert m_field.some_field == "field" + + if PYDANTIC_VERSION[0] == "2": + alias_dump = m_alias.model_dump(by_alias=True) + field_dump = m_alias.model_dump() + else: + alias_dump = m_alias.model_dump(by_alias=True) + field_dump = m_alias.model_dump() + + assert "someField" in alias_dump + assert "some_field" in field_dump + + +def test_api_message_with_additional_field() -> None: + class Msg(APIMessage): + some_field: str + + m = Msg(detail="ok", someField="value") # type: ignore[reportMissingParameter] + assert m.detail == "ok" + assert m.some_field == "value" + + if PYDANTIC_VERSION[0] == "2": + dumped = m.model_dump(by_alias=True) + else: + dumped = m.model_dump(by_alias=True) + + assert dumped["someField"] == "value" + assert dumped["detail"] == "ok" + + +def test_pydantic_v1_branch_reload(monkeypatch: MonkeyPatch) -> None: + """Reload module with a fake pydantic.VERSION starting with '1' to hit the v1 else-branch.""" + import sys + import importlib + import pydantic as _pyd + + original_version = _pyd.VERSION + try: + monkeypatch.setattr(_pyd, "VERSION", ("1",)) + # ensure module is re-imported fresh + sys.modules.pop("fastapi_utils.api_model", None) + am = importlib.import_module("fastapi_utils.api_model") + # the v1 path defines an inner `Config` class on APIModel + assert hasattr(am.APIModel, "Config") + finally: + # restore original module state + monkeypatch.setattr(_pyd, "VERSION", original_version) + sys.modules.pop("fastapi_utils.api_model", None) + importlib.import_module("fastapi_utils.api_model") From c35f59c8e556fa996eb261219c3df8ad5fed52b1 Mon Sep 17 00:00:00 2001 From: Binil Date: Wed, 8 Apr 2026 12:24:56 -0400 Subject: [PATCH 2/2] fix(tasks): add repeat_every task handle + cancellation helpers --- fastapi_utils/tasks.py | 32 +++++++++++++++++++++++++++- tests/test_tasks.py | 47 +++++++++++++++++++++++++++++++++++++++++- 2 files changed, 77 insertions(+), 2 deletions(-) diff --git a/fastapi_utils/tasks.py b/fastapi_utils/tasks.py index 5c83f65..fb145f3 100644 --- a/fastapi_utils/tasks.py +++ b/fastapi_utils/tasks.py @@ -6,6 +6,7 @@ from functools import wraps from traceback import format_exception from typing import Any, Callable, Coroutine, Union +from weakref import WeakKeyDictionary from starlette.concurrency import run_in_threadpool @@ -18,6 +19,33 @@ NoArgsNoReturnDecorator = Callable[[NoArgsNoReturnAnyFuncT], NoArgsNoReturnAsyncFuncT] +_REPEATED_TASKS: WeakKeyDictionary[NoArgsNoReturnAsyncFuncT, asyncio.Task[None]] = WeakKeyDictionary() + + +def get_repeated_task(task_func: NoArgsNoReturnAsyncFuncT) -> asyncio.Task[None] | None: + """Return the currently running repeated task for a wrapped function, if any.""" + return _REPEATED_TASKS.get(task_func) + + +async def cancel_repeated_task(task_func: NoArgsNoReturnAsyncFuncT) -> bool: + """Cancel the running repeated task for a wrapped function. + + Returns True when a running task existed and cancellation was requested. + Returns False when there is no active task to cancel. + """ + task = _REPEATED_TASKS.get(task_func) + if task is None or task.done(): + return False + + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + return True + + async def _handle_func(func: NoArgsNoReturnAnyFuncT) -> None: if asyncio.iscoroutinefunction(func): await func() @@ -111,7 +139,9 @@ async def loop() -> None: if on_complete: await _handle_func(on_complete) - asyncio.ensure_future(loop()) + task = asyncio.create_task(loop()) + _REPEATED_TASKS[wrapped] = task + task.add_done_callback(lambda _: _REPEATED_TASKS.pop(wrapped, None)) return wrapped diff --git a/tests/test_tasks.py b/tests/test_tasks.py index a50841f..38b9156 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -4,7 +4,12 @@ import pytest -from fastapi_utils.tasks import NoArgsNoReturnAsyncFuncT, repeat_every +from fastapi_utils.tasks import ( + NoArgsNoReturnAsyncFuncT, + cancel_repeated_task, + get_repeated_task, + repeat_every, +) # Fixtures: @@ -159,6 +164,26 @@ async def test_max_repetitions_and_wait_first( assert self.counter == max_repetitions asyncio_sleep_mock.assert_has_calls((max_repetitions + 1) * [call(seconds)], any_order=True) + @pytest.mark.asyncio + @pytest.mark.timeout(1) + async def test_exposes_task_handle_and_cancel_path(self, seconds: float) -> None: + def increase_counter_forever() -> None: + self.increase_counter() + + wrapped = repeat_every(seconds=seconds)(increase_counter_forever) + + await wrapped() + await asyncio.sleep(seconds * 2) + + task = get_repeated_task(wrapped) + assert task is not None + assert not task.done() + + canceled = await cancel_repeated_task(wrapped) + assert canceled + + assert get_repeated_task(wrapped) is None + @pytest.mark.asyncio @pytest.mark.timeout(1) async def test_stop_loop_on_exc( @@ -224,6 +249,26 @@ async def test_max_repetitions_and_wait_first( assert self.counter == max_repetitions asyncio_sleep_mock.assert_has_calls((max_repetitions + 1) * [call(seconds)], any_order=True) + @pytest.mark.asyncio + @pytest.mark.timeout(1) + async def test_exposes_task_handle_and_cancel_path(self, seconds: float) -> None: + async def increase_counter_forever_async() -> None: + self.increase_counter() + + wrapped = repeat_every(seconds=seconds)(increase_counter_forever_async) + + await wrapped() + await asyncio.sleep(seconds * 2) + + task = get_repeated_task(wrapped) + assert task is not None + assert not task.done() + + canceled = await cancel_repeated_task(wrapped) + assert canceled + + assert get_repeated_task(wrapped) is None + @pytest.mark.asyncio @pytest.mark.timeout(1) async def test_stop_loop_on_exc(