From d267c0896f387935f983b63525c0c42890a9a812 Mon Sep 17 00:00:00 2001 From: ver217 Date: Mon, 8 May 2023 15:03:09 +0800 Subject: [PATCH 1/6] [booster] fix no_sync method --- colossalai/booster/plugin/gemini_plugin.py | 5 ++++- colossalai/booster/plugin/low_level_zero_plugin.py | 5 ++++- colossalai/booster/plugin/plugin_base.py | 9 ++++++++- colossalai/booster/plugin/torch_ddp_plugin.py | 6 +++++- 4 files changed, 21 insertions(+), 4 deletions(-) diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index fde8912a648f..4b32a0ea62c4 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -2,7 +2,7 @@ import os import warnings from pathlib import Path -from typing import Callable, List, Optional, Tuple, Union +from typing import Callable, Iterator, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -286,3 +286,6 @@ def control_checkpoint_io(self) -> bool: def get_checkpoint_io(self) -> CheckpointIO: return GeminiCheckpointIO() + + def no_sync(self, model: nn.Module) -> Iterator[None]: + raise NotImplementedError diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index 828d8b27422f..b153ff3639db 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -1,5 +1,5 @@ import warnings -from typing import Callable, List, Optional, Tuple, Union +from typing import Callable, Iterator, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -197,3 +197,6 @@ def control_checkpoint_io(self) -> bool: def get_checkpoint_io(self) -> CheckpointIO: return LowLevelZeroCheckpointIO() + + def no_sync(self, model: nn.Module) -> Iterator[None]: + raise NotImplementedError diff --git a/colossalai/booster/plugin/plugin_base.py b/colossalai/booster/plugin/plugin_base.py index 7a222022c1b2..9e0c9066ce6b 100644 --- a/colossalai/booster/plugin/plugin_base.py +++ b/colossalai/booster/plugin/plugin_base.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Callable, List, Tuple, Union +from typing import Callable, Iterator, List, Tuple, Union import torch.nn as nn from torch.optim import Optimizer @@ -59,3 +59,10 @@ def get_checkpoint_io(self) -> CheckpointIO: Get checkpoint io object for this plugin, only invoked when control_checkpoint_io is True. """ pass + + @abstractmethod + def no_sync(self, model: nn.Module) -> Iterator[None]: + """ + Context manager to disable gradient synchronization. + """ + pass diff --git a/colossalai/booster/plugin/torch_ddp_plugin.py b/colossalai/booster/plugin/torch_ddp_plugin.py index d30d266c0048..6ec6b7dc430c 100644 --- a/colossalai/booster/plugin/torch_ddp_plugin.py +++ b/colossalai/booster/plugin/torch_ddp_plugin.py @@ -1,4 +1,4 @@ -from typing import Callable, List, Tuple, Union +from typing import Callable, Iterator, List, Tuple, Union import torch.nn as nn from torch.nn.parallel import DistributedDataParallel as DDP @@ -142,3 +142,7 @@ def control_checkpoint_io(self) -> bool: def get_checkpoint_io(self) -> CheckpointIO: return TorchDDPCheckpointIO() + + def no_sync(self, model: nn.Module) -> Iterator[None]: + assert isinstance(model, TorchDDPModel), 'Model is not boosted by TorchDDPPlugin.' + return model.module.no_sync() From dea9249363241106eb1479a911392b1ad86b7b78 Mon Sep 17 00:00:00 2001 From: ver217 Date: Mon, 8 May 2023 15:44:19 +0800 Subject: [PATCH 2/6] [booster] add test for ddp no_sync --- .../test_plugin/test_torch_ddp_plugin.py | 61 +++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/tests/test_booster/test_plugin/test_torch_ddp_plugin.py b/tests/test_booster/test_plugin/test_torch_ddp_plugin.py index 30c4db12309f..5d171bddb414 100644 --- a/tests/test_booster/test_plugin/test_torch_ddp_plugin.py +++ b/tests/test_booster/test_plugin/test_torch_ddp_plugin.py @@ -1,7 +1,11 @@ +from contextlib import nullcontext + import torch import torch.distributed as dist +import torch.nn as nn from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import SGD +from torch.utils.data import TensorDataset import colossalai from colossalai.booster import Booster @@ -44,10 +48,67 @@ def check_torch_ddp_plugin(): torch.cuda.empty_cache() +class DummyModel(nn.Module): + + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.rand(1)) + + def forward(self, x): + return self.weight * x + + +def check_torch_ddp_no_sync(): + plugin = TorchDDPPlugin() + booster = Booster(plugin=plugin) + + model = DummyModel() + criterion = lambda x: x.mean() + optimizer = SGD(model.parameters(), lr=1e-3) + # create a custom dasetset with 0 to 10 + dataset = torch.arange(0, 10) + train_dataloader = plugin.prepare_train_dataloader(dataset, batch_size=2) + model, optimizer, criterion, train_dataloader, _ = booster.boost(model, + optimizer, + criterion, + dataloader=train_dataloader) + + def fwd_bwd(): + output = model(batch.cuda()) + loss = criterion(output) + booster.backward(loss, optimizer) + + def get_grad_set_over_all_ranks(): + for p in model.parameters(): + # grad shape is (1, ) + assert p.grad.shape == (1,) + grad_list = [torch.empty_like(p.grad) for _ in range(dist.get_world_size())] + dist.all_gather(grad_list, p.grad) + # get grad set of all ranks + grad_set = set([grad.item() for grad in grad_list]) + # as the model only has one parameter, we can return here + return grad_set + + for i, batch in enumerate(train_dataloader): + if i > 1: + # only check the first two batches + break + # no_sync for the first batch, sync for the second batch + ctx = booster.no_sync(model) if i == 0 else nullcontext() + with ctx: + fwd_bwd() + grad_set = get_grad_set_over_all_ranks() + # for the first batch, all ranks should have different grads + # for the second batch, as grad is synchronized,all ranks should have the same grads + target_num_different_grad = dist.get_world_size() if i == 0 else 1 + assert len(grad_set) == target_num_different_grad + + def run_dist(rank, world_size, port): # init dist env colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') check_torch_ddp_plugin() + check_torch_ddp_no_sync() @rerun_if_address_is_in_use() From a540dbc9c959a265dc321833140e5ad141695405 Mon Sep 17 00:00:00 2001 From: ver217 Date: Mon, 8 May 2023 15:47:20 +0800 Subject: [PATCH 3/6] [booster] fix merge --- colossalai/booster/plugin/plugin_base.py | 1 + 1 file changed, 1 insertion(+) diff --git a/colossalai/booster/plugin/plugin_base.py b/colossalai/booster/plugin/plugin_base.py index 4cc95f7eb01a..561f58bc5570 100644 --- a/colossalai/booster/plugin/plugin_base.py +++ b/colossalai/booster/plugin/plugin_base.py @@ -67,6 +67,7 @@ def no_sync(self, model: nn.Module) -> Iterator[None]: """ pass + @abstractmethod def prepare_dataloader(self, dataset: Dataset, batch_size: int, From 7e92c2d26154d1ee86812353020be47284c7959a Mon Sep 17 00:00:00 2001 From: ver217 Date: Mon, 8 May 2023 15:47:49 +0800 Subject: [PATCH 4/6] [booster] update unit test --- tests/test_booster/test_plugin/test_torch_ddp_plugin.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_booster/test_plugin/test_torch_ddp_plugin.py b/tests/test_booster/test_plugin/test_torch_ddp_plugin.py index 5d171bddb414..8b3b31aaf722 100644 --- a/tests/test_booster/test_plugin/test_torch_ddp_plugin.py +++ b/tests/test_booster/test_plugin/test_torch_ddp_plugin.py @@ -5,7 +5,6 @@ import torch.nn as nn from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import SGD -from torch.utils.data import TensorDataset import colossalai from colossalai.booster import Booster From ed27989a694fa5d732fb9e7c111f728d81eb2f5c Mon Sep 17 00:00:00 2001 From: ver217 Date: Mon, 8 May 2023 16:48:49 +0800 Subject: [PATCH 5/6] [booster] update unit test --- tests/test_booster/test_plugin/test_torch_ddp_plugin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_booster/test_plugin/test_torch_ddp_plugin.py b/tests/test_booster/test_plugin/test_torch_ddp_plugin.py index 8b3b31aaf722..fbe44e5ce6fb 100644 --- a/tests/test_booster/test_plugin/test_torch_ddp_plugin.py +++ b/tests/test_booster/test_plugin/test_torch_ddp_plugin.py @@ -66,7 +66,7 @@ def check_torch_ddp_no_sync(): optimizer = SGD(model.parameters(), lr=1e-3) # create a custom dasetset with 0 to 10 dataset = torch.arange(0, 10) - train_dataloader = plugin.prepare_train_dataloader(dataset, batch_size=2) + train_dataloader = plugin.prepare_dataloader(dataset, batch_size=2) model, optimizer, criterion, train_dataloader, _ = booster.boost(model, optimizer, criterion, From 2ae68b24dcb58afef06d50f7d53ae1608e7d32b1 Mon Sep 17 00:00:00 2001 From: ver217 Date: Mon, 8 May 2023 17:21:40 +0800 Subject: [PATCH 6/6] [booster] update unit test --- tests/test_booster/test_plugin/test_dp_plugin_base.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/test_booster/test_plugin/test_dp_plugin_base.py b/tests/test_booster/test_plugin/test_dp_plugin_base.py index eab949828db9..61aeded12203 100644 --- a/tests/test_booster/test_plugin/test_dp_plugin_base.py +++ b/tests/test_booster/test_plugin/test_dp_plugin_base.py @@ -1,4 +1,4 @@ -from typing import Callable, List, Tuple, Union +from typing import Callable, Iterator, List, Tuple, Union import torch import torch.distributed as dist @@ -49,6 +49,9 @@ def supported_devices(self) -> List[str]: def supported_precisions(self) -> List[str]: pass + def no_sync(self, model: nn.Module) -> Iterator[None]: + pass + def check_dataloader_sharding(): plugin = DPPluginWrapper()