diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index 0fb992f1da52..54d815ce701e 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -224,8 +224,20 @@ class GeminiPlugin(DPPluginBase): >>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion) Args: - device (torch.device): device to place the model. - placement_policy (str, optional): "cpu", "cuda", "auto". Defaults to "cpu". + chunk_config_dict (dict, optional): chunk configuration dictionary. + chunk_init_device (torch.device, optional): device to initialize the chunk. + placement_policy (str, optional): "static" and "auto". Defaults to "static". + shard_param_frac (float, optional): fraction of parameters to be sharded. Only for "static" placement. + If `shard_param_frac` is 1.0, it's equal to zero-3. If `shard_param_frac` is 0.0, it's equal to zero-2. Defaults to 1.0. + offload_optim_frac (float, optional): fraction of optimizer states to be offloaded. Only for "static" placement. + If `shard_param_frac` is 1.0 and `offload_optim_frac` is 0.0, it's equal to old "cuda" placement. Defaults to 0.0. + offload_param_frac (float, optional): fraction of parameters to be offloaded. Only for "static" placement. + For efficiency, this argument is useful only when `shard_param_frac` is 1.0 and `offload_optim_frac` is 1.0. + If `shard_param_frac` is 1.0, `offload_optim_frac` is 1.0 and `offload_param_frac` is 1.0, it's equal to old "cpu" placement. + When using static placement, we recommend users to tune `shard_param_frac` first and then `offload_optim_frac`. + Defaults to 0.0. + warmup_non_model_data_ratio (float, optional): ratio of expected non-model data memory during warmup. Only for "auto" placement. Defaults to 0.8. + steady_cuda_cap_ratio (float, optional): ratio of allowed cuda capacity for model data during steady state. Only for "auto" placement. Defaults to 0.9. precision (str, optional): precision. Support 'fp16' and 'bf16'. Defaults to 'fp16'. pin_memory (bool, optional): use pin memory on CPU. Defaults to False. force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False. @@ -257,8 +269,14 @@ class GeminiPlugin(DPPluginBase): def __init__( self, + chunk_config_dict: Optional[dict] = None, chunk_init_device: Optional[torch.device] = None, - placement_policy: str = "cpu", + placement_policy: str = "static", + shard_param_frac: float = 1.0, # only for static placement + offload_optim_frac: float = 0.0, # only for static placement + offload_param_frac: float = 0.0, # only for static placement + warmup_non_model_data_ratio: float = 0.8, # only for auto placement + steady_cuda_cap_ratio: float = 0.9, # only for auto placement precision: str = "fp16", pin_memory: bool = False, force_outputs_fp32: bool = False, @@ -282,8 +300,14 @@ def __init__( super().__init__() assert precision in SUPPORTED_PRECISION, f'precision {precision} is not supported' self.gemini_config = dict( + chunk_config_dict=chunk_config_dict, chunk_init_device=(chunk_init_device or get_current_device()), placement_policy=placement_policy, + shard_param_frac=shard_param_frac, + offload_optim_frac=offload_optim_frac, + offload_param_frac=offload_param_frac, + warmup_non_model_data_ratio=warmup_non_model_data_ratio, + steady_cuda_cap_ratio=steady_cuda_cap_ratio, pin_memory=pin_memory, force_outputs_fp32=force_outputs_fp32, strict_ddp_mode=strict_ddp_mode, diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index c8f66a52ff23..0cd90459b76a 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -59,7 +59,12 @@ def __init__( module: torch.nn.Module, chunk_config_dict: Optional[dict] = None, chunk_init_device: torch.device = torch.device('cpu'), - placement_policy: str = "cpu", + placement_policy: str = "static", + shard_param_frac: float = 1.0, # only for static placement + offload_optim_frac: float = 0.0, # only for static placement + offload_param_frac: float = 0.0, # only for static placement + warmup_non_model_data_ratio: float = 0.8, # only for auto placement + steady_cuda_cap_ratio: float = 0.9, # only for auto placement search_range_m: int = 32, # chunk search options hidden_dim: Optional[int] = None, # chunk search options min_chunk_size_m: float = 32, # chunk search options @@ -86,8 +91,14 @@ def __init__( strict_ddp_flag=strict_ddp_mode, process_group=process_group, verbose=verbose) - self.gemini_manager = GeminiManager(placement_policy, self.chunk_manager, memstats) - + self.gemini_manager = GeminiManager(placement_policy, + self.chunk_manager, + memstats, + shard_param_frac=shard_param_frac, + offload_optim_frac=offload_optim_frac, + offload_param_frac=offload_param_frac, + warmup_non_model_data_ratio=warmup_non_model_data_ratio, + steady_cuda_cap_ratio=steady_cuda_cap_ratio) self.force_outputs_fp32 = force_outputs_fp32 self.param_op_hook = GeminiZeROHook(self.gemini_manager) self.fp32_params: List[torch.Tensor] = list() @@ -112,17 +123,17 @@ def __init__( for p in module.parameters(): param_order.append(p) - self._init_chunks(param_order=param_order, - strict_ddp_mode=strict_ddp_mode, - cpu_offload=self.gemini_manager.policy_name != 'cuda', - pin_memory=pin_memory) - for name, param in module.named_parameters(): self.param2name[param] = name for m_name, m_var in module.named_modules(): for p_name, p_var in m_var.named_parameters(recurse=False): param_name = m_name + '.' + p_name if m_name else p_name self.name2param[param_name] = p_var + + self._init_chunks(param_order=param_order, + strict_ddp_mode=strict_ddp_mode, + cpu_offload=self.gemini_manager.policy_name != 'cuda', + pin_memory=pin_memory) super().__init__(module) self._non_persistent_buffers_set = self._get_non_persistent_buffers_set(module) self._cast_buffers() @@ -605,7 +616,7 @@ def load_fp32_parameter(chunk_slice, data): for chunk_32 in chunk_list: chunk_16 = chunk_32.paired_chunk assert chunk_16 is not None - chunk_16.optim_update() + chunk_16.payload.copy_(chunk_32.payload) for name, buf in persistent_buffers.items(): if buf is not None: @@ -664,18 +675,17 @@ def _init_chunks(self, param_order, strict_ddp_mode: bool, cpu_offload: bool, pi self.fp16_params.append(p) self.fp32_params.append(fp32_p) - self.grads_device[p] = self.gemini_manager.default_device self.chunk_manager.close_all_groups() + self.gemini_manager.setup_grads_device(self.fp16_params, self.grads_device) + # move master weights to corresponding device and setup paired chunks for p, fp32_p in zip(self.fp16_params, self.fp32_params): chunk_16 = self.chunk_manager.get_chunk(p) chunk_32 = self.chunk_manager.get_chunk(fp32_p) chunk_32.init_pair(chunk_16) - - # keep gathered chunks are in CUDA - if chunk_16.keep_gathered: - self.grads_device[p] = get_current_device() + if chunk_32.device_type != self.grads_device[p].type: + self.chunk_manager.move_chunk(chunk_32, self.grads_device[p]) def _cast_buffers(self): for buffer in self.module.buffers(): diff --git a/colossalai/zero/gemini/gemini_mgr.py b/colossalai/zero/gemini/gemini_mgr.py index c38e6eff840d..b8e4717908f7 100644 --- a/colossalai/zero/gemini/gemini_mgr.py +++ b/colossalai/zero/gemini/gemini_mgr.py @@ -1,6 +1,6 @@ import functools from time import time -from typing import List, Optional, Tuple +from typing import Dict, List, Optional, Tuple import torch @@ -26,7 +26,11 @@ class GeminiManager: memstats (MemStats, optional): a mem stats collected by a runtime mem tracer. if None then GeminiManager will collect it during a warmup iteration. """ - def __init__(self, placement_policy: str, chunk_manager: ChunkManager, memstats: Optional[MemStats] = None) -> None: + def __init__(self, + placement_policy: str, + chunk_manager: ChunkManager, + memstats: Optional[MemStats] = None, + **placement_kwargs) -> None: assert placement_policy in PlacementPolicyFactory.get_policy_names() self.policy_name = placement_policy @@ -37,7 +41,7 @@ def __init__(self, placement_policy: str, chunk_manager: ChunkManager, memstats: self._memstats = memstats self._mem_stats_collector = ChunkMemStatsCollector(chunk_manager, self._memstats) if policy_cls.need_mem_stats else None - self._placement_policy = policy_cls(chunk_manager, self._mem_stats_collector) + self._placement_policy = policy_cls(chunk_manager, self._mem_stats_collector, **placement_kwargs) self._compute_list: List[Tuple[Chunk, ...]] = [] self._compute_idx: int = -1 @@ -133,10 +137,6 @@ def _record_chunks_order(self, chunks: Tuple[Chunk, ...]) -> None: if self._warmup and self._placement_policy.need_mem_stats: self._compute_list.append(chunks) - @property - def default_device(self): - return self._placement_policy.get_default_device() - def sample_overall_data(self): if self._mem_stats_collector: self._mem_stats_collector.sample_overall_data() @@ -159,6 +159,6 @@ def cuda_margin_mem(self) -> Optional[float]: def is_cuda_margin_mem_avail(self) -> bool: return self._placement_policy.need_mem_stats - @staticmethod - def get_default_device(policy_name: str) -> torch.device: - return PlacementPolicyFactory.get_default_device(policy_name) + def setup_grads_device(self, params: List[torch.Tensor], grads_device_map: Dict[torch.Tensor, + torch.device]) -> None: + self._placement_policy.setup_grads_device(params, grads_device_map) diff --git a/colossalai/zero/gemini/placement_policy.py b/colossalai/zero/gemini/placement_policy.py index 84a868872f88..cd775da5e11f 100644 --- a/colossalai/zero/gemini/placement_policy.py +++ b/colossalai/zero/gemini/placement_policy.py @@ -1,4 +1,5 @@ import functools +import warnings from abc import ABC, abstractmethod from time import time from typing import Dict, List, Optional, Tuple, Type @@ -7,6 +8,7 @@ from colossalai.utils import get_current_device from colossalai.utils.memory import colo_device_memory_capacity +from colossalai.zero.gemini.chunk import Chunk from .chunk import Chunk, ChunkManager from .memory_tracer import ChunkMemStatsCollector @@ -17,7 +19,8 @@ class PlacementPolicy(ABC): def __init__(self, chunk_manager: ChunkManager, - mem_stats_collector: Optional[ChunkMemStatsCollector] = None) -> None: + mem_stats_collector: Optional[ChunkMemStatsCollector] = None, + **kwargs) -> None: self.chunk_manager = chunk_manager self.mem_stats_collector: Optional[ChunkMemStatsCollector] = mem_stats_collector @@ -25,57 +28,87 @@ def __init__(self, def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, float]: raise NotImplementedError - @staticmethod - def get_default_device() -> torch.device: - return torch.device('cpu') + @abstractmethod + def setup_grads_device(self, params: List[torch.Tensor], grads_device_map: Dict[torch.Tensor, + torch.device]) -> None: + raise NotImplementedError -class CPUPlacementPolicy(PlacementPolicy): +class StaticPlacementPolicy(PlacementPolicy): def __init__(self, chunk_manager: ChunkManager, - mem_stats_collector: Optional[ChunkMemStatsCollector] = None) -> None: + mem_stats_collector: Optional[ChunkMemStatsCollector] = None, + shard_param_frac: float = 1.0, + offload_optim_frac: float = 0.0, + offload_param_frac: float = 0.0, + **kwargs) -> None: super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector) + if offload_param_frac > 0.0 and (shard_param_frac != 1.0 or offload_optim_frac != 1.0): + warnings.warn('offload_param_frac is ignored when shard_param_frac != 1.0 or offload_optim_frac != 1.0') + offload_param_frac = 0.0 + self.shard_param_frac = shard_param_frac + self.offload_optim_frac = offload_optim_frac + self.offload_param_frac = offload_param_frac + # these should be initialized in setup_grads_device + self.keep_gathered_chunk_mem = 0.0 + self.keep_cuda_chunk_mem = 0.0 def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, float]: - volume = 0 - start = time() + can_shard_chunk_mem = sum(chunk.chunk_mem for chunk in can_evict_chunks) + can_offload_chunk_mem = can_shard_chunk_mem for chunk in can_evict_chunks: + if can_shard_chunk_mem <= self.keep_gathered_chunk_mem: + break self.chunk_manager.release_chunk(chunk) + # real saved mem is chunk_mem - shard_mem, for simplicity we use chunk_mem + can_shard_chunk_mem -= chunk.chunk_mem + for chunk in can_evict_chunks: + if can_offload_chunk_mem <= self.keep_cuda_chunk_mem: + break self.chunk_manager.move_chunk(chunk, torch.device('cpu')) - volume += chunk.chunk_mem - return volume, time() - start - - -class CUDAPlacementPolicy(PlacementPolicy): - - def __init__(self, - chunk_manager: ChunkManager, - mem_stats_collector: Optional[ChunkMemStatsCollector] = None) -> None: - assert torch.cuda.is_available(), 'Cannot use CUDATensorPlacementPolicy when CUDA is not available' - super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector) - - def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, float]: - return 0, 0 - - @staticmethod - def get_default_device() -> torch.device: - return get_current_device() + # real saved mem is shard_mem, for simplicity we use chunk_mem + can_offload_chunk_mem -= chunk.chunk_mem + return 0, 0.0 + + def setup_grads_device(self, params: List[torch.Tensor], grads_device_map: Dict[torch.Tensor, + torch.device]) -> None: + total_chunk_mem = sum(self.chunk_manager.get_chunk(p).chunk_mem for p in params) + + offload_optim_chunk_mem = total_chunk_mem * self.offload_optim_frac + offloaded_optim_chunk_mem = 0 + chunks = set(self.chunk_manager.get_chunk(p) for p in params) + for chunk in chunks: + params = chunk.get_tensors() + # init offload optim settings + # keep gathered chunks are in CUDA + if chunk.keep_gathered or offloaded_optim_chunk_mem >= offload_optim_chunk_mem: + device = get_current_device() + else: + device = torch.device('cpu') + # real offloaded mem is chunk.shard_mem, for simplicity we use chunk mem here + offloaded_optim_chunk_mem += chunk.chunk_mem + for p in params: + grads_device_map[p] = device + self.keep_gathered_chunk_mem = total_chunk_mem * (1 - self.shard_param_frac) + self.keep_cuda_chunk_mem = total_chunk_mem * (1 - self.offload_param_frac) class AutoPlacementPolicy(PlacementPolicy): - need_mem_stats: bool = True - # model data will use 1-_warmup_non_model_data_ratio CUDA memory in warmup phase - # you can set them by AutoPlacementPolicy.set_warmup_non_model_data_ratio() - # and AutoPlacementPolicy.set_steady_cuda_cap_ratio() - _warmup_non_model_data_ratio: float = 0.8 - _steady_cuda_cap_ratio: float = 0.9 def __init__(self, chunk_manager: ChunkManager, - mem_stats_collector: Optional[ChunkMemStatsCollector] = None) -> None: + mem_stats_collector: Optional[ChunkMemStatsCollector] = None, + warmup_non_model_data_ratio: float = 0.8, + steady_cuda_cap_ratio: float = 0.9, + **kwargs) -> None: super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector) + # model data will use 1-_warmup_non_model_data_ratio CUDA memory in warmup phase + # you can set them by AutoPlacementPolicy.set_warmup_non_model_data_ratio() + # and AutoPlacementPolicy.set_steady_cuda_cap_ratio() + self._warmup_non_model_data_ratio = warmup_non_model_data_ratio + self._steady_cuda_cap_ratio = steady_cuda_cap_ratio def evict_tensors(self, can_evict_chunks: List[Chunk], @@ -105,11 +138,11 @@ def evict_tensors(self, used_cuda_model_data = self.chunk_manager.total_mem['cuda'] if warmup: # We designate a part of CUDA memory for model data in warmup iterations. - max_cuda_non_model_data_per_period = cuda_capacity * AutoPlacementPolicy._warmup_non_model_data_ratio + max_cuda_non_model_data_per_period = cuda_capacity * self._warmup_non_model_data_ratio else: # max non-model-data cuda memory consumption of this sampling moment and the next sampling moment. max_cuda_non_model_data_per_period = self.mem_stats_collector.next_period_non_model_data_usage('cuda') - cuda_capacity *= AutoPlacementPolicy._steady_cuda_cap_ratio + cuda_capacity *= self._steady_cuda_cap_ratio total_cuda_model_data = cuda_capacity - max_cuda_non_model_data_per_period avail_cuda_model_data = total_cuda_model_data - used_cuda_model_data freed_cuda_model_data = 0 @@ -145,89 +178,22 @@ def _sort_can_evict_chunks(can_evict_chunks: tuple, compute_idx: int, compute_li next_compute_idx = sorted(next_compute_idx.items(), key=lambda pair: pair[1], reverse=True) return [t for (t, idx) in next_compute_idx] - @staticmethod - def set_warmup_non_model_data_ratio(ratio: float) -> None: - ratio = float(ratio) - assert 0.0 < ratio < 1.0 - AutoPlacementPolicy._warmup_non_model_data_ratio = ratio - - @staticmethod - def set_steady_cuda_cap_ratio(ratio: float) -> None: - ratio = float(ratio) - assert 0.0 < ratio < 1.0 - AutoPlacementPolicy._steady_cuda_cap_ratio = ratio - - -class ConstPlacementPolicy(PlacementPolicy): - - need_mem_stats: bool = False - _accessed_memory_boundary = 512 * 1024**2 - - def __init__(self, - chunk_manager: ChunkManager, - mem_stats_collector: Optional[ChunkMemStatsCollector] = None) -> None: - super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector) - - def evict_tensors(self, - can_evict_chunks: List[Chunk], - cuda_demand: int = 0, - warmup: bool = True, - compute_list: Optional[List[Tuple[Chunk, ...]]] = None, - compute_idx: int = 0, - **kwargs) -> Tuple[int, float]: - """ - See the docstrings in the class `AutoPlacementPolicy`. - """ - start = time() - used_accessed_memory = self.chunk_manager.accessed_mem - avail_accessed_memory = ConstPlacementPolicy._accessed_memory_boundary - used_accessed_memory - freed_accessed_memory = 0 - - if avail_accessed_memory < cuda_demand: - to_free_memory = cuda_demand - avail_accessed_memory - to_free_chunks = can_evict_chunks - - if not warmup: - # sort all chunks - to_free_chunks = self._sort_can_evict_chunks(tuple(to_free_chunks), compute_idx, tuple(compute_list)) - - for chunk in to_free_chunks: - if freed_accessed_memory >= to_free_memory: - break - - self.chunk_manager.release_chunk(chunk) - self.chunk_manager.move_chunk(chunk, torch.device('cpu')) - freed_accessed_memory += chunk.chunk_mem - - if freed_accessed_memory < to_free_memory: - raise RuntimeError(f"Adjust layout failed! No enough CUDA memory! " - f"Need {to_free_memory}, freed {freed_accessed_memory}") - return freed_accessed_memory, time() - start - - @staticmethod - @functools.lru_cache(maxsize=None) - def _sort_can_evict_chunks(can_evict_chunks: tuple, compute_idx: int, compute_list: tuple) -> list: - next_compute_idx = {chunk: len(compute_list) for chunk in can_evict_chunks} - for i in range(len(compute_list) - 1, compute_idx, -1): - for chunk in compute_list[i]: - if chunk in next_compute_idx: - next_compute_idx[chunk] = i - next_compute_idx = sorted(next_compute_idx.items(), key=lambda pair: pair[1], reverse=True) - return [t for (t, idx) in next_compute_idx] - - @staticmethod - def set_const_memory_boundary(cuda_memory_mb: int) -> None: - boundary = int(cuda_memory_mb * 1024**2) - assert boundary > 0 - ConstPlacementPolicy._accessed_memory_boundary = boundary + def setup_grads_device(self, params: List[torch.Tensor], grads_device_map: Dict[torch.Tensor, + torch.device]) -> None: + for p in params: + chunk = self.chunk_manager.get_chunk(p) + # init offload optim settings + # keep gathered chunks are in CUDA + if chunk.keep_gathered: + grads_device_map[p] = get_current_device() + else: + grads_device_map[p] = torch.device('cpu') class PlacementPolicyFactory: policies: Dict[str, Type[PlacementPolicy]] = { - 'cpu': CPUPlacementPolicy, - 'cuda': CUDAPlacementPolicy, 'auto': AutoPlacementPolicy, - 'const': ConstPlacementPolicy + 'static': StaticPlacementPolicy, } @staticmethod @@ -239,8 +205,3 @@ def create(policy_name: str) -> Type[PlacementPolicy]: @staticmethod def get_policy_names(): return tuple(PlacementPolicyFactory.policies.keys()) - - @staticmethod - def get_default_device(policy_name: str) -> torch.device: - policy_cls = PlacementPolicyFactory.create(policy_name) - return policy_cls.get_default_device() diff --git a/requirements/requirements-test.txt b/requirements/requirements-test.txt index 1a2d8cbff625..57a354d0d013 100644 --- a/requirements/requirements-test.txt +++ b/requirements/requirements-test.txt @@ -16,4 +16,4 @@ triton==2.0.0.dev20221202 requests==2.27.1 # downgrade to avoid huggingface error https://github.com/huggingface/transformers/issues/17611 SentencePiece ninja -flash-attn +flash-attn==2.0.5 diff --git a/tests/test_booster/test_plugin/test_gemini_plugin.py b/tests/test_booster/test_plugin/test_gemini_plugin.py index c635a7b51537..910c5dbec72c 100644 --- a/tests/test_booster/test_plugin/test_gemini_plugin.py +++ b/tests/test_booster/test_plugin/test_gemini_plugin.py @@ -12,19 +12,16 @@ from colossalai.nn.optimizer import HybridAdam from colossalai.tensor.colo_parameter import ColoParameter from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -from colossalai.zero import ColoInitContext from tests.kit.model_zoo import model_zoo def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]: try: - if init_method == 'colo': - ctx = ColoInitContext() - elif init_method == 'lazy': + if init_method == 'lazy': ctx = LazyInitContext() else: ctx = nullcontext() - plugin = GeminiPlugin(placement_policy='cuda', strict_ddp_mode=True, max_norm=1.0, initial_scale=2**5) + plugin = GeminiPlugin(max_norm=1.0, initial_scale=2**5) booster = Booster(plugin=plugin) with ctx: model = model_fn() diff --git a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py index 43fdcb21df2e..6720be58490b 100644 --- a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py @@ -18,12 +18,45 @@ ) from tests.kit.model_zoo import model_zoo +MODEL_PLACEMENT_CONFIGS = [ + { + 'placement_policy': 'static', + 'shard_param_frac': 0.0 + }, # zero2 + { + 'placement_policy': 'static', + 'shard_param_frac': 1.0 + }, # zero3 + { + 'placement_policy': 'static', + 'shard_param_frac': 0.5 + }, # zero3-half +] + +OPTIM_PLACEMENT_CONFIGS = [ + { + 'placement_policy': 'static', + 'shard_param_frac': 0.0, + 'offload_optim_frac': 0.0 + }, # zero2 + { + 'placement_policy': 'static', + 'shard_param_frac': 0.0, + 'offload_optim_frac': 1.0 + }, # zero2-offload + { + 'placement_policy': 'static', + 'shard_param_frac': 0.0, + 'offload_optim_frac': 0.5 + }, # zero2-offload-half +] + @clear_cache_before_run() -@parameterize('placement_policy', ['cuda', 'cpu']) +@parameterize('placement_config', MODEL_PLACEMENT_CONFIGS) @parameterize('model_name', ['transformers_bert_for_sequence_classification']) @parameterize('use_safetensors', [False, True]) -def exam_state_dict_with_origin(placement_policy, model_name, use_safetensors: bool): +def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: bool): from transformers import BertForSequenceClassification (model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values())) bert_model = model_fn() @@ -32,7 +65,7 @@ def exam_state_dict_with_origin(placement_policy, model_name, use_safetensors: b pretrained_path = os.path.join(tempdir, 'pretrained') bert_model.config.save_pretrained(save_directory=pretrained_path) - plugin = GeminiPlugin(placement_policy=placement_policy) + plugin = GeminiPlugin(**placement_config) booster = Booster(plugin=plugin) bert_model, _, _, _, _ = booster.boost(bert_model) model_size = sum(p.numel() * p.element_size() for p in bert_model.parameters()) / 1024**2 @@ -51,14 +84,14 @@ def exam_state_dict_with_origin(placement_policy, model_name, use_safetensors: b @clear_cache_before_run() -@parameterize('placement_policy', ['cuda', 'cpu']) +@parameterize('placement_config', OPTIM_PLACEMENT_CONFIGS) @parameterize('shard', [False, True]) @parameterize('model_name', ['transformers_gpt']) @parameterize('size_per_shard', [32]) -def exam_state_dict(placement_policy, shard: bool, model_name: str, size_per_shard: int): +def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_shard: int): (model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values())) criterion = lambda x: x.mean() - plugin = GeminiPlugin(placement_policy=placement_policy, precision="fp16", initial_scale=(2**14)) + plugin = GeminiPlugin(**placement_config, precision="fp16", initial_scale=(2**14)) booster = Booster(plugin=plugin) model = model_fn() diff --git a/tests/test_zero/test_gemini/test_fwd_bwd.py b/tests/test_zero/test_gemini/test_fwd_bwd.py index 57d12c55b9b6..4cbf564ecfb9 100644 --- a/tests/test_zero/test_gemini/test_fwd_bwd.py +++ b/tests/test_zero/test_gemini/test_fwd_bwd.py @@ -11,10 +11,28 @@ from colossalai.utils.cuda import get_current_device from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero.gemini.chunk import search_chunk_configuration -from tests.components_to_test import run_fwd, run_fwd_bwd +from tests.components_to_test import run_fwd_bwd from tests.components_to_test.registry import non_distributed_component_funcs from tests.test_tensor.common_utils import set_seed +PLACEMENT_CONFIGS = [ + { + 'placement_policy': 'static', + 'shard_param_frac': 0.0 + }, # zero2 + { + 'placement_policy': 'static', + 'shard_param_frac': 1.0 + }, # zero3 + { + 'placement_policy': 'static', + 'shard_param_frac': 0.5 + }, # zero3-half + { + 'placement_policy': 'auto' + } +] + def check_grad(model: GeminiDDP, torch_model: torch.nn.Module): chunk_manager = model.chunk_manager @@ -27,12 +45,12 @@ def check_grad(model: GeminiDDP, torch_model: torch.nn.Module): assert_close(p0, p1.grad, rtol=1e-3, atol=5e-5) -@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const']) +@parameterize('placement_config', PLACEMENT_CONFIGS) @parameterize('keep_gather', [False, True]) @parameterize('model_name', ['gpt2', 'bert', 'albert']) @parameterize('use_grad_checkpoint', [False, True]) def exam_gpt_fwd_bwd( - placement_policy, + placement_config, keep_gather, model_name: str, use_grad_checkpoint: bool = False, @@ -53,7 +71,7 @@ def exam_gpt_fwd_bwd( config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100) config_dict[world_size]['chunk_size'] = 5000 config_dict[world_size]['keep_gathered'] = keep_gather - model = GeminiDDP(model, config_dict, placement_policy=placement_policy, pin_memory=True) + model = GeminiDDP(model, config_dict, init_device, pin_memory=True, **placement_config) optimizer = HybridAdam(model.parameters(), lr=1e-3) zero_optim = GeminiOptimizer(optimizer, model, initial_scale=1) @@ -85,62 +103,10 @@ def exam_gpt_fwd_bwd( check_grad(model, torch_model) -@parameterize('placement_policy', ['cuda', 'cpu']) -@parameterize('keep_gather', [False, True]) -@parameterize('model_name', ['gpt2', 'bert', 'albert']) -@parameterize('scatter_after_inference', [False, True]) -def exam_gpt_inference( - placement_policy, - keep_gather, - model_name: str, - scatter_after_inference: bool = False, -): - init_device = get_current_device() - get_components_func = non_distributed_component_funcs.get_callable(model_name) - model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() - - set_seed(42) - model = model_builder() - - set_seed(42) - torch_model = model_builder().cuda() - for torch_p, p in zip(torch_model.parameters(), model.parameters()): - torch_p.data.copy_(p.data) - - world_size = torch.distributed.get_world_size() - config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100) - config_dict[world_size]['chunk_size'] = 5000 - config_dict[world_size]['keep_gathered'] = keep_gather - model = GeminiDDP(model, config_dict, pin_memory=True, scatter_after_inference=scatter_after_inference) - - rank = dist.get_rank() - amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=1) - torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3) - torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config) - torch_model = DDP(torch_model, device_ids=[rank]) - - set_seed(rank) - model.eval() - torch_model.eval() - for i, (input_ids, label) in enumerate(train_dataloader): - # you can only test a single fwd + bwd. - # after bwd param is grad for Gemini, due to the chunk reuse optimization. - if i > 0: - break - with torch.no_grad(): - input_ids, label = input_ids.cuda(), label.cuda() - - torch_loss = run_fwd(torch_model, input_ids, label, criterion) - loss = run_fwd(model, input_ids, label, criterion) - - assert torch.equal(torch_loss, loss) - - def run_dist(rank, world_size, port): config = {} colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') exam_gpt_fwd_bwd() - exam_gpt_inference() @pytest.mark.dist diff --git a/tests/test_zero/test_gemini/test_grad_clip.py b/tests/test_zero/test_gemini/test_grad_clip.py index 69058256ae47..82b9133b89c1 100644 --- a/tests/test_zero/test_gemini/test_grad_clip.py +++ b/tests/test_zero/test_gemini/test_grad_clip.py @@ -8,13 +8,36 @@ from colossalai.amp import convert_to_apex_amp from colossalai.nn.optimizer import HybridAdam from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -from colossalai.utils.cuda import get_current_device from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero.gemini.chunk import search_chunk_configuration from tests.components_to_test import run_fwd_bwd from tests.components_to_test.registry import non_distributed_component_funcs from tests.test_tensor.common_utils import set_seed +PLACEMENT_CONFIGS = [ + { + 'placement_policy': 'static', + 'shard_param_frac': 0.0, + 'offload_optim_frac': 0.0, + 'offload_param_frac': 0.0 + }, # zero2 + { + 'placement_policy': 'static', + 'shard_param_frac': 0.0, + 'offload_optim_frac': 1.0, + 'offload_param_frac': 0.0 + }, # zero2-offload + { + 'placement_policy': 'static', + 'shard_param_frac': 0.0, + 'offload_optim_frac': 0.5, + 'offload_param_frac': 0.0 + }, # zero2-offload-half + { + 'placement_policy': 'auto' + } +] + def check_param(model: GeminiDDP, torch_model: torch.nn.Module): zero_dict = model.state_dict(only_rank_0=False) @@ -29,9 +52,9 @@ def check_param(model: GeminiDDP, torch_model: torch.nn.Module): assert_close(value, temp_zero_value, rtol=1e-3, atol=4e-3) -@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const']) +@parameterize('placement_config', PLACEMENT_CONFIGS) @parameterize('model_name', ['gpt2']) -def exam_grad_clipping(placement_policy, model_name: str): +def exam_grad_clipping(placement_config, model_name: str): set_seed(1912) get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() @@ -51,7 +74,7 @@ def exam_grad_clipping(placement_policy, model_name: str): config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100) config_dict[world_size]['chunk_size'] = 5000 config_dict[world_size]['keep_gathered'] = False - if placement_policy != 'cuda': + if placement_config['placement_policy'] != 'cuda': init_device = torch.device('cpu') else: init_device = None @@ -59,8 +82,8 @@ def exam_grad_clipping(placement_policy, model_name: str): model = GeminiDDP(model, chunk_config_dict=config_dict, chunk_init_device=init_device, - placement_policy=placement_policy, - pin_memory=True) + pin_memory=True, + **placement_config) optimizer = HybridAdam(model.parameters(), lr=1e-3) zero_optim = GeminiOptimizer(optimizer, model, initial_scale=32, clipping_norm=1.0) diff --git a/tests/test_zero/test_gemini/test_inference.py b/tests/test_zero/test_gemini/test_inference.py index 74f51601cb23..20d145f9661f 100644 --- a/tests/test_zero/test_gemini/test_inference.py +++ b/tests/test_zero/test_gemini/test_inference.py @@ -17,6 +17,24 @@ from tests.components_to_test.registry import non_distributed_component_funcs from tests.test_tensor.common_utils import set_seed +PLACEMENT_CONFIGS = [ + { + 'placement_policy': 'static', + 'shard_param_frac': 0.0 + }, # zero2 + { + 'placement_policy': 'static', + 'shard_param_frac': 1.0 + }, # zero3 + { + 'placement_policy': 'static', + 'shard_param_frac': 0.5 + }, # zero3-half + { + 'placement_policy': 'auto' + } +] + def check_param(model: GeminiDDP, torch_model: torch.nn.Module): zero_dict = model.state_dict(only_rank_0=False) @@ -31,28 +49,24 @@ def check_param(model: GeminiDDP, torch_model: torch.nn.Module): assert_close(value, temp_zero_value, rtol=1e-3, atol=4e-3) -def multi_chunk_init(model: torch.nn.Module, placement_policy: str): +def multi_chunk_init(model: torch.nn.Module, placement_config: dict): world_size = dist.get_world_size() config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100) config_dict[world_size]['chunk_size'] = 5000 config_dict[world_size]['keep_gathered'] = False - if placement_policy != 'cuda': - init_device = torch.device('cpu') - else: - init_device = None - model = GeminiDDP(model, config_dict, init_device, placement_policy=placement_policy, pin_memory=True) + model = GeminiDDP(model, config_dict, pin_memory=True, **placement_config) return model -def single_chunk_init(model: torch.nn.Module, placement_policy: str): - model = GeminiDDP(model, chunk_init_device=get_current_device(), placement_policy=placement_policy, pin_memory=True) +def single_chunk_init(model: torch.nn.Module, placement_config: dict): + model = GeminiDDP(model, chunk_init_device=get_current_device(), pin_memory=True, **placement_config) return model -@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const']) +@parameterize('placement_config', PLACEMENT_CONFIGS) @parameterize('model_name', ['gpt2']) @parameterize('model_init_func', [single_chunk_init, multi_chunk_init]) -def exam_inference(placement_policy: str, model_name: str, model_init_func: Callable): +def exam_inference(placement_config: dict, model_name: str, model_init_func: Callable): set_seed(19360226) get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() @@ -62,14 +76,13 @@ def exam_inference(placement_policy: str, model_name: str, model_init_func: Call torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3) torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config) torch_model = DDP(torch_model, device_ids=[dist.get_rank()]) - init_dev = get_current_device() model = model_builder().to(init_dev) for torch_p, p in zip(torch_model.parameters(), model.parameters()): p.data.copy_(torch_p.data) - model = model_init_func(model, placement_policy) + model = model_init_func(model, placement_config) optimizer = HybridAdam(model.parameters(), lr=1e-3) zero_optim = GeminiOptimizer(optimizer, model, initial_scale=128) diff --git a/tests/test_zero/test_gemini/test_optim.py b/tests/test_zero/test_gemini/test_optim.py index 5eb3e4e4ea66..edcbada0acbb 100644 --- a/tests/test_zero/test_gemini/test_optim.py +++ b/tests/test_zero/test_gemini/test_optim.py @@ -15,6 +15,41 @@ from tests.components_to_test.registry import non_distributed_component_funcs from tests.test_tensor.common_utils import set_seed +PLACEMENT_CONFIGS = [ + { + 'placement_policy': 'static', + 'shard_param_frac': 0.0, + 'offload_optim_frac': 0.0 + }, # zero2 + { + 'placement_policy': 'static', + 'shard_param_frac': 0.0, + 'offload_optim_frac': 1.0 + }, # zero2-offload + { + 'placement_policy': 'static', + 'shard_param_frac': 0.0, + 'offload_optim_frac': 0.5 + }, # zero2-offload-half + { + 'placement_policy': 'static', + 'shard_param_frac': 1.0 + }, # zero3 + { + 'placement_policy': 'static', + 'shard_param_frac': 0.5 + }, # zero3-half + { + 'placement_policy': 'static', + 'shard_param_frac': 1.0, + 'offload_optim_frac': 1.0, + 'offload_param_frac': 1.0 + }, # zero3-offload-all + { + 'placement_policy': 'auto' + } +] + # this model is large enough to slice to chunks TEST_MODELS = ['gpt2'] # these models are too small, all parameters in these models are compacted into one chunk @@ -50,10 +85,10 @@ def check_param(model: GeminiDDP, torch_model: torch.nn.Module, dtype: torch.dty msg=lambda s: s + f'\n{key}\n{temp_zero_value.dtype}') -@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const']) +@parameterize('placement_config', PLACEMENT_CONFIGS) @parameterize('model_name', TEST_MODELS) @parameterize('mixed_precision', [torch.half, torch.bfloat16]) -def exam_model_step(placement_policy, model_name: str, mixed_precision: torch.dtype): +def exam_model_step(placement_config, model_name: str, mixed_precision: torch.dtype): set_seed(42) get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() @@ -73,11 +108,7 @@ def exam_model_step(placement_policy, model_name: str, mixed_precision: torch.dt config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100) config_dict[world_size]['chunk_size'] = 5000 config_dict[world_size]['keep_gathered'] = False - if placement_policy != 'cuda': - init_device = torch.device('cpu') - else: - init_device = None - model = GeminiDDP(model, config_dict, init_device, placement_policy, mixed_precision=mixed_precision) + model = GeminiDDP(model, config_dict, **placement_config, mixed_precision=mixed_precision) optimizer = HybridAdam(model.parameters(), lr=1e-3) zero_optim = GeminiOptimizer(optimizer, model, initial_scale=128) @@ -104,10 +135,10 @@ def exam_model_step(placement_policy, model_name: str, mixed_precision: torch.dt check_param(model, torch_model, mixed_precision) -@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const']) +@parameterize('placement_config', PLACEMENT_CONFIGS) @parameterize('model_name', EXAMPLE_MODELS) @parameterize('mixed_precision', [torch.half, torch.bfloat16]) -def exam_tiny_example(placement_policy, model_name: str, mixed_precision: torch.dtype): +def exam_tiny_example(placement_config, model_name: str, mixed_precision: torch.dtype): set_seed(2008) get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() @@ -127,7 +158,8 @@ def exam_tiny_example(placement_policy, model_name: str, mixed_precision: torch. chunk_init_device=get_current_device(), search_range_m=1, pin_memory=True, - mixed_precision=mixed_precision) + mixed_precision=mixed_precision, + **placement_config) optimizer = HybridAdam(model.parameters(), lr=1e-3) zero_optim = GeminiOptimizer(optimizer, model, initial_scale=2) diff --git a/tests/test_zero/test_gemini/test_zeroddp_state_dict.py b/tests/test_zero/test_gemini/test_zeroddp_state_dict.py index 12c195efd6ed..656bd709e2a1 100644 --- a/tests/test_zero/test_gemini/test_zeroddp_state_dict.py +++ b/tests/test_zero/test_gemini/test_zeroddp_state_dict.py @@ -9,6 +9,24 @@ from tests.components_to_test.registry import non_distributed_component_funcs from tests.test_tensor.common_utils import set_seed +PLACEMENT_CONFIGS = [ + { + 'placement_policy': 'static', + 'shard_param_frac': 0.0 + }, # zero2 + { + 'placement_policy': 'static', + 'shard_param_frac': 1.0 + }, # zero3 + { + 'placement_policy': 'static', + 'shard_param_frac': 0.5 + }, # zero3-half + { + 'placement_policy': 'auto' + } +] + def ignore_the_first_parameter(model: torch.nn.Module): for name, param in model.named_parameters(): @@ -17,10 +35,10 @@ def ignore_the_first_parameter(model: torch.nn.Module): return -@parameterize('placement_policy', ['cuda', 'cpu', 'auto']) +@parameterize('placement_config', PLACEMENT_CONFIGS) @parameterize('keep_gathered', [True, False]) @parameterize('model_name', ['gpt2', 'bert']) -def exam_state_dict(placement_policy, keep_gathered, model_name: str): +def exam_state_dict(placement_config, keep_gathered, model_name: str): set_seed(431) get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() @@ -35,7 +53,7 @@ def exam_state_dict(placement_policy, keep_gathered, model_name: str): config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100) config_dict[world_size]['chunk_size'] = 5000 config_dict[world_size]['keep_gathered'] = keep_gathered - model = GeminiDDP(model, config_dict, placement_policy=placement_policy, pin_memory=True) + model = GeminiDDP(model, config_dict, **placement_config, pin_memory=True) model.train() zero_dict = model.state_dict(only_rank_0=False) @@ -47,10 +65,10 @@ def exam_state_dict(placement_policy, keep_gathered, model_name: str): assert_close(value, temp_zero_value, rtol=1e-3, atol=1e-5) -@parameterize('placement_policy', ['cuda', 'cpu', 'auto']) +@parameterize('placement_config', PLACEMENT_CONFIGS) @parameterize('keep_gathered', [True, False]) @parameterize('model_name', ['gpt2', 'bert']) -def exam_load_state_dict(placement_policy, keep_gathered, model_name: str): +def exam_load_state_dict(placement_config, keep_gathered, model_name: str): set_seed(431) get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() @@ -65,11 +83,7 @@ def exam_load_state_dict(placement_policy, keep_gathered, model_name: str): config_dict[world_size]['chunk_size'] = 5000 config_dict[world_size]['keep_gathered'] = keep_gathered - if placement_policy != 'cuda': - init_device = torch.device('cpu') - else: - init_device = None - model = GeminiDDP(model, config_dict, init_device, placement_policy, pin_memory=True) + model = GeminiDDP(model, config_dict, **placement_config, pin_memory=True) torch_dict = torch_model.state_dict() model.load_state_dict(torch_dict, strict=False) @@ -81,11 +95,37 @@ def exam_load_state_dict(placement_policy, keep_gathered, model_name: str): assert_close(value, temp_zero_value, rtol=1e-3, atol=1e-5) +@parameterize('placement_config', PLACEMENT_CONFIGS) +@parameterize('model_name', ['gpt2', 'bert']) +def exam_state_dict_shard(placement_config, model_name: str): + get_components_func = non_distributed_component_funcs.get_callable(model_name) + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + + model = model_builder() + + model_size = sum(p.numel() * p.element_size() for p in model.parameters()) / 1024**2 + + config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100) + model = GeminiDDP(model, config_dict, **placement_config) + model.train() + + zero_dict = model.state_dict(only_rank_0=False) + accumulated_keys = set() + # ensure number of shards > 1 + for shard, _ in model.state_dict_shard(max_shard_size=(model_size / 3), only_rank_0=False): + for key, value in shard.items(): + assert key not in accumulated_keys, f"key `{key}` is duplicated." + accumulated_keys.add(key) + assert key in zero_dict, f"{key} not in ZeRO dictionary." + assert torch.equal(value, zero_dict[key]), f"{key} not equal." + + def run_dist(rank, world_size, port): config = {} colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') exam_state_dict() exam_load_state_dict() + exam_state_dict_shard() @pytest.mark.dist diff --git a/tests/test_zero/test_gemini/test_zeroddp_state_dict_shard.py b/tests/test_zero/test_gemini/test_zeroddp_state_dict_shard.py deleted file mode 100644 index c8ac8a8502c0..000000000000 --- a/tests/test_zero/test_gemini/test_zeroddp_state_dict_shard.py +++ /dev/null @@ -1,50 +0,0 @@ -import pytest -import torch - -import colossalai -from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -from colossalai.zero import GeminiDDP -from colossalai.zero.gemini.chunk import search_chunk_configuration -from tests.components_to_test.registry import non_distributed_component_funcs - - -@parameterize('placement_policy', ['cuda', 'cpu']) -@parameterize('model_name', ['gpt2', 'bert']) -def exam_state_dict(placement_policy, model_name: str): - get_components_func = non_distributed_component_funcs.get_callable(model_name) - model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() - - model = model_builder() - - model_size = sum(p.numel() * p.element_size() for p in model.parameters()) / 1024**2 - - config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100) - model = GeminiDDP(model, config_dict, placement_policy=placement_policy) - model.train() - - zero_dict = model.state_dict(only_rank_0=False) - accumulated_keys = set() - # ensure number of shards > 1 - for shard, _ in model.state_dict_shard(max_shard_size=(model_size / 3), only_rank_0=False): - for key, value in shard.items(): - assert key not in accumulated_keys, f"key `{key}` is duplicated." - accumulated_keys.add(key) - assert key in zero_dict, f"{key} not in ZeRO dictionary." - assert torch.equal(value, zero_dict[key]), f"{key} not equal." - - -def run_dist(rank, world_size, port): - config = {} - colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - exam_state_dict() - - -@pytest.mark.dist -@pytest.mark.parametrize('world_size', [1, 4]) -@rerun_if_address_is_in_use() -def test_zero_ddp_state_dict_shard(world_size): - spawn(run_dist, world_size) - - -if __name__ == '__main__': - test_zero_ddp_state_dict_shard(1) diff --git a/tests/test_zero/test_gemini/test_zerooptim_state_dict.py b/tests/test_zero/test_gemini/test_zerooptim_state_dict.py index 80e8821c1bf7..09725e11ec0c 100644 --- a/tests/test_zero/test_gemini/test_zerooptim_state_dict.py +++ b/tests/test_zero/test_gemini/test_zerooptim_state_dict.py @@ -10,10 +10,31 @@ from tests.components_to_test.registry import non_distributed_component_funcs from tests.test_tensor.common_utils import set_seed - -@parameterize('placement_policy', ['cuda', 'cpu', 'auto']) +PLACEMENT_CONFIGS = [ + { + 'placement_policy': 'static', + 'shard_param_frac': 0.0, + 'offload_optim_frac': 0.0 + }, # zero2 + { + 'placement_policy': 'static', + 'shard_param_frac': 0.0, + 'offload_optim_frac': 1.0 + }, # zero2-offload + { + 'placement_policy': 'static', + 'shard_param_frac': 0.0, + 'offload_optim_frac': 0.5 + }, # zero2-offload-half + { + 'placement_policy': 'auto' + } +] + + +@parameterize('placement_config', PLACEMENT_CONFIGS) @parameterize('keep_gathered', [True, False]) -def exam_zero_optim_state_dict(placement_policy, keep_gathered): +def exam_zero_optim_state_dict(placement_config, keep_gathered): set_seed(431) get_components_func = non_distributed_component_funcs.get_callable('gpt2') model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() @@ -21,18 +42,13 @@ def exam_zero_optim_state_dict(placement_policy, keep_gathered): model = model_builder() set_seed(451) - torch_model = model_builder() # get a different model world_size = torch.distributed.get_world_size() config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100) config_dict[world_size]['chunk_size'] = 5000 config_dict[world_size]['keep_gathered'] = keep_gathered - if placement_policy != 'cuda': - init_device = torch.device('cpu') - else: - init_device = None - model = GeminiDDP(model, config_dict, init_device, placement_policy, pin_memory=True) + model = GeminiDDP(model, config_dict, **placement_config, pin_memory=True) optimizer = HybridAdam(model.parameters()) optim = GeminiOptimizer(optimizer, model, initial_scale=32) # initialize the link between chunk16 and chunk32