Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 82 additions & 49 deletions src/tilefusion/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,42 +6,87 @@

import numpy as np

# GPU detection - PyTorch based
try:
import cupy as cp
from cupyx.scipy.ndimage import shift as cp_shift
from cucim.skimage.exposure import match_histograms
from cucim.skimage.measure import block_reduce
from cucim.skimage.registration import phase_cross_correlation
from opm_processing.imageprocessing.ssim_cuda import (
structural_similarity_cupy_sep_shared as ssim_cuda,
)

xp = cp
USING_GPU = True
except Exception:
cp = None
cp_shift = None
from skimage.exposure import match_histograms
from skimage.measure import block_reduce
from skimage.registration import phase_cross_correlation
from scipy.ndimage import shift as _shift_cpu
from skimage.metrics import structural_similarity as _ssim_cpu

xp = np
USING_GPU = False
import torch
import torch.nn.functional as F
TORCH_AVAILABLE = True
CUDA_AVAILABLE = torch.cuda.is_available()
except ImportError:
torch = None
F = None
TORCH_AVAILABLE = False
CUDA_AVAILABLE = False

# CPU fallbacks
from scipy.ndimage import shift as _shift_cpu
from skimage.exposure import match_histograms
from skimage.measure import block_reduce
from skimage.metrics import structural_similarity as _ssim_cpu
from skimage.registration import phase_cross_correlation

# Legacy compatibility
USING_GPU = CUDA_AVAILABLE
xp = np
cp = None


def shift_array(arr, shift_vec):
"""Shift array using GPU if available, else CPU fallback."""
if USING_GPU and cp_shift is not None:
return cp_shift(arr, shift=shift_vec, order=1, prefilter=False)
return _shift_cpu(arr, shift=shift_vec, order=1, prefilter=False)
"""
Shift array by subpixel amounts using GPU (torch) or CPU (scipy).

Parameters
----------
arr : ndarray
2D input array.
shift_vec : array-like
(dy, dx) shift amounts.

Returns
-------
shifted : ndarray
Shifted array, same shape as input.
"""
arr_np = np.asarray(arr)

if CUDA_AVAILABLE and arr_np.ndim == 2:
return _shift_array_torch(arr_np, shift_vec)

return _shift_cpu(arr_np, shift=shift_vec, order=1, prefilter=False)


def _shift_array_torch(arr: np.ndarray, shift_vec) -> np.ndarray:
"""GPU shift using torch.nn.functional.grid_sample."""
h, w = arr.shape
dy, dx = float(shift_vec[0]), float(shift_vec[1])

# Create pixel coordinate grids
y_coords = torch.arange(h, device="cuda", dtype=torch.float32)
x_coords = torch.arange(w, device="cuda", dtype=torch.float32)
grid_y, grid_x = torch.meshgrid(y_coords, x_coords, indexing="ij")

# Apply shift: to shift output by (dy, dx), sample from (y-dy, x-dx)
sample_y = grid_y - dy
sample_x = grid_x - dx

# Normalize to [-1, 1] for grid_sample (align_corners=True)
sample_x = 2 * sample_x / (w - 1) - 1
sample_y = 2 * sample_y / (h - 1) - 1

# Stack to (H, W, 2) with (x, y) order, add batch dim -> (1, H, W, 2)
grid = torch.stack([sample_x, sample_y], dim=-1).unsqueeze(0)

# Input: (1, 1, H, W)
t = torch.from_numpy(arr).float().cuda().unsqueeze(0).unsqueeze(0)

# grid_sample with bilinear interpolation
out = F.grid_sample(t, grid, mode="bilinear", padding_mode="zeros", align_corners=True)

return out.squeeze().cpu().numpy()


def compute_ssim(arr1, arr2, win_size: int) -> float:
"""SSIM wrapper that routes to GPU kernel or CPU skimage."""
if USING_GPU and "ssim_cuda" in globals():
return float(ssim_cuda(arr1, arr2, win_size=win_size))
"""SSIM using skimage (CPU)."""
arr1_np = np.asarray(arr1)
arr2_np = np.asarray(arr2)
data_range = float(arr1_np.max() - arr1_np.min())
Expand All @@ -51,21 +96,7 @@ def compute_ssim(arr1, arr2, win_size: int) -> float:


def make_1d_profile(length: int, blend: int) -> np.ndarray:
"""
Create a linear ramp profile over `blend` pixels at each end.

Parameters
----------
length : int
Number of pixels.
blend : int
Ramp width.

Returns
-------
prof : (length,) float32
Linear profile.
"""
"""Create a linear ramp profile over `blend` pixels at each end."""
blend = min(blend, length // 2)
prof = np.ones(length, dtype=np.float32)
if blend > 0:
Expand All @@ -76,12 +107,14 @@ def make_1d_profile(length: int, blend: int) -> np.ndarray:


def to_numpy(arr):
"""Convert array to numpy, handling both CPU and GPU arrays."""
if USING_GPU and cp is not None and isinstance(arr, cp.ndarray):
return cp.asnumpy(arr)
"""Convert array to numpy."""
if TORCH_AVAILABLE and torch is not None and isinstance(arr, torch.Tensor):
return arr.cpu().numpy()
return np.asarray(arr)


def to_device(arr):
"""Move array to current device (GPU if available, else CPU)."""
return xp.asarray(arr)
"""Move array to GPU if available."""
if CUDA_AVAILABLE:
return torch.from_numpy(np.asarray(arr)).cuda()
return np.asarray(arr)
36 changes: 36 additions & 0 deletions tests/test_shift_array.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""Unit tests for GPU shift_array."""
import numpy as np
import sys
sys.path.insert(0, "src")

from tilefusion.utils import shift_array, CUDA_AVAILABLE
from scipy.ndimage import shift as scipy_shift


def test_integer_shift():
arr = np.random.rand(256, 256).astype(np.float32)
cpu = scipy_shift(arr, (3.0, -5.0), order=1, prefilter=False)
gpu = shift_array(arr, (3.0, -5.0))
np.testing.assert_allclose(gpu, cpu, rtol=1e-4, atol=1e-4)


def test_subpixel_mean_error():
arr = np.random.rand(256, 256).astype(np.float32)
cpu = scipy_shift(arr, (5.5, -3.2), order=1, prefilter=False)
gpu = shift_array(arr, (5.5, -3.2))
mean_diff = np.abs(cpu - gpu).mean()
assert mean_diff < 0.01, f"Mean diff {mean_diff} too high"


def test_zero_shift():
arr = np.random.rand(256, 256).astype(np.float32)
result = shift_array(arr, (0.0, 0.0))
# Allow small tolerance due to grid_sample interpolation
np.testing.assert_allclose(result, arr, rtol=1e-4, atol=1e-4)


if __name__ == "__main__":
test_integer_shift()
test_subpixel_mean_error()
test_zero_shift()
print("All tests passed")