Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions colossalai/pipeline/schedule/one_f_one_b.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,11 +128,11 @@ def forward_step(self,
Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor).
"""
micro_batch = self.load_micro_batch()

# for the first stage, input_obj is None
# for the non-first stage, input_obj is the output of the previous stage and it's must be a dict
output_obj = model_forward(model, micro_batch, input_obj)
if self.stage_manager.is_last_stage():

loss = criterion(output_obj, micro_batch) / self.num_microbatches
if accum_loss is not None:
accum_loss.add_(loss.detach())
Expand All @@ -158,7 +158,6 @@ def backward_step(self, optimizer: OptimizerWrapper, input_obj: Optional[dict],

# Retain the grad on the input_obj.
tree_map(retain_grad, input_obj)

# Backward pass.
if output_obj_grad is None:
optimizer.backward(output_obj)
Expand Down
9 changes: 7 additions & 2 deletions colossalai/zero/low_level/low_level_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,6 @@ def _add_to_bucket(self, param, group_id):
def backward(self, loss, retain_graph=False):
assert not(self._partition_grads and not self.require_grad_sync), \
"ZeRO2(partition_grads) and gradient accumulation(no_sync) are not compatible"

if self.mixed_precision_mixin is not None:
loss = self.mixed_precision_mixin.pre_backward(loss)

Expand All @@ -333,6 +332,13 @@ def backward(self, loss, retain_graph=False):

self.zero_grad()

def backward_by_grad(self, tensor, grad):
# in lower stage which grad is transfered by higher stage
# we need to pass the optim state down.
if self.mixed_precision_mixin is not None:
grad = self.mixed_precision_mixin.pre_backward_by_grad(tensor, grad)
torch.autograd.backward(tensor, grad)

def zero_grad(self, set_to_none=True):
"""
Set parameter gradients to zero. If set_to_none = True, gradient
Expand All @@ -358,7 +364,6 @@ def zero_grad(self, set_to_none=True):

def step(self, closure=None):
assert closure is None, 'closure is not supported by step()'

if not self.require_grad_sync:
return

Expand Down
9 changes: 9 additions & 0 deletions tests/test_shardformer/test_model/test_shard_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,15 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
'zero_stage': 2,
'precision': 'fp16',
'initial_scale': 1
}, {
'tp_size': 1,
'pp_size': 2,
'num_microbatches': 2,
'enable_all_optimization': True,
'use_lazy_init': True,
'zero_stage': 1,
'precision': 'fp16',
'initial_scale': 1
}])
def run_bert_test(test_config):

Expand Down
9 changes: 9 additions & 0 deletions tests/test_shardformer/test_model/test_shard_bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,15 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
'zero_stage': 2,
'precision': 'fp16',
'initial_scale': 1
}, {
'tp_size': 1,
'pp_size': 2,
'num_microbatches': 2,
'enable_all_optimization': True,
'use_lazy_init': True,
'zero_stage': 1,
'precision': 'fp16',
'initial_scale': 1
}])
def run_bloom_test(test_config):

Expand Down
9 changes: 9 additions & 0 deletions tests/test_shardformer/test_model/test_shard_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,15 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
'zero_stage': 2,
'precision': 'fp16',
'initial_scale': 1
}, {
'tp_size': 1,
'pp_size': 2,
'num_microbatches': 2,
'enable_all_optimization': True,
'use_lazy_init': True,
'zero_stage': 1,
'precision': 'fp16',
'initial_scale': 1
}])
@clear_cache_before_run()
def run_gpt2_test(test_config):
Expand Down
9 changes: 9 additions & 0 deletions tests/test_shardformer/test_model/test_shard_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,15 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
'zero_stage': 2,
'precision': 'fp16',
'initial_scale': 1
}, {
'tp_size': 1,
'pp_size': 2,
'num_microbatches': 2,
'enable_all_optimization': True,
'use_lazy_init': True,
'zero_stage': 1,
'precision': 'fp16',
'initial_scale': 1
}])
def run_llama_test(test_config):

Expand Down
9 changes: 9 additions & 0 deletions tests/test_shardformer/test_model/test_shard_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,15 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
'zero_stage': 2,
'precision': 'fp16',
'initial_scale': 1
}, {
'tp_size': 1,
'pp_size': 2,
'num_microbatches': 2,
'enable_all_optimization': True,
'use_lazy_init': True,
'zero_stage': 1,
'precision': 'fp16',
'initial_scale': 1
}])
def run_opt_test(test_config):

Expand Down
9 changes: 9 additions & 0 deletions tests/test_shardformer/test_model/test_shard_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,15 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
'zero_stage': 2,
'precision': 'fp16',
'initial_scale': 1
}, {
'tp_size': 1,
'pp_size': 2,
'num_microbatches': 2,
'enable_all_optimization': True,
'use_lazy_init': True,
'zero_stage': 1,
'precision': 'fp16',
'initial_scale': 1
}])
@clear_cache_before_run()
def run_t5_test(test_config):
Expand Down
11 changes: 10 additions & 1 deletion tests/test_shardformer/test_model/test_shard_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,

if org_model.__class__.__name__ == 'ViTModel':
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)

check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)

# unwrap model
Expand Down Expand Up @@ -97,6 +96,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
torch.cuda.empty_cache()


#TODO: num_microbatch size = 2 inf loss
@parameterize('test_config', [{
'tp_size': 2,
'pp_size': 2,
Expand Down Expand Up @@ -132,6 +132,15 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
'zero_stage': 2,
'precision': 'fp16',
'initial_scale': 1
}, {
'tp_size': 1,
'pp_size': 2,
'num_microbatches': 4,
'enable_all_optimization': True,
'use_lazy_init': False,
'zero_stage': 1,
'precision': 'fp16',
'initial_scale': 1
}])
def run_vit_test(test_config):

Expand Down
67 changes: 37 additions & 30 deletions tests/test_shardformer/test_model/test_shard_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,37 +112,44 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,

torch.cuda.empty_cache()


#TODO fix WhisperForConditionalGeneration enable jit fused operato
# TODO(jianghai) fix fp16
#TODO fix WhisperForConditionalGeneration enable jit fused operator
@parameterize('test_config', [{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 2,
'enable_all_optimization': True,
'use_lazy_init': True,
'precision': 'fp32',
'initial_scale': 1,
}, {
'tp_size': 1,
'pp_size': 2,
'num_microbatches': 4,
'use_lazy_init': False,
'precision': 'fp32',
'initial_scale': 1,
}, {
'tp_size': 4,
'pp_size': 1,
'enable_all_optimization': True,
'use_lazy_init': False,
'precision': 'fp32',
}, {
'tp_size': 1,
'pp_size': 4,
'num_microbatches': 4,
'use_lazy_init': False,
'precision': 'fp32',
}])
@parameterize(
'test_config',
[
{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 2,
'enable_all_optimization': True,
'use_lazy_init': True,
'precision': 'fp32',
'initial_scale': 1,
},
{
'tp_size': 1,
'pp_size': 2,
'num_microbatches': 4,
'use_lazy_init': False,
'precision': 'fp32',
'initial_scale': 1,
},
{
'tp_size': 4,
'pp_size': 1,
'enable_all_optimization': True,
'use_lazy_init': False,
'precision': 'fp32',
},
{
'tp_size': 1,
'pp_size': 4,
'num_microbatches': 4,
'use_lazy_init': False,
'precision': 'fp32',
},
# whisper is not supported fp16 for now.
])
def run_whisper_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry('transformers_whisper')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
Expand Down