diff --git a/applications/Chat/coati/trainer/strategies/colossalai.py b/applications/Chat/coati/trainer/strategies/colossalai.py index fa55f97ad661..003fb28b6951 100644 --- a/applications/Chat/coati/trainer/strategies/colossalai.py +++ b/applications/Chat/coati/trainer/strategies/colossalai.py @@ -1,17 +1,13 @@ import warnings from typing import Optional -import torch -import torch.distributed as dist import torch.nn as nn import colossalai from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin -from colossalai.booster.plugin.gemini_plugin import GeminiModel from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel -from colossalai.tensor import ProcessGroup, ShardSpec +from colossalai.lazy.lazy_init import LazyInitContext from colossalai.utils import get_current_device -from colossalai.zero import ColoInitContext from colossalai.zero.gemini.gemini_ddp import GeminiDDP from .ddp import DDPStrategy @@ -42,47 +38,42 @@ class LowLevelZeroStrategy(DDPStrategy): """ - def __init__(self, - stage: int = 2, - precision: str = 'fp16', - seed: int = 42, - placement_policy: str = 'cuda', - reduce_bucket_size: int = 12 * 1024**2, # only for stage 1&2 - overlap_communication: bool = True, # only for stage 1&2 - initial_scale: float = 2**16, - growth_factor: float = 2, - backoff_factor: float = 0.5, - growth_interval: int = 1000, - hysteresis: int = 2, - min_scale: float = 1, - max_scale: float = 2**32, - max_norm: float = 0.0, - norm_type: float = 2.0 - ) -> None: + def __init__( + self, + stage: int = 2, + precision: str = 'fp16', + seed: int = 42, + placement_policy: str = 'cuda', + reduce_bucket_size: int = 12 * 1024**2, # only for stage 1&2 + overlap_communication: bool = True, # only for stage 1&2 + initial_scale: float = 2**16, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + min_scale: float = 1, + max_scale: float = 2**32, + max_norm: float = 0.0, + norm_type: float = 2.0) -> None: assert stage in (1, 2), f'Unsupported stage "{stage}"' assert placement_policy in ('cpu', 'cuda'), f'Unsupported placement policy "{placement_policy}"' assert precision in ('fp32', 'fp16'), f'Unsupported precision "{precision}"' - plugin_initializer = lambda: LowLevelZeroPlugin( - # zero_config - stage=stage, - precision=precision, - # zero_optim_config - reduce_bucket_size_in_m=reduce_bucket_size, - overlap_communication=overlap_communication, - cpu_offload=(placement_policy == 'cpu'), - # optim_config - initial_scale=initial_scale, - growth_factor=growth_factor, - backoff_factor=backoff_factor, - growth_interval=growth_interval, - hysteresis=hysteresis, - min_scale=min_scale, - max_scale=max_scale, - max_norm=max_norm, - norm_type=norm_type - ) + plugin_initializer = lambda: LowLevelZeroPlugin(stage=stage, + precision=precision, + reduce_bucket_size_in_m=reduce_bucket_size, + overlap_communication=overlap_communication, + cpu_offload=(placement_policy == 'cpu'), + initial_scale=initial_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + min_scale=min_scale, + max_scale=max_scale, + max_norm=max_norm, + norm_type=norm_type) super().__init__(seed, plugin_initializer) @@ -131,64 +122,55 @@ class GeminiStrategy(DDPStrategy): """ - def __init__(self, - seed: int = 42, - shard_init: bool = False, # only for stage 3 - placement_policy: str = 'cuda', - pin_memory: bool = True, # only for stage 3 - force_outputs_fp32: bool = False, # only for stage 3 - search_range_m: int = 32, # only for stage 3 - hidden_dim: Optional[int] = None, # only for stage 3 - min_chunk_size_m: float = 32, # only for stage 3 - gpu_margin_mem_ratio: float = 0.0, # only for stage 3 - initial_scale: float = 2**16, - growth_factor: float = 2, - backoff_factor: float = 0.5, - growth_interval: int = 1000, - hysteresis: int = 2, - min_scale: float = 1, - max_scale: float = 2**32, - max_norm: float = 0.0, - norm_type: float = 2.0 - ) -> None: - - assert placement_policy in ('cpu', 'cuda'), f'Unsupported placement policy "{placement_policy}"' + def __init__( + self, + seed: int = 42, + shard_init: bool = False, # only for stage 3 + placement_policy: str = 'auto', + pin_memory: bool = True, # only for stage 3 + force_outputs_fp32: bool = False, # only for stage 3 + search_range_m: int = 32, # only for stage 3 + hidden_dim: Optional[int] = None, # only for stage 3 + min_chunk_size_m: float = 32, # only for stage 3 + gpu_margin_mem_ratio: float = 0.0, # only for stage 3 + initial_scale: float = 2**16, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + min_scale: float = 1, + max_scale: float = 2**32, + max_norm: float = 0.0, + norm_type: float = 2.0) -> None: # TODO(ver217): support shard_init when using from_pretrained() if shard_init: - warnings.warn( - f'Shard init is not supported model.from_pretrained() yet. ' - 'Please load weights after strategy.prepare()' - ) + warnings.warn(f'Shard init is not supported model.from_pretrained() yet. ' + 'Please load weights after strategy.prepare()') self.shard_init = shard_init warnings.warn(f'Stage 3 only supports fp16. Precision is set to fp16.') # NOTE: dist should be initialized before calling get_current_device() - plugin_initializer = lambda: GeminiPlugin( - # gemini_config - device=get_current_device(), - placement_policy=placement_policy, - precision='fp16', - pin_memory=pin_memory, - force_outputs_fp32=force_outputs_fp32, - strict_ddp_mode=shard_init, - search_range_m=search_range_m, - hidden_dim=hidden_dim, - min_chunk_size_m=min_chunk_size_m, - # zero_optim_config - gpu_margin_mem_ratio=gpu_margin_mem_ratio, - # optim_config - initial_scale=initial_scale, - growth_factor=growth_factor, - backoff_factor=backoff_factor, - growth_interval=growth_interval, - hysteresis=hysteresis, - min_scale=min_scale, - max_scale=max_scale, - max_norm=max_norm, - norm_type=norm_type - ) + plugin_initializer = lambda: GeminiPlugin(chunk_init_device=get_current_device(), + placement_policy=placement_policy, + precision='fp16', + pin_memory=pin_memory, + force_outputs_fp32=force_outputs_fp32, + strict_ddp_mode=shard_init, + search_range_m=search_range_m, + hidden_dim=hidden_dim, + min_chunk_size_m=min_chunk_size_m, + gpu_margin_mem_ratio=gpu_margin_mem_ratio, + initial_scale=initial_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + min_scale=min_scale, + max_scale=max_scale, + max_norm=max_norm, + norm_type=norm_type) super().__init__(seed, plugin_initializer) @@ -200,16 +182,8 @@ def setup_distributed(self) -> None: colossalai.launch_from_torch({}, seed=self.seed) def model_init_context(self): - world_size = dist.get_world_size() - shard_pg = ProcessGroup(tp_degree=world_size) if self.shard_init else None - default_dist_spec = ShardSpec([-1], [world_size]) if self.shard_init else None - return ColoInitContext(device=get_current_device(), - dtype=torch.half, - default_pg=shard_pg, - default_dist_spec=default_dist_spec) + return LazyInitContext(default_device=get_current_device()) def unwrap_model(self, model: nn.Module) -> nn.Module: - assert isinstance(model, GeminiModel) - ddp_model = model.unwrap() - assert isinstance(ddp_model, GeminiDDP) - return ddp_model.module + assert isinstance(model, GeminiDDP) + return model.module diff --git a/applications/Chat/examples/requirements.txt b/applications/Chat/examples/requirements.txt index 5d0f9f927d17..a7cfb5da7fe1 100644 --- a/applications/Chat/examples/requirements.txt +++ b/applications/Chat/examples/requirements.txt @@ -1,3 +1,3 @@ pandas>=1.4.1 sentencepiece -colossalai==0.3.1 \ No newline at end of file +colossalai>=0.3.1 diff --git a/applications/Chat/examples/train_prompts.py b/applications/Chat/examples/train_prompts.py index d27a70a3fef6..a52efa635ca6 100644 --- a/applications/Chat/examples/train_prompts.py +++ b/applications/Chat/examples/train_prompts.py @@ -23,7 +23,7 @@ def main(args): if args.strategy == 'ddp': strategy = DDPStrategy() elif args.strategy == 'colossalai_gemini': - strategy = GeminiStrategy(placement_policy='cuda', initial_scale=2**5) + strategy = GeminiStrategy(placement_policy='auto', initial_scale=2**5) elif args.strategy == 'colossalai_zero2': strategy = LowLevelZeroStrategy(stage=2, placement_policy='cuda') else: diff --git a/applications/Chat/examples/train_reward_model.py b/applications/Chat/examples/train_reward_model.py index 190460bc20f6..db11ba78e8a8 100644 --- a/applications/Chat/examples/train_reward_model.py +++ b/applications/Chat/examples/train_reward_model.py @@ -27,7 +27,7 @@ def train(args): if args.strategy == 'ddp': strategy = DDPStrategy() elif args.strategy == 'colossalai_gemini': - strategy = GeminiStrategy(placement_policy='cuda') + strategy = GeminiStrategy(placement_policy='auto') elif args.strategy == 'colossalai_zero2': strategy = LowLevelZeroStrategy(stage=2, placement_policy='cuda') else: diff --git a/applications/Chat/examples/train_sft.py b/applications/Chat/examples/train_sft.py index f068ea2bf5de..15e470768ee4 100644 --- a/applications/Chat/examples/train_sft.py +++ b/applications/Chat/examples/train_sft.py @@ -6,24 +6,23 @@ import torch.distributed as dist from coati.dataset import SFTDataset, SupervisedDataset from coati.models.bloom import BLOOMActor +from coati.models.chatglm import ChatGLMActor +from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer from coati.models.gpt import GPTActor from coati.models.llama import LlamaActor from coati.models.opt import OPTActor -from coati.models.chatglm import ChatGLMActor from coati.trainer import SFTTrainer from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy from datasets import load_dataset from torch.optim import Adam from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler -from transformers import AutoTokenizer, BloomTokenizerFast, LlamaTokenizer, AutoModel -from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer +from transformers import AutoTokenizer, BloomTokenizerFast, LlamaTokenizer from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer from transformers.trainer import get_scheduler from colossalai.logging import get_dist_logger from colossalai.nn.optimizer import HybridAdam -from colossalai.tensor import ColoParameter def train(args): @@ -31,7 +30,7 @@ def train(args): if args.strategy == 'ddp': strategy = DDPStrategy() elif args.strategy == 'colossalai_gemini': - strategy = GeminiStrategy(placement_policy='cuda') + strategy = GeminiStrategy(placement_policy='auto') elif args.strategy == 'colossalai_zero2': strategy = LowLevelZeroStrategy(stage=2, placement_policy='cuda') elif args.strategy == 'colossalai_zero2_cpu': @@ -91,16 +90,6 @@ def train(args): else: raise ValueError(f'Unsupported model "{args.model}"') - if args.model == 'llama' and args.strategy == 'colossalai_gemini': - # this is a hack to deal with the resized embedding - # to make sure all parameters are ColoParameter for Colossal-AI Gemini Compatibility - for name, param in model.named_parameters(): - if not isinstance(param, ColoParameter): - sub_module_name = '.'.join(name.split('.')[:-1]) - weight_name = name.split('.')[-1] - sub_module = model.get_submodule(sub_module_name) - setattr(sub_module, weight_name, ColoParameter(param)) - # configure optimizer if args.strategy.startswith('colossalai'): optim = HybridAdam(model.parameters(), lr=args.lr, clipping_norm=1.0) diff --git a/applications/Chat/requirements-test.txt b/applications/Chat/requirements-test.txt index eb1a77875acb..adf2cc1bf545 100644 --- a/applications/Chat/requirements-test.txt +++ b/applications/Chat/requirements-test.txt @@ -1,2 +1,2 @@ pytest -colossalai==0.3.1 \ No newline at end of file +colossalai>=0.3.1 diff --git a/applications/Chat/requirements.txt b/applications/Chat/requirements.txt index e5f5ca0932a8..93276b069671 100644 --- a/applications/Chat/requirements.txt +++ b/applications/Chat/requirements.txt @@ -2,7 +2,7 @@ transformers>=4.20.1 tqdm datasets loralib -colossalai==0.3.1 +colossalai>=0.3.1 torch<2.0.0, >=1.12.1 langchain tokenizers diff --git a/applications/Chat/tests/test_checkpoint.py b/applications/Chat/tests/test_checkpoint.py index 3a3bf5b19cb8..605d86760dad 100644 --- a/applications/Chat/tests/test_checkpoint.py +++ b/applications/Chat/tests/test_checkpoint.py @@ -40,7 +40,7 @@ def run_test_checkpoint(strategy_name: str, if strategy_name == "ddp": strategy = DDPStrategy() elif strategy_name == "colossalai_gemini": - strategy = GeminiStrategy(placement_policy="cuda", initial_scale=2**5) + strategy = GeminiStrategy(placement_policy="auto", initial_scale=2**5) elif strategy_name == "colossalai_zero2": strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda") else: