diff --git a/README.md b/README.md
index 209f2a8..f385ef6 100644
--- a/README.md
+++ b/README.md
@@ -1,4 +1,4 @@
-
Welcome to PathFlowAI 👋
+Welcome to PathFlowAI
@@ -10,11 +10,11 @@
### 🏠 [Homepage](https://github.com/jlevy44/PathFlowAI)
-MedRxiv Manuscript: https://www.medrxiv.org/content/10.1101/19003897v1
+Published in the Proceedings of the Pacific Symposium for Biocomputing 2020, Manuscript: https://psb.stanford.edu/psb-online/proceedings/psb20/Levy.pdf
## Install
-First, install [openslide](https://openslide.org/download/).
+First, install [openslide](https://openslide.org/download/). Note: may need to install libiconv and shapely using conda. Will update with more installation information, please submit issues as well.
```sh
pip install pathflowai
diff --git a/bin/install_apex b/bin/install_apex
index af171c1..3dc2635 100644
--- a/bin/install_apex
+++ b/bin/install_apex
@@ -1,5 +1,7 @@
#!/bin/bash
+export TMPDIR=$HOME/tmp
+mkdir -p $TMPDIR
rm -rf apex
git clone https://github.com/NVIDIA/apex
cd apex
diff --git a/bin/install_lightnet b/bin/install_lightnet
new file mode 100644
index 0000000..eb0cbb7
--- /dev/null
+++ b/bin/install_lightnet
@@ -0,0 +1,11 @@
+#!/bin/bash
+#python -m lightnet download tiny-yolo
+#python -m lightnet download yolo
+rm -rf lightnet
+git clone https://gitlab.com/EAVISE/lightnet.git
+cd lightnet
+pip install .
+cd ..
+rm -rf lightnet
+#wget https://pjreddie.com/media/files/yolo.weights
+#wget https://pjreddie.com/media/files/tiny-yolo.weights
diff --git a/experimental/datasets.py b/experimental/datasets.py
new file mode 100644
index 0000000..4e8e98b
--- /dev/null
+++ b/experimental/datasets.py
@@ -0,0 +1,99 @@
+#
+# Lightnet dataset that works with brambox annotations
+# Copyright EAVISE
+#
+# https://eavise.gitlab.io/lightnet/_modules/lightnet/models/_dataset_brambox.html#BramboxDataset
+# https://eavise.gitlab.io/brambox/notes/02-getting_started.html#Loading-data
+
+import os
+import copy
+import logging
+from PIL import Image
+import numpy as np
+import lightnet.data as lnd
+from pathflowai.utils import load_sql_df
+import dask.array as da
+from os.path import join
+
+try:
+ import brambox as bb
+except ImportError:
+ bb = None
+
+__all__ = ['BramboxDataset']
+log = logging.getLogger(__name__)
+
+# ADD IMAGE ANNOTATION TRANSFORM
+# ADD TRAIN VAL TEST INFO
+
+class BramboxPathFlowDataset(lnd.Dataset):
+ """ Dataset for any brambox annotations.
+
+ Args:
+ annotations (dataframe): Dataframe containing brambox annotations
+ input_dimension (tuple): (width,height) tuple with default dimensions of the network
+ class_label_map (list): List of class_labels
+ identify (function, optional): Lambda/function to get image based of annotation filename or image id; Default **replace/add .png extension to filename/id**
+ img_transform (torchvision.transforms.Compose): Transforms to perform on the images
+ anno_transform (torchvision.transforms.Compose): Transforms to perform on the annotations
+
+ Note:
+ This dataset opens images with the Pillow library
+ """
+ def __init__(self, input_dir, patch_info_file, patch_size, annotations, input_dimension, class_label_map=None, identify=None, img_transform=None, anno_transform=None):
+ if bb is None:
+ raise ImportError('Brambox needs to be installed to use this dataset')
+ super().__init__(input_dimension)
+
+ self.annos = annotations
+ self.annos['ignore']=0
+ self.annos['class_label']=self.annos['class_label'].astype(int)#-1
+ print(self.annos['class_label'].unique())
+ #print(self.annos.shape)
+ self.keys = self.annos.image.cat.categories # stores unique patches
+ #print(self.keys)
+ self.img_tf = img_transform
+ self.anno_tf = anno_transform
+ self.patch_info=load_sql_df(patch_info_file, patch_size)
+ IDs=self.patch_info['ID'].unique()
+ self.slides = {slide:da.from_zarr(join(input_dir,'{}.zarr'.format(slide))) for slide in IDs}
+ self.id = lambda k: k.split('/')
+ # experiment
+ #self.annos['x_top_left'], self.annos['y_top_left']=self.annos['y_top_left'], self.annos['x_top_left']
+ self.annos['width'], self.annos['height']=self.annos['height'], self.annos['width']
+ # Add class_ids
+ if class_label_map is None:
+ log.warning(f'No class_label_map given, generating it by sorting unique class labels from data alphabetically, which is not always deterministic behaviour')
+ class_label_map = list(np.sort(self.annos.class_label.unique()))
+ self.annos['class_id'] = self.annos.class_label.map(dict((l, i) for i, l in enumerate(class_label_map)))
+
+ def __len__(self):
+ return len(self.keys)
+
+ @lnd.Dataset.resize_getitem
+ def __getitem__(self, index):
+ """ Get transformed image and annotations based of the index of ``self.keys``
+
+ Args:
+ index (int): index of the ``self.keys`` list containing all the image identifiers of the dataset.
+
+ Returns:
+ tuple: (transformed image, list of transformed brambox boxes)
+ """
+ if index >= len(self):
+ raise IndexError(f'list index out of range [{index}/{len(self)-1}]')
+
+ # Load
+ #print(self.keys[index])
+ ID,x,y,patch_size=self.id(self.keys[index])
+ x,y,patch_size=int(x),int(y),int(patch_size)
+ img = self.slides[ID][x:x+patch_size,y:y+patch_size].compute()#Image.open(self.id(self.keys[index]))
+ anno = bb.util.select_images(self.annos, [self.keys[index]])
+
+ # Transform
+ if self.img_tf is not None:
+ img = self.img_tf(img)
+ if self.anno_tf is not None:
+ anno = self.anno_tf(anno)
+
+ return img, anno
diff --git a/experimental/get_anchors.py b/experimental/get_anchors.py
new file mode 100644
index 0000000..d98004f
--- /dev/null
+++ b/experimental/get_anchors.py
@@ -0,0 +1,23 @@
+from sklearn.cluster import KMeans
+import numpy as np, pandas as pd, brambox as bb
+import pickle, argparse
+
+p=argparse.ArgumentParser()
+p.add_argument('--patch_size',default=512,type=int)
+p.add_argument('--n_anchors',default=20,type=int)
+p.add_argument('--sample_p',default=1.,type=float)
+
+args=p.parse_args()
+np.random.seed(42)
+patch_size=args.patch_size
+n_anchors=args.n_anchors
+sample_p=args.sample_p
+annotation_file = 'annotations_bbox_{}.pkl'.format(patch_size)
+annotations=bb.io.load('pandas',annotation_file)
+if sample_p<1.:
+ annotations=annotations.sample(frac=sample_p)
+
+X=annotations[['x_top_left','y_top_left']].astype(float).values+(annotations['width']/2.).astype(float).values.reshape(-1,1)
+km=KMeans(n_clusters=n_anchors,n_jobs=-1).fit(X)
+anchors=km.cluster_centers_
+pickle.dump(anchors,open('anchors.pkl','wb'))
diff --git a/experimental/get_bounding_boxes_from_seg_point_masks.py b/experimental/get_bounding_boxes_from_seg_point_masks.py
new file mode 100644
index 0000000..fb0252f
--- /dev/null
+++ b/experimental/get_bounding_boxes_from_seg_point_masks.py
@@ -0,0 +1,137 @@
+import brambox as bb
+import os
+from os.path import join, basename
+from pathflowai.utils import load_sql_df, npy2da
+import skimage
+import dask, dask.array as da, pandas as pd, numpy as np
+import argparse
+from scipy import ndimage
+from scipy.ndimage.measurements import label
+import pickle
+from dask.distributed import Client
+from multiprocessing import Pool
+from functools import reduce
+
+def get_box(l,prop):
+ c=[prop.centroid[1], prop.centroid[0]]
+ # l=rev_label[i+1]
+ width = prop.bbox[3] - prop.bbox[1] + 1
+ height = prop.bbox[2] - prop.bbox[0] + 1
+ wh=max(width,height)
+ # c = [ci-wh/2 for ci in c]
+ return [l]+c+[wh]
+
+def get_boxes(m,ID='test',x='x',y='y',patch_size='patchsize', num_classes=3):
+ lbls,n_lbl=label(m)
+ obj_labels={}
+ for i in range(1,num_classes+1):
+ obj_labels[i]=np.unique(lbls[m==i].flatten())
+ rev_label={}
+ for k in obj_labels:
+ for i in obj_labels[k]:
+ rev_label[i]=k
+ rev_label={k:rev_label[k] for k in sorted(list(rev_label.keys()))}
+ objProps = list(skimage.measure.regionprops(lbls))
+ #print(len(objProps),len(rev_label))
+ boxes=dask.compute(*[dask.delayed(get_box)(rev_label[i],objProps[i-1]) for i in list(rev_label.keys())],scheduler='threading') # [get_box(rev_label[i],objProps[i-1]) for i in list(rev_label.keys())]#
+ #print(boxes)
+ boxes=pd.DataFrame(np.array(boxes).astype(int),columns=['class_label','x_top_left','y_top_left','width'])
+
+ #boxes['class_label']=m[boxes[['x_top_left','y_top_left']].values.T.tolist()]
+ boxes['height']=boxes['width']
+ boxes['image']='{}/{}/{}/{}'.format(ID,x,y,patch_size)
+ boxes=boxes[['image','class_label','x_top_left','y_top_left','width','height']]
+ boxes.loc[:,'x_top_left']=np.clip(boxes.loc[:,'x_top_left'],0,m.shape[1])
+ boxes.loc[:,'y_top_left']=np.clip(boxes.loc[:,'y_top_left'],0,m.shape[0])
+
+ bbox_df=bb.util.new('annotation').drop(columns=['difficult','ignore','lost','occluded','truncated'])[['image','class_label','x_top_left','y_top_left','width','height']]
+ bbox_df=bbox_df.append(boxes)
+ #print(boxes)
+ return boxes
+
+if __name__=='__main__':
+ p=argparse.ArgumentParser()
+ p.add_argument('--num_classes',default=4,type=int)
+ p.add_argument('--patch_size',default=512,type=int)
+ p.add_argument('--n_workers',default=40,type=int)
+ p.add_argument('--p_sample',default=0.7,type=float)
+ p.add_argument('--input_dir',default='inputs',type=str)
+ p.add_argument('--patch_info_file',default='cell_info.db',type=str)
+ p.add_argument('--reference_mask',default='reference_mask.npy',type=str)
+ #c=Client()
+ # add mode to just use own extracted boudning boxes or from seg, maybe from histomicstk
+
+ args=p.parse_args()
+ num_classes=args.num_classes
+ n_workers=args.n_workers
+ input_dir=args.input_dir
+ patch_info_file=args.patch_info_file
+ patch_size=args.patch_size
+ p_sample=args.p_sample
+ np.random.seed(42)
+ annotation_file = 'annotations_bbox_{}.pkl'.format(patch_size)
+ reference_mask=args.reference_mask
+ if not os.path.exists('widths.pkl'):
+ m=np.load(reference_mask)
+ bbox_df=get_boxes(m)
+ official_widths=dict(bbox_df.groupby('class_label')['width'].mean()+2*bbox_df.groupby('class_label')['width'].std())
+ pickle.dump(official_widths,open('widths.pkl','wb'))
+ else:
+ official_widths=pickle.load(open('widths.pkl','rb'))
+
+ patch_info=load_sql_df(patch_info_file, patch_size)
+ IDs=patch_info['ID'].unique()
+ #slides = {slide:da.from_zarr(join(input_dir,'{}.zarr'.format(slide))) for slide in IDs}
+ masks = {mask:npy2da(join(input_dir,'{}_mask.npy'.format(mask))) for mask in IDs}
+
+ if p_sample < 1.:
+ patch_info=patch_info.sample(frac=p_sample)
+
+ if not os.path.exists(annotation_file):
+ bbox_df=bb.util.new('annotation').drop(columns=['difficult','ignore','lost','occluded','truncated'])[['image','class_label','x_top_left','y_top_left','width','height']]
+ else:
+ bbox_df=bb.io.load('pandas',annotation_file)
+
+ patch_info=patch_info[~np.isin(np.vectorize(lambda i: '/'.join(patch_info.iloc[i][['ID','x','y','patch_size']].astype(str).tolist()))(np.arange(patch_info.shape[0])),set(bbox_df.image.cat.categories))]
+
+ print(patch_info.shape[0])
+
+ def get_boxes_point_seg(m,ID,x,y,patch_size2,num_classes):
+ bbox_dff=get_boxes(m,ID=ID,x=x,y=y,patch_size=patch_size2, num_classes=num_classes)
+ for i in official_widths.keys():
+ bbox_dff.loc[bbox_dff['class_label']==i,'width']=int(official_widths[i])
+ bbox_dff.loc[:,'x_top_left']=(bbox_dff.loc[:,'x_top_left']-bbox_dff['width']/2.).astype(int)
+ bbox_dff.loc[:,'y_top_left']=(bbox_dff.loc[:,'y_top_left']-bbox_dff['width']/2.).astype(int)
+ bbox_dff.loc[:,'x_top_left']=np.clip(bbox_dff.loc[:,'x_top_left'],0,m.shape[1])
+ bbox_dff.loc[:,'y_top_left']=np.clip(bbox_dff.loc[:,'y_top_left'],0,m.shape[0])
+ return bbox_dff
+
+ def process_chunk(patch_info_sub):
+ patch_info_sub=patch_info_sub.reset_index(drop=True)
+ bbox_dfs=[]
+
+ for i in range(patch_info_sub.shape[0]):
+ #print(i)
+ patch=patch_info_sub.iloc[i]
+ ID,x,y,patch_size2=patch[['ID','x','y','patch_size']].tolist()
+ m=masks[ID][x:x+patch_size2,y:y+patch_size2]
+ bbox_dff=get_boxes_point_seg(m,ID,x,y,patch_size2,num_classes)#dask.delayed(get_boxes_point_seg)(m,ID,x,y,patch_size2)
+ #print(bbox_dff)
+ bbox_dfs.append(bbox_dff)
+ return bbox_dfs
+
+ patch_info_subs=np.array_split(patch_info,n_workers)
+
+ p=Pool(n_workers)
+
+ bbox_dfs=reduce(lambda x,y:x+y,p.map(process_chunk,patch_info_subs))
+
+ #bbox_dfs=dask.compute(*bbox_dfs,scheduler='processes')
+
+ bbox_df=pd.concat([bbox_df]+bbox_dfs)
+
+
+ bbox_df.loc[:,'height']=bbox_df['width']
+
+
+ bb.io.save(bbox_df,'pandas',annotation_file)
diff --git a/experimental/get_counts.py b/experimental/get_counts.py
new file mode 100644
index 0000000..d242324
--- /dev/null
+++ b/experimental/get_counts.py
@@ -0,0 +1,73 @@
+import brambox as bb
+import os
+from os.path import join, basename
+from pathflowai.utils import load_sql_df, npy2da, df2sql
+import skimage
+import dask, dask.array as da, pandas as pd, numpy as np
+import argparse
+from scipy import ndimage
+from scipy.ndimage.measurements import label
+import pickle
+from dask.distributed import Client
+from multiprocessing import Pool
+from functools import reduce
+
+def count_cells(m, num_classes=3):
+ lbls,n_lbl=label(m)
+ obj_labels=np.zeros(num_classes)
+ for i in range(1,num_classes+1):
+ obj_labels[i-1]=len(np.unique(lbls[m==i].flatten()))
+ return obj_labels
+
+if __name__=='__main__':
+ p=argparse.ArgumentParser()
+ p.add_argument('--num_classes',default=4,type=int)
+ p.add_argument('--patch_size',default=512,type=int)
+ p.add_argument('--n_workers',default=40,type=int)
+ p.add_argument('--p_sample',default=0.7,type=float)
+ p.add_argument('--input_dir',default='inputs',type=str)
+ p.add_argument('--patch_info_file',default='cell_info.db',type=str)
+ p.add_argument('--reference_mask',default='reference_mask.npy',type=str)
+ #c=Client()
+ # add mode to just use own extracted boudning boxes or from seg, maybe from histomicstk
+
+ args=p.parse_args()
+ num_classes=args.num_classes
+ n_workers=args.n_workers
+ input_dir=args.input_dir
+ patch_info_file=args.patch_info_file
+ patch_size=args.patch_size
+ np.random.seed(42)
+ reference_mask=args.reference_mask
+
+ patch_info=load_sql_df(patch_info_file, patch_size)
+ IDs=patch_info['ID'].unique()
+ #slides = {slide:da.from_zarr(join(input_dir,'{}.zarr'.format(slide))) for slide in IDs}
+ masks = {mask:npy2da(join(input_dir,'{}_mask.npy'.format(mask))) for mask in IDs}
+
+ def process_chunk(patch_info_sub):
+ patch_info_sub=patch_info_sub.reset_index(drop=True)
+ counts=[]
+ for i in range(patch_info_sub.shape[0]):
+ #print(i)
+ patch=patch_info_sub.iloc[i]
+ ID,x,y,patch_size2=patch[['ID','x','y','patch_size']].tolist()
+ m=masks[ID][x:x+patch_size2,y:y+patch_size2]
+ counts.append(dask.delayed(count_cells)(m, num_classes=num_classes))
+
+ return dask.compute(*counts,scheduler='threading')
+
+ patch_info_subs=np.array_split(patch_info,n_workers)
+
+ p=Pool(n_workers)
+
+ counts=reduce(lambda x,y:x+y,p.map(process_chunk,patch_info_subs))
+
+ #bbox_dfs=dask.compute(*bbox_dfs,scheduler='processes')
+
+ counts=pd.DataFrame(np.vstack(counts))
+
+ patch_info=pd.concat([patch_info[['ID','x','y','patch_size','annotation']].reset_index(drop=True),counts.reset_index(drop=True)],axis=1).reset_index()
+ print(patch_info)
+
+ df2sql(patch_info, 'counts_test.db', patch_size, mode='replace')
diff --git a/experimental/object_detection.py b/experimental/object_detection.py
new file mode 100644
index 0000000..34eba09
--- /dev/null
+++ b/experimental/object_detection.py
@@ -0,0 +1,191 @@
+import lightnet as ln
+import torch
+import numpy as np, pandas as pd
+import matplotlib.pyplot as plt
+import brambox as bb
+import dask as da
+from datasets import BramboxPathFlowDataset
+import argparse, pickle
+from sklearn.model_selection import train_test_split
+
+# Settings
+ln.logger.setConsoleLevel('ERROR') # Only show error log messages
+bb.logger.setConsoleLevel('ERROR')
+# https://eavise.gitlab.io/lightnet/notes/02-B-engine.html
+
+p=argparse.ArgumentParser()
+p.add_argument('--num_classes',default=4,type=int)
+p.add_argument('--patch_size',default=512,type=int)
+p.add_argument('--patch_info_file',default='cell_info.db',type=str)
+p.add_argument('--input_dir',default='inputs',type=str)
+p.add_argument('--sample_p',default=1.,type=float)
+p.add_argument('--conf_thresh',default=0.01,type=float)
+p.add_argument('--nms_thresh',default=0.5,type=float)
+
+
+args=p.parse_args()
+np.random.seed(42)
+num_classes=args.num_classes+1
+patch_size=args.patch_size
+batch_size=64
+patch_info_file=args.patch_info_file
+input_dir=args.input_dir
+sample_p=args.sample_p
+conf_thresh=args.conf_thresh
+nms_thresh=args.nms_thresh
+anchors=pickle.load(open('anchors.pkl','rb'))
+
+annotation_file = 'annotations_bbox_{}.pkl'.format(patch_size)
+annotations=bb.io.load('pandas',annotation_file)
+
+if sample_p < 1.:
+ annotations=annotations.sample(frac=sample_p)
+
+annotations_dict={}
+annotations_dict['train'],annotations_dict['test']=train_test_split(annotations)
+annotations_dict['train'],annotations_dict['val']=train_test_split(annotations_dict['train'])
+
+model=ln.models.Yolo(num_classes=num_classes,anchors=anchors.tolist())
+
+loss = ln.network.loss.RegionLoss(
+ num_classes=model.num_classes,
+ anchors=model.anchors,
+ stride=model.stride
+)
+
+transforms = ln.data.transform.Compose([ln.data.transform.RandomHSV(
+ hue=1,
+ saturation=2,
+ value=2
+)])
+
+# Create HyperParameters
+params = ln.engine.HyperParameters(
+ network=model,
+ input_dimension = (patch_size,patch_size),
+ mini_batch_size=16,
+ batch_size=batch_size,
+ max_batches=80000
+)
+
+post = ln.data.transform.Compose([
+ ln.data.transform.GetBoundingBoxes(
+ num_classes=params.network.num_classes,
+ anchors=params.network.anchors,
+ conf_thresh=conf_thresh,
+ ),
+
+ ln.data.transform.NonMaxSuppression(
+ nms_thresh=nms_thresh
+ ),
+
+ ln.data.transform.TensorToBrambox(
+ network_size=(patch_size,patch_size),
+ # class_label_map=class_label_map,
+ )
+])
+
+datasets={k:BramboxPathFlowDataset(input_dir,patch_info_file, patch_size, annotations_dict[k], input_dimension=(patch_size,patch_size), class_label_map=None, identify=None, img_transform=None, anno_transform=None) for k in ['train','val','test']}
+# transforms
+
+params.loss = ln.network.loss.RegionLoss(params.network.num_classes, params.network.anchors)
+params.optim = torch.optim.SGD(params.network.parameters(), lr=1e-4)
+params.scheduler = ln.engine.SchedulerCompositor(
+ # batch scheduler
+ (0, torch.optim.lr_scheduler.CosineAnnealingLR(params.optim,T_max=200))
+ )
+
+dls = {k:ln.data.DataLoader(
+ datasets[k],
+ batch_size = batch_size,
+ collate_fn = ln.data.brambox_collate # We want the data to be grouped as a list
+ ) for k in ['train','val','test']}
+
+params.val_loader=dls['val']
+
+class CustomEngine(ln.engine.Engine):
+ def start(self):
+ """ Do whatever needs to be done before starting """
+ self.params.to(self.device) # Casting parameters to a certain device
+ self.optim.zero_grad() # Make sure to start with no gradients
+ self.loss_acc = [] # Loss accumulator
+
+ def process_batch(self, data):
+ """ Forward and backward pass """
+ data, target = data # Unpack
+ #print(target)
+ data=data.permute(0,3,1,2).float()
+ if torch.cuda.is_available():
+ data=data.cuda()
+
+ #print(data)
+
+ output = self.network(data)
+ #print(output)
+
+ loss = self.loss(output, target)
+
+ #print(loss)
+ loss.backward()
+ bbox=post(output)
+ print(bbox)
+
+ self.loss_acc.append(loss.item())
+
+ @ln.engine.Engine.batch_end(100) # how to pass in validation dataloader
+ def val_loop(self):
+ with torch.no_grad():
+ for i,data in enumerate(self.val_loader):
+ if i > 100:
+ break
+ data, target = data
+ data=data.permute(0,3,1,2).float()
+ if torch.cuda.is_available():
+ data=data.cuda()
+ output = self.network(data)
+ #print(output)
+ loss = self.loss(output, target)
+ print(loss)
+ bbox=post(output)
+ print(bbox)
+ if not i:
+ bbox_final=[bbox]
+ else:
+ bbox_final.append(bbox)
+
+ detections=pd.concat(bbox_final)
+ print(detections)
+ print(annotations_dict['val'])
+ pr=bb.stat.pr(detections, annotations_dict['val'], threshold=0.5)
+ auc=bb.stat.auc(pr)
+ print('VAL AUC={}'.format(auc))
+
+ @ln.engine.Engine.batch_end(300)
+ def save_model(self):
+ self.params.save(f'backup-{self.batch}.state.pt')
+
+ def train_batch(self):
+ """ Weight update and logging """
+ self.optim.step()
+ self.optim.zero_grad()
+
+ batch_loss = sum(self.loss_acc) / len(self.loss_acc)
+ self.loss_acc = []
+ self.log(f'Loss: {batch_loss}')
+
+ def quit(self):
+ if self.batch >= self.max_batches: # Should probably save weights here
+ print('Reached end of training')
+ return True
+ return False
+
+
+
+# Create engine
+engine = CustomEngine(
+ params, dls['train'], # Dataloader (None) is not valid
+ device=torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+)
+
+for i in range(10):
+ engine()
diff --git a/pathflowai/cli_preprocessing.py b/pathflowai/cli_preprocessing.py
index efedd04..04bb5f5 100644
--- a/pathflowai/cli_preprocessing.py
+++ b/pathflowai/cli_preprocessing.py
@@ -50,10 +50,11 @@ def output_if_exists(filename):
@click.option('-nn', '--n_neighbors', default=5, help='If adjusting mask, number of neighbors connectivity to remove.', show_default=True)
@click.option('-bp', '--basic_preprocess', is_flag=True, help='Basic preprocessing pipeline, annotation areas are not saved. Used for benchmarking tool against comparable pipelines', show_default=True)
@click.option('-ei', '--entire_image', is_flag=True, help='Store entire image in central db rather than patches.', show_default=True)
-def preprocess_pipeline(img2npy,basename,input_dir,annotations,preprocess,patches,threshold,patch_size, intensity_threshold, generate_finetune_segmentation, target_segmentation_class, target_threshold, out_db, adjust_mask, n_neighbors, basic_preprocess, entire_image):
+@click.option('-nz', '--no_zarr', is_flag=True, help='Don\'t save zarr format file.', show_default=True)
+def preprocess_pipeline(img2npy,basename,input_dir,annotations,preprocess,patches,threshold,patch_size, intensity_threshold, generate_finetune_segmentation, target_segmentation_class, target_threshold, out_db, adjust_mask, n_neighbors, basic_preprocess, entire_image, no_zarr):
"""Preprocessing pipeline that accomplishes 3 things. 1: storage into ZARR format, 2: optional mask adjustment, 3: storage of patch-level information into SQL DB"""
- for ext in ['.npy','.svs','.tiff','.tif', '.vms', '.vmu', '.ndpi', '.scn', '.mrxs', '.svslide', '.bif', '.jpeg', '.png']:
+ for ext in ['.npy','.svs','.tiff','.tif', '.vms', '.vmu', '.ndpi', '.scn', '.mrxs', '.svslide', '.bif', '.jpeg', '.png', '.h5']:
svs_file = output_if_exists(join(input_dir,'{}{}'.format(basename,ext)))
if svs_file != None:
break
@@ -75,13 +76,15 @@ def preprocess_pipeline(img2npy,basename,input_dir,annotations,preprocess,patche
npy_mask=npy_mask,
annotations=annotations,
out_zarr=out_zarr,
- out_pkl=out_pkl)
+ out_pkl=out_pkl,
+ no_zarr=no_zarr)
if npy_mask==None and xml_file==None:
+ print('Generating Zero Mask')
npy_mask=join(input_dir,'{}_mask.npz'.format(basename))
target_segmentation_class=1
generate_finetune_segmentation=True
- create_zero_mask(npy_mask,out_zarr,out_pkl)
+ create_zero_mask(npy_mask,out_zarr if not no_zarr else svs_file,out_pkl)
preprocess_point = time.time()
@@ -93,7 +96,7 @@ def preprocess_pipeline(img2npy,basename,input_dir,annotations,preprocess,patche
adj_npy=join(adj_dir,os.path.basename(npy_mask))
os.makedirs(adj_dir,exist_ok=True)
if not os.path.exists(adj_npy):
- adjust_mask(npy_mask, out_zarr, adj_npy, n_neighbors)
+ adjust_mask(npy_mask, out_zarr if not no_zarr else svs_file, adj_npy, n_neighbors)
adjust_point = time.time()
print('Adjust took {}'.format(adjust_point-preprocess_point))
@@ -111,7 +114,8 @@ def preprocess_pipeline(img2npy,basename,input_dir,annotations,preprocess,patche
target_threshold=target_threshold,
adj_mask=adj_npy,
basic_preprocess=basic_preprocess,
- entire_image=entire_image)
+ entire_image=entire_image,
+ svs_file=svs_file)
patch_point = time.time()
print('Patches took {}'.format(patch_point-adjust_point))
diff --git a/pathflowai/cli_visualizations.py b/pathflowai/cli_visualizations.py
index 91cf3ff..2ad4408 100644
--- a/pathflowai/cli_visualizations.py
+++ b/pathflowai/cli_visualizations.py
@@ -125,7 +125,7 @@ def plot_embeddings(embeddings_file,plotly_output_file, annotations, remove_back
"""Perform UMAP embeddings of patches and plot using plotly."""
import torch
from umap import UMAP
- from visualize import PlotlyPlot
+ from pathflowai.visualize import PlotlyPlot
import pandas as pd, numpy as np
embeddings_dict=torch.load(embeddings_file)
embeddings=embeddings_dict['embeddings']
diff --git a/pathflowai/datasets.py b/pathflowai/datasets.py
index fcf6dd7..9be82ee 100644
--- a/pathflowai/datasets.py
+++ b/pathflowai/datasets.py
@@ -372,6 +372,8 @@ def __init__(self,dataset_df, set, patch_info_file, transformers, input_dir, tar
self.classify_annotations=classify_annotations
print(self.targets)
self.dilation_jitter=DilationJitter(dilation_jitter,self.segmentation,(original_set=='train'))
+ if not self.targets:
+ self.targets = [pos_annotation_class]+list(other_annotations)
def concat(self, other_dataset):
"""Concatenate this dataset with others. Updates its own internal attributes.
@@ -528,7 +530,7 @@ def update_dataset(self, input_dir, new_db, prediction_basename=[]):
self.segmentation_maps = {slide:npy2da(join(self.input_dir,'{}_mask.npy'.format(slide))) for slide in IDs}
self.length = self.patch_info.shape[0]
- @pysnooper.snoop('get_item.log')
+ #@pysnooper.snoop('get_item.log')
def __getitem__(self, i):
patch_info = self.patch_info.iloc[i]
ID = patch_info['ID']
diff --git a/pathflowai/model_training.py b/pathflowai/model_training.py
index 9405b02..26c6b0d 100644
--- a/pathflowai/model_training.py
+++ b/pathflowai/model_training.py
@@ -124,7 +124,8 @@ def train_model_(training_opts):
eta_min=training_opts['eta_min'],
T_mult=training_opts['T_mult']),
loss_fn=training_opts['loss_fn'],
- num_train_batches=num_train_batches)
+ num_train_batches=num_train_batches,
+ seg_out_class=training_opts['seg_out_class'])
if not training_opts['predict']:
@@ -164,11 +165,17 @@ def train_model_(training_opts):
exit()
y_pred = trainer.predict(dataloader)
print(ID,y_pred.shape)
- segmentation_predictions2npy(y_pred, dataset.patch_info, dataset.segmentation_maps[ID], npy_output='{}/{}_predict.npy'.format(training_opts['prediction_output_dir'],ID), original_patch_size=training_opts['patch_size'], resized_patch_size=training_opts['patch_resize'])
+ segmentation_predictions2npy(y_pred, dataset.patch_info, dataset.segmentation_maps[ID], npy_output='{}/{}_predict.npy'.format(training_opts['prediction_output_dir'],ID), original_patch_size=training_opts['patch_size'], resized_patch_size=training_opts['patch_resize'], output_probs=(training_opts['seg_out_class']>=0))
else:
extract_embedding=training_opts['extract_embedding']
if extract_embedding:
- trainer.model.fc = trainer.model.fc[0]
+ architecture=training_opts['architecture']
+ if hasattr(trainer.model,"fc"):
+ trainer.model.fc = trainer.model.fc[0]
+ elif hasattr(trainer.model,"output"):
+ trainer.model.output = trainer.model.output[0]
+ elif architecture.startswith('alexnet') or architecture.startswith('vgg') or architecture.startswith('densenet'):
+ trainer.model.classifier[6]=trainer.model.classifier[6][0]
trainer.bce=False
y_pred = trainer.predict(dataloaders['test'])
@@ -206,7 +213,7 @@ def train_model_(training_opts):
@click.option('-fn', '--fix_names', is_flag=True, help='Whether to fix names in dataset_df.', show_default=True)
@click.option('-a', '--architecture', default='alexnet', help='Neural Network Architecture.', type=click.Choice(['alexnet', 'densenet121', 'densenet161', 'densenet169', 'densenet201',
'inception_v3', 'resnet101', 'resnet152', 'resnet18', 'resnet34', 'resnet50', 'vgg11', 'vgg11_bn','unet','unet2','nested_unet','fast_scnn',
- 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 'vgg19', 'vgg19_bn', 'deeplabv3_resnet101','deeplabv3_resnet50','fcn_resnet101', 'fcn_resnet50']+['efficientnet-b{}'.format(i) for i in range(8)]), show_default=True)
+ 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 'vgg19', 'vgg19_bn', 'deeplabv3_resnet101','deeplabv3_resnet50','fcn_resnet101', 'fcn_resnet50',"sqnxt23_w3d2", "sqnxt23_w2", "sqnxt23v5_w1", "sqnxt23v5_w3d2", "sqnxt23v5_w2"]+['efficientnet-b{}'.format(i) for i in range(8)]), show_default=True)
@click.option('-imb', '--imbalanced_correction', is_flag=True, help='Attempt to correct for imbalanced data.', show_default=True)
@click.option('-imb2', '--imbalanced_correction2', is_flag=True, help='Attempt to correct for imbalanced data.', show_default=True)
@click.option('-ca', '--classify_annotations', is_flag=True, help='Classify annotations.', show_default=True)
@@ -238,7 +245,9 @@ def train_model_(training_opts):
@click.option('-cw', '--custom_weights', default='', help='Comma delimited custom weights', type=click.Path(exists=False), show_default=True)
@click.option('-pset', '--prediction_set', default='test', help='Dataset to predict on.', type=click.Choice(['train','val','test']), show_default=True)
@click.option('-ut', '--user_transforms_file', default='', help='YAML file to add transforms from.', type=click.Path(exists=False), show_default=True)
-def train_model(segmentation,prediction,pos_annotation_class,other_annotations,save_location,pretrained_save_location,input_dir,patch_size,patch_resize,target_names,dataset_df,fix_names, architecture, imbalanced_correction, imbalanced_correction2, classify_annotations, num_targets, subsample_p,subsample_p_val,num_training_images_epoch, learning_rate, transform_platform, n_epoch, patch_info_file, target_segmentation_class, target_threshold, oversampling_factor, supplement, batch_size, run_test, mt_bce, prediction_output_dir, extract_embedding, extract_model, binary_threshold, pretrain, overwrite_loss_fn, adopt_training_loss, external_test_db,external_test_dir, prediction_basename, custom_weights, prediction_set, user_transforms_file):
+@click.option('-svp', '--save_val_predictions', is_flag=True, help='Whether to save the validation predictions.', show_default=True)
+@click.option('-soc', '--seg_out_class', default=-1, help='Output a particular segmentation class probabilities.', show_default=True)
+def train_model(segmentation,prediction,pos_annotation_class,other_annotations,save_location,pretrained_save_location,input_dir,patch_size,patch_resize,target_names,dataset_df,fix_names, architecture, imbalanced_correction, imbalanced_correction2, classify_annotations, num_targets, subsample_p,subsample_p_val,num_training_images_epoch, learning_rate, transform_platform, n_epoch, patch_info_file, target_segmentation_class, target_threshold, oversampling_factor, supplement, batch_size, run_test, mt_bce, prediction_output_dir, extract_embedding, extract_model, binary_threshold, pretrain, overwrite_loss_fn, adopt_training_loss, external_test_db,external_test_dir, prediction_basename, custom_weights, prediction_set, user_transforms_file, save_val_predictions, seg_out_class):
"""Train and predict using model for regression and classification tasks."""
# add separate pretrain ability on separating cell types, then transfer learn
# add pretrain and efficient net, pretraining remove last layer while loading state dict
@@ -296,11 +305,12 @@ def train_model(segmentation,prediction,pos_annotation_class,other_annotations,s
external_test_db=external_test_db,
external_test_dir=external_test_dir,
prediction_basename=prediction_basename,
- save_val_predictions=True,
+ save_val_predictions=save_val_predictions,
custom_weights=custom_weights,
prediction_set=prediction_set,
user_transforms=dict(),
- dilation_jitter=dict())
+ dilation_jitter=dict(),
+ seg_out_class=seg_out_class)
training_opts = dict(normalization_file="normalization_parameters.pkl",
loss_fn='bce',
diff --git a/pathflowai/models.py b/pathflowai/models.py
index 6ac138c..e841a84 100644
--- a/pathflowai/models.py
+++ b/pathflowai/models.py
@@ -66,6 +66,9 @@ def __init__(self, n_input, hidden_topology, dropout_p, n_outputs=1, binary=True
self.layers.append(nn.Sequential(self.output_layer,output_transform))
self.mlp = nn.Sequential(*self.layers)
+ def forward(self,x):
+ return self.mlp(x)
+
class FixedSegmentationModule(nn.Module):
"""Special model modification for segmentation tasks. Gets output from some of the models' forward loops.
@@ -139,6 +142,11 @@ def generate_model(pretrain,architecture,num_classes, add_sigmoid=True, n_hidden
else:
model = EfficientNet.from_name(architecture, override_params=dict(num_classes=num_classes))
print(model)
+ elif architecture.startswith('sqnxt'):
+ from pytorchcv.model_provider import get_model as ptcv_get_model
+ model = ptcv_get_model(architecture, pretrained=pretrain)
+ num_ftrs=int(128*int(architecture.split('_')[-1][1]))
+ model.output=MLP(num_ftrs, [1000], dropout_p=0., n_outputs=num_classes, binary=add_sigmoid, softmax=False).mlp
else:
#for pretrained on imagenet
model_names = [m for m in dir(models) if not m.startswith('__')]
@@ -161,7 +169,7 @@ def generate_model(pretrain,architecture,num_classes, add_sigmoid=True, n_hidden
#linear_layer = nn.Linear(num_ftrs, num_classes)
#torch.nn.init.xavier_uniform(linear_layer.weight)
model.fc = MLP(num_ftrs, [1000], dropout_p=0., n_outputs=num_classes, binary=add_sigmoid, softmax=False).mlp#nn.Sequential(*([linear_layer]+([nn.Sigmoid()] if (add_sigmoid) else [])))
- elif architecture.startswith('alexnet') or architecture.startswith('vgg') or architecture.startswith('densenets'):
+ elif architecture.startswith('alexnet') or architecture.startswith('vgg') or architecture.startswith('densenet'):
num_ftrs = model.classifier[6].in_features
#linear_layer = nn.Linear(num_ftrs, num_classes)
#torch.nn.init.xavier_uniform(linear_layer.weight)
@@ -231,7 +239,7 @@ class ModelTrainer:
num_train_batches:int
Number of training batches for epoch.
"""
- def __init__(self, model, n_epoch=300, validation_dataloader=None, optimizer_opts=dict(name='adam',lr=1e-3,weight_decay=1e-4), scheduler_opts=dict(scheduler='warm_restarts',lr_scheduler_decay=0.5,T_max=10,eta_min=5e-8,T_mult=2), loss_fn='ce', reduction='mean', num_train_batches=None):
+ def __init__(self, model, n_epoch=300, validation_dataloader=None, optimizer_opts=dict(name='adam',lr=1e-3,weight_decay=1e-4), scheduler_opts=dict(scheduler='warm_restarts',lr_scheduler_decay=0.5,T_max=10,eta_min=5e-8,T_mult=2), loss_fn='ce', reduction='mean', num_train_batches=None, seg_out_class=-1):
self.model = model
optimizers = {'adam':torch.optim.Adam, 'sgd':torch.optim.SGD}
@@ -240,7 +248,11 @@ def __init__(self, model, n_epoch=300, validation_dataloader=None, optimizer_opt
if 'name' not in list(optimizer_opts.keys()):
optimizer_opts['name']='adam'
self.optimizer = optimizers[optimizer_opts.pop('name')](self.model.parameters(),**optimizer_opts)
- self.model, self.optimizer = amp.initialize(self.model, self.optimizer, opt_level='O2')
+ if torch.cuda.is_available():
+ self.model, self.optimizer = amp.initialize(self.model, self.optimizer, opt_level='O2')
+ self.cuda=True
+ else:
+ self.cuda=False
self.scheduler = Scheduler(optimizer=self.optimizer,opts=scheduler_opts)
self.n_epoch = n_epoch
self.validation_dataloader = validation_dataloader
@@ -251,6 +263,7 @@ def __init__(self, model, n_epoch=300, validation_dataloader=None, optimizer_opt
self.original_loss_fn = copy.deepcopy(loss_functions[loss_fn])
self.num_train_batches = num_train_batches
self.val_loss_fn = copy.deepcopy(loss_functions[loss_fn])
+ self.seg_out_class=seg_out_class
def calc_loss(self, y_pred, y_true):
"""Calculates loss supplied in init statement and modified by reweighting.
@@ -348,8 +361,11 @@ def loss_backward(self,loss):
Torch loss calculated.
"""
- with amp.scale_loss(loss,self.optimizer) as scaled_loss:
- scaled_loss.backward()
+ if self.cuda:
+ with amp.scale_loss(loss,self.optimizer) as scaled_loss:
+ scaled_loss.backward()
+ else:
+ loss.backward()
#@pysnooper.snoop('train_loop.log')
def train_loop(self, epoch, train_dataloader):
@@ -493,13 +509,17 @@ def test_loop(self, test_dataloader):
if torch.cuda.is_available():
X = X.cuda()
if test_dataloader.dataset.segmentation:
- prediction=self.model(X).detach().cpu().numpy().argmax(axis=1)
+ prediction=self.model(X).detach().cpu().numpy()
+ if self.seg_out_class>=0:
+ prediction=prediction[:,self.seg_out_class,...]
+ else:
+ prediction=prediction.argmax(axis=1).astype(int)
pred_size=prediction.shape#size()
#pred_mean=prediction[0].mean(axis=0)
- y_pred.append((prediction).astype(int))
+ y_pred.append(prediction)
else:
prediction=self.model(X)
- if (len(test_dataloader.dataset.targets)-1) or self.bce:
+ if self.loss_fn_name != 'mse' and ((len(test_dataloader.dataset.targets)-1) or self.bce):
prediction=self.sigmoid(prediction)
elif test_dataloader.dataset.classify_annotations:
prediction=F.softmax(prediction,dim=1)
diff --git a/pathflowai/utils.py b/pathflowai/utils.py
index cb80f2a..3f4001b 100644
--- a/pathflowai/utils.py
+++ b/pathflowai/utils.py
@@ -35,9 +35,8 @@
#import xarray as xr, sparse
import pickle
import copy
-
+import h5py
import nonechucks as nc
-
from nonechucks import SafeDataLoader as DataLoader
def load_sql_df(sql_file, patch_size):
@@ -224,6 +223,9 @@ def create_sparse_annotation_arrays(xml_file, img_size, annotations=[]):
interior_points_dict = {annotation:parse_coord_return_boxes(xml_file, annotation_name = annotation, return_coords = False) for annotation in annotations}#grab_interior_points(xml_file, img_size, annotations=annotations) if annotations else {}
return {annotation:interior_points_dict[annotation] for annotation in annotations}#sparse.COO.from_scipy_sparse((sps.coo_matrix(interior_points_dict[annotation],img_size, dtype=np.uint8) if interior_points_dict[annotation] not None else sps.coo_matrix(img_size, dtype=np.uint8)).tocsr()) for annotation in annotations} # [sps.coo_matrix(img_size, dtype=np.uint8)]+
+def load_image(svs_file):
+ return (npy2da(svs_file) if (svs_file.endswith('.npy') or svs_file.endswith('.h5')) else svs2dask_array(svs_file, tile_size=1000, overlap=0))
+
def load_process_image(svs_file, xml_file=None, npy_mask=None, annotations=[]):
"""Load SVS-like image (including NPY), segmentation/classification annotations, generate dask array and dictionary of annotations.
@@ -246,7 +248,7 @@ def load_process_image(svs_file, xml_file=None, npy_mask=None, annotations=[]):
Annotation masks.
"""
- arr = npy2da(svs_file) if svs_file.endswith('.npy') else svs2dask_array(svs_file, tile_size=1000, overlap=0)#load_image(svs_file)
+ arr = load_image(svs_file)#npy2da(svs_file) if (svs_file.endswith('.npy') or svs_file.endswith('.h5')) else svs2dask_array(svs_file, tile_size=1000, overlap=0)#load_image(svs_file)
img_size = arr.shape[:2]
masks = {}#{'purple': create_purple_mask(arr,img_size,sparse=False)}
if xml_file is not None:
@@ -261,7 +263,7 @@ def load_process_image(svs_file, xml_file=None, npy_mask=None, annotations=[]):
#arr = da.concatenate([arr,masks.pop('purple')],axis=2)
return arr, masks#xr.Dataset.from_dict({k:v for k,v in list(data_arr.items())+list(purple_arr.items())+list(mask_arr.items())})#list(dict(image=data_arr,purple=purple_arr,annotations=mask_arr).items()))#arr, masks
-def save_dataset(arr, masks, out_zarr, out_pkl):
+def save_dataset(arr, masks, out_zarr, out_pkl, no_zarr):
"""Saves dask array image, dictionary of annotations to zarr and pickle respectively.
Parameters
@@ -275,13 +277,14 @@ def save_dataset(arr, masks, out_zarr, out_pkl):
out_pkl:str
Pickle output file.
"""
- arr.astype('uint8').to_zarr(out_zarr, overwrite=True)
+ if not no_zarr:
+ arr.astype('uint8').to_zarr(out_zarr, overwrite=True)
pickle.dump(masks,open(out_pkl,'wb'))
#dataset.to_netcdf(out_netcdf, compute=False)
#pickle.dump(dataset, open(out_pkl,'wb'), protocol=-1)
-def run_preprocessing_pipeline(svs_file, xml_file=None, npy_mask=None, annotations=[], out_zarr='output_zarr.zarr', out_pkl='output.pkl'):
+def run_preprocessing_pipeline(svs_file, xml_file=None, npy_mask=None, annotations=[], out_zarr='output_zarr.zarr', out_pkl='output.pkl',no_zarr=False):
"""Run preprocessing pipeline. Store image into zarr format, segmentations maintain as npy, and xml annotations as pickle.
Parameters
@@ -301,7 +304,7 @@ def run_preprocessing_pipeline(svs_file, xml_file=None, npy_mask=None, annotatio
"""
#save_dataset(load_process_image(svs_file, xml_file, npy_mask, annotations), out_netcdf)
arr, masks = load_process_image(svs_file, xml_file, npy_mask, annotations)
- save_dataset(arr, masks,out_zarr, out_pkl)
+ save_dataset(arr, masks,out_zarr, out_pkl, no_zarr)
###################
@@ -380,7 +383,11 @@ def load_dataset(in_zarr, in_pkl):
Annotations dictionary.
"""
- return da.from_zarr(in_zarr), pickle.load(open(in_pkl,'rb'))#xr.open_dataset(in_netcdf)
+ if not os.path.exists(in_pkl):
+ annotations={'annotations':''}
+ else:
+ annotations=pickle.load(open(in_pkl,'rb'))
+ return (da.from_zarr(in_zarr) if in_zarr.endswith('.zarr') else load_image(in_zarr)), annotations#xr.open_dataset(in_netcdf)
def is_valid_patch(xs,ys,patch_size,purple_mask,intensity_threshold,threshold=0.5):
"""Deprecated, computes whether patch is valid."""
@@ -400,7 +407,7 @@ def fix_polygon(poly):
return poly
#@pysnooper.snoop("extract_patch.log")
-def extract_patch_information(basename, input_dir='./', annotations=[], threshold=0.5, patch_size=224, generate_finetune_segmentation=False, target_class=0, intensity_threshold=100., target_threshold=0., adj_mask='', basic_preprocess=False, tries=0, entire_image=False):
+def extract_patch_information(basename, input_dir='./', annotations=[], threshold=0.5, patch_size=224, generate_finetune_segmentation=False, target_class=0, intensity_threshold=100., target_threshold=0., adj_mask='', basic_preprocess=False, tries=0, entire_image=False, svs_file=''):
"""Final step of preprocessing pipeline. Break up image into patches, include if not background and of a certain intensity, find area of each annotation type in patch, spatial information, image ID and dump data to SQL table.
Parameters
@@ -450,7 +457,7 @@ def extract_patch_information(basename, input_dir='./', annotations=[], threshol
from functools import reduce
#from distributed import Client,LocalCluster
max_tries=4
- kargs=dict(basename=basename, input_dir=input_dir, annotations=annotations, threshold=threshold, patch_size=patch_size, generate_finetune_segmentation=generate_finetune_segmentation, target_class=target_class, intensity_threshold=intensity_threshold, target_threshold=target_threshold, adj_mask=adj_mask, basic_preprocess=basic_preprocess, tries=tries)
+ kargs=dict(basename=basename, input_dir=input_dir, annotations=annotations, threshold=threshold, patch_size=patch_size, generate_finetune_segmentation=generate_finetune_segmentation, target_class=target_class, intensity_threshold=intensity_threshold, target_threshold=target_threshold, adj_mask=adj_mask, basic_preprocess=basic_preprocess, tries=tries, svs_file=svs_file)
try:
#,
# 'distributed.scheduler.allowed-failures':20,
@@ -459,13 +466,15 @@ def extract_patch_information(basename, input_dir='./', annotations=[], threshol
#cluster.adapt(minimum=10, maximum=100)
#cluster = LocalCluster(threads_per_worker=1, n_workers=20, memory_limit="80G")
#client=Client()#Client(cluster)#processes=True)#cluster,
-
- arr, masks = load_dataset(join(input_dir,'{}.zarr'.format(basename)),join(input_dir,'{}_mask.pkl'.format(basename)))
+ in_zarr=join(input_dir,'{}.zarr'.format(basename))
+ in_zarr=(in_zarr if os.path.exists(in_zarr) else svs_file)
+ arr, masks = load_dataset(in_zarr,join(input_dir,'{}_mask.pkl'.format(basename)))
if 'annotations' in masks:
segmentation = True
-
#if generate_finetune_segmentation:
- segmentation_mask = npy2da(join(input_dir,'{}_mask.npy'.format(basename)) if not adj_mask else adj_mask)
+ mask=join(input_dir,'{}_mask.npy'.format(basename))
+ mask = (mask if os.path.exists(mask) else mask.replace('.npy','.npz'))
+ segmentation_mask = (npy2da(mask) if not adj_mask else adj_mask)
else:
segmentation = False
annotations=list(annotations)
@@ -531,7 +540,7 @@ def extract_patch_information(basename, input_dir='./', annotations=[], threshol
print(patch_info)
return patch_info
-def generate_patch_pipeline(basename, input_dir='./', annotations=[], threshold=0.5, patch_size=224, out_db='patch_info.db', generate_finetune_segmentation=False, target_class=0, intensity_threshold=100., target_threshold=0., adj_mask='', basic_preprocess=False, entire_image=False):
+def generate_patch_pipeline(basename, input_dir='./', annotations=[], threshold=0.5, patch_size=224, out_db='patch_info.db', generate_finetune_segmentation=False, target_class=0, intensity_threshold=100., target_threshold=0., adj_mask='', basic_preprocess=False, entire_image=False,svs_file=''):
"""Find area coverage of each annotation in each patch and store patch information into SQL db.
Parameters
@@ -561,7 +570,7 @@ def generate_patch_pipeline(basename, input_dir='./', annotations=[], threshold=
basic_preprocess:bool
Do not store patch level information.
"""
- patch_info = extract_patch_information(basename, input_dir, annotations, threshold, patch_size, generate_finetune_segmentation=generate_finetune_segmentation, target_class=target_class, intensity_threshold=intensity_threshold, target_threshold=target_threshold, adj_mask=adj_mask, basic_preprocess=basic_preprocess, entire_image=entire_image)
+ patch_info = extract_patch_information(basename, input_dir, annotations, threshold, patch_size, generate_finetune_segmentation=generate_finetune_segmentation, target_class=target_class, intensity_threshold=intensity_threshold, target_threshold=target_threshold, adj_mask=adj_mask, basic_preprocess=basic_preprocess, entire_image=entire_image,svs_file=svs_file)
conn = sqlite3.connect(out_db)
patch_info.to_sql(str(patch_size), con=conn, if_exists='append')
conn.close()
@@ -705,10 +714,11 @@ def npy2da(npy_file):
arr=da.from_array(np.load(npy_file, mmap_mode = 'r+'))
else:
npy_file=npy_file.replace('.npy','.npz')
- if npy_file.endswith('.npz'):
+ elif npy_file.endswith('.npz'):
from scipy.sparse import load_npz
arr=da.from_array(load_npz(npy_file).toarray())
-
+ elif npy_file.endswith('.h5'):
+ arr=da.from_array(h5py.File(npy_file, 'r')['dataset'])
return arr
def grab_interior_points(xml_file, img_size, annotations=[]):
@@ -859,7 +869,7 @@ def fix_names(file_dir):
#######
#@pysnooper.snoop('seg2npy.log')
-def segmentation_predictions2npy(y_pred, patch_info, segmentation_map, npy_output, original_patch_size=500, resized_patch_size=256):
+def segmentation_predictions2npy(y_pred, patch_info, segmentation_map, npy_output, original_patch_size=500, resized_patch_size=256, output_probs=False):
"""Convert segmentation predictions from model to numpy masks.
Parameters
@@ -875,11 +885,12 @@ def segmentation_predictions2npy(y_pred, patch_info, segmentation_map, npy_outpu
"""
import cv2
import copy
+ print(output_probs)
seg_map_shape=segmentation_map.shape[-2:]
original_seg_shape=copy.deepcopy(seg_map_shape)
if resized_patch_size!=original_patch_size:
seg_map_shape = [int(dim*resized_patch_size/original_patch_size) for dim in seg_map_shape]
- segmentation_map = np.zeros(tuple(seg_map_shape))
+ segmentation_map = np.zeros(tuple(seg_map_shape)).astype(float)
for i in range(patch_info.shape[0]):
patch_info_i = patch_info.iloc[i]
ID = patch_info_i['ID']
@@ -895,4 +906,6 @@ def segmentation_predictions2npy(y_pred, patch_info, segmentation_map, npy_outpu
if resized_patch_size!=original_patch_size:
segmentation_map=cv2.resize(segmentation_map.astype(float), dsize=original_seg_shape, interpolation=cv2.INTER_NEAREST)
os.makedirs(npy_output[:npy_output.rfind('/')],exist_ok=True)
- np.save(npy_output,segmentation_map.astype(np.uint8))
+ if not output_probs:
+ segmentation_map=segmentation_map.astype(np.uint8)
+ np.save(npy_output,segmentation_map)
diff --git a/setup.py b/setup.py
index 78cb253..9d12e57 100644
--- a/setup.py
+++ b/setup.py
@@ -30,7 +30,15 @@
'networkx',
'shap',
'pyyaml',
- 'torch-encoding']
+ 'torch-encoding',
+ #'lightnet',
+ 'brambox',
+ 'blosc',
+ 'numcodecs',
+ 'zarr',
+ 'pytorchcv',
+ 'h5py'
+ ]
with open('README.md','r', encoding='utf-8') as f:
long_description = f.read()
@@ -56,7 +64,8 @@ def run(self):
author='Joshua Levy',
author_email='joshualevy44@berkeley.edu',
license='MIT',
- scripts=['bin/install_apex'],
+ scripts=['bin/install_apex',
+ 'bin/install_lightnet'],
#cmdclass={'install': CustomInstallCommand},
entry_points={
'console_scripts':['pathflowai-preprocess=pathflowai.cli_preprocessing:preprocessing',