From 75d1d1cca44f1d0855c49fb15a0ccb16683ba323 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Wed, 7 Aug 2024 09:51:18 +0000 Subject: [PATCH 1/6] fix --- colossalai/quantization/fp8.py | 22 ---------------------- 1 file changed, 22 deletions(-) diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py index bc8c3ced4cdd..805824a896c7 100644 --- a/colossalai/quantization/fp8.py +++ b/colossalai/quantization/fp8.py @@ -376,28 +376,6 @@ def all_to_all_fp8(output_list, input_list, group=None, fp8_format="e5m2"): output_list[i].copy_(cast_from_fp8(tensor, scale, input_type)) -def all_to_all_single_fp8(output_tensor, input_tensor, group=None, fp8_format="e5m2"): - - world_size = dist.get_world_size(group) - - per_slice_len = input_tensor.size(0) // world_size - input_type = input_tensor.dtype - ret, scale = cast_to_fp8(input_tensor, fp8_format=fp8_format) - fp8_type = ret.dtype - input_tensor = ret.view(torch.uint8) - tensor = torch.empty_like(input_tensor) - scale_list = [torch.empty_like(scale) for _ in range(world_size)] - dist.all_to_all_single(tensor, input_tensor, group=group) - dist.all_gather(scale_list, scale, group=group) - cast_tensor_list = [] - - for i in range(world_size): - output_part = tensor[per_slice_len * i : per_slice_len * (i + 1)].view(fp8_type) - output_part = cast_from_fp8(output_part, scale_list[i], input_type) - cast_tensor_list.append(output_part) - output_tensor.copy_(torch.concatenate(cast_tensor_list, dim=0)) - - def gather_fp8(output_list, input_, group=None, fp8_format="e5m2"): world_size = dist.get_world_size(group) From f081275993dde620a4fcc9823a1f43926e808d8e Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Thu, 8 Aug 2024 08:02:30 +0000 Subject: [PATCH 2/6] fix --- .github/workflows/example_check_on_pr.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/example_check_on_pr.yml b/.github/workflows/example_check_on_pr.yml index 56fa006b1633..1ccdd59afefd 100644 --- a/.github/workflows/example_check_on_pr.yml +++ b/.github/workflows/example_check_on_pr.yml @@ -107,7 +107,7 @@ jobs: - name: Install Colossal-AI run: | - BUILD_EXT=1 pip install -v . + BUILD_EXT=1 pip install -v -e . - name: Store Colossal-AI Cache run: | From 4f127135b2c33cde6056d251d0c32d845276c2ac Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Thu, 8 Aug 2024 08:04:55 +0000 Subject: [PATCH 3/6] fix --- .github/workflows/example_check_on_pr.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/example_check_on_pr.yml b/.github/workflows/example_check_on_pr.yml index 1ccdd59afefd..7a906738cb96 100644 --- a/.github/workflows/example_check_on_pr.yml +++ b/.github/workflows/example_check_on_pr.yml @@ -9,6 +9,7 @@ on: paths: - "examples/**" - "!examples/**.md" + - ".github/workflows/example_check_on_pr.yml" jobs: # This is for changed example files detect and output a matrix containing all the corresponding directory name. From 19bf547d44059824acd0590398f4d3ca55be1807 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Thu, 15 Aug 2024 06:38:31 +0000 Subject: [PATCH 4/6] zero fp8 --- .../booster/plugin/low_level_zero_plugin.py | 19 ++++++++-- colossalai/quantization/fp8.py | 8 ++-- examples/language/llama/benchmark.py | 37 ++++++++++++++++--- requirements/requirements.txt | 2 +- 4 files changed, 53 insertions(+), 13 deletions(-) diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index 64f264f7eba1..63d46f6f8a2e 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -35,6 +35,7 @@ from colossalai.interface.optimizer import DistributedOptim from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed from colossalai.quantization import BnbQuantizationConfig, quantize_model +from colossalai.quantization.fp8_hook import FP8Hook from colossalai.tensor.colo_parameter import ColoParameter from colossalai.tensor.param_op_hook import ColoParamOpHookManager from colossalai.zero import LowLevelZeroOptimizer @@ -62,7 +63,9 @@ class OptimizerParamCheckState(enum.Enum): class LowLevelZeroModel(ModelWrapper, AMPModelMixin): - def __init__(self, module: nn.Module, precision: str, overlap_allgather: bool = False) -> None: + def __init__( + self, module: nn.Module, precision: str, overlap_allgather: bool = False, use_fp8: bool = False + ) -> None: super().__init__(module) self.dtype = None if precision == "fp16": @@ -74,11 +77,16 @@ def __init__(self, module: nn.Module, precision: str, overlap_allgather: bool = module = module.to(get_accelerator().get_current_device()) self.module = module self.convert_fn = None + self.use_fp8 = use_fp8 if self.dtype is not None: self.convert_fn = partial(_convert_floating_point, dtype=self.dtype) self.overlap_allgather = overlap_allgather + self.op_hooks = [] if overlap_allgather: - self.op_hook = ZeroOpHook() + self.op_hooks.append(ZeroOpHook()) + if use_fp8: + self.op_hooks.append(FP8Hook()) + if overlap_allgather or use_fp8: for p in module.parameters(): if p.requires_grad and type(p) is not ColoParameter: p.__class__ = ColoParameter @@ -335,6 +343,7 @@ def __init__( master_weights: bool = True, verbose: bool = False, fp8_communication: bool = False, + use_fp8: bool = False, ) -> None: super().__init__() assert stage in (1, 2), f"LowLevelZeroPlugin only supports stage 1/2 training" @@ -362,6 +371,7 @@ def __init__( ) self.lora_enabled = False self.verbose = verbose + self.use_fp8 = use_fp8 # set class name with stage, for better error message setattr(self.__class__, "__name__", f"LowLevelZeroPlugin_ZeRO-{stage}") @@ -476,7 +486,10 @@ def configure( if not isinstance(model, ModelWrapper): model = LowLevelZeroModel( - model, self.precision, overlap_allgather=self.zero_optim_kwargs["overlap_allgather"] + model, + self.precision, + overlap_allgather=self.zero_optim_kwargs["overlap_allgather"], + use_fp8=self.use_fp8, ) # TODO: Support Galore + ZeRO diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py index 4dd7db236c5d..1cdea69143a3 100644 --- a/colossalai/quantization/fp8.py +++ b/colossalai/quantization/fp8.py @@ -728,14 +728,14 @@ def backward(ctx: Any, out_grad) -> Any: return x_grad.reshape(ctx.x_shape), w_grad, bias_grad -@torch.compile(mode="reduce-overhead", disable=not SUPPORT_TORCH_COMPILE) +@torch.compile(mode="max-autotune-no-cudagraphs", disable=not SUPPORT_TORCH_COMPILE, dynamic=False) def _linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: return _LinearFp8.apply(input, weight, bias) def linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: out = _linear_fp8(input, weight, bias) - if SUPPORT_TORCH_COMPILE: - # avoid modifying the tensor created from cuda graph - out = out.clone() + # if SUPPORT_TORCH_COMPILE: + # # avoid modifying the tensor created from cuda graph + # out = out.clone() return out diff --git a/examples/language/llama/benchmark.py b/examples/language/llama/benchmark.py index 07583161b6fb..95e661d04765 100644 --- a/examples/language/llama/benchmark.py +++ b/examples/language/llama/benchmark.py @@ -10,6 +10,7 @@ from model_utils import format_numel_str, get_model_numel from performance_evaluator import PerformanceEvaluator, get_profile_context from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, MixedPrecision +from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm from transformers import AutoConfig, AutoModelForCausalLM from transformers.models.llama.configuration_llama import LlamaConfig @@ -17,12 +18,14 @@ import colossalai from colossalai.accelerator import get_accelerator from colossalai.booster import Booster -from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, TorchFSDPPlugin +from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchFSDPPlugin from colossalai.cluster import DistCoordinator from colossalai.lazy import LazyInitContext from colossalai.nn.optimizer import HybridAdam from colossalai.shardformer import PipelineGradientCheckpointConfig +# torch._dynamo.config.optimize_ddp=False + warnings.filterwarnings("ignore") # ============================== # Constants @@ -64,7 +67,7 @@ def main(): parser.add_argument( "-p", "--plugin", - choices=["gemini", "gemini_auto", "fsdp", "fsdp_cpu", "3d", "3d_cpu"], + choices=["gemini", "gemini_auto", "fsdp", "fsdp_cpu", "3d", "3d_cpu", "zero"], default="gemini", help="Choose which plugin to use", ) @@ -204,7 +207,7 @@ def empty_init(): zero_stage=args.zero, sp_size=args.sp, enable_sequence_parallelism=args.sp > 1, - enable_fused_normalization=torch.cuda.is_available(), + enable_fused_normalization=False, enable_flash_attention=args.xformers, microbatch_size=args.mbs, precision="bf16", @@ -215,6 +218,8 @@ def empty_init(): use_fp8=args.use_fp8, **hybrid_kwargs, ) + elif args.plugin == "zero": + plugin = LowLevelZeroPlugin(initial_scale=2**5, fp8_communication=args.use_fp8_comm, use_fp8=args.use_fp8) elif args.plugin == "3d_cpu": plugin = HybridParallelPlugin( tp_size=args.tp, @@ -223,7 +228,7 @@ def empty_init(): num_model_chunks=args.n_chunks, zero_stage=args.zero, cpu_offload=True, - enable_fused_normalization=torch.cuda.is_available(), + enable_fused_normalization=False, enable_flash_attention=args.xformers, microbatch_size=args.mbs, initial_scale=2**8, @@ -259,7 +264,7 @@ def empty_init(): if isinstance(plugin, (GeminiPlugin, HybridParallelPlugin)) else nullcontext() ) - + init_ctx = nullcontext() init_kwargs = {} if config.model_type == "chatglm": init_kwargs["empty_init"] = False @@ -277,6 +282,24 @@ def empty_init(): if config.model_type == "chatglm": model.transformer.encoder.gradient_checkpointing = True + # test lora + from peft import LoraConfig + + # lora_config = LoraConfig( + # # init_lora_weights="pissa_niter_4", # Initialize the PiSSA with fast SVD, which completes in just a few seconds. + # ) + # lora_config = LoraConfig(task_type="CAUSAL_LM", r=20, lora_alpha=32, lora_dropout=0.1) + config_params = { + "init_lora_weights": "pissa_niter_4", # Initialize the PiSSA with fast SVD, which completes in just a few seconds. + "task_type": "CAUSAL_LM", + "r": 32, + "lora_alpha": 32, + "lora_dropout": 0.1, + } + + lora_config = LoraConfig(**config_params) + model = plugin.enable_lora(model, lora_config=lora_config) + model_numel = get_model_numel(model) coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}") performance_evaluator = PerformanceEvaluator( @@ -301,6 +324,8 @@ def empty_init(): f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB" ) + writer = SummaryWriter("tensorboard/loss") + with get_profile_context( args.profile, args.ignore_steps, @@ -331,6 +356,8 @@ def empty_init(): performance_evaluator.on_step_start(step) outputs = model(**batch) loss = outputs[0] + if dist.get_rank() == 0: + writer.add_scalar(tag="Loss", scalar_value=loss.item(), global_step=step) booster.backward(loss, optimizer) optimizer.step() optimizer.zero_grad() diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 651eb66e89ab..6e24b07b8639 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -8,7 +8,7 @@ click fabric contexttimer ninja -torch>=2.1.0,<=2.3.0 +torch>=2.4.0 safetensors einops pydantic From 65829ed084f23cd4cfaef7984ef7e92ca649c2a3 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Thu, 15 Aug 2024 06:43:37 +0000 Subject: [PATCH 5/6] zero fp8 --- colossalai/quantization/fp8.py | 8 +++---- examples/language/llama/benchmark.py | 36 ++++------------------------ 2 files changed, 8 insertions(+), 36 deletions(-) diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py index 1cdea69143a3..4dd7db236c5d 100644 --- a/colossalai/quantization/fp8.py +++ b/colossalai/quantization/fp8.py @@ -728,14 +728,14 @@ def backward(ctx: Any, out_grad) -> Any: return x_grad.reshape(ctx.x_shape), w_grad, bias_grad -@torch.compile(mode="max-autotune-no-cudagraphs", disable=not SUPPORT_TORCH_COMPILE, dynamic=False) +@torch.compile(mode="reduce-overhead", disable=not SUPPORT_TORCH_COMPILE) def _linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: return _LinearFp8.apply(input, weight, bias) def linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: out = _linear_fp8(input, weight, bias) - # if SUPPORT_TORCH_COMPILE: - # # avoid modifying the tensor created from cuda graph - # out = out.clone() + if SUPPORT_TORCH_COMPILE: + # avoid modifying the tensor created from cuda graph + out = out.clone() return out diff --git a/examples/language/llama/benchmark.py b/examples/language/llama/benchmark.py index 95e661d04765..21d081145cd9 100644 --- a/examples/language/llama/benchmark.py +++ b/examples/language/llama/benchmark.py @@ -10,7 +10,6 @@ from model_utils import format_numel_str, get_model_numel from performance_evaluator import PerformanceEvaluator, get_profile_context from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, MixedPrecision -from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm from transformers import AutoConfig, AutoModelForCausalLM from transformers.models.llama.configuration_llama import LlamaConfig @@ -18,14 +17,12 @@ import colossalai from colossalai.accelerator import get_accelerator from colossalai.booster import Booster -from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchFSDPPlugin +from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, TorchFSDPPlugin from colossalai.cluster import DistCoordinator from colossalai.lazy import LazyInitContext from colossalai.nn.optimizer import HybridAdam from colossalai.shardformer import PipelineGradientCheckpointConfig -# torch._dynamo.config.optimize_ddp=False - warnings.filterwarnings("ignore") # ============================== # Constants @@ -67,7 +64,7 @@ def main(): parser.add_argument( "-p", "--plugin", - choices=["gemini", "gemini_auto", "fsdp", "fsdp_cpu", "3d", "3d_cpu", "zero"], + choices=["gemini", "gemini_auto", "fsdp", "fsdp_cpu", "3d", "3d_cpu"], default="gemini", help="Choose which plugin to use", ) @@ -207,7 +204,7 @@ def empty_init(): zero_stage=args.zero, sp_size=args.sp, enable_sequence_parallelism=args.sp > 1, - enable_fused_normalization=False, + enable_fused_normalization=torch.cuda.is_available(), enable_flash_attention=args.xformers, microbatch_size=args.mbs, precision="bf16", @@ -218,8 +215,6 @@ def empty_init(): use_fp8=args.use_fp8, **hybrid_kwargs, ) - elif args.plugin == "zero": - plugin = LowLevelZeroPlugin(initial_scale=2**5, fp8_communication=args.use_fp8_comm, use_fp8=args.use_fp8) elif args.plugin == "3d_cpu": plugin = HybridParallelPlugin( tp_size=args.tp, @@ -228,7 +223,7 @@ def empty_init(): num_model_chunks=args.n_chunks, zero_stage=args.zero, cpu_offload=True, - enable_fused_normalization=False, + enable_fused_normalization=torch.cuda.is_available(), enable_flash_attention=args.xformers, microbatch_size=args.mbs, initial_scale=2**8, @@ -264,7 +259,6 @@ def empty_init(): if isinstance(plugin, (GeminiPlugin, HybridParallelPlugin)) else nullcontext() ) - init_ctx = nullcontext() init_kwargs = {} if config.model_type == "chatglm": init_kwargs["empty_init"] = False @@ -282,24 +276,6 @@ def empty_init(): if config.model_type == "chatglm": model.transformer.encoder.gradient_checkpointing = True - # test lora - from peft import LoraConfig - - # lora_config = LoraConfig( - # # init_lora_weights="pissa_niter_4", # Initialize the PiSSA with fast SVD, which completes in just a few seconds. - # ) - # lora_config = LoraConfig(task_type="CAUSAL_LM", r=20, lora_alpha=32, lora_dropout=0.1) - config_params = { - "init_lora_weights": "pissa_niter_4", # Initialize the PiSSA with fast SVD, which completes in just a few seconds. - "task_type": "CAUSAL_LM", - "r": 32, - "lora_alpha": 32, - "lora_dropout": 0.1, - } - - lora_config = LoraConfig(**config_params) - model = plugin.enable_lora(model, lora_config=lora_config) - model_numel = get_model_numel(model) coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}") performance_evaluator = PerformanceEvaluator( @@ -324,8 +300,6 @@ def empty_init(): f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB" ) - writer = SummaryWriter("tensorboard/loss") - with get_profile_context( args.profile, args.ignore_steps, @@ -356,8 +330,6 @@ def empty_init(): performance_evaluator.on_step_start(step) outputs = model(**batch) loss = outputs[0] - if dist.get_rank() == 0: - writer.add_scalar(tag="Loss", scalar_value=loss.item(), global_step=step) booster.backward(loss, optimizer) optimizer.step() optimizer.zero_grad() From b04bd8b8139cf5834cf0cd2a7cfa78f8c7dfe974 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Thu, 15 Aug 2024 15:04:24 +0800 Subject: [PATCH 6/6] Update requirements.txt --- requirements/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 6e24b07b8639..651eb66e89ab 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -8,7 +8,7 @@ click fabric contexttimer ninja -torch>=2.4.0 +torch>=2.1.0,<=2.3.0 safetensors einops pydantic