Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/build_on_schedule.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
runs-on: [self-hosted, gpu]
container:
image: hpcaitech/pytorch-cuda:2.1.0-12.1.0
options: --gpus all --rm -v /dev/shm -v /data/scratch/llama-tiny:/data/scratch/llama-tiny
options: --gpus all --rm -v /dev/shm -v /data/scratch/:/data/scratch/
timeout-minutes: 90
steps:
- name: Check GPU Availability # ensure all GPUs have enough memory
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/compatiblity_test_on_dispatch.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ jobs:
matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}}
container:
image: ${{ matrix.container }}
options: --gpus all --rm -v /dev/shm -v /data/scratch/cifar-10:/data/scratch/cifar-10 -v /data/scratch/llama-tiny:/data/scratch/llama-tiny
options: --gpus all --rm -v /dev/shm -v /data/scratch/:/data/scratch/
timeout-minutes: 120
steps:
- name: Install dependencies
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/compatiblity_test_on_pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ jobs:
matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}}
container:
image: ${{ matrix.container }}
options: --gpus all --rm -v /dev/shm -v /data/scratch/cifar-10:/data/scratch/cifar-10 -v /data/scratch/llama-tiny:/data/scratch/llama-tiny
options: --gpus all --rm -v /dev/shm -v /data/scratch/:/data/scratch/
timeout-minutes: 120
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}-run-test-${{ matrix.container }}
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/compatiblity_test_on_schedule.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ jobs:
matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}}
container:
image: ${{ matrix.container }}
options: --gpus all --rm -v /dev/shm -v /data/scratch/cifar-10:/data/scratch/cifar-10 -v /data/scratch/llama-tiny:/data/scratch/llama-tiny
options: --gpus all --rm -v /dev/shm -v /data/scratch/:/data/scratch/
timeout-minutes: 120
steps:
- name: Install dependencies
Expand Down
2 changes: 0 additions & 2 deletions applications/ColossalMoE/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from colossalai.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.cluster import DistCoordinator
from colossalai.moe.checkpoint import MoECheckpointIO


def parse_args():
Expand Down Expand Up @@ -69,7 +68,6 @@ def main():
ep_size=ep_size,
zero_stage=1,
precision=args.precision,
checkpoint_io=MoECheckpointIO,
enable_fused_normalization=args.use_layernorm_kernel,
enable_jit_fused=args.use_kernel,
)
Expand Down
2 changes: 0 additions & 2 deletions applications/ColossalMoE/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from colossalai.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
from colossalai.cluster import DistCoordinator
from colossalai.moe.checkpoint import MoECheckpointIO
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device
Expand Down Expand Up @@ -158,7 +157,6 @@ def main():
enable_jit_fused=args.use_kernel,
precision=args.precision,
zero_stage=args.zero_stage,
checkpoint_io=MoECheckpointIO,
)

else:
Expand Down
2 changes: 1 addition & 1 deletion colossalai/booster/plugin/low_level_zero_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,7 @@ def configure(

if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
optimizer: LowLevelZeroOptimizer = LowLevelZeroOptimizer(
optimizer, **self.zero_optim_kwargs, verbose=self.verbose
optimizer, **zero_optim_kwargs, verbose=self.verbose
)
# inject update_master_params
model.update_master_params = MethodType(optimizer.update_master_params, model)
Expand Down
44 changes: 22 additions & 22 deletions colossalai/zero/low_level/low_level_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class LowLevelOptStrategyBase(ABC):
def __init__(
self,
param_group,
process_group,
dp_process_group,
master_weights,
partition_grad,
cpu_offload,
Expand All @@ -46,14 +46,14 @@ def __init__(
self.param_group = param_group
self._dtype = self.param_group["params"][0].dtype

if process_group is None: # if process_group is none, convert to default explicitly
process_group = dist.group.WORLD
if dp_process_group is None: # if dp_process_group is none, convert to default explicitly
dp_process_group = dist.group.WORLD

self.process_group = process_group
self.dp_process_group = dp_process_group

# if process_group is none, will use the default one
self._local_rank = dist.get_rank(group=self.process_group)
self._world_size = dist.get_world_size(group=self.process_group)
# if dp_process_group is none, will use the default one
self._local_rank = dist.get_rank(group=self.dp_process_group)
self._world_size = dist.get_world_size(group=self.dp_process_group)

# master weights copy
self._master_weights = master_weights
Expand All @@ -65,9 +65,9 @@ def __init__(

# ParameterStore will manage the tensor buffers used for zero
# it will not manage the tensors used by mixed precision training
self._param_store = ParameterStore(process_group)
self._grad_store = GradientStore(process_group, partition_grad=partition_grad)
self._bucket_store = BucketStore(process_group, reduce_bucket_size=reduce_bucket_size)
self._param_store = ParameterStore(dp_process_group)
self._grad_store = GradientStore(dp_process_group, partition_grad=partition_grad)
self._bucket_store = BucketStore(dp_process_group, reduce_bucket_size=reduce_bucket_size)

# working and master params for mixed precision training
group_params = []
Expand Down Expand Up @@ -224,7 +224,7 @@ def _run_reduction(self):
flat_grads = flat_grads.to(self._communication_dtype)

if not self._partition_grad:
dist.all_reduce(flat_grads, group=self.process_group)
dist.all_reduce(flat_grads, group=self.dp_process_group)
if flat_grads.dtype != grad_dtype:
flat_grads = flat_grads.to(grad_dtype)

Expand All @@ -234,7 +234,7 @@ def _run_reduction(self):
else:
flat_grads_list = list(flat_grads.split(len(flat_grads) // self._world_size))
recieved_grad = torch.zeros_like(flat_grads_list[0])
dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.process_group)
dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.dp_process_group)

if recieved_grad.dtype != grad_dtype:
recieved_grad = recieved_grad.to(grad_dtype)
Expand Down Expand Up @@ -294,7 +294,7 @@ def state_dict(self, optim: torch.optim.Optimizer) -> Dict:
gather_tensor = [
torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(self._world_size)
]
dist.all_gather(gather_tensor, v, group=self.process_group)
dist.all_gather(gather_tensor, v, group=self.dp_process_group)
param_state = (
torch.stack(gather_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu()
)
Expand Down Expand Up @@ -328,7 +328,7 @@ def get_grad_norm(self, norm_type: int = 2) -> float:
total_norm_cuda = torch.tensor(
[float(total_norm)], device=get_accelerator().get_current_device(), dtype=torch.float
)
dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=self.process_group)
dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=self.dp_process_group)
total_norm = total_norm_cuda.item()

else:
Expand All @@ -342,7 +342,7 @@ def get_grad_norm(self, norm_type: int = 2) -> float:
[float(total_norm_exponentiated)], device=get_accelerator().get_current_device(), dtype=torch.float
)
torch.distributed.all_reduce(
total_norm_exponentiated_cuda, op=torch.distributed.ReduceOp.SUM, group=self.process_group
total_norm_exponentiated_cuda, op=torch.distributed.ReduceOp.SUM, group=self.dp_process_group
)
total_norm = total_norm_exponentiated_cuda.item() ** (1.0 / norm_type)

Expand Down Expand Up @@ -381,7 +381,7 @@ def get_param_grad(self, param):
return None
if self._partition_grad:
tensor_list = [torch.empty_like(grad_maybe_partial[0]) for _ in range(self._world_size)]
dist.all_gather(tensor_list, grad_maybe_partial[0], group=self.process_group)
dist.all_gather(tensor_list, grad_maybe_partial[0], group=self.dp_process_group)
grad_flat = torch.cat(tensor_list, dim=0)
else:
grad_flat = torch.cat(grad_maybe_partial, dim=0)
Expand Down Expand Up @@ -420,7 +420,7 @@ class LowLevelOptStrategy(LowLevelOptStrategyBase):
def __init__(
self,
param_group: Dict[str, Any], # from optimizer.param_groups
process_group: Optional[ProcessGroup] = None, # the dp pg for comm
dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm
reduce_bucket_size: int = 1024 * 1024, # communication
communication_dtype: Optional[torch.dtype] = None,
overlap_communication: bool = False,
Expand All @@ -430,7 +430,7 @@ def __init__(
):
super().__init__(
param_group=param_group,
process_group=process_group,
dp_process_group=dp_process_group,
cpu_offload=cpu_offload,
partition_grad=partition_grad,
master_weights=master_weights,
Expand Down Expand Up @@ -516,7 +516,7 @@ def post_step(self):
all_splited_param = [
torch.zeros(master_param.shape, device=device, dtype=self._dtype) for _ in range(self._world_size)
]
dist.all_gather(all_splited_param, master_param.to(device).to(self._dtype), group=self.process_group)
dist.all_gather(all_splited_param, master_param.to(device).to(self._dtype), group=self.dp_process_group)
working_param.data.copy_(flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param))

# restore tmp values
Expand All @@ -535,7 +535,7 @@ def __init__(
overlap_communication: bool = False,
partition_grad: bool = False, # stage 2 flag
cpu_offload: bool = False, # cpu offload
process_group: Optional[ProcessGroup] = None, # the dp pg for comm
dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm
master_weights: bool = True, # master weights
):
for param in param_group["params"]:
Expand All @@ -544,7 +544,7 @@ def __init__(

super().__init__(
param_group=param_group,
process_group=process_group,
dp_process_group=dp_process_group,
cpu_offload=cpu_offload,
partition_grad=partition_grad,
master_weights=master_weights,
Expand All @@ -556,6 +556,6 @@ def __init__(
# def get_param_grad(self, param): # TODO @botbw: discuss whether it's intuitive to return grad of divided of full moe tensor
# moe_partial_grad = super().get_param_grad(param)
# moe_grad_list = [torch.empty_like(moe_partial_grad) for _ in range(self._world_size)]
# dist.all_gather(moe_grad_list, moe_partial_grad, group=self.process_group)
# dist.all_gather(moe_grad_list, moe_partial_grad, group=self.dp_process_group)
# moe_grad = torch.cat(moe_grad_list, dim=0).reshape(param.shape[0] * self._world_size, *param.shape[1:])
# return moe_grad
4 changes: 2 additions & 2 deletions tests/test_moe/test_moe_zero_fwd_bwd_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,13 @@ def run_zero_with_original_model(world_size, master_weights: bool, dtype: torch.
strategies = [
LowLevelOptStrategy(
param_group=zero_optimizer.param_groups[0],
process_group=plugin.global_dp_group,
dp_process_group=plugin.global_dp_group,
overlap_communication=False,
partition_grad=(stage == 2),
),
MoeZeroStrategy(
param_group=zero_optimizer.param_groups[1],
process_group=plugin.moe_dp_group,
dp_process_group=plugin.moe_dp_group,
overlap_communication=True,
partition_grad=(stage == 2),
),
Expand Down