diff --git a/wavefront/client/src/pages/apps/[appId]/llm-inference/CreateLLMInferenceDialog.tsx b/wavefront/client/src/pages/apps/[appId]/llm-inference/CreateLLMInferenceDialog.tsx index 6d718ebb..e05c5047 100644 --- a/wavefront/client/src/pages/apps/[appId]/llm-inference/CreateLLMInferenceDialog.tsx +++ b/wavefront/client/src/pages/apps/[appId]/llm-inference/CreateLLMInferenceDialog.tsx @@ -71,6 +71,7 @@ const createLLMInferenceSchema = z.object({ displayName: z.string().min(1, 'Display name is required'), llmModel: z.string().min(1, 'LLM model name is required'), type: z.enum(['openai', 'anthropic', 'gemini', 'azure_openai', 'ollama', 'vllm', 'groq']), + modelType: z.enum(['llm', 'embedding']), apiKey: z.string().optional(), baseUrl: z.string().optional(), }); @@ -103,6 +104,7 @@ const CreateLLMInferenceDialog: React.FC = ({ displayName: '', llmModel: '', type: 'openai', + modelType: 'llm', apiKey: '', baseUrl: getDefaultBaseUrl('openai'), }, @@ -131,6 +133,7 @@ const CreateLLMInferenceDialog: React.FC = ({ displayName: '', llmModel: '', type: defaultType, + modelType: 'llm', apiKey: '', baseUrl: getDefaultBaseUrl(defaultType), }); @@ -171,6 +174,7 @@ const CreateLLMInferenceDialog: React.FC = ({ llm_model: data.llmModel.trim(), api_key: data.apiKey?.trim() || undefined, type: data.type, + model_type: data.modelType, base_url: data.baseUrl?.trim() || undefined, parameters: Object.keys(cleanedParams).length > 0 ? cleanedParams : undefined, }); @@ -366,6 +370,31 @@ const CreateLLMInferenceDialog: React.FC = ({ )} /> + + ( + + + Model Type* + + + Choose whether this is an LLM or embedding model + + + )} + />
diff --git a/wavefront/client/src/pages/apps/[appId]/llm-inference/[configId].tsx b/wavefront/client/src/pages/apps/[appId]/llm-inference/[configId].tsx index f58a6628..fa185e14 100644 --- a/wavefront/client/src/pages/apps/[appId]/llm-inference/[configId].tsx +++ b/wavefront/client/src/pages/apps/[appId]/llm-inference/[configId].tsx @@ -48,6 +48,7 @@ const llmConfigFormSchema = z.object({ llm_model: z.string().min(1, 'LLM model is required'), type: z.enum(['openai', 'anthropic', 'gemini', 'azure_openai', 'ollama', 'vllm', 'groq']), api_key: z.string().optional(), + model_type: z.enum(['llm', 'embedding']), base_url: z.string().optional(), parameters: z.record(z.any()).optional(), }); @@ -74,6 +75,7 @@ const LLMInferenceConfigDetail: React.FC = () => { display_name: '', llm_model: '', type: 'openai', + model_type: 'llm', api_key: '', base_url: '', parameters: {}, @@ -91,6 +93,7 @@ const LLMInferenceConfigDetail: React.FC = () => { display_name: config.display_name, llm_model: config.llm_model, type: config.type, + model_type: (config.model_type as 'llm' | 'embedding') || 'llm', api_key: '', // API key is never returned for security base_url: config.base_url || '', parameters: mergedParams, @@ -134,6 +137,7 @@ const LLMInferenceConfigDetail: React.FC = () => { display_name: data.display_name.trim(), llm_model: data.llm_model.trim(), type: data.type, + model_type: data.model_type, parameters: Object.keys(cleanedParams).length > 0 ? cleanedParams : null, }; @@ -244,6 +248,7 @@ const LLMInferenceConfigDetail: React.FC = () => { display_name: config.display_name, llm_model: config.llm_model, type: config.type, + model_type: (config.model_type as 'llm' | 'embedding') || 'llm', api_key: '', base_url: config.base_url || '', parameters: mergedParams, @@ -369,6 +374,28 @@ const LLMInferenceConfigDetail: React.FC = () => { )} /> )} + + ( + + Model Type + + + + )} + />
diff --git a/wavefront/client/src/types/llm-inference-config.ts b/wavefront/client/src/types/llm-inference-config.ts index f86ffc2a..534ae76d 100644 --- a/wavefront/client/src/types/llm-inference-config.ts +++ b/wavefront/client/src/types/llm-inference-config.ts @@ -7,6 +7,7 @@ export interface LLMInferenceConfig { llm_model: string; display_name: string; type: InferenceEngineType; + model_type?: string; base_url?: string; parameters?: Record | null; is_deleted: boolean; @@ -19,6 +20,7 @@ export interface CreateLLMConfigRequest { display_name: string; api_key?: string; type: InferenceEngineType; + model_type?: string; base_url?: string; parameters?: Record | null; } @@ -28,6 +30,7 @@ export interface UpdateLLMConfigRequest { display_name?: string; api_key?: string | null; type?: InferenceEngineType; + model_type?: string; base_url?: string | null; parameters?: Record | null; } diff --git a/wavefront/server/background_jobs/rag_ingestion/rag_ingestion/models/doc_content.py b/wavefront/server/background_jobs/rag_ingestion/rag_ingestion/models/doc_content.py index b07cd30b..5ed3cf96 100644 --- a/wavefront/server/background_jobs/rag_ingestion/rag_ingestion/models/doc_content.py +++ b/wavefront/server/background_jobs/rag_ingestion/rag_ingestion/models/doc_content.py @@ -1,5 +1,6 @@ from dataclasses import dataclass from typing import Union +from rag_ingestion.processors.file_processor import DocumentType @dataclass @@ -7,4 +8,4 @@ class DocContent: """Model representing the extracted content from a document file""" content: Union[str, bytes] - parse_type: str + document_type: DocumentType diff --git a/wavefront/server/background_jobs/rag_ingestion/rag_ingestion/processors/file_processor.py b/wavefront/server/background_jobs/rag_ingestion/rag_ingestion/processors/file_processor.py index abada45a..02ec6bd9 100644 --- a/wavefront/server/background_jobs/rag_ingestion/rag_ingestion/processors/file_processor.py +++ b/wavefront/server/background_jobs/rag_ingestion/rag_ingestion/processors/file_processor.py @@ -1,48 +1,59 @@ import os import tempfile import textract -from typing import Union +from typing import Tuple +from enum import Enum +from common_module.log.logger import logger -class FileProcessor: - def process_file(self, file_content: bytes, file_type: str) -> Union[str, bytes]: - mime_type = file_type +class DocumentType(Enum): + PDF = 'pdf' + IMAGE = 'image' + TEXT = 'text' - if mime_type.startswith('text/plain'): - return file_content.decode('utf-8') - if mime_type.startswith('image/'): - return file_content +class FileProcessor: + def process_file( + self, file_content: bytes, file_type: str + ) -> Tuple[str | bytes, DocumentType]: + mime_type = file_type + document_type = self.extract_document_type(mime_type) + if document_type == DocumentType.TEXT: + return file_content.decode('utf-8'), DocumentType.TEXT - if mime_type.startswith('application/'): - try: - sub_type = mime_type.split('/')[1] - except IndexError: - raise ValueError( - f'Unsupported file type: Malformed MIME type "{mime_type}"' - ) + if document_type == DocumentType.IMAGE: + return file_content, DocumentType.IMAGE - # Set delete=False to keep the file until we manually call os.unlink + if document_type == DocumentType.PDF: with tempfile.NamedTemporaryFile( - mode='w+b', delete=False, suffix=f'.{sub_type}' + mode='w+b', delete=False, suffix='.pdf' ) as temp_file: temp_file.write(file_content) - temp_file.flush() # Ensure data is written to disk before processing + temp_file.flush() temp_file_path = temp_file.name try: - # Process the file using its path text_content = textract.process( temp_file_path, method='pdfminer' ).decode('utf-8') - return text_content + return text_content, DocumentType.PDF except Exception as e: - # Re-raise processing errors + logger.error(f'Text extraction failed for {mime_type}: {e}') raise RuntimeError(f'Text extraction failed for {mime_type}: {e}') finally: os.unlink(temp_file_path) + # Explicit raise to prevent implicit None return. + raise RuntimeError(f'Unsupported or unknown document type: {document_type}') + + def extract_document_type(self, file_type: str) -> DocumentType: + if file_type.startswith('text/plain'): + return DocumentType.TEXT + if file_type.startswith('image/'): + return DocumentType.IMAGE + if file_type in ('application/pdf', 'application/x-pdf'): + return DocumentType.PDF else: - raise ValueError(f'Unsupported file type: {mime_type}') + raise ValueError(f'Unsupported file type: {file_type}') diff --git a/wavefront/server/background_jobs/rag_ingestion/rag_ingestion/processors/kb_storage_processor.py b/wavefront/server/background_jobs/rag_ingestion/rag_ingestion/processors/kb_storage_processor.py index ea57f641..c6d7a70b 100644 --- a/wavefront/server/background_jobs/rag_ingestion/rag_ingestion/processors/kb_storage_processor.py +++ b/wavefront/server/background_jobs/rag_ingestion/rag_ingestion/processors/kb_storage_processor.py @@ -1,6 +1,6 @@ from flo_cloud.cloud_storage import CloudStorageManager from dataclasses import dataclass -from typing import List, Union +from typing import List from flo_utils.utils.log import logger from rag_ingestion.service.kb_rag_storage import KBRagStorage from rag_ingestion.embeddings.embed import EmbeddingFunc @@ -8,7 +8,7 @@ from rag_ingestion.stream.queue_message import QueueMessage from flo_cloud.kms import FloKmsService from flo_utils.streaming.message_processor import MessageProcessor, ProcessingResult -from rag_ingestion.processors.file_processor import FileProcessor +from rag_ingestion.processors.file_processor import FileProcessor, DocumentType from rag_ingestion.embeddings.image_embed import ImageEmbedding from rag_ingestion.models.knowledge_base_embeddings import KnowledgeBaseEmbeddingObject from rag_ingestion.models.rag_message import RagEventMessage @@ -20,7 +20,7 @@ class KbStorageInsights: doc_id: str doc_content: DocContent kb_id: str - file_type: str + file_type: DocumentType class KbStorageProcessor(MessageProcessor): @@ -36,7 +36,7 @@ def __init__( self.file_processor = FileProcessor() self.image_embedding = ImageEmbedding() - async def _extract_text_from_message( + async def _extract_content( self, message: QueueMessage, file_content: bytes ) -> DocContent: """ @@ -49,12 +49,10 @@ async def _extract_text_from_message( Returns: A DocContent object with extracted content and parse_type. """ - if message.parse_type != 'kb_insertion': - return DocContent(content='', parse_type=message.parse_type) - content: Union[str, bytes] = self.file_processor.process_file( + (content, document_type) = self.file_processor.process_file( file_content, str(message.file_type) ) - return DocContent(content=content, parse_type=message.parse_type) + return DocContent(content=content, document_type=document_type) def __insert_kb_from_message( self, insights: List[ProcessingResult[KbStorageInsights]] @@ -74,27 +72,29 @@ def __insert_kb_from_message( for kb_insight in insights: kb_id = kb_insight.insights.kb_id doc_id = kb_insight.insights.doc_id - file_type = kb_insight.insights.file_type + document_type = kb_insight.insights.doc_content.document_type logger.info('Embeddings storing process is started') - if file_type and file_type.lower() in ('pdf', 'application/pdf'): + if ( + document_type == DocumentType.PDF + or document_type == DocumentType.TEXT + ): extracted_docs = [kb_insight.insights.doc_content.content] docs: List[KnowledgeBaseEmbeddingObject] = ( self.kb_rag_storage.process_document(extracted_docs) ) - else: + elif document_type == DocumentType.IMAGE: image_data = [kb_insight.insights.doc_content.content] docs: List[KnowledgeBaseEmbeddingObject] = [ self.image_embedding.embed_image(image_data) for image_data in image_data ] - embeddings.append( EmbeddingsToStore( kb_embeddings=docs, doc_id=doc_id, kb_id=kb_id, - file_type=file_type, + file_type=document_type, ) ) @@ -126,14 +126,14 @@ async def process( if self.encryption_service else file_content_encrypt ) - doc_content = await self._extract_text_from_message(message, file_content) + doc_content = await self._extract_content(message, file_content) return ProcessingResult[KbStorageInsights]( success=True, insights=KbStorageInsights( doc_id=message.doc_id, doc_content=doc_content, kb_id=message.kb_id, - file_type=message.file_type, + file_type=doc_content.document_type, ), ) diff --git a/wavefront/server/background_jobs/rag_ingestion/rag_ingestion/service/kb_rag_storage.py b/wavefront/server/background_jobs/rag_ingestion/rag_ingestion/service/kb_rag_storage.py index b408e531..f78ee445 100644 --- a/wavefront/server/background_jobs/rag_ingestion/rag_ingestion/service/kb_rag_storage.py +++ b/wavefront/server/background_jobs/rag_ingestion/rag_ingestion/service/kb_rag_storage.py @@ -1,12 +1,12 @@ -from typing import Any, List, Dict, Tuple, Optional import logging -import numpy as np import tiktoken -from datetime import datetime import uuid -from rag_ingestion.embeddings.embed import EmbeddingFunc import httpx import time +import ast +import numpy as np +from flo_utils.utils.log import logger +from datetime import datetime from dataclasses import dataclass from rag_ingestion.env import FLOWARE_SERVICE_URL, APP_ENV, PASSTHROUGH_SECRET from rag_ingestion.constants.auth import RootfloHeaders @@ -14,8 +14,9 @@ KnowledgeBaseEmbeddingObject, RetrieveParams, ) -from flo_utils.utils.log import logger -import ast +from typing import Any, List, Dict, Tuple, Optional +from rag_ingestion.embeddings.embed import EmbeddingFunc +from rag_ingestion.processors.file_processor import DocumentType @dataclass @@ -23,7 +24,7 @@ class EmbeddingsToStore: kb_embeddings: List[KnowledgeBaseEmbeddingObject] doc_id: str kb_id: str - file_type: str + file_type: DocumentType class KBRagStorage: diff --git a/wavefront/server/docker-compose.yml b/wavefront/server/docker-compose.yml new file mode 100644 index 00000000..a0b78199 --- /dev/null +++ b/wavefront/server/docker-compose.yml @@ -0,0 +1,34 @@ +version: '3.8' + +services: + postgres: + image: ankane/pgvector + container_name: postgres + restart: always + environment: + - POSTGRES_USER=postgres + - POSTGRES_PASSWORD=postgres + - POSTGRES_DB=floware + ports: + - '5432:5432' + volumes: + - /var/lib/postgres-data:/var/lib/postgresql/data + - /var/lib/init.sql:/docker-entrypoint-initdb.d/init.sql + networks: + - floware-network + + redis: + image: redis:7.4 + container_name: redis + ports: + - "6379:6379" + networks: + - floware-network + +volumes: + db: + driver: local + +networks: + floware-network: + driver: bridge diff --git a/wavefront/server/modules/db_repo_module/db_repo_module/alembic/versions/2025_12_16_1406-f7572bcd9510_embedding_models.py b/wavefront/server/modules/db_repo_module/db_repo_module/alembic/versions/2025_12_16_1406-f7572bcd9510_embedding_models.py new file mode 100644 index 00000000..82cf4820 --- /dev/null +++ b/wavefront/server/modules/db_repo_module/db_repo_module/alembic/versions/2025_12_16_1406-f7572bcd9510_embedding_models.py @@ -0,0 +1,35 @@ +"""embedding_models + +Revision ID: f7572bcd9510 +Revises: 10e09e25efa0 +Create Date: 2025-12-16 14:06:06.178161 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = 'f7572bcd9510' +down_revision: Union[str, None] = '10e09e25efa0' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # Add model_type column to llm_inference_config table + # Set default to 'llm' for all existing rows + op.add_column( + 'llm_inference_config', + sa.Column( + 'model_type', sa.String(length=64), nullable=False, server_default='llm' + ), + ) + + +def downgrade() -> None: + # Remove model_type column from llm_inference_config table + op.drop_column('llm_inference_config', 'model_type') diff --git a/wavefront/server/modules/db_repo_module/db_repo_module/models/llm_inference_config.py b/wavefront/server/modules/db_repo_module/db_repo_module/models/llm_inference_config.py index 0fd17223..f0969aae 100644 --- a/wavefront/server/modules/db_repo_module/db_repo_module/models/llm_inference_config.py +++ b/wavefront/server/modules/db_repo_module/db_repo_module/models/llm_inference_config.py @@ -2,7 +2,7 @@ from datetime import datetime from typing import Optional -from sqlalchemy import String, JSON, func +from sqlalchemy import String, JSON, func, text from sqlalchemy.orm import Mapped, mapped_column from ..database.base import Base @@ -18,6 +18,9 @@ class LlmInferenceConfig(Base): display_name: Mapped[str] = mapped_column(String(length=255), nullable=False) api_key: Mapped[Optional[str]] = mapped_column(String(length=512), nullable=True) type: Mapped[str] = mapped_column(String(length=64), nullable=False) + model_type: Mapped[str] = mapped_column( + String(length=64), nullable=False, default='llm', server_default=text("'llm'") + ) base_url: Mapped[Optional[str]] = mapped_column(String(length=512), nullable=True) parameters: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True) is_deleted: Mapped[bool] = mapped_column(default=False, nullable=False) diff --git a/wavefront/server/modules/llm_inference_config_module/llm_inference_config_module/controllers/llm_inference_config_controller.py b/wavefront/server/modules/llm_inference_config_module/llm_inference_config_module/controllers/llm_inference_config_controller.py index 7c13c889..1be313d8 100644 --- a/wavefront/server/modules/llm_inference_config_module/llm_inference_config_module/controllers/llm_inference_config_controller.py +++ b/wavefront/server/modules/llm_inference_config_module/llm_inference_config_module/controllers/llm_inference_config_controller.py @@ -68,6 +68,7 @@ async def create_llm_inference_config( type=payload.type.value, base_url=payload.base_url, parameters=payload.parameters, + model_type=payload.model_type, ) return JSONResponse( @@ -221,6 +222,22 @@ async def update_llm_inference_config( f'Invalid type value. Must be one of: {valid_values}' ), ) + if payload.model_type is not UNSET: + if payload.model_type is None: + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content=response_formatter.buildErrorResponse( + 'model_type cannot be null' + ), + ) + if payload.model_type not in ['llm', 'embedding']: + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content=response_formatter.buildErrorResponse( + 'Invalid model_type value. Must be "llm" or "embedding"' + ), + ) + update_data['model_type'] = payload.model_type if payload.base_url is not UNSET: update_data['base_url'] = payload.base_url if payload.parameters is not UNSET: diff --git a/wavefront/server/modules/llm_inference_config_module/llm_inference_config_module/models/schemas.py b/wavefront/server/modules/llm_inference_config_module/llm_inference_config_module/models/schemas.py index aa5b24d7..1a910eea 100644 --- a/wavefront/server/modules/llm_inference_config_module/llm_inference_config_module/models/schemas.py +++ b/wavefront/server/modules/llm_inference_config_module/llm_inference_config_module/models/schemas.py @@ -28,6 +28,9 @@ class CreateLlmInferenceConfigPayload(BaseModel): None, description='API key for the inference engine (optional)' ) type: InferenceEngineType = Field(..., description='Type of inference engine') + model_type: Optional[str] = Field( + 'llm', description='Type of model: "llm" or "embedding" (defaults to "llm")' + ) base_url: Optional[str] = Field( None, description='Base URL for the inference API (optional)' ) @@ -49,6 +52,9 @@ class UpdateLlmInferenceConfigPayload(BaseModel): type: Union[InferenceEngineType, Any] = Field( default=UNSET, description='Type of inference engine' ) + model_type: Union[str, Any] = Field( + default=UNSET, description='Type of model: "llm" or "embedding"' + ) base_url: Union[str, None, Any] = Field( default=UNSET, description='Base URL for the inference API' ) diff --git a/wavefront/server/modules/llm_inference_config_module/llm_inference_config_module/services/llm_inference_config_service.py b/wavefront/server/modules/llm_inference_config_module/llm_inference_config_module/services/llm_inference_config_service.py index 7e7ad42c..d4dbddf1 100644 --- a/wavefront/server/modules/llm_inference_config_module/llm_inference_config_module/services/llm_inference_config_service.py +++ b/wavefront/server/modules/llm_inference_config_module/llm_inference_config_module/services/llm_inference_config_service.py @@ -42,6 +42,7 @@ async def create_config( type: str, base_url: Optional[str] = None, parameters: Optional[dict] = None, + model_type: Optional[str] = 'llm', ) -> dict: """ Create a new LLM inference configuration @@ -53,11 +54,14 @@ async def create_config( type: Type of inference engine base_url: Base URL for the LLM provider (optional) parameters: LLM parameters like temperature, max_tokens, etc. (optional) + model_type: Type of model: "llm" or "embedding" (defaults to "llm") Returns: Created LLM inference config as dict """ - logger.info(f'Creating LLM inference config - model: {llm_model}, type: {type}') + logger.info( + f'Creating LLM inference config - model: {llm_model}, type: {type}, model_type: {model_type}' + ) config = await self.llm_inference_config_repository.create( llm_model=llm_model, @@ -66,6 +70,7 @@ async def create_config( type=type, base_url=base_url, parameters=parameters, + model_type=model_type, ) # Convert to dict