Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 27 additions & 3 deletions colossalai/booster/plugin/gemini_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,8 +224,20 @@ class GeminiPlugin(DPPluginBase):
>>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)

Args:
device (torch.device): device to place the model.
placement_policy (str, optional): "cpu", "cuda", "auto". Defaults to "cpu".
chunk_config_dict (dict, optional): chunk configuration dictionary.
chunk_init_device (torch.device, optional): device to initialize the chunk.
placement_policy (str, optional): "static" and "auto". Defaults to "static".
shard_param_frac (float, optional): fraction of parameters to be sharded. Only for "static" placement.
If `shard_param_frac` is 1.0, it's equal to zero-3. If `shard_param_frac` is 0.0, it's equal to zero-2. Defaults to 1.0.
offload_optim_frac (float, optional): fraction of optimizer states to be offloaded. Only for "static" placement.
If `shard_param_frac` is 1.0 and `offload_optim_frac` is 0.0, it's equal to old "cuda" placement. Defaults to 0.0.
offload_param_frac (float, optional): fraction of parameters to be offloaded. Only for "static" placement.
For efficiency, this argument is useful only when `shard_param_frac` is 1.0 and `offload_optim_frac` is 1.0.
If `shard_param_frac` is 1.0, `offload_optim_frac` is 1.0 and `offload_param_frac` is 1.0, it's equal to old "cpu" placement.
When using static placement, we recommend users to tune `shard_param_frac` first and then `offload_optim_frac`.
Defaults to 0.0.
warmup_non_model_data_ratio (float, optional): ratio of expected non-model data memory during warmup. Only for "auto" placement. Defaults to 0.8.
steady_cuda_cap_ratio (float, optional): ratio of allowed cuda capacity for model data during steady state. Only for "auto" placement. Defaults to 0.9.
precision (str, optional): precision. Support 'fp16' and 'bf16'. Defaults to 'fp16'.
pin_memory (bool, optional): use pin memory on CPU. Defaults to False.
force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False.
Expand Down Expand Up @@ -257,8 +269,14 @@ class GeminiPlugin(DPPluginBase):

def __init__(
self,
chunk_config_dict: Optional[dict] = None,
chunk_init_device: Optional[torch.device] = None,
placement_policy: str = "cpu",
placement_policy: str = "static",
shard_param_frac: float = 1.0, # only for static placement
offload_optim_frac: float = 0.0, # only for static placement
offload_param_frac: float = 0.0, # only for static placement
warmup_non_model_data_ratio: float = 0.8, # only for auto placement
steady_cuda_cap_ratio: float = 0.9, # only for auto placement
Comment thread
FrankLeeeee marked this conversation as resolved.
precision: str = "fp16",
pin_memory: bool = False,
force_outputs_fp32: bool = False,
Expand All @@ -282,8 +300,14 @@ def __init__(
super().__init__()
assert precision in SUPPORTED_PRECISION, f'precision {precision} is not supported'
self.gemini_config = dict(
chunk_config_dict=chunk_config_dict,
chunk_init_device=(chunk_init_device or get_current_device()),
placement_policy=placement_policy,
shard_param_frac=shard_param_frac,
offload_optim_frac=offload_optim_frac,
offload_param_frac=offload_param_frac,
warmup_non_model_data_ratio=warmup_non_model_data_ratio,
steady_cuda_cap_ratio=steady_cuda_cap_ratio,
pin_memory=pin_memory,
force_outputs_fp32=force_outputs_fp32,
strict_ddp_mode=strict_ddp_mode,
Expand Down
38 changes: 24 additions & 14 deletions colossalai/zero/gemini/gemini_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,12 @@ def __init__(
module: torch.nn.Module,
chunk_config_dict: Optional[dict] = None,
chunk_init_device: torch.device = torch.device('cpu'),
placement_policy: str = "cpu",
placement_policy: str = "static",
shard_param_frac: float = 1.0, # only for static placement
offload_optim_frac: float = 0.0, # only for static placement
offload_param_frac: float = 0.0, # only for static placement
warmup_non_model_data_ratio: float = 0.8, # only for auto placement
steady_cuda_cap_ratio: float = 0.9, # only for auto placement
search_range_m: int = 32, # chunk search options
hidden_dim: Optional[int] = None, # chunk search options
min_chunk_size_m: float = 32, # chunk search options
Expand All @@ -86,8 +91,14 @@ def __init__(
strict_ddp_flag=strict_ddp_mode,
process_group=process_group,
verbose=verbose)
self.gemini_manager = GeminiManager(placement_policy, self.chunk_manager, memstats)

self.gemini_manager = GeminiManager(placement_policy,
self.chunk_manager,
memstats,
shard_param_frac=shard_param_frac,
offload_optim_frac=offload_optim_frac,
offload_param_frac=offload_param_frac,
warmup_non_model_data_ratio=warmup_non_model_data_ratio,
steady_cuda_cap_ratio=steady_cuda_cap_ratio)
self.force_outputs_fp32 = force_outputs_fp32
self.param_op_hook = GeminiZeROHook(self.gemini_manager)
self.fp32_params: List[torch.Tensor] = list()
Expand All @@ -112,17 +123,17 @@ def __init__(
for p in module.parameters():
param_order.append(p)

self._init_chunks(param_order=param_order,
strict_ddp_mode=strict_ddp_mode,
cpu_offload=self.gemini_manager.policy_name != 'cuda',
pin_memory=pin_memory)

for name, param in module.named_parameters():
self.param2name[param] = name
for m_name, m_var in module.named_modules():
for p_name, p_var in m_var.named_parameters(recurse=False):
param_name = m_name + '.' + p_name if m_name else p_name
self.name2param[param_name] = p_var

self._init_chunks(param_order=param_order,
strict_ddp_mode=strict_ddp_mode,
cpu_offload=self.gemini_manager.policy_name != 'cuda',
pin_memory=pin_memory)
super().__init__(module)
self._non_persistent_buffers_set = self._get_non_persistent_buffers_set(module)
self._cast_buffers()
Expand Down Expand Up @@ -605,7 +616,7 @@ def load_fp32_parameter(chunk_slice, data):
for chunk_32 in chunk_list:
chunk_16 = chunk_32.paired_chunk
assert chunk_16 is not None
chunk_16.optim_update()
chunk_16.payload.copy_(chunk_32.payload)

for name, buf in persistent_buffers.items():
if buf is not None:
Expand Down Expand Up @@ -664,18 +675,17 @@ def _init_chunks(self, param_order, strict_ddp_mode: bool, cpu_offload: bool, pi

self.fp16_params.append(p)
self.fp32_params.append(fp32_p)
self.grads_device[p] = self.gemini_manager.default_device

self.chunk_manager.close_all_groups()

self.gemini_manager.setup_grads_device(self.fp16_params, self.grads_device)
# move master weights to corresponding device and setup paired chunks
for p, fp32_p in zip(self.fp16_params, self.fp32_params):
chunk_16 = self.chunk_manager.get_chunk(p)
chunk_32 = self.chunk_manager.get_chunk(fp32_p)
chunk_32.init_pair(chunk_16)

# keep gathered chunks are in CUDA
if chunk_16.keep_gathered:
self.grads_device[p] = get_current_device()
if chunk_32.device_type != self.grads_device[p].type:
self.chunk_manager.move_chunk(chunk_32, self.grads_device[p])

def _cast_buffers(self):
for buffer in self.module.buffers():
Expand Down
20 changes: 10 additions & 10 deletions colossalai/zero/gemini/gemini_mgr.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import functools
from time import time
from typing import List, Optional, Tuple
from typing import Dict, List, Optional, Tuple

import torch

Expand All @@ -26,7 +26,11 @@ class GeminiManager:
memstats (MemStats, optional): a mem stats collected by a runtime mem tracer. if None then GeminiManager will collect it during a warmup iteration.
"""

def __init__(self, placement_policy: str, chunk_manager: ChunkManager, memstats: Optional[MemStats] = None) -> None:
def __init__(self,
placement_policy: str,
chunk_manager: ChunkManager,
memstats: Optional[MemStats] = None,
**placement_kwargs) -> None:

assert placement_policy in PlacementPolicyFactory.get_policy_names()
self.policy_name = placement_policy
Expand All @@ -37,7 +41,7 @@ def __init__(self, placement_policy: str, chunk_manager: ChunkManager, memstats:
self._memstats = memstats
self._mem_stats_collector = ChunkMemStatsCollector(chunk_manager,
self._memstats) if policy_cls.need_mem_stats else None
self._placement_policy = policy_cls(chunk_manager, self._mem_stats_collector)
self._placement_policy = policy_cls(chunk_manager, self._mem_stats_collector, **placement_kwargs)
self._compute_list: List[Tuple[Chunk, ...]] = []
self._compute_idx: int = -1

Expand Down Expand Up @@ -133,10 +137,6 @@ def _record_chunks_order(self, chunks: Tuple[Chunk, ...]) -> None:
if self._warmup and self._placement_policy.need_mem_stats:
self._compute_list.append(chunks)

@property
def default_device(self):
return self._placement_policy.get_default_device()

def sample_overall_data(self):
if self._mem_stats_collector:
self._mem_stats_collector.sample_overall_data()
Expand All @@ -159,6 +159,6 @@ def cuda_margin_mem(self) -> Optional[float]:
def is_cuda_margin_mem_avail(self) -> bool:
return self._placement_policy.need_mem_stats

@staticmethod
def get_default_device(policy_name: str) -> torch.device:
return PlacementPolicyFactory.get_default_device(policy_name)
def setup_grads_device(self, params: List[torch.Tensor], grads_device_map: Dict[torch.Tensor,
torch.device]) -> None:
self._placement_policy.setup_grads_device(params, grads_device_map)
Loading