Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 42 additions & 25 deletions api/tasks/workflow_execution_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from celery import shared_task
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError

from core.db.session_factory import session_factory
from core.workflow.entities.workflow_execution import WorkflowExecution
Expand Down Expand Up @@ -48,29 +49,36 @@ def save_workflow_execution_task(
with session_factory.create_session() as session:
# Deserialize execution data
execution = WorkflowExecution.model_validate(execution_data)

# Check if workflow run already exists
existing_run = session.scalar(select(WorkflowRun).where(WorkflowRun.id == execution.id_))

if existing_run:
# Update existing workflow run
_update_workflow_run_from_execution(existing_run, execution)
logger.debug("Updated existing workflow run: %s", execution.id_)
else:
# Create new workflow run
workflow_run = _create_workflow_run_from_execution(
execution=execution,
tenant_id=tenant_id,
app_id=app_id,
triggered_from=WorkflowRunTriggeredFrom(triggered_from),
creator_user_id=creator_user_id,
creator_user_role=CreatorUserRole(creator_user_role),
)
session.commit()
return True
workflow_run = _create_workflow_run_from_execution(
execution=execution,
tenant_id=tenant_id,
app_id=app_id,
triggered_from=WorkflowRunTriggeredFrom(triggered_from),
creator_user_id=creator_user_id,
creator_user_role=CreatorUserRole(creator_user_role),
)
try:
session.add(workflow_run)
logger.debug("Created new workflow run: %s", execution.id_)

session.commit()
return True
session.commit()
return True
except IntegrityError:
session.rollback()
existing_run = session.scalar(select(WorkflowRun).where(WorkflowRun.id == execution.id_))
if existing_run:
_update_workflow_run_from_execution(existing_run, execution)
session.commit()
return True
# This case is rare. Let Celery's retry mechanism handle it.
logger.warning(
"IntegrityError on insert but record with id %s not found after rollback. Task will be retried.",
execution.id_,
)
raise

except Exception as e:
logger.exception("Failed to save workflow execution %s", execution_data.get("id_", "unknown"))
Expand Down Expand Up @@ -116,17 +124,26 @@ def _create_workflow_run_from_execution(
return workflow_run


def _update_workflow_run_from_execution(workflow_run: WorkflowRun, execution: WorkflowExecution):
"""
Update a WorkflowRun database model from a WorkflowExecution domain entity.
"""
WORKFLOW_TERMINAL_STATES = {"succeeded", "failed", "stopped", "partial-succeeded"}


def _update_workflow_run_from_execution(workflow_run: WorkflowRun, execution: WorkflowExecution) -> None:
json_converter = WorkflowRuntimeTypeConverter()
workflow_run.status = execution.status.value
current_status = workflow_run.status
new_status = execution.status.value

if current_status in WORKFLOW_TERMINAL_STATES and new_status not in WORKFLOW_TERMINAL_STATES:
# If current status is terminal, do not update to a non-terminal status.
# Only update finished_at if it's not set.
workflow_run.finished_at = workflow_run.finished_at or execution.finished_at
return

workflow_run.status = new_status
workflow_run.outputs = (
json.dumps(json_converter.to_json_encodable(execution.outputs)) if execution.outputs else "{}"
)
workflow_run.error = execution.error_message
workflow_run.elapsed_time = execution.elapsed_time
workflow_run.total_tokens = execution.total_tokens
workflow_run.total_steps = execution.total_steps
workflow_run.finished_at = execution.finished_at
workflow_run.finished_at = workflow_run.finished_at or execution.finished_at
81 changes: 52 additions & 29 deletions api/tasks/workflow_node_execution_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from celery import shared_task
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError

from core.db.session_factory import session_factory
from core.workflow.entities.workflow_node_execution import (
Expand Down Expand Up @@ -50,31 +51,40 @@ def save_workflow_node_execution_task(
with session_factory.create_session() as session:
# Deserialize execution data
execution = WorkflowNodeExecution.model_validate(execution_data)

# Check if node execution already exists
existing_execution = session.scalar(
select(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id == execution.id)
)

if existing_execution:
# Update existing node execution
_update_node_execution_from_domain(existing_execution, execution)
logger.debug("Updated existing workflow node execution: %s", execution.id)
else:
# Create new node execution
node_execution = _create_node_execution_from_domain(
execution=execution,
tenant_id=tenant_id,
app_id=app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom(triggered_from),
creator_user_id=creator_user_id,
creator_user_role=CreatorUserRole(creator_user_role),
)
session.commit()
return True
node_execution = _create_node_execution_from_domain(
execution=execution,
tenant_id=tenant_id,
app_id=app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom(triggered_from),
creator_user_id=creator_user_id,
creator_user_role=CreatorUserRole(creator_user_role),
)
try:
session.add(node_execution)
logger.debug("Created new workflow node execution: %s", execution.id)

session.commit()
return True
session.commit()
return True
except IntegrityError:
session.rollback()
existing_execution = session.scalar(
select(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id == execution.id)
)
if existing_execution:
_update_node_execution_from_domain(existing_execution, execution)
session.commit()
return True
# This case is rare. Let Celery's retry mechanism handle it.
logger.warning(
"IntegrityError on insert but record with id %s not found after rollback. Task will be retried.",
execution.id,
)
raise

except Exception as e:
logger.exception("Failed to save workflow node execution %s", execution_data.get("id", "unknown"))
Expand Down Expand Up @@ -136,20 +146,35 @@ def _create_node_execution_from_domain(
return node_execution


NODE_TERMINAL_STATES = {"succeeded", "failed", "exception"}


def _update_node_execution_from_domain(node_execution: WorkflowNodeExecutionModel, execution: WorkflowNodeExecution):
"""
Update a WorkflowNodeExecutionModel database model from a WorkflowNodeExecution domain entity.
"""
# Update serialized data
json_converter = WorkflowRuntimeTypeConverter()
node_execution.inputs = json.dumps(json_converter.to_json_encodable(execution.inputs)) if execution.inputs else "{}"
current_status = node_execution.status
new_status = execution.status.value

if current_status in NODE_TERMINAL_STATES and new_status not in NODE_TERMINAL_STATES:
# If current status is terminal, do not update to a non-terminal status.
# Only update finished_at if it's not set.
node_execution.finished_at = node_execution.finished_at or execution.finished_at
return

node_execution.status = new_status
node_execution.inputs = (
json.dumps(json_converter.to_json_encodable(execution.inputs))
if execution.inputs is not None
else node_execution.inputs
)
node_execution.process_data = (
json.dumps(json_converter.to_json_encodable(execution.process_data)) if execution.process_data else "{}"
json.dumps(json_converter.to_json_encodable(execution.process_data))
if execution.process_data is not None
else node_execution.process_data
)
node_execution.outputs = (
json.dumps(json_converter.to_json_encodable(execution.outputs)) if execution.outputs else "{}"
)
# Convert metadata enum keys to strings for JSON serialization

if execution.metadata:
metadata_for_json = {
key.value if hasattr(key, "value") else str(key): value for key, value in execution.metadata.items()
Expand All @@ -158,8 +183,6 @@ def _update_node_execution_from_domain(node_execution: WorkflowNodeExecutionMode
else:
node_execution.execution_metadata = "{}"

# Update other fields
node_execution.status = execution.status.value
node_execution.error = execution.error
node_execution.elapsed_time = execution.elapsed_time
node_execution.finished_at = execution.finished_at
node_execution.finished_at = node_execution.finished_at or execution.finished_at
Loading