Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion colossalai/zero/low_level/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def compute_norm(gradients, params, dp_group, mp_group, norm_type=2):
return total_norm


def sync_param(flat_tensor, tensor_list):
def sync_tensor(flat_tensor, tensor_list):
"""
Synchronize the flattened tensor and unflattened tensor list. When
a list of tensor are flattened with `torch._utils._unflatten_dense_tensors`,
Expand Down
122 changes: 97 additions & 25 deletions colossalai/zero/low_level/bookkeeping/bucket_store.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
from typing import Dict

import torch
from torch import Tensor
from torch._utils import _flatten_dense_tensors
from torch.distributed import ProcessGroup

from .base_store import BaseStore
Expand All @@ -7,35 +12,102 @@ class BucketStore(BaseStore):

def __init__(self, torch_pg: ProcessGroup):
super().__init__(torch_pg)
self._params = dict()
self._num_elements_in_bucket = dict()

# init and reset
self.current_group_id = 0
# mapping gardient slices and parameter
self.grad_to_param_mapping = dict()

self._param_list = []
self._padding_size = []

self.reset()

def num_elements_in_bucket(self, reduce_rank: int = None):
return self._num_elements_in_bucket[reduce_rank]
def num_elements_in_bucket(self) -> int:
"""Return the total number of elements in bucket

Returns:
int: the total number of elements in bucket
"""

return self._num_elements_in_bucket

def add_param_grad(self, group_id: int, param: Tensor, padding_size: int):
"""Add a param to bucket and record the padding size of a param for gradient padding

Args:
group_id (int): The index of a parameter group
param (Tensor): The parameter
padding_size (int): The padding size of the parameter
"""

self._param_list.append(param)
self._padding_size.append(padding_size)
self._num_elements_in_bucket += (param.numel() + padding_size)
self.current_group_id = group_id

def build_grad_in_bucket(self):
"""Orgnize parameters' gradient(padding and split), follows the paramters' splitting method

Data structure of self._grad_in_bucket:
{
rank0: [grad0_rank0, grad1_rank0, ...]
rank1: [grad1_rank1, grad1_rank1, ...]
}
"""

for param, padding_size in zip(self._param_list, self._padding_size):
with torch.no_grad():
grad = param.grad.detach().flatten()
if padding_size > 0:
grad = torch.nn.functional.pad(grad, [0, padding_size])
grad_list = grad.split(grad.numel() // self._world_size)
for rank in range(self._world_size):
grad_current_rank = grad_list[rank].detach()
self.grad_to_param_mapping[id(grad_current_rank)] = id(param)
self._grad_in_bucket[rank].append(grad_current_rank)
param.grad = None

def get_grad(self) -> Dict:
"""Return the dictionary of gradients slices, of which the keys are ranks

Returns:
Dict: The dictionary of gradients slices
"""

return self._grad_in_bucket

def get_flatten_grad(self) -> Tensor:
"""Return the flattened gradients slices in the bucket, the data orginization of the flattened tensor:
[grad0_rank0, grad1_rank0, ..., grad_1_rank0, grad1_rank1, ....]

Returns:
Tensor: the flattened gradients slices in the bucket
"""

flat_grad = []
for grad_list in self._grad_in_bucket.values():
flat_grad.append(_flatten_dense_tensors(grad_list))
flat_grad = _flatten_dense_tensors(flat_grad)
return flat_grad

def get_param_id_of_grad(self, grad: Tensor) -> int:
"""Return the id of a parameter which the gradient slice belongs to

Args:
grad (Tensor): the gradient slice

def add_num_elements_in_bucket(self, num_elements, reduce_rank: int = None):
self._num_elements_in_bucket[reduce_rank] += num_elements
Returns:
int: the id of a parameter which the gradient slice belongs to
"""

def add_param(self, tensor, reduce_rank: int = None):
self._params[reduce_rank].append(tensor)
return self.grad_to_param_mapping[id(grad)]

def reset(self):
keys = [None] + list(range(self._world_size))
self._params = {rank: [] for rank in keys}
self._num_elements_in_bucket = {rank: 0 for rank in keys}

def reset_by_rank(self, reduce_rank=None):
self._params[reduce_rank] = []
self._num_elements_in_bucket[reduce_rank] = 0

def get_grad(self, reduce_rank: int = None):
param_list = self.get_param(reduce_rank)
for param in param_list:
# the param must have grad for reduction
assert param.grad is not None, f'Parameter of size ({param.size()}) has None grad, cannot be reduced'
return [param.grad for param in param_list]

def get_param(self, reduce_rank: int = None):
return self._params[reduce_rank]
self.grad_to_param_mapping = dict()
self._num_elements_in_bucket = 0
self._param_list = []
self._padding_size = []
self._grad_in_bucket = dict()
for rank in range(self._world_size):
self._grad_in_bucket[rank] = []
118 changes: 61 additions & 57 deletions colossalai/zero/low_level/bookkeeping/gradient_store.py
Original file line number Diff line number Diff line change
@@ -1,88 +1,92 @@
from typing import List

from torch import Tensor
from torch._utils import _flatten_dense_tensors

from .base_store import BaseStore


class GradientStore(BaseStore):

def __init__(self, *args):
def __init__(self, *args, partition_grad: bool = False):
super().__init__(*args)
# bookkeeping data structures
self._averaged_gradients = dict()

# for backward reduction hooks
self._grad_acc_objs = []

def append_accumulate_grad_object(self, obj):
"""
Keep :class:`AccumulateGrad` objects. If these objects are not kept, reduction hooks may not
be attached successfully.

:param obj: An object of :class:`AccumulateGrad` class
:type obj: :class:`AccumulateGrad`
self._grads_of_params mapping the paramater and its gradient slices
data structure:
{
group_id:{
param_id: [grad_rank0, grad_rank1, ...]
}
}
"""
self._grads_of_params = dict()
# for zero2, it's `param_id: [grad_local_rank]`
self._working_index = 0 if partition_grad else self._local_rank

self._grad_acc_objs.append(obj)
def get_partitioned_gradients_by_param_id(self, group_id: int, param_id: int) -> List:
"""Return list of gradient slices of a specific parameter

def get_averaged_gradients_by_group(self, group_id: int) -> List[Tensor]:
"""
Return average gradients of a parameter group
Args:
group_id (int): The index of a parameter group
param_id (int): The id of a parameter

:param group_id: The index of parameter group
:type group_id: int

:return: Return the list of averaged gradients of a parameter group. Each element is a gradient, not a parameter.
:rtype: List[torch.Tensor]
Returns:
List: the list of gradient slices of a parameter.
"""
if group_id not in self._averaged_gradients:
self._averaged_gradients[group_id] = []

return self._averaged_gradients[group_id]

def append_average_gradient_by_group(self, group_id: int, tensor: Tensor) -> None:
"""
Append an average gradient to the list of averaged gradients of a parameter group
if group_id in self._grads_of_params:
if param_id in self._grads_of_params[group_id]:
return self._grads_of_params[group_id][param_id]
# the param has no grad, for instance, in layer drop
return []

:param group_id: The index of a parameter group
:param tensor: A :class:`torch.Tensor` object
:type group_id: int
:type tensor: torch.Tensor
def append_gradients_by_param_id(self, grad: Tensor, group_id: int, param_id: int):
"""Append a gradient slice to the parameter's gradient slice list

Args:
grad (Tensor): The gradient slice to append to list
group_id (int): The index of a parameter group
param_id (int): The id of a parameter
"""

if group_id in self._averaged_gradients:
self._averaged_gradients[group_id].append(tensor)
if group_id not in self._grads_of_params:
self._grads_of_params[group_id] = dict()
if param_id not in self._grads_of_params[group_id]:
self._grads_of_params[group_id][param_id] = [grad]
else:
self._averaged_gradients[group_id] = [tensor]
self._grads_of_params[group_id][param_id].append(grad)

def add_average_gradient_by_group(self, group_id: int, tensor_idx: int, tensor: Tensor) -> None:
def add_gradients_by_param_id(self, grad: Tensor, grad_idx: int, group_id: int, param_id: int):
"""For old gradient accumulation, not in use now.
Add a gradient slice on an existing slice of the parameter's gradient

Args:
grad (Tensor): The split gradient to append to list
grad_idx (int): The index of the existing slice
group_id (int): The index of a parameter group
param_id (int): The id of a parameter
"""
Add an average gradient to the list of averaged gradients of a parameter group

:param group_id: The index of a parameter group
:param tensor_idx: The index of a tensor in the list of averaged gradients
:param tensor: A :class:`torch.Tensor` object
:type group_id: int
:type tensor_idx: int
:type tensor: torch.Tensor
self._grads_of_params[group_id][param_id][grad_idx].add_(grad)

"""
self._averaged_gradients[group_id][tensor_idx].add_(tensor)
def get_working_grads_by_group_id(self, group_id: int) -> List:
"""Return list of working gradient slices in the group

def reset_average_gradients_by_group(self, group_id: int) -> None:
"""
Reset the bookkeeping data structure for averaged gradients to an empty list
Args:
group_id (int): The index of a parameter group

:param group_id: The index of a parameter group
:type group_id: int
Returns:
List: the list working gradient slices in the group
"""

self._averaged_gradients[group_id] = []
grad_list = []
for param_grads in self._grads_of_params[group_id].values():
grad_list.append(param_grads[self._working_index])

def reset_all_average_gradients(self) -> None:
"""
Reset the bookkeeping data structure for averaged gradients to an empty list
"""
self._averaged_gradients = dict()
return grad_list

def reset_grads_by_group_id(self, group_id: int):
self._grads_of_params[group_id] = dict()

def reset_all_gradients(self):
self._grads_of_params = dict()
Loading