Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 76 additions & 102 deletions applications/Chat/coati/trainer/strategies/colossalai.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,13 @@
import warnings
from typing import Optional

import torch
import torch.distributed as dist
import torch.nn as nn

import colossalai
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin
from colossalai.booster.plugin.gemini_plugin import GeminiModel
from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel
from colossalai.tensor import ProcessGroup, ShardSpec
from colossalai.lazy.lazy_init import LazyInitContext
from colossalai.utils import get_current_device
from colossalai.zero import ColoInitContext
from colossalai.zero.gemini.gemini_ddp import GeminiDDP

from .ddp import DDPStrategy
Expand Down Expand Up @@ -42,47 +38,42 @@ class LowLevelZeroStrategy(DDPStrategy):

"""

def __init__(self,
stage: int = 2,
precision: str = 'fp16',
seed: int = 42,
placement_policy: str = 'cuda',
reduce_bucket_size: int = 12 * 1024**2, # only for stage 1&2
overlap_communication: bool = True, # only for stage 1&2
initial_scale: float = 2**16,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
min_scale: float = 1,
max_scale: float = 2**32,
max_norm: float = 0.0,
norm_type: float = 2.0
) -> None:
def __init__(
self,
stage: int = 2,
precision: str = 'fp16',
seed: int = 42,
placement_policy: str = 'cuda',
reduce_bucket_size: int = 12 * 1024**2, # only for stage 1&2
overlap_communication: bool = True, # only for stage 1&2
initial_scale: float = 2**16,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
min_scale: float = 1,
max_scale: float = 2**32,
max_norm: float = 0.0,
norm_type: float = 2.0) -> None:

assert stage in (1, 2), f'Unsupported stage "{stage}"'
assert placement_policy in ('cpu', 'cuda'), f'Unsupported placement policy "{placement_policy}"'
assert precision in ('fp32', 'fp16'), f'Unsupported precision "{precision}"'

plugin_initializer = lambda: LowLevelZeroPlugin(
# zero_config
stage=stage,
precision=precision,
# zero_optim_config
reduce_bucket_size_in_m=reduce_bucket_size,
overlap_communication=overlap_communication,
cpu_offload=(placement_policy == 'cpu'),
# optim_config
initial_scale=initial_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
min_scale=min_scale,
max_scale=max_scale,
max_norm=max_norm,
norm_type=norm_type
)
plugin_initializer = lambda: LowLevelZeroPlugin(stage=stage,
precision=precision,
reduce_bucket_size_in_m=reduce_bucket_size,
overlap_communication=overlap_communication,
cpu_offload=(placement_policy == 'cpu'),
initial_scale=initial_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
min_scale=min_scale,
max_scale=max_scale,
max_norm=max_norm,
norm_type=norm_type)

super().__init__(seed, plugin_initializer)

Expand Down Expand Up @@ -131,64 +122,55 @@ class GeminiStrategy(DDPStrategy):

"""

def __init__(self,
seed: int = 42,
shard_init: bool = False, # only for stage 3
placement_policy: str = 'cuda',
pin_memory: bool = True, # only for stage 3
force_outputs_fp32: bool = False, # only for stage 3
search_range_m: int = 32, # only for stage 3
hidden_dim: Optional[int] = None, # only for stage 3
min_chunk_size_m: float = 32, # only for stage 3
gpu_margin_mem_ratio: float = 0.0, # only for stage 3
initial_scale: float = 2**16,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
min_scale: float = 1,
max_scale: float = 2**32,
max_norm: float = 0.0,
norm_type: float = 2.0
) -> None:

assert placement_policy in ('cpu', 'cuda'), f'Unsupported placement policy "{placement_policy}"'
def __init__(
self,
seed: int = 42,
shard_init: bool = False, # only for stage 3
placement_policy: str = 'auto',
pin_memory: bool = True, # only for stage 3
force_outputs_fp32: bool = False, # only for stage 3
search_range_m: int = 32, # only for stage 3
hidden_dim: Optional[int] = None, # only for stage 3
min_chunk_size_m: float = 32, # only for stage 3
gpu_margin_mem_ratio: float = 0.0, # only for stage 3
initial_scale: float = 2**16,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
min_scale: float = 1,
max_scale: float = 2**32,
max_norm: float = 0.0,
norm_type: float = 2.0) -> None:

# TODO(ver217): support shard_init when using from_pretrained()
if shard_init:
warnings.warn(
f'Shard init is not supported model.from_pretrained() yet. '
'Please load weights after strategy.prepare()'
)
warnings.warn(f'Shard init is not supported model.from_pretrained() yet. '
'Please load weights after strategy.prepare()')
self.shard_init = shard_init

warnings.warn(f'Stage 3 only supports fp16. Precision is set to fp16.')

# NOTE: dist should be initialized before calling get_current_device()
plugin_initializer = lambda: GeminiPlugin(
# gemini_config
device=get_current_device(),
placement_policy=placement_policy,
precision='fp16',
pin_memory=pin_memory,
force_outputs_fp32=force_outputs_fp32,
strict_ddp_mode=shard_init,
search_range_m=search_range_m,
hidden_dim=hidden_dim,
min_chunk_size_m=min_chunk_size_m,
# zero_optim_config
gpu_margin_mem_ratio=gpu_margin_mem_ratio,
# optim_config
initial_scale=initial_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
min_scale=min_scale,
max_scale=max_scale,
max_norm=max_norm,
norm_type=norm_type
)
plugin_initializer = lambda: GeminiPlugin(chunk_init_device=get_current_device(),
placement_policy=placement_policy,
precision='fp16',
pin_memory=pin_memory,
force_outputs_fp32=force_outputs_fp32,
strict_ddp_mode=shard_init,
search_range_m=search_range_m,
hidden_dim=hidden_dim,
min_chunk_size_m=min_chunk_size_m,
gpu_margin_mem_ratio=gpu_margin_mem_ratio,
initial_scale=initial_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
min_scale=min_scale,
max_scale=max_scale,
max_norm=max_norm,
norm_type=norm_type)

super().__init__(seed, plugin_initializer)

Expand All @@ -200,16 +182,8 @@ def setup_distributed(self) -> None:
colossalai.launch_from_torch({}, seed=self.seed)

def model_init_context(self):
world_size = dist.get_world_size()
shard_pg = ProcessGroup(tp_degree=world_size) if self.shard_init else None
default_dist_spec = ShardSpec([-1], [world_size]) if self.shard_init else None
return ColoInitContext(device=get_current_device(),
dtype=torch.half,
default_pg=shard_pg,
default_dist_spec=default_dist_spec)
return LazyInitContext(default_device=get_current_device())

def unwrap_model(self, model: nn.Module) -> nn.Module:
assert isinstance(model, GeminiModel)
ddp_model = model.unwrap()
assert isinstance(ddp_model, GeminiDDP)
return ddp_model.module
assert isinstance(model, GeminiDDP)
return model.module
2 changes: 1 addition & 1 deletion applications/Chat/examples/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
pandas>=1.4.1
sentencepiece
colossalai==0.3.1
colossalai>=0.3.1
2 changes: 1 addition & 1 deletion applications/Chat/examples/train_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def main(args):
if args.strategy == 'ddp':
strategy = DDPStrategy()
elif args.strategy == 'colossalai_gemini':
strategy = GeminiStrategy(placement_policy='cuda', initial_scale=2**5)
strategy = GeminiStrategy(placement_policy='auto', initial_scale=2**5)
elif args.strategy == 'colossalai_zero2':
strategy = LowLevelZeroStrategy(stage=2, placement_policy='cuda')
else:
Expand Down
2 changes: 1 addition & 1 deletion applications/Chat/examples/train_reward_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def train(args):
if args.strategy == 'ddp':
strategy = DDPStrategy()
elif args.strategy == 'colossalai_gemini':
strategy = GeminiStrategy(placement_policy='cuda')
strategy = GeminiStrategy(placement_policy='auto')
elif args.strategy == 'colossalai_zero2':
strategy = LowLevelZeroStrategy(stage=2, placement_policy='cuda')
else:
Expand Down
19 changes: 4 additions & 15 deletions applications/Chat/examples/train_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,32 +6,31 @@
import torch.distributed as dist
from coati.dataset import SFTDataset, SupervisedDataset
from coati.models.bloom import BLOOMActor
from coati.models.chatglm import ChatGLMActor
from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
from coati.models.gpt import GPTActor
from coati.models.llama import LlamaActor
from coati.models.opt import OPTActor
from coati.models.chatglm import ChatGLMActor
from coati.trainer import SFTTrainer
from coati.trainer.strategies import DDPStrategy, GeminiStrategy, LowLevelZeroStrategy
from datasets import load_dataset
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from transformers import AutoTokenizer, BloomTokenizerFast, LlamaTokenizer, AutoModel
from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
from transformers import AutoTokenizer, BloomTokenizerFast, LlamaTokenizer
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
from transformers.trainer import get_scheduler

from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import HybridAdam
from colossalai.tensor import ColoParameter


def train(args):
# configure strategy
if args.strategy == 'ddp':
strategy = DDPStrategy()
elif args.strategy == 'colossalai_gemini':
strategy = GeminiStrategy(placement_policy='cuda')
strategy = GeminiStrategy(placement_policy='auto')
elif args.strategy == 'colossalai_zero2':
strategy = LowLevelZeroStrategy(stage=2, placement_policy='cuda')
elif args.strategy == 'colossalai_zero2_cpu':
Expand Down Expand Up @@ -91,16 +90,6 @@ def train(args):
else:
raise ValueError(f'Unsupported model "{args.model}"')

if args.model == 'llama' and args.strategy == 'colossalai_gemini':
# this is a hack to deal with the resized embedding
# to make sure all parameters are ColoParameter for Colossal-AI Gemini Compatibility
for name, param in model.named_parameters():
if not isinstance(param, ColoParameter):
sub_module_name = '.'.join(name.split('.')[:-1])
weight_name = name.split('.')[-1]
sub_module = model.get_submodule(sub_module_name)
setattr(sub_module, weight_name, ColoParameter(param))

# configure optimizer
if args.strategy.startswith('colossalai'):
optim = HybridAdam(model.parameters(), lr=args.lr, clipping_norm=1.0)
Expand Down
2 changes: 1 addition & 1 deletion applications/Chat/requirements-test.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
pytest
colossalai==0.3.1
colossalai>=0.3.1
2 changes: 1 addition & 1 deletion applications/Chat/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ transformers>=4.20.1
tqdm
datasets
loralib
colossalai==0.3.1
colossalai>=0.3.1
torch<2.0.0, >=1.12.1
langchain
tokenizers
Expand Down
2 changes: 1 addition & 1 deletion applications/Chat/tests/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def run_test_checkpoint(strategy_name: str,
if strategy_name == "ddp":
strategy = DDPStrategy()
elif strategy_name == "colossalai_gemini":
strategy = GeminiStrategy(placement_policy="cuda", initial_scale=2**5)
strategy = GeminiStrategy(placement_policy="auto", initial_scale=2**5)
elif strategy_name == "colossalai_zero2":
strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
else:
Expand Down