Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@
from flo_cloud.message_queue import MessageQueueManager
from flo_cloud.cloud_storage import CloudStorageManager
from pydantic import BaseModel
from sqlalchemy import Result
from sqlalchemy import select
from knowledge_base_module.queries.generate_query import QueryGenerator

kb_document_router = APIRouter()

Expand Down Expand Up @@ -182,11 +181,23 @@ async def upload_document(
os.unlink(temp_file_path)


def _document_row_to_dict(row: dict) -> dict:
"""Convert a raw document row to the same format as KnowledgeBaseDocuments.to_dict()."""
result = dict(row)
for key, value in result.items():
if isinstance(value, uuid.UUID):
result[key] = str(value)
elif isinstance(value, datetime):
result[key] = value.isoformat()
return result


@kb_document_router.get('/v1/knowledge-bases/{kb_id}/documents')
@inject
async def get_documents(
kb_id: uuid.UUID,
file_type: Optional[str] = Query(None, description='Type of file to filter by'),
query_filter: Optional[str] = Query(None, alias='$filter'),
offset: int = Query(0, ge=0, description='The number of items to skip'),
limit: int = Query(
10, ge=1, le=100, description='The maximum number of items to return'
Expand All @@ -199,35 +210,28 @@ async def get_documents(
] = Depends(Provide[KnowledgeBaseContainer.knowledge_base_documents_repository]),
) -> JSONResponse:
"""Get documents for a knowledge base with optional filtering and pagination."""
# Validate knowledge base exists
existing_document = await knowledge_base_documents_repository.find_one(
knowledge_base_id=kb_id
)
if not existing_document:
return JSONResponse(
status_code=status.HTTP_200_OK,
content=response_formatter.buildSuccessResponse(data={'resources': []}),
try:
query_generator = QueryGenerator()
sql_query, query_params = query_generator.get_documents_list_query(
kb_id=str(kb_id),
file_type=file_type,
filter=query_filter,
offset=offset,
limit=limit,
)

# Fetch documents
async with knowledge_base_documents_repository.session() as session:
query = select(KnowledgeBaseDocuments).where(
KnowledgeBaseDocuments.knowledge_base_id == kb_id
rows = await knowledge_base_documents_repository.execute_query(
sql_query, query_params
)

if file_type:
query = query.where(KnowledgeBaseDocuments.file_type == file_type)

query = query.slice(offset, limit)

results: Result = await session.execute(query)
resources = results.scalars().all()
data = [res.to_dict() for res in resources]

data = [_document_row_to_dict(row) for row in rows]
return JSONResponse(
status_code=status.HTTP_200_OK,
content=response_formatter.buildSuccessResponse(data={'resources': data}),
)
except ValueError as e:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content=response_formatter.buildErrorResponse(str(e)),
)


@kb_document_router.delete('/v1/knowledge-bases/{kb_id}/documents/{document_id}')
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import List, Optional
import base64
from typing import List, Optional, Tuple
import uuid

from common_module.common_container import CommonContainer
Expand All @@ -21,6 +22,7 @@
)
from knowledge_base_module.services.kb_rag_retrieve import KBRagResponse
from knowledge_base_module.services.image_rag_retrieve import ImageRagRetrieve
from flo_cloud.cloud_storage import CloudStorageManager
from pydantic import BaseModel, Field
from datetime import datetime
from sqlalchemy import Result
Expand Down Expand Up @@ -71,9 +73,10 @@ class DocWiseEmbeddingSchema(BaseModel):


class ImagePayload(BaseModel):
"""Payload for Image embedding."""
"""Payload for Image embedding. Use image_data (base64) or image_url (gs:// or s3://); image_url has priority if both are set."""

image_data: Optional[str] = None
image_url: Optional[str] = None


class DocumentPayload(BaseModel):
Expand All @@ -96,6 +99,112 @@ def convert_uuids_to_str(data):
return data


def _parse_cloud_image_url(url: str) -> Tuple[str, str, str]:
"""
Parse gs:// or s3:// URL into (scheme, bucket, key).
Returns (scheme, bucket, key) or raises ValueError.
"""
url = (url or '').strip()
if url.startswith('gs://'):
rest = url[5:]
if '/' not in rest:
raise ValueError('Invalid gs:// URL: missing path after bucket')
bucket, _, key = rest.partition('/')
return ('gs', bucket, key)
if url.startswith('s3://'):
rest = url[5:]
if '/' not in rest:
raise ValueError('Invalid s3:// URL: missing path after bucket')
bucket, _, key = rest.partition('/')
return ('s3', bucket, key)
raise ValueError('image_url must be in gs:// or s3:// format')
Comment on lines +102 to +120
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Empty object key passes validation for URLs like gs://bucket/.

When the URL is gs://bucket/, rest.partition('/') produces ('bucket', '/', '') — the function returns an empty key, which will cause a confusing error downstream when the storage read fails. Add a check for a non-empty key.

Proposed fix
     if url.startswith('gs://'):
         rest = url[5:]
         if '/' not in rest:
             raise ValueError('Invalid gs:// URL: missing path after bucket')
         bucket, _, key = rest.partition('/')
+        if not key:
+            raise ValueError('Invalid gs:// URL: missing object key after bucket')
         return ('gs', bucket, key)
     if url.startswith('s3://'):
         rest = url[5:]
         if '/' not in rest:
             raise ValueError('Invalid s3:// URL: missing path after bucket')
         bucket, _, key = rest.partition('/')
+        if not key:
+            raise ValueError('Invalid s3:// URL: missing object key after bucket')
         return ('s3', bucket, key)
🧰 Tools
🪛 Ruff (0.15.1)

[warning] 111-111: Avoid specifying long messages outside the exception class

(TRY003)


[warning] 117-117: Avoid specifying long messages outside the exception class

(TRY003)


[warning] 120-120: Avoid specifying long messages outside the exception class

(TRY003)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@wavefront/server/modules/knowledge_base_module/knowledge_base_module/controllers/rag_retreival_controller.py`
around lines 102 - 120, _infer that _parse_cloud_image_url currently accepts
URLs like "gs://bucket/" because rest.partition('/') yields an empty key; update
_parse_cloud_image_url to validate that the "key" component is non-empty after
partitioning for both gs:// and s3:// branches and raise ValueError (e.g.,
"Invalid gs:// URL: missing path after bucket" / "Invalid s3:// URL: missing
path after bucket") when key == "" so callers don't receive an empty object key.



async def _resolve_image_data(
payload: ImagePayload,
cloud_storage: CloudStorageManager,
config: dict,
response_formatter: ResponseFormatter,
) -> Tuple[Optional[str], Optional[JSONResponse]]:
"""
Resolve image payload to a single image_data string (base64) for the inference API.
When both are provided, image_url has priority; otherwise uses image_data or fetches from image_url (gs:// or s3://).
Returns (image_data, None) on success, or (None, error_json_response) on validation/fetch error.
"""
if payload.image_url:
pass
elif payload.image_data:
return (payload.image_data, None)
else:
return (
None,
JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content=response_formatter.buildErrorResponse(
'Query or Image data should not be empty'
),
),
)
try:
scheme, bucket, key = _parse_cloud_image_url(payload.image_url)
except ValueError as e:
return (
None,
JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content=response_formatter.buildErrorResponse(str(e)),
),
)
cloud_provider = (
(config.get('cloud_config') or {}).get('cloud_provider', '').lower()
)
if scheme == 'gs' and cloud_provider != 'gcp':
return (
None,
JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content=response_formatter.buildErrorResponse(
'image_url gs:// is only supported when cloud provider is GCP'
),
),
)
if scheme == 's3' and cloud_provider != 'aws':
return (
None,
JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content=response_formatter.buildErrorResponse(
'image_url s3:// is only supported when cloud provider is AWS'
),
),
)
try:
content = cloud_storage.read_file(bucket, key)
except Exception as e:
return (
None,
JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content=response_formatter.buildErrorResponse(
f'Failed to fetch image from storage: {e!s}'
),
),
)
image_bytes = content.read() if hasattr(content, 'read') else content
Comment on lines +181 to +193
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Blocking synchronous I/O call inside an async function will stall the event loop.

cloud_storage.read_file(bucket, key) is a synchronous SDK call. In the same codebase, upload_document (in knowledge_base_document_controller.py, line 130) wraps the analogous cloud_storage.save_small_file with asyncio.to_thread(...). The same treatment is needed here to avoid blocking the FastAPI event loop under concurrent requests.

Proposed fix
     try:
-        content = cloud_storage.read_file(bucket, key)
+        content = await asyncio.to_thread(cloud_storage.read_file, bucket, key)
     except Exception as e:

You'll also need to add import asyncio at the top of the file.

🧰 Tools
🪛 Ruff (0.15.1)

[warning] 183-183: Do not catch blind exception: Exception

(BLE001)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@wavefront/server/modules/knowledge_base_module/knowledge_base_module/controllers/rag_retreival_controller.py`
around lines 181 - 193, The call to cloud_storage.read_file(bucket, key) is
synchronous and will block the async event loop; change it to run in a thread
using asyncio.to_thread (e.g., content = await
asyncio.to_thread(cloud_storage.read_file, bucket, key)) and preserve the
existing try/except so any exception raised during the threaded call is caught
and returned the same way; also add import asyncio at the top of the module so
asyncio.to_thread is available.

if not image_bytes:
return (
None,
JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content=response_formatter.buildErrorResponse(
'Image from URL is empty'
),
),
)
image_data_b64 = base64.b64encode(image_bytes).decode('utf-8')
return (image_data_b64, None)


@rag_retrieval_router.post('/v1/knowledge-base/{kb_id}/retrieve')
@inject
async def retrieve_query(
Expand Down Expand Up @@ -128,6 +237,9 @@ async def retrieve_query(
Provide[KnowledgeBaseContainer.image_knowledge_base_retrieve]
),
config: dict = Depends(Provide[KnowledgeBaseContainer.config]),
cloud_storage: CloudStorageManager = Depends(
Provide[KnowledgeBaseContainer.cloud_storage]
),
):
if not query and not payload:
return JSONResponse(
Expand Down Expand Up @@ -156,9 +268,14 @@ async def retrieve_query(
limit,
)
else:
image_data, error_response = await _resolve_image_data(
payload, cloud_storage, config, response_formatter
)
if error_response is not None:
return error_response
inference_url = config['model']['inference_service_url']
retrieved_docs = await image_rag_retrieval.retrieve_images(
payload.image_data,
image_data,
inference_url,
kb_id,
threshold,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,64 @@ def get_image_embedding_dino(

return sql_query, params

def get_documents_list_query(
self,
kb_id: str,
file_type: Optional[str] = None,
filter: Optional[str] = None,
offset: int = 0,
limit: int = 10,
) -> Tuple[str, Dict[str, Any]]:
"""
Generate SQL query to list knowledge base documents with optional
metadata filter (OData-style $filter) and file_type.

Returns:
Tuple of (SQL query string, query parameters)
"""
params: Dict[str, Any] = {
'kb_id': kb_id,
'offset': offset,
'limit': limit,
}
conditions = ['knowledge_base_id = :kb_id']
if file_type:
params['file_type'] = file_type
conditions.append('file_type = :file_type')

metadata_filter_clause = ''
if filter:
where_clause, filter_params = self.odata_parser.prepare_odata_filter(filter)
if where_clause and filter_params:
metadata_filter_clause = self.build_metadata_clause(
where_clause,
filter_params,
lambda field: f"(metadata_value ->> '{field}')",
)
params.update(filter_params)
conditions.append(f'({metadata_filter_clause})')

where_sql = ' AND '.join(conditions)
sql_query = f"""
SELECT
id,
knowledge_base_id,
file_path,
file_name,
file_type,
file_size,
created_at,
updated_at,
metadata_value
FROM
{KnowledgeBaseDocuments.__tablename__}
WHERE
{where_sql}
ORDER BY created_at DESC
LIMIT :limit OFFSET :offset
"""
return sql_query, params

@staticmethod
def get_update_tokens_query() -> str:
"""
Expand Down
Loading