Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions colossalai/booster/booster.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import warnings
from contextlib import contextmanager
from typing import Any, Callable, Dict, Iterator, List, Optional, Union

Expand All @@ -8,6 +7,8 @@
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader

from colossalai.logging import get_dist_logger

SUPPORT_PEFT = False
try:
import peft
Expand Down Expand Up @@ -81,20 +82,26 @@ def __init__(
plugin, Plugin
), f"Expected the argument plugin to be an instance of Plugin, but got {type(plugin)}."
self.plugin = plugin
self.logger = get_dist_logger()

# set accelerator
if self.plugin and self.plugin.control_device():
self.accelerator = None
if device is not None:
warnings.warn("The plugin will control the accelerator, so the device argument will be ignored.")
self.logger.warning(
"The plugin will control the accelerator," "so the device argument will be ignored.", ranks=[0]
)
else:
device = device or "cuda"
self.accelerator = Accelerator(device)

# set precision
if self.plugin and self.plugin.control_precision():
if mixed_precision is not None:
warnings.warn("The plugin will control the precision, so the mixed_precision argument will be ignored.")
self.logger.warning(
"The plugin will control the precision," "so the mixed_precision argument will be ignored.",
ranks=[0],
)
self.mixed_precision = None
elif mixed_precision is None:
self.mixed_precision = None
Expand Down Expand Up @@ -267,8 +274,9 @@ def enable_lora(
), "Please provide pretrained directory path if not passing in lora configuration."
if quantize is True:
if bnb_quantization_config is not None:
warnings.warn(
"User defined BnbQuantizationConfig is not fully tested in ColossalAI. Use it at your own risk."
self.logger.warning(
"User defined BnbQuantizationConfig is not fully tested in ColossalAI. Use it at your own risk.",
ranks=[0],
)
else:
bnb_quantization_config = BnbQuantizationConfig(
Expand Down
26 changes: 16 additions & 10 deletions colossalai/booster/plugin/gemini_plugin.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import gc
import logging
import os
import random
from pathlib import Path
Expand Down Expand Up @@ -27,6 +26,7 @@
)
from colossalai.cluster import DistCoordinator, ProcessGroupMesh
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.logging import get_dist_logger
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.memory_tracer import MemStats
Expand Down Expand Up @@ -63,6 +63,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
def __init__(self) -> None:
super().__init__()
self.coordinator = DistCoordinator()
self.logger = get_dist_logger()

def save_unsharded_model(self, model: GeminiDDP, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
"""
Expand Down Expand Up @@ -118,7 +119,7 @@ def save_sharded_model(
"""
assert isinstance(model, GeminiDDP), "Please boost the model before saving!"
if os.path.isfile(checkpoint_path):
logging.error(f"Provided path ({checkpoint_path}) should be a directory, not a file")
self.logger.error(f"Provided path ({checkpoint_path}) should be a directory, not a file", ranks=[0])
return

Path(checkpoint_path).mkdir(parents=True, exist_ok=True)
Expand All @@ -143,10 +144,11 @@ def save_sharded_model(
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
save_config_file(model.unwrap(), checkpoint_path)
logging.info(
self.logger.info(
f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
f"index located at {save_index_file}.",
ranks=[0],
)

def load_sharded_model(
Expand All @@ -168,7 +170,7 @@ def save_sharded_optimizer(
assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before saving!"

if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
self.logger.error(f"Provided path ({checkpoint}) should be a directory, not a file", ranks=[0])
return

Path(checkpoint).mkdir(parents=True, exist_ok=True)
Expand Down Expand Up @@ -201,10 +203,11 @@ def save_sharded_optimizer(
if self.coordinator.is_master():
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
logging.info(
self.logger.info(
f"The optimizer is going to be split to checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
f"index located at {save_index_file}.",
ranks=[0],
)

def load_sharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint_index_file: Path, prefix: str):
Expand All @@ -214,7 +217,7 @@ def load_sharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint_index_fi
"""
assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before loading!"
if not os.path.isfile(checkpoint_index_file):
logging.error(f"Provided path ({checkpoint_index_file}) should be a file")
self.logger.error(f"Provided path ({checkpoint_index_file}) should be a file", ranks=[0])

assert isinstance(optimizer, GeminiOptimizer)

Expand Down Expand Up @@ -369,9 +372,12 @@ def __init__(
assert precision in SUPPORTED_PRECISION, f"precision {precision} is not supported"
if get_accelerator().name == "npu":
assert placement_policy == "static", "NPU only supports static placement policy"

self.logger = get_dist_logger()
if enable_async_reduce and not pin_memory:
logging.warning(
f"enable_async_reduce sets pin_memory=True to achieve best performance, which is not implicitly set."
self.logger.warning(
f"enable_async_reduce sets pin_memory=True to achieve best performance, which is not implicitly set.",
ranks=[0],
)
pin_memory = True
self.gemini_config = dict(
Expand Down
35 changes: 22 additions & 13 deletions colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import ctypes
import random
import warnings
from collections import defaultdict
from contextlib import contextmanager, nullcontext
from copy import deepcopy
Expand All @@ -27,6 +26,7 @@
from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.interface.optimizer import DistributedOptim
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager
Expand Down Expand Up @@ -1023,6 +1023,7 @@ def __init__(
inner_ring_size: int = None,
) -> None:
super().__init__()
self.logger = get_dist_logger()

assert (
dist.get_world_size() % (tp_size * pp_size) == 0
Expand All @@ -1040,8 +1041,9 @@ def __init__(
tp_size > 1
), f"Sequence parallelism mode {self.sequence_parallelism_mode} must be enabled when using tensor parallelism"
if sp_size != 1:
warnings.warn(
f"The sp_size will be the same as tp_size in sequence parallelism mode {self.sequence_parallelism_mode}, will ignore the given sequence parallelism size."
self.logger.warning(
f"The sp_size will be the same as tp_size in sequence parallelism mode {self.sequence_parallelism_mode}, will ignore the given sequence parallelism size.",
ranks=[0],
)
self.sp_size = 1
self.dp_size = dist.get_world_size() // (tp_size * pp_size)
Expand Down Expand Up @@ -1126,7 +1128,12 @@ def __init__(
else:
raise NotImplementedError()
if sequence_parallelism_mode == "ring_attn":
assert parallel_output, "Ring Attention doesn't support gathering output yet."
if not parallel_output:
self.logger.warning(
"parallel_output must be True for Zigzag Ring Attention, as we've not supported Zigzag all-gather yet.",
ranks=[0],
)
parallel_output = True

self.tp_group = self.pg_mesh.get_group_along_axis(self.tp_axis)
self.dp_group = self.pg_mesh.get_group_along_axis(self.dp_axis)
Expand Down Expand Up @@ -1231,7 +1238,10 @@ def configure(
optimizer = cast_to_distributed(optimizer)

if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and self.dp_size > 0:
warnings.warn("Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.")
self.logger.warning(
"Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.",
ranks=[0],
)
zero_config["partition_grad"] = False
zero_stage = 0

Expand Down Expand Up @@ -1287,9 +1297,10 @@ def configure(
else:
is_zero = self.dp_size > 1
if self.dp_size == 1:
warnings.warn(
self.logger.warning(
"Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. "
"If you do not intend to use cpu_offload, please consider set zero_stage=0."
"If you do not intend to use cpu_offload, please consider set zero_stage=0.",
ranks=[0],
)

assert self.precision != "fp32", "Please set precision to 'fp16' or 'bf16' when using ZeRO."
Expand Down Expand Up @@ -1332,7 +1343,7 @@ def execute_pipeline(
assert self.enable_pipeline_parallelism, "pipeline parallelism is not enabled"

if return_outputs:
warnings.warn("return_outputs may lead to significant extra memory consumption.")
self.logger.warning("return_outputs may lead to significant extra memory consumption.", ranks=[0])

# Create a context for gradient synchronization based on the optimizer type.
# If it's a HybridParallelZeroOptimizer, use optimizer.no_sync(); otherwise, use model.no_sync().
Expand All @@ -1346,10 +1357,8 @@ def execute_pipeline(
)

# run with gradients accumulation
if (
model.require_grad_sync == False
or (isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False)
or not torch.is_grad_enabled()
if model.require_grad_sync == False or (
isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False
):
return outputs

Expand Down Expand Up @@ -1449,7 +1458,7 @@ def enable_lora(
assert not isinstance(model, HybridParallelModule), "Lora should be enabled before boosting the model."
assert self.pp_size == 1 and self.tp_size == 1
self.lora_enabled = True
warnings.warn("You have enabled LoRa training. Please check the hyperparameters such as lr")
self.logger.warning("You have enabled LoRa training. Please check the hyperparameters such as lr", ranks=[0])

if bnb_quantization_config is not None:
model = quantize_model(model, bnb_quantization_config)
Expand Down
39 changes: 18 additions & 21 deletions colossalai/booster/plugin/low_level_zero_plugin.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import enum
import logging
import os
import warnings
from contextlib import nullcontext
from functools import partial
from pathlib import Path
Expand Down Expand Up @@ -33,6 +31,7 @@
)
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.interface.optimizer import DistributedOptim
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
from colossalai.quantization import BnbQuantizationConfig, quantize_model
from colossalai.tensor.colo_parameter import ColoParameter
Expand Down Expand Up @@ -62,9 +61,7 @@ class OptimizerParamCheckState(enum.Enum):


class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
def __init__(
self, module: nn.Module, precision: str, overlap_allgather: bool = False, cast_inputs: bool = True
) -> None:
def __init__(self, module: nn.Module, precision: str, overlap_allgather: bool = False) -> None:
super().__init__(module)
self.dtype = None
if precision == "fp16":
Expand All @@ -76,7 +73,7 @@ def __init__(
module = module.to(get_accelerator().get_current_device())
self.module = module
self.convert_fn = None
if self.dtype is not None and cast_inputs:
if self.dtype is not None:
self.convert_fn = partial(_convert_floating_point, dtype=self.dtype)
self.overlap_allgather = overlap_allgather
if overlap_allgather:
Expand Down Expand Up @@ -140,7 +137,7 @@ def save_sharded_optimizer(
"""
assert isinstance(optimizer, LowLevelZeroOptimizer), "Please boost the optimizer before saving!"
if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
self.logger.error(f"Provided path ({checkpoint}) should be a directory, not a file", ranks=[0])
return

Path(checkpoint).mkdir(parents=True, exist_ok=True)
Expand Down Expand Up @@ -177,10 +174,11 @@ def save_sharded_optimizer(
index_file.append_meta_data("total_size", total_size)
if self.coordinator.is_master():
index_file.write_index_file(save_index_file)
logging.info(
self.logger.info(
f"The optimizer is going to be split to checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
f"index located at {save_index_file}.",
ranks=[0],
)

def load_sharded_optimizer(self, optimizer: OptimizerWrapper, index_file_path: str, prefix: str):
Expand Down Expand Up @@ -267,7 +265,7 @@ def save_sharded_model(

def save_lora_as_pretrained(self, model, checkpoint, use_safetensors):
if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
self.logger.error(f"Provided path ({checkpoint}) should be a directory, not a file", ranks=[0])
return
from peft import PeftModel

Expand Down Expand Up @@ -336,7 +334,6 @@ def __init__(
cpu_offload: bool = False,
master_weights: bool = True,
verbose: bool = False,
cast_inputs: bool = True,
) -> None:
super().__init__()
assert stage in (1, 2), f"LowLevelZeroPlugin only supports stage 1/2 training"
Expand All @@ -363,8 +360,7 @@ def __init__(
)
self.lora_enabled = False
self.verbose = verbose
self.cast_inputs = cast_inputs

self.logger = get_dist_logger()
# set class name with stage, for better error message
setattr(self.__class__, "__name__", f"LowLevelZeroPlugin_ZeRO-{stage}")

Expand Down Expand Up @@ -400,7 +396,7 @@ def enable_lora(

assert not isinstance(model, LowLevelZeroModel), "Lora should be enabled before boosting the model."
self.lora_enabled = True
warnings.warn("You have enabled LoRa training. Please check the hyperparameters such as lr")
self.logger.warning("You have enabled LoRa training. Please check the hyperparameters such as lr", ranks=[0])

if bnb_quantization_config is not None:
model = quantize_model(model, bnb_quantization_config)
Expand Down Expand Up @@ -449,8 +445,9 @@ def add_lora_params_to_optimizer(self, model, optimizer):
origin_param = name2param[origin_key]
group_id, check_state = self.get_param_group_id(optimizer, origin_param, param)
if check_state == OptimizerParamCheckState.ORIGIN_PARAM_NOT_FIND:
warnings.warn(
f"Origin parameter {origin_key} related to {name} doesn't exist in optimizer param_groups."
self.logger.warning(
f"Origin parameter {origin_key} related to {name} doesn't exist in optimizer param_groups.",
ranks=[0],
)
elif (
check_state == OptimizerParamCheckState.ORIGIN_PARAM_FINDED
Expand Down Expand Up @@ -478,10 +475,7 @@ def configure(

if not isinstance(model, ModelWrapper):
model = LowLevelZeroModel(
model,
self.precision,
overlap_allgather=self.zero_optim_kwargs["overlap_allgather"],
cast_inputs=self.cast_inputs,
model, self.precision, overlap_allgather=self.zero_optim_kwargs["overlap_allgather"]
)

# TODO: Support Galore + ZeRO
Expand All @@ -493,7 +487,10 @@ def configure(
optimizer = cast_to_distributed(optimizer)

if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and dp_size > 0:
warnings.warn("Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.")
self.logger.warning(
"Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.",
ranks=[0],
)
zero_optim_kwargs["partition_grad"] = False
zero_stage = 0

Expand Down
Loading