Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 18 additions & 30 deletions colossalai/zero/low_level/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@

import torch
import torch.distributed as dist
from torch import inf
from torch import Tensor, inf
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from torch.distributed import ProcessGroup

from colossalai.tensor import ColoParameter
from colossalai.utils import is_model_parallel_parameter
Expand Down Expand Up @@ -194,56 +195,43 @@ def calculate_global_norm_from_list(norm_list):
return math.sqrt(total_norm)


def compute_norm(gradients, params, dp_group, mp_group, norm_type=2):
def compute_norm(gradients: Tensor, dp_group: ProcessGroup, tp_group: ProcessGroup, norm_type: int = 2) -> int:
"""Clips gradient norm of an iterable of parameters.
This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
added functionality to handle model parallel parameters. Note that
the gradients are modified in place.
Arguments:
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
single Tensor that will have gradients normalized
max_norm (float or int): max norm of the gradients
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
infinity norm.
added functionality to handle model parallel parameters.

Args:
gradients (Tensor): The gradients to compute norm
dp_group (ProcessGroup): The process group of ZeRO Data Parallelism
tp_group (ProcessGroup): The process group of Tensor Parallelism
norm_type (int, optional): type of the used p-norm, Can be ``'inf'`` for infinity norm. Defaults to 2.

Returns:
Total norm of the parameters (viewed as a single vector).
int: The total norm of given gradients
"""

if mp_group is None:
mp_rank = 0
else:
mp_rank = dist.get_rank(mp_group)

norm_type = float(norm_type)
if norm_type == inf:
total_norm = max(g.data.abs().max() for g in gradients)
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=dp_group)

# Take max across all GPUs.
if mp_group is not None:
if tp_group is not None:
dist.all_reduce(tensor=total_norm_cuda, op=torch.distributed.ReduceOp.MAX)
total_norm = total_norm_cuda[0].item()
else:
total_norm = 0.0
# if dist.get_rank() == 0:
# logger.info(f"Total Norm beginning {total_norm}")

for g, p in zip(gradients, params):
# Pipeline parallelism may replicate parameters. Avoid multi-counting.
tp_param_flag = False
if is_model_parallel_parameter(p) or (isinstance(p, ColoParameter) and not p.is_replicate()):
tp_param_flag = True
if tp_param_flag or mp_rank == 0:
param_norm = g.data.double().norm(2)
total_norm += param_norm.item()**2
for g in gradients:
param_norm = g.data.double().norm(2)
total_norm += param_norm.item()**2

# Sum across all model parallel GPUs.
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
torch.distributed.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.SUM, group=dp_group)

if mp_group is not None:
dist.all_reduce(tensor=total_norm_cuda, op=torch.distributed.ReduceOp.SUM, group=mp_group)
if tp_group is not None:
dist.all_reduce(tensor=total_norm_cuda, op=torch.distributed.ReduceOp.SUM, group=tp_group)

total_norm = total_norm_cuda[0].item()**(1. / norm_type)

Expand Down
74 changes: 18 additions & 56 deletions colossalai/zero/low_level/low_level_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,17 @@

import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from torch.optim import Optimizer

from colossalai.amp.naive_amp.mixed_precision_mixin import (
BF16MixedPrecisionMixin,
FP16MixedPrecisionMixin,
MixedPrecisionMixin,
)
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.interface import OptimizerWrapper
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.tensor import ColoParameter, ProcessGroup
# from colossalai.tensor import ColoParameter, ProcessGroup
from colossalai.utils.cuda import get_current_device

from ._utils import (
Expand Down Expand Up @@ -77,11 +75,12 @@ def __init__(
overlap_communication: bool = False,
partition_grad: bool = False, # stage 2 flag
cpu_offload: bool = False, # cpu offload
dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm
tp_process_group: Optional[ProcessGroup] = None, # if using tp
forced_dtype: Optional[torch.dtype] = None):

# TODO:
# 1. process group api
# 2. checkpoint IO
# 1. state_dict for checkpoint IO

super(LowLevelZeroOptimizer, self).__init__(optim=optimizer)
self._dtype = self.optim.param_groups[0]['params'][0].dtype
Expand All @@ -96,30 +95,12 @@ def __init__(
# grad accumulation
self.require_grad_sync = True

colo_pg = self._search_colo_process_group()
if isinstance(colo_pg, ProcessGroup):
self._local_rank = colo_pg.dp_local_rank()
self._world_size = colo_pg.dp_world_size()
self._dp_global_ranks = colo_pg.get_ranks_in_dp()
self._dp_torch_group = colo_pg.dp_process_group()
self._mp_torch_group = None
if colo_pg.tp_world_size() > 1:
self._mp_torch_group = colo_pg.tp_process_group()
elif colo_pg is None:
dp_parallel_mode = ParallelMode.DATA
mp_parallel_mode = ParallelMode.MODEL

self._dp_parallel_mode = dp_parallel_mode
self._mp_parallel_mode = mp_parallel_mode
self._local_rank = gpc.get_local_rank(dp_parallel_mode)
self._world_size = gpc.get_world_size(dp_parallel_mode)
self._dp_global_ranks = gpc.get_ranks_in_group(dp_parallel_mode)
self._dp_torch_group = gpc.get_group(dp_parallel_mode)
self._mp_torch_group = None
if gpc.is_initialized(mp_parallel_mode) and gpc.get_world_size(mp_parallel_mode) > 1:
self._mp_torch_group = gpc.get_group(mp_parallel_mode)
else:
raise NotImplementedError
# if process_group is none, will use the default one
self.dp_pg = dp_process_group
Comment thread
FrankLeeeee marked this conversation as resolved.
self._local_rank = dist.get_rank(group=self.dp_pg)
self._world_size = dist.get_world_size(group=self.dp_pg)

self.tp_pg = tp_process_group

# working and master params for mixed precision training
self._working_param_groups = dict()
Expand All @@ -145,9 +126,9 @@ def __init__(

# ParameterStore will manage the tensor buffers used for zero
# it will not manage the tensors used by mixed precision training
self._param_store = ParameterStore(self._dp_torch_group)
self._grad_store = GradientStore(self._dp_torch_group, partition_grad=partition_grad)
self._bucket_store = BucketStore(self._dp_torch_group)
self._param_store = ParameterStore(self.dp_pg)
self._grad_store = GradientStore(self.dp_pg, partition_grad=partition_grad)
self._bucket_store = BucketStore(self.dp_pg)

# iterate over the param group in the optimizer
# partition these param groups for data parallel training
Expand Down Expand Up @@ -212,22 +193,6 @@ def _sanity_checks(self):
assert param.dtype == self._dtype, \
f"Parameters are expected to have the same dtype `{self._dtype}`, but got `{param.dtype}`"

def _search_colo_process_group(self):
colo_flag = False
colo_pg = None
for param_group in self.optim.param_groups:
group_params = param_group['params']
for param in group_params:
if isinstance(param, ColoParameter):
colo_flag = True
if colo_pg is None:
colo_pg = param.get_process_group()
else:
assert colo_pg == param.get_process_group(), "All parameters should be in a same process group"
elif colo_flag:
raise RuntimeError("All parameters should be ColoParameter if you use ColoParameter.")
return colo_pg

def _create_master_param_current_rank(self, param_list):
# split each param evenly by world size
params_current_rank = []
Expand Down Expand Up @@ -291,7 +256,7 @@ def _run_reduction(self):
flat_grads = flat_grads.to(self._communication_dtype)

if not self._partition_grads:
dist.all_reduce(flat_grads, group=self._dp_torch_group)
dist.all_reduce(flat_grads, group=self.dp_pg)
if flat_grads.dtype != grad_dtype:
flat_grads = flat_grads.to(grad_dtype)

Expand All @@ -307,7 +272,7 @@ def _run_reduction(self):
else:
flat_grads_list = list(flat_grads.split(len(flat_grads) // self._world_size))
recieved_grad = torch.zeros_like(flat_grads_list[0])
dist.reduce_scatter(recieved_grad, flat_grads_list, group=self._dp_torch_group)
dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.dp_pg)

if recieved_grad.dtype != grad_dtype:
recieved_grad = recieved_grad.to(grad_dtype)
Expand Down Expand Up @@ -425,10 +390,7 @@ def step(self, closure=None):

# compute norm
working_grads = self._grad_store.get_working_grads_by_group_id(group_id)
norm_group = compute_norm(gradients=working_grads,
params=real_working_params[group_id],
dp_group=self._dp_torch_group,
mp_group=self._mp_torch_group)
norm_group = compute_norm(gradients=working_grads, dp_group=self.dp_pg, tp_group=self.tp_pg)
norm_groups.append(norm_group)

self._grad_store.reset_grads_by_group_id(group_id)
Expand All @@ -454,7 +416,7 @@ def step(self, closure=None):

for idx, splited_param in enumerate(master_working_param):
full_master_param = [torch.zeros_like(splited_param).cuda() for _ in range(self._world_size)]
dist.all_gather(full_master_param, splited_param.cuda(), group=self._dp_torch_group)
dist.all_gather(full_master_param, splited_param.cuda(), group=self.dp_pg)
working_param = real_working_params[group_id][idx]
full_master_param = flatten(full_master_param)[:working_param.numel()].reshape_as(working_param)
working_param.data.copy_(full_master_param)
Expand Down
5 changes: 2 additions & 3 deletions tests/test_zero/test_low_level/test_zero_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,9 @@ def exam_zero_init():

assert optimizer1._local_rank == optimizer2._local_rank
assert optimizer1._world_size == optimizer2._world_size
assert optimizer1._dp_global_ranks == optimizer2._dp_global_ranks

mp_group1 = optimizer1._mp_torch_group
mp_group2 = optimizer2._mp_torch_group
mp_group1 = optimizer1.tp_pg
mp_group2 = optimizer2.tp_pg
assert dist.get_world_size(mp_group1) == dist.get_world_size(mp_group2)
assert dist.get_rank(mp_group1) == dist.get_rank(mp_group2)

Expand Down
4 changes: 3 additions & 1 deletion tests/test_zero/test_low_level/test_zero_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ def exam_zero_with_tp(overlap_flag, partition_flag):
initial_scale=2,
clip_grad_norm=1.0,
overlap_communication=overlap_flag,
partition_grad=partition_flag)
partition_grad=partition_flag,
dp_process_group=tp_pg.dp_process_group(),
tp_process_group=tp_pg.tp_process_group())

dp_local_rank = tp_pg.dp_local_rank()
set_seed(255 + dp_local_rank)
Expand Down