diff --git a/src/tilefusion/utils.py b/src/tilefusion/utils.py index d0d1769..1b524a8 100644 --- a/src/tilefusion/utils.py +++ b/src/tilefusion/utils.py @@ -6,42 +6,88 @@ import numpy as np +# GPU detection - PyTorch based try: - import cupy as cp - from cupyx.scipy.ndimage import shift as cp_shift - from cucim.skimage.exposure import match_histograms - from cucim.skimage.measure import block_reduce - from cucim.skimage.registration import phase_cross_correlation - from opm_processing.imageprocessing.ssim_cuda import ( - structural_similarity_cupy_sep_shared as ssim_cuda, - ) - - xp = cp - USING_GPU = True -except Exception: - cp = None - cp_shift = None - from skimage.exposure import match_histograms - from skimage.measure import block_reduce - from skimage.registration import phase_cross_correlation - from scipy.ndimage import shift as _shift_cpu - from skimage.metrics import structural_similarity as _ssim_cpu - - xp = np - USING_GPU = False + import torch + TORCH_AVAILABLE = True + CUDA_AVAILABLE = torch.cuda.is_available() +except ImportError: + torch = None + TORCH_AVAILABLE = False + CUDA_AVAILABLE = False + +# CPU fallbacks +from scipy.ndimage import shift as _shift_cpu +from skimage.exposure import match_histograms as _match_histograms_cpu +from skimage.measure import block_reduce +from skimage.metrics import structural_similarity as _ssim_cpu +from skimage.registration import phase_cross_correlation + +# Legacy compatibility +USING_GPU = CUDA_AVAILABLE +xp = np +cp = None + + +def match_histograms(image, reference): + """ + Match histogram of image to reference using GPU (torch) or CPU (skimage). + + Parameters + ---------- + image : ndarray + Image to transform. + reference : ndarray + Reference image for histogram matching. + + Returns + ------- + matched : ndarray + Image with matched histogram. + """ + image_np = np.asarray(image) + reference_np = np.asarray(reference) + + if CUDA_AVAILABLE and image_np.ndim == 2: + return _match_histograms_torch(image_np, reference_np) + + return _match_histograms_cpu(image_np, reference_np) + + +def _match_histograms_torch(image: np.ndarray, reference: np.ndarray) -> np.ndarray: + """GPU histogram matching using torch operations.""" + # Move to GPU + img = torch.from_numpy(image.astype(np.float32)).cuda().flatten() + ref = torch.from_numpy(reference.astype(np.float32)).cuda().flatten() + + # Get sorted indices + img_sorted, img_indices = torch.sort(img) + ref_sorted, _ = torch.sort(ref) + + # Create inverse mapping + inv_indices = torch.empty_like(img_indices) + inv_indices[img_indices] = torch.arange(len(img), device="cuda") + + # Interpolate reference values at image quantiles + img_quantiles = torch.linspace(0, 1, len(img), device="cuda") + ref_quantiles = torch.linspace(0, 1, len(ref), device="cuda") + + # Map image values to reference values via quantile matching + interp_values = torch.zeros_like(img) + interp_values[img_indices] = ref_sorted[ + (inv_indices.float() / len(img) * len(ref)).long().clamp(0, len(ref) - 1) + ] + + return interp_values.reshape(image.shape).cpu().numpy() def shift_array(arr, shift_vec): - """Shift array using GPU if available, else CPU fallback.""" - if USING_GPU and cp_shift is not None: - return cp_shift(arr, shift=shift_vec, order=1, prefilter=False) - return _shift_cpu(arr, shift=shift_vec, order=1, prefilter=False) + """Shift array using scipy (CPU).""" + return _shift_cpu(np.asarray(arr), shift=shift_vec, order=1, prefilter=False) def compute_ssim(arr1, arr2, win_size: int) -> float: - """SSIM wrapper that routes to GPU kernel or CPU skimage.""" - if USING_GPU and "ssim_cuda" in globals(): - return float(ssim_cuda(arr1, arr2, win_size=win_size)) + """SSIM using skimage (CPU).""" arr1_np = np.asarray(arr1) arr2_np = np.asarray(arr2) data_range = float(arr1_np.max() - arr1_np.min()) @@ -51,21 +97,7 @@ def compute_ssim(arr1, arr2, win_size: int) -> float: def make_1d_profile(length: int, blend: int) -> np.ndarray: - """ - Create a linear ramp profile over `blend` pixels at each end. - - Parameters - ---------- - length : int - Number of pixels. - blend : int - Ramp width. - - Returns - ------- - prof : (length,) float32 - Linear profile. - """ + """Create a linear ramp profile over `blend` pixels at each end.""" blend = min(blend, length // 2) prof = np.ones(length, dtype=np.float32) if blend > 0: @@ -76,12 +108,14 @@ def make_1d_profile(length: int, blend: int) -> np.ndarray: def to_numpy(arr): - """Convert array to numpy, handling both CPU and GPU arrays.""" - if USING_GPU and cp is not None and isinstance(arr, cp.ndarray): - return cp.asnumpy(arr) + """Convert array to numpy.""" + if TORCH_AVAILABLE and torch is not None and isinstance(arr, torch.Tensor): + return arr.cpu().numpy() return np.asarray(arr) def to_device(arr): - """Move array to current device (GPU if available, else CPU).""" - return xp.asarray(arr) + """Move array to GPU if available.""" + if CUDA_AVAILABLE: + return torch.from_numpy(np.asarray(arr)).cuda() + return np.asarray(arr) diff --git a/tests/test_histogram_match.py b/tests/test_histogram_match.py new file mode 100644 index 0000000..3db18b5 --- /dev/null +++ b/tests/test_histogram_match.py @@ -0,0 +1,35 @@ +"""Unit tests for GPU histogram matching.""" +import numpy as np +import sys +sys.path.insert(0, "src") + +from tilefusion.utils import match_histograms, CUDA_AVAILABLE +from skimage.exposure import match_histograms as skimage_match + + +def test_histogram_range(): + img = np.random.rand(256, 256).astype(np.float32) + ref = np.random.rand(256, 256).astype(np.float32) * 2 + 1 + result = match_histograms(img, ref) + # Output should be in reference range + assert result.min() >= ref.min() - 0.1 + assert result.max() <= ref.max() + 0.1 + + +def test_histogram_correlation(): + img = np.random.rand(256, 256).astype(np.float32) + ref = np.random.rand(256, 256).astype(np.float32) + + cpu = skimage_match(img, ref) + gpu = match_histograms(img, ref) + + cpu_hist, _ = np.histogram(cpu.flatten(), bins=100) + gpu_hist, _ = np.histogram(gpu.flatten(), bins=100) + corr = np.corrcoef(cpu_hist, gpu_hist)[0, 1] + assert corr > 0.99, f"Histogram correlation {corr} too low" + + +if __name__ == "__main__": + test_histogram_range() + test_histogram_correlation() + print("All tests passed")