From 46040aa61507019ee0407dde3334a6b38ca305c6 Mon Sep 17 00:00:00 2001 From: Hongquan Li Date: Mon, 29 Dec 2025 20:39:31 -0800 Subject: [PATCH 1/4] feat: Optimize parallel registration with smart RAM management MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Auto-detect parallel mode: enabled for multi-file formats (individual TIFFs, zarr), disabled for single-file OME-TIFF (benchmarks showed 40% slower due to I/O contention) - Fix RAM estimation to use actual overlap region size instead of full tile size (was 5-10x over-conservative) - Add hybrid batching: simple path when memory fits, batched when >30% RAM needed - Use class max_workers setting instead of hardcoded values - Add worker count logging for transparency - Add tests for overlap bounds calculation 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- src/tilefusion/core.py | 108 +++++++++++++++++++++++++++++-------- tests/test_registration.py | 52 ++++++++++++++++++ 2 files changed, 139 insertions(+), 21 deletions(-) diff --git a/src/tilefusion/core.py b/src/tilefusion/core.py index cff64fb..9a9fb0d 100644 --- a/src/tilefusion/core.py +++ b/src/tilefusion/core.py @@ -518,10 +518,16 @@ def refine_tile_positions_with_cross_correlation( ssim_window: int = None, ch_idx: int = 0, threshold: float = None, - parallel: bool = True, + parallel: Optional[bool] = None, ) -> None: """ Detect and score overlaps between neighboring tile pairs via cross-correlation. + + Parameters + ---------- + parallel : bool, optional + If None (default), auto-detects: enabled for individual TIFFs, + disabled for single-file OME-TIFF (due to I/O contention). """ df = downsample_factors or self.downsample_factors sw = ssim_window or self.ssim_window @@ -541,7 +547,18 @@ def refine_tile_positions_with_cross_correlation( # Compute bounds pair_bounds = compute_pair_bounds(adjacent_pairs, (self.Y, self.X)) - # Use parallel processing for CPU mode + # Auto-detect parallel mode if not specified + if parallel is None: + # Parallel helps for individual TIFFs (separate files) + # but hurts for single-file OME-TIFF (I/O contention) + is_multi_file = ( + self._is_zarr_format + or self._is_individual_tiffs_format + or self._is_ome_tiff_tiles_format + ) + parallel = is_multi_file + + # Use parallel processing for CPU mode with enough pairs use_parallel = parallel and not USING_GPU and len(pair_bounds) > 4 if use_parallel: @@ -557,20 +574,32 @@ def _register_parallel( th: float, max_shift: Tuple[int, int], ) -> None: - """Register tile pairs using parallel processing (CPU mode).""" - import psutil + """Register tile pairs using parallel I/O and compute. - available_ram = psutil.virtual_memory().available - patch_size_est = self.Y * self.X * 4 * 2 - max_pairs_in_memory = int(available_ram * 0.3 / patch_size_est) - batch_size = max(16, max_pairs_in_memory) + Uses batching only when estimated memory exceeds 30% of available RAM. + """ + import psutil n_pairs = len(pair_bounds) - n_batches = (n_pairs + batch_size - 1) // batch_size - n_workers = min(cpu_count(), batch_size, 8) + n_workers = min(cpu_count(), n_pairs, self._max_workers) + io_workers = min(n_pairs, self._max_workers) + print(f"Parallel registration: {n_pairs} pairs, {n_workers} compute workers, {io_workers} I/O workers") + + # Estimate memory needed based on actual overlap size + if pair_bounds: + total_pixels = 0 + for _, _, bounds_i_y, bounds_i_x, _, _ in pair_bounds: + patch_h = bounds_i_y[1] - bounds_i_y[0] + patch_w = bounds_i_x[1] - bounds_i_x[0] + total_pixels += patch_h * patch_w + # 4 bytes per float32 pixel, 2 patches per pair + estimated_memory = total_pixels * 4 * 2 + else: + estimated_memory = 0 - if n_batches > 1: - print(f"Processing {n_pairs} pairs in {n_batches} batches") + available_ram = psutil.virtual_memory().available + ram_budget = int(available_ram * 0.3) + needs_batching = estimated_memory > ram_budget def read_pair_patches(args): i_pos, j_pos, bounds_i_y, bounds_i_x, bounds_j_y, bounds_j_x = args @@ -585,25 +614,62 @@ def read_pair_patches(args): except Exception: return (i_pos, j_pos, None, None) - for batch_idx in range(n_batches): - start = batch_idx * batch_size - end = min(start + batch_size, n_pairs) - batch = pair_bounds[start:end] + if needs_batching: + # Batched approach for large datasets + avg_pair_bytes = estimated_memory // n_pairs if n_pairs > 0 else 1 + batch_size = max(16, ram_budget // avg_pair_bytes) + n_batches = (n_pairs + batch_size - 1) // batch_size + + print(f"Processing {n_pairs} pairs in {n_batches} batches (RAM limited, {n_workers} workers)") + + for batch_idx in range(n_batches): + start = batch_idx * batch_size + end = min(start + batch_size, n_pairs) + batch = pair_bounds[start:end] + + with ThreadPoolExecutor(max_workers=io_workers) as io_executor: + patches = list(io_executor.map(read_pair_patches, batch)) + + work_items = [ + (i, j, pi, pj, df, sw, th, max_shift) + for i, j, pi, pj in patches + if pi is not None + ] + + desc = f"register {batch_idx+1}/{n_batches}" + with ThreadPoolExecutor(max_workers=n_workers) as executor: + results = list( + tqdm( + executor.map(register_pair_worker, work_items), + total=len(work_items), + desc=desc, + leave=True, + ) + ) + + for i_pos, j_pos, dy_s, dx_s, score in results: + if dy_s is not None: + self.pairwise_metrics[(i_pos, j_pos)] = (dy_s, dx_s, score) - with ThreadPoolExecutor(max_workers=8) as io_executor: - patches = list(io_executor.map(read_pair_patches, batch)) + del patches, work_items, results + gc.collect() + else: + # Simple approach - load all at once + with ThreadPoolExecutor(max_workers=io_workers) as io_executor: + patches = list(io_executor.map(read_pair_patches, pair_bounds)) work_items = [ - (i, j, pi, pj, df, sw, th, max_shift) for i, j, pi, pj in patches if pi is not None + (i, j, pi, pj, df, sw, th, max_shift) + for i, j, pi, pj in patches + if pi is not None ] - desc = f"register {batch_idx+1}/{n_batches}" if n_batches > 1 else "register" with ThreadPoolExecutor(max_workers=n_workers) as executor: results = list( tqdm( executor.map(register_pair_worker, work_items), total=len(work_items), - desc=desc, + desc="register", leave=True, ) ) diff --git a/tests/test_registration.py b/tests/test_registration.py index 13e56d8..f35a2fe 100644 --- a/tests/test_registration.py +++ b/tests/test_registration.py @@ -142,3 +142,55 @@ def test_returns_float_tuple(self): assert isinstance(shift[0], float) assert isinstance(shift[1], float) assert isinstance(ssim, float) + + +class TestOverlapBoundsSize: + """Tests verifying overlap region size from compute_pair_bounds.""" + + def test_horizontal_overlap_much_smaller_than_full_tile(self): + """Verify horizontal overlap region is much smaller than full tile. + + For a 2048x2048 tile with 15% overlap (~307px), the overlap region is + roughly 2048 * 307 = 628,736 pixels, not 2048 * 2048 = 4,194,304 pixels. + """ + tile_shape = (2048, 2048) + overlap_pixels = 307 + dx = tile_shape[1] - overlap_pixels # 1741 + + adjacent_pairs = [ + (0, 1, 0, dx, tile_shape[0], overlap_pixels), # horizontal pair + ] + pair_bounds = compute_pair_bounds(adjacent_pairs, tile_shape) + + _, _, bounds_i_y, bounds_i_x, _, _ = pair_bounds[0] + actual_overlap_h = bounds_i_y[1] - bounds_i_y[0] + actual_overlap_w = bounds_i_x[1] - bounds_i_x[0] + actual_overlap_pixels = actual_overlap_h * actual_overlap_w + + full_tile_pixels = tile_shape[0] * tile_shape[1] + + # Overlap region should be much smaller than full tile + ratio = full_tile_pixels / actual_overlap_pixels + assert ratio > 5, f"Expected overlap to be <20% of full tile, got ratio {ratio:.1f}" + + # Verify the bounds are correct + assert actual_overlap_w == overlap_pixels + assert actual_overlap_h == tile_shape[0] + + def test_vertical_overlap_bounds(self): + """Test bounds for vertical overlap.""" + tile_shape = (2048, 2048) + overlap_pixels = 307 + dy = tile_shape[0] - overlap_pixels + + adjacent_pairs = [ + (0, 1, dy, 0, overlap_pixels, tile_shape[1]), # vertical pair + ] + pair_bounds = compute_pair_bounds(adjacent_pairs, tile_shape) + + _, _, bounds_i_y, bounds_i_x, _, _ = pair_bounds[0] + actual_overlap_h = bounds_i_y[1] - bounds_i_y[0] + actual_overlap_w = bounds_i_x[1] - bounds_i_x[0] + + assert actual_overlap_h == overlap_pixels + assert actual_overlap_w == tile_shape[1] From f92659700b70773828faf6c135a69027cdbe9e22 Mon Sep 17 00:00:00 2001 From: Hongquan Li Date: Mon, 29 Dec 2025 21:15:46 -0800 Subject: [PATCH 2/4] style: Apply black formatting --- src/tilefusion/core.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/tilefusion/core.py b/src/tilefusion/core.py index 9a9fb0d..b89ce38 100644 --- a/src/tilefusion/core.py +++ b/src/tilefusion/core.py @@ -583,7 +583,9 @@ def _register_parallel( n_pairs = len(pair_bounds) n_workers = min(cpu_count(), n_pairs, self._max_workers) io_workers = min(n_pairs, self._max_workers) - print(f"Parallel registration: {n_pairs} pairs, {n_workers} compute workers, {io_workers} I/O workers") + print( + f"Parallel registration: {n_pairs} pairs, {n_workers} compute workers, {io_workers} I/O workers" + ) # Estimate memory needed based on actual overlap size if pair_bounds: @@ -620,7 +622,9 @@ def read_pair_patches(args): batch_size = max(16, ram_budget // avg_pair_bytes) n_batches = (n_pairs + batch_size - 1) // batch_size - print(f"Processing {n_pairs} pairs in {n_batches} batches (RAM limited, {n_workers} workers)") + print( + f"Processing {n_pairs} pairs in {n_batches} batches (RAM limited, {n_workers} workers)" + ) for batch_idx in range(n_batches): start = batch_idx * batch_size @@ -659,9 +663,7 @@ def read_pair_patches(args): patches = list(io_executor.map(read_pair_patches, pair_bounds)) work_items = [ - (i, j, pi, pj, df, sw, th, max_shift) - for i, j, pi, pj in patches - if pi is not None + (i, j, pi, pj, df, sw, th, max_shift) for i, j, pi, pj in patches if pi is not None ] with ThreadPoolExecutor(max_workers=n_workers) as executor: From 03551cbf372d207ccafbc9eac72b741924c78532 Mon Sep 17 00:00:00 2001 From: hongquanli Date: Mon, 29 Dec 2025 21:23:16 -0800 Subject: [PATCH 3/4] Update src/tilefusion/core.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- src/tilefusion/core.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/tilefusion/core.py b/src/tilefusion/core.py index b89ce38..bc4efaa 100644 --- a/src/tilefusion/core.py +++ b/src/tilefusion/core.py @@ -526,8 +526,9 @@ def refine_tile_positions_with_cross_correlation( Parameters ---------- parallel : bool, optional - If None (default), auto-detects: enabled for individual TIFFs, - disabled for single-file OME-TIFF (due to I/O contention). + If None (default), auto-detects: enabled for multi-file formats + (Zarr, individual TIFFs, OME-TIFF tiles), disabled for single-file + OME-TIFF (due to I/O contention). """ df = downsample_factors or self.downsample_factors sw = ssim_window or self.ssim_window From d541c8909a4a7a4d6af3f60ed57b6e4f45759eac Mon Sep 17 00:00:00 2001 From: hongquanli Date: Mon, 29 Dec 2025 21:23:30 -0800 Subject: [PATCH 4/4] Update src/tilefusion/core.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- src/tilefusion/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tilefusion/core.py b/src/tilefusion/core.py index bc4efaa..4086a55 100644 --- a/src/tilefusion/core.py +++ b/src/tilefusion/core.py @@ -619,7 +619,7 @@ def read_pair_patches(args): if needs_batching: # Batched approach for large datasets - avg_pair_bytes = estimated_memory // n_pairs if n_pairs > 0 else 1 + avg_pair_bytes = max(1, estimated_memory // n_pairs) if n_pairs > 0 else 1 batch_size = max(16, ram_budget // avg_pair_bytes) n_batches = (n_pairs + batch_size - 1) // batch_size