Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion colossalai/booster/plugin/gemini_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import warnings
from pathlib import Path
from typing import Callable, List, Optional, Tuple, Union
from typing import Callable, Iterator, List, Optional, Tuple, Union

import torch
import torch.nn as nn
Expand Down Expand Up @@ -286,3 +286,6 @@ def control_checkpoint_io(self) -> bool:

def get_checkpoint_io(self) -> CheckpointIO:
return GeminiCheckpointIO()

def no_sync(self, model: nn.Module) -> Iterator[None]:
raise NotImplementedError
5 changes: 4 additions & 1 deletion colossalai/booster/plugin/low_level_zero_plugin.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import warnings
from typing import Callable, List, Optional, Tuple, Union
from typing import Callable, Iterator, List, Optional, Tuple, Union

import torch
import torch.nn as nn
Expand Down Expand Up @@ -197,3 +197,6 @@ def control_checkpoint_io(self) -> bool:

def get_checkpoint_io(self) -> CheckpointIO:
return LowLevelZeroCheckpointIO()

def no_sync(self, model: nn.Module) -> Iterator[None]:
raise NotImplementedError
9 changes: 8 additions & 1 deletion colossalai/booster/plugin/plugin_base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Callable, List, Tuple, Union
from typing import Callable, Iterator, List, Tuple, Union

import torch.nn as nn
from torch.optim import Optimizer
Expand Down Expand Up @@ -60,6 +60,13 @@ def get_checkpoint_io(self) -> CheckpointIO:
"""
pass

@abstractmethod
def no_sync(self, model: nn.Module) -> Iterator[None]:
"""
Context manager to disable gradient synchronization.
"""
pass

@abstractmethod
def prepare_dataloader(self,
dataset: Dataset,
Expand Down
6 changes: 5 additions & 1 deletion colossalai/booster/plugin/torch_ddp_plugin.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, List, Tuple, Union
from typing import Callable, Iterator, List, Tuple, Union

import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
Expand Down Expand Up @@ -142,3 +142,7 @@ def control_checkpoint_io(self) -> bool:

def get_checkpoint_io(self) -> CheckpointIO:
return TorchDDPCheckpointIO()

def no_sync(self, model: nn.Module) -> Iterator[None]:
assert isinstance(model, TorchDDPModel), 'Model is not boosted by TorchDDPPlugin.'
return model.module.no_sync()
5 changes: 4 additions & 1 deletion tests/test_booster/test_plugin/test_dp_plugin_base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, List, Tuple, Union
from typing import Callable, Iterator, List, Tuple, Union

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -49,6 +49,9 @@ def supported_devices(self) -> List[str]:
def supported_precisions(self) -> List[str]:
pass

def no_sync(self, model: nn.Module) -> Iterator[None]:
pass


def check_dataloader_sharding():
plugin = DPPluginWrapper()
Expand Down
60 changes: 60 additions & 0 deletions tests/test_booster/test_plugin/test_torch_ddp_plugin.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from contextlib import nullcontext

import torch
import torch.distributed as dist
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import SGD

Expand Down Expand Up @@ -44,10 +47,67 @@ def check_torch_ddp_plugin():
torch.cuda.empty_cache()


class DummyModel(nn.Module):

def __init__(self):
super().__init__()
self.weight = nn.Parameter(torch.rand(1))

def forward(self, x):
return self.weight * x


def check_torch_ddp_no_sync():
plugin = TorchDDPPlugin()
booster = Booster(plugin=plugin)

model = DummyModel()
criterion = lambda x: x.mean()
optimizer = SGD(model.parameters(), lr=1e-3)
# create a custom dasetset with 0 to 10
dataset = torch.arange(0, 10)
train_dataloader = plugin.prepare_dataloader(dataset, batch_size=2)
model, optimizer, criterion, train_dataloader, _ = booster.boost(model,
optimizer,
criterion,
dataloader=train_dataloader)

def fwd_bwd():
output = model(batch.cuda())
loss = criterion(output)
booster.backward(loss, optimizer)

def get_grad_set_over_all_ranks():
for p in model.parameters():
# grad shape is (1, )
assert p.grad.shape == (1,)
grad_list = [torch.empty_like(p.grad) for _ in range(dist.get_world_size())]
dist.all_gather(grad_list, p.grad)
# get grad set of all ranks
grad_set = set([grad.item() for grad in grad_list])
# as the model only has one parameter, we can return here
return grad_set

for i, batch in enumerate(train_dataloader):
if i > 1:
# only check the first two batches
break
# no_sync for the first batch, sync for the second batch
ctx = booster.no_sync(model) if i == 0 else nullcontext()
with ctx:
fwd_bwd()
grad_set = get_grad_set_over_all_ranks()
# for the first batch, all ranks should have different grads
# for the second batch, as grad is synchronized,all ranks should have the same grads
target_num_different_grad = dist.get_world_size() if i == 0 else 1
assert len(grad_set) == target_num_different_grad


def run_dist(rank, world_size, port):
# init dist env
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
check_torch_ddp_plugin()
check_torch_ddp_no_sync()


@rerun_if_address_is_in_use()
Expand Down