diff --git a/.kiro/specs/papers-library-importer/design.md b/.kiro/specs/papers-library-importer/design.md index fa90f3016..682ef3ee9 100644 --- a/.kiro/specs/papers-library-importer/design.md +++ b/.kiro/specs/papers-library-importer/design.md @@ -548,7 +548,7 @@ This structure enables: **New Model: ImportHistory** ```python class ImportHistory(DatabaseModel): - __tablename__ = "import_history" + __tablename__ = "imports" workspace_id: Mapped[UUID] = mapped_column(ForeignKey("workspaces.id", ondelete="CASCADE"), index=True) user_id: Mapped[UUID] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"), index=True) diff --git a/.kiro/specs/pdf-workflow-orchestrator/design.md b/.kiro/specs/pdf-workflow-orchestrator/design.md new file mode 100644 index 000000000..cd8cbb2e6 --- /dev/null +++ b/.kiro/specs/pdf-workflow-orchestrator/design.md @@ -0,0 +1,599 @@ +# PDF Workflow Orchestrator Design + +## Overview + +The PDF Workflow Orchestrator refactors the existing document upload and processing pipeline to use RQ's native job chaining capabilities. Instead of a single monolithic job, the system splits processing into 6 chained jobs using RQ's built-in features without custom abstractions. + +## Current vs New Architecture + +### Current Flow (Single Job) +``` +POST /documents/bulk → process_bulk_upload() → upload_and_preprocess_documents_job → (File upload + preprocessing + DB creation in one step) +``` + +### New Flow (Chained Jobs) +``` +POST /documents/bulk → process_bulk_upload() → Upload files to S3 + Create DB records → analysis_and_preprocess_job(document_id, s3_url) → conditional_ocr_job (if needed) → text_extraction_job +``` + +### Key Changes from Current Implementation + +1. **File Upload Moved to API**: Files uploaded to S3 in `process_bulk_upload()` before job enqueueing +2. **Job Splitting**: `upload_and_preprocess_documents_job` split into chained jobs with combined analysis+preprocessing +3. **S3 URLs Instead of File Data**: Jobs receive document IDs and S3 URLs, not raw file bytes +4. **RQ Dependencies**: Use `depends_on` parameter for job chaining +5. **Job Metadata**: Track workflow progress using `job.meta` +6. **In-Place Processing**: OCRmyPDF overwrites the same S3 object path for page rotation + +## Integration with Existing Code + +### File Operations Integration + +The design uses existing file operations from `contexts/files.py` but requires some helper functions to be added: + + +## CLI Commands Architecture + +### Overview + +The CLI workflow management system integrates with the existing Extralit CLI using Typer's sub-application pattern. It communicates with the server through FastAPI endpoints using the HTTP client, following the same pattern as the existing `import_bib.py` command. + +**Key Architecture Principles:** +- CLI located in `extralit/src/extralit/cli/` (client-side) +- Server endpoints in `extralit-server/src/extralit_server/api/handlers/v1/` (server-side) +- Communication via `client.api.http_client.post/get()` calls +- No direct imports between CLI and server modules + +### Required FastAPI Endpoints + +Before implementing the CLI, we need these server endpoints: + +```python +# extralit-server/src/extralit_server/api/handlers/v1/workflows.py +from fastapi import APIRouter, HTTPException, Query, Security +from typing import Optional, List +from uuid import UUID +from extralit_server.api.schemas.v1.workflows import ( + StartWorkflowRequest, StartWorkflowResponse, + WorkflowStatusResponse, RestartWorkflowRequest +) + +router = APIRouter(tags=["workflows"]) + +@router.post("/workflows/start", response_model=StartWorkflowResponse) +async def start_workflow(request: StartWorkflowRequest) -> StartWorkflowResponse: + """Start PDF processing workflow for a document.""" + # Implementation calls start_document_workflow() function + pass + +@router.get("/workflows/status", response_model=List[WorkflowStatusResponse]) +async def get_workflow_status( + document_id: Optional[UUID] = Query(None), + reference: Optional[str] = Query(None), + workspace_name: Optional[str] = Query(None) +) -> List[WorkflowStatusResponse]: + """Get workflow status for documents.""" + # Implementation calls WorkflowContext methods + pass + +@router.post("/workflows/restart", response_model=StartWorkflowResponse) +async def restart_workflow(request: RestartWorkflowRequest) -> StartWorkflowResponse: + """Restart failed workflow jobs using DAG-based resumability.""" + try: + # Get current workflow state + workflow = await DocumentWorkflow.get_by_document_id(db, request.document_id) + if not workflow: + raise HTTPException(404, "Workflow not found") + + if not workflow.is_resumable(): + raise HTTPException(400, "Workflow is not in a resumable state") + + # Get workflow context for resumption + current_context = workflow.get_workflow_context() + + updated_context = resume_workflow( + request.document_id, + current_context + ) + + # Update workflow record + workflow.update_workflow_context(updated_context) + await db.commit() + + return StartWorkflowResponse( + workflow_id=str(workflow.id), + document_id=str(request.document_id), + group_id=workflow.group_id, + status="running", + restarted_jobs=workflow.get_failed_jobs() + ) + + except Exception as e: + raise HTTPException(500, f"Failed to restart workflow: {str(e)}") + pass + +@router.get("/workflows/", response_model=List[WorkflowStatusResponse]) +async def list_workflows( + workspace_name: Optional[str] = Query(None), + status_filter: Optional[str] = Query(None), + limit: int = Query(50) +) -> List[WorkflowStatusResponse]: + """List workflows with optional filtering.""" + # Implementation calls WorkflowContext.list_workflows() + pass +``` + +### CLI Implementation + +```python +# extralit/src/extralit/cli/workflows.py +import typer +from typing import Optional +from uuid import UUID +from rich.console import Console +from rich.table import Table +from rich.progress import Progress, SpinnerColumn, TextColumn +from extralit.client import Extralit + +console = Console() +workflow_app = typer.Typer(help="Manage PDF processing workflows") + +@workflow_app.command() +def start( + document_id: str = typer.Option(..., help="Document UUID to process"), + workspace_name: str = typer.Option(..., help="Workspace name"), + reference: str = typer.Option(None, help="Document reference for tracking"), + force: bool = typer.Option(False, help="Force restart if workflow already exists"), + verbose: bool = typer.Option(False, "-v", "--verbose", help="Show detailed output") +): + """Start PDF processing workflow for a document.""" + try: + client = Extralit.from_credentials() + + # Call server endpoint + response = client.api.http_client.post( + f"{client.api_url}/api/v1/workflows/start", + json={ + "document_id": document_id, + "workspace_name": workspace_name, + "reference": reference or f"doc_{document_id[:8]}", + "force": force + } + ) + + if response.status_code != 200: + error_detail = response.json().get("detail", str(response.text)) + raise ValueError(f"Error starting workflow: {error_detail}") + + result = response.json() + console.print(f"[green]✓ Started workflow {result['workflow_id']}[/green]") + + if verbose: + console.print(f"Document ID: {result['document_id']}") + console.print(f"Reference: {result['reference']}") + console.print(f"Group ID: {result['group_id']}") + + console.print(f"Track progress with: [bold]extralit workflow status --document-id {document_id}[/bold]") + + except Exception as e: + console.print(f"[red]Error starting workflow: {e}[/red]") + raise typer.Exit(1) + +@workflow_app.command() +def status( + document_id: str = typer.Option(None, help="Document UUID to check"), + reference: str = typer.Option(None, help="Document reference to check"), + workspace_name: str = typer.Option(None, help="Filter by workspace name"), + watch: bool = typer.Option(False, "-w", "--watch", help="Watch status updates in real-time"), + json_output: bool = typer.Option(False, "--json", help="Output status as JSON") +): + """Check workflow status for documents.""" + try: + if not document_id and not reference: + console.print("[red]Must specify either --document-id or --reference[/red]") + raise typer.Exit(1) + + client = Extralit.from_credentials() + + # Call server endpoint + params = {} + if document_id: + params["document_id"] = document_id + if reference: + params["reference"] = reference + if workspace_name: + params["workspace_name"] = workspace_name + + response = client.api.http_client.get( + f"{client.api_url}/api/v1/workflows/status", + params=params + ) + + if response.status_code != 200: + error_detail = response.json().get("detail", str(response.text)) + raise ValueError(f"Error checking status: {error_detail}") + + workflows = response.json() + + if not workflows: + console.print("[yellow]No workflows found[/yellow]") + return + + if json_output: + import json + console.print(json.dumps(workflows, indent=2)) + return + + # Display status table + _display_workflow_status_table(workflows, watch) + + except Exception as e: + console.print(f"[red]Error checking status: {e}[/red]") + raise typer.Exit(1) + +@workflow_app.command() +def restart( + document_id: str = typer.Option(None, help="Document UUID to restart"), + reference: str = typer.Option(None, help="Document reference to restart"), + workspace_name: str = typer.Option(None, help="Filter by workspace name"), + failed_only: bool = typer.Option(True, help="Only restart failed jobs"), + confirm: bool = typer.Option(False, "-y", "--yes", help="Skip confirmation prompt") +): + """Restart failed workflow jobs for documents.""" + try: + if not document_id and not reference: + console.print("[red]Must specify either --document-id or --reference[/red]") + raise typer.Exit(1) + + client = Extralit.from_credentials() + + # First get workflows to restart + params = {} + if document_id: + params["document_id"] = document_id + if reference: + params["reference"] = reference + if workspace_name: + params["workspace_name"] = workspace_name + + status_response = client.api.http_client.get( + f"{client.api_url}/api/v1/workflows/status", + params=params + ) + + if status_response.status_code != 200: + raise ValueError("Failed to get workflow status") + + workflows = status_response.json() + failed_workflows = [w for w in workflows if w['status'] == 'failed'] + + if not failed_workflows: + console.print("[yellow]No failed workflows found[/yellow]") + return + + # Confirmation prompt + if not confirm: + workflow_count = len(failed_workflows) + if not typer.confirm(f"Restart {workflow_count} failed workflow(s)?"): + console.print("Cancelled") + return + + # Restart workflows + restarted_count = 0 + for workflow in failed_workflows: + try: + restart_response = client.api.http_client.post( + f"{client.api_url}/api/v1/workflows/restart", + json={ + "document_id": workflow['document_id'], + "failed_only": failed_only + } + ) + + if restart_response.status_code == 200: + console.print(f"[green]✓ Restarted workflow for document {workflow['document_id']}[/green]") + restarted_count += 1 + else: + error_detail = restart_response.json().get("detail", "Unknown error") + console.print(f"[red]✗ Failed to restart workflow for document {workflow['document_id']}: {error_detail}[/red]") + + except Exception as e: + console.print(f"[red]✗ Failed to restart workflow for document {workflow['document_id']}: {e}[/red]") + + console.print(f"[blue]Restarted {restarted_count} of {len(failed_workflows)} workflows[/blue]") + + except Exception as e: + console.print(f"[red]Error restarting workflows: {e}[/red]") + raise typer.Exit(1) + +@workflow_app.command() +def list( + workspace_name: str = typer.Option(None, help="Filter by workspace name"), + status_filter: str = typer.Option(None, help="Filter by status (running, completed, failed)"), + limit: int = typer.Option(50, help="Maximum number of workflows to show"), + json_output: bool = typer.Option(False, "--json", help="Output as JSON") +): + """List recent workflows.""" + try: + client = Extralit.from_credentials() + + params = {"limit": limit} + if workspace_name: + params["workspace_name"] = workspace_name + if status_filter: + params["status_filter"] = status_filter + + response = client.api.http_client.get( + f"{client.api_url}/api/v1/workflows/", + params=params + ) + + if response.status_code != 200: + error_detail = response.json().get("detail", str(response.text)) + raise ValueError(f"Error listing workflows: {error_detail}") + + workflows = response.json() + + if not workflows: + console.print("[yellow]No workflows found[/yellow]") + return + + if json_output: + import json + console.print(json.dumps(workflows, indent=2, default=str)) + return + + _display_workflow_status_table(workflows, watch=False) + + except Exception as e: + console.print(f"[red]Error listing workflows: {e}[/red]") + raise typer.Exit(1) + +def _display_workflow_status_table(workflows: list, watch: bool = False): + """Display workflow status in a formatted table.""" + def create_table(): + table = Table(title="PDF Processing Workflows") + table.add_column("Document ID", style="cyan", no_wrap=True) + table.add_column("Reference", style="magenta") + table.add_column("Workspace", style="blue") + table.add_column("Status", style="green") + table.add_column("Progress", style="yellow") + table.add_column("Started", style="dim") + table.add_column("Duration", style="dim") + + for workflow in workflows: + # Calculate progress percentage from RQ Group + total_jobs = workflow.get('total_jobs', 0) + completed_jobs = workflow.get('completed_jobs', 0) + progress = f"{completed_jobs}/{total_jobs} ({int(completed_jobs/total_jobs*100) if total_jobs > 0 else 0}%)" + + # Format status with color + status = workflow['status'] + if status == 'completed': + status = f"[green]{status}[/green]" + elif status == 'failed': + status = f"[red]{status}[/red]" + elif status == 'running': + status = f"[yellow]{status}[/yellow]" + + # Calculate duration + import datetime + started = workflow.get('inserted_at') + if started: + if isinstance(started, str): + started = datetime.datetime.fromisoformat(started.replace('Z', '+00:00')) + duration = str(datetime.datetime.utcnow() - started.replace(tzinfo=None)).split('.')[0] + else: + duration = "Unknown" + + table.add_row( + workflow['document_id'][:8] + "...", + workflow.get('reference', 'N/A'), + workflow.get('workspace_name', 'N/A'), + status, + progress, + started.strftime('%Y-%m-%d %H:%M') if started else 'N/A', + duration + ) + + return table + + if watch: + import time + try: + while True: + console.clear() + console.print(create_table()) + console.print("\n[dim]Press Ctrl+C to stop watching[/dim]") + time.sleep(5) + except KeyboardInterrupt: + console.print("\n[yellow]Stopped watching[/yellow]") + else: + console.print(create_table()) + +# Add to main CLI app +# In extralit/src/extralit/cli/__init__.py +# app.add_typer(workflow_app, name="workflow") +``` + +### CLI Usage Examples + +```bash +# Start workflow for a specific document +extralit workflow start --document-id 123e4567-e89b-12d3-a456-426614174000 --workspace-name "research-papers" + +# Check status of a specific document +extralit workflow status --document-id 123e4567-e89b-12d3-a456-426614174000 + +# Check status of all documents in a reference batch +extralit workflow status --reference "batch_2024_01_15" --workspace-name "research-papers" + +# Watch status updates in real-time +extralit workflow status --document-id 123e4567-e89b-12d3-a456-426614174000 --watch + +# List recent workflows +extralit workflow list --workspace-name "research-papers" --status-filter "failed" + +# Restart failed workflows +extralit workflow restart --reference "batch_2024_01_15" --failed-only + +# Get status as JSON for scripting +extralit workflow status --document-id 123e4567-e89b-12d3-a456-426614174000 --json +``` + +## Data Models + +### Document Metadata Schema + +The `documents.metadata_` field uses the existing structured schema in `extralit_server/src/extralit_server/api/schemas/v1/document/metadata.py` to store analysis and preprocessing results. This schema includes: + +- **DocumentProcessingMetadata**: Complete workflow metadata with analysis and preprocessing results +- **AnalysisMetadata**: PDF analysis results (OCR quality, layout analysis) +- **PreprocessingMetadata**: Processing results and timing information +- **OCRQualityMetadata**: OCR quality metrics and scores +- **LayoutAnalysisMetadata**: PDF layout analysis results + +### Database Model Using RQ Groups + +The existing `DocumentWorkflow` model in `extralit_server/src/extralit_server/models/database.py` needs to be updated to support RQ Groups: + +**Key Changes Required:** +- Replace `job_ids` dictionary field with `group_id` string field +- Add `status` field for caching workflow status +- Add methods to interact with RQ Groups (`get_workflow_status`, `is_resumable`, `restart_failed_jobs`) +- Update database migration to support the new schema + +**RQ Groups Integration:** +- Each document workflow gets a unique RQ Group ID +- All jobs for a document are added to the same group +- Group status becomes the source of truth for workflow state +- Database model provides efficient querying and caching layer + +### Job Schemas for RQ Groups Integration + +The existing `WorkflowJobResult` schema in `extralit_server/src/extralit_server/api/schemas/v1/jobs.py` needs to be extended to support RQ Groups: + +**Required Updates:** +- Add `group_id` field to track which RQ Group the job belongs to +- Add `group_status` field for overall group status information +- Extend job metadata to include group-level progress information + +**New Schema Fields:** +- `group_id`: RQ Group identifier for the workflow +- `group_progress`: Overall progress of the group (0.0-1.0) +- `group_status`: Status of the entire group (running, completed, failed) +- `workflow_step`: Current step in the workflow process + +## RQ-Native Workflow Design + +### Leveraging RQ Groups and Dependencies + +Instead of building custom workflow orchestration, we use RQ's native Groups and job dependencies for resumable workflows: + +### Job Properties for Resumability + +Each job must have these properties to enable resumability: + +1. **Idempotency**: Jobs can be safely re-run without side effects +2. **Artifact Management**: Clear definition of what artifacts are produced/consumed +3. **Context Awareness**: Jobs receive and update workflow context +4. **Dependency Declaration**: Explicit dependencies in the DAG definition +5. **Conditional Logic**: Ability to skip jobs based on workflow state + +### RQ Groups Job Implementation Pattern + +Jobs will be implemented using RQ Groups for workflow coordination: + +**Key Implementation Requirements:** +- Jobs are added to RQ Groups during enqueueing +- Job metadata includes group_id and workflow context +- Jobs use `depends_on` parameter for dependency management +- Group status is queried using RQ Groups API +- Failed jobs can be restarted within the same group + +**Integration Points:** +- `extralit_server/src/extralit_server/jobs/document_jobs.py`: Update existing job functions +- `extralit_server/src/extralit_server/workflows/documents.py`: Update workflow orchestrator +- `extralit_server/src/extralit_server/contexts/workflows.py`: Update job querying functions + +### Integration with Existing Code Structure + +The design leverages existing modules while adding resumability: + +1. **Analysis Job**: Uses `PDFOCRLayerDetector` from `analysis.py` and `PDFAnalyzer` from `margin.py` +2. **Preprocess Job**: Uses `PDFPreprocessor` from `preprocessing.py` with analysis disabled +3. **File Handling**: Uses existing `download_file_from_s3()` and `upload_file_to_s3()` from `files.py` +4. **Schemas**: Extends existing `PDFMetadata` from `preprocessing.py` +5. **Workflow State**: Stored in enhanced `DocumentWorkflow` model with artifact tracking + +This approach minimizes code duplication and leverages the existing, well-tested PDF processing logic while adding comprehensive resumability. + +## Implementation Strategy + +### Phase 1: Minimal Viable Workflow +1. **Refactor Existing Job**: Split `upload_and_preprocess_documents_job` into `analysis_job` and `preprocess_job` +2. **Move File Upload**: Upload files to S3 in `process_bulk_upload()` before job enqueueing +3. **Add Job Metadata**: Track workflow progress using `job.meta` +4. **Test Basic Chaining**: Verify jobs can enqueue dependent jobs + +### Phase 2: Complete Workflow +2. **RQ Dependencies**: Use `depends_on` parameter for job chaining +4. **API Extensions**: Add document workflow status endpoint + +### Phase 3: Management and Recovery +1. **CLI Commands**: Add workflow start/status commands using typer +2. **Error Handling**: Implement job restart for failed workflows +3. **Testing**: Add comprehensive tests for workflow execution +4. **Performance**: Optimize for multiple concurrent workflows + +### Key Principles +- **Incremental Refactoring**: Modify existing code gradually +- **Simple Recovery**: Use RQ registries and metadata for workflow state + +## Testing Strategy + +### End-to-End Workflow Tests + +**Complete PDF Processing Workflow:** +- Test PDF workflow from upload through analysis, preprocessing, and conditional OCR completion with all jobs succeeding + +**Conditional OCR Logic:** +- Test workflow skips OCR job when analysis determines PDF has good OCR text layer +- Test workflow enqueues OCR job when analysis determines PDF needs OCR processing + +**Workflow State Tracking:** +- Test document metadata is updated correctly at each workflow step completion +- Test workflow status progresses from "queued" to "running" to "completed" appropriately + +### API Integration Tests + +**Bulk Upload Integration:** +- Test POST /documents/bulk creates RQ Groups with proper job dependencies after S3 upload +- Test API returns workflow group_id and initial status for tracking purposes + +**Job Status Querying:** +- Test GET /jobs API filters jobs by document_id, reference, and workflow_step parameters +- Test API returns job metadata including RQ Group information and progress +- Test API shows error details and failure information when jobs fail +- Test API correctly queries RQ Groups for job status instead of individual job fetches + +**Workflow Progress Monitoring:** +- Test API shows current workflow step and overall progress percentage for active workflows +- Test API correctly identifies completed workflows versus failed or stalled ones + +### CLI Workflow Management Tests + +**Workflow Status Commands:** +- Test `workflow status --document-id` command shows all jobs for a specific document +- Test `workflow status --reference` command shows jobs for all documents in a reference batch + +**Failed Job Restart:** +- Test CLI can identify failed jobs using RQ Group status for a given document_id +- Test CLI restart command re-enqueues failed jobs within the same RQ Group with proper dependencies +- Test restarted workflow continues from the failed step without re-running completed jobs +- Test RQ Group-based resumability maintains workflow integrity + +**Error Handling:** +- Test CLI commands provide clear error messages for invalid document IDs or missing workflows +- Test CLI gracefully handles Redis connection issues and RQ Groups access problems +- Test CLI handles RQ Group expiration and cleanup scenarios \ No newline at end of file diff --git a/.kiro/specs/pdf-workflow-orchestrator/requirements.md b/.kiro/specs/pdf-workflow-orchestrator/requirements.md new file mode 100644 index 000000000..922458efc --- /dev/null +++ b/.kiro/specs/pdf-workflow-orchestrator/requirements.md @@ -0,0 +1,102 @@ +# PDF Workflow Orchestrator Requirements + +## Introduction + +The PDF Workflow Orchestrator leverages RQ's native job chaining capabilities to process PDFs through a series of steps: analysis, preprocessing, OCR, text extraction, table extraction, and embedding. The system uses RQ's built-in features like job dependencies, job groups, job metadata, and job registries to provide workflow execution and tracking without custom abstractions. + +## Requirements + +### Requirement 1: RQ Native Job Dependencies + +**User Story:** As a developer, I want to use RQ's `depends_on` parameter to chain jobs together, so that jobs execute in the correct order without custom workflow abstractions. + +#### Acceptance Criteria + +1. WHEN enqueueing jobs THEN the system SHALL use RQ's `depends_on` parameter to define job dependencies +2. WHEN a dependency job fails THEN RQ SHALL automatically prevent dependent jobs from running +3. WHEN jobs have multiple dependencies THEN RQ SHALL wait for all dependencies to complete successfully +4. WHEN conditional logic is needed THEN jobs SHALL enqueue their own dependent jobs based on results +5. WHEN parallel jobs are needed THEN the system SHALL enqueue multiple jobs without dependencies + +### Requirement 2: RQ Job Metadata for Document Tracking + +**User Story:** As a developer, I want to use RQ's job metadata to track document processing, so that I can query job status by document ID or reference using RQ's built-in capabilities. + +#### Acceptance Criteria + +1. WHEN enqueueing jobs THEN the system SHALL store document_id, reference, workspace_id, and workflow_step in job.meta +2. WHEN querying jobs THEN the system SHALL scan RQ job registries to find jobs by metadata +3. WHEN jobs are running THEN the system SHALL update job.meta with progress information +4. WHEN jobs complete THEN the system SHALL store results in job.meta or job.result +5. WHEN tracking workflows THEN the system SHALL use job metadata to reconstruct workflow state + +### Requirement 3: RQ Job Groups for Document Workflows + +**User Story:** As a developer, I want to use RQ's job groups to track related jobs for a document, so that I can monitor complete document processing workflows. + +#### Acceptance Criteria + +1. WHEN starting document processing THEN the system SHALL create an RQ Group for the document +2. WHEN enqueueing jobs THEN the system SHALL add jobs to the document's group +3. WHEN querying workflow status THEN the system SHALL use group.get_jobs() to retrieve all related jobs +4. WHEN jobs fan-out THEN multiple jobs SHALL be added to the same group +5. WHEN groups expire THEN RQ SHALL automatically clean up completed groups + +### Requirement 4: Enhanced Job Functions with Type Hints + +**User Story:** As a developer, I want to define job functions with clear type hints and use RQ's @job decorator, so that job definitions are maintainable and compatible with RQ's serialization. + +#### Acceptance Criteria + +1. WHEN defining job functions THEN developers SHALL use type hints for all parameters and return values +2. WHEN using RQ decorators THEN the system SHALL use RQ's @job decorator with queue, timeout, and retry parameters +3. WHEN jobs need conditional logic THEN they SHALL enqueue dependent jobs within the function +4. WHEN jobs process files THEN they SHALL accept database IDs or S3 URLs instead of raw file data +5. WHEN jobs complete THEN they SHALL return serializable results that can be passed to dependent jobs + +### Requirement 5: Database and S3 File References + +**User Story:** As a system operator, I want jobs to use SQLAlchemy database connections and S3 presigned URLs to access files, so that large files don't clog up the Redis queue. + +#### Acceptance Criteria + +1. WHEN jobs need database access THEN they SHALL use the existing get_async_db dependency injection +2. WHEN jobs need file access THEN they SHALL use presigned S3 URLs or direct S3 client access +3. WHEN passing data between jobs THEN the system SHALL pass document IDs and file references, not raw data +4. WHEN jobs create temporary files THEN they SHALL clean up resources after processing +5. WHEN jobs store results THEN they SHALL use existing database models and S3 storage patterns + +### Requirement 6: Enhanced Job API for Document Workflows + +**User Story:** As a developer, I want to query jobs by document ID, reference, or workflow step through the existing jobs API, so that I can track document processing progress. + +#### Acceptance Criteria + +1. WHEN querying jobs THEN the API SHALL support filtering by document_id, reference, and workflow_step +2. WHEN returning job status THEN the API SHALL include job metadata and group information +3. WHEN jobs fail THEN the API SHALL return error details and failure information +4. WHEN restarting workflows THEN the system SHALL provide CLI commands to re-enqueue failed jobs +5. WHEN monitoring progress THEN the API SHALL show the current workflow step and overall progress + +### Requirement 7: Multi-Queue Worker Support + +**User Story:** As a system operator, I want to run workers on different queues, so that I can scale processing based on resource requirements. + +#### Acceptance Criteria + +1. WHEN running CPU workers THEN they SHALL process jobs from default and high priority queues +2. WHEN running workers THEN they SHALL process jobs from dedicated queues +3. WHEN scaling workers THEN the system SHALL support multiple workers per queue type +4. WHEN jobs require specific resources THEN they SHALL be enqueued to appropriate queues +5. WHEN workers are distributed THEN RQ SHALL handle job distribution and coordination automatically + +### Requirement 8: PDF Processing Workflow Implementation + +**User Story:** As a system user, I want to process PDFs through a complete workflow of analysis, preprocessing, OCR, text extraction, table extraction, and embedding, so that I can extract structured data from documents. + +#### Acceptance Criteria + +1. When enqueing PDF jobs, THEN they should ordered such that documents within reference are processed in FIFO order +2. WHEN starting PDF processing THEN the system SHALL enqueue combined analysis and preprocessing job +3. WHEN analysis and preprocessing complete THEN the system SHALL conditionally enqueue OCR job if needed +5. WHEN analysis and preprocessing complete THEN the system SHALL enqueue table extraction job \ No newline at end of file diff --git a/.kiro/specs/pdf-workflow-orchestrator/tasks.md b/.kiro/specs/pdf-workflow-orchestrator/tasks.md new file mode 100644 index 000000000..885cc3af7 --- /dev/null +++ b/.kiro/specs/pdf-workflow-orchestrator/tasks.md @@ -0,0 +1,205 @@ +# Implementation Plan + +## Important Update: RQ Groups Integration Required + +Based on the design requirements, the current implementation needs to be updated to use **RQ Groups** for workflow tracking instead of the custom DocumentWorkflow.job_ids approach. RQ Groups provide native support for: + +- Grouping related jobs together for a document workflow +- Querying all jobs in a group with `group.get_jobs()` +- Tracking group-level status and progress +- Built-in group expiration and cleanup +- Resumable workflows using group-based job identification + +**Key Changes Required:** +1. Replace `DocumentWorkflow.job_ids` dictionary with `group_id` field +2. Use RQ Groups to create and manage document processing workflows +3. Update all job querying functions to use `Group.get_jobs()` instead of individual job fetches +4. Implement group-based workflow restart and resumability +5. Update API endpoints to work with RQ Groups + +**Note:** If RQ Groups are not available in the current RQ version, we may need to implement a custom Group wrapper or upgrade RQ version. + +## Phase 1: RQ Groups Integration and Job Chaining + +- [x] 1. Implement RQ Groups for workflow tracking + - Replace custom DocumentWorkflow.job_ids tracking with RQ Groups + - Update workflow orchestrator to create and manage RQ Groups for document workflows + - Modify job querying functions to use RQ Group.get_jobs() instead of individual job fetches + - Update DocumentWorkflow model to store group_id instead of job_ids dictionary + - _Requirements: 3.1, 3.2, 3.3, 3.4, 3.5_ + +- [x] 1.1 Create combined PDF processing job function + - Create `analysis_and_preprocess_job(document_id, s3_url, reference, workspace_id)` combining PDFOCRLayerDetector, PDFAnalyzer, and PDFPreprocessor + - Analysis runs on original PDF, then OCRmyPDF preprocessing overwrites same S3 path for page rotation + - Add job metadata tracking (document_id, reference, workflow_step, started_at, completed_at) + - Use type hints for all parameters and return values + - Integrate with existing file download/upload functions from contexts/files.py + - Store combined results in documents.metadata_ using DocumentProcessingMetadata schema + - _Requirements: 1.1, 2.1, 4.1, 4.5_ + +- [x] 1.2 Create DocumentWorkflow database model + - Add DocumentWorkflow model to models/database.py for efficient job tracking + - Create database migration for document_workflows table + - Add relationship to Document model + - Include methods for job status updates and workflow queries + - _Requirements: 2.2, 2.5, 6.1_ + +- [x] 1.3 Update DocumentWorkflow model for RQ Groups + - Remove job_ids field and add group_id field to DocumentWorkflow model + - Add status field to track overall workflow status + - Create database migration to update existing workflows table + - Add methods to interact with RQ Groups (get_workflow_status, is_resumable, restart_failed_jobs) + - Update relationships and queries to work with RQ Groups + - _Requirements: 3.1, 3.2, 3.3_ + +- [x] 1.4 Refactor workflow orchestrator to use RQ Groups + - Update create_document_workflow() to create RQ Group for each document workflow + - Add all workflow jobs to the same RQ Group using group parameter + - Use RQ's depends_on parameter for job dependencies within the group + - Store group_id in DocumentWorkflow record instead of individual job_ids + - Handle conditional OCR logic in orchestrator using group-aware job enqueueing + - _Requirements: 1.1, 1.3, 1.4, 3.1, 8.1_ + - Handle conditional OCR logic in orchestrator, not in individual jobs + - Update workflow to use single analysis_and_preprocess_job instead of separate jobs + - _Requirements: 1.1, 1.3, 1.4, 8.1_ + + +- [x] 1.6 Update process_bulk_upload function for RQ Groups + - Move file upload to S3 into process_bulk_upload (before job enqueueing) + - Create document records in database before enqueueing jobs + - Replace upload_and_preprocess_documents_job with RQ Groups-based workflow + - Update DocumentsBulkResponse to return workflow_id and group_id + - Maintain backward compatibility with existing API contracts + - _Requirements: 5.1, 5.2, 3.1_ + +## Phase 2: Job Querying and API Enhancement + +- [x] 2. Create Pydantic schemas for job input/output + - Create api/schemas/v1/document/metadata.py with DocumentProcessingMetadata schema for documents.metadata_ field + - Add WorkflowJobResult schema to api/schemas/v1/jobs.py + - Ensure all schemas have proper type hints and validation + - _Requirements: 4.1, 4.2_ + +- [x] 2.1 Implement RQ Groups-based job querying + - Update `get_jobs_for_document(db, document_id)` to use RQ Group.get_jobs() via group_id + - Update `get_jobs_by_reference(db, reference)` to query multiple groups + - Update `get_workflow_status(db, document_id)` to use RQ Group status methods + - Replace individual job fetches with group-based operations + - Handle group expiration and missing groups gracefully + - _Requirements: 2.2, 2.5, 3.2, 3.3_ + +- [x] 2.2 Update jobs API endpoint for RQ Groups + - Update GET /jobs/ to use RQ Groups-based job querying functions + - Add group_id parameter for direct group querying + - Modify WorkflowJobResult schema to include group information + - Return group metadata in API responses including group status and progress + - _Requirements: 6.1, 6.2, 3.2_ + +- [ ] 2.3 Update document workflow status endpoint for RQ Groups + - Update GET /workflows/document/{document_id} to use RQ Group status + - Calculate workflow progress using RQ Group.get_jobs() and job statuses + - Return overall workflow status derived from RQ Group state + - _Requirements: 6.5, 8.1, 3.2_ + +- [ ] 2.4 Add RQ Groups-based workflow status monitoring + - Implement workflow status updates using RQ Group callbacks + - Add group status change monitoring to update DocumentWorkflow + - Create workflow progress calculation based on RQ Group job states + - Add workflow cleanup for expired/completed groups + - _Requirements: 2.1, 2.4, 6.5, 3.3_ + +## Phase 3: CLI + +- [x] 3. Add CLI workflow management commands +- [x] 3.1 Create FastAPI workflow endpoints + - Create `extralit-server/src/extralit_server/api/handlers/v1/workflows.py` with workflow router + - Add Pydantic schemas in `extralit-server/src/extralit_server/api/schemas/v1/workflows.py` + - Implement `POST /workflows/start` endpoint for starting workflows + - Implement `GET /workflows/status` endpoint for querying workflow status + - Implement `POST /workflows/restart` endpoint for restarting failed workflows + - Implement `GET /workflows/` endpoint for listing workflows with filters + - _Requirements: 6.4_ + +- [x] 3.2 Extend WorkflowContext for RQ Groups API operations + - Update `get_workflow_status()` method to use RQ Group status and job information + - Update `get_workflows_by_reference()` method to work with group-based tracking + - Update `list_workflows()` method to include RQ Group information + - Implement efficient database queries using group_id indexing + - Add error handling for missing groups and RQ connection issues + - _Requirements: 6.4, 3.2, 3.3_ + +- [x] 3.3 Implement RQ Groups-based workflow restart functionality + - Create `restart_failed_workflow()` function using RQ Group failed job identification + - Add logic to identify failed jobs using RQ Group.get_jobs() with status filtering + - Implement job re-enqueueing within the same RQ Group with proper dependencies + - Update DocumentWorkflow records with new group state information + - Add support for partial vs full workflow restart using RQ Group capabilities + - _Requirements: 6.4, 3.4, 3.5_ + +- [x] 3.4 Create CLI module structure and integration + - Create `extralit/src/extralit/cli/workflows.py` with typer app + - Add workflow_app to main CLI using `app.add_typer(workflow_app, name="workflow")` + - Import Rich library components for formatted output (Console, Table, Progress) + - Set up HTTP client communication pattern following `import_bib.py` example + - Set up error handling patterns with typer.Exit and console.print + - _Requirements: 6.4_ + +- [x] 3.5 Implement CLI workflow start command + - Create `workflow start` command with document_id, workspace_name, reference, force, and verbose options + - Use `client.api.http_client.post()` to call `/workflows/start` endpoint + - Add validation and error handling for HTTP responses + - Add confirmation prompts and detailed output formatting + - Handle errors gracefully with user-friendly messages + - _Requirements: 6.4_ + +- [x] 3.6 Implement CLI workflow status command + - Create `workflow status` command with document_id, reference, workspace_name, watch, and json_output options + - Use `client.api.http_client.get()` to call `/workflows/status` endpoint + - Implement `_display_workflow_status_table()` helper function using Rich Table + - Add real-time status watching with `--watch` flag and periodic updates + - Support JSON output format for scripting and automation + - Calculate and display progress percentages and duration information + - _Requirements: 6.4_ + +- [x] 3.7 Implement CLI workflow restart command + - Create `workflow restart` command with document_id, reference, failed_only, and confirm options + - Use `client.api.http_client.post()` to call `/workflows/restart` endpoint + - Add confirmation prompts before restarting workflows + - Implement selective restart logic (failed jobs only vs full workflow) + - Display progress and results of restart operations + - _Requirements: 6.4_ + +- [x] 3.8 Implement CLI workflow list command + - Create `workflow list` command with workspace_name, status_filter, limit, and json_output options + - Use `client.api.http_client.get()` to call `/workflows/` endpoint + - Add filtering capabilities by workspace and status + - Implement pagination with configurable limits + - Support both table and JSON output formats + - Display comprehensive workflow information in formatted table + - _Requirements: 6.4_ + +## Phase 4: Tests and workflow handling +- [x] 4. Tests and workflow handling +- [x] 4.1 Add comprehensive RQ Groups testing + - Unit tests for RQ Groups integration functions (See extralit-server/tests/unit/jobs/test_jobs.py) + - Integration tests for complete workflow using RQ Groups + - Test group-based job querying and status functions + - Test CLI commands with RQ Groups + - Test group failure and restart scenarios + - _Requirements: All requirements validation, 3.1, 3.2, 3.3, 3.4, 3.5_ + +- [x] 4.2 Performance optimization for RQ Groups + - Test with multiple concurrent workflows using RQ Groups + - Optimize group-based job querying performance + - Add monitoring for group and queue performance + - Test worker scaling with group-aware job distribution + - Benchmark RQ Groups vs individual job tracking performance + - _Requirements: 7.2, 7.3, 7.5, 3.2_ + +- [x] 4.3 RQ Groups documentation and examples + - Document RQ Groups integration patterns + - Create examples of group-based workflow management + - Document group-based job restart procedures + - Add troubleshooting guide for RQ Groups issues + - Document performance characteristics and limitations + - _Requirements: 3.1, 3.2, 3.3, 3.4, 3.5_ \ No newline at end of file diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 000000000..79c89280c --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,84 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Architecture Overview + +Extralit is a multi-component system for scientific literature data extraction with human-in-the-loop workflows: + +- **extralit-server/**: FastAPI backend server with PostgreSQL database, handles users, datasets, records, and API interactions +- **extralit-frontend/**: Vue.js/Nuxt.js web UI for data visualization, annotation, and team collaboration +- **extralit/**: Python SDK client library for programmatic interaction with the server +- **Vector Database**: External Elasticsearch/OpenSearch for scalable vector similarity searches + +## Development Commands + +### Server (extralit-server/) +```bash +cd extralit-server/ +pdm run server-dev # Start server with auto-reload + worker +pdm run server # Start server only +pdm run worker # Start background worker only +pdm run migrate # Run database migrations +pdm run test # Run tests +pdm run test-cov # Run tests with coverage +pdm run lint # Run ruff linting +``` + +### Frontend (extralit-frontend/) +```bash +cd extralit-frontend/ +npm run dev # Development server +npm run build # Production build +npm run test # Run Jest tests +npm run test:watch # Run tests in watch mode +npm run e2e # Run Playwright e2e tests +npm run lint # ESLint check +npm run lint:fix # Fix ESLint issues +npm run format # Format with Prettier +``` + +### Client SDK (extralit/) +```bash +cd extralit/ +pdm run test # Run tests +pdm run test-cov # Run tests with coverage +pdm run lint # Run ruff linting +pdm run format # Format with black +pdm run all # Format, lint, and test +``` + +## Key Development Notes + +### Frontend Architecture +- Transitioning from Vuex to Pinia (v1/ directory contains new architecture) +- Uses domain-driven design with entities, use cases, and dependency injection +- Component structure: base (stateless) → features (page-specific) → global (reusable) + +### Backend Structure +- FastAPI with SQLAlchemy ORM and Alembic migrations +- Background job processing with Redis Queue (rq) +- OAuth2 authentication with JWT tokens +- Webhook system for external integrations +- Document processing with OCR capabilities + +### Database Management +- Alembic handles all database schema changes +- Use `pdm run revision` to create new migrations after model changes +- Always run `pdm run migrate` before starting development + +### Testing +- Backend: pytest with async support, factory-boy for fixtures +- Frontend: Jest for unit tests, Playwright for e2e +- Python packages require Python 3.9+ (extralit) or 3.10+ (extralit-server) +- Node.js 18+ required for frontend + +### Container Environment +- Docker Compose setup available for full stack development +- Services: Elasticsearch, Redis, MinIO for file storage +- See `.github/workflows/copilot-setup-steps.yml` for complete environment setup + +### Linting Configuration +- Python: Ruff with shared configuration across packages +- Frontend: ESLint + Prettier with TypeScript support +- Pre-commit hooks for code formatting and linting \ No newline at end of file diff --git a/extralit-server/CLAUDE.md b/extralit-server/CLAUDE.md new file mode 100644 index 000000000..f0e3dd8ad --- /dev/null +++ b/extralit-server/CLAUDE.md @@ -0,0 +1,128 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Architecture Overview + +Extralit-server is the FastAPI backend component of the Extralit ecosystem for scientific literature data extraction with human-in-the-loop workflows. This server handles: + +- **User Management**: Authentication, authorization, workspaces, and role-based access +- **Dataset Management**: Scientific literature datasets with metadata, fields, and questions +- **Record Processing**: Document records with responses, suggestions, and annotations +- **Background Jobs**: Asynchronous document processing, OCR, and ML inference via Redis Queue (rq) +- **Vector Search**: Integration with Elasticsearch/OpenSearch for semantic similarity searches +- **Webhook System**: Event notifications and external integrations +- **File Storage**: Document management with MinIO S3-compatible storage + +## Development Commands + +```bash +# Core development workflow +pdm run server-dev # Start server with auto-reload + worker + migrations +pdm run server # Start server only (production mode) +pdm run worker # Start background worker only +pdm run migrate # Run database migrations +pdm run revision # Create new Alembic migration after model changes + +# Testing and quality +pdm run test # Run pytest test suite +pdm run test-cov # Run tests with coverage report +pdm run lint # Run ruff linting (required before commits) + +# Database management +pdm run cli database migrate # Run migrations +pdm run cli database users create_default # Create default admin user +pdm run cli database users create # Interactive user creation +pdm run cli database revisions # Generate migration files + +# Background job management +pdm run cli worker # Start RQ worker for background jobs + +# Search engine management +pdm run cli search_engine reindex # Reindex all datasets in search engine +``` + +## Key Architecture Patterns + +### FastAPI Application Structure +- **Main App**: `src/extralit_server/_app.py` - Application factory with middleware, CORS, and lifespan management +- **API Routes**: `src/extralit_server/api/routes.py` - Centralized router configuration for v1 API +- **Route Handlers**: `src/extralit_server/api/handlers/v1/` - Request handlers organized by domain +- **Schemas**: `src/extralit_server/api/schemas/v1/` - Pydantic request/response models +- **Policies**: `src/extralit_server/api/policies/v1/` - Authorization and access control logic + +### Database Layer (SQLAlchemy + Alembic) +- **Models**: `src/extralit_server/models/database.py` - Core domain models (User, Workspace, Dataset, Record, etc.) +- **Base Model**: `src/extralit_server/models/base.py` - Abstract base with common CRUD operations +- **Migrations**: `src/extralit_server/alembic/versions/` - Database schema evolution +- **Connection**: `src/extralit_server/database.py` - Async database session management + +### Background Job Processing (RQ) +- **Queue Setup**: `src/extralit_server/jobs/queues.py` - Redis connection and queue configuration +- **Job Modules**: `src/extralit_server/jobs/` - Background tasks for documents, imports, OCR, webhooks +- **Worker**: Started via `pdm run worker` for processing async tasks + +### Search Engine Integration +- **Abstraction**: `src/extralit_server/search_engine/base.py` - Common interface for search engines +- **Implementations**: + - `src/extralit_server/search_engine/elasticsearch.py` + - `src/extralit_server/search_engine/opensearch.py` +- **Configuration**: Set via `EXTRALIT_SEARCH_ENGINE` environment variable + +### Context Layer (Business Logic) +- **Contexts**: `src/extralit_server/contexts/` - Domain-specific business logic separate from API handlers +- **Examples**: `accounts.py`, `datasets.py`, `records.py`, `imports.py` +- **Pattern**: Contexts handle complex operations, validation, and cross-domain logic + +### Authentication & Security +- **OAuth2**: `src/extralit_server/security/authentication/oauth2/` - OAuth2 provider integrations +- **JWT**: `src/extralit_server/security/authentication/jwt.py` - Token-based authentication +- **API Keys**: `src/extralit_server/security/authentication/db/api_key_backend.py` - Alternative auth method + +## Environment Configuration + +Key environment variables (prefixed with `EXTRALIT_`): +```bash +EXTRALIT_DATABASE_URL=sqlite+aiosqlite:///extralit.db # Database connection +EXTRALIT_REDIS_URL=redis://localhost:6379/0 # Redis for background jobs +EXTRALIT_ELASTICSEARCH=http://localhost:9200 # Search engine endpoint +EXTRALIT_SEARCH_ENGINE=elasticsearch # elasticsearch|opensearch +EXTRALIT_S3_ENDPOINT=http://localhost:9000 # MinIO/S3 storage +EXTRALIT_CORS_ORIGINS=["*"] # CORS configuration +``` + +## Testing Strategy + +- **Unit Tests**: `tests/unit/` - Component-level testing with mocking +- **Integration Tests**: Database and external service integration testing +- **Factories**: `tests/factories.py` - Test data generation with factory-boy +- **Async Testing**: pytest-asyncio for async database and API operations +- **Test Database**: Isolated test database created per test session + +## Important Development Notes + +### Database Migrations +- Always run `pdm run revision` after model changes to generate migrations +- Review generated migrations before applying with `pdm run migrate` +- Never edit existing migration files; create new ones for changes + +### Background Job Development +- Jobs defined in `src/extralit_server/jobs/` are executed by separate worker processes +- Use `HIGH_QUEUE` for time-sensitive jobs, `DEFAULT_QUEUE` for regular processing +- Jobs should be idempotent and handle failures gracefully + +### Search Engine Operations +- Dataset records are automatically indexed/updated in the search engine +- Use `pdm run cli search_engine reindex` after significant data changes +- Search operations are asynchronous and may have eventual consistency + +### Security Considerations +- All API endpoints require authentication (JWT tokens or API keys) +- Workspace-based authorization controls data access +- Sensitive operations require specific user roles (admin, owner) +- Environment variables should never contain secrets in production + +### File Processing +- Document uploads trigger background OCR and preprocessing jobs +- Large files are processed asynchronously to avoid blocking API requests +- File storage integrates with MinIO for scalable object storage \ No newline at end of file diff --git a/extralit-server/pdm.lock b/extralit-server/pdm.lock index 1b85fe107..9e4774204 100644 --- a/extralit-server/pdm.lock +++ b/extralit-server/pdm.lock @@ -5,7 +5,7 @@ groups = ["default", "postgresql", "test"] strategy = [] lock_version = "4.5.0" -content_hash = "sha256:27aee5423445c4af9206f4c9822e864e4aec0de735ff9b7ae80c7690258f267e" +content_hash = "sha256:e23b9777a04acef9716156bc7a9fd7e8d35368569c52ad594e321e60cd05977b" [[metadata.targets]] requires_python = ">=3.10" @@ -2382,15 +2382,15 @@ files = [ [[package]] name = "rq" -version = "1.16.2" +version = "2.4.1" summary = "" dependencies = [ "click", "redis", ] files = [ - {file = "rq-1.16.2-py3-none-any.whl", hash = "sha256:52e619f6cb469b00e04da74305045d244b75fecb2ecaa4f26422add57d3c5f09"}, - {file = "rq-1.16.2.tar.gz", hash = "sha256:5c5b9ad5fbaf792b8fada25cc7627f4d206a9a4455aced371d4f501cc3f13b34"}, + {file = "rq-2.4.1-py3-none-any.whl", hash = "sha256:a3a0839ba3213a9be013b398670caf71d9360a0c8525f343687cf2c2199e5ec8"}, + {file = "rq-2.4.1.tar.gz", hash = "sha256:40ba01af3edacc008ab376009a3a547278d2bfe02a77cd4434adc0b01788239f"}, ] [[package]] diff --git a/extralit-server/pyproject.toml b/extralit-server/pyproject.toml index f2f45e17a..c825b448d 100644 --- a/extralit-server/pyproject.toml +++ b/extralit-server/pyproject.toml @@ -48,7 +48,7 @@ dependencies = [ "oauthlib ~= 3.2.0", "social-auth-core ~= 4.5.0", # Background processing - "rq ~= 1.16.2", + "rq~=2.4.1", "lazy-loader>=0.4", # Info status "psutil ~= 5.8, <5.10", @@ -139,7 +139,6 @@ exclude_lines = [ ] [tool.ruff] -# Exclude a variety of commonly ignored directories. exclude = [ ".bzr", ".direnv", @@ -165,7 +164,6 @@ line-length = 120 target-version = "py310" [tool.ruff.lint] -# Enforce only high-priority correctness rules initially. select = [ "F", # Pyflakes (undefined names, logical issues) "E7", # Syntax/indentation errors @@ -182,7 +180,6 @@ select = [ "RUF", # ruff-specific rules ] -# Temporarily ignore modernization and lower-priority/style rules. ignore = [ "E402", # imports not at top (lazy / optional import patterns) "B904", # exception chaining (will phase in later) @@ -196,7 +193,6 @@ ignore = [ "FAST002" # FastAPI Depends suggestion ] -# Per-file ignores [tool.ruff.lint.per-file-ignores] "__init__.py" = ["F401", "F403", "I001"] # Ignore unused imports and wildcard imports in __init__.py "tests/**/*.py" = ["ASYNC", "F821"] # More lenient async rules in tests (ignore undefined names like forward refs/fixtures) diff --git a/extralit-server/src/extralit_server/alembic/versions/54d65879a68e_create_document_workflows_table.py b/extralit-server/src/extralit_server/alembic/versions/54d65879a68e_create_document_workflows_table.py new file mode 100644 index 000000000..90f2f678e --- /dev/null +++ b/extralit-server/src/extralit_server/alembic/versions/54d65879a68e_create_document_workflows_table.py @@ -0,0 +1,62 @@ +# Copyright 2024-present, Extralit Labs, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""create_document_workflows_table + +Revision ID: 54d65879a68e +Revises: 7d6b33203390 +Create Date: 2025-08-17 23:12:00.379621 + +""" + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "54d65879a68e" +down_revision = "7d6b33203390" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "workflows", + sa.Column("workflow_type", sa.String(length=50), nullable=False), + sa.Column("status", sa.String(length=50), default="pending", nullable=False), + sa.Column("workspace_id", sa.Uuid(), nullable=False), + sa.Column("document_id", sa.Uuid(), nullable=False), + sa.Column("reference", sa.String(length=255), nullable=True), + sa.Column("group_id", sa.String(length=255), nullable=False), + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("inserted_at", sa.DateTime(), nullable=False), + sa.Column("updated_at", sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint(["document_id"], ["documents.id"], ondelete="CASCADE"), + sa.ForeignKeyConstraint(["workspace_id"], ["workspaces.id"]), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index(op.f("ix_workflows_document_id"), "workflows", ["document_id"], unique=False) + op.create_index(op.f("ix_workflows_reference"), "workflows", ["reference"], unique=False) + op.create_index(op.f("ix_workflows_group_id"), "workflows", ["group_id"], unique=False) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f("ix_workflows_group_id"), table_name="workflows") + op.drop_index(op.f("ix_workflows_reference"), table_name="workflows") + op.drop_index(op.f("ix_workflows_document_id"), table_name="workflows") + op.drop_table("workflows") + # ### end Alembic commands ### diff --git a/extralit-server/src/extralit_server/alembic/versions/7d6b33203390_create_import_history_table.py b/extralit-server/src/extralit_server/alembic/versions/7d6b33203390_create_import_history_table.py index 902d513f2..f83f0fdf0 100644 --- a/extralit-server/src/extralit_server/alembic/versions/7d6b33203390_create_import_history_table.py +++ b/extralit-server/src/extralit_server/alembic/versions/7d6b33203390_create_import_history_table.py @@ -32,7 +32,7 @@ def upgrade() -> None: op.create_table( - "import_history", + "imports", sa.Column("workspace_id", sa.Uuid(), nullable=False), sa.Column("user_id", sa.Uuid(), nullable=False), sa.Column("filename", sa.String(), nullable=False), @@ -45,15 +45,15 @@ def upgrade() -> None: sa.ForeignKeyConstraint(["workspace_id"], ["workspaces.id"], ondelete="CASCADE"), sa.PrimaryKeyConstraint("id"), ) - op.create_index(op.f("ix_import_history_user_id"), "import_history", ["user_id"], unique=False) - op.create_index(op.f("ix_import_history_workspace_id"), "import_history", ["workspace_id"], unique=False) + op.create_index(op.f("ix_import_history_user_id"), "imports", ["user_id"], unique=False) + op.create_index(op.f("ix_import_history_workspace_id"), "imports", ["workspace_id"], unique=False) op.create_index(op.f("ix_documents_doi"), "documents", ["doi"], unique=False) op.add_column("documents", sa.Column("metadata", sa.JSON(), nullable=True)) def downgrade() -> None: op.drop_index(op.f("ix_documents_doi"), table_name="documents") - op.drop_index(op.f("ix_import_history_workspace_id"), table_name="import_history") - op.drop_index(op.f("ix_import_history_user_id"), table_name="import_history") - op.drop_table("import_history") + op.drop_index(op.f("ix_import_history_workspace_id"), table_name="imports") + op.drop_index(op.f("ix_import_history_user_id"), table_name="imports") + op.drop_table("imports") op.drop_column("documents", "metadata") diff --git a/extralit-server/src/extralit_server/api/handlers/v1/documents.py b/extralit-server/src/extralit_server/api/handlers/v1/documents.py index 1da34e28e..8ff59f322 100644 --- a/extralit-server/src/extralit_server/api/handlers/v1/documents.py +++ b/extralit-server/src/extralit_server/api/handlers/v1/documents.py @@ -267,12 +267,14 @@ async def create_documents_bulk( try: metadata_dict = json.loads(documents_metadata) bulk_create = DocumentsBulkCreate.model_validate(metadata_dict) - except json.JSONDecodeError: + except json.JSONDecodeError as e: + print(e) raise HTTPException( status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="Invalid JSON in documents_metadata", ) except Exception as e: + print(e) raise HTTPException( status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=f"Invalid metadata format: {e!s}", diff --git a/extralit-server/src/extralit_server/api/handlers/v1/jobs.py b/extralit-server/src/extralit_server/api/handlers/v1/jobs.py index 0f7f11eca..de53c0065 100644 --- a/extralit-server/src/extralit_server/api/handlers/v1/jobs.py +++ b/extralit-server/src/extralit_server/api/handlers/v1/jobs.py @@ -12,18 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Annotated +from typing import Annotated, Optional +from uuid import UUID -from fastapi import APIRouter, Depends, HTTPException, Security, status +from fastapi import APIRouter, Depends, HTTPException, Query, Security, status from rq.exceptions import NoSuchJobError from rq.job import Job from sqlalchemy.ext.asyncio import AsyncSession from extralit_server.api.policies.v1 import JobPolicy, authorize from extralit_server.api.schemas.v1.jobs import Job as JobSchema +from extralit_server.api.schemas.v1.jobs import WorkflowJobResult +from extralit_server.contexts.workflows import get_jobs_by_reference, get_jobs_for_document from extralit_server.database import get_async_db from extralit_server.jobs.queues import REDIS_CONNECTION from extralit_server.models import User +from extralit_server.models.database import DocumentWorkflow from extralit_server.security import auth router = APIRouter(tags=["jobs"]) @@ -51,3 +55,98 @@ async def get_job( await authorize(current_user, JobPolicy.get) return JobSchema(id=job.id, status=job.get_status(refresh=True)) + + +@router.get("/jobs/", response_model=list[WorkflowJobResult]) +async def get_jobs( + *, + db: Annotated[AsyncSession, Depends(get_async_db)], + current_user: Annotated[User, Security(auth.get_current_user)], + document_id: Optional[UUID] = Query(None, description="Filter by document ID"), + reference: Optional[str] = Query(None, description="Filter by document reference"), + group_id: Optional[str] = Query(None, description="Filter by RQ Group ID"), +): + """ + Get jobs using RQ Groups-based querying. + + Maps document IDs to RQ Groups for efficient job status retrieval. + For reference-level queries, searches over documents linked to the reference. + """ + await authorize(current_user, JobPolicy.get) + + try: + if document_id: + # Get jobs for specific document using RQ Groups + jobs_raw = await get_jobs_for_document(db, document_id) + return [_convert_job_data_to_result(job_data) for job_data in jobs_raw] + + elif reference: + # Get jobs for all documents with reference using RQ Groups + jobs_raw = await get_jobs_by_reference(db, reference) + return [_convert_job_data_to_result(job_data) for job_data in jobs_raw] + + elif group_id: + # Direct group querying - get workflow and jobs from RQ Group + workflow = await DocumentWorkflow.get_by_group_id(db, group_id) + if not workflow: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Workflow with group_id `{group_id}` not found", + ) + + from extralit_server.contexts.workflows import get_workflow_status_from_group + + group_status = get_workflow_status_from_group(group_id) + + # Convert group jobs to WorkflowJobResult format + results = [] + for job_data in group_status.get("jobs", []): + job_result = _convert_job_data_to_result(job_data) + # Add group metadata to each job result + job_result.group_id = workflow.group_id + job_result.group_status = group_status.get("status") + job_result.group_progress = group_status.get("progress") + job_result.total_jobs = group_status.get("total_jobs") + job_result.completed_jobs = group_status.get("completed_jobs") + job_result.failed_jobs = group_status.get("failed_jobs") + job_result.running_jobs = group_status.get("running_jobs") + job_result.document_id = workflow.document_id + job_result.reference = workflow.reference + job_result.workspace_id = workflow.workspace_id + results.append(job_result) + + return results + else: + # No filters provided - return empty list + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Must provide at least one filter: document_id, reference, or group_id", + ) + + except Exception as e: + if isinstance(e, HTTPException): + raise + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Error retrieving jobs: {e!s}", + ) + + +def _convert_job_data_to_result(job_data: dict) -> WorkflowJobResult: + """Convert job data to WorkflowJobResult.""" + return WorkflowJobResult( + id=job_data.get("id", ""), + status=job_data.get("status", "unknown"), + document_id=job_data.get("document_id"), + reference=job_data.get("reference"), + workspace_id=job_data.get("workspace_id"), + workflow_step=job_data.get("workflow_step"), + progress=job_data.get("meta", {}).get("progress") if job_data.get("meta") else None, + error=job_data.get("exc_info") or job_data.get("error"), + result=job_data.get("result"), + meta=job_data.get("meta"), + started_at=job_data.get("started_at"), + completed_at=job_data.get("ended_at"), + # RQ Groups metadata (will be populated for group queries) + group_id=job_data.get("group_id"), + ) diff --git a/extralit-server/src/extralit_server/api/handlers/v1/workflows.py b/extralit-server/src/extralit_server/api/handlers/v1/workflows.py new file mode 100644 index 000000000..5ab77aa3a --- /dev/null +++ b/extralit-server/src/extralit_server/api/handlers/v1/workflows.py @@ -0,0 +1,347 @@ +# Copyright 2024-present, Extralit Labs, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Annotated, Optional +from uuid import UUID + +from fastapi import APIRouter, Depends, HTTPException, Query, Security, status +from sqlalchemy.ext.asyncio import AsyncSession + +from extralit_server.api.policies.v1 import JobPolicy, authorize +from extralit_server.api.schemas.v1.workflows import ( + RestartWorkflowRequest, + StartWorkflowRequest, + StartWorkflowResponse, + WorkflowStatusResponse, +) +from extralit_server.contexts.workflows import ( + get_workflow_status, + get_workflow_statuses_by_reference, + restart_failed_jobs_in_workflow, +) +from extralit_server.database import get_async_db +from extralit_server.models import User +from extralit_server.models.database import Document, DocumentWorkflow, Workspace +from extralit_server.security import auth +from extralit_server.workflows.documents import create_document_workflow + +_LOGGER = logging.getLogger(__name__) + +router = APIRouter(tags=["workflows"]) + + +@router.post( + "/workflows/start", +) +async def start_workflow( + *, + db: Annotated[AsyncSession, Depends(get_async_db)], + current_user: Annotated[User, Security(auth.get_current_user)], + request: StartWorkflowRequest, +) -> StartWorkflowResponse: + """Start PDF processing workflow for a document.""" + await authorize(current_user, JobPolicy.get) + + try: + # Get document and validate it exists + document = await db.get(Document, request.document_id) + if not document: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Document with id `{request.document_id}` not found", + ) + + # Get workspace by name + from sqlalchemy import select + + workspace_result = await db.execute(select(Workspace).where(Workspace.name == request.workspace_name)) + workspace = workspace_result.scalar_one_or_none() + if not workspace: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Workspace `{request.workspace_name}` not found", + ) + + # Check if workflow already exists + existing_workflow = await DocumentWorkflow.get_by_document_id(db, request.document_id) + if existing_workflow and not request.force: + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail=f"Workflow already exists for document {request.document_id}. Use force=true to restart.", + ) + + # Generate reference if not provided + reference = request.reference or f"doc_{str(request.document_id)[:8]}" + + # Get document S3 URL (assuming it's stored in document metadata or similar) + # This is a placeholder - you'll need to implement the actual S3 URL retrieval + s3_url = getattr(document, "s3_url", None) or f"s3://documents/{document.id}" + + # Start the workflow + await create_document_workflow( + document_id=request.document_id, + s3_url=s3_url, + reference=reference, + workspace_name=request.workspace_name, + workspace_id=workspace.id, + ) + + # Get the created workflow + workflow = await DocumentWorkflow.get_by_document_id(db, request.document_id) + if not workflow: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to create workflow record", + ) + + return StartWorkflowResponse( + workflow_id=str(workflow.id), + document_id=str(request.document_id), + group_id=workflow.group_id, + status=workflow.status, + reference=reference, + ) + + except HTTPException: + raise + except Exception as e: + _LOGGER.error(f"Error starting workflow for document {request.document_id}: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to start workflow: {e!s}", + ) + + +@router.get("/workflows/status", response_model=list[WorkflowStatusResponse]) +async def get_workflow_status_endpoint( + *, + db: Annotated[AsyncSession, Depends(get_async_db)], + current_user: Annotated[User, Security(auth.get_current_user)], + document_id: Optional[UUID] = Query(None, description="Filter by document ID"), + reference: Optional[str] = Query(None, description="Filter by document reference"), + workspace_name: Optional[str] = Query(None, description="Filter by workspace name"), +) -> list[WorkflowStatusResponse] | None: + """Get workflow status for documents.""" + await authorize(current_user, JobPolicy.get) + + try: + if not document_id and not reference: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Must specify either document_id or reference", + ) + + if document_id: + # Get status for specific document + workflow_status = await get_workflow_status(db, document_id) + + # Get workspace name if needed + if workspace_name and workflow_status.get("workspace_id"): + workspace = await db.get(Workspace, workflow_status["workspace_id"]) + if workspace and workspace.name != workspace_name: + return [] # Filter out if workspace doesn't match + workflow_status["workspace_name"] = workspace.name if workspace else None + + return [_convert_to_workflow_status_response(workflow_status)] + + elif reference: + # Get status for all documents with reference + workflow_statuses = await get_workflow_statuses_by_reference(db, reference) + + results = [] + for workflow_status in workflow_statuses: + # Get workspace name if needed + if workspace_name and workflow_status.get("workspace_id"): + workspace = await db.get(Workspace, workflow_status["workspace_id"]) + if workspace and workspace.name != workspace_name: + continue # Skip if workspace doesn't match + workflow_status["workspace_name"] = workspace.name if workspace else None + elif workflow_status.get("workspace_id"): + workspace = await db.get(Workspace, workflow_status["workspace_id"]) + workflow_status["workspace_name"] = workspace.name if workspace else None + + results.append(_convert_to_workflow_status_response(workflow_status)) + + return results + + except HTTPException: + raise + except Exception as e: + _LOGGER.error(f"Error getting workflow status: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to get workflow status: {e!s}", + ) + + +@router.post("/workflows/restart") +async def restart_workflow( + *, + db: Annotated[AsyncSession, Depends(get_async_db)], + current_user: Annotated[User, Security(auth.get_current_user)], + request: RestartWorkflowRequest, +) -> StartWorkflowResponse: + """Restart failed workflow jobs using RQ Groups.""" + await authorize(current_user, JobPolicy.get) + + try: + # Get workflow + workflow = await DocumentWorkflow.get_by_document_id(db, request.document_id) + if not workflow: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Workflow not found for document {request.document_id}", + ) + + # Check if workflow is resumable + from extralit_server.contexts.workflows import is_workflow_resumable + + if not is_workflow_resumable(workflow.group_id): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Workflow is not in a resumable state (no failed jobs found)", + ) + + # Restart failed jobs + restart_result = await restart_failed_jobs_in_workflow(db, workflow) + + if not restart_result["success"]: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to restart workflow: {restart_result.get('error', 'Unknown error')}", + ) + + return StartWorkflowResponse( + workflow_id=str(workflow.id), + document_id=str(request.document_id), + group_id=workflow.group_id, + status="running", + reference=workflow.reference, + restarted_jobs=restart_result["restarted_jobs"], + ) + + except HTTPException: + raise + except Exception as e: + _LOGGER.error(f"Error restarting workflow for document {request.document_id}: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to restart workflow: {e!s}", + ) + + +@router.get( + "/workflows/", +) +async def list_workflows( + *, + db: Annotated[AsyncSession, Depends(get_async_db)], + current_user: Annotated[User, Security(auth.get_current_user)], + workspace_name: Optional[str] = Query(None, description="Filter by workspace name"), + status_filter: Optional[str] = Query(None, description="Filter by status"), + limit: int = Query(50, description="Maximum number of workflows to return"), +) -> list[WorkflowStatusResponse]: + """List workflows with optional filtering.""" + await authorize(current_user, JobPolicy.get) + + try: + from sqlalchemy import select + + # Build query + query = select(DocumentWorkflow).order_by(DocumentWorkflow.inserted_at.desc()).limit(limit) + + # Apply workspace filter + if workspace_name: + workspace_result = await db.execute(select(Workspace).where(Workspace.name == workspace_name)) + workspace = workspace_result.scalar_one_or_none() + if not workspace: + return [] # No workflows if workspace doesn't exist + query = query.where(DocumentWorkflow.workspace_id == workspace.id) + + # Apply status filter + if status_filter: + query = query.where(DocumentWorkflow.status == status_filter) + + # Execute query + result = await db.execute(query) + workflows = result.scalars().all() + + # Convert to response format + workflow_responses = [] + for workflow in workflows: + try: + # Get detailed workflow status + workflow_status = await get_workflow_status(db, workflow.document_id) + + # Get workspace name + workspace = await db.get(Workspace, workflow.workspace_id) + workflow_status["workspace_name"] = workspace.name if workspace else None + + workflow_responses.append(_convert_to_workflow_status_response(workflow_status)) + + except Exception as workflow_error: + _LOGGER.warning(f"Error processing workflow {workflow.id}: {workflow_error}") + # Add basic workflow info even if detailed status fails + workflow_responses.append( + WorkflowStatusResponse( + workflow_id=str(workflow.id), + document_id=str(workflow.document_id), + group_id=workflow.group_id, + status="error", + progress=0.0, + reference=workflow.reference, + workspace_id=str(workflow.workspace_id), + workflow_type=workflow.workflow_type, + total_jobs=0, + completed_jobs=0, + failed_jobs=0, + running_jobs=0, + created_at=workflow.inserted_at, + updated_at=workflow.updated_at, + error=f"Error processing workflow: {workflow_error}", + ) + ) + + return workflow_responses + + except Exception as e: + _LOGGER.error(f"Error listing workflows: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to list workflows: {e!s}", + ) + + +def _convert_to_workflow_status_response(workflow_status: dict) -> WorkflowStatusResponse: + """Convert workflow status dictionary to response schema.""" + return WorkflowStatusResponse( + workflow_id=str(workflow_status.get("workflow_id", "")), + document_id=str(workflow_status.get("document_id", "")), + group_id=workflow_status.get("group_id", ""), + status=workflow_status.get("status", "unknown"), + progress=workflow_status.get("progress", 0.0), + reference=workflow_status.get("reference"), + workspace_name=workflow_status.get("workspace_name"), + workspace_id=str(workflow_status.get("workspace_id", "")), + workflow_type=workflow_status.get("workflow_type", "unknown"), + total_jobs=workflow_status.get("total_jobs", 0), + completed_jobs=workflow_status.get("completed_jobs", 0), + failed_jobs=workflow_status.get("failed_jobs", 0), + running_jobs=workflow_status.get("running_jobs", 0), + created_at=workflow_status.get("created_at"), + updated_at=workflow_status.get("updated_at"), + error=workflow_status.get("error"), + jobs=workflow_status.get("jobs"), + ) diff --git a/extralit-server/src/extralit_server/api/routes.py b/extralit-server/src/extralit_server/api/routes.py index 7bc866e73..20ce8322f 100644 --- a/extralit-server/src/extralit_server/api/routes.py +++ b/extralit-server/src/extralit_server/api/routes.py @@ -72,6 +72,9 @@ vectors_settings as vectors_settings_v1, ) from extralit_server.api.handlers.v1 import webhooks as webhooks_v1 +from extralit_server.api.handlers.v1 import ( + workflows as workflows_v1, +) from extralit_server.api.handlers.v1 import ( workspaces as workspaces_v1, ) @@ -106,6 +109,7 @@ def create_api_v1(): vectors_settings_v1.router, workspaces_v1.router, webhooks_v1.router, + workflows_v1.router, jobs_v1.router, oauth2_v1.router, settings_v1.router, diff --git a/extralit-server/src/extralit_server/api/schemas/v1/document/segments.py b/extralit-server/src/extralit_server/api/schemas/v1/document/chunks.py similarity index 100% rename from extralit-server/src/extralit_server/api/schemas/v1/document/segments.py rename to extralit-server/src/extralit_server/api/schemas/v1/document/chunks.py diff --git a/extralit-server/src/extralit_server/api/schemas/v1/document/metadata.py b/extralit-server/src/extralit_server/api/schemas/v1/document/metadata.py new file mode 100644 index 000000000..04d6f7adb --- /dev/null +++ b/extralit-server/src/extralit_server/api/schemas/v1/document/metadata.py @@ -0,0 +1,107 @@ +# Copyright 2024-present, Extralit Labs, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Document processing metadata schemas for workflow tracking.""" + +from datetime import datetime, timezone +from typing import Any, Optional + +from pydantic import BaseModel, Field + + +class OCRQualityMetadata(BaseModel): + """OCR quality analysis metadata.""" + + total_chars: int = Field(..., description="Total characters analyzed") + ocr_artifacts: int = Field(..., description="Number of OCR artifacts detected") + suspicious_patterns: int = Field(..., description="Number of suspicious patterns found") + ocr_quality_score: float = Field(..., description="Overall OCR quality score (0.0-1.0)") + + +class LayoutAnalysisMetadata(BaseModel): + """PDF layout analysis metadata.""" + + page_count: int = Field(None, description="Number of pages in PDF") + has_tables: bool = Field(default=False, description="Whether tables were detected") + has_figures: bool = Field(default=False, description="Whether figures were detected") + text_regions: int = Field(default=0, description="Number of text regions detected") + margin_analysis: dict[str, Any] = Field(default_factory=dict, description="Margin analysis results") + + +class AnalysisMetadata(BaseModel): + """Analysis job results stored in documents.metadata_.""" + + has_ocr_text_layer: Optional[bool] = Field(None, description="Whether PDF has OCR text layer") + needs_ocr: Optional[bool] = Field(None, description="Whether additional OCR processing is needed") + ocr_quality: OCRQualityMetadata = Field(..., description="OCR quality analysis") + layout_analysis: LayoutAnalysisMetadata = Field(..., description="Layout analysis results") + analysis_completed_at: Optional[str] = Field(None, description="When analysis was completed") + + +class PreprocessingMetadata(BaseModel): + """Preprocessing job results stored in documents.metadata_.""" + + processing_time: float = Field(..., description="Processing time in seconds") + ocr_applied: bool = Field(..., description="Whether OCR was applied during preprocessing") + processed_s3_url: Optional[str] = Field(None, description="S3 URL of processed PDF") + preprocessing_completed_at: Optional[str] = Field(None, description="When preprocessing was completed") + + +class TextExtractionMetadata(BaseModel): + """Text extraction job results.""" + + extracted_text_length: int = Field(..., description="Length of extracted text") + extraction_method: str = Field(..., description="Method used for extraction") + text_extraction_completed_at: Optional[str] = Field(None, description="When text extraction was completed") + + +class DocumentProcessingMetadata(BaseModel): + """Complete document processing metadata stored in documents.metadata_.""" + + workflow_id: Optional[str] = Field(None, description="Workflow ID for tracking") + analysis_metadata: Optional[AnalysisMetadata] = Field(None, description="Analysis results") + preprocessing_metadata: Optional[PreprocessingMetadata] = Field(None, description="Preprocessing results") + text_extraction_metadata: Optional[TextExtractionMetadata] = Field(None, description="Text extraction results") + workflow_started_at: Optional[datetime] = Field(None, description="When workflow was started") + workflow_completed_at: Optional[datetime] = Field(None, description="When workflow was completed") + workflow_status: str = Field(default="running", description="Overall workflow status") + + def update_analysis_results(self, analysis_result: dict) -> None: + """Update analysis metadata from job result.""" + self.analysis_metadata = AnalysisMetadata( + has_ocr_text_layer=analysis_result.get("has_ocr_text_layer"), + needs_ocr=analysis_result.get("needs_ocr"), + ocr_quality=OCRQualityMetadata(**analysis_result.get("analysis_metadata", {})), + layout_analysis=LayoutAnalysisMetadata(**analysis_result.get("layout_analysis", {})), + analysis_completed_at=datetime.now(timezone.utc).isoformat(), + ) + + def update_preprocessing_results(self, preprocess_result: dict) -> None: + """Update preprocessing metadata from job result.""" + self.preprocessing_metadata = PreprocessingMetadata( + processing_time=preprocess_result["processing_time"], + ocr_applied=preprocess_result.get("ocr_applied", False), + processed_s3_url=preprocess_result.get("processed_s3_url"), + preprocessing_completed_at=datetime.now(timezone.utc).isoformat(), + ) + + def is_workflow_complete(self) -> bool: + """Check if all workflow steps are complete.""" + return all( + [ + self.analysis_metadata is not None, + self.preprocessing_metadata is not None, + self.text_extraction_metadata is not None, + ] + ) diff --git a/extralit-server/src/extralit_server/api/schemas/v1/jobs.py b/extralit-server/src/extralit_server/api/schemas/v1/jobs.py index 54dede294..5e685f1f4 100644 --- a/extralit-server/src/extralit_server/api/schemas/v1/jobs.py +++ b/extralit-server/src/extralit_server/api/schemas/v1/jobs.py @@ -12,10 +12,40 @@ # See the License for the specific language governing permissions and # limitations under the License. -from pydantic import BaseModel +from datetime import datetime +from typing import Any, Optional +from uuid import UUID + +from pydantic import BaseModel, Field from rq.job import JobStatus class Job(BaseModel): id: str status: JobStatus + + +class WorkflowJobResult(BaseModel): + """Schema for workflow job results with RQ Groups metadata.""" + + id: str = Field(..., description="Job ID") + status: JobStatus = Field(..., description="Job status") + document_id: Optional[UUID] = Field(None, description="Document ID associated with the job") + reference: Optional[str] = Field(None, description="Document reference for tracking") + workspace_id: Optional[UUID] = Field(None, description="Workspace ID") + workflow_step: Optional[str] = Field(None, description="Current workflow step") + progress: Optional[float] = Field(None, description="Job progress (0.0-1.0)") + error: Optional[str] = Field(None, description="Error message if job failed") + result: Optional[dict[str, Any]] = Field(None, description="Job result data") + meta: Optional[dict[str, Any]] = Field(None, description="Additional job metadata") + started_at: Optional[datetime] = Field(None, description="When job was started") + completed_at: Optional[datetime] = Field(None, description="When job was completed") + + # RQ Groups integration fields + group_id: Optional[str] = Field(None, description="RQ Group ID for the workflow") + group_status: Optional[str] = Field(None, description="Status of the entire RQ Group") + group_progress: Optional[float] = Field(None, description="Overall progress of the group (0.0-1.0)") + total_jobs: Optional[int] = Field(None, description="Total number of jobs in the group") + completed_jobs: Optional[int] = Field(None, description="Number of completed jobs in the group") + failed_jobs: Optional[int] = Field(None, description="Number of failed jobs in the group") + running_jobs: Optional[int] = Field(None, description="Number of running jobs in the group") diff --git a/extralit-server/src/extralit_server/api/schemas/v1/workflows.py b/extralit-server/src/extralit_server/api/schemas/v1/workflows.py new file mode 100644 index 000000000..c1f598ea6 --- /dev/null +++ b/extralit-server/src/extralit_server/api/schemas/v1/workflows.py @@ -0,0 +1,84 @@ +# Copyright 2024-present, Extralit Labs, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from datetime import datetime +from typing import Any, Optional +from uuid import UUID + +from pydantic import BaseModel, Field + + +class StartWorkflowRequest(BaseModel): + """Request schema for starting a workflow.""" + + document_id: UUID = Field(..., description="Document UUID to process") + workspace_name: str = Field(..., description="Workspace name") + reference: Optional[str] = Field(None, description="Document reference for tracking") + force: bool = Field(False, description="Force restart if workflow already exists") + + +class StartWorkflowResponse(BaseModel): + """Response schema for starting a workflow.""" + + workflow_id: str = Field(..., description="Workflow ID") + document_id: str = Field(..., description="Document ID") + group_id: str = Field(..., description="RQ Group ID for tracking") + status: str = Field(..., description="Initial workflow status") + reference: Optional[str] = Field(None, description="Document reference") + restarted_jobs: Optional[list[str]] = Field(None, description="List of restarted job IDs (for restarts)") + + +class RestartWorkflowRequest(BaseModel): + """Request schema for restarting a workflow.""" + + document_id: UUID = Field(..., description="Document UUID to restart") + failed_only: bool = Field(True, description="Only restart failed jobs") + + +class WorkflowStatusResponse(BaseModel): + """Response schema for workflow status.""" + + workflow_id: str = Field(..., description="Workflow ID") + document_id: str = Field(..., description="Document ID") + group_id: str = Field(..., description="RQ Group ID") + status: str = Field(..., description="Workflow status") + progress: float = Field(..., description="Progress percentage (0.0-1.0)") + reference: Optional[str] = Field(None, description="Document reference") + workspace_name: Optional[str] = Field(None, description="Workspace name") + workspace_id: Optional[str] = Field(None, description="Workspace ID") + workflow_type: str = Field(..., description="Type of workflow") + + # Job statistics + total_jobs: int = Field(..., description="Total number of jobs") + completed_jobs: int = Field(..., description="Number of completed jobs") + failed_jobs: int = Field(..., description="Number of failed jobs") + running_jobs: int = Field(..., description="Number of running jobs") + + # Timestamps + created_at: Optional[datetime] = Field(None, description="When workflow was created") + updated_at: Optional[datetime] = Field(None, description="When workflow was last updated") + + # Error information + error: Optional[str] = Field(None, description="Error message if workflow failed") + + # Job details (optional, for detailed queries) + jobs: Optional[list[dict[str, Any]]] = Field(None, description="Detailed job information") + + +class WorkflowListRequest(BaseModel): + """Request schema for listing workflows.""" + + workspace_name: Optional[str] = Field(None, description="Filter by workspace name") + status_filter: Optional[str] = Field(None, description="Filter by status") + limit: int = Field(50, description="Maximum number of workflows to return") diff --git a/extralit-server/src/extralit_server/cli/worker.py b/extralit-server/src/extralit_server/cli/worker.py index 193115504..e387e35b1 100644 --- a/extralit-server/src/extralit_server/cli/worker.py +++ b/extralit-server/src/extralit_server/cli/worker.py @@ -15,23 +15,21 @@ import typer -from extralit_server.jobs.queues import DEFAULT_QUEUE, HIGH_QUEUE +from extralit_server.jobs.queues import DEFAULT_QUEUE, HIGH_QUEUE, OCR_QUEUE DEFAULT_NUM_WORKERS = 2 def worker( - queues: list[str] = typer.Option([DEFAULT_QUEUE.name, HIGH_QUEUE.name], help="Name of queues to listen"), + queues: list[str] = typer.Option( + [DEFAULT_QUEUE.name, HIGH_QUEUE.name, OCR_QUEUE.name], help="Name of queues to listen" + ), num_workers: int = typer.Option(DEFAULT_NUM_WORKERS, help="Number of workers to start"), ) -> None: from rq.worker_pool import WorkerPool from extralit_server.jobs.queues import REDIS_CONNECTION - worker_pool = WorkerPool( - connection=REDIS_CONNECTION, - queues=queues, - num_workers=num_workers, - ) + worker_pool = WorkerPool(connection=REDIS_CONNECTION, queues=queues, num_workers=num_workers, reload=True) worker_pool.start() diff --git a/extralit-server/src/extralit_server/contexts/document/analysis.py b/extralit-server/src/extralit_server/contexts/document/analysis.py index 497d03040..4e68fd856 100644 --- a/extralit-server/src/extralit_server/contexts/document/analysis.py +++ b/extralit-server/src/extralit_server/contexts/document/analysis.py @@ -48,10 +48,10 @@ def has_ocr_text_layer(self, pdf_bytes: bytes, threshold: float = 0.5, verbose=F print(page_info) pages_with_fonts = sum(1 for page in page_info if page.get("has_fonts", False)) - total_pages = len(page_info) + page_count = len(page_info) # Return True if more than 50% of pages have fonts - return pages_with_fonts > (total_pages * threshold) + return pages_with_fonts > (page_count * threshold) def _check_font_resources_per_page(self, pdf_bytes: bytes) -> list[dict]: """ diff --git a/extralit-server/src/extralit_server/contexts/document/margin.py b/extralit-server/src/extralit_server/contexts/document/margin.py index 6f57a3637..92ffe1ddd 100644 --- a/extralit-server/src/extralit_server/contexts/document/margin.py +++ b/extralit-server/src/extralit_server/contexts/document/margin.py @@ -147,7 +147,7 @@ def analyze_pdf_layout(self, pdf_data: bytes, filename: str) -> dict: try: images = pdf2image.convert_from_bytes(pdf_data, dpi=150) # type: ignore if not images: - return {"analysis_available": False, "error": "No pages found"} + return {"error": "No pages found"} _LOGGER.info(f"Analyzing layout for {filename} with {len(images)} pages") @@ -155,15 +155,14 @@ def analyze_pdf_layout(self, pdf_data: bytes, filename: str) -> dict: layout_data = self._analyze_page_layout(images) return { - "analysis_available": True, - "total_pages": len(images), + "page_count": len(images), "page_dimensions": {"width": images[0].size[0], "height": images[0].size[1]} if images else {}, **layout_data, } except Exception as e: _LOGGER.error(f"PDF layout analysis failed for {filename}: {e}") - return {"analysis_available": False, "error": str(e)} + return {"error": str(e)} def _analyze_page_layout(self, images: list["Image"]) -> dict: """ diff --git a/extralit-server/src/extralit_server/contexts/files.py b/extralit-server/src/extralit_server/contexts/files.py index 5e2eb2190..e79c498d8 100644 --- a/extralit-server/src/extralit_server/contexts/files.py +++ b/extralit-server/src/extralit_server/contexts/files.py @@ -565,6 +565,31 @@ def put_document_file( return None +def download_file_content(client: Minio | LocalFileStorage, document_url: str) -> bytes: + """ + Download file content from a document URL. + + Args: + client: Minio or LocalFileStorage client + document_url: URL in format "/api/v1/file/{bucket_name}/{object_path}" + + Returns: + File content as bytes + """ + # Parse URL to get bucket and object path + if not document_url.startswith("/api/v1/file/"): + raise ValueError(f"Invalid document URL format: {document_url}") + + url_parts = document_url.replace("/api/v1/file/", "").split("/", 1) + if len(url_parts) != 2: + raise ValueError(f"Invalid document URL format: {document_url}") + + bucket_name, object_path = url_parts + + file_response = get_object(client, bucket_name, object_path) + return file_response.response.read() + + def delete_bucket(client: Minio | LocalFileStorage, workspace_name: str): if isinstance(client, LocalFileStorage): try: diff --git a/extralit-server/src/extralit_server/contexts/imports.py b/extralit-server/src/extralit_server/contexts/imports.py index 781895f9b..001961d7e 100644 --- a/extralit-server/src/extralit_server/contexts/imports.py +++ b/extralit-server/src/extralit_server/contexts/imports.py @@ -14,7 +14,7 @@ import logging from os.path import basename -from uuid import UUID +from uuid import UUID, uuid4 from fastapi import HTTPException, UploadFile, status from sqlalchemy import and_, case, or_, select @@ -22,6 +22,7 @@ from extralit_server.api.schemas.v1.documents import DocumentCreate, DocumentListItem from extralit_server.api.schemas.v1.imports import ( + BulkDocumentInfo, DocumentImportAnalysis, DocumentMetadata, DocumentsBulkCreate, @@ -34,8 +35,10 @@ ImportStatus, ImportSummary, ) -from extralit_server.jobs.document_jobs import upload_and_preprocess_documents_job -from extralit_server.models.database import Document, ImportHistory +from extralit_server.contexts import files as file_context +from extralit_server.database import AsyncSessionLocal +from extralit_server.models.database import Document, ImportHistory, Workspace +from extralit_server.workflows.documents import create_document_workflow _LOGGER = logging.getLogger(__name__) @@ -334,11 +337,10 @@ async def process_bulk_upload( user_id: str, ) -> DocumentsBulkResponse: """ - Process bulk document upload with associated PDF files using reference-based jobs. + Process bulk document upload with associated PDF files using new workflow orchestrator. - This function creates one job per reference that handles multiple files for that reference. - It validates all files, groups them by reference, and creates reference-based upload jobs - for efficient processing and progress tracking. + This function now handles file upload to S3 before job enqueueing, creates document records + in database, and uses the new start_document_workflow() orchestrator for processing. Args: bulk_create: DocumentsBulkCreate with reference-based document information @@ -346,12 +348,11 @@ async def process_bulk_upload( user_id: ID of the user creating the documents Returns: - DocumentsBulkResponse with job IDs indexed by reference and validation results + DocumentsBulkResponse with workflow_id and job_ids for tracking """ - from extralit_server.jobs import DEFAULT_QUEUE # Create a mapping of filenames to file objects for quick lookup - file_mapping = {file.filename: file for file in files} if files else {} + file_mapping: dict[str, UploadFile] = {file.filename: file for file in files} if files else {} # Validate that all referenced files are included in the upload missing_files = [] @@ -370,8 +371,7 @@ async def process_bulk_upload( detail=f"Referenced files not found in upload: {', '.join(missing_files)}", ) - # Group documents by reference (should be 1:1 but validate) - reference_to_doc = {} + reference_to_doc: dict[str, BulkDocumentInfo] = {} for doc in bulk_create.documents: if doc.reference in reference_to_doc: raise HTTPException( @@ -380,87 +380,128 @@ async def process_bulk_upload( ) reference_to_doc[doc.reference] = doc - # Process each reference and create reference-based jobs - job_ids = {} + # Get storage client + client = file_context.get_minio_client() + if client is None: + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to get storage client") + + # Process each reference: upload files to S3, create documents, start workflows + job_ids: dict[str, list[str]] = {} failed_validations = [] - for reference, doc in reference_to_doc.items(): - try: - # Validate and read all files for this reference - file_data_list = [] - reference_failed = False - - # Handle documents with no associated files - if not doc.associated_files: - # Create a reference-based job for documents without files - job = DEFAULT_QUEUE.enqueue( - upload_and_preprocess_documents_job, - reference=reference, - reference_data=doc.document_create.model_dump(), - file_data_list=[], - user_id=user_id, - ) - - # Store job ID mapped to reference key for frontend tracking - job_ids[reference] = job.id - _LOGGER.info(f"Created reference-based job {job.id} for reference {reference} with no files") - continue - - for filename in doc.associated_files: - try: - file = file_mapping[filename] - - if not file.filename or not file.filename.lower().endswith(".pdf"): - failed_validations.append(f"{filename}: Not a PDF file") + async with AsyncSessionLocal() as db: + for reference, doc in reference_to_doc.items(): + try: + # Get workspace for bucket name + workspace = await Workspace.get(db, doc.document_create.workspace_id) + if not workspace: + failed_validations.append(f"{reference}: Workspace not found") + continue + + # Handle documents with no associated files + if not doc.associated_files: + # Create document record without file (uses remote url) + document = await create_document(db, doc.document_create) + continue + + # Process files for this reference + reference_failed = False + uploaded_documents: list[DocumentListItem] = [] + + for filename in doc.associated_files: + try: + file = file_mapping[filename] + + if not file.filename or not file.filename.lower().endswith(".pdf"): + failed_validations.append(f"{filename}: Not a PDF file") + reference_failed = True + continue + + # Read file content + file_content = await file.read() + + # Reset file position for potential future reads + await file.seek(0) + + # Create document record first + file_document_create = DocumentCreate( + id=uuid4(), + reference=doc.document_create.reference, + pmid=doc.document_create.pmid, + doi=doc.document_create.doi, + url=None, # Will be set after S3 upload + file_name=filename, + workspace_id=doc.document_create.workspace_id, + metadata=doc.document_create.metadata, + ) + + # Check for existing documents + existing_documents = await find_existing_documents( + db=db, + workspace_id=file_document_create.workspace_id, + document_id=file_document_create.id, + file_name=file_document_create.file_name, + limit=1, + ) + + if existing_documents: + existing_document_id = existing_documents[0].id + _LOGGER.info(f"Document already exists for file {filename} with ID {existing_document_id}") + continue + + # Upload file to S3 + file_url = file_context.put_document_file( + client=client, + workspace_name=workspace.name, + document_id=file_document_create.id, + file_data=file_content, + filename=filename, + ) + + if file_url: + file_document_create.url = file_url + + # Create document in database + document = await create_document(db, file_document_create) + uploaded_documents.append(document) + + _LOGGER.info(f"Uploaded file {filename} to S3 and created document {document.id}") + + except Exception as e: + _LOGGER.error(f"Error processing file {filename} for reference {reference}: {e!s}") + failed_validations.append(f"{filename}: {e!s}") reference_failed = True - continue - - # Read file content - file_content = await file.read() - # Validate file size (100 MB limit) - if len(file_content) > 100 * 1024 * 1024: - failed_validations.append(f"{filename}: File exceeds maximum size of 100 MB") - reference_failed = True - continue - - # Reset file position for potential future reads - await file.seek(0) - - file_data_list.append((filename, file_content)) - - except Exception as e: - _LOGGER.error(f"Error processing file {filename} for reference {reference}: {e!s}") - failed_validations.append(f"{filename}: {e!s}") - reference_failed = True - - # Skip this reference if any files failed validation - if reference_failed: - continue - - # Set a default filename if not already set (use first file) - if not doc.document_create.file_name and file_data_list: - doc.document_create.file_name = file_data_list[0][0] - - # Create a reference-based job for multiple files - job = DEFAULT_QUEUE.enqueue( - upload_and_preprocess_documents_job, - reference=reference, - reference_data=doc.document_create.model_dump(), - file_data_list=file_data_list, - user_id=user_id, - job_timeout=None, # No timeout for large uploads - ) - - # Store job ID mapped to reference key for frontend tracking - job_ids[reference] = job.id - _LOGGER.info( - f"Created reference-based job {job.id} for reference {reference} with {len(file_data_list)} files" - ) - - except Exception as e: - _LOGGER.error(f"Error processing reference {reference}: {e!s}") - failed_validations.append(f"{reference}: {e!s}") + # Skip this reference if any files failed validation + if reference_failed or not uploaded_documents: + continue + + # Start workflows for each uploaded document + document_job_group = {} + for document in uploaded_documents: + try: + job_group = await create_document_workflow( + document_id=document.id, + s3_url=document.url, + reference=reference, + workspace_name=workspace.name, + workspace_id=workspace.id, + ) + + # Store the group object for later use + document_job_group[str(document.id)] = job_group + + except Exception as e: + _LOGGER.error(f"Error starting workflow for document {document.id}: {e}") + failed_validations.append(f"{reference}/{document.file_name}: Workflow start failed: {e}") + + # For each reference, select first job id + # TODO handle multiple jobs per reference or skip reporting job status during frontend upload + job_ids[reference] = next(group for group in document_job_group.values()).get_jobs()[0].id + + except Exception as e: + _LOGGER.error(f"Error processing reference {reference}: {e!s}") + failed_validations.append(f"{reference}: {e!s}") return DocumentsBulkResponse( job_ids=job_ids, total_documents=len(reference_to_doc), failed_validations=failed_validations diff --git a/extralit-server/src/extralit_server/contexts/workflows.py b/extralit-server/src/extralit_server/contexts/workflows.py new file mode 100644 index 000000000..84f9e682a --- /dev/null +++ b/extralit-server/src/extralit_server/contexts/workflows.py @@ -0,0 +1,909 @@ +# Copyright 2024-present, Extralit Labs, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Workflow job querying and management functions.""" + +import logging +from typing import Any, Optional +from uuid import UUID + +from rq.exceptions import NoSuchJobError +from rq.group import Group +from rq.job import Job +from sqlalchemy.ext.asyncio import AsyncSession + +from extralit_server.jobs.queues import REDIS_CONNECTION +from extralit_server.models.database import DocumentWorkflow + +_LOGGER = logging.getLogger(__name__) + + +async def get_jobs_for_document(db: AsyncSession, document_id: UUID) -> list[dict[str, Any]]: + """ + Get all jobs for a document using RQ Group lookup. + + This replaces expensive registry scanning with efficient RQ Group queries. + + Args: + db: Database session + document_id: Document ID to get jobs for + + Returns: + List of job dictionaries with status and metadata + """ + try: + # Get workflow record for the document + workflow = await DocumentWorkflow.get_by_document_id(db, document_id) + if not workflow: + _LOGGER.info(f"No workflow found for document {document_id}") + return [] + + # Handle group expiration and missing groups gracefully + try: + group = Group.fetch(name=workflow.group_id, connection=REDIS_CONNECTION) + except Exception as e: + _LOGGER.warning(f"Group {workflow.group_id} not found or expired for document {document_id}: {e}") + return [ + { + "id": "group_expired", + "status": "expired", + "workflow_step": "unknown", + "document_id": document_id, + "error": f"Group not found or expired: {e}", + } + ] + + jobs = group.get_jobs() + + job_data_list = [] + for job in jobs: + try: + job_data = { + "id": job.id, + "status": job.get_status(refresh=True), + "workflow_step": job.meta.get("workflow_step", "unknown") if job.meta else "unknown", + "document_id": document_id, + "group_id": workflow.group_id, + "created_at": job.created_at, + "started_at": job.started_at, + "ended_at": job.ended_at, + "meta": job.meta, + "result": job.result if job.is_finished else None, + "exc_info": job.exc_info if job.is_failed else None, + } + job_data_list.append(job_data) + except Exception as e: + # Handle individual job errors gracefully + _LOGGER.warning(f"Error processing job {job.id} for document {document_id}: {e}") + job_data_list.append( + { + "id": job.id, + "status": "error", + "workflow_step": "unknown", + "document_id": document_id, + "group_id": workflow.group_id, + "error": f"Job processing error: {e}", + } + ) + + return job_data_list + + except Exception as e: + _LOGGER.error(f"Error getting jobs for document {document_id}: {e}") + return [] + + +async def get_jobs_by_reference(db: AsyncSession, reference: str) -> list[dict[str, Any]]: + """ + Get all jobs for documents with a specific reference using RQ Groups. + + This efficiently queries multiple RQ Groups for all documents in a reference batch + by reusing the get_jobs_for_document function. + + Args: + db: Database session + reference: Document reference to search for + + Returns: + List of job dictionaries with status and metadata + """ + try: + # Get all workflows with the reference + workflows = await DocumentWorkflow.get_by_reference(db, reference) + + if not workflows: + _LOGGER.info(f"No workflows found for reference {reference}") + return [] + + all_jobs = [] + for workflow in workflows: + try: + # Reuse get_jobs_for_document for each document in the reference + document_jobs = await get_jobs_for_document(db, workflow.document_id) + all_jobs.extend(document_jobs) + except Exception as workflow_error: + _LOGGER.error( + f"Error getting jobs for document {workflow.document_id} in reference {reference}: {workflow_error}" + ) + # Add placeholder job for workflow processing error + all_jobs.append( + { + "id": f"workflow_error_{workflow.id}", + "status": "error", + "workflow_step": "unknown", + "document_id": workflow.document_id, + "group_id": workflow.group_id, + "reference": reference, + "error": f"Workflow processing error: {workflow_error}", + } + ) + + return all_jobs + + except Exception as e: + _LOGGER.error(f"Error getting jobs for reference {reference}: {e}") + return [] + + +async def get_workflow_status(db: AsyncSession, document_id: UUID) -> dict[str, Any]: + """ + Get complete workflow status for a document using RQ Groups. + + Args: + db: Database session + document_id: Document ID to get workflow status for + + Returns: + Dictionary with workflow status and progress information + """ + try: + workflow = await DocumentWorkflow.get_by_document_id(db, document_id) + if not workflow: + _LOGGER.info(f"No workflow found for document {document_id}") + return { + "document_id": document_id, + "status": "not_found", + "progress": 0.0, + "total_jobs": 0, + "completed_jobs": 0, + "failed_jobs": 0, + "running_jobs": 0, + "jobs": [], + "error": "No workflow found for document", + } + + # Get workflow status using RQ Groups with enhanced error handling + workflow_status = get_workflow_status_from_group(workflow.group_id) + + # Add additional workflow metadata + workflow_status.update( + { + "document_id": document_id, + "workflow_id": workflow.id, + "workflow_type": workflow.workflow_type, + "group_id": workflow.group_id, + "reference": workflow.reference, + "workspace_id": workflow.workspace_id, + "created_at": workflow.inserted_at, + "updated_at": workflow.updated_at, + "cached_status": workflow.status, # Include cached status for comparison + } + ) + + # Update cached status if it differs from RQ Group status + if workflow.status != workflow_status["status"] and workflow_status["status"] not in ["error", "expired"]: + try: + await update_workflow_status(db, workflow, workflow_status["status"]) + _LOGGER.info( + f"Updated cached workflow status for document {document_id} from {workflow.status} to {workflow_status['status']}" + ) + except Exception as update_error: + _LOGGER.warning(f"Failed to update cached workflow status for document {document_id}: {update_error}") + + return workflow_status + + except Exception as e: + _LOGGER.error(f"Error getting workflow status for document {document_id}: {e}") + return { + "document_id": document_id, + "status": "error", + "progress": 0.0, + "total_jobs": 0, + "completed_jobs": 0, + "failed_jobs": 0, + "running_jobs": 0, + "jobs": [], + "error": str(e), + } + + +def get_job_by_id(job_id: str) -> Optional[dict[str, Any]]: + """ + Get a single job by ID with error handling. + + Args: + job_id: Job ID to fetch + + Returns: + Job dictionary or None if not found + """ + try: + job = Job.fetch(job_id, connection=REDIS_CONNECTION) + return { + "id": job.id, + "status": job.get_status(refresh=True), + "created_at": job.created_at, + "started_at": job.started_at, + "ended_at": job.ended_at, + "meta": job.meta, + "result": job.result if job.is_finished else None, + "exc_info": job.exc_info if job.is_failed else None, + } + except (NoSuchJobError, Exception) as e: + _LOGGER.warning(f"Job {job_id} not found: {e}") + return None + + +async def update_workflow_status_on_job_completion(db: AsyncSession, document_id: UUID) -> None: + """ + Update workflow status when a job completes or fails. + + This function should be called when jobs complete to update the overall + workflow status based on the current state of all jobs using RQ Groups. + + Args: + db: Database session + document_id: Document ID to update workflow status for + """ + try: + workflow = await DocumentWorkflow.get_by_document_id(db, document_id) + if not workflow: + _LOGGER.warning(f"No workflow found for document {document_id}") + return + + # Get current workflow status using RQ Groups + workflow_status = get_workflow_status_from_group(workflow.group_id) + new_status = workflow_status["status"] + + # Update the workflow record if status changed + if workflow.status != new_status: + await update_workflow_status(db, workflow, new_status) + _LOGGER.info(f"Updated workflow status for document {document_id} to {new_status}") + + except Exception as e: + _LOGGER.error(f"Error updating workflow status for document {document_id}: {e}") + + +async def calculate_workflow_progress(db: AsyncSession, document_id: UUID) -> float: + """ + Calculate workflow progress based on completed steps. + + Args: + db: Database session + document_id: Document ID to calculate progress for + + Returns: + Progress as float between 0.0 and 1.0 + """ + try: + workflow_status = await get_workflow_status(db, document_id) + return workflow_status.get("progress", 0.0) + except Exception as e: + _LOGGER.error(f"Error calculating workflow progress for document {document_id}: {e}") + return 0.0 + + +def create_job_completion_callback(document_id: UUID): + """ + Create a callback function for job completion that updates workflow status. + + This can be used with RQ's job callbacks to automatically update workflow + status when jobs complete. + + Args: + document_id: Document ID associated with the job + + Returns: + Callback function that can be used with RQ jobs + """ + + async def callback(job, connection, result, *args, **kwargs): + """Job completion callback to update workflow status.""" + try: + from extralit_server.database import AsyncSessionLocal + + async with AsyncSessionLocal() as db: + await update_workflow_status_on_job_completion(db, document_id) + except Exception as e: + _LOGGER.error(f"Error in job completion callback for document {document_id}: {e}") + + return callback + + +def get_workflow_status_from_group(group_id: str) -> dict[str, Any]: + """ + Get workflow status using RQ Group. + + Args: + group_id: RQ Group ID + + Returns: + Dictionary with workflow status and job information + """ + try: + # Handle group expiration and missing groups gracefully + try: + group = Group.fetch(name=group_id, connection=REDIS_CONNECTION) + except Exception as e: + _LOGGER.warning(f"Group {group_id} not found or expired: {e}") + return { + "status": "expired", + "progress": 0.0, + "total_jobs": 0, + "completed_jobs": 0, + "failed_jobs": 0, + "running_jobs": 0, + "jobs": [], + "error": f"Group not found or expired: {e}", + } + + jobs = group.get_jobs() + + total_jobs = len(jobs) + if total_jobs == 0: + return { + "status": "pending", + "progress": 0.0, + "total_jobs": 0, + "completed_jobs": 0, + "failed_jobs": 0, + "running_jobs": 0, + "jobs": [], + } + + completed_jobs = sum(1 for job in jobs if job.is_finished) + failed_jobs = sum(1 for job in jobs if job.is_failed) + running_jobs = sum(1 for job in jobs if job.is_started and not job.is_finished) + + # Determine overall status + if failed_jobs > 0: + overall_status = "failed" + elif completed_jobs == total_jobs: + overall_status = "completed" + elif running_jobs > 0: + overall_status = "running" + else: + overall_status = "pending" + + # Calculate progress (0.0 to 1.0) + progress = completed_jobs / total_jobs if total_jobs > 0 else 0.0 + + job_details = [] + for job in jobs: + try: + job_details.append( + { + "id": job.id, + "status": job.get_status(refresh=True), + "created_at": job.created_at, + "started_at": job.started_at, + "ended_at": job.ended_at, + "meta": job.meta, + "result": job.result if job.is_finished else None, + "exc_info": job.exc_info if job.is_failed else None, + } + ) + except Exception as job_error: + # Handle individual job errors gracefully + _LOGGER.warning(f"Error processing job {job.id}: {job_error}") + job_details.append( + { + "id": job.id, + "status": "error", + "error": f"Job processing error: {job_error}", + } + ) + + return { + "status": overall_status, + "progress": progress, + "total_jobs": total_jobs, + "completed_jobs": completed_jobs, + "failed_jobs": failed_jobs, + "running_jobs": running_jobs, + "jobs": job_details, + } + + except Exception as e: + _LOGGER.error(f"Error getting workflow status from group {group_id}: {e}") + return { + "status": "error", + "progress": 0.0, + "total_jobs": 0, + "completed_jobs": 0, + "failed_jobs": 0, + "running_jobs": 0, + "jobs": [], + "error": str(e), + } + + +def is_workflow_resumable(group_id: str) -> bool: + """ + Check if workflow can be resumed (has failed jobs that can be retried). + + Args: + group_id: RQ Group ID + + Returns: + True if workflow has failed jobs that can be resumed + """ + try: + # Handle group expiration and missing groups gracefully + try: + group = Group.fetch(name=group_id, connection=REDIS_CONNECTION) + except Exception as e: + _LOGGER.warning(f"Group {group_id} not found or expired, cannot resume: {e}") + return False + + jobs = group.get_jobs() + + # Check if there are any failed jobs + failed_jobs = [job for job in jobs if job.is_failed] + return len(failed_jobs) > 0 + + except Exception as e: + _LOGGER.error(f"Error checking if workflow {group_id} is resumable: {e}") + return False + + +def get_failed_jobs_in_group(group_id: str) -> list[dict[str, Any]]: + """ + Add logic to identify failed jobs using RQ Group.get_jobs() with status filtering. + + Args: + group_id: RQ Group ID + + Returns: + List of failed job dictionaries with details + """ + try: + # Handle group expiration and missing groups gracefully + try: + group = Group.fetch(name=group_id, connection=REDIS_CONNECTION) + except Exception as e: + _LOGGER.warning(f"Group {group_id} not found or expired: {e}") + return [] + + jobs = group.get_jobs() + + # Filter failed jobs and return detailed information + failed_jobs = [] + for job in jobs: + try: + if job.is_failed: + failed_job_info = { + "id": job.id, + "status": job.get_status(refresh=True), + "created_at": job.created_at, + "started_at": job.started_at, + "ended_at": job.ended_at, + "meta": job.meta, + "exc_info": job.exc_info, + "failure_reason": str(job.exc_info) if job.exc_info else "Unknown failure", + "workflow_step": job.meta.get("workflow_step", "unknown") if job.meta else "unknown", + } + failed_jobs.append(failed_job_info) + except Exception as job_error: + _LOGGER.warning(f"Error processing failed job {job.id}: {job_error}") + # Add basic info even if detailed processing fails + failed_jobs.append( + { + "id": job.id, + "status": "failed", + "error": f"Job processing error: {job_error}", + "workflow_step": "unknown", + } + ) + + return failed_jobs + + except Exception as e: + _LOGGER.error(f"Error getting failed jobs for group {group_id}: {e}") + return [] + + +async def restart_failed_jobs_in_workflow(db: AsyncSession, workflow: DocumentWorkflow) -> dict[str, Any]: + """ + Restart failed jobs in the workflow group. + + Args: + db: Database session + workflow: DocumentWorkflow instance + + Returns: + Dictionary with restart results + """ + try: + # Handle group expiration and missing groups gracefully + try: + group = Group.fetch(name=workflow.group_id, connection=REDIS_CONNECTION) + except Exception as e: + _LOGGER.error(f"Group {workflow.group_id} not found or expired, cannot restart: {e}") + return { + "success": False, + "error": f"Group not found or expired: {e}", + "restarted_jobs": [], + "total_failed": 0, + } + + jobs = group.get_jobs() + + failed_jobs = [job for job in jobs if job.is_failed] + restarted_jobs = [] + + for job in failed_jobs: + try: + # Requeue the failed job + job.requeue() + restarted_jobs.append(job.id) + _LOGGER.info(f"Restarted failed job {job.id} in workflow {workflow.id}") + except Exception as e: + # Log individual job restart failures but continue + _LOGGER.warning(f"Failed to restart job {job.id}: {e}") + + # Update workflow status if jobs were restarted + if restarted_jobs: + await update_workflow_status(db, workflow, "running") + _LOGGER.info( + f"Updated workflow {workflow.id} status to running after restarting {len(restarted_jobs)} jobs" + ) + + return {"success": True, "restarted_jobs": restarted_jobs, "total_failed": len(failed_jobs)} + + except Exception as e: + _LOGGER.error(f"Error restarting failed jobs in workflow {workflow.id}: {e}") + return {"success": False, "error": str(e), "restarted_jobs": [], "total_failed": 0} + + +async def restart_failed_workflow(db: AsyncSession, document_id: UUID, partial_restart: bool = True) -> dict[str, Any]: + """ + Create restart_failed_workflow() function using RQ Group failed job identification. + + This function provides a high-level interface for restarting workflows with + support for partial vs full workflow restart using RQ Group capabilities. + + Args: + db: Database session + document_id: Document ID to restart workflow for + partial_restart: If True, only restart failed jobs; if False, restart entire workflow + + Returns: + Dictionary with restart results including job re-enqueueing status + """ + try: + # Get workflow by document ID + workflow = await DocumentWorkflow.get_by_document_id(db, document_id) + if not workflow: + return { + "success": False, + "error": f"No workflow found for document {document_id}", + "restarted_jobs": [], + "total_failed": 0, + } + + # Check if workflow is resumable + if not is_workflow_resumable(workflow.group_id): + return { + "success": False, + "error": "Workflow is not in a resumable state (no failed jobs found)", + "restarted_jobs": [], + "total_failed": 0, + } + + if partial_restart: + # Use existing function for partial restart (failed jobs only) + return await restart_failed_jobs_in_workflow(db, workflow) + else: + # Full workflow restart - restart all jobs in the group + try: + group = Group.fetch(name=workflow.group_id, connection=REDIS_CONNECTION) + except Exception as e: + _LOGGER.error(f"Group {workflow.group_id} not found or expired, cannot restart: {e}") + return { + "success": False, + "error": f"Group not found or expired: {e}", + "restarted_jobs": [], + "total_failed": 0, + } + + jobs = group.get_jobs() + restarted_jobs = [] + failed_jobs = [job for job in jobs if job.is_failed] + + # Restart all jobs in the workflow with proper dependencies + for job in jobs: + try: + # Requeue job (RQ will handle dependencies automatically) + job.requeue() + restarted_jobs.append(job.id) + _LOGGER.info(f"Restarted job {job.id} in full workflow restart for {workflow.id}") + except Exception as e: + _LOGGER.warning(f"Failed to restart job {job.id} in full restart: {e}") + + # Update DocumentWorkflow records with new group state information + if restarted_jobs: + await update_workflow_status(db, workflow, "running") + _LOGGER.info( + f"Updated workflow {workflow.id} status to running after full restart of {len(restarted_jobs)} jobs" + ) + + return { + "success": True, + "restarted_jobs": restarted_jobs, + "total_failed": len(failed_jobs), + "restart_type": "full", + } + + except Exception as e: + _LOGGER.error(f"Error restarting workflow for document {document_id}: {e}") + return {"success": False, "error": str(e), "restarted_jobs": [], "total_failed": 0} + + +async def update_workflow_status(db: AsyncSession, workflow: DocumentWorkflow, new_status: str) -> None: + """ + Update workflow status in database. + + Args: + db: Database session + workflow: DocumentWorkflow instance + new_status: New status to set + """ + workflow.status = new_status + db.add(workflow) + await db.commit() + + +async def get_workflow_by_document_id(db: AsyncSession, document_id: UUID) -> Optional[DocumentWorkflow]: + """ + Get workflow by document ID with enhanced functionality. + + Args: + db: Database session + document_id: Document ID + + Returns: + DocumentWorkflow instance or None + """ + return await DocumentWorkflow.get_by_document_id(db, document_id) + + +async def get_workflow_by_group_id(db: AsyncSession, group_id: str) -> Optional[DocumentWorkflow]: + """ + Get workflow by RQ Group ID. + + Args: + db: Database session + group_id: RQ Group ID + + Returns: + DocumentWorkflow instance or None + """ + return await DocumentWorkflow.get_by_group_id(db, group_id) + + +async def get_workflows_by_reference( + db: AsyncSession, reference: str, workspace_id: Optional[UUID] = None +) -> list[DocumentWorkflow]: + """ + Get workflows by reference (batch tracking). + + Args: + db: Database session + reference: Document reference + workspace_id: Optional workspace ID filter + + Returns: + List of DocumentWorkflow instances + """ + return await DocumentWorkflow.get_by_reference(db, reference, str(workspace_id) if workspace_id else None) + + +async def get_workflow_statuses_by_reference(db: AsyncSession, reference: str) -> list[dict[str, Any]]: + """ + Get workflow statuses for all documents with a specific reference. + + This is more efficient than calling get_workflow_status for each document individually. + + Args: + db: Database session + reference: Document reference to search for + + Returns: + List of workflow status dictionaries + """ + try: + workflows = await DocumentWorkflow.get_by_reference(db, reference) + + if not workflows: + _LOGGER.info(f"No workflows found for reference {reference}") + return [] + + workflow_statuses = [] + for workflow in workflows: + try: + # Get workflow status using RQ Groups + workflow_status = get_workflow_status_from_group(workflow.group_id) + + # Add workflow metadata + workflow_status.update( + { + "document_id": workflow.document_id, + "workflow_id": workflow.id, + "workflow_type": workflow.workflow_type, + "group_id": workflow.group_id, + "reference": reference, + "workspace_id": workflow.workspace_id, + "created_at": workflow.inserted_at, + "updated_at": workflow.updated_at, + "cached_status": workflow.status, + } + ) + + workflow_statuses.append(workflow_status) + + # Update cached status if needed + if workflow.status != workflow_status["status"] and workflow_status["status"] not in [ + "error", + "expired", + ]: + try: + await update_workflow_status(db, workflow, workflow_status["status"]) + except Exception as update_error: + _LOGGER.warning( + f"Failed to update cached workflow status for workflow {workflow.id}: {update_error}" + ) + + except Exception as workflow_error: + _LOGGER.error(f"Error processing workflow {workflow.id} for reference {reference}: {workflow_error}") + # Add error status for failed workflow processing + workflow_statuses.append( + { + "document_id": workflow.document_id, + "workflow_id": workflow.id, + "workflow_type": workflow.workflow_type, + "group_id": workflow.group_id, + "reference": reference, + "workspace_id": workflow.workspace_id, + "status": "error", + "progress": 0.0, + "total_jobs": 0, + "completed_jobs": 0, + "failed_jobs": 0, + "running_jobs": 0, + "jobs": [], + "error": f"Workflow processing error: {workflow_error}", + "created_at": workflow.inserted_at, + "updated_at": workflow.updated_at, + "cached_status": workflow.status, + } + ) + + return workflow_statuses + + except Exception as e: + _LOGGER.error(f"Error getting workflow statuses for reference {reference}: {e}") + return [] + + +async def list_workflows( + db: AsyncSession, + workspace_id: Optional[UUID] = None, + status_filter: Optional[str] = None, + limit: int = 50, +) -> list[dict[str, Any]]: + """ + List workflows with optional filtering and RQ Group information. + + Args: + db: Database session + workspace_id: Optional workspace ID filter + status_filter: Optional status filter + limit: Maximum number of workflows to return + + Returns: + List of workflow status dictionaries with RQ Group information + """ + try: + from sqlalchemy import select + + # Build query with efficient database queries using group_id indexing + query = select(DocumentWorkflow).order_by(DocumentWorkflow.inserted_at.desc()).limit(limit) + + if workspace_id: + query = query.where(DocumentWorkflow.workspace_id == workspace_id) + + if status_filter: + query = query.where(DocumentWorkflow.status == status_filter) + + # Execute query + result = await db.execute(query) + workflows = result.scalars().all() + + workflow_statuses = [] + for workflow in workflows: + try: + # Get workflow status using RQ Groups with enhanced error handling + workflow_status = get_workflow_status_from_group(workflow.group_id) + + # Add workflow metadata including RQ Group information + workflow_status.update( + { + "document_id": workflow.document_id, + "workflow_id": workflow.id, + "workflow_type": workflow.workflow_type, + "group_id": workflow.group_id, + "reference": workflow.reference, + "workspace_id": workflow.workspace_id, + "created_at": workflow.inserted_at, + "updated_at": workflow.updated_at, + "cached_status": workflow.status, + } + ) + + workflow_statuses.append(workflow_status) + + # Update cached status if needed for performance optimization + if workflow.status != workflow_status["status"] and workflow_status["status"] not in [ + "error", + "expired", + ]: + try: + await update_workflow_status(db, workflow, workflow_status["status"]) + except Exception as update_error: + _LOGGER.warning( + f"Failed to update cached workflow status for workflow {workflow.id}: {update_error}" + ) + + except Exception as workflow_error: + # Handle missing groups and RQ connection issues gracefully + _LOGGER.warning(f"Error processing workflow {workflow.id}: {workflow_error}") + # Add basic workflow info even if RQ Group access fails + workflow_statuses.append( + { + "document_id": workflow.document_id, + "workflow_id": workflow.id, + "workflow_type": workflow.workflow_type, + "group_id": workflow.group_id, + "reference": workflow.reference, + "workspace_id": workflow.workspace_id, + "status": "error", + "progress": 0.0, + "total_jobs": 0, + "completed_jobs": 0, + "failed_jobs": 0, + "running_jobs": 0, + "jobs": [], + "error": f"RQ Group access error: {workflow_error}", + "created_at": workflow.inserted_at, + "updated_at": workflow.updated_at, + "cached_status": workflow.status, + } + ) + + return workflow_statuses + + except Exception as e: + _LOGGER.error(f"Error listing workflows: {e}") + return [] diff --git a/extralit-server/src/extralit_server/jobs/dataset_jobs.py b/extralit-server/src/extralit_server/jobs/dataset_jobs.py index 00b7c74ac..e2953e68d 100644 --- a/extralit-server/src/extralit_server/jobs/dataset_jobs.py +++ b/extralit-server/src/extralit_server/jobs/dataset_jobs.py @@ -20,7 +20,7 @@ from extralit_server.contexts import distribution from extralit_server.database import AsyncSessionLocal -from extralit_server.jobs.queues import DEFAULT_QUEUE, JOB_TIMEOUT_DISABLED +from extralit_server.jobs.queues import DEFAULT_QUEUE, JOB_TIMEOUT_DISABLED, REDIS_CONNECTION from extralit_server.models import Record, Response from extralit_server.search_engine.base import SearchEngine from extralit_server.settings import settings @@ -28,7 +28,7 @@ JOB_RECORDS_YIELD_PER = 100 -@job(DEFAULT_QUEUE, timeout=JOB_TIMEOUT_DISABLED, retry=Retry(max=3)) +@job(DEFAULT_QUEUE, connection=REDIS_CONNECTION, timeout=JOB_TIMEOUT_DISABLED, retry=Retry(max=3)) async def update_dataset_records_status_job(dataset_id: UUID) -> None: """This Job updates the status of all the records in the dataset when the distribution strategy changes.""" diff --git a/extralit-server/src/extralit_server/jobs/document_jobs.py b/extralit-server/src/extralit_server/jobs/document_jobs.py index 1c735b181..28fe0af48 100644 --- a/extralit-server/src/extralit_server/jobs/document_jobs.py +++ b/extralit-server/src/extralit_server/jobs/document_jobs.py @@ -16,23 +16,34 @@ import logging import os +from datetime import datetime, timezone from typing import Any from uuid import UUID, uuid4 -from rq import Retry +from rq import Retry, get_current_job from rq.decorators import job +from extralit_server.api.schemas.v1.document.metadata import DocumentProcessingMetadata from extralit_server.api.schemas.v1.documents import DocumentCreate from extralit_server.contexts import files, imports from extralit_server.contexts.document import preprocessing -from extralit_server.database import AsyncSessionLocal -from extralit_server.jobs import DEFAULT_QUEUE, JOB_TIMEOUT_DISABLED +from extralit_server.contexts.document.analysis import PDFOCRLayerDetector +from extralit_server.contexts.document.margin import PDFAnalyzer +from extralit_server.contexts.document.preprocessing import PDFPreprocessingSettings, PDFPreprocessor +from extralit_server.database import AsyncSessionLocal, SyncSessionLocal +from extralit_server.jobs.queues import DEFAULT_QUEUE, JOB_TIMEOUT_DISABLED, REDIS_CONNECTION +from extralit_server.models.database import Document _LOGGER = logging.getLogger(__name__) -@job(DEFAULT_QUEUE, timeout=JOB_TIMEOUT_DISABLED, retry=Retry(max=3, interval=[10, 30, 60])) -async def upload_and_preprocess_documents_job( +@job( + queue=DEFAULT_QUEUE, + connection=REDIS_CONNECTION, + timeout=JOB_TIMEOUT_DISABLED, + retry=Retry(max=3, interval=[10, 30, 60]), +) +async def upload_and_preprocess_documents_job( # Deprecated reference: str, reference_data: dict[str, Any], file_data_list: list[tuple[str, bytes]], # List of (filename, file_data) tuples @@ -197,3 +208,128 @@ async def upload_and_preprocess_documents_job( _LOGGER.warning(f"Failed to cleanup temporary file {temp_file}: {e!s}") return results + + +@job(queue=DEFAULT_QUEUE, connection=REDIS_CONNECTION, timeout=600, retry=Retry(max=3, interval=[10, 30, 60])) +def analysis_and_preprocess_job(document_id: UUID, s3_url: str, reference: str, workspace_name: str) -> dict[str, Any]: + """ + Analyze PDF structure and content, then preprocess using existing modules. + + This job combines PDFOCRLayerDetector, PDFAnalyzer, and PDFPreprocessor to: + 1. Analyze original PDF structure and content + 2. Preprocess PDF using OCRmyPDF for page rotation (overwrites same S3 path) + 3. Store combined results in documents.metadata_ using DocumentProcessingMetadata schema + + Args: + document_id: UUID of the document to process + s3_url: S3 URL of the PDF file + reference: Reference key for tracking + workspace_name: Name of the workspace where the document is stored + + Returns: + Dictionary containing combined analysis and preprocessing results + """ + current_job = get_current_job() + current_job.meta.update( + { + "document_id": str(document_id), + "reference": reference, + "workspace_name": str(workspace_name), + "workflow_step": "analysis_and_preprocess", + "started_at": datetime.now(timezone.utc).isoformat(), + } + ) + current_job.save_meta() + + try: + # Download original PDF from storage + client = files.get_minio_client() + if client is None: + raise Exception("Failed to get storage client") + + pdf_data = files.download_file_content(client, s3_url) + filename = s3_url.split("/")[-1] + + # Step 1: Analyze original PDF structure and content + ocr_detector = PDFOCRLayerDetector() + has_ocr_text_layer = ocr_detector.has_ocr_text_layer(pdf_data) + ocr_quality = ocr_detector.analyze_character_quality(pdf_data) + + pdf_analyzer = PDFAnalyzer() + layout_analysis = pdf_analyzer.analyze_pdf_layout(pdf_data, filename) + + analysis_result = { + "document_id": str(document_id), + "has_ocr_text_layer": has_ocr_text_layer, + "ocr_quality_score": ocr_quality.get("ocr_quality_score", 0.0), + "layout_analysis": layout_analysis, + "needs_ocr": not has_ocr_text_layer or ocr_quality.get("ocr_quality_score", 0.0) < 0.7, + "analysis_metadata": { + "total_chars": ocr_quality.get("total_chars", 0), + "ocr_artifacts": ocr_quality.get("ocr_artifacts", 0), + "suspicious_patterns": ocr_quality.get("suspicious_patterns", 0), + "ocr_quality_score": ocr_quality.get("ocr_quality_score", 0.0), + }, + } + + # Step 2: Preprocess PDF (OCRmyPDF for page rotation, overwrites same S3 path) + settings = PDFPreprocessingSettings(enable_analysis=False) # Analysis already done + preprocessor = PDFPreprocessor(settings) + processing_response = preprocessor.preprocess(pdf_data, filename) + + # OCRmyPDF overwrites the same S3 object path, so we upload back to same location + object_path = s3_url.replace(f"/api/v1/file/{workspace_name}/", "") + + files.put_object( + client, + workspace_name, + object_path, + processing_response.processed_data, + len(processing_response.processed_data), + content_type="application/pdf", + metadata={"processing_applied": "ocrmypdf_rotation", "original_filename": filename}, + ) + + # Combine results + combined_result = { + "document_id": str(document_id), + "analysis_result": analysis_result, + "preprocessing_result": { + "processing_time": processing_response.metadata.processing_time, + "ocr_applied": getattr(processing_response.metadata, "ocr_applied", False), + "preprocessing_metadata": processing_response.metadata.model_dump(), + }, + "needs_ocr": analysis_result["needs_ocr"], + } + + # Store combined results in document.metadata_ using sync database operations + with SyncSessionLocal() as db: + document = db.get(Document, document_id) + if document: + # Initialize or update document metadata + if document.metadata_ is None: + document.metadata_ = DocumentProcessingMetadata( + workflow_started_at=datetime.now(timezone.utc) + ).model_dump() + + metadata = DocumentProcessingMetadata(**document.metadata_) + metadata.update_analysis_results(analysis_result) + metadata.update_preprocessing_results(combined_result["preprocessing_result"]) + document.metadata_ = metadata.model_dump() + db.commit() + + # Store results for dependent jobs + current_job.meta["needs_ocr"] = analysis_result["needs_ocr"] + current_job.meta["analysis_complete"] = True + current_job.meta["preprocessing_complete"] = True + current_job.meta["completed_at"] = datetime.now(timezone.utc).isoformat() + current_job.save_meta() + + return combined_result + + except Exception as e: + _LOGGER.error(f"Error in analysis_and_preprocess_job for document {document_id}: {e}") + current_job.meta["error"] = str(e) + current_job.meta["completed_at"] = datetime.now(timezone.utc).isoformat() + current_job.save_meta() + raise diff --git a/extralit-server/src/extralit_server/jobs/hub_jobs.py b/extralit-server/src/extralit_server/jobs/hub_jobs.py index 7ddba7a03..fbd2ae8a6 100644 --- a/extralit-server/src/extralit_server/jobs/hub_jobs.py +++ b/extralit-server/src/extralit_server/jobs/hub_jobs.py @@ -21,7 +21,7 @@ from extralit_server.api.schemas.v1.datasets import HubDatasetMapping from extralit_server.contexts.hub import HubDataset, HubDatasetExporter from extralit_server.database import AsyncSessionLocal -from extralit_server.jobs.queues import DEFAULT_QUEUE, JOB_TIMEOUT_DISABLED +from extralit_server.jobs.queues import DEFAULT_QUEUE, JOB_TIMEOUT_DISABLED, REDIS_CONNECTION from extralit_server.models import Dataset from extralit_server.search_engine.base import SearchEngine from extralit_server.settings import settings @@ -29,7 +29,7 @@ HUB_DATASET_TAKE_ROWS = 10_000 -@job(DEFAULT_QUEUE, timeout=JOB_TIMEOUT_DISABLED, retry=Retry(max=3)) +@job(DEFAULT_QUEUE, connection=REDIS_CONNECTION, timeout=JOB_TIMEOUT_DISABLED, retry=Retry(max=3)) async def import_dataset_from_hub_job(name: str, subset: str, split: str, dataset_id: UUID, mapping: dict) -> None: async with AsyncSessionLocal() as db: dataset = await Dataset.get_or_raise( @@ -52,7 +52,7 @@ async def import_dataset_from_hub_job(name: str, subset: str, split: str, datase ) -@job(DEFAULT_QUEUE, timeout=JOB_TIMEOUT_DISABLED, retry=Retry(max=3)) +@job(DEFAULT_QUEUE, connection=REDIS_CONNECTION, timeout=JOB_TIMEOUT_DISABLED, retry=Retry(max=3)) async def export_dataset_to_hub_job( name: str, subset: str, split: str, private: bool, token: str, dataset_id: UUID ) -> None: diff --git a/extralit-server/src/extralit_server/jobs/import_jobs.py b/extralit-server/src/extralit_server/jobs/import_jobs.py index 85c619f45..93f900dce 100644 --- a/extralit-server/src/extralit_server/jobs/import_jobs.py +++ b/extralit-server/src/extralit_server/jobs/import_jobs.py @@ -43,7 +43,7 @@ from extralit_server.api.schemas.v1.suggestions import SuggestionCreate from extralit_server.contexts.records_bulk import UpsertRecordsBulk from extralit_server.database import AsyncSessionLocal -from extralit_server.jobs.queues import DEFAULT_QUEUE, JOB_TIMEOUT_DISABLED +from extralit_server.jobs.queues import DEFAULT_QUEUE, JOB_TIMEOUT_DISABLED, REDIS_CONNECTION from extralit_server.models import Dataset, ImportHistory from extralit_server.search_engine.base import SearchEngine from extralit_server.settings import settings @@ -165,7 +165,7 @@ def _row_suggestions(self, row: dict[str, Any], dataset: Dataset) -> list[Sugges return suggestions -@job(DEFAULT_QUEUE, timeout=JOB_TIMEOUT_DISABLED, retry=Retry(max=3)) +@job(DEFAULT_QUEUE, connection=REDIS_CONNECTION, timeout=JOB_TIMEOUT_DISABLED, retry=Retry(max=3)) async def import_dataset_from_import_history_job(history_id: UUID, dataset_id: UUID, mapping: dict) -> None: """ Import dataset records from ImportHistory data. diff --git a/extralit-server/src/extralit_server/jobs/queues.py b/extralit-server/src/extralit_server/jobs/queues.py index a3e0fae2e..8560888d7 100644 --- a/extralit-server/src/extralit_server/jobs/queues.py +++ b/extralit-server/src/extralit_server/jobs/queues.py @@ -25,5 +25,7 @@ DEFAULT_QUEUE = Queue("default", connection=REDIS_CONNECTION) HIGH_QUEUE = Queue("high", connection=REDIS_CONNECTION) +OCR_QUEUE = Queue("ocr", connection=REDIS_CONNECTION) +GPU_QUEUE = Queue("gpu", connection=REDIS_CONNECTION) JOB_TIMEOUT_DISABLED = -1 diff --git a/extralit-server/src/extralit_server/jobs/webhook_jobs.py b/extralit-server/src/extralit_server/jobs/webhook_jobs.py index c1c5a7843..9dce8b9d4 100644 --- a/extralit-server/src/extralit_server/jobs/webhook_jobs.py +++ b/extralit-server/src/extralit_server/jobs/webhook_jobs.py @@ -23,7 +23,7 @@ from extralit_server.contexts import webhooks from extralit_server.database import AsyncSessionLocal -from extralit_server.jobs.queues import HIGH_QUEUE +from extralit_server.jobs.queues import HIGH_QUEUE, REDIS_CONNECTION from extralit_server.models import Webhook from extralit_server.webhooks.v1.commons import notify_event @@ -43,7 +43,7 @@ async def enqueue_notify_events(db: AsyncSession, event: str, timestamp: datetim return enqueued_jobs -@job(HIGH_QUEUE, retry=Retry(max=3, interval=[10, 60, 180])) +@job(HIGH_QUEUE, connection=REDIS_CONNECTION, retry=Retry(max=3, interval=[10, 60, 180])) async def notify_event_job(webhook_id: UUID, event: str, timestamp: datetime, data: dict) -> None: async with AsyncSessionLocal() as db: webhook = await Webhook.get_or_raise(db, webhook_id) diff --git a/extralit-server/src/extralit_server/logging.py b/extralit-server/src/extralit_server/logging.py index e40a723e0..ea2199bb2 100644 --- a/extralit-server/src/extralit_server/logging.py +++ b/extralit-server/src/extralit_server/logging.py @@ -67,3 +67,7 @@ def configure_logging(): # See the note here: https://docs.python.org/3/library/logging.html#logging.Logger.propagate # We only attach our handler to the root logger and let propagation take care of the rest logging.basicConfig(handlers=[handler], level=logging.WARNING) + + # Suppress pdfminer warnings about invalid color values + logging.getLogger("pdfminer.pdfinterp").setLevel(logging.ERROR) + logging.getLogger("pdfminer").setLevel(logging.ERROR) diff --git a/extralit-server/src/extralit_server/models/database.py b/extralit-server/src/extralit_server/models/database.py index 1e5c5efc9..2a0041f9d 100644 --- a/extralit-server/src/extralit_server/models/database.py +++ b/extralit-server/src/extralit_server/models/database.py @@ -15,7 +15,7 @@ import base64 import secrets from datetime import datetime -from typing import Any, Union +from typing import Any, Optional, Union from uuid import UUID from pydantic import TypeAdapter @@ -26,10 +26,12 @@ String, Text, UniqueConstraint, + select, sql, ) from sqlalchemy import Enum as SAEnum from sqlalchemy.engine.default import DefaultExecutionContext +from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.mutable import MutableDict, MutableList from sqlalchemy.orm import Mapped, mapped_column, relationship @@ -54,6 +56,7 @@ "Dataset", "DatasetUser", "Document", + "DocumentWorkflow", "Field", "ImportHistory", "MetadataProperty", @@ -647,6 +650,13 @@ class Document(DatabaseModel): metadata_: Mapped[dict | None] = mapped_column("metadata", MutableDict.as_mutable(JSON()), nullable=True) workspace: Mapped["Workspace"] = relationship("Workspace", back_populates="documents") + workflows: Mapped[list["DocumentWorkflow"]] = relationship( + "DocumentWorkflow", + back_populates="document", + cascade="all, delete-orphan", + passive_deletes=True, + order_by="DocumentWorkflow.inserted_at.desc()", + ) def __repr__(self): return ( @@ -678,7 +688,7 @@ def __repr__(self): class ImportHistory(DatabaseModel): - __tablename__ = "import_history" + __tablename__ = "imports" workspace_id: Mapped[UUID] = mapped_column(ForeignKey("workspaces.id", ondelete="CASCADE"), index=True) user_id: Mapped[UUID] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"), index=True) @@ -695,3 +705,45 @@ def __repr__(self): f"user_id={str(self.user_id)!r}, filename={self.filename!r}, " f"inserted_at={str(self.inserted_at)!r}, updated_at={str(self.updated_at)!r})" ) + + +class DocumentWorkflow(DatabaseModel): + """Track document processing workflows for efficient job querying.""" + + __tablename__ = "workflows" + + workflow_type: Mapped[str] = mapped_column(String(50)) + workspace_id: Mapped[UUID] = mapped_column(ForeignKey("workspaces.id"), nullable=False) + document_id: Mapped[UUID] = mapped_column(ForeignKey("documents.id"), nullable=False, index=True) + reference: Mapped[str] = mapped_column(String(255), nullable=True, index=True) + + # RQ Group integration + group_id: Mapped[str] = mapped_column(String(255), nullable=False, index=True) + status: Mapped[str] = mapped_column(String(50), default="pending", index=True) # Cached workflow status + + # Relationships + document: Mapped["Document"] = relationship("Document", back_populates="workflows") + workspace: Mapped["Workspace"] = relationship("Workspace") + + @classmethod + async def get_by_document_id(cls, db: AsyncSession, document_id: UUID) -> Optional["DocumentWorkflow"]: + """Get workflow by document ID.""" + result = await db.execute(select(cls).where(cls.document_id == document_id)) + return result.scalar_one_or_none() + + @classmethod + async def get_by_group_id(cls, db: AsyncSession, group_id: str) -> Optional["DocumentWorkflow"]: + """Get workflow by RQ Group ID.""" + result = await db.execute(select(cls).where(cls.group_id == group_id)) + return result.scalar_one_or_none() + + @classmethod + async def get_by_reference( + cls, db: AsyncSession, reference: str, workspace_id: str | None = None + ) -> list["DocumentWorkflow"]: + """Get workflows by reference (batch tracking).""" + query = select(cls).where(cls.reference == reference) + if workspace_id: + query = query.where(cls.workspace_id == workspace_id) + result = await db.execute(query) + return result.scalars().all() diff --git a/extralit-server/src/extralit_server/workflows/documents.py b/extralit-server/src/extralit_server/workflows/documents.py new file mode 100644 index 000000000..71251b2db --- /dev/null +++ b/extralit-server/src/extralit_server/workflows/documents.py @@ -0,0 +1,114 @@ +# Copyright 2024-present, Extralit Labs, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging +from uuid import UUID, uuid4 + +from rq.group import Group + +from extralit_server.database import AsyncSessionLocal +from extralit_server.jobs.document_jobs import analysis_and_preprocess_job +from extralit_server.jobs.queues import DEFAULT_QUEUE, OCR_QUEUE, REDIS_CONNECTION +from extralit_server.models.database import DocumentWorkflow + +_LOGGER = logging.getLogger(__name__) + + +async def create_document_workflow( + document_id: UUID, s3_url: str, reference: str, workspace_name: str, workspace_id: UUID +) -> Group: + """ + Start PDF processing workflow using RQ Groups for job tracking. + + Creates DocumentWorkflow record and manages entire job chain using RQ Groups. + Handles conditional OCR logic in orchestrator, not in individual jobs. + + Args: + document_id: UUID of the document to process + s3_url: S3 URL of the PDF file + reference: Reference key for tracking + workspace_name: Workspace name for job context + workspace_id: UUID of the workspace + + Returns: + Dictionary containing workflow_id and group_id for tracking + """ + group_id = f"document_workflow_{document_id}_{uuid4().hex[:8]}" + group = Group(REDIS_CONNECTION, name=group_id) + + # Step 3: Create DocumentWorkflow record for tracking + async with AsyncSessionLocal() as db: + workflow = DocumentWorkflow( + id=uuid4(), + document_id=document_id, + workflow_type="pdf_processing", + workspace_id=workspace_id, + reference=reference, + group_id=group_id, + status="running", + ) + db.add(workflow) + await db.commit() + await db.refresh(workflow) + + # Step 4: Prepare jobs using Queue.prepare_data() + analysis_job_data = DEFAULT_QUEUE.prepare_data( + analysis_and_preprocess_job, + (document_id, s3_url, reference, workspace_name), + timeout=600, + job_id=f"analysis_preprocess_{document_id}", + meta={ + "document_id": str(document_id), + "reference": reference, + "workflow_step": "analysis_and_preprocess", + "workflow_id": str(workflow.id), + }, + ) + + text_extraction_job_data = OCR_QUEUE.prepare_data( + "extralit_ocr.jobs.pymupdf_to_markdown_job", + (document_id, s3_url, s3_url.split("/")[-1], {}, workspace_name), + timeout=900, + job_id=f"text_extraction_{document_id}", + meta={ + "document_id": str(document_id), + "reference": reference, + "workflow_step": "text_extraction", + "workflow_id": str(workflow.id), + }, + ) + + group.enqueue_many(queue=DEFAULT_QUEUE, job_datas=[analysis_job_data]) + group.enqueue_many(queue=OCR_QUEUE, job_datas=[text_extraction_job_data]) + + # Step 6: Future table extraction job (conditional based on analysis results) + # This will be added when table extraction is implemented + # table_extraction_job_data = OCR_QUEUE.prepare_data( + # "extralit_ocr.jobs.table_extraction_job", + # (document_id, s3_url), + # depends_on=[jobs[0]], # depends on analysis job + # group=group, + # job_id=f"table_extraction_{document_id}", + # meta={ + # "document_id": str(document_id), + # "reference": reference, + # "workflow_step": "table_extraction", + # "workflow_id": str(workflow.id) + # } + # ) + + _LOGGER.info(f"Started PDF workflow {workflow.id} for document {document_id} with group {group_id}") + + return group diff --git a/extralit-server/tests/integration/test_rq_groups_workflow.py b/extralit-server/tests/integration/test_rq_groups_workflow.py new file mode 100644 index 000000000..c32d8e621 --- /dev/null +++ b/extralit-server/tests/integration/test_rq_groups_workflow.py @@ -0,0 +1,570 @@ +# Copyright 2024-present, Extralit Labs, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration tests for complete workflow using RQ Groups.""" + +import asyncio +from unittest.mock import MagicMock, patch +from uuid import uuid4 + +import pytest +from httpx import AsyncClient +from rq.group import Group +from rq.job import Job + +from extralit_server.contexts.workflows import ( + get_jobs_for_document, + get_workflow_status, + restart_failed_workflow, +) +from extralit_server.models.database import Document, DocumentWorkflow, Workspace +from extralit_server.workflows.documents import create_document_workflow + + +@pytest.mark.asyncio +class TestRQGroupsWorkflowIntegration: + """Integration tests for RQ Groups workflow functionality.""" + + @pytest.fixture + async def test_workspace(self, async_db): + """Create test workspace.""" + workspace = Workspace( + id=uuid4(), + name="test_workspace", + title="Test Workspace", + description="Test workspace for RQ Groups integration tests", + ) + async_db.add(workspace) + await async_db.commit() + await async_db.refresh(workspace) + return workspace + + @pytest.fixture + async def test_document(self, async_db, test_workspace): + """Create test document.""" + document = Document( + id=uuid4(), + reference="test_ref_123", + file_name="test.pdf", + workspace_id=test_workspace.id, + url="s3://test-bucket/test.pdf", + metadata_={}, + ) + async_db.add(document) + await async_db.commit() + await async_db.refresh(document) + return document + + @pytest.fixture + def mock_redis_connection(self): + """Mock Redis connection for RQ Groups.""" + with ( + patch("extralit_server.workflows.documents.REDIS_CONNECTION") as mock_conn, + patch("extralit_server.contexts.workflows.REDIS_CONNECTION", mock_conn), + ): + yield mock_conn + + @pytest.fixture + def mock_rq_queues(self): + """Mock RQ queues.""" + with ( + patch("extralit_server.workflows.documents.DEFAULT_QUEUE") as mock_default, + patch("extralit_server.workflows.documents.OCR_QUEUE") as mock_ocr, + ): + # Setup queue prepare_data methods + mock_default.prepare_data.return_value = { + "func": "analysis_and_preprocess_job", + "args": [], + "kwargs": {}, + "timeout": 600, + "job_id": "test_job_1", + } + + mock_ocr.prepare_data.return_value = { + "func": "text_extraction_job", + "args": [], + "kwargs": {}, + "timeout": 900, + "job_id": "test_job_2", + } + + yield mock_default, mock_ocr + + async def test_create_document_workflow_with_rq_groups( + self, async_db, test_document, test_workspace, mock_redis_connection, mock_rq_queues + ): + """Test creating document workflow with RQ Groups integration.""" + mock_default_queue, mock_ocr_queue = mock_rq_queues + + # Mock RQ Group + mock_group = MagicMock(spec=Group) + mock_group.name = f"document_workflow_{test_document.id}_12345678" + + with patch("extralit_server.workflows.documents.Group", return_value=mock_group): + group = await create_document_workflow( + document_id=test_document.id, + s3_url=test_document.url, + reference=test_document.reference, + workspace_name=test_workspace.name, + workspace_id=test_workspace.id, + ) + + # Verify group was created + assert group == mock_group + + # Verify DocumentWorkflow record was created + workflow = await DocumentWorkflow.get_by_document_id(async_db, test_document.id) + assert workflow is not None + assert workflow.document_id == test_document.id + assert workflow.workflow_type == "pdf_processing" + assert workflow.status == "running" + assert workflow.group_id.startswith(f"document_workflow_{test_document.id}") + + # Verify jobs were prepared and enqueued + mock_default_queue.prepare_data.assert_called_once() + mock_ocr_queue.prepare_data.assert_called_once() + mock_group.enqueue_many.assert_called() + + async def test_workflow_status_tracking_with_rq_groups(self, async_db, test_document, mock_redis_connection): + """Test workflow status tracking using RQ Groups.""" + # Create workflow record + workflow = DocumentWorkflow( + id=uuid4(), + document_id=test_document.id, + workflow_type="pdf_processing", + workspace_id=test_document.workspace_id, + reference=test_document.reference, + group_id="test_group_123", + status="running", + ) + async_db.add(workflow) + await async_db.commit() + + # Mock RQ Group with jobs + mock_job1 = MagicMock(spec=Job) + mock_job1.id = "job1" + mock_job1.is_finished = True + mock_job1.is_failed = False + mock_job1.get_status.return_value = "finished" + mock_job1.meta = {"workflow_step": "analysis_and_preprocess"} + mock_job1.created_at = None + mock_job1.started_at = None + mock_job1.ended_at = None + mock_job1.result = None + mock_job1.exc_info = None + + mock_job2 = MagicMock(spec=Job) + mock_job2.id = "job2" + mock_job2.is_finished = False + mock_job2.is_failed = False + mock_job2.is_started = True + mock_job2.get_status.return_value = "started" + mock_job2.meta = {"workflow_step": "text_extraction"} + mock_job2.created_at = None + mock_job2.started_at = None + mock_job2.ended_at = None + mock_job2.result = None + mock_job2.exc_info = None + + mock_group = MagicMock(spec=Group) + mock_group.get_jobs.return_value = [mock_job1, mock_job2] + + with patch("extralit_server.contexts.workflows.Group.fetch", return_value=mock_group): + status = await get_workflow_status(async_db, test_document.id) + + assert status["status"] == "running" + assert status["progress"] == 0.5 # 1 of 2 jobs completed + assert status["total_jobs"] == 2 + assert status["completed_jobs"] == 1 + assert status["failed_jobs"] == 0 + assert status["running_jobs"] == 1 + assert status["document_id"] == test_document.id + assert status["group_id"] == "test_group_123" + + async def test_workflow_restart_with_rq_groups(self, async_db, test_document, mock_redis_connection): + """Test workflow restart functionality using RQ Groups.""" + # Create workflow record + workflow = DocumentWorkflow( + id=uuid4(), + document_id=test_document.id, + workflow_type="pdf_processing", + workspace_id=test_document.workspace_id, + reference=test_document.reference, + group_id="test_group_123", + status="failed", + ) + async_db.add(workflow) + await async_db.commit() + + # Mock failed job + mock_failed_job = MagicMock(spec=Job) + mock_failed_job.id = "failed_job" + mock_failed_job.is_failed = True + mock_failed_job.requeue = MagicMock() + + # Mock completed job + mock_completed_job = MagicMock(spec=Job) + mock_completed_job.id = "completed_job" + mock_completed_job.is_failed = False + + mock_group = MagicMock(spec=Group) + mock_group.get_jobs.return_value = [mock_failed_job, mock_completed_job] + + with patch("extralit_server.contexts.workflows.Group.fetch", return_value=mock_group): + result = await restart_failed_workflow(async_db, test_document.id, partial_restart=True) + + assert result["success"] is True + assert result["restarted_jobs"] == ["failed_job"] + assert result["total_failed"] == 1 + + # Verify failed job was requeued + mock_failed_job.requeue.assert_called_once() + + # Verify workflow status was updated + await async_db.refresh(workflow) + assert workflow.status == "running" + + async def test_job_querying_with_rq_groups(self, async_db, test_document, mock_redis_connection): + """Test job querying functionality using RQ Groups.""" + # Create workflow record + workflow = DocumentWorkflow( + id=uuid4(), + document_id=test_document.id, + workflow_type="pdf_processing", + workspace_id=test_document.workspace_id, + reference=test_document.reference, + group_id="test_group_123", + status="running", + ) + async_db.add(workflow) + await async_db.commit() + + # Mock RQ jobs with metadata + mock_job1 = MagicMock(spec=Job) + mock_job1.id = "analysis_job" + mock_job1.get_status.return_value = "finished" + mock_job1.meta = { + "document_id": str(test_document.id), + "reference": test_document.reference, + "workflow_step": "analysis_and_preprocess", + "workflow_id": str(workflow.id), + } + mock_job1.created_at = None + mock_job1.started_at = None + mock_job1.ended_at = None + mock_job1.result = {"analysis_complete": True} + mock_job1.exc_info = None + mock_job1.is_finished = True + + mock_job2 = MagicMock(spec=Job) + mock_job2.id = "text_extraction_job" + mock_job2.get_status.return_value = "started" + mock_job2.meta = { + "document_id": str(test_document.id), + "reference": test_document.reference, + "workflow_step": "text_extraction", + "workflow_id": str(workflow.id), + } + mock_job2.created_at = None + mock_job2.started_at = None + mock_job2.ended_at = None + mock_job2.result = None + mock_job2.exc_info = None + mock_job2.is_finished = False + + mock_group = MagicMock(spec=Group) + mock_group.get_jobs.return_value = [mock_job1, mock_job2] + + with patch("extralit_server.contexts.workflows.Group.fetch", return_value=mock_group): + jobs = await get_jobs_for_document(async_db, test_document.id) + + assert len(jobs) == 2 + + # Verify job details + analysis_job = next(job for job in jobs if job["id"] == "analysis_job") + assert analysis_job["status"] == "finished" + assert analysis_job["workflow_step"] == "analysis_and_preprocess" + assert analysis_job["document_id"] == test_document.id + assert analysis_job["group_id"] == "test_group_123" + assert analysis_job["result"] == {"analysis_complete": True} + + text_job = next(job for job in jobs if job["id"] == "text_extraction_job") + assert text_job["status"] == "started" + assert text_job["workflow_step"] == "text_extraction" + assert text_job["result"] is None + + async def test_workflow_group_expiration_handling(self, async_db, test_document, mock_redis_connection): + """Test handling of expired RQ Groups.""" + # Create workflow record + workflow = DocumentWorkflow( + id=uuid4(), + document_id=test_document.id, + workflow_type="pdf_processing", + workspace_id=test_document.workspace_id, + reference=test_document.reference, + group_id="expired_group_123", + status="running", + ) + async_db.add(workflow) + await async_db.commit() + + # Mock expired group + with patch("extralit_server.contexts.workflows.Group.fetch", side_effect=Exception("Group expired")): + jobs = await get_jobs_for_document(async_db, test_document.id) + + assert len(jobs) == 1 + assert jobs[0]["id"] == "group_expired" + assert jobs[0]["status"] == "expired" + assert "Group not found or expired" in jobs[0]["error"] + + async def test_workflow_api_integration_with_rq_groups( + self, async_client: AsyncClient, owner_auth_header: dict, async_db, test_document + ): + """Test workflow API endpoints with RQ Groups integration.""" + # Create workflow record + workflow = DocumentWorkflow( + id=uuid4(), + document_id=test_document.id, + workflow_type="pdf_processing", + workspace_id=test_document.workspace_id, + reference=test_document.reference, + group_id="api_test_group_123", + status="running", + ) + async_db.add(workflow) + await async_db.commit() + + # Mock RQ Group for API calls + mock_job = MagicMock(spec=Job) + mock_job.id = "api_test_job" + mock_job.get_status.return_value = "started" + mock_job.meta = { + "document_id": str(test_document.id), + "reference": test_document.reference, + "workflow_step": "analysis_and_preprocess", + } + mock_job.created_at = None + mock_job.started_at = None + mock_job.ended_at = None + mock_job.result = None + mock_job.exc_info = None + mock_job.is_finished = False + + mock_group = MagicMock(spec=Group) + mock_group.get_jobs.return_value = [mock_job] + + with patch("extralit_server.contexts.workflows.Group.fetch", return_value=mock_group): + # Test jobs API with document_id filter + response = await async_client.get( + f"/api/v1/jobs/?document_id={test_document.id}", headers=owner_auth_header + ) + + assert response.status_code == 200 + jobs_data = response.json() + assert len(jobs_data) == 1 + assert jobs_data[0]["id"] == "api_test_job" + assert jobs_data[0]["status"] == "started" + assert jobs_data[0]["workflow_step"] == "analysis_and_preprocess" + + async def test_concurrent_workflow_processing( + self, async_db, test_workspace, mock_redis_connection, mock_rq_queues + ): + """Test multiple concurrent workflows using RQ Groups.""" + mock_default_queue, mock_ocr_queue = mock_rq_queues + + # Create multiple test documents + documents = [] + for i in range(3): + doc = Document( + id=uuid4(), + reference=f"concurrent_test_{i}", + file_name=f"test_{i}.pdf", + workspace_id=test_workspace.id, + url=f"s3://test-bucket/test_{i}.pdf", + metadata_={}, + ) + async_db.add(doc) + documents.append(doc) + + await async_db.commit() + + # Mock RQ Groups for each workflow + mock_groups = [] + for i, doc in enumerate(documents): + mock_group = MagicMock(spec=Group) + mock_group.name = f"document_workflow_{doc.id}_{i:08d}" + mock_groups.append(mock_group) + + with patch("extralit_server.workflows.documents.Group", side_effect=mock_groups): + # Create workflows concurrently + tasks = [] + for doc in documents: + task = create_document_workflow( + document_id=doc.id, + s3_url=doc.url, + reference=doc.reference, + workspace_name=test_workspace.name, + workspace_id=test_workspace.id, + ) + tasks.append(task) + + # Execute all workflows concurrently + results = await asyncio.gather(*tasks) + + # Verify all workflows were created + assert len(results) == 3 + for i, group in enumerate(results): + assert group == mock_groups[i] + + # Verify all DocumentWorkflow records were created + for doc in documents: + workflow = await DocumentWorkflow.get_by_document_id(async_db, doc.id) + assert workflow is not None + assert workflow.document_id == doc.id + assert workflow.status == "running" + + async def test_workflow_failure_and_restart_scenarios(self, async_db, test_document, mock_redis_connection): + """Test various workflow failure and restart scenarios.""" + # Create workflow record + workflow = DocumentWorkflow( + id=uuid4(), + document_id=test_document.id, + workflow_type="pdf_processing", + workspace_id=test_document.workspace_id, + reference=test_document.reference, + group_id="failure_test_group", + status="failed", + ) + async_db.add(workflow) + await async_db.commit() + + # Test scenario 1: Partial failure with some jobs completed + mock_completed_job = MagicMock(spec=Job) + mock_completed_job.id = "completed_job" + mock_completed_job.is_failed = False + mock_completed_job.is_finished = True + + mock_failed_job = MagicMock(spec=Job) + mock_failed_job.id = "failed_job" + mock_failed_job.is_failed = True + mock_failed_job.requeue = MagicMock() + + mock_group = MagicMock(spec=Group) + mock_group.get_jobs.return_value = [mock_completed_job, mock_failed_job] + + with patch("extralit_server.contexts.workflows.Group.fetch", return_value=mock_group): + # Test partial restart (failed jobs only) + result = await restart_failed_workflow(async_db, test_document.id, partial_restart=True) + + assert result["success"] is True + assert result["restarted_jobs"] == ["failed_job"] + assert result["total_failed"] == 1 + mock_failed_job.requeue.assert_called_once() + + # Reset mock + mock_failed_job.requeue.reset_mock() + + # Test full restart (all jobs) + result = await restart_failed_workflow(async_db, test_document.id, partial_restart=False) + + assert result["success"] is True + assert len(result["restarted_jobs"]) == 2 # Both jobs restarted + assert result["restart_type"] == "full" + + async def test_workflow_progress_calculation(self, async_db, test_document, mock_redis_connection): + """Test workflow progress calculation with various job states.""" + # Create workflow record + workflow = DocumentWorkflow( + id=uuid4(), + document_id=test_document.id, + workflow_type="pdf_processing", + workspace_id=test_document.workspace_id, + reference=test_document.reference, + group_id="progress_test_group", + status="running", + ) + async_db.add(workflow) + await async_db.commit() + + # Test different progress scenarios + test_scenarios = [ + # (completed, failed, running, expected_status, expected_progress) + ([1, 1], [], [], "completed", 1.0), # All completed + ([1], [1], [], "failed", 1.0), # Mixed with failures + ([1], [], [1], "running", 0.5), # Half completed, half running + ([], [], [2], "running", 0.0), # All running + ([], [], [], "pending", 0.0), # No jobs started + ] + + for completed_count, failed_count, running_count, expected_status, expected_progress in test_scenarios: + mock_jobs = [] + + # Add completed jobs + for i in range(len(completed_count)): + job = MagicMock(spec=Job) + job.id = f"completed_{i}" + job.is_finished = True + job.is_failed = False + job.is_started = True + job.get_status.return_value = "finished" + job.meta = {} + job.created_at = None + job.started_at = None + job.ended_at = None + job.result = None + job.exc_info = None + mock_jobs.append(job) + + # Add failed jobs + for i in range(len(failed_count)): + job = MagicMock(spec=Job) + job.id = f"failed_{i}" + job.is_finished = True + job.is_failed = True + job.is_started = True + job.get_status.return_value = "failed" + job.meta = {} + job.created_at = None + job.started_at = None + job.ended_at = None + job.result = None + job.exc_info = "Test error" + mock_jobs.append(job) + + # Add running jobs + for i in range(len(running_count)): + job = MagicMock(spec=Job) + job.id = f"running_{i}" + job.is_finished = False + job.is_failed = False + job.is_started = True + job.get_status.return_value = "started" + job.meta = {} + job.created_at = None + job.started_at = None + job.ended_at = None + job.result = None + job.exc_info = None + mock_jobs.append(job) + + mock_group = MagicMock(spec=Group) + mock_group.get_jobs.return_value = mock_jobs + + with patch("extralit_server.contexts.workflows.Group.fetch", return_value=mock_group): + status = await get_workflow_status(async_db, test_document.id) + + assert status["status"] == expected_status, f"Expected {expected_status}, got {status['status']}" + assert status["progress"] == expected_progress, ( + f"Expected {expected_progress}, got {status['progress']}" + ) diff --git a/extralit-server/tests/unit/jobs/test_jobs.py b/extralit-server/tests/unit/jobs/test_jobs.py new file mode 100644 index 000000000..e6375fd40 --- /dev/null +++ b/extralit-server/tests/unit/jobs/test_jobs.py @@ -0,0 +1,59 @@ +# Copyright 2024-present, Extralit Labs, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from uuid import uuid4 + +import pytest +from httpx import AsyncClient + + +@pytest.mark.asyncio +class TestJobsAPI: + """Test jobs API endpoints with RQ Groups integration.""" + + async def test_get_jobs_requires_filter(self, async_client: AsyncClient, owner_auth_header: dict): + """Test that GET /jobs/ requires at least one filter parameter.""" + response = await async_client.get("/api/v1/jobs/", headers=owner_auth_header) + + assert response.status_code == 400 + assert "Must provide at least one filter" in response.json()["detail"] + + async def test_get_jobs_by_document_id_not_found(self, async_client: AsyncClient, owner_auth_header: dict): + """Test GET /jobs/ with non-existent document_id returns empty list.""" + non_existent_id = uuid4() + response = await async_client.get(f"/api/v1/jobs/?document_id={non_existent_id}", headers=owner_auth_header) + + assert response.status_code == 200 + assert response.json() == [] + + async def test_get_jobs_by_reference_not_found(self, async_client: AsyncClient, owner_auth_header: dict): + """Test GET /jobs/ with non-existent reference returns empty list.""" + response = await async_client.get("/api/v1/jobs/?reference=non_existent_reference", headers=owner_auth_header) + + assert response.status_code == 200 + assert response.json() == [] + + async def test_get_jobs_by_group_id_not_found(self, async_client: AsyncClient, owner_auth_header: dict): + """Test GET /jobs/ with non-existent group_id returns 404.""" + response = await async_client.get("/api/v1/jobs/?group_id=non_existent_group", headers=owner_auth_header) + + assert response.status_code == 404 + assert "not found" in response.json()["detail"] + + async def test_jobs_api_schema_validation(self, async_client: AsyncClient, owner_auth_header: dict): + """Test that the API validates query parameters correctly.""" + # Test with invalid UUID format for document_id + response = await async_client.get("/api/v1/jobs/?document_id=invalid_uuid", headers=owner_auth_header) + + assert response.status_code == 422 # Validation error diff --git a/extralit-server/tests/unit/jobs/test_rq_groups_failure_scenarios.py b/extralit-server/tests/unit/jobs/test_rq_groups_failure_scenarios.py new file mode 100644 index 000000000..23519fde2 --- /dev/null +++ b/extralit-server/tests/unit/jobs/test_rq_groups_failure_scenarios.py @@ -0,0 +1,533 @@ +# Copyright 2024-present, Extralit Labs, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for RQ Groups failure and restart scenarios.""" + +from datetime import datetime, timezone +from unittest.mock import AsyncMock, MagicMock, patch +from uuid import uuid4 + +import pytest +from rq.group import Group +from rq.job import Job + +from extralit_server.contexts.workflows import ( + get_failed_jobs_in_group, + get_workflow_status_from_group, + is_workflow_resumable, + restart_failed_jobs_in_workflow, + restart_failed_workflow, +) +from extralit_server.models.database import DocumentWorkflow + + +class TestRQGroupsFailureScenarios: + """Test RQ Groups failure and restart scenarios.""" + + @pytest.fixture + def mock_document_workflow(self): + """Create mock DocumentWorkflow.""" + workflow = MagicMock(spec=DocumentWorkflow) + workflow.id = uuid4() + workflow.document_id = uuid4() + workflow.group_id = "test_group_123" + workflow.status = "failed" + workflow.reference = "test_ref" + workflow.workspace_id = uuid4() + workflow.inserted_at = datetime.now(timezone.utc) + workflow.updated_at = datetime.now(timezone.utc) + return workflow + + def test_single_job_failure_scenario(self): + """Test workflow status when a single job fails.""" + # Create mock jobs - one completed, one failed + completed_job = MagicMock(spec=Job) + completed_job.id = "completed_job" + completed_job.is_finished = True + completed_job.is_failed = False + completed_job.is_started = True + completed_job.get_status.return_value = "finished" + completed_job.meta = {"workflow_step": "analysis_and_preprocess"} + completed_job.created_at = datetime.now(timezone.utc) + completed_job.started_at = datetime.now(timezone.utc) + completed_job.ended_at = datetime.now(timezone.utc) + completed_job.result = {"analysis_complete": True} + completed_job.exc_info = None + + failed_job = MagicMock(spec=Job) + failed_job.id = "failed_job" + failed_job.is_finished = True + failed_job.is_failed = True + failed_job.is_started = True + failed_job.get_status.return_value = "failed" + failed_job.meta = {"workflow_step": "text_extraction"} + failed_job.created_at = datetime.now(timezone.utc) + failed_job.started_at = datetime.now(timezone.utc) + failed_job.ended_at = datetime.now(timezone.utc) + failed_job.result = None + failed_job.exc_info = "OCR processing failed: timeout" + + mock_group = MagicMock(spec=Group) + mock_group.get_jobs.return_value = [completed_job, failed_job] + + with patch("extralit_server.contexts.workflows.Group.fetch", return_value=mock_group): + status = get_workflow_status_from_group("test_group_123") + + assert status["status"] == "failed" + assert status["progress"] == 1.0 # Both jobs finished + assert status["total_jobs"] == 2 + assert status["completed_jobs"] == 2 + assert status["failed_jobs"] == 1 + assert status["running_jobs"] == 0 + assert len(status["jobs"]) == 2 + + def test_multiple_job_failures_scenario(self): + """Test workflow status when multiple jobs fail.""" + # Create mock jobs - one completed, two failed + completed_job = MagicMock(spec=Job) + completed_job.is_finished = True + completed_job.is_failed = False + completed_job.get_status.return_value = "finished" + + failed_job1 = MagicMock(spec=Job) + failed_job1.is_finished = True + failed_job1.is_failed = True + failed_job1.get_status.return_value = "failed" + failed_job1.exc_info = "Analysis failed: invalid PDF" + + failed_job2 = MagicMock(spec=Job) + failed_job2.is_finished = True + failed_job2.is_failed = True + failed_job2.get_status.return_value = "failed" + failed_job2.exc_info = "Text extraction failed: timeout" + + mock_group = MagicMock(spec=Group) + mock_group.get_jobs.return_value = [completed_job, failed_job1, failed_job2] + + with patch("extralit_server.contexts.workflows.Group.fetch", return_value=mock_group): + status = get_workflow_status_from_group("test_group_123") + + assert status["status"] == "failed" + assert status["progress"] == 1.0 # All jobs finished + assert status["total_jobs"] == 3 + assert status["completed_jobs"] == 3 + assert status["failed_jobs"] == 2 + assert status["running_jobs"] == 0 + + def test_job_timeout_failure_scenario(self): + """Test handling of job timeout failures.""" + timeout_job = MagicMock(spec=Job) + timeout_job.id = "timeout_job" + timeout_job.is_finished = True + timeout_job.is_failed = True + timeout_job.get_status.return_value = "failed" + timeout_job.meta = {"workflow_step": "text_extraction"} + timeout_job.exc_info = "rq.timeouts.JobTimeoutException: Job exceeded maximum timeout value (900s)" + timeout_job.created_at = datetime.now(timezone.utc) + timeout_job.started_at = datetime.now(timezone.utc) + timeout_job.ended_at = datetime.now(timezone.utc) + + mock_group = MagicMock(spec=Group) + mock_group.get_jobs.return_value = [timeout_job] + + with patch("extralit_server.contexts.workflows.Group.fetch", return_value=mock_group): + failed_jobs = get_failed_jobs_in_group("test_group_123") + + assert len(failed_jobs) == 1 + assert failed_jobs[0]["id"] == "timeout_job" + assert failed_jobs[0]["workflow_step"] == "text_extraction" + assert "JobTimeoutException" in failed_jobs[0]["failure_reason"] + + def test_job_memory_error_failure_scenario(self): + """Test handling of job memory error failures.""" + memory_error_job = MagicMock(spec=Job) + memory_error_job.id = "memory_error_job" + memory_error_job.is_finished = True + memory_error_job.is_failed = True + memory_error_job.get_status.return_value = "failed" + memory_error_job.meta = {"workflow_step": "analysis_and_preprocess"} + memory_error_job.exc_info = "MemoryError: Unable to allocate memory for PDF processing" + memory_error_job.created_at = datetime.now(timezone.utc) + memory_error_job.started_at = datetime.now(timezone.utc) + memory_error_job.ended_at = datetime.now(timezone.utc) + + mock_group = MagicMock(spec=Group) + mock_group.get_jobs.return_value = [memory_error_job] + + with patch("extralit_server.contexts.workflows.Group.fetch", return_value=mock_group): + failed_jobs = get_failed_jobs_in_group("test_group_123") + + assert len(failed_jobs) == 1 + assert "MemoryError" in failed_jobs[0]["failure_reason"] + + def test_job_dependency_failure_scenario(self): + """Test workflow when job fails due to dependency failure.""" + # First job fails + failed_dependency_job = MagicMock(spec=Job) + failed_dependency_job.id = "analysis_job" + failed_dependency_job.is_finished = True + failed_dependency_job.is_failed = True + failed_dependency_job.get_status.return_value = "failed" + failed_dependency_job.meta = {"workflow_step": "analysis_and_preprocess"} + + # Dependent job should not start (or be cancelled) + cancelled_job = MagicMock(spec=Job) + cancelled_job.id = "text_extraction_job" + cancelled_job.is_finished = False + cancelled_job.is_failed = False + cancelled_job.is_started = False + cancelled_job.get_status.return_value = "deferred" # RQ status for jobs waiting on dependencies + cancelled_job.meta = {"workflow_step": "text_extraction"} + + mock_group = MagicMock(spec=Group) + mock_group.get_jobs.return_value = [failed_dependency_job, cancelled_job] + + with patch("extralit_server.contexts.workflows.Group.fetch", return_value=mock_group): + status = get_workflow_status_from_group("test_group_123") + + assert status["status"] == "failed" # Has failed jobs + assert status["total_jobs"] == 2 + assert status["completed_jobs"] == 1 # Only the failed job "completed" + assert status["failed_jobs"] == 1 + assert status["running_jobs"] == 0 + + @pytest.mark.asyncio + async def test_partial_workflow_restart_scenario(self, mock_document_workflow): + """Test partial restart of workflow (failed jobs only).""" + db_mock = AsyncMock() + + # Mock jobs: one completed, one failed + completed_job = MagicMock(spec=Job) + completed_job.id = "completed_job" + completed_job.is_failed = False + + failed_job = MagicMock(spec=Job) + failed_job.id = "failed_job" + failed_job.is_failed = True + failed_job.requeue = MagicMock() + + mock_group = MagicMock(spec=Group) + mock_group.get_jobs.return_value = [completed_job, failed_job] + + with ( + patch("extralit_server.contexts.workflows.Group.fetch", return_value=mock_group), + patch("extralit_server.contexts.workflows.update_workflow_status") as mock_update, + ): + result = await restart_failed_jobs_in_workflow(db_mock, mock_document_workflow) + + assert result["success"] is True + assert result["restarted_jobs"] == ["failed_job"] + assert result["total_failed"] == 1 + + # Only failed job should be requeued + failed_job.requeue.assert_called_once() + + # Workflow status should be updated to running + mock_update.assert_called_once_with(db_mock, mock_document_workflow, "running") + + @pytest.mark.asyncio + async def test_full_workflow_restart_scenario(self, mock_document_workflow): + """Test full restart of workflow (all jobs).""" + db_mock = AsyncMock() + document_id = mock_document_workflow.document_id + + # Mock jobs: completed, failed, and running + completed_job = MagicMock(spec=Job) + completed_job.id = "completed_job" + completed_job.is_failed = False + completed_job.requeue = MagicMock() + + failed_job = MagicMock(spec=Job) + failed_job.id = "failed_job" + failed_job.is_failed = True + failed_job.requeue = MagicMock() + + running_job = MagicMock(spec=Job) + running_job.id = "running_job" + running_job.is_failed = False + running_job.requeue = MagicMock() + + mock_group = MagicMock(spec=Group) + mock_group.get_jobs.return_value = [completed_job, failed_job, running_job] + + with ( + patch.object(DocumentWorkflow, "get_by_document_id", return_value=mock_document_workflow), + patch("extralit_server.contexts.workflows.is_workflow_resumable", return_value=True), + patch("extralit_server.contexts.workflows.Group.fetch", return_value=mock_group), + patch("extralit_server.contexts.workflows.update_workflow_status") as mock_update, + ): + result = await restart_failed_workflow(db_mock, document_id, partial_restart=False) + + assert result["success"] is True + assert len(result["restarted_jobs"]) == 3 # All jobs restarted + assert result["total_failed"] == 1 # Only one was actually failed + assert result["restart_type"] == "full" + + # All jobs should be requeued + completed_job.requeue.assert_called_once() + failed_job.requeue.assert_called_once() + running_job.requeue.assert_called_once() + + mock_update.assert_called_once_with(db_mock, mock_document_workflow, "running") + + @pytest.mark.asyncio + async def test_restart_with_job_requeue_failure(self, mock_document_workflow): + """Test restart scenario when job requeue fails.""" + db_mock = AsyncMock() + + # Mock failed job that fails to requeue + failed_job = MagicMock(spec=Job) + failed_job.id = "failed_job" + failed_job.is_failed = True + failed_job.requeue.side_effect = Exception("Failed to requeue job") + + # Mock successful job + successful_job = MagicMock(spec=Job) + successful_job.id = "successful_job" + successful_job.is_failed = True + successful_job.requeue = MagicMock() + + mock_group = MagicMock(spec=Group) + mock_group.get_jobs.return_value = [failed_job, successful_job] + + with ( + patch("extralit_server.contexts.workflows.Group.fetch", return_value=mock_group), + patch("extralit_server.contexts.workflows.update_workflow_status") as mock_update, + ): + result = await restart_failed_jobs_in_workflow(db_mock, mock_document_workflow) + + assert result["success"] is True + assert result["restarted_jobs"] == ["successful_job"] # Only successful requeue + assert result["total_failed"] == 2 + + # Both jobs should have requeue attempted + failed_job.requeue.assert_called_once() + successful_job.requeue.assert_called_once() + + # Workflow should still be updated if any jobs were restarted + mock_update.assert_called_once() + + def test_workflow_not_resumable_scenario(self): + """Test scenario where workflow is not resumable (no failed jobs).""" + # All jobs completed successfully + completed_job1 = MagicMock(spec=Job) + completed_job1.is_failed = False + + completed_job2 = MagicMock(spec=Job) + completed_job2.is_failed = False + + mock_group = MagicMock(spec=Group) + mock_group.get_jobs.return_value = [completed_job1, completed_job2] + + with patch("extralit_server.contexts.workflows.Group.fetch", return_value=mock_group): + resumable = is_workflow_resumable("test_group_123") + + assert resumable is False + + def test_group_corruption_scenario(self): + """Test handling of corrupted or inconsistent group state.""" + # Mock job that raises exception during processing + corrupted_job = MagicMock(spec=Job) + corrupted_job.id = "corrupted_job" + corrupted_job.get_status.side_effect = Exception("Job data corrupted") + corrupted_job.is_failed = True # This should still work + + normal_job = MagicMock(spec=Job) + normal_job.id = "normal_job" + normal_job.is_failed = False + normal_job.is_finished = True + normal_job.get_status.return_value = "finished" + normal_job.meta = {"workflow_step": "analysis"} + normal_job.created_at = datetime.now(timezone.utc) + normal_job.started_at = None + normal_job.ended_at = None + normal_job.result = None + normal_job.exc_info = None + + mock_group = MagicMock(spec=Group) + mock_group.get_jobs.return_value = [corrupted_job, normal_job] + + with patch("extralit_server.contexts.workflows.Group.fetch", return_value=mock_group): + status = get_workflow_status_from_group("test_group_123") + + # Should handle corrupted job gracefully + assert status["total_jobs"] == 2 + assert len(status["jobs"]) == 2 + + # Find the corrupted job in results + corrupted_job_result = next(job for job in status["jobs"] if job["id"] == "corrupted_job") + assert corrupted_job_result["status"] == "error" + assert "Job processing error" in corrupted_job_result["error"] + + def test_redis_connection_failure_scenario(self): + """Test handling of Redis connection failures.""" + with patch("extralit_server.contexts.workflows.Group.fetch", side_effect=Exception("Redis connection failed")): + status = get_workflow_status_from_group("test_group_123") + + assert status["status"] == "expired" + assert "Group not found or expired" in status["error"] + assert status["progress"] == 0.0 + assert status["total_jobs"] == 0 + + def test_group_expiration_during_processing(self): + """Test handling of group expiration during job processing.""" + # First call succeeds, second call fails (group expired) + mock_group = MagicMock(spec=Group) + mock_job = MagicMock(spec=Job) + mock_job.id = "test_job" + mock_job.is_failed = False + mock_group.get_jobs.return_value = [mock_job] + + with patch( + "extralit_server.contexts.workflows.Group.fetch", side_effect=[mock_group, Exception("Group expired")] + ): + # First call should succeed + status1 = get_workflow_status_from_group("test_group_123") + assert status1["status"] != "expired" + + # Second call should handle expiration + status2 = get_workflow_status_from_group("test_group_123") + assert status2["status"] == "expired" + + @pytest.mark.asyncio + async def test_concurrent_restart_attempts(self, mock_document_workflow): + """Test handling of concurrent restart attempts on the same workflow.""" + db_mock = AsyncMock() + document_id = mock_document_workflow.document_id + + failed_job = MagicMock(spec=Job) + failed_job.id = "failed_job" + failed_job.is_failed = True + failed_job.requeue = MagicMock() + + mock_group = MagicMock(spec=Group) + mock_group.get_jobs.return_value = [failed_job] + + with ( + patch.object(DocumentWorkflow, "get_by_document_id", return_value=mock_document_workflow), + patch("extralit_server.contexts.workflows.is_workflow_resumable", return_value=True), + patch("extralit_server.contexts.workflows.Group.fetch", return_value=mock_group), + patch("extralit_server.contexts.workflows.update_workflow_status"), + ): + # Simulate concurrent restart attempts + import asyncio + + tasks = [ + restart_failed_workflow(db_mock, document_id, partial_restart=True), + restart_failed_workflow(db_mock, document_id, partial_restart=True), + ] + + results = await asyncio.gather(*tasks, return_exceptions=True) + + # Both should succeed (RQ handles concurrency) + for result in results: + if isinstance(result, dict): + assert result["success"] is True + assert result["restarted_jobs"] == ["failed_job"] + + def test_large_workflow_failure_scenario(self): + """Test handling of workflows with many jobs and mixed failure states.""" + # Create 50 jobs with mixed states + jobs = [] + for i in range(50): + job = MagicMock(spec=Job) + job.id = f"job_{i}" + job.meta = {"workflow_step": f"step_{i % 5}"} + job.created_at = datetime.now(timezone.utc) + job.started_at = None + job.ended_at = None + job.result = None + job.exc_info = None + + if i < 30: # 30 completed + job.is_finished = True + job.is_failed = False + job.get_status.return_value = "finished" + elif i < 40: # 10 failed + job.is_finished = True + job.is_failed = True + job.get_status.return_value = "failed" + job.exc_info = f"Error in job {i}" + else: # 10 running + job.is_finished = False + job.is_failed = False + job.is_started = True + job.get_status.return_value = "started" + + jobs.append(job) + + mock_group = MagicMock(spec=Group) + mock_group.get_jobs.return_value = jobs + + with patch("extralit_server.contexts.workflows.Group.fetch", return_value=mock_group): + status = get_workflow_status_from_group("large_workflow_group") + + assert status["status"] == "failed" # Has failed jobs + assert status["total_jobs"] == 50 + assert status["completed_jobs"] == 40 # 30 successful + 10 failed + assert status["failed_jobs"] == 10 + assert status["running_jobs"] == 10 + assert status["progress"] == 0.8 # 40/50 completed + + # Test failed job retrieval + failed_jobs = get_failed_jobs_in_group("large_workflow_group") + assert len(failed_jobs) == 10 + for i, failed_job in enumerate(failed_jobs): + assert failed_job["id"] == f"job_{30 + i}" + assert "Error in job" in failed_job["failure_reason"] + + @pytest.mark.asyncio + async def test_workflow_restart_after_partial_completion(self, mock_document_workflow): + """Test restarting workflow that had some jobs complete before failure.""" + db_mock = AsyncMock() + + # Simulate workflow that had analysis complete but text extraction failed + analysis_job = MagicMock(spec=Job) + analysis_job.id = "analysis_job" + analysis_job.is_failed = False + analysis_job.is_finished = True + analysis_job.meta = {"workflow_step": "analysis_and_preprocess"} + + text_job = MagicMock(spec=Job) + text_job.id = "text_extraction_job" + text_job.is_failed = True + text_job.requeue = MagicMock() + text_job.meta = {"workflow_step": "text_extraction"} + + table_job = MagicMock(spec=Job) + table_job.id = "table_extraction_job" + table_job.is_failed = True # Failed due to dependency + table_job.requeue = MagicMock() + table_job.meta = {"workflow_step": "table_extraction"} + + mock_group = MagicMock(spec=Group) + mock_group.get_jobs.return_value = [analysis_job, text_job, table_job] + + with ( + patch("extralit_server.contexts.workflows.Group.fetch", return_value=mock_group), + patch("extralit_server.contexts.workflows.update_workflow_status") as mock_update, + ): + result = await restart_failed_jobs_in_workflow(db_mock, mock_document_workflow) + + assert result["success"] is True + assert len(result["restarted_jobs"]) == 2 # Only failed jobs restarted + assert "text_extraction_job" in result["restarted_jobs"] + assert "table_extraction_job" in result["restarted_jobs"] + assert result["total_failed"] == 2 + + # Failed jobs should be requeued + text_job.requeue.assert_called_once() + table_job.requeue.assert_called_once() + + mock_update.assert_called_once_with(db_mock, mock_document_workflow, "running") diff --git a/extralit-server/tests/unit/jobs/test_rq_groups_integration.py b/extralit-server/tests/unit/jobs/test_rq_groups_integration.py new file mode 100644 index 000000000..2e2882580 --- /dev/null +++ b/extralit-server/tests/unit/jobs/test_rq_groups_integration.py @@ -0,0 +1,511 @@ +# Copyright 2024-present, Extralit Labs, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for RQ Groups integration functions.""" + +from datetime import datetime, timezone +from unittest.mock import AsyncMock, MagicMock, patch +from uuid import uuid4 + +import pytest +from rq.group import Group +from rq.job import Job + +from extralit_server.contexts.workflows import ( + get_failed_jobs_in_group, + get_jobs_by_reference, + get_jobs_for_document, + get_workflow_status, + get_workflow_status_from_group, + is_workflow_resumable, + restart_failed_jobs_in_workflow, + restart_failed_workflow, +) +from extralit_server.models.database import DocumentWorkflow + + +class TestRQGroupsIntegration: + """Test RQ Groups integration functions.""" + + @pytest.fixture + def mock_document_workflow(self): + """Create mock DocumentWorkflow.""" + workflow = MagicMock(spec=DocumentWorkflow) + workflow.id = uuid4() + workflow.document_id = uuid4() + workflow.group_id = "test_group_123" + workflow.status = "running" + workflow.reference = "test_ref" + workflow.workspace_id = uuid4() + workflow.inserted_at = datetime.now(timezone.utc) + workflow.updated_at = datetime.now(timezone.utc) + return workflow + + @pytest.fixture + def mock_rq_job(self): + """Create mock RQ Job.""" + job = MagicMock(spec=Job) + job.id = "test_job_123" + job.meta = { + "document_id": str(uuid4()), + "reference": "test_ref", + "workflow_step": "analysis_and_preprocess", + "workflow_id": str(uuid4()), + } + job.created_at = datetime.now(timezone.utc) + job.started_at = datetime.now(timezone.utc) + job.ended_at = None + job.result = None + job.exc_info = None + job.is_finished = False + job.is_failed = False + job.is_started = True + job.get_status.return_value = "started" + return job + + @pytest.fixture + def mock_rq_group(self, mock_rq_job): + """Create mock RQ Group.""" + group = MagicMock(spec=Group) + group.name = "test_group_123" + group.get_jobs.return_value = [mock_rq_job] + return group + + @pytest.mark.asyncio + async def test_get_jobs_for_document_success(self, mock_document_workflow, mock_rq_group): + """Test successful job retrieval for document using RQ Groups.""" + db_mock = AsyncMock() + document_id = mock_document_workflow.document_id + + with ( + patch.object(DocumentWorkflow, "get_by_document_id", return_value=mock_document_workflow), + patch("extralit_server.contexts.workflows.Group.fetch", return_value=mock_rq_group), + ): + jobs = await get_jobs_for_document(db_mock, document_id) + + assert len(jobs) == 1 + assert jobs[0]["id"] == "test_job_123" + assert jobs[0]["status"] == "started" + assert jobs[0]["workflow_step"] == "analysis_and_preprocess" + assert jobs[0]["document_id"] == document_id + assert jobs[0]["group_id"] == "test_group_123" + + @pytest.mark.asyncio + async def test_get_jobs_for_document_no_workflow(self): + """Test job retrieval when no workflow exists.""" + db_mock = AsyncMock() + document_id = uuid4() + + with patch.object(DocumentWorkflow, "get_by_document_id", return_value=None): + jobs = await get_jobs_for_document(db_mock, document_id) + + assert jobs == [] + + @pytest.mark.asyncio + async def test_get_jobs_for_document_group_expired(self, mock_document_workflow): + """Test job retrieval when RQ Group is expired or missing.""" + db_mock = AsyncMock() + document_id = mock_document_workflow.document_id + + with ( + patch.object(DocumentWorkflow, "get_by_document_id", return_value=mock_document_workflow), + patch("extralit_server.contexts.workflows.Group.fetch", side_effect=Exception("Group not found")), + ): + jobs = await get_jobs_for_document(db_mock, document_id) + + assert len(jobs) == 1 + assert jobs[0]["id"] == "group_expired" + assert jobs[0]["status"] == "expired" + assert "Group not found or expired" in jobs[0]["error"] + + @pytest.mark.asyncio + async def test_get_jobs_by_reference_success(self, mock_document_workflow): + """Test successful job retrieval by reference using RQ Groups.""" + db_mock = AsyncMock() + reference = "test_ref" + + # Mock multiple workflows for the reference + workflow1 = mock_document_workflow + workflow2 = MagicMock(spec=DocumentWorkflow) + workflow2.document_id = uuid4() + workflow2.group_id = "test_group_456" + + with ( + patch.object(DocumentWorkflow, "get_by_reference", return_value=[workflow1, workflow2]), + patch("extralit_server.contexts.workflows.get_jobs_for_document") as mock_get_jobs, + ): + # Mock return values for each document + mock_get_jobs.side_effect = [ + [{"id": "job1", "document_id": workflow1.document_id}], + [{"id": "job2", "document_id": workflow2.document_id}], + ] + + jobs = await get_jobs_by_reference(db_mock, reference) + + assert len(jobs) == 2 + assert jobs[0]["id"] == "job1" + assert jobs[1]["id"] == "job2" + assert mock_get_jobs.call_count == 2 + + @pytest.mark.asyncio + async def test_get_jobs_by_reference_no_workflows(self): + """Test job retrieval by reference when no workflows exist.""" + db_mock = AsyncMock() + reference = "nonexistent_ref" + + with patch.object(DocumentWorkflow, "get_by_reference", return_value=[]): + jobs = await get_jobs_by_reference(db_mock, reference) + + assert jobs == [] + + @pytest.mark.asyncio + async def test_get_workflow_status_success(self, mock_document_workflow): + """Test successful workflow status retrieval using RQ Groups.""" + db_mock = AsyncMock() + document_id = mock_document_workflow.document_id + + mock_status = { + "status": "running", + "progress": 0.5, + "total_jobs": 2, + "completed_jobs": 1, + "failed_jobs": 0, + "running_jobs": 1, + "jobs": [], + } + + with ( + patch.object(DocumentWorkflow, "get_by_document_id", return_value=mock_document_workflow), + patch("extralit_server.contexts.workflows.get_workflow_status_from_group", return_value=mock_status), + patch("extralit_server.contexts.workflows.update_workflow_status") as mock_update, + ): + status = await get_workflow_status(db_mock, document_id) + + assert status["status"] == "running" + assert status["progress"] == 0.5 + assert status["document_id"] == document_id + assert status["workflow_id"] == mock_document_workflow.id + assert status["group_id"] == "test_group_123" + + # Should not update status if it matches + mock_update.assert_not_called() + + @pytest.mark.asyncio + async def test_get_workflow_status_no_workflow(self): + """Test workflow status retrieval when no workflow exists.""" + db_mock = AsyncMock() + document_id = uuid4() + + with patch.object(DocumentWorkflow, "get_by_document_id", return_value=None): + status = await get_workflow_status(db_mock, document_id) + + assert status["status"] == "not_found" + assert status["document_id"] == document_id + assert status["progress"] == 0.0 + assert "No workflow found" in status["error"] + + def test_get_workflow_status_from_group_success(self, mock_rq_group, mock_rq_job): + """Test workflow status calculation from RQ Group.""" + # Setup job states + job1 = mock_rq_job + job1.is_finished = True + job1.is_failed = False + job1.get_status.return_value = "finished" + + job2 = MagicMock(spec=Job) + job2.id = "test_job_456" + job2.is_finished = False + job2.is_failed = False + job2.is_started = True + job2.get_status.return_value = "started" + job2.meta = {"workflow_step": "text_extraction"} + job2.created_at = datetime.now(timezone.utc) + job2.started_at = datetime.now(timezone.utc) + job2.ended_at = None + job2.result = None + job2.exc_info = None + + mock_rq_group.get_jobs.return_value = [job1, job2] + + with patch("extralit_server.contexts.workflows.Group.fetch", return_value=mock_rq_group): + status = get_workflow_status_from_group("test_group_123") + + assert status["status"] == "running" # Has running jobs + assert status["progress"] == 0.5 # 1 of 2 completed + assert status["total_jobs"] == 2 + assert status["completed_jobs"] == 1 + assert status["failed_jobs"] == 0 + assert status["running_jobs"] == 1 + assert len(status["jobs"]) == 2 + + def test_get_workflow_status_from_group_expired(self): + """Test workflow status when group is expired or missing.""" + with patch("extralit_server.contexts.workflows.Group.fetch", side_effect=Exception("Group expired")): + status = get_workflow_status_from_group("expired_group") + + assert status["status"] == "expired" + assert status["progress"] == 0.0 + assert "Group not found or expired" in status["error"] + + def test_get_workflow_status_from_group_all_completed(self, mock_rq_group): + """Test workflow status when all jobs are completed.""" + job1 = MagicMock(spec=Job) + job1.is_finished = True + job1.is_failed = False + job1.get_status.return_value = "finished" + + job2 = MagicMock(spec=Job) + job2.is_finished = True + job2.is_failed = False + job2.get_status.return_value = "finished" + + mock_rq_group.get_jobs.return_value = [job1, job2] + + with patch("extralit_server.contexts.workflows.Group.fetch", return_value=mock_rq_group): + status = get_workflow_status_from_group("test_group_123") + + assert status["status"] == "completed" + assert status["progress"] == 1.0 + assert status["completed_jobs"] == 2 + assert status["running_jobs"] == 0 + + def test_get_workflow_status_from_group_with_failures(self, mock_rq_group): + """Test workflow status when some jobs have failed.""" + job1 = MagicMock(spec=Job) + job1.is_finished = True + job1.is_failed = True + job1.get_status.return_value = "failed" + job1.exc_info = "Test error" + + job2 = MagicMock(spec=Job) + job2.is_finished = True + job2.is_failed = False + job2.get_status.return_value = "finished" + + mock_rq_group.get_jobs.return_value = [job1, job2] + + with patch("extralit_server.contexts.workflows.Group.fetch", return_value=mock_rq_group): + status = get_workflow_status_from_group("test_group_123") + + assert status["status"] == "failed" # Has failed jobs + assert status["progress"] == 1.0 # Both jobs finished + assert status["completed_jobs"] == 2 + assert status["failed_jobs"] == 1 + + def test_is_workflow_resumable_with_failed_jobs(self, mock_rq_group): + """Test workflow resumability when there are failed jobs.""" + failed_job = MagicMock(spec=Job) + failed_job.is_failed = True + + completed_job = MagicMock(spec=Job) + completed_job.is_failed = False + + mock_rq_group.get_jobs.return_value = [failed_job, completed_job] + + with patch("extralit_server.contexts.workflows.Group.fetch", return_value=mock_rq_group): + resumable = is_workflow_resumable("test_group_123") + + assert resumable is True + + def test_is_workflow_resumable_no_failed_jobs(self, mock_rq_group): + """Test workflow resumability when there are no failed jobs.""" + completed_job1 = MagicMock(spec=Job) + completed_job1.is_failed = False + + completed_job2 = MagicMock(spec=Job) + completed_job2.is_failed = False + + mock_rq_group.get_jobs.return_value = [completed_job1, completed_job2] + + with patch("extralit_server.contexts.workflows.Group.fetch", return_value=mock_rq_group): + resumable = is_workflow_resumable("test_group_123") + + assert resumable is False + + def test_is_workflow_resumable_group_expired(self): + """Test workflow resumability when group is expired.""" + with patch("extralit_server.contexts.workflows.Group.fetch", side_effect=Exception("Group expired")): + resumable = is_workflow_resumable("expired_group") + + assert resumable is False + + def test_get_failed_jobs_in_group_success(self, mock_rq_group): + """Test retrieval of failed jobs from RQ Group.""" + failed_job = MagicMock(spec=Job) + failed_job.id = "failed_job_123" + failed_job.is_failed = True + failed_job.get_status.return_value = "failed" + failed_job.exc_info = "Test failure reason" + failed_job.meta = {"workflow_step": "analysis_and_preprocess"} + failed_job.created_at = datetime.now(timezone.utc) + failed_job.started_at = datetime.now(timezone.utc) + failed_job.ended_at = datetime.now(timezone.utc) + + completed_job = MagicMock(spec=Job) + completed_job.is_failed = False + + mock_rq_group.get_jobs.return_value = [failed_job, completed_job] + + with patch("extralit_server.contexts.workflows.Group.fetch", return_value=mock_rq_group): + failed_jobs = get_failed_jobs_in_group("test_group_123") + + assert len(failed_jobs) == 1 + assert failed_jobs[0]["id"] == "failed_job_123" + assert failed_jobs[0]["status"] == "failed" + assert failed_jobs[0]["failure_reason"] == "Test failure reason" + assert failed_jobs[0]["workflow_step"] == "analysis_and_preprocess" + + def test_get_failed_jobs_in_group_no_failures(self, mock_rq_group): + """Test retrieval of failed jobs when there are none.""" + completed_job1 = MagicMock(spec=Job) + completed_job1.is_failed = False + + completed_job2 = MagicMock(spec=Job) + completed_job2.is_failed = False + + mock_rq_group.get_jobs.return_value = [completed_job1, completed_job2] + + with patch("extralit_server.contexts.workflows.Group.fetch", return_value=mock_rq_group): + failed_jobs = get_failed_jobs_in_group("test_group_123") + + assert failed_jobs == [] + + def test_get_failed_jobs_in_group_expired(self): + """Test failed job retrieval when group is expired.""" + with patch("extralit_server.contexts.workflows.Group.fetch", side_effect=Exception("Group expired")): + failed_jobs = get_failed_jobs_in_group("expired_group") + + assert failed_jobs == [] + + @pytest.mark.asyncio + async def test_restart_failed_jobs_in_workflow_success(self, mock_document_workflow, mock_rq_group): + """Test successful restart of failed jobs in workflow.""" + db_mock = AsyncMock() + + failed_job = MagicMock(spec=Job) + failed_job.id = "failed_job_123" + failed_job.is_failed = True + failed_job.requeue = MagicMock() + + completed_job = MagicMock(spec=Job) + completed_job.is_failed = False + + mock_rq_group.get_jobs.return_value = [failed_job, completed_job] + + with ( + patch("extralit_server.contexts.workflows.Group.fetch", return_value=mock_rq_group), + patch("extralit_server.contexts.workflows.update_workflow_status") as mock_update, + ): + result = await restart_failed_jobs_in_workflow(db_mock, mock_document_workflow) + + assert result["success"] is True + assert result["restarted_jobs"] == ["failed_job_123"] + assert result["total_failed"] == 1 + failed_job.requeue.assert_called_once() + mock_update.assert_called_once_with(db_mock, mock_document_workflow, "running") + + @pytest.mark.asyncio + async def test_restart_failed_jobs_in_workflow_group_expired(self, mock_document_workflow): + """Test restart when group is expired.""" + db_mock = AsyncMock() + + with patch("extralit_server.contexts.workflows.Group.fetch", side_effect=Exception("Group expired")): + result = await restart_failed_jobs_in_workflow(db_mock, mock_document_workflow) + + assert result["success"] is False + assert "Group not found or expired" in result["error"] + assert result["restarted_jobs"] == [] + + @pytest.mark.asyncio + async def test_restart_failed_workflow_partial_success(self, mock_document_workflow): + """Test partial workflow restart (failed jobs only).""" + db_mock = AsyncMock() + document_id = mock_document_workflow.document_id + + with ( + patch.object(DocumentWorkflow, "get_by_document_id", return_value=mock_document_workflow), + patch("extralit_server.contexts.workflows.is_workflow_resumable", return_value=True), + patch("extralit_server.contexts.workflows.restart_failed_jobs_in_workflow") as mock_restart, + ): + mock_restart.return_value = {"success": True, "restarted_jobs": ["job1", "job2"], "total_failed": 2} + + result = await restart_failed_workflow(db_mock, document_id, partial_restart=True) + + assert result["success"] is True + assert result["restarted_jobs"] == ["job1", "job2"] + assert result["total_failed"] == 2 + mock_restart.assert_called_once_with(db_mock, mock_document_workflow) + + @pytest.mark.asyncio + async def test_restart_failed_workflow_full_restart(self, mock_document_workflow, mock_rq_group): + """Test full workflow restart (all jobs).""" + db_mock = AsyncMock() + document_id = mock_document_workflow.document_id + + job1 = MagicMock(spec=Job) + job1.id = "job1" + job1.is_failed = True + job1.requeue = MagicMock() + + job2 = MagicMock(spec=Job) + job2.id = "job2" + job2.is_failed = False + job2.requeue = MagicMock() + + mock_rq_group.get_jobs.return_value = [job1, job2] + + with ( + patch.object(DocumentWorkflow, "get_by_document_id", return_value=mock_document_workflow), + patch("extralit_server.contexts.workflows.is_workflow_resumable", return_value=True), + patch("extralit_server.contexts.workflows.Group.fetch", return_value=mock_rq_group), + patch("extralit_server.contexts.workflows.update_workflow_status") as mock_update, + ): + result = await restart_failed_workflow(db_mock, document_id, partial_restart=False) + + assert result["success"] is True + assert result["restarted_jobs"] == ["job1", "job2"] + assert result["total_failed"] == 1 # Only job1 was failed + assert result["restart_type"] == "full" + + # Both jobs should be requeued + job1.requeue.assert_called_once() + job2.requeue.assert_called_once() + mock_update.assert_called_once_with(db_mock, mock_document_workflow, "running") + + @pytest.mark.asyncio + async def test_restart_failed_workflow_no_workflow(self): + """Test restart when no workflow exists.""" + db_mock = AsyncMock() + document_id = uuid4() + + with patch.object(DocumentWorkflow, "get_by_document_id", return_value=None): + result = await restart_failed_workflow(db_mock, document_id) + + assert result["success"] is False + assert "No workflow found" in result["error"] + + @pytest.mark.asyncio + async def test_restart_failed_workflow_not_resumable(self, mock_document_workflow): + """Test restart when workflow is not resumable.""" + db_mock = AsyncMock() + document_id = mock_document_workflow.document_id + + with ( + patch.object(DocumentWorkflow, "get_by_document_id", return_value=mock_document_workflow), + patch("extralit_server.contexts.workflows.is_workflow_resumable", return_value=False), + ): + result = await restart_failed_workflow(db_mock, document_id) + + assert result["success"] is False + assert "not in a resumable state" in result["error"] diff --git a/extralit/src/extralit/cli/app.py b/extralit/src/extralit/cli/app.py index ab1b0e1df..9cadd69d6 100644 --- a/extralit/src/extralit/cli/app.py +++ b/extralit/src/extralit/cli/app.py @@ -27,6 +27,7 @@ training, users, whoami, + workflows, workspaces, ) from extralit.cli.typer_ext import ExtralitTyper @@ -67,6 +68,7 @@ def register_subcommands(): app.add_typer(training.app, name="training") app.add_typer(users.app, name="users") app.add_typer(whoami.app, name="whoami") + app.add_typer(workflows.app, name="workflows") app.add_typer(workspaces.app, name="workspaces") diff --git a/extralit-server/src/extralit_server/api/schemas/v1/document/ocr.py b/extralit/src/extralit/cli/workflows/__init__.py similarity index 90% rename from extralit-server/src/extralit_server/api/schemas/v1/document/ocr.py rename to extralit/src/extralit/cli/workflows/__init__.py index fb5dffc96..f2b955200 100644 --- a/extralit-server/src/extralit_server/api/schemas/v1/document/ocr.py +++ b/extralit/src/extralit/cli/workflows/__init__.py @@ -12,3 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from .__main__ import app + +if __name__ == "__main__": + app() diff --git a/extralit/src/extralit/cli/workflows/__main__.py b/extralit/src/extralit/cli/workflows/__main__.py new file mode 100644 index 000000000..90a0606ae --- /dev/null +++ b/extralit/src/extralit/cli/workflows/__main__.py @@ -0,0 +1,705 @@ +# Copyright 2024-present, Extralit Labs, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Workflow management CLI commands for PDF processing workflows. + +This module provides CLI commands to manage PDF processing workflows: +- Start workflows for documents +- Check workflow status +- Restart failed workflows +- List workflows with filtering + +The CLI communicates with the server through FastAPI endpoints using the HTTP client, +following the same pattern as the existing import_bib.py command. +""" + +import json +import time +from typing import Optional +from uuid import UUID + +import typer +from rich.console import Console +from rich.progress import Progress, SpinnerColumn, TextColumn +from rich.table import Table + +from extralit.cli.rich import get_themed_panel +from extralit.client import Extralit + +# Create typer app with Rich library components for formatted output +app = typer.Typer(help="Manage PDF processing workflows") + +# Set up HTTP client communication pattern following import_bib.py example +console = Console() + + +def _handle_http_error(response, operation: str) -> None: + """Handle HTTP errors with user-friendly messages.""" + try: + error_detail = response.json().get("detail", str(response.text)) + except Exception: + error_detail = str(response.text) + + panel = get_themed_panel( + f"Error {operation}: {error_detail}", + title="API Error", + title_align="left", + success=False, + ) + console.print(panel) + raise typer.Exit(1) + + +@app.command() +def start( + document_id: str = typer.Option(..., "--document-id", help="Document UUID to process"), + workspace_name: str = typer.Option(..., "--workspace", "-w", help="Workspace name"), + reference: Optional[str] = typer.Option(None, "--reference", "-r", help="Document reference for tracking"), + force: bool = typer.Option(False, "--force", "-f", help="Force restart if workflow already exists"), + verbose: bool = typer.Option(False, "--verbose", "-v", help="Show detailed output"), +) -> None: + """Start PDF processing workflow for a document.""" + try: + client = Extralit.from_credentials() + except Exception as e: + panel = get_themed_panel( + f"Authentication failed: {e}", + title="Authentication Error", + title_align="left", + success=False, + ) + console.print(panel) + raise typer.Exit(1) + + try: + # Validate document_id is a valid UUID + try: + UUID(document_id) + except ValueError: + panel = get_themed_panel( + f"Invalid document ID format: {document_id}", + title="Invalid Input", + title_align="left", + success=False, + ) + console.print(panel) + raise typer.Exit(1) + + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + console=console, + ) as progress: + task = progress.add_task("Starting workflow...", total=None) + + # Use client.api.http_client.post() to call /workflows/start endpoint + response = client.api.http_client.post( + f"{client.api_url}/api/v1/workflows/start", + json={ + "document_id": document_id, + "workspace_name": workspace_name, + "reference": reference or f"doc_{document_id[:8]}", + "force": force, + }, + ) + + progress.update(task, completed=True, description="Workflow start request completed") + + # Add validation and error handling for HTTP responses + if response.status_code != 200: + _handle_http_error(response, "starting workflow") + + result = response.json() + + # Add confirmation prompts and detailed output formatting + panel = get_themed_panel( + f"✓ Started workflow {result['workflow_id']}", + title="Workflow Started", + title_align="left", + success=True, + ) + console.print(panel) + + if verbose: + console.print(f"[dim]Document ID:[/dim] {result['document_id']}") + console.print(f"[dim]Reference:[/dim] {result.get('reference', 'N/A')}") + console.print(f"[dim]Group ID:[/dim] {result['group_id']}") + console.print(f"[dim]Status:[/dim] {result['status']}") + + console.print(f"\n[bold]Track progress with:[/bold] extralit workflows status --document-id {document_id}") + + except typer.Exit: + raise + except Exception as e: + # Handle errors gracefully with user-friendly messages + panel = get_themed_panel( + f"Unexpected error starting workflow: {e}", + title="Error", + title_align="left", + success=False, + ) + console.print(panel) + raise typer.Exit(1) + + +@app.command() +def status( + document_id: Optional[str] = typer.Option(None, "--document-id", help="Document UUID to check"), + reference: Optional[str] = typer.Option(None, "--reference", "-r", help="Document reference to check"), + workspace_name: Optional[str] = typer.Option(None, "--workspace", "-w", help="Filter by workspace name"), + watch: bool = typer.Option(False, "--watch", help="Watch status updates in real-time"), + json_output: bool = typer.Option(False, "--json", help="Output status as JSON"), +) -> None: + """Check workflow status for documents.""" + try: + client = Extralit.from_credentials() + except Exception as e: + panel = get_themed_panel( + f"Authentication failed: {e}", + title="Authentication Error", + title_align="left", + success=False, + ) + console.print(panel) + raise typer.Exit(1) + + try: + if not document_id and not reference: + panel = get_themed_panel( + "Must specify either --document-id or --reference", + title="Missing Required Parameter", + title_align="left", + success=False, + ) + console.print(panel) + raise typer.Exit(1) + + # Validate document_id if provided + if document_id: + try: + UUID(document_id) + except ValueError: + panel = get_themed_panel( + f"Invalid document ID format: {document_id}", + title="Invalid Input", + title_align="left", + success=False, + ) + console.print(panel) + raise typer.Exit(1) + + def get_workflow_status(): + """Get workflow status from API.""" + # Use client.api.http_client.get() to call /workflows/status endpoint + params = {} + if document_id: + params["document_id"] = document_id + if reference: + params["reference"] = reference + if workspace_name: + params["workspace_name"] = workspace_name + + response = client.api.http_client.get( + f"{client.api_url}/api/v1/workflows/status", + params=params, + ) + + if response.status_code != 200: + _handle_http_error(response, "checking workflow status") + + return response.json() + + if watch: + # Add real-time status watching with --watch flag and periodic updates + try: + while True: + console.clear() + workflows = get_workflow_status() + + if not workflows: + console.print("[yellow]No workflows found[/yellow]") + else: + # Support JSON output format for scripting and automation + if json_output: + console.print(json.dumps(workflows, indent=2, default=str)) + else: + # Implement _display_workflow_status_table() helper function using Rich Table + _display_workflow_status_table(workflows) + + console.print("\n[dim]Press Ctrl+C to stop watching[/dim]") + time.sleep(5) + except KeyboardInterrupt: + console.print("\n[yellow]Stopped watching[/yellow]") + else: + workflows = get_workflow_status() + + if not workflows: + console.print("[yellow]No workflows found[/yellow]") + return + + # Support JSON output format for scripting and automation + if json_output: + console.print(json.dumps(workflows, indent=2, default=str)) + return + + # Implement _display_workflow_status_table() helper function using Rich Table + _display_workflow_status_table(workflows) + + except typer.Exit: + raise + except Exception as e: + panel = get_themed_panel( + f"Unexpected error checking workflow status: {e}", + title="Error", + title_align="left", + success=False, + ) + console.print(panel) + raise typer.Exit(1) + + +def _display_workflow_status_table(workflows: list) -> None: + """ + Implement _display_workflow_status_table() helper function using Rich Table. + + Calculate and display progress percentages and duration information. + """ + table = Table(title="PDF Processing Workflows") + table.add_column("Document ID", style="cyan", no_wrap=True) + table.add_column("Reference", style="magenta") + table.add_column("Workspace", style="blue") + table.add_column("Status", style="green") + table.add_column("Progress", style="yellow") + table.add_column("Started", style="dim") + table.add_column("Duration", style="dim") + + for workflow in workflows: + # Calculate progress percentage and duration information + total_jobs = workflow.get("total_jobs", 0) + completed_jobs = workflow.get("completed_jobs", 0) + progress_pct = int(completed_jobs / total_jobs * 100) if total_jobs > 0 else 0 + progress = f"{completed_jobs}/{total_jobs} ({progress_pct}%)" + + # Format status with color + status = workflow["status"] + if status == "completed": + status = f"[green]{status}[/green]" + elif status == "failed": + status = f"[red]{status}[/red]" + elif status == "running": + status = f"[yellow]{status}[/yellow]" + + # Calculate duration + from datetime import datetime + + created_at = workflow.get("created_at") + if created_at: + if isinstance(created_at, str): + try: + created_dt = datetime.fromisoformat(created_at.replace("Z", "+00:00")) + duration = str(datetime.utcnow() - created_dt.replace(tzinfo=None)).split(".")[0] + except Exception: + duration = "Unknown" + else: + duration = "Unknown" + else: + duration = "Unknown" + + # Format started time + started_str = "N/A" + if created_at: + try: + if isinstance(created_at, str): + created_dt = datetime.fromisoformat(created_at.replace("Z", "+00:00")) + started_str = created_dt.strftime("%Y-%m-%d %H:%M") + except Exception: + started_str = "N/A" + + table.add_row( + workflow["document_id"][:8] + "..." if len(workflow["document_id"]) > 8 else workflow["document_id"], + workflow.get("reference", "N/A"), + workflow.get("workspace_name", "N/A"), + status, + progress, + started_str, + duration, + ) + + console.print(table) + + +@app.command() +def restart( + document_id: Optional[str] = typer.Option(None, "--document-id", help="Document UUID to restart"), + reference: Optional[str] = typer.Option(None, "--reference", "-r", help="Document reference to restart"), + workspace_name: Optional[str] = typer.Option(None, "--workspace", "-w", help="Filter by workspace name"), + failed_only: bool = typer.Option(True, "--failed-only/--all", help="Only restart failed jobs"), + confirm: bool = typer.Option(False, "--yes", "-y", help="Skip confirmation prompt"), +) -> None: + """Restart failed workflow jobs for documents.""" + try: + client = Extralit.from_credentials() + except Exception as e: + panel = get_themed_panel( + f"Authentication failed: {e}", + title="Authentication Error", + title_align="left", + success=False, + ) + console.print(panel) + raise typer.Exit(1) + + try: + if not document_id and not reference: + panel = get_themed_panel( + "Must specify either --document-id or --reference", + title="Missing Required Parameter", + title_align="left", + success=False, + ) + console.print(panel) + raise typer.Exit(1) + + # Validate document_id if provided + if document_id: + try: + UUID(document_id) + except ValueError: + panel = get_themed_panel( + f"Invalid document ID format: {document_id}", + title="Invalid Input", + title_align="left", + success=False, + ) + console.print(panel) + raise typer.Exit(1) + + # First get workflows to restart + params = {} + if document_id: + params["document_id"] = document_id + if reference: + params["reference"] = reference + if workspace_name: + params["workspace_name"] = workspace_name + + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + console=console, + ) as progress: + task = progress.add_task("Getting workflow status...", total=None) + + status_response = client.api.http_client.get( + f"{client.api_url}/api/v1/workflows/status", + params=params, + ) + + progress.update(task, completed=True, description="Workflow status retrieved") + + if status_response.status_code != 200: + _handle_http_error(status_response, "getting workflow status") + + workflows = status_response.json() + failed_workflows = [w for w in workflows if w["status"] == "failed"] + + if not failed_workflows: + panel = get_themed_panel( + "No failed workflows found", + title="No Workflows to Restart", + title_align="left", + success=True, + ) + console.print(panel) + return + + # Add confirmation prompts before restarting workflows + if not confirm: + workflow_count = len(failed_workflows) + restart_type = "failed jobs only" if failed_only else "all jobs" + + console.print(f"\n[bold]Found {workflow_count} failed workflow(s) to restart ({restart_type}):[/bold]") + for workflow in failed_workflows: + console.print(f" • Document {workflow['document_id'][:8]}... - {workflow.get('reference', 'N/A')}") + + if not typer.confirm(f"\nRestart {workflow_count} workflow(s)?"): + console.print("Cancelled") + return + + # Restart workflows + restarted_count = 0 + failed_count = 0 + + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + console=console, + ) as progress: + task = progress.add_task("Restarting workflows...", total=len(failed_workflows)) + + for _i, workflow in enumerate(failed_workflows): + try: + # Use client.api.http_client.post() to call /workflows/restart endpoint + restart_response = client.api.http_client.post( + f"{client.api_url}/api/v1/workflows/restart", + json={ + "document_id": workflow["document_id"], + "failed_only": failed_only, + }, + ) + + if restart_response.status_code == 200: + result = restart_response.json() + restarted_jobs = result.get("restarted_jobs", []) + console.print( + f"[green]✓ Restarted workflow for document {workflow['document_id'][:8]}... " + f"({len(restarted_jobs)} jobs)[/green]" + ) + restarted_count += 1 + else: + try: + error_detail = restart_response.json().get("detail", "Unknown error") + except Exception: + error_detail = str(restart_response.text) + console.print( + f"[red]✗ Failed to restart workflow for document {workflow['document_id'][:8]}...: " + f"{error_detail}[/red]" + ) + failed_count += 1 + + except Exception as e: + console.print( + f"[red]✗ Failed to restart workflow for document {workflow['document_id'][:8]}...: {e}[/red]" + ) + failed_count += 1 + + progress.update(task, advance=1) + + # Display progress and results of restart operations + if restarted_count > 0: + panel = get_themed_panel( + f"Successfully restarted {restarted_count} of {len(failed_workflows)} workflows", + title="Restart Complete", + title_align="left", + success=True, + ) + console.print(panel) + + if failed_count > 0: + panel = get_themed_panel( + f"Failed to restart {failed_count} workflows", + title="Restart Errors", + title_align="left", + success=False, + ) + console.print(panel) + + except typer.Exit: + raise + except Exception as e: + panel = get_themed_panel( + f"Unexpected error restarting workflows: {e}", + title="Error", + title_align="left", + success=False, + ) + console.print(panel) + raise typer.Exit(1) + + +@app.command() +def list( + workspace_name: Optional[str] = typer.Option(None, "--workspace", "-w", help="Filter by workspace name"), + status_filter: Optional[str] = typer.Option( + None, "--status", "-s", help="Filter by status (running, completed, failed)" + ), + limit: int = typer.Option(50, "--limit", "-l", help="Maximum number of workflows to show"), + json_output: bool = typer.Option(False, "--json", help="Output as JSON"), +) -> None: + """List recent workflows.""" + try: + client = Extralit.from_credentials() + except Exception as e: + panel = get_themed_panel( + f"Authentication failed: {e}", + title="Authentication Error", + title_align="left", + success=False, + ) + console.print(panel) + raise typer.Exit(1) + + try: + # Don't show progress for JSON output + if not json_output: + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + console=console, + ) as progress: + task = progress.add_task("Fetching workflows...", total=None) + + # Use client.api.http_client.get() to call /workflows/ endpoint + params = {"limit": limit} + if workspace_name: + params["workspace_name"] = workspace_name + if status_filter: + params["status_filter"] = status_filter + + response = client.api.http_client.get( + f"{client.api_url}/api/v1/workflows/", + params=params, + ) + progress.update(task, completed=True, description="Workflows retrieved") + else: + # Use client.api.http_client.get() to call /workflows/ endpoint + params = {"limit": limit} + if workspace_name: + params["workspace_name"] = workspace_name + if status_filter: + params["status_filter"] = status_filter + + response = client.api.http_client.get( + f"{client.api_url}/api/v1/workflows/", + params=params, + ) + + if response.status_code != 200: + _handle_http_error(response, "listing workflows") + + workflows = response.json() + + if not workflows: + panel = get_themed_panel( + "No workflows found", + title="No Workflows", + title_align="left", + success=True, + ) + console.print(panel) + return + + # Support both table and JSON output formats + if json_output: + console.print(json.dumps(workflows, indent=2, default=str)) + return + + # Display comprehensive workflow information in formatted table + _display_workflow_list_table(workflows) + + except typer.Exit: + raise + except Exception as e: + panel = get_themed_panel( + f"Unexpected error listing workflows: {e}", + title="Error", + title_align="left", + success=False, + ) + console.print(panel) + raise typer.Exit(1) + + +def _display_workflow_list_table(workflows: list) -> None: + """ + Display comprehensive workflow information in formatted table. + + Add filtering capabilities by workspace and status. + Implement pagination with configurable limits. + """ + table = Table(title=f"Recent Workflows ({len(workflows)} shown)") + table.add_column("Document ID", style="cyan", no_wrap=True) + table.add_column("Reference", style="magenta") + table.add_column("Workspace", style="blue") + table.add_column("Type", style="green") + table.add_column("Status", style="yellow") + table.add_column("Progress", style="yellow") + table.add_column("Jobs", style="dim") + table.add_column("Created", style="dim") + table.add_column("Duration", style="dim") + + for workflow in workflows: + # Calculate progress and job statistics + total_jobs = workflow.get("total_jobs", 0) + completed_jobs = workflow.get("completed_jobs", 0) + failed_jobs = workflow.get("failed_jobs", 0) + running_jobs = workflow.get("running_jobs", 0) + + progress_pct = int(completed_jobs / total_jobs * 100) if total_jobs > 0 else 0 + progress = f"{progress_pct}%" + + # Format status with color + status = workflow["status"] + if status == "completed": + status = f"[green]{status}[/green]" + elif status == "failed": + status = f"[red]{status}[/red]" + elif status == "running": + status = f"[yellow]{status}[/yellow]" + elif status == "pending": + status = f"[blue]{status}[/blue]" + + # Format job statistics + jobs_info = f"{completed_jobs}✓" + if failed_jobs > 0: + jobs_info += f" {failed_jobs}✗" + if running_jobs > 0: + jobs_info += f" {running_jobs}⟳" + jobs_info += f"/{total_jobs}" + + # Calculate duration + from datetime import datetime + + created_at = workflow.get("created_at") + duration = "Unknown" + created_str = "N/A" + + if created_at: + try: + if isinstance(created_at, str): + created_dt = datetime.fromisoformat(created_at.replace("Z", "+00:00")) + duration = str(datetime.utcnow() - created_dt.replace(tzinfo=None)).split(".")[0] + created_str = created_dt.strftime("%m-%d %H:%M") + except Exception: + pass + + table.add_row( + workflow["document_id"][:8] + "..." if len(workflow["document_id"]) > 8 else workflow["document_id"], + workflow.get("reference", "N/A")[:20] + ("..." if len(workflow.get("reference", "")) > 20 else ""), + workflow.get("workspace_name", "N/A")[:15] + + ("..." if len(workflow.get("workspace_name", "")) > 15 else ""), + workflow.get("workflow_type", "unknown")[:10], + status, + progress, + jobs_info, + created_str, + duration, + ) + + console.print(table) + + # Add summary information + total_workflows = len(workflows) + completed_count = len([w for w in workflows if w["status"] == "completed"]) + failed_count = len([w for w in workflows if w["status"] == "failed"]) + running_count = len([w for w in workflows if w["status"] == "running"]) + + console.print( + f"\n[dim]Summary: {completed_count} completed, {running_count} running, {failed_count} failed out of {total_workflows} total[/dim]" + ) + + +if __name__ == "__main__": + app() diff --git a/extralit/tests/unit/cli/test_workflows_cli.py b/extralit/tests/unit/cli/test_workflows_cli.py new file mode 100644 index 000000000..81aa10371 --- /dev/null +++ b/extralit/tests/unit/cli/test_workflows_cli.py @@ -0,0 +1,533 @@ +# Copyright 2024-present, Extralit Labs, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for CLI workflow commands with RQ Groups integration.""" + +import json +from unittest.mock import MagicMock, patch +from uuid import uuid4 + +import pytest +from typer.testing import CliRunner + +from extralit.cli.workflows.__main__ import app + + +class TestWorkflowsCLI: + """Test CLI workflow commands with RQ Groups integration.""" + + @pytest.fixture + def runner(self): + """Create CLI test runner.""" + return CliRunner() + + @pytest.fixture + def mock_client(self): + """Mock Extralit client.""" + client = MagicMock() + client.api_url = "http://localhost:8000" + client.api.http_client = MagicMock() + return client + + @pytest.fixture + def sample_workflow_response(self): + """Sample workflow API response.""" + return { + "workflow_id": str(uuid4()), + "document_id": str(uuid4()), + "group_id": "document_workflow_123_abcd1234", + "reference": "test_ref", + "status": "running", + } + + @pytest.fixture + def sample_status_response(self): + """Sample workflow status API response.""" + return [ + { + "document_id": str(uuid4()), + "workflow_id": str(uuid4()), + "group_id": "document_workflow_123_abcd1234", + "reference": "test_ref", + "workspace_name": "test_workspace", + "status": "running", + "progress": 0.5, + "total_jobs": 2, + "completed_jobs": 1, + "failed_jobs": 0, + "running_jobs": 1, + "created_at": "2024-01-15T10:30:00Z", + "jobs": [ + {"id": "analysis_job", "status": "finished", "workflow_step": "analysis_and_preprocess"}, + {"id": "text_job", "status": "started", "workflow_step": "text_extraction"}, + ], + } + ] + + def test_start_command_success(self, runner, mock_client, sample_workflow_response): + """Test successful workflow start command.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = sample_workflow_response + mock_client.api.http_client.post.return_value = mock_response + + with patch("extralit.cli.workflows.__main__.Extralit.from_credentials", return_value=mock_client): + result = runner.invoke( + app, + ["start", "--document-id", str(uuid4()), "--workspace", "test_workspace", "--reference", "test_ref"], + ) + + assert result.exit_code == 0 + assert "Started workflow" in result.stdout + assert sample_workflow_response["workflow_id"] in result.stdout + + # Verify API call + mock_client.api.http_client.post.assert_called_once() + call_args = mock_client.api.http_client.post.call_args + assert "/api/v1/workflows/start" in call_args[0][0] + + request_data = call_args[1]["json"] + assert "document_id" in request_data + assert request_data["workspace_name"] == "test_workspace" + assert request_data["reference"] == "test_ref" + + def test_start_command_invalid_document_id(self, runner, mock_client): + """Test start command with invalid document ID format.""" + with patch("extralit.cli.workflows.__main__.Extralit.from_credentials", return_value=mock_client): + result = runner.invoke(app, ["start", "--document-id", "invalid-uuid", "--workspace", "test_workspace"]) + + assert result.exit_code == 1 + assert "Invalid document ID format" in result.stdout + + def test_start_command_api_error(self, runner, mock_client): + """Test start command with API error response.""" + mock_response = MagicMock() + mock_response.status_code = 400 + mock_response.json.return_value = {"detail": "Document not found"} + mock_client.api.http_client.post.return_value = mock_response + + with patch("extralit.cli.workflows.__main__.Extralit.from_credentials", return_value=mock_client): + result = runner.invoke(app, ["start", "--document-id", str(uuid4()), "--workspace", "test_workspace"]) + + assert result.exit_code == 1 + assert "Error starting workflow" in result.stdout + assert "Document not found" in result.stdout + + def test_start_command_with_verbose(self, runner, mock_client, sample_workflow_response): + """Test start command with verbose output.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = sample_workflow_response + mock_client.api.http_client.post.return_value = mock_response + + with patch("extralit.cli.workflows.__main__.Extralit.from_credentials", return_value=mock_client): + result = runner.invoke( + app, ["start", "--document-id", str(uuid4()), "--workspace", "test_workspace", "--verbose"] + ) + + assert result.exit_code == 0 + assert "Document ID:" in result.stdout + assert "Group ID:" in result.stdout + assert sample_workflow_response["group_id"] in result.stdout + + def test_status_command_by_document_id(self, runner, mock_client, sample_status_response): + """Test status command with document ID filter.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = sample_status_response + mock_client.api.http_client.get.return_value = mock_response + + document_id = str(uuid4()) + with patch("extralit.cli.workflows.__main__.Extralit.from_credentials", return_value=mock_client): + result = runner.invoke(app, ["status", "--document-id", document_id]) + + assert result.exit_code == 0 + assert "PDF Processing Workflows" in result.stdout + assert "running" in result.stdout + # Progress is displayed as "completed_jobs/total_jobs (percentage%)" + # The exact format may vary due to Rich table formatting + + # Verify API call + mock_client.api.http_client.get.assert_called_once() + call_args = mock_client.api.http_client.get.call_args + assert "/api/v1/workflows/status" in call_args[0][0] + assert call_args[1]["params"]["document_id"] == document_id + + def test_status_command_by_reference(self, runner, mock_client, sample_status_response): + """Test status command with reference filter.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = sample_status_response + mock_client.api.http_client.get.return_value = mock_response + + with patch("extralit.cli.workflows.__main__.Extralit.from_credentials", return_value=mock_client): + result = runner.invoke(app, ["status", "--reference", "test_ref", "--workspace", "test_workspace"]) + + assert result.exit_code == 0 + assert "PDF Processing Workflows" in result.stdout + + # Verify API call parameters + call_args = mock_client.api.http_client.get.call_args + params = call_args[1]["params"] + assert params["reference"] == "test_ref" + assert params["workspace_name"] == "test_workspace" + + def test_status_command_json_output(self, runner, mock_client, sample_status_response): + """Test status command with JSON output format.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = sample_status_response + mock_client.api.http_client.get.return_value = mock_response + + with patch("extralit.cli.workflows.__main__.Extralit.from_credentials", return_value=mock_client): + result = runner.invoke(app, ["status", "--document-id", str(uuid4()), "--json"]) + + assert result.exit_code == 0 + + # Verify JSON output + try: + output_data = json.loads(result.stdout.strip()) + assert len(output_data) == 1 + assert output_data[0]["status"] == "running" + assert output_data[0]["group_id"] == "document_workflow_123_abcd1234" + except json.JSONDecodeError: + pytest.fail("Output is not valid JSON") + + def test_status_command_no_workflows_found(self, runner, mock_client): + """Test status command when no workflows are found.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = [] + mock_client.api.http_client.get.return_value = mock_response + + with patch("extralit.cli.workflows.__main__.Extralit.from_credentials", return_value=mock_client): + result = runner.invoke(app, ["status", "--document-id", str(uuid4())]) + + assert result.exit_code == 0 + assert "No workflows found" in result.stdout + + def test_status_command_missing_parameters(self, runner, mock_client): + """Test status command without required parameters.""" + with patch("extralit.cli.workflows.__main__.Extralit.from_credentials", return_value=mock_client): + result = runner.invoke(app, ["status"]) + + assert result.exit_code == 1 + assert "Must specify either --document-id or --reference" in result.stdout + + def test_restart_command_success(self, runner, mock_client, sample_status_response): + """Test successful workflow restart command.""" + # Mock status response with failed workflow + failed_status = sample_status_response.copy() + failed_status[0]["status"] = "failed" + + mock_status_response = MagicMock() + mock_status_response.status_code = 200 + mock_status_response.json.return_value = failed_status + + mock_restart_response = MagicMock() + mock_restart_response.status_code = 200 + mock_restart_response.json.return_value = { + "success": True, + "restarted_jobs": ["failed_job_1", "failed_job_2"], + "total_failed": 2, + } + + mock_client.api.http_client.get.return_value = mock_status_response + mock_client.api.http_client.post.return_value = mock_restart_response + + with patch("extralit.cli.workflows.__main__.Extralit.from_credentials", return_value=mock_client): + result = runner.invoke( + app, + [ + "restart", + "--document-id", + str(uuid4()), + "--yes", # Skip confirmation + ], + ) + + assert result.exit_code == 0 + assert "Restarted workflow" in result.stdout + assert "(2 jobs)" in result.stdout + + # Verify API calls + assert mock_client.api.http_client.get.call_count == 1 + assert mock_client.api.http_client.post.call_count == 1 + + # Verify restart API call + restart_call = mock_client.api.http_client.post.call_args + assert "/api/v1/workflows/restart" in restart_call[0][0] + assert restart_call[1]["json"]["failed_only"] is True + + def test_restart_command_no_failed_workflows(self, runner, mock_client, sample_status_response): + """Test restart command when no failed workflows exist.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = sample_status_response # Running workflow + mock_client.api.http_client.get.return_value = mock_response + + with patch("extralit.cli.workflows.__main__.Extralit.from_credentials", return_value=mock_client): + result = runner.invoke(app, ["restart", "--document-id", str(uuid4())]) + + assert result.exit_code == 0 + assert "No failed workflows found" in result.stdout + + def test_restart_command_with_confirmation(self, runner, mock_client, sample_status_response): + """Test restart command with user confirmation.""" + # Mock failed workflow + failed_status = sample_status_response.copy() + failed_status[0]["status"] = "failed" + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = failed_status + mock_client.api.http_client.get.return_value = mock_response + + with ( + patch("extralit.cli.workflows.__main__.Extralit.from_credentials", return_value=mock_client), + patch("typer.confirm", return_value=False), + ): # User cancels + result = runner.invoke(app, ["restart", "--document-id", str(uuid4())]) + + assert result.exit_code == 0 + assert "Cancelled" in result.stdout + + # Should not call restart API + mock_client.api.http_client.post.assert_not_called() + + def test_restart_command_full_restart(self, runner, mock_client, sample_status_response): + """Test restart command with full restart option.""" + failed_status = sample_status_response.copy() + failed_status[0]["status"] = "failed" + + mock_status_response = MagicMock() + mock_status_response.status_code = 200 + mock_status_response.json.return_value = failed_status + + mock_restart_response = MagicMock() + mock_restart_response.status_code = 200 + mock_restart_response.json.return_value = { + "success": True, + "restarted_jobs": ["job1", "job2", "job3"], + "total_failed": 1, + "restart_type": "full", + } + + mock_client.api.http_client.get.return_value = mock_status_response + mock_client.api.http_client.post.return_value = mock_restart_response + + with patch("extralit.cli.workflows.__main__.Extralit.from_credentials", return_value=mock_client): + result = runner.invoke( + app, + [ + "restart", + "--document-id", + str(uuid4()), + "--all", # Full restart (this sets failed_only=False) + "--yes", + ], + ) + + assert result.exit_code == 0 + + # Verify restart API call with correct parameters + restart_call = mock_client.api.http_client.post.call_args + assert restart_call[1]["json"]["failed_only"] is False + + def test_list_command_success(self, runner, mock_client, sample_status_response): + """Test successful workflow list command.""" + # Add more workflows to the response + extended_response = sample_status_response * 3 # 3 workflows + for i, workflow in enumerate(extended_response): + workflow["document_id"] = str(uuid4()) + workflow["reference"] = f"test_ref_{i}" + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = extended_response + mock_client.api.http_client.get.return_value = mock_response + + with patch("extralit.cli.workflows.__main__.Extralit.from_credentials", return_value=mock_client): + result = runner.invoke(app, ["list"]) + + assert result.exit_code == 0 + assert "Recent Workflows (3 shown)" in result.stdout + assert "Summary:" in result.stdout + + # Verify API call + call_args = mock_client.api.http_client.get.call_args + assert "/api/v1/workflows/" in call_args[0][0] + assert call_args[1]["params"]["limit"] == 50 + + def test_list_command_with_filters(self, runner, mock_client, sample_status_response): + """Test list command with workspace and status filters.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = sample_status_response + mock_client.api.http_client.get.return_value = mock_response + + with patch("extralit.cli.workflows.__main__.Extralit.from_credentials", return_value=mock_client): + result = runner.invoke( + app, ["list", "--workspace", "test_workspace", "--status", "failed", "--limit", "25"] + ) + + assert result.exit_code == 0 + + # Verify API call parameters + call_args = mock_client.api.http_client.get.call_args + params = call_args[1]["params"] + assert params["workspace_name"] == "test_workspace" + assert params["status_filter"] == "failed" + assert params["limit"] == 25 + + def test_list_command_json_output(self, runner, mock_client, sample_status_response): + """Test list command with JSON output.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = sample_status_response + mock_client.api.http_client.get.return_value = mock_response + + with patch("extralit.cli.workflows.__main__.Extralit.from_credentials", return_value=mock_client): + result = runner.invoke(app, ["list", "--json"]) + + assert result.exit_code == 0 + + # Verify JSON output + try: + output_data = json.loads(result.stdout.strip()) + assert len(output_data) == 1 + assert output_data[0]["status"] == "running" + except json.JSONDecodeError: + pytest.fail("Output is not valid JSON") + + def test_list_command_no_workflows(self, runner, mock_client): + """Test list command when no workflows exist.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = [] + mock_client.api.http_client.get.return_value = mock_response + + with patch("extralit.cli.workflows.__main__.Extralit.from_credentials", return_value=mock_client): + result = runner.invoke(app, ["list"]) + + assert result.exit_code == 0 + assert "No workflows found" in result.stdout + + def test_authentication_failure(self, runner): + """Test CLI commands with authentication failure.""" + with patch( + "extralit.cli.workflows.__main__.Extralit.from_credentials", side_effect=Exception("Authentication failed") + ): + result = runner.invoke(app, ["start", "--document-id", str(uuid4()), "--workspace", "test_workspace"]) + + assert result.exit_code == 1 + assert "Authentication failed" in result.stdout + + def test_api_connection_error(self, runner, mock_client): + """Test CLI commands with API connection errors.""" + mock_client.api.http_client.post.side_effect = Exception("Connection refused") + + with patch("extralit.cli.workflows.__main__.Extralit.from_credentials", return_value=mock_client): + result = runner.invoke(app, ["start", "--document-id", str(uuid4()), "--workspace", "test_workspace"]) + + assert result.exit_code == 1 + assert "Unexpected error" in result.stdout + + def test_workflow_status_table_formatting(self, runner, mock_client): + """Test workflow status table formatting with various states.""" + complex_status_response = [ + { + "document_id": str(uuid4()), + "workflow_id": str(uuid4()), + "group_id": "group_1", + "reference": "completed_workflow", + "workspace_name": "test_workspace", + "status": "completed", + "progress": 1.0, + "total_jobs": 3, + "completed_jobs": 3, + "failed_jobs": 0, + "running_jobs": 0, + "created_at": "2024-01-15T10:30:00Z", + }, + { + "document_id": str(uuid4()), + "workflow_id": str(uuid4()), + "group_id": "group_2", + "reference": "failed_workflow", + "workspace_name": "test_workspace", + "status": "failed", + "progress": 0.67, + "total_jobs": 3, + "completed_jobs": 2, + "failed_jobs": 1, + "running_jobs": 0, + "created_at": "2024-01-15T11:00:00Z", + }, + { + "document_id": str(uuid4()), + "workflow_id": str(uuid4()), + "group_id": "group_3", + "reference": "running_workflow", + "workspace_name": "test_workspace", + "status": "running", + "progress": 0.33, + "total_jobs": 3, + "completed_jobs": 1, + "failed_jobs": 0, + "running_jobs": 2, + "created_at": "2024-01-15T11:30:00Z", + }, + ] + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = complex_status_response + mock_client.api.http_client.get.return_value = mock_response + + with patch("extralit.cli.workflows.__main__.Extralit.from_credentials", return_value=mock_client): + result = runner.invoke(app, ["status", "--reference", "test_workflows"]) + + assert result.exit_code == 0 + assert "PDF Processing Workflows" in result.stdout + + # Check that workflow data is displayed (status may be formatted with colors) + # The exact formatting may vary due to Rich table rendering + output = result.stdout + # Check that some form of status information is present + assert any(status in output for status in ["completed", "failed", "running"]) + + # Check that progress information is displayed in some format + # Progress may be displayed as percentages or fractions + assert any(progress in output for progress in ["100%", "67%", "33%", "3/3", "2/3", "1/3"]) + + @patch("time.sleep") # Mock sleep to speed up test + def test_status_command_watch_mode(self, mock_sleep, runner, mock_client, sample_status_response): + """Test status command with watch mode.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = sample_status_response + mock_client.api.http_client.get.return_value = mock_response + + # Mock KeyboardInterrupt to exit watch mode + mock_sleep.side_effect = [None, KeyboardInterrupt()] + + with patch("extralit.cli.workflows.__main__.Extralit.from_credentials", return_value=mock_client): + result = runner.invoke(app, ["status", "--document-id", str(uuid4()), "--watch"]) + + assert result.exit_code == 0 + assert "Stopped watching" in result.stdout + + # Should have made multiple API calls + assert mock_client.api.http_client.get.call_count >= 2