Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions colossalai/gemini/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@
from enum import Enum
from typing import Optional, Dict, List

from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode
from colossalai.utils import get_current_device
from colossalai.tensor import ProcessGroup as ColoProcessGroup


class TensorState(Enum):
Expand Down Expand Up @@ -65,14 +64,16 @@ class Chunk:
def __init__(self,
chunk_size: int,
src_rank: int,
process_group: ColoProcessGroup,
dtype: torch.dtype,
init_device: Optional[torch.device] = None,
force_data_on_cuda: bool = False) -> None:
self.size = chunk_size
self.utilized_size = 0
self.src_rank = src_rank
self.is_src_rank = gpc.get_local_rank(ParallelMode.DATA) == src_rank
self.global_src_rank = gpc.get_ranks_in_group(ParallelMode.DATA)[src_rank]
self.process_group = process_group
self.is_src_rank = process_group.dp_local_rank() == src_rank
self.global_src_rank = process_group.get_ranks_in_dp()[src_rank]
self.dtype = dtype
device = init_device or get_current_device()
if force_data_on_cuda:
Expand Down Expand Up @@ -150,7 +151,7 @@ def access(self) -> None:
if not self.is_src_rank:
alloc_storage(self._payload)
self.move_device(get_current_device(), update_ptr=False)
dist.broadcast(self.data, self.global_src_rank, group=gpc.get_group(ParallelMode.DATA))
dist.broadcast(self.data, self.global_src_rank, group=self.process_group.dp_process_group())

# update tensor meta info
self._update_tensors_ptr()
Expand Down Expand Up @@ -193,9 +194,9 @@ def reduce(self, is_all_reduce: bool = False) -> None:
"""
self.move_device(get_current_device(), update_ptr=False)
if is_all_reduce:
dist.all_reduce(self.data, group=gpc.get_group(ParallelMode.DATA))
dist.all_reduce(self.data, group=self.process_group.dp_process_group())
else:
dist.reduce(self.data, self.global_src_rank, group=gpc.get_group(ParallelMode.DATA))
dist.reduce(self.data, self.global_src_rank, group=self.process_group.dp_process_group())
self._update_tensors_ptr()
self._update_tensors_state(TensorState.HOLD)

Expand All @@ -216,7 +217,7 @@ def tensor_trans_state(self, tensor: torch.Tensor, tensor_state: TensorState) ->
# invalid calls will be ignored and nothing changes
if (self.tensors_info[tensor].state, tensor_state) not in STATE_TRANS:
# print(
# f'WARNING: Rank{gpc.get_global_rank()} apply invalid state trans: {self.tensors_info[tensor].state} to {tensor_state}'
# f'WARNING: Rank{self.process_group.rank()} apply invalid state trans: {self.tensors_info[tensor].state} to {tensor_state}'
# )
return
self.tensors_info[tensor].state = tensor_state
Expand Down
15 changes: 9 additions & 6 deletions colossalai/gemini/chunk_mgr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@
from typing import Optional, Dict, Deque, Set, List, Tuple, Iterable
from collections import deque

from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.utils import get_current_device
from colossalai.tensor import ProcessGroup as ColoProcessGroup
from .chunk import Chunk, ChunkFullError, TensorState


Expand All @@ -20,10 +19,13 @@ class ChunkManager:

def __init__(self,
chunk_size: Optional[int],
process_group: ColoProcessGroup,
enable_distributed_storage: bool = False,
init_device: Optional[torch.device] = None) -> None:
assert chunk_size is None or chunk_size > 0
assert isinstance(process_group, ColoProcessGroup)
self.chunk_size = chunk_size
self.process_group = process_group
self.enable_distributed_storage = enable_distributed_storage
self.device = init_device or get_current_device()
self.chunk_groups: Dict[str, Deque[Chunk]] = {}
Expand Down Expand Up @@ -69,6 +71,7 @@ def append_tensor(self, tensor: torch.Tensor, group_name: str) -> None:
src_rank = self._get_next_src_rank(group_name)
chunk = Chunk(chunk_size,
src_rank,
self.process_group,
tensor.dtype,
self.device,
force_data_on_cuda=self.groups_force_data_on_cuda[group_name])
Expand All @@ -89,17 +92,17 @@ def append_tensor(self, tensor: torch.Tensor, group_name: str) -> None:
def _get_next_src_rank(self, group_name: str) -> int:
if not self.enable_distributed_storage:
# the chunk is owned by the current rank if no distributed storage is enabled
return gpc.get_local_rank(ParallelMode.DATA)
return self.process_group.dp_local_rank()
if self.chunk_size is None:
if group_name not in self.rank_load:
self.rank_load[group_name] = torch.zeros(gpc.get_world_size(ParallelMode.DATA), dtype=torch.int64)
self.rank_load[group_name] = torch.zeros(self.process_group.dp_world_size(), dtype=torch.int64)

# the process owning the tensor will be the process with the smallest number of elements
src_rank = torch.argmin(self.rank_load[group_name]).item()
else:
# chunk is owned by processes in a round-robin fashion
chunk_idx = len(self.chunk_groups[group_name])
src_rank = chunk_idx % gpc.get_world_size(ParallelMode.DATA)
src_rank = chunk_idx % self.process_group.dp_world_size()
return src_rank

def access_chunk(self, chunk: Chunk) -> None:
Expand Down Expand Up @@ -222,7 +225,7 @@ def exec_lazy_release(self) -> None:
self.lazy_release_tensors.clear()

def __repr__(self) -> str:
msg = f'Rank {gpc.get_local_rank(ParallelMode.DATA)}:\n'
msg = f'Rank {self.process_group.dp_local_rank()}:\n'
msg += 'Total memory: ' + ', '.join([f'{k}={v}B' for k, v in self.total_mem.items()]) + '\n'
for group_name, group in self.chunk_groups.items():
msg += f'Group {group_name}:\n'
Expand Down
9 changes: 3 additions & 6 deletions colossalai/nn/parallel/data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def grad_handle(self, p, grad):
return empty_grad

else:
#TODO(jiaruifang) fixme
# TODO(jiaruifang) fixme
self.process_group.set_cpu_groups()
dist.all_reduce(grad, group=self.process_group.cpu_dp_process_group())
return grad
Expand Down Expand Up @@ -191,11 +191,8 @@ class ZeroDDP(ColoDDP):
For more details, see the API reference of ``GeminiManager``.
"""

def __init__(self,
module: torch.nn.Module,
gemini_manager: GeminiManager,
process_group: Optional[ColoProcessGroup] = None) -> None:
super().__init__(module.half(), process_group=process_group)
def __init__(self, module: torch.nn.Module, gemini_manager: GeminiManager) -> None:
super().__init__(module.half(), process_group=gemini_manager.chunk_manager.process_group)
self.gemini_manager = gemini_manager
self.chunk_manager = gemini_manager.chunk_manager
self.param_op_hook = ZeROHookV2(gemini_manager)
Expand Down
6 changes: 6 additions & 0 deletions colossalai/tensor/process_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,3 +171,9 @@ def cpu_dp_process_group(self):
def cpu_tp_process_group(self):
assert self._has_cpu_groups
return PYTORCHPGDICT_.get(self._tp_rank_list, 'gloo')

def get_ranks_in_dp(self):
return self._dp_rank_list

def get_ranks_in_tp(self):
return self._tp_rank_list
6 changes: 3 additions & 3 deletions tests/test_ddp/test_ddp_ignore_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@ def init_ddp(module: torch.nn.Module) -> ColoDDP:


def init_ddpv2(module: torch.nn.Module, use_chunk: bool = False) -> ZeroDDP:
pg = ProcessGroup()
chunk_size = ChunkManager.search_chunk_size(module, 64, 2) if use_chunk else None
chunk_manager = ChunkManager(chunk_size)
chunk_manager = ChunkManager(chunk_size, pg)
gemini_manager = GeminiManager('cuda', chunk_manager)
pg = ProcessGroup()
return ZeroDDP(module, gemini_manager, pg)
return ZeroDDP(module, gemini_manager)


class Net(torch.nn.Module):
Expand Down
6 changes: 3 additions & 3 deletions tests/test_ddp/test_ddp_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@ def init_ddp(module: torch.nn.Module) -> ColoDDP:


def init_ddpv2(module: torch.nn.Module, use_chunk: bool = False, use_zero: bool = False) -> ZeroDDP:
pg = ProcessGroup()
chunk_size = ChunkManager.search_chunk_size(module, 64, 4) if use_chunk else None
chunk_manager = ChunkManager(chunk_size, enable_distributed_storage=use_zero)
chunk_manager = ChunkManager(chunk_size, pg, enable_distributed_storage=use_zero)
gemini_manager = GeminiManager('cuda', chunk_manager)
pg = ProcessGroup()
return ZeroDDP(module, gemini_manager, process_group=pg)
return ZeroDDP(module, gemini_manager)


def run_state_dict(ddp_init_func: Callable[[torch.nn.Module], ColoDDP]):
Expand Down
8 changes: 4 additions & 4 deletions tests/test_tensor/test_chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
from colossalai.gemini import ChunkManager
from colossalai.testing import rerun_if_address_is_in_use, parameterize
from colossalai.utils import free_port
from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode
from colossalai.tensor import ProcessGroup as ColoProcessGroup


def check_has_params(params: List[torch.Tensor], has_tensors: List[bool]):
Expand Down Expand Up @@ -38,12 +37,13 @@ def check_has_params(params: List[torch.Tensor], has_tensors: List[bool]):
@parameterize('use_chunk', [False, True])
@parameterize('use_zero', [False, True])
def run_chunk_zero(use_chunk, use_zero):
rank = gpc.get_local_rank(ParallelMode.DATA)
pg = ColoProcessGroup()
rank = pg.rank()
if rank == 0:
print(f'use_chunk={use_chunk}, use_zero={use_zero}')
params = [torch.rand(8, 8) for _ in range(3)]
chunk_size = 128 if use_chunk else None
chunk_manager = ChunkManager(chunk_size, enable_distributed_storage=use_zero)
chunk_manager = ChunkManager(chunk_size, pg, enable_distributed_storage=use_zero)
chunk_manager.create_group('param')
assert chunk_manager.total_mem['cpu'] == 0
assert chunk_manager.total_mem['cuda'] == 0
Expand Down
16 changes: 7 additions & 9 deletions tests/test_tensor/test_zero_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@ def check_param_equal(model, torch_model, pg: ProcessGroup):
def check_grad_equal(model, torch_model, pg: ProcessGroup):
for (n, p), (tn, tp) in zip(model.named_parameters(), torch_model.named_parameters()):
if p.grad is not None:
torch.distributed.barrier()
print(torch.distributed.get_rank(), p.grad)
assert tensor_shard_equal(tp.grad.to(dtype=p.grad.dtype, device=p.grad.device), p.grad,
pg.tp_local_rank(), pg.tp_world_size()), \
f'{tp.grad} vs {p.grad}\n{n}:\n\t{tp.grad.shape} vs {p.grad.shape} in {pg.rank()}'
Expand Down Expand Up @@ -63,9 +61,9 @@ def init_1d_col_spec(model, pg: ProcessGroup):
p.set_tensor_spec(*spec)


@parameterize('use_chunk', [False])
@parameterize('use_zero', [False])
@parameterize('placement_policy', ['cuda'])
@parameterize('use_chunk', [False, True])
@parameterize('use_zero', [False, True])
@parameterize('placement_policy', ['cuda', 'cpu'])
def run_gpt(use_chunk, use_zero, placement_policy, tp_init_spec_func=None):
set_seed(42)
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
Expand All @@ -92,10 +90,11 @@ def run_gpt(use_chunk, use_zero, placement_policy, tp_init_spec_func=None):

chunk_size = ChunkManager.search_chunk_size(model, 8192, 8) if use_chunk else None
chunk_manager = ChunkManager(chunk_size,
pg,
enable_distributed_storage=use_zero,
init_device=GeminiManager.get_default_device(placement_policy))
gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager, pg)
model = ZeroDDP(model, gemini_manager)
optim = HybridAdam(model.parameters(), lr=1e-3)
optim = ZeroOptimizer(optim, model, initial_scale=1)

Expand All @@ -104,7 +103,7 @@ def run_gpt(use_chunk, use_zero, placement_policy, tp_init_spec_func=None):
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
torch_model = DDP(torch_model, device_ids=[pg.rank()], process_group=pg.dp_process_group())

# print(chunk_manager)
print(chunk_manager)
check_param_equal(model, torch_model, pg)

model.eval()
Expand All @@ -129,13 +128,12 @@ def run_dist(rank, world_size, port):
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
if world_size == 4:
run_gpt(tp_init_spec_func=init_1d_col_spec)
# run_gpt(tp_init_spec_func=init_1d_row_spec)
run_gpt(tp_init_spec_func=init_1d_row_spec)
else:
run_gpt(tp_init_spec_func=init_1d_col_spec)


@pytest.mark.dist
@pytest.mark.skip("buggy test")
@pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use()
def test_gpt(world_size):
Expand Down
5 changes: 3 additions & 2 deletions tests/test_zero/test_zero_optim_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,14 @@


def init_zero(model, use_chunk, use_zero, placement_policy):
pg = ProcessGroup()
chunk_size = ChunkManager.search_chunk_size(model, 8192, 8) if use_chunk else None
chunk_manager = ChunkManager(chunk_size,
pg,
enable_distributed_storage=use_zero,
init_device=GeminiManager.get_default_device(placement_policy))
gemini_manager = GeminiManager(placement_policy, chunk_manager)
pg = ProcessGroup()
return ZeroDDP(model, gemini_manager, pg)
return ZeroDDP(model, gemini_manager)


def run_step(model, optim, criterion, data, label):
Expand Down