From c329fea0cef9e8da552250b7020ac399fb14258b Mon Sep 17 00:00:00 2001 From: Bao Tran Date: Thu, 7 May 2026 15:37:31 -0400 Subject: [PATCH 1/6] feat(streaming): add document.* and helox training topic constants --- src/deepiri_modelkit/streaming/topics.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/deepiri_modelkit/streaming/topics.py b/src/deepiri_modelkit/streaming/topics.py index 4e99a15..cb9634a 100644 --- a/src/deepiri_modelkit/streaming/topics.py +++ b/src/deepiri_modelkit/streaming/topics.py @@ -11,7 +11,15 @@ class StreamTopics(str, Enum): PLATFORM_EVENTS = "platform-events" AGI_DECISIONS = "agi-decisions" TRAINING_EVENTS = "training-events" - + # LIS document routing streams (document.* namespace) + DOCUMENT_VECTORIZE = "document.vectorize" + DOCUMENT_TRAINING = "document.training" + DOCUMENT_STRUCTURED = "document.structured" + DOCUMENT_ARTIFACTS = "document.artifacts" + # Cyrex runtime training signals for Helox (pipeline.* namespace) + HELOX_TRAINING_RAW = "pipeline.helox-training.raw" + HELOX_TRAINING_STRUCTURED = "pipeline.helox-training.structured" + @classmethod def all(cls) -> list: """Get all topic names""" From b10e4f7f2dbbf5e6a1b346ca3e7ad2f6584d1c30 Mon Sep 17 00:00:00 2001 From: Bao Tran Date: Thu, 7 May 2026 16:00:09 -0400 Subject: [PATCH 2/6] style(streaming): normalize topic constants file --- src/deepiri_modelkit/streaming/topics.py | 57 +++++++++++++----------- 1 file changed, 30 insertions(+), 27 deletions(-) diff --git a/src/deepiri_modelkit/streaming/topics.py b/src/deepiri_modelkit/streaming/topics.py index cb9634a..f3515db 100644 --- a/src/deepiri_modelkit/streaming/topics.py +++ b/src/deepiri_modelkit/streaming/topics.py @@ -1,27 +1,30 @@ -""" -Stream topic definitions -""" -from enum import Enum - - -class StreamTopics(str, Enum): - """Redis Stream topics""" - MODEL_EVENTS = "model-events" - INFERENCE_EVENTS = "inference-events" - PLATFORM_EVENTS = "platform-events" - AGI_DECISIONS = "agi-decisions" - TRAINING_EVENTS = "training-events" - # LIS document routing streams (document.* namespace) - DOCUMENT_VECTORIZE = "document.vectorize" - DOCUMENT_TRAINING = "document.training" - DOCUMENT_STRUCTURED = "document.structured" - DOCUMENT_ARTIFACTS = "document.artifacts" - # Cyrex runtime training signals for Helox (pipeline.* namespace) - HELOX_TRAINING_RAW = "pipeline.helox-training.raw" - HELOX_TRAINING_STRUCTURED = "pipeline.helox-training.structured" - - @classmethod - def all(cls) -> list: - """Get all topic names""" - return [topic.value for topic in cls] - +""" +Stream topic definitions. +""" + +from enum import Enum + + +class StreamTopics(str, Enum): + """Redis Stream topics.""" + + MODEL_EVENTS = "model-events" + INFERENCE_EVENTS = "inference-events" + PLATFORM_EVENTS = "platform-events" + AGI_DECISIONS = "agi-decisions" + TRAINING_EVENTS = "training-events" + + # LIS document routing streams (document.* namespace). + DOCUMENT_VECTORIZE = "document.vectorize" + DOCUMENT_TRAINING = "document.training" + DOCUMENT_STRUCTURED = "document.structured" + DOCUMENT_ARTIFACTS = "document.artifacts" + + # Cyrex runtime training signals for Helox (pipeline.* namespace). + HELOX_TRAINING_RAW = "pipeline.helox-training.raw" + HELOX_TRAINING_STRUCTURED = "pipeline.helox-training.structured" + + @classmethod + def all(cls) -> list[str]: + """Get all topic names.""" + return [topic.value for topic in cls] From 1c738611e55f75f30111fa077a2794b11a8f091c Mon Sep 17 00:00:00 2001 From: Bao Tran Date: Thu, 7 May 2026 16:04:14 -0400 Subject: [PATCH 3/6] style: format modelkit source with black --- src/deepiri_modelkit/__init__.py | 77 +- src/deepiri_modelkit/contracts/contract.py | 57 +- src/deepiri_modelkit/contracts/events.py | 206 ++-- src/deepiri_modelkit/contracts/models.py | 292 +++--- src/deepiri_modelkit/contracts/services.py | 102 +- src/deepiri_modelkit/data/monitoring.py | 794 ++++++++------- src/deepiri_modelkit/data/validation.py | 764 +++++++------- src/deepiri_modelkit/logging.py | 318 +++--- src/deepiri_modelkit/ml/__init__.py | 68 +- src/deepiri_modelkit/ml/confidence.py | 608 ++++++------ src/deepiri_modelkit/ml/semantic.py | 704 ++++++------- src/deepiri_modelkit/rag/__init__.py | 321 +++--- .../rag/advanced_retrieval.py | 816 +++++++-------- src/deepiri_modelkit/rag/async_processing.py | 169 ++-- src/deepiri_modelkit/rag/base.py | 617 ++++++------ src/deepiri_modelkit/rag/caching.py | 932 +++++++++--------- src/deepiri_modelkit/rag/monitoring.py | 729 +++++++------- src/deepiri_modelkit/rag/processors.py | 866 ++++++++-------- src/deepiri_modelkit/rag/retrievers.py | 575 ++++++----- src/deepiri_modelkit/rag/testing.py | 666 ++++++------- .../registry/adapters/__init__.py | 3 +- .../registry/model_registry.py | 669 +++++++------ .../streaming/event_stream.py | 398 ++++---- src/deepiri_modelkit/streaming/schemas.py | 111 ++- .../streaming/sidecar_utils.py | 160 +-- src/deepiri_modelkit/utils/__init__.py | 15 +- src/deepiri_modelkit/utils/device.py | 301 +++--- 27 files changed, 5788 insertions(+), 5550 deletions(-) diff --git a/src/deepiri_modelkit/__init__.py b/src/deepiri_modelkit/__init__.py index 4d3f09b..5601e33 100644 --- a/src/deepiri_modelkit/__init__.py +++ b/src/deepiri_modelkit/__init__.py @@ -1,36 +1,41 @@ -""" -Deepiri ModelKit - Shared contracts, interfaces, and utilities -""" - -__version__ = "0.1.0" - -from .contracts.models import AIModel, AIModelPydantic, ModelInput, ModelOutput, ModelMetadata -from .contracts.events import ( - ModelReadyEvent, - InferenceEvent, - PlatformEvent, - AGIDecisionEvent, - TrainingEvent, -) -from .streaming.event_stream import StreamingClient -from .registry.model_registry import ModelRegistryClient -from .logging import get_logger, get_error_logger, ErrorLogger - -__all__ = [ - "AIModel", # Protocol interface for type checking - "AIModelPydantic", # Pydantic-compatible type for use in BaseModel fields - "ModelInput", - "ModelOutput", - "ModelMetadata", - "ModelReadyEvent", - "InferenceEvent", - "PlatformEvent", - "AGIDecisionEvent", - "TrainingEvent", - "StreamingClient", - "ModelRegistryClient", - "get_logger", - "get_error_logger", - "ErrorLogger", -] - +""" +Deepiri ModelKit - Shared contracts, interfaces, and utilities +""" + +__version__ = "0.1.0" + +from .contracts.models import ( + AIModel, + AIModelPydantic, + ModelInput, + ModelOutput, + ModelMetadata, +) +from .contracts.events import ( + ModelReadyEvent, + InferenceEvent, + PlatformEvent, + AGIDecisionEvent, + TrainingEvent, +) +from .streaming.event_stream import StreamingClient +from .registry.model_registry import ModelRegistryClient +from .logging import get_logger, get_error_logger, ErrorLogger + +__all__ = [ + "AIModel", # Protocol interface for type checking + "AIModelPydantic", # Pydantic-compatible type for use in BaseModel fields + "ModelInput", + "ModelOutput", + "ModelMetadata", + "ModelReadyEvent", + "InferenceEvent", + "PlatformEvent", + "AGIDecisionEvent", + "TrainingEvent", + "StreamingClient", + "ModelRegistryClient", + "get_logger", + "get_error_logger", + "ErrorLogger", +] diff --git a/src/deepiri_modelkit/contracts/contract.py b/src/deepiri_modelkit/contracts/contract.py index 00abc21..237ac62 100644 --- a/src/deepiri_modelkit/contracts/contract.py +++ b/src/deepiri_modelkit/contracts/contract.py @@ -1,27 +1,30 @@ -""" -Model contract for registry (separated from models.py to avoid Pydantic Protocol conflicts) -""" -from __future__ import annotations - -from typing import Dict, Any, Optional -from pydantic import BaseModel - -from .models import ModelMetadata - - -class ModelContract(BaseModel): - """ - Complete model contract for registry. - - A contract is serializable metadata that describes a model's interface, - input/output schemas, and validation requirements. It does NOT contain - the actual model instance (which would be a Protocol type that Pydantic - cannot serialize). The model instance should be loaded separately when needed. - """ - metadata: ModelMetadata - input_schema: Dict[str, Any] - output_schema: Dict[str, Any] - validation_tests: Optional[list] = None - model_path: Optional[str] = None # Path/reference to where the model can be loaded from - model_id: Optional[str] = None # Unique identifier for the model instance - +""" +Model contract for registry (separated from models.py to avoid Pydantic Protocol conflicts) +""" + +from __future__ import annotations + +from typing import Dict, Any, Optional +from pydantic import BaseModel + +from .models import ModelMetadata + + +class ModelContract(BaseModel): + """ + Complete model contract for registry. + + A contract is serializable metadata that describes a model's interface, + input/output schemas, and validation requirements. It does NOT contain + the actual model instance (which would be a Protocol type that Pydantic + cannot serialize). The model instance should be loaded separately when needed. + """ + + metadata: ModelMetadata + input_schema: Dict[str, Any] + output_schema: Dict[str, Any] + validation_tests: Optional[list] = None + model_path: Optional[str] = ( + None # Path/reference to where the model can be loaded from + ) + model_id: Optional[str] = None # Unique identifier for the model instance diff --git a/src/deepiri_modelkit/contracts/events.py b/src/deepiri_modelkit/contracts/events.py index 9711a93..e1325b4 100644 --- a/src/deepiri_modelkit/contracts/events.py +++ b/src/deepiri_modelkit/contracts/events.py @@ -1,99 +1,107 @@ -""" -Event schemas for streaming service -""" -from pydantic import BaseModel, Field -from typing import Dict, Any, Optional -from datetime import datetime -from enum import Enum - - -class EventType(str, Enum): - """Event type enumeration""" - MODEL_READY = "model-ready" - MODEL_LOADED = "model-loaded" - MODEL_FAILED = "model-failed" - INFERENCE_COMPLETE = "inference-complete" - INFERENCE_FAILED = "inference-failed" - USER_INTERACTION = "user-interaction" - TASK_CREATED = "task-created" - TASK_COMPLETED = "task-completed" - AGI_DECISION = "agi-decision" - AGI_ACTION = "agi-action" - TRAINING_STARTED = "training-started" - TRAINING_COMPLETE = "training-complete" - TRAINING_FAILED = "training-failed" - - -class BaseEvent(BaseModel): - """Base event schema""" - event: str - timestamp: str = Field(default_factory=lambda: datetime.utcnow().isoformat()) - source: str - correlation_id: Optional[str] = None - - -class ModelReadyEvent(BaseEvent): - """Event published when model is trained and ready""" - event: str = EventType.MODEL_READY - model_name: str - version: str - registry_path: str # S3/MLflow path - metadata: Dict[str, Any] - model_type: Optional[str] = None - accuracy: Optional[float] = None - size_mb: Optional[float] = None - - -class ModelLoadedEvent(BaseEvent): - """Event published when model is loaded in runtime""" - event: str = EventType.MODEL_LOADED - model_name: str - version: str - load_time_ms: float - cache_location: Optional[str] = None - - -class InferenceEvent(BaseEvent): - """Event published after inference completes""" - event: str = EventType.INFERENCE_COMPLETE - model_name: str - version: str - user_id: Optional[str] = None - request_id: Optional[str] = None - latency_ms: float - tokens_used: Optional[int] = None - cost: Optional[float] = None - confidence: Optional[float] = None - success: bool = True - - -class PlatformEvent(BaseEvent): - """Event published by platform services""" - event: str # user-interaction, task-created, etc. - service: str - user_id: Optional[str] = None - action: str - data: Dict[str, Any] - organization_id: Optional[str] = None - - -class AGIDecisionEvent(BaseEvent): - """Event published by Cyrex-AGI for autonomous decisions""" - event: str = EventType.AGI_DECISION - decision_type: str - target_service: Optional[str] = None - action: Dict[str, Any] - reasoning: Optional[str] = None - confidence: Optional[float] = None - - -class TrainingEvent(BaseEvent): - """Event published during training""" - event: str # training-started, training-complete, training-failed - experiment_id: str - model_name: str - status: str - progress: Optional[float] = None # 0.0 to 1.0 - metrics: Optional[Dict[str, Any]] = None - error: Optional[str] = None - +""" +Event schemas for streaming service +""" + +from pydantic import BaseModel, Field +from typing import Dict, Any, Optional +from datetime import datetime +from enum import Enum + + +class EventType(str, Enum): + """Event type enumeration""" + + MODEL_READY = "model-ready" + MODEL_LOADED = "model-loaded" + MODEL_FAILED = "model-failed" + INFERENCE_COMPLETE = "inference-complete" + INFERENCE_FAILED = "inference-failed" + USER_INTERACTION = "user-interaction" + TASK_CREATED = "task-created" + TASK_COMPLETED = "task-completed" + AGI_DECISION = "agi-decision" + AGI_ACTION = "agi-action" + TRAINING_STARTED = "training-started" + TRAINING_COMPLETE = "training-complete" + TRAINING_FAILED = "training-failed" + + +class BaseEvent(BaseModel): + """Base event schema""" + + event: str + timestamp: str = Field(default_factory=lambda: datetime.utcnow().isoformat()) + source: str + correlation_id: Optional[str] = None + + +class ModelReadyEvent(BaseEvent): + """Event published when model is trained and ready""" + + event: str = EventType.MODEL_READY + model_name: str + version: str + registry_path: str # S3/MLflow path + metadata: Dict[str, Any] + model_type: Optional[str] = None + accuracy: Optional[float] = None + size_mb: Optional[float] = None + + +class ModelLoadedEvent(BaseEvent): + """Event published when model is loaded in runtime""" + + event: str = EventType.MODEL_LOADED + model_name: str + version: str + load_time_ms: float + cache_location: Optional[str] = None + + +class InferenceEvent(BaseEvent): + """Event published after inference completes""" + + event: str = EventType.INFERENCE_COMPLETE + model_name: str + version: str + user_id: Optional[str] = None + request_id: Optional[str] = None + latency_ms: float + tokens_used: Optional[int] = None + cost: Optional[float] = None + confidence: Optional[float] = None + success: bool = True + + +class PlatformEvent(BaseEvent): + """Event published by platform services""" + + event: str # user-interaction, task-created, etc. + service: str + user_id: Optional[str] = None + action: str + data: Dict[str, Any] + organization_id: Optional[str] = None + + +class AGIDecisionEvent(BaseEvent): + """Event published by Cyrex-AGI for autonomous decisions""" + + event: str = EventType.AGI_DECISION + decision_type: str + target_service: Optional[str] = None + action: Dict[str, Any] + reasoning: Optional[str] = None + confidence: Optional[float] = None + + +class TrainingEvent(BaseEvent): + """Event published during training""" + + event: str # training-started, training-complete, training-failed + experiment_id: str + model_name: str + status: str + progress: Optional[float] = None # 0.0 to 1.0 + metrics: Optional[Dict[str, Any]] = None + error: Optional[str] = None diff --git a/src/deepiri_modelkit/contracts/models.py b/src/deepiri_modelkit/contracts/models.py index 6e52336..79b2e43 100644 --- a/src/deepiri_modelkit/contracts/models.py +++ b/src/deepiri_modelkit/contracts/models.py @@ -1,143 +1,149 @@ -""" -Model contracts and interfaces -""" -from __future__ import annotations # Defer annotation evaluation to prevent Pydantic from processing Protocol types - -from typing import Protocol, Dict, Any, Optional, Annotated -from pydantic import BaseModel, Field, GetCoreSchemaHandler -from pydantic_core import core_schema -from datetime import datetime - - -class ModelInput(BaseModel): - """Standard model input schema""" - data: Dict[str, Any] - metadata: Optional[Dict[str, Any]] = None - timestamp: str = Field(default_factory=lambda: datetime.utcnow().isoformat()) - - -class ModelOutput(BaseModel): - """Standard model output schema""" - prediction: Any - confidence: Optional[float] = None - metadata: Optional[Dict[str, Any]] = None - timestamp: str = Field(default_factory=lambda: datetime.utcnow().isoformat()) - - -class ModelMetadata(BaseModel): - """Model metadata schema""" - name: str - version: str - description: Optional[str] = None - architecture: Optional[str] = None - accuracy: Optional[float] = None - size_mb: Optional[float] = None - created_at: str = Field(default_factory=lambda: datetime.utcnow().isoformat()) - trained_by: Optional[str] = None - tags: Optional[Dict[str, Any]] = None - - -class AIModel(Protocol): - """ - Interface that all models must implement - Used by both Helox (training) and Cyrex (runtime) - - Note: This is a Protocol (structural type). To use in Pydantic models, - use AIModelPydantic instead, which has full Pydantic schema support. - """ - - def predict(self, input: ModelInput) -> ModelOutput: - """Run inference on input""" - ... - - def get_metadata(self) -> ModelMetadata: - """Get model metadata""" - ... - - def validate(self) -> bool: - """Validate model is ready for use""" - ... - - def export(self, format: str = "onnx") -> str: - """Export model to specified format, returns path""" - ... - - -class AIModelPydantic: - """ - Pydantic-compatible wrapper for AIModel Protocol. - Implements __get_pydantic_core_schema__ to provide full schema support. - - Usage in Pydantic models: - model: Optional[AIModelPydantic] = None - - Note: In practice, model instances should be loaded separately and referenced - by ID/path rather than stored directly in serializable Pydantic models. - """ - - @classmethod - def __get_pydantic_core_schema__( - cls, - source_type: Any, - handler: GetCoreSchemaHandler, - ) -> core_schema.CoreSchema: - """ - Pydantic Core Schema handler for AIModel Protocol. - - This allows Pydantic to process AIModel types in model fields. - Since Protocols are structural types, we validate that the object - has the required methods rather than checking exact type. - """ - def validate_aimodel(value: Any) -> Any: - """Validate that value implements AIModel Protocol interface""" - if value is None: - return None - - # Check for required Protocol methods - required_methods = ['predict', 'get_metadata', 'validate', 'export'] - missing_methods = [m for m in required_methods if not hasattr(value, m)] - - if missing_methods: - raise ValueError( - f"Object does not implement AIModel Protocol. " - f"Missing methods: {', '.join(missing_methods)}" - ) - - return value - - def serialize_aimodel(value: Any) -> Dict[str, Any]: - """Serialize AIModel instance to dict""" - if value is None: - return None - - # Try to get metadata if available - metadata = None - if hasattr(value, 'get_metadata'): - try: - metadata = value.get_metadata() - # Convert ModelMetadata to dict if it's a Pydantic model - if hasattr(metadata, 'model_dump'): - metadata = metadata.model_dump() - elif hasattr(metadata, 'dict'): - metadata = metadata.dict() - except Exception: - pass - - return { - "type": "AIModel", - "metadata": metadata, - "has_predict": hasattr(value, 'predict'), - "has_validate": hasattr(value, 'validate'), - } - - return core_schema.no_info_plain_validator_function( - validate_aimodel, - serialization=core_schema.plain_serializer_function_ser_schema( - serialize_aimodel - ) - ) - - -# ModelContract moved to contract.py to avoid Pydantic Protocol conflicts -# Import it from .contract import ModelContract when needed - +""" +Model contracts and interfaces +""" + +from __future__ import ( + annotations, +) # Defer annotation evaluation to prevent Pydantic from processing Protocol types + +from typing import Protocol, Dict, Any, Optional, Annotated +from pydantic import BaseModel, Field, GetCoreSchemaHandler +from pydantic_core import core_schema +from datetime import datetime + + +class ModelInput(BaseModel): + """Standard model input schema""" + + data: Dict[str, Any] + metadata: Optional[Dict[str, Any]] = None + timestamp: str = Field(default_factory=lambda: datetime.utcnow().isoformat()) + + +class ModelOutput(BaseModel): + """Standard model output schema""" + + prediction: Any + confidence: Optional[float] = None + metadata: Optional[Dict[str, Any]] = None + timestamp: str = Field(default_factory=lambda: datetime.utcnow().isoformat()) + + +class ModelMetadata(BaseModel): + """Model metadata schema""" + + name: str + version: str + description: Optional[str] = None + architecture: Optional[str] = None + accuracy: Optional[float] = None + size_mb: Optional[float] = None + created_at: str = Field(default_factory=lambda: datetime.utcnow().isoformat()) + trained_by: Optional[str] = None + tags: Optional[Dict[str, Any]] = None + + +class AIModel(Protocol): + """ + Interface that all models must implement + Used by both Helox (training) and Cyrex (runtime) + + Note: This is a Protocol (structural type). To use in Pydantic models, + use AIModelPydantic instead, which has full Pydantic schema support. + """ + + def predict(self, input: ModelInput) -> ModelOutput: + """Run inference on input""" + ... + + def get_metadata(self) -> ModelMetadata: + """Get model metadata""" + ... + + def validate(self) -> bool: + """Validate model is ready for use""" + ... + + def export(self, format: str = "onnx") -> str: + """Export model to specified format, returns path""" + ... + + +class AIModelPydantic: + """ + Pydantic-compatible wrapper for AIModel Protocol. + Implements __get_pydantic_core_schema__ to provide full schema support. + + Usage in Pydantic models: + model: Optional[AIModelPydantic] = None + + Note: In practice, model instances should be loaded separately and referenced + by ID/path rather than stored directly in serializable Pydantic models. + """ + + @classmethod + def __get_pydantic_core_schema__( + cls, + source_type: Any, + handler: GetCoreSchemaHandler, + ) -> core_schema.CoreSchema: + """ + Pydantic Core Schema handler for AIModel Protocol. + + This allows Pydantic to process AIModel types in model fields. + Since Protocols are structural types, we validate that the object + has the required methods rather than checking exact type. + """ + + def validate_aimodel(value: Any) -> Any: + """Validate that value implements AIModel Protocol interface""" + if value is None: + return None + + # Check for required Protocol methods + required_methods = ["predict", "get_metadata", "validate", "export"] + missing_methods = [m for m in required_methods if not hasattr(value, m)] + + if missing_methods: + raise ValueError( + f"Object does not implement AIModel Protocol. " + f"Missing methods: {', '.join(missing_methods)}" + ) + + return value + + def serialize_aimodel(value: Any) -> Dict[str, Any]: + """Serialize AIModel instance to dict""" + if value is None: + return None + + # Try to get metadata if available + metadata = None + if hasattr(value, "get_metadata"): + try: + metadata = value.get_metadata() + # Convert ModelMetadata to dict if it's a Pydantic model + if hasattr(metadata, "model_dump"): + metadata = metadata.model_dump() + elif hasattr(metadata, "dict"): + metadata = metadata.dict() + except Exception: + pass + + return { + "type": "AIModel", + "metadata": metadata, + "has_predict": hasattr(value, "predict"), + "has_validate": hasattr(value, "validate"), + } + + return core_schema.no_info_plain_validator_function( + validate_aimodel, + serialization=core_schema.plain_serializer_function_ser_schema( + serialize_aimodel + ), + ) + + +# ModelContract moved to contract.py to avoid Pydantic Protocol conflicts +# Import it from .contract import ModelContract when needed diff --git a/src/deepiri_modelkit/contracts/services.py b/src/deepiri_modelkit/contracts/services.py index 9011102..e6b01e5 100644 --- a/src/deepiri_modelkit/contracts/services.py +++ b/src/deepiri_modelkit/contracts/services.py @@ -1,58 +1,44 @@ -""" -Service contracts and interfaces -""" -from typing import Protocol, Dict, Any, Optional -from pydantic import BaseModel - - -class ModelRegistryService(Protocol): - """Interface for model registry operations""" - - def register_model( - self, - model_name: str, - version: str, - model_path: str, - metadata: Dict[str, Any] - ) -> bool: - """Register a model in the registry""" - ... - - def get_model( - self, - model_name: str, - version: Optional[str] = None - ) -> Dict[str, Any]: - """Get model information from registry""" - ... - - def list_models(self, model_name: Optional[str] = None) -> list: - """List available models""" - ... - - def download_model( - self, - model_name: str, - version: str, - destination: str - ) -> str: - """Download model to destination, returns local path""" - ... - - -class StreamingService(Protocol): - """Interface for streaming operations""" - - def publish(self, topic: str, event: Dict[str, Any]) -> bool: - """Publish event to topic""" - ... - - def subscribe( - self, - topic: str, - callback: callable, - consumer_group: Optional[str] = None - ) -> None: - """Subscribe to topic with callback""" - ... - +""" +Service contracts and interfaces +""" + +from typing import Protocol, Dict, Any, Optional +from pydantic import BaseModel + + +class ModelRegistryService(Protocol): + """Interface for model registry operations""" + + def register_model( + self, model_name: str, version: str, model_path: str, metadata: Dict[str, Any] + ) -> bool: + """Register a model in the registry""" + ... + + def get_model( + self, model_name: str, version: Optional[str] = None + ) -> Dict[str, Any]: + """Get model information from registry""" + ... + + def list_models(self, model_name: Optional[str] = None) -> list: + """List available models""" + ... + + def download_model(self, model_name: str, version: str, destination: str) -> str: + """Download model to destination, returns local path""" + ... + + +class StreamingService(Protocol): + """Interface for streaming operations""" + + def publish(self, topic: str, event: Dict[str, Any]) -> bool: + """Publish event to topic""" + ... + + def subscribe( + self, topic: str, callback: callable, consumer_group: Optional[str] = None + ) -> None: + """Subscribe to topic with callback""" + ... diff --git a/src/deepiri_modelkit/data/monitoring.py b/src/deepiri_modelkit/data/monitoring.py index 3fe2d65..f8e1ffc 100644 --- a/src/deepiri_modelkit/data/monitoring.py +++ b/src/deepiri_modelkit/data/monitoring.py @@ -1,375 +1,419 @@ -""" -Dataset Monitoring and Logging Utilities -Provides monitoring, alerting, and logging for dataset versioning operations -""" -import json -import time -from pathlib import Path -from typing import Dict, List, Any, Optional -from datetime import datetime, timedelta -import statistics - -from deepiri_modelkit.logging import get_logger - -logger = get_logger("deepiri_modelkit.data.monitoring") - - -class DatasetMonitor: - """ - Monitors dataset versioning operations and provides insights. - - Features: - - Operation metrics and performance tracking - - Dataset health monitoring - - Usage analytics - - Alerting for data quality issues - """ - - def __init__(self, log_dir: str = "./logs/dataset_monitoring"): - self.log_dir = Path(log_dir) - self.log_dir.mkdir(parents=True, exist_ok=True) - - # Metrics storage - self.metrics_file = self.log_dir / "metrics.jsonl" - self.alerts_file = self.log_dir / "alerts.jsonl" - - # In-memory metrics for quick access - self.current_metrics = { - "total_versions_created": 0, - "total_datasets_tracked": 0, - "average_version_creation_time": 0, - "validation_errors_today": 0, - "last_health_check": None, - "storage_usage_bytes": 0 - } - - self._load_metrics() - - def log_version_creation(self, operation_data: Dict[str, Any]): - """Log dataset version creation operation.""" - log_entry = { - "timestamp": datetime.utcnow().isoformat(), - "operation": "version_creation", - "dataset_name": operation_data.get("dataset_name"), - "version": operation_data.get("version"), - "dataset_type": operation_data.get("dataset_type"), - "total_samples": operation_data.get("total_samples"), - "file_count": operation_data.get("file_count"), - "creation_time_seconds": operation_data.get("creation_time", 0), - "change_type": operation_data.get("change_type"), - "quality_score": operation_data.get("quality_score"), - "storage_path": operation_data.get("storage_path"), - "created_by": operation_data.get("created_by") - } - - self._write_log_entry(self.metrics_file, log_entry) - self.current_metrics["total_versions_created"] += 1 - - logger.info("Version creation logged", - dataset=operation_data.get("dataset_name"), - version=operation_data.get("version")) - - def log_validation_result(self, validation_data: Dict[str, Any]): - """Log dataset validation results.""" - log_entry = { - "timestamp": datetime.utcnow().isoformat(), - "operation": "validation", - "dataset_name": validation_data.get("dataset_name"), - "version": validation_data.get("version"), - "is_valid": validation_data.get("is_valid"), - "quality_score": validation_data.get("quality_score"), - "error_count": len(validation_data.get("errors", [])), - "warning_count": len(validation_data.get("warnings", [])), - "validation_time_seconds": validation_data.get("validation_time", 0) - } - - self._write_log_entry(self.metrics_file, log_entry) - - if not validation_data.get("is_valid", True): - self.current_metrics["validation_errors_today"] += 1 - - # Check for alerts - if validation_data.get("quality_score", 1.0) < 0.7: - self._create_alert("low_quality_score", { - "dataset_name": validation_data.get("dataset_name"), - "version": validation_data.get("version"), - "quality_score": validation_data.get("quality_score"), - "errors": validation_data.get("errors", []) - }) - - logger.info("Validation result logged", - dataset=validation_data.get("dataset_name"), - valid=validation_data.get("is_valid"), - quality=validation_data.get("quality_score")) - - def log_training_usage(self, training_data: Dict[str, Any]): - """Log dataset usage in training.""" - log_entry = { - "timestamp": datetime.utcnow().isoformat(), - "operation": "training_usage", - "dataset_name": training_data.get("dataset_name"), - "dataset_version": training_data.get("dataset_version"), - "model_name": training_data.get("model_name"), - "training_duration_seconds": training_data.get("training_duration", 0), - "final_loss": training_data.get("final_loss"), - "experiment_id": training_data.get("experiment_id"), - "output_model_path": training_data.get("output_model_path") - } - - self._write_log_entry(self.metrics_file, log_entry) - - logger.info("Training usage logged", - dataset=training_data.get("dataset_name"), - version=training_data.get("dataset_version"), - model=training_data.get("model_name")) - - def get_health_report(self) -> Dict[str, Any]: - """Generate comprehensive health report.""" - report = { - "timestamp": datetime.utcnow().isoformat(), - "summary": { - "total_versions": self.current_metrics["total_versions_created"], - "datasets_tracked": self.current_metrics["total_datasets_tracked"], - "validation_errors_today": self.current_metrics["validation_errors_today"], - "storage_usage_gb": self.current_metrics["storage_usage_bytes"] / (1024**3) - }, - "performance": self._analyze_performance(), - "quality_trends": self._analyze_quality_trends(), - "alerts": self._get_recent_alerts(), - "recommendations": self._generate_recommendations() - } - - self.current_metrics["last_health_check"] = report["timestamp"] - return report - - def get_usage_analytics(self, days: int = 30) -> Dict[str, Any]: - """Get usage analytics for the specified period.""" - cutoff_date = datetime.utcnow() - timedelta(days=days) - - analytics = { - "period_days": days, - "version_creations": [], - "training_runs": [], - "validation_runs": [], - "popular_datasets": {}, - "quality_distribution": {} - } - - # Read logs and filter by date - if self.metrics_file.exists(): - with open(self.metrics_file, 'r') as f: - for line in f: - try: - entry = json.loads(line.strip()) - entry_date = datetime.fromisoformat(entry["timestamp"]) - - if entry_date >= cutoff_date: - if entry["operation"] == "version_creation": - analytics["version_creations"].append(entry) - dataset = entry.get("dataset_name", "unknown") - analytics["popular_datasets"][dataset] = analytics["popular_datasets"].get(dataset, 0) + 1 - - elif entry["operation"] == "training_usage": - analytics["training_runs"].append(entry) - - elif entry["operation"] == "validation": - analytics["validation_runs"].append(entry) - quality = entry.get("quality_score", 0) - quality_bucket = f"{int(quality * 10) / 10:.1f}" - analytics["quality_distribution"][quality_bucket] = analytics["quality_distribution"].get(quality_bucket, 0) + 1 - - except json.JSONDecodeError: - continue - - return analytics - - def _analyze_performance(self) -> Dict[str, Any]: - """Analyze system performance metrics.""" - creation_times = [] - validation_times = [] - - if self.metrics_file.exists(): - with open(self.metrics_file, 'r') as f: - for line in f: - try: - entry = json.loads(line.strip()) - if entry["operation"] == "version_creation": - if "creation_time_seconds" in entry: - creation_times.append(entry["creation_time_seconds"]) - elif entry["operation"] == "validation": - if "validation_time_seconds" in entry: - validation_times.append(entry["validation_time_seconds"]) - except json.JSONDecodeError: - continue - - return { - "avg_version_creation_time": statistics.mean(creation_times) if creation_times else 0, - "avg_validation_time": statistics.mean(validation_times) if validation_times else 0, - "total_operations": len(creation_times) + len(validation_times), - "creation_times": creation_times[-10:], # Last 10 - "validation_times": validation_times[-10:] # Last 10 - } - - def _analyze_quality_trends(self) -> Dict[str, Any]: - """Analyze quality trends over time.""" - quality_scores = [] - - if self.metrics_file.exists(): - with open(self.metrics_file, 'r') as f: - for line in f: - try: - entry = json.loads(line.strip()) - if entry["operation"] == "validation" and "quality_score" in entry: - quality_scores.append(entry["quality_score"]) - except json.JSONDecodeError: - continue - - if not quality_scores: - return {"trend": "insufficient_data"} - - recent_scores = quality_scores[-20:] # Last 20 validations - avg_quality = statistics.mean(recent_scores) if recent_scores else 0 - - # Simple trend analysis - if len(recent_scores) >= 10: - first_half = recent_scores[:len(recent_scores)//2] - second_half = recent_scores[len(recent_scores)//2:] - - first_avg = statistics.mean(first_half) - second_avg = statistics.mean(second_half) - - if second_avg > first_avg + 0.05: - trend = "improving" - elif second_avg < first_avg - 0.05: - trend = "declining" - else: - trend = "stable" - else: - trend = "insufficient_data" - - return { - "average_quality": avg_quality, - "trend": trend, - "total_validations": len(quality_scores), - "quality_distribution": { - "excellent": len([s for s in quality_scores if s >= 0.9]), - "good": len([s for s in quality_scores if 0.7 <= s < 0.9]), - "fair": len([s for s in quality_scores if 0.5 <= s < 0.7]), - "poor": len([s for s in quality_scores if s < 0.5]) - } - } - - def _generate_recommendations(self) -> List[str]: - """Generate recommendations based on current state.""" - recommendations = [] - - # Check for frequent validation errors - if self.current_metrics["validation_errors_today"] > 5: - recommendations.append("High validation error rate detected. Review data quality processes.") - - # Check quality trends - quality_analysis = self._analyze_quality_trends() - if quality_analysis.get("trend") == "declining": - recommendations.append("Dataset quality is declining. Consider reviewing data sources and annotation processes.") - - # Check performance - performance = self._analyze_performance() - if performance.get("avg_version_creation_time", 0) > 300: # 5 minutes - recommendations.append("Version creation is slow. Consider optimizing storage or processing.") - - # General recommendations - if self.current_metrics["total_versions_created"] == 0: - recommendations.append("No dataset versions created yet. Start versioning your datasets for reproducibility.") - - if not recommendations: - recommendations.append("System operating normally. Continue regular monitoring.") - - return recommendations - - def _create_alert(self, alert_type: str, alert_data: Dict[str, Any]): - """Create an alert for monitoring.""" - alert_entry = { - "timestamp": datetime.utcnow().isoformat(), - "alert_type": alert_type, - "severity": "warning", # Could be "info", "warning", "error" - "data": alert_data, - "resolved": False - } - - self._write_log_entry(self.alerts_file, alert_entry) - - logger.warning("Alert created", - type=alert_type, - data=alert_data) - - def _get_recent_alerts(self, hours: int = 24) -> List[Dict[str, Any]]: - """Get recent alerts.""" - alerts = [] - cutoff_time = datetime.utcnow() - timedelta(hours=hours) - - if self.alerts_file.exists(): - with open(self.alerts_file, 'r') as f: - for line in f: - try: - alert = json.loads(line.strip()) - alert_time = datetime.fromisoformat(alert["timestamp"]) - if alert_time >= cutoff_time: - alerts.append(alert) - except json.JSONDecodeError: - continue - - return alerts[-10:] # Return last 10 alerts - - def _write_log_entry(self, log_file: Path, entry: Dict[str, Any]): - """Write a log entry to file.""" - with open(log_file, 'a') as f: - f.write(json.dumps(entry) + '\n') - - def _load_metrics(self): - """Load current metrics from log files.""" - if self.metrics_file.exists(): - try: - with open(self.metrics_file, 'r') as f: - lines = f.readlines() - if lines: - # Count operations from logs - version_count = sum(1 for line in lines if '"operation": "version_creation"' in line) - validation_count = sum(1 for line in lines if '"operation": "validation"' in line and '"is_valid": false' in line) - - self.current_metrics["total_versions_created"] = version_count - self.current_metrics["validation_errors_today"] = validation_count - except Exception as e: - logger.warning("Failed to load metrics from log", error=str(e)) - - -# Convenience functions -def log_version_creation(dataset_name: str, version: str, **kwargs): - """Convenience function to log version creation.""" - monitor = DatasetMonitor() - monitor.log_version_creation({ - "dataset_name": dataset_name, - "version": version, - **kwargs - }) - - -def log_validation_result(dataset_name: str, version: str, **kwargs): - """Convenience function to log validation results.""" - monitor = DatasetMonitor() - monitor.log_validation_result({ - "dataset_name": dataset_name, - "version": version, - **kwargs - }) - - -def get_health_report(): - """Get current health report.""" - monitor = DatasetMonitor() - return monitor.get_health_report() - - -def get_usage_analytics(days: int = 30): - """Get usage analytics.""" - monitor = DatasetMonitor() - return monitor.get_usage_analytics(days) +""" +Dataset Monitoring and Logging Utilities +Provides monitoring, alerting, and logging for dataset versioning operations +""" + +import json +import time +from pathlib import Path +from typing import Dict, List, Any, Optional +from datetime import datetime, timedelta +import statistics + +from deepiri_modelkit.logging import get_logger + +logger = get_logger("deepiri_modelkit.data.monitoring") + + +class DatasetMonitor: + """ + Monitors dataset versioning operations and provides insights. + + Features: + - Operation metrics and performance tracking + - Dataset health monitoring + - Usage analytics + - Alerting for data quality issues + """ + + def __init__(self, log_dir: str = "./logs/dataset_monitoring"): + self.log_dir = Path(log_dir) + self.log_dir.mkdir(parents=True, exist_ok=True) + + # Metrics storage + self.metrics_file = self.log_dir / "metrics.jsonl" + self.alerts_file = self.log_dir / "alerts.jsonl" + + # In-memory metrics for quick access + self.current_metrics = { + "total_versions_created": 0, + "total_datasets_tracked": 0, + "average_version_creation_time": 0, + "validation_errors_today": 0, + "last_health_check": None, + "storage_usage_bytes": 0, + } + + self._load_metrics() + + def log_version_creation(self, operation_data: Dict[str, Any]): + """Log dataset version creation operation.""" + log_entry = { + "timestamp": datetime.utcnow().isoformat(), + "operation": "version_creation", + "dataset_name": operation_data.get("dataset_name"), + "version": operation_data.get("version"), + "dataset_type": operation_data.get("dataset_type"), + "total_samples": operation_data.get("total_samples"), + "file_count": operation_data.get("file_count"), + "creation_time_seconds": operation_data.get("creation_time", 0), + "change_type": operation_data.get("change_type"), + "quality_score": operation_data.get("quality_score"), + "storage_path": operation_data.get("storage_path"), + "created_by": operation_data.get("created_by"), + } + + self._write_log_entry(self.metrics_file, log_entry) + self.current_metrics["total_versions_created"] += 1 + + logger.info( + "Version creation logged", + dataset=operation_data.get("dataset_name"), + version=operation_data.get("version"), + ) + + def log_validation_result(self, validation_data: Dict[str, Any]): + """Log dataset validation results.""" + log_entry = { + "timestamp": datetime.utcnow().isoformat(), + "operation": "validation", + "dataset_name": validation_data.get("dataset_name"), + "version": validation_data.get("version"), + "is_valid": validation_data.get("is_valid"), + "quality_score": validation_data.get("quality_score"), + "error_count": len(validation_data.get("errors", [])), + "warning_count": len(validation_data.get("warnings", [])), + "validation_time_seconds": validation_data.get("validation_time", 0), + } + + self._write_log_entry(self.metrics_file, log_entry) + + if not validation_data.get("is_valid", True): + self.current_metrics["validation_errors_today"] += 1 + + # Check for alerts + if validation_data.get("quality_score", 1.0) < 0.7: + self._create_alert( + "low_quality_score", + { + "dataset_name": validation_data.get("dataset_name"), + "version": validation_data.get("version"), + "quality_score": validation_data.get("quality_score"), + "errors": validation_data.get("errors", []), + }, + ) + + logger.info( + "Validation result logged", + dataset=validation_data.get("dataset_name"), + valid=validation_data.get("is_valid"), + quality=validation_data.get("quality_score"), + ) + + def log_training_usage(self, training_data: Dict[str, Any]): + """Log dataset usage in training.""" + log_entry = { + "timestamp": datetime.utcnow().isoformat(), + "operation": "training_usage", + "dataset_name": training_data.get("dataset_name"), + "dataset_version": training_data.get("dataset_version"), + "model_name": training_data.get("model_name"), + "training_duration_seconds": training_data.get("training_duration", 0), + "final_loss": training_data.get("final_loss"), + "experiment_id": training_data.get("experiment_id"), + "output_model_path": training_data.get("output_model_path"), + } + + self._write_log_entry(self.metrics_file, log_entry) + + logger.info( + "Training usage logged", + dataset=training_data.get("dataset_name"), + version=training_data.get("dataset_version"), + model=training_data.get("model_name"), + ) + + def get_health_report(self) -> Dict[str, Any]: + """Generate comprehensive health report.""" + report = { + "timestamp": datetime.utcnow().isoformat(), + "summary": { + "total_versions": self.current_metrics["total_versions_created"], + "datasets_tracked": self.current_metrics["total_datasets_tracked"], + "validation_errors_today": self.current_metrics[ + "validation_errors_today" + ], + "storage_usage_gb": self.current_metrics["storage_usage_bytes"] + / (1024**3), + }, + "performance": self._analyze_performance(), + "quality_trends": self._analyze_quality_trends(), + "alerts": self._get_recent_alerts(), + "recommendations": self._generate_recommendations(), + } + + self.current_metrics["last_health_check"] = report["timestamp"] + return report + + def get_usage_analytics(self, days: int = 30) -> Dict[str, Any]: + """Get usage analytics for the specified period.""" + cutoff_date = datetime.utcnow() - timedelta(days=days) + + analytics = { + "period_days": days, + "version_creations": [], + "training_runs": [], + "validation_runs": [], + "popular_datasets": {}, + "quality_distribution": {}, + } + + # Read logs and filter by date + if self.metrics_file.exists(): + with open(self.metrics_file, "r") as f: + for line in f: + try: + entry = json.loads(line.strip()) + entry_date = datetime.fromisoformat(entry["timestamp"]) + + if entry_date >= cutoff_date: + if entry["operation"] == "version_creation": + analytics["version_creations"].append(entry) + dataset = entry.get("dataset_name", "unknown") + analytics["popular_datasets"][dataset] = ( + analytics["popular_datasets"].get(dataset, 0) + 1 + ) + + elif entry["operation"] == "training_usage": + analytics["training_runs"].append(entry) + + elif entry["operation"] == "validation": + analytics["validation_runs"].append(entry) + quality = entry.get("quality_score", 0) + quality_bucket = f"{int(quality * 10) / 10:.1f}" + analytics["quality_distribution"][quality_bucket] = ( + analytics["quality_distribution"].get( + quality_bucket, 0 + ) + + 1 + ) + + except json.JSONDecodeError: + continue + + return analytics + + def _analyze_performance(self) -> Dict[str, Any]: + """Analyze system performance metrics.""" + creation_times = [] + validation_times = [] + + if self.metrics_file.exists(): + with open(self.metrics_file, "r") as f: + for line in f: + try: + entry = json.loads(line.strip()) + if entry["operation"] == "version_creation": + if "creation_time_seconds" in entry: + creation_times.append(entry["creation_time_seconds"]) + elif entry["operation"] == "validation": + if "validation_time_seconds" in entry: + validation_times.append( + entry["validation_time_seconds"] + ) + except json.JSONDecodeError: + continue + + return { + "avg_version_creation_time": ( + statistics.mean(creation_times) if creation_times else 0 + ), + "avg_validation_time": ( + statistics.mean(validation_times) if validation_times else 0 + ), + "total_operations": len(creation_times) + len(validation_times), + "creation_times": creation_times[-10:], # Last 10 + "validation_times": validation_times[-10:], # Last 10 + } + + def _analyze_quality_trends(self) -> Dict[str, Any]: + """Analyze quality trends over time.""" + quality_scores = [] + + if self.metrics_file.exists(): + with open(self.metrics_file, "r") as f: + for line in f: + try: + entry = json.loads(line.strip()) + if ( + entry["operation"] == "validation" + and "quality_score" in entry + ): + quality_scores.append(entry["quality_score"]) + except json.JSONDecodeError: + continue + + if not quality_scores: + return {"trend": "insufficient_data"} + + recent_scores = quality_scores[-20:] # Last 20 validations + avg_quality = statistics.mean(recent_scores) if recent_scores else 0 + + # Simple trend analysis + if len(recent_scores) >= 10: + first_half = recent_scores[: len(recent_scores) // 2] + second_half = recent_scores[len(recent_scores) // 2 :] + + first_avg = statistics.mean(first_half) + second_avg = statistics.mean(second_half) + + if second_avg > first_avg + 0.05: + trend = "improving" + elif second_avg < first_avg - 0.05: + trend = "declining" + else: + trend = "stable" + else: + trend = "insufficient_data" + + return { + "average_quality": avg_quality, + "trend": trend, + "total_validations": len(quality_scores), + "quality_distribution": { + "excellent": len([s for s in quality_scores if s >= 0.9]), + "good": len([s for s in quality_scores if 0.7 <= s < 0.9]), + "fair": len([s for s in quality_scores if 0.5 <= s < 0.7]), + "poor": len([s for s in quality_scores if s < 0.5]), + }, + } + + def _generate_recommendations(self) -> List[str]: + """Generate recommendations based on current state.""" + recommendations = [] + + # Check for frequent validation errors + if self.current_metrics["validation_errors_today"] > 5: + recommendations.append( + "High validation error rate detected. Review data quality processes." + ) + + # Check quality trends + quality_analysis = self._analyze_quality_trends() + if quality_analysis.get("trend") == "declining": + recommendations.append( + "Dataset quality is declining. Consider reviewing data sources and annotation processes." + ) + + # Check performance + performance = self._analyze_performance() + if performance.get("avg_version_creation_time", 0) > 300: # 5 minutes + recommendations.append( + "Version creation is slow. Consider optimizing storage or processing." + ) + + # General recommendations + if self.current_metrics["total_versions_created"] == 0: + recommendations.append( + "No dataset versions created yet. Start versioning your datasets for reproducibility." + ) + + if not recommendations: + recommendations.append( + "System operating normally. Continue regular monitoring." + ) + + return recommendations + + def _create_alert(self, alert_type: str, alert_data: Dict[str, Any]): + """Create an alert for monitoring.""" + alert_entry = { + "timestamp": datetime.utcnow().isoformat(), + "alert_type": alert_type, + "severity": "warning", # Could be "info", "warning", "error" + "data": alert_data, + "resolved": False, + } + + self._write_log_entry(self.alerts_file, alert_entry) + + logger.warning("Alert created", type=alert_type, data=alert_data) + + def _get_recent_alerts(self, hours: int = 24) -> List[Dict[str, Any]]: + """Get recent alerts.""" + alerts = [] + cutoff_time = datetime.utcnow() - timedelta(hours=hours) + + if self.alerts_file.exists(): + with open(self.alerts_file, "r") as f: + for line in f: + try: + alert = json.loads(line.strip()) + alert_time = datetime.fromisoformat(alert["timestamp"]) + if alert_time >= cutoff_time: + alerts.append(alert) + except json.JSONDecodeError: + continue + + return alerts[-10:] # Return last 10 alerts + + def _write_log_entry(self, log_file: Path, entry: Dict[str, Any]): + """Write a log entry to file.""" + with open(log_file, "a") as f: + f.write(json.dumps(entry) + "\n") + + def _load_metrics(self): + """Load current metrics from log files.""" + if self.metrics_file.exists(): + try: + with open(self.metrics_file, "r") as f: + lines = f.readlines() + if lines: + # Count operations from logs + version_count = sum( + 1 + for line in lines + if '"operation": "version_creation"' in line + ) + validation_count = sum( + 1 + for line in lines + if '"operation": "validation"' in line + and '"is_valid": false' in line + ) + + self.current_metrics["total_versions_created"] = version_count + self.current_metrics["validation_errors_today"] = ( + validation_count + ) + except Exception as e: + logger.warning("Failed to load metrics from log", error=str(e)) + + +# Convenience functions +def log_version_creation(dataset_name: str, version: str, **kwargs): + """Convenience function to log version creation.""" + monitor = DatasetMonitor() + monitor.log_version_creation( + {"dataset_name": dataset_name, "version": version, **kwargs} + ) + + +def log_validation_result(dataset_name: str, version: str, **kwargs): + """Convenience function to log validation results.""" + monitor = DatasetMonitor() + monitor.log_validation_result( + {"dataset_name": dataset_name, "version": version, **kwargs} + ) + + +def get_health_report(): + """Get current health report.""" + monitor = DatasetMonitor() + return monitor.get_health_report() + + +def get_usage_analytics(days: int = 30): + """Get usage analytics.""" + monitor = DatasetMonitor() + return monitor.get_usage_analytics(days) diff --git a/src/deepiri_modelkit/data/validation.py b/src/deepiri_modelkit/data/validation.py index 25e7b2e..acc2aa8 100644 --- a/src/deepiri_modelkit/data/validation.py +++ b/src/deepiri_modelkit/data/validation.py @@ -1,364 +1,400 @@ -""" -Dataset Validation Utilities -Provides validation and quality checks for language intelligence datasets -""" -import json -from pathlib import Path -from typing import Dict, List, Any, Optional -import re -from collections import Counter - -from deepiri_modelkit.logging import get_logger - -logger = get_logger("deepiri_modelkit.data.validation") - - -class DatasetValidator: - """ - Validates dataset quality and integrity for language intelligence tasks. - - Supports validation for: - - Lease abstraction datasets - - Contract intelligence datasets - - General text quality checks - """ - - def __init__(self, dataset_type: str = "general"): - self.dataset_type = dataset_type - self.validation_rules = self._get_validation_rules() - - def _get_validation_rules(self) -> Dict[str, Any]: - """Get validation rules based on dataset type.""" - base_rules = { - "min_samples": 10, - "max_samples": 100000, - "min_text_length": 10, - "max_text_length": 10000, - "required_fields": ["text"], - "text_quality_checks": True - } - - type_specific_rules = { - "lease_abstraction": { - "min_samples": 50, - "lease_keywords": [ - "lease", "agreement", "landlord", "tenant", "rent", - "premises", "term", "commencement", "expiration" - ], - "min_keyword_matches": 2, - "check_address_patterns": True, - "check_rent_patterns": True - }, - "contract_intelligence": { - "min_samples": 50, - "contract_keywords": [ - "contract", "agreement", "party", "obligation", - "clause", "provision", "section", "article" - ], - "min_keyword_matches": 2, - "check_legal_patterns": True - } - } - - if self.dataset_type in type_specific_rules: - base_rules.update(type_specific_rules[self.dataset_type]) - - return base_rules - - def validate_dataset(self, data_path: Path) -> Dict[str, Any]: - """ - Comprehensive dataset validation. - - Args: - data_path: Path to dataset files - - Returns: - Validation results dictionary - """ - logger.info("Starting dataset validation", path=str(data_path), type=self.dataset_type) - - results = { - "is_valid": True, - "errors": [], - "warnings": [], - "statistics": {}, - "quality_score": 0.0 - } - - try: - # Load and parse data - samples = self._load_samples(data_path) - results["statistics"]["total_samples"] = len(samples) - - if not samples: - results["is_valid"] = False - results["errors"].append("No samples found in dataset") - return results - - # Basic structure validation - self._validate_structure(samples, results) - - # Content quality validation - if results["is_valid"]: - self._validate_content_quality(samples, results) - - # Type-specific validation - if self.dataset_type != "general": - self._validate_type_specific(samples, results) - - # Calculate overall quality score - results["quality_score"] = self._calculate_quality_score(results) - - # Determine final validity - results["is_valid"] = len(results["errors"]) == 0 - - except Exception as e: - results["is_valid"] = False - results["errors"].append(f"Validation failed with error: {str(e)}") - logger.error("Dataset validation failed", error=str(e)) - - logger.info("Dataset validation complete", - valid=results["is_valid"], - quality_score=results["quality_score"], - errors=len(results["errors"]), - warnings=len(results["warnings"])) - - return results - - def _load_samples(self, data_path: Path) -> List[Dict[str, Any]]: - """Load samples from dataset files.""" - samples = [] - - if data_path.is_file() and data_path.suffix == ".jsonl": - with open(data_path, 'r', encoding='utf-8') as f: - for line_num, line in enumerate(f, 1): - line = line.strip() - if line: - try: - sample = json.loads(line) - samples.append(sample) - except json.JSONDecodeError as e: - logger.warning(f"Invalid JSON at line {line_num}: {e}") - - elif data_path.is_dir(): - for file_path in data_path.glob("*.jsonl"): - with open(file_path, 'r', encoding='utf-8') as f: - for line_num, line in enumerate(f, 1): - line = line.strip() - if line: - try: - sample = json.loads(line) - samples.append(sample) - except json.JSONDecodeError as e: - logger.warning(f"Invalid JSON in {file_path} at line {line_num}: {e}") - - return samples - - def _validate_structure(self, samples: List[Dict], results: Dict): - """Validate basic dataset structure.""" - if len(samples) < self.validation_rules["min_samples"]: - results["errors"].append( - f"Insufficient samples: {len(samples)} < {self.validation_rules['min_samples']}" - ) - - if len(samples) > self.validation_rules["max_samples"]: - results["warnings"].append( - f"Large dataset: {len(samples)} > {self.validation_rules['max_samples']}" - ) - - # Check required fields - required_fields = self.validation_rules["required_fields"] - for i, sample in enumerate(samples[:100]): # Check first 100 samples - for field in required_fields: - if field not in sample: - results["errors"].append(f"Missing required field '{field}' in sample {i}") - - def _validate_content_quality(self, samples: List[Dict], results: Dict): - """Validate content quality.""" - text_lengths = [] - empty_texts = 0 - duplicate_texts = set() - seen_texts = set() - - for sample in samples: - text = sample.get("text", "").strip() - - # Check text length - text_len = len(text) - text_lengths.append(text_len) - - if text_len < self.validation_rules["min_text_length"]: - results["errors"].append(f"Text too short: {text_len} chars") - elif text_len > self.validation_rules["max_text_length"]: - results["warnings"].append(f"Text too long: {text_len} chars") - - if not text: - empty_texts += 1 - - # Check for duplicates - if text in seen_texts: - duplicate_texts.add(text) - else: - seen_texts.add(text) - - # Statistics - results["statistics"].update({ - "avg_text_length": sum(text_lengths) / len(text_lengths) if text_lengths else 0, - "min_text_length": min(text_lengths) if text_lengths else 0, - "max_text_length": max(text_lengths) if text_lengths else 0, - "empty_texts": empty_texts, - "duplicate_texts": len(duplicate_texts) - }) - - if empty_texts > 0: - results["errors"].append(f"Found {empty_texts} empty text samples") - - if len(duplicate_texts) > len(samples) * 0.01: # >1% duplicates - results["warnings"].append(f"High duplicate rate: {len(duplicate_texts)} duplicates") - - def _validate_type_specific(self, samples: List[Dict], results: Dict): - """Type-specific validation.""" - if self.dataset_type == "lease_abstraction": - self._validate_lease_abstraction(samples, results) - elif self.dataset_type == "contract_intelligence": - self._validate_contract_intelligence(samples, results) - - def _validate_lease_abstraction(self, samples: List[Dict], results: Dict): - """Validate lease abstraction dataset.""" - keywords = self.validation_rules["lease_keywords"] - min_matches = self.validation_rules["min_keyword_matches"] - - low_keyword_samples = 0 - address_pattern_matches = 0 - rent_pattern_matches = 0 - - # Address patterns (street numbers, street names, cities) - address_pattern = r'\d+\s+[A-Za-z0-9\s,.-]+(?:Street|St|Avenue|Ave|Road|Rd|Boulevard|Blvd|Drive|Dr|Lane|Ln|Way|Place|Pl|Court|Ct)\s*,?\s*[A-Za-z\s]+,?\s*\d{5}' - - # Rent patterns (dollar amounts) - rent_pattern = r'\$\d{1,3}(?:,\d{3})*(?:\.\d{2})?' - - for sample in samples[:500]: # Check first 500 samples for performance - text = sample.get("text", "").lower() - - # Keyword matching - keyword_matches = sum(1 for keyword in keywords if keyword in text) - if keyword_matches < min_matches: - low_keyword_samples += 1 - - # Pattern matching - if re.search(address_pattern, sample.get("text", "")): - address_pattern_matches += 1 - - if re.search(rent_pattern, sample.get("text", "")): - rent_pattern_matches += 1 - - total_checked = min(500, len(samples)) - keyword_failure_rate = low_keyword_samples / total_checked - - if keyword_failure_rate > 0.3: # >30% samples lack keywords - results["warnings"].append( - f"Low keyword relevance: {keyword_failure_rate:.1%} samples lack lease keywords" - ) - - results["statistics"].update({ - "address_pattern_matches": address_pattern_matches, - "rent_pattern_matches": rent_pattern_matches, - "keyword_relevance_score": 1.0 - keyword_failure_rate - }) - - def _validate_contract_intelligence(self, samples: List[Dict], results: Dict): - """Validate contract intelligence dataset.""" - keywords = self.validation_rules["contract_keywords"] - min_matches = self.validation_rules["min_keyword_matches"] - - low_keyword_samples = 0 - legal_pattern_matches = 0 - - # Legal clause patterns - legal_patterns = [ - r'\bsection\s+\d+', - r'\barticle\s+\d+', - r'\bclause\s+\d+', - r'\bparagraph\s+\d+', - r'\bsubsection\s+\d+' - ] - - for sample in samples[:500]: # Check first 500 samples - text = sample.get("text", "").lower() - - # Keyword matching - keyword_matches = sum(1 for keyword in keywords if keyword in text) - if keyword_matches < min_matches: - low_keyword_samples += 1 - - # Legal pattern matching - if any(re.search(pattern, sample.get("text", "")) for pattern in legal_patterns): - legal_pattern_matches += 1 - - total_checked = min(500, len(samples)) - keyword_failure_rate = low_keyword_samples / total_checked - - if keyword_failure_rate > 0.3: - results["warnings"].append( - f"Low keyword relevance: {keyword_failure_rate:.1%} samples lack contract keywords" - ) - - results["statistics"].update({ - "legal_pattern_matches": legal_pattern_matches, - "keyword_relevance_score": 1.0 - keyword_failure_rate - }) - - def _calculate_quality_score(self, results: Dict) -> float: - """Calculate overall quality score (0.0 to 1.0).""" - score = 1.0 - - # Penalize errors heavily - error_penalty = len(results["errors"]) * 0.2 - score -= min(error_penalty, 0.8) - - # Penalize warnings moderately - warning_penalty = len(results["warnings"]) * 0.05 - score -= min(warning_penalty, 0.2) - - # Bonus for good statistics - stats = results["statistics"] - - if stats.get("avg_text_length", 0) > 100: - score += 0.05 # Good average text length - - if stats.get("duplicate_texts", 0) == 0: - score += 0.1 # No duplicates - - if stats.get("empty_texts", 0) == 0: - score += 0.1 # No empty texts - - # Type-specific bonuses - if self.dataset_type == "lease_abstraction": - if stats.get("keyword_relevance_score", 0) > 0.7: - score += 0.1 - if stats.get("address_pattern_matches", 0) > 0: - score += 0.05 - - elif self.dataset_type == "contract_intelligence": - if stats.get("keyword_relevance_score", 0) > 0.7: - score += 0.1 - if stats.get("legal_pattern_matches", 0) > 0: - score += 0.05 - - return max(0.0, min(1.0, score)) - - -def validate_dataset_quality(data_path: Path, dataset_type: str = "general") -> Dict[str, Any]: - """ - Convenience function to validate dataset quality. - - Args: - data_path: Path to dataset - dataset_type: Type of dataset for specialized validation - - Returns: - Validation results - """ - validator = DatasetValidator(dataset_type) - return validator.validate_dataset(data_path) +""" +Dataset Validation Utilities +Provides validation and quality checks for language intelligence datasets +""" + +import json +from pathlib import Path +from typing import Dict, List, Any, Optional +import re +from collections import Counter + +from deepiri_modelkit.logging import get_logger + +logger = get_logger("deepiri_modelkit.data.validation") + + +class DatasetValidator: + """ + Validates dataset quality and integrity for language intelligence tasks. + + Supports validation for: + - Lease abstraction datasets + - Contract intelligence datasets + - General text quality checks + """ + + def __init__(self, dataset_type: str = "general"): + self.dataset_type = dataset_type + self.validation_rules = self._get_validation_rules() + + def _get_validation_rules(self) -> Dict[str, Any]: + """Get validation rules based on dataset type.""" + base_rules = { + "min_samples": 10, + "max_samples": 100000, + "min_text_length": 10, + "max_text_length": 10000, + "required_fields": ["text"], + "text_quality_checks": True, + } + + type_specific_rules = { + "lease_abstraction": { + "min_samples": 50, + "lease_keywords": [ + "lease", + "agreement", + "landlord", + "tenant", + "rent", + "premises", + "term", + "commencement", + "expiration", + ], + "min_keyword_matches": 2, + "check_address_patterns": True, + "check_rent_patterns": True, + }, + "contract_intelligence": { + "min_samples": 50, + "contract_keywords": [ + "contract", + "agreement", + "party", + "obligation", + "clause", + "provision", + "section", + "article", + ], + "min_keyword_matches": 2, + "check_legal_patterns": True, + }, + } + + if self.dataset_type in type_specific_rules: + base_rules.update(type_specific_rules[self.dataset_type]) + + return base_rules + + def validate_dataset(self, data_path: Path) -> Dict[str, Any]: + """ + Comprehensive dataset validation. + + Args: + data_path: Path to dataset files + + Returns: + Validation results dictionary + """ + logger.info( + "Starting dataset validation", path=str(data_path), type=self.dataset_type + ) + + results = { + "is_valid": True, + "errors": [], + "warnings": [], + "statistics": {}, + "quality_score": 0.0, + } + + try: + # Load and parse data + samples = self._load_samples(data_path) + results["statistics"]["total_samples"] = len(samples) + + if not samples: + results["is_valid"] = False + results["errors"].append("No samples found in dataset") + return results + + # Basic structure validation + self._validate_structure(samples, results) + + # Content quality validation + if results["is_valid"]: + self._validate_content_quality(samples, results) + + # Type-specific validation + if self.dataset_type != "general": + self._validate_type_specific(samples, results) + + # Calculate overall quality score + results["quality_score"] = self._calculate_quality_score(results) + + # Determine final validity + results["is_valid"] = len(results["errors"]) == 0 + + except Exception as e: + results["is_valid"] = False + results["errors"].append(f"Validation failed with error: {str(e)}") + logger.error("Dataset validation failed", error=str(e)) + + logger.info( + "Dataset validation complete", + valid=results["is_valid"], + quality_score=results["quality_score"], + errors=len(results["errors"]), + warnings=len(results["warnings"]), + ) + + return results + + def _load_samples(self, data_path: Path) -> List[Dict[str, Any]]: + """Load samples from dataset files.""" + samples = [] + + if data_path.is_file() and data_path.suffix == ".jsonl": + with open(data_path, "r", encoding="utf-8") as f: + for line_num, line in enumerate(f, 1): + line = line.strip() + if line: + try: + sample = json.loads(line) + samples.append(sample) + except json.JSONDecodeError as e: + logger.warning(f"Invalid JSON at line {line_num}: {e}") + + elif data_path.is_dir(): + for file_path in data_path.glob("*.jsonl"): + with open(file_path, "r", encoding="utf-8") as f: + for line_num, line in enumerate(f, 1): + line = line.strip() + if line: + try: + sample = json.loads(line) + samples.append(sample) + except json.JSONDecodeError as e: + logger.warning( + f"Invalid JSON in {file_path} at line {line_num}: {e}" + ) + + return samples + + def _validate_structure(self, samples: List[Dict], results: Dict): + """Validate basic dataset structure.""" + if len(samples) < self.validation_rules["min_samples"]: + results["errors"].append( + f"Insufficient samples: {len(samples)} < {self.validation_rules['min_samples']}" + ) + + if len(samples) > self.validation_rules["max_samples"]: + results["warnings"].append( + f"Large dataset: {len(samples)} > {self.validation_rules['max_samples']}" + ) + + # Check required fields + required_fields = self.validation_rules["required_fields"] + for i, sample in enumerate(samples[:100]): # Check first 100 samples + for field in required_fields: + if field not in sample: + results["errors"].append( + f"Missing required field '{field}' in sample {i}" + ) + + def _validate_content_quality(self, samples: List[Dict], results: Dict): + """Validate content quality.""" + text_lengths = [] + empty_texts = 0 + duplicate_texts = set() + seen_texts = set() + + for sample in samples: + text = sample.get("text", "").strip() + + # Check text length + text_len = len(text) + text_lengths.append(text_len) + + if text_len < self.validation_rules["min_text_length"]: + results["errors"].append(f"Text too short: {text_len} chars") + elif text_len > self.validation_rules["max_text_length"]: + results["warnings"].append(f"Text too long: {text_len} chars") + + if not text: + empty_texts += 1 + + # Check for duplicates + if text in seen_texts: + duplicate_texts.add(text) + else: + seen_texts.add(text) + + # Statistics + results["statistics"].update( + { + "avg_text_length": ( + sum(text_lengths) / len(text_lengths) if text_lengths else 0 + ), + "min_text_length": min(text_lengths) if text_lengths else 0, + "max_text_length": max(text_lengths) if text_lengths else 0, + "empty_texts": empty_texts, + "duplicate_texts": len(duplicate_texts), + } + ) + + if empty_texts > 0: + results["errors"].append(f"Found {empty_texts} empty text samples") + + if len(duplicate_texts) > len(samples) * 0.01: # >1% duplicates + results["warnings"].append( + f"High duplicate rate: {len(duplicate_texts)} duplicates" + ) + + def _validate_type_specific(self, samples: List[Dict], results: Dict): + """Type-specific validation.""" + if self.dataset_type == "lease_abstraction": + self._validate_lease_abstraction(samples, results) + elif self.dataset_type == "contract_intelligence": + self._validate_contract_intelligence(samples, results) + + def _validate_lease_abstraction(self, samples: List[Dict], results: Dict): + """Validate lease abstraction dataset.""" + keywords = self.validation_rules["lease_keywords"] + min_matches = self.validation_rules["min_keyword_matches"] + + low_keyword_samples = 0 + address_pattern_matches = 0 + rent_pattern_matches = 0 + + # Address patterns (street numbers, street names, cities) + address_pattern = r"\d+\s+[A-Za-z0-9\s,.-]+(?:Street|St|Avenue|Ave|Road|Rd|Boulevard|Blvd|Drive|Dr|Lane|Ln|Way|Place|Pl|Court|Ct)\s*,?\s*[A-Za-z\s]+,?\s*\d{5}" + + # Rent patterns (dollar amounts) + rent_pattern = r"\$\d{1,3}(?:,\d{3})*(?:\.\d{2})?" + + for sample in samples[:500]: # Check first 500 samples for performance + text = sample.get("text", "").lower() + + # Keyword matching + keyword_matches = sum(1 for keyword in keywords if keyword in text) + if keyword_matches < min_matches: + low_keyword_samples += 1 + + # Pattern matching + if re.search(address_pattern, sample.get("text", "")): + address_pattern_matches += 1 + + if re.search(rent_pattern, sample.get("text", "")): + rent_pattern_matches += 1 + + total_checked = min(500, len(samples)) + keyword_failure_rate = low_keyword_samples / total_checked + + if keyword_failure_rate > 0.3: # >30% samples lack keywords + results["warnings"].append( + f"Low keyword relevance: {keyword_failure_rate:.1%} samples lack lease keywords" + ) + + results["statistics"].update( + { + "address_pattern_matches": address_pattern_matches, + "rent_pattern_matches": rent_pattern_matches, + "keyword_relevance_score": 1.0 - keyword_failure_rate, + } + ) + + def _validate_contract_intelligence(self, samples: List[Dict], results: Dict): + """Validate contract intelligence dataset.""" + keywords = self.validation_rules["contract_keywords"] + min_matches = self.validation_rules["min_keyword_matches"] + + low_keyword_samples = 0 + legal_pattern_matches = 0 + + # Legal clause patterns + legal_patterns = [ + r"\bsection\s+\d+", + r"\barticle\s+\d+", + r"\bclause\s+\d+", + r"\bparagraph\s+\d+", + r"\bsubsection\s+\d+", + ] + + for sample in samples[:500]: # Check first 500 samples + text = sample.get("text", "").lower() + + # Keyword matching + keyword_matches = sum(1 for keyword in keywords if keyword in text) + if keyword_matches < min_matches: + low_keyword_samples += 1 + + # Legal pattern matching + if any( + re.search(pattern, sample.get("text", "")) for pattern in legal_patterns + ): + legal_pattern_matches += 1 + + total_checked = min(500, len(samples)) + keyword_failure_rate = low_keyword_samples / total_checked + + if keyword_failure_rate > 0.3: + results["warnings"].append( + f"Low keyword relevance: {keyword_failure_rate:.1%} samples lack contract keywords" + ) + + results["statistics"].update( + { + "legal_pattern_matches": legal_pattern_matches, + "keyword_relevance_score": 1.0 - keyword_failure_rate, + } + ) + + def _calculate_quality_score(self, results: Dict) -> float: + """Calculate overall quality score (0.0 to 1.0).""" + score = 1.0 + + # Penalize errors heavily + error_penalty = len(results["errors"]) * 0.2 + score -= min(error_penalty, 0.8) + + # Penalize warnings moderately + warning_penalty = len(results["warnings"]) * 0.05 + score -= min(warning_penalty, 0.2) + + # Bonus for good statistics + stats = results["statistics"] + + if stats.get("avg_text_length", 0) > 100: + score += 0.05 # Good average text length + + if stats.get("duplicate_texts", 0) == 0: + score += 0.1 # No duplicates + + if stats.get("empty_texts", 0) == 0: + score += 0.1 # No empty texts + + # Type-specific bonuses + if self.dataset_type == "lease_abstraction": + if stats.get("keyword_relevance_score", 0) > 0.7: + score += 0.1 + if stats.get("address_pattern_matches", 0) > 0: + score += 0.05 + + elif self.dataset_type == "contract_intelligence": + if stats.get("keyword_relevance_score", 0) > 0.7: + score += 0.1 + if stats.get("legal_pattern_matches", 0) > 0: + score += 0.05 + + return max(0.0, min(1.0, score)) + + +def validate_dataset_quality( + data_path: Path, dataset_type: str = "general" +) -> Dict[str, Any]: + """ + Convenience function to validate dataset quality. + + Args: + data_path: Path to dataset + dataset_type: Type of dataset for specialized validation + + Returns: + Validation results + """ + validator = DatasetValidator(dataset_type) + return validator.validate_dataset(data_path) diff --git a/src/deepiri_modelkit/logging.py b/src/deepiri_modelkit/logging.py index 9b65998..450eb0c 100644 --- a/src/deepiri_modelkit/logging.py +++ b/src/deepiri_modelkit/logging.py @@ -1,147 +1,171 @@ -""" -Shared logging utilities for all Deepiri services -Used by: Cyrex (runtime), Helox (training), and all microservices -""" -import logging -import sys -import json -from datetime import datetime -from typing import Any, Dict, Optional -from pathlib import Path - - -class StructuredLogger: - """JSON structured logger for all Deepiri services""" - - def __init__(self, name: str, level: int = logging.INFO): - self.logger = logging.getLogger(name) - self.logger.setLevel(level) - - # Remove existing handlers - self.logger.handlers = [] - - # Create console handler with JSON formatting - handler = logging.StreamHandler(sys.stdout) - handler.setLevel(level) - handler.setFormatter(JsonFormatter()) - self.logger.addHandler(handler) - - self.logger.propagate = False - - def _log(self, level: int, event: str, **kwargs): - """Internal log method with structured data""" - extra = {"event": event, "timestamp": datetime.utcnow().isoformat() + "Z"} - extra.update(kwargs) - self.logger.log(level, json.dumps(extra)) - - def debug(self, event: str, **kwargs): - self._log(logging.DEBUG, event, **kwargs) - - def info(self, event: str, **kwargs): - self._log(logging.INFO, event, **kwargs) - - def warning(self, event: str, **kwargs): - self._log(logging.WARNING, event, **kwargs) - - def error(self, event: str, **kwargs): - self._log(logging.ERROR, event, **kwargs) - - def critical(self, event: str, **kwargs): - self._log(logging.CRITICAL, event, **kwargs) - - -class JsonFormatter(logging.Formatter): - """Format logs as JSON""" - - def format(self, record: logging.LogRecord) -> str: - log_data = { - "timestamp": datetime.utcnow().isoformat() + "Z", - "level": record.levelname, - "logger": record.name, - "message": record.getMessage(), - } - - # Add extra fields if present - if hasattr(record, 'event'): - log_data['event'] = record.event - - # Add any custom fields from extra={} - for key, value in record.__dict__.items(): - if key not in ['name', 'msg', 'args', 'created', 'filename', 'funcName', - 'levelname', 'levelno', 'lineno', 'module', 'msecs', - 'message', 'pathname', 'process', 'processName', - 'relativeCreated', 'thread', 'threadName', 'exc_info', - 'exc_text', 'stack_info', 'event', 'timestamp']: - log_data[key] = value - - return json.dumps(log_data) - - -def get_logger(name: str, level: int = logging.INFO) -> StructuredLogger: - """ - Get structured logger instance - - Usage: - from deepiri_modelkit.logging import get_logger - logger = get_logger("my_service") - logger.info("service_started", port=8000, version="1.0") - """ - return StructuredLogger(name, level) - - -class ErrorLogger: - """Error logging with context""" - - def __init__(self): - self.logger = get_logger("error_logger") - - def log_api_error(self, error: Exception, request_id: str, endpoint: str): - """Log API errors with context""" - self.logger.error( - "api_error", - error=str(error), - error_type=type(error).__name__, - request_id=request_id, - endpoint=endpoint - ) - - def log_model_error(self, error: Exception, model_name: str, input_data: Optional[Dict] = None): - """Log model inference errors""" - self.logger.error( - "model_error", - error=str(error), - error_type=type(error).__name__, - model_name=model_name, - input_sample=str(input_data)[:200] if input_data else None - ) - - def log_training_error(self, error: Exception, pipeline: str, config: Optional[Dict] = None): - """Log training pipeline errors""" - self.logger.error( - "training_error", - error=str(error), - error_type=type(error).__name__, - pipeline=pipeline, - config=config - ) - - -# Singleton instances -_loggers: Dict[str, StructuredLogger] = {} -_error_logger: Optional[ErrorLogger] = None - - -def get_cached_logger(name: str) -> StructuredLogger: - """Get or create cached logger instance""" - if name not in _loggers: - _loggers[name] = get_logger(name) - return _loggers[name] - - -def get_error_logger() -> ErrorLogger: - """Get singleton error logger""" - global _error_logger - if _error_logger is None: - _error_logger = ErrorLogger() - return _error_logger - +""" +Shared logging utilities for all Deepiri services +Used by: Cyrex (runtime), Helox (training), and all microservices +""" + +import logging +import sys +import json +from datetime import datetime +from typing import Any, Dict, Optional +from pathlib import Path + + +class StructuredLogger: + """JSON structured logger for all Deepiri services""" + + def __init__(self, name: str, level: int = logging.INFO): + self.logger = logging.getLogger(name) + self.logger.setLevel(level) + + # Remove existing handlers + self.logger.handlers = [] + + # Create console handler with JSON formatting + handler = logging.StreamHandler(sys.stdout) + handler.setLevel(level) + handler.setFormatter(JsonFormatter()) + self.logger.addHandler(handler) + + self.logger.propagate = False + + def _log(self, level: int, event: str, **kwargs): + """Internal log method with structured data""" + extra = {"event": event, "timestamp": datetime.utcnow().isoformat() + "Z"} + extra.update(kwargs) + self.logger.log(level, json.dumps(extra)) + + def debug(self, event: str, **kwargs): + self._log(logging.DEBUG, event, **kwargs) + + def info(self, event: str, **kwargs): + self._log(logging.INFO, event, **kwargs) + + def warning(self, event: str, **kwargs): + self._log(logging.WARNING, event, **kwargs) + + def error(self, event: str, **kwargs): + self._log(logging.ERROR, event, **kwargs) + + def critical(self, event: str, **kwargs): + self._log(logging.CRITICAL, event, **kwargs) + + +class JsonFormatter(logging.Formatter): + """Format logs as JSON""" + + def format(self, record: logging.LogRecord) -> str: + log_data = { + "timestamp": datetime.utcnow().isoformat() + "Z", + "level": record.levelname, + "logger": record.name, + "message": record.getMessage(), + } + + # Add extra fields if present + if hasattr(record, "event"): + log_data["event"] = record.event + + # Add any custom fields from extra={} + for key, value in record.__dict__.items(): + if key not in [ + "name", + "msg", + "args", + "created", + "filename", + "funcName", + "levelname", + "levelno", + "lineno", + "module", + "msecs", + "message", + "pathname", + "process", + "processName", + "relativeCreated", + "thread", + "threadName", + "exc_info", + "exc_text", + "stack_info", + "event", + "timestamp", + ]: + log_data[key] = value + + return json.dumps(log_data) + + +def get_logger(name: str, level: int = logging.INFO) -> StructuredLogger: + """ + Get structured logger instance + + Usage: + from deepiri_modelkit.logging import get_logger + logger = get_logger("my_service") + logger.info("service_started", port=8000, version="1.0") + """ + return StructuredLogger(name, level) + + +class ErrorLogger: + """Error logging with context""" + + def __init__(self): + self.logger = get_logger("error_logger") + + def log_api_error(self, error: Exception, request_id: str, endpoint: str): + """Log API errors with context""" + self.logger.error( + "api_error", + error=str(error), + error_type=type(error).__name__, + request_id=request_id, + endpoint=endpoint, + ) + + def log_model_error( + self, error: Exception, model_name: str, input_data: Optional[Dict] = None + ): + """Log model inference errors""" + self.logger.error( + "model_error", + error=str(error), + error_type=type(error).__name__, + model_name=model_name, + input_sample=str(input_data)[:200] if input_data else None, + ) + + def log_training_error( + self, error: Exception, pipeline: str, config: Optional[Dict] = None + ): + """Log training pipeline errors""" + self.logger.error( + "training_error", + error=str(error), + error_type=type(error).__name__, + pipeline=pipeline, + config=config, + ) + + +# Singleton instances +_loggers: Dict[str, StructuredLogger] = {} +_error_logger: Optional[ErrorLogger] = None + + +def get_cached_logger(name: str) -> StructuredLogger: + """Get or create cached logger instance""" + if name not in _loggers: + _loggers[name] = get_logger(name) + return _loggers[name] + + +def get_error_logger() -> ErrorLogger: + """Get singleton error logger""" + global _error_logger + if _error_logger is None: + _error_logger = ErrorLogger() + return _error_logger diff --git a/src/deepiri_modelkit/ml/__init__.py b/src/deepiri_modelkit/ml/__init__.py index 09c30ff..19ab9c1 100644 --- a/src/deepiri_modelkit/ml/__init__.py +++ b/src/deepiri_modelkit/ml/__init__.py @@ -1,33 +1,35 @@ -"""ML utilities for Deepiri ModelKit""" - -try: - from .confidence import ( - ConfidenceLevel, - ConfidenceSource, - ConfidenceAttributes, - ConfidenceCalculator, - get_confidence_calculator, - ) - _HAS_CONFIDENCE = True -except ImportError: - _HAS_CONFIDENCE = False - -try: - from .semantic import SemanticAnalyzer, get_semantic_analyzer - _HAS_SEMANTIC = True -except ImportError: - _HAS_SEMANTIC = False - -__all__ = [] - -if _HAS_CONFIDENCE: - __all__ += [ - "ConfidenceLevel", - "ConfidenceSource", - "ConfidenceAttributes", - "ConfidenceCalculator", - "get_confidence_calculator", - ] - -if _HAS_SEMANTIC: - __all__ += ["SemanticAnalyzer", "get_semantic_analyzer"] +"""ML utilities for Deepiri ModelKit""" + +try: + from .confidence import ( + ConfidenceLevel, + ConfidenceSource, + ConfidenceAttributes, + ConfidenceCalculator, + get_confidence_calculator, + ) + + _HAS_CONFIDENCE = True +except ImportError: + _HAS_CONFIDENCE = False + +try: + from .semantic import SemanticAnalyzer, get_semantic_analyzer + + _HAS_SEMANTIC = True +except ImportError: + _HAS_SEMANTIC = False + +__all__ = [] + +if _HAS_CONFIDENCE: + __all__ += [ + "ConfidenceLevel", + "ConfidenceSource", + "ConfidenceAttributes", + "ConfidenceCalculator", + "get_confidence_calculator", + ] + +if _HAS_SEMANTIC: + __all__ += ["SemanticAnalyzer", "get_semantic_analyzer"] diff --git a/src/deepiri_modelkit/ml/confidence.py b/src/deepiri_modelkit/ml/confidence.py index 89d5f2b..5d78f84 100644 --- a/src/deepiri_modelkit/ml/confidence.py +++ b/src/deepiri_modelkit/ml/confidence.py @@ -1,292 +1,316 @@ -""" -Redesigned Confidence Classes System -Provides structured confidence assessment with multiple attributes -""" -from enum import Enum -from typing import Dict, List, Optional, Tuple -from dataclasses import dataclass - -try: - import numpy as np - HAS_NUMPY = True -except ImportError: - HAS_NUMPY = False - - -class ConfidenceLevel(str, Enum): - """Confidence level categories""" - VERY_HIGH = "very_high" # 0.9-1.0 - HIGH = "high" # 0.75-0.9 - MEDIUM = "medium" # 0.5-0.75 - LOW = "low" # 0.25-0.5 - VERY_LOW = "very_low" # 0.0-0.25 - - -class ConfidenceSource(str, Enum): - """Sources of confidence information""" - MODEL_PREDICTION = "model_prediction" - TRAINING_DATA_COVERAGE = "training_data_coverage" - FEATURE_QUALITY = "feature_quality" - CONTEXT_MATCH = "context_match" - HISTORICAL_ACCURACY = "historical_accuracy" - ENSEMBLE_AGREEMENT = "ensemble_agreement" - - -@dataclass -class ConfidenceAttributes: - """ - Comprehensive confidence attributes for model predictions - - Attributes: - base_score: Raw model confidence score (0.0-1.0) - level: Categorical confidence level - sources: Dictionary of confidence sources and their contributions - uncertainty: Measure of prediction uncertainty - calibration: How well-calibrated the prediction is - reliability: Overall reliability score - explanation: Human-readable explanation - """ - base_score: float - level: ConfidenceLevel - sources: Dict[str, float] - uncertainty: float - calibration: float - reliability: float - explanation: str - - def to_dict(self) -> Dict: - """Convert to dictionary""" - return { - "base_score": self.base_score, - "level": self.level.value, - "sources": self.sources, - "uncertainty": self.uncertainty, - "calibration": self.calibration, - "reliability": self.reliability, - "explanation": self.explanation - } - - -class ConfidenceCalculator: - """ - Calculate comprehensive confidence scores with multiple attributes - """ - - def __init__(self): - self.confidence_thresholds = { - ConfidenceLevel.VERY_HIGH: 0.9, - ConfidenceLevel.HIGH: 0.75, - ConfidenceLevel.MEDIUM: 0.5, - ConfidenceLevel.LOW: 0.25, - ConfidenceLevel.VERY_LOW: 0.0 - } - - def calculate_confidence( - self, - model_probabilities: "np.ndarray", - top_k_probs: Optional[List[float]] = None, - training_coverage: Optional[float] = None, - feature_quality: Optional[float] = None, - context_match: Optional[float] = None, - historical_accuracy: Optional[Dict[int, float]] = None - ) -> ConfidenceAttributes: - """ - Calculate comprehensive confidence attributes - - Args: - model_probabilities: Model output probabilities for all classes - top_k_probs: Top-k probabilities (for ensemble agreement) - training_coverage: How well training data covers this example (0-1) - feature_quality: Quality of input features (0-1) - context_match: How well context matches expected patterns (0-1) - historical_accuracy: Historical accuracy per class {class_id: accuracy} - - Returns: - ConfidenceAttributes object - """ - if not HAS_NUMPY: - raise ImportError("numpy is required for ConfidenceCalculator. Install with: pip install numpy") - - # Base score: maximum probability - base_score = float(np.max(model_probabilities)) - - # Uncertainty: entropy-based measure - entropy = -np.sum(model_probabilities * np.log(model_probabilities + 1e-10)) - max_entropy = np.log(len(model_probabilities)) - uncertainty = float(entropy / max_entropy) # Normalized to [0, 1] - - # Calibration: difference between top-2 probabilities (margin) - sorted_probs = np.sort(model_probabilities)[::-1] - margin = float(sorted_probs[0] - sorted_probs[1]) if len(sorted_probs) > 1 else 1.0 - calibration = float(margin) # Higher margin = better calibration - - # Source contributions - sources = {} - - # Model prediction contribution - sources[ConfidenceSource.MODEL_PREDICTION.value] = base_score - - # Training data coverage - if training_coverage is not None: - sources[ConfidenceSource.TRAINING_DATA_COVERAGE.value] = training_coverage - else: - sources[ConfidenceSource.TRAINING_DATA_COVERAGE.value] = 0.7 # Default moderate - - # Feature quality - if feature_quality is not None: - sources[ConfidenceSource.FEATURE_QUALITY.value] = feature_quality - else: - sources[ConfidenceSource.FEATURE_QUALITY.value] = 0.8 # Default good - - # Context match - if context_match is not None: - sources[ConfidenceSource.CONTEXT_MATCH.value] = context_match - else: - sources[ConfidenceSource.CONTEXT_MATCH.value] = 0.7 # Default moderate - - # Historical accuracy - if historical_accuracy: - predicted_class = int(np.argmax(model_probabilities)) - hist_acc = historical_accuracy.get(predicted_class, 0.7) - sources[ConfidenceSource.HISTORICAL_ACCURACY.value] = hist_acc - else: - sources[ConfidenceSource.HISTORICAL_ACCURACY.value] = 0.7 # Default moderate - - # Ensemble agreement (if top_k_probs provided) - if top_k_probs: - agreement = float(np.std(top_k_probs)) # Lower std = higher agreement - sources[ConfidenceSource.ENSEMBLE_AGREEMENT.value] = 1.0 - min(agreement, 1.0) - else: - sources[ConfidenceSource.ENSEMBLE_AGREEMENT.value] = 0.7 # Default moderate - - # Weighted reliability score - weights = { - ConfidenceSource.MODEL_PREDICTION.value: 0.4, - ConfidenceSource.TRAINING_DATA_COVERAGE.value: 0.15, - ConfidenceSource.FEATURE_QUALITY.value: 0.15, - ConfidenceSource.CONTEXT_MATCH.value: 0.1, - ConfidenceSource.HISTORICAL_ACCURACY.value: 0.1, - ConfidenceSource.ENSEMBLE_AGREEMENT.value: 0.1 - } - - reliability = sum( - sources[source] * weights.get(source, 0.0) - for source in sources - ) - - # Adjust reliability based on uncertainty and calibration - reliability = reliability * (1.0 - uncertainty * 0.3) * (0.7 + calibration * 0.3) - reliability = max(0.0, min(1.0, reliability)) - - # Determine confidence level - level = self._get_confidence_level(reliability) - - # Generate explanation - explanation = self._generate_explanation( - reliability, level, sources, uncertainty, calibration - ) - - return ConfidenceAttributes( - base_score=base_score, - level=level, - sources=sources, - uncertainty=uncertainty, - calibration=calibration, - reliability=reliability, - explanation=explanation - ) - - def _get_confidence_level(self, reliability: float) -> ConfidenceLevel: - """Get confidence level from reliability score""" - if reliability >= 0.9: - return ConfidenceLevel.VERY_HIGH - elif reliability >= 0.75: - return ConfidenceLevel.HIGH - elif reliability >= 0.5: - return ConfidenceLevel.MEDIUM - elif reliability >= 0.25: - return ConfidenceLevel.LOW - else: - return ConfidenceLevel.VERY_LOW - - def _generate_explanation( - self, - reliability: float, - level: ConfidenceLevel, - sources: Dict[str, float], - uncertainty: float, - calibration: float - ) -> str: - """Generate human-readable explanation""" - parts = [] - - # Main confidence statement - parts.append(f"Confidence: {level.value.replace('_', ' ').title()} ({reliability:.2%})") - - # Key factors - key_factors = [] - if sources.get(ConfidenceSource.MODEL_PREDICTION.value, 0) > 0.8: - key_factors.append("strong model prediction") - if sources.get(ConfidenceSource.TRAINING_DATA_COVERAGE.value, 0) > 0.8: - key_factors.append("good training coverage") - if uncertainty < 0.3: - key_factors.append("low uncertainty") - if calibration > 0.5: - key_factors.append("clear class separation") - - if key_factors: - parts.append(f"Key factors: {', '.join(key_factors)}") - - # Concerns - concerns = [] - if uncertainty > 0.6: - concerns.append("high uncertainty") - if calibration < 0.2: - concerns.append("unclear class separation") - if sources.get(ConfidenceSource.TRAINING_DATA_COVERAGE.value, 0) < 0.5: - concerns.append("limited training coverage") - - if concerns: - parts.append(f"Concerns: {', '.join(concerns)}") - - return ". ".join(parts) + "." - - def should_accept_prediction( - self, - confidence: ConfidenceAttributes, - min_reliability: float = 0.7, - min_level: ConfidenceLevel = ConfidenceLevel.MEDIUM - ) -> Tuple[bool, str]: - """ - Determine if prediction should be accepted based on confidence - - Returns: - (should_accept, reason) - """ - level_order = { - ConfidenceLevel.VERY_LOW: 0, - ConfidenceLevel.LOW: 1, - ConfidenceLevel.MEDIUM: 2, - ConfidenceLevel.HIGH: 3, - ConfidenceLevel.VERY_HIGH: 4 - } - - if confidence.reliability < min_reliability: - return False, f"Reliability {confidence.reliability:.2%} below threshold {min_reliability:.2%}" - - if level_order[confidence.level] < level_order[min_level]: - return False, f"Confidence level {confidence.level.value} below required {min_level.value}" - - return True, "Confidence meets requirements" - - -# Singleton instance -_confidence_calculator = None - - -def get_confidence_calculator() -> ConfidenceCalculator: - """Get singleton ConfidenceCalculator instance""" - global _confidence_calculator - if _confidence_calculator is None: - _confidence_calculator = ConfidenceCalculator() - return _confidence_calculator +""" +Redesigned Confidence Classes System +Provides structured confidence assessment with multiple attributes +""" + +from enum import Enum +from typing import Dict, List, Optional, Tuple +from dataclasses import dataclass + +try: + import numpy as np + + HAS_NUMPY = True +except ImportError: + HAS_NUMPY = False + + +class ConfidenceLevel(str, Enum): + """Confidence level categories""" + + VERY_HIGH = "very_high" # 0.9-1.0 + HIGH = "high" # 0.75-0.9 + MEDIUM = "medium" # 0.5-0.75 + LOW = "low" # 0.25-0.5 + VERY_LOW = "very_low" # 0.0-0.25 + + +class ConfidenceSource(str, Enum): + """Sources of confidence information""" + + MODEL_PREDICTION = "model_prediction" + TRAINING_DATA_COVERAGE = "training_data_coverage" + FEATURE_QUALITY = "feature_quality" + CONTEXT_MATCH = "context_match" + HISTORICAL_ACCURACY = "historical_accuracy" + ENSEMBLE_AGREEMENT = "ensemble_agreement" + + +@dataclass +class ConfidenceAttributes: + """ + Comprehensive confidence attributes for model predictions + + Attributes: + base_score: Raw model confidence score (0.0-1.0) + level: Categorical confidence level + sources: Dictionary of confidence sources and their contributions + uncertainty: Measure of prediction uncertainty + calibration: How well-calibrated the prediction is + reliability: Overall reliability score + explanation: Human-readable explanation + """ + + base_score: float + level: ConfidenceLevel + sources: Dict[str, float] + uncertainty: float + calibration: float + reliability: float + explanation: str + + def to_dict(self) -> Dict: + """Convert to dictionary""" + return { + "base_score": self.base_score, + "level": self.level.value, + "sources": self.sources, + "uncertainty": self.uncertainty, + "calibration": self.calibration, + "reliability": self.reliability, + "explanation": self.explanation, + } + + +class ConfidenceCalculator: + """ + Calculate comprehensive confidence scores with multiple attributes + """ + + def __init__(self): + self.confidence_thresholds = { + ConfidenceLevel.VERY_HIGH: 0.9, + ConfidenceLevel.HIGH: 0.75, + ConfidenceLevel.MEDIUM: 0.5, + ConfidenceLevel.LOW: 0.25, + ConfidenceLevel.VERY_LOW: 0.0, + } + + def calculate_confidence( + self, + model_probabilities: "np.ndarray", + top_k_probs: Optional[List[float]] = None, + training_coverage: Optional[float] = None, + feature_quality: Optional[float] = None, + context_match: Optional[float] = None, + historical_accuracy: Optional[Dict[int, float]] = None, + ) -> ConfidenceAttributes: + """ + Calculate comprehensive confidence attributes + + Args: + model_probabilities: Model output probabilities for all classes + top_k_probs: Top-k probabilities (for ensemble agreement) + training_coverage: How well training data covers this example (0-1) + feature_quality: Quality of input features (0-1) + context_match: How well context matches expected patterns (0-1) + historical_accuracy: Historical accuracy per class {class_id: accuracy} + + Returns: + ConfidenceAttributes object + """ + if not HAS_NUMPY: + raise ImportError( + "numpy is required for ConfidenceCalculator. Install with: pip install numpy" + ) + + # Base score: maximum probability + base_score = float(np.max(model_probabilities)) + + # Uncertainty: entropy-based measure + entropy = -np.sum(model_probabilities * np.log(model_probabilities + 1e-10)) + max_entropy = np.log(len(model_probabilities)) + uncertainty = float(entropy / max_entropy) # Normalized to [0, 1] + + # Calibration: difference between top-2 probabilities (margin) + sorted_probs = np.sort(model_probabilities)[::-1] + margin = ( + float(sorted_probs[0] - sorted_probs[1]) if len(sorted_probs) > 1 else 1.0 + ) + calibration = float(margin) # Higher margin = better calibration + + # Source contributions + sources = {} + + # Model prediction contribution + sources[ConfidenceSource.MODEL_PREDICTION.value] = base_score + + # Training data coverage + if training_coverage is not None: + sources[ConfidenceSource.TRAINING_DATA_COVERAGE.value] = training_coverage + else: + sources[ConfidenceSource.TRAINING_DATA_COVERAGE.value] = ( + 0.7 # Default moderate + ) + + # Feature quality + if feature_quality is not None: + sources[ConfidenceSource.FEATURE_QUALITY.value] = feature_quality + else: + sources[ConfidenceSource.FEATURE_QUALITY.value] = 0.8 # Default good + + # Context match + if context_match is not None: + sources[ConfidenceSource.CONTEXT_MATCH.value] = context_match + else: + sources[ConfidenceSource.CONTEXT_MATCH.value] = 0.7 # Default moderate + + # Historical accuracy + if historical_accuracy: + predicted_class = int(np.argmax(model_probabilities)) + hist_acc = historical_accuracy.get(predicted_class, 0.7) + sources[ConfidenceSource.HISTORICAL_ACCURACY.value] = hist_acc + else: + sources[ConfidenceSource.HISTORICAL_ACCURACY.value] = ( + 0.7 # Default moderate + ) + + # Ensemble agreement (if top_k_probs provided) + if top_k_probs: + agreement = float(np.std(top_k_probs)) # Lower std = higher agreement + sources[ConfidenceSource.ENSEMBLE_AGREEMENT.value] = 1.0 - min( + agreement, 1.0 + ) + else: + sources[ConfidenceSource.ENSEMBLE_AGREEMENT.value] = 0.7 # Default moderate + + # Weighted reliability score + weights = { + ConfidenceSource.MODEL_PREDICTION.value: 0.4, + ConfidenceSource.TRAINING_DATA_COVERAGE.value: 0.15, + ConfidenceSource.FEATURE_QUALITY.value: 0.15, + ConfidenceSource.CONTEXT_MATCH.value: 0.1, + ConfidenceSource.HISTORICAL_ACCURACY.value: 0.1, + ConfidenceSource.ENSEMBLE_AGREEMENT.value: 0.1, + } + + reliability = sum( + sources[source] * weights.get(source, 0.0) for source in sources + ) + + # Adjust reliability based on uncertainty and calibration + reliability = ( + reliability * (1.0 - uncertainty * 0.3) * (0.7 + calibration * 0.3) + ) + reliability = max(0.0, min(1.0, reliability)) + + # Determine confidence level + level = self._get_confidence_level(reliability) + + # Generate explanation + explanation = self._generate_explanation( + reliability, level, sources, uncertainty, calibration + ) + + return ConfidenceAttributes( + base_score=base_score, + level=level, + sources=sources, + uncertainty=uncertainty, + calibration=calibration, + reliability=reliability, + explanation=explanation, + ) + + def _get_confidence_level(self, reliability: float) -> ConfidenceLevel: + """Get confidence level from reliability score""" + if reliability >= 0.9: + return ConfidenceLevel.VERY_HIGH + elif reliability >= 0.75: + return ConfidenceLevel.HIGH + elif reliability >= 0.5: + return ConfidenceLevel.MEDIUM + elif reliability >= 0.25: + return ConfidenceLevel.LOW + else: + return ConfidenceLevel.VERY_LOW + + def _generate_explanation( + self, + reliability: float, + level: ConfidenceLevel, + sources: Dict[str, float], + uncertainty: float, + calibration: float, + ) -> str: + """Generate human-readable explanation""" + parts = [] + + # Main confidence statement + parts.append( + f"Confidence: {level.value.replace('_', ' ').title()} ({reliability:.2%})" + ) + + # Key factors + key_factors = [] + if sources.get(ConfidenceSource.MODEL_PREDICTION.value, 0) > 0.8: + key_factors.append("strong model prediction") + if sources.get(ConfidenceSource.TRAINING_DATA_COVERAGE.value, 0) > 0.8: + key_factors.append("good training coverage") + if uncertainty < 0.3: + key_factors.append("low uncertainty") + if calibration > 0.5: + key_factors.append("clear class separation") + + if key_factors: + parts.append(f"Key factors: {', '.join(key_factors)}") + + # Concerns + concerns = [] + if uncertainty > 0.6: + concerns.append("high uncertainty") + if calibration < 0.2: + concerns.append("unclear class separation") + if sources.get(ConfidenceSource.TRAINING_DATA_COVERAGE.value, 0) < 0.5: + concerns.append("limited training coverage") + + if concerns: + parts.append(f"Concerns: {', '.join(concerns)}") + + return ". ".join(parts) + "." + + def should_accept_prediction( + self, + confidence: ConfidenceAttributes, + min_reliability: float = 0.7, + min_level: ConfidenceLevel = ConfidenceLevel.MEDIUM, + ) -> Tuple[bool, str]: + """ + Determine if prediction should be accepted based on confidence + + Returns: + (should_accept, reason) + """ + level_order = { + ConfidenceLevel.VERY_LOW: 0, + ConfidenceLevel.LOW: 1, + ConfidenceLevel.MEDIUM: 2, + ConfidenceLevel.HIGH: 3, + ConfidenceLevel.VERY_HIGH: 4, + } + + if confidence.reliability < min_reliability: + return ( + False, + f"Reliability {confidence.reliability:.2%} below threshold {min_reliability:.2%}", + ) + + if level_order[confidence.level] < level_order[min_level]: + return ( + False, + f"Confidence level {confidence.level.value} below required {min_level.value}", + ) + + return True, "Confidence meets requirements" + + +# Singleton instance +_confidence_calculator = None + + +def get_confidence_calculator() -> ConfidenceCalculator: + """Get singleton ConfidenceCalculator instance""" + global _confidence_calculator + if _confidence_calculator is None: + _confidence_calculator = ConfidenceCalculator() + return _confidence_calculator diff --git a/src/deepiri_modelkit/ml/semantic.py b/src/deepiri_modelkit/ml/semantic.py index 141c6b1..a921fd3 100644 --- a/src/deepiri_modelkit/ml/semantic.py +++ b/src/deepiri_modelkit/ml/semantic.py @@ -1,343 +1,361 @@ -""" -Dynamic Semantic Analysis for Data Augmentation -Inspired by Carnegie Mellon University (CMU) Language Technologies Institute approaches -Uses semantic similarity and contextual understanding for dynamic variation generation -""" -import json -import re -from typing import List, Dict, Optional, Set -from collections import defaultdict -import os - -from deepiri_modelkit.logging import get_logger - -logger = get_logger("deepiri_modelkit.ml.semantic") - -# Try multiple HTTP clients -try: - import httpx - HAS_HTTPX = True -except ImportError: - HAS_HTTPX = False - -try: - import requests - HAS_REQUESTS = True -except ImportError: - HAS_REQUESTS = False - -try: - import ollama - HAS_OLLAMA_PKG = True -except ImportError: - HAS_OLLAMA_PKG = False - - -class SemanticAnalyzer: - """ - Dynamic semantic analysis for generating variations - Inspired by CMU's semantic analysis approaches - """ - - def __init__(self, ollama_base_url: str = "http://localhost:11434", model: str = "llama3:8b"): - self.ollama_base_url = ollama_base_url - self.model = model - self._cache = {} # Cache for semantic analysis results - - def _call_ollama(self, prompt: str, timeout: float = 15.0) -> Optional[str]: # Reduced from 30s to 15s - """Call Ollama API directly via HTTP or Python package""" - # Try ollama Python package first (cleaner API) - if HAS_OLLAMA_PKG: - try: - response = ollama.generate( - model=self.model, - prompt=prompt, - options={ - "temperature": 0.7, - "top_p": 0.9, - "num_predict": 100 # Reduced from 200 for faster responses - } - ) - return response.get("response", "").strip() - except Exception: - # Fall back to HTTP - pass - - # Fall back to HTTP API - try: - if HAS_HTTPX: - logger.debug(f"Calling Ollama HTTP with {len(prompt)} char prompt") - response = httpx.post( - f"{self.ollama_base_url}/api/generate", - json={ - "model": self.model, - "prompt": prompt, - "stream": False, - "options": { - "temperature": 0.7, - "top_p": 0.9, - "num_predict": 100 # Reduced from 200 for faster responses - } - }, - timeout=timeout - ) - - if response.status_code == 200: - result = response.json() - logger.debug("Ollama HTTP call succeeded") - return result.get("response", "").strip() - else: - logger.debug(f"Ollama HTTP call failed: HTTP {response.status_code}") - elif HAS_REQUESTS: - logger.debug(f"Calling Ollama HTTP with {len(prompt)} char prompt") - response = requests.post( - f"{self.ollama_base_url}/api/generate", - json={ - "model": self.model, - "prompt": prompt, - "stream": False, - "options": { - "temperature": 0.7, - "top_p": 0.9, - "num_predict": 100 # Reduced from 200 for faster responses - } - }, - timeout=timeout - ) - - if response.status_code == 200: - result = response.json() - logger.debug("Ollama HTTP call succeeded") - return result.get("response", "").strip() - else: - logger.debug(f"Ollama HTTP call failed: HTTP {response.status_code}") - except Exception as e: - logger.debug(f"Ollama HTTP call failed: {e}") - - return None - - def extract_semantic_verbs(self, text: str, category: str) -> List[str]: - """ - Extract semantically similar verbs using Ollama - Inspired by CMU's semantic role labeling approaches - Cached per category (not per text) for performance - """ - # Cache by category only, not per text (much more efficient) - cache_key = f"verbs:{category}" - if cache_key in self._cache: - return self._cache[cache_key] - - # Use category-level prompt (not text-specific) for better caching - prompt = f"""For tasks in the '{category}' category, suggest 6-8 common action verbs that are semantically similar and could be used interchangeably. - -Category: {category} - -Return ONLY a JSON array of verbs, no explanation. Example: ["write", "draft", "compose", "create", "author"]""" - - response = self._call_ollama(prompt) - if response: - try: - json_match = re.search(r'\[.*?\]', response, re.DOTALL) - if json_match: - verbs = json.loads(json_match.group()) - if isinstance(verbs, list) and len(verbs) > 0: - self._cache[cache_key] = verbs - return verbs - except Exception: - pass - - # Fallback: return empty list - return [] - - def generate_semantic_prefixes(self, text: str, category: str) -> List[str]: - """ - Generate semantically appropriate prefixes using contextual analysis - Inspired by CMU's discourse analysis approaches - Cached per category (not per text) for performance - """ - cache_key = f"prefixes:{category}" - if cache_key in self._cache: - return self._cache[cache_key] - - # Simplified prompt - category only (not text-specific) for better caching - prompt = f"""For tasks in the '{category}' category, generate 8-10 natural ways to introduce task requests. - -Category: {category} - -Consider: politeness levels, personal perspectives (I need, Can you, Let me), contextual frames. - -Return ONLY a JSON array of prefixes. Example: ["I need to", "Can you help me", "Please", "I want to"]""" - - response = self._call_ollama(prompt) - if response: - try: - json_match = re.search(r'\[.*?\]', response, re.DOTALL) - if json_match: - prefixes = json.loads(json_match.group()) - if isinstance(prefixes, list) and len(prefixes) > 0: - self._cache[cache_key] = prefixes - return prefixes - except Exception: - pass - - # Fallback: return default prefixes - return [ - "I need to", "Can you help me", "Please", "I want to", - "Help me", "I should", "Let me", "I'm going to" - ] - - def generate_semantic_suffixes(self, text: str, category: str) -> List[str]: - """ - Generate semantically appropriate suffixes using temporal and contextual analysis - Cached per category (not per text) for performance - """ - cache_key = f"suffixes:{category}" - if cache_key in self._cache: - return self._cache[cache_key] - - # Simplified prompt - category only (not text-specific) for better caching - prompt = f"""For tasks in the '{category}' category, generate 6-8 natural ways to add temporal or contextual information. - -Category: {category} - -Consider: time constraints, urgency levels, contextual markers. - -Return ONLY a JSON array of suffixes. Example: [" today", " this week", " as soon as possible"]""" - - response = self._call_ollama(prompt) - if response: - try: - json_match = re.search(r'\[.*?\]', response, re.DOTALL) - if json_match: - suffixes = json.loads(json_match.group()) - if isinstance(suffixes, list) and len(suffixes) > 0: - self._cache[cache_key] = suffixes - return suffixes - except Exception: - pass - - # Fallback: return default suffixes - return [ - "", " today", " this week", " as soon as possible", - " when you have time", " - urgent", " - important" - ] - - def generate_paraphrases(self, text: str, category: str, num_paraphrases: int = 3) -> List[str]: - """ - Generate semantic paraphrases using Ollama - Inspired by CMU's paraphrase generation approaches - """ - prompt = f"""Generate {num_paraphrases} different natural ways to express this task request. Each should be semantically equivalent but use different wording: - -Original: "{text}" -Category: {category} - -Requirements: -- Keep the same meaning and intent -- Use natural, conversational language -- Vary sentence structure and word choice -- Each paraphrase should be a complete sentence - -Return ONLY the paraphrases, one per line, without numbering or bullets.""" - - response = self._call_ollama(prompt) - if response: - paraphrases = [] - for line in response.strip().split('\n'): - line = line.strip() - # Remove common prefixes - for prefix in ['- ', '1. ', '2. ', '3. ', '4. ', '5. ', '* ', '• ']: - if line.startswith(prefix): - line = line[len(prefix):].strip() - - if line and line != text and len(line) > 10: - paraphrases.append(line) - - return paraphrases[:num_paraphrases] - - return [] - - def analyze_semantic_structure(self, text: str) -> Dict: - """ - Analyze semantic structure of text - Inspired by CMU's semantic role labeling and dependency parsing - """ - prompt = f"""Analyze the semantic structure of this task request: - -"{text}" - -Identify: -1. Main action verb -2. Object/noun phrase -3. Modifiers/adjectives -4. Temporal markers (if any) -5. Urgency indicators (if any) - -Return a JSON object with these fields.""" - - response = self._call_ollama(prompt) - if response: - try: - json_match = re.search(r'\{.*?\}', response, re.DOTALL) - if json_match: - return json.loads(json_match.group()) - except Exception: - pass - - # Fallback: simple analysis - words = text.lower().split() - return { - "action_verb": words[0] if words else "unknown", - "object": " ".join(words[1:]) if len(words) > 1 else "", - "modifiers": [], - "temporal": None, - "urgency": None - } - - def check_ollama_available(self) -> bool: - """Check if Ollama is available""" - # Try ollama package first - if HAS_OLLAMA_PKG: - try: - ollama.list() # This will raise if not available - return True - except Exception: - pass - - # Fall back to HTTP check - try: - if HAS_HTTPX: - response = httpx.get( - f"{self.ollama_base_url}/api/tags", - timeout=5.0 - ) - return response.status_code == 200 - elif HAS_REQUESTS: - response = requests.get( - f"{self.ollama_base_url}/api/tags", - timeout=5.0 - ) - return response.status_code == 200 - except Exception: - pass - - return False - - -def get_semantic_analyzer( - ollama_base_url: Optional[str] = None, - model: Optional[str] = None -) -> Optional[SemanticAnalyzer]: - """ - Factory function to get semantic analyzer - """ - base_url = ollama_base_url or os.getenv("OLLAMA_BASE_URL", "http://localhost:11434") - model_name = model or os.getenv("OLLAMA_MODEL", "llama3:8b") - - analyzer = SemanticAnalyzer(ollama_base_url=base_url, model=model_name) - - if analyzer.check_ollama_available(): - return analyzer - else: - logger.warning(f"Ollama not available at {base_url}") - return None +""" +Dynamic Semantic Analysis for Data Augmentation +Inspired by Carnegie Mellon University (CMU) Language Technologies Institute approaches +Uses semantic similarity and contextual understanding for dynamic variation generation +""" + +import json +import re +from typing import List, Dict, Optional, Set +from collections import defaultdict +import os + +from deepiri_modelkit.logging import get_logger + +logger = get_logger("deepiri_modelkit.ml.semantic") + +# Try multiple HTTP clients +try: + import httpx + + HAS_HTTPX = True +except ImportError: + HAS_HTTPX = False + +try: + import requests + + HAS_REQUESTS = True +except ImportError: + HAS_REQUESTS = False + +try: + import ollama + + HAS_OLLAMA_PKG = True +except ImportError: + HAS_OLLAMA_PKG = False + + +class SemanticAnalyzer: + """ + Dynamic semantic analysis for generating variations + Inspired by CMU's semantic analysis approaches + """ + + def __init__( + self, ollama_base_url: str = "http://localhost:11434", model: str = "llama3:8b" + ): + self.ollama_base_url = ollama_base_url + self.model = model + self._cache = {} # Cache for semantic analysis results + + def _call_ollama( + self, prompt: str, timeout: float = 15.0 + ) -> Optional[str]: # Reduced from 30s to 15s + """Call Ollama API directly via HTTP or Python package""" + # Try ollama Python package first (cleaner API) + if HAS_OLLAMA_PKG: + try: + response = ollama.generate( + model=self.model, + prompt=prompt, + options={ + "temperature": 0.7, + "top_p": 0.9, + "num_predict": 100, # Reduced from 200 for faster responses + }, + ) + return response.get("response", "").strip() + except Exception: + # Fall back to HTTP + pass + + # Fall back to HTTP API + try: + if HAS_HTTPX: + logger.debug(f"Calling Ollama HTTP with {len(prompt)} char prompt") + response = httpx.post( + f"{self.ollama_base_url}/api/generate", + json={ + "model": self.model, + "prompt": prompt, + "stream": False, + "options": { + "temperature": 0.7, + "top_p": 0.9, + "num_predict": 100, # Reduced from 200 for faster responses + }, + }, + timeout=timeout, + ) + + if response.status_code == 200: + result = response.json() + logger.debug("Ollama HTTP call succeeded") + return result.get("response", "").strip() + else: + logger.debug( + f"Ollama HTTP call failed: HTTP {response.status_code}" + ) + elif HAS_REQUESTS: + logger.debug(f"Calling Ollama HTTP with {len(prompt)} char prompt") + response = requests.post( + f"{self.ollama_base_url}/api/generate", + json={ + "model": self.model, + "prompt": prompt, + "stream": False, + "options": { + "temperature": 0.7, + "top_p": 0.9, + "num_predict": 100, # Reduced from 200 for faster responses + }, + }, + timeout=timeout, + ) + + if response.status_code == 200: + result = response.json() + logger.debug("Ollama HTTP call succeeded") + return result.get("response", "").strip() + else: + logger.debug( + f"Ollama HTTP call failed: HTTP {response.status_code}" + ) + except Exception as e: + logger.debug(f"Ollama HTTP call failed: {e}") + + return None + + def extract_semantic_verbs(self, text: str, category: str) -> List[str]: + """ + Extract semantically similar verbs using Ollama + Inspired by CMU's semantic role labeling approaches + Cached per category (not per text) for performance + """ + # Cache by category only, not per text (much more efficient) + cache_key = f"verbs:{category}" + if cache_key in self._cache: + return self._cache[cache_key] + + # Use category-level prompt (not text-specific) for better caching + prompt = f"""For tasks in the '{category}' category, suggest 6-8 common action verbs that are semantically similar and could be used interchangeably. + +Category: {category} + +Return ONLY a JSON array of verbs, no explanation. Example: ["write", "draft", "compose", "create", "author"]""" + + response = self._call_ollama(prompt) + if response: + try: + json_match = re.search(r"\[.*?\]", response, re.DOTALL) + if json_match: + verbs = json.loads(json_match.group()) + if isinstance(verbs, list) and len(verbs) > 0: + self._cache[cache_key] = verbs + return verbs + except Exception: + pass + + # Fallback: return empty list + return [] + + def generate_semantic_prefixes(self, text: str, category: str) -> List[str]: + """ + Generate semantically appropriate prefixes using contextual analysis + Inspired by CMU's discourse analysis approaches + Cached per category (not per text) for performance + """ + cache_key = f"prefixes:{category}" + if cache_key in self._cache: + return self._cache[cache_key] + + # Simplified prompt - category only (not text-specific) for better caching + prompt = f"""For tasks in the '{category}' category, generate 8-10 natural ways to introduce task requests. + +Category: {category} + +Consider: politeness levels, personal perspectives (I need, Can you, Let me), contextual frames. + +Return ONLY a JSON array of prefixes. Example: ["I need to", "Can you help me", "Please", "I want to"]""" + + response = self._call_ollama(prompt) + if response: + try: + json_match = re.search(r"\[.*?\]", response, re.DOTALL) + if json_match: + prefixes = json.loads(json_match.group()) + if isinstance(prefixes, list) and len(prefixes) > 0: + self._cache[cache_key] = prefixes + return prefixes + except Exception: + pass + + # Fallback: return default prefixes + return [ + "I need to", + "Can you help me", + "Please", + "I want to", + "Help me", + "I should", + "Let me", + "I'm going to", + ] + + def generate_semantic_suffixes(self, text: str, category: str) -> List[str]: + """ + Generate semantically appropriate suffixes using temporal and contextual analysis + Cached per category (not per text) for performance + """ + cache_key = f"suffixes:{category}" + if cache_key in self._cache: + return self._cache[cache_key] + + # Simplified prompt - category only (not text-specific) for better caching + prompt = f"""For tasks in the '{category}' category, generate 6-8 natural ways to add temporal or contextual information. + +Category: {category} + +Consider: time constraints, urgency levels, contextual markers. + +Return ONLY a JSON array of suffixes. Example: [" today", " this week", " as soon as possible"]""" + + response = self._call_ollama(prompt) + if response: + try: + json_match = re.search(r"\[.*?\]", response, re.DOTALL) + if json_match: + suffixes = json.loads(json_match.group()) + if isinstance(suffixes, list) and len(suffixes) > 0: + self._cache[cache_key] = suffixes + return suffixes + except Exception: + pass + + # Fallback: return default suffixes + return [ + "", + " today", + " this week", + " as soon as possible", + " when you have time", + " - urgent", + " - important", + ] + + def generate_paraphrases( + self, text: str, category: str, num_paraphrases: int = 3 + ) -> List[str]: + """ + Generate semantic paraphrases using Ollama + Inspired by CMU's paraphrase generation approaches + """ + prompt = f"""Generate {num_paraphrases} different natural ways to express this task request. Each should be semantically equivalent but use different wording: + +Original: "{text}" +Category: {category} + +Requirements: +- Keep the same meaning and intent +- Use natural, conversational language +- Vary sentence structure and word choice +- Each paraphrase should be a complete sentence + +Return ONLY the paraphrases, one per line, without numbering or bullets.""" + + response = self._call_ollama(prompt) + if response: + paraphrases = [] + for line in response.strip().split("\n"): + line = line.strip() + # Remove common prefixes + for prefix in ["- ", "1. ", "2. ", "3. ", "4. ", "5. ", "* ", "• "]: + if line.startswith(prefix): + line = line[len(prefix) :].strip() + + if line and line != text and len(line) > 10: + paraphrases.append(line) + + return paraphrases[:num_paraphrases] + + return [] + + def analyze_semantic_structure(self, text: str) -> Dict: + """ + Analyze semantic structure of text + Inspired by CMU's semantic role labeling and dependency parsing + """ + prompt = f"""Analyze the semantic structure of this task request: + +"{text}" + +Identify: +1. Main action verb +2. Object/noun phrase +3. Modifiers/adjectives +4. Temporal markers (if any) +5. Urgency indicators (if any) + +Return a JSON object with these fields.""" + + response = self._call_ollama(prompt) + if response: + try: + json_match = re.search(r"\{.*?\}", response, re.DOTALL) + if json_match: + return json.loads(json_match.group()) + except Exception: + pass + + # Fallback: simple analysis + words = text.lower().split() + return { + "action_verb": words[0] if words else "unknown", + "object": " ".join(words[1:]) if len(words) > 1 else "", + "modifiers": [], + "temporal": None, + "urgency": None, + } + + def check_ollama_available(self) -> bool: + """Check if Ollama is available""" + # Try ollama package first + if HAS_OLLAMA_PKG: + try: + ollama.list() # This will raise if not available + return True + except Exception: + pass + + # Fall back to HTTP check + try: + if HAS_HTTPX: + response = httpx.get(f"{self.ollama_base_url}/api/tags", timeout=5.0) + return response.status_code == 200 + elif HAS_REQUESTS: + response = requests.get(f"{self.ollama_base_url}/api/tags", timeout=5.0) + return response.status_code == 200 + except Exception: + pass + + return False + + +def get_semantic_analyzer( + ollama_base_url: Optional[str] = None, model: Optional[str] = None +) -> Optional[SemanticAnalyzer]: + """ + Factory function to get semantic analyzer + """ + base_url = ollama_base_url or os.getenv("OLLAMA_BASE_URL", "http://localhost:11434") + model_name = model or os.getenv("OLLAMA_MODEL", "llama3:8b") + + analyzer = SemanticAnalyzer(ollama_base_url=base_url, model=model_name) + + if analyzer.check_ollama_available(): + return analyzer + else: + logger.warning(f"Ollama not available at {base_url}") + return None diff --git a/src/deepiri_modelkit/rag/__init__.py b/src/deepiri_modelkit/rag/__init__.py index 0b3bbcc..6d85952 100644 --- a/src/deepiri_modelkit/rag/__init__.py +++ b/src/deepiri_modelkit/rag/__init__.py @@ -1,155 +1,166 @@ -""" -Universal RAG Module for Deepiri Platform -Reusable across all industry niches: Insurance, Manufacturing, Property Management, Healthcare, etc. -""" - -from .base import ( - UniversalRAGEngine, - Document, - DocumentType, - IndustryNiche, - RAGConfig, - RAGQuery, - RetrievalResult, -) -from .processors import ( - DocumentProcessor, - RegulationProcessor, - HistoricalDataProcessor, - KnowledgeBaseProcessor, - ManualProcessor, - get_processor, -) -from .retrievers import ( - MultiModalRetriever, - HybridRetriever, - ContextualRetriever, - get_retriever, -) - -# Advanced features (optional imports) -try: - from .advanced_retrieval import ( - AdvancedRetrievalPipeline, - QueryExpander, - SynonymQueryExpander, - RephraseQueryExpander, - MultiQueryRetriever, - QueryCache, - ) - HAS_ADVANCED_RETRIEVAL = True -except ImportError: - HAS_ADVANCED_RETRIEVAL = False - AdvancedRetrievalPipeline = None - QueryExpander = None - SynonymQueryExpander = None - RephraseQueryExpander = None - MultiQueryRetriever = None - QueryCache = None - -try: - from .caching import ( - AdvancedCacheManager, - EmbeddingCache, - QueryResultCache, - ) - HAS_CACHING = True -except ImportError: - HAS_CACHING = False - AdvancedCacheManager = None - EmbeddingCache = None - QueryResultCache = None - -try: - from .monitoring import ( - RAGMonitor, - RetrievalMetrics, - IndexingMetrics, - SystemMetrics, - PerformanceTimer, - ) - HAS_MONITORING = True -except ImportError: - HAS_MONITORING = False - RAGMonitor = None - RetrievalMetrics = None - IndexingMetrics = None - SystemMetrics = None - PerformanceTimer = None - -try: - from .async_processing import ( - AsyncBatchProcessor, - AsyncDocumentIndexer, - AsyncDocumentProcessor, - BatchProcessingConfig, - BatchProcessingResult, - ) - HAS_ASYNC = True -except ImportError: - HAS_ASYNC = False - AsyncBatchProcessor = None - AsyncDocumentIndexer = None - AsyncDocumentProcessor = None - BatchProcessingConfig = None - BatchProcessingResult = None - -__all__ = [ - # Core - "UniversalRAGEngine", - "Document", - "DocumentType", - "IndustryNiche", - "RAGConfig", - "RAGQuery", - "RetrievalResult", - # Processors - "DocumentProcessor", - "RegulationProcessor", - "HistoricalDataProcessor", - "KnowledgeBaseProcessor", - "ManualProcessor", - "get_processor", - # Retrievers - "MultiModalRetriever", - "HybridRetriever", - "ContextualRetriever", - "get_retriever", -] - -# Conditionally add advanced features -if HAS_ADVANCED_RETRIEVAL: - __all__.extend([ - "AdvancedRetrievalPipeline", - "QueryExpander", - "SynonymQueryExpander", - "RephraseQueryExpander", - "MultiQueryRetriever", - "QueryCache", - ]) - -if HAS_CACHING: - __all__.extend([ - "AdvancedCacheManager", - "EmbeddingCache", - "QueryResultCache", - ]) - -if HAS_MONITORING: - __all__.extend([ - "RAGMonitor", - "RetrievalMetrics", - "IndexingMetrics", - "SystemMetrics", - "PerformanceTimer", - ]) - -if HAS_ASYNC: - __all__.extend([ - "AsyncBatchProcessor", - "AsyncDocumentIndexer", - "AsyncDocumentProcessor", - "BatchProcessingConfig", - "BatchProcessingResult", - ]) - +""" +Universal RAG Module for Deepiri Platform +Reusable across all industry niches: Insurance, Manufacturing, Property Management, Healthcare, etc. +""" + +from .base import ( + UniversalRAGEngine, + Document, + DocumentType, + IndustryNiche, + RAGConfig, + RAGQuery, + RetrievalResult, +) +from .processors import ( + DocumentProcessor, + RegulationProcessor, + HistoricalDataProcessor, + KnowledgeBaseProcessor, + ManualProcessor, + get_processor, +) +from .retrievers import ( + MultiModalRetriever, + HybridRetriever, + ContextualRetriever, + get_retriever, +) + +# Advanced features (optional imports) +try: + from .advanced_retrieval import ( + AdvancedRetrievalPipeline, + QueryExpander, + SynonymQueryExpander, + RephraseQueryExpander, + MultiQueryRetriever, + QueryCache, + ) + + HAS_ADVANCED_RETRIEVAL = True +except ImportError: + HAS_ADVANCED_RETRIEVAL = False + AdvancedRetrievalPipeline = None + QueryExpander = None + SynonymQueryExpander = None + RephraseQueryExpander = None + MultiQueryRetriever = None + QueryCache = None + +try: + from .caching import ( + AdvancedCacheManager, + EmbeddingCache, + QueryResultCache, + ) + + HAS_CACHING = True +except ImportError: + HAS_CACHING = False + AdvancedCacheManager = None + EmbeddingCache = None + QueryResultCache = None + +try: + from .monitoring import ( + RAGMonitor, + RetrievalMetrics, + IndexingMetrics, + SystemMetrics, + PerformanceTimer, + ) + + HAS_MONITORING = True +except ImportError: + HAS_MONITORING = False + RAGMonitor = None + RetrievalMetrics = None + IndexingMetrics = None + SystemMetrics = None + PerformanceTimer = None + +try: + from .async_processing import ( + AsyncBatchProcessor, + AsyncDocumentIndexer, + AsyncDocumentProcessor, + BatchProcessingConfig, + BatchProcessingResult, + ) + + HAS_ASYNC = True +except ImportError: + HAS_ASYNC = False + AsyncBatchProcessor = None + AsyncDocumentIndexer = None + AsyncDocumentProcessor = None + BatchProcessingConfig = None + BatchProcessingResult = None + +__all__ = [ + # Core + "UniversalRAGEngine", + "Document", + "DocumentType", + "IndustryNiche", + "RAGConfig", + "RAGQuery", + "RetrievalResult", + # Processors + "DocumentProcessor", + "RegulationProcessor", + "HistoricalDataProcessor", + "KnowledgeBaseProcessor", + "ManualProcessor", + "get_processor", + # Retrievers + "MultiModalRetriever", + "HybridRetriever", + "ContextualRetriever", + "get_retriever", +] + +# Conditionally add advanced features +if HAS_ADVANCED_RETRIEVAL: + __all__.extend( + [ + "AdvancedRetrievalPipeline", + "QueryExpander", + "SynonymQueryExpander", + "RephraseQueryExpander", + "MultiQueryRetriever", + "QueryCache", + ] + ) + +if HAS_CACHING: + __all__.extend( + [ + "AdvancedCacheManager", + "EmbeddingCache", + "QueryResultCache", + ] + ) + +if HAS_MONITORING: + __all__.extend( + [ + "RAGMonitor", + "RetrievalMetrics", + "IndexingMetrics", + "SystemMetrics", + "PerformanceTimer", + ] + ) + +if HAS_ASYNC: + __all__.extend( + [ + "AsyncBatchProcessor", + "AsyncDocumentIndexer", + "AsyncDocumentProcessor", + "BatchProcessingConfig", + "BatchProcessingResult", + ] + ) diff --git a/src/deepiri_modelkit/rag/advanced_retrieval.py b/src/deepiri_modelkit/rag/advanced_retrieval.py index 14e1de6..711177f 100644 --- a/src/deepiri_modelkit/rag/advanced_retrieval.py +++ b/src/deepiri_modelkit/rag/advanced_retrieval.py @@ -1,394 +1,422 @@ -""" -Advanced Retrieval Strategies for Universal RAG -Query expansion, multi-query retrieval, and advanced search techniques -""" - -from abc import ABC, abstractmethod -from typing import List, Dict, Any, Optional, Tuple -from dataclasses import dataclass -import hashlib -import json - -from .base import Document, RetrievalResult, RAGQuery - - -@dataclass -class ExpandedQuery: - """Expanded query with multiple variations""" - original_query: str - expanded_queries: List[str] - query_type: str # "synonym", "rephrase", "keyword", etc. - confidence: float - - -class QueryExpander(ABC): - """Base class for query expansion strategies""" - - @abstractmethod - def expand(self, query: str, max_expansions: int = 3) -> ExpandedQuery: - """Expand a query into multiple variations""" - pass - - -class SynonymQueryExpander(QueryExpander): - """Expand queries using synonyms and related terms""" - - def __init__(self, synonym_dict: Optional[Dict[str, List[str]]] = None): - self.synonym_dict = synonym_dict or self._default_synonyms() - - def expand(self, query: str, max_expansions: int = 3) -> ExpandedQuery: - """Expand query using synonyms""" - words = query.lower().split() - expanded = [query] # Always include original - - for word in words: - if word in self.synonym_dict: - synonyms = self.synonym_dict[word][:max_expansions] - for synonym in synonyms: - expanded_query = query.replace(word, synonym) - if expanded_query not in expanded: - expanded.append(expanded_query) - - return ExpandedQuery( - original_query=query, - expanded_queries=expanded[:max_expansions + 1], - query_type="synonym", - confidence=0.8 - ) - - def _default_synonyms(self) -> Dict[str, List[str]]: - """Default synonym dictionary""" - return { - "repair": ["fix", "maintain", "service", "restore"], - "maintenance": ["service", "upkeep", "repair", "inspection"], - "claim": ["request", "application", "report", "filing"], - "policy": ["coverage", "plan", "insurance", "agreement"], - "regulation": ["rule", "standard", "requirement", "guideline"], - "procedure": ["process", "method", "protocol", "steps"], - "equipment": ["machine", "device", "tool", "apparatus"], - "safety": ["security", "protection", "precaution"], - "inspection": ["examination", "review", "check", "audit"], - "documentation": ["record", "file", "document", "paperwork"], - } - - -class RephraseQueryExpander(QueryExpander): - """Expand queries by rephrasing""" - - def __init__(self, llm_client=None): - self.llm_client = llm_client - - def expand(self, query: str, max_expansions: int = 3) -> ExpandedQuery: - """Rephrase query using LLM or templates""" - if self.llm_client: - return self._llm_rephrase(query, max_expansions) - else: - return self._template_rephrase(query, max_expansions) - - def _template_rephrase(self, query: str, max_expansions: int) -> ExpandedQuery: - """Rephrase using templates""" - templates = [ - f"What is {query}?", - f"How to {query}?", - f"Information about {query}", - f"Details on {query}", - ] - - expanded = [query] + templates[:max_expansions] - - return ExpandedQuery( - original_query=query, - expanded_queries=expanded, - query_type="rephrase", - confidence=0.7 - ) - - def _llm_rephrase(self, query: str, max_expansions: int) -> ExpandedQuery: - """Rephrase using LLM (if available)""" - # Placeholder for LLM-based rephrasing - return self._template_rephrase(query, max_expansions) - - -class KeywordExtractor: - """Extract keywords from queries for hybrid search""" - - def __init__(self): - # Common stop words - self.stop_words = { - "the", "a", "an", "and", "or", "but", "in", "on", "at", "to", "for", - "of", "with", "by", "is", "are", "was", "were", "be", "been", "being", - "have", "has", "had", "do", "does", "did", "will", "would", "should", - "could", "may", "might", "must", "can", "this", "that", "these", "those" - } - - def extract(self, query: str, max_keywords: int = 10) -> List[str]: - """Extract keywords from query""" - words = query.lower().split() - keywords = [ - word.strip(".,!?;:()[]{}") - for word in words - if word.strip(".,!?;:()[]{}") not in self.stop_words - and len(word.strip(".,!?;:()[]{}")) > 2 - ] - return keywords[:max_keywords] - - -class MultiQueryRetriever: - """ - Multi-query retrieval strategy - Generates multiple query variations and combines results - """ - - def __init__( - self, - base_retriever, - query_expander: Optional[QueryExpander] = None, - fusion_method: str = "rrf" # "rrf" (Reciprocal Rank Fusion) or "mean" - ): - self.base_retriever = base_retriever - self.query_expander = query_expander or SynonymQueryExpander() - self.fusion_method = fusion_method - - def retrieve( - self, - query: RAGQuery, - num_queries: int = 3 - ) -> List[RetrievalResult]: - """ - Retrieve using multiple query variations - - Args: - query: Original RAG query - num_queries: Number of query variations to generate - - Returns: - Fused retrieval results - """ - # Expand query - expanded = self.query_expander.expand(query.query, max_expansions=num_queries - 1) - - # Retrieve for each query variation - all_results: Dict[str, List[RetrievalResult]] = {} - - for expanded_query in expanded.expanded_queries: - # Create new query with expanded text - expanded_rag_query = RAGQuery( - query=expanded_query, - industry=query.industry, - doc_types=query.doc_types, - date_range=query.date_range, - metadata_filters=query.metadata_filters, - top_k=query.top_k or 10 - ) - - # Retrieve results - results = self.base_retriever.retrieve(expanded_rag_query) - all_results[expanded_query] = results - - # Fuse results - fused = self._fuse_results(all_results, query.top_k or 10) - - return fused - - def _fuse_results( - self, - all_results: Dict[str, List[RetrievalResult]], - top_k: int - ) -> List[RetrievalResult]: - """Fuse results from multiple queries""" - if self.fusion_method == "rrf": - return self._reciprocal_rank_fusion(all_results, top_k) - else: - return self._mean_score_fusion(all_results, top_k) - - def _reciprocal_rank_fusion( - self, - all_results: Dict[str, List[RetrievalResult]], - top_k: int, - k: int = 60 - ) -> List[RetrievalResult]: - """ - Reciprocal Rank Fusion (RRF) - Combines rankings from multiple queries - """ - doc_scores: Dict[str, Dict[str, Any]] = {} - - for query_text, results in all_results.items(): - for rank, result in enumerate(results): - doc_id = result.document.id - rrf_score = 1.0 / (k + rank + 1) - - if doc_id not in doc_scores: - doc_scores[doc_id] = { - "document": result.document, - "rrf_score": 0.0, - "max_score": result.score, - "count": 0 - } - - doc_scores[doc_id]["rrf_score"] += rrf_score - doc_scores[doc_id]["max_score"] = max( - doc_scores[doc_id]["max_score"], - result.score - ) - doc_scores[doc_id]["count"] += 1 - - # Create fused results - fused = [] - for doc_id, scores in doc_scores.items(): - fused.append(RetrievalResult( - document=scores["document"], - score=scores["rrf_score"], # Use RRF score - rerank_score=scores["max_score"] # Store max original score - )) - - # Sort by RRF score - fused.sort(key=lambda x: x.score, reverse=True) - - return fused[:top_k] - - def _mean_score_fusion( - self, - all_results: Dict[str, List[RetrievalResult]], - top_k: int - ) -> List[RetrievalResult]: - """Fuse results using mean score""" - doc_scores: Dict[str, Dict[str, Any]] = {} - - for query_text, results in all_results.items(): - for result in results: - doc_id = result.document.id - - if doc_id not in doc_scores: - doc_scores[doc_id] = { - "document": result.document, - "scores": [], - "count": 0 - } - - doc_scores[doc_id]["scores"].append(result.score) - doc_scores[doc_id]["count"] += 1 - - # Calculate mean scores - fused = [] - for doc_id, scores in doc_scores.items(): - mean_score = sum(scores["scores"]) / len(scores["scores"]) - fused.append(RetrievalResult( - document=scores["document"], - score=mean_score, - rerank_score=max(scores["scores"]) - )) - - # Sort by mean score - fused.sort(key=lambda x: x.score, reverse=True) - - return fused[:top_k] - - -class QueryCache: - """Cache for query results""" - - def __init__(self, cache_manager=None): - self.cache_manager = cache_manager - self.cache_ttl = 3600 # 1 hour - - def get_cache_key(self, query: RAGQuery) -> str: - """Generate cache key for query""" - query_dict = query.to_dict() - query_str = json.dumps(query_dict, sort_keys=True) - query_hash = hashlib.md5(query_str.encode()).hexdigest() - return f"rag:query:{query_hash}" - - def get(self, query: RAGQuery) -> Optional[List[RetrievalResult]]: - """Get cached results""" - if not self.cache_manager: - return None - - cache_key = self.get_cache_key(query) - cached = self.cache_manager.get(cache_key) - - if cached: - # Reconstruct RetrievalResult objects - results = [] - for item in cached: - doc = Document.from_dict(item["document"]) - result = RetrievalResult( - document=doc, - score=item["score"], - rerank_score=item.get("rerank_score") - ) - results.append(result) - return results - - return None - - def set(self, query: RAGQuery, results: List[RetrievalResult]): - """Cache results""" - if not self.cache_manager: - return - - cache_key = self.get_cache_key(query) - # Serialize results - serialized = [ - { - "document": r.document.to_dict(), - "score": r.score, - "rerank_score": r.rerank_score - } - for r in results - ] - - self.cache_manager.set(cache_key, serialized, ttl=self.cache_ttl) - - -class AdvancedRetrievalPipeline: - """ - Advanced retrieval pipeline with: - - Query expansion - - Multi-query retrieval - - Result caching - - Hybrid search - """ - - def __init__( - self, - base_retriever, - query_expander: Optional[QueryExpander] = None, - use_multi_query: bool = True, - use_cache: bool = True, - cache_manager=None - ): - self.base_retriever = base_retriever - self.query_expander = query_expander or SynonymQueryExpander() - self.use_multi_query = use_multi_query - self.use_cache = use_cache - self.query_cache = QueryCache(cache_manager) if use_cache else None - - if use_multi_query: - self.multi_query_retriever = MultiQueryRetriever( - base_retriever=base_retriever, - query_expander=query_expander - ) - else: - self.multi_query_retriever = None - - def retrieve(self, query: RAGQuery) -> List[RetrievalResult]: - """Retrieve with advanced strategies""" - # Check cache - if self.use_cache and self.query_cache: - cached = self.query_cache.get(query) - if cached: - return cached - - # Retrieve using multi-query or base retriever - if self.use_multi_query and self.multi_query_retriever: - results = self.multi_query_retriever.retrieve(query) - else: - results = self.base_retriever.retrieve(query) - - # Cache results - if self.use_cache and self.query_cache: - self.query_cache.set(query, results) - - return results - +""" +Advanced Retrieval Strategies for Universal RAG +Query expansion, multi-query retrieval, and advanced search techniques +""" + +from abc import ABC, abstractmethod +from typing import List, Dict, Any, Optional, Tuple +from dataclasses import dataclass +import hashlib +import json + +from .base import Document, RetrievalResult, RAGQuery + + +@dataclass +class ExpandedQuery: + """Expanded query with multiple variations""" + + original_query: str + expanded_queries: List[str] + query_type: str # "synonym", "rephrase", "keyword", etc. + confidence: float + + +class QueryExpander(ABC): + """Base class for query expansion strategies""" + + @abstractmethod + def expand(self, query: str, max_expansions: int = 3) -> ExpandedQuery: + """Expand a query into multiple variations""" + pass + + +class SynonymQueryExpander(QueryExpander): + """Expand queries using synonyms and related terms""" + + def __init__(self, synonym_dict: Optional[Dict[str, List[str]]] = None): + self.synonym_dict = synonym_dict or self._default_synonyms() + + def expand(self, query: str, max_expansions: int = 3) -> ExpandedQuery: + """Expand query using synonyms""" + words = query.lower().split() + expanded = [query] # Always include original + + for word in words: + if word in self.synonym_dict: + synonyms = self.synonym_dict[word][:max_expansions] + for synonym in synonyms: + expanded_query = query.replace(word, synonym) + if expanded_query not in expanded: + expanded.append(expanded_query) + + return ExpandedQuery( + original_query=query, + expanded_queries=expanded[: max_expansions + 1], + query_type="synonym", + confidence=0.8, + ) + + def _default_synonyms(self) -> Dict[str, List[str]]: + """Default synonym dictionary""" + return { + "repair": ["fix", "maintain", "service", "restore"], + "maintenance": ["service", "upkeep", "repair", "inspection"], + "claim": ["request", "application", "report", "filing"], + "policy": ["coverage", "plan", "insurance", "agreement"], + "regulation": ["rule", "standard", "requirement", "guideline"], + "procedure": ["process", "method", "protocol", "steps"], + "equipment": ["machine", "device", "tool", "apparatus"], + "safety": ["security", "protection", "precaution"], + "inspection": ["examination", "review", "check", "audit"], + "documentation": ["record", "file", "document", "paperwork"], + } + + +class RephraseQueryExpander(QueryExpander): + """Expand queries by rephrasing""" + + def __init__(self, llm_client=None): + self.llm_client = llm_client + + def expand(self, query: str, max_expansions: int = 3) -> ExpandedQuery: + """Rephrase query using LLM or templates""" + if self.llm_client: + return self._llm_rephrase(query, max_expansions) + else: + return self._template_rephrase(query, max_expansions) + + def _template_rephrase(self, query: str, max_expansions: int) -> ExpandedQuery: + """Rephrase using templates""" + templates = [ + f"What is {query}?", + f"How to {query}?", + f"Information about {query}", + f"Details on {query}", + ] + + expanded = [query] + templates[:max_expansions] + + return ExpandedQuery( + original_query=query, + expanded_queries=expanded, + query_type="rephrase", + confidence=0.7, + ) + + def _llm_rephrase(self, query: str, max_expansions: int) -> ExpandedQuery: + """Rephrase using LLM (if available)""" + # Placeholder for LLM-based rephrasing + return self._template_rephrase(query, max_expansions) + + +class KeywordExtractor: + """Extract keywords from queries for hybrid search""" + + def __init__(self): + # Common stop words + self.stop_words = { + "the", + "a", + "an", + "and", + "or", + "but", + "in", + "on", + "at", + "to", + "for", + "of", + "with", + "by", + "is", + "are", + "was", + "were", + "be", + "been", + "being", + "have", + "has", + "had", + "do", + "does", + "did", + "will", + "would", + "should", + "could", + "may", + "might", + "must", + "can", + "this", + "that", + "these", + "those", + } + + def extract(self, query: str, max_keywords: int = 10) -> List[str]: + """Extract keywords from query""" + words = query.lower().split() + keywords = [ + word.strip(".,!?;:()[]{}") + for word in words + if word.strip(".,!?;:()[]{}") not in self.stop_words + and len(word.strip(".,!?;:()[]{}")) > 2 + ] + return keywords[:max_keywords] + + +class MultiQueryRetriever: + """ + Multi-query retrieval strategy + Generates multiple query variations and combines results + """ + + def __init__( + self, + base_retriever, + query_expander: Optional[QueryExpander] = None, + fusion_method: str = "rrf", # "rrf" (Reciprocal Rank Fusion) or "mean" + ): + self.base_retriever = base_retriever + self.query_expander = query_expander or SynonymQueryExpander() + self.fusion_method = fusion_method + + def retrieve(self, query: RAGQuery, num_queries: int = 3) -> List[RetrievalResult]: + """ + Retrieve using multiple query variations + + Args: + query: Original RAG query + num_queries: Number of query variations to generate + + Returns: + Fused retrieval results + """ + # Expand query + expanded = self.query_expander.expand( + query.query, max_expansions=num_queries - 1 + ) + + # Retrieve for each query variation + all_results: Dict[str, List[RetrievalResult]] = {} + + for expanded_query in expanded.expanded_queries: + # Create new query with expanded text + expanded_rag_query = RAGQuery( + query=expanded_query, + industry=query.industry, + doc_types=query.doc_types, + date_range=query.date_range, + metadata_filters=query.metadata_filters, + top_k=query.top_k or 10, + ) + + # Retrieve results + results = self.base_retriever.retrieve(expanded_rag_query) + all_results[expanded_query] = results + + # Fuse results + fused = self._fuse_results(all_results, query.top_k or 10) + + return fused + + def _fuse_results( + self, all_results: Dict[str, List[RetrievalResult]], top_k: int + ) -> List[RetrievalResult]: + """Fuse results from multiple queries""" + if self.fusion_method == "rrf": + return self._reciprocal_rank_fusion(all_results, top_k) + else: + return self._mean_score_fusion(all_results, top_k) + + def _reciprocal_rank_fusion( + self, all_results: Dict[str, List[RetrievalResult]], top_k: int, k: int = 60 + ) -> List[RetrievalResult]: + """ + Reciprocal Rank Fusion (RRF) + Combines rankings from multiple queries + """ + doc_scores: Dict[str, Dict[str, Any]] = {} + + for query_text, results in all_results.items(): + for rank, result in enumerate(results): + doc_id = result.document.id + rrf_score = 1.0 / (k + rank + 1) + + if doc_id not in doc_scores: + doc_scores[doc_id] = { + "document": result.document, + "rrf_score": 0.0, + "max_score": result.score, + "count": 0, + } + + doc_scores[doc_id]["rrf_score"] += rrf_score + doc_scores[doc_id]["max_score"] = max( + doc_scores[doc_id]["max_score"], result.score + ) + doc_scores[doc_id]["count"] += 1 + + # Create fused results + fused = [] + for doc_id, scores in doc_scores.items(): + fused.append( + RetrievalResult( + document=scores["document"], + score=scores["rrf_score"], # Use RRF score + rerank_score=scores["max_score"], # Store max original score + ) + ) + + # Sort by RRF score + fused.sort(key=lambda x: x.score, reverse=True) + + return fused[:top_k] + + def _mean_score_fusion( + self, all_results: Dict[str, List[RetrievalResult]], top_k: int + ) -> List[RetrievalResult]: + """Fuse results using mean score""" + doc_scores: Dict[str, Dict[str, Any]] = {} + + for query_text, results in all_results.items(): + for result in results: + doc_id = result.document.id + + if doc_id not in doc_scores: + doc_scores[doc_id] = { + "document": result.document, + "scores": [], + "count": 0, + } + + doc_scores[doc_id]["scores"].append(result.score) + doc_scores[doc_id]["count"] += 1 + + # Calculate mean scores + fused = [] + for doc_id, scores in doc_scores.items(): + mean_score = sum(scores["scores"]) / len(scores["scores"]) + fused.append( + RetrievalResult( + document=scores["document"], + score=mean_score, + rerank_score=max(scores["scores"]), + ) + ) + + # Sort by mean score + fused.sort(key=lambda x: x.score, reverse=True) + + return fused[:top_k] + + +class QueryCache: + """Cache for query results""" + + def __init__(self, cache_manager=None): + self.cache_manager = cache_manager + self.cache_ttl = 3600 # 1 hour + + def get_cache_key(self, query: RAGQuery) -> str: + """Generate cache key for query""" + query_dict = query.to_dict() + query_str = json.dumps(query_dict, sort_keys=True) + query_hash = hashlib.md5(query_str.encode()).hexdigest() + return f"rag:query:{query_hash}" + + def get(self, query: RAGQuery) -> Optional[List[RetrievalResult]]: + """Get cached results""" + if not self.cache_manager: + return None + + cache_key = self.get_cache_key(query) + cached = self.cache_manager.get(cache_key) + + if cached: + # Reconstruct RetrievalResult objects + results = [] + for item in cached: + doc = Document.from_dict(item["document"]) + result = RetrievalResult( + document=doc, + score=item["score"], + rerank_score=item.get("rerank_score"), + ) + results.append(result) + return results + + return None + + def set(self, query: RAGQuery, results: List[RetrievalResult]): + """Cache results""" + if not self.cache_manager: + return + + cache_key = self.get_cache_key(query) + # Serialize results + serialized = [ + { + "document": r.document.to_dict(), + "score": r.score, + "rerank_score": r.rerank_score, + } + for r in results + ] + + self.cache_manager.set(cache_key, serialized, ttl=self.cache_ttl) + + +class AdvancedRetrievalPipeline: + """ + Advanced retrieval pipeline with: + - Query expansion + - Multi-query retrieval + - Result caching + - Hybrid search + """ + + def __init__( + self, + base_retriever, + query_expander: Optional[QueryExpander] = None, + use_multi_query: bool = True, + use_cache: bool = True, + cache_manager=None, + ): + self.base_retriever = base_retriever + self.query_expander = query_expander or SynonymQueryExpander() + self.use_multi_query = use_multi_query + self.use_cache = use_cache + self.query_cache = QueryCache(cache_manager) if use_cache else None + + if use_multi_query: + self.multi_query_retriever = MultiQueryRetriever( + base_retriever=base_retriever, query_expander=query_expander + ) + else: + self.multi_query_retriever = None + + def retrieve(self, query: RAGQuery) -> List[RetrievalResult]: + """Retrieve with advanced strategies""" + # Check cache + if self.use_cache and self.query_cache: + cached = self.query_cache.get(query) + if cached: + return cached + + # Retrieve using multi-query or base retriever + if self.use_multi_query and self.multi_query_retriever: + results = self.multi_query_retriever.retrieve(query) + else: + results = self.base_retriever.retrieve(query) + + # Cache results + if self.use_cache and self.query_cache: + self.query_cache.set(query, results) + + return results diff --git a/src/deepiri_modelkit/rag/async_processing.py b/src/deepiri_modelkit/rag/async_processing.py index 9625443..87a0b0f 100644 --- a/src/deepiri_modelkit/rag/async_processing.py +++ b/src/deepiri_modelkit/rag/async_processing.py @@ -5,6 +5,7 @@ import asyncio from typing import List, Dict, Any, Optional, Callable, Awaitable + # Fix for Python < 3.9 compatibility try: from collections.abc import AsyncIterator @@ -20,6 +21,7 @@ @dataclass class BatchProcessingConfig: """Configuration for batch processing""" + batch_size: int = 100 max_concurrent_batches: int = 5 chunk_size: int = 1000 @@ -33,24 +35,25 @@ class BatchProcessingConfig: @dataclass class BatchProcessingResult: """Result of batch processing operation""" + total_items: int processed_items: int successful_items: int failed_items: int processing_time_seconds: float errors: List[Dict[str, Any]] = None - + def __post_init__(self): if self.errors is None: self.errors = [] - + @property def success_rate(self) -> float: """Calculate success rate""" if self.total_items == 0: return 0.0 return self.successful_items / self.total_items - + def to_dict(self) -> Dict[str, Any]: """Convert to dictionary""" return { @@ -68,25 +71,25 @@ class AsyncBatchProcessor: """ Async batch processor for high-performance document processing """ - + def __init__(self, config: BatchProcessingConfig): self.config = config self.semaphore = asyncio.Semaphore(config.max_concurrent_batches) - + async def process_batch( self, items: List[Any], processor_func: Callable[[Any], Awaitable[Any]], - progress_callback: Optional[Callable[[int, int], None]] = None + progress_callback: Optional[Callable[[int, int], None]] = None, ) -> BatchProcessingResult: """ Process items in batches asynchronously - + Args: items: List of items to process processor_func: Async function to process each item progress_callback: Optional callback for progress updates - + Returns: BatchProcessingResult with statistics """ @@ -95,135 +98,125 @@ async def process_batch( successful_items = 0 failed_items = 0 errors = [] - + # Split into batches batches = [ - items[i:i + self.config.batch_size] + items[i : i + self.config.batch_size] for i in range(0, total_items, self.config.batch_size) ] - + # Process batches concurrently tasks = [] for batch_idx, batch in enumerate(batches): task = self._process_batch_with_semaphore( - batch, - batch_idx, - processor_func, - progress_callback + batch, batch_idx, processor_func, progress_callback ) tasks.append(task) - + # Wait for all batches batch_results = await asyncio.gather(*tasks, return_exceptions=True) - + # Aggregate results for result in batch_results: if isinstance(result, Exception): failed_items += self.config.batch_size - errors.append({ - "error": str(result), - "type": type(result).__name__ - }) + errors.append({"error": str(result), "type": type(result).__name__}) else: successful_items += result["successful"] failed_items += result["failed"] errors.extend(result.get("errors", [])) - + processing_time = time.time() - start_time - + return BatchProcessingResult( total_items=total_items, processed_items=total_items, successful_items=successful_items, failed_items=failed_items, processing_time_seconds=processing_time, - errors=errors + errors=errors, ) - + async def _process_batch_with_semaphore( self, batch: List[Any], batch_idx: int, processor_func: Callable[[Any], Awaitable[Any]], - progress_callback: Optional[Callable[[int, int], None]] + progress_callback: Optional[Callable[[int, int], None]], ) -> Dict[str, Any]: """Process a single batch with semaphore control""" async with self.semaphore: return await self._process_single_batch( - batch, - batch_idx, - processor_func, - progress_callback + batch, batch_idx, processor_func, progress_callback ) - + async def _process_single_batch( self, batch: List[Any], batch_idx: int, processor_func: Callable[[Any], Awaitable[Any]], - progress_callback: Optional[Callable[[int, int], None]] + progress_callback: Optional[Callable[[int, int], None]], ) -> Dict[str, Any]: """Process a single batch""" successful = 0 failed = 0 errors = [] - + # Process items in batch concurrently tasks = [processor_func(item) for item in batch] results = await asyncio.gather(*tasks, return_exceptions=True) - + for idx, result in enumerate(results): if isinstance(result, Exception): failed += 1 - errors.append({ - "item_index": batch_idx * self.config.batch_size + idx, - "error": str(result), - "type": type(result).__name__ - }) + errors.append( + { + "item_index": batch_idx * self.config.batch_size + idx, + "error": str(result), + "type": type(result).__name__, + } + ) else: successful += 1 - + # Progress callback if progress_callback: total_processed = (batch_idx + 1) * self.config.batch_size progress_callback(total_processed, len(batch) * (batch_idx + 1)) - - return { - "successful": successful, - "failed": failed, - "errors": errors - } + + return {"successful": successful, "failed": failed, "errors": errors} class AsyncDocumentIndexer: """ Async document indexer with batching and retry logic """ - + def __init__( self, index_func: Callable[[Document], Awaitable[bool]], - config: Optional[BatchProcessingConfig] = None + config: Optional[BatchProcessingConfig] = None, ): self.index_func = index_func self.config = config or BatchProcessingConfig() self.batch_processor = AsyncBatchProcessor(self.config) - + async def index_documents( self, documents: List[Document], - progress_callback: Optional[Callable[[int, int], None]] = None + progress_callback: Optional[Callable[[int, int], None]] = None, ) -> BatchProcessingResult: """ Index documents asynchronously in batches - + Args: documents: List of documents to index progress_callback: Optional callback for progress - + Returns: BatchProcessingResult with statistics """ + async def index_document(doc: Document) -> bool: """Index a single document with retry""" for attempt in range(self.config.max_retries): @@ -238,27 +231,25 @@ async def index_document(doc: Document) -> bool: ) continue raise - + return False - + return await self.batch_processor.process_batch( - documents, - index_document, - progress_callback + documents, index_document, progress_callback ) - + async def index_documents_streaming( self, document_stream: AsyncIterator[Document], - progress_callback: Optional[Callable[[int, int], None]] = None + progress_callback: Optional[Callable[[int, int], None]] = None, ) -> BatchProcessingResult: """ Index documents from async stream - + Args: document_stream: Async iterator of documents progress_callback: Optional callback for progress - + Returns: BatchProcessingResult with statistics """ @@ -268,39 +259,39 @@ async def index_documents_streaming( failed = 0 errors = [] start_time = time.time() - + async def process_current_batch(): nonlocal successful, failed, errors - + if not batch: return - + result = await self.index_documents(batch, progress_callback) successful += result.successful_items failed += result.failed_items errors.extend(result.errors) batch.clear() - + async for document in document_stream: batch.append(document) total_processed += 1 - + if len(batch) >= self.config.batch_size: await process_current_batch() - + # Process remaining batch if batch: await process_current_batch() - + processing_time = time.time() - start_time - + return BatchProcessingResult( total_items=total_processed, processed_items=total_processed, successful_items=successful, failed_items=failed, processing_time_seconds=processing_time, - errors=errors + errors=errors, ) @@ -308,65 +299,55 @@ class AsyncDocumentProcessor: """ Async document processor for parallel document processing """ - + def __init__( self, processor_func: Callable[[str, Dict], List[Document]], - config: Optional[BatchProcessingConfig] = None + config: Optional[BatchProcessingConfig] = None, ): self.processor_func = processor_func self.config = config or BatchProcessingConfig() self.batch_processor = AsyncBatchProcessor(self.config) - + async def process_documents( self, raw_documents: List[Dict[str, Any]], - progress_callback: Optional[Callable[[int, int], None]] = None + progress_callback: Optional[Callable[[int, int], None]] = None, ) -> tuple[List[Document], BatchProcessingResult]: """ Process raw documents asynchronously - + Args: raw_documents: List of dicts with 'content' and 'metadata' progress_callback: Optional callback for progress - + Returns: Tuple of (processed_documents, processing_result) """ processed_docs = [] errors = [] - + async def process_document(item: Dict[str, Any]) -> List[Document]: """Process a single document""" try: content = item.get("content", "") metadata = item.get("metadata", {}) - docs = await asyncio.to_thread( - self.processor_func, - content, - metadata - ) + docs = await asyncio.to_thread(self.processor_func, content, metadata) return docs except Exception as e: - errors.append({ - "item": item.get("id", "unknown"), - "error": str(e) - }) + errors.append({"item": item.get("id", "unknown"), "error": str(e)}) return [] - + result = await self.batch_processor.process_batch( - raw_documents, - process_document, - progress_callback + raw_documents, process_document, progress_callback ) - + # Collect all processed documents tasks = [process_document(item) for item in raw_documents] doc_lists = await asyncio.gather(*tasks, return_exceptions=True) - + for doc_list in doc_lists: if isinstance(doc_list, list): processed_docs.extend(doc_list) - - return processed_docs, result + return processed_docs, result diff --git a/src/deepiri_modelkit/rag/base.py b/src/deepiri_modelkit/rag/base.py index 7693e5d..ec97ac5 100644 --- a/src/deepiri_modelkit/rag/base.py +++ b/src/deepiri_modelkit/rag/base.py @@ -1,300 +1,317 @@ -""" -Universal RAG Base Classes -Core abstractions for retrieval-augmented generation across all industries -""" - -from abc import ABC, abstractmethod -from dataclasses import dataclass, field -from enum import Enum -from typing import List, Dict, Any, Optional, Union -from datetime import datetime - - -class DocumentType(Enum): - """Types of documents that can be indexed""" - REGULATION = "regulation" # Laws, regulations, compliance documents - POLICY = "policy" # Insurance policies, company policies - MANUAL = "manual" # Equipment manuals, operation guides - CONTRACT = "contract" # Legal contracts, agreements - WORK_ORDER = "work_order" # Maintenance work orders, service requests - CLAIM_RECORD = "claim_record" # Insurance claims, warranty claims - MAINTENANCE_LOG = "maintenance_log" # Equipment maintenance history - FAQ = "faq" # Frequently asked questions - KNOWLEDGE_BASE = "knowledge_base" # General knowledge articles - REPORT = "report" # Inspection reports, audit reports - PROCEDURE = "procedure" # Standard operating procedures - SAFETY_GUIDELINE = "safety_guideline" # Safety protocols and guidelines - TECHNICAL_SPEC = "technical_spec" # Technical specifications - INVOICE = "invoice" # Billing and invoices - OTHER = "other" # Catch-all for other document types - - -class IndustryNiche(Enum): - """Supported industry niches""" - INSURANCE = "insurance" # Property & casualty insurance - MANUFACTURING = "manufacturing" # Industrial manufacturing - PROPERTY_MANAGEMENT = "property_management" # Real estate management - HEALTHCARE = "healthcare" # Healthcare providers - CONSTRUCTION = "construction" # Construction industry - AUTOMOTIVE = "automotive" # Automotive services - ENERGY = "energy" # Energy & utilities - LOGISTICS = "logistics" # Transportation & logistics - RETAIL = "retail" # Retail operations - HOSPITALITY = "hospitality" # Hotels & hospitality - GENERIC = "generic" # Cross-industry - - -@dataclass -class RAGConfig: - """Configuration for RAG system""" - # Industry configuration - industry: IndustryNiche = IndustryNiche.GENERIC - - # Vector database configuration - vector_db_type: str = "milvus" # milvus, pinecone, weaviate, memory - collection_name: str = "deepiri_universal_rag" - vector_db_host: str = "milvus" - vector_db_port: int = 19530 - - # Embedding configuration - embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2" - embedding_dimension: int = 384 - - # Retrieval configuration - top_k: int = 5 # Number of documents to retrieve - similarity_threshold: float = 0.7 # Minimum similarity score - use_reranking: bool = True # Use cross-encoder reranking - reranker_model: str = "cross-encoder/ms-marco-MiniLM-L-6-v2" - - # Chunking configuration - chunk_size: int = 1000 # Characters per chunk - chunk_overlap: int = 200 # Overlap between chunks - - # Metadata filtering - enable_metadata_filtering: bool = True - date_range_filtering: bool = True - - # Multi-modal support - support_images: bool = False - support_tables: bool = True - support_code: bool = False - - -@dataclass -class Document: - """Universal document representation""" - id: str - content: str - doc_type: DocumentType - industry: IndustryNiche - - # Metadata - title: Optional[str] = None - source: Optional[str] = None - created_at: Optional[datetime] = None - updated_at: Optional[datetime] = None - author: Optional[str] = None - version: Optional[str] = None - - # Industry-specific metadata - metadata: Dict[str, Any] = field(default_factory=dict) - - # Processing metadata - chunk_index: Optional[int] = None - total_chunks: Optional[int] = None - - def to_dict(self) -> Dict[str, Any]: - """Convert to dictionary for storage""" - return { - "id": self.id, - "content": self.content, - "doc_type": self.doc_type.value, - "industry": self.industry.value, - "title": self.title, - "source": self.source, - "created_at": self.created_at.isoformat() if self.created_at else None, - "updated_at": self.updated_at.isoformat() if self.updated_at else None, - "author": self.author, - "version": self.version, - "metadata": self.metadata, - "chunk_index": self.chunk_index, - "total_chunks": self.total_chunks, - } - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> 'Document': - """Create from dictionary""" - return cls( - id=data["id"], - content=data["content"], - doc_type=DocumentType(data["doc_type"]), - industry=IndustryNiche(data["industry"]), - title=data.get("title"), - source=data.get("source"), - created_at=datetime.fromisoformat(data["created_at"]) if data.get("created_at") else None, - updated_at=datetime.fromisoformat(data["updated_at"]) if data.get("updated_at") else None, - author=data.get("author"), - version=data.get("version"), - metadata=data.get("metadata", {}), - chunk_index=data.get("chunk_index"), - total_chunks=data.get("total_chunks"), - ) - - -@dataclass -class RetrievalResult: - """Result from RAG retrieval""" - document: Document - score: float # Similarity score - rerank_score: Optional[float] = None # Reranking score if applicable - - def to_dict(self) -> Dict[str, Any]: - """Convert to dictionary""" - return { - "document": self.document.to_dict(), - "score": self.score, - "rerank_score": self.rerank_score, - } - - -@dataclass -class RAGQuery: - """Query for RAG system""" - query: str - industry: Optional[IndustryNiche] = None - doc_types: Optional[List[DocumentType]] = None - date_range: Optional[tuple[datetime, datetime]] = None - metadata_filters: Optional[Dict[str, Any]] = None - top_k: Optional[int] = None - - def to_dict(self) -> Dict[str, Any]: - """Convert to dictionary""" - return { - "query": self.query, - "industry": self.industry.value if self.industry else None, - "doc_types": [dt.value for dt in self.doc_types] if self.doc_types else None, - "date_range": [dr.isoformat() for dr in self.date_range] if self.date_range else None, - "metadata_filters": self.metadata_filters, - "top_k": self.top_k, - } - - -class UniversalRAGEngine(ABC): - """ - Abstract base class for universal RAG engine - Implements common RAG patterns across all industries - """ - - def __init__(self, config: RAGConfig): - self.config = config - self._initialize() - - @abstractmethod - def _initialize(self): - """Initialize RAG components (vector store, embeddings, etc.)""" - pass - - @abstractmethod - def index_document(self, document: Document) -> bool: - """ - Index a single document - - Args: - document: Document to index - - Returns: - True if successful, False otherwise - """ - pass - - @abstractmethod - def index_documents(self, documents: List[Document]) -> Dict[str, Any]: - """ - Index multiple documents in batch - - Args: - documents: List of documents to index - - Returns: - Statistics about the indexing operation - """ - pass - - @abstractmethod - def retrieve(self, query: RAGQuery) -> List[RetrievalResult]: - """ - Retrieve relevant documents for a query - - Args: - query: Query with filters and parameters - - Returns: - List of retrieval results with scores - """ - pass - - @abstractmethod - def generate_with_context( - self, - query: str, - retrieved_docs: List[RetrievalResult], - llm_prompt_template: Optional[str] = None - ) -> Dict[str, Any]: - """ - Generate response using retrieved context - - Args: - query: User query - retrieved_docs: Retrieved documents for context - llm_prompt_template: Optional custom prompt template - - Returns: - Generated response with metadata - """ - pass - - @abstractmethod - def delete_documents(self, doc_ids: List[str]) -> bool: - """Delete documents by IDs""" - pass - - @abstractmethod - def update_document(self, doc_id: str, document: Document) -> bool: - """Update an existing document""" - pass - - @abstractmethod - def get_statistics(self) -> Dict[str, Any]: - """Get statistics about indexed documents""" - pass - - def search( - self, - query: str, - industry: Optional[IndustryNiche] = None, - doc_types: Optional[List[DocumentType]] = None, - top_k: Optional[int] = None, - **filters - ) -> List[RetrievalResult]: - """ - Convenience method for simple search - - Args: - query: Search query - industry: Filter by industry - doc_types: Filter by document types - top_k: Number of results - **filters: Additional metadata filters - - Returns: - List of retrieval results - """ - rag_query = RAGQuery( - query=query, - industry=industry, - doc_types=doc_types, - top_k=top_k or self.config.top_k, - metadata_filters=filters if filters else None - ) - return self.retrieve(rag_query) - +""" +Universal RAG Base Classes +Core abstractions for retrieval-augmented generation across all industries +""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from enum import Enum +from typing import List, Dict, Any, Optional, Union +from datetime import datetime + + +class DocumentType(Enum): + """Types of documents that can be indexed""" + + REGULATION = "regulation" # Laws, regulations, compliance documents + POLICY = "policy" # Insurance policies, company policies + MANUAL = "manual" # Equipment manuals, operation guides + CONTRACT = "contract" # Legal contracts, agreements + WORK_ORDER = "work_order" # Maintenance work orders, service requests + CLAIM_RECORD = "claim_record" # Insurance claims, warranty claims + MAINTENANCE_LOG = "maintenance_log" # Equipment maintenance history + FAQ = "faq" # Frequently asked questions + KNOWLEDGE_BASE = "knowledge_base" # General knowledge articles + REPORT = "report" # Inspection reports, audit reports + PROCEDURE = "procedure" # Standard operating procedures + SAFETY_GUIDELINE = "safety_guideline" # Safety protocols and guidelines + TECHNICAL_SPEC = "technical_spec" # Technical specifications + INVOICE = "invoice" # Billing and invoices + OTHER = "other" # Catch-all for other document types + + +class IndustryNiche(Enum): + """Supported industry niches""" + + INSURANCE = "insurance" # Property & casualty insurance + MANUFACTURING = "manufacturing" # Industrial manufacturing + PROPERTY_MANAGEMENT = "property_management" # Real estate management + HEALTHCARE = "healthcare" # Healthcare providers + CONSTRUCTION = "construction" # Construction industry + AUTOMOTIVE = "automotive" # Automotive services + ENERGY = "energy" # Energy & utilities + LOGISTICS = "logistics" # Transportation & logistics + RETAIL = "retail" # Retail operations + HOSPITALITY = "hospitality" # Hotels & hospitality + GENERIC = "generic" # Cross-industry + + +@dataclass +class RAGConfig: + """Configuration for RAG system""" + + # Industry configuration + industry: IndustryNiche = IndustryNiche.GENERIC + + # Vector database configuration + vector_db_type: str = "milvus" # milvus, pinecone, weaviate, memory + collection_name: str = "deepiri_universal_rag" + vector_db_host: str = "milvus" + vector_db_port: int = 19530 + + # Embedding configuration + embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2" + embedding_dimension: int = 384 + + # Retrieval configuration + top_k: int = 5 # Number of documents to retrieve + similarity_threshold: float = 0.7 # Minimum similarity score + use_reranking: bool = True # Use cross-encoder reranking + reranker_model: str = "cross-encoder/ms-marco-MiniLM-L-6-v2" + + # Chunking configuration + chunk_size: int = 1000 # Characters per chunk + chunk_overlap: int = 200 # Overlap between chunks + + # Metadata filtering + enable_metadata_filtering: bool = True + date_range_filtering: bool = True + + # Multi-modal support + support_images: bool = False + support_tables: bool = True + support_code: bool = False + + +@dataclass +class Document: + """Universal document representation""" + + id: str + content: str + doc_type: DocumentType + industry: IndustryNiche + + # Metadata + title: Optional[str] = None + source: Optional[str] = None + created_at: Optional[datetime] = None + updated_at: Optional[datetime] = None + author: Optional[str] = None + version: Optional[str] = None + + # Industry-specific metadata + metadata: Dict[str, Any] = field(default_factory=dict) + + # Processing metadata + chunk_index: Optional[int] = None + total_chunks: Optional[int] = None + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary for storage""" + return { + "id": self.id, + "content": self.content, + "doc_type": self.doc_type.value, + "industry": self.industry.value, + "title": self.title, + "source": self.source, + "created_at": self.created_at.isoformat() if self.created_at else None, + "updated_at": self.updated_at.isoformat() if self.updated_at else None, + "author": self.author, + "version": self.version, + "metadata": self.metadata, + "chunk_index": self.chunk_index, + "total_chunks": self.total_chunks, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "Document": + """Create from dictionary""" + return cls( + id=data["id"], + content=data["content"], + doc_type=DocumentType(data["doc_type"]), + industry=IndustryNiche(data["industry"]), + title=data.get("title"), + source=data.get("source"), + created_at=( + datetime.fromisoformat(data["created_at"]) + if data.get("created_at") + else None + ), + updated_at=( + datetime.fromisoformat(data["updated_at"]) + if data.get("updated_at") + else None + ), + author=data.get("author"), + version=data.get("version"), + metadata=data.get("metadata", {}), + chunk_index=data.get("chunk_index"), + total_chunks=data.get("total_chunks"), + ) + + +@dataclass +class RetrievalResult: + """Result from RAG retrieval""" + + document: Document + score: float # Similarity score + rerank_score: Optional[float] = None # Reranking score if applicable + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary""" + return { + "document": self.document.to_dict(), + "score": self.score, + "rerank_score": self.rerank_score, + } + + +@dataclass +class RAGQuery: + """Query for RAG system""" + + query: str + industry: Optional[IndustryNiche] = None + doc_types: Optional[List[DocumentType]] = None + date_range: Optional[tuple[datetime, datetime]] = None + metadata_filters: Optional[Dict[str, Any]] = None + top_k: Optional[int] = None + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary""" + return { + "query": self.query, + "industry": self.industry.value if self.industry else None, + "doc_types": ( + [dt.value for dt in self.doc_types] if self.doc_types else None + ), + "date_range": ( + [dr.isoformat() for dr in self.date_range] if self.date_range else None + ), + "metadata_filters": self.metadata_filters, + "top_k": self.top_k, + } + + +class UniversalRAGEngine(ABC): + """ + Abstract base class for universal RAG engine + Implements common RAG patterns across all industries + """ + + def __init__(self, config: RAGConfig): + self.config = config + self._initialize() + + @abstractmethod + def _initialize(self): + """Initialize RAG components (vector store, embeddings, etc.)""" + pass + + @abstractmethod + def index_document(self, document: Document) -> bool: + """ + Index a single document + + Args: + document: Document to index + + Returns: + True if successful, False otherwise + """ + pass + + @abstractmethod + def index_documents(self, documents: List[Document]) -> Dict[str, Any]: + """ + Index multiple documents in batch + + Args: + documents: List of documents to index + + Returns: + Statistics about the indexing operation + """ + pass + + @abstractmethod + def retrieve(self, query: RAGQuery) -> List[RetrievalResult]: + """ + Retrieve relevant documents for a query + + Args: + query: Query with filters and parameters + + Returns: + List of retrieval results with scores + """ + pass + + @abstractmethod + def generate_with_context( + self, + query: str, + retrieved_docs: List[RetrievalResult], + llm_prompt_template: Optional[str] = None, + ) -> Dict[str, Any]: + """ + Generate response using retrieved context + + Args: + query: User query + retrieved_docs: Retrieved documents for context + llm_prompt_template: Optional custom prompt template + + Returns: + Generated response with metadata + """ + pass + + @abstractmethod + def delete_documents(self, doc_ids: List[str]) -> bool: + """Delete documents by IDs""" + pass + + @abstractmethod + def update_document(self, doc_id: str, document: Document) -> bool: + """Update an existing document""" + pass + + @abstractmethod + def get_statistics(self) -> Dict[str, Any]: + """Get statistics about indexed documents""" + pass + + def search( + self, + query: str, + industry: Optional[IndustryNiche] = None, + doc_types: Optional[List[DocumentType]] = None, + top_k: Optional[int] = None, + **filters, + ) -> List[RetrievalResult]: + """ + Convenience method for simple search + + Args: + query: Search query + industry: Filter by industry + doc_types: Filter by document types + top_k: Number of results + **filters: Additional metadata filters + + Returns: + List of retrieval results + """ + rag_query = RAGQuery( + query=query, + industry=industry, + doc_types=doc_types, + top_k=top_k or self.config.top_k, + metadata_filters=filters if filters else None, + ) + return self.retrieve(rag_query) diff --git a/src/deepiri_modelkit/rag/caching.py b/src/deepiri_modelkit/rag/caching.py index 8da0ff7..e2e4b7b 100644 --- a/src/deepiri_modelkit/rag/caching.py +++ b/src/deepiri_modelkit/rag/caching.py @@ -1,466 +1,466 @@ -""" -Advanced Caching Layer for Universal RAG -Redis-based caching with intelligent invalidation and TTL management -""" - -from typing import Optional, Any, List, Dict -import json -import hashlib -import time -from datetime import datetime, timedelta -from dataclasses import dataclass, asdict - -from .base import Document, RetrievalResult, RAGQuery - - -@dataclass -class CacheEntry: - """Cache entry with metadata""" - key: str - value: Any - created_at: datetime - expires_at: Optional[datetime] - access_count: int = 0 - last_accessed: Optional[datetime] = None - tags: List[str] = None - - def __post_init__(self): - if self.tags is None: - self.tags = [] - if self.last_accessed is None: - self.last_accessed = self.created_at - - def is_expired(self) -> bool: - """Check if entry is expired""" - if self.expires_at is None: - return False - return datetime.now() > self.expires_at - - def to_dict(self) -> Dict: - """Convert to dictionary for serialization""" - return { - "key": self.key, - "value": self.value, - "created_at": self.created_at.isoformat(), - "expires_at": self.expires_at.isoformat() if self.expires_at else None, - "access_count": self.access_count, - "last_accessed": self.last_accessed.isoformat() if self.last_accessed else None, - "tags": self.tags - } - - @classmethod - def from_dict(cls, data: Dict) -> 'CacheEntry': - """Create from dictionary""" - return cls( - key=data["key"], - value=data["value"], - created_at=datetime.fromisoformat(data["created_at"]), - expires_at=datetime.fromisoformat(data["expires_at"]) if data.get("expires_at") else None, - access_count=data.get("access_count", 0), - last_accessed=datetime.fromisoformat(data["last_accessed"]) if data.get("last_accessed") else None, - tags=data.get("tags", []) - ) - - -class AdvancedCacheManager: - """ - Advanced cache manager with: - - TTL management - - Tag-based invalidation - - Access tracking - - Size limits - - LRU eviction - """ - - def __init__( - self, - redis_client=None, - default_ttl: int = 3600, - max_size: int = 10000, - enable_lru: bool = True - ): - self.redis_client = redis_client - self.default_ttl = default_ttl - self.max_size = max_size - self.enable_lru = enable_lru - - # In-memory fallback if Redis unavailable - self.memory_cache: Dict[str, CacheEntry] = {} - self.tag_index: Dict[str, List[str]] = {} # tag -> [keys] - - def _get_key_prefix(self, namespace: str = "rag") -> str: - """Get key prefix""" - return f"{namespace}:" - - def _serialize_value(self, value: Any) -> str: - """Serialize value for storage""" - if isinstance(value, (list, dict)): - return json.dumps(value) - return str(value) - - def _deserialize_value(self, value: str, value_type: type = None) -> Any: - """Deserialize value from storage""" - try: - if value_type == list or (isinstance(value, str) and value.startswith('[')): - return json.loads(value) - elif value_type == dict or (isinstance(value, str) and value.startswith('{')): - return json.loads(value) - return value - except (json.JSONDecodeError, TypeError): - return value - - def get( - self, - key: str, - namespace: str = "rag", - update_access: bool = True - ) -> Optional[Any]: - """Get value from cache""" - full_key = f"{self._get_key_prefix(namespace)}{key}" - - # Try Redis first - if self.redis_client: - try: - cached = self.redis_client.get(full_key) - if cached: - entry_data = json.loads(cached) - entry = CacheEntry.from_dict(entry_data) - - if entry.is_expired(): - self.delete(key, namespace) - return None - - if update_access: - entry.access_count += 1 - entry.last_accessed = datetime.now() - self._update_redis_entry(full_key, entry) - - return entry.value - except Exception as e: - # Fallback to memory - pass - - # Fallback to memory cache - if key in self.memory_cache: - entry = self.memory_cache[key] - - if entry.is_expired(): - del self.memory_cache[key] - return None - - if update_access: - entry.access_count += 1 - entry.last_accessed = datetime.now() - - return entry.value - - return None - - def set( - self, - key: str, - value: Any, - namespace: str = "rag", - ttl: Optional[int] = None, - tags: Optional[List[str]] = None - ) -> bool: - """Set value in cache""" - full_key = f"{self._get_key_prefix(namespace)}{key}" - ttl = ttl or self.default_ttl - tags = tags or [] - - expires_at = datetime.now() + timedelta(seconds=ttl) - entry = CacheEntry( - key=full_key, - value=value, - created_at=datetime.now(), - expires_at=expires_at, - tags=tags - ) - - # Try Redis first - if self.redis_client: - try: - entry_data = json.dumps(entry.to_dict()) - self.redis_client.setex(full_key, ttl, entry_data) - - # Update tag index - for tag in tags: - tag_key = f"{self._get_key_prefix(namespace)}tag:{tag}" - if tag_key not in self.tag_index: - self.tag_index[tag_key] = [] - if full_key not in self.tag_index[tag_key]: - self.tag_index[tag_key].append(full_key) - self.redis_client.sadd(tag_key, full_key) - - return True - except Exception as e: - # Fallback to memory - pass - - # Fallback to memory cache - # Check size limit - if len(self.memory_cache) >= self.max_size: - if self.enable_lru: - self._evict_lru() - else: - # Remove oldest - oldest_key = min( - self.memory_cache.keys(), - key=lambda k: self.memory_cache[k].created_at - ) - del self.memory_cache[oldest_key] - - self.memory_cache[key] = entry - - # Update tag index - for tag in tags: - if tag not in self.tag_index: - self.tag_index[tag] = [] - if key not in self.tag_index[tag]: - self.tag_index[tag].append(key) - - return True - - def delete(self, key: str, namespace: str = "rag") -> bool: - """Delete key from cache""" - full_key = f"{self._get_key_prefix(namespace)}{key}" - - # Try Redis - if self.redis_client: - try: - self.redis_client.delete(full_key) - return True - except Exception: - pass - - # Memory cache - if key in self.memory_cache: - entry = self.memory_cache[key] - # Remove from tag indexes - for tag in entry.tags: - if tag in self.tag_index and key in self.tag_index[tag]: - self.tag_index[tag].remove(key) - del self.memory_cache[key] - return True - - return False - - def invalidate_by_tag(self, tag: str, namespace: str = "rag") -> int: - """Invalidate all keys with given tag""" - tag_key = f"{self._get_key_prefix(namespace)}tag:{tag}" - count = 0 - - # Get keys from tag index - keys_to_delete = [] - - if self.redis_client: - try: - keys_to_delete = list(self.redis_client.smembers(tag_key)) - self.redis_client.delete(tag_key) - except Exception: - pass - - if tag in self.tag_index: - keys_to_delete.extend(self.tag_index[tag]) - del self.tag_index[tag] - - # Delete all keys - for key in keys_to_delete: - if self.delete(key.replace(self._get_key_prefix(namespace), ""), namespace): - count += 1 - - return count - - def invalidate_by_pattern(self, pattern: str, namespace: str = "rag") -> int: - """Invalidate keys matching pattern""" - full_pattern = f"{self._get_key_prefix(namespace)}{pattern}" - count = 0 - - if self.redis_client: - try: - keys = self.redis_client.keys(full_pattern) - if keys: - count = self.redis_client.delete(*keys) - except Exception: - pass - - # Memory cache - keys_to_delete = [ - k for k in self.memory_cache.keys() - if self._match_pattern(k, pattern) - ] - for key in keys_to_delete: - self.delete(key, namespace) - count += 1 - - return count - - def _match_pattern(self, key: str, pattern: str) -> bool: - """Simple pattern matching (supports * wildcard)""" - import fnmatch - return fnmatch.fnmatch(key, pattern) - - def _evict_lru(self): - """Evict least recently used entry""" - if not self.memory_cache: - return - - lru_key = min( - self.memory_cache.keys(), - key=lambda k: ( - self.memory_cache[k].last_accessed or - self.memory_cache[k].created_at - ) - ) - del self.memory_cache[lru_key] - - def _update_redis_entry(self, key: str, entry: CacheEntry): - """Update Redis entry with new access info""" - if not self.redis_client: - return - - try: - entry_data = json.dumps(entry.to_dict()) - # Get remaining TTL - ttl = self.redis_client.ttl(key) - if ttl > 0: - self.redis_client.setex(key, ttl, entry_data) - except Exception: - pass - - def get_stats(self) -> Dict[str, Any]: - """Get cache statistics""" - stats = { - "memory_cache_size": len(self.memory_cache), - "max_size": self.max_size, - "tag_index_size": len(self.tag_index), - "redis_available": self.redis_client is not None, - } - - if self.memory_cache: - total_access = sum(e.access_count for e in self.memory_cache.values()) - stats["total_accesses"] = total_access - stats["avg_access_per_entry"] = total_access / len(self.memory_cache) - - return stats - - def clear(self, namespace: str = "rag"): - """Clear all cache entries in namespace""" - pattern = f"{self._get_key_prefix(namespace)}*" - return self.invalidate_by_pattern(pattern, namespace) - - -class EmbeddingCache: - """Specialized cache for embeddings""" - - def __init__(self, cache_manager: AdvancedCacheManager): - self.cache_manager = cache_manager - self.namespace = "rag:embeddings" - - def get_embedding_key(self, text: str) -> str: - """Generate cache key for embedding""" - text_hash = hashlib.md5(text.encode()).hexdigest() - return f"emb:{text_hash}" - - def get(self, text: str) -> Optional[Any]: - """Get cached embedding""" - key = self.get_embedding_key(text) - return self.cache_manager.get(key, namespace=self.namespace) - - def set(self, text: str, embedding: Any, ttl: int = 86400): - """Cache embedding (24 hour default TTL)""" - key = self.get_embedding_key(text) - return self.cache_manager.set( - key, - embedding, - namespace=self.namespace, - ttl=ttl, - tags=["embedding"] - ) - - -class QueryResultCache: - """Specialized cache for query results""" - - def __init__(self, cache_manager: AdvancedCacheManager): - self.cache_manager = cache_manager - self.namespace = "rag:queries" - - def get_query_key(self, query: RAGQuery) -> str: - """Generate cache key for query""" - query_dict = query.to_dict() - query_str = json.dumps(query_dict, sort_keys=True) - query_hash = hashlib.md5(query_str.encode()).hexdigest() - return f"query:{query_hash}" - - def get(self, query: RAGQuery) -> Optional[List[RetrievalResult]]: - """Get cached query results""" - key = self.get_query_key(query) - cached = self.cache_manager.get(key, namespace=self.namespace) - - if cached: - # Reconstruct RetrievalResult objects - results = [] - for item in cached: - doc = Document.from_dict(item["document"]) - result = RetrievalResult( - document=doc, - score=item["score"], - rerank_score=item.get("rerank_score") - ) - results.append(result) - return results - - return None - - def set( - self, - query: RAGQuery, - results: List[RetrievalResult], - ttl: int = 3600, - tags: Optional[List[str]] = None - ): - """Cache query results""" - key = self.get_query_key(query) - - # Serialize results - serialized = [ - { - "document": r.document.to_dict(), - "score": r.score, - "rerank_score": r.rerank_score - } - for r in results - ] - - # Add query tags - query_tags = tags or [] - if query.industry: - query_tags.append(f"industry:{query.industry.value if hasattr(query.industry, 'value') else query.industry}") - if query.doc_types: - for dt in query.doc_types: - query_tags.append(f"doctype:{dt.value if hasattr(dt, 'value') else dt}") - - return self.cache_manager.set( - key, - serialized, - namespace=self.namespace, - ttl=ttl, - tags=query_tags - ) - - def invalidate_by_industry(self, industry: str): - """Invalidate all queries for an industry""" - return self.cache_manager.invalidate_by_tag( - f"industry:{industry}", - namespace=self.namespace - ) - - def invalidate_by_doc_type(self, doc_type: str): - """Invalidate all queries for a document type""" - return self.cache_manager.invalidate_by_tag( - f"doctype:{doc_type}", - namespace=self.namespace - ) - +""" +Advanced Caching Layer for Universal RAG +Redis-based caching with intelligent invalidation and TTL management +""" + +from typing import Optional, Any, List, Dict +import json +import hashlib +import time +from datetime import datetime, timedelta +from dataclasses import dataclass, asdict + +from .base import Document, RetrievalResult, RAGQuery + + +@dataclass +class CacheEntry: + """Cache entry with metadata""" + + key: str + value: Any + created_at: datetime + expires_at: Optional[datetime] + access_count: int = 0 + last_accessed: Optional[datetime] = None + tags: List[str] = None + + def __post_init__(self): + if self.tags is None: + self.tags = [] + if self.last_accessed is None: + self.last_accessed = self.created_at + + def is_expired(self) -> bool: + """Check if entry is expired""" + if self.expires_at is None: + return False + return datetime.now() > self.expires_at + + def to_dict(self) -> Dict: + """Convert to dictionary for serialization""" + return { + "key": self.key, + "value": self.value, + "created_at": self.created_at.isoformat(), + "expires_at": self.expires_at.isoformat() if self.expires_at else None, + "access_count": self.access_count, + "last_accessed": ( + self.last_accessed.isoformat() if self.last_accessed else None + ), + "tags": self.tags, + } + + @classmethod + def from_dict(cls, data: Dict) -> "CacheEntry": + """Create from dictionary""" + return cls( + key=data["key"], + value=data["value"], + created_at=datetime.fromisoformat(data["created_at"]), + expires_at=( + datetime.fromisoformat(data["expires_at"]) + if data.get("expires_at") + else None + ), + access_count=data.get("access_count", 0), + last_accessed=( + datetime.fromisoformat(data["last_accessed"]) + if data.get("last_accessed") + else None + ), + tags=data.get("tags", []), + ) + + +class AdvancedCacheManager: + """ + Advanced cache manager with: + - TTL management + - Tag-based invalidation + - Access tracking + - Size limits + - LRU eviction + """ + + def __init__( + self, + redis_client=None, + default_ttl: int = 3600, + max_size: int = 10000, + enable_lru: bool = True, + ): + self.redis_client = redis_client + self.default_ttl = default_ttl + self.max_size = max_size + self.enable_lru = enable_lru + + # In-memory fallback if Redis unavailable + self.memory_cache: Dict[str, CacheEntry] = {} + self.tag_index: Dict[str, List[str]] = {} # tag -> [keys] + + def _get_key_prefix(self, namespace: str = "rag") -> str: + """Get key prefix""" + return f"{namespace}:" + + def _serialize_value(self, value: Any) -> str: + """Serialize value for storage""" + if isinstance(value, (list, dict)): + return json.dumps(value) + return str(value) + + def _deserialize_value(self, value: str, value_type: type = None) -> Any: + """Deserialize value from storage""" + try: + if value_type == list or (isinstance(value, str) and value.startswith("[")): + return json.loads(value) + elif value_type == dict or ( + isinstance(value, str) and value.startswith("{") + ): + return json.loads(value) + return value + except (json.JSONDecodeError, TypeError): + return value + + def get( + self, key: str, namespace: str = "rag", update_access: bool = True + ) -> Optional[Any]: + """Get value from cache""" + full_key = f"{self._get_key_prefix(namespace)}{key}" + + # Try Redis first + if self.redis_client: + try: + cached = self.redis_client.get(full_key) + if cached: + entry_data = json.loads(cached) + entry = CacheEntry.from_dict(entry_data) + + if entry.is_expired(): + self.delete(key, namespace) + return None + + if update_access: + entry.access_count += 1 + entry.last_accessed = datetime.now() + self._update_redis_entry(full_key, entry) + + return entry.value + except Exception as e: + # Fallback to memory + pass + + # Fallback to memory cache + if key in self.memory_cache: + entry = self.memory_cache[key] + + if entry.is_expired(): + del self.memory_cache[key] + return None + + if update_access: + entry.access_count += 1 + entry.last_accessed = datetime.now() + + return entry.value + + return None + + def set( + self, + key: str, + value: Any, + namespace: str = "rag", + ttl: Optional[int] = None, + tags: Optional[List[str]] = None, + ) -> bool: + """Set value in cache""" + full_key = f"{self._get_key_prefix(namespace)}{key}" + ttl = ttl or self.default_ttl + tags = tags or [] + + expires_at = datetime.now() + timedelta(seconds=ttl) + entry = CacheEntry( + key=full_key, + value=value, + created_at=datetime.now(), + expires_at=expires_at, + tags=tags, + ) + + # Try Redis first + if self.redis_client: + try: + entry_data = json.dumps(entry.to_dict()) + self.redis_client.setex(full_key, ttl, entry_data) + + # Update tag index + for tag in tags: + tag_key = f"{self._get_key_prefix(namespace)}tag:{tag}" + if tag_key not in self.tag_index: + self.tag_index[tag_key] = [] + if full_key not in self.tag_index[tag_key]: + self.tag_index[tag_key].append(full_key) + self.redis_client.sadd(tag_key, full_key) + + return True + except Exception as e: + # Fallback to memory + pass + + # Fallback to memory cache + # Check size limit + if len(self.memory_cache) >= self.max_size: + if self.enable_lru: + self._evict_lru() + else: + # Remove oldest + oldest_key = min( + self.memory_cache.keys(), + key=lambda k: self.memory_cache[k].created_at, + ) + del self.memory_cache[oldest_key] + + self.memory_cache[key] = entry + + # Update tag index + for tag in tags: + if tag not in self.tag_index: + self.tag_index[tag] = [] + if key not in self.tag_index[tag]: + self.tag_index[tag].append(key) + + return True + + def delete(self, key: str, namespace: str = "rag") -> bool: + """Delete key from cache""" + full_key = f"{self._get_key_prefix(namespace)}{key}" + + # Try Redis + if self.redis_client: + try: + self.redis_client.delete(full_key) + return True + except Exception: + pass + + # Memory cache + if key in self.memory_cache: + entry = self.memory_cache[key] + # Remove from tag indexes + for tag in entry.tags: + if tag in self.tag_index and key in self.tag_index[tag]: + self.tag_index[tag].remove(key) + del self.memory_cache[key] + return True + + return False + + def invalidate_by_tag(self, tag: str, namespace: str = "rag") -> int: + """Invalidate all keys with given tag""" + tag_key = f"{self._get_key_prefix(namespace)}tag:{tag}" + count = 0 + + # Get keys from tag index + keys_to_delete = [] + + if self.redis_client: + try: + keys_to_delete = list(self.redis_client.smembers(tag_key)) + self.redis_client.delete(tag_key) + except Exception: + pass + + if tag in self.tag_index: + keys_to_delete.extend(self.tag_index[tag]) + del self.tag_index[tag] + + # Delete all keys + for key in keys_to_delete: + if self.delete(key.replace(self._get_key_prefix(namespace), ""), namespace): + count += 1 + + return count + + def invalidate_by_pattern(self, pattern: str, namespace: str = "rag") -> int: + """Invalidate keys matching pattern""" + full_pattern = f"{self._get_key_prefix(namespace)}{pattern}" + count = 0 + + if self.redis_client: + try: + keys = self.redis_client.keys(full_pattern) + if keys: + count = self.redis_client.delete(*keys) + except Exception: + pass + + # Memory cache + keys_to_delete = [ + k for k in self.memory_cache.keys() if self._match_pattern(k, pattern) + ] + for key in keys_to_delete: + self.delete(key, namespace) + count += 1 + + return count + + def _match_pattern(self, key: str, pattern: str) -> bool: + """Simple pattern matching (supports * wildcard)""" + import fnmatch + + return fnmatch.fnmatch(key, pattern) + + def _evict_lru(self): + """Evict least recently used entry""" + if not self.memory_cache: + return + + lru_key = min( + self.memory_cache.keys(), + key=lambda k: ( + self.memory_cache[k].last_accessed or self.memory_cache[k].created_at + ), + ) + del self.memory_cache[lru_key] + + def _update_redis_entry(self, key: str, entry: CacheEntry): + """Update Redis entry with new access info""" + if not self.redis_client: + return + + try: + entry_data = json.dumps(entry.to_dict()) + # Get remaining TTL + ttl = self.redis_client.ttl(key) + if ttl > 0: + self.redis_client.setex(key, ttl, entry_data) + except Exception: + pass + + def get_stats(self) -> Dict[str, Any]: + """Get cache statistics""" + stats = { + "memory_cache_size": len(self.memory_cache), + "max_size": self.max_size, + "tag_index_size": len(self.tag_index), + "redis_available": self.redis_client is not None, + } + + if self.memory_cache: + total_access = sum(e.access_count for e in self.memory_cache.values()) + stats["total_accesses"] = total_access + stats["avg_access_per_entry"] = total_access / len(self.memory_cache) + + return stats + + def clear(self, namespace: str = "rag"): + """Clear all cache entries in namespace""" + pattern = f"{self._get_key_prefix(namespace)}*" + return self.invalidate_by_pattern(pattern, namespace) + + +class EmbeddingCache: + """Specialized cache for embeddings""" + + def __init__(self, cache_manager: AdvancedCacheManager): + self.cache_manager = cache_manager + self.namespace = "rag:embeddings" + + def get_embedding_key(self, text: str) -> str: + """Generate cache key for embedding""" + text_hash = hashlib.md5(text.encode()).hexdigest() + return f"emb:{text_hash}" + + def get(self, text: str) -> Optional[Any]: + """Get cached embedding""" + key = self.get_embedding_key(text) + return self.cache_manager.get(key, namespace=self.namespace) + + def set(self, text: str, embedding: Any, ttl: int = 86400): + """Cache embedding (24 hour default TTL)""" + key = self.get_embedding_key(text) + return self.cache_manager.set( + key, embedding, namespace=self.namespace, ttl=ttl, tags=["embedding"] + ) + + +class QueryResultCache: + """Specialized cache for query results""" + + def __init__(self, cache_manager: AdvancedCacheManager): + self.cache_manager = cache_manager + self.namespace = "rag:queries" + + def get_query_key(self, query: RAGQuery) -> str: + """Generate cache key for query""" + query_dict = query.to_dict() + query_str = json.dumps(query_dict, sort_keys=True) + query_hash = hashlib.md5(query_str.encode()).hexdigest() + return f"query:{query_hash}" + + def get(self, query: RAGQuery) -> Optional[List[RetrievalResult]]: + """Get cached query results""" + key = self.get_query_key(query) + cached = self.cache_manager.get(key, namespace=self.namespace) + + if cached: + # Reconstruct RetrievalResult objects + results = [] + for item in cached: + doc = Document.from_dict(item["document"]) + result = RetrievalResult( + document=doc, + score=item["score"], + rerank_score=item.get("rerank_score"), + ) + results.append(result) + return results + + return None + + def set( + self, + query: RAGQuery, + results: List[RetrievalResult], + ttl: int = 3600, + tags: Optional[List[str]] = None, + ): + """Cache query results""" + key = self.get_query_key(query) + + # Serialize results + serialized = [ + { + "document": r.document.to_dict(), + "score": r.score, + "rerank_score": r.rerank_score, + } + for r in results + ] + + # Add query tags + query_tags = tags or [] + if query.industry: + query_tags.append( + f"industry:{query.industry.value if hasattr(query.industry, 'value') else query.industry}" + ) + if query.doc_types: + for dt in query.doc_types: + query_tags.append(f"doctype:{dt.value if hasattr(dt, 'value') else dt}") + + return self.cache_manager.set( + key, serialized, namespace=self.namespace, ttl=ttl, tags=query_tags + ) + + def invalidate_by_industry(self, industry: str): + """Invalidate all queries for an industry""" + return self.cache_manager.invalidate_by_tag( + f"industry:{industry}", namespace=self.namespace + ) + + def invalidate_by_doc_type(self, doc_type: str): + """Invalidate all queries for a document type""" + return self.cache_manager.invalidate_by_tag( + f"doctype:{doc_type}", namespace=self.namespace + ) diff --git a/src/deepiri_modelkit/rag/monitoring.py b/src/deepiri_modelkit/rag/monitoring.py index 95429b3..f587408 100644 --- a/src/deepiri_modelkit/rag/monitoring.py +++ b/src/deepiri_modelkit/rag/monitoring.py @@ -1,354 +1,375 @@ -""" -Monitoring and Observability for Universal RAG -Metrics, performance tracking, and analytics -""" - -from typing import Dict, Any, List, Optional -from dataclasses import dataclass, field, asdict -from datetime import datetime, timedelta -from collections import defaultdict -import time -import json - -from .base import RAGQuery, RetrievalResult - - -@dataclass -class RetrievalMetrics: - """Metrics for a single retrieval operation""" - query_id: str - query_text: str - timestamp: datetime - retrieval_time_ms: float - num_results: int - top_score: Optional[float] = None - cache_hit: bool = False - reranking_used: bool = False - query_expansion_used: bool = False - industry: Optional[str] = None - doc_types: Optional[List[str]] = None - - def to_dict(self) -> Dict: - """Convert to dictionary""" - return { - "query_id": self.query_id, - "query_text": self.query_text, - "timestamp": self.timestamp.isoformat(), - "retrieval_time_ms": self.retrieval_time_ms, - "num_results": self.num_results, - "top_score": self.top_score, - "cache_hit": self.cache_hit, - "reranking_used": self.reranking_used, - "query_expansion_used": self.query_expansion_used, - "industry": self.industry, - "doc_types": self.doc_types, - } - - -@dataclass -class IndexingMetrics: - """Metrics for indexing operations""" - operation_id: str - timestamp: datetime - operation_type: str # "index", "update", "delete" - num_documents: int - processing_time_ms: float - success: bool - error: Optional[str] = None - - def to_dict(self) -> Dict: - """Convert to dictionary""" - return { - "operation_id": self.operation_id, - "timestamp": self.timestamp.isoformat(), - "operation_type": self.operation_type, - "num_documents": self.num_documents, - "processing_time_ms": self.processing_time_ms, - "success": self.success, - "error": self.error, - } - - -@dataclass -class SystemMetrics: - """System-wide metrics""" - total_queries: int = 0 - total_indexed_documents: int = 0 - cache_hit_rate: float = 0.0 - avg_retrieval_time_ms: float = 0.0 - avg_indexing_time_ms: float = 0.0 - error_rate: float = 0.0 - last_updated: Optional[datetime] = None - - def to_dict(self) -> Dict: - """Convert to dictionary""" - return { - "total_queries": self.total_queries, - "total_indexed_documents": self.total_indexed_documents, - "cache_hit_rate": self.cache_hit_rate, - "avg_retrieval_time_ms": self.avg_retrieval_time_ms, - "avg_indexing_time_ms": self.avg_indexing_time_ms, - "error_rate": self.error_rate, - "last_updated": self.last_updated.isoformat() if self.last_updated else None, - } - - -class RAGMonitor: - """ - Monitor RAG system performance and collect metrics - """ - - def __init__(self, max_history: int = 10000): - self.max_history = max_history - - # Metrics storage - self.retrieval_metrics: List[RetrievalMetrics] = [] - self.indexing_metrics: List[IndexingMetrics] = [] - - # Aggregated stats - self.system_metrics = SystemMetrics() - - # Time windows for analysis - self.hourly_stats: Dict[str, Dict] = defaultdict(dict) - self.daily_stats: Dict[str, Dict] = defaultdict(dict) - - def record_retrieval( - self, - query: RAGQuery, - results: List[RetrievalResult], - retrieval_time_ms: float, - cache_hit: bool = False, - reranking_used: bool = False, - query_expansion_used: bool = False - ): - """Record retrieval metrics""" - query_id = f"q_{int(time.time() * 1000)}" - - metric = RetrievalMetrics( - query_id=query_id, - query_text=query.query, - timestamp=datetime.now(), - retrieval_time_ms=retrieval_time_ms, - num_results=len(results), - top_score=results[0].score if results else None, - cache_hit=cache_hit, - reranking_used=reranking_used, - query_expansion_used=query_expansion_used, - industry=query.industry.value if query.industry and hasattr(query.industry, 'value') else str(query.industry), - doc_types=[dt.value if hasattr(dt, 'value') else str(dt) for dt in query.doc_types] if query.doc_types else None, - ) - - self.retrieval_metrics.append(metric) - - # Trim history - if len(self.retrieval_metrics) > self.max_history: - self.retrieval_metrics = self.retrieval_metrics[-self.max_history:] - - # Update aggregated stats - self._update_system_metrics() - - def record_indexing( - self, - operation_type: str, - num_documents: int, - processing_time_ms: float, - success: bool, - error: Optional[str] = None - ): - """Record indexing metrics""" - operation_id = f"idx_{int(time.time() * 1000)}" - - metric = IndexingMetrics( - operation_id=operation_id, - timestamp=datetime.now(), - operation_type=operation_type, - num_documents=num_documents, - processing_time_ms=processing_time_ms, - success=success, - error=error - ) - - self.indexing_metrics.append(metric) - - # Trim history - if len(self.indexing_metrics) > self.max_history: - self.indexing_metrics = self.indexing_metrics[-self.max_history:] - - # Update aggregated stats - self._update_system_metrics() - - def _update_system_metrics(self): - """Update system-wide aggregated metrics""" - if not self.retrieval_metrics: - return - - # Calculate averages - total_queries = len(self.retrieval_metrics) - cache_hits = sum(1 for m in self.retrieval_metrics if m.cache_hit) - total_retrieval_time = sum(m.retrieval_time_ms for m in self.retrieval_metrics) - - self.system_metrics.total_queries = total_queries - self.system_metrics.cache_hit_rate = cache_hits / total_queries if total_queries > 0 else 0.0 - self.system_metrics.avg_retrieval_time_ms = total_retrieval_time / total_queries if total_queries > 0 else 0.0 - - if self.indexing_metrics: - total_indexed = sum(m.num_documents for m in self.indexing_metrics if m.success) - total_indexing_time = sum(m.processing_time_ms for m in self.indexing_metrics) - total_indexing_ops = len(self.indexing_metrics) - - self.system_metrics.total_indexed_documents = total_indexed - self.system_metrics.avg_indexing_time_ms = total_indexing_time / total_indexing_ops if total_indexing_ops > 0 else 0.0 - - # Error rate - total_ops = total_queries + len(self.indexing_metrics) - errors = sum(1 for m in self.indexing_metrics if not m.success) - self.system_metrics.error_rate = errors / total_ops if total_ops > 0 else 0.0 - - self.system_metrics.last_updated = datetime.now() - - def get_retrieval_stats( - self, - time_window_minutes: Optional[int] = None, - industry: Optional[str] = None - ) -> Dict[str, Any]: - """Get retrieval statistics for time window""" - metrics = self.retrieval_metrics - - # Filter by time window - if time_window_minutes: - cutoff = datetime.now() - timedelta(minutes=time_window_minutes) - metrics = [m for m in metrics if m.timestamp >= cutoff] - - # Filter by industry - if industry: - metrics = [m for m in metrics if m.industry == industry] - - if not metrics: - return { - "count": 0, - "avg_time_ms": 0.0, - "cache_hit_rate": 0.0, - } - - return { - "count": len(metrics), - "avg_time_ms": sum(m.retrieval_time_ms for m in metrics) / len(metrics), - "min_time_ms": min(m.retrieval_time_ms for m in metrics), - "max_time_ms": max(m.retrieval_time_ms for m in metrics), - "cache_hit_rate": sum(1 for m in metrics if m.cache_hit) / len(metrics), - "avg_results": sum(m.num_results for m in metrics) / len(metrics), - "avg_top_score": sum(m.top_score for m in metrics if m.top_score) / len([m for m in metrics if m.top_score]), - } - - def get_indexing_stats( - self, - time_window_minutes: Optional[int] = None - ) -> Dict[str, Any]: - """Get indexing statistics""" - metrics = self.indexing_metrics - - # Filter by time window - if time_window_minutes: - cutoff = datetime.now() - timedelta(minutes=time_window_minutes) - metrics = [m for m in metrics if m.timestamp >= cutoff] - - if not metrics: - return { - "count": 0, - "total_documents": 0, - "avg_time_ms": 0.0, - "success_rate": 0.0, - } - - successful = [m for m in metrics if m.success] - - return { - "count": len(metrics), - "total_documents": sum(m.num_documents for m in successful), - "avg_time_ms": sum(m.processing_time_ms for m in metrics) / len(metrics), - "success_rate": len(successful) / len(metrics) if metrics else 0.0, - "error_count": len([m for m in metrics if not m.success]), - } - - def get_top_queries( - self, - limit: int = 10, - time_window_minutes: Optional[int] = None - ) -> List[Dict[str, Any]]: - """Get most frequent queries""" - metrics = self.retrieval_metrics - - # Filter by time window - if time_window_minutes: - cutoff = datetime.now() - timedelta(minutes=time_window_minutes) - metrics = [m for m in metrics if m.timestamp >= cutoff] - - # Count query frequencies - query_counts: Dict[str, int] = defaultdict(int) - for m in metrics: - query_counts[m.query_text] += 1 - - # Sort by frequency - top_queries = sorted( - query_counts.items(), - key=lambda x: x[1], - reverse=True - )[:limit] - - return [ - { - "query": query, - "count": count - } - for query, count in top_queries - ] - - def get_performance_report(self) -> Dict[str, Any]: - """Get comprehensive performance report""" - return { - "system_metrics": self.system_metrics.to_dict(), - "retrieval_stats_1h": self.get_retrieval_stats(time_window_minutes=60), - "retrieval_stats_24h": self.get_retrieval_stats(time_window_minutes=1440), - "indexing_stats_1h": self.get_indexing_stats(time_window_minutes=60), - "indexing_stats_24h": self.get_indexing_stats(time_window_minutes=1440), - "top_queries_24h": self.get_top_queries(limit=10, time_window_minutes=1440), - } - - def export_metrics(self, filepath: str): - """Export metrics to JSON file""" - data = { - "retrieval_metrics": [m.to_dict() for m in self.retrieval_metrics[-1000:]], # Last 1000 - "indexing_metrics": [m.to_dict() for m in self.indexing_metrics[-1000:]], # Last 1000 - "system_metrics": self.system_metrics.to_dict(), - "exported_at": datetime.now().isoformat(), - } - - with open(filepath, 'w') as f: - json.dump(data, f, indent=2) - - -class PerformanceTimer: - """Context manager for timing operations""" - - def __init__(self, monitor: Optional[RAGMonitor] = None, operation_name: str = "operation"): - self.monitor = monitor - self.operation_name = operation_name - self.start_time = None - self.end_time = None - - def __enter__(self): - self.start_time = time.time() - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.end_time = time.time() - return False - - def elapsed_ms(self) -> float: - """Get elapsed time in milliseconds""" - if self.start_time and self.end_time: - return (self.end_time - self.start_time) * 1000 - elif self.start_time: - return (time.time() - self.start_time) * 1000 - return 0.0 - +""" +Monitoring and Observability for Universal RAG +Metrics, performance tracking, and analytics +""" + +from typing import Dict, Any, List, Optional +from dataclasses import dataclass, field, asdict +from datetime import datetime, timedelta +from collections import defaultdict +import time +import json + +from .base import RAGQuery, RetrievalResult + + +@dataclass +class RetrievalMetrics: + """Metrics for a single retrieval operation""" + + query_id: str + query_text: str + timestamp: datetime + retrieval_time_ms: float + num_results: int + top_score: Optional[float] = None + cache_hit: bool = False + reranking_used: bool = False + query_expansion_used: bool = False + industry: Optional[str] = None + doc_types: Optional[List[str]] = None + + def to_dict(self) -> Dict: + """Convert to dictionary""" + return { + "query_id": self.query_id, + "query_text": self.query_text, + "timestamp": self.timestamp.isoformat(), + "retrieval_time_ms": self.retrieval_time_ms, + "num_results": self.num_results, + "top_score": self.top_score, + "cache_hit": self.cache_hit, + "reranking_used": self.reranking_used, + "query_expansion_used": self.query_expansion_used, + "industry": self.industry, + "doc_types": self.doc_types, + } + + +@dataclass +class IndexingMetrics: + """Metrics for indexing operations""" + + operation_id: str + timestamp: datetime + operation_type: str # "index", "update", "delete" + num_documents: int + processing_time_ms: float + success: bool + error: Optional[str] = None + + def to_dict(self) -> Dict: + """Convert to dictionary""" + return { + "operation_id": self.operation_id, + "timestamp": self.timestamp.isoformat(), + "operation_type": self.operation_type, + "num_documents": self.num_documents, + "processing_time_ms": self.processing_time_ms, + "success": self.success, + "error": self.error, + } + + +@dataclass +class SystemMetrics: + """System-wide metrics""" + + total_queries: int = 0 + total_indexed_documents: int = 0 + cache_hit_rate: float = 0.0 + avg_retrieval_time_ms: float = 0.0 + avg_indexing_time_ms: float = 0.0 + error_rate: float = 0.0 + last_updated: Optional[datetime] = None + + def to_dict(self) -> Dict: + """Convert to dictionary""" + return { + "total_queries": self.total_queries, + "total_indexed_documents": self.total_indexed_documents, + "cache_hit_rate": self.cache_hit_rate, + "avg_retrieval_time_ms": self.avg_retrieval_time_ms, + "avg_indexing_time_ms": self.avg_indexing_time_ms, + "error_rate": self.error_rate, + "last_updated": ( + self.last_updated.isoformat() if self.last_updated else None + ), + } + + +class RAGMonitor: + """ + Monitor RAG system performance and collect metrics + """ + + def __init__(self, max_history: int = 10000): + self.max_history = max_history + + # Metrics storage + self.retrieval_metrics: List[RetrievalMetrics] = [] + self.indexing_metrics: List[IndexingMetrics] = [] + + # Aggregated stats + self.system_metrics = SystemMetrics() + + # Time windows for analysis + self.hourly_stats: Dict[str, Dict] = defaultdict(dict) + self.daily_stats: Dict[str, Dict] = defaultdict(dict) + + def record_retrieval( + self, + query: RAGQuery, + results: List[RetrievalResult], + retrieval_time_ms: float, + cache_hit: bool = False, + reranking_used: bool = False, + query_expansion_used: bool = False, + ): + """Record retrieval metrics""" + query_id = f"q_{int(time.time() * 1000)}" + + metric = RetrievalMetrics( + query_id=query_id, + query_text=query.query, + timestamp=datetime.now(), + retrieval_time_ms=retrieval_time_ms, + num_results=len(results), + top_score=results[0].score if results else None, + cache_hit=cache_hit, + reranking_used=reranking_used, + query_expansion_used=query_expansion_used, + industry=( + query.industry.value + if query.industry and hasattr(query.industry, "value") + else str(query.industry) + ), + doc_types=( + [ + dt.value if hasattr(dt, "value") else str(dt) + for dt in query.doc_types + ] + if query.doc_types + else None + ), + ) + + self.retrieval_metrics.append(metric) + + # Trim history + if len(self.retrieval_metrics) > self.max_history: + self.retrieval_metrics = self.retrieval_metrics[-self.max_history :] + + # Update aggregated stats + self._update_system_metrics() + + def record_indexing( + self, + operation_type: str, + num_documents: int, + processing_time_ms: float, + success: bool, + error: Optional[str] = None, + ): + """Record indexing metrics""" + operation_id = f"idx_{int(time.time() * 1000)}" + + metric = IndexingMetrics( + operation_id=operation_id, + timestamp=datetime.now(), + operation_type=operation_type, + num_documents=num_documents, + processing_time_ms=processing_time_ms, + success=success, + error=error, + ) + + self.indexing_metrics.append(metric) + + # Trim history + if len(self.indexing_metrics) > self.max_history: + self.indexing_metrics = self.indexing_metrics[-self.max_history :] + + # Update aggregated stats + self._update_system_metrics() + + def _update_system_metrics(self): + """Update system-wide aggregated metrics""" + if not self.retrieval_metrics: + return + + # Calculate averages + total_queries = len(self.retrieval_metrics) + cache_hits = sum(1 for m in self.retrieval_metrics if m.cache_hit) + total_retrieval_time = sum(m.retrieval_time_ms for m in self.retrieval_metrics) + + self.system_metrics.total_queries = total_queries + self.system_metrics.cache_hit_rate = ( + cache_hits / total_queries if total_queries > 0 else 0.0 + ) + self.system_metrics.avg_retrieval_time_ms = ( + total_retrieval_time / total_queries if total_queries > 0 else 0.0 + ) + + if self.indexing_metrics: + total_indexed = sum( + m.num_documents for m in self.indexing_metrics if m.success + ) + total_indexing_time = sum( + m.processing_time_ms for m in self.indexing_metrics + ) + total_indexing_ops = len(self.indexing_metrics) + + self.system_metrics.total_indexed_documents = total_indexed + self.system_metrics.avg_indexing_time_ms = ( + total_indexing_time / total_indexing_ops + if total_indexing_ops > 0 + else 0.0 + ) + + # Error rate + total_ops = total_queries + len(self.indexing_metrics) + errors = sum(1 for m in self.indexing_metrics if not m.success) + self.system_metrics.error_rate = errors / total_ops if total_ops > 0 else 0.0 + + self.system_metrics.last_updated = datetime.now() + + def get_retrieval_stats( + self, time_window_minutes: Optional[int] = None, industry: Optional[str] = None + ) -> Dict[str, Any]: + """Get retrieval statistics for time window""" + metrics = self.retrieval_metrics + + # Filter by time window + if time_window_minutes: + cutoff = datetime.now() - timedelta(minutes=time_window_minutes) + metrics = [m for m in metrics if m.timestamp >= cutoff] + + # Filter by industry + if industry: + metrics = [m for m in metrics if m.industry == industry] + + if not metrics: + return { + "count": 0, + "avg_time_ms": 0.0, + "cache_hit_rate": 0.0, + } + + return { + "count": len(metrics), + "avg_time_ms": sum(m.retrieval_time_ms for m in metrics) / len(metrics), + "min_time_ms": min(m.retrieval_time_ms for m in metrics), + "max_time_ms": max(m.retrieval_time_ms for m in metrics), + "cache_hit_rate": sum(1 for m in metrics if m.cache_hit) / len(metrics), + "avg_results": sum(m.num_results for m in metrics) / len(metrics), + "avg_top_score": sum(m.top_score for m in metrics if m.top_score) + / len([m for m in metrics if m.top_score]), + } + + def get_indexing_stats( + self, time_window_minutes: Optional[int] = None + ) -> Dict[str, Any]: + """Get indexing statistics""" + metrics = self.indexing_metrics + + # Filter by time window + if time_window_minutes: + cutoff = datetime.now() - timedelta(minutes=time_window_minutes) + metrics = [m for m in metrics if m.timestamp >= cutoff] + + if not metrics: + return { + "count": 0, + "total_documents": 0, + "avg_time_ms": 0.0, + "success_rate": 0.0, + } + + successful = [m for m in metrics if m.success] + + return { + "count": len(metrics), + "total_documents": sum(m.num_documents for m in successful), + "avg_time_ms": sum(m.processing_time_ms for m in metrics) / len(metrics), + "success_rate": len(successful) / len(metrics) if metrics else 0.0, + "error_count": len([m for m in metrics if not m.success]), + } + + def get_top_queries( + self, limit: int = 10, time_window_minutes: Optional[int] = None + ) -> List[Dict[str, Any]]: + """Get most frequent queries""" + metrics = self.retrieval_metrics + + # Filter by time window + if time_window_minutes: + cutoff = datetime.now() - timedelta(minutes=time_window_minutes) + metrics = [m for m in metrics if m.timestamp >= cutoff] + + # Count query frequencies + query_counts: Dict[str, int] = defaultdict(int) + for m in metrics: + query_counts[m.query_text] += 1 + + # Sort by frequency + top_queries = sorted(query_counts.items(), key=lambda x: x[1], reverse=True)[ + :limit + ] + + return [{"query": query, "count": count} for query, count in top_queries] + + def get_performance_report(self) -> Dict[str, Any]: + """Get comprehensive performance report""" + return { + "system_metrics": self.system_metrics.to_dict(), + "retrieval_stats_1h": self.get_retrieval_stats(time_window_minutes=60), + "retrieval_stats_24h": self.get_retrieval_stats(time_window_minutes=1440), + "indexing_stats_1h": self.get_indexing_stats(time_window_minutes=60), + "indexing_stats_24h": self.get_indexing_stats(time_window_minutes=1440), + "top_queries_24h": self.get_top_queries(limit=10, time_window_minutes=1440), + } + + def export_metrics(self, filepath: str): + """Export metrics to JSON file""" + data = { + "retrieval_metrics": [ + m.to_dict() for m in self.retrieval_metrics[-1000:] + ], # Last 1000 + "indexing_metrics": [ + m.to_dict() for m in self.indexing_metrics[-1000:] + ], # Last 1000 + "system_metrics": self.system_metrics.to_dict(), + "exported_at": datetime.now().isoformat(), + } + + with open(filepath, "w") as f: + json.dump(data, f, indent=2) + + +class PerformanceTimer: + """Context manager for timing operations""" + + def __init__( + self, monitor: Optional[RAGMonitor] = None, operation_name: str = "operation" + ): + self.monitor = monitor + self.operation_name = operation_name + self.start_time = None + self.end_time = None + + def __enter__(self): + self.start_time = time.time() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.end_time = time.time() + return False + + def elapsed_ms(self) -> float: + """Get elapsed time in milliseconds""" + if self.start_time and self.end_time: + return (self.end_time - self.start_time) * 1000 + elif self.start_time: + return (time.time() - self.start_time) * 1000 + return 0.0 diff --git a/src/deepiri_modelkit/rag/processors.py b/src/deepiri_modelkit/rag/processors.py index 3fdb477..2d617ee 100644 --- a/src/deepiri_modelkit/rag/processors.py +++ b/src/deepiri_modelkit/rag/processors.py @@ -1,423 +1,443 @@ -""" -Document Processors for Universal RAG -Handles preprocessing, chunking, and metadata extraction for different document types -""" - -from abc import ABC, abstractmethod -from typing import List, Dict, Any, Optional -import re -from datetime import datetime - -from .base import Document, DocumentType, IndustryNiche - - -class DocumentProcessor(ABC): - """Base class for document processing""" - - def __init__(self, chunk_size: int = 1000, chunk_overlap: int = 200): - self.chunk_size = chunk_size - self.chunk_overlap = chunk_overlap - - @abstractmethod - def process(self, raw_content: str, metadata: Dict[str, Any]) -> List[Document]: - """ - Process raw content into structured documents - - Args: - raw_content: Raw text content - metadata: Document metadata - - Returns: - List of processed document chunks - """ - pass - - def chunk_text(self, text: str) -> List[str]: - """ - Chunk text into smaller pieces with overlap - - Args: - text: Text to chunk - - Returns: - List of text chunks - """ - chunks = [] - start = 0 - text_length = len(text) - - while start < text_length: - end = start + self.chunk_size - - # Find the last sentence boundary within chunk_size - if end < text_length: - # Look for sentence endings - for sep in ['. ', '.\n', '! ', '!\n', '? ', '?\n']: - last_sep = text.rfind(sep, start, end) - if last_sep != -1: - end = last_sep + 1 - break - - chunk = text[start:end].strip() - if chunk: - chunks.append(chunk) - - # Move start forward with overlap - start = end - self.chunk_overlap if end - self.chunk_overlap > start else end - - return chunks - - def extract_metadata(self, content: str) -> Dict[str, Any]: - """Extract metadata from content (can be overridden)""" - return {} - - -class RegulationProcessor(DocumentProcessor): - """ - Processor for regulations, policies, and compliance documents - Common across insurance, healthcare, manufacturing, etc. - """ - - def process(self, raw_content: str, metadata: Dict[str, Any]) -> List[Document]: - """Process regulation documents""" - # Extract sections and subsections - sections = self._extract_sections(raw_content) - - documents = [] - base_id = metadata.get('id', 'reg_' + str(hash(raw_content[:100]))) - industry = IndustryNiche(metadata.get('industry', 'generic')) - - for idx, section in enumerate(sections): - doc = Document( - id=f"{base_id}_chunk_{idx}", - content=section['content'], - doc_type=DocumentType.REGULATION, - industry=industry, - title=metadata.get('title', 'Regulation Document'), - source=metadata.get('source'), - created_at=self._parse_date(metadata.get('created_at')), - metadata={ - **metadata, - 'section': section.get('section'), - 'subsection': section.get('subsection'), - }, - chunk_index=idx, - total_chunks=len(sections), - ) - documents.append(doc) - - return documents - - def _extract_sections(self, content: str) -> List[Dict[str, Any]]: - """Extract sections from regulation text""" - # Match section headers like "Section 1.2.3", "Article 5", etc. - section_pattern = r'(Section|Article|Part|Chapter)\s+(\d+(?:\.\d+)*)' - - sections = [] - current_section = {'section': None, 'content': ''} - - lines = content.split('\n') - for line in lines: - match = re.search(section_pattern, line, re.IGNORECASE) - if match: - # Save previous section - if current_section['content']: - sections.append(current_section) - # Start new section - current_section = { - 'section': match.group(0), - 'content': line + '\n' - } - else: - current_section['content'] += line + '\n' - - # Add last section - if current_section['content']: - sections.append(current_section) - - # If no sections found, treat entire content as one chunk - if not sections: - chunks = self.chunk_text(content) - sections = [{'section': f'Chunk {i+1}', 'content': chunk} - for i, chunk in enumerate(chunks)] - - return sections - - def _parse_date(self, date_str: Optional[str]) -> Optional[datetime]: - """Parse date string to datetime""" - if not date_str: - return None - try: - return datetime.fromisoformat(date_str) - except (ValueError, AttributeError): - return None - - -class HistoricalDataProcessor(DocumentProcessor): - """ - Processor for historical operational data - - Work orders, maintenance logs, claim records, service history - """ - - def process(self, raw_content: str, metadata: Dict[str, Any]) -> List[Document]: - """Process historical data records""" - doc_type_str = metadata.get('doc_type', 'work_order') - doc_type = DocumentType(doc_type_str) if doc_type_str else DocumentType.WORK_ORDER - - # Historical data is typically already structured - # We may need minimal chunking - chunks = self.chunk_text(raw_content) if len(raw_content) > self.chunk_size else [raw_content] - - documents = [] - base_id = metadata.get('id', 'hist_' + str(hash(raw_content[:100]))) - industry = IndustryNiche(metadata.get('industry', 'generic')) - - for idx, chunk in enumerate(chunks): - doc = Document( - id=f"{base_id}_chunk_{idx}", - content=chunk, - doc_type=doc_type, - industry=industry, - title=metadata.get('title', f'{doc_type.value.replace("_", " ").title()}'), - source=metadata.get('source'), - created_at=self._parse_date(metadata.get('created_at')), - updated_at=self._parse_date(metadata.get('updated_at')), - metadata={ - **metadata, - 'record_type': doc_type.value, - 'record_id': metadata.get('record_id'), - 'status': metadata.get('status'), - 'priority': metadata.get('priority'), - 'assigned_to': metadata.get('assigned_to'), - }, - chunk_index=idx, - total_chunks=len(chunks), - ) - documents.append(doc) - - return documents - - def _parse_date(self, date_str: Optional[str]) -> Optional[datetime]: - """Parse date string to datetime""" - if not date_str: - return None - if isinstance(date_str, datetime): - return date_str - try: - return datetime.fromisoformat(date_str) - except (ValueError, AttributeError): - try: - return datetime.strptime(date_str, '%Y-%m-%d') - except (ValueError, AttributeError): - return None - - -class KnowledgeBaseProcessor(DocumentProcessor): - """ - Processor for knowledge base articles, FAQs, and guides - - Equipment repair guides, compliance advice, troubleshooting steps - """ - - def process(self, raw_content: str, metadata: Dict[str, Any]) -> List[Document]: - """Process knowledge base articles""" - doc_type = DocumentType.KNOWLEDGE_BASE if metadata.get('doc_type') == 'knowledge_base' else DocumentType.FAQ - - # Extract Q&A pairs if FAQ format - if doc_type == DocumentType.FAQ: - qa_pairs = self._extract_qa_pairs(raw_content) - if qa_pairs: - return self._process_qa_pairs(qa_pairs, metadata) - - # Otherwise, process as regular article - chunks = self.chunk_text(raw_content) - - documents = [] - base_id = metadata.get('id', 'kb_' + str(hash(raw_content[:100]))) - industry = IndustryNiche(metadata.get('industry', 'generic')) - - for idx, chunk in enumerate(chunks): - doc = Document( - id=f"{base_id}_chunk_{idx}", - content=chunk, - doc_type=doc_type, - industry=industry, - title=metadata.get('title', 'Knowledge Base Article'), - source=metadata.get('source'), - created_at=self._parse_date(metadata.get('created_at')), - updated_at=self._parse_date(metadata.get('updated_at')), - author=metadata.get('author'), - metadata={ - **metadata, - 'category': metadata.get('category'), - 'tags': metadata.get('tags', []), - 'difficulty_level': metadata.get('difficulty_level'), - }, - chunk_index=idx, - total_chunks=len(chunks), - ) - documents.append(doc) - - return documents - - def _extract_qa_pairs(self, content: str) -> List[Dict[str, str]]: - """Extract Q&A pairs from FAQ content""" - qa_pairs = [] - - # Try different FAQ formats - # Format 1: Q: ... A: ... - pattern1 = r'Q:\s*(.+?)\s*A:\s*(.+?)(?=Q:|$)' - matches = re.findall(pattern1, content, re.DOTALL | re.IGNORECASE) - if matches: - qa_pairs.extend([{'question': q.strip(), 'answer': a.strip()} for q, a in matches]) - - # Format 2: Question/Answer headers - pattern2 = r'Question:\s*(.+?)\s*Answer:\s*(.+?)(?=Question:|$)' - matches = re.findall(pattern2, content, re.DOTALL | re.IGNORECASE) - if matches: - qa_pairs.extend([{'question': q.strip(), 'answer': a.strip()} for q, a in matches]) - - return qa_pairs - - def _process_qa_pairs(self, qa_pairs: List[Dict[str, str]], metadata: Dict[str, Any]) -> List[Document]: - """Process Q&A pairs into documents""" - documents = [] - base_id = metadata.get('id', 'faq_' + str(hash(str(qa_pairs[0])))) - industry = IndustryNiche(metadata.get('industry', 'generic')) - - for idx, qa in enumerate(qa_pairs): - content = f"Question: {qa['question']}\n\nAnswer: {qa['answer']}" - doc = Document( - id=f"{base_id}_qa_{idx}", - content=content, - doc_type=DocumentType.FAQ, - industry=industry, - title=qa['question'][:100], # Use question as title - source=metadata.get('source'), - metadata={ - **metadata, - 'question': qa['question'], - 'answer': qa['answer'], - }, - chunk_index=idx, - total_chunks=len(qa_pairs), - ) - documents.append(doc) - - return documents - - def _parse_date(self, date_str: Optional[str]) -> Optional[datetime]: - """Parse date string to datetime""" - if not date_str: - return None - if isinstance(date_str, datetime): - return date_str - try: - return datetime.fromisoformat(date_str) - except (ValueError, AttributeError): - return None - - -class ManualProcessor(DocumentProcessor): - """ - Processor for equipment manuals, operation guides, technical specifications - """ - - def process(self, raw_content: str, metadata: Dict[str, Any]) -> List[Document]: - """Process manual documents""" - # Extract chapters and sections - sections = self._extract_sections(raw_content) - - documents = [] - base_id = metadata.get('id', 'manual_' + str(hash(raw_content[:100]))) - industry = IndustryNiche(metadata.get('industry', 'generic')) - - for idx, section in enumerate(sections): - doc = Document( - id=f"{base_id}_chunk_{idx}", - content=section['content'], - doc_type=DocumentType.MANUAL, - industry=industry, - title=metadata.get('title', 'Equipment Manual'), - source=metadata.get('source'), - version=metadata.get('version'), - metadata={ - **metadata, - 'chapter': section.get('chapter'), - 'section': section.get('section'), - 'equipment_model': metadata.get('equipment_model'), - 'manufacturer': metadata.get('manufacturer'), - }, - chunk_index=idx, - total_chunks=len(sections), - ) - documents.append(doc) - - return documents - - def _extract_sections(self, content: str) -> List[Dict[str, Any]]: - """Extract sections from manual text""" - # Match chapter/section headers - section_pattern = r'(Chapter|Section)\s+(\d+(?:\.\d+)*):?\s*(.+?)(?=\n)' - - sections = [] - current_section = {'chapter': None, 'section': None, 'content': ''} - - lines = content.split('\n') - for line in lines: - match = re.search(section_pattern, line, re.IGNORECASE) - if match: - # Save previous section - if current_section['content']: - sections.append(current_section) - # Start new section - section_type = match.group(1).lower() - section_num = match.group(2) - section_title = match.group(3).strip() if match.group(3) else '' - current_section = { - section_type: f"{section_type.title()} {section_num}", - 'section_title': section_title, - 'content': line + '\n' - } - else: - current_section['content'] += line + '\n' - - # Add last section - if current_section['content']: - sections.append(current_section) - - # If no sections found, chunk the content - if not sections: - chunks = self.chunk_text(content) - sections = [{'section': f'Chunk {i+1}', 'content': chunk} - for i, chunk in enumerate(chunks)] - - return sections - - -def get_processor(doc_type: DocumentType, **kwargs) -> DocumentProcessor: - """ - Factory function to get appropriate processor for document type - - Args: - doc_type: Type of document - **kwargs: Additional configuration for processor - - Returns: - Configured document processor - """ - processor_map = { - DocumentType.REGULATION: RegulationProcessor, - DocumentType.POLICY: RegulationProcessor, # Similar processing - DocumentType.WORK_ORDER: HistoricalDataProcessor, - DocumentType.CLAIM_RECORD: HistoricalDataProcessor, - DocumentType.MAINTENANCE_LOG: HistoricalDataProcessor, - DocumentType.FAQ: KnowledgeBaseProcessor, - DocumentType.KNOWLEDGE_BASE: KnowledgeBaseProcessor, - DocumentType.MANUAL: ManualProcessor, - DocumentType.TECHNICAL_SPEC: ManualProcessor, # Similar processing - DocumentType.PROCEDURE: ManualProcessor, # Similar processing - } - - processor_class = processor_map.get(doc_type, DocumentProcessor) - return processor_class(**kwargs) - +""" +Document Processors for Universal RAG +Handles preprocessing, chunking, and metadata extraction for different document types +""" + +from abc import ABC, abstractmethod +from typing import List, Dict, Any, Optional +import re +from datetime import datetime + +from .base import Document, DocumentType, IndustryNiche + + +class DocumentProcessor(ABC): + """Base class for document processing""" + + def __init__(self, chunk_size: int = 1000, chunk_overlap: int = 200): + self.chunk_size = chunk_size + self.chunk_overlap = chunk_overlap + + @abstractmethod + def process(self, raw_content: str, metadata: Dict[str, Any]) -> List[Document]: + """ + Process raw content into structured documents + + Args: + raw_content: Raw text content + metadata: Document metadata + + Returns: + List of processed document chunks + """ + pass + + def chunk_text(self, text: str) -> List[str]: + """ + Chunk text into smaller pieces with overlap + + Args: + text: Text to chunk + + Returns: + List of text chunks + """ + chunks = [] + start = 0 + text_length = len(text) + + while start < text_length: + end = start + self.chunk_size + + # Find the last sentence boundary within chunk_size + if end < text_length: + # Look for sentence endings + for sep in [". ", ".\n", "! ", "!\n", "? ", "?\n"]: + last_sep = text.rfind(sep, start, end) + if last_sep != -1: + end = last_sep + 1 + break + + chunk = text[start:end].strip() + if chunk: + chunks.append(chunk) + + # Move start forward with overlap + start = ( + end - self.chunk_overlap if end - self.chunk_overlap > start else end + ) + + return chunks + + def extract_metadata(self, content: str) -> Dict[str, Any]: + """Extract metadata from content (can be overridden)""" + return {} + + +class RegulationProcessor(DocumentProcessor): + """ + Processor for regulations, policies, and compliance documents + Common across insurance, healthcare, manufacturing, etc. + """ + + def process(self, raw_content: str, metadata: Dict[str, Any]) -> List[Document]: + """Process regulation documents""" + # Extract sections and subsections + sections = self._extract_sections(raw_content) + + documents = [] + base_id = metadata.get("id", "reg_" + str(hash(raw_content[:100]))) + industry = IndustryNiche(metadata.get("industry", "generic")) + + for idx, section in enumerate(sections): + doc = Document( + id=f"{base_id}_chunk_{idx}", + content=section["content"], + doc_type=DocumentType.REGULATION, + industry=industry, + title=metadata.get("title", "Regulation Document"), + source=metadata.get("source"), + created_at=self._parse_date(metadata.get("created_at")), + metadata={ + **metadata, + "section": section.get("section"), + "subsection": section.get("subsection"), + }, + chunk_index=idx, + total_chunks=len(sections), + ) + documents.append(doc) + + return documents + + def _extract_sections(self, content: str) -> List[Dict[str, Any]]: + """Extract sections from regulation text""" + # Match section headers like "Section 1.2.3", "Article 5", etc. + section_pattern = r"(Section|Article|Part|Chapter)\s+(\d+(?:\.\d+)*)" + + sections = [] + current_section = {"section": None, "content": ""} + + lines = content.split("\n") + for line in lines: + match = re.search(section_pattern, line, re.IGNORECASE) + if match: + # Save previous section + if current_section["content"]: + sections.append(current_section) + # Start new section + current_section = {"section": match.group(0), "content": line + "\n"} + else: + current_section["content"] += line + "\n" + + # Add last section + if current_section["content"]: + sections.append(current_section) + + # If no sections found, treat entire content as one chunk + if not sections: + chunks = self.chunk_text(content) + sections = [ + {"section": f"Chunk {i+1}", "content": chunk} + for i, chunk in enumerate(chunks) + ] + + return sections + + def _parse_date(self, date_str: Optional[str]) -> Optional[datetime]: + """Parse date string to datetime""" + if not date_str: + return None + try: + return datetime.fromisoformat(date_str) + except (ValueError, AttributeError): + return None + + +class HistoricalDataProcessor(DocumentProcessor): + """ + Processor for historical operational data + - Work orders, maintenance logs, claim records, service history + """ + + def process(self, raw_content: str, metadata: Dict[str, Any]) -> List[Document]: + """Process historical data records""" + doc_type_str = metadata.get("doc_type", "work_order") + doc_type = ( + DocumentType(doc_type_str) if doc_type_str else DocumentType.WORK_ORDER + ) + + # Historical data is typically already structured + # We may need minimal chunking + chunks = ( + self.chunk_text(raw_content) + if len(raw_content) > self.chunk_size + else [raw_content] + ) + + documents = [] + base_id = metadata.get("id", "hist_" + str(hash(raw_content[:100]))) + industry = IndustryNiche(metadata.get("industry", "generic")) + + for idx, chunk in enumerate(chunks): + doc = Document( + id=f"{base_id}_chunk_{idx}", + content=chunk, + doc_type=doc_type, + industry=industry, + title=metadata.get( + "title", f'{doc_type.value.replace("_", " ").title()}' + ), + source=metadata.get("source"), + created_at=self._parse_date(metadata.get("created_at")), + updated_at=self._parse_date(metadata.get("updated_at")), + metadata={ + **metadata, + "record_type": doc_type.value, + "record_id": metadata.get("record_id"), + "status": metadata.get("status"), + "priority": metadata.get("priority"), + "assigned_to": metadata.get("assigned_to"), + }, + chunk_index=idx, + total_chunks=len(chunks), + ) + documents.append(doc) + + return documents + + def _parse_date(self, date_str: Optional[str]) -> Optional[datetime]: + """Parse date string to datetime""" + if not date_str: + return None + if isinstance(date_str, datetime): + return date_str + try: + return datetime.fromisoformat(date_str) + except (ValueError, AttributeError): + try: + return datetime.strptime(date_str, "%Y-%m-%d") + except (ValueError, AttributeError): + return None + + +class KnowledgeBaseProcessor(DocumentProcessor): + """ + Processor for knowledge base articles, FAQs, and guides + - Equipment repair guides, compliance advice, troubleshooting steps + """ + + def process(self, raw_content: str, metadata: Dict[str, Any]) -> List[Document]: + """Process knowledge base articles""" + doc_type = ( + DocumentType.KNOWLEDGE_BASE + if metadata.get("doc_type") == "knowledge_base" + else DocumentType.FAQ + ) + + # Extract Q&A pairs if FAQ format + if doc_type == DocumentType.FAQ: + qa_pairs = self._extract_qa_pairs(raw_content) + if qa_pairs: + return self._process_qa_pairs(qa_pairs, metadata) + + # Otherwise, process as regular article + chunks = self.chunk_text(raw_content) + + documents = [] + base_id = metadata.get("id", "kb_" + str(hash(raw_content[:100]))) + industry = IndustryNiche(metadata.get("industry", "generic")) + + for idx, chunk in enumerate(chunks): + doc = Document( + id=f"{base_id}_chunk_{idx}", + content=chunk, + doc_type=doc_type, + industry=industry, + title=metadata.get("title", "Knowledge Base Article"), + source=metadata.get("source"), + created_at=self._parse_date(metadata.get("created_at")), + updated_at=self._parse_date(metadata.get("updated_at")), + author=metadata.get("author"), + metadata={ + **metadata, + "category": metadata.get("category"), + "tags": metadata.get("tags", []), + "difficulty_level": metadata.get("difficulty_level"), + }, + chunk_index=idx, + total_chunks=len(chunks), + ) + documents.append(doc) + + return documents + + def _extract_qa_pairs(self, content: str) -> List[Dict[str, str]]: + """Extract Q&A pairs from FAQ content""" + qa_pairs = [] + + # Try different FAQ formats + # Format 1: Q: ... A: ... + pattern1 = r"Q:\s*(.+?)\s*A:\s*(.+?)(?=Q:|$)" + matches = re.findall(pattern1, content, re.DOTALL | re.IGNORECASE) + if matches: + qa_pairs.extend( + [{"question": q.strip(), "answer": a.strip()} for q, a in matches] + ) + + # Format 2: Question/Answer headers + pattern2 = r"Question:\s*(.+?)\s*Answer:\s*(.+?)(?=Question:|$)" + matches = re.findall(pattern2, content, re.DOTALL | re.IGNORECASE) + if matches: + qa_pairs.extend( + [{"question": q.strip(), "answer": a.strip()} for q, a in matches] + ) + + return qa_pairs + + def _process_qa_pairs( + self, qa_pairs: List[Dict[str, str]], metadata: Dict[str, Any] + ) -> List[Document]: + """Process Q&A pairs into documents""" + documents = [] + base_id = metadata.get("id", "faq_" + str(hash(str(qa_pairs[0])))) + industry = IndustryNiche(metadata.get("industry", "generic")) + + for idx, qa in enumerate(qa_pairs): + content = f"Question: {qa['question']}\n\nAnswer: {qa['answer']}" + doc = Document( + id=f"{base_id}_qa_{idx}", + content=content, + doc_type=DocumentType.FAQ, + industry=industry, + title=qa["question"][:100], # Use question as title + source=metadata.get("source"), + metadata={ + **metadata, + "question": qa["question"], + "answer": qa["answer"], + }, + chunk_index=idx, + total_chunks=len(qa_pairs), + ) + documents.append(doc) + + return documents + + def _parse_date(self, date_str: Optional[str]) -> Optional[datetime]: + """Parse date string to datetime""" + if not date_str: + return None + if isinstance(date_str, datetime): + return date_str + try: + return datetime.fromisoformat(date_str) + except (ValueError, AttributeError): + return None + + +class ManualProcessor(DocumentProcessor): + """ + Processor for equipment manuals, operation guides, technical specifications + """ + + def process(self, raw_content: str, metadata: Dict[str, Any]) -> List[Document]: + """Process manual documents""" + # Extract chapters and sections + sections = self._extract_sections(raw_content) + + documents = [] + base_id = metadata.get("id", "manual_" + str(hash(raw_content[:100]))) + industry = IndustryNiche(metadata.get("industry", "generic")) + + for idx, section in enumerate(sections): + doc = Document( + id=f"{base_id}_chunk_{idx}", + content=section["content"], + doc_type=DocumentType.MANUAL, + industry=industry, + title=metadata.get("title", "Equipment Manual"), + source=metadata.get("source"), + version=metadata.get("version"), + metadata={ + **metadata, + "chapter": section.get("chapter"), + "section": section.get("section"), + "equipment_model": metadata.get("equipment_model"), + "manufacturer": metadata.get("manufacturer"), + }, + chunk_index=idx, + total_chunks=len(sections), + ) + documents.append(doc) + + return documents + + def _extract_sections(self, content: str) -> List[Dict[str, Any]]: + """Extract sections from manual text""" + # Match chapter/section headers + section_pattern = r"(Chapter|Section)\s+(\d+(?:\.\d+)*):?\s*(.+?)(?=\n)" + + sections = [] + current_section = {"chapter": None, "section": None, "content": ""} + + lines = content.split("\n") + for line in lines: + match = re.search(section_pattern, line, re.IGNORECASE) + if match: + # Save previous section + if current_section["content"]: + sections.append(current_section) + # Start new section + section_type = match.group(1).lower() + section_num = match.group(2) + section_title = match.group(3).strip() if match.group(3) else "" + current_section = { + section_type: f"{section_type.title()} {section_num}", + "section_title": section_title, + "content": line + "\n", + } + else: + current_section["content"] += line + "\n" + + # Add last section + if current_section["content"]: + sections.append(current_section) + + # If no sections found, chunk the content + if not sections: + chunks = self.chunk_text(content) + sections = [ + {"section": f"Chunk {i+1}", "content": chunk} + for i, chunk in enumerate(chunks) + ] + + return sections + + +def get_processor(doc_type: DocumentType, **kwargs) -> DocumentProcessor: + """ + Factory function to get appropriate processor for document type + + Args: + doc_type: Type of document + **kwargs: Additional configuration for processor + + Returns: + Configured document processor + """ + processor_map = { + DocumentType.REGULATION: RegulationProcessor, + DocumentType.POLICY: RegulationProcessor, # Similar processing + DocumentType.WORK_ORDER: HistoricalDataProcessor, + DocumentType.CLAIM_RECORD: HistoricalDataProcessor, + DocumentType.MAINTENANCE_LOG: HistoricalDataProcessor, + DocumentType.FAQ: KnowledgeBaseProcessor, + DocumentType.KNOWLEDGE_BASE: KnowledgeBaseProcessor, + DocumentType.MANUAL: ManualProcessor, + DocumentType.TECHNICAL_SPEC: ManualProcessor, # Similar processing + DocumentType.PROCEDURE: ManualProcessor, # Similar processing + } + + processor_class = processor_map.get(doc_type, DocumentProcessor) + return processor_class(**kwargs) diff --git a/src/deepiri_modelkit/rag/retrievers.py b/src/deepiri_modelkit/rag/retrievers.py index dcda0d0..e8804f3 100644 --- a/src/deepiri_modelkit/rag/retrievers.py +++ b/src/deepiri_modelkit/rag/retrievers.py @@ -1,288 +1,287 @@ -""" -Retrieval Components for Universal RAG -Implements various retrieval strategies for different use cases -""" - -from abc import ABC, abstractmethod -from typing import List, Dict, Any, Optional -from dataclasses import dataclass - -from .base import Document, RetrievalResult, RAGQuery - - -class BaseRetriever(ABC): - """Base class for retrievers""" - - @abstractmethod - def retrieve(self, query: RAGQuery) -> List[RetrievalResult]: - """Retrieve relevant documents for query""" - pass - - -class MultiModalRetriever(BaseRetriever): - """ - Multi-modal retriever supporting text, images, tables, and code - Useful for technical manuals, equipment guides, etc. - """ - - def __init__( - self, - text_embeddings, - image_embeddings=None, - table_embeddings=None, - code_embeddings=None - ): - self.text_embeddings = text_embeddings - self.image_embeddings = image_embeddings - self.table_embeddings = table_embeddings - self.code_embeddings = code_embeddings - - def retrieve(self, query: RAGQuery) -> List[RetrievalResult]: - """ - Retrieve documents across multiple modalities - - Currently focuses on text, but can be extended for other modalities - """ - # For now, delegate to text retrieval - # Future: Add image, table, code retrieval - results = self._retrieve_text(query) - return results - - def _retrieve_text(self, query: RAGQuery) -> List[RetrievalResult]: - """Retrieve text documents""" - # This will be implemented by the concrete RAG engine - # This is a placeholder for the interface - return [] - - -class HybridRetriever(BaseRetriever): - """ - Hybrid retriever combining: - - Semantic search (vector similarity) - - Keyword search (BM25) - - Metadata filtering - - Provides better recall than pure semantic search - """ - - def __init__( - self, - vector_retriever, - keyword_retriever=None, - vector_weight: float = 0.7, - keyword_weight: float = 0.3 - ): - self.vector_retriever = vector_retriever - self.keyword_retriever = keyword_retriever - self.vector_weight = vector_weight - self.keyword_weight = keyword_weight - - def retrieve(self, query: RAGQuery) -> List[RetrievalResult]: - """ - Retrieve using hybrid approach - - Combines vector and keyword search results - """ - results = [] - - # Vector search - vector_results = self._retrieve_vector(query) - results.extend(vector_results) - - # Keyword search (if available) - if self.keyword_retriever: - keyword_results = self._retrieve_keyword(query) - results.extend(keyword_results) - - # Merge and re-score - merged_results = self._merge_results(vector_results, keyword_results if self.keyword_retriever else []) - - return merged_results - - def _retrieve_vector(self, query: RAGQuery) -> List[RetrievalResult]: - """Vector similarity search""" - # Placeholder - implemented by concrete engine - return [] - - def _retrieve_keyword(self, query: RAGQuery) -> List[RetrievalResult]: - """Keyword search using BM25 or similar""" - # Placeholder - implemented by concrete engine - return [] - - def _merge_results( - self, - vector_results: List[RetrievalResult], - keyword_results: List[RetrievalResult] - ) -> List[RetrievalResult]: - """ - Merge results from different retrievers using weighted scoring - - Uses Reciprocal Rank Fusion (RRF) for combining rankings - """ - # Create a dictionary to store combined scores - doc_scores: Dict[str, Dict[str, Any]] = {} - - # Process vector results - for rank, result in enumerate(vector_results): - doc_id = result.document.id - # RRF score: 1 / (k + rank) where k=60 is common - rrf_score = 1.0 / (60 + rank + 1) - doc_scores[doc_id] = { - 'document': result.document, - 'vector_score': result.score, - 'vector_rrf': rrf_score, - 'keyword_rrf': 0.0, - 'keyword_score': 0.0, - } - - # Process keyword results - for rank, result in enumerate(keyword_results): - doc_id = result.document.id - rrf_score = 1.0 / (60 + rank + 1) - - if doc_id in doc_scores: - doc_scores[doc_id]['keyword_rrf'] = rrf_score - doc_scores[doc_id]['keyword_score'] = result.score - else: - doc_scores[doc_id] = { - 'document': result.document, - 'vector_score': 0.0, - 'vector_rrf': 0.0, - 'keyword_rrf': rrf_score, - 'keyword_score': result.score, - } - - # Calculate combined scores - merged = [] - for doc_id, scores in doc_scores.items(): - combined_rrf = ( - self.vector_weight * scores['vector_rrf'] + - self.keyword_weight * scores['keyword_rrf'] - ) - - result = RetrievalResult( - document=scores['document'], - score=combined_rrf, - rerank_score=None, - ) - merged.append(result) - - # Sort by combined score - merged.sort(key=lambda x: x.score, reverse=True) - - return merged - - -class ContextualRetriever(BaseRetriever): - """ - Contextual retriever that considers: - - User context (role, history, preferences) - - Temporal context (recent vs historical) - - Industry context (specific to niche) - """ - - def __init__( - self, - base_retriever: BaseRetriever, - use_user_context: bool = True, - use_temporal_context: bool = True, - use_industry_context: bool = True - ): - self.base_retriever = base_retriever - self.use_user_context = use_user_context - self.use_temporal_context = use_temporal_context - self.use_industry_context = use_industry_context - - def retrieve(self, query: RAGQuery) -> List[RetrievalResult]: - """ - Retrieve with contextual awareness - """ - # Get base results - results = self.base_retriever.retrieve(query) - - # Apply contextual reranking - if self.use_temporal_context: - results = self._apply_temporal_boost(results, query) - - if self.use_industry_context: - results = self._apply_industry_boost(results, query) - - return results - - def _apply_temporal_boost( - self, - results: List[RetrievalResult], - query: RAGQuery - ) -> List[RetrievalResult]: - """ - Boost recent documents (useful for regulations, updates) - """ - import time - from datetime import datetime - - current_time = datetime.now().timestamp() - - for result in results: - if result.document.updated_at or result.document.created_at: - doc_time = (result.document.updated_at or result.document.created_at).timestamp() - # Calculate age in days - age_days = (current_time - doc_time) / 86400 - - # Apply decay: more recent = higher boost - # Documents within 30 days get full boost - # Older documents gradually lose boost - if age_days < 30: - temporal_boost = 1.0 - elif age_days < 180: # 6 months - temporal_boost = 0.9 - elif age_days < 365: # 1 year - temporal_boost = 0.8 - else: - temporal_boost = 0.7 - - result.score *= temporal_boost - - # Re-sort by adjusted scores - results.sort(key=lambda x: x.score, reverse=True) - return results - - def _apply_industry_boost( - self, - results: List[RetrievalResult], - query: RAGQuery - ) -> List[RetrievalResult]: - """ - Boost documents matching the query's industry - """ - if not query.industry: - return results - - for result in results: - if result.document.industry == query.industry: - result.score *= 1.1 # 10% boost for industry match - - # Re-sort - results.sort(key=lambda x: x.score, reverse=True) - return results - - -def get_retriever(retriever_type: str, **kwargs) -> BaseRetriever: - """ - Factory function to get retriever - - Args: - retriever_type: Type of retriever ('hybrid', 'multimodal', 'contextual') - **kwargs: Configuration for retriever - - Returns: - Configured retriever - """ - retriever_map = { - 'hybrid': HybridRetriever, - 'multimodal': MultiModalRetriever, - 'contextual': ContextualRetriever, - } - - retriever_class = retriever_map.get(retriever_type, HybridRetriever) - return retriever_class(**kwargs) - +""" +Retrieval Components for Universal RAG +Implements various retrieval strategies for different use cases +""" + +from abc import ABC, abstractmethod +from typing import List, Dict, Any, Optional +from dataclasses import dataclass + +from .base import Document, RetrievalResult, RAGQuery + + +class BaseRetriever(ABC): + """Base class for retrievers""" + + @abstractmethod + def retrieve(self, query: RAGQuery) -> List[RetrievalResult]: + """Retrieve relevant documents for query""" + pass + + +class MultiModalRetriever(BaseRetriever): + """ + Multi-modal retriever supporting text, images, tables, and code + Useful for technical manuals, equipment guides, etc. + """ + + def __init__( + self, + text_embeddings, + image_embeddings=None, + table_embeddings=None, + code_embeddings=None, + ): + self.text_embeddings = text_embeddings + self.image_embeddings = image_embeddings + self.table_embeddings = table_embeddings + self.code_embeddings = code_embeddings + + def retrieve(self, query: RAGQuery) -> List[RetrievalResult]: + """ + Retrieve documents across multiple modalities + + Currently focuses on text, but can be extended for other modalities + """ + # For now, delegate to text retrieval + # Future: Add image, table, code retrieval + results = self._retrieve_text(query) + return results + + def _retrieve_text(self, query: RAGQuery) -> List[RetrievalResult]: + """Retrieve text documents""" + # This will be implemented by the concrete RAG engine + # This is a placeholder for the interface + return [] + + +class HybridRetriever(BaseRetriever): + """ + Hybrid retriever combining: + - Semantic search (vector similarity) + - Keyword search (BM25) + - Metadata filtering + + Provides better recall than pure semantic search + """ + + def __init__( + self, + vector_retriever, + keyword_retriever=None, + vector_weight: float = 0.7, + keyword_weight: float = 0.3, + ): + self.vector_retriever = vector_retriever + self.keyword_retriever = keyword_retriever + self.vector_weight = vector_weight + self.keyword_weight = keyword_weight + + def retrieve(self, query: RAGQuery) -> List[RetrievalResult]: + """ + Retrieve using hybrid approach + + Combines vector and keyword search results + """ + results = [] + + # Vector search + vector_results = self._retrieve_vector(query) + results.extend(vector_results) + + # Keyword search (if available) + if self.keyword_retriever: + keyword_results = self._retrieve_keyword(query) + results.extend(keyword_results) + + # Merge and re-score + merged_results = self._merge_results( + vector_results, keyword_results if self.keyword_retriever else [] + ) + + return merged_results + + def _retrieve_vector(self, query: RAGQuery) -> List[RetrievalResult]: + """Vector similarity search""" + # Placeholder - implemented by concrete engine + return [] + + def _retrieve_keyword(self, query: RAGQuery) -> List[RetrievalResult]: + """Keyword search using BM25 or similar""" + # Placeholder - implemented by concrete engine + return [] + + def _merge_results( + self, + vector_results: List[RetrievalResult], + keyword_results: List[RetrievalResult], + ) -> List[RetrievalResult]: + """ + Merge results from different retrievers using weighted scoring + + Uses Reciprocal Rank Fusion (RRF) for combining rankings + """ + # Create a dictionary to store combined scores + doc_scores: Dict[str, Dict[str, Any]] = {} + + # Process vector results + for rank, result in enumerate(vector_results): + doc_id = result.document.id + # RRF score: 1 / (k + rank) where k=60 is common + rrf_score = 1.0 / (60 + rank + 1) + doc_scores[doc_id] = { + "document": result.document, + "vector_score": result.score, + "vector_rrf": rrf_score, + "keyword_rrf": 0.0, + "keyword_score": 0.0, + } + + # Process keyword results + for rank, result in enumerate(keyword_results): + doc_id = result.document.id + rrf_score = 1.0 / (60 + rank + 1) + + if doc_id in doc_scores: + doc_scores[doc_id]["keyword_rrf"] = rrf_score + doc_scores[doc_id]["keyword_score"] = result.score + else: + doc_scores[doc_id] = { + "document": result.document, + "vector_score": 0.0, + "vector_rrf": 0.0, + "keyword_rrf": rrf_score, + "keyword_score": result.score, + } + + # Calculate combined scores + merged = [] + for doc_id, scores in doc_scores.items(): + combined_rrf = ( + self.vector_weight * scores["vector_rrf"] + + self.keyword_weight * scores["keyword_rrf"] + ) + + result = RetrievalResult( + document=scores["document"], + score=combined_rrf, + rerank_score=None, + ) + merged.append(result) + + # Sort by combined score + merged.sort(key=lambda x: x.score, reverse=True) + + return merged + + +class ContextualRetriever(BaseRetriever): + """ + Contextual retriever that considers: + - User context (role, history, preferences) + - Temporal context (recent vs historical) + - Industry context (specific to niche) + """ + + def __init__( + self, + base_retriever: BaseRetriever, + use_user_context: bool = True, + use_temporal_context: bool = True, + use_industry_context: bool = True, + ): + self.base_retriever = base_retriever + self.use_user_context = use_user_context + self.use_temporal_context = use_temporal_context + self.use_industry_context = use_industry_context + + def retrieve(self, query: RAGQuery) -> List[RetrievalResult]: + """ + Retrieve with contextual awareness + """ + # Get base results + results = self.base_retriever.retrieve(query) + + # Apply contextual reranking + if self.use_temporal_context: + results = self._apply_temporal_boost(results, query) + + if self.use_industry_context: + results = self._apply_industry_boost(results, query) + + return results + + def _apply_temporal_boost( + self, results: List[RetrievalResult], query: RAGQuery + ) -> List[RetrievalResult]: + """ + Boost recent documents (useful for regulations, updates) + """ + import time + from datetime import datetime + + current_time = datetime.now().timestamp() + + for result in results: + if result.document.updated_at or result.document.created_at: + doc_time = ( + result.document.updated_at or result.document.created_at + ).timestamp() + # Calculate age in days + age_days = (current_time - doc_time) / 86400 + + # Apply decay: more recent = higher boost + # Documents within 30 days get full boost + # Older documents gradually lose boost + if age_days < 30: + temporal_boost = 1.0 + elif age_days < 180: # 6 months + temporal_boost = 0.9 + elif age_days < 365: # 1 year + temporal_boost = 0.8 + else: + temporal_boost = 0.7 + + result.score *= temporal_boost + + # Re-sort by adjusted scores + results.sort(key=lambda x: x.score, reverse=True) + return results + + def _apply_industry_boost( + self, results: List[RetrievalResult], query: RAGQuery + ) -> List[RetrievalResult]: + """ + Boost documents matching the query's industry + """ + if not query.industry: + return results + + for result in results: + if result.document.industry == query.industry: + result.score *= 1.1 # 10% boost for industry match + + # Re-sort + results.sort(key=lambda x: x.score, reverse=True) + return results + + +def get_retriever(retriever_type: str, **kwargs) -> BaseRetriever: + """ + Factory function to get retriever + + Args: + retriever_type: Type of retriever ('hybrid', 'multimodal', 'contextual') + **kwargs: Configuration for retriever + + Returns: + Configured retriever + """ + retriever_map = { + "hybrid": HybridRetriever, + "multimodal": MultiModalRetriever, + "contextual": ContextualRetriever, + } + + retriever_class = retriever_map.get(retriever_type, HybridRetriever) + return retriever_class(**kwargs) diff --git a/src/deepiri_modelkit/rag/testing.py b/src/deepiri_modelkit/rag/testing.py index bc11473..1377e55 100644 --- a/src/deepiri_modelkit/rag/testing.py +++ b/src/deepiri_modelkit/rag/testing.py @@ -1,332 +1,334 @@ -""" -Testing Utilities for Universal RAG -Comprehensive test helpers, fixtures, and evaluation tools -""" - -from typing import List, Dict, Any, Optional, Tuple -from dataclasses import dataclass -import json -from datetime import datetime - -from .base import Document, DocumentType, IndustryNiche, RAGQuery, RetrievalResult - - -@dataclass -class TestCase: - """Test case for RAG evaluation""" - query: str - expected_doc_ids: List[str] # Document IDs that should be retrieved - expected_doc_types: Optional[List[DocumentType]] = None - min_score: float = 0.7 # Minimum similarity score - top_k: int = 5 - metadata: Dict[str, Any] = None - - def __post_init__(self): - if self.metadata is None: - self.metadata = {} - - -@dataclass -class TestResult: - """Result of a test case""" - test_case: TestCase - retrieved_doc_ids: List[str] - retrieved_scores: List[float] - precision: float - recall: float - f1_score: float - passed: bool - error: Optional[str] = None - - def to_dict(self) -> Dict[str, Any]: - """Convert to dictionary""" - return { - "query": self.test_case.query, - "expected_doc_ids": self.test_case.expected_doc_ids, - "retrieved_doc_ids": self.retrieved_doc_ids, - "precision": self.precision, - "recall": self.recall, - "f1_score": self.f1_score, - "passed": self.passed, - "error": self.error, - } - - -class RAGEvaluator: - """ - Evaluator for RAG system performance - """ - - def __init__(self, rag_engine): - self.rag_engine = rag_engine - - def evaluate( - self, - test_cases: List[TestCase], - industry: Optional[IndustryNiche] = None - ) -> Dict[str, Any]: - """ - Evaluate RAG system on test cases - - Args: - test_cases: List of test cases - industry: Industry context - - Returns: - Evaluation results with metrics - """ - results = [] - - for test_case in test_cases: - result = self._evaluate_test_case(test_case, industry) - results.append(result) - - # Calculate aggregate metrics - total_precision = sum(r.precision for r in results) / len(results) if results else 0.0 - total_recall = sum(r.recall for r in results) / len(results) if results else 0.0 - total_f1 = sum(r.f1_score for r in results) / len(results) if results else 0.0 - passed_count = sum(1 for r in results if r.passed) - - return { - "total_tests": len(test_cases), - "passed": passed_count, - "failed": len(test_cases) - passed_count, - "avg_precision": total_precision, - "avg_recall": total_recall, - "avg_f1_score": total_f1, - "pass_rate": passed_count / len(test_cases) if test_cases else 0.0, - "results": [r.to_dict() for r in results], - } - - def _evaluate_test_case( - self, - test_case: TestCase, - industry: Optional[IndustryNiche] - ) -> TestResult: - """Evaluate a single test case""" - try: - # Build query - query = RAGQuery( - query=test_case.query, - industry=industry, - doc_types=test_case.expected_doc_types, - top_k=test_case.top_k - ) - - # Retrieve documents - results = self.rag_engine.retrieve(query) - - # Extract IDs and scores - retrieved_doc_ids = [r.document.id for r in results] - retrieved_scores = [r.score for r in results] - - # Calculate precision, recall, F1 - expected_set = set(test_case.expected_doc_ids) - retrieved_set = set(retrieved_doc_ids) - - if not retrieved_set: - precision = 0.0 - recall = 0.0 - f1_score = 0.0 - else: - # Precision: relevant retrieved / total retrieved - relevant_retrieved = len(expected_set & retrieved_set) - precision = relevant_retrieved / len(retrieved_set) if retrieved_set else 0.0 - - # Recall: relevant retrieved / total relevant - recall = relevant_retrieved / len(expected_set) if expected_set else 0.0 - - # F1 score - f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0 - - # Check if passed - passed = ( - precision >= 0.7 and - recall >= 0.7 and - f1_score >= 0.7 and - (not retrieved_scores or max(retrieved_scores) >= test_case.min_score) - ) - - return TestResult( - test_case=test_case, - retrieved_doc_ids=retrieved_doc_ids, - retrieved_scores=retrieved_scores, - precision=precision, - recall=recall, - f1_score=f1_score, - passed=passed - ) - - except Exception as e: - return TestResult( - test_case=test_case, - retrieved_doc_ids=[], - retrieved_scores=[], - precision=0.0, - recall=0.0, - f1_score=0.0, - passed=False, - error=str(e) - ) - - -class RAGTestFixture: - """ - Test fixture for creating test data and scenarios - """ - - @staticmethod - def create_test_documents( - industry: IndustryNiche = IndustryNiche.MANUFACTURING, - num_documents: int = 10 - ) -> List[Document]: - """Create test documents""" - documents = [] - - for i in range(num_documents): - doc = Document( - id=f"test_doc_{i}", - content=f"Test document {i} content. This is sample content for testing RAG retrieval.", - doc_type=DocumentType.MANUAL if i % 2 == 0 else DocumentType.MAINTENANCE_LOG, - industry=industry, - title=f"Test Document {i}", - source="test_fixture", - metadata={"test_index": i} - ) - documents.append(doc) - - return documents - - @staticmethod - def create_test_cases( - num_cases: int = 5 - ) -> List[TestCase]: - """Create test cases""" - test_cases = [] - - for i in range(num_cases): - test_case = TestCase( - query=f"test query {i}", - expected_doc_ids=[f"test_doc_{i}", f"test_doc_{i+1}"], - top_k=5 - ) - test_cases.append(test_case) - - return test_cases - - -class PerformanceBenchmark: - """ - Performance benchmarking for RAG operations - """ - - def __init__(self, rag_engine): - self.rag_engine = rag_engine - - def benchmark_retrieval( - self, - queries: List[str], - iterations: int = 10 - ) -> Dict[str, Any]: - """ - Benchmark retrieval performance - - Args: - queries: List of test queries - iterations: Number of iterations per query - - Returns: - Performance metrics - """ - import time - - total_time = 0.0 - total_queries = 0 - times = [] - - for query_text in queries: - query = RAGQuery(query=query_text, top_k=5) - - for _ in range(iterations): - start = time.time() - results = self.rag_engine.retrieve(query) - elapsed = time.time() - start - - total_time += elapsed - total_queries += 1 - times.append(elapsed * 1000) # Convert to ms - - avg_time_ms = (total_time / total_queries) * 1000 if total_queries > 0 else 0.0 - min_time_ms = min(times) if times else 0.0 - max_time_ms = max(times) if times else 0.0 - - return { - "total_queries": total_queries, - "avg_time_ms": avg_time_ms, - "min_time_ms": min_time_ms, - "max_time_ms": max_time_ms, - "queries_per_second": total_queries / total_time if total_time > 0 else 0.0, - } - - def benchmark_indexing( - self, - documents: List[Document], - batch_sizes: List[int] = [1, 10, 100] - ) -> Dict[str, Any]: - """ - Benchmark indexing performance - - Args: - documents: Documents to index - batch_sizes: Different batch sizes to test - - Returns: - Performance metrics for each batch size - """ - import time - - results = {} - - for batch_size in batch_sizes: - batches = [ - documents[i:i + batch_size] - for i in range(0, len(documents), batch_size) - ] - - start = time.time() - for batch in batches: - self.rag_engine.index_documents(batch) - elapsed = time.time() - start - - results[f"batch_size_{batch_size}"] = { - "total_documents": len(documents), - "num_batches": len(batches), - "total_time_seconds": elapsed, - "avg_time_per_doc_ms": (elapsed / len(documents)) * 1000 if documents else 0.0, - "docs_per_second": len(documents) / elapsed if elapsed > 0 else 0.0, - } - - return results - - -def create_evaluation_dataset( - industry: IndustryNiche, - num_documents: int = 100, - num_queries: int = 20 -) -> Tuple[List[Document], List[TestCase]]: - """ - Create a complete evaluation dataset - - Args: - industry: Industry for documents - num_documents: Number of documents to create - num_queries: Number of test queries - - Returns: - Tuple of (documents, test_cases) - """ - documents = RAGTestFixture.create_test_documents(industry, num_documents) - test_cases = RAGTestFixture.create_test_cases(num_queries) - - return documents, test_cases - +""" +Testing Utilities for Universal RAG +Comprehensive test helpers, fixtures, and evaluation tools +""" + +from typing import List, Dict, Any, Optional, Tuple +from dataclasses import dataclass +import json +from datetime import datetime + +from .base import Document, DocumentType, IndustryNiche, RAGQuery, RetrievalResult + + +@dataclass +class TestCase: + """Test case for RAG evaluation""" + + query: str + expected_doc_ids: List[str] # Document IDs that should be retrieved + expected_doc_types: Optional[List[DocumentType]] = None + min_score: float = 0.7 # Minimum similarity score + top_k: int = 5 + metadata: Dict[str, Any] = None + + def __post_init__(self): + if self.metadata is None: + self.metadata = {} + + +@dataclass +class TestResult: + """Result of a test case""" + + test_case: TestCase + retrieved_doc_ids: List[str] + retrieved_scores: List[float] + precision: float + recall: float + f1_score: float + passed: bool + error: Optional[str] = None + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary""" + return { + "query": self.test_case.query, + "expected_doc_ids": self.test_case.expected_doc_ids, + "retrieved_doc_ids": self.retrieved_doc_ids, + "precision": self.precision, + "recall": self.recall, + "f1_score": self.f1_score, + "passed": self.passed, + "error": self.error, + } + + +class RAGEvaluator: + """ + Evaluator for RAG system performance + """ + + def __init__(self, rag_engine): + self.rag_engine = rag_engine + + def evaluate( + self, test_cases: List[TestCase], industry: Optional[IndustryNiche] = None + ) -> Dict[str, Any]: + """ + Evaluate RAG system on test cases + + Args: + test_cases: List of test cases + industry: Industry context + + Returns: + Evaluation results with metrics + """ + results = [] + + for test_case in test_cases: + result = self._evaluate_test_case(test_case, industry) + results.append(result) + + # Calculate aggregate metrics + total_precision = ( + sum(r.precision for r in results) / len(results) if results else 0.0 + ) + total_recall = sum(r.recall for r in results) / len(results) if results else 0.0 + total_f1 = sum(r.f1_score for r in results) / len(results) if results else 0.0 + passed_count = sum(1 for r in results if r.passed) + + return { + "total_tests": len(test_cases), + "passed": passed_count, + "failed": len(test_cases) - passed_count, + "avg_precision": total_precision, + "avg_recall": total_recall, + "avg_f1_score": total_f1, + "pass_rate": passed_count / len(test_cases) if test_cases else 0.0, + "results": [r.to_dict() for r in results], + } + + def _evaluate_test_case( + self, test_case: TestCase, industry: Optional[IndustryNiche] + ) -> TestResult: + """Evaluate a single test case""" + try: + # Build query + query = RAGQuery( + query=test_case.query, + industry=industry, + doc_types=test_case.expected_doc_types, + top_k=test_case.top_k, + ) + + # Retrieve documents + results = self.rag_engine.retrieve(query) + + # Extract IDs and scores + retrieved_doc_ids = [r.document.id for r in results] + retrieved_scores = [r.score for r in results] + + # Calculate precision, recall, F1 + expected_set = set(test_case.expected_doc_ids) + retrieved_set = set(retrieved_doc_ids) + + if not retrieved_set: + precision = 0.0 + recall = 0.0 + f1_score = 0.0 + else: + # Precision: relevant retrieved / total retrieved + relevant_retrieved = len(expected_set & retrieved_set) + precision = ( + relevant_retrieved / len(retrieved_set) if retrieved_set else 0.0 + ) + + # Recall: relevant retrieved / total relevant + recall = relevant_retrieved / len(expected_set) if expected_set else 0.0 + + # F1 score + f1_score = ( + 2 * (precision * recall) / (precision + recall) + if (precision + recall) > 0 + else 0.0 + ) + + # Check if passed + passed = ( + precision >= 0.7 + and recall >= 0.7 + and f1_score >= 0.7 + and ( + not retrieved_scores or max(retrieved_scores) >= test_case.min_score + ) + ) + + return TestResult( + test_case=test_case, + retrieved_doc_ids=retrieved_doc_ids, + retrieved_scores=retrieved_scores, + precision=precision, + recall=recall, + f1_score=f1_score, + passed=passed, + ) + + except Exception as e: + return TestResult( + test_case=test_case, + retrieved_doc_ids=[], + retrieved_scores=[], + precision=0.0, + recall=0.0, + f1_score=0.0, + passed=False, + error=str(e), + ) + + +class RAGTestFixture: + """ + Test fixture for creating test data and scenarios + """ + + @staticmethod + def create_test_documents( + industry: IndustryNiche = IndustryNiche.MANUFACTURING, num_documents: int = 10 + ) -> List[Document]: + """Create test documents""" + documents = [] + + for i in range(num_documents): + doc = Document( + id=f"test_doc_{i}", + content=f"Test document {i} content. This is sample content for testing RAG retrieval.", + doc_type=( + DocumentType.MANUAL if i % 2 == 0 else DocumentType.MAINTENANCE_LOG + ), + industry=industry, + title=f"Test Document {i}", + source="test_fixture", + metadata={"test_index": i}, + ) + documents.append(doc) + + return documents + + @staticmethod + def create_test_cases(num_cases: int = 5) -> List[TestCase]: + """Create test cases""" + test_cases = [] + + for i in range(num_cases): + test_case = TestCase( + query=f"test query {i}", + expected_doc_ids=[f"test_doc_{i}", f"test_doc_{i+1}"], + top_k=5, + ) + test_cases.append(test_case) + + return test_cases + + +class PerformanceBenchmark: + """ + Performance benchmarking for RAG operations + """ + + def __init__(self, rag_engine): + self.rag_engine = rag_engine + + def benchmark_retrieval( + self, queries: List[str], iterations: int = 10 + ) -> Dict[str, Any]: + """ + Benchmark retrieval performance + + Args: + queries: List of test queries + iterations: Number of iterations per query + + Returns: + Performance metrics + """ + import time + + total_time = 0.0 + total_queries = 0 + times = [] + + for query_text in queries: + query = RAGQuery(query=query_text, top_k=5) + + for _ in range(iterations): + start = time.time() + results = self.rag_engine.retrieve(query) + elapsed = time.time() - start + + total_time += elapsed + total_queries += 1 + times.append(elapsed * 1000) # Convert to ms + + avg_time_ms = (total_time / total_queries) * 1000 if total_queries > 0 else 0.0 + min_time_ms = min(times) if times else 0.0 + max_time_ms = max(times) if times else 0.0 + + return { + "total_queries": total_queries, + "avg_time_ms": avg_time_ms, + "min_time_ms": min_time_ms, + "max_time_ms": max_time_ms, + "queries_per_second": total_queries / total_time if total_time > 0 else 0.0, + } + + def benchmark_indexing( + self, documents: List[Document], batch_sizes: List[int] = [1, 10, 100] + ) -> Dict[str, Any]: + """ + Benchmark indexing performance + + Args: + documents: Documents to index + batch_sizes: Different batch sizes to test + + Returns: + Performance metrics for each batch size + """ + import time + + results = {} + + for batch_size in batch_sizes: + batches = [ + documents[i : i + batch_size] + for i in range(0, len(documents), batch_size) + ] + + start = time.time() + for batch in batches: + self.rag_engine.index_documents(batch) + elapsed = time.time() - start + + results[f"batch_size_{batch_size}"] = { + "total_documents": len(documents), + "num_batches": len(batches), + "total_time_seconds": elapsed, + "avg_time_per_doc_ms": ( + (elapsed / len(documents)) * 1000 if documents else 0.0 + ), + "docs_per_second": len(documents) / elapsed if elapsed > 0 else 0.0, + } + + return results + + +def create_evaluation_dataset( + industry: IndustryNiche, num_documents: int = 100, num_queries: int = 20 +) -> Tuple[List[Document], List[TestCase]]: + """ + Create a complete evaluation dataset + + Args: + industry: Industry for documents + num_documents: Number of documents to create + num_queries: Number of test queries + + Returns: + Tuple of (documents, test_cases) + """ + documents = RAGTestFixture.create_test_documents(industry, num_documents) + test_cases = RAGTestFixture.create_test_cases(num_queries) + + return documents, test_cases diff --git a/src/deepiri_modelkit/registry/adapters/__init__.py b/src/deepiri_modelkit/registry/adapters/__init__.py index e949845..43c51bc 100644 --- a/src/deepiri_modelkit/registry/adapters/__init__.py +++ b/src/deepiri_modelkit/registry/adapters/__init__.py @@ -1,2 +1 @@ -"""Storage adapters for model registry""" - +"""Storage adapters for model registry""" diff --git a/src/deepiri_modelkit/registry/model_registry.py b/src/deepiri_modelkit/registry/model_registry.py index 5c5ffdd..7be6f9e 100644 --- a/src/deepiri_modelkit/registry/model_registry.py +++ b/src/deepiri_modelkit/registry/model_registry.py @@ -1,336 +1,333 @@ -""" -Unified model registry client -Supports MLflow, S3/MinIO, and local storage -""" -import os -from typing import Dict, Any, Optional -from pathlib import Path -import mlflow -import boto3 -from botocore.exceptions import ClientError - -from ..contracts.models import ModelMetadata - - -class ModelRegistryClient: - """ - Unified client for model registry operations - Supports MLflow, S3/MinIO, and local storage - """ - - def __init__( - self, - registry_type: str = "mlflow", # mlflow, s3, local - mlflow_tracking_uri: Optional[str] = None, - s3_endpoint: Optional[str] = None, - s3_access_key: Optional[str] = None, - s3_secret_key: Optional[str] = None, - s3_bucket: Optional[str] = None, - local_path: Optional[str] = None - ): - """ - Initialize model registry client - - Args: - registry_type: Type of registry (mlflow, s3, local) - mlflow_tracking_uri: MLflow tracking URI (defaults to MLFLOW_TRACKING_URI env var or http://mlflow:5000) - s3_endpoint: S3/MinIO endpoint - s3_access_key: S3 access key - s3_secret_key: S3 secret key - s3_bucket: S3 bucket name - local_path: Local storage path - """ - self.registry_type = registry_type - - if registry_type == "mlflow": - # Use provided URI, or environment variable, or default - tracking_uri = mlflow_tracking_uri or os.getenv("MLFLOW_TRACKING_URI") or "http://mlflow:5000" - mlflow.set_tracking_uri(tracking_uri) - self.client = mlflow - self.tracking_uri = tracking_uri - elif registry_type == "s3": - self.s3_client = boto3.client( - 's3', - endpoint_url=s3_endpoint, - aws_access_key_id=s3_access_key, - aws_secret_access_key=s3_secret_key - ) - self.s3_bucket = s3_bucket - elif registry_type == "local": - self.local_path = Path(local_path or "./models") - self.local_path.mkdir(parents=True, exist_ok=True) - else: - raise ValueError(f"Unknown registry type: {registry_type}") - - def register_model( - self, - model_name: str, - version: str, - model_path: str, - metadata: Dict[str, Any] - ) -> bool: - """ - Register model in registry - - Args: - model_name: Model name - version: Model version - model_path: Path to model file/directory - metadata: Model metadata - - Returns: - True if successful - """ - try: - if self.registry_type == "mlflow": - # Register with MLflow - model_uri = f"runs:/{metadata.get('run_id', 'latest')}/model" - mlflow.register_model(model_uri, f"{model_name}-{version}") - return True - - elif self.registry_type == "s3": - # Upload to S3 - s3_key = f"models/{model_name}/{version}/model" - self.s3_client.upload_file(model_path, self.s3_bucket, s3_key) - - # Upload metadata - import json - metadata_key = f"models/{model_name}/{version}/metadata.json" - self.s3_client.put_object( - Bucket=self.s3_bucket, - Key=metadata_key, - Body=json.dumps(metadata) - ) - return True - - elif self.registry_type == "local": - # Copy to local storage - model_dir = self.local_path / model_name / version - model_dir.mkdir(parents=True, exist_ok=True) - - import shutil - if os.path.isdir(model_path): - shutil.copytree(model_path, model_dir / "model", dirs_exist_ok=True) - else: - shutil.copy2(model_path, model_dir / "model") - - # Save metadata - import json - with open(model_dir / "metadata.json", "w") as f: - json.dump(metadata, f) - - return True - - except Exception as e: - print(f"Error registering model: {e}") - return False - - def get_model( - self, - model_name: str, - version: Optional[str] = None - ) -> Dict[str, Any]: - """ - Get model information from registry - - Args: - model_name: Model name - version: Model version (optional, gets latest if not specified) - - Returns: - Model information dict - """ - try: - if self.registry_type == "mlflow": - if version: - model_uri = f"models:/{model_name}/{version}" - else: - model_uri = f"models:/{model_name}/latest" - - model = mlflow.pyfunc.load_model(model_uri) - return { - "model": model, - "uri": model_uri, - "type": "mlflow" - } - - elif self.registry_type == "s3": - if not version: - # List versions and get latest - prefix = f"models/{model_name}/" - response = self.s3_client.list_objects_v2( - Bucket=self.s3_bucket, - Prefix=prefix, - Delimiter="/" - ) - versions = [obj["Prefix"].split("/")[-2] for obj in response.get("CommonPrefixes", [])] - version = max(versions) if versions else None - - if not version: - raise ValueError(f"Model {model_name} not found") - - # Download metadata - metadata_key = f"models/{model_name}/{version}/metadata.json" - response = self.s3_client.get_object(Bucket=self.s3_bucket, Key=metadata_key) - import json - metadata = json.loads(response["Body"].read()) - - return { - "model_path": f"s3://{self.s3_bucket}/models/{model_name}/{version}/model", - "metadata": metadata, - "type": "s3" - } - - elif self.registry_type == "local": - if not version: - # Get latest version - model_dir = self.local_path / model_name - if not model_dir.exists(): - raise ValueError(f"Model {model_name} not found") - - versions = [d.name for d in model_dir.iterdir() if d.is_dir()] - version = max(versions) if versions else None - - if not version: - raise ValueError(f"Model {model_name} not found") - - model_dir = self.local_path / model_name / version - metadata_path = model_dir / "metadata.json" - - import json - with open(metadata_path) as f: - metadata = json.load(f) - - return { - "model_path": str(model_dir / "model"), - "metadata": metadata, - "type": "local" - } - - except Exception as e: - print(f"Error getting model: {e}") - raise - - def download_model( - self, - model_name: str, - version: str, - destination: str - ) -> str: - """ - Download model to destination - - Args: - model_name: Model name - version: Model version - destination: Local destination path - - Returns: - Local path to downloaded model - """ - model_info = self.get_model(model_name, version) - - if self.registry_type == "s3": - # Download from S3 - s3_key = f"models/{model_name}/{version}/model" - os.makedirs(destination, exist_ok=True) - - # Check if it's a file or directory - try: - self.s3_client.head_object(Bucket=self.s3_bucket, Key=s3_key) - # It's a file - local_path = os.path.join(destination, "model") - self.s3_client.download_file(self.s3_bucket, s3_key, local_path) - return local_path - except ClientError: - # It's a directory, list and download all files - prefix = f"{s3_key}/" - paginator = self.s3_client.get_paginator('list_objects_v2') - for page in paginator.paginate(Bucket=self.s3_bucket, Prefix=prefix): - for obj in page.get('Contents', []): - key = obj['Key'] - local_file = os.path.join(destination, key[len(prefix):]) - os.makedirs(os.path.dirname(local_file), exist_ok=True) - self.s3_client.download_file(self.s3_bucket, key, local_file) - return destination - - elif self.registry_type == "local": - # Copy from local - source = model_info["model_path"] - import shutil - if os.path.isdir(source): - shutil.copytree(source, destination, dirs_exist_ok=True) - else: - os.makedirs(os.path.dirname(destination), exist_ok=True) - shutil.copy2(source, destination) - return destination - - else: - # MLflow handles loading directly - return model_info["uri"] - - def list_models(self, model_name: Optional[str] = None) -> list: - """ - List available models - - Args: - model_name: Optional model name filter - - Returns: - List of model information dicts - """ - try: - if self.registry_type == "mlflow": - if model_name: - models = mlflow.search_registered_models(filter_string=f"name='{model_name}'") - else: - models = mlflow.search_registered_models() - - return [ - { - "name": m.name, - "versions": [v.version for v in m.latest_versions], - "latest_version": m.latest_versions[0].version if m.latest_versions else None - } - for m in models - ] - - elif self.registry_type == "s3": - prefix = "models/" - if model_name: - prefix = f"models/{model_name}/" - - response = self.s3_client.list_objects_v2( - Bucket=self.s3_bucket, - Prefix=prefix, - Delimiter="/" - ) - - models = [] - for prefix_obj in response.get("CommonPrefixes", []): - model_path = prefix_obj["Prefix"] - parts = model_path.strip("/").split("/") - if len(parts) >= 2: - models.append({ - "name": parts[1], - "path": model_path - }) - - return models - - elif self.registry_type == "local": - models = [] - for model_dir in self.local_path.iterdir(): - if model_dir.is_dir(): - versions = [d.name for d in model_dir.iterdir() if d.is_dir()] - models.append({ - "name": model_dir.name, - "versions": versions, - "latest_version": max(versions) if versions else None - }) - return models - - except Exception as e: - print(f"Error listing models: {e}") - return [] - +""" +Unified model registry client +Supports MLflow, S3/MinIO, and local storage +""" + +import os +from typing import Dict, Any, Optional +from pathlib import Path +import mlflow +import boto3 +from botocore.exceptions import ClientError + +from ..contracts.models import ModelMetadata + + +class ModelRegistryClient: + """ + Unified client for model registry operations + Supports MLflow, S3/MinIO, and local storage + """ + + def __init__( + self, + registry_type: str = "mlflow", # mlflow, s3, local + mlflow_tracking_uri: Optional[str] = None, + s3_endpoint: Optional[str] = None, + s3_access_key: Optional[str] = None, + s3_secret_key: Optional[str] = None, + s3_bucket: Optional[str] = None, + local_path: Optional[str] = None, + ): + """ + Initialize model registry client + + Args: + registry_type: Type of registry (mlflow, s3, local) + mlflow_tracking_uri: MLflow tracking URI (defaults to MLFLOW_TRACKING_URI env var or http://mlflow:5000) + s3_endpoint: S3/MinIO endpoint + s3_access_key: S3 access key + s3_secret_key: S3 secret key + s3_bucket: S3 bucket name + local_path: Local storage path + """ + self.registry_type = registry_type + + if registry_type == "mlflow": + # Use provided URI, or environment variable, or default + tracking_uri = ( + mlflow_tracking_uri + or os.getenv("MLFLOW_TRACKING_URI") + or "http://mlflow:5000" + ) + mlflow.set_tracking_uri(tracking_uri) + self.client = mlflow + self.tracking_uri = tracking_uri + elif registry_type == "s3": + self.s3_client = boto3.client( + "s3", + endpoint_url=s3_endpoint, + aws_access_key_id=s3_access_key, + aws_secret_access_key=s3_secret_key, + ) + self.s3_bucket = s3_bucket + elif registry_type == "local": + self.local_path = Path(local_path or "./models") + self.local_path.mkdir(parents=True, exist_ok=True) + else: + raise ValueError(f"Unknown registry type: {registry_type}") + + def register_model( + self, model_name: str, version: str, model_path: str, metadata: Dict[str, Any] + ) -> bool: + """ + Register model in registry + + Args: + model_name: Model name + version: Model version + model_path: Path to model file/directory + metadata: Model metadata + + Returns: + True if successful + """ + try: + if self.registry_type == "mlflow": + # Register with MLflow + model_uri = f"runs:/{metadata.get('run_id', 'latest')}/model" + mlflow.register_model(model_uri, f"{model_name}-{version}") + return True + + elif self.registry_type == "s3": + # Upload to S3 + s3_key = f"models/{model_name}/{version}/model" + self.s3_client.upload_file(model_path, self.s3_bucket, s3_key) + + # Upload metadata + import json + + metadata_key = f"models/{model_name}/{version}/metadata.json" + self.s3_client.put_object( + Bucket=self.s3_bucket, Key=metadata_key, Body=json.dumps(metadata) + ) + return True + + elif self.registry_type == "local": + # Copy to local storage + model_dir = self.local_path / model_name / version + model_dir.mkdir(parents=True, exist_ok=True) + + import shutil + + if os.path.isdir(model_path): + shutil.copytree(model_path, model_dir / "model", dirs_exist_ok=True) + else: + shutil.copy2(model_path, model_dir / "model") + + # Save metadata + import json + + with open(model_dir / "metadata.json", "w") as f: + json.dump(metadata, f) + + return True + + except Exception as e: + print(f"Error registering model: {e}") + return False + + def get_model( + self, model_name: str, version: Optional[str] = None + ) -> Dict[str, Any]: + """ + Get model information from registry + + Args: + model_name: Model name + version: Model version (optional, gets latest if not specified) + + Returns: + Model information dict + """ + try: + if self.registry_type == "mlflow": + if version: + model_uri = f"models:/{model_name}/{version}" + else: + model_uri = f"models:/{model_name}/latest" + + model = mlflow.pyfunc.load_model(model_uri) + return {"model": model, "uri": model_uri, "type": "mlflow"} + + elif self.registry_type == "s3": + if not version: + # List versions and get latest + prefix = f"models/{model_name}/" + response = self.s3_client.list_objects_v2( + Bucket=self.s3_bucket, Prefix=prefix, Delimiter="/" + ) + versions = [ + obj["Prefix"].split("/")[-2] + for obj in response.get("CommonPrefixes", []) + ] + version = max(versions) if versions else None + + if not version: + raise ValueError(f"Model {model_name} not found") + + # Download metadata + metadata_key = f"models/{model_name}/{version}/metadata.json" + response = self.s3_client.get_object( + Bucket=self.s3_bucket, Key=metadata_key + ) + import json + + metadata = json.loads(response["Body"].read()) + + return { + "model_path": f"s3://{self.s3_bucket}/models/{model_name}/{version}/model", + "metadata": metadata, + "type": "s3", + } + + elif self.registry_type == "local": + if not version: + # Get latest version + model_dir = self.local_path / model_name + if not model_dir.exists(): + raise ValueError(f"Model {model_name} not found") + + versions = [d.name for d in model_dir.iterdir() if d.is_dir()] + version = max(versions) if versions else None + + if not version: + raise ValueError(f"Model {model_name} not found") + + model_dir = self.local_path / model_name / version + metadata_path = model_dir / "metadata.json" + + import json + + with open(metadata_path) as f: + metadata = json.load(f) + + return { + "model_path": str(model_dir / "model"), + "metadata": metadata, + "type": "local", + } + + except Exception as e: + print(f"Error getting model: {e}") + raise + + def download_model(self, model_name: str, version: str, destination: str) -> str: + """ + Download model to destination + + Args: + model_name: Model name + version: Model version + destination: Local destination path + + Returns: + Local path to downloaded model + """ + model_info = self.get_model(model_name, version) + + if self.registry_type == "s3": + # Download from S3 + s3_key = f"models/{model_name}/{version}/model" + os.makedirs(destination, exist_ok=True) + + # Check if it's a file or directory + try: + self.s3_client.head_object(Bucket=self.s3_bucket, Key=s3_key) + # It's a file + local_path = os.path.join(destination, "model") + self.s3_client.download_file(self.s3_bucket, s3_key, local_path) + return local_path + except ClientError: + # It's a directory, list and download all files + prefix = f"{s3_key}/" + paginator = self.s3_client.get_paginator("list_objects_v2") + for page in paginator.paginate(Bucket=self.s3_bucket, Prefix=prefix): + for obj in page.get("Contents", []): + key = obj["Key"] + local_file = os.path.join(destination, key[len(prefix) :]) + os.makedirs(os.path.dirname(local_file), exist_ok=True) + self.s3_client.download_file(self.s3_bucket, key, local_file) + return destination + + elif self.registry_type == "local": + # Copy from local + source = model_info["model_path"] + import shutil + + if os.path.isdir(source): + shutil.copytree(source, destination, dirs_exist_ok=True) + else: + os.makedirs(os.path.dirname(destination), exist_ok=True) + shutil.copy2(source, destination) + return destination + + else: + # MLflow handles loading directly + return model_info["uri"] + + def list_models(self, model_name: Optional[str] = None) -> list: + """ + List available models + + Args: + model_name: Optional model name filter + + Returns: + List of model information dicts + """ + try: + if self.registry_type == "mlflow": + if model_name: + models = mlflow.search_registered_models( + filter_string=f"name='{model_name}'" + ) + else: + models = mlflow.search_registered_models() + + return [ + { + "name": m.name, + "versions": [v.version for v in m.latest_versions], + "latest_version": ( + m.latest_versions[0].version if m.latest_versions else None + ), + } + for m in models + ] + + elif self.registry_type == "s3": + prefix = "models/" + if model_name: + prefix = f"models/{model_name}/" + + response = self.s3_client.list_objects_v2( + Bucket=self.s3_bucket, Prefix=prefix, Delimiter="/" + ) + + models = [] + for prefix_obj in response.get("CommonPrefixes", []): + model_path = prefix_obj["Prefix"] + parts = model_path.strip("/").split("/") + if len(parts) >= 2: + models.append({"name": parts[1], "path": model_path}) + + return models + + elif self.registry_type == "local": + models = [] + for model_dir in self.local_path.iterdir(): + if model_dir.is_dir(): + versions = [d.name for d in model_dir.iterdir() if d.is_dir()] + models.append( + { + "name": model_dir.name, + "versions": versions, + "latest_version": max(versions) if versions else None, + } + ) + return models + + except Exception as e: + print(f"Error listing models: {e}") + return [] diff --git a/src/deepiri_modelkit/streaming/event_stream.py b/src/deepiri_modelkit/streaming/event_stream.py index 40e5ed4..759fbef 100644 --- a/src/deepiri_modelkit/streaming/event_stream.py +++ b/src/deepiri_modelkit/streaming/event_stream.py @@ -1,204 +1,194 @@ -""" -Redis Streams client for event-driven architecture -""" -import redis.asyncio as redis -from typing import Dict, Any, Optional, AsyncIterator, Callable -import json -import asyncio -from datetime import datetime - -from .topics import StreamTopics -from ..contracts.events import BaseEvent - - -class StreamingClient: - """ - Redis Streams client for publishing and subscribing to events - """ - - def __init__( - self, - redis_url: Optional[str] = None, - redis_host: str = "redis", - redis_port: int = 6379, - redis_password: Optional[str] = None - ): - """ - Initialize streaming client - - Args: - redis_url: Full Redis URL (redis://password@host:port) - redis_host: Redis host (if not using redis_url) - redis_port: Redis port (if not using redis_url) - redis_password: Redis password (if not using redis_url) - """ - if redis_url: - self.redis = redis.from_url(redis_url, decode_responses=True) - else: - self.redis = redis.Redis( - host=redis_host, - port=redis_port, - password=redis_password, - decode_responses=True - ) - self._running = False - self._subscriptions = {} - - async def connect(self): - """Connect to Redis""" - await self.redis.ping() - - async def disconnect(self): - """Disconnect from Redis""" - await self.redis.close() - - async def publish( - self, - topic: str, - event: Dict[str, Any], - max_length: Optional[int] = 10000 - ) -> str: - """ - Publish event to stream - - Args: - topic: Stream topic name - event: Event data (dict) - max_length: Max stream length (truncate old messages) - - Returns: - Message ID - """ - # Ensure event has timestamp - if "timestamp" not in event: - event["timestamp"] = datetime.utcnow().isoformat() - - # Publish to stream - message_id = await self.redis.xadd( - topic, - event, - maxlen=max_length, - approximate=True - ) - - return message_id - - async def subscribe( - self, - topic: str, - callback: Callable[[Dict[str, Any]], None], - consumer_group: Optional[str] = None, - consumer_name: Optional[str] = None, - last_id: str = "0", - block_ms: int = 1000 - ) -> AsyncIterator[Dict[str, Any]]: - """ - Subscribe to stream and yield events - - Args: - topic: Stream topic name - callback: Optional callback function - consumer_group: Consumer group name (for load balancing) - consumer_name: Consumer name (unique per consumer) - last_id: Last message ID to read from - block_ms: Block time in milliseconds - - Yields: - Event data (dict) - """ - # Create consumer group if specified - if consumer_group: - try: - await self.redis.xgroup_create( - topic, - consumer_group, - id="0", - mkstream=True - ) - except redis.ResponseError as e: - if "BUSYGROUP" not in str(e): - raise - - self._running = True - - while self._running: - try: - if consumer_group and consumer_name: - # Read from consumer group - messages = await self.redis.xreadgroup( - consumer_group, - consumer_name, - {topic: ">"}, - count=10, - block=block_ms - ) - else: - # Direct read - messages = await self.redis.xread( - {topic: last_id}, - count=10, - block=block_ms - ) - - for stream_name, stream_messages in messages: - for msg_id, data in stream_messages: - # Yield event - yield data - - # Call callback if provided - if callback: - try: - await callback(data) if asyncio.iscoroutinefunction(callback) else callback(data) - except Exception as e: - print(f"Callback error: {e}") - - # Update last_id for next read - last_id = msg_id - - # Acknowledge if using consumer group - if consumer_group and consumer_name: - await self.redis.xack(topic, consumer_group, msg_id) - - except asyncio.CancelledError: - break - except Exception as e: - print(f"Stream subscription error: {e}") - await asyncio.sleep(1) - - async def subscribe_async( - self, - topic: str, - callback: Callable[[Dict[str, Any]], None], - consumer_group: Optional[str] = None, - consumer_name: Optional[str] = None - ): - """ - Subscribe to stream in background task - - Args: - topic: Stream topic name - callback: Callback function - consumer_group: Consumer group name - consumer_name: Consumer name - """ - async for event in self.subscribe( - topic, - callback, - consumer_group, - consumer_name - ): - pass # Callback handles events - - def stop(self): - """Stop all subscriptions""" - self._running = False - - async def get_stream_info(self, topic: str) -> Dict[str, Any]: - """Get stream information""" - info = await self.redis.xinfo_stream(topic) - return dict(info) - - async def get_stream_length(self, topic: str) -> int: - """Get number of messages in stream""" - return await self.redis.xlen(topic) - +""" +Redis Streams client for event-driven architecture +""" + +import redis.asyncio as redis +from typing import Dict, Any, Optional, AsyncIterator, Callable +import json +import asyncio +from datetime import datetime + +from .topics import StreamTopics +from ..contracts.events import BaseEvent + + +class StreamingClient: + """ + Redis Streams client for publishing and subscribing to events + """ + + def __init__( + self, + redis_url: Optional[str] = None, + redis_host: str = "redis", + redis_port: int = 6379, + redis_password: Optional[str] = None, + ): + """ + Initialize streaming client + + Args: + redis_url: Full Redis URL (redis://password@host:port) + redis_host: Redis host (if not using redis_url) + redis_port: Redis port (if not using redis_url) + redis_password: Redis password (if not using redis_url) + """ + if redis_url: + self.redis = redis.from_url(redis_url, decode_responses=True) + else: + self.redis = redis.Redis( + host=redis_host, + port=redis_port, + password=redis_password, + decode_responses=True, + ) + self._running = False + self._subscriptions = {} + + async def connect(self): + """Connect to Redis""" + await self.redis.ping() + + async def disconnect(self): + """Disconnect from Redis""" + await self.redis.close() + + async def publish( + self, topic: str, event: Dict[str, Any], max_length: Optional[int] = 10000 + ) -> str: + """ + Publish event to stream + + Args: + topic: Stream topic name + event: Event data (dict) + max_length: Max stream length (truncate old messages) + + Returns: + Message ID + """ + # Ensure event has timestamp + if "timestamp" not in event: + event["timestamp"] = datetime.utcnow().isoformat() + + # Publish to stream + message_id = await self.redis.xadd( + topic, event, maxlen=max_length, approximate=True + ) + + return message_id + + async def subscribe( + self, + topic: str, + callback: Callable[[Dict[str, Any]], None], + consumer_group: Optional[str] = None, + consumer_name: Optional[str] = None, + last_id: str = "0", + block_ms: int = 1000, + ) -> AsyncIterator[Dict[str, Any]]: + """ + Subscribe to stream and yield events + + Args: + topic: Stream topic name + callback: Optional callback function + consumer_group: Consumer group name (for load balancing) + consumer_name: Consumer name (unique per consumer) + last_id: Last message ID to read from + block_ms: Block time in milliseconds + + Yields: + Event data (dict) + """ + # Create consumer group if specified + if consumer_group: + try: + await self.redis.xgroup_create( + topic, consumer_group, id="0", mkstream=True + ) + except redis.ResponseError as e: + if "BUSYGROUP" not in str(e): + raise + + self._running = True + + while self._running: + try: + if consumer_group and consumer_name: + # Read from consumer group + messages = await self.redis.xreadgroup( + consumer_group, + consumer_name, + {topic: ">"}, + count=10, + block=block_ms, + ) + else: + # Direct read + messages = await self.redis.xread( + {topic: last_id}, count=10, block=block_ms + ) + + for stream_name, stream_messages in messages: + for msg_id, data in stream_messages: + # Yield event + yield data + + # Call callback if provided + if callback: + try: + ( + await callback(data) + if asyncio.iscoroutinefunction(callback) + else callback(data) + ) + except Exception as e: + print(f"Callback error: {e}") + + # Update last_id for next read + last_id = msg_id + + # Acknowledge if using consumer group + if consumer_group and consumer_name: + await self.redis.xack(topic, consumer_group, msg_id) + + except asyncio.CancelledError: + break + except Exception as e: + print(f"Stream subscription error: {e}") + await asyncio.sleep(1) + + async def subscribe_async( + self, + topic: str, + callback: Callable[[Dict[str, Any]], None], + consumer_group: Optional[str] = None, + consumer_name: Optional[str] = None, + ): + """ + Subscribe to stream in background task + + Args: + topic: Stream topic name + callback: Callback function + consumer_group: Consumer group name + consumer_name: Consumer name + """ + async for event in self.subscribe( + topic, callback, consumer_group, consumer_name + ): + pass # Callback handles events + + def stop(self): + """Stop all subscriptions""" + self._running = False + + async def get_stream_info(self, topic: str) -> Dict[str, Any]: + """Get stream information""" + info = await self.redis.xinfo_stream(topic) + return dict(info) + + async def get_stream_length(self, topic: str) -> int: + """Get number of messages in stream""" + return await self.redis.xlen(topic) diff --git a/src/deepiri_modelkit/streaming/schemas.py b/src/deepiri_modelkit/streaming/schemas.py index 388fa31..fdd1b4e 100644 --- a/src/deepiri_modelkit/streaming/schemas.py +++ b/src/deepiri_modelkit/streaming/schemas.py @@ -1,56 +1,55 @@ -""" -Streaming event schemas and validation -""" -from .topics import StreamTopics -from ..contracts.events import ( - BaseEvent, - ModelReadyEvent, - ModelLoadedEvent, - InferenceEvent, - PlatformEvent, - AGIDecisionEvent, - TrainingEvent, -) - - -# Map topics to event schemas -TOPIC_EVENT_SCHEMAS = { - StreamTopics.MODEL_EVENTS: [ModelReadyEvent, ModelLoadedEvent], - StreamTopics.INFERENCE_EVENTS: [InferenceEvent], - StreamTopics.PLATFORM_EVENTS: [PlatformEvent], - StreamTopics.AGI_DECISIONS: [AGIDecisionEvent], - StreamTopics.TRAINING_EVENTS: [TrainingEvent], -} - - -def validate_event(topic: str, event_data: dict) -> BaseEvent: - """ - Validate event against schema - - Args: - topic: Stream topic - event_data: Event data dict - - Returns: - Validated event object - - Raises: - ValueError: If event doesn't match schema - """ - if topic not in TOPIC_EVENT_SCHEMAS: - # Unknown topic, return base event - return BaseEvent(**event_data) - - schemas = TOPIC_EVENT_SCHEMAS[topic] - event_type = event_data.get("event") - - # Try to match event type to schema - for schema in schemas: - try: - return schema(**event_data) - except Exception: - continue - - # Fallback to base event - return BaseEvent(**event_data) - +""" +Streaming event schemas and validation +""" + +from .topics import StreamTopics +from ..contracts.events import ( + BaseEvent, + ModelReadyEvent, + ModelLoadedEvent, + InferenceEvent, + PlatformEvent, + AGIDecisionEvent, + TrainingEvent, +) + +# Map topics to event schemas +TOPIC_EVENT_SCHEMAS = { + StreamTopics.MODEL_EVENTS: [ModelReadyEvent, ModelLoadedEvent], + StreamTopics.INFERENCE_EVENTS: [InferenceEvent], + StreamTopics.PLATFORM_EVENTS: [PlatformEvent], + StreamTopics.AGI_DECISIONS: [AGIDecisionEvent], + StreamTopics.TRAINING_EVENTS: [TrainingEvent], +} + + +def validate_event(topic: str, event_data: dict) -> BaseEvent: + """ + Validate event against schema + + Args: + topic: Stream topic + event_data: Event data dict + + Returns: + Validated event object + + Raises: + ValueError: If event doesn't match schema + """ + if topic not in TOPIC_EVENT_SCHEMAS: + # Unknown topic, return base event + return BaseEvent(**event_data) + + schemas = TOPIC_EVENT_SCHEMAS[topic] + event_type = event_data.get("event") + + # Try to match event type to schema + for schema in schemas: + try: + return schema(**event_data) + except Exception: + continue + + # Fallback to base event + return BaseEvent(**event_data) diff --git a/src/deepiri_modelkit/streaming/sidecar_utils.py b/src/deepiri_modelkit/streaming/sidecar_utils.py index 3af5322..6842054 100644 --- a/src/deepiri_modelkit/streaming/sidecar_utils.py +++ b/src/deepiri_modelkit/streaming/sidecar_utils.py @@ -1,79 +1,81 @@ -""" -Shared Sugar Glider/Synapse sidecar helpers. - -These utilities are reused by multiple services (for example Cyrex and Helox) -to keep sidecar transport behavior consistent across repos. -""" - -from __future__ import annotations - -import json -import os -from typing import Any, Callable, Dict, Optional -from urllib.parse import urlparse - - -def env_float(name: str, default: float, logger: Optional[Callable[[str], None]] = None) -> float: - """Read a float env var with safe fallback.""" - raw = os.getenv(name) - if raw is None: - return default - try: - return float(raw) - except ValueError: - if logger is not None: - logger(f"invalid float env {name}={raw!r}; using {default}") - return default - - -def resolve_grpc_addr(base_url: str, explicit_grpc_addr: Optional[str] = None) -> str: - """ - Resolve sidecar gRPC host:port from explicit/env/base URL. - - Resolution order: - 1) explicit_grpc_addr - 2) SYNAPSE_GRPC_ADDR - 3) derive from base_url (8081 -> 50051) - """ - env_addr = os.getenv("SYNAPSE_GRPC_ADDR") - if explicit_grpc_addr: - return explicit_grpc_addr - if env_addr: - return env_addr - - parsed = urlparse(base_url) - if parsed.scheme in {"http", "https"}: - host = parsed.hostname or "localhost" - port = parsed.port - if port is None: - port = 443 if parsed.scheme == "https" else 80 - if port == 8081: - port = 50051 - return f"{host}:{port}" - - if base_url: - return base_url - return "localhost:50051" - - -def sidecar_payload_from_fields(fields: Dict[str, Any]) -> Dict[str, Any]: - """Normalize sidecar event fields to a payload dict.""" - payload = fields.get("payload", {}) - if isinstance(payload, str): - try: - payload = json.loads(payload) - except ValueError: - payload = {} - elif not isinstance(payload, dict): - payload = {} - - if "event" not in payload and fields.get("event_type"): - payload["event"] = fields.get("event_type") - - if "timestamp" not in payload and fields.get("timestamp"): - payload["timestamp"] = fields.get("timestamp") - - if "sender" not in payload and fields.get("sender"): - payload["sender"] = fields.get("sender") - - return payload +""" +Shared Sugar Glider/Synapse sidecar helpers. + +These utilities are reused by multiple services (for example Cyrex and Helox) +to keep sidecar transport behavior consistent across repos. +""" + +from __future__ import annotations + +import json +import os +from typing import Any, Callable, Dict, Optional +from urllib.parse import urlparse + + +def env_float( + name: str, default: float, logger: Optional[Callable[[str], None]] = None +) -> float: + """Read a float env var with safe fallback.""" + raw = os.getenv(name) + if raw is None: + return default + try: + return float(raw) + except ValueError: + if logger is not None: + logger(f"invalid float env {name}={raw!r}; using {default}") + return default + + +def resolve_grpc_addr(base_url: str, explicit_grpc_addr: Optional[str] = None) -> str: + """ + Resolve sidecar gRPC host:port from explicit/env/base URL. + + Resolution order: + 1) explicit_grpc_addr + 2) SYNAPSE_GRPC_ADDR + 3) derive from base_url (8081 -> 50051) + """ + env_addr = os.getenv("SYNAPSE_GRPC_ADDR") + if explicit_grpc_addr: + return explicit_grpc_addr + if env_addr: + return env_addr + + parsed = urlparse(base_url) + if parsed.scheme in {"http", "https"}: + host = parsed.hostname or "localhost" + port = parsed.port + if port is None: + port = 443 if parsed.scheme == "https" else 80 + if port == 8081: + port = 50051 + return f"{host}:{port}" + + if base_url: + return base_url + return "localhost:50051" + + +def sidecar_payload_from_fields(fields: Dict[str, Any]) -> Dict[str, Any]: + """Normalize sidecar event fields to a payload dict.""" + payload = fields.get("payload", {}) + if isinstance(payload, str): + try: + payload = json.loads(payload) + except ValueError: + payload = {} + elif not isinstance(payload, dict): + payload = {} + + if "event" not in payload and fields.get("event_type"): + payload["event"] = fields.get("event_type") + + if "timestamp" not in payload and fields.get("timestamp"): + payload["timestamp"] = fields.get("timestamp") + + if "sender" not in payload and fields.get("sender"): + payload["sender"] = fields.get("sender") + + return payload diff --git a/src/deepiri_modelkit/utils/__init__.py b/src/deepiri_modelkit/utils/__init__.py index a7dfcc9..0d08069 100644 --- a/src/deepiri_modelkit/utils/__init__.py +++ b/src/deepiri_modelkit/utils/__init__.py @@ -1,7 +1,8 @@ -"""Common utilities for Deepiri ModelKit""" - -try: - from .device import get_device, get_torch_device - __all__ = ["get_device", "get_torch_device"] -except ImportError: - __all__ = [] +"""Common utilities for Deepiri ModelKit""" + +try: + from .device import get_device, get_torch_device + + __all__ = ["get_device", "get_torch_device"] +except ImportError: + __all__ = [] diff --git a/src/deepiri_modelkit/utils/device.py b/src/deepiri_modelkit/utils/device.py index 8b3922a..0226c5e 100644 --- a/src/deepiri_modelkit/utils/device.py +++ b/src/deepiri_modelkit/utils/device.py @@ -1,143 +1,158 @@ -""" -GPU Device Detection Utility -Automatically detects and uses GPU (CUDA) if available, falls back to CPU -""" -import os - -try: - import torch - HAS_TORCH = True -except ImportError: - HAS_TORCH = False - -from deepiri_modelkit.logging import get_logger - -logger = get_logger("deepiri_modelkit.utils.device") - - -def get_device() -> str: - """ - Detect the best available device with proper fallback: CUDA → MPS → CPU - - Returns device string that can be used with PyTorch and SentenceTransformers. - Actually tests GPU functionality, not just availability. - """ - if not HAS_TORCH: - logger.info("PyTorch not installed, using CPU") - return "cpu" - - # Log diagnostic information for debugging - logger.debug(f"PyTorch version: {torch.__version__}") - logger.debug(f"CUDA available (torch.cuda.is_available()): {torch.cuda.is_available()}") - - if torch.cuda.is_available(): - try: - # Get CUDA information - cuda_version = torch.version.cuda - device_count = torch.cuda.device_count() - gpu_name = torch.cuda.get_device_name(0) - cuda_capability = torch.cuda.get_device_capability(0) - - logger.info( - f"CUDA detected: version={cuda_version}, devices={device_count}, " - f"GPU={gpu_name}, capability={cuda_capability[0]}.{cuda_capability[1]}" - ) - - # Check for RTX 5080/5090 (sm_120) compatibility issue - if cuda_capability[0] >= 12: - # Check if PyTorch supports this compute capability - try: - # Try to get the list of supported compute capabilities - # PyTorch 2.9.1 with CUDA 12.6 supports up to sm_90 - # RTX 5080/5090 requires sm_120 support (CUDA 12.8+) - test_tensor = torch.tensor([1.0], device='cuda') - result = test_tensor * 2.0 - _ = result.cpu() - del test_tensor, result - torch.cuda.empty_cache() - except RuntimeError as e: - if "no kernel image is available for execution on the device" in str(e) or \ - "cudaErrorNoKernelImageForDevice" in str(e): - logger.error( - f"RTX 5080/5090 (sm_{cuda_capability[0]}.{cuda_capability[1]}) detected, but PyTorch doesn't support this compute capability. " - f"Current PyTorch supports up to sm_90. " - f"To fix: Rebuild Docker image (CUDA 12.8 support should be automatic): " - f"docker-compose -f docker-compose.dev.yml build --no-cache cyrex" - ) - raise - else: - raise - - # Test GPU functionality with a simple operation - test_tensor = torch.tensor([1.0], device='cuda') - result = test_tensor * 2.0 - _ = result.cpu() # Ensure operation completes - del test_tensor, result - torch.cuda.empty_cache() - - logger.info( - f"CUDA GPU detected and tested successfully: {gpu_name} " - f"(CUDA {cuda_version}, Capability {cuda_capability[0]}.{cuda_capability[1]})" - ) - return "cuda" - except RuntimeError as cuda_error: - error_msg = str(cuda_error) - # Check if this is the RTX 5080 compatibility issue - if "no kernel image is available for execution on the device" in error_msg or \ - "cudaErrorNoKernelImageForDevice" in error_msg: - logger.error( - f"GPU compute capability not supported by current PyTorch installation. " - f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'Unknown'}, " - f"Capability: {torch.cuda.get_device_capability(0) if torch.cuda.is_available() else 'Unknown'}. " - f"Error: {error_msg}. " - f"Solution: Rebuild Docker image (CUDA 12.8 support should be automatic): " - f"docker-compose -f docker-compose.dev.yml build --no-cache cyrex" - ) - else: - logger.warning( - f"CUDA available but GPU test failed: {error_msg}. " - f"Falling back to CPU. This may indicate: " - f"1) GPU not accessible in Docker container (check NVIDIA Container Toolkit), " - f"2) CUDA driver mismatch, or 3) GPU memory issue." - ) - except Exception as cuda_error: - logger.warning( - f"CUDA available but test failed: {cuda_error}. Falling back to CPU" - ) - else: - # CUDA not available - provide diagnostic info - logger.debug("CUDA not available via torch.cuda.is_available()") - - # Check if we're in Docker and might need NVIDIA Container Toolkit - if os.path.exists("/.dockerenv"): - logger.debug("Running in Docker container - ensure NVIDIA Container Toolkit is installed") - # Check for NVIDIA runtime - if os.path.exists("/proc/driver/nvidia"): - logger.warning( - "NVIDIA driver detected in container but PyTorch CUDA not available. " - "This may indicate: 1) PyTorch not built with CUDA support, " - "2) CUDA libraries not in container, or 3) NVIDIA Container Toolkit not configured." - ) - - # Check MPS (Apple Silicon) - if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): - try: - test_tensor = torch.tensor([1.0], device='mps') - result = test_tensor * 2.0 - _ = result.cpu() - del test_tensor, result - logger.info("Apple Silicon (MPS) detected and tested successfully") - return "mps" - except Exception as mps_error: - logger.warning(f"MPS available but test failed, falling back to CPU: {mps_error}") - - # Fallback to CPU - logger.info("Using CPU device (no GPU detected or GPU test failed)") - return "cpu" - - -def get_torch_device() -> "torch.device": - """Get PyTorch device object""" - if not HAS_TORCH: - raise ImportError("torch is required for get_torch_device(). Install with: pip install torch") - return torch.device(get_device()) +""" +GPU Device Detection Utility +Automatically detects and uses GPU (CUDA) if available, falls back to CPU +""" + +import os + +try: + import torch + + HAS_TORCH = True +except ImportError: + HAS_TORCH = False + +from deepiri_modelkit.logging import get_logger + +logger = get_logger("deepiri_modelkit.utils.device") + + +def get_device() -> str: + """ + Detect the best available device with proper fallback: CUDA → MPS → CPU + + Returns device string that can be used with PyTorch and SentenceTransformers. + Actually tests GPU functionality, not just availability. + """ + if not HAS_TORCH: + logger.info("PyTorch not installed, using CPU") + return "cpu" + + # Log diagnostic information for debugging + logger.debug(f"PyTorch version: {torch.__version__}") + logger.debug( + f"CUDA available (torch.cuda.is_available()): {torch.cuda.is_available()}" + ) + + if torch.cuda.is_available(): + try: + # Get CUDA information + cuda_version = torch.version.cuda + device_count = torch.cuda.device_count() + gpu_name = torch.cuda.get_device_name(0) + cuda_capability = torch.cuda.get_device_capability(0) + + logger.info( + f"CUDA detected: version={cuda_version}, devices={device_count}, " + f"GPU={gpu_name}, capability={cuda_capability[0]}.{cuda_capability[1]}" + ) + + # Check for RTX 5080/5090 (sm_120) compatibility issue + if cuda_capability[0] >= 12: + # Check if PyTorch supports this compute capability + try: + # Try to get the list of supported compute capabilities + # PyTorch 2.9.1 with CUDA 12.6 supports up to sm_90 + # RTX 5080/5090 requires sm_120 support (CUDA 12.8+) + test_tensor = torch.tensor([1.0], device="cuda") + result = test_tensor * 2.0 + _ = result.cpu() + del test_tensor, result + torch.cuda.empty_cache() + except RuntimeError as e: + if ( + "no kernel image is available for execution on the device" + in str(e) + or "cudaErrorNoKernelImageForDevice" in str(e) + ): + logger.error( + f"RTX 5080/5090 (sm_{cuda_capability[0]}.{cuda_capability[1]}) detected, but PyTorch doesn't support this compute capability. " + f"Current PyTorch supports up to sm_90. " + f"To fix: Rebuild Docker image (CUDA 12.8 support should be automatic): " + f"docker-compose -f docker-compose.dev.yml build --no-cache cyrex" + ) + raise + else: + raise + + # Test GPU functionality with a simple operation + test_tensor = torch.tensor([1.0], device="cuda") + result = test_tensor * 2.0 + _ = result.cpu() # Ensure operation completes + del test_tensor, result + torch.cuda.empty_cache() + + logger.info( + f"CUDA GPU detected and tested successfully: {gpu_name} " + f"(CUDA {cuda_version}, Capability {cuda_capability[0]}.{cuda_capability[1]})" + ) + return "cuda" + except RuntimeError as cuda_error: + error_msg = str(cuda_error) + # Check if this is the RTX 5080 compatibility issue + if ( + "no kernel image is available for execution on the device" in error_msg + or "cudaErrorNoKernelImageForDevice" in error_msg + ): + logger.error( + f"GPU compute capability not supported by current PyTorch installation. " + f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'Unknown'}, " + f"Capability: {torch.cuda.get_device_capability(0) if torch.cuda.is_available() else 'Unknown'}. " + f"Error: {error_msg}. " + f"Solution: Rebuild Docker image (CUDA 12.8 support should be automatic): " + f"docker-compose -f docker-compose.dev.yml build --no-cache cyrex" + ) + else: + logger.warning( + f"CUDA available but GPU test failed: {error_msg}. " + f"Falling back to CPU. This may indicate: " + f"1) GPU not accessible in Docker container (check NVIDIA Container Toolkit), " + f"2) CUDA driver mismatch, or 3) GPU memory issue." + ) + except Exception as cuda_error: + logger.warning( + f"CUDA available but test failed: {cuda_error}. Falling back to CPU" + ) + else: + # CUDA not available - provide diagnostic info + logger.debug("CUDA not available via torch.cuda.is_available()") + + # Check if we're in Docker and might need NVIDIA Container Toolkit + if os.path.exists("/.dockerenv"): + logger.debug( + "Running in Docker container - ensure NVIDIA Container Toolkit is installed" + ) + # Check for NVIDIA runtime + if os.path.exists("/proc/driver/nvidia"): + logger.warning( + "NVIDIA driver detected in container but PyTorch CUDA not available. " + "This may indicate: 1) PyTorch not built with CUDA support, " + "2) CUDA libraries not in container, or 3) NVIDIA Container Toolkit not configured." + ) + + # Check MPS (Apple Silicon) + if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + try: + test_tensor = torch.tensor([1.0], device="mps") + result = test_tensor * 2.0 + _ = result.cpu() + del test_tensor, result + logger.info("Apple Silicon (MPS) detected and tested successfully") + return "mps" + except Exception as mps_error: + logger.warning( + f"MPS available but test failed, falling back to CPU: {mps_error}" + ) + + # Fallback to CPU + logger.info("Using CPU device (no GPU detected or GPU test failed)") + return "cpu" + + +def get_torch_device() -> "torch.device": + """Get PyTorch device object""" + if not HAS_TORCH: + raise ImportError( + "torch is required for get_torch_device(). Install with: pip install torch" + ) + return torch.device(get_device()) From bdc9100530d56b813be8a1dc388a20efc8ee45c2 Mon Sep 17 00:00:00 2001 From: Bao Tran Date: Thu, 7 May 2026 16:07:40 -0400 Subject: [PATCH 4/6] test(streaming): cover shared topic constants --- tests/test_streaming_topics.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) create mode 100644 tests/test_streaming_topics.py diff --git a/tests/test_streaming_topics.py b/tests/test_streaming_topics.py new file mode 100644 index 0000000..1542700 --- /dev/null +++ b/tests/test_streaming_topics.py @@ -0,0 +1,27 @@ +from deepiri_modelkit.streaming.topics import StreamTopics + + +def test_document_stream_topics_match_lis_routing_namespace() -> None: + assert StreamTopics.DOCUMENT_VECTORIZE.value == "document.vectorize" + assert StreamTopics.DOCUMENT_TRAINING.value == "document.training" + assert StreamTopics.DOCUMENT_STRUCTURED.value == "document.structured" + assert StreamTopics.DOCUMENT_ARTIFACTS.value == "document.artifacts" + + +def test_helox_training_topics_stay_in_pipeline_namespace() -> None: + assert StreamTopics.HELOX_TRAINING_RAW.value == "pipeline.helox-training.raw" + assert ( + StreamTopics.HELOX_TRAINING_STRUCTURED.value + == "pipeline.helox-training.structured" + ) + + +def test_all_includes_shared_stream_topics() -> None: + topics = set(StreamTopics.all()) + + assert "document.vectorize" in topics + assert "document.training" in topics + assert "document.structured" in topics + assert "document.artifacts" in topics + assert "pipeline.helox-training.raw" in topics + assert "pipeline.helox-training.structured" in topics From 8bb4d183b61eab618e0a5f0e6bec80140a6a6edc Mon Sep 17 00:00:00 2001 From: Bao Tran Date: Sat, 9 May 2026 19:26:40 -0400 Subject: [PATCH 5/6] chore(streaming): narrow topic constants PR diff --- src/deepiri_modelkit/__init__.py | 77 +- src/deepiri_modelkit/contracts/contract.py | 57 +- src/deepiri_modelkit/contracts/events.py | 206 ++-- src/deepiri_modelkit/contracts/models.py | 292 +++--- src/deepiri_modelkit/contracts/services.py | 102 +- src/deepiri_modelkit/data/monitoring.py | 794 +++++++-------- src/deepiri_modelkit/data/validation.py | 764 +++++++------- src/deepiri_modelkit/logging.py | 318 +++--- src/deepiri_modelkit/ml/__init__.py | 68 +- src/deepiri_modelkit/ml/confidence.py | 608 ++++++------ src/deepiri_modelkit/ml/semantic.py | 704 +++++++------ src/deepiri_modelkit/rag/__init__.py | 321 +++--- .../rag/advanced_retrieval.py | 816 ++++++++------- src/deepiri_modelkit/rag/async_processing.py | 169 ++-- src/deepiri_modelkit/rag/base.py | 617 ++++++------ src/deepiri_modelkit/rag/caching.py | 932 +++++++++--------- src/deepiri_modelkit/rag/monitoring.py | 729 +++++++------- src/deepiri_modelkit/rag/processors.py | 866 ++++++++-------- src/deepiri_modelkit/rag/retrievers.py | 575 +++++------ src/deepiri_modelkit/rag/testing.py | 666 +++++++------ .../registry/adapters/__init__.py | 3 +- .../registry/model_registry.py | 669 ++++++------- .../streaming/event_stream.py | 398 ++++---- src/deepiri_modelkit/streaming/schemas.py | 111 +-- .../streaming/sidecar_utils.py | 160 ++- src/deepiri_modelkit/streaming/topics.py | 27 +- src/deepiri_modelkit/utils/__init__.py | 15 +- src/deepiri_modelkit/utils/device.py | 301 +++--- 28 files changed, 5562 insertions(+), 5803 deletions(-) diff --git a/src/deepiri_modelkit/__init__.py b/src/deepiri_modelkit/__init__.py index 5601e33..4d3f09b 100644 --- a/src/deepiri_modelkit/__init__.py +++ b/src/deepiri_modelkit/__init__.py @@ -1,41 +1,36 @@ -""" -Deepiri ModelKit - Shared contracts, interfaces, and utilities -""" - -__version__ = "0.1.0" - -from .contracts.models import ( - AIModel, - AIModelPydantic, - ModelInput, - ModelOutput, - ModelMetadata, -) -from .contracts.events import ( - ModelReadyEvent, - InferenceEvent, - PlatformEvent, - AGIDecisionEvent, - TrainingEvent, -) -from .streaming.event_stream import StreamingClient -from .registry.model_registry import ModelRegistryClient -from .logging import get_logger, get_error_logger, ErrorLogger - -__all__ = [ - "AIModel", # Protocol interface for type checking - "AIModelPydantic", # Pydantic-compatible type for use in BaseModel fields - "ModelInput", - "ModelOutput", - "ModelMetadata", - "ModelReadyEvent", - "InferenceEvent", - "PlatformEvent", - "AGIDecisionEvent", - "TrainingEvent", - "StreamingClient", - "ModelRegistryClient", - "get_logger", - "get_error_logger", - "ErrorLogger", -] +""" +Deepiri ModelKit - Shared contracts, interfaces, and utilities +""" + +__version__ = "0.1.0" + +from .contracts.models import AIModel, AIModelPydantic, ModelInput, ModelOutput, ModelMetadata +from .contracts.events import ( + ModelReadyEvent, + InferenceEvent, + PlatformEvent, + AGIDecisionEvent, + TrainingEvent, +) +from .streaming.event_stream import StreamingClient +from .registry.model_registry import ModelRegistryClient +from .logging import get_logger, get_error_logger, ErrorLogger + +__all__ = [ + "AIModel", # Protocol interface for type checking + "AIModelPydantic", # Pydantic-compatible type for use in BaseModel fields + "ModelInput", + "ModelOutput", + "ModelMetadata", + "ModelReadyEvent", + "InferenceEvent", + "PlatformEvent", + "AGIDecisionEvent", + "TrainingEvent", + "StreamingClient", + "ModelRegistryClient", + "get_logger", + "get_error_logger", + "ErrorLogger", +] + diff --git a/src/deepiri_modelkit/contracts/contract.py b/src/deepiri_modelkit/contracts/contract.py index 237ac62..00abc21 100644 --- a/src/deepiri_modelkit/contracts/contract.py +++ b/src/deepiri_modelkit/contracts/contract.py @@ -1,30 +1,27 @@ -""" -Model contract for registry (separated from models.py to avoid Pydantic Protocol conflicts) -""" - -from __future__ import annotations - -from typing import Dict, Any, Optional -from pydantic import BaseModel - -from .models import ModelMetadata - - -class ModelContract(BaseModel): - """ - Complete model contract for registry. - - A contract is serializable metadata that describes a model's interface, - input/output schemas, and validation requirements. It does NOT contain - the actual model instance (which would be a Protocol type that Pydantic - cannot serialize). The model instance should be loaded separately when needed. - """ - - metadata: ModelMetadata - input_schema: Dict[str, Any] - output_schema: Dict[str, Any] - validation_tests: Optional[list] = None - model_path: Optional[str] = ( - None # Path/reference to where the model can be loaded from - ) - model_id: Optional[str] = None # Unique identifier for the model instance +""" +Model contract for registry (separated from models.py to avoid Pydantic Protocol conflicts) +""" +from __future__ import annotations + +from typing import Dict, Any, Optional +from pydantic import BaseModel + +from .models import ModelMetadata + + +class ModelContract(BaseModel): + """ + Complete model contract for registry. + + A contract is serializable metadata that describes a model's interface, + input/output schemas, and validation requirements. It does NOT contain + the actual model instance (which would be a Protocol type that Pydantic + cannot serialize). The model instance should be loaded separately when needed. + """ + metadata: ModelMetadata + input_schema: Dict[str, Any] + output_schema: Dict[str, Any] + validation_tests: Optional[list] = None + model_path: Optional[str] = None # Path/reference to where the model can be loaded from + model_id: Optional[str] = None # Unique identifier for the model instance + diff --git a/src/deepiri_modelkit/contracts/events.py b/src/deepiri_modelkit/contracts/events.py index e1325b4..9711a93 100644 --- a/src/deepiri_modelkit/contracts/events.py +++ b/src/deepiri_modelkit/contracts/events.py @@ -1,107 +1,99 @@ -""" -Event schemas for streaming service -""" - -from pydantic import BaseModel, Field -from typing import Dict, Any, Optional -from datetime import datetime -from enum import Enum - - -class EventType(str, Enum): - """Event type enumeration""" - - MODEL_READY = "model-ready" - MODEL_LOADED = "model-loaded" - MODEL_FAILED = "model-failed" - INFERENCE_COMPLETE = "inference-complete" - INFERENCE_FAILED = "inference-failed" - USER_INTERACTION = "user-interaction" - TASK_CREATED = "task-created" - TASK_COMPLETED = "task-completed" - AGI_DECISION = "agi-decision" - AGI_ACTION = "agi-action" - TRAINING_STARTED = "training-started" - TRAINING_COMPLETE = "training-complete" - TRAINING_FAILED = "training-failed" - - -class BaseEvent(BaseModel): - """Base event schema""" - - event: str - timestamp: str = Field(default_factory=lambda: datetime.utcnow().isoformat()) - source: str - correlation_id: Optional[str] = None - - -class ModelReadyEvent(BaseEvent): - """Event published when model is trained and ready""" - - event: str = EventType.MODEL_READY - model_name: str - version: str - registry_path: str # S3/MLflow path - metadata: Dict[str, Any] - model_type: Optional[str] = None - accuracy: Optional[float] = None - size_mb: Optional[float] = None - - -class ModelLoadedEvent(BaseEvent): - """Event published when model is loaded in runtime""" - - event: str = EventType.MODEL_LOADED - model_name: str - version: str - load_time_ms: float - cache_location: Optional[str] = None - - -class InferenceEvent(BaseEvent): - """Event published after inference completes""" - - event: str = EventType.INFERENCE_COMPLETE - model_name: str - version: str - user_id: Optional[str] = None - request_id: Optional[str] = None - latency_ms: float - tokens_used: Optional[int] = None - cost: Optional[float] = None - confidence: Optional[float] = None - success: bool = True - - -class PlatformEvent(BaseEvent): - """Event published by platform services""" - - event: str # user-interaction, task-created, etc. - service: str - user_id: Optional[str] = None - action: str - data: Dict[str, Any] - organization_id: Optional[str] = None - - -class AGIDecisionEvent(BaseEvent): - """Event published by Cyrex-AGI for autonomous decisions""" - - event: str = EventType.AGI_DECISION - decision_type: str - target_service: Optional[str] = None - action: Dict[str, Any] - reasoning: Optional[str] = None - confidence: Optional[float] = None - - -class TrainingEvent(BaseEvent): - """Event published during training""" - - event: str # training-started, training-complete, training-failed - experiment_id: str - model_name: str - status: str - progress: Optional[float] = None # 0.0 to 1.0 - metrics: Optional[Dict[str, Any]] = None - error: Optional[str] = None +""" +Event schemas for streaming service +""" +from pydantic import BaseModel, Field +from typing import Dict, Any, Optional +from datetime import datetime +from enum import Enum + + +class EventType(str, Enum): + """Event type enumeration""" + MODEL_READY = "model-ready" + MODEL_LOADED = "model-loaded" + MODEL_FAILED = "model-failed" + INFERENCE_COMPLETE = "inference-complete" + INFERENCE_FAILED = "inference-failed" + USER_INTERACTION = "user-interaction" + TASK_CREATED = "task-created" + TASK_COMPLETED = "task-completed" + AGI_DECISION = "agi-decision" + AGI_ACTION = "agi-action" + TRAINING_STARTED = "training-started" + TRAINING_COMPLETE = "training-complete" + TRAINING_FAILED = "training-failed" + + +class BaseEvent(BaseModel): + """Base event schema""" + event: str + timestamp: str = Field(default_factory=lambda: datetime.utcnow().isoformat()) + source: str + correlation_id: Optional[str] = None + + +class ModelReadyEvent(BaseEvent): + """Event published when model is trained and ready""" + event: str = EventType.MODEL_READY + model_name: str + version: str + registry_path: str # S3/MLflow path + metadata: Dict[str, Any] + model_type: Optional[str] = None + accuracy: Optional[float] = None + size_mb: Optional[float] = None + + +class ModelLoadedEvent(BaseEvent): + """Event published when model is loaded in runtime""" + event: str = EventType.MODEL_LOADED + model_name: str + version: str + load_time_ms: float + cache_location: Optional[str] = None + + +class InferenceEvent(BaseEvent): + """Event published after inference completes""" + event: str = EventType.INFERENCE_COMPLETE + model_name: str + version: str + user_id: Optional[str] = None + request_id: Optional[str] = None + latency_ms: float + tokens_used: Optional[int] = None + cost: Optional[float] = None + confidence: Optional[float] = None + success: bool = True + + +class PlatformEvent(BaseEvent): + """Event published by platform services""" + event: str # user-interaction, task-created, etc. + service: str + user_id: Optional[str] = None + action: str + data: Dict[str, Any] + organization_id: Optional[str] = None + + +class AGIDecisionEvent(BaseEvent): + """Event published by Cyrex-AGI for autonomous decisions""" + event: str = EventType.AGI_DECISION + decision_type: str + target_service: Optional[str] = None + action: Dict[str, Any] + reasoning: Optional[str] = None + confidence: Optional[float] = None + + +class TrainingEvent(BaseEvent): + """Event published during training""" + event: str # training-started, training-complete, training-failed + experiment_id: str + model_name: str + status: str + progress: Optional[float] = None # 0.0 to 1.0 + metrics: Optional[Dict[str, Any]] = None + error: Optional[str] = None + diff --git a/src/deepiri_modelkit/contracts/models.py b/src/deepiri_modelkit/contracts/models.py index 79b2e43..6e52336 100644 --- a/src/deepiri_modelkit/contracts/models.py +++ b/src/deepiri_modelkit/contracts/models.py @@ -1,149 +1,143 @@ -""" -Model contracts and interfaces -""" - -from __future__ import ( - annotations, -) # Defer annotation evaluation to prevent Pydantic from processing Protocol types - -from typing import Protocol, Dict, Any, Optional, Annotated -from pydantic import BaseModel, Field, GetCoreSchemaHandler -from pydantic_core import core_schema -from datetime import datetime - - -class ModelInput(BaseModel): - """Standard model input schema""" - - data: Dict[str, Any] - metadata: Optional[Dict[str, Any]] = None - timestamp: str = Field(default_factory=lambda: datetime.utcnow().isoformat()) - - -class ModelOutput(BaseModel): - """Standard model output schema""" - - prediction: Any - confidence: Optional[float] = None - metadata: Optional[Dict[str, Any]] = None - timestamp: str = Field(default_factory=lambda: datetime.utcnow().isoformat()) - - -class ModelMetadata(BaseModel): - """Model metadata schema""" - - name: str - version: str - description: Optional[str] = None - architecture: Optional[str] = None - accuracy: Optional[float] = None - size_mb: Optional[float] = None - created_at: str = Field(default_factory=lambda: datetime.utcnow().isoformat()) - trained_by: Optional[str] = None - tags: Optional[Dict[str, Any]] = None - - -class AIModel(Protocol): - """ - Interface that all models must implement - Used by both Helox (training) and Cyrex (runtime) - - Note: This is a Protocol (structural type). To use in Pydantic models, - use AIModelPydantic instead, which has full Pydantic schema support. - """ - - def predict(self, input: ModelInput) -> ModelOutput: - """Run inference on input""" - ... - - def get_metadata(self) -> ModelMetadata: - """Get model metadata""" - ... - - def validate(self) -> bool: - """Validate model is ready for use""" - ... - - def export(self, format: str = "onnx") -> str: - """Export model to specified format, returns path""" - ... - - -class AIModelPydantic: - """ - Pydantic-compatible wrapper for AIModel Protocol. - Implements __get_pydantic_core_schema__ to provide full schema support. - - Usage in Pydantic models: - model: Optional[AIModelPydantic] = None - - Note: In practice, model instances should be loaded separately and referenced - by ID/path rather than stored directly in serializable Pydantic models. - """ - - @classmethod - def __get_pydantic_core_schema__( - cls, - source_type: Any, - handler: GetCoreSchemaHandler, - ) -> core_schema.CoreSchema: - """ - Pydantic Core Schema handler for AIModel Protocol. - - This allows Pydantic to process AIModel types in model fields. - Since Protocols are structural types, we validate that the object - has the required methods rather than checking exact type. - """ - - def validate_aimodel(value: Any) -> Any: - """Validate that value implements AIModel Protocol interface""" - if value is None: - return None - - # Check for required Protocol methods - required_methods = ["predict", "get_metadata", "validate", "export"] - missing_methods = [m for m in required_methods if not hasattr(value, m)] - - if missing_methods: - raise ValueError( - f"Object does not implement AIModel Protocol. " - f"Missing methods: {', '.join(missing_methods)}" - ) - - return value - - def serialize_aimodel(value: Any) -> Dict[str, Any]: - """Serialize AIModel instance to dict""" - if value is None: - return None - - # Try to get metadata if available - metadata = None - if hasattr(value, "get_metadata"): - try: - metadata = value.get_metadata() - # Convert ModelMetadata to dict if it's a Pydantic model - if hasattr(metadata, "model_dump"): - metadata = metadata.model_dump() - elif hasattr(metadata, "dict"): - metadata = metadata.dict() - except Exception: - pass - - return { - "type": "AIModel", - "metadata": metadata, - "has_predict": hasattr(value, "predict"), - "has_validate": hasattr(value, "validate"), - } - - return core_schema.no_info_plain_validator_function( - validate_aimodel, - serialization=core_schema.plain_serializer_function_ser_schema( - serialize_aimodel - ), - ) - - -# ModelContract moved to contract.py to avoid Pydantic Protocol conflicts -# Import it from .contract import ModelContract when needed +""" +Model contracts and interfaces +""" +from __future__ import annotations # Defer annotation evaluation to prevent Pydantic from processing Protocol types + +from typing import Protocol, Dict, Any, Optional, Annotated +from pydantic import BaseModel, Field, GetCoreSchemaHandler +from pydantic_core import core_schema +from datetime import datetime + + +class ModelInput(BaseModel): + """Standard model input schema""" + data: Dict[str, Any] + metadata: Optional[Dict[str, Any]] = None + timestamp: str = Field(default_factory=lambda: datetime.utcnow().isoformat()) + + +class ModelOutput(BaseModel): + """Standard model output schema""" + prediction: Any + confidence: Optional[float] = None + metadata: Optional[Dict[str, Any]] = None + timestamp: str = Field(default_factory=lambda: datetime.utcnow().isoformat()) + + +class ModelMetadata(BaseModel): + """Model metadata schema""" + name: str + version: str + description: Optional[str] = None + architecture: Optional[str] = None + accuracy: Optional[float] = None + size_mb: Optional[float] = None + created_at: str = Field(default_factory=lambda: datetime.utcnow().isoformat()) + trained_by: Optional[str] = None + tags: Optional[Dict[str, Any]] = None + + +class AIModel(Protocol): + """ + Interface that all models must implement + Used by both Helox (training) and Cyrex (runtime) + + Note: This is a Protocol (structural type). To use in Pydantic models, + use AIModelPydantic instead, which has full Pydantic schema support. + """ + + def predict(self, input: ModelInput) -> ModelOutput: + """Run inference on input""" + ... + + def get_metadata(self) -> ModelMetadata: + """Get model metadata""" + ... + + def validate(self) -> bool: + """Validate model is ready for use""" + ... + + def export(self, format: str = "onnx") -> str: + """Export model to specified format, returns path""" + ... + + +class AIModelPydantic: + """ + Pydantic-compatible wrapper for AIModel Protocol. + Implements __get_pydantic_core_schema__ to provide full schema support. + + Usage in Pydantic models: + model: Optional[AIModelPydantic] = None + + Note: In practice, model instances should be loaded separately and referenced + by ID/path rather than stored directly in serializable Pydantic models. + """ + + @classmethod + def __get_pydantic_core_schema__( + cls, + source_type: Any, + handler: GetCoreSchemaHandler, + ) -> core_schema.CoreSchema: + """ + Pydantic Core Schema handler for AIModel Protocol. + + This allows Pydantic to process AIModel types in model fields. + Since Protocols are structural types, we validate that the object + has the required methods rather than checking exact type. + """ + def validate_aimodel(value: Any) -> Any: + """Validate that value implements AIModel Protocol interface""" + if value is None: + return None + + # Check for required Protocol methods + required_methods = ['predict', 'get_metadata', 'validate', 'export'] + missing_methods = [m for m in required_methods if not hasattr(value, m)] + + if missing_methods: + raise ValueError( + f"Object does not implement AIModel Protocol. " + f"Missing methods: {', '.join(missing_methods)}" + ) + + return value + + def serialize_aimodel(value: Any) -> Dict[str, Any]: + """Serialize AIModel instance to dict""" + if value is None: + return None + + # Try to get metadata if available + metadata = None + if hasattr(value, 'get_metadata'): + try: + metadata = value.get_metadata() + # Convert ModelMetadata to dict if it's a Pydantic model + if hasattr(metadata, 'model_dump'): + metadata = metadata.model_dump() + elif hasattr(metadata, 'dict'): + metadata = metadata.dict() + except Exception: + pass + + return { + "type": "AIModel", + "metadata": metadata, + "has_predict": hasattr(value, 'predict'), + "has_validate": hasattr(value, 'validate'), + } + + return core_schema.no_info_plain_validator_function( + validate_aimodel, + serialization=core_schema.plain_serializer_function_ser_schema( + serialize_aimodel + ) + ) + + +# ModelContract moved to contract.py to avoid Pydantic Protocol conflicts +# Import it from .contract import ModelContract when needed + diff --git a/src/deepiri_modelkit/contracts/services.py b/src/deepiri_modelkit/contracts/services.py index e6b01e5..9011102 100644 --- a/src/deepiri_modelkit/contracts/services.py +++ b/src/deepiri_modelkit/contracts/services.py @@ -1,44 +1,58 @@ -""" -Service contracts and interfaces -""" - -from typing import Protocol, Dict, Any, Optional -from pydantic import BaseModel - - -class ModelRegistryService(Protocol): - """Interface for model registry operations""" - - def register_model( - self, model_name: str, version: str, model_path: str, metadata: Dict[str, Any] - ) -> bool: - """Register a model in the registry""" - ... - - def get_model( - self, model_name: str, version: Optional[str] = None - ) -> Dict[str, Any]: - """Get model information from registry""" - ... - - def list_models(self, model_name: Optional[str] = None) -> list: - """List available models""" - ... - - def download_model(self, model_name: str, version: str, destination: str) -> str: - """Download model to destination, returns local path""" - ... - - -class StreamingService(Protocol): - """Interface for streaming operations""" - - def publish(self, topic: str, event: Dict[str, Any]) -> bool: - """Publish event to topic""" - ... - - def subscribe( - self, topic: str, callback: callable, consumer_group: Optional[str] = None - ) -> None: - """Subscribe to topic with callback""" - ... +""" +Service contracts and interfaces +""" +from typing import Protocol, Dict, Any, Optional +from pydantic import BaseModel + + +class ModelRegistryService(Protocol): + """Interface for model registry operations""" + + def register_model( + self, + model_name: str, + version: str, + model_path: str, + metadata: Dict[str, Any] + ) -> bool: + """Register a model in the registry""" + ... + + def get_model( + self, + model_name: str, + version: Optional[str] = None + ) -> Dict[str, Any]: + """Get model information from registry""" + ... + + def list_models(self, model_name: Optional[str] = None) -> list: + """List available models""" + ... + + def download_model( + self, + model_name: str, + version: str, + destination: str + ) -> str: + """Download model to destination, returns local path""" + ... + + +class StreamingService(Protocol): + """Interface for streaming operations""" + + def publish(self, topic: str, event: Dict[str, Any]) -> bool: + """Publish event to topic""" + ... + + def subscribe( + self, + topic: str, + callback: callable, + consumer_group: Optional[str] = None + ) -> None: + """Subscribe to topic with callback""" + ... + diff --git a/src/deepiri_modelkit/data/monitoring.py b/src/deepiri_modelkit/data/monitoring.py index f8e1ffc..3fe2d65 100644 --- a/src/deepiri_modelkit/data/monitoring.py +++ b/src/deepiri_modelkit/data/monitoring.py @@ -1,419 +1,375 @@ -""" -Dataset Monitoring and Logging Utilities -Provides monitoring, alerting, and logging for dataset versioning operations -""" - -import json -import time -from pathlib import Path -from typing import Dict, List, Any, Optional -from datetime import datetime, timedelta -import statistics - -from deepiri_modelkit.logging import get_logger - -logger = get_logger("deepiri_modelkit.data.monitoring") - - -class DatasetMonitor: - """ - Monitors dataset versioning operations and provides insights. - - Features: - - Operation metrics and performance tracking - - Dataset health monitoring - - Usage analytics - - Alerting for data quality issues - """ - - def __init__(self, log_dir: str = "./logs/dataset_monitoring"): - self.log_dir = Path(log_dir) - self.log_dir.mkdir(parents=True, exist_ok=True) - - # Metrics storage - self.metrics_file = self.log_dir / "metrics.jsonl" - self.alerts_file = self.log_dir / "alerts.jsonl" - - # In-memory metrics for quick access - self.current_metrics = { - "total_versions_created": 0, - "total_datasets_tracked": 0, - "average_version_creation_time": 0, - "validation_errors_today": 0, - "last_health_check": None, - "storage_usage_bytes": 0, - } - - self._load_metrics() - - def log_version_creation(self, operation_data: Dict[str, Any]): - """Log dataset version creation operation.""" - log_entry = { - "timestamp": datetime.utcnow().isoformat(), - "operation": "version_creation", - "dataset_name": operation_data.get("dataset_name"), - "version": operation_data.get("version"), - "dataset_type": operation_data.get("dataset_type"), - "total_samples": operation_data.get("total_samples"), - "file_count": operation_data.get("file_count"), - "creation_time_seconds": operation_data.get("creation_time", 0), - "change_type": operation_data.get("change_type"), - "quality_score": operation_data.get("quality_score"), - "storage_path": operation_data.get("storage_path"), - "created_by": operation_data.get("created_by"), - } - - self._write_log_entry(self.metrics_file, log_entry) - self.current_metrics["total_versions_created"] += 1 - - logger.info( - "Version creation logged", - dataset=operation_data.get("dataset_name"), - version=operation_data.get("version"), - ) - - def log_validation_result(self, validation_data: Dict[str, Any]): - """Log dataset validation results.""" - log_entry = { - "timestamp": datetime.utcnow().isoformat(), - "operation": "validation", - "dataset_name": validation_data.get("dataset_name"), - "version": validation_data.get("version"), - "is_valid": validation_data.get("is_valid"), - "quality_score": validation_data.get("quality_score"), - "error_count": len(validation_data.get("errors", [])), - "warning_count": len(validation_data.get("warnings", [])), - "validation_time_seconds": validation_data.get("validation_time", 0), - } - - self._write_log_entry(self.metrics_file, log_entry) - - if not validation_data.get("is_valid", True): - self.current_metrics["validation_errors_today"] += 1 - - # Check for alerts - if validation_data.get("quality_score", 1.0) < 0.7: - self._create_alert( - "low_quality_score", - { - "dataset_name": validation_data.get("dataset_name"), - "version": validation_data.get("version"), - "quality_score": validation_data.get("quality_score"), - "errors": validation_data.get("errors", []), - }, - ) - - logger.info( - "Validation result logged", - dataset=validation_data.get("dataset_name"), - valid=validation_data.get("is_valid"), - quality=validation_data.get("quality_score"), - ) - - def log_training_usage(self, training_data: Dict[str, Any]): - """Log dataset usage in training.""" - log_entry = { - "timestamp": datetime.utcnow().isoformat(), - "operation": "training_usage", - "dataset_name": training_data.get("dataset_name"), - "dataset_version": training_data.get("dataset_version"), - "model_name": training_data.get("model_name"), - "training_duration_seconds": training_data.get("training_duration", 0), - "final_loss": training_data.get("final_loss"), - "experiment_id": training_data.get("experiment_id"), - "output_model_path": training_data.get("output_model_path"), - } - - self._write_log_entry(self.metrics_file, log_entry) - - logger.info( - "Training usage logged", - dataset=training_data.get("dataset_name"), - version=training_data.get("dataset_version"), - model=training_data.get("model_name"), - ) - - def get_health_report(self) -> Dict[str, Any]: - """Generate comprehensive health report.""" - report = { - "timestamp": datetime.utcnow().isoformat(), - "summary": { - "total_versions": self.current_metrics["total_versions_created"], - "datasets_tracked": self.current_metrics["total_datasets_tracked"], - "validation_errors_today": self.current_metrics[ - "validation_errors_today" - ], - "storage_usage_gb": self.current_metrics["storage_usage_bytes"] - / (1024**3), - }, - "performance": self._analyze_performance(), - "quality_trends": self._analyze_quality_trends(), - "alerts": self._get_recent_alerts(), - "recommendations": self._generate_recommendations(), - } - - self.current_metrics["last_health_check"] = report["timestamp"] - return report - - def get_usage_analytics(self, days: int = 30) -> Dict[str, Any]: - """Get usage analytics for the specified period.""" - cutoff_date = datetime.utcnow() - timedelta(days=days) - - analytics = { - "period_days": days, - "version_creations": [], - "training_runs": [], - "validation_runs": [], - "popular_datasets": {}, - "quality_distribution": {}, - } - - # Read logs and filter by date - if self.metrics_file.exists(): - with open(self.metrics_file, "r") as f: - for line in f: - try: - entry = json.loads(line.strip()) - entry_date = datetime.fromisoformat(entry["timestamp"]) - - if entry_date >= cutoff_date: - if entry["operation"] == "version_creation": - analytics["version_creations"].append(entry) - dataset = entry.get("dataset_name", "unknown") - analytics["popular_datasets"][dataset] = ( - analytics["popular_datasets"].get(dataset, 0) + 1 - ) - - elif entry["operation"] == "training_usage": - analytics["training_runs"].append(entry) - - elif entry["operation"] == "validation": - analytics["validation_runs"].append(entry) - quality = entry.get("quality_score", 0) - quality_bucket = f"{int(quality * 10) / 10:.1f}" - analytics["quality_distribution"][quality_bucket] = ( - analytics["quality_distribution"].get( - quality_bucket, 0 - ) - + 1 - ) - - except json.JSONDecodeError: - continue - - return analytics - - def _analyze_performance(self) -> Dict[str, Any]: - """Analyze system performance metrics.""" - creation_times = [] - validation_times = [] - - if self.metrics_file.exists(): - with open(self.metrics_file, "r") as f: - for line in f: - try: - entry = json.loads(line.strip()) - if entry["operation"] == "version_creation": - if "creation_time_seconds" in entry: - creation_times.append(entry["creation_time_seconds"]) - elif entry["operation"] == "validation": - if "validation_time_seconds" in entry: - validation_times.append( - entry["validation_time_seconds"] - ) - except json.JSONDecodeError: - continue - - return { - "avg_version_creation_time": ( - statistics.mean(creation_times) if creation_times else 0 - ), - "avg_validation_time": ( - statistics.mean(validation_times) if validation_times else 0 - ), - "total_operations": len(creation_times) + len(validation_times), - "creation_times": creation_times[-10:], # Last 10 - "validation_times": validation_times[-10:], # Last 10 - } - - def _analyze_quality_trends(self) -> Dict[str, Any]: - """Analyze quality trends over time.""" - quality_scores = [] - - if self.metrics_file.exists(): - with open(self.metrics_file, "r") as f: - for line in f: - try: - entry = json.loads(line.strip()) - if ( - entry["operation"] == "validation" - and "quality_score" in entry - ): - quality_scores.append(entry["quality_score"]) - except json.JSONDecodeError: - continue - - if not quality_scores: - return {"trend": "insufficient_data"} - - recent_scores = quality_scores[-20:] # Last 20 validations - avg_quality = statistics.mean(recent_scores) if recent_scores else 0 - - # Simple trend analysis - if len(recent_scores) >= 10: - first_half = recent_scores[: len(recent_scores) // 2] - second_half = recent_scores[len(recent_scores) // 2 :] - - first_avg = statistics.mean(first_half) - second_avg = statistics.mean(second_half) - - if second_avg > first_avg + 0.05: - trend = "improving" - elif second_avg < first_avg - 0.05: - trend = "declining" - else: - trend = "stable" - else: - trend = "insufficient_data" - - return { - "average_quality": avg_quality, - "trend": trend, - "total_validations": len(quality_scores), - "quality_distribution": { - "excellent": len([s for s in quality_scores if s >= 0.9]), - "good": len([s for s in quality_scores if 0.7 <= s < 0.9]), - "fair": len([s for s in quality_scores if 0.5 <= s < 0.7]), - "poor": len([s for s in quality_scores if s < 0.5]), - }, - } - - def _generate_recommendations(self) -> List[str]: - """Generate recommendations based on current state.""" - recommendations = [] - - # Check for frequent validation errors - if self.current_metrics["validation_errors_today"] > 5: - recommendations.append( - "High validation error rate detected. Review data quality processes." - ) - - # Check quality trends - quality_analysis = self._analyze_quality_trends() - if quality_analysis.get("trend") == "declining": - recommendations.append( - "Dataset quality is declining. Consider reviewing data sources and annotation processes." - ) - - # Check performance - performance = self._analyze_performance() - if performance.get("avg_version_creation_time", 0) > 300: # 5 minutes - recommendations.append( - "Version creation is slow. Consider optimizing storage or processing." - ) - - # General recommendations - if self.current_metrics["total_versions_created"] == 0: - recommendations.append( - "No dataset versions created yet. Start versioning your datasets for reproducibility." - ) - - if not recommendations: - recommendations.append( - "System operating normally. Continue regular monitoring." - ) - - return recommendations - - def _create_alert(self, alert_type: str, alert_data: Dict[str, Any]): - """Create an alert for monitoring.""" - alert_entry = { - "timestamp": datetime.utcnow().isoformat(), - "alert_type": alert_type, - "severity": "warning", # Could be "info", "warning", "error" - "data": alert_data, - "resolved": False, - } - - self._write_log_entry(self.alerts_file, alert_entry) - - logger.warning("Alert created", type=alert_type, data=alert_data) - - def _get_recent_alerts(self, hours: int = 24) -> List[Dict[str, Any]]: - """Get recent alerts.""" - alerts = [] - cutoff_time = datetime.utcnow() - timedelta(hours=hours) - - if self.alerts_file.exists(): - with open(self.alerts_file, "r") as f: - for line in f: - try: - alert = json.loads(line.strip()) - alert_time = datetime.fromisoformat(alert["timestamp"]) - if alert_time >= cutoff_time: - alerts.append(alert) - except json.JSONDecodeError: - continue - - return alerts[-10:] # Return last 10 alerts - - def _write_log_entry(self, log_file: Path, entry: Dict[str, Any]): - """Write a log entry to file.""" - with open(log_file, "a") as f: - f.write(json.dumps(entry) + "\n") - - def _load_metrics(self): - """Load current metrics from log files.""" - if self.metrics_file.exists(): - try: - with open(self.metrics_file, "r") as f: - lines = f.readlines() - if lines: - # Count operations from logs - version_count = sum( - 1 - for line in lines - if '"operation": "version_creation"' in line - ) - validation_count = sum( - 1 - for line in lines - if '"operation": "validation"' in line - and '"is_valid": false' in line - ) - - self.current_metrics["total_versions_created"] = version_count - self.current_metrics["validation_errors_today"] = ( - validation_count - ) - except Exception as e: - logger.warning("Failed to load metrics from log", error=str(e)) - - -# Convenience functions -def log_version_creation(dataset_name: str, version: str, **kwargs): - """Convenience function to log version creation.""" - monitor = DatasetMonitor() - monitor.log_version_creation( - {"dataset_name": dataset_name, "version": version, **kwargs} - ) - - -def log_validation_result(dataset_name: str, version: str, **kwargs): - """Convenience function to log validation results.""" - monitor = DatasetMonitor() - monitor.log_validation_result( - {"dataset_name": dataset_name, "version": version, **kwargs} - ) - - -def get_health_report(): - """Get current health report.""" - monitor = DatasetMonitor() - return monitor.get_health_report() - - -def get_usage_analytics(days: int = 30): - """Get usage analytics.""" - monitor = DatasetMonitor() - return monitor.get_usage_analytics(days) +""" +Dataset Monitoring and Logging Utilities +Provides monitoring, alerting, and logging for dataset versioning operations +""" +import json +import time +from pathlib import Path +from typing import Dict, List, Any, Optional +from datetime import datetime, timedelta +import statistics + +from deepiri_modelkit.logging import get_logger + +logger = get_logger("deepiri_modelkit.data.monitoring") + + +class DatasetMonitor: + """ + Monitors dataset versioning operations and provides insights. + + Features: + - Operation metrics and performance tracking + - Dataset health monitoring + - Usage analytics + - Alerting for data quality issues + """ + + def __init__(self, log_dir: str = "./logs/dataset_monitoring"): + self.log_dir = Path(log_dir) + self.log_dir.mkdir(parents=True, exist_ok=True) + + # Metrics storage + self.metrics_file = self.log_dir / "metrics.jsonl" + self.alerts_file = self.log_dir / "alerts.jsonl" + + # In-memory metrics for quick access + self.current_metrics = { + "total_versions_created": 0, + "total_datasets_tracked": 0, + "average_version_creation_time": 0, + "validation_errors_today": 0, + "last_health_check": None, + "storage_usage_bytes": 0 + } + + self._load_metrics() + + def log_version_creation(self, operation_data: Dict[str, Any]): + """Log dataset version creation operation.""" + log_entry = { + "timestamp": datetime.utcnow().isoformat(), + "operation": "version_creation", + "dataset_name": operation_data.get("dataset_name"), + "version": operation_data.get("version"), + "dataset_type": operation_data.get("dataset_type"), + "total_samples": operation_data.get("total_samples"), + "file_count": operation_data.get("file_count"), + "creation_time_seconds": operation_data.get("creation_time", 0), + "change_type": operation_data.get("change_type"), + "quality_score": operation_data.get("quality_score"), + "storage_path": operation_data.get("storage_path"), + "created_by": operation_data.get("created_by") + } + + self._write_log_entry(self.metrics_file, log_entry) + self.current_metrics["total_versions_created"] += 1 + + logger.info("Version creation logged", + dataset=operation_data.get("dataset_name"), + version=operation_data.get("version")) + + def log_validation_result(self, validation_data: Dict[str, Any]): + """Log dataset validation results.""" + log_entry = { + "timestamp": datetime.utcnow().isoformat(), + "operation": "validation", + "dataset_name": validation_data.get("dataset_name"), + "version": validation_data.get("version"), + "is_valid": validation_data.get("is_valid"), + "quality_score": validation_data.get("quality_score"), + "error_count": len(validation_data.get("errors", [])), + "warning_count": len(validation_data.get("warnings", [])), + "validation_time_seconds": validation_data.get("validation_time", 0) + } + + self._write_log_entry(self.metrics_file, log_entry) + + if not validation_data.get("is_valid", True): + self.current_metrics["validation_errors_today"] += 1 + + # Check for alerts + if validation_data.get("quality_score", 1.0) < 0.7: + self._create_alert("low_quality_score", { + "dataset_name": validation_data.get("dataset_name"), + "version": validation_data.get("version"), + "quality_score": validation_data.get("quality_score"), + "errors": validation_data.get("errors", []) + }) + + logger.info("Validation result logged", + dataset=validation_data.get("dataset_name"), + valid=validation_data.get("is_valid"), + quality=validation_data.get("quality_score")) + + def log_training_usage(self, training_data: Dict[str, Any]): + """Log dataset usage in training.""" + log_entry = { + "timestamp": datetime.utcnow().isoformat(), + "operation": "training_usage", + "dataset_name": training_data.get("dataset_name"), + "dataset_version": training_data.get("dataset_version"), + "model_name": training_data.get("model_name"), + "training_duration_seconds": training_data.get("training_duration", 0), + "final_loss": training_data.get("final_loss"), + "experiment_id": training_data.get("experiment_id"), + "output_model_path": training_data.get("output_model_path") + } + + self._write_log_entry(self.metrics_file, log_entry) + + logger.info("Training usage logged", + dataset=training_data.get("dataset_name"), + version=training_data.get("dataset_version"), + model=training_data.get("model_name")) + + def get_health_report(self) -> Dict[str, Any]: + """Generate comprehensive health report.""" + report = { + "timestamp": datetime.utcnow().isoformat(), + "summary": { + "total_versions": self.current_metrics["total_versions_created"], + "datasets_tracked": self.current_metrics["total_datasets_tracked"], + "validation_errors_today": self.current_metrics["validation_errors_today"], + "storage_usage_gb": self.current_metrics["storage_usage_bytes"] / (1024**3) + }, + "performance": self._analyze_performance(), + "quality_trends": self._analyze_quality_trends(), + "alerts": self._get_recent_alerts(), + "recommendations": self._generate_recommendations() + } + + self.current_metrics["last_health_check"] = report["timestamp"] + return report + + def get_usage_analytics(self, days: int = 30) -> Dict[str, Any]: + """Get usage analytics for the specified period.""" + cutoff_date = datetime.utcnow() - timedelta(days=days) + + analytics = { + "period_days": days, + "version_creations": [], + "training_runs": [], + "validation_runs": [], + "popular_datasets": {}, + "quality_distribution": {} + } + + # Read logs and filter by date + if self.metrics_file.exists(): + with open(self.metrics_file, 'r') as f: + for line in f: + try: + entry = json.loads(line.strip()) + entry_date = datetime.fromisoformat(entry["timestamp"]) + + if entry_date >= cutoff_date: + if entry["operation"] == "version_creation": + analytics["version_creations"].append(entry) + dataset = entry.get("dataset_name", "unknown") + analytics["popular_datasets"][dataset] = analytics["popular_datasets"].get(dataset, 0) + 1 + + elif entry["operation"] == "training_usage": + analytics["training_runs"].append(entry) + + elif entry["operation"] == "validation": + analytics["validation_runs"].append(entry) + quality = entry.get("quality_score", 0) + quality_bucket = f"{int(quality * 10) / 10:.1f}" + analytics["quality_distribution"][quality_bucket] = analytics["quality_distribution"].get(quality_bucket, 0) + 1 + + except json.JSONDecodeError: + continue + + return analytics + + def _analyze_performance(self) -> Dict[str, Any]: + """Analyze system performance metrics.""" + creation_times = [] + validation_times = [] + + if self.metrics_file.exists(): + with open(self.metrics_file, 'r') as f: + for line in f: + try: + entry = json.loads(line.strip()) + if entry["operation"] == "version_creation": + if "creation_time_seconds" in entry: + creation_times.append(entry["creation_time_seconds"]) + elif entry["operation"] == "validation": + if "validation_time_seconds" in entry: + validation_times.append(entry["validation_time_seconds"]) + except json.JSONDecodeError: + continue + + return { + "avg_version_creation_time": statistics.mean(creation_times) if creation_times else 0, + "avg_validation_time": statistics.mean(validation_times) if validation_times else 0, + "total_operations": len(creation_times) + len(validation_times), + "creation_times": creation_times[-10:], # Last 10 + "validation_times": validation_times[-10:] # Last 10 + } + + def _analyze_quality_trends(self) -> Dict[str, Any]: + """Analyze quality trends over time.""" + quality_scores = [] + + if self.metrics_file.exists(): + with open(self.metrics_file, 'r') as f: + for line in f: + try: + entry = json.loads(line.strip()) + if entry["operation"] == "validation" and "quality_score" in entry: + quality_scores.append(entry["quality_score"]) + except json.JSONDecodeError: + continue + + if not quality_scores: + return {"trend": "insufficient_data"} + + recent_scores = quality_scores[-20:] # Last 20 validations + avg_quality = statistics.mean(recent_scores) if recent_scores else 0 + + # Simple trend analysis + if len(recent_scores) >= 10: + first_half = recent_scores[:len(recent_scores)//2] + second_half = recent_scores[len(recent_scores)//2:] + + first_avg = statistics.mean(first_half) + second_avg = statistics.mean(second_half) + + if second_avg > first_avg + 0.05: + trend = "improving" + elif second_avg < first_avg - 0.05: + trend = "declining" + else: + trend = "stable" + else: + trend = "insufficient_data" + + return { + "average_quality": avg_quality, + "trend": trend, + "total_validations": len(quality_scores), + "quality_distribution": { + "excellent": len([s for s in quality_scores if s >= 0.9]), + "good": len([s for s in quality_scores if 0.7 <= s < 0.9]), + "fair": len([s for s in quality_scores if 0.5 <= s < 0.7]), + "poor": len([s for s in quality_scores if s < 0.5]) + } + } + + def _generate_recommendations(self) -> List[str]: + """Generate recommendations based on current state.""" + recommendations = [] + + # Check for frequent validation errors + if self.current_metrics["validation_errors_today"] > 5: + recommendations.append("High validation error rate detected. Review data quality processes.") + + # Check quality trends + quality_analysis = self._analyze_quality_trends() + if quality_analysis.get("trend") == "declining": + recommendations.append("Dataset quality is declining. Consider reviewing data sources and annotation processes.") + + # Check performance + performance = self._analyze_performance() + if performance.get("avg_version_creation_time", 0) > 300: # 5 minutes + recommendations.append("Version creation is slow. Consider optimizing storage or processing.") + + # General recommendations + if self.current_metrics["total_versions_created"] == 0: + recommendations.append("No dataset versions created yet. Start versioning your datasets for reproducibility.") + + if not recommendations: + recommendations.append("System operating normally. Continue regular monitoring.") + + return recommendations + + def _create_alert(self, alert_type: str, alert_data: Dict[str, Any]): + """Create an alert for monitoring.""" + alert_entry = { + "timestamp": datetime.utcnow().isoformat(), + "alert_type": alert_type, + "severity": "warning", # Could be "info", "warning", "error" + "data": alert_data, + "resolved": False + } + + self._write_log_entry(self.alerts_file, alert_entry) + + logger.warning("Alert created", + type=alert_type, + data=alert_data) + + def _get_recent_alerts(self, hours: int = 24) -> List[Dict[str, Any]]: + """Get recent alerts.""" + alerts = [] + cutoff_time = datetime.utcnow() - timedelta(hours=hours) + + if self.alerts_file.exists(): + with open(self.alerts_file, 'r') as f: + for line in f: + try: + alert = json.loads(line.strip()) + alert_time = datetime.fromisoformat(alert["timestamp"]) + if alert_time >= cutoff_time: + alerts.append(alert) + except json.JSONDecodeError: + continue + + return alerts[-10:] # Return last 10 alerts + + def _write_log_entry(self, log_file: Path, entry: Dict[str, Any]): + """Write a log entry to file.""" + with open(log_file, 'a') as f: + f.write(json.dumps(entry) + '\n') + + def _load_metrics(self): + """Load current metrics from log files.""" + if self.metrics_file.exists(): + try: + with open(self.metrics_file, 'r') as f: + lines = f.readlines() + if lines: + # Count operations from logs + version_count = sum(1 for line in lines if '"operation": "version_creation"' in line) + validation_count = sum(1 for line in lines if '"operation": "validation"' in line and '"is_valid": false' in line) + + self.current_metrics["total_versions_created"] = version_count + self.current_metrics["validation_errors_today"] = validation_count + except Exception as e: + logger.warning("Failed to load metrics from log", error=str(e)) + + +# Convenience functions +def log_version_creation(dataset_name: str, version: str, **kwargs): + """Convenience function to log version creation.""" + monitor = DatasetMonitor() + monitor.log_version_creation({ + "dataset_name": dataset_name, + "version": version, + **kwargs + }) + + +def log_validation_result(dataset_name: str, version: str, **kwargs): + """Convenience function to log validation results.""" + monitor = DatasetMonitor() + monitor.log_validation_result({ + "dataset_name": dataset_name, + "version": version, + **kwargs + }) + + +def get_health_report(): + """Get current health report.""" + monitor = DatasetMonitor() + return monitor.get_health_report() + + +def get_usage_analytics(days: int = 30): + """Get usage analytics.""" + monitor = DatasetMonitor() + return monitor.get_usage_analytics(days) diff --git a/src/deepiri_modelkit/data/validation.py b/src/deepiri_modelkit/data/validation.py index acc2aa8..25e7b2e 100644 --- a/src/deepiri_modelkit/data/validation.py +++ b/src/deepiri_modelkit/data/validation.py @@ -1,400 +1,364 @@ -""" -Dataset Validation Utilities -Provides validation and quality checks for language intelligence datasets -""" - -import json -from pathlib import Path -from typing import Dict, List, Any, Optional -import re -from collections import Counter - -from deepiri_modelkit.logging import get_logger - -logger = get_logger("deepiri_modelkit.data.validation") - - -class DatasetValidator: - """ - Validates dataset quality and integrity for language intelligence tasks. - - Supports validation for: - - Lease abstraction datasets - - Contract intelligence datasets - - General text quality checks - """ - - def __init__(self, dataset_type: str = "general"): - self.dataset_type = dataset_type - self.validation_rules = self._get_validation_rules() - - def _get_validation_rules(self) -> Dict[str, Any]: - """Get validation rules based on dataset type.""" - base_rules = { - "min_samples": 10, - "max_samples": 100000, - "min_text_length": 10, - "max_text_length": 10000, - "required_fields": ["text"], - "text_quality_checks": True, - } - - type_specific_rules = { - "lease_abstraction": { - "min_samples": 50, - "lease_keywords": [ - "lease", - "agreement", - "landlord", - "tenant", - "rent", - "premises", - "term", - "commencement", - "expiration", - ], - "min_keyword_matches": 2, - "check_address_patterns": True, - "check_rent_patterns": True, - }, - "contract_intelligence": { - "min_samples": 50, - "contract_keywords": [ - "contract", - "agreement", - "party", - "obligation", - "clause", - "provision", - "section", - "article", - ], - "min_keyword_matches": 2, - "check_legal_patterns": True, - }, - } - - if self.dataset_type in type_specific_rules: - base_rules.update(type_specific_rules[self.dataset_type]) - - return base_rules - - def validate_dataset(self, data_path: Path) -> Dict[str, Any]: - """ - Comprehensive dataset validation. - - Args: - data_path: Path to dataset files - - Returns: - Validation results dictionary - """ - logger.info( - "Starting dataset validation", path=str(data_path), type=self.dataset_type - ) - - results = { - "is_valid": True, - "errors": [], - "warnings": [], - "statistics": {}, - "quality_score": 0.0, - } - - try: - # Load and parse data - samples = self._load_samples(data_path) - results["statistics"]["total_samples"] = len(samples) - - if not samples: - results["is_valid"] = False - results["errors"].append("No samples found in dataset") - return results - - # Basic structure validation - self._validate_structure(samples, results) - - # Content quality validation - if results["is_valid"]: - self._validate_content_quality(samples, results) - - # Type-specific validation - if self.dataset_type != "general": - self._validate_type_specific(samples, results) - - # Calculate overall quality score - results["quality_score"] = self._calculate_quality_score(results) - - # Determine final validity - results["is_valid"] = len(results["errors"]) == 0 - - except Exception as e: - results["is_valid"] = False - results["errors"].append(f"Validation failed with error: {str(e)}") - logger.error("Dataset validation failed", error=str(e)) - - logger.info( - "Dataset validation complete", - valid=results["is_valid"], - quality_score=results["quality_score"], - errors=len(results["errors"]), - warnings=len(results["warnings"]), - ) - - return results - - def _load_samples(self, data_path: Path) -> List[Dict[str, Any]]: - """Load samples from dataset files.""" - samples = [] - - if data_path.is_file() and data_path.suffix == ".jsonl": - with open(data_path, "r", encoding="utf-8") as f: - for line_num, line in enumerate(f, 1): - line = line.strip() - if line: - try: - sample = json.loads(line) - samples.append(sample) - except json.JSONDecodeError as e: - logger.warning(f"Invalid JSON at line {line_num}: {e}") - - elif data_path.is_dir(): - for file_path in data_path.glob("*.jsonl"): - with open(file_path, "r", encoding="utf-8") as f: - for line_num, line in enumerate(f, 1): - line = line.strip() - if line: - try: - sample = json.loads(line) - samples.append(sample) - except json.JSONDecodeError as e: - logger.warning( - f"Invalid JSON in {file_path} at line {line_num}: {e}" - ) - - return samples - - def _validate_structure(self, samples: List[Dict], results: Dict): - """Validate basic dataset structure.""" - if len(samples) < self.validation_rules["min_samples"]: - results["errors"].append( - f"Insufficient samples: {len(samples)} < {self.validation_rules['min_samples']}" - ) - - if len(samples) > self.validation_rules["max_samples"]: - results["warnings"].append( - f"Large dataset: {len(samples)} > {self.validation_rules['max_samples']}" - ) - - # Check required fields - required_fields = self.validation_rules["required_fields"] - for i, sample in enumerate(samples[:100]): # Check first 100 samples - for field in required_fields: - if field not in sample: - results["errors"].append( - f"Missing required field '{field}' in sample {i}" - ) - - def _validate_content_quality(self, samples: List[Dict], results: Dict): - """Validate content quality.""" - text_lengths = [] - empty_texts = 0 - duplicate_texts = set() - seen_texts = set() - - for sample in samples: - text = sample.get("text", "").strip() - - # Check text length - text_len = len(text) - text_lengths.append(text_len) - - if text_len < self.validation_rules["min_text_length"]: - results["errors"].append(f"Text too short: {text_len} chars") - elif text_len > self.validation_rules["max_text_length"]: - results["warnings"].append(f"Text too long: {text_len} chars") - - if not text: - empty_texts += 1 - - # Check for duplicates - if text in seen_texts: - duplicate_texts.add(text) - else: - seen_texts.add(text) - - # Statistics - results["statistics"].update( - { - "avg_text_length": ( - sum(text_lengths) / len(text_lengths) if text_lengths else 0 - ), - "min_text_length": min(text_lengths) if text_lengths else 0, - "max_text_length": max(text_lengths) if text_lengths else 0, - "empty_texts": empty_texts, - "duplicate_texts": len(duplicate_texts), - } - ) - - if empty_texts > 0: - results["errors"].append(f"Found {empty_texts} empty text samples") - - if len(duplicate_texts) > len(samples) * 0.01: # >1% duplicates - results["warnings"].append( - f"High duplicate rate: {len(duplicate_texts)} duplicates" - ) - - def _validate_type_specific(self, samples: List[Dict], results: Dict): - """Type-specific validation.""" - if self.dataset_type == "lease_abstraction": - self._validate_lease_abstraction(samples, results) - elif self.dataset_type == "contract_intelligence": - self._validate_contract_intelligence(samples, results) - - def _validate_lease_abstraction(self, samples: List[Dict], results: Dict): - """Validate lease abstraction dataset.""" - keywords = self.validation_rules["lease_keywords"] - min_matches = self.validation_rules["min_keyword_matches"] - - low_keyword_samples = 0 - address_pattern_matches = 0 - rent_pattern_matches = 0 - - # Address patterns (street numbers, street names, cities) - address_pattern = r"\d+\s+[A-Za-z0-9\s,.-]+(?:Street|St|Avenue|Ave|Road|Rd|Boulevard|Blvd|Drive|Dr|Lane|Ln|Way|Place|Pl|Court|Ct)\s*,?\s*[A-Za-z\s]+,?\s*\d{5}" - - # Rent patterns (dollar amounts) - rent_pattern = r"\$\d{1,3}(?:,\d{3})*(?:\.\d{2})?" - - for sample in samples[:500]: # Check first 500 samples for performance - text = sample.get("text", "").lower() - - # Keyword matching - keyword_matches = sum(1 for keyword in keywords if keyword in text) - if keyword_matches < min_matches: - low_keyword_samples += 1 - - # Pattern matching - if re.search(address_pattern, sample.get("text", "")): - address_pattern_matches += 1 - - if re.search(rent_pattern, sample.get("text", "")): - rent_pattern_matches += 1 - - total_checked = min(500, len(samples)) - keyword_failure_rate = low_keyword_samples / total_checked - - if keyword_failure_rate > 0.3: # >30% samples lack keywords - results["warnings"].append( - f"Low keyword relevance: {keyword_failure_rate:.1%} samples lack lease keywords" - ) - - results["statistics"].update( - { - "address_pattern_matches": address_pattern_matches, - "rent_pattern_matches": rent_pattern_matches, - "keyword_relevance_score": 1.0 - keyword_failure_rate, - } - ) - - def _validate_contract_intelligence(self, samples: List[Dict], results: Dict): - """Validate contract intelligence dataset.""" - keywords = self.validation_rules["contract_keywords"] - min_matches = self.validation_rules["min_keyword_matches"] - - low_keyword_samples = 0 - legal_pattern_matches = 0 - - # Legal clause patterns - legal_patterns = [ - r"\bsection\s+\d+", - r"\barticle\s+\d+", - r"\bclause\s+\d+", - r"\bparagraph\s+\d+", - r"\bsubsection\s+\d+", - ] - - for sample in samples[:500]: # Check first 500 samples - text = sample.get("text", "").lower() - - # Keyword matching - keyword_matches = sum(1 for keyword in keywords if keyword in text) - if keyword_matches < min_matches: - low_keyword_samples += 1 - - # Legal pattern matching - if any( - re.search(pattern, sample.get("text", "")) for pattern in legal_patterns - ): - legal_pattern_matches += 1 - - total_checked = min(500, len(samples)) - keyword_failure_rate = low_keyword_samples / total_checked - - if keyword_failure_rate > 0.3: - results["warnings"].append( - f"Low keyword relevance: {keyword_failure_rate:.1%} samples lack contract keywords" - ) - - results["statistics"].update( - { - "legal_pattern_matches": legal_pattern_matches, - "keyword_relevance_score": 1.0 - keyword_failure_rate, - } - ) - - def _calculate_quality_score(self, results: Dict) -> float: - """Calculate overall quality score (0.0 to 1.0).""" - score = 1.0 - - # Penalize errors heavily - error_penalty = len(results["errors"]) * 0.2 - score -= min(error_penalty, 0.8) - - # Penalize warnings moderately - warning_penalty = len(results["warnings"]) * 0.05 - score -= min(warning_penalty, 0.2) - - # Bonus for good statistics - stats = results["statistics"] - - if stats.get("avg_text_length", 0) > 100: - score += 0.05 # Good average text length - - if stats.get("duplicate_texts", 0) == 0: - score += 0.1 # No duplicates - - if stats.get("empty_texts", 0) == 0: - score += 0.1 # No empty texts - - # Type-specific bonuses - if self.dataset_type == "lease_abstraction": - if stats.get("keyword_relevance_score", 0) > 0.7: - score += 0.1 - if stats.get("address_pattern_matches", 0) > 0: - score += 0.05 - - elif self.dataset_type == "contract_intelligence": - if stats.get("keyword_relevance_score", 0) > 0.7: - score += 0.1 - if stats.get("legal_pattern_matches", 0) > 0: - score += 0.05 - - return max(0.0, min(1.0, score)) - - -def validate_dataset_quality( - data_path: Path, dataset_type: str = "general" -) -> Dict[str, Any]: - """ - Convenience function to validate dataset quality. - - Args: - data_path: Path to dataset - dataset_type: Type of dataset for specialized validation - - Returns: - Validation results - """ - validator = DatasetValidator(dataset_type) - return validator.validate_dataset(data_path) +""" +Dataset Validation Utilities +Provides validation and quality checks for language intelligence datasets +""" +import json +from pathlib import Path +from typing import Dict, List, Any, Optional +import re +from collections import Counter + +from deepiri_modelkit.logging import get_logger + +logger = get_logger("deepiri_modelkit.data.validation") + + +class DatasetValidator: + """ + Validates dataset quality and integrity for language intelligence tasks. + + Supports validation for: + - Lease abstraction datasets + - Contract intelligence datasets + - General text quality checks + """ + + def __init__(self, dataset_type: str = "general"): + self.dataset_type = dataset_type + self.validation_rules = self._get_validation_rules() + + def _get_validation_rules(self) -> Dict[str, Any]: + """Get validation rules based on dataset type.""" + base_rules = { + "min_samples": 10, + "max_samples": 100000, + "min_text_length": 10, + "max_text_length": 10000, + "required_fields": ["text"], + "text_quality_checks": True + } + + type_specific_rules = { + "lease_abstraction": { + "min_samples": 50, + "lease_keywords": [ + "lease", "agreement", "landlord", "tenant", "rent", + "premises", "term", "commencement", "expiration" + ], + "min_keyword_matches": 2, + "check_address_patterns": True, + "check_rent_patterns": True + }, + "contract_intelligence": { + "min_samples": 50, + "contract_keywords": [ + "contract", "agreement", "party", "obligation", + "clause", "provision", "section", "article" + ], + "min_keyword_matches": 2, + "check_legal_patterns": True + } + } + + if self.dataset_type in type_specific_rules: + base_rules.update(type_specific_rules[self.dataset_type]) + + return base_rules + + def validate_dataset(self, data_path: Path) -> Dict[str, Any]: + """ + Comprehensive dataset validation. + + Args: + data_path: Path to dataset files + + Returns: + Validation results dictionary + """ + logger.info("Starting dataset validation", path=str(data_path), type=self.dataset_type) + + results = { + "is_valid": True, + "errors": [], + "warnings": [], + "statistics": {}, + "quality_score": 0.0 + } + + try: + # Load and parse data + samples = self._load_samples(data_path) + results["statistics"]["total_samples"] = len(samples) + + if not samples: + results["is_valid"] = False + results["errors"].append("No samples found in dataset") + return results + + # Basic structure validation + self._validate_structure(samples, results) + + # Content quality validation + if results["is_valid"]: + self._validate_content_quality(samples, results) + + # Type-specific validation + if self.dataset_type != "general": + self._validate_type_specific(samples, results) + + # Calculate overall quality score + results["quality_score"] = self._calculate_quality_score(results) + + # Determine final validity + results["is_valid"] = len(results["errors"]) == 0 + + except Exception as e: + results["is_valid"] = False + results["errors"].append(f"Validation failed with error: {str(e)}") + logger.error("Dataset validation failed", error=str(e)) + + logger.info("Dataset validation complete", + valid=results["is_valid"], + quality_score=results["quality_score"], + errors=len(results["errors"]), + warnings=len(results["warnings"])) + + return results + + def _load_samples(self, data_path: Path) -> List[Dict[str, Any]]: + """Load samples from dataset files.""" + samples = [] + + if data_path.is_file() and data_path.suffix == ".jsonl": + with open(data_path, 'r', encoding='utf-8') as f: + for line_num, line in enumerate(f, 1): + line = line.strip() + if line: + try: + sample = json.loads(line) + samples.append(sample) + except json.JSONDecodeError as e: + logger.warning(f"Invalid JSON at line {line_num}: {e}") + + elif data_path.is_dir(): + for file_path in data_path.glob("*.jsonl"): + with open(file_path, 'r', encoding='utf-8') as f: + for line_num, line in enumerate(f, 1): + line = line.strip() + if line: + try: + sample = json.loads(line) + samples.append(sample) + except json.JSONDecodeError as e: + logger.warning(f"Invalid JSON in {file_path} at line {line_num}: {e}") + + return samples + + def _validate_structure(self, samples: List[Dict], results: Dict): + """Validate basic dataset structure.""" + if len(samples) < self.validation_rules["min_samples"]: + results["errors"].append( + f"Insufficient samples: {len(samples)} < {self.validation_rules['min_samples']}" + ) + + if len(samples) > self.validation_rules["max_samples"]: + results["warnings"].append( + f"Large dataset: {len(samples)} > {self.validation_rules['max_samples']}" + ) + + # Check required fields + required_fields = self.validation_rules["required_fields"] + for i, sample in enumerate(samples[:100]): # Check first 100 samples + for field in required_fields: + if field not in sample: + results["errors"].append(f"Missing required field '{field}' in sample {i}") + + def _validate_content_quality(self, samples: List[Dict], results: Dict): + """Validate content quality.""" + text_lengths = [] + empty_texts = 0 + duplicate_texts = set() + seen_texts = set() + + for sample in samples: + text = sample.get("text", "").strip() + + # Check text length + text_len = len(text) + text_lengths.append(text_len) + + if text_len < self.validation_rules["min_text_length"]: + results["errors"].append(f"Text too short: {text_len} chars") + elif text_len > self.validation_rules["max_text_length"]: + results["warnings"].append(f"Text too long: {text_len} chars") + + if not text: + empty_texts += 1 + + # Check for duplicates + if text in seen_texts: + duplicate_texts.add(text) + else: + seen_texts.add(text) + + # Statistics + results["statistics"].update({ + "avg_text_length": sum(text_lengths) / len(text_lengths) if text_lengths else 0, + "min_text_length": min(text_lengths) if text_lengths else 0, + "max_text_length": max(text_lengths) if text_lengths else 0, + "empty_texts": empty_texts, + "duplicate_texts": len(duplicate_texts) + }) + + if empty_texts > 0: + results["errors"].append(f"Found {empty_texts} empty text samples") + + if len(duplicate_texts) > len(samples) * 0.01: # >1% duplicates + results["warnings"].append(f"High duplicate rate: {len(duplicate_texts)} duplicates") + + def _validate_type_specific(self, samples: List[Dict], results: Dict): + """Type-specific validation.""" + if self.dataset_type == "lease_abstraction": + self._validate_lease_abstraction(samples, results) + elif self.dataset_type == "contract_intelligence": + self._validate_contract_intelligence(samples, results) + + def _validate_lease_abstraction(self, samples: List[Dict], results: Dict): + """Validate lease abstraction dataset.""" + keywords = self.validation_rules["lease_keywords"] + min_matches = self.validation_rules["min_keyword_matches"] + + low_keyword_samples = 0 + address_pattern_matches = 0 + rent_pattern_matches = 0 + + # Address patterns (street numbers, street names, cities) + address_pattern = r'\d+\s+[A-Za-z0-9\s,.-]+(?:Street|St|Avenue|Ave|Road|Rd|Boulevard|Blvd|Drive|Dr|Lane|Ln|Way|Place|Pl|Court|Ct)\s*,?\s*[A-Za-z\s]+,?\s*\d{5}' + + # Rent patterns (dollar amounts) + rent_pattern = r'\$\d{1,3}(?:,\d{3})*(?:\.\d{2})?' + + for sample in samples[:500]: # Check first 500 samples for performance + text = sample.get("text", "").lower() + + # Keyword matching + keyword_matches = sum(1 for keyword in keywords if keyword in text) + if keyword_matches < min_matches: + low_keyword_samples += 1 + + # Pattern matching + if re.search(address_pattern, sample.get("text", "")): + address_pattern_matches += 1 + + if re.search(rent_pattern, sample.get("text", "")): + rent_pattern_matches += 1 + + total_checked = min(500, len(samples)) + keyword_failure_rate = low_keyword_samples / total_checked + + if keyword_failure_rate > 0.3: # >30% samples lack keywords + results["warnings"].append( + f"Low keyword relevance: {keyword_failure_rate:.1%} samples lack lease keywords" + ) + + results["statistics"].update({ + "address_pattern_matches": address_pattern_matches, + "rent_pattern_matches": rent_pattern_matches, + "keyword_relevance_score": 1.0 - keyword_failure_rate + }) + + def _validate_contract_intelligence(self, samples: List[Dict], results: Dict): + """Validate contract intelligence dataset.""" + keywords = self.validation_rules["contract_keywords"] + min_matches = self.validation_rules["min_keyword_matches"] + + low_keyword_samples = 0 + legal_pattern_matches = 0 + + # Legal clause patterns + legal_patterns = [ + r'\bsection\s+\d+', + r'\barticle\s+\d+', + r'\bclause\s+\d+', + r'\bparagraph\s+\d+', + r'\bsubsection\s+\d+' + ] + + for sample in samples[:500]: # Check first 500 samples + text = sample.get("text", "").lower() + + # Keyword matching + keyword_matches = sum(1 for keyword in keywords if keyword in text) + if keyword_matches < min_matches: + low_keyword_samples += 1 + + # Legal pattern matching + if any(re.search(pattern, sample.get("text", "")) for pattern in legal_patterns): + legal_pattern_matches += 1 + + total_checked = min(500, len(samples)) + keyword_failure_rate = low_keyword_samples / total_checked + + if keyword_failure_rate > 0.3: + results["warnings"].append( + f"Low keyword relevance: {keyword_failure_rate:.1%} samples lack contract keywords" + ) + + results["statistics"].update({ + "legal_pattern_matches": legal_pattern_matches, + "keyword_relevance_score": 1.0 - keyword_failure_rate + }) + + def _calculate_quality_score(self, results: Dict) -> float: + """Calculate overall quality score (0.0 to 1.0).""" + score = 1.0 + + # Penalize errors heavily + error_penalty = len(results["errors"]) * 0.2 + score -= min(error_penalty, 0.8) + + # Penalize warnings moderately + warning_penalty = len(results["warnings"]) * 0.05 + score -= min(warning_penalty, 0.2) + + # Bonus for good statistics + stats = results["statistics"] + + if stats.get("avg_text_length", 0) > 100: + score += 0.05 # Good average text length + + if stats.get("duplicate_texts", 0) == 0: + score += 0.1 # No duplicates + + if stats.get("empty_texts", 0) == 0: + score += 0.1 # No empty texts + + # Type-specific bonuses + if self.dataset_type == "lease_abstraction": + if stats.get("keyword_relevance_score", 0) > 0.7: + score += 0.1 + if stats.get("address_pattern_matches", 0) > 0: + score += 0.05 + + elif self.dataset_type == "contract_intelligence": + if stats.get("keyword_relevance_score", 0) > 0.7: + score += 0.1 + if stats.get("legal_pattern_matches", 0) > 0: + score += 0.05 + + return max(0.0, min(1.0, score)) + + +def validate_dataset_quality(data_path: Path, dataset_type: str = "general") -> Dict[str, Any]: + """ + Convenience function to validate dataset quality. + + Args: + data_path: Path to dataset + dataset_type: Type of dataset for specialized validation + + Returns: + Validation results + """ + validator = DatasetValidator(dataset_type) + return validator.validate_dataset(data_path) diff --git a/src/deepiri_modelkit/logging.py b/src/deepiri_modelkit/logging.py index 450eb0c..9b65998 100644 --- a/src/deepiri_modelkit/logging.py +++ b/src/deepiri_modelkit/logging.py @@ -1,171 +1,147 @@ -""" -Shared logging utilities for all Deepiri services -Used by: Cyrex (runtime), Helox (training), and all microservices -""" - -import logging -import sys -import json -from datetime import datetime -from typing import Any, Dict, Optional -from pathlib import Path - - -class StructuredLogger: - """JSON structured logger for all Deepiri services""" - - def __init__(self, name: str, level: int = logging.INFO): - self.logger = logging.getLogger(name) - self.logger.setLevel(level) - - # Remove existing handlers - self.logger.handlers = [] - - # Create console handler with JSON formatting - handler = logging.StreamHandler(sys.stdout) - handler.setLevel(level) - handler.setFormatter(JsonFormatter()) - self.logger.addHandler(handler) - - self.logger.propagate = False - - def _log(self, level: int, event: str, **kwargs): - """Internal log method with structured data""" - extra = {"event": event, "timestamp": datetime.utcnow().isoformat() + "Z"} - extra.update(kwargs) - self.logger.log(level, json.dumps(extra)) - - def debug(self, event: str, **kwargs): - self._log(logging.DEBUG, event, **kwargs) - - def info(self, event: str, **kwargs): - self._log(logging.INFO, event, **kwargs) - - def warning(self, event: str, **kwargs): - self._log(logging.WARNING, event, **kwargs) - - def error(self, event: str, **kwargs): - self._log(logging.ERROR, event, **kwargs) - - def critical(self, event: str, **kwargs): - self._log(logging.CRITICAL, event, **kwargs) - - -class JsonFormatter(logging.Formatter): - """Format logs as JSON""" - - def format(self, record: logging.LogRecord) -> str: - log_data = { - "timestamp": datetime.utcnow().isoformat() + "Z", - "level": record.levelname, - "logger": record.name, - "message": record.getMessage(), - } - - # Add extra fields if present - if hasattr(record, "event"): - log_data["event"] = record.event - - # Add any custom fields from extra={} - for key, value in record.__dict__.items(): - if key not in [ - "name", - "msg", - "args", - "created", - "filename", - "funcName", - "levelname", - "levelno", - "lineno", - "module", - "msecs", - "message", - "pathname", - "process", - "processName", - "relativeCreated", - "thread", - "threadName", - "exc_info", - "exc_text", - "stack_info", - "event", - "timestamp", - ]: - log_data[key] = value - - return json.dumps(log_data) - - -def get_logger(name: str, level: int = logging.INFO) -> StructuredLogger: - """ - Get structured logger instance - - Usage: - from deepiri_modelkit.logging import get_logger - logger = get_logger("my_service") - logger.info("service_started", port=8000, version="1.0") - """ - return StructuredLogger(name, level) - - -class ErrorLogger: - """Error logging with context""" - - def __init__(self): - self.logger = get_logger("error_logger") - - def log_api_error(self, error: Exception, request_id: str, endpoint: str): - """Log API errors with context""" - self.logger.error( - "api_error", - error=str(error), - error_type=type(error).__name__, - request_id=request_id, - endpoint=endpoint, - ) - - def log_model_error( - self, error: Exception, model_name: str, input_data: Optional[Dict] = None - ): - """Log model inference errors""" - self.logger.error( - "model_error", - error=str(error), - error_type=type(error).__name__, - model_name=model_name, - input_sample=str(input_data)[:200] if input_data else None, - ) - - def log_training_error( - self, error: Exception, pipeline: str, config: Optional[Dict] = None - ): - """Log training pipeline errors""" - self.logger.error( - "training_error", - error=str(error), - error_type=type(error).__name__, - pipeline=pipeline, - config=config, - ) - - -# Singleton instances -_loggers: Dict[str, StructuredLogger] = {} -_error_logger: Optional[ErrorLogger] = None - - -def get_cached_logger(name: str) -> StructuredLogger: - """Get or create cached logger instance""" - if name not in _loggers: - _loggers[name] = get_logger(name) - return _loggers[name] - - -def get_error_logger() -> ErrorLogger: - """Get singleton error logger""" - global _error_logger - if _error_logger is None: - _error_logger = ErrorLogger() - return _error_logger +""" +Shared logging utilities for all Deepiri services +Used by: Cyrex (runtime), Helox (training), and all microservices +""" +import logging +import sys +import json +from datetime import datetime +from typing import Any, Dict, Optional +from pathlib import Path + + +class StructuredLogger: + """JSON structured logger for all Deepiri services""" + + def __init__(self, name: str, level: int = logging.INFO): + self.logger = logging.getLogger(name) + self.logger.setLevel(level) + + # Remove existing handlers + self.logger.handlers = [] + + # Create console handler with JSON formatting + handler = logging.StreamHandler(sys.stdout) + handler.setLevel(level) + handler.setFormatter(JsonFormatter()) + self.logger.addHandler(handler) + + self.logger.propagate = False + + def _log(self, level: int, event: str, **kwargs): + """Internal log method with structured data""" + extra = {"event": event, "timestamp": datetime.utcnow().isoformat() + "Z"} + extra.update(kwargs) + self.logger.log(level, json.dumps(extra)) + + def debug(self, event: str, **kwargs): + self._log(logging.DEBUG, event, **kwargs) + + def info(self, event: str, **kwargs): + self._log(logging.INFO, event, **kwargs) + + def warning(self, event: str, **kwargs): + self._log(logging.WARNING, event, **kwargs) + + def error(self, event: str, **kwargs): + self._log(logging.ERROR, event, **kwargs) + + def critical(self, event: str, **kwargs): + self._log(logging.CRITICAL, event, **kwargs) + + +class JsonFormatter(logging.Formatter): + """Format logs as JSON""" + + def format(self, record: logging.LogRecord) -> str: + log_data = { + "timestamp": datetime.utcnow().isoformat() + "Z", + "level": record.levelname, + "logger": record.name, + "message": record.getMessage(), + } + + # Add extra fields if present + if hasattr(record, 'event'): + log_data['event'] = record.event + + # Add any custom fields from extra={} + for key, value in record.__dict__.items(): + if key not in ['name', 'msg', 'args', 'created', 'filename', 'funcName', + 'levelname', 'levelno', 'lineno', 'module', 'msecs', + 'message', 'pathname', 'process', 'processName', + 'relativeCreated', 'thread', 'threadName', 'exc_info', + 'exc_text', 'stack_info', 'event', 'timestamp']: + log_data[key] = value + + return json.dumps(log_data) + + +def get_logger(name: str, level: int = logging.INFO) -> StructuredLogger: + """ + Get structured logger instance + + Usage: + from deepiri_modelkit.logging import get_logger + logger = get_logger("my_service") + logger.info("service_started", port=8000, version="1.0") + """ + return StructuredLogger(name, level) + + +class ErrorLogger: + """Error logging with context""" + + def __init__(self): + self.logger = get_logger("error_logger") + + def log_api_error(self, error: Exception, request_id: str, endpoint: str): + """Log API errors with context""" + self.logger.error( + "api_error", + error=str(error), + error_type=type(error).__name__, + request_id=request_id, + endpoint=endpoint + ) + + def log_model_error(self, error: Exception, model_name: str, input_data: Optional[Dict] = None): + """Log model inference errors""" + self.logger.error( + "model_error", + error=str(error), + error_type=type(error).__name__, + model_name=model_name, + input_sample=str(input_data)[:200] if input_data else None + ) + + def log_training_error(self, error: Exception, pipeline: str, config: Optional[Dict] = None): + """Log training pipeline errors""" + self.logger.error( + "training_error", + error=str(error), + error_type=type(error).__name__, + pipeline=pipeline, + config=config + ) + + +# Singleton instances +_loggers: Dict[str, StructuredLogger] = {} +_error_logger: Optional[ErrorLogger] = None + + +def get_cached_logger(name: str) -> StructuredLogger: + """Get or create cached logger instance""" + if name not in _loggers: + _loggers[name] = get_logger(name) + return _loggers[name] + + +def get_error_logger() -> ErrorLogger: + """Get singleton error logger""" + global _error_logger + if _error_logger is None: + _error_logger = ErrorLogger() + return _error_logger + diff --git a/src/deepiri_modelkit/ml/__init__.py b/src/deepiri_modelkit/ml/__init__.py index 19ab9c1..09c30ff 100644 --- a/src/deepiri_modelkit/ml/__init__.py +++ b/src/deepiri_modelkit/ml/__init__.py @@ -1,35 +1,33 @@ -"""ML utilities for Deepiri ModelKit""" - -try: - from .confidence import ( - ConfidenceLevel, - ConfidenceSource, - ConfidenceAttributes, - ConfidenceCalculator, - get_confidence_calculator, - ) - - _HAS_CONFIDENCE = True -except ImportError: - _HAS_CONFIDENCE = False - -try: - from .semantic import SemanticAnalyzer, get_semantic_analyzer - - _HAS_SEMANTIC = True -except ImportError: - _HAS_SEMANTIC = False - -__all__ = [] - -if _HAS_CONFIDENCE: - __all__ += [ - "ConfidenceLevel", - "ConfidenceSource", - "ConfidenceAttributes", - "ConfidenceCalculator", - "get_confidence_calculator", - ] - -if _HAS_SEMANTIC: - __all__ += ["SemanticAnalyzer", "get_semantic_analyzer"] +"""ML utilities for Deepiri ModelKit""" + +try: + from .confidence import ( + ConfidenceLevel, + ConfidenceSource, + ConfidenceAttributes, + ConfidenceCalculator, + get_confidence_calculator, + ) + _HAS_CONFIDENCE = True +except ImportError: + _HAS_CONFIDENCE = False + +try: + from .semantic import SemanticAnalyzer, get_semantic_analyzer + _HAS_SEMANTIC = True +except ImportError: + _HAS_SEMANTIC = False + +__all__ = [] + +if _HAS_CONFIDENCE: + __all__ += [ + "ConfidenceLevel", + "ConfidenceSource", + "ConfidenceAttributes", + "ConfidenceCalculator", + "get_confidence_calculator", + ] + +if _HAS_SEMANTIC: + __all__ += ["SemanticAnalyzer", "get_semantic_analyzer"] diff --git a/src/deepiri_modelkit/ml/confidence.py b/src/deepiri_modelkit/ml/confidence.py index 5d78f84..89d5f2b 100644 --- a/src/deepiri_modelkit/ml/confidence.py +++ b/src/deepiri_modelkit/ml/confidence.py @@ -1,316 +1,292 @@ -""" -Redesigned Confidence Classes System -Provides structured confidence assessment with multiple attributes -""" - -from enum import Enum -from typing import Dict, List, Optional, Tuple -from dataclasses import dataclass - -try: - import numpy as np - - HAS_NUMPY = True -except ImportError: - HAS_NUMPY = False - - -class ConfidenceLevel(str, Enum): - """Confidence level categories""" - - VERY_HIGH = "very_high" # 0.9-1.0 - HIGH = "high" # 0.75-0.9 - MEDIUM = "medium" # 0.5-0.75 - LOW = "low" # 0.25-0.5 - VERY_LOW = "very_low" # 0.0-0.25 - - -class ConfidenceSource(str, Enum): - """Sources of confidence information""" - - MODEL_PREDICTION = "model_prediction" - TRAINING_DATA_COVERAGE = "training_data_coverage" - FEATURE_QUALITY = "feature_quality" - CONTEXT_MATCH = "context_match" - HISTORICAL_ACCURACY = "historical_accuracy" - ENSEMBLE_AGREEMENT = "ensemble_agreement" - - -@dataclass -class ConfidenceAttributes: - """ - Comprehensive confidence attributes for model predictions - - Attributes: - base_score: Raw model confidence score (0.0-1.0) - level: Categorical confidence level - sources: Dictionary of confidence sources and their contributions - uncertainty: Measure of prediction uncertainty - calibration: How well-calibrated the prediction is - reliability: Overall reliability score - explanation: Human-readable explanation - """ - - base_score: float - level: ConfidenceLevel - sources: Dict[str, float] - uncertainty: float - calibration: float - reliability: float - explanation: str - - def to_dict(self) -> Dict: - """Convert to dictionary""" - return { - "base_score": self.base_score, - "level": self.level.value, - "sources": self.sources, - "uncertainty": self.uncertainty, - "calibration": self.calibration, - "reliability": self.reliability, - "explanation": self.explanation, - } - - -class ConfidenceCalculator: - """ - Calculate comprehensive confidence scores with multiple attributes - """ - - def __init__(self): - self.confidence_thresholds = { - ConfidenceLevel.VERY_HIGH: 0.9, - ConfidenceLevel.HIGH: 0.75, - ConfidenceLevel.MEDIUM: 0.5, - ConfidenceLevel.LOW: 0.25, - ConfidenceLevel.VERY_LOW: 0.0, - } - - def calculate_confidence( - self, - model_probabilities: "np.ndarray", - top_k_probs: Optional[List[float]] = None, - training_coverage: Optional[float] = None, - feature_quality: Optional[float] = None, - context_match: Optional[float] = None, - historical_accuracy: Optional[Dict[int, float]] = None, - ) -> ConfidenceAttributes: - """ - Calculate comprehensive confidence attributes - - Args: - model_probabilities: Model output probabilities for all classes - top_k_probs: Top-k probabilities (for ensemble agreement) - training_coverage: How well training data covers this example (0-1) - feature_quality: Quality of input features (0-1) - context_match: How well context matches expected patterns (0-1) - historical_accuracy: Historical accuracy per class {class_id: accuracy} - - Returns: - ConfidenceAttributes object - """ - if not HAS_NUMPY: - raise ImportError( - "numpy is required for ConfidenceCalculator. Install with: pip install numpy" - ) - - # Base score: maximum probability - base_score = float(np.max(model_probabilities)) - - # Uncertainty: entropy-based measure - entropy = -np.sum(model_probabilities * np.log(model_probabilities + 1e-10)) - max_entropy = np.log(len(model_probabilities)) - uncertainty = float(entropy / max_entropy) # Normalized to [0, 1] - - # Calibration: difference between top-2 probabilities (margin) - sorted_probs = np.sort(model_probabilities)[::-1] - margin = ( - float(sorted_probs[0] - sorted_probs[1]) if len(sorted_probs) > 1 else 1.0 - ) - calibration = float(margin) # Higher margin = better calibration - - # Source contributions - sources = {} - - # Model prediction contribution - sources[ConfidenceSource.MODEL_PREDICTION.value] = base_score - - # Training data coverage - if training_coverage is not None: - sources[ConfidenceSource.TRAINING_DATA_COVERAGE.value] = training_coverage - else: - sources[ConfidenceSource.TRAINING_DATA_COVERAGE.value] = ( - 0.7 # Default moderate - ) - - # Feature quality - if feature_quality is not None: - sources[ConfidenceSource.FEATURE_QUALITY.value] = feature_quality - else: - sources[ConfidenceSource.FEATURE_QUALITY.value] = 0.8 # Default good - - # Context match - if context_match is not None: - sources[ConfidenceSource.CONTEXT_MATCH.value] = context_match - else: - sources[ConfidenceSource.CONTEXT_MATCH.value] = 0.7 # Default moderate - - # Historical accuracy - if historical_accuracy: - predicted_class = int(np.argmax(model_probabilities)) - hist_acc = historical_accuracy.get(predicted_class, 0.7) - sources[ConfidenceSource.HISTORICAL_ACCURACY.value] = hist_acc - else: - sources[ConfidenceSource.HISTORICAL_ACCURACY.value] = ( - 0.7 # Default moderate - ) - - # Ensemble agreement (if top_k_probs provided) - if top_k_probs: - agreement = float(np.std(top_k_probs)) # Lower std = higher agreement - sources[ConfidenceSource.ENSEMBLE_AGREEMENT.value] = 1.0 - min( - agreement, 1.0 - ) - else: - sources[ConfidenceSource.ENSEMBLE_AGREEMENT.value] = 0.7 # Default moderate - - # Weighted reliability score - weights = { - ConfidenceSource.MODEL_PREDICTION.value: 0.4, - ConfidenceSource.TRAINING_DATA_COVERAGE.value: 0.15, - ConfidenceSource.FEATURE_QUALITY.value: 0.15, - ConfidenceSource.CONTEXT_MATCH.value: 0.1, - ConfidenceSource.HISTORICAL_ACCURACY.value: 0.1, - ConfidenceSource.ENSEMBLE_AGREEMENT.value: 0.1, - } - - reliability = sum( - sources[source] * weights.get(source, 0.0) for source in sources - ) - - # Adjust reliability based on uncertainty and calibration - reliability = ( - reliability * (1.0 - uncertainty * 0.3) * (0.7 + calibration * 0.3) - ) - reliability = max(0.0, min(1.0, reliability)) - - # Determine confidence level - level = self._get_confidence_level(reliability) - - # Generate explanation - explanation = self._generate_explanation( - reliability, level, sources, uncertainty, calibration - ) - - return ConfidenceAttributes( - base_score=base_score, - level=level, - sources=sources, - uncertainty=uncertainty, - calibration=calibration, - reliability=reliability, - explanation=explanation, - ) - - def _get_confidence_level(self, reliability: float) -> ConfidenceLevel: - """Get confidence level from reliability score""" - if reliability >= 0.9: - return ConfidenceLevel.VERY_HIGH - elif reliability >= 0.75: - return ConfidenceLevel.HIGH - elif reliability >= 0.5: - return ConfidenceLevel.MEDIUM - elif reliability >= 0.25: - return ConfidenceLevel.LOW - else: - return ConfidenceLevel.VERY_LOW - - def _generate_explanation( - self, - reliability: float, - level: ConfidenceLevel, - sources: Dict[str, float], - uncertainty: float, - calibration: float, - ) -> str: - """Generate human-readable explanation""" - parts = [] - - # Main confidence statement - parts.append( - f"Confidence: {level.value.replace('_', ' ').title()} ({reliability:.2%})" - ) - - # Key factors - key_factors = [] - if sources.get(ConfidenceSource.MODEL_PREDICTION.value, 0) > 0.8: - key_factors.append("strong model prediction") - if sources.get(ConfidenceSource.TRAINING_DATA_COVERAGE.value, 0) > 0.8: - key_factors.append("good training coverage") - if uncertainty < 0.3: - key_factors.append("low uncertainty") - if calibration > 0.5: - key_factors.append("clear class separation") - - if key_factors: - parts.append(f"Key factors: {', '.join(key_factors)}") - - # Concerns - concerns = [] - if uncertainty > 0.6: - concerns.append("high uncertainty") - if calibration < 0.2: - concerns.append("unclear class separation") - if sources.get(ConfidenceSource.TRAINING_DATA_COVERAGE.value, 0) < 0.5: - concerns.append("limited training coverage") - - if concerns: - parts.append(f"Concerns: {', '.join(concerns)}") - - return ". ".join(parts) + "." - - def should_accept_prediction( - self, - confidence: ConfidenceAttributes, - min_reliability: float = 0.7, - min_level: ConfidenceLevel = ConfidenceLevel.MEDIUM, - ) -> Tuple[bool, str]: - """ - Determine if prediction should be accepted based on confidence - - Returns: - (should_accept, reason) - """ - level_order = { - ConfidenceLevel.VERY_LOW: 0, - ConfidenceLevel.LOW: 1, - ConfidenceLevel.MEDIUM: 2, - ConfidenceLevel.HIGH: 3, - ConfidenceLevel.VERY_HIGH: 4, - } - - if confidence.reliability < min_reliability: - return ( - False, - f"Reliability {confidence.reliability:.2%} below threshold {min_reliability:.2%}", - ) - - if level_order[confidence.level] < level_order[min_level]: - return ( - False, - f"Confidence level {confidence.level.value} below required {min_level.value}", - ) - - return True, "Confidence meets requirements" - - -# Singleton instance -_confidence_calculator = None - - -def get_confidence_calculator() -> ConfidenceCalculator: - """Get singleton ConfidenceCalculator instance""" - global _confidence_calculator - if _confidence_calculator is None: - _confidence_calculator = ConfidenceCalculator() - return _confidence_calculator +""" +Redesigned Confidence Classes System +Provides structured confidence assessment with multiple attributes +""" +from enum import Enum +from typing import Dict, List, Optional, Tuple +from dataclasses import dataclass + +try: + import numpy as np + HAS_NUMPY = True +except ImportError: + HAS_NUMPY = False + + +class ConfidenceLevel(str, Enum): + """Confidence level categories""" + VERY_HIGH = "very_high" # 0.9-1.0 + HIGH = "high" # 0.75-0.9 + MEDIUM = "medium" # 0.5-0.75 + LOW = "low" # 0.25-0.5 + VERY_LOW = "very_low" # 0.0-0.25 + + +class ConfidenceSource(str, Enum): + """Sources of confidence information""" + MODEL_PREDICTION = "model_prediction" + TRAINING_DATA_COVERAGE = "training_data_coverage" + FEATURE_QUALITY = "feature_quality" + CONTEXT_MATCH = "context_match" + HISTORICAL_ACCURACY = "historical_accuracy" + ENSEMBLE_AGREEMENT = "ensemble_agreement" + + +@dataclass +class ConfidenceAttributes: + """ + Comprehensive confidence attributes for model predictions + + Attributes: + base_score: Raw model confidence score (0.0-1.0) + level: Categorical confidence level + sources: Dictionary of confidence sources and their contributions + uncertainty: Measure of prediction uncertainty + calibration: How well-calibrated the prediction is + reliability: Overall reliability score + explanation: Human-readable explanation + """ + base_score: float + level: ConfidenceLevel + sources: Dict[str, float] + uncertainty: float + calibration: float + reliability: float + explanation: str + + def to_dict(self) -> Dict: + """Convert to dictionary""" + return { + "base_score": self.base_score, + "level": self.level.value, + "sources": self.sources, + "uncertainty": self.uncertainty, + "calibration": self.calibration, + "reliability": self.reliability, + "explanation": self.explanation + } + + +class ConfidenceCalculator: + """ + Calculate comprehensive confidence scores with multiple attributes + """ + + def __init__(self): + self.confidence_thresholds = { + ConfidenceLevel.VERY_HIGH: 0.9, + ConfidenceLevel.HIGH: 0.75, + ConfidenceLevel.MEDIUM: 0.5, + ConfidenceLevel.LOW: 0.25, + ConfidenceLevel.VERY_LOW: 0.0 + } + + def calculate_confidence( + self, + model_probabilities: "np.ndarray", + top_k_probs: Optional[List[float]] = None, + training_coverage: Optional[float] = None, + feature_quality: Optional[float] = None, + context_match: Optional[float] = None, + historical_accuracy: Optional[Dict[int, float]] = None + ) -> ConfidenceAttributes: + """ + Calculate comprehensive confidence attributes + + Args: + model_probabilities: Model output probabilities for all classes + top_k_probs: Top-k probabilities (for ensemble agreement) + training_coverage: How well training data covers this example (0-1) + feature_quality: Quality of input features (0-1) + context_match: How well context matches expected patterns (0-1) + historical_accuracy: Historical accuracy per class {class_id: accuracy} + + Returns: + ConfidenceAttributes object + """ + if not HAS_NUMPY: + raise ImportError("numpy is required for ConfidenceCalculator. Install with: pip install numpy") + + # Base score: maximum probability + base_score = float(np.max(model_probabilities)) + + # Uncertainty: entropy-based measure + entropy = -np.sum(model_probabilities * np.log(model_probabilities + 1e-10)) + max_entropy = np.log(len(model_probabilities)) + uncertainty = float(entropy / max_entropy) # Normalized to [0, 1] + + # Calibration: difference between top-2 probabilities (margin) + sorted_probs = np.sort(model_probabilities)[::-1] + margin = float(sorted_probs[0] - sorted_probs[1]) if len(sorted_probs) > 1 else 1.0 + calibration = float(margin) # Higher margin = better calibration + + # Source contributions + sources = {} + + # Model prediction contribution + sources[ConfidenceSource.MODEL_PREDICTION.value] = base_score + + # Training data coverage + if training_coverage is not None: + sources[ConfidenceSource.TRAINING_DATA_COVERAGE.value] = training_coverage + else: + sources[ConfidenceSource.TRAINING_DATA_COVERAGE.value] = 0.7 # Default moderate + + # Feature quality + if feature_quality is not None: + sources[ConfidenceSource.FEATURE_QUALITY.value] = feature_quality + else: + sources[ConfidenceSource.FEATURE_QUALITY.value] = 0.8 # Default good + + # Context match + if context_match is not None: + sources[ConfidenceSource.CONTEXT_MATCH.value] = context_match + else: + sources[ConfidenceSource.CONTEXT_MATCH.value] = 0.7 # Default moderate + + # Historical accuracy + if historical_accuracy: + predicted_class = int(np.argmax(model_probabilities)) + hist_acc = historical_accuracy.get(predicted_class, 0.7) + sources[ConfidenceSource.HISTORICAL_ACCURACY.value] = hist_acc + else: + sources[ConfidenceSource.HISTORICAL_ACCURACY.value] = 0.7 # Default moderate + + # Ensemble agreement (if top_k_probs provided) + if top_k_probs: + agreement = float(np.std(top_k_probs)) # Lower std = higher agreement + sources[ConfidenceSource.ENSEMBLE_AGREEMENT.value] = 1.0 - min(agreement, 1.0) + else: + sources[ConfidenceSource.ENSEMBLE_AGREEMENT.value] = 0.7 # Default moderate + + # Weighted reliability score + weights = { + ConfidenceSource.MODEL_PREDICTION.value: 0.4, + ConfidenceSource.TRAINING_DATA_COVERAGE.value: 0.15, + ConfidenceSource.FEATURE_QUALITY.value: 0.15, + ConfidenceSource.CONTEXT_MATCH.value: 0.1, + ConfidenceSource.HISTORICAL_ACCURACY.value: 0.1, + ConfidenceSource.ENSEMBLE_AGREEMENT.value: 0.1 + } + + reliability = sum( + sources[source] * weights.get(source, 0.0) + for source in sources + ) + + # Adjust reliability based on uncertainty and calibration + reliability = reliability * (1.0 - uncertainty * 0.3) * (0.7 + calibration * 0.3) + reliability = max(0.0, min(1.0, reliability)) + + # Determine confidence level + level = self._get_confidence_level(reliability) + + # Generate explanation + explanation = self._generate_explanation( + reliability, level, sources, uncertainty, calibration + ) + + return ConfidenceAttributes( + base_score=base_score, + level=level, + sources=sources, + uncertainty=uncertainty, + calibration=calibration, + reliability=reliability, + explanation=explanation + ) + + def _get_confidence_level(self, reliability: float) -> ConfidenceLevel: + """Get confidence level from reliability score""" + if reliability >= 0.9: + return ConfidenceLevel.VERY_HIGH + elif reliability >= 0.75: + return ConfidenceLevel.HIGH + elif reliability >= 0.5: + return ConfidenceLevel.MEDIUM + elif reliability >= 0.25: + return ConfidenceLevel.LOW + else: + return ConfidenceLevel.VERY_LOW + + def _generate_explanation( + self, + reliability: float, + level: ConfidenceLevel, + sources: Dict[str, float], + uncertainty: float, + calibration: float + ) -> str: + """Generate human-readable explanation""" + parts = [] + + # Main confidence statement + parts.append(f"Confidence: {level.value.replace('_', ' ').title()} ({reliability:.2%})") + + # Key factors + key_factors = [] + if sources.get(ConfidenceSource.MODEL_PREDICTION.value, 0) > 0.8: + key_factors.append("strong model prediction") + if sources.get(ConfidenceSource.TRAINING_DATA_COVERAGE.value, 0) > 0.8: + key_factors.append("good training coverage") + if uncertainty < 0.3: + key_factors.append("low uncertainty") + if calibration > 0.5: + key_factors.append("clear class separation") + + if key_factors: + parts.append(f"Key factors: {', '.join(key_factors)}") + + # Concerns + concerns = [] + if uncertainty > 0.6: + concerns.append("high uncertainty") + if calibration < 0.2: + concerns.append("unclear class separation") + if sources.get(ConfidenceSource.TRAINING_DATA_COVERAGE.value, 0) < 0.5: + concerns.append("limited training coverage") + + if concerns: + parts.append(f"Concerns: {', '.join(concerns)}") + + return ". ".join(parts) + "." + + def should_accept_prediction( + self, + confidence: ConfidenceAttributes, + min_reliability: float = 0.7, + min_level: ConfidenceLevel = ConfidenceLevel.MEDIUM + ) -> Tuple[bool, str]: + """ + Determine if prediction should be accepted based on confidence + + Returns: + (should_accept, reason) + """ + level_order = { + ConfidenceLevel.VERY_LOW: 0, + ConfidenceLevel.LOW: 1, + ConfidenceLevel.MEDIUM: 2, + ConfidenceLevel.HIGH: 3, + ConfidenceLevel.VERY_HIGH: 4 + } + + if confidence.reliability < min_reliability: + return False, f"Reliability {confidence.reliability:.2%} below threshold {min_reliability:.2%}" + + if level_order[confidence.level] < level_order[min_level]: + return False, f"Confidence level {confidence.level.value} below required {min_level.value}" + + return True, "Confidence meets requirements" + + +# Singleton instance +_confidence_calculator = None + + +def get_confidence_calculator() -> ConfidenceCalculator: + """Get singleton ConfidenceCalculator instance""" + global _confidence_calculator + if _confidence_calculator is None: + _confidence_calculator = ConfidenceCalculator() + return _confidence_calculator diff --git a/src/deepiri_modelkit/ml/semantic.py b/src/deepiri_modelkit/ml/semantic.py index a921fd3..141c6b1 100644 --- a/src/deepiri_modelkit/ml/semantic.py +++ b/src/deepiri_modelkit/ml/semantic.py @@ -1,361 +1,343 @@ -""" -Dynamic Semantic Analysis for Data Augmentation -Inspired by Carnegie Mellon University (CMU) Language Technologies Institute approaches -Uses semantic similarity and contextual understanding for dynamic variation generation -""" - -import json -import re -from typing import List, Dict, Optional, Set -from collections import defaultdict -import os - -from deepiri_modelkit.logging import get_logger - -logger = get_logger("deepiri_modelkit.ml.semantic") - -# Try multiple HTTP clients -try: - import httpx - - HAS_HTTPX = True -except ImportError: - HAS_HTTPX = False - -try: - import requests - - HAS_REQUESTS = True -except ImportError: - HAS_REQUESTS = False - -try: - import ollama - - HAS_OLLAMA_PKG = True -except ImportError: - HAS_OLLAMA_PKG = False - - -class SemanticAnalyzer: - """ - Dynamic semantic analysis for generating variations - Inspired by CMU's semantic analysis approaches - """ - - def __init__( - self, ollama_base_url: str = "http://localhost:11434", model: str = "llama3:8b" - ): - self.ollama_base_url = ollama_base_url - self.model = model - self._cache = {} # Cache for semantic analysis results - - def _call_ollama( - self, prompt: str, timeout: float = 15.0 - ) -> Optional[str]: # Reduced from 30s to 15s - """Call Ollama API directly via HTTP or Python package""" - # Try ollama Python package first (cleaner API) - if HAS_OLLAMA_PKG: - try: - response = ollama.generate( - model=self.model, - prompt=prompt, - options={ - "temperature": 0.7, - "top_p": 0.9, - "num_predict": 100, # Reduced from 200 for faster responses - }, - ) - return response.get("response", "").strip() - except Exception: - # Fall back to HTTP - pass - - # Fall back to HTTP API - try: - if HAS_HTTPX: - logger.debug(f"Calling Ollama HTTP with {len(prompt)} char prompt") - response = httpx.post( - f"{self.ollama_base_url}/api/generate", - json={ - "model": self.model, - "prompt": prompt, - "stream": False, - "options": { - "temperature": 0.7, - "top_p": 0.9, - "num_predict": 100, # Reduced from 200 for faster responses - }, - }, - timeout=timeout, - ) - - if response.status_code == 200: - result = response.json() - logger.debug("Ollama HTTP call succeeded") - return result.get("response", "").strip() - else: - logger.debug( - f"Ollama HTTP call failed: HTTP {response.status_code}" - ) - elif HAS_REQUESTS: - logger.debug(f"Calling Ollama HTTP with {len(prompt)} char prompt") - response = requests.post( - f"{self.ollama_base_url}/api/generate", - json={ - "model": self.model, - "prompt": prompt, - "stream": False, - "options": { - "temperature": 0.7, - "top_p": 0.9, - "num_predict": 100, # Reduced from 200 for faster responses - }, - }, - timeout=timeout, - ) - - if response.status_code == 200: - result = response.json() - logger.debug("Ollama HTTP call succeeded") - return result.get("response", "").strip() - else: - logger.debug( - f"Ollama HTTP call failed: HTTP {response.status_code}" - ) - except Exception as e: - logger.debug(f"Ollama HTTP call failed: {e}") - - return None - - def extract_semantic_verbs(self, text: str, category: str) -> List[str]: - """ - Extract semantically similar verbs using Ollama - Inspired by CMU's semantic role labeling approaches - Cached per category (not per text) for performance - """ - # Cache by category only, not per text (much more efficient) - cache_key = f"verbs:{category}" - if cache_key in self._cache: - return self._cache[cache_key] - - # Use category-level prompt (not text-specific) for better caching - prompt = f"""For tasks in the '{category}' category, suggest 6-8 common action verbs that are semantically similar and could be used interchangeably. - -Category: {category} - -Return ONLY a JSON array of verbs, no explanation. Example: ["write", "draft", "compose", "create", "author"]""" - - response = self._call_ollama(prompt) - if response: - try: - json_match = re.search(r"\[.*?\]", response, re.DOTALL) - if json_match: - verbs = json.loads(json_match.group()) - if isinstance(verbs, list) and len(verbs) > 0: - self._cache[cache_key] = verbs - return verbs - except Exception: - pass - - # Fallback: return empty list - return [] - - def generate_semantic_prefixes(self, text: str, category: str) -> List[str]: - """ - Generate semantically appropriate prefixes using contextual analysis - Inspired by CMU's discourse analysis approaches - Cached per category (not per text) for performance - """ - cache_key = f"prefixes:{category}" - if cache_key in self._cache: - return self._cache[cache_key] - - # Simplified prompt - category only (not text-specific) for better caching - prompt = f"""For tasks in the '{category}' category, generate 8-10 natural ways to introduce task requests. - -Category: {category} - -Consider: politeness levels, personal perspectives (I need, Can you, Let me), contextual frames. - -Return ONLY a JSON array of prefixes. Example: ["I need to", "Can you help me", "Please", "I want to"]""" - - response = self._call_ollama(prompt) - if response: - try: - json_match = re.search(r"\[.*?\]", response, re.DOTALL) - if json_match: - prefixes = json.loads(json_match.group()) - if isinstance(prefixes, list) and len(prefixes) > 0: - self._cache[cache_key] = prefixes - return prefixes - except Exception: - pass - - # Fallback: return default prefixes - return [ - "I need to", - "Can you help me", - "Please", - "I want to", - "Help me", - "I should", - "Let me", - "I'm going to", - ] - - def generate_semantic_suffixes(self, text: str, category: str) -> List[str]: - """ - Generate semantically appropriate suffixes using temporal and contextual analysis - Cached per category (not per text) for performance - """ - cache_key = f"suffixes:{category}" - if cache_key in self._cache: - return self._cache[cache_key] - - # Simplified prompt - category only (not text-specific) for better caching - prompt = f"""For tasks in the '{category}' category, generate 6-8 natural ways to add temporal or contextual information. - -Category: {category} - -Consider: time constraints, urgency levels, contextual markers. - -Return ONLY a JSON array of suffixes. Example: [" today", " this week", " as soon as possible"]""" - - response = self._call_ollama(prompt) - if response: - try: - json_match = re.search(r"\[.*?\]", response, re.DOTALL) - if json_match: - suffixes = json.loads(json_match.group()) - if isinstance(suffixes, list) and len(suffixes) > 0: - self._cache[cache_key] = suffixes - return suffixes - except Exception: - pass - - # Fallback: return default suffixes - return [ - "", - " today", - " this week", - " as soon as possible", - " when you have time", - " - urgent", - " - important", - ] - - def generate_paraphrases( - self, text: str, category: str, num_paraphrases: int = 3 - ) -> List[str]: - """ - Generate semantic paraphrases using Ollama - Inspired by CMU's paraphrase generation approaches - """ - prompt = f"""Generate {num_paraphrases} different natural ways to express this task request. Each should be semantically equivalent but use different wording: - -Original: "{text}" -Category: {category} - -Requirements: -- Keep the same meaning and intent -- Use natural, conversational language -- Vary sentence structure and word choice -- Each paraphrase should be a complete sentence - -Return ONLY the paraphrases, one per line, without numbering or bullets.""" - - response = self._call_ollama(prompt) - if response: - paraphrases = [] - for line in response.strip().split("\n"): - line = line.strip() - # Remove common prefixes - for prefix in ["- ", "1. ", "2. ", "3. ", "4. ", "5. ", "* ", "• "]: - if line.startswith(prefix): - line = line[len(prefix) :].strip() - - if line and line != text and len(line) > 10: - paraphrases.append(line) - - return paraphrases[:num_paraphrases] - - return [] - - def analyze_semantic_structure(self, text: str) -> Dict: - """ - Analyze semantic structure of text - Inspired by CMU's semantic role labeling and dependency parsing - """ - prompt = f"""Analyze the semantic structure of this task request: - -"{text}" - -Identify: -1. Main action verb -2. Object/noun phrase -3. Modifiers/adjectives -4. Temporal markers (if any) -5. Urgency indicators (if any) - -Return a JSON object with these fields.""" - - response = self._call_ollama(prompt) - if response: - try: - json_match = re.search(r"\{.*?\}", response, re.DOTALL) - if json_match: - return json.loads(json_match.group()) - except Exception: - pass - - # Fallback: simple analysis - words = text.lower().split() - return { - "action_verb": words[0] if words else "unknown", - "object": " ".join(words[1:]) if len(words) > 1 else "", - "modifiers": [], - "temporal": None, - "urgency": None, - } - - def check_ollama_available(self) -> bool: - """Check if Ollama is available""" - # Try ollama package first - if HAS_OLLAMA_PKG: - try: - ollama.list() # This will raise if not available - return True - except Exception: - pass - - # Fall back to HTTP check - try: - if HAS_HTTPX: - response = httpx.get(f"{self.ollama_base_url}/api/tags", timeout=5.0) - return response.status_code == 200 - elif HAS_REQUESTS: - response = requests.get(f"{self.ollama_base_url}/api/tags", timeout=5.0) - return response.status_code == 200 - except Exception: - pass - - return False - - -def get_semantic_analyzer( - ollama_base_url: Optional[str] = None, model: Optional[str] = None -) -> Optional[SemanticAnalyzer]: - """ - Factory function to get semantic analyzer - """ - base_url = ollama_base_url or os.getenv("OLLAMA_BASE_URL", "http://localhost:11434") - model_name = model or os.getenv("OLLAMA_MODEL", "llama3:8b") - - analyzer = SemanticAnalyzer(ollama_base_url=base_url, model=model_name) - - if analyzer.check_ollama_available(): - return analyzer - else: - logger.warning(f"Ollama not available at {base_url}") - return None +""" +Dynamic Semantic Analysis for Data Augmentation +Inspired by Carnegie Mellon University (CMU) Language Technologies Institute approaches +Uses semantic similarity and contextual understanding for dynamic variation generation +""" +import json +import re +from typing import List, Dict, Optional, Set +from collections import defaultdict +import os + +from deepiri_modelkit.logging import get_logger + +logger = get_logger("deepiri_modelkit.ml.semantic") + +# Try multiple HTTP clients +try: + import httpx + HAS_HTTPX = True +except ImportError: + HAS_HTTPX = False + +try: + import requests + HAS_REQUESTS = True +except ImportError: + HAS_REQUESTS = False + +try: + import ollama + HAS_OLLAMA_PKG = True +except ImportError: + HAS_OLLAMA_PKG = False + + +class SemanticAnalyzer: + """ + Dynamic semantic analysis for generating variations + Inspired by CMU's semantic analysis approaches + """ + + def __init__(self, ollama_base_url: str = "http://localhost:11434", model: str = "llama3:8b"): + self.ollama_base_url = ollama_base_url + self.model = model + self._cache = {} # Cache for semantic analysis results + + def _call_ollama(self, prompt: str, timeout: float = 15.0) -> Optional[str]: # Reduced from 30s to 15s + """Call Ollama API directly via HTTP or Python package""" + # Try ollama Python package first (cleaner API) + if HAS_OLLAMA_PKG: + try: + response = ollama.generate( + model=self.model, + prompt=prompt, + options={ + "temperature": 0.7, + "top_p": 0.9, + "num_predict": 100 # Reduced from 200 for faster responses + } + ) + return response.get("response", "").strip() + except Exception: + # Fall back to HTTP + pass + + # Fall back to HTTP API + try: + if HAS_HTTPX: + logger.debug(f"Calling Ollama HTTP with {len(prompt)} char prompt") + response = httpx.post( + f"{self.ollama_base_url}/api/generate", + json={ + "model": self.model, + "prompt": prompt, + "stream": False, + "options": { + "temperature": 0.7, + "top_p": 0.9, + "num_predict": 100 # Reduced from 200 for faster responses + } + }, + timeout=timeout + ) + + if response.status_code == 200: + result = response.json() + logger.debug("Ollama HTTP call succeeded") + return result.get("response", "").strip() + else: + logger.debug(f"Ollama HTTP call failed: HTTP {response.status_code}") + elif HAS_REQUESTS: + logger.debug(f"Calling Ollama HTTP with {len(prompt)} char prompt") + response = requests.post( + f"{self.ollama_base_url}/api/generate", + json={ + "model": self.model, + "prompt": prompt, + "stream": False, + "options": { + "temperature": 0.7, + "top_p": 0.9, + "num_predict": 100 # Reduced from 200 for faster responses + } + }, + timeout=timeout + ) + + if response.status_code == 200: + result = response.json() + logger.debug("Ollama HTTP call succeeded") + return result.get("response", "").strip() + else: + logger.debug(f"Ollama HTTP call failed: HTTP {response.status_code}") + except Exception as e: + logger.debug(f"Ollama HTTP call failed: {e}") + + return None + + def extract_semantic_verbs(self, text: str, category: str) -> List[str]: + """ + Extract semantically similar verbs using Ollama + Inspired by CMU's semantic role labeling approaches + Cached per category (not per text) for performance + """ + # Cache by category only, not per text (much more efficient) + cache_key = f"verbs:{category}" + if cache_key in self._cache: + return self._cache[cache_key] + + # Use category-level prompt (not text-specific) for better caching + prompt = f"""For tasks in the '{category}' category, suggest 6-8 common action verbs that are semantically similar and could be used interchangeably. + +Category: {category} + +Return ONLY a JSON array of verbs, no explanation. Example: ["write", "draft", "compose", "create", "author"]""" + + response = self._call_ollama(prompt) + if response: + try: + json_match = re.search(r'\[.*?\]', response, re.DOTALL) + if json_match: + verbs = json.loads(json_match.group()) + if isinstance(verbs, list) and len(verbs) > 0: + self._cache[cache_key] = verbs + return verbs + except Exception: + pass + + # Fallback: return empty list + return [] + + def generate_semantic_prefixes(self, text: str, category: str) -> List[str]: + """ + Generate semantically appropriate prefixes using contextual analysis + Inspired by CMU's discourse analysis approaches + Cached per category (not per text) for performance + """ + cache_key = f"prefixes:{category}" + if cache_key in self._cache: + return self._cache[cache_key] + + # Simplified prompt - category only (not text-specific) for better caching + prompt = f"""For tasks in the '{category}' category, generate 8-10 natural ways to introduce task requests. + +Category: {category} + +Consider: politeness levels, personal perspectives (I need, Can you, Let me), contextual frames. + +Return ONLY a JSON array of prefixes. Example: ["I need to", "Can you help me", "Please", "I want to"]""" + + response = self._call_ollama(prompt) + if response: + try: + json_match = re.search(r'\[.*?\]', response, re.DOTALL) + if json_match: + prefixes = json.loads(json_match.group()) + if isinstance(prefixes, list) and len(prefixes) > 0: + self._cache[cache_key] = prefixes + return prefixes + except Exception: + pass + + # Fallback: return default prefixes + return [ + "I need to", "Can you help me", "Please", "I want to", + "Help me", "I should", "Let me", "I'm going to" + ] + + def generate_semantic_suffixes(self, text: str, category: str) -> List[str]: + """ + Generate semantically appropriate suffixes using temporal and contextual analysis + Cached per category (not per text) for performance + """ + cache_key = f"suffixes:{category}" + if cache_key in self._cache: + return self._cache[cache_key] + + # Simplified prompt - category only (not text-specific) for better caching + prompt = f"""For tasks in the '{category}' category, generate 6-8 natural ways to add temporal or contextual information. + +Category: {category} + +Consider: time constraints, urgency levels, contextual markers. + +Return ONLY a JSON array of suffixes. Example: [" today", " this week", " as soon as possible"]""" + + response = self._call_ollama(prompt) + if response: + try: + json_match = re.search(r'\[.*?\]', response, re.DOTALL) + if json_match: + suffixes = json.loads(json_match.group()) + if isinstance(suffixes, list) and len(suffixes) > 0: + self._cache[cache_key] = suffixes + return suffixes + except Exception: + pass + + # Fallback: return default suffixes + return [ + "", " today", " this week", " as soon as possible", + " when you have time", " - urgent", " - important" + ] + + def generate_paraphrases(self, text: str, category: str, num_paraphrases: int = 3) -> List[str]: + """ + Generate semantic paraphrases using Ollama + Inspired by CMU's paraphrase generation approaches + """ + prompt = f"""Generate {num_paraphrases} different natural ways to express this task request. Each should be semantically equivalent but use different wording: + +Original: "{text}" +Category: {category} + +Requirements: +- Keep the same meaning and intent +- Use natural, conversational language +- Vary sentence structure and word choice +- Each paraphrase should be a complete sentence + +Return ONLY the paraphrases, one per line, without numbering or bullets.""" + + response = self._call_ollama(prompt) + if response: + paraphrases = [] + for line in response.strip().split('\n'): + line = line.strip() + # Remove common prefixes + for prefix in ['- ', '1. ', '2. ', '3. ', '4. ', '5. ', '* ', '• ']: + if line.startswith(prefix): + line = line[len(prefix):].strip() + + if line and line != text and len(line) > 10: + paraphrases.append(line) + + return paraphrases[:num_paraphrases] + + return [] + + def analyze_semantic_structure(self, text: str) -> Dict: + """ + Analyze semantic structure of text + Inspired by CMU's semantic role labeling and dependency parsing + """ + prompt = f"""Analyze the semantic structure of this task request: + +"{text}" + +Identify: +1. Main action verb +2. Object/noun phrase +3. Modifiers/adjectives +4. Temporal markers (if any) +5. Urgency indicators (if any) + +Return a JSON object with these fields.""" + + response = self._call_ollama(prompt) + if response: + try: + json_match = re.search(r'\{.*?\}', response, re.DOTALL) + if json_match: + return json.loads(json_match.group()) + except Exception: + pass + + # Fallback: simple analysis + words = text.lower().split() + return { + "action_verb": words[0] if words else "unknown", + "object": " ".join(words[1:]) if len(words) > 1 else "", + "modifiers": [], + "temporal": None, + "urgency": None + } + + def check_ollama_available(self) -> bool: + """Check if Ollama is available""" + # Try ollama package first + if HAS_OLLAMA_PKG: + try: + ollama.list() # This will raise if not available + return True + except Exception: + pass + + # Fall back to HTTP check + try: + if HAS_HTTPX: + response = httpx.get( + f"{self.ollama_base_url}/api/tags", + timeout=5.0 + ) + return response.status_code == 200 + elif HAS_REQUESTS: + response = requests.get( + f"{self.ollama_base_url}/api/tags", + timeout=5.0 + ) + return response.status_code == 200 + except Exception: + pass + + return False + + +def get_semantic_analyzer( + ollama_base_url: Optional[str] = None, + model: Optional[str] = None +) -> Optional[SemanticAnalyzer]: + """ + Factory function to get semantic analyzer + """ + base_url = ollama_base_url or os.getenv("OLLAMA_BASE_URL", "http://localhost:11434") + model_name = model or os.getenv("OLLAMA_MODEL", "llama3:8b") + + analyzer = SemanticAnalyzer(ollama_base_url=base_url, model=model_name) + + if analyzer.check_ollama_available(): + return analyzer + else: + logger.warning(f"Ollama not available at {base_url}") + return None diff --git a/src/deepiri_modelkit/rag/__init__.py b/src/deepiri_modelkit/rag/__init__.py index 6d85952..0b3bbcc 100644 --- a/src/deepiri_modelkit/rag/__init__.py +++ b/src/deepiri_modelkit/rag/__init__.py @@ -1,166 +1,155 @@ -""" -Universal RAG Module for Deepiri Platform -Reusable across all industry niches: Insurance, Manufacturing, Property Management, Healthcare, etc. -""" - -from .base import ( - UniversalRAGEngine, - Document, - DocumentType, - IndustryNiche, - RAGConfig, - RAGQuery, - RetrievalResult, -) -from .processors import ( - DocumentProcessor, - RegulationProcessor, - HistoricalDataProcessor, - KnowledgeBaseProcessor, - ManualProcessor, - get_processor, -) -from .retrievers import ( - MultiModalRetriever, - HybridRetriever, - ContextualRetriever, - get_retriever, -) - -# Advanced features (optional imports) -try: - from .advanced_retrieval import ( - AdvancedRetrievalPipeline, - QueryExpander, - SynonymQueryExpander, - RephraseQueryExpander, - MultiQueryRetriever, - QueryCache, - ) - - HAS_ADVANCED_RETRIEVAL = True -except ImportError: - HAS_ADVANCED_RETRIEVAL = False - AdvancedRetrievalPipeline = None - QueryExpander = None - SynonymQueryExpander = None - RephraseQueryExpander = None - MultiQueryRetriever = None - QueryCache = None - -try: - from .caching import ( - AdvancedCacheManager, - EmbeddingCache, - QueryResultCache, - ) - - HAS_CACHING = True -except ImportError: - HAS_CACHING = False - AdvancedCacheManager = None - EmbeddingCache = None - QueryResultCache = None - -try: - from .monitoring import ( - RAGMonitor, - RetrievalMetrics, - IndexingMetrics, - SystemMetrics, - PerformanceTimer, - ) - - HAS_MONITORING = True -except ImportError: - HAS_MONITORING = False - RAGMonitor = None - RetrievalMetrics = None - IndexingMetrics = None - SystemMetrics = None - PerformanceTimer = None - -try: - from .async_processing import ( - AsyncBatchProcessor, - AsyncDocumentIndexer, - AsyncDocumentProcessor, - BatchProcessingConfig, - BatchProcessingResult, - ) - - HAS_ASYNC = True -except ImportError: - HAS_ASYNC = False - AsyncBatchProcessor = None - AsyncDocumentIndexer = None - AsyncDocumentProcessor = None - BatchProcessingConfig = None - BatchProcessingResult = None - -__all__ = [ - # Core - "UniversalRAGEngine", - "Document", - "DocumentType", - "IndustryNiche", - "RAGConfig", - "RAGQuery", - "RetrievalResult", - # Processors - "DocumentProcessor", - "RegulationProcessor", - "HistoricalDataProcessor", - "KnowledgeBaseProcessor", - "ManualProcessor", - "get_processor", - # Retrievers - "MultiModalRetriever", - "HybridRetriever", - "ContextualRetriever", - "get_retriever", -] - -# Conditionally add advanced features -if HAS_ADVANCED_RETRIEVAL: - __all__.extend( - [ - "AdvancedRetrievalPipeline", - "QueryExpander", - "SynonymQueryExpander", - "RephraseQueryExpander", - "MultiQueryRetriever", - "QueryCache", - ] - ) - -if HAS_CACHING: - __all__.extend( - [ - "AdvancedCacheManager", - "EmbeddingCache", - "QueryResultCache", - ] - ) - -if HAS_MONITORING: - __all__.extend( - [ - "RAGMonitor", - "RetrievalMetrics", - "IndexingMetrics", - "SystemMetrics", - "PerformanceTimer", - ] - ) - -if HAS_ASYNC: - __all__.extend( - [ - "AsyncBatchProcessor", - "AsyncDocumentIndexer", - "AsyncDocumentProcessor", - "BatchProcessingConfig", - "BatchProcessingResult", - ] - ) +""" +Universal RAG Module for Deepiri Platform +Reusable across all industry niches: Insurance, Manufacturing, Property Management, Healthcare, etc. +""" + +from .base import ( + UniversalRAGEngine, + Document, + DocumentType, + IndustryNiche, + RAGConfig, + RAGQuery, + RetrievalResult, +) +from .processors import ( + DocumentProcessor, + RegulationProcessor, + HistoricalDataProcessor, + KnowledgeBaseProcessor, + ManualProcessor, + get_processor, +) +from .retrievers import ( + MultiModalRetriever, + HybridRetriever, + ContextualRetriever, + get_retriever, +) + +# Advanced features (optional imports) +try: + from .advanced_retrieval import ( + AdvancedRetrievalPipeline, + QueryExpander, + SynonymQueryExpander, + RephraseQueryExpander, + MultiQueryRetriever, + QueryCache, + ) + HAS_ADVANCED_RETRIEVAL = True +except ImportError: + HAS_ADVANCED_RETRIEVAL = False + AdvancedRetrievalPipeline = None + QueryExpander = None + SynonymQueryExpander = None + RephraseQueryExpander = None + MultiQueryRetriever = None + QueryCache = None + +try: + from .caching import ( + AdvancedCacheManager, + EmbeddingCache, + QueryResultCache, + ) + HAS_CACHING = True +except ImportError: + HAS_CACHING = False + AdvancedCacheManager = None + EmbeddingCache = None + QueryResultCache = None + +try: + from .monitoring import ( + RAGMonitor, + RetrievalMetrics, + IndexingMetrics, + SystemMetrics, + PerformanceTimer, + ) + HAS_MONITORING = True +except ImportError: + HAS_MONITORING = False + RAGMonitor = None + RetrievalMetrics = None + IndexingMetrics = None + SystemMetrics = None + PerformanceTimer = None + +try: + from .async_processing import ( + AsyncBatchProcessor, + AsyncDocumentIndexer, + AsyncDocumentProcessor, + BatchProcessingConfig, + BatchProcessingResult, + ) + HAS_ASYNC = True +except ImportError: + HAS_ASYNC = False + AsyncBatchProcessor = None + AsyncDocumentIndexer = None + AsyncDocumentProcessor = None + BatchProcessingConfig = None + BatchProcessingResult = None + +__all__ = [ + # Core + "UniversalRAGEngine", + "Document", + "DocumentType", + "IndustryNiche", + "RAGConfig", + "RAGQuery", + "RetrievalResult", + # Processors + "DocumentProcessor", + "RegulationProcessor", + "HistoricalDataProcessor", + "KnowledgeBaseProcessor", + "ManualProcessor", + "get_processor", + # Retrievers + "MultiModalRetriever", + "HybridRetriever", + "ContextualRetriever", + "get_retriever", +] + +# Conditionally add advanced features +if HAS_ADVANCED_RETRIEVAL: + __all__.extend([ + "AdvancedRetrievalPipeline", + "QueryExpander", + "SynonymQueryExpander", + "RephraseQueryExpander", + "MultiQueryRetriever", + "QueryCache", + ]) + +if HAS_CACHING: + __all__.extend([ + "AdvancedCacheManager", + "EmbeddingCache", + "QueryResultCache", + ]) + +if HAS_MONITORING: + __all__.extend([ + "RAGMonitor", + "RetrievalMetrics", + "IndexingMetrics", + "SystemMetrics", + "PerformanceTimer", + ]) + +if HAS_ASYNC: + __all__.extend([ + "AsyncBatchProcessor", + "AsyncDocumentIndexer", + "AsyncDocumentProcessor", + "BatchProcessingConfig", + "BatchProcessingResult", + ]) + diff --git a/src/deepiri_modelkit/rag/advanced_retrieval.py b/src/deepiri_modelkit/rag/advanced_retrieval.py index 711177f..14e1de6 100644 --- a/src/deepiri_modelkit/rag/advanced_retrieval.py +++ b/src/deepiri_modelkit/rag/advanced_retrieval.py @@ -1,422 +1,394 @@ -""" -Advanced Retrieval Strategies for Universal RAG -Query expansion, multi-query retrieval, and advanced search techniques -""" - -from abc import ABC, abstractmethod -from typing import List, Dict, Any, Optional, Tuple -from dataclasses import dataclass -import hashlib -import json - -from .base import Document, RetrievalResult, RAGQuery - - -@dataclass -class ExpandedQuery: - """Expanded query with multiple variations""" - - original_query: str - expanded_queries: List[str] - query_type: str # "synonym", "rephrase", "keyword", etc. - confidence: float - - -class QueryExpander(ABC): - """Base class for query expansion strategies""" - - @abstractmethod - def expand(self, query: str, max_expansions: int = 3) -> ExpandedQuery: - """Expand a query into multiple variations""" - pass - - -class SynonymQueryExpander(QueryExpander): - """Expand queries using synonyms and related terms""" - - def __init__(self, synonym_dict: Optional[Dict[str, List[str]]] = None): - self.synonym_dict = synonym_dict or self._default_synonyms() - - def expand(self, query: str, max_expansions: int = 3) -> ExpandedQuery: - """Expand query using synonyms""" - words = query.lower().split() - expanded = [query] # Always include original - - for word in words: - if word in self.synonym_dict: - synonyms = self.synonym_dict[word][:max_expansions] - for synonym in synonyms: - expanded_query = query.replace(word, synonym) - if expanded_query not in expanded: - expanded.append(expanded_query) - - return ExpandedQuery( - original_query=query, - expanded_queries=expanded[: max_expansions + 1], - query_type="synonym", - confidence=0.8, - ) - - def _default_synonyms(self) -> Dict[str, List[str]]: - """Default synonym dictionary""" - return { - "repair": ["fix", "maintain", "service", "restore"], - "maintenance": ["service", "upkeep", "repair", "inspection"], - "claim": ["request", "application", "report", "filing"], - "policy": ["coverage", "plan", "insurance", "agreement"], - "regulation": ["rule", "standard", "requirement", "guideline"], - "procedure": ["process", "method", "protocol", "steps"], - "equipment": ["machine", "device", "tool", "apparatus"], - "safety": ["security", "protection", "precaution"], - "inspection": ["examination", "review", "check", "audit"], - "documentation": ["record", "file", "document", "paperwork"], - } - - -class RephraseQueryExpander(QueryExpander): - """Expand queries by rephrasing""" - - def __init__(self, llm_client=None): - self.llm_client = llm_client - - def expand(self, query: str, max_expansions: int = 3) -> ExpandedQuery: - """Rephrase query using LLM or templates""" - if self.llm_client: - return self._llm_rephrase(query, max_expansions) - else: - return self._template_rephrase(query, max_expansions) - - def _template_rephrase(self, query: str, max_expansions: int) -> ExpandedQuery: - """Rephrase using templates""" - templates = [ - f"What is {query}?", - f"How to {query}?", - f"Information about {query}", - f"Details on {query}", - ] - - expanded = [query] + templates[:max_expansions] - - return ExpandedQuery( - original_query=query, - expanded_queries=expanded, - query_type="rephrase", - confidence=0.7, - ) - - def _llm_rephrase(self, query: str, max_expansions: int) -> ExpandedQuery: - """Rephrase using LLM (if available)""" - # Placeholder for LLM-based rephrasing - return self._template_rephrase(query, max_expansions) - - -class KeywordExtractor: - """Extract keywords from queries for hybrid search""" - - def __init__(self): - # Common stop words - self.stop_words = { - "the", - "a", - "an", - "and", - "or", - "but", - "in", - "on", - "at", - "to", - "for", - "of", - "with", - "by", - "is", - "are", - "was", - "were", - "be", - "been", - "being", - "have", - "has", - "had", - "do", - "does", - "did", - "will", - "would", - "should", - "could", - "may", - "might", - "must", - "can", - "this", - "that", - "these", - "those", - } - - def extract(self, query: str, max_keywords: int = 10) -> List[str]: - """Extract keywords from query""" - words = query.lower().split() - keywords = [ - word.strip(".,!?;:()[]{}") - for word in words - if word.strip(".,!?;:()[]{}") not in self.stop_words - and len(word.strip(".,!?;:()[]{}")) > 2 - ] - return keywords[:max_keywords] - - -class MultiQueryRetriever: - """ - Multi-query retrieval strategy - Generates multiple query variations and combines results - """ - - def __init__( - self, - base_retriever, - query_expander: Optional[QueryExpander] = None, - fusion_method: str = "rrf", # "rrf" (Reciprocal Rank Fusion) or "mean" - ): - self.base_retriever = base_retriever - self.query_expander = query_expander or SynonymQueryExpander() - self.fusion_method = fusion_method - - def retrieve(self, query: RAGQuery, num_queries: int = 3) -> List[RetrievalResult]: - """ - Retrieve using multiple query variations - - Args: - query: Original RAG query - num_queries: Number of query variations to generate - - Returns: - Fused retrieval results - """ - # Expand query - expanded = self.query_expander.expand( - query.query, max_expansions=num_queries - 1 - ) - - # Retrieve for each query variation - all_results: Dict[str, List[RetrievalResult]] = {} - - for expanded_query in expanded.expanded_queries: - # Create new query with expanded text - expanded_rag_query = RAGQuery( - query=expanded_query, - industry=query.industry, - doc_types=query.doc_types, - date_range=query.date_range, - metadata_filters=query.metadata_filters, - top_k=query.top_k or 10, - ) - - # Retrieve results - results = self.base_retriever.retrieve(expanded_rag_query) - all_results[expanded_query] = results - - # Fuse results - fused = self._fuse_results(all_results, query.top_k or 10) - - return fused - - def _fuse_results( - self, all_results: Dict[str, List[RetrievalResult]], top_k: int - ) -> List[RetrievalResult]: - """Fuse results from multiple queries""" - if self.fusion_method == "rrf": - return self._reciprocal_rank_fusion(all_results, top_k) - else: - return self._mean_score_fusion(all_results, top_k) - - def _reciprocal_rank_fusion( - self, all_results: Dict[str, List[RetrievalResult]], top_k: int, k: int = 60 - ) -> List[RetrievalResult]: - """ - Reciprocal Rank Fusion (RRF) - Combines rankings from multiple queries - """ - doc_scores: Dict[str, Dict[str, Any]] = {} - - for query_text, results in all_results.items(): - for rank, result in enumerate(results): - doc_id = result.document.id - rrf_score = 1.0 / (k + rank + 1) - - if doc_id not in doc_scores: - doc_scores[doc_id] = { - "document": result.document, - "rrf_score": 0.0, - "max_score": result.score, - "count": 0, - } - - doc_scores[doc_id]["rrf_score"] += rrf_score - doc_scores[doc_id]["max_score"] = max( - doc_scores[doc_id]["max_score"], result.score - ) - doc_scores[doc_id]["count"] += 1 - - # Create fused results - fused = [] - for doc_id, scores in doc_scores.items(): - fused.append( - RetrievalResult( - document=scores["document"], - score=scores["rrf_score"], # Use RRF score - rerank_score=scores["max_score"], # Store max original score - ) - ) - - # Sort by RRF score - fused.sort(key=lambda x: x.score, reverse=True) - - return fused[:top_k] - - def _mean_score_fusion( - self, all_results: Dict[str, List[RetrievalResult]], top_k: int - ) -> List[RetrievalResult]: - """Fuse results using mean score""" - doc_scores: Dict[str, Dict[str, Any]] = {} - - for query_text, results in all_results.items(): - for result in results: - doc_id = result.document.id - - if doc_id not in doc_scores: - doc_scores[doc_id] = { - "document": result.document, - "scores": [], - "count": 0, - } - - doc_scores[doc_id]["scores"].append(result.score) - doc_scores[doc_id]["count"] += 1 - - # Calculate mean scores - fused = [] - for doc_id, scores in doc_scores.items(): - mean_score = sum(scores["scores"]) / len(scores["scores"]) - fused.append( - RetrievalResult( - document=scores["document"], - score=mean_score, - rerank_score=max(scores["scores"]), - ) - ) - - # Sort by mean score - fused.sort(key=lambda x: x.score, reverse=True) - - return fused[:top_k] - - -class QueryCache: - """Cache for query results""" - - def __init__(self, cache_manager=None): - self.cache_manager = cache_manager - self.cache_ttl = 3600 # 1 hour - - def get_cache_key(self, query: RAGQuery) -> str: - """Generate cache key for query""" - query_dict = query.to_dict() - query_str = json.dumps(query_dict, sort_keys=True) - query_hash = hashlib.md5(query_str.encode()).hexdigest() - return f"rag:query:{query_hash}" - - def get(self, query: RAGQuery) -> Optional[List[RetrievalResult]]: - """Get cached results""" - if not self.cache_manager: - return None - - cache_key = self.get_cache_key(query) - cached = self.cache_manager.get(cache_key) - - if cached: - # Reconstruct RetrievalResult objects - results = [] - for item in cached: - doc = Document.from_dict(item["document"]) - result = RetrievalResult( - document=doc, - score=item["score"], - rerank_score=item.get("rerank_score"), - ) - results.append(result) - return results - - return None - - def set(self, query: RAGQuery, results: List[RetrievalResult]): - """Cache results""" - if not self.cache_manager: - return - - cache_key = self.get_cache_key(query) - # Serialize results - serialized = [ - { - "document": r.document.to_dict(), - "score": r.score, - "rerank_score": r.rerank_score, - } - for r in results - ] - - self.cache_manager.set(cache_key, serialized, ttl=self.cache_ttl) - - -class AdvancedRetrievalPipeline: - """ - Advanced retrieval pipeline with: - - Query expansion - - Multi-query retrieval - - Result caching - - Hybrid search - """ - - def __init__( - self, - base_retriever, - query_expander: Optional[QueryExpander] = None, - use_multi_query: bool = True, - use_cache: bool = True, - cache_manager=None, - ): - self.base_retriever = base_retriever - self.query_expander = query_expander or SynonymQueryExpander() - self.use_multi_query = use_multi_query - self.use_cache = use_cache - self.query_cache = QueryCache(cache_manager) if use_cache else None - - if use_multi_query: - self.multi_query_retriever = MultiQueryRetriever( - base_retriever=base_retriever, query_expander=query_expander - ) - else: - self.multi_query_retriever = None - - def retrieve(self, query: RAGQuery) -> List[RetrievalResult]: - """Retrieve with advanced strategies""" - # Check cache - if self.use_cache and self.query_cache: - cached = self.query_cache.get(query) - if cached: - return cached - - # Retrieve using multi-query or base retriever - if self.use_multi_query and self.multi_query_retriever: - results = self.multi_query_retriever.retrieve(query) - else: - results = self.base_retriever.retrieve(query) - - # Cache results - if self.use_cache and self.query_cache: - self.query_cache.set(query, results) - - return results +""" +Advanced Retrieval Strategies for Universal RAG +Query expansion, multi-query retrieval, and advanced search techniques +""" + +from abc import ABC, abstractmethod +from typing import List, Dict, Any, Optional, Tuple +from dataclasses import dataclass +import hashlib +import json + +from .base import Document, RetrievalResult, RAGQuery + + +@dataclass +class ExpandedQuery: + """Expanded query with multiple variations""" + original_query: str + expanded_queries: List[str] + query_type: str # "synonym", "rephrase", "keyword", etc. + confidence: float + + +class QueryExpander(ABC): + """Base class for query expansion strategies""" + + @abstractmethod + def expand(self, query: str, max_expansions: int = 3) -> ExpandedQuery: + """Expand a query into multiple variations""" + pass + + +class SynonymQueryExpander(QueryExpander): + """Expand queries using synonyms and related terms""" + + def __init__(self, synonym_dict: Optional[Dict[str, List[str]]] = None): + self.synonym_dict = synonym_dict or self._default_synonyms() + + def expand(self, query: str, max_expansions: int = 3) -> ExpandedQuery: + """Expand query using synonyms""" + words = query.lower().split() + expanded = [query] # Always include original + + for word in words: + if word in self.synonym_dict: + synonyms = self.synonym_dict[word][:max_expansions] + for synonym in synonyms: + expanded_query = query.replace(word, synonym) + if expanded_query not in expanded: + expanded.append(expanded_query) + + return ExpandedQuery( + original_query=query, + expanded_queries=expanded[:max_expansions + 1], + query_type="synonym", + confidence=0.8 + ) + + def _default_synonyms(self) -> Dict[str, List[str]]: + """Default synonym dictionary""" + return { + "repair": ["fix", "maintain", "service", "restore"], + "maintenance": ["service", "upkeep", "repair", "inspection"], + "claim": ["request", "application", "report", "filing"], + "policy": ["coverage", "plan", "insurance", "agreement"], + "regulation": ["rule", "standard", "requirement", "guideline"], + "procedure": ["process", "method", "protocol", "steps"], + "equipment": ["machine", "device", "tool", "apparatus"], + "safety": ["security", "protection", "precaution"], + "inspection": ["examination", "review", "check", "audit"], + "documentation": ["record", "file", "document", "paperwork"], + } + + +class RephraseQueryExpander(QueryExpander): + """Expand queries by rephrasing""" + + def __init__(self, llm_client=None): + self.llm_client = llm_client + + def expand(self, query: str, max_expansions: int = 3) -> ExpandedQuery: + """Rephrase query using LLM or templates""" + if self.llm_client: + return self._llm_rephrase(query, max_expansions) + else: + return self._template_rephrase(query, max_expansions) + + def _template_rephrase(self, query: str, max_expansions: int) -> ExpandedQuery: + """Rephrase using templates""" + templates = [ + f"What is {query}?", + f"How to {query}?", + f"Information about {query}", + f"Details on {query}", + ] + + expanded = [query] + templates[:max_expansions] + + return ExpandedQuery( + original_query=query, + expanded_queries=expanded, + query_type="rephrase", + confidence=0.7 + ) + + def _llm_rephrase(self, query: str, max_expansions: int) -> ExpandedQuery: + """Rephrase using LLM (if available)""" + # Placeholder for LLM-based rephrasing + return self._template_rephrase(query, max_expansions) + + +class KeywordExtractor: + """Extract keywords from queries for hybrid search""" + + def __init__(self): + # Common stop words + self.stop_words = { + "the", "a", "an", "and", "or", "but", "in", "on", "at", "to", "for", + "of", "with", "by", "is", "are", "was", "were", "be", "been", "being", + "have", "has", "had", "do", "does", "did", "will", "would", "should", + "could", "may", "might", "must", "can", "this", "that", "these", "those" + } + + def extract(self, query: str, max_keywords: int = 10) -> List[str]: + """Extract keywords from query""" + words = query.lower().split() + keywords = [ + word.strip(".,!?;:()[]{}") + for word in words + if word.strip(".,!?;:()[]{}") not in self.stop_words + and len(word.strip(".,!?;:()[]{}")) > 2 + ] + return keywords[:max_keywords] + + +class MultiQueryRetriever: + """ + Multi-query retrieval strategy + Generates multiple query variations and combines results + """ + + def __init__( + self, + base_retriever, + query_expander: Optional[QueryExpander] = None, + fusion_method: str = "rrf" # "rrf" (Reciprocal Rank Fusion) or "mean" + ): + self.base_retriever = base_retriever + self.query_expander = query_expander or SynonymQueryExpander() + self.fusion_method = fusion_method + + def retrieve( + self, + query: RAGQuery, + num_queries: int = 3 + ) -> List[RetrievalResult]: + """ + Retrieve using multiple query variations + + Args: + query: Original RAG query + num_queries: Number of query variations to generate + + Returns: + Fused retrieval results + """ + # Expand query + expanded = self.query_expander.expand(query.query, max_expansions=num_queries - 1) + + # Retrieve for each query variation + all_results: Dict[str, List[RetrievalResult]] = {} + + for expanded_query in expanded.expanded_queries: + # Create new query with expanded text + expanded_rag_query = RAGQuery( + query=expanded_query, + industry=query.industry, + doc_types=query.doc_types, + date_range=query.date_range, + metadata_filters=query.metadata_filters, + top_k=query.top_k or 10 + ) + + # Retrieve results + results = self.base_retriever.retrieve(expanded_rag_query) + all_results[expanded_query] = results + + # Fuse results + fused = self._fuse_results(all_results, query.top_k or 10) + + return fused + + def _fuse_results( + self, + all_results: Dict[str, List[RetrievalResult]], + top_k: int + ) -> List[RetrievalResult]: + """Fuse results from multiple queries""" + if self.fusion_method == "rrf": + return self._reciprocal_rank_fusion(all_results, top_k) + else: + return self._mean_score_fusion(all_results, top_k) + + def _reciprocal_rank_fusion( + self, + all_results: Dict[str, List[RetrievalResult]], + top_k: int, + k: int = 60 + ) -> List[RetrievalResult]: + """ + Reciprocal Rank Fusion (RRF) + Combines rankings from multiple queries + """ + doc_scores: Dict[str, Dict[str, Any]] = {} + + for query_text, results in all_results.items(): + for rank, result in enumerate(results): + doc_id = result.document.id + rrf_score = 1.0 / (k + rank + 1) + + if doc_id not in doc_scores: + doc_scores[doc_id] = { + "document": result.document, + "rrf_score": 0.0, + "max_score": result.score, + "count": 0 + } + + doc_scores[doc_id]["rrf_score"] += rrf_score + doc_scores[doc_id]["max_score"] = max( + doc_scores[doc_id]["max_score"], + result.score + ) + doc_scores[doc_id]["count"] += 1 + + # Create fused results + fused = [] + for doc_id, scores in doc_scores.items(): + fused.append(RetrievalResult( + document=scores["document"], + score=scores["rrf_score"], # Use RRF score + rerank_score=scores["max_score"] # Store max original score + )) + + # Sort by RRF score + fused.sort(key=lambda x: x.score, reverse=True) + + return fused[:top_k] + + def _mean_score_fusion( + self, + all_results: Dict[str, List[RetrievalResult]], + top_k: int + ) -> List[RetrievalResult]: + """Fuse results using mean score""" + doc_scores: Dict[str, Dict[str, Any]] = {} + + for query_text, results in all_results.items(): + for result in results: + doc_id = result.document.id + + if doc_id not in doc_scores: + doc_scores[doc_id] = { + "document": result.document, + "scores": [], + "count": 0 + } + + doc_scores[doc_id]["scores"].append(result.score) + doc_scores[doc_id]["count"] += 1 + + # Calculate mean scores + fused = [] + for doc_id, scores in doc_scores.items(): + mean_score = sum(scores["scores"]) / len(scores["scores"]) + fused.append(RetrievalResult( + document=scores["document"], + score=mean_score, + rerank_score=max(scores["scores"]) + )) + + # Sort by mean score + fused.sort(key=lambda x: x.score, reverse=True) + + return fused[:top_k] + + +class QueryCache: + """Cache for query results""" + + def __init__(self, cache_manager=None): + self.cache_manager = cache_manager + self.cache_ttl = 3600 # 1 hour + + def get_cache_key(self, query: RAGQuery) -> str: + """Generate cache key for query""" + query_dict = query.to_dict() + query_str = json.dumps(query_dict, sort_keys=True) + query_hash = hashlib.md5(query_str.encode()).hexdigest() + return f"rag:query:{query_hash}" + + def get(self, query: RAGQuery) -> Optional[List[RetrievalResult]]: + """Get cached results""" + if not self.cache_manager: + return None + + cache_key = self.get_cache_key(query) + cached = self.cache_manager.get(cache_key) + + if cached: + # Reconstruct RetrievalResult objects + results = [] + for item in cached: + doc = Document.from_dict(item["document"]) + result = RetrievalResult( + document=doc, + score=item["score"], + rerank_score=item.get("rerank_score") + ) + results.append(result) + return results + + return None + + def set(self, query: RAGQuery, results: List[RetrievalResult]): + """Cache results""" + if not self.cache_manager: + return + + cache_key = self.get_cache_key(query) + # Serialize results + serialized = [ + { + "document": r.document.to_dict(), + "score": r.score, + "rerank_score": r.rerank_score + } + for r in results + ] + + self.cache_manager.set(cache_key, serialized, ttl=self.cache_ttl) + + +class AdvancedRetrievalPipeline: + """ + Advanced retrieval pipeline with: + - Query expansion + - Multi-query retrieval + - Result caching + - Hybrid search + """ + + def __init__( + self, + base_retriever, + query_expander: Optional[QueryExpander] = None, + use_multi_query: bool = True, + use_cache: bool = True, + cache_manager=None + ): + self.base_retriever = base_retriever + self.query_expander = query_expander or SynonymQueryExpander() + self.use_multi_query = use_multi_query + self.use_cache = use_cache + self.query_cache = QueryCache(cache_manager) if use_cache else None + + if use_multi_query: + self.multi_query_retriever = MultiQueryRetriever( + base_retriever=base_retriever, + query_expander=query_expander + ) + else: + self.multi_query_retriever = None + + def retrieve(self, query: RAGQuery) -> List[RetrievalResult]: + """Retrieve with advanced strategies""" + # Check cache + if self.use_cache and self.query_cache: + cached = self.query_cache.get(query) + if cached: + return cached + + # Retrieve using multi-query or base retriever + if self.use_multi_query and self.multi_query_retriever: + results = self.multi_query_retriever.retrieve(query) + else: + results = self.base_retriever.retrieve(query) + + # Cache results + if self.use_cache and self.query_cache: + self.query_cache.set(query, results) + + return results + diff --git a/src/deepiri_modelkit/rag/async_processing.py b/src/deepiri_modelkit/rag/async_processing.py index 87a0b0f..9625443 100644 --- a/src/deepiri_modelkit/rag/async_processing.py +++ b/src/deepiri_modelkit/rag/async_processing.py @@ -5,7 +5,6 @@ import asyncio from typing import List, Dict, Any, Optional, Callable, Awaitable - # Fix for Python < 3.9 compatibility try: from collections.abc import AsyncIterator @@ -21,7 +20,6 @@ @dataclass class BatchProcessingConfig: """Configuration for batch processing""" - batch_size: int = 100 max_concurrent_batches: int = 5 chunk_size: int = 1000 @@ -35,25 +33,24 @@ class BatchProcessingConfig: @dataclass class BatchProcessingResult: """Result of batch processing operation""" - total_items: int processed_items: int successful_items: int failed_items: int processing_time_seconds: float errors: List[Dict[str, Any]] = None - + def __post_init__(self): if self.errors is None: self.errors = [] - + @property def success_rate(self) -> float: """Calculate success rate""" if self.total_items == 0: return 0.0 return self.successful_items / self.total_items - + def to_dict(self) -> Dict[str, Any]: """Convert to dictionary""" return { @@ -71,25 +68,25 @@ class AsyncBatchProcessor: """ Async batch processor for high-performance document processing """ - + def __init__(self, config: BatchProcessingConfig): self.config = config self.semaphore = asyncio.Semaphore(config.max_concurrent_batches) - + async def process_batch( self, items: List[Any], processor_func: Callable[[Any], Awaitable[Any]], - progress_callback: Optional[Callable[[int, int], None]] = None, + progress_callback: Optional[Callable[[int, int], None]] = None ) -> BatchProcessingResult: """ Process items in batches asynchronously - + Args: items: List of items to process processor_func: Async function to process each item progress_callback: Optional callback for progress updates - + Returns: BatchProcessingResult with statistics """ @@ -98,125 +95,135 @@ async def process_batch( successful_items = 0 failed_items = 0 errors = [] - + # Split into batches batches = [ - items[i : i + self.config.batch_size] + items[i:i + self.config.batch_size] for i in range(0, total_items, self.config.batch_size) ] - + # Process batches concurrently tasks = [] for batch_idx, batch in enumerate(batches): task = self._process_batch_with_semaphore( - batch, batch_idx, processor_func, progress_callback + batch, + batch_idx, + processor_func, + progress_callback ) tasks.append(task) - + # Wait for all batches batch_results = await asyncio.gather(*tasks, return_exceptions=True) - + # Aggregate results for result in batch_results: if isinstance(result, Exception): failed_items += self.config.batch_size - errors.append({"error": str(result), "type": type(result).__name__}) + errors.append({ + "error": str(result), + "type": type(result).__name__ + }) else: successful_items += result["successful"] failed_items += result["failed"] errors.extend(result.get("errors", [])) - + processing_time = time.time() - start_time - + return BatchProcessingResult( total_items=total_items, processed_items=total_items, successful_items=successful_items, failed_items=failed_items, processing_time_seconds=processing_time, - errors=errors, + errors=errors ) - + async def _process_batch_with_semaphore( self, batch: List[Any], batch_idx: int, processor_func: Callable[[Any], Awaitable[Any]], - progress_callback: Optional[Callable[[int, int], None]], + progress_callback: Optional[Callable[[int, int], None]] ) -> Dict[str, Any]: """Process a single batch with semaphore control""" async with self.semaphore: return await self._process_single_batch( - batch, batch_idx, processor_func, progress_callback + batch, + batch_idx, + processor_func, + progress_callback ) - + async def _process_single_batch( self, batch: List[Any], batch_idx: int, processor_func: Callable[[Any], Awaitable[Any]], - progress_callback: Optional[Callable[[int, int], None]], + progress_callback: Optional[Callable[[int, int], None]] ) -> Dict[str, Any]: """Process a single batch""" successful = 0 failed = 0 errors = [] - + # Process items in batch concurrently tasks = [processor_func(item) for item in batch] results = await asyncio.gather(*tasks, return_exceptions=True) - + for idx, result in enumerate(results): if isinstance(result, Exception): failed += 1 - errors.append( - { - "item_index": batch_idx * self.config.batch_size + idx, - "error": str(result), - "type": type(result).__name__, - } - ) + errors.append({ + "item_index": batch_idx * self.config.batch_size + idx, + "error": str(result), + "type": type(result).__name__ + }) else: successful += 1 - + # Progress callback if progress_callback: total_processed = (batch_idx + 1) * self.config.batch_size progress_callback(total_processed, len(batch) * (batch_idx + 1)) - - return {"successful": successful, "failed": failed, "errors": errors} + + return { + "successful": successful, + "failed": failed, + "errors": errors + } class AsyncDocumentIndexer: """ Async document indexer with batching and retry logic """ - + def __init__( self, index_func: Callable[[Document], Awaitable[bool]], - config: Optional[BatchProcessingConfig] = None, + config: Optional[BatchProcessingConfig] = None ): self.index_func = index_func self.config = config or BatchProcessingConfig() self.batch_processor = AsyncBatchProcessor(self.config) - + async def index_documents( self, documents: List[Document], - progress_callback: Optional[Callable[[int, int], None]] = None, + progress_callback: Optional[Callable[[int, int], None]] = None ) -> BatchProcessingResult: """ Index documents asynchronously in batches - + Args: documents: List of documents to index progress_callback: Optional callback for progress - + Returns: BatchProcessingResult with statistics """ - async def index_document(doc: Document) -> bool: """Index a single document with retry""" for attempt in range(self.config.max_retries): @@ -231,25 +238,27 @@ async def index_document(doc: Document) -> bool: ) continue raise - + return False - + return await self.batch_processor.process_batch( - documents, index_document, progress_callback + documents, + index_document, + progress_callback ) - + async def index_documents_streaming( self, document_stream: AsyncIterator[Document], - progress_callback: Optional[Callable[[int, int], None]] = None, + progress_callback: Optional[Callable[[int, int], None]] = None ) -> BatchProcessingResult: """ Index documents from async stream - + Args: document_stream: Async iterator of documents progress_callback: Optional callback for progress - + Returns: BatchProcessingResult with statistics """ @@ -259,39 +268,39 @@ async def index_documents_streaming( failed = 0 errors = [] start_time = time.time() - + async def process_current_batch(): nonlocal successful, failed, errors - + if not batch: return - + result = await self.index_documents(batch, progress_callback) successful += result.successful_items failed += result.failed_items errors.extend(result.errors) batch.clear() - + async for document in document_stream: batch.append(document) total_processed += 1 - + if len(batch) >= self.config.batch_size: await process_current_batch() - + # Process remaining batch if batch: await process_current_batch() - + processing_time = time.time() - start_time - + return BatchProcessingResult( total_items=total_processed, processed_items=total_processed, successful_items=successful, failed_items=failed, processing_time_seconds=processing_time, - errors=errors, + errors=errors ) @@ -299,55 +308,65 @@ class AsyncDocumentProcessor: """ Async document processor for parallel document processing """ - + def __init__( self, processor_func: Callable[[str, Dict], List[Document]], - config: Optional[BatchProcessingConfig] = None, + config: Optional[BatchProcessingConfig] = None ): self.processor_func = processor_func self.config = config or BatchProcessingConfig() self.batch_processor = AsyncBatchProcessor(self.config) - + async def process_documents( self, raw_documents: List[Dict[str, Any]], - progress_callback: Optional[Callable[[int, int], None]] = None, + progress_callback: Optional[Callable[[int, int], None]] = None ) -> tuple[List[Document], BatchProcessingResult]: """ Process raw documents asynchronously - + Args: raw_documents: List of dicts with 'content' and 'metadata' progress_callback: Optional callback for progress - + Returns: Tuple of (processed_documents, processing_result) """ processed_docs = [] errors = [] - + async def process_document(item: Dict[str, Any]) -> List[Document]: """Process a single document""" try: content = item.get("content", "") metadata = item.get("metadata", {}) - docs = await asyncio.to_thread(self.processor_func, content, metadata) + docs = await asyncio.to_thread( + self.processor_func, + content, + metadata + ) return docs except Exception as e: - errors.append({"item": item.get("id", "unknown"), "error": str(e)}) + errors.append({ + "item": item.get("id", "unknown"), + "error": str(e) + }) return [] - + result = await self.batch_processor.process_batch( - raw_documents, process_document, progress_callback + raw_documents, + process_document, + progress_callback ) - + # Collect all processed documents tasks = [process_document(item) for item in raw_documents] doc_lists = await asyncio.gather(*tasks, return_exceptions=True) - + for doc_list in doc_lists: if isinstance(doc_list, list): processed_docs.extend(doc_list) - + return processed_docs, result + diff --git a/src/deepiri_modelkit/rag/base.py b/src/deepiri_modelkit/rag/base.py index ec97ac5..7693e5d 100644 --- a/src/deepiri_modelkit/rag/base.py +++ b/src/deepiri_modelkit/rag/base.py @@ -1,317 +1,300 @@ -""" -Universal RAG Base Classes -Core abstractions for retrieval-augmented generation across all industries -""" - -from abc import ABC, abstractmethod -from dataclasses import dataclass, field -from enum import Enum -from typing import List, Dict, Any, Optional, Union -from datetime import datetime - - -class DocumentType(Enum): - """Types of documents that can be indexed""" - - REGULATION = "regulation" # Laws, regulations, compliance documents - POLICY = "policy" # Insurance policies, company policies - MANUAL = "manual" # Equipment manuals, operation guides - CONTRACT = "contract" # Legal contracts, agreements - WORK_ORDER = "work_order" # Maintenance work orders, service requests - CLAIM_RECORD = "claim_record" # Insurance claims, warranty claims - MAINTENANCE_LOG = "maintenance_log" # Equipment maintenance history - FAQ = "faq" # Frequently asked questions - KNOWLEDGE_BASE = "knowledge_base" # General knowledge articles - REPORT = "report" # Inspection reports, audit reports - PROCEDURE = "procedure" # Standard operating procedures - SAFETY_GUIDELINE = "safety_guideline" # Safety protocols and guidelines - TECHNICAL_SPEC = "technical_spec" # Technical specifications - INVOICE = "invoice" # Billing and invoices - OTHER = "other" # Catch-all for other document types - - -class IndustryNiche(Enum): - """Supported industry niches""" - - INSURANCE = "insurance" # Property & casualty insurance - MANUFACTURING = "manufacturing" # Industrial manufacturing - PROPERTY_MANAGEMENT = "property_management" # Real estate management - HEALTHCARE = "healthcare" # Healthcare providers - CONSTRUCTION = "construction" # Construction industry - AUTOMOTIVE = "automotive" # Automotive services - ENERGY = "energy" # Energy & utilities - LOGISTICS = "logistics" # Transportation & logistics - RETAIL = "retail" # Retail operations - HOSPITALITY = "hospitality" # Hotels & hospitality - GENERIC = "generic" # Cross-industry - - -@dataclass -class RAGConfig: - """Configuration for RAG system""" - - # Industry configuration - industry: IndustryNiche = IndustryNiche.GENERIC - - # Vector database configuration - vector_db_type: str = "milvus" # milvus, pinecone, weaviate, memory - collection_name: str = "deepiri_universal_rag" - vector_db_host: str = "milvus" - vector_db_port: int = 19530 - - # Embedding configuration - embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2" - embedding_dimension: int = 384 - - # Retrieval configuration - top_k: int = 5 # Number of documents to retrieve - similarity_threshold: float = 0.7 # Minimum similarity score - use_reranking: bool = True # Use cross-encoder reranking - reranker_model: str = "cross-encoder/ms-marco-MiniLM-L-6-v2" - - # Chunking configuration - chunk_size: int = 1000 # Characters per chunk - chunk_overlap: int = 200 # Overlap between chunks - - # Metadata filtering - enable_metadata_filtering: bool = True - date_range_filtering: bool = True - - # Multi-modal support - support_images: bool = False - support_tables: bool = True - support_code: bool = False - - -@dataclass -class Document: - """Universal document representation""" - - id: str - content: str - doc_type: DocumentType - industry: IndustryNiche - - # Metadata - title: Optional[str] = None - source: Optional[str] = None - created_at: Optional[datetime] = None - updated_at: Optional[datetime] = None - author: Optional[str] = None - version: Optional[str] = None - - # Industry-specific metadata - metadata: Dict[str, Any] = field(default_factory=dict) - - # Processing metadata - chunk_index: Optional[int] = None - total_chunks: Optional[int] = None - - def to_dict(self) -> Dict[str, Any]: - """Convert to dictionary for storage""" - return { - "id": self.id, - "content": self.content, - "doc_type": self.doc_type.value, - "industry": self.industry.value, - "title": self.title, - "source": self.source, - "created_at": self.created_at.isoformat() if self.created_at else None, - "updated_at": self.updated_at.isoformat() if self.updated_at else None, - "author": self.author, - "version": self.version, - "metadata": self.metadata, - "chunk_index": self.chunk_index, - "total_chunks": self.total_chunks, - } - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "Document": - """Create from dictionary""" - return cls( - id=data["id"], - content=data["content"], - doc_type=DocumentType(data["doc_type"]), - industry=IndustryNiche(data["industry"]), - title=data.get("title"), - source=data.get("source"), - created_at=( - datetime.fromisoformat(data["created_at"]) - if data.get("created_at") - else None - ), - updated_at=( - datetime.fromisoformat(data["updated_at"]) - if data.get("updated_at") - else None - ), - author=data.get("author"), - version=data.get("version"), - metadata=data.get("metadata", {}), - chunk_index=data.get("chunk_index"), - total_chunks=data.get("total_chunks"), - ) - - -@dataclass -class RetrievalResult: - """Result from RAG retrieval""" - - document: Document - score: float # Similarity score - rerank_score: Optional[float] = None # Reranking score if applicable - - def to_dict(self) -> Dict[str, Any]: - """Convert to dictionary""" - return { - "document": self.document.to_dict(), - "score": self.score, - "rerank_score": self.rerank_score, - } - - -@dataclass -class RAGQuery: - """Query for RAG system""" - - query: str - industry: Optional[IndustryNiche] = None - doc_types: Optional[List[DocumentType]] = None - date_range: Optional[tuple[datetime, datetime]] = None - metadata_filters: Optional[Dict[str, Any]] = None - top_k: Optional[int] = None - - def to_dict(self) -> Dict[str, Any]: - """Convert to dictionary""" - return { - "query": self.query, - "industry": self.industry.value if self.industry else None, - "doc_types": ( - [dt.value for dt in self.doc_types] if self.doc_types else None - ), - "date_range": ( - [dr.isoformat() for dr in self.date_range] if self.date_range else None - ), - "metadata_filters": self.metadata_filters, - "top_k": self.top_k, - } - - -class UniversalRAGEngine(ABC): - """ - Abstract base class for universal RAG engine - Implements common RAG patterns across all industries - """ - - def __init__(self, config: RAGConfig): - self.config = config - self._initialize() - - @abstractmethod - def _initialize(self): - """Initialize RAG components (vector store, embeddings, etc.)""" - pass - - @abstractmethod - def index_document(self, document: Document) -> bool: - """ - Index a single document - - Args: - document: Document to index - - Returns: - True if successful, False otherwise - """ - pass - - @abstractmethod - def index_documents(self, documents: List[Document]) -> Dict[str, Any]: - """ - Index multiple documents in batch - - Args: - documents: List of documents to index - - Returns: - Statistics about the indexing operation - """ - pass - - @abstractmethod - def retrieve(self, query: RAGQuery) -> List[RetrievalResult]: - """ - Retrieve relevant documents for a query - - Args: - query: Query with filters and parameters - - Returns: - List of retrieval results with scores - """ - pass - - @abstractmethod - def generate_with_context( - self, - query: str, - retrieved_docs: List[RetrievalResult], - llm_prompt_template: Optional[str] = None, - ) -> Dict[str, Any]: - """ - Generate response using retrieved context - - Args: - query: User query - retrieved_docs: Retrieved documents for context - llm_prompt_template: Optional custom prompt template - - Returns: - Generated response with metadata - """ - pass - - @abstractmethod - def delete_documents(self, doc_ids: List[str]) -> bool: - """Delete documents by IDs""" - pass - - @abstractmethod - def update_document(self, doc_id: str, document: Document) -> bool: - """Update an existing document""" - pass - - @abstractmethod - def get_statistics(self) -> Dict[str, Any]: - """Get statistics about indexed documents""" - pass - - def search( - self, - query: str, - industry: Optional[IndustryNiche] = None, - doc_types: Optional[List[DocumentType]] = None, - top_k: Optional[int] = None, - **filters, - ) -> List[RetrievalResult]: - """ - Convenience method for simple search - - Args: - query: Search query - industry: Filter by industry - doc_types: Filter by document types - top_k: Number of results - **filters: Additional metadata filters - - Returns: - List of retrieval results - """ - rag_query = RAGQuery( - query=query, - industry=industry, - doc_types=doc_types, - top_k=top_k or self.config.top_k, - metadata_filters=filters if filters else None, - ) - return self.retrieve(rag_query) +""" +Universal RAG Base Classes +Core abstractions for retrieval-augmented generation across all industries +""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from enum import Enum +from typing import List, Dict, Any, Optional, Union +from datetime import datetime + + +class DocumentType(Enum): + """Types of documents that can be indexed""" + REGULATION = "regulation" # Laws, regulations, compliance documents + POLICY = "policy" # Insurance policies, company policies + MANUAL = "manual" # Equipment manuals, operation guides + CONTRACT = "contract" # Legal contracts, agreements + WORK_ORDER = "work_order" # Maintenance work orders, service requests + CLAIM_RECORD = "claim_record" # Insurance claims, warranty claims + MAINTENANCE_LOG = "maintenance_log" # Equipment maintenance history + FAQ = "faq" # Frequently asked questions + KNOWLEDGE_BASE = "knowledge_base" # General knowledge articles + REPORT = "report" # Inspection reports, audit reports + PROCEDURE = "procedure" # Standard operating procedures + SAFETY_GUIDELINE = "safety_guideline" # Safety protocols and guidelines + TECHNICAL_SPEC = "technical_spec" # Technical specifications + INVOICE = "invoice" # Billing and invoices + OTHER = "other" # Catch-all for other document types + + +class IndustryNiche(Enum): + """Supported industry niches""" + INSURANCE = "insurance" # Property & casualty insurance + MANUFACTURING = "manufacturing" # Industrial manufacturing + PROPERTY_MANAGEMENT = "property_management" # Real estate management + HEALTHCARE = "healthcare" # Healthcare providers + CONSTRUCTION = "construction" # Construction industry + AUTOMOTIVE = "automotive" # Automotive services + ENERGY = "energy" # Energy & utilities + LOGISTICS = "logistics" # Transportation & logistics + RETAIL = "retail" # Retail operations + HOSPITALITY = "hospitality" # Hotels & hospitality + GENERIC = "generic" # Cross-industry + + +@dataclass +class RAGConfig: + """Configuration for RAG system""" + # Industry configuration + industry: IndustryNiche = IndustryNiche.GENERIC + + # Vector database configuration + vector_db_type: str = "milvus" # milvus, pinecone, weaviate, memory + collection_name: str = "deepiri_universal_rag" + vector_db_host: str = "milvus" + vector_db_port: int = 19530 + + # Embedding configuration + embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2" + embedding_dimension: int = 384 + + # Retrieval configuration + top_k: int = 5 # Number of documents to retrieve + similarity_threshold: float = 0.7 # Minimum similarity score + use_reranking: bool = True # Use cross-encoder reranking + reranker_model: str = "cross-encoder/ms-marco-MiniLM-L-6-v2" + + # Chunking configuration + chunk_size: int = 1000 # Characters per chunk + chunk_overlap: int = 200 # Overlap between chunks + + # Metadata filtering + enable_metadata_filtering: bool = True + date_range_filtering: bool = True + + # Multi-modal support + support_images: bool = False + support_tables: bool = True + support_code: bool = False + + +@dataclass +class Document: + """Universal document representation""" + id: str + content: str + doc_type: DocumentType + industry: IndustryNiche + + # Metadata + title: Optional[str] = None + source: Optional[str] = None + created_at: Optional[datetime] = None + updated_at: Optional[datetime] = None + author: Optional[str] = None + version: Optional[str] = None + + # Industry-specific metadata + metadata: Dict[str, Any] = field(default_factory=dict) + + # Processing metadata + chunk_index: Optional[int] = None + total_chunks: Optional[int] = None + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary for storage""" + return { + "id": self.id, + "content": self.content, + "doc_type": self.doc_type.value, + "industry": self.industry.value, + "title": self.title, + "source": self.source, + "created_at": self.created_at.isoformat() if self.created_at else None, + "updated_at": self.updated_at.isoformat() if self.updated_at else None, + "author": self.author, + "version": self.version, + "metadata": self.metadata, + "chunk_index": self.chunk_index, + "total_chunks": self.total_chunks, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> 'Document': + """Create from dictionary""" + return cls( + id=data["id"], + content=data["content"], + doc_type=DocumentType(data["doc_type"]), + industry=IndustryNiche(data["industry"]), + title=data.get("title"), + source=data.get("source"), + created_at=datetime.fromisoformat(data["created_at"]) if data.get("created_at") else None, + updated_at=datetime.fromisoformat(data["updated_at"]) if data.get("updated_at") else None, + author=data.get("author"), + version=data.get("version"), + metadata=data.get("metadata", {}), + chunk_index=data.get("chunk_index"), + total_chunks=data.get("total_chunks"), + ) + + +@dataclass +class RetrievalResult: + """Result from RAG retrieval""" + document: Document + score: float # Similarity score + rerank_score: Optional[float] = None # Reranking score if applicable + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary""" + return { + "document": self.document.to_dict(), + "score": self.score, + "rerank_score": self.rerank_score, + } + + +@dataclass +class RAGQuery: + """Query for RAG system""" + query: str + industry: Optional[IndustryNiche] = None + doc_types: Optional[List[DocumentType]] = None + date_range: Optional[tuple[datetime, datetime]] = None + metadata_filters: Optional[Dict[str, Any]] = None + top_k: Optional[int] = None + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary""" + return { + "query": self.query, + "industry": self.industry.value if self.industry else None, + "doc_types": [dt.value for dt in self.doc_types] if self.doc_types else None, + "date_range": [dr.isoformat() for dr in self.date_range] if self.date_range else None, + "metadata_filters": self.metadata_filters, + "top_k": self.top_k, + } + + +class UniversalRAGEngine(ABC): + """ + Abstract base class for universal RAG engine + Implements common RAG patterns across all industries + """ + + def __init__(self, config: RAGConfig): + self.config = config + self._initialize() + + @abstractmethod + def _initialize(self): + """Initialize RAG components (vector store, embeddings, etc.)""" + pass + + @abstractmethod + def index_document(self, document: Document) -> bool: + """ + Index a single document + + Args: + document: Document to index + + Returns: + True if successful, False otherwise + """ + pass + + @abstractmethod + def index_documents(self, documents: List[Document]) -> Dict[str, Any]: + """ + Index multiple documents in batch + + Args: + documents: List of documents to index + + Returns: + Statistics about the indexing operation + """ + pass + + @abstractmethod + def retrieve(self, query: RAGQuery) -> List[RetrievalResult]: + """ + Retrieve relevant documents for a query + + Args: + query: Query with filters and parameters + + Returns: + List of retrieval results with scores + """ + pass + + @abstractmethod + def generate_with_context( + self, + query: str, + retrieved_docs: List[RetrievalResult], + llm_prompt_template: Optional[str] = None + ) -> Dict[str, Any]: + """ + Generate response using retrieved context + + Args: + query: User query + retrieved_docs: Retrieved documents for context + llm_prompt_template: Optional custom prompt template + + Returns: + Generated response with metadata + """ + pass + + @abstractmethod + def delete_documents(self, doc_ids: List[str]) -> bool: + """Delete documents by IDs""" + pass + + @abstractmethod + def update_document(self, doc_id: str, document: Document) -> bool: + """Update an existing document""" + pass + + @abstractmethod + def get_statistics(self) -> Dict[str, Any]: + """Get statistics about indexed documents""" + pass + + def search( + self, + query: str, + industry: Optional[IndustryNiche] = None, + doc_types: Optional[List[DocumentType]] = None, + top_k: Optional[int] = None, + **filters + ) -> List[RetrievalResult]: + """ + Convenience method for simple search + + Args: + query: Search query + industry: Filter by industry + doc_types: Filter by document types + top_k: Number of results + **filters: Additional metadata filters + + Returns: + List of retrieval results + """ + rag_query = RAGQuery( + query=query, + industry=industry, + doc_types=doc_types, + top_k=top_k or self.config.top_k, + metadata_filters=filters if filters else None + ) + return self.retrieve(rag_query) + diff --git a/src/deepiri_modelkit/rag/caching.py b/src/deepiri_modelkit/rag/caching.py index e2e4b7b..8da0ff7 100644 --- a/src/deepiri_modelkit/rag/caching.py +++ b/src/deepiri_modelkit/rag/caching.py @@ -1,466 +1,466 @@ -""" -Advanced Caching Layer for Universal RAG -Redis-based caching with intelligent invalidation and TTL management -""" - -from typing import Optional, Any, List, Dict -import json -import hashlib -import time -from datetime import datetime, timedelta -from dataclasses import dataclass, asdict - -from .base import Document, RetrievalResult, RAGQuery - - -@dataclass -class CacheEntry: - """Cache entry with metadata""" - - key: str - value: Any - created_at: datetime - expires_at: Optional[datetime] - access_count: int = 0 - last_accessed: Optional[datetime] = None - tags: List[str] = None - - def __post_init__(self): - if self.tags is None: - self.tags = [] - if self.last_accessed is None: - self.last_accessed = self.created_at - - def is_expired(self) -> bool: - """Check if entry is expired""" - if self.expires_at is None: - return False - return datetime.now() > self.expires_at - - def to_dict(self) -> Dict: - """Convert to dictionary for serialization""" - return { - "key": self.key, - "value": self.value, - "created_at": self.created_at.isoformat(), - "expires_at": self.expires_at.isoformat() if self.expires_at else None, - "access_count": self.access_count, - "last_accessed": ( - self.last_accessed.isoformat() if self.last_accessed else None - ), - "tags": self.tags, - } - - @classmethod - def from_dict(cls, data: Dict) -> "CacheEntry": - """Create from dictionary""" - return cls( - key=data["key"], - value=data["value"], - created_at=datetime.fromisoformat(data["created_at"]), - expires_at=( - datetime.fromisoformat(data["expires_at"]) - if data.get("expires_at") - else None - ), - access_count=data.get("access_count", 0), - last_accessed=( - datetime.fromisoformat(data["last_accessed"]) - if data.get("last_accessed") - else None - ), - tags=data.get("tags", []), - ) - - -class AdvancedCacheManager: - """ - Advanced cache manager with: - - TTL management - - Tag-based invalidation - - Access tracking - - Size limits - - LRU eviction - """ - - def __init__( - self, - redis_client=None, - default_ttl: int = 3600, - max_size: int = 10000, - enable_lru: bool = True, - ): - self.redis_client = redis_client - self.default_ttl = default_ttl - self.max_size = max_size - self.enable_lru = enable_lru - - # In-memory fallback if Redis unavailable - self.memory_cache: Dict[str, CacheEntry] = {} - self.tag_index: Dict[str, List[str]] = {} # tag -> [keys] - - def _get_key_prefix(self, namespace: str = "rag") -> str: - """Get key prefix""" - return f"{namespace}:" - - def _serialize_value(self, value: Any) -> str: - """Serialize value for storage""" - if isinstance(value, (list, dict)): - return json.dumps(value) - return str(value) - - def _deserialize_value(self, value: str, value_type: type = None) -> Any: - """Deserialize value from storage""" - try: - if value_type == list or (isinstance(value, str) and value.startswith("[")): - return json.loads(value) - elif value_type == dict or ( - isinstance(value, str) and value.startswith("{") - ): - return json.loads(value) - return value - except (json.JSONDecodeError, TypeError): - return value - - def get( - self, key: str, namespace: str = "rag", update_access: bool = True - ) -> Optional[Any]: - """Get value from cache""" - full_key = f"{self._get_key_prefix(namespace)}{key}" - - # Try Redis first - if self.redis_client: - try: - cached = self.redis_client.get(full_key) - if cached: - entry_data = json.loads(cached) - entry = CacheEntry.from_dict(entry_data) - - if entry.is_expired(): - self.delete(key, namespace) - return None - - if update_access: - entry.access_count += 1 - entry.last_accessed = datetime.now() - self._update_redis_entry(full_key, entry) - - return entry.value - except Exception as e: - # Fallback to memory - pass - - # Fallback to memory cache - if key in self.memory_cache: - entry = self.memory_cache[key] - - if entry.is_expired(): - del self.memory_cache[key] - return None - - if update_access: - entry.access_count += 1 - entry.last_accessed = datetime.now() - - return entry.value - - return None - - def set( - self, - key: str, - value: Any, - namespace: str = "rag", - ttl: Optional[int] = None, - tags: Optional[List[str]] = None, - ) -> bool: - """Set value in cache""" - full_key = f"{self._get_key_prefix(namespace)}{key}" - ttl = ttl or self.default_ttl - tags = tags or [] - - expires_at = datetime.now() + timedelta(seconds=ttl) - entry = CacheEntry( - key=full_key, - value=value, - created_at=datetime.now(), - expires_at=expires_at, - tags=tags, - ) - - # Try Redis first - if self.redis_client: - try: - entry_data = json.dumps(entry.to_dict()) - self.redis_client.setex(full_key, ttl, entry_data) - - # Update tag index - for tag in tags: - tag_key = f"{self._get_key_prefix(namespace)}tag:{tag}" - if tag_key not in self.tag_index: - self.tag_index[tag_key] = [] - if full_key not in self.tag_index[tag_key]: - self.tag_index[tag_key].append(full_key) - self.redis_client.sadd(tag_key, full_key) - - return True - except Exception as e: - # Fallback to memory - pass - - # Fallback to memory cache - # Check size limit - if len(self.memory_cache) >= self.max_size: - if self.enable_lru: - self._evict_lru() - else: - # Remove oldest - oldest_key = min( - self.memory_cache.keys(), - key=lambda k: self.memory_cache[k].created_at, - ) - del self.memory_cache[oldest_key] - - self.memory_cache[key] = entry - - # Update tag index - for tag in tags: - if tag not in self.tag_index: - self.tag_index[tag] = [] - if key not in self.tag_index[tag]: - self.tag_index[tag].append(key) - - return True - - def delete(self, key: str, namespace: str = "rag") -> bool: - """Delete key from cache""" - full_key = f"{self._get_key_prefix(namespace)}{key}" - - # Try Redis - if self.redis_client: - try: - self.redis_client.delete(full_key) - return True - except Exception: - pass - - # Memory cache - if key in self.memory_cache: - entry = self.memory_cache[key] - # Remove from tag indexes - for tag in entry.tags: - if tag in self.tag_index and key in self.tag_index[tag]: - self.tag_index[tag].remove(key) - del self.memory_cache[key] - return True - - return False - - def invalidate_by_tag(self, tag: str, namespace: str = "rag") -> int: - """Invalidate all keys with given tag""" - tag_key = f"{self._get_key_prefix(namespace)}tag:{tag}" - count = 0 - - # Get keys from tag index - keys_to_delete = [] - - if self.redis_client: - try: - keys_to_delete = list(self.redis_client.smembers(tag_key)) - self.redis_client.delete(tag_key) - except Exception: - pass - - if tag in self.tag_index: - keys_to_delete.extend(self.tag_index[tag]) - del self.tag_index[tag] - - # Delete all keys - for key in keys_to_delete: - if self.delete(key.replace(self._get_key_prefix(namespace), ""), namespace): - count += 1 - - return count - - def invalidate_by_pattern(self, pattern: str, namespace: str = "rag") -> int: - """Invalidate keys matching pattern""" - full_pattern = f"{self._get_key_prefix(namespace)}{pattern}" - count = 0 - - if self.redis_client: - try: - keys = self.redis_client.keys(full_pattern) - if keys: - count = self.redis_client.delete(*keys) - except Exception: - pass - - # Memory cache - keys_to_delete = [ - k for k in self.memory_cache.keys() if self._match_pattern(k, pattern) - ] - for key in keys_to_delete: - self.delete(key, namespace) - count += 1 - - return count - - def _match_pattern(self, key: str, pattern: str) -> bool: - """Simple pattern matching (supports * wildcard)""" - import fnmatch - - return fnmatch.fnmatch(key, pattern) - - def _evict_lru(self): - """Evict least recently used entry""" - if not self.memory_cache: - return - - lru_key = min( - self.memory_cache.keys(), - key=lambda k: ( - self.memory_cache[k].last_accessed or self.memory_cache[k].created_at - ), - ) - del self.memory_cache[lru_key] - - def _update_redis_entry(self, key: str, entry: CacheEntry): - """Update Redis entry with new access info""" - if not self.redis_client: - return - - try: - entry_data = json.dumps(entry.to_dict()) - # Get remaining TTL - ttl = self.redis_client.ttl(key) - if ttl > 0: - self.redis_client.setex(key, ttl, entry_data) - except Exception: - pass - - def get_stats(self) -> Dict[str, Any]: - """Get cache statistics""" - stats = { - "memory_cache_size": len(self.memory_cache), - "max_size": self.max_size, - "tag_index_size": len(self.tag_index), - "redis_available": self.redis_client is not None, - } - - if self.memory_cache: - total_access = sum(e.access_count for e in self.memory_cache.values()) - stats["total_accesses"] = total_access - stats["avg_access_per_entry"] = total_access / len(self.memory_cache) - - return stats - - def clear(self, namespace: str = "rag"): - """Clear all cache entries in namespace""" - pattern = f"{self._get_key_prefix(namespace)}*" - return self.invalidate_by_pattern(pattern, namespace) - - -class EmbeddingCache: - """Specialized cache for embeddings""" - - def __init__(self, cache_manager: AdvancedCacheManager): - self.cache_manager = cache_manager - self.namespace = "rag:embeddings" - - def get_embedding_key(self, text: str) -> str: - """Generate cache key for embedding""" - text_hash = hashlib.md5(text.encode()).hexdigest() - return f"emb:{text_hash}" - - def get(self, text: str) -> Optional[Any]: - """Get cached embedding""" - key = self.get_embedding_key(text) - return self.cache_manager.get(key, namespace=self.namespace) - - def set(self, text: str, embedding: Any, ttl: int = 86400): - """Cache embedding (24 hour default TTL)""" - key = self.get_embedding_key(text) - return self.cache_manager.set( - key, embedding, namespace=self.namespace, ttl=ttl, tags=["embedding"] - ) - - -class QueryResultCache: - """Specialized cache for query results""" - - def __init__(self, cache_manager: AdvancedCacheManager): - self.cache_manager = cache_manager - self.namespace = "rag:queries" - - def get_query_key(self, query: RAGQuery) -> str: - """Generate cache key for query""" - query_dict = query.to_dict() - query_str = json.dumps(query_dict, sort_keys=True) - query_hash = hashlib.md5(query_str.encode()).hexdigest() - return f"query:{query_hash}" - - def get(self, query: RAGQuery) -> Optional[List[RetrievalResult]]: - """Get cached query results""" - key = self.get_query_key(query) - cached = self.cache_manager.get(key, namespace=self.namespace) - - if cached: - # Reconstruct RetrievalResult objects - results = [] - for item in cached: - doc = Document.from_dict(item["document"]) - result = RetrievalResult( - document=doc, - score=item["score"], - rerank_score=item.get("rerank_score"), - ) - results.append(result) - return results - - return None - - def set( - self, - query: RAGQuery, - results: List[RetrievalResult], - ttl: int = 3600, - tags: Optional[List[str]] = None, - ): - """Cache query results""" - key = self.get_query_key(query) - - # Serialize results - serialized = [ - { - "document": r.document.to_dict(), - "score": r.score, - "rerank_score": r.rerank_score, - } - for r in results - ] - - # Add query tags - query_tags = tags or [] - if query.industry: - query_tags.append( - f"industry:{query.industry.value if hasattr(query.industry, 'value') else query.industry}" - ) - if query.doc_types: - for dt in query.doc_types: - query_tags.append(f"doctype:{dt.value if hasattr(dt, 'value') else dt}") - - return self.cache_manager.set( - key, serialized, namespace=self.namespace, ttl=ttl, tags=query_tags - ) - - def invalidate_by_industry(self, industry: str): - """Invalidate all queries for an industry""" - return self.cache_manager.invalidate_by_tag( - f"industry:{industry}", namespace=self.namespace - ) - - def invalidate_by_doc_type(self, doc_type: str): - """Invalidate all queries for a document type""" - return self.cache_manager.invalidate_by_tag( - f"doctype:{doc_type}", namespace=self.namespace - ) +""" +Advanced Caching Layer for Universal RAG +Redis-based caching with intelligent invalidation and TTL management +""" + +from typing import Optional, Any, List, Dict +import json +import hashlib +import time +from datetime import datetime, timedelta +from dataclasses import dataclass, asdict + +from .base import Document, RetrievalResult, RAGQuery + + +@dataclass +class CacheEntry: + """Cache entry with metadata""" + key: str + value: Any + created_at: datetime + expires_at: Optional[datetime] + access_count: int = 0 + last_accessed: Optional[datetime] = None + tags: List[str] = None + + def __post_init__(self): + if self.tags is None: + self.tags = [] + if self.last_accessed is None: + self.last_accessed = self.created_at + + def is_expired(self) -> bool: + """Check if entry is expired""" + if self.expires_at is None: + return False + return datetime.now() > self.expires_at + + def to_dict(self) -> Dict: + """Convert to dictionary for serialization""" + return { + "key": self.key, + "value": self.value, + "created_at": self.created_at.isoformat(), + "expires_at": self.expires_at.isoformat() if self.expires_at else None, + "access_count": self.access_count, + "last_accessed": self.last_accessed.isoformat() if self.last_accessed else None, + "tags": self.tags + } + + @classmethod + def from_dict(cls, data: Dict) -> 'CacheEntry': + """Create from dictionary""" + return cls( + key=data["key"], + value=data["value"], + created_at=datetime.fromisoformat(data["created_at"]), + expires_at=datetime.fromisoformat(data["expires_at"]) if data.get("expires_at") else None, + access_count=data.get("access_count", 0), + last_accessed=datetime.fromisoformat(data["last_accessed"]) if data.get("last_accessed") else None, + tags=data.get("tags", []) + ) + + +class AdvancedCacheManager: + """ + Advanced cache manager with: + - TTL management + - Tag-based invalidation + - Access tracking + - Size limits + - LRU eviction + """ + + def __init__( + self, + redis_client=None, + default_ttl: int = 3600, + max_size: int = 10000, + enable_lru: bool = True + ): + self.redis_client = redis_client + self.default_ttl = default_ttl + self.max_size = max_size + self.enable_lru = enable_lru + + # In-memory fallback if Redis unavailable + self.memory_cache: Dict[str, CacheEntry] = {} + self.tag_index: Dict[str, List[str]] = {} # tag -> [keys] + + def _get_key_prefix(self, namespace: str = "rag") -> str: + """Get key prefix""" + return f"{namespace}:" + + def _serialize_value(self, value: Any) -> str: + """Serialize value for storage""" + if isinstance(value, (list, dict)): + return json.dumps(value) + return str(value) + + def _deserialize_value(self, value: str, value_type: type = None) -> Any: + """Deserialize value from storage""" + try: + if value_type == list or (isinstance(value, str) and value.startswith('[')): + return json.loads(value) + elif value_type == dict or (isinstance(value, str) and value.startswith('{')): + return json.loads(value) + return value + except (json.JSONDecodeError, TypeError): + return value + + def get( + self, + key: str, + namespace: str = "rag", + update_access: bool = True + ) -> Optional[Any]: + """Get value from cache""" + full_key = f"{self._get_key_prefix(namespace)}{key}" + + # Try Redis first + if self.redis_client: + try: + cached = self.redis_client.get(full_key) + if cached: + entry_data = json.loads(cached) + entry = CacheEntry.from_dict(entry_data) + + if entry.is_expired(): + self.delete(key, namespace) + return None + + if update_access: + entry.access_count += 1 + entry.last_accessed = datetime.now() + self._update_redis_entry(full_key, entry) + + return entry.value + except Exception as e: + # Fallback to memory + pass + + # Fallback to memory cache + if key in self.memory_cache: + entry = self.memory_cache[key] + + if entry.is_expired(): + del self.memory_cache[key] + return None + + if update_access: + entry.access_count += 1 + entry.last_accessed = datetime.now() + + return entry.value + + return None + + def set( + self, + key: str, + value: Any, + namespace: str = "rag", + ttl: Optional[int] = None, + tags: Optional[List[str]] = None + ) -> bool: + """Set value in cache""" + full_key = f"{self._get_key_prefix(namespace)}{key}" + ttl = ttl or self.default_ttl + tags = tags or [] + + expires_at = datetime.now() + timedelta(seconds=ttl) + entry = CacheEntry( + key=full_key, + value=value, + created_at=datetime.now(), + expires_at=expires_at, + tags=tags + ) + + # Try Redis first + if self.redis_client: + try: + entry_data = json.dumps(entry.to_dict()) + self.redis_client.setex(full_key, ttl, entry_data) + + # Update tag index + for tag in tags: + tag_key = f"{self._get_key_prefix(namespace)}tag:{tag}" + if tag_key not in self.tag_index: + self.tag_index[tag_key] = [] + if full_key not in self.tag_index[tag_key]: + self.tag_index[tag_key].append(full_key) + self.redis_client.sadd(tag_key, full_key) + + return True + except Exception as e: + # Fallback to memory + pass + + # Fallback to memory cache + # Check size limit + if len(self.memory_cache) >= self.max_size: + if self.enable_lru: + self._evict_lru() + else: + # Remove oldest + oldest_key = min( + self.memory_cache.keys(), + key=lambda k: self.memory_cache[k].created_at + ) + del self.memory_cache[oldest_key] + + self.memory_cache[key] = entry + + # Update tag index + for tag in tags: + if tag not in self.tag_index: + self.tag_index[tag] = [] + if key not in self.tag_index[tag]: + self.tag_index[tag].append(key) + + return True + + def delete(self, key: str, namespace: str = "rag") -> bool: + """Delete key from cache""" + full_key = f"{self._get_key_prefix(namespace)}{key}" + + # Try Redis + if self.redis_client: + try: + self.redis_client.delete(full_key) + return True + except Exception: + pass + + # Memory cache + if key in self.memory_cache: + entry = self.memory_cache[key] + # Remove from tag indexes + for tag in entry.tags: + if tag in self.tag_index and key in self.tag_index[tag]: + self.tag_index[tag].remove(key) + del self.memory_cache[key] + return True + + return False + + def invalidate_by_tag(self, tag: str, namespace: str = "rag") -> int: + """Invalidate all keys with given tag""" + tag_key = f"{self._get_key_prefix(namespace)}tag:{tag}" + count = 0 + + # Get keys from tag index + keys_to_delete = [] + + if self.redis_client: + try: + keys_to_delete = list(self.redis_client.smembers(tag_key)) + self.redis_client.delete(tag_key) + except Exception: + pass + + if tag in self.tag_index: + keys_to_delete.extend(self.tag_index[tag]) + del self.tag_index[tag] + + # Delete all keys + for key in keys_to_delete: + if self.delete(key.replace(self._get_key_prefix(namespace), ""), namespace): + count += 1 + + return count + + def invalidate_by_pattern(self, pattern: str, namespace: str = "rag") -> int: + """Invalidate keys matching pattern""" + full_pattern = f"{self._get_key_prefix(namespace)}{pattern}" + count = 0 + + if self.redis_client: + try: + keys = self.redis_client.keys(full_pattern) + if keys: + count = self.redis_client.delete(*keys) + except Exception: + pass + + # Memory cache + keys_to_delete = [ + k for k in self.memory_cache.keys() + if self._match_pattern(k, pattern) + ] + for key in keys_to_delete: + self.delete(key, namespace) + count += 1 + + return count + + def _match_pattern(self, key: str, pattern: str) -> bool: + """Simple pattern matching (supports * wildcard)""" + import fnmatch + return fnmatch.fnmatch(key, pattern) + + def _evict_lru(self): + """Evict least recently used entry""" + if not self.memory_cache: + return + + lru_key = min( + self.memory_cache.keys(), + key=lambda k: ( + self.memory_cache[k].last_accessed or + self.memory_cache[k].created_at + ) + ) + del self.memory_cache[lru_key] + + def _update_redis_entry(self, key: str, entry: CacheEntry): + """Update Redis entry with new access info""" + if not self.redis_client: + return + + try: + entry_data = json.dumps(entry.to_dict()) + # Get remaining TTL + ttl = self.redis_client.ttl(key) + if ttl > 0: + self.redis_client.setex(key, ttl, entry_data) + except Exception: + pass + + def get_stats(self) -> Dict[str, Any]: + """Get cache statistics""" + stats = { + "memory_cache_size": len(self.memory_cache), + "max_size": self.max_size, + "tag_index_size": len(self.tag_index), + "redis_available": self.redis_client is not None, + } + + if self.memory_cache: + total_access = sum(e.access_count for e in self.memory_cache.values()) + stats["total_accesses"] = total_access + stats["avg_access_per_entry"] = total_access / len(self.memory_cache) + + return stats + + def clear(self, namespace: str = "rag"): + """Clear all cache entries in namespace""" + pattern = f"{self._get_key_prefix(namespace)}*" + return self.invalidate_by_pattern(pattern, namespace) + + +class EmbeddingCache: + """Specialized cache for embeddings""" + + def __init__(self, cache_manager: AdvancedCacheManager): + self.cache_manager = cache_manager + self.namespace = "rag:embeddings" + + def get_embedding_key(self, text: str) -> str: + """Generate cache key for embedding""" + text_hash = hashlib.md5(text.encode()).hexdigest() + return f"emb:{text_hash}" + + def get(self, text: str) -> Optional[Any]: + """Get cached embedding""" + key = self.get_embedding_key(text) + return self.cache_manager.get(key, namespace=self.namespace) + + def set(self, text: str, embedding: Any, ttl: int = 86400): + """Cache embedding (24 hour default TTL)""" + key = self.get_embedding_key(text) + return self.cache_manager.set( + key, + embedding, + namespace=self.namespace, + ttl=ttl, + tags=["embedding"] + ) + + +class QueryResultCache: + """Specialized cache for query results""" + + def __init__(self, cache_manager: AdvancedCacheManager): + self.cache_manager = cache_manager + self.namespace = "rag:queries" + + def get_query_key(self, query: RAGQuery) -> str: + """Generate cache key for query""" + query_dict = query.to_dict() + query_str = json.dumps(query_dict, sort_keys=True) + query_hash = hashlib.md5(query_str.encode()).hexdigest() + return f"query:{query_hash}" + + def get(self, query: RAGQuery) -> Optional[List[RetrievalResult]]: + """Get cached query results""" + key = self.get_query_key(query) + cached = self.cache_manager.get(key, namespace=self.namespace) + + if cached: + # Reconstruct RetrievalResult objects + results = [] + for item in cached: + doc = Document.from_dict(item["document"]) + result = RetrievalResult( + document=doc, + score=item["score"], + rerank_score=item.get("rerank_score") + ) + results.append(result) + return results + + return None + + def set( + self, + query: RAGQuery, + results: List[RetrievalResult], + ttl: int = 3600, + tags: Optional[List[str]] = None + ): + """Cache query results""" + key = self.get_query_key(query) + + # Serialize results + serialized = [ + { + "document": r.document.to_dict(), + "score": r.score, + "rerank_score": r.rerank_score + } + for r in results + ] + + # Add query tags + query_tags = tags or [] + if query.industry: + query_tags.append(f"industry:{query.industry.value if hasattr(query.industry, 'value') else query.industry}") + if query.doc_types: + for dt in query.doc_types: + query_tags.append(f"doctype:{dt.value if hasattr(dt, 'value') else dt}") + + return self.cache_manager.set( + key, + serialized, + namespace=self.namespace, + ttl=ttl, + tags=query_tags + ) + + def invalidate_by_industry(self, industry: str): + """Invalidate all queries for an industry""" + return self.cache_manager.invalidate_by_tag( + f"industry:{industry}", + namespace=self.namespace + ) + + def invalidate_by_doc_type(self, doc_type: str): + """Invalidate all queries for a document type""" + return self.cache_manager.invalidate_by_tag( + f"doctype:{doc_type}", + namespace=self.namespace + ) + diff --git a/src/deepiri_modelkit/rag/monitoring.py b/src/deepiri_modelkit/rag/monitoring.py index f587408..95429b3 100644 --- a/src/deepiri_modelkit/rag/monitoring.py +++ b/src/deepiri_modelkit/rag/monitoring.py @@ -1,375 +1,354 @@ -""" -Monitoring and Observability for Universal RAG -Metrics, performance tracking, and analytics -""" - -from typing import Dict, Any, List, Optional -from dataclasses import dataclass, field, asdict -from datetime import datetime, timedelta -from collections import defaultdict -import time -import json - -from .base import RAGQuery, RetrievalResult - - -@dataclass -class RetrievalMetrics: - """Metrics for a single retrieval operation""" - - query_id: str - query_text: str - timestamp: datetime - retrieval_time_ms: float - num_results: int - top_score: Optional[float] = None - cache_hit: bool = False - reranking_used: bool = False - query_expansion_used: bool = False - industry: Optional[str] = None - doc_types: Optional[List[str]] = None - - def to_dict(self) -> Dict: - """Convert to dictionary""" - return { - "query_id": self.query_id, - "query_text": self.query_text, - "timestamp": self.timestamp.isoformat(), - "retrieval_time_ms": self.retrieval_time_ms, - "num_results": self.num_results, - "top_score": self.top_score, - "cache_hit": self.cache_hit, - "reranking_used": self.reranking_used, - "query_expansion_used": self.query_expansion_used, - "industry": self.industry, - "doc_types": self.doc_types, - } - - -@dataclass -class IndexingMetrics: - """Metrics for indexing operations""" - - operation_id: str - timestamp: datetime - operation_type: str # "index", "update", "delete" - num_documents: int - processing_time_ms: float - success: bool - error: Optional[str] = None - - def to_dict(self) -> Dict: - """Convert to dictionary""" - return { - "operation_id": self.operation_id, - "timestamp": self.timestamp.isoformat(), - "operation_type": self.operation_type, - "num_documents": self.num_documents, - "processing_time_ms": self.processing_time_ms, - "success": self.success, - "error": self.error, - } - - -@dataclass -class SystemMetrics: - """System-wide metrics""" - - total_queries: int = 0 - total_indexed_documents: int = 0 - cache_hit_rate: float = 0.0 - avg_retrieval_time_ms: float = 0.0 - avg_indexing_time_ms: float = 0.0 - error_rate: float = 0.0 - last_updated: Optional[datetime] = None - - def to_dict(self) -> Dict: - """Convert to dictionary""" - return { - "total_queries": self.total_queries, - "total_indexed_documents": self.total_indexed_documents, - "cache_hit_rate": self.cache_hit_rate, - "avg_retrieval_time_ms": self.avg_retrieval_time_ms, - "avg_indexing_time_ms": self.avg_indexing_time_ms, - "error_rate": self.error_rate, - "last_updated": ( - self.last_updated.isoformat() if self.last_updated else None - ), - } - - -class RAGMonitor: - """ - Monitor RAG system performance and collect metrics - """ - - def __init__(self, max_history: int = 10000): - self.max_history = max_history - - # Metrics storage - self.retrieval_metrics: List[RetrievalMetrics] = [] - self.indexing_metrics: List[IndexingMetrics] = [] - - # Aggregated stats - self.system_metrics = SystemMetrics() - - # Time windows for analysis - self.hourly_stats: Dict[str, Dict] = defaultdict(dict) - self.daily_stats: Dict[str, Dict] = defaultdict(dict) - - def record_retrieval( - self, - query: RAGQuery, - results: List[RetrievalResult], - retrieval_time_ms: float, - cache_hit: bool = False, - reranking_used: bool = False, - query_expansion_used: bool = False, - ): - """Record retrieval metrics""" - query_id = f"q_{int(time.time() * 1000)}" - - metric = RetrievalMetrics( - query_id=query_id, - query_text=query.query, - timestamp=datetime.now(), - retrieval_time_ms=retrieval_time_ms, - num_results=len(results), - top_score=results[0].score if results else None, - cache_hit=cache_hit, - reranking_used=reranking_used, - query_expansion_used=query_expansion_used, - industry=( - query.industry.value - if query.industry and hasattr(query.industry, "value") - else str(query.industry) - ), - doc_types=( - [ - dt.value if hasattr(dt, "value") else str(dt) - for dt in query.doc_types - ] - if query.doc_types - else None - ), - ) - - self.retrieval_metrics.append(metric) - - # Trim history - if len(self.retrieval_metrics) > self.max_history: - self.retrieval_metrics = self.retrieval_metrics[-self.max_history :] - - # Update aggregated stats - self._update_system_metrics() - - def record_indexing( - self, - operation_type: str, - num_documents: int, - processing_time_ms: float, - success: bool, - error: Optional[str] = None, - ): - """Record indexing metrics""" - operation_id = f"idx_{int(time.time() * 1000)}" - - metric = IndexingMetrics( - operation_id=operation_id, - timestamp=datetime.now(), - operation_type=operation_type, - num_documents=num_documents, - processing_time_ms=processing_time_ms, - success=success, - error=error, - ) - - self.indexing_metrics.append(metric) - - # Trim history - if len(self.indexing_metrics) > self.max_history: - self.indexing_metrics = self.indexing_metrics[-self.max_history :] - - # Update aggregated stats - self._update_system_metrics() - - def _update_system_metrics(self): - """Update system-wide aggregated metrics""" - if not self.retrieval_metrics: - return - - # Calculate averages - total_queries = len(self.retrieval_metrics) - cache_hits = sum(1 for m in self.retrieval_metrics if m.cache_hit) - total_retrieval_time = sum(m.retrieval_time_ms for m in self.retrieval_metrics) - - self.system_metrics.total_queries = total_queries - self.system_metrics.cache_hit_rate = ( - cache_hits / total_queries if total_queries > 0 else 0.0 - ) - self.system_metrics.avg_retrieval_time_ms = ( - total_retrieval_time / total_queries if total_queries > 0 else 0.0 - ) - - if self.indexing_metrics: - total_indexed = sum( - m.num_documents for m in self.indexing_metrics if m.success - ) - total_indexing_time = sum( - m.processing_time_ms for m in self.indexing_metrics - ) - total_indexing_ops = len(self.indexing_metrics) - - self.system_metrics.total_indexed_documents = total_indexed - self.system_metrics.avg_indexing_time_ms = ( - total_indexing_time / total_indexing_ops - if total_indexing_ops > 0 - else 0.0 - ) - - # Error rate - total_ops = total_queries + len(self.indexing_metrics) - errors = sum(1 for m in self.indexing_metrics if not m.success) - self.system_metrics.error_rate = errors / total_ops if total_ops > 0 else 0.0 - - self.system_metrics.last_updated = datetime.now() - - def get_retrieval_stats( - self, time_window_minutes: Optional[int] = None, industry: Optional[str] = None - ) -> Dict[str, Any]: - """Get retrieval statistics for time window""" - metrics = self.retrieval_metrics - - # Filter by time window - if time_window_minutes: - cutoff = datetime.now() - timedelta(minutes=time_window_minutes) - metrics = [m for m in metrics if m.timestamp >= cutoff] - - # Filter by industry - if industry: - metrics = [m for m in metrics if m.industry == industry] - - if not metrics: - return { - "count": 0, - "avg_time_ms": 0.0, - "cache_hit_rate": 0.0, - } - - return { - "count": len(metrics), - "avg_time_ms": sum(m.retrieval_time_ms for m in metrics) / len(metrics), - "min_time_ms": min(m.retrieval_time_ms for m in metrics), - "max_time_ms": max(m.retrieval_time_ms for m in metrics), - "cache_hit_rate": sum(1 for m in metrics if m.cache_hit) / len(metrics), - "avg_results": sum(m.num_results for m in metrics) / len(metrics), - "avg_top_score": sum(m.top_score for m in metrics if m.top_score) - / len([m for m in metrics if m.top_score]), - } - - def get_indexing_stats( - self, time_window_minutes: Optional[int] = None - ) -> Dict[str, Any]: - """Get indexing statistics""" - metrics = self.indexing_metrics - - # Filter by time window - if time_window_minutes: - cutoff = datetime.now() - timedelta(minutes=time_window_minutes) - metrics = [m for m in metrics if m.timestamp >= cutoff] - - if not metrics: - return { - "count": 0, - "total_documents": 0, - "avg_time_ms": 0.0, - "success_rate": 0.0, - } - - successful = [m for m in metrics if m.success] - - return { - "count": len(metrics), - "total_documents": sum(m.num_documents for m in successful), - "avg_time_ms": sum(m.processing_time_ms for m in metrics) / len(metrics), - "success_rate": len(successful) / len(metrics) if metrics else 0.0, - "error_count": len([m for m in metrics if not m.success]), - } - - def get_top_queries( - self, limit: int = 10, time_window_minutes: Optional[int] = None - ) -> List[Dict[str, Any]]: - """Get most frequent queries""" - metrics = self.retrieval_metrics - - # Filter by time window - if time_window_minutes: - cutoff = datetime.now() - timedelta(minutes=time_window_minutes) - metrics = [m for m in metrics if m.timestamp >= cutoff] - - # Count query frequencies - query_counts: Dict[str, int] = defaultdict(int) - for m in metrics: - query_counts[m.query_text] += 1 - - # Sort by frequency - top_queries = sorted(query_counts.items(), key=lambda x: x[1], reverse=True)[ - :limit - ] - - return [{"query": query, "count": count} for query, count in top_queries] - - def get_performance_report(self) -> Dict[str, Any]: - """Get comprehensive performance report""" - return { - "system_metrics": self.system_metrics.to_dict(), - "retrieval_stats_1h": self.get_retrieval_stats(time_window_minutes=60), - "retrieval_stats_24h": self.get_retrieval_stats(time_window_minutes=1440), - "indexing_stats_1h": self.get_indexing_stats(time_window_minutes=60), - "indexing_stats_24h": self.get_indexing_stats(time_window_minutes=1440), - "top_queries_24h": self.get_top_queries(limit=10, time_window_minutes=1440), - } - - def export_metrics(self, filepath: str): - """Export metrics to JSON file""" - data = { - "retrieval_metrics": [ - m.to_dict() for m in self.retrieval_metrics[-1000:] - ], # Last 1000 - "indexing_metrics": [ - m.to_dict() for m in self.indexing_metrics[-1000:] - ], # Last 1000 - "system_metrics": self.system_metrics.to_dict(), - "exported_at": datetime.now().isoformat(), - } - - with open(filepath, "w") as f: - json.dump(data, f, indent=2) - - -class PerformanceTimer: - """Context manager for timing operations""" - - def __init__( - self, monitor: Optional[RAGMonitor] = None, operation_name: str = "operation" - ): - self.monitor = monitor - self.operation_name = operation_name - self.start_time = None - self.end_time = None - - def __enter__(self): - self.start_time = time.time() - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.end_time = time.time() - return False - - def elapsed_ms(self) -> float: - """Get elapsed time in milliseconds""" - if self.start_time and self.end_time: - return (self.end_time - self.start_time) * 1000 - elif self.start_time: - return (time.time() - self.start_time) * 1000 - return 0.0 +""" +Monitoring and Observability for Universal RAG +Metrics, performance tracking, and analytics +""" + +from typing import Dict, Any, List, Optional +from dataclasses import dataclass, field, asdict +from datetime import datetime, timedelta +from collections import defaultdict +import time +import json + +from .base import RAGQuery, RetrievalResult + + +@dataclass +class RetrievalMetrics: + """Metrics for a single retrieval operation""" + query_id: str + query_text: str + timestamp: datetime + retrieval_time_ms: float + num_results: int + top_score: Optional[float] = None + cache_hit: bool = False + reranking_used: bool = False + query_expansion_used: bool = False + industry: Optional[str] = None + doc_types: Optional[List[str]] = None + + def to_dict(self) -> Dict: + """Convert to dictionary""" + return { + "query_id": self.query_id, + "query_text": self.query_text, + "timestamp": self.timestamp.isoformat(), + "retrieval_time_ms": self.retrieval_time_ms, + "num_results": self.num_results, + "top_score": self.top_score, + "cache_hit": self.cache_hit, + "reranking_used": self.reranking_used, + "query_expansion_used": self.query_expansion_used, + "industry": self.industry, + "doc_types": self.doc_types, + } + + +@dataclass +class IndexingMetrics: + """Metrics for indexing operations""" + operation_id: str + timestamp: datetime + operation_type: str # "index", "update", "delete" + num_documents: int + processing_time_ms: float + success: bool + error: Optional[str] = None + + def to_dict(self) -> Dict: + """Convert to dictionary""" + return { + "operation_id": self.operation_id, + "timestamp": self.timestamp.isoformat(), + "operation_type": self.operation_type, + "num_documents": self.num_documents, + "processing_time_ms": self.processing_time_ms, + "success": self.success, + "error": self.error, + } + + +@dataclass +class SystemMetrics: + """System-wide metrics""" + total_queries: int = 0 + total_indexed_documents: int = 0 + cache_hit_rate: float = 0.0 + avg_retrieval_time_ms: float = 0.0 + avg_indexing_time_ms: float = 0.0 + error_rate: float = 0.0 + last_updated: Optional[datetime] = None + + def to_dict(self) -> Dict: + """Convert to dictionary""" + return { + "total_queries": self.total_queries, + "total_indexed_documents": self.total_indexed_documents, + "cache_hit_rate": self.cache_hit_rate, + "avg_retrieval_time_ms": self.avg_retrieval_time_ms, + "avg_indexing_time_ms": self.avg_indexing_time_ms, + "error_rate": self.error_rate, + "last_updated": self.last_updated.isoformat() if self.last_updated else None, + } + + +class RAGMonitor: + """ + Monitor RAG system performance and collect metrics + """ + + def __init__(self, max_history: int = 10000): + self.max_history = max_history + + # Metrics storage + self.retrieval_metrics: List[RetrievalMetrics] = [] + self.indexing_metrics: List[IndexingMetrics] = [] + + # Aggregated stats + self.system_metrics = SystemMetrics() + + # Time windows for analysis + self.hourly_stats: Dict[str, Dict] = defaultdict(dict) + self.daily_stats: Dict[str, Dict] = defaultdict(dict) + + def record_retrieval( + self, + query: RAGQuery, + results: List[RetrievalResult], + retrieval_time_ms: float, + cache_hit: bool = False, + reranking_used: bool = False, + query_expansion_used: bool = False + ): + """Record retrieval metrics""" + query_id = f"q_{int(time.time() * 1000)}" + + metric = RetrievalMetrics( + query_id=query_id, + query_text=query.query, + timestamp=datetime.now(), + retrieval_time_ms=retrieval_time_ms, + num_results=len(results), + top_score=results[0].score if results else None, + cache_hit=cache_hit, + reranking_used=reranking_used, + query_expansion_used=query_expansion_used, + industry=query.industry.value if query.industry and hasattr(query.industry, 'value') else str(query.industry), + doc_types=[dt.value if hasattr(dt, 'value') else str(dt) for dt in query.doc_types] if query.doc_types else None, + ) + + self.retrieval_metrics.append(metric) + + # Trim history + if len(self.retrieval_metrics) > self.max_history: + self.retrieval_metrics = self.retrieval_metrics[-self.max_history:] + + # Update aggregated stats + self._update_system_metrics() + + def record_indexing( + self, + operation_type: str, + num_documents: int, + processing_time_ms: float, + success: bool, + error: Optional[str] = None + ): + """Record indexing metrics""" + operation_id = f"idx_{int(time.time() * 1000)}" + + metric = IndexingMetrics( + operation_id=operation_id, + timestamp=datetime.now(), + operation_type=operation_type, + num_documents=num_documents, + processing_time_ms=processing_time_ms, + success=success, + error=error + ) + + self.indexing_metrics.append(metric) + + # Trim history + if len(self.indexing_metrics) > self.max_history: + self.indexing_metrics = self.indexing_metrics[-self.max_history:] + + # Update aggregated stats + self._update_system_metrics() + + def _update_system_metrics(self): + """Update system-wide aggregated metrics""" + if not self.retrieval_metrics: + return + + # Calculate averages + total_queries = len(self.retrieval_metrics) + cache_hits = sum(1 for m in self.retrieval_metrics if m.cache_hit) + total_retrieval_time = sum(m.retrieval_time_ms for m in self.retrieval_metrics) + + self.system_metrics.total_queries = total_queries + self.system_metrics.cache_hit_rate = cache_hits / total_queries if total_queries > 0 else 0.0 + self.system_metrics.avg_retrieval_time_ms = total_retrieval_time / total_queries if total_queries > 0 else 0.0 + + if self.indexing_metrics: + total_indexed = sum(m.num_documents for m in self.indexing_metrics if m.success) + total_indexing_time = sum(m.processing_time_ms for m in self.indexing_metrics) + total_indexing_ops = len(self.indexing_metrics) + + self.system_metrics.total_indexed_documents = total_indexed + self.system_metrics.avg_indexing_time_ms = total_indexing_time / total_indexing_ops if total_indexing_ops > 0 else 0.0 + + # Error rate + total_ops = total_queries + len(self.indexing_metrics) + errors = sum(1 for m in self.indexing_metrics if not m.success) + self.system_metrics.error_rate = errors / total_ops if total_ops > 0 else 0.0 + + self.system_metrics.last_updated = datetime.now() + + def get_retrieval_stats( + self, + time_window_minutes: Optional[int] = None, + industry: Optional[str] = None + ) -> Dict[str, Any]: + """Get retrieval statistics for time window""" + metrics = self.retrieval_metrics + + # Filter by time window + if time_window_minutes: + cutoff = datetime.now() - timedelta(minutes=time_window_minutes) + metrics = [m for m in metrics if m.timestamp >= cutoff] + + # Filter by industry + if industry: + metrics = [m for m in metrics if m.industry == industry] + + if not metrics: + return { + "count": 0, + "avg_time_ms": 0.0, + "cache_hit_rate": 0.0, + } + + return { + "count": len(metrics), + "avg_time_ms": sum(m.retrieval_time_ms for m in metrics) / len(metrics), + "min_time_ms": min(m.retrieval_time_ms for m in metrics), + "max_time_ms": max(m.retrieval_time_ms for m in metrics), + "cache_hit_rate": sum(1 for m in metrics if m.cache_hit) / len(metrics), + "avg_results": sum(m.num_results for m in metrics) / len(metrics), + "avg_top_score": sum(m.top_score for m in metrics if m.top_score) / len([m for m in metrics if m.top_score]), + } + + def get_indexing_stats( + self, + time_window_minutes: Optional[int] = None + ) -> Dict[str, Any]: + """Get indexing statistics""" + metrics = self.indexing_metrics + + # Filter by time window + if time_window_minutes: + cutoff = datetime.now() - timedelta(minutes=time_window_minutes) + metrics = [m for m in metrics if m.timestamp >= cutoff] + + if not metrics: + return { + "count": 0, + "total_documents": 0, + "avg_time_ms": 0.0, + "success_rate": 0.0, + } + + successful = [m for m in metrics if m.success] + + return { + "count": len(metrics), + "total_documents": sum(m.num_documents for m in successful), + "avg_time_ms": sum(m.processing_time_ms for m in metrics) / len(metrics), + "success_rate": len(successful) / len(metrics) if metrics else 0.0, + "error_count": len([m for m in metrics if not m.success]), + } + + def get_top_queries( + self, + limit: int = 10, + time_window_minutes: Optional[int] = None + ) -> List[Dict[str, Any]]: + """Get most frequent queries""" + metrics = self.retrieval_metrics + + # Filter by time window + if time_window_minutes: + cutoff = datetime.now() - timedelta(minutes=time_window_minutes) + metrics = [m for m in metrics if m.timestamp >= cutoff] + + # Count query frequencies + query_counts: Dict[str, int] = defaultdict(int) + for m in metrics: + query_counts[m.query_text] += 1 + + # Sort by frequency + top_queries = sorted( + query_counts.items(), + key=lambda x: x[1], + reverse=True + )[:limit] + + return [ + { + "query": query, + "count": count + } + for query, count in top_queries + ] + + def get_performance_report(self) -> Dict[str, Any]: + """Get comprehensive performance report""" + return { + "system_metrics": self.system_metrics.to_dict(), + "retrieval_stats_1h": self.get_retrieval_stats(time_window_minutes=60), + "retrieval_stats_24h": self.get_retrieval_stats(time_window_minutes=1440), + "indexing_stats_1h": self.get_indexing_stats(time_window_minutes=60), + "indexing_stats_24h": self.get_indexing_stats(time_window_minutes=1440), + "top_queries_24h": self.get_top_queries(limit=10, time_window_minutes=1440), + } + + def export_metrics(self, filepath: str): + """Export metrics to JSON file""" + data = { + "retrieval_metrics": [m.to_dict() for m in self.retrieval_metrics[-1000:]], # Last 1000 + "indexing_metrics": [m.to_dict() for m in self.indexing_metrics[-1000:]], # Last 1000 + "system_metrics": self.system_metrics.to_dict(), + "exported_at": datetime.now().isoformat(), + } + + with open(filepath, 'w') as f: + json.dump(data, f, indent=2) + + +class PerformanceTimer: + """Context manager for timing operations""" + + def __init__(self, monitor: Optional[RAGMonitor] = None, operation_name: str = "operation"): + self.monitor = monitor + self.operation_name = operation_name + self.start_time = None + self.end_time = None + + def __enter__(self): + self.start_time = time.time() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.end_time = time.time() + return False + + def elapsed_ms(self) -> float: + """Get elapsed time in milliseconds""" + if self.start_time and self.end_time: + return (self.end_time - self.start_time) * 1000 + elif self.start_time: + return (time.time() - self.start_time) * 1000 + return 0.0 + diff --git a/src/deepiri_modelkit/rag/processors.py b/src/deepiri_modelkit/rag/processors.py index 2d617ee..3fdb477 100644 --- a/src/deepiri_modelkit/rag/processors.py +++ b/src/deepiri_modelkit/rag/processors.py @@ -1,443 +1,423 @@ -""" -Document Processors for Universal RAG -Handles preprocessing, chunking, and metadata extraction for different document types -""" - -from abc import ABC, abstractmethod -from typing import List, Dict, Any, Optional -import re -from datetime import datetime - -from .base import Document, DocumentType, IndustryNiche - - -class DocumentProcessor(ABC): - """Base class for document processing""" - - def __init__(self, chunk_size: int = 1000, chunk_overlap: int = 200): - self.chunk_size = chunk_size - self.chunk_overlap = chunk_overlap - - @abstractmethod - def process(self, raw_content: str, metadata: Dict[str, Any]) -> List[Document]: - """ - Process raw content into structured documents - - Args: - raw_content: Raw text content - metadata: Document metadata - - Returns: - List of processed document chunks - """ - pass - - def chunk_text(self, text: str) -> List[str]: - """ - Chunk text into smaller pieces with overlap - - Args: - text: Text to chunk - - Returns: - List of text chunks - """ - chunks = [] - start = 0 - text_length = len(text) - - while start < text_length: - end = start + self.chunk_size - - # Find the last sentence boundary within chunk_size - if end < text_length: - # Look for sentence endings - for sep in [". ", ".\n", "! ", "!\n", "? ", "?\n"]: - last_sep = text.rfind(sep, start, end) - if last_sep != -1: - end = last_sep + 1 - break - - chunk = text[start:end].strip() - if chunk: - chunks.append(chunk) - - # Move start forward with overlap - start = ( - end - self.chunk_overlap if end - self.chunk_overlap > start else end - ) - - return chunks - - def extract_metadata(self, content: str) -> Dict[str, Any]: - """Extract metadata from content (can be overridden)""" - return {} - - -class RegulationProcessor(DocumentProcessor): - """ - Processor for regulations, policies, and compliance documents - Common across insurance, healthcare, manufacturing, etc. - """ - - def process(self, raw_content: str, metadata: Dict[str, Any]) -> List[Document]: - """Process regulation documents""" - # Extract sections and subsections - sections = self._extract_sections(raw_content) - - documents = [] - base_id = metadata.get("id", "reg_" + str(hash(raw_content[:100]))) - industry = IndustryNiche(metadata.get("industry", "generic")) - - for idx, section in enumerate(sections): - doc = Document( - id=f"{base_id}_chunk_{idx}", - content=section["content"], - doc_type=DocumentType.REGULATION, - industry=industry, - title=metadata.get("title", "Regulation Document"), - source=metadata.get("source"), - created_at=self._parse_date(metadata.get("created_at")), - metadata={ - **metadata, - "section": section.get("section"), - "subsection": section.get("subsection"), - }, - chunk_index=idx, - total_chunks=len(sections), - ) - documents.append(doc) - - return documents - - def _extract_sections(self, content: str) -> List[Dict[str, Any]]: - """Extract sections from regulation text""" - # Match section headers like "Section 1.2.3", "Article 5", etc. - section_pattern = r"(Section|Article|Part|Chapter)\s+(\d+(?:\.\d+)*)" - - sections = [] - current_section = {"section": None, "content": ""} - - lines = content.split("\n") - for line in lines: - match = re.search(section_pattern, line, re.IGNORECASE) - if match: - # Save previous section - if current_section["content"]: - sections.append(current_section) - # Start new section - current_section = {"section": match.group(0), "content": line + "\n"} - else: - current_section["content"] += line + "\n" - - # Add last section - if current_section["content"]: - sections.append(current_section) - - # If no sections found, treat entire content as one chunk - if not sections: - chunks = self.chunk_text(content) - sections = [ - {"section": f"Chunk {i+1}", "content": chunk} - for i, chunk in enumerate(chunks) - ] - - return sections - - def _parse_date(self, date_str: Optional[str]) -> Optional[datetime]: - """Parse date string to datetime""" - if not date_str: - return None - try: - return datetime.fromisoformat(date_str) - except (ValueError, AttributeError): - return None - - -class HistoricalDataProcessor(DocumentProcessor): - """ - Processor for historical operational data - - Work orders, maintenance logs, claim records, service history - """ - - def process(self, raw_content: str, metadata: Dict[str, Any]) -> List[Document]: - """Process historical data records""" - doc_type_str = metadata.get("doc_type", "work_order") - doc_type = ( - DocumentType(doc_type_str) if doc_type_str else DocumentType.WORK_ORDER - ) - - # Historical data is typically already structured - # We may need minimal chunking - chunks = ( - self.chunk_text(raw_content) - if len(raw_content) > self.chunk_size - else [raw_content] - ) - - documents = [] - base_id = metadata.get("id", "hist_" + str(hash(raw_content[:100]))) - industry = IndustryNiche(metadata.get("industry", "generic")) - - for idx, chunk in enumerate(chunks): - doc = Document( - id=f"{base_id}_chunk_{idx}", - content=chunk, - doc_type=doc_type, - industry=industry, - title=metadata.get( - "title", f'{doc_type.value.replace("_", " ").title()}' - ), - source=metadata.get("source"), - created_at=self._parse_date(metadata.get("created_at")), - updated_at=self._parse_date(metadata.get("updated_at")), - metadata={ - **metadata, - "record_type": doc_type.value, - "record_id": metadata.get("record_id"), - "status": metadata.get("status"), - "priority": metadata.get("priority"), - "assigned_to": metadata.get("assigned_to"), - }, - chunk_index=idx, - total_chunks=len(chunks), - ) - documents.append(doc) - - return documents - - def _parse_date(self, date_str: Optional[str]) -> Optional[datetime]: - """Parse date string to datetime""" - if not date_str: - return None - if isinstance(date_str, datetime): - return date_str - try: - return datetime.fromisoformat(date_str) - except (ValueError, AttributeError): - try: - return datetime.strptime(date_str, "%Y-%m-%d") - except (ValueError, AttributeError): - return None - - -class KnowledgeBaseProcessor(DocumentProcessor): - """ - Processor for knowledge base articles, FAQs, and guides - - Equipment repair guides, compliance advice, troubleshooting steps - """ - - def process(self, raw_content: str, metadata: Dict[str, Any]) -> List[Document]: - """Process knowledge base articles""" - doc_type = ( - DocumentType.KNOWLEDGE_BASE - if metadata.get("doc_type") == "knowledge_base" - else DocumentType.FAQ - ) - - # Extract Q&A pairs if FAQ format - if doc_type == DocumentType.FAQ: - qa_pairs = self._extract_qa_pairs(raw_content) - if qa_pairs: - return self._process_qa_pairs(qa_pairs, metadata) - - # Otherwise, process as regular article - chunks = self.chunk_text(raw_content) - - documents = [] - base_id = metadata.get("id", "kb_" + str(hash(raw_content[:100]))) - industry = IndustryNiche(metadata.get("industry", "generic")) - - for idx, chunk in enumerate(chunks): - doc = Document( - id=f"{base_id}_chunk_{idx}", - content=chunk, - doc_type=doc_type, - industry=industry, - title=metadata.get("title", "Knowledge Base Article"), - source=metadata.get("source"), - created_at=self._parse_date(metadata.get("created_at")), - updated_at=self._parse_date(metadata.get("updated_at")), - author=metadata.get("author"), - metadata={ - **metadata, - "category": metadata.get("category"), - "tags": metadata.get("tags", []), - "difficulty_level": metadata.get("difficulty_level"), - }, - chunk_index=idx, - total_chunks=len(chunks), - ) - documents.append(doc) - - return documents - - def _extract_qa_pairs(self, content: str) -> List[Dict[str, str]]: - """Extract Q&A pairs from FAQ content""" - qa_pairs = [] - - # Try different FAQ formats - # Format 1: Q: ... A: ... - pattern1 = r"Q:\s*(.+?)\s*A:\s*(.+?)(?=Q:|$)" - matches = re.findall(pattern1, content, re.DOTALL | re.IGNORECASE) - if matches: - qa_pairs.extend( - [{"question": q.strip(), "answer": a.strip()} for q, a in matches] - ) - - # Format 2: Question/Answer headers - pattern2 = r"Question:\s*(.+?)\s*Answer:\s*(.+?)(?=Question:|$)" - matches = re.findall(pattern2, content, re.DOTALL | re.IGNORECASE) - if matches: - qa_pairs.extend( - [{"question": q.strip(), "answer": a.strip()} for q, a in matches] - ) - - return qa_pairs - - def _process_qa_pairs( - self, qa_pairs: List[Dict[str, str]], metadata: Dict[str, Any] - ) -> List[Document]: - """Process Q&A pairs into documents""" - documents = [] - base_id = metadata.get("id", "faq_" + str(hash(str(qa_pairs[0])))) - industry = IndustryNiche(metadata.get("industry", "generic")) - - for idx, qa in enumerate(qa_pairs): - content = f"Question: {qa['question']}\n\nAnswer: {qa['answer']}" - doc = Document( - id=f"{base_id}_qa_{idx}", - content=content, - doc_type=DocumentType.FAQ, - industry=industry, - title=qa["question"][:100], # Use question as title - source=metadata.get("source"), - metadata={ - **metadata, - "question": qa["question"], - "answer": qa["answer"], - }, - chunk_index=idx, - total_chunks=len(qa_pairs), - ) - documents.append(doc) - - return documents - - def _parse_date(self, date_str: Optional[str]) -> Optional[datetime]: - """Parse date string to datetime""" - if not date_str: - return None - if isinstance(date_str, datetime): - return date_str - try: - return datetime.fromisoformat(date_str) - except (ValueError, AttributeError): - return None - - -class ManualProcessor(DocumentProcessor): - """ - Processor for equipment manuals, operation guides, technical specifications - """ - - def process(self, raw_content: str, metadata: Dict[str, Any]) -> List[Document]: - """Process manual documents""" - # Extract chapters and sections - sections = self._extract_sections(raw_content) - - documents = [] - base_id = metadata.get("id", "manual_" + str(hash(raw_content[:100]))) - industry = IndustryNiche(metadata.get("industry", "generic")) - - for idx, section in enumerate(sections): - doc = Document( - id=f"{base_id}_chunk_{idx}", - content=section["content"], - doc_type=DocumentType.MANUAL, - industry=industry, - title=metadata.get("title", "Equipment Manual"), - source=metadata.get("source"), - version=metadata.get("version"), - metadata={ - **metadata, - "chapter": section.get("chapter"), - "section": section.get("section"), - "equipment_model": metadata.get("equipment_model"), - "manufacturer": metadata.get("manufacturer"), - }, - chunk_index=idx, - total_chunks=len(sections), - ) - documents.append(doc) - - return documents - - def _extract_sections(self, content: str) -> List[Dict[str, Any]]: - """Extract sections from manual text""" - # Match chapter/section headers - section_pattern = r"(Chapter|Section)\s+(\d+(?:\.\d+)*):?\s*(.+?)(?=\n)" - - sections = [] - current_section = {"chapter": None, "section": None, "content": ""} - - lines = content.split("\n") - for line in lines: - match = re.search(section_pattern, line, re.IGNORECASE) - if match: - # Save previous section - if current_section["content"]: - sections.append(current_section) - # Start new section - section_type = match.group(1).lower() - section_num = match.group(2) - section_title = match.group(3).strip() if match.group(3) else "" - current_section = { - section_type: f"{section_type.title()} {section_num}", - "section_title": section_title, - "content": line + "\n", - } - else: - current_section["content"] += line + "\n" - - # Add last section - if current_section["content"]: - sections.append(current_section) - - # If no sections found, chunk the content - if not sections: - chunks = self.chunk_text(content) - sections = [ - {"section": f"Chunk {i+1}", "content": chunk} - for i, chunk in enumerate(chunks) - ] - - return sections - - -def get_processor(doc_type: DocumentType, **kwargs) -> DocumentProcessor: - """ - Factory function to get appropriate processor for document type - - Args: - doc_type: Type of document - **kwargs: Additional configuration for processor - - Returns: - Configured document processor - """ - processor_map = { - DocumentType.REGULATION: RegulationProcessor, - DocumentType.POLICY: RegulationProcessor, # Similar processing - DocumentType.WORK_ORDER: HistoricalDataProcessor, - DocumentType.CLAIM_RECORD: HistoricalDataProcessor, - DocumentType.MAINTENANCE_LOG: HistoricalDataProcessor, - DocumentType.FAQ: KnowledgeBaseProcessor, - DocumentType.KNOWLEDGE_BASE: KnowledgeBaseProcessor, - DocumentType.MANUAL: ManualProcessor, - DocumentType.TECHNICAL_SPEC: ManualProcessor, # Similar processing - DocumentType.PROCEDURE: ManualProcessor, # Similar processing - } - - processor_class = processor_map.get(doc_type, DocumentProcessor) - return processor_class(**kwargs) +""" +Document Processors for Universal RAG +Handles preprocessing, chunking, and metadata extraction for different document types +""" + +from abc import ABC, abstractmethod +from typing import List, Dict, Any, Optional +import re +from datetime import datetime + +from .base import Document, DocumentType, IndustryNiche + + +class DocumentProcessor(ABC): + """Base class for document processing""" + + def __init__(self, chunk_size: int = 1000, chunk_overlap: int = 200): + self.chunk_size = chunk_size + self.chunk_overlap = chunk_overlap + + @abstractmethod + def process(self, raw_content: str, metadata: Dict[str, Any]) -> List[Document]: + """ + Process raw content into structured documents + + Args: + raw_content: Raw text content + metadata: Document metadata + + Returns: + List of processed document chunks + """ + pass + + def chunk_text(self, text: str) -> List[str]: + """ + Chunk text into smaller pieces with overlap + + Args: + text: Text to chunk + + Returns: + List of text chunks + """ + chunks = [] + start = 0 + text_length = len(text) + + while start < text_length: + end = start + self.chunk_size + + # Find the last sentence boundary within chunk_size + if end < text_length: + # Look for sentence endings + for sep in ['. ', '.\n', '! ', '!\n', '? ', '?\n']: + last_sep = text.rfind(sep, start, end) + if last_sep != -1: + end = last_sep + 1 + break + + chunk = text[start:end].strip() + if chunk: + chunks.append(chunk) + + # Move start forward with overlap + start = end - self.chunk_overlap if end - self.chunk_overlap > start else end + + return chunks + + def extract_metadata(self, content: str) -> Dict[str, Any]: + """Extract metadata from content (can be overridden)""" + return {} + + +class RegulationProcessor(DocumentProcessor): + """ + Processor for regulations, policies, and compliance documents + Common across insurance, healthcare, manufacturing, etc. + """ + + def process(self, raw_content: str, metadata: Dict[str, Any]) -> List[Document]: + """Process regulation documents""" + # Extract sections and subsections + sections = self._extract_sections(raw_content) + + documents = [] + base_id = metadata.get('id', 'reg_' + str(hash(raw_content[:100]))) + industry = IndustryNiche(metadata.get('industry', 'generic')) + + for idx, section in enumerate(sections): + doc = Document( + id=f"{base_id}_chunk_{idx}", + content=section['content'], + doc_type=DocumentType.REGULATION, + industry=industry, + title=metadata.get('title', 'Regulation Document'), + source=metadata.get('source'), + created_at=self._parse_date(metadata.get('created_at')), + metadata={ + **metadata, + 'section': section.get('section'), + 'subsection': section.get('subsection'), + }, + chunk_index=idx, + total_chunks=len(sections), + ) + documents.append(doc) + + return documents + + def _extract_sections(self, content: str) -> List[Dict[str, Any]]: + """Extract sections from regulation text""" + # Match section headers like "Section 1.2.3", "Article 5", etc. + section_pattern = r'(Section|Article|Part|Chapter)\s+(\d+(?:\.\d+)*)' + + sections = [] + current_section = {'section': None, 'content': ''} + + lines = content.split('\n') + for line in lines: + match = re.search(section_pattern, line, re.IGNORECASE) + if match: + # Save previous section + if current_section['content']: + sections.append(current_section) + # Start new section + current_section = { + 'section': match.group(0), + 'content': line + '\n' + } + else: + current_section['content'] += line + '\n' + + # Add last section + if current_section['content']: + sections.append(current_section) + + # If no sections found, treat entire content as one chunk + if not sections: + chunks = self.chunk_text(content) + sections = [{'section': f'Chunk {i+1}', 'content': chunk} + for i, chunk in enumerate(chunks)] + + return sections + + def _parse_date(self, date_str: Optional[str]) -> Optional[datetime]: + """Parse date string to datetime""" + if not date_str: + return None + try: + return datetime.fromisoformat(date_str) + except (ValueError, AttributeError): + return None + + +class HistoricalDataProcessor(DocumentProcessor): + """ + Processor for historical operational data + - Work orders, maintenance logs, claim records, service history + """ + + def process(self, raw_content: str, metadata: Dict[str, Any]) -> List[Document]: + """Process historical data records""" + doc_type_str = metadata.get('doc_type', 'work_order') + doc_type = DocumentType(doc_type_str) if doc_type_str else DocumentType.WORK_ORDER + + # Historical data is typically already structured + # We may need minimal chunking + chunks = self.chunk_text(raw_content) if len(raw_content) > self.chunk_size else [raw_content] + + documents = [] + base_id = metadata.get('id', 'hist_' + str(hash(raw_content[:100]))) + industry = IndustryNiche(metadata.get('industry', 'generic')) + + for idx, chunk in enumerate(chunks): + doc = Document( + id=f"{base_id}_chunk_{idx}", + content=chunk, + doc_type=doc_type, + industry=industry, + title=metadata.get('title', f'{doc_type.value.replace("_", " ").title()}'), + source=metadata.get('source'), + created_at=self._parse_date(metadata.get('created_at')), + updated_at=self._parse_date(metadata.get('updated_at')), + metadata={ + **metadata, + 'record_type': doc_type.value, + 'record_id': metadata.get('record_id'), + 'status': metadata.get('status'), + 'priority': metadata.get('priority'), + 'assigned_to': metadata.get('assigned_to'), + }, + chunk_index=idx, + total_chunks=len(chunks), + ) + documents.append(doc) + + return documents + + def _parse_date(self, date_str: Optional[str]) -> Optional[datetime]: + """Parse date string to datetime""" + if not date_str: + return None + if isinstance(date_str, datetime): + return date_str + try: + return datetime.fromisoformat(date_str) + except (ValueError, AttributeError): + try: + return datetime.strptime(date_str, '%Y-%m-%d') + except (ValueError, AttributeError): + return None + + +class KnowledgeBaseProcessor(DocumentProcessor): + """ + Processor for knowledge base articles, FAQs, and guides + - Equipment repair guides, compliance advice, troubleshooting steps + """ + + def process(self, raw_content: str, metadata: Dict[str, Any]) -> List[Document]: + """Process knowledge base articles""" + doc_type = DocumentType.KNOWLEDGE_BASE if metadata.get('doc_type') == 'knowledge_base' else DocumentType.FAQ + + # Extract Q&A pairs if FAQ format + if doc_type == DocumentType.FAQ: + qa_pairs = self._extract_qa_pairs(raw_content) + if qa_pairs: + return self._process_qa_pairs(qa_pairs, metadata) + + # Otherwise, process as regular article + chunks = self.chunk_text(raw_content) + + documents = [] + base_id = metadata.get('id', 'kb_' + str(hash(raw_content[:100]))) + industry = IndustryNiche(metadata.get('industry', 'generic')) + + for idx, chunk in enumerate(chunks): + doc = Document( + id=f"{base_id}_chunk_{idx}", + content=chunk, + doc_type=doc_type, + industry=industry, + title=metadata.get('title', 'Knowledge Base Article'), + source=metadata.get('source'), + created_at=self._parse_date(metadata.get('created_at')), + updated_at=self._parse_date(metadata.get('updated_at')), + author=metadata.get('author'), + metadata={ + **metadata, + 'category': metadata.get('category'), + 'tags': metadata.get('tags', []), + 'difficulty_level': metadata.get('difficulty_level'), + }, + chunk_index=idx, + total_chunks=len(chunks), + ) + documents.append(doc) + + return documents + + def _extract_qa_pairs(self, content: str) -> List[Dict[str, str]]: + """Extract Q&A pairs from FAQ content""" + qa_pairs = [] + + # Try different FAQ formats + # Format 1: Q: ... A: ... + pattern1 = r'Q:\s*(.+?)\s*A:\s*(.+?)(?=Q:|$)' + matches = re.findall(pattern1, content, re.DOTALL | re.IGNORECASE) + if matches: + qa_pairs.extend([{'question': q.strip(), 'answer': a.strip()} for q, a in matches]) + + # Format 2: Question/Answer headers + pattern2 = r'Question:\s*(.+?)\s*Answer:\s*(.+?)(?=Question:|$)' + matches = re.findall(pattern2, content, re.DOTALL | re.IGNORECASE) + if matches: + qa_pairs.extend([{'question': q.strip(), 'answer': a.strip()} for q, a in matches]) + + return qa_pairs + + def _process_qa_pairs(self, qa_pairs: List[Dict[str, str]], metadata: Dict[str, Any]) -> List[Document]: + """Process Q&A pairs into documents""" + documents = [] + base_id = metadata.get('id', 'faq_' + str(hash(str(qa_pairs[0])))) + industry = IndustryNiche(metadata.get('industry', 'generic')) + + for idx, qa in enumerate(qa_pairs): + content = f"Question: {qa['question']}\n\nAnswer: {qa['answer']}" + doc = Document( + id=f"{base_id}_qa_{idx}", + content=content, + doc_type=DocumentType.FAQ, + industry=industry, + title=qa['question'][:100], # Use question as title + source=metadata.get('source'), + metadata={ + **metadata, + 'question': qa['question'], + 'answer': qa['answer'], + }, + chunk_index=idx, + total_chunks=len(qa_pairs), + ) + documents.append(doc) + + return documents + + def _parse_date(self, date_str: Optional[str]) -> Optional[datetime]: + """Parse date string to datetime""" + if not date_str: + return None + if isinstance(date_str, datetime): + return date_str + try: + return datetime.fromisoformat(date_str) + except (ValueError, AttributeError): + return None + + +class ManualProcessor(DocumentProcessor): + """ + Processor for equipment manuals, operation guides, technical specifications + """ + + def process(self, raw_content: str, metadata: Dict[str, Any]) -> List[Document]: + """Process manual documents""" + # Extract chapters and sections + sections = self._extract_sections(raw_content) + + documents = [] + base_id = metadata.get('id', 'manual_' + str(hash(raw_content[:100]))) + industry = IndustryNiche(metadata.get('industry', 'generic')) + + for idx, section in enumerate(sections): + doc = Document( + id=f"{base_id}_chunk_{idx}", + content=section['content'], + doc_type=DocumentType.MANUAL, + industry=industry, + title=metadata.get('title', 'Equipment Manual'), + source=metadata.get('source'), + version=metadata.get('version'), + metadata={ + **metadata, + 'chapter': section.get('chapter'), + 'section': section.get('section'), + 'equipment_model': metadata.get('equipment_model'), + 'manufacturer': metadata.get('manufacturer'), + }, + chunk_index=idx, + total_chunks=len(sections), + ) + documents.append(doc) + + return documents + + def _extract_sections(self, content: str) -> List[Dict[str, Any]]: + """Extract sections from manual text""" + # Match chapter/section headers + section_pattern = r'(Chapter|Section)\s+(\d+(?:\.\d+)*):?\s*(.+?)(?=\n)' + + sections = [] + current_section = {'chapter': None, 'section': None, 'content': ''} + + lines = content.split('\n') + for line in lines: + match = re.search(section_pattern, line, re.IGNORECASE) + if match: + # Save previous section + if current_section['content']: + sections.append(current_section) + # Start new section + section_type = match.group(1).lower() + section_num = match.group(2) + section_title = match.group(3).strip() if match.group(3) else '' + current_section = { + section_type: f"{section_type.title()} {section_num}", + 'section_title': section_title, + 'content': line + '\n' + } + else: + current_section['content'] += line + '\n' + + # Add last section + if current_section['content']: + sections.append(current_section) + + # If no sections found, chunk the content + if not sections: + chunks = self.chunk_text(content) + sections = [{'section': f'Chunk {i+1}', 'content': chunk} + for i, chunk in enumerate(chunks)] + + return sections + + +def get_processor(doc_type: DocumentType, **kwargs) -> DocumentProcessor: + """ + Factory function to get appropriate processor for document type + + Args: + doc_type: Type of document + **kwargs: Additional configuration for processor + + Returns: + Configured document processor + """ + processor_map = { + DocumentType.REGULATION: RegulationProcessor, + DocumentType.POLICY: RegulationProcessor, # Similar processing + DocumentType.WORK_ORDER: HistoricalDataProcessor, + DocumentType.CLAIM_RECORD: HistoricalDataProcessor, + DocumentType.MAINTENANCE_LOG: HistoricalDataProcessor, + DocumentType.FAQ: KnowledgeBaseProcessor, + DocumentType.KNOWLEDGE_BASE: KnowledgeBaseProcessor, + DocumentType.MANUAL: ManualProcessor, + DocumentType.TECHNICAL_SPEC: ManualProcessor, # Similar processing + DocumentType.PROCEDURE: ManualProcessor, # Similar processing + } + + processor_class = processor_map.get(doc_type, DocumentProcessor) + return processor_class(**kwargs) + diff --git a/src/deepiri_modelkit/rag/retrievers.py b/src/deepiri_modelkit/rag/retrievers.py index e8804f3..dcda0d0 100644 --- a/src/deepiri_modelkit/rag/retrievers.py +++ b/src/deepiri_modelkit/rag/retrievers.py @@ -1,287 +1,288 @@ -""" -Retrieval Components for Universal RAG -Implements various retrieval strategies for different use cases -""" - -from abc import ABC, abstractmethod -from typing import List, Dict, Any, Optional -from dataclasses import dataclass - -from .base import Document, RetrievalResult, RAGQuery - - -class BaseRetriever(ABC): - """Base class for retrievers""" - - @abstractmethod - def retrieve(self, query: RAGQuery) -> List[RetrievalResult]: - """Retrieve relevant documents for query""" - pass - - -class MultiModalRetriever(BaseRetriever): - """ - Multi-modal retriever supporting text, images, tables, and code - Useful for technical manuals, equipment guides, etc. - """ - - def __init__( - self, - text_embeddings, - image_embeddings=None, - table_embeddings=None, - code_embeddings=None, - ): - self.text_embeddings = text_embeddings - self.image_embeddings = image_embeddings - self.table_embeddings = table_embeddings - self.code_embeddings = code_embeddings - - def retrieve(self, query: RAGQuery) -> List[RetrievalResult]: - """ - Retrieve documents across multiple modalities - - Currently focuses on text, but can be extended for other modalities - """ - # For now, delegate to text retrieval - # Future: Add image, table, code retrieval - results = self._retrieve_text(query) - return results - - def _retrieve_text(self, query: RAGQuery) -> List[RetrievalResult]: - """Retrieve text documents""" - # This will be implemented by the concrete RAG engine - # This is a placeholder for the interface - return [] - - -class HybridRetriever(BaseRetriever): - """ - Hybrid retriever combining: - - Semantic search (vector similarity) - - Keyword search (BM25) - - Metadata filtering - - Provides better recall than pure semantic search - """ - - def __init__( - self, - vector_retriever, - keyword_retriever=None, - vector_weight: float = 0.7, - keyword_weight: float = 0.3, - ): - self.vector_retriever = vector_retriever - self.keyword_retriever = keyword_retriever - self.vector_weight = vector_weight - self.keyword_weight = keyword_weight - - def retrieve(self, query: RAGQuery) -> List[RetrievalResult]: - """ - Retrieve using hybrid approach - - Combines vector and keyword search results - """ - results = [] - - # Vector search - vector_results = self._retrieve_vector(query) - results.extend(vector_results) - - # Keyword search (if available) - if self.keyword_retriever: - keyword_results = self._retrieve_keyword(query) - results.extend(keyword_results) - - # Merge and re-score - merged_results = self._merge_results( - vector_results, keyword_results if self.keyword_retriever else [] - ) - - return merged_results - - def _retrieve_vector(self, query: RAGQuery) -> List[RetrievalResult]: - """Vector similarity search""" - # Placeholder - implemented by concrete engine - return [] - - def _retrieve_keyword(self, query: RAGQuery) -> List[RetrievalResult]: - """Keyword search using BM25 or similar""" - # Placeholder - implemented by concrete engine - return [] - - def _merge_results( - self, - vector_results: List[RetrievalResult], - keyword_results: List[RetrievalResult], - ) -> List[RetrievalResult]: - """ - Merge results from different retrievers using weighted scoring - - Uses Reciprocal Rank Fusion (RRF) for combining rankings - """ - # Create a dictionary to store combined scores - doc_scores: Dict[str, Dict[str, Any]] = {} - - # Process vector results - for rank, result in enumerate(vector_results): - doc_id = result.document.id - # RRF score: 1 / (k + rank) where k=60 is common - rrf_score = 1.0 / (60 + rank + 1) - doc_scores[doc_id] = { - "document": result.document, - "vector_score": result.score, - "vector_rrf": rrf_score, - "keyword_rrf": 0.0, - "keyword_score": 0.0, - } - - # Process keyword results - for rank, result in enumerate(keyword_results): - doc_id = result.document.id - rrf_score = 1.0 / (60 + rank + 1) - - if doc_id in doc_scores: - doc_scores[doc_id]["keyword_rrf"] = rrf_score - doc_scores[doc_id]["keyword_score"] = result.score - else: - doc_scores[doc_id] = { - "document": result.document, - "vector_score": 0.0, - "vector_rrf": 0.0, - "keyword_rrf": rrf_score, - "keyword_score": result.score, - } - - # Calculate combined scores - merged = [] - for doc_id, scores in doc_scores.items(): - combined_rrf = ( - self.vector_weight * scores["vector_rrf"] - + self.keyword_weight * scores["keyword_rrf"] - ) - - result = RetrievalResult( - document=scores["document"], - score=combined_rrf, - rerank_score=None, - ) - merged.append(result) - - # Sort by combined score - merged.sort(key=lambda x: x.score, reverse=True) - - return merged - - -class ContextualRetriever(BaseRetriever): - """ - Contextual retriever that considers: - - User context (role, history, preferences) - - Temporal context (recent vs historical) - - Industry context (specific to niche) - """ - - def __init__( - self, - base_retriever: BaseRetriever, - use_user_context: bool = True, - use_temporal_context: bool = True, - use_industry_context: bool = True, - ): - self.base_retriever = base_retriever - self.use_user_context = use_user_context - self.use_temporal_context = use_temporal_context - self.use_industry_context = use_industry_context - - def retrieve(self, query: RAGQuery) -> List[RetrievalResult]: - """ - Retrieve with contextual awareness - """ - # Get base results - results = self.base_retriever.retrieve(query) - - # Apply contextual reranking - if self.use_temporal_context: - results = self._apply_temporal_boost(results, query) - - if self.use_industry_context: - results = self._apply_industry_boost(results, query) - - return results - - def _apply_temporal_boost( - self, results: List[RetrievalResult], query: RAGQuery - ) -> List[RetrievalResult]: - """ - Boost recent documents (useful for regulations, updates) - """ - import time - from datetime import datetime - - current_time = datetime.now().timestamp() - - for result in results: - if result.document.updated_at or result.document.created_at: - doc_time = ( - result.document.updated_at or result.document.created_at - ).timestamp() - # Calculate age in days - age_days = (current_time - doc_time) / 86400 - - # Apply decay: more recent = higher boost - # Documents within 30 days get full boost - # Older documents gradually lose boost - if age_days < 30: - temporal_boost = 1.0 - elif age_days < 180: # 6 months - temporal_boost = 0.9 - elif age_days < 365: # 1 year - temporal_boost = 0.8 - else: - temporal_boost = 0.7 - - result.score *= temporal_boost - - # Re-sort by adjusted scores - results.sort(key=lambda x: x.score, reverse=True) - return results - - def _apply_industry_boost( - self, results: List[RetrievalResult], query: RAGQuery - ) -> List[RetrievalResult]: - """ - Boost documents matching the query's industry - """ - if not query.industry: - return results - - for result in results: - if result.document.industry == query.industry: - result.score *= 1.1 # 10% boost for industry match - - # Re-sort - results.sort(key=lambda x: x.score, reverse=True) - return results - - -def get_retriever(retriever_type: str, **kwargs) -> BaseRetriever: - """ - Factory function to get retriever - - Args: - retriever_type: Type of retriever ('hybrid', 'multimodal', 'contextual') - **kwargs: Configuration for retriever - - Returns: - Configured retriever - """ - retriever_map = { - "hybrid": HybridRetriever, - "multimodal": MultiModalRetriever, - "contextual": ContextualRetriever, - } - - retriever_class = retriever_map.get(retriever_type, HybridRetriever) - return retriever_class(**kwargs) +""" +Retrieval Components for Universal RAG +Implements various retrieval strategies for different use cases +""" + +from abc import ABC, abstractmethod +from typing import List, Dict, Any, Optional +from dataclasses import dataclass + +from .base import Document, RetrievalResult, RAGQuery + + +class BaseRetriever(ABC): + """Base class for retrievers""" + + @abstractmethod + def retrieve(self, query: RAGQuery) -> List[RetrievalResult]: + """Retrieve relevant documents for query""" + pass + + +class MultiModalRetriever(BaseRetriever): + """ + Multi-modal retriever supporting text, images, tables, and code + Useful for technical manuals, equipment guides, etc. + """ + + def __init__( + self, + text_embeddings, + image_embeddings=None, + table_embeddings=None, + code_embeddings=None + ): + self.text_embeddings = text_embeddings + self.image_embeddings = image_embeddings + self.table_embeddings = table_embeddings + self.code_embeddings = code_embeddings + + def retrieve(self, query: RAGQuery) -> List[RetrievalResult]: + """ + Retrieve documents across multiple modalities + + Currently focuses on text, but can be extended for other modalities + """ + # For now, delegate to text retrieval + # Future: Add image, table, code retrieval + results = self._retrieve_text(query) + return results + + def _retrieve_text(self, query: RAGQuery) -> List[RetrievalResult]: + """Retrieve text documents""" + # This will be implemented by the concrete RAG engine + # This is a placeholder for the interface + return [] + + +class HybridRetriever(BaseRetriever): + """ + Hybrid retriever combining: + - Semantic search (vector similarity) + - Keyword search (BM25) + - Metadata filtering + + Provides better recall than pure semantic search + """ + + def __init__( + self, + vector_retriever, + keyword_retriever=None, + vector_weight: float = 0.7, + keyword_weight: float = 0.3 + ): + self.vector_retriever = vector_retriever + self.keyword_retriever = keyword_retriever + self.vector_weight = vector_weight + self.keyword_weight = keyword_weight + + def retrieve(self, query: RAGQuery) -> List[RetrievalResult]: + """ + Retrieve using hybrid approach + + Combines vector and keyword search results + """ + results = [] + + # Vector search + vector_results = self._retrieve_vector(query) + results.extend(vector_results) + + # Keyword search (if available) + if self.keyword_retriever: + keyword_results = self._retrieve_keyword(query) + results.extend(keyword_results) + + # Merge and re-score + merged_results = self._merge_results(vector_results, keyword_results if self.keyword_retriever else []) + + return merged_results + + def _retrieve_vector(self, query: RAGQuery) -> List[RetrievalResult]: + """Vector similarity search""" + # Placeholder - implemented by concrete engine + return [] + + def _retrieve_keyword(self, query: RAGQuery) -> List[RetrievalResult]: + """Keyword search using BM25 or similar""" + # Placeholder - implemented by concrete engine + return [] + + def _merge_results( + self, + vector_results: List[RetrievalResult], + keyword_results: List[RetrievalResult] + ) -> List[RetrievalResult]: + """ + Merge results from different retrievers using weighted scoring + + Uses Reciprocal Rank Fusion (RRF) for combining rankings + """ + # Create a dictionary to store combined scores + doc_scores: Dict[str, Dict[str, Any]] = {} + + # Process vector results + for rank, result in enumerate(vector_results): + doc_id = result.document.id + # RRF score: 1 / (k + rank) where k=60 is common + rrf_score = 1.0 / (60 + rank + 1) + doc_scores[doc_id] = { + 'document': result.document, + 'vector_score': result.score, + 'vector_rrf': rrf_score, + 'keyword_rrf': 0.0, + 'keyword_score': 0.0, + } + + # Process keyword results + for rank, result in enumerate(keyword_results): + doc_id = result.document.id + rrf_score = 1.0 / (60 + rank + 1) + + if doc_id in doc_scores: + doc_scores[doc_id]['keyword_rrf'] = rrf_score + doc_scores[doc_id]['keyword_score'] = result.score + else: + doc_scores[doc_id] = { + 'document': result.document, + 'vector_score': 0.0, + 'vector_rrf': 0.0, + 'keyword_rrf': rrf_score, + 'keyword_score': result.score, + } + + # Calculate combined scores + merged = [] + for doc_id, scores in doc_scores.items(): + combined_rrf = ( + self.vector_weight * scores['vector_rrf'] + + self.keyword_weight * scores['keyword_rrf'] + ) + + result = RetrievalResult( + document=scores['document'], + score=combined_rrf, + rerank_score=None, + ) + merged.append(result) + + # Sort by combined score + merged.sort(key=lambda x: x.score, reverse=True) + + return merged + + +class ContextualRetriever(BaseRetriever): + """ + Contextual retriever that considers: + - User context (role, history, preferences) + - Temporal context (recent vs historical) + - Industry context (specific to niche) + """ + + def __init__( + self, + base_retriever: BaseRetriever, + use_user_context: bool = True, + use_temporal_context: bool = True, + use_industry_context: bool = True + ): + self.base_retriever = base_retriever + self.use_user_context = use_user_context + self.use_temporal_context = use_temporal_context + self.use_industry_context = use_industry_context + + def retrieve(self, query: RAGQuery) -> List[RetrievalResult]: + """ + Retrieve with contextual awareness + """ + # Get base results + results = self.base_retriever.retrieve(query) + + # Apply contextual reranking + if self.use_temporal_context: + results = self._apply_temporal_boost(results, query) + + if self.use_industry_context: + results = self._apply_industry_boost(results, query) + + return results + + def _apply_temporal_boost( + self, + results: List[RetrievalResult], + query: RAGQuery + ) -> List[RetrievalResult]: + """ + Boost recent documents (useful for regulations, updates) + """ + import time + from datetime import datetime + + current_time = datetime.now().timestamp() + + for result in results: + if result.document.updated_at or result.document.created_at: + doc_time = (result.document.updated_at or result.document.created_at).timestamp() + # Calculate age in days + age_days = (current_time - doc_time) / 86400 + + # Apply decay: more recent = higher boost + # Documents within 30 days get full boost + # Older documents gradually lose boost + if age_days < 30: + temporal_boost = 1.0 + elif age_days < 180: # 6 months + temporal_boost = 0.9 + elif age_days < 365: # 1 year + temporal_boost = 0.8 + else: + temporal_boost = 0.7 + + result.score *= temporal_boost + + # Re-sort by adjusted scores + results.sort(key=lambda x: x.score, reverse=True) + return results + + def _apply_industry_boost( + self, + results: List[RetrievalResult], + query: RAGQuery + ) -> List[RetrievalResult]: + """ + Boost documents matching the query's industry + """ + if not query.industry: + return results + + for result in results: + if result.document.industry == query.industry: + result.score *= 1.1 # 10% boost for industry match + + # Re-sort + results.sort(key=lambda x: x.score, reverse=True) + return results + + +def get_retriever(retriever_type: str, **kwargs) -> BaseRetriever: + """ + Factory function to get retriever + + Args: + retriever_type: Type of retriever ('hybrid', 'multimodal', 'contextual') + **kwargs: Configuration for retriever + + Returns: + Configured retriever + """ + retriever_map = { + 'hybrid': HybridRetriever, + 'multimodal': MultiModalRetriever, + 'contextual': ContextualRetriever, + } + + retriever_class = retriever_map.get(retriever_type, HybridRetriever) + return retriever_class(**kwargs) + diff --git a/src/deepiri_modelkit/rag/testing.py b/src/deepiri_modelkit/rag/testing.py index 1377e55..bc11473 100644 --- a/src/deepiri_modelkit/rag/testing.py +++ b/src/deepiri_modelkit/rag/testing.py @@ -1,334 +1,332 @@ -""" -Testing Utilities for Universal RAG -Comprehensive test helpers, fixtures, and evaluation tools -""" - -from typing import List, Dict, Any, Optional, Tuple -from dataclasses import dataclass -import json -from datetime import datetime - -from .base import Document, DocumentType, IndustryNiche, RAGQuery, RetrievalResult - - -@dataclass -class TestCase: - """Test case for RAG evaluation""" - - query: str - expected_doc_ids: List[str] # Document IDs that should be retrieved - expected_doc_types: Optional[List[DocumentType]] = None - min_score: float = 0.7 # Minimum similarity score - top_k: int = 5 - metadata: Dict[str, Any] = None - - def __post_init__(self): - if self.metadata is None: - self.metadata = {} - - -@dataclass -class TestResult: - """Result of a test case""" - - test_case: TestCase - retrieved_doc_ids: List[str] - retrieved_scores: List[float] - precision: float - recall: float - f1_score: float - passed: bool - error: Optional[str] = None - - def to_dict(self) -> Dict[str, Any]: - """Convert to dictionary""" - return { - "query": self.test_case.query, - "expected_doc_ids": self.test_case.expected_doc_ids, - "retrieved_doc_ids": self.retrieved_doc_ids, - "precision": self.precision, - "recall": self.recall, - "f1_score": self.f1_score, - "passed": self.passed, - "error": self.error, - } - - -class RAGEvaluator: - """ - Evaluator for RAG system performance - """ - - def __init__(self, rag_engine): - self.rag_engine = rag_engine - - def evaluate( - self, test_cases: List[TestCase], industry: Optional[IndustryNiche] = None - ) -> Dict[str, Any]: - """ - Evaluate RAG system on test cases - - Args: - test_cases: List of test cases - industry: Industry context - - Returns: - Evaluation results with metrics - """ - results = [] - - for test_case in test_cases: - result = self._evaluate_test_case(test_case, industry) - results.append(result) - - # Calculate aggregate metrics - total_precision = ( - sum(r.precision for r in results) / len(results) if results else 0.0 - ) - total_recall = sum(r.recall for r in results) / len(results) if results else 0.0 - total_f1 = sum(r.f1_score for r in results) / len(results) if results else 0.0 - passed_count = sum(1 for r in results if r.passed) - - return { - "total_tests": len(test_cases), - "passed": passed_count, - "failed": len(test_cases) - passed_count, - "avg_precision": total_precision, - "avg_recall": total_recall, - "avg_f1_score": total_f1, - "pass_rate": passed_count / len(test_cases) if test_cases else 0.0, - "results": [r.to_dict() for r in results], - } - - def _evaluate_test_case( - self, test_case: TestCase, industry: Optional[IndustryNiche] - ) -> TestResult: - """Evaluate a single test case""" - try: - # Build query - query = RAGQuery( - query=test_case.query, - industry=industry, - doc_types=test_case.expected_doc_types, - top_k=test_case.top_k, - ) - - # Retrieve documents - results = self.rag_engine.retrieve(query) - - # Extract IDs and scores - retrieved_doc_ids = [r.document.id for r in results] - retrieved_scores = [r.score for r in results] - - # Calculate precision, recall, F1 - expected_set = set(test_case.expected_doc_ids) - retrieved_set = set(retrieved_doc_ids) - - if not retrieved_set: - precision = 0.0 - recall = 0.0 - f1_score = 0.0 - else: - # Precision: relevant retrieved / total retrieved - relevant_retrieved = len(expected_set & retrieved_set) - precision = ( - relevant_retrieved / len(retrieved_set) if retrieved_set else 0.0 - ) - - # Recall: relevant retrieved / total relevant - recall = relevant_retrieved / len(expected_set) if expected_set else 0.0 - - # F1 score - f1_score = ( - 2 * (precision * recall) / (precision + recall) - if (precision + recall) > 0 - else 0.0 - ) - - # Check if passed - passed = ( - precision >= 0.7 - and recall >= 0.7 - and f1_score >= 0.7 - and ( - not retrieved_scores or max(retrieved_scores) >= test_case.min_score - ) - ) - - return TestResult( - test_case=test_case, - retrieved_doc_ids=retrieved_doc_ids, - retrieved_scores=retrieved_scores, - precision=precision, - recall=recall, - f1_score=f1_score, - passed=passed, - ) - - except Exception as e: - return TestResult( - test_case=test_case, - retrieved_doc_ids=[], - retrieved_scores=[], - precision=0.0, - recall=0.0, - f1_score=0.0, - passed=False, - error=str(e), - ) - - -class RAGTestFixture: - """ - Test fixture for creating test data and scenarios - """ - - @staticmethod - def create_test_documents( - industry: IndustryNiche = IndustryNiche.MANUFACTURING, num_documents: int = 10 - ) -> List[Document]: - """Create test documents""" - documents = [] - - for i in range(num_documents): - doc = Document( - id=f"test_doc_{i}", - content=f"Test document {i} content. This is sample content for testing RAG retrieval.", - doc_type=( - DocumentType.MANUAL if i % 2 == 0 else DocumentType.MAINTENANCE_LOG - ), - industry=industry, - title=f"Test Document {i}", - source="test_fixture", - metadata={"test_index": i}, - ) - documents.append(doc) - - return documents - - @staticmethod - def create_test_cases(num_cases: int = 5) -> List[TestCase]: - """Create test cases""" - test_cases = [] - - for i in range(num_cases): - test_case = TestCase( - query=f"test query {i}", - expected_doc_ids=[f"test_doc_{i}", f"test_doc_{i+1}"], - top_k=5, - ) - test_cases.append(test_case) - - return test_cases - - -class PerformanceBenchmark: - """ - Performance benchmarking for RAG operations - """ - - def __init__(self, rag_engine): - self.rag_engine = rag_engine - - def benchmark_retrieval( - self, queries: List[str], iterations: int = 10 - ) -> Dict[str, Any]: - """ - Benchmark retrieval performance - - Args: - queries: List of test queries - iterations: Number of iterations per query - - Returns: - Performance metrics - """ - import time - - total_time = 0.0 - total_queries = 0 - times = [] - - for query_text in queries: - query = RAGQuery(query=query_text, top_k=5) - - for _ in range(iterations): - start = time.time() - results = self.rag_engine.retrieve(query) - elapsed = time.time() - start - - total_time += elapsed - total_queries += 1 - times.append(elapsed * 1000) # Convert to ms - - avg_time_ms = (total_time / total_queries) * 1000 if total_queries > 0 else 0.0 - min_time_ms = min(times) if times else 0.0 - max_time_ms = max(times) if times else 0.0 - - return { - "total_queries": total_queries, - "avg_time_ms": avg_time_ms, - "min_time_ms": min_time_ms, - "max_time_ms": max_time_ms, - "queries_per_second": total_queries / total_time if total_time > 0 else 0.0, - } - - def benchmark_indexing( - self, documents: List[Document], batch_sizes: List[int] = [1, 10, 100] - ) -> Dict[str, Any]: - """ - Benchmark indexing performance - - Args: - documents: Documents to index - batch_sizes: Different batch sizes to test - - Returns: - Performance metrics for each batch size - """ - import time - - results = {} - - for batch_size in batch_sizes: - batches = [ - documents[i : i + batch_size] - for i in range(0, len(documents), batch_size) - ] - - start = time.time() - for batch in batches: - self.rag_engine.index_documents(batch) - elapsed = time.time() - start - - results[f"batch_size_{batch_size}"] = { - "total_documents": len(documents), - "num_batches": len(batches), - "total_time_seconds": elapsed, - "avg_time_per_doc_ms": ( - (elapsed / len(documents)) * 1000 if documents else 0.0 - ), - "docs_per_second": len(documents) / elapsed if elapsed > 0 else 0.0, - } - - return results - - -def create_evaluation_dataset( - industry: IndustryNiche, num_documents: int = 100, num_queries: int = 20 -) -> Tuple[List[Document], List[TestCase]]: - """ - Create a complete evaluation dataset - - Args: - industry: Industry for documents - num_documents: Number of documents to create - num_queries: Number of test queries - - Returns: - Tuple of (documents, test_cases) - """ - documents = RAGTestFixture.create_test_documents(industry, num_documents) - test_cases = RAGTestFixture.create_test_cases(num_queries) - - return documents, test_cases +""" +Testing Utilities for Universal RAG +Comprehensive test helpers, fixtures, and evaluation tools +""" + +from typing import List, Dict, Any, Optional, Tuple +from dataclasses import dataclass +import json +from datetime import datetime + +from .base import Document, DocumentType, IndustryNiche, RAGQuery, RetrievalResult + + +@dataclass +class TestCase: + """Test case for RAG evaluation""" + query: str + expected_doc_ids: List[str] # Document IDs that should be retrieved + expected_doc_types: Optional[List[DocumentType]] = None + min_score: float = 0.7 # Minimum similarity score + top_k: int = 5 + metadata: Dict[str, Any] = None + + def __post_init__(self): + if self.metadata is None: + self.metadata = {} + + +@dataclass +class TestResult: + """Result of a test case""" + test_case: TestCase + retrieved_doc_ids: List[str] + retrieved_scores: List[float] + precision: float + recall: float + f1_score: float + passed: bool + error: Optional[str] = None + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary""" + return { + "query": self.test_case.query, + "expected_doc_ids": self.test_case.expected_doc_ids, + "retrieved_doc_ids": self.retrieved_doc_ids, + "precision": self.precision, + "recall": self.recall, + "f1_score": self.f1_score, + "passed": self.passed, + "error": self.error, + } + + +class RAGEvaluator: + """ + Evaluator for RAG system performance + """ + + def __init__(self, rag_engine): + self.rag_engine = rag_engine + + def evaluate( + self, + test_cases: List[TestCase], + industry: Optional[IndustryNiche] = None + ) -> Dict[str, Any]: + """ + Evaluate RAG system on test cases + + Args: + test_cases: List of test cases + industry: Industry context + + Returns: + Evaluation results with metrics + """ + results = [] + + for test_case in test_cases: + result = self._evaluate_test_case(test_case, industry) + results.append(result) + + # Calculate aggregate metrics + total_precision = sum(r.precision for r in results) / len(results) if results else 0.0 + total_recall = sum(r.recall for r in results) / len(results) if results else 0.0 + total_f1 = sum(r.f1_score for r in results) / len(results) if results else 0.0 + passed_count = sum(1 for r in results if r.passed) + + return { + "total_tests": len(test_cases), + "passed": passed_count, + "failed": len(test_cases) - passed_count, + "avg_precision": total_precision, + "avg_recall": total_recall, + "avg_f1_score": total_f1, + "pass_rate": passed_count / len(test_cases) if test_cases else 0.0, + "results": [r.to_dict() for r in results], + } + + def _evaluate_test_case( + self, + test_case: TestCase, + industry: Optional[IndustryNiche] + ) -> TestResult: + """Evaluate a single test case""" + try: + # Build query + query = RAGQuery( + query=test_case.query, + industry=industry, + doc_types=test_case.expected_doc_types, + top_k=test_case.top_k + ) + + # Retrieve documents + results = self.rag_engine.retrieve(query) + + # Extract IDs and scores + retrieved_doc_ids = [r.document.id for r in results] + retrieved_scores = [r.score for r in results] + + # Calculate precision, recall, F1 + expected_set = set(test_case.expected_doc_ids) + retrieved_set = set(retrieved_doc_ids) + + if not retrieved_set: + precision = 0.0 + recall = 0.0 + f1_score = 0.0 + else: + # Precision: relevant retrieved / total retrieved + relevant_retrieved = len(expected_set & retrieved_set) + precision = relevant_retrieved / len(retrieved_set) if retrieved_set else 0.0 + + # Recall: relevant retrieved / total relevant + recall = relevant_retrieved / len(expected_set) if expected_set else 0.0 + + # F1 score + f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0 + + # Check if passed + passed = ( + precision >= 0.7 and + recall >= 0.7 and + f1_score >= 0.7 and + (not retrieved_scores or max(retrieved_scores) >= test_case.min_score) + ) + + return TestResult( + test_case=test_case, + retrieved_doc_ids=retrieved_doc_ids, + retrieved_scores=retrieved_scores, + precision=precision, + recall=recall, + f1_score=f1_score, + passed=passed + ) + + except Exception as e: + return TestResult( + test_case=test_case, + retrieved_doc_ids=[], + retrieved_scores=[], + precision=0.0, + recall=0.0, + f1_score=0.0, + passed=False, + error=str(e) + ) + + +class RAGTestFixture: + """ + Test fixture for creating test data and scenarios + """ + + @staticmethod + def create_test_documents( + industry: IndustryNiche = IndustryNiche.MANUFACTURING, + num_documents: int = 10 + ) -> List[Document]: + """Create test documents""" + documents = [] + + for i in range(num_documents): + doc = Document( + id=f"test_doc_{i}", + content=f"Test document {i} content. This is sample content for testing RAG retrieval.", + doc_type=DocumentType.MANUAL if i % 2 == 0 else DocumentType.MAINTENANCE_LOG, + industry=industry, + title=f"Test Document {i}", + source="test_fixture", + metadata={"test_index": i} + ) + documents.append(doc) + + return documents + + @staticmethod + def create_test_cases( + num_cases: int = 5 + ) -> List[TestCase]: + """Create test cases""" + test_cases = [] + + for i in range(num_cases): + test_case = TestCase( + query=f"test query {i}", + expected_doc_ids=[f"test_doc_{i}", f"test_doc_{i+1}"], + top_k=5 + ) + test_cases.append(test_case) + + return test_cases + + +class PerformanceBenchmark: + """ + Performance benchmarking for RAG operations + """ + + def __init__(self, rag_engine): + self.rag_engine = rag_engine + + def benchmark_retrieval( + self, + queries: List[str], + iterations: int = 10 + ) -> Dict[str, Any]: + """ + Benchmark retrieval performance + + Args: + queries: List of test queries + iterations: Number of iterations per query + + Returns: + Performance metrics + """ + import time + + total_time = 0.0 + total_queries = 0 + times = [] + + for query_text in queries: + query = RAGQuery(query=query_text, top_k=5) + + for _ in range(iterations): + start = time.time() + results = self.rag_engine.retrieve(query) + elapsed = time.time() - start + + total_time += elapsed + total_queries += 1 + times.append(elapsed * 1000) # Convert to ms + + avg_time_ms = (total_time / total_queries) * 1000 if total_queries > 0 else 0.0 + min_time_ms = min(times) if times else 0.0 + max_time_ms = max(times) if times else 0.0 + + return { + "total_queries": total_queries, + "avg_time_ms": avg_time_ms, + "min_time_ms": min_time_ms, + "max_time_ms": max_time_ms, + "queries_per_second": total_queries / total_time if total_time > 0 else 0.0, + } + + def benchmark_indexing( + self, + documents: List[Document], + batch_sizes: List[int] = [1, 10, 100] + ) -> Dict[str, Any]: + """ + Benchmark indexing performance + + Args: + documents: Documents to index + batch_sizes: Different batch sizes to test + + Returns: + Performance metrics for each batch size + """ + import time + + results = {} + + for batch_size in batch_sizes: + batches = [ + documents[i:i + batch_size] + for i in range(0, len(documents), batch_size) + ] + + start = time.time() + for batch in batches: + self.rag_engine.index_documents(batch) + elapsed = time.time() - start + + results[f"batch_size_{batch_size}"] = { + "total_documents": len(documents), + "num_batches": len(batches), + "total_time_seconds": elapsed, + "avg_time_per_doc_ms": (elapsed / len(documents)) * 1000 if documents else 0.0, + "docs_per_second": len(documents) / elapsed if elapsed > 0 else 0.0, + } + + return results + + +def create_evaluation_dataset( + industry: IndustryNiche, + num_documents: int = 100, + num_queries: int = 20 +) -> Tuple[List[Document], List[TestCase]]: + """ + Create a complete evaluation dataset + + Args: + industry: Industry for documents + num_documents: Number of documents to create + num_queries: Number of test queries + + Returns: + Tuple of (documents, test_cases) + """ + documents = RAGTestFixture.create_test_documents(industry, num_documents) + test_cases = RAGTestFixture.create_test_cases(num_queries) + + return documents, test_cases + diff --git a/src/deepiri_modelkit/registry/adapters/__init__.py b/src/deepiri_modelkit/registry/adapters/__init__.py index 43c51bc..e949845 100644 --- a/src/deepiri_modelkit/registry/adapters/__init__.py +++ b/src/deepiri_modelkit/registry/adapters/__init__.py @@ -1 +1,2 @@ -"""Storage adapters for model registry""" +"""Storage adapters for model registry""" + diff --git a/src/deepiri_modelkit/registry/model_registry.py b/src/deepiri_modelkit/registry/model_registry.py index 7be6f9e..5c5ffdd 100644 --- a/src/deepiri_modelkit/registry/model_registry.py +++ b/src/deepiri_modelkit/registry/model_registry.py @@ -1,333 +1,336 @@ -""" -Unified model registry client -Supports MLflow, S3/MinIO, and local storage -""" - -import os -from typing import Dict, Any, Optional -from pathlib import Path -import mlflow -import boto3 -from botocore.exceptions import ClientError - -from ..contracts.models import ModelMetadata - - -class ModelRegistryClient: - """ - Unified client for model registry operations - Supports MLflow, S3/MinIO, and local storage - """ - - def __init__( - self, - registry_type: str = "mlflow", # mlflow, s3, local - mlflow_tracking_uri: Optional[str] = None, - s3_endpoint: Optional[str] = None, - s3_access_key: Optional[str] = None, - s3_secret_key: Optional[str] = None, - s3_bucket: Optional[str] = None, - local_path: Optional[str] = None, - ): - """ - Initialize model registry client - - Args: - registry_type: Type of registry (mlflow, s3, local) - mlflow_tracking_uri: MLflow tracking URI (defaults to MLFLOW_TRACKING_URI env var or http://mlflow:5000) - s3_endpoint: S3/MinIO endpoint - s3_access_key: S3 access key - s3_secret_key: S3 secret key - s3_bucket: S3 bucket name - local_path: Local storage path - """ - self.registry_type = registry_type - - if registry_type == "mlflow": - # Use provided URI, or environment variable, or default - tracking_uri = ( - mlflow_tracking_uri - or os.getenv("MLFLOW_TRACKING_URI") - or "http://mlflow:5000" - ) - mlflow.set_tracking_uri(tracking_uri) - self.client = mlflow - self.tracking_uri = tracking_uri - elif registry_type == "s3": - self.s3_client = boto3.client( - "s3", - endpoint_url=s3_endpoint, - aws_access_key_id=s3_access_key, - aws_secret_access_key=s3_secret_key, - ) - self.s3_bucket = s3_bucket - elif registry_type == "local": - self.local_path = Path(local_path or "./models") - self.local_path.mkdir(parents=True, exist_ok=True) - else: - raise ValueError(f"Unknown registry type: {registry_type}") - - def register_model( - self, model_name: str, version: str, model_path: str, metadata: Dict[str, Any] - ) -> bool: - """ - Register model in registry - - Args: - model_name: Model name - version: Model version - model_path: Path to model file/directory - metadata: Model metadata - - Returns: - True if successful - """ - try: - if self.registry_type == "mlflow": - # Register with MLflow - model_uri = f"runs:/{metadata.get('run_id', 'latest')}/model" - mlflow.register_model(model_uri, f"{model_name}-{version}") - return True - - elif self.registry_type == "s3": - # Upload to S3 - s3_key = f"models/{model_name}/{version}/model" - self.s3_client.upload_file(model_path, self.s3_bucket, s3_key) - - # Upload metadata - import json - - metadata_key = f"models/{model_name}/{version}/metadata.json" - self.s3_client.put_object( - Bucket=self.s3_bucket, Key=metadata_key, Body=json.dumps(metadata) - ) - return True - - elif self.registry_type == "local": - # Copy to local storage - model_dir = self.local_path / model_name / version - model_dir.mkdir(parents=True, exist_ok=True) - - import shutil - - if os.path.isdir(model_path): - shutil.copytree(model_path, model_dir / "model", dirs_exist_ok=True) - else: - shutil.copy2(model_path, model_dir / "model") - - # Save metadata - import json - - with open(model_dir / "metadata.json", "w") as f: - json.dump(metadata, f) - - return True - - except Exception as e: - print(f"Error registering model: {e}") - return False - - def get_model( - self, model_name: str, version: Optional[str] = None - ) -> Dict[str, Any]: - """ - Get model information from registry - - Args: - model_name: Model name - version: Model version (optional, gets latest if not specified) - - Returns: - Model information dict - """ - try: - if self.registry_type == "mlflow": - if version: - model_uri = f"models:/{model_name}/{version}" - else: - model_uri = f"models:/{model_name}/latest" - - model = mlflow.pyfunc.load_model(model_uri) - return {"model": model, "uri": model_uri, "type": "mlflow"} - - elif self.registry_type == "s3": - if not version: - # List versions and get latest - prefix = f"models/{model_name}/" - response = self.s3_client.list_objects_v2( - Bucket=self.s3_bucket, Prefix=prefix, Delimiter="/" - ) - versions = [ - obj["Prefix"].split("/")[-2] - for obj in response.get("CommonPrefixes", []) - ] - version = max(versions) if versions else None - - if not version: - raise ValueError(f"Model {model_name} not found") - - # Download metadata - metadata_key = f"models/{model_name}/{version}/metadata.json" - response = self.s3_client.get_object( - Bucket=self.s3_bucket, Key=metadata_key - ) - import json - - metadata = json.loads(response["Body"].read()) - - return { - "model_path": f"s3://{self.s3_bucket}/models/{model_name}/{version}/model", - "metadata": metadata, - "type": "s3", - } - - elif self.registry_type == "local": - if not version: - # Get latest version - model_dir = self.local_path / model_name - if not model_dir.exists(): - raise ValueError(f"Model {model_name} not found") - - versions = [d.name for d in model_dir.iterdir() if d.is_dir()] - version = max(versions) if versions else None - - if not version: - raise ValueError(f"Model {model_name} not found") - - model_dir = self.local_path / model_name / version - metadata_path = model_dir / "metadata.json" - - import json - - with open(metadata_path) as f: - metadata = json.load(f) - - return { - "model_path": str(model_dir / "model"), - "metadata": metadata, - "type": "local", - } - - except Exception as e: - print(f"Error getting model: {e}") - raise - - def download_model(self, model_name: str, version: str, destination: str) -> str: - """ - Download model to destination - - Args: - model_name: Model name - version: Model version - destination: Local destination path - - Returns: - Local path to downloaded model - """ - model_info = self.get_model(model_name, version) - - if self.registry_type == "s3": - # Download from S3 - s3_key = f"models/{model_name}/{version}/model" - os.makedirs(destination, exist_ok=True) - - # Check if it's a file or directory - try: - self.s3_client.head_object(Bucket=self.s3_bucket, Key=s3_key) - # It's a file - local_path = os.path.join(destination, "model") - self.s3_client.download_file(self.s3_bucket, s3_key, local_path) - return local_path - except ClientError: - # It's a directory, list and download all files - prefix = f"{s3_key}/" - paginator = self.s3_client.get_paginator("list_objects_v2") - for page in paginator.paginate(Bucket=self.s3_bucket, Prefix=prefix): - for obj in page.get("Contents", []): - key = obj["Key"] - local_file = os.path.join(destination, key[len(prefix) :]) - os.makedirs(os.path.dirname(local_file), exist_ok=True) - self.s3_client.download_file(self.s3_bucket, key, local_file) - return destination - - elif self.registry_type == "local": - # Copy from local - source = model_info["model_path"] - import shutil - - if os.path.isdir(source): - shutil.copytree(source, destination, dirs_exist_ok=True) - else: - os.makedirs(os.path.dirname(destination), exist_ok=True) - shutil.copy2(source, destination) - return destination - - else: - # MLflow handles loading directly - return model_info["uri"] - - def list_models(self, model_name: Optional[str] = None) -> list: - """ - List available models - - Args: - model_name: Optional model name filter - - Returns: - List of model information dicts - """ - try: - if self.registry_type == "mlflow": - if model_name: - models = mlflow.search_registered_models( - filter_string=f"name='{model_name}'" - ) - else: - models = mlflow.search_registered_models() - - return [ - { - "name": m.name, - "versions": [v.version for v in m.latest_versions], - "latest_version": ( - m.latest_versions[0].version if m.latest_versions else None - ), - } - for m in models - ] - - elif self.registry_type == "s3": - prefix = "models/" - if model_name: - prefix = f"models/{model_name}/" - - response = self.s3_client.list_objects_v2( - Bucket=self.s3_bucket, Prefix=prefix, Delimiter="/" - ) - - models = [] - for prefix_obj in response.get("CommonPrefixes", []): - model_path = prefix_obj["Prefix"] - parts = model_path.strip("/").split("/") - if len(parts) >= 2: - models.append({"name": parts[1], "path": model_path}) - - return models - - elif self.registry_type == "local": - models = [] - for model_dir in self.local_path.iterdir(): - if model_dir.is_dir(): - versions = [d.name for d in model_dir.iterdir() if d.is_dir()] - models.append( - { - "name": model_dir.name, - "versions": versions, - "latest_version": max(versions) if versions else None, - } - ) - return models - - except Exception as e: - print(f"Error listing models: {e}") - return [] +""" +Unified model registry client +Supports MLflow, S3/MinIO, and local storage +""" +import os +from typing import Dict, Any, Optional +from pathlib import Path +import mlflow +import boto3 +from botocore.exceptions import ClientError + +from ..contracts.models import ModelMetadata + + +class ModelRegistryClient: + """ + Unified client for model registry operations + Supports MLflow, S3/MinIO, and local storage + """ + + def __init__( + self, + registry_type: str = "mlflow", # mlflow, s3, local + mlflow_tracking_uri: Optional[str] = None, + s3_endpoint: Optional[str] = None, + s3_access_key: Optional[str] = None, + s3_secret_key: Optional[str] = None, + s3_bucket: Optional[str] = None, + local_path: Optional[str] = None + ): + """ + Initialize model registry client + + Args: + registry_type: Type of registry (mlflow, s3, local) + mlflow_tracking_uri: MLflow tracking URI (defaults to MLFLOW_TRACKING_URI env var or http://mlflow:5000) + s3_endpoint: S3/MinIO endpoint + s3_access_key: S3 access key + s3_secret_key: S3 secret key + s3_bucket: S3 bucket name + local_path: Local storage path + """ + self.registry_type = registry_type + + if registry_type == "mlflow": + # Use provided URI, or environment variable, or default + tracking_uri = mlflow_tracking_uri or os.getenv("MLFLOW_TRACKING_URI") or "http://mlflow:5000" + mlflow.set_tracking_uri(tracking_uri) + self.client = mlflow + self.tracking_uri = tracking_uri + elif registry_type == "s3": + self.s3_client = boto3.client( + 's3', + endpoint_url=s3_endpoint, + aws_access_key_id=s3_access_key, + aws_secret_access_key=s3_secret_key + ) + self.s3_bucket = s3_bucket + elif registry_type == "local": + self.local_path = Path(local_path or "./models") + self.local_path.mkdir(parents=True, exist_ok=True) + else: + raise ValueError(f"Unknown registry type: {registry_type}") + + def register_model( + self, + model_name: str, + version: str, + model_path: str, + metadata: Dict[str, Any] + ) -> bool: + """ + Register model in registry + + Args: + model_name: Model name + version: Model version + model_path: Path to model file/directory + metadata: Model metadata + + Returns: + True if successful + """ + try: + if self.registry_type == "mlflow": + # Register with MLflow + model_uri = f"runs:/{metadata.get('run_id', 'latest')}/model" + mlflow.register_model(model_uri, f"{model_name}-{version}") + return True + + elif self.registry_type == "s3": + # Upload to S3 + s3_key = f"models/{model_name}/{version}/model" + self.s3_client.upload_file(model_path, self.s3_bucket, s3_key) + + # Upload metadata + import json + metadata_key = f"models/{model_name}/{version}/metadata.json" + self.s3_client.put_object( + Bucket=self.s3_bucket, + Key=metadata_key, + Body=json.dumps(metadata) + ) + return True + + elif self.registry_type == "local": + # Copy to local storage + model_dir = self.local_path / model_name / version + model_dir.mkdir(parents=True, exist_ok=True) + + import shutil + if os.path.isdir(model_path): + shutil.copytree(model_path, model_dir / "model", dirs_exist_ok=True) + else: + shutil.copy2(model_path, model_dir / "model") + + # Save metadata + import json + with open(model_dir / "metadata.json", "w") as f: + json.dump(metadata, f) + + return True + + except Exception as e: + print(f"Error registering model: {e}") + return False + + def get_model( + self, + model_name: str, + version: Optional[str] = None + ) -> Dict[str, Any]: + """ + Get model information from registry + + Args: + model_name: Model name + version: Model version (optional, gets latest if not specified) + + Returns: + Model information dict + """ + try: + if self.registry_type == "mlflow": + if version: + model_uri = f"models:/{model_name}/{version}" + else: + model_uri = f"models:/{model_name}/latest" + + model = mlflow.pyfunc.load_model(model_uri) + return { + "model": model, + "uri": model_uri, + "type": "mlflow" + } + + elif self.registry_type == "s3": + if not version: + # List versions and get latest + prefix = f"models/{model_name}/" + response = self.s3_client.list_objects_v2( + Bucket=self.s3_bucket, + Prefix=prefix, + Delimiter="/" + ) + versions = [obj["Prefix"].split("/")[-2] for obj in response.get("CommonPrefixes", [])] + version = max(versions) if versions else None + + if not version: + raise ValueError(f"Model {model_name} not found") + + # Download metadata + metadata_key = f"models/{model_name}/{version}/metadata.json" + response = self.s3_client.get_object(Bucket=self.s3_bucket, Key=metadata_key) + import json + metadata = json.loads(response["Body"].read()) + + return { + "model_path": f"s3://{self.s3_bucket}/models/{model_name}/{version}/model", + "metadata": metadata, + "type": "s3" + } + + elif self.registry_type == "local": + if not version: + # Get latest version + model_dir = self.local_path / model_name + if not model_dir.exists(): + raise ValueError(f"Model {model_name} not found") + + versions = [d.name for d in model_dir.iterdir() if d.is_dir()] + version = max(versions) if versions else None + + if not version: + raise ValueError(f"Model {model_name} not found") + + model_dir = self.local_path / model_name / version + metadata_path = model_dir / "metadata.json" + + import json + with open(metadata_path) as f: + metadata = json.load(f) + + return { + "model_path": str(model_dir / "model"), + "metadata": metadata, + "type": "local" + } + + except Exception as e: + print(f"Error getting model: {e}") + raise + + def download_model( + self, + model_name: str, + version: str, + destination: str + ) -> str: + """ + Download model to destination + + Args: + model_name: Model name + version: Model version + destination: Local destination path + + Returns: + Local path to downloaded model + """ + model_info = self.get_model(model_name, version) + + if self.registry_type == "s3": + # Download from S3 + s3_key = f"models/{model_name}/{version}/model" + os.makedirs(destination, exist_ok=True) + + # Check if it's a file or directory + try: + self.s3_client.head_object(Bucket=self.s3_bucket, Key=s3_key) + # It's a file + local_path = os.path.join(destination, "model") + self.s3_client.download_file(self.s3_bucket, s3_key, local_path) + return local_path + except ClientError: + # It's a directory, list and download all files + prefix = f"{s3_key}/" + paginator = self.s3_client.get_paginator('list_objects_v2') + for page in paginator.paginate(Bucket=self.s3_bucket, Prefix=prefix): + for obj in page.get('Contents', []): + key = obj['Key'] + local_file = os.path.join(destination, key[len(prefix):]) + os.makedirs(os.path.dirname(local_file), exist_ok=True) + self.s3_client.download_file(self.s3_bucket, key, local_file) + return destination + + elif self.registry_type == "local": + # Copy from local + source = model_info["model_path"] + import shutil + if os.path.isdir(source): + shutil.copytree(source, destination, dirs_exist_ok=True) + else: + os.makedirs(os.path.dirname(destination), exist_ok=True) + shutil.copy2(source, destination) + return destination + + else: + # MLflow handles loading directly + return model_info["uri"] + + def list_models(self, model_name: Optional[str] = None) -> list: + """ + List available models + + Args: + model_name: Optional model name filter + + Returns: + List of model information dicts + """ + try: + if self.registry_type == "mlflow": + if model_name: + models = mlflow.search_registered_models(filter_string=f"name='{model_name}'") + else: + models = mlflow.search_registered_models() + + return [ + { + "name": m.name, + "versions": [v.version for v in m.latest_versions], + "latest_version": m.latest_versions[0].version if m.latest_versions else None + } + for m in models + ] + + elif self.registry_type == "s3": + prefix = "models/" + if model_name: + prefix = f"models/{model_name}/" + + response = self.s3_client.list_objects_v2( + Bucket=self.s3_bucket, + Prefix=prefix, + Delimiter="/" + ) + + models = [] + for prefix_obj in response.get("CommonPrefixes", []): + model_path = prefix_obj["Prefix"] + parts = model_path.strip("/").split("/") + if len(parts) >= 2: + models.append({ + "name": parts[1], + "path": model_path + }) + + return models + + elif self.registry_type == "local": + models = [] + for model_dir in self.local_path.iterdir(): + if model_dir.is_dir(): + versions = [d.name for d in model_dir.iterdir() if d.is_dir()] + models.append({ + "name": model_dir.name, + "versions": versions, + "latest_version": max(versions) if versions else None + }) + return models + + except Exception as e: + print(f"Error listing models: {e}") + return [] + diff --git a/src/deepiri_modelkit/streaming/event_stream.py b/src/deepiri_modelkit/streaming/event_stream.py index 759fbef..40e5ed4 100644 --- a/src/deepiri_modelkit/streaming/event_stream.py +++ b/src/deepiri_modelkit/streaming/event_stream.py @@ -1,194 +1,204 @@ -""" -Redis Streams client for event-driven architecture -""" - -import redis.asyncio as redis -from typing import Dict, Any, Optional, AsyncIterator, Callable -import json -import asyncio -from datetime import datetime - -from .topics import StreamTopics -from ..contracts.events import BaseEvent - - -class StreamingClient: - """ - Redis Streams client for publishing and subscribing to events - """ - - def __init__( - self, - redis_url: Optional[str] = None, - redis_host: str = "redis", - redis_port: int = 6379, - redis_password: Optional[str] = None, - ): - """ - Initialize streaming client - - Args: - redis_url: Full Redis URL (redis://password@host:port) - redis_host: Redis host (if not using redis_url) - redis_port: Redis port (if not using redis_url) - redis_password: Redis password (if not using redis_url) - """ - if redis_url: - self.redis = redis.from_url(redis_url, decode_responses=True) - else: - self.redis = redis.Redis( - host=redis_host, - port=redis_port, - password=redis_password, - decode_responses=True, - ) - self._running = False - self._subscriptions = {} - - async def connect(self): - """Connect to Redis""" - await self.redis.ping() - - async def disconnect(self): - """Disconnect from Redis""" - await self.redis.close() - - async def publish( - self, topic: str, event: Dict[str, Any], max_length: Optional[int] = 10000 - ) -> str: - """ - Publish event to stream - - Args: - topic: Stream topic name - event: Event data (dict) - max_length: Max stream length (truncate old messages) - - Returns: - Message ID - """ - # Ensure event has timestamp - if "timestamp" not in event: - event["timestamp"] = datetime.utcnow().isoformat() - - # Publish to stream - message_id = await self.redis.xadd( - topic, event, maxlen=max_length, approximate=True - ) - - return message_id - - async def subscribe( - self, - topic: str, - callback: Callable[[Dict[str, Any]], None], - consumer_group: Optional[str] = None, - consumer_name: Optional[str] = None, - last_id: str = "0", - block_ms: int = 1000, - ) -> AsyncIterator[Dict[str, Any]]: - """ - Subscribe to stream and yield events - - Args: - topic: Stream topic name - callback: Optional callback function - consumer_group: Consumer group name (for load balancing) - consumer_name: Consumer name (unique per consumer) - last_id: Last message ID to read from - block_ms: Block time in milliseconds - - Yields: - Event data (dict) - """ - # Create consumer group if specified - if consumer_group: - try: - await self.redis.xgroup_create( - topic, consumer_group, id="0", mkstream=True - ) - except redis.ResponseError as e: - if "BUSYGROUP" not in str(e): - raise - - self._running = True - - while self._running: - try: - if consumer_group and consumer_name: - # Read from consumer group - messages = await self.redis.xreadgroup( - consumer_group, - consumer_name, - {topic: ">"}, - count=10, - block=block_ms, - ) - else: - # Direct read - messages = await self.redis.xread( - {topic: last_id}, count=10, block=block_ms - ) - - for stream_name, stream_messages in messages: - for msg_id, data in stream_messages: - # Yield event - yield data - - # Call callback if provided - if callback: - try: - ( - await callback(data) - if asyncio.iscoroutinefunction(callback) - else callback(data) - ) - except Exception as e: - print(f"Callback error: {e}") - - # Update last_id for next read - last_id = msg_id - - # Acknowledge if using consumer group - if consumer_group and consumer_name: - await self.redis.xack(topic, consumer_group, msg_id) - - except asyncio.CancelledError: - break - except Exception as e: - print(f"Stream subscription error: {e}") - await asyncio.sleep(1) - - async def subscribe_async( - self, - topic: str, - callback: Callable[[Dict[str, Any]], None], - consumer_group: Optional[str] = None, - consumer_name: Optional[str] = None, - ): - """ - Subscribe to stream in background task - - Args: - topic: Stream topic name - callback: Callback function - consumer_group: Consumer group name - consumer_name: Consumer name - """ - async for event in self.subscribe( - topic, callback, consumer_group, consumer_name - ): - pass # Callback handles events - - def stop(self): - """Stop all subscriptions""" - self._running = False - - async def get_stream_info(self, topic: str) -> Dict[str, Any]: - """Get stream information""" - info = await self.redis.xinfo_stream(topic) - return dict(info) - - async def get_stream_length(self, topic: str) -> int: - """Get number of messages in stream""" - return await self.redis.xlen(topic) +""" +Redis Streams client for event-driven architecture +""" +import redis.asyncio as redis +from typing import Dict, Any, Optional, AsyncIterator, Callable +import json +import asyncio +from datetime import datetime + +from .topics import StreamTopics +from ..contracts.events import BaseEvent + + +class StreamingClient: + """ + Redis Streams client for publishing and subscribing to events + """ + + def __init__( + self, + redis_url: Optional[str] = None, + redis_host: str = "redis", + redis_port: int = 6379, + redis_password: Optional[str] = None + ): + """ + Initialize streaming client + + Args: + redis_url: Full Redis URL (redis://password@host:port) + redis_host: Redis host (if not using redis_url) + redis_port: Redis port (if not using redis_url) + redis_password: Redis password (if not using redis_url) + """ + if redis_url: + self.redis = redis.from_url(redis_url, decode_responses=True) + else: + self.redis = redis.Redis( + host=redis_host, + port=redis_port, + password=redis_password, + decode_responses=True + ) + self._running = False + self._subscriptions = {} + + async def connect(self): + """Connect to Redis""" + await self.redis.ping() + + async def disconnect(self): + """Disconnect from Redis""" + await self.redis.close() + + async def publish( + self, + topic: str, + event: Dict[str, Any], + max_length: Optional[int] = 10000 + ) -> str: + """ + Publish event to stream + + Args: + topic: Stream topic name + event: Event data (dict) + max_length: Max stream length (truncate old messages) + + Returns: + Message ID + """ + # Ensure event has timestamp + if "timestamp" not in event: + event["timestamp"] = datetime.utcnow().isoformat() + + # Publish to stream + message_id = await self.redis.xadd( + topic, + event, + maxlen=max_length, + approximate=True + ) + + return message_id + + async def subscribe( + self, + topic: str, + callback: Callable[[Dict[str, Any]], None], + consumer_group: Optional[str] = None, + consumer_name: Optional[str] = None, + last_id: str = "0", + block_ms: int = 1000 + ) -> AsyncIterator[Dict[str, Any]]: + """ + Subscribe to stream and yield events + + Args: + topic: Stream topic name + callback: Optional callback function + consumer_group: Consumer group name (for load balancing) + consumer_name: Consumer name (unique per consumer) + last_id: Last message ID to read from + block_ms: Block time in milliseconds + + Yields: + Event data (dict) + """ + # Create consumer group if specified + if consumer_group: + try: + await self.redis.xgroup_create( + topic, + consumer_group, + id="0", + mkstream=True + ) + except redis.ResponseError as e: + if "BUSYGROUP" not in str(e): + raise + + self._running = True + + while self._running: + try: + if consumer_group and consumer_name: + # Read from consumer group + messages = await self.redis.xreadgroup( + consumer_group, + consumer_name, + {topic: ">"}, + count=10, + block=block_ms + ) + else: + # Direct read + messages = await self.redis.xread( + {topic: last_id}, + count=10, + block=block_ms + ) + + for stream_name, stream_messages in messages: + for msg_id, data in stream_messages: + # Yield event + yield data + + # Call callback if provided + if callback: + try: + await callback(data) if asyncio.iscoroutinefunction(callback) else callback(data) + except Exception as e: + print(f"Callback error: {e}") + + # Update last_id for next read + last_id = msg_id + + # Acknowledge if using consumer group + if consumer_group and consumer_name: + await self.redis.xack(topic, consumer_group, msg_id) + + except asyncio.CancelledError: + break + except Exception as e: + print(f"Stream subscription error: {e}") + await asyncio.sleep(1) + + async def subscribe_async( + self, + topic: str, + callback: Callable[[Dict[str, Any]], None], + consumer_group: Optional[str] = None, + consumer_name: Optional[str] = None + ): + """ + Subscribe to stream in background task + + Args: + topic: Stream topic name + callback: Callback function + consumer_group: Consumer group name + consumer_name: Consumer name + """ + async for event in self.subscribe( + topic, + callback, + consumer_group, + consumer_name + ): + pass # Callback handles events + + def stop(self): + """Stop all subscriptions""" + self._running = False + + async def get_stream_info(self, topic: str) -> Dict[str, Any]: + """Get stream information""" + info = await self.redis.xinfo_stream(topic) + return dict(info) + + async def get_stream_length(self, topic: str) -> int: + """Get number of messages in stream""" + return await self.redis.xlen(topic) + diff --git a/src/deepiri_modelkit/streaming/schemas.py b/src/deepiri_modelkit/streaming/schemas.py index fdd1b4e..388fa31 100644 --- a/src/deepiri_modelkit/streaming/schemas.py +++ b/src/deepiri_modelkit/streaming/schemas.py @@ -1,55 +1,56 @@ -""" -Streaming event schemas and validation -""" - -from .topics import StreamTopics -from ..contracts.events import ( - BaseEvent, - ModelReadyEvent, - ModelLoadedEvent, - InferenceEvent, - PlatformEvent, - AGIDecisionEvent, - TrainingEvent, -) - -# Map topics to event schemas -TOPIC_EVENT_SCHEMAS = { - StreamTopics.MODEL_EVENTS: [ModelReadyEvent, ModelLoadedEvent], - StreamTopics.INFERENCE_EVENTS: [InferenceEvent], - StreamTopics.PLATFORM_EVENTS: [PlatformEvent], - StreamTopics.AGI_DECISIONS: [AGIDecisionEvent], - StreamTopics.TRAINING_EVENTS: [TrainingEvent], -} - - -def validate_event(topic: str, event_data: dict) -> BaseEvent: - """ - Validate event against schema - - Args: - topic: Stream topic - event_data: Event data dict - - Returns: - Validated event object - - Raises: - ValueError: If event doesn't match schema - """ - if topic not in TOPIC_EVENT_SCHEMAS: - # Unknown topic, return base event - return BaseEvent(**event_data) - - schemas = TOPIC_EVENT_SCHEMAS[topic] - event_type = event_data.get("event") - - # Try to match event type to schema - for schema in schemas: - try: - return schema(**event_data) - except Exception: - continue - - # Fallback to base event - return BaseEvent(**event_data) +""" +Streaming event schemas and validation +""" +from .topics import StreamTopics +from ..contracts.events import ( + BaseEvent, + ModelReadyEvent, + ModelLoadedEvent, + InferenceEvent, + PlatformEvent, + AGIDecisionEvent, + TrainingEvent, +) + + +# Map topics to event schemas +TOPIC_EVENT_SCHEMAS = { + StreamTopics.MODEL_EVENTS: [ModelReadyEvent, ModelLoadedEvent], + StreamTopics.INFERENCE_EVENTS: [InferenceEvent], + StreamTopics.PLATFORM_EVENTS: [PlatformEvent], + StreamTopics.AGI_DECISIONS: [AGIDecisionEvent], + StreamTopics.TRAINING_EVENTS: [TrainingEvent], +} + + +def validate_event(topic: str, event_data: dict) -> BaseEvent: + """ + Validate event against schema + + Args: + topic: Stream topic + event_data: Event data dict + + Returns: + Validated event object + + Raises: + ValueError: If event doesn't match schema + """ + if topic not in TOPIC_EVENT_SCHEMAS: + # Unknown topic, return base event + return BaseEvent(**event_data) + + schemas = TOPIC_EVENT_SCHEMAS[topic] + event_type = event_data.get("event") + + # Try to match event type to schema + for schema in schemas: + try: + return schema(**event_data) + except Exception: + continue + + # Fallback to base event + return BaseEvent(**event_data) + diff --git a/src/deepiri_modelkit/streaming/sidecar_utils.py b/src/deepiri_modelkit/streaming/sidecar_utils.py index 6842054..3af5322 100644 --- a/src/deepiri_modelkit/streaming/sidecar_utils.py +++ b/src/deepiri_modelkit/streaming/sidecar_utils.py @@ -1,81 +1,79 @@ -""" -Shared Sugar Glider/Synapse sidecar helpers. - -These utilities are reused by multiple services (for example Cyrex and Helox) -to keep sidecar transport behavior consistent across repos. -""" - -from __future__ import annotations - -import json -import os -from typing import Any, Callable, Dict, Optional -from urllib.parse import urlparse - - -def env_float( - name: str, default: float, logger: Optional[Callable[[str], None]] = None -) -> float: - """Read a float env var with safe fallback.""" - raw = os.getenv(name) - if raw is None: - return default - try: - return float(raw) - except ValueError: - if logger is not None: - logger(f"invalid float env {name}={raw!r}; using {default}") - return default - - -def resolve_grpc_addr(base_url: str, explicit_grpc_addr: Optional[str] = None) -> str: - """ - Resolve sidecar gRPC host:port from explicit/env/base URL. - - Resolution order: - 1) explicit_grpc_addr - 2) SYNAPSE_GRPC_ADDR - 3) derive from base_url (8081 -> 50051) - """ - env_addr = os.getenv("SYNAPSE_GRPC_ADDR") - if explicit_grpc_addr: - return explicit_grpc_addr - if env_addr: - return env_addr - - parsed = urlparse(base_url) - if parsed.scheme in {"http", "https"}: - host = parsed.hostname or "localhost" - port = parsed.port - if port is None: - port = 443 if parsed.scheme == "https" else 80 - if port == 8081: - port = 50051 - return f"{host}:{port}" - - if base_url: - return base_url - return "localhost:50051" - - -def sidecar_payload_from_fields(fields: Dict[str, Any]) -> Dict[str, Any]: - """Normalize sidecar event fields to a payload dict.""" - payload = fields.get("payload", {}) - if isinstance(payload, str): - try: - payload = json.loads(payload) - except ValueError: - payload = {} - elif not isinstance(payload, dict): - payload = {} - - if "event" not in payload and fields.get("event_type"): - payload["event"] = fields.get("event_type") - - if "timestamp" not in payload and fields.get("timestamp"): - payload["timestamp"] = fields.get("timestamp") - - if "sender" not in payload and fields.get("sender"): - payload["sender"] = fields.get("sender") - - return payload +""" +Shared Sugar Glider/Synapse sidecar helpers. + +These utilities are reused by multiple services (for example Cyrex and Helox) +to keep sidecar transport behavior consistent across repos. +""" + +from __future__ import annotations + +import json +import os +from typing import Any, Callable, Dict, Optional +from urllib.parse import urlparse + + +def env_float(name: str, default: float, logger: Optional[Callable[[str], None]] = None) -> float: + """Read a float env var with safe fallback.""" + raw = os.getenv(name) + if raw is None: + return default + try: + return float(raw) + except ValueError: + if logger is not None: + logger(f"invalid float env {name}={raw!r}; using {default}") + return default + + +def resolve_grpc_addr(base_url: str, explicit_grpc_addr: Optional[str] = None) -> str: + """ + Resolve sidecar gRPC host:port from explicit/env/base URL. + + Resolution order: + 1) explicit_grpc_addr + 2) SYNAPSE_GRPC_ADDR + 3) derive from base_url (8081 -> 50051) + """ + env_addr = os.getenv("SYNAPSE_GRPC_ADDR") + if explicit_grpc_addr: + return explicit_grpc_addr + if env_addr: + return env_addr + + parsed = urlparse(base_url) + if parsed.scheme in {"http", "https"}: + host = parsed.hostname or "localhost" + port = parsed.port + if port is None: + port = 443 if parsed.scheme == "https" else 80 + if port == 8081: + port = 50051 + return f"{host}:{port}" + + if base_url: + return base_url + return "localhost:50051" + + +def sidecar_payload_from_fields(fields: Dict[str, Any]) -> Dict[str, Any]: + """Normalize sidecar event fields to a payload dict.""" + payload = fields.get("payload", {}) + if isinstance(payload, str): + try: + payload = json.loads(payload) + except ValueError: + payload = {} + elif not isinstance(payload, dict): + payload = {} + + if "event" not in payload and fields.get("event_type"): + payload["event"] = fields.get("event_type") + + if "timestamp" not in payload and fields.get("timestamp"): + payload["timestamp"] = fields.get("timestamp") + + if "sender" not in payload and fields.get("sender"): + payload["sender"] = fields.get("sender") + + return payload diff --git a/src/deepiri_modelkit/streaming/topics.py b/src/deepiri_modelkit/streaming/topics.py index f3515db..b609448 100644 --- a/src/deepiri_modelkit/streaming/topics.py +++ b/src/deepiri_modelkit/streaming/topics.py @@ -1,30 +1,27 @@ -""" -Stream topic definitions. -""" - -from enum import Enum - - +""" +Stream topic definitions +""" +from enum import Enum + + class StreamTopics(str, Enum): - """Redis Stream topics.""" - + """Redis Stream topics""" MODEL_EVENTS = "model-events" INFERENCE_EVENTS = "inference-events" PLATFORM_EVENTS = "platform-events" AGI_DECISIONS = "agi-decisions" TRAINING_EVENTS = "training-events" - # LIS document routing streams (document.* namespace). DOCUMENT_VECTORIZE = "document.vectorize" DOCUMENT_TRAINING = "document.training" DOCUMENT_STRUCTURED = "document.structured" DOCUMENT_ARTIFACTS = "document.artifacts" - - # Cyrex runtime training signals for Helox (pipeline.* namespace). + # Cyrex runtime training signals consumed by Helox. HELOX_TRAINING_RAW = "pipeline.helox-training.raw" HELOX_TRAINING_STRUCTURED = "pipeline.helox-training.structured" @classmethod - def all(cls) -> list[str]: - """Get all topic names.""" - return [topic.value for topic in cls] + def all(cls) -> list: + """Get all topic names""" + return [topic.value for topic in cls] + diff --git a/src/deepiri_modelkit/utils/__init__.py b/src/deepiri_modelkit/utils/__init__.py index 0d08069..a7dfcc9 100644 --- a/src/deepiri_modelkit/utils/__init__.py +++ b/src/deepiri_modelkit/utils/__init__.py @@ -1,8 +1,7 @@ -"""Common utilities for Deepiri ModelKit""" - -try: - from .device import get_device, get_torch_device - - __all__ = ["get_device", "get_torch_device"] -except ImportError: - __all__ = [] +"""Common utilities for Deepiri ModelKit""" + +try: + from .device import get_device, get_torch_device + __all__ = ["get_device", "get_torch_device"] +except ImportError: + __all__ = [] diff --git a/src/deepiri_modelkit/utils/device.py b/src/deepiri_modelkit/utils/device.py index 0226c5e..8b3922a 100644 --- a/src/deepiri_modelkit/utils/device.py +++ b/src/deepiri_modelkit/utils/device.py @@ -1,158 +1,143 @@ -""" -GPU Device Detection Utility -Automatically detects and uses GPU (CUDA) if available, falls back to CPU -""" - -import os - -try: - import torch - - HAS_TORCH = True -except ImportError: - HAS_TORCH = False - -from deepiri_modelkit.logging import get_logger - -logger = get_logger("deepiri_modelkit.utils.device") - - -def get_device() -> str: - """ - Detect the best available device with proper fallback: CUDA → MPS → CPU - - Returns device string that can be used with PyTorch and SentenceTransformers. - Actually tests GPU functionality, not just availability. - """ - if not HAS_TORCH: - logger.info("PyTorch not installed, using CPU") - return "cpu" - - # Log diagnostic information for debugging - logger.debug(f"PyTorch version: {torch.__version__}") - logger.debug( - f"CUDA available (torch.cuda.is_available()): {torch.cuda.is_available()}" - ) - - if torch.cuda.is_available(): - try: - # Get CUDA information - cuda_version = torch.version.cuda - device_count = torch.cuda.device_count() - gpu_name = torch.cuda.get_device_name(0) - cuda_capability = torch.cuda.get_device_capability(0) - - logger.info( - f"CUDA detected: version={cuda_version}, devices={device_count}, " - f"GPU={gpu_name}, capability={cuda_capability[0]}.{cuda_capability[1]}" - ) - - # Check for RTX 5080/5090 (sm_120) compatibility issue - if cuda_capability[0] >= 12: - # Check if PyTorch supports this compute capability - try: - # Try to get the list of supported compute capabilities - # PyTorch 2.9.1 with CUDA 12.6 supports up to sm_90 - # RTX 5080/5090 requires sm_120 support (CUDA 12.8+) - test_tensor = torch.tensor([1.0], device="cuda") - result = test_tensor * 2.0 - _ = result.cpu() - del test_tensor, result - torch.cuda.empty_cache() - except RuntimeError as e: - if ( - "no kernel image is available for execution on the device" - in str(e) - or "cudaErrorNoKernelImageForDevice" in str(e) - ): - logger.error( - f"RTX 5080/5090 (sm_{cuda_capability[0]}.{cuda_capability[1]}) detected, but PyTorch doesn't support this compute capability. " - f"Current PyTorch supports up to sm_90. " - f"To fix: Rebuild Docker image (CUDA 12.8 support should be automatic): " - f"docker-compose -f docker-compose.dev.yml build --no-cache cyrex" - ) - raise - else: - raise - - # Test GPU functionality with a simple operation - test_tensor = torch.tensor([1.0], device="cuda") - result = test_tensor * 2.0 - _ = result.cpu() # Ensure operation completes - del test_tensor, result - torch.cuda.empty_cache() - - logger.info( - f"CUDA GPU detected and tested successfully: {gpu_name} " - f"(CUDA {cuda_version}, Capability {cuda_capability[0]}.{cuda_capability[1]})" - ) - return "cuda" - except RuntimeError as cuda_error: - error_msg = str(cuda_error) - # Check if this is the RTX 5080 compatibility issue - if ( - "no kernel image is available for execution on the device" in error_msg - or "cudaErrorNoKernelImageForDevice" in error_msg - ): - logger.error( - f"GPU compute capability not supported by current PyTorch installation. " - f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'Unknown'}, " - f"Capability: {torch.cuda.get_device_capability(0) if torch.cuda.is_available() else 'Unknown'}. " - f"Error: {error_msg}. " - f"Solution: Rebuild Docker image (CUDA 12.8 support should be automatic): " - f"docker-compose -f docker-compose.dev.yml build --no-cache cyrex" - ) - else: - logger.warning( - f"CUDA available but GPU test failed: {error_msg}. " - f"Falling back to CPU. This may indicate: " - f"1) GPU not accessible in Docker container (check NVIDIA Container Toolkit), " - f"2) CUDA driver mismatch, or 3) GPU memory issue." - ) - except Exception as cuda_error: - logger.warning( - f"CUDA available but test failed: {cuda_error}. Falling back to CPU" - ) - else: - # CUDA not available - provide diagnostic info - logger.debug("CUDA not available via torch.cuda.is_available()") - - # Check if we're in Docker and might need NVIDIA Container Toolkit - if os.path.exists("/.dockerenv"): - logger.debug( - "Running in Docker container - ensure NVIDIA Container Toolkit is installed" - ) - # Check for NVIDIA runtime - if os.path.exists("/proc/driver/nvidia"): - logger.warning( - "NVIDIA driver detected in container but PyTorch CUDA not available. " - "This may indicate: 1) PyTorch not built with CUDA support, " - "2) CUDA libraries not in container, or 3) NVIDIA Container Toolkit not configured." - ) - - # Check MPS (Apple Silicon) - if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): - try: - test_tensor = torch.tensor([1.0], device="mps") - result = test_tensor * 2.0 - _ = result.cpu() - del test_tensor, result - logger.info("Apple Silicon (MPS) detected and tested successfully") - return "mps" - except Exception as mps_error: - logger.warning( - f"MPS available but test failed, falling back to CPU: {mps_error}" - ) - - # Fallback to CPU - logger.info("Using CPU device (no GPU detected or GPU test failed)") - return "cpu" - - -def get_torch_device() -> "torch.device": - """Get PyTorch device object""" - if not HAS_TORCH: - raise ImportError( - "torch is required for get_torch_device(). Install with: pip install torch" - ) - return torch.device(get_device()) +""" +GPU Device Detection Utility +Automatically detects and uses GPU (CUDA) if available, falls back to CPU +""" +import os + +try: + import torch + HAS_TORCH = True +except ImportError: + HAS_TORCH = False + +from deepiri_modelkit.logging import get_logger + +logger = get_logger("deepiri_modelkit.utils.device") + + +def get_device() -> str: + """ + Detect the best available device with proper fallback: CUDA → MPS → CPU + + Returns device string that can be used with PyTorch and SentenceTransformers. + Actually tests GPU functionality, not just availability. + """ + if not HAS_TORCH: + logger.info("PyTorch not installed, using CPU") + return "cpu" + + # Log diagnostic information for debugging + logger.debug(f"PyTorch version: {torch.__version__}") + logger.debug(f"CUDA available (torch.cuda.is_available()): {torch.cuda.is_available()}") + + if torch.cuda.is_available(): + try: + # Get CUDA information + cuda_version = torch.version.cuda + device_count = torch.cuda.device_count() + gpu_name = torch.cuda.get_device_name(0) + cuda_capability = torch.cuda.get_device_capability(0) + + logger.info( + f"CUDA detected: version={cuda_version}, devices={device_count}, " + f"GPU={gpu_name}, capability={cuda_capability[0]}.{cuda_capability[1]}" + ) + + # Check for RTX 5080/5090 (sm_120) compatibility issue + if cuda_capability[0] >= 12: + # Check if PyTorch supports this compute capability + try: + # Try to get the list of supported compute capabilities + # PyTorch 2.9.1 with CUDA 12.6 supports up to sm_90 + # RTX 5080/5090 requires sm_120 support (CUDA 12.8+) + test_tensor = torch.tensor([1.0], device='cuda') + result = test_tensor * 2.0 + _ = result.cpu() + del test_tensor, result + torch.cuda.empty_cache() + except RuntimeError as e: + if "no kernel image is available for execution on the device" in str(e) or \ + "cudaErrorNoKernelImageForDevice" in str(e): + logger.error( + f"RTX 5080/5090 (sm_{cuda_capability[0]}.{cuda_capability[1]}) detected, but PyTorch doesn't support this compute capability. " + f"Current PyTorch supports up to sm_90. " + f"To fix: Rebuild Docker image (CUDA 12.8 support should be automatic): " + f"docker-compose -f docker-compose.dev.yml build --no-cache cyrex" + ) + raise + else: + raise + + # Test GPU functionality with a simple operation + test_tensor = torch.tensor([1.0], device='cuda') + result = test_tensor * 2.0 + _ = result.cpu() # Ensure operation completes + del test_tensor, result + torch.cuda.empty_cache() + + logger.info( + f"CUDA GPU detected and tested successfully: {gpu_name} " + f"(CUDA {cuda_version}, Capability {cuda_capability[0]}.{cuda_capability[1]})" + ) + return "cuda" + except RuntimeError as cuda_error: + error_msg = str(cuda_error) + # Check if this is the RTX 5080 compatibility issue + if "no kernel image is available for execution on the device" in error_msg or \ + "cudaErrorNoKernelImageForDevice" in error_msg: + logger.error( + f"GPU compute capability not supported by current PyTorch installation. " + f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'Unknown'}, " + f"Capability: {torch.cuda.get_device_capability(0) if torch.cuda.is_available() else 'Unknown'}. " + f"Error: {error_msg}. " + f"Solution: Rebuild Docker image (CUDA 12.8 support should be automatic): " + f"docker-compose -f docker-compose.dev.yml build --no-cache cyrex" + ) + else: + logger.warning( + f"CUDA available but GPU test failed: {error_msg}. " + f"Falling back to CPU. This may indicate: " + f"1) GPU not accessible in Docker container (check NVIDIA Container Toolkit), " + f"2) CUDA driver mismatch, or 3) GPU memory issue." + ) + except Exception as cuda_error: + logger.warning( + f"CUDA available but test failed: {cuda_error}. Falling back to CPU" + ) + else: + # CUDA not available - provide diagnostic info + logger.debug("CUDA not available via torch.cuda.is_available()") + + # Check if we're in Docker and might need NVIDIA Container Toolkit + if os.path.exists("/.dockerenv"): + logger.debug("Running in Docker container - ensure NVIDIA Container Toolkit is installed") + # Check for NVIDIA runtime + if os.path.exists("/proc/driver/nvidia"): + logger.warning( + "NVIDIA driver detected in container but PyTorch CUDA not available. " + "This may indicate: 1) PyTorch not built with CUDA support, " + "2) CUDA libraries not in container, or 3) NVIDIA Container Toolkit not configured." + ) + + # Check MPS (Apple Silicon) + if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): + try: + test_tensor = torch.tensor([1.0], device='mps') + result = test_tensor * 2.0 + _ = result.cpu() + del test_tensor, result + logger.info("Apple Silicon (MPS) detected and tested successfully") + return "mps" + except Exception as mps_error: + logger.warning(f"MPS available but test failed, falling back to CPU: {mps_error}") + + # Fallback to CPU + logger.info("Using CPU device (no GPU detected or GPU test failed)") + return "cpu" + + +def get_torch_device() -> "torch.device": + """Get PyTorch device object""" + if not HAS_TORCH: + raise ImportError("torch is required for get_torch_device(). Install with: pip install torch") + return torch.device(get_device()) From 577d4c951415f29737668c76ade43c084117f128 Mon Sep 17 00:00:00 2001 From: Bao Tran Date: Sat, 9 May 2026 19:29:17 -0400 Subject: [PATCH 6/6] style: satisfy modelkit black check --- src/deepiri_modelkit/__init__.py | 9 +- src/deepiri_modelkit/contracts/contract.py | 9 +- src/deepiri_modelkit/contracts/events.py | 10 +- src/deepiri_modelkit/contracts/models.py | 56 ++-- src/deepiri_modelkit/contracts/services.py | 36 +-- src/deepiri_modelkit/data/monitoring.py | 176 +++++++---- src/deepiri_modelkit/data/validation.py | 126 +++++--- src/deepiri_modelkit/logging.py | 88 ++++-- src/deepiri_modelkit/ml/__init__.py | 2 + src/deepiri_modelkit/ml/confidence.py | 62 ++-- src/deepiri_modelkit/ml/semantic.py | 88 +++--- src/deepiri_modelkit/rag/__init__.py | 67 ++-- .../rag/advanced_retrieval.py | 224 +++++++------ src/deepiri_modelkit/rag/async_processing.py | 169 +++++----- src/deepiri_modelkit/rag/base.py | 101 +++--- src/deepiri_modelkit/rag/caching.py | 212 ++++++------- src/deepiri_modelkit/rag/monitoring.py | 189 ++++++----- src/deepiri_modelkit/rag/processors.py | 296 ++++++++++-------- src/deepiri_modelkit/rag/retrievers.py | 143 +++++---- src/deepiri_modelkit/rag/testing.py | 166 +++++----- .../registry/adapters/__init__.py | 1 - .../registry/model_registry.py | 191 ++++++----- .../streaming/event_stream.py | 90 +++--- src/deepiri_modelkit/streaming/schemas.py | 15 +- .../streaming/sidecar_utils.py | 4 +- src/deepiri_modelkit/streaming/topics.py | 39 +-- src/deepiri_modelkit/utils/__init__.py | 1 + src/deepiri_modelkit/utils/device.py | 39 ++- 28 files changed, 1424 insertions(+), 1185 deletions(-) diff --git a/src/deepiri_modelkit/__init__.py b/src/deepiri_modelkit/__init__.py index 4d3f09b..227bf54 100644 --- a/src/deepiri_modelkit/__init__.py +++ b/src/deepiri_modelkit/__init__.py @@ -4,7 +4,13 @@ __version__ = "0.1.0" -from .contracts.models import AIModel, AIModelPydantic, ModelInput, ModelOutput, ModelMetadata +from .contracts.models import ( + AIModel, + AIModelPydantic, + ModelInput, + ModelOutput, + ModelMetadata, +) from .contracts.events import ( ModelReadyEvent, InferenceEvent, @@ -33,4 +39,3 @@ "get_error_logger", "ErrorLogger", ] - diff --git a/src/deepiri_modelkit/contracts/contract.py b/src/deepiri_modelkit/contracts/contract.py index 00abc21..b16aa1a 100644 --- a/src/deepiri_modelkit/contracts/contract.py +++ b/src/deepiri_modelkit/contracts/contract.py @@ -1,6 +1,7 @@ """ Model contract for registry (separated from models.py to avoid Pydantic Protocol conflicts) """ + from __future__ import annotations from typing import Dict, Any, Optional @@ -12,16 +13,18 @@ class ModelContract(BaseModel): """ Complete model contract for registry. - + A contract is serializable metadata that describes a model's interface, input/output schemas, and validation requirements. It does NOT contain the actual model instance (which would be a Protocol type that Pydantic cannot serialize). The model instance should be loaded separately when needed. """ + metadata: ModelMetadata input_schema: Dict[str, Any] output_schema: Dict[str, Any] validation_tests: Optional[list] = None - model_path: Optional[str] = None # Path/reference to where the model can be loaded from + model_path: Optional[str] = ( + None # Path/reference to where the model can be loaded from + ) model_id: Optional[str] = None # Unique identifier for the model instance - diff --git a/src/deepiri_modelkit/contracts/events.py b/src/deepiri_modelkit/contracts/events.py index 9711a93..fefe582 100644 --- a/src/deepiri_modelkit/contracts/events.py +++ b/src/deepiri_modelkit/contracts/events.py @@ -1,6 +1,7 @@ """ Event schemas for streaming service """ + from pydantic import BaseModel, Field from typing import Dict, Any, Optional from datetime import datetime @@ -9,6 +10,7 @@ class EventType(str, Enum): """Event type enumeration""" + MODEL_READY = "model-ready" MODEL_LOADED = "model-loaded" MODEL_FAILED = "model-failed" @@ -26,6 +28,7 @@ class EventType(str, Enum): class BaseEvent(BaseModel): """Base event schema""" + event: str timestamp: str = Field(default_factory=lambda: datetime.utcnow().isoformat()) source: str @@ -34,6 +37,7 @@ class BaseEvent(BaseModel): class ModelReadyEvent(BaseEvent): """Event published when model is trained and ready""" + event: str = EventType.MODEL_READY model_name: str version: str @@ -46,6 +50,7 @@ class ModelReadyEvent(BaseEvent): class ModelLoadedEvent(BaseEvent): """Event published when model is loaded in runtime""" + event: str = EventType.MODEL_LOADED model_name: str version: str @@ -55,6 +60,7 @@ class ModelLoadedEvent(BaseEvent): class InferenceEvent(BaseEvent): """Event published after inference completes""" + event: str = EventType.INFERENCE_COMPLETE model_name: str version: str @@ -69,6 +75,7 @@ class InferenceEvent(BaseEvent): class PlatformEvent(BaseEvent): """Event published by platform services""" + event: str # user-interaction, task-created, etc. service: str user_id: Optional[str] = None @@ -79,6 +86,7 @@ class PlatformEvent(BaseEvent): class AGIDecisionEvent(BaseEvent): """Event published by Cyrex-AGI for autonomous decisions""" + event: str = EventType.AGI_DECISION decision_type: str target_service: Optional[str] = None @@ -89,6 +97,7 @@ class AGIDecisionEvent(BaseEvent): class TrainingEvent(BaseEvent): """Event published during training""" + event: str # training-started, training-complete, training-failed experiment_id: str model_name: str @@ -96,4 +105,3 @@ class TrainingEvent(BaseEvent): progress: Optional[float] = None # 0.0 to 1.0 metrics: Optional[Dict[str, Any]] = None error: Optional[str] = None - diff --git a/src/deepiri_modelkit/contracts/models.py b/src/deepiri_modelkit/contracts/models.py index 6e52336..bdc6568 100644 --- a/src/deepiri_modelkit/contracts/models.py +++ b/src/deepiri_modelkit/contracts/models.py @@ -1,7 +1,10 @@ """ Model contracts and interfaces """ -from __future__ import annotations # Defer annotation evaluation to prevent Pydantic from processing Protocol types + +from __future__ import ( + annotations, +) # Defer annotation evaluation to prevent Pydantic from processing Protocol types from typing import Protocol, Dict, Any, Optional, Annotated from pydantic import BaseModel, Field, GetCoreSchemaHandler @@ -11,6 +14,7 @@ class ModelInput(BaseModel): """Standard model input schema""" + data: Dict[str, Any] metadata: Optional[Dict[str, Any]] = None timestamp: str = Field(default_factory=lambda: datetime.utcnow().isoformat()) @@ -18,6 +22,7 @@ class ModelInput(BaseModel): class ModelOutput(BaseModel): """Standard model output schema""" + prediction: Any confidence: Optional[float] = None metadata: Optional[Dict[str, Any]] = None @@ -26,6 +31,7 @@ class ModelOutput(BaseModel): class ModelMetadata(BaseModel): """Model metadata schema""" + name: str version: str description: Optional[str] = None @@ -41,23 +47,23 @@ class AIModel(Protocol): """ Interface that all models must implement Used by both Helox (training) and Cyrex (runtime) - + Note: This is a Protocol (structural type). To use in Pydantic models, use AIModelPydantic instead, which has full Pydantic schema support. """ - + def predict(self, input: ModelInput) -> ModelOutput: """Run inference on input""" ... - + def get_metadata(self) -> ModelMetadata: """Get model metadata""" ... - + def validate(self) -> bool: """Validate model is ready for use""" ... - + def export(self, format: str = "onnx") -> str: """Export model to specified format, returns path""" ... @@ -67,14 +73,14 @@ class AIModelPydantic: """ Pydantic-compatible wrapper for AIModel Protocol. Implements __get_pydantic_core_schema__ to provide full schema support. - + Usage in Pydantic models: model: Optional[AIModelPydantic] = None - + Note: In practice, model instances should be loaded separately and referenced by ID/path rather than stored directly in serializable Pydantic models. """ - + @classmethod def __get_pydantic_core_schema__( cls, @@ -83,61 +89,61 @@ def __get_pydantic_core_schema__( ) -> core_schema.CoreSchema: """ Pydantic Core Schema handler for AIModel Protocol. - + This allows Pydantic to process AIModel types in model fields. Since Protocols are structural types, we validate that the object has the required methods rather than checking exact type. """ + def validate_aimodel(value: Any) -> Any: """Validate that value implements AIModel Protocol interface""" if value is None: return None - + # Check for required Protocol methods - required_methods = ['predict', 'get_metadata', 'validate', 'export'] + required_methods = ["predict", "get_metadata", "validate", "export"] missing_methods = [m for m in required_methods if not hasattr(value, m)] - + if missing_methods: raise ValueError( f"Object does not implement AIModel Protocol. " f"Missing methods: {', '.join(missing_methods)}" ) - + return value - + def serialize_aimodel(value: Any) -> Dict[str, Any]: """Serialize AIModel instance to dict""" if value is None: return None - + # Try to get metadata if available metadata = None - if hasattr(value, 'get_metadata'): + if hasattr(value, "get_metadata"): try: metadata = value.get_metadata() # Convert ModelMetadata to dict if it's a Pydantic model - if hasattr(metadata, 'model_dump'): + if hasattr(metadata, "model_dump"): metadata = metadata.model_dump() - elif hasattr(metadata, 'dict'): + elif hasattr(metadata, "dict"): metadata = metadata.dict() except Exception: pass - + return { "type": "AIModel", "metadata": metadata, - "has_predict": hasattr(value, 'predict'), - "has_validate": hasattr(value, 'validate'), + "has_predict": hasattr(value, "predict"), + "has_validate": hasattr(value, "validate"), } - + return core_schema.no_info_plain_validator_function( validate_aimodel, serialization=core_schema.plain_serializer_function_ser_schema( serialize_aimodel - ) + ), ) # ModelContract moved to contract.py to avoid Pydantic Protocol conflicts # Import it from .contract import ModelContract when needed - diff --git a/src/deepiri_modelkit/contracts/services.py b/src/deepiri_modelkit/contracts/services.py index 9011102..b844a46 100644 --- a/src/deepiri_modelkit/contracts/services.py +++ b/src/deepiri_modelkit/contracts/services.py @@ -1,58 +1,44 @@ """ Service contracts and interfaces """ + from typing import Protocol, Dict, Any, Optional from pydantic import BaseModel class ModelRegistryService(Protocol): """Interface for model registry operations""" - + def register_model( - self, - model_name: str, - version: str, - model_path: str, - metadata: Dict[str, Any] + self, model_name: str, version: str, model_path: str, metadata: Dict[str, Any] ) -> bool: """Register a model in the registry""" ... - + def get_model( - self, - model_name: str, - version: Optional[str] = None + self, model_name: str, version: Optional[str] = None ) -> Dict[str, Any]: """Get model information from registry""" ... - + def list_models(self, model_name: Optional[str] = None) -> list: """List available models""" ... - - def download_model( - self, - model_name: str, - version: str, - destination: str - ) -> str: + + def download_model(self, model_name: str, version: str, destination: str) -> str: """Download model to destination, returns local path""" ... class StreamingService(Protocol): """Interface for streaming operations""" - + def publish(self, topic: str, event: Dict[str, Any]) -> bool: """Publish event to topic""" ... - + def subscribe( - self, - topic: str, - callback: callable, - consumer_group: Optional[str] = None + self, topic: str, callback: callable, consumer_group: Optional[str] = None ) -> None: """Subscribe to topic with callback""" ... - diff --git a/src/deepiri_modelkit/data/monitoring.py b/src/deepiri_modelkit/data/monitoring.py index 3fe2d65..81b7a91 100644 --- a/src/deepiri_modelkit/data/monitoring.py +++ b/src/deepiri_modelkit/data/monitoring.py @@ -2,6 +2,7 @@ Dataset Monitoring and Logging Utilities Provides monitoring, alerting, and logging for dataset versioning operations """ + import json import time from pathlib import Path @@ -40,7 +41,7 @@ def __init__(self, log_dir: str = "./logs/dataset_monitoring"): "average_version_creation_time": 0, "validation_errors_today": 0, "last_health_check": None, - "storage_usage_bytes": 0 + "storage_usage_bytes": 0, } self._load_metrics() @@ -59,15 +60,17 @@ def log_version_creation(self, operation_data: Dict[str, Any]): "change_type": operation_data.get("change_type"), "quality_score": operation_data.get("quality_score"), "storage_path": operation_data.get("storage_path"), - "created_by": operation_data.get("created_by") + "created_by": operation_data.get("created_by"), } self._write_log_entry(self.metrics_file, log_entry) self.current_metrics["total_versions_created"] += 1 - logger.info("Version creation logged", - dataset=operation_data.get("dataset_name"), - version=operation_data.get("version")) + logger.info( + "Version creation logged", + dataset=operation_data.get("dataset_name"), + version=operation_data.get("version"), + ) def log_validation_result(self, validation_data: Dict[str, Any]): """Log dataset validation results.""" @@ -80,7 +83,7 @@ def log_validation_result(self, validation_data: Dict[str, Any]): "quality_score": validation_data.get("quality_score"), "error_count": len(validation_data.get("errors", [])), "warning_count": len(validation_data.get("warnings", [])), - "validation_time_seconds": validation_data.get("validation_time", 0) + "validation_time_seconds": validation_data.get("validation_time", 0), } self._write_log_entry(self.metrics_file, log_entry) @@ -90,17 +93,22 @@ def log_validation_result(self, validation_data: Dict[str, Any]): # Check for alerts if validation_data.get("quality_score", 1.0) < 0.7: - self._create_alert("low_quality_score", { - "dataset_name": validation_data.get("dataset_name"), - "version": validation_data.get("version"), - "quality_score": validation_data.get("quality_score"), - "errors": validation_data.get("errors", []) - }) - - logger.info("Validation result logged", - dataset=validation_data.get("dataset_name"), - valid=validation_data.get("is_valid"), - quality=validation_data.get("quality_score")) + self._create_alert( + "low_quality_score", + { + "dataset_name": validation_data.get("dataset_name"), + "version": validation_data.get("version"), + "quality_score": validation_data.get("quality_score"), + "errors": validation_data.get("errors", []), + }, + ) + + logger.info( + "Validation result logged", + dataset=validation_data.get("dataset_name"), + valid=validation_data.get("is_valid"), + quality=validation_data.get("quality_score"), + ) def log_training_usage(self, training_data: Dict[str, Any]): """Log dataset usage in training.""" @@ -113,15 +121,17 @@ def log_training_usage(self, training_data: Dict[str, Any]): "training_duration_seconds": training_data.get("training_duration", 0), "final_loss": training_data.get("final_loss"), "experiment_id": training_data.get("experiment_id"), - "output_model_path": training_data.get("output_model_path") + "output_model_path": training_data.get("output_model_path"), } self._write_log_entry(self.metrics_file, log_entry) - logger.info("Training usage logged", - dataset=training_data.get("dataset_name"), - version=training_data.get("dataset_version"), - model=training_data.get("model_name")) + logger.info( + "Training usage logged", + dataset=training_data.get("dataset_name"), + version=training_data.get("dataset_version"), + model=training_data.get("model_name"), + ) def get_health_report(self) -> Dict[str, Any]: """Generate comprehensive health report.""" @@ -130,13 +140,16 @@ def get_health_report(self) -> Dict[str, Any]: "summary": { "total_versions": self.current_metrics["total_versions_created"], "datasets_tracked": self.current_metrics["total_datasets_tracked"], - "validation_errors_today": self.current_metrics["validation_errors_today"], - "storage_usage_gb": self.current_metrics["storage_usage_bytes"] / (1024**3) + "validation_errors_today": self.current_metrics[ + "validation_errors_today" + ], + "storage_usage_gb": self.current_metrics["storage_usage_bytes"] + / (1024**3), }, "performance": self._analyze_performance(), "quality_trends": self._analyze_quality_trends(), "alerts": self._get_recent_alerts(), - "recommendations": self._generate_recommendations() + "recommendations": self._generate_recommendations(), } self.current_metrics["last_health_check"] = report["timestamp"] @@ -152,12 +165,12 @@ def get_usage_analytics(self, days: int = 30) -> Dict[str, Any]: "training_runs": [], "validation_runs": [], "popular_datasets": {}, - "quality_distribution": {} + "quality_distribution": {}, } # Read logs and filter by date if self.metrics_file.exists(): - with open(self.metrics_file, 'r') as f: + with open(self.metrics_file, "r") as f: for line in f: try: entry = json.loads(line.strip()) @@ -167,7 +180,9 @@ def get_usage_analytics(self, days: int = 30) -> Dict[str, Any]: if entry["operation"] == "version_creation": analytics["version_creations"].append(entry) dataset = entry.get("dataset_name", "unknown") - analytics["popular_datasets"][dataset] = analytics["popular_datasets"].get(dataset, 0) + 1 + analytics["popular_datasets"][dataset] = ( + analytics["popular_datasets"].get(dataset, 0) + 1 + ) elif entry["operation"] == "training_usage": analytics["training_runs"].append(entry) @@ -176,7 +191,12 @@ def get_usage_analytics(self, days: int = 30) -> Dict[str, Any]: analytics["validation_runs"].append(entry) quality = entry.get("quality_score", 0) quality_bucket = f"{int(quality * 10) / 10:.1f}" - analytics["quality_distribution"][quality_bucket] = analytics["quality_distribution"].get(quality_bucket, 0) + 1 + analytics["quality_distribution"][quality_bucket] = ( + analytics["quality_distribution"].get( + quality_bucket, 0 + ) + + 1 + ) except json.JSONDecodeError: continue @@ -189,7 +209,7 @@ def _analyze_performance(self) -> Dict[str, Any]: validation_times = [] if self.metrics_file.exists(): - with open(self.metrics_file, 'r') as f: + with open(self.metrics_file, "r") as f: for line in f: try: entry = json.loads(line.strip()) @@ -198,16 +218,22 @@ def _analyze_performance(self) -> Dict[str, Any]: creation_times.append(entry["creation_time_seconds"]) elif entry["operation"] == "validation": if "validation_time_seconds" in entry: - validation_times.append(entry["validation_time_seconds"]) + validation_times.append( + entry["validation_time_seconds"] + ) except json.JSONDecodeError: continue return { - "avg_version_creation_time": statistics.mean(creation_times) if creation_times else 0, - "avg_validation_time": statistics.mean(validation_times) if validation_times else 0, + "avg_version_creation_time": ( + statistics.mean(creation_times) if creation_times else 0 + ), + "avg_validation_time": ( + statistics.mean(validation_times) if validation_times else 0 + ), "total_operations": len(creation_times) + len(validation_times), "creation_times": creation_times[-10:], # Last 10 - "validation_times": validation_times[-10:] # Last 10 + "validation_times": validation_times[-10:], # Last 10 } def _analyze_quality_trends(self) -> Dict[str, Any]: @@ -215,11 +241,14 @@ def _analyze_quality_trends(self) -> Dict[str, Any]: quality_scores = [] if self.metrics_file.exists(): - with open(self.metrics_file, 'r') as f: + with open(self.metrics_file, "r") as f: for line in f: try: entry = json.loads(line.strip()) - if entry["operation"] == "validation" and "quality_score" in entry: + if ( + entry["operation"] == "validation" + and "quality_score" in entry + ): quality_scores.append(entry["quality_score"]) except json.JSONDecodeError: continue @@ -232,8 +261,8 @@ def _analyze_quality_trends(self) -> Dict[str, Any]: # Simple trend analysis if len(recent_scores) >= 10: - first_half = recent_scores[:len(recent_scores)//2] - second_half = recent_scores[len(recent_scores)//2:] + first_half = recent_scores[: len(recent_scores) // 2] + second_half = recent_scores[len(recent_scores) // 2 :] first_avg = statistics.mean(first_half) second_avg = statistics.mean(second_half) @@ -255,8 +284,8 @@ def _analyze_quality_trends(self) -> Dict[str, Any]: "excellent": len([s for s in quality_scores if s >= 0.9]), "good": len([s for s in quality_scores if 0.7 <= s < 0.9]), "fair": len([s for s in quality_scores if 0.5 <= s < 0.7]), - "poor": len([s for s in quality_scores if s < 0.5]) - } + "poor": len([s for s in quality_scores if s < 0.5]), + }, } def _generate_recommendations(self) -> List[str]: @@ -265,24 +294,34 @@ def _generate_recommendations(self) -> List[str]: # Check for frequent validation errors if self.current_metrics["validation_errors_today"] > 5: - recommendations.append("High validation error rate detected. Review data quality processes.") + recommendations.append( + "High validation error rate detected. Review data quality processes." + ) # Check quality trends quality_analysis = self._analyze_quality_trends() if quality_analysis.get("trend") == "declining": - recommendations.append("Dataset quality is declining. Consider reviewing data sources and annotation processes.") + recommendations.append( + "Dataset quality is declining. Consider reviewing data sources and annotation processes." + ) # Check performance performance = self._analyze_performance() if performance.get("avg_version_creation_time", 0) > 300: # 5 minutes - recommendations.append("Version creation is slow. Consider optimizing storage or processing.") + recommendations.append( + "Version creation is slow. Consider optimizing storage or processing." + ) # General recommendations if self.current_metrics["total_versions_created"] == 0: - recommendations.append("No dataset versions created yet. Start versioning your datasets for reproducibility.") + recommendations.append( + "No dataset versions created yet. Start versioning your datasets for reproducibility." + ) if not recommendations: - recommendations.append("System operating normally. Continue regular monitoring.") + recommendations.append( + "System operating normally. Continue regular monitoring." + ) return recommendations @@ -293,14 +332,12 @@ def _create_alert(self, alert_type: str, alert_data: Dict[str, Any]): "alert_type": alert_type, "severity": "warning", # Could be "info", "warning", "error" "data": alert_data, - "resolved": False + "resolved": False, } self._write_log_entry(self.alerts_file, alert_entry) - logger.warning("Alert created", - type=alert_type, - data=alert_data) + logger.warning("Alert created", type=alert_type, data=alert_data) def _get_recent_alerts(self, hours: int = 24) -> List[Dict[str, Any]]: """Get recent alerts.""" @@ -308,7 +345,7 @@ def _get_recent_alerts(self, hours: int = 24) -> List[Dict[str, Any]]: cutoff_time = datetime.utcnow() - timedelta(hours=hours) if self.alerts_file.exists(): - with open(self.alerts_file, 'r') as f: + with open(self.alerts_file, "r") as f: for line in f: try: alert = json.loads(line.strip()) @@ -322,22 +359,33 @@ def _get_recent_alerts(self, hours: int = 24) -> List[Dict[str, Any]]: def _write_log_entry(self, log_file: Path, entry: Dict[str, Any]): """Write a log entry to file.""" - with open(log_file, 'a') as f: - f.write(json.dumps(entry) + '\n') + with open(log_file, "a") as f: + f.write(json.dumps(entry) + "\n") def _load_metrics(self): """Load current metrics from log files.""" if self.metrics_file.exists(): try: - with open(self.metrics_file, 'r') as f: + with open(self.metrics_file, "r") as f: lines = f.readlines() if lines: # Count operations from logs - version_count = sum(1 for line in lines if '"operation": "version_creation"' in line) - validation_count = sum(1 for line in lines if '"operation": "validation"' in line and '"is_valid": false' in line) + version_count = sum( + 1 + for line in lines + if '"operation": "version_creation"' in line + ) + validation_count = sum( + 1 + for line in lines + if '"operation": "validation"' in line + and '"is_valid": false' in line + ) self.current_metrics["total_versions_created"] = version_count - self.current_metrics["validation_errors_today"] = validation_count + self.current_metrics["validation_errors_today"] = ( + validation_count + ) except Exception as e: logger.warning("Failed to load metrics from log", error=str(e)) @@ -346,21 +394,17 @@ def _load_metrics(self): def log_version_creation(dataset_name: str, version: str, **kwargs): """Convenience function to log version creation.""" monitor = DatasetMonitor() - monitor.log_version_creation({ - "dataset_name": dataset_name, - "version": version, - **kwargs - }) + monitor.log_version_creation( + {"dataset_name": dataset_name, "version": version, **kwargs} + ) def log_validation_result(dataset_name: str, version: str, **kwargs): """Convenience function to log validation results.""" monitor = DatasetMonitor() - monitor.log_validation_result({ - "dataset_name": dataset_name, - "version": version, - **kwargs - }) + monitor.log_validation_result( + {"dataset_name": dataset_name, "version": version, **kwargs} + ) def get_health_report(): diff --git a/src/deepiri_modelkit/data/validation.py b/src/deepiri_modelkit/data/validation.py index 25e7b2e..60ef3c7 100644 --- a/src/deepiri_modelkit/data/validation.py +++ b/src/deepiri_modelkit/data/validation.py @@ -2,6 +2,7 @@ Dataset Validation Utilities Provides validation and quality checks for language intelligence datasets """ + import json from pathlib import Path from typing import Dict, List, Any, Optional @@ -35,29 +36,42 @@ def _get_validation_rules(self) -> Dict[str, Any]: "min_text_length": 10, "max_text_length": 10000, "required_fields": ["text"], - "text_quality_checks": True + "text_quality_checks": True, } type_specific_rules = { "lease_abstraction": { "min_samples": 50, "lease_keywords": [ - "lease", "agreement", "landlord", "tenant", "rent", - "premises", "term", "commencement", "expiration" + "lease", + "agreement", + "landlord", + "tenant", + "rent", + "premises", + "term", + "commencement", + "expiration", ], "min_keyword_matches": 2, "check_address_patterns": True, - "check_rent_patterns": True + "check_rent_patterns": True, }, "contract_intelligence": { "min_samples": 50, "contract_keywords": [ - "contract", "agreement", "party", "obligation", - "clause", "provision", "section", "article" + "contract", + "agreement", + "party", + "obligation", + "clause", + "provision", + "section", + "article", ], "min_keyword_matches": 2, - "check_legal_patterns": True - } + "check_legal_patterns": True, + }, } if self.dataset_type in type_specific_rules: @@ -75,14 +89,16 @@ def validate_dataset(self, data_path: Path) -> Dict[str, Any]: Returns: Validation results dictionary """ - logger.info("Starting dataset validation", path=str(data_path), type=self.dataset_type) + logger.info( + "Starting dataset validation", path=str(data_path), type=self.dataset_type + ) results = { "is_valid": True, "errors": [], "warnings": [], "statistics": {}, - "quality_score": 0.0 + "quality_score": 0.0, } try: @@ -117,11 +133,13 @@ def validate_dataset(self, data_path: Path) -> Dict[str, Any]: results["errors"].append(f"Validation failed with error: {str(e)}") logger.error("Dataset validation failed", error=str(e)) - logger.info("Dataset validation complete", - valid=results["is_valid"], - quality_score=results["quality_score"], - errors=len(results["errors"]), - warnings=len(results["warnings"])) + logger.info( + "Dataset validation complete", + valid=results["is_valid"], + quality_score=results["quality_score"], + errors=len(results["errors"]), + warnings=len(results["warnings"]), + ) return results @@ -130,7 +148,7 @@ def _load_samples(self, data_path: Path) -> List[Dict[str, Any]]: samples = [] if data_path.is_file() and data_path.suffix == ".jsonl": - with open(data_path, 'r', encoding='utf-8') as f: + with open(data_path, "r", encoding="utf-8") as f: for line_num, line in enumerate(f, 1): line = line.strip() if line: @@ -142,7 +160,7 @@ def _load_samples(self, data_path: Path) -> List[Dict[str, Any]]: elif data_path.is_dir(): for file_path in data_path.glob("*.jsonl"): - with open(file_path, 'r', encoding='utf-8') as f: + with open(file_path, "r", encoding="utf-8") as f: for line_num, line in enumerate(f, 1): line = line.strip() if line: @@ -150,7 +168,9 @@ def _load_samples(self, data_path: Path) -> List[Dict[str, Any]]: sample = json.loads(line) samples.append(sample) except json.JSONDecodeError as e: - logger.warning(f"Invalid JSON in {file_path} at line {line_num}: {e}") + logger.warning( + f"Invalid JSON in {file_path} at line {line_num}: {e}" + ) return samples @@ -171,7 +191,9 @@ def _validate_structure(self, samples: List[Dict], results: Dict): for i, sample in enumerate(samples[:100]): # Check first 100 samples for field in required_fields: if field not in sample: - results["errors"].append(f"Missing required field '{field}' in sample {i}") + results["errors"].append( + f"Missing required field '{field}' in sample {i}" + ) def _validate_content_quality(self, samples: List[Dict], results: Dict): """Validate content quality.""" @@ -202,19 +224,25 @@ def _validate_content_quality(self, samples: List[Dict], results: Dict): seen_texts.add(text) # Statistics - results["statistics"].update({ - "avg_text_length": sum(text_lengths) / len(text_lengths) if text_lengths else 0, - "min_text_length": min(text_lengths) if text_lengths else 0, - "max_text_length": max(text_lengths) if text_lengths else 0, - "empty_texts": empty_texts, - "duplicate_texts": len(duplicate_texts) - }) + results["statistics"].update( + { + "avg_text_length": ( + sum(text_lengths) / len(text_lengths) if text_lengths else 0 + ), + "min_text_length": min(text_lengths) if text_lengths else 0, + "max_text_length": max(text_lengths) if text_lengths else 0, + "empty_texts": empty_texts, + "duplicate_texts": len(duplicate_texts), + } + ) if empty_texts > 0: results["errors"].append(f"Found {empty_texts} empty text samples") if len(duplicate_texts) > len(samples) * 0.01: # >1% duplicates - results["warnings"].append(f"High duplicate rate: {len(duplicate_texts)} duplicates") + results["warnings"].append( + f"High duplicate rate: {len(duplicate_texts)} duplicates" + ) def _validate_type_specific(self, samples: List[Dict], results: Dict): """Type-specific validation.""" @@ -233,10 +261,10 @@ def _validate_lease_abstraction(self, samples: List[Dict], results: Dict): rent_pattern_matches = 0 # Address patterns (street numbers, street names, cities) - address_pattern = r'\d+\s+[A-Za-z0-9\s,.-]+(?:Street|St|Avenue|Ave|Road|Rd|Boulevard|Blvd|Drive|Dr|Lane|Ln|Way|Place|Pl|Court|Ct)\s*,?\s*[A-Za-z\s]+,?\s*\d{5}' + address_pattern = r"\d+\s+[A-Za-z0-9\s,.-]+(?:Street|St|Avenue|Ave|Road|Rd|Boulevard|Blvd|Drive|Dr|Lane|Ln|Way|Place|Pl|Court|Ct)\s*,?\s*[A-Za-z\s]+,?\s*\d{5}" # Rent patterns (dollar amounts) - rent_pattern = r'\$\d{1,3}(?:,\d{3})*(?:\.\d{2})?' + rent_pattern = r"\$\d{1,3}(?:,\d{3})*(?:\.\d{2})?" for sample in samples[:500]: # Check first 500 samples for performance text = sample.get("text", "").lower() @@ -261,11 +289,13 @@ def _validate_lease_abstraction(self, samples: List[Dict], results: Dict): f"Low keyword relevance: {keyword_failure_rate:.1%} samples lack lease keywords" ) - results["statistics"].update({ - "address_pattern_matches": address_pattern_matches, - "rent_pattern_matches": rent_pattern_matches, - "keyword_relevance_score": 1.0 - keyword_failure_rate - }) + results["statistics"].update( + { + "address_pattern_matches": address_pattern_matches, + "rent_pattern_matches": rent_pattern_matches, + "keyword_relevance_score": 1.0 - keyword_failure_rate, + } + ) def _validate_contract_intelligence(self, samples: List[Dict], results: Dict): """Validate contract intelligence dataset.""" @@ -277,11 +307,11 @@ def _validate_contract_intelligence(self, samples: List[Dict], results: Dict): # Legal clause patterns legal_patterns = [ - r'\bsection\s+\d+', - r'\barticle\s+\d+', - r'\bclause\s+\d+', - r'\bparagraph\s+\d+', - r'\bsubsection\s+\d+' + r"\bsection\s+\d+", + r"\barticle\s+\d+", + r"\bclause\s+\d+", + r"\bparagraph\s+\d+", + r"\bsubsection\s+\d+", ] for sample in samples[:500]: # Check first 500 samples @@ -293,7 +323,9 @@ def _validate_contract_intelligence(self, samples: List[Dict], results: Dict): low_keyword_samples += 1 # Legal pattern matching - if any(re.search(pattern, sample.get("text", "")) for pattern in legal_patterns): + if any( + re.search(pattern, sample.get("text", "")) for pattern in legal_patterns + ): legal_pattern_matches += 1 total_checked = min(500, len(samples)) @@ -304,10 +336,12 @@ def _validate_contract_intelligence(self, samples: List[Dict], results: Dict): f"Low keyword relevance: {keyword_failure_rate:.1%} samples lack contract keywords" ) - results["statistics"].update({ - "legal_pattern_matches": legal_pattern_matches, - "keyword_relevance_score": 1.0 - keyword_failure_rate - }) + results["statistics"].update( + { + "legal_pattern_matches": legal_pattern_matches, + "keyword_relevance_score": 1.0 - keyword_failure_rate, + } + ) def _calculate_quality_score(self, results: Dict) -> float: """Calculate overall quality score (0.0 to 1.0).""" @@ -349,7 +383,9 @@ def _calculate_quality_score(self, results: Dict) -> float: return max(0.0, min(1.0, score)) -def validate_dataset_quality(data_path: Path, dataset_type: str = "general") -> Dict[str, Any]: +def validate_dataset_quality( + data_path: Path, dataset_type: str = "general" +) -> Dict[str, Any]: """ Convenience function to validate dataset quality. diff --git a/src/deepiri_modelkit/logging.py b/src/deepiri_modelkit/logging.py index 9b65998..6530da0 100644 --- a/src/deepiri_modelkit/logging.py +++ b/src/deepiri_modelkit/logging.py @@ -2,6 +2,7 @@ Shared logging utilities for all Deepiri services Used by: Cyrex (runtime), Helox (training), and all microservices """ + import logging import sys import json @@ -12,47 +13,47 @@ class StructuredLogger: """JSON structured logger for all Deepiri services""" - + def __init__(self, name: str, level: int = logging.INFO): self.logger = logging.getLogger(name) self.logger.setLevel(level) - + # Remove existing handlers self.logger.handlers = [] - + # Create console handler with JSON formatting handler = logging.StreamHandler(sys.stdout) handler.setLevel(level) handler.setFormatter(JsonFormatter()) self.logger.addHandler(handler) - + self.logger.propagate = False - + def _log(self, level: int, event: str, **kwargs): """Internal log method with structured data""" extra = {"event": event, "timestamp": datetime.utcnow().isoformat() + "Z"} extra.update(kwargs) self.logger.log(level, json.dumps(extra)) - + def debug(self, event: str, **kwargs): self._log(logging.DEBUG, event, **kwargs) - + def info(self, event: str, **kwargs): self._log(logging.INFO, event, **kwargs) - + def warning(self, event: str, **kwargs): self._log(logging.WARNING, event, **kwargs) - + def error(self, event: str, **kwargs): self._log(logging.ERROR, event, **kwargs) - + def critical(self, event: str, **kwargs): self._log(logging.CRITICAL, event, **kwargs) class JsonFormatter(logging.Formatter): """Format logs as JSON""" - + def format(self, record: logging.LogRecord) -> str: log_data = { "timestamp": datetime.utcnow().isoformat() + "Z", @@ -60,27 +61,47 @@ def format(self, record: logging.LogRecord) -> str: "logger": record.name, "message": record.getMessage(), } - + # Add extra fields if present - if hasattr(record, 'event'): - log_data['event'] = record.event - + if hasattr(record, "event"): + log_data["event"] = record.event + # Add any custom fields from extra={} for key, value in record.__dict__.items(): - if key not in ['name', 'msg', 'args', 'created', 'filename', 'funcName', - 'levelname', 'levelno', 'lineno', 'module', 'msecs', - 'message', 'pathname', 'process', 'processName', - 'relativeCreated', 'thread', 'threadName', 'exc_info', - 'exc_text', 'stack_info', 'event', 'timestamp']: + if key not in [ + "name", + "msg", + "args", + "created", + "filename", + "funcName", + "levelname", + "levelno", + "lineno", + "module", + "msecs", + "message", + "pathname", + "process", + "processName", + "relativeCreated", + "thread", + "threadName", + "exc_info", + "exc_text", + "stack_info", + "event", + "timestamp", + ]: log_data[key] = value - + return json.dumps(log_data) def get_logger(name: str, level: int = logging.INFO) -> StructuredLogger: """ Get structured logger instance - + Usage: from deepiri_modelkit.logging import get_logger logger = get_logger("my_service") @@ -91,10 +112,10 @@ def get_logger(name: str, level: int = logging.INFO) -> StructuredLogger: class ErrorLogger: """Error logging with context""" - + def __init__(self): self.logger = get_logger("error_logger") - + def log_api_error(self, error: Exception, request_id: str, endpoint: str): """Log API errors with context""" self.logger.error( @@ -102,27 +123,31 @@ def log_api_error(self, error: Exception, request_id: str, endpoint: str): error=str(error), error_type=type(error).__name__, request_id=request_id, - endpoint=endpoint + endpoint=endpoint, ) - - def log_model_error(self, error: Exception, model_name: str, input_data: Optional[Dict] = None): + + def log_model_error( + self, error: Exception, model_name: str, input_data: Optional[Dict] = None + ): """Log model inference errors""" self.logger.error( "model_error", error=str(error), error_type=type(error).__name__, model_name=model_name, - input_sample=str(input_data)[:200] if input_data else None + input_sample=str(input_data)[:200] if input_data else None, ) - - def log_training_error(self, error: Exception, pipeline: str, config: Optional[Dict] = None): + + def log_training_error( + self, error: Exception, pipeline: str, config: Optional[Dict] = None + ): """Log training pipeline errors""" self.logger.error( "training_error", error=str(error), error_type=type(error).__name__, pipeline=pipeline, - config=config + config=config, ) @@ -144,4 +169,3 @@ def get_error_logger() -> ErrorLogger: if _error_logger is None: _error_logger = ErrorLogger() return _error_logger - diff --git a/src/deepiri_modelkit/ml/__init__.py b/src/deepiri_modelkit/ml/__init__.py index 09c30ff..1a41dcb 100644 --- a/src/deepiri_modelkit/ml/__init__.py +++ b/src/deepiri_modelkit/ml/__init__.py @@ -8,12 +8,14 @@ ConfidenceCalculator, get_confidence_calculator, ) + _HAS_CONFIDENCE = True except ImportError: _HAS_CONFIDENCE = False try: from .semantic import SemanticAnalyzer, get_semantic_analyzer + _HAS_SEMANTIC = True except ImportError: _HAS_SEMANTIC = False diff --git a/src/deepiri_modelkit/ml/confidence.py b/src/deepiri_modelkit/ml/confidence.py index 89d5f2b..ca23d11 100644 --- a/src/deepiri_modelkit/ml/confidence.py +++ b/src/deepiri_modelkit/ml/confidence.py @@ -2,12 +2,14 @@ Redesigned Confidence Classes System Provides structured confidence assessment with multiple attributes """ + from enum import Enum from typing import Dict, List, Optional, Tuple from dataclasses import dataclass try: import numpy as np + HAS_NUMPY = True except ImportError: HAS_NUMPY = False @@ -15,6 +17,7 @@ class ConfidenceLevel(str, Enum): """Confidence level categories""" + VERY_HIGH = "very_high" # 0.9-1.0 HIGH = "high" # 0.75-0.9 MEDIUM = "medium" # 0.5-0.75 @@ -24,6 +27,7 @@ class ConfidenceLevel(str, Enum): class ConfidenceSource(str, Enum): """Sources of confidence information""" + MODEL_PREDICTION = "model_prediction" TRAINING_DATA_COVERAGE = "training_data_coverage" FEATURE_QUALITY = "feature_quality" @@ -46,6 +50,7 @@ class ConfidenceAttributes: reliability: Overall reliability score explanation: Human-readable explanation """ + base_score: float level: ConfidenceLevel sources: Dict[str, float] @@ -63,7 +68,7 @@ def to_dict(self) -> Dict: "uncertainty": self.uncertainty, "calibration": self.calibration, "reliability": self.reliability, - "explanation": self.explanation + "explanation": self.explanation, } @@ -78,7 +83,7 @@ def __init__(self): ConfidenceLevel.HIGH: 0.75, ConfidenceLevel.MEDIUM: 0.5, ConfidenceLevel.LOW: 0.25, - ConfidenceLevel.VERY_LOW: 0.0 + ConfidenceLevel.VERY_LOW: 0.0, } def calculate_confidence( @@ -88,7 +93,7 @@ def calculate_confidence( training_coverage: Optional[float] = None, feature_quality: Optional[float] = None, context_match: Optional[float] = None, - historical_accuracy: Optional[Dict[int, float]] = None + historical_accuracy: Optional[Dict[int, float]] = None, ) -> ConfidenceAttributes: """ Calculate comprehensive confidence attributes @@ -105,7 +110,9 @@ def calculate_confidence( ConfidenceAttributes object """ if not HAS_NUMPY: - raise ImportError("numpy is required for ConfidenceCalculator. Install with: pip install numpy") + raise ImportError( + "numpy is required for ConfidenceCalculator. Install with: pip install numpy" + ) # Base score: maximum probability base_score = float(np.max(model_probabilities)) @@ -117,7 +124,9 @@ def calculate_confidence( # Calibration: difference between top-2 probabilities (margin) sorted_probs = np.sort(model_probabilities)[::-1] - margin = float(sorted_probs[0] - sorted_probs[1]) if len(sorted_probs) > 1 else 1.0 + margin = ( + float(sorted_probs[0] - sorted_probs[1]) if len(sorted_probs) > 1 else 1.0 + ) calibration = float(margin) # Higher margin = better calibration # Source contributions @@ -130,7 +139,9 @@ def calculate_confidence( if training_coverage is not None: sources[ConfidenceSource.TRAINING_DATA_COVERAGE.value] = training_coverage else: - sources[ConfidenceSource.TRAINING_DATA_COVERAGE.value] = 0.7 # Default moderate + sources[ConfidenceSource.TRAINING_DATA_COVERAGE.value] = ( + 0.7 # Default moderate + ) # Feature quality if feature_quality is not None: @@ -150,12 +161,16 @@ def calculate_confidence( hist_acc = historical_accuracy.get(predicted_class, 0.7) sources[ConfidenceSource.HISTORICAL_ACCURACY.value] = hist_acc else: - sources[ConfidenceSource.HISTORICAL_ACCURACY.value] = 0.7 # Default moderate + sources[ConfidenceSource.HISTORICAL_ACCURACY.value] = ( + 0.7 # Default moderate + ) # Ensemble agreement (if top_k_probs provided) if top_k_probs: agreement = float(np.std(top_k_probs)) # Lower std = higher agreement - sources[ConfidenceSource.ENSEMBLE_AGREEMENT.value] = 1.0 - min(agreement, 1.0) + sources[ConfidenceSource.ENSEMBLE_AGREEMENT.value] = 1.0 - min( + agreement, 1.0 + ) else: sources[ConfidenceSource.ENSEMBLE_AGREEMENT.value] = 0.7 # Default moderate @@ -166,16 +181,17 @@ def calculate_confidence( ConfidenceSource.FEATURE_QUALITY.value: 0.15, ConfidenceSource.CONTEXT_MATCH.value: 0.1, ConfidenceSource.HISTORICAL_ACCURACY.value: 0.1, - ConfidenceSource.ENSEMBLE_AGREEMENT.value: 0.1 + ConfidenceSource.ENSEMBLE_AGREEMENT.value: 0.1, } reliability = sum( - sources[source] * weights.get(source, 0.0) - for source in sources + sources[source] * weights.get(source, 0.0) for source in sources ) # Adjust reliability based on uncertainty and calibration - reliability = reliability * (1.0 - uncertainty * 0.3) * (0.7 + calibration * 0.3) + reliability = ( + reliability * (1.0 - uncertainty * 0.3) * (0.7 + calibration * 0.3) + ) reliability = max(0.0, min(1.0, reliability)) # Determine confidence level @@ -193,7 +209,7 @@ def calculate_confidence( uncertainty=uncertainty, calibration=calibration, reliability=reliability, - explanation=explanation + explanation=explanation, ) def _get_confidence_level(self, reliability: float) -> ConfidenceLevel: @@ -215,13 +231,15 @@ def _generate_explanation( level: ConfidenceLevel, sources: Dict[str, float], uncertainty: float, - calibration: float + calibration: float, ) -> str: """Generate human-readable explanation""" parts = [] # Main confidence statement - parts.append(f"Confidence: {level.value.replace('_', ' ').title()} ({reliability:.2%})") + parts.append( + f"Confidence: {level.value.replace('_', ' ').title()} ({reliability:.2%})" + ) # Key factors key_factors = [] @@ -255,7 +273,7 @@ def should_accept_prediction( self, confidence: ConfidenceAttributes, min_reliability: float = 0.7, - min_level: ConfidenceLevel = ConfidenceLevel.MEDIUM + min_level: ConfidenceLevel = ConfidenceLevel.MEDIUM, ) -> Tuple[bool, str]: """ Determine if prediction should be accepted based on confidence @@ -268,14 +286,20 @@ def should_accept_prediction( ConfidenceLevel.LOW: 1, ConfidenceLevel.MEDIUM: 2, ConfidenceLevel.HIGH: 3, - ConfidenceLevel.VERY_HIGH: 4 + ConfidenceLevel.VERY_HIGH: 4, } if confidence.reliability < min_reliability: - return False, f"Reliability {confidence.reliability:.2%} below threshold {min_reliability:.2%}" + return ( + False, + f"Reliability {confidence.reliability:.2%} below threshold {min_reliability:.2%}", + ) if level_order[confidence.level] < level_order[min_level]: - return False, f"Confidence level {confidence.level.value} below required {min_level.value}" + return ( + False, + f"Confidence level {confidence.level.value} below required {min_level.value}", + ) return True, "Confidence meets requirements" diff --git a/src/deepiri_modelkit/ml/semantic.py b/src/deepiri_modelkit/ml/semantic.py index 141c6b1..366af91 100644 --- a/src/deepiri_modelkit/ml/semantic.py +++ b/src/deepiri_modelkit/ml/semantic.py @@ -3,6 +3,7 @@ Inspired by Carnegie Mellon University (CMU) Language Technologies Institute approaches Uses semantic similarity and contextual understanding for dynamic variation generation """ + import json import re from typing import List, Dict, Optional, Set @@ -16,18 +17,21 @@ # Try multiple HTTP clients try: import httpx + HAS_HTTPX = True except ImportError: HAS_HTTPX = False try: import requests + HAS_REQUESTS = True except ImportError: HAS_REQUESTS = False try: import ollama + HAS_OLLAMA_PKG = True except ImportError: HAS_OLLAMA_PKG = False @@ -39,12 +43,16 @@ class SemanticAnalyzer: Inspired by CMU's semantic analysis approaches """ - def __init__(self, ollama_base_url: str = "http://localhost:11434", model: str = "llama3:8b"): + def __init__( + self, ollama_base_url: str = "http://localhost:11434", model: str = "llama3:8b" + ): self.ollama_base_url = ollama_base_url self.model = model self._cache = {} # Cache for semantic analysis results - def _call_ollama(self, prompt: str, timeout: float = 15.0) -> Optional[str]: # Reduced from 30s to 15s + def _call_ollama( + self, prompt: str, timeout: float = 15.0 + ) -> Optional[str]: # Reduced from 30s to 15s """Call Ollama API directly via HTTP or Python package""" # Try ollama Python package first (cleaner API) if HAS_OLLAMA_PKG: @@ -55,8 +63,8 @@ def _call_ollama(self, prompt: str, timeout: float = 15.0) -> Optional[str]: # options={ "temperature": 0.7, "top_p": 0.9, - "num_predict": 100 # Reduced from 200 for faster responses - } + "num_predict": 100, # Reduced from 200 for faster responses + }, ) return response.get("response", "").strip() except Exception: @@ -76,10 +84,10 @@ def _call_ollama(self, prompt: str, timeout: float = 15.0) -> Optional[str]: # "options": { "temperature": 0.7, "top_p": 0.9, - "num_predict": 100 # Reduced from 200 for faster responses - } + "num_predict": 100, # Reduced from 200 for faster responses + }, }, - timeout=timeout + timeout=timeout, ) if response.status_code == 200: @@ -87,7 +95,9 @@ def _call_ollama(self, prompt: str, timeout: float = 15.0) -> Optional[str]: # logger.debug("Ollama HTTP call succeeded") return result.get("response", "").strip() else: - logger.debug(f"Ollama HTTP call failed: HTTP {response.status_code}") + logger.debug( + f"Ollama HTTP call failed: HTTP {response.status_code}" + ) elif HAS_REQUESTS: logger.debug(f"Calling Ollama HTTP with {len(prompt)} char prompt") response = requests.post( @@ -99,10 +109,10 @@ def _call_ollama(self, prompt: str, timeout: float = 15.0) -> Optional[str]: # "options": { "temperature": 0.7, "top_p": 0.9, - "num_predict": 100 # Reduced from 200 for faster responses - } + "num_predict": 100, # Reduced from 200 for faster responses + }, }, - timeout=timeout + timeout=timeout, ) if response.status_code == 200: @@ -110,7 +120,9 @@ def _call_ollama(self, prompt: str, timeout: float = 15.0) -> Optional[str]: # logger.debug("Ollama HTTP call succeeded") return result.get("response", "").strip() else: - logger.debug(f"Ollama HTTP call failed: HTTP {response.status_code}") + logger.debug( + f"Ollama HTTP call failed: HTTP {response.status_code}" + ) except Exception as e: logger.debug(f"Ollama HTTP call failed: {e}") @@ -137,7 +149,7 @@ def extract_semantic_verbs(self, text: str, category: str) -> List[str]: response = self._call_ollama(prompt) if response: try: - json_match = re.search(r'\[.*?\]', response, re.DOTALL) + json_match = re.search(r"\[.*?\]", response, re.DOTALL) if json_match: verbs = json.loads(json_match.group()) if isinstance(verbs, list) and len(verbs) > 0: @@ -171,7 +183,7 @@ def generate_semantic_prefixes(self, text: str, category: str) -> List[str]: response = self._call_ollama(prompt) if response: try: - json_match = re.search(r'\[.*?\]', response, re.DOTALL) + json_match = re.search(r"\[.*?\]", response, re.DOTALL) if json_match: prefixes = json.loads(json_match.group()) if isinstance(prefixes, list) and len(prefixes) > 0: @@ -182,8 +194,14 @@ def generate_semantic_prefixes(self, text: str, category: str) -> List[str]: # Fallback: return default prefixes return [ - "I need to", "Can you help me", "Please", "I want to", - "Help me", "I should", "Let me", "I'm going to" + "I need to", + "Can you help me", + "Please", + "I want to", + "Help me", + "I should", + "Let me", + "I'm going to", ] def generate_semantic_suffixes(self, text: str, category: str) -> List[str]: @@ -207,7 +225,7 @@ def generate_semantic_suffixes(self, text: str, category: str) -> List[str]: response = self._call_ollama(prompt) if response: try: - json_match = re.search(r'\[.*?\]', response, re.DOTALL) + json_match = re.search(r"\[.*?\]", response, re.DOTALL) if json_match: suffixes = json.loads(json_match.group()) if isinstance(suffixes, list) and len(suffixes) > 0: @@ -218,11 +236,18 @@ def generate_semantic_suffixes(self, text: str, category: str) -> List[str]: # Fallback: return default suffixes return [ - "", " today", " this week", " as soon as possible", - " when you have time", " - urgent", " - important" + "", + " today", + " this week", + " as soon as possible", + " when you have time", + " - urgent", + " - important", ] - def generate_paraphrases(self, text: str, category: str, num_paraphrases: int = 3) -> List[str]: + def generate_paraphrases( + self, text: str, category: str, num_paraphrases: int = 3 + ) -> List[str]: """ Generate semantic paraphrases using Ollama Inspired by CMU's paraphrase generation approaches @@ -243,12 +268,12 @@ def generate_paraphrases(self, text: str, category: str, num_paraphrases: int = response = self._call_ollama(prompt) if response: paraphrases = [] - for line in response.strip().split('\n'): + for line in response.strip().split("\n"): line = line.strip() # Remove common prefixes - for prefix in ['- ', '1. ', '2. ', '3. ', '4. ', '5. ', '* ', '• ']: + for prefix in ["- ", "1. ", "2. ", "3. ", "4. ", "5. ", "* ", "• "]: if line.startswith(prefix): - line = line[len(prefix):].strip() + line = line[len(prefix) :].strip() if line and line != text and len(line) > 10: paraphrases.append(line) @@ -278,7 +303,7 @@ def analyze_semantic_structure(self, text: str) -> Dict: response = self._call_ollama(prompt) if response: try: - json_match = re.search(r'\{.*?\}', response, re.DOTALL) + json_match = re.search(r"\{.*?\}", response, re.DOTALL) if json_match: return json.loads(json_match.group()) except Exception: @@ -291,7 +316,7 @@ def analyze_semantic_structure(self, text: str) -> Dict: "object": " ".join(words[1:]) if len(words) > 1 else "", "modifiers": [], "temporal": None, - "urgency": None + "urgency": None, } def check_ollama_available(self) -> bool: @@ -307,16 +332,10 @@ def check_ollama_available(self) -> bool: # Fall back to HTTP check try: if HAS_HTTPX: - response = httpx.get( - f"{self.ollama_base_url}/api/tags", - timeout=5.0 - ) + response = httpx.get(f"{self.ollama_base_url}/api/tags", timeout=5.0) return response.status_code == 200 elif HAS_REQUESTS: - response = requests.get( - f"{self.ollama_base_url}/api/tags", - timeout=5.0 - ) + response = requests.get(f"{self.ollama_base_url}/api/tags", timeout=5.0) return response.status_code == 200 except Exception: pass @@ -325,8 +344,7 @@ def check_ollama_available(self) -> bool: def get_semantic_analyzer( - ollama_base_url: Optional[str] = None, - model: Optional[str] = None + ollama_base_url: Optional[str] = None, model: Optional[str] = None ) -> Optional[SemanticAnalyzer]: """ Factory function to get semantic analyzer diff --git a/src/deepiri_modelkit/rag/__init__.py b/src/deepiri_modelkit/rag/__init__.py index 0b3bbcc..746f60f 100644 --- a/src/deepiri_modelkit/rag/__init__.py +++ b/src/deepiri_modelkit/rag/__init__.py @@ -37,6 +37,7 @@ MultiQueryRetriever, QueryCache, ) + HAS_ADVANCED_RETRIEVAL = True except ImportError: HAS_ADVANCED_RETRIEVAL = False @@ -53,6 +54,7 @@ EmbeddingCache, QueryResultCache, ) + HAS_CACHING = True except ImportError: HAS_CACHING = False @@ -68,6 +70,7 @@ SystemMetrics, PerformanceTimer, ) + HAS_MONITORING = True except ImportError: HAS_MONITORING = False @@ -85,6 +88,7 @@ BatchProcessingConfig, BatchProcessingResult, ) + HAS_ASYNC = True except ImportError: HAS_ASYNC = False @@ -119,37 +123,44 @@ # Conditionally add advanced features if HAS_ADVANCED_RETRIEVAL: - __all__.extend([ - "AdvancedRetrievalPipeline", - "QueryExpander", - "SynonymQueryExpander", - "RephraseQueryExpander", - "MultiQueryRetriever", - "QueryCache", - ]) + __all__.extend( + [ + "AdvancedRetrievalPipeline", + "QueryExpander", + "SynonymQueryExpander", + "RephraseQueryExpander", + "MultiQueryRetriever", + "QueryCache", + ] + ) if HAS_CACHING: - __all__.extend([ - "AdvancedCacheManager", - "EmbeddingCache", - "QueryResultCache", - ]) + __all__.extend( + [ + "AdvancedCacheManager", + "EmbeddingCache", + "QueryResultCache", + ] + ) if HAS_MONITORING: - __all__.extend([ - "RAGMonitor", - "RetrievalMetrics", - "IndexingMetrics", - "SystemMetrics", - "PerformanceTimer", - ]) + __all__.extend( + [ + "RAGMonitor", + "RetrievalMetrics", + "IndexingMetrics", + "SystemMetrics", + "PerformanceTimer", + ] + ) if HAS_ASYNC: - __all__.extend([ - "AsyncBatchProcessor", - "AsyncDocumentIndexer", - "AsyncDocumentProcessor", - "BatchProcessingConfig", - "BatchProcessingResult", - ]) - + __all__.extend( + [ + "AsyncBatchProcessor", + "AsyncDocumentIndexer", + "AsyncDocumentProcessor", + "BatchProcessingConfig", + "BatchProcessingResult", + ] + ) diff --git a/src/deepiri_modelkit/rag/advanced_retrieval.py b/src/deepiri_modelkit/rag/advanced_retrieval.py index 14e1de6..4de0fbc 100644 --- a/src/deepiri_modelkit/rag/advanced_retrieval.py +++ b/src/deepiri_modelkit/rag/advanced_retrieval.py @@ -15,6 +15,7 @@ @dataclass class ExpandedQuery: """Expanded query with multiple variations""" + original_query: str expanded_queries: List[str] query_type: str # "synonym", "rephrase", "keyword", etc. @@ -23,7 +24,7 @@ class ExpandedQuery: class QueryExpander(ABC): """Base class for query expansion strategies""" - + @abstractmethod def expand(self, query: str, max_expansions: int = 3) -> ExpandedQuery: """Expand a query into multiple variations""" @@ -32,15 +33,15 @@ def expand(self, query: str, max_expansions: int = 3) -> ExpandedQuery: class SynonymQueryExpander(QueryExpander): """Expand queries using synonyms and related terms""" - + def __init__(self, synonym_dict: Optional[Dict[str, List[str]]] = None): self.synonym_dict = synonym_dict or self._default_synonyms() - + def expand(self, query: str, max_expansions: int = 3) -> ExpandedQuery: """Expand query using synonyms""" words = query.lower().split() expanded = [query] # Always include original - + for word in words: if word in self.synonym_dict: synonyms = self.synonym_dict[word][:max_expansions] @@ -48,14 +49,14 @@ def expand(self, query: str, max_expansions: int = 3) -> ExpandedQuery: expanded_query = query.replace(word, synonym) if expanded_query not in expanded: expanded.append(expanded_query) - + return ExpandedQuery( original_query=query, - expanded_queries=expanded[:max_expansions + 1], + expanded_queries=expanded[: max_expansions + 1], query_type="synonym", - confidence=0.8 + confidence=0.8, ) - + def _default_synonyms(self) -> Dict[str, List[str]]: """Default synonym dictionary""" return { @@ -74,17 +75,17 @@ def _default_synonyms(self) -> Dict[str, List[str]]: class RephraseQueryExpander(QueryExpander): """Expand queries by rephrasing""" - + def __init__(self, llm_client=None): self.llm_client = llm_client - + def expand(self, query: str, max_expansions: int = 3) -> ExpandedQuery: """Rephrase query using LLM or templates""" if self.llm_client: return self._llm_rephrase(query, max_expansions) else: return self._template_rephrase(query, max_expansions) - + def _template_rephrase(self, query: str, max_expansions: int) -> ExpandedQuery: """Rephrase using templates""" templates = [ @@ -93,16 +94,16 @@ def _template_rephrase(self, query: str, max_expansions: int) -> ExpandedQuery: f"Information about {query}", f"Details on {query}", ] - + expanded = [query] + templates[:max_expansions] - + return ExpandedQuery( original_query=query, expanded_queries=expanded, query_type="rephrase", - confidence=0.7 + confidence=0.7, ) - + def _llm_rephrase(self, query: str, max_expansions: int) -> ExpandedQuery: """Rephrase using LLM (if available)""" # Placeholder for LLM-based rephrasing @@ -111,16 +112,51 @@ def _llm_rephrase(self, query: str, max_expansions: int) -> ExpandedQuery: class KeywordExtractor: """Extract keywords from queries for hybrid search""" - + def __init__(self): # Common stop words self.stop_words = { - "the", "a", "an", "and", "or", "but", "in", "on", "at", "to", "for", - "of", "with", "by", "is", "are", "was", "were", "be", "been", "being", - "have", "has", "had", "do", "does", "did", "will", "would", "should", - "could", "may", "might", "must", "can", "this", "that", "these", "those" + "the", + "a", + "an", + "and", + "or", + "but", + "in", + "on", + "at", + "to", + "for", + "of", + "with", + "by", + "is", + "are", + "was", + "were", + "be", + "been", + "being", + "have", + "has", + "had", + "do", + "does", + "did", + "will", + "would", + "should", + "could", + "may", + "might", + "must", + "can", + "this", + "that", + "these", + "those", } - + def extract(self, query: str, max_keywords: int = 10) -> List[str]: """Extract keywords from query""" words = query.lower().split() @@ -138,38 +174,36 @@ class MultiQueryRetriever: Multi-query retrieval strategy Generates multiple query variations and combines results """ - + def __init__( self, base_retriever, query_expander: Optional[QueryExpander] = None, - fusion_method: str = "rrf" # "rrf" (Reciprocal Rank Fusion) or "mean" + fusion_method: str = "rrf", # "rrf" (Reciprocal Rank Fusion) or "mean" ): self.base_retriever = base_retriever self.query_expander = query_expander or SynonymQueryExpander() self.fusion_method = fusion_method - - def retrieve( - self, - query: RAGQuery, - num_queries: int = 3 - ) -> List[RetrievalResult]: + + def retrieve(self, query: RAGQuery, num_queries: int = 3) -> List[RetrievalResult]: """ Retrieve using multiple query variations - + Args: query: Original RAG query num_queries: Number of query variations to generate - + Returns: Fused retrieval results """ # Expand query - expanded = self.query_expander.expand(query.query, max_expansions=num_queries - 1) - + expanded = self.query_expander.expand( + query.query, max_expansions=num_queries - 1 + ) + # Retrieve for each query variation all_results: Dict[str, List[RetrievalResult]] = {} - + for expanded_query in expanded.expanded_queries: # Create new query with expanded text expanded_rag_query = RAGQuery( @@ -178,135 +212,131 @@ def retrieve( doc_types=query.doc_types, date_range=query.date_range, metadata_filters=query.metadata_filters, - top_k=query.top_k or 10 + top_k=query.top_k or 10, ) - + # Retrieve results results = self.base_retriever.retrieve(expanded_rag_query) all_results[expanded_query] = results - + # Fuse results fused = self._fuse_results(all_results, query.top_k or 10) - + return fused - + def _fuse_results( - self, - all_results: Dict[str, List[RetrievalResult]], - top_k: int + self, all_results: Dict[str, List[RetrievalResult]], top_k: int ) -> List[RetrievalResult]: """Fuse results from multiple queries""" if self.fusion_method == "rrf": return self._reciprocal_rank_fusion(all_results, top_k) else: return self._mean_score_fusion(all_results, top_k) - + def _reciprocal_rank_fusion( - self, - all_results: Dict[str, List[RetrievalResult]], - top_k: int, - k: int = 60 + self, all_results: Dict[str, List[RetrievalResult]], top_k: int, k: int = 60 ) -> List[RetrievalResult]: """ Reciprocal Rank Fusion (RRF) Combines rankings from multiple queries """ doc_scores: Dict[str, Dict[str, Any]] = {} - + for query_text, results in all_results.items(): for rank, result in enumerate(results): doc_id = result.document.id rrf_score = 1.0 / (k + rank + 1) - + if doc_id not in doc_scores: doc_scores[doc_id] = { "document": result.document, "rrf_score": 0.0, "max_score": result.score, - "count": 0 + "count": 0, } - + doc_scores[doc_id]["rrf_score"] += rrf_score doc_scores[doc_id]["max_score"] = max( - doc_scores[doc_id]["max_score"], - result.score + doc_scores[doc_id]["max_score"], result.score ) doc_scores[doc_id]["count"] += 1 - + # Create fused results fused = [] for doc_id, scores in doc_scores.items(): - fused.append(RetrievalResult( - document=scores["document"], - score=scores["rrf_score"], # Use RRF score - rerank_score=scores["max_score"] # Store max original score - )) - + fused.append( + RetrievalResult( + document=scores["document"], + score=scores["rrf_score"], # Use RRF score + rerank_score=scores["max_score"], # Store max original score + ) + ) + # Sort by RRF score fused.sort(key=lambda x: x.score, reverse=True) - + return fused[:top_k] - + def _mean_score_fusion( - self, - all_results: Dict[str, List[RetrievalResult]], - top_k: int + self, all_results: Dict[str, List[RetrievalResult]], top_k: int ) -> List[RetrievalResult]: """Fuse results using mean score""" doc_scores: Dict[str, Dict[str, Any]] = {} - + for query_text, results in all_results.items(): for result in results: doc_id = result.document.id - + if doc_id not in doc_scores: doc_scores[doc_id] = { "document": result.document, "scores": [], - "count": 0 + "count": 0, } - + doc_scores[doc_id]["scores"].append(result.score) doc_scores[doc_id]["count"] += 1 - + # Calculate mean scores fused = [] for doc_id, scores in doc_scores.items(): mean_score = sum(scores["scores"]) / len(scores["scores"]) - fused.append(RetrievalResult( - document=scores["document"], - score=mean_score, - rerank_score=max(scores["scores"]) - )) - + fused.append( + RetrievalResult( + document=scores["document"], + score=mean_score, + rerank_score=max(scores["scores"]), + ) + ) + # Sort by mean score fused.sort(key=lambda x: x.score, reverse=True) - + return fused[:top_k] class QueryCache: """Cache for query results""" - + def __init__(self, cache_manager=None): self.cache_manager = cache_manager self.cache_ttl = 3600 # 1 hour - + def get_cache_key(self, query: RAGQuery) -> str: """Generate cache key for query""" query_dict = query.to_dict() query_str = json.dumps(query_dict, sort_keys=True) query_hash = hashlib.md5(query_str.encode()).hexdigest() return f"rag:query:{query_hash}" - + def get(self, query: RAGQuery) -> Optional[List[RetrievalResult]]: """Get cached results""" if not self.cache_manager: return None - + cache_key = self.get_cache_key(query) cached = self.cache_manager.get(cache_key) - + if cached: # Reconstruct RetrievalResult objects results = [] @@ -315,29 +345,29 @@ def get(self, query: RAGQuery) -> Optional[List[RetrievalResult]]: result = RetrievalResult( document=doc, score=item["score"], - rerank_score=item.get("rerank_score") + rerank_score=item.get("rerank_score"), ) results.append(result) return results - + return None - + def set(self, query: RAGQuery, results: List[RetrievalResult]): """Cache results""" if not self.cache_manager: return - + cache_key = self.get_cache_key(query) # Serialize results serialized = [ { "document": r.document.to_dict(), "score": r.score, - "rerank_score": r.rerank_score + "rerank_score": r.rerank_score, } for r in results ] - + self.cache_manager.set(cache_key, serialized, ttl=self.cache_ttl) @@ -349,29 +379,28 @@ class AdvancedRetrievalPipeline: - Result caching - Hybrid search """ - + def __init__( self, base_retriever, query_expander: Optional[QueryExpander] = None, use_multi_query: bool = True, use_cache: bool = True, - cache_manager=None + cache_manager=None, ): self.base_retriever = base_retriever self.query_expander = query_expander or SynonymQueryExpander() self.use_multi_query = use_multi_query self.use_cache = use_cache self.query_cache = QueryCache(cache_manager) if use_cache else None - + if use_multi_query: self.multi_query_retriever = MultiQueryRetriever( - base_retriever=base_retriever, - query_expander=query_expander + base_retriever=base_retriever, query_expander=query_expander ) else: self.multi_query_retriever = None - + def retrieve(self, query: RAGQuery) -> List[RetrievalResult]: """Retrieve with advanced strategies""" # Check cache @@ -379,16 +408,15 @@ def retrieve(self, query: RAGQuery) -> List[RetrievalResult]: cached = self.query_cache.get(query) if cached: return cached - + # Retrieve using multi-query or base retriever if self.use_multi_query and self.multi_query_retriever: results = self.multi_query_retriever.retrieve(query) else: results = self.base_retriever.retrieve(query) - + # Cache results if self.use_cache and self.query_cache: self.query_cache.set(query, results) - - return results + return results diff --git a/src/deepiri_modelkit/rag/async_processing.py b/src/deepiri_modelkit/rag/async_processing.py index 9625443..87a0b0f 100644 --- a/src/deepiri_modelkit/rag/async_processing.py +++ b/src/deepiri_modelkit/rag/async_processing.py @@ -5,6 +5,7 @@ import asyncio from typing import List, Dict, Any, Optional, Callable, Awaitable + # Fix for Python < 3.9 compatibility try: from collections.abc import AsyncIterator @@ -20,6 +21,7 @@ @dataclass class BatchProcessingConfig: """Configuration for batch processing""" + batch_size: int = 100 max_concurrent_batches: int = 5 chunk_size: int = 1000 @@ -33,24 +35,25 @@ class BatchProcessingConfig: @dataclass class BatchProcessingResult: """Result of batch processing operation""" + total_items: int processed_items: int successful_items: int failed_items: int processing_time_seconds: float errors: List[Dict[str, Any]] = None - + def __post_init__(self): if self.errors is None: self.errors = [] - + @property def success_rate(self) -> float: """Calculate success rate""" if self.total_items == 0: return 0.0 return self.successful_items / self.total_items - + def to_dict(self) -> Dict[str, Any]: """Convert to dictionary""" return { @@ -68,25 +71,25 @@ class AsyncBatchProcessor: """ Async batch processor for high-performance document processing """ - + def __init__(self, config: BatchProcessingConfig): self.config = config self.semaphore = asyncio.Semaphore(config.max_concurrent_batches) - + async def process_batch( self, items: List[Any], processor_func: Callable[[Any], Awaitable[Any]], - progress_callback: Optional[Callable[[int, int], None]] = None + progress_callback: Optional[Callable[[int, int], None]] = None, ) -> BatchProcessingResult: """ Process items in batches asynchronously - + Args: items: List of items to process processor_func: Async function to process each item progress_callback: Optional callback for progress updates - + Returns: BatchProcessingResult with statistics """ @@ -95,135 +98,125 @@ async def process_batch( successful_items = 0 failed_items = 0 errors = [] - + # Split into batches batches = [ - items[i:i + self.config.batch_size] + items[i : i + self.config.batch_size] for i in range(0, total_items, self.config.batch_size) ] - + # Process batches concurrently tasks = [] for batch_idx, batch in enumerate(batches): task = self._process_batch_with_semaphore( - batch, - batch_idx, - processor_func, - progress_callback + batch, batch_idx, processor_func, progress_callback ) tasks.append(task) - + # Wait for all batches batch_results = await asyncio.gather(*tasks, return_exceptions=True) - + # Aggregate results for result in batch_results: if isinstance(result, Exception): failed_items += self.config.batch_size - errors.append({ - "error": str(result), - "type": type(result).__name__ - }) + errors.append({"error": str(result), "type": type(result).__name__}) else: successful_items += result["successful"] failed_items += result["failed"] errors.extend(result.get("errors", [])) - + processing_time = time.time() - start_time - + return BatchProcessingResult( total_items=total_items, processed_items=total_items, successful_items=successful_items, failed_items=failed_items, processing_time_seconds=processing_time, - errors=errors + errors=errors, ) - + async def _process_batch_with_semaphore( self, batch: List[Any], batch_idx: int, processor_func: Callable[[Any], Awaitable[Any]], - progress_callback: Optional[Callable[[int, int], None]] + progress_callback: Optional[Callable[[int, int], None]], ) -> Dict[str, Any]: """Process a single batch with semaphore control""" async with self.semaphore: return await self._process_single_batch( - batch, - batch_idx, - processor_func, - progress_callback + batch, batch_idx, processor_func, progress_callback ) - + async def _process_single_batch( self, batch: List[Any], batch_idx: int, processor_func: Callable[[Any], Awaitable[Any]], - progress_callback: Optional[Callable[[int, int], None]] + progress_callback: Optional[Callable[[int, int], None]], ) -> Dict[str, Any]: """Process a single batch""" successful = 0 failed = 0 errors = [] - + # Process items in batch concurrently tasks = [processor_func(item) for item in batch] results = await asyncio.gather(*tasks, return_exceptions=True) - + for idx, result in enumerate(results): if isinstance(result, Exception): failed += 1 - errors.append({ - "item_index": batch_idx * self.config.batch_size + idx, - "error": str(result), - "type": type(result).__name__ - }) + errors.append( + { + "item_index": batch_idx * self.config.batch_size + idx, + "error": str(result), + "type": type(result).__name__, + } + ) else: successful += 1 - + # Progress callback if progress_callback: total_processed = (batch_idx + 1) * self.config.batch_size progress_callback(total_processed, len(batch) * (batch_idx + 1)) - - return { - "successful": successful, - "failed": failed, - "errors": errors - } + + return {"successful": successful, "failed": failed, "errors": errors} class AsyncDocumentIndexer: """ Async document indexer with batching and retry logic """ - + def __init__( self, index_func: Callable[[Document], Awaitable[bool]], - config: Optional[BatchProcessingConfig] = None + config: Optional[BatchProcessingConfig] = None, ): self.index_func = index_func self.config = config or BatchProcessingConfig() self.batch_processor = AsyncBatchProcessor(self.config) - + async def index_documents( self, documents: List[Document], - progress_callback: Optional[Callable[[int, int], None]] = None + progress_callback: Optional[Callable[[int, int], None]] = None, ) -> BatchProcessingResult: """ Index documents asynchronously in batches - + Args: documents: List of documents to index progress_callback: Optional callback for progress - + Returns: BatchProcessingResult with statistics """ + async def index_document(doc: Document) -> bool: """Index a single document with retry""" for attempt in range(self.config.max_retries): @@ -238,27 +231,25 @@ async def index_document(doc: Document) -> bool: ) continue raise - + return False - + return await self.batch_processor.process_batch( - documents, - index_document, - progress_callback + documents, index_document, progress_callback ) - + async def index_documents_streaming( self, document_stream: AsyncIterator[Document], - progress_callback: Optional[Callable[[int, int], None]] = None + progress_callback: Optional[Callable[[int, int], None]] = None, ) -> BatchProcessingResult: """ Index documents from async stream - + Args: document_stream: Async iterator of documents progress_callback: Optional callback for progress - + Returns: BatchProcessingResult with statistics """ @@ -268,39 +259,39 @@ async def index_documents_streaming( failed = 0 errors = [] start_time = time.time() - + async def process_current_batch(): nonlocal successful, failed, errors - + if not batch: return - + result = await self.index_documents(batch, progress_callback) successful += result.successful_items failed += result.failed_items errors.extend(result.errors) batch.clear() - + async for document in document_stream: batch.append(document) total_processed += 1 - + if len(batch) >= self.config.batch_size: await process_current_batch() - + # Process remaining batch if batch: await process_current_batch() - + processing_time = time.time() - start_time - + return BatchProcessingResult( total_items=total_processed, processed_items=total_processed, successful_items=successful, failed_items=failed, processing_time_seconds=processing_time, - errors=errors + errors=errors, ) @@ -308,65 +299,55 @@ class AsyncDocumentProcessor: """ Async document processor for parallel document processing """ - + def __init__( self, processor_func: Callable[[str, Dict], List[Document]], - config: Optional[BatchProcessingConfig] = None + config: Optional[BatchProcessingConfig] = None, ): self.processor_func = processor_func self.config = config or BatchProcessingConfig() self.batch_processor = AsyncBatchProcessor(self.config) - + async def process_documents( self, raw_documents: List[Dict[str, Any]], - progress_callback: Optional[Callable[[int, int], None]] = None + progress_callback: Optional[Callable[[int, int], None]] = None, ) -> tuple[List[Document], BatchProcessingResult]: """ Process raw documents asynchronously - + Args: raw_documents: List of dicts with 'content' and 'metadata' progress_callback: Optional callback for progress - + Returns: Tuple of (processed_documents, processing_result) """ processed_docs = [] errors = [] - + async def process_document(item: Dict[str, Any]) -> List[Document]: """Process a single document""" try: content = item.get("content", "") metadata = item.get("metadata", {}) - docs = await asyncio.to_thread( - self.processor_func, - content, - metadata - ) + docs = await asyncio.to_thread(self.processor_func, content, metadata) return docs except Exception as e: - errors.append({ - "item": item.get("id", "unknown"), - "error": str(e) - }) + errors.append({"item": item.get("id", "unknown"), "error": str(e)}) return [] - + result = await self.batch_processor.process_batch( - raw_documents, - process_document, - progress_callback + raw_documents, process_document, progress_callback ) - + # Collect all processed documents tasks = [process_document(item) for item in raw_documents] doc_lists = await asyncio.gather(*tasks, return_exceptions=True) - + for doc_list in doc_lists: if isinstance(doc_list, list): processed_docs.extend(doc_list) - - return processed_docs, result + return processed_docs, result diff --git a/src/deepiri_modelkit/rag/base.py b/src/deepiri_modelkit/rag/base.py index 7693e5d..e7ebf66 100644 --- a/src/deepiri_modelkit/rag/base.py +++ b/src/deepiri_modelkit/rag/base.py @@ -12,6 +12,7 @@ class DocumentType(Enum): """Types of documents that can be indexed""" + REGULATION = "regulation" # Laws, regulations, compliance documents POLICY = "policy" # Insurance policies, company policies MANUAL = "manual" # Equipment manuals, operation guides @@ -31,6 +32,7 @@ class DocumentType(Enum): class IndustryNiche(Enum): """Supported industry niches""" + INSURANCE = "insurance" # Property & casualty insurance MANUFACTURING = "manufacturing" # Industrial manufacturing PROPERTY_MANAGEMENT = "property_management" # Real estate management @@ -47,33 +49,34 @@ class IndustryNiche(Enum): @dataclass class RAGConfig: """Configuration for RAG system""" + # Industry configuration industry: IndustryNiche = IndustryNiche.GENERIC - + # Vector database configuration vector_db_type: str = "milvus" # milvus, pinecone, weaviate, memory collection_name: str = "deepiri_universal_rag" vector_db_host: str = "milvus" vector_db_port: int = 19530 - + # Embedding configuration embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2" embedding_dimension: int = 384 - + # Retrieval configuration top_k: int = 5 # Number of documents to retrieve similarity_threshold: float = 0.7 # Minimum similarity score use_reranking: bool = True # Use cross-encoder reranking reranker_model: str = "cross-encoder/ms-marco-MiniLM-L-6-v2" - + # Chunking configuration chunk_size: int = 1000 # Characters per chunk chunk_overlap: int = 200 # Overlap between chunks - + # Metadata filtering enable_metadata_filtering: bool = True date_range_filtering: bool = True - + # Multi-modal support support_images: bool = False support_tables: bool = True @@ -83,11 +86,12 @@ class RAGConfig: @dataclass class Document: """Universal document representation""" + id: str content: str doc_type: DocumentType industry: IndustryNiche - + # Metadata title: Optional[str] = None source: Optional[str] = None @@ -95,14 +99,14 @@ class Document: updated_at: Optional[datetime] = None author: Optional[str] = None version: Optional[str] = None - + # Industry-specific metadata metadata: Dict[str, Any] = field(default_factory=dict) - + # Processing metadata chunk_index: Optional[int] = None total_chunks: Optional[int] = None - + def to_dict(self) -> Dict[str, Any]: """Convert to dictionary for storage""" return { @@ -120,9 +124,9 @@ def to_dict(self) -> Dict[str, Any]: "chunk_index": self.chunk_index, "total_chunks": self.total_chunks, } - + @classmethod - def from_dict(cls, data: Dict[str, Any]) -> 'Document': + def from_dict(cls, data: Dict[str, Any]) -> "Document": """Create from dictionary""" return cls( id=data["id"], @@ -131,8 +135,16 @@ def from_dict(cls, data: Dict[str, Any]) -> 'Document': industry=IndustryNiche(data["industry"]), title=data.get("title"), source=data.get("source"), - created_at=datetime.fromisoformat(data["created_at"]) if data.get("created_at") else None, - updated_at=datetime.fromisoformat(data["updated_at"]) if data.get("updated_at") else None, + created_at=( + datetime.fromisoformat(data["created_at"]) + if data.get("created_at") + else None + ), + updated_at=( + datetime.fromisoformat(data["updated_at"]) + if data.get("updated_at") + else None + ), author=data.get("author"), version=data.get("version"), metadata=data.get("metadata", {}), @@ -144,10 +156,11 @@ def from_dict(cls, data: Dict[str, Any]) -> 'Document': @dataclass class RetrievalResult: """Result from RAG retrieval""" + document: Document score: float # Similarity score rerank_score: Optional[float] = None # Reranking score if applicable - + def to_dict(self) -> Dict[str, Any]: """Convert to dictionary""" return { @@ -160,20 +173,25 @@ def to_dict(self) -> Dict[str, Any]: @dataclass class RAGQuery: """Query for RAG system""" + query: str industry: Optional[IndustryNiche] = None doc_types: Optional[List[DocumentType]] = None date_range: Optional[tuple[datetime, datetime]] = None metadata_filters: Optional[Dict[str, Any]] = None top_k: Optional[int] = None - + def to_dict(self) -> Dict[str, Any]: """Convert to dictionary""" return { "query": self.query, "industry": self.industry.value if self.industry else None, - "doc_types": [dt.value for dt in self.doc_types] if self.doc_types else None, - "date_range": [dr.isoformat() for dr in self.date_range] if self.date_range else None, + "doc_types": ( + [dt.value for dt in self.doc_types] if self.doc_types else None + ), + "date_range": ( + [dr.isoformat() for dr in self.date_range] if self.date_range else None + ), "metadata_filters": self.metadata_filters, "top_k": self.top_k, } @@ -184,108 +202,108 @@ class UniversalRAGEngine(ABC): Abstract base class for universal RAG engine Implements common RAG patterns across all industries """ - + def __init__(self, config: RAGConfig): self.config = config self._initialize() - + @abstractmethod def _initialize(self): """Initialize RAG components (vector store, embeddings, etc.)""" pass - + @abstractmethod def index_document(self, document: Document) -> bool: """ Index a single document - + Args: document: Document to index - + Returns: True if successful, False otherwise """ pass - + @abstractmethod def index_documents(self, documents: List[Document]) -> Dict[str, Any]: """ Index multiple documents in batch - + Args: documents: List of documents to index - + Returns: Statistics about the indexing operation """ pass - + @abstractmethod def retrieve(self, query: RAGQuery) -> List[RetrievalResult]: """ Retrieve relevant documents for a query - + Args: query: Query with filters and parameters - + Returns: List of retrieval results with scores """ pass - + @abstractmethod def generate_with_context( self, query: str, retrieved_docs: List[RetrievalResult], - llm_prompt_template: Optional[str] = None + llm_prompt_template: Optional[str] = None, ) -> Dict[str, Any]: """ Generate response using retrieved context - + Args: query: User query retrieved_docs: Retrieved documents for context llm_prompt_template: Optional custom prompt template - + Returns: Generated response with metadata """ pass - + @abstractmethod def delete_documents(self, doc_ids: List[str]) -> bool: """Delete documents by IDs""" pass - + @abstractmethod def update_document(self, doc_id: str, document: Document) -> bool: """Update an existing document""" pass - + @abstractmethod def get_statistics(self) -> Dict[str, Any]: """Get statistics about indexed documents""" pass - + def search( self, query: str, industry: Optional[IndustryNiche] = None, doc_types: Optional[List[DocumentType]] = None, top_k: Optional[int] = None, - **filters + **filters, ) -> List[RetrievalResult]: """ Convenience method for simple search - + Args: query: Search query industry: Filter by industry doc_types: Filter by document types top_k: Number of results **filters: Additional metadata filters - + Returns: List of retrieval results """ @@ -294,7 +312,6 @@ def search( industry=industry, doc_types=doc_types, top_k=top_k or self.config.top_k, - metadata_filters=filters if filters else None + metadata_filters=filters if filters else None, ) return self.retrieve(rag_query) - diff --git a/src/deepiri_modelkit/rag/caching.py b/src/deepiri_modelkit/rag/caching.py index 8da0ff7..6d25942 100644 --- a/src/deepiri_modelkit/rag/caching.py +++ b/src/deepiri_modelkit/rag/caching.py @@ -16,6 +16,7 @@ @dataclass class CacheEntry: """Cache entry with metadata""" + key: str value: Any created_at: datetime @@ -23,19 +24,19 @@ class CacheEntry: access_count: int = 0 last_accessed: Optional[datetime] = None tags: List[str] = None - + def __post_init__(self): if self.tags is None: self.tags = [] if self.last_accessed is None: self.last_accessed = self.created_at - + def is_expired(self) -> bool: """Check if entry is expired""" if self.expires_at is None: return False return datetime.now() > self.expires_at - + def to_dict(self) -> Dict: """Convert to dictionary for serialization""" return { @@ -44,21 +45,31 @@ def to_dict(self) -> Dict: "created_at": self.created_at.isoformat(), "expires_at": self.expires_at.isoformat() if self.expires_at else None, "access_count": self.access_count, - "last_accessed": self.last_accessed.isoformat() if self.last_accessed else None, - "tags": self.tags + "last_accessed": ( + self.last_accessed.isoformat() if self.last_accessed else None + ), + "tags": self.tags, } - + @classmethod - def from_dict(cls, data: Dict) -> 'CacheEntry': + def from_dict(cls, data: Dict) -> "CacheEntry": """Create from dictionary""" return cls( key=data["key"], value=data["value"], created_at=datetime.fromisoformat(data["created_at"]), - expires_at=datetime.fromisoformat(data["expires_at"]) if data.get("expires_at") else None, + expires_at=( + datetime.fromisoformat(data["expires_at"]) + if data.get("expires_at") + else None + ), access_count=data.get("access_count", 0), - last_accessed=datetime.fromisoformat(data["last_accessed"]) if data.get("last_accessed") else None, - tags=data.get("tags", []) + last_accessed=( + datetime.fromisoformat(data["last_accessed"]) + if data.get("last_accessed") + else None + ), + tags=data.get("tags", []), ) @@ -71,53 +82,52 @@ class AdvancedCacheManager: - Size limits - LRU eviction """ - + def __init__( self, redis_client=None, default_ttl: int = 3600, max_size: int = 10000, - enable_lru: bool = True + enable_lru: bool = True, ): self.redis_client = redis_client self.default_ttl = default_ttl self.max_size = max_size self.enable_lru = enable_lru - + # In-memory fallback if Redis unavailable self.memory_cache: Dict[str, CacheEntry] = {} self.tag_index: Dict[str, List[str]] = {} # tag -> [keys] - + def _get_key_prefix(self, namespace: str = "rag") -> str: """Get key prefix""" return f"{namespace}:" - + def _serialize_value(self, value: Any) -> str: """Serialize value for storage""" if isinstance(value, (list, dict)): return json.dumps(value) return str(value) - + def _deserialize_value(self, value: str, value_type: type = None) -> Any: """Deserialize value from storage""" try: - if value_type == list or (isinstance(value, str) and value.startswith('[')): + if value_type == list or (isinstance(value, str) and value.startswith("[")): return json.loads(value) - elif value_type == dict or (isinstance(value, str) and value.startswith('{')): + elif value_type == dict or ( + isinstance(value, str) and value.startswith("{") + ): return json.loads(value) return value except (json.JSONDecodeError, TypeError): return value - + def get( - self, - key: str, - namespace: str = "rag", - update_access: bool = True + self, key: str, namespace: str = "rag", update_access: bool = True ) -> Optional[Any]: """Get value from cache""" full_key = f"{self._get_key_prefix(namespace)}{key}" - + # Try Redis first if self.redis_client: try: @@ -125,65 +135,65 @@ def get( if cached: entry_data = json.loads(cached) entry = CacheEntry.from_dict(entry_data) - + if entry.is_expired(): self.delete(key, namespace) return None - + if update_access: entry.access_count += 1 entry.last_accessed = datetime.now() self._update_redis_entry(full_key, entry) - + return entry.value except Exception as e: # Fallback to memory pass - + # Fallback to memory cache if key in self.memory_cache: entry = self.memory_cache[key] - + if entry.is_expired(): del self.memory_cache[key] return None - + if update_access: entry.access_count += 1 entry.last_accessed = datetime.now() - + return entry.value - + return None - + def set( self, key: str, value: Any, namespace: str = "rag", ttl: Optional[int] = None, - tags: Optional[List[str]] = None + tags: Optional[List[str]] = None, ) -> bool: """Set value in cache""" full_key = f"{self._get_key_prefix(namespace)}{key}" ttl = ttl or self.default_ttl tags = tags or [] - + expires_at = datetime.now() + timedelta(seconds=ttl) entry = CacheEntry( key=full_key, value=value, created_at=datetime.now(), expires_at=expires_at, - tags=tags + tags=tags, ) - + # Try Redis first if self.redis_client: try: entry_data = json.dumps(entry.to_dict()) self.redis_client.setex(full_key, ttl, entry_data) - + # Update tag index for tag in tags: tag_key = f"{self._get_key_prefix(namespace)}tag:{tag}" @@ -192,12 +202,12 @@ def set( if full_key not in self.tag_index[tag_key]: self.tag_index[tag_key].append(full_key) self.redis_client.sadd(tag_key, full_key) - + return True except Exception as e: # Fallback to memory pass - + # Fallback to memory cache # Check size limit if len(self.memory_cache) >= self.max_size: @@ -207,25 +217,25 @@ def set( # Remove oldest oldest_key = min( self.memory_cache.keys(), - key=lambda k: self.memory_cache[k].created_at + key=lambda k: self.memory_cache[k].created_at, ) del self.memory_cache[oldest_key] - + self.memory_cache[key] = entry - + # Update tag index for tag in tags: if tag not in self.tag_index: self.tag_index[tag] = [] if key not in self.tag_index[tag]: self.tag_index[tag].append(key) - + return True - + def delete(self, key: str, namespace: str = "rag") -> bool: """Delete key from cache""" full_key = f"{self._get_key_prefix(namespace)}{key}" - + # Try Redis if self.redis_client: try: @@ -233,7 +243,7 @@ def delete(self, key: str, namespace: str = "rag") -> bool: return True except Exception: pass - + # Memory cache if key in self.memory_cache: entry = self.memory_cache[key] @@ -243,40 +253,40 @@ def delete(self, key: str, namespace: str = "rag") -> bool: self.tag_index[tag].remove(key) del self.memory_cache[key] return True - + return False - + def invalidate_by_tag(self, tag: str, namespace: str = "rag") -> int: """Invalidate all keys with given tag""" tag_key = f"{self._get_key_prefix(namespace)}tag:{tag}" count = 0 - + # Get keys from tag index keys_to_delete = [] - + if self.redis_client: try: keys_to_delete = list(self.redis_client.smembers(tag_key)) self.redis_client.delete(tag_key) except Exception: pass - + if tag in self.tag_index: keys_to_delete.extend(self.tag_index[tag]) del self.tag_index[tag] - + # Delete all keys for key in keys_to_delete: if self.delete(key.replace(self._get_key_prefix(namespace), ""), namespace): count += 1 - + return count - + def invalidate_by_pattern(self, pattern: str, namespace: str = "rag") -> int: """Invalidate keys matching pattern""" full_pattern = f"{self._get_key_prefix(namespace)}{pattern}" count = 0 - + if self.redis_client: try: keys = self.redis_client.keys(full_pattern) @@ -284,42 +294,41 @@ def invalidate_by_pattern(self, pattern: str, namespace: str = "rag") -> int: count = self.redis_client.delete(*keys) except Exception: pass - + # Memory cache keys_to_delete = [ - k for k in self.memory_cache.keys() - if self._match_pattern(k, pattern) + k for k in self.memory_cache.keys() if self._match_pattern(k, pattern) ] for key in keys_to_delete: self.delete(key, namespace) count += 1 - + return count - + def _match_pattern(self, key: str, pattern: str) -> bool: """Simple pattern matching (supports * wildcard)""" import fnmatch + return fnmatch.fnmatch(key, pattern) - + def _evict_lru(self): """Evict least recently used entry""" if not self.memory_cache: return - + lru_key = min( self.memory_cache.keys(), key=lambda k: ( - self.memory_cache[k].last_accessed or - self.memory_cache[k].created_at - ) + self.memory_cache[k].last_accessed or self.memory_cache[k].created_at + ), ) del self.memory_cache[lru_key] - + def _update_redis_entry(self, key: str, entry: CacheEntry): """Update Redis entry with new access info""" if not self.redis_client: return - + try: entry_data = json.dumps(entry.to_dict()) # Get remaining TTL @@ -328,7 +337,7 @@ def _update_redis_entry(self, key: str, entry: CacheEntry): self.redis_client.setex(key, ttl, entry_data) except Exception: pass - + def get_stats(self) -> Dict[str, Any]: """Get cache statistics""" stats = { @@ -337,14 +346,14 @@ def get_stats(self) -> Dict[str, Any]: "tag_index_size": len(self.tag_index), "redis_available": self.redis_client is not None, } - + if self.memory_cache: total_access = sum(e.access_count for e in self.memory_cache.values()) stats["total_accesses"] = total_access stats["avg_access_per_entry"] = total_access / len(self.memory_cache) - + return stats - + def clear(self, namespace: str = "rag"): """Clear all cache entries in namespace""" pattern = f"{self._get_key_prefix(namespace)}*" @@ -353,52 +362,48 @@ def clear(self, namespace: str = "rag"): class EmbeddingCache: """Specialized cache for embeddings""" - + def __init__(self, cache_manager: AdvancedCacheManager): self.cache_manager = cache_manager self.namespace = "rag:embeddings" - + def get_embedding_key(self, text: str) -> str: """Generate cache key for embedding""" text_hash = hashlib.md5(text.encode()).hexdigest() return f"emb:{text_hash}" - + def get(self, text: str) -> Optional[Any]: """Get cached embedding""" key = self.get_embedding_key(text) return self.cache_manager.get(key, namespace=self.namespace) - + def set(self, text: str, embedding: Any, ttl: int = 86400): """Cache embedding (24 hour default TTL)""" key = self.get_embedding_key(text) return self.cache_manager.set( - key, - embedding, - namespace=self.namespace, - ttl=ttl, - tags=["embedding"] + key, embedding, namespace=self.namespace, ttl=ttl, tags=["embedding"] ) class QueryResultCache: """Specialized cache for query results""" - + def __init__(self, cache_manager: AdvancedCacheManager): self.cache_manager = cache_manager self.namespace = "rag:queries" - + def get_query_key(self, query: RAGQuery) -> str: """Generate cache key for query""" query_dict = query.to_dict() query_str = json.dumps(query_dict, sort_keys=True) query_hash = hashlib.md5(query_str.encode()).hexdigest() return f"query:{query_hash}" - + def get(self, query: RAGQuery) -> Optional[List[RetrievalResult]]: """Get cached query results""" key = self.get_query_key(query) cached = self.cache_manager.get(key, namespace=self.namespace) - + if cached: # Reconstruct RetrievalResult objects results = [] @@ -407,60 +412,55 @@ def get(self, query: RAGQuery) -> Optional[List[RetrievalResult]]: result = RetrievalResult( document=doc, score=item["score"], - rerank_score=item.get("rerank_score") + rerank_score=item.get("rerank_score"), ) results.append(result) return results - + return None - + def set( self, query: RAGQuery, results: List[RetrievalResult], ttl: int = 3600, - tags: Optional[List[str]] = None + tags: Optional[List[str]] = None, ): """Cache query results""" key = self.get_query_key(query) - + # Serialize results serialized = [ { "document": r.document.to_dict(), "score": r.score, - "rerank_score": r.rerank_score + "rerank_score": r.rerank_score, } for r in results ] - + # Add query tags query_tags = tags or [] if query.industry: - query_tags.append(f"industry:{query.industry.value if hasattr(query.industry, 'value') else query.industry}") + query_tags.append( + f"industry:{query.industry.value if hasattr(query.industry, 'value') else query.industry}" + ) if query.doc_types: for dt in query.doc_types: query_tags.append(f"doctype:{dt.value if hasattr(dt, 'value') else dt}") - + return self.cache_manager.set( - key, - serialized, - namespace=self.namespace, - ttl=ttl, - tags=query_tags + key, serialized, namespace=self.namespace, ttl=ttl, tags=query_tags ) - + def invalidate_by_industry(self, industry: str): """Invalidate all queries for an industry""" return self.cache_manager.invalidate_by_tag( - f"industry:{industry}", - namespace=self.namespace + f"industry:{industry}", namespace=self.namespace ) - + def invalidate_by_doc_type(self, doc_type: str): """Invalidate all queries for a document type""" return self.cache_manager.invalidate_by_tag( - f"doctype:{doc_type}", - namespace=self.namespace + f"doctype:{doc_type}", namespace=self.namespace ) - diff --git a/src/deepiri_modelkit/rag/monitoring.py b/src/deepiri_modelkit/rag/monitoring.py index 95429b3..da5e6fa 100644 --- a/src/deepiri_modelkit/rag/monitoring.py +++ b/src/deepiri_modelkit/rag/monitoring.py @@ -16,6 +16,7 @@ @dataclass class RetrievalMetrics: """Metrics for a single retrieval operation""" + query_id: str query_text: str timestamp: datetime @@ -27,7 +28,7 @@ class RetrievalMetrics: query_expansion_used: bool = False industry: Optional[str] = None doc_types: Optional[List[str]] = None - + def to_dict(self) -> Dict: """Convert to dictionary""" return { @@ -48,6 +49,7 @@ def to_dict(self) -> Dict: @dataclass class IndexingMetrics: """Metrics for indexing operations""" + operation_id: str timestamp: datetime operation_type: str # "index", "update", "delete" @@ -55,7 +57,7 @@ class IndexingMetrics: processing_time_ms: float success: bool error: Optional[str] = None - + def to_dict(self) -> Dict: """Convert to dictionary""" return { @@ -72,6 +74,7 @@ def to_dict(self) -> Dict: @dataclass class SystemMetrics: """System-wide metrics""" + total_queries: int = 0 total_indexed_documents: int = 0 cache_hit_rate: float = 0.0 @@ -79,7 +82,7 @@ class SystemMetrics: avg_indexing_time_ms: float = 0.0 error_rate: float = 0.0 last_updated: Optional[datetime] = None - + def to_dict(self) -> Dict: """Convert to dictionary""" return { @@ -89,7 +92,9 @@ def to_dict(self) -> Dict: "avg_retrieval_time_ms": self.avg_retrieval_time_ms, "avg_indexing_time_ms": self.avg_indexing_time_ms, "error_rate": self.error_rate, - "last_updated": self.last_updated.isoformat() if self.last_updated else None, + "last_updated": ( + self.last_updated.isoformat() if self.last_updated else None + ), } @@ -97,21 +102,21 @@ class RAGMonitor: """ Monitor RAG system performance and collect metrics """ - + def __init__(self, max_history: int = 10000): self.max_history = max_history - + # Metrics storage self.retrieval_metrics: List[RetrievalMetrics] = [] self.indexing_metrics: List[IndexingMetrics] = [] - + # Aggregated stats self.system_metrics = SystemMetrics() - + # Time windows for analysis self.hourly_stats: Dict[str, Dict] = defaultdict(dict) self.daily_stats: Dict[str, Dict] = defaultdict(dict) - + def record_retrieval( self, query: RAGQuery, @@ -119,11 +124,11 @@ def record_retrieval( retrieval_time_ms: float, cache_hit: bool = False, reranking_used: bool = False, - query_expansion_used: bool = False + query_expansion_used: bool = False, ): """Record retrieval metrics""" query_id = f"q_{int(time.time() * 1000)}" - + metric = RetrievalMetrics( query_id=query_id, query_text=query.query, @@ -134,30 +139,41 @@ def record_retrieval( cache_hit=cache_hit, reranking_used=reranking_used, query_expansion_used=query_expansion_used, - industry=query.industry.value if query.industry and hasattr(query.industry, 'value') else str(query.industry), - doc_types=[dt.value if hasattr(dt, 'value') else str(dt) for dt in query.doc_types] if query.doc_types else None, + industry=( + query.industry.value + if query.industry and hasattr(query.industry, "value") + else str(query.industry) + ), + doc_types=( + [ + dt.value if hasattr(dt, "value") else str(dt) + for dt in query.doc_types + ] + if query.doc_types + else None + ), ) - + self.retrieval_metrics.append(metric) - + # Trim history if len(self.retrieval_metrics) > self.max_history: - self.retrieval_metrics = self.retrieval_metrics[-self.max_history:] - + self.retrieval_metrics = self.retrieval_metrics[-self.max_history :] + # Update aggregated stats self._update_system_metrics() - + def record_indexing( self, operation_type: str, num_documents: int, processing_time_ms: float, success: bool, - error: Optional[str] = None + error: Optional[str] = None, ): """Record indexing metrics""" operation_id = f"idx_{int(time.time() * 1000)}" - + metric = IndexingMetrics( operation_id=operation_id, timestamp=datetime.now(), @@ -165,71 +181,81 @@ def record_indexing( num_documents=num_documents, processing_time_ms=processing_time_ms, success=success, - error=error + error=error, ) - + self.indexing_metrics.append(metric) - + # Trim history if len(self.indexing_metrics) > self.max_history: - self.indexing_metrics = self.indexing_metrics[-self.max_history:] - + self.indexing_metrics = self.indexing_metrics[-self.max_history :] + # Update aggregated stats self._update_system_metrics() - + def _update_system_metrics(self): """Update system-wide aggregated metrics""" if not self.retrieval_metrics: return - + # Calculate averages total_queries = len(self.retrieval_metrics) cache_hits = sum(1 for m in self.retrieval_metrics if m.cache_hit) total_retrieval_time = sum(m.retrieval_time_ms for m in self.retrieval_metrics) - + self.system_metrics.total_queries = total_queries - self.system_metrics.cache_hit_rate = cache_hits / total_queries if total_queries > 0 else 0.0 - self.system_metrics.avg_retrieval_time_ms = total_retrieval_time / total_queries if total_queries > 0 else 0.0 - + self.system_metrics.cache_hit_rate = ( + cache_hits / total_queries if total_queries > 0 else 0.0 + ) + self.system_metrics.avg_retrieval_time_ms = ( + total_retrieval_time / total_queries if total_queries > 0 else 0.0 + ) + if self.indexing_metrics: - total_indexed = sum(m.num_documents for m in self.indexing_metrics if m.success) - total_indexing_time = sum(m.processing_time_ms for m in self.indexing_metrics) + total_indexed = sum( + m.num_documents for m in self.indexing_metrics if m.success + ) + total_indexing_time = sum( + m.processing_time_ms for m in self.indexing_metrics + ) total_indexing_ops = len(self.indexing_metrics) - + self.system_metrics.total_indexed_documents = total_indexed - self.system_metrics.avg_indexing_time_ms = total_indexing_time / total_indexing_ops if total_indexing_ops > 0 else 0.0 - + self.system_metrics.avg_indexing_time_ms = ( + total_indexing_time / total_indexing_ops + if total_indexing_ops > 0 + else 0.0 + ) + # Error rate total_ops = total_queries + len(self.indexing_metrics) errors = sum(1 for m in self.indexing_metrics if not m.success) self.system_metrics.error_rate = errors / total_ops if total_ops > 0 else 0.0 - + self.system_metrics.last_updated = datetime.now() - + def get_retrieval_stats( - self, - time_window_minutes: Optional[int] = None, - industry: Optional[str] = None + self, time_window_minutes: Optional[int] = None, industry: Optional[str] = None ) -> Dict[str, Any]: """Get retrieval statistics for time window""" metrics = self.retrieval_metrics - + # Filter by time window if time_window_minutes: cutoff = datetime.now() - timedelta(minutes=time_window_minutes) metrics = [m for m in metrics if m.timestamp >= cutoff] - + # Filter by industry if industry: metrics = [m for m in metrics if m.industry == industry] - + if not metrics: return { "count": 0, "avg_time_ms": 0.0, "cache_hit_rate": 0.0, } - + return { "count": len(metrics), "avg_time_ms": sum(m.retrieval_time_ms for m in metrics) / len(metrics), @@ -237,21 +263,21 @@ def get_retrieval_stats( "max_time_ms": max(m.retrieval_time_ms for m in metrics), "cache_hit_rate": sum(1 for m in metrics if m.cache_hit) / len(metrics), "avg_results": sum(m.num_results for m in metrics) / len(metrics), - "avg_top_score": sum(m.top_score for m in metrics if m.top_score) / len([m for m in metrics if m.top_score]), + "avg_top_score": sum(m.top_score for m in metrics if m.top_score) + / len([m for m in metrics if m.top_score]), } - + def get_indexing_stats( - self, - time_window_minutes: Optional[int] = None + self, time_window_minutes: Optional[int] = None ) -> Dict[str, Any]: """Get indexing statistics""" metrics = self.indexing_metrics - + # Filter by time window if time_window_minutes: cutoff = datetime.now() - timedelta(minutes=time_window_minutes) metrics = [m for m in metrics if m.timestamp >= cutoff] - + if not metrics: return { "count": 0, @@ -259,9 +285,9 @@ def get_indexing_stats( "avg_time_ms": 0.0, "success_rate": 0.0, } - + successful = [m for m in metrics if m.success] - + return { "count": len(metrics), "total_documents": sum(m.num_documents for m in successful), @@ -269,40 +295,30 @@ def get_indexing_stats( "success_rate": len(successful) / len(metrics) if metrics else 0.0, "error_count": len([m for m in metrics if not m.success]), } - + def get_top_queries( - self, - limit: int = 10, - time_window_minutes: Optional[int] = None + self, limit: int = 10, time_window_minutes: Optional[int] = None ) -> List[Dict[str, Any]]: """Get most frequent queries""" metrics = self.retrieval_metrics - + # Filter by time window if time_window_minutes: cutoff = datetime.now() - timedelta(minutes=time_window_minutes) metrics = [m for m in metrics if m.timestamp >= cutoff] - + # Count query frequencies query_counts: Dict[str, int] = defaultdict(int) for m in metrics: query_counts[m.query_text] += 1 - + # Sort by frequency - top_queries = sorted( - query_counts.items(), - key=lambda x: x[1], - reverse=True - )[:limit] - - return [ - { - "query": query, - "count": count - } - for query, count in top_queries + top_queries = sorted(query_counts.items(), key=lambda x: x[1], reverse=True)[ + :limit ] - + + return [{"query": query, "count": count} for query, count in top_queries] + def get_performance_report(self) -> Dict[str, Any]: """Get comprehensive performance report""" return { @@ -313,37 +329,43 @@ def get_performance_report(self) -> Dict[str, Any]: "indexing_stats_24h": self.get_indexing_stats(time_window_minutes=1440), "top_queries_24h": self.get_top_queries(limit=10, time_window_minutes=1440), } - + def export_metrics(self, filepath: str): """Export metrics to JSON file""" data = { - "retrieval_metrics": [m.to_dict() for m in self.retrieval_metrics[-1000:]], # Last 1000 - "indexing_metrics": [m.to_dict() for m in self.indexing_metrics[-1000:]], # Last 1000 + "retrieval_metrics": [ + m.to_dict() for m in self.retrieval_metrics[-1000:] + ], # Last 1000 + "indexing_metrics": [ + m.to_dict() for m in self.indexing_metrics[-1000:] + ], # Last 1000 "system_metrics": self.system_metrics.to_dict(), "exported_at": datetime.now().isoformat(), } - - with open(filepath, 'w') as f: + + with open(filepath, "w") as f: json.dump(data, f, indent=2) class PerformanceTimer: """Context manager for timing operations""" - - def __init__(self, monitor: Optional[RAGMonitor] = None, operation_name: str = "operation"): + + def __init__( + self, monitor: Optional[RAGMonitor] = None, operation_name: str = "operation" + ): self.monitor = monitor self.operation_name = operation_name self.start_time = None self.end_time = None - + def __enter__(self): self.start_time = time.time() return self - + def __exit__(self, exc_type, exc_val, exc_tb): self.end_time = time.time() return False - + def elapsed_ms(self) -> float: """Get elapsed time in milliseconds""" if self.start_time and self.end_time: @@ -351,4 +373,3 @@ def elapsed_ms(self) -> float: elif self.start_time: return (time.time() - self.start_time) * 1000 return 0.0 - diff --git a/src/deepiri_modelkit/rag/processors.py b/src/deepiri_modelkit/rag/processors.py index 3fdb477..5585cde 100644 --- a/src/deepiri_modelkit/rag/processors.py +++ b/src/deepiri_modelkit/rag/processors.py @@ -13,60 +13,62 @@ class DocumentProcessor(ABC): """Base class for document processing""" - + def __init__(self, chunk_size: int = 1000, chunk_overlap: int = 200): self.chunk_size = chunk_size self.chunk_overlap = chunk_overlap - + @abstractmethod def process(self, raw_content: str, metadata: Dict[str, Any]) -> List[Document]: """ Process raw content into structured documents - + Args: raw_content: Raw text content metadata: Document metadata - + Returns: List of processed document chunks """ pass - + def chunk_text(self, text: str) -> List[str]: """ Chunk text into smaller pieces with overlap - + Args: text: Text to chunk - + Returns: List of text chunks """ chunks = [] start = 0 text_length = len(text) - + while start < text_length: end = start + self.chunk_size - + # Find the last sentence boundary within chunk_size if end < text_length: # Look for sentence endings - for sep in ['. ', '.\n', '! ', '!\n', '? ', '?\n']: + for sep in [". ", ".\n", "! ", "!\n", "? ", "?\n"]: last_sep = text.rfind(sep, start, end) if last_sep != -1: end = last_sep + 1 break - + chunk = text[start:end].strip() if chunk: chunks.append(chunk) - + # Move start forward with overlap - start = end - self.chunk_overlap if end - self.chunk_overlap > start else end - + start = ( + end - self.chunk_overlap if end - self.chunk_overlap > start else end + ) + return chunks - + def extract_metadata(self, content: str) -> Dict[str, Any]: """Extract metadata from content (can be overridden)""" return {} @@ -77,72 +79,71 @@ class RegulationProcessor(DocumentProcessor): Processor for regulations, policies, and compliance documents Common across insurance, healthcare, manufacturing, etc. """ - + def process(self, raw_content: str, metadata: Dict[str, Any]) -> List[Document]: """Process regulation documents""" # Extract sections and subsections sections = self._extract_sections(raw_content) - + documents = [] - base_id = metadata.get('id', 'reg_' + str(hash(raw_content[:100]))) - industry = IndustryNiche(metadata.get('industry', 'generic')) - + base_id = metadata.get("id", "reg_" + str(hash(raw_content[:100]))) + industry = IndustryNiche(metadata.get("industry", "generic")) + for idx, section in enumerate(sections): doc = Document( id=f"{base_id}_chunk_{idx}", - content=section['content'], + content=section["content"], doc_type=DocumentType.REGULATION, industry=industry, - title=metadata.get('title', 'Regulation Document'), - source=metadata.get('source'), - created_at=self._parse_date(metadata.get('created_at')), + title=metadata.get("title", "Regulation Document"), + source=metadata.get("source"), + created_at=self._parse_date(metadata.get("created_at")), metadata={ **metadata, - 'section': section.get('section'), - 'subsection': section.get('subsection'), + "section": section.get("section"), + "subsection": section.get("subsection"), }, chunk_index=idx, total_chunks=len(sections), ) documents.append(doc) - + return documents - + def _extract_sections(self, content: str) -> List[Dict[str, Any]]: """Extract sections from regulation text""" # Match section headers like "Section 1.2.3", "Article 5", etc. - section_pattern = r'(Section|Article|Part|Chapter)\s+(\d+(?:\.\d+)*)' - + section_pattern = r"(Section|Article|Part|Chapter)\s+(\d+(?:\.\d+)*)" + sections = [] - current_section = {'section': None, 'content': ''} - - lines = content.split('\n') + current_section = {"section": None, "content": ""} + + lines = content.split("\n") for line in lines: match = re.search(section_pattern, line, re.IGNORECASE) if match: # Save previous section - if current_section['content']: + if current_section["content"]: sections.append(current_section) # Start new section - current_section = { - 'section': match.group(0), - 'content': line + '\n' - } + current_section = {"section": match.group(0), "content": line + "\n"} else: - current_section['content'] += line + '\n' - + current_section["content"] += line + "\n" + # Add last section - if current_section['content']: + if current_section["content"]: sections.append(current_section) - + # If no sections found, treat entire content as one chunk if not sections: chunks = self.chunk_text(content) - sections = [{'section': f'Chunk {i+1}', 'content': chunk} - for i, chunk in enumerate(chunks)] - + sections = [ + {"section": f"Chunk {i+1}", "content": chunk} + for i, chunk in enumerate(chunks) + ] + return sections - + def _parse_date(self, date_str: Optional[str]) -> Optional[datetime]: """Parse date string to datetime""" if not date_str: @@ -158,45 +159,53 @@ class HistoricalDataProcessor(DocumentProcessor): Processor for historical operational data - Work orders, maintenance logs, claim records, service history """ - + def process(self, raw_content: str, metadata: Dict[str, Any]) -> List[Document]: """Process historical data records""" - doc_type_str = metadata.get('doc_type', 'work_order') - doc_type = DocumentType(doc_type_str) if doc_type_str else DocumentType.WORK_ORDER - + doc_type_str = metadata.get("doc_type", "work_order") + doc_type = ( + DocumentType(doc_type_str) if doc_type_str else DocumentType.WORK_ORDER + ) + # Historical data is typically already structured # We may need minimal chunking - chunks = self.chunk_text(raw_content) if len(raw_content) > self.chunk_size else [raw_content] - + chunks = ( + self.chunk_text(raw_content) + if len(raw_content) > self.chunk_size + else [raw_content] + ) + documents = [] - base_id = metadata.get('id', 'hist_' + str(hash(raw_content[:100]))) - industry = IndustryNiche(metadata.get('industry', 'generic')) - + base_id = metadata.get("id", "hist_" + str(hash(raw_content[:100]))) + industry = IndustryNiche(metadata.get("industry", "generic")) + for idx, chunk in enumerate(chunks): doc = Document( id=f"{base_id}_chunk_{idx}", content=chunk, doc_type=doc_type, industry=industry, - title=metadata.get('title', f'{doc_type.value.replace("_", " ").title()}'), - source=metadata.get('source'), - created_at=self._parse_date(metadata.get('created_at')), - updated_at=self._parse_date(metadata.get('updated_at')), + title=metadata.get( + "title", f'{doc_type.value.replace("_", " ").title()}' + ), + source=metadata.get("source"), + created_at=self._parse_date(metadata.get("created_at")), + updated_at=self._parse_date(metadata.get("updated_at")), metadata={ **metadata, - 'record_type': doc_type.value, - 'record_id': metadata.get('record_id'), - 'status': metadata.get('status'), - 'priority': metadata.get('priority'), - 'assigned_to': metadata.get('assigned_to'), + "record_type": doc_type.value, + "record_id": metadata.get("record_id"), + "status": metadata.get("status"), + "priority": metadata.get("priority"), + "assigned_to": metadata.get("assigned_to"), }, chunk_index=idx, total_chunks=len(chunks), ) documents.append(doc) - + return documents - + def _parse_date(self, date_str: Optional[str]) -> Optional[datetime]: """Parse date string to datetime""" if not date_str: @@ -207,7 +216,7 @@ def _parse_date(self, date_str: Optional[str]) -> Optional[datetime]: return datetime.fromisoformat(date_str) except (ValueError, AttributeError): try: - return datetime.strptime(date_str, '%Y-%m-%d') + return datetime.strptime(date_str, "%Y-%m-%d") except (ValueError, AttributeError): return None @@ -217,73 +226,83 @@ class KnowledgeBaseProcessor(DocumentProcessor): Processor for knowledge base articles, FAQs, and guides - Equipment repair guides, compliance advice, troubleshooting steps """ - + def process(self, raw_content: str, metadata: Dict[str, Any]) -> List[Document]: """Process knowledge base articles""" - doc_type = DocumentType.KNOWLEDGE_BASE if metadata.get('doc_type') == 'knowledge_base' else DocumentType.FAQ - + doc_type = ( + DocumentType.KNOWLEDGE_BASE + if metadata.get("doc_type") == "knowledge_base" + else DocumentType.FAQ + ) + # Extract Q&A pairs if FAQ format if doc_type == DocumentType.FAQ: qa_pairs = self._extract_qa_pairs(raw_content) if qa_pairs: return self._process_qa_pairs(qa_pairs, metadata) - + # Otherwise, process as regular article chunks = self.chunk_text(raw_content) - + documents = [] - base_id = metadata.get('id', 'kb_' + str(hash(raw_content[:100]))) - industry = IndustryNiche(metadata.get('industry', 'generic')) - + base_id = metadata.get("id", "kb_" + str(hash(raw_content[:100]))) + industry = IndustryNiche(metadata.get("industry", "generic")) + for idx, chunk in enumerate(chunks): doc = Document( id=f"{base_id}_chunk_{idx}", content=chunk, doc_type=doc_type, industry=industry, - title=metadata.get('title', 'Knowledge Base Article'), - source=metadata.get('source'), - created_at=self._parse_date(metadata.get('created_at')), - updated_at=self._parse_date(metadata.get('updated_at')), - author=metadata.get('author'), + title=metadata.get("title", "Knowledge Base Article"), + source=metadata.get("source"), + created_at=self._parse_date(metadata.get("created_at")), + updated_at=self._parse_date(metadata.get("updated_at")), + author=metadata.get("author"), metadata={ **metadata, - 'category': metadata.get('category'), - 'tags': metadata.get('tags', []), - 'difficulty_level': metadata.get('difficulty_level'), + "category": metadata.get("category"), + "tags": metadata.get("tags", []), + "difficulty_level": metadata.get("difficulty_level"), }, chunk_index=idx, total_chunks=len(chunks), ) documents.append(doc) - + return documents - + def _extract_qa_pairs(self, content: str) -> List[Dict[str, str]]: """Extract Q&A pairs from FAQ content""" qa_pairs = [] - + # Try different FAQ formats # Format 1: Q: ... A: ... - pattern1 = r'Q:\s*(.+?)\s*A:\s*(.+?)(?=Q:|$)' + pattern1 = r"Q:\s*(.+?)\s*A:\s*(.+?)(?=Q:|$)" matches = re.findall(pattern1, content, re.DOTALL | re.IGNORECASE) if matches: - qa_pairs.extend([{'question': q.strip(), 'answer': a.strip()} for q, a in matches]) - + qa_pairs.extend( + [{"question": q.strip(), "answer": a.strip()} for q, a in matches] + ) + # Format 2: Question/Answer headers - pattern2 = r'Question:\s*(.+?)\s*Answer:\s*(.+?)(?=Question:|$)' + pattern2 = r"Question:\s*(.+?)\s*Answer:\s*(.+?)(?=Question:|$)" matches = re.findall(pattern2, content, re.DOTALL | re.IGNORECASE) if matches: - qa_pairs.extend([{'question': q.strip(), 'answer': a.strip()} for q, a in matches]) - + qa_pairs.extend( + [{"question": q.strip(), "answer": a.strip()} for q, a in matches] + ) + return qa_pairs - - def _process_qa_pairs(self, qa_pairs: List[Dict[str, str]], metadata: Dict[str, Any]) -> List[Document]: + + def _process_qa_pairs( + self, qa_pairs: List[Dict[str, str]], metadata: Dict[str, Any] + ) -> List[Document]: """Process Q&A pairs into documents""" documents = [] - base_id = metadata.get('id', 'faq_' + str(hash(str(qa_pairs[0])))) - industry = IndustryNiche(metadata.get('industry', 'generic')) - + base_id = metadata.get("id", "faq_" + str(hash(str(qa_pairs[0])))) + industry = IndustryNiche(metadata.get("industry", "generic")) + for idx, qa in enumerate(qa_pairs): content = f"Question: {qa['question']}\n\nAnswer: {qa['answer']}" doc = Document( @@ -291,20 +310,20 @@ def _process_qa_pairs(self, qa_pairs: List[Dict[str, str]], metadata: Dict[str, content=content, doc_type=DocumentType.FAQ, industry=industry, - title=qa['question'][:100], # Use question as title - source=metadata.get('source'), + title=qa["question"][:100], # Use question as title + source=metadata.get("source"), metadata={ **metadata, - 'question': qa['question'], - 'answer': qa['answer'], + "question": qa["question"], + "answer": qa["answer"], }, chunk_index=idx, total_chunks=len(qa_pairs), ) documents.append(doc) - + return documents - + def _parse_date(self, date_str: Optional[str]) -> Optional[datetime]: """Parse date string to datetime""" if not date_str: @@ -321,87 +340,89 @@ class ManualProcessor(DocumentProcessor): """ Processor for equipment manuals, operation guides, technical specifications """ - + def process(self, raw_content: str, metadata: Dict[str, Any]) -> List[Document]: """Process manual documents""" # Extract chapters and sections sections = self._extract_sections(raw_content) - + documents = [] - base_id = metadata.get('id', 'manual_' + str(hash(raw_content[:100]))) - industry = IndustryNiche(metadata.get('industry', 'generic')) - + base_id = metadata.get("id", "manual_" + str(hash(raw_content[:100]))) + industry = IndustryNiche(metadata.get("industry", "generic")) + for idx, section in enumerate(sections): doc = Document( id=f"{base_id}_chunk_{idx}", - content=section['content'], + content=section["content"], doc_type=DocumentType.MANUAL, industry=industry, - title=metadata.get('title', 'Equipment Manual'), - source=metadata.get('source'), - version=metadata.get('version'), + title=metadata.get("title", "Equipment Manual"), + source=metadata.get("source"), + version=metadata.get("version"), metadata={ **metadata, - 'chapter': section.get('chapter'), - 'section': section.get('section'), - 'equipment_model': metadata.get('equipment_model'), - 'manufacturer': metadata.get('manufacturer'), + "chapter": section.get("chapter"), + "section": section.get("section"), + "equipment_model": metadata.get("equipment_model"), + "manufacturer": metadata.get("manufacturer"), }, chunk_index=idx, total_chunks=len(sections), ) documents.append(doc) - + return documents - + def _extract_sections(self, content: str) -> List[Dict[str, Any]]: """Extract sections from manual text""" # Match chapter/section headers - section_pattern = r'(Chapter|Section)\s+(\d+(?:\.\d+)*):?\s*(.+?)(?=\n)' - + section_pattern = r"(Chapter|Section)\s+(\d+(?:\.\d+)*):?\s*(.+?)(?=\n)" + sections = [] - current_section = {'chapter': None, 'section': None, 'content': ''} - - lines = content.split('\n') + current_section = {"chapter": None, "section": None, "content": ""} + + lines = content.split("\n") for line in lines: match = re.search(section_pattern, line, re.IGNORECASE) if match: # Save previous section - if current_section['content']: + if current_section["content"]: sections.append(current_section) # Start new section section_type = match.group(1).lower() section_num = match.group(2) - section_title = match.group(3).strip() if match.group(3) else '' + section_title = match.group(3).strip() if match.group(3) else "" current_section = { section_type: f"{section_type.title()} {section_num}", - 'section_title': section_title, - 'content': line + '\n' + "section_title": section_title, + "content": line + "\n", } else: - current_section['content'] += line + '\n' - + current_section["content"] += line + "\n" + # Add last section - if current_section['content']: + if current_section["content"]: sections.append(current_section) - + # If no sections found, chunk the content if not sections: chunks = self.chunk_text(content) - sections = [{'section': f'Chunk {i+1}', 'content': chunk} - for i, chunk in enumerate(chunks)] - + sections = [ + {"section": f"Chunk {i+1}", "content": chunk} + for i, chunk in enumerate(chunks) + ] + return sections def get_processor(doc_type: DocumentType, **kwargs) -> DocumentProcessor: """ Factory function to get appropriate processor for document type - + Args: doc_type: Type of document **kwargs: Additional configuration for processor - + Returns: Configured document processor """ @@ -417,7 +438,6 @@ def get_processor(doc_type: DocumentType, **kwargs) -> DocumentProcessor: DocumentType.TECHNICAL_SPEC: ManualProcessor, # Similar processing DocumentType.PROCEDURE: ManualProcessor, # Similar processing } - + processor_class = processor_map.get(doc_type, DocumentProcessor) return processor_class(**kwargs) - diff --git a/src/deepiri_modelkit/rag/retrievers.py b/src/deepiri_modelkit/rag/retrievers.py index dcda0d0..bf40f86 100644 --- a/src/deepiri_modelkit/rag/retrievers.py +++ b/src/deepiri_modelkit/rag/retrievers.py @@ -12,7 +12,7 @@ class BaseRetriever(ABC): """Base class for retrievers""" - + @abstractmethod def retrieve(self, query: RAGQuery) -> List[RetrievalResult]: """Retrieve relevant documents for query""" @@ -24,30 +24,30 @@ class MultiModalRetriever(BaseRetriever): Multi-modal retriever supporting text, images, tables, and code Useful for technical manuals, equipment guides, etc. """ - + def __init__( self, text_embeddings, image_embeddings=None, table_embeddings=None, - code_embeddings=None + code_embeddings=None, ): self.text_embeddings = text_embeddings self.image_embeddings = image_embeddings self.table_embeddings = table_embeddings self.code_embeddings = code_embeddings - + def retrieve(self, query: RAGQuery) -> List[RetrievalResult]: """ Retrieve documents across multiple modalities - + Currently focuses on text, but can be extended for other modalities """ # For now, delegate to text retrieval # Future: Add image, table, code retrieval results = self._retrieve_text(query) return results - + def _retrieve_text(self, query: RAGQuery) -> List[RetrievalResult]: """Retrieve text documents""" # This will be implemented by the concrete RAG engine @@ -61,115 +61,117 @@ class HybridRetriever(BaseRetriever): - Semantic search (vector similarity) - Keyword search (BM25) - Metadata filtering - + Provides better recall than pure semantic search """ - + def __init__( self, vector_retriever, keyword_retriever=None, vector_weight: float = 0.7, - keyword_weight: float = 0.3 + keyword_weight: float = 0.3, ): self.vector_retriever = vector_retriever self.keyword_retriever = keyword_retriever self.vector_weight = vector_weight self.keyword_weight = keyword_weight - + def retrieve(self, query: RAGQuery) -> List[RetrievalResult]: """ Retrieve using hybrid approach - + Combines vector and keyword search results """ results = [] - + # Vector search vector_results = self._retrieve_vector(query) results.extend(vector_results) - + # Keyword search (if available) if self.keyword_retriever: keyword_results = self._retrieve_keyword(query) results.extend(keyword_results) - + # Merge and re-score - merged_results = self._merge_results(vector_results, keyword_results if self.keyword_retriever else []) - + merged_results = self._merge_results( + vector_results, keyword_results if self.keyword_retriever else [] + ) + return merged_results - + def _retrieve_vector(self, query: RAGQuery) -> List[RetrievalResult]: """Vector similarity search""" # Placeholder - implemented by concrete engine return [] - + def _retrieve_keyword(self, query: RAGQuery) -> List[RetrievalResult]: """Keyword search using BM25 or similar""" # Placeholder - implemented by concrete engine return [] - + def _merge_results( self, vector_results: List[RetrievalResult], - keyword_results: List[RetrievalResult] + keyword_results: List[RetrievalResult], ) -> List[RetrievalResult]: """ Merge results from different retrievers using weighted scoring - + Uses Reciprocal Rank Fusion (RRF) for combining rankings """ # Create a dictionary to store combined scores doc_scores: Dict[str, Dict[str, Any]] = {} - + # Process vector results for rank, result in enumerate(vector_results): doc_id = result.document.id # RRF score: 1 / (k + rank) where k=60 is common rrf_score = 1.0 / (60 + rank + 1) doc_scores[doc_id] = { - 'document': result.document, - 'vector_score': result.score, - 'vector_rrf': rrf_score, - 'keyword_rrf': 0.0, - 'keyword_score': 0.0, + "document": result.document, + "vector_score": result.score, + "vector_rrf": rrf_score, + "keyword_rrf": 0.0, + "keyword_score": 0.0, } - + # Process keyword results for rank, result in enumerate(keyword_results): doc_id = result.document.id rrf_score = 1.0 / (60 + rank + 1) - + if doc_id in doc_scores: - doc_scores[doc_id]['keyword_rrf'] = rrf_score - doc_scores[doc_id]['keyword_score'] = result.score + doc_scores[doc_id]["keyword_rrf"] = rrf_score + doc_scores[doc_id]["keyword_score"] = result.score else: doc_scores[doc_id] = { - 'document': result.document, - 'vector_score': 0.0, - 'vector_rrf': 0.0, - 'keyword_rrf': rrf_score, - 'keyword_score': result.score, + "document": result.document, + "vector_score": 0.0, + "vector_rrf": 0.0, + "keyword_rrf": rrf_score, + "keyword_score": result.score, } - + # Calculate combined scores merged = [] for doc_id, scores in doc_scores.items(): combined_rrf = ( - self.vector_weight * scores['vector_rrf'] + - self.keyword_weight * scores['keyword_rrf'] + self.vector_weight * scores["vector_rrf"] + + self.keyword_weight * scores["keyword_rrf"] ) - + result = RetrievalResult( - document=scores['document'], + document=scores["document"], score=combined_rrf, rerank_score=None, ) merged.append(result) - + # Sort by combined score merged.sort(key=lambda x: x.score, reverse=True) - + return merged @@ -180,54 +182,54 @@ class ContextualRetriever(BaseRetriever): - Temporal context (recent vs historical) - Industry context (specific to niche) """ - + def __init__( self, base_retriever: BaseRetriever, use_user_context: bool = True, use_temporal_context: bool = True, - use_industry_context: bool = True + use_industry_context: bool = True, ): self.base_retriever = base_retriever self.use_user_context = use_user_context self.use_temporal_context = use_temporal_context self.use_industry_context = use_industry_context - + def retrieve(self, query: RAGQuery) -> List[RetrievalResult]: """ Retrieve with contextual awareness """ # Get base results results = self.base_retriever.retrieve(query) - + # Apply contextual reranking if self.use_temporal_context: results = self._apply_temporal_boost(results, query) - + if self.use_industry_context: results = self._apply_industry_boost(results, query) - + return results - + def _apply_temporal_boost( - self, - results: List[RetrievalResult], - query: RAGQuery + self, results: List[RetrievalResult], query: RAGQuery ) -> List[RetrievalResult]: """ Boost recent documents (useful for regulations, updates) """ import time from datetime import datetime - + current_time = datetime.now().timestamp() - + for result in results: if result.document.updated_at or result.document.created_at: - doc_time = (result.document.updated_at or result.document.created_at).timestamp() + doc_time = ( + result.document.updated_at or result.document.created_at + ).timestamp() # Calculate age in days age_days = (current_time - doc_time) / 86400 - + # Apply decay: more recent = higher boost # Documents within 30 days get full boost # Older documents gradually lose boost @@ -239,28 +241,26 @@ def _apply_temporal_boost( temporal_boost = 0.8 else: temporal_boost = 0.7 - + result.score *= temporal_boost - + # Re-sort by adjusted scores results.sort(key=lambda x: x.score, reverse=True) return results - + def _apply_industry_boost( - self, - results: List[RetrievalResult], - query: RAGQuery + self, results: List[RetrievalResult], query: RAGQuery ) -> List[RetrievalResult]: """ Boost documents matching the query's industry """ if not query.industry: return results - + for result in results: if result.document.industry == query.industry: result.score *= 1.1 # 10% boost for industry match - + # Re-sort results.sort(key=lambda x: x.score, reverse=True) return results @@ -269,20 +269,19 @@ def _apply_industry_boost( def get_retriever(retriever_type: str, **kwargs) -> BaseRetriever: """ Factory function to get retriever - + Args: retriever_type: Type of retriever ('hybrid', 'multimodal', 'contextual') **kwargs: Configuration for retriever - + Returns: Configured retriever """ retriever_map = { - 'hybrid': HybridRetriever, - 'multimodal': MultiModalRetriever, - 'contextual': ContextualRetriever, + "hybrid": HybridRetriever, + "multimodal": MultiModalRetriever, + "contextual": ContextualRetriever, } - + retriever_class = retriever_map.get(retriever_type, HybridRetriever) return retriever_class(**kwargs) - diff --git a/src/deepiri_modelkit/rag/testing.py b/src/deepiri_modelkit/rag/testing.py index bc11473..f118dc2 100644 --- a/src/deepiri_modelkit/rag/testing.py +++ b/src/deepiri_modelkit/rag/testing.py @@ -14,13 +14,14 @@ @dataclass class TestCase: """Test case for RAG evaluation""" + query: str expected_doc_ids: List[str] # Document IDs that should be retrieved expected_doc_types: Optional[List[DocumentType]] = None min_score: float = 0.7 # Minimum similarity score top_k: int = 5 metadata: Dict[str, Any] = None - + def __post_init__(self): if self.metadata is None: self.metadata = {} @@ -29,6 +30,7 @@ def __post_init__(self): @dataclass class TestResult: """Result of a test case""" + test_case: TestCase retrieved_doc_ids: List[str] retrieved_scores: List[float] @@ -37,7 +39,7 @@ class TestResult: f1_score: float passed: bool error: Optional[str] = None - + def to_dict(self) -> Dict[str, Any]: """Convert to dictionary""" return { @@ -56,37 +58,37 @@ class RAGEvaluator: """ Evaluator for RAG system performance """ - + def __init__(self, rag_engine): self.rag_engine = rag_engine - + def evaluate( - self, - test_cases: List[TestCase], - industry: Optional[IndustryNiche] = None + self, test_cases: List[TestCase], industry: Optional[IndustryNiche] = None ) -> Dict[str, Any]: """ Evaluate RAG system on test cases - + Args: test_cases: List of test cases industry: Industry context - + Returns: Evaluation results with metrics """ results = [] - + for test_case in test_cases: result = self._evaluate_test_case(test_case, industry) results.append(result) - + # Calculate aggregate metrics - total_precision = sum(r.precision for r in results) / len(results) if results else 0.0 + total_precision = ( + sum(r.precision for r in results) / len(results) if results else 0.0 + ) total_recall = sum(r.recall for r in results) / len(results) if results else 0.0 total_f1 = sum(r.f1_score for r in results) / len(results) if results else 0.0 passed_count = sum(1 for r in results if r.passed) - + return { "total_tests": len(test_cases), "passed": passed_count, @@ -97,11 +99,9 @@ def evaluate( "pass_rate": passed_count / len(test_cases) if test_cases else 0.0, "results": [r.to_dict() for r in results], } - + def _evaluate_test_case( - self, - test_case: TestCase, - industry: Optional[IndustryNiche] + self, test_case: TestCase, industry: Optional[IndustryNiche] ) -> TestResult: """Evaluate a single test case""" try: @@ -110,20 +110,20 @@ def _evaluate_test_case( query=test_case.query, industry=industry, doc_types=test_case.expected_doc_types, - top_k=test_case.top_k + top_k=test_case.top_k, ) - + # Retrieve documents results = self.rag_engine.retrieve(query) - + # Extract IDs and scores retrieved_doc_ids = [r.document.id for r in results] retrieved_scores = [r.score for r in results] - + # Calculate precision, recall, F1 expected_set = set(test_case.expected_doc_ids) retrieved_set = set(retrieved_doc_ids) - + if not retrieved_set: precision = 0.0 recall = 0.0 @@ -131,22 +131,30 @@ def _evaluate_test_case( else: # Precision: relevant retrieved / total retrieved relevant_retrieved = len(expected_set & retrieved_set) - precision = relevant_retrieved / len(retrieved_set) if retrieved_set else 0.0 - + precision = ( + relevant_retrieved / len(retrieved_set) if retrieved_set else 0.0 + ) + # Recall: relevant retrieved / total relevant recall = relevant_retrieved / len(expected_set) if expected_set else 0.0 - + # F1 score - f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0 - + f1_score = ( + 2 * (precision * recall) / (precision + recall) + if (precision + recall) > 0 + else 0.0 + ) + # Check if passed passed = ( - precision >= 0.7 and - recall >= 0.7 and - f1_score >= 0.7 and - (not retrieved_scores or max(retrieved_scores) >= test_case.min_score) + precision >= 0.7 + and recall >= 0.7 + and f1_score >= 0.7 + and ( + not retrieved_scores or max(retrieved_scores) >= test_case.min_score + ) ) - + return TestResult( test_case=test_case, retrieved_doc_ids=retrieved_doc_ids, @@ -154,9 +162,9 @@ def _evaluate_test_case( precision=precision, recall=recall, f1_score=f1_score, - passed=passed + passed=passed, ) - + except Exception as e: return TestResult( test_case=test_case, @@ -166,7 +174,7 @@ def _evaluate_test_case( recall=0.0, f1_score=0.0, passed=False, - error=str(e) + error=str(e), ) @@ -174,44 +182,43 @@ class RAGTestFixture: """ Test fixture for creating test data and scenarios """ - + @staticmethod def create_test_documents( - industry: IndustryNiche = IndustryNiche.MANUFACTURING, - num_documents: int = 10 + industry: IndustryNiche = IndustryNiche.MANUFACTURING, num_documents: int = 10 ) -> List[Document]: """Create test documents""" documents = [] - + for i in range(num_documents): doc = Document( id=f"test_doc_{i}", content=f"Test document {i} content. This is sample content for testing RAG retrieval.", - doc_type=DocumentType.MANUAL if i % 2 == 0 else DocumentType.MAINTENANCE_LOG, + doc_type=( + DocumentType.MANUAL if i % 2 == 0 else DocumentType.MAINTENANCE_LOG + ), industry=industry, title=f"Test Document {i}", source="test_fixture", - metadata={"test_index": i} + metadata={"test_index": i}, ) documents.append(doc) - + return documents - + @staticmethod - def create_test_cases( - num_cases: int = 5 - ) -> List[TestCase]: + def create_test_cases(num_cases: int = 5) -> List[TestCase]: """Create test cases""" test_cases = [] - + for i in range(num_cases): test_case = TestCase( query=f"test query {i}", expected_doc_ids=[f"test_doc_{i}", f"test_doc_{i+1}"], - top_k=5 + top_k=5, ) test_cases.append(test_case) - + return test_cases @@ -219,47 +226,45 @@ class PerformanceBenchmark: """ Performance benchmarking for RAG operations """ - + def __init__(self, rag_engine): self.rag_engine = rag_engine - + def benchmark_retrieval( - self, - queries: List[str], - iterations: int = 10 + self, queries: List[str], iterations: int = 10 ) -> Dict[str, Any]: """ Benchmark retrieval performance - + Args: queries: List of test queries iterations: Number of iterations per query - + Returns: Performance metrics """ import time - + total_time = 0.0 total_queries = 0 times = [] - + for query_text in queries: query = RAGQuery(query=query_text, top_k=5) - + for _ in range(iterations): start = time.time() results = self.rag_engine.retrieve(query) elapsed = time.time() - start - + total_time += elapsed total_queries += 1 times.append(elapsed * 1000) # Convert to ms - + avg_time_ms = (total_time / total_queries) * 1000 if total_queries > 0 else 0.0 min_time_ms = min(times) if times else 0.0 max_time_ms = max(times) if times else 0.0 - + return { "total_queries": total_queries, "avg_time_ms": avg_time_ms, @@ -267,66 +272,63 @@ def benchmark_retrieval( "max_time_ms": max_time_ms, "queries_per_second": total_queries / total_time if total_time > 0 else 0.0, } - + def benchmark_indexing( - self, - documents: List[Document], - batch_sizes: List[int] = [1, 10, 100] + self, documents: List[Document], batch_sizes: List[int] = [1, 10, 100] ) -> Dict[str, Any]: """ Benchmark indexing performance - + Args: documents: Documents to index batch_sizes: Different batch sizes to test - + Returns: Performance metrics for each batch size """ import time - + results = {} - + for batch_size in batch_sizes: batches = [ - documents[i:i + batch_size] + documents[i : i + batch_size] for i in range(0, len(documents), batch_size) ] - + start = time.time() for batch in batches: self.rag_engine.index_documents(batch) elapsed = time.time() - start - + results[f"batch_size_{batch_size}"] = { "total_documents": len(documents), "num_batches": len(batches), "total_time_seconds": elapsed, - "avg_time_per_doc_ms": (elapsed / len(documents)) * 1000 if documents else 0.0, + "avg_time_per_doc_ms": ( + (elapsed / len(documents)) * 1000 if documents else 0.0 + ), "docs_per_second": len(documents) / elapsed if elapsed > 0 else 0.0, } - + return results def create_evaluation_dataset( - industry: IndustryNiche, - num_documents: int = 100, - num_queries: int = 20 + industry: IndustryNiche, num_documents: int = 100, num_queries: int = 20 ) -> Tuple[List[Document], List[TestCase]]: """ Create a complete evaluation dataset - + Args: industry: Industry for documents num_documents: Number of documents to create num_queries: Number of test queries - + Returns: Tuple of (documents, test_cases) """ documents = RAGTestFixture.create_test_documents(industry, num_documents) test_cases = RAGTestFixture.create_test_cases(num_queries) - - return documents, test_cases + return documents, test_cases diff --git a/src/deepiri_modelkit/registry/adapters/__init__.py b/src/deepiri_modelkit/registry/adapters/__init__.py index e949845..44e0e80 100644 --- a/src/deepiri_modelkit/registry/adapters/__init__.py +++ b/src/deepiri_modelkit/registry/adapters/__init__.py @@ -1,2 +1 @@ """Storage adapters for model registry""" - diff --git a/src/deepiri_modelkit/registry/model_registry.py b/src/deepiri_modelkit/registry/model_registry.py index 5c5ffdd..1d2fe3c 100644 --- a/src/deepiri_modelkit/registry/model_registry.py +++ b/src/deepiri_modelkit/registry/model_registry.py @@ -2,6 +2,7 @@ Unified model registry client Supports MLflow, S3/MinIO, and local storage """ + import os from typing import Dict, Any, Optional from pathlib import Path @@ -17,7 +18,7 @@ class ModelRegistryClient: Unified client for model registry operations Supports MLflow, S3/MinIO, and local storage """ - + def __init__( self, registry_type: str = "mlflow", # mlflow, s3, local @@ -26,11 +27,11 @@ def __init__( s3_access_key: Optional[str] = None, s3_secret_key: Optional[str] = None, s3_bucket: Optional[str] = None, - local_path: Optional[str] = None + local_path: Optional[str] = None, ): """ Initialize model registry client - + Args: registry_type: Type of registry (mlflow, s3, local) mlflow_tracking_uri: MLflow tracking URI (defaults to MLFLOW_TRACKING_URI env var or http://mlflow:5000) @@ -41,19 +42,23 @@ def __init__( local_path: Local storage path """ self.registry_type = registry_type - + if registry_type == "mlflow": # Use provided URI, or environment variable, or default - tracking_uri = mlflow_tracking_uri or os.getenv("MLFLOW_TRACKING_URI") or "http://mlflow:5000" + tracking_uri = ( + mlflow_tracking_uri + or os.getenv("MLFLOW_TRACKING_URI") + or "http://mlflow:5000" + ) mlflow.set_tracking_uri(tracking_uri) self.client = mlflow self.tracking_uri = tracking_uri elif registry_type == "s3": self.s3_client = boto3.client( - 's3', + "s3", endpoint_url=s3_endpoint, aws_access_key_id=s3_access_key, - aws_secret_access_key=s3_secret_key + aws_secret_access_key=s3_secret_key, ) self.s3_bucket = s3_bucket elif registry_type == "local": @@ -61,23 +66,19 @@ def __init__( self.local_path.mkdir(parents=True, exist_ok=True) else: raise ValueError(f"Unknown registry type: {registry_type}") - + def register_model( - self, - model_name: str, - version: str, - model_path: str, - metadata: Dict[str, Any] + self, model_name: str, version: str, model_path: str, metadata: Dict[str, Any] ) -> bool: """ Register model in registry - + Args: model_name: Model name version: Model version model_path: Path to model file/directory metadata: Model metadata - + Returns: True if successful """ @@ -87,56 +88,55 @@ def register_model( model_uri = f"runs:/{metadata.get('run_id', 'latest')}/model" mlflow.register_model(model_uri, f"{model_name}-{version}") return True - + elif self.registry_type == "s3": # Upload to S3 s3_key = f"models/{model_name}/{version}/model" self.s3_client.upload_file(model_path, self.s3_bucket, s3_key) - + # Upload metadata import json + metadata_key = f"models/{model_name}/{version}/metadata.json" self.s3_client.put_object( - Bucket=self.s3_bucket, - Key=metadata_key, - Body=json.dumps(metadata) + Bucket=self.s3_bucket, Key=metadata_key, Body=json.dumps(metadata) ) return True - + elif self.registry_type == "local": # Copy to local storage model_dir = self.local_path / model_name / version model_dir.mkdir(parents=True, exist_ok=True) - + import shutil + if os.path.isdir(model_path): shutil.copytree(model_path, model_dir / "model", dirs_exist_ok=True) else: shutil.copy2(model_path, model_dir / "model") - + # Save metadata import json + with open(model_dir / "metadata.json", "w") as f: json.dump(metadata, f) - + return True - + except Exception as e: print(f"Error registering model: {e}") return False - + def get_model( - self, - model_name: str, - version: Optional[str] = None + self, model_name: str, version: Optional[str] = None ) -> Dict[str, Any]: """ Get model information from registry - + Args: model_name: Model name version: Model version (optional, gets latest if not specified) - + Returns: Model information dict """ @@ -146,95 +146,91 @@ def get_model( model_uri = f"models:/{model_name}/{version}" else: model_uri = f"models:/{model_name}/latest" - + model = mlflow.pyfunc.load_model(model_uri) - return { - "model": model, - "uri": model_uri, - "type": "mlflow" - } - + return {"model": model, "uri": model_uri, "type": "mlflow"} + elif self.registry_type == "s3": if not version: # List versions and get latest prefix = f"models/{model_name}/" response = self.s3_client.list_objects_v2( - Bucket=self.s3_bucket, - Prefix=prefix, - Delimiter="/" + Bucket=self.s3_bucket, Prefix=prefix, Delimiter="/" ) - versions = [obj["Prefix"].split("/")[-2] for obj in response.get("CommonPrefixes", [])] + versions = [ + obj["Prefix"].split("/")[-2] + for obj in response.get("CommonPrefixes", []) + ] version = max(versions) if versions else None - + if not version: raise ValueError(f"Model {model_name} not found") - + # Download metadata metadata_key = f"models/{model_name}/{version}/metadata.json" - response = self.s3_client.get_object(Bucket=self.s3_bucket, Key=metadata_key) + response = self.s3_client.get_object( + Bucket=self.s3_bucket, Key=metadata_key + ) import json + metadata = json.loads(response["Body"].read()) - + return { "model_path": f"s3://{self.s3_bucket}/models/{model_name}/{version}/model", "metadata": metadata, - "type": "s3" + "type": "s3", } - + elif self.registry_type == "local": if not version: # Get latest version model_dir = self.local_path / model_name if not model_dir.exists(): raise ValueError(f"Model {model_name} not found") - + versions = [d.name for d in model_dir.iterdir() if d.is_dir()] version = max(versions) if versions else None - + if not version: raise ValueError(f"Model {model_name} not found") - + model_dir = self.local_path / model_name / version metadata_path = model_dir / "metadata.json" - + import json + with open(metadata_path) as f: metadata = json.load(f) - + return { "model_path": str(model_dir / "model"), "metadata": metadata, - "type": "local" + "type": "local", } - + except Exception as e: print(f"Error getting model: {e}") raise - - def download_model( - self, - model_name: str, - version: str, - destination: str - ) -> str: + + def download_model(self, model_name: str, version: str, destination: str) -> str: """ Download model to destination - + Args: model_name: Model name version: Model version destination: Local destination path - + Returns: Local path to downloaded model """ model_info = self.get_model(model_name, version) - + if self.registry_type == "s3": # Download from S3 s3_key = f"models/{model_name}/{version}/model" os.makedirs(destination, exist_ok=True) - + # Check if it's a file or directory try: self.s3_client.head_object(Bucket=self.s3_bucket, Key=s3_key) @@ -245,92 +241,93 @@ def download_model( except ClientError: # It's a directory, list and download all files prefix = f"{s3_key}/" - paginator = self.s3_client.get_paginator('list_objects_v2') + paginator = self.s3_client.get_paginator("list_objects_v2") for page in paginator.paginate(Bucket=self.s3_bucket, Prefix=prefix): - for obj in page.get('Contents', []): - key = obj['Key'] - local_file = os.path.join(destination, key[len(prefix):]) + for obj in page.get("Contents", []): + key = obj["Key"] + local_file = os.path.join(destination, key[len(prefix) :]) os.makedirs(os.path.dirname(local_file), exist_ok=True) self.s3_client.download_file(self.s3_bucket, key, local_file) return destination - + elif self.registry_type == "local": # Copy from local source = model_info["model_path"] import shutil + if os.path.isdir(source): shutil.copytree(source, destination, dirs_exist_ok=True) else: os.makedirs(os.path.dirname(destination), exist_ok=True) shutil.copy2(source, destination) return destination - + else: # MLflow handles loading directly return model_info["uri"] - + def list_models(self, model_name: Optional[str] = None) -> list: """ List available models - + Args: model_name: Optional model name filter - + Returns: List of model information dicts """ try: if self.registry_type == "mlflow": if model_name: - models = mlflow.search_registered_models(filter_string=f"name='{model_name}'") + models = mlflow.search_registered_models( + filter_string=f"name='{model_name}'" + ) else: models = mlflow.search_registered_models() - + return [ { "name": m.name, "versions": [v.version for v in m.latest_versions], - "latest_version": m.latest_versions[0].version if m.latest_versions else None + "latest_version": ( + m.latest_versions[0].version if m.latest_versions else None + ), } for m in models ] - + elif self.registry_type == "s3": prefix = "models/" if model_name: prefix = f"models/{model_name}/" - + response = self.s3_client.list_objects_v2( - Bucket=self.s3_bucket, - Prefix=prefix, - Delimiter="/" + Bucket=self.s3_bucket, Prefix=prefix, Delimiter="/" ) - + models = [] for prefix_obj in response.get("CommonPrefixes", []): model_path = prefix_obj["Prefix"] parts = model_path.strip("/").split("/") if len(parts) >= 2: - models.append({ - "name": parts[1], - "path": model_path - }) - + models.append({"name": parts[1], "path": model_path}) + return models - + elif self.registry_type == "local": models = [] for model_dir in self.local_path.iterdir(): if model_dir.is_dir(): versions = [d.name for d in model_dir.iterdir() if d.is_dir()] - models.append({ - "name": model_dir.name, - "versions": versions, - "latest_version": max(versions) if versions else None - }) + models.append( + { + "name": model_dir.name, + "versions": versions, + "latest_version": max(versions) if versions else None, + } + ) return models - + except Exception as e: print(f"Error listing models: {e}") return [] - diff --git a/src/deepiri_modelkit/streaming/event_stream.py b/src/deepiri_modelkit/streaming/event_stream.py index 40e5ed4..904fe0b 100644 --- a/src/deepiri_modelkit/streaming/event_stream.py +++ b/src/deepiri_modelkit/streaming/event_stream.py @@ -1,6 +1,7 @@ """ Redis Streams client for event-driven architecture """ + import redis.asyncio as redis from typing import Dict, Any, Optional, AsyncIterator, Callable import json @@ -15,17 +16,17 @@ class StreamingClient: """ Redis Streams client for publishing and subscribing to events """ - + def __init__( self, redis_url: Optional[str] = None, redis_host: str = "redis", redis_port: int = 6379, - redis_password: Optional[str] = None + redis_password: Optional[str] = None, ): """ Initialize streaming client - + Args: redis_url: Full Redis URL (redis://password@host:port) redis_host: Redis host (if not using redis_url) @@ -39,50 +40,44 @@ def __init__( host=redis_host, port=redis_port, password=redis_password, - decode_responses=True + decode_responses=True, ) self._running = False self._subscriptions = {} - + async def connect(self): """Connect to Redis""" await self.redis.ping() - + async def disconnect(self): """Disconnect from Redis""" await self.redis.close() - + async def publish( - self, - topic: str, - event: Dict[str, Any], - max_length: Optional[int] = 10000 + self, topic: str, event: Dict[str, Any], max_length: Optional[int] = 10000 ) -> str: """ Publish event to stream - + Args: topic: Stream topic name event: Event data (dict) max_length: Max stream length (truncate old messages) - + Returns: Message ID """ # Ensure event has timestamp if "timestamp" not in event: event["timestamp"] = datetime.utcnow().isoformat() - + # Publish to stream message_id = await self.redis.xadd( - topic, - event, - maxlen=max_length, - approximate=True + topic, event, maxlen=max_length, approximate=True ) - + return message_id - + async def subscribe( self, topic: str, @@ -90,11 +85,11 @@ async def subscribe( consumer_group: Optional[str] = None, consumer_name: Optional[str] = None, last_id: str = "0", - block_ms: int = 1000 + block_ms: int = 1000, ) -> AsyncIterator[Dict[str, Any]]: """ Subscribe to stream and yield events - + Args: topic: Stream topic name callback: Optional callback function @@ -102,7 +97,7 @@ async def subscribe( consumer_name: Consumer name (unique per consumer) last_id: Last message ID to read from block_ms: Block time in milliseconds - + Yields: Event data (dict) """ @@ -110,17 +105,14 @@ async def subscribe( if consumer_group: try: await self.redis.xgroup_create( - topic, - consumer_group, - id="0", - mkstream=True + topic, consumer_group, id="0", mkstream=True ) except redis.ResponseError as e: if "BUSYGROUP" not in str(e): raise - + self._running = True - + while self._running: try: if consumer_group and consumer_name: @@ -130,51 +122,53 @@ async def subscribe( consumer_name, {topic: ">"}, count=10, - block=block_ms + block=block_ms, ) else: # Direct read messages = await self.redis.xread( - {topic: last_id}, - count=10, - block=block_ms + {topic: last_id}, count=10, block=block_ms ) - + for stream_name, stream_messages in messages: for msg_id, data in stream_messages: # Yield event yield data - + # Call callback if provided if callback: try: - await callback(data) if asyncio.iscoroutinefunction(callback) else callback(data) + ( + await callback(data) + if asyncio.iscoroutinefunction(callback) + else callback(data) + ) except Exception as e: print(f"Callback error: {e}") - + # Update last_id for next read last_id = msg_id - + # Acknowledge if using consumer group if consumer_group and consumer_name: await self.redis.xack(topic, consumer_group, msg_id) - + except asyncio.CancelledError: break except Exception as e: print(f"Stream subscription error: {e}") await asyncio.sleep(1) - + async def subscribe_async( self, topic: str, callback: Callable[[Dict[str, Any]], None], consumer_group: Optional[str] = None, - consumer_name: Optional[str] = None + consumer_name: Optional[str] = None, ): """ Subscribe to stream in background task - + Args: topic: Stream topic name callback: Callback function @@ -182,23 +176,19 @@ async def subscribe_async( consumer_name: Consumer name """ async for event in self.subscribe( - topic, - callback, - consumer_group, - consumer_name + topic, callback, consumer_group, consumer_name ): pass # Callback handles events - + def stop(self): """Stop all subscriptions""" self._running = False - + async def get_stream_info(self, topic: str) -> Dict[str, Any]: """Get stream information""" info = await self.redis.xinfo_stream(topic) return dict(info) - + async def get_stream_length(self, topic: str) -> int: """Get number of messages in stream""" return await self.redis.xlen(topic) - diff --git a/src/deepiri_modelkit/streaming/schemas.py b/src/deepiri_modelkit/streaming/schemas.py index 388fa31..830cbae 100644 --- a/src/deepiri_modelkit/streaming/schemas.py +++ b/src/deepiri_modelkit/streaming/schemas.py @@ -1,6 +1,7 @@ """ Streaming event schemas and validation """ + from .topics import StreamTopics from ..contracts.events import ( BaseEvent, @@ -12,7 +13,6 @@ TrainingEvent, ) - # Map topics to event schemas TOPIC_EVENT_SCHEMAS = { StreamTopics.MODEL_EVENTS: [ModelReadyEvent, ModelLoadedEvent], @@ -26,31 +26,30 @@ def validate_event(topic: str, event_data: dict) -> BaseEvent: """ Validate event against schema - + Args: topic: Stream topic event_data: Event data dict - + Returns: Validated event object - + Raises: ValueError: If event doesn't match schema """ if topic not in TOPIC_EVENT_SCHEMAS: # Unknown topic, return base event return BaseEvent(**event_data) - + schemas = TOPIC_EVENT_SCHEMAS[topic] event_type = event_data.get("event") - + # Try to match event type to schema for schema in schemas: try: return schema(**event_data) except Exception: continue - + # Fallback to base event return BaseEvent(**event_data) - diff --git a/src/deepiri_modelkit/streaming/sidecar_utils.py b/src/deepiri_modelkit/streaming/sidecar_utils.py index 3af5322..1ec9faf 100644 --- a/src/deepiri_modelkit/streaming/sidecar_utils.py +++ b/src/deepiri_modelkit/streaming/sidecar_utils.py @@ -13,7 +13,9 @@ from urllib.parse import urlparse -def env_float(name: str, default: float, logger: Optional[Callable[[str], None]] = None) -> float: +def env_float( + name: str, default: float, logger: Optional[Callable[[str], None]] = None +) -> float: """Read a float env var with safe fallback.""" raw = os.getenv(name) if raw is None: diff --git a/src/deepiri_modelkit/streaming/topics.py b/src/deepiri_modelkit/streaming/topics.py index b609448..0e50984 100644 --- a/src/deepiri_modelkit/streaming/topics.py +++ b/src/deepiri_modelkit/streaming/topics.py @@ -1,27 +1,28 @@ """ Stream topic definitions """ + from enum import Enum -class StreamTopics(str, Enum): - """Redis Stream topics""" - MODEL_EVENTS = "model-events" - INFERENCE_EVENTS = "inference-events" - PLATFORM_EVENTS = "platform-events" - AGI_DECISIONS = "agi-decisions" - TRAINING_EVENTS = "training-events" - # LIS document routing streams (document.* namespace). - DOCUMENT_VECTORIZE = "document.vectorize" - DOCUMENT_TRAINING = "document.training" - DOCUMENT_STRUCTURED = "document.structured" - DOCUMENT_ARTIFACTS = "document.artifacts" - # Cyrex runtime training signals consumed by Helox. - HELOX_TRAINING_RAW = "pipeline.helox-training.raw" - HELOX_TRAINING_STRUCTURED = "pipeline.helox-training.structured" - - @classmethod - def all(cls) -> list: +class StreamTopics(str, Enum): + """Redis Stream topics""" + + MODEL_EVENTS = "model-events" + INFERENCE_EVENTS = "inference-events" + PLATFORM_EVENTS = "platform-events" + AGI_DECISIONS = "agi-decisions" + TRAINING_EVENTS = "training-events" + # LIS document routing streams (document.* namespace). + DOCUMENT_VECTORIZE = "document.vectorize" + DOCUMENT_TRAINING = "document.training" + DOCUMENT_STRUCTURED = "document.structured" + DOCUMENT_ARTIFACTS = "document.artifacts" + # Cyrex runtime training signals consumed by Helox. + HELOX_TRAINING_RAW = "pipeline.helox-training.raw" + HELOX_TRAINING_STRUCTURED = "pipeline.helox-training.structured" + + @classmethod + def all(cls) -> list: """Get all topic names""" return [topic.value for topic in cls] - diff --git a/src/deepiri_modelkit/utils/__init__.py b/src/deepiri_modelkit/utils/__init__.py index a7dfcc9..dbfdac6 100644 --- a/src/deepiri_modelkit/utils/__init__.py +++ b/src/deepiri_modelkit/utils/__init__.py @@ -2,6 +2,7 @@ try: from .device import get_device, get_torch_device + __all__ = ["get_device", "get_torch_device"] except ImportError: __all__ = [] diff --git a/src/deepiri_modelkit/utils/device.py b/src/deepiri_modelkit/utils/device.py index 8b3922a..9c08107 100644 --- a/src/deepiri_modelkit/utils/device.py +++ b/src/deepiri_modelkit/utils/device.py @@ -2,10 +2,12 @@ GPU Device Detection Utility Automatically detects and uses GPU (CUDA) if available, falls back to CPU """ + import os try: import torch + HAS_TORCH = True except ImportError: HAS_TORCH = False @@ -28,7 +30,9 @@ def get_device() -> str: # Log diagnostic information for debugging logger.debug(f"PyTorch version: {torch.__version__}") - logger.debug(f"CUDA available (torch.cuda.is_available()): {torch.cuda.is_available()}") + logger.debug( + f"CUDA available (torch.cuda.is_available()): {torch.cuda.is_available()}" + ) if torch.cuda.is_available(): try: @@ -50,14 +54,17 @@ def get_device() -> str: # Try to get the list of supported compute capabilities # PyTorch 2.9.1 with CUDA 12.6 supports up to sm_90 # RTX 5080/5090 requires sm_120 support (CUDA 12.8+) - test_tensor = torch.tensor([1.0], device='cuda') + test_tensor = torch.tensor([1.0], device="cuda") result = test_tensor * 2.0 _ = result.cpu() del test_tensor, result torch.cuda.empty_cache() except RuntimeError as e: - if "no kernel image is available for execution on the device" in str(e) or \ - "cudaErrorNoKernelImageForDevice" in str(e): + if ( + "no kernel image is available for execution on the device" + in str(e) + or "cudaErrorNoKernelImageForDevice" in str(e) + ): logger.error( f"RTX 5080/5090 (sm_{cuda_capability[0]}.{cuda_capability[1]}) detected, but PyTorch doesn't support this compute capability. " f"Current PyTorch supports up to sm_90. " @@ -69,7 +76,7 @@ def get_device() -> str: raise # Test GPU functionality with a simple operation - test_tensor = torch.tensor([1.0], device='cuda') + test_tensor = torch.tensor([1.0], device="cuda") result = test_tensor * 2.0 _ = result.cpu() # Ensure operation completes del test_tensor, result @@ -83,8 +90,10 @@ def get_device() -> str: except RuntimeError as cuda_error: error_msg = str(cuda_error) # Check if this is the RTX 5080 compatibility issue - if "no kernel image is available for execution on the device" in error_msg or \ - "cudaErrorNoKernelImageForDevice" in error_msg: + if ( + "no kernel image is available for execution on the device" in error_msg + or "cudaErrorNoKernelImageForDevice" in error_msg + ): logger.error( f"GPU compute capability not supported by current PyTorch installation. " f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'Unknown'}, " @@ -110,7 +119,9 @@ def get_device() -> str: # Check if we're in Docker and might need NVIDIA Container Toolkit if os.path.exists("/.dockerenv"): - logger.debug("Running in Docker container - ensure NVIDIA Container Toolkit is installed") + logger.debug( + "Running in Docker container - ensure NVIDIA Container Toolkit is installed" + ) # Check for NVIDIA runtime if os.path.exists("/proc/driver/nvidia"): logger.warning( @@ -120,16 +131,18 @@ def get_device() -> str: ) # Check MPS (Apple Silicon) - if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): + if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): try: - test_tensor = torch.tensor([1.0], device='mps') + test_tensor = torch.tensor([1.0], device="mps") result = test_tensor * 2.0 _ = result.cpu() del test_tensor, result logger.info("Apple Silicon (MPS) detected and tested successfully") return "mps" except Exception as mps_error: - logger.warning(f"MPS available but test failed, falling back to CPU: {mps_error}") + logger.warning( + f"MPS available but test failed, falling back to CPU: {mps_error}" + ) # Fallback to CPU logger.info("Using CPU device (no GPU detected or GPU test failed)") @@ -139,5 +152,7 @@ def get_device() -> str: def get_torch_device() -> "torch.device": """Get PyTorch device object""" if not HAS_TORCH: - raise ImportError("torch is required for get_torch_device(). Install with: pip install torch") + raise ImportError( + "torch is required for get_torch_device(). Install with: pip install torch" + ) return torch.device(get_device())