diff --git a/README.md b/README.md index 209f2a8..f385ef6 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -

Welcome to PathFlowAI 👋

+

Welcome to PathFlowAI

Version @@ -10,11 +10,11 @@ ### 🏠 [Homepage](https://github.com/jlevy44/PathFlowAI) -MedRxiv Manuscript: https://www.medrxiv.org/content/10.1101/19003897v1 +Published in the Proceedings of the Pacific Symposium for Biocomputing 2020, Manuscript: https://psb.stanford.edu/psb-online/proceedings/psb20/Levy.pdf ## Install -First, install [openslide](https://openslide.org/download/). +First, install [openslide](https://openslide.org/download/). Note: may need to install libiconv and shapely using conda. Will update with more installation information, please submit issues as well. ```sh pip install pathflowai diff --git a/bin/install_apex b/bin/install_apex index af171c1..3dc2635 100644 --- a/bin/install_apex +++ b/bin/install_apex @@ -1,5 +1,7 @@ #!/bin/bash +export TMPDIR=$HOME/tmp +mkdir -p $TMPDIR rm -rf apex git clone https://github.com/NVIDIA/apex cd apex diff --git a/bin/install_lightnet b/bin/install_lightnet new file mode 100644 index 0000000..eb0cbb7 --- /dev/null +++ b/bin/install_lightnet @@ -0,0 +1,11 @@ +#!/bin/bash +#python -m lightnet download tiny-yolo +#python -m lightnet download yolo +rm -rf lightnet +git clone https://gitlab.com/EAVISE/lightnet.git +cd lightnet +pip install . +cd .. +rm -rf lightnet +#wget https://pjreddie.com/media/files/yolo.weights +#wget https://pjreddie.com/media/files/tiny-yolo.weights diff --git a/experimental/datasets.py b/experimental/datasets.py new file mode 100644 index 0000000..4e8e98b --- /dev/null +++ b/experimental/datasets.py @@ -0,0 +1,99 @@ +# +# Lightnet dataset that works with brambox annotations +# Copyright EAVISE +# +# https://eavise.gitlab.io/lightnet/_modules/lightnet/models/_dataset_brambox.html#BramboxDataset +# https://eavise.gitlab.io/brambox/notes/02-getting_started.html#Loading-data + +import os +import copy +import logging +from PIL import Image +import numpy as np +import lightnet.data as lnd +from pathflowai.utils import load_sql_df +import dask.array as da +from os.path import join + +try: + import brambox as bb +except ImportError: + bb = None + +__all__ = ['BramboxDataset'] +log = logging.getLogger(__name__) + +# ADD IMAGE ANNOTATION TRANSFORM +# ADD TRAIN VAL TEST INFO + +class BramboxPathFlowDataset(lnd.Dataset): + """ Dataset for any brambox annotations. + + Args: + annotations (dataframe): Dataframe containing brambox annotations + input_dimension (tuple): (width,height) tuple with default dimensions of the network + class_label_map (list): List of class_labels + identify (function, optional): Lambda/function to get image based of annotation filename or image id; Default **replace/add .png extension to filename/id** + img_transform (torchvision.transforms.Compose): Transforms to perform on the images + anno_transform (torchvision.transforms.Compose): Transforms to perform on the annotations + + Note: + This dataset opens images with the Pillow library + """ + def __init__(self, input_dir, patch_info_file, patch_size, annotations, input_dimension, class_label_map=None, identify=None, img_transform=None, anno_transform=None): + if bb is None: + raise ImportError('Brambox needs to be installed to use this dataset') + super().__init__(input_dimension) + + self.annos = annotations + self.annos['ignore']=0 + self.annos['class_label']=self.annos['class_label'].astype(int)#-1 + print(self.annos['class_label'].unique()) + #print(self.annos.shape) + self.keys = self.annos.image.cat.categories # stores unique patches + #print(self.keys) + self.img_tf = img_transform + self.anno_tf = anno_transform + self.patch_info=load_sql_df(patch_info_file, patch_size) + IDs=self.patch_info['ID'].unique() + self.slides = {slide:da.from_zarr(join(input_dir,'{}.zarr'.format(slide))) for slide in IDs} + self.id = lambda k: k.split('/') + # experiment + #self.annos['x_top_left'], self.annos['y_top_left']=self.annos['y_top_left'], self.annos['x_top_left'] + self.annos['width'], self.annos['height']=self.annos['height'], self.annos['width'] + # Add class_ids + if class_label_map is None: + log.warning(f'No class_label_map given, generating it by sorting unique class labels from data alphabetically, which is not always deterministic behaviour') + class_label_map = list(np.sort(self.annos.class_label.unique())) + self.annos['class_id'] = self.annos.class_label.map(dict((l, i) for i, l in enumerate(class_label_map))) + + def __len__(self): + return len(self.keys) + + @lnd.Dataset.resize_getitem + def __getitem__(self, index): + """ Get transformed image and annotations based of the index of ``self.keys`` + + Args: + index (int): index of the ``self.keys`` list containing all the image identifiers of the dataset. + + Returns: + tuple: (transformed image, list of transformed brambox boxes) + """ + if index >= len(self): + raise IndexError(f'list index out of range [{index}/{len(self)-1}]') + + # Load + #print(self.keys[index]) + ID,x,y,patch_size=self.id(self.keys[index]) + x,y,patch_size=int(x),int(y),int(patch_size) + img = self.slides[ID][x:x+patch_size,y:y+patch_size].compute()#Image.open(self.id(self.keys[index])) + anno = bb.util.select_images(self.annos, [self.keys[index]]) + + # Transform + if self.img_tf is not None: + img = self.img_tf(img) + if self.anno_tf is not None: + anno = self.anno_tf(anno) + + return img, anno diff --git a/experimental/get_anchors.py b/experimental/get_anchors.py new file mode 100644 index 0000000..d98004f --- /dev/null +++ b/experimental/get_anchors.py @@ -0,0 +1,23 @@ +from sklearn.cluster import KMeans +import numpy as np, pandas as pd, brambox as bb +import pickle, argparse + +p=argparse.ArgumentParser() +p.add_argument('--patch_size',default=512,type=int) +p.add_argument('--n_anchors',default=20,type=int) +p.add_argument('--sample_p',default=1.,type=float) + +args=p.parse_args() +np.random.seed(42) +patch_size=args.patch_size +n_anchors=args.n_anchors +sample_p=args.sample_p +annotation_file = 'annotations_bbox_{}.pkl'.format(patch_size) +annotations=bb.io.load('pandas',annotation_file) +if sample_p<1.: + annotations=annotations.sample(frac=sample_p) + +X=annotations[['x_top_left','y_top_left']].astype(float).values+(annotations['width']/2.).astype(float).values.reshape(-1,1) +km=KMeans(n_clusters=n_anchors,n_jobs=-1).fit(X) +anchors=km.cluster_centers_ +pickle.dump(anchors,open('anchors.pkl','wb')) diff --git a/experimental/get_bounding_boxes_from_seg_point_masks.py b/experimental/get_bounding_boxes_from_seg_point_masks.py new file mode 100644 index 0000000..fb0252f --- /dev/null +++ b/experimental/get_bounding_boxes_from_seg_point_masks.py @@ -0,0 +1,137 @@ +import brambox as bb +import os +from os.path import join, basename +from pathflowai.utils import load_sql_df, npy2da +import skimage +import dask, dask.array as da, pandas as pd, numpy as np +import argparse +from scipy import ndimage +from scipy.ndimage.measurements import label +import pickle +from dask.distributed import Client +from multiprocessing import Pool +from functools import reduce + +def get_box(l,prop): + c=[prop.centroid[1], prop.centroid[0]] + # l=rev_label[i+1] + width = prop.bbox[3] - prop.bbox[1] + 1 + height = prop.bbox[2] - prop.bbox[0] + 1 + wh=max(width,height) + # c = [ci-wh/2 for ci in c] + return [l]+c+[wh] + +def get_boxes(m,ID='test',x='x',y='y',patch_size='patchsize', num_classes=3): + lbls,n_lbl=label(m) + obj_labels={} + for i in range(1,num_classes+1): + obj_labels[i]=np.unique(lbls[m==i].flatten()) + rev_label={} + for k in obj_labels: + for i in obj_labels[k]: + rev_label[i]=k + rev_label={k:rev_label[k] for k in sorted(list(rev_label.keys()))} + objProps = list(skimage.measure.regionprops(lbls)) + #print(len(objProps),len(rev_label)) + boxes=dask.compute(*[dask.delayed(get_box)(rev_label[i],objProps[i-1]) for i in list(rev_label.keys())],scheduler='threading') # [get_box(rev_label[i],objProps[i-1]) for i in list(rev_label.keys())]# + #print(boxes) + boxes=pd.DataFrame(np.array(boxes).astype(int),columns=['class_label','x_top_left','y_top_left','width']) + + #boxes['class_label']=m[boxes[['x_top_left','y_top_left']].values.T.tolist()] + boxes['height']=boxes['width'] + boxes['image']='{}/{}/{}/{}'.format(ID,x,y,patch_size) + boxes=boxes[['image','class_label','x_top_left','y_top_left','width','height']] + boxes.loc[:,'x_top_left']=np.clip(boxes.loc[:,'x_top_left'],0,m.shape[1]) + boxes.loc[:,'y_top_left']=np.clip(boxes.loc[:,'y_top_left'],0,m.shape[0]) + + bbox_df=bb.util.new('annotation').drop(columns=['difficult','ignore','lost','occluded','truncated'])[['image','class_label','x_top_left','y_top_left','width','height']] + bbox_df=bbox_df.append(boxes) + #print(boxes) + return boxes + +if __name__=='__main__': + p=argparse.ArgumentParser() + p.add_argument('--num_classes',default=4,type=int) + p.add_argument('--patch_size',default=512,type=int) + p.add_argument('--n_workers',default=40,type=int) + p.add_argument('--p_sample',default=0.7,type=float) + p.add_argument('--input_dir',default='inputs',type=str) + p.add_argument('--patch_info_file',default='cell_info.db',type=str) + p.add_argument('--reference_mask',default='reference_mask.npy',type=str) + #c=Client() + # add mode to just use own extracted boudning boxes or from seg, maybe from histomicstk + + args=p.parse_args() + num_classes=args.num_classes + n_workers=args.n_workers + input_dir=args.input_dir + patch_info_file=args.patch_info_file + patch_size=args.patch_size + p_sample=args.p_sample + np.random.seed(42) + annotation_file = 'annotations_bbox_{}.pkl'.format(patch_size) + reference_mask=args.reference_mask + if not os.path.exists('widths.pkl'): + m=np.load(reference_mask) + bbox_df=get_boxes(m) + official_widths=dict(bbox_df.groupby('class_label')['width'].mean()+2*bbox_df.groupby('class_label')['width'].std()) + pickle.dump(official_widths,open('widths.pkl','wb')) + else: + official_widths=pickle.load(open('widths.pkl','rb')) + + patch_info=load_sql_df(patch_info_file, patch_size) + IDs=patch_info['ID'].unique() + #slides = {slide:da.from_zarr(join(input_dir,'{}.zarr'.format(slide))) for slide in IDs} + masks = {mask:npy2da(join(input_dir,'{}_mask.npy'.format(mask))) for mask in IDs} + + if p_sample < 1.: + patch_info=patch_info.sample(frac=p_sample) + + if not os.path.exists(annotation_file): + bbox_df=bb.util.new('annotation').drop(columns=['difficult','ignore','lost','occluded','truncated'])[['image','class_label','x_top_left','y_top_left','width','height']] + else: + bbox_df=bb.io.load('pandas',annotation_file) + + patch_info=patch_info[~np.isin(np.vectorize(lambda i: '/'.join(patch_info.iloc[i][['ID','x','y','patch_size']].astype(str).tolist()))(np.arange(patch_info.shape[0])),set(bbox_df.image.cat.categories))] + + print(patch_info.shape[0]) + + def get_boxes_point_seg(m,ID,x,y,patch_size2,num_classes): + bbox_dff=get_boxes(m,ID=ID,x=x,y=y,patch_size=patch_size2, num_classes=num_classes) + for i in official_widths.keys(): + bbox_dff.loc[bbox_dff['class_label']==i,'width']=int(official_widths[i]) + bbox_dff.loc[:,'x_top_left']=(bbox_dff.loc[:,'x_top_left']-bbox_dff['width']/2.).astype(int) + bbox_dff.loc[:,'y_top_left']=(bbox_dff.loc[:,'y_top_left']-bbox_dff['width']/2.).astype(int) + bbox_dff.loc[:,'x_top_left']=np.clip(bbox_dff.loc[:,'x_top_left'],0,m.shape[1]) + bbox_dff.loc[:,'y_top_left']=np.clip(bbox_dff.loc[:,'y_top_left'],0,m.shape[0]) + return bbox_dff + + def process_chunk(patch_info_sub): + patch_info_sub=patch_info_sub.reset_index(drop=True) + bbox_dfs=[] + + for i in range(patch_info_sub.shape[0]): + #print(i) + patch=patch_info_sub.iloc[i] + ID,x,y,patch_size2=patch[['ID','x','y','patch_size']].tolist() + m=masks[ID][x:x+patch_size2,y:y+patch_size2] + bbox_dff=get_boxes_point_seg(m,ID,x,y,patch_size2,num_classes)#dask.delayed(get_boxes_point_seg)(m,ID,x,y,patch_size2) + #print(bbox_dff) + bbox_dfs.append(bbox_dff) + return bbox_dfs + + patch_info_subs=np.array_split(patch_info,n_workers) + + p=Pool(n_workers) + + bbox_dfs=reduce(lambda x,y:x+y,p.map(process_chunk,patch_info_subs)) + + #bbox_dfs=dask.compute(*bbox_dfs,scheduler='processes') + + bbox_df=pd.concat([bbox_df]+bbox_dfs) + + + bbox_df.loc[:,'height']=bbox_df['width'] + + + bb.io.save(bbox_df,'pandas',annotation_file) diff --git a/experimental/get_counts.py b/experimental/get_counts.py new file mode 100644 index 0000000..d242324 --- /dev/null +++ b/experimental/get_counts.py @@ -0,0 +1,73 @@ +import brambox as bb +import os +from os.path import join, basename +from pathflowai.utils import load_sql_df, npy2da, df2sql +import skimage +import dask, dask.array as da, pandas as pd, numpy as np +import argparse +from scipy import ndimage +from scipy.ndimage.measurements import label +import pickle +from dask.distributed import Client +from multiprocessing import Pool +from functools import reduce + +def count_cells(m, num_classes=3): + lbls,n_lbl=label(m) + obj_labels=np.zeros(num_classes) + for i in range(1,num_classes+1): + obj_labels[i-1]=len(np.unique(lbls[m==i].flatten())) + return obj_labels + +if __name__=='__main__': + p=argparse.ArgumentParser() + p.add_argument('--num_classes',default=4,type=int) + p.add_argument('--patch_size',default=512,type=int) + p.add_argument('--n_workers',default=40,type=int) + p.add_argument('--p_sample',default=0.7,type=float) + p.add_argument('--input_dir',default='inputs',type=str) + p.add_argument('--patch_info_file',default='cell_info.db',type=str) + p.add_argument('--reference_mask',default='reference_mask.npy',type=str) + #c=Client() + # add mode to just use own extracted boudning boxes or from seg, maybe from histomicstk + + args=p.parse_args() + num_classes=args.num_classes + n_workers=args.n_workers + input_dir=args.input_dir + patch_info_file=args.patch_info_file + patch_size=args.patch_size + np.random.seed(42) + reference_mask=args.reference_mask + + patch_info=load_sql_df(patch_info_file, patch_size) + IDs=patch_info['ID'].unique() + #slides = {slide:da.from_zarr(join(input_dir,'{}.zarr'.format(slide))) for slide in IDs} + masks = {mask:npy2da(join(input_dir,'{}_mask.npy'.format(mask))) for mask in IDs} + + def process_chunk(patch_info_sub): + patch_info_sub=patch_info_sub.reset_index(drop=True) + counts=[] + for i in range(patch_info_sub.shape[0]): + #print(i) + patch=patch_info_sub.iloc[i] + ID,x,y,patch_size2=patch[['ID','x','y','patch_size']].tolist() + m=masks[ID][x:x+patch_size2,y:y+patch_size2] + counts.append(dask.delayed(count_cells)(m, num_classes=num_classes)) + + return dask.compute(*counts,scheduler='threading') + + patch_info_subs=np.array_split(patch_info,n_workers) + + p=Pool(n_workers) + + counts=reduce(lambda x,y:x+y,p.map(process_chunk,patch_info_subs)) + + #bbox_dfs=dask.compute(*bbox_dfs,scheduler='processes') + + counts=pd.DataFrame(np.vstack(counts)) + + patch_info=pd.concat([patch_info[['ID','x','y','patch_size','annotation']].reset_index(drop=True),counts.reset_index(drop=True)],axis=1).reset_index() + print(patch_info) + + df2sql(patch_info, 'counts_test.db', patch_size, mode='replace') diff --git a/experimental/object_detection.py b/experimental/object_detection.py new file mode 100644 index 0000000..34eba09 --- /dev/null +++ b/experimental/object_detection.py @@ -0,0 +1,191 @@ +import lightnet as ln +import torch +import numpy as np, pandas as pd +import matplotlib.pyplot as plt +import brambox as bb +import dask as da +from datasets import BramboxPathFlowDataset +import argparse, pickle +from sklearn.model_selection import train_test_split + +# Settings +ln.logger.setConsoleLevel('ERROR') # Only show error log messages +bb.logger.setConsoleLevel('ERROR') +# https://eavise.gitlab.io/lightnet/notes/02-B-engine.html + +p=argparse.ArgumentParser() +p.add_argument('--num_classes',default=4,type=int) +p.add_argument('--patch_size',default=512,type=int) +p.add_argument('--patch_info_file',default='cell_info.db',type=str) +p.add_argument('--input_dir',default='inputs',type=str) +p.add_argument('--sample_p',default=1.,type=float) +p.add_argument('--conf_thresh',default=0.01,type=float) +p.add_argument('--nms_thresh',default=0.5,type=float) + + +args=p.parse_args() +np.random.seed(42) +num_classes=args.num_classes+1 +patch_size=args.patch_size +batch_size=64 +patch_info_file=args.patch_info_file +input_dir=args.input_dir +sample_p=args.sample_p +conf_thresh=args.conf_thresh +nms_thresh=args.nms_thresh +anchors=pickle.load(open('anchors.pkl','rb')) + +annotation_file = 'annotations_bbox_{}.pkl'.format(patch_size) +annotations=bb.io.load('pandas',annotation_file) + +if sample_p < 1.: + annotations=annotations.sample(frac=sample_p) + +annotations_dict={} +annotations_dict['train'],annotations_dict['test']=train_test_split(annotations) +annotations_dict['train'],annotations_dict['val']=train_test_split(annotations_dict['train']) + +model=ln.models.Yolo(num_classes=num_classes,anchors=anchors.tolist()) + +loss = ln.network.loss.RegionLoss( + num_classes=model.num_classes, + anchors=model.anchors, + stride=model.stride +) + +transforms = ln.data.transform.Compose([ln.data.transform.RandomHSV( + hue=1, + saturation=2, + value=2 +)]) + +# Create HyperParameters +params = ln.engine.HyperParameters( + network=model, + input_dimension = (patch_size,patch_size), + mini_batch_size=16, + batch_size=batch_size, + max_batches=80000 +) + +post = ln.data.transform.Compose([ + ln.data.transform.GetBoundingBoxes( + num_classes=params.network.num_classes, + anchors=params.network.anchors, + conf_thresh=conf_thresh, + ), + + ln.data.transform.NonMaxSuppression( + nms_thresh=nms_thresh + ), + + ln.data.transform.TensorToBrambox( + network_size=(patch_size,patch_size), + # class_label_map=class_label_map, + ) +]) + +datasets={k:BramboxPathFlowDataset(input_dir,patch_info_file, patch_size, annotations_dict[k], input_dimension=(patch_size,patch_size), class_label_map=None, identify=None, img_transform=None, anno_transform=None) for k in ['train','val','test']} +# transforms + +params.loss = ln.network.loss.RegionLoss(params.network.num_classes, params.network.anchors) +params.optim = torch.optim.SGD(params.network.parameters(), lr=1e-4) +params.scheduler = ln.engine.SchedulerCompositor( + # batch scheduler + (0, torch.optim.lr_scheduler.CosineAnnealingLR(params.optim,T_max=200)) + ) + +dls = {k:ln.data.DataLoader( + datasets[k], + batch_size = batch_size, + collate_fn = ln.data.brambox_collate # We want the data to be grouped as a list + ) for k in ['train','val','test']} + +params.val_loader=dls['val'] + +class CustomEngine(ln.engine.Engine): + def start(self): + """ Do whatever needs to be done before starting """ + self.params.to(self.device) # Casting parameters to a certain device + self.optim.zero_grad() # Make sure to start with no gradients + self.loss_acc = [] # Loss accumulator + + def process_batch(self, data): + """ Forward and backward pass """ + data, target = data # Unpack + #print(target) + data=data.permute(0,3,1,2).float() + if torch.cuda.is_available(): + data=data.cuda() + + #print(data) + + output = self.network(data) + #print(output) + + loss = self.loss(output, target) + + #print(loss) + loss.backward() + bbox=post(output) + print(bbox) + + self.loss_acc.append(loss.item()) + + @ln.engine.Engine.batch_end(100) # how to pass in validation dataloader + def val_loop(self): + with torch.no_grad(): + for i,data in enumerate(self.val_loader): + if i > 100: + break + data, target = data + data=data.permute(0,3,1,2).float() + if torch.cuda.is_available(): + data=data.cuda() + output = self.network(data) + #print(output) + loss = self.loss(output, target) + print(loss) + bbox=post(output) + print(bbox) + if not i: + bbox_final=[bbox] + else: + bbox_final.append(bbox) + + detections=pd.concat(bbox_final) + print(detections) + print(annotations_dict['val']) + pr=bb.stat.pr(detections, annotations_dict['val'], threshold=0.5) + auc=bb.stat.auc(pr) + print('VAL AUC={}'.format(auc)) + + @ln.engine.Engine.batch_end(300) + def save_model(self): + self.params.save(f'backup-{self.batch}.state.pt') + + def train_batch(self): + """ Weight update and logging """ + self.optim.step() + self.optim.zero_grad() + + batch_loss = sum(self.loss_acc) / len(self.loss_acc) + self.loss_acc = [] + self.log(f'Loss: {batch_loss}') + + def quit(self): + if self.batch >= self.max_batches: # Should probably save weights here + print('Reached end of training') + return True + return False + + + +# Create engine +engine = CustomEngine( + params, dls['train'], # Dataloader (None) is not valid + device=torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') +) + +for i in range(10): + engine() diff --git a/pathflowai/cli_preprocessing.py b/pathflowai/cli_preprocessing.py index efedd04..04bb5f5 100644 --- a/pathflowai/cli_preprocessing.py +++ b/pathflowai/cli_preprocessing.py @@ -50,10 +50,11 @@ def output_if_exists(filename): @click.option('-nn', '--n_neighbors', default=5, help='If adjusting mask, number of neighbors connectivity to remove.', show_default=True) @click.option('-bp', '--basic_preprocess', is_flag=True, help='Basic preprocessing pipeline, annotation areas are not saved. Used for benchmarking tool against comparable pipelines', show_default=True) @click.option('-ei', '--entire_image', is_flag=True, help='Store entire image in central db rather than patches.', show_default=True) -def preprocess_pipeline(img2npy,basename,input_dir,annotations,preprocess,patches,threshold,patch_size, intensity_threshold, generate_finetune_segmentation, target_segmentation_class, target_threshold, out_db, adjust_mask, n_neighbors, basic_preprocess, entire_image): +@click.option('-nz', '--no_zarr', is_flag=True, help='Don\'t save zarr format file.', show_default=True) +def preprocess_pipeline(img2npy,basename,input_dir,annotations,preprocess,patches,threshold,patch_size, intensity_threshold, generate_finetune_segmentation, target_segmentation_class, target_threshold, out_db, adjust_mask, n_neighbors, basic_preprocess, entire_image, no_zarr): """Preprocessing pipeline that accomplishes 3 things. 1: storage into ZARR format, 2: optional mask adjustment, 3: storage of patch-level information into SQL DB""" - for ext in ['.npy','.svs','.tiff','.tif', '.vms', '.vmu', '.ndpi', '.scn', '.mrxs', '.svslide', '.bif', '.jpeg', '.png']: + for ext in ['.npy','.svs','.tiff','.tif', '.vms', '.vmu', '.ndpi', '.scn', '.mrxs', '.svslide', '.bif', '.jpeg', '.png', '.h5']: svs_file = output_if_exists(join(input_dir,'{}{}'.format(basename,ext))) if svs_file != None: break @@ -75,13 +76,15 @@ def preprocess_pipeline(img2npy,basename,input_dir,annotations,preprocess,patche npy_mask=npy_mask, annotations=annotations, out_zarr=out_zarr, - out_pkl=out_pkl) + out_pkl=out_pkl, + no_zarr=no_zarr) if npy_mask==None and xml_file==None: + print('Generating Zero Mask') npy_mask=join(input_dir,'{}_mask.npz'.format(basename)) target_segmentation_class=1 generate_finetune_segmentation=True - create_zero_mask(npy_mask,out_zarr,out_pkl) + create_zero_mask(npy_mask,out_zarr if not no_zarr else svs_file,out_pkl) preprocess_point = time.time() @@ -93,7 +96,7 @@ def preprocess_pipeline(img2npy,basename,input_dir,annotations,preprocess,patche adj_npy=join(adj_dir,os.path.basename(npy_mask)) os.makedirs(adj_dir,exist_ok=True) if not os.path.exists(adj_npy): - adjust_mask(npy_mask, out_zarr, adj_npy, n_neighbors) + adjust_mask(npy_mask, out_zarr if not no_zarr else svs_file, adj_npy, n_neighbors) adjust_point = time.time() print('Adjust took {}'.format(adjust_point-preprocess_point)) @@ -111,7 +114,8 @@ def preprocess_pipeline(img2npy,basename,input_dir,annotations,preprocess,patche target_threshold=target_threshold, adj_mask=adj_npy, basic_preprocess=basic_preprocess, - entire_image=entire_image) + entire_image=entire_image, + svs_file=svs_file) patch_point = time.time() print('Patches took {}'.format(patch_point-adjust_point)) diff --git a/pathflowai/cli_visualizations.py b/pathflowai/cli_visualizations.py index 91cf3ff..2ad4408 100644 --- a/pathflowai/cli_visualizations.py +++ b/pathflowai/cli_visualizations.py @@ -125,7 +125,7 @@ def plot_embeddings(embeddings_file,plotly_output_file, annotations, remove_back """Perform UMAP embeddings of patches and plot using plotly.""" import torch from umap import UMAP - from visualize import PlotlyPlot + from pathflowai.visualize import PlotlyPlot import pandas as pd, numpy as np embeddings_dict=torch.load(embeddings_file) embeddings=embeddings_dict['embeddings'] diff --git a/pathflowai/datasets.py b/pathflowai/datasets.py index fcf6dd7..9be82ee 100644 --- a/pathflowai/datasets.py +++ b/pathflowai/datasets.py @@ -372,6 +372,8 @@ def __init__(self,dataset_df, set, patch_info_file, transformers, input_dir, tar self.classify_annotations=classify_annotations print(self.targets) self.dilation_jitter=DilationJitter(dilation_jitter,self.segmentation,(original_set=='train')) + if not self.targets: + self.targets = [pos_annotation_class]+list(other_annotations) def concat(self, other_dataset): """Concatenate this dataset with others. Updates its own internal attributes. @@ -528,7 +530,7 @@ def update_dataset(self, input_dir, new_db, prediction_basename=[]): self.segmentation_maps = {slide:npy2da(join(self.input_dir,'{}_mask.npy'.format(slide))) for slide in IDs} self.length = self.patch_info.shape[0] - @pysnooper.snoop('get_item.log') + #@pysnooper.snoop('get_item.log') def __getitem__(self, i): patch_info = self.patch_info.iloc[i] ID = patch_info['ID'] diff --git a/pathflowai/model_training.py b/pathflowai/model_training.py index 9405b02..26c6b0d 100644 --- a/pathflowai/model_training.py +++ b/pathflowai/model_training.py @@ -124,7 +124,8 @@ def train_model_(training_opts): eta_min=training_opts['eta_min'], T_mult=training_opts['T_mult']), loss_fn=training_opts['loss_fn'], - num_train_batches=num_train_batches) + num_train_batches=num_train_batches, + seg_out_class=training_opts['seg_out_class']) if not training_opts['predict']: @@ -164,11 +165,17 @@ def train_model_(training_opts): exit() y_pred = trainer.predict(dataloader) print(ID,y_pred.shape) - segmentation_predictions2npy(y_pred, dataset.patch_info, dataset.segmentation_maps[ID], npy_output='{}/{}_predict.npy'.format(training_opts['prediction_output_dir'],ID), original_patch_size=training_opts['patch_size'], resized_patch_size=training_opts['patch_resize']) + segmentation_predictions2npy(y_pred, dataset.patch_info, dataset.segmentation_maps[ID], npy_output='{}/{}_predict.npy'.format(training_opts['prediction_output_dir'],ID), original_patch_size=training_opts['patch_size'], resized_patch_size=training_opts['patch_resize'], output_probs=(training_opts['seg_out_class']>=0)) else: extract_embedding=training_opts['extract_embedding'] if extract_embedding: - trainer.model.fc = trainer.model.fc[0] + architecture=training_opts['architecture'] + if hasattr(trainer.model,"fc"): + trainer.model.fc = trainer.model.fc[0] + elif hasattr(trainer.model,"output"): + trainer.model.output = trainer.model.output[0] + elif architecture.startswith('alexnet') or architecture.startswith('vgg') or architecture.startswith('densenet'): + trainer.model.classifier[6]=trainer.model.classifier[6][0] trainer.bce=False y_pred = trainer.predict(dataloaders['test']) @@ -206,7 +213,7 @@ def train_model_(training_opts): @click.option('-fn', '--fix_names', is_flag=True, help='Whether to fix names in dataset_df.', show_default=True) @click.option('-a', '--architecture', default='alexnet', help='Neural Network Architecture.', type=click.Choice(['alexnet', 'densenet121', 'densenet161', 'densenet169', 'densenet201', 'inception_v3', 'resnet101', 'resnet152', 'resnet18', 'resnet34', 'resnet50', 'vgg11', 'vgg11_bn','unet','unet2','nested_unet','fast_scnn', - 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 'vgg19', 'vgg19_bn', 'deeplabv3_resnet101','deeplabv3_resnet50','fcn_resnet101', 'fcn_resnet50']+['efficientnet-b{}'.format(i) for i in range(8)]), show_default=True) + 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 'vgg19', 'vgg19_bn', 'deeplabv3_resnet101','deeplabv3_resnet50','fcn_resnet101', 'fcn_resnet50',"sqnxt23_w3d2", "sqnxt23_w2", "sqnxt23v5_w1", "sqnxt23v5_w3d2", "sqnxt23v5_w2"]+['efficientnet-b{}'.format(i) for i in range(8)]), show_default=True) @click.option('-imb', '--imbalanced_correction', is_flag=True, help='Attempt to correct for imbalanced data.', show_default=True) @click.option('-imb2', '--imbalanced_correction2', is_flag=True, help='Attempt to correct for imbalanced data.', show_default=True) @click.option('-ca', '--classify_annotations', is_flag=True, help='Classify annotations.', show_default=True) @@ -238,7 +245,9 @@ def train_model_(training_opts): @click.option('-cw', '--custom_weights', default='', help='Comma delimited custom weights', type=click.Path(exists=False), show_default=True) @click.option('-pset', '--prediction_set', default='test', help='Dataset to predict on.', type=click.Choice(['train','val','test']), show_default=True) @click.option('-ut', '--user_transforms_file', default='', help='YAML file to add transforms from.', type=click.Path(exists=False), show_default=True) -def train_model(segmentation,prediction,pos_annotation_class,other_annotations,save_location,pretrained_save_location,input_dir,patch_size,patch_resize,target_names,dataset_df,fix_names, architecture, imbalanced_correction, imbalanced_correction2, classify_annotations, num_targets, subsample_p,subsample_p_val,num_training_images_epoch, learning_rate, transform_platform, n_epoch, patch_info_file, target_segmentation_class, target_threshold, oversampling_factor, supplement, batch_size, run_test, mt_bce, prediction_output_dir, extract_embedding, extract_model, binary_threshold, pretrain, overwrite_loss_fn, adopt_training_loss, external_test_db,external_test_dir, prediction_basename, custom_weights, prediction_set, user_transforms_file): +@click.option('-svp', '--save_val_predictions', is_flag=True, help='Whether to save the validation predictions.', show_default=True) +@click.option('-soc', '--seg_out_class', default=-1, help='Output a particular segmentation class probabilities.', show_default=True) +def train_model(segmentation,prediction,pos_annotation_class,other_annotations,save_location,pretrained_save_location,input_dir,patch_size,patch_resize,target_names,dataset_df,fix_names, architecture, imbalanced_correction, imbalanced_correction2, classify_annotations, num_targets, subsample_p,subsample_p_val,num_training_images_epoch, learning_rate, transform_platform, n_epoch, patch_info_file, target_segmentation_class, target_threshold, oversampling_factor, supplement, batch_size, run_test, mt_bce, prediction_output_dir, extract_embedding, extract_model, binary_threshold, pretrain, overwrite_loss_fn, adopt_training_loss, external_test_db,external_test_dir, prediction_basename, custom_weights, prediction_set, user_transforms_file, save_val_predictions, seg_out_class): """Train and predict using model for regression and classification tasks.""" # add separate pretrain ability on separating cell types, then transfer learn # add pretrain and efficient net, pretraining remove last layer while loading state dict @@ -296,11 +305,12 @@ def train_model(segmentation,prediction,pos_annotation_class,other_annotations,s external_test_db=external_test_db, external_test_dir=external_test_dir, prediction_basename=prediction_basename, - save_val_predictions=True, + save_val_predictions=save_val_predictions, custom_weights=custom_weights, prediction_set=prediction_set, user_transforms=dict(), - dilation_jitter=dict()) + dilation_jitter=dict(), + seg_out_class=seg_out_class) training_opts = dict(normalization_file="normalization_parameters.pkl", loss_fn='bce', diff --git a/pathflowai/models.py b/pathflowai/models.py index 6ac138c..e841a84 100644 --- a/pathflowai/models.py +++ b/pathflowai/models.py @@ -66,6 +66,9 @@ def __init__(self, n_input, hidden_topology, dropout_p, n_outputs=1, binary=True self.layers.append(nn.Sequential(self.output_layer,output_transform)) self.mlp = nn.Sequential(*self.layers) + def forward(self,x): + return self.mlp(x) + class FixedSegmentationModule(nn.Module): """Special model modification for segmentation tasks. Gets output from some of the models' forward loops. @@ -139,6 +142,11 @@ def generate_model(pretrain,architecture,num_classes, add_sigmoid=True, n_hidden else: model = EfficientNet.from_name(architecture, override_params=dict(num_classes=num_classes)) print(model) + elif architecture.startswith('sqnxt'): + from pytorchcv.model_provider import get_model as ptcv_get_model + model = ptcv_get_model(architecture, pretrained=pretrain) + num_ftrs=int(128*int(architecture.split('_')[-1][1])) + model.output=MLP(num_ftrs, [1000], dropout_p=0., n_outputs=num_classes, binary=add_sigmoid, softmax=False).mlp else: #for pretrained on imagenet model_names = [m for m in dir(models) if not m.startswith('__')] @@ -161,7 +169,7 @@ def generate_model(pretrain,architecture,num_classes, add_sigmoid=True, n_hidden #linear_layer = nn.Linear(num_ftrs, num_classes) #torch.nn.init.xavier_uniform(linear_layer.weight) model.fc = MLP(num_ftrs, [1000], dropout_p=0., n_outputs=num_classes, binary=add_sigmoid, softmax=False).mlp#nn.Sequential(*([linear_layer]+([nn.Sigmoid()] if (add_sigmoid) else []))) - elif architecture.startswith('alexnet') or architecture.startswith('vgg') or architecture.startswith('densenets'): + elif architecture.startswith('alexnet') or architecture.startswith('vgg') or architecture.startswith('densenet'): num_ftrs = model.classifier[6].in_features #linear_layer = nn.Linear(num_ftrs, num_classes) #torch.nn.init.xavier_uniform(linear_layer.weight) @@ -231,7 +239,7 @@ class ModelTrainer: num_train_batches:int Number of training batches for epoch. """ - def __init__(self, model, n_epoch=300, validation_dataloader=None, optimizer_opts=dict(name='adam',lr=1e-3,weight_decay=1e-4), scheduler_opts=dict(scheduler='warm_restarts',lr_scheduler_decay=0.5,T_max=10,eta_min=5e-8,T_mult=2), loss_fn='ce', reduction='mean', num_train_batches=None): + def __init__(self, model, n_epoch=300, validation_dataloader=None, optimizer_opts=dict(name='adam',lr=1e-3,weight_decay=1e-4), scheduler_opts=dict(scheduler='warm_restarts',lr_scheduler_decay=0.5,T_max=10,eta_min=5e-8,T_mult=2), loss_fn='ce', reduction='mean', num_train_batches=None, seg_out_class=-1): self.model = model optimizers = {'adam':torch.optim.Adam, 'sgd':torch.optim.SGD} @@ -240,7 +248,11 @@ def __init__(self, model, n_epoch=300, validation_dataloader=None, optimizer_opt if 'name' not in list(optimizer_opts.keys()): optimizer_opts['name']='adam' self.optimizer = optimizers[optimizer_opts.pop('name')](self.model.parameters(),**optimizer_opts) - self.model, self.optimizer = amp.initialize(self.model, self.optimizer, opt_level='O2') + if torch.cuda.is_available(): + self.model, self.optimizer = amp.initialize(self.model, self.optimizer, opt_level='O2') + self.cuda=True + else: + self.cuda=False self.scheduler = Scheduler(optimizer=self.optimizer,opts=scheduler_opts) self.n_epoch = n_epoch self.validation_dataloader = validation_dataloader @@ -251,6 +263,7 @@ def __init__(self, model, n_epoch=300, validation_dataloader=None, optimizer_opt self.original_loss_fn = copy.deepcopy(loss_functions[loss_fn]) self.num_train_batches = num_train_batches self.val_loss_fn = copy.deepcopy(loss_functions[loss_fn]) + self.seg_out_class=seg_out_class def calc_loss(self, y_pred, y_true): """Calculates loss supplied in init statement and modified by reweighting. @@ -348,8 +361,11 @@ def loss_backward(self,loss): Torch loss calculated. """ - with amp.scale_loss(loss,self.optimizer) as scaled_loss: - scaled_loss.backward() + if self.cuda: + with amp.scale_loss(loss,self.optimizer) as scaled_loss: + scaled_loss.backward() + else: + loss.backward() #@pysnooper.snoop('train_loop.log') def train_loop(self, epoch, train_dataloader): @@ -493,13 +509,17 @@ def test_loop(self, test_dataloader): if torch.cuda.is_available(): X = X.cuda() if test_dataloader.dataset.segmentation: - prediction=self.model(X).detach().cpu().numpy().argmax(axis=1) + prediction=self.model(X).detach().cpu().numpy() + if self.seg_out_class>=0: + prediction=prediction[:,self.seg_out_class,...] + else: + prediction=prediction.argmax(axis=1).astype(int) pred_size=prediction.shape#size() #pred_mean=prediction[0].mean(axis=0) - y_pred.append((prediction).astype(int)) + y_pred.append(prediction) else: prediction=self.model(X) - if (len(test_dataloader.dataset.targets)-1) or self.bce: + if self.loss_fn_name != 'mse' and ((len(test_dataloader.dataset.targets)-1) or self.bce): prediction=self.sigmoid(prediction) elif test_dataloader.dataset.classify_annotations: prediction=F.softmax(prediction,dim=1) diff --git a/pathflowai/utils.py b/pathflowai/utils.py index cb80f2a..3f4001b 100644 --- a/pathflowai/utils.py +++ b/pathflowai/utils.py @@ -35,9 +35,8 @@ #import xarray as xr, sparse import pickle import copy - +import h5py import nonechucks as nc - from nonechucks import SafeDataLoader as DataLoader def load_sql_df(sql_file, patch_size): @@ -224,6 +223,9 @@ def create_sparse_annotation_arrays(xml_file, img_size, annotations=[]): interior_points_dict = {annotation:parse_coord_return_boxes(xml_file, annotation_name = annotation, return_coords = False) for annotation in annotations}#grab_interior_points(xml_file, img_size, annotations=annotations) if annotations else {} return {annotation:interior_points_dict[annotation] for annotation in annotations}#sparse.COO.from_scipy_sparse((sps.coo_matrix(interior_points_dict[annotation],img_size, dtype=np.uint8) if interior_points_dict[annotation] not None else sps.coo_matrix(img_size, dtype=np.uint8)).tocsr()) for annotation in annotations} # [sps.coo_matrix(img_size, dtype=np.uint8)]+ +def load_image(svs_file): + return (npy2da(svs_file) if (svs_file.endswith('.npy') or svs_file.endswith('.h5')) else svs2dask_array(svs_file, tile_size=1000, overlap=0)) + def load_process_image(svs_file, xml_file=None, npy_mask=None, annotations=[]): """Load SVS-like image (including NPY), segmentation/classification annotations, generate dask array and dictionary of annotations. @@ -246,7 +248,7 @@ def load_process_image(svs_file, xml_file=None, npy_mask=None, annotations=[]): Annotation masks. """ - arr = npy2da(svs_file) if svs_file.endswith('.npy') else svs2dask_array(svs_file, tile_size=1000, overlap=0)#load_image(svs_file) + arr = load_image(svs_file)#npy2da(svs_file) if (svs_file.endswith('.npy') or svs_file.endswith('.h5')) else svs2dask_array(svs_file, tile_size=1000, overlap=0)#load_image(svs_file) img_size = arr.shape[:2] masks = {}#{'purple': create_purple_mask(arr,img_size,sparse=False)} if xml_file is not None: @@ -261,7 +263,7 @@ def load_process_image(svs_file, xml_file=None, npy_mask=None, annotations=[]): #arr = da.concatenate([arr,masks.pop('purple')],axis=2) return arr, masks#xr.Dataset.from_dict({k:v for k,v in list(data_arr.items())+list(purple_arr.items())+list(mask_arr.items())})#list(dict(image=data_arr,purple=purple_arr,annotations=mask_arr).items()))#arr, masks -def save_dataset(arr, masks, out_zarr, out_pkl): +def save_dataset(arr, masks, out_zarr, out_pkl, no_zarr): """Saves dask array image, dictionary of annotations to zarr and pickle respectively. Parameters @@ -275,13 +277,14 @@ def save_dataset(arr, masks, out_zarr, out_pkl): out_pkl:str Pickle output file. """ - arr.astype('uint8').to_zarr(out_zarr, overwrite=True) + if not no_zarr: + arr.astype('uint8').to_zarr(out_zarr, overwrite=True) pickle.dump(masks,open(out_pkl,'wb')) #dataset.to_netcdf(out_netcdf, compute=False) #pickle.dump(dataset, open(out_pkl,'wb'), protocol=-1) -def run_preprocessing_pipeline(svs_file, xml_file=None, npy_mask=None, annotations=[], out_zarr='output_zarr.zarr', out_pkl='output.pkl'): +def run_preprocessing_pipeline(svs_file, xml_file=None, npy_mask=None, annotations=[], out_zarr='output_zarr.zarr', out_pkl='output.pkl',no_zarr=False): """Run preprocessing pipeline. Store image into zarr format, segmentations maintain as npy, and xml annotations as pickle. Parameters @@ -301,7 +304,7 @@ def run_preprocessing_pipeline(svs_file, xml_file=None, npy_mask=None, annotatio """ #save_dataset(load_process_image(svs_file, xml_file, npy_mask, annotations), out_netcdf) arr, masks = load_process_image(svs_file, xml_file, npy_mask, annotations) - save_dataset(arr, masks,out_zarr, out_pkl) + save_dataset(arr, masks,out_zarr, out_pkl, no_zarr) ################### @@ -380,7 +383,11 @@ def load_dataset(in_zarr, in_pkl): Annotations dictionary. """ - return da.from_zarr(in_zarr), pickle.load(open(in_pkl,'rb'))#xr.open_dataset(in_netcdf) + if not os.path.exists(in_pkl): + annotations={'annotations':''} + else: + annotations=pickle.load(open(in_pkl,'rb')) + return (da.from_zarr(in_zarr) if in_zarr.endswith('.zarr') else load_image(in_zarr)), annotations#xr.open_dataset(in_netcdf) def is_valid_patch(xs,ys,patch_size,purple_mask,intensity_threshold,threshold=0.5): """Deprecated, computes whether patch is valid.""" @@ -400,7 +407,7 @@ def fix_polygon(poly): return poly #@pysnooper.snoop("extract_patch.log") -def extract_patch_information(basename, input_dir='./', annotations=[], threshold=0.5, patch_size=224, generate_finetune_segmentation=False, target_class=0, intensity_threshold=100., target_threshold=0., adj_mask='', basic_preprocess=False, tries=0, entire_image=False): +def extract_patch_information(basename, input_dir='./', annotations=[], threshold=0.5, patch_size=224, generate_finetune_segmentation=False, target_class=0, intensity_threshold=100., target_threshold=0., adj_mask='', basic_preprocess=False, tries=0, entire_image=False, svs_file=''): """Final step of preprocessing pipeline. Break up image into patches, include if not background and of a certain intensity, find area of each annotation type in patch, spatial information, image ID and dump data to SQL table. Parameters @@ -450,7 +457,7 @@ def extract_patch_information(basename, input_dir='./', annotations=[], threshol from functools import reduce #from distributed import Client,LocalCluster max_tries=4 - kargs=dict(basename=basename, input_dir=input_dir, annotations=annotations, threshold=threshold, patch_size=patch_size, generate_finetune_segmentation=generate_finetune_segmentation, target_class=target_class, intensity_threshold=intensity_threshold, target_threshold=target_threshold, adj_mask=adj_mask, basic_preprocess=basic_preprocess, tries=tries) + kargs=dict(basename=basename, input_dir=input_dir, annotations=annotations, threshold=threshold, patch_size=patch_size, generate_finetune_segmentation=generate_finetune_segmentation, target_class=target_class, intensity_threshold=intensity_threshold, target_threshold=target_threshold, adj_mask=adj_mask, basic_preprocess=basic_preprocess, tries=tries, svs_file=svs_file) try: #, # 'distributed.scheduler.allowed-failures':20, @@ -459,13 +466,15 @@ def extract_patch_information(basename, input_dir='./', annotations=[], threshol #cluster.adapt(minimum=10, maximum=100) #cluster = LocalCluster(threads_per_worker=1, n_workers=20, memory_limit="80G") #client=Client()#Client(cluster)#processes=True)#cluster, - - arr, masks = load_dataset(join(input_dir,'{}.zarr'.format(basename)),join(input_dir,'{}_mask.pkl'.format(basename))) + in_zarr=join(input_dir,'{}.zarr'.format(basename)) + in_zarr=(in_zarr if os.path.exists(in_zarr) else svs_file) + arr, masks = load_dataset(in_zarr,join(input_dir,'{}_mask.pkl'.format(basename))) if 'annotations' in masks: segmentation = True - #if generate_finetune_segmentation: - segmentation_mask = npy2da(join(input_dir,'{}_mask.npy'.format(basename)) if not adj_mask else adj_mask) + mask=join(input_dir,'{}_mask.npy'.format(basename)) + mask = (mask if os.path.exists(mask) else mask.replace('.npy','.npz')) + segmentation_mask = (npy2da(mask) if not adj_mask else adj_mask) else: segmentation = False annotations=list(annotations) @@ -531,7 +540,7 @@ def extract_patch_information(basename, input_dir='./', annotations=[], threshol print(patch_info) return patch_info -def generate_patch_pipeline(basename, input_dir='./', annotations=[], threshold=0.5, patch_size=224, out_db='patch_info.db', generate_finetune_segmentation=False, target_class=0, intensity_threshold=100., target_threshold=0., adj_mask='', basic_preprocess=False, entire_image=False): +def generate_patch_pipeline(basename, input_dir='./', annotations=[], threshold=0.5, patch_size=224, out_db='patch_info.db', generate_finetune_segmentation=False, target_class=0, intensity_threshold=100., target_threshold=0., adj_mask='', basic_preprocess=False, entire_image=False,svs_file=''): """Find area coverage of each annotation in each patch and store patch information into SQL db. Parameters @@ -561,7 +570,7 @@ def generate_patch_pipeline(basename, input_dir='./', annotations=[], threshold= basic_preprocess:bool Do not store patch level information. """ - patch_info = extract_patch_information(basename, input_dir, annotations, threshold, patch_size, generate_finetune_segmentation=generate_finetune_segmentation, target_class=target_class, intensity_threshold=intensity_threshold, target_threshold=target_threshold, adj_mask=adj_mask, basic_preprocess=basic_preprocess, entire_image=entire_image) + patch_info = extract_patch_information(basename, input_dir, annotations, threshold, patch_size, generate_finetune_segmentation=generate_finetune_segmentation, target_class=target_class, intensity_threshold=intensity_threshold, target_threshold=target_threshold, adj_mask=adj_mask, basic_preprocess=basic_preprocess, entire_image=entire_image,svs_file=svs_file) conn = sqlite3.connect(out_db) patch_info.to_sql(str(patch_size), con=conn, if_exists='append') conn.close() @@ -705,10 +714,11 @@ def npy2da(npy_file): arr=da.from_array(np.load(npy_file, mmap_mode = 'r+')) else: npy_file=npy_file.replace('.npy','.npz') - if npy_file.endswith('.npz'): + elif npy_file.endswith('.npz'): from scipy.sparse import load_npz arr=da.from_array(load_npz(npy_file).toarray()) - + elif npy_file.endswith('.h5'): + arr=da.from_array(h5py.File(npy_file, 'r')['dataset']) return arr def grab_interior_points(xml_file, img_size, annotations=[]): @@ -859,7 +869,7 @@ def fix_names(file_dir): ####### #@pysnooper.snoop('seg2npy.log') -def segmentation_predictions2npy(y_pred, patch_info, segmentation_map, npy_output, original_patch_size=500, resized_patch_size=256): +def segmentation_predictions2npy(y_pred, patch_info, segmentation_map, npy_output, original_patch_size=500, resized_patch_size=256, output_probs=False): """Convert segmentation predictions from model to numpy masks. Parameters @@ -875,11 +885,12 @@ def segmentation_predictions2npy(y_pred, patch_info, segmentation_map, npy_outpu """ import cv2 import copy + print(output_probs) seg_map_shape=segmentation_map.shape[-2:] original_seg_shape=copy.deepcopy(seg_map_shape) if resized_patch_size!=original_patch_size: seg_map_shape = [int(dim*resized_patch_size/original_patch_size) for dim in seg_map_shape] - segmentation_map = np.zeros(tuple(seg_map_shape)) + segmentation_map = np.zeros(tuple(seg_map_shape)).astype(float) for i in range(patch_info.shape[0]): patch_info_i = patch_info.iloc[i] ID = patch_info_i['ID'] @@ -895,4 +906,6 @@ def segmentation_predictions2npy(y_pred, patch_info, segmentation_map, npy_outpu if resized_patch_size!=original_patch_size: segmentation_map=cv2.resize(segmentation_map.astype(float), dsize=original_seg_shape, interpolation=cv2.INTER_NEAREST) os.makedirs(npy_output[:npy_output.rfind('/')],exist_ok=True) - np.save(npy_output,segmentation_map.astype(np.uint8)) + if not output_probs: + segmentation_map=segmentation_map.astype(np.uint8) + np.save(npy_output,segmentation_map) diff --git a/setup.py b/setup.py index 78cb253..9d12e57 100644 --- a/setup.py +++ b/setup.py @@ -30,7 +30,15 @@ 'networkx', 'shap', 'pyyaml', - 'torch-encoding'] + 'torch-encoding', + #'lightnet', + 'brambox', + 'blosc', + 'numcodecs', + 'zarr', + 'pytorchcv', + 'h5py' + ] with open('README.md','r', encoding='utf-8') as f: long_description = f.read() @@ -56,7 +64,8 @@ def run(self): author='Joshua Levy', author_email='joshualevy44@berkeley.edu', license='MIT', - scripts=['bin/install_apex'], + scripts=['bin/install_apex', + 'bin/install_lightnet'], #cmdclass={'install': CustomInstallCommand}, entry_points={ 'console_scripts':['pathflowai-preprocess=pathflowai.cli_preprocessing:preprocessing',