diff --git a/recognition/README.md b/recognition/README.md deleted file mode 100644 index 32c99e899..000000000 --- a/recognition/README.md +++ /dev/null @@ -1,10 +0,0 @@ -# Recognition Tasks -Various recognition tasks solved in deep learning frameworks. - -Tasks may include: -* Image Segmentation -* Object detection -* Graph node classification -* Image super resolution -* Disease classification -* Generative modelling with StyleGAN and Stable Diffusion \ No newline at end of file diff --git a/recognition/improved_unet_BochengLin/README.md b/recognition/improved_unet_BochengLin/README.md new file mode 100644 index 000000000..1916a2acd --- /dev/null +++ b/recognition/improved_unet_BochengLin/README.md @@ -0,0 +1,184 @@ +## Improved 3D U-Net for Prostate MRI Segmentation + +**Author**: Bocheng Lin 48275565 +**Project**: COMP3710 Pattern Analysis Report - Project 7 (Hard Difficulty) + +## 1. Problem Description +This project implements a 6-class 3D segmenntation on the prostate 3D MRI dataset using an Improved UNet3D model. The goal is to achieve Dice >= 0.70 on all foreground classes. + +## 2. Algorithm Description + +Uses a 3D U-Net architecture with the following improvements: + +- **Encoder-Decoder with Skip Connections**: Standard U-Net structure but with strided convolutions (Conv3d stride=2) instead of max-pooling for downsampling +- **Instance Norm**: Works better than batch norm when batch size is small (4 in this case) +- **LeakyReLU**: Better gradient flow compared to ReLU +- **Dice Loss**: Optimizes directly for segmentation IoU, handles class imbalance well +- **6-class Output**: Background + 5 tissue classes + +### Model Details + +Encoder: 5 levels (1 → 64 → 128 → 256 → 320 channels) +Decoder: 4 levels with symmetric upsampling and skip connections +Total: ~17.5M parameters + +### Training Setup + +- Optimizer: Adam (lr=1e-4, weight_decay=1e-5) +- Loss: Dice Loss +- Learning rate: Reduce by 0.5× if validation Dice doesn't improve for 5 epochs +- Batch size: 4 +- Data split: 80% train / 10% val / 10% test +- Early stopping: None (run full 100 epochs) + +## 3. How it Works + +### Working Principle + +The improved 3D U-Net uses an encoder-decoder architecture with key improvements over standard U-Net: + +1. **Encoder** progressively downsamples using **strided convolutions (stride=2)** instead of max-pooling (improvement: learnable downsampling preserves more information) +2. **InstanceNorm + LeakyReLU** at each block (improvement: better for small batch sizes and smoother gradient flow) +3. **Bottleneck** captures abstract features at the deepest level (320 channels) +4. **Decoder** progressively upsamples with **ConvTranspose3d** and skip connections (recovers fine details) +5. **Output layer** produces 6-channel prediction (one per class) + +During inference, argmax converts output to class labels. + +**Key Improvements:** +- Strided convolutions: Learnable downsampling vs fixed pooling +- Instance normalization: Better than batch norm for small batches +- LeakyReLU: Avoids dead neurons, improves gradient flow + +### Architecture Diagram + +``` +Input (1×128×128×64) + ↓ +[Conv + InstanceNorm + LeakyReLU] → Strided Conv (stride 2) + ↓ (64 ch) +[Conv + InstanceNorm + LeakyReLU] → Strided Conv (stride 2) + ↓ (128 ch) +[Conv + InstanceNorm + LeakyReLU] → Strided Conv (stride 2) + ↓ (256 ch) +[Conv + InstanceNorm + LeakyReLU] → Strided Conv (stride 2) + ↓ (320 ch) +[Conv + InstanceNorm + LeakyReLU] [Bottleneck] + ↓ (320 ch) +ConvTranspose3d + Skip + [Conv Block] + ↓ (256 ch) +ConvTranspose3d + Skip + [Conv Block] + ↓ (128 ch) +ConvTranspose3d + Skip + [Conv Block] + ↓ (64 ch) +ConvTranspose3d + Skip + [Conv Block] + ↓ (32 ch) +Output Conv (1×1×1, 6 classes) + ↓ +Output (6×128×128×64) +``` + +### Segmentation Results + +![Per-Class Dice Scores](results/per_class_dice.png) + +![Segmentation Example 1](results/segmentation_example_1.png) + +![Segmentation Example 2](results/segmentation_example_2.png) + +![Segmentation Example 3](results/segmentation_example_3.png) + +## 4. Dataset and Preprocessing + +### Data + +3D MRI volumes with labels (6 classes: background + 5 tissue types) + +### Processing + +- Center crop to 128×128×64 voxels +- Z-score normalization per volume: $(x - \mu) / \sigma$ +- Labels: 0-5 (one-hot for training) + +### Split + +- Train: 80% (~170 samples) +- Val: 10% (~21 samples) +- Test: 10% (~21 samples) + +Deterministic split for reproducibility. + +## 5. Project Structure + +- `dataset.py` - Prostate3DDataset class for loading and preprocessing 3D MRI data +- `modules.py` - UNet3D_Improved model architecture +- `train.py` - Training script with Dice loss and validation +- `predict.py` - Evaluation script for computing per-class Dice +- `visualize_results.py` - Generate segmentation visualizations and metrics +- `test_model.py` - Model validation script +- `best_model.pth` - Trained model weights +- `environment.yml` - Conda environment configuration + +## 6. Usage + +### Setup Environment + +```bash +conda env create -f environment.yml +conda activate unet3d +``` + +### Data Path Configuration + +The dataset path needs to be specified for training and evaluation scripts. Update the path according to your local setup: + +```bash +# Default path (modify as needed) +DATA_PATH=C:\data\HipMRI_3D +``` + +### Training + +```bash +python train.py --data_path C:\data\HipMRI_3D --batch_size 4 --lr 1e-4 --epochs 100 +``` + +### Evaluation on Test Set + +```bash +python predict.py --data_path C:\data\HipMRI_3D --model_path best_model.pth +``` + +### Generate Visualizations + +```bash +python visualize_results.py --data_path C:\data\HipMRI_3D --num_samples 3 +``` + +## 7. Dependencies + +- PyTorch >= 1.9.0 +- torchvision >= 0.10.0 +- numpy >= 1.19.0 +- nibabel >= 3.2.0 +- tqdm >= 4.50.0 +- matplotlib >= 3.3.0 +- Python >= 3.8 + +See `environment.yml` for exact versions. + +## 8. Reproducibility + +- Fixed random seed in data splitting for deterministic train/val/test split +- Model weights saved to `best_model.pth` +- All hyperparameters configurable via command-line arguments +- Preprocessing is deterministic (no random augmentation in prediction) + +## 9. References + +- U-Net: Convolutional Networks for Biomedical Image Segmentation (Ronneberger et al., 2015) +- Instance Normalization: Ulyanov et al., 2016 + +## 10. Acknowledgments + +This project was developed with assistance from GitHub Copilot. diff --git a/recognition/improved_unet_BochengLin/best_model.pth b/recognition/improved_unet_BochengLin/best_model.pth new file mode 100644 index 000000000..7eac47bcf Binary files /dev/null and b/recognition/improved_unet_BochengLin/best_model.pth differ diff --git a/recognition/improved_unet_BochengLin/dataset.py b/recognition/improved_unet_BochengLin/dataset.py new file mode 100644 index 000000000..08f5752c0 --- /dev/null +++ b/recognition/improved_unet_BochengLin/dataset.py @@ -0,0 +1,96 @@ +import torch +import numpy as np +from torch.utils.data import Dataset +import os +import nibabel as nib + +class Prostate3DDataset(Dataset): + def __init__(self, root_dir, split="train", transform=None, target_shape=(128, 128, 64)): + """ + Load 3D MRI dataset with semantic segmentation labels. + + Args: + root_dir: Path to dataset root directory + split: 'train', 'val', or 'test' + transform: Optional data augmentation transforms + target_shape: Output volume size (H, W, D) + """ + self.img_dir = os.path.join(root_dir, "semantic_MRs") + self.lbl_dir = os.path.join(root_dir, "semantic_labels_only") + self.transform = transform + self.target_shape = target_shape + + all_files = sorted(os.listdir(self.img_dir)) + + # Deterministic train/val/test split (80/10/10) + np.random.seed(42) + np.random.shuffle(all_files) + + train_split = int(0.8 * len(all_files)) + val_split = int(0.9 * len(all_files)) + + if split == "train": + self.file_list = all_files[:train_split] + elif split == "val": + self.file_list = all_files[train_split:val_split] + elif split == "test": + self.file_list = all_files[val_split:] + else: + raise ValueError(f"Invalid split name: {split}") + + print(f"Loaded {len(self.file_list)} files for the {split} set.") + + def __len__(self): + return len(self.file_list) + + def __getitem__(self, idx): + file_name = self.file_list[idx] + img_path = os.path.join(self.img_dir, file_name) + + # Map image filename to corresponding label filename + lbl_file_name = file_name.replace("_LFOV", "_SEMANTIC") + lbl_path = os.path.join(self.lbl_dir, lbl_file_name) + + img = nib.load(img_path).get_fdata().astype(np.float32) + lbl = nib.load(lbl_path).get_fdata().astype(np.int64) + + # Center crop to target shape + h, w, d = img.shape + th, tw, td = self.target_shape + x1 = int(round((w - tw) / 2.)) + y1 = int(round((h - th) / 2.)) + z1 = int(round((d - td) / 2.)) + img = img[y1:y1+th, x1:x1+tw, z1:z1+td] + lbl = lbl[y1:y1+th, x1:x1+tw, z1:z1+td] + + # Z-score normalization + img = (img - img.mean()) / (img.std() + 1e-8) + + # Convert to tensor with channel dimension (1, H, W, D) + img = torch.from_numpy(img).unsqueeze(0) + lbl = torch.from_numpy(lbl).unsqueeze(0) + + sample = {"image": img, "label": lbl} + + # Optional data augmentation + if self.transform: + sample = self.transform(sample) + + return sample +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Test dataset loading") + parser.add_argument("--data_path", type=str, + default=r"C:\data\HipMRI_3D", + help="Path to dataset root directory") + + args = parser.parse_args() + + print(f"Loading data from: {args.data_path}") + train_dataset = Prostate3DDataset(root_dir=args.data_path, split="train") + + if len(train_dataset) > 0: + sample = train_dataset[0] + print("Sample image shape:", sample["image"].shape) + print("Sample label shape:", sample["label"].shape) \ No newline at end of file diff --git a/recognition/improved_unet_BochengLin/environment.yml b/recognition/improved_unet_BochengLin/environment.yml new file mode 100644 index 000000000..c3f3ba829 --- /dev/null +++ b/recognition/improved_unet_BochengLin/environment.yml @@ -0,0 +1,74 @@ +name: unet3d +channels: + - defaults + - https://repo.anaconda.com/pkgs/main + - https://repo.anaconda.com/pkgs/r + - https://repo.anaconda.com/pkgs/msys2 +dependencies: + - bzip2=1.0.8=h2bbff1b_6 + - ca-certificates=2025.9.9=haa95532_0 + - expat=2.7.1=h8ddb27b_0 + - libffi=3.4.4=hd77b12b_1 + - libzlib=1.3.1=h02ab6af_0 + - openssl=3.0.18=h543e019_0 + - pip=25.2=pyhc872135_0 + - python=3.10.18=h981015d_0 + - setuptools=80.9.0=py310haa95532_0 + - sqlite=3.50.2=hda9a48d_1 + - tk=8.6.15=hf199647_0 + - tzdata=2025b=h04d1e81_0 + - ucrt=10.0.22621.0=haa95532_0 + - vc=14.3=h2df5915_10 + - vc14_runtime=14.44.35208=h4927774_10 + - vs2015_runtime=14.44.35208=ha6b5a95_10 + - wheel=0.45.1=py310haa95532_0 + - xz=5.6.4=h4754444_1 + - zlib=1.3.1=h02ab6af_0 + - pip: + - click==8.3.0 + - colorama==0.4.6 + - contourpy==1.3.2 + - cycler==0.12.1 + - deprecated==1.2.18 + - einops==0.8.1 + - filelock==3.13.1 + - fonttools==4.60.1 + - fsspec==2024.6.1 + - humanize==4.13.0 + - imageio==2.37.0 + - importlib-resources==6.5.2 + - jinja2==3.1.4 + - kiwisolver==1.4.9 + - lazy-loader==0.4 + - markdown-it-py==4.0.0 + - markupsafe==2.1.5 + - matplotlib==3.10.7 + - mdurl==0.1.2 + - monai==1.5.1 + - mpmath==1.3.0 + - networkx==3.3 + - nibabel==5.3.2 + - numpy==2.1.2 + - opencv-python==4.12.0.88 + - packaging==25.0 + - pillow==11.0.0 + - pygments==2.19.2 + - pyparsing==3.2.5 + - python-dateutil==2.9.0.post0 + - rich==14.2.0 + - scikit-image==0.25.2 + - scipy==1.15.3 + - shellingham==1.5.4 + - simpleitk==2.5.2 + - six==1.17.0 + - sympy==1.13.1 + - tifffile==2025.5.10 + - torch==2.6.0+cu124 + - torchaudio==2.6.0+cu124 + - torchio==0.20.23 + - torchvision==0.21.0+cu124 + - tqdm==4.67.1 + - typer==0.19.2 + - typing-extensions==4.12.2 + - wrapt==1.17.3 +prefix: C:\Users\66449\anaconda3\envs\unet3d diff --git a/recognition/improved_unet_BochengLin/git_update.py b/recognition/improved_unet_BochengLin/git_update.py new file mode 100644 index 000000000..7272ee543 --- /dev/null +++ b/recognition/improved_unet_BochengLin/git_update.py @@ -0,0 +1,23 @@ +import os +import subprocess + +BRANCH = "topic-recognition" +REMOTE = "origin" + +os.chdir(os.path.dirname(os.path.abspath(__file__))) + +# Show status +print(" Git Status:") +os.system("git status") + +# Get commit message +msg = input("\n Commit message: ").strip() +if not msg: + msg = "Update" + +# Add, commit, push +os.system("git add -A") +os.system(f'git commit -m "{msg}"') +os.system(f"git push {REMOTE} {BRANCH}") + +print("✅ Done!") diff --git a/recognition/improved_unet_BochengLin/modules.py b/recognition/improved_unet_BochengLin/modules.py new file mode 100644 index 000000000..de1d1ac0c --- /dev/null +++ b/recognition/improved_unet_BochengLin/modules.py @@ -0,0 +1,124 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class ConvBlock(nn.Module): + """Double convolution: Conv3d + InstanceNorm + LeakyReLU, repeated twice.""" + def __init__(self, in_channels, out_channels): + super().__init__() + self.conv_block = nn.Sequential( + nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm3d(out_channels), + nn.LeakyReLU(negative_slope=0.01, inplace=True), + nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm3d(out_channels), + nn.LeakyReLU(negative_slope=0.01, inplace=True) + ) + + def forward(self, x): + return self.conv_block(x) + + +class DownBlockImproved(nn.Module): + """Downsampling with strided convolution instead of max-pooling.""" + def __init__(self, in_channels, out_channels): + super().__init__() + self.downsample_conv = nn.Sequential( + nn.Conv3d(in_channels, in_channels, kernel_size=3, stride=2, padding=1, bias=False), + nn.InstanceNorm3d(in_channels), + nn.LeakyReLU(negative_slope=0.01, inplace=True), + ConvBlock(in_channels, out_channels) + ) + + def forward(self, x): + return self.downsample_conv(x) + + +class UpBlockImproved(nn.Module): + """Upsampling with ConvTranspose3d and skip connection concatenation.""" + def __init__(self, in_channels, out_channels): + super().__init__() + self.up = nn.ConvTranspose3d(in_channels, out_channels, kernel_size=2, stride=2) + # Concatenation doubles channels: out_channels * 2 + self.conv = ConvBlock(out_channels * 2, out_channels) + + def forward(self, x1, x2): + x1 = self.up(x1) + + # Pad to match skip connection size (handles odd dimensions) + diffZ = x2.size()[2] - x1.size()[2] + diffY = x2.size()[3] - x1.size()[3] + diffX = x2.size()[4] - x1.size()[4] + + x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, + diffY // 2, diffY - diffY // 2, + diffZ // 2, diffZ - diffZ // 2]) + + # Concatenate with skip connection + x = torch.cat([x2, x1], dim=1) + return self.conv(x) + + +class OutConv(nn.Module): + """1x1x1 convolution for final output.""" + def __init__(self, in_channels, out_channels): + super(OutConv, self).__init__() + self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=1) + + def forward(self, x): + return self.conv(x) + + +class UNet3D_Improved(nn.Module): + """ + 3D U-Net architecture for semantic segmentation. + + Encoder: 1 -> 32 -> 64 -> 128 -> 256 -> 320 channels (downsampling) + Decoder: 320 -> 256 -> 128 -> 64 -> 32 channels (upsampling with skip connections) + Output: num_classes channels + + Key features: + - Strided convolutions for downsampling (learnable) + - Instance normalization (better for small batch sizes) + - LeakyReLU activations + - Skip connections to preserve fine details + """ + def __init__(self, in_channels, num_classes): + super(UNet3D_Improved, self).__init__() + self.in_channels = in_channels + self.num_classes = num_classes + + # Encoder + self.inc = ConvBlock(in_channels, 32) + self.down1 = DownBlockImproved(32, 64) + self.down2 = DownBlockImproved(64, 128) + self.down3 = DownBlockImproved(128, 256) + self.down4 = DownBlockImproved(256, 320) + + # Decoder + self.up1 = UpBlockImproved(320, 256) + self.up2 = UpBlockImproved(256, 128) + self.up3 = UpBlockImproved(128, 64) + self.up4 = UpBlockImproved(64, 32) + + # Output + self.outc = OutConv(32, num_classes) + + def forward(self, x): + # Encoder + x1 = self.inc(x) + x2 = self.down1(x1) + x3 = self.down2(x2) + x4 = self.down3(x3) + x5 = self.down4(x4) + + # Decoder with skip connections + x = self.up1(x5, x4) + x = self.up2(x, x3) + x = self.up3(x, x2) + x = self.up4(x, x1) + + # Output + logits = self.outc(x) + return logits \ No newline at end of file diff --git a/recognition/improved_unet_BochengLin/predict.py b/recognition/improved_unet_BochengLin/predict.py new file mode 100644 index 000000000..f6061422e --- /dev/null +++ b/recognition/improved_unet_BochengLin/predict.py @@ -0,0 +1,91 @@ +import torch +import torch.nn as nn +from torch.utils.data import DataLoader +import argparse +import numpy as np +from tqdm import tqdm + +from dataset import Prostate3DDataset +from modules import UNet3D_Improved + + +def evaluate(model, test_loader, device, num_classes=6): + """Evaluate model and compute per-class Dice coefficients.""" + model.eval() + all_dice_scores = {c: [] for c in range(num_classes)} + + with torch.no_grad(): + pbar = tqdm(test_loader, desc="Evaluating", ncols=100) + for batch in pbar: + images = batch["image"].to(device) + labels = batch["label"].to(device) + + outputs = model(images) + preds = torch.argmax(outputs, dim=1, keepdim=True) + + for c in range(num_classes): + pred_c = (preds == c).float() + target_c = (labels == c).float() + intersection = torch.sum(pred_c * target_c) + union = torch.sum(pred_c) + torch.sum(target_c) + + if union == 0: + dice = 1.0 + else: + dice = (2.0 * intersection + 1.0) / (union + 1.0) + dice = dice.item() + + all_dice_scores[c].append(dice) + + return all_dice_scores + + +def main(args): + """Load model and evaluate on test set.""" + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"Using device: {device}") + + # Load model + model = UNet3D_Improved(in_channels=1, num_classes=6) + model.load_state_dict(torch.load(args.model_path, map_location=device)) + model = model.to(device) + print(f"Loaded model from {args.model_path}") + + # Load test set + test_dataset = Prostate3DDataset(root_dir=args.data_path, split="test") + test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0) + print(f"Test samples: {len(test_dataset)}") + + # Evaluate + all_dice_scores = evaluate(model, test_loader, device, num_classes=6) + + # Print results + print("\n" + "=" * 50) + print("Test Set Results") + print("=" * 50) + + mean_dice_per_class = [] + for c in range(6): + mean_dice = np.mean(all_dice_scores[c]) + mean_dice_per_class.append(mean_dice) + print(f"Class {c}: {mean_dice:.4f}") + + overall_dice = np.mean(mean_dice_per_class) + print(f"\nOverall Dice: {overall_dice:.4f}") + + if overall_dice >= 0.7: + print("✓ SUCCESS: Overall Dice >= 0.7") + else: + print(f"✗ FAIL: Overall Dice {overall_dice:.4f} < 0.7") + + print("=" * 50) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Evaluate model on test set") + parser.add_argument("--data_path", type=str, default=r"C:\data\HipMRI_3D", help="Path to dataset") + parser.add_argument("--model_path", type=str, default="best_model.pth", help="Path to model checkpoint") + parser.add_argument("--batch_size", type=int, default=4, help="Batch size") + + args = parser.parse_args() + main(args) diff --git a/recognition/improved_unet_BochengLin/results/per_class_dice.png b/recognition/improved_unet_BochengLin/results/per_class_dice.png new file mode 100644 index 000000000..05ee0ab9f Binary files /dev/null and b/recognition/improved_unet_BochengLin/results/per_class_dice.png differ diff --git a/recognition/improved_unet_BochengLin/results/segmentation_example_1.png b/recognition/improved_unet_BochengLin/results/segmentation_example_1.png new file mode 100644 index 000000000..86196310c Binary files /dev/null and b/recognition/improved_unet_BochengLin/results/segmentation_example_1.png differ diff --git a/recognition/improved_unet_BochengLin/results/segmentation_example_2.png b/recognition/improved_unet_BochengLin/results/segmentation_example_2.png new file mode 100644 index 000000000..9a690be4e Binary files /dev/null and b/recognition/improved_unet_BochengLin/results/segmentation_example_2.png differ diff --git a/recognition/improved_unet_BochengLin/results/segmentation_example_3.png b/recognition/improved_unet_BochengLin/results/segmentation_example_3.png new file mode 100644 index 000000000..31ff1cda0 Binary files /dev/null and b/recognition/improved_unet_BochengLin/results/segmentation_example_3.png differ diff --git a/recognition/improved_unet_BochengLin/test_model.py b/recognition/improved_unet_BochengLin/test_model.py new file mode 100644 index 000000000..51d7f74d8 --- /dev/null +++ b/recognition/improved_unet_BochengLin/test_model.py @@ -0,0 +1,65 @@ +import torch +from modules import UNet3D_Improved + +""" +Basic sanity checks for UNet3D_Improved model. +Tests forward/backward pass and parameter count. +""" + +print("Testing UNet3D_Improved model...") +print("=" * 50) + +# Setup device +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +print(f"Using device: {device}") + +# Initialize model +model = UNet3D_Improved(in_channels=1, num_classes=2) +model = model.to(device) +print(f"✓ Model created") + +# Check parameter count +total_params = sum(p.numel() for p in model.parameters()) +print(f"✓ Total parameters: {total_params:,}") + +# Forward pass test +print("\nTesting forward pass...") +batch_size = 2 +dummy_input = torch.randn(batch_size, 1, 128, 128, 64).to(device) +print(f" Input shape: {dummy_input.shape}") + +try: + with torch.no_grad(): + output = model(dummy_input) + print(f"✓ Output shape: {output.shape}") + print(f"✓ Forward pass successful!") +except Exception as e: + print(f"✗ Error: {e}") + exit(1) + +# Backward pass test +print("\nTesting backward pass...") +try: + model.train() + optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) + + # Create dummy batch + dummy_input = torch.randn(batch_size, 1, 128, 128, 64, requires_grad=True).to(device) + dummy_target = torch.randint(0, 2, (batch_size, 128, 128, 64)).to(device) + + # Forward and backward + output = model(dummy_input) + loss = torch.nn.functional.cross_entropy(output, dummy_target) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + print(f"✓ Loss: {loss.item():.4f}") + print(f"✓ Backward pass successful!") +except Exception as e: + print(f"✗ Error: {e}") + exit(1) + +print("\n" + "=" * 50) +print("✓ All tests passed!") diff --git a/recognition/improved_unet_BochengLin/train.py b/recognition/improved_unet_BochengLin/train.py new file mode 100644 index 000000000..98c84e104 --- /dev/null +++ b/recognition/improved_unet_BochengLin/train.py @@ -0,0 +1,182 @@ +import torch +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import DataLoader +import os +import argparse +from tqdm import tqdm +import numpy as np + +from dataset import Prostate3DDataset +from modules import UNet3D_Improved + + +class DiceLoss(nn.Module): + """Dice Loss for multi-class segmentation.""" + def __init__(self, smooth=1.0, num_classes=6): + super(DiceLoss, self).__init__() + self.smooth = smooth + self.num_classes = num_classes + + def forward(self, pred, target): + """pred: (B, C, H, W, D), target: (B, 1, H, W, D)""" + target_one_hot = torch.zeros_like(pred) + for c in range(self.num_classes): + target_one_hot[:, c] = (target.squeeze(1) == c).float() + + pred = torch.softmax(pred, dim=1) + intersection = torch.sum(pred * target_one_hot, dim=(2, 3, 4)) + union = torch.sum(pred, dim=(2, 3, 4)) + torch.sum(target_one_hot, dim=(2, 3, 4)) + dice = (2.0 * intersection + self.smooth) / (union + self.smooth) + return 1.0 - dice.mean() + + +class DiceCoefficient: + """Compute Dice coefficient for validation.""" + def __init__(self, smooth=1.0, num_classes=6): + self.smooth = smooth + self.num_classes = num_classes + + def compute(self, pred, target): + """pred: (B, C, H, W, D), target: (B, 1, H, W, D)""" + pred = torch.argmax(pred, dim=1, keepdim=True) + dice_scores = [] + + for c in range(self.num_classes): + pred_c = (pred == c).float() + target_c = (target == c).float() + intersection = torch.sum(pred_c * target_c) + union = torch.sum(pred_c) + torch.sum(target_c) + if union == 0: + dice_scores.append(1.0) + else: + dice = (2.0 * intersection + self.smooth) / (union + self.smooth) + dice_scores.append(dice.item()) + + return np.mean(dice_scores) + + +def train_epoch(model, train_loader, criterion, optimizer, device): + """Train one epoch.""" + model.train() + total_loss = 0.0 + pbar = tqdm(train_loader, desc="Training", ncols=100) + + for batch_idx, batch in enumerate(pbar): + images = batch["image"].to(device) + labels = batch["label"].to(device) + + optimizer.zero_grad() + outputs = model(images) + loss = criterion(outputs, labels) + loss.backward() + optimizer.step() + + total_loss += loss.item() + avg_loss = total_loss / (batch_idx + 1) + pbar.set_postfix({"loss": f"{avg_loss:.4f}"}) + + return total_loss / len(train_loader) + + +def validate(model, val_loader, criterion, device): + """Validate on val set and return loss and Dice.""" + model.eval() + total_loss = 0.0 + dice_metric = DiceCoefficient(num_classes=6) + total_dice = 0.0 + pbar = tqdm(val_loader, desc="Validating", ncols=100) + + with torch.no_grad(): + for batch_idx, batch in enumerate(pbar): + images = batch["image"].to(device) + labels = batch["label"].to(device) + + outputs = model(images) + + # Debug: print shapes and value ranges + if batch_idx == 0: + print(f"\nDebug - Outputs shape: {outputs.shape}, range: [{outputs.min():.4f}, {outputs.max():.4f}]") + print(f"Debug - Labels shape: {labels.shape}, unique values: {torch.unique(labels)}") + + loss = criterion(outputs, labels) + total_loss += loss.item() + + dice = dice_metric.compute(outputs, labels) + total_dice += dice + + avg_dice = total_dice / (batch_idx + 1) + pbar.set_postfix({"dice": f"{avg_dice:.4f}"}) + + return total_loss / len(val_loader), total_dice / len(val_loader) + + +def main(args): + """ + Train 3D U-Net model for prostate MRI segmentation. + + Training workflow: + 1. Initialize model and move to device (CPU/GPU) + 2. Setup Dice loss and Adam optimizer + 3. Load train/val datasets + 4. Train for specified epochs with validation + 5. Save best model based on validation Dice + 6. Use learning rate scheduling to adjust learning rate + """ + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"Using device: {device}") + + # Create model - 6 classes (0,1,2,3,4,5) + model = UNet3D_Improved(in_channels=1, num_classes=6) + model = model.to(device) + print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}") + + # Loss and optimizer + criterion = DiceLoss(num_classes=6) + optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5) + scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=5, verbose=True) + + # Data loaders + train_dataset = Prostate3DDataset(root_dir=args.data_path, split="train") + val_dataset = Prostate3DDataset(root_dir=args.data_path, split="val") + + train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=0) + val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0) + + print(f"Train samples: {len(train_dataset)}") + print(f"Val samples: {len(val_dataset)}") + + # Training loop + best_dice = 0.0 + best_model_path = "best_model.pth" + + for epoch in range(args.epochs): + print(f"\nEpoch {epoch + 1}/{args.epochs}") + + train_loss = train_epoch(model, train_loader, criterion, optimizer, device) + val_loss, val_dice = validate(model, val_loader, criterion, device) + + print(f"Train Loss: {train_loss:.4f}") + print(f"Val Loss: {val_loss:.4f}") + print(f"Val Dice: {val_dice:.4f}") + + scheduler.step(val_dice) + + if val_dice > best_dice: + best_dice = val_dice + torch.save(model.state_dict(), best_model_path) + print(f"Saved best model with Dice: {best_dice:.4f}") + + print(f"\nTraining completed. Best Dice: {best_dice:.4f}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Train 3D U-Net for prostate segmentation") + parser.add_argument("--data_path", type=str, default=r"C:\data\HipMRI_3D", help="Path to dataset") + parser.add_argument("--batch_size", type=int, default=4, help="Batch size") + parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate") + parser.add_argument("--epochs", type=int, default=100, help="Number of epochs") + parser.add_argument("--target_dice", type=float, default=0.7, help="Target Dice coefficient") + + args = parser.parse_args() + main(args) diff --git a/recognition/improved_unet_BochengLin/visualize_results.py b/recognition/improved_unet_BochengLin/visualize_results.py new file mode 100644 index 000000000..f976e6fd0 --- /dev/null +++ b/recognition/improved_unet_BochengLin/visualize_results.py @@ -0,0 +1,200 @@ +import torch +import numpy as np +import matplotlib.pyplot as plt +import matplotlib.patches as mpatches +from torch.utils.data import DataLoader +import argparse +import os + +from dataset import Prostate3DDataset +from modules import UNet3D_Improved + + +def visualize_segmentation(model, test_loader, device, num_samples=3, save_dir=None): + """ + Generate and save segmentation visualization images. + Shows side-by-side comparison of input, ground truth, and predictions. + """ + + if save_dir is None: + script_dir = os.path.dirname(os.path.abspath(__file__)) + save_dir = os.path.join(script_dir, "results") + + if not os.path.exists(save_dir): + os.makedirs(save_dir) + + model.eval() + + sample_count = 0 + with torch.no_grad(): + for batch in test_loader: + if sample_count >= num_samples: + break + + images = batch["image"].to(device) + labels = batch["label"].to(device) + + outputs = model(images) + preds = torch.argmax(outputs, dim=1) + + batch_size = images.shape[0] + for b in range(batch_size): + if sample_count >= num_samples: + break + + # Extract middle slice for visualization + image = images[b, 0].cpu().numpy() + label = labels[b, 0].cpu().numpy() + pred = preds[b].cpu().numpy() + + mid_z = image.shape[2] // 2 + + # Create comparison figure + fig, axes = plt.subplots(1, 3, figsize=(15, 4)) + + # Input image + axes[0].imshow(image[:, :, mid_z], cmap='gray') + axes[0].set_title('Input Image (Middle Slice)') + axes[0].axis('off') + + # Ground truth labels + im1 = axes[1].imshow(label[:, :, mid_z], cmap='tab10', vmin=0, vmax=9) + axes[1].set_title('Ground Truth Label') + axes[1].axis('off') + + # Model prediction + im2 = axes[2].imshow(pred[:, :, mid_z], cmap='tab10', vmin=0, vmax=9) + axes[2].set_title('Model Prediction') + axes[2].axis('off') + + plt.tight_layout() + save_path = os.path.join(save_dir, f"segmentation_example_{sample_count + 1}.png") + plt.savefig(save_path, dpi=100, bbox_inches='tight') + print(f"Saved: {save_path}") + plt.close() + + sample_count += 1 + + +def compute_metrics(model, test_loader, device, num_classes=6): + """Compute Dice coefficient for each class on test set.""" + model.eval() + all_dice_scores = {c: [] for c in range(num_classes)} + + with torch.no_grad(): + for batch in test_loader: + images = batch["image"].to(device) + labels = batch["label"].to(device) + + outputs = model(images) + preds = torch.argmax(outputs, dim=1, keepdim=True) + + # Calculate Dice for each class + for c in range(num_classes): + pred_c = (preds == c).float() + target_c = (labels == c).float() + intersection = torch.sum(pred_c * target_c) + union = torch.sum(pred_c) + torch.sum(target_c) + + if union == 0: + dice = 1.0 + else: + dice = (2.0 * intersection + 1.0) / (union + 1.0) + dice = dice.item() + + all_dice_scores[c].append(dice) + + return all_dice_scores + + +def plot_metrics(all_dice_scores, save_dir=None): + """Generate and save Dice score bar chart.""" + + if save_dir is None: + script_dir = os.path.dirname(os.path.abspath(__file__)) + save_dir = os.path.join(script_dir, "results") + + if not os.path.exists(save_dir): + os.makedirs(save_dir) + + classes = list(all_dice_scores.keys()) + mean_dices = [np.mean(all_dice_scores[c]) for c in classes] + + fig, ax = plt.subplots(figsize=(10, 6)) + bars = ax.bar(classes, mean_dices, color='skyblue', edgecolor='navy', alpha=0.7) + + # Add value labels on each bar + for i, (c, dice) in enumerate(zip(classes, mean_dices)): + ax.text(i, dice + 0.02, f'{dice:.3f}', ha='center', va='bottom', fontsize=10) + + ax.set_xlabel('Class', fontsize=12) + ax.set_ylabel('Dice Coefficient', fontsize=12) + ax.set_title('Per-Class Dice Scores on Test Set', fontsize=14) + ax.set_ylim([0, 1.0]) + ax.axhline(y=0.7, color='red', linestyle='--', label='Target (0.7)', linewidth=2) + ax.legend() + ax.grid(axis='y', alpha=0.3) + + plt.tight_layout() + save_path = os.path.join(save_dir, "per_class_dice.png") + plt.savefig(save_path, dpi=100, bbox_inches='tight') + print(f"Saved: {save_path}") + plt.close() + + +def main(args): + """Generate segmentation visualizations and performance metrics.""" + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"Using device: {device}") + + # Create results directory if needed + if args.save_dir is None: + args.save_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "results") + os.makedirs(args.save_dir, exist_ok=True) + + # Load trained model + model = UNet3D_Improved(in_channels=1, num_classes=6) + model.load_state_dict(torch.load(args.model_path, map_location=device)) + model = model.to(device) + print(f"Loaded model from {args.model_path}") + + # Load test dataset + test_dataset = Prostate3DDataset(root_dir=args.data_path, split="test") + test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0) + + print(f"Generating visualizations for {args.num_samples} test samples...") + + # Generate segmentation visualizations + visualize_segmentation(model, test_loader, device, args.num_samples, args.save_dir) + + # Compute and plot Dice metrics + print("Computing per-class Dice scores...") + all_dice_scores = compute_metrics(model, test_loader, device, num_classes=6) + plot_metrics(all_dice_scores, args.save_dir) + + # Print summary + print("\n" + "="*50) + print("Results Summary") + print("="*50) + mean_dice_per_class = [] + for c in range(6): + mean_dice = np.mean(all_dice_scores[c]) + mean_dice_per_class.append(mean_dice) + print(f"Class {c}: {mean_dice:.4f}") + + overall_dice = np.mean(mean_dice_per_class) + print(f"\nOverall Dice: {overall_dice:.4f}") + print("="*50) + print(f"Visualizations saved to {args.save_dir}/") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Generate segmentation visualizations and metrics") + parser.add_argument("--data_path", type=str, default=r"C:\data\HipMRI_3D", help="Path to dataset root") + parser.add_argument("--model_path", type=str, default="best_model.pth", help="Path to trained model") + parser.add_argument("--batch_size", type=int, default=4, help="Batch size for evaluation") + parser.add_argument("--num_samples", type=int, default=3, help="Number of samples to visualize") + parser.add_argument("--save_dir", type=str, default=None, help="Output directory for visualizations") + + args = parser.parse_args() + main(args)