From e7a0e330cb5b31688010260469da9c908410fc85 Mon Sep 17 00:00:00 2001 From: Chenglong Wang Date: Thu, 24 Nov 2022 15:33:12 +0800 Subject: [PATCH 1/5] Moved datachecking-related functions to `check_data_utils.py`. Extract `check_dataloader` function to enable checking dataloader. Minor bugs and typing issues fixed. Signed-off-by: Chenglong Wang --- MONAI_EX | 2 +- strix/data_checker.py | 349 +++----------------- strix/data_io/base_dataset/basic_dataset.py | 8 +- strix/utilities/check_data_utils.py | 344 +++++++++++++++++++ strix/utilities/utils.py | 62 +--- 5 files changed, 417 insertions(+), 348 deletions(-) create mode 100644 strix/utilities/check_data_utils.py diff --git a/MONAI_EX b/MONAI_EX index 8049bc3..0fddc4a 160000 --- a/MONAI_EX +++ b/MONAI_EX @@ -1 +1 @@ -Subproject commit 8049bc3830089264c313dda2c92873f20f9358c1 +Subproject commit 0fddc4a7e8e0f992461126da03ca5dfa37492b51 diff --git a/strix/data_checker.py b/strix/data_checker.py index 3b1d195..eacfcb3 100644 --- a/strix/data_checker.py +++ b/strix/data_checker.py @@ -1,15 +1,9 @@ import os -import yaml import click -import numpy as np -import nibabel as nib -from tqdm import tqdm from pathlib import Path from functools import partial from types import SimpleNamespace as sn from sklearn.model_selection import train_test_split -import warnings -from termcolor import colored import torch @@ -19,135 +13,34 @@ from strix.utilities.click_callbacks import get_unknown_options, dump_params, parse_project from strix.utilities.enum import FRAMEWORKS, Phases, SerialFileFormat from strix.utilities.utils import ( - plot_segmentation_masks, setup_logger, get_items, trycatch, generate_synthetic_datalist, ) +from strix.utilities.check_data_utils import check_dataloader from strix.utilities.registry import DatasetRegistry from strix.configures import config as cfg -from monai_ex.utils import first, check_dir +from monai_ex.utils import check_dir from monai_ex.data import DataLoader -def save_raw_image(data, meta_dict, out_dir, phase, dataset_name, batch_index, logger_name=None): - if isinstance(data, torch.Tensor): - data = data.cpu().numpy() - - if logger_name: - logger = setup_logger(logger_name) - logger.info(f"Saving {phase} image to {out_dir}...") - - for i, patch in enumerate(data): - out_fname = check_dir( - out_dir, - dataset_name, - "raw images", - f"{phase}-batch{batch_index}-{i}.nii.gz", - isFile=True, - ) - nib.save(nib.Nifti1Image(patch.squeeze(), meta_dict["affine"][i]), out_fname) - - -def save_fnames(data, img_meta_key, image_fpath): - fnames = {idx + 1: fname for idx, fname in enumerate(data[img_meta_key]["filename_or_obj"])} - image_fpath = str(image_fpath).split("-chn")[0] - output_path = os.path.splitext(image_fpath)[0] + "-fnames.yml" - with open(output_path, "w") as f: - yaml.dump(fnames, f) - - -def save_2d_image_grid( - images, - nrow, - ncol, - out_dir, - phase, - dataset_name, - batch_index, - axis=None, - chn_idx=None, - overlap_method='mask', - mask=None, - mask_class_num = 2, - fnames: list = None, - alpha: float = None -): - if axis is not None and chn_idx is not None: - images = np.take(images, chn_idx, axis) - if mask is not None and mask.size(axis) > 1: - mask = np.take(mask, chn_idx, axis) - if mask is not None: - data_slice = images.detach().numpy() - mask_slice = mask.detach().numpy() - fig = plot_segmentation_masks(data_slice, mask_slice, nrow, ncol, - alpha = alpha, method = overlap_method, mask_class_num = mask_class_num, fnames = fnames) - - output_fname = f"-chn{chn_idx}" if chn_idx is not None else "" - - output_path = check_dir( - out_dir, - dataset_name, - f"{phase}-batch{batch_index}{output_fname}.png", - isFile=True, - ) - - fig.savefig(output_path, bbox_inches="tight", pad_inches=0) - return output_path - - -def save_3d_image_grid( - images, - axis, - nrow, - ncol, - out_dir, - phase, - dataset_name, - batch_index, - slice_index, - multichannel=False, - overlap_method='mask', - mask=None, - mask_class_num = 2, - fnames: list = None, - alpha: float = None -): - images = np.take(images, slice_index, axis) - if mask is not None: - mask = np.take(mask, slice_index, axis) - if mask is not None: - data_slice = images.detach().numpy() - mask_slice = mask.detach().numpy() - fig = plot_segmentation_masks(data_slice, mask_slice, nrow, ncol, - alpha = alpha, method = overlap_method, mask_class_num=mask_class_num, fnames = fnames) - - if multichannel: - output_fname = f"channel{slice_index}.png" - else: - output_fname = f"slice{slice_index}.png" - - output_path = check_dir(out_dir, dataset_name, f"{phase}-batch{batch_index}", output_fname, isFile=True) - fig.savefig(output_path, bbox_inches="tight", pad_inches=0) - return output_path - - option = partial(click.option, cls=OptionEx) command = partial(click.command, cls=CommandEx) -check_cmd_history = os.path.join(cfg.get_strix_cfg("cache_dir"), '.strix_check_cmd_history') +check_cmd_history = os.path.join(cfg.get_strix_cfg("cache_dir"), ".strix_check_cmd_history") @command( - "check-data", + "check-data", context_settings={ - "allow_extra_args": True, "ignore_unknown_options": True, "prompt_in_default_map": True, - "default_map": get_items(check_cmd_history, format=SerialFileFormat.YAML, allow_filenotfound=True) - }) -@option("--tensor-dim", prompt=True, type=Choice(["2D", "3D"]), default="2D", help="2D or 3D") -@option( - "--framework", prompt=True, type=Choice(FRAMEWORKS), default="segmentation", help="Choose your framework type" + "allow_extra_args": True, + "ignore_unknown_options": True, + "prompt_in_default_map": True, + "default_map": get_items(check_cmd_history, format=SerialFileFormat.YAML, allow_filenotfound=True), + }, ) +@option("--tensor-dim", prompt=True, type=Choice(["2D", "3D"]), default="2D", help="2D or 3D") +@option("--framework", prompt=True, type=Choice(FRAMEWORKS), default="segmentation", help="Choose your framework type") @option("--alpha", type=float, default=0.7, help="The opacity of mask") @option("--project", type=click.Path(), callback=parse_project, default=Path.cwd(), help="Project folder path") @option("--data-list", type=str, callback=data_select, default=None, help="Data file list") # todo: Rename @@ -159,7 +52,13 @@ def save_3d_image_grid( @option("--mask-key", type=str, default="mask", help="Specify mask key, default is 'mask'") @option("--seed", type=int, default=101, help="random seed") @option("--out-dir", type=str, prompt=True, default=cfg.get_strix_cfg("OUTPUT_DIR")) -@option("--dump-params", hidden=True, is_flag=True, default=False, callback=partial(dump_params, output_path=check_cmd_history)) +@option( + "--dump-params", + hidden=True, + is_flag=True, + default=False, + callback=partial(dump_params, output_path=check_cmd_history), +) @click.pass_context def check_data(ctx, **args): cargs = sn(**args) @@ -192,188 +91,40 @@ def get_train_valid_datasets(): train_dataset, valid_dataset = get_train_valid_datasets() logger.info(f"Creating dataset '{cargs.data_list}' successfully!") - train_num = min(cargs.n_batch, len(train_dataset)) - valid_num = min(cargs.n_batch, len(valid_dataset)) - train_dataloader = DataLoader(train_dataset, num_workers=1, batch_size=train_num, shuffle=True) - valid_dataloader = DataLoader(valid_dataset, num_workers=1, batch_size=valid_num, shuffle=False) - train_data = first(train_dataloader) - - img_key = cfg.get_key("IMAGE") - msk_key = cfg.get_key("MASK") if cargs.mask_key is None else cargs.mask_key - data_shape = train_data[img_key][0].shape - exist_mask = train_data.get(msk_key) is not None - channel, shape = data_shape[0], data_shape[1:] - logger.info(f"Data Channel: {channel}, Shape: {shape}") + if isinstance(train_dataset, torch.utils.data.DataLoader): + train_dataloader = train_dataset + valid_dataloader = valid_dataset + else: + train_num = min(cargs.n_batch, len(train_dataset)) + valid_num = min(cargs.n_batch, len(valid_dataset)) + train_dataloader = DataLoader(train_dataset, num_workers=1, batch_size=train_num, shuffle=True) + valid_dataloader = DataLoader(valid_dataset, num_workers=1, batch_size=valid_num, shuffle=False) if cargs.mask_overlap and cargs.contour_overlap: raise ValueError("mask_overlap/contour_overlap can only choose one!") - overlap = cargs.mask_overlap or cargs.contour_overlap - - if overlap and not exist_mask: - logger.warn(f"{msk_key} is not found in datalist.") - - if cargs.mask_overlap: - overlap_m = 'mask' - elif cargs.contour_overlap: - overlap_m = 'contour' - - def _check_mask(masks, fnames): - mask_class_num = len(masks.unique()) - 1 - msk = masks if mask_class_num > 0 else None - for i in range(masks.shape[0]): - n_class = len(msk[i,...].unique()) - 1 - if n_class == mask_class_num: - continue - elif n_class > 0: - warnings.warn(colored(f"Other cases had {mask_class_num} kinds of labels, but {fnames[i]} got {n_class}, please check your data", "yellow")) - else: - warnings.warn(colored(f"Case {fnames[i]} has no label (only background), please check your data", "yellow")) - return mask_class_num, msk - - if len(shape) == 2 and channel == 1: - for phase, dataloader in { - Phases.TRAIN.value: train_dataloader, - Phases.VALID.value: valid_dataloader, - }.items(): - for i, data in enumerate(tqdm(dataloader)): - bs = dataloader.batch_size - if exist_mask and overlap: - fnames = data[str(img_key) + '_meta_dict']["filename_or_obj"] - mask_class_num, msk = _check_mask(data[msk_key],fnames) - else: - mask_class_num = 0 - msk=None - row = int(np.ceil(np.sqrt(bs))) - column = row - if (row -1) * column >= bs: - row -= 1 - output_fpath = save_2d_image_grid( - data[img_key], - row, - column, - cargs.out_dir, - phase, - cargs.data_list, - i, - overlap_method=overlap_m, - mask=msk, - mask_class_num = mask_class_num, - fnames = fnames, - alpha = cargs.alpha - ) - if cargs.save_raw: - save_raw_image( - data[img_key], - data[f"{img_key}_meta_dict"], - cargs.out_dir, - phase, - cargs.data_list, - i, - logger_name - ) - - save_fnames(data, img_key + "_meta_dict", output_fpath) - elif len(shape) == 2 and channel > 1: - z_axis = 1 - for phase, dataloader in { - Phases.TRAIN.value: train_dataloader, - Phases.VALID.value: valid_dataloader, - }.items(): - for i, data in enumerate(tqdm(dataloader)): - bs = dataloader.batch_size - if exist_mask and overlap: - fnames = data[str(img_key) + '_meta_dict']["filename_or_obj"] - mask_class_num, msk = _check_mask(data[msk_key],fnames) - else: - mask_class_num = 0 - msk=None - - row = int(np.ceil(np.sqrt(bs))) - column = row - if (row -1) * column >= bs: - row -= 1 - if cargs.save_raw: - save_raw_image( - data[img_key], - data[f"{img_key}_meta_dict"], - cargs.out_dir, - phase, - cargs.data_list, - i, - logger_name - ) - - for ch_idx in range(channel): - output_fpath = save_2d_image_grid( - data[img_key], - row, - column, - cargs.out_dir, - phase, - cargs.data_list, - i, - axis = z_axis, - chn_idx =ch_idx, - overlap_method=overlap_m, - mask=msk, - mask_class_num=mask_class_num, - fnames = fnames, - alpha = cargs.alpha - ) - - save_fnames(data, img_key + "_meta_dict", output_fpath) - - elif len(shape) == 3 and channel == 1: - z_axis = np.argmin(shape) - for phase, dataloader in { - Phases.TRAIN.value: train_dataloader, - Phases.VALID.value: valid_dataloader, - }.items(): - for i, data in enumerate(tqdm(dataloader)): - bs = dataloader.batch_size - if exist_mask and overlap: - fnames = data[str(img_key) + '_meta_dict']["filename_or_obj"] - mask_class_num, msk = _check_mask(data[msk_key],fnames) - else: - mask_class_num = 0 - msk=None - - row = int(np.ceil(np.sqrt(bs))) - column = row - if (row -1) * column >= bs: - row -= 1 - for slice_idx in range(shape[z_axis]): - output_fpath = save_3d_image_grid( - data[img_key], - z_axis + 2, - row, - column, - cargs.out_dir, - phase, - cargs.data_list, - i, - slice_idx, - multichannel=False, - overlap_method=overlap_m, - mask=msk, - mask_class_num=mask_class_num, - fnames = fnames, - alpha = cargs.alpha - ) - - if cargs.save_raw: - save_raw_image( - data[img_key], - data[f"{img_key}_meta_dict"], - cargs.out_dir, - phase, - cargs.data_list, - i, - logger_name - ) - - save_fnames(data, img_key + "_meta_dict", output_fpath) + overlap_m = "mask" if cargs.mask_overlap else "contour" if cargs.contour_overlap else None + + check_dataloader( + phase=Phases.TRAIN, + dataloader=train_dataloader, + mask_key=cargs.mask_key, + out_dir=cargs.out_dir, + dataset_name=cargs.data_list, + overlap_method=overlap_m, + alpha=cargs.alpha, + save_raw=cargs.save_raw, + logger=logger, + ) - else: - raise NotImplementedError(f"Not implement data-checking for shape of {shape}, channel of {channel}") + check_dataloader( + phase=Phases.VALID, + dataloader=valid_dataloader, + mask_key=cargs.mask_key, + out_dir=cargs.out_dir, + dataset_name=cargs.data_list, + overlap_method=overlap_m, + alpha=cargs.alpha, + save_raw=cargs.save_raw, + logger=logger, + ) diff --git a/strix/data_io/base_dataset/basic_dataset.py b/strix/data_io/base_dataset/basic_dataset.py index f73df3d..5d4c263 100644 --- a/strix/data_io/base_dataset/basic_dataset.py +++ b/strix/data_io/base_dataset/basic_dataset.py @@ -23,7 +23,7 @@ def __new__( to_tensor: Union[Sequence[MapTransform], MapTransform], is_supervised: bool, dataset_type: Dataset, - dataset_kwargs: dict, + dataset_kwargs: Optional[dict] = None, additional_transforms: Optional[Sequence[MapTransform]] = None, check_data: bool = True, profiling: bool = False, @@ -72,4 +72,8 @@ def _wrap_range(transforms): else: self.transforms = None - return self.dataset(self.input_data, transform=self.transforms, **self.dataset_kwargs) + if self.dataset_kwargs: + return self.dataset(self.input_data, transform=self.transforms, **self.dataset_kwargs) + else: + print(self.dataset, type(self.dataset)) + return self.dataset(self.input_data, transform=self.transforms) diff --git a/strix/utilities/check_data_utils.py b/strix/utilities/check_data_utils.py new file mode 100644 index 0000000..74108cc --- /dev/null +++ b/strix/utilities/check_data_utils.py @@ -0,0 +1,344 @@ +import logging +import warnings +from pathlib import Path +from typing import List, Optional, Union + +import matplotlib +import nibabel as nib +import numpy as np +import torch +from monai_ex.data import DataLoader +from monai_ex.utils import check_dir, first +from termcolor import colored +from tqdm import tqdm + +from strix.configures import config as cfg +from strix.utilities.enum import FRAMEWORKS, Phases +from strix.utilities.utils import setup_logger, get_colors, get_colormaps + +matplotlib.use("Agg") +import matplotlib.cm as mpl_color_map +import matplotlib.colors as mcolors +import matplotlib.pyplot as plt +from matplotlib.colors import Normalize + + +def save_raw_image(data, meta_dict, out_dir, phase, dataset_name, batch_index, logger_name=None): + if isinstance(data, torch.Tensor): + data = data.cpu().numpy() + + if logger_name: + logger = setup_logger(logger_name) + logger.info(f"Saving {phase} image to {out_dir}...") + + for i, patch in enumerate(data): + out_fname = check_dir( + out_dir, + dataset_name, + "raw images", + f"{phase}-batch{batch_index}-{i}.nii.gz", + isFile=True, + ) + nib.save(nib.Nifti1Image(patch.squeeze(), meta_dict["affine"][i]), out_fname) + + +def save_2d_image_grid( + images: torch.Tensor, + nrow: int, + ncol: int, + out_dir: Union[Path, str], + phase: str, + dataset_name: str, + batch_index: int, + fnames: List, + axis: Optional[int] = None, + chn_idx: Optional[int] = None, + overlap_method: Optional[str] = "mask", + mask: Optional[torch.Tensor] = None, + mask_class_num: int = 2, + alpha: float = 0.7, +): + if axis is not None and chn_idx is not None: + index = torch.tensor(chn_idx).to(images.device) + images = torch.index_select(images, dim=axis, index=index) + if mask is not None and mask.size(axis) > 1: + mask = torch.index_select(mask, dim=axis, index=index) + + fig = plot_segmentation_masks( + images.detach().cpu().numpy(), + mask.detach().cpu().numpy() if mask else None, + nrow, + ncol, + alpha=alpha, + method=overlap_method, + mask_class_num=mask_class_num, + fnames=fnames, + ) + + + output_fname = f"-chn{chn_idx}" if chn_idx is not None else "" + + output_path = check_dir( + out_dir, + dataset_name, + f"{phase}-batch{batch_index}{output_fname}.png", + isFile=True, + ) + + fig.savefig(str(output_path), bbox_inches="tight", pad_inches=0) + return output_path + + +def save_3d_image_grid( + images, + axis, + nrow, + ncol, + out_dir, + phase, + dataset_name, + batch_index: int, + slice_index: int, + fnames: List, + multichannel: bool = False, + overlap_method: Optional[str] = "mask", + mask=None, + mask_class_num=2, + alpha: float = 0.7, +): + # images = np.take(images, slice_index, axis) + index = torch.tensor(slice_index).to(images.device) + images = torch.index_select(images, dim=axis, index=index).squeeze(axis).detach().cpu().numpy() + + if mask is not None: + # mask = np.take(mask, slice_index, axis) + mask = torch.index_select(mask, dim=axis, index=index).squeeze(axis).detach().cpu().numpy() + + fig = plot_segmentation_masks( + images, + mask, + nrow, + ncol, + alpha=alpha, + method=overlap_method, + mask_class_num=mask_class_num, + fnames=fnames, + ) + + output_fname = f"channel{slice_index}.png" if multichannel else f"slice{slice_index}.png" + output_path = check_dir(out_dir, dataset_name, f"{phase}-batch{batch_index}", output_fname, isFile=True) + fig.savefig(str(output_path), bbox_inches="tight", pad_inches=0) + return output_path + + +def plot_segmentation_masks( + images: np.ndarray, + masks: Optional[np.ndarray], + nrow: int, + ncol: int, + alpha: float = 0.8, + method: Optional[str] = "mask", + mask_class_num: int = 2, + fnames: Optional[List] = None, +): + # cm = 1/2.54 + plt.close() + ratio = images.shape[-1] / images.shape[-2] + fig = plt.figure(figsize=(15 * ratio * ncol, 15 * nrow)) + axes = fig.subplots(nrow, ncol) + color_num = mask_class_num + + if color_num > 0: + colors = get_colors(color_num) + cmap = get_colormaps(color_num) + + for i in range(nrow): + for j in range(ncol): + axes[i, j].axis("off") + if i * ncol + j < images.shape[0]: + if fnames: + title = str(Path(fnames[i * ncol + j]).stem) + axes[i, j].set_title(title, fontsize=8) + + # draw images + axes[i, j].imshow(images[i * ncol + j, ...].squeeze(), cmap="gray", interpolation="bicubic") + + # draw masks if exists + if method == "mask" and masks is not None: + axes[i, j].imshow( + np.ma.masked_equal(masks[i * ncol + j, ...].squeeze(), 0), + cmap, #! unbound! + alpha=alpha, + norm=Normalize(vmin=1, vmax=color_num), + ) + elif method == "contour" and masks is not None: + list = [i for i in np.unique(masks[i * ncol + j, ...].squeeze()).tolist() if i] + if len(list) > 0: + axes[i, j].contour( + masks[i * ncol + j, ...].squeeze(), + levels=[x - 0.01 for x in list], + colors=colors[min(list) - 1 : max(list)], + ) + else: + continue + + plt.subplots_adjust(left=0.1, right=0.2, bottom=0.1, top=0.2, wspace=0.1, hspace=0.2) + return fig + + +def check_dataloader( + phase: Phases, + dataloader: DataLoader, + mask_key: str, + out_dir: Union[Path, str], + dataset_name: str, + overlap_method: Optional[str] = None, + alpha: float = 0.7, + save_raw: bool = False, + logger: logging.Logger = logging.getLogger("data-check"), +) -> None: + logger_name = logger.name + first_batch = first(dataloader) + img_key = cfg.get_key("IMAGE") + msk_key = cfg.get_key("MASK") if mask_key is None else mask_key + data_shape = first_batch[img_key][0].shape + exist_mask = first_batch.get(msk_key) is not None + channel, shape = data_shape[0], data_shape[1:] + logger.info(f"Data Channel: {channel}, Shape: {shape}") + + if overlap_method and not exist_mask: + logger.warn(f"{msk_key} is not found in datalist.") + + def _check_mask(masks, fnames): + mask_class_num = len(masks.unique()) - 1 + msk = masks if mask_class_num > 0 else None + for i in range(masks.shape[0]): + n_class = len(msk[i, ...].unique()) - 1 #! msk is None + if n_class == mask_class_num: + continue + elif n_class > 0: + warnings.warn( + colored( + f"Other cases had {mask_class_num} kinds of labels, but {fnames[i]} got {n_class}, please check your data", + "yellow", + ) + ) + else: + warnings.warn( + colored(f"Case {fnames[i]} has no label (only background), please check your data", "yellow") + ) + return mask_class_num, msk + + if len(shape) == 2 and channel == 1: + for i, data in enumerate(tqdm(dataloader)): + bs = dataloader.batch_size + fnames = data[str(img_key) + "_meta_dict"]["filename_or_obj"] + if exist_mask and overlap_method: + mask_class_num, msk = _check_mask(data[msk_key], fnames) + else: + mask_class_num = 0 + msk = None + row = int(np.ceil(np.sqrt(bs))) + column = row + if (row - 1) * column >= bs: + row -= 1 + save_2d_image_grid( + data[img_key], + row, + column, + out_dir, + phase.value, + dataset_name, + i, + fnames=fnames, + overlap_method=overlap_method, + mask=msk, + mask_class_num=mask_class_num, + alpha=alpha, + ) + if save_raw: + save_raw_image( + data[img_key], data[f"{img_key}_meta_dict"], out_dir, phase.value, dataset_name, i, logger_name + ) + + elif len(shape) == 2 and channel > 1: + z_axis = 1 + + for i, data in enumerate(tqdm(dataloader)): + bs = dataloader.batch_size or 1 # prevent None + fnames = data[str(img_key) + "_meta_dict"]["filename_or_obj"] + if exist_mask and overlap_method: + mask_class_num, msk = _check_mask(data[msk_key], fnames) + else: + mask_class_num = 0 + msk = None + + row = int(np.ceil(np.sqrt(bs))) + column = row + if (row - 1) * column >= bs: + row -= 1 + if save_raw: + save_raw_image( + data[img_key], data[f"{img_key}_meta_dict"], out_dir, phase.value, dataset_name, i, logger_name + ) + + for ch_idx in range(channel): + save_2d_image_grid( + data[img_key], + row, + column, + out_dir, + phase.value, + dataset_name, + i, + fnames=fnames, + axis=z_axis, + chn_idx=ch_idx, + overlap_method=overlap_method, + mask=msk, + mask_class_num=mask_class_num, + alpha=alpha, + ) + + elif len(shape) == 3 and channel == 1: + z_axis = np.argmin(shape) + + for i, data in enumerate(tqdm(dataloader)): + bs = dataloader.batch_size or 1 + fnames = data[str(img_key) + "_meta_dict"]["filename_or_obj"] + if exist_mask and overlap_method: + mask_class_num, msk = _check_mask(data[msk_key], fnames) + else: + mask_class_num = 0 + msk = None + + row = int(np.ceil(np.sqrt(bs))) + column = row + if (row - 1) * column >= bs: + row -= 1 + for slice_idx in range(shape[z_axis]): + save_3d_image_grid( + data[img_key], + z_axis + 2, + row, + column, + out_dir, + phase.value, + dataset_name, + i, + slice_idx, + fnames=fnames, + multichannel=False, + overlap_method=overlap_method, + mask=msk, + mask_class_num=mask_class_num, + alpha=alpha, + ) + + if save_raw: + save_raw_image( + data[img_key], data[f"{img_key}_meta_dict"], out_dir, phase.value, dataset_name, i, logger_name + ) + + else: + raise NotImplementedError(f"Not implement data-checking for shape of {shape}, channel of {channel}") diff --git a/strix/utilities/utils.py b/strix/utilities/utils.py index 077284c..3fb7af6 100644 --- a/strix/utilities/utils.py +++ b/strix/utilities/utils.py @@ -7,7 +7,7 @@ import warnings from functools import partial from pathlib import Path -from typing import Any, Callable, Optional, TextIO, Union +from typing import Any, Callable, Optional, TextIO, Union, List import matplotlib import pylab import torch @@ -15,6 +15,7 @@ from PIL import Image from termcolor import colored import logging + matplotlib.use("Agg") import matplotlib.cm as mpl_color_map import matplotlib.colors as mcolors @@ -33,7 +34,9 @@ @trycatch() -def get_items(filelist: Union[str, Path], format: Optional[SerialFileFormat] = None, allow_filenotfound: bool = False) -> Any: +def get_items( + filelist: Union[str, Path], format: Optional[SerialFileFormat] = None, allow_filenotfound: bool = False +) -> Any: """Get items from given serialized file. Support both json and yaml. Args: @@ -43,7 +46,7 @@ def get_items(filelist: Union[str, Path], format: Optional[SerialFileFormat] = N Raises: FileNotFoundError: filelist not found. - ValueError: unknown file format is given when format is not specified + ValueError: unknown file format is given when format is not specified GenericException: json.JSONDecodeError and yaml.YAMLError GenericException: data path not exists @@ -143,16 +146,18 @@ def get_colors(num: Optional[int] = None): else: return list(mcolors.TABLEAU_COLORS.values()) # type: ignore -def get_colormaps(num: int = None): + +def get_colormaps(num: Optional[int] = None): if num is None: num = 10 if num == 1: - return 'tab10' + return "tab10" else: - colormap = [list(plt.cm.tab10(0.1 * i)) for i in range(num)] - new_colormap = LinearSegmentedColormap.from_list('my_colormap', colormap, N=num + 1 if num == 1 else num) + colormap = [list(plt.cm.tab10(0.1 * i)) for i in range(num)] + new_colormap = LinearSegmentedColormap.from_list("my_colormap", colormap, N=num + 1 if num == 1 else num) return new_colormap + def bbox_3D(img): r = np.any(img, axis=(1, 2)) c = np.any(img, axis=(0, 2)) @@ -522,43 +527,6 @@ def is_bound(data, coord, offsets): return boundaries -def plot_segmentation_masks(images: np.ndarray, - masks: np.ndarray, - nrow:int, - ncol:int, - alpha: float = 0.8, - method = 'mask', - mask_class_num = 2, - fnames = None -) : - # cm = 1/2.54 - plt.close() - ratio = images.shape[-1]/images.shape[-2] - fig = plt.figure(figsize=(15 * ratio * ncol, 15 * nrow)) - axes = fig.subplots(nrow, ncol) - color_num = mask_class_num - if color_num >0: - colors = get_colors(color_num) - cmap = get_colormaps(color_num) - for i in range(nrow): - for j in range(ncol): - axes[i,j].axis('off') - if i* ncol+j < images.shape[0]: - title = str(Path(fnames[i* ncol+j]).stem) - axes[i,j].set_title(title, fontsize=8) - axes[i,j].imshow(images[i*ncol+j,...].squeeze(), cmap = 'gray', interpolation = 'bicubic') - if method == 'mask': - axes[i,j].imshow(np.ma.masked_equal(masks[i*ncol+j,...].squeeze(), 0), cmap, alpha=alpha, norm=Normalize(vmin=1, vmax=color_num)) - elif method == 'contour': - list = [i for i in np.unique(masks[i*ncol+j,...].squeeze()).tolist() if i] - if len(list) > 0: - axes[i,j].contour(masks[i*ncol+j,...].squeeze(), levels=[x-0.01 for x in list], colors=colors[min(list)-1:max(list)]) - else: - continue - plt.subplots_adjust( - left=0.1, right=0.2, bottom=0.1, top=0.2, wspace=0.1, hspace=0.2 - ) - return fig class LogColorFormatter(logging.Formatter): """Logging colored formatter, adapted from https://stackoverflow.com/a/56944256/3638629""" @@ -717,7 +685,7 @@ def setup_logger( def warning_on_one_line(message, category, filename, lineno, file=None, line=None): - return '%s:%s:\n %s: %s\n' % (filename, lineno, category.__name__, message) + return "%s:%s:\n %s: %s\n" % (filename, lineno, category.__name__, message) def singleton(cls): @@ -743,7 +711,9 @@ def save_sourcecode(code_rootdir, out_dir, verbose=True): os.system(f"cd {str(code_rootdir)}; tar -{tar_opt} {outpath} .") -def get_torch_datast(strix_dataset, phase: Phases, opts: dict, synthetic_data_num=100, split_func: Optional[Callable] = None): +def get_torch_datast( + strix_dataset, phase: Phases, opts: dict, synthetic_data_num=100, split_func: Optional[Callable] = None +): """This function return pytorch dataset generated by registered strix dataset. Args: From 79701eb6b7b8aee75c54edc6c8e62d4807417910 Mon Sep 17 00:00:00 2001 From: Chenglong Wang Date: Thu, 24 Nov 2022 16:20:37 +0800 Subject: [PATCH 2/5] Added progressbar support. Signed-off-by: Chenglong Wang --- strix/utilities/check_data_utils.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/strix/utilities/check_data_utils.py b/strix/utilities/check_data_utils.py index 74108cc..7aed903 100644 --- a/strix/utilities/check_data_utils.py +++ b/strix/utilities/check_data_utils.py @@ -1,7 +1,7 @@ import logging import warnings from pathlib import Path -from typing import List, Optional, Union +from typing import List, Optional, Union, Callable import matplotlib import nibabel as nib @@ -196,6 +196,7 @@ def check_dataloader( alpha: float = 0.7, save_raw: bool = False, logger: logging.Logger = logging.getLogger("data-check"), + progress_bar: Optional[Callable] = None ) -> None: logger_name = logger.name first_batch = first(dataloader) @@ -260,6 +261,8 @@ def _check_mask(masks, fnames): save_raw_image( data[img_key], data[f"{img_key}_meta_dict"], out_dir, phase.value, dataset_name, i, logger_name ) + if progress_bar: + progress_bar(i/len(dataloader)) elif len(shape) == 2 and channel > 1: z_axis = 1 @@ -299,6 +302,8 @@ def _check_mask(masks, fnames): mask_class_num=mask_class_num, alpha=alpha, ) + if progress_bar: + progress_bar(i/len(dataloader)) elif len(shape) == 3 and channel == 1: z_axis = np.argmin(shape) @@ -339,6 +344,8 @@ def _check_mask(masks, fnames): save_raw_image( data[img_key], data[f"{img_key}_meta_dict"], out_dir, phase.value, dataset_name, i, logger_name ) + if progress_bar: + progress_bar(i/len(dataloader)) else: raise NotImplementedError(f"Not implement data-checking for shape of {shape}, channel of {channel}") From 461fdf1f15e198971e2fe079ca774683c0a45613 Mon Sep 17 00:00:00 2001 From: Dexuan Li <18693088246@163.com> Date: Thu, 24 Nov 2022 21:04:23 +0800 Subject: [PATCH 3/5] Add an argument to change the width of contour --- strix/data_checker.py | 4 ++++ strix/utilities/check_data_utils.py | 12 +++++++++++- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/strix/data_checker.py b/strix/data_checker.py index eacfcb3..3b614bb 100644 --- a/strix/data_checker.py +++ b/strix/data_checker.py @@ -18,6 +18,7 @@ trycatch, generate_synthetic_datalist, ) + from strix.utilities.check_data_utils import check_dataloader from strix.utilities.registry import DatasetRegistry from strix.configures import config as cfg @@ -42,6 +43,7 @@ @option("--tensor-dim", prompt=True, type=Choice(["2D", "3D"]), default="2D", help="2D or 3D") @option("--framework", prompt=True, type=Choice(FRAMEWORKS), default="segmentation", help="Choose your framework type") @option("--alpha", type=float, default=0.7, help="The opacity of mask") +@option("--line-width", type=float, default=0.1, help="The width of contour line") @option("--project", type=click.Path(), callback=parse_project, default=Path.cwd(), help="Project folder path") @option("--data-list", type=str, callback=data_select, default=None, help="Data file list") # todo: Rename @option("--n-batch", prompt=True, type=int, default=9, help="Batch size") @@ -113,6 +115,7 @@ def get_train_valid_datasets(): dataset_name=cargs.data_list, overlap_method=overlap_m, alpha=cargs.alpha, + line_width = cargs.line_width, save_raw=cargs.save_raw, logger=logger, ) @@ -125,6 +128,7 @@ def get_train_valid_datasets(): dataset_name=cargs.data_list, overlap_method=overlap_m, alpha=cargs.alpha, + line_width = cargs.line_width, save_raw=cargs.save_raw, logger=logger, ) diff --git a/strix/utilities/check_data_utils.py b/strix/utilities/check_data_utils.py index 74108cc..d056731 100644 --- a/strix/utilities/check_data_utils.py +++ b/strix/utilities/check_data_utils.py @@ -57,6 +57,7 @@ def save_2d_image_grid( mask: Optional[torch.Tensor] = None, mask_class_num: int = 2, alpha: float = 0.7, + line_width: float = 0.1, ): if axis is not None and chn_idx is not None: index = torch.tensor(chn_idx).to(images.device) @@ -70,6 +71,7 @@ def save_2d_image_grid( nrow, ncol, alpha=alpha, + line_width=line_width, method=overlap_method, mask_class_num=mask_class_num, fnames=fnames, @@ -105,6 +107,7 @@ def save_3d_image_grid( mask=None, mask_class_num=2, alpha: float = 0.7, + line_width: float = 0.1, ): # images = np.take(images, slice_index, axis) index = torch.tensor(slice_index).to(images.device) @@ -120,6 +123,7 @@ def save_3d_image_grid( nrow, ncol, alpha=alpha, + line_width=line_width, method=overlap_method, mask_class_num=mask_class_num, fnames=fnames, @@ -136,7 +140,8 @@ def plot_segmentation_masks( masks: Optional[np.ndarray], nrow: int, ncol: int, - alpha: float = 0.8, + alpha: float = 0.7, + line_width: float = 0.1, method: Optional[str] = "mask", mask_class_num: int = 2, fnames: Optional[List] = None, @@ -178,6 +183,7 @@ def plot_segmentation_masks( masks[i * ncol + j, ...].squeeze(), levels=[x - 0.01 for x in list], colors=colors[min(list) - 1 : max(list)], + linewidths = line_width ) else: continue @@ -194,6 +200,7 @@ def check_dataloader( dataset_name: str, overlap_method: Optional[str] = None, alpha: float = 0.7, + line_width: float = 0.1, save_raw: bool = False, logger: logging.Logger = logging.getLogger("data-check"), ) -> None: @@ -255,6 +262,7 @@ def _check_mask(masks, fnames): mask=msk, mask_class_num=mask_class_num, alpha=alpha, + line_width=line_width ) if save_raw: save_raw_image( @@ -298,6 +306,7 @@ def _check_mask(masks, fnames): mask=msk, mask_class_num=mask_class_num, alpha=alpha, + line_width=line_width ) elif len(shape) == 3 and channel == 1: @@ -333,6 +342,7 @@ def _check_mask(masks, fnames): mask=msk, mask_class_num=mask_class_num, alpha=alpha, + line_width=line_width ) if save_raw: From e0d3a0afc57d09cc072537a4ef21dcc671c51ec7 Mon Sep 17 00:00:00 2001 From: Chenglong Wang Date: Sun, 27 Nov 2022 16:22:15 +0800 Subject: [PATCH 4/5] Enabled not saving figures locally. Added get_unique_filename. --- strix/utilities/check_data_utils.py | 108 +++++++++++++++++----------- strix/utilities/utils.py | 28 +++++++- 2 files changed, 93 insertions(+), 43 deletions(-) diff --git a/strix/utilities/check_data_utils.py b/strix/utilities/check_data_utils.py index 7aed903..34e38bb 100644 --- a/strix/utilities/check_data_utils.py +++ b/strix/utilities/check_data_utils.py @@ -13,12 +13,10 @@ from tqdm import tqdm from strix.configures import config as cfg -from strix.utilities.enum import FRAMEWORKS, Phases -from strix.utilities.utils import setup_logger, get_colors, get_colormaps +from strix.utilities.enum import Phases +from strix.utilities.utils import setup_logger, get_colors, get_colormaps, get_unique_filename matplotlib.use("Agg") -import matplotlib.cm as mpl_color_map -import matplotlib.colors as mcolors import matplotlib.pyplot as plt from matplotlib.colors import Normalize @@ -46,7 +44,7 @@ def save_2d_image_grid( images: torch.Tensor, nrow: int, ncol: int, - out_dir: Union[Path, str], + out_dir: Optional[Union[Path, str]], phase: str, dataset_name: str, batch_index: int, @@ -66,7 +64,7 @@ def save_2d_image_grid( fig = plot_segmentation_masks( images.detach().cpu().numpy(), - mask.detach().cpu().numpy() if mask else None, + mask.detach().cpu().numpy() if mask is not None else None, nrow, ncol, alpha=alpha, @@ -75,28 +73,28 @@ def save_2d_image_grid( fnames=fnames, ) + if out_dir: + output_fname = f"-chn{chn_idx}" if chn_idx is not None else "" + output_path = check_dir( + out_dir, + dataset_name, + f"{phase}-batch{batch_index}{output_fname}.png", + isFile=True, + ) - output_fname = f"-chn{chn_idx}" if chn_idx is not None else "" - - output_path = check_dir( - out_dir, - dataset_name, - f"{phase}-batch{batch_index}{output_fname}.png", - isFile=True, - ) + fig.savefig(str(output_path), bbox_inches="tight", pad_inches=0) - fig.savefig(str(output_path), bbox_inches="tight", pad_inches=0) - return output_path + return fig def save_3d_image_grid( images, - axis, - nrow, - ncol, - out_dir, - phase, - dataset_name, + axis: int, + nrow: int, + ncol: int, + out_dir: Optional[Union[Path, str]], + phase: str, + dataset_name: str, batch_index: int, slice_index: int, fnames: List, @@ -125,10 +123,12 @@ def save_3d_image_grid( fnames=fnames, ) - output_fname = f"channel{slice_index}.png" if multichannel else f"slice{slice_index}.png" - output_path = check_dir(out_dir, dataset_name, f"{phase}-batch{batch_index}", output_fname, isFile=True) - fig.savefig(str(output_path), bbox_inches="tight", pad_inches=0) - return output_path + if out_dir: + output_fname = f"channel{slice_index}.png" if multichannel else f"slice{slice_index}.png" + output_path = check_dir(out_dir, dataset_name, f"{phase}-batch{batch_index}", output_fname, isFile=True) + fig.savefig(str(output_path), bbox_inches="tight", pad_inches=0) + + return fig def plot_segmentation_masks( @@ -157,8 +157,7 @@ def plot_segmentation_masks( axes[i, j].axis("off") if i * ncol + j < images.shape[0]: if fnames: - title = str(Path(fnames[i * ncol + j]).stem) - axes[i, j].set_title(title, fontsize=8) + axes[i, j].set_title(fnames[i * ncol + j], fontsize=8) # draw images axes[i, j].imshow(images[i * ncol + j, ...].squeeze(), cmap="gray", interpolation="bicubic") @@ -178,7 +177,7 @@ def plot_segmentation_masks( masks[i * ncol + j, ...].squeeze(), levels=[x - 0.01 for x in list], colors=colors[min(list) - 1 : max(list)], - ) + ) #! TypeError: slice indices must be integers or None or have an __index__ method else: continue @@ -190,14 +189,14 @@ def check_dataloader( phase: Phases, dataloader: DataLoader, mask_key: str, - out_dir: Union[Path, str], + out_dir: Optional[Union[Path, str]], dataset_name: str, overlap_method: Optional[str] = None, alpha: float = 0.7, save_raw: bool = False, logger: logging.Logger = logging.getLogger("data-check"), progress_bar: Optional[Callable] = None -) -> None: +): logger_name = logger.name first_batch = first(dataloader) img_key = cfg.get_key("IMAGE") @@ -206,6 +205,7 @@ def check_dataloader( exist_mask = first_batch.get(msk_key) is not None channel, shape = data_shape[0], data_shape[1:] logger.info(f"Data Channel: {channel}, Shape: {shape}") + figures = [] if overlap_method and not exist_mask: logger.warn(f"{msk_key} is not found in datalist.") @@ -233,7 +233,7 @@ def _check_mask(masks, fnames): if len(shape) == 2 and channel == 1: for i, data in enumerate(tqdm(dataloader)): bs = dataloader.batch_size - fnames = data[str(img_key) + "_meta_dict"]["filename_or_obj"] + fnames = get_unique_filename(data[str(img_key) + "_meta_dict"]["filename_or_obj"]) if exist_mask and overlap_method: mask_class_num, msk = _check_mask(data[msk_key], fnames) else: @@ -243,7 +243,8 @@ def _check_mask(masks, fnames): column = row if (row - 1) * column >= bs: row -= 1 - save_2d_image_grid( + + figs = save_2d_image_grid( data[img_key], row, column, @@ -257,19 +258,23 @@ def _check_mask(masks, fnames): mask_class_num=mask_class_num, alpha=alpha, ) + figures.append(figs) + if save_raw: save_raw_image( data[img_key], data[f"{img_key}_meta_dict"], out_dir, phase.value, dataset_name, i, logger_name ) if progress_bar: - progress_bar(i/len(dataloader)) + progress_bar((i + 1) / len(dataloader)) + + return figures elif len(shape) == 2 and channel > 1: z_axis = 1 for i, data in enumerate(tqdm(dataloader)): bs = dataloader.batch_size or 1 # prevent None - fnames = data[str(img_key) + "_meta_dict"]["filename_or_obj"] + fnames = get_unique_filename(data[str(img_key) + "_meta_dict"]["filename_or_obj"]) if exist_mask and overlap_method: mask_class_num, msk = _check_mask(data[msk_key], fnames) else: @@ -285,8 +290,9 @@ def _check_mask(masks, fnames): data[img_key], data[f"{img_key}_meta_dict"], out_dir, phase.value, dataset_name, i, logger_name ) + channel_figures = {ch_idx : [] for ch_idx in range(channel)} for ch_idx in range(channel): - save_2d_image_grid( + figs = save_2d_image_grid( data[img_key], row, column, @@ -302,15 +308,28 @@ def _check_mask(masks, fnames): mask_class_num=mask_class_num, alpha=alpha, ) + channel_figures[ch_idx].append(figs) + + # fill the figures with [{0: fig1, 1: fig2, 2: fig3}, {0: fig1, 1: fig2, 2: fig3}, ...] + for _, ch_idx in enumerate(channel_figures): + if _ == 0: + for item in channel_figures[ch_idx]: + figures.append({ch_idx: item}) + else: + for i, item in enumerate(channel_figures[ch_idx]): + figures[i].update({ch_idx: item}) + if progress_bar: - progress_bar(i/len(dataloader)) + progress_bar((i + 1) / len(dataloader)) + + return figures elif len(shape) == 3 and channel == 1: - z_axis = np.argmin(shape) + z_axis = int(np.argmin(shape)) for i, data in enumerate(tqdm(dataloader)): bs = dataloader.batch_size or 1 - fnames = data[str(img_key) + "_meta_dict"]["filename_or_obj"] + fnames = get_unique_filename(data[str(img_key) + "_meta_dict"]["filename_or_obj"]) if exist_mask and overlap_method: mask_class_num, msk = _check_mask(data[msk_key], fnames) else: @@ -321,8 +340,10 @@ def _check_mask(masks, fnames): column = row if (row - 1) * column >= bs: row -= 1 - for slice_idx in range(shape[z_axis]): - save_3d_image_grid( + + fig_volume = [] + for slice_idx in range(data[img_key].shape[z_axis + 2]): + figs = save_3d_image_grid( data[img_key], z_axis + 2, row, @@ -339,13 +360,16 @@ def _check_mask(masks, fnames): mask_class_num=mask_class_num, alpha=alpha, ) + fig_volume.append(figs) + figures.append(fig_volume) if save_raw: save_raw_image( data[img_key], data[f"{img_key}_meta_dict"], out_dir, phase.value, dataset_name, i, logger_name ) if progress_bar: - progress_bar(i/len(dataloader)) + progress_bar((i + 1) / len(dataloader)) + return figures else: raise NotImplementedError(f"Not implement data-checking for shape of {shape}, channel of {channel}") diff --git a/strix/utilities/utils.py b/strix/utilities/utils.py index 3fb7af6..2c8b41b 100644 --- a/strix/utilities/utils.py +++ b/strix/utilities/utils.py @@ -7,7 +7,7 @@ import warnings from functools import partial from pathlib import Path -from typing import Any, Callable, Optional, TextIO, Union, List +from typing import Any, Callable, Optional, TextIO, Union, List, Sequence, Tuple import matplotlib import pylab import torch @@ -363,6 +363,32 @@ def output_filename_check(torch_dataset, meta_key="image_meta_dict"): return "" +# todo: check through all data +def get_unique_filename(file_paths: List, index: int = 0): + if len(file_paths) == 1: + fpath = file_paths[0] + if isinstance(fpath, list): + fpath = fpath[0] + return Path(fpath).stem + + if isinstance(prev_fpath := file_paths[0], list): + prev_fpath = prev_fpath[0] + if isinstance(next_fpath := file_paths[1], list): + next_fpath = next_fpath[0] + + root_dir = None + prev_parents = list(Path(prev_fpath).parents)[::-1] + next_parents = list(Path(next_fpath).parents)[::-1] + for (prev_item, next_item) in zip(prev_parents, next_parents): + if prev_item.stem != next_item.stem: + root_dir = prev_item.parent + break + + if root_dir: + return [str(Path(file_path).relative_to(root_dir).parts[index]) for file_path in file_paths] + return "" + + def detect_port(port): """Detect if the port is used""" socket_test = socket.socket(socket.AF_INET, socket.SOCK_STREAM) From 77e68ea4bf25759c27fbdb6ad1a9c39f1572d2f2 Mon Sep 17 00:00:00 2001 From: Dexuan Li <18693088246@163.com> Date: Wed, 30 Nov 2022 21:04:06 +0800 Subject: [PATCH 5/5] Add check_histogram and fixed plot problem. --- strix/data_checker.py | 7 +- strix/utilities/check_data_utils.py | 133 ++++++++++++++++++++++++++-- strix/utilities/enum.py | 1 + 3 files changed, 130 insertions(+), 11 deletions(-) diff --git a/strix/data_checker.py b/strix/data_checker.py index 3b614bb..a4a8a00 100644 --- a/strix/data_checker.py +++ b/strix/data_checker.py @@ -54,6 +54,7 @@ @option("--mask-key", type=str, default="mask", help="Specify mask key, default is 'mask'") @option("--seed", type=int, default=101, help="random seed") @option("--out-dir", type=str, prompt=True, default=cfg.get_strix_cfg("OUTPUT_DIR")) +@option("--check-histogram", is_flag=True, help="Plotting histograms of data") @option( "--dump-params", hidden=True, @@ -115,8 +116,9 @@ def get_train_valid_datasets(): dataset_name=cargs.data_list, overlap_method=overlap_m, alpha=cargs.alpha, - line_width = cargs.line_width, + line_width=cargs.line_width, save_raw=cargs.save_raw, + check_hist=cargs.check_histogram, logger=logger, ) @@ -128,7 +130,8 @@ def get_train_valid_datasets(): dataset_name=cargs.data_list, overlap_method=overlap_m, alpha=cargs.alpha, - line_width = cargs.line_width, + line_width=cargs.line_width, save_raw=cargs.save_raw, + check_hist=cargs.check_histogram, logger=logger, ) diff --git a/strix/utilities/check_data_utils.py b/strix/utilities/check_data_utils.py index 81ba846..bfdc508 100644 --- a/strix/utilities/check_data_utils.py +++ b/strix/utilities/check_data_utils.py @@ -13,7 +13,7 @@ from monai_ex.utils import check_dir, first from termcolor import colored from tqdm import tqdm - +from skimage import exposure, img_as_float from strix.configures import config as cfg from strix.utilities.enum import Phases from strix.utilities.utils import setup_logger, get_colors, get_colormaps, get_unique_filename @@ -135,6 +135,35 @@ def save_3d_image_grid( return fig +def save_histogram_grid( + images: np.ndarray, + nrow: int, + ncol: int, + out_dir: Optional[Union[Path, str]], + phase: str, + dataset_name: str, + batch_index: int, + slice_index: int = None, + multichannel: bool = False, + fnames: Optional[List] = None, +): + images = images.detach().cpu().numpy() + + fig = plot_histogram( + images, + nrow, + ncol, + fnames=fnames, + ) + + if out_dir: + output_fname = f'channel_{slice_index}_histograms.png' if multichannel else 'histograms.png' + output_path = check_dir(out_dir, dataset_name, f"{phase}-batch{batch_index}", output_fname, isFile=True) + fig.savefig(str(output_path), bbox_inches="tight", pad_inches=0) + + return fig + + def plot_segmentation_masks( images: np.ndarray, masks: Optional[np.ndarray], @@ -149,36 +178,39 @@ def plot_segmentation_masks( # cm = 1/2.54 plt.close() ratio = images.shape[-1] / images.shape[-2] - fig = plt.figure(figsize=(15 * ratio * ncol, 15 * nrow)) + fig = plt.figure(figsize=(10 * ratio * ncol, 10 * nrow), dpi=300) axes = fig.subplots(nrow, ncol) color_num = mask_class_num if color_num > 0: colors = get_colors(color_num) cmap = get_colormaps(color_num) + else: + cmap, colors = None, None for i in range(nrow): for j in range(ncol): - axes[i, j].axis("off") + ax = axes[i, j] if nrow > 1 else axes[j] + ax.axis("off") if i * ncol + j < images.shape[0]: if fnames: - axes[i, j].set_title(fnames[i * ncol + j], fontsize=8) + ax.set_title(fnames[i * ncol + j], fontsize=8) # draw images - axes[i, j].imshow(images[i * ncol + j, ...].squeeze(), cmap="gray", interpolation="bicubic") + ax.imshow(images[i * ncol + j, ...].squeeze(), cmap="gray", interpolation="bicubic") # draw masks if exists - if method == "mask" and masks is not None: - axes[i, j].imshow( + if method == "mask" and masks is not None and cmap is not None: + ax.imshow( np.ma.masked_equal(masks[i * ncol + j, ...].squeeze(), 0), cmap, # ! unbound! alpha=alpha, norm=Normalize(vmin=1, vmax=color_num), ) - elif method == "contour" and masks is not None: + elif method == "contour" and masks is not None and colors is not None: list = [i for i in np.unique(masks[i * ncol + j, ...].squeeze()).tolist() if i] if len(list) > 0: - axes[i, j].contour( + ax.contour( masks[i * ncol + j, ...].squeeze(), levels=[x - 0.01 for x in list], colors=colors[min(list) - 1: max(list)], @@ -191,6 +223,48 @@ def plot_segmentation_masks( return fig +def plot_histogram( + images: np.ndarray, + nrow: int, + ncol: int, + fnames: Optional[List] = None, +): + plt.close() + plt.style.use('seaborn-white') + ratio = images.shape[-2] / images.shape[-3] + fig = plt.figure(figsize=(10 * ratio * ncol, 10 * nrow)) + axes = fig.subplots(nrow, ncol, sharey=True) + images = img_as_float(images) + + for i in range(nrow): + for j in range(ncol): + if i * ncol + j < images.shape[0]: + img = images[i * ncol + j, ...] + ax_hist = axes[i, j] if nrow > 1 else axes[j] + ax_cdf = ax_hist.twinx() + if fnames: + ax_hist.set_title(fnames[i * ncol + j], fontsize=8) + + # Display histogram + ax_hist.hist(img.ravel(), bins=256, color='tab:blue') + ax_hist.ticklabel_format(axis='y', style='scientific', scilimits=(0, 0)) + x_min, x_max = ax_hist.get_xlim() + ax_hist.set_xlabel('Pixel intensity') + ax_hist.set_xticks(np.linspace(x_min, x_max, 5)) + y_min, y_max = ax_hist.get_ylim() + ax_hist.set_ylabel('Number of pixels') + ax_hist.set_yticks(np.linspace(y_min, y_max, 5)) + + # Display cumulative distribution + img_cdf, bins = exposure.cumulative_distribution(img, 256) + ax_cdf.plot(bins, img_cdf, 'tab:red') + ax_cdf.set_yticks([]) + + plt.tight_layout() + + return fig + + def check_dataloader( phase: Phases, dataloader: DataLoader, @@ -201,6 +275,7 @@ def check_dataloader( alpha: float = 0.7, line_width: float = 0.1, save_raw: bool = False, + check_hist: bool = False, logger: logging.Logger = logging.getLogger("data-check"), progress_bar: Optional[Callable] = None ): @@ -268,6 +343,19 @@ def _check_mask(masks, fnames): ) figures.append(figs) + if check_hist: + hist = save_histogram_grid( + data[img_key], + row, + column, + out_dir, + phase.value, + dataset_name, + i, + fnames=fnames, + ) + figures.append(hist) + if save_raw: save_raw_image( data[img_key], data[f"{img_key}_meta_dict"], out_dir, phase.value, dataset_name, i, logger_name @@ -319,6 +407,19 @@ def _check_mask(masks, fnames): ) channel_figures[ch_idx].append(figs) + if check_hist: + hist = save_histogram_grid( + data[img_key], + row, + column, + out_dir, + phase.value, + dataset_name, + i, + fnames=fnames, + ) + figures.append(hist) + # fill the figures with [{0: fig1, 1: fig2, 2: fig3}, {0: fig1, 1: fig2, 2: fig3}, ...] for _, ch_idx in enumerate(channel_figures): if _ == 0: @@ -373,6 +474,19 @@ def _check_mask(masks, fnames): fig_volume.append(figs) figures.append(fig_volume) + if check_hist: + hist = save_histogram_grid( + data[img_key], + row, + column, + out_dir, + phase.value, + dataset_name, + batch_index=i, + fnames=fnames, + ) + figures.append(hist) + if save_raw: save_raw_image( data[img_key], data[f"{img_key}_meta_dict"], out_dir, phase.value, dataset_name, i, logger_name @@ -381,5 +495,6 @@ def _check_mask(masks, fnames): progress_bar((i + 1) / len(dataloader)) return figures + else: raise NotImplementedError(f"Not implement data-checking for shape of {shape}, channel of {channel}") diff --git a/strix/utilities/enum.py b/strix/utilities/enum.py index 7c2bc8c..19bcad6 100644 --- a/strix/utilities/enum.py +++ b/strix/utilities/enum.py @@ -142,6 +142,7 @@ class Freezers(Enum): FULL = "full" SUBTASK = "subtask" + FREEZERS = get_enums(Freezers)