From 19e1dec09e86a8100233c933d1d58c65c0aa5630 Mon Sep 17 00:00:00 2001 From: Chad Retz Date: Wed, 30 Nov 2022 17:39:35 -0600 Subject: [PATCH 1/3] Support raising cancellation in sync multithreaded activities --- README.md | 19 +++++-- temporalio/activity.py | 72 ++++++++++++++++++----- temporalio/bridge/runtime.py | 9 ++- temporalio/bridge/src/lib.rs | 6 ++ temporalio/bridge/src/runtime.rs | 5 ++ temporalio/exceptions.py | 2 +- temporalio/testing/_activity.py | 25 ++++++++ temporalio/worker/_activity.py | 94 ++++++++++++++++++++++++++----- temporalio/worker/_interceptor.py | 1 - tests/bridge/__init__.py | 0 tests/bridge/test_runtime.py | 40 +++++++++++++ tests/testing/test_activity.py | 21 ++++++- tests/worker/test_activity.py | 83 ++++++++++++++++++++++++--- 13 files changed, 332 insertions(+), 45 deletions(-) create mode 100644 tests/bridge/__init__.py create mode 100644 tests/bridge/test_runtime.py diff --git a/README.md b/README.md index 9fc319206..e2ba81db1 100644 --- a/README.md +++ b/README.md @@ -860,10 +860,10 @@ Synchronous activities, i.e. functions that do not have `async def`, can be used `activity_executor` worker parameter must be set with a `concurrent.futures.Executor` instance to use for executing the activities. -Cancellation for synchronous activities is done in the background and the activity must choose to listen for it and -react appropriately. If after cancellation is obtained an unwrapped `temporalio.exceptions.CancelledError` is raised, -the activity will be marked cancelled. An activity must heartbeat to receive cancellation and there are other ways to be -notified about cancellation (see "Activity Context" and "Heartbeating and Cancellation" later). +All long running activities should heartbeat so they can be cancelled. Cancellation in threaded activities throws but +multiprocess/other activities does not. The sections below on each synchronous type explain further. There are also +calls on the context that can check for cancellation. For more information, see "Activity Context" and +"Heartbeating and Cancellation" sections later. Note, all calls from an activity to functions in the `temporalio.activity` package are powered by [contextvars](https://docs.python.org/3/library/contextvars.html). Therefore, new threads starting _inside_ of @@ -876,6 +876,15 @@ If `activity_executor` is set to an instance of `concurrent.futures.ThreadPoolEx are considered multithreaded activities. Besides `activity_executor`, no other worker parameters are required for synchronous multithreaded activities. +By default, cancellation of a synchronous multithreaded activity is done via a `temporalio.exceptions.CancelledError` +thrown into the activity thread. Activities that do not wish to have cancellation thrown can set +`no_thread_cancel_exception=True` in the `@activity.defn` decorator. + +Code that wishes to be temporarily shielded from the cancellation exception can run inside +`with activity.shield_thread_cancel_exception():`. But once the last nested form of that block is finished, even if +there is a return statement within, it will throw the cancellation if there was one. A `try` + +`except temporalio.exceptions.CancelledError` would have to surround the `with` to handle the cancellation explicitly. + ###### Synchronous Multiprocess/Other Activities If `activity_executor` is set to an instance of `concurrent.futures.Executor` that is _not_ @@ -901,6 +910,8 @@ calls in the `temporalio.activity` package make use of it. Specifically: * `is_cancelled()` - Whether a cancellation has been requested on this activity * `wait_for_cancelled()` - `async` call to wait for cancellation request * `wait_for_cancelled_sync(timeout)` - Synchronous blocking call to wait for cancellation request +* `shield_thread_cancel_exception()` - Context manager for use in `with` clauses by synchronous multithreaded activities + to prevent cancel exception from being thrown during the block of code * `is_worker_shutdown()` - Whether the worker has started graceful shutdown * `wait_for_worker_shutdown()` - `async` call to wait for start of graceful worker shutdown * `wait_for_worker_shutdown_sync(timeout)` - Synchronous blocking call to wait for start of graceful worker shutdown diff --git a/temporalio/activity.py b/temporalio/activity.py index f7371e68c..bc5e2a84a 100644 --- a/temporalio/activity.py +++ b/temporalio/activity.py @@ -15,12 +15,13 @@ import inspect import logging import threading +from contextlib import AbstractContextManager, contextmanager from dataclasses import dataclass from datetime import datetime, timedelta -from functools import partial from typing import ( Any, Callable, + Iterator, List, Mapping, MutableMapping, @@ -34,7 +35,6 @@ ) import temporalio.common -import temporalio.exceptions from .types import CallableType @@ -45,11 +45,18 @@ def defn(fn: CallableType) -> CallableType: @overload -def defn(*, name: str) -> Callable[[CallableType], CallableType]: +def defn( + *, name: Optional[str] = None, no_thread_cancel_exception: bool = False +) -> Callable[[CallableType], CallableType]: ... -def defn(fn: Optional[CallableType] = None, *, name: Optional[str] = None): +def defn( + fn: Optional[CallableType] = None, + *, + name: Optional[str] = None, + no_thread_cancel_exception: bool = False, +): """Decorator for activity functions. Activities can be async or non-async. @@ -57,20 +64,22 @@ def defn(fn: Optional[CallableType] = None, *, name: Optional[str] = None): Args: fn: The function to decorate. name: Name to use for the activity. Defaults to function ``__name__``. + no_thread_cancel_exception: If set to true, an exception will not be + raised in synchronous, threaded activities upon cancellation. """ - def with_name(name: str, fn: CallableType) -> CallableType: + def decorator(fn: CallableType) -> CallableType: # This performs validation - _Definition._apply_to_callable(fn, name) + _Definition._apply_to_callable( + fn, + activity_name=name or fn.__name__, + no_thread_cancel_exception=no_thread_cancel_exception, + ) return fn - # If name option is available, return decorator function - if name is not None: - return partial(with_name, name) - if fn is None: - raise RuntimeError("Cannot invoke defn without function or name") - # Otherwise just run decorator function - return with_name(fn.__name__, fn) + if fn is not None: + return decorator(fn) + return decorator @dataclass(frozen=True) @@ -122,6 +131,7 @@ class _Context: heartbeat: Optional[Callable[..., None]] cancelled_event: _CompositeEvent worker_shutdown_event: _CompositeEvent + shield_thread_cancel_exception: Optional[Callable[[], AbstractContextManager]] _logger_details: Optional[Mapping[str, Any]] = None @staticmethod @@ -221,6 +231,36 @@ def is_cancelled() -> bool: return _Context.current().cancelled_event.is_set() +@contextmanager +def shield_thread_cancel_exception() -> Iterator[None]: + """Context manager for synchronous multithreaded activities to delay + cancellation exceptions. + + By default, heartbeating synchronous multithreaded activities have an + exception thrown inside when cancellation occurs. Code within a "with" block + of this context manager will delay that throwing until the end. Even if the + block returns a value or throws its own exception, if a cancellation + exception is pending, it is thrown instead. Therefore users are encouraged + to not throw out of this block and can surround this with a try/except if + they wish to catch a cancellation. + + This properly supports nested calls and will only throw after the last one. + + This just runs the blocks with no extra effects for async activities or + synchronous multiprocess/other activities. + + Raises: + temporalio.exceptions.CancelledError: If a cancellation occurs anytime + during this block and this is not nested in another shield block. + """ + shield_context = _Context.current().shield_thread_cancel_exception + if not shield_context: + yield None + else: + with shield_context(): + yield None + + async def wait_for_cancelled() -> None: """Asynchronously wait for this activity to get a cancellation request. @@ -353,6 +393,7 @@ class _Definition: name: str fn: Callable is_async: bool + no_thread_cancel_exception: bool # Types loaded on post init if both are None arg_types: Optional[List[Type]] = None ret_type: Optional[Type] = None @@ -379,7 +420,9 @@ def must_from_callable(fn: Callable) -> _Definition: ) @staticmethod - def _apply_to_callable(fn: Callable, activity_name: str) -> None: + def _apply_to_callable( + fn: Callable, *, activity_name: str, no_thread_cancel_exception: bool = False + ) -> None: # Validate the activity if hasattr(fn, "__temporal_activity_definition"): raise ValueError("Function already contains activity definition") @@ -399,6 +442,7 @@ def _apply_to_callable(fn: Callable, activity_name: str) -> None: # iscoroutinefunction does not return true for async __call__ # TODO(cretz): Why can't MyPy handle this? is_async=inspect.iscoroutinefunction(fn) or inspect.iscoroutinefunction(fn.__call__), # type: ignore + no_thread_cancel_exception=no_thread_cancel_exception, ), ) diff --git a/temporalio/bridge/runtime.py b/temporalio/bridge/runtime.py index bb5ee5cb0..e94746ce5 100644 --- a/temporalio/bridge/runtime.py +++ b/temporalio/bridge/runtime.py @@ -6,7 +6,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import ClassVar, Mapping, Optional +from typing import ClassVar, Mapping, Optional, Type import temporalio.bridge.temporal_sdk_bridge @@ -54,6 +54,13 @@ def set_default(runtime: Runtime, *, error_if_already_set: bool = True) -> None: raise RuntimeError("Runtime default already set") _default_runtime = runtime + @staticmethod + def _raise_in_thread(thread_id: int, exc_type: Type[BaseException]) -> bool: + """Internal helper for raising an exception in thread.""" + return temporalio.bridge.temporal_sdk_bridge.raise_in_thread( + thread_id, exc_type + ) + def __init__(self, *, telemetry: TelemetryConfig) -> None: """Create a default runtime with the given telemetry config. diff --git a/temporalio/bridge/src/lib.rs b/temporalio/bridge/src/lib.rs index 91ecbdcec..014eb6b38 100644 --- a/temporalio/bridge/src/lib.rs +++ b/temporalio/bridge/src/lib.rs @@ -16,6 +16,7 @@ fn temporal_sdk_bridge(py: Python, m: &PyModule) -> PyResult<()> { // Runtime stuff m.add_class::()?; m.add_function(wrap_pyfunction!(init_runtime, m)?)?; + m.add_function(wrap_pyfunction!(raise_in_thread, m)?)?; // Testing stuff m.add_class::()?; @@ -48,6 +49,11 @@ fn init_runtime(telemetry_config: runtime::TelemetryConfig) -> PyResult(py: Python<'a>, thread_id: i32, exc: &PyAny) -> bool { + runtime::raise_in_thread(py, thread_id, exc) +} + #[pyfunction] fn start_temporalite<'a>( py: Python<'a>, diff --git a/temporalio/bridge/src/runtime.rs b/temporalio/bridge/src/runtime.rs index 11fda3113..ab02b4e19 100644 --- a/temporalio/bridge/src/runtime.rs +++ b/temporalio/bridge/src/runtime.rs @@ -1,5 +1,6 @@ use pyo3::exceptions::{PyRuntimeError, PyValueError}; use pyo3::prelude::*; +use pyo3::AsPyPointer; use std::collections::HashMap; use std::future::Future; use std::net::SocketAddr; @@ -75,6 +76,10 @@ pub fn init_runtime(telemetry_config: TelemetryConfig) -> PyResult { }) } +pub fn raise_in_thread<'a>(_py: Python<'a>, thread_id: i32, exc: &PyAny) -> bool { + unsafe { pyo3::ffi::PyThreadState_SetAsyncExc(thread_id, exc.as_ptr()) == 1 } +} + impl Runtime { pub fn future_into_py<'a, F, T>(&self, py: Python<'a>, fut: F) -> PyResult<&'a PyAny> where diff --git a/temporalio/exceptions.py b/temporalio/exceptions.py index 5c05ec459..68515e730 100644 --- a/temporalio/exceptions.py +++ b/temporalio/exceptions.py @@ -97,7 +97,7 @@ def non_retryable(self) -> bool: class CancelledError(FailureError): """Error raised on workflow/activity cancellation.""" - def __init__(self, message: str, *details: Any) -> None: + def __init__(self, message: str = "Cancelled", *details: Any) -> None: """Initialize a cancelled error.""" super().__init__(message) self._details = details diff --git a/temporalio/testing/_activity.py b/temporalio/testing/_activity.py index f035810ef..91dfd8ebd 100644 --- a/temporalio/testing/_activity.py +++ b/temporalio/testing/_activity.py @@ -12,6 +12,8 @@ from typing_extensions import ParamSpec import temporalio.activity +import temporalio.exceptions +import temporalio.worker._activity _Params = ParamSpec("_Params") _Return = TypeVar("_Return") @@ -111,6 +113,17 @@ def __init__( self.env = env self.fn = fn self.is_async = inspect.iscoroutinefunction(fn) + self.cancel_thread_raiser: Optional[ + temporalio.worker._activity._ThreadExceptionRaiser + ] = None + if not self.is_async: + # If there is a definition and they disable thread raising, don't + # set + defn = temporalio.activity._Definition.from_callable(fn) + if not defn or not defn.no_thread_cancel_exception: + self.cancel_thread_raiser = ( + temporalio.worker._activity._ThreadExceptionRaiser() + ) # Create context self.context = temporalio.activity._Context( info=lambda: env.info, @@ -123,10 +136,18 @@ def __init__( thread_event=threading.Event(), async_event=asyncio.Event() if self.is_async else None, ), + shield_thread_cancel_exception=None + if not self.cancel_thread_raiser + else self.cancel_thread_raiser.shielded, ) self.task: Optional[asyncio.Task] = None def run(self, *args, **kwargs) -> Any: + if self.cancel_thread_raiser: + thread_id = threading.current_thread().ident + if thread_id is not None: + self.cancel_thread_raiser.set_thread_id(thread_id) + @contextmanager def activity_context(): # Set cancelled and shutdown if already so in environment @@ -163,6 +184,10 @@ async def run_async(): def cancel(self) -> None: if not self.context.cancelled_event.is_set(): self.context.cancelled_event.set() + if self.cancel_thread_raiser: + self.cancel_thread_raiser.raise_in_thread( + temporalio.exceptions.CancelledError + ) if self.task and not self.task.done(): self.task.cancel() diff --git a/temporalio/worker/_activity.py b/temporalio/worker/_activity.py index 9b75c65dd..c30d29ca2 100644 --- a/temporalio/worker/_activity.py +++ b/temporalio/worker/_activity.py @@ -13,9 +13,10 @@ import threading import warnings from abc import ABC, abstractmethod +from contextlib import contextmanager from dataclasses import dataclass from datetime import datetime, timedelta, timezone -from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, Iterator, Optional, Sequence, Tuple, Type, Union import google.protobuf.duration_pb2 import google.protobuf.timestamp_pb2 @@ -27,6 +28,7 @@ import temporalio.bridge.proto.activity_result import temporalio.bridge.proto.activity_task import temporalio.bridge.proto.common +import temporalio.bridge.runtime import temporalio.bridge.worker import temporalio.client import temporalio.common @@ -275,6 +277,8 @@ async def _run_activity( # No async event async_event=None, ) + if not activity_def.no_thread_cancel_exception: + running_activity.cancel_thread_raiser = _ThreadExceptionRaiser() else: manager = self._shared_state_manager # Pre-checked on worker init @@ -367,7 +371,6 @@ async def _run_activity( args=args, executor=None if not running_activity.sync else self._activity_executor, headers=start.header_fields, - _cancelled_event=running_activity.cancelled_event, ) # Set the context early so the logging adapter works and @@ -378,6 +381,9 @@ async def _run_activity( heartbeat=None, cancelled_event=running_activity.cancelled_event, worker_shutdown_event=self._worker_shutdown_event, + shield_thread_cancel_exception=None + if not running_activity.cancel_thread_raiser + else running_activity.cancel_thread_raiser.shielded, ) ) temporalio.activity.logger.debug("Starting activity") @@ -385,7 +391,9 @@ async def _run_activity( # Build the interceptors chaining in reverse. We build a context right # now even though the info() can't be intercepted and heartbeat() will # fail. The interceptors may want to use the info() during init. - impl: ActivityInboundInterceptor = _ActivityInboundImpl(self) + impl: ActivityInboundInterceptor = _ActivityInboundImpl( + self, running_activity + ) for interceptor in reversed(list(self._interceptors)): impl = interceptor.intercept_activity(impl) # Init @@ -481,6 +489,7 @@ class _RunningActivity: task: Optional[asyncio.Task] = None cancelled_event: Optional[temporalio.activity._CompositeEvent] = None last_heartbeat_task: Optional[asyncio.Task] = None + cancel_thread_raiser: Optional[_ThreadExceptionRaiser] = None sync: bool = False done: bool = False cancelled_by_request: bool = False @@ -496,16 +505,63 @@ def cancel( self.cancelled_due_to_heartbeat_error = cancelled_due_to_heartbeat_error if self.cancelled_event: self.cancelled_event.set() - # We do not cancel the task of sync activities - if not self.sync and self.task and not self.done: - # TODO(cretz): Check that Python >= 3.9 and set msg? - self.task.cancel() + if not self.done: + # If there's a thread raiser, use it + if self.cancel_thread_raiser: + self.cancel_thread_raiser.raise_in_thread( + temporalio.exceptions.CancelledError + ) + # If not sync and there's a task, cancel it + if not self.sync and self.task: + # TODO(cretz): Check that Python >= 3.9 and set msg? + self.task.cancel() + + +class _ThreadExceptionRaiser: + def __init__(self) -> None: + self._lock = threading.Lock() + self._thread_id: Optional[int] = None + self._pending_exception: Optional[Type[Exception]] = None + self._shield_depth = 0 + + def set_thread_id(self, thread_id: int) -> None: + with self._lock: + self._thread_id = thread_id + + def raise_in_thread(self, exc_type: Type[Exception]) -> None: + with self._lock: + self._pending_exception = exc_type + self._raise_in_thread_if_pending_unlocked() + + @contextmanager + def shielded(self) -> Iterator[None]: + with self._lock: + self._shield_depth += 1 + try: + yield None + finally: + with self._lock: + self._shield_depth -= 1 + self._raise_in_thread_if_pending_unlocked() + + def _raise_in_thread_if_pending_unlocked(self) -> None: + # Does not apply if no thread ID + if self._thread_id is not None: + # Raise and reset if depth is 0 + if self._shield_depth == 0 and self._pending_exception: + temporalio.bridge.runtime.Runtime._raise_in_thread( + self._thread_id, self._pending_exception + ) + self._pending_exception = None class _ActivityInboundImpl(ActivityInboundInterceptor): - def __init__(self, worker: _ActivityWorker) -> None: + def __init__( + self, worker: _ActivityWorker, running_activity: _RunningActivity + ) -> None: # We are intentionally not calling the base class's __init__ here self._worker = worker + self._running_activity = running_activity def init(self, outbound: ActivityOutboundInterceptor) -> None: # Set the context callables. We are setting values instead of replacing @@ -559,17 +615,20 @@ async def heartbeat_with_context(*details: Any) -> None: ) try: - # Shutdown event always present here - shutdown_event = self._worker._worker_shutdown_event - assert shutdown_event + # Cancel and shutdown event always present here + cancelled_event = self._running_activity.cancelled_event + assert cancelled_event + worker_shutdown_event = self._worker._worker_shutdown_event + assert worker_shutdown_event return await loop.run_in_executor( input.executor, _execute_sync_activity, info, heartbeat, + self._running_activity.cancel_thread_raiser, # Only thread event, this may cross a process boundary - input._cancelled_event.thread_event, - shutdown_event.thread_event, + cancelled_event.thread_event, + worker_shutdown_event.thread_event, input.fn, *input.args, ) @@ -599,11 +658,17 @@ def heartbeat(self, *details: Any) -> None: def _execute_sync_activity( info: temporalio.activity.Info, heartbeat: Union[Callable[..., None], SharedHeartbeatSender], + # This is only set for threaded activities + cancel_thread_raiser: Optional[_ThreadExceptionRaiser], cancelled_event: threading.Event, worker_shutdown_event: threading.Event, fn: Callable[..., Any], *args: Any, ) -> Any: + if cancel_thread_raiser: + thread_id = threading.current_thread().ident + if thread_id is not None: + cancel_thread_raiser.set_thread_id(thread_id) heartbeat_fn: Callable[..., None] if isinstance(heartbeat, SharedHeartbeatSender): # To make mypy happy @@ -623,6 +688,9 @@ def _execute_sync_activity( worker_shutdown_event=temporalio.activity._CompositeEvent( thread_event=worker_shutdown_event, async_event=None ), + shield_thread_cancel_exception=None + if not cancel_thread_raiser + else cancel_thread_raiser.shielded, ) ) return fn(*args) diff --git a/temporalio/worker/_interceptor.py b/temporalio/worker/_interceptor.py index e889417b1..56e2eb5ab 100644 --- a/temporalio/worker/_interceptor.py +++ b/temporalio/worker/_interceptor.py @@ -87,7 +87,6 @@ class ExecuteActivityInput: args: Sequence[Any] executor: Optional[concurrent.futures.Executor] headers: Mapping[str, temporalio.api.common.v1.Payload] - _cancelled_event: temporalio.activity._CompositeEvent class ActivityInboundInterceptor: diff --git a/tests/bridge/__init__.py b/tests/bridge/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/bridge/test_runtime.py b/tests/bridge/test_runtime.py new file mode 100644 index 000000000..e855d857d --- /dev/null +++ b/tests/bridge/test_runtime.py @@ -0,0 +1,40 @@ +from threading import Event, Thread +from time import sleep +from typing import Optional + +from temporalio.bridge.runtime import Runtime + + +class SomeException(Exception): + pass + + +def test_bridge_runtime_raise_in_thread(): + waiting = Event() + exc_in_thread: Optional[Exception] = None + + def wait_forever(): + try: + waiting.set() + while True: + sleep(0.1) + except BaseException as err: + nonlocal exc_in_thread + exc_in_thread = err + + # Start thread + thread = Thread(target=wait_forever, daemon=True) + thread.start() + + # Wait until sleeping + waiting.wait(5) + + # Raise exception + assert thread.ident + assert thread.is_alive() + assert Runtime._raise_in_thread(thread.ident, SomeException) + + # Make sure thread completes + thread.join(5) + assert not thread.is_alive() + assert type(exc_in_thread) is SomeException diff --git a/tests/testing/test_activity.py b/tests/testing/test_activity.py index 70f698bdf..29b66c772 100644 --- a/tests/testing/test_activity.py +++ b/tests/testing/test_activity.py @@ -1,8 +1,10 @@ import asyncio import threading +import time from contextvars import copy_context from temporalio import activity +from temporalio.exceptions import CancelledError from temporalio.testing import ActivityEnvironment @@ -42,6 +44,7 @@ async def via_create_task(): def test_activity_env_sync(): waiting = threading.Event() + properly_cancelled = False def do_stuff(param: str) -> None: activity.heartbeat(f"param: {param}") @@ -58,8 +61,18 @@ def via_thread(): # Wait for cancel waiting.set() - activity.wait_for_cancelled_sync() - activity.heartbeat("cancelled") + try: + # Confirm shielding works + with activity.shield_thread_cancel_exception(): + try: + while not activity.is_cancelled(): + time.sleep(0.2) + time.sleep(0.2) + except: + raise RuntimeError("Unexpected") + except CancelledError: + nonlocal properly_cancelled + properly_cancelled = True env = ActivityEnvironment() # Set heartbeat handler to add to list @@ -70,9 +83,11 @@ def via_thread(): thread.start() waiting.wait() # Cancel and confirm done + time.sleep(1) env.cancel() thread.join() - assert heartbeats == ["param: param1", "task, type: unknown", "cancelled"] + assert heartbeats == ["param: param1", "task, type: unknown"] + assert properly_cancelled async def test_activity_env_assert(): diff --git a/tests/worker/test_activity.py b/tests/worker/test_activity.py index 67cea2831..e2b97bb75 100644 --- a/tests/worker/test_activity.py +++ b/tests/worker/test_activity.py @@ -8,7 +8,7 @@ import uuid from dataclasses import dataclass from datetime import datetime, timedelta, timezone -from typing import Any, Callable, List, Optional, Sequence +from typing import Any, Callable, List, NoReturn, Optional, Sequence import pytest @@ -262,13 +262,18 @@ async def wait_cancel() -> str: assert isinstance(err.value.cause.cause, CancelledError) -async def test_sync_activity_thread_cancel(client: Client, worker: ExternalWorker): +async def test_sync_activity_thread_cancel_caught( + client: Client, worker: ExternalWorker +): @activity.defn def wait_cancel() -> str: - while not activity.is_cancelled(): - time.sleep(1) - activity.heartbeat() - return "Cancelled" + try: + while True: + time.sleep(1) + activity.heartbeat() + except CancelledError: + assert activity.is_cancelled() + return "Cancelled" with concurrent.futures.ThreadPoolExecutor() as executor: result = await _execute_workflow_with_activity( @@ -287,11 +292,71 @@ async def test_sync_activity_thread_cancel_uncaught( client: Client, worker: ExternalWorker ): @activity.defn + def wait_cancel() -> NoReturn: + while True: + time.sleep(1) + activity.heartbeat() + + with pytest.raises(WorkflowFailureError) as err: + with concurrent.futures.ThreadPoolExecutor() as executor: + await _execute_workflow_with_activity( + client, + worker, + wait_cancel, + cancel_after_ms=100, + wait_for_cancellation=True, + heartbeat_timeout_ms=3000, + worker_config={"activity_executor": executor}, + ) + assert isinstance(err.value.cause, ActivityError) + assert isinstance(err.value.cause.cause, CancelledError) + + +async def test_sync_activity_thread_cancel_exception_disabled( + client: Client, worker: ExternalWorker +): + @activity.defn(no_thread_cancel_exception=True) def wait_cancel() -> str: - while not activity.is_cancelled(): + while True: time.sleep(1) activity.heartbeat() - raise CancelledError("Cancelled") + if activity.is_cancelled(): + # Heartbeat again just to confirm nothing happens + time.sleep(1) + activity.heartbeat() + return "Cancelled" + + with concurrent.futures.ThreadPoolExecutor() as executor: + result = await _execute_workflow_with_activity( + client, + worker, + wait_cancel, + cancel_after_ms=100, + wait_for_cancellation=True, + heartbeat_timeout_ms=3000, + worker_config={"activity_executor": executor}, + ) + assert result.result == "Cancelled" + + +async def test_sync_activity_thread_cancel_exception_shielded( + client: Client, worker: ExternalWorker +): + events: List[str] = [] + + @activity.defn + def wait_cancel() -> None: + events.append("pre1") + with activity.shield_thread_cancel_exception(): + events.append("pre2") + with activity.shield_thread_cancel_exception(): + events.append("pre3") + while not activity.is_cancelled(): + time.sleep(1) + activity.heartbeat() + events.append("post3") + events.append("post2") + events.append("post1") with pytest.raises(WorkflowFailureError) as err: with concurrent.futures.ThreadPoolExecutor() as executor: @@ -306,6 +371,8 @@ def wait_cancel() -> str: ) assert isinstance(err.value.cause, ActivityError) assert isinstance(err.value.cause.cause, CancelledError) + # This will have every event except post1 because that's where it throws + assert events == ["pre1", "pre2", "pre3", "post3", "post2"] @activity.defn From c2dc99ed3bfc07cc52b2fed352f2b5a6e546cf1e Mon Sep 17 00:00:00 2001 From: Chad Retz Date: Thu, 1 Dec 2022 08:22:03 -0600 Subject: [PATCH 2/3] Proper platform c_long type and updates to heartbeat docs --- README.md | 18 ++++++++++-------- temporalio/activity.py | 14 +++++++------- temporalio/bridge/src/runtime.rs | 2 +- 3 files changed, 18 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index e2ba81db1..c14263bd7 100644 --- a/README.md +++ b/README.md @@ -850,8 +850,8 @@ activities no special worker parameters are needed. Cancellation for asynchronous activities is done via [`asyncio.Task.cancel`](https://docs.python.org/3/library/asyncio-task.html#asyncio.Task.cancel). This means that -`asyncio.CancelledError` will be raised (and can be caught, but it is not recommended). An activity must heartbeat to -receive cancellation and there are other ways to be notified about cancellation (see "Activity Context" and +`asyncio.CancelledError` will be raised (and can be caught, but it is not recommended). A non-local activity must +heartbeat to receive cancellation and there are other ways to be notified about cancellation (see "Activity Context" and "Heartbeating and Cancellation" later). ##### Synchronous Activities @@ -860,9 +860,9 @@ Synchronous activities, i.e. functions that do not have `async def`, can be used `activity_executor` worker parameter must be set with a `concurrent.futures.Executor` instance to use for executing the activities. -All long running activities should heartbeat so they can be cancelled. Cancellation in threaded activities throws but -multiprocess/other activities does not. The sections below on each synchronous type explain further. There are also -calls on the context that can check for cancellation. For more information, see "Activity Context" and +All long running, non-local activities should heartbeat so they can be cancelled. Cancellation in threaded activities +throws but multiprocess/other activities does not. The sections below on each synchronous type explain further. There +are also calls on the context that can check for cancellation. For more information, see "Activity Context" and "Heartbeating and Cancellation" sections later. Note, all calls from an activity to functions in the `temporalio.activity` package are powered by @@ -923,15 +923,17 @@ occurs. Synchronous activities cannot call any of the `async` functions. ##### Heartbeating and Cancellation -In order for an activity to be notified of cancellation requests, they must invoke `temporalio.activity.heartbeat()`. -It is strongly recommended that all but the fastest executing activities call this function regularly. "Types of -Activities" has specifics on cancellation for asynchronous and synchronous activities. +In order for a non-local activity to be notified of cancellation requests, it must invoke +`temporalio.activity.heartbeat()`. It is strongly recommended that all but the fastest executing activities call this +function regularly. "Types of Activities" has specifics on cancellation for asynchronous and synchronous activities. In addition to obtaining cancellation information, heartbeats also support detail data that is persisted on the server for retrieval during activity retry. If an activity calls `temporalio.activity.heartbeat(123, 456)` and then fails and is retried, `temporalio.activity.info().heartbeat_details` will return an iterable containing `123` and `456` on the next run. +Heartbeating has no effect on local activities. + ##### Worker Shutdown An activity can react to a worker shutdown. Using `is_worker_shutdown` or one of the `wait_for_worker_shutdown` diff --git a/temporalio/activity.py b/temporalio/activity.py index bc5e2a84a..c4e5941bb 100644 --- a/temporalio/activity.py +++ b/temporalio/activity.py @@ -236,13 +236,13 @@ def shield_thread_cancel_exception() -> Iterator[None]: """Context manager for synchronous multithreaded activities to delay cancellation exceptions. - By default, heartbeating synchronous multithreaded activities have an - exception thrown inside when cancellation occurs. Code within a "with" block - of this context manager will delay that throwing until the end. Even if the - block returns a value or throws its own exception, if a cancellation - exception is pending, it is thrown instead. Therefore users are encouraged - to not throw out of this block and can surround this with a try/except if - they wish to catch a cancellation. + By default, synchronous multithreaded activities have an exception thrown + inside when cancellation occurs. Code within a "with" block of this context + manager will delay that throwing until the end. Even if the block returns a + value or throws its own exception, if a cancellation exception is pending, + it is thrown instead. Therefore users are encouraged to not throw out of + this block and can surround this with a try/except if they wish to catch a + cancellation. This properly supports nested calls and will only throw after the last one. diff --git a/temporalio/bridge/src/runtime.rs b/temporalio/bridge/src/runtime.rs index ab02b4e19..3e851c759 100644 --- a/temporalio/bridge/src/runtime.rs +++ b/temporalio/bridge/src/runtime.rs @@ -76,7 +76,7 @@ pub fn init_runtime(telemetry_config: TelemetryConfig) -> PyResult { }) } -pub fn raise_in_thread<'a>(_py: Python<'a>, thread_id: i32, exc: &PyAny) -> bool { +pub fn raise_in_thread<'a>(_py: Python<'a>, thread_id: std::os::raw::c_long, exc: &PyAny) -> bool { unsafe { pyo3::ffi::PyThreadState_SetAsyncExc(thread_id, exc.as_ptr()) == 1 } } From 207b53bea8c9200feee3d76826f271c018efc432 Mon Sep 17 00:00:00 2001 From: Chad Retz Date: Thu, 1 Dec 2022 08:30:52 -0600 Subject: [PATCH 3/3] Fix thread ID type --- temporalio/bridge/src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/temporalio/bridge/src/lib.rs b/temporalio/bridge/src/lib.rs index 014eb6b38..4ad007d5b 100644 --- a/temporalio/bridge/src/lib.rs +++ b/temporalio/bridge/src/lib.rs @@ -50,7 +50,7 @@ fn init_runtime(telemetry_config: runtime::TelemetryConfig) -> PyResult(py: Python<'a>, thread_id: i32, exc: &PyAny) -> bool { +fn raise_in_thread<'a>(py: Python<'a>, thread_id: std::os::raw::c_long, exc: &PyAny) -> bool { runtime::raise_in_thread(py, thread_id, exc) }