Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 0 additions & 10 deletions recognition/README.md

This file was deleted.

184 changes: 184 additions & 0 deletions recognition/improved_unet_BochengLin/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
## Improved 3D U-Net for Prostate MRI Segmentation

**Author**: Bocheng Lin 48275565
**Project**: COMP3710 Pattern Analysis Report - Project 7 (Hard Difficulty)

## 1. Problem Description
This project implements a 6-class 3D segmenntation on the prostate 3D MRI dataset using an Improved UNet3D model. The goal is to achieve Dice >= 0.70 on all foreground classes.

## 2. Algorithm Description

Uses a 3D U-Net architecture with the following improvements:

- **Encoder-Decoder with Skip Connections**: Standard U-Net structure but with strided convolutions (Conv3d stride=2) instead of max-pooling for downsampling
- **Instance Norm**: Works better than batch norm when batch size is small (4 in this case)
- **LeakyReLU**: Better gradient flow compared to ReLU
- **Dice Loss**: Optimizes directly for segmentation IoU, handles class imbalance well
- **6-class Output**: Background + 5 tissue classes

### Model Details

Encoder: 5 levels (1 → 64 → 128 → 256 → 320 channels)
Decoder: 4 levels with symmetric upsampling and skip connections
Total: ~17.5M parameters

### Training Setup

- Optimizer: Adam (lr=1e-4, weight_decay=1e-5)
- Loss: Dice Loss
- Learning rate: Reduce by 0.5× if validation Dice doesn't improve for 5 epochs
- Batch size: 4
- Data split: 80% train / 10% val / 10% test
- Early stopping: None (run full 100 epochs)

## 3. How it Works

### Working Principle

The improved 3D U-Net uses an encoder-decoder architecture with key improvements over standard U-Net:

1. **Encoder** progressively downsamples using **strided convolutions (stride=2)** instead of max-pooling (improvement: learnable downsampling preserves more information)
2. **InstanceNorm + LeakyReLU** at each block (improvement: better for small batch sizes and smoother gradient flow)
3. **Bottleneck** captures abstract features at the deepest level (320 channels)
4. **Decoder** progressively upsamples with **ConvTranspose3d** and skip connections (recovers fine details)
5. **Output layer** produces 6-channel prediction (one per class)

During inference, argmax converts output to class labels.

**Key Improvements:**
- Strided convolutions: Learnable downsampling vs fixed pooling
- Instance normalization: Better than batch norm for small batches
- LeakyReLU: Avoids dead neurons, improves gradient flow

### Architecture Diagram

```
Input (1×128×128×64)
[Conv + InstanceNorm + LeakyReLU] → Strided Conv (stride 2)
↓ (64 ch)
[Conv + InstanceNorm + LeakyReLU] → Strided Conv (stride 2)
↓ (128 ch)
[Conv + InstanceNorm + LeakyReLU] → Strided Conv (stride 2)
↓ (256 ch)
[Conv + InstanceNorm + LeakyReLU] → Strided Conv (stride 2)
↓ (320 ch)
[Conv + InstanceNorm + LeakyReLU] [Bottleneck]
↓ (320 ch)
ConvTranspose3d + Skip + [Conv Block]
↓ (256 ch)
ConvTranspose3d + Skip + [Conv Block]
↓ (128 ch)
ConvTranspose3d + Skip + [Conv Block]
↓ (64 ch)
ConvTranspose3d + Skip + [Conv Block]
↓ (32 ch)
Output Conv (1×1×1, 6 classes)
Output (6×128×128×64)
```

### Segmentation Results

![Per-Class Dice Scores](results/per_class_dice.png)

![Segmentation Example 1](results/segmentation_example_1.png)

![Segmentation Example 2](results/segmentation_example_2.png)

![Segmentation Example 3](results/segmentation_example_3.png)

## 4. Dataset and Preprocessing

### Data

3D MRI volumes with labels (6 classes: background + 5 tissue types)

### Processing

- Center crop to 128×128×64 voxels
- Z-score normalization per volume: $(x - \mu) / \sigma$
- Labels: 0-5 (one-hot for training)

### Split

- Train: 80% (~170 samples)
- Val: 10% (~21 samples)
- Test: 10% (~21 samples)

Deterministic split for reproducibility.

## 5. Project Structure

- `dataset.py` - Prostate3DDataset class for loading and preprocessing 3D MRI data
- `modules.py` - UNet3D_Improved model architecture
- `train.py` - Training script with Dice loss and validation
- `predict.py` - Evaluation script for computing per-class Dice
- `visualize_results.py` - Generate segmentation visualizations and metrics
- `test_model.py` - Model validation script
- `best_model.pth` - Trained model weights
- `environment.yml` - Conda environment configuration

## 6. Usage

### Setup Environment

```bash
conda env create -f environment.yml
conda activate unet3d
```

### Data Path Configuration

The dataset path needs to be specified for training and evaluation scripts. Update the path according to your local setup:

```bash
# Default path (modify as needed)
DATA_PATH=C:\data\HipMRI_3D
```

### Training

```bash
python train.py --data_path C:\data\HipMRI_3D --batch_size 4 --lr 1e-4 --epochs 100
```

### Evaluation on Test Set

```bash
python predict.py --data_path C:\data\HipMRI_3D --model_path best_model.pth
```

### Generate Visualizations

```bash
python visualize_results.py --data_path C:\data\HipMRI_3D --num_samples 3
```

## 7. Dependencies

- PyTorch >= 1.9.0
- torchvision >= 0.10.0
- numpy >= 1.19.0
- nibabel >= 3.2.0
- tqdm >= 4.50.0
- matplotlib >= 3.3.0
- Python >= 3.8

See `environment.yml` for exact versions.

## 8. Reproducibility

- Fixed random seed in data splitting for deterministic train/val/test split
- Model weights saved to `best_model.pth`
- All hyperparameters configurable via command-line arguments
- Preprocessing is deterministic (no random augmentation in prediction)

## 9. References

- U-Net: Convolutional Networks for Biomedical Image Segmentation (Ronneberger et al., 2015)
- Instance Normalization: Ulyanov et al., 2016

## 10. Acknowledgments

This project was developed with assistance from GitHub Copilot.
Binary file not shown.
96 changes: 96 additions & 0 deletions recognition/improved_unet_BochengLin/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import torch
import numpy as np
from torch.utils.data import Dataset
import os
import nibabel as nib

class Prostate3DDataset(Dataset):
def __init__(self, root_dir, split="train", transform=None, target_shape=(128, 128, 64)):
"""
Load 3D MRI dataset with semantic segmentation labels.

Args:
root_dir: Path to dataset root directory
split: 'train', 'val', or 'test'
transform: Optional data augmentation transforms
target_shape: Output volume size (H, W, D)
"""
self.img_dir = os.path.join(root_dir, "semantic_MRs")
self.lbl_dir = os.path.join(root_dir, "semantic_labels_only")
self.transform = transform
self.target_shape = target_shape

all_files = sorted(os.listdir(self.img_dir))

# Deterministic train/val/test split (80/10/10)
np.random.seed(42)
np.random.shuffle(all_files)

train_split = int(0.8 * len(all_files))
val_split = int(0.9 * len(all_files))

if split == "train":
self.file_list = all_files[:train_split]
elif split == "val":
self.file_list = all_files[train_split:val_split]
elif split == "test":
self.file_list = all_files[val_split:]
else:
raise ValueError(f"Invalid split name: {split}")

print(f"Loaded {len(self.file_list)} files for the {split} set.")

def __len__(self):
return len(self.file_list)

def __getitem__(self, idx):
file_name = self.file_list[idx]
img_path = os.path.join(self.img_dir, file_name)

# Map image filename to corresponding label filename
lbl_file_name = file_name.replace("_LFOV", "_SEMANTIC")
lbl_path = os.path.join(self.lbl_dir, lbl_file_name)

img = nib.load(img_path).get_fdata().astype(np.float32)
lbl = nib.load(lbl_path).get_fdata().astype(np.int64)

# Center crop to target shape
h, w, d = img.shape
th, tw, td = self.target_shape
x1 = int(round((w - tw) / 2.))
y1 = int(round((h - th) / 2.))
z1 = int(round((d - td) / 2.))
img = img[y1:y1+th, x1:x1+tw, z1:z1+td]
lbl = lbl[y1:y1+th, x1:x1+tw, z1:z1+td]

# Z-score normalization
img = (img - img.mean()) / (img.std() + 1e-8)

# Convert to tensor with channel dimension (1, H, W, D)
img = torch.from_numpy(img).unsqueeze(0)
lbl = torch.from_numpy(lbl).unsqueeze(0)

sample = {"image": img, "label": lbl}

# Optional data augmentation
if self.transform:
sample = self.transform(sample)

return sample
if __name__ == "__main__":
import argparse

parser = argparse.ArgumentParser(description="Test dataset loading")
parser.add_argument("--data_path", type=str,
default=r"C:\data\HipMRI_3D",
help="Path to dataset root directory")

args = parser.parse_args()

print(f"Loading data from: {args.data_path}")
train_dataset = Prostate3DDataset(root_dir=args.data_path, split="train")

if len(train_dataset) > 0:
sample = train_dataset[0]
print("Sample image shape:", sample["image"].shape)
print("Sample label shape:", sample["label"].shape)
74 changes: 74 additions & 0 deletions recognition/improved_unet_BochengLin/environment.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
name: unet3d
channels:
- defaults
- https://repo.anaconda.com/pkgs/main
- https://repo.anaconda.com/pkgs/r
- https://repo.anaconda.com/pkgs/msys2
dependencies:
- bzip2=1.0.8=h2bbff1b_6
- ca-certificates=2025.9.9=haa95532_0
- expat=2.7.1=h8ddb27b_0
- libffi=3.4.4=hd77b12b_1
- libzlib=1.3.1=h02ab6af_0
- openssl=3.0.18=h543e019_0
- pip=25.2=pyhc872135_0
- python=3.10.18=h981015d_0
- setuptools=80.9.0=py310haa95532_0
- sqlite=3.50.2=hda9a48d_1
- tk=8.6.15=hf199647_0
- tzdata=2025b=h04d1e81_0
- ucrt=10.0.22621.0=haa95532_0
- vc=14.3=h2df5915_10
- vc14_runtime=14.44.35208=h4927774_10
- vs2015_runtime=14.44.35208=ha6b5a95_10
- wheel=0.45.1=py310haa95532_0
- xz=5.6.4=h4754444_1
- zlib=1.3.1=h02ab6af_0
- pip:
- click==8.3.0
- colorama==0.4.6
- contourpy==1.3.2
- cycler==0.12.1
- deprecated==1.2.18
- einops==0.8.1
- filelock==3.13.1
- fonttools==4.60.1
- fsspec==2024.6.1
- humanize==4.13.0
- imageio==2.37.0
- importlib-resources==6.5.2
- jinja2==3.1.4
- kiwisolver==1.4.9
- lazy-loader==0.4
- markdown-it-py==4.0.0
- markupsafe==2.1.5
- matplotlib==3.10.7
- mdurl==0.1.2
- monai==1.5.1
- mpmath==1.3.0
- networkx==3.3
- nibabel==5.3.2
- numpy==2.1.2
- opencv-python==4.12.0.88
- packaging==25.0
- pillow==11.0.0
- pygments==2.19.2
- pyparsing==3.2.5
- python-dateutil==2.9.0.post0
- rich==14.2.0
- scikit-image==0.25.2
- scipy==1.15.3
- shellingham==1.5.4
- simpleitk==2.5.2
- six==1.17.0
- sympy==1.13.1
- tifffile==2025.5.10
- torch==2.6.0+cu124
- torchaudio==2.6.0+cu124
- torchio==0.20.23
- torchvision==0.21.0+cu124
- tqdm==4.67.1
- typer==0.19.2
- typing-extensions==4.12.2
- wrapt==1.17.3
prefix: C:\Users\66449\anaconda3\envs\unet3d
23 changes: 23 additions & 0 deletions recognition/improved_unet_BochengLin/git_update.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import os
import subprocess

BRANCH = "topic-recognition"
REMOTE = "origin"

os.chdir(os.path.dirname(os.path.abspath(__file__)))

# Show status
print(" Git Status:")
os.system("git status")

# Get commit message
msg = input("\n Commit message: ").strip()
if not msg:
msg = "Update"

# Add, commit, push
os.system("git add -A")
os.system(f'git commit -m "{msg}"')
os.system(f"git push {REMOTE} {BRANCH}")

print("✅ Done!")
Loading