diff --git a/api/tasks/workflow_execution_tasks.py b/api/tasks/workflow_execution_tasks.py index 3b3c6e5313058c..b8266dc8a8ea6b 100644 --- a/api/tasks/workflow_execution_tasks.py +++ b/api/tasks/workflow_execution_tasks.py @@ -10,6 +10,7 @@ from celery import shared_task from sqlalchemy import select +from sqlalchemy.exc import IntegrityError from core.db.session_factory import session_factory from core.workflow.entities.workflow_execution import WorkflowExecution @@ -48,29 +49,36 @@ def save_workflow_execution_task( with session_factory.create_session() as session: # Deserialize execution data execution = WorkflowExecution.model_validate(execution_data) - - # Check if workflow run already exists existing_run = session.scalar(select(WorkflowRun).where(WorkflowRun.id == execution.id_)) - if existing_run: - # Update existing workflow run _update_workflow_run_from_execution(existing_run, execution) - logger.debug("Updated existing workflow run: %s", execution.id_) - else: - # Create new workflow run - workflow_run = _create_workflow_run_from_execution( - execution=execution, - tenant_id=tenant_id, - app_id=app_id, - triggered_from=WorkflowRunTriggeredFrom(triggered_from), - creator_user_id=creator_user_id, - creator_user_role=CreatorUserRole(creator_user_role), - ) + session.commit() + return True + workflow_run = _create_workflow_run_from_execution( + execution=execution, + tenant_id=tenant_id, + app_id=app_id, + triggered_from=WorkflowRunTriggeredFrom(triggered_from), + creator_user_id=creator_user_id, + creator_user_role=CreatorUserRole(creator_user_role), + ) + try: session.add(workflow_run) - logger.debug("Created new workflow run: %s", execution.id_) - - session.commit() - return True + session.commit() + return True + except IntegrityError: + session.rollback() + existing_run = session.scalar(select(WorkflowRun).where(WorkflowRun.id == execution.id_)) + if existing_run: + _update_workflow_run_from_execution(existing_run, execution) + session.commit() + return True + # This case is rare. Let Celery's retry mechanism handle it. + logger.warning( + "IntegrityError on insert but record with id %s not found after rollback. Task will be retried.", + execution.id_, + ) + raise except Exception as e: logger.exception("Failed to save workflow execution %s", execution_data.get("id_", "unknown")) @@ -116,12 +124,21 @@ def _create_workflow_run_from_execution( return workflow_run -def _update_workflow_run_from_execution(workflow_run: WorkflowRun, execution: WorkflowExecution): - """ - Update a WorkflowRun database model from a WorkflowExecution domain entity. - """ +WORKFLOW_TERMINAL_STATES = {"succeeded", "failed", "stopped", "partial-succeeded"} + + +def _update_workflow_run_from_execution(workflow_run: WorkflowRun, execution: WorkflowExecution) -> None: json_converter = WorkflowRuntimeTypeConverter() - workflow_run.status = execution.status.value + current_status = workflow_run.status + new_status = execution.status.value + + if current_status in WORKFLOW_TERMINAL_STATES and new_status not in WORKFLOW_TERMINAL_STATES: + # If current status is terminal, do not update to a non-terminal status. + # Only update finished_at if it's not set. + workflow_run.finished_at = workflow_run.finished_at or execution.finished_at + return + + workflow_run.status = new_status workflow_run.outputs = ( json.dumps(json_converter.to_json_encodable(execution.outputs)) if execution.outputs else "{}" ) @@ -129,4 +146,4 @@ def _update_workflow_run_from_execution(workflow_run: WorkflowRun, execution: Wo workflow_run.elapsed_time = execution.elapsed_time workflow_run.total_tokens = execution.total_tokens workflow_run.total_steps = execution.total_steps - workflow_run.finished_at = execution.finished_at + workflow_run.finished_at = workflow_run.finished_at or execution.finished_at diff --git a/api/tasks/workflow_node_execution_tasks.py b/api/tasks/workflow_node_execution_tasks.py index b30a4ff15b93cb..6f74a5117e458b 100644 --- a/api/tasks/workflow_node_execution_tasks.py +++ b/api/tasks/workflow_node_execution_tasks.py @@ -10,6 +10,7 @@ from celery import shared_task from sqlalchemy import select +from sqlalchemy.exc import IntegrityError from core.db.session_factory import session_factory from core.workflow.entities.workflow_node_execution import ( @@ -50,31 +51,40 @@ def save_workflow_node_execution_task( with session_factory.create_session() as session: # Deserialize execution data execution = WorkflowNodeExecution.model_validate(execution_data) - - # Check if node execution already exists existing_execution = session.scalar( select(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id == execution.id) ) - if existing_execution: - # Update existing node execution _update_node_execution_from_domain(existing_execution, execution) - logger.debug("Updated existing workflow node execution: %s", execution.id) - else: - # Create new node execution - node_execution = _create_node_execution_from_domain( - execution=execution, - tenant_id=tenant_id, - app_id=app_id, - triggered_from=WorkflowNodeExecutionTriggeredFrom(triggered_from), - creator_user_id=creator_user_id, - creator_user_role=CreatorUserRole(creator_user_role), - ) + session.commit() + return True + node_execution = _create_node_execution_from_domain( + execution=execution, + tenant_id=tenant_id, + app_id=app_id, + triggered_from=WorkflowNodeExecutionTriggeredFrom(triggered_from), + creator_user_id=creator_user_id, + creator_user_role=CreatorUserRole(creator_user_role), + ) + try: session.add(node_execution) - logger.debug("Created new workflow node execution: %s", execution.id) - - session.commit() - return True + session.commit() + return True + except IntegrityError: + session.rollback() + existing_execution = session.scalar( + select(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id == execution.id) + ) + if existing_execution: + _update_node_execution_from_domain(existing_execution, execution) + session.commit() + return True + # This case is rare. Let Celery's retry mechanism handle it. + logger.warning( + "IntegrityError on insert but record with id %s not found after rollback. Task will be retried.", + execution.id, + ) + raise except Exception as e: logger.exception("Failed to save workflow node execution %s", execution_data.get("id", "unknown")) @@ -136,20 +146,35 @@ def _create_node_execution_from_domain( return node_execution +NODE_TERMINAL_STATES = {"succeeded", "failed", "exception"} + + def _update_node_execution_from_domain(node_execution: WorkflowNodeExecutionModel, execution: WorkflowNodeExecution): - """ - Update a WorkflowNodeExecutionModel database model from a WorkflowNodeExecution domain entity. - """ - # Update serialized data json_converter = WorkflowRuntimeTypeConverter() - node_execution.inputs = json.dumps(json_converter.to_json_encodable(execution.inputs)) if execution.inputs else "{}" + current_status = node_execution.status + new_status = execution.status.value + + if current_status in NODE_TERMINAL_STATES and new_status not in NODE_TERMINAL_STATES: + # If current status is terminal, do not update to a non-terminal status. + # Only update finished_at if it's not set. + node_execution.finished_at = node_execution.finished_at or execution.finished_at + return + + node_execution.status = new_status + node_execution.inputs = ( + json.dumps(json_converter.to_json_encodable(execution.inputs)) + if execution.inputs is not None + else node_execution.inputs + ) node_execution.process_data = ( - json.dumps(json_converter.to_json_encodable(execution.process_data)) if execution.process_data else "{}" + json.dumps(json_converter.to_json_encodable(execution.process_data)) + if execution.process_data is not None + else node_execution.process_data ) node_execution.outputs = ( json.dumps(json_converter.to_json_encodable(execution.outputs)) if execution.outputs else "{}" ) - # Convert metadata enum keys to strings for JSON serialization + if execution.metadata: metadata_for_json = { key.value if hasattr(key, "value") else str(key): value for key, value in execution.metadata.items() @@ -158,8 +183,6 @@ def _update_node_execution_from_domain(node_execution: WorkflowNodeExecutionMode else: node_execution.execution_metadata = "{}" - # Update other fields - node_execution.status = execution.status.value node_execution.error = execution.error node_execution.elapsed_time = execution.elapsed_time - node_execution.finished_at = execution.finished_at + node_execution.finished_at = node_execution.finished_at or execution.finished_at