Skip to content

Batch-size more than 1 #845

@sulaimanvesal

Description

@sulaimanvesal

Hi Team,

Thank you for creating MONAI.

I have been struggling for few days on using MONAI for my task. I have a large dataset of Prostate MRI and they have various sizes, and different pixel spacing and intensity. My input size is (256x256x20) and I want to train a 3D UNet on this data. The main data loader is below, and it works fine for a batch size == 1. However I want to train with larger batch-size maybe 2-4-8. As I have enough GPU. But with increasing the batch-size to 2 the data loader gives and error which I searched but nothing popped up.

I am not sure, maybe when I read two images which have different image sizes cause this issue.

def get_loader(train_images, train_segs, valid_images, valid_segs, patch_size = [256, 256, 20]):
    
    data_dicts_train = [{'image': image_name,
                   'label': label_name} for image_name, label_name in zip(train_images,train_segs)]
    train_transform = Compose([
        LoadImaged(keys=['image','label']),
        AddChanneld(keys=['image', 'label']),
        Spacingd(keys=["image", "label"], pixdim=[0.5, 0.5, 3.0], mode=("bilinear", "nearest")),
        CenterSpatialCropd(keys=['image', 'label'], roi_size = patch_size),
        SpatialPadd(keys=['image', 'label'], spatial_size=patch_size, method='end'),
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=1),
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2),
        NormalizeIntensityd(keys=['image'], nonzero=True, channel_wise=True),
        RandScaleIntensityd(keys="image", factors=0.1, prob=1.0),
        RandShiftIntensityd(keys="image", offsets=0.1, prob=1.0),
        RandAdjustContrastd(keys=['image'],gamma=(0.5, 2.5),prob=0.2),
        #ScaleIntensityRanged(keys=['image'], a_min=0, a_max=255, b_min=0.0, b_max=1.0),
        ToTensord(keys=['image', 'label'])
    ])
    data_dicts_valid = [{'image': image_name,
                   'label': label_name} for image_name, label_name in zip(valid_images,valid_segs)]
    valid_transform = Compose([
        LoadImaged(keys=['image','label']),
        AddChanneld(keys=['image', 'label']),
        Spacingd(keys=["image", "label"], pixdim=[0.5, 0.5, 3.0], mode=("bilinear", "nearest")),
        CenterSpatialCropd(keys=['image', 'label'], roi_size = patch_size),
        SpatialPadd(keys=['image', 'label'], spatial_size=patch_size, method='end'),
        NormalizeIntensityd(keys=['image'], nonzero=True, channel_wise=True),
        #ScaleIntensityRanged(keys=['image'], a_min=0, a_max=255, b_min=0.0, b_max=1.0),
        ToTensord(keys=['image', 'label'])
    ])
    
    train_ds = data.Dataset(data=data_dicts_train, transform=train_transform)
    train_loader = data.DataLoader(
        train_ds,
        batch_size=1,
        shuffle=True, #collate_fn=list_data_collate,
        num_workers=8,
        pin_memory=True,
    )
    val_ds = data.Dataset(data=data_dicts_valid, transform=valid_transform)
    val_loader = data.DataLoader(
        val_ds,
        batch_size=1,
        shuffle=False,
        num_workers=8, #collate_fn=list_data_collate,
        pin_memory=False
    )
    return train_loader, val_loader

Here is the error:

Traceback (most recent call last):
  File "/home/CommonTools/anaconda3/lib/python3.8/threading.py", line 932, in _bootstrap_inner
    self.run()
  File "/home/CommonTools/anaconda3/lib/python3.8/threading.py", line 870, in run
    self._target(*self._args, **self._kwargs)
  File "/home/CommonTools/anaconda3/lib/python3.8/site-packages/monai/data/thread_buffer.py", line 48, in enqueue_values
    for src_val in self.src:
  File "/home/CommonTools/anaconda3/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 517, in __next__
    data = self._next_data()
  File "/home/CommonTools/anaconda3/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 557, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
  File "/home/CommonTools/anaconda3/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 47, in fetch
    return self.collate_fn(data)
  File "/home/CommonTools/anaconda3/lib/python3.8/site-packages/monai/data/utils.py", line 413, in list_data_collate
    ret[key] = default_collate(data_for_batch)
  File "/home/CommonTools/anaconda3/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 73, in default_collate
    return {key: default_collate([d[key] for d in batch]) for key in elem}
  File "/home/CommonTools/anaconda3/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 73, in <dictcomp>
    return {key: default_collate([d[key] for d in batch]) for key in elem}
  File "/home/CommonTools/anaconda3/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 73, in <listcomp>
    return {key: default_collate([d[key] for d in batch]) for key in elem}
KeyError: '0008|0060'

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions