Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 9 additions & 11 deletions colossalai/booster/plugin/dp_plugin_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,19 @@ def __init__(self) -> None:
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()

def prepare_train_dataloader(self,
dataset,
batch_size,
shuffle=False,
seed=1024,
drop_last=False,
pin_memory=False,
num_workers=0,
**kwargs):
def prepare_dataloader(self,
dataset,
batch_size,
shuffle=False,
seed=1024,
drop_last=False,
pin_memory=False,
num_workers=0,
**kwargs):
r"""
Prepare a dataloader for distributed training. The dataloader will be wrapped by
`torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`.

Note:
1. Evaluation datasets should not be passed to this function.

Args:
dataset (`torch.utils.data.Dataset`): The dataset to be loaded.
Expand Down
2 changes: 1 addition & 1 deletion colossalai/booster/plugin/gemini_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ class GeminiPlugin(DPPluginBase):
>>> model, train_dataset, optimizer, criterion = ...
>>> plugin = GeminiPlugin()

>>> train_dataloader = plugin.prepare_train_dataloader(train_dataset, batch_size=8)
>>> train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
>>> booster = Booster(plugin=plugin)
>>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)

Expand Down
2 changes: 1 addition & 1 deletion colossalai/booster/plugin/low_level_zero_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ class LowLevelZeroPlugin(DPPluginBase):
>>> model, train_dataset, optimizer, criterion = ...
>>> plugin = LowLevelZeroPlugin()

>>> train_dataloader = plugin.prepare_train_dataloader(train_dataset, batch_size=8)
>>> train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
>>> booster = Booster(plugin=plugin)
>>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)

Expand Down
17 changes: 16 additions & 1 deletion colossalai/booster/plugin/plugin_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch.nn as nn
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, Dataset

from colossalai.checkpoint_io import CheckpointIO
from colossalai.interface import OptimizerWrapper
Expand Down Expand Up @@ -59,3 +59,18 @@ def get_checkpoint_io(self) -> CheckpointIO:
Get checkpoint io object for this plugin, only invoked when control_checkpoint_io is True.
"""
pass

@abstractmethod
def prepare_dataloader(self,
dataset: Dataset,
batch_size: int,
shuffle: bool = False,
seed: int = 1024,
drop_last: bool = False,
pin_memory: bool = False,
num_workers: int = 0,
**kwargs):
"""Prepare a dataloader for distributed training. The dataloader will be wrapped by
`torch.utils.data.DataLoader`
"""
pass
2 changes: 1 addition & 1 deletion colossalai/booster/plugin/torch_ddp_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ class TorchDDPPlugin(DPPluginBase):
>>> model, train_dataset, optimizer, criterion = ...
>>> plugin = TorchDDPPlugin()

>>> train_dataloader = plugin.prepare_train_dataloader(train_dataset, batch_size=8)
>>> train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
>>> booster = Booster(plugin=plugin)
>>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)

Expand Down
10 changes: 2 additions & 8 deletions examples/tutorial/new_api/cifar_resnet/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,8 @@ def build_dataloader(batch_size: int, coordinator: DistCoordinator, plugin: DPPl
download=True)

# Data loader
train_dataloader = plugin.prepare_train_dataloader(train_dataset,
batch_size=batch_size,
shuffle=True,
drop_last=True)
test_dataloader = plugin.prepare_train_dataloader(test_dataset,
batch_size=batch_size,
shuffle=False,
drop_last=False)
train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
test_dataloader = plugin.prepare_dataloader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=False)
return train_dataloader, test_dataloader


Expand Down
10 changes: 2 additions & 8 deletions examples/tutorial/new_api/cifar_vit/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,8 @@ def build_dataloader(batch_size: int, coordinator: DistCoordinator, plugin: DPPl
download=True)

# Data loader
train_dataloader = plugin.prepare_train_dataloader(train_dataset,
batch_size=batch_size,
shuffle=True,
drop_last=True)
test_dataloader = plugin.prepare_train_dataloader(test_dataset,
batch_size=batch_size,
shuffle=False,
drop_last=False)
train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
test_dataloader = plugin.prepare_dataloader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=False)
return train_dataloader, test_dataloader


Expand Down
16 changes: 8 additions & 8 deletions examples/tutorial/new_api/glue_bert/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,26 +84,26 @@ def prepare_data(self):
AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)

def train_dataloader(self):
return self.plugin.prepare_train_dataloader(self.dataset["train"],
batch_size=self.train_batch_size,
shuffle=True,
drop_last=True)
return self.plugin.prepare_dataloader(self.dataset["train"],
batch_size=self.train_batch_size,
shuffle=True,
drop_last=True)

def val_dataloader(self):
if len(self.eval_splits) == 1:
return self.plugin.prepare_train_dataloader(self.dataset["validation"], batch_size=self.eval_batch_size)
return self.plugin.prepare_dataloader(self.dataset["validation"], batch_size=self.eval_batch_size)
elif len(self.eval_splits) > 1:
return [
self.plugin.prepare_train_dataloader(self.dataset[x], batch_size=self.eval_batch_size)
self.plugin.prepare_dataloader(self.dataset[x], batch_size=self.eval_batch_size)
for x in self.eval_splits
]

def test_dataloader(self):
if len(self.eval_splits) == 1:
return self.plugin.prepare_train_dataloader(self.dataset["test"], batch_size=self.eval_batch_size)
return self.plugin.prepare_dataloader(self.dataset["test"], batch_size=self.eval_batch_size)
elif len(self.eval_splits) > 1:
return [
self.plugin.prepare_train_dataloader(self.dataset[x], batch_size=self.eval_batch_size)
self.plugin.prepare_dataloader(self.dataset[x], batch_size=self.eval_batch_size)
for x in self.eval_splits
]

Expand Down
2 changes: 1 addition & 1 deletion tests/test_booster/test_plugin/test_dp_plugin_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def check_dataloader_sharding():

# create a custom dasetset with 0 to 10
dataset = TensorDataset(torch.arange(0, 10))
train_dataloader = plugin.prepare_train_dataloader(dataset, batch_size=2)
train_dataloader = plugin.prepare_dataloader(dataset, batch_size=2)

# get the first batch of data
batch = next(iter(train_dataloader))[0].cuda()
Expand Down