-
Notifications
You must be signed in to change notification settings - Fork 30
feat(floware): image url base search and tag based document fetch #221
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,5 @@ | ||
| from typing import List, Optional | ||
| import base64 | ||
| from typing import List, Optional, Tuple | ||
| import uuid | ||
|
|
||
| from common_module.common_container import CommonContainer | ||
|
|
@@ -21,6 +22,7 @@ | |
| ) | ||
| from knowledge_base_module.services.kb_rag_retrieve import KBRagResponse | ||
| from knowledge_base_module.services.image_rag_retrieve import ImageRagRetrieve | ||
| from flo_cloud.cloud_storage import CloudStorageManager | ||
| from pydantic import BaseModel, Field | ||
| from datetime import datetime | ||
| from sqlalchemy import Result | ||
|
|
@@ -71,9 +73,10 @@ class DocWiseEmbeddingSchema(BaseModel): | |
|
|
||
|
|
||
| class ImagePayload(BaseModel): | ||
| """Payload for Image embedding.""" | ||
| """Payload for Image embedding. Use image_data (base64) or image_url (gs:// or s3://); image_url has priority if both are set.""" | ||
|
|
||
| image_data: Optional[str] = None | ||
| image_url: Optional[str] = None | ||
|
|
||
|
|
||
| class DocumentPayload(BaseModel): | ||
|
|
@@ -96,6 +99,112 @@ def convert_uuids_to_str(data): | |
| return data | ||
|
|
||
|
|
||
| def _parse_cloud_image_url(url: str) -> Tuple[str, str, str]: | ||
| """ | ||
| Parse gs:// or s3:// URL into (scheme, bucket, key). | ||
| Returns (scheme, bucket, key) or raises ValueError. | ||
| """ | ||
| url = (url or '').strip() | ||
| if url.startswith('gs://'): | ||
| rest = url[5:] | ||
| if '/' not in rest: | ||
| raise ValueError('Invalid gs:// URL: missing path after bucket') | ||
| bucket, _, key = rest.partition('/') | ||
| return ('gs', bucket, key) | ||
| if url.startswith('s3://'): | ||
| rest = url[5:] | ||
| if '/' not in rest: | ||
| raise ValueError('Invalid s3:// URL: missing path after bucket') | ||
| bucket, _, key = rest.partition('/') | ||
| return ('s3', bucket, key) | ||
| raise ValueError('image_url must be in gs:// or s3:// format') | ||
|
|
||
|
|
||
| async def _resolve_image_data( | ||
| payload: ImagePayload, | ||
| cloud_storage: CloudStorageManager, | ||
| config: dict, | ||
| response_formatter: ResponseFormatter, | ||
| ) -> Tuple[Optional[str], Optional[JSONResponse]]: | ||
| """ | ||
| Resolve image payload to a single image_data string (base64) for the inference API. | ||
| When both are provided, image_url has priority; otherwise uses image_data or fetches from image_url (gs:// or s3://). | ||
| Returns (image_data, None) on success, or (None, error_json_response) on validation/fetch error. | ||
| """ | ||
| if payload.image_url: | ||
| pass | ||
| elif payload.image_data: | ||
| return (payload.image_data, None) | ||
| else: | ||
| return ( | ||
| None, | ||
| JSONResponse( | ||
| status_code=status.HTTP_400_BAD_REQUEST, | ||
| content=response_formatter.buildErrorResponse( | ||
| 'Query or Image data should not be empty' | ||
| ), | ||
| ), | ||
| ) | ||
| try: | ||
| scheme, bucket, key = _parse_cloud_image_url(payload.image_url) | ||
| except ValueError as e: | ||
| return ( | ||
| None, | ||
| JSONResponse( | ||
| status_code=status.HTTP_400_BAD_REQUEST, | ||
| content=response_formatter.buildErrorResponse(str(e)), | ||
| ), | ||
| ) | ||
| cloud_provider = ( | ||
| (config.get('cloud_config') or {}).get('cloud_provider', '').lower() | ||
| ) | ||
| if scheme == 'gs' and cloud_provider != 'gcp': | ||
| return ( | ||
| None, | ||
| JSONResponse( | ||
| status_code=status.HTTP_400_BAD_REQUEST, | ||
| content=response_formatter.buildErrorResponse( | ||
| 'image_url gs:// is only supported when cloud provider is GCP' | ||
| ), | ||
| ), | ||
| ) | ||
| if scheme == 's3' and cloud_provider != 'aws': | ||
| return ( | ||
| None, | ||
| JSONResponse( | ||
| status_code=status.HTTP_400_BAD_REQUEST, | ||
| content=response_formatter.buildErrorResponse( | ||
| 'image_url s3:// is only supported when cloud provider is AWS' | ||
| ), | ||
| ), | ||
| ) | ||
| try: | ||
| content = cloud_storage.read_file(bucket, key) | ||
| except Exception as e: | ||
| return ( | ||
| None, | ||
| JSONResponse( | ||
| status_code=status.HTTP_400_BAD_REQUEST, | ||
| content=response_formatter.buildErrorResponse( | ||
| f'Failed to fetch image from storage: {e!s}' | ||
| ), | ||
| ), | ||
| ) | ||
| image_bytes = content.read() if hasattr(content, 'read') else content | ||
|
Comment on lines
+181
to
+193
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Blocking synchronous I/O call inside an async function will stall the event loop.
Proposed fix try:
- content = cloud_storage.read_file(bucket, key)
+ content = await asyncio.to_thread(cloud_storage.read_file, bucket, key)
except Exception as e:You'll also need to add 🧰 Tools🪛 Ruff (0.15.1)[warning] 183-183: Do not catch blind exception: (BLE001) 🤖 Prompt for AI Agents |
||
| if not image_bytes: | ||
| return ( | ||
| None, | ||
| JSONResponse( | ||
| status_code=status.HTTP_400_BAD_REQUEST, | ||
| content=response_formatter.buildErrorResponse( | ||
| 'Image from URL is empty' | ||
| ), | ||
| ), | ||
| ) | ||
| image_data_b64 = base64.b64encode(image_bytes).decode('utf-8') | ||
| return (image_data_b64, None) | ||
|
|
||
|
|
||
| @rag_retrieval_router.post('/v1/knowledge-base/{kb_id}/retrieve') | ||
| @inject | ||
| async def retrieve_query( | ||
|
|
@@ -128,6 +237,9 @@ async def retrieve_query( | |
| Provide[KnowledgeBaseContainer.image_knowledge_base_retrieve] | ||
| ), | ||
| config: dict = Depends(Provide[KnowledgeBaseContainer.config]), | ||
| cloud_storage: CloudStorageManager = Depends( | ||
| Provide[KnowledgeBaseContainer.cloud_storage] | ||
| ), | ||
| ): | ||
| if not query and not payload: | ||
| return JSONResponse( | ||
|
|
@@ -156,9 +268,14 @@ async def retrieve_query( | |
| limit, | ||
| ) | ||
| else: | ||
| image_data, error_response = await _resolve_image_data( | ||
| payload, cloud_storage, config, response_formatter | ||
| ) | ||
| if error_response is not None: | ||
| return error_response | ||
| inference_url = config['model']['inference_service_url'] | ||
| retrieved_docs = await image_rag_retrieval.retrieve_images( | ||
| payload.image_data, | ||
| image_data, | ||
| inference_url, | ||
| kb_id, | ||
| threshold, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Empty object key passes validation for URLs like
gs://bucket/.When the URL is
gs://bucket/,rest.partition('/')produces('bucket', '/', '')— the function returns an emptykey, which will cause a confusing error downstream when the storage read fails. Add a check for a non-empty key.Proposed fix
if url.startswith('gs://'): rest = url[5:] if '/' not in rest: raise ValueError('Invalid gs:// URL: missing path after bucket') bucket, _, key = rest.partition('/') + if not key: + raise ValueError('Invalid gs:// URL: missing object key after bucket') return ('gs', bucket, key) if url.startswith('s3://'): rest = url[5:] if '/' not in rest: raise ValueError('Invalid s3:// URL: missing path after bucket') bucket, _, key = rest.partition('/') + if not key: + raise ValueError('Invalid s3:// URL: missing object key after bucket') return ('s3', bucket, key)🧰 Tools
🪛 Ruff (0.15.1)
[warning] 111-111: Avoid specifying long messages outside the exception class
(TRY003)
[warning] 117-117: Avoid specifying long messages outside the exception class
(TRY003)
[warning] 120-120: Avoid specifying long messages outside the exception class
(TRY003)
🤖 Prompt for AI Agents