From 7f1f1fe0b473ed19a4a4bfd60abc41d8eeddd686 Mon Sep 17 00:00:00 2001 From: vizsatiz Date: Wed, 18 Feb 2026 13:00:24 +0530 Subject: [PATCH] feat(floware): image url base search and tag based document fetch --- .../knowledge_base_document_controller.py | 54 ++++---- .../controllers/rag_retreival_controller.py | 123 +++++++++++++++++- .../queries/generate_query.py | 58 +++++++++ 3 files changed, 207 insertions(+), 28 deletions(-) diff --git a/wavefront/server/modules/knowledge_base_module/knowledge_base_module/controllers/knowledge_base_document_controller.py b/wavefront/server/modules/knowledge_base_module/knowledge_base_module/controllers/knowledge_base_document_controller.py index c01872be..0637a389 100644 --- a/wavefront/server/modules/knowledge_base_module/knowledge_base_module/controllers/knowledge_base_document_controller.py +++ b/wavefront/server/modules/knowledge_base_module/knowledge_base_module/controllers/knowledge_base_document_controller.py @@ -27,8 +27,7 @@ from flo_cloud.message_queue import MessageQueueManager from flo_cloud.cloud_storage import CloudStorageManager from pydantic import BaseModel -from sqlalchemy import Result -from sqlalchemy import select +from knowledge_base_module.queries.generate_query import QueryGenerator kb_document_router = APIRouter() @@ -182,11 +181,23 @@ async def upload_document( os.unlink(temp_file_path) +def _document_row_to_dict(row: dict) -> dict: + """Convert a raw document row to the same format as KnowledgeBaseDocuments.to_dict().""" + result = dict(row) + for key, value in result.items(): + if isinstance(value, uuid.UUID): + result[key] = str(value) + elif isinstance(value, datetime): + result[key] = value.isoformat() + return result + + @kb_document_router.get('/v1/knowledge-bases/{kb_id}/documents') @inject async def get_documents( kb_id: uuid.UUID, file_type: Optional[str] = Query(None, description='Type of file to filter by'), + query_filter: Optional[str] = Query(None, alias='$filter'), offset: int = Query(0, ge=0, description='The number of items to skip'), limit: int = Query( 10, ge=1, le=100, description='The maximum number of items to return' @@ -199,35 +210,28 @@ async def get_documents( ] = Depends(Provide[KnowledgeBaseContainer.knowledge_base_documents_repository]), ) -> JSONResponse: """Get documents for a knowledge base with optional filtering and pagination.""" - # Validate knowledge base exists - existing_document = await knowledge_base_documents_repository.find_one( - knowledge_base_id=kb_id - ) - if not existing_document: - return JSONResponse( - status_code=status.HTTP_200_OK, - content=response_formatter.buildSuccessResponse(data={'resources': []}), + try: + query_generator = QueryGenerator() + sql_query, query_params = query_generator.get_documents_list_query( + kb_id=str(kb_id), + file_type=file_type, + filter=query_filter, + offset=offset, + limit=limit, ) - - # Fetch documents - async with knowledge_base_documents_repository.session() as session: - query = select(KnowledgeBaseDocuments).where( - KnowledgeBaseDocuments.knowledge_base_id == kb_id + rows = await knowledge_base_documents_repository.execute_query( + sql_query, query_params ) - - if file_type: - query = query.where(KnowledgeBaseDocuments.file_type == file_type) - - query = query.slice(offset, limit) - - results: Result = await session.execute(query) - resources = results.scalars().all() - data = [res.to_dict() for res in resources] - + data = [_document_row_to_dict(row) for row in rows] return JSONResponse( status_code=status.HTTP_200_OK, content=response_formatter.buildSuccessResponse(data={'resources': data}), ) + except ValueError as e: + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content=response_formatter.buildErrorResponse(str(e)), + ) @kb_document_router.delete('/v1/knowledge-bases/{kb_id}/documents/{document_id}') diff --git a/wavefront/server/modules/knowledge_base_module/knowledge_base_module/controllers/rag_retreival_controller.py b/wavefront/server/modules/knowledge_base_module/knowledge_base_module/controllers/rag_retreival_controller.py index ce88782f..8dd2c6c3 100644 --- a/wavefront/server/modules/knowledge_base_module/knowledge_base_module/controllers/rag_retreival_controller.py +++ b/wavefront/server/modules/knowledge_base_module/knowledge_base_module/controllers/rag_retreival_controller.py @@ -1,4 +1,5 @@ -from typing import List, Optional +import base64 +from typing import List, Optional, Tuple import uuid from common_module.common_container import CommonContainer @@ -21,6 +22,7 @@ ) from knowledge_base_module.services.kb_rag_retrieve import KBRagResponse from knowledge_base_module.services.image_rag_retrieve import ImageRagRetrieve +from flo_cloud.cloud_storage import CloudStorageManager from pydantic import BaseModel, Field from datetime import datetime from sqlalchemy import Result @@ -71,9 +73,10 @@ class DocWiseEmbeddingSchema(BaseModel): class ImagePayload(BaseModel): - """Payload for Image embedding.""" + """Payload for Image embedding. Use image_data (base64) or image_url (gs:// or s3://); image_url has priority if both are set.""" image_data: Optional[str] = None + image_url: Optional[str] = None class DocumentPayload(BaseModel): @@ -96,6 +99,112 @@ def convert_uuids_to_str(data): return data +def _parse_cloud_image_url(url: str) -> Tuple[str, str, str]: + """ + Parse gs:// or s3:// URL into (scheme, bucket, key). + Returns (scheme, bucket, key) or raises ValueError. + """ + url = (url or '').strip() + if url.startswith('gs://'): + rest = url[5:] + if '/' not in rest: + raise ValueError('Invalid gs:// URL: missing path after bucket') + bucket, _, key = rest.partition('/') + return ('gs', bucket, key) + if url.startswith('s3://'): + rest = url[5:] + if '/' not in rest: + raise ValueError('Invalid s3:// URL: missing path after bucket') + bucket, _, key = rest.partition('/') + return ('s3', bucket, key) + raise ValueError('image_url must be in gs:// or s3:// format') + + +async def _resolve_image_data( + payload: ImagePayload, + cloud_storage: CloudStorageManager, + config: dict, + response_formatter: ResponseFormatter, +) -> Tuple[Optional[str], Optional[JSONResponse]]: + """ + Resolve image payload to a single image_data string (base64) for the inference API. + When both are provided, image_url has priority; otherwise uses image_data or fetches from image_url (gs:// or s3://). + Returns (image_data, None) on success, or (None, error_json_response) on validation/fetch error. + """ + if payload.image_url: + pass + elif payload.image_data: + return (payload.image_data, None) + else: + return ( + None, + JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content=response_formatter.buildErrorResponse( + 'Query or Image data should not be empty' + ), + ), + ) + try: + scheme, bucket, key = _parse_cloud_image_url(payload.image_url) + except ValueError as e: + return ( + None, + JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content=response_formatter.buildErrorResponse(str(e)), + ), + ) + cloud_provider = ( + (config.get('cloud_config') or {}).get('cloud_provider', '').lower() + ) + if scheme == 'gs' and cloud_provider != 'gcp': + return ( + None, + JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content=response_formatter.buildErrorResponse( + 'image_url gs:// is only supported when cloud provider is GCP' + ), + ), + ) + if scheme == 's3' and cloud_provider != 'aws': + return ( + None, + JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content=response_formatter.buildErrorResponse( + 'image_url s3:// is only supported when cloud provider is AWS' + ), + ), + ) + try: + content = cloud_storage.read_file(bucket, key) + except Exception as e: + return ( + None, + JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content=response_formatter.buildErrorResponse( + f'Failed to fetch image from storage: {e!s}' + ), + ), + ) + image_bytes = content.read() if hasattr(content, 'read') else content + if not image_bytes: + return ( + None, + JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content=response_formatter.buildErrorResponse( + 'Image from URL is empty' + ), + ), + ) + image_data_b64 = base64.b64encode(image_bytes).decode('utf-8') + return (image_data_b64, None) + + @rag_retrieval_router.post('/v1/knowledge-base/{kb_id}/retrieve') @inject async def retrieve_query( @@ -128,6 +237,9 @@ async def retrieve_query( Provide[KnowledgeBaseContainer.image_knowledge_base_retrieve] ), config: dict = Depends(Provide[KnowledgeBaseContainer.config]), + cloud_storage: CloudStorageManager = Depends( + Provide[KnowledgeBaseContainer.cloud_storage] + ), ): if not query and not payload: return JSONResponse( @@ -156,9 +268,14 @@ async def retrieve_query( limit, ) else: + image_data, error_response = await _resolve_image_data( + payload, cloud_storage, config, response_formatter + ) + if error_response is not None: + return error_response inference_url = config['model']['inference_service_url'] retrieved_docs = await image_rag_retrieval.retrieve_images( - payload.image_data, + image_data, inference_url, kb_id, threshold, diff --git a/wavefront/server/modules/knowledge_base_module/knowledge_base_module/queries/generate_query.py b/wavefront/server/modules/knowledge_base_module/knowledge_base_module/queries/generate_query.py index 32761ce8..7b4ced74 100644 --- a/wavefront/server/modules/knowledge_base_module/knowledge_base_module/queries/generate_query.py +++ b/wavefront/server/modules/knowledge_base_module/knowledge_base_module/queries/generate_query.py @@ -280,6 +280,64 @@ def get_image_embedding_dino( return sql_query, params + def get_documents_list_query( + self, + kb_id: str, + file_type: Optional[str] = None, + filter: Optional[str] = None, + offset: int = 0, + limit: int = 10, + ) -> Tuple[str, Dict[str, Any]]: + """ + Generate SQL query to list knowledge base documents with optional + metadata filter (OData-style $filter) and file_type. + + Returns: + Tuple of (SQL query string, query parameters) + """ + params: Dict[str, Any] = { + 'kb_id': kb_id, + 'offset': offset, + 'limit': limit, + } + conditions = ['knowledge_base_id = :kb_id'] + if file_type: + params['file_type'] = file_type + conditions.append('file_type = :file_type') + + metadata_filter_clause = '' + if filter: + where_clause, filter_params = self.odata_parser.prepare_odata_filter(filter) + if where_clause and filter_params: + metadata_filter_clause = self.build_metadata_clause( + where_clause, + filter_params, + lambda field: f"(metadata_value ->> '{field}')", + ) + params.update(filter_params) + conditions.append(f'({metadata_filter_clause})') + + where_sql = ' AND '.join(conditions) + sql_query = f""" + SELECT + id, + knowledge_base_id, + file_path, + file_name, + file_type, + file_size, + created_at, + updated_at, + metadata_value + FROM + {KnowledgeBaseDocuments.__tablename__} + WHERE + {where_sql} + ORDER BY created_at DESC + LIMIT :limit OFFSET :offset + """ + return sql_query, params + @staticmethod def get_update_tokens_query() -> str: """