Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions colossalai/testing/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@ def assert_close_loose(a: Tensor, b: Tensor, rtol: float = 1e-3, atol: float = 1
b,
rtol=rtol,
atol=atol,
msg=f"Tensor not close, shape: {a.shape} vs {b.shape}, \
dtype: {a.dtype} vs {b.dtype}",
)


Expand Down
10 changes: 4 additions & 6 deletions tests/test_checkpoint_io/test_gemini_checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b
dist.barrier()

new_bert_model = BertForSequenceClassification.from_pretrained(pretrained_path)
check_state_dict_equal(bert_model.state_dict(only_rank_0=False), new_bert_model.state_dict(), False)
check_state_dict_equal(bert_model.state_dict(only_rank_0=False), new_bert_model.state_dict())


@clear_cache_before_run()
Expand Down Expand Up @@ -130,13 +130,11 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha

booster.load_model(new_model, model_ckpt_path)
check_state_dict_equal(
model.state_dict(only_rank_0=False), new_model.state_dict(only_rank_0=False), False, ignore_dtype=True
model.state_dict(only_rank_0=False), new_model.state_dict(only_rank_0=False), ignore_dtype=True
)

booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
check_state_dict_equal(
optimizer.state_dict(only_rank_0=False), new_optimizer.state_dict(only_rank_0=False), False
)
check_state_dict_equal(optimizer.state_dict(only_rank_0=False), new_optimizer.state_dict(only_rank_0=False))
for group in new_optimizer.param_groups:
assert group["lr"] == 0.1

Expand Down Expand Up @@ -169,7 +167,7 @@ def exam_lazy_from_pretrained():
booster.save_model(model, save_path, shard=False)
dist.barrier()
state_dict = torch.load(save_path, map_location="cpu")
check_state_dict_equal(state_dict, orig_state_dict, False, ignore_dtype=True)
check_state_dict_equal(state_dict, orig_state_dict, ignore_dtype=True)


def run_dist(rank, world_size, port):
Expand Down
8 changes: 4 additions & 4 deletions tests/test_checkpoint_io/test_gemini_torch_compability.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,12 @@ def exam_torch_load_from_gemini(shard: bool, model_name: str):
check_state_dict_equal(
model.state_dict(only_rank_0=False, prefix="module.module."),
new_model.state_dict(),
False,
ignore_device=False,
ignore_dtype=True,
)

new_booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
check_state_dict_equal(optimizer.state_dict(only_rank_0=False), new_optimizer.state_dict(), False)
check_state_dict_equal(optimizer.state_dict(only_rank_0=False), new_optimizer.state_dict(), ignore_device=False)

# Check the new model/optimizer can successfully run.
data = data_gen_fn()
Expand Down Expand Up @@ -128,7 +128,7 @@ def exam_gemini_load_from_torch(shard: bool, model_name: str):
check_state_dict_equal(
new_model.state_dict(only_rank_0=False, prefix="module.module."),
model.state_dict(),
False,
ignore_device=False,
ignore_dtype=True,
)

Expand All @@ -145,7 +145,7 @@ def exam_gemini_load_from_torch(shard: bool, model_name: str):
k in old_group and k in new_group
), f"Old group's keys: {list(old_group.keys())}, New group's keys: {list(new_group.keys())}"
assert old_group[k] == new_group[k]
check_state_dict_equal(old_state_dict["state"], new_state_dict["state"], False)
check_state_dict_equal(old_state_dict["state"], new_state_dict["state"], ignore_device=False)

# Check the new model/optimizer can successfully run.
data = data_gen_fn()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,9 @@ def _preprocess_data(data):
new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion)

booster.load_model(new_model, model_ckpt_path)
check_state_dict_equal(model.unwrap().state_dict(), new_model.unwrap().state_dict(), False)
check_state_dict_equal(model.unwrap().state_dict(), new_model.unwrap().state_dict())
booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
check_state_dict_equal(optimizer.unwrap().state_dict(), new_optimizer.unwrap().state_dict(), False)
check_state_dict_equal(optimizer.unwrap().state_dict(), new_optimizer.unwrap().state_dict())
dist.barrier()

# Check whether the loaded model & optimizer works smoothly.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool):
new_model, new_optimizer, _, _, _ = booster.boost(new_model, new_optimizer)

booster.load_model(new_model, model_ckpt_path)
check_state_dict_equal(model.state_dict(), new_model.state_dict(), False)
check_state_dict_equal(model.state_dict(), new_model.state_dict())
# check master weight
assert isinstance(new_optimizer, LowLevelZeroOptimizer)
working_param_id_set = set(id(p) for p in new_model.parameters())
Expand All @@ -70,7 +70,7 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool):
)

booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
check_state_dict_equal(optimizer.optim.state_dict(), new_optimizer.optim.state_dict(), False)
check_state_dict_equal(optimizer.optim.state_dict(), new_optimizer.optim.state_dict())
torch.cuda.empty_cache()


Expand Down Expand Up @@ -110,7 +110,7 @@ def run_fn(stage, shard, offload, model_fn, data_gen_fn, output_transform_fn, lo
booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=False)
new_model = new_booster.enable_lora(new_model, pretrained_dir=model_ckpt_path, lora_config=lora_config)
new_model, new_optimizer, criterion, _, _ = new_booster.boost(new_model, new_optimizer, criterion)
check_state_dict_equal(model.state_dict(), new_model.state_dict(), False)
check_state_dict_equal(model.state_dict(), new_model.state_dict())

# check master weight
assert isinstance(new_optimizer, LowLevelZeroOptimizer)
Expand All @@ -126,7 +126,7 @@ def run_fn(stage, shard, offload, model_fn, data_gen_fn, output_transform_fn, lo
)

new_booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
check_state_dict_equal(optimizer.optim.state_dict(), new_optimizer.optim.state_dict(), False)
check_state_dict_equal(optimizer.optim.state_dict(), new_optimizer.optim.state_dict())

except Exception as e:
# return repr(e)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,9 @@ def exam_from_pretrained(plugin_type: str, model_name: str, shard=True, size_per
new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion)

if plugin_type == "gemini":
check_state_dict_equal(model.state_dict(only_rank_0=False), new_model.state_dict(only_rank_0=False), False)
check_state_dict_equal(model.state_dict(only_rank_0=False), new_model.state_dict(only_rank_0=False))
else:
check_state_dict_equal(model.unwrap().state_dict(), new_model.unwrap().state_dict(), False)
check_state_dict_equal(model.unwrap().state_dict(), new_model.unwrap().state_dict())
dist.barrier()


Expand Down
6 changes: 3 additions & 3 deletions tests/test_checkpoint_io/test_torch_ddp_checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,12 @@ def check_torch_ddp_checkpointIO(shard: bool, size_per_shard: int):
)

booster.load_model(new_model, model_ckpt_path)
check_state_dict_equal(model.state_dict(), new_model.state_dict(), False)
check_state_dict_equal(model.state_dict(), new_model.state_dict())

booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict(), False)
check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict())
booster.load_lr_scheduler(new_scheduler, lr_scheduler_ckpt_path)
check_state_dict_equal(scheduler.state_dict(), new_scheduler.state_dict(), False)
check_state_dict_equal(scheduler.state_dict(), new_scheduler.state_dict())


def run_dist(rank, world_size, port):
Expand Down
96 changes: 33 additions & 63 deletions tests/test_optimizer/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from torch.testing import assert_close

import colossalai
from colossalai.shardformer.layer._operation import _gather
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import parameterize, spawn
Expand Down Expand Up @@ -119,11 +118,15 @@ def run_bert_test(test_config, optim_class, sharded_optim_class):
test_config["use_lazy_init"] = False
test_config["pp_size"] = 1 # Do NOT test Pipeline Parallel
test_config["initial_scale"] = 2**15 # avoid overflow
target_models = [
"transformers_bert",
]

for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_bert_fwd_bwd(
model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config, optim_class, sharded_optim_class
)
if name in target_models:
check_bert_fwd_bwd(
model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config, optim_class, sharded_optim_class
)

clear_layout_converter()
Randomizer.reset_index()
Expand Down Expand Up @@ -152,7 +155,8 @@ def check_dist_optim_state(org_optimizer, sharded_optimizer):
shard_spec = sharded_optimizer.shard_spec_dict[id(tp)]
use_zero = sharded_optimizer.use_zero
tp_optim_state = tp_state[key]
p_state_shape, tp_state_shape = p_state[key].shape, tp_state[key].shape
state = p_state[key]

dp_size, tp_size = (
sharded_optimizer.dp_size,
sharded_optimizer.tp_size,
Expand All @@ -165,88 +169,54 @@ def check_dist_optim_state(org_optimizer, sharded_optimizer):
if shard_spec.sharding_sequence[0] == "R":
if use_zero:
# sq_row need gather alone dp group
if key == "exp_avg_sq_row":
tp_optim_state = _gather(
input_=tp_optim_state,
dim=-1,
process_group=sharded_optimizer.dp_group,
)
tp_optim_state.shape
# sq_col don't need gather alone dp group
if key == "exp_avg_sq_col":
pass
else:
pass
if key == "exp_avg_sq_row":
state = state.chunk(dp_size, dim=-1)[dist.get_rank(sharded_optimizer.dp_group)]

# gather from tp group
# sq_row don need gather alone tp group
if key == "exp_avg_sq_row":
pass
# sq_col need gather alone dp group
# sq_col need gather alone tp group
if key == "exp_avg_sq_col":
tp_optim_state = _gather(
input_=tp_optim_state, dim=-1, process_group=sharded_optimizer.tp_group
)
tp_optim_state.shape

state = state.chunk(tp_size, dim=-1)[dist.get_rank(sharded_optimizer.tp_group)]
# row parallel
if shard_spec.sharding_sequence[-1] == "R":
if use_zero:
elif shard_spec.sharding_sequence[-1] == "R":
# TODO: this case may cause shape mismatch @duanjunwen
if use_zero and key == "exp_avg_sq_row" and state.shape[0] // tp_size % dp_size == 0:
# sq_row need gather alone dp group
if key == "exp_avg_sq_row":
if p_state[key].shape[0] // tp_size % dp_size != 0:
pass
else:
tp_optim_state = _gather(
input_=tp_optim_state,
dim=-1,
process_group=sharded_optimizer.dp_group,
)
tp_optim_state.shape
# sq_col don't need gather alone dp group
if key == "exp_avg_sq_col":
pass
else:
pass

state = state.chunk(dp_size, dim=-1)[dist.get_rank(sharded_optimizer.dp_group)]

# gather from tp group
# sq_row need gather alone tp group
if key == "exp_avg_sq_row":
tp_optim_state = _gather(
input_=tp_optim_state, dim=-1, process_group=sharded_optimizer.tp_group
)
tp_optim_state.shape
state = state.chunk(tp_size, dim=-1)[dist.get_rank(sharded_optimizer.tp_group)]
# sq_col don't need gather alone dp group
if key == "exp_avg_sq_col":
pass
else:
return
else:
if use_zero:
# sq_row need gather alone dp group
if key == "exp_avg_sq_row":
# row residule; no gather
if p_state[key].shape[0] % dp_size != 0:
if state.shape[0] % dp_size != 0:
pass
else:
tp_optim_state = _gather(
input_=tp_optim_state,
dim=-1,
process_group=sharded_optimizer.dp_group,
)
tp_optim_state.shape
state = state.chunk(dp_size, dim=-1)[dist.get_rank(sharded_optimizer.dp_group)]
# sq_col don't need gather alone dp group
if key == "exp_avg_sq_col":
tp_optim_state = tp_optim_state.div_(dp_size)
# need a div;
else:
pass
# Sovled a New issus: different dtype;
# So far, only happen in H100 env;
# Seem torch.set_default_dtype(torch.bfloat16) not act on booster.percision;
# Or assert_close just update to check dtype;
if p_state[key].dtype != tp_optim_state.dtype:
tp_optim_state = tp_optim_state.type(p_state[key].dtype)
try:
assert_close(p_state[key], tp_optim_state, atol=5e-4, rtol=1.6e-2)
except:
pass

if state.dtype != tp_optim_state.dtype:
tp_optim_state = tp_optim_state.type(state.dtype)
# TODO: some sharding checks are currently buggy, but the state values should match
# @duanjunwen
if state.shape != tp_optim_state.shape:
return
assert_close(state, tp_optim_state, atol=5e-4, rtol=1.6e-2)


def check_dist_param(org_model, sharded_model, weight_layer_for_check, atol, rtol):
Expand Down
Loading