Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions monailabel/tasks/train/basic_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def __init__(
stats_path=None,
train_save_interval=20,
val_interval=1,
n_saved=5,
final_filename="checkpoint_final.pt",
key_metric_filename="model.pt",
model_dict_key="model",
Expand All @@ -123,6 +124,7 @@ def __init__(
:param stats_path: Path to save the train stats
:param train_save_interval: checkpoint save interval for training
:param val_interval: validation interval (run every x epochs)
:param n_saved: max checkpoints to save
:param final_filename: name of final checkpoint that will be saved
:param key_metric_filename: best key metric model file name
:param model_dict_key: key to save network weights into checkpoint
Expand Down Expand Up @@ -157,6 +159,7 @@ def __init__(

self._train_save_interval = train_save_interval
self._val_interval = val_interval
self._n_saved = n_saved
self._final_filename = final_filename
self._key_metric_filename = key_metric_filename
self._model_dict_key = model_dict_key
Expand Down Expand Up @@ -340,7 +343,7 @@ def config(self):

@staticmethod
def _validate_transforms(transforms, step="Training", name="pre"):
if not transforms or isinstance(transforms, Compose):
if not transforms or isinstance(transforms, Compose) or callable(transforms):
return transforms
if isinstance(transforms, list):
return Compose(transforms)
Expand Down Expand Up @@ -528,7 +531,7 @@ def _create_evaluator(self, context: Context):
save_dict={self._model_dict_key: context.network},
save_key_metric=True,
key_metric_filename=self._key_metric_filename,
n_saved=5,
n_saved=self._n_saved,
)
)

Expand Down Expand Up @@ -560,7 +563,7 @@ def _create_trainer(self, context: Context):
key_metric_filename=f"train_{self._key_metric_filename}"
if context.evaluator
else self._key_metric_filename,
n_saved=5,
n_saved=self._n_saved,
)
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,17 +94,20 @@ public void run() {
list.addIntParameter("Width", "Width", bbox[2]);
list.addIntParameter("Height", "Height", bbox[3]);

boolean override = !info.models.get(selectedModel).nuclick;
list.addBooleanParameter("Override", "Override", override);

if (Dialogs.showParameterDialog("MONAILabel", list)) {
String model = (String) list.getChoiceParameterValue("Model");
bbox[0] = list.getIntParameterValue("X").intValue();
bbox[1] = list.getIntParameterValue("Y").intValue();
bbox[2] = list.getIntParameterValue("Width").intValue();
bbox[3] = list.getIntParameterValue("Height").intValue();
override = list.getBooleanParameterValue("Override").booleanValue();

selectedModel = model;
selectedBBox = bbox;
boolean isNuClick = info.models.get(model).nuclick;
runInference(model, new HashSet<String>(Arrays.asList(labels.get(model))), bbox, imageData, isNuClick);
runInference(model, new HashSet<String>(Arrays.asList(labels.get(model))), bbox, imageData, override);
}
} catch (Exception ex) {
ex.printStackTrace();
Expand Down Expand Up @@ -145,9 +148,9 @@ private int[] getBBOX(ROI roi) {
}

private void runInference(String model, Set<String> labels, int[] bbox, ImageData<BufferedImage> imageData,
boolean isNuClick) throws SAXException, IOException, ParserConfigurationException, InterruptedException {
boolean override) throws SAXException, IOException, ParserConfigurationException, InterruptedException {
logger.info("MONAILabel Annotation - Run Inference...");
logger.info("Model: " + model + "; IsNuClick: " + isNuClick + "; Labels: " + labels);
logger.info("Model: " + model + "; override: " + override + "; Labels: " + labels);

String image = Utils.getNameWithoutExtension(imageData.getServerPath());

Expand All @@ -163,7 +166,7 @@ private void runInference(String model, Set<String> labels, int[] bbox, ImageDat

Document dom = MonaiLabelClient.infer(model, image, req);
NodeList annotation_list = dom.getElementsByTagName("Annotation");
int count = updateAnnotations(labels, annotation_list, roi, imageData, !isNuClick);
int count = updateAnnotations(labels, annotation_list, roi, imageData, override);

// Update hierarchy to see changes in QuPath's hierarchy
QP.fireHierarchyUpdate(imageData.getHierarchy());
Expand Down
30 changes: 26 additions & 4 deletions sample-apps/pathology/lib/configs/nuclick.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import lib.infers
import lib.trainers
from lib.nets import UNet
from monai.networks.nets import BasicUNet

from monailabel.interfaces.config import TaskConfig
from monailabel.interfaces.tasks.infer import InferTask
Expand All @@ -42,11 +42,16 @@ def init(self, name: str, model_dir: str, conf: Dict[str, str], planner: Any, **

# Download PreTrained Model
if strtobool(self.conf.get("use_pretrained_model", "true")):
url = f"{self.PRE_TRAINED_PATH}/NuClick_UNet_40xAll.pth"
url = f"{self.PRE_TRAINED_PATH}/pathology_nuclick_bunet.pt"
download_file(url, self.path[0])

# Network
self.network = UNet(n_channels=5, n_classes=1)
self.network = BasicUNet(
spatial_dims=2,
in_channels=5,
out_channels=1,
features=(32, 64, 128, 256, 512, 32),
)

def infer(self) -> Union[InferTask, Dict[str, InferTask]]:
task: InferTask = lib.infers.NuClick(
Expand All @@ -59,4 +64,21 @@ def infer(self) -> Union[InferTask, Dict[str, InferTask]]:
return task

def trainer(self) -> Optional[TrainTask]:
return None
output_dir = os.path.join(self.model_dir, self.name)
task: TrainTask = lib.trainers.NuClick(
model_dir=output_dir,
network=self.network,
load_path=self.path[0],
publish_path=self.path[1],
labels=self.labels,
description="Train Nuclei DeepEdit Model",
train_save_interval=1,
config={
"max_epochs": 10,
"train_batch_size": 64,
"dataset_max_region": (10240, 10240),
"dataset_limit": 0,
"dataset_randomize": True,
},
)
return task
18 changes: 10 additions & 8 deletions sample-apps/pathology/lib/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,14 +141,16 @@ def write_images(self, batch_data, output_data, epoch):
label[y == region] = region

self.logger.info(
"{} - {} - Image: {}; Label: {} (nz: {}); Pred: {} (nz: {})".format(
"{} - {} - Image: {}; Label: {} (nz: {}); Pred: {} (nz: {}); Sig: (pos-nz: {}, neg-nz: {})".format(
bidx,
region,
image.shape,
label.shape,
np.count_nonzero(label),
y_pred.shape,
np.count_nonzero(y_pred[region]),
np.count_nonzero(image[3]) if image.shape == 5 else 0,
np.count_nonzero(image[4]) if image.shape == 5 else 0,
)
)

Expand All @@ -172,15 +174,15 @@ def write_images(self, batch_data, output_data, epoch):
break

def write_region_metrics(self, epoch):
metric_sum = 0
for region in self.metric_data:
metric = self.metric_data[region].mean()
self.logger.info(f"Epoch[{epoch}] Metrics -- Region: {region:0>2d}, {self.tag_name}: {metric:.4f}")
if len(self.metric_data) > 1:
metric_sum = 0
for region in self.metric_data:
metric = self.metric_data[region].mean()
self.logger.info(f"Epoch[{epoch}] Metrics -- Region: {region:0>2d}, {self.tag_name}: {metric:.4f}")

self.writer.add_scalar(f"dice_{region:0>2d}", metric, epoch)
metric_sum += metric
self.writer.add_scalar(f"dice_{region:0>2d}", metric, epoch)
metric_sum += metric

if len(self.metric_data) > 1:
metric_avg = metric_sum / len(self.metric_data)
self.writer.add_scalar("dice_regions_avg", metric_avg, epoch)

Expand Down
7 changes: 2 additions & 5 deletions sample-apps/pathology/lib/infers/nuclick.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,7 @@ def __call__(self, data):
return d

@staticmethod
def get_clickmap_boundingbox(cx, cy, m, n):
bb = 128
def get_clickmap_boundingbox(cx, cy, m, n, bb=128):
click_map = np.zeros((m, n), dtype=np.uint8)

# Removing points out of image dimension (these points may have been clicked unwanted)
Expand Down Expand Up @@ -162,9 +161,7 @@ def get_clickmap_boundingbox(cx, cy, m, n):
return click_map, bounding_boxes

@staticmethod
def get_patches_and_signals(img, click_map, bounding_boxes, cx, cy, m, n):
bb = 128

def get_patches_and_signals(img, click_map, bounding_boxes, cx, cy, m, n, bb=128):
# total = number of clicks
total = len(bounding_boxes)
img = np.array([img]) # img.shape=(1,3,m,n)
Expand Down
11 changes: 0 additions & 11 deletions sample-apps/pathology/lib/nets/__init__.py

This file was deleted.

105 changes: 0 additions & 105 deletions sample-apps/pathology/lib/nets/unet.py

This file was deleted.

1 change: 1 addition & 0 deletions sample-apps/pathology/lib/trainers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@
# limitations under the License.

from .deepedit_nuclei import DeepEditNuclei
from .nuclick import NuClick
from .segmentation_nuclei import SegmentationNuclei
Loading