Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 7 additions & 10 deletions colossalai/booster/plugin/gemini_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,11 +181,11 @@ class GeminiPlugin(DPPluginBase):
pin_memory (bool, optional): use pin memory on CPU. Defaults to False.
force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False.
strict_ddp_mode (bool, optional): use strict ddp mode (only use dp without other parallelism). Defaults to False.
search_range_mb (int, optional): chunk size searching range in MegaByte. Defaults to 32.
search_range_m (int, optional): chunk size searching range divided by 2^20. Defaults to 32.
hidden_dim (int, optional): the hidden dimension of DNN.
Users can provide this argument to speed up searching.
If users do not know this argument before training, it is ok. We will use a default value 1024.
min_chunk_size_mb (float, optional): the minimum chunk size in MegaByte.
min_chunk_size_m (float, optional): the minimum chunk size divided by 2^20.
If the aggregate size of parameters is still smaller than the minimum chunk size,
all parameters will be compacted into one small chunk.
memstats (MemStats, optional) the memory statistics collector by a runtime memory tracer.
Expand Down Expand Up @@ -214,9 +214,9 @@ def __init__(
pin_memory: bool = False,
force_outputs_fp32: bool = False,
strict_ddp_mode: bool = False,
search_range_mb: int = 32,
search_range_m: int = 32,
hidden_dim: Optional[int] = None,
min_chunk_size_mb: float = 32,
min_chunk_size_m: float = 32,
memstats: Optional[MemStats] = None,
gpu_margin_mem_ratio: float = 0.0,
initial_scale: float = 2**32,
Expand All @@ -238,9 +238,9 @@ def __init__(
pin_memory=pin_memory,
force_outputs_fp32=force_outputs_fp32,
strict_ddp_mode=strict_ddp_mode,
search_range_mb=search_range_mb,
search_range_m=search_range_m,
hidden_dim=hidden_dim,
min_chunk_size_mb=min_chunk_size_mb,
min_chunk_size_m=min_chunk_size_m,
memstats=memstats,
mixed_precision=PRECISION_STR_TO_DTYPE[precision],
)
Expand Down Expand Up @@ -295,10 +295,7 @@ def configure(

if optimizer is not None and \
not isinstance(optimizer, OptimizerWrapper):
optimizer = GeminiOptimizer(model.unwrap(),
optimizer,
self.zero_optim_config,
self.optim_kwargs,
optimizer = GeminiOptimizer(model.unwrap(), optimizer, self.zero_optim_config, self.optim_kwargs,
self.verbose)

return model, optimizer, criterion, dataloader, lr_scheduler
Expand Down
26 changes: 13 additions & 13 deletions colossalai/zero/gemini/chunk/search_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,9 @@ def classify_params_by_dp_degree(param_order: OrderedParamGenerator,

def search_chunk_configuration(
model: nn.Module,
search_range_mb: float,
search_interval_byte: int, # hidden size is the best value for the interval
min_chunk_size_mb: float = 32,
search_range_m: float,
search_interval: int, # hidden size is the best value for the interval
min_chunk_size_m: float = 32,
filter_exlarge_params: bool = True,
strict_ddp_flag: bool = False,
memstas: Optional[MemStats] = None) -> Tuple[Dict, int, int]:
Expand All @@ -126,9 +126,9 @@ def search_chunk_configuration(

Args:
model (nn.Module): torch module
search_range_mb (float): searching range in mega byte.
search_interval_byte (int): searching interval in byte.
min_chunk_size_mb (float, optional): the minimum size of a distributed chunk.
search_range_m (float): searching range divided by 2^20.
search_interval (int): searching interval.
min_chunk_size_m (float, optional): the minimum size of a distributed chunk, divided by 2^20..
filter_exlarge_params (bool, optional): filter extreme large parameters. Defaults to True.
strict_ddp_flag (bool, optional): whether to enable the strict ddp mode.
all parameters keep replicated in this mode.
Expand All @@ -145,9 +145,9 @@ def search_chunk_configuration(
for p in model.parameters():
param_order.append(p)

search_range_byte = round(search_range_mb * 1024**2)
min_chunk_size_byte = round(min_chunk_size_mb * 1024**2)
assert search_range_byte >= 0
search_range = round(search_range_m * 1024**2)
min_chunk_size = round(min_chunk_size_m * 1024**2)
assert search_range >= 0

params_dict = classify_params_by_dp_degree(param_order, strict_ddp_flag)
size_lcm = np.lcm.reduce(list(params_dict.keys()))
Expand All @@ -162,23 +162,23 @@ def search_chunk_configuration(
total_param_size += group_acc_size

# let small parameters keep gathered in CUDA all the time
if group_acc_size < min_chunk_size_byte:
if group_acc_size < min_chunk_size:
config_dict[dp_degree] = dict(chunk_size=group_acc_size, keep_gathered=True)
else:
size_dict[dp_degree] = size_list

if filter_exlarge_params:
_filter_exlarge_params(model, size_dict)

max_size = min_chunk_size_byte
max_size = min_chunk_size
for key in size_dict:
max_size = max(max_size, max(size_dict[key]))
start_size = int(math.ceil(max_size / search_interval_byte) * search_interval_byte)
start_size = int(math.ceil(max_size / search_interval) * search_interval)

min_chunk_waste = float('+inf')
best_chunk_size = start_size

for chunk_size in range(start_size, start_size + search_range_byte + 1, search_interval_byte):
for chunk_size in range(start_size, start_size + search_range + 1, search_interval):
temp_waste = 0
for key in size_dict:
temp_waste += _get_unused_byte(size_dict[key], chunk_size)
Expand Down
14 changes: 7 additions & 7 deletions colossalai/zero/gemini/chunk/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ def init_chunk_manager(model: nn.Module,
verbose: bool = False,
**kwargs) -> ChunkManager:
if hidden_dim:
search_interval_byte = hidden_dim
search_interval = hidden_dim
else:
search_interval_byte = 1024 # defaults to 1kb
kwargs["search_interval_byte"] = search_interval_byte
search_interval = 1024 # defaults to 1024
kwargs["search_interval"] = search_interval

dist.barrier()
begin = time()
Expand All @@ -36,13 +36,13 @@ def init_chunk_manager(model: nn.Module,
dist.barrier()
end = time()
span_s = end - begin
mb_size = 1024**2
total_size /= mb_size
wasted_size /= mb_size
mega_unit = 1024**2
total_size /= mega_unit
wasted_size /= mega_unit

if verbose and dist.get_rank() == 0:
print("searching chunk configuration is completed in {:.2f} s.\n".format(span_s),
"used number: {:.2f} MB, wasted number: {:.2f} MB\n".format(total_size, wasted_size),
"used number: {:.2f} * 2^20, wasted number: {:.2f} * 2^20\n".format(total_size, wasted_size),
"total wasted percentage is {:.2f}%".format(100 * safe_div(wasted_size, total_size + wasted_size)),
sep='',
flush=True)
Expand Down
16 changes: 8 additions & 8 deletions colossalai/zero/gemini/gemini_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,9 +739,9 @@ def __init__(self,
force_outputs_fp32: bool = False,
strict_ddp_mode: bool = False,
scatter_after_inference: bool = True,
search_range_mb: int = 32,
search_range_m: int = 32,
hidden_dim: Optional[int] = None,
min_chunk_size_mb: float = 32,
min_chunk_size_m: float = 32,
memstats: Optional[MemStats] = None,
mixed_precision: torch.dtype = torch.float16,
verbose: bool = False) -> None:
Expand All @@ -763,24 +763,24 @@ def __init__(self,
placement_policy (str, optional): "cpu", "cuda", "auto". Defaults to "cpu".
pin_memory (bool, optional): use pin memory on CPU. Defaults to False.
force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False.
search_range_mb (int, optional): chunk size searching range in MegaByte. Defaults to 32.
search_range_m (int, optional): chunk size searching range divided by 2^20. Defaults to 32.
hidden_dim (int, optional): the hidden dimension of DNN.
Users can provide this argument to speed up searching.
If users do not know this argument before training, it is ok. We will use a default value 1024.
min_chunk_size_mb (float, optional): the minimum chunk size in MegaByte.
min_chunk_size_m (float, optional): the minimum chunk size divided by 2^20.
If the aggregate size of parameters is still smaller than the minimum chunk size,
all parameters will be compacted into one small chunk.
memstats (MemStats, optional) the memory statistics collector by a runtime memory tracer.
"""
# some ugly hotfix for the compatibility with Lightning
if search_range_mb is None:
search_range_mb = 32
if search_range_m is None:
search_range_m = 32

chunk_manager = init_chunk_manager(model=module,
init_device=device,
hidden_dim=hidden_dim,
search_range_mb=search_range_mb,
min_chunk_size_mb=min_chunk_size_mb,
search_range_m=search_range_m,
min_chunk_size_m=min_chunk_size_m,
strict_ddp_flag=strict_ddp_mode,
verbose=verbose)
gemini_manager = GeminiManager(placement_policy, chunk_manager, memstats)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_auto_parallel/test_offload/test_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def exam_fwd_bwd(model_name: str, memory_budget: float, solver_name: str):
placement_policy='cpu',
pin_memory=True,
hidden_dim=8192,
search_range_mb=128)
search_range_m=128)
gemini_model = zero_model_wrapper(gemini_model, 3, gemini_config)
optim_config = dict(reduce_bucket_size=12 * 1024 * 1024, overlap_communication=True, verbose=True)
gemini_optim = zero_optim_wrapper(gemini_model, hybrid_optimizer, optim_config=optim_config)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def check_auto_parallel_with_gemini(rank, world_size, port):
device=get_current_device(),
placement_policy='cpu',
pin_memory=True,
search_range_mb=128)
search_range_m=128)

post_process_colo_init_ctx(gm, device=get_current_device(), default_pg=dp_process_group)
gm = zero_model_wrapper(gm, zero_stage=3, gemini_config=gemini_config)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_checkpoint_io/test_gemini_checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def exam_state_dict_with_origin(placement_policy, model_name, use_safetensors: b
bert_model.config.save_pretrained(save_directory=pretrained_path)

# TODO(ver217): use boost api
config_dict, *_ = search_chunk_configuration(bert_model, search_range_mb=1, search_interval_byte=100)
config_dict, *_ = search_chunk_configuration(bert_model, search_range_m=1, search_interval=100)
chunk_manager = ChunkManager(config_dict)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
bert_model = ZeroDDP(bert_model, gemini_manager)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_tensor/test_tp_with_zero.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def run_gpt(placement_policy, tp_init_spec_func=None):
tp_init_spec_func(model, pg)

dp_world_size = pg.dp_world_size()
config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
config_dict[dp_world_size]['chunk_size'] = 5000
config_dict[dp_world_size]['keep_gathered'] = False
if placement_policy != 'cuda':
Expand Down
4 changes: 2 additions & 2 deletions tests/test_zero/test_gemini/test_fwd_bwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def exam_gpt_fwd_bwd(
torch_p.data.copy_(p.data)

world_size = torch.distributed.get_world_size()
config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = keep_gather
chunk_manager = ChunkManager(config_dict)
Expand Down Expand Up @@ -113,7 +113,7 @@ def exam_gpt_inference(
torch_p.data.copy_(p.data)

world_size = torch.distributed.get_world_size()
config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = keep_gather
chunk_manager = ChunkManager(config_dict)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_zero/test_gemini/test_gemini_use_rmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_
assert len(step_list) == 4

world_size = torch.distributed.get_world_size()
config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = keep_gather
chunk_manager = ChunkManager(config_dict)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_zero/test_gemini/test_grad_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def exam_grad_clipping(placement_policy, model_name: str):
p.data.copy_(torch_p.data)

world_size = torch.distributed.get_world_size()
config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = False
if placement_policy != 'cuda':
Expand Down
2 changes: 1 addition & 1 deletion tests/test_zero/test_gemini/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module):

def multi_chunk_init(model: torch.nn.Module, placement_policy: str):
world_size = dist.get_world_size()
config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = False
if placement_policy != 'cuda':
Expand Down
4 changes: 2 additions & 2 deletions tests/test_zero/test_gemini/test_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def exam_model_step(placement_policy, model_name: str, mixed_precision: torch.dt
p.data.copy_(torch_p.data)

world_size = torch.distributed.get_world_size()
config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = False
if placement_policy != 'cuda':
Expand Down Expand Up @@ -130,7 +130,7 @@ def exam_tiny_example(placement_policy, model_name: str, mixed_precision: torch.
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
p.data.copy_(torch_p.data)

chunk_manager = init_chunk_manager(model=model, init_device=get_current_device(), search_range_mb=1)
chunk_manager = init_chunk_manager(model=model, init_device=get_current_device(), search_range_m=1)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager, pin_memory=True, mixed_precision=mixed_precision)
optimizer = HybridAdam(model.parameters(), lr=1e-3)
Expand Down
22 changes: 11 additions & 11 deletions tests/test_zero/test_gemini/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ def exam_search_chunk_size():
model = model_builder()
init_1d_row_spec(model, pg_tp)
config_dict, *_ = search_chunk_configuration(model,
search_range_mb=1,
search_interval_byte=16,
min_chunk_size_mb=0,
search_range_m=1,
search_interval=16,
min_chunk_size_m=0,
filter_exlarge_params=True)

for key in config_dict:
Expand All @@ -54,19 +54,19 @@ def exam_search_strict_ddp():
with ColoInitContext(device=get_current_device()):
ddp_model = model_builder()
re_dict, re_total, re_wasted = search_chunk_configuration(ddp_model,
search_range_mb=1,
search_interval_byte=16,
min_chunk_size_mb=0,
search_range_m=1,
search_interval=16,
min_chunk_size_m=0,
filter_exlarge_params=True,
strict_ddp_flag=False)
# get the chunk configuration over sharded ddp models
with ColoInitContext(device=get_current_device(), default_pg=default_shard_pg,
default_dist_spec=default_shard_spec):
sharded_ddp_model = model_builder()
sh_dict, sh_total, sh_wasted = search_chunk_configuration(sharded_ddp_model,
search_range_mb=1,
search_interval_byte=16,
min_chunk_size_mb=0,
search_range_m=1,
search_interval=16,
min_chunk_size_m=0,
filter_exlarge_params=True,
strict_ddp_flag=True)
assert re_dict == sh_dict
Expand All @@ -91,8 +91,8 @@ def exam_chunk_manager():
chunk_manager = init_chunk_manager(sharded_ddp_model,
get_current_device(),
hidden_dim=16,
search_range_mb=1,
min_chunk_size_mb=0,
search_range_m=1,
min_chunk_size_m=0,
filter_exlarge_params=True,
strict_ddp_flag=True)
config_dict = chunk_manager.dp_degree_chunk_size_dict
Expand Down
4 changes: 2 additions & 2 deletions tests/test_zero/test_gemini/test_zeroddp_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def exam_state_dict(placement_policy, keep_gathered, model_name: str):
torch_p.data.copy_(p.data)

world_size = torch.distributed.get_world_size()
config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = keep_gathered
chunk_manager = ChunkManager(config_dict)
Expand Down Expand Up @@ -67,7 +67,7 @@ def exam_load_state_dict(placement_policy, keep_gathered, model_name: str):
torch_model = model_builder() # get a different model

world_size = torch.distributed.get_world_size()
config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = keep_gathered

Expand Down
Loading