Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions colossalai/booster/plugin/gemini_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,7 @@ def __init__(
enable_jit_fused: bool = False,
enable_sequence_overlap: bool = False,
enable_async_reduce: bool = True,
use_fp8: bool = False,
verbose: bool = False,
fp8_communication: bool = False,
) -> None:
Expand Down Expand Up @@ -397,6 +398,7 @@ def __init__(
max_prefetch=max_prefetch,
enable_async_reduce=enable_async_reduce,
fp8_communication=fp8_communication,
use_fp8=use_fp8,
)
self.zero_optim_config = dict(
gpu_margin_mem_ratio=gpu_margin_mem_ratio,
Expand Down
2 changes: 1 addition & 1 deletion colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.quantization import BnbQuantizationConfig, quantize_model
from colossalai.quantization.fp8_hook import FP8Hook
from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer
from colossalai.shardformer.layer.utils import SeqParallelUtils
from colossalai.shardformer.policies.base_policy import Policy
Expand All @@ -40,7 +41,6 @@
from colossalai.zero.low_level import LowLevelZeroOptimizer
from colossalai.zero.low_level.zero_hook import ZeroOpHook, wait_all_gather_handle

from .fp8_hook import FP8Hook
from .pp_plugin_base import PipelinePluginBase

SUPPORT_SP_MODE = ["split_gather", "ring", "all_to_all"]
Expand Down
4 changes: 2 additions & 2 deletions colossalai/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,5 +652,5 @@ def backward(ctx: Any, out_grad) -> Any:
return x_grad.reshape(ctx.x_shape), w_grad, bias_grad


def linear_fp8(x: torch.Tensor, w: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return _LinearFp8.apply(x, w, bias)
def linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return _LinearFp8.apply(input, weight, bias)
11 changes: 8 additions & 3 deletions colossalai/zero/gemini/gemini_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from colossalai.interface import ModelWrapper
from colossalai.lazy import LazyTensor
from colossalai.logging import get_dist_logger
from colossalai.quantization.fp8_hook import FP8Hook
from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.tensor.d_tensor import (
distribute_tensor,
Expand Down Expand Up @@ -99,6 +100,7 @@ def __init__(
verbose: bool = False,
enable_async_reduce: bool = True,
fp8_communication: bool = False,
use_fp8: bool = False,
) -> None:
assert mixed_precision in (torch.float16, torch.bfloat16)
reuse_fp16_chunk = master_weights if not enable_gradient_accumulation else False
Expand Down Expand Up @@ -138,6 +140,9 @@ def __init__(
)
self.force_outputs_fp32 = force_outputs_fp32
self.param_op_hook = GeminiZeROHook(self.gemini_manager)
self.hooks = [self.param_op_hook]
if use_fp8:
self.hooks.append(FP8Hook())
self.fp32_params: List[torch.Tensor] = list()
self.fp16_params: List[ColoParameter] = list()
self.grads_device: Dict[torch.Tensor, torch.device] = dict()
Expand Down Expand Up @@ -310,7 +315,7 @@ def forward(self, *args, **kwargs):
outputs = self._inference_forward(*args, **kwargs)
else:
self.gemini_manager.pre_iter(*args)
with ColoParamOpHookManager.use_hooks(self.param_op_hook):
with ColoParamOpHookManager.use_hooks(*self.hooks):
outputs = self.module(*args, **kwargs)

if self.force_outputs_fp32:
Expand All @@ -319,7 +324,7 @@ def forward(self, *args, **kwargs):

def _inference_forward(self, *args, **kwargs):
"""This function is only triggered for inference."""
fwd_ctx = ColoParamOpHookManager.use_hooks(self.param_op_hook)
fwd_ctx = ColoParamOpHookManager.use_hooks(*self.hooks)
if not self.scatter_after_inference:
# gather all chunks
for chunk in self.chunk_manager.get_chunks(self.fp16_params):
Expand Down Expand Up @@ -372,7 +377,7 @@ def _post_backward(self):

def backward(self, loss: torch.Tensor):
self._pre_backward()
with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(self.param_op_hook):
with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(*self.hooks):
loss.backward()
self._post_backward()

Expand Down
7 changes: 7 additions & 0 deletions examples/language/llama/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ def main():
parser.add_argument("--prefetch_num", type=int, default=0, help="chunk prefetch max number")
parser.add_argument("--no_cache", action="store_true")
parser.add_argument("--use_fp8_comm", action="store_true", default=False, help="for using fp8 during communication")
parser.add_argument("--overlap_allgather", action="store_true")
parser.add_argument("--use_fp8", action="store_true")
args = parser.parse_args()

colossalai.launch_from_torch()
Expand Down Expand Up @@ -136,6 +138,7 @@ def empty_init():
enable_flash_attention=args.xformers,
max_prefetch=args.prefetch_num,
enable_async_reduce=not args.disable_async_reduce,
use_fp8=args.use_fp8,
)
elif args.plugin == "gemini_auto":
plugin = GeminiPlugin(
Expand All @@ -148,6 +151,7 @@ def empty_init():
max_prefetch=args.prefetch_num,
enable_async_reduce=not args.disable_async_reduce,
enable_flash_attention=args.xformers,
use_fp8=args.use_fp8,
)
elif args.plugin == "fsdp":
if use_empty_init:
Expand Down Expand Up @@ -207,6 +211,8 @@ def empty_init():
dp_outside=False,
overlap_p2p=args.overlap,
enable_metadata_cache=not args.no_cache,
overlap_allgather=args.overlap_allgather,
use_fp8=args.use_fp8,
**hybrid_kwargs,
)
elif args.plugin == "3d_cpu":
Expand All @@ -223,6 +229,7 @@ def empty_init():
initial_scale=2**8,
precision="bf16",
overlap_p2p=args.overlap,
use_fp8=args.use_fp8,
)
else:
raise ValueError(f"Unknown plugin {args.plugin}")
Expand Down
2 changes: 1 addition & 1 deletion tests/test_fp8/test_fp8_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import torch.nn.functional as F

from colossalai.accelerator import get_accelerator
from colossalai.booster.plugin.fp8_hook import FP8Hook
from colossalai.quantization.fp8 import linear_fp8
from colossalai.quantization.fp8_hook import FP8Hook
from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.utils import get_current_device
Expand Down