Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions colossalai/zero/gemini/chunk/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,6 +637,7 @@ def init_grad_chunk(self) -> "Chunk":
# grad chunk is initialized, just reallocate cuda global chunk
self.grad_chunk.cuda_shard = None
self.grad_chunk.is_gathered = True
self.grad_chunk.l2_norm = None
alloc_storage(self.grad_chunk.cuda_global_chunk)

return self.grad_chunk
1 change: 1 addition & 0 deletions colossalai/zero/gemini/gemini_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,7 @@ def grad_handle(self, p, grad):
grad_chunk = self.chunk_manager.rearrange_accumulated_grad_chunk(chunk)
else:
grad_chunk = chunk.grad_chunk
chunk.grad_chunk.l2_norm = None

# hold -> compute -> hold after bwd
grad_chunk.tensor_trans_state(p, TensorState.COMPUTE)
Expand Down
12 changes: 10 additions & 2 deletions tests/test_zero/test_gemini/test_grad_accum.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,10 @@ def check_grad(model: GeminiDDP, torch_model: torch.nn.Module):
@parameterize("keep_gathered", [False, True])
@parameterize("model_name", ["transformers_gpt_lm"])
@parameterize("master_weights", [False, True])
def exam_gemini_grad_acc(placement_config, keep_gathered: bool, model_name: str, master_weights: bool):
@parameterize("use_grad_checkpoint", [False, True])
def exam_gemini_grad_acc(
placement_config, keep_gathered: bool, model_name: str, master_weights: bool, use_grad_checkpoint: bool
):
init_device = get_current_device()
model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next(
iter(model_zoo.get_sub_registry(model_name).values())
Expand All @@ -63,6 +66,10 @@ def exam_gemini_grad_acc(placement_config, keep_gathered: bool, model_name: str,
for torch_p, p in zip(torch_model.parameters(), gemini_model.parameters()):
torch_p.data.copy_(p.data)

if use_grad_checkpoint:
gemini_model.gradient_checkpointing_enable()
torch_model.gradient_checkpointing_enable()

world_size = torch.distributed.get_world_size()
config_dict, *_ = search_chunk_configuration(gemini_model, search_range_m=1, search_interval=100)
config_dict[world_size]["chunk_size"] = 5000
Expand All @@ -77,7 +84,7 @@ def exam_gemini_grad_acc(placement_config, keep_gathered: bool, model_name: str,
**placement_config,
)
optimizer = HybridAdam(gemini_model.parameters(), lr=1e-3)
gemini_optim = GeminiOptimizer(optimizer, gemini_model, initial_scale=1)
gemini_optim = GeminiOptimizer(optimizer, gemini_model, initial_scale=1, max_norm=1.0)

rank = dist.get_rank()

Expand Down Expand Up @@ -112,6 +119,7 @@ def exam_gemini_grad_acc(placement_config, keep_gathered: bool, model_name: str,
check_grad(gemini_model, torch_model)

if (i + 1) % accum_iter == 0:
torch.nn.utils.clip_grad_norm_(amp.master_params(torch_optim), 1.0)
torch_optim.step()
gemini_optim.step()
torch_optim.zero_grad()
Expand Down
2 changes: 1 addition & 1 deletion tests/test_zero/test_gemini/test_grad_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def exam_grad_clipping(placement_config, model_name: str, master_weights: bool):
)

optimizer = HybridAdam(model.parameters(), lr=1e-3)
zero_optim = GeminiOptimizer(optimizer, model, initial_scale=32, clipping_norm=1.0)
zero_optim = GeminiOptimizer(optimizer, model, initial_scale=32, max_norm=1.0)

model.train()
torch_model.train()
Expand Down