Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions colossalai/booster/plugin/low_level_zero_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from colossalai.interface.optimizer import DistributedOptim
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
from colossalai.quantization import BnbQuantizationConfig, quantize_model
from colossalai.quantization.fp8_hook import FP8Hook
from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.zero import LowLevelZeroOptimizer
Expand Down Expand Up @@ -62,7 +63,9 @@ class OptimizerParamCheckState(enum.Enum):


class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
def __init__(self, module: nn.Module, precision: str, overlap_allgather: bool = False) -> None:
def __init__(
self, module: nn.Module, precision: str, overlap_allgather: bool = False, use_fp8: bool = False
) -> None:
super().__init__(module)
self.dtype = None
if precision == "fp16":
Expand All @@ -74,11 +77,16 @@ def __init__(self, module: nn.Module, precision: str, overlap_allgather: bool =
module = module.to(get_accelerator().get_current_device())
self.module = module
self.convert_fn = None
self.use_fp8 = use_fp8
if self.dtype is not None:
self.convert_fn = partial(_convert_floating_point, dtype=self.dtype)
self.overlap_allgather = overlap_allgather
self.op_hooks = []
if overlap_allgather:
self.op_hook = ZeroOpHook()
self.op_hooks.append(ZeroOpHook())
if use_fp8:
self.op_hooks.append(FP8Hook())
if overlap_allgather or use_fp8:
for p in module.parameters():
if p.requires_grad and type(p) is not ColoParameter:
p.__class__ = ColoParameter
Expand Down Expand Up @@ -335,6 +343,7 @@ def __init__(
master_weights: bool = True,
verbose: bool = False,
fp8_communication: bool = False,
use_fp8: bool = False,
) -> None:
super().__init__()
assert stage in (1, 2), f"LowLevelZeroPlugin only supports stage 1/2 training"
Expand Down Expand Up @@ -362,6 +371,7 @@ def __init__(
)
self.lora_enabled = False
self.verbose = verbose
self.use_fp8 = use_fp8

# set class name with stage, for better error message
setattr(self.__class__, "__name__", f"LowLevelZeroPlugin_ZeRO-{stage}")
Expand Down Expand Up @@ -476,7 +486,10 @@ def configure(

if not isinstance(model, ModelWrapper):
model = LowLevelZeroModel(
model, self.precision, overlap_allgather=self.zero_optim_kwargs["overlap_allgather"]
model,
self.precision,
overlap_allgather=self.zero_optim_kwargs["overlap_allgather"],
use_fp8=self.use_fp8,
)

# TODO: Support Galore + ZeRO
Expand Down
1 change: 0 additions & 1 deletion examples/language/llama/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,6 @@ def empty_init():
if isinstance(plugin, (GeminiPlugin, HybridParallelPlugin))
else nullcontext()
)

init_kwargs = {}
if config.model_type == "chatglm":
init_kwargs["empty_init"] = False
Expand Down