From 1251c161ae5ceb3189375a4d92f7a1c891731509 Mon Sep 17 00:00:00 2001 From: heyufan1995 Date: Wed, 12 Mar 2025 15:20:55 -0400 Subject: [PATCH 1/2] Add CVPR workshop baseline code Signed-off-by: heyufan1995 --- vista3d/README.md | 5 + vista3d/cvpr_workshop/Dockerfile | 24 +++ vista3d/cvpr_workshop/README.md | 34 +++++ vista3d/cvpr_workshop/infer_cvpr.py | 147 +++++++++++++++++++ vista3d/cvpr_workshop/predict.sh | 4 + vista3d/cvpr_workshop/requirements.txt | 13 ++ vista3d/cvpr_workshop/train_cvpr.py | 196 +++++++++++++++++++++++++ vista3d/cvpr_workshop/update_ckpt.py | 33 +++++ 8 files changed, 456 insertions(+) create mode 100755 vista3d/cvpr_workshop/Dockerfile create mode 100644 vista3d/cvpr_workshop/README.md create mode 100755 vista3d/cvpr_workshop/infer_cvpr.py create mode 100755 vista3d/cvpr_workshop/predict.sh create mode 100755 vista3d/cvpr_workshop/requirements.txt create mode 100755 vista3d/cvpr_workshop/train_cvpr.py create mode 100755 vista3d/cvpr_workshop/update_ckpt.py diff --git a/vista3d/README.md b/vista3d/README.md index bf8c911..2122de2 100644 --- a/vista3d/README.md +++ b/vista3d/README.md @@ -13,6 +13,11 @@ limitations under the License. # MONAI **V**ersatile **I**maging **S**egmen**T**ation and **A**nnotation [[`Paper`](https://arxiv.org/pdf/2406.05285)] [[`Demo`](https://build.nvidia.com/nvidia/vista-3d)] [[`Checkpoint`]](https://drive.google.com/file/d/1DRYA2-AI-UJ23W1VbjqHsnHENGi0ShUl/view?usp=sharing) + +## News! +[03/12/2025] We provide VISTA3D as a baseline for the challenge "CVPR 2025: Foundation Models for Interactive 3D Biomedical Image Segmentation"([link](https://www.codabench.org/competitions/5263/)). The simplified code based on MONAI 1.4 is provided in the [here](./cvpr_workshop/). + +[02/26/2025] VISTA3D paper has been accepted by **CVPR2025**! ## Overview The **VISTA3D** is a foundation model trained systematically on 11,454 volumes encompassing 127 types of human anatomical structures and various lesions. It provides accurate out-of-the-box segmentation that matches state-of-the-art supervised models which are trained on each dataset. The model also achieves state-of-the-art zero-shot interactive segmentation in 3D, representing a promising step toward developing a versatile medical image foundation model. diff --git a/vista3d/cvpr_workshop/Dockerfile b/vista3d/cvpr_workshop/Dockerfile new file mode 100755 index 0000000..186d69f --- /dev/null +++ b/vista3d/cvpr_workshop/Dockerfile @@ -0,0 +1,24 @@ +# Use an appropriate base image with GPU support +FROM nvidia/cuda:11.8.0-runtime-ubuntu22.04 +RUN apt-get update && apt-get install -y \ + python3 python3-pip && \ + rm -rf /var/lib/apt/lists/* +# Set working directory +WORKDIR /workspace + +# Copy inference script and requirements +COPY infer_cvpr.py /workspace/infer.py +COPY train_cvpr.py /workspace/train.py +COPY update_ckpt.py /workspace/update_ckpt.py +COPY Dockerfile /workspace/Dockerfile +COPY requirements.txt /workspace/ +COPY model_final.pth /workspace +# Install Python dependencies +RUN pip3 install -r requirements.txt + +# Copy the prediction script +COPY predict.sh /workspace/predict.sh +RUN chmod +x /workspace/predict.sh + +# Set default command +CMD ["/bin/bash"] \ No newline at end of file diff --git a/vista3d/cvpr_workshop/README.md b/vista3d/cvpr_workshop/README.md new file mode 100644 index 0000000..6f13a92 --- /dev/null +++ b/vista3d/cvpr_workshop/README.md @@ -0,0 +1,34 @@ + + +# Overview +This repository is written for the "CVPR 2025: Foundation Models for Interactive 3D Biomedical Image Segmentation"([link](https://www.codabench.org/competitions/5263/)) challenge. It +is based on MONAI 1.4. Many of the functions in the main VISTA3D repository are moved to MONAI 1.4 and this simplified folder will directly use components from MONAI. + +It is overly simplied to train interactive segmentation models across different modalities. The sophisticated transforms and recipes used for VISTA3D are removed. + +# Setup +``` +pip install -r requirements.txt +``` + +# Training +Download VISTA3D pretrained checkpoint or from scratch. Generate a json list that contains your traning data. +``` +torchrun --nnodes=1 --nproc_per_node=8 train_cvpr.py +``` + +# Inference +We provide a Dockerfile to satisfy the challenge format. For more details, refer to the [challenge website]((https://www.codabench.org/competitions/5263/)) + + diff --git a/vista3d/cvpr_workshop/infer_cvpr.py b/vista3d/cvpr_workshop/infer_cvpr.py new file mode 100755 index 0000000..7694965 --- /dev/null +++ b/vista3d/cvpr_workshop/infer_cvpr.py @@ -0,0 +1,147 @@ +import monai +import monai.transforms +import torch +import argparse +import numpy as np +import nibabel as nib +import glob +from monai.networks.nets.vista3d import vista3d132 +from monai.utils import optional_import +from monai.apps.vista3d.inferer import point_based_window_inferer +from monai.inferers import SlidingWindowInfererAdapt + +tqdm, _ = optional_import("tqdm", name="tqdm") +import numpy as np +import pdb +import os + +def convert_clicks(alldata): + # indexes = list(alldata.keys()) + # data = [alldata[i] for i in indexes] + data = alldata + B = len(data) # Number of objects + indexes = np.arange(1, B+1).tolist() + # Determine the maximum number of points across all objects + max_N = max(len(obj['fg']) + len(obj['bg']) for obj in data) + + # Initialize padded arrays + point_coords = np.zeros((B, max_N, 3), dtype=int) + point_labels = np.full((B, max_N), -1, dtype=int) + + for i, obj in enumerate(data): + points = [] + labels = [] + + # Add foreground points + for fg_point in obj['fg']: + points.append(fg_point) + labels.append(1) + + # Add background points + for bg_point in obj['bg']: + points.append(bg_point) + labels.append(0) + + # Fill in the arrays + point_coords[i, :len(points)] = points + point_labels[i, :len(labels)] = labels + + return point_coords, point_labels, indexes + + +if __name__ == '__main__': + # set to true to save nifti files for visualization + save_data = False + point_inferer = True # use point based inferen + roi_size = [128,128,128] + parser = argparse.ArgumentParser() + parser.add_argument("--test_img_path", type=str, default='./tests') + parser.add_argument("--save_path", type=str, default='./outputs/') + parser.add_argument("--model", type=str, default='checkpoints/model_final.pth') + args = parser.parse_args() + os.makedirs(args.save_path,exist_ok=True) + # load model + checkpoint_path = args.model + model = vista3d132(in_channels=1) + pretrained_ckpt = torch.load(checkpoint_path, map_location='cuda') + model.load_state_dict(pretrained_ckpt, strict=True) + + # load data + test_cases = glob.glob(os.path.join(args.test_img_path, "*.npz")) + for img_path in test_cases: + case_name = os.path.basename(img_path) + print(case_name) + img = np.load(img_path, allow_pickle=True) + img_array = img['imgs'] + spacing = img['spacing'] + original_shape = img_array.shape + affine = np.diag(spacing.tolist() + [1]) # 4x4 affine matrix + if save_data: + # Create a NIfTI image + nifti_img = nib.Nifti1Image(img_array, affine) + # Save the NIfTI file + nib.save(nifti_img, img_path.replace('.npz','.nii.gz')) + nifti_img = nib.Nifti1Image(img['gts'], affine) + # Save the NIfTI file + nib.save(nifti_img, img_path.replace('.npz','gts.nii.gz')) + clicks = img.get('clicks', [{'fg':[[418, 138, 136]], 'bg':[]}]) + point_coords, point_labels, indexes = convert_clicks(clicks) + # preprocess + img_array = torch.from_numpy(img_array) + img_array = img_array.unsqueeze(0) + img_array = monai.transforms.ScaleIntensityRangePercentiles(lower=1, upper=99, b_min=0, b_max=1, clip=True)(img_array) + img_array = img_array.unsqueeze(0) # add channel dim + device = 'cuda' + # slidingwindow + with torch.no_grad(): + if not point_inferer: + model.NINF_VALUE = 0 # set to 0 in case sliding window is used. + # directly using slidingwindow inferer is not optimal. + val_outputs = SlidingWindowInfererAdapt( + roi_size=roi_size, sw_batch_size=1, with_coord=True, padding_mode="replicate" + )( + inputs=img_array.to(device), + transpose=True, + network=model.to(device), + point_coords=torch.from_numpy(point_coords).to(device), + point_labels=torch.from_numpy(point_labels).to(device) + )[0] > 0 + final_outputs = torch.zeros_like(val_outputs[0], dtype=torch.float32) + for i, v in enumerate(val_outputs): + final_outputs += indexes[i] * v + else: + # point based + final_outputs = torch.zeros_like(img_array[0,0], dtype=torch.float32) + for i, v in enumerate(indexes): + val_outputs = point_based_window_inferer( + inputs=img_array.to(device), + roi_size=roi_size, + transpose=True, + with_coord=True, + predictor=model.to(device), + mode="gaussian", + sw_device=device, + device=device, + center_only=True, # only crop the center + point_coords=torch.from_numpy(point_coords[[i]]).to(device), + point_labels=torch.from_numpy(point_labels[[i]]).to(device) + )[0] > 0 + final_outputs[val_outputs[0]] = v + final_outputs = torch.nan_to_num(final_outputs) + # save data + if save_data: + # Create a NIfTI image + nifti_img = nib.Nifti1Image(final_outputs.to(torch.float32).data.cpu().numpy(), affine) + # Save the NIfTI file + nib.save(nifti_img, os.path.join(args.save_path, case_name.replace('.npz','.nii.gz'))) + np.savez_compressed(os.path.join(args.save_path, case_name), segs=final_outputs.to(torch.float32).data.cpu().numpy()) + + + + + + + + + + \ No newline at end of file diff --git a/vista3d/cvpr_workshop/predict.sh b/vista3d/cvpr_workshop/predict.sh new file mode 100755 index 0000000..438a705 --- /dev/null +++ b/vista3d/cvpr_workshop/predict.sh @@ -0,0 +1,4 @@ +#!/bin/bash + +# Run inference script with input/output folder paths +python3 infer.py --test_img_path /workspace/inputs/ --save_path /workspace/outputs/ --model /workspace/model_final.pth diff --git a/vista3d/cvpr_workshop/requirements.txt b/vista3d/cvpr_workshop/requirements.txt new file mode 100755 index 0000000..c57fbde --- /dev/null +++ b/vista3d/cvpr_workshop/requirements.txt @@ -0,0 +1,13 @@ +tensorboard +matplotlib +monai +torchvision +nibabel +torch +connected-components-3d +pandas +numpy +scipy +cupy-cuda12x +cucim +tqdm \ No newline at end of file diff --git a/vista3d/cvpr_workshop/train_cvpr.py b/vista3d/cvpr_workshop/train_cvpr.py new file mode 100755 index 0000000..c10da20 --- /dev/null +++ b/vista3d/cvpr_workshop/train_cvpr.py @@ -0,0 +1,196 @@ +import os +import json +import monai.transforms +import torch +import torch.nn as nn +import torch.optim as optim +import torch.distributed as dist +from torch.utils.data import Dataset +from torch.nn.parallel import DistributedDataParallel as DDP +import numpy as np +import monai +from tqdm import tqdm +import pdb +from monai.networks.nets import vista3d132 +from monai.apps.vista3d.sampler import sample_prompt_pairs +from torch.utils.tensorboard import SummaryWriter +from monai.data import DataLoader, DistributedSampler +import warnings +import nibabel as nib +warnings.simplefilter("ignore") +# Custom dataset for .npz files + +import matplotlib.pyplot as plt +import torchvision.utils as vutils + +NUM_PATCHES_PER_IMAGE=4 + +def plot_to_tensorboard(writer, epoch, inputs, labels, points, outputs): + """ + Plots B figures, where each figure shows the slice where the point is located + and overlays the point on this slice. + + Args: + writer: TensorBoard writer + epoch: Current epoch number + inputs: Tensor [1, 1, H, W, D] - Input image + labels: Tensor [1, 1, H, W, D] - Ground truth segmentation + points: Tensor [B, N, 3] - Foreground object points (z, y, x) + outputs: Tensor [B, 1, H, W, D] - Model outputs + """ + B, N, _ = points.shape # B objects, N click points per object + inputs_np = inputs[0, 0].cpu().numpy() # [H, W, D] + labels_np = labels[0, 0].cpu().numpy() # [H, W, D] + + for b in range(B): + fig, axes = plt.subplots(1, 3, figsize=(12, 4)) + + # Select the first click point in (z, y, x) format + x, y, z = points[b, 0].cpu().numpy().astype(int) + + # Extract the corresponding slice + input_slice = inputs_np[:, :, z] # Get slice at depth z + label_slice = labels_np[:, :, z] + output_slice = outputs[b, 0].cpu().detach().numpy()[:, :, z] > 0 + + # Plot input with point overlay + axes[0].imshow(input_slice, cmap='gray') + axes[0].scatter(y, x, c='red', marker='x', s=50) + axes[0].set_title(f"Input (Slice {z})") + + # Plot label + axes[1].imshow(label_slice, cmap='gray') + axes[0].scatter(y, x, c='red', marker='x', s=50) + axes[1].set_title(f"Ground Truth (Slice {z})") + + # Plot output + axes[2].imshow(output_slice, cmap='gray') + axes[0].scatter(y, x, c='red', marker='x', s=50) + axes[2].set_title(f"Model Output (Slice {z})") + + plt.tight_layout() + + # Log figure to TensorBoard + writer.add_figure(f"Object_{b}_Segmentation", fig, epoch) + plt.close(fig) + +class NPZDataset(Dataset): + def __init__(self, json_file): + with open(json_file, 'r') as f: + self.file_paths = json.load(f) + self.base_path = '/workspace/VISTA/CVPR-MedSegFMCompetition/trainsubset' + def __len__(self): + return len(self.file_paths) + + def __getitem__(self, idx): + img = np.load(os.path.join(self.base_path, self.file_paths[idx])) + img_array = torch.from_numpy(img['imgs']).unsqueeze(0).to(torch.float32) + label = torch.from_numpy(img['gts']).unsqueeze(0).to(torch.int32) + data = {"image": img_array, "label": label, 'filename': self.file_paths[idx]} + affine = np.diag(img['spacing'].tolist() + [1]) # 4x4 affine matrix + transforms = monai.transforms.Compose([ + monai.transforms.ScaleIntensityRangePercentilesd(keys="image", lower=1, upper=99, b_min=0, b_max=1, clip=True), + monai.transforms.SpatialPadd(mode=["constant", "constant"], keys=["image", "label"], spatial_size=[128, 128, 128]), + monai.transforms.RandCropByLabelClassesd(spatial_size=[128, 128, 128], keys=["image", "label"], label_key="label",num_classes=label.max() + 1, num_samples=NUM_PATCHES_PER_IMAGE), + monai.transforms.RandScaleIntensityd(factors=0.2, prob=0.2, keys="image"), + monai.transforms.RandShiftIntensityd(offsets=0.2, prob=0.2, keys="image"), + monai.transforms.RandGaussianNoised(mean=0., std=0.2, prob=0.2, keys="image"), + monai.transforms.RandFlipd(spatial_axis=0, prob=0.2, keys=["image", "label"]), + monai.transforms.RandFlipd(spatial_axis=1, prob=0.2, keys=["image", "label"]), + monai.transforms.RandFlipd(spatial_axis=2, prob=0.2, keys=["image", "label"]), + monai.transforms.RandRotate90d(max_k=3, prob=0.2, keys=["image", "label"]) + ]) + data = transforms(data) + return data +# Training function +def train(): + epoch_number = 100 + start_epoch = 30 + lr = 2e-5 + checkpoint_dir = "checkpoints" + os.makedirs(checkpoint_dir, exist_ok=True) + dist.init_process_group(backend="nccl") + world_size = int(os.environ["WORLD_SIZE"]) + local_rank = int(os.environ["LOCAL_RANK"]) + torch.cuda.set_device(local_rank) + device = torch.device(f"cuda:{local_rank}") + json_file = "subset.json" # Update with your JSON file + dataset = NPZDataset(json_file) + sampler = torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=world_size, rank=local_rank) + dataloader = DataLoader(dataset, batch_size=1, sampler=sampler, num_workers=32) + model = vista3d132(in_channels=1).to(device) + # pretrained_ckpt = torch.load('/workspace/VISTA/vista3d/bundles/vista3d/models/model.pt', map_location=device) + pretrained_ckpt = torch.load(os.path.join(checkpoint_dir, f"model_epoch{start_epoch}.pth")) + model = DDP(model, device_ids=[local_rank], find_unused_parameters=True) + model.load_state_dict(pretrained_ckpt['model'], strict=True) + optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1.0e-05) + lr_scheduler = monai.optimizers.WarmupCosineSchedule(optimizer=optimizer, t_total= epoch_number+1, warmup_multiplier=0.1, warmup_steps=0) + if local_rank == 0: + writer = SummaryWriter(log_dir=os.path.join(checkpoint_dir, "Events")) + + step = start_epoch * len(dataloader) * NUM_PATCHES_PER_IMAGE + + for epoch in range(start_epoch, epoch_number): + sampler.set_epoch(epoch) + for batch in tqdm(dataloader): + image_l = batch["image"] + label_l = batch["label"] + for _k in range(image_l.shape[0]): + inputs = image_l[[_k]].to(device) + labels = label_l[[_k]].to(device) + label_prompt, point, point_label, prompt_class = sample_prompt_pairs( + labels, + list(set(labels.unique().tolist()) - {0}), + max_point=5, + max_prompt=10, + drop_label_prob=1, + drop_point_prob=0, + ) + skip_update = torch.zeros(1, device=device) + if point is None: + print(f"Iteration skipped due to None prompts at {batch['filename']}") + skip_update = torch.ones(1, device=device) + if world_size > 1: + dist.all_reduce(skip_update, op=dist.ReduceOp.SUM) + if skip_update[0] > 0: + continue # some rank has no foreground, skip this batch + optimizer.zero_grad() + outputs = model( + input_images=inputs, + point_coords=point, + point_labels=point_label + ) + if local_rank==0 and step % 50 == 0: + plot_to_tensorboard(writer, step, inputs, labels, point, outputs) + + loss, loss_n = torch.tensor(0.0, device=device), torch.tensor( + 0.0, device=device + ) + if prompt_class is not None: + for idx in range(len(prompt_class)): + if prompt_class[idx] == 0: + continue # skip background class + loss_n += 1.0 + gt = labels == prompt_class[idx] + loss += monai.losses.DiceCELoss(include_background=False, sigmoid=True, smooth_dr=1.0e-05, + smooth_nr=0, softmax=False, squared_pred=True, + to_onehot_y=False)(outputs[[idx]].float(), gt.float()) + loss /= max(loss_n, 1.0) + print(loss) + loss.backward() + optimizer.step() + step += 1 + if local_rank == 0: + writer.add_scalar('loss', loss.item(), step) + if local_rank == 0 and epoch % 5 == 0: + checkpoint_path = os.path.join(checkpoint_dir, f"model_epoch{epoch}.pth") + torch.save({'model': model.state_dict(), 'epoch': epoch, 'step':step}, checkpoint_path) + print(f"Rank {local_rank}, Epoch {epoch}, Loss: {loss.item()}, Checkpoint saved: {checkpoint_path}") + lr_scheduler.step() + + dist.destroy_process_group() + + +if __name__ == "__main__": + train() + # torchrun --nnodes=1 --nproc_per_node=8 train_cvpr.py \ No newline at end of file diff --git a/vista3d/cvpr_workshop/update_ckpt.py b/vista3d/cvpr_workshop/update_ckpt.py new file mode 100755 index 0000000..d798f66 --- /dev/null +++ b/vista3d/cvpr_workshop/update_ckpt.py @@ -0,0 +1,33 @@ +import torch +import argparse + +def remove_module_prefix(input_pth, output_pth): + # Load the checkpoint + checkpoint = torch.load(input_pth, map_location="cpu")['model'] + + # Modify the state_dict to remove 'module.' prefix + new_state_dict = {} + for key, value in checkpoint.items(): + if isinstance(value, dict) and "state_dict" in value: + # If the checkpoint contains a 'state_dict' key (common in some saved models) + new_state_dict = {k.replace("module.", ""): v for k, v in value["state_dict"].items()} + value["state_dict"] = new_state_dict + torch.save(value, output_pth) + print(f"Updated weights saved to {output_pth}") + return + elif "module." in key: + new_state_dict[key.replace("module.", "")] = value + else: + new_state_dict[key] = value + + # Save the modified weights + torch.save(new_state_dict, output_pth) + print(f"Updated weights saved to {output_pth}") + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Remove 'module.' prefix from PyTorch weights") + parser.add_argument("--input", required=True, help="Path to input .pth file") + parser.add_argument("--output", required=True, help="Path to save the modified .pth file") + args = parser.parse_args() + + remove_module_prefix(args.input, args.output) From 61dbaf6413bcaf01f5a489614c142cce8871ac59 Mon Sep 17 00:00:00 2001 From: heyufan1995 Date: Fri, 14 Mar 2025 09:16:50 -0400 Subject: [PATCH 2/2] Update readme Signed-off-by: heyufan1995 --- vista3d/cvpr_workshop/README.md | 11 ++++++++--- vista3d/cvpr_workshop/train_cvpr.py | 10 ++++++---- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/vista3d/cvpr_workshop/README.md b/vista3d/cvpr_workshop/README.md index 6f13a92..fc9e675 100644 --- a/vista3d/cvpr_workshop/README.md +++ b/vista3d/cvpr_workshop/README.md @@ -15,7 +15,7 @@ limitations under the License. This repository is written for the "CVPR 2025: Foundation Models for Interactive 3D Biomedical Image Segmentation"([link](https://www.codabench.org/competitions/5263/)) challenge. It is based on MONAI 1.4. Many of the functions in the main VISTA3D repository are moved to MONAI 1.4 and this simplified folder will directly use components from MONAI. -It is overly simplied to train interactive segmentation models across different modalities. The sophisticated transforms and recipes used for VISTA3D are removed. +It is simplified to train interactive segmentation models across different modalities. The sophisticated transforms and recipes used for VISTA3D are removed. The finetuned VISTA3D checkpoint on the challenge subsets is available [here](https://drive.google.com/file/d/1r2KvHP_30nHR3LU7NJEdscVnlZ2hTtcd/view?usp=sharing) # Setup ``` @@ -23,12 +23,17 @@ pip install -r requirements.txt ``` # Training -Download VISTA3D pretrained checkpoint or from scratch. Generate a json list that contains your traning data. +Download the challenge subsets finetuned [checkpoint](https://drive.google.com/file/d/1r2KvHP_30nHR3LU7NJEdscVnlZ2hTtcd/view?usp=sharing) or VISTA3D original [checkpoint]((https://drive.google.com/file/d/1DRYA2-AI-UJ23W1VbjqHsnHENGi0ShUl/view?usp=sharing)). Generate a json list that contains your traning data and update the json file path in the script. ``` torchrun --nnodes=1 --nproc_per_node=8 train_cvpr.py ``` # Inference -We provide a Dockerfile to satisfy the challenge format. For more details, refer to the [challenge website]((https://www.codabench.org/competitions/5263/)) +You can directly download the [docker file](https://drive.google.com/file/d/1r2KvHP_30nHR3LU7NJEdscVnlZ2hTtcd/view?usp=sharing) for the challenge baseline. +We provide a Dockerfile to satisfy the challenge format. For more details, refer to the [challenge website]((https://www.codabench.org/competitions/5263/)). +``` +docker build -t vista3d:latest . +docker save -o vista3d.tar.gz vista3d:latest +``` diff --git a/vista3d/cvpr_workshop/train_cvpr.py b/vista3d/cvpr_workshop/train_cvpr.py index c10da20..cc9c71c 100755 --- a/vista3d/cvpr_workshop/train_cvpr.py +++ b/vista3d/cvpr_workshop/train_cvpr.py @@ -104,23 +104,25 @@ def __getitem__(self, idx): return data # Training function def train(): + json_file = "subset.json" # Update with your JSON file epoch_number = 100 - start_epoch = 30 + start_epoch = 0 lr = 2e-5 checkpoint_dir = "checkpoints" + start_checkpoint = '/workspace/CPRR25_vista3D_model_final_10percent_data.pth' + os.makedirs(checkpoint_dir, exist_ok=True) dist.init_process_group(backend="nccl") world_size = int(os.environ["WORLD_SIZE"]) local_rank = int(os.environ["LOCAL_RANK"]) torch.cuda.set_device(local_rank) device = torch.device(f"cuda:{local_rank}") - json_file = "subset.json" # Update with your JSON file dataset = NPZDataset(json_file) sampler = torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=world_size, rank=local_rank) dataloader = DataLoader(dataset, batch_size=1, sampler=sampler, num_workers=32) model = vista3d132(in_channels=1).to(device) - # pretrained_ckpt = torch.load('/workspace/VISTA/vista3d/bundles/vista3d/models/model.pt', map_location=device) - pretrained_ckpt = torch.load(os.path.join(checkpoint_dir, f"model_epoch{start_epoch}.pth")) + pretrained_ckpt = torch.load(start_checkpoint, map_location=device) + # pretrained_ckpt = torch.load(os.path.join(checkpoint_dir, f"model_epoch{start_epoch}.pth")) model = DDP(model, device_ids=[local_rank], find_unused_parameters=True) model.load_state_dict(pretrained_ckpt['model'], strict=True) optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1.0e-05)