Skip to content
Binary file added .DS_Store
Binary file not shown.
3 changes: 3 additions & 0 deletions .idea/.gitignore

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 8 additions & 0 deletions .idea/PatternAnalysis-2025.iml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 12 additions & 0 deletions .idea/inspectionProfiles/Project_Default.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions .idea/inspectionProfiles/profiles_settings.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 7 additions & 0 deletions .idea/misc.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 8 additions & 0 deletions .idea/modules.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions .idea/vcs.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Binary file added recognition/.DS_Store
Binary file not shown.
43 changes: 43 additions & 0 deletions recognition/GANmodel_47508042/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Pattern Recognition - Generative Model using VQVAE

## Overview
This project implements a Vector Quantised Variational Autoencoder (VQVAE) to learna generative model from the HipMRI
prostate cancer study dataset. THe model reconstructs 2D MRI slices using compressed discrete latent codes, effectively
learning meaningful anatomical representation of prostate structures. The objective of this is to produce clear, realistic
reconstructions of MRI slices with a SSIM > 0.6 on testing.

## How it works
VQVAE is a discrete latent variable model that combines that representational power of neural networks with the
compression efficiency of vector quantization. It consists of three main components:
1. Encoder: A convolutional neural network that maps input MRI slices to a continuous latent space.
2. Vector Quantizer: This module discretizes the continuous latent representations into a finite set of learned embedding vectors (codebook).
3. Decoder: Another convolutional neural network that reconstructs the MRI slices from the quantized latent codes.
The model is trained end-to-end using a combination of reconstruction loss (Mean Squared Error) and a commitment loss (VQ loss)

## Files
- 'dataset.py' - Encoder, Decoder, VectorQuantizer, VQVAE model
- 'modules.py' - 'HipMRIDataset' loader
- 'predict.py' - load best model and produce reconstructions
- 'train.py' - training loop, validation, checkpointing, plots
- 'utils.py' - helper functions for training and evaluation
- 'readme.md' - documentation

## Install:
```bash
pip install torch torchvision matplotlib scikit-image numpy nibabel tqdm
```

## Training
To train the model, run:
```bash
python3 train.py --data_root "/Users/justin/Downloads/HipMRIDataset"
```

## Prediction
To generate reconstructions using the best model, run:
```bash
python3 predict.py --checkpoint outputs/best_checkpoint.pth --data_root "/Users/justin/Downloads/HipMRIDataset"
```

## Author
Justin Teng (47508042)
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
85 changes: 85 additions & 0 deletions recognition/GANmodel_47508042/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import os
import numpy as np
import nibabel as nib
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms

class HipMRIDataset(Dataset):
"""
dataset for loading 3D hip MRI volumes from NIfTI files and slicing them into 2D images.
"""
def __init__(self, root_dir, image_size=256, transform=None, max_slices_per_volume=None, recursive=True):
self.root_dir = root_dir # root directory containing NIfTI files
self.image_size = image_size # size to resize images to
self.max_slices = max_slices_per_volume # max number of slices to use per volume
self.recursive = recursive # whether to search subdirectories

# finding files
self.files = self.collect_files()
if len(self.files) == 0:
raise ValueError(f"no files found in {root_dir}")

# transform
if transform is None:
self.transform = transforms.Compose([
transforms.Resize((image_size, image_size)),
transforms.ToTensor(),
])
else:
self.transform = transform

# pre-index slices
self.index_map = []
for f_idx, path in enumerate(self.files):
vol = nib.load(path).get_fdata()
depth = vol.shape[2]
n_slices = depth if self.max_slices is None else min(depth, self.max_slices)
for si in range(n_slices):
self.index_map.append((f_idx, si))

def collect_files(self):
"""
collect all NIfTI files in the root directory (and subdirectories if recursive)
"""
exists = ('.nii', '.nii.gz')
files = []
if self.recursive:
# walk through subdirectories
for root, _, filenames in os.walk(self.root_dir):
for fn in filenames:
if fn.lower().endswith(exists):
files.append(os.path.join(root, fn))
else:
# only current directory
for fn in os.listdir(self.root_dir):
if fn.lower().endswith(exists):
files.append(os.path.join(self.root_dir, fn))
return sorted(files)

def __len__(self):
"""
returns the total number of 2D slices across all volumes
"""
return len(self.index_map)

def __getitem__(self, index):
"""
get the 2D slice at the given index
"""
file_index, slice_index = self.index_map[index]
path = self.files[file_index]
vol = nib.load(path).get_fdata()
min, max = vol.min(), vol.max()
if max - min < 1e-8:
norm = np.zeros_like(vol, dtype=np.float32)
else:
norm = (vol - min) / (max - min)

slice2d = norm[:, :, slice_index]
# convert to uint8 image then to PIL so transform works reliably
img = Image.fromarray((slice2d * 255).astype(np.uint8))
if self.transform:
img = self.transform(img) # tensor in [0,1}, shape (C,H,W)
return img

113 changes: 113 additions & 0 deletions recognition/GANmodel_47508042/modules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

class Encoder(nn.Module):
"""
compresses the input image to latent representation
"""
def __init__(self, in_ch=1, hidden=128, z_channels=64):
super().__init__()
self.enc = nn.Sequential(
nn.Conv2d(in_ch, hidden//2, 4, 2, 1), # downsample 1
nn.ReLU(True),
nn.Conv2d(hidden//2, hidden, 4, 2, 1), # downsample 2
nn.ReLU(True),
nn.Conv2d(hidden, z_channels, 3, 1, 1), # project to latent space
)

def forward(self, x):
"""
forward pass through the encoder
"""
return self.enc(x) # returns latent representation z_e

class Decoder(nn.Module):
"""
reconstructs the image from latent representation
"""
def __init__(self, out_ch=1, hidden=128, z_channels=64):
super().__init__()
self.dec = nn.Sequential(
nn.Conv2d(z_channels, hidden, 3, 1, 1),
nn.ReLU(True),
nn.ConvTranspose2d(hidden, hidden//2, 4, 2, 1), # upsample 1
nn.ReLU(True),
nn.ConvTranspose2d(hidden//2, out_ch, 4, 2, 1), # upsample 2
nn.Sigmoid() # output in [0,1]
)

def forward(self, z):
"""
forward pass through the decoder
"""
return self.dec(z) # returns reconstructed image

class VectorQuantizer(nn.Module):
"""
Vector Quantization layer for VQ-VAE
"""
def __init__(self, num_embeddings=512, embedding_dim=64, beta=0.25):
super().__init__()
self.num_embeddings = num_embeddings # number of discrete embeddings
self.embedding_dim = embedding_dim # dimension of each embedding
self.beta = beta # commitment loss weight

self.embedding = nn.Embedding(self.num_embeddings, self.embedding_dim)
nn.init.uniform_(self.embedding.weight, -1.0 / self.num_embeddings, 1.0 / self.num_embeddings)

def forward(self, z):
"""
forward pass through the VQ layer
"""
z_perm = z.permute(0, 2, 3, 1).contiguous()
flat_z = z_perm.view(-1, self.embedding_dim)

# compute distances
distances = (
torch.sum(flat_z**2, dim=1, keepdim=True)
+ torch.sum(self.embedding.weight**2, dim=1)
- 2 * torch.matmul(flat_z, self.embedding.weight.t())
)

# encoding
encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
encodings = torch.zeros(encoding_indices.size(0), self.num_embeddings, device=z.device)
encodings.scatter_(1, encoding_indices, 1)

# quantize and reshape
quantized = torch.matmul(encodings, self.embedding.weight)
quantized = quantized.view(z_perm.shape)
quantized = quantized.permute(0, 3, 1, 2).contiguous()

# losses
e_latent_loss = F.mse_loss(quantized.detach(), z)
q_latent_loss = F.mse_loss(quantized, z.detach())
loss = q_latent_loss + self.beta * e_latent_loss

# estimator
quantized = z + (quantized - z).detach()

# visualisation
encoding_indices = encoding_indices.view(z.shape[0], z.shape[2], z.shape[3])
return quantized, loss, encoding_indices

class VQVAE(nn.Module):
"""
VQ-VAE model combining encoder, vector quantizer, and decoder
"""
def __init__(self, in_ch=1, hidden=128, z_channels=64, num_embeddings=512, embedding_dim=64, beta=0.25):
super().__init__()
assert z_channels == embedding_dim
self.encoder = Encoder(in_ch, hidden, z_channels)
self.vq = VectorQuantizer(num_embeddings=num_embeddings, embedding_dim=embedding_dim, beta=beta)
self.decoder = Decoder(in_ch, hidden, z_channels)

def forward(self, x):
"""
forward pass through the VQ-VAE
"""
z_e = self.encoder(x)
quantized, vq_loss, indices = self.vq(z_e)
x_recon = self.decoder(quantized)
return x_recon, vq_loss, indices
40 changes: 40 additions & 0 deletions recognition/GANmodel_47508042/predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import os
import argparse
import torch
from torch.utils.data import DataLoader
from dataset import HipMRIDataset
from modules import VQVAE
from utils import ensure_dir, save_pair_grid

def visualise(checkpoint, data_root, output_dir, batch_size=8, max_slices=40):
"""
Visualise reconstructions from a trained VQ-VAE model
"""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
checkpoint = torch.load(checkpoint, map_location=device, weights_only=True)
model = VQVAE(in_ch=1, hidden=128, z_channels=64, num_embeddings=512, beta=0.25).to(device)
model.load_state_dict(checkpoint['model_state'])
model.eval()

ds = HipMRIDataset(data_root, image_size=256, max_slices_per_volume=max_slices, recursive=True)
loader = DataLoader(ds, batch_size=batch_size, shuffle=False, num_workers=2)

ensure_dir(output_dir)
with torch.no_grad():
batch = next(iter(loader))
x = batch.to(device)
x_recon, _, indices = model(x)
save_pair_grid(x.cpu(), x_recon.cpu(), os.path.join(output_dir, 'pred_sample.png'))
print("saved reconstructions to", os.path.join(output_dir, "pred_sample.png"))
print("indices shape:", indices.shape)

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint", type=str, required=True)
parser.add_argument("--data_root", type=str, required=True)
parser.add_argument("--output_dir", type=str, default="outputs/pred")
parser.add_argument("--batch_size", type=int, default=8)
parser.add_argument("--max_slices", type=int, default=40)
args = parser.parse_args()
# run visualisation
visualise(args.checkpoint, args.data_root, args.output_dir, args.batch_size, args.max_slices)
Loading