From 17afa80f581b0ed8cd8dc7bbea8a39f54a61893b Mon Sep 17 00:00:00 2001 From: Nelson Chen Date: Wed, 17 Dec 2025 02:10:42 -0800 Subject: [PATCH] Lets "task create" nodes run in parallel. Like https://blog.comfy.org/p/unlimited-parallel-api-nodes-and --- py/wavespeed_task_nodes.py | 59 +++++++++++++++++++++++++------------- 1 file changed, 39 insertions(+), 20 deletions(-) diff --git a/py/wavespeed_task_nodes.py b/py/wavespeed_task_nodes.py index df586a2..8463f9b 100644 --- a/py/wavespeed_task_nodes.py +++ b/py/wavespeed_task_nodes.py @@ -422,16 +422,19 @@ def INPUT_TYPES(cls): CATEGORY = "WaveSpeedAI" FUNCTION = "submit_task" + + # Add this to indicate the method is async + OUTPUT_NODE = False - def submit_task(self, client, task_info, wait_for_completion=True, - max_wait_time=300, poll_interval=5): + async def submit_task(self, client, task_info, wait_for_completion=True, + max_wait_time=300, poll_interval=5): """ - Submit task from task_info using dynamic request handling + Submit task from task_info using dynamic request handling (async version) Args: client: WaveSpeed API client task_info: Task information from WaveSpeedTaskCreateDynamic - wait_for_completion: Whether to wait for completion + wait_for_completion: Whether to wait for completion max_wait_time: Maximum wait time poll_interval: Polling interval @@ -443,12 +446,13 @@ def submit_task(self, client, task_info, wait_for_completion=True, raise ValueError("Invalid task_info") model_uuid = task_info.get("modelUUID") - if not model_uuid: + if not model_uuid: raise ValueError("Missing modelUUID in task_info") try: # Import required modules from .wavespeed_api.client import WaveSpeedClient + import asyncio # Initialize the client wavespeed_client = WaveSpeedClient(client["api_key"]) @@ -461,15 +465,20 @@ def submit_task(self, client, task_info, wait_for_completion=True, print(f"Submitting task to model {model_uuid} with parameters: {request_json}") - # Use WaveSpeedClient to send request like in the reference - response = wavespeed_client.send_request( - dynamic_request, - wait_for_completion=wait_for_completion, - polling_interval=poll_interval, - timeout=max_wait_time + # Use asyncio to run the blocking operation in a thread pool + # This allows multiple tasks to run concurrently + loop = asyncio.get_event_loop() + response = await loop.run_in_executor( + None, + lambda: wavespeed_client.send_request( + dynamic_request, + wait_for_completion=wait_for_completion, + polling_interval=poll_interval, + timeout=max_wait_time + ) ) - if not response: + if not response: raise ValueError("No response from API") # Extract task information @@ -483,7 +492,7 @@ def submit_task(self, client, task_info, wait_for_completion=True, except Exception as e: error_message = str(e) print(f"Error in WaveSpeedTaskSubmit: {error_message}") - raise Exception(f"WaveSpeedTaskSubmit failed: {error_message}") + raise Exception(f"WaveSpeedTaskSubmit failed: {error_message}") class WaveSpeedTaskStatus: @@ -529,9 +538,9 @@ def INPUT_TYPES(cls): CATEGORY = "WaveSpeedAI" FUNCTION = "check_status" - def check_status(self, client, task_id, max_wait_time=300, poll_interval=5, wait_for_completion=True): + async def check_status(self, client, task_id, max_wait_time=300, poll_interval=5, wait_for_completion=True): """ - Check task status and return results + Check task status and return results (async version) Args: client: WaveSpeed API client @@ -540,7 +549,7 @@ def check_status(self, client, task_id, max_wait_time=300, poll_interval=5, wait poll_interval: Polling interval wait_for_completion: Whether to wait for completion - Returns: + Returns: tuple: (task_id, video_url, image, audio_url, text, firstImageUrl, imageUrls) """ @@ -550,20 +559,30 @@ def check_status(self, client, task_id, max_wait_time=300, poll_interval=5, wait try: # Import required modules from .wavespeed_api.client import WaveSpeedClient + import asyncio # Initialize the client wavespeed_client = WaveSpeedClient(client["api_key"]) print(f"Checking status for task {task_id}") + # Use asyncio to run the blocking operation in a thread pool + loop = asyncio.get_event_loop() + if wait_for_completion: # Wait for task completion - response = wavespeed_client.wait_for_task( - task_id, poll_interval, max_wait_time + response = await loop.run_in_executor( + None, + lambda: wavespeed_client.wait_for_task( + task_id, poll_interval, max_wait_time + ) ) else: # Just check current status - response = wavespeed_client.check_task_status(task_id) + response = await loop.run_in_executor( + None, + lambda: wavespeed_client.check_task_status(task_id) + ) if not response: raise ValueError("No response from API") @@ -584,7 +603,7 @@ def check_status(self, client, task_id, max_wait_time=300, poll_interval=5, wait return (task_id, "", None, "", "", "", []) else: # Unknown status, throw error - raise Exception(f"Unknown task status: {status}") + raise Exception(f"Unknown task status: {status}") # Process outputs for different types # Use shared output processor