Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 17 additions & 31 deletions colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from colossalai.shardformer import ShardConfig
from colossalai.shardformer.policies.base_policy import Policy
from colossalai.tensor.moe_tensor.api import is_moe_tensor
from colossalai.zero.low_level import LowLevelOptStrategy, LowLevelZeroOptimizer, MoeZeroStrategy
from colossalai.zero.low_level import LowLevelZeroOptimizer


class MoeHybridParallelZeroOptimizer(LowLevelZeroOptimizer):
Expand Down Expand Up @@ -68,38 +68,19 @@ def __init__(
if use_pipeline:
init_pipeline_optimizer(optimizer, model)

assert (
len(optimizer.param_groups) == 1
), "Currently only one parameter group is supported, and we will support multiple groups later."
zero_params = list(filter(lambda x: not is_moe_tensor(x), model.parameters()))
moe_params = list(filter(lambda x: is_moe_tensor(x), model.parameters()))

optimizer.param_groups.clear()
optimizer.add_param_group({"params": zero_params})
optimizer.add_param_group({"params": moe_params})
strategies = [
LowLevelOptStrategy(
param_group=optimizer.param_groups[0],
process_group=dp_process_group,
reduce_bucket_size=reduce_bucket_size,
communication_dtype=communication_dtype,
overlap_communication=overlap_communication,
partition_grad=partition_grad,
cpu_offload=cpu_offload,
),
MoeZeroStrategy(
param_group=optimizer.param_groups[1],
process_group=moe_extra_dp_process_group,
reduce_bucket_size=reduce_bucket_size,
communication_dtype=communication_dtype,
overlap_communication=overlap_communication,
partition_grad=partition_grad,
cpu_offload=cpu_offload,
),
]
pg_param_list = {
dp_process_group: [],
moe_extra_dp_process_group: [],
}
for param in model.parameters():
if is_moe_tensor(param):
pg_param_list[moe_extra_dp_process_group].append(param)
else:
pg_param_list[dp_process_group].append(param)

super().__init__(
optimizer=optimizer,
group_strategies=strategies,
pg_param_list=pg_param_list,
initial_scale=initial_scale,
min_scale=min_scale,
growth_factor=growth_factor,
Expand All @@ -109,6 +90,11 @@ def __init__(
max_scale=max_scale,
clip_grad_norm=clip_grad_norm,
verbose=verbose,
reduce_bucket_size=reduce_bucket_size,
communication_dtype=communication_dtype,
overlap_communication=overlap_communication,
partition_grad=partition_grad,
cpu_offload=cpu_offload,
forced_dtype=forced_dtype,
)

Expand Down
3 changes: 1 addition & 2 deletions colossalai/zero/low_level/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from .low_level_optim import LowLevelZeroOptimizer
from .low_level_strategy import LowLevelOptStrategy, LowLevelOptStrategyBase, MoeZeroStrategy

__all__ = ["LowLevelZeroOptimizer", "LowLevelOptStrategy", "MoeZeroStrategy", "LowLevelOptStrategyBase"]
__all__ = ["LowLevelZeroOptimizer"]
3 changes: 1 addition & 2 deletions colossalai/zero/low_level/bookkeeping/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from .bucket_store import BucketStore
from .gradient_store import GradientStore
from .parameter_store import ParameterStore
from .tensor_bucket import TensorBucket

__all__ = ["GradientStore", "ParameterStore", "BucketStore", "TensorBucket"]
__all__ = ["GradientStore", "BucketStore", "TensorBucket"]
5 changes: 5 additions & 0 deletions colossalai/zero/low_level/bookkeeping/bucket_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from torch._utils import _flatten_dense_tensors
from torch.distributed import ProcessGroup

from colossalai.accelerator.api import get_accelerator

from .base_store import BaseStore


Expand All @@ -13,10 +15,13 @@ def __init__(
self,
torch_pg: ProcessGroup,
reduce_bucket_size: int,
overlap_comm: bool = False,
):
super().__init__(torch_pg)
self.reduce_bucket_size = reduce_bucket_size
self.reset_all()
if overlap_comm:
self.comm_stream = get_accelerator().Stream()

def reset_all(self) -> None:
# init
Expand Down
7 changes: 2 additions & 5 deletions colossalai/zero/low_level/bookkeeping/gradient_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


class GradientStore(BaseStore):
def __init__(self, *args, partition_grad: bool = False, require_grad_sync: bool = True):
def __init__(self, *args, partition_grad: bool = False):
super().__init__(*args)
"""
self._grads_of_params mapping the parameter and its gradient slices
Expand All @@ -20,8 +20,6 @@ def __init__(self, *args, partition_grad: bool = False, require_grad_sync: bool
self._grads_of_params = dict()
# stage 2
self._partition_grads = partition_grad
# grad accumulation
self.require_grad_sync = require_grad_sync
self._working_index = 0 if partition_grad else self._local_rank
# for zero2, it's `param_id: [grad_local_rank]`
self.grad_to_param_mapping = dict()
Expand Down Expand Up @@ -107,8 +105,7 @@ def get_working_grad_by_param_id(self, param_id) -> Tensor:
for group in self._grads_of_params.values():
if param_id in group.keys():
return group[param_id][self._working_index]

raise KeyError(f"Working gradient for param_id {param_id} not found.")
return None

def reset_grads_by_group_id(self, group_id: int):
self._grads_of_params[group_id] = dict()
Expand Down
60 changes: 0 additions & 60 deletions colossalai/zero/low_level/bookkeeping/parameter_store.py

This file was deleted.

Loading