Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions applications/ChatGPT/chatgpt/trainer/strategies/colossalai.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class ColossalAIStrategy(DDPStrategy):

Args:
stage(int): The stage to use in ZeRO. Choose in (1, 2, 3)
precision(str): The precision to use. Choose in ('fp32', 'fp16'). Stage 3 only supports fp16.
seed(int): The seed for the random number generator.
shard_init(bool): Whether to shard the model parameters during initialization. Only for ZeRO-3.
This is not compativle with `from_pretrained()`. We temporarily disable this and will support it in the future.
Expand Down Expand Up @@ -56,6 +57,7 @@ class ColossalAIStrategy(DDPStrategy):
def __init__(
self,
stage: int = 3,
precision: str = 'fp16',
seed: int = 42,
shard_init: bool = False, # only for stage 3
placement_policy: str = 'cuda',
Expand All @@ -78,12 +80,17 @@ def __init__(
norm_type: float = 2.0) -> None:
super().__init__(seed)
assert placement_policy in ('cpu', 'cuda'), f'Unsupported placement policy "{placement_policy}"'
assert precision in ('fp32', 'fp16'), f'Unsupported precision "{precision}"'
self.stage = stage
# TODO(ver217): support shard_init when using from_pretrained()
if shard_init:
warnings.warn(
f'Shard init is not supported model.from_pretrained() yet. Please load weights after strategy.prepare()'
)
if stage == 3 and precision == 'fp32':
warnings.warn(f'Stage 3 only supports fp16. Precision is set to fp16.')
precision = 'fp16'
self.precision = precision
self.shard_init = shard_init
self.gemini_config = dict(device=get_current_device(),
placement_policy=placement_policy,
Expand Down Expand Up @@ -124,7 +131,10 @@ def model_init_context(self):
return super().model_init_context()

def setup_model(self, model: nn.Module) -> nn.Module:
return zero_model_wrapper(model, zero_stage=self.stage, gemini_config=self.gemini_config)
model = zero_model_wrapper(model, zero_stage=self.stage, gemini_config=self.gemini_config)
if self.stage != 3 and self.precision == 'fp16':
model = model.half()
return model

def setup_optimizer(self, optimizer: optim.Optimizer, model: nn.Module) -> optim.Optimizer:
assert isinstance(optimizer, (CPUAdam, HybridAdam)), f'Unsupported optimizer {type(optimizer)}'
Expand Down Expand Up @@ -156,7 +166,7 @@ def save_model(self, model: nn.Module, path: str, only_rank0: bool = False) -> N
# merge lora_weights into weights
for module in unwrapped_model.modules():
if isinstance(module, LoraLinear):
module.merge_weights=True
module.merge_weights = True
module.eval()
# get state_dict and save
state_dict = unwrapped_model.state_dict()
Expand Down