-
Notifications
You must be signed in to change notification settings - Fork 1.4k
[WIP] Support to train/run Deepgrow 2D/3D models #1395
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
a3e419c
0f626f6
d15f04e
95f11db
9fd9456
06879fe
ec8d0d7
bf4d031
91aa795
a2c0c89
41ad0af
3f70ccf
672ae85
cad793a
e45e4e0
7c641b8
8787d3e
28191dd
f92c319
284fa13
cfdc769
34a4422
cc81759
d5a6d6a
3893a62
19bd7b7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,10 @@ | ||
| # Copyright 2020 MONAI Consortium | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,268 @@ | ||
| # Copyright 2020 - 2021 MONAI Consortium | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import logging | ||
| import os | ||
| from typing import Dict, List | ||
|
|
||
| import numpy as np | ||
|
|
||
| from monai.transforms import AsChannelFirstd, Compose, LoadImaged, Orientationd, Spacingd | ||
| from monai.utils import GridSampleMode | ||
|
|
||
|
|
||
| def create_dataset( | ||
| datalist, | ||
| output_dir, | ||
| dimension, | ||
| pixdim, | ||
| keys=("image", "label"), | ||
| base_dir=None, | ||
| limit=0, | ||
| relative_path=False, | ||
| transforms=None, | ||
| ) -> List[Dict]: | ||
| """ | ||
| Utility to pre-process and create dataset list for Deepgrow training over on existing one. | ||
| The input data list is normally a list of images and labels (3D volume) that needs pre-processing | ||
| for Deepgrow training pipeline. | ||
|
|
||
| Args: | ||
| datalist: A generic dataset with a length property which normally contains a list of data dictionary. | ||
| For example, typical input data can be a list of dictionaries:: | ||
|
|
||
| [{'image': 'img1.nii', 'label': 'label1.nii'}] | ||
|
|
||
| output_dir: target directory to store the training data for Deepgrow Training | ||
| pixdim: output voxel spacing. | ||
| dimension: dimension for Deepgrow training. It can be 2 or 3. | ||
| keys: Image and Label keys in input datalist. Defaults to 'image' and 'label' | ||
| base_dir: base directory in case related path is used for the keys in datalist. Defaults to None. | ||
| limit: limit number of inputs for pre-processing. Defaults to 0 (no limit). | ||
| relative_path: output keys values should be based on relative path. Defaults to False. | ||
| transforms: explicit transforms to execute operations on input data. | ||
|
|
||
| Raises: | ||
| ValueError: When ``dimension`` is not one of [2, 3] | ||
| ValueError: When ``datalist`` is Empty | ||
|
|
||
| Example:: | ||
|
|
||
| datalist = create_dataset( | ||
| datalist=[{'image': 'img1.nii', 'label': 'label1.nii'}], | ||
| base_dir=None, | ||
| output_dir=output_2d, | ||
| dimension=2, | ||
| keys=('image', 'label') | ||
| pixdim=(1.0, 1.0), | ||
| limit=0, | ||
| relative_path=True | ||
| ) | ||
|
|
||
| print(datalist[0]["image"], datalist[0]["label"]) | ||
| """ | ||
|
|
||
| if dimension not in [2, 3]: | ||
| raise ValueError("Dimension can be only 2 or 3 as Deepgrow supports only 2D/3D Training") | ||
|
|
||
| if not len(datalist): | ||
| raise ValueError("Input Datalist is empty") | ||
|
|
||
| if not isinstance(keys, list) and not isinstance(keys, tuple): | ||
| keys = [keys] | ||
|
|
||
| transforms = _default_transforms(keys, pixdim) if transforms is None else transforms | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi @wyli @SachidanandAlle , I feel this
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. based on Daguang/research ask, we are providing only utility to flatten/pre-process the data prior to training. |
||
| new_datalist = [] | ||
| for idx in range(len(datalist)): | ||
| if limit and idx >= limit: | ||
| break | ||
|
|
||
| image = datalist[idx][keys[0]] | ||
| label = datalist[idx].get(keys[1]) if len(keys) > 1 else None | ||
| if base_dir: | ||
| image = os.path.join(base_dir, image) | ||
| label = os.path.join(base_dir, label) if label else None | ||
|
|
||
| image = os.path.abspath(image) | ||
| label = os.path.abspath(label) if label else None | ||
|
|
||
| logging.info("Image: {}; Label: {}".format(image, label if label else None)) | ||
| if dimension == 2: | ||
Nic-Ma marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| data = _save_data_2d( | ||
| vol_idx=idx, | ||
| data=transforms({"image": image, "label": label}), | ||
| keys=("image", "label"), | ||
| dataset_dir=output_dir, | ||
| relative_path=relative_path, | ||
| ) | ||
| else: | ||
| data = _save_data_3d( | ||
| vol_idx=idx, | ||
| data=transforms({"image": image, "label": label}), | ||
| keys=("image", "label"), | ||
| dataset_dir=output_dir, | ||
| relative_path=relative_path, | ||
| ) | ||
| new_datalist.extend(data) | ||
| return new_datalist | ||
|
|
||
|
|
||
| def _default_transforms(keys, pixdim): | ||
| mode = [GridSampleMode.BILINEAR, GridSampleMode.NEAREST] if len(keys) == 2 else [GridSampleMode.BILINEAR] | ||
| return Compose([ | ||
| LoadImaged(keys=keys), | ||
| AsChannelFirstd(keys=keys), | ||
| Spacingd(keys=keys, pixdim=pixdim, mode=mode), | ||
| Orientationd(keys=keys, axcodes="RAS"), | ||
| ]) | ||
|
|
||
|
|
||
| def _save_data_2d(vol_idx, data, keys, dataset_dir, relative_path): | ||
| vol_image = data[keys[0]] | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe it's better to input |
||
| vol_label = data.get(keys[1]) | ||
| data_list = [] | ||
|
|
||
| if len(vol_image.shape) == 4: | ||
| logging.info( | ||
| "4D-Image, pick only first series; Image: {}; Label: {}".format( | ||
| vol_image.shape, vol_label.shape if vol_label else None | ||
| ) | ||
| ) | ||
| vol_image = vol_image[0] | ||
| vol_image = np.moveaxis(vol_image, -1, 0) | ||
|
|
||
| image_count = 0 | ||
| label_count = 0 | ||
| unique_labels_count = 0 | ||
| for sid in range(vol_image.shape[0]): | ||
| image = vol_image[sid, ...] | ||
| label = vol_label[sid, ...] if vol_label is not None else None | ||
|
|
||
| if vol_label is not None and np.sum(label) == 0: | ||
| continue | ||
|
|
||
| image_file_prefix = "vol_idx_{:0>4d}_slice_{:0>3d}".format(vol_idx, sid) | ||
| image_file = os.path.join(dataset_dir, "images", image_file_prefix) | ||
| image_file += ".npy" | ||
|
|
||
| os.makedirs(os.path.join(dataset_dir, "images"), exist_ok=True) | ||
| np.save(image_file, image) | ||
| image_count += 1 | ||
|
|
||
| # Test Data | ||
| if vol_label is None: | ||
| data_list.append( | ||
| { | ||
| "image": image_file.replace(dataset_dir + "/", "") if relative_path else image_file, | ||
| } | ||
| ) | ||
| continue | ||
|
|
||
| # For all Labels | ||
| unique_labels = np.unique(label.flatten()) | ||
| unique_labels = unique_labels[unique_labels != 0] | ||
| unique_labels_count = max(unique_labels_count, len(unique_labels)) | ||
|
|
||
| for idx in unique_labels: | ||
| label_file_prefix = "{}_region_{:0>2d}".format(image_file_prefix, int(idx)) | ||
| label_file = os.path.join(dataset_dir, "labels", label_file_prefix) | ||
| label_file += ".npy" | ||
|
|
||
| os.makedirs(os.path.join(dataset_dir, "labels"), exist_ok=True) | ||
| curr_label = (label == idx).astype(np.float32) | ||
| np.save(label_file, curr_label) | ||
|
|
||
| label_count += 1 | ||
| data_list.append( | ||
| { | ||
| "image": image_file.replace(dataset_dir + "/", "") if relative_path else image_file, | ||
| "label": label_file.replace(dataset_dir + "/", "") if relative_path else label_file, | ||
| "region": int(idx), | ||
| } | ||
| ) | ||
|
|
||
| logging.info( | ||
| "{} => Image Shape: {} => {}; Label Shape: {} => {}; Unique Labels: {}".format( | ||
| vol_idx, | ||
| vol_image.shape, | ||
| image_count, | ||
| vol_label.shape if vol_label is not None else None, | ||
| label_count, | ||
| unique_labels_count, | ||
| ) | ||
| ) | ||
| return data_list | ||
|
|
||
|
|
||
| def _save_data_3d(vol_idx, data, keys, dataset_dir, relative_path): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think 2D and 3D have very similar logic, could you please merge them into 1 function?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. acutally its different.. 2D flattens slices over multiple labels and 3D only does for labels. |
||
| vol_image = data[keys[0]] | ||
| vol_label = data.get(keys[1]) | ||
| data_list = [] | ||
|
|
||
| if len(vol_image.shape) == 4: | ||
| logging.info("4D-Image, pick only first series; Image: {}; Label: {}".format(vol_image.shape, vol_label.shape)) | ||
| vol_image = vol_image[0] | ||
| vol_image = np.moveaxis(vol_image, -1, 0) | ||
|
|
||
| image_count = 0 | ||
| label_count = 0 | ||
| unique_labels_count = 0 | ||
|
|
||
| image_file_prefix = "vol_idx_{:0>4d}".format(vol_idx) | ||
| image_file = os.path.join(dataset_dir, "images", image_file_prefix) | ||
| image_file += ".npy" | ||
|
|
||
| os.makedirs(os.path.join(dataset_dir, "images"), exist_ok=True) | ||
| np.save(image_file, vol_image) | ||
| image_count += 1 | ||
|
|
||
| # Test Data | ||
| if vol_label is None: | ||
| data_list.append( | ||
| { | ||
| "image": image_file.replace(dataset_dir + "/", "") if relative_path else image_file, | ||
| } | ||
| ) | ||
| else: | ||
| # For all Labels | ||
| unique_labels = np.unique(vol_label.flatten()) | ||
| unique_labels = unique_labels[unique_labels != 0] | ||
| unique_labels_count = max(unique_labels_count, len(unique_labels)) | ||
|
|
||
| for idx in unique_labels: | ||
| label_file_prefix = "{}_region_{:0>2d}".format(image_file_prefix, int(idx)) | ||
| label_file = os.path.join(dataset_dir, "labels", label_file_prefix) | ||
| label_file += ".npy" | ||
|
|
||
| curr_label = (vol_label == idx).astype(np.float32) | ||
| os.makedirs(os.path.join(dataset_dir, "labels"), exist_ok=True) | ||
| np.save(label_file, curr_label) | ||
|
|
||
| label_count += 1 | ||
| data_list.append( | ||
| { | ||
| "image": image_file.replace(dataset_dir + "/", "") if relative_path else image_file, | ||
| "label": label_file.replace(dataset_dir + "/", "") if relative_path else label_file, | ||
| "region": int(idx), | ||
| } | ||
| ) | ||
|
|
||
| logging.info( | ||
| "{} => Image Shape: {} => {}; Label Shape: {} => {}; Unique Labels: {}".format( | ||
| vol_idx, | ||
| vol_image.shape, | ||
| image_count, | ||
| vol_label.shape if vol_label is not None else None, | ||
| label_count, | ||
| unique_labels_count, | ||
| ) | ||
| ) | ||
| return data_list | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,81 @@ | ||
| # Copyright 2020 - 2021 MONAI Consortium | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| import importlib | ||
| from typing import Dict | ||
|
|
||
| import torch | ||
|
|
||
| from monai.engines.utils import CommonKeys | ||
| from monai.engines.workflow import Engine, Events | ||
| from monai.transforms import Compose | ||
|
|
||
|
|
||
| class Interaction: | ||
| """ | ||
| Ignite handler used to introduce interactions (simulation of clicks) for Deepgrow Training/Evaluation. | ||
|
|
||
| Args: | ||
| transforms: execute additional transformation during every iteration (before train). | ||
| Typically, several Tensor based transforms composed by `Compose`. | ||
| max_interactions: maximum number of interactions per iteration | ||
| train: training or evaluation | ||
| key_probability: field name to fill probability for every interaction | ||
| """ | ||
|
|
||
| def __init__(self, transforms, max_interactions: int, train: bool, key_probability: str = "probability") -> None: | ||
| self.transforms = transforms | ||
| self.max_interactions = max_interactions | ||
| self.train = train | ||
| self.key_probability = key_probability | ||
|
|
||
| if not isinstance(self.transforms, Compose): | ||
| transforms = [] | ||
| for t in self.transforms: | ||
| transforms.append(self.init_external_class(t)) | ||
| self.transforms = Compose(transforms) | ||
|
|
||
| @staticmethod | ||
| def init_external_class(config_dict): | ||
| class_args = None if config_dict.get("args") is None else dict(config_dict.get("args")) | ||
| class_path = config_dict.get("path", config_dict["name"]) | ||
|
|
||
| module_name, class_name = class_path.rsplit(".", 1) | ||
| m = importlib.import_module(module_name) | ||
| c = getattr(m, class_name) | ||
| return c(**class_args) if class_args else c() | ||
|
|
||
| def attach(self, engine: Engine) -> None: | ||
| if not engine.has_event_handler(self, Events.ITERATION_STARTED): | ||
| engine.add_event_handler(Events.ITERATION_STARTED, self) | ||
|
|
||
| def __call__(self, engine: Engine, batchdata: Dict[str, torch.Tensor]): | ||
| if batchdata is None: | ||
| raise ValueError("Must provide batch data for current iteration.") | ||
|
|
||
| for j in range(self.max_interactions): | ||
| inputs, _ = engine.prepare_batch(batchdata) | ||
| inputs = inputs.to(engine.state.device) | ||
|
|
||
| engine.network.eval() | ||
| with torch.no_grad(): | ||
| if engine.amp: | ||
| with torch.cuda.amp.autocast(): | ||
| predictions = engine.inferer(inputs, engine.network) | ||
| else: | ||
| predictions = engine.inferer(inputs, engine.network) | ||
|
|
||
| batchdata.update({CommonKeys.PRED: predictions}) | ||
| batchdata[self.key_probability] = torch.as_tensor( | ||
| ([1.0 - ((1.0 / self.max_interactions) * j)] if self.train else [1.0]) * len(inputs) | ||
| ) | ||
| batchdata = self.transforms(batchdata) | ||
|
|
||
| return engine._iteration(engine, batchdata) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As this handler will attach to
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the batchdata is updated... and the _iteration() must happen on new modified batchdata. will returning batchdata is enough? and _iteration() in later call uses the updated batchdata? in that case, we need one round of training to make sure it works as expected. this code was based on your initial feedback on how to run interactions while training. may be something got changed recently. nevertheless, you can confirm what to make _iteration run on new/updated batchdata in above case.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Suggest to use: engine.state.batch = self.transforms(batchdata)Then return nothing in this function. Thanks.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you mean Traceback (most recent call last): Traceback (most recent call last): |
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is
create_datalistinstead ofcreate_dataset, right?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this should be create_dataset.
As it reads in the raw images and it will transform them into new volumes (a lot of .npy files) and store them back on disk.
So that means it creates a new dataset.
Just that this function also creates a datalist that has the entries of the new dataset and returns it.