From 6fb9b0f63919d162eb65a82f4c5156695568e346 Mon Sep 17 00:00:00 2001 From: ver217 Date: Thu, 18 May 2023 15:40:24 +0800 Subject: [PATCH 01/13] [plugin] torch ddp plugin add save sharded model --- colossalai/booster/plugin/torch_ddp_plugin.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/colossalai/booster/plugin/torch_ddp_plugin.py b/colossalai/booster/plugin/torch_ddp_plugin.py index 99cd2f7791d3..b317ccf48ad9 100644 --- a/colossalai/booster/plugin/torch_ddp_plugin.py +++ b/colossalai/booster/plugin/torch_ddp_plugin.py @@ -1,4 +1,4 @@ -from typing import Callable, Iterator, List, Tuple, Union +from typing import Callable, Iterator, List, Optional, Tuple, Union import torch.nn as nn from torch.nn.parallel import DistributedDataParallel as DDP @@ -50,6 +50,16 @@ def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): if self.coordinator.is_master(): super().save_lr_scheduler(lr_scheduler, checkpoint) + def save_sharded_model(self, + model: nn.Module, + checkpoint_path: str, + gather_dtensor: bool = False, + variant: Optional[str] = None, + max_shard_size: int = 1024, + use_safetensors: bool = False): + if self.coordinator.is_master(): + super().save_sharded_model(model, checkpoint_path, gather_dtensor, variant, max_shard_size, use_safetensors) + class TorchDDPModel(ModelWrapper): From b36de46e580544fff05dda9390320ae73543b1f6 Mon Sep 17 00:00:00 2001 From: ver217 Date: Thu, 18 May 2023 16:07:24 +0800 Subject: [PATCH 02/13] [test] fix torch ddp ckpt io test --- .../checkpoint_io/checkpoint_io_base.py | 6 +- colossalai/checkpoint_io/utils.py | 62 +++++++++++-------- .../test_torch_ddp_checkpoint_io.py | 51 ++++++++++----- 3 files changed, 72 insertions(+), 47 deletions(-) diff --git a/colossalai/checkpoint_io/checkpoint_io_base.py b/colossalai/checkpoint_io/checkpoint_io_base.py index 9cf344ecc41b..fbc8fc5429ad 100644 --- a/colossalai/checkpoint_io/checkpoint_io_base.py +++ b/colossalai/checkpoint_io/checkpoint_io_base.py @@ -1,7 +1,6 @@ from abc import ABC, abstractmethod from pathlib import Path -from typing import Union -from typing import Optional +from typing import Optional, Union import torch import torch.nn as nn @@ -84,9 +83,8 @@ def load_model(self, # containing no distributed tensors, dtensor -> full tensor conversion # should be done offline via our CLI # the existence of index file means it is a sharded checkpoint - ckpt_path = Path(checkpoint) index_file_exists, index_file_path = has_index_file(checkpoint) - + # return the origin model instead of the unwrapped model origin_model = model diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index ee4bd72e89ec..435feda4ac6a 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -1,10 +1,12 @@ # coding=utf-8 +import re from pathlib import Path +from typing import Iterator, List, Mapping, Optional, OrderedDict, Tuple + import torch import torch.nn as nn -from typing import List, Mapping, OrderedDict, Optional, Tuple, Iterator + from colossalai.tensor.d_tensor.d_tensor import DTensor -import re SAFE_WEIGHTS_NAME = "model.safetensors" WEIGHTS_NAME = "pytorch_model.bin" @@ -15,6 +17,7 @@ # General helper functions # ====================================== + def calculate_tensor_size(tensor: torch.Tensor) -> float: """ Calculate the size of a parameter in MB. Used to compute whether a group of params exceed the shard size. @@ -28,6 +31,7 @@ def calculate_tensor_size(tensor: torch.Tensor) -> float: """ return tensor.numel() * tensor.element_size() / 1024 / 1024 + def is_safetensors_available() -> bool: """ Check whether safetensors is available. @@ -78,7 +82,6 @@ def is_safetensor_checkpoint(checkpoint_file_path: str) -> bool: # Helper functions for saving shard file # ====================================== def shard_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]: - """ Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a given size. @@ -100,35 +103,39 @@ def shard_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> It current_block_size = 0 current_block[key] = weight current_block_size += weight_size - + if ret_block != None: yield ret_block, ret_block_size yield current_block, current_block_size -def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool =False): +def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool = False): """ load shard state dict into model """ if use_safetensors and not checkpoint_file.suffix == ".safetensors": raise Exception("load the model using `safetensors`, but no file endwith .safetensors") if use_safetensors: - from safetensors.torch import safe_open from safetensors.torch import load_file as safe_load_file + from safetensors.torch import safe_open with safe_open(checkpoint_file, framework="pt") as f: metadata = f.metadata() if metadata["format"] != "pt": raise NotImplementedError( - f"Conversion from a {metadata['format']} safetensors archive to PyTorch is not implemented yet." - ) + f"Conversion from a {metadata['format']} safetensors archive to PyTorch is not implemented yet.") return safe_load_file(checkpoint_file) else: return torch.load(checkpoint_file) - -def load_state_dict_into_model(model: nn.Module, state_dict: torch.Tensor, missing_keys: List, strict: bool = False, load_sub_module: bool = True): + + +def load_state_dict_into_model(model: nn.Module, + state_dict: torch.Tensor, + missing_keys: List, + strict: bool = False, + load_sub_module: bool = True): r"""Copies parameters and buffers from :attr:`state_dict` into - this module and its descendants. + this module and its descendants. Args: state_dict (dict): a dict containing parameters and @@ -166,11 +173,12 @@ def load(module: nn.Module, state_dict, prefix="", load_sub_module: bool = True) if strict: if len(unexpected_keys) > 0: - error_msgs = 'Unexpected key(s) in state_dict: {}. '.format( - ', '.join('"{}"'.format(k) for k in unexpected_keys)) + error_msgs = 'Unexpected key(s) in state_dict: {}. '.format(', '.join( + '"{}"'.format(k) for k in unexpected_keys)) raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( - model.__class__.__name__, "\n\t".join(error_msgs))) - + model.__class__.__name__, "\n\t".join(error_msgs))) + + # ====================================== # Helper functions for saving state dict # ====================================== @@ -350,6 +358,8 @@ def has_index_file(checkpoint_path: str) -> Tuple[bool, Optional[Path]]: return True, index_files[0] else: return False, None + else: + raise RuntimeError(f'Invalid checkpoint path {checkpoint_path}. Expected a file or a directory.') def load_state_dict(checkpoint_file_path: Path): @@ -380,7 +390,6 @@ def load_state_dict(checkpoint_file_path: Path): else: # load with torch return torch.load(checkpoint_file_path) - def add_variant(weights_name: str, variant: Optional[str] = None) -> str: @@ -392,17 +401,18 @@ def add_variant(weights_name: str, variant: Optional[str] = None) -> str: return weights_name -def get_base_filenames(variant: str=None, use_safetensors: bool=False): - """ - generate base weight filenames - """ - weights_name = SAFE_WEIGHTS_NAME if use_safetensors else WEIGHTS_NAME - weights_name = add_variant(weights_name, variant) +def get_base_filenames(variant: str = None, use_safetensors: bool = False): + """ + generate base weight filenames + """ + weights_name = SAFE_WEIGHTS_NAME if use_safetensors else WEIGHTS_NAME + weights_name = add_variant(weights_name, variant) + + save_index_file = SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME + save_index_file = add_variant(save_index_file, variant) - save_index_file = SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME - save_index_file = add_variant(save_index_file, variant) + return weights_name, save_index_file - return weights_name, save_index_file def get_shard_filename(weights_name: str, idx: int): """ @@ -410,4 +420,4 @@ def get_shard_filename(weights_name: str, idx: int): """ shard_file = weights_name.replace(".bin", f"-{idx+1:05d}.bin") shard_file = shard_file.replace(".safetensors", f"-{idx + 1:05d}.safetensors") - return shard_file \ No newline at end of file + return shard_file diff --git a/tests/test_checkpoint_io/test_torch_ddp_checkpoint_io.py b/tests/test_checkpoint_io/test_torch_ddp_checkpoint_io.py index 3c05ea9f1b17..8899d2b06e17 100644 --- a/tests/test_checkpoint_io/test_torch_ddp_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_torch_ddp_checkpoint_io.py @@ -1,3 +1,4 @@ +import os import tempfile import torch @@ -8,12 +9,13 @@ import colossalai from colossalai.booster import Booster from colossalai.booster.plugin import TorchDDPPlugin -from colossalai.booster.plugin.torch_ddp_plugin import TorchDDPCheckpointIO +from colossalai.cluster import DistCoordinator from colossalai.interface import OptimizerWrapper -from colossalai.testing import check_state_dict_equal, rerun_if_address_is_in_use, spawn +from colossalai.testing import check_state_dict_equal, parameterize, rerun_if_address_is_in_use, spawn -def check_torch_ddp_checkpointIO(): +@parameterize('shard', [True, False]) +def check_torch_ddp_checkpointIO(shard: bool): plugin = TorchDDPPlugin() booster = Booster(plugin=plugin) model = resnet18() @@ -34,23 +36,38 @@ def check_torch_ddp_checkpointIO(): optimizer.step() scheduler.step() - optimizer_ckpt_tempfile = tempfile.NamedTemporaryFile() - lr_scheduler_ckpt_tempfile = tempfile.NamedTemporaryFile() - ckpt_io = TorchDDPCheckpointIO() - ckpt_io.save_optimizer(optimizer, optimizer_ckpt_tempfile.name) - ckpt_io.save_lr_scheduler(scheduler, lr_scheduler_ckpt_tempfile.name) + coordinator = DistCoordinator() - new_model = resnet18() - new_optimizer = SGD((new_model.parameters()), lr=0.001) - new_scheduler = torch.optim.lr_scheduler.StepLR(new_optimizer, step_size=1, gamma=0.1) - _, new_optimizer, _, _, new_scheduler = booster.boost(new_model, new_optimizer, lr_scheduler=new_scheduler) + with tempfile.TemporaryDirectory() as tempdir: + model_ckpt_path = f"{tempdir}/model" + optimizer_ckpt_path = f"{tempdir}/optimizer" + lr_scheduler_ckpt_path = f"{tempdir}/lr_scheduler" + booster.save_model(model, model_ckpt_path, shard=shard) + if not shard: + # TODO(ver217): optimizer checkpointing is not supported for sharded checkpoint + booster.save_optimizer(optimizer, optimizer_ckpt_path) + booster.save_lr_scheduler(scheduler, lr_scheduler_ckpt_path) - if ckpt_io.coordinator.is_master(): - ckpt_io.load_optimizer(new_optimizer, optimizer_ckpt_tempfile.name) - check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict(), False) + new_model = resnet18() + new_optimizer = SGD((new_model.parameters()), lr=0.001) + new_scheduler = torch.optim.lr_scheduler.StepLR(new_optimizer, step_size=1, gamma=0.1) + new_model, new_optimizer, _, _, new_scheduler = booster.boost(new_model, + new_optimizer, + lr_scheduler=new_scheduler) - ckpt_io.load_lr_scheduler(new_scheduler, lr_scheduler_ckpt_tempfile.name) - check_state_dict_equal(scheduler.state_dict(), new_scheduler.state_dict(), False) + if coordinator.is_master(): + booster.load_model(new_model, model_ckpt_path) + check_state_dict_equal(model.state_dict(), new_model.state_dict(), False) + + if not shard: + booster.load_optimizer(new_optimizer, optimizer_ckpt_path) + check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict(), False) + booster.load_lr_scheduler(new_scheduler, lr_scheduler_ckpt_path) + check_state_dict_equal(scheduler.state_dict(), new_scheduler.state_dict(), False) + else: + assert not os.path.exists(model_ckpt_path) + assert not os.path.exists(optimizer_ckpt_path) + assert not os.path.exists(lr_scheduler_ckpt_path) def run_dist(rank, world_size, port): From 4b7379e892ed2c9c200f670256376fb0b6d7660b Mon Sep 17 00:00:00 2001 From: ver217 Date: Thu, 18 May 2023 16:45:26 +0800 Subject: [PATCH 03/13] [test] fix torch ddp ckpt io test --- .../test_torch_ddp_checkpoint_io.py | 33 +++++++++---------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/tests/test_checkpoint_io/test_torch_ddp_checkpoint_io.py b/tests/test_checkpoint_io/test_torch_ddp_checkpoint_io.py index 8899d2b06e17..8a4217941fe3 100644 --- a/tests/test_checkpoint_io/test_torch_ddp_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_torch_ddp_checkpoint_io.py @@ -1,7 +1,7 @@ -import os import tempfile import torch +import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import SGD from torchvision.models import resnet18 @@ -9,7 +9,6 @@ import colossalai from colossalai.booster import Booster from colossalai.booster.plugin import TorchDDPPlugin -from colossalai.cluster import DistCoordinator from colossalai.interface import OptimizerWrapper from colossalai.testing import check_state_dict_equal, parameterize, rerun_if_address_is_in_use, spawn @@ -36,9 +35,11 @@ def check_torch_ddp_checkpointIO(shard: bool): optimizer.step() scheduler.step() - coordinator = DistCoordinator() - with tempfile.TemporaryDirectory() as tempdir: + obj = [tempdir] + dist.broadcast_object_list(obj, src=0) + tempdir = obj[0] # use the same directory on all ranks + model_ckpt_path = f"{tempdir}/model" optimizer_ckpt_path = f"{tempdir}/optimizer" lr_scheduler_ckpt_path = f"{tempdir}/lr_scheduler" @@ -47,6 +48,7 @@ def check_torch_ddp_checkpointIO(shard: bool): # TODO(ver217): optimizer checkpointing is not supported for sharded checkpoint booster.save_optimizer(optimizer, optimizer_ckpt_path) booster.save_lr_scheduler(scheduler, lr_scheduler_ckpt_path) + dist.barrier() new_model = resnet18() new_optimizer = SGD((new_model.parameters()), lr=0.001) @@ -55,19 +57,16 @@ def check_torch_ddp_checkpointIO(shard: bool): new_optimizer, lr_scheduler=new_scheduler) - if coordinator.is_master(): - booster.load_model(new_model, model_ckpt_path) - check_state_dict_equal(model.state_dict(), new_model.state_dict(), False) - - if not shard: - booster.load_optimizer(new_optimizer, optimizer_ckpt_path) - check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict(), False) - booster.load_lr_scheduler(new_scheduler, lr_scheduler_ckpt_path) - check_state_dict_equal(scheduler.state_dict(), new_scheduler.state_dict(), False) - else: - assert not os.path.exists(model_ckpt_path) - assert not os.path.exists(optimizer_ckpt_path) - assert not os.path.exists(lr_scheduler_ckpt_path) + booster.load_model(new_model, model_ckpt_path) + check_state_dict_equal(model.state_dict(), new_model.state_dict(), False) + + if not shard: + booster.load_optimizer(new_optimizer, optimizer_ckpt_path) + check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict(), False) + booster.load_lr_scheduler(new_scheduler, lr_scheduler_ckpt_path) + check_state_dict_equal(scheduler.state_dict(), new_scheduler.state_dict(), False) + + dist.barrier() def run_dist(rank, world_size, port): From 5dedfd4cdb9a950b8cd7c50db8b9c61438d17ce1 Mon Sep 17 00:00:00 2001 From: ver217 Date: Thu, 18 May 2023 17:45:23 +0800 Subject: [PATCH 04/13] [test] fix low level zero plugin test --- tests/test_booster/test_plugin/test_low_level_zero_plugin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py index d84b96f77a75..2e013ccf895c 100644 --- a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py +++ b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py @@ -11,7 +11,7 @@ from tests.kit.model_zoo import model_zoo # These models are not compatible with AMP -_AMP_ERR_MODELS = ['timm_convit', 'dlrm', 'deepfm_interactionarch', 'deepfm_simpledeepfmnn`'] +_AMP_ERR_MODELS = ['timm_convit', 'dlrm', 'deepfm_interactionarch', 'deepfm_simpledeepfmnn'] # These models have no parameters _LOW_LEVEL_ZERO_ERR_MODELS = ['dlrm_interactionarch'] # These models will get stuck From 402e06c1f5e6e86d54136c7f5c9eb310f525a338 Mon Sep 17 00:00:00 2001 From: ver217 Date: Thu, 18 May 2023 17:58:42 +0800 Subject: [PATCH 05/13] [test] fix low level zero plugin test --- tests/test_booster/test_plugin/test_low_level_zero_plugin.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py index 2e013ccf895c..090a59ff1c36 100644 --- a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py +++ b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py @@ -67,6 +67,7 @@ def check_low_level_zero_plugin(stage: int, early_stop: bool = True): skipped_models.append(name) continue err = run_fn(stage, model_fn, data_gen_fn, output_transform_fn) + dist.barrier() torch.cuda.empty_cache() if err is None: From d80cce7c3de60cec451cb378304e17c77e05e8c1 Mon Sep 17 00:00:00 2001 From: ver217 Date: Thu, 18 May 2023 18:12:46 +0800 Subject: [PATCH 06/13] [test] add debug info --- .github/workflows/build_on_pr.yml | 2 +- tests/test_booster/test_plugin/test_low_level_zero_plugin.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/build_on_pr.yml b/.github/workflows/build_on_pr.yml index a9e50e231164..96031226690d 100644 --- a/.github/workflows/build_on_pr.yml +++ b/.github/workflows/build_on_pr.yml @@ -138,7 +138,7 @@ jobs: - name: Execute Unit Testing run: | - CURL_CA_BUNDLE="" PYTHONPATH=$PWD pytest --testmon --testmon-cov=. tests/ + CURL_CA_BUNDLE="" PYTHONPATH=$PWD pytest -s --testmon --testmon-cov=. tests/ env: DATA: /data/scratch/cifar-10 NCCL_SHM_DISABLE: 1 diff --git a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py index 090a59ff1c36..1a2ff305e86a 100644 --- a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py +++ b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py @@ -63,6 +63,7 @@ def check_low_level_zero_plugin(stage: int, early_stop: bool = True): for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items(): # FIXME(ver217): fix these models + print(name) if name in ignore_models: skipped_models.append(name) continue From af9d453e0a2fb41b4db4a84d6b4153ed808aaa45 Mon Sep 17 00:00:00 2001 From: ver217 Date: Thu, 18 May 2023 18:18:30 +0800 Subject: [PATCH 07/13] [test] add debug info --- tests/test_booster/test_plugin/test_low_level_zero_plugin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py index 1a2ff305e86a..922e4272141a 100644 --- a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py +++ b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py @@ -93,7 +93,7 @@ def run_dist(rank, world_size, port, early_stop: bool = True): @rerun_if_address_is_in_use() def test_low_level_zero_plugin(early_stop: bool = True): - spawn(run_dist, 2, early_stop=early_stop) + spawn(run_dist, 4, early_stop=early_stop) if __name__ == '__main__': From 7b4fcc314b49e0261367dc4a38d9c36f72a134e0 Mon Sep 17 00:00:00 2001 From: ver217 Date: Thu, 18 May 2023 18:19:21 +0800 Subject: [PATCH 08/13] [test] add debug info --- tests/test_booster/test_plugin/test_low_level_zero_plugin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py index 922e4272141a..c7e63092c93e 100644 --- a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py +++ b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py @@ -63,7 +63,7 @@ def check_low_level_zero_plugin(stage: int, early_stop: bool = True): for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items(): # FIXME(ver217): fix these models - print(name) + print(f'{name}\n\n') if name in ignore_models: skipped_models.append(name) continue From 0c22eafa01ccf62677afd587a2b637395134799b Mon Sep 17 00:00:00 2001 From: ver217 Date: Thu, 18 May 2023 18:28:50 +0800 Subject: [PATCH 09/13] [test] add debug info --- .../test_booster/test_plugin/test_low_level_zero_plugin.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py index c7e63092c93e..68e6655a0b60 100644 --- a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py +++ b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py @@ -13,7 +13,7 @@ # These models are not compatible with AMP _AMP_ERR_MODELS = ['timm_convit', 'dlrm', 'deepfm_interactionarch', 'deepfm_simpledeepfmnn'] # These models have no parameters -_LOW_LEVEL_ZERO_ERR_MODELS = ['dlrm_interactionarch'] +_LOW_LEVEL_ZERO_ERR_MODELS = ['dlrm_interactionarch', 'deepfm_overarch'] # These models will get stuck _STUCK_MODELS = [ 'diffusers_vq_model', 'transformers_albert', 'transformers_albert_for_pretraining', 'transformers_bert', @@ -63,7 +63,7 @@ def check_low_level_zero_plugin(stage: int, early_stop: bool = True): for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items(): # FIXME(ver217): fix these models - print(f'{name}\n\n') + print(f'\n{name}\n') if name in ignore_models: skipped_models.append(name) continue @@ -93,7 +93,7 @@ def run_dist(rank, world_size, port, early_stop: bool = True): @rerun_if_address_is_in_use() def test_low_level_zero_plugin(early_stop: bool = True): - spawn(run_dist, 4, early_stop=early_stop) + spawn(run_dist, 2, early_stop=early_stop) if __name__ == '__main__': From c5fca9a45cbe63f915951c165efcf77dc0a372ff Mon Sep 17 00:00:00 2001 From: ver217 Date: Thu, 18 May 2023 18:33:55 +0800 Subject: [PATCH 10/13] [test] add debug info --- tests/test_booster/test_plugin/test_low_level_zero_plugin.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py index 68e6655a0b60..2e88c0c541ef 100644 --- a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py +++ b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py @@ -68,7 +68,7 @@ def check_low_level_zero_plugin(stage: int, early_stop: bool = True): skipped_models.append(name) continue err = run_fn(stage, model_fn, data_gen_fn, output_transform_fn) - dist.barrier() + torch.cuda.empty_cache() if err is None: @@ -93,7 +93,7 @@ def run_dist(rank, world_size, port, early_stop: bool = True): @rerun_if_address_is_in_use() def test_low_level_zero_plugin(early_stop: bool = True): - spawn(run_dist, 2, early_stop=early_stop) + spawn(run_dist, 4, early_stop=early_stop) if __name__ == '__main__': From e38ff10b0ed4c592b6eabee88a49a17fba7bde94 Mon Sep 17 00:00:00 2001 From: ver217 Date: Thu, 18 May 2023 19:26:56 +0800 Subject: [PATCH 11/13] [test] fix low level zero plugin test --- tests/test_booster/test_plugin/test_low_level_zero_plugin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py index 2e88c0c541ef..081804214386 100644 --- a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py +++ b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py @@ -13,7 +13,7 @@ # These models are not compatible with AMP _AMP_ERR_MODELS = ['timm_convit', 'dlrm', 'deepfm_interactionarch', 'deepfm_simpledeepfmnn'] # These models have no parameters -_LOW_LEVEL_ZERO_ERR_MODELS = ['dlrm_interactionarch', 'deepfm_overarch'] +_LOW_LEVEL_ZERO_ERR_MODELS = ['dlrm_interactionarch', 'deepfm_overarch', 'deepfm_sparsearch'] # These models will get stuck _STUCK_MODELS = [ 'diffusers_vq_model', 'transformers_albert', 'transformers_albert_for_pretraining', 'transformers_bert', From 99c542a291c17f531c5c4817156bdbbb2e1c1292 Mon Sep 17 00:00:00 2001 From: ver217 Date: Thu, 18 May 2023 19:34:32 +0800 Subject: [PATCH 12/13] [test] fix low level zero plugin test --- tests/test_booster/test_plugin/test_low_level_zero_plugin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py index 081804214386..1987b9909e12 100644 --- a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py +++ b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py @@ -13,7 +13,7 @@ # These models are not compatible with AMP _AMP_ERR_MODELS = ['timm_convit', 'dlrm', 'deepfm_interactionarch', 'deepfm_simpledeepfmnn'] # These models have no parameters -_LOW_LEVEL_ZERO_ERR_MODELS = ['dlrm_interactionarch', 'deepfm_overarch', 'deepfm_sparsearch'] +_LOW_LEVEL_ZERO_ERR_MODELS = ['dlrm_interactionarch', 'deepfm_overarch', 'deepfm_sparsearch', 'dlrm_sparsearch'] # These models will get stuck _STUCK_MODELS = [ 'diffusers_vq_model', 'transformers_albert', 'transformers_albert_for_pretraining', 'transformers_bert', From 9bb7129fb893a4161fbebda933916b5f0d6b169b Mon Sep 17 00:00:00 2001 From: ver217 Date: Thu, 18 May 2023 19:50:45 +0800 Subject: [PATCH 13/13] [test] remove debug info --- .github/workflows/build_on_pr.yml | 2 +- tests/test_booster/test_plugin/test_low_level_zero_plugin.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/workflows/build_on_pr.yml b/.github/workflows/build_on_pr.yml index 96031226690d..a9e50e231164 100644 --- a/.github/workflows/build_on_pr.yml +++ b/.github/workflows/build_on_pr.yml @@ -138,7 +138,7 @@ jobs: - name: Execute Unit Testing run: | - CURL_CA_BUNDLE="" PYTHONPATH=$PWD pytest -s --testmon --testmon-cov=. tests/ + CURL_CA_BUNDLE="" PYTHONPATH=$PWD pytest --testmon --testmon-cov=. tests/ env: DATA: /data/scratch/cifar-10 NCCL_SHM_DISABLE: 1 diff --git a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py index 1987b9909e12..f70f27be2aa7 100644 --- a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py +++ b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py @@ -63,7 +63,6 @@ def check_low_level_zero_plugin(stage: int, early_stop: bool = True): for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items(): # FIXME(ver217): fix these models - print(f'\n{name}\n') if name in ignore_models: skipped_models.append(name) continue