diff --git a/.gitignore b/.gitignore index b7faf40..a599e53 100644 --- a/.gitignore +++ b/.gitignore @@ -205,3 +205,4 @@ cython_debug/ marimo/_static/ marimo/_lsp/ __marimo__/ +.DS_Store diff --git a/src/tilefusion/core.py b/src/tilefusion/core.py index e6d75ab..c37cab7 100644 --- a/src/tilefusion/core.py +++ b/src/tilefusion/core.py @@ -25,6 +25,7 @@ from .registration import ( compute_pair_bounds, find_adjacent_pairs, + find_adjacent_pairs_fast, register_and_score, register_pair_worker, ) @@ -46,6 +47,7 @@ create_zarr_store, write_ngff_metadata, write_scale_group_metadata, + OMETiffReader, ) @@ -388,10 +390,16 @@ def refine_tile_positions_with_cross_correlation( max_shift = (100, 100) - # Find adjacent pairs - adjacent_pairs = find_adjacent_pairs( - self._tile_positions, self._pixel_size, (self.Y, self.X) - ) + # Use fast spatial hashing for large datasets (>100 tiles) + n_tiles = len(self._tile_positions) + if n_tiles > 100: + adjacent_pairs = find_adjacent_pairs_fast( + self._tile_positions, self._pixel_size, (self.Y, self.X) + ) + else: + adjacent_pairs = find_adjacent_pairs( + self._tile_positions, self._pixel_size, (self.Y, self.X) + ) if self._debug: print(f"Found {len(adjacent_pairs)} adjacent tile pairs to register") @@ -402,7 +410,16 @@ def refine_tile_positions_with_cross_correlation( # Use parallel processing for CPU mode use_parallel = parallel and not USING_GPU and len(pair_bounds) > 4 - if use_parallel: + # Use optimized path for OME-TIFF multi-series format (persistent file handle + cache) + is_ome_tiff_multi = ( + not self._is_zarr_format + and not self._is_individual_tiffs_format + and not self._is_ome_tiff_tiles_format + ) + + if use_parallel and is_ome_tiff_multi and n_tiles > 50: + self._register_parallel_optimized(pair_bounds, df, sw, th, max_shift) + elif use_parallel: self._register_parallel(pair_bounds, df, sw, th, max_shift) else: self._register_sequential(pair_bounds, df, sw, th, max_shift) @@ -473,6 +490,79 @@ def read_pair_patches(args): del patches, work_items, results gc.collect() + def _register_parallel_optimized( + self, + pair_bounds: List[Tuple], + df: Tuple[int, int], + sw: int, + th: float, + max_shift: Tuple[int, int], + ) -> None: + """Register tile pairs using persistent file handle (OME-TIFF only).""" + import psutil + + available_ram = psutil.virtual_memory().available + patch_size_est = self.Y * self.X * 4 * 2 + max_pairs_in_memory = int(available_ram * 0.3 / patch_size_est) + batch_size = max(16, max_pairs_in_memory) + + n_pairs = len(pair_bounds) + n_batches = (n_pairs + batch_size - 1) // batch_size + n_workers = min(cpu_count(), batch_size, 8) + + if n_batches > 1: + print(f"Processing {n_pairs} pairs in {n_batches} batches") + + with OMETiffReader(self.tiff_path) as reader: + + def read_pair_patches(args): + i_pos, j_pos, bounds_i_y, bounds_i_x, bounds_j_y, bounds_j_x = args + try: + patch_i = reader.read_region( + i_pos, + slice(bounds_i_y[0], bounds_i_y[1]), + slice(bounds_i_x[0], bounds_i_x[1]), + ) + patch_j = reader.read_region( + j_pos, + slice(bounds_j_y[0], bounds_j_y[1]), + slice(bounds_j_x[0], bounds_j_x[1]), + ) + return (i_pos, j_pos, patch_i, patch_j) + except Exception: + return (i_pos, j_pos, None, None) + + for batch_idx in range(n_batches): + start = batch_idx * batch_size + end = min(start + batch_size, n_pairs) + batch = pair_bounds[start:end] + + patches = [read_pair_patches(args) for args in batch] + + work_items = [ + (i, j, pi, pj, df, sw, th, max_shift) + for i, j, pi, pj in patches + if pi is not None + ] + + desc = f"register {batch_idx+1}/{n_batches}" if n_batches > 1 else "register" + with ThreadPoolExecutor(max_workers=n_workers) as executor: + results = list( + tqdm( + executor.map(register_pair_worker, work_items), + total=len(work_items), + desc=desc, + leave=True, + ) + ) + + for i_pos, j_pos, dy_s, dx_s, score in results: + if dy_s is not None: + self.pairwise_metrics[(i_pos, j_pos)] = (dy_s, dx_s, score) + + del patches, work_items, results + gc.collect() + def _register_sequential( self, pair_bounds: List[Tuple], diff --git a/src/tilefusion/io/__init__.py b/src/tilefusion/io/__init__.py index daa718c..2e3f48c 100644 --- a/src/tilefusion/io/__init__.py +++ b/src/tilefusion/io/__init__.py @@ -2,7 +2,7 @@ I/O modules for different microscopy file formats. """ -from .ome_tiff import load_ome_tiff_metadata, read_ome_tiff_tile, read_ome_tiff_region +from .ome_tiff import load_ome_tiff_metadata, read_ome_tiff_tile, read_ome_tiff_region, OMETiffReader from .individual_tiffs import ( load_individual_tiffs_metadata, read_individual_tiffs_tile, @@ -26,6 +26,7 @@ "load_ome_tiff_metadata", "read_ome_tiff_tile", "read_ome_tiff_region", + "OMETiffReader", "load_individual_tiffs_metadata", "read_individual_tiffs_tile", "read_individual_tiffs_region", diff --git a/src/tilefusion/io/ome_tiff.py b/src/tilefusion/io/ome_tiff.py index d802201..791a807 100644 --- a/src/tilefusion/io/ome_tiff.py +++ b/src/tilefusion/io/ome_tiff.py @@ -133,3 +133,48 @@ def read_ome_tiff_region( # Flip along Y axis to correct orientation arr = np.flip(arr, axis=-2) return arr[:, y_slice, x_slice].astype(np.float32) + + +class OMETiffReader: + """ + Persistent file handle for OME-TIFF reading. + + Keeps the file open during registration to avoid repeated open/close + operations on large files. + """ + + def __init__(self, path: Path): + self.path = Path(path) + self._tif = None + + def __enter__(self): + self._tif = tifffile.TiffFile(self.path) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if self._tif is not None: + self._tif.close() + return False + + def read_region(self, tile_idx: int, y_slice: slice, x_slice: slice) -> np.ndarray: + """ + Read a region of a tile using the persistent file handle. + + Parameters + ---------- + tile_idx : int + Index of the tile. + y_slice, x_slice : slice + Region to read. + + Returns + ------- + arr : ndarray of shape (C, h, w) + Tile region as float32. + """ + arr = self._tif.series[tile_idx].asarray() + if arr.ndim == 2: + arr = arr[np.newaxis, :, :] + # Flip along Y axis to correct orientation + arr = np.flip(arr, axis=-2) + return arr[:, y_slice, x_slice].astype(np.float32) diff --git a/src/tilefusion/registration.py b/src/tilefusion/registration.py index b9ad595..e982d4b 100644 --- a/src/tilefusion/registration.py +++ b/src/tilefusion/registration.py @@ -4,7 +4,8 @@ Phase cross-correlation based registration with SSIM scoring. """ -from typing import Any, Tuple, Union +from collections import defaultdict +from typing import Any, List, Tuple, Union import numpy as np @@ -168,6 +169,94 @@ def find_adjacent_pairs(tile_positions, pixel_size, tile_shape, min_overlap=15): return adjacent_pairs +def find_adjacent_pairs_fast( + tile_positions: List[Tuple[float, float]], + pixel_size: Tuple[float, float], + tile_shape: Tuple[int, int], + min_overlap: int = 15, +) -> List[Tuple[int, int, int, int, int, int]]: + """ + Find adjacent tile pairs using spatial hashing (O(n) instead of O(n²)). + + Uses grid-based spatial hashing to only compare tiles in nearby cells, + dramatically reducing comparisons for large datasets. + + Parameters + ---------- + tile_positions : list of (y, x) tuples + Stage positions for each tile in physical units. + pixel_size : tuple of (py, px) + Pixel size in physical units. + tile_shape : tuple of (Y, X) + Tile dimensions in pixels. + min_overlap : int + Minimum overlap in pixels. + + Returns + ------- + adjacent_pairs : list of tuples + Each tuple: (i_pos, j_pos, dy, dx, overlap_y, overlap_x) + """ + n_pos = len(tile_positions) + if n_pos == 0: + return [] + + Y, X = tile_shape + py, px = pixel_size + + # Cell size is tile size in physical units + cell_y = Y * py + cell_x = X * px + + # Build spatial hash: (cell_y, cell_x) -> list of tile indices + grid: defaultdict = defaultdict(list) + for i, (y, x) in enumerate(tile_positions): + cy = int(y / cell_y) + cx = int(x / cell_x) + grid[(cy, cx)].append(i) + + # Check only tiles in same or adjacent cells + adjacent_pairs = [] + checked = set() + + for (cy, cx), tiles_in_cell in grid.items(): + # Gather all tiles in 3x3 neighborhood + neighborhood = [] + for dy in [-1, 0, 1]: + for dx in [-1, 0, 1]: + neighborhood.extend(grid.get((cy + dy, cx + dx), [])) + + # Check pairs between tiles in this cell and neighborhood + for i_pos in tiles_in_cell: + for j_pos in neighborhood: + # Ensure i < j and not already checked + if i_pos >= j_pos: + continue + pair_key = (i_pos, j_pos) + if pair_key in checked: + continue + checked.add(pair_key) + + # Compute offset in pixels + phys = np.array(tile_positions[j_pos]) - np.array(tile_positions[i_pos]) + vox_off = np.round(phys / np.array(pixel_size)).astype(int) + dy_px, dx_px = vox_off + + overlap_y = Y - abs(dy_px) + overlap_x = X - abs(dx_px) + + # Check if tiles are adjacent + is_horizontal_neighbor = abs(dy_px) < min_overlap and overlap_x >= min_overlap + is_vertical_neighbor = abs(dx_px) < min_overlap and overlap_y >= min_overlap + + if is_horizontal_neighbor or is_vertical_neighbor: + adjacent_pairs.append( + (i_pos, j_pos, dy_px, dx_px, overlap_y, overlap_x) + ) + + return adjacent_pairs + + def compute_pair_bounds(adjacent_pairs, tile_shape): """ Compute overlap bounds for each adjacent pair.