diff --git a/applications/Chat/coati/trainer/strategies/colossalai.py b/applications/Chat/coati/trainer/strategies/colossalai.py index fa55f97ad661..5d3fd5db7640 100644 --- a/applications/Chat/coati/trainer/strategies/colossalai.py +++ b/applications/Chat/coati/trainer/strategies/colossalai.py @@ -7,7 +7,6 @@ import colossalai from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin -from colossalai.booster.plugin.gemini_plugin import GeminiModel from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel from colossalai.tensor import ProcessGroup, ShardSpec from colossalai.utils import get_current_device @@ -42,47 +41,42 @@ class LowLevelZeroStrategy(DDPStrategy): """ - def __init__(self, - stage: int = 2, - precision: str = 'fp16', - seed: int = 42, - placement_policy: str = 'cuda', - reduce_bucket_size: int = 12 * 1024**2, # only for stage 1&2 - overlap_communication: bool = True, # only for stage 1&2 - initial_scale: float = 2**16, - growth_factor: float = 2, - backoff_factor: float = 0.5, - growth_interval: int = 1000, - hysteresis: int = 2, - min_scale: float = 1, - max_scale: float = 2**32, - max_norm: float = 0.0, - norm_type: float = 2.0 - ) -> None: + def __init__( + self, + stage: int = 2, + precision: str = 'fp16', + seed: int = 42, + placement_policy: str = 'cuda', + reduce_bucket_size: int = 12 * 1024**2, # only for stage 1&2 + overlap_communication: bool = True, # only for stage 1&2 + initial_scale: float = 2**16, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + min_scale: float = 1, + max_scale: float = 2**32, + max_norm: float = 0.0, + norm_type: float = 2.0) -> None: assert stage in (1, 2), f'Unsupported stage "{stage}"' assert placement_policy in ('cpu', 'cuda'), f'Unsupported placement policy "{placement_policy}"' assert precision in ('fp32', 'fp16'), f'Unsupported precision "{precision}"' - plugin_initializer = lambda: LowLevelZeroPlugin( - # zero_config - stage=stage, - precision=precision, - # zero_optim_config - reduce_bucket_size_in_m=reduce_bucket_size, - overlap_communication=overlap_communication, - cpu_offload=(placement_policy == 'cpu'), - # optim_config - initial_scale=initial_scale, - growth_factor=growth_factor, - backoff_factor=backoff_factor, - growth_interval=growth_interval, - hysteresis=hysteresis, - min_scale=min_scale, - max_scale=max_scale, - max_norm=max_norm, - norm_type=norm_type - ) + plugin_initializer = lambda: LowLevelZeroPlugin(stage=stage, + precision=precision, + reduce_bucket_size_in_m=reduce_bucket_size, + overlap_communication=overlap_communication, + cpu_offload=(placement_policy == 'cpu'), + initial_scale=initial_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + min_scale=min_scale, + max_scale=max_scale, + max_norm=max_norm, + norm_type=norm_type) super().__init__(seed, plugin_initializer) @@ -131,64 +125,57 @@ class GeminiStrategy(DDPStrategy): """ - def __init__(self, - seed: int = 42, - shard_init: bool = False, # only for stage 3 - placement_policy: str = 'cuda', - pin_memory: bool = True, # only for stage 3 - force_outputs_fp32: bool = False, # only for stage 3 - search_range_m: int = 32, # only for stage 3 - hidden_dim: Optional[int] = None, # only for stage 3 - min_chunk_size_m: float = 32, # only for stage 3 - gpu_margin_mem_ratio: float = 0.0, # only for stage 3 - initial_scale: float = 2**16, - growth_factor: float = 2, - backoff_factor: float = 0.5, - growth_interval: int = 1000, - hysteresis: int = 2, - min_scale: float = 1, - max_scale: float = 2**32, - max_norm: float = 0.0, - norm_type: float = 2.0 - ) -> None: + def __init__( + self, + seed: int = 42, + shard_init: bool = False, # only for stage 3 + placement_policy: str = 'cuda', + pin_memory: bool = True, # only for stage 3 + force_outputs_fp32: bool = False, # only for stage 3 + search_range_m: int = 32, # only for stage 3 + hidden_dim: Optional[int] = None, # only for stage 3 + min_chunk_size_m: float = 32, # only for stage 3 + gpu_margin_mem_ratio: float = 0.0, # only for stage 3 + initial_scale: float = 2**16, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + min_scale: float = 1, + max_scale: float = 2**32, + max_norm: float = 0.0, + norm_type: float = 2.0) -> None: assert placement_policy in ('cpu', 'cuda'), f'Unsupported placement policy "{placement_policy}"' # TODO(ver217): support shard_init when using from_pretrained() if shard_init: - warnings.warn( - f'Shard init is not supported model.from_pretrained() yet. ' - 'Please load weights after strategy.prepare()' - ) + warnings.warn(f'Shard init is not supported model.from_pretrained() yet. ' + 'Please load weights after strategy.prepare()') self.shard_init = shard_init warnings.warn(f'Stage 3 only supports fp16. Precision is set to fp16.') # NOTE: dist should be initialized before calling get_current_device() - plugin_initializer = lambda: GeminiPlugin( - # gemini_config - device=get_current_device(), - placement_policy=placement_policy, - precision='fp16', - pin_memory=pin_memory, - force_outputs_fp32=force_outputs_fp32, - strict_ddp_mode=shard_init, - search_range_m=search_range_m, - hidden_dim=hidden_dim, - min_chunk_size_m=min_chunk_size_m, - # zero_optim_config - gpu_margin_mem_ratio=gpu_margin_mem_ratio, - # optim_config - initial_scale=initial_scale, - growth_factor=growth_factor, - backoff_factor=backoff_factor, - growth_interval=growth_interval, - hysteresis=hysteresis, - min_scale=min_scale, - max_scale=max_scale, - max_norm=max_norm, - norm_type=norm_type - ) + plugin_initializer = lambda: GeminiPlugin(device=get_current_device(), + placement_policy=placement_policy, + precision='fp16', + pin_memory=pin_memory, + force_outputs_fp32=force_outputs_fp32, + strict_ddp_mode=shard_init, + search_range_m=search_range_m, + hidden_dim=hidden_dim, + min_chunk_size_m=min_chunk_size_m, + gpu_margin_mem_ratio=gpu_margin_mem_ratio, + initial_scale=initial_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + min_scale=min_scale, + max_scale=max_scale, + max_norm=max_norm, + norm_type=norm_type) super().__init__(seed, plugin_initializer) @@ -209,7 +196,5 @@ def model_init_context(self): default_dist_spec=default_dist_spec) def unwrap_model(self, model: nn.Module) -> nn.Module: - assert isinstance(model, GeminiModel) - ddp_model = model.unwrap() - assert isinstance(ddp_model, GeminiDDP) - return ddp_model.module + assert isinstance(model, GeminiDDP) + return model.module