diff --git a/src/tilefusion/utils.py b/src/tilefusion/utils.py index d0d1769..5602743 100644 --- a/src/tilefusion/utils.py +++ b/src/tilefusion/utils.py @@ -6,66 +6,101 @@ import numpy as np +# GPU detection - PyTorch based try: - import cupy as cp - from cupyx.scipy.ndimage import shift as cp_shift - from cucim.skimage.exposure import match_histograms - from cucim.skimage.measure import block_reduce - from cucim.skimage.registration import phase_cross_correlation - from opm_processing.imageprocessing.ssim_cuda import ( - structural_similarity_cupy_sep_shared as ssim_cuda, - ) - - xp = cp - USING_GPU = True -except Exception: - cp = None - cp_shift = None - from skimage.exposure import match_histograms - from skimage.measure import block_reduce - from skimage.registration import phase_cross_correlation - from scipy.ndimage import shift as _shift_cpu - from skimage.metrics import structural_similarity as _ssim_cpu - - xp = np - USING_GPU = False + import torch + import torch.nn.functional as F + TORCH_AVAILABLE = True + CUDA_AVAILABLE = torch.cuda.is_available() +except ImportError: + torch = None + F = None + TORCH_AVAILABLE = False + CUDA_AVAILABLE = False + +# CPU fallbacks +from scipy.ndimage import shift as _shift_cpu +from skimage.exposure import match_histograms +from skimage.measure import block_reduce +from skimage.metrics import structural_similarity as _ssim_cpu +from skimage.registration import phase_cross_correlation + +# Legacy compatibility +USING_GPU = CUDA_AVAILABLE +xp = np +cp = None -def shift_array(arr, shift_vec): - """Shift array using GPU if available, else CPU fallback.""" - if USING_GPU and cp_shift is not None: - return cp_shift(arr, shift=shift_vec, order=1, prefilter=False) - return _shift_cpu(arr, shift=shift_vec, order=1, prefilter=False) +def compute_ssim(arr1, arr2, win_size: int) -> float: + """ + Compute SSIM using GPU (torch) or CPU (skimage). + Parameters + ---------- + arr1, arr2 : ndarray + Input images (2D). + win_size : int + Window size for local statistics. + + Returns + ------- + ssim : float + Mean SSIM value. + """ + arr1_np = np.asarray(arr1, dtype=np.float32) + arr2_np = np.asarray(arr2, dtype=np.float32) + + if CUDA_AVAILABLE and arr1_np.ndim == 2: + data_range = float(arr1_np.max() - arr1_np.min()) + if data_range == 0: + data_range = 1.0 + return _compute_ssim_torch(arr1_np, arr2_np, win_size, data_range) -def compute_ssim(arr1, arr2, win_size: int) -> float: - """SSIM wrapper that routes to GPU kernel or CPU skimage.""" - if USING_GPU and "ssim_cuda" in globals(): - return float(ssim_cuda(arr1, arr2, win_size=win_size)) - arr1_np = np.asarray(arr1) - arr2_np = np.asarray(arr2) data_range = float(arr1_np.max() - arr1_np.min()) if data_range == 0: data_range = 1.0 return float(_ssim_cpu(arr1_np, arr2_np, win_size=win_size, data_range=data_range)) -def make_1d_profile(length: int, blend: int) -> np.ndarray: - """ - Create a linear ramp profile over `blend` pixels at each end. +def _compute_ssim_torch(arr1: np.ndarray, arr2: np.ndarray, win_size: int, data_range: float) -> float: + """GPU SSIM using torch conv2d for local statistics.""" + C1 = (0.01 * data_range) ** 2 + C2 = (0.03 * data_range) ** 2 - Parameters - ---------- - length : int - Number of pixels. - blend : int - Ramp width. + # Create uniform window + window = torch.ones(1, 1, win_size, win_size, device="cuda") / (win_size * win_size) - Returns - ------- - prof : (length,) float32 - Linear profile. - """ + # Convert to tensors (1, 1, H, W) + img1 = torch.from_numpy(arr1).float().cuda().unsqueeze(0).unsqueeze(0) + img2 = torch.from_numpy(arr2).float().cuda().unsqueeze(0).unsqueeze(0) + + # Compute local means + mu1 = F.conv2d(img1, window, padding=win_size // 2) + mu2 = F.conv2d(img2, window, padding=win_size // 2) + + mu1_sq = mu1 ** 2 + mu2_sq = mu2 ** 2 + mu1_mu2 = mu1 * mu2 + + # Compute local variances and covariance + sigma1_sq = F.conv2d(img1 ** 2, window, padding=win_size // 2) - mu1_sq + sigma2_sq = F.conv2d(img2 ** 2, window, padding=win_size // 2) - mu2_sq + sigma12 = F.conv2d(img1 * img2, window, padding=win_size // 2) - mu1_mu2 + + # SSIM formula + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / \ + ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) + + return float(ssim_map.mean().cpu()) + + +def shift_array(arr, shift_vec): + """Shift array using scipy (CPU).""" + return _shift_cpu(np.asarray(arr), shift=shift_vec, order=1, prefilter=False) + + +def make_1d_profile(length: int, blend: int) -> np.ndarray: + """Create a linear ramp profile over `blend` pixels at each end.""" blend = min(blend, length // 2) prof = np.ones(length, dtype=np.float32) if blend > 0: @@ -76,12 +111,14 @@ def make_1d_profile(length: int, blend: int) -> np.ndarray: def to_numpy(arr): - """Convert array to numpy, handling both CPU and GPU arrays.""" - if USING_GPU and cp is not None and isinstance(arr, cp.ndarray): - return cp.asnumpy(arr) + """Convert array to numpy.""" + if TORCH_AVAILABLE and torch is not None and isinstance(arr, torch.Tensor): + return arr.cpu().numpy() return np.asarray(arr) def to_device(arr): - """Move array to current device (GPU if available, else CPU).""" - return xp.asarray(arr) + """Move array to GPU if available.""" + if CUDA_AVAILABLE: + return torch.from_numpy(np.asarray(arr)).cuda() + return np.asarray(arr) diff --git a/tests/test_ssim.py b/tests/test_ssim.py new file mode 100644 index 0000000..ce7f1a0 --- /dev/null +++ b/tests/test_ssim.py @@ -0,0 +1,38 @@ +"""Unit tests for GPU SSIM.""" +import numpy as np +import sys +sys.path.insert(0, "src") + +from tilefusion.utils import compute_ssim, CUDA_AVAILABLE +from skimage.metrics import structural_similarity as skimage_ssim + + +def test_ssim_similar_images(): + arr1 = np.random.rand(256, 256).astype(np.float32) + arr2 = arr1 + np.random.rand(256, 256).astype(np.float32) * 0.1 + + data_range = arr1.max() - arr1.min() + cpu = skimage_ssim(arr1, arr2, win_size=15, data_range=data_range) + gpu = compute_ssim(arr1, arr2, win_size=15) + + assert abs(cpu - gpu) < 0.01, f"SSIM diff {abs(cpu-gpu)} too high" + + +def test_ssim_identical_images(): + arr = np.random.rand(256, 256).astype(np.float32) + ssim = compute_ssim(arr, arr, win_size=15) + assert ssim > 0.99, f"SSIM of identical images should be ~1.0, got {ssim}" + + +def test_ssim_different_images(): + arr1 = np.random.rand(256, 256).astype(np.float32) + arr2 = np.random.rand(256, 256).astype(np.float32) + ssim = compute_ssim(arr1, arr2, win_size=15) + assert ssim < 0.5, f"SSIM of random images should be low, got {ssim}" + + +if __name__ == "__main__": + test_ssim_similar_images() + test_ssim_identical_images() + test_ssim_different_images() + print("All tests passed")