From 52f8c7d76f8e8a5c3bc9b7a7d93cf0625630a860 Mon Sep 17 00:00:00 2001 From: mrkaye97 Date: Mon, 17 Feb 2025 16:25:05 -0500 Subject: [PATCH 1/3] feat: thread killing --- hatchet_sdk/worker/runner/runner.py | 36 +++++++++++++++++++++-------- 1 file changed, 27 insertions(+), 9 deletions(-) diff --git a/hatchet_sdk/worker/runner/runner.py b/hatchet_sdk/worker/runner/runner.py index 6e27edd3..07b3c796 100644 --- a/hatchet_sdk/worker/runner/runner.py +++ b/hatchet_sdk/worker/runner/runner.py @@ -2,12 +2,13 @@ import contextvars import ctypes import functools +import time import json import traceback from concurrent.futures import ThreadPoolExecutor from enum import Enum from multiprocessing import Queue -from threading import Thread, current_thread +from threading import Thread, current_thread, Event from typing import Any, Callable, Dict, Literal, Type, TypeVar, cast, overload from pydantic import BaseModel @@ -86,6 +87,9 @@ def __init__( labels=labels, client=new_client_raw(config).dispatcher ) + self.cancellation_events: Dict[str, Event] = {} + + def create_workflow_run_url(self, action: Action) -> str: return f"{self.config.server_url}/workflow-runs/{action.workflow_run_id}?tenant={action.tenant_id}" @@ -198,15 +202,18 @@ def inner_callback(task: asyncio.Task[Any]) -> None: def thread_action_func( self, context: Context, action_func: Callable[..., Any], action: Action ) -> Any: - if action.step_run_id is not None and action.step_run_id != "": - self.threads[action.step_run_id] = current_thread() - elif ( - action.get_group_key_run_id is not None - and action.get_group_key_run_id != "" - ): - self.threads[action.get_group_key_run_id] = current_thread() + run_id = action.step_run_id or action.get_group_key_run_id + if run_id: + self.threads[run_id] = current_thread() + self.cancellation_events[run_id] = Event() + + while not self.cancellation_events[run_id].is_set(): + result = action_func(context) + if result is not None: + return result + + return None - return action_func(context) ## TODO: Stricter type hinting here # We wrap all actions in an async func @@ -260,6 +267,9 @@ def cleanup_run_id(self, run_id: str | None) -> None: if run_id in self.contexts: del self.contexts[run_id] + if run_id in self.cancellation_events: + del self.cancellation_events[run_id] + def create_context( self, action: Action, action_func: Callable[..., Any] | None ) -> Context | DurableContext: @@ -419,8 +429,16 @@ async def handle_cancel_action(self, run_id: str) -> None: if future: future.cancel() + if run_id in self.cancellation_events: + self.cancellation_events[run_id].set() + # check if thread is still running, if so, print a warning if run_id in self.threads: + thread = self.threads.get(run_id) + if thread: + self.force_kill_thread(thread) + time.sleep(1) + logger.warning( f"Thread {self.threads[run_id].ident} with run id {run_id} is still running after cancellation. This could cause the thread pool to get blocked and prevent new tasks from running." ) From 763433ba8624008eb9edde7c349359bac52a19fb Mon Sep 17 00:00:00 2001 From: mrkaye97 Date: Mon, 17 Feb 2025 16:29:12 -0500 Subject: [PATCH 2/3] fix: rm events --- hatchet_sdk/worker/runner/runner.py | 26 ++++++++------------------ 1 file changed, 8 insertions(+), 18 deletions(-) diff --git a/hatchet_sdk/worker/runner/runner.py b/hatchet_sdk/worker/runner/runner.py index 07b3c796..a5c4b958 100644 --- a/hatchet_sdk/worker/runner/runner.py +++ b/hatchet_sdk/worker/runner/runner.py @@ -87,8 +87,6 @@ def __init__( labels=labels, client=new_client_raw(config).dispatcher ) - self.cancellation_events: Dict[str, Event] = {} - def create_workflow_run_url(self, action: Action) -> str: return f"{self.config.server_url}/workflow-runs/{action.workflow_run_id}?tenant={action.tenant_id}" @@ -202,17 +200,15 @@ def inner_callback(task: asyncio.Task[Any]) -> None: def thread_action_func( self, context: Context, action_func: Callable[..., Any], action: Action ) -> Any: - run_id = action.step_run_id or action.get_group_key_run_id - if run_id: - self.threads[run_id] = current_thread() - self.cancellation_events[run_id] = Event() - - while not self.cancellation_events[run_id].is_set(): - result = action_func(context) - if result is not None: - return result + if action.step_run_id is not None and action.step_run_id != "": + self.threads[action.step_run_id] = current_thread() + elif ( + action.get_group_key_run_id is not None + and action.get_group_key_run_id != "" + ): + self.threads[action.get_group_key_run_id] = current_thread() - return None + return action_func(context) ## TODO: Stricter type hinting here @@ -267,9 +263,6 @@ def cleanup_run_id(self, run_id: str | None) -> None: if run_id in self.contexts: del self.contexts[run_id] - if run_id in self.cancellation_events: - del self.cancellation_events[run_id] - def create_context( self, action: Action, action_func: Callable[..., Any] | None ) -> Context | DurableContext: @@ -429,9 +422,6 @@ async def handle_cancel_action(self, run_id: str) -> None: if future: future.cancel() - if run_id in self.cancellation_events: - self.cancellation_events[run_id].set() - # check if thread is still running, if so, print a warning if run_id in self.threads: thread = self.threads.get(run_id) From 85d8b5604116fae1cff38cef1ae56275f8809224 Mon Sep 17 00:00:00 2001 From: mrkaye97 Date: Mon, 17 Feb 2025 16:30:12 -0500 Subject: [PATCH 3/3] fix: lint --- hatchet_sdk/worker/runner/runner.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/hatchet_sdk/worker/runner/runner.py b/hatchet_sdk/worker/runner/runner.py index a5c4b958..a9b3b45c 100644 --- a/hatchet_sdk/worker/runner/runner.py +++ b/hatchet_sdk/worker/runner/runner.py @@ -2,14 +2,14 @@ import contextvars import ctypes import functools -import time import json +import time import traceback from concurrent.futures import ThreadPoolExecutor from enum import Enum from multiprocessing import Queue -from threading import Thread, current_thread, Event -from typing import Any, Callable, Dict, Literal, Type, TypeVar, cast, overload +from threading import Thread, current_thread +from typing import Any, Callable, Dict, cast from pydantic import BaseModel @@ -87,7 +87,6 @@ def __init__( labels=labels, client=new_client_raw(config).dispatcher ) - def create_workflow_run_url(self, action: Action) -> str: return f"{self.config.server_url}/workflow-runs/{action.workflow_run_id}?tenant={action.tenant_id}" @@ -210,7 +209,6 @@ def thread_action_func( return action_func(context) - ## TODO: Stricter type hinting here # We wrap all actions in an async func async def async_wrapped_action_func(