Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ pnpm-debug.log*
lerna-debug.log*

backend/app/models/image-generation/*
backend/app/models/onnx_models/*

node_modules

Expand Down
115 changes: 115 additions & 0 deletions backend/app/models/Inpainter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@

import cv2
import numpy as np
import onnxruntime as ort
import os
from app.logging.setup_logging import get_logger

logger = get_logger(__name__)

class Inpainter:
def __init__(self):
self.output_img_size = 512 # LaMa fixed input size
self._init_session()

def _init_session(self):
"""Initialize the ONNX Runtime session."""
model_path = os.path.join(
os.path.dirname(os.path.dirname(__file__)),
"models",
"onnx_models",
"lama_fp32.onnx"
)

if not os.path.exists(model_path):
logger.error(f"Inpainting model not found at {model_path}")
self.session = None
return
Comment thread
coderabbitai[bot] marked this conversation as resolved.

providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
if 'CUDAExecutionProvider' not in ort.get_available_providers():
providers = ['CPUExecutionProvider']

try:
self.session = ort.InferenceSession(model_path, providers=providers)
logger.info(f"Inpainting model loaded successfully from {model_path}")
except Exception as e:
logger.error(f"Failed to load inpainting model: {e}")
self.session = None

def inpaint(self, image: np.ndarray, mask: np.ndarray) -> np.ndarray:
"""
Perform inpainting on the image using the mask.
:param image: Input image (H, W, 3) BGR
:param mask: Input mask (H, W) or (H, W, 1) 0-255 (255=inpainting area)
:return: Inpainted image (H, W, 3) BGR
"""
if self.session is None:
# Try to re-init if it failed previously (e.g. download finished)
self._init_session()
if self.session is None:
raise RuntimeError("Inpainting model not loaded.")

original_h, original_w = image.shape[:2]

# 1. Preprocess
# Resize/Pad to 512x512
# For simplicity, we'll just resize to 512x512.
# LaMa is resilient, but aspect ratio distortion might affect quality slightly.
# Ideally, we should pad, but resizing is faster/easier for V1.
# Let's try resizing first.

img_resized = cv2.resize(image, (self.output_img_size, self.output_img_size), interpolation=cv2.INTER_AREA)
mask_resized = cv2.resize(mask, (self.output_img_size, self.output_img_size), interpolation=cv2.INTER_NEAREST)

# Normalize Image: [0, 255] -> [0, 1], HWC -> CHW
img_input = img_resized.astype(np.float32) / 255.0
img_input = np.transpose(img_input, (2, 0, 1)) # (3, 512, 512)
img_input = np.expand_dims(img_input, axis=0) # (1, 3, 512, 512)

# Normalize Mask: [0, 255] -> [0, 1], HW -> CHW
if len(mask_resized.shape) == 2:
mask_resized = np.expand_dims(mask_resized, axis=-1) # (512, 512, 1)

mask_input = mask_resized.astype(np.float32) / 255.0
mask_input = (mask_input > 0.5).astype(np.float32) # threshold
mask_input = np.transpose(mask_input, (2, 0, 1)) # (1, 512, 512)
mask_input = np.expand_dims(mask_input, axis=0) # (1, 1, 512, 512)

# 2. Inference
inputs = {
self.session.get_inputs()[0].name: img_input,
self.session.get_inputs()[1].name: mask_input
}
outputs = self.session.run(None, inputs)
output_data = outputs[0] # (1, 3, 512, 512)

# 3. Postprocess
# Clip to [0, 255], CHW -> HWC
output_img = output_data[0]
output_img = np.transpose(output_img, (1, 2, 0)) # (512, 512, 3)

# Auto-detect output range: LaMa can be [0, 1] or [0, 255]
# If max value is small (<= 1.0 + epsilon), assume it's [0, 1] and scale up.
if output_img.max() <= 1.1:
output_img = output_img * 255.0

output_img = np.clip(output_img, 0, 255).astype(np.uint8)

# Resize back to original
result_img = cv2.resize(output_img, (original_w, original_h), interpolation=cv2.INTER_CUBIC)
Comment thread
coderabbitai[bot] marked this conversation as resolved.

# 4. Blend to preserve original quality
# Create a binary mask of the inpainted region
if len(mask.shape) == 2:
mask = mask[:, :, np.newaxis]

# Normalize mask to 0-1
mask_normalized = mask.astype(np.float32) / 255.0
mask_normalized = (mask_normalized > 0.5).astype(np.float32)

# Blend: original * (1 - mask) + result * mask
final_img = image.astype(np.float32) * (1 - mask_normalized) + result_img.astype(np.float32) * mask_normalized
final_img = np.clip(final_img, 0, 255).astype(np.uint8)

return final_img
76 changes: 76 additions & 0 deletions backend/app/routes/edit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from fastapi import APIRouter, HTTPException, Body
Comment thread
coderabbitai[bot] marked this conversation as resolved.
from pydantic import BaseModel
import cv2
import numpy as np
import base64
import os
from app.models.Inpainter import Inpainter
from app.logging.setup_logging import get_logger

logger = get_logger(__name__)
router = APIRouter()

# Initialize Inpainter - GLOBAL instance to avoid reloading model
inpainter = Inpainter()

class MagicEraserRequest(BaseModel):
image_path: str
mask_data: str # Base64 string
Comment thread
Aryan-Shan marked this conversation as resolved.

class MagicEraserResponse(BaseModel):
success: bool
image_data: str | None = None # Base64 string
error: str | None = None

def base64_to_cv2(b64str):
if "," in b64str:
b64str = b64str.split(",")[1]
img_data = base64.b64decode(b64str)
nparr = np.frombuffer(img_data, np.uint8)
img = cv2.imdecode(nparr, cv2.IMREAD_UNCHANGED)
return img

def cv2_to_base64(img):
_, buffer = cv2.imencode('.png', img)
b64_str = base64.b64encode(buffer).decode('utf-8')
return f"data:image/png;base64,{b64_str}"

@router.post("/magic-eraser", response_model=MagicEraserResponse)
def magic_eraser(body: MagicEraserRequest):
try:
# Custom Validation: Prevent Path Traversal
# Ensure path is absolute and doesn't contain traversal sequences
abs_path = os.path.abspath(body.image_path)
base_dir = os.path.abspath(os.getcwd()) # Or a specific allowed media directory

# Simple check for ".." usage which suggests traversal attempts
if ".." in body.image_path:
return MagicEraserResponse(success=False, error="Invalid image path: Path traversal detected")

if not os.path.exists(abs_path):
return MagicEraserResponse(success=False, error="Image file not found")

image = cv2.imread(body.image_path)
if image is None:
return MagicEraserResponse(success=False, error="Failed to load image file")

# 2. Load Mask
mask = base64_to_cv2(body.mask_data)
if mask is None:
return MagicEraserResponse(success=False, error="Failed to decode mask data")

# Ensure mask is single channel
if len(mask.shape) == 3:
mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)

# 3. Inpaint
result = inpainter.inpaint(image, mask)

# 4. Return result as Base64 for preview
b64_result = cv2_to_base64(result)

return MagicEraserResponse(success=True, image_data=b64_result)

except Exception as e:
logger.exception("Magic Eraser failed")
return MagicEraserResponse(success=False, error="Internal processing error")
2 changes: 2 additions & 0 deletions backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from app.routes.images import router as images_router
from app.routes.face_clusters import router as face_clusters_router
from app.routes.user_preferences import router as user_preferences_router
from app.routes.edit import router as edit_router
from fastapi.openapi.utils import get_openapi
from app.logging.setup_logging import (
configure_uvicorn_logging,
Expand Down Expand Up @@ -132,6 +133,7 @@ async def root():
app.include_router(
user_preferences_router, prefix="/user-preferences", tags=["User Preferences"]
)
app.include_router(edit_router, prefix="/edit", tags=["Edit"])


# Entry point for running with: python3 main.py
Expand Down
4 changes: 3 additions & 1 deletion backend/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ mkdocs-material==9.6.16
mkdocs-material-extensions==1.3.1
mkdocs-swagger-ui-tag==0.7.1
mpmath==1.3.0
numpy==1.26.4
numpy<2.0.0
tqdm==4.66.4
requests==2.31.0
onnxruntime==1.17.1
opencv-python==4.9.0.80
orjson==3.10.3
Expand Down
54 changes: 54 additions & 0 deletions backend/tests/test_inpainter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import cv2
import numpy as np
import sys
import os

# Add backend to path
sys.path.append(os.path.join(os.path.dirname(os.path.dirname(__file__))))

from app.models.Inpainter import Inpainter

def test_inpainter():
print("Initializing Inpainter...")
try:
inpainter = Inpainter()
if inpainter.session is None:
print("FAILED: Model session not initialized. Model file might be missing.")
return

print("Creating dummy image and mask...")
# Create a 512x512 gradient image
img = np.zeros((512, 512, 3), dtype=np.uint8)
for i in range(512):
img[i, :, :] = i // 2

# Create a mask (white square in center)
mask = np.zeros((512, 512), dtype=np.uint8)
mask[200:300, 200:300] = 255

print("Running inpaint...")
result = inpainter.inpaint(img, mask)

print("Inpaint finished.")
print(f"Result shape: {result.shape}")

# Verify shape
assert result.shape == img.shape, f"Shape mismatch. Expected {img.shape}, got {result.shape}"

# Check if the center is not black/unmodified (basic check)
center_pixel = result[250, 250]
print(f"Center pixel value: {center_pixel}")

# Check if the center is not black (0) which would indicate incorrect scaling [0,1]->uint8
assert not np.all(center_pixel == 0), "Center pixel is black (0). Model output likely [0, 1] but treated as [0, 255]."

print("SUCCESS: Inpainter verification passed.")

except Exception as e:
print(f"FAILED: Exception occurred: {e}")
import traceback
traceback.print_exc()
Comment thread
coderabbitai[bot] marked this conversation as resolved.
sys.exit(1)

if __name__ == "__main__":
test_inpainter()
Binary file added debug.txt
Binary file not shown.
94 changes: 94 additions & 0 deletions docs/backend/backend_python/openapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -1304,6 +1304,47 @@
}
}
}
},
"/edit/magic-eraser": {
"post": {
"tags": [
"Edit"
],
"summary": "Magic Eraser",
"operationId": "magic_eraser_edit_magic_eraser_post",
"requestBody": {
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/MagicEraserRequest"
}
}
},
"required": true
},
"responses": {
"200": {
"description": "Successful Response",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/MagicEraserResponse"
}
}
}
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/HTTPValidationError"
}
}
}
}
}
}
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.
},
"components": {
Expand Down Expand Up @@ -2266,6 +2307,59 @@
],
"title": "InputType"
},
"MagicEraserRequest": {
"properties": {
"image_path": {
"type": "string",
"title": "Image Path"
},
"mask_data": {
"type": "string",
"title": "Mask Data"
}
},
"type": "object",
"required": [
"image_path",
"mask_data"
],
"title": "MagicEraserRequest"
},
"MagicEraserResponse": {
"properties": {
"success": {
"type": "boolean",
"title": "Success"
},
"image_data": {
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
],
"title": "Image Data"
},
"error": {
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
],
"title": "Error"
}
},
"type": "object",
"required": [
"success"
],
"title": "MagicEraserResponse"
},
"MetadataModel": {
"properties": {
"name": {
Expand Down
Loading