-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathrun_model.py
More file actions
80 lines (67 loc) · 2.36 KB
/
run_model.py
File metadata and controls
80 lines (67 loc) · 2.36 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
''' Runner script to (locally) run training '''
import os, sys
import shutil
import tempfile
import matplotlib.pyplot as plt
import PIL
import torch
import numpy as np
from sklearn.metrics import classification_report
from monai.config import print_config
from monai.data import decollate_batch, DataLoader
from monai.metrics import ROCAUCMetric
from monai.networks.nets import DenseNet121, EfficientNetBN
from monai.transforms import (
Activations,
EnsureChannelFirst,
AsDiscrete,
Compose,
LoadImage,
RandFlip,
RandRotate,
RandZoom,
ScaleIntensity,
)
from monai.utils import set_determinism
from mri_dataset import MRIDataset_bl
from glob import glob
import pandas as pd
from collections import Counter
from datetime import datetime
from helper_functions import load_adni
from helper_functions import get_images_from_dirs
root_dir = '/home/imber/Projects/PASSIAN/data/ADNI/aramis_preproc/CAPS_ADNI_bl' # the full adni baseline dataset
labels_csv_path = glob(root_dir + '/*.csv')[0]
print("Dataset directory is: ", root_dir)
## Set deterministic
set_determinism(seed=317)
image_files_list, image_class, cn = load_adni(root_dir, labels_csv_path)
# Todo: abstract these for tidyness
## Prep train, val, test
val_frac = 0.2
test_frac = 0.1
length = len(image_files_list)
num_class = 3
indices = np.arange(length)
np.random.shuffle(indices)
bs = 10 # batch size
max_epochs = 300 # AUC 0.7741 after 100 epochs
test_split = int(test_frac * length)
val_split = int(val_frac * length) + test_split
test_indices = indices[:test_split]
val_indices = indices[test_split:val_split]
train_indices = indices[val_split:]
train_x = [image_files_list[i] for i in train_indices]
train_y = [image_class[i] for i in train_indices]
val_x = [image_files_list[i] for i in val_indices]
val_y = [image_class[i] for i in val_indices]
test_x = [image_files_list[i] for i in test_indices]
test_y = [image_class[i] for i in test_indices]
print(
f"Training count: {len(train_x)}, Validation count: "
f"{len(val_x)}, Test count: {len(test_x)}")
print('Class balance on training population:', torch.tensor(train_y).unique(return_counts=True))
print('Class balance on validation population:', torch.tensor(val_y).unique(return_counts=True))
print('Class balance on test population:', torch.tensor(test_y).unique(return_counts=True))
## Todo: implement 5-fold crossval
## Todo: clean up imports