Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
from colossalai.shardformer.layer.utils import SeqParallelUtils
from colossalai.shardformer.policies.base_policy import Policy
from colossalai.tensor.d_tensor.api import is_distributed_tensor
from colossalai.zero.low_level import LowLevelZeroOptimizer
from colossalai.utils.device import get_current_device
from colossalai.zero.low_level import LowLevelZeroOptimizer

from .pp_plugin_base import PipelinePluginBase

Expand Down Expand Up @@ -385,7 +385,9 @@ def _compute_grad_norm(self, param_gradient_pairs: List[Tuple[Tensor]], norm_typ

total_norm_exponentiated += grad_norm_exponentiated

total_norm_exponentiated_cuda = torch.tensor([float(total_norm_exponentiated)], device=get_current_device(), dtype=torch.float32)
total_norm_exponentiated_cuda = torch.tensor(
[float(total_norm_exponentiated)], device=get_current_device(), dtype=torch.float32
)
if self.tp_size > 1:
# compute norm in tp process group
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg)
Expand Down Expand Up @@ -586,7 +588,9 @@ def _compute_grad_norm(self, param_gradient_pairs: List[Tuple[Tensor]], norm_typ

total_norm_exponentiated += grad_norm_exponentiated

total_norm_exponentiated_cuda = torch.tensor([float(total_norm_exponentiated)], device=get_current_device(), dtype=torch.float32)
total_norm_exponentiated_cuda = torch.tensor(
[float(total_norm_exponentiated)], device=get_current_device(), dtype=torch.float32
)
if self.tp_size > 1:
# compute norm in tp process group
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg)
Expand Down Expand Up @@ -837,7 +841,9 @@ def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> flo

total_norm_exponentiated += grad_norm_exponentiated

total_norm_exponentiated_cuda = torch.tensor([float(total_norm_exponentiated)], device=get_current_device(), dtype=torch.float32)
total_norm_exponentiated_cuda = torch.tensor(
[float(total_norm_exponentiated)], device=get_current_device(), dtype=torch.float32
)
if dp_size > 1:
# compute norm in dp process group
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.dp_pg)
Expand Down
19 changes: 13 additions & 6 deletions tests/test_booster/test_plugin/test_gemini_plugin.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from contextlib import nullcontext
from typing import Optional
import pytest

import pytest
import torch
import torch.distributed as dist

Expand All @@ -11,8 +11,6 @@
from colossalai.fx import is_compatible_with_meta
from colossalai.lazy.lazy_init import LazyInitContext
from colossalai.nn.optimizer import HybridAdam
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
Expand All @@ -26,7 +24,13 @@ def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, zero_size, t
ctx = nullcontext()
extra_dp_size = dist.get_world_size() // (zero_size * tp_size)
enable_all_optimization = True if tp_size > 1 else False
plugin = GeminiPlugin(max_norm=1.0, initial_scale=2**5, tp_size=tp_size, extra_dp_size=extra_dp_size, enable_all_optimization=enable_all_optimization)
plugin = GeminiPlugin(
max_norm=1.0,
initial_scale=2**5,
tp_size=tp_size,
extra_dp_size=extra_dp_size,
enable_all_optimization=enable_all_optimization,
)
booster = Booster(plugin=plugin)
with ctx:
model = model_fn()
Expand Down Expand Up @@ -66,7 +70,9 @@ def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, zero_size, t
@parameterize("init_method", ["none"])
@parameterize("zero_size", [2])
@parameterize("tp_size", [2])
def check_gemini_plugin(subset: str, init_method: str = "none", early_stop: bool = True, zero_size: int = 1, tp_size: int = 1):
def check_gemini_plugin(
subset: str, init_method: str = "none", early_stop: bool = True, zero_size: int = 1, tp_size: int = 1
):
"""check gemini plugin over model zoo

Args:
Expand Down Expand Up @@ -156,11 +162,12 @@ def run_dist(rank, world_size, port, early_stop: bool = True):
def test_gemini_plugin(early_stop: bool = True):
spawn(run_dist, 4, early_stop=early_stop)


@pytest.mark.largedist
@rerun_if_address_is_in_use()
def test_gemini_plugin_3d(early_stop: bool = True):
spawn(run_dist, 8, early_stop=early_stop)


if __name__ == "__main__":
test_gemini_plugin(early_stop=False)
test_gemini_plugin(early_stop=False)