diff --git a/src/tilefusion/utils.py b/src/tilefusion/utils.py index d0d1769..1c5c231 100644 --- a/src/tilefusion/utils.py +++ b/src/tilefusion/utils.py @@ -6,42 +6,87 @@ import numpy as np +# GPU detection - PyTorch based try: - import cupy as cp - from cupyx.scipy.ndimage import shift as cp_shift - from cucim.skimage.exposure import match_histograms - from cucim.skimage.measure import block_reduce - from cucim.skimage.registration import phase_cross_correlation - from opm_processing.imageprocessing.ssim_cuda import ( - structural_similarity_cupy_sep_shared as ssim_cuda, - ) - - xp = cp - USING_GPU = True -except Exception: - cp = None - cp_shift = None - from skimage.exposure import match_histograms - from skimage.measure import block_reduce - from skimage.registration import phase_cross_correlation - from scipy.ndimage import shift as _shift_cpu - from skimage.metrics import structural_similarity as _ssim_cpu - - xp = np - USING_GPU = False + import torch + import torch.nn.functional as F + TORCH_AVAILABLE = True + CUDA_AVAILABLE = torch.cuda.is_available() +except ImportError: + torch = None + F = None + TORCH_AVAILABLE = False + CUDA_AVAILABLE = False + +# CPU fallbacks +from scipy.ndimage import shift as _shift_cpu +from skimage.exposure import match_histograms +from skimage.measure import block_reduce +from skimage.metrics import structural_similarity as _ssim_cpu +from skimage.registration import phase_cross_correlation + +# Legacy compatibility +USING_GPU = CUDA_AVAILABLE +xp = np +cp = None def shift_array(arr, shift_vec): - """Shift array using GPU if available, else CPU fallback.""" - if USING_GPU and cp_shift is not None: - return cp_shift(arr, shift=shift_vec, order=1, prefilter=False) - return _shift_cpu(arr, shift=shift_vec, order=1, prefilter=False) + """ + Shift array by subpixel amounts using GPU (torch) or CPU (scipy). + + Parameters + ---------- + arr : ndarray + 2D input array. + shift_vec : array-like + (dy, dx) shift amounts. + + Returns + ------- + shifted : ndarray + Shifted array, same shape as input. + """ + arr_np = np.asarray(arr) + + if CUDA_AVAILABLE and arr_np.ndim == 2: + return _shift_array_torch(arr_np, shift_vec) + + return _shift_cpu(arr_np, shift=shift_vec, order=1, prefilter=False) + + +def _shift_array_torch(arr: np.ndarray, shift_vec) -> np.ndarray: + """GPU shift using torch.nn.functional.grid_sample.""" + h, w = arr.shape + dy, dx = float(shift_vec[0]), float(shift_vec[1]) + + # Create pixel coordinate grids + y_coords = torch.arange(h, device="cuda", dtype=torch.float32) + x_coords = torch.arange(w, device="cuda", dtype=torch.float32) + grid_y, grid_x = torch.meshgrid(y_coords, x_coords, indexing="ij") + + # Apply shift: to shift output by (dy, dx), sample from (y-dy, x-dx) + sample_y = grid_y - dy + sample_x = grid_x - dx + + # Normalize to [-1, 1] for grid_sample (align_corners=True) + sample_x = 2 * sample_x / (w - 1) - 1 + sample_y = 2 * sample_y / (h - 1) - 1 + + # Stack to (H, W, 2) with (x, y) order, add batch dim -> (1, H, W, 2) + grid = torch.stack([sample_x, sample_y], dim=-1).unsqueeze(0) + + # Input: (1, 1, H, W) + t = torch.from_numpy(arr).float().cuda().unsqueeze(0).unsqueeze(0) + + # grid_sample with bilinear interpolation + out = F.grid_sample(t, grid, mode="bilinear", padding_mode="zeros", align_corners=True) + + return out.squeeze().cpu().numpy() def compute_ssim(arr1, arr2, win_size: int) -> float: - """SSIM wrapper that routes to GPU kernel or CPU skimage.""" - if USING_GPU and "ssim_cuda" in globals(): - return float(ssim_cuda(arr1, arr2, win_size=win_size)) + """SSIM using skimage (CPU).""" arr1_np = np.asarray(arr1) arr2_np = np.asarray(arr2) data_range = float(arr1_np.max() - arr1_np.min()) @@ -51,21 +96,7 @@ def compute_ssim(arr1, arr2, win_size: int) -> float: def make_1d_profile(length: int, blend: int) -> np.ndarray: - """ - Create a linear ramp profile over `blend` pixels at each end. - - Parameters - ---------- - length : int - Number of pixels. - blend : int - Ramp width. - - Returns - ------- - prof : (length,) float32 - Linear profile. - """ + """Create a linear ramp profile over `blend` pixels at each end.""" blend = min(blend, length // 2) prof = np.ones(length, dtype=np.float32) if blend > 0: @@ -76,12 +107,14 @@ def make_1d_profile(length: int, blend: int) -> np.ndarray: def to_numpy(arr): - """Convert array to numpy, handling both CPU and GPU arrays.""" - if USING_GPU and cp is not None and isinstance(arr, cp.ndarray): - return cp.asnumpy(arr) + """Convert array to numpy.""" + if TORCH_AVAILABLE and torch is not None and isinstance(arr, torch.Tensor): + return arr.cpu().numpy() return np.asarray(arr) def to_device(arr): - """Move array to current device (GPU if available, else CPU).""" - return xp.asarray(arr) + """Move array to GPU if available.""" + if CUDA_AVAILABLE: + return torch.from_numpy(np.asarray(arr)).cuda() + return np.asarray(arr) diff --git a/tests/test_shift_array.py b/tests/test_shift_array.py new file mode 100644 index 0000000..8874b42 --- /dev/null +++ b/tests/test_shift_array.py @@ -0,0 +1,36 @@ +"""Unit tests for GPU shift_array.""" +import numpy as np +import sys +sys.path.insert(0, "src") + +from tilefusion.utils import shift_array, CUDA_AVAILABLE +from scipy.ndimage import shift as scipy_shift + + +def test_integer_shift(): + arr = np.random.rand(256, 256).astype(np.float32) + cpu = scipy_shift(arr, (3.0, -5.0), order=1, prefilter=False) + gpu = shift_array(arr, (3.0, -5.0)) + np.testing.assert_allclose(gpu, cpu, rtol=1e-4, atol=1e-4) + + +def test_subpixel_mean_error(): + arr = np.random.rand(256, 256).astype(np.float32) + cpu = scipy_shift(arr, (5.5, -3.2), order=1, prefilter=False) + gpu = shift_array(arr, (5.5, -3.2)) + mean_diff = np.abs(cpu - gpu).mean() + assert mean_diff < 0.01, f"Mean diff {mean_diff} too high" + + +def test_zero_shift(): + arr = np.random.rand(256, 256).astype(np.float32) + result = shift_array(arr, (0.0, 0.0)) + # Allow small tolerance due to grid_sample interpolation + np.testing.assert_allclose(result, arr, rtol=1e-4, atol=1e-4) + + +if __name__ == "__main__": + test_integer_shift() + test_subpixel_mean_error() + test_zero_shift() + print("All tests passed")