-
Notifications
You must be signed in to change notification settings - Fork 780
Closed
Description
Hi Team,
Thank you for creating MONAI.
I have been struggling for few days on using MONAI for my task. I have a large dataset of Prostate MRI and they have various sizes, and different pixel spacing and intensity. My input size is (256x256x20) and I want to train a 3D UNet on this data. The main data loader is below, and it works fine for a batch size == 1. However I want to train with larger batch-size maybe 2-4-8. As I have enough GPU. But with increasing the batch-size to 2 the data loader gives and error which I searched but nothing popped up.
I am not sure, maybe when I read two images which have different image sizes cause this issue.
def get_loader(train_images, train_segs, valid_images, valid_segs, patch_size = [256, 256, 20]):
data_dicts_train = [{'image': image_name,
'label': label_name} for image_name, label_name in zip(train_images,train_segs)]
train_transform = Compose([
LoadImaged(keys=['image','label']),
AddChanneld(keys=['image', 'label']),
Spacingd(keys=["image", "label"], pixdim=[0.5, 0.5, 3.0], mode=("bilinear", "nearest")),
CenterSpatialCropd(keys=['image', 'label'], roi_size = patch_size),
SpatialPadd(keys=['image', 'label'], spatial_size=patch_size, method='end'),
RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=1),
RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2),
NormalizeIntensityd(keys=['image'], nonzero=True, channel_wise=True),
RandScaleIntensityd(keys="image", factors=0.1, prob=1.0),
RandShiftIntensityd(keys="image", offsets=0.1, prob=1.0),
RandAdjustContrastd(keys=['image'],gamma=(0.5, 2.5),prob=0.2),
#ScaleIntensityRanged(keys=['image'], a_min=0, a_max=255, b_min=0.0, b_max=1.0),
ToTensord(keys=['image', 'label'])
])
data_dicts_valid = [{'image': image_name,
'label': label_name} for image_name, label_name in zip(valid_images,valid_segs)]
valid_transform = Compose([
LoadImaged(keys=['image','label']),
AddChanneld(keys=['image', 'label']),
Spacingd(keys=["image", "label"], pixdim=[0.5, 0.5, 3.0], mode=("bilinear", "nearest")),
CenterSpatialCropd(keys=['image', 'label'], roi_size = patch_size),
SpatialPadd(keys=['image', 'label'], spatial_size=patch_size, method='end'),
NormalizeIntensityd(keys=['image'], nonzero=True, channel_wise=True),
#ScaleIntensityRanged(keys=['image'], a_min=0, a_max=255, b_min=0.0, b_max=1.0),
ToTensord(keys=['image', 'label'])
])
train_ds = data.Dataset(data=data_dicts_train, transform=train_transform)
train_loader = data.DataLoader(
train_ds,
batch_size=1,
shuffle=True, #collate_fn=list_data_collate,
num_workers=8,
pin_memory=True,
)
val_ds = data.Dataset(data=data_dicts_valid, transform=valid_transform)
val_loader = data.DataLoader(
val_ds,
batch_size=1,
shuffle=False,
num_workers=8, #collate_fn=list_data_collate,
pin_memory=False
)
return train_loader, val_loader
Here is the error:
Traceback (most recent call last):
File "/home/CommonTools/anaconda3/lib/python3.8/threading.py", line 932, in _bootstrap_inner
self.run()
File "/home/CommonTools/anaconda3/lib/python3.8/threading.py", line 870, in run
self._target(*self._args, **self._kwargs)
File "/home/CommonTools/anaconda3/lib/python3.8/site-packages/monai/data/thread_buffer.py", line 48, in enqueue_values
for src_val in self.src:
File "/home/CommonTools/anaconda3/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 517, in __next__
data = self._next_data()
File "/home/CommonTools/anaconda3/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 557, in _next_data
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
File "/home/CommonTools/anaconda3/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 47, in fetch
return self.collate_fn(data)
File "/home/CommonTools/anaconda3/lib/python3.8/site-packages/monai/data/utils.py", line 413, in list_data_collate
ret[key] = default_collate(data_for_batch)
File "/home/CommonTools/anaconda3/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 73, in default_collate
return {key: default_collate([d[key] for d in batch]) for key in elem}
File "/home/CommonTools/anaconda3/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 73, in <dictcomp>
return {key: default_collate([d[key] for d in batch]) for key in elem}
File "/home/CommonTools/anaconda3/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 73, in <listcomp>
return {key: default_collate([d[key] for d in batch]) for key in elem}
KeyError: '0008|0060'
Metadata
Metadata
Assignees
Labels
No labels