Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions colossalai/zero/low_level/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,18 @@ def get_grad_accumulate_object(tensor):
return grad_acc_obj


def split_half_float_double(tensor_list):
def split_by_dtype(tensor_list):
"""
Splits a list of PyTorch tensors into sublists based on their data type.

:param tensor_list: A list of PyTorch tensors.
:type tensor_list: list[torch.Tensor]
:return: A list of sublists, where each sublist contains tensors of a specific data type.
:rtype: list[list[torch.Tensor]]
"""
dtypes = ["torch.cuda.HalfTensor", "torch.cuda.FloatTensor", "torch.cuda.DoubleTensor", "torch.cuda.BFloat16Tensor"]
buckets = []
for i, dtype in enumerate(dtypes):
for _, dtype in enumerate(dtypes):
bucket = [t for t in tensor_list if t.type() == dtype]
if bucket:
buckets.append(bucket)
Expand Down
40 changes: 20 additions & 20 deletions colossalai/zero/low_level/bookkeeping/parameter_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ class ParameterStore(BaseStore):
def __init__(self, torch_pg: ProcessGroup):
super().__init__(torch_pg)
# param partitioning data structures
self._fp16_param_to_rank = dict()
self._rank_groupid_to_fp16_param_list = dict()
self._rank_group_id_to_flat_fp16_param = dict()
self._param_to_rank = dict()
self._rank_group_id_to_param_list = dict()
self._rank_group_id_to_flat_param = dict()

# param reduction data structures
self._is_param_reduced = dict()
Expand All @@ -29,7 +29,7 @@ def set_param_to_rank(self, tensor: Tensor, rank: int) -> None:
:type rank: int
"""

self._fp16_param_to_rank[tensor] = rank
self._param_to_rank[tensor] = rank

def get_param_rank(self, tensor: Tensor) -> int:
"""
Expand All @@ -38,7 +38,7 @@ def get_param_rank(self, tensor: Tensor) -> int:
:param tensor: A :class:`torch.Tensor` object
:type tensor: torch.Tensor
"""
return self._fp16_param_to_rank[tensor]
return self._param_to_rank[tensor]

def belongs_to_current_rank(self, tensor) -> bool:
"""
Expand All @@ -51,29 +51,29 @@ def belongs_to_current_rank(self, tensor) -> bool:
:rtype: bool
"""

tensor_rank = self._fp16_param_to_rank[tensor]
tensor_rank = self._param_to_rank[tensor]
return tensor_rank == self._local_rank

def add_fp16_param_list_by_rank_group(self, rank, group_id, tensor_list) -> None:
if rank not in self._rank_groupid_to_fp16_param_list:
self._rank_groupid_to_fp16_param_list[rank] = dict()
def add_param_list_by_rank_group(self, rank, group_id, tensor_list) -> None:
if rank not in self._rank_group_id_to_param_list:
self._rank_group_id_to_param_list[rank] = dict()

if group_id not in self._rank_groupid_to_fp16_param_list[rank]:
self._rank_groupid_to_fp16_param_list[rank][group_id] = []
if group_id not in self._rank_group_id_to_param_list[rank]:
self._rank_group_id_to_param_list[rank][group_id] = []

self._rank_groupid_to_fp16_param_list[rank][group_id].extend(tensor_list)
self._rank_group_id_to_param_list[rank][group_id].extend(tensor_list)

def get_fp16_params_by_rank_group(self, rank, group_id) -> List[Tensor]:
return self._rank_groupid_to_fp16_param_list[rank][group_id]
def get_params_by_rank_group(self, rank, group_id) -> List[Tensor]:
return self._rank_group_id_to_param_list[rank][group_id]

def add_flat_fp16_param_by_rank_group(self, rank, group_id, tensor) -> None:
if rank not in self._rank_group_id_to_flat_fp16_param:
self._rank_group_id_to_flat_fp16_param[rank] = dict()
def add_flat_param_by_rank_group(self, rank, group_id, tensor) -> None:
if rank not in self._rank_group_id_to_flat_param:
self._rank_group_id_to_flat_param[rank] = dict()

self._rank_group_id_to_flat_fp16_param[rank][group_id] = tensor
self._rank_group_id_to_flat_param[rank][group_id] = tensor

def get_flat_fp16_param_by_rank_group(self, rank, group_id) -> Tensor:
return self._rank_group_id_to_flat_fp16_param[rank][group_id]
def get_flat_param_by_rank_group(self, rank, group_id) -> Tensor:
return self._rank_group_id_to_flat_param[rank][group_id]

def is_param_reduced(self, tensor):
return self._is_param_reduced[tensor]
Expand Down
109 changes: 55 additions & 54 deletions colossalai/zero/low_level/low_level_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
has_inf_or_nan,
reduce_tensor_dp_group,
release_param_grad,
split_half_float_double,
split_by_dtype,
sync_param,
)
from .bookkeeping import BucketStore, GradientStore, ParameterStore, TensorBucket
Expand Down Expand Up @@ -90,9 +90,10 @@ def __init__(
self._mp_torch_group = gpc.get_group(mp_parallel_mode)
else:
raise NotImplementedError
# fp16 and fp32 params for mixed precision training
self._fp16_param_groups = dict()
self._fp32_flat_param_groups_of_current_rank = dict()

# working and master params for mixed precision training
self._working_param_groups = dict()
self._master_flat_param_groups_of_current_rank = dict()

# communication params
self._overlap_communication = overlap_communication
Expand Down Expand Up @@ -138,8 +139,8 @@ def __init__(
if param.requires_grad:
group_params.append(param)

# add the fp16 params to fp16_param_groups for bookkeeping
self._fp16_param_groups[group_id] = group_params
# add the working params to working_param_groups for bookkeeping
self._working_param_groups[group_id] = group_params

# assign parameters to ranks
# the params in the list are sorted
Expand All @@ -148,7 +149,7 @@ def __init__(
# store the mapping between param to rank
# each param should belong to only one rank
for rank, params in enumerate(params_per_rank):
self._param_store.add_fp16_param_list_by_rank_group(rank, group_id, params)
self._param_store.add_param_list_by_rank_group(rank, group_id, params)
for param in params:
self._param_store.set_param_to_rank(param, rank)

Expand All @@ -159,33 +160,33 @@ def __init__(

# flatten the reordered tensors
for rank in range(self._world_size):
tensor_list = self._param_store.get_fp16_params_by_rank_group(rank, group_id)
tensor_list = self._param_store.get_params_by_rank_group(rank, group_id)
with torch.no_grad():
flat_tensor = flatten(tensor_list)
flat_tensor = flat_tensor.data.cuda()
self._param_store.add_flat_fp16_param_by_rank_group(rank, group_id, flat_tensor)
self._param_store.add_flat_param_by_rank_group(rank, group_id, flat_tensor)

# sync parameters
for rank in range(self._world_size):
flat_tensor = self._param_store.get_flat_fp16_param_by_rank_group(rank, group_id)
tensor_list = self._param_store.get_fp16_params_by_rank_group(rank, group_id)
flat_tensor = self._param_store.get_flat_param_by_rank_group(rank, group_id)
tensor_list = self._param_store.get_params_by_rank_group(rank, group_id)
sync_param(flat_tensor=flat_tensor, tensor_list=tensor_list)

# create a copy of fp32 weights of the parameters for which this rank is responsible
fp16_flat_current_rank = self._param_store.get_flat_fp16_param_by_rank_group(self._local_rank, group_id)
fp32_flat_current_rank = fp16_flat_current_rank.float()
# create a copy of fp32 master weights of the parameters for which this rank is responsible
working_flat_current_rank = self._param_store.get_flat_param_by_rank_group(self._local_rank, group_id)
master_flat_current_rank = working_flat_current_rank.float()
device = 'cpu' if self._cpu_offload else get_current_device()
fp32_flat_current_rank = fp32_flat_current_rank.to(device)
fp32_flat_current_rank.requires_grad = True
self._fp32_flat_param_groups_of_current_rank[group_id] = fp32_flat_current_rank
master_flat_current_rank = master_flat_current_rank.to(device)
master_flat_current_rank.requires_grad = True
self._master_flat_param_groups_of_current_rank[group_id] = master_flat_current_rank

# need to replace the params in the `params` field in the optimizer
# so that when the optimizer calls step(), it only updates the tensors
# managed by this data parallel rank
param_group['params'] = [fp32_flat_current_rank]
param_group['params'] = [master_flat_current_rank]

# set reduction state
for param in self._fp16_param_groups[group_id]:
for param in self._working_param_groups[group_id]:
self._param_store.set_param_reduction_state(param, False)

# intialize communication stream for
Expand All @@ -209,7 +210,7 @@ def loss_scale(self):

@property
def num_param_groups(self):
return len(self._fp16_param_groups)
return len(self._working_param_groups)

def _sanity_checks(self):
assert torch.cuda.is_available(), 'CUDA is required'
Expand Down Expand Up @@ -261,10 +262,10 @@ def _grad_handler(self, param, grad, reduce_rank):
return grad

def _attach_reduction_hook(self):
# we iterate over the fp16 params
# we iterate over the working params
# on each param, we register a hook to its AccumulateGrad object
for group_id in range(self.num_param_groups):
param_group = self._fp16_param_groups[group_id]
param_group = self._working_param_groups[group_id]
for param in param_group:
if param.requires_grad:
# determines the reduction destionation rank
Expand Down Expand Up @@ -315,7 +316,7 @@ def _reduce_tensor_list_with_one_dtype(self, tensor_list, bucket_size, reduce_ra
self._reduce_tensor_bucket(bucket=param_bucket, reduce_rank=reduce_rank)

def _reduce_grads(self, reduce_rank, grads, bucket_size):
grad_buckets_by_dtype = split_half_float_double(grads)
grad_buckets_by_dtype = split_by_dtype(grads)

for tensor_list in grad_buckets_by_dtype:
self._reduce_tensor_list_with_one_dtype(tensor_list=tensor_list,
Expand Down Expand Up @@ -418,7 +419,7 @@ def zero_grad(self, set_to_none=True):
:param set_to_none: Whether set the gradient to None. Default value is True.
:type set_to_none: bool
"""
for _, param_group in self._fp16_param_groups.items():
for _, param_group in self._working_param_groups.items():
for param in param_group:
if set_to_none:
param.grad = None
Expand Down Expand Up @@ -446,33 +447,33 @@ def step(self, closure=None):
self.zero_grad()
return

# copy the grad of fp16 param to fp32 param
# copy the grad of working param to master param
single_grad_partition_groups = []
norm_groups = []

for group_id in range(self.num_param_groups):
# compute norm
norm_group = compute_norm(gradients=self._grad_store.get_averaged_gradients_by_group(group_id),
params=self._param_store.get_fp16_params_by_rank_group(group_id=group_id,
rank=self._local_rank),
params=self._param_store.get_params_by_rank_group(group_id=group_id,
rank=self._local_rank),
dp_group=self._dp_torch_group,
mp_group=self._mp_torch_group)
norm_groups.append(norm_group)

# create flat gradient for the flat fp32 params
fp16_avg_grads = self._grad_store.get_averaged_gradients_by_group(group_id)
flat_fp16_avg_grads = flatten(fp16_avg_grads)
# create flat gradient for the flat fp32 master params
working_avg_grads = self._grad_store.get_averaged_gradients_by_group(group_id)
flat_working_avg_grads = flatten(working_avg_grads)

dtype = self._fp32_flat_param_groups_of_current_rank[group_id].dtype
flat_fp32_avg_grads = flat_fp16_avg_grads.to(dtype)
dtype = self._master_flat_param_groups_of_current_rank[group_id].dtype
flat_master_avg_grads = flat_working_avg_grads.to(dtype)

param_shape = self._fp32_flat_param_groups_of_current_rank[group_id].shape
assert param_shape == flat_fp32_avg_grads.shape, \
f'fp32 param and grad have different shape {param_shape} vs {flat_fp32_avg_grads.shape}'
param_shape = self._master_flat_param_groups_of_current_rank[group_id].shape
assert param_shape == flat_master_avg_grads.shape, \
f'fp32 param and grad have different shape {param_shape} vs {flat_master_avg_grads.shape}'

single_grad_partition_groups.append(flat_fp32_avg_grads)
device = self._fp32_flat_param_groups_of_current_rank[group_id].device
self._fp32_flat_param_groups_of_current_rank[group_id].grad = flat_fp32_avg_grads.to(device)
single_grad_partition_groups.append(flat_master_avg_grads)
device = self._master_flat_param_groups_of_current_rank[group_id].device
self._master_flat_param_groups_of_current_rank[group_id].grad = flat_master_avg_grads.to(device)
self._grad_store.reset_average_gradients_by_group(group_id)

# unscale and clip grads
Expand All @@ -481,37 +482,37 @@ def step(self, closure=None):

# update the parameters
self.optim.step()
# release the fp32 grad
release_param_grad(self._fp32_flat_param_groups_of_current_rank.values())
# release the master grad
release_param_grad(self._master_flat_param_groups_of_current_rank.values())

# update fp16 partition updated by the current rank
for group_id in range(len(self._fp16_param_groups)):
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=self._local_rank, group_id=group_id)
fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id]
fp16_param.data.copy_(fp32_param)
# update working partition updated by the current rank
for group_id in range(len(self._working_param_groups)):
working_param = self._param_store.get_flat_param_by_rank_group(rank=self._local_rank, group_id=group_id)
master_param = self._master_flat_param_groups_of_current_rank[group_id]
working_param.data.copy_(master_param)

# broadcast the updated model weights
handles = []
for group_id in range(self.num_param_groups):
for index in range(self._world_size):
rank = self._dp_global_ranks[index]
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=index, group_id=group_id)
handle = dist.broadcast(fp16_param, src=rank, group=self._dp_torch_group, async_op=True)
working_param = self._param_store.get_flat_param_by_rank_group(rank=index, group_id=group_id)
handle = dist.broadcast(working_param, src=rank, group=self._dp_torch_group, async_op=True)
handles.append(handle)

for handle in handles:
handle.wait()

##################
# FP16 Utilities #
##################
#############################
# Mixed Precision Utilities #
#############################

def _check_overflow(self):
# clear previous overflow record
self._found_overflow.fill_(0.0)

# check for overflow
for group_id in range(len(self._fp16_param_groups)):
for group_id in range(len(self._working_param_groups)):
for avg_grad in self._grad_store.get_averaged_gradients_by_group(group_id):
if avg_grad is not None and has_inf_or_nan(avg_grad):
self._found_overflow.fill_(1.0)
Expand Down Expand Up @@ -554,7 +555,7 @@ def _sync_grad(self):

# accumulate gradient
for group_id in range(self.num_param_groups):
param_group = self._param_store.get_fp16_params_by_rank_group(self._local_rank, group_id)
param_group = self._param_store.get_params_by_rank_group(self._local_rank, group_id)

avg_gradients_group = self._grad_store.get_averaged_gradients_by_group(group_id)

Expand All @@ -575,8 +576,8 @@ def _reduce_grad_stage1(self):
# if not overlapping communication (no reduction hook is attached)
# we need to manually reduce these gradients
if not self._overlap_communication:
for group_id in range(len(self._fp16_param_groups)):
param_group = self._fp16_param_groups[group_id]
for group_id in range(len(self._working_param_groups)):
param_group = self._working_param_groups[group_id]
for param in param_group:
if param.grad is not None:
self._add_to_reduction_bucket(param)
Expand Down