-
-
Notifications
You must be signed in to change notification settings - Fork 618
Feature/magic eraser #697
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Feature/magic eraser #697
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
24bbb6b
Add OCR text selection feature
Aryan-Shan 3bb8d9b
Add OCR text selection feature
Aryan-Shan 9e2cc0d
Add OCR text selection feature
Aryan-Shan 4a3362c
Added text selection feature
Aryan-Shan a223a88
Added text selection feature
Aryan-Shan c52b418
Added Magic Eraser (Model excluded)
Aryan-Shan 244dbc4
Fixed image resolution for magic eraser
Aryan-Shan a91b203
magic eraser feature refactored
Aryan-Shan a6b76e7
magic eraser feature refactored
Aryan-Shan 7a4bd2c
magic eraser feature refactored
Aryan-Shan File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,115 @@ | ||
|
|
||
| import cv2 | ||
| import numpy as np | ||
| import onnxruntime as ort | ||
| import os | ||
| from app.logging.setup_logging import get_logger | ||
|
|
||
| logger = get_logger(__name__) | ||
|
|
||
| class Inpainter: | ||
| def __init__(self): | ||
| self.output_img_size = 512 # LaMa fixed input size | ||
| self._init_session() | ||
|
|
||
| def _init_session(self): | ||
| """Initialize the ONNX Runtime session.""" | ||
| model_path = os.path.join( | ||
| os.path.dirname(os.path.dirname(__file__)), | ||
| "models", | ||
| "onnx_models", | ||
| "lama_fp32.onnx" | ||
| ) | ||
|
|
||
| if not os.path.exists(model_path): | ||
| logger.error(f"Inpainting model not found at {model_path}") | ||
| self.session = None | ||
| return | ||
|
|
||
| providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] | ||
| if 'CUDAExecutionProvider' not in ort.get_available_providers(): | ||
| providers = ['CPUExecutionProvider'] | ||
|
|
||
| try: | ||
| self.session = ort.InferenceSession(model_path, providers=providers) | ||
| logger.info(f"Inpainting model loaded successfully from {model_path}") | ||
| except Exception as e: | ||
| logger.error(f"Failed to load inpainting model: {e}") | ||
| self.session = None | ||
|
|
||
| def inpaint(self, image: np.ndarray, mask: np.ndarray) -> np.ndarray: | ||
| """ | ||
| Perform inpainting on the image using the mask. | ||
| :param image: Input image (H, W, 3) BGR | ||
| :param mask: Input mask (H, W) or (H, W, 1) 0-255 (255=inpainting area) | ||
| :return: Inpainted image (H, W, 3) BGR | ||
| """ | ||
| if self.session is None: | ||
| # Try to re-init if it failed previously (e.g. download finished) | ||
| self._init_session() | ||
| if self.session is None: | ||
| raise RuntimeError("Inpainting model not loaded.") | ||
|
|
||
| original_h, original_w = image.shape[:2] | ||
|
|
||
| # 1. Preprocess | ||
| # Resize/Pad to 512x512 | ||
| # For simplicity, we'll just resize to 512x512. | ||
| # LaMa is resilient, but aspect ratio distortion might affect quality slightly. | ||
| # Ideally, we should pad, but resizing is faster/easier for V1. | ||
| # Let's try resizing first. | ||
|
|
||
| img_resized = cv2.resize(image, (self.output_img_size, self.output_img_size), interpolation=cv2.INTER_AREA) | ||
| mask_resized = cv2.resize(mask, (self.output_img_size, self.output_img_size), interpolation=cv2.INTER_NEAREST) | ||
|
|
||
| # Normalize Image: [0, 255] -> [0, 1], HWC -> CHW | ||
| img_input = img_resized.astype(np.float32) / 255.0 | ||
| img_input = np.transpose(img_input, (2, 0, 1)) # (3, 512, 512) | ||
| img_input = np.expand_dims(img_input, axis=0) # (1, 3, 512, 512) | ||
|
|
||
| # Normalize Mask: [0, 255] -> [0, 1], HW -> CHW | ||
| if len(mask_resized.shape) == 2: | ||
| mask_resized = np.expand_dims(mask_resized, axis=-1) # (512, 512, 1) | ||
|
|
||
| mask_input = mask_resized.astype(np.float32) / 255.0 | ||
| mask_input = (mask_input > 0.5).astype(np.float32) # threshold | ||
| mask_input = np.transpose(mask_input, (2, 0, 1)) # (1, 512, 512) | ||
| mask_input = np.expand_dims(mask_input, axis=0) # (1, 1, 512, 512) | ||
|
|
||
| # 2. Inference | ||
| inputs = { | ||
| self.session.get_inputs()[0].name: img_input, | ||
| self.session.get_inputs()[1].name: mask_input | ||
| } | ||
| outputs = self.session.run(None, inputs) | ||
| output_data = outputs[0] # (1, 3, 512, 512) | ||
|
|
||
| # 3. Postprocess | ||
| # Clip to [0, 255], CHW -> HWC | ||
| output_img = output_data[0] | ||
| output_img = np.transpose(output_img, (1, 2, 0)) # (512, 512, 3) | ||
|
|
||
| # Auto-detect output range: LaMa can be [0, 1] or [0, 255] | ||
| # If max value is small (<= 1.0 + epsilon), assume it's [0, 1] and scale up. | ||
| if output_img.max() <= 1.1: | ||
| output_img = output_img * 255.0 | ||
|
|
||
| output_img = np.clip(output_img, 0, 255).astype(np.uint8) | ||
|
|
||
| # Resize back to original | ||
| result_img = cv2.resize(output_img, (original_w, original_h), interpolation=cv2.INTER_CUBIC) | ||
|
coderabbitai[bot] marked this conversation as resolved.
|
||
|
|
||
| # 4. Blend to preserve original quality | ||
| # Create a binary mask of the inpainted region | ||
| if len(mask.shape) == 2: | ||
| mask = mask[:, :, np.newaxis] | ||
|
|
||
| # Normalize mask to 0-1 | ||
| mask_normalized = mask.astype(np.float32) / 255.0 | ||
| mask_normalized = (mask_normalized > 0.5).astype(np.float32) | ||
|
|
||
| # Blend: original * (1 - mask) + result * mask | ||
| final_img = image.astype(np.float32) * (1 - mask_normalized) + result_img.astype(np.float32) * mask_normalized | ||
| final_img = np.clip(final_img, 0, 255).astype(np.uint8) | ||
|
|
||
| return final_img | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,76 @@ | ||
| from fastapi import APIRouter, HTTPException, Body | ||
|
coderabbitai[bot] marked this conversation as resolved.
|
||
| from pydantic import BaseModel | ||
| import cv2 | ||
| import numpy as np | ||
| import base64 | ||
| import os | ||
| from app.models.Inpainter import Inpainter | ||
| from app.logging.setup_logging import get_logger | ||
|
|
||
| logger = get_logger(__name__) | ||
| router = APIRouter() | ||
|
|
||
| # Initialize Inpainter - GLOBAL instance to avoid reloading model | ||
| inpainter = Inpainter() | ||
|
|
||
| class MagicEraserRequest(BaseModel): | ||
| image_path: str | ||
| mask_data: str # Base64 string | ||
|
Aryan-Shan marked this conversation as resolved.
|
||
|
|
||
| class MagicEraserResponse(BaseModel): | ||
| success: bool | ||
| image_data: str | None = None # Base64 string | ||
| error: str | None = None | ||
|
|
||
| def base64_to_cv2(b64str): | ||
| if "," in b64str: | ||
| b64str = b64str.split(",")[1] | ||
| img_data = base64.b64decode(b64str) | ||
| nparr = np.frombuffer(img_data, np.uint8) | ||
| img = cv2.imdecode(nparr, cv2.IMREAD_UNCHANGED) | ||
| return img | ||
|
|
||
| def cv2_to_base64(img): | ||
| _, buffer = cv2.imencode('.png', img) | ||
| b64_str = base64.b64encode(buffer).decode('utf-8') | ||
| return f"data:image/png;base64,{b64_str}" | ||
|
|
||
| @router.post("/magic-eraser", response_model=MagicEraserResponse) | ||
| def magic_eraser(body: MagicEraserRequest): | ||
| try: | ||
| # Custom Validation: Prevent Path Traversal | ||
| # Ensure path is absolute and doesn't contain traversal sequences | ||
| abs_path = os.path.abspath(body.image_path) | ||
| base_dir = os.path.abspath(os.getcwd()) # Or a specific allowed media directory | ||
|
|
||
| # Simple check for ".." usage which suggests traversal attempts | ||
| if ".." in body.image_path: | ||
| return MagicEraserResponse(success=False, error="Invalid image path: Path traversal detected") | ||
|
|
||
| if not os.path.exists(abs_path): | ||
| return MagicEraserResponse(success=False, error="Image file not found") | ||
|
|
||
| image = cv2.imread(body.image_path) | ||
| if image is None: | ||
| return MagicEraserResponse(success=False, error="Failed to load image file") | ||
|
|
||
| # 2. Load Mask | ||
| mask = base64_to_cv2(body.mask_data) | ||
| if mask is None: | ||
| return MagicEraserResponse(success=False, error="Failed to decode mask data") | ||
|
|
||
| # Ensure mask is single channel | ||
| if len(mask.shape) == 3: | ||
| mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY) | ||
|
|
||
| # 3. Inpaint | ||
| result = inpainter.inpaint(image, mask) | ||
|
|
||
| # 4. Return result as Base64 for preview | ||
| b64_result = cv2_to_base64(result) | ||
|
|
||
| return MagicEraserResponse(success=True, image_data=b64_result) | ||
|
|
||
| except Exception as e: | ||
| logger.exception("Magic Eraser failed") | ||
| return MagicEraserResponse(success=False, error="Internal processing error") | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,54 @@ | ||
| import cv2 | ||
| import numpy as np | ||
| import sys | ||
| import os | ||
|
|
||
| # Add backend to path | ||
| sys.path.append(os.path.join(os.path.dirname(os.path.dirname(__file__)))) | ||
|
|
||
| from app.models.Inpainter import Inpainter | ||
|
|
||
| def test_inpainter(): | ||
| print("Initializing Inpainter...") | ||
| try: | ||
| inpainter = Inpainter() | ||
| if inpainter.session is None: | ||
| print("FAILED: Model session not initialized. Model file might be missing.") | ||
| return | ||
|
|
||
| print("Creating dummy image and mask...") | ||
| # Create a 512x512 gradient image | ||
| img = np.zeros((512, 512, 3), dtype=np.uint8) | ||
| for i in range(512): | ||
| img[i, :, :] = i // 2 | ||
|
|
||
| # Create a mask (white square in center) | ||
| mask = np.zeros((512, 512), dtype=np.uint8) | ||
| mask[200:300, 200:300] = 255 | ||
|
|
||
| print("Running inpaint...") | ||
| result = inpainter.inpaint(img, mask) | ||
|
|
||
| print("Inpaint finished.") | ||
| print(f"Result shape: {result.shape}") | ||
|
|
||
| # Verify shape | ||
| assert result.shape == img.shape, f"Shape mismatch. Expected {img.shape}, got {result.shape}" | ||
|
|
||
| # Check if the center is not black/unmodified (basic check) | ||
| center_pixel = result[250, 250] | ||
| print(f"Center pixel value: {center_pixel}") | ||
|
|
||
| # Check if the center is not black (0) which would indicate incorrect scaling [0,1]->uint8 | ||
| assert not np.all(center_pixel == 0), "Center pixel is black (0). Model output likely [0, 1] but treated as [0, 255]." | ||
|
|
||
| print("SUCCESS: Inpainter verification passed.") | ||
|
|
||
| except Exception as e: | ||
| print(f"FAILED: Exception occurred: {e}") | ||
| import traceback | ||
| traceback.print_exc() | ||
|
coderabbitai[bot] marked this conversation as resolved.
|
||
| sys.exit(1) | ||
|
|
||
| if __name__ == "__main__": | ||
| test_inpainter() | ||
Binary file not shown.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.