diff --git a/src/deepiri_modelkit/__init__.py b/src/deepiri_modelkit/__init__.py index 4d3f09b..227bf54 100644 --- a/src/deepiri_modelkit/__init__.py +++ b/src/deepiri_modelkit/__init__.py @@ -4,7 +4,13 @@ __version__ = "0.1.0" -from .contracts.models import AIModel, AIModelPydantic, ModelInput, ModelOutput, ModelMetadata +from .contracts.models import ( + AIModel, + AIModelPydantic, + ModelInput, + ModelOutput, + ModelMetadata, +) from .contracts.events import ( ModelReadyEvent, InferenceEvent, @@ -33,4 +39,3 @@ "get_error_logger", "ErrorLogger", ] - diff --git a/src/deepiri_modelkit/contracts/contract.py b/src/deepiri_modelkit/contracts/contract.py index 00abc21..b16aa1a 100644 --- a/src/deepiri_modelkit/contracts/contract.py +++ b/src/deepiri_modelkit/contracts/contract.py @@ -1,6 +1,7 @@ """ Model contract for registry (separated from models.py to avoid Pydantic Protocol conflicts) """ + from __future__ import annotations from typing import Dict, Any, Optional @@ -12,16 +13,18 @@ class ModelContract(BaseModel): """ Complete model contract for registry. - + A contract is serializable metadata that describes a model's interface, input/output schemas, and validation requirements. It does NOT contain the actual model instance (which would be a Protocol type that Pydantic cannot serialize). The model instance should be loaded separately when needed. """ + metadata: ModelMetadata input_schema: Dict[str, Any] output_schema: Dict[str, Any] validation_tests: Optional[list] = None - model_path: Optional[str] = None # Path/reference to where the model can be loaded from + model_path: Optional[str] = ( + None # Path/reference to where the model can be loaded from + ) model_id: Optional[str] = None # Unique identifier for the model instance - diff --git a/src/deepiri_modelkit/contracts/events.py b/src/deepiri_modelkit/contracts/events.py index 9711a93..fefe582 100644 --- a/src/deepiri_modelkit/contracts/events.py +++ b/src/deepiri_modelkit/contracts/events.py @@ -1,6 +1,7 @@ """ Event schemas for streaming service """ + from pydantic import BaseModel, Field from typing import Dict, Any, Optional from datetime import datetime @@ -9,6 +10,7 @@ class EventType(str, Enum): """Event type enumeration""" + MODEL_READY = "model-ready" MODEL_LOADED = "model-loaded" MODEL_FAILED = "model-failed" @@ -26,6 +28,7 @@ class EventType(str, Enum): class BaseEvent(BaseModel): """Base event schema""" + event: str timestamp: str = Field(default_factory=lambda: datetime.utcnow().isoformat()) source: str @@ -34,6 +37,7 @@ class BaseEvent(BaseModel): class ModelReadyEvent(BaseEvent): """Event published when model is trained and ready""" + event: str = EventType.MODEL_READY model_name: str version: str @@ -46,6 +50,7 @@ class ModelReadyEvent(BaseEvent): class ModelLoadedEvent(BaseEvent): """Event published when model is loaded in runtime""" + event: str = EventType.MODEL_LOADED model_name: str version: str @@ -55,6 +60,7 @@ class ModelLoadedEvent(BaseEvent): class InferenceEvent(BaseEvent): """Event published after inference completes""" + event: str = EventType.INFERENCE_COMPLETE model_name: str version: str @@ -69,6 +75,7 @@ class InferenceEvent(BaseEvent): class PlatformEvent(BaseEvent): """Event published by platform services""" + event: str # user-interaction, task-created, etc. service: str user_id: Optional[str] = None @@ -79,6 +86,7 @@ class PlatformEvent(BaseEvent): class AGIDecisionEvent(BaseEvent): """Event published by Cyrex-AGI for autonomous decisions""" + event: str = EventType.AGI_DECISION decision_type: str target_service: Optional[str] = None @@ -89,6 +97,7 @@ class AGIDecisionEvent(BaseEvent): class TrainingEvent(BaseEvent): """Event published during training""" + event: str # training-started, training-complete, training-failed experiment_id: str model_name: str @@ -96,4 +105,3 @@ class TrainingEvent(BaseEvent): progress: Optional[float] = None # 0.0 to 1.0 metrics: Optional[Dict[str, Any]] = None error: Optional[str] = None - diff --git a/src/deepiri_modelkit/contracts/models.py b/src/deepiri_modelkit/contracts/models.py index 6e52336..bdc6568 100644 --- a/src/deepiri_modelkit/contracts/models.py +++ b/src/deepiri_modelkit/contracts/models.py @@ -1,7 +1,10 @@ """ Model contracts and interfaces """ -from __future__ import annotations # Defer annotation evaluation to prevent Pydantic from processing Protocol types + +from __future__ import ( + annotations, +) # Defer annotation evaluation to prevent Pydantic from processing Protocol types from typing import Protocol, Dict, Any, Optional, Annotated from pydantic import BaseModel, Field, GetCoreSchemaHandler @@ -11,6 +14,7 @@ class ModelInput(BaseModel): """Standard model input schema""" + data: Dict[str, Any] metadata: Optional[Dict[str, Any]] = None timestamp: str = Field(default_factory=lambda: datetime.utcnow().isoformat()) @@ -18,6 +22,7 @@ class ModelInput(BaseModel): class ModelOutput(BaseModel): """Standard model output schema""" + prediction: Any confidence: Optional[float] = None metadata: Optional[Dict[str, Any]] = None @@ -26,6 +31,7 @@ class ModelOutput(BaseModel): class ModelMetadata(BaseModel): """Model metadata schema""" + name: str version: str description: Optional[str] = None @@ -41,23 +47,23 @@ class AIModel(Protocol): """ Interface that all models must implement Used by both Helox (training) and Cyrex (runtime) - + Note: This is a Protocol (structural type). To use in Pydantic models, use AIModelPydantic instead, which has full Pydantic schema support. """ - + def predict(self, input: ModelInput) -> ModelOutput: """Run inference on input""" ... - + def get_metadata(self) -> ModelMetadata: """Get model metadata""" ... - + def validate(self) -> bool: """Validate model is ready for use""" ... - + def export(self, format: str = "onnx") -> str: """Export model to specified format, returns path""" ... @@ -67,14 +73,14 @@ class AIModelPydantic: """ Pydantic-compatible wrapper for AIModel Protocol. Implements __get_pydantic_core_schema__ to provide full schema support. - + Usage in Pydantic models: model: Optional[AIModelPydantic] = None - + Note: In practice, model instances should be loaded separately and referenced by ID/path rather than stored directly in serializable Pydantic models. """ - + @classmethod def __get_pydantic_core_schema__( cls, @@ -83,61 +89,61 @@ def __get_pydantic_core_schema__( ) -> core_schema.CoreSchema: """ Pydantic Core Schema handler for AIModel Protocol. - + This allows Pydantic to process AIModel types in model fields. Since Protocols are structural types, we validate that the object has the required methods rather than checking exact type. """ + def validate_aimodel(value: Any) -> Any: """Validate that value implements AIModel Protocol interface""" if value is None: return None - + # Check for required Protocol methods - required_methods = ['predict', 'get_metadata', 'validate', 'export'] + required_methods = ["predict", "get_metadata", "validate", "export"] missing_methods = [m for m in required_methods if not hasattr(value, m)] - + if missing_methods: raise ValueError( f"Object does not implement AIModel Protocol. " f"Missing methods: {', '.join(missing_methods)}" ) - + return value - + def serialize_aimodel(value: Any) -> Dict[str, Any]: """Serialize AIModel instance to dict""" if value is None: return None - + # Try to get metadata if available metadata = None - if hasattr(value, 'get_metadata'): + if hasattr(value, "get_metadata"): try: metadata = value.get_metadata() # Convert ModelMetadata to dict if it's a Pydantic model - if hasattr(metadata, 'model_dump'): + if hasattr(metadata, "model_dump"): metadata = metadata.model_dump() - elif hasattr(metadata, 'dict'): + elif hasattr(metadata, "dict"): metadata = metadata.dict() except Exception: pass - + return { "type": "AIModel", "metadata": metadata, - "has_predict": hasattr(value, 'predict'), - "has_validate": hasattr(value, 'validate'), + "has_predict": hasattr(value, "predict"), + "has_validate": hasattr(value, "validate"), } - + return core_schema.no_info_plain_validator_function( validate_aimodel, serialization=core_schema.plain_serializer_function_ser_schema( serialize_aimodel - ) + ), ) # ModelContract moved to contract.py to avoid Pydantic Protocol conflicts # Import it from .contract import ModelContract when needed - diff --git a/src/deepiri_modelkit/contracts/services.py b/src/deepiri_modelkit/contracts/services.py index 9011102..b844a46 100644 --- a/src/deepiri_modelkit/contracts/services.py +++ b/src/deepiri_modelkit/contracts/services.py @@ -1,58 +1,44 @@ """ Service contracts and interfaces """ + from typing import Protocol, Dict, Any, Optional from pydantic import BaseModel class ModelRegistryService(Protocol): """Interface for model registry operations""" - + def register_model( - self, - model_name: str, - version: str, - model_path: str, - metadata: Dict[str, Any] + self, model_name: str, version: str, model_path: str, metadata: Dict[str, Any] ) -> bool: """Register a model in the registry""" ... - + def get_model( - self, - model_name: str, - version: Optional[str] = None + self, model_name: str, version: Optional[str] = None ) -> Dict[str, Any]: """Get model information from registry""" ... - + def list_models(self, model_name: Optional[str] = None) -> list: """List available models""" ... - - def download_model( - self, - model_name: str, - version: str, - destination: str - ) -> str: + + def download_model(self, model_name: str, version: str, destination: str) -> str: """Download model to destination, returns local path""" ... class StreamingService(Protocol): """Interface for streaming operations""" - + def publish(self, topic: str, event: Dict[str, Any]) -> bool: """Publish event to topic""" ... - + def subscribe( - self, - topic: str, - callback: callable, - consumer_group: Optional[str] = None + self, topic: str, callback: callable, consumer_group: Optional[str] = None ) -> None: """Subscribe to topic with callback""" ... - diff --git a/src/deepiri_modelkit/data/monitoring.py b/src/deepiri_modelkit/data/monitoring.py index 3fe2d65..81b7a91 100644 --- a/src/deepiri_modelkit/data/monitoring.py +++ b/src/deepiri_modelkit/data/monitoring.py @@ -2,6 +2,7 @@ Dataset Monitoring and Logging Utilities Provides monitoring, alerting, and logging for dataset versioning operations """ + import json import time from pathlib import Path @@ -40,7 +41,7 @@ def __init__(self, log_dir: str = "./logs/dataset_monitoring"): "average_version_creation_time": 0, "validation_errors_today": 0, "last_health_check": None, - "storage_usage_bytes": 0 + "storage_usage_bytes": 0, } self._load_metrics() @@ -59,15 +60,17 @@ def log_version_creation(self, operation_data: Dict[str, Any]): "change_type": operation_data.get("change_type"), "quality_score": operation_data.get("quality_score"), "storage_path": operation_data.get("storage_path"), - "created_by": operation_data.get("created_by") + "created_by": operation_data.get("created_by"), } self._write_log_entry(self.metrics_file, log_entry) self.current_metrics["total_versions_created"] += 1 - logger.info("Version creation logged", - dataset=operation_data.get("dataset_name"), - version=operation_data.get("version")) + logger.info( + "Version creation logged", + dataset=operation_data.get("dataset_name"), + version=operation_data.get("version"), + ) def log_validation_result(self, validation_data: Dict[str, Any]): """Log dataset validation results.""" @@ -80,7 +83,7 @@ def log_validation_result(self, validation_data: Dict[str, Any]): "quality_score": validation_data.get("quality_score"), "error_count": len(validation_data.get("errors", [])), "warning_count": len(validation_data.get("warnings", [])), - "validation_time_seconds": validation_data.get("validation_time", 0) + "validation_time_seconds": validation_data.get("validation_time", 0), } self._write_log_entry(self.metrics_file, log_entry) @@ -90,17 +93,22 @@ def log_validation_result(self, validation_data: Dict[str, Any]): # Check for alerts if validation_data.get("quality_score", 1.0) < 0.7: - self._create_alert("low_quality_score", { - "dataset_name": validation_data.get("dataset_name"), - "version": validation_data.get("version"), - "quality_score": validation_data.get("quality_score"), - "errors": validation_data.get("errors", []) - }) - - logger.info("Validation result logged", - dataset=validation_data.get("dataset_name"), - valid=validation_data.get("is_valid"), - quality=validation_data.get("quality_score")) + self._create_alert( + "low_quality_score", + { + "dataset_name": validation_data.get("dataset_name"), + "version": validation_data.get("version"), + "quality_score": validation_data.get("quality_score"), + "errors": validation_data.get("errors", []), + }, + ) + + logger.info( + "Validation result logged", + dataset=validation_data.get("dataset_name"), + valid=validation_data.get("is_valid"), + quality=validation_data.get("quality_score"), + ) def log_training_usage(self, training_data: Dict[str, Any]): """Log dataset usage in training.""" @@ -113,15 +121,17 @@ def log_training_usage(self, training_data: Dict[str, Any]): "training_duration_seconds": training_data.get("training_duration", 0), "final_loss": training_data.get("final_loss"), "experiment_id": training_data.get("experiment_id"), - "output_model_path": training_data.get("output_model_path") + "output_model_path": training_data.get("output_model_path"), } self._write_log_entry(self.metrics_file, log_entry) - logger.info("Training usage logged", - dataset=training_data.get("dataset_name"), - version=training_data.get("dataset_version"), - model=training_data.get("model_name")) + logger.info( + "Training usage logged", + dataset=training_data.get("dataset_name"), + version=training_data.get("dataset_version"), + model=training_data.get("model_name"), + ) def get_health_report(self) -> Dict[str, Any]: """Generate comprehensive health report.""" @@ -130,13 +140,16 @@ def get_health_report(self) -> Dict[str, Any]: "summary": { "total_versions": self.current_metrics["total_versions_created"], "datasets_tracked": self.current_metrics["total_datasets_tracked"], - "validation_errors_today": self.current_metrics["validation_errors_today"], - "storage_usage_gb": self.current_metrics["storage_usage_bytes"] / (1024**3) + "validation_errors_today": self.current_metrics[ + "validation_errors_today" + ], + "storage_usage_gb": self.current_metrics["storage_usage_bytes"] + / (1024**3), }, "performance": self._analyze_performance(), "quality_trends": self._analyze_quality_trends(), "alerts": self._get_recent_alerts(), - "recommendations": self._generate_recommendations() + "recommendations": self._generate_recommendations(), } self.current_metrics["last_health_check"] = report["timestamp"] @@ -152,12 +165,12 @@ def get_usage_analytics(self, days: int = 30) -> Dict[str, Any]: "training_runs": [], "validation_runs": [], "popular_datasets": {}, - "quality_distribution": {} + "quality_distribution": {}, } # Read logs and filter by date if self.metrics_file.exists(): - with open(self.metrics_file, 'r') as f: + with open(self.metrics_file, "r") as f: for line in f: try: entry = json.loads(line.strip()) @@ -167,7 +180,9 @@ def get_usage_analytics(self, days: int = 30) -> Dict[str, Any]: if entry["operation"] == "version_creation": analytics["version_creations"].append(entry) dataset = entry.get("dataset_name", "unknown") - analytics["popular_datasets"][dataset] = analytics["popular_datasets"].get(dataset, 0) + 1 + analytics["popular_datasets"][dataset] = ( + analytics["popular_datasets"].get(dataset, 0) + 1 + ) elif entry["operation"] == "training_usage": analytics["training_runs"].append(entry) @@ -176,7 +191,12 @@ def get_usage_analytics(self, days: int = 30) -> Dict[str, Any]: analytics["validation_runs"].append(entry) quality = entry.get("quality_score", 0) quality_bucket = f"{int(quality * 10) / 10:.1f}" - analytics["quality_distribution"][quality_bucket] = analytics["quality_distribution"].get(quality_bucket, 0) + 1 + analytics["quality_distribution"][quality_bucket] = ( + analytics["quality_distribution"].get( + quality_bucket, 0 + ) + + 1 + ) except json.JSONDecodeError: continue @@ -189,7 +209,7 @@ def _analyze_performance(self) -> Dict[str, Any]: validation_times = [] if self.metrics_file.exists(): - with open(self.metrics_file, 'r') as f: + with open(self.metrics_file, "r") as f: for line in f: try: entry = json.loads(line.strip()) @@ -198,16 +218,22 @@ def _analyze_performance(self) -> Dict[str, Any]: creation_times.append(entry["creation_time_seconds"]) elif entry["operation"] == "validation": if "validation_time_seconds" in entry: - validation_times.append(entry["validation_time_seconds"]) + validation_times.append( + entry["validation_time_seconds"] + ) except json.JSONDecodeError: continue return { - "avg_version_creation_time": statistics.mean(creation_times) if creation_times else 0, - "avg_validation_time": statistics.mean(validation_times) if validation_times else 0, + "avg_version_creation_time": ( + statistics.mean(creation_times) if creation_times else 0 + ), + "avg_validation_time": ( + statistics.mean(validation_times) if validation_times else 0 + ), "total_operations": len(creation_times) + len(validation_times), "creation_times": creation_times[-10:], # Last 10 - "validation_times": validation_times[-10:] # Last 10 + "validation_times": validation_times[-10:], # Last 10 } def _analyze_quality_trends(self) -> Dict[str, Any]: @@ -215,11 +241,14 @@ def _analyze_quality_trends(self) -> Dict[str, Any]: quality_scores = [] if self.metrics_file.exists(): - with open(self.metrics_file, 'r') as f: + with open(self.metrics_file, "r") as f: for line in f: try: entry = json.loads(line.strip()) - if entry["operation"] == "validation" and "quality_score" in entry: + if ( + entry["operation"] == "validation" + and "quality_score" in entry + ): quality_scores.append(entry["quality_score"]) except json.JSONDecodeError: continue @@ -232,8 +261,8 @@ def _analyze_quality_trends(self) -> Dict[str, Any]: # Simple trend analysis if len(recent_scores) >= 10: - first_half = recent_scores[:len(recent_scores)//2] - second_half = recent_scores[len(recent_scores)//2:] + first_half = recent_scores[: len(recent_scores) // 2] + second_half = recent_scores[len(recent_scores) // 2 :] first_avg = statistics.mean(first_half) second_avg = statistics.mean(second_half) @@ -255,8 +284,8 @@ def _analyze_quality_trends(self) -> Dict[str, Any]: "excellent": len([s for s in quality_scores if s >= 0.9]), "good": len([s for s in quality_scores if 0.7 <= s < 0.9]), "fair": len([s for s in quality_scores if 0.5 <= s < 0.7]), - "poor": len([s for s in quality_scores if s < 0.5]) - } + "poor": len([s for s in quality_scores if s < 0.5]), + }, } def _generate_recommendations(self) -> List[str]: @@ -265,24 +294,34 @@ def _generate_recommendations(self) -> List[str]: # Check for frequent validation errors if self.current_metrics["validation_errors_today"] > 5: - recommendations.append("High validation error rate detected. Review data quality processes.") + recommendations.append( + "High validation error rate detected. Review data quality processes." + ) # Check quality trends quality_analysis = self._analyze_quality_trends() if quality_analysis.get("trend") == "declining": - recommendations.append("Dataset quality is declining. Consider reviewing data sources and annotation processes.") + recommendations.append( + "Dataset quality is declining. Consider reviewing data sources and annotation processes." + ) # Check performance performance = self._analyze_performance() if performance.get("avg_version_creation_time", 0) > 300: # 5 minutes - recommendations.append("Version creation is slow. Consider optimizing storage or processing.") + recommendations.append( + "Version creation is slow. Consider optimizing storage or processing." + ) # General recommendations if self.current_metrics["total_versions_created"] == 0: - recommendations.append("No dataset versions created yet. Start versioning your datasets for reproducibility.") + recommendations.append( + "No dataset versions created yet. Start versioning your datasets for reproducibility." + ) if not recommendations: - recommendations.append("System operating normally. Continue regular monitoring.") + recommendations.append( + "System operating normally. Continue regular monitoring." + ) return recommendations @@ -293,14 +332,12 @@ def _create_alert(self, alert_type: str, alert_data: Dict[str, Any]): "alert_type": alert_type, "severity": "warning", # Could be "info", "warning", "error" "data": alert_data, - "resolved": False + "resolved": False, } self._write_log_entry(self.alerts_file, alert_entry) - logger.warning("Alert created", - type=alert_type, - data=alert_data) + logger.warning("Alert created", type=alert_type, data=alert_data) def _get_recent_alerts(self, hours: int = 24) -> List[Dict[str, Any]]: """Get recent alerts.""" @@ -308,7 +345,7 @@ def _get_recent_alerts(self, hours: int = 24) -> List[Dict[str, Any]]: cutoff_time = datetime.utcnow() - timedelta(hours=hours) if self.alerts_file.exists(): - with open(self.alerts_file, 'r') as f: + with open(self.alerts_file, "r") as f: for line in f: try: alert = json.loads(line.strip()) @@ -322,22 +359,33 @@ def _get_recent_alerts(self, hours: int = 24) -> List[Dict[str, Any]]: def _write_log_entry(self, log_file: Path, entry: Dict[str, Any]): """Write a log entry to file.""" - with open(log_file, 'a') as f: - f.write(json.dumps(entry) + '\n') + with open(log_file, "a") as f: + f.write(json.dumps(entry) + "\n") def _load_metrics(self): """Load current metrics from log files.""" if self.metrics_file.exists(): try: - with open(self.metrics_file, 'r') as f: + with open(self.metrics_file, "r") as f: lines = f.readlines() if lines: # Count operations from logs - version_count = sum(1 for line in lines if '"operation": "version_creation"' in line) - validation_count = sum(1 for line in lines if '"operation": "validation"' in line and '"is_valid": false' in line) + version_count = sum( + 1 + for line in lines + if '"operation": "version_creation"' in line + ) + validation_count = sum( + 1 + for line in lines + if '"operation": "validation"' in line + and '"is_valid": false' in line + ) self.current_metrics["total_versions_created"] = version_count - self.current_metrics["validation_errors_today"] = validation_count + self.current_metrics["validation_errors_today"] = ( + validation_count + ) except Exception as e: logger.warning("Failed to load metrics from log", error=str(e)) @@ -346,21 +394,17 @@ def _load_metrics(self): def log_version_creation(dataset_name: str, version: str, **kwargs): """Convenience function to log version creation.""" monitor = DatasetMonitor() - monitor.log_version_creation({ - "dataset_name": dataset_name, - "version": version, - **kwargs - }) + monitor.log_version_creation( + {"dataset_name": dataset_name, "version": version, **kwargs} + ) def log_validation_result(dataset_name: str, version: str, **kwargs): """Convenience function to log validation results.""" monitor = DatasetMonitor() - monitor.log_validation_result({ - "dataset_name": dataset_name, - "version": version, - **kwargs - }) + monitor.log_validation_result( + {"dataset_name": dataset_name, "version": version, **kwargs} + ) def get_health_report(): diff --git a/src/deepiri_modelkit/data/validation.py b/src/deepiri_modelkit/data/validation.py index 25e7b2e..60ef3c7 100644 --- a/src/deepiri_modelkit/data/validation.py +++ b/src/deepiri_modelkit/data/validation.py @@ -2,6 +2,7 @@ Dataset Validation Utilities Provides validation and quality checks for language intelligence datasets """ + import json from pathlib import Path from typing import Dict, List, Any, Optional @@ -35,29 +36,42 @@ def _get_validation_rules(self) -> Dict[str, Any]: "min_text_length": 10, "max_text_length": 10000, "required_fields": ["text"], - "text_quality_checks": True + "text_quality_checks": True, } type_specific_rules = { "lease_abstraction": { "min_samples": 50, "lease_keywords": [ - "lease", "agreement", "landlord", "tenant", "rent", - "premises", "term", "commencement", "expiration" + "lease", + "agreement", + "landlord", + "tenant", + "rent", + "premises", + "term", + "commencement", + "expiration", ], "min_keyword_matches": 2, "check_address_patterns": True, - "check_rent_patterns": True + "check_rent_patterns": True, }, "contract_intelligence": { "min_samples": 50, "contract_keywords": [ - "contract", "agreement", "party", "obligation", - "clause", "provision", "section", "article" + "contract", + "agreement", + "party", + "obligation", + "clause", + "provision", + "section", + "article", ], "min_keyword_matches": 2, - "check_legal_patterns": True - } + "check_legal_patterns": True, + }, } if self.dataset_type in type_specific_rules: @@ -75,14 +89,16 @@ def validate_dataset(self, data_path: Path) -> Dict[str, Any]: Returns: Validation results dictionary """ - logger.info("Starting dataset validation", path=str(data_path), type=self.dataset_type) + logger.info( + "Starting dataset validation", path=str(data_path), type=self.dataset_type + ) results = { "is_valid": True, "errors": [], "warnings": [], "statistics": {}, - "quality_score": 0.0 + "quality_score": 0.0, } try: @@ -117,11 +133,13 @@ def validate_dataset(self, data_path: Path) -> Dict[str, Any]: results["errors"].append(f"Validation failed with error: {str(e)}") logger.error("Dataset validation failed", error=str(e)) - logger.info("Dataset validation complete", - valid=results["is_valid"], - quality_score=results["quality_score"], - errors=len(results["errors"]), - warnings=len(results["warnings"])) + logger.info( + "Dataset validation complete", + valid=results["is_valid"], + quality_score=results["quality_score"], + errors=len(results["errors"]), + warnings=len(results["warnings"]), + ) return results @@ -130,7 +148,7 @@ def _load_samples(self, data_path: Path) -> List[Dict[str, Any]]: samples = [] if data_path.is_file() and data_path.suffix == ".jsonl": - with open(data_path, 'r', encoding='utf-8') as f: + with open(data_path, "r", encoding="utf-8") as f: for line_num, line in enumerate(f, 1): line = line.strip() if line: @@ -142,7 +160,7 @@ def _load_samples(self, data_path: Path) -> List[Dict[str, Any]]: elif data_path.is_dir(): for file_path in data_path.glob("*.jsonl"): - with open(file_path, 'r', encoding='utf-8') as f: + with open(file_path, "r", encoding="utf-8") as f: for line_num, line in enumerate(f, 1): line = line.strip() if line: @@ -150,7 +168,9 @@ def _load_samples(self, data_path: Path) -> List[Dict[str, Any]]: sample = json.loads(line) samples.append(sample) except json.JSONDecodeError as e: - logger.warning(f"Invalid JSON in {file_path} at line {line_num}: {e}") + logger.warning( + f"Invalid JSON in {file_path} at line {line_num}: {e}" + ) return samples @@ -171,7 +191,9 @@ def _validate_structure(self, samples: List[Dict], results: Dict): for i, sample in enumerate(samples[:100]): # Check first 100 samples for field in required_fields: if field not in sample: - results["errors"].append(f"Missing required field '{field}' in sample {i}") + results["errors"].append( + f"Missing required field '{field}' in sample {i}" + ) def _validate_content_quality(self, samples: List[Dict], results: Dict): """Validate content quality.""" @@ -202,19 +224,25 @@ def _validate_content_quality(self, samples: List[Dict], results: Dict): seen_texts.add(text) # Statistics - results["statistics"].update({ - "avg_text_length": sum(text_lengths) / len(text_lengths) if text_lengths else 0, - "min_text_length": min(text_lengths) if text_lengths else 0, - "max_text_length": max(text_lengths) if text_lengths else 0, - "empty_texts": empty_texts, - "duplicate_texts": len(duplicate_texts) - }) + results["statistics"].update( + { + "avg_text_length": ( + sum(text_lengths) / len(text_lengths) if text_lengths else 0 + ), + "min_text_length": min(text_lengths) if text_lengths else 0, + "max_text_length": max(text_lengths) if text_lengths else 0, + "empty_texts": empty_texts, + "duplicate_texts": len(duplicate_texts), + } + ) if empty_texts > 0: results["errors"].append(f"Found {empty_texts} empty text samples") if len(duplicate_texts) > len(samples) * 0.01: # >1% duplicates - results["warnings"].append(f"High duplicate rate: {len(duplicate_texts)} duplicates") + results["warnings"].append( + f"High duplicate rate: {len(duplicate_texts)} duplicates" + ) def _validate_type_specific(self, samples: List[Dict], results: Dict): """Type-specific validation.""" @@ -233,10 +261,10 @@ def _validate_lease_abstraction(self, samples: List[Dict], results: Dict): rent_pattern_matches = 0 # Address patterns (street numbers, street names, cities) - address_pattern = r'\d+\s+[A-Za-z0-9\s,.-]+(?:Street|St|Avenue|Ave|Road|Rd|Boulevard|Blvd|Drive|Dr|Lane|Ln|Way|Place|Pl|Court|Ct)\s*,?\s*[A-Za-z\s]+,?\s*\d{5}' + address_pattern = r"\d+\s+[A-Za-z0-9\s,.-]+(?:Street|St|Avenue|Ave|Road|Rd|Boulevard|Blvd|Drive|Dr|Lane|Ln|Way|Place|Pl|Court|Ct)\s*,?\s*[A-Za-z\s]+,?\s*\d{5}" # Rent patterns (dollar amounts) - rent_pattern = r'\$\d{1,3}(?:,\d{3})*(?:\.\d{2})?' + rent_pattern = r"\$\d{1,3}(?:,\d{3})*(?:\.\d{2})?" for sample in samples[:500]: # Check first 500 samples for performance text = sample.get("text", "").lower() @@ -261,11 +289,13 @@ def _validate_lease_abstraction(self, samples: List[Dict], results: Dict): f"Low keyword relevance: {keyword_failure_rate:.1%} samples lack lease keywords" ) - results["statistics"].update({ - "address_pattern_matches": address_pattern_matches, - "rent_pattern_matches": rent_pattern_matches, - "keyword_relevance_score": 1.0 - keyword_failure_rate - }) + results["statistics"].update( + { + "address_pattern_matches": address_pattern_matches, + "rent_pattern_matches": rent_pattern_matches, + "keyword_relevance_score": 1.0 - keyword_failure_rate, + } + ) def _validate_contract_intelligence(self, samples: List[Dict], results: Dict): """Validate contract intelligence dataset.""" @@ -277,11 +307,11 @@ def _validate_contract_intelligence(self, samples: List[Dict], results: Dict): # Legal clause patterns legal_patterns = [ - r'\bsection\s+\d+', - r'\barticle\s+\d+', - r'\bclause\s+\d+', - r'\bparagraph\s+\d+', - r'\bsubsection\s+\d+' + r"\bsection\s+\d+", + r"\barticle\s+\d+", + r"\bclause\s+\d+", + r"\bparagraph\s+\d+", + r"\bsubsection\s+\d+", ] for sample in samples[:500]: # Check first 500 samples @@ -293,7 +323,9 @@ def _validate_contract_intelligence(self, samples: List[Dict], results: Dict): low_keyword_samples += 1 # Legal pattern matching - if any(re.search(pattern, sample.get("text", "")) for pattern in legal_patterns): + if any( + re.search(pattern, sample.get("text", "")) for pattern in legal_patterns + ): legal_pattern_matches += 1 total_checked = min(500, len(samples)) @@ -304,10 +336,12 @@ def _validate_contract_intelligence(self, samples: List[Dict], results: Dict): f"Low keyword relevance: {keyword_failure_rate:.1%} samples lack contract keywords" ) - results["statistics"].update({ - "legal_pattern_matches": legal_pattern_matches, - "keyword_relevance_score": 1.0 - keyword_failure_rate - }) + results["statistics"].update( + { + "legal_pattern_matches": legal_pattern_matches, + "keyword_relevance_score": 1.0 - keyword_failure_rate, + } + ) def _calculate_quality_score(self, results: Dict) -> float: """Calculate overall quality score (0.0 to 1.0).""" @@ -349,7 +383,9 @@ def _calculate_quality_score(self, results: Dict) -> float: return max(0.0, min(1.0, score)) -def validate_dataset_quality(data_path: Path, dataset_type: str = "general") -> Dict[str, Any]: +def validate_dataset_quality( + data_path: Path, dataset_type: str = "general" +) -> Dict[str, Any]: """ Convenience function to validate dataset quality. diff --git a/src/deepiri_modelkit/logging.py b/src/deepiri_modelkit/logging.py index 9b65998..6530da0 100644 --- a/src/deepiri_modelkit/logging.py +++ b/src/deepiri_modelkit/logging.py @@ -2,6 +2,7 @@ Shared logging utilities for all Deepiri services Used by: Cyrex (runtime), Helox (training), and all microservices """ + import logging import sys import json @@ -12,47 +13,47 @@ class StructuredLogger: """JSON structured logger for all Deepiri services""" - + def __init__(self, name: str, level: int = logging.INFO): self.logger = logging.getLogger(name) self.logger.setLevel(level) - + # Remove existing handlers self.logger.handlers = [] - + # Create console handler with JSON formatting handler = logging.StreamHandler(sys.stdout) handler.setLevel(level) handler.setFormatter(JsonFormatter()) self.logger.addHandler(handler) - + self.logger.propagate = False - + def _log(self, level: int, event: str, **kwargs): """Internal log method with structured data""" extra = {"event": event, "timestamp": datetime.utcnow().isoformat() + "Z"} extra.update(kwargs) self.logger.log(level, json.dumps(extra)) - + def debug(self, event: str, **kwargs): self._log(logging.DEBUG, event, **kwargs) - + def info(self, event: str, **kwargs): self._log(logging.INFO, event, **kwargs) - + def warning(self, event: str, **kwargs): self._log(logging.WARNING, event, **kwargs) - + def error(self, event: str, **kwargs): self._log(logging.ERROR, event, **kwargs) - + def critical(self, event: str, **kwargs): self._log(logging.CRITICAL, event, **kwargs) class JsonFormatter(logging.Formatter): """Format logs as JSON""" - + def format(self, record: logging.LogRecord) -> str: log_data = { "timestamp": datetime.utcnow().isoformat() + "Z", @@ -60,27 +61,47 @@ def format(self, record: logging.LogRecord) -> str: "logger": record.name, "message": record.getMessage(), } - + # Add extra fields if present - if hasattr(record, 'event'): - log_data['event'] = record.event - + if hasattr(record, "event"): + log_data["event"] = record.event + # Add any custom fields from extra={} for key, value in record.__dict__.items(): - if key not in ['name', 'msg', 'args', 'created', 'filename', 'funcName', - 'levelname', 'levelno', 'lineno', 'module', 'msecs', - 'message', 'pathname', 'process', 'processName', - 'relativeCreated', 'thread', 'threadName', 'exc_info', - 'exc_text', 'stack_info', 'event', 'timestamp']: + if key not in [ + "name", + "msg", + "args", + "created", + "filename", + "funcName", + "levelname", + "levelno", + "lineno", + "module", + "msecs", + "message", + "pathname", + "process", + "processName", + "relativeCreated", + "thread", + "threadName", + "exc_info", + "exc_text", + "stack_info", + "event", + "timestamp", + ]: log_data[key] = value - + return json.dumps(log_data) def get_logger(name: str, level: int = logging.INFO) -> StructuredLogger: """ Get structured logger instance - + Usage: from deepiri_modelkit.logging import get_logger logger = get_logger("my_service") @@ -91,10 +112,10 @@ def get_logger(name: str, level: int = logging.INFO) -> StructuredLogger: class ErrorLogger: """Error logging with context""" - + def __init__(self): self.logger = get_logger("error_logger") - + def log_api_error(self, error: Exception, request_id: str, endpoint: str): """Log API errors with context""" self.logger.error( @@ -102,27 +123,31 @@ def log_api_error(self, error: Exception, request_id: str, endpoint: str): error=str(error), error_type=type(error).__name__, request_id=request_id, - endpoint=endpoint + endpoint=endpoint, ) - - def log_model_error(self, error: Exception, model_name: str, input_data: Optional[Dict] = None): + + def log_model_error( + self, error: Exception, model_name: str, input_data: Optional[Dict] = None + ): """Log model inference errors""" self.logger.error( "model_error", error=str(error), error_type=type(error).__name__, model_name=model_name, - input_sample=str(input_data)[:200] if input_data else None + input_sample=str(input_data)[:200] if input_data else None, ) - - def log_training_error(self, error: Exception, pipeline: str, config: Optional[Dict] = None): + + def log_training_error( + self, error: Exception, pipeline: str, config: Optional[Dict] = None + ): """Log training pipeline errors""" self.logger.error( "training_error", error=str(error), error_type=type(error).__name__, pipeline=pipeline, - config=config + config=config, ) @@ -144,4 +169,3 @@ def get_error_logger() -> ErrorLogger: if _error_logger is None: _error_logger = ErrorLogger() return _error_logger - diff --git a/src/deepiri_modelkit/ml/__init__.py b/src/deepiri_modelkit/ml/__init__.py index 09c30ff..1a41dcb 100644 --- a/src/deepiri_modelkit/ml/__init__.py +++ b/src/deepiri_modelkit/ml/__init__.py @@ -8,12 +8,14 @@ ConfidenceCalculator, get_confidence_calculator, ) + _HAS_CONFIDENCE = True except ImportError: _HAS_CONFIDENCE = False try: from .semantic import SemanticAnalyzer, get_semantic_analyzer + _HAS_SEMANTIC = True except ImportError: _HAS_SEMANTIC = False diff --git a/src/deepiri_modelkit/ml/confidence.py b/src/deepiri_modelkit/ml/confidence.py index 89d5f2b..ca23d11 100644 --- a/src/deepiri_modelkit/ml/confidence.py +++ b/src/deepiri_modelkit/ml/confidence.py @@ -2,12 +2,14 @@ Redesigned Confidence Classes System Provides structured confidence assessment with multiple attributes """ + from enum import Enum from typing import Dict, List, Optional, Tuple from dataclasses import dataclass try: import numpy as np + HAS_NUMPY = True except ImportError: HAS_NUMPY = False @@ -15,6 +17,7 @@ class ConfidenceLevel(str, Enum): """Confidence level categories""" + VERY_HIGH = "very_high" # 0.9-1.0 HIGH = "high" # 0.75-0.9 MEDIUM = "medium" # 0.5-0.75 @@ -24,6 +27,7 @@ class ConfidenceLevel(str, Enum): class ConfidenceSource(str, Enum): """Sources of confidence information""" + MODEL_PREDICTION = "model_prediction" TRAINING_DATA_COVERAGE = "training_data_coverage" FEATURE_QUALITY = "feature_quality" @@ -46,6 +50,7 @@ class ConfidenceAttributes: reliability: Overall reliability score explanation: Human-readable explanation """ + base_score: float level: ConfidenceLevel sources: Dict[str, float] @@ -63,7 +68,7 @@ def to_dict(self) -> Dict: "uncertainty": self.uncertainty, "calibration": self.calibration, "reliability": self.reliability, - "explanation": self.explanation + "explanation": self.explanation, } @@ -78,7 +83,7 @@ def __init__(self): ConfidenceLevel.HIGH: 0.75, ConfidenceLevel.MEDIUM: 0.5, ConfidenceLevel.LOW: 0.25, - ConfidenceLevel.VERY_LOW: 0.0 + ConfidenceLevel.VERY_LOW: 0.0, } def calculate_confidence( @@ -88,7 +93,7 @@ def calculate_confidence( training_coverage: Optional[float] = None, feature_quality: Optional[float] = None, context_match: Optional[float] = None, - historical_accuracy: Optional[Dict[int, float]] = None + historical_accuracy: Optional[Dict[int, float]] = None, ) -> ConfidenceAttributes: """ Calculate comprehensive confidence attributes @@ -105,7 +110,9 @@ def calculate_confidence( ConfidenceAttributes object """ if not HAS_NUMPY: - raise ImportError("numpy is required for ConfidenceCalculator. Install with: pip install numpy") + raise ImportError( + "numpy is required for ConfidenceCalculator. Install with: pip install numpy" + ) # Base score: maximum probability base_score = float(np.max(model_probabilities)) @@ -117,7 +124,9 @@ def calculate_confidence( # Calibration: difference between top-2 probabilities (margin) sorted_probs = np.sort(model_probabilities)[::-1] - margin = float(sorted_probs[0] - sorted_probs[1]) if len(sorted_probs) > 1 else 1.0 + margin = ( + float(sorted_probs[0] - sorted_probs[1]) if len(sorted_probs) > 1 else 1.0 + ) calibration = float(margin) # Higher margin = better calibration # Source contributions @@ -130,7 +139,9 @@ def calculate_confidence( if training_coverage is not None: sources[ConfidenceSource.TRAINING_DATA_COVERAGE.value] = training_coverage else: - sources[ConfidenceSource.TRAINING_DATA_COVERAGE.value] = 0.7 # Default moderate + sources[ConfidenceSource.TRAINING_DATA_COVERAGE.value] = ( + 0.7 # Default moderate + ) # Feature quality if feature_quality is not None: @@ -150,12 +161,16 @@ def calculate_confidence( hist_acc = historical_accuracy.get(predicted_class, 0.7) sources[ConfidenceSource.HISTORICAL_ACCURACY.value] = hist_acc else: - sources[ConfidenceSource.HISTORICAL_ACCURACY.value] = 0.7 # Default moderate + sources[ConfidenceSource.HISTORICAL_ACCURACY.value] = ( + 0.7 # Default moderate + ) # Ensemble agreement (if top_k_probs provided) if top_k_probs: agreement = float(np.std(top_k_probs)) # Lower std = higher agreement - sources[ConfidenceSource.ENSEMBLE_AGREEMENT.value] = 1.0 - min(agreement, 1.0) + sources[ConfidenceSource.ENSEMBLE_AGREEMENT.value] = 1.0 - min( + agreement, 1.0 + ) else: sources[ConfidenceSource.ENSEMBLE_AGREEMENT.value] = 0.7 # Default moderate @@ -166,16 +181,17 @@ def calculate_confidence( ConfidenceSource.FEATURE_QUALITY.value: 0.15, ConfidenceSource.CONTEXT_MATCH.value: 0.1, ConfidenceSource.HISTORICAL_ACCURACY.value: 0.1, - ConfidenceSource.ENSEMBLE_AGREEMENT.value: 0.1 + ConfidenceSource.ENSEMBLE_AGREEMENT.value: 0.1, } reliability = sum( - sources[source] * weights.get(source, 0.0) - for source in sources + sources[source] * weights.get(source, 0.0) for source in sources ) # Adjust reliability based on uncertainty and calibration - reliability = reliability * (1.0 - uncertainty * 0.3) * (0.7 + calibration * 0.3) + reliability = ( + reliability * (1.0 - uncertainty * 0.3) * (0.7 + calibration * 0.3) + ) reliability = max(0.0, min(1.0, reliability)) # Determine confidence level @@ -193,7 +209,7 @@ def calculate_confidence( uncertainty=uncertainty, calibration=calibration, reliability=reliability, - explanation=explanation + explanation=explanation, ) def _get_confidence_level(self, reliability: float) -> ConfidenceLevel: @@ -215,13 +231,15 @@ def _generate_explanation( level: ConfidenceLevel, sources: Dict[str, float], uncertainty: float, - calibration: float + calibration: float, ) -> str: """Generate human-readable explanation""" parts = [] # Main confidence statement - parts.append(f"Confidence: {level.value.replace('_', ' ').title()} ({reliability:.2%})") + parts.append( + f"Confidence: {level.value.replace('_', ' ').title()} ({reliability:.2%})" + ) # Key factors key_factors = [] @@ -255,7 +273,7 @@ def should_accept_prediction( self, confidence: ConfidenceAttributes, min_reliability: float = 0.7, - min_level: ConfidenceLevel = ConfidenceLevel.MEDIUM + min_level: ConfidenceLevel = ConfidenceLevel.MEDIUM, ) -> Tuple[bool, str]: """ Determine if prediction should be accepted based on confidence @@ -268,14 +286,20 @@ def should_accept_prediction( ConfidenceLevel.LOW: 1, ConfidenceLevel.MEDIUM: 2, ConfidenceLevel.HIGH: 3, - ConfidenceLevel.VERY_HIGH: 4 + ConfidenceLevel.VERY_HIGH: 4, } if confidence.reliability < min_reliability: - return False, f"Reliability {confidence.reliability:.2%} below threshold {min_reliability:.2%}" + return ( + False, + f"Reliability {confidence.reliability:.2%} below threshold {min_reliability:.2%}", + ) if level_order[confidence.level] < level_order[min_level]: - return False, f"Confidence level {confidence.level.value} below required {min_level.value}" + return ( + False, + f"Confidence level {confidence.level.value} below required {min_level.value}", + ) return True, "Confidence meets requirements" diff --git a/src/deepiri_modelkit/ml/semantic.py b/src/deepiri_modelkit/ml/semantic.py index 141c6b1..366af91 100644 --- a/src/deepiri_modelkit/ml/semantic.py +++ b/src/deepiri_modelkit/ml/semantic.py @@ -3,6 +3,7 @@ Inspired by Carnegie Mellon University (CMU) Language Technologies Institute approaches Uses semantic similarity and contextual understanding for dynamic variation generation """ + import json import re from typing import List, Dict, Optional, Set @@ -16,18 +17,21 @@ # Try multiple HTTP clients try: import httpx + HAS_HTTPX = True except ImportError: HAS_HTTPX = False try: import requests + HAS_REQUESTS = True except ImportError: HAS_REQUESTS = False try: import ollama + HAS_OLLAMA_PKG = True except ImportError: HAS_OLLAMA_PKG = False @@ -39,12 +43,16 @@ class SemanticAnalyzer: Inspired by CMU's semantic analysis approaches """ - def __init__(self, ollama_base_url: str = "http://localhost:11434", model: str = "llama3:8b"): + def __init__( + self, ollama_base_url: str = "http://localhost:11434", model: str = "llama3:8b" + ): self.ollama_base_url = ollama_base_url self.model = model self._cache = {} # Cache for semantic analysis results - def _call_ollama(self, prompt: str, timeout: float = 15.0) -> Optional[str]: # Reduced from 30s to 15s + def _call_ollama( + self, prompt: str, timeout: float = 15.0 + ) -> Optional[str]: # Reduced from 30s to 15s """Call Ollama API directly via HTTP or Python package""" # Try ollama Python package first (cleaner API) if HAS_OLLAMA_PKG: @@ -55,8 +63,8 @@ def _call_ollama(self, prompt: str, timeout: float = 15.0) -> Optional[str]: # options={ "temperature": 0.7, "top_p": 0.9, - "num_predict": 100 # Reduced from 200 for faster responses - } + "num_predict": 100, # Reduced from 200 for faster responses + }, ) return response.get("response", "").strip() except Exception: @@ -76,10 +84,10 @@ def _call_ollama(self, prompt: str, timeout: float = 15.0) -> Optional[str]: # "options": { "temperature": 0.7, "top_p": 0.9, - "num_predict": 100 # Reduced from 200 for faster responses - } + "num_predict": 100, # Reduced from 200 for faster responses + }, }, - timeout=timeout + timeout=timeout, ) if response.status_code == 200: @@ -87,7 +95,9 @@ def _call_ollama(self, prompt: str, timeout: float = 15.0) -> Optional[str]: # logger.debug("Ollama HTTP call succeeded") return result.get("response", "").strip() else: - logger.debug(f"Ollama HTTP call failed: HTTP {response.status_code}") + logger.debug( + f"Ollama HTTP call failed: HTTP {response.status_code}" + ) elif HAS_REQUESTS: logger.debug(f"Calling Ollama HTTP with {len(prompt)} char prompt") response = requests.post( @@ -99,10 +109,10 @@ def _call_ollama(self, prompt: str, timeout: float = 15.0) -> Optional[str]: # "options": { "temperature": 0.7, "top_p": 0.9, - "num_predict": 100 # Reduced from 200 for faster responses - } + "num_predict": 100, # Reduced from 200 for faster responses + }, }, - timeout=timeout + timeout=timeout, ) if response.status_code == 200: @@ -110,7 +120,9 @@ def _call_ollama(self, prompt: str, timeout: float = 15.0) -> Optional[str]: # logger.debug("Ollama HTTP call succeeded") return result.get("response", "").strip() else: - logger.debug(f"Ollama HTTP call failed: HTTP {response.status_code}") + logger.debug( + f"Ollama HTTP call failed: HTTP {response.status_code}" + ) except Exception as e: logger.debug(f"Ollama HTTP call failed: {e}") @@ -137,7 +149,7 @@ def extract_semantic_verbs(self, text: str, category: str) -> List[str]: response = self._call_ollama(prompt) if response: try: - json_match = re.search(r'\[.*?\]', response, re.DOTALL) + json_match = re.search(r"\[.*?\]", response, re.DOTALL) if json_match: verbs = json.loads(json_match.group()) if isinstance(verbs, list) and len(verbs) > 0: @@ -171,7 +183,7 @@ def generate_semantic_prefixes(self, text: str, category: str) -> List[str]: response = self._call_ollama(prompt) if response: try: - json_match = re.search(r'\[.*?\]', response, re.DOTALL) + json_match = re.search(r"\[.*?\]", response, re.DOTALL) if json_match: prefixes = json.loads(json_match.group()) if isinstance(prefixes, list) and len(prefixes) > 0: @@ -182,8 +194,14 @@ def generate_semantic_prefixes(self, text: str, category: str) -> List[str]: # Fallback: return default prefixes return [ - "I need to", "Can you help me", "Please", "I want to", - "Help me", "I should", "Let me", "I'm going to" + "I need to", + "Can you help me", + "Please", + "I want to", + "Help me", + "I should", + "Let me", + "I'm going to", ] def generate_semantic_suffixes(self, text: str, category: str) -> List[str]: @@ -207,7 +225,7 @@ def generate_semantic_suffixes(self, text: str, category: str) -> List[str]: response = self._call_ollama(prompt) if response: try: - json_match = re.search(r'\[.*?\]', response, re.DOTALL) + json_match = re.search(r"\[.*?\]", response, re.DOTALL) if json_match: suffixes = json.loads(json_match.group()) if isinstance(suffixes, list) and len(suffixes) > 0: @@ -218,11 +236,18 @@ def generate_semantic_suffixes(self, text: str, category: str) -> List[str]: # Fallback: return default suffixes return [ - "", " today", " this week", " as soon as possible", - " when you have time", " - urgent", " - important" + "", + " today", + " this week", + " as soon as possible", + " when you have time", + " - urgent", + " - important", ] - def generate_paraphrases(self, text: str, category: str, num_paraphrases: int = 3) -> List[str]: + def generate_paraphrases( + self, text: str, category: str, num_paraphrases: int = 3 + ) -> List[str]: """ Generate semantic paraphrases using Ollama Inspired by CMU's paraphrase generation approaches @@ -243,12 +268,12 @@ def generate_paraphrases(self, text: str, category: str, num_paraphrases: int = response = self._call_ollama(prompt) if response: paraphrases = [] - for line in response.strip().split('\n'): + for line in response.strip().split("\n"): line = line.strip() # Remove common prefixes - for prefix in ['- ', '1. ', '2. ', '3. ', '4. ', '5. ', '* ', '• ']: + for prefix in ["- ", "1. ", "2. ", "3. ", "4. ", "5. ", "* ", "• "]: if line.startswith(prefix): - line = line[len(prefix):].strip() + line = line[len(prefix) :].strip() if line and line != text and len(line) > 10: paraphrases.append(line) @@ -278,7 +303,7 @@ def analyze_semantic_structure(self, text: str) -> Dict: response = self._call_ollama(prompt) if response: try: - json_match = re.search(r'\{.*?\}', response, re.DOTALL) + json_match = re.search(r"\{.*?\}", response, re.DOTALL) if json_match: return json.loads(json_match.group()) except Exception: @@ -291,7 +316,7 @@ def analyze_semantic_structure(self, text: str) -> Dict: "object": " ".join(words[1:]) if len(words) > 1 else "", "modifiers": [], "temporal": None, - "urgency": None + "urgency": None, } def check_ollama_available(self) -> bool: @@ -307,16 +332,10 @@ def check_ollama_available(self) -> bool: # Fall back to HTTP check try: if HAS_HTTPX: - response = httpx.get( - f"{self.ollama_base_url}/api/tags", - timeout=5.0 - ) + response = httpx.get(f"{self.ollama_base_url}/api/tags", timeout=5.0) return response.status_code == 200 elif HAS_REQUESTS: - response = requests.get( - f"{self.ollama_base_url}/api/tags", - timeout=5.0 - ) + response = requests.get(f"{self.ollama_base_url}/api/tags", timeout=5.0) return response.status_code == 200 except Exception: pass @@ -325,8 +344,7 @@ def check_ollama_available(self) -> bool: def get_semantic_analyzer( - ollama_base_url: Optional[str] = None, - model: Optional[str] = None + ollama_base_url: Optional[str] = None, model: Optional[str] = None ) -> Optional[SemanticAnalyzer]: """ Factory function to get semantic analyzer diff --git a/src/deepiri_modelkit/rag/__init__.py b/src/deepiri_modelkit/rag/__init__.py index 0b3bbcc..746f60f 100644 --- a/src/deepiri_modelkit/rag/__init__.py +++ b/src/deepiri_modelkit/rag/__init__.py @@ -37,6 +37,7 @@ MultiQueryRetriever, QueryCache, ) + HAS_ADVANCED_RETRIEVAL = True except ImportError: HAS_ADVANCED_RETRIEVAL = False @@ -53,6 +54,7 @@ EmbeddingCache, QueryResultCache, ) + HAS_CACHING = True except ImportError: HAS_CACHING = False @@ -68,6 +70,7 @@ SystemMetrics, PerformanceTimer, ) + HAS_MONITORING = True except ImportError: HAS_MONITORING = False @@ -85,6 +88,7 @@ BatchProcessingConfig, BatchProcessingResult, ) + HAS_ASYNC = True except ImportError: HAS_ASYNC = False @@ -119,37 +123,44 @@ # Conditionally add advanced features if HAS_ADVANCED_RETRIEVAL: - __all__.extend([ - "AdvancedRetrievalPipeline", - "QueryExpander", - "SynonymQueryExpander", - "RephraseQueryExpander", - "MultiQueryRetriever", - "QueryCache", - ]) + __all__.extend( + [ + "AdvancedRetrievalPipeline", + "QueryExpander", + "SynonymQueryExpander", + "RephraseQueryExpander", + "MultiQueryRetriever", + "QueryCache", + ] + ) if HAS_CACHING: - __all__.extend([ - "AdvancedCacheManager", - "EmbeddingCache", - "QueryResultCache", - ]) + __all__.extend( + [ + "AdvancedCacheManager", + "EmbeddingCache", + "QueryResultCache", + ] + ) if HAS_MONITORING: - __all__.extend([ - "RAGMonitor", - "RetrievalMetrics", - "IndexingMetrics", - "SystemMetrics", - "PerformanceTimer", - ]) + __all__.extend( + [ + "RAGMonitor", + "RetrievalMetrics", + "IndexingMetrics", + "SystemMetrics", + "PerformanceTimer", + ] + ) if HAS_ASYNC: - __all__.extend([ - "AsyncBatchProcessor", - "AsyncDocumentIndexer", - "AsyncDocumentProcessor", - "BatchProcessingConfig", - "BatchProcessingResult", - ]) - + __all__.extend( + [ + "AsyncBatchProcessor", + "AsyncDocumentIndexer", + "AsyncDocumentProcessor", + "BatchProcessingConfig", + "BatchProcessingResult", + ] + ) diff --git a/src/deepiri_modelkit/rag/advanced_retrieval.py b/src/deepiri_modelkit/rag/advanced_retrieval.py index 14e1de6..4de0fbc 100644 --- a/src/deepiri_modelkit/rag/advanced_retrieval.py +++ b/src/deepiri_modelkit/rag/advanced_retrieval.py @@ -15,6 +15,7 @@ @dataclass class ExpandedQuery: """Expanded query with multiple variations""" + original_query: str expanded_queries: List[str] query_type: str # "synonym", "rephrase", "keyword", etc. @@ -23,7 +24,7 @@ class ExpandedQuery: class QueryExpander(ABC): """Base class for query expansion strategies""" - + @abstractmethod def expand(self, query: str, max_expansions: int = 3) -> ExpandedQuery: """Expand a query into multiple variations""" @@ -32,15 +33,15 @@ def expand(self, query: str, max_expansions: int = 3) -> ExpandedQuery: class SynonymQueryExpander(QueryExpander): """Expand queries using synonyms and related terms""" - + def __init__(self, synonym_dict: Optional[Dict[str, List[str]]] = None): self.synonym_dict = synonym_dict or self._default_synonyms() - + def expand(self, query: str, max_expansions: int = 3) -> ExpandedQuery: """Expand query using synonyms""" words = query.lower().split() expanded = [query] # Always include original - + for word in words: if word in self.synonym_dict: synonyms = self.synonym_dict[word][:max_expansions] @@ -48,14 +49,14 @@ def expand(self, query: str, max_expansions: int = 3) -> ExpandedQuery: expanded_query = query.replace(word, synonym) if expanded_query not in expanded: expanded.append(expanded_query) - + return ExpandedQuery( original_query=query, - expanded_queries=expanded[:max_expansions + 1], + expanded_queries=expanded[: max_expansions + 1], query_type="synonym", - confidence=0.8 + confidence=0.8, ) - + def _default_synonyms(self) -> Dict[str, List[str]]: """Default synonym dictionary""" return { @@ -74,17 +75,17 @@ def _default_synonyms(self) -> Dict[str, List[str]]: class RephraseQueryExpander(QueryExpander): """Expand queries by rephrasing""" - + def __init__(self, llm_client=None): self.llm_client = llm_client - + def expand(self, query: str, max_expansions: int = 3) -> ExpandedQuery: """Rephrase query using LLM or templates""" if self.llm_client: return self._llm_rephrase(query, max_expansions) else: return self._template_rephrase(query, max_expansions) - + def _template_rephrase(self, query: str, max_expansions: int) -> ExpandedQuery: """Rephrase using templates""" templates = [ @@ -93,16 +94,16 @@ def _template_rephrase(self, query: str, max_expansions: int) -> ExpandedQuery: f"Information about {query}", f"Details on {query}", ] - + expanded = [query] + templates[:max_expansions] - + return ExpandedQuery( original_query=query, expanded_queries=expanded, query_type="rephrase", - confidence=0.7 + confidence=0.7, ) - + def _llm_rephrase(self, query: str, max_expansions: int) -> ExpandedQuery: """Rephrase using LLM (if available)""" # Placeholder for LLM-based rephrasing @@ -111,16 +112,51 @@ def _llm_rephrase(self, query: str, max_expansions: int) -> ExpandedQuery: class KeywordExtractor: """Extract keywords from queries for hybrid search""" - + def __init__(self): # Common stop words self.stop_words = { - "the", "a", "an", "and", "or", "but", "in", "on", "at", "to", "for", - "of", "with", "by", "is", "are", "was", "were", "be", "been", "being", - "have", "has", "had", "do", "does", "did", "will", "would", "should", - "could", "may", "might", "must", "can", "this", "that", "these", "those" + "the", + "a", + "an", + "and", + "or", + "but", + "in", + "on", + "at", + "to", + "for", + "of", + "with", + "by", + "is", + "are", + "was", + "were", + "be", + "been", + "being", + "have", + "has", + "had", + "do", + "does", + "did", + "will", + "would", + "should", + "could", + "may", + "might", + "must", + "can", + "this", + "that", + "these", + "those", } - + def extract(self, query: str, max_keywords: int = 10) -> List[str]: """Extract keywords from query""" words = query.lower().split() @@ -138,38 +174,36 @@ class MultiQueryRetriever: Multi-query retrieval strategy Generates multiple query variations and combines results """ - + def __init__( self, base_retriever, query_expander: Optional[QueryExpander] = None, - fusion_method: str = "rrf" # "rrf" (Reciprocal Rank Fusion) or "mean" + fusion_method: str = "rrf", # "rrf" (Reciprocal Rank Fusion) or "mean" ): self.base_retriever = base_retriever self.query_expander = query_expander or SynonymQueryExpander() self.fusion_method = fusion_method - - def retrieve( - self, - query: RAGQuery, - num_queries: int = 3 - ) -> List[RetrievalResult]: + + def retrieve(self, query: RAGQuery, num_queries: int = 3) -> List[RetrievalResult]: """ Retrieve using multiple query variations - + Args: query: Original RAG query num_queries: Number of query variations to generate - + Returns: Fused retrieval results """ # Expand query - expanded = self.query_expander.expand(query.query, max_expansions=num_queries - 1) - + expanded = self.query_expander.expand( + query.query, max_expansions=num_queries - 1 + ) + # Retrieve for each query variation all_results: Dict[str, List[RetrievalResult]] = {} - + for expanded_query in expanded.expanded_queries: # Create new query with expanded text expanded_rag_query = RAGQuery( @@ -178,135 +212,131 @@ def retrieve( doc_types=query.doc_types, date_range=query.date_range, metadata_filters=query.metadata_filters, - top_k=query.top_k or 10 + top_k=query.top_k or 10, ) - + # Retrieve results results = self.base_retriever.retrieve(expanded_rag_query) all_results[expanded_query] = results - + # Fuse results fused = self._fuse_results(all_results, query.top_k or 10) - + return fused - + def _fuse_results( - self, - all_results: Dict[str, List[RetrievalResult]], - top_k: int + self, all_results: Dict[str, List[RetrievalResult]], top_k: int ) -> List[RetrievalResult]: """Fuse results from multiple queries""" if self.fusion_method == "rrf": return self._reciprocal_rank_fusion(all_results, top_k) else: return self._mean_score_fusion(all_results, top_k) - + def _reciprocal_rank_fusion( - self, - all_results: Dict[str, List[RetrievalResult]], - top_k: int, - k: int = 60 + self, all_results: Dict[str, List[RetrievalResult]], top_k: int, k: int = 60 ) -> List[RetrievalResult]: """ Reciprocal Rank Fusion (RRF) Combines rankings from multiple queries """ doc_scores: Dict[str, Dict[str, Any]] = {} - + for query_text, results in all_results.items(): for rank, result in enumerate(results): doc_id = result.document.id rrf_score = 1.0 / (k + rank + 1) - + if doc_id not in doc_scores: doc_scores[doc_id] = { "document": result.document, "rrf_score": 0.0, "max_score": result.score, - "count": 0 + "count": 0, } - + doc_scores[doc_id]["rrf_score"] += rrf_score doc_scores[doc_id]["max_score"] = max( - doc_scores[doc_id]["max_score"], - result.score + doc_scores[doc_id]["max_score"], result.score ) doc_scores[doc_id]["count"] += 1 - + # Create fused results fused = [] for doc_id, scores in doc_scores.items(): - fused.append(RetrievalResult( - document=scores["document"], - score=scores["rrf_score"], # Use RRF score - rerank_score=scores["max_score"] # Store max original score - )) - + fused.append( + RetrievalResult( + document=scores["document"], + score=scores["rrf_score"], # Use RRF score + rerank_score=scores["max_score"], # Store max original score + ) + ) + # Sort by RRF score fused.sort(key=lambda x: x.score, reverse=True) - + return fused[:top_k] - + def _mean_score_fusion( - self, - all_results: Dict[str, List[RetrievalResult]], - top_k: int + self, all_results: Dict[str, List[RetrievalResult]], top_k: int ) -> List[RetrievalResult]: """Fuse results using mean score""" doc_scores: Dict[str, Dict[str, Any]] = {} - + for query_text, results in all_results.items(): for result in results: doc_id = result.document.id - + if doc_id not in doc_scores: doc_scores[doc_id] = { "document": result.document, "scores": [], - "count": 0 + "count": 0, } - + doc_scores[doc_id]["scores"].append(result.score) doc_scores[doc_id]["count"] += 1 - + # Calculate mean scores fused = [] for doc_id, scores in doc_scores.items(): mean_score = sum(scores["scores"]) / len(scores["scores"]) - fused.append(RetrievalResult( - document=scores["document"], - score=mean_score, - rerank_score=max(scores["scores"]) - )) - + fused.append( + RetrievalResult( + document=scores["document"], + score=mean_score, + rerank_score=max(scores["scores"]), + ) + ) + # Sort by mean score fused.sort(key=lambda x: x.score, reverse=True) - + return fused[:top_k] class QueryCache: """Cache for query results""" - + def __init__(self, cache_manager=None): self.cache_manager = cache_manager self.cache_ttl = 3600 # 1 hour - + def get_cache_key(self, query: RAGQuery) -> str: """Generate cache key for query""" query_dict = query.to_dict() query_str = json.dumps(query_dict, sort_keys=True) query_hash = hashlib.md5(query_str.encode()).hexdigest() return f"rag:query:{query_hash}" - + def get(self, query: RAGQuery) -> Optional[List[RetrievalResult]]: """Get cached results""" if not self.cache_manager: return None - + cache_key = self.get_cache_key(query) cached = self.cache_manager.get(cache_key) - + if cached: # Reconstruct RetrievalResult objects results = [] @@ -315,29 +345,29 @@ def get(self, query: RAGQuery) -> Optional[List[RetrievalResult]]: result = RetrievalResult( document=doc, score=item["score"], - rerank_score=item.get("rerank_score") + rerank_score=item.get("rerank_score"), ) results.append(result) return results - + return None - + def set(self, query: RAGQuery, results: List[RetrievalResult]): """Cache results""" if not self.cache_manager: return - + cache_key = self.get_cache_key(query) # Serialize results serialized = [ { "document": r.document.to_dict(), "score": r.score, - "rerank_score": r.rerank_score + "rerank_score": r.rerank_score, } for r in results ] - + self.cache_manager.set(cache_key, serialized, ttl=self.cache_ttl) @@ -349,29 +379,28 @@ class AdvancedRetrievalPipeline: - Result caching - Hybrid search """ - + def __init__( self, base_retriever, query_expander: Optional[QueryExpander] = None, use_multi_query: bool = True, use_cache: bool = True, - cache_manager=None + cache_manager=None, ): self.base_retriever = base_retriever self.query_expander = query_expander or SynonymQueryExpander() self.use_multi_query = use_multi_query self.use_cache = use_cache self.query_cache = QueryCache(cache_manager) if use_cache else None - + if use_multi_query: self.multi_query_retriever = MultiQueryRetriever( - base_retriever=base_retriever, - query_expander=query_expander + base_retriever=base_retriever, query_expander=query_expander ) else: self.multi_query_retriever = None - + def retrieve(self, query: RAGQuery) -> List[RetrievalResult]: """Retrieve with advanced strategies""" # Check cache @@ -379,16 +408,15 @@ def retrieve(self, query: RAGQuery) -> List[RetrievalResult]: cached = self.query_cache.get(query) if cached: return cached - + # Retrieve using multi-query or base retriever if self.use_multi_query and self.multi_query_retriever: results = self.multi_query_retriever.retrieve(query) else: results = self.base_retriever.retrieve(query) - + # Cache results if self.use_cache and self.query_cache: self.query_cache.set(query, results) - - return results + return results diff --git a/src/deepiri_modelkit/rag/async_processing.py b/src/deepiri_modelkit/rag/async_processing.py index 9625443..87a0b0f 100644 --- a/src/deepiri_modelkit/rag/async_processing.py +++ b/src/deepiri_modelkit/rag/async_processing.py @@ -5,6 +5,7 @@ import asyncio from typing import List, Dict, Any, Optional, Callable, Awaitable + # Fix for Python < 3.9 compatibility try: from collections.abc import AsyncIterator @@ -20,6 +21,7 @@ @dataclass class BatchProcessingConfig: """Configuration for batch processing""" + batch_size: int = 100 max_concurrent_batches: int = 5 chunk_size: int = 1000 @@ -33,24 +35,25 @@ class BatchProcessingConfig: @dataclass class BatchProcessingResult: """Result of batch processing operation""" + total_items: int processed_items: int successful_items: int failed_items: int processing_time_seconds: float errors: List[Dict[str, Any]] = None - + def __post_init__(self): if self.errors is None: self.errors = [] - + @property def success_rate(self) -> float: """Calculate success rate""" if self.total_items == 0: return 0.0 return self.successful_items / self.total_items - + def to_dict(self) -> Dict[str, Any]: """Convert to dictionary""" return { @@ -68,25 +71,25 @@ class AsyncBatchProcessor: """ Async batch processor for high-performance document processing """ - + def __init__(self, config: BatchProcessingConfig): self.config = config self.semaphore = asyncio.Semaphore(config.max_concurrent_batches) - + async def process_batch( self, items: List[Any], processor_func: Callable[[Any], Awaitable[Any]], - progress_callback: Optional[Callable[[int, int], None]] = None + progress_callback: Optional[Callable[[int, int], None]] = None, ) -> BatchProcessingResult: """ Process items in batches asynchronously - + Args: items: List of items to process processor_func: Async function to process each item progress_callback: Optional callback for progress updates - + Returns: BatchProcessingResult with statistics """ @@ -95,135 +98,125 @@ async def process_batch( successful_items = 0 failed_items = 0 errors = [] - + # Split into batches batches = [ - items[i:i + self.config.batch_size] + items[i : i + self.config.batch_size] for i in range(0, total_items, self.config.batch_size) ] - + # Process batches concurrently tasks = [] for batch_idx, batch in enumerate(batches): task = self._process_batch_with_semaphore( - batch, - batch_idx, - processor_func, - progress_callback + batch, batch_idx, processor_func, progress_callback ) tasks.append(task) - + # Wait for all batches batch_results = await asyncio.gather(*tasks, return_exceptions=True) - + # Aggregate results for result in batch_results: if isinstance(result, Exception): failed_items += self.config.batch_size - errors.append({ - "error": str(result), - "type": type(result).__name__ - }) + errors.append({"error": str(result), "type": type(result).__name__}) else: successful_items += result["successful"] failed_items += result["failed"] errors.extend(result.get("errors", [])) - + processing_time = time.time() - start_time - + return BatchProcessingResult( total_items=total_items, processed_items=total_items, successful_items=successful_items, failed_items=failed_items, processing_time_seconds=processing_time, - errors=errors + errors=errors, ) - + async def _process_batch_with_semaphore( self, batch: List[Any], batch_idx: int, processor_func: Callable[[Any], Awaitable[Any]], - progress_callback: Optional[Callable[[int, int], None]] + progress_callback: Optional[Callable[[int, int], None]], ) -> Dict[str, Any]: """Process a single batch with semaphore control""" async with self.semaphore: return await self._process_single_batch( - batch, - batch_idx, - processor_func, - progress_callback + batch, batch_idx, processor_func, progress_callback ) - + async def _process_single_batch( self, batch: List[Any], batch_idx: int, processor_func: Callable[[Any], Awaitable[Any]], - progress_callback: Optional[Callable[[int, int], None]] + progress_callback: Optional[Callable[[int, int], None]], ) -> Dict[str, Any]: """Process a single batch""" successful = 0 failed = 0 errors = [] - + # Process items in batch concurrently tasks = [processor_func(item) for item in batch] results = await asyncio.gather(*tasks, return_exceptions=True) - + for idx, result in enumerate(results): if isinstance(result, Exception): failed += 1 - errors.append({ - "item_index": batch_idx * self.config.batch_size + idx, - "error": str(result), - "type": type(result).__name__ - }) + errors.append( + { + "item_index": batch_idx * self.config.batch_size + idx, + "error": str(result), + "type": type(result).__name__, + } + ) else: successful += 1 - + # Progress callback if progress_callback: total_processed = (batch_idx + 1) * self.config.batch_size progress_callback(total_processed, len(batch) * (batch_idx + 1)) - - return { - "successful": successful, - "failed": failed, - "errors": errors - } + + return {"successful": successful, "failed": failed, "errors": errors} class AsyncDocumentIndexer: """ Async document indexer with batching and retry logic """ - + def __init__( self, index_func: Callable[[Document], Awaitable[bool]], - config: Optional[BatchProcessingConfig] = None + config: Optional[BatchProcessingConfig] = None, ): self.index_func = index_func self.config = config or BatchProcessingConfig() self.batch_processor = AsyncBatchProcessor(self.config) - + async def index_documents( self, documents: List[Document], - progress_callback: Optional[Callable[[int, int], None]] = None + progress_callback: Optional[Callable[[int, int], None]] = None, ) -> BatchProcessingResult: """ Index documents asynchronously in batches - + Args: documents: List of documents to index progress_callback: Optional callback for progress - + Returns: BatchProcessingResult with statistics """ + async def index_document(doc: Document) -> bool: """Index a single document with retry""" for attempt in range(self.config.max_retries): @@ -238,27 +231,25 @@ async def index_document(doc: Document) -> bool: ) continue raise - + return False - + return await self.batch_processor.process_batch( - documents, - index_document, - progress_callback + documents, index_document, progress_callback ) - + async def index_documents_streaming( self, document_stream: AsyncIterator[Document], - progress_callback: Optional[Callable[[int, int], None]] = None + progress_callback: Optional[Callable[[int, int], None]] = None, ) -> BatchProcessingResult: """ Index documents from async stream - + Args: document_stream: Async iterator of documents progress_callback: Optional callback for progress - + Returns: BatchProcessingResult with statistics """ @@ -268,39 +259,39 @@ async def index_documents_streaming( failed = 0 errors = [] start_time = time.time() - + async def process_current_batch(): nonlocal successful, failed, errors - + if not batch: return - + result = await self.index_documents(batch, progress_callback) successful += result.successful_items failed += result.failed_items errors.extend(result.errors) batch.clear() - + async for document in document_stream: batch.append(document) total_processed += 1 - + if len(batch) >= self.config.batch_size: await process_current_batch() - + # Process remaining batch if batch: await process_current_batch() - + processing_time = time.time() - start_time - + return BatchProcessingResult( total_items=total_processed, processed_items=total_processed, successful_items=successful, failed_items=failed, processing_time_seconds=processing_time, - errors=errors + errors=errors, ) @@ -308,65 +299,55 @@ class AsyncDocumentProcessor: """ Async document processor for parallel document processing """ - + def __init__( self, processor_func: Callable[[str, Dict], List[Document]], - config: Optional[BatchProcessingConfig] = None + config: Optional[BatchProcessingConfig] = None, ): self.processor_func = processor_func self.config = config or BatchProcessingConfig() self.batch_processor = AsyncBatchProcessor(self.config) - + async def process_documents( self, raw_documents: List[Dict[str, Any]], - progress_callback: Optional[Callable[[int, int], None]] = None + progress_callback: Optional[Callable[[int, int], None]] = None, ) -> tuple[List[Document], BatchProcessingResult]: """ Process raw documents asynchronously - + Args: raw_documents: List of dicts with 'content' and 'metadata' progress_callback: Optional callback for progress - + Returns: Tuple of (processed_documents, processing_result) """ processed_docs = [] errors = [] - + async def process_document(item: Dict[str, Any]) -> List[Document]: """Process a single document""" try: content = item.get("content", "") metadata = item.get("metadata", {}) - docs = await asyncio.to_thread( - self.processor_func, - content, - metadata - ) + docs = await asyncio.to_thread(self.processor_func, content, metadata) return docs except Exception as e: - errors.append({ - "item": item.get("id", "unknown"), - "error": str(e) - }) + errors.append({"item": item.get("id", "unknown"), "error": str(e)}) return [] - + result = await self.batch_processor.process_batch( - raw_documents, - process_document, - progress_callback + raw_documents, process_document, progress_callback ) - + # Collect all processed documents tasks = [process_document(item) for item in raw_documents] doc_lists = await asyncio.gather(*tasks, return_exceptions=True) - + for doc_list in doc_lists: if isinstance(doc_list, list): processed_docs.extend(doc_list) - - return processed_docs, result + return processed_docs, result diff --git a/src/deepiri_modelkit/rag/base.py b/src/deepiri_modelkit/rag/base.py index 7693e5d..e7ebf66 100644 --- a/src/deepiri_modelkit/rag/base.py +++ b/src/deepiri_modelkit/rag/base.py @@ -12,6 +12,7 @@ class DocumentType(Enum): """Types of documents that can be indexed""" + REGULATION = "regulation" # Laws, regulations, compliance documents POLICY = "policy" # Insurance policies, company policies MANUAL = "manual" # Equipment manuals, operation guides @@ -31,6 +32,7 @@ class DocumentType(Enum): class IndustryNiche(Enum): """Supported industry niches""" + INSURANCE = "insurance" # Property & casualty insurance MANUFACTURING = "manufacturing" # Industrial manufacturing PROPERTY_MANAGEMENT = "property_management" # Real estate management @@ -47,33 +49,34 @@ class IndustryNiche(Enum): @dataclass class RAGConfig: """Configuration for RAG system""" + # Industry configuration industry: IndustryNiche = IndustryNiche.GENERIC - + # Vector database configuration vector_db_type: str = "milvus" # milvus, pinecone, weaviate, memory collection_name: str = "deepiri_universal_rag" vector_db_host: str = "milvus" vector_db_port: int = 19530 - + # Embedding configuration embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2" embedding_dimension: int = 384 - + # Retrieval configuration top_k: int = 5 # Number of documents to retrieve similarity_threshold: float = 0.7 # Minimum similarity score use_reranking: bool = True # Use cross-encoder reranking reranker_model: str = "cross-encoder/ms-marco-MiniLM-L-6-v2" - + # Chunking configuration chunk_size: int = 1000 # Characters per chunk chunk_overlap: int = 200 # Overlap between chunks - + # Metadata filtering enable_metadata_filtering: bool = True date_range_filtering: bool = True - + # Multi-modal support support_images: bool = False support_tables: bool = True @@ -83,11 +86,12 @@ class RAGConfig: @dataclass class Document: """Universal document representation""" + id: str content: str doc_type: DocumentType industry: IndustryNiche - + # Metadata title: Optional[str] = None source: Optional[str] = None @@ -95,14 +99,14 @@ class Document: updated_at: Optional[datetime] = None author: Optional[str] = None version: Optional[str] = None - + # Industry-specific metadata metadata: Dict[str, Any] = field(default_factory=dict) - + # Processing metadata chunk_index: Optional[int] = None total_chunks: Optional[int] = None - + def to_dict(self) -> Dict[str, Any]: """Convert to dictionary for storage""" return { @@ -120,9 +124,9 @@ def to_dict(self) -> Dict[str, Any]: "chunk_index": self.chunk_index, "total_chunks": self.total_chunks, } - + @classmethod - def from_dict(cls, data: Dict[str, Any]) -> 'Document': + def from_dict(cls, data: Dict[str, Any]) -> "Document": """Create from dictionary""" return cls( id=data["id"], @@ -131,8 +135,16 @@ def from_dict(cls, data: Dict[str, Any]) -> 'Document': industry=IndustryNiche(data["industry"]), title=data.get("title"), source=data.get("source"), - created_at=datetime.fromisoformat(data["created_at"]) if data.get("created_at") else None, - updated_at=datetime.fromisoformat(data["updated_at"]) if data.get("updated_at") else None, + created_at=( + datetime.fromisoformat(data["created_at"]) + if data.get("created_at") + else None + ), + updated_at=( + datetime.fromisoformat(data["updated_at"]) + if data.get("updated_at") + else None + ), author=data.get("author"), version=data.get("version"), metadata=data.get("metadata", {}), @@ -144,10 +156,11 @@ def from_dict(cls, data: Dict[str, Any]) -> 'Document': @dataclass class RetrievalResult: """Result from RAG retrieval""" + document: Document score: float # Similarity score rerank_score: Optional[float] = None # Reranking score if applicable - + def to_dict(self) -> Dict[str, Any]: """Convert to dictionary""" return { @@ -160,20 +173,25 @@ def to_dict(self) -> Dict[str, Any]: @dataclass class RAGQuery: """Query for RAG system""" + query: str industry: Optional[IndustryNiche] = None doc_types: Optional[List[DocumentType]] = None date_range: Optional[tuple[datetime, datetime]] = None metadata_filters: Optional[Dict[str, Any]] = None top_k: Optional[int] = None - + def to_dict(self) -> Dict[str, Any]: """Convert to dictionary""" return { "query": self.query, "industry": self.industry.value if self.industry else None, - "doc_types": [dt.value for dt in self.doc_types] if self.doc_types else None, - "date_range": [dr.isoformat() for dr in self.date_range] if self.date_range else None, + "doc_types": ( + [dt.value for dt in self.doc_types] if self.doc_types else None + ), + "date_range": ( + [dr.isoformat() for dr in self.date_range] if self.date_range else None + ), "metadata_filters": self.metadata_filters, "top_k": self.top_k, } @@ -184,108 +202,108 @@ class UniversalRAGEngine(ABC): Abstract base class for universal RAG engine Implements common RAG patterns across all industries """ - + def __init__(self, config: RAGConfig): self.config = config self._initialize() - + @abstractmethod def _initialize(self): """Initialize RAG components (vector store, embeddings, etc.)""" pass - + @abstractmethod def index_document(self, document: Document) -> bool: """ Index a single document - + Args: document: Document to index - + Returns: True if successful, False otherwise """ pass - + @abstractmethod def index_documents(self, documents: List[Document]) -> Dict[str, Any]: """ Index multiple documents in batch - + Args: documents: List of documents to index - + Returns: Statistics about the indexing operation """ pass - + @abstractmethod def retrieve(self, query: RAGQuery) -> List[RetrievalResult]: """ Retrieve relevant documents for a query - + Args: query: Query with filters and parameters - + Returns: List of retrieval results with scores """ pass - + @abstractmethod def generate_with_context( self, query: str, retrieved_docs: List[RetrievalResult], - llm_prompt_template: Optional[str] = None + llm_prompt_template: Optional[str] = None, ) -> Dict[str, Any]: """ Generate response using retrieved context - + Args: query: User query retrieved_docs: Retrieved documents for context llm_prompt_template: Optional custom prompt template - + Returns: Generated response with metadata """ pass - + @abstractmethod def delete_documents(self, doc_ids: List[str]) -> bool: """Delete documents by IDs""" pass - + @abstractmethod def update_document(self, doc_id: str, document: Document) -> bool: """Update an existing document""" pass - + @abstractmethod def get_statistics(self) -> Dict[str, Any]: """Get statistics about indexed documents""" pass - + def search( self, query: str, industry: Optional[IndustryNiche] = None, doc_types: Optional[List[DocumentType]] = None, top_k: Optional[int] = None, - **filters + **filters, ) -> List[RetrievalResult]: """ Convenience method for simple search - + Args: query: Search query industry: Filter by industry doc_types: Filter by document types top_k: Number of results **filters: Additional metadata filters - + Returns: List of retrieval results """ @@ -294,7 +312,6 @@ def search( industry=industry, doc_types=doc_types, top_k=top_k or self.config.top_k, - metadata_filters=filters if filters else None + metadata_filters=filters if filters else None, ) return self.retrieve(rag_query) - diff --git a/src/deepiri_modelkit/rag/caching.py b/src/deepiri_modelkit/rag/caching.py index 8da0ff7..6d25942 100644 --- a/src/deepiri_modelkit/rag/caching.py +++ b/src/deepiri_modelkit/rag/caching.py @@ -16,6 +16,7 @@ @dataclass class CacheEntry: """Cache entry with metadata""" + key: str value: Any created_at: datetime @@ -23,19 +24,19 @@ class CacheEntry: access_count: int = 0 last_accessed: Optional[datetime] = None tags: List[str] = None - + def __post_init__(self): if self.tags is None: self.tags = [] if self.last_accessed is None: self.last_accessed = self.created_at - + def is_expired(self) -> bool: """Check if entry is expired""" if self.expires_at is None: return False return datetime.now() > self.expires_at - + def to_dict(self) -> Dict: """Convert to dictionary for serialization""" return { @@ -44,21 +45,31 @@ def to_dict(self) -> Dict: "created_at": self.created_at.isoformat(), "expires_at": self.expires_at.isoformat() if self.expires_at else None, "access_count": self.access_count, - "last_accessed": self.last_accessed.isoformat() if self.last_accessed else None, - "tags": self.tags + "last_accessed": ( + self.last_accessed.isoformat() if self.last_accessed else None + ), + "tags": self.tags, } - + @classmethod - def from_dict(cls, data: Dict) -> 'CacheEntry': + def from_dict(cls, data: Dict) -> "CacheEntry": """Create from dictionary""" return cls( key=data["key"], value=data["value"], created_at=datetime.fromisoformat(data["created_at"]), - expires_at=datetime.fromisoformat(data["expires_at"]) if data.get("expires_at") else None, + expires_at=( + datetime.fromisoformat(data["expires_at"]) + if data.get("expires_at") + else None + ), access_count=data.get("access_count", 0), - last_accessed=datetime.fromisoformat(data["last_accessed"]) if data.get("last_accessed") else None, - tags=data.get("tags", []) + last_accessed=( + datetime.fromisoformat(data["last_accessed"]) + if data.get("last_accessed") + else None + ), + tags=data.get("tags", []), ) @@ -71,53 +82,52 @@ class AdvancedCacheManager: - Size limits - LRU eviction """ - + def __init__( self, redis_client=None, default_ttl: int = 3600, max_size: int = 10000, - enable_lru: bool = True + enable_lru: bool = True, ): self.redis_client = redis_client self.default_ttl = default_ttl self.max_size = max_size self.enable_lru = enable_lru - + # In-memory fallback if Redis unavailable self.memory_cache: Dict[str, CacheEntry] = {} self.tag_index: Dict[str, List[str]] = {} # tag -> [keys] - + def _get_key_prefix(self, namespace: str = "rag") -> str: """Get key prefix""" return f"{namespace}:" - + def _serialize_value(self, value: Any) -> str: """Serialize value for storage""" if isinstance(value, (list, dict)): return json.dumps(value) return str(value) - + def _deserialize_value(self, value: str, value_type: type = None) -> Any: """Deserialize value from storage""" try: - if value_type == list or (isinstance(value, str) and value.startswith('[')): + if value_type == list or (isinstance(value, str) and value.startswith("[")): return json.loads(value) - elif value_type == dict or (isinstance(value, str) and value.startswith('{')): + elif value_type == dict or ( + isinstance(value, str) and value.startswith("{") + ): return json.loads(value) return value except (json.JSONDecodeError, TypeError): return value - + def get( - self, - key: str, - namespace: str = "rag", - update_access: bool = True + self, key: str, namespace: str = "rag", update_access: bool = True ) -> Optional[Any]: """Get value from cache""" full_key = f"{self._get_key_prefix(namespace)}{key}" - + # Try Redis first if self.redis_client: try: @@ -125,65 +135,65 @@ def get( if cached: entry_data = json.loads(cached) entry = CacheEntry.from_dict(entry_data) - + if entry.is_expired(): self.delete(key, namespace) return None - + if update_access: entry.access_count += 1 entry.last_accessed = datetime.now() self._update_redis_entry(full_key, entry) - + return entry.value except Exception as e: # Fallback to memory pass - + # Fallback to memory cache if key in self.memory_cache: entry = self.memory_cache[key] - + if entry.is_expired(): del self.memory_cache[key] return None - + if update_access: entry.access_count += 1 entry.last_accessed = datetime.now() - + return entry.value - + return None - + def set( self, key: str, value: Any, namespace: str = "rag", ttl: Optional[int] = None, - tags: Optional[List[str]] = None + tags: Optional[List[str]] = None, ) -> bool: """Set value in cache""" full_key = f"{self._get_key_prefix(namespace)}{key}" ttl = ttl or self.default_ttl tags = tags or [] - + expires_at = datetime.now() + timedelta(seconds=ttl) entry = CacheEntry( key=full_key, value=value, created_at=datetime.now(), expires_at=expires_at, - tags=tags + tags=tags, ) - + # Try Redis first if self.redis_client: try: entry_data = json.dumps(entry.to_dict()) self.redis_client.setex(full_key, ttl, entry_data) - + # Update tag index for tag in tags: tag_key = f"{self._get_key_prefix(namespace)}tag:{tag}" @@ -192,12 +202,12 @@ def set( if full_key not in self.tag_index[tag_key]: self.tag_index[tag_key].append(full_key) self.redis_client.sadd(tag_key, full_key) - + return True except Exception as e: # Fallback to memory pass - + # Fallback to memory cache # Check size limit if len(self.memory_cache) >= self.max_size: @@ -207,25 +217,25 @@ def set( # Remove oldest oldest_key = min( self.memory_cache.keys(), - key=lambda k: self.memory_cache[k].created_at + key=lambda k: self.memory_cache[k].created_at, ) del self.memory_cache[oldest_key] - + self.memory_cache[key] = entry - + # Update tag index for tag in tags: if tag not in self.tag_index: self.tag_index[tag] = [] if key not in self.tag_index[tag]: self.tag_index[tag].append(key) - + return True - + def delete(self, key: str, namespace: str = "rag") -> bool: """Delete key from cache""" full_key = f"{self._get_key_prefix(namespace)}{key}" - + # Try Redis if self.redis_client: try: @@ -233,7 +243,7 @@ def delete(self, key: str, namespace: str = "rag") -> bool: return True except Exception: pass - + # Memory cache if key in self.memory_cache: entry = self.memory_cache[key] @@ -243,40 +253,40 @@ def delete(self, key: str, namespace: str = "rag") -> bool: self.tag_index[tag].remove(key) del self.memory_cache[key] return True - + return False - + def invalidate_by_tag(self, tag: str, namespace: str = "rag") -> int: """Invalidate all keys with given tag""" tag_key = f"{self._get_key_prefix(namespace)}tag:{tag}" count = 0 - + # Get keys from tag index keys_to_delete = [] - + if self.redis_client: try: keys_to_delete = list(self.redis_client.smembers(tag_key)) self.redis_client.delete(tag_key) except Exception: pass - + if tag in self.tag_index: keys_to_delete.extend(self.tag_index[tag]) del self.tag_index[tag] - + # Delete all keys for key in keys_to_delete: if self.delete(key.replace(self._get_key_prefix(namespace), ""), namespace): count += 1 - + return count - + def invalidate_by_pattern(self, pattern: str, namespace: str = "rag") -> int: """Invalidate keys matching pattern""" full_pattern = f"{self._get_key_prefix(namespace)}{pattern}" count = 0 - + if self.redis_client: try: keys = self.redis_client.keys(full_pattern) @@ -284,42 +294,41 @@ def invalidate_by_pattern(self, pattern: str, namespace: str = "rag") -> int: count = self.redis_client.delete(*keys) except Exception: pass - + # Memory cache keys_to_delete = [ - k for k in self.memory_cache.keys() - if self._match_pattern(k, pattern) + k for k in self.memory_cache.keys() if self._match_pattern(k, pattern) ] for key in keys_to_delete: self.delete(key, namespace) count += 1 - + return count - + def _match_pattern(self, key: str, pattern: str) -> bool: """Simple pattern matching (supports * wildcard)""" import fnmatch + return fnmatch.fnmatch(key, pattern) - + def _evict_lru(self): """Evict least recently used entry""" if not self.memory_cache: return - + lru_key = min( self.memory_cache.keys(), key=lambda k: ( - self.memory_cache[k].last_accessed or - self.memory_cache[k].created_at - ) + self.memory_cache[k].last_accessed or self.memory_cache[k].created_at + ), ) del self.memory_cache[lru_key] - + def _update_redis_entry(self, key: str, entry: CacheEntry): """Update Redis entry with new access info""" if not self.redis_client: return - + try: entry_data = json.dumps(entry.to_dict()) # Get remaining TTL @@ -328,7 +337,7 @@ def _update_redis_entry(self, key: str, entry: CacheEntry): self.redis_client.setex(key, ttl, entry_data) except Exception: pass - + def get_stats(self) -> Dict[str, Any]: """Get cache statistics""" stats = { @@ -337,14 +346,14 @@ def get_stats(self) -> Dict[str, Any]: "tag_index_size": len(self.tag_index), "redis_available": self.redis_client is not None, } - + if self.memory_cache: total_access = sum(e.access_count for e in self.memory_cache.values()) stats["total_accesses"] = total_access stats["avg_access_per_entry"] = total_access / len(self.memory_cache) - + return stats - + def clear(self, namespace: str = "rag"): """Clear all cache entries in namespace""" pattern = f"{self._get_key_prefix(namespace)}*" @@ -353,52 +362,48 @@ def clear(self, namespace: str = "rag"): class EmbeddingCache: """Specialized cache for embeddings""" - + def __init__(self, cache_manager: AdvancedCacheManager): self.cache_manager = cache_manager self.namespace = "rag:embeddings" - + def get_embedding_key(self, text: str) -> str: """Generate cache key for embedding""" text_hash = hashlib.md5(text.encode()).hexdigest() return f"emb:{text_hash}" - + def get(self, text: str) -> Optional[Any]: """Get cached embedding""" key = self.get_embedding_key(text) return self.cache_manager.get(key, namespace=self.namespace) - + def set(self, text: str, embedding: Any, ttl: int = 86400): """Cache embedding (24 hour default TTL)""" key = self.get_embedding_key(text) return self.cache_manager.set( - key, - embedding, - namespace=self.namespace, - ttl=ttl, - tags=["embedding"] + key, embedding, namespace=self.namespace, ttl=ttl, tags=["embedding"] ) class QueryResultCache: """Specialized cache for query results""" - + def __init__(self, cache_manager: AdvancedCacheManager): self.cache_manager = cache_manager self.namespace = "rag:queries" - + def get_query_key(self, query: RAGQuery) -> str: """Generate cache key for query""" query_dict = query.to_dict() query_str = json.dumps(query_dict, sort_keys=True) query_hash = hashlib.md5(query_str.encode()).hexdigest() return f"query:{query_hash}" - + def get(self, query: RAGQuery) -> Optional[List[RetrievalResult]]: """Get cached query results""" key = self.get_query_key(query) cached = self.cache_manager.get(key, namespace=self.namespace) - + if cached: # Reconstruct RetrievalResult objects results = [] @@ -407,60 +412,55 @@ def get(self, query: RAGQuery) -> Optional[List[RetrievalResult]]: result = RetrievalResult( document=doc, score=item["score"], - rerank_score=item.get("rerank_score") + rerank_score=item.get("rerank_score"), ) results.append(result) return results - + return None - + def set( self, query: RAGQuery, results: List[RetrievalResult], ttl: int = 3600, - tags: Optional[List[str]] = None + tags: Optional[List[str]] = None, ): """Cache query results""" key = self.get_query_key(query) - + # Serialize results serialized = [ { "document": r.document.to_dict(), "score": r.score, - "rerank_score": r.rerank_score + "rerank_score": r.rerank_score, } for r in results ] - + # Add query tags query_tags = tags or [] if query.industry: - query_tags.append(f"industry:{query.industry.value if hasattr(query.industry, 'value') else query.industry}") + query_tags.append( + f"industry:{query.industry.value if hasattr(query.industry, 'value') else query.industry}" + ) if query.doc_types: for dt in query.doc_types: query_tags.append(f"doctype:{dt.value if hasattr(dt, 'value') else dt}") - + return self.cache_manager.set( - key, - serialized, - namespace=self.namespace, - ttl=ttl, - tags=query_tags + key, serialized, namespace=self.namespace, ttl=ttl, tags=query_tags ) - + def invalidate_by_industry(self, industry: str): """Invalidate all queries for an industry""" return self.cache_manager.invalidate_by_tag( - f"industry:{industry}", - namespace=self.namespace + f"industry:{industry}", namespace=self.namespace ) - + def invalidate_by_doc_type(self, doc_type: str): """Invalidate all queries for a document type""" return self.cache_manager.invalidate_by_tag( - f"doctype:{doc_type}", - namespace=self.namespace + f"doctype:{doc_type}", namespace=self.namespace ) - diff --git a/src/deepiri_modelkit/rag/monitoring.py b/src/deepiri_modelkit/rag/monitoring.py index 95429b3..da5e6fa 100644 --- a/src/deepiri_modelkit/rag/monitoring.py +++ b/src/deepiri_modelkit/rag/monitoring.py @@ -16,6 +16,7 @@ @dataclass class RetrievalMetrics: """Metrics for a single retrieval operation""" + query_id: str query_text: str timestamp: datetime @@ -27,7 +28,7 @@ class RetrievalMetrics: query_expansion_used: bool = False industry: Optional[str] = None doc_types: Optional[List[str]] = None - + def to_dict(self) -> Dict: """Convert to dictionary""" return { @@ -48,6 +49,7 @@ def to_dict(self) -> Dict: @dataclass class IndexingMetrics: """Metrics for indexing operations""" + operation_id: str timestamp: datetime operation_type: str # "index", "update", "delete" @@ -55,7 +57,7 @@ class IndexingMetrics: processing_time_ms: float success: bool error: Optional[str] = None - + def to_dict(self) -> Dict: """Convert to dictionary""" return { @@ -72,6 +74,7 @@ def to_dict(self) -> Dict: @dataclass class SystemMetrics: """System-wide metrics""" + total_queries: int = 0 total_indexed_documents: int = 0 cache_hit_rate: float = 0.0 @@ -79,7 +82,7 @@ class SystemMetrics: avg_indexing_time_ms: float = 0.0 error_rate: float = 0.0 last_updated: Optional[datetime] = None - + def to_dict(self) -> Dict: """Convert to dictionary""" return { @@ -89,7 +92,9 @@ def to_dict(self) -> Dict: "avg_retrieval_time_ms": self.avg_retrieval_time_ms, "avg_indexing_time_ms": self.avg_indexing_time_ms, "error_rate": self.error_rate, - "last_updated": self.last_updated.isoformat() if self.last_updated else None, + "last_updated": ( + self.last_updated.isoformat() if self.last_updated else None + ), } @@ -97,21 +102,21 @@ class RAGMonitor: """ Monitor RAG system performance and collect metrics """ - + def __init__(self, max_history: int = 10000): self.max_history = max_history - + # Metrics storage self.retrieval_metrics: List[RetrievalMetrics] = [] self.indexing_metrics: List[IndexingMetrics] = [] - + # Aggregated stats self.system_metrics = SystemMetrics() - + # Time windows for analysis self.hourly_stats: Dict[str, Dict] = defaultdict(dict) self.daily_stats: Dict[str, Dict] = defaultdict(dict) - + def record_retrieval( self, query: RAGQuery, @@ -119,11 +124,11 @@ def record_retrieval( retrieval_time_ms: float, cache_hit: bool = False, reranking_used: bool = False, - query_expansion_used: bool = False + query_expansion_used: bool = False, ): """Record retrieval metrics""" query_id = f"q_{int(time.time() * 1000)}" - + metric = RetrievalMetrics( query_id=query_id, query_text=query.query, @@ -134,30 +139,41 @@ def record_retrieval( cache_hit=cache_hit, reranking_used=reranking_used, query_expansion_used=query_expansion_used, - industry=query.industry.value if query.industry and hasattr(query.industry, 'value') else str(query.industry), - doc_types=[dt.value if hasattr(dt, 'value') else str(dt) for dt in query.doc_types] if query.doc_types else None, + industry=( + query.industry.value + if query.industry and hasattr(query.industry, "value") + else str(query.industry) + ), + doc_types=( + [ + dt.value if hasattr(dt, "value") else str(dt) + for dt in query.doc_types + ] + if query.doc_types + else None + ), ) - + self.retrieval_metrics.append(metric) - + # Trim history if len(self.retrieval_metrics) > self.max_history: - self.retrieval_metrics = self.retrieval_metrics[-self.max_history:] - + self.retrieval_metrics = self.retrieval_metrics[-self.max_history :] + # Update aggregated stats self._update_system_metrics() - + def record_indexing( self, operation_type: str, num_documents: int, processing_time_ms: float, success: bool, - error: Optional[str] = None + error: Optional[str] = None, ): """Record indexing metrics""" operation_id = f"idx_{int(time.time() * 1000)}" - + metric = IndexingMetrics( operation_id=operation_id, timestamp=datetime.now(), @@ -165,71 +181,81 @@ def record_indexing( num_documents=num_documents, processing_time_ms=processing_time_ms, success=success, - error=error + error=error, ) - + self.indexing_metrics.append(metric) - + # Trim history if len(self.indexing_metrics) > self.max_history: - self.indexing_metrics = self.indexing_metrics[-self.max_history:] - + self.indexing_metrics = self.indexing_metrics[-self.max_history :] + # Update aggregated stats self._update_system_metrics() - + def _update_system_metrics(self): """Update system-wide aggregated metrics""" if not self.retrieval_metrics: return - + # Calculate averages total_queries = len(self.retrieval_metrics) cache_hits = sum(1 for m in self.retrieval_metrics if m.cache_hit) total_retrieval_time = sum(m.retrieval_time_ms for m in self.retrieval_metrics) - + self.system_metrics.total_queries = total_queries - self.system_metrics.cache_hit_rate = cache_hits / total_queries if total_queries > 0 else 0.0 - self.system_metrics.avg_retrieval_time_ms = total_retrieval_time / total_queries if total_queries > 0 else 0.0 - + self.system_metrics.cache_hit_rate = ( + cache_hits / total_queries if total_queries > 0 else 0.0 + ) + self.system_metrics.avg_retrieval_time_ms = ( + total_retrieval_time / total_queries if total_queries > 0 else 0.0 + ) + if self.indexing_metrics: - total_indexed = sum(m.num_documents for m in self.indexing_metrics if m.success) - total_indexing_time = sum(m.processing_time_ms for m in self.indexing_metrics) + total_indexed = sum( + m.num_documents for m in self.indexing_metrics if m.success + ) + total_indexing_time = sum( + m.processing_time_ms for m in self.indexing_metrics + ) total_indexing_ops = len(self.indexing_metrics) - + self.system_metrics.total_indexed_documents = total_indexed - self.system_metrics.avg_indexing_time_ms = total_indexing_time / total_indexing_ops if total_indexing_ops > 0 else 0.0 - + self.system_metrics.avg_indexing_time_ms = ( + total_indexing_time / total_indexing_ops + if total_indexing_ops > 0 + else 0.0 + ) + # Error rate total_ops = total_queries + len(self.indexing_metrics) errors = sum(1 for m in self.indexing_metrics if not m.success) self.system_metrics.error_rate = errors / total_ops if total_ops > 0 else 0.0 - + self.system_metrics.last_updated = datetime.now() - + def get_retrieval_stats( - self, - time_window_minutes: Optional[int] = None, - industry: Optional[str] = None + self, time_window_minutes: Optional[int] = None, industry: Optional[str] = None ) -> Dict[str, Any]: """Get retrieval statistics for time window""" metrics = self.retrieval_metrics - + # Filter by time window if time_window_minutes: cutoff = datetime.now() - timedelta(minutes=time_window_minutes) metrics = [m for m in metrics if m.timestamp >= cutoff] - + # Filter by industry if industry: metrics = [m for m in metrics if m.industry == industry] - + if not metrics: return { "count": 0, "avg_time_ms": 0.0, "cache_hit_rate": 0.0, } - + return { "count": len(metrics), "avg_time_ms": sum(m.retrieval_time_ms for m in metrics) / len(metrics), @@ -237,21 +263,21 @@ def get_retrieval_stats( "max_time_ms": max(m.retrieval_time_ms for m in metrics), "cache_hit_rate": sum(1 for m in metrics if m.cache_hit) / len(metrics), "avg_results": sum(m.num_results for m in metrics) / len(metrics), - "avg_top_score": sum(m.top_score for m in metrics if m.top_score) / len([m for m in metrics if m.top_score]), + "avg_top_score": sum(m.top_score for m in metrics if m.top_score) + / len([m for m in metrics if m.top_score]), } - + def get_indexing_stats( - self, - time_window_minutes: Optional[int] = None + self, time_window_minutes: Optional[int] = None ) -> Dict[str, Any]: """Get indexing statistics""" metrics = self.indexing_metrics - + # Filter by time window if time_window_minutes: cutoff = datetime.now() - timedelta(minutes=time_window_minutes) metrics = [m for m in metrics if m.timestamp >= cutoff] - + if not metrics: return { "count": 0, @@ -259,9 +285,9 @@ def get_indexing_stats( "avg_time_ms": 0.0, "success_rate": 0.0, } - + successful = [m for m in metrics if m.success] - + return { "count": len(metrics), "total_documents": sum(m.num_documents for m in successful), @@ -269,40 +295,30 @@ def get_indexing_stats( "success_rate": len(successful) / len(metrics) if metrics else 0.0, "error_count": len([m for m in metrics if not m.success]), } - + def get_top_queries( - self, - limit: int = 10, - time_window_minutes: Optional[int] = None + self, limit: int = 10, time_window_minutes: Optional[int] = None ) -> List[Dict[str, Any]]: """Get most frequent queries""" metrics = self.retrieval_metrics - + # Filter by time window if time_window_minutes: cutoff = datetime.now() - timedelta(minutes=time_window_minutes) metrics = [m for m in metrics if m.timestamp >= cutoff] - + # Count query frequencies query_counts: Dict[str, int] = defaultdict(int) for m in metrics: query_counts[m.query_text] += 1 - + # Sort by frequency - top_queries = sorted( - query_counts.items(), - key=lambda x: x[1], - reverse=True - )[:limit] - - return [ - { - "query": query, - "count": count - } - for query, count in top_queries + top_queries = sorted(query_counts.items(), key=lambda x: x[1], reverse=True)[ + :limit ] - + + return [{"query": query, "count": count} for query, count in top_queries] + def get_performance_report(self) -> Dict[str, Any]: """Get comprehensive performance report""" return { @@ -313,37 +329,43 @@ def get_performance_report(self) -> Dict[str, Any]: "indexing_stats_24h": self.get_indexing_stats(time_window_minutes=1440), "top_queries_24h": self.get_top_queries(limit=10, time_window_minutes=1440), } - + def export_metrics(self, filepath: str): """Export metrics to JSON file""" data = { - "retrieval_metrics": [m.to_dict() for m in self.retrieval_metrics[-1000:]], # Last 1000 - "indexing_metrics": [m.to_dict() for m in self.indexing_metrics[-1000:]], # Last 1000 + "retrieval_metrics": [ + m.to_dict() for m in self.retrieval_metrics[-1000:] + ], # Last 1000 + "indexing_metrics": [ + m.to_dict() for m in self.indexing_metrics[-1000:] + ], # Last 1000 "system_metrics": self.system_metrics.to_dict(), "exported_at": datetime.now().isoformat(), } - - with open(filepath, 'w') as f: + + with open(filepath, "w") as f: json.dump(data, f, indent=2) class PerformanceTimer: """Context manager for timing operations""" - - def __init__(self, monitor: Optional[RAGMonitor] = None, operation_name: str = "operation"): + + def __init__( + self, monitor: Optional[RAGMonitor] = None, operation_name: str = "operation" + ): self.monitor = monitor self.operation_name = operation_name self.start_time = None self.end_time = None - + def __enter__(self): self.start_time = time.time() return self - + def __exit__(self, exc_type, exc_val, exc_tb): self.end_time = time.time() return False - + def elapsed_ms(self) -> float: """Get elapsed time in milliseconds""" if self.start_time and self.end_time: @@ -351,4 +373,3 @@ def elapsed_ms(self) -> float: elif self.start_time: return (time.time() - self.start_time) * 1000 return 0.0 - diff --git a/src/deepiri_modelkit/rag/processors.py b/src/deepiri_modelkit/rag/processors.py index 3fdb477..5585cde 100644 --- a/src/deepiri_modelkit/rag/processors.py +++ b/src/deepiri_modelkit/rag/processors.py @@ -13,60 +13,62 @@ class DocumentProcessor(ABC): """Base class for document processing""" - + def __init__(self, chunk_size: int = 1000, chunk_overlap: int = 200): self.chunk_size = chunk_size self.chunk_overlap = chunk_overlap - + @abstractmethod def process(self, raw_content: str, metadata: Dict[str, Any]) -> List[Document]: """ Process raw content into structured documents - + Args: raw_content: Raw text content metadata: Document metadata - + Returns: List of processed document chunks """ pass - + def chunk_text(self, text: str) -> List[str]: """ Chunk text into smaller pieces with overlap - + Args: text: Text to chunk - + Returns: List of text chunks """ chunks = [] start = 0 text_length = len(text) - + while start < text_length: end = start + self.chunk_size - + # Find the last sentence boundary within chunk_size if end < text_length: # Look for sentence endings - for sep in ['. ', '.\n', '! ', '!\n', '? ', '?\n']: + for sep in [". ", ".\n", "! ", "!\n", "? ", "?\n"]: last_sep = text.rfind(sep, start, end) if last_sep != -1: end = last_sep + 1 break - + chunk = text[start:end].strip() if chunk: chunks.append(chunk) - + # Move start forward with overlap - start = end - self.chunk_overlap if end - self.chunk_overlap > start else end - + start = ( + end - self.chunk_overlap if end - self.chunk_overlap > start else end + ) + return chunks - + def extract_metadata(self, content: str) -> Dict[str, Any]: """Extract metadata from content (can be overridden)""" return {} @@ -77,72 +79,71 @@ class RegulationProcessor(DocumentProcessor): Processor for regulations, policies, and compliance documents Common across insurance, healthcare, manufacturing, etc. """ - + def process(self, raw_content: str, metadata: Dict[str, Any]) -> List[Document]: """Process regulation documents""" # Extract sections and subsections sections = self._extract_sections(raw_content) - + documents = [] - base_id = metadata.get('id', 'reg_' + str(hash(raw_content[:100]))) - industry = IndustryNiche(metadata.get('industry', 'generic')) - + base_id = metadata.get("id", "reg_" + str(hash(raw_content[:100]))) + industry = IndustryNiche(metadata.get("industry", "generic")) + for idx, section in enumerate(sections): doc = Document( id=f"{base_id}_chunk_{idx}", - content=section['content'], + content=section["content"], doc_type=DocumentType.REGULATION, industry=industry, - title=metadata.get('title', 'Regulation Document'), - source=metadata.get('source'), - created_at=self._parse_date(metadata.get('created_at')), + title=metadata.get("title", "Regulation Document"), + source=metadata.get("source"), + created_at=self._parse_date(metadata.get("created_at")), metadata={ **metadata, - 'section': section.get('section'), - 'subsection': section.get('subsection'), + "section": section.get("section"), + "subsection": section.get("subsection"), }, chunk_index=idx, total_chunks=len(sections), ) documents.append(doc) - + return documents - + def _extract_sections(self, content: str) -> List[Dict[str, Any]]: """Extract sections from regulation text""" # Match section headers like "Section 1.2.3", "Article 5", etc. - section_pattern = r'(Section|Article|Part|Chapter)\s+(\d+(?:\.\d+)*)' - + section_pattern = r"(Section|Article|Part|Chapter)\s+(\d+(?:\.\d+)*)" + sections = [] - current_section = {'section': None, 'content': ''} - - lines = content.split('\n') + current_section = {"section": None, "content": ""} + + lines = content.split("\n") for line in lines: match = re.search(section_pattern, line, re.IGNORECASE) if match: # Save previous section - if current_section['content']: + if current_section["content"]: sections.append(current_section) # Start new section - current_section = { - 'section': match.group(0), - 'content': line + '\n' - } + current_section = {"section": match.group(0), "content": line + "\n"} else: - current_section['content'] += line + '\n' - + current_section["content"] += line + "\n" + # Add last section - if current_section['content']: + if current_section["content"]: sections.append(current_section) - + # If no sections found, treat entire content as one chunk if not sections: chunks = self.chunk_text(content) - sections = [{'section': f'Chunk {i+1}', 'content': chunk} - for i, chunk in enumerate(chunks)] - + sections = [ + {"section": f"Chunk {i+1}", "content": chunk} + for i, chunk in enumerate(chunks) + ] + return sections - + def _parse_date(self, date_str: Optional[str]) -> Optional[datetime]: """Parse date string to datetime""" if not date_str: @@ -158,45 +159,53 @@ class HistoricalDataProcessor(DocumentProcessor): Processor for historical operational data - Work orders, maintenance logs, claim records, service history """ - + def process(self, raw_content: str, metadata: Dict[str, Any]) -> List[Document]: """Process historical data records""" - doc_type_str = metadata.get('doc_type', 'work_order') - doc_type = DocumentType(doc_type_str) if doc_type_str else DocumentType.WORK_ORDER - + doc_type_str = metadata.get("doc_type", "work_order") + doc_type = ( + DocumentType(doc_type_str) if doc_type_str else DocumentType.WORK_ORDER + ) + # Historical data is typically already structured # We may need minimal chunking - chunks = self.chunk_text(raw_content) if len(raw_content) > self.chunk_size else [raw_content] - + chunks = ( + self.chunk_text(raw_content) + if len(raw_content) > self.chunk_size + else [raw_content] + ) + documents = [] - base_id = metadata.get('id', 'hist_' + str(hash(raw_content[:100]))) - industry = IndustryNiche(metadata.get('industry', 'generic')) - + base_id = metadata.get("id", "hist_" + str(hash(raw_content[:100]))) + industry = IndustryNiche(metadata.get("industry", "generic")) + for idx, chunk in enumerate(chunks): doc = Document( id=f"{base_id}_chunk_{idx}", content=chunk, doc_type=doc_type, industry=industry, - title=metadata.get('title', f'{doc_type.value.replace("_", " ").title()}'), - source=metadata.get('source'), - created_at=self._parse_date(metadata.get('created_at')), - updated_at=self._parse_date(metadata.get('updated_at')), + title=metadata.get( + "title", f'{doc_type.value.replace("_", " ").title()}' + ), + source=metadata.get("source"), + created_at=self._parse_date(metadata.get("created_at")), + updated_at=self._parse_date(metadata.get("updated_at")), metadata={ **metadata, - 'record_type': doc_type.value, - 'record_id': metadata.get('record_id'), - 'status': metadata.get('status'), - 'priority': metadata.get('priority'), - 'assigned_to': metadata.get('assigned_to'), + "record_type": doc_type.value, + "record_id": metadata.get("record_id"), + "status": metadata.get("status"), + "priority": metadata.get("priority"), + "assigned_to": metadata.get("assigned_to"), }, chunk_index=idx, total_chunks=len(chunks), ) documents.append(doc) - + return documents - + def _parse_date(self, date_str: Optional[str]) -> Optional[datetime]: """Parse date string to datetime""" if not date_str: @@ -207,7 +216,7 @@ def _parse_date(self, date_str: Optional[str]) -> Optional[datetime]: return datetime.fromisoformat(date_str) except (ValueError, AttributeError): try: - return datetime.strptime(date_str, '%Y-%m-%d') + return datetime.strptime(date_str, "%Y-%m-%d") except (ValueError, AttributeError): return None @@ -217,73 +226,83 @@ class KnowledgeBaseProcessor(DocumentProcessor): Processor for knowledge base articles, FAQs, and guides - Equipment repair guides, compliance advice, troubleshooting steps """ - + def process(self, raw_content: str, metadata: Dict[str, Any]) -> List[Document]: """Process knowledge base articles""" - doc_type = DocumentType.KNOWLEDGE_BASE if metadata.get('doc_type') == 'knowledge_base' else DocumentType.FAQ - + doc_type = ( + DocumentType.KNOWLEDGE_BASE + if metadata.get("doc_type") == "knowledge_base" + else DocumentType.FAQ + ) + # Extract Q&A pairs if FAQ format if doc_type == DocumentType.FAQ: qa_pairs = self._extract_qa_pairs(raw_content) if qa_pairs: return self._process_qa_pairs(qa_pairs, metadata) - + # Otherwise, process as regular article chunks = self.chunk_text(raw_content) - + documents = [] - base_id = metadata.get('id', 'kb_' + str(hash(raw_content[:100]))) - industry = IndustryNiche(metadata.get('industry', 'generic')) - + base_id = metadata.get("id", "kb_" + str(hash(raw_content[:100]))) + industry = IndustryNiche(metadata.get("industry", "generic")) + for idx, chunk in enumerate(chunks): doc = Document( id=f"{base_id}_chunk_{idx}", content=chunk, doc_type=doc_type, industry=industry, - title=metadata.get('title', 'Knowledge Base Article'), - source=metadata.get('source'), - created_at=self._parse_date(metadata.get('created_at')), - updated_at=self._parse_date(metadata.get('updated_at')), - author=metadata.get('author'), + title=metadata.get("title", "Knowledge Base Article"), + source=metadata.get("source"), + created_at=self._parse_date(metadata.get("created_at")), + updated_at=self._parse_date(metadata.get("updated_at")), + author=metadata.get("author"), metadata={ **metadata, - 'category': metadata.get('category'), - 'tags': metadata.get('tags', []), - 'difficulty_level': metadata.get('difficulty_level'), + "category": metadata.get("category"), + "tags": metadata.get("tags", []), + "difficulty_level": metadata.get("difficulty_level"), }, chunk_index=idx, total_chunks=len(chunks), ) documents.append(doc) - + return documents - + def _extract_qa_pairs(self, content: str) -> List[Dict[str, str]]: """Extract Q&A pairs from FAQ content""" qa_pairs = [] - + # Try different FAQ formats # Format 1: Q: ... A: ... - pattern1 = r'Q:\s*(.+?)\s*A:\s*(.+?)(?=Q:|$)' + pattern1 = r"Q:\s*(.+?)\s*A:\s*(.+?)(?=Q:|$)" matches = re.findall(pattern1, content, re.DOTALL | re.IGNORECASE) if matches: - qa_pairs.extend([{'question': q.strip(), 'answer': a.strip()} for q, a in matches]) - + qa_pairs.extend( + [{"question": q.strip(), "answer": a.strip()} for q, a in matches] + ) + # Format 2: Question/Answer headers - pattern2 = r'Question:\s*(.+?)\s*Answer:\s*(.+?)(?=Question:|$)' + pattern2 = r"Question:\s*(.+?)\s*Answer:\s*(.+?)(?=Question:|$)" matches = re.findall(pattern2, content, re.DOTALL | re.IGNORECASE) if matches: - qa_pairs.extend([{'question': q.strip(), 'answer': a.strip()} for q, a in matches]) - + qa_pairs.extend( + [{"question": q.strip(), "answer": a.strip()} for q, a in matches] + ) + return qa_pairs - - def _process_qa_pairs(self, qa_pairs: List[Dict[str, str]], metadata: Dict[str, Any]) -> List[Document]: + + def _process_qa_pairs( + self, qa_pairs: List[Dict[str, str]], metadata: Dict[str, Any] + ) -> List[Document]: """Process Q&A pairs into documents""" documents = [] - base_id = metadata.get('id', 'faq_' + str(hash(str(qa_pairs[0])))) - industry = IndustryNiche(metadata.get('industry', 'generic')) - + base_id = metadata.get("id", "faq_" + str(hash(str(qa_pairs[0])))) + industry = IndustryNiche(metadata.get("industry", "generic")) + for idx, qa in enumerate(qa_pairs): content = f"Question: {qa['question']}\n\nAnswer: {qa['answer']}" doc = Document( @@ -291,20 +310,20 @@ def _process_qa_pairs(self, qa_pairs: List[Dict[str, str]], metadata: Dict[str, content=content, doc_type=DocumentType.FAQ, industry=industry, - title=qa['question'][:100], # Use question as title - source=metadata.get('source'), + title=qa["question"][:100], # Use question as title + source=metadata.get("source"), metadata={ **metadata, - 'question': qa['question'], - 'answer': qa['answer'], + "question": qa["question"], + "answer": qa["answer"], }, chunk_index=idx, total_chunks=len(qa_pairs), ) documents.append(doc) - + return documents - + def _parse_date(self, date_str: Optional[str]) -> Optional[datetime]: """Parse date string to datetime""" if not date_str: @@ -321,87 +340,89 @@ class ManualProcessor(DocumentProcessor): """ Processor for equipment manuals, operation guides, technical specifications """ - + def process(self, raw_content: str, metadata: Dict[str, Any]) -> List[Document]: """Process manual documents""" # Extract chapters and sections sections = self._extract_sections(raw_content) - + documents = [] - base_id = metadata.get('id', 'manual_' + str(hash(raw_content[:100]))) - industry = IndustryNiche(metadata.get('industry', 'generic')) - + base_id = metadata.get("id", "manual_" + str(hash(raw_content[:100]))) + industry = IndustryNiche(metadata.get("industry", "generic")) + for idx, section in enumerate(sections): doc = Document( id=f"{base_id}_chunk_{idx}", - content=section['content'], + content=section["content"], doc_type=DocumentType.MANUAL, industry=industry, - title=metadata.get('title', 'Equipment Manual'), - source=metadata.get('source'), - version=metadata.get('version'), + title=metadata.get("title", "Equipment Manual"), + source=metadata.get("source"), + version=metadata.get("version"), metadata={ **metadata, - 'chapter': section.get('chapter'), - 'section': section.get('section'), - 'equipment_model': metadata.get('equipment_model'), - 'manufacturer': metadata.get('manufacturer'), + "chapter": section.get("chapter"), + "section": section.get("section"), + "equipment_model": metadata.get("equipment_model"), + "manufacturer": metadata.get("manufacturer"), }, chunk_index=idx, total_chunks=len(sections), ) documents.append(doc) - + return documents - + def _extract_sections(self, content: str) -> List[Dict[str, Any]]: """Extract sections from manual text""" # Match chapter/section headers - section_pattern = r'(Chapter|Section)\s+(\d+(?:\.\d+)*):?\s*(.+?)(?=\n)' - + section_pattern = r"(Chapter|Section)\s+(\d+(?:\.\d+)*):?\s*(.+?)(?=\n)" + sections = [] - current_section = {'chapter': None, 'section': None, 'content': ''} - - lines = content.split('\n') + current_section = {"chapter": None, "section": None, "content": ""} + + lines = content.split("\n") for line in lines: match = re.search(section_pattern, line, re.IGNORECASE) if match: # Save previous section - if current_section['content']: + if current_section["content"]: sections.append(current_section) # Start new section section_type = match.group(1).lower() section_num = match.group(2) - section_title = match.group(3).strip() if match.group(3) else '' + section_title = match.group(3).strip() if match.group(3) else "" current_section = { section_type: f"{section_type.title()} {section_num}", - 'section_title': section_title, - 'content': line + '\n' + "section_title": section_title, + "content": line + "\n", } else: - current_section['content'] += line + '\n' - + current_section["content"] += line + "\n" + # Add last section - if current_section['content']: + if current_section["content"]: sections.append(current_section) - + # If no sections found, chunk the content if not sections: chunks = self.chunk_text(content) - sections = [{'section': f'Chunk {i+1}', 'content': chunk} - for i, chunk in enumerate(chunks)] - + sections = [ + {"section": f"Chunk {i+1}", "content": chunk} + for i, chunk in enumerate(chunks) + ] + return sections def get_processor(doc_type: DocumentType, **kwargs) -> DocumentProcessor: """ Factory function to get appropriate processor for document type - + Args: doc_type: Type of document **kwargs: Additional configuration for processor - + Returns: Configured document processor """ @@ -417,7 +438,6 @@ def get_processor(doc_type: DocumentType, **kwargs) -> DocumentProcessor: DocumentType.TECHNICAL_SPEC: ManualProcessor, # Similar processing DocumentType.PROCEDURE: ManualProcessor, # Similar processing } - + processor_class = processor_map.get(doc_type, DocumentProcessor) return processor_class(**kwargs) - diff --git a/src/deepiri_modelkit/rag/retrievers.py b/src/deepiri_modelkit/rag/retrievers.py index dcda0d0..bf40f86 100644 --- a/src/deepiri_modelkit/rag/retrievers.py +++ b/src/deepiri_modelkit/rag/retrievers.py @@ -12,7 +12,7 @@ class BaseRetriever(ABC): """Base class for retrievers""" - + @abstractmethod def retrieve(self, query: RAGQuery) -> List[RetrievalResult]: """Retrieve relevant documents for query""" @@ -24,30 +24,30 @@ class MultiModalRetriever(BaseRetriever): Multi-modal retriever supporting text, images, tables, and code Useful for technical manuals, equipment guides, etc. """ - + def __init__( self, text_embeddings, image_embeddings=None, table_embeddings=None, - code_embeddings=None + code_embeddings=None, ): self.text_embeddings = text_embeddings self.image_embeddings = image_embeddings self.table_embeddings = table_embeddings self.code_embeddings = code_embeddings - + def retrieve(self, query: RAGQuery) -> List[RetrievalResult]: """ Retrieve documents across multiple modalities - + Currently focuses on text, but can be extended for other modalities """ # For now, delegate to text retrieval # Future: Add image, table, code retrieval results = self._retrieve_text(query) return results - + def _retrieve_text(self, query: RAGQuery) -> List[RetrievalResult]: """Retrieve text documents""" # This will be implemented by the concrete RAG engine @@ -61,115 +61,117 @@ class HybridRetriever(BaseRetriever): - Semantic search (vector similarity) - Keyword search (BM25) - Metadata filtering - + Provides better recall than pure semantic search """ - + def __init__( self, vector_retriever, keyword_retriever=None, vector_weight: float = 0.7, - keyword_weight: float = 0.3 + keyword_weight: float = 0.3, ): self.vector_retriever = vector_retriever self.keyword_retriever = keyword_retriever self.vector_weight = vector_weight self.keyword_weight = keyword_weight - + def retrieve(self, query: RAGQuery) -> List[RetrievalResult]: """ Retrieve using hybrid approach - + Combines vector and keyword search results """ results = [] - + # Vector search vector_results = self._retrieve_vector(query) results.extend(vector_results) - + # Keyword search (if available) if self.keyword_retriever: keyword_results = self._retrieve_keyword(query) results.extend(keyword_results) - + # Merge and re-score - merged_results = self._merge_results(vector_results, keyword_results if self.keyword_retriever else []) - + merged_results = self._merge_results( + vector_results, keyword_results if self.keyword_retriever else [] + ) + return merged_results - + def _retrieve_vector(self, query: RAGQuery) -> List[RetrievalResult]: """Vector similarity search""" # Placeholder - implemented by concrete engine return [] - + def _retrieve_keyword(self, query: RAGQuery) -> List[RetrievalResult]: """Keyword search using BM25 or similar""" # Placeholder - implemented by concrete engine return [] - + def _merge_results( self, vector_results: List[RetrievalResult], - keyword_results: List[RetrievalResult] + keyword_results: List[RetrievalResult], ) -> List[RetrievalResult]: """ Merge results from different retrievers using weighted scoring - + Uses Reciprocal Rank Fusion (RRF) for combining rankings """ # Create a dictionary to store combined scores doc_scores: Dict[str, Dict[str, Any]] = {} - + # Process vector results for rank, result in enumerate(vector_results): doc_id = result.document.id # RRF score: 1 / (k + rank) where k=60 is common rrf_score = 1.0 / (60 + rank + 1) doc_scores[doc_id] = { - 'document': result.document, - 'vector_score': result.score, - 'vector_rrf': rrf_score, - 'keyword_rrf': 0.0, - 'keyword_score': 0.0, + "document": result.document, + "vector_score": result.score, + "vector_rrf": rrf_score, + "keyword_rrf": 0.0, + "keyword_score": 0.0, } - + # Process keyword results for rank, result in enumerate(keyword_results): doc_id = result.document.id rrf_score = 1.0 / (60 + rank + 1) - + if doc_id in doc_scores: - doc_scores[doc_id]['keyword_rrf'] = rrf_score - doc_scores[doc_id]['keyword_score'] = result.score + doc_scores[doc_id]["keyword_rrf"] = rrf_score + doc_scores[doc_id]["keyword_score"] = result.score else: doc_scores[doc_id] = { - 'document': result.document, - 'vector_score': 0.0, - 'vector_rrf': 0.0, - 'keyword_rrf': rrf_score, - 'keyword_score': result.score, + "document": result.document, + "vector_score": 0.0, + "vector_rrf": 0.0, + "keyword_rrf": rrf_score, + "keyword_score": result.score, } - + # Calculate combined scores merged = [] for doc_id, scores in doc_scores.items(): combined_rrf = ( - self.vector_weight * scores['vector_rrf'] + - self.keyword_weight * scores['keyword_rrf'] + self.vector_weight * scores["vector_rrf"] + + self.keyword_weight * scores["keyword_rrf"] ) - + result = RetrievalResult( - document=scores['document'], + document=scores["document"], score=combined_rrf, rerank_score=None, ) merged.append(result) - + # Sort by combined score merged.sort(key=lambda x: x.score, reverse=True) - + return merged @@ -180,54 +182,54 @@ class ContextualRetriever(BaseRetriever): - Temporal context (recent vs historical) - Industry context (specific to niche) """ - + def __init__( self, base_retriever: BaseRetriever, use_user_context: bool = True, use_temporal_context: bool = True, - use_industry_context: bool = True + use_industry_context: bool = True, ): self.base_retriever = base_retriever self.use_user_context = use_user_context self.use_temporal_context = use_temporal_context self.use_industry_context = use_industry_context - + def retrieve(self, query: RAGQuery) -> List[RetrievalResult]: """ Retrieve with contextual awareness """ # Get base results results = self.base_retriever.retrieve(query) - + # Apply contextual reranking if self.use_temporal_context: results = self._apply_temporal_boost(results, query) - + if self.use_industry_context: results = self._apply_industry_boost(results, query) - + return results - + def _apply_temporal_boost( - self, - results: List[RetrievalResult], - query: RAGQuery + self, results: List[RetrievalResult], query: RAGQuery ) -> List[RetrievalResult]: """ Boost recent documents (useful for regulations, updates) """ import time from datetime import datetime - + current_time = datetime.now().timestamp() - + for result in results: if result.document.updated_at or result.document.created_at: - doc_time = (result.document.updated_at or result.document.created_at).timestamp() + doc_time = ( + result.document.updated_at or result.document.created_at + ).timestamp() # Calculate age in days age_days = (current_time - doc_time) / 86400 - + # Apply decay: more recent = higher boost # Documents within 30 days get full boost # Older documents gradually lose boost @@ -239,28 +241,26 @@ def _apply_temporal_boost( temporal_boost = 0.8 else: temporal_boost = 0.7 - + result.score *= temporal_boost - + # Re-sort by adjusted scores results.sort(key=lambda x: x.score, reverse=True) return results - + def _apply_industry_boost( - self, - results: List[RetrievalResult], - query: RAGQuery + self, results: List[RetrievalResult], query: RAGQuery ) -> List[RetrievalResult]: """ Boost documents matching the query's industry """ if not query.industry: return results - + for result in results: if result.document.industry == query.industry: result.score *= 1.1 # 10% boost for industry match - + # Re-sort results.sort(key=lambda x: x.score, reverse=True) return results @@ -269,20 +269,19 @@ def _apply_industry_boost( def get_retriever(retriever_type: str, **kwargs) -> BaseRetriever: """ Factory function to get retriever - + Args: retriever_type: Type of retriever ('hybrid', 'multimodal', 'contextual') **kwargs: Configuration for retriever - + Returns: Configured retriever """ retriever_map = { - 'hybrid': HybridRetriever, - 'multimodal': MultiModalRetriever, - 'contextual': ContextualRetriever, + "hybrid": HybridRetriever, + "multimodal": MultiModalRetriever, + "contextual": ContextualRetriever, } - + retriever_class = retriever_map.get(retriever_type, HybridRetriever) return retriever_class(**kwargs) - diff --git a/src/deepiri_modelkit/rag/testing.py b/src/deepiri_modelkit/rag/testing.py index bc11473..f118dc2 100644 --- a/src/deepiri_modelkit/rag/testing.py +++ b/src/deepiri_modelkit/rag/testing.py @@ -14,13 +14,14 @@ @dataclass class TestCase: """Test case for RAG evaluation""" + query: str expected_doc_ids: List[str] # Document IDs that should be retrieved expected_doc_types: Optional[List[DocumentType]] = None min_score: float = 0.7 # Minimum similarity score top_k: int = 5 metadata: Dict[str, Any] = None - + def __post_init__(self): if self.metadata is None: self.metadata = {} @@ -29,6 +30,7 @@ def __post_init__(self): @dataclass class TestResult: """Result of a test case""" + test_case: TestCase retrieved_doc_ids: List[str] retrieved_scores: List[float] @@ -37,7 +39,7 @@ class TestResult: f1_score: float passed: bool error: Optional[str] = None - + def to_dict(self) -> Dict[str, Any]: """Convert to dictionary""" return { @@ -56,37 +58,37 @@ class RAGEvaluator: """ Evaluator for RAG system performance """ - + def __init__(self, rag_engine): self.rag_engine = rag_engine - + def evaluate( - self, - test_cases: List[TestCase], - industry: Optional[IndustryNiche] = None + self, test_cases: List[TestCase], industry: Optional[IndustryNiche] = None ) -> Dict[str, Any]: """ Evaluate RAG system on test cases - + Args: test_cases: List of test cases industry: Industry context - + Returns: Evaluation results with metrics """ results = [] - + for test_case in test_cases: result = self._evaluate_test_case(test_case, industry) results.append(result) - + # Calculate aggregate metrics - total_precision = sum(r.precision for r in results) / len(results) if results else 0.0 + total_precision = ( + sum(r.precision for r in results) / len(results) if results else 0.0 + ) total_recall = sum(r.recall for r in results) / len(results) if results else 0.0 total_f1 = sum(r.f1_score for r in results) / len(results) if results else 0.0 passed_count = sum(1 for r in results if r.passed) - + return { "total_tests": len(test_cases), "passed": passed_count, @@ -97,11 +99,9 @@ def evaluate( "pass_rate": passed_count / len(test_cases) if test_cases else 0.0, "results": [r.to_dict() for r in results], } - + def _evaluate_test_case( - self, - test_case: TestCase, - industry: Optional[IndustryNiche] + self, test_case: TestCase, industry: Optional[IndustryNiche] ) -> TestResult: """Evaluate a single test case""" try: @@ -110,20 +110,20 @@ def _evaluate_test_case( query=test_case.query, industry=industry, doc_types=test_case.expected_doc_types, - top_k=test_case.top_k + top_k=test_case.top_k, ) - + # Retrieve documents results = self.rag_engine.retrieve(query) - + # Extract IDs and scores retrieved_doc_ids = [r.document.id for r in results] retrieved_scores = [r.score for r in results] - + # Calculate precision, recall, F1 expected_set = set(test_case.expected_doc_ids) retrieved_set = set(retrieved_doc_ids) - + if not retrieved_set: precision = 0.0 recall = 0.0 @@ -131,22 +131,30 @@ def _evaluate_test_case( else: # Precision: relevant retrieved / total retrieved relevant_retrieved = len(expected_set & retrieved_set) - precision = relevant_retrieved / len(retrieved_set) if retrieved_set else 0.0 - + precision = ( + relevant_retrieved / len(retrieved_set) if retrieved_set else 0.0 + ) + # Recall: relevant retrieved / total relevant recall = relevant_retrieved / len(expected_set) if expected_set else 0.0 - + # F1 score - f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0 - + f1_score = ( + 2 * (precision * recall) / (precision + recall) + if (precision + recall) > 0 + else 0.0 + ) + # Check if passed passed = ( - precision >= 0.7 and - recall >= 0.7 and - f1_score >= 0.7 and - (not retrieved_scores or max(retrieved_scores) >= test_case.min_score) + precision >= 0.7 + and recall >= 0.7 + and f1_score >= 0.7 + and ( + not retrieved_scores or max(retrieved_scores) >= test_case.min_score + ) ) - + return TestResult( test_case=test_case, retrieved_doc_ids=retrieved_doc_ids, @@ -154,9 +162,9 @@ def _evaluate_test_case( precision=precision, recall=recall, f1_score=f1_score, - passed=passed + passed=passed, ) - + except Exception as e: return TestResult( test_case=test_case, @@ -166,7 +174,7 @@ def _evaluate_test_case( recall=0.0, f1_score=0.0, passed=False, - error=str(e) + error=str(e), ) @@ -174,44 +182,43 @@ class RAGTestFixture: """ Test fixture for creating test data and scenarios """ - + @staticmethod def create_test_documents( - industry: IndustryNiche = IndustryNiche.MANUFACTURING, - num_documents: int = 10 + industry: IndustryNiche = IndustryNiche.MANUFACTURING, num_documents: int = 10 ) -> List[Document]: """Create test documents""" documents = [] - + for i in range(num_documents): doc = Document( id=f"test_doc_{i}", content=f"Test document {i} content. This is sample content for testing RAG retrieval.", - doc_type=DocumentType.MANUAL if i % 2 == 0 else DocumentType.MAINTENANCE_LOG, + doc_type=( + DocumentType.MANUAL if i % 2 == 0 else DocumentType.MAINTENANCE_LOG + ), industry=industry, title=f"Test Document {i}", source="test_fixture", - metadata={"test_index": i} + metadata={"test_index": i}, ) documents.append(doc) - + return documents - + @staticmethod - def create_test_cases( - num_cases: int = 5 - ) -> List[TestCase]: + def create_test_cases(num_cases: int = 5) -> List[TestCase]: """Create test cases""" test_cases = [] - + for i in range(num_cases): test_case = TestCase( query=f"test query {i}", expected_doc_ids=[f"test_doc_{i}", f"test_doc_{i+1}"], - top_k=5 + top_k=5, ) test_cases.append(test_case) - + return test_cases @@ -219,47 +226,45 @@ class PerformanceBenchmark: """ Performance benchmarking for RAG operations """ - + def __init__(self, rag_engine): self.rag_engine = rag_engine - + def benchmark_retrieval( - self, - queries: List[str], - iterations: int = 10 + self, queries: List[str], iterations: int = 10 ) -> Dict[str, Any]: """ Benchmark retrieval performance - + Args: queries: List of test queries iterations: Number of iterations per query - + Returns: Performance metrics """ import time - + total_time = 0.0 total_queries = 0 times = [] - + for query_text in queries: query = RAGQuery(query=query_text, top_k=5) - + for _ in range(iterations): start = time.time() results = self.rag_engine.retrieve(query) elapsed = time.time() - start - + total_time += elapsed total_queries += 1 times.append(elapsed * 1000) # Convert to ms - + avg_time_ms = (total_time / total_queries) * 1000 if total_queries > 0 else 0.0 min_time_ms = min(times) if times else 0.0 max_time_ms = max(times) if times else 0.0 - + return { "total_queries": total_queries, "avg_time_ms": avg_time_ms, @@ -267,66 +272,63 @@ def benchmark_retrieval( "max_time_ms": max_time_ms, "queries_per_second": total_queries / total_time if total_time > 0 else 0.0, } - + def benchmark_indexing( - self, - documents: List[Document], - batch_sizes: List[int] = [1, 10, 100] + self, documents: List[Document], batch_sizes: List[int] = [1, 10, 100] ) -> Dict[str, Any]: """ Benchmark indexing performance - + Args: documents: Documents to index batch_sizes: Different batch sizes to test - + Returns: Performance metrics for each batch size """ import time - + results = {} - + for batch_size in batch_sizes: batches = [ - documents[i:i + batch_size] + documents[i : i + batch_size] for i in range(0, len(documents), batch_size) ] - + start = time.time() for batch in batches: self.rag_engine.index_documents(batch) elapsed = time.time() - start - + results[f"batch_size_{batch_size}"] = { "total_documents": len(documents), "num_batches": len(batches), "total_time_seconds": elapsed, - "avg_time_per_doc_ms": (elapsed / len(documents)) * 1000 if documents else 0.0, + "avg_time_per_doc_ms": ( + (elapsed / len(documents)) * 1000 if documents else 0.0 + ), "docs_per_second": len(documents) / elapsed if elapsed > 0 else 0.0, } - + return results def create_evaluation_dataset( - industry: IndustryNiche, - num_documents: int = 100, - num_queries: int = 20 + industry: IndustryNiche, num_documents: int = 100, num_queries: int = 20 ) -> Tuple[List[Document], List[TestCase]]: """ Create a complete evaluation dataset - + Args: industry: Industry for documents num_documents: Number of documents to create num_queries: Number of test queries - + Returns: Tuple of (documents, test_cases) """ documents = RAGTestFixture.create_test_documents(industry, num_documents) test_cases = RAGTestFixture.create_test_cases(num_queries) - - return documents, test_cases + return documents, test_cases diff --git a/src/deepiri_modelkit/registry/adapters/__init__.py b/src/deepiri_modelkit/registry/adapters/__init__.py index e949845..44e0e80 100644 --- a/src/deepiri_modelkit/registry/adapters/__init__.py +++ b/src/deepiri_modelkit/registry/adapters/__init__.py @@ -1,2 +1 @@ """Storage adapters for model registry""" - diff --git a/src/deepiri_modelkit/registry/model_registry.py b/src/deepiri_modelkit/registry/model_registry.py index 5c5ffdd..1d2fe3c 100644 --- a/src/deepiri_modelkit/registry/model_registry.py +++ b/src/deepiri_modelkit/registry/model_registry.py @@ -2,6 +2,7 @@ Unified model registry client Supports MLflow, S3/MinIO, and local storage """ + import os from typing import Dict, Any, Optional from pathlib import Path @@ -17,7 +18,7 @@ class ModelRegistryClient: Unified client for model registry operations Supports MLflow, S3/MinIO, and local storage """ - + def __init__( self, registry_type: str = "mlflow", # mlflow, s3, local @@ -26,11 +27,11 @@ def __init__( s3_access_key: Optional[str] = None, s3_secret_key: Optional[str] = None, s3_bucket: Optional[str] = None, - local_path: Optional[str] = None + local_path: Optional[str] = None, ): """ Initialize model registry client - + Args: registry_type: Type of registry (mlflow, s3, local) mlflow_tracking_uri: MLflow tracking URI (defaults to MLFLOW_TRACKING_URI env var or http://mlflow:5000) @@ -41,19 +42,23 @@ def __init__( local_path: Local storage path """ self.registry_type = registry_type - + if registry_type == "mlflow": # Use provided URI, or environment variable, or default - tracking_uri = mlflow_tracking_uri or os.getenv("MLFLOW_TRACKING_URI") or "http://mlflow:5000" + tracking_uri = ( + mlflow_tracking_uri + or os.getenv("MLFLOW_TRACKING_URI") + or "http://mlflow:5000" + ) mlflow.set_tracking_uri(tracking_uri) self.client = mlflow self.tracking_uri = tracking_uri elif registry_type == "s3": self.s3_client = boto3.client( - 's3', + "s3", endpoint_url=s3_endpoint, aws_access_key_id=s3_access_key, - aws_secret_access_key=s3_secret_key + aws_secret_access_key=s3_secret_key, ) self.s3_bucket = s3_bucket elif registry_type == "local": @@ -61,23 +66,19 @@ def __init__( self.local_path.mkdir(parents=True, exist_ok=True) else: raise ValueError(f"Unknown registry type: {registry_type}") - + def register_model( - self, - model_name: str, - version: str, - model_path: str, - metadata: Dict[str, Any] + self, model_name: str, version: str, model_path: str, metadata: Dict[str, Any] ) -> bool: """ Register model in registry - + Args: model_name: Model name version: Model version model_path: Path to model file/directory metadata: Model metadata - + Returns: True if successful """ @@ -87,56 +88,55 @@ def register_model( model_uri = f"runs:/{metadata.get('run_id', 'latest')}/model" mlflow.register_model(model_uri, f"{model_name}-{version}") return True - + elif self.registry_type == "s3": # Upload to S3 s3_key = f"models/{model_name}/{version}/model" self.s3_client.upload_file(model_path, self.s3_bucket, s3_key) - + # Upload metadata import json + metadata_key = f"models/{model_name}/{version}/metadata.json" self.s3_client.put_object( - Bucket=self.s3_bucket, - Key=metadata_key, - Body=json.dumps(metadata) + Bucket=self.s3_bucket, Key=metadata_key, Body=json.dumps(metadata) ) return True - + elif self.registry_type == "local": # Copy to local storage model_dir = self.local_path / model_name / version model_dir.mkdir(parents=True, exist_ok=True) - + import shutil + if os.path.isdir(model_path): shutil.copytree(model_path, model_dir / "model", dirs_exist_ok=True) else: shutil.copy2(model_path, model_dir / "model") - + # Save metadata import json + with open(model_dir / "metadata.json", "w") as f: json.dump(metadata, f) - + return True - + except Exception as e: print(f"Error registering model: {e}") return False - + def get_model( - self, - model_name: str, - version: Optional[str] = None + self, model_name: str, version: Optional[str] = None ) -> Dict[str, Any]: """ Get model information from registry - + Args: model_name: Model name version: Model version (optional, gets latest if not specified) - + Returns: Model information dict """ @@ -146,95 +146,91 @@ def get_model( model_uri = f"models:/{model_name}/{version}" else: model_uri = f"models:/{model_name}/latest" - + model = mlflow.pyfunc.load_model(model_uri) - return { - "model": model, - "uri": model_uri, - "type": "mlflow" - } - + return {"model": model, "uri": model_uri, "type": "mlflow"} + elif self.registry_type == "s3": if not version: # List versions and get latest prefix = f"models/{model_name}/" response = self.s3_client.list_objects_v2( - Bucket=self.s3_bucket, - Prefix=prefix, - Delimiter="/" + Bucket=self.s3_bucket, Prefix=prefix, Delimiter="/" ) - versions = [obj["Prefix"].split("/")[-2] for obj in response.get("CommonPrefixes", [])] + versions = [ + obj["Prefix"].split("/")[-2] + for obj in response.get("CommonPrefixes", []) + ] version = max(versions) if versions else None - + if not version: raise ValueError(f"Model {model_name} not found") - + # Download metadata metadata_key = f"models/{model_name}/{version}/metadata.json" - response = self.s3_client.get_object(Bucket=self.s3_bucket, Key=metadata_key) + response = self.s3_client.get_object( + Bucket=self.s3_bucket, Key=metadata_key + ) import json + metadata = json.loads(response["Body"].read()) - + return { "model_path": f"s3://{self.s3_bucket}/models/{model_name}/{version}/model", "metadata": metadata, - "type": "s3" + "type": "s3", } - + elif self.registry_type == "local": if not version: # Get latest version model_dir = self.local_path / model_name if not model_dir.exists(): raise ValueError(f"Model {model_name} not found") - + versions = [d.name for d in model_dir.iterdir() if d.is_dir()] version = max(versions) if versions else None - + if not version: raise ValueError(f"Model {model_name} not found") - + model_dir = self.local_path / model_name / version metadata_path = model_dir / "metadata.json" - + import json + with open(metadata_path) as f: metadata = json.load(f) - + return { "model_path": str(model_dir / "model"), "metadata": metadata, - "type": "local" + "type": "local", } - + except Exception as e: print(f"Error getting model: {e}") raise - - def download_model( - self, - model_name: str, - version: str, - destination: str - ) -> str: + + def download_model(self, model_name: str, version: str, destination: str) -> str: """ Download model to destination - + Args: model_name: Model name version: Model version destination: Local destination path - + Returns: Local path to downloaded model """ model_info = self.get_model(model_name, version) - + if self.registry_type == "s3": # Download from S3 s3_key = f"models/{model_name}/{version}/model" os.makedirs(destination, exist_ok=True) - + # Check if it's a file or directory try: self.s3_client.head_object(Bucket=self.s3_bucket, Key=s3_key) @@ -245,92 +241,93 @@ def download_model( except ClientError: # It's a directory, list and download all files prefix = f"{s3_key}/" - paginator = self.s3_client.get_paginator('list_objects_v2') + paginator = self.s3_client.get_paginator("list_objects_v2") for page in paginator.paginate(Bucket=self.s3_bucket, Prefix=prefix): - for obj in page.get('Contents', []): - key = obj['Key'] - local_file = os.path.join(destination, key[len(prefix):]) + for obj in page.get("Contents", []): + key = obj["Key"] + local_file = os.path.join(destination, key[len(prefix) :]) os.makedirs(os.path.dirname(local_file), exist_ok=True) self.s3_client.download_file(self.s3_bucket, key, local_file) return destination - + elif self.registry_type == "local": # Copy from local source = model_info["model_path"] import shutil + if os.path.isdir(source): shutil.copytree(source, destination, dirs_exist_ok=True) else: os.makedirs(os.path.dirname(destination), exist_ok=True) shutil.copy2(source, destination) return destination - + else: # MLflow handles loading directly return model_info["uri"] - + def list_models(self, model_name: Optional[str] = None) -> list: """ List available models - + Args: model_name: Optional model name filter - + Returns: List of model information dicts """ try: if self.registry_type == "mlflow": if model_name: - models = mlflow.search_registered_models(filter_string=f"name='{model_name}'") + models = mlflow.search_registered_models( + filter_string=f"name='{model_name}'" + ) else: models = mlflow.search_registered_models() - + return [ { "name": m.name, "versions": [v.version for v in m.latest_versions], - "latest_version": m.latest_versions[0].version if m.latest_versions else None + "latest_version": ( + m.latest_versions[0].version if m.latest_versions else None + ), } for m in models ] - + elif self.registry_type == "s3": prefix = "models/" if model_name: prefix = f"models/{model_name}/" - + response = self.s3_client.list_objects_v2( - Bucket=self.s3_bucket, - Prefix=prefix, - Delimiter="/" + Bucket=self.s3_bucket, Prefix=prefix, Delimiter="/" ) - + models = [] for prefix_obj in response.get("CommonPrefixes", []): model_path = prefix_obj["Prefix"] parts = model_path.strip("/").split("/") if len(parts) >= 2: - models.append({ - "name": parts[1], - "path": model_path - }) - + models.append({"name": parts[1], "path": model_path}) + return models - + elif self.registry_type == "local": models = [] for model_dir in self.local_path.iterdir(): if model_dir.is_dir(): versions = [d.name for d in model_dir.iterdir() if d.is_dir()] - models.append({ - "name": model_dir.name, - "versions": versions, - "latest_version": max(versions) if versions else None - }) + models.append( + { + "name": model_dir.name, + "versions": versions, + "latest_version": max(versions) if versions else None, + } + ) return models - + except Exception as e: print(f"Error listing models: {e}") return [] - diff --git a/src/deepiri_modelkit/streaming/event_stream.py b/src/deepiri_modelkit/streaming/event_stream.py index 40e5ed4..904fe0b 100644 --- a/src/deepiri_modelkit/streaming/event_stream.py +++ b/src/deepiri_modelkit/streaming/event_stream.py @@ -1,6 +1,7 @@ """ Redis Streams client for event-driven architecture """ + import redis.asyncio as redis from typing import Dict, Any, Optional, AsyncIterator, Callable import json @@ -15,17 +16,17 @@ class StreamingClient: """ Redis Streams client for publishing and subscribing to events """ - + def __init__( self, redis_url: Optional[str] = None, redis_host: str = "redis", redis_port: int = 6379, - redis_password: Optional[str] = None + redis_password: Optional[str] = None, ): """ Initialize streaming client - + Args: redis_url: Full Redis URL (redis://password@host:port) redis_host: Redis host (if not using redis_url) @@ -39,50 +40,44 @@ def __init__( host=redis_host, port=redis_port, password=redis_password, - decode_responses=True + decode_responses=True, ) self._running = False self._subscriptions = {} - + async def connect(self): """Connect to Redis""" await self.redis.ping() - + async def disconnect(self): """Disconnect from Redis""" await self.redis.close() - + async def publish( - self, - topic: str, - event: Dict[str, Any], - max_length: Optional[int] = 10000 + self, topic: str, event: Dict[str, Any], max_length: Optional[int] = 10000 ) -> str: """ Publish event to stream - + Args: topic: Stream topic name event: Event data (dict) max_length: Max stream length (truncate old messages) - + Returns: Message ID """ # Ensure event has timestamp if "timestamp" not in event: event["timestamp"] = datetime.utcnow().isoformat() - + # Publish to stream message_id = await self.redis.xadd( - topic, - event, - maxlen=max_length, - approximate=True + topic, event, maxlen=max_length, approximate=True ) - + return message_id - + async def subscribe( self, topic: str, @@ -90,11 +85,11 @@ async def subscribe( consumer_group: Optional[str] = None, consumer_name: Optional[str] = None, last_id: str = "0", - block_ms: int = 1000 + block_ms: int = 1000, ) -> AsyncIterator[Dict[str, Any]]: """ Subscribe to stream and yield events - + Args: topic: Stream topic name callback: Optional callback function @@ -102,7 +97,7 @@ async def subscribe( consumer_name: Consumer name (unique per consumer) last_id: Last message ID to read from block_ms: Block time in milliseconds - + Yields: Event data (dict) """ @@ -110,17 +105,14 @@ async def subscribe( if consumer_group: try: await self.redis.xgroup_create( - topic, - consumer_group, - id="0", - mkstream=True + topic, consumer_group, id="0", mkstream=True ) except redis.ResponseError as e: if "BUSYGROUP" not in str(e): raise - + self._running = True - + while self._running: try: if consumer_group and consumer_name: @@ -130,51 +122,53 @@ async def subscribe( consumer_name, {topic: ">"}, count=10, - block=block_ms + block=block_ms, ) else: # Direct read messages = await self.redis.xread( - {topic: last_id}, - count=10, - block=block_ms + {topic: last_id}, count=10, block=block_ms ) - + for stream_name, stream_messages in messages: for msg_id, data in stream_messages: # Yield event yield data - + # Call callback if provided if callback: try: - await callback(data) if asyncio.iscoroutinefunction(callback) else callback(data) + ( + await callback(data) + if asyncio.iscoroutinefunction(callback) + else callback(data) + ) except Exception as e: print(f"Callback error: {e}") - + # Update last_id for next read last_id = msg_id - + # Acknowledge if using consumer group if consumer_group and consumer_name: await self.redis.xack(topic, consumer_group, msg_id) - + except asyncio.CancelledError: break except Exception as e: print(f"Stream subscription error: {e}") await asyncio.sleep(1) - + async def subscribe_async( self, topic: str, callback: Callable[[Dict[str, Any]], None], consumer_group: Optional[str] = None, - consumer_name: Optional[str] = None + consumer_name: Optional[str] = None, ): """ Subscribe to stream in background task - + Args: topic: Stream topic name callback: Callback function @@ -182,23 +176,19 @@ async def subscribe_async( consumer_name: Consumer name """ async for event in self.subscribe( - topic, - callback, - consumer_group, - consumer_name + topic, callback, consumer_group, consumer_name ): pass # Callback handles events - + def stop(self): """Stop all subscriptions""" self._running = False - + async def get_stream_info(self, topic: str) -> Dict[str, Any]: """Get stream information""" info = await self.redis.xinfo_stream(topic) return dict(info) - + async def get_stream_length(self, topic: str) -> int: """Get number of messages in stream""" return await self.redis.xlen(topic) - diff --git a/src/deepiri_modelkit/streaming/schemas.py b/src/deepiri_modelkit/streaming/schemas.py index 388fa31..830cbae 100644 --- a/src/deepiri_modelkit/streaming/schemas.py +++ b/src/deepiri_modelkit/streaming/schemas.py @@ -1,6 +1,7 @@ """ Streaming event schemas and validation """ + from .topics import StreamTopics from ..contracts.events import ( BaseEvent, @@ -12,7 +13,6 @@ TrainingEvent, ) - # Map topics to event schemas TOPIC_EVENT_SCHEMAS = { StreamTopics.MODEL_EVENTS: [ModelReadyEvent, ModelLoadedEvent], @@ -26,31 +26,30 @@ def validate_event(topic: str, event_data: dict) -> BaseEvent: """ Validate event against schema - + Args: topic: Stream topic event_data: Event data dict - + Returns: Validated event object - + Raises: ValueError: If event doesn't match schema """ if topic not in TOPIC_EVENT_SCHEMAS: # Unknown topic, return base event return BaseEvent(**event_data) - + schemas = TOPIC_EVENT_SCHEMAS[topic] event_type = event_data.get("event") - + # Try to match event type to schema for schema in schemas: try: return schema(**event_data) except Exception: continue - + # Fallback to base event return BaseEvent(**event_data) - diff --git a/src/deepiri_modelkit/streaming/sidecar_utils.py b/src/deepiri_modelkit/streaming/sidecar_utils.py index 3af5322..1ec9faf 100644 --- a/src/deepiri_modelkit/streaming/sidecar_utils.py +++ b/src/deepiri_modelkit/streaming/sidecar_utils.py @@ -13,7 +13,9 @@ from urllib.parse import urlparse -def env_float(name: str, default: float, logger: Optional[Callable[[str], None]] = None) -> float: +def env_float( + name: str, default: float, logger: Optional[Callable[[str], None]] = None +) -> float: """Read a float env var with safe fallback.""" raw = os.getenv(name) if raw is None: diff --git a/src/deepiri_modelkit/streaming/topics.py b/src/deepiri_modelkit/streaming/topics.py index 4e99a15..0e50984 100644 --- a/src/deepiri_modelkit/streaming/topics.py +++ b/src/deepiri_modelkit/streaming/topics.py @@ -1,19 +1,28 @@ """ Stream topic definitions """ + from enum import Enum class StreamTopics(str, Enum): """Redis Stream topics""" + MODEL_EVENTS = "model-events" INFERENCE_EVENTS = "inference-events" PLATFORM_EVENTS = "platform-events" AGI_DECISIONS = "agi-decisions" TRAINING_EVENTS = "training-events" - + # LIS document routing streams (document.* namespace). + DOCUMENT_VECTORIZE = "document.vectorize" + DOCUMENT_TRAINING = "document.training" + DOCUMENT_STRUCTURED = "document.structured" + DOCUMENT_ARTIFACTS = "document.artifacts" + # Cyrex runtime training signals consumed by Helox. + HELOX_TRAINING_RAW = "pipeline.helox-training.raw" + HELOX_TRAINING_STRUCTURED = "pipeline.helox-training.structured" + @classmethod def all(cls) -> list: """Get all topic names""" return [topic.value for topic in cls] - diff --git a/src/deepiri_modelkit/utils/__init__.py b/src/deepiri_modelkit/utils/__init__.py index a7dfcc9..dbfdac6 100644 --- a/src/deepiri_modelkit/utils/__init__.py +++ b/src/deepiri_modelkit/utils/__init__.py @@ -2,6 +2,7 @@ try: from .device import get_device, get_torch_device + __all__ = ["get_device", "get_torch_device"] except ImportError: __all__ = [] diff --git a/src/deepiri_modelkit/utils/device.py b/src/deepiri_modelkit/utils/device.py index 8b3922a..9c08107 100644 --- a/src/deepiri_modelkit/utils/device.py +++ b/src/deepiri_modelkit/utils/device.py @@ -2,10 +2,12 @@ GPU Device Detection Utility Automatically detects and uses GPU (CUDA) if available, falls back to CPU """ + import os try: import torch + HAS_TORCH = True except ImportError: HAS_TORCH = False @@ -28,7 +30,9 @@ def get_device() -> str: # Log diagnostic information for debugging logger.debug(f"PyTorch version: {torch.__version__}") - logger.debug(f"CUDA available (torch.cuda.is_available()): {torch.cuda.is_available()}") + logger.debug( + f"CUDA available (torch.cuda.is_available()): {torch.cuda.is_available()}" + ) if torch.cuda.is_available(): try: @@ -50,14 +54,17 @@ def get_device() -> str: # Try to get the list of supported compute capabilities # PyTorch 2.9.1 with CUDA 12.6 supports up to sm_90 # RTX 5080/5090 requires sm_120 support (CUDA 12.8+) - test_tensor = torch.tensor([1.0], device='cuda') + test_tensor = torch.tensor([1.0], device="cuda") result = test_tensor * 2.0 _ = result.cpu() del test_tensor, result torch.cuda.empty_cache() except RuntimeError as e: - if "no kernel image is available for execution on the device" in str(e) or \ - "cudaErrorNoKernelImageForDevice" in str(e): + if ( + "no kernel image is available for execution on the device" + in str(e) + or "cudaErrorNoKernelImageForDevice" in str(e) + ): logger.error( f"RTX 5080/5090 (sm_{cuda_capability[0]}.{cuda_capability[1]}) detected, but PyTorch doesn't support this compute capability. " f"Current PyTorch supports up to sm_90. " @@ -69,7 +76,7 @@ def get_device() -> str: raise # Test GPU functionality with a simple operation - test_tensor = torch.tensor([1.0], device='cuda') + test_tensor = torch.tensor([1.0], device="cuda") result = test_tensor * 2.0 _ = result.cpu() # Ensure operation completes del test_tensor, result @@ -83,8 +90,10 @@ def get_device() -> str: except RuntimeError as cuda_error: error_msg = str(cuda_error) # Check if this is the RTX 5080 compatibility issue - if "no kernel image is available for execution on the device" in error_msg or \ - "cudaErrorNoKernelImageForDevice" in error_msg: + if ( + "no kernel image is available for execution on the device" in error_msg + or "cudaErrorNoKernelImageForDevice" in error_msg + ): logger.error( f"GPU compute capability not supported by current PyTorch installation. " f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'Unknown'}, " @@ -110,7 +119,9 @@ def get_device() -> str: # Check if we're in Docker and might need NVIDIA Container Toolkit if os.path.exists("/.dockerenv"): - logger.debug("Running in Docker container - ensure NVIDIA Container Toolkit is installed") + logger.debug( + "Running in Docker container - ensure NVIDIA Container Toolkit is installed" + ) # Check for NVIDIA runtime if os.path.exists("/proc/driver/nvidia"): logger.warning( @@ -120,16 +131,18 @@ def get_device() -> str: ) # Check MPS (Apple Silicon) - if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): + if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): try: - test_tensor = torch.tensor([1.0], device='mps') + test_tensor = torch.tensor([1.0], device="mps") result = test_tensor * 2.0 _ = result.cpu() del test_tensor, result logger.info("Apple Silicon (MPS) detected and tested successfully") return "mps" except Exception as mps_error: - logger.warning(f"MPS available but test failed, falling back to CPU: {mps_error}") + logger.warning( + f"MPS available but test failed, falling back to CPU: {mps_error}" + ) # Fallback to CPU logger.info("Using CPU device (no GPU detected or GPU test failed)") @@ -139,5 +152,7 @@ def get_device() -> str: def get_torch_device() -> "torch.device": """Get PyTorch device object""" if not HAS_TORCH: - raise ImportError("torch is required for get_torch_device(). Install with: pip install torch") + raise ImportError( + "torch is required for get_torch_device(). Install with: pip install torch" + ) return torch.device(get_device()) diff --git a/tests/test_streaming_topics.py b/tests/test_streaming_topics.py new file mode 100644 index 0000000..1542700 --- /dev/null +++ b/tests/test_streaming_topics.py @@ -0,0 +1,27 @@ +from deepiri_modelkit.streaming.topics import StreamTopics + + +def test_document_stream_topics_match_lis_routing_namespace() -> None: + assert StreamTopics.DOCUMENT_VECTORIZE.value == "document.vectorize" + assert StreamTopics.DOCUMENT_TRAINING.value == "document.training" + assert StreamTopics.DOCUMENT_STRUCTURED.value == "document.structured" + assert StreamTopics.DOCUMENT_ARTIFACTS.value == "document.artifacts" + + +def test_helox_training_topics_stay_in_pipeline_namespace() -> None: + assert StreamTopics.HELOX_TRAINING_RAW.value == "pipeline.helox-training.raw" + assert ( + StreamTopics.HELOX_TRAINING_STRUCTURED.value + == "pipeline.helox-training.structured" + ) + + +def test_all_includes_shared_stream_topics() -> None: + topics = set(StreamTopics.all()) + + assert "document.vectorize" in topics + assert "document.training" in topics + assert "document.structured" in topics + assert "document.artifacts" in topics + assert "pipeline.helox-training.raw" in topics + assert "pipeline.helox-training.structured" in topics