diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 000000000..7a52f2ed3 Binary files /dev/null and b/.DS_Store differ diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 000000000..26d33521a --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,3 @@ +# Default ignored files +/shelf/ +/workspace.xml diff --git a/.idea/PatternAnalysis-2025.iml b/.idea/PatternAnalysis-2025.iml new file mode 100644 index 000000000..d0876a78d --- /dev/null +++ b/.idea/PatternAnalysis-2025.iml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml new file mode 100644 index 000000000..0318588e8 --- /dev/null +++ b/.idea/inspectionProfiles/Project_Default.xml @@ -0,0 +1,12 @@ + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 000000000..105ce2da2 --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 000000000..9de286525 --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,7 @@ + + + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 000000000..33e468890 --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 000000000..94a25f7f4 --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/recognition/.DS_Store b/recognition/.DS_Store new file mode 100644 index 000000000..e93a6e752 Binary files /dev/null and b/recognition/.DS_Store differ diff --git a/recognition/GANmodel_47508042/README.md b/recognition/GANmodel_47508042/README.md new file mode 100644 index 000000000..e07e12ee6 --- /dev/null +++ b/recognition/GANmodel_47508042/README.md @@ -0,0 +1,43 @@ +# Pattern Recognition - Generative Model using VQVAE + +## Overview +This project implements a Vector Quantised Variational Autoencoder (VQVAE) to learna generative model from the HipMRI +prostate cancer study dataset. THe model reconstructs 2D MRI slices using compressed discrete latent codes, effectively +learning meaningful anatomical representation of prostate structures. The objective of this is to produce clear, realistic +reconstructions of MRI slices with a SSIM > 0.6 on testing. + +## How it works +VQVAE is a discrete latent variable model that combines that representational power of neural networks with the +compression efficiency of vector quantization. It consists of three main components: +1. Encoder: A convolutional neural network that maps input MRI slices to a continuous latent space. +2. Vector Quantizer: This module discretizes the continuous latent representations into a finite set of learned embedding vectors (codebook). +3. Decoder: Another convolutional neural network that reconstructs the MRI slices from the quantized latent codes. +The model is trained end-to-end using a combination of reconstruction loss (Mean Squared Error) and a commitment loss (VQ loss) + +## Files + - 'dataset.py' - Encoder, Decoder, VectorQuantizer, VQVAE model + - 'modules.py' - 'HipMRIDataset' loader + - 'predict.py' - load best model and produce reconstructions + - 'train.py' - training loop, validation, checkpointing, plots + - 'utils.py' - helper functions for training and evaluation + - 'readme.md' - documentation + +## Install: +```bash +pip install torch torchvision matplotlib scikit-image numpy nibabel tqdm +``` + +## Training +To train the model, run: +```bash +python3 train.py --data_root "/Users/justin/Downloads/HipMRIDataset" +``` + +## Prediction +To generate reconstructions using the best model, run: +```bash +python3 predict.py --checkpoint outputs/best_checkpoint.pth --data_root "/Users/justin/Downloads/HipMRIDataset" +``` + +## Author +Justin Teng (47508042) \ No newline at end of file diff --git a/recognition/GANmodel_47508042/__pycache__/dataset.cpython-310.pyc b/recognition/GANmodel_47508042/__pycache__/dataset.cpython-310.pyc new file mode 100644 index 000000000..3ffd77ddd Binary files /dev/null and b/recognition/GANmodel_47508042/__pycache__/dataset.cpython-310.pyc differ diff --git a/recognition/GANmodel_47508042/__pycache__/dataset.cpython-312.pyc b/recognition/GANmodel_47508042/__pycache__/dataset.cpython-312.pyc new file mode 100644 index 000000000..bc811c459 Binary files /dev/null and b/recognition/GANmodel_47508042/__pycache__/dataset.cpython-312.pyc differ diff --git a/recognition/GANmodel_47508042/__pycache__/modules.cpython-310.pyc b/recognition/GANmodel_47508042/__pycache__/modules.cpython-310.pyc new file mode 100644 index 000000000..6a7957bea Binary files /dev/null and b/recognition/GANmodel_47508042/__pycache__/modules.cpython-310.pyc differ diff --git a/recognition/GANmodel_47508042/__pycache__/modules.cpython-312.pyc b/recognition/GANmodel_47508042/__pycache__/modules.cpython-312.pyc new file mode 100644 index 000000000..495e1efff Binary files /dev/null and b/recognition/GANmodel_47508042/__pycache__/modules.cpython-312.pyc differ diff --git a/recognition/GANmodel_47508042/__pycache__/utils.cpython-310.pyc b/recognition/GANmodel_47508042/__pycache__/utils.cpython-310.pyc new file mode 100644 index 000000000..05a7e2941 Binary files /dev/null and b/recognition/GANmodel_47508042/__pycache__/utils.cpython-310.pyc differ diff --git a/recognition/GANmodel_47508042/__pycache__/utils.cpython-312.pyc b/recognition/GANmodel_47508042/__pycache__/utils.cpython-312.pyc new file mode 100644 index 000000000..578abb109 Binary files /dev/null and b/recognition/GANmodel_47508042/__pycache__/utils.cpython-312.pyc differ diff --git a/recognition/GANmodel_47508042/dataset.py b/recognition/GANmodel_47508042/dataset.py new file mode 100644 index 000000000..2d20488e4 --- /dev/null +++ b/recognition/GANmodel_47508042/dataset.py @@ -0,0 +1,85 @@ +import os +import numpy as np +import nibabel as nib +from PIL import Image +from torch.utils.data import Dataset +from torchvision import transforms + +class HipMRIDataset(Dataset): + """ + dataset for loading 3D hip MRI volumes from NIfTI files and slicing them into 2D images. + """ + def __init__(self, root_dir, image_size=256, transform=None, max_slices_per_volume=None, recursive=True): + self.root_dir = root_dir # root directory containing NIfTI files + self.image_size = image_size # size to resize images to + self.max_slices = max_slices_per_volume # max number of slices to use per volume + self.recursive = recursive # whether to search subdirectories + + # finding files + self.files = self.collect_files() + if len(self.files) == 0: + raise ValueError(f"no files found in {root_dir}") + + # transform + if transform is None: + self.transform = transforms.Compose([ + transforms.Resize((image_size, image_size)), + transforms.ToTensor(), + ]) + else: + self.transform = transform + + # pre-index slices + self.index_map = [] + for f_idx, path in enumerate(self.files): + vol = nib.load(path).get_fdata() + depth = vol.shape[2] + n_slices = depth if self.max_slices is None else min(depth, self.max_slices) + for si in range(n_slices): + self.index_map.append((f_idx, si)) + + def collect_files(self): + """ + collect all NIfTI files in the root directory (and subdirectories if recursive) + """ + exists = ('.nii', '.nii.gz') + files = [] + if self.recursive: + # walk through subdirectories + for root, _, filenames in os.walk(self.root_dir): + for fn in filenames: + if fn.lower().endswith(exists): + files.append(os.path.join(root, fn)) + else: + # only current directory + for fn in os.listdir(self.root_dir): + if fn.lower().endswith(exists): + files.append(os.path.join(self.root_dir, fn)) + return sorted(files) + + def __len__(self): + """ + returns the total number of 2D slices across all volumes + """ + return len(self.index_map) + + def __getitem__(self, index): + """ + get the 2D slice at the given index + """ + file_index, slice_index = self.index_map[index] + path = self.files[file_index] + vol = nib.load(path).get_fdata() + min, max = vol.min(), vol.max() + if max - min < 1e-8: + norm = np.zeros_like(vol, dtype=np.float32) + else: + norm = (vol - min) / (max - min) + + slice2d = norm[:, :, slice_index] + # convert to uint8 image then to PIL so transform works reliably + img = Image.fromarray((slice2d * 255).astype(np.uint8)) + if self.transform: + img = self.transform(img) # tensor in [0,1}, shape (C,H,W) + return img + diff --git a/recognition/GANmodel_47508042/modules.py b/recognition/GANmodel_47508042/modules.py new file mode 100644 index 000000000..dfef14a85 --- /dev/null +++ b/recognition/GANmodel_47508042/modules.py @@ -0,0 +1,113 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Encoder(nn.Module): + """ + compresses the input image to latent representation + """ + def __init__(self, in_ch=1, hidden=128, z_channels=64): + super().__init__() + self.enc = nn.Sequential( + nn.Conv2d(in_ch, hidden//2, 4, 2, 1), # downsample 1 + nn.ReLU(True), + nn.Conv2d(hidden//2, hidden, 4, 2, 1), # downsample 2 + nn.ReLU(True), + nn.Conv2d(hidden, z_channels, 3, 1, 1), # project to latent space + ) + + def forward(self, x): + """ + forward pass through the encoder + """ + return self.enc(x) # returns latent representation z_e + +class Decoder(nn.Module): + """ + reconstructs the image from latent representation + """ + def __init__(self, out_ch=1, hidden=128, z_channels=64): + super().__init__() + self.dec = nn.Sequential( + nn.Conv2d(z_channels, hidden, 3, 1, 1), + nn.ReLU(True), + nn.ConvTranspose2d(hidden, hidden//2, 4, 2, 1), # upsample 1 + nn.ReLU(True), + nn.ConvTranspose2d(hidden//2, out_ch, 4, 2, 1), # upsample 2 + nn.Sigmoid() # output in [0,1] + ) + + def forward(self, z): + """ + forward pass through the decoder + """ + return self.dec(z) # returns reconstructed image + +class VectorQuantizer(nn.Module): + """ + Vector Quantization layer for VQ-VAE + """ + def __init__(self, num_embeddings=512, embedding_dim=64, beta=0.25): + super().__init__() + self.num_embeddings = num_embeddings # number of discrete embeddings + self.embedding_dim = embedding_dim # dimension of each embedding + self.beta = beta # commitment loss weight + + self.embedding = nn.Embedding(self.num_embeddings, self.embedding_dim) + nn.init.uniform_(self.embedding.weight, -1.0 / self.num_embeddings, 1.0 / self.num_embeddings) + + def forward(self, z): + """ + forward pass through the VQ layer + """ + z_perm = z.permute(0, 2, 3, 1).contiguous() + flat_z = z_perm.view(-1, self.embedding_dim) + + # compute distances + distances = ( + torch.sum(flat_z**2, dim=1, keepdim=True) + + torch.sum(self.embedding.weight**2, dim=1) + - 2 * torch.matmul(flat_z, self.embedding.weight.t()) + ) + + # encoding + encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1) + encodings = torch.zeros(encoding_indices.size(0), self.num_embeddings, device=z.device) + encodings.scatter_(1, encoding_indices, 1) + + # quantize and reshape + quantized = torch.matmul(encodings, self.embedding.weight) + quantized = quantized.view(z_perm.shape) + quantized = quantized.permute(0, 3, 1, 2).contiguous() + + # losses + e_latent_loss = F.mse_loss(quantized.detach(), z) + q_latent_loss = F.mse_loss(quantized, z.detach()) + loss = q_latent_loss + self.beta * e_latent_loss + + # estimator + quantized = z + (quantized - z).detach() + + # visualisation + encoding_indices = encoding_indices.view(z.shape[0], z.shape[2], z.shape[3]) + return quantized, loss, encoding_indices + +class VQVAE(nn.Module): + """ + VQ-VAE model combining encoder, vector quantizer, and decoder + """ + def __init__(self, in_ch=1, hidden=128, z_channels=64, num_embeddings=512, embedding_dim=64, beta=0.25): + super().__init__() + assert z_channels == embedding_dim + self.encoder = Encoder(in_ch, hidden, z_channels) + self.vq = VectorQuantizer(num_embeddings=num_embeddings, embedding_dim=embedding_dim, beta=beta) + self.decoder = Decoder(in_ch, hidden, z_channels) + + def forward(self, x): + """ + forward pass through the VQ-VAE + """ + z_e = self.encoder(x) + quantized, vq_loss, indices = self.vq(z_e) + x_recon = self.decoder(quantized) + return x_recon, vq_loss, indices diff --git a/recognition/GANmodel_47508042/predict.py b/recognition/GANmodel_47508042/predict.py new file mode 100644 index 000000000..a318f9965 --- /dev/null +++ b/recognition/GANmodel_47508042/predict.py @@ -0,0 +1,40 @@ +import os +import argparse +import torch +from torch.utils.data import DataLoader +from dataset import HipMRIDataset +from modules import VQVAE +from utils import ensure_dir, save_pair_grid + +def visualise(checkpoint, data_root, output_dir, batch_size=8, max_slices=40): + """ + Visualise reconstructions from a trained VQ-VAE model + """ + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + checkpoint = torch.load(checkpoint, map_location=device, weights_only=True) + model = VQVAE(in_ch=1, hidden=128, z_channels=64, num_embeddings=512, beta=0.25).to(device) + model.load_state_dict(checkpoint['model_state']) + model.eval() + + ds = HipMRIDataset(data_root, image_size=256, max_slices_per_volume=max_slices, recursive=True) + loader = DataLoader(ds, batch_size=batch_size, shuffle=False, num_workers=2) + + ensure_dir(output_dir) + with torch.no_grad(): + batch = next(iter(loader)) + x = batch.to(device) + x_recon, _, indices = model(x) + save_pair_grid(x.cpu(), x_recon.cpu(), os.path.join(output_dir, 'pred_sample.png')) + print("saved reconstructions to", os.path.join(output_dir, "pred_sample.png")) + print("indices shape:", indices.shape) + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--checkpoint", type=str, required=True) + parser.add_argument("--data_root", type=str, required=True) + parser.add_argument("--output_dir", type=str, default="outputs/pred") + parser.add_argument("--batch_size", type=int, default=8) + parser.add_argument("--max_slices", type=int, default=40) + args = parser.parse_args() + # run visualisation + visualise(args.checkpoint, args.data_root, args.output_dir, args.batch_size, args.max_slices) \ No newline at end of file diff --git a/recognition/GANmodel_47508042/train.py b/recognition/GANmodel_47508042/train.py new file mode 100644 index 000000000..e4d89fd5f --- /dev/null +++ b/recognition/GANmodel_47508042/train.py @@ -0,0 +1,112 @@ +import os +import argparse +import torch +from torch import optim +from torch.utils.data import DataLoader, random_split +import torch.nn.functional as F +from tqdm import tqdm + +from modules import VQVAE +from dataset import HipMRIDataset +from utils import ensure_dir, save_pair_grid, batch_ssim + +def train_loop(args): + """ + main training loop for VQ-VAE on hip MRI dataset + """ + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + print("device:", device) + + ds = HipMRIDataset(args.data_root, image_size=args.image_size, max_slices_per_volume=args.max_slices, recursive=True) + # split training and testing + n_val = max(1, int(len(ds) * args.val_frac)) + n_train = len(ds) - n_val + train_set, val_set = random_split(ds, [n_train, n_val]) + + # data loaders + train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) + val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) + model = VQVAE(in_ch=1, hidden=args.hidden, z_channels=args.z_ch, num_embeddings=args.num_embeddings, beta=args.beta).to(device) + optimiser = optim.Adam(model.parameters(), lr=args.lr) + + # output directory + ensure_dir(args.output_dir) + best_ssim = 0.0 + + # training loop + for epoch in range(1, args.epochs + 1): + model.train() + train_loss = 0.0 + for batch in tqdm(train_loader, desc=f'Epoch {epoch}/{args.epochs} [train]'): + x = batch.to(device) + x_recon, vq_loss, _ = model(x) + recon_loss = F.mse_loss(x_recon,x) + loss = recon_loss + vq_loss + + optimiser.zero_grad() + loss.backward() + optimiser.step() + + train_loss += loss.item() * x.size(0) + + train_loss = train_loss / len(train_loader.dataset) + + # validation + model.eval() + val_loss = 0.0 + ssim_scores = [] + with torch.no_grad(): + for i, batch in enumerate(tqdm(val_loader, desc=f"Epoch {epoch}/{args.epochs} [val]")): + x = batch.to(device) + x_recon, vq_loss, _ = model(x) + recon_loss = F.mse_loss(x_recon, x) + loss = recon_loss + vq_loss + val_loss += loss.item() * x.size(0) + ssim_scores.append(batch_ssim(x, x_recon)) + + # save first validation batch reconstructions + if i == 0: + save_pair_grid(x.cpu(), x_recon.cpu(), + os.path.join(args.output_dir, f"epoch_{epoch:03d}_recon.png"), n=min(4, x.size(0))) + + val_loss = val_loss / len(val_loader.dataset) + mean_ssim = float(sum(ssim_scores) / len(ssim_scores)) + + print(f"Epoch {epoch} TrainLoss={train_loss:.4f} ValLoss={val_loss:.4f} ValSSIM={mean_ssim:.4f}") + + # checkpoint by SSIM + if mean_ssim > best_ssim: + best_ssim = mean_ssim + torch.save({ + 'epoch': epoch, + 'model_state': model.state_dict(), + 'optimizer_state': optimiser.state_dict(), + 'ssim': best_ssim + }, os.path.join(args.output_dir, "best_checkpoint.pth")) + print(f" Saved best model (SSIM={best_ssim:.4f})") + + # save periodic snapshot + if epoch % args.save_every == 0: + torch.save(model.state_dict(), os.path.join(args.output_dir, f"vqvae_epoch_{epoch}.pt")) + + print("Training complete. Best SSIM:", best_ssim) + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--data_root", type=str, required=True) + parser.add_argument("--output_dir", type=str, default="outputs") + parser.add_argument("--image_size", type=int, default=256) + parser.add_argument("--epochs", type=int, default=40) + parser.add_argument("--batch_size", type=int, default=8) + parser.add_argument("--lr", type=float, default=2e-4) + parser.add_argument("--z_ch", type=int, default=64) + parser.add_argument("--hidden", type=int, default=128) + parser.add_argument("--num_embeddings", type=int, default=512) + parser.add_argument("--beta", type=float, default=0.25) + parser.add_argument("--max_slices", type=int, default=40) + parser.add_argument("--val_frac", type=float, default=0.10) + parser.add_argument("--workers", type=int, default=2) + parser.add_argument("--save_every", type=int, default=10) + args = parser.parse_args() + ensure_dir(args.output_dir) + train_loop(args) diff --git a/recognition/GANmodel_47508042/utils.py b/recognition/GANmodel_47508042/utils.py new file mode 100644 index 000000000..519313a5d --- /dev/null +++ b/recognition/GANmodel_47508042/utils.py @@ -0,0 +1,45 @@ +import os +import torch +import matplotlib.pyplot as plt +import numpy as np +from skimage.metrics import structural_similarity as ssim + +def ensure_dir(path): + os.makedirs(path, exist_ok=True) + +def save_pair_grid(x_orig, x_rec, path, n=4): + """ + Save a grid with n original / reconstructed pairs (first n in batch). + x_*: tensors (B, C, H, W) in [0,1] + """ + ensure_dir(os.path.dirname(path)) + B = x_orig.shape[0] + n = min(n, B) + fig, axes = plt.subplots(n, 2, figsize=(6, 3*n)) + for i in range(n): + axes[i,0].imshow(x_orig[i].cpu().squeeze(), cmap='gray') + axes[i,0].set_title("Original") + axes[i,0].axis('off') + axes[i,1].imshow(x_rec[i].cpu().squeeze(), cmap='gray') + axes[i,1].set_title("Reconstruction") + axes[i,1].axis('off') + plt.tight_layout() + plt.savefig(path, dpi=150) + plt.close(fig) + +def batch_ssim(x, x_recon): + """ + Compute mean SSIM across a batch. Assumes x and x_recon are (B,1,H,W) in [0,1]. + """ + x_np = x.detach().cpu().numpy() + xr_np = x_recon.detach().cpu().numpy() + scores = [] + for i in range(x_np.shape[0]): + im1 = x_np[i,0] + im2 = xr_np[i,0] + try: + s = ssim(im1, im2, data_range=1.0) + except Exception: + s = 0.0 + scores.append(s) + return float(np.mean(scores))