Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def __init__(

super().__init__(
optimizer=optimizer,
pg_param_list=pg_param_list,
pg_to_param_list=pg_param_list,
initial_scale=initial_scale,
min_scale=min_scale,
growth_factor=growth_factor,
Expand Down
4 changes: 1 addition & 3 deletions colossalai/zero/low_level/bookkeeping/bucket_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,11 @@ def __init__(
self,
torch_pg: ProcessGroup,
reduce_bucket_size: int,
overlap_comm: bool = False,
):
super().__init__(torch_pg)
self.reduce_bucket_size = reduce_bucket_size
self.reset_all()
if overlap_comm:
self.comm_stream = get_accelerator().Stream()
self.comm_stream = get_accelerator().Stream()

def reset_all(self) -> None:
# init
Expand Down
74 changes: 40 additions & 34 deletions colossalai/zero/low_level/low_level_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class LowLevelZeroFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):
def __init__(
self,
num_working_param_groups: int,
grad_stores: Dict[nn.Parameter, GradientStore],
pg_to_grad_store: Dict[ProcessGroup, GradientStore],
initial_scale: float = 2**16,
min_scale: float = 1,
growth_factor: float = 2,
Expand All @@ -48,10 +48,10 @@ def __init__(
max_scale,
)
self.num_working_param_groups = num_working_param_groups
self.grad_stores = grad_stores
self.pg_to_grad_store = pg_to_grad_store

def check_local_overflow(self) -> bool:
for store in self.grad_stores.values():
for store in self.pg_to_grad_store.values():
for group_id in range(self.num_working_param_groups):
for avg_grad in store.get_working_grads_by_group_id(group_id):
if avg_grad is not None and has_inf_or_nan(avg_grad):
Expand All @@ -65,7 +65,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
def __init__(
self,
optimizer: Optimizer,
pg_param_list: Dict[ProcessGroup, List[nn.Parameter]] = None,
pg_to_param_list: Dict[ProcessGroup, List[nn.Parameter]] = None,
initial_scale: int = 2**16, # grad scaler config
min_scale: int = 1,
growth_factor: float = 2.0,
Expand All @@ -89,14 +89,14 @@ def __init__(
self._logger = get_dist_logger()
self._verbose = verbose

if pg_param_list is None:
pg_param_list = {dist.group.WORLD: []}
if pg_to_param_list is None:
pg_to_param_list = {dist.group.WORLD: []}
for group in self.optim.param_groups:
pg_param_list[dist.group.WORLD].extend(group["params"])
pg_to_param_list[dist.group.WORLD].extend(group["params"])

self.pg_param_list = pg_param_list
self.pg_to_param_list = pg_to_param_list
param_to_pg = {}
for grp, param_list in pg_param_list.items():
for grp, param_list in pg_to_param_list.items():
for p in param_list:
assert isinstance(p, nn.Parameter)
param_to_pg[p] = grp
Expand Down Expand Up @@ -148,15 +148,18 @@ def __init__(
self.working_to_master_param = dict()

# NOTE need to gurantee the order of process group is the same accross all ranks
self.grad_stores = {pg: GradientStore(pg, partition_grad=self._partition_grads) for pg in self.pg_param_list}
# param id to grad store, have to use id(param) as key since it is used in stores
self.pid2grad_store = {id(param): self.grad_stores[param_to_pg[param]] for param in param_to_pg}
self.bucket_stores = {
pg: BucketStore(pg, reduce_bucket_size, overlap_comm=self._overlap_communication)
for pg in self.pg_param_list
# process_group <---> xxx_store
# process_group <---> [param1 param2 ...]
# each process group have its own stores
# param belonging to one process_group will use corresponding store
self.pg_to_grad_store = {
pg: GradientStore(pg, partition_grad=self._partition_grads) for pg in self.pg_to_param_list
}
# param id to grad store, have to use id(param) as key since it is used in stores
self.pid_to_grad_store = {id(param): self.pg_to_grad_store[param_to_pg[param]] for param in param_to_pg}
self.pg_to_bucket_store = {pg: BucketStore(pg, reduce_bucket_size) for pg in self.pg_to_param_list}
# param id to bucket store, have to use id(param) as key since it is used in stores
self.pid2bucket_store = {id(param): self.bucket_stores[param_to_pg[param]] for param in param_to_pg}
self.pid_to_bucket_store = {id(param): self.pg_to_bucket_store[param_to_pg[param]] for param in param_to_pg}

# iterate over the param group in the optimizer
# partition these param groups for data parallel training
Expand Down Expand Up @@ -190,7 +193,7 @@ def __init__(
if self._dtype is torch.float16:
self.mixed_precision_mixin = LowLevelZeroFP16MixedPrecisionMixin(
self.num_param_groups,
self.grad_stores,
self.pg_to_grad_store,
initial_scale=initial_scale,
min_scale=min_scale,
growth_factor=growth_factor,
Expand Down Expand Up @@ -231,9 +234,9 @@ def _create_master_param_current_rank(self, param_list):

for param in param_list:
padding_size = (
self.pid2bucket_store[id(param)].world_size
- param.numel() % self.pid2bucket_store[id(param)].world_size
) % self.pid2bucket_store[id(param)].world_size
self.pid_to_bucket_store[id(param)].world_size
- param.numel() % self.pid_to_bucket_store[id(param)].world_size
) % self.pid_to_bucket_store[id(param)].world_size
self.record_param_padding_size(param, padding_size)

with torch.no_grad():
Expand All @@ -246,9 +249,9 @@ def _create_master_param_current_rank(self, param_list):
padding_param = param.data.view(-1)

splited_params = padding_param.split(
padding_param.numel() // self.pid2bucket_store[id(param)].world_size
padding_param.numel() // self.pid_to_bucket_store[id(param)].world_size
)
splited_params = splited_params[self.pid2bucket_store[id(param)].local_rank]
splited_params = splited_params[self.pid_to_bucket_store[id(param)].local_rank]

# use fp32 when master_weights is True
if self._master_weights is True:
Expand Down Expand Up @@ -288,7 +291,7 @@ def _grad_handler(param, group_id):
#######################

def _run_reduction(self):
for bucket_store in self.bucket_stores.values():
for bucket_store in self.pg_to_bucket_store.values():
if bucket_store.num_elements_in_bucket() <= 0:
continue

Expand Down Expand Up @@ -367,10 +370,13 @@ def _add_grad(
param_id: int,
rank: int = 0,
) -> None:
if len(self.pid2grad_store[param_id].get_partitioned_gradients_by_param_id(group_id, param_id)) < partition_num:
self.pid2grad_store[param_id].append_gradients_by_param_id(grad, group_id, param_id)
if (
len(self.pid_to_grad_store[param_id].get_partitioned_gradients_by_param_id(group_id, param_id))
< partition_num
):
self.pid_to_grad_store[param_id].append_gradients_by_param_id(grad, group_id, param_id)
else:
self.pid2grad_store[param_id].add_gradients_by_param_id(grad, rank, group_id, param_id)
self.pid_to_grad_store[param_id].add_gradients_by_param_id(grad, rank, group_id, param_id)

def _add_to_bucket(self, param, group_id):
param_size = param.numel()
Expand All @@ -380,13 +386,13 @@ def _add_to_bucket(self, param, group_id):
# or got a grad of param from another group
# after reduction, the bucket will be empty
if (
self.pid2bucket_store[id(param)].num_elements_in_bucket() + param_size > self._reduce_bucket_size
or group_id != self.pid2bucket_store[id(param)].current_group_id
self.pid_to_bucket_store[id(param)].num_elements_in_bucket() + param_size > self._reduce_bucket_size
or group_id != self.pid_to_bucket_store[id(param)].current_group_id
):
self._run_reduction()

padding_size = self.get_param_padding_size(param)
self.pid2bucket_store[id(param)].add_param_grad(group_id, param, padding_size)
self.pid_to_bucket_store[id(param)].add_param_grad(group_id, param, padding_size)

################################
# torch.optim.Optimizer methods
Expand Down Expand Up @@ -429,11 +435,11 @@ def backward_by_grad(self, tensor, grad):
get_accelerator().synchronize()

def zero_bucket_stores(self):
for bucket_store in self.bucket_stores.values():
for bucket_store in self.pg_to_bucket_store.values():
bucket_store.reset_all()

def zero_grad_stores(self):
for grad_store in self.grad_stores.values():
for grad_store in self.pg_to_grad_store.values():
grad_store.reset_all_gradients()

def zero_grad(self, set_to_none=True):
Expand Down Expand Up @@ -492,7 +498,7 @@ def step(self, closure=None):
# if a working param requires grad and has no grad
# it is not 'really' working, e.g. the droped layer
# else the splited grad should be attached to the splited param
grad_store = self.pid2grad_store[id(working_param)]
grad_store = self.pid_to_grad_store[id(working_param)]
grads = grad_store.get_partitioned_gradients_by_param_id(group_id, id(working_param))
grad_index = 0 if self._partition_grads else grad_store.local_rank
if len(grads) > 0:
Expand All @@ -507,7 +513,7 @@ def step(self, closure=None):

# compute norm
norm_group = 0
for grad_store in self.grad_stores.values():
for grad_store in self.pg_to_grad_store.values():
working_grads = grad_store.get_working_grads_by_group_id(group_id)
norm_group += self._compute_grad_norm(pg=grad_store.torch_pg, gradients=working_grads)

Expand Down Expand Up @@ -840,7 +846,7 @@ def get_padding_map(self) -> Dict[int, Tensor]:
return self._padding_map

def get_param_grad(self, working_param: nn.Parameter) -> Tensor:
grad_store = self.pid2grad_store[id(working_param)]
grad_store = self.pid_to_grad_store[id(working_param)]
partial_grad = grad_store.get_working_grad_by_param_id(id(working_param))
if partial_grad is None:
return None
Expand Down
2 changes: 1 addition & 1 deletion tests/test_moe/test_moe_zero_fwd_bwd_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def run_zero_with_original_model(world_size, master_weights: bool, dtype: torch.

zero_optimizer = LowLevelZeroOptimizer(
zero_optimizer,
pg_param_list=pg_param_list,
pg_to_param_list=pg_param_list,
master_weights=master_weights,
initial_scale=1,
overlap_communication=False,
Expand Down