Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions colossalai/zero/low_level/low_level_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,12 +333,23 @@ def backward(self, loss, retain_graph=False):
self.zero_grad()

def backward_by_grad(self, tensor, grad):
# in lower stage which grad is transfered by higher stage
# we need to pass the optim state down.
assert not(self._partition_grads and not self.require_grad_sync), \
"ZeRO2(partition_grads) and gradient accumulation(no_sync) are not compatible"

if self.mixed_precision_mixin is not None:
grad = self.mixed_precision_mixin.pre_backward_by_grad(tensor, grad)
torch.autograd.backward(tensor, grad)

if not self.require_grad_sync:
return
self._reduce_grad(self._partition_grads)

# clear reduced grads
if self._overlap_communication:
torch.cuda.synchronize()

self.zero_grad()

def zero_grad(self, set_to_none=True):
"""
Set parameter gradients to zero. If set_to_none = True, gradient
Expand Down
9 changes: 9 additions & 0 deletions tests/test_shardformer/test_model/test_shard_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,15 @@ def run_bert_test(test_config):
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp32',
},
{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp16',
'zero_stage': 1,
'initial_scale': 1,
},
])
Expand Down
10 changes: 10 additions & 0 deletions tests/test_shardformer/test_model/test_shard_bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,16 @@ def run_bloom_test(test_config):
'precision': 'fp32',
'initial_scale': 1,
},
{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp16',
'zero_stage': 1,
'initial_scale': 1,
},
])
def run_bloom_3d_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom')
Expand Down
10 changes: 10 additions & 0 deletions tests/test_shardformer/test_model/test_shard_chatglm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,16 @@ def run_chatglm_test(test_config):
'precision': 'fp32',
'initial_scale': 1,
},
{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp16',
'zero_stage': 1,
'initial_scale': 1,
},
])
def run_chatglm_3d_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry('transformers_chatglm')
Expand Down
10 changes: 10 additions & 0 deletions tests/test_shardformer/test_model/test_shard_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,16 @@ def run_gpt2_test(test_config):
'precision': 'fp32',
'initial_scale': 1,
},
{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp16',
'zero_stage': 1,
'initial_scale': 1,
},
])
@clear_cache_before_run()
def run_gpt2_3d_test(test_config):
Expand Down
10 changes: 10 additions & 0 deletions tests/test_shardformer/test_model/test_shard_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,16 @@ def run_llama_test(test_config):
'precision': 'fp32',
'initial_scale': 1,
},
{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp16',
'zero_stage': 1,
'initial_scale': 1,
},
])
def run_llama_3d_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry('transformers_llama')
Expand Down
10 changes: 10 additions & 0 deletions tests/test_shardformer/test_model/test_shard_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,16 @@ def run_opt_test(test_config):
'precision': 'fp32',
'initial_scale': 1,
},
{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp16',
'zero_stage': 1,
'initial_scale': 1,
},
])
def run_opt_3d_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry('transformers_opt')
Expand Down
10 changes: 10 additions & 0 deletions tests/test_shardformer/test_model/test_shard_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,16 @@ def run_t5_test(test_config):
'precision': 'fp32',
'initial_scale': 1,
},
{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp16',
'zero_stage': 1,
'initial_scale': 1,
},
])
def run_t5_3d_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry('transformers_t5')
Expand Down
9 changes: 9 additions & 0 deletions tests/test_shardformer/test_model/test_shard_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,15 @@ def run_vit_test(test_config):
'precision': 'fp32',
'initial_scale': 1,
},
{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 2,
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp32',
'initial_scale': 1,
},
])
def run_vit_3d_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry('transformers_vit')
Expand Down
11 changes: 10 additions & 1 deletion tests/test_shardformer/test_model/test_shard_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,

# check weights
if test_config['precision'] == 'fp32':
atol, rtol = 5e-4, 5e-4
atol, rtol = 1e-3, 1e-3
else:
atol, rtol = 5e-3, 5e-3
if stage_manager is None or stage_manager.is_first_stage():
Expand Down Expand Up @@ -195,6 +195,15 @@ def run_whisper_test(test_config):
'precision': 'fp32',
'initial_scale': 1,
},
{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 2,
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp32',
'initial_scale': 1,
},
])
def run_whisper_3d_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry('transformers_whisper')
Expand Down