Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
017d0f8
feat: Add GPU-accelerated operations via PyTorch
hongquanli Jan 4, 2026
1eae788
style: Apply black formatting
hongquanli Jan 4, 2026
a51bc18
fix: Remove unused variables in _match_histograms_torch
hongquanli Jan 4, 2026
b627a15
feat: Add dtype preservation to shift_array, match_histograms, block_…
hongquanli Jan 4, 2026
bfefa67
fix: Handle 2D block_size for 3D arrays in _block_reduce_torch
hongquanli Jan 4, 2026
f9473c9
refactor: Extract duplicate data_range calculation in compute_ssim
hongquanli Jan 4, 2026
a2d1eb9
refactor: Add named constants for magic numbers (_FFT_EPS, _SSIM_K1, …
hongquanli Jan 4, 2026
a297df5
test: Add CPU fallback and dtype preservation tests
hongquanli Jan 4, 2026
21d2d45
test: Add subpixel phase correlation tests
hongquanli Jan 4, 2026
0e03178
refactor: Add _PARABOLIC_EPS constant for subpixel refinement
hongquanli Jan 4, 2026
f2bd5a1
fix: Guard against 1-pixel arrays in _shift_array_torch
hongquanli Jan 4, 2026
91c7ea9
refactor: Add __all__ export list and document legacy compatibility vars
hongquanli Jan 4, 2026
ea5d0f5
test: Refactor tests to use rng fixture and pytest class style
hongquanli Jan 4, 2026
99a46a8
chore: Clean up unused imports and variables
hongquanli Jan 4, 2026
afe7e4e
Add type hints to public functions in utils.py
hongquanli Jan 4, 2026
4a111c2
Fix histogram matching bug and improve type hints
hongquanli Jan 4, 2026
149784b
Cosmetic cleanups and fix shift_array CPU/GPU consistency
hongquanli Jan 4, 2026
dd94f2a
Document GPU path placeholder values in phase_cross_correlation
hongquanli Jan 5, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
509 changes: 467 additions & 42 deletions src/tilefusion/utils.py

Large diffs are not rendered by default.

10 changes: 10 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,13 @@ def sample_tile(rng):
def sample_multichannel_tile(rng):
"""Generate a sample multi-channel tile."""
return rng.random((3, 100, 100), dtype=np.float32) * 65535


@pytest.fixture
def force_cpu(monkeypatch):
"""Force CPU fallback by setting CUDA_AVAILABLE to False."""
import tilefusion.utils as utils

monkeypatch.setattr(utils, "CUDA_AVAILABLE", False)
yield
# monkeypatch automatically restores after test
67 changes: 67 additions & 0 deletions tests/test_block_reduce.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
"""Unit tests for GPU block_reduce."""

import numpy as np
import pytest
import sys
from skimage.measure import block_reduce as skimage_block_reduce

sys.path.insert(0, "src")

from tilefusion.utils import block_reduce


class TestBlockReduce:
"""Test block_reduce GPU vs CPU equivalence."""

def test_2d_basic(self, rng):
"""Test 2D block reduce matches skimage."""
arr = rng.random((256, 256)).astype(np.float32)
block_size = (4, 4)

result = block_reduce(arr, block_size, np.mean)
expected = skimage_block_reduce(arr, block_size, np.mean)

np.testing.assert_allclose(result, expected, rtol=1e-5)

def test_2d_large(self, rng):
"""Test larger 2D array."""
arr = rng.random((1024, 1024)).astype(np.float32)
block_size = (8, 8)

result = block_reduce(arr, block_size, np.mean)
expected = skimage_block_reduce(arr, block_size, np.mean)

np.testing.assert_allclose(result, expected, rtol=1e-5)

def test_3d_multichannel(self, rng):
"""Test 3D array with channel dimension."""
arr = rng.random((3, 256, 256)).astype(np.float32)
block_size = (1, 4, 4)

result = block_reduce(arr, block_size, np.mean)
expected = skimage_block_reduce(arr, block_size, np.mean)

np.testing.assert_allclose(result, expected, rtol=1e-5)

def test_output_shape(self, rng):
"""Test output shape is correct."""
arr = rng.random((512, 512)).astype(np.float32)
block_size = (4, 4)

result = block_reduce(arr, block_size, np.mean)

assert result.shape == (128, 128)

def test_non_divisible_shape(self, rng):
"""Test block reduce with non-divisible dimensions."""
arr = rng.random((100, 100)).astype(np.float32)
block_size = (8, 8)

result = block_reduce(arr, block_size, np.mean)
expected = skimage_block_reduce(arr, block_size, np.mean)

np.testing.assert_allclose(result, expected, rtol=1e-5)


if __name__ == "__main__":
pytest.main([__file__, "-v"])
136 changes: 136 additions & 0 deletions tests/test_cpu_fallback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
"""Tests for CPU fallback paths and dtype preservation."""

import numpy as np
import pytest
import sys

sys.path.insert(0, "src")

from tilefusion.utils import (
phase_cross_correlation,
shift_array,
match_histograms,
block_reduce,
compute_ssim,
)


class TestCPUFallback:
"""Test that CPU fallback paths work correctly."""

def test_phase_cross_correlation_cpu(self, force_cpu, rng):
"""Test phase_cross_correlation with CPU fallback."""
ref = rng.random((128, 128)).astype(np.float32)
mov = np.roll(ref, 5, axis=0)

shift, error, phasediff = phase_cross_correlation(ref, mov)

assert abs(shift[0] - (-5)) < 1, f"Y shift {shift[0]} not close to -5"

def test_shift_array_cpu(self, force_cpu, rng):
"""Test shift_array with CPU fallback."""
arr = rng.random((128, 128)).astype(np.float32)
result = shift_array(arr, (3.0, -2.0))

assert result.shape == arr.shape
assert result.dtype == arr.dtype

def test_match_histograms_cpu(self, force_cpu, rng):
"""Test match_histograms with CPU fallback."""
img = rng.random((128, 128)).astype(np.float32)
ref = rng.random((128, 128)).astype(np.float32) * 2

result = match_histograms(img, ref)

assert result.shape == img.shape

def test_block_reduce_cpu(self, force_cpu, rng):
"""Test block_reduce with CPU fallback."""
arr = rng.random((128, 128)).astype(np.float32)
result = block_reduce(arr, (4, 4), np.mean)

assert result.shape == (32, 32)

def test_compute_ssim_cpu(self, force_cpu, rng):
"""Test compute_ssim with CPU fallback."""
arr1 = rng.random((128, 128)).astype(np.float32)
arr2 = arr1 + rng.random((128, 128)).astype(np.float32) * 0.1

ssim = compute_ssim(arr1, arr2, win_size=7)

assert 0.0 <= ssim <= 1.0


class TestDtypePreservation:
"""Test that dtype is preserved when preserve_dtype=True."""

@pytest.mark.parametrize("dtype", [np.uint8, np.uint16, np.float32, np.float64])
def test_shift_array_dtype(self, dtype, force_cpu, rng):
"""Test shift_array preserves dtype."""
arr = (rng.random((64, 64)) * 255).astype(dtype)
result = shift_array(arr, (1.5, -1.5), preserve_dtype=True)

assert result.dtype == dtype, f"Expected {dtype}, got {result.dtype}"

@pytest.mark.parametrize("dtype", [np.uint8, np.uint16, np.float32, np.float64])
def test_match_histograms_dtype(self, dtype, force_cpu, rng):
"""Test match_histograms preserves dtype."""
img = (rng.random((64, 64)) * 255).astype(dtype)
ref = (rng.random((64, 64)) * 255).astype(dtype)
result = match_histograms(img, ref, preserve_dtype=True)

assert result.dtype == dtype, f"Expected {dtype}, got {result.dtype}"

@pytest.mark.parametrize("dtype", [np.uint8, np.uint16, np.float32, np.float64])
def test_block_reduce_dtype(self, dtype, force_cpu, rng):
"""Test block_reduce preserves dtype."""
arr = (rng.random((64, 64)) * 255).astype(dtype)
result = block_reduce(arr, (4, 4), np.mean, preserve_dtype=True)

assert result.dtype == dtype, f"Expected {dtype}, got {result.dtype}"

def test_shift_array_no_preserve(self, force_cpu, rng):
"""Test shift_array returns float when preserve_dtype=False."""
arr = (rng.random((64, 64)) * 255).astype(np.uint16)
result = shift_array(arr, (1.5, -1.5), preserve_dtype=False)

# Should return float64 (scipy default)
assert result.dtype in [np.float32, np.float64]


class TestEdgeCases:
"""Test edge cases and boundary conditions."""

def test_shift_zero(self, force_cpu, rng):
"""Test that zero shift returns nearly identical array."""
arr = rng.random((64, 64)).astype(np.float32)
result = shift_array(arr, (0.0, 0.0))

np.testing.assert_allclose(result, arr, rtol=1e-5, atol=1e-5)

def test_identical_images_ssim(self, force_cpu, rng):
"""Test SSIM of identical images is ~1.0."""
arr = rng.random((64, 64)).astype(np.float32)
ssim = compute_ssim(arr, arr, win_size=7)

assert ssim > 0.99, f"SSIM of identical images should be ~1.0, got {ssim}"

def test_block_reduce_3d(self, force_cpu, rng):
"""Test block_reduce with 3D array."""
arr = rng.random((3, 64, 64)).astype(np.float32)
result = block_reduce(arr, (1, 4, 4), np.mean)

assert result.shape == (3, 16, 16)

def test_different_size_histogram_match(self, force_cpu, rng):
"""Test histogram matching with different sized images."""
img = rng.random((64, 64)).astype(np.float32)
ref = rng.random((128, 128)).astype(np.float32)

result = match_histograms(img, ref)

assert result.shape == img.shape


if __name__ == "__main__":
pytest.main([__file__, "-v"])
102 changes: 102 additions & 0 deletions tests/test_fft.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
"""Unit tests for GPU phase_cross_correlation (FFT)."""

import numpy as np
import pytest
import sys

sys.path.insert(0, "src")

from tilefusion.utils import phase_cross_correlation
from skimage.registration import phase_cross_correlation as skimage_pcc


class TestPhaseCorrelation:
"""Tests for phase_cross_correlation function."""

def test_known_shift(self, rng):
"""Test detection of known integer shift."""
ref = rng.random((256, 256)).astype(np.float32)

# Create shifted version: mov is ref shifted by (+5, -3)
# phase_cross_correlation returns shift to apply to mov to align with ref
# So it should return (-5, +3)
mov = np.zeros_like(ref)
mov[5:, :253] = ref[:-5, 3:]

shift, _, _ = phase_cross_correlation(ref, mov)

assert abs(shift[0] - (-5)) < 1, f"Y shift {shift[0]} not close to -5"
assert abs(shift[1] - 3) < 1, f"X shift {shift[1]} not close to 3"

def test_zero_shift(self, rng):
"""Test that identical images give zero shift."""
ref = rng.random((256, 256)).astype(np.float32)

shift, _, _ = phase_cross_correlation(ref, ref)

assert abs(shift[0]) < 0.5, f"Y shift {shift[0]} should be ~0"
assert abs(shift[1]) < 0.5, f"X shift {shift[1]} should be ~0"

def test_matches_skimage_direction(self, rng):
"""Test that shift direction matches skimage convention."""
ref = rng.random((128, 128)).astype(np.float32)

# Shift by rolling
mov = np.roll(np.roll(ref, 10, axis=0), -7, axis=1)

gpu_shift, _, _ = phase_cross_correlation(ref, mov)
cpu_shift, _, _ = skimage_pcc(ref, mov)

# Directions should match
assert np.sign(gpu_shift[0]) == np.sign(cpu_shift[0]), "Y direction mismatch"
assert np.sign(gpu_shift[1]) == np.sign(cpu_shift[1]), "X direction mismatch"


class TestSubpixelRefinement:
"""Tests for subpixel phase correlation refinement."""

def test_subpixel_refinement(self, rng):
"""Test subpixel accuracy with upsample_factor > 1."""
ref = rng.random((128, 128)).astype(np.float32)

# Use integer shift for ground truth (subpixel refinement should still work)
mov = np.roll(np.roll(ref, 7, axis=0), -4, axis=1)

# Test with upsample_factor=10 for subpixel refinement
shift_subpixel, _, _ = phase_cross_correlation(ref, mov, upsample_factor=10)

# Should detect the shift direction correctly
assert (
abs(shift_subpixel[0] - (-7)) < 1
), f"Subpixel Y shift {shift_subpixel[0]} not close to -7"
assert (
abs(shift_subpixel[1] - 4) < 1
), f"Subpixel X shift {shift_subpixel[1]} not close to 4"

# Verify reasonable range
assert -10 < shift_subpixel[0] < 0, f"Subpixel Y shift {shift_subpixel[0]} out of range"
assert 0 < shift_subpixel[1] < 10, f"Subpixel X shift {shift_subpixel[1]} out of range"

def test_subpixel_vs_integer_consistency(self, rng):
"""Test that subpixel and integer modes give consistent direction."""
ref = rng.random((64, 64)).astype(np.float32)
mov = np.roll(np.roll(ref, 3, axis=0), -2, axis=1)

shift_int, _, _ = phase_cross_correlation(ref, mov, upsample_factor=1)
shift_sub, _, _ = phase_cross_correlation(ref, mov, upsample_factor=10)

# Signs should match
assert np.sign(shift_int[0]) == np.sign(
shift_sub[0]
), "Y direction mismatch between int/subpixel"
assert np.sign(shift_int[1]) == np.sign(
shift_sub[1]
), "X direction mismatch between int/subpixel"

# Magnitudes should be close
assert abs(shift_int[0] - shift_sub[0]) < 1, "Y magnitude differs too much"
assert abs(shift_int[1] - shift_sub[1]) < 1, "X magnitude differs too much"


if __name__ == "__main__":
pytest.main([__file__, "-v"])
64 changes: 64 additions & 0 deletions tests/test_histogram_match.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""Unit tests for GPU histogram matching."""

import numpy as np
import pytest
import sys

sys.path.insert(0, "src")

from tilefusion.utils import match_histograms
from skimage.exposure import match_histograms as skimage_match


class TestMatchHistograms:
"""Tests for match_histograms function."""

def test_histogram_range(self, rng):
"""Test output is in reference range."""
img = rng.random((256, 256)).astype(np.float32)
ref = rng.random((256, 256)).astype(np.float32) * 2 + 1
result = match_histograms(img, ref)
# Output should be in reference range
assert result.min() >= ref.min() - 0.1
assert result.max() <= ref.max() + 0.1

def test_histogram_correlation(self, rng):
"""Test histogram correlation with skimage."""
img = rng.random((256, 256)).astype(np.float32)
ref = rng.random((256, 256)).astype(np.float32)

cpu = skimage_match(img, ref)
gpu = match_histograms(img, ref)

cpu_hist, _ = np.histogram(cpu.flatten(), bins=100)
gpu_hist, _ = np.histogram(gpu.flatten(), bins=100)
corr = np.corrcoef(cpu_hist, gpu_hist)[0, 1]
assert corr > 0.99, f"Histogram correlation {corr} too low"

def test_same_image(self, rng):
"""Test matching image to itself."""
img = rng.random((128, 128)).astype(np.float32)
result = match_histograms(img, img)
np.testing.assert_allclose(result, img, rtol=1e-5)

def test_different_sizes(self, rng):
"""Test matching images of different sizes."""
img = rng.random((64, 64)).astype(np.float32)
ref = rng.random((128, 128)).astype(np.float32)
result = match_histograms(img, ref)
assert result.shape == img.shape

def test_pixel_values_match_skimage(self, rng):
"""Test pixel-by-pixel matching against skimage."""
img = rng.random((64, 64)).astype(np.float32)
ref = rng.random((64, 64)).astype(np.float32) * 2 + 1

cpu = skimage_match(img, ref)
gpu = match_histograms(img, ref)

# Pixel values should be close (not just histogram shape)
np.testing.assert_allclose(gpu, cpu, rtol=1e-4, atol=1e-4)


if __name__ == "__main__":
pytest.main([__file__, "-v"])
Loading