Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 89 additions & 20 deletions src/tilefusion/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,10 +518,17 @@ def refine_tile_positions_with_cross_correlation(
ssim_window: int = None,
ch_idx: int = 0,
threshold: float = None,
parallel: bool = True,
parallel: Optional[bool] = None,
) -> None:
"""
Detect and score overlaps between neighboring tile pairs via cross-correlation.

Parameters
----------
parallel : bool, optional
If None (default), auto-detects: enabled for multi-file formats
(Zarr, individual TIFFs, OME-TIFF tiles), disabled for single-file
OME-TIFF (due to I/O contention).
"""
df = downsample_factors or self.downsample_factors
sw = ssim_window or self.ssim_window
Expand All @@ -541,7 +548,18 @@ def refine_tile_positions_with_cross_correlation(
# Compute bounds
pair_bounds = compute_pair_bounds(adjacent_pairs, (self.Y, self.X))

# Use parallel processing for CPU mode
# Auto-detect parallel mode if not specified
if parallel is None:
# Parallel helps for individual TIFFs (separate files)
# but hurts for single-file OME-TIFF (I/O contention)
is_multi_file = (
self._is_zarr_format
or self._is_individual_tiffs_format
or self._is_ome_tiff_tiles_format
)
parallel = is_multi_file

# Use parallel processing for CPU mode with enough pairs
use_parallel = parallel and not USING_GPU and len(pair_bounds) > 4

if use_parallel:
Expand All @@ -557,20 +575,34 @@ def _register_parallel(
th: float,
max_shift: Tuple[int, int],
) -> None:
"""Register tile pairs using parallel processing (CPU mode)."""
import psutil
"""Register tile pairs using parallel I/O and compute.

available_ram = psutil.virtual_memory().available
patch_size_est = self.Y * self.X * 4 * 2
max_pairs_in_memory = int(available_ram * 0.3 / patch_size_est)
batch_size = max(16, max_pairs_in_memory)
Uses batching only when estimated memory exceeds 30% of available RAM.
"""
import psutil

n_pairs = len(pair_bounds)
n_batches = (n_pairs + batch_size - 1) // batch_size
n_workers = min(cpu_count(), batch_size, 8)
n_workers = min(cpu_count(), n_pairs, self._max_workers)
io_workers = min(n_pairs, self._max_workers)
print(
f"Parallel registration: {n_pairs} pairs, {n_workers} compute workers, {io_workers} I/O workers"
)

if n_batches > 1:
print(f"Processing {n_pairs} pairs in {n_batches} batches")
# Estimate memory needed based on actual overlap size
if pair_bounds:
total_pixels = 0
for _, _, bounds_i_y, bounds_i_x, _, _ in pair_bounds:
patch_h = bounds_i_y[1] - bounds_i_y[0]
patch_w = bounds_i_x[1] - bounds_i_x[0]
total_pixels += patch_h * patch_w
# 4 bytes per float32 pixel, 2 patches per pair
estimated_memory = total_pixels * 4 * 2
else:
estimated_memory = 0

available_ram = psutil.virtual_memory().available
ram_budget = int(available_ram * 0.3)
needs_batching = estimated_memory > ram_budget

def read_pair_patches(args):
i_pos, j_pos, bounds_i_y, bounds_i_x, bounds_j_y, bounds_j_x = args
Expand All @@ -585,25 +617,62 @@ def read_pair_patches(args):
except Exception:
return (i_pos, j_pos, None, None)

for batch_idx in range(n_batches):
start = batch_idx * batch_size
end = min(start + batch_size, n_pairs)
batch = pair_bounds[start:end]
if needs_batching:
# Batched approach for large datasets
avg_pair_bytes = max(1, estimated_memory // n_pairs) if n_pairs > 0 else 1
batch_size = max(16, ram_budget // avg_pair_bytes)
n_batches = (n_pairs + batch_size - 1) // batch_size

print(
f"Processing {n_pairs} pairs in {n_batches} batches (RAM limited, {n_workers} workers)"
)

for batch_idx in range(n_batches):
start = batch_idx * batch_size
end = min(start + batch_size, n_pairs)
batch = pair_bounds[start:end]

with ThreadPoolExecutor(max_workers=io_workers) as io_executor:
patches = list(io_executor.map(read_pair_patches, batch))

work_items = [
(i, j, pi, pj, df, sw, th, max_shift)
for i, j, pi, pj in patches
if pi is not None
]

desc = f"register {batch_idx+1}/{n_batches}"
with ThreadPoolExecutor(max_workers=n_workers) as executor:
results = list(
tqdm(
executor.map(register_pair_worker, work_items),
total=len(work_items),
desc=desc,
leave=True,
)
)

with ThreadPoolExecutor(max_workers=8) as io_executor:
patches = list(io_executor.map(read_pair_patches, batch))
for i_pos, j_pos, dy_s, dx_s, score in results:
if dy_s is not None:
self.pairwise_metrics[(i_pos, j_pos)] = (dy_s, dx_s, score)

del patches, work_items, results
gc.collect()
else:
# Simple approach - load all at once
with ThreadPoolExecutor(max_workers=io_workers) as io_executor:
patches = list(io_executor.map(read_pair_patches, pair_bounds))

work_items = [
(i, j, pi, pj, df, sw, th, max_shift) for i, j, pi, pj in patches if pi is not None
]

desc = f"register {batch_idx+1}/{n_batches}" if n_batches > 1 else "register"
with ThreadPoolExecutor(max_workers=n_workers) as executor:
results = list(
tqdm(
executor.map(register_pair_worker, work_items),
total=len(work_items),
desc=desc,
desc="register",
leave=True,
)
)
Expand Down
52 changes: 52 additions & 0 deletions tests/test_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,3 +142,55 @@ def test_returns_float_tuple(self):
assert isinstance(shift[0], float)
assert isinstance(shift[1], float)
assert isinstance(ssim, float)


class TestOverlapBoundsSize:
"""Tests verifying overlap region size from compute_pair_bounds."""

def test_horizontal_overlap_much_smaller_than_full_tile(self):
"""Verify horizontal overlap region is much smaller than full tile.

For a 2048x2048 tile with 15% overlap (~307px), the overlap region is
roughly 2048 * 307 = 628,736 pixels, not 2048 * 2048 = 4,194,304 pixels.
"""
tile_shape = (2048, 2048)
overlap_pixels = 307
dx = tile_shape[1] - overlap_pixels # 1741

adjacent_pairs = [
(0, 1, 0, dx, tile_shape[0], overlap_pixels), # horizontal pair
]
pair_bounds = compute_pair_bounds(adjacent_pairs, tile_shape)

_, _, bounds_i_y, bounds_i_x, _, _ = pair_bounds[0]
actual_overlap_h = bounds_i_y[1] - bounds_i_y[0]
actual_overlap_w = bounds_i_x[1] - bounds_i_x[0]
actual_overlap_pixels = actual_overlap_h * actual_overlap_w

full_tile_pixels = tile_shape[0] * tile_shape[1]

# Overlap region should be much smaller than full tile
ratio = full_tile_pixels / actual_overlap_pixels
assert ratio > 5, f"Expected overlap to be <20% of full tile, got ratio {ratio:.1f}"

# Verify the bounds are correct
assert actual_overlap_w == overlap_pixels
assert actual_overlap_h == tile_shape[0]

def test_vertical_overlap_bounds(self):
"""Test bounds for vertical overlap."""
tile_shape = (2048, 2048)
overlap_pixels = 307
dy = tile_shape[0] - overlap_pixels

adjacent_pairs = [
(0, 1, dy, 0, overlap_pixels, tile_shape[1]), # vertical pair
]
pair_bounds = compute_pair_bounds(adjacent_pairs, tile_shape)

_, _, bounds_i_y, bounds_i_x, _, _ = pair_bounds[0]
actual_overlap_h = bounds_i_y[1] - bounds_i_y[0]
actual_overlap_w = bounds_i_x[1] - bounds_i_x[0]

assert actual_overlap_h == overlap_pixels
assert actual_overlap_w == tile_shape[1]