diff --git a/MONAI_EX b/MONAI_EX index 8049bc3..0fddc4a 160000 --- a/MONAI_EX +++ b/MONAI_EX @@ -1 +1 @@ -Subproject commit 8049bc3830089264c313dda2c92873f20f9358c1 +Subproject commit 0fddc4a7e8e0f992461126da03ca5dfa37492b51 diff --git a/strix/data_checker.py b/strix/data_checker.py index 3b1d195..a4a8a00 100644 --- a/strix/data_checker.py +++ b/strix/data_checker.py @@ -1,15 +1,9 @@ import os -import yaml import click -import numpy as np -import nibabel as nib -from tqdm import tqdm from pathlib import Path from functools import partial from types import SimpleNamespace as sn from sklearn.model_selection import train_test_split -import warnings -from termcolor import colored import torch @@ -19,136 +13,37 @@ from strix.utilities.click_callbacks import get_unknown_options, dump_params, parse_project from strix.utilities.enum import FRAMEWORKS, Phases, SerialFileFormat from strix.utilities.utils import ( - plot_segmentation_masks, setup_logger, get_items, trycatch, generate_synthetic_datalist, ) + +from strix.utilities.check_data_utils import check_dataloader from strix.utilities.registry import DatasetRegistry from strix.configures import config as cfg -from monai_ex.utils import first, check_dir +from monai_ex.utils import check_dir from monai_ex.data import DataLoader -def save_raw_image(data, meta_dict, out_dir, phase, dataset_name, batch_index, logger_name=None): - if isinstance(data, torch.Tensor): - data = data.cpu().numpy() - - if logger_name: - logger = setup_logger(logger_name) - logger.info(f"Saving {phase} image to {out_dir}...") - - for i, patch in enumerate(data): - out_fname = check_dir( - out_dir, - dataset_name, - "raw images", - f"{phase}-batch{batch_index}-{i}.nii.gz", - isFile=True, - ) - nib.save(nib.Nifti1Image(patch.squeeze(), meta_dict["affine"][i]), out_fname) - - -def save_fnames(data, img_meta_key, image_fpath): - fnames = {idx + 1: fname for idx, fname in enumerate(data[img_meta_key]["filename_or_obj"])} - image_fpath = str(image_fpath).split("-chn")[0] - output_path = os.path.splitext(image_fpath)[0] + "-fnames.yml" - with open(output_path, "w") as f: - yaml.dump(fnames, f) - - -def save_2d_image_grid( - images, - nrow, - ncol, - out_dir, - phase, - dataset_name, - batch_index, - axis=None, - chn_idx=None, - overlap_method='mask', - mask=None, - mask_class_num = 2, - fnames: list = None, - alpha: float = None -): - if axis is not None and chn_idx is not None: - images = np.take(images, chn_idx, axis) - if mask is not None and mask.size(axis) > 1: - mask = np.take(mask, chn_idx, axis) - if mask is not None: - data_slice = images.detach().numpy() - mask_slice = mask.detach().numpy() - fig = plot_segmentation_masks(data_slice, mask_slice, nrow, ncol, - alpha = alpha, method = overlap_method, mask_class_num = mask_class_num, fnames = fnames) - - output_fname = f"-chn{chn_idx}" if chn_idx is not None else "" - - output_path = check_dir( - out_dir, - dataset_name, - f"{phase}-batch{batch_index}{output_fname}.png", - isFile=True, - ) - - fig.savefig(output_path, bbox_inches="tight", pad_inches=0) - return output_path - - -def save_3d_image_grid( - images, - axis, - nrow, - ncol, - out_dir, - phase, - dataset_name, - batch_index, - slice_index, - multichannel=False, - overlap_method='mask', - mask=None, - mask_class_num = 2, - fnames: list = None, - alpha: float = None -): - images = np.take(images, slice_index, axis) - if mask is not None: - mask = np.take(mask, slice_index, axis) - if mask is not None: - data_slice = images.detach().numpy() - mask_slice = mask.detach().numpy() - fig = plot_segmentation_masks(data_slice, mask_slice, nrow, ncol, - alpha = alpha, method = overlap_method, mask_class_num=mask_class_num, fnames = fnames) - - if multichannel: - output_fname = f"channel{slice_index}.png" - else: - output_fname = f"slice{slice_index}.png" - - output_path = check_dir(out_dir, dataset_name, f"{phase}-batch{batch_index}", output_fname, isFile=True) - fig.savefig(output_path, bbox_inches="tight", pad_inches=0) - return output_path - - option = partial(click.option, cls=OptionEx) command = partial(click.command, cls=CommandEx) -check_cmd_history = os.path.join(cfg.get_strix_cfg("cache_dir"), '.strix_check_cmd_history') +check_cmd_history = os.path.join(cfg.get_strix_cfg("cache_dir"), ".strix_check_cmd_history") @command( - "check-data", + "check-data", context_settings={ - "allow_extra_args": True, "ignore_unknown_options": True, "prompt_in_default_map": True, - "default_map": get_items(check_cmd_history, format=SerialFileFormat.YAML, allow_filenotfound=True) - }) -@option("--tensor-dim", prompt=True, type=Choice(["2D", "3D"]), default="2D", help="2D or 3D") -@option( - "--framework", prompt=True, type=Choice(FRAMEWORKS), default="segmentation", help="Choose your framework type" + "allow_extra_args": True, + "ignore_unknown_options": True, + "prompt_in_default_map": True, + "default_map": get_items(check_cmd_history, format=SerialFileFormat.YAML, allow_filenotfound=True), + }, ) +@option("--tensor-dim", prompt=True, type=Choice(["2D", "3D"]), default="2D", help="2D or 3D") +@option("--framework", prompt=True, type=Choice(FRAMEWORKS), default="segmentation", help="Choose your framework type") @option("--alpha", type=float, default=0.7, help="The opacity of mask") +@option("--line-width", type=float, default=0.1, help="The width of contour line") @option("--project", type=click.Path(), callback=parse_project, default=Path.cwd(), help="Project folder path") @option("--data-list", type=str, callback=data_select, default=None, help="Data file list") # todo: Rename @option("--n-batch", prompt=True, type=int, default=9, help="Batch size") @@ -159,7 +54,14 @@ def save_3d_image_grid( @option("--mask-key", type=str, default="mask", help="Specify mask key, default is 'mask'") @option("--seed", type=int, default=101, help="random seed") @option("--out-dir", type=str, prompt=True, default=cfg.get_strix_cfg("OUTPUT_DIR")) -@option("--dump-params", hidden=True, is_flag=True, default=False, callback=partial(dump_params, output_path=check_cmd_history)) +@option("--check-histogram", is_flag=True, help="Plotting histograms of data") +@option( + "--dump-params", + hidden=True, + is_flag=True, + default=False, + callback=partial(dump_params, output_path=check_cmd_history), +) @click.pass_context def check_data(ctx, **args): cargs = sn(**args) @@ -192,188 +94,44 @@ def get_train_valid_datasets(): train_dataset, valid_dataset = get_train_valid_datasets() logger.info(f"Creating dataset '{cargs.data_list}' successfully!") - train_num = min(cargs.n_batch, len(train_dataset)) - valid_num = min(cargs.n_batch, len(valid_dataset)) - train_dataloader = DataLoader(train_dataset, num_workers=1, batch_size=train_num, shuffle=True) - valid_dataloader = DataLoader(valid_dataset, num_workers=1, batch_size=valid_num, shuffle=False) - train_data = first(train_dataloader) - - img_key = cfg.get_key("IMAGE") - msk_key = cfg.get_key("MASK") if cargs.mask_key is None else cargs.mask_key - data_shape = train_data[img_key][0].shape - exist_mask = train_data.get(msk_key) is not None - channel, shape = data_shape[0], data_shape[1:] - logger.info(f"Data Channel: {channel}, Shape: {shape}") + if isinstance(train_dataset, torch.utils.data.DataLoader): + train_dataloader = train_dataset + valid_dataloader = valid_dataset + else: + train_num = min(cargs.n_batch, len(train_dataset)) + valid_num = min(cargs.n_batch, len(valid_dataset)) + train_dataloader = DataLoader(train_dataset, num_workers=1, batch_size=train_num, shuffle=True) + valid_dataloader = DataLoader(valid_dataset, num_workers=1, batch_size=valid_num, shuffle=False) if cargs.mask_overlap and cargs.contour_overlap: raise ValueError("mask_overlap/contour_overlap can only choose one!") - overlap = cargs.mask_overlap or cargs.contour_overlap - - if overlap and not exist_mask: - logger.warn(f"{msk_key} is not found in datalist.") - - if cargs.mask_overlap: - overlap_m = 'mask' - elif cargs.contour_overlap: - overlap_m = 'contour' - - def _check_mask(masks, fnames): - mask_class_num = len(masks.unique()) - 1 - msk = masks if mask_class_num > 0 else None - for i in range(masks.shape[0]): - n_class = len(msk[i,...].unique()) - 1 - if n_class == mask_class_num: - continue - elif n_class > 0: - warnings.warn(colored(f"Other cases had {mask_class_num} kinds of labels, but {fnames[i]} got {n_class}, please check your data", "yellow")) - else: - warnings.warn(colored(f"Case {fnames[i]} has no label (only background), please check your data", "yellow")) - return mask_class_num, msk - - if len(shape) == 2 and channel == 1: - for phase, dataloader in { - Phases.TRAIN.value: train_dataloader, - Phases.VALID.value: valid_dataloader, - }.items(): - for i, data in enumerate(tqdm(dataloader)): - bs = dataloader.batch_size - if exist_mask and overlap: - fnames = data[str(img_key) + '_meta_dict']["filename_or_obj"] - mask_class_num, msk = _check_mask(data[msk_key],fnames) - else: - mask_class_num = 0 - msk=None - row = int(np.ceil(np.sqrt(bs))) - column = row - if (row -1) * column >= bs: - row -= 1 - output_fpath = save_2d_image_grid( - data[img_key], - row, - column, - cargs.out_dir, - phase, - cargs.data_list, - i, - overlap_method=overlap_m, - mask=msk, - mask_class_num = mask_class_num, - fnames = fnames, - alpha = cargs.alpha - ) - if cargs.save_raw: - save_raw_image( - data[img_key], - data[f"{img_key}_meta_dict"], - cargs.out_dir, - phase, - cargs.data_list, - i, - logger_name - ) - - save_fnames(data, img_key + "_meta_dict", output_fpath) - elif len(shape) == 2 and channel > 1: - z_axis = 1 - for phase, dataloader in { - Phases.TRAIN.value: train_dataloader, - Phases.VALID.value: valid_dataloader, - }.items(): - for i, data in enumerate(tqdm(dataloader)): - bs = dataloader.batch_size - if exist_mask and overlap: - fnames = data[str(img_key) + '_meta_dict']["filename_or_obj"] - mask_class_num, msk = _check_mask(data[msk_key],fnames) - else: - mask_class_num = 0 - msk=None - - row = int(np.ceil(np.sqrt(bs))) - column = row - if (row -1) * column >= bs: - row -= 1 - if cargs.save_raw: - save_raw_image( - data[img_key], - data[f"{img_key}_meta_dict"], - cargs.out_dir, - phase, - cargs.data_list, - i, - logger_name - ) - - for ch_idx in range(channel): - output_fpath = save_2d_image_grid( - data[img_key], - row, - column, - cargs.out_dir, - phase, - cargs.data_list, - i, - axis = z_axis, - chn_idx =ch_idx, - overlap_method=overlap_m, - mask=msk, - mask_class_num=mask_class_num, - fnames = fnames, - alpha = cargs.alpha - ) - - save_fnames(data, img_key + "_meta_dict", output_fpath) - - elif len(shape) == 3 and channel == 1: - z_axis = np.argmin(shape) - for phase, dataloader in { - Phases.TRAIN.value: train_dataloader, - Phases.VALID.value: valid_dataloader, - }.items(): - for i, data in enumerate(tqdm(dataloader)): - bs = dataloader.batch_size - if exist_mask and overlap: - fnames = data[str(img_key) + '_meta_dict']["filename_or_obj"] - mask_class_num, msk = _check_mask(data[msk_key],fnames) - else: - mask_class_num = 0 - msk=None - - row = int(np.ceil(np.sqrt(bs))) - column = row - if (row -1) * column >= bs: - row -= 1 - for slice_idx in range(shape[z_axis]): - output_fpath = save_3d_image_grid( - data[img_key], - z_axis + 2, - row, - column, - cargs.out_dir, - phase, - cargs.data_list, - i, - slice_idx, - multichannel=False, - overlap_method=overlap_m, - mask=msk, - mask_class_num=mask_class_num, - fnames = fnames, - alpha = cargs.alpha - ) - - if cargs.save_raw: - save_raw_image( - data[img_key], - data[f"{img_key}_meta_dict"], - cargs.out_dir, - phase, - cargs.data_list, - i, - logger_name - ) - - save_fnames(data, img_key + "_meta_dict", output_fpath) + overlap_m = "mask" if cargs.mask_overlap else "contour" if cargs.contour_overlap else None + + check_dataloader( + phase=Phases.TRAIN, + dataloader=train_dataloader, + mask_key=cargs.mask_key, + out_dir=cargs.out_dir, + dataset_name=cargs.data_list, + overlap_method=overlap_m, + alpha=cargs.alpha, + line_width=cargs.line_width, + save_raw=cargs.save_raw, + check_hist=cargs.check_histogram, + logger=logger, + ) - else: - raise NotImplementedError(f"Not implement data-checking for shape of {shape}, channel of {channel}") + check_dataloader( + phase=Phases.VALID, + dataloader=valid_dataloader, + mask_key=cargs.mask_key, + out_dir=cargs.out_dir, + dataset_name=cargs.data_list, + overlap_method=overlap_m, + alpha=cargs.alpha, + line_width=cargs.line_width, + save_raw=cargs.save_raw, + check_hist=cargs.check_histogram, + logger=logger, + ) diff --git a/strix/data_io/base_dataset/basic_dataset.py b/strix/data_io/base_dataset/basic_dataset.py index f73df3d..5d4c263 100644 --- a/strix/data_io/base_dataset/basic_dataset.py +++ b/strix/data_io/base_dataset/basic_dataset.py @@ -23,7 +23,7 @@ def __new__( to_tensor: Union[Sequence[MapTransform], MapTransform], is_supervised: bool, dataset_type: Dataset, - dataset_kwargs: dict, + dataset_kwargs: Optional[dict] = None, additional_transforms: Optional[Sequence[MapTransform]] = None, check_data: bool = True, profiling: bool = False, @@ -72,4 +72,8 @@ def _wrap_range(transforms): else: self.transforms = None - return self.dataset(self.input_data, transform=self.transforms, **self.dataset_kwargs) + if self.dataset_kwargs: + return self.dataset(self.input_data, transform=self.transforms, **self.dataset_kwargs) + else: + print(self.dataset, type(self.dataset)) + return self.dataset(self.input_data, transform=self.transforms) diff --git a/strix/utilities/check_data_utils.py b/strix/utilities/check_data_utils.py new file mode 100644 index 0000000..bfdc508 --- /dev/null +++ b/strix/utilities/check_data_utils.py @@ -0,0 +1,500 @@ +from matplotlib.colors import Normalize +import matplotlib.pyplot as plt +import logging +import warnings +from pathlib import Path +from typing import List, Optional, Union, Callable + +import matplotlib +import nibabel as nib +import numpy as np +import torch +from monai_ex.data import DataLoader +from monai_ex.utils import check_dir, first +from termcolor import colored +from tqdm import tqdm +from skimage import exposure, img_as_float +from strix.configures import config as cfg +from strix.utilities.enum import Phases +from strix.utilities.utils import setup_logger, get_colors, get_colormaps, get_unique_filename + +matplotlib.use("Agg") + + +def save_raw_image(data, meta_dict, out_dir, phase, dataset_name, batch_index, logger_name=None): + if isinstance(data, torch.Tensor): + data = data.cpu().numpy() + + if logger_name: + logger = setup_logger(logger_name) + logger.info(f"Saving {phase} image to {out_dir}...") + + for i, patch in enumerate(data): + out_fname = check_dir( + out_dir, + dataset_name, + "raw images", + f"{phase}-batch{batch_index}-{i}.nii.gz", + isFile=True, + ) + nib.save(nib.Nifti1Image(patch.squeeze(), meta_dict["affine"][i]), out_fname) + + +def save_2d_image_grid( + images: torch.Tensor, + nrow: int, + ncol: int, + out_dir: Optional[Union[Path, str]], + phase: str, + dataset_name: str, + batch_index: int, + fnames: List, + axis: Optional[int] = None, + chn_idx: Optional[int] = None, + overlap_method: Optional[str] = "mask", + mask: Optional[torch.Tensor] = None, + mask_class_num: int = 2, + alpha: float = 0.7, + line_width: float = 0.1, +): + if axis is not None and chn_idx is not None: + index = torch.tensor(chn_idx).to(images.device) + images = torch.index_select(images, dim=axis, index=index) + if mask is not None and mask.size(axis) > 1: + mask = torch.index_select(mask, dim=axis, index=index) + + fig = plot_segmentation_masks( + images.detach().cpu().numpy(), + mask.detach().cpu().numpy() if mask is not None else None, + nrow, + ncol, + alpha=alpha, + line_width=line_width, + method=overlap_method, + mask_class_num=mask_class_num, + fnames=fnames, + ) + + if out_dir: + output_fname = f"-chn{chn_idx}" if chn_idx is not None else "" + output_path = check_dir( + out_dir, + dataset_name, + f"{phase}-batch{batch_index}{output_fname}.png", + isFile=True, + ) + + fig.savefig(str(output_path), bbox_inches="tight", pad_inches=0) + + return fig + + +def save_3d_image_grid( + images, + axis: int, + nrow: int, + ncol: int, + out_dir: Optional[Union[Path, str]], + phase: str, + dataset_name: str, + batch_index: int, + slice_index: int, + fnames: List, + multichannel: bool = False, + overlap_method: Optional[str] = "mask", + mask=None, + mask_class_num=2, + alpha: float = 0.7, + line_width: float = 0.1, +): + # images = np.take(images, slice_index, axis) + index = torch.tensor(slice_index).to(images.device) + images = torch.index_select(images, dim=axis, index=index).squeeze(axis).detach().cpu().numpy() + + if mask is not None: + # mask = np.take(mask, slice_index, axis) + mask = torch.index_select(mask, dim=axis, index=index).squeeze(axis).detach().cpu().numpy() + + fig = plot_segmentation_masks( + images, + mask, + nrow, + ncol, + alpha=alpha, + line_width=line_width, + method=overlap_method, + mask_class_num=mask_class_num, + fnames=fnames, + ) + + if out_dir: + output_fname = f"channel{slice_index}.png" if multichannel else f"slice{slice_index}.png" + output_path = check_dir(out_dir, dataset_name, f"{phase}-batch{batch_index}", output_fname, isFile=True) + fig.savefig(str(output_path), bbox_inches="tight", pad_inches=0) + + return fig + + +def save_histogram_grid( + images: np.ndarray, + nrow: int, + ncol: int, + out_dir: Optional[Union[Path, str]], + phase: str, + dataset_name: str, + batch_index: int, + slice_index: int = None, + multichannel: bool = False, + fnames: Optional[List] = None, +): + images = images.detach().cpu().numpy() + + fig = plot_histogram( + images, + nrow, + ncol, + fnames=fnames, + ) + + if out_dir: + output_fname = f'channel_{slice_index}_histograms.png' if multichannel else 'histograms.png' + output_path = check_dir(out_dir, dataset_name, f"{phase}-batch{batch_index}", output_fname, isFile=True) + fig.savefig(str(output_path), bbox_inches="tight", pad_inches=0) + + return fig + + +def plot_segmentation_masks( + images: np.ndarray, + masks: Optional[np.ndarray], + nrow: int, + ncol: int, + alpha: float = 0.7, + line_width: float = 0.1, + method: Optional[str] = "mask", + mask_class_num: int = 2, + fnames: Optional[List] = None, +): + # cm = 1/2.54 + plt.close() + ratio = images.shape[-1] / images.shape[-2] + fig = plt.figure(figsize=(10 * ratio * ncol, 10 * nrow), dpi=300) + axes = fig.subplots(nrow, ncol) + color_num = mask_class_num + + if color_num > 0: + colors = get_colors(color_num) + cmap = get_colormaps(color_num) + else: + cmap, colors = None, None + + for i in range(nrow): + for j in range(ncol): + ax = axes[i, j] if nrow > 1 else axes[j] + ax.axis("off") + if i * ncol + j < images.shape[0]: + if fnames: + ax.set_title(fnames[i * ncol + j], fontsize=8) + + # draw images + ax.imshow(images[i * ncol + j, ...].squeeze(), cmap="gray", interpolation="bicubic") + + # draw masks if exists + if method == "mask" and masks is not None and cmap is not None: + ax.imshow( + np.ma.masked_equal(masks[i * ncol + j, ...].squeeze(), 0), + cmap, # ! unbound! + alpha=alpha, + norm=Normalize(vmin=1, vmax=color_num), + ) + elif method == "contour" and masks is not None and colors is not None: + list = [i for i in np.unique(masks[i * ncol + j, ...].squeeze()).tolist() if i] + if len(list) > 0: + ax.contour( + masks[i * ncol + j, ...].squeeze(), + levels=[x - 0.01 for x in list], + colors=colors[min(list) - 1: max(list)], + linewidths=line_width + ) # ! TypeError: slice indices must be integers or None or have an __index__ method >> >> >> > + else: + continue + + plt.subplots_adjust(left=0.1, right=0.2, bottom=0.1, top=0.2, wspace=0.1, hspace=0.2) + return fig + + +def plot_histogram( + images: np.ndarray, + nrow: int, + ncol: int, + fnames: Optional[List] = None, +): + plt.close() + plt.style.use('seaborn-white') + ratio = images.shape[-2] / images.shape[-3] + fig = plt.figure(figsize=(10 * ratio * ncol, 10 * nrow)) + axes = fig.subplots(nrow, ncol, sharey=True) + images = img_as_float(images) + + for i in range(nrow): + for j in range(ncol): + if i * ncol + j < images.shape[0]: + img = images[i * ncol + j, ...] + ax_hist = axes[i, j] if nrow > 1 else axes[j] + ax_cdf = ax_hist.twinx() + if fnames: + ax_hist.set_title(fnames[i * ncol + j], fontsize=8) + + # Display histogram + ax_hist.hist(img.ravel(), bins=256, color='tab:blue') + ax_hist.ticklabel_format(axis='y', style='scientific', scilimits=(0, 0)) + x_min, x_max = ax_hist.get_xlim() + ax_hist.set_xlabel('Pixel intensity') + ax_hist.set_xticks(np.linspace(x_min, x_max, 5)) + y_min, y_max = ax_hist.get_ylim() + ax_hist.set_ylabel('Number of pixels') + ax_hist.set_yticks(np.linspace(y_min, y_max, 5)) + + # Display cumulative distribution + img_cdf, bins = exposure.cumulative_distribution(img, 256) + ax_cdf.plot(bins, img_cdf, 'tab:red') + ax_cdf.set_yticks([]) + + plt.tight_layout() + + return fig + + +def check_dataloader( + phase: Phases, + dataloader: DataLoader, + mask_key: str, + out_dir: Optional[Union[Path, str]], + dataset_name: str, + overlap_method: Optional[str] = None, + alpha: float = 0.7, + line_width: float = 0.1, + save_raw: bool = False, + check_hist: bool = False, + logger: logging.Logger = logging.getLogger("data-check"), + progress_bar: Optional[Callable] = None +): + logger_name = logger.name + first_batch = first(dataloader) + img_key = cfg.get_key("IMAGE") + msk_key = cfg.get_key("MASK") if mask_key is None else mask_key + data_shape = first_batch[img_key][0].shape + exist_mask = first_batch.get(msk_key) is not None + channel, shape = data_shape[0], data_shape[1:] + logger.info(f"Data Channel: {channel}, Shape: {shape}") + figures = [] + + if overlap_method and not exist_mask: + logger.warn(f"{msk_key} is not found in datalist.") + + def _check_mask(masks, fnames): + mask_class_num = len(masks.unique()) - 1 + msk = masks if mask_class_num > 0 else None + for i in range(masks.shape[0]): + n_class = len(msk[i, ...].unique()) - 1 # ! msk is None + if n_class == mask_class_num: + continue + elif n_class > 0: + warnings.warn( + colored( + f"Other cases had {mask_class_num} kinds of labels, but {fnames[i]} got {n_class}, please check your data", + "yellow", + ) + ) + else: + warnings.warn( + colored(f"Case {fnames[i]} has no label (only background), please check your data", "yellow") + ) + return mask_class_num, msk + + if len(shape) == 2 and channel == 1: + for i, data in enumerate(tqdm(dataloader)): + bs = dataloader.batch_size + fnames = get_unique_filename(data[str(img_key) + "_meta_dict"]["filename_or_obj"]) + if exist_mask and overlap_method: + mask_class_num, msk = _check_mask(data[msk_key], fnames) + else: + mask_class_num = 0 + msk = None + row = int(np.ceil(np.sqrt(bs))) + column = row + if (row - 1) * column >= bs: + row -= 1 + + figs = save_2d_image_grid( + data[img_key], + row, + column, + out_dir, + phase.value, + dataset_name, + i, + fnames=fnames, + overlap_method=overlap_method, + mask=msk, + mask_class_num=mask_class_num, + alpha=alpha, + line_width=line_width + ) + figures.append(figs) + + if check_hist: + hist = save_histogram_grid( + data[img_key], + row, + column, + out_dir, + phase.value, + dataset_name, + i, + fnames=fnames, + ) + figures.append(hist) + + if save_raw: + save_raw_image( + data[img_key], data[f"{img_key}_meta_dict"], out_dir, phase.value, dataset_name, i, logger_name + ) + if progress_bar: + progress_bar((i + 1) / len(dataloader)) + + return figures + + elif len(shape) == 2 and channel > 1: + z_axis = 1 + + for i, data in enumerate(tqdm(dataloader)): + bs = dataloader.batch_size or 1 # prevent None + fnames = get_unique_filename(data[str(img_key) + "_meta_dict"]["filename_or_obj"]) + if exist_mask and overlap_method: + mask_class_num, msk = _check_mask(data[msk_key], fnames) + else: + mask_class_num = 0 + msk = None + + row = int(np.ceil(np.sqrt(bs))) + column = row + if (row - 1) * column >= bs: + row -= 1 + if save_raw: + save_raw_image( + data[img_key], data[f"{img_key}_meta_dict"], out_dir, phase.value, dataset_name, i, logger_name + ) + + channel_figures = {ch_idx : [] for ch_idx in range(channel)} + for ch_idx in range(channel): + figs = save_2d_image_grid( + data[img_key], + row, + column, + out_dir, + phase.value, + dataset_name, + i, + fnames=fnames, + axis=z_axis, + chn_idx=ch_idx, + overlap_method=overlap_method, + mask=msk, + mask_class_num=mask_class_num, + alpha=alpha, + line_width=line_width + ) + channel_figures[ch_idx].append(figs) + + if check_hist: + hist = save_histogram_grid( + data[img_key], + row, + column, + out_dir, + phase.value, + dataset_name, + i, + fnames=fnames, + ) + figures.append(hist) + + # fill the figures with [{0: fig1, 1: fig2, 2: fig3}, {0: fig1, 1: fig2, 2: fig3}, ...] + for _, ch_idx in enumerate(channel_figures): + if _ == 0: + for item in channel_figures[ch_idx]: + figures.append({ch_idx: item}) + else: + for i, item in enumerate(channel_figures[ch_idx]): + figures[i].update({ch_idx: item}) + + if progress_bar: + progress_bar((i + 1) / len(dataloader)) + + return figures + + elif len(shape) == 3 and channel == 1: + z_axis = int(np.argmin(shape)) + + for i, data in enumerate(tqdm(dataloader)): + bs = dataloader.batch_size or 1 + fnames = get_unique_filename(data[str(img_key) + "_meta_dict"]["filename_or_obj"]) + if exist_mask and overlap_method: + mask_class_num, msk = _check_mask(data[msk_key], fnames) + else: + mask_class_num = 0 + msk = None + + row = int(np.ceil(np.sqrt(bs))) + column = row + if (row - 1) * column >= bs: + row -= 1 + + fig_volume = [] + for slice_idx in range(data[img_key].shape[z_axis + 2]): + figs = save_3d_image_grid( + data[img_key], + z_axis + 2, + row, + column, + out_dir, + phase.value, + dataset_name, + i, + slice_idx, + fnames=fnames, + multichannel=False, + overlap_method=overlap_method, + mask=msk, + mask_class_num=mask_class_num, + alpha=alpha, + line_width=line_width + ) + fig_volume.append(figs) + figures.append(fig_volume) + + if check_hist: + hist = save_histogram_grid( + data[img_key], + row, + column, + out_dir, + phase.value, + dataset_name, + batch_index=i, + fnames=fnames, + ) + figures.append(hist) + + if save_raw: + save_raw_image( + data[img_key], data[f"{img_key}_meta_dict"], out_dir, phase.value, dataset_name, i, logger_name + ) + if progress_bar: + progress_bar((i + 1) / len(dataloader)) + + return figures + + else: + raise NotImplementedError(f"Not implement data-checking for shape of {shape}, channel of {channel}") diff --git a/strix/utilities/enum.py b/strix/utilities/enum.py index 7c2bc8c..19bcad6 100644 --- a/strix/utilities/enum.py +++ b/strix/utilities/enum.py @@ -142,6 +142,7 @@ class Freezers(Enum): FULL = "full" SUBTASK = "subtask" + FREEZERS = get_enums(Freezers) diff --git a/strix/utilities/utils.py b/strix/utilities/utils.py index 077284c..2c8b41b 100644 --- a/strix/utilities/utils.py +++ b/strix/utilities/utils.py @@ -7,7 +7,7 @@ import warnings from functools import partial from pathlib import Path -from typing import Any, Callable, Optional, TextIO, Union +from typing import Any, Callable, Optional, TextIO, Union, List, Sequence, Tuple import matplotlib import pylab import torch @@ -15,6 +15,7 @@ from PIL import Image from termcolor import colored import logging + matplotlib.use("Agg") import matplotlib.cm as mpl_color_map import matplotlib.colors as mcolors @@ -33,7 +34,9 @@ @trycatch() -def get_items(filelist: Union[str, Path], format: Optional[SerialFileFormat] = None, allow_filenotfound: bool = False) -> Any: +def get_items( + filelist: Union[str, Path], format: Optional[SerialFileFormat] = None, allow_filenotfound: bool = False +) -> Any: """Get items from given serialized file. Support both json and yaml. Args: @@ -43,7 +46,7 @@ def get_items(filelist: Union[str, Path], format: Optional[SerialFileFormat] = N Raises: FileNotFoundError: filelist not found. - ValueError: unknown file format is given when format is not specified + ValueError: unknown file format is given when format is not specified GenericException: json.JSONDecodeError and yaml.YAMLError GenericException: data path not exists @@ -143,16 +146,18 @@ def get_colors(num: Optional[int] = None): else: return list(mcolors.TABLEAU_COLORS.values()) # type: ignore -def get_colormaps(num: int = None): + +def get_colormaps(num: Optional[int] = None): if num is None: num = 10 if num == 1: - return 'tab10' + return "tab10" else: - colormap = [list(plt.cm.tab10(0.1 * i)) for i in range(num)] - new_colormap = LinearSegmentedColormap.from_list('my_colormap', colormap, N=num + 1 if num == 1 else num) + colormap = [list(plt.cm.tab10(0.1 * i)) for i in range(num)] + new_colormap = LinearSegmentedColormap.from_list("my_colormap", colormap, N=num + 1 if num == 1 else num) return new_colormap + def bbox_3D(img): r = np.any(img, axis=(1, 2)) c = np.any(img, axis=(0, 2)) @@ -358,6 +363,32 @@ def output_filename_check(torch_dataset, meta_key="image_meta_dict"): return "" +# todo: check through all data +def get_unique_filename(file_paths: List, index: int = 0): + if len(file_paths) == 1: + fpath = file_paths[0] + if isinstance(fpath, list): + fpath = fpath[0] + return Path(fpath).stem + + if isinstance(prev_fpath := file_paths[0], list): + prev_fpath = prev_fpath[0] + if isinstance(next_fpath := file_paths[1], list): + next_fpath = next_fpath[0] + + root_dir = None + prev_parents = list(Path(prev_fpath).parents)[::-1] + next_parents = list(Path(next_fpath).parents)[::-1] + for (prev_item, next_item) in zip(prev_parents, next_parents): + if prev_item.stem != next_item.stem: + root_dir = prev_item.parent + break + + if root_dir: + return [str(Path(file_path).relative_to(root_dir).parts[index]) for file_path in file_paths] + return "" + + def detect_port(port): """Detect if the port is used""" socket_test = socket.socket(socket.AF_INET, socket.SOCK_STREAM) @@ -522,43 +553,6 @@ def is_bound(data, coord, offsets): return boundaries -def plot_segmentation_masks(images: np.ndarray, - masks: np.ndarray, - nrow:int, - ncol:int, - alpha: float = 0.8, - method = 'mask', - mask_class_num = 2, - fnames = None -) : - # cm = 1/2.54 - plt.close() - ratio = images.shape[-1]/images.shape[-2] - fig = plt.figure(figsize=(15 * ratio * ncol, 15 * nrow)) - axes = fig.subplots(nrow, ncol) - color_num = mask_class_num - if color_num >0: - colors = get_colors(color_num) - cmap = get_colormaps(color_num) - for i in range(nrow): - for j in range(ncol): - axes[i,j].axis('off') - if i* ncol+j < images.shape[0]: - title = str(Path(fnames[i* ncol+j]).stem) - axes[i,j].set_title(title, fontsize=8) - axes[i,j].imshow(images[i*ncol+j,...].squeeze(), cmap = 'gray', interpolation = 'bicubic') - if method == 'mask': - axes[i,j].imshow(np.ma.masked_equal(masks[i*ncol+j,...].squeeze(), 0), cmap, alpha=alpha, norm=Normalize(vmin=1, vmax=color_num)) - elif method == 'contour': - list = [i for i in np.unique(masks[i*ncol+j,...].squeeze()).tolist() if i] - if len(list) > 0: - axes[i,j].contour(masks[i*ncol+j,...].squeeze(), levels=[x-0.01 for x in list], colors=colors[min(list)-1:max(list)]) - else: - continue - plt.subplots_adjust( - left=0.1, right=0.2, bottom=0.1, top=0.2, wspace=0.1, hspace=0.2 - ) - return fig class LogColorFormatter(logging.Formatter): """Logging colored formatter, adapted from https://stackoverflow.com/a/56944256/3638629""" @@ -717,7 +711,7 @@ def setup_logger( def warning_on_one_line(message, category, filename, lineno, file=None, line=None): - return '%s:%s:\n %s: %s\n' % (filename, lineno, category.__name__, message) + return "%s:%s:\n %s: %s\n" % (filename, lineno, category.__name__, message) def singleton(cls): @@ -743,7 +737,9 @@ def save_sourcecode(code_rootdir, out_dir, verbose=True): os.system(f"cd {str(code_rootdir)}; tar -{tar_opt} {outpath} .") -def get_torch_datast(strix_dataset, phase: Phases, opts: dict, synthetic_data_num=100, split_func: Optional[Callable] = None): +def get_torch_datast( + strix_dataset, phase: Phases, opts: dict, synthetic_data_num=100, split_func: Optional[Callable] = None +): """This function return pytorch dataset generated by registered strix dataset. Args: