From 30412866e0c6860de0787cd9c5e9a5bfffdc712c Mon Sep 17 00:00:00 2001 From: Camille Zhong <44392324+Camille7777@users.noreply.github.com> Date: Mon, 3 Apr 2023 10:11:03 +0800 Subject: [PATCH 01/27] [chatgpt] add pre-trained model RoBERTa for RLHF stage 2 & 3 (#3223) * Add RoBERTa for RLHF Stage 2 & 3 (test) RoBERTa for RLHF Stage 2 & 3 (still in testing) * Revert "Add RoBERTa for RLHF Stage 2 & 3 (test)" This reverts commit 06741d894dcbe958acd4e10d771f22275e20e368. * Add RoBERTa for RLHF stage 2 & 3 1. add roberta folder under model folder 2. add roberta option in train_reward_model.py 3. add some test in testci * add test for reward model training * Update test_ci.sh * Revert "Update test_ci.sh" This reverts commit 9c7352b81766f3177d31eeec0ec178a301df966a. * Add RoBERTa for RLHF Stage 2 & 3 (test) RoBERTa for RLHF Stage 2 & 3 (still in testing) * Revert "Add RoBERTa for RLHF Stage 2 & 3 (test)" This reverts commit 06741d894dcbe958acd4e10d771f22275e20e368. * Add RoBERTa for RLHF stage 2 & 3 1. add roberta folder under model folder 2. add roberta option in train_reward_model.py 3. add some test in testci * Update test_ci.sh * Revert "Update test_ci.sh" This reverts commit 9c7352b81766f3177d31eeec0ec178a301df966a. * update roberta with coati --- .../Chat/coati/models/roberta/__init__.py | 5 +++ .../coati/models/roberta/roberta_actor.py | 35 +++++++++++++++++ .../coati/models/roberta/roberta_critic.py | 38 ++++++++++++++++++ .../Chat/coati/models/roberta/roberta_rm.py | 39 +++++++++++++++++++ applications/Chat/examples/inference.py | 9 ++++- applications/Chat/examples/test_ci.sh | 20 ++++++++++ applications/Chat/examples/train_dummy.py | 10 ++++- applications/Chat/examples/train_prompts.py | 17 ++++++-- .../Chat/examples/train_reward_model.py | 9 ++++- 9 files changed, 173 insertions(+), 9 deletions(-) create mode 100644 applications/Chat/coati/models/roberta/__init__.py create mode 100644 applications/Chat/coati/models/roberta/roberta_actor.py create mode 100644 applications/Chat/coati/models/roberta/roberta_critic.py create mode 100644 applications/Chat/coati/models/roberta/roberta_rm.py diff --git a/applications/Chat/coati/models/roberta/__init__.py b/applications/Chat/coati/models/roberta/__init__.py new file mode 100644 index 000000000000..0f4a8de067b1 --- /dev/null +++ b/applications/Chat/coati/models/roberta/__init__.py @@ -0,0 +1,5 @@ +from .roberta_actor import RoBERTaActor +from .roberta_critic import RoBERTaCritic +from .roberta_rm import RoBERTaRM + +__all__ = ['RoBERTaActor', 'RoBERTaCritic', 'RoBERTaRM'] \ No newline at end of file diff --git a/applications/Chat/coati/models/roberta/roberta_actor.py b/applications/Chat/coati/models/roberta/roberta_actor.py new file mode 100644 index 000000000000..e35fa6eb19a8 --- /dev/null +++ b/applications/Chat/coati/models/roberta/roberta_actor.py @@ -0,0 +1,35 @@ +from typing import Optional + +from transformers.models.roberta.configuration_roberta import RobertaConfig +from transformers.models.roberta.modeling_roberta import RobertaForCausalLM + +from ..base import Actor + +class RoBERTaActor(Actor): + """ + RoBERTa Actor model. + + Args: + pretrained (str): Pretrained model name or path. + config (RoBERTaConfig): Model config. + checkpoint (bool): Enable gradient checkpointing. + lora_rank (int): Rank of the low-rank approximation. + lora_train_bias (str): LoRA bias training mode. + """ + + + def __init__(self, + pretrained: Optional[str] = None, + config: Optional[RobertaConfig] = None, + checkpoint: bool = False, + lora_rank: int = 0, + lora_train_bias: str = 'none') -> None: + if pretrained is not None: + model = RobertaForCausalLM.from_pretrained(pretrained) + elif config is not None: + model = RobertaForCausalLM(config) + else: + model = RobertaForCausalLM(RobertaConfig()) + if checkpoint: + model.gradient_checkpointing_enable() + super().__init__(model, lora_rank, lora_train_bias) diff --git a/applications/Chat/coati/models/roberta/roberta_critic.py b/applications/Chat/coati/models/roberta/roberta_critic.py new file mode 100644 index 000000000000..c8dc0d9e14f2 --- /dev/null +++ b/applications/Chat/coati/models/roberta/roberta_critic.py @@ -0,0 +1,38 @@ +from typing import Optional + +import torch.nn as nn +from transformers.models.roberta.configuration_roberta import RobertaConfig +from transformers.models.roberta.modeling_roberta import RobertaModel + +from ..base import Critic + + +class RoBERTaCritic(Critic): + """ + RoBERTa Critic model. + + Args: + pretrained (str): Pretrained model name or path. + config (RoBERTa Config): Model config. + checkpoint (bool): Enable gradient checkpointing. + lora_rank (int): Rank of the low-rank approximation. + lora_train_bias (str): LoRA bias training mode. + """ + + def __init__(self, + pretrained: Optional[str] = None, + config: Optional[RobertaConfig] = None, + checkpoint: bool = False, + lora_rank: int = 0, + lora_train_bias: str = 'none', + **kwargs) -> None: + if pretrained is not None: + model = RobertaModel.from_pretrained(pretrained, add_pooling_layer=False) + elif config is not None: + model = RobertaModel(config) + else: + model = RobertaModel(RobertaConfig()) + if checkpoint: + model.gradient_checkpointing_enable() + value_head = nn.Linear(model.config.hidden_size, 1) + super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs) diff --git a/applications/Chat/coati/models/roberta/roberta_rm.py b/applications/Chat/coati/models/roberta/roberta_rm.py new file mode 100644 index 000000000000..77075052978b --- /dev/null +++ b/applications/Chat/coati/models/roberta/roberta_rm.py @@ -0,0 +1,39 @@ +from typing import Optional + +import torch.nn as nn +from transformers import RobertaConfig, RobertaModel + + +from ..base import RewardModel + + +class RoBERTaRM(RewardModel): + """ + RoBERTa Reward model. + + Args: + pretrained (str): Pretrained model name or path. + config (RoBERTaConfig): Model config. + checkpoint (bool): Enable gradient checkpointing. + lora_rank (int): Rank of the low-rank approximation. + lora_train_bias (str): LoRA bias training mode. + """ + + def __init__(self, + pretrained: Optional[str] = None, + config: Optional[RobertaConfig] = None, + checkpoint: bool = False, + lora_rank: int = 0, + lora_train_bias: str = 'none') -> None: + if pretrained is not None: + model = RobertaModel.from_pretrained(pretrained, add_pooling_layer=False) + elif config is not None: + model = RobertaModel(config) + else: + model = RobertaModel(RobertaConfig()) + if checkpoint: + model.gradient_checkpointing_enable() + + value_head = nn.Linear(model.config.hidden_size, 1) + value_head.weight.data.normal_(mean=0.0, std=1/(model.config.hidden_size + 1)) + super().__init__(model, value_head, lora_rank, lora_train_bias) \ No newline at end of file diff --git a/applications/Chat/examples/inference.py b/applications/Chat/examples/inference.py index f75950804d2e..ae59d91c1822 100644 --- a/applications/Chat/examples/inference.py +++ b/applications/Chat/examples/inference.py @@ -4,7 +4,8 @@ from coati.models.bloom import BLOOMActor from coati.models.gpt import GPTActor from coati.models.opt import OPTActor -from transformers import AutoTokenizer +from coati.models.roberta import RoBERTaActor +from transformers import AutoTokenizer, RobertaTokenizer from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer @@ -16,6 +17,8 @@ def eval(args): actor = BLOOMActor(pretrained=args.pretrain).to(torch.cuda.current_device()) elif args.model == 'opt': actor = OPTActor(pretrained=args.pretrain).to(torch.cuda.current_device()) + elif args.model == 'roberta': + actor = RoBERTaActor(pretrained=args.pretrain).to(torch.cuda.current_device()) else: raise ValueError(f'Unsupported model "{args.model}"') @@ -31,6 +34,8 @@ def eval(args): tokenizer.pad_token = tokenizer.eos_token elif args.model == 'opt': tokenizer = AutoTokenizer.from_pretrained('facebook/opt-350m') + elif args.model == 'roberta': + tokenizer = RobertaTokenizer.from_pretrained("roberta-base") else: raise ValueError(f'Unsupported model "{args.model}"') @@ -49,7 +54,7 @@ def eval(args): if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt']) + parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'roberta']) # We suggest to use the pretrained model from HuggingFace, use pretrain to configure model parser.add_argument('--pretrain', type=str, default=None) parser.add_argument('--model_path', type=str, default=None) diff --git a/applications/Chat/examples/test_ci.sh b/applications/Chat/examples/test_ci.sh index db1d0b64e3b3..64cf68a0a13f 100755 --- a/applications/Chat/examples/test_ci.sh +++ b/applications/Chat/examples/test_ci.sh @@ -40,6 +40,13 @@ torchrun --standalone --nproc_per_node=2 ${BASE}/train_dummy.py \ --save_path ${BASE}/actor_checkpoint_dummy.pt python ${BASE}/inference.py --model_path ${BASE}/actor_checkpoint_dummy.pt --pretrain 'gpt2' --model gpt2 +torchrun --standalone --nproc_per_node=2 ${BASE}/train_dummy.py \ + --strategy colossalai_zero2 --num_episodes 1 --max_timesteps 2 \ + --update_timesteps 2 --max_epochs 1 --train_batch_size 2\ + --pretrain 'roberta-base' --model roberta --lora_rank 4\ + --save_path ${BASE}/actor_checkpoint_dummy.pt +python ${BASE}/inference.py --model_path ${BASE}/actor_checkpoint_dummy.pt --pretrain 'roberta-base' --model roberta + rm -rf ${BASE}/actor_checkpoint_dummy.pt # train prompts @@ -68,6 +75,13 @@ torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py $PROMPT_PATH \ --save_path ${BASE}/actor_checkpoint_prompts.pt python ${BASE}/inference.py --model_path ${BASE}/actor_checkpoint_prompts.pt --pretrain 'gpt2' --model gpt2 +torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py $PROMPT_PATH \ + --strategy colossalai_zero2 --num_episodes 1 --max_timesteps 2 \ + --update_timesteps 2 --max_epochs 1 --train_batch_size 2\ + --pretrain 'roberta-base' --model roberta --lora_rank 4\ + --save_path ${BASE}/actor_checkpoint_prompts.pt +python ${BASE}/inference.py --model_path ${BASE}/actor_checkpoint_prompts.pt --pretrain 'roberta-base' --model roberta + rm -rf ${BASE}/actor_checkpoint_prompts.pt # train rm @@ -94,4 +108,10 @@ torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \ --dataset 'Anthropic/hh-rlhf' --subset 'harmless-base'\ --test True --lora_rank 4 +torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \ + --pretrain 'roberta-base' --model 'roberta' \ + --strategy colossalai_zero2 --loss_fn 'log_exp'\ + --dataset 'Anthropic/hh-rlhf' --subset 'harmless-base'\ + --test True --lora_rank 4 + rm -rf ${BASE}/rm_ckpt.pt diff --git a/applications/Chat/examples/train_dummy.py b/applications/Chat/examples/train_dummy.py index d944b018de8f..4ac7ace44803 100644 --- a/applications/Chat/examples/train_dummy.py +++ b/applications/Chat/examples/train_dummy.py @@ -6,11 +6,12 @@ from coati.models.bloom import BLOOMActor, BLOOMCritic from coati.models.gpt import GPTActor, GPTCritic from coati.models.opt import OPTActor, OPTCritic +from coati.models.roberta import RoBERTaActor, RoBERTaCritic from coati.trainer import PPOTrainer from coati.trainer.callbacks import SaveCheckpoint from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy from torch.optim import Adam -from transformers import AutoTokenizer, BloomTokenizerFast +from transformers import AutoTokenizer, BloomTokenizerFast, RobertaTokenizer from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer from colossalai.nn.optimizer import HybridAdam @@ -46,6 +47,9 @@ def main(args): elif args.model == 'opt': actor = OPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device()) critic = OPTCritic(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device()) + elif args.model == 'roberta': + actor = RoBERTaActor(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device()) + critic = RoBERTaCritic(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device()) else: raise ValueError(f'Unsupported model "{args.model}"') @@ -69,6 +73,8 @@ def main(args): tokenizer.pad_token = tokenizer.eos_token elif args.model == 'opt': tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") + elif args.model == 'roberta': + tokenizer = RobertaTokenizer.from_pretrained("roberta-base") else: raise ValueError(f'Unsupported model "{args.model}"') @@ -128,7 +134,7 @@ def main(args): parser.add_argument('--strategy', choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'], default='naive') - parser.add_argument('--model', type=str, default='gpt2', choices=['gpt2', 'bloom', 'opt']) + parser.add_argument('--model', type=str, default='gpt2', choices=['gpt2', 'bloom', 'opt', 'roberta']) parser.add_argument('--pretrain', type=str, default=None) parser.add_argument('--save_path', type=str, default='actor_checkpoint_dummy.pt') parser.add_argument('--need_optim_ckpt', type=bool, default=False) diff --git a/applications/Chat/examples/train_prompts.py b/applications/Chat/examples/train_prompts.py index 6643796d7a8b..5ded6d8432ed 100644 --- a/applications/Chat/examples/train_prompts.py +++ b/applications/Chat/examples/train_prompts.py @@ -8,13 +8,14 @@ from coati.models.gpt import GPTRM, GPTActor, GPTCritic from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM from coati.models.opt import OPTRM, OPTActor, OPTCritic +from coati.models.roberta import RoBERTaRM, RoBERTaActor, RoBERTaCritic from coati.trainer import PPOTrainer from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy from coati.utils import prepare_llama_tokenizer_and_embedding from torch.optim import Adam from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler -from transformers import AutoTokenizer, BloomTokenizerFast, GPT2Tokenizer, LlamaTokenizer +from transformers import AutoTokenizer, BloomTokenizerFast, GPT2Tokenizer, LlamaTokenizer, RobertaTokenizer from colossalai.nn.optimizer import HybridAdam @@ -44,6 +45,8 @@ def main(args): initial_model = OPTActor(pretrained=args.pretrain) elif args.model == 'llama': initial_model = LlamaActor(pretrained=args.pretrain) + elif args.model == 'roberta': + initial_model = RoBERTaActor(pretrained=args.pretrain) else: raise ValueError(f'Unsupported actor model "{args.model}"') @@ -60,6 +63,8 @@ def main(args): reward_model = OPTRM(pretrained=args.rm_pretrain) elif rm_model_name == 'llama': reward_model = LlamaRM(pretrained=args.rm_pretrain) + elif rm_model_name == 'roberta': + reward_model = RoBERTaRM(pretrained=args.rm_pretrain) else: raise ValueError(f'Unsupported reward model "{rm_model_name}"') @@ -79,6 +84,8 @@ def main(args): actor = OPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank) elif args.model == 'llama': actor = LlamaActor(pretrained=args.pretrain, lora_rank=args.lora_rank) + elif args.model == 'roberta': + actor = RoBERTaActor(pretrained=args.pretrain, lora_rank=args.lora_rank) else: raise ValueError(f'Unsupported actor model "{args.model}"') @@ -90,6 +97,8 @@ def main(args): critic = OPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True) elif rm_model_name == 'llama': critic = LlamaCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True) + elif rm_model_name == 'roberta': + critic = RoBERTaCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True) else: raise ValueError(f'Unsupported reward model "{rm_model_name}"') @@ -119,6 +128,8 @@ def main(args): elif args.model == 'llama': tokenizer = LlamaTokenizer.from_pretrained(args.pretrain) tokenizer.eos_token = '<\s>' + elif args.model == 'roberta': + tokenizer = RobertaTokenizer.from_pretrained("roberta-base") else: raise ValueError(f'Unsupported model "{args.model}"') @@ -200,9 +211,9 @@ def tokenize_fn(texts): choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'], default='naive', help='strategy to use') - parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama']) + parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama', 'roberta']) parser.add_argument('--pretrain', type=str, default=None) - parser.add_argument('--rm_model', default=None, choices=['gpt2', 'bloom', 'opt', 'llama']) + parser.add_argument('--rm_model', default=None, choices=['gpt2', 'bloom', 'opt', 'llama', 'roberta']) parser.add_argument('--rm_path', type=str, default=None) parser.add_argument('--rm_pretrain', type=str, default=None) parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts') diff --git a/applications/Chat/examples/train_reward_model.py b/applications/Chat/examples/train_reward_model.py index 729dfa23128f..aa1b51dea7f9 100644 --- a/applications/Chat/examples/train_reward_model.py +++ b/applications/Chat/examples/train_reward_model.py @@ -11,12 +11,13 @@ from coati.models.gpt import GPTRM from coati.models.llama import LlamaRM from coati.models.opt import OPTRM +from coati.models.roberta import RoBERTaRM from coati.trainer import RewardModelTrainer from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy from coati.utils import prepare_llama_tokenizer_and_embedding from datasets import load_dataset from torch.optim import Adam -from transformers import AutoTokenizer, BloomTokenizerFast, DebertaV2Tokenizer, LlamaTokenizer +from transformers import AutoTokenizer, BloomTokenizerFast, DebertaV2Tokenizer, LlamaTokenizer, RobertaTokenizer from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer from colossalai.nn.optimizer import HybridAdam @@ -47,6 +48,8 @@ def train(args): model = DebertaRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device()) elif args.model == 'llama': model = LlamaRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device()) + elif args.model == 'roberta': + model = RoBERTaRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device()) else: raise ValueError(f'Unsupported model "{args.model}"') @@ -67,6 +70,8 @@ def train(args): tokenizer = DebertaV2Tokenizer.from_pretrained('microsoft/deberta-v3-large') elif args.model == 'llama': tokenizer = LlamaTokenizer.from_pretrained(args.pretrain) + elif args.model == 'roberta': + tokenizer = RobertaTokenizer.from_pretrained("roberta-base") else: raise ValueError(f'Unsupported model "{args.model}"') max_len = args.max_len @@ -140,7 +145,7 @@ def train(args): parser.add_argument('--strategy', choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'], default='naive') - parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'deberta', 'llama'], default='bloom') + parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'deberta', 'llama', 'roberta'], default='bloom') parser.add_argument('--pretrain', type=str, default=None) parser.add_argument('--model_path', type=str, default=None) parser.add_argument('--need_optim_ckpt', type=bool, default=False) From 638a07a7f9b504e6c9781e9aa2a9b6c5e9dc49ed Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Mon, 3 Apr 2023 17:12:22 +0800 Subject: [PATCH 02/27] [test] fixed gemini plugin test (#3411) * [test] fixed gemini plugin test * polish code * polish code --- .../offload/base_offload_module.py | 16 ++--- .../auto_parallel/offload/mem_optimize.py | 21 +++--- colossalai/auto_parallel/offload/runtime.py | 59 +++++++++-------- colossalai/auto_parallel/offload/util.py | 33 ++++++---- .../test_offload/test_perf.py | 64 +++++++++---------- .../test_plugin/test_gemini_plugin.py | 60 ++++++----------- .../test_zero/low_level_zero/test_zero1_2.py | 2 + 7 files changed, 124 insertions(+), 131 deletions(-) diff --git a/colossalai/auto_parallel/offload/base_offload_module.py b/colossalai/auto_parallel/offload/base_offload_module.py index 59cea4ece266..3a32f0722a1b 100644 --- a/colossalai/auto_parallel/offload/base_offload_module.py +++ b/colossalai/auto_parallel/offload/base_offload_module.py @@ -1,10 +1,11 @@ -from typing import Optional, Set from functools import partial +from typing import Optional, Set + import torch import torch.nn as nn -from colossalai.nn.parallel.data_parallel import _cast_float from colossalai.gemini.tensor_utils import free_storage +from colossalai.nn.parallel.data_parallel import _cast_float from .region_manager import RegionManager from .util import GlobalRuntimeInfo @@ -20,10 +21,7 @@ class BaseOffloadModule: is_sync (bool): synchronous mode or not. """ - def __init__(self, - model: nn.Module, - region_manager: RegionManager, - is_sync=True): + def __init__(self, model: nn.Module, region_manager: RegionManager, is_sync=True): self.model = model self.region_manager = region_manager @@ -69,8 +67,8 @@ def _post_backward(self): for p in self.model.parameters(): p.grad = None - GlobalRuntimeInfo.fwd_prefetch_event_map.clear() - GlobalRuntimeInfo.bwd_prefetch_event_map.clear() + GlobalRuntimeInfo().fwd_prefetch_event_map.clear() + GlobalRuntimeInfo().bwd_prefetch_event_map.clear() def grad_handle(self, p, grad): empty_grad = torch.empty_like(grad) @@ -82,7 +80,7 @@ def grad_handle(self, p, grad): self.overflow_counter += region.has_inf_or_nan master_stream = torch.cuda.current_stream() with torch.cuda.stream(self.grad_offload_stream): - GlobalRuntimeInfo.d2h_stream.wait_stream(master_stream) + GlobalRuntimeInfo().d2h_stream.wait_stream(master_stream) region.move_grad_to_cpu() return empty_grad diff --git a/colossalai/auto_parallel/offload/mem_optimize.py b/colossalai/auto_parallel/offload/mem_optimize.py index 02778696a106..d56166dea982 100644 --- a/colossalai/auto_parallel/offload/mem_optimize.py +++ b/colossalai/auto_parallel/offload/mem_optimize.py @@ -1,4 +1,5 @@ from typing import Dict + import torch import torch.fx from torch.fx import GraphModule @@ -7,10 +8,11 @@ from colossalai.fx import ColoTracer, is_compatible_with_meta from colossalai.fx.passes.meta_info_prop import MetaInfoProp -from .region_manager import RegionManager -from .runtime import runtime_syn_offload_apply_pass, runtime_asyn_offload_apply_pass from .base_offload_module import BaseOffloadModule -from .util import compute_max_param_mem, compute_total_param_mem, compute_act_peak_mem, GlobalRuntimeInfo +from .region_manager import RegionManager +from .runtime import runtime_asyn_offload_apply_pass, runtime_syn_offload_apply_pass +from .util import GlobalRuntimeInfo, compute_act_peak_mem, compute_max_param_mem, compute_total_param_mem + def memory_optimize(model: torch.nn.Module, inps: Dict[str, torch.Tensor], @@ -29,13 +31,14 @@ def memory_optimize(model: torch.nn.Module, region_manager = RegionManager(graph, solver_name=solver_name, memory_budget=memory_budget) region_manager._build_regions() - GlobalRuntimeInfo.region_list = region_manager.region_list + GlobalRuntimeInfo().region_list = region_manager.region_list - act_peak_mem = compute_act_peak_mem(region_manager.region_list) / 1024 ** 2 - max_param_mem = compute_max_param_mem(region_manager.region_list) / 1024 ** 2 - total_param_mem = compute_total_param_mem(region_manager.region_list) / 1024 ** 2 + act_peak_mem = compute_act_peak_mem(region_manager.region_list) / 1024**2 + max_param_mem = compute_max_param_mem(region_manager.region_list) / 1024**2 + total_param_mem = compute_total_param_mem(region_manager.region_list) / 1024**2 print( - f"act_peak_mem={act_peak_mem:.3f} MB | max_param_mem={max_param_mem:.3f} MB | total_param_mem={total_param_mem:.3f}") + f"act_peak_mem={act_peak_mem:.3f} MB | max_param_mem={max_param_mem:.3f} MB | total_param_mem={total_param_mem:.3f}" + ) if solver_name == 'syn': gm = runtime_syn_offload_apply_pass(gm, region_manager.region_list) @@ -45,5 +48,5 @@ def memory_optimize(model: torch.nn.Module, raise TypeError(f"Unknown solver name {solver_name}!") gm.recompile() - optimized_model = BaseOffloadModule(gm, region_manager, solver_name=='syn') + optimized_model = BaseOffloadModule(gm, region_manager, solver_name == 'syn') return optimized_model diff --git a/colossalai/auto_parallel/offload/runtime.py b/colossalai/auto_parallel/offload/runtime.py index 91c7945bd65f..764ac608826b 100644 --- a/colossalai/auto_parallel/offload/runtime.py +++ b/colossalai/auto_parallel/offload/runtime.py @@ -1,4 +1,5 @@ from typing import List + import torch from torch.fx.node import Node @@ -23,13 +24,13 @@ def forward(ctx, input_, fwd_info, bwd_info): ctx.bwd_info = bwd_info d2h_rid = fwd_info.get('d2h_rid', None) if d2h_rid is not None: - free_region = GlobalRuntimeInfo.region_list[d2h_rid] + free_region = GlobalRuntimeInfo().region_list[d2h_rid] assert isinstance(free_region, Region) free_region.free_cuda_data() h2d_rid = fwd_info.get('h2d_rid', None) if h2d_rid is not None: - h2d_region = GlobalRuntimeInfo.region_list[h2d_rid] + h2d_region = GlobalRuntimeInfo().region_list[h2d_rid] assert isinstance(h2d_region, Region) h2d_region.move_param_to_cuda() @@ -40,7 +41,7 @@ def backward(ctx, grad_output): h2d_rid = ctx.bwd_info.get('h2d_rid', None) if h2d_rid is not None: - pref_region = GlobalRuntimeInfo.region_list[h2d_rid] + pref_region = GlobalRuntimeInfo().region_list[h2d_rid] assert isinstance(pref_region, Region) pref_region.move_param_to_cuda() @@ -65,23 +66,22 @@ def forward(ctx, input_, fwd_info, bwd_info): sync_rid = fwd_info.get('sync_rid', None) if sync_rid is not None: - prefetch_event = GlobalRuntimeInfo.fwd_prefetch_event_map.get( - sync_rid, None) + prefetch_event = GlobalRuntimeInfo().fwd_prefetch_event_map.get(sync_rid, None) if prefetch_event: prefetch_event.wait() h2d_rid = fwd_info.get('h2d_rid', None) if h2d_rid is not None: - pref_region = GlobalRuntimeInfo.region_list[h2d_rid] + pref_region = GlobalRuntimeInfo().region_list[h2d_rid] assert isinstance(pref_region, Region) master_stream = torch.cuda.current_stream() - with torch.cuda.stream(GlobalRuntimeInfo.h2d_stream): - GlobalRuntimeInfo.h2d_stream.wait_stream(master_stream) + with torch.cuda.stream(GlobalRuntimeInfo().h2d_stream): + GlobalRuntimeInfo().h2d_stream.wait_stream(master_stream) pref_region.move_param_to_cuda() prefetch_event = torch.cuda.Event() - prefetch_event.record(GlobalRuntimeInfo.h2d_stream) - GlobalRuntimeInfo.fwd_prefetch_event_map[h2d_rid] = prefetch_event + prefetch_event.record(GlobalRuntimeInfo().h2d_stream) + GlobalRuntimeInfo().fwd_prefetch_event_map[h2d_rid] = prefetch_event return input_ @@ -90,10 +90,9 @@ def backward(ctx, grad_output): sync_rid = ctx.bwd_info.get('sync_rid', None) if sync_rid is not None: - wait_region = GlobalRuntimeInfo.region_list[sync_rid] + wait_region = GlobalRuntimeInfo().region_list[sync_rid] assert isinstance(wait_region, Region) - prefetch_event = GlobalRuntimeInfo.bwd_prefetch_event_map.get( - sync_rid, None) + prefetch_event = GlobalRuntimeInfo().bwd_prefetch_event_map.get(sync_rid, None) if prefetch_event: prefetch_event.wait() else: @@ -101,16 +100,16 @@ def backward(ctx, grad_output): h2d_rid = ctx.bwd_info.get('h2d_rid', None) if h2d_rid is not None: - pref_region = GlobalRuntimeInfo.region_list[h2d_rid] + pref_region = GlobalRuntimeInfo().region_list[h2d_rid] assert isinstance(pref_region, Region) master_stream = torch.cuda.current_stream() - with torch.cuda.stream(GlobalRuntimeInfo.h2d_stream): - GlobalRuntimeInfo.h2d_stream.wait_stream(master_stream) + with torch.cuda.stream(GlobalRuntimeInfo().h2d_stream): + GlobalRuntimeInfo().h2d_stream.wait_stream(master_stream) pref_region.move_param_to_cuda() prefetch_event = torch.cuda.Event() - prefetch_event.record(GlobalRuntimeInfo.h2d_stream) - GlobalRuntimeInfo.bwd_prefetch_event_map[h2d_rid] = prefetch_event + prefetch_event.record(GlobalRuntimeInfo().h2d_stream) + GlobalRuntimeInfo().bwd_prefetch_event_map[h2d_rid] = prefetch_event return grad_output, None, None @@ -129,6 +128,7 @@ def convert_fwd_upload_bwd_offload_to_action(tensor, fwd_info, bwd_info): ret = SynPreFwdPostBwdOP.apply(tensor, fwd_info, bwd_info) return ret + def convert_fwd_prefetch_bwd_offload_to_action(tensor, fwd_info, bwd_info): ''' Convert Prefetch and Offload operation into runtime action. @@ -189,7 +189,8 @@ def runtime_syn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[R if fwd_info or bwd_info: with mod_graph.inserting_after(last_inp_node): - new_node = mod_graph.create_node('call_function', convert_fwd_upload_bwd_offload_to_action, + new_node = mod_graph.create_node('call_function', + convert_fwd_upload_bwd_offload_to_action, args=(last_inp_node, fwd_info, bwd_info)) replace_node_users(last_inp_node, new_node) @@ -206,11 +207,11 @@ def runtime_asyn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[ # upload parameters of the first region last_inp_node = tuple(mod_graph.nodes)[0] - first_region_with_p = [ - region for region in region_list if region.param_size][0] + first_region_with_p = [region for region in region_list if region.param_size][0] fwd_info = {"h2d_rid": first_region_with_p.r_id} with mod_graph.inserting_after(last_inp_node): - upload_apply_node = mod_graph.create_node('call_function', convert_fwd_upload_bwd_offload_to_action, + upload_apply_node = mod_graph.create_node('call_function', + convert_fwd_upload_bwd_offload_to_action, args=(last_inp_node, fwd_info, {})) replace_node_users(last_inp_node, upload_apply_node) last_inp_node = upload_apply_node @@ -225,19 +226,20 @@ def runtime_asyn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[ fwd_info['h2d_rid'] = fwd_prefetch_region.r_id # forward offload - if r_idx > 0 and region_list[r_idx-1].need_offload: + if r_idx > 0 and region_list[r_idx - 1].need_offload: fwd_info['d2h_rid'] = r_idx - 1 bwd_info = {} # backward prefetch - if r_idx > 0 and region_list[r_idx-1].need_offload: + if r_idx > 0 and region_list[r_idx - 1].need_offload: bwd_info['sync_rid'] = r_idx - 1 - if r_idx > 0 and region_list[r_idx-1].bwd_prefetch_region: - bwd_info['h2d_rid'] = region_list[r_idx-1].bwd_prefetch_region.r_id + if r_idx > 0 and region_list[r_idx - 1].bwd_prefetch_region: + bwd_info['h2d_rid'] = region_list[r_idx - 1].bwd_prefetch_region.r_id if fwd_info or bwd_info: with mod_graph.inserting_after(last_inp_node): - new_node = mod_graph.create_node('call_function', convert_fwd_prefetch_bwd_offload_to_action, + new_node = mod_graph.create_node('call_function', + convert_fwd_prefetch_bwd_offload_to_action, args=(last_inp_node, fwd_info, bwd_info)) replace_node_users(last_inp_node, new_node) @@ -246,7 +248,8 @@ def runtime_asyn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[ if region.bwd_prefetch_region: bwd_info = {'h2d_rid': region.bwd_prefetch_region.r_id} with mod_graph.inserting_after(last_inp_node): - new_node = mod_graph.create_node('call_function', convert_fwd_prefetch_bwd_offload_to_action, + new_node = mod_graph.create_node('call_function', + convert_fwd_prefetch_bwd_offload_to_action, args=(last_inp_node, {}, bwd_info)) replace_node_users(last_inp_node, new_node) # gm.graph.print_tabular() diff --git a/colossalai/auto_parallel/offload/util.py b/colossalai/auto_parallel/offload/util.py index a99c4eb20225..6b010512cc9c 100644 --- a/colossalai/auto_parallel/offload/util.py +++ b/colossalai/auto_parallel/offload/util.py @@ -1,6 +1,9 @@ from dataclasses import dataclass from typing import List + import torch + +from colossalai.context.singleton_meta import SingletonMeta from colossalai.fx.profiler import calculate_fwd_out, calculate_fwd_tmp from .region import Region @@ -12,6 +15,7 @@ class NodeInfo: runtime_fwd_mem: float = 0 runtime_bwd_mem: float = 0 + class NvDevicePower: """ NVIDIA GPU computing performance (TFLOPs). @@ -30,12 +34,14 @@ class NvDevicePower: A100_FP32 = 19.5 -class GlobalRuntimeInfo: - h2d_stream = torch.cuda.Stream() - d2h_stream = torch.cuda.Stream() - fwd_prefetch_event_map = {} - bwd_prefetch_event_map = {} - region_list = [] +class GlobalRuntimeInfo(metaclass=SingletonMeta): + + def __init__(self): + self.h2d_stream = torch.cuda.Stream() + self.d2h_stream = torch.cuda.Stream() + self.fwd_prefetch_event_map = {} + self.bwd_prefetch_event_map = {} + self.region_list = [] def compute_act_peak_mem(region_list: List[Region]) -> float: @@ -70,21 +76,24 @@ def compute_act_peak_mem(region_list: List[Region]) -> float: return act_peak_mem + def compute_max_param_mem(region_list: List[Region]) -> float: return max(region.param_size for region in region_list) + def compute_total_param_mem(region_list: List[Region]) -> float: return sum(region.param_size for region in region_list if region.r_id <= region.shared_rid) + def requires_upload_p_in_fwd(shared_reg: Region): - return (shared_reg.r_id >= shared_reg.shared_rid) or ( - shared_reg.r_id < shared_reg.shared_rid and shared_reg.need_offload) + return (shared_reg.r_id >= shared_reg.shared_rid) or (shared_reg.r_id < shared_reg.shared_rid + and shared_reg.need_offload) + def requires_release_p_in_bwd(shared_reg: Region): - return (shared_reg.r_id >= shared_reg.shared_rid) or ( - shared_reg.r_id < shared_reg.shared_rid and shared_reg.need_offload) + return (shared_reg.r_id >= shared_reg.shared_rid) or (shared_reg.r_id < shared_reg.shared_rid + and shared_reg.need_offload) + def requires_offload_g_in_bwd(region: Region): return region.param_size and (region.r_id <= region.shared_rid) - - diff --git a/tests/test_auto_parallel/test_offload/test_perf.py b/tests/test_auto_parallel/test_offload/test_perf.py index d569570f4b7d..17bf9cb87f51 100644 --- a/tests/test_auto_parallel/test_offload/test_perf.py +++ b/tests/test_auto_parallel/test_offload/test_perf.py @@ -1,46 +1,44 @@ import time -import pytest from functools import partial +import pytest import torch -from torch.utils._pytree import tree_map import torch.multiprocessing as mp +from torch.utils._pytree import tree_map import colossalai -from colossalai.nn.optimizer import HybridAdam -from colossalai.fx.profiler import parameter_size -from colossalai.utils.model.colo_init_context import ColoInitContext -from colossalai.utils import free_port, get_current_device -from colossalai.nn.parallel import zero_model_wrapper, zero_optim_wrapper from colossalai.auto_parallel.offload.amp_optimizer import AMPOptimizer from colossalai.auto_parallel.offload.mem_optimize import memory_optimize from colossalai.auto_parallel.offload.solver import NOT_NVML +from colossalai.fx.profiler import parameter_size +from colossalai.nn.optimizer import HybridAdam +from colossalai.nn.parallel import zero_model_wrapper, zero_optim_wrapper from colossalai.testing import parameterize - -from tests.test_tensor.common_utils import set_seed +from colossalai.utils import free_port, get_current_device +from colossalai.utils.model.colo_init_context import ColoInitContext from tests.test_auto_parallel.test_offload.model_utils import * +from tests.test_tensor.common_utils import set_seed @parameterize('model_name', ['gpt2_']) @parameterize('memory_budget', [5000]) @parameterize('solver_name', ['asyn']) -def exam_fwd_bwd( - model_name: str, - memory_budget: float, - solver_name: str -): +def exam_fwd_bwd(model_name: str, memory_budget: float, solver_name: str): # build model get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, data_gen = get_components_func() - label = torch.randint(low=0, high=128, size=(64, 8,), device=get_current_device()) + label = torch.randint(low=0, high=128, size=( + 64, + 8, + ), device=get_current_device()) criterion = LMLoss() set_seed(42) start_time = time.time() model = model_builder() model.train() - param_size = parameter_size(model) / 1024 ** 2 / 2 + param_size = parameter_size(model) / 1024**2 / 2 init_time = time.time() - start_time print(f"init_param_size={param_size:.3f} MB | init_model_time={init_time:.3f} s") @@ -92,13 +90,11 @@ def exam_fwd_bwd( torch.cuda.synchronize() exec_time = sum(sorted(time_list)[:5]) / 5 - runtime_peak_mem_alc = torch.cuda.max_memory_allocated() / 1024 ** 2 - runtime_peak_mem_res = torch.cuda.max_memory_reserved() / 1024 ** 2 + runtime_peak_mem_alc = torch.cuda.max_memory_allocated() / 1024**2 + runtime_peak_mem_res = torch.cuda.max_memory_reserved() / 1024**2 print(f'gemini | model_name: {model_name}') - print( - f'| exec_time={exec_time:.3f} s | param_size={param_size:.3f} MB ' - f'| runtime_peak_mem_alc={runtime_peak_mem_alc:.3f} MB| runtime_peak_mem_res={runtime_peak_mem_res:.3f} MB|' - ) + print(f'| exec_time={exec_time:.3f} s | param_size={param_size:.3f} MB ' + f'| runtime_peak_mem_alc={runtime_peak_mem_alc:.3f} MB| runtime_peak_mem_res={runtime_peak_mem_res:.3f} MB|') print(time_list) del data_args @@ -129,22 +125,26 @@ def exam_fwd_bwd( torch.cuda.synchronize() exec_time = sum(sorted(time_list)[:5]) / 5 - runtime_peak_mem_alc = torch.cuda.max_memory_allocated() / 1024 ** 2 - runtime_peak_mem_res = torch.cuda.max_memory_reserved() / 1024 ** 2 + runtime_peak_mem_alc = torch.cuda.max_memory_allocated() / 1024**2 + runtime_peak_mem_res = torch.cuda.max_memory_reserved() / 1024**2 print(f'solver_name: {solver_name} | model_name: {model_name}') - print( - f'| exec_time={exec_time:.3f} s | param_size={param_size:.3f} MB ' - f'| runtime_peak_mem_alc={runtime_peak_mem_alc:.3f} MB| runtime_peak_mem_res={runtime_peak_mem_res:.3f} MB|' - ) + print(f'| exec_time={exec_time:.3f} s | param_size={param_size:.3f} MB ' + f'| runtime_peak_mem_alc={runtime_peak_mem_alc:.3f} MB| runtime_peak_mem_res={runtime_peak_mem_res:.3f} MB|') print(time_list) -@pytest.mark.skipif(NOT_NVML, reason='pynvml is not installed') -def test_perf(rank, world_size, port): + +def run_dist(rank, world_size, port): config = {} colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') exam_fwd_bwd() -if __name__ == '__main__': - run_func = partial(test_perf, world_size=1, port=free_port()) +@pytest.mark.skip("this test failed") +@pytest.mark.skipif(NOT_NVML, reason='pynvml is not installed') +def test_perf(): + run_func = partial(run_dist, world_size=1, port=free_port()) mp.spawn(run_func, nprocs=1) + + +if __name__ == '__main__': + test_perf() diff --git a/tests/test_booster/test_plugin/test_gemini_plugin.py b/tests/test_booster/test_plugin/test_gemini_plugin.py index 7a0d4a15d53a..169983a76110 100644 --- a/tests/test_booster/test_plugin/test_gemini_plugin.py +++ b/tests/test_booster/test_plugin/test_gemini_plugin.py @@ -21,9 +21,6 @@ def check_gemini_plugin(early_stop: bool = True): Args: early_stop (bool, optional): Whether to stop when getting the first error. Defaults to True. """ - plugin = GeminiPlugin(placement_policy='cuda', strict_ddp_mode=True, max_norm=1.0, initial_scale=2**5) - booster = Booster(plugin=plugin) - passed_models = [] failed_info = {} # (model_name, error) pair @@ -34,46 +31,23 @@ def check_gemini_plugin(early_stop: bool = True): continue # These models are not compatible with gemini if name in [ - 'diffusers_clip_vision_model', - 'timm_resnet', - 'timm_beit', - 'timm_beitv2', - 'timm_eca_nfnet', - 'timm_efficientformer', - 'timm_hrnet_w18_small', - 'timm_nf_ecaresnet101', - 'timm_nf_regnet_b0', - 'timm_skresnet18', - 'timm_wide_resnet50_2', - 'timm_convit', - 'timm_dm_nfnet', - 'timm_swin_transformer', - 'torchaudio_conformer', - 'torchaudio_deepspeech', - 'torchaudio_wavernn', - 'torchaudio_tacotron', - 'deepfm_interactionarch', - 'deepfm_simpledeepfmnn', - 'dlrm', - 'dlrm_interactionarch', - 'torchvision_googlenet', - 'torchvision_inception_v3', - 'torchvision_mobilenet_v3_small', - 'torchvision_resnet18', - 'torchvision_resnext50_32x4d', - 'torchvision_wide_resnet50_2', - 'torchvision_vit_b_16', - 'torchvision_convnext_base', - 'torchvision_swin_s', - 'transformers_albert', - 'transformers_albert_for_pretraining', - 'transformers_bert', - 'transformers_bert_for_pretraining', - 'transformers_gpt_double_heads', - 'torchaudio_hubert_base', + 'diffusers_clip_vision_model', 'timm_resnet', 'timm_beit', 'timm_beitv2', 'timm_eca_nfnet', + 'timm_efficientformer', 'timm_hrnet_w18_small', 'timm_nf_ecaresnet101', 'timm_nf_regnet_b0', + 'timm_skresnet18', 'timm_wide_resnet50_2', 'timm_convit', 'timm_dm_nfnet', 'timm_swin_transformer', + 'torchaudio_conformer', 'torchaudio_deepspeech', 'torchaudio_wavernn', 'torchaudio_tacotron', + 'deepfm_interactionarch', 'deepfm_simpledeepfmnn', 'dlrm', 'dlrm_interactionarch', + 'torchvision_googlenet', 'torchvision_inception_v3', 'torchvision_mobilenet_v3_small', + 'torchvision_resnet18', 'torchvision_resnext50_32x4d', 'torchvision_wide_resnet50_2', + 'torchvision_vit_b_16', 'torchvision_convnext_base', 'torchvision_swin_s', 'transformers_albert', + 'transformers_albert_for_pretraining', 'transformers_bert', 'transformers_bert_for_pretraining', + 'transformers_gpt_double_heads', 'torchaudio_hubert_base', 'torchaudio_wav2vec2_base', + 'transformers_t5_for_conditional_generation', 'transformers_t5', 'transformers_t5_encoder_model' ]: continue + try: + plugin = GeminiPlugin(placement_policy='cuda', strict_ddp_mode=True, max_norm=1.0, initial_scale=2**5) + booster = Booster(plugin=plugin) model = model_fn() optimizer = HybridAdam(model.parameters(), lr=1e-3) criterion = lambda x: x.mean() @@ -97,10 +71,15 @@ def check_gemini_plugin(early_stop: bool = True): booster.backward(loss, optimizer) optimizer.step() passed_models.append(name) + + del booster, plugin, model, optimizer, criterion, data, output, loss except Exception as e: failed_info[name] = e if early_stop: raise e + + torch.cuda.empty_cache() + if dist.get_rank() == 0: print(f'Passed models({len(passed_models)}): {passed_models}\n\n') print(f'Failed models({len(failed_info)}): {list(failed_info.keys())}\n\n') @@ -138,7 +117,6 @@ def run_dist(rank, world_size, port, early_stop: bool = True): check_gemini_plugin(early_stop=early_stop) -@pytest.mark.skip(reason='Skip gemini plugin test due to OOM') @rerun_if_address_is_in_use() def test_gemini_plugin(early_stop: bool = True): world_size = 2 diff --git a/tests/test_zero/low_level_zero/test_zero1_2.py b/tests/test_zero/low_level_zero/test_zero1_2.py index 930b6129174e..ed76e0171fb4 100644 --- a/tests/test_zero/low_level_zero/test_zero1_2.py +++ b/tests/test_zero/low_level_zero/test_zero1_2.py @@ -9,6 +9,7 @@ from torch.testing import assert_close import colossalai +from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing.random import seed_all from colossalai.utils import free_port from colossalai.zero import LowLevelZeroOptimizer @@ -176,6 +177,7 @@ def run_dist(rank, world_size, port): @pytest.mark.dist +@rerun_if_address_is_in_use() def test_zero_1_2(): world_size = 2 run_func = partial(run_dist, world_size=world_size, port=free_port()) From b09adff724c2bbded1c71cc51a707f736a0e2899 Mon Sep 17 00:00:00 2001 From: Yuanchen <70520919+chengeharrison@users.noreply.github.com> Date: Tue, 4 Apr 2023 09:46:23 +0800 Subject: [PATCH 03/27] [chat]fix sft training for bloom, gpt and opt (#3418) fix sft training for bloom, gpt and opt --- applications/Chat/coati/models/bloom/bloom_lm.py | 3 +++ applications/Chat/coati/models/gpt/gpt_lm.py | 3 +++ applications/Chat/coati/models/opt/opt_lm.py | 3 +++ 3 files changed, 9 insertions(+) diff --git a/applications/Chat/coati/models/bloom/bloom_lm.py b/applications/Chat/coati/models/bloom/bloom_lm.py index 628af2e341a2..e4184fcd0d9c 100644 --- a/applications/Chat/coati/models/bloom/bloom_lm.py +++ b/applications/Chat/coati/models/bloom/bloom_lm.py @@ -33,3 +33,6 @@ def __init__(self, if checkpoint: model.gradient_checkpointing_enable() super().__init__(model, lora_rank, lora_train_bias) + + def forward(self, input_ids, attention_mask=None, labels=None, **kwargs): + return self.model(input_ids, attention_mask=attention_mask, labels=labels, **kwargs) diff --git a/applications/Chat/coati/models/gpt/gpt_lm.py b/applications/Chat/coati/models/gpt/gpt_lm.py index 23fc13bf23a4..c558d7e9ea8d 100644 --- a/applications/Chat/coati/models/gpt/gpt_lm.py +++ b/applications/Chat/coati/models/gpt/gpt_lm.py @@ -33,3 +33,6 @@ def __init__(self, if checkpoint: model.gradient_checkpointing_enable() super().__init__(model, lora_rank, lora_train_bias) + + def forward(self, input_ids, attention_mask=None, labels=None, **kwargs): + return self.model(input_ids, attention_mask=attention_mask, labels=labels, **kwargs) diff --git a/applications/Chat/coati/models/opt/opt_lm.py b/applications/Chat/coati/models/opt/opt_lm.py index 65d79e1b2307..47afae847f13 100644 --- a/applications/Chat/coati/models/opt/opt_lm.py +++ b/applications/Chat/coati/models/opt/opt_lm.py @@ -33,3 +33,6 @@ def __init__(self, if checkpoint: model.gradient_checkpointing_enable() super().__init__(model, lora_rank, lora_train_bias) + + def forward(self, input_ids, attention_mask=None, labels=None, **kwargs): + return self.model(input_ids, attention_mask=attention_mask, labels=labels, **kwargs) From 26b7aac0be10fb83692e197ca326f8b67c1c990b Mon Sep 17 00:00:00 2001 From: ver217 Date: Tue, 4 Apr 2023 13:48:16 +0800 Subject: [PATCH 04/27] [zero] reorganize zero/gemini folder structure (#3424) * [zero] refactor low-level zero folder structure * [zero] fix legacy zero import path * [zero] fix legacy zero import path * [zero] remove useless import * [zero] refactor gemini folder structure * [zero] refactor gemini folder structure * [zero] refactor legacy zero import path * [zero] refactor gemini folder structure * [zero] refactor gemini folder structure * [zero] refactor gemini folder structure * [zero] refactor legacy zero import path * [zero] fix test import path * [zero] fix test * [zero] fix circular import * [zero] update import --- .../coati/trainer/strategies/colossalai.py | 9 +- .../offload/base_offload_module.py | 2 +- colossalai/auto_parallel/offload/region.py | 15 +- colossalai/booster/plugin/gemini_plugin.py | 6 +- colossalai/engine/_base_engine.py | 2 +- .../engine/schedule/_pipeline_schedule.py | 2 +- colossalai/gemini/__init__.py | 9 - colossalai/initialize.py | 5 +- colossalai/nn/layer/moe/experts.py | 2 +- colossalai/nn/layer/moe/layers.py | 2 +- colossalai/nn/optimizer/gemini_optimizer.py | 15 - colossalai/nn/parallel/__init__.py | 8 +- colossalai/nn/parallel/data_parallel.py | 525 +--------------- colossalai/nn/parallel/gemini_parallel.py | 63 -- colossalai/zero/__init__.py | 57 +- colossalai/zero/gemini/__init__.py | 11 + .../{ => zero}/gemini/chunk/__init__.py | 0 colossalai/{ => zero}/gemini/chunk/chunk.py | 0 colossalai/{ => zero}/gemini/chunk/manager.py | 3 +- .../{ => zero}/gemini/chunk/search_utils.py | 2 +- colossalai/{ => zero}/gemini/chunk/utils.py | 5 +- .../gemini}/colo_init_context.py | 5 +- colossalai/zero/gemini/gemini_ddp.py | 590 ++++++++++++++++++ .../zero/{utils => gemini}/gemini_hook.py | 4 +- colossalai/{ => zero}/gemini/gemini_mgr.py | 6 +- .../gemini/gemini_optimizer.py} | 14 +- .../gemini/memory_tracer/__init__.py | 0 .../memory_tracer/chunk_memstats_collector.py | 4 +- .../gemini/memory_tracer/memory_monitor.py | 0 .../gemini/memory_tracer/memory_stats.py | 2 +- .../memory_tracer/memstats_collector.py | 13 +- .../memory_tracer/param_runtime_order.py | 0 .../memory_tracer/runtime_mem_tracer.py | 9 +- .../static_memstats_collector.py | 2 +- .../{ => zero}/gemini/memory_tracer/utils.py | 0 .../{ => zero}/gemini/placement_policy.py | 5 +- .../{nn/parallel => zero/gemini}/utils.py | 5 +- colossalai/zero/legacy/__init__.py | 44 ++ colossalai/zero/legacy/gemini/__init__.py | 9 + .../legacy}/gemini/gemini_context.py | 0 .../legacy}/gemini/ophooks/__init__.py | 0 .../gemini/ophooks/_shard_grad_ophook.py | 0 .../gemini/ophooks/_shard_param_ophook.py | 1 + .../gemini/ophooks/runtime_mem_tracer_hook.py | 4 +- .../{ => zero/legacy}/gemini/ophooks/utils.py | 0 .../legacy}/gemini/paramhooks/__init__.py | 0 .../gemini/paramhooks/_param_hookmgr.py | 0 .../legacy}/gemini/stateful_tensor.py | 8 +- .../legacy}/gemini/stateful_tensor_mgr.py | 15 +- .../legacy}/gemini/tensor_placement_policy.py | 7 +- .../{ => zero/legacy}/gemini/tensor_utils.py | 6 +- .../zero/{ => legacy}/init_ctx/__init__.py | 0 .../{ => legacy}/init_ctx/init_context.py | 8 +- .../zero/{ => legacy}/shard_utils/__init__.py | 0 .../shard_utils/base_shard_strategy.py | 3 +- .../bucket_tensor_shard_strategy.py | 11 +- .../zero/{ => legacy}/shard_utils/commons.py | 4 +- .../shard_utils/tensor_shard_strategy.py | 11 +- .../{ => legacy}/sharded_model/__init__.py | 2 +- .../zero/{ => legacy}/sharded_model/_utils.py | 6 +- .../sharded_model/reduce_scatter.py | 0 .../sharded_model/sharded_model_v2.py | 20 +- .../zero/{ => legacy}/sharded_model/utils.py | 5 +- .../sharded_model}/zero_hook.py | 10 +- .../zero/legacy/sharded_optim/__init__.py | 3 + .../sharded_optim/sharded_optim_v2.py | 10 +- .../zero/legacy/sharded_param/__init__.py | 4 + .../sharded_param/sharded_param.py | 12 +- .../sharded_param/sharded_tensor.py | 3 +- colossalai/zero/low_level/__init__.py | 3 + .../{sharded_optim => low_level}/_utils.py | 0 .../bookkeeping/__init__.py | 0 .../bookkeeping/base_store.py | 0 .../bookkeeping/bucket_store.py | 0 .../bookkeeping/gradient_store.py | 0 .../bookkeeping/parameter_store.py | 0 .../bookkeeping/tensor_bucket.py | 0 .../low_level_optim.py | 0 colossalai/zero/sharded_optim/__init__.py | 4 - colossalai/zero/sharded_param/__init__.py | 4 - colossalai/zero/utils/__init__.py | 3 - .../zero_wrapper.py => zero/wrapper.py} | 6 +- docs/source/en/features/nvme_offload.md | 2 +- docs/source/zh-Hans/features/nvme_offload.md | 2 +- examples/images/dreambooth/debug.py | 2 +- .../dreambooth/train_dreambooth_colossalai.py | 5 +- .../train_dreambooth_colossalai_lora.py | 5 +- examples/images/vit/test_vit.py | 2 +- examples/images/vit/train.py | 2 +- examples/language/bert/train_bert_demo.py | 3 +- .../language/gpt/gemini/train_gpt_demo.py | 3 +- examples/language/opt/train_gemini_opt.py | 20 +- examples/language/palm/train.py | 8 +- .../roberta/pretraining/run_pretraining.py | 138 ++-- examples/tutorial/opt/opt/run_clm.py | 22 +- .../test_offload/test_perf.py | 3 +- .../test_compatibility_with_gemini.py | 3 +- tests/test_ddp/test_ddp_ignore_params.py | 8 +- tests/test_ddp/test_ddp_state_dict.py | 15 +- tests/test_gemini/test_gemini_manager.py | 146 ++--- tests/test_gemini/test_param_op.py | 2 +- tests/test_gemini/test_runtime_mem_tracer.py | 4 +- tests/test_gemini/update/test_chunk_mgrv2.py | 2 +- tests/test_gemini/update/test_chunkv2.py | 4 +- tests/test_gemini/update/test_fwd_bwd.py | 8 +- .../test_gemini/update/test_gemini_use_rmt.py | 10 +- .../update/test_get_torch_model.py | 5 +- tests/test_gemini/update/test_grad_clip.py | 8 +- tests/test_gemini/update/test_inference.py | 8 +- tests/test_gemini/update/test_optim.py | 8 +- tests/test_gemini/update/test_search.py | 4 +- .../update/test_zeroddp_state_dict.py | 7 +- .../update/test_zerooptim_state_dict.py | 8 +- tests/test_moe/test_moe_checkpoint.py | 2 +- tests/test_moe/test_moe_colo_init.py | 123 ++-- tests/test_moe/test_moe_zero_init.py | 226 ++++--- tests/test_moe/test_moe_zero_model.py | 10 +- tests/test_moe/test_moe_zero_optim.py | 12 +- tests/test_optimizer/test_cpu_adam.py | 2 +- .../test_optimizer/test_fused_adam_kernel.py | 2 +- tests/test_optimizer/test_hybrid_adam.py | 4 +- tests/test_tensor/model/test_gpt2.py | 2 +- tests/test_tensor/model/test_model.py | 2 +- tests/test_tensor/model/test_module_spec.py | 2 +- tests/test_tensor/test_context.py | 2 +- tests/test_tensor/test_tp_with_zero.py | 6 +- tests/test_utils/test_colo_checkpoint.py | 26 +- tests/test_utils/test_commons.py | 13 +- .../test_zero_gradient_clippling.py | 13 +- tests/test_zero/common.py | 5 +- .../low_level_zero/test_zero_init.py | 3 +- .../test_zero/low_level_zero/test_zero_tp.py | 3 +- tests/test_zero/test_found_inf.py | 144 ++--- tests/test_zero/test_init_context.py | 6 +- tests/test_zero/test_shard_model_v2.py | 10 +- tests/test_zero/test_shard_param.py | 11 +- .../test_sharded_optim_state_dict.py | 25 +- tests/test_zero/test_sharded_optim_v2.py | 12 +- .../test_sharded_optim_with_sync_bn.py | 9 +- tests/test_zero/test_state_dict.py | 14 +- tests/test_zero/test_tensor_utils.py | 25 +- tests/test_zero/test_zero_engine.py | 14 +- 142 files changed, 1427 insertions(+), 1396 deletions(-) delete mode 100644 colossalai/gemini/__init__.py delete mode 100644 colossalai/nn/optimizer/gemini_optimizer.py delete mode 100644 colossalai/nn/parallel/gemini_parallel.py create mode 100644 colossalai/zero/gemini/__init__.py rename colossalai/{ => zero}/gemini/chunk/__init__.py (100%) rename colossalai/{ => zero}/gemini/chunk/chunk.py (100%) rename colossalai/{ => zero}/gemini/chunk/manager.py (99%) rename colossalai/{ => zero}/gemini/chunk/search_utils.py (98%) rename colossalai/{ => zero}/gemini/chunk/utils.py (91%) rename colossalai/{utils/model => zero/gemini}/colo_init_context.py (97%) create mode 100644 colossalai/zero/gemini/gemini_ddp.py rename colossalai/zero/{utils => gemini}/gemini_hook.py (95%) rename colossalai/{ => zero}/gemini/gemini_mgr.py (97%) rename colossalai/{nn/optimizer/zero_optimizer.py => zero/gemini/gemini_optimizer.py} (97%) rename colossalai/{ => zero}/gemini/memory_tracer/__init__.py (100%) rename colossalai/{ => zero}/gemini/memory_tracer/chunk_memstats_collector.py (91%) rename colossalai/{ => zero}/gemini/memory_tracer/memory_monitor.py (100%) rename colossalai/{ => zero}/gemini/memory_tracer/memory_stats.py (98%) rename colossalai/{ => zero}/gemini/memory_tracer/memstats_collector.py (92%) rename colossalai/{ => zero}/gemini/memory_tracer/param_runtime_order.py (100%) rename colossalai/{ => zero}/gemini/memory_tracer/runtime_mem_tracer.py (95%) rename colossalai/{ => zero}/gemini/memory_tracer/static_memstats_collector.py (98%) rename colossalai/{ => zero}/gemini/memory_tracer/utils.py (100%) rename colossalai/{ => zero}/gemini/placement_policy.py (98%) rename colossalai/{nn/parallel => zero/gemini}/utils.py (97%) create mode 100644 colossalai/zero/legacy/__init__.py create mode 100644 colossalai/zero/legacy/gemini/__init__.py rename colossalai/{ => zero/legacy}/gemini/gemini_context.py (100%) rename colossalai/{ => zero/legacy}/gemini/ophooks/__init__.py (100%) rename colossalai/{ => zero/legacy}/gemini/ophooks/_shard_grad_ophook.py (100%) rename colossalai/{ => zero/legacy}/gemini/ophooks/_shard_param_ophook.py (99%) rename colossalai/{ => zero/legacy}/gemini/ophooks/runtime_mem_tracer_hook.py (96%) rename colossalai/{ => zero/legacy}/gemini/ophooks/utils.py (100%) rename colossalai/{ => zero/legacy}/gemini/paramhooks/__init__.py (100%) rename colossalai/{ => zero/legacy}/gemini/paramhooks/_param_hookmgr.py (100%) rename colossalai/{ => zero/legacy}/gemini/stateful_tensor.py (97%) rename colossalai/{ => zero/legacy}/gemini/stateful_tensor_mgr.py (94%) rename colossalai/{ => zero/legacy}/gemini/tensor_placement_policy.py (96%) rename colossalai/{ => zero/legacy}/gemini/tensor_utils.py (97%) rename colossalai/zero/{ => legacy}/init_ctx/__init__.py (100%) rename colossalai/zero/{ => legacy}/init_ctx/init_context.py (97%) rename colossalai/zero/{ => legacy}/shard_utils/__init__.py (100%) rename colossalai/zero/{ => legacy}/shard_utils/base_shard_strategy.py (87%) rename colossalai/zero/{ => legacy}/shard_utils/bucket_tensor_shard_strategy.py (89%) rename colossalai/zero/{ => legacy}/shard_utils/commons.py (95%) rename colossalai/zero/{ => legacy}/shard_utils/tensor_shard_strategy.py (86%) rename colossalai/zero/{ => legacy}/sharded_model/__init__.py (61%) rename colossalai/zero/{ => legacy}/sharded_model/_utils.py (95%) rename colossalai/zero/{ => legacy}/sharded_model/reduce_scatter.py (100%) rename colossalai/zero/{ => legacy}/sharded_model/sharded_model_v2.py (97%) rename colossalai/zero/{ => legacy}/sharded_model/utils.py (91%) rename colossalai/zero/{utils => legacy/sharded_model}/zero_hook.py (92%) create mode 100644 colossalai/zero/legacy/sharded_optim/__init__.py rename colossalai/zero/{ => legacy}/sharded_optim/sharded_optim_v2.py (97%) create mode 100644 colossalai/zero/legacy/sharded_param/__init__.py rename colossalai/zero/{ => legacy}/sharded_param/sharded_param.py (93%) rename colossalai/zero/{ => legacy}/sharded_param/sharded_tensor.py (92%) create mode 100644 colossalai/zero/low_level/__init__.py rename colossalai/zero/{sharded_optim => low_level}/_utils.py (100%) rename colossalai/zero/{sharded_optim => low_level}/bookkeeping/__init__.py (100%) rename colossalai/zero/{sharded_optim => low_level}/bookkeeping/base_store.py (100%) rename colossalai/zero/{sharded_optim => low_level}/bookkeeping/bucket_store.py (100%) rename colossalai/zero/{sharded_optim => low_level}/bookkeeping/gradient_store.py (100%) rename colossalai/zero/{sharded_optim => low_level}/bookkeeping/parameter_store.py (100%) rename colossalai/zero/{sharded_optim => low_level}/bookkeeping/tensor_bucket.py (100%) rename colossalai/zero/{sharded_optim => low_level}/low_level_optim.py (100%) delete mode 100644 colossalai/zero/sharded_optim/__init__.py delete mode 100644 colossalai/zero/sharded_param/__init__.py delete mode 100644 colossalai/zero/utils/__init__.py rename colossalai/{nn/parallel/zero_wrapper.py => zero/wrapper.py} (95%) diff --git a/applications/Chat/coati/trainer/strategies/colossalai.py b/applications/Chat/coati/trainer/strategies/colossalai.py index 521c536406d1..ba85ba76d4b1 100644 --- a/applications/Chat/coati/trainer/strategies/colossalai.py +++ b/applications/Chat/coati/trainer/strategies/colossalai.py @@ -14,17 +14,16 @@ import colossalai from colossalai.logging import get_dist_logger from colossalai.nn.optimizer import CPUAdam, HybridAdam -from colossalai.nn.parallel import ZeroDDP, zero_model_wrapper, zero_optim_wrapper -from colossalai.nn.parallel.utils import get_static_torch_model from colossalai.tensor import ProcessGroup, ShardSpec from colossalai.utils import get_current_device -from colossalai.utils.model.colo_init_context import ColoInitContext - -logger = get_dist_logger(__name__) +from colossalai.zero import ColoInitContext, ZeroDDP, zero_model_wrapper, zero_optim_wrapper +from colossalai.zero.gemini.utils import get_static_torch_model from .base import Strategy from .ddp import DDPStrategy +logger = get_dist_logger(__name__) + class ColossalAIStrategy(DDPStrategy): """ diff --git a/colossalai/auto_parallel/offload/base_offload_module.py b/colossalai/auto_parallel/offload/base_offload_module.py index 3a32f0722a1b..d0c328e134ff 100644 --- a/colossalai/auto_parallel/offload/base_offload_module.py +++ b/colossalai/auto_parallel/offload/base_offload_module.py @@ -4,8 +4,8 @@ import torch import torch.nn as nn -from colossalai.gemini.tensor_utils import free_storage from colossalai.nn.parallel.data_parallel import _cast_float +from colossalai.zero.legacy.gemini.tensor_utils import free_storage from .region_manager import RegionManager from .util import GlobalRuntimeInfo diff --git a/colossalai/auto_parallel/offload/region.py b/colossalai/auto_parallel/offload/region.py index e6907cc4b81d..9a2f558c3145 100644 --- a/colossalai/auto_parallel/offload/region.py +++ b/colossalai/auto_parallel/offload/region.py @@ -1,7 +1,10 @@ -from typing import List, Dict, Tuple +from typing import Dict, List, Tuple + import torch from torch.fx import Node -from colossalai.gemini.tensor_utils import alloc_storage, free_storage + +from colossalai.zero.legacy.gemini.tensor_utils import alloc_storage, free_storage + class Region: """ @@ -52,15 +55,13 @@ def init_param_data(self, pre_alloc_tensor: torch.Tensor = None): Map the parameters in the region to a contiguous memory space. """ - self.fp16_data = torch.zeros( - self.param_num, dtype=torch.half, device='cuda') + self.fp16_data = torch.zeros(self.param_num, dtype=torch.half, device='cuda') offset = 0 for param in self.fp16_params: param.data = param.data.cuda() p_num = param.data.numel() self.fp16_data[offset:offset + p_num].copy_(param.data.flatten()) - param.data = self.fp16_data[offset:offset + - p_num].view(param.data.shape) + param.data = self.fp16_data[offset:offset + p_num].view(param.data.shape) self.param_to_range[param] = (offset, offset + p_num) offset += p_num @@ -141,4 +142,4 @@ def split(self, cut_node_idx: int, cut_param_idx: int): def __update_params_ptr(self) -> None: for param in self.fp16_params: begin, end = self.param_to_range[param] - param.data = self.fp16_data[begin:end].view(param.data.shape) \ No newline at end of file + param.data = self.fp16_data[begin:end].view(param.data.shape) diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index c3c9d007d44f..3c6e539ba972 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -14,12 +14,12 @@ from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO from colossalai.cluster import DistCoordinator -from colossalai.gemini.memory_tracer import MemStats from colossalai.interface import ModelWrapper, OptimizerWrapper -from colossalai.nn.parallel import GeminiDDP, zero_model_wrapper, zero_optim_wrapper from colossalai.tensor.colo_parameter import ColoParameter from colossalai.utils import get_current_device -from colossalai.utils.model.colo_init_context import _convert_to_coloparam +from colossalai.zero import GeminiDDP, zero_model_wrapper, zero_optim_wrapper +from colossalai.zero.gemini.colo_init_context import _convert_to_coloparam +from colossalai.zero.gemini.memory_tracer import MemStats from .plugin_base import Plugin diff --git a/colossalai/engine/_base_engine.py b/colossalai/engine/_base_engine.py index 59d8e1058652..ff8979d82401 100644 --- a/colossalai/engine/_base_engine.py +++ b/colossalai/engine/_base_engine.py @@ -10,8 +10,8 @@ from colossalai.engine.gradient_handler import BaseGradientHandler from colossalai.engine.schedule import BaseSchedule, InterleavedPipelineSchedule, NonPipelineSchedule, PipelineSchedule -from colossalai.gemini.ophooks import BaseOpHook, register_ophooks_recursively from colossalai.logging import get_dist_logger +from colossalai.zero.legacy.gemini import BaseOpHook, register_ophooks_recursively class Engine: diff --git a/colossalai/engine/schedule/_pipeline_schedule.py b/colossalai/engine/schedule/_pipeline_schedule.py index 712ae8242409..38175fe0941c 100644 --- a/colossalai/engine/schedule/_pipeline_schedule.py +++ b/colossalai/engine/schedule/_pipeline_schedule.py @@ -157,7 +157,7 @@ def load_micro_batch(self): return self._move_to_device(mciro_batch_data) def pre_processing(self, engine): - from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2 + from colossalai.zero.legacy import ShardedModelV2 # TODO: remove this after testing new zero with pipeline parallelism model = engine.model diff --git a/colossalai/gemini/__init__.py b/colossalai/gemini/__init__.py deleted file mode 100644 index 7a5a44ebb1ef..000000000000 --- a/colossalai/gemini/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -from .chunk import ChunkManager, TensorInfo, TensorState, search_chunk_configuration -from .gemini_mgr import GeminiManager -from .stateful_tensor_mgr import StatefulTensorMgr -from .tensor_placement_policy import TensorPlacementPolicyFactory - -__all__ = [ - 'StatefulTensorMgr', 'TensorPlacementPolicyFactory', 'GeminiManager', 'TensorInfo', 'TensorState', 'ChunkManager', - 'search_chunk_configuration' -] diff --git a/colossalai/initialize.py b/colossalai/initialize.py index f3719dcb47b3..5d3f3e5530cb 100644 --- a/colossalai/initialize.py +++ b/colossalai/initialize.py @@ -29,13 +29,12 @@ PipelineSchedule, get_tensor_shape, ) -from colossalai.gemini.ophooks import BaseOpHook from colossalai.logging import get_dist_logger from colossalai.nn.optimizer.colossalai_optimizer import ColossalaiOptimizer from colossalai.utils import get_current_device, is_using_ddp, is_using_pp, is_using_sequence, sync_model_param from colossalai.utils.moe import sync_moe_model_param -from colossalai.zero import convert_to_zero_v2 -from colossalai.zero.sharded_optim.sharded_optim_v2 import ShardedOptimizerV2 +from colossalai.zero.legacy import ShardedOptimizerV2, convert_to_zero_v2 +from colossalai.zero.legacy.gemini.ophooks import BaseOpHook def get_default_parser(): diff --git a/colossalai/nn/layer/moe/experts.py b/colossalai/nn/layer/moe/experts.py index 4fb9ad332c24..2e5d9e6e79a9 100644 --- a/colossalai/nn/layer/moe/experts.py +++ b/colossalai/nn/layer/moe/experts.py @@ -9,7 +9,7 @@ from colossalai.context import ParallelMode, seed from colossalai.context.moe_context import MOE_CONTEXT from colossalai.utils import get_current_device -from colossalai.zero.init_ctx import no_shard_zero_decrator +from colossalai.zero.legacy.init_ctx import no_shard_zero_decrator class MoeExperts(nn.Module): diff --git a/colossalai/nn/layer/moe/layers.py b/colossalai/nn/layer/moe/layers.py index 0969eb818229..b90d1f0bfcc6 100644 --- a/colossalai/nn/layer/moe/layers.py +++ b/colossalai/nn/layer/moe/layers.py @@ -18,7 +18,7 @@ from colossalai.nn.layer.moe.routers import MoeRouter, Top1Router, Top2Router from colossalai.nn.layer.moe.utils import NormalNoiseGenerator, UniformNoiseGenerator from colossalai.utils import get_current_device -from colossalai.zero.init_ctx import no_shard_zero_context, no_shard_zero_decrator +from colossalai.zero.legacy.init_ctx import no_shard_zero_context, no_shard_zero_decrator @no_shard_zero_decrator(is_replicated=True) diff --git a/colossalai/nn/optimizer/gemini_optimizer.py b/colossalai/nn/optimizer/gemini_optimizer.py deleted file mode 100644 index 31d161612600..000000000000 --- a/colossalai/nn/optimizer/gemini_optimizer.py +++ /dev/null @@ -1,15 +0,0 @@ -from typing import Any - -import torch - -from colossalai.nn.optimizer import HybridAdam -from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer - -__all__ = ['GeminiAdamOptimizer'] - - -class GeminiAdamOptimizer(ZeroOptimizer): - - def __init__(self, model: torch.nn.Module, **defaults: Any) -> None: - optimizer = HybridAdam(model.parameters(), **defaults) - super().__init__(optimizer, model, **defaults) diff --git a/colossalai/nn/parallel/__init__.py b/colossalai/nn/parallel/__init__.py index 2afc8f18c36f..17e010f478c9 100644 --- a/colossalai/nn/parallel/__init__.py +++ b/colossalai/nn/parallel/__init__.py @@ -1,5 +1,5 @@ -from .data_parallel import ColoDDP, ZeroDDP -from .gemini_parallel import GeminiDDP -from .zero_wrapper import zero_model_wrapper, zero_optim_wrapper +from .data_parallel import ColoDDP -__all__ = ['ColoDDP', 'ZeroDDP', 'GeminiDDP', 'zero_model_wrapper', 'zero_optim_wrapper'] +__all__ = [ + 'ColoDDP', +] diff --git a/colossalai/nn/parallel/data_parallel.py b/colossalai/nn/parallel/data_parallel.py index a9d001bd0a9c..f839d6b28444 100644 --- a/colossalai/nn/parallel/data_parallel.py +++ b/colossalai/nn/parallel/data_parallel.py @@ -1,31 +1,14 @@ -import itertools from collections import OrderedDict from functools import partial -from typing import Dict, Iterable, List, Optional, Set +from typing import Iterable, Optional, Set import torch import torch.distributed as dist -import torch.nn as nn -from colossalai.gemini.chunk import Chunk, ChunkManager, TensorState -from colossalai.gemini.gemini_mgr import GeminiManager -from colossalai.gemini.memory_tracer import OrderedParamGenerator -from colossalai.logging import get_dist_logger -from colossalai.nn.parallel.utils import get_temp_total_chunk_on_cuda from colossalai.tensor import ProcessGroup as ColoProcessGroup -from colossalai.tensor import ReplicaSpec -from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec -from colossalai.tensor.param_op_hook import ColoParamOpHookManager -from colossalai.utils import get_current_device, is_ddp_ignored -from colossalai.zero.utils.gemini_hook import GeminiZeROHook +from colossalai.utils import is_ddp_ignored from .reducer import Reducer -from .utils import get_static_torch_model - -try: - from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys -except ImportError: - _EXTRA_STATE_KEY_SUFFIX = '_extra_state' def free_storage(data: torch.Tensor) -> None: @@ -189,507 +172,3 @@ def state_dict(self, destination=None, prefix='', keep_vars=False): def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', strict: bool = True): return self.module.load_state_dict(state_dict, strict) - - -class ZeroDDP(ColoDDP): - """ZeRO DDP for ColoTensor. - Warning: Nested ZeroDDP is not supported now. - It is designed to be used with ChunkManager and GeminiManager. - For more details, see the API reference of ``ChunkManager`` and ``GeminiManager``. - - Args: - module (torch.nn.Module): Module to apply ZeRO-DP. - gemini_manager (GeminiManager): Manages the chunk manager and heterogeneous momery space. - For more details, see the API reference of ``GeminiManager``. - pin_memory (bool): Chunks on CPU Memory use pin-memory. - force_outputs_fp32 (bool): If set to True, outputs will be fp32. Otherwise, outputs will be fp16. - Defaults to False. - strict_ddp_mode (bool): If set to True, there is no tensor sharding, each tensor is replicated. - Defaults to False. Users can set it to True, when they clearly know that they only need DDP. - """ - - def __init__(self, - module: torch.nn.Module, - gemini_manager: GeminiManager, - pin_memory: bool = False, - force_outputs_fp32: bool = False, - strict_ddp_mode: bool = False) -> None: - super().__init__(module, process_group=ColoProcessGroup()) - self.gemini_manager = gemini_manager - self.chunk_manager: ChunkManager = gemini_manager.chunk_manager - self.force_outputs_fp32 = force_outputs_fp32 - self.param_op_hook = GeminiZeROHook(gemini_manager) - self.fp32_params: List[ColoTensor] = list() - self.fp16_params: List[ColoParameter] = list() - self.overflow_counter = 0 - self.grads_device: Dict[torch.Tensor, torch.device] = dict() - self.param2name: Dict[nn.Parameter, str] = dict() - self.name2param: Dict[str, nn.Parameter] = dict() - - self._cast_buffers() - self._logger = get_dist_logger() - - if self.gemini_manager._premade_memstats_: - # build chunk in param runtime visited order. - param_order = self.gemini_manager.memstats()._param_runtime_order - else: - # build chunk in param initialized order. - # Note: in this way, it can not get filter unused params during runtime. - param_order = OrderedParamGenerator() - for p in module.parameters(): - param_order.append(p) - - self._init_chunks(param_order=param_order, - strict_ddp_mode=strict_ddp_mode, - cpu_offload=self.gemini_manager.policy_name != 'cuda', - pin_memory=pin_memory) - - for name, param in module.named_parameters(): - self.param2name[param] = name - for m_name, m_var in module.named_modules(): - for p_name, p_var in m_var.named_parameters(recurse=False): - param_name = m_name + '.' + p_name if m_name else p_name - self.name2param[param_name] = p_var - - def _post_forward(self): - """This function is only triggered for inference. - """ - access_list = list(self.chunk_manager.accessed_chunks) - # we need to scatter all accessed chunks and move them to their original places - for chunk in access_list: - if chunk.keep_gathered: - self.chunk_manager.fake_release_chunk(chunk) - else: - assert chunk.can_release - self.chunk_manager.release_chunk(chunk) - first_param = next(iter(chunk.tensors_info)) - self.chunk_manager.move_chunk(chunk, self.grads_device[first_param]) - assert self.chunk_manager.accessed_mem == 0 - # reset all recorded attributes - self.gemini_manager.reset_attributes() - - def forward(self, *args, **kwargs): - # check whether we are in a inference mode - grad_flag = torch.is_grad_enabled() - if not grad_flag: - assert not self.gemini_manager.need_warmup or not self.gemini_manager.is_warmup( - ), "You should run a completed iteration as your warmup iter" - - args, kwargs = _cast_float(args, torch.half), _cast_float(kwargs, torch.half) - self.module.zero_grad(set_to_none=True) - self.gemini_manager.pre_iter(*args) - with ColoParamOpHookManager.use_hooks(self.param_op_hook): - outputs = self.module(*args, **kwargs) - # scatter chunks in the inference mode - if not grad_flag: - self._post_forward() - - if self.force_outputs_fp32: - return _cast_float(outputs, torch.float) - return outputs - - def _setup_grads_ptr(self): - for p in self.module.parameters(): - if is_ddp_ignored(p): - continue - p.grad = None - - def _pre_backward(self): - # set a visit label for all parameters - # the label is used to check whether the parameter is correctly reduced - for param in self.param2name: - if not is_ddp_ignored(param): - setattr(param, "_gemini_reduced", False) - - def _post_backward(self): - if self.chunk_manager.accessed_mem != 0: - error_params = ["Reduction failed at followed parameters:"] - for param in self.param2name: - if not is_ddp_ignored(param) and not getattr(param, "_gemini_reduced"): - error_params.append(self.param2name[param]) - error_str = "\n\t".join(error_params) - raise RuntimeError("ZERO DDP error: the synchronization of gradients doesn't exit properly.", - "The most possible reason is that the model is not compatible with ZeroDDP.\n", - f"{error_str}") - self._setup_grads_ptr() - self._logger.debug( - f'comp cuda demand time: {self.gemini_manager._comp_cuda_demand_time}, layout time: {self.gemini_manager._layout_time}, evict time: {self.gemini_manager._evict_time}, CPU->CUDA vol: {self.gemini_manager._h2d_volume}B, CUDA->CPU vol: {self.gemini_manager._d2h_volume}' - ) - self.gemini_manager.post_iter() - - def backward(self, loss: torch.Tensor): - self._pre_backward() - with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(self.param_op_hook): - loss.backward() - self._post_backward() - - def backward_by_grad(self, tensor, grad): - with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(self.param_op_hook): - torch.autograd.backward(tensor, grad) - self._post_backward() - - def grad_handle(self, p, grad): - empty_grad = torch.empty_like(grad) - free_storage(empty_grad) - with torch._C.DisableTorchFunction(): - chunk = self.chunk_manager.get_chunk(p) - if chunk.tensors_info[p].state != TensorState.HOLD_AFTER_BWD: - raise RuntimeError(f"Parameter `{self.param2name[p]}` failed at the gradient reduction. " - "Some unsupported torch function is operated upon this parameter.") - self.chunk_manager.trans_tensor_state(p, TensorState.READY_FOR_REDUCE) - chunk.copy_tensor_to_chunk_slice(p, grad) - reduced = self.chunk_manager.reduce_chunk(chunk) - if reduced: - if chunk.is_gathered: - chunk.cuda_global_chunk.div_(chunk.pg_size) - else: - chunk.cuda_shard.div_(chunk.pg_size) - # check overflow elements - self.overflow_counter += chunk.has_inf_or_nan - # record l2 norm for gradient clipping - if chunk.l2_norm_flag: - chunk.set_l2_norm() - self.chunk_manager.move_chunk(chunk, self.grads_device[p], force_copy=True) - return empty_grad - - def zero_grad(self, set_to_none: bool = False) -> None: - self.module.zero_grad(set_to_none=True) - - def set_chunk_grad_device(self, chunk: Chunk, device: torch.device) -> None: - for tensor in chunk.get_tensors(): - self.grads_device[tensor] = device - - def state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: bool = True): - """Returns a dictionary containing a whole state of the module. - - Both parameters and persistent buffers (e.g. running averages) are included. - Keys are corresponding parameter and buffer names. - Parameters and buffers set to ``None`` are not included. - - Warning: The non strict state dict would ignore the parameters if the tensors of the parameters - are shared with other parameters which have been included in the dictionary. - When you need to load the state dict, you should set the argument `strict` to False. - - Returns: - dict: - a dictionary containing a whole state of the module - """ - if destination is None: - destination = OrderedDict() - destination._metadata = OrderedDict() - destination._metadata[prefix[:-1]] = local_metadata = dict(version=self._version) - self._save_to_state_dict(destination, prefix, keep_vars, only_rank_0) - - for hook in self._state_dict_hooks.values(): - hook_result = hook(self, destination, prefix, local_metadata) - if hook_result is not None: - destination = hook_result - return destination - - def _get_param_to_save_data(self, param_list: List[torch.nn.Parameter], only_rank_0: bool) -> Dict: - """ - get param content from chunks. - - Args: - param_list (_type_): a list of torch.nn.Parameters - only_rank_0 (_type_): _description_ - - Returns: - Dict: a dict whose key is param name and value is param with correct payload - """ - # save parameters - param_to_save_data = dict() - chunk_list = self.chunk_manager.get_chunks(param_list) - for chunk in chunk_list: - temp_chunk = get_temp_total_chunk_on_cuda(chunk) - - for tensor, tensor_info in chunk.tensors_info.items(): - record_tensor = torch.empty([0]) - record_flag = (not only_rank_0) | (dist.get_rank(chunk.torch_pg) == 0) - if record_flag: - record_tensor = temp_chunk[tensor_info.offset:tensor_info.end].view(tensor.shape).cpu() - - assert tensor not in param_to_save_data - param_to_save_data[tensor] = record_tensor - - del temp_chunk - return param_to_save_data - - def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True): - r"""Saves module state to `destination` dictionary, containing a state - of the module, but not its descendants. This is called on every - submodule in :meth:`~torch.nn.Module.state_dict`. - - In rare cases, subclasses can achieve class-specific behavior by - overriding this method with custom logic. - - Args: - destination (dict): a dict where state will be stored - prefix (str): the prefix for parameters and buffers used in this - module - """ - assert keep_vars is False, "`state_dict` with parameter, `keep_vars=True`, is not supported now." - - # get copies of fp32 parameters in CPU - param_to_save_data = self._get_param_to_save_data(self.fp32_params, only_rank_0) - # get the mapping between copies and fp16 parameters - p_mapping = dict() - for p, fp32_p in zip(self.fp16_params, self.fp32_params): - name = self.param2name[p] - assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name) - record_parameter = param_to_save_data[fp32_p] - p_mapping[p] = record_parameter - for name, param in self.name2param.items(): - if param is not None: - if is_ddp_ignored(param): - # deal with ddp ignored parameters - destination[prefix + name] = param if keep_vars else param.detach() - else: - destination[prefix + name] = p_mapping[param] - del p_mapping - del param_to_save_data - - # save all buffers - for name, buf in self.named_buffers(): - if buf is not None and name not in self._non_persistent_buffers_set: - destination[prefix + name] = buf if keep_vars else buf.detach() - # save extra states - extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX - if getattr(self.__class__, "get_extra_state", - torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state: - destination[extra_state_key] = self.get_extra_state() - - def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', strict: bool = True): - r"""Copies parameters and buffers from :attr:`state_dict` into - this module and its descendants. If :attr:`strict` is ``True``, then - the keys of :attr:`state_dict` must exactly match the keys returned - by this module's :meth:`~torch.nn.Module.state_dict` function. - - Args: - state_dict (dict): a dict containing parameters and - persistent buffers. - strict (bool, optional): whether to strictly enforce that the keys - in :attr:`state_dict` match the keys returned by this module's - :meth:`~torch.nn.Module.state_dict` function. Default: ``True`` - - Returns: - ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: - * **missing_keys** is a list of str containing the missing keys - * **unexpected_keys** is a list of str containing the unexpected keys - - Note: - If a parameter or buffer is registered as ``None`` and its corresponding key - exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a - ``RuntimeError``. - """ - missing_keys: List[str] = [] - unexpected_keys: List[str] = [] - error_msgs: List[str] = [] - - # copy state_dict so _load_from_state_dict can modify it - metadata = getattr(state_dict, '_metadata', None) - state_dict = state_dict.copy() - if metadata is not None: - # mypy isn't aware that "_metadata" exists in state_dict - state_dict._metadata = metadata # type: ignore[attr-defined] - - prefix = '' - local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) - self._load_from_state_dict(state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) - - if strict: - if len(unexpected_keys) > 0: - error_msgs.insert( - 0, 'Unexpected key(s) in state_dict: {}. '.format(', '.join( - '"{}"'.format(k) for k in unexpected_keys))) - if len(missing_keys) > 0: - error_msgs.insert( - 0, 'Missing key(s) in state_dict: {}. '.format(', '.join('"{}"'.format(k) for k in missing_keys))) - - if len(error_msgs) > 0: - raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( - self.__class__.__name__, "\n\t".join(error_msgs))) - return _IncompatibleKeys(missing_keys, unexpected_keys) - - def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, - error_msgs): - r"""Copies parameters and buffers from :attr:`state_dict` into only - this module, but not its descendants. This is called on every submodule - in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this - module in input :attr:`state_dict` is provided as :attr:`local_metadata`. - For state dicts without metadata, :attr:`local_metadata` is empty. - Subclasses can achieve class-specific backward compatible loading using - the version number at `local_metadata.get("version", None)`. - - .. note:: - :attr:`state_dict` is not the same object as the input - :attr:`state_dict` to :meth:`~torch.nn.Module.load_state_dict`. So - it can be modified. - - Args: - state_dict (dict): a dict containing parameters and - persistent buffers. - prefix (str): the prefix for parameters and buffers used in this - module - local_metadata (dict): a dict containing the metadata for this module. - See - strict (bool): whether to strictly enforce that the keys in - :attr:`state_dict` with :attr:`prefix` match the names of - parameters and buffers in this module - missing_keys (list of str): if ``strict=True``, add missing keys to - this list - unexpected_keys (list of str): if ``strict=True``, add unexpected - keys to this list - error_msgs (list of str): error messages should be added to this - list, and will be reported together in - :meth:`~torch.nn.Module.load_state_dict` - """ - for hook in self._load_state_dict_pre_hooks.values(): - hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) - - persistent_buffers = {k: v for k, v in self.named_buffers() if k not in self._non_persistent_buffers_set} - local_name_params = itertools.chain(self.named_parameters(), persistent_buffers.items()) - local_state = {k: v for k, v in local_name_params if v is not None} - - def load(param_name, dest_tensor, copy_func): - state_key = prefix + param_name - if state_key in state_dict: - input_param = state_dict[state_key] - # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+ - if len(dest_tensor.shape) == 0 and len(input_param.shape) == 1: - input_param = input_param[0] - if input_param.shape != dest_tensor.shape: - # local shape should match the one in checkpoint - error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, ' - 'the shape in current model is {}.'.format(state_key, input_param.shape, - dest_tensor.shape)) - return - try: - with torch.no_grad(): - copy_func(input_param) - except Exception as ex: - error_msgs.append('While copying the parameter named "{}", ' - 'whose dimensions in the model are {} and ' - 'whose dimensions in the checkpoint are {}, ' - 'an exception occurred : {}.'.format(state_key, dest_tensor.size(), - input_param.size(), ex.args)) - elif strict: - missing_keys.append(state_key) - - def load_fp32_parameter(chunk_slice, data): - chunk_slice.copy_(data.flatten()) - - for name, param in self.named_parameters(): - if is_ddp_ignored(param): - # deal with ddp ignored parameters - load(name, param, param.copy_) - - fp32_to_name = dict() - for p, fp32_p in zip(self.fp16_params, self.fp32_params): - if p is not None: - name = self.param2name[p] - fp32_to_name[fp32_p] = name - - chunk_list = self.chunk_manager.get_chunks(self.fp32_params) - for chunk in chunk_list: - temp_chunk = get_temp_total_chunk_on_cuda(chunk) - - for tensor, tensor_info in chunk.tensors_info.items(): - parameter_name = fp32_to_name[tensor] - parameter_slice = temp_chunk[tensor_info.offset:tensor_info.end] - load(parameter_name, tensor, partial(load_fp32_parameter, parameter_slice)) - - if chunk.is_gathered: - chunk.cuda_global_chunk.copy_(temp_chunk) - elif chunk.cuda_shard is not None: - chunk.cuda_shard.copy_(temp_chunk[chunk.shard_begin:chunk.shard_end]) - else: - chunk.cpu_shard.copy_(temp_chunk[chunk.shard_begin:chunk.shard_end]) - - del temp_chunk - - for chunk_32 in chunk_list: - chunk_16 = chunk_32.paired_chunk - assert chunk_16 is not None - chunk_16.optim_update() - - for name, buf in persistent_buffers.items(): - if buf is not None: - load(name, buf, buf.copy_) - - extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX - if getattr(self.__class__, "set_extra_state", - torch.nn.Module.set_extra_state) is not torch.nn.Module.set_extra_state: - if extra_state_key in state_dict: - self.set_extra_state(state_dict[extra_state_key]) - elif strict: - missing_keys.append(extra_state_key) - elif strict and (extra_state_key in state_dict): - unexpected_keys.append(extra_state_key) - - if strict: - for key in state_dict.keys(): - if key.startswith(prefix) and key != extra_state_key: - input_name = key[len(prefix):] - if input_name not in local_state: - unexpected_keys.append(key) - - def _init_chunks(self, param_order, strict_ddp_mode: bool, cpu_offload: bool, pin_memory: bool): - ddp_pg = ColoProcessGroup() - for p in param_order.generate(): - assert isinstance(p, ColoParameter) - - # gather sharded parameters in the strict ddp mode - if strict_ddp_mode: - if not p.is_replicate(): - p.set_dist_spec(ReplicaSpec()) - p.set_process_group(pg=ddp_pg) - - # ignore the parameters with no gradient - if not p.requires_grad: - self.set_params_to_ignore([p]) - - # move ignored parameters to CUDA - if is_ddp_ignored(p): - p.data = p.data.to(device=get_current_device(), dtype=torch.float16) - continue - - # create a fp32 parameter - fp32_data = p.data.float() - fp32_p = ColoTensor(fp32_data, spec=ColoTensorSpec(p.process_group)) - # create a fp16 parameter - p.data = p.data.half() - - # register the fp16 parameter and fp32 parameter in the chunk manager - dp_world_size = p.process_group.dp_world_size() - self.chunk_manager.register_tensor(tensor=p, - group_type='fp16_param', - config_key=dp_world_size, - cpu_offload=cpu_offload, - pin_memory=pin_memory) - self.chunk_manager.register_tensor(tensor=fp32_p, - group_type='fp32_param', - config_key=dp_world_size, - cpu_offload=cpu_offload, - pin_memory=pin_memory) - - self.fp16_params.append(p) - self.fp32_params.append(fp32_p) - self.grads_device[p] = self.gemini_manager.default_device - - self.chunk_manager.close_all_groups() - - for p, fp32_p in zip(self.fp16_params, self.fp32_params): - chunk_16 = self.chunk_manager.get_chunk(p) - chunk_32 = self.chunk_manager.get_chunk(fp32_p) - chunk_32.init_pair(chunk_16) - - # keep gathered chunks are in CUDA - if chunk_16.keep_gathered: - self.grads_device[p] = get_current_device() - - def _cast_buffers(self): - for buffer in self.module.buffers(): - buffer.data = buffer.cuda() - if torch.is_floating_point(buffer): - buffer.data = buffer.half() diff --git a/colossalai/nn/parallel/gemini_parallel.py b/colossalai/nn/parallel/gemini_parallel.py deleted file mode 100644 index 2c6e15d91736..000000000000 --- a/colossalai/nn/parallel/gemini_parallel.py +++ /dev/null @@ -1,63 +0,0 @@ -from typing import Optional - -import torch - -from colossalai.gemini.chunk import init_chunk_manager -from colossalai.gemini.gemini_mgr import GeminiManager -from colossalai.gemini.memory_tracer import MemStats - -from .data_parallel import ZeroDDP - - -class GeminiDDP(ZeroDDP): - - def __init__(self, - module: torch.nn.Module, - device: torch.device, - placement_policy: str = "cpu", - pin_memory: bool = False, - force_outputs_fp32: bool = False, - strict_ddp_mode: bool = False, - search_range_mb: int = 32, - hidden_dim: Optional[int] = None, - min_chunk_size_mb: float = 32, - memstats: Optional[MemStats] = None) -> None: - """ - A torch.Module warpper using ZeRO-DP and Genimi. - ZeRO is for parallel. Gemini is for memory management. - WARNING: The class will modify the module inline! - - Example: - model is initialized under the context of ColoInitContext - >>> model = GeminiDDP(model, torch.cuda.current_device(), "cuda") - >>> logits = model(x) - >>> loss = criterion(logits, labels) - >>> model.backward(loss) - - Args: - module (torch.nn.Module): the model to be wrapped. - device (torch.device): device to place the model. - placement_policy (str, optional): "cpu", "cuda", "auto". Defaults to "cpu". - pin_memory (bool, optional): use pin memory on CPU. Defaults to False. - force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False. - search_range_mb (int, optional): chunk size searching range in MegaByte. Defaults to 32. - hidden_dim (int, optional): the hidden dimension of DNN. - Users can provide this argument to speed up searching. - If users do not know this argument before training, it is ok. We will use a default value 1024. - min_chunk_size_mb (float, optional): the minimum chunk size in MegaByte. - If the aggregate size of parameters is still samller than the minimum chunk size, - all parameters will be compacted into one small chunk. - memstats (MemStats, optional) the memory statistics collector by a runtime memory tracer. - """ - # some ugly hotfix for the compatibility with Lightning - if search_range_mb is None: - search_range_mb = 32 - - chunk_manager = init_chunk_manager(model=module, - init_device=device, - hidden_dim=hidden_dim, - search_range_mb=search_range_mb, - min_chunk_size_mb=min_chunk_size_mb, - strict_ddp_flag=strict_ddp_mode) - gemini_manager = GeminiManager(placement_policy, chunk_manager, memstats) - super().__init__(module, gemini_manager, pin_memory, force_outputs_fp32, strict_ddp_mode) diff --git a/colossalai/zero/__init__.py b/colossalai/zero/__init__.py index 098ccbb45c5a..3465079e4fbb 100644 --- a/colossalai/zero/__init__.py +++ b/colossalai/zero/__init__.py @@ -1,41 +1,16 @@ -from typing import Tuple - -import torch -import torch.nn as nn - -from colossalai.logging import get_dist_logger -from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2 -from colossalai.zero.sharded_optim import LowLevelZeroOptimizer, ShardedOptimizerV2 - -from ..nn.optimizer.zero_optimizer import ZeroOptimizer - - -def convert_to_zero_v2(model: nn.Module, optimizer: torch.optim.Optimizer, model_config, - optimizer_config) -> Tuple[ShardedModelV2, ShardedOptimizerV2]: - """ - A helper function to integrate the model and optimizer with ZeRO optimizer and off-loading - - :param model: Your model object - :type model: :class:`torch.nn.Module` - :param optimizer_config: Your optimizer object - :type optimizer_config: :class:`dict` - - :return: (model, optimizer) - :rtype: Tuple - """ - - logger = get_dist_logger('convert_to_zero_v2') - - logger.info(f'optimizer_config is {optimizer_config}', ranks=[0]) - if optimizer_config is None: - optimizer_config = dict() - logger.info(f'model_config is {model_config}', ranks=[0]) - if model_config is None: - model_config = dict() - - zero_model = ShardedModelV2(model, **model_config) - zero_optimizer = ShardedOptimizerV2(zero_model, optimizer, **optimizer_config) - return zero_model, zero_optimizer - - -__all__ = ['convert_to_zero_v2', 'LowLevelZeroOptimizer', 'ShardedModelV2', 'ShardedOptimizerV2', 'ZeroOptimizer'] +from .gemini import ( + ColoInitContext, + GeminiAdamOptimizer, + GeminiDDP, + ZeroDDP, + ZeroOptimizer, + get_static_torch_model, + post_process_colo_init_ctx, +) +from .low_level import LowLevelZeroOptimizer +from .wrapper import zero_model_wrapper, zero_optim_wrapper + +__all__ = [ + 'ZeroDDP', 'GeminiDDP', 'ZeroOptimizer', 'GeminiAdamOptimizer', 'zero_model_wrapper', 'zero_optim_wrapper', + 'LowLevelZeroOptimizer', 'ColoInitContext', 'post_process_colo_init_ctx', 'get_static_torch_model' +] diff --git a/colossalai/zero/gemini/__init__.py b/colossalai/zero/gemini/__init__.py new file mode 100644 index 000000000000..60f85ca2f540 --- /dev/null +++ b/colossalai/zero/gemini/__init__.py @@ -0,0 +1,11 @@ +from .chunk import ChunkManager, TensorInfo, TensorState, search_chunk_configuration +from .colo_init_context import ColoInitContext, post_process_colo_init_ctx +from .gemini_ddp import GeminiDDP, ZeroDDP +from .gemini_mgr import GeminiManager +from .gemini_optimizer import GeminiAdamOptimizer, ZeroOptimizer +from .utils import get_static_torch_model + +__all__ = [ + 'GeminiManager', 'TensorInfo', 'TensorState', 'ChunkManager', 'search_chunk_configuration', 'ZeroDDP', 'GeminiDDP', + 'get_static_torch_model', 'GeminiAdamOptimizer', 'ZeroOptimizer', 'ColoInitContext', 'post_process_colo_init_ctx' +] diff --git a/colossalai/gemini/chunk/__init__.py b/colossalai/zero/gemini/chunk/__init__.py similarity index 100% rename from colossalai/gemini/chunk/__init__.py rename to colossalai/zero/gemini/chunk/__init__.py diff --git a/colossalai/gemini/chunk/chunk.py b/colossalai/zero/gemini/chunk/chunk.py similarity index 100% rename from colossalai/gemini/chunk/chunk.py rename to colossalai/zero/gemini/chunk/chunk.py diff --git a/colossalai/gemini/chunk/manager.py b/colossalai/zero/gemini/chunk/manager.py similarity index 99% rename from colossalai/gemini/chunk/manager.py rename to colossalai/zero/gemini/chunk/manager.py index 2fa65c970316..d85df0b00476 100644 --- a/colossalai/gemini/chunk/manager.py +++ b/colossalai/zero/gemini/chunk/manager.py @@ -3,10 +3,11 @@ import torch -from colossalai.gemini.chunk import Chunk, ChunkFullError, TensorState from colossalai.tensor import ColoTensor from colossalai.utils import get_current_device +from .chunk import Chunk, ChunkFullError, TensorState + class ChunkManager: """ diff --git a/colossalai/gemini/chunk/search_utils.py b/colossalai/zero/gemini/chunk/search_utils.py similarity index 98% rename from colossalai/gemini/chunk/search_utils.py rename to colossalai/zero/gemini/chunk/search_utils.py index fe9650721d74..a69b782ead2e 100644 --- a/colossalai/gemini/chunk/search_utils.py +++ b/colossalai/zero/gemini/chunk/search_utils.py @@ -5,9 +5,9 @@ import torch.distributed as dist import torch.nn as nn -from colossalai.gemini.memory_tracer import MemStats, OrderedParamGenerator from colossalai.tensor import ColoParameter from colossalai.utils import is_ddp_ignored +from colossalai.zero.gemini.memory_tracer import MemStats, OrderedParamGenerator def _filter_exlarge_params(model: nn.Module, size_dict: Dict[int, List[int]]) -> None: diff --git a/colossalai/gemini/chunk/utils.py b/colossalai/zero/gemini/chunk/utils.py similarity index 91% rename from colossalai/gemini/chunk/utils.py rename to colossalai/zero/gemini/chunk/utils.py index 83512b8e0ee5..283f74203592 100644 --- a/colossalai/gemini/chunk/utils.py +++ b/colossalai/zero/gemini/chunk/utils.py @@ -5,10 +5,11 @@ import torch.distributed as dist import torch.nn as nn -from colossalai.gemini.chunk import ChunkManager -from colossalai.gemini.chunk.search_utils import search_chunk_configuration from colossalai.utils import is_ddp_ignored +from .manager import ChunkManager +from .search_utils import search_chunk_configuration + def safe_div(a, b): if a == 0: diff --git a/colossalai/utils/model/colo_init_context.py b/colossalai/zero/gemini/colo_init_context.py similarity index 97% rename from colossalai/utils/model/colo_init_context.py rename to colossalai/zero/gemini/colo_init_context.py index 87ae413a2a8a..5937ee9eff9a 100644 --- a/colossalai/utils/model/colo_init_context.py +++ b/colossalai/zero/gemini/colo_init_context.py @@ -3,10 +3,8 @@ import torch from torch import nn -from colossalai.nn.parallel.layers import ColoEmbedding, ColoLinear, register_colo_module from colossalai.tensor import ColoParameter, ColoTensor, ProcessGroup - -from .utils import InsertPostInitMethodToModuleSubClasses +from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses # find named_params includes replica @@ -89,6 +87,7 @@ def __init__(self, self._default_dist_spec = default_dist_spec def _register_colo_modules(self): + from colossalai.nn.parallel.layers import ColoEmbedding, ColoLinear, register_colo_module register_colo_module(torch.nn.Linear, ColoLinear()) register_colo_module(torch.nn.Embedding, ColoEmbedding()) diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py new file mode 100644 index 000000000000..50f1b1ef1ccc --- /dev/null +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -0,0 +1,590 @@ +import itertools +from collections import OrderedDict +from functools import partial +from typing import Dict, List, Optional + +import torch +import torch.distributed as dist +import torch.nn as nn + +from colossalai.logging import get_dist_logger +from colossalai.nn.parallel.data_parallel import ColoDDP, _cast_float, free_storage +from colossalai.tensor import ProcessGroup as ColoProcessGroup +from colossalai.tensor import ReplicaSpec +from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec +from colossalai.tensor.param_op_hook import ColoParamOpHookManager +from colossalai.utils import get_current_device, is_ddp_ignored + +from .chunk import Chunk, ChunkManager, TensorState, init_chunk_manager +from .gemini_hook import GeminiZeROHook +from .gemini_mgr import GeminiManager +from .memory_tracer import MemStats, OrderedParamGenerator +from .utils import get_temp_total_chunk_on_cuda + +try: + from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys +except ImportError: + _EXTRA_STATE_KEY_SUFFIX = '_extra_state' + +__all__ = [ + 'ZeroDDP', + 'GeminiDDP', +] + + +class ZeroDDP(ColoDDP): + """ZeRO DDP for ColoTensor. + Warning: Nested ZeroDDP is not supported now. + It is designed to be used with ChunkManager and GeminiManager. + For more details, see the API reference of ``ChunkManager`` and ``GeminiManager``. + + Args: + module (torch.nn.Module): Module to apply ZeRO-DP. + gemini_manager (GeminiManager): Manages the chunk manager and heterogeneous momery space. + For more details, see the API reference of ``GeminiManager``. + pin_memory (bool): Chunks on CPU Memory use pin-memory. + force_outputs_fp32 (bool): If set to True, outputs will be fp32. Otherwise, outputs will be fp16. + Defaults to False. + strict_ddp_mode (bool): If set to True, there is no tensor sharding, each tensor is replicated. + Defaults to False. Users can set it to True, when they clearly know that they only need DDP. + """ + + def __init__(self, + module: torch.nn.Module, + gemini_manager: GeminiManager, + pin_memory: bool = False, + force_outputs_fp32: bool = False, + strict_ddp_mode: bool = False) -> None: + super().__init__(module, process_group=ColoProcessGroup()) + self.gemini_manager = gemini_manager + self.chunk_manager: ChunkManager = gemini_manager.chunk_manager + self.force_outputs_fp32 = force_outputs_fp32 + self.param_op_hook = GeminiZeROHook(gemini_manager) + self.fp32_params: List[ColoTensor] = list() + self.fp16_params: List[ColoParameter] = list() + self.overflow_counter = 0 + self.grads_device: Dict[torch.Tensor, torch.device] = dict() + self.param2name: Dict[nn.Parameter, str] = dict() + self.name2param: Dict[str, nn.Parameter] = dict() + + self._cast_buffers() + self._logger = get_dist_logger() + + if self.gemini_manager._premade_memstats_: + # build chunk in param runtime visited order. + param_order = self.gemini_manager.memstats()._param_runtime_order + else: + # build chunk in param initialized order. + # Note: in this way, it can not get filter unused params during runtime. + param_order = OrderedParamGenerator() + for p in module.parameters(): + param_order.append(p) + + self._init_chunks(param_order=param_order, + strict_ddp_mode=strict_ddp_mode, + cpu_offload=self.gemini_manager.policy_name != 'cuda', + pin_memory=pin_memory) + + for name, param in module.named_parameters(): + self.param2name[param] = name + for m_name, m_var in module.named_modules(): + for p_name, p_var in m_var.named_parameters(recurse=False): + param_name = m_name + '.' + p_name if m_name else p_name + self.name2param[param_name] = p_var + + def _post_forward(self): + """This function is only triggered for inference. + """ + access_list = list(self.chunk_manager.accessed_chunks) + # we need to scatter all accessed chunks and move them to their original places + for chunk in access_list: + if chunk.keep_gathered: + self.chunk_manager.fake_release_chunk(chunk) + else: + assert chunk.can_release + self.chunk_manager.release_chunk(chunk) + first_param = next(iter(chunk.tensors_info)) + self.chunk_manager.move_chunk(chunk, self.grads_device[first_param]) + assert self.chunk_manager.accessed_mem == 0 + # reset all recorded attributes + self.gemini_manager.reset_attributes() + + def forward(self, *args, **kwargs): + # check whether we are in a inference mode + grad_flag = torch.is_grad_enabled() + if not grad_flag: + assert not self.gemini_manager.need_warmup or not self.gemini_manager.is_warmup( + ), "You should run a completed iteration as your warmup iter" + + args, kwargs = _cast_float(args, torch.half), _cast_float(kwargs, torch.half) + self.module.zero_grad(set_to_none=True) + self.gemini_manager.pre_iter(*args) + with ColoParamOpHookManager.use_hooks(self.param_op_hook): + outputs = self.module(*args, **kwargs) + # scatter chunks in the inference mode + if not grad_flag: + self._post_forward() + + if self.force_outputs_fp32: + return _cast_float(outputs, torch.float) + return outputs + + def _setup_grads_ptr(self): + for p in self.module.parameters(): + if is_ddp_ignored(p): + continue + p.grad = None + + def _pre_backward(self): + # set a visit label for all parameters + # the label is used to check whether the parameter is correctly reduced + for param in self.param2name: + if not is_ddp_ignored(param): + setattr(param, "_gemini_reduced", False) + + def _post_backward(self): + if self.chunk_manager.accessed_mem != 0: + error_params = ["Reduction failed at followed parameters:"] + for param in self.param2name: + if not is_ddp_ignored(param) and not getattr(param, "_gemini_reduced"): + error_params.append(self.param2name[param]) + error_str = "\n\t".join(error_params) + raise RuntimeError("ZERO DDP error: the synchronization of gradients doesn't exit properly.", + "The most possible reason is that the model is not compatible with ZeroDDP.\n", + f"{error_str}") + self._setup_grads_ptr() + self._logger.debug( + f'comp cuda demand time: {self.gemini_manager._comp_cuda_demand_time}, layout time: {self.gemini_manager._layout_time}, evict time: {self.gemini_manager._evict_time}, CPU->CUDA vol: {self.gemini_manager._h2d_volume}B, CUDA->CPU vol: {self.gemini_manager._d2h_volume}' + ) + self.gemini_manager.post_iter() + + def backward(self, loss: torch.Tensor): + self._pre_backward() + with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(self.param_op_hook): + loss.backward() + self._post_backward() + + def backward_by_grad(self, tensor, grad): + with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(self.param_op_hook): + torch.autograd.backward(tensor, grad) + self._post_backward() + + def grad_handle(self, p, grad): + empty_grad = torch.empty_like(grad) + free_storage(empty_grad) + with torch._C.DisableTorchFunction(): + chunk = self.chunk_manager.get_chunk(p) + if chunk.tensors_info[p].state != TensorState.HOLD_AFTER_BWD: + raise RuntimeError(f"Parameter `{self.param2name[p]}` failed at the gradient reduction. " + "Some unsupported torch function is operated upon this parameter.") + self.chunk_manager.trans_tensor_state(p, TensorState.READY_FOR_REDUCE) + chunk.copy_tensor_to_chunk_slice(p, grad) + reduced = self.chunk_manager.reduce_chunk(chunk) + if reduced: + if chunk.is_gathered: + chunk.cuda_global_chunk.div_(chunk.pg_size) + else: + chunk.cuda_shard.div_(chunk.pg_size) + # check overflow elements + self.overflow_counter += chunk.has_inf_or_nan + # record l2 norm for gradient clipping + if chunk.l2_norm_flag: + chunk.set_l2_norm() + self.chunk_manager.move_chunk(chunk, self.grads_device[p], force_copy=True) + return empty_grad + + def zero_grad(self, set_to_none: bool = False) -> None: + self.module.zero_grad(set_to_none=True) + + def set_chunk_grad_device(self, chunk: Chunk, device: torch.device) -> None: + for tensor in chunk.get_tensors(): + self.grads_device[tensor] = device + + def state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: bool = True): + """Returns a dictionary containing a whole state of the module. + + Both parameters and persistent buffers (e.g. running averages) are included. + Keys are corresponding parameter and buffer names. + Parameters and buffers set to ``None`` are not included. + + Warning: The non strict state dict would ignore the parameters if the tensors of the parameters + are shared with other parameters which have been included in the dictionary. + When you need to load the state dict, you should set the argument `strict` to False. + + Returns: + dict: + a dictionary containing a whole state of the module + """ + if destination is None: + destination = OrderedDict() + destination._metadata = OrderedDict() + destination._metadata[prefix[:-1]] = local_metadata = dict(version=self._version) + self._save_to_state_dict(destination, prefix, keep_vars, only_rank_0) + + for hook in self._state_dict_hooks.values(): + hook_result = hook(self, destination, prefix, local_metadata) + if hook_result is not None: + destination = hook_result + return destination + + def _get_param_to_save_data(self, param_list: List[torch.nn.Parameter], only_rank_0: bool) -> Dict: + """ + get param content from chunks. + + Args: + param_list (_type_): a list of torch.nn.Parameters + only_rank_0 (_type_): _description_ + + Returns: + Dict: a dict whose key is param name and value is param with correct payload + """ + # save parameters + param_to_save_data = dict() + chunk_list = self.chunk_manager.get_chunks(param_list) + for chunk in chunk_list: + temp_chunk = get_temp_total_chunk_on_cuda(chunk) + + for tensor, tensor_info in chunk.tensors_info.items(): + record_tensor = torch.empty([0]) + record_flag = (not only_rank_0) | (dist.get_rank(chunk.torch_pg) == 0) + if record_flag: + record_tensor = temp_chunk[tensor_info.offset:tensor_info.end].view(tensor.shape).cpu() + + assert tensor not in param_to_save_data + param_to_save_data[tensor] = record_tensor + + del temp_chunk + return param_to_save_data + + def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True): + r"""Saves module state to `destination` dictionary, containing a state + of the module, but not its descendants. This is called on every + submodule in :meth:`~torch.nn.Module.state_dict`. + + In rare cases, subclasses can achieve class-specific behavior by + overriding this method with custom logic. + + Args: + destination (dict): a dict where state will be stored + prefix (str): the prefix for parameters and buffers used in this + module + """ + assert keep_vars is False, "`state_dict` with parameter, `keep_vars=True`, is not supported now." + + # get copies of fp32 parameters in CPU + param_to_save_data = self._get_param_to_save_data(self.fp32_params, only_rank_0) + # get the mapping between copies and fp16 parameters + p_mapping = dict() + for p, fp32_p in zip(self.fp16_params, self.fp32_params): + name = self.param2name[p] + assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name) + record_parameter = param_to_save_data[fp32_p] + p_mapping[p] = record_parameter + for name, param in self.name2param.items(): + if param is not None: + if is_ddp_ignored(param): + # deal with ddp ignored parameters + destination[prefix + name] = param if keep_vars else param.detach() + else: + destination[prefix + name] = p_mapping[param] + del p_mapping + del param_to_save_data + + # save all buffers + for name, buf in self.named_buffers(): + if buf is not None and name not in self._non_persistent_buffers_set: + destination[prefix + name] = buf if keep_vars else buf.detach() + # save extra states + extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX + if getattr(self.__class__, "get_extra_state", + torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state: + destination[extra_state_key] = self.get_extra_state() + + def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', strict: bool = True): + r"""Copies parameters and buffers from :attr:`state_dict` into + this module and its descendants. If :attr:`strict` is ``True``, then + the keys of :attr:`state_dict` must exactly match the keys returned + by this module's :meth:`~torch.nn.Module.state_dict` function. + + Args: + state_dict (dict): a dict containing parameters and + persistent buffers. + strict (bool, optional): whether to strictly enforce that the keys + in :attr:`state_dict` match the keys returned by this module's + :meth:`~torch.nn.Module.state_dict` function. Default: ``True`` + + Returns: + ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: + * **missing_keys** is a list of str containing the missing keys + * **unexpected_keys** is a list of str containing the unexpected keys + + Note: + If a parameter or buffer is registered as ``None`` and its corresponding key + exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a + ``RuntimeError``. + """ + missing_keys: List[str] = [] + unexpected_keys: List[str] = [] + error_msgs: List[str] = [] + + # copy state_dict so _load_from_state_dict can modify it + metadata = getattr(state_dict, '_metadata', None) + state_dict = state_dict.copy() + if metadata is not None: + # mypy isn't aware that "_metadata" exists in state_dict + state_dict._metadata = metadata # type: ignore[attr-defined] + + prefix = '' + local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) + self._load_from_state_dict(state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) + + if strict: + if len(unexpected_keys) > 0: + error_msgs.insert( + 0, 'Unexpected key(s) in state_dict: {}. '.format(', '.join( + '"{}"'.format(k) for k in unexpected_keys))) + if len(missing_keys) > 0: + error_msgs.insert( + 0, 'Missing key(s) in state_dict: {}. '.format(', '.join('"{}"'.format(k) for k in missing_keys))) + + if len(error_msgs) > 0: + raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( + self.__class__.__name__, "\n\t".join(error_msgs))) + return _IncompatibleKeys(missing_keys, unexpected_keys) + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, + error_msgs): + r"""Copies parameters and buffers from :attr:`state_dict` into only + this module, but not its descendants. This is called on every submodule + in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this + module in input :attr:`state_dict` is provided as :attr:`local_metadata`. + For state dicts without metadata, :attr:`local_metadata` is empty. + Subclasses can achieve class-specific backward compatible loading using + the version number at `local_metadata.get("version", None)`. + + .. note:: + :attr:`state_dict` is not the same object as the input + :attr:`state_dict` to :meth:`~torch.nn.Module.load_state_dict`. So + it can be modified. + + Args: + state_dict (dict): a dict containing parameters and + persistent buffers. + prefix (str): the prefix for parameters and buffers used in this + module + local_metadata (dict): a dict containing the metadata for this module. + See + strict (bool): whether to strictly enforce that the keys in + :attr:`state_dict` with :attr:`prefix` match the names of + parameters and buffers in this module + missing_keys (list of str): if ``strict=True``, add missing keys to + this list + unexpected_keys (list of str): if ``strict=True``, add unexpected + keys to this list + error_msgs (list of str): error messages should be added to this + list, and will be reported together in + :meth:`~torch.nn.Module.load_state_dict` + """ + for hook in self._load_state_dict_pre_hooks.values(): + hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + + persistent_buffers = {k: v for k, v in self.named_buffers() if k not in self._non_persistent_buffers_set} + local_name_params = itertools.chain(self.named_parameters(), persistent_buffers.items()) + local_state = {k: v for k, v in local_name_params if v is not None} + + def load(param_name, dest_tensor, copy_func): + state_key = prefix + param_name + if state_key in state_dict: + input_param = state_dict[state_key] + # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+ + if len(dest_tensor.shape) == 0 and len(input_param.shape) == 1: + input_param = input_param[0] + if input_param.shape != dest_tensor.shape: + # local shape should match the one in checkpoint + error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, ' + 'the shape in current model is {}.'.format(state_key, input_param.shape, + dest_tensor.shape)) + return + try: + with torch.no_grad(): + copy_func(input_param) + except Exception as ex: + error_msgs.append('While copying the parameter named "{}", ' + 'whose dimensions in the model are {} and ' + 'whose dimensions in the checkpoint are {}, ' + 'an exception occurred : {}.'.format(state_key, dest_tensor.size(), + input_param.size(), ex.args)) + elif strict: + missing_keys.append(state_key) + + def load_fp32_parameter(chunk_slice, data): + chunk_slice.copy_(data.flatten()) + + for name, param in self.named_parameters(): + if is_ddp_ignored(param): + # deal with ddp ignored parameters + load(name, param, param.copy_) + + fp32_to_name = dict() + for p, fp32_p in zip(self.fp16_params, self.fp32_params): + if p is not None: + name = self.param2name[p] + fp32_to_name[fp32_p] = name + + chunk_list = self.chunk_manager.get_chunks(self.fp32_params) + for chunk in chunk_list: + temp_chunk = get_temp_total_chunk_on_cuda(chunk) + + for tensor, tensor_info in chunk.tensors_info.items(): + parameter_name = fp32_to_name[tensor] + parameter_slice = temp_chunk[tensor_info.offset:tensor_info.end] + load(parameter_name, tensor, partial(load_fp32_parameter, parameter_slice)) + + if chunk.is_gathered: + chunk.cuda_global_chunk.copy_(temp_chunk) + elif chunk.cuda_shard is not None: + chunk.cuda_shard.copy_(temp_chunk[chunk.shard_begin:chunk.shard_end]) + else: + chunk.cpu_shard.copy_(temp_chunk[chunk.shard_begin:chunk.shard_end]) + + del temp_chunk + + for chunk_32 in chunk_list: + chunk_16 = chunk_32.paired_chunk + assert chunk_16 is not None + chunk_16.optim_update() + + for name, buf in persistent_buffers.items(): + if buf is not None: + load(name, buf, buf.copy_) + + extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX + if getattr(self.__class__, "set_extra_state", + torch.nn.Module.set_extra_state) is not torch.nn.Module.set_extra_state: + if extra_state_key in state_dict: + self.set_extra_state(state_dict[extra_state_key]) + elif strict: + missing_keys.append(extra_state_key) + elif strict and (extra_state_key in state_dict): + unexpected_keys.append(extra_state_key) + + if strict: + for key in state_dict.keys(): + if key.startswith(prefix) and key != extra_state_key: + input_name = key[len(prefix):] + if input_name not in local_state: + unexpected_keys.append(key) + + def _init_chunks(self, param_order, strict_ddp_mode: bool, cpu_offload: bool, pin_memory: bool): + ddp_pg = ColoProcessGroup() + for p in param_order.generate(): + assert isinstance(p, ColoParameter) + + # gather sharded parameters in the strict ddp mode + if strict_ddp_mode: + if not p.is_replicate(): + p.set_dist_spec(ReplicaSpec()) + p.set_process_group(pg=ddp_pg) + + # ignore the parameters with no gradient + if not p.requires_grad: + self.set_params_to_ignore([p]) + + # move ignored parameters to CUDA + if is_ddp_ignored(p): + p.data = p.data.to(device=get_current_device(), dtype=torch.float16) + continue + + # create a fp32 parameter + fp32_data = p.data.float() + fp32_p = ColoTensor(fp32_data, spec=ColoTensorSpec(p.process_group)) + # create a fp16 parameter + p.data = p.data.half() + + # register the fp16 parameter and fp32 parameter in the chunk manager + dp_world_size = p.process_group.dp_world_size() + self.chunk_manager.register_tensor(tensor=p, + group_type='fp16_param', + config_key=dp_world_size, + cpu_offload=cpu_offload, + pin_memory=pin_memory) + self.chunk_manager.register_tensor(tensor=fp32_p, + group_type='fp32_param', + config_key=dp_world_size, + cpu_offload=cpu_offload, + pin_memory=pin_memory) + + self.fp16_params.append(p) + self.fp32_params.append(fp32_p) + self.grads_device[p] = self.gemini_manager.default_device + + self.chunk_manager.close_all_groups() + + for p, fp32_p in zip(self.fp16_params, self.fp32_params): + chunk_16 = self.chunk_manager.get_chunk(p) + chunk_32 = self.chunk_manager.get_chunk(fp32_p) + chunk_32.init_pair(chunk_16) + + # keep gathered chunks are in CUDA + if chunk_16.keep_gathered: + self.grads_device[p] = get_current_device() + + def _cast_buffers(self): + for buffer in self.module.buffers(): + buffer.data = buffer.cuda() + if torch.is_floating_point(buffer): + buffer.data = buffer.half() + + +class GeminiDDP(ZeroDDP): + + def __init__(self, + module: torch.nn.Module, + device: torch.device, + placement_policy: str = "cpu", + pin_memory: bool = False, + force_outputs_fp32: bool = False, + strict_ddp_mode: bool = False, + search_range_mb: int = 32, + hidden_dim: Optional[int] = None, + min_chunk_size_mb: float = 32, + memstats: Optional[MemStats] = None) -> None: + """ + A torch.Module warpper using ZeRO-DP and Genimi. + ZeRO is for parallel. Gemini is for memory management. + WARNING: The class will modify the module inline! + + Example: + model is initialized under the context of ColoInitContext + >>> model = GeminiDDP(model, torch.cuda.current_device(), "cuda") + >>> logits = model(x) + >>> loss = criterion(logits, labels) + >>> model.backward(loss) + + Args: + module (torch.nn.Module): the model to be wrapped. + device (torch.device): device to place the model. + placement_policy (str, optional): "cpu", "cuda", "auto". Defaults to "cpu". + pin_memory (bool, optional): use pin memory on CPU. Defaults to False. + force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False. + search_range_mb (int, optional): chunk size searching range in MegaByte. Defaults to 32. + hidden_dim (int, optional): the hidden dimension of DNN. + Users can provide this argument to speed up searching. + If users do not know this argument before training, it is ok. We will use a default value 1024. + min_chunk_size_mb (float, optional): the minimum chunk size in MegaByte. + If the aggregate size of parameters is still samller than the minimum chunk size, + all parameters will be compacted into one small chunk. + memstats (MemStats, optional) the memory statistics collector by a runtime memory tracer. + """ + # some ugly hotfix for the compatibility with Lightning + if search_range_mb is None: + search_range_mb = 32 + + chunk_manager = init_chunk_manager(model=module, + init_device=device, + hidden_dim=hidden_dim, + search_range_mb=search_range_mb, + min_chunk_size_mb=min_chunk_size_mb, + strict_ddp_flag=strict_ddp_mode) + gemini_manager = GeminiManager(placement_policy, chunk_manager, memstats) + super().__init__(module, gemini_manager, pin_memory, force_outputs_fp32, strict_ddp_mode) diff --git a/colossalai/zero/utils/gemini_hook.py b/colossalai/zero/gemini/gemini_hook.py similarity index 95% rename from colossalai/zero/utils/gemini_hook.py rename to colossalai/zero/gemini/gemini_hook.py index bddc307a0504..dbc2924858e6 100644 --- a/colossalai/zero/utils/gemini_hook.py +++ b/colossalai/zero/gemini/gemini_hook.py @@ -5,10 +5,10 @@ import torch -from colossalai.gemini import TensorState -from colossalai.gemini.gemini_mgr import GeminiManager from colossalai.tensor.param_op_hook import ColoParamOpHook from colossalai.utils import is_ddp_ignored +from colossalai.zero.gemini import TensorState +from colossalai.zero.gemini.gemini_mgr import GeminiManager class TrainingPhase(Enum): diff --git a/colossalai/gemini/gemini_mgr.py b/colossalai/zero/gemini/gemini_mgr.py similarity index 97% rename from colossalai/gemini/gemini_mgr.py rename to colossalai/zero/gemini/gemini_mgr.py index 72a5e4a7f19b..c38e6eff840d 100644 --- a/colossalai/gemini/gemini_mgr.py +++ b/colossalai/zero/gemini/gemini_mgr.py @@ -4,10 +4,8 @@ import torch -from colossalai.gemini.chunk import Chunk, ChunkManager -from colossalai.gemini.memory_tracer import MemStats - -from .memory_tracer import ChunkMemStatsCollector +from .chunk import Chunk, ChunkManager +from .memory_tracer import ChunkMemStatsCollector, MemStats from .placement_policy import PlacementPolicyFactory diff --git a/colossalai/nn/optimizer/zero_optimizer.py b/colossalai/zero/gemini/gemini_optimizer.py similarity index 97% rename from colossalai/nn/optimizer/zero_optimizer.py rename to colossalai/zero/gemini/gemini_optimizer.py index 422ebb7a3944..8e0237ddc7bc 100644 --- a/colossalai/nn/optimizer/zero_optimizer.py +++ b/colossalai/zero/gemini/gemini_optimizer.py @@ -10,12 +10,15 @@ from torch.optim import Optimizer from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler -from colossalai.gemini.chunk import Chunk, ChunkManager from colossalai.logging import get_dist_logger from colossalai.nn.optimizer import ColossalaiOptimizer, CPUAdam, FusedAdam, HybridAdam -from colossalai.nn.parallel.data_parallel import ZeroDDP from colossalai.utils import disposable, get_current_device, is_ddp_ignored +from .chunk import Chunk, ChunkManager +from .gemini_ddp import ZeroDDP + +__all__ = ['ZeroOptimizer', 'GeminiAdamOptimizer'] + _AVAIL_OPTIM_LIST = {FusedAdam, CPUAdam, HybridAdam} @@ -316,3 +319,10 @@ def get_range_pair(local_chunk: Chunk, local_param: Parameter): fake_params_list.append(fake_param) group['params'] = fake_params_list + + +class GeminiAdamOptimizer(ZeroOptimizer): + + def __init__(self, model: torch.nn.Module, **defaults: Any) -> None: + optimizer = HybridAdam(model.parameters(), **defaults) + super().__init__(optimizer, model, **defaults) diff --git a/colossalai/gemini/memory_tracer/__init__.py b/colossalai/zero/gemini/memory_tracer/__init__.py similarity index 100% rename from colossalai/gemini/memory_tracer/__init__.py rename to colossalai/zero/gemini/memory_tracer/__init__.py diff --git a/colossalai/gemini/memory_tracer/chunk_memstats_collector.py b/colossalai/zero/gemini/memory_tracer/chunk_memstats_collector.py similarity index 91% rename from colossalai/gemini/memory_tracer/chunk_memstats_collector.py rename to colossalai/zero/gemini/memory_tracer/chunk_memstats_collector.py index 1a5b6bf525be..f5eb05b4f22a 100644 --- a/colossalai/gemini/memory_tracer/chunk_memstats_collector.py +++ b/colossalai/zero/gemini/memory_tracer/chunk_memstats_collector.py @@ -1,10 +1,10 @@ from typing import Optional -from colossalai.gemini.chunk import ChunkManager -from colossalai.gemini.memory_tracer import MemStats from colossalai.utils import get_current_device from colossalai.utils.memory import colo_device_memory_capacity +from colossalai.zero.gemini.chunk import ChunkManager +from .memory_stats import MemStats from .memstats_collector import MemStatsCollector diff --git a/colossalai/gemini/memory_tracer/memory_monitor.py b/colossalai/zero/gemini/memory_tracer/memory_monitor.py similarity index 100% rename from colossalai/gemini/memory_tracer/memory_monitor.py rename to colossalai/zero/gemini/memory_tracer/memory_monitor.py diff --git a/colossalai/gemini/memory_tracer/memory_stats.py b/colossalai/zero/gemini/memory_tracer/memory_stats.py similarity index 98% rename from colossalai/gemini/memory_tracer/memory_stats.py rename to colossalai/zero/gemini/memory_tracer/memory_stats.py index 84fa00fb9361..9a45034ee27e 100644 --- a/colossalai/gemini/memory_tracer/memory_stats.py +++ b/colossalai/zero/gemini/memory_tracer/memory_stats.py @@ -2,7 +2,7 @@ import torch -from colossalai.gemini.memory_tracer import OrderedParamGenerator +from .param_runtime_order import OrderedParamGenerator class MemStats(object): diff --git a/colossalai/gemini/memory_tracer/memstats_collector.py b/colossalai/zero/gemini/memory_tracer/memstats_collector.py similarity index 92% rename from colossalai/gemini/memory_tracer/memstats_collector.py rename to colossalai/zero/gemini/memory_tracer/memstats_collector.py index d939da6eb4cf..0694be48550a 100644 --- a/colossalai/gemini/memory_tracer/memstats_collector.py +++ b/colossalai/zero/gemini/memory_tracer/memstats_collector.py @@ -1,12 +1,7 @@ import time -from typing import List, Optional - -import torch - -from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor -from colossalai.gemini.stateful_tensor import StatefulTensor -from colossalai.utils.memory import colo_device_memory_used +from typing import Optional +from .memory_monitor import SyncCudaMemoryMonitor from .memory_stats import MemStats @@ -49,7 +44,7 @@ def next_period_non_model_data_usage(self, device_type: str) -> int: assert self._step_total > 0, 'Cannot get mem stats info before collection phase.' assert len(self._memstats.non_model_data_list(device_type)) > self._step_idx, \ f"{len(self._memstats.non_model_data_list(device_type))} should be > than step idx {self._step_idx}, "\ - f"step total {self._step_total}" + f"step total {self._step_total}" next_non_model_data = self._memstats.non_model_data_list(device_type)[self._step_idx] self._step_idx = (self._step_idx + 1) % self._step_total return next_non_model_data @@ -75,6 +70,8 @@ def record_model_data_volume(self) -> None: Sampling model data statistics. """ if self._start_flag and not self.use_outside_memstats: + from colossalai.zero.legacy.gemini import StatefulTensor + # The following code work for ZeroInitContext, which is deprecated in v0.1.12 cuda_mem = StatefulTensor.GST_MGR.total_mem['cuda'] self._memstats.record_max_cuda_model_data(cuda_mem) diff --git a/colossalai/gemini/memory_tracer/param_runtime_order.py b/colossalai/zero/gemini/memory_tracer/param_runtime_order.py similarity index 100% rename from colossalai/gemini/memory_tracer/param_runtime_order.py rename to colossalai/zero/gemini/memory_tracer/param_runtime_order.py diff --git a/colossalai/gemini/memory_tracer/runtime_mem_tracer.py b/colossalai/zero/gemini/memory_tracer/runtime_mem_tracer.py similarity index 95% rename from colossalai/gemini/memory_tracer/runtime_mem_tracer.py rename to colossalai/zero/gemini/memory_tracer/runtime_mem_tracer.py index a643751da7e2..0c9eac8b63e3 100644 --- a/colossalai/gemini/memory_tracer/runtime_mem_tracer.py +++ b/colossalai/zero/gemini/memory_tracer/runtime_mem_tracer.py @@ -1,9 +1,14 @@ import torch.nn -from colossalai.gemini.memory_tracer import MemStats -from colossalai.gemini.ophooks.runtime_mem_tracer_hook import GradMemStats, GradMemTracerHook, ParamMemTracerHook from colossalai.nn.parallel.data_parallel import _cast_float from colossalai.tensor.param_op_hook import ColoParamOpHookManager +from colossalai.zero.legacy.gemini.ophooks.runtime_mem_tracer_hook import ( + GradMemStats, + GradMemTracerHook, + ParamMemTracerHook, +) + +from .memory_stats import MemStats __all__ = ['RuntimeMemTracer'] diff --git a/colossalai/gemini/memory_tracer/static_memstats_collector.py b/colossalai/zero/gemini/memory_tracer/static_memstats_collector.py similarity index 98% rename from colossalai/gemini/memory_tracer/static_memstats_collector.py rename to colossalai/zero/gemini/memory_tracer/static_memstats_collector.py index 3209881e100c..b8f9a095f422 100644 --- a/colossalai/gemini/memory_tracer/static_memstats_collector.py +++ b/colossalai/zero/gemini/memory_tracer/static_memstats_collector.py @@ -6,7 +6,7 @@ from colossalai.fx.passes.meta_info_prop import MetaInfoProp from colossalai.fx.profiler import calculate_fwd_out, calculate_fwd_tmp, is_compatible_with_meta -from colossalai.gemini.chunk import ChunkManager +from colossalai.zero.gemini.chunk import ChunkManager if is_compatible_with_meta(): from colossalai.fx.profiler import MetaTensor diff --git a/colossalai/gemini/memory_tracer/utils.py b/colossalai/zero/gemini/memory_tracer/utils.py similarity index 100% rename from colossalai/gemini/memory_tracer/utils.py rename to colossalai/zero/gemini/memory_tracer/utils.py diff --git a/colossalai/gemini/placement_policy.py b/colossalai/zero/gemini/placement_policy.py similarity index 98% rename from colossalai/gemini/placement_policy.py rename to colossalai/zero/gemini/placement_policy.py index fed1cc2985ff..84a868872f88 100644 --- a/colossalai/gemini/placement_policy.py +++ b/colossalai/zero/gemini/placement_policy.py @@ -5,11 +5,12 @@ import torch -from colossalai.gemini.chunk import Chunk, ChunkManager -from colossalai.gemini.memory_tracer import ChunkMemStatsCollector from colossalai.utils import get_current_device from colossalai.utils.memory import colo_device_memory_capacity +from .chunk import Chunk, ChunkManager +from .memory_tracer import ChunkMemStatsCollector + class PlacementPolicy(ABC): need_mem_stats: bool = False diff --git a/colossalai/nn/parallel/utils.py b/colossalai/zero/gemini/utils.py similarity index 97% rename from colossalai/nn/parallel/utils.py rename to colossalai/zero/gemini/utils.py index 08fdb6026e38..e52b5b836b0b 100644 --- a/colossalai/nn/parallel/utils.py +++ b/colossalai/zero/gemini/utils.py @@ -6,9 +6,10 @@ import torch.distributed as dist import torch.nn as nn -from colossalai.gemini.chunk import Chunk from colossalai.utils import get_current_device +from .chunk import Chunk + def get_temp_total_chunk_on_cuda(chunk: Chunk): if chunk.is_gathered: @@ -77,7 +78,7 @@ def get_static_torch_model(zero_ddp_model, Returns: torch.nn.Module: a static torch model used for saving checkpoints or numeric checks """ - from colossalai.nn.parallel import ZeroDDP + from colossalai.zero.gemini.gemini_ddp import ZeroDDP assert isinstance(zero_ddp_model, ZeroDDP) state_dict = zero_ddp_model.state_dict(only_rank_0=only_rank_0) diff --git a/colossalai/zero/legacy/__init__.py b/colossalai/zero/legacy/__init__.py new file mode 100644 index 000000000000..35570a1f539a --- /dev/null +++ b/colossalai/zero/legacy/__init__.py @@ -0,0 +1,44 @@ +from typing import Tuple + +import torch +import torch.nn as nn + +from colossalai.logging import get_dist_logger + +from .init_ctx import ZeroInitContext, no_shard_zero_context, no_shard_zero_decrator +from .sharded_model import ShardedModelV2 +from .sharded_optim import ShardedOptimizerV2 + + +def convert_to_zero_v2(model: nn.Module, optimizer: torch.optim.Optimizer, model_config, + optimizer_config) -> Tuple[ShardedModelV2, ShardedOptimizerV2]: + """ + A helper function to integrate the model and optimizer with ZeRO optimizer and off-loading + + :param model: Your model object + :type model: :class:`torch.nn.Module` + :param optimizer_config: Your optimizer object + :type optimizer_config: :class:`dict` + + :return: (model, optimizer) + :rtype: Tuple + """ + + logger = get_dist_logger('convert_to_zero_v2') + + logger.info(f'optimizer_config is {optimizer_config}', ranks=[0]) + if optimizer_config is None: + optimizer_config = dict() + logger.info(f'model_config is {model_config}', ranks=[0]) + if model_config is None: + model_config = dict() + + zero_model = ShardedModelV2(model, **model_config) + zero_optimizer = ShardedOptimizerV2(zero_model, optimizer, **optimizer_config) + return zero_model, zero_optimizer + + +__all__ = [ + 'convert_to_zero_v2', 'ShardedModelV2', 'ShardedOptimizerV2', 'ZeroInitContext', 'no_shard_zero_context', + 'no_shard_zero_decrator' +] diff --git a/colossalai/zero/legacy/gemini/__init__.py b/colossalai/zero/legacy/gemini/__init__.py new file mode 100644 index 000000000000..754ae9bc0044 --- /dev/null +++ b/colossalai/zero/legacy/gemini/__init__.py @@ -0,0 +1,9 @@ +from .ophooks import BaseOpHook, register_ophooks_recursively +from .stateful_tensor import StatefulTensor +from .stateful_tensor_mgr import StatefulTensorMgr +from .tensor_placement_policy import AutoTensorPlacementPolicy, CPUTensorPlacementPolicy, CUDATensorPlacementPolicy + +__all__ = [ + 'StatefulTensorMgr', 'StatefulTensor', 'CPUTensorPlacementPolicy', 'CUDATensorPlacementPolicy', + 'AutoTensorPlacementPolicy', 'register_ophooks_recursively', 'BaseOpHook' +] diff --git a/colossalai/gemini/gemini_context.py b/colossalai/zero/legacy/gemini/gemini_context.py similarity index 100% rename from colossalai/gemini/gemini_context.py rename to colossalai/zero/legacy/gemini/gemini_context.py diff --git a/colossalai/gemini/ophooks/__init__.py b/colossalai/zero/legacy/gemini/ophooks/__init__.py similarity index 100% rename from colossalai/gemini/ophooks/__init__.py rename to colossalai/zero/legacy/gemini/ophooks/__init__.py diff --git a/colossalai/gemini/ophooks/_shard_grad_ophook.py b/colossalai/zero/legacy/gemini/ophooks/_shard_grad_ophook.py similarity index 100% rename from colossalai/gemini/ophooks/_shard_grad_ophook.py rename to colossalai/zero/legacy/gemini/ophooks/_shard_grad_ophook.py diff --git a/colossalai/gemini/ophooks/_shard_param_ophook.py b/colossalai/zero/legacy/gemini/ophooks/_shard_param_ophook.py similarity index 99% rename from colossalai/gemini/ophooks/_shard_param_ophook.py rename to colossalai/zero/legacy/gemini/ophooks/_shard_param_ophook.py index 57f76970cc86..80736d14085e 100644 --- a/colossalai/gemini/ophooks/_shard_param_ophook.py +++ b/colossalai/zero/legacy/gemini/ophooks/_shard_param_ophook.py @@ -1,4 +1,5 @@ import torch + from colossalai.registry import OPHOOKS from . import BaseOpHook diff --git a/colossalai/gemini/ophooks/runtime_mem_tracer_hook.py b/colossalai/zero/legacy/gemini/ophooks/runtime_mem_tracer_hook.py similarity index 96% rename from colossalai/gemini/ophooks/runtime_mem_tracer_hook.py rename to colossalai/zero/legacy/gemini/ophooks/runtime_mem_tracer_hook.py index 6d0df4e615ca..f40d6ced1ee0 100644 --- a/colossalai/gemini/ophooks/runtime_mem_tracer_hook.py +++ b/colossalai/zero/legacy/gemini/ophooks/runtime_mem_tracer_hook.py @@ -5,9 +5,9 @@ import torch -from colossalai.gemini.memory_tracer import MemStats, SyncCudaMemoryMonitor -from colossalai.gemini.tensor_utils import alloc_storage, free_storage from colossalai.tensor.param_op_hook import ColoParamOpHook +from colossalai.zero.gemini.memory_tracer import MemStats, SyncCudaMemoryMonitor +from colossalai.zero.legacy.gemini.tensor_utils import alloc_storage, free_storage class TrainingPhase(Enum): diff --git a/colossalai/gemini/ophooks/utils.py b/colossalai/zero/legacy/gemini/ophooks/utils.py similarity index 100% rename from colossalai/gemini/ophooks/utils.py rename to colossalai/zero/legacy/gemini/ophooks/utils.py diff --git a/colossalai/gemini/paramhooks/__init__.py b/colossalai/zero/legacy/gemini/paramhooks/__init__.py similarity index 100% rename from colossalai/gemini/paramhooks/__init__.py rename to colossalai/zero/legacy/gemini/paramhooks/__init__.py diff --git a/colossalai/gemini/paramhooks/_param_hookmgr.py b/colossalai/zero/legacy/gemini/paramhooks/_param_hookmgr.py similarity index 100% rename from colossalai/gemini/paramhooks/_param_hookmgr.py rename to colossalai/zero/legacy/gemini/paramhooks/_param_hookmgr.py diff --git a/colossalai/gemini/stateful_tensor.py b/colossalai/zero/legacy/gemini/stateful_tensor.py similarity index 97% rename from colossalai/gemini/stateful_tensor.py rename to colossalai/zero/legacy/gemini/stateful_tensor.py index 18fc8fd14d3c..1619ae40798d 100644 --- a/colossalai/gemini/stateful_tensor.py +++ b/colossalai/zero/legacy/gemini/stateful_tensor.py @@ -1,9 +1,9 @@ from enum import Enum -from typing import Optional +from typing import Optional, Union + import torch -from typing import Union -from colossalai.gemini.gemini_context import GeminiMemoryManager +from .gemini_context import GeminiMemoryManager def sizeof_tensor(tensor: torch.Tensor): @@ -19,7 +19,7 @@ class TensorState(Enum): class StatefulTensor(object): - """A Structure stores a Torch Tensor and labeled states. + """A Structure stores a Torch Tensor and labeled states. Inspired from the paper: PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management diff --git a/colossalai/gemini/stateful_tensor_mgr.py b/colossalai/zero/legacy/gemini/stateful_tensor_mgr.py similarity index 94% rename from colossalai/gemini/stateful_tensor_mgr.py rename to colossalai/zero/legacy/gemini/stateful_tensor_mgr.py index c300f9bffc89..3b37444b0fe0 100644 --- a/colossalai/gemini/stateful_tensor_mgr.py +++ b/colossalai/zero/legacy/gemini/stateful_tensor_mgr.py @@ -1,13 +1,16 @@ import functools -import torch import types -from colossalai.utils.cuda import get_current_device -from colossalai.gemini.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage -from colossalai.gemini.stateful_tensor import StatefulTensor, TensorState -from colossalai.gemini.tensor_placement_policy import TensorPlacementPolicy +from time import time from typing import List + +import torch + from colossalai.logging import get_dist_logger -from time import time +from colossalai.utils.cuda import get_current_device + +from .stateful_tensor import StatefulTensor, TensorState +from .tensor_placement_policy import TensorPlacementPolicy +from .tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage class StatefulTensorMgr(object): diff --git a/colossalai/gemini/tensor_placement_policy.py b/colossalai/zero/legacy/gemini/tensor_placement_policy.py similarity index 96% rename from colossalai/gemini/tensor_placement_policy.py rename to colossalai/zero/legacy/gemini/tensor_placement_policy.py index 0e575254c0b6..165ae51fee60 100644 --- a/colossalai/gemini/tensor_placement_policy.py +++ b/colossalai/zero/legacy/gemini/tensor_placement_policy.py @@ -5,11 +5,12 @@ import torch -from colossalai.gemini.memory_tracer import MemStatsCollector -from colossalai.gemini.stateful_tensor import StatefulTensor -from colossalai.gemini.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage from colossalai.utils import get_current_device from colossalai.utils.memory import colo_device_memory_capacity +from colossalai.zero.gemini.memory_tracer import MemStatsCollector + +from .stateful_tensor import StatefulTensor +from .tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage class TensorPlacementPolicy(ABC): diff --git a/colossalai/gemini/tensor_utils.py b/colossalai/zero/legacy/gemini/tensor_utils.py similarity index 97% rename from colossalai/gemini/tensor_utils.py rename to colossalai/zero/legacy/gemini/tensor_utils.py index bcc159f9954a..b7f23e0253fd 100644 --- a/colossalai/gemini/tensor_utils.py +++ b/colossalai/zero/legacy/gemini/tensor_utils.py @@ -1,6 +1,8 @@ +from typing import Tuple, Union + import torch -from colossalai.gemini.stateful_tensor import StatefulTensor -from typing import Union, Tuple + +from .stateful_tensor import StatefulTensor def is_storage_empty(tensor: torch.Tensor) -> bool: diff --git a/colossalai/zero/init_ctx/__init__.py b/colossalai/zero/legacy/init_ctx/__init__.py similarity index 100% rename from colossalai/zero/init_ctx/__init__.py rename to colossalai/zero/legacy/init_ctx/__init__.py diff --git a/colossalai/zero/init_ctx/init_context.py b/colossalai/zero/legacy/init_ctx/init_context.py similarity index 97% rename from colossalai/zero/init_ctx/init_context.py rename to colossalai/zero/legacy/init_ctx/init_context.py index b40b69962cf7..f8be0ca4f3fc 100644 --- a/colossalai/zero/init_ctx/init_context.py +++ b/colossalai/zero/legacy/init_ctx/init_context.py @@ -13,10 +13,10 @@ from colossalai.core import global_context as gpc from colossalai.logging import get_dist_logger from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses -from colossalai.zero.shard_utils import BaseShardStrategy -from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16 -from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2 -from colossalai.zero.sharded_param import ShardedParamV2 +from colossalai.zero.legacy.shard_utils import BaseShardStrategy +from colossalai.zero.legacy.sharded_model._utils import cast_tensor_to_fp16 +from colossalai.zero.legacy.sharded_model.sharded_model_v2 import ShardedModelV2 +from colossalai.zero.legacy.sharded_param import ShardedParamV2 @dataclass diff --git a/colossalai/zero/shard_utils/__init__.py b/colossalai/zero/legacy/shard_utils/__init__.py similarity index 100% rename from colossalai/zero/shard_utils/__init__.py rename to colossalai/zero/legacy/shard_utils/__init__.py diff --git a/colossalai/zero/shard_utils/base_shard_strategy.py b/colossalai/zero/legacy/shard_utils/base_shard_strategy.py similarity index 87% rename from colossalai/zero/shard_utils/base_shard_strategy.py rename to colossalai/zero/legacy/shard_utils/base_shard_strategy.py index 7c2f4c9f6659..7ca951091640 100644 --- a/colossalai/zero/shard_utils/base_shard_strategy.py +++ b/colossalai/zero/legacy/shard_utils/base_shard_strategy.py @@ -2,7 +2,8 @@ from typing import List, Optional import torch.distributed as dist -from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor + +from colossalai.zero.legacy.sharded_param.sharded_tensor import ShardedTensor class BaseShardStrategy(ABC): diff --git a/colossalai/zero/shard_utils/bucket_tensor_shard_strategy.py b/colossalai/zero/legacy/shard_utils/bucket_tensor_shard_strategy.py similarity index 89% rename from colossalai/zero/shard_utils/bucket_tensor_shard_strategy.py rename to colossalai/zero/legacy/shard_utils/bucket_tensor_shard_strategy.py index a7bd7cf538e7..11297bf6d62c 100644 --- a/colossalai/zero/shard_utils/bucket_tensor_shard_strategy.py +++ b/colossalai/zero/legacy/shard_utils/bucket_tensor_shard_strategy.py @@ -2,17 +2,18 @@ import torch import torch.distributed as dist -from colossalai.utils import get_current_device -from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor from torch._utils import _flatten_dense_tensors as flatten +from colossalai.utils import get_current_device +from colossalai.zero.legacy.sharded_param.sharded_tensor import ShardedTensor + from .tensor_shard_strategy import TensorShardStrategy class BucketTensorShardStrategy(TensorShardStrategy): - """Use the same shard scheme as `TensorShardStrategy`'s, but it gathers tensors of a sub-module together, - which will fully utilize network bandwidth. - It is especially useful when sub-module contains bias, + """Use the same shard scheme as `TensorShardStrategy`'s, but it gathers tensors of a sub-module together, + which will fully utilize network bandwidth. + It is especially useful when sub-module contains bias, since we cannot utilize network bandwidth well if we only gather a bias tensor (bias is usaully small). """ diff --git a/colossalai/zero/shard_utils/commons.py b/colossalai/zero/legacy/shard_utils/commons.py similarity index 95% rename from colossalai/zero/shard_utils/commons.py rename to colossalai/zero/legacy/shard_utils/commons.py index 71cef44c177f..bf5ae325caf4 100644 --- a/colossalai/zero/shard_utils/commons.py +++ b/colossalai/zero/legacy/shard_utils/commons.py @@ -1,7 +1,7 @@ -import torch -import torch.nn.functional as F from typing import Tuple +import torch + def get_shard(tensor: torch.Tensor, rank: int, world_size: int) -> Tuple[torch.Tensor, int]: """Return the local shard of a full tensor.""" diff --git a/colossalai/zero/shard_utils/tensor_shard_strategy.py b/colossalai/zero/legacy/shard_utils/tensor_shard_strategy.py similarity index 86% rename from colossalai/zero/shard_utils/tensor_shard_strategy.py rename to colossalai/zero/legacy/shard_utils/tensor_shard_strategy.py index 5bdd95400d82..d1df4803b820 100644 --- a/colossalai/zero/shard_utils/tensor_shard_strategy.py +++ b/colossalai/zero/legacy/shard_utils/tensor_shard_strategy.py @@ -2,11 +2,12 @@ import torch import torch.distributed as dist + from colossalai.utils import get_current_device -from colossalai.zero.shard_utils import BaseShardStrategy -from colossalai.zero.shard_utils.commons import get_shard -from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor -from colossalai.gemini.tensor_utils import colo_model_data_tensor_move_inline +from colossalai.zero.legacy.gemini.tensor_utils import colo_model_data_tensor_move_inline +from colossalai.zero.legacy.shard_utils import BaseShardStrategy +from colossalai.zero.legacy.shard_utils.commons import get_shard +from colossalai.zero.legacy.sharded_param.sharded_tensor import ShardedTensor class TensorShardStrategy(BaseShardStrategy): @@ -27,7 +28,7 @@ def _shard_tensor(self, t: ShardedTensor, process_group: Optional[dist.ProcessGr Args: t (ShardedTensor): a tensor to be sharded. - process_group (Optional[dist.ProcessGroup], optional): the process group among which tensor shards. + process_group (Optional[dist.ProcessGroup], optional): the process group among which tensor shards. Defaults to None. """ if t.is_sharded: diff --git a/colossalai/zero/sharded_model/__init__.py b/colossalai/zero/legacy/sharded_model/__init__.py similarity index 61% rename from colossalai/zero/sharded_model/__init__.py rename to colossalai/zero/legacy/sharded_model/__init__.py index 725179295c60..93120bdc34b4 100644 --- a/colossalai/zero/sharded_model/__init__.py +++ b/colossalai/zero/legacy/sharded_model/__init__.py @@ -1,3 +1,3 @@ from .sharded_model_v2 import ShardedModelV2 -__all__ = ['ShardedModelV2'] \ No newline at end of file +__all__ = ['ShardedModelV2'] diff --git a/colossalai/zero/sharded_model/_utils.py b/colossalai/zero/legacy/sharded_model/_utils.py similarity index 95% rename from colossalai/zero/sharded_model/_utils.py rename to colossalai/zero/legacy/sharded_model/_utils.py index 85a3ab73dd1b..2bd01531a78f 100644 --- a/colossalai/zero/sharded_model/_utils.py +++ b/colossalai/zero/legacy/sharded_model/_utils.py @@ -1,9 +1,9 @@ -from typing import Any, Callable, List, Tuple +from typing import Any, Callable, List, Tuple, Union import torch import torch.nn.functional as F -from typing import Union -from colossalai.gemini.stateful_tensor import StatefulTensor + +from colossalai.zero.legacy.gemini.stateful_tensor import StatefulTensor def get_gradient_predivide_factor(world_size: int) -> float: diff --git a/colossalai/zero/sharded_model/reduce_scatter.py b/colossalai/zero/legacy/sharded_model/reduce_scatter.py similarity index 100% rename from colossalai/zero/sharded_model/reduce_scatter.py rename to colossalai/zero/legacy/sharded_model/reduce_scatter.py diff --git a/colossalai/zero/sharded_model/sharded_model_v2.py b/colossalai/zero/legacy/sharded_model/sharded_model_v2.py similarity index 97% rename from colossalai/zero/sharded_model/sharded_model_v2.py rename to colossalai/zero/legacy/sharded_model/sharded_model_v2.py index 12e8f65d4a35..edd2cc8e68fe 100644 --- a/colossalai/zero/sharded_model/sharded_model_v2.py +++ b/colossalai/zero/legacy/sharded_model/sharded_model_v2.py @@ -13,19 +13,18 @@ from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc -from colossalai.gemini.memory_tracer import MemStatsCollector, StaticMemStatsCollector -from colossalai.gemini.ophooks import register_ophooks_recursively -from colossalai.gemini.paramhooks import BaseParamHookMgr -from colossalai.gemini.stateful_tensor import TensorState -from colossalai.gemini.stateful_tensor_mgr import StatefulTensorMgr -from colossalai.gemini.tensor_placement_policy import TensorPlacementPolicy, TensorPlacementPolicyFactory -from colossalai.gemini.tensor_utils import colo_model_data_move_to_cpu from colossalai.logging import get_dist_logger from colossalai.utils import disposable, get_current_device from colossalai.utils.memory import colo_device_memory_capacity -from colossalai.zero.shard_utils import BaseShardStrategy -from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer -from colossalai.zero.utils import ZeroHook +from colossalai.zero.gemini.memory_tracer import MemStatsCollector, StaticMemStatsCollector +from colossalai.zero.legacy.gemini.ophooks import register_ophooks_recursively +from colossalai.zero.legacy.gemini.paramhooks import BaseParamHookMgr +from colossalai.zero.legacy.gemini.stateful_tensor import TensorState +from colossalai.zero.legacy.gemini.stateful_tensor_mgr import StatefulTensorMgr +from colossalai.zero.legacy.gemini.tensor_placement_policy import TensorPlacementPolicy, TensorPlacementPolicyFactory +from colossalai.zero.legacy.gemini.tensor_utils import colo_model_data_move_to_cpu +from colossalai.zero.legacy.shard_utils import BaseShardStrategy +from colossalai.zero.legacy.sharded_model.reduce_scatter import ReduceScatterBucketer from ._utils import ( cast_float_arguments, @@ -35,6 +34,7 @@ free_storage, get_gradient_predivide_factor, ) +from .zero_hook import ZeroHook try: from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX diff --git a/colossalai/zero/sharded_model/utils.py b/colossalai/zero/legacy/sharded_model/utils.py similarity index 91% rename from colossalai/zero/sharded_model/utils.py rename to colossalai/zero/legacy/sharded_model/utils.py index 69f5a23ac920..08806e78ea3b 100644 --- a/colossalai/zero/sharded_model/utils.py +++ b/colossalai/zero/legacy/sharded_model/utils.py @@ -1,7 +1,8 @@ +import copy + import torch -from colossalai.zero.sharded_model import ShardedModelV2 -import copy +from colossalai.zero.legacy.sharded_model import ShardedModelV2 def col_model_deepcopy(sharded_model: ShardedModelV2, other_model: torch.nn.Module): diff --git a/colossalai/zero/utils/zero_hook.py b/colossalai/zero/legacy/sharded_model/zero_hook.py similarity index 92% rename from colossalai/zero/utils/zero_hook.py rename to colossalai/zero/legacy/sharded_model/zero_hook.py index 87bf2c0f5086..50f4bdfc775d 100644 --- a/colossalai/zero/utils/zero_hook.py +++ b/colossalai/zero/legacy/sharded_model/zero_hook.py @@ -3,14 +3,14 @@ import torch import torch.distributed as dist -from colossalai.gemini.memory_tracer import MemStatsCollector -from colossalai.gemini.ophooks import BaseOpHook -from colossalai.gemini.stateful_tensor import TensorState -from colossalai.gemini.stateful_tensor_mgr import StatefulTensorMgr from colossalai.logging import get_dist_logger from colossalai.registry import OPHOOKS from colossalai.utils import get_current_device -from colossalai.zero.shard_utils import BaseShardStrategy +from colossalai.zero.gemini.memory_tracer import MemStatsCollector +from colossalai.zero.legacy.gemini.ophooks import BaseOpHook +from colossalai.zero.legacy.gemini.stateful_tensor import TensorState +from colossalai.zero.legacy.gemini.stateful_tensor_mgr import StatefulTensorMgr +from colossalai.zero.legacy.shard_utils import BaseShardStrategy @OPHOOKS.register_module diff --git a/colossalai/zero/legacy/sharded_optim/__init__.py b/colossalai/zero/legacy/sharded_optim/__init__.py new file mode 100644 index 000000000000..b71a70aeffa4 --- /dev/null +++ b/colossalai/zero/legacy/sharded_optim/__init__.py @@ -0,0 +1,3 @@ +from .sharded_optim_v2 import ShardedOptimizerV2 + +__all__ = ['ShardedOptimizerV2'] diff --git a/colossalai/zero/sharded_optim/sharded_optim_v2.py b/colossalai/zero/legacy/sharded_optim/sharded_optim_v2.py similarity index 97% rename from colossalai/zero/sharded_optim/sharded_optim_v2.py rename to colossalai/zero/legacy/sharded_optim/sharded_optim_v2.py index 43a0b7d76107..7ce1c056f583 100644 --- a/colossalai/zero/sharded_optim/sharded_optim_v2.py +++ b/colossalai/zero/legacy/sharded_optim/sharded_optim_v2.py @@ -14,13 +14,13 @@ from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc -from colossalai.gemini.stateful_tensor import StatefulTensor, TensorState -from colossalai.gemini.tensor_placement_policy import AutoTensorPlacementPolicy -from colossalai.gemini.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage from colossalai.logging import get_dist_logger from colossalai.nn.optimizer import ColossalaiOptimizer -from colossalai.zero.sharded_model import ShardedModelV2 -from colossalai.zero.sharded_model._utils import cast_tensor_to_fp32 +from colossalai.zero.legacy.gemini.stateful_tensor import StatefulTensor, TensorState +from colossalai.zero.legacy.gemini.tensor_placement_policy import AutoTensorPlacementPolicy +from colossalai.zero.legacy.gemini.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage +from colossalai.zero.legacy.sharded_model import ShardedModelV2 +from colossalai.zero.legacy.sharded_model._utils import cast_tensor_to_fp32 class OptimState(Enum): diff --git a/colossalai/zero/legacy/sharded_param/__init__.py b/colossalai/zero/legacy/sharded_param/__init__.py new file mode 100644 index 000000000000..47e2ce2fa0e0 --- /dev/null +++ b/colossalai/zero/legacy/sharded_param/__init__.py @@ -0,0 +1,4 @@ +from .sharded_param import ShardedParamV2 +from .sharded_tensor import ShardedTensor + +__all__ = ['ShardedTensor', 'ShardedParamV2'] diff --git a/colossalai/zero/sharded_param/sharded_param.py b/colossalai/zero/legacy/sharded_param/sharded_param.py similarity index 93% rename from colossalai/zero/sharded_param/sharded_param.py rename to colossalai/zero/legacy/sharded_param/sharded_param.py index db0f2d149431..4bcc4b62104a 100644 --- a/colossalai/zero/sharded_param/sharded_param.py +++ b/colossalai/zero/legacy/sharded_param/sharded_param.py @@ -1,9 +1,11 @@ +from typing import List, Optional, Tuple + import torch -from typing import Optional, Tuple -from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor -from colossalai.gemini.tensor_utils import colo_tensor_mem_usage -from colossalai.gemini.stateful_tensor import StatefulTensor, TensorState -from typing import List + +from colossalai.zero.legacy.gemini.stateful_tensor import StatefulTensor, TensorState +from colossalai.zero.legacy.gemini.tensor_utils import colo_tensor_mem_usage + +from .sharded_tensor import ShardedTensor EMPTY_TENSOR_DICT = {} diff --git a/colossalai/zero/sharded_param/sharded_tensor.py b/colossalai/zero/legacy/sharded_param/sharded_tensor.py similarity index 92% rename from colossalai/zero/sharded_param/sharded_tensor.py rename to colossalai/zero/legacy/sharded_param/sharded_tensor.py index 77f4aec30f32..af60312600f2 100644 --- a/colossalai/zero/sharded_param/sharded_tensor.py +++ b/colossalai/zero/legacy/sharded_param/sharded_tensor.py @@ -1,5 +1,6 @@ import torch -from colossalai.gemini.stateful_tensor import StatefulTensor, TensorState + +from colossalai.zero.legacy.gemini.stateful_tensor import StatefulTensor, TensorState class ShardedTensor(StatefulTensor): diff --git a/colossalai/zero/low_level/__init__.py b/colossalai/zero/low_level/__init__.py new file mode 100644 index 000000000000..ae3c1de3a5bc --- /dev/null +++ b/colossalai/zero/low_level/__init__.py @@ -0,0 +1,3 @@ +from .low_level_optim import LowLevelZeroOptimizer + +__all__ = ['LowLevelZeroOptimizer'] diff --git a/colossalai/zero/sharded_optim/_utils.py b/colossalai/zero/low_level/_utils.py similarity index 100% rename from colossalai/zero/sharded_optim/_utils.py rename to colossalai/zero/low_level/_utils.py diff --git a/colossalai/zero/sharded_optim/bookkeeping/__init__.py b/colossalai/zero/low_level/bookkeeping/__init__.py similarity index 100% rename from colossalai/zero/sharded_optim/bookkeeping/__init__.py rename to colossalai/zero/low_level/bookkeeping/__init__.py diff --git a/colossalai/zero/sharded_optim/bookkeeping/base_store.py b/colossalai/zero/low_level/bookkeeping/base_store.py similarity index 100% rename from colossalai/zero/sharded_optim/bookkeeping/base_store.py rename to colossalai/zero/low_level/bookkeeping/base_store.py diff --git a/colossalai/zero/sharded_optim/bookkeeping/bucket_store.py b/colossalai/zero/low_level/bookkeeping/bucket_store.py similarity index 100% rename from colossalai/zero/sharded_optim/bookkeeping/bucket_store.py rename to colossalai/zero/low_level/bookkeeping/bucket_store.py diff --git a/colossalai/zero/sharded_optim/bookkeeping/gradient_store.py b/colossalai/zero/low_level/bookkeeping/gradient_store.py similarity index 100% rename from colossalai/zero/sharded_optim/bookkeeping/gradient_store.py rename to colossalai/zero/low_level/bookkeeping/gradient_store.py diff --git a/colossalai/zero/sharded_optim/bookkeeping/parameter_store.py b/colossalai/zero/low_level/bookkeeping/parameter_store.py similarity index 100% rename from colossalai/zero/sharded_optim/bookkeeping/parameter_store.py rename to colossalai/zero/low_level/bookkeeping/parameter_store.py diff --git a/colossalai/zero/sharded_optim/bookkeeping/tensor_bucket.py b/colossalai/zero/low_level/bookkeeping/tensor_bucket.py similarity index 100% rename from colossalai/zero/sharded_optim/bookkeeping/tensor_bucket.py rename to colossalai/zero/low_level/bookkeeping/tensor_bucket.py diff --git a/colossalai/zero/sharded_optim/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py similarity index 100% rename from colossalai/zero/sharded_optim/low_level_optim.py rename to colossalai/zero/low_level/low_level_optim.py diff --git a/colossalai/zero/sharded_optim/__init__.py b/colossalai/zero/sharded_optim/__init__.py deleted file mode 100644 index 30c26fb75f30..000000000000 --- a/colossalai/zero/sharded_optim/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .low_level_optim import LowLevelZeroOptimizer -from .sharded_optim_v2 import ShardedOptimizerV2 - -__all__ = ['ShardedOptimizerV2', 'LowLevelZeroOptimizer'] diff --git a/colossalai/zero/sharded_param/__init__.py b/colossalai/zero/sharded_param/__init__.py deleted file mode 100644 index 5642a504acf7..000000000000 --- a/colossalai/zero/sharded_param/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor -from colossalai.zero.sharded_param.sharded_param import ShardedParamV2 - -__all__ = ['ShardedTensor', 'ShardedParamV2'] diff --git a/colossalai/zero/utils/__init__.py b/colossalai/zero/utils/__init__.py deleted file mode 100644 index c4e687228957..000000000000 --- a/colossalai/zero/utils/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .zero_hook import ZeroHook - -__all__ = ['ZeroHook'] \ No newline at end of file diff --git a/colossalai/nn/parallel/zero_wrapper.py b/colossalai/zero/wrapper.py similarity index 95% rename from colossalai/nn/parallel/zero_wrapper.py rename to colossalai/zero/wrapper.py index be8d1da7c24e..4553249e271d 100644 --- a/colossalai/nn/parallel/zero_wrapper.py +++ b/colossalai/zero/wrapper.py @@ -4,7 +4,7 @@ import torch import torch.nn as nn -from .gemini_parallel import GeminiDDP +from .gemini import GeminiDDP def zero_model_wrapper(model: nn.Module, zero_stage: int = 1, gemini_config: Optional[Dict] = None): @@ -99,11 +99,11 @@ def zero_optim_wrapper(model: nn.Module, config_dict['max_scale'] = max_scale if zero_stage in [1, 2]: - from colossalai.zero.sharded_optim.low_level_optim import LowLevelZeroOptimizer + from colossalai.zero.low_level import LowLevelZeroOptimizer config_dict['partition_grad'] = zero_stage == 2 config_dict['clip_grad_norm'] = max_norm return LowLevelZeroOptimizer(optimizer, **config_dict) else: - from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer + from colossalai.zero.gemini.gemini_optimizer import ZeroOptimizer config_dict['clipping_norm'] = max_norm return ZeroOptimizer(optimizer, model, **config_dict) diff --git a/docs/source/en/features/nvme_offload.md b/docs/source/en/features/nvme_offload.md index 2933c3db6c58..38d2c4af904c 100644 --- a/docs/source/en/features/nvme_offload.md +++ b/docs/source/en/features/nvme_offload.md @@ -78,7 +78,7 @@ from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel import colossalai from colossalai.nn.optimizer import HybridAdam -from colossalai.nn.parallel import zero_model_wrapper, zero_optim_wrapper +from colossalai.zero import zero_model_wrapper, zero_optim_wrapper from colossalai.utils.model.colo_init_context import ColoInitContext ``` diff --git a/docs/source/zh-Hans/features/nvme_offload.md b/docs/source/zh-Hans/features/nvme_offload.md index f33474efaa78..fd75ed1f5b3e 100644 --- a/docs/source/zh-Hans/features/nvme_offload.md +++ b/docs/source/zh-Hans/features/nvme_offload.md @@ -77,7 +77,7 @@ from transformers.models.gpt2.configuration_gpt2 import GPT2Config from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel import colossalai from colossalai.nn.optimizer import HybridAdam -from colossalai.nn.parallel import zero_model_wrapper, zero_optim_wrapper +from colossalai.zero import zero_model_wrapper, zero_optim_wrapper from colossalai.utils.model.colo_init_context import ColoInitContext ``` diff --git a/examples/images/dreambooth/debug.py b/examples/images/dreambooth/debug.py index c4adb48230be..33219b2caa29 100644 --- a/examples/images/dreambooth/debug.py +++ b/examples/images/dreambooth/debug.py @@ -5,7 +5,7 @@ from diffusers import AutoencoderKL import colossalai -from colossalai.utils.model.colo_init_context import ColoInitContext, post_process_colo_init_ctx +from colossalai.zero import ColoInitContext, post_process_colo_init_ctx path = "/data/scratch/diffuser/stable-diffusion-v1-4" diff --git a/examples/images/dreambooth/train_dreambooth_colossalai.py b/examples/images/dreambooth/train_dreambooth_colossalai.py index 5c4c86bc7073..e6159e1058b9 100644 --- a/examples/images/dreambooth/train_dreambooth_colossalai.py +++ b/examples/images/dreambooth/train_dreambooth_colossalai.py @@ -21,10 +21,9 @@ from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc from colossalai.logging import disable_existing_loggers, get_dist_logger -from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer -from colossalai.nn.parallel.utils import get_static_torch_model from colossalai.utils import get_current_device -from colossalai.utils.model.colo_init_context import ColoInitContext +from colossalai.zero import ColoInitContext, GeminiAdamOptimizer +from colossalai.zero.gemini import get_static_torch_model disable_existing_loggers() logger = get_dist_logger() diff --git a/examples/images/dreambooth/train_dreambooth_colossalai_lora.py b/examples/images/dreambooth/train_dreambooth_colossalai_lora.py index 3d789ae2ce0f..1b2fc778d5ed 100644 --- a/examples/images/dreambooth/train_dreambooth_colossalai_lora.py +++ b/examples/images/dreambooth/train_dreambooth_colossalai_lora.py @@ -23,10 +23,9 @@ from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc from colossalai.logging import disable_existing_loggers, get_dist_logger -from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer -from colossalai.nn.parallel.utils import get_static_torch_model from colossalai.utils import get_current_device -from colossalai.utils.model.colo_init_context import ColoInitContext +from colossalai.zero import ColoInitContext, GeminiAdamOptimizer +from colossalai.zero.gemini import get_static_torch_model disable_existing_loggers() logger = get_dist_logger() diff --git a/examples/images/vit/test_vit.py b/examples/images/vit/test_vit.py index 90f2475b885e..6a587e1df96a 100644 --- a/examples/images/vit/test_vit.py +++ b/examples/images/vit/test_vit.py @@ -18,7 +18,7 @@ from colossalai.testing import rerun_if_address_is_in_use from colossalai.utils import free_port from colossalai.utils.cuda import get_current_device -from colossalai.utils.model.colo_init_context import ColoInitContext +from colossalai.zero import ColoInitContext def set_seed(seed): diff --git a/examples/images/vit/train.py b/examples/images/vit/train.py index 0b4489244368..b42cf2bedc6b 100644 --- a/examples/images/vit/train.py +++ b/examples/images/vit/train.py @@ -19,7 +19,7 @@ from colossalai.nn.parallel.data_parallel import ColoDDP from colossalai.tensor import ComputePattern, ComputeSpec, DistSpecManager, ProcessGroup, ShardSpec from colossalai.utils import get_current_device -from colossalai.utils.model.colo_init_context import ColoInitContext +from colossalai.zero import ColoInitContext def init_1d_row_for_linear_weight_spec(model, world_size: int): diff --git a/examples/language/bert/train_bert_demo.py b/examples/language/bert/train_bert_demo.py index b690ff787d01..9a0278b2c711 100644 --- a/examples/language/bert/train_bert_demo.py +++ b/examples/language/bert/train_bert_demo.py @@ -12,10 +12,9 @@ import colossalai from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.nn.optimizer import HybridAdam -from colossalai.nn.parallel import zero_model_wrapper, zero_optim_wrapper from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec from colossalai.utils import get_current_device -from colossalai.utils.model.colo_init_context import ColoInitContext +from colossalai.zero import ColoInitContext, zero_model_wrapper, zero_optim_wrapper CAI_VERSION = colossalai.__version__ diff --git a/examples/language/gpt/gemini/train_gpt_demo.py b/examples/language/gpt/gemini/train_gpt_demo.py index f46226bce2b5..b2a7fa36d021 100644 --- a/examples/language/gpt/gemini/train_gpt_demo.py +++ b/examples/language/gpt/gemini/train_gpt_demo.py @@ -13,10 +13,9 @@ import colossalai from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.nn.optimizer import HybridAdam -from colossalai.nn.parallel import zero_model_wrapper, zero_optim_wrapper from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec from colossalai.utils import get_current_device -from colossalai.utils.model.colo_init_context import ColoInitContext +from colossalai.zero import ColoInitContext, zero_model_wrapper, zero_optim_wrapper CAI_VERSION = colossalai.__version__ diff --git a/examples/language/opt/train_gemini_opt.py b/examples/language/opt/train_gemini_opt.py index 4993ce25db17..4874f831c2ec 100755 --- a/examples/language/opt/train_gemini_opt.py +++ b/examples/language/opt/train_gemini_opt.py @@ -34,12 +34,9 @@ import colossalai from colossalai.logging import disable_existing_loggers, get_dist_logger -from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer -from colossalai.nn.parallel import GeminiDDP -from colossalai.utils import get_current_device -from colossalai.utils.model.colo_init_context import ColoInitContext - from colossalai.tensor import ProcessGroup, ShardSpec +from colossalai.utils import get_current_device +from colossalai.zero import ColoInitContext, GeminiAdamOptimizer, GeminiDDP def get_data(batch_size, seq_len, vocab_size): @@ -179,13 +176,15 @@ def main(): # build model if args.model_name_or_path is None: logger.info("Train a new model from scratch", ranks=[0]) - with ColoInitContext(device=init_dev, dtype=torch.half, + with ColoInitContext(device=init_dev, + dtype=torch.half, default_dist_spec=default_dist_spec, default_pg=shard_pg): model = OPTForCausalLM(config) else: logger.info("Finetune a pre-trained model", ranks=[0]) - with ColoInitContext(device=init_dev, dtype=torch.half, + with ColoInitContext(device=init_dev, + dtype=torch.half, default_dist_spec=default_dist_spec, default_pg=shard_pg): model = OPTForCausalLM.from_pretrained(args.model_name_or_path, @@ -198,8 +197,11 @@ def main(): numel = sum([p.numel() for p in model.parameters()]) PLACEMENT_POLICY = 'cpu' - model = GeminiDDP(model, device=get_current_device(), placement_policy=PLACEMENT_POLICY, - pin_memory=True, strict_ddp_mode=args.shardinit) + model = GeminiDDP(model, + device=get_current_device(), + placement_policy=PLACEMENT_POLICY, + pin_memory=True, + strict_ddp_mode=args.shardinit) optimizer = GeminiAdamOptimizer(model, lr=args.learning_rate, initial_scale=2**14, gpu_margin_mem_ratio=0.0) SEQ_LEN = 1024 diff --git a/examples/language/palm/train.py b/examples/language/palm/train.py index 2f012780da77..7923e4fc855d 100644 --- a/examples/language/palm/train.py +++ b/examples/language/palm/train.py @@ -15,11 +15,9 @@ import colossalai from colossalai.logging import disable_existing_loggers, get_dist_logger -from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer -from colossalai.nn.parallel import ZeroDDP from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec from colossalai.utils import MultiTimer, get_current_device -from colossalai.utils.model.colo_init_context import ColoInitContext +from colossalai.zero import ColoInitContext, GeminiAdamOptimizer, ZeroDDP # constants @@ -127,7 +125,7 @@ def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy: return model -## Parameter Sharding Strategies for Tensor Parallelism +# Parameter Sharding Strategies for Tensor Parallelism def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup): spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) param.set_tensor_spec(*spec) @@ -232,7 +230,7 @@ def __len__(self): tensor_parallelize(model, pg) model = gemini_zero_dpp(model, pg, args.placement) - #optimizer + # optimizer #optimizer = GeminiAdamOptimizer(model, lr=1e-7, initial_scale=2**5) optimizer = GeminiAdamOptimizer(model, lr=LEARNING_RATE, initial_scale=2**5) diff --git a/examples/language/roberta/pretraining/run_pretraining.py b/examples/language/roberta/pretraining/run_pretraining.py index 9840a122cbc4..eef7bb6ad5cd 100644 --- a/examples/language/roberta/pretraining/run_pretraining.py +++ b/examples/language/roberta/pretraining/run_pretraining.py @@ -1,69 +1,67 @@ -import colossalai import math +import os +import time +from functools import partial + import torch -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc -import colossalai.nn as col_nn from arguments import parse_args -from pretrain_utils import get_model, get_optimizer, get_lr_scheduler, save_ckpt -from utils.exp_util import get_tflops, get_mem_info, throughput_calculator, log_args -from utils.global_vars import set_global_variables, get_timers, get_tensorboard_writer -from utils.logger import Logger from evaluation import evaluate from loss import LossForPretraining - -from colossalai.zero.init_ctx import ZeroInitContext -from colossalai.zero.shard_utils import TensorShardStrategy -from colossalai.zero.sharded_model import ShardedModelV2 -from colossalai.zero.sharded_optim import ShardedOptimizerV2 from nvidia_bert_dataset_provider import NvidiaBertDatasetProvider +from pretrain_utils import get_lr_scheduler, get_model, get_optimizer, save_ckpt from tqdm import tqdm -import os -import time -from functools import partial - from transformers import AutoTokenizer +from utils.exp_util import get_mem_info, get_tflops, log_args, throughput_calculator +from utils.global_vars import get_tensorboard_writer, get_timers, set_global_variables +from utils.logger import Logger -from colossalai.gemini import ChunkManager, GeminiManager -from colossalai.utils.model.colo_init_context import ColoInitContext -from colossalai.utils import get_current_device +import colossalai +import colossalai.nn as col_nn +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.nn.optimizer import HybridAdam from colossalai.nn.parallel import ZeroDDP -from colossalai.zero import ZeroOptimizer from colossalai.tensor import ProcessGroup -from colossalai.nn.optimizer import HybridAdam +from colossalai.utils import get_current_device +from colossalai.zero import ZeroOptimizer +from colossalai.zero.gemini import ChunkManager, ColoInitContext, GeminiManager +from colossalai.zero.legacy import ShardedModelV2, ShardedOptimizerV2, ZeroInitContext +from colossalai.zero.legacy.shard_utils import TensorShardStrategy def main(): args = parse_args() launch_time = time.strftime("%Y-%m-%d-%H:%M:%S", time.localtime()) - + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path) os.environ['CUDA_LAUNCH_BLOCKING'] = '1' - + logger = Logger(os.path.join(args.log_path, launch_time), cuda=torch.cuda.is_available(), debug=args.vscode_debug) - + if args.vscode_debug: colossalai.launch(config={}, - rank=args.rank, - world_size=args.world_size, - host=args.host, - port=args.port, - backend=args.backend) + rank=args.rank, + world_size=args.world_size, + host=args.host, + port=args.port, + backend=args.backend) args.local_rank = -1 args.log_interval = 1 else: - colossalai.launch_from_torch(args.colossal_config) #args.colossal_config + colossalai.launch_from_torch(args.colossal_config) # args.colossal_config args.local_rank = int(os.environ["LOCAL_RANK"]) - logger.info(f'launch_from_torch, world size: {torch.distributed.get_world_size()} | ' + - f'ParallelMode.MODEL: {ParallelMode.MODEL} | ParallelMode.DATA: {ParallelMode.DATA} | ParallelMode.TENSOR: {ParallelMode.TENSOR}') + logger.info( + f'launch_from_torch, world size: {torch.distributed.get_world_size()} | ' + + f'ParallelMode.MODEL: {ParallelMode.MODEL} | ParallelMode.DATA: {ParallelMode.DATA} | ParallelMode.TENSOR: {ParallelMode.TENSOR}' + ) log_args(logger, args) args.tokenizer = tokenizer args.logger = logger set_global_variables(launch_time, args.tensorboard_path) - + use_zero = hasattr(gpc.config, 'zero') world_size = torch.distributed.get_world_size() @@ -71,8 +69,8 @@ def main(): if use_zero: shard_strategy = TensorShardStrategy() with ZeroInitContext(target_device=torch.cuda.current_device(), shard_strategy=shard_strategy, - shard_param=True): - + shard_param=True): + config, model, numel = get_model(args, logger) # model = ShardedModelV2(model, shard_strategy, tensor_placement_policy='cpu', reuse_fp16_shard=True) else: @@ -82,9 +80,10 @@ def main(): os.mkdir(os.path.join(args.ckpt_path, launch_time)) logger.info(f'Model numel: {numel}') - + get_tflops_func = partial(get_tflops, numel, args.train_micro_batch_size_per_gpu, args.max_seq_length) - steps_per_epoch = 144003367 // world_size // args.train_micro_batch_size_per_gpu // args.gradient_accumulation_steps // args.refresh_bucket_size #len(dataloader) + # len(dataloader) + steps_per_epoch = 144003367 // world_size // args.train_micro_batch_size_per_gpu // args.gradient_accumulation_steps // args.refresh_bucket_size total_steps = steps_per_epoch * args.epoch # build optimizer and lr_scheduler @@ -98,18 +97,23 @@ def main(): o_l_state_dict['lr_scheduler']['last_epoch'] = o_l_state_dict['lr_scheduler']['last_epoch'] - 1 optimizer = get_optimizer(model, lr=args.lr) optimizer.load_state_dict(o_l_state_dict['optimizer']) - lr_scheduler = get_lr_scheduler(optimizer, total_steps=total_steps, last_epoch=o_l_state_dict['lr_scheduler']['last_epoch']) #o_l_state_dict['lr_scheduler']['last_epoch'] + # o_l_state_dict['lr_scheduler']['last_epoch'] + lr_scheduler = get_lr_scheduler(optimizer, + total_steps=total_steps, + last_epoch=o_l_state_dict['lr_scheduler']['last_epoch']) for state in optimizer.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.cuda(f"cuda:{torch.cuda.current_device()}") # if you want delete the above three code, have to move the model to gpu, because in optimizer.step() lr_scheduler.load_state_dict(o_l_state_dict['lr_scheduler']) - + start_epoch = o_l_state_dict['epoch'] start_shard = o_l_state_dict['shard'] + 1 # global_step = o_l_state_dict['global_step'] + 1 - logger.info(f'resume from epoch {start_epoch} shard {start_shard} step {lr_scheduler.last_epoch} lr {lr_scheduler.get_last_lr()[0]}') + logger.info( + f'resume from epoch {start_epoch} shard {start_shard} step {lr_scheduler.last_epoch} lr {lr_scheduler.get_last_lr()[0]}' + ) else: optimizer = get_optimizer(model, lr=args.lr) lr_scheduler = get_lr_scheduler(optimizer, total_steps=total_steps, last_epoch=-1) @@ -124,12 +128,11 @@ def main(): # initialize with colossalai engine, _, _, lr_scheduelr = colossalai.initialize(model=model, - optimizer=optimizer, - criterion=criterion, - lr_scheduler=lr_scheduler) - + optimizer=optimizer, + criterion=criterion, + lr_scheduler=lr_scheduler) + logger.info(get_mem_info(prefix='After init model, ')) - best_loss = None eval_loss = 0 @@ -146,13 +149,16 @@ def main(): dataset_iterator, total_length = pretrain_dataset_provider.get_shard(shard) # pretrain_dataset_provider.prefetch_shard(shard + 1) # may cause cpu memory overload if torch.distributed.get_rank() == 0: - iterator_data = tqdm(enumerate(dataset_iterator), total=(total_length // args.train_micro_batch_size_per_gpu // world_size), colour='cyan', smoothing=1) + iterator_data = tqdm(enumerate(dataset_iterator), + total=(total_length // args.train_micro_batch_size_per_gpu // world_size), + colour='cyan', + smoothing=1) else: iterator_data = enumerate(dataset_iterator) engine.train() - - for step, batch_data in iterator_data: + + for step, batch_data in iterator_data: # batch_data = pretrain_dataset_provider.get_batch(batch_index) input_ids = batch_data[0].cuda(f"cuda:{torch.cuda.current_device()}") @@ -162,7 +168,7 @@ def main(): # nsp_label = batch_data[5].cuda() output = engine(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) - + loss = engine.criterion(output.logits, mlm_label) pretrain_dataset_provider.prefetch_batch() @@ -172,14 +178,15 @@ def main(): engine.step() lr_scheduelr.step() engine.zero_grad() - + global_step += 1 if global_step % args.log_interval == 0 and global_step != 0 \ - and torch.distributed.get_rank() == 0: + and torch.distributed.get_rank() == 0: elapsed_time = timers('interval_time').elapsed(reset=False) elapsed_time_per_iteration = elapsed_time / global_step - samples_per_sec, tflops, approx_parameters_in_billions = throughput_calculator(numel, args, config, elapsed_time, global_step, world_size) + samples_per_sec, tflops, approx_parameters_in_billions = throughput_calculator( + numel, args, config, elapsed_time, global_step, world_size) cur_loss = train_loss / args.log_interval current_lr = lr_scheduelr.get_last_lr()[0] @@ -189,12 +196,13 @@ def main(): if args.wandb: tensorboard_log = get_tensorboard_writer() - tensorboard_log.log_train({ - 'lr': current_lr, - 'loss': cur_loss, - 'ppl': math.exp(cur_loss), - 'mins_batch': elapsed_time_per_iteration - }, global_step) + tensorboard_log.log_train( + { + 'lr': current_lr, + 'loss': cur_loss, + 'ppl': math.exp(cur_loss), + 'mins_batch': elapsed_time_per_iteration + }, global_step) train_loss = 0 @@ -202,12 +210,14 @@ def main(): logger.info('*' * 100) eval_loss += evaluate(engine, args, logger, global_step) - save_ckpt(engine.model, optimizer, lr_scheduelr, os.path.join(args.ckpt_path, launch_time, f'epoch-{epoch}_shard-{shard}_' + launch_time), epoch, shard, global_step) - - + save_ckpt(engine.model, optimizer, lr_scheduelr, + os.path.join(args.ckpt_path, launch_time, f'epoch-{epoch}_shard-{shard}_' + launch_time), epoch, + shard, global_step) + eval_loss /= len(os.listdir(args.data_path_prefix)) - logger.info(f'epoch {epoch} | shard_length {len(os.listdir(args.data_path_prefix))} | elapsed_time: {timers("epoch_time").elapsed() / 60 :.3f} mins' + \ - f'eval_loss: {eval_loss} | ppl: {math.exp(eval_loss)}') + logger.info( + f'epoch {epoch} | shard_length {len(os.listdir(args.data_path_prefix))} | elapsed_time: {timers("epoch_time").elapsed() / 60 :.3f} mins' + + f'eval_loss: {eval_loss} | ppl: {math.exp(eval_loss)}') logger.info('-' * 100) if args.wandb and torch.distributed.get_rank() == 0: tensorboard_log = get_tensorboard_writer() diff --git a/examples/tutorial/opt/opt/run_clm.py b/examples/tutorial/opt/opt/run_clm.py index c4f576cb18aa..e618b4d66957 100755 --- a/examples/tutorial/opt/opt/run_clm.py +++ b/examples/tutorial/opt/opt/run_clm.py @@ -30,24 +30,13 @@ import datasets import torch import torch.distributed as dist +import transformers from accelerate.utils import set_seed from context import barrier_context from datasets import load_dataset from packaging import version from torch.utils.data import DataLoader from tqdm.auto import tqdm - -import colossalai -import transformers -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.logging import disable_existing_loggers, get_dist_logger -from colossalai.nn.optimizer import HybridAdam -from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer -from colossalai.nn.parallel import ZeroDDP -from colossalai.tensor import ProcessGroup -from colossalai.utils import get_current_device, get_dataloader -from colossalai.utils.model.colo_init_context import ColoInitContext from transformers import ( CONFIG_MAPPING, MODEL_MAPPING, @@ -61,6 +50,15 @@ ) from transformers.utils.versions import require_version +import colossalai +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.logging import disable_existing_loggers, get_dist_logger +from colossalai.nn.optimizer import HybridAdam +from colossalai.tensor import ProcessGroup +from colossalai.utils import get_current_device, get_dataloader +from colossalai.zero import ColoInitContext, ZeroDDP, ZeroOptimizer + require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys()) diff --git a/tests/test_auto_parallel/test_offload/test_perf.py b/tests/test_auto_parallel/test_offload/test_perf.py index 17bf9cb87f51..c925843fb2b6 100644 --- a/tests/test_auto_parallel/test_offload/test_perf.py +++ b/tests/test_auto_parallel/test_offload/test_perf.py @@ -12,10 +12,9 @@ from colossalai.auto_parallel.offload.solver import NOT_NVML from colossalai.fx.profiler import parameter_size from colossalai.nn.optimizer import HybridAdam -from colossalai.nn.parallel import zero_model_wrapper, zero_optim_wrapper from colossalai.testing import parameterize from colossalai.utils import free_port, get_current_device -from colossalai.utils.model.colo_init_context import ColoInitContext +from colossalai.zero import ColoInitContext, zero_model_wrapper, zero_optim_wrapper from tests.test_auto_parallel.test_offload.model_utils import * from tests.test_tensor.common_utils import set_seed diff --git a/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py b/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py index 760401c3f2c2..9879ae461848 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py @@ -11,12 +11,11 @@ from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers from colossalai.nn.optimizer import HybridAdam -from colossalai.nn.parallel import zero_model_wrapper, zero_optim_wrapper from colossalai.tensor.process_group import ProcessGroup from colossalai.testing import assert_close, rerun_if_address_is_in_use from colossalai.testing.pytest_wrapper import run_on_environment_flag from colossalai.utils import free_port, get_current_device -from colossalai.utils.model.colo_init_context import ColoInitContext, post_process_colo_init_ctx +from colossalai.zero import ColoInitContext, post_process_colo_init_ctx, zero_model_wrapper, zero_optim_wrapper class MLP(torch.nn.Module): diff --git a/tests/test_ddp/test_ddp_ignore_params.py b/tests/test_ddp/test_ddp_ignore_params.py index 679c8b0f6afe..2ad20f6bec72 100644 --- a/tests/test_ddp/test_ddp_ignore_params.py +++ b/tests/test_ddp/test_ddp_ignore_params.py @@ -10,14 +10,14 @@ import torch.multiprocessing as mp import colossalai -from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration -from colossalai.gemini.gemini_mgr import GeminiManager -from colossalai.nn.parallel import ColoDDP, ZeroDDP +from colossalai.nn.parallel import ColoDDP from colossalai.tensor import ProcessGroup from colossalai.testing import rerun_if_address_is_in_use from colossalai.utils import free_port from colossalai.utils.cuda import get_current_device -from colossalai.utils.model.colo_init_context import ColoInitContext +from colossalai.zero import ColoInitContext, ZeroDDP +from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration +from colossalai.zero.gemini.gemini_mgr import GeminiManager def set_seed(seed): diff --git a/tests/test_ddp/test_ddp_state_dict.py b/tests/test_ddp/test_ddp_state_dict.py index f229364c6eb1..bd4742ff2cc9 100644 --- a/tests/test_ddp/test_ddp_state_dict.py +++ b/tests/test_ddp/test_ddp_state_dict.py @@ -1,18 +1,19 @@ import copy +from collections import OrderedDict +from functools import partial import pytest -import colossalai import torch import torch.multiprocessing as mp + +import colossalai +from colossalai.nn.parallel import ColoDDP +from colossalai.tensor import ColoParameter, ProcessGroup from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils.cuda import get_current_device from colossalai.utils import free_port -from colossalai.utils.model.colo_init_context import ColoInitContext -from functools import partial +from colossalai.utils.cuda import get_current_device +from colossalai.zero import ColoInitContext from tests.components_to_test.registry import non_distributed_component_funcs -from colossalai.nn.parallel import ColoDDP -from collections import OrderedDict -from colossalai.tensor import ProcessGroup, ColoParameter def check_state_dict_equal(state_dict: OrderedDict, other_state_dict: OrderedDict): diff --git a/tests/test_gemini/test_gemini_manager.py b/tests/test_gemini/test_gemini_manager.py index 0c138f101f75..aee9432532c6 100644 --- a/tests/test_gemini/test_gemini_manager.py +++ b/tests/test_gemini/test_gemini_manager.py @@ -1,73 +1,73 @@ -import pytest -import torch - -from colossalai.gemini.stateful_tensor import TensorState, StatefulTensor - - -@pytest.mark.dist -def test_gemini_manager(): - # reset the manager, in case that there exists memory information left - manager = StatefulTensor.GST_MGR - manager.reset() - - # occupation 8 - st1 = StatefulTensor(torch.empty(2, 2, dtype=torch.float16, device='cuda')) - # occupation 60 - st2 = StatefulTensor(torch.empty(3, 5, dtype=torch.float32, device='cpu')) - - # occupation 28 - t1 = torch.empty(7, device='cuda') - # occupation 12 - t2 = torch.empty(3, device='cpu') - st3 = StatefulTensor(t1, TensorState.HOLD_AFTER_FWD) - st4 = StatefulTensor(None, TensorState.FREE) - - assert manager.total_number == 4 - assert manager.total_mem['cpu'] == 60 - assert manager.total_mem['cuda'] == 36 - assert manager.state_mem['cpu'][TensorState.HOLD] == 60 - assert manager.state_mem['cuda'][TensorState.HOLD] == 8 - assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 28 - - st4.payload_reset(t2) - st3.payload_reset(t2) - - assert manager.total_number == 4 - assert manager.total_mem['cpu'] == 84 - assert manager.total_mem['cuda'] == 8 - assert manager.state_mem['cpu'][TensorState.HOLD] == 72 - assert manager.state_mem['cuda'][TensorState.HOLD] == 8 - assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_FWD] == 12 - assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 0 - - st1.move_to(torch.device('cpu')) - st2.move_to(torch.device('cpu')) - st3.move_to(torch.device('cuda', 0)) - - assert manager.total_number == 4 - assert manager.total_mem['cpu'] == 80 - assert manager.total_mem['cuda'] == 12 - assert manager.state_mem['cpu'][TensorState.HOLD] == 80 - assert manager.state_mem['cuda'][TensorState.HOLD] == 0 - assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_FWD] == 0 - assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 12 - - st1.trans_state(TensorState.COMPUTE) - st2.trans_state(TensorState.COMPUTE) - st2.trans_state(TensorState.HOLD_AFTER_BWD) - - assert manager.total_number == 4 - assert manager.total_mem['cpu'] == 80 - assert manager.total_mem['cuda'] == 12 - assert manager.state_mem['cpu'][TensorState.HOLD] == 12 - assert manager.state_mem['cuda'][TensorState.HOLD] == 0 - assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_FWD] == 0 - assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 12 - assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_BWD] == 60 - assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_BWD] == 0 - assert manager.state_mem['cpu'][TensorState.COMPUTE] == 8 - assert manager.state_mem['cuda'][TensorState.COMPUTE] == 0 - - -if __name__ == '__main__': - test_gemini_manager() +import pytest +import torch + +from colossalai.zero.legacy.gemini.stateful_tensor import StatefulTensor, TensorState + + +@pytest.mark.dist +def test_gemini_manager(): + # reset the manager, in case that there exists memory information left + manager = StatefulTensor.GST_MGR + manager.reset() + + # occupation 8 + st1 = StatefulTensor(torch.empty(2, 2, dtype=torch.float16, device='cuda')) + # occupation 60 + st2 = StatefulTensor(torch.empty(3, 5, dtype=torch.float32, device='cpu')) + + # occupation 28 + t1 = torch.empty(7, device='cuda') + # occupation 12 + t2 = torch.empty(3, device='cpu') + st3 = StatefulTensor(t1, TensorState.HOLD_AFTER_FWD) + st4 = StatefulTensor(None, TensorState.FREE) + + assert manager.total_number == 4 + assert manager.total_mem['cpu'] == 60 + assert manager.total_mem['cuda'] == 36 + assert manager.state_mem['cpu'][TensorState.HOLD] == 60 + assert manager.state_mem['cuda'][TensorState.HOLD] == 8 + assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 28 + + st4.payload_reset(t2) + st3.payload_reset(t2) + + assert manager.total_number == 4 + assert manager.total_mem['cpu'] == 84 + assert manager.total_mem['cuda'] == 8 + assert manager.state_mem['cpu'][TensorState.HOLD] == 72 + assert manager.state_mem['cuda'][TensorState.HOLD] == 8 + assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_FWD] == 12 + assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 0 + + st1.move_to(torch.device('cpu')) + st2.move_to(torch.device('cpu')) + st3.move_to(torch.device('cuda', 0)) + + assert manager.total_number == 4 + assert manager.total_mem['cpu'] == 80 + assert manager.total_mem['cuda'] == 12 + assert manager.state_mem['cpu'][TensorState.HOLD] == 80 + assert manager.state_mem['cuda'][TensorState.HOLD] == 0 + assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_FWD] == 0 + assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 12 + + st1.trans_state(TensorState.COMPUTE) + st2.trans_state(TensorState.COMPUTE) + st2.trans_state(TensorState.HOLD_AFTER_BWD) + + assert manager.total_number == 4 + assert manager.total_mem['cpu'] == 80 + assert manager.total_mem['cuda'] == 12 + assert manager.state_mem['cpu'][TensorState.HOLD] == 12 + assert manager.state_mem['cuda'][TensorState.HOLD] == 0 + assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_FWD] == 0 + assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 12 + assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_BWD] == 60 + assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_BWD] == 0 + assert manager.state_mem['cpu'][TensorState.COMPUTE] == 8 + assert manager.state_mem['cuda'][TensorState.COMPUTE] == 0 + + +if __name__ == '__main__': + test_gemini_manager() diff --git a/tests/test_gemini/test_param_op.py b/tests/test_gemini/test_param_op.py index daf386d6d6af..9ebacdb70f1a 100644 --- a/tests/test_gemini/test_param_op.py +++ b/tests/test_gemini/test_param_op.py @@ -2,7 +2,7 @@ import torch -from colossalai.gemini.paramhooks import BaseParamHookMgr +from colossalai.zero.legacy.gemini.paramhooks import BaseParamHookMgr from tests.components_to_test.registry import non_distributed_component_funcs diff --git a/tests/test_gemini/test_runtime_mem_tracer.py b/tests/test_gemini/test_runtime_mem_tracer.py index 294868458c47..9a3e93493ec3 100644 --- a/tests/test_gemini/test_runtime_mem_tracer.py +++ b/tests/test_gemini/test_runtime_mem_tracer.py @@ -3,8 +3,8 @@ import numpy as np import torch -from colossalai.gemini.memory_tracer.runtime_mem_tracer import RuntimeMemTracer -from colossalai.utils.model.colo_init_context import ColoInitContext +from colossalai.zero import ColoInitContext +from colossalai.zero.gemini.memory_tracer.runtime_mem_tracer import RuntimeMemTracer from tests.components_to_test import run_fwd_bwd from tests.components_to_test.registry import non_distributed_component_funcs diff --git a/tests/test_gemini/update/test_chunk_mgrv2.py b/tests/test_gemini/update/test_chunk_mgrv2.py index 7d192fc631a6..ba0945551b18 100644 --- a/tests/test_gemini/update/test_chunk_mgrv2.py +++ b/tests/test_gemini/update/test_chunk_mgrv2.py @@ -5,10 +5,10 @@ import torch.multiprocessing as mp import colossalai -from colossalai.gemini.chunk import ChunkManager from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.utils import free_port +from colossalai.zero.gemini.chunk import ChunkManager from tests.test_tensor.common_utils import debug_print CUDA_MEM_0 = {False: 512, True: 1024} diff --git a/tests/test_gemini/update/test_chunkv2.py b/tests/test_gemini/update/test_chunkv2.py index 96855410bea6..5f9ba5d3a827 100644 --- a/tests/test_gemini/update/test_chunkv2.py +++ b/tests/test_gemini/update/test_chunkv2.py @@ -6,12 +6,12 @@ import torch.multiprocessing as mp import colossalai -from colossalai.gemini import TensorState -from colossalai.gemini.chunk import Chunk from colossalai.tensor import ColoParameter from colossalai.tensor import ProcessGroup as ColoProcessGroup from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.utils import free_port, get_current_device +from colossalai.zero.gemini import TensorState +from colossalai.zero.gemini.chunk import Chunk def dist_sum(x): diff --git a/tests/test_gemini/update/test_fwd_bwd.py b/tests/test_gemini/update/test_fwd_bwd.py index 2821dc78d984..8cfacd018324 100644 --- a/tests/test_gemini/update/test_fwd_bwd.py +++ b/tests/test_gemini/update/test_fwd_bwd.py @@ -8,16 +8,14 @@ import colossalai from colossalai.amp import convert_to_apex_amp -from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration -from colossalai.gemini.gemini_mgr import GeminiManager from colossalai.nn.optimizer import HybridAdam -from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer -from colossalai.nn.parallel import ZeroDDP from colossalai.tensor import ProcessGroup from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.utils import free_port from colossalai.utils.cuda import get_current_device -from colossalai.utils.model.colo_init_context import ColoInitContext +from colossalai.zero import ColoInitContext, ZeroDDP, ZeroOptimizer +from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration +from colossalai.zero.gemini.gemini_mgr import GeminiManager from tests.components_to_test import run_fwd_bwd from tests.components_to_test.registry import non_distributed_component_funcs from tests.test_tensor.common_utils import set_seed diff --git a/tests/test_gemini/update/test_gemini_use_rmt.py b/tests/test_gemini/update/test_gemini_use_rmt.py index 8cf17a0a726e..9d5419e9452d 100644 --- a/tests/test_gemini/update/test_gemini_use_rmt.py +++ b/tests/test_gemini/update/test_gemini_use_rmt.py @@ -5,15 +5,13 @@ import torch.multiprocessing as mp import colossalai -from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration -from colossalai.gemini.gemini_mgr import GeminiManager -from colossalai.gemini.memory_tracer.runtime_mem_tracer import RuntimeMemTracer -from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer -from colossalai.nn.parallel import GeminiDDP, ZeroDDP from colossalai.tensor import ProcessGroup from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.utils import free_port -from colossalai.utils.model.colo_init_context import ColoInitContext +from colossalai.zero import ColoInitContext, GeminiAdamOptimizer, GeminiDDP, ZeroDDP +from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration +from colossalai.zero.gemini.gemini_mgr import GeminiManager +from colossalai.zero.gemini.memory_tracer.runtime_mem_tracer import RuntimeMemTracer from tests.components_to_test import run_fwd_bwd from tests.components_to_test.registry import non_distributed_component_funcs from tests.test_tensor.common_utils import set_seed diff --git a/tests/test_gemini/update/test_get_torch_model.py b/tests/test_gemini/update/test_get_torch_model.py index e6d586b37041..c014ced975ce 100644 --- a/tests/test_gemini/update/test_get_torch_model.py +++ b/tests/test_gemini/update/test_get_torch_model.py @@ -6,13 +6,12 @@ import torch.multiprocessing as mp import colossalai -from colossalai.nn.parallel import GeminiDDP -from colossalai.nn.parallel.utils import get_static_torch_model from colossalai.tensor import ColoParameter from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.utils import free_port from colossalai.utils.cuda import get_current_device -from colossalai.utils.model.colo_init_context import ColoInitContext +from colossalai.zero import ColoInitContext, GeminiDDP +from colossalai.zero.gemini.utils import get_static_torch_model from tests.components_to_test.registry import non_distributed_component_funcs diff --git a/tests/test_gemini/update/test_grad_clip.py b/tests/test_gemini/update/test_grad_clip.py index d97ba94399c0..65f252c558d7 100644 --- a/tests/test_gemini/update/test_grad_clip.py +++ b/tests/test_gemini/update/test_grad_clip.py @@ -10,15 +10,13 @@ import colossalai from colossalai.amp import convert_to_apex_amp -from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration -from colossalai.gemini.gemini_mgr import GeminiManager from colossalai.nn.optimizer import HybridAdam -from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer -from colossalai.nn.parallel import ZeroDDP from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.utils import free_port from colossalai.utils.cuda import get_current_device -from colossalai.utils.model.colo_init_context import ColoInitContext +from colossalai.zero import ColoInitContext, ZeroDDP, ZeroOptimizer +from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration +from colossalai.zero.gemini.gemini_mgr import GeminiManager from tests.components_to_test import run_fwd_bwd from tests.components_to_test.registry import non_distributed_component_funcs from tests.test_tensor.common_utils import debug_print, set_seed diff --git a/tests/test_gemini/update/test_inference.py b/tests/test_gemini/update/test_inference.py index b057448ad378..12392d6e57ad 100644 --- a/tests/test_gemini/update/test_inference.py +++ b/tests/test_gemini/update/test_inference.py @@ -10,15 +10,13 @@ import colossalai from colossalai.amp import convert_to_apex_amp -from colossalai.gemini.chunk import ChunkManager, init_chunk_manager, search_chunk_configuration -from colossalai.gemini.gemini_mgr import GeminiManager from colossalai.nn.optimizer import HybridAdam -from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer -from colossalai.nn.parallel import ZeroDDP, zero_model_wrapper from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.utils import free_port from colossalai.utils.cuda import get_current_device -from colossalai.utils.model.colo_init_context import ColoInitContext, post_process_colo_init_ctx +from colossalai.zero import ColoInitContext, ZeroDDP, ZeroOptimizer, post_process_colo_init_ctx, zero_model_wrapper +from colossalai.zero.gemini.chunk import ChunkManager, init_chunk_manager, search_chunk_configuration +from colossalai.zero.gemini.gemini_mgr import GeminiManager from tests.components_to_test import run_fwd_bwd from tests.components_to_test.registry import non_distributed_component_funcs from tests.test_tensor.common_utils import debug_print, set_seed diff --git a/tests/test_gemini/update/test_optim.py b/tests/test_gemini/update/test_optim.py index cd3aa6051d78..7364e59d10b4 100644 --- a/tests/test_gemini/update/test_optim.py +++ b/tests/test_gemini/update/test_optim.py @@ -9,16 +9,14 @@ import colossalai from colossalai.amp import convert_to_apex_amp -from colossalai.gemini.chunk import ChunkManager, init_chunk_manager, search_chunk_configuration -from colossalai.gemini.gemini_mgr import GeminiManager from colossalai.nn.optimizer import HybridAdam -from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer -from colossalai.nn.parallel import ZeroDDP from colossalai.tensor import ColoParameter, ColoTensor from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.utils import free_port from colossalai.utils.cuda import get_current_device -from colossalai.utils.model.colo_init_context import ColoInitContext, post_process_colo_init_ctx +from colossalai.zero import ColoInitContext, ZeroDDP, ZeroOptimizer, post_process_colo_init_ctx +from colossalai.zero.gemini.chunk import ChunkManager, init_chunk_manager, search_chunk_configuration +from colossalai.zero.gemini.gemini_mgr import GeminiManager from tests.components_to_test import run_fwd_bwd from tests.components_to_test.registry import non_distributed_component_funcs from tests.test_tensor.common_utils import debug_print, set_seed diff --git a/tests/test_gemini/update/test_search.py b/tests/test_gemini/update/test_search.py index 2fcdd5380906..71cdf9a18840 100644 --- a/tests/test_gemini/update/test_search.py +++ b/tests/test_gemini/update/test_search.py @@ -6,11 +6,11 @@ import torch.multiprocessing as mp import colossalai -from colossalai.gemini.chunk import init_chunk_manager, search_chunk_configuration from colossalai.tensor import ComputePattern, ComputeSpec, ProcessGroup, ShardSpec from colossalai.testing import rerun_if_address_is_in_use from colossalai.utils import free_port, get_current_device -from colossalai.utils.model.colo_init_context import ColoInitContext +from colossalai.zero import ColoInitContext +from colossalai.zero.gemini.chunk import init_chunk_manager, search_chunk_configuration from tests.components_to_test.registry import non_distributed_component_funcs diff --git a/tests/test_gemini/update/test_zeroddp_state_dict.py b/tests/test_gemini/update/test_zeroddp_state_dict.py index 00d835842f79..7e759808d1cb 100644 --- a/tests/test_gemini/update/test_zeroddp_state_dict.py +++ b/tests/test_gemini/update/test_zeroddp_state_dict.py @@ -7,13 +7,12 @@ from torch.testing import assert_close import colossalai -from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration -from colossalai.gemini.gemini_mgr import GeminiManager -from colossalai.nn.parallel import ZeroDDP from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.utils import free_port from colossalai.utils.cuda import get_current_device -from colossalai.utils.model.colo_init_context import ColoInitContext +from colossalai.zero import ColoInitContext, ZeroDDP +from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration +from colossalai.zero.gemini.gemini_mgr import GeminiManager from tests.components_to_test.registry import non_distributed_component_funcs from tests.test_tensor.common_utils import debug_print, set_seed diff --git a/tests/test_gemini/update/test_zerooptim_state_dict.py b/tests/test_gemini/update/test_zerooptim_state_dict.py index fd13af6b2b0a..996dc4eb8b56 100644 --- a/tests/test_gemini/update/test_zerooptim_state_dict.py +++ b/tests/test_gemini/update/test_zerooptim_state_dict.py @@ -6,15 +6,13 @@ import torch.multiprocessing as mp import colossalai -from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration -from colossalai.gemini.gemini_mgr import GeminiManager from colossalai.nn.optimizer import HybridAdam -from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer -from colossalai.nn.parallel import ZeroDDP from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.utils import free_port from colossalai.utils.cuda import get_current_device -from colossalai.utils.model.colo_init_context import ColoInitContext +from colossalai.zero import ColoInitContext, ZeroDDP, ZeroOptimizer +from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration +from colossalai.zero.gemini.gemini_mgr import GeminiManager from tests.components_to_test.registry import non_distributed_component_funcs from tests.test_tensor.common_utils import debug_print, set_seed diff --git a/tests/test_moe/test_moe_checkpoint.py b/tests/test_moe/test_moe_checkpoint.py index f99e74ea55c1..5b6fe441112d 100644 --- a/tests/test_moe/test_moe_checkpoint.py +++ b/tests/test_moe/test_moe_checkpoint.py @@ -11,7 +11,7 @@ from colossalai.nn.layer.moe import load_moe_model, save_moe_model from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.utils import free_port, get_current_device -from colossalai.utils.model.colo_init_context import ColoInitContext +from colossalai.zero import ColoInitContext from tests.test_moe.test_moe_zero_init import MoeModel from tests.test_tensor.common_utils import debug_print from tests.test_zero.common import CONFIG diff --git a/tests/test_moe/test_moe_colo_init.py b/tests/test_moe/test_moe_colo_init.py index ae0c1390c129..23ad1a3dc6a4 100644 --- a/tests/test_moe/test_moe_colo_init.py +++ b/tests/test_moe/test_moe_colo_init.py @@ -1,63 +1,60 @@ -from functools import partial - -import colossalai -import pytest -import torch -import torch.multiprocessing as mp -import torch.distributed as dist -from colossalai.testing import parameterize -from colossalai.utils import free_port -from colossalai.context import MOE_CONTEXT -from colossalai.tensor import ColoParameter -from colossalai.utils.model.colo_init_context import ColoInitContext - -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils import get_current_device - -from tests.test_zero.common import CONFIG -from tests.test_moe.test_moe_zero_init import MoeModel -from tests.test_tensor.common_utils import debug_print - - -@parameterize("init_device_type", ['cpu', 'cuda']) -def exam_moe_colo_init(init_device_type): - world_size = dist.get_world_size() - - if init_device_type == 'cuda': - init_device = get_current_device() - elif init_device_type == 'cpu': - init_device = torch.device("cpu") - else: - raise NotImplementedError("Unknown device found.") - - with ColoInitContext(device=init_device): - model = MoeModel(checkpoint=True) - - for name, param in model.named_parameters(): - assert isinstance(param, ColoParameter), "parameter `{}` has an init problem".format(name) - - if hasattr(param, "moe_info"): - param.set_process_group(param.moe_info.pg) - - if hasattr(param, "moe_info"): - assert param.process_group.dp_world_size() == param.moe_info.dp_size - else: - assert param.process_group.dp_world_size() == world_size - - -def _run_dist(rank, world_size, port): - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - MOE_CONTEXT.setup(seed=42) - exam_moe_colo_init() - - -@pytest.mark.dist -@pytest.mark.parametrize("world_size", [4]) -@rerun_if_address_is_in_use() -def test_moe_colo_init(world_size): - run_func = partial(_run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) - - -if __name__ == '__main__': - test_moe_colo_init(world_size=4) +from functools import partial + +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +import colossalai +from colossalai.context import MOE_CONTEXT +from colossalai.tensor import ColoParameter +from colossalai.testing import parameterize, rerun_if_address_is_in_use +from colossalai.utils import free_port, get_current_device +from colossalai.zero import ColoInitContext +from tests.test_moe.test_moe_zero_init import MoeModel +from tests.test_tensor.common_utils import debug_print +from tests.test_zero.common import CONFIG + + +@parameterize("init_device_type", ['cpu', 'cuda']) +def exam_moe_colo_init(init_device_type): + world_size = dist.get_world_size() + + if init_device_type == 'cuda': + init_device = get_current_device() + elif init_device_type == 'cpu': + init_device = torch.device("cpu") + else: + raise NotImplementedError("Unknown device found.") + + with ColoInitContext(device=init_device): + model = MoeModel(checkpoint=True) + + for name, param in model.named_parameters(): + assert isinstance(param, ColoParameter), "parameter `{}` has an init problem".format(name) + + if hasattr(param, "moe_info"): + param.set_process_group(param.moe_info.pg) + + if hasattr(param, "moe_info"): + assert param.process_group.dp_world_size() == param.moe_info.dp_size + else: + assert param.process_group.dp_world_size() == world_size + + +def _run_dist(rank, world_size, port): + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + MOE_CONTEXT.setup(seed=42) + exam_moe_colo_init() + + +@pytest.mark.dist +@pytest.mark.parametrize("world_size", [4]) +@rerun_if_address_is_in_use() +def test_moe_colo_init(world_size): + run_func = partial(_run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_moe_colo_init(world_size=4) diff --git a/tests/test_moe/test_moe_zero_init.py b/tests/test_moe/test_moe_zero_init.py index 04dc9c514dd0..5987e31f71c9 100644 --- a/tests/test_moe/test_moe_zero_init.py +++ b/tests/test_moe/test_moe_zero_init.py @@ -1,114 +1,112 @@ -from functools import partial - -import colossalai -import pytest -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from colossalai.nn import CheckpointModule -from colossalai.logging import get_dist_logger -from colossalai.testing import parameterize -from colossalai.utils import free_port -from colossalai.context import MOE_CONTEXT -from colossalai.nn.layer import MoeModule -from colossalai.zero.init_ctx import ZeroInitContext -from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy) - -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils import get_current_device -from tests.test_zero.common import CONFIG - - -class MoeModel(nn.Module): - - def __init__(self, checkpoint: bool = False): - - class TestSubModule(CheckpointModule): - - def __init__(self): - super().__init__(checkpoint) - expert_cls = nn.Linear - expert_args_dict = dict(in_features=16, out_features=16) - self.moe = MoeModule(dim_model=16, - num_experts=8, - use_residual=True, - expert_cls=expert_cls, - **expert_args_dict) - self.proj = nn.Linear(16, 4) - - def _forward(self, x): - x, y = self.moe(x) - x = self.proj(x) - return x, y - - super().__init__() - self.test_embed = nn.Linear(4, 16) - self.test_transform = TestSubModule() - - def forward(self, x): - MOE_CONTEXT.reset_loss() - - x = self.test_embed(x) - x, y = self.test_transform(x) - - MOE_CONTEXT.add_loss(y) - return x - - -@parameterize("init_device_type", ['cpu', 'cuda']) -@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy]) -def run_moe_zero_init(init_device_type, shard_strategy_class): - logger = get_dist_logger("test_moe_zero_init") - - if init_device_type == 'cuda': - init_device = get_current_device() - elif init_device_type == 'cpu': - init_device = torch.device("cpu") - else: - raise NotImplementedError("Unknown device found.") - - model_numel_tensor = torch.zeros(1, dtype=torch.int) - with ZeroInitContext(target_device=init_device, - shard_strategy=shard_strategy_class(), - shard_param=True, - model_numel_tensor=model_numel_tensor): - model = MoeModel(checkpoint=True) - - for name, param in model.named_parameters(): - assert hasattr(param, 'colo_attr') - - # the parameters in moe experts and its gate should not be sharded - if ('experts' in name) or ('gate' in name) or ('residual_combine' in name): - assert not param.colo_attr.sharded_data_tensor.is_sharded, "`{}` parameter has problem".format(name) - else: - assert param.colo_attr.sharded_data_tensor.is_sharded - - # the parameters in moe experts is not replicated - if 'experts' in name: - assert not param.colo_attr.is_replicated - else: - assert param.colo_attr.is_replicated - - if param.colo_attr.param_is_sharded: - assert param.colo_attr.data_payload.device.type == init_device.type, \ - f'{param.colo_attr.data_payload.device.type} vs. {init_device.type}' - else: - assert param.colo_attr.data_payload.device.type == 'cuda' - - -def _run_dist(rank, world_size, port): - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - MOE_CONTEXT.setup(seed=42) - run_moe_zero_init() - - -@pytest.mark.dist -@pytest.mark.parametrize("world_size", [2, 4]) -@rerun_if_address_is_in_use() -def test_moe_zero_init(world_size): - run_func = partial(_run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) - - -if __name__ == '__main__': - test_moe_zero_init(world_size=2) +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp +import torch.nn as nn + +import colossalai +from colossalai.context import MOE_CONTEXT +from colossalai.logging import get_dist_logger +from colossalai.nn import CheckpointModule +from colossalai.nn.layer import MoeModule +from colossalai.testing import parameterize, rerun_if_address_is_in_use +from colossalai.utils import free_port, get_current_device +from colossalai.zero.legacy.init_ctx import ZeroInitContext +from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy, TensorShardStrategy +from tests.test_zero.common import CONFIG + + +class MoeModel(nn.Module): + + def __init__(self, checkpoint: bool = False): + + class TestSubModule(CheckpointModule): + + def __init__(self): + super().__init__(checkpoint) + expert_cls = nn.Linear + expert_args_dict = dict(in_features=16, out_features=16) + self.moe = MoeModule(dim_model=16, + num_experts=8, + use_residual=True, + expert_cls=expert_cls, + **expert_args_dict) + self.proj = nn.Linear(16, 4) + + def _forward(self, x): + x, y = self.moe(x) + x = self.proj(x) + return x, y + + super().__init__() + self.test_embed = nn.Linear(4, 16) + self.test_transform = TestSubModule() + + def forward(self, x): + MOE_CONTEXT.reset_loss() + + x = self.test_embed(x) + x, y = self.test_transform(x) + + MOE_CONTEXT.add_loss(y) + return x + + +@parameterize("init_device_type", ['cpu', 'cuda']) +@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy]) +def run_moe_zero_init(init_device_type, shard_strategy_class): + logger = get_dist_logger("test_moe_zero_init") + + if init_device_type == 'cuda': + init_device = get_current_device() + elif init_device_type == 'cpu': + init_device = torch.device("cpu") + else: + raise NotImplementedError("Unknown device found.") + + model_numel_tensor = torch.zeros(1, dtype=torch.int) + with ZeroInitContext(target_device=init_device, + shard_strategy=shard_strategy_class(), + shard_param=True, + model_numel_tensor=model_numel_tensor): + model = MoeModel(checkpoint=True) + + for name, param in model.named_parameters(): + assert hasattr(param, 'colo_attr') + + # the parameters in moe experts and its gate should not be sharded + if ('experts' in name) or ('gate' in name) or ('residual_combine' in name): + assert not param.colo_attr.sharded_data_tensor.is_sharded, "`{}` parameter has problem".format(name) + else: + assert param.colo_attr.sharded_data_tensor.is_sharded + + # the parameters in moe experts is not replicated + if 'experts' in name: + assert not param.colo_attr.is_replicated + else: + assert param.colo_attr.is_replicated + + if param.colo_attr.param_is_sharded: + assert param.colo_attr.data_payload.device.type == init_device.type, \ + f'{param.colo_attr.data_payload.device.type} vs. {init_device.type}' + else: + assert param.colo_attr.data_payload.device.type == 'cuda' + + +def _run_dist(rank, world_size, port): + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + MOE_CONTEXT.setup(seed=42) + run_moe_zero_init() + + +@pytest.mark.dist +@pytest.mark.parametrize("world_size", [2, 4]) +@rerun_if_address_is_in_use() +def test_moe_zero_init(world_size): + run_func = partial(_run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_moe_zero_init(world_size=2) diff --git a/tests/test_moe/test_moe_zero_model.py b/tests/test_moe/test_moe_zero_model.py index d608ebf0718e..d38f66fef658 100644 --- a/tests/test_moe/test_moe_zero_model.py +++ b/tests/test_moe/test_moe_zero_model.py @@ -10,11 +10,11 @@ from colossalai.nn import MoeLoss from colossalai.testing import assert_equal_in_group, parameterize, rerun_if_address_is_in_use from colossalai.utils import free_port -from colossalai.zero.init_ctx import ZeroInitContext -from colossalai.zero.shard_utils import BucketTensorShardStrategy, TensorShardStrategy -from colossalai.zero.sharded_model import ShardedModelV2 -from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16 -from colossalai.zero.sharded_model.utils import col_model_deepcopy +from colossalai.zero.legacy.init_ctx import ZeroInitContext +from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy, TensorShardStrategy +from colossalai.zero.legacy.sharded_model import ShardedModelV2 +from colossalai.zero.legacy.sharded_model._utils import cast_tensor_to_fp16 +from colossalai.zero.legacy.sharded_model.utils import col_model_deepcopy from tests.components_to_test.registry import non_distributed_component_funcs from tests.test_moe.test_moe_zero_init import MoeModel from tests.test_zero.common import CONFIG, check_grads_padding, run_fwd_bwd diff --git a/tests/test_moe/test_moe_zero_optim.py b/tests/test_moe/test_moe_zero_optim.py index 9d9a7bd17390..7e140bf862f2 100644 --- a/tests/test_moe/test_moe_zero_optim.py +++ b/tests/test_moe/test_moe_zero_optim.py @@ -12,12 +12,12 @@ from colossalai.nn.optimizer import CPUAdam from colossalai.testing import assert_equal_in_group, parameterize, rerun_if_address_is_in_use from colossalai.utils import free_port, get_current_device -from colossalai.zero.init_ctx import ZeroInitContext -from colossalai.zero.shard_utils import BucketTensorShardStrategy, TensorShardStrategy -from colossalai.zero.sharded_model import ShardedModelV2 -from colossalai.zero.sharded_model.utils import col_model_deepcopy -from colossalai.zero.sharded_optim import ShardedOptimizerV2 -from colossalai.zero.sharded_optim._utils import has_inf_or_nan +from colossalai.zero.legacy.init_ctx import ZeroInitContext +from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy, TensorShardStrategy +from colossalai.zero.legacy.sharded_model import ShardedModelV2 +from colossalai.zero.legacy.sharded_model.utils import col_model_deepcopy +from colossalai.zero.legacy.sharded_optim import ShardedOptimizerV2 +from colossalai.zero.low_level._utils import has_inf_or_nan from tests.components_to_test.registry import non_distributed_component_funcs from tests.test_moe.test_moe_zero_init import MoeModel from tests.test_zero.common import CONFIG, check_sharded_model_params diff --git a/tests/test_optimizer/test_cpu_adam.py b/tests/test_optimizer/test_cpu_adam.py index d317dc2e34ad..ea1c044f5820 100644 --- a/tests/test_optimizer/test_cpu_adam.py +++ b/tests/test_optimizer/test_cpu_adam.py @@ -56,7 +56,7 @@ def test_cpu_adam(adamw, step, p_dtype, g_dtype): eps = 1e-8 weight_decay = 0 - for i in range(1024): + for i in range(3): p_data = torch.rand(64, dtype=p_dtype) p_data_copy = p_data.clone().float() p_grad = torch.rand(64, dtype=g_dtype) diff --git a/tests/test_optimizer/test_fused_adam_kernel.py b/tests/test_optimizer/test_fused_adam_kernel.py index 7b9b6e9c48ba..8ff6618aee2e 100644 --- a/tests/test_optimizer/test_fused_adam_kernel.py +++ b/tests/test_optimizer/test_fused_adam_kernel.py @@ -54,7 +54,7 @@ def test_adam(adamw, step, p_dtype, g_dtype): count = 0 - for i in range(1024): + for i in range(3): p = torch.rand(64, dtype=p_dtype).cuda() p_copy = p.clone().float() g = torch.rand(p.shape, dtype=g_dtype).cuda() diff --git a/tests/test_optimizer/test_hybrid_adam.py b/tests/test_optimizer/test_hybrid_adam.py index d19192add3fb..2576d8ffee43 100644 --- a/tests/test_optimizer/test_hybrid_adam.py +++ b/tests/test_optimizer/test_hybrid_adam.py @@ -1,12 +1,12 @@ import torch import torch.nn as nn -from torch.optim.adam import Adam from torch.optim import AdamW +from torch.optim.adam import Adam from colossalai.nn.optimizer.hybrid_adam import HybridAdam from colossalai.testing import parameterize -RE = 1024 +RE = 3 @parameterize('adamw', [False, True]) diff --git a/tests/test_tensor/model/test_gpt2.py b/tests/test_tensor/model/test_gpt2.py index ad8ac87b2e1e..0d6a3fe26c2d 100644 --- a/tests/test_tensor/model/test_gpt2.py +++ b/tests/test_tensor/model/test_gpt2.py @@ -11,7 +11,7 @@ from colossalai.testing import rerun_if_address_is_in_use from colossalai.utils import free_port from colossalai.utils.cuda import get_current_device -from colossalai.utils.model.colo_init_context import ColoInitContext +from colossalai.zero import ColoInitContext from tests.components_to_test.registry import non_distributed_component_funcs from tests.test_tensor.common_utils import ( debug_print, diff --git a/tests/test_tensor/model/test_model.py b/tests/test_tensor/model/test_model.py index 3f53b94e0642..83abc641cbd4 100644 --- a/tests/test_tensor/model/test_model.py +++ b/tests/test_tensor/model/test_model.py @@ -11,7 +11,7 @@ from colossalai.testing import rerun_if_address_is_in_use from colossalai.utils import free_port from colossalai.utils.cuda import get_current_device -from colossalai.utils.model.colo_init_context import ColoInitContext +from colossalai.zero import ColoInitContext from tests.components_to_test.registry import non_distributed_component_funcs from tests.test_tensor.common_utils import ( check_equal, diff --git a/tests/test_tensor/model/test_module_spec.py b/tests/test_tensor/model/test_module_spec.py index 997b416f12c3..739bf2b0a641 100644 --- a/tests/test_tensor/model/test_module_spec.py +++ b/tests/test_tensor/model/test_module_spec.py @@ -20,7 +20,7 @@ from colossalai.testing import rerun_if_address_is_in_use from colossalai.utils import free_port from colossalai.utils.cuda import get_current_device -from colossalai.utils.model.colo_init_context import ColoInitContext +from colossalai.zero import ColoInitContext from tests.components_to_test.registry import non_distributed_component_funcs from tests.test_tensor.common_utils import set_seed, tensor_equal, tensor_shard_equal diff --git a/tests/test_tensor/test_context.py b/tests/test_tensor/test_context.py index 2f7aebed5bc4..047371f45bda 100644 --- a/tests/test_tensor/test_context.py +++ b/tests/test_tensor/test_context.py @@ -17,7 +17,7 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.utils import free_port from colossalai.utils.cuda import get_current_device -from colossalai.utils.model.colo_init_context import ColoInitContext +from colossalai.zero import ColoInitContext from tests.components_to_test.registry import non_distributed_component_funcs from tests.test_tensor.common_utils import set_seed diff --git a/tests/test_tensor/test_tp_with_zero.py b/tests/test_tensor/test_tp_with_zero.py index 1a6d23f6a2eb..94e39e5d1546 100644 --- a/tests/test_tensor/test_tp_with_zero.py +++ b/tests/test_tensor/test_tp_with_zero.py @@ -7,14 +7,12 @@ import colossalai from colossalai.amp import convert_to_apex_amp -from colossalai.gemini.chunk import search_chunk_configuration -from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer -from colossalai.nn.parallel import GeminiDDP, ZeroDDP from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.utils import free_port from colossalai.utils.cuda import get_current_device -from colossalai.utils.model.colo_init_context import ColoInitContext +from colossalai.zero import ColoInitContext, GeminiAdamOptimizer, GeminiDDP, ZeroDDP +from colossalai.zero.gemini import search_chunk_configuration from tests.components_to_test.registry import non_distributed_component_funcs from tests.test_tensor.common_utils import set_seed, tensor_shard_equal from tests.test_tensor.model.test_gpt2 import init_megatron_spec diff --git a/tests/test_utils/test_colo_checkpoint.py b/tests/test_utils/test_colo_checkpoint.py index a5ea75fffc36..7c2ad9078f98 100644 --- a/tests/test_utils/test_colo_checkpoint.py +++ b/tests/test_utils/test_colo_checkpoint.py @@ -1,25 +1,23 @@ -import os, shutil -import torch -import pytest +import os +import shutil from copy import deepcopy from functools import partial -import torch.multiprocessing as mp +import pytest +import torch import torch.distributed as dist - -from torch.optim.lr_scheduler import CosineAnnealingLR -from torch.optim.lr_scheduler import MultiplicativeLR -from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR +import torch.multiprocessing as mp +from torch.optim.lr_scheduler import CosineAnnealingLR, MultiplicativeLR import colossalai +from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR +from colossalai.nn.optimizer import ColossalaiOptimizer +from colossalai.tensor import ColoTensor, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils.cuda import get_current_device from colossalai.utils import free_port -from colossalai.utils.model.colo_init_context import ColoInitContext -from colossalai.tensor import ComputePattern, ComputeSpec, ColoTensor, ShardSpec, ProcessGroup -from colossalai.utils.checkpoint import save_checkpoint, load_checkpoint -from colossalai.nn.optimizer import ColossalaiOptimizer - +from colossalai.utils.checkpoint import load_checkpoint, save_checkpoint +from colossalai.utils.cuda import get_current_device +from colossalai.zero import ColoInitContext from tests.components_to_test.registry import non_distributed_component_funcs diff --git a/tests/test_utils/test_commons.py b/tests/test_utils/test_commons.py index 0ecb7446c788..6bfa6f33c812 100644 --- a/tests/test_utils/test_commons.py +++ b/tests/test_utils/test_commons.py @@ -1,13 +1,12 @@ -from colossalai.utils import free_port -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.zero.sharded_param import ShardedTensor -from colossalai.gemini.tensor_utils import colo_model_data_tensor_move, colo_model_data_tensor_move_inline -import colossalai - import torch - import torch.multiprocessing as mp +import colossalai +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils import free_port +from colossalai.zero.legacy.gemini.tensor_utils import colo_model_data_tensor_move, colo_model_data_tensor_move_inline +from colossalai.zero.legacy.sharded_param import ShardedTensor + def run_tensor_move(rank): colossalai.launch(config={}, rank=0, world_size=1, host='localhost', port=free_port(), backend='nccl') diff --git a/tests/test_utils/test_zero_gradient_clippling.py b/tests/test_utils/test_zero_gradient_clippling.py index 8bdae88464b1..920656726d63 100644 --- a/tests/test_utils/test_zero_gradient_clippling.py +++ b/tests/test_utils/test_zero_gradient_clippling.py @@ -2,21 +2,22 @@ # -*- encoding: utf-8 -*- import copy +from functools import partial -import colossalai -from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2 import pytest import torch import torch.distributed as dist import torch.multiprocessing as mp import torch.nn as nn -from colossalai.logging import disable_existing_loggers -from colossalai.utils import checkpoint, clip_grad_norm_fp32, free_port from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.utils import clip_grad_norm_ -from colossalai.zero.shard_utils.tensor_shard_strategy import TensorShardStrategy -from functools import partial + +import colossalai +from colossalai.logging import disable_existing_loggers from colossalai.testing import parameterize, rerun_if_address_is_in_use +from colossalai.utils import checkpoint, clip_grad_norm_fp32, free_port +from colossalai.zero.legacy.shard_utils.tensor_shard_strategy import TensorShardStrategy +from colossalai.zero.legacy.sharded_model.sharded_model_v2 import ShardedModelV2 def checkpoint_wrapper(module, enable=True): diff --git a/tests/test_zero/common.py b/tests/test_zero/common.py index bc6cd75a6a60..2c3d122c79af 100644 --- a/tests/test_zero/common.py +++ b/tests/test_zero/common.py @@ -2,10 +2,11 @@ import torch import torch.distributed as dist + from colossalai.logging import get_dist_logger from colossalai.utils import checkpoint -from colossalai.zero.shard_utils import TensorShardStrategy -from colossalai.zero.sharded_model import ShardedModelV2 +from colossalai.zero.legacy.shard_utils import TensorShardStrategy +from colossalai.zero.legacy.sharded_model import ShardedModelV2 LOGGER = get_dist_logger('zero_test') diff --git a/tests/test_zero/low_level_zero/test_zero_init.py b/tests/test_zero/low_level_zero/test_zero_init.py index 1305da5df9c5..803d0021df96 100644 --- a/tests/test_zero/low_level_zero/test_zero_init.py +++ b/tests/test_zero/low_level_zero/test_zero_init.py @@ -9,8 +9,7 @@ import colossalai from colossalai.tensor import ProcessGroup from colossalai.utils import free_port, get_current_device -from colossalai.utils.model.colo_init_context import ColoInitContext -from colossalai.zero import LowLevelZeroOptimizer +from colossalai.zero import ColoInitContext, LowLevelZeroOptimizer class MlpModel(nn.Module): diff --git a/tests/test_zero/low_level_zero/test_zero_tp.py b/tests/test_zero/low_level_zero/test_zero_tp.py index 15d3530ff90a..bb7495583dc9 100644 --- a/tests/test_zero/low_level_zero/test_zero_tp.py +++ b/tests/test_zero/low_level_zero/test_zero_tp.py @@ -11,8 +11,7 @@ from colossalai.tensor import ProcessGroup from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.utils import free_port, get_current_device -from colossalai.utils.model.colo_init_context import ColoInitContext -from colossalai.zero import LowLevelZeroOptimizer +from colossalai.zero import ColoInitContext, LowLevelZeroOptimizer from tests.test_tensor.common_utils import set_seed, split_param_col_tp1d, split_param_row_tp1d, tensor_shard_equal diff --git a/tests/test_zero/test_found_inf.py b/tests/test_zero/test_found_inf.py index 34283f5015e1..641136718161 100644 --- a/tests/test_zero/test_found_inf.py +++ b/tests/test_zero/test_found_inf.py @@ -1,72 +1,72 @@ -from functools import partial - -import colossalai -from colossalai.utils.cuda import get_current_device -import pytest -import torch -import torch.multiprocessing as mp -from colossalai.nn.optimizer import HybridAdam -from colossalai.testing import parameterize, rerun_if_address_is_in_use -from colossalai.utils import free_port -from colossalai.zero.init_ctx import ZeroInitContext -from colossalai.zero.shard_utils import BucketTensorShardStrategy -from colossalai.zero.sharded_model import ShardedModelV2 -from colossalai.zero.sharded_optim import ShardedOptimizerV2 -from colossalai.zero.sharded_optim._utils import has_inf_or_nan -from tests.components_to_test.registry import non_distributed_component_funcs -from tests.test_zero.test_sharded_optim_v2 import _run_step - -from common import CONFIG - - -@parameterize("cpu_offload", [True, False]) -@parameterize("shard_strategy_class", [BucketTensorShardStrategy]) -@parameterize("gpu_margin_mem_ratio", [0.0, 0.7]) -def _run_test_found_inf(cpu_offload, shard_strategy_class, gpu_margin_mem_ratio): - test_models = ['repeated_computed_layers'] - shard_strategy = shard_strategy_class() - - for model_name in test_models: - get_components_func = non_distributed_component_funcs.get_callable(model_name) - model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func() - - with ZeroInitContext(target_device=torch.device(f'cpu:0') if cpu_offload else get_current_device(), - shard_strategy=shard_strategy, - shard_param=True): - zero_model = model_builder(checkpoint=True) - zero_model = ShardedModelV2( - zero_model, - shard_strategy, - tensor_placement_policy='cpu' if cpu_offload else 'cuda', - reuse_fp16_shard=True, - ) - - sharded_optim = HybridAdam(zero_model.parameters(), lr=1e-3) - sharded_optim = ShardedOptimizerV2(zero_model, sharded_optim, gpu_margin_mem_ratio=gpu_margin_mem_ratio) - - for i, (data, label) in enumerate(train_dataloader): - if i > 1: - break - assert zero_model.overflow_counter == 0 - data, label = data.cuda(), label.cuda() - _run_step(zero_model, sharded_optim, data, label, criterion, False) - for param in zero_model.parameters(): - assert not has_inf_or_nan(param.colo_attr.data_payload) - - -def _run_dist(rank, world_size, port): - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - _run_test_found_inf() - - -# use_cpuadam = True can be used with cpu_offload = False -@pytest.mark.dist -@pytest.mark.parametrize("world_size", [1, 2]) -@rerun_if_address_is_in_use() -def test_found_inf(world_size): - run_func = partial(_run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) - - -if __name__ == '__main__': - test_found_inf(world_size=2) +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp +from common import CONFIG + +import colossalai +from colossalai.nn.optimizer import HybridAdam +from colossalai.testing import parameterize, rerun_if_address_is_in_use +from colossalai.utils import free_port +from colossalai.utils.cuda import get_current_device +from colossalai.zero.legacy.init_ctx import ZeroInitContext +from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy +from colossalai.zero.legacy.sharded_model import ShardedModelV2 +from colossalai.zero.legacy.sharded_optim import ShardedOptimizerV2 +from colossalai.zero.low_level._utils import has_inf_or_nan +from tests.components_to_test.registry import non_distributed_component_funcs +from tests.test_zero.test_sharded_optim_v2 import _run_step + + +@parameterize("cpu_offload", [True, False]) +@parameterize("shard_strategy_class", [BucketTensorShardStrategy]) +@parameterize("gpu_margin_mem_ratio", [0.0, 0.7]) +def _run_test_found_inf(cpu_offload, shard_strategy_class, gpu_margin_mem_ratio): + test_models = ['repeated_computed_layers'] + shard_strategy = shard_strategy_class() + + for model_name in test_models: + get_components_func = non_distributed_component_funcs.get_callable(model_name) + model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func() + + with ZeroInitContext(target_device=torch.device(f'cpu:0') if cpu_offload else get_current_device(), + shard_strategy=shard_strategy, + shard_param=True): + zero_model = model_builder(checkpoint=True) + zero_model = ShardedModelV2( + zero_model, + shard_strategy, + tensor_placement_policy='cpu' if cpu_offload else 'cuda', + reuse_fp16_shard=True, + ) + + sharded_optim = HybridAdam(zero_model.parameters(), lr=1e-3) + sharded_optim = ShardedOptimizerV2(zero_model, sharded_optim, gpu_margin_mem_ratio=gpu_margin_mem_ratio) + + for i, (data, label) in enumerate(train_dataloader): + if i > 1: + break + assert zero_model.overflow_counter == 0 + data, label = data.cuda(), label.cuda() + _run_step(zero_model, sharded_optim, data, label, criterion, False) + for param in zero_model.parameters(): + assert not has_inf_or_nan(param.colo_attr.data_payload) + + +def _run_dist(rank, world_size, port): + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + _run_test_found_inf() + + +# use_cpuadam = True can be used with cpu_offload = False +@pytest.mark.dist +@pytest.mark.parametrize("world_size", [1, 2]) +@rerun_if_address_is_in_use() +def test_found_inf(world_size): + run_func = partial(_run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_found_inf(world_size=2) diff --git a/tests/test_zero/test_init_context.py b/tests/test_zero/test_init_context.py index 0cba7a492380..0eb8842de3e3 100644 --- a/tests/test_zero/test_init_context.py +++ b/tests/test_zero/test_init_context.py @@ -9,14 +9,14 @@ from common import CONFIG import colossalai -from colossalai.gemini.memory_tracer.utils import colo_model_mem_usage from colossalai.logging import get_dist_logger from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.utils import free_port from colossalai.utils.cuda import get_current_device from colossalai.utils.memory import colo_device_memory_used -from colossalai.zero.init_ctx import ZeroInitContext -from colossalai.zero.shard_utils import BucketTensorShardStrategy, TensorShardStrategy +from colossalai.zero.gemini.memory_tracer.utils import colo_model_mem_usage +from colossalai.zero.legacy.init_ctx import ZeroInitContext +from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy, TensorShardStrategy from tests.components_to_test.registry import non_distributed_component_funcs diff --git a/tests/test_zero/test_shard_model_v2.py b/tests/test_zero/test_shard_model_v2.py index 95a9dee38acf..884444adf167 100644 --- a/tests/test_zero/test_shard_model_v2.py +++ b/tests/test_zero/test_shard_model_v2.py @@ -12,11 +12,11 @@ import colossalai from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.utils import free_port -from colossalai.zero.init_ctx import ZeroInitContext -from colossalai.zero.shard_utils import BucketTensorShardStrategy -from colossalai.zero.sharded_model import ShardedModelV2 -from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16 -from colossalai.zero.sharded_model.utils import col_model_deepcopy +from colossalai.zero.legacy.init_ctx import ZeroInitContext +from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy +from colossalai.zero.legacy.sharded_model import ShardedModelV2 +from colossalai.zero.legacy.sharded_model._utils import cast_tensor_to_fp16 +from colossalai.zero.legacy.sharded_model.utils import col_model_deepcopy from tests.components_to_test.registry import non_distributed_component_funcs diff --git a/tests/test_zero/test_shard_param.py b/tests/test_zero/test_shard_param.py index 8db2b7e79604..6085de3c8919 100644 --- a/tests/test_zero/test_shard_param.py +++ b/tests/test_zero/test_shard_param.py @@ -1,17 +1,18 @@ from copy import deepcopy from functools import partial -import colossalai import pytest import torch import torch.multiprocessing as mp + +import colossalai from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.utils import free_port -from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy) -from colossalai.zero.sharded_param import ShardedTensor -from colossalai.zero.sharded_param.sharded_param import ShardedParamV2 +from colossalai.zero.legacy.gemini.stateful_tensor import StatefulTensor +from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy, TensorShardStrategy +from colossalai.zero.legacy.sharded_param import ShardedTensor +from colossalai.zero.legacy.sharded_param.sharded_param import ShardedParamV2 from tests.test_zero.common import CONFIG, allclose -from colossalai.gemini.stateful_tensor import StatefulTensor @parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy]) diff --git a/tests/test_zero/test_sharded_optim_state_dict.py b/tests/test_zero/test_sharded_optim_state_dict.py index f8c42930b281..d257a02854c6 100644 --- a/tests/test_zero/test_sharded_optim_state_dict.py +++ b/tests/test_zero/test_sharded_optim_state_dict.py @@ -1,20 +1,21 @@ +from functools import partial + import pytest -import colossalai import torch import torch.multiprocessing as mp -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils.cuda import get_current_device -from colossalai.utils import free_port -from functools import partial -from tests.test_tensor.common_utils import set_seed -from tests.components_to_test.registry import non_distributed_component_funcs -from colossalai.testing import parameterize + +import colossalai from colossalai.nn.optimizer import HybridAdam -from colossalai.zero.init_ctx import ZeroInitContext -from colossalai.zero.shard_utils import TensorShardStrategy -from colossalai.zero.sharded_model import ShardedModelV2 -from colossalai.zero.sharded_optim import ShardedOptimizerV2 from colossalai.tensor import ProcessGroup +from colossalai.testing import parameterize, rerun_if_address_is_in_use +from colossalai.utils import free_port +from colossalai.utils.cuda import get_current_device +from colossalai.zero.legacy.init_ctx import ZeroInitContext +from colossalai.zero.legacy.shard_utils import TensorShardStrategy +from colossalai.zero.legacy.sharded_model import ShardedModelV2 +from colossalai.zero.legacy.sharded_optim import ShardedOptimizerV2 +from tests.components_to_test.registry import non_distributed_component_funcs +from tests.test_tensor.common_utils import set_seed def init_zero(model_builder, placement_policy): diff --git a/tests/test_zero/test_sharded_optim_v2.py b/tests/test_zero/test_sharded_optim_v2.py index 8fe7eb639eab..3eea13d5d5c2 100644 --- a/tests/test_zero/test_sharded_optim_v2.py +++ b/tests/test_zero/test_sharded_optim_v2.py @@ -13,12 +13,12 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.utils import free_port from colossalai.utils.cuda import get_current_device -from colossalai.zero.init_ctx import ZeroInitContext -from colossalai.zero.shard_utils import BucketTensorShardStrategy, TensorShardStrategy -from colossalai.zero.sharded_model import ShardedModelV2 -from colossalai.zero.sharded_model.utils import col_model_deepcopy -from colossalai.zero.sharded_optim import ShardedOptimizerV2 -from colossalai.zero.sharded_optim._utils import has_inf_or_nan +from colossalai.zero.legacy.init_ctx import ZeroInitContext +from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy, TensorShardStrategy +from colossalai.zero.legacy.sharded_model import ShardedModelV2 +from colossalai.zero.legacy.sharded_model.utils import col_model_deepcopy +from colossalai.zero.legacy.sharded_optim import ShardedOptimizerV2 +from colossalai.zero.low_level._utils import has_inf_or_nan from tests.components_to_test.registry import non_distributed_component_funcs diff --git a/tests/test_zero/test_sharded_optim_with_sync_bn.py b/tests/test_zero/test_sharded_optim_with_sync_bn.py index ea5b315188a3..05512f59a36a 100644 --- a/tests/test_zero/test_sharded_optim_with_sync_bn.py +++ b/tests/test_zero/test_sharded_optim_with_sync_bn.py @@ -3,18 +3,19 @@ from functools import partial -import colossalai import pytest import torch import torch.distributed as dist import torch.multiprocessing as mp +from torchvision.models import resnet50 + +import colossalai from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc from colossalai.testing import rerun_if_address_is_in_use from colossalai.utils import free_port -from colossalai.zero.init_ctx import ZeroInitContext -from colossalai.zero.shard_utils import TensorShardStrategy -from torchvision.models import resnet50 +from colossalai.zero.legacy.init_ctx import ZeroInitContext +from colossalai.zero.legacy.shard_utils import TensorShardStrategy def run_dist(rank, world_size, port): diff --git a/tests/test_zero/test_state_dict.py b/tests/test_zero/test_state_dict.py index 7ac9b151e4d6..c435d9bb1ef7 100644 --- a/tests/test_zero/test_state_dict.py +++ b/tests/test_zero/test_state_dict.py @@ -4,20 +4,20 @@ from copy import deepcopy from functools import partial -import colossalai import pytest import torch import torch.multiprocessing as mp +from common import CONFIG + +import colossalai from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.utils import free_port -from colossalai.zero.init_ctx import ZeroInitContext -from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy) -from colossalai.zero.sharded_model import ShardedModelV2 -from colossalai.zero.sharded_model.utils import col_model_deepcopy +from colossalai.zero.legacy.init_ctx import ZeroInitContext +from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy, TensorShardStrategy +from colossalai.zero.legacy.sharded_model import ShardedModelV2 +from colossalai.zero.legacy.sharded_model.utils import col_model_deepcopy from tests.components_to_test.registry import non_distributed_component_funcs -from common import CONFIG - @parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy]) def run_zero_state_dict(shard_strategy_class): diff --git a/tests/test_zero/test_tensor_utils.py b/tests/test_zero/test_tensor_utils.py index 81855ff5e10a..3114481707c2 100644 --- a/tests/test_zero/test_tensor_utils.py +++ b/tests/test_zero/test_tensor_utils.py @@ -1,18 +1,21 @@ +from functools import partial + import pytest +import torch +import torch.multiprocessing as mp import colossalai -from colossalai.utils.cuda import get_current_device -from colossalai.gemini.tensor_utils import (colo_tensor_mem_usage, colo_model_data_tensor_move, - colo_model_data_tensor_move_inline, colo_model_data_move_to_cpu, - colo_model_tensor_clone) -from colossalai.gemini.stateful_tensor import StatefulTensor -from colossalai.utils import free_port from colossalai.testing import rerun_if_address_is_in_use - -import torch - -from functools import partial -import torch.multiprocessing as mp +from colossalai.utils import free_port +from colossalai.utils.cuda import get_current_device +from colossalai.zero.legacy.gemini.stateful_tensor import StatefulTensor +from colossalai.zero.legacy.gemini.tensor_utils import ( + colo_model_data_move_to_cpu, + colo_model_data_tensor_move, + colo_model_data_tensor_move_inline, + colo_model_tensor_clone, + colo_tensor_mem_usage, +) def _run_colo_tensor_mem_usage(): diff --git a/tests/test_zero/test_zero_engine.py b/tests/test_zero/test_zero_engine.py index 80ded65d634c..1e7f53358526 100644 --- a/tests/test_zero/test_zero_engine.py +++ b/tests/test_zero/test_zero_engine.py @@ -3,21 +3,21 @@ from functools import partial -import colossalai import pytest import torch import torch.distributed as dist import torch.multiprocessing as mp +from common import MP_PARALLEL_CONFIG, ZERO_PARALLEL_CONFIG, check_params, check_sharded_model_params +from torch.nn.parallel import DistributedDataParallel as DDP + +import colossalai from colossalai.core import global_context as gpc from colossalai.testing import rerun_if_address_is_in_use from colossalai.utils import free_port -from colossalai.zero.init_ctx import ZeroInitContext -from colossalai.zero.sharded_model.utils import col_model_deepcopy -from colossalai.zero.sharded_optim._utils import has_inf_or_nan +from colossalai.zero.legacy.init_ctx import ZeroInitContext +from colossalai.zero.legacy.sharded_model.utils import col_model_deepcopy +from colossalai.zero.low_level._utils import has_inf_or_nan from tests.components_to_test.registry import non_distributed_component_funcs -from torch.nn.parallel import DistributedDataParallel as DDP - -from common import (MP_PARALLEL_CONFIG, ZERO_PARALLEL_CONFIG, check_params, check_sharded_model_params) def run_dist(rank, world_size, port, parallel_config): From 1beb85cc25f35d51083bde6fbaa99a5c4c7fd387 Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Tue, 4 Apr 2023 15:23:01 +0800 Subject: [PATCH 05/27] [checkpoint] refactored the API and added safetensors support (#3427) * [checkpoint] refactored the API and added safetensors support * polish code --- colossalai/booster/plugin/torch_ddp_plugin.py | 4 +- colossalai/checkpoint_io/__init__.py | 5 +- .../checkpoint_io/checkpoint_io_base.py | 332 ++++-------------- .../checkpoint_io/general_checkpoint_io.py | 53 ++- colossalai/checkpoint_io/index_file.py | 150 ++++++++ colossalai/checkpoint_io/utils.py | 278 +++++++++++++++ requirements/requirements.txt | 1 + .../test_plugin/test_torch_ddp_plugin.py | 23 ++ .../test_general_checkpoint_io.py | 13 +- 9 files changed, 579 insertions(+), 280 deletions(-) create mode 100644 colossalai/checkpoint_io/index_file.py create mode 100644 colossalai/checkpoint_io/utils.py diff --git a/colossalai/booster/plugin/torch_ddp_plugin.py b/colossalai/booster/plugin/torch_ddp_plugin.py index d7f3d22d93cc..e2abe11ba143 100644 --- a/colossalai/booster/plugin/torch_ddp_plugin.py +++ b/colossalai/booster/plugin/torch_ddp_plugin.py @@ -33,7 +33,7 @@ def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool = # the model should be unwrapped in self.load_model via ModelWrapper.unwrap return super().load_unsharded_model(model, checkpoint, strict=strict) - def save_unsharded_model(self, model: nn.Module, checkpoint: str): + def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool): """ Save model to checkpoint but only on master process. """ @@ -41,7 +41,7 @@ def save_unsharded_model(self, model: nn.Module, checkpoint: str): if self.coordinator.is_master(): super().save_unsharded_model(model, checkpoint) - def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str): + def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool): """ Save optimizer to checkpoint but only on master process. """ diff --git a/colossalai/checkpoint_io/__init__.py b/colossalai/checkpoint_io/__init__.py index 3cec630b2f86..c25048e25754 100644 --- a/colossalai/checkpoint_io/__init__.py +++ b/colossalai/checkpoint_io/__init__.py @@ -1,4 +1,5 @@ -from .checkpoint_io_base import CheckpointIO, ShardCheckpointIndexFile +from .checkpoint_io_base import CheckpointIO from .general_checkpoint_io import GeneralCheckpointIO +from .index_file import CheckpointIndexFile -__all__ = ['CheckpointIO', 'ShardCheckpointIndexFile', 'GeneralCheckpointIO'] +__all__ = ['CheckpointIO', 'CheckpointIndexFile', 'GeneralCheckpointIO'] diff --git a/colossalai/checkpoint_io/checkpoint_io_base.py b/colossalai/checkpoint_io/checkpoint_io_base.py index d6eef7a96cdc..b91b00831e52 100644 --- a/colossalai/checkpoint_io/checkpoint_io_base.py +++ b/colossalai/checkpoint_io/checkpoint_io_base.py @@ -1,7 +1,6 @@ -import json from abc import ABC, abstractmethod from pathlib import Path -from typing import Any, Union +from typing import Union import torch import torch.nn as nn @@ -10,7 +9,9 @@ from colossalai.interface import ModelWrapper -__all__ = ['CheckpointIO', 'ShardCheckpointIndexFile'] +from .utils import has_index_file + +__all__ = ['CheckpointIO'] class CheckpointIO(ABC): @@ -25,15 +26,31 @@ class CheckpointIO(ABC): >>> # load model from checkpoint >>> model = checkpoint_io.load_model(model, 'model.pt') >>> - >>> # save model to checkpoint + >>> # save model to checkpoint, any distributed tensor is gathered by default >>> checkpoint_io.save_model(model, 'model.pt') >>> + >>> # if the model contains distributed tensor, and you don't want to gather it + >>> # each rank will save its own shard of the distributed tensor + >>> checkpoint_io.save_model(model, 'model.pt', gather_dtensor=False) + >>> >>> # save model to sharded checkpoints >>> checkpoint_io.save_model(model, './checkpoints/', shard=True) >>> + >>> # save model to sharded and assume we don't want to gather distributed tensors + >>> checkpoint_io.save_model(model, './checkpoints/', shard=True, gather_dtensor=False) + >>> + >>> # Note: + >>> # 1. we don't support loading from distributed tensors, conversion from distributed tensors + >>> # checkpoints to full tensor checkpoint should be done offline via our CLI + >>> # 2. you don't have to specify whether the model is sharded or not when loading the model + >>> # as it will be automatically detected + >>> >>> # load model from sharded checkpoints >>> model = checkpoint_io.load_model(model, './checkpoints/') >>> + >>> # load model from unsharded checkpoints + >>> model = checkpoint_io.load_model(model, './checkpoints/') + >>> >>> # load optimizer from checkpoint >>> optimizer = checkpoint_io.load_optimizer(optimizer, 'optimizer.pt') >>> @@ -58,21 +75,27 @@ def load_model(self, 1. a file path, e.g. 'model.pt' 2. a path to a json file which defines the index to the sharded checkpoint 3. a path to a folder containing a unique .index.json file for sharded checkpoint + Distributed tensors cannot be loaded directly unless gathered offline via our CLI. strict (bool): whether to strictly enforce that the param name in the checkpoint match the keys returned by this module's. """ + # since we only support loaded sharded and unsharded weight format + # containing no distributed tensors, dtensor -> full tensor conversion + # should be done offline via our CLI + # the existence of index file means it is a sharded checkpoint ckpt_path = Path(checkpoint) - is_sharded = self.is_sharded_checkpoint(ckpt_path) + index_file_exists, index_file_path = has_index_file(checkpoint) + # return the origin model instead of the unwrapped model origin_model = model if isinstance(model, ModelWrapper): model = model.unwrap() - if is_sharded: - self.load_sharded_model(model, ckpt_path, strict) + if index_file_exists: + self.load_sharded_model(model, index_file_path, strict) else: - self.load_unsharded_model(model, ckpt_path, strict) + self.load_unsharded_model(model, checkpoint, strict) return origin_model @@ -80,8 +103,10 @@ def save_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, shard: bool = False, + gather_dtensor: bool = True, prefix: str = None, - size_per_shard: int = 1024): + size_per_shard: int = 1024, + use_safetensors: bool = False): """ Save model to checkpoint. @@ -103,17 +128,19 @@ def save_model(self, shard (bool): whether to shard the checkpoint. Default: False. If set to True, the checkpoint will be sharded into multiple files. The model shards will be specificed by a `model.index.json` file. When shard = True, please ensure that the checkpoint path is a directory path instead of a file path. + gather_dtensor (bool): whether to gather the distributed tensor to the first device. Default: True. prefix (str): prefix for the model checkpoint file name when shard=True. Default: None. size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard = True. + use_safetensors (bool): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved """ if isinstance(model, ModelWrapper): model = model.unwrap() if shard: - self.save_sharded_model(model, checkpoint, prefix, size_per_shard) + self.save_sharded_model(model, checkpoint, gather_dtensor, prefix, size_per_shard, use_safetensors) else: - self.save_unsharded_model(model, checkpoint) + self.save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors) def load_optimizer(self, optimizer: Optimizer, checkpoint: str): """ @@ -123,22 +150,27 @@ def load_optimizer(self, optimizer: Optimizer, checkpoint: str): optimizer (Optimizer): optimizer to be loaded. checkpoint (str): checkpoint path. This value is made compatiblity with the model checkpoints in the """ - ckpt_path = Path(checkpoint) - is_sharded = self.is_sharded_checkpoint(ckpt_path) + index_file_exists, index_file_path = has_index_file(checkpoint) - if is_sharded: - self.load_sharded_optimizer(optimizer, ckpt_path) + if Path(checkpoint).is_dir() and not index_file_exists: + # if the checkpoint is a directory and there is no index file, raise error + raise ValueError(f'Cannot find index file in {checkpoint}') + + if index_file_exists: + # the existence of index file means it is a sharded checkpoint + self.load_sharded_optimizer(optimizer, index_file_path) else: - self.load_unsharded_optimizer(optimizer, ckpt_path) + self.load_unsharded_optimizer(optimizer, checkpoint) def save_optimizer(self, optimizer: Optimizer, checkpoint: str, shard: bool = False, + gather_dtensor=True, prefix: str = None, size_per_shard: int = 1024): """ - Save optimizer to checkpoint. + Save optimizer to checkpoint. Optimizer states saving is not compatible with safetensors. Args: optimizer (Optimizer): optimizer to be saved. @@ -148,30 +180,33 @@ def save_optimizer(self, 3. a path to a folder containing a unique .index.json file for sharded checkpoint shard (bool): whether to shard the checkpoint. Default: False. If set to True, the checkpoint will be sharded into multiple files. The optimizer shards will be specificed by a `optimizer.index.json` file. + gather_dtensor (bool): whether to gather the distributed tensor to the first device. Default: True. prefix (str): prefix for the optimizer checkpoint when shard = True. Default: None. size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard is set to True. """ if shard: - self.save_sharded_optimizer(optimizer, checkpoint, prefix, size_per_shard) + self.save_sharded_optimizer(optimizer, checkpoint, gather_dtensor, prefix, size_per_shard) else: - self.save_unsharded_optimizer(optimizer, checkpoint) + self.save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor) # ======================================================== # Abstract methods for model loading/saving implementation # ======================================================== @abstractmethod - def load_sharded_model(self, model: nn.Module, checkpoint: Path, strict: bool): + def load_sharded_model(self, model: nn.Module, index_file_path: str, strict: bool): """ Load model from sharded checkpoint. Args: model (nn.Module): model to be loaded. - checkpoint (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file. + index_file_path (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file. + strict (bool): whether to strictly enforce that the param name in + the checkpoint match the keys returned by this module's. """ pass @abstractmethod - def load_unsharded_model(self, model: nn.Module, checkpoint: Path, strict: bool): + def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool): """ Load model from unsharded checkpoint. @@ -184,26 +219,31 @@ def load_unsharded_model(self, model: nn.Module, checkpoint: Path, strict: bool) pass @abstractmethod - def save_sharded_model(self, model: nn.Module, checkpoint: Path, prefix: str, size_per_shard: int): + def save_sharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, prefix: str, + size_per_shard: int, use_safetensors: bool): """ Save model to sharded checkpoint. Args: model (nn.Module): model to be saved. - checkpoint (Path): checkpoint path. It should be a directory path. + checkpoint (str): checkpoint path. It should be a directory path. + gather_dtensor (bool): whether to gather the distributed tensor to the first device. prefix (str): prefix for the model checkpoint. size_per_shard (int): size per shard in MB. + use_safetensors (bool): whether to use safe tensors. """ pass @abstractmethod - def save_unsharded_model(self, model: nn.Module, checkpoint: Path): + def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): """ Save model to unsharded checkpoint. Args: model (nn.Module): model to be saved. - checkpoint (Path): checkpoint path. It should be a single file path pointing to a model weight binary. + checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary. + gather_dtensor (bool): whether to gather the distributed tensor to the first device. + use_safetensors (bool): whether to use safe tensors. """ pass @@ -212,13 +252,13 @@ def save_unsharded_model(self, model: nn.Module, checkpoint: Path): # ======================================================== @abstractmethod - def load_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, prefix: str, size_per_shard: int): + def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, prefix: str, size_per_shard: int): """ Load optimizer from sharded checkpoint. Args: optimizer (Optimizer): optimizer to be loaded. - checkpoint (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file. + index_file_path (str): checkpoint path. It should be path to the .index.json file or a path to a directory which contains a .index.json file. prefix (str): prefix for the optimizer checkpoint. size_per_shard (int): size per shard in MB. """ @@ -236,26 +276,29 @@ def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path): pass @abstractmethod - def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, prefix: str, size_per_shard: int): + def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, prefix: str, + size_per_shard: int): """ Save optimizer to sharded checkpoint. Args: optimizer (Optimizer): optimizer to be saved. checkpoint (Path): checkpoint path. It should be a directory path. + gather_dtensor (bool): whether to gather the distributed tensor to the first device. prefix (str): prefix for the optimizer checkpoint. size_per_shard (int): size per shard in MB. """ pass @abstractmethod - def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path): + def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool): """ Save optimizer to unsharded checkpoint. Args: optimizer (Optimizer): optimizer to be saved. checkpoint (str): checkpoint path. It should be a single file path pointing to a model weight binary. + gather_dtensor (bool): whether to gather the distributed tensor to the first device. """ pass @@ -264,7 +307,6 @@ def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path): # as this is quite standard, there is no need # to make them abstract # ============================================ - def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): """ Save lr scheduler to checkpoint. @@ -285,231 +327,3 @@ def load_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): """ state_dict = torch.load(checkpoint) lr_scheduler.load_state_dict(state_dict) - - # ======================================== - # Helper functions for loading state dict - # ======================================== - - def get_sharded_checkpoint_index_file(self, checkpoint_path: Path): - """ - Get the index file path for a sharded checkpoint. - - Args: - checkpoint_path (Path): path to the checkpoint. - - Returns: - Path: path to the index file. - """ - if checkpoint_path.is_file(): - # check if it is .index.json - if checkpoint_path.name.endswith('.index.json'): - return checkpoint_path - else: - raise ValueError(f'Invalid checkpoint path: {checkpoint_path}. ') - elif checkpoint_path.is_dir(): - # check if there is only one a file ending with .index.json in this directory - index_files = list(checkpoint_path.glob('*.index.json')) - if len(index_files) == 1: - return index_files[0] - else: - raise ValueError(f'Found {len(index_files)} index files in {checkpoint_path}. ') - - def is_sharded_checkpoint(self, checkpoint_path: Path): - """ - Check whether the checkpoint is sharded. - - Args: - checkpoint (str): checkpoint path. - - Returns: - bool: whether the checkpoint is sharded. - """ - if checkpoint_path.is_file(): - # check if it is .index.json - if checkpoint_path.name.endswith('.index.json'): - return True - else: - return False - elif checkpoint_path.is_dir(): - # check if there is only one a file ending with .index.json in this directory - index_files = list(checkpoint_path.glob('*.index.json')) - if len(index_files) == 1: - return True - else: - raise ValueError(f'Found {len(index_files)} index files in {checkpoint_path}. ') - - def get_checkpoint_shard_filenames(self, index_file_path: Path): - """ - Get checkpoint shard filenames from a json file. - - Args: - index_file_path (Path): path to the json file. - - Returns: - list: checkpoint shard filenames. - """ - with open(str(index_file_path), 'r') as f: - shard_filenames = json.load(f) - - if "weight_map" in index: - index = index["weight_map"] - - checkpoint_root_path = index_file_path.absolute().parent - - # read the checkpoint file list from the json file and get a list of unique file names - checkpoint_files = sorted(list(set(index.values()))) - - # get the absolute paths for all checkpoint files - checkpoint_files = [checkpoint_root_path.joinpath(f) for f in checkpoint_files] - return shard_filenames - - def load_safetensors_state_dict(self, *args, **kwargs): - """ - Load safetensors state dict from checkpoint. - """ - # TODO(FrankLeeeee): support huggingface safetensors - raise NotImplementedError("This method is not implemented to support safe tensors") - - def load_state_dict(self, checkpoint_file_path: Path): - """ - Load state dict from checkpoint. - - Args: - checkpoint_file_path (Path): path to the checkpoint file. - - Returns: - dict: state dict. - """ - return torch.load(str(checkpoint_file_path)) - - # ====================================== - # Helper functions for saving state dict - # ====================================== - - def save_safetensors_state_dict(self, *args, **kwargs): - """ - Save safetensors state dict to checkpoint. - """ - # TODO(FrankLeeeee): support huggingface safetensors - raise NotImplementedError("This method is not implemented to support safe tensors") - - def generate_checkpoint_shard_file_name(self, index: int, total_number: int, prefix: str = None): - """ - Generate checkpoint shard file name. - - Args: - index (int): index of the shard. - total_number (int): total number of shards. - prefix (str): prefix of the shard file name. Default: None. - """ - if prefix is None: - return f"{index}-of-{total_number}.bin" - else: - return f"{prefix}-{index}-of-{total_number}.bin" - - def save_checkpoint(self, state_dict: dict, checkpoint_file_path: Path): - """ - Save state dict to checkpoint. - - Args: - state_dict (dict): state dict. - checkpoint_file_path (Path): path to the checkpoint file. - """ - torch.save(state_dict, str(checkpoint_file_path)) - - def save_state_dict_as_shard(self, state_dict: dict, index: int, total_number: int, prefix: str, - checkpoint_path: Path): - """ - Save state dict as shard. - - Args: - state_dict (dict): state dict. - checkpoint_path (Path): path to the checkpoint file. - """ - # generate the shard name - shard_file_name = self.generate_checkpoint_shard_file_name(index, total_number, prefix) - shard_file_path = checkpoint_path.joinpath(shard_file_name) - - # save the shard - self.save_checkpoint(state_dict, shard_file_path) - - def calculate_param_size(self, param: torch.Tensor): - """ - Calculate the size of a parameter in MB. Used to compute whether a group of params exceed the shard size. - If so, a new shard should be created. - - ArgsL - param (torch.Tensor): parameter tensor. - """ - # TODO(FrankLeeeee): check if this tensor is a DTensor, compute its global size if so - return param.numel() * param.element_size() / 1024 / 1024 - - -class ShardCheckpointIndexFile: - """ - This class is a data structure to keep the content in the index.json file for sharded checkpoint. - - Example: - >>> index = ShardCheckpointIndexFile() - >>> index.load('index.json') - >>> index.append_metadata('model_type', 'bert') - >>> index.append_weight_map('bert.embeddings.word_embeddings.weight', 'bert.embeddings.word_embeddings.weight-0-of-2.bin') - >>> index.export('index.json') - """ - - def __init__(self) -> None: - self.metadata: dict = dict() - self.weight_map: dict = dict() - - def load(self, json_path: str): - """ - Load the index file from a json file. - - Args: - json_path (str): path to the json file. - """ - # load the json file - with open(json_path, 'r') as f: - index = json.load(f) - - # assign attributes if exists - if "metadata" in index: - self.metadata = index["metadata"] - if "weight_map" in index: - self.weight_map = index["weight_map"] - - def export(self, json_path: str): - """ - Export the index file to a json file. - - Args: - json_path (str): path to the json file. - """ - # create the index file - index = dict() - index["metadata"] = self.metadata - index["weight_map"] = self.weight_map - - # export the index file - with open(json_path, 'w') as f: - json.dump(index, f, indent=4) - - def append_weight_map(self, param_name: str, shard_file: str): - """ - Append a weight map entry to the index file. - - Args: - param_name (str): name of the parameter. - shard_file (str): name of the shard file. - """ - self.weight_map[param_name] = shard_file - - def append_meta_data(self, name: str, val: Any): - """ - Append a metadata entry to the index file. - - Args: - name (str): name of the metadata. - val (Any): value of the metadata. - """ - self.metadata[name] = val diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py index cfabcfa5589f..c779f4c17355 100644 --- a/colossalai/checkpoint_io/general_checkpoint_io.py +++ b/colossalai/checkpoint_io/general_checkpoint_io.py @@ -4,42 +4,67 @@ from torch.optim import Optimizer from .checkpoint_io_base import CheckpointIO +from .index_file import CheckpointIndexFile +from .utils import has_index_file, load_state_dict, save_state_dict __all__ = ['GeneralCheckpointIO'] class GeneralCheckpointIO(CheckpointIO): - def load_sharded_model(self, model: nn.Module, checkpoint: Path, strict: bool): - index_file_path = self.get_sharded_checkpoint_index_file(checkpoint) + def load_sharded_model(self, model: nn.Module, index_file_path: Path, strict: bool): + # load the index file + index_file = CheckpointIndexFile.from_file(index_file_path) # iterate over the shard checkpoint files # and load each - shard_files = self.get_checkpoint_shard_filenames(index_file_path) - for shard_file in shard_files: - shard_checkpoint = self.load_state_dict(shard_file) + index_file.assert_no_dtensor_checkpoint() + checkpoint_file_list, _ = index_file.get_checkpoint_fileanames() + for shard_file in checkpoint_file_list: + shard_checkpoint = load_state_dict(shard_file) model.load_state_dict(shard_checkpoint, strict=strict) - def load_unsharded_model(self, model: nn.Module, checkpoint: Path, strict: bool): - checkpoint = self.load_state_dict(str(checkpoint)) + def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool): + checkpoint = load_state_dict(checkpoint) model.load_state_dict(checkpoint, strict=strict) - def save_sharded_model(self, model: nn.Module, checkpoint: Path, prefix: str, size_per_shard: int): + def save_sharded_model(self, model: nn.Module, checkpoint: Path, gather_dtensor: bool, prefix: str, + size_per_shard: int, use_safetensors: bool): # TODO(FrankLeeeee): implement this method as it can be supported by Huggingface model raise NotImplementedError("Sharded model checkpoint is not supported yet.") - def save_unsharded_model(self, model: nn.Module, checkpoint: Path): - self.save_checkpoint(model.state_dict(), checkpoint) + def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): + state_dict = model.state_dict() + + # TODO(FrankLeeeee): add support for gather_dtensor + if gather_dtensor: + pass + + # save the checkpoint + save_state_dict(state_dict, checkpoint, use_safetensors) def load_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, prefix: str, size_per_shard: int): raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.") def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path): - checkpoint = self.load_state_dict(checkpoint) + checkpoint = load_state_dict(checkpoint) optimizer.load_state_dict(checkpoint) - def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, prefix: str, size_per_shard: int): + def save_sharded_optimizer( + self, + optimizer: Optimizer, + checkpoint: Path, + gather_dtensor: bool, + prefix: str, + size_per_shard: int, + ): raise NotImplementedError("Sharded optimizer checkpoint is not supported yet.") - def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path): - self.save_checkpoint(optimizer.state_dict(), checkpoint) + def save_unsharded_optimizer( + self, + optimizer: Optimizer, + checkpoint: Path, + gather_dtensor: bool, + ): + # TODO(FrankLeeeee): handle distributed tensors + save_state_dict(optimizer.state_dict(), checkpoint, use_safetensors=False) diff --git a/colossalai/checkpoint_io/index_file.py b/colossalai/checkpoint_io/index_file.py new file mode 100644 index 000000000000..32ff1b762e88 --- /dev/null +++ b/colossalai/checkpoint_io/index_file.py @@ -0,0 +1,150 @@ +import json +from pathlib import Path +from typing import Any, List, Union + +from .utils import is_dtensor_checkpoint + +__all__ = ['CheckpointIndexFile'] + + +class CheckpointIndexFile: + """ + This class is a data structure to keep the content in the index.json file for sharded checkpoint. + + Example: + >>> index = CheckpointIndexFile.from_file('model.index.json') + >>> index.append_metadata('model_type', 'bert') + >>> index.append_weight_map('bert.embeddings.word_embeddings.weight', 'model_0001-of-0002.bin') + >>> index.export('new_index.json') + """ + + def __init__(self) -> None: + self.root_path = None + self.metadata: dict = dict() + self.weight_map: dict = dict() + + @staticmethod + def from_file(index_path: Union[str, Path]): + """ + Create a CheckpointIndexFile object from a json file. + + Args: + index_path (str): path to the json file. + + Returns: + CheckpointIndexFile: CheckpointIndexFile object. + """ + index = CheckpointIndexFile() + index.load(index_path) + return index + + def load(self, json_path: str): + """ + Load the index file from a json file. + + Args: + json_path (str): path to the json file. + """ + # load the json file + with open(json_path, 'r') as f: + index = json.load(f) + + # assign attributes if exists + if "metadata" in index: + self.metadata = index["metadata"] + if "weight_map" in index: + self.weight_map = index["weight_map"] + + # assign the root directory for the index file + self.root_path = Path(json_path).absolute().parent + + def export(self, json_path: str): + """ + Export the index file to a json file. + + Args: + json_path (str): path to the json file. + """ + # create the index file + index = dict() + index["metadata"] = self.metadata + index["weight_map"] = self.weight_map + + # export the index file + with open(json_path, 'w') as f: + json.dump(index, f, indent=4) + + def append_weight_map(self, param_name: str, shard_file: str): + """ + Append a weight map entry to the index file. + + Args: + param_name (str): name of the parameter. + shard_file (str): name of the shard file. + """ + self.weight_map[param_name] = shard_file + + def append_meta_data(self, name: str, val: Any): + """ + Append a metadata entry to the index file. + + Args: + name (str): name of the metadata. + val (Any): value of the metadata. + """ + self.metadata[name] = val + + def contains_dtensor(self): + """ + Check if the index file contains any distributed tensor. The distributed tensors will be stored in + `dtensor/module.linear.weight.*.bin` or `dtensor/module.linear.weight.*.safetensors` in the weight map. + + Returns: + bool: True if the index file contains any distributed tensor, False otherwise. + """ + for value in self.weight_map.values(): + if value.endswith(".*.bin") or value.endswith(".*.safetensors"): + return True + return False + + def get_checkpoint_fileanames(self) -> List[str]: + """ + Get the set of checkpoint filenames in the weight map. + + Returns: + list: checkpoint shard filenames. + """ + # read the checkpoint file list from the json file and get a list of unique file names + checkpoint_files = sorted(list(set(self.weight_map.values()))) + + # get the absolute paths for all checkpoint files + checkpoint_files = [str(self.root_path.joinpath(f)) for f in checkpoint_files] + + dtensor_list = [] + checkpoint_list = [] + + for ckpt_file in checkpoint_files: + if is_dtensor_checkpoint(ckpt_file): + dtensor_list.append(ckpt_file) + else: + checkpoint_list.append(ckpt_file) + + return checkpoint_list, dtensor_list + + def assert_no_dtensor_checkpoint(self): + for val in self.weight_map.values(): + if is_dtensor_checkpoint(val): + raise ValueError(f"Checkpoint file {val} contains distributed tensor") + + def get_checkpoint_file(self, param_name: str) -> str: + """ + Get the checkpoint file name for a parameter. + + Args: + param_name (str): name of the parameter. + + Returns: + str: checkpoint file name. + """ + ckpt_path = self.weight_map[param_name] + return ckpt_path diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py new file mode 100644 index 000000000000..76c9db0afaff --- /dev/null +++ b/colossalai/checkpoint_io/utils.py @@ -0,0 +1,278 @@ +from pathlib import Path +from typing import List, Optional, Tuple + +import torch + +# ====================================== +# General helper functions +# ====================================== + + +def calculate_tensor_size(tensor: torch.Tensor) -> float: + """ + Calculate the size of a parameter in MB. Used to compute whether a group of params exceed the shard size. + If so, a new shard should be created. + + Args: + tenosr (torch.Tensor): the tensor to calculate size for. + + Returns: + float: size of the tensor in MB. + """ + return tensor.numel() * tensor.element_size() / 1024 / 1024 + + +def is_safetensors_available() -> bool: + """ + Check whether safetensors is available. + + Returns: + bool: whether safetensors is available. + """ + try: + import safetensors + return True + except ImportError: + return False + + +def is_dtensor_checkpoint(checkpoint_file_path: str) -> bool: + """ + Check whether the checkpoint file is a dtensor checkpoint. + + Args: + checkpoint_file_path (str): path to the checkpoint file. + + Returns: + bool: whether the checkpoint file is a dtensor checkpoint. + """ + if checkpoint_file_path.endswith('.*.safetensors') or checkpoint_file_path.endswith('.*.bin'): + return True + else: + return False + + +def is_safetensor_checkpoint(checkpoint_file_path: str) -> bool: + """ + Check whether the checkpoint file is a safetensor checkpoint. + + Args: + checkpoint_file_path (str): path to the checkpoint file. + + Returns: + bool: whether the checkpoint file is a safetensor checkpoint. + """ + if checkpoint_file_path.endswith('.safetensors'): + return True + else: + return False + + +# ====================================== +# Helper functions for saving state dict +# ====================================== + + +def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors: bool) -> None: + """ + Save state dict to checkpoint. + + Args: + state_dict (dict): state dict. + checkpoint_file_path (str): path to the checkpoint file. + use_safetensors (bool): whether to use safetensors to save the checkpoint. + """ + if use_safetensors: + assert is_safetensors_available(), "safetensors is not available." + assert checkpoint_file_path.endswith('.safetensors'), \ + "safetensors only supports .safetensors suffix for checkpoint file." + from safetensors.torch import save_file + save_file(state_dict, checkpoint_file_path) + else: + torch.save(state_dict, checkpoint_file_path) + + +def save_dtensor(name: str, tensor: torch.Tensor, index_file: "CheckpointIndexFile", use_safetensors: bool) -> None: + """ + Save distributed tensor to checkpoint. This checkpoint will be a dictionary which contains + only one tensor. + + Args: + tensor (Tensor): tensor to be saved. + index_file (CheckpointIndexFile): path to the checkpoint file. + size_per_shard (int): size per shard in MB. + """ + root_path = index_file.root_path + output_root_path = root_path.joinpath('dtensor') + + # create directory + output_root_path.mkdir(exist_ok=True) + + # save tensor to this directory + # TODO(YuliangLiu): get index of the tensor shard + # e.g. index = + index = 0 + + # save tensor to file + ckpt_file_name = generate_dtensor_file_name(name, index, use_safetensors) + ckpt_file_path = output_root_path.joinpath(ckpt_file_name) + + # dtensor ckpt file always contains only one tensor + state_dict = {name: tensor} + save_state_dict(state_dict, str(ckpt_file_path), use_safetensors) + + # update the weight map + # * means all shards + ckpt_file_name_in_weight_map = 'dtensor/' + generate_dtensor_file_name(name, '*', use_safetensors) + index_file.append_weight_map(name, ckpt_file_name_in_weight_map) + + +def get_checkpoint_file_suffix(use_safetensors: bool) -> str: + """ + Get checkpoint file suffix. + + Args: + use_safetensors (bool): whether to use safetensors to save the checkpoint. + + Returns: + str: checkpoint file suffix. + """ + if use_safetensors: + return '.safetensors' + else: + return '.bin' + + +def generate_checkpoint_shard_file_name(index: int, + total_number: int, + use_safetensors: bool, + prefix: str = None) -> str: + """ + Generate checkpoint shard file name. + + Args: + index (int): index of the shard. + total_number (int): total number of shards. + use_safetensors (bool): whether to use safetensors to save the checkpoint. + prefix (str): prefix of the shard file name. Default: None. + + Returns: + str: checkpoint shard file name. + """ + suffix = get_checkpoint_file_suffix(use_safetensors) + + if prefix is None: + return f"{index:05d}-of-{total_number:05d}.{suffix}" + else: + return f"{prefix}-{index:05d}-of-{total_number:05d}.{suffix}" + + +def generate_dtensor_file_name(param_name: str, index: int, use_safetensors: bool) -> str: + """ + Generate dtensor file name. + + Args: + param_name (str): name of the distributed parameter. + index (int): index of the shard. + use_safetensors (bool): whether to use safetensors to save the checkpoint. + + Returns: + str: dtensor file name. + """ + suffix = get_checkpoint_file_suffix(use_safetensors) + return f'{param_name}.{index}.{suffix}' + + +def save_state_dict_as_shard( + state_dict: dict, + checkpoint_path: str, + index: int, + total_number: int, + use_safetensors: bool, + prefix: str = None, +) -> None: + """ + Save state dict as shard. + + Args: + state_dict (dict): state dict. + checkpoint_path (str): path to the checkpoint file. + index (int): index of the shard. + total_number (int): total number of shards. + prefix (str): prefix of the shard file name. + use_safetensors (bool): whether to use safetensors to save the checkpoint. + """ + # generate the shard name + shard_file_name = generate_checkpoint_shard_file_name(index, total_number, use_safetensors, prefix) + shard_file_path = Path(checkpoint_path).joinpath(shard_file_name).absolute() + + # save the shard + save_state_dict(state_dict, str(shard_file_path), use_safetensors) + + +# ======================================== +# Helper functions for loading state dict +# ======================================== + + +def has_index_file(checkpoint_path: str) -> Tuple[bool, Optional[Path]]: + """ + Check whether the checkpoint has an index file. + + Args: + checkpoint_path (str): path to the checkpoint. + + Returns: + Tuple[bool, Optional[Path]]: a tuple of (has_index_file, index_file_path) + """ + checkpoint_path = Path(checkpoint_path) + if checkpoint_path.is_file(): + # check if it is .index.json + if checkpoint_path.name.endswith('.index.json'): + return True, checkpoint_path + else: + return False, None + elif checkpoint_path.is_dir(): + # check if there is only one a file ending with .index.json in this directory + index_files = list(checkpoint_path.glob('*.index.json')) + + # if we found a .index.json file, make sure there is only one + if len(index_files) > 0: + assert len( + index_files + ) == 1, f'Expected to find one .index.json file in {checkpoint_path}, but found {len(index_files)}' + + if len(index_files) == 1: + return True, index_files[0] + else: + return False, None + + +def load_state_dict(checkpoint_file_path: Path): + """ + Load state dict from checkpoint. + + Args: + checkpoint_file_path (Path): path to the checkpoint file. + + Returns: + dict: state dict. + """ + + assert not is_dtensor_checkpoint(checkpoint_file_path), \ + f'Cannot load state dict from dtensor checkpoint {checkpoint_file_path}, you should convert the distributed tensors to gathered tensors with our CLI offline.' + + if is_safetensor_checkpoint(checkpoint_file_path): + assert is_safetensors_available(), \ + f'Cannot load state dict from safetensor checkpoint {checkpoint_file_path}, because safetensors is not available. Please install safetensors first with pip install safetensors.' + # load with safetensors + from safetensors import safe_open + state_dict = {} + with safe_open(checkpoint_file_path, framework="pt", device="cpu") as f: + for k in f.keys(): + state_dict[k] = f.get_tensor(k) + return state_dict + + else: + # load with torch + return torch.load(checkpoint_file_path) diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 4e4f35edb2d9..b34dc2e223ae 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -9,3 +9,4 @@ fabric contexttimer ninja torch>=1.11 +safetensors diff --git a/tests/test_booster/test_plugin/test_torch_ddp_plugin.py b/tests/test_booster/test_plugin/test_torch_ddp_plugin.py index 2dcc5a5bba27..71e8582cc364 100644 --- a/tests/test_booster/test_plugin/test_torch_ddp_plugin.py +++ b/tests/test_booster/test_plugin/test_torch_ddp_plugin.py @@ -71,6 +71,29 @@ def check_dataloader_sharding(): batch_to_compare), 'Same number was found across ranks but expected it to be different' +def check_checkpoint_save_and_load(): + model_fn, data_gen_fn, output_transform_fn, _ = model_zoo['timm_resnet'] + + plugin = TorchDDPPlugin() + booster = Booster(plugin=plugin) + + model = model_fn() + optimizer = SGD(model.parameters(), lr=1e-3) + criterion = lambda x: x.mean() + data = data_gen_fn() + + data = {k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items()} + + model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) + + output = model(**data) + output = output_transform_fn(output) + output_key = list(output.keys())[0] + loss = criterion(output[output_key]) + + booster.backward(loss, optimizer) + + def run_dist(rank, world_size, port): # init dist env colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') diff --git a/tests/test_checkpoint_io/test_general_checkpoint_io.py b/tests/test_checkpoint_io/test_general_checkpoint_io.py index f9f0e03c4fa1..dfbb16af4ec6 100644 --- a/tests/test_checkpoint_io/test_general_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_general_checkpoint_io.py @@ -1,5 +1,6 @@ import tempfile +import pytest import torch from torch.optim import Adam from torchvision.models import resnet18 @@ -14,7 +15,8 @@ # ======== -def test_unsharded_checkpoint(): +@pytest.mark.parametrize('use_safetensors', [True, False]) +def test_unsharded_checkpoint(use_safetensors: bool): # create a model and optimizer model = resnet18() optimizer = Adam(model.parameters(), lr=0.001) @@ -29,12 +31,16 @@ def test_unsharded_checkpoint(): optimizer.step() # create a temp file for checkpoint - model_ckpt_tempfile = tempfile.NamedTemporaryFile() + if use_safetensors: + suffix = ".safetensors" + else: + suffix = ".bin" + model_ckpt_tempfile = tempfile.NamedTemporaryFile(suffix=suffix) optimizer_ckpt_tempfile = tempfile.NamedTemporaryFile() # save the model and optimizer ckpt_io = GeneralCheckpointIO() - ckpt_io.save_model(model, model_ckpt_tempfile.name) + ckpt_io.save_model(model, model_ckpt_tempfile.name, use_safetensors=use_safetensors) ckpt_io.save_optimizer(optimizer, optimizer_ckpt_tempfile.name) # create new model @@ -68,3 +74,4 @@ def recursive_check(d1, d2): # check for model and optimizer state dict recursively recursive_check(model.state_dict(), new_model.state_dict()) recursive_check(optimizer.state_dict(), new_optimizer.state_dict()) + recursive_check(optimizer.state_dict(), new_optimizer.state_dict()) From 773955abfaa3b5aef832ad4a33ce053183edee0e Mon Sep 17 00:00:00 2001 From: Yuanchen <70520919+chengeharrison@users.noreply.github.com> Date: Tue, 4 Apr 2023 15:30:01 +0800 Subject: [PATCH 06/27] fix save_model inin naive and ddp strategy (#3436) Co-authored-by: Yuanchen Xu --- .../Chat/coati/trainer/strategies/ddp.py | 34 ++++++++++++++----- .../Chat/coati/trainer/strategies/naive.py | 27 ++++++++++++--- 2 files changed, 49 insertions(+), 12 deletions(-) diff --git a/applications/Chat/coati/trainer/strategies/ddp.py b/applications/Chat/coati/trainer/strategies/ddp.py index 83cbbe633de9..8a8c4b3c2f4e 100644 --- a/applications/Chat/coati/trainer/strategies/ddp.py +++ b/applications/Chat/coati/trainer/strategies/ddp.py @@ -1,3 +1,5 @@ +from typing import Optional + import os import random @@ -5,12 +7,13 @@ import torch import torch.distributed as dist import torch.nn as nn -from coati.models.base import Actor +from coati.models.base import LM, Actor, RewardModel from coati.models.lora import LoraLinear from coati.replay_buffer import ReplayBuffer from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import Optimizer from torch.utils.data import DataLoader +from transformers.tokenization_utils_base import PreTrainedTokenizerBase from .base import Strategy from .naive import NaiveStrategy @@ -72,17 +75,32 @@ def _unwrap_actor(actor: Actor) -> nn.Module: model: DDP = Strategy._unwrap_actor(actor) return model.module - def save_model(self, model: nn.Module, path: str, only_rank0: bool = False) -> None: + def save_model(self, model: nn.Module, path: str, only_rank0: bool = False, tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None: + if only_rank0 and dist.get_rank() != 0: + return None + for module in model.modules(): if isinstance(module, LoraLinear): module.merge_weights = True module.eval() - - if only_rank0 and dist.get_rank() != 0: - return - model = model.model.module - state_dict = model.state_dict() - torch.save(state_dict, path) + + if isinstance(model, RewardModel): + state_dict = model.state_dict() + if only_rank0 and dist.get_rank() != 0: + return + torch.save(state_dict, path) + else: + try: + if isinstance(model, LM): + model = model.model + model.save_pretrained(path) + if tokenizer is not None: + tokenizer.save_pretrained(path) + except AttributeError: + state_dict = model.state_dict() + if only_rank0 and dist.get_rank() != 0: + return + torch.save(state_dict, path) def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None: if only_rank0 and dist.get_rank() != 0: diff --git a/applications/Chat/coati/trainer/strategies/naive.py b/applications/Chat/coati/trainer/strategies/naive.py index 80768d7e649c..bb47e5ab2688 100644 --- a/applications/Chat/coati/trainer/strategies/naive.py +++ b/applications/Chat/coati/trainer/strategies/naive.py @@ -1,11 +1,14 @@ -from typing import Any +from typing import Any, Optional import torch import torch.nn as nn import torch.optim as optim from coati.replay_buffer import ReplayBuffer +from coati.models.base import LM, RewardModel +from coati.models.lora import LoraLinear from torch.optim import Optimizer from torch.utils.data import DataLoader +from transformers.tokenization_utils_base import PreTrainedTokenizerBase from .base import Strategy @@ -38,9 +41,25 @@ def setup_dataloader(self, replay_buffer: ReplayBuffer, pin_memory: bool = False pin_memory=pin_memory, collate_fn=replay_buffer.collate_fn) - def save_model(self, model: nn.Module, path: str, only_rank0: bool = False) -> None: - unwrapped_model = self._unwrap_model(model) - torch.save(unwrapped_model.state_dict(), path) + def save_model(self, model: nn.Module, path: str, only_rank0: bool = False, tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None: + for module in model.modules(): + if isinstance(module, LoraLinear): + module.merge_weights = True + module.eval() + + if isinstance(model, RewardModel): + state_dict = model.state_dict() + torch.save(state_dict, path) + else: + try: + if isinstance(model, LM): + model = model.model + model.save_pretrained(path) + if tokenizer is not None: + tokenizer.save_pretrained(path) + except AttributeError: + state_dict = model.state_dict() + torch.save(state_dict, path) def load_model(self, model: nn.Module, path: str, map_location: Any = None, strict: bool = True) -> None: unwrapped_model = self._unwrap_model(model) From 573af8418406a319e91be07f58fca798a6e72dbd Mon Sep 17 00:00:00 2001 From: ver217 Date: Tue, 4 Apr 2023 17:32:51 +0800 Subject: [PATCH 07/27] [example] update examples related to zero/gemini (#3431) * [zero] update legacy import * [zero] update examples * [example] fix opt tutorial * [example] fix opt tutorial * [example] fix opt tutorial * [example] fix opt tutorial * [example] fix import --- colossalai/zero/legacy/__init__.py | 3 ++- .../roberta/configs/colossalai_ddp.py | 7 ++++++- .../roberta/configs/colossalai_zero.py | 9 ++++++-- examples/tutorial/opt/opt/colossalai_zero.py | 6 +++++- examples/tutorial/opt/opt/requirements.txt | 1 + examples/tutorial/opt/opt/run_clm.py | 6 +++++- examples/tutorial/opt/opt/test_ci.sh | 21 +++++++++++++++++++ examples/tutorial/opt/test_ci.sh | 3 +++ 8 files changed, 50 insertions(+), 6 deletions(-) create mode 100755 examples/tutorial/opt/opt/test_ci.sh create mode 100755 examples/tutorial/opt/test_ci.sh diff --git a/colossalai/zero/legacy/__init__.py b/colossalai/zero/legacy/__init__.py index 35570a1f539a..3783d38e61b2 100644 --- a/colossalai/zero/legacy/__init__.py +++ b/colossalai/zero/legacy/__init__.py @@ -6,6 +6,7 @@ from colossalai.logging import get_dist_logger from .init_ctx import ZeroInitContext, no_shard_zero_context, no_shard_zero_decrator +from .shard_utils import BucketTensorShardStrategy, TensorShardStrategy from .sharded_model import ShardedModelV2 from .sharded_optim import ShardedOptimizerV2 @@ -40,5 +41,5 @@ def convert_to_zero_v2(model: nn.Module, optimizer: torch.optim.Optimizer, model __all__ = [ 'convert_to_zero_v2', 'ShardedModelV2', 'ShardedOptimizerV2', 'ZeroInitContext', 'no_shard_zero_context', - 'no_shard_zero_decrator' + 'no_shard_zero_decrator', 'TensorShardStrategy', 'BucketTensorShardStrategy' ] diff --git a/examples/language/roberta/configs/colossalai_ddp.py b/examples/language/roberta/configs/colossalai_ddp.py index c3c59aa4079c..3146ffc45eef 100644 --- a/examples/language/roberta/configs/colossalai_ddp.py +++ b/examples/language/roberta/configs/colossalai_ddp.py @@ -1,4 +1,9 @@ -from colossalai.zero.shard_utils import TensorShardStrategy from colossalai.nn.optimizer import FusedAdam +try: + from colossalai.zero.shard_utils import TensorShardStrategy +except ImportError: + # colossalai > 0.2.8 + from colossalai.zero.legacy import TensorShardStrategy + clip_grad_norm = 1.0 diff --git a/examples/language/roberta/configs/colossalai_zero.py b/examples/language/roberta/configs/colossalai_zero.py index c5debdce0988..bae4c723ccc8 100644 --- a/examples/language/roberta/configs/colossalai_zero.py +++ b/examples/language/roberta/configs/colossalai_zero.py @@ -1,6 +1,11 @@ -from colossalai.zero.shard_utils import TensorShardStrategy from colossalai.nn.optimizer import FusedAdam +try: + from colossalai.zero.shard_utils import TensorShardStrategy +except ImportError: + # colossalai > 0.2.8 + from colossalai.zero.legacy import TensorShardStrategy + # fp16 = dict( # mode=AMP_TYPE.TORCH, # ) @@ -29,4 +34,4 @@ weight_decay=1e-2, ) -# 64433 \ No newline at end of file +# 64433 diff --git a/examples/tutorial/opt/opt/colossalai_zero.py b/examples/tutorial/opt/opt/colossalai_zero.py index 833745f3e8d8..7c2c152450c5 100644 --- a/examples/tutorial/opt/opt/colossalai_zero.py +++ b/examples/tutorial/opt/opt/colossalai_zero.py @@ -1,4 +1,8 @@ -from colossalai.zero.shard_utils import TensorShardStrategy +try: + from colossalai.zero.shard_utils import TensorShardStrategy +except ImportError: + # colossalai > 0.2.8 + from colossalai.zero.legacy import TensorShardStrategy zero = dict(model_config=dict(shard_strategy=TensorShardStrategy(), tensor_placement_policy="auto", diff --git a/examples/tutorial/opt/opt/requirements.txt b/examples/tutorial/opt/opt/requirements.txt index c34df7992d3f..d0ed2c717aee 100644 --- a/examples/tutorial/opt/opt/requirements.txt +++ b/examples/tutorial/opt/opt/requirements.txt @@ -4,3 +4,4 @@ datasets >= 1.8.0 sentencepiece != 0.1.92 protobuf accelerate == 0.13.2 +transformers diff --git a/examples/tutorial/opt/opt/run_clm.py b/examples/tutorial/opt/opt/run_clm.py index e618b4d66957..fdc86adab665 100755 --- a/examples/tutorial/opt/opt/run_clm.py +++ b/examples/tutorial/opt/opt/run_clm.py @@ -413,7 +413,11 @@ def main(): cai_version = colossalai.__version__ logger.info(f'using Colossal-AI version {cai_version}') if version.parse(cai_version) > version.parse("0.1.10"): - from colossalai.nn.parallel import GeminiDDP + try: + from colossalai.nn.parallel import GeminiDDP + except ImportError: + # this works for unreleased main branch, and this may be released on 0.2.9 + from colossalai.zero import GeminiDDP model = GeminiDDP(model, device=get_current_device(), placement_policy=PLACEMENT_POLICY, pin_memory=True) elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"): from colossalai.gemini import ChunkManager, GeminiManager diff --git a/examples/tutorial/opt/opt/test_ci.sh b/examples/tutorial/opt/opt/test_ci.sh new file mode 100755 index 000000000000..e505da1364de --- /dev/null +++ b/examples/tutorial/opt/opt/test_ci.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +set -xue + +pip install -r requirements.txt + +BS=8 +MEMCAP=0 +GPUNUM=2 +MODLE="facebook/opt-125m" + +torchrun \ + --nproc_per_node ${GPUNUM} \ + --master_port 19198 \ + run_clm.py \ + -s \ + --output_dir $PWD \ + --mem_cap ${MEMCAP} \ + --model_name_or_path ${MODLE} \ + --per_device_train_batch_size ${BS} \ + --num_train_epochs 1 diff --git a/examples/tutorial/opt/test_ci.sh b/examples/tutorial/opt/test_ci.sh new file mode 100755 index 000000000000..8341bb10510f --- /dev/null +++ b/examples/tutorial/opt/test_ci.sh @@ -0,0 +1,3 @@ +#!/bin/bash + +cd opt && bash test_ci.sh From ffcdbf0f6519366f322f2809e19f741492779c6c Mon Sep 17 00:00:00 2001 From: YuliangLiu0306 <72588413+YuliangLiu0306@users.noreply.github.com> Date: Tue, 4 Apr 2023 17:40:45 +0800 Subject: [PATCH 08/27] [autoparallel]integrate auto parallel feature with new tracer (#3408) * [autoparallel] integrate new analyzer in module level * unify the profiling method * polish * fix no codegen bug * fix pass bug * fix liveness test * polish --- .../_analyzer/_subclasses/flop_tensor.py | 23 +++- colossalai/_analyzer/fx/codegen.py | 19 +-- colossalai/_analyzer/fx/node_util.py | 2 +- colossalai/_analyzer/fx/passes/shape_prop.py | 9 +- .../_analyzer/fx/tracer/bias_addition.py | 114 +++++++++++----- colossalai/_analyzer/fx/tracer/tracer.py | 2 +- .../auto_parallel/meta_profiler/__init__.py | 2 +- .../meta_profiler/meta_registry/activation.py | 4 +- .../meta_registry/binary_elementwise_ops.py | 6 +- .../meta_profiler/meta_registry/conv.py | 28 ++-- .../meta_profiler/meta_registry/embedding.py | 8 +- .../meta_profiler/meta_registry/linear.py | 79 +++++------ .../meta_profiler/meta_registry/non_spmd.py | 2 - .../meta_profiler/meta_registry/norm.py | 34 ++--- .../meta_profiler/meta_registry/pooling.py | 14 +- .../meta_profiler/meta_registry/tensor.py | 10 +- .../meta_profiler/meta_registry/where.py | 4 +- .../{metainfo.py => shard_metainfo.py} | 16 +-- .../passes/comm_metainfo_pass.py | 28 ++-- .../auto_parallel/passes/meta_info_prop.py | 10 +- .../passes/runtime_apply_pass.py | 16 ++- .../passes/runtime_preparation_pass.py | 14 +- .../auto_parallel/tensor_shard/initialize.py | 17 ++- .../node_handler/batch_norm_handler.py | 2 - .../tensor_shard/node_handler/node_handler.py | 22 +-- .../solver/strategies_constructor.py | 18 +-- tests/test_analyzer/__init__.py | 0 .../test_size_value_converting_pass.py | 11 +- .../test_bias_addition_forward.py | 8 +- .../test_tensor_shard/test_checkpoint.py | 14 +- .../test_compatibility_with_ddp.py | 8 +- .../test_compatibility_with_gemini.py | 8 +- .../test_find_repeat_block.py | 8 +- .../test_gpt/test_runtime_with_gpt_modules.py | 34 +++-- .../test_gpt/test_solver_with_gpt_module.py | 8 +- .../test_liveness_analysis.py | 16 ++- .../test_metainfo/test_embedding_metainfo.py | 2 +- .../test_metainfo/test_linear_metainfo.py | 2 +- .../test_metainfo/test_matmul_metainfo.py | 2 +- .../test_metainfo/test_norm_metainfo.py | 2 +- .../test_metainfo/test_tensor_metainfo.py | 2 +- .../test_metainfo/test_where_metainfo.py | 2 +- .../test_tensor_shard/test_metainfo/utils.py | 19 ++- .../test_param_resharding_cost.py | 126 ------------------ .../test_shape_consistency_pass.py | 86 ------------ .../test_solver_with_resnet_v2.py | 7 +- 46 files changed, 397 insertions(+), 471 deletions(-) rename colossalai/auto_parallel/meta_profiler/{metainfo.py => shard_metainfo.py} (94%) create mode 100644 tests/test_analyzer/__init__.py delete mode 100644 tests/test_auto_parallel/test_tensor_shard/test_param_resharding_cost.py delete mode 100644 tests/test_auto_parallel/test_tensor_shard/test_shape_consistency_pass.py diff --git a/colossalai/_analyzer/_subclasses/flop_tensor.py b/colossalai/_analyzer/_subclasses/flop_tensor.py index dd35b00b3fab..59991dc50912 100644 --- a/colossalai/_analyzer/_subclasses/flop_tensor.py +++ b/colossalai/_analyzer/_subclasses/flop_tensor.py @@ -235,7 +235,28 @@ def matmul_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number: # Inputs contains the shapes of two matrices. input_shapes = [v.shape for v in inputs] assert len(input_shapes) == 2, input_shapes - assert input_shapes[0][-1] == input_shapes[1][-2], input_shapes + + # There are three cases: 1) gemm, 2) gemv, 3) dot + if all(len(shape) == 2 for shape in input_shapes): + # gemm + assert input_shapes[0][-1] == input_shapes[1][-2], input_shapes + elif all(len(shape) == 1 for shape in input_shapes): + # dot + assert input_shapes[0][0] == input_shapes[1][0], input_shapes + + # expand shape + input_shapes[0] = torch.Size([1, input_shapes[0][0]]) + input_shapes[1] = torch.Size([input_shapes[1][0], 1]) + else: + # gemv + if len(input_shapes[0]) == 1: + assert input_shapes[0][0] == input_shapes[1][-2], input_shapes + input_shapes.reverse() + else: + assert input_shapes[1][0] == input_shapes[0][-1], input_shapes + + # expand the shape of the vector to [batch size, 1] + input_shapes[-1] = torch.Size([input_shapes[-1][-1], 1]) flops = reduce(operator.mul, input_shapes[0]) * input_shapes[-1][-1] return flops diff --git a/colossalai/_analyzer/fx/codegen.py b/colossalai/_analyzer/fx/codegen.py index 1117c0103166..b768e59004b1 100644 --- a/colossalai/_analyzer/fx/codegen.py +++ b/colossalai/_analyzer/fx/codegen.py @@ -1,8 +1,12 @@ from typing import Any, Callable, Dict, Iterable, List, Tuple import torch + +try: + from torch.fx.graph import CodeGen +except: + pass from torch.fx.graph import ( - CodeGen, PythonCode, _custom_builtins, _format_target, @@ -48,8 +52,8 @@ def _end_of_ckpt(node: Node, ckpt_level: int) -> bool: """ Check if the node could end the ckpt region at `ckpt_level` """ - if len(node.meta['info'].to_recompute) > ckpt_level: - return node.meta['info'].to_recompute[ckpt_level] is not None + if len(node.meta['info'].activation_checkpoint) > ckpt_level: + return node.meta['info'].activation_checkpoint[ckpt_level] is not None return True @@ -90,8 +94,8 @@ def _find_nested_ckpt_regions(node_list: List[Node], ckpt_level: int = 0): current_region = None for idx, node in enumerate(node_list): - if len(node.meta['info'].to_recompute) > ckpt_level: - act_ckpt_label = node.meta['info'].to_recompute[ckpt_level] + if len(node.meta['info'].activation_checkpoint) > ckpt_level: + act_ckpt_label = node.meta['info'].activation_checkpoint[ckpt_level] # this activation checkpoint label is not set yet # meaning this is the first node of the activation ckpt region @@ -152,12 +156,12 @@ def emit_ckpt_func(body, # label given by each layer, e.g. if you are currently at level (0, 1, 1) # the label will be '0_1_1' - label = "_".join([str(idx) for idx in node_list[0].meta['info'].to_recompute[:ckpt_level + 1]]) + label = "_".join([str(idx) for idx in node_list[0].meta['info'].activation_checkpoint[:ckpt_level + 1]]) ckpt_fn_def = _gen_ckpt_fn_def(label, inputs) ckpt_func.append(f'{ckpt_fn_def}\n') # if there is more level to fetch - if ckpt_level + 1 < max(map(lambda node: len(node.meta['info'].to_recompute), node_list)): + if ckpt_level + 1 < max(map(lambda node: len(node.meta['info'].activation_checkpoint), node_list)): ckpt_regions = _find_nested_ckpt_regions(node_list, ckpt_level + 1) start_idx = [item[0] for item in ckpt_regions] end_idx = [item[1] for item in ckpt_regions] @@ -215,7 +219,6 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, ckpt_regions = _find_nested_ckpt_regions(nodes, 0) start_idx = [item[0] for item in ckpt_regions] end_idx = [item[1] for item in ckpt_regions] - node_list = list(nodes) node_idx = 0 diff --git a/colossalai/_analyzer/fx/node_util.py b/colossalai/_analyzer/fx/node_util.py index 8c8956d8ea7c..fbe8400a437e 100644 --- a/colossalai/_analyzer/fx/node_util.py +++ b/colossalai/_analyzer/fx/node_util.py @@ -112,7 +112,7 @@ class MetaInfo: # should keep the same whenever manipulated # ============================= Invariant ================================== - to_recompute: Tuple[torch.Tensor] = () # (region_0, region_1, ...) support nested codegen + activation_checkpoint: Tuple[torch.Tensor] = () # (region_0, region_1, ...) support nested codegen to_offload: Optional[bool] = False sharding_spec: str = 'RR' diff --git a/colossalai/_analyzer/fx/passes/shape_prop.py b/colossalai/_analyzer/fx/passes/shape_prop.py index b3859e250ac8..23e83013e02f 100644 --- a/colossalai/_analyzer/fx/passes/shape_prop.py +++ b/colossalai/_analyzer/fx/passes/shape_prop.py @@ -237,7 +237,14 @@ def propagate(self, *args, device=None): Returns: Any: The value returned from executing the Module """ - wrap_fn = lambda elem: MetaTensor(elem, device=device) + + # wrap_fn = lambda elem: MetaTensor(elem, device=device) + def wrap_fn(elem, device=device): + if isinstance(elem, torch.Tensor): + return MetaTensor(elem, device=device) + else: + return elem + with self._mode: return super().run(*tree_map(wrap_fn, args)) diff --git a/colossalai/_analyzer/fx/tracer/bias_addition.py b/colossalai/_analyzer/fx/tracer/bias_addition.py index 495678501664..1e75b47ca5b0 100644 --- a/colossalai/_analyzer/fx/tracer/bias_addition.py +++ b/colossalai/_analyzer/fx/tracer/bias_addition.py @@ -21,69 +21,111 @@ def linear_impl(input, weight, bias=None): @register_tracer_impl(F.conv1d, name='_bias_addition_impl') -def conv1d_impl(input, weight, **kwargs): - bias = getattr(kwargs, 'bias', None) +def conv1d_impl(input, weight, bias=None, stride=_single(1), padding=_single(0), dilation=_single(1), groups=1): if bias is None: - return F.conv1d(input, weight, **kwargs) + return F.conv1d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) else: - new_kwargs = kwargs - new_kwargs['bias'] = None - return F.conv1d(input, weight, **kwargs) + bias.reshape((-1, 1)) + return F.conv1d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape( + (-1, 1)) @register_tracer_impl(F.conv2d, name='_bias_addition_impl') -def conv2d_impl(input, weight, **kwargs): - bias = getattr(kwargs, 'bias', None) +def conv2d_impl(input, weight, bias=None, stride=_pair(1), padding=_pair(0), dilation=_pair(1), groups=1): if bias is None: - return F.conv2d(input, weight, **kwargs) + return F.conv2d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) else: - new_kwargs = kwargs - new_kwargs['bias'] = None - return F.conv2d(input, weight, **kwargs) + bias.reshape((-1, 1, 1)) + return F.conv2d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape( + (-1, 1, 1)) @register_tracer_impl(F.conv3d, name='_bias_addition_impl') -def conv3d_impl(input, weight, **kwargs): - bias = getattr(kwargs, 'bias', None) +def conv3d_impl(input, weight, bias=None, stride=_triple(1), padding=_triple(0), dilation=_triple(1), groups=1): if bias is None: - return F.conv3d(input, weight, **kwargs) + return F.conv3d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) else: - new_kwargs = kwargs - new_kwargs['bias'] = None - return F.conv3d(input, weight, **new_kwargs) + bias.reshape((-1, 1, 1, 1)) + return F.conv3d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape( + (-1, 1, 1, 1)) @register_tracer_impl(F.conv_transpose1d, name='_bias_addition_impl') -def conv_transpose1d_impl(input, weight, **kwargs): - bias = getattr(kwargs, 'bias', None) +def conv_transpose1d_impl(input, + weight, + bias=None, + stride=_single(1), + padding=_single(0), + output_padding=_single(0), + groups=1, + dilation=_single(1)): if bias is None: - return F.conv_transpose1d(input, weight, **kwargs) + return F.conv_transpose1d(input, + weight, + stride=stride, + padding=padding, + output_padding=output_padding, + groups=groups, + dilation=dilation) else: - new_kwargs = kwargs - new_kwargs['bias'] = None - return F.conv_transpose1d(input, weight, **new_kwargs) + bias.reshape((-1, 1)) + return F.conv_transpose1d(input, + weight, + stride=stride, + padding=padding, + output_padding=output_padding, + groups=groups, + dilation=dilation) + bias.reshape((-1, 1)) @register_tracer_impl(F.conv_transpose2d, name='_bias_addition_impl') -def conv_transpose2d_impl(input, weight, **kwargs): - bias = getattr(kwargs, 'bias', None) +def conv_transpose2d_impl(input, + weight, + bias=None, + stride=_pair(1), + padding=_pair(0), + output_padding=_pair(0), + groups=1, + dilation=_pair(1)): if bias is None: - return F.conv_transpose2d(input, weight, **kwargs) + return F.conv_transpose2d(input, + weight, + stride=stride, + padding=padding, + output_padding=output_padding, + groups=groups, + dilation=dilation) else: - new_kwargs = kwargs - new_kwargs['bias'] = None - return F.conv_transpose2d(input, weight, **new_kwargs) + bias.reshape((-1, 1, 1)) + return F.conv_transpose2d(input, + weight, + stride=stride, + padding=padding, + output_padding=output_padding, + groups=groups, + dilation=dilation) + bias.reshape((-1, 1, 1)) @register_tracer_impl(F.conv_transpose3d, name='_bias_addition_impl') -def conv_transpose3d_impl(input, weight, **kwargs): - bias = getattr(kwargs, 'bias', None) +def conv_transpose3d_impl(input, + weight, + bias=None, + stride=_triple(1), + padding=_triple(0), + output_padding=_triple(0), + groups=1, + dilation=_triple(1)): if bias is None: - return F.conv_transpose3d(input, weight, **kwargs) + return F.conv_transpose3d(input, + weight, + stride=stride, + padding=padding, + output_padding=output_padding, + groups=groups, + dilation=dilation) else: - new_kwargs = kwargs - new_kwargs['bias'] = None - return F.conv_transpose3d(input, weight, **new_kwargs) + bias.reshape((-1, 1, 1, 1)) + return F.conv_transpose3d(input, + weight, + stride=stride, + padding=padding, + output_padding=output_padding, + groups=groups, + dilation=dilation) + bias.reshape((-1, 1, 1, 1)) @register_tracer_impl(torch.addmm, name='_bias_addition_impl') diff --git a/colossalai/_analyzer/fx/tracer/tracer.py b/colossalai/_analyzer/fx/tracer/tracer.py index 1a247449f3d8..6958a00a6a72 100644 --- a/colossalai/_analyzer/fx/tracer/tracer.py +++ b/colossalai/_analyzer/fx/tracer/tracer.py @@ -155,7 +155,7 @@ def create_proxy(self, def create_node(self, *args, **kwargs) -> Node: node = super().create_node(*args, **kwargs) - n_info = MetaInfo(node, mod_dir=self.mod_dir, to_recompute=tuple(self.ckpt_regions)) + n_info = MetaInfo(node, mod_dir=self.mod_dir, activation_checkpoint=tuple(self.ckpt_regions)) return node def trace(self, diff --git a/colossalai/auto_parallel/meta_profiler/__init__.py b/colossalai/auto_parallel/meta_profiler/__init__.py index bfd36195149b..3741d8e5a8ad 100644 --- a/colossalai/auto_parallel/meta_profiler/__init__.py +++ b/colossalai/auto_parallel/meta_profiler/__init__.py @@ -1,3 +1,3 @@ from .meta_registry import * -from .metainfo import * from .registry import meta_register +from .shard_metainfo import * diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/activation.py b/colossalai/auto_parallel/meta_profiler/meta_registry/activation.py index faeed9f29e61..0f2e9e44f91c 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/activation.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/activation.py @@ -2,9 +2,9 @@ import torch +from colossalai._analyzer._subclasses.flop_tensor import ewise_flop_counter as elementwise_flop_counter +from colossalai._analyzer.fx.node_util import compute_size_in_bytes as activation_size from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem -from colossalai.fx.profiler.memory_utils import activation_size -from colossalai.fx.profiler.opcount import elementwise_flop_counter from ..registry import meta_register diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py b/colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py index 281a92c0d4f1..e451748512b9 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py @@ -2,9 +2,9 @@ import torch +from colossalai._analyzer._subclasses.flop_tensor import flop_mapping +from colossalai._analyzer.fx.node_util import compute_size_in_bytes as activation_size from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem -from colossalai.fx.profiler.memory_utils import activation_size -from colossalai.fx.profiler.opcount import flop_mapping from ..constants import BCAST_FUNC_OP, NO_SAVE_ACTIVATION from ..registry import meta_register @@ -17,7 +17,7 @@ def binary_elementwise_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, Train """Meta information generator for binary elementwise operations NOTE: Some of the binary elementwise operations will discard the input activation after computation, as they don't need those tensors for back propagation, for example, if there are two tensors being sent for `torch.add`, - they will be discarded right after add operation is done. We create a simple API in `MetaInfo` class to identify + they will be discarded right after add operation is done. We create a simple API in `ShardMetaInfo` class to identify this behavior, it is critical for better memory estimation. Returns: diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/conv.py b/colossalai/auto_parallel/meta_profiler/meta_registry/conv.py index d1bb6e7fa798..4336bf68363c 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/conv.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/conv.py @@ -2,6 +2,8 @@ import torch +from colossalai._analyzer._subclasses.flop_tensor import flop_mapping +from colossalai._analyzer.fx.node_util import compute_size_in_bytes from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( MemoryCost, OperationData, @@ -10,8 +12,6 @@ StrategiesVector, TrainCycleItem, ) -from colossalai.fx.profiler.memory_utils import activation_size -from colossalai.fx.profiler.opcount import flop_mapping from colossalai.tensor.sharding_spec import ShardingSpec from ..registry import meta_register @@ -110,18 +110,18 @@ def convnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L # calculate memory cost # TODO: use profiler to check conv temp memory # NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward - fwd_memory_cost = MemoryCost( - activation=activation_size([input_tensor, output_tensor]), - parameter=activation_size([weight_tensor, bias_tensor]) if has_bias else activation_size(weight_tensor), - temp=0, - buffer=0) - - bwd_memory_cost = MemoryCost( - activation=activation_size([input_tensor, weight_tensor, bias_tensor]) - if has_bias else activation_size([input_tensor, weight_tensor]), - parameter=activation_size([weight_tensor, bias_tensor]) if has_bias else activation_size(weight_tensor), - temp=0, - buffer=0) + fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor]), + parameter=compute_size_in_bytes([weight_tensor, bias_tensor]) + if has_bias else compute_size_in_bytes(weight_tensor), + temp=0, + buffer=0) + + bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor]) + if has_bias else compute_size_in_bytes([input_tensor, weight_tensor]), + parameter=compute_size_in_bytes([weight_tensor, bias_tensor]) + if has_bias else compute_size_in_bytes(weight_tensor), + temp=0, + buffer=0) # total cost is the sum of forward and backward cost total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation, diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/embedding.py b/colossalai/auto_parallel/meta_profiler/meta_registry/embedding.py index 2997f31adff8..d5d80f5b3700 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/embedding.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/embedding.py @@ -2,9 +2,9 @@ import torch +from colossalai._analyzer._subclasses.flop_tensor import flop_mapping +from colossalai._analyzer.fx.node_util import compute_size_in_bytes from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem -from colossalai.fx.profiler.memory_utils import activation_size -from colossalai.fx.profiler.opcount import flop_mapping from ..registry import meta_register @@ -34,11 +34,11 @@ def embedding_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem # NOTE: during the backward phase of torch.nn.Embedding, it seems when the input is large enough, it will # have a temp memory which is kind of weird and we don't know the reason yet, so currently we just assume # that there will be no temp memory, as the temp memory is significantly smaller than the gradient memory - fwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, output_tensor]), + fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor]), parameter=0, temp=0, buffer=0) - bwd_memory_cost = MemoryCost(activation=activation_size([weight_tensor]), parameter=0, temp=0, buffer=0) + bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([weight_tensor]), parameter=0, temp=0, buffer=0) total_memory_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation) diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/linear.py b/colossalai/auto_parallel/meta_profiler/meta_registry/linear.py index 617375721222..7697fc6c383d 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/linear.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/linear.py @@ -3,6 +3,8 @@ import torch +from colossalai._analyzer._subclasses.flop_tensor import flop_mapping +from colossalai._analyzer.fx.node_util import compute_size_in_bytes from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( MemoryCost, OperationData, @@ -11,8 +13,6 @@ StrategiesVector, TrainCycleItem, ) -from colossalai.fx.profiler.memory_utils import activation_size -from colossalai.fx.profiler.opcount import flop_mapping from colossalai.tensor.sharding_spec import ShardingSpec from ..registry import meta_register @@ -112,14 +112,14 @@ def linear_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L # NOTE: Linear don't have buffer and temp in forward and backward phase # the forward activation cost is the size of output_tensor, parameter cost is the size of weight_tensor and bias_tensor # NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward - fwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, output_tensor]), - parameter=activation_size([weight_tensor, bias_tensor]), + fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor]), + parameter=compute_size_in_bytes([weight_tensor, bias_tensor]), temp=0, buffer=0) # the backward activation cost is the size of input_tensor, weight_tensor and bias_tensor, parameter cost is 0 - bwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, weight_tensor, bias_tensor]), - parameter=activation_size([weight_tensor, bias_tensor]), + bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor]), + parameter=compute_size_in_bytes([weight_tensor, bias_tensor]), temp=0, buffer=0) @@ -148,14 +148,14 @@ def linear_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L # NOTE: Linear don't have buffer and temp in forward and backward phase # the forward activation cost is the size of output_tensor, parameter cost is the size of weight_tensor # NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward - fwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, output_tensor]), - parameter=activation_size(weight_tensor), + fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor]), + parameter=compute_size_in_bytes(weight_tensor), temp=0, buffer=0) # the backward activation cost is the size of input_tensor and weight_tensor, parameter cost is 0 - bwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, weight_tensor]), - parameter=activation_size(weight_tensor), + bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, weight_tensor]), + parameter=compute_size_in_bytes(weight_tensor), temp=0, buffer=0) @@ -210,48 +210,48 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L # Check dimension if all(len(tensor.shape) == 1 for tensor in input_tensors): # Dot - fwd_compute_cost = flop_mapping[torch.ops.aten.dot.default](input_tensors, output_tensors) + fwd_compute_cost = flop_mapping[torch.ops.aten.matmul.default](input_tensors, output_tensors) bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor](input_tensors[0], output_tensors) * 2 - fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors), parameter=0, temp=0, buffer=0) - bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors), parameter=0, temp=0, buffer=0) + fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors), parameter=0, temp=0, buffer=0) + bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors), parameter=0, temp=0, buffer=0) elif len(input_tensors[0].shape) >= 2 and len(input_tensors[1].shape) == 1: # gemv case 1: matrix-vector multiplication # & # batched gemv case 1: batched matrix-vector multiplication - fwd_compute_cost = flop_mapping[torch.ops.aten.mv.default]( + fwd_compute_cost = flop_mapping[torch.ops.aten.matmul.default]( [input_tensors[0].reshape(-1, input_tensors[0].shape[-1]), input_tensors[1]], output_tensors) # combine the dimensions of output bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor]( [output_tensors[0].reshape(-1), input_tensors[1]], output_tensors) + \ - flop_mapping[torch.ops.aten.mv.default]( + flop_mapping[torch.ops.aten.matmul.default]( [input_tensors[0].reshape(-1, input_tensors[0].shape[-1]).transpose(0, 1), output_tensors[0].reshape(-1)], output_tensors) - fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors), parameter=0, temp=0, buffer=0) - bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors), parameter=0, temp=0, buffer=0) + fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors), parameter=0, temp=0, buffer=0) + bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors), parameter=0, temp=0, buffer=0) elif len(input_tensors[0].shape) == 1 and len(input_tensors[1].shape) == 2: # gemv case 2: vector-matrix multiplication - fwd_compute_cost = flop_mapping[torch.ops.aten.mv.default](input_tensors, output_tensors) + fwd_compute_cost = flop_mapping[torch.ops.aten.matmul.default](input_tensors, output_tensors) bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor]([output_tensors[0], input_tensors[0]], output_tensors) + \ - flop_mapping[torch.ops.aten.mv.default]([input_tensors[1], output_tensors[0]], output_tensors) + flop_mapping[torch.ops.aten.matmul.default]([input_tensors[1], output_tensors[0]], output_tensors) - fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors), parameter=0, temp=0, buffer=0) - bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors), + fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors), parameter=0, temp=0, buffer=0) + bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors), parameter=0, - temp=activation_size(input_tensors[1]), + temp=compute_size_in_bytes(input_tensors[1]), buffer=0) elif len(input_tensors[0].shape) == 1 and len(input_tensors[1].shape) >= 3: # batched gemv case 2: vector-batched matrix multiplication - fwd_compute_cost = flop_mapping[torch.ops.aten.mv.default]( + fwd_compute_cost = flop_mapping[torch.ops.aten.matmul.default]( [input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]), input_tensors[0]], [output_tensors[0].reshape(-1)]) @@ -260,15 +260,15 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L [output_tensors[0].reshape(-1), input_tensors[0]], output_tensors ) + \ - flop_mapping[torch.ops.aten.mv.default]( + flop_mapping[torch.ops.aten.matmul.default]( [input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]).transpose(0, 1), output_tensors[0].reshape(-1)], output_tensors ) - fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors + [input_tensors[1]])) - bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors[0]), + fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors + [input_tensors[1]])) + bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors[0]), parameter=0, - temp=activation_size(input_tensors[1]), + temp=compute_size_in_bytes(input_tensors[1]), buffer=0) elif len(input_tensors[0].shape) >= 2 and len(input_tensors[1].shape) == 2: @@ -287,8 +287,8 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L [input_tensors[0].reshape(-1, input_tensors[0].shape[-1])] ) - fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors), parameter=0, temp=0, buffer=0) - bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors), parameter=0, temp=0, buffer=0) + fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors), parameter=0, temp=0, buffer=0) + bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors), parameter=0, temp=0, buffer=0) elif len(input_tensors[0].shape) == 2 and len(input_tensors[1].shape) >= 3: # batched gemm case 2: matrix-batched matrix multiplication @@ -306,11 +306,12 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L [input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2])] ) - fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors) + activation_size(input_tensors[1]), - temp=activation_size(output_tensors)) - bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors[0]), + fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors) + + compute_size_in_bytes(input_tensors[1]), + temp=compute_size_in_bytes(output_tensors)) + bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors[0]), parameter=0, - temp=activation_size(input_tensors[1]) + activation_size(output_tensors)) + temp=compute_size_in_bytes(input_tensors[1]) + compute_size_in_bytes(output_tensors)) elif all(len(tensor.shape) >= 3 for tensor in input_tensors): # Batched matrix-batched matrix multiplication @@ -351,8 +352,8 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L [input_tensors[0].reshape(-1, input_dim_00, input_dim_01)] ) - fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors)) - bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors)) + fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors)) + bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors)) else: # Case 2: batch dimensions are different @@ -381,10 +382,10 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L ) fwd_mem_cost = MemoryCost( - activation=activation_size([output_tensors[0], extended_input_0, extended_input_1])) - bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors) - - activation_size([extended_input_0, extended_input_1]), - temp=activation_size([extended_input_0, extended_input_1])) + activation=compute_size_in_bytes([output_tensors[0], extended_input_0, extended_input_1])) + bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors) - + compute_size_in_bytes([extended_input_0, extended_input_1]), + temp=compute_size_in_bytes([extended_input_0, extended_input_1])) # compute cost compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost) diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/non_spmd.py b/colossalai/auto_parallel/meta_profiler/meta_registry/non_spmd.py index 4634d3ccdcfd..12874810b13e 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/non_spmd.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/non_spmd.py @@ -4,8 +4,6 @@ import torch from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem -from colossalai.fx.profiler.memory_utils import activation_size -from colossalai.fx.profiler.opcount import flop_mapping from ..registry import meta_register diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/norm.py b/colossalai/auto_parallel/meta_profiler/meta_registry/norm.py index 3a1db396e188..b872fdc8bdcd 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/norm.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/norm.py @@ -2,6 +2,8 @@ import torch +from colossalai._analyzer._subclasses.flop_tensor import flop_mapping +from colossalai._analyzer.fx.node_util import compute_size_in_bytes from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( MemoryCost, OperationData, @@ -10,8 +12,6 @@ StrategiesVector, TrainCycleItem, ) -from colossalai.fx.profiler.memory_utils import activation_size -from colossalai.fx.profiler.opcount import flop_mapping from colossalai.tensor.sharding_spec import ShardingSpec from ..registry import meta_register @@ -77,17 +77,18 @@ def batchnormnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleIt # calculate memory cost # the fwd activation cost is output plus saved mean and saved inv std # NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward - fwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, output_tensor, mean_tensor, var_tensor]), - parameter=activation_size([weight_tensor, bias_tensor]), + fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes( + [input_tensor, output_tensor, mean_tensor, var_tensor]), + parameter=compute_size_in_bytes([weight_tensor, bias_tensor]), temp=0, - buffer=activation_size([mean_tensor, var_tensor])) + buffer=compute_size_in_bytes([mean_tensor, var_tensor])) # the bwd memory cost is quite tricky here, BatchNorm will remove saved mean # and saved inv std during backward phase - bwd_memory_cost = MemoryCost(activation=activation_size([input_tensor]), - parameter=activation_size([weight_tensor, bias_tensor]), - temp=activation_size([mean_tensor, var_tensor]), - buffer=activation_size([mean_tensor, var_tensor])) + bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor]), + parameter=compute_size_in_bytes([weight_tensor, bias_tensor]), + temp=compute_size_in_bytes([mean_tensor, var_tensor]), + buffer=compute_size_in_bytes([mean_tensor, var_tensor])) # total cost is the sum of forward and backward cost total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation, @@ -131,15 +132,16 @@ def layernorm_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem # memory cost # NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward - fwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, output_tensor, weight_tensor, bias_tensor]), - parameter=activation_size([weight_tensor, bias_tensor]), + fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes( + [input_tensor, output_tensor, weight_tensor, bias_tensor]), + parameter=compute_size_in_bytes([weight_tensor, bias_tensor]), temp=0, - buffer=activation_size([running_mean, running_var])) + buffer=compute_size_in_bytes([running_mean, running_var])) - bwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, weight_tensor, bias_tensor]), - parameter=activation_size([weight_tensor, bias_tensor]), - temp=activation_size([running_mean, running_var]), - buffer=activation_size([running_mean, running_var])) + bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor]), + parameter=compute_size_in_bytes([weight_tensor, bias_tensor]), + temp=compute_size_in_bytes([running_mean, running_var]), + buffer=compute_size_in_bytes([running_mean, running_var])) total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation, parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter, diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/pooling.py b/colossalai/auto_parallel/meta_profiler/meta_registry/pooling.py index 21272ea09ac1..d785dfcca9ba 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/pooling.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/pooling.py @@ -2,9 +2,9 @@ import torch +from colossalai._analyzer._subclasses.flop_tensor import flop_mapping +from colossalai._analyzer.fx.node_util import compute_size_in_bytes from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem -from colossalai.fx.profiler.memory_utils import activation_size -from colossalai.fx.profiler.opcount import flop_mapping from ..registry import meta_register @@ -52,8 +52,8 @@ def avgpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost) # calculate memory cost - fwd_mem_cost = MemoryCost() if is_inplace else MemoryCost(activation=activation_size(output_tensor)) - bwd_mem_cost = MemoryCost() if is_inplace else MemoryCost(activation=activation_size(input_tensor)) + fwd_mem_cost = MemoryCost() if is_inplace else MemoryCost(activation=compute_size_in_bytes(output_tensor)) + bwd_mem_cost = MemoryCost() if is_inplace else MemoryCost(activation=compute_size_in_bytes(input_tensor)) # total cost total_mem_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation) @@ -114,11 +114,11 @@ def maxpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, # calculate memory cost # NOTE: the index matrix will be discarded in backward phase # NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward - fwd_mem_cost = MemoryCost(activation=activation_size([input_tensor, output_tensor, index_matrix])) + fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor, index_matrix])) # temp memory for backward is the index matrix to be discarded - bwd_mem_cost = MemoryCost(activation=activation_size(input_tensor) - activation_size(index_matrix), - temp=activation_size(index_matrix)) + bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensor) - compute_size_in_bytes(index_matrix), + temp=compute_size_in_bytes(index_matrix)) # total cost total_mem_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation, temp=bwd_mem_cost.temp) diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/tensor.py b/colossalai/auto_parallel/meta_profiler/meta_registry/tensor.py index 332e649d2d7e..97fe3c6196f5 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/tensor.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/tensor.py @@ -2,9 +2,9 @@ import torch +from colossalai._analyzer._subclasses.flop_tensor import flop_mapping +from colossalai._analyzer.fx.node_util import compute_size_in_bytes from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem -from colossalai.fx.profiler.memory_utils import activation_size -from colossalai.fx.profiler.opcount import flop_mapping from ..registry import meta_register @@ -35,11 +35,11 @@ def meta_func(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[tor # memory costs # NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward - fwd_mem_cost = MemoryCost(activation=activation_size(outputs) * 2, parameter=0, temp=0, buffer=0) + fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(outputs) * 2, parameter=0, temp=0, buffer=0) - bwd_mem_cost = MemoryCost(activation=activation_size(outputs) * bwd_mem_out_factor, + bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(outputs) * bwd_mem_out_factor, parameter=0, - temp=activation_size(outputs) * bwd_mem_tmp_factor, + temp=compute_size_in_bytes(outputs) * bwd_mem_tmp_factor, buffer=0) total_mem_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation, diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/where.py b/colossalai/auto_parallel/meta_profiler/meta_registry/where.py index c67eb40bc80e..5cba1b5b6e2b 100644 --- a/colossalai/auto_parallel/meta_profiler/meta_registry/where.py +++ b/colossalai/auto_parallel/meta_profiler/meta_registry/where.py @@ -2,9 +2,9 @@ import torch +from colossalai._analyzer._subclasses.flop_tensor import flop_mapping +from colossalai._analyzer.fx.node_util import compute_size_in_bytes as activation_size from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem -from colossalai.fx.profiler.memory_utils import activation_size -from colossalai.fx.profiler.opcount import flop_mapping from ..registry import meta_register diff --git a/colossalai/auto_parallel/meta_profiler/metainfo.py b/colossalai/auto_parallel/meta_profiler/shard_metainfo.py similarity index 94% rename from colossalai/auto_parallel/meta_profiler/metainfo.py rename to colossalai/auto_parallel/meta_profiler/shard_metainfo.py index 44b1882e06cc..0eee908b48b7 100644 --- a/colossalai/auto_parallel/meta_profiler/metainfo.py +++ b/colossalai/auto_parallel/meta_profiler/shard_metainfo.py @@ -15,11 +15,11 @@ from .constants import INPLACE_MODULE, INPLACE_OPS, NO_SAVE_ACTIVATION from .registry import meta_register -__all__ = ['MetaInfo'] +__all__ = ['ShardMetaInfo'] -class MetaInfo: - """MetaInfo class +class ShardMetaInfo: + """ShardMetaInfo class This class is used to store meta info based on sharding strategy and the given target function. """ @@ -46,9 +46,9 @@ def __init__(self, strategy: ShardingStrategy = None, target: Callable = None) - # target function self._target = target - # compute metainfo if possible + # compute shard_metainfo if possible if self._strategy is not None and self._target is not None: - self.compute_metainfo() + self.compute_shard_metainfo() @property def strategy(self) -> ShardingStrategy: @@ -62,13 +62,13 @@ def target(self) -> Callable: def strategy(self, strategy: ShardingStrategy) -> None: self._strategy = strategy if self._strategy is not None and self._target is not None: - self.compute_metainfo() + self.compute_shard_metainfo() @target.setter def target(self, target: Callable) -> None: self._target = target if self._strategy is not None and self._target is not None: - self.compute_metainfo() + self.compute_shard_metainfo() def compute_sharded_opdata(self, operation_data: OperationData, sharding_spec: ShardingSpec): """ @@ -93,7 +93,7 @@ def compute_sharded_opdata(self, operation_data: OperationData, sharding_spec: S return op_data - def compute_metainfo(self): + def compute_shard_metainfo(self): """ Compute meta info based on sharding strategy and the given target function. """ diff --git a/colossalai/auto_parallel/passes/comm_metainfo_pass.py b/colossalai/auto_parallel/passes/comm_metainfo_pass.py index ab3acb0563ff..ffda58e0689f 100644 --- a/colossalai/auto_parallel/passes/comm_metainfo_pass.py +++ b/colossalai/auto_parallel/passes/comm_metainfo_pass.py @@ -4,7 +4,7 @@ from torch.fx import GraphModule from torch.fx.node import Node -from colossalai.auto_parallel.meta_profiler import MetaInfo +from colossalai.auto_parallel.meta_profiler import ShardMetaInfo from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply, runtime_comm_spec_apply from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem from colossalai.tensor.comm_spec import CommSpec @@ -14,15 +14,15 @@ shape_consistency_manager = ShapeConsistencyManager() -def _construct_meta_info(node: Node, origin_sharding_spec: ShardingSpec, - target_sharding_spec: ShardingSpec) -> MetaInfo: +def _construct_shard_meta_info(node: Node, origin_sharding_spec: ShardingSpec, + target_sharding_spec: ShardingSpec) -> ShardMetaInfo: # get comm_action_sequence and total_cost from shape_consistency_manager _, comm_action_sequence, total_cost = shape_consistency_manager.shape_consistency( origin_sharding_spec, target_sharding_spec) - meta_info = MetaInfo() + meta_info = ShardMetaInfo() # NOTE: the cost in shape_consistency_manager.mem_cost is the count in number of numel - # get mem cost for MetaInfo + # get mem cost for ShardMetaInfo mem_cost = shape_consistency_manager.mem_cost(comm_action_sequence) # extract user that has _meta_data and extract element length input_node = next(n for n in node._input_nodes if hasattr(n, '_meta_data')) @@ -36,12 +36,12 @@ def _construct_meta_info(node: Node, origin_sharding_spec: ShardingSpec, meta_info.memory_cost = mem_cost - # get computation cost for MetaInfo + # get computation cost for ShardMetaInfo meta_info.compute_cost = TrainCycleItem(total_cost['forward'] * element_length, total_cost['backward'] * element_length, total_cost['total'] * element_length) - # get tensor shape for MetaInfo + # get tensor shape for ShardMetaInfo origin_sharding_spec: ShardingSpec target_sharding_spec: ShardingSpec input_shape = origin_sharding_spec.get_sharded_shape_per_device() @@ -54,7 +54,7 @@ def _construct_meta_info(node: Node, origin_sharding_spec: ShardingSpec, return meta_info -def _runtime_apply_meta_info(node: Node, origin_spec_dict, sharding_spec_dict) -> MetaInfo: +def _runtime_apply_meta_info(node: Node, origin_spec_dict, sharding_spec_dict) -> ShardMetaInfo: """ This method is used to construct `MetaInto` for shape consistency node """ @@ -65,17 +65,17 @@ def _runtime_apply_meta_info(node: Node, origin_spec_dict, sharding_spec_dict) - origin_sharding_spec, target_sharding_spec = origin_spec_dict[node_index], sharding_spec_dict[node_index][ user_node_index] - return _construct_meta_info(node, origin_sharding_spec, target_sharding_spec) + return _construct_shard_meta_info(node, origin_sharding_spec, target_sharding_spec) -def _runtime_comm_spec_apply_meta_info(node: Node, comm_actions_dict: Dict) -> MetaInfo: +def _runtime_comm_spec_apply_meta_info(node: Node, comm_actions_dict: Dict) -> ShardMetaInfo: # extract node_index and op_data_name node_index, op_data_name = node.args[2], node.args[3] comm_action = comm_actions_dict[node_index][op_data_name] if isinstance(comm_action.comm_spec, CommSpec): # this case is for all_reduce, there will be no memory cost - meta_info = MetaInfo() + meta_info = ShardMetaInfo() meta_info.memory_cost = TrainCycleItem(MemoryCost(), MemoryCost(), MemoryCost) output_node = next(n for n in node.users if hasattr(n, '_meta_data')) element_length = output_node._meta_data.element_size() @@ -93,7 +93,7 @@ def _runtime_comm_spec_apply_meta_info(node: Node, comm_actions_dict: Dict) -> M # this case will be handled by shape consistency manager origin_sharding_spec, target_sharding_spec = comm_action.comm_spec['src_spec'], comm_action.comm_spec[ 'tgt_spec'] - meta_info = _construct_meta_info(node, origin_sharding_spec, target_sharding_spec) + meta_info = _construct_shard_meta_info(node, origin_sharding_spec, target_sharding_spec) return meta_info @@ -105,9 +105,9 @@ def comm_metainfo_pass(gm: GraphModule, sharding_spec_dict: Dict, origin_spec_di """ for node in gm.graph.nodes: if node.target == runtime_apply: - setattr(node, 'best_metainfo', _runtime_apply_meta_info(node, origin_spec_dict, sharding_spec_dict)) + setattr(node, 'best_strategy_info', _runtime_apply_meta_info(node, origin_spec_dict, sharding_spec_dict)) elif node.target == runtime_comm_spec_apply: - setattr(node, 'best_metainfo', _runtime_comm_spec_apply_meta_info(node, comm_actions_dict)) + setattr(node, 'best_strategy_info', _runtime_comm_spec_apply_meta_info(node, comm_actions_dict)) else: pass return gm diff --git a/colossalai/auto_parallel/passes/meta_info_prop.py b/colossalai/auto_parallel/passes/meta_info_prop.py index f7e07ef1ec18..bc0960483980 100644 --- a/colossalai/auto_parallel/passes/meta_info_prop.py +++ b/colossalai/auto_parallel/passes/meta_info_prop.py @@ -7,7 +7,7 @@ from torch.fx import GraphModule from torch.fx.node import Node -from colossalai.auto_parallel.meta_profiler import MetaInfo +from colossalai.auto_parallel.meta_profiler import ShardMetaInfo from colossalai.auto_parallel.passes.constants import OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS from colossalai.fx._compatibility import compatibility from colossalai.fx.profiler import GraphInfo @@ -96,12 +96,12 @@ def node_handler(self, node: Node) -> None: """ Handle other kind of nodes """ - assert hasattr(node, 'best_metainfo'), f"Cannot find best_metainfo in node {node}, {node.op}" + assert hasattr(node, 'best_strategy_info'), f"Cannot find best_strategy_info in node {node}, {node.op}" graph_info = GraphInfo() - meta_info = node.best_metainfo - meta_info: MetaInfo + meta_info = node.best_strategy_info + meta_info: ShardMetaInfo - # set data_ptr for input_tensor in MetaInfo class + # set data_ptr for input_tensor in ShardMetaInfo class input_tensors: List[torch.Tensor] = meta_info.fwd_in buffer_tensors: List[torch.Tensor] = meta_info.fwd_buffer output_tensors: List[torch.Tensor] = meta_info.fwd_out diff --git a/colossalai/auto_parallel/passes/runtime_apply_pass.py b/colossalai/auto_parallel/passes/runtime_apply_pass.py index 9d83f105748b..a473bb6e973d 100644 --- a/colossalai/auto_parallel/passes/runtime_apply_pass.py +++ b/colossalai/auto_parallel/passes/runtime_apply_pass.py @@ -4,7 +4,7 @@ import torch from torch.fx.node import Node -from colossalai.auto_parallel.meta_profiler import MetaInfo +from colossalai._analyzer.fx.node_util import MetaInfo from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( CommAction, CommType, @@ -128,9 +128,10 @@ def _shape_consistency_apply(gm: torch.fx.GraphModule): runtime_apply, args=(node, origin_dict_node, input_dict_node, node_to_index_dict[node], user_node_index)) - if 'activation_checkpoint' in user_node.meta: - shape_consistency_node.meta['activation_checkpoint'] = user_node.meta['activation_checkpoint'] - + if hasattr(user_node.meta['info'], 'activation_checkpoint'): + MetaInfo(shape_consistency_node, + mod_dir=user_node.meta['info'].mod_dir, + activation_checkpoint=tuple(user_node.meta['info'].activation_checkpoint)) new_args = list(user_node.args) new_kwargs = dict(user_node.kwargs) # the origin node may be a positional argument or key word argument of user node @@ -210,9 +211,10 @@ def _comm_spec_apply(gm: torch.fx.GraphModule): # substitute the origin node with comm_spec_apply_node new_kwargs[str(node)] = comm_spec_apply_node user.kwargs = new_kwargs - - if 'activation_checkpoint' in node.meta: - comm_spec_apply_node.meta['activation_checkpoint'] = node.meta['activation_checkpoint'] + if hasattr(node.meta['info'], 'activation_checkpoint'): + MetaInfo(comm_spec_apply_node, + mod_dir=node.meta['info'].mod_dir, + activation_checkpoint=tuple(node.meta['info'].activation_checkpoint)) return gm diff --git a/colossalai/auto_parallel/passes/runtime_preparation_pass.py b/colossalai/auto_parallel/passes/runtime_preparation_pass.py index 3be3084222fe..e1d0c627274e 100644 --- a/colossalai/auto_parallel/passes/runtime_preparation_pass.py +++ b/colossalai/auto_parallel/passes/runtime_preparation_pass.py @@ -6,6 +6,7 @@ from torch.fx import symbolic_trace from torch.fx.node import Node +from colossalai._analyzer.fx.node_util import MetaInfo from colossalai.auto_parallel.tensor_shard.constants import RESHAPE_FUNC_OP from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( CommAction, @@ -74,9 +75,9 @@ def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int], origin_node_sharding_spec_dict[node_index] = strategies_vector[strategy_index].get_sharding_spec_by_name( str(node)) - # attach the corresponding metainfo if node has the attribute `metainfo_vector` - if hasattr(node, 'metainfo_vector'): - setattr(node, 'best_metainfo', node.metainfo_vector[strategy_index]) + # attach the corresponding metainfo if node has the attribute `strategies_info` + if hasattr(node, 'strategies_info'): + setattr(node, 'best_strategy_info', node.strategies_info[strategy_index]) # the dict to get input sharding specs of user node sharding_spec_convert_dict = {} @@ -172,8 +173,11 @@ def _post_processing(node, size_processing_node): # It will be used to replace the original node with processing node in slice object node_pairs[node] = size_processing_node size_processing_node._meta_data = node._meta_data - if 'activation_checkpoint' in node.meta: - size_processing_node.meta['activation_checkpoint'] = node.meta['activation_checkpoint'] + + if hasattr(node.meta['info'], 'activation_checkpoint'): + MetaInfo(size_processing_node, + mod_dir=node.meta['info'].mod_dir, + activation_checkpoint=tuple(node.meta['info'].activation_checkpoint)) user_list = list(node.users.keys()) for user in user_list: diff --git a/colossalai/auto_parallel/tensor_shard/initialize.py b/colossalai/auto_parallel/tensor_shard/initialize.py index 60472eee52ca..b406ca6fb7e0 100644 --- a/colossalai/auto_parallel/tensor_shard/initialize.py +++ b/colossalai/auto_parallel/tensor_shard/initialize.py @@ -6,6 +6,10 @@ from torch.fx import GraphModule from torch.fx.graph import Graph +from colossalai._analyzer.fx.codegen import ActivationCheckpointCodeGen +from colossalai._analyzer.fx.graph_module import ColoGraphModule +from colossalai._analyzer.fx.passes import shape_prop_pass +from colossalai._analyzer.fx.tracer.tracer import ColoTracer from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass from colossalai.auto_parallel.tensor_shard.options import DataloaderOption, ShardOption, SolverOptions, SolverPerference @@ -13,8 +17,6 @@ from colossalai.auto_parallel.tensor_shard.solver import CostGraph, GraphAnalyser, Solver, StrategiesConstructor from colossalai.device.alpha_beta_profiler import AlphaBetaProfiler from colossalai.device.device_mesh import DeviceMesh -from colossalai.fx.graph_module import ColoGraphModule -from colossalai.fx.tracer import ColoTracer from colossalai.tensor.sharding_spec import ShardingSpec @@ -126,6 +128,7 @@ def solve_solution(gm: ColoGraphModule, strategy_constructor: StrategiesConstruc def transform_to_sharded_model(gm: ColoGraphModule, + meta_args: Dict, solution: List[int], device_mesh: DeviceMesh, strategies_constructor: StrategiesConstructor, @@ -142,6 +145,7 @@ def transform_to_sharded_model(gm: ColoGraphModule, strategies_constructor, overlap=overlap) gm = runtime_apply_pass(gm) + shape_prop_pass(gm, *meta_args.values(), sharding_spec_dict, origin_spec_dict, comm_actions_dict) gm.recompile() sharding_spec_dicts = (sharding_spec_dict, origin_spec_dict, comm_actions_dict) @@ -243,10 +247,13 @@ def initialize_model(model: nn.Module, solution will be used to debug or help to analyze the sharding result. Therefore, we will not just return a series of integers, but return the best strategies. ''' - tracer = ColoTracer(trace_act_ckpt=True) + tracer = ColoTracer(trace_act_ckpt=True, bias_addition_split=True) graph = tracer.trace(root=model, meta_args=meta_args) + graph.set_codegen(ActivationCheckpointCodeGen()) gm = ColoGraphModule(model, graph, model.__class__.__name__) + + shape_prop_pass(gm, *meta_args.values()) gm.recompile() strategies_constructor = build_strategy_constructor(graph, @@ -261,7 +268,9 @@ def initialize_model(model: nn.Module, if save_solver_solution: torch.save(solution, solution_path) - gm, sharding_spec_dicts = transform_to_sharded_model(gm, solution, device_mesh, strategies_constructor, overlap) + gm, sharding_spec_dicts = transform_to_sharded_model(gm, meta_args, solution, device_mesh, strategies_constructor, + overlap) + model_to_return = ModuleWrapper(gm, *sharding_spec_dicts) if return_solution: diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/batch_norm_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/batch_norm_handler.py index 57b623b0122c..cb1bb36b7879 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/batch_norm_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/batch_norm_handler.py @@ -2,8 +2,6 @@ import torch -from colossalai.auto_parallel.meta_profiler.metainfo import MetaInfo - from ..sharding_strategy import OperationData, OperationDataType, StrategiesVector from .node_handler import MetaInfoModuleHandler, ModuleHandler from .registry import operator_registry diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py index 136e57c5e0f5..ab391ebfaf80 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py @@ -4,7 +4,7 @@ import torch from torch.fx.node import Node -from colossalai.auto_parallel.meta_profiler.metainfo import MetaInfo, meta_register +from colossalai.auto_parallel.meta_profiler.shard_metainfo import ShardMetaInfo, meta_register from colossalai.auto_parallel.tensor_shard.options import ShardOption, SolverPerference from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( OperationData, @@ -258,7 +258,7 @@ class MetaInfoNodeHandler(NodeHandler): def register_strategy(self, compute_resharding_cost: bool = True) -> StrategiesVector: """ This method is inherited from NodeHandler. It will register the strategies first, - and rewrite the memory_cost and compute_cost of the strategy using the MetaInfo class. + and rewrite the memory_cost and compute_cost of the strategy using the ShardMetaInfo class. """ super().register_strategy(compute_resharding_cost=compute_resharding_cost) target = self.get_target_function() @@ -266,15 +266,15 @@ def register_strategy(self, compute_resharding_cost: bool = True) -> StrategiesV # is not patched, we will use the default cost model to compute the cost. # TODO: patch all torch functions and modules to make it clean if meta_register.has(target.__class__) or meta_register.has(target): - metainfo_vector = [] + strategies_info = [] for strategy in self.strategies_vector: - metainfo = MetaInfo(strategy, target) + metainfo = ShardMetaInfo(strategy, target) strategy.compute_cost = metainfo.compute_cost strategy.memory_cost = metainfo.memory_cost - metainfo_vector.append(metainfo) + strategies_info.append(metainfo) # attach metainfos to the handler - setattr(self, "metainfo_vector", metainfo_vector) + setattr(self, "strategies_info", strategies_info) else: logger = get_dist_logger() @@ -313,7 +313,7 @@ class MetaInfoModuleHandler(ModuleHandler): def register_strategy(self, compute_resharding_cost: bool = True) -> StrategiesVector: """ This method is inherited from NodeHandler. It will register the strategies first, - and rewrite the memory_cost and compute_cost of the strategy using the MetaInfo class. + and rewrite the memory_cost and compute_cost of the strategy using the ShardMetaInfo class. """ super().register_strategy(compute_resharding_cost=compute_resharding_cost) target = self.get_target_function() @@ -321,15 +321,15 @@ def register_strategy(self, compute_resharding_cost: bool = True) -> StrategiesV # is not patched, we will use the default cost model to compute the cost. # TODO: patch all torch functions and modules to make it clean if meta_register.has(target.__class__) or meta_register.has(target): - metainfo_vector = [] + strategies_info = [] for strategy in self.strategies_vector: - metainfo = MetaInfo(strategy, target) + metainfo = ShardMetaInfo(strategy, target) strategy.compute_cost = metainfo.compute_cost strategy.memory_cost = metainfo.memory_cost - metainfo_vector.append(metainfo) + strategies_info.append(metainfo) # attach metainfos to the handler - setattr(self, "metainfo_vector", metainfo_vector) + setattr(self, "strategies_info", strategies_info) else: logger = get_dist_logger() diff --git a/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py b/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py index 59ead1ca8fac..044a8ac847ea 100644 --- a/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py +++ b/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py @@ -137,9 +137,9 @@ def _check_no_strategy_for_data(data): shard_option=self.solver_options.shard_option, solver_perference=self.solver_options.solver_perference) handler.register_strategy() - # attach metainfo_vector to node - if hasattr(handler, 'metainfo_vector'): - setattr(node, 'metainfo_vector', handler.metainfo_vector) + # attach strategies_info to node + if hasattr(handler, 'strategies_info'): + setattr(node, 'strategies_info', handler.strategies_info) # call_function node elif node.op == 'call_function': @@ -150,9 +150,9 @@ def _check_no_strategy_for_data(data): shard_option=self.solver_options.shard_option, solver_perference=self.solver_options.solver_perference) handler.register_strategy() - # attach metainfo_vector to node - if hasattr(handler, 'metainfo_vector'): - setattr(node, 'metainfo_vector', handler.metainfo_vector) + # attach strategies_info to node + if hasattr(handler, 'strategies_info'): + setattr(node, 'strategies_info', handler.strategies_info) # call_method node elif node.op == 'call_method': @@ -163,9 +163,9 @@ def _check_no_strategy_for_data(data): shard_option=self.solver_options.shard_option, solver_perference=self.solver_options.solver_perference) handler.register_strategy() - # attach metainfo_vector to node - if hasattr(handler, 'metainfo_vector'): - setattr(node, 'metainfo_vector', handler.metainfo_vector) + # attach strategies_info to node + if hasattr(handler, 'strategies_info'): + setattr(node, 'strategies_info', handler.strategies_info) # output node elif node.op == 'output': diff --git a/tests/test_analyzer/__init__.py b/tests/test_analyzer/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/test_auto_parallel/test_pass/test_size_value_converting_pass.py b/tests/test_auto_parallel/test_pass/test_size_value_converting_pass.py index 3494830080ff..7d4fd844ab26 100644 --- a/tests/test_auto_parallel/test_pass/test_size_value_converting_pass.py +++ b/tests/test_auto_parallel/test_pass/test_size_value_converting_pass.py @@ -1,10 +1,12 @@ +import pytest import torch import torch.nn.functional as F +from colossalai._analyzer.fx.graph_module import ColoGraphModule +from colossalai._analyzer.fx.passes import shape_prop_pass +from colossalai._analyzer.fx.tracer.tracer import ColoTracer from colossalai.auto_parallel.passes.runtime_preparation_pass import size_value_converting_pass from colossalai.device.device_mesh import DeviceMesh -from colossalai.fx.graph_module import ColoGraphModule -from colossalai.fx.tracer import ColoTracer from colossalai.tensor.sharding_spec import ShardingSpec @@ -33,6 +35,7 @@ def recover_narrow(gm, narrow_node): return gm +@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0') def test_size_value_converting_pass(): model = TestModule() physical_mesh_id = torch.arange(0, 4) @@ -40,14 +43,14 @@ def test_size_value_converting_pass(): device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) meta_args = {'x': torch.rand(4, 8).to('meta')} input = torch.rand(4, 8) - tracer = ColoTracer() + tracer = ColoTracer(bias_addition_split=True) graph = tracer.trace(root=model, meta_args=meta_args) - x_node = list(graph.nodes)[0] x_sharding_spec = ShardingSpec(device_mesh, entire_shape=(4, 8), dim_partition_dict={0: [0]}) setattr(x_node, 'sharding_spec', x_sharding_spec) gm = ColoGraphModule(model, graph) gm = insert_narrow(gm, x_node) + shape_prop_pass(gm, *meta_args.values()) gm.recompile() size = gm(input) assert size == torch.Size([2, 8]) diff --git a/tests/test_auto_parallel/test_tensor_shard/test_bias_addition_forward.py b/tests/test_auto_parallel/test_tensor_shard/test_bias_addition_forward.py index f43885a6ac44..6d1b28912c9b 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_bias_addition_forward.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_bias_addition_forward.py @@ -4,7 +4,12 @@ import torch import torch.multiprocessing as mp -from colossalai.auto_parallel.tensor_shard.initialize import initialize_model +try: + from colossalai.auto_parallel.tensor_shard.initialize import initialize_model + NO_CODEGEN = False +except: + NO_CODEGEN = True + from colossalai.device.device_mesh import DeviceMesh from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers @@ -77,6 +82,7 @@ def check_conv_module(rank, world_size, port): @run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.skipif(NO_CODEGEN, reason='No codegen found') @pytest.mark.dist @rerun_if_address_is_in_use() def test_bias_addition_module(): diff --git a/tests/test_auto_parallel/test_tensor_shard/test_checkpoint.py b/tests/test_auto_parallel/test_tensor_shard/test_checkpoint.py index 0b42722fec5f..7a4c8d32ed80 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_checkpoint.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_checkpoint.py @@ -8,13 +8,15 @@ from torch.utils.checkpoint import checkpoint from transformers.pytorch_utils import Conv1D -from colossalai.auto_parallel.tensor_shard.initialize import initialize_model +try: + from colossalai.auto_parallel.tensor_shard.initialize import initialize_model + NO_CODEGEN = False +except: + NO_CODEGEN = True + from colossalai.device.device_mesh import DeviceMesh -from colossalai.fx.graph_module import ColoGraphModule -from colossalai.fx.tracer import ColoTracer from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers -from colossalai.tensor.shape_consistency import ShapeConsistencyManager from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing.pytest_wrapper import run_on_environment_flag from colossalai.utils import free_port @@ -43,6 +45,7 @@ def check_act_ckpt(rank, world_size, port): disable_existing_loggers() launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') model = GPT2MLPWithCkpt(intermediate_size=4 * HIDDEN_SIZE, hidden_size=HIDDEN_SIZE) + input = torch.rand(1, 64, HIDDEN_SIZE) input_sample = { 'hidden_states': torch.rand(1, 64, HIDDEN_SIZE).to('meta'), } @@ -54,10 +57,11 @@ def check_act_ckpt(rank, world_size, port): gm = initialize_model(model, input_sample, device_mesh) code = gm.module.graph.python_code('self').src assert "runtime_comm_spec_apply_1 = colossalai_auto_parallel_passes_runtime_apply_pass_runtime_comm_spec_apply(linear_1, comm_actions_dict, 12, 'linear_1')" in code - assert "view_3 = colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, False, view_1, comm_actions_dict, use_reentrant=True)" in code + assert "view_3 = torch.utils.checkpoint.checkpoint(self.checkpoint_0, view_1, comm_actions_dict, use_reentrant=False)" in code @run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.skipif(NO_CODEGEN, reason='No codegen found') @pytest.mark.dist @rerun_if_address_is_in_use() def test_mlp_layer(): diff --git a/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_ddp.py b/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_ddp.py index e4982a5d7f5a..7c3277c69970 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_ddp.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_ddp.py @@ -6,7 +6,12 @@ import torch.multiprocessing as mp from torch.nn.parallel import DistributedDataParallel as DDP -from colossalai.auto_parallel.tensor_shard.initialize import initialize_model +try: + from colossalai.auto_parallel.tensor_shard.initialize import initialize_model + NO_CODEGEN = False +except: + NO_CODEGEN = True + from colossalai.device.device_mesh import DeviceMesh from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers @@ -93,6 +98,7 @@ def check_compatibility_with_ddp(rank, world_size, port): @run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.skipif(NO_CODEGEN, reason='No codegen found') @pytest.mark.dist @rerun_if_address_is_in_use() def test_compatibility_with_ddp(): diff --git a/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py b/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py index 9879ae461848..e4435a049f62 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py @@ -6,7 +6,12 @@ import torch.multiprocessing as mp from torch.nn.parallel import DistributedDataParallel as DDP -from colossalai.auto_parallel.tensor_shard.initialize import initialize_model +try: + from colossalai.auto_parallel.tensor_shard.initialize import initialize_model + NO_CODEGEN = False +except: + NO_CODEGEN = True + from colossalai.device.device_mesh import DeviceMesh from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers @@ -101,6 +106,7 @@ def check_auto_parallel_with_gemini(rank, world_size, port): @run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.skipif(NO_CODEGEN, reason='No codegen found') @pytest.mark.dist @rerun_if_address_is_in_use() def test_auto_parallel_with_gemini(): diff --git a/tests/test_auto_parallel/test_tensor_shard/test_find_repeat_block.py b/tests/test_auto_parallel/test_tensor_shard/test_find_repeat_block.py index 90301521f207..e7fccad36caf 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_find_repeat_block.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_find_repeat_block.py @@ -5,8 +5,11 @@ from torch.fx import GraphModule from transformers.pytorch_utils import Conv1D +from colossalai._analyzer.fx.graph_module import ColoGraphModule +from colossalai._analyzer.fx.passes import shape_prop_pass +# from colossalai.fx.tracer.tracer import ColoTracer +from colossalai._analyzer.fx.tracer.tracer import ColoTracer from colossalai.auto_parallel.tensor_shard.utils.factory import find_repeat_blocks -from colossalai.fx.tracer.tracer import ColoTracer from colossalai.testing import parameterize from colossalai.testing.pytest_wrapper import run_on_environment_flag @@ -83,11 +86,12 @@ def test_repeat_blocks(model_cls): model = model_cls(4 * HIDDEN_DIM, HIDDEN_DIM, NUM_REPEAT_BLOCKS) - tracer = ColoTracer() + tracer = ColoTracer(bias_addition_split=True) input_sample = {'x': torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to('meta')} graph = tracer.trace(root=model, meta_args=input_sample) gm = GraphModule(model, graph, model.__class__.__name__) + shape_prop_pass(gm, *input_sample.values()) gm.recompile() node_list = list(graph.nodes) diff --git a/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_runtime_with_gpt_modules.py b/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_runtime_with_gpt_modules.py index ebeef9870fe9..8688890efc93 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_runtime_with_gpt_modules.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_runtime_with_gpt_modules.py @@ -10,15 +10,23 @@ import transformers from torch.fx import GraphModule -from colossalai.auto_parallel.tensor_shard.initialize import ( - ModuleWrapper, - build_strategy_constructor, - solve_solution, - transform_to_sharded_model, -) +from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass +# from colossalai.fx.tracer.tracer import ColoTracer +from colossalai._analyzer.fx.tracer.tracer import ColoTracer + +try: + from colossalai.auto_parallel.tensor_shard.initialize import ( + ModuleWrapper, + build_strategy_constructor, + solve_solution, + transform_to_sharded_model, + ) + NO_CODEGEN = False +except: + NO_CODEGEN = True + from colossalai.auto_parallel.tensor_shard.sharding_strategy import ShardingSpec from colossalai.device.device_mesh import DeviceMesh -from colossalai.fx.tracer.tracer import ColoTracer from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers from colossalai.tensor.shape_consistency import to_global @@ -52,9 +60,8 @@ def _check_module_grad(module: torch.nn.Module, origin_param_dict: Dict[str, tor param_sharding_spec = best_sharding_spec_dict[new_name] grad_to_compare = copy.deepcopy(param_grad) param_grad_global = to_global(grad_to_compare, param_sharding_spec) - try: - assert_close_loose(param_grad_global, origin_param_grad, rtol=1e-03, atol=1e-03) + assert_close_loose(param_grad_global, origin_param_grad, rtol=1e-03, atol=1e-05) except: difference = param_grad_global - origin_param_grad avg_diff = difference.abs().sum() / difference.numel() @@ -66,7 +73,7 @@ def check_attention_layer(rank, model_cls, world_size, port): disable_existing_loggers() launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - config = transformers.GPT2Config(n_position=64, n_layer=1, n_head=16, n_embd=HIDDEN_DIM) + config = transformers.GPT2Config(n_position=64, n_layer=2, n_head=16, n_embd=HIDDEN_DIM) if model_cls == GPT2MLP: model = model_cls(intermediate_size=4 * config.hidden_size, config=config).to('cuda') @@ -111,15 +118,17 @@ def check_attention_layer(rank, model_cls, world_size, port): # [[0, 1] # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - tracer = ColoTracer() + tracer = ColoTracer(bias_addition_split=True) graph = tracer.trace(root=model, meta_args=meta_input_sample) gm = GraphModule(model, graph, model.__class__.__name__) + shape_prop_pass(gm, *meta_input_sample.values()) gm.recompile() strategies_constructor = build_strategy_constructor(graph, device_mesh, 'standard', 'replicated', 'standard') solution = solve_solution(gm, strategies_constructor, memory_budget=-1) - gm, sharding_spec_dicts = transform_to_sharded_model(gm, solution, device_mesh, strategies_constructor) + gm, sharding_spec_dicts = transform_to_sharded_model(gm, meta_input_sample, solution, device_mesh, + strategies_constructor) gm = ModuleWrapper(gm, *sharding_spec_dicts) nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies] @@ -176,6 +185,7 @@ def check_attention_layer(rank, model_cls, world_size, port): @run_on_environment_flag(name='AUTO_PARALLEL') +@pytest.mark.skipif(NO_CODEGEN, reason="no codegen module") @pytest.mark.dist @parameterize('model_cls', [GPT2MLP, GPT2Block, GPT2Attention, GPT2Model]) @rerun_if_address_is_in_use() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_solver_with_gpt_module.py b/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_solver_with_gpt_module.py index 4adb4fbaf047..5f0688d5f5f2 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_solver_with_gpt_module.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_solver_with_gpt_module.py @@ -3,11 +3,12 @@ import transformers from torch.fx import GraphModule +from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass +from colossalai._analyzer.fx.tracer.tracer import ColoTracer from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP from colossalai.auto_parallel.tensor_shard.options import SolverOptions from colossalai.auto_parallel.tensor_shard.solver import CostGraph, GraphAnalyser, Solver, StrategiesConstructor from colossalai.device.device_mesh import DeviceMesh -from colossalai.fx.tracer.tracer import ColoTracer from colossalai.tensor.shape_consistency import ShapeConsistencyManager from colossalai.testing import parameterize from colossalai.testing.pytest_wrapper import run_on_environment_flag @@ -21,7 +22,7 @@ @run_on_environment_flag(name='AUTO_PARALLEL') @parameterize('model_cls', [GPT2Block, GPT2Attention, GPT2MLP, GPT2Model]) def test_self_attention_block(model_cls): - config = transformers.GPT2Config(n_position=64, n_layer=12, n_head=16, n_embd=HIDDEN_DIM) + config = transformers.GPT2Config(n_position=64, n_layer=2, n_head=16, n_embd=HIDDEN_DIM) if model_cls == GPT2MLP: model = model_cls(intermediate_size=4 * config.hidden_size, config=config) else: @@ -33,7 +34,7 @@ def test_self_attention_block(model_cls): device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) shape_consistency_manager = ShapeConsistencyManager() - tracer = ColoTracer() + tracer = ColoTracer(bias_addition_split=True) if model_cls == GPT2MLP: input_sample = { 'hidden_states': torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to('meta'), @@ -52,6 +53,7 @@ def test_self_attention_block(model_cls): graph = tracer.trace(root=model, meta_args=input_sample) gm = GraphModule(model, graph, model.__class__.__name__) + shape_prop_pass(gm, *input_sample.values()) print(gm.graph) gm.recompile() solver_options = SolverOptions() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_liveness_analysis.py b/tests/test_auto_parallel/test_tensor_shard/test_liveness_analysis.py index f5de7bf702ff..8d421243827e 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_liveness_analysis.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_liveness_analysis.py @@ -1,8 +1,11 @@ +import pytest import torch import torch.nn as nn +from colossalai._analyzer.fx.graph_module import ColoGraphModule +from colossalai._analyzer.fx.passes import shape_prop_pass +from colossalai._analyzer.fx.tracer.tracer import ColoTracer from colossalai.auto_parallel.tensor_shard.solver import GraphAnalyser -from colossalai.fx import ColoGraphModule, ColoTracer class LinearModel(nn.Module): @@ -22,15 +25,14 @@ def forward(self, x1, x2): return out +@pytest.mark.skip('meta tensor has some bugs in 1.11') def test_liveness_analysis(): model = LinearModel() - tracer = ColoTracer() - graph = tracer.trace(model, - meta_args={ - 'x1': torch.rand(4, 4, device='meta'), - 'x2': torch.rand(4, 4, device='meta') - }) + tracer = ColoTracer(bias_addition_split=True) + meta_args = {'x1': torch.rand(4, 4, device='meta'), 'x2': torch.rand(4, 4, device='meta')} + graph = tracer.trace(model, meta_args=meta_args) gm = ColoGraphModule(root=model, graph=graph, class_name=model.__class__.__name__) + shape_prop_pass(gm, *meta_args.values()) graph_analyser = GraphAnalyser(gm) liveness_list = graph_analyser.liveness_analysis() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_embedding_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_embedding_metainfo.py index 2fb1306546ca..5f3d2df503f6 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_embedding_metainfo.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_embedding_metainfo.py @@ -24,7 +24,7 @@ from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results if torch.__version__ >= '1.12.0': - from colossalai.auto_parallel.meta_profiler import MetaInfo, meta_register + from colossalai.auto_parallel.meta_profiler import ShardMetaInfo, meta_register @pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations") diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_linear_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_linear_metainfo.py index e9c0601eb1e4..ddc8e3c6ac02 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_linear_metainfo.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_linear_metainfo.py @@ -17,7 +17,7 @@ from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy if torch.__version__ >= '1.12.0': - from colossalai.auto_parallel.meta_profiler import MetaInfo, meta_register + from colossalai.auto_parallel.meta_profiler import ShardMetaInfo, meta_register class MyModule(nn.Module): diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_matmul_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_matmul_metainfo.py index fd29c63fb522..1242b9db0750 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_matmul_metainfo.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_matmul_metainfo.py @@ -24,7 +24,7 @@ from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results if torch.__version__ >= '1.12.0': - from colossalai.auto_parallel.meta_profiler import MetaInfo, meta_register + from colossalai.auto_parallel.meta_profiler import ShardMetaInfo, meta_register @pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations") diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_norm_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_norm_metainfo.py index 9d3ab9c82670..d3342d310157 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_norm_metainfo.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_norm_metainfo.py @@ -23,7 +23,7 @@ from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy, print_results if torch.__version__ >= '1.12.0': - from colossalai.auto_parallel.meta_profiler import MetaInfo, meta_register + from colossalai.auto_parallel.meta_profiler import ShardMetaInfo, meta_register def _batchnorm_module_mem_test(rank, world_size, port): diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_tensor_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_tensor_metainfo.py index a0ab66fdc060..a544e9a3cb2f 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_tensor_metainfo.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_tensor_metainfo.py @@ -24,7 +24,7 @@ from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results if torch.__version__ >= '1.12.0': - from colossalai.auto_parallel.meta_profiler import MetaInfo, meta_register + from colossalai.auto_parallel.meta_profiler import ShardMetaInfo, meta_register class SplitModule(nn.Module): diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_where_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_where_metainfo.py index 20156f9ab4d5..2ae13ea2b1f8 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_where_metainfo.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_where_metainfo.py @@ -22,7 +22,7 @@ from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results if torch.__version__ >= '1.12.0': - from colossalai.auto_parallel.meta_profiler import MetaInfo, meta_register + from colossalai.auto_parallel.meta_profiler import ShardMetaInfo, meta_register @pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations") diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/utils.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/utils.py index 60ecd1dd9801..4ca85d34da30 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/utils.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/utils.py @@ -5,16 +5,19 @@ import torch from torch.fx import GraphModule +from colossalai._analyzer.fx.graph_module import ColoGraphModule +from colossalai._analyzer.fx.passes import shape_prop_pass +# from colossalai.fx.tracer.tracer import ColoTracer +from colossalai._analyzer.fx.tracer.tracer import ColoTracer from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass from colossalai.auto_parallel.tensor_shard.options import SolverOptions from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDataType, TrainCycleItem from colossalai.auto_parallel.tensor_shard.solver import StrategiesConstructor from colossalai.device.device_mesh import DeviceMesh -from colossalai.fx.tracer.tracer import ColoTracer if torch.__version__ >= '1.12.0': - from colossalai.auto_parallel.meta_profiler import MetaInfo + from colossalai.auto_parallel.meta_profiler import ShardMetaInfo def mem_test_for_node_strategy(rank: int, @@ -30,14 +33,16 @@ def mem_test_for_node_strategy(rank: int, model_to_shard, args_to_shard, kwargs_to_shard = copy.deepcopy(model), copy.deepcopy(input_args), copy.deepcopy( input_kwargs) - tracer = ColoTracer() + tracer = ColoTracer(bias_addition_split=True) input_sample = {} for input_arg, meta_arg_name in zip(input_args, meta_arg_names): input_sample[meta_arg_name] = torch.rand(input_arg.shape).to('meta') for meta_kwarg_name, input_kwarg in input_kwargs.items(): input_sample[meta_kwarg_name] = torch.rand(input_kwarg.shape).to('meta') graph = tracer.trace(root=model_to_shard, meta_args=input_sample) - gm = GraphModule(model_to_shard, graph, model_to_shard.__class__.__name__) + gm = ColoGraphModule(model_to_shard, graph, model_to_shard.__class__.__name__) + shape_prop_pass(gm, *input_sample.values()) + gm.recompile() solver_options = SolverOptions() strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) strategies_constructor.build_strategies_and_cost() @@ -108,10 +113,10 @@ def mem_test_for_node_strategy(rank: int, # estimated memory if target_node.op == "call_module": - metainfo = MetaInfo(target_node.strategies_vector[strategy_index], - target_node.graph.owning_module.get_submodule(target_node.target)) + metainfo = ShardMetaInfo(target_node.strategies_vector[strategy_index], + target_node.graph.owning_module.get_submodule(target_node.target)) else: - metainfo = MetaInfo(target_node.strategies_vector[strategy_index], target_node.target) + metainfo = ShardMetaInfo(target_node.strategies_vector[strategy_index], target_node.target) print("estimated memory:") print( diff --git a/tests/test_auto_parallel/test_tensor_shard/test_param_resharding_cost.py b/tests/test_auto_parallel/test_tensor_shard/test_param_resharding_cost.py deleted file mode 100644 index 92f011ba30d2..000000000000 --- a/tests/test_auto_parallel/test_tensor_shard/test_param_resharding_cost.py +++ /dev/null @@ -1,126 +0,0 @@ -import torch - -from colossalai.auto_parallel.tensor_shard.options import SolverOptions -from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDataType -from colossalai.auto_parallel.tensor_shard.solver import CostGraph, GraphAnalyser, Solver, StrategiesConstructor -from colossalai.device.device_mesh import DeviceMesh -from colossalai.fx import ColoGraphModule, ColoTracer -from colossalai.testing.pytest_wrapper import run_on_environment_flag - - -def _param_resharding_cost_assertion(node): - for strategy in node.strategies_vector: - for prev_node, resharding_cost in strategy.resharding_costs.items(): - if strategy.get_op_data_by_name(str(prev_node)).type == OperationDataType.PARAM: - for cost in resharding_cost: - assert cost.fwd == 0 - assert cost.bwd == 0 - assert cost.total == 0 - - -class LinearModel(torch.nn.Module): - - def __init__(self, in_features, out_features): - super().__init__() - self.linear = torch.nn.Linear(in_features, out_features) - - def forward(self, x): - x = self.linear(x) - x = x * 2 - - return x - - -class ConvModel(torch.nn.Module): - - def __init__(self, in_channels, out_channels, kernel_size, bias=True): - super().__init__() - self.conv = torch.nn.Conv2d(in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - bias=bias) - - def forward(self, x): - x = self.conv(x) - x = x * 2 - - return x - - -@run_on_environment_flag(name='AUTO_PARALLEL') -def test_linear_module(): - model = LinearModel(4, 8) - physical_mesh_id = torch.arange(0, 4) - mesh_shape = (2, 2) - # [[0, 1] - # [2, 3]] - device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) - tracer = ColoTracer() - # graph(): - # %x : torch.Tensor [#users=1] = placeholder[target=x] - # %linear_weight : [#users=1] = get_attr[target=linear.weight] - # %linear_bias : [#users=1] = get_attr[target=linear.bias] - # %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%x, %linear_weight), kwargs = {}) - # %add : [#users=1] = call_function[target=operator.add](args = (%linear, %linear_bias), kwargs = {}) - # %mul : [#users=1] = call_function[target=operator.mul](args = (%add, 2), kwargs = {}) - # return mul - graph = tracer.trace(root=model, meta_args={'x': torch.rand(4, 4).to('meta')}) - # def forward(self, x : torch.Tensor): - # linear_weight = self.linear.weight - # linear_bias = self.linear.bias - # linear = torch._C._nn.linear(x, linear_weight); x = linear_weight = None - # add = linear + linear_bias; linear = linear_bias = None - # mul = add * 2; add = None - # return mul - gm = ColoGraphModule(model, graph) - gm.recompile() - node_list = list(graph.nodes) - - solver_options = SolverOptions() - strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) - strategies_constructor.build_strategies_and_cost() - linear_node = node_list[3] - _param_resharding_cost_assertion(linear_node) - - -@run_on_environment_flag(name='AUTO_PARALLEL') -def test_conv_module(): - model = ConvModel(3, 6, 2) - physical_mesh_id = torch.arange(0, 4) - mesh_shape = (2, 2) - # [[0, 1] - # [2, 3]] - device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) - tracer = ColoTracer() - # graph(): - # %x : torch.Tensor [#users=1] = placeholder[target=x] - # %conv_weight : [#users=1] = get_attr[target=conv.weight] - # %conv_bias : [#users=1] = get_attr[target=conv.bias] - # %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%x, %conv_weight), kwargs = {}) - # %view : [#users=1] = call_method[target=view](args = (%conv_bias, [1, -1, 1, 1]), kwargs = {}) - # %add : [#users=1] = call_function[target=operator.add](args = (%conv2d, %view), kwargs = {}) - # %mul : [#users=1] = call_function[target=operator.mul](args = (%add, 2), kwargs = {}) - # return mul - graph = tracer.trace(root=model, meta_args={'x': torch.rand(4, 3, 64, 64).to('meta')}) - # def forward(self, x : torch.Tensor): - # conv_weight = self.conv.weight - # conv_bias = self.conv.bias - # conv2d = torch.conv2d(x, conv_weight); x = conv_weight = None - # view = conv_bias.view([1, -1, 1, 1]); conv_bias = None - # add = conv2d + view; conv2d = view = None - # mul = add * 2; add = None - # return mul - gm = ColoGraphModule(model, graph) - - gm.recompile() - node_list = list(graph.nodes) - conv_node = node_list[3] - solver_options = SolverOptions() - strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) - strategies_constructor.build_strategies_and_cost() - _param_resharding_cost_assertion(conv_node) - - -if __name__ == '__main__': - test_linear_module() - test_conv_module() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_shape_consistency_pass.py b/tests/test_auto_parallel/test_tensor_shard/test_shape_consistency_pass.py deleted file mode 100644 index 24a3ae5b42c3..000000000000 --- a/tests/test_auto_parallel/test_tensor_shard/test_shape_consistency_pass.py +++ /dev/null @@ -1,86 +0,0 @@ -import copy -from functools import partial - -import pytest -import torch -import torch.multiprocessing as mp -import torch.nn as nn - -from colossalai.auto_parallel.tensor_shard.initialize import initialize_model -from colossalai.device.device_mesh import DeviceMesh -from colossalai.initialize import launch -from colossalai.logging import disable_existing_loggers -from colossalai.testing import assert_close, rerun_if_address_is_in_use -from colossalai.testing.pytest_wrapper import run_on_environment_flag -from colossalai.utils import free_port - - -class ConvModel(nn.Module): - - def __init__(self, c_in, c_out): - super().__init__() - self.conv = nn.Conv2d(c_in, c_out, kernel_size=3, padding=1, bias=False) - - def forward(self, x): - x = self.conv(x) - x = torch.flatten(x) - return x - - -def check_apply(rank, world_size, port): - disable_existing_loggers() - launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - input = torch.rand(4, 4, 4, 4).cuda() - test_input = copy.deepcopy(input) - # graph(): - # %x : torch.Tensor [#users=1] = placeholder[target=x] - # %conv : [#users=1] = call_module[target=conv](args = (%mul,), kwargs = {}) - # return conv - model = ConvModel(4, 4).cuda() - test_model = copy.deepcopy(model) - physical_mesh_id = torch.arange(0, 4) - mesh_shape = (2, 2) - # [[0, 1] - # [2, 3]] - device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) - meta_args = {'x': torch.rand(4, 4, 4, 4).to('meta')} - gm = initialize_model(model, meta_args, device_mesh) - - output = gm(input) - origin_output = test_model(test_input) - assert output.equal(origin_output) - origin_loss = origin_output.sum() - loss = output.sum() - - origin_loss.backward() - loss.backward() - - grad_0 = test_model.conv.weight.grad.narrow(0, 0, 1) - grad_1 = test_model.conv.weight.grad.narrow(0, 1, 1) - grad_2 = test_model.conv.weight.grad.narrow(0, 2, 1) - grad_3 = test_model.conv.weight.grad.narrow(0, 3, 1) - - if rank == 0: - assert_close(gm.module.conv.weight.grad.data, grad_0.data) - elif rank == 1: - assert_close(gm.module.conv.weight.grad.data, grad_1.data) - elif rank == 2: - assert_close(gm.module.conv.weight.grad.data, grad_2.data) - elif rank == 3: - assert_close(gm.module.conv.weight.grad.data, grad_3.data) - else: - raise ValueError(f'rank {rank} does not exist.') - - -# skip this test due to pulp not installed in CI environment -@run_on_environment_flag(name='AUTO_PARALLEL') -@pytest.mark.dist -@rerun_if_address_is_in_use() -def test_apply(): - world_size = 4 - run_func = partial(check_apply, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) - - -if __name__ == '__main__': - test_apply() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_solver_with_resnet_v2.py b/tests/test_auto_parallel/test_tensor_shard/test_solver_with_resnet_v2.py index bbfc3e1fcc14..fb47baab9476 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_solver_with_resnet_v2.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_solver_with_resnet_v2.py @@ -2,11 +2,13 @@ from torch.fx import GraphModule from torchvision.models import resnet50 +from colossalai._analyzer.fx.passes import shape_prop_pass +# from colossalai.fx.tracer.tracer import ColoTracer +from colossalai._analyzer.fx.tracer.tracer import ColoTracer from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP from colossalai.auto_parallel.tensor_shard.options import SolverOptions from colossalai.auto_parallel.tensor_shard.solver import CostGraph, GraphAnalyser, Solver, StrategiesConstructor from colossalai.device.device_mesh import DeviceMesh -from colossalai.fx.tracer.tracer import ColoTracer from colossalai.tensor.shape_consistency import ShapeConsistencyManager from colossalai.testing.pytest_wrapper import run_on_environment_flag @@ -20,7 +22,7 @@ def test_cost_graph(): device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) shape_consistency_manager = ShapeConsistencyManager() - tracer = ColoTracer() + tracer = ColoTracer(bias_addition_split=True) model = resnet50(num_classes=100000) input_sample = {'x': torch.rand(128, 3, 224, 224).to('meta')} @@ -50,6 +52,7 @@ def test_cost_graph(): # %fc : [#users=1] = call_module[target=fc](args = (%flatten,), kwargs = {}) # return fc gm = GraphModule(model, graph, model.__class__.__name__) + shape_prop_pass(gm, *input_sample.values()) gm.recompile() solver_options = SolverOptions() From b92313903f36e863c91b3e9bac621dba31cf52c0 Mon Sep 17 00:00:00 2001 From: Yuanchen <70520919+chengeharrison@users.noreply.github.com> Date: Wed, 5 Apr 2023 09:45:42 +0800 Subject: [PATCH 09/27] fix save_model indent error in ppo trainer (#3450) Co-authored-by: Yuanchen Xu --- applications/Chat/coati/trainer/ppo.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/applications/Chat/coati/trainer/ppo.py b/applications/Chat/coati/trainer/ppo.py index 84254d50d7e7..6b99855be20e 100644 --- a/applications/Chat/coati/trainer/ppo.py +++ b/applications/Chat/coati/trainer/ppo.py @@ -117,6 +117,9 @@ def training_step(self, experience: Experience) -> Dict[str, float]: return {'reward': experience.reward.mean().item()} + def save_model(self, path: str, only_rank0: bool = False, tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None: + self.strategy.save_model(model=self.actor, path=path, only_rank0=only_rank0, tokenizer=tokenizer) + def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, actor: Actor) -> None: origin_model = strategy._unwrap_actor(actor) @@ -129,7 +132,3 @@ def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, acto new_kwargs['update_model_kwargs_fn'] = update_model_kwargs_fn return new_kwargs - - -def save_model(self, path: str, only_rank0: bool = False, tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None: - self.strategy.save_model(model=self.actor, path=path, only_rank0=only_rank0, tokenizer=tokenizer) From 46c009dba462800b4f7bd54b4558a37e6326726d Mon Sep 17 00:00:00 2001 From: Hakjin Lee Date: Thu, 6 Apr 2023 00:24:43 +0900 Subject: [PATCH 10/27] [format] Run lint on colossalai.engine (#3367) --- .../engine/gradient_accumulation/__init__.py | 15 +++++++++++---- .../gradient_handler/_base_gradient_handler.py | 2 +- .../_data_parallel_gradient_handler.py | 7 ++++--- .../_pipeline_parallel_gradient_handler.py | 7 ++++--- .../_sequence_parallel_gradient_handler.py | 7 ++++--- .../gradient_handler/_zero_gradient_handler.py | 1 + colossalai/engine/schedule/__init__.py | 2 +- colossalai/engine/schedule/_base_schedule.py | 2 +- .../engine/schedule/_non_pipeline_schedule.py | 9 +++++---- 9 files changed, 32 insertions(+), 20 deletions(-) diff --git a/colossalai/engine/gradient_accumulation/__init__.py b/colossalai/engine/gradient_accumulation/__init__.py index 4585b9a2529c..4cb6f4ad7384 100644 --- a/colossalai/engine/gradient_accumulation/__init__.py +++ b/colossalai/engine/gradient_accumulation/__init__.py @@ -1,10 +1,17 @@ +from typing import Iterable, List + import torch.nn as nn -from typing import List -from colossalai.engine import BaseGradientHandler -from typing import Iterable from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler -from ._gradient_accumulation import GradAccumDataloader, GradAccumOptimizer, GradAccumLrSchedulerByStep, GradAccumGradientHandler + +from colossalai.engine import BaseGradientHandler + +from ._gradient_accumulation import ( + GradAccumDataloader, + GradAccumGradientHandler, + GradAccumLrSchedulerByStep, + GradAccumOptimizer, +) __all__ = [ 'accumulate_gradient', 'GradAccumDataloader', 'GradAccumOptimizer', 'GradAccumLrSchedulerByStep', diff --git a/colossalai/engine/gradient_handler/_base_gradient_handler.py b/colossalai/engine/gradient_handler/_base_gradient_handler.py index c212359867d1..7d96dd8a88a6 100644 --- a/colossalai/engine/gradient_handler/_base_gradient_handler.py +++ b/colossalai/engine/gradient_handler/_base_gradient_handler.py @@ -5,7 +5,7 @@ class BaseGradientHandler(ABC): - """A basic helper class to handle all-reduce operations of gradients across different parallel groups + """A basic helper class to handle all-reduce operations of gradients across different parallel groups before optimization. Args: diff --git a/colossalai/engine/gradient_handler/_data_parallel_gradient_handler.py b/colossalai/engine/gradient_handler/_data_parallel_gradient_handler.py index d113fc516459..5cc7169c5a9f 100644 --- a/colossalai/engine/gradient_handler/_data_parallel_gradient_handler.py +++ b/colossalai/engine/gradient_handler/_data_parallel_gradient_handler.py @@ -1,16 +1,17 @@ from colossalai.core import global_context as gpc from colossalai.registry import GRADIENT_HANDLER -from ._base_gradient_handler import BaseGradientHandler + from ...context.parallel_mode import ParallelMode +from ._base_gradient_handler import BaseGradientHandler from .utils import bucket_allreduce @GRADIENT_HANDLER.register_module class DataParallelGradientHandler(BaseGradientHandler): """A helper class to handle all-reduce operations in a data parallel group. - A all-reduce collective communication will be operated in + A all-reduce collective communication will be operated in :func:`handle_gradient` among a data parallel group. - For better performance, it bucketizes the gradients of all parameters that are + For better performance, it bucketizes the gradients of all parameters that are the same type to improve the efficiency of communication. Args: diff --git a/colossalai/engine/gradient_handler/_pipeline_parallel_gradient_handler.py b/colossalai/engine/gradient_handler/_pipeline_parallel_gradient_handler.py index 83f5c00cf2af..5b49a9c0360d 100644 --- a/colossalai/engine/gradient_handler/_pipeline_parallel_gradient_handler.py +++ b/colossalai/engine/gradient_handler/_pipeline_parallel_gradient_handler.py @@ -4,9 +4,10 @@ import torch import torch.distributed as dist +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors + from colossalai.core import global_context as gpc from colossalai.registry import GRADIENT_HANDLER -from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from ._base_gradient_handler import BaseGradientHandler @@ -14,9 +15,9 @@ @GRADIENT_HANDLER.register_module class PipelineSharedModuleGradientHandler(BaseGradientHandler): """A helper class to handle all-reduce operations in sub parallel groups. - A all-reduce collective communication will be operated in + A all-reduce collective communication will be operated in :func:`handle_gradient` among all sub pipeline parallel groups. - For better performance, it bucketizes the gradients of all parameters that are + For better performance, it bucketizes the gradients of all parameters that are the same type to improve the efficiency of communication. Args: diff --git a/colossalai/engine/gradient_handler/_sequence_parallel_gradient_handler.py b/colossalai/engine/gradient_handler/_sequence_parallel_gradient_handler.py index 53a8ea935a42..ea4f0fbb1c71 100644 --- a/colossalai/engine/gradient_handler/_sequence_parallel_gradient_handler.py +++ b/colossalai/engine/gradient_handler/_sequence_parallel_gradient_handler.py @@ -1,16 +1,17 @@ from colossalai.core import global_context as gpc from colossalai.registry import GRADIENT_HANDLER -from ._base_gradient_handler import BaseGradientHandler + from ...context.parallel_mode import ParallelMode +from ._base_gradient_handler import BaseGradientHandler from .utils import bucket_allreduce @GRADIENT_HANDLER.register_module class SequenceParallelGradientHandler(BaseGradientHandler): """A helper class to handle all-reduce operations in a data parallel group. - A all-reduce collective communication will be operated in + A all-reduce collective communication will be operated in :func:`handle_gradient` among a data parallel group. - For better performance, it bucketizes the gradients of all parameters that are + For better performance, it bucketizes the gradients of all parameters that are the same type to improve the efficiency of communication. Args: diff --git a/colossalai/engine/gradient_handler/_zero_gradient_handler.py b/colossalai/engine/gradient_handler/_zero_gradient_handler.py index f85303e75184..19fd1e97f86f 100644 --- a/colossalai/engine/gradient_handler/_zero_gradient_handler.py +++ b/colossalai/engine/gradient_handler/_zero_gradient_handler.py @@ -1,4 +1,5 @@ from colossalai.registry import GRADIENT_HANDLER + from ._base_gradient_handler import BaseGradientHandler diff --git a/colossalai/engine/schedule/__init__.py b/colossalai/engine/schedule/__init__.py index 54170286e99b..0f2c039d7057 100644 --- a/colossalai/engine/schedule/__init__.py +++ b/colossalai/engine/schedule/__init__.py @@ -1,5 +1,5 @@ from ._base_schedule import BaseSchedule -from ._pipeline_schedule import PipelineSchedule, InterleavedPipelineSchedule, get_tensor_shape from ._non_pipeline_schedule import NonPipelineSchedule +from ._pipeline_schedule import InterleavedPipelineSchedule, PipelineSchedule, get_tensor_shape __all__ = ['BaseSchedule', 'NonPipelineSchedule', 'PipelineSchedule', 'InterleavedPipelineSchedule', 'get_tensor_shape'] diff --git a/colossalai/engine/schedule/_base_schedule.py b/colossalai/engine/schedule/_base_schedule.py index ba797bad9778..a2d50041127a 100644 --- a/colossalai/engine/schedule/_base_schedule.py +++ b/colossalai/engine/schedule/_base_schedule.py @@ -2,10 +2,10 @@ # -*- encoding: utf-8 -*- from abc import ABC, abstractmethod +from typing import Callable, Iterable import torch -from typing import Iterable, Callable from colossalai.logging import get_dist_logger from colossalai.utils import get_current_device diff --git a/colossalai/engine/schedule/_non_pipeline_schedule.py b/colossalai/engine/schedule/_non_pipeline_schedule.py index c62bfb7d7375..b9239d928a7b 100644 --- a/colossalai/engine/schedule/_non_pipeline_schedule.py +++ b/colossalai/engine/schedule/_non_pipeline_schedule.py @@ -1,13 +1,14 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -from typing import Iterable +import inspect +from typing import Callable, Iterable import torch -import inspect -from ._base_schedule import BaseSchedule + from colossalai.utils import conditional_context -from typing import Callable + +from ._base_schedule import BaseSchedule class NonPipelineSchedule(BaseSchedule): From 72cb4dd433aa58bab00e207840f50c18eab7d9b2 Mon Sep 17 00:00:00 2001 From: Camille Zhong <44392324+Camille7777@users.noreply.github.com> Date: Thu, 6 Apr 2023 09:30:28 +0800 Subject: [PATCH 11/27] [Chat] fix the tokenizer "int too big to convert" error in SFT training (#3453) * Add RoBERTa for RLHF Stage 2 & 3 (test) RoBERTa for RLHF Stage 2 & 3 (still in testing) * Revert "Add RoBERTa for RLHF Stage 2 & 3 (test)" This reverts commit 06741d894dcbe958acd4e10d771f22275e20e368. * Add RoBERTa for RLHF stage 2 & 3 1. add roberta folder under model folder 2. add roberta option in train_reward_model.py 3. add some test in testci * Update test_ci.sh * Revert "Update test_ci.sh" This reverts commit 9c7352b81766f3177d31eeec0ec178a301df966a. * Add RoBERTa for RLHF Stage 2 & 3 (test) RoBERTa for RLHF Stage 2 & 3 (still in testing) * Revert "Add RoBERTa for RLHF Stage 2 & 3 (test)" This reverts commit 06741d894dcbe958acd4e10d771f22275e20e368. * Add RoBERTa for RLHF stage 2 & 3 1. add roberta folder under model folder 2. add roberta option in train_reward_model.py 3. add some test in testci * Update test_ci.sh * Revert "Update test_ci.sh" This reverts commit 9c7352b81766f3177d31eeec0ec178a301df966a. * update roberta with coati * chat ci update * Revert "chat ci update" This reverts commit 17ae7ae01fa752bd3289fc39069868fde99cf846. * [Chat] fix the tokenizer "int too big to convert" error in SFT training fix the tokenizer error during SFT training using Bloom and OPT --- applications/Chat/coati/dataset/sft_dataset.py | 11 ++++++----- applications/Chat/examples/train_sft.py | 9 ++++++--- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/applications/Chat/coati/dataset/sft_dataset.py b/applications/Chat/coati/dataset/sft_dataset.py index 76ae6b158822..91e38f06daba 100644 --- a/applications/Chat/coati/dataset/sft_dataset.py +++ b/applications/Chat/coati/dataset/sft_dataset.py @@ -78,14 +78,14 @@ def __getitem__(self, idx): # return dict(self.prompts[idx], self.prompts[idx]) -def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict: +def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer, max_length: int) -> Dict: """Tokenize a list of strings.""" tokenized_list = [ tokenizer( text, return_tensors="pt", padding="longest", - max_length=tokenizer.model_max_length, + max_length=max_length, truncation=True, ) for text in strings ] @@ -105,10 +105,11 @@ def preprocess( sources: Sequence[str], targets: Sequence[str], tokenizer: transformers.PreTrainedTokenizer, + max_length: int, ) -> Dict: """Preprocess the data by tokenizing.""" examples = [s + t for s, t in zip(sources, targets)] - examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)] + examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer, max_length) for strings in (examples, sources)] input_ids = examples_tokenized["input_ids"] labels = copy.deepcopy(input_ids) for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]): @@ -119,7 +120,7 @@ def preprocess( class SupervisedDataset(Dataset): """Dataset for supervised fine-tuning.""" - def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer, max_datasets_size: int = None): + def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer, max_datasets_size: int = None, max_length: int = 512): super(SupervisedDataset, self).__init__() logger.info("Loading data...") list_data_dict = jload(data_path) @@ -138,7 +139,7 @@ def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer, targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict] logger.info("Tokenizing inputs... This may take some time...") - data_dict = preprocess(sources, targets, tokenizer) + data_dict = preprocess(sources, targets, tokenizer, max_length) self.input_ids = data_dict["input_ids"] self.labels = data_dict["labels"] diff --git a/applications/Chat/examples/train_sft.py b/applications/Chat/examples/train_sft.py index 035d5a1ded1d..c0ac7b177694 100644 --- a/applications/Chat/examples/train_sft.py +++ b/applications/Chat/examples/train_sft.py @@ -71,6 +71,7 @@ def train(args): else: raise ValueError(f'Unsupported model "{args.model}"') tokenizer.pad_token = tokenizer.eos_token + max_len = args.max_len if args.model == 'llama': tokenizer = prepare_llama_tokenizer_and_embedding(tokenizer, model) @@ -99,13 +100,14 @@ def train(args): train_data = load_dataset(args.dataset, 'super_natural_instructions', split='train') eval_data = load_dataset(args.dataset, 'super_natural_instructions', split='test') - train_dataset = SFTDataset(train_data, tokenizer) - eval_dataset = SFTDataset(eval_data, tokenizer) + train_dataset = SFTDataset(train_data, tokenizer, max_len) + eval_dataset = SFTDataset(eval_data, tokenizer, max_len) else: train_dataset = SupervisedDataset(tokenizer=tokenizer, data_path=args.dataset, - max_datasets_size=args.max_datasets_size) + max_datasets_size=args.max_datasets_size, + max_length=max_len) eval_dataset = None data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) @@ -176,6 +178,7 @@ def train(args): parser.add_argument('--need_optim_ckpt', type=bool, default=False) parser.add_argument('--max_epochs', type=int, default=3) parser.add_argument('--batch_size', type=int, default=4) + parser.add_argument('--max_len', type=int, default=512) parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank") parser.add_argument('--log_interval', type=int, default=100, help="how many steps to log") parser.add_argument('--lr', type=float, default=5e-6) From 933048ad3e68f243d0cfeda1e7a8cf1b55a4cf88 Mon Sep 17 00:00:00 2001 From: ver217 Date: Thu, 6 Apr 2023 09:38:25 +0800 Subject: [PATCH 12/27] [test] reorganize zero/gemini tests (#3445) --- tests/test_moe/test_moe_checkpoint.py | 2 +- tests/test_moe/test_moe_colo_init.py | 2 +- tests/test_moe/test_moe_zero_init.py | 2 +- tests/test_moe/test_moe_zero_model.py | 2 +- tests/test_moe/test_moe_zero_optim.py | 2 +- .../update => test_zero/test_gemini}/test_chunk_mgrv2.py | 0 .../update => test_zero/test_gemini}/test_chunkv2.py | 0 .../update => test_zero/test_gemini}/test_fwd_bwd.py | 0 .../update => test_zero/test_gemini}/test_gemini_use_rmt.py | 0 .../update => test_zero/test_gemini}/test_get_torch_model.py | 0 .../update => test_zero/test_gemini}/test_grad_clip.py | 0 .../update => test_zero/test_gemini}/test_inference.py | 0 .../{test_gemini/update => test_zero/test_gemini}/test_optim.py | 0 tests/{ => test_zero}/test_gemini/test_runtime_mem_tracer.py | 0 .../update => test_zero/test_gemini}/test_search.py | 0 .../update => test_zero/test_gemini}/test_zeroddp_state_dict.py | 0 .../test_gemini}/test_zerooptim_state_dict.py | 0 tests/test_zero/{ => test_legacy}/common.py | 0 tests/test_zero/{ => test_legacy}/test_found_inf.py | 2 +- .../test_legacy}/test_gemini_manager.py | 0 tests/test_zero/{ => test_legacy}/test_init_context.py | 0 tests/{test_gemini => test_zero/test_legacy}/test_param_op.py | 0 tests/test_zero/{ => test_legacy}/test_shard_model_v2.py | 0 tests/test_zero/{ => test_legacy}/test_shard_param.py | 2 +- .../{ => test_legacy}/test_sharded_optim_state_dict.py | 0 tests/test_zero/{ => test_legacy}/test_sharded_optim_v2.py | 0 .../{ => test_legacy}/test_sharded_optim_with_sync_bn.py | 0 tests/test_zero/{ => test_legacy}/test_state_dict.py | 1 - tests/test_zero/{ => test_legacy}/test_tensor_utils.py | 0 tests/test_zero/{ => test_legacy}/test_zero_engine.py | 0 .../{low_level_zero => test_low_level}/test_grad_acc.py | 0 .../{low_level_zero => test_low_level}/test_zero1_2.py | 0 .../{low_level_zero => test_low_level}/test_zero_init.py | 0 .../{low_level_zero => test_low_level}/test_zero_tp.py | 0 34 files changed, 7 insertions(+), 8 deletions(-) rename tests/{test_gemini/update => test_zero/test_gemini}/test_chunk_mgrv2.py (100%) rename tests/{test_gemini/update => test_zero/test_gemini}/test_chunkv2.py (100%) rename tests/{test_gemini/update => test_zero/test_gemini}/test_fwd_bwd.py (100%) rename tests/{test_gemini/update => test_zero/test_gemini}/test_gemini_use_rmt.py (100%) rename tests/{test_gemini/update => test_zero/test_gemini}/test_get_torch_model.py (100%) rename tests/{test_gemini/update => test_zero/test_gemini}/test_grad_clip.py (100%) rename tests/{test_gemini/update => test_zero/test_gemini}/test_inference.py (100%) rename tests/{test_gemini/update => test_zero/test_gemini}/test_optim.py (100%) rename tests/{ => test_zero}/test_gemini/test_runtime_mem_tracer.py (100%) rename tests/{test_gemini/update => test_zero/test_gemini}/test_search.py (100%) rename tests/{test_gemini/update => test_zero/test_gemini}/test_zeroddp_state_dict.py (100%) rename tests/{test_gemini/update => test_zero/test_gemini}/test_zerooptim_state_dict.py (100%) rename tests/test_zero/{ => test_legacy}/common.py (100%) rename tests/test_zero/{ => test_legacy}/test_found_inf.py (97%) rename tests/{test_gemini => test_zero/test_legacy}/test_gemini_manager.py (100%) rename tests/test_zero/{ => test_legacy}/test_init_context.py (100%) rename tests/{test_gemini => test_zero/test_legacy}/test_param_op.py (100%) rename tests/test_zero/{ => test_legacy}/test_shard_model_v2.py (100%) rename tests/test_zero/{ => test_legacy}/test_shard_param.py (98%) rename tests/test_zero/{ => test_legacy}/test_sharded_optim_state_dict.py (100%) rename tests/test_zero/{ => test_legacy}/test_sharded_optim_v2.py (100%) rename tests/test_zero/{ => test_legacy}/test_sharded_optim_with_sync_bn.py (100%) rename tests/test_zero/{ => test_legacy}/test_state_dict.py (98%) rename tests/test_zero/{ => test_legacy}/test_tensor_utils.py (100%) rename tests/test_zero/{ => test_legacy}/test_zero_engine.py (100%) rename tests/test_zero/{low_level_zero => test_low_level}/test_grad_acc.py (100%) rename tests/test_zero/{low_level_zero => test_low_level}/test_zero1_2.py (100%) rename tests/test_zero/{low_level_zero => test_low_level}/test_zero_init.py (100%) rename tests/test_zero/{low_level_zero => test_low_level}/test_zero_tp.py (100%) diff --git a/tests/test_moe/test_moe_checkpoint.py b/tests/test_moe/test_moe_checkpoint.py index 5b6fe441112d..d2cff44ad9bd 100644 --- a/tests/test_moe/test_moe_checkpoint.py +++ b/tests/test_moe/test_moe_checkpoint.py @@ -14,7 +14,7 @@ from colossalai.zero import ColoInitContext from tests.test_moe.test_moe_zero_init import MoeModel from tests.test_tensor.common_utils import debug_print -from tests.test_zero.common import CONFIG +from tests.test_zero.test_legacy.common import CONFIG def exam_moe_checkpoint(): diff --git a/tests/test_moe/test_moe_colo_init.py b/tests/test_moe/test_moe_colo_init.py index 23ad1a3dc6a4..4826d87ac044 100644 --- a/tests/test_moe/test_moe_colo_init.py +++ b/tests/test_moe/test_moe_colo_init.py @@ -13,7 +13,7 @@ from colossalai.zero import ColoInitContext from tests.test_moe.test_moe_zero_init import MoeModel from tests.test_tensor.common_utils import debug_print -from tests.test_zero.common import CONFIG +from tests.test_zero.test_legacy.common import CONFIG @parameterize("init_device_type", ['cpu', 'cuda']) diff --git a/tests/test_moe/test_moe_zero_init.py b/tests/test_moe/test_moe_zero_init.py index 5987e31f71c9..18b50eb5c482 100644 --- a/tests/test_moe/test_moe_zero_init.py +++ b/tests/test_moe/test_moe_zero_init.py @@ -14,7 +14,7 @@ from colossalai.utils import free_port, get_current_device from colossalai.zero.legacy.init_ctx import ZeroInitContext from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy, TensorShardStrategy -from tests.test_zero.common import CONFIG +from tests.test_zero.test_legacy.common import CONFIG class MoeModel(nn.Module): diff --git a/tests/test_moe/test_moe_zero_model.py b/tests/test_moe/test_moe_zero_model.py index d38f66fef658..49c452938e25 100644 --- a/tests/test_moe/test_moe_zero_model.py +++ b/tests/test_moe/test_moe_zero_model.py @@ -17,7 +17,7 @@ from colossalai.zero.legacy.sharded_model.utils import col_model_deepcopy from tests.components_to_test.registry import non_distributed_component_funcs from tests.test_moe.test_moe_zero_init import MoeModel -from tests.test_zero.common import CONFIG, check_grads_padding, run_fwd_bwd +from tests.test_zero.test_legacy.common import CONFIG, check_grads_padding, run_fwd_bwd @parameterize("enable_autocast", [False]) diff --git a/tests/test_moe/test_moe_zero_optim.py b/tests/test_moe/test_moe_zero_optim.py index 7e140bf862f2..b43e52bb4c6a 100644 --- a/tests/test_moe/test_moe_zero_optim.py +++ b/tests/test_moe/test_moe_zero_optim.py @@ -20,7 +20,7 @@ from colossalai.zero.low_level._utils import has_inf_or_nan from tests.components_to_test.registry import non_distributed_component_funcs from tests.test_moe.test_moe_zero_init import MoeModel -from tests.test_zero.common import CONFIG, check_sharded_model_params +from tests.test_zero.test_legacy.common import CONFIG, check_sharded_model_params def _run_step(model, optimizer, data, label, criterion, grad_handler): diff --git a/tests/test_gemini/update/test_chunk_mgrv2.py b/tests/test_zero/test_gemini/test_chunk_mgrv2.py similarity index 100% rename from tests/test_gemini/update/test_chunk_mgrv2.py rename to tests/test_zero/test_gemini/test_chunk_mgrv2.py diff --git a/tests/test_gemini/update/test_chunkv2.py b/tests/test_zero/test_gemini/test_chunkv2.py similarity index 100% rename from tests/test_gemini/update/test_chunkv2.py rename to tests/test_zero/test_gemini/test_chunkv2.py diff --git a/tests/test_gemini/update/test_fwd_bwd.py b/tests/test_zero/test_gemini/test_fwd_bwd.py similarity index 100% rename from tests/test_gemini/update/test_fwd_bwd.py rename to tests/test_zero/test_gemini/test_fwd_bwd.py diff --git a/tests/test_gemini/update/test_gemini_use_rmt.py b/tests/test_zero/test_gemini/test_gemini_use_rmt.py similarity index 100% rename from tests/test_gemini/update/test_gemini_use_rmt.py rename to tests/test_zero/test_gemini/test_gemini_use_rmt.py diff --git a/tests/test_gemini/update/test_get_torch_model.py b/tests/test_zero/test_gemini/test_get_torch_model.py similarity index 100% rename from tests/test_gemini/update/test_get_torch_model.py rename to tests/test_zero/test_gemini/test_get_torch_model.py diff --git a/tests/test_gemini/update/test_grad_clip.py b/tests/test_zero/test_gemini/test_grad_clip.py similarity index 100% rename from tests/test_gemini/update/test_grad_clip.py rename to tests/test_zero/test_gemini/test_grad_clip.py diff --git a/tests/test_gemini/update/test_inference.py b/tests/test_zero/test_gemini/test_inference.py similarity index 100% rename from tests/test_gemini/update/test_inference.py rename to tests/test_zero/test_gemini/test_inference.py diff --git a/tests/test_gemini/update/test_optim.py b/tests/test_zero/test_gemini/test_optim.py similarity index 100% rename from tests/test_gemini/update/test_optim.py rename to tests/test_zero/test_gemini/test_optim.py diff --git a/tests/test_gemini/test_runtime_mem_tracer.py b/tests/test_zero/test_gemini/test_runtime_mem_tracer.py similarity index 100% rename from tests/test_gemini/test_runtime_mem_tracer.py rename to tests/test_zero/test_gemini/test_runtime_mem_tracer.py diff --git a/tests/test_gemini/update/test_search.py b/tests/test_zero/test_gemini/test_search.py similarity index 100% rename from tests/test_gemini/update/test_search.py rename to tests/test_zero/test_gemini/test_search.py diff --git a/tests/test_gemini/update/test_zeroddp_state_dict.py b/tests/test_zero/test_gemini/test_zeroddp_state_dict.py similarity index 100% rename from tests/test_gemini/update/test_zeroddp_state_dict.py rename to tests/test_zero/test_gemini/test_zeroddp_state_dict.py diff --git a/tests/test_gemini/update/test_zerooptim_state_dict.py b/tests/test_zero/test_gemini/test_zerooptim_state_dict.py similarity index 100% rename from tests/test_gemini/update/test_zerooptim_state_dict.py rename to tests/test_zero/test_gemini/test_zerooptim_state_dict.py diff --git a/tests/test_zero/common.py b/tests/test_zero/test_legacy/common.py similarity index 100% rename from tests/test_zero/common.py rename to tests/test_zero/test_legacy/common.py diff --git a/tests/test_zero/test_found_inf.py b/tests/test_zero/test_legacy/test_found_inf.py similarity index 97% rename from tests/test_zero/test_found_inf.py rename to tests/test_zero/test_legacy/test_found_inf.py index 641136718161..03a1a609b672 100644 --- a/tests/test_zero/test_found_inf.py +++ b/tests/test_zero/test_legacy/test_found_inf.py @@ -4,6 +4,7 @@ import torch import torch.multiprocessing as mp from common import CONFIG +from test_sharded_optim_v2 import _run_step import colossalai from colossalai.nn.optimizer import HybridAdam @@ -16,7 +17,6 @@ from colossalai.zero.legacy.sharded_optim import ShardedOptimizerV2 from colossalai.zero.low_level._utils import has_inf_or_nan from tests.components_to_test.registry import non_distributed_component_funcs -from tests.test_zero.test_sharded_optim_v2 import _run_step @parameterize("cpu_offload", [True, False]) diff --git a/tests/test_gemini/test_gemini_manager.py b/tests/test_zero/test_legacy/test_gemini_manager.py similarity index 100% rename from tests/test_gemini/test_gemini_manager.py rename to tests/test_zero/test_legacy/test_gemini_manager.py diff --git a/tests/test_zero/test_init_context.py b/tests/test_zero/test_legacy/test_init_context.py similarity index 100% rename from tests/test_zero/test_init_context.py rename to tests/test_zero/test_legacy/test_init_context.py diff --git a/tests/test_gemini/test_param_op.py b/tests/test_zero/test_legacy/test_param_op.py similarity index 100% rename from tests/test_gemini/test_param_op.py rename to tests/test_zero/test_legacy/test_param_op.py diff --git a/tests/test_zero/test_shard_model_v2.py b/tests/test_zero/test_legacy/test_shard_model_v2.py similarity index 100% rename from tests/test_zero/test_shard_model_v2.py rename to tests/test_zero/test_legacy/test_shard_model_v2.py diff --git a/tests/test_zero/test_shard_param.py b/tests/test_zero/test_legacy/test_shard_param.py similarity index 98% rename from tests/test_zero/test_shard_param.py rename to tests/test_zero/test_legacy/test_shard_param.py index 6085de3c8919..b76648321451 100644 --- a/tests/test_zero/test_shard_param.py +++ b/tests/test_zero/test_legacy/test_shard_param.py @@ -4,6 +4,7 @@ import pytest import torch import torch.multiprocessing as mp +from common import CONFIG, allclose import colossalai from colossalai.testing import parameterize, rerun_if_address_is_in_use @@ -12,7 +13,6 @@ from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy, TensorShardStrategy from colossalai.zero.legacy.sharded_param import ShardedTensor from colossalai.zero.legacy.sharded_param.sharded_param import ShardedParamV2 -from tests.test_zero.common import CONFIG, allclose @parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy]) diff --git a/tests/test_zero/test_sharded_optim_state_dict.py b/tests/test_zero/test_legacy/test_sharded_optim_state_dict.py similarity index 100% rename from tests/test_zero/test_sharded_optim_state_dict.py rename to tests/test_zero/test_legacy/test_sharded_optim_state_dict.py diff --git a/tests/test_zero/test_sharded_optim_v2.py b/tests/test_zero/test_legacy/test_sharded_optim_v2.py similarity index 100% rename from tests/test_zero/test_sharded_optim_v2.py rename to tests/test_zero/test_legacy/test_sharded_optim_v2.py diff --git a/tests/test_zero/test_sharded_optim_with_sync_bn.py b/tests/test_zero/test_legacy/test_sharded_optim_with_sync_bn.py similarity index 100% rename from tests/test_zero/test_sharded_optim_with_sync_bn.py rename to tests/test_zero/test_legacy/test_sharded_optim_with_sync_bn.py diff --git a/tests/test_zero/test_state_dict.py b/tests/test_zero/test_legacy/test_state_dict.py similarity index 98% rename from tests/test_zero/test_state_dict.py rename to tests/test_zero/test_legacy/test_state_dict.py index c435d9bb1ef7..40d2820d800a 100644 --- a/tests/test_zero/test_state_dict.py +++ b/tests/test_zero/test_legacy/test_state_dict.py @@ -1,7 +1,6 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -from copy import deepcopy from functools import partial import pytest diff --git a/tests/test_zero/test_tensor_utils.py b/tests/test_zero/test_legacy/test_tensor_utils.py similarity index 100% rename from tests/test_zero/test_tensor_utils.py rename to tests/test_zero/test_legacy/test_tensor_utils.py diff --git a/tests/test_zero/test_zero_engine.py b/tests/test_zero/test_legacy/test_zero_engine.py similarity index 100% rename from tests/test_zero/test_zero_engine.py rename to tests/test_zero/test_legacy/test_zero_engine.py diff --git a/tests/test_zero/low_level_zero/test_grad_acc.py b/tests/test_zero/test_low_level/test_grad_acc.py similarity index 100% rename from tests/test_zero/low_level_zero/test_grad_acc.py rename to tests/test_zero/test_low_level/test_grad_acc.py diff --git a/tests/test_zero/low_level_zero/test_zero1_2.py b/tests/test_zero/test_low_level/test_zero1_2.py similarity index 100% rename from tests/test_zero/low_level_zero/test_zero1_2.py rename to tests/test_zero/test_low_level/test_zero1_2.py diff --git a/tests/test_zero/low_level_zero/test_zero_init.py b/tests/test_zero/test_low_level/test_zero_init.py similarity index 100% rename from tests/test_zero/low_level_zero/test_zero_init.py rename to tests/test_zero/test_low_level/test_zero_init.py diff --git a/tests/test_zero/low_level_zero/test_zero_tp.py b/tests/test_zero/test_low_level/test_zero_tp.py similarity index 100% rename from tests/test_zero/low_level_zero/test_zero_tp.py rename to tests/test_zero/test_low_level/test_zero_tp.py From 8f740deb5323bf940724911541456a1ab0329179 Mon Sep 17 00:00:00 2001 From: YH <100389977+yhna940@users.noreply.github.com> Date: Thu, 6 Apr 2023 10:43:31 +0900 Subject: [PATCH 13/27] Fix typo (#3448) --- colossalai/tensor/param_op_hook.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/colossalai/tensor/param_op_hook.py b/colossalai/tensor/param_op_hook.py index ed705da0eb0d..9c2e0d4adbf1 100644 --- a/colossalai/tensor/param_op_hook.py +++ b/colossalai/tensor/param_op_hook.py @@ -168,12 +168,12 @@ def _get_grad_args(*args): # if there is no grad tensor, the backward of PreFwdPostBwd can't be triggered arg_zero = args[0] if not isinstance(arg_zero, tuple): - raise NotImplementedError("Some torch function is incompatible because of its complcated inputs.") + raise NotImplementedError("Some torch function is incompatible because of its complicated inputs.") check_grad_flag = False for obj in arg_zero: check_grad_flag |= _is_grad_tensor(obj) if not check_grad_flag: - raise NotImplementedError("Some torch function is incompatible because of its complcated inputs.") + raise NotImplementedError("Some torch function is incompatible because of its complicated inputs.") return arg_zero, args[1:] From 7d8d825681a41df62f9f546d8586c50a3757e3a8 Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Thu, 6 Apr 2023 09:43:51 +0800 Subject: [PATCH 14/27] [booster] fixed the torch ddp plugin with the new checkpoint api (#3442) --- colossalai/booster/plugin/gemini_plugin.py | 9 +++++---- colossalai/booster/plugin/torch_ddp_plugin.py | 6 +++--- examples/tutorial/new_api/torch_ddp/README.md | 6 +++--- 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index 3c6e539ba972..6693b1f44d62 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -13,6 +13,7 @@ from torch.utils.data.distributed import DistributedSampler from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO +from colossalai.checkpoint_io.utils import save_state_dict from colossalai.cluster import DistCoordinator from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.tensor.colo_parameter import ColoParameter @@ -83,7 +84,7 @@ def load_unsharded_model(self, model: GeminiDDP, checkpoint: str, strict: bool = # the model should be unwrapped in self.load_model via ModelWrapper.unwrap return super().load_unsharded_model(model, checkpoint, strict=strict) - def save_unsharded_model(self, model: GeminiDDP, checkpoint: str): + def save_unsharded_model(self, model: GeminiDDP, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): """ Save model to checkpoint but only on master process. """ @@ -91,14 +92,14 @@ def save_unsharded_model(self, model: GeminiDDP, checkpoint: str): # as there is communication when get state dict, this must be called on all processes state_dict = model.state_dict(only_rank_0=True) if self.coordinator.is_master(): - self.save_checkpoint(state_dict, checkpoint) + save_state_dict(state_dict, checkpoint, use_safetensors) - def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str): + def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool): """ Save optimizer to checkpoint but only on master process. """ # TODO(ver217): optimizer state dict is sharded - super().save_unsharded_optimizer(optimizer, checkpoint) + super().save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor) def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): """ diff --git a/colossalai/booster/plugin/torch_ddp_plugin.py b/colossalai/booster/plugin/torch_ddp_plugin.py index e2abe11ba143..c5e310c7e769 100644 --- a/colossalai/booster/plugin/torch_ddp_plugin.py +++ b/colossalai/booster/plugin/torch_ddp_plugin.py @@ -33,20 +33,20 @@ def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool = # the model should be unwrapped in self.load_model via ModelWrapper.unwrap return super().load_unsharded_model(model, checkpoint, strict=strict) - def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool): + def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): """ Save model to checkpoint but only on master process. """ # the model should be unwrapped in self.load_model via ModelWrapper.unwrap if self.coordinator.is_master(): - super().save_unsharded_model(model, checkpoint) + super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors) def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool): """ Save optimizer to checkpoint but only on master process. """ if self.coordinator.is_master(): - super().save_unsharded_optimizer(optimizer, checkpoint) + super().save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor) def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): """ diff --git a/examples/tutorial/new_api/torch_ddp/README.md b/examples/tutorial/new_api/torch_ddp/README.md index 62d5a083d0a1..e120bacb0c84 100644 --- a/examples/tutorial/new_api/torch_ddp/README.md +++ b/examples/tutorial/new_api/torch_ddp/README.md @@ -2,10 +2,10 @@ ## 🚀 Quick Start -This example provides a training script and and evaluation script. The training script provides a an example of training ResNet on CIFAR10 dataset from scratch. +This example provides a training script and an evaluation script. The training script provides an example of training ResNet on CIFAR10 dataset from scratch. - Training Arguments - - `-r, `--resume`: resume from checkpoint file path + - `-r`, `--resume`: resume from checkpoint file path - `-c`, `--checkpoint`: the folder to save checkpoints - `-i`, `--interval`: epoch interval to save checkpoints - `-f`, `--fp16`: use fp16 @@ -41,4 +41,4 @@ Expected accuracy performance will be: | --------- | ------------------------ | --------------------- | --------------------- | | ResNet-18 | 85.85% | 85.03% | 85.12% | -**Note: the baseline is a adapted from the [script](https://pytorch-tutorial.readthedocs.io/en/latest/tutorial/chapter03_intermediate/3_2_2_cnn_resnet_cifar10/) to use `torchvision.models.resnet18`** +**Note: the baseline is adapted from the [script](https://pytorch-tutorial.readthedocs.io/en/latest/tutorial/chapter03_intermediate/3_2_2_cnn_resnet_cifar10/) to use `torchvision.models.resnet18`** From 57a3c4db6d5af3f3d46dfcbc52afb55cb5415298 Mon Sep 17 00:00:00 2001 From: kingkingofall <83848390+kingkingofall@users.noreply.github.com> Date: Thu, 6 Apr 2023 10:58:53 +0800 Subject: [PATCH 15/27] [chat]fix readme (#3429) * fix stage 2 fix stage 2 * add torch --- applications/Chat/examples/README.md | 2 +- applications/Chat/inference/README.md | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/applications/Chat/examples/README.md b/applications/Chat/examples/README.md index 49401ec30db5..6c02606eab93 100644 --- a/applications/Chat/examples/README.md +++ b/applications/Chat/examples/README.md @@ -57,7 +57,7 @@ You can run the `examples/train_rm.sh` to start a reward model training. You can also use the following cmd to start training a reward model. ``` -torchrun --standalone --nproc_per_node=4 train_reward_model.py +torchrun --standalone --nproc_per_node=4 train_reward_model.py \ --pretrain "/path/to/LLaMa-7B/" \ --model 'llama' \ --strategy colossalai_zero2 \ diff --git a/applications/Chat/inference/README.md b/applications/Chat/inference/README.md index 6c23bc73cd60..434677c98fa5 100644 --- a/applications/Chat/inference/README.md +++ b/applications/Chat/inference/README.md @@ -51,6 +51,7 @@ Please ensure you have downloaded HF-format model weights of LLaMA models. Usage: ```python +import torch from transformers import LlamaForCausalLM USE_8BIT = True # use 8-bit quantization; otherwise, use fp16 From 73afb6359489295eafa95378341b9da866fe6097 Mon Sep 17 00:00:00 2001 From: Dr-Corgi Date: Thu, 6 Apr 2023 11:19:14 +0800 Subject: [PATCH 16/27] [chat]fix save_model(#3377) The function save_model should be a part of PPOTrainer. --- applications/Chat/coati/trainer/ppo.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/applications/Chat/coati/trainer/ppo.py b/applications/Chat/coati/trainer/ppo.py index 6b99855be20e..5c7c71d20b16 100644 --- a/applications/Chat/coati/trainer/ppo.py +++ b/applications/Chat/coati/trainer/ppo.py @@ -116,6 +116,9 @@ def training_step(self, experience: Experience) -> Dict[str, float]: self.critic_optim.zero_grad() return {'reward': experience.reward.mean().item()} + + def save_model(self, path: str, only_rank0: bool = False, tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None: + self.strategy.save_model(model=self.actor, path=path, only_rank0=only_rank0, tokenizer=tokenizer) def save_model(self, path: str, only_rank0: bool = False, tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None: self.strategy.save_model(model=self.actor, path=path, only_rank0=only_rank0, tokenizer=tokenizer) From 62f4e2eb0760ac8bfe28834b061dbc2bda93ade9 Mon Sep 17 00:00:00 2001 From: YY Lin Date: Thu, 6 Apr 2023 11:54:52 +0800 Subject: [PATCH 17/27] [Chat]Add Peft support & fix the ptx bug (#3433) * Update ppo.py Fix the bug of fetching wrong batch data * Add peft model support in SFT and Prompts training In stage-1 and stage-3, the peft model supports are added. So the trained artifacts will be only a small lora additions instead of the whole bunch of files. * Delete test_prompts.txt * Delete test_pretrained.txt * Move the peft stuffs to a community folder. * Move the demo sft to community * delete dirty files * Add instructions to install peft using source * Remove Chinese comments * remove the Chinese comments --- applications/Chat/coati/trainer/ppo.py | 7 +- .../Chat/examples/community/EasyPeftModel.md | 24 ++ .../Chat/examples/community/easy_dataset.py | 242 ++++++++++++++++++ .../Chat/examples/community/easy_models.py | 97 +++++++ .../examples/community/train_peft_prompts.py | 227 ++++++++++++++++ .../Chat/examples/community/train_peft_sft.py | 187 ++++++++++++++ 6 files changed, 781 insertions(+), 3 deletions(-) create mode 100644 applications/Chat/examples/community/EasyPeftModel.md create mode 100644 applications/Chat/examples/community/easy_dataset.py create mode 100644 applications/Chat/examples/community/easy_models.py create mode 100644 applications/Chat/examples/community/train_peft_prompts.py create mode 100644 applications/Chat/examples/community/train_peft_sft.py diff --git a/applications/Chat/coati/trainer/ppo.py b/applications/Chat/coati/trainer/ppo.py index 5c7c71d20b16..2b0cfcc16f24 100644 --- a/applications/Chat/coati/trainer/ppo.py +++ b/applications/Chat/coati/trainer/ppo.py @@ -92,9 +92,10 @@ def training_step(self, experience: Experience) -> Dict[str, float]: # ptx loss if self.ptx_coef != 0: - ptx = next(iter(self.pretrain_dataloader))['input_ids'].to(torch.cuda.current_device()) - label = next(iter(self.pretrain_dataloader))['labels'].to(torch.cuda.current_device())[:, 1:] - attention_mask = next(iter(self.pretrain_dataloader))['attention_mask'].to(torch.cuda.current_device()) + batch = next(iter(self.pretrain_dataloader)) + ptx = batch['input_ids'].to(torch.cuda.current_device()) + label = batch['labels'].to(torch.cuda.current_device())[:, 1:] + attention_mask = batch['attention_mask'].to(torch.cuda.current_device()) ptx_log_probs = self.actor.get_base_model()(ptx, attention_mask=attention_mask)['logits'][..., :-1, :] ptx_loss = self.ptx_loss_fn(ptx_log_probs.view(-1, ptx_log_probs.size(-1)), label.view(-1)) actor_loss = ptx_loss * self.ptx_coef + actor_loss * (1 - self.ptx_coef) diff --git a/applications/Chat/examples/community/EasyPeftModel.md b/applications/Chat/examples/community/EasyPeftModel.md new file mode 100644 index 000000000000..16c4af76b91f --- /dev/null +++ b/applications/Chat/examples/community/EasyPeftModel.md @@ -0,0 +1,24 @@ +# Add Peft support for SFT and Prompts model training + +The orginal implementation just adopts the loralib and merges the layers into the final model. The huggingface peft is a better lora model implementation and can be easily training and distributed. + +Since reward model is relative small, I just keep it as original one. I suggest train full model to get the proper reward/critic model. + +# Prelimenary installation +Since the current pypi peft package(0.2) has some bugs, please install the peft package using source. +``` +git clone https://github.com/huggingface/peft +cd peft +pip install . +``` + +# Usage +For SFT training, just call train_peft_sft.py + +Its arguments are almost identical to train_sft.py instead adding a new eval_dataset if you have a eval_dataset file. The data file is just a plain datafile, please check the format in the easy_dataset.py. + +For stage-3 rlhf training, call train_peft_prompts.py. +Its arguments are almost idential to train_prompts.py. The only difference is that I use text files to indicate the prompt and pretrained data file. The models are included in easy_models.py. Currently only bloom models are tested, but technically gpt2/opt/llama should be supported. + +# Dataformat +Please refer the formats in test_sft.txt, test_prompts.txt, test_pretrained.txt. \ No newline at end of file diff --git a/applications/Chat/examples/community/easy_dataset.py b/applications/Chat/examples/community/easy_dataset.py new file mode 100644 index 000000000000..15dd9a3ccd1d --- /dev/null +++ b/applications/Chat/examples/community/easy_dataset.py @@ -0,0 +1,242 @@ +import copy +from typing import Dict, Sequence +from datasets import load_dataset +from torch.utils.data import Dataset +from transformers import AutoTokenizer +import torch +from tqdm import tqdm +import json + +from tqdm import tqdm +import json + +IGNORE_INDEX = -100 + + +def _tokenize_fn(strings: Sequence[str], tokenizer: AutoTokenizer,max_length :int = 512) -> Dict: + """Tokenize a list of strings.""" + tokenized_list = [ + tokenizer( + text, + return_tensors="pt", + padding="longest", + max_length=max_length, + truncation=True, + ) for text in strings + ] + input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list] + input_ids_lens = labels_lens = [ + tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list + ] + return dict( + input_ids=input_ids, + labels=labels, + input_ids_lens=input_ids_lens, + labels_lens=labels_lens, + ) + + +def preprocess( + sources: Sequence[str], + targets: Sequence[str], + tokenizer: AutoTokenizer, + max_length :int = 512 +) -> Dict: + """Preprocess the data by tokenizing.""" + examples = [s + t for s, t in zip(sources, targets)] + examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer,max_length) for strings in (examples, sources)] + input_ids = examples_tokenized["input_ids"] + labels = copy.deepcopy(input_ids) + for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]): + label[:source_len] = IGNORE_INDEX + return dict(input_ids=input_ids, labels=labels) + + +class EasySupervisedDataset(Dataset): + def __init__(self, data_file :str, tokenizer :AutoTokenizer,max_length :int = 512) -> None: + super(EasySupervisedDataset,self).__init__() + with open(data_file,"r",encoding="UTF-8") as f: + all_lines = f.readlines() + #split to source and target ,source the characters before "回答:" including "回答:", target the characters after "回答:" + sources,targets = [],[] + for line in all_lines: + if "回答:" in line: + sep_index = line.index("回答:") + sources.append(line[:sep_index+3]) + targets.append(line[sep_index+3:]+tokenizer.eos_token) + else: + sources.append(line) + targets.append(""+tokenizer.eos_token) + data_dict = preprocess(sources, targets, tokenizer,max_length) + + self.input_ids = data_dict["input_ids"] + self.labels = data_dict["labels"] + self.data_file = data_file + + def __len__(self): + return len(self.input_ids) + + def __getitem__(self, i) -> Dict[str, torch.Tensor]: + return dict(input_ids=self.input_ids[i], labels=self.labels[i]) + + def __repr__(self): + return f"LawSupervisedDataset(data_file={self.data_file}, input_ids_len={len(self.input_ids)}, labels_len={len(self.labels)})" + + def __str__(self): + return f"LawSupervisedDataset(data_file={self.data_file}, input_ids_len={len(self.input_ids)}, labels_len={len(self.labels)})" + +class EasyPromptsDataset(Dataset): + def __init__(self,data_file :str, tokenizer :AutoTokenizer, max_length :int = 96) -> None: + super(EasyPromptsDataset,self).__init__() + with open(data_file,"r",encoding="UTF-8") as f: + all_lines = f.readlines() + all_lines = [line if "回答:" not in line else line[:line.index("回答:")+3] for line in all_lines] + self.prompts = [ + tokenizer(line, + return_tensors='pt', + max_length=max_length, + padding='max_length', + truncation=True)['input_ids'].to(torch.cuda.current_device()).squeeze(0) + for line in tqdm(all_lines) + ] + self.data_file = data_file + def __len__(self): + return len(self.prompts) + + def __getitem__(self, idx): + return self.prompts[idx] + + def __repr__(self): + return f"LawPromptsDataset(data_file={self.data_file}, prompts_len={len(self.prompts)})" + + def __str__(self): + return f"LawPromptsDataset(data_file={self.data_file}, prompts_len={len(self.prompts)})" + + +class EasyRewardDataset(Dataset): + def __init__(self,train_file :str,tokenizer :AutoTokenizer, special_token = None,max_length = 512) -> None: + super(EasyRewardDataset,self).__init__() + self.chosen = [] + self.reject = [] + if special_token is None: + self.end_token = tokenizer.eos_token + else: + self.end_token = special_token + print(self.end_token) + #read all lines in the train_file to a list + with open(train_file,"r",encoding="UTF-8") as f: + all_lines = f.readlines() + for line in tqdm(all_lines): + data = json.loads(line) + prompt = "提问:"+data['prompt']+" 回答:" + + chosen = prompt + data['chosen'] + self.end_token + chosen_token = tokenizer(chosen, + max_length=max_length, + padding="max_length", + truncation=True, + return_tensors="pt") + self.chosen.append({ + "input_ids": chosen_token['input_ids'], + "attention_mask": chosen_token['attention_mask'] + }) + + reject = prompt + data['rejected'] + self.end_token + reject_token = tokenizer(reject, + max_length=max_length, + padding="max_length", + truncation=True, + return_tensors="pt") + self.reject.append({ + "input_ids": reject_token['input_ids'], + "attention_mask": reject_token['attention_mask'] + }) + + def __len__(self): + length = len(self.chosen) + return length + + def __getitem__(self, idx): + return self.chosen[idx]["input_ids"], self.chosen[idx]["attention_mask"], self.reject[idx][ + "input_ids"], self.reject[idx]["attention_mask"] + + #python representation of the object and the string representation of the object + def __repr__(self): + return f"LawRewardDataset(chosen_len={len(self.chosen)}, reject_len={len(self.reject)})" + + def __str__(self): + return f"LawRewardDataset(chosen_len={len(self.chosen)}, reject_len={len(self.reject)})" + +''' +Easy SFT just accept a text file which can be read line by line. However the datasest will group texts together to max_length so LLM will learn the texts meaning better. +If individual lines are not related, just set is_group_texts to False. +''' +class EasySFTDataset(Dataset): + + def __init__(self,data_file :str,tokenizer :AutoTokenizer,max_length = 512,is_group_texts = True) -> None: + super().__init__() + #read the data_file line by line + with open(data_file,"r",encoding="UTF-8") as f: + #encode the text data line by line and put raw python list input_ids only to raw_input_ids list + raw_input_ids = [] + for line in f: + encoded_ids = tokenizer.encode(line) + #if the encoded_ids is longer than max_length, then split it into several parts + if len(encoded_ids) > max_length: + for i in range(0,len(encoded_ids),max_length): + raw_input_ids.append(encoded_ids[i:i+max_length]) + else: + raw_input_ids.append(encoded_ids) + + grouped_inpup_ids = [] + current_input_ids = [] + attention_mask = [] + if tokenizer.pad_token_id is None: + tokenizer.pad_token_id = tokenizer.eos_token_id + if is_group_texts: + for input_ids in raw_input_ids: + if len(current_input_ids) + len(input_ids) > max_length: + #pad the current_input_ids to max_length with tokenizer.pad_token_id + padded_length = max_length - len(current_input_ids) + current_input_ids.extend([tokenizer.pad_token_id] * padded_length) + grouped_inpup_ids.append(torch.tensor(current_input_ids,dtype=torch.long)) + attention_mask.append(torch.tensor([1] * (max_length - padded_length) + [0] * padded_length,dtype=torch.long)) + current_input_ids = [] + else: + current_input_ids.extend(input_ids) + if len(current_input_ids) > 0: + padded_length = max_length - len(current_input_ids) + current_input_ids.extend([tokenizer.pad_token_id] * padded_length) + grouped_inpup_ids.append(torch.tensor(current_input_ids,dtype=torch.long)) + attention_mask.append(torch.tensor([1] * (max_length - padded_length) + [0] * padded_length,dtype=torch.long)) + else: + #just append the raw_input_ids to max_length + for input_ids in raw_input_ids: + padded_length = max_length - len(input_ids) + input_ids.extend([tokenizer.pad_token_id] * padded_length) + attention_mask.append(torch.tensor([1] * (max_length - padded_length) + [0] * padded_length,dtype=torch.long)) + grouped_inpup_ids.append(torch.tensor(input_ids,dtype=torch.long)) + self.input_ids = grouped_inpup_ids + self.labels = copy.deepcopy(self.input_ids) + self.file_name = data_file + self.attention_mask = attention_mask + + def __len__(self): + return len(self.input_ids) + + #get item from dataset + def __getitem__(self,idx): + return dict(input_ids=self.input_ids[idx],labels=self.labels[idx],attention_mask=self.attention_mask[idx]) + + #generate the dataset description to be printed by print in python + def __repr__(self): + return f"EasySFTDataset(len={len(self)},\nfile_name is {self.file_name})" + + #generate the dataset description to be printed by print in python + def __str__(self): + return f"EasySFTDataset(len={len(self)},\nfile_name is {self.file_name})" + + + + + \ No newline at end of file diff --git a/applications/Chat/examples/community/easy_models.py b/applications/Chat/examples/community/easy_models.py new file mode 100644 index 000000000000..080fc1802a02 --- /dev/null +++ b/applications/Chat/examples/community/easy_models.py @@ -0,0 +1,97 @@ +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.modules import Module + +from coati.models.generation import generate +from coati.models.utils import log_probs_from_logits,masked_mean +from transformers import BloomConfig,BloomForCausalLM +from peft import PeftModel + +class Actor(Module): + """ + Actor model base class. + + Args: + model (nn.Module): Actor Model. + """ + + def __init__(self, model: nn.Module) -> None: + super().__init__() + self.model = model + + @torch.no_grad() + def generate( + self, + input_ids: torch.Tensor, + return_action_mask: bool = True, + **kwargs + ) -> Union[Tuple[torch.LongTensor, torch.LongTensor], Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]]: + sequences = generate(self.model, input_ids, **kwargs) + attention_mask = None + pad_token_id = kwargs.get('pad_token_id', None) + if pad_token_id is not None: + attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device) + if not return_action_mask: + return sequences, attention_mask, None + input_len = input_ids.size(1) + eos_token_id = kwargs.get('eos_token_id', None) + if eos_token_id is None: + action_mask = torch.ones_like(sequences, dtype=torch.bool) + else: + # left padding may be applied, only mask action + action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0 + action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input + action_mask[:, :input_len] = False + action_mask = action_mask[:, 1:] + return sequences, attention_mask, action_mask[:, -(sequences.size(1) - input_len):] + + def forward(self, + sequences: torch.LongTensor, + num_actions: int, + attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + """Returns action log probs + """ + output = self.model(sequences, attention_mask=attention_mask) + logits = output['logits'] + log_probs = log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:]) + return log_probs[:, -num_actions:] + + def get_base_model(self): + return self.model + + +class BLOOMActor(Actor): + """ + BLOOM Actor model. + + Args: + pretrained (str): Pretrained model name or path. + config (BloomConfig): Model config. + checkpoint (bool): Enable gradient checkpointing. + lora_rank (int): LoRA rank. + lora_train_bias (str): LoRA bias training mode. + """ + + def __init__(self, + pretrained: str = None, + config: Optional[BloomConfig] = None, + checkpoint: bool = False, + lora_path: str = None) -> None: + if pretrained is not None: + model = BloomForCausalLM.from_pretrained(pretrained) + elif config is not None: + model = BloomForCausalLM(config) + else: + model = BloomForCausalLM(BloomConfig()) + if lora_path is not None: + model = PeftModel.from_pretrained(model,lora_path) + if checkpoint: + model.gradient_checkpointing_enable() + super().__init__(model) + + def print_trainable_parameters(self): + self.get_base_model().print_trainable_parameters() + diff --git a/applications/Chat/examples/community/train_peft_prompts.py b/applications/Chat/examples/community/train_peft_prompts.py new file mode 100644 index 000000000000..b9394c9e4190 --- /dev/null +++ b/applications/Chat/examples/community/train_peft_prompts.py @@ -0,0 +1,227 @@ +import argparse + +import pandas as pd +import torch +import torch.distributed as dist +from coati.dataset import DataCollatorForSupervisedDataset, PromptDataset, SupervisedDataset +from coati.models.bloom import BLOOMRM, BLOOMCritic +from easy_models import BLOOMActor +from coati.models.gpt import GPTRM, GPTActor, GPTCritic +from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM +from coati.models.opt import OPTRM, OPTActor, OPTCritic +from coati.trainer import PPOTrainer +from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy +from coati.utils import prepare_llama_tokenizer_and_embedding +from torch.optim import Adam +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler +from transformers import AutoTokenizer, BloomTokenizerFast, GPT2Tokenizer, LlamaTokenizer + +from colossalai.nn.optimizer import HybridAdam +from peft import PeftModel +from easy_dataset import EasyPromptsDataset,EasySupervisedDataset + +def main(args): + # configure strategy + if args.strategy == 'naive': + strategy = NaiveStrategy() + elif args.strategy == 'ddp': + strategy = DDPStrategy() + elif args.strategy == 'colossalai_gemini': + strategy = ColossalAIStrategy(stage=3, placement_policy='cpu', initial_scale=2**5) + elif args.strategy == 'colossalai_zero2': + strategy = ColossalAIStrategy(stage=2, placement_policy='cpu') + else: + raise ValueError(f'Unsupported strategy "{args.strategy}"') + + if args.rm_path is not None: + state_dict = torch.load(args.rm_path, map_location='cpu') + + # configure model + if args.model == 'bloom': + # initial_model = BLOOMActor(pretrained=args.pretrain) + print('Using peft lora to load Bloom model as inital_model') + initial_model = BLOOMActor(pretrained=args.pretrain,lora_path=args.sft_lora_path) + print('Using peft lora to load Bloom model as initial_model (Done)') + else: + raise ValueError(f'Unsupported actor model "{args.model}"') + + if args.rm_model == None: + rm_model_name = args.model + else: + rm_model_name = args.rm_model + + if rm_model_name == 'gpt2': + reward_model = GPTRM(pretrained=args.rm_pretrain) + elif rm_model_name == 'bloom': + print("load bloom reward model ",args.rm_pretrain) + reward_model = BLOOMRM(pretrained=args.rm_pretrain) + elif rm_model_name == 'opt': + reward_model = OPTRM(pretrained=args.rm_pretrain) + elif rm_model_name == 'llama': + reward_model = LlamaRM(pretrained=args.rm_pretrain) + else: + raise ValueError(f'Unsupported reward model "{rm_model_name}"') + + if args.rm_path is not None: + print('Loading reward model from', args.rm_path) + reward_model.load_state_dict(state_dict) + + if args.strategy != 'colossalai_gemini': + initial_model.to(torch.float16).to(torch.cuda.current_device()) + reward_model.to(torch.float16).to(torch.cuda.current_device()) + + with strategy.model_init_context(): + if args.model == 'bloom': + # actor = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank) + print('Using peft lora to load Bloom model as Actor') + actor = BLOOMActor(pretrained=args.pretrain,lora_path=args.sft_lora_path) + print('Using peft lora to load Bloom model as Actor (Done)') + else: + raise ValueError(f'Unsupported actor model "{args.model}"') + + if rm_model_name == 'gpt2': + critic = GPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True) + elif rm_model_name == 'bloom': + print("load bloom critic ",args.rm_pretrain," lora_rank ",args.lora_rank," use_action_mask ",True) + critic = BLOOMCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True) + print("load bloom critic (Done) ") + elif rm_model_name == 'opt': + critic = OPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True) + elif rm_model_name == 'llama': + critic = LlamaCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True) + else: + raise ValueError(f'Unsupported reward model "{rm_model_name}"') + + if args.rm_path is not None: + print('Loading reward model from', args.rm_path) + critic.load_state_dict(state_dict) + del state_dict + + if args.strategy != 'colossalai_gemini': + critic.to(torch.float16).to(torch.cuda.current_device()) + actor.to(torch.float16).to(torch.cuda.current_device()) + + # configure optimizer + if args.strategy.startswith('colossalai'): + actor_optim = HybridAdam(actor.parameters(), lr=1e-7) + critic_optim = HybridAdam(critic.parameters(), lr=1e-7) + else: + actor_optim = Adam(actor.parameters(), lr=1e-7) + critic_optim = Adam(critic.parameters(), lr=1e-7) + + # configure tokenizer + if args.model == 'gpt2': + tokenizer = GPT2Tokenizer.from_pretrained(args.rm_pretrain) + elif args.model == 'bloom': + tokenizer = BloomTokenizerFast.from_pretrained(args.rm_pretrain) + elif args.model == 'opt': + tokenizer = AutoTokenizer.from_pretrained(args.rm_pretrain) + elif args.model == 'llama': + tokenizer = LlamaTokenizer.from_pretrained(args.pretrain) + tokenizer.eos_token = '<\s>' + else: + raise ValueError(f'Unsupported model "{args.model}"') + + if args.model == 'llama': + tokenizer = prepare_llama_tokenizer_and_embedding(tokenizer, actor) + else: + tokenizer.pad_token = tokenizer.eos_token + + data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) + + prompt_dataset = EasyPromptsDataset(args.prompt_path,tokenizer) + if dist.is_initialized() and dist.get_world_size() > 1: + prompt_sampler = DistributedSampler(prompt_dataset, shuffle=True, seed=42, drop_last=True) + else: + prompt_sampler = None + prompt_dataloader = DataLoader(prompt_dataset, + shuffle=(prompt_sampler is None), + sampler=prompt_sampler, + batch_size=args.train_batch_size) + + pretrain_dataset = EasySupervisedDataset(args.pretrain_dataset, tokenizer) + if dist.is_initialized() and dist.get_world_size() > 1: + pretrain_sampler = DistributedSampler(pretrain_dataset, shuffle=True, seed=42, drop_last=True) + else: + pretrain_sampler = None + pretrain_dataloader = DataLoader(pretrain_dataset, + shuffle=(pretrain_sampler is None), + sampler=pretrain_sampler, + batch_size=args.ptx_batch_size, + collate_fn=data_collator) + + def tokenize_fn(texts): + # MUST padding to max length to ensure inputs of all ranks have the same length + # Different length may lead to hang when using gemini, as different generation steps + batch = tokenizer(texts, return_tensors='pt', max_length=96, padding='max_length', truncation=True) + return {k: v.to(torch.cuda.current_device()) for k, v in batch.items()} + + (actor, actor_optim), (critic, critic_optim) = strategy.prepare((actor, actor_optim), (critic, critic_optim)) + + # configure trainer + trainer = PPOTrainer( + strategy, + actor, + critic, + reward_model, + initial_model, + actor_optim, + critic_optim, + kl_coef=args.kl_coef, + ptx_coef=args.ptx_coef, + max_epochs=args.max_epochs, + train_batch_size=args.train_batch_size, + experience_batch_size=args.experience_batch_size, + tokenizer=tokenize_fn, + max_length=512, + do_sample=True, + temperature=1.0, + top_k=50, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + ) + + trainer.fit(prompt_dataloader=prompt_dataloader, + pretrain_dataloader=pretrain_dataloader, + num_episodes=args.num_episodes, + max_timesteps=args.max_timesteps, + update_timesteps=args.update_timesteps) + + # save model checkpoint after fitting + trainer.save_model(args.save_path, only_rank0=True, tokenizer=tokenizer) + # save optimizer checkpoint on all ranks + if args.need_optim_ckpt: + strategy.save_optimizer(actor_optim, + 'actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()), + only_rank0=False) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--prompt_path', type=str, default=None, help='path to the prompt dataset') + parser.add_argument('--pretrain_dataset', type=str, default=None, help='path to the pretrained dataset') + parser.add_argument('--strategy', + choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'], + default='naive', + help='strategy to use') + parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama']) + parser.add_argument('--pretrain', type=str, default=None) + parser.add_argument('--sft_lora_path', type=str, default=None) + parser.add_argument('--rm_model', default=None, choices=['gpt2', 'bloom', 'opt', 'llama']) + parser.add_argument('--rm_path', type=str, default=None) + parser.add_argument('--rm_pretrain', type=str, default=None) + parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts') + parser.add_argument('--need_optim_ckpt', type=bool, default=False) + parser.add_argument('--num_episodes', type=int, default=10) + parser.add_argument('--max_timesteps', type=int, default=10) + parser.add_argument('--update_timesteps', type=int, default=10) + parser.add_argument('--max_epochs', type=int, default=5) + parser.add_argument('--train_batch_size', type=int, default=2) + parser.add_argument('--ptx_batch_size', type=int, default=1) + parser.add_argument('--experience_batch_size', type=int, default=8) + parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank") + parser.add_argument('--kl_coef', type=float, default=0.1) + parser.add_argument('--ptx_coef', type=float, default=0.9) + args = parser.parse_args() + main(args) diff --git a/applications/Chat/examples/community/train_peft_sft.py b/applications/Chat/examples/community/train_peft_sft.py new file mode 100644 index 000000000000..65d901261bfc --- /dev/null +++ b/applications/Chat/examples/community/train_peft_sft.py @@ -0,0 +1,187 @@ +import argparse +import os + +import loralib as lora +import torch +import torch.distributed as dist +from coati.dataset import DataCollatorForSupervisedDataset, SFTDataset, SupervisedDataset +from coati.models.base import RewardModel +from coati.models.bloom import BLOOMLM +from coati.models.gpt import GPTLM +from coati.models.llama import LlamaLM +from coati.models.opt import OPTLM +from coati.trainer import SFTTrainer +from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy +from coati.utils import prepare_llama_tokenizer_and_embedding +from datasets import load_dataset +from torch.optim import Adam +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler +from transformers import AutoTokenizer, BloomTokenizerFast,AutoModelForCausalLM +from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer + +from colossalai.logging import get_dist_logger +from colossalai.nn.optimizer import HybridAdam +from colossalai.tensor import ColoParameter + +from torch.utils.data.dataloader import default_collate +from peft import LoraConfig, TaskType,get_peft_model,PeftModel +from easy_dataset import EasyDataset + +def train(args): + # configure strategy + if args.strategy == 'naive': + strategy = NaiveStrategy() + elif args.strategy == 'ddp': + strategy = DDPStrategy() + elif args.strategy == 'colossalai_gemini': + strategy = ColossalAIStrategy(stage=3, placement_policy='cuda') + elif args.strategy == 'colossalai_zero2': + strategy = ColossalAIStrategy(stage=2, placement_policy='cuda') + else: + raise ValueError(f'Unsupported strategy "{args.strategy}"') + + # configure model + with strategy.model_init_context(): + print('Warning: currently only bloom is tested, gpt2,llama and opt are not tested') + model = AutoModelForCausalLM.from_pretrained(args.pretrain).to(torch.cuda.current_device()) + #if the args.save_path exists and args.save_path+'/adapter_config.json' exists, we'll load the adapter_config.json + if os.path.exists(args.save_path) and os.path.exists(args.save_path+'/adapter_config.json') \ + and os.path.exists(args.save_path+'/adapter_model.bin'): + print("loading from saved peft model ",args.save_path) + model = PeftModel.from_pretrained(model, args.save_path) + else: + #we'll use peft lora library to do the lora + lora_rank = args.lora_rank if args.lora_rank > 0 else 32 + #config lora with rank of lora_rank + lora_config = LoraConfig(task_type=TaskType.CAUSAL_LM, inference_mode=False, r=lora_rank, lora_alpha=32, lora_dropout=0.1) + model = get_peft_model(model, lora_config) + model.print_trainable_parameters() + + + # configure tokenizer + if args.model == 'gpt2': + tokenizer = GPT2Tokenizer.from_pretrained('gpt2') + tokenizer.pad_token = tokenizer.eos_token + elif args.model == 'bloom': + tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain) + tokenizer.pad_token = tokenizer.eos_token + elif args.model == 'opt': + tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") + elif args.model == 'llama': + tokenizer = AutoTokenizer.from_pretrained( + args.pretrain, + padding_side="right", + use_fast=False, + ) + tokenizer.eos_token = '<\s>' + else: + raise ValueError(f'Unsupported model "{args.model}"') + tokenizer.pad_token = tokenizer.eos_token + if args.model == 'llama': + tokenizer = prepare_llama_tokenizer_and_embedding(tokenizer, model) + + if args.strategy == 'colossalai_gemini': + # this is a hack to deal with the resized embedding + # to make sure all parameters are ColoParameter for Colossal-AI Gemini Compatiblity + for name, param in model.named_parameters(): + if not isinstance(param, ColoParameter): + sub_module_name = '.'.join(name.split('.')[:-1]) + weight_name = name.split('.')[-1] + sub_module = model.get_submodule(sub_module_name) + setattr(sub_module, weight_name, ColoParameter(param)) + else: + tokenizer.pad_token = tokenizer.eos_token + + # configure optimizer + if args.strategy.startswith('colossalai'): + optim = HybridAdam(model.parameters(), lr=args.lr, clipping_norm=1.0) + else: + optim = Adam(model.parameters(), lr=args.lr) + + logger = get_dist_logger() + logger.set_level('WARNING') + + # configure dataset + law_dataset = EasyDataset(args.dataset,tokenizer=tokenizer,is_group_texts=not args.is_short_text) + train_dataset = law_dataset + print(train_dataset) + eval_dataset = None + if args.eval_dataset is not None: + eval_dataset = EasyDataset(args.eval_dataset,tokenizer=tokenizer,is_group_texts=not args.is_short_text) + data_collator = default_collate + if dist.is_initialized() and dist.get_world_size() > 1: + train_sampler = DistributedSampler(train_dataset, + shuffle=True, + seed=42, + drop_last=True, + rank=dist.get_rank(), + num_replicas=dist.get_world_size()) + if eval_dataset is not None: + eval_sampler = DistributedSampler(eval_dataset, + shuffle=False, + seed=42, + drop_last=False, + rank=dist.get_rank(), + num_replicas=dist.get_world_size()) + else: + train_sampler = None + eval_sampler = None + + train_dataloader = DataLoader(train_dataset, + shuffle=(train_sampler is None), + sampler=train_sampler, + batch_size=args.batch_size, + collate_fn=data_collator, + pin_memory=True) + if eval_dataset is not None: + eval_dataloader = DataLoader(eval_dataset, + shuffle=(eval_sampler is None), + sampler=eval_sampler, + batch_size=args.batch_size, + collate_fn=data_collator, + pin_memory=True) + else: + eval_dataloader = None + + trainer = SFTTrainer(model=model, + strategy=strategy, + optim=optim, + train_dataloader=train_dataloader, + eval_dataloader=eval_dataloader, + batch_size=args.batch_size, + max_epochs=args.max_epochs, + accimulation_steps=args.accimulation_steps) + + trainer.fit(logger=logger, log_interval=args.log_interval) + + # save model checkpoint after fitting on only rank0 + trainer.save_model(path=args.save_path, only_rank0=True, tokenizer=tokenizer) + # save optimizer checkpoint on all ranks + if args.need_optim_ckpt: + strategy.save_optimizer(trainer.optimizer, + 'rm_optim_checkpoint_%d.pt' % (torch.cuda.current_device()), + only_rank0=False) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--strategy', + choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'], + default='naive') + parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom') + parser.add_argument('--pretrain', type=str, default=None) + parser.add_argument('--dataset', type=str, default=None) + parser.add_argument('--eval_dataset', type=str, default=None) + parser.add_argument('--save_path', type=str, default='output') + parser.add_argument('--need_optim_ckpt', type=bool, default=False) + parser.add_argument('--max_epochs', type=int, default=3) + parser.add_argument('--batch_size', type=int, default=4) + parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank") + parser.add_argument('--log_interval', type=int, default=100, help="how many steps to log") + parser.add_argument('--lr', type=float, default=5e-6) + parser.add_argument('--accimulation_steps', type=int, default=8) + parser.add_argument('--enable_peft_lora',action='store_true', default=False) + parser.add_argument("--is_short_text",action='store_true', default=False) + args = parser.parse_args() + train(args) From 80eba05b0abc0ce24f02254cbe2c7b8f9ff5d688 Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Thu, 6 Apr 2023 14:51:35 +0800 Subject: [PATCH 18/27] [test] refactor tests with spawn (#3452) * [test] added spawn decorator * polish code * polish code * polish code * polish code * polish code * polish code --- .github/workflows/build_on_pr.yml | 19 +- applications/Chat/tests/test_checkpoint.py | 8 +- applications/Chat/tests/test_data.py | 8 +- colossalai/cli/benchmark/benchmark.py | 3 +- colossalai/testing/__init__.py | 16 +- colossalai/testing/utils.py | 88 ++++++- colossalai/utils/__init__.py | 2 - colossalai/utils/common.py | 17 -- docs/requirements-doc-test.txt | 1 + docs/source/en/basics/colotensor_concept.md | 7 +- .../zh-Hans/basics/colotensor_concept.md | 7 +- examples/images/vit/test_vit.py | 8 +- .../auto_offload/train_gpt_offload.py | 39 ++-- .../auto_parallel/auto_parallel_with_gpt.py | 7 +- .../auto_parallel/auto_ckpt_batchsize_test.py | 11 +- .../auto_parallel/auto_ckpt_solver_test.py | 6 +- requirements/requirements-test.txt | 1 + tests/test_amp/test_naive_fp16.py | 10 +- tests/test_amp/test_torch_fp16.py | 10 +- .../test_fx/test_bias_addition.py | 3 +- tests/test_analyzer/test_fx/test_mod_dir.py | 11 +- .../test_analyzer/test_fx/test_nested_ckpt.py | 5 +- .../test_analyzer/test_fx/test_shape_prop.py | 4 +- .../test_fx/test_symbolic_profile.py | 4 +- .../test_subclasses/test_aten.py | 5 +- .../test_subclasses/test_flop_tensor.py | 4 +- .../test_subclasses/test_meta_mode.py | 5 +- .../test_C_solver_consistency.py | 10 +- .../test_ckpt_torchvision.py | 17 +- .../test_ckpt_solvers/test_linearize.py | 3 + .../test_offload/test_perf.py | 10 +- .../test_offload/test_solver.py | 23 +- .../test_pass/test_node_converting_pass.py | 2 + .../test_size_value_converting_pass.py | 2 + .../test_bias_addition_forward.py | 12 +- .../test_tensor_shard/test_checkpoint.py | 12 +- .../test_compatibility_with_ddp.py | 10 +- .../test_compatibility_with_gemini.py | 14 +- .../test_find_repeat_block.py | 4 +- .../test_gpt/test_runtime_with_gpt_modules.py | 9 +- .../test_gpt/test_solver_with_gpt_module.py | 6 +- .../test_liveness_analysis.py | 3 + .../test_metainfo/test_activation_metainfo.py | 15 +- .../test_binary_elementwise_metainfo.py | 10 +- .../test_metainfo/test_conv_metainfo.py | 17 +- .../test_metainfo/test_embedding_metainfo.py | 27 +-- .../test_metainfo/test_linear_metainfo.py | 20 +- .../test_metainfo/test_matmul_metainfo.py | 25 +- .../test_metainfo/test_norm_metainfo.py | 22 +- .../test_metainfo/test_pooling_metainfo.py | 15 +- .../test_metainfo/test_tensor_metainfo.py | 22 +- .../test_metainfo/test_where_metainfo.py | 23 +- .../test_node_handler/test_addbmm_handler.py | 39 ++-- .../test_node_handler/test_addmm_handler.py | 17 +- .../test_batch_norm_handler.py | 11 +- .../test_bias_linear_function_node.py | 12 +- .../test_bias_linear_module_node.py | 16 +- .../test_binary_elementwise_handler.py | 39 ++-- .../test_node_handler/test_bmm_handler.py | 14 +- .../test_node_handler/test_conv_handler.py | 19 +- .../test_default_reshape_handler.py | 3 +- .../test_embedding_handler.py | 14 +- .../test_node_handler/test_getattr_handler.py | 2 + .../test_node_handler/test_getitem_handler.py | 13 +- .../test_layer_norm_handler.py | 11 +- .../test_node_handler/test_linear_handler.py | 35 ++- .../test_node_handler/test_matmul_handler.py | 3 +- .../test_norm_pooling_handler.py | 5 +- .../test_node_handler/test_output_handler.py | 4 +- .../test_permute_and_transpose_handler.py | 21 +- .../test_placeholder_handler.py | 4 +- .../test_node_handler/test_shard_option.py | 4 +- .../test_node_handler/test_softmax_handler.py | 17 +- .../test_node_handler/test_split_handler.py | 18 +- .../test_node_handler/test_sum_handler.py | 13 +- .../test_tensor_constructor.py | 3 +- .../test_unary_element_wise_handler.py | 3 +- .../test_node_handler/test_view_handler.py | 14 +- .../test_node_handler/test_where_handler.py | 2 + .../test_solver_with_resnet_v2.py | 3 +- .../benchmark_autochunk_alphafold.py | 2 +- .../test_autochunk_alphafold_utils.py | 2 +- .../test_autochunk_evoformer_block.py | 12 +- .../test_autochunk_evoformer_stack.py | 12 +- .../test_autochunk_extramsa_block.py | 12 +- .../test_autochunk_diffuser_utils.py | 2 +- .../test_autochunk_unet.py | 14 +- .../test_autochunk_gpt.py | 14 +- .../test_autochunk_transformer_utils.py | 2 +- .../test_autochunk_vit/test_autochunk_vit.py | 12 +- .../test_autochunk_vit_utils.py | 2 +- tests/test_booster/test_accelerator.py | 19 +- .../test_mixed_precision/test_fp16_torch.py | 10 +- .../test_plugin/test_gemini_plugin.py | 11 +- .../test_plugin/test_torch_ddp_plugin.py | 10 +- .../test_general_checkpoint_io.py | 4 +- .../test_cluster/test_device_mesh_manager.py | 11 +- .../test_comm/test_boardcast_send_recv_v2.py | 15 +- tests/test_comm/test_comm.py | 12 +- tests/test_comm/test_object_list_p2p.py | 21 +- tests/test_comm/test_object_list_p2p_v2.py | 17 +- tests/test_context/test_hybrid_parallel.py | 22 +- tests/test_data/test_data_parallel_sampler.py | 14 +- .../test_deterministic_dataloader.py | 15 +- .../test_cifar_with_data_pipeline_tensor.py | 21 +- ...test_cifar_with_data_pipeline_tensor_v2.py | 215 +++++++++--------- tests/test_ddp/test_ddp_ignore_params.py | 8 +- tests/test_ddp/test_ddp_state_dict.py | 9 +- tests/test_ddp/test_reducer.py | 17 +- tests/test_device/test_alpha_beta.py | 12 +- tests/test_device/test_extract_alpha_beta.py | 15 +- tests/test_device/test_init_logical_pg.py | 13 +- .../test_search_logical_device_mesh.py | 12 +- tests/test_engine/test_engine.py | 11 +- .../test_engine/test_gradient_accumluation.py | 18 +- .../test_activation_checkpoint_codegen.py | 18 +- ...st_nested_activation_checkpoint_codegen.py | 19 +- .../test_codegen/test_offload_codegen.py | 18 +- tests/test_fx/test_coloproxy.py | 7 +- tests/test_fx/test_comm_size_compute.py | 11 +- tests/test_fx/test_complete_workflow.py | 87 ------- tests/test_fx/test_graph_manipulation.py | 9 +- tests/test_fx/test_meta/test_aten.py | 3 + tests/test_fx/test_meta/test_backward.py | 5 + tests/test_fx/test_meta/test_meta_trace.py | 5 + tests/test_fx/test_meta_info_prop.py | 5 +- tests/test_fx/test_parallel_1d.py | 18 +- tests/test_fx/test_pipeline_passes.py | 16 +- .../test_profiler_meta_info_prop.py | 4 +- .../test_activation_checkpoint_annotation.py | 2 + .../test_tracer/test_bias_addition_module.py | 3 + .../test_fx/test_tracer/test_control_flow.py | 3 + .../test_tracer/test_functional_conv.py | 3 + .../test_hf_model/test_hf_albert.py | 2 + .../test_tracer/test_hf_model/test_hf_bert.py | 2 + .../test_hf_model/test_hf_diffuser.py | 3 + .../test_tracer/test_hf_model/test_hf_gpt.py | 2 + .../test_tracer/test_hf_model/test_hf_opt.py | 2 + .../test_tracer/test_hf_model/test_hf_t5.py | 2 + .../test_tracer/test_patched_module.py | 16 ++ tests/test_fx/test_tracer/test_patched_op.py | 7 +- .../test_timm_model/test_timm_model.py | 2 + .../test_torchaudio_model.py | 2 + .../test_torchrec_model/test_deepfm_model.py | 2 + .../test_torchrec_model/test_dlrm_model.py | 2 + .../test_torchvision_model.py | 2 + tests/test_layers/test_1d/test_1d.py | 10 +- tests/test_layers/test_2d/test_2d.py | 31 +-- tests/test_layers/test_2p5d/test_2p5d.py | 15 +- tests/test_layers/test_3d/test_3d.py | 27 ++- tests/test_layers/test_cache_embedding.py | 38 ++-- .../test_sequence/test_sequence.py | 19 +- tests/test_moe/test_grad_handler.py | 17 +- tests/test_moe/test_kernel.py | 21 +- tests/test_moe/test_moe_checkpoint.py | 10 +- tests/test_moe/test_moe_colo_init.py | 10 +- tests/test_moe/test_moe_group.py | 17 +- tests/test_moe/test_moe_zero_init.py | 10 +- tests/test_moe/test_moe_zero_model.py | 9 +- tests/test_moe/test_moe_zero_optim.py | 10 +- tests/test_ops/test_addmm_tp.py | 18 +- tests/test_ops/test_embedding_bag_tp.py | 14 +- tests/test_ops/test_embedding_tp.py | 16 +- tests/test_ops/test_linear_tp.py | 16 +- tests/test_ops/test_loss_func.py | 100 ++++---- tests/test_ops/test_op.py | 20 +- tests/test_ops/test_view.py | 197 ++++++++-------- tests/test_optimizer/test_cpu_adam.py | 3 +- tests/test_optimizer/test_fused_adam.py | 5 +- .../test_optimizer/test_fused_adam_kernel.py | 3 +- tests/test_optimizer/test_hybrid_adam.py | 3 +- tests/test_optimizer/test_nvme.py | 11 +- tests/test_pipeline/rpc_test_utils.py | 14 +- tests/test_pipeline/test_middleware_1f1b.py | 71 +++--- tests/test_pipeline/test_pipelinable.py | 10 +- .../test_pipeline_process_group.py | 9 +- tests/test_tensor/core/test_dist_spec_mgr.py | 14 +- tests/test_tensor/core/test_tensor.py | 15 +- tests/test_tensor/model/test_gpt2.py | 9 +- tests/test_tensor/model/test_model.py | 12 +- tests/test_tensor/model/test_module_spec.py | 14 +- .../test_tensor/test_colo_checkpoint_tools.py | 88 ++++--- tests/test_tensor/test_comm_spec_apply.py | 11 +- tests/test_tensor/test_context.py | 9 +- .../test_dtensor/test_comm_spec.py | 9 +- .../test_tensor/test_dtensor/test_dtensor.py | 9 +- .../test_dtensor/test_layout_converter.py | 16 +- tests/test_tensor/test_mix_gather.py | 9 +- tests/test_tensor/test_parameter.py | 7 +- .../test_shape_consistency_apply.py | 9 +- tests/test_tensor/test_sharded_linear.py | 9 +- tests/test_tensor/test_tp_with_zero.py | 9 +- tests/test_trainer/test_pipeline/test_p2p.py | 26 ++- .../test_pipeline/test_pipeline_schedule.py | 29 +-- .../test_trainer_with_non_pipe_schedule.py | 13 +- .../test_trainer_with_pipe_schedule.py | 21 +- .../test_activation_checkpointing.py | 9 +- .../test_checkpoint/test_checkpoint_1d.py | 157 +++++++------ .../test_checkpoint/test_checkpoint_2d.py | 157 +++++++------ .../test_checkpoint/test_checkpoint_2p5d.py | 157 +++++++------ .../test_checkpoint/test_checkpoint_3d.py | 157 +++++++------ .../test_checkpoint_io/test_load.py | 18 +- .../test_checkpoint_io/test_merge.py | 31 ++- .../test_checkpoint_io/test_redist.py | 25 +- .../test_checkpoint_io/test_save.py | 26 ++- tests/test_utils/test_colo_checkpoint.py | 13 +- tests/test_utils/test_commons.py | 10 +- tests/test_utils/test_flash_attention.py | 13 +- .../test_lazy_init/test_distribute.py | 9 +- tests/test_utils/test_memory.py | 10 +- .../test_utils/test_norm_gradient_clipping.py | 20 +- .../test_zero_gradient_clippling.py | 9 +- .../test_zero/test_gemini/test_chunk_mgrv2.py | 9 +- tests/test_zero/test_gemini/test_chunkv2.py | 10 +- tests/test_zero/test_gemini/test_fwd_bwd.py | 9 +- .../test_gemini/test_gemini_use_rmt.py | 11 +- .../test_gemini/test_get_torch_model.py | 10 +- tests/test_zero/test_gemini/test_grad_clip.py | 12 +- tests/test_zero/test_gemini/test_inference.py | 8 +- tests/test_zero/test_gemini/test_optim.py | 10 +- .../test_gemini/test_runtime_mem_tracer.py | 2 + tests/test_zero/test_gemini/test_search.py | 11 +- .../test_gemini/test_zeroddp_state_dict.py | 10 +- .../test_gemini/test_zerooptim_state_dict.py | 9 +- tests/test_zero/test_legacy/test_found_inf.py | 9 +- .../test_legacy/test_gemini_manager.py | 2 + .../test_legacy/test_init_context.py | 9 +- tests/test_zero/test_legacy/test_param_op.py | 2 + .../test_legacy/test_shard_model_v2.py | 9 +- .../test_zero/test_legacy/test_shard_param.py | 11 +- .../test_sharded_optim_state_dict.py | 9 +- .../test_legacy/test_sharded_optim_v2.py | 9 +- .../test_sharded_optim_with_sync_bn.py | 10 +- .../test_zero/test_legacy/test_state_dict.py | 7 +- .../test_legacy/test_tensor_utils.py | 9 +- .../test_zero/test_legacy/test_zero_engine.py | 12 +- .../test_zero/test_low_level/test_grad_acc.py | 8 +- .../test_zero/test_low_level/test_zero1_2.py | 9 +- .../test_low_level/test_zero_init.py | 10 +- .../test_zero/test_low_level/test_zero_tp.py | 11 +- 240 files changed, 1721 insertions(+), 2340 deletions(-) delete mode 100644 tests/test_fx/test_complete_workflow.py diff --git a/.github/workflows/build_on_pr.yml b/.github/workflows/build_on_pr.yml index f595e677394a..e6febeeb4d87 100644 --- a/.github/workflows/build_on_pr.yml +++ b/.github/workflows/build_on_pr.yml @@ -8,10 +8,10 @@ jobs: detect: name: Detect file change if: | - github.event.pull_request.draft == false && - github.base_ref == 'main' && - github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' && - contains( github.event.pull_request.labels.*.name, 'Run Build and Test') + github.event.pull_request.draft == false && + github.base_ref == 'main' && + github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI' && + contains( github.event.pull_request.labels.*.name, 'Run Build and Test') outputs: changedExtenisonFiles: ${{ steps.find-extension-change.outputs.all_changed_files }} anyExtensionFileChanged: ${{ steps.find-extension-change.outputs.any_changed }} @@ -27,10 +27,10 @@ jobs: - name: Locate base commit id: locate-base-sha run: | - curBranch=$(git rev-parse --abbrev-ref HEAD) - commonCommit=$(git merge-base origin/main $curBranch) - echo $commonCommit - echo "baseSHA=$commonCommit" >> $GITHUB_OUTPUT + curBranch=$(git rev-parse --abbrev-ref HEAD) + commonCommit=$(git merge-base origin/main $curBranch) + echo $commonCommit + echo "baseSHA=$commonCommit" >> $GITHUB_OUTPUT - name: Find the changed extension-related files id: find-extension-change @@ -63,7 +63,6 @@ jobs: echo "$file was changed" done - build: name: Build and Test Colossal-AI needs: detect @@ -124,7 +123,7 @@ jobs: - name: Execute Unit Testing if: needs.detect.outputs.anyLibraryFileChanged == 'true' run: | - PYTHONPATH=$PWD pytest --cov=. --cov-report xml tests/ + CURL_CA_BUNDLE="" PYTHONPATH=$PWD pytest --cov=. --cov-report xml tests/ env: DATA: /data/scratch/cifar-10 NCCL_SHM_DISABLE: 1 diff --git a/applications/Chat/tests/test_checkpoint.py b/applications/Chat/tests/test_checkpoint.py index 8c7848525201..4c05a3431699 100644 --- a/applications/Chat/tests/test_checkpoint.py +++ b/applications/Chat/tests/test_checkpoint.py @@ -1,19 +1,16 @@ import os import tempfile from contextlib import nullcontext -from functools import partial import pytest import torch import torch.distributed as dist -import torch.multiprocessing as mp from coati.models.gpt import GPTActor from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy from transformers.models.gpt2.configuration_gpt2 import GPT2Config from colossalai.nn.optimizer import HybridAdam -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import rerun_if_address_is_in_use, spawn GPT_CONFIG = GPT2Config(n_embd=128, n_layer=4, n_head=4) @@ -90,8 +87,7 @@ def run_dist(rank, world_size, port, strategy): @pytest.mark.parametrize('strategy', ['ddp', 'colossalai_zero2', 'colossalai_gemini']) @rerun_if_address_is_in_use() def test_checkpoint(world_size, strategy): - run_func = partial(run_dist, world_size=world_size, port=free_port(), strategy=strategy) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, world_size, strategy=strategy) if __name__ == '__main__': diff --git a/applications/Chat/tests/test_data.py b/applications/Chat/tests/test_data.py index 577309a0fceb..2e4d4ceac05f 100644 --- a/applications/Chat/tests/test_data.py +++ b/applications/Chat/tests/test_data.py @@ -1,11 +1,9 @@ import os from copy import deepcopy -from functools import partial import pytest import torch import torch.distributed as dist -import torch.multiprocessing as mp from coati.experience_maker import NaiveExperienceMaker from coati.models.base import RewardModel from coati.models.gpt import GPTActor, GPTCritic @@ -13,8 +11,7 @@ from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy from transformers.models.gpt2.configuration_gpt2 import GPT2Config -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import rerun_if_address_is_in_use, spawn GPT_CONFIG = GPT2Config(n_embd=128, n_layer=4, n_head=4) @@ -114,8 +111,7 @@ def run_dist(rank, world_size, port, strategy): @pytest.mark.parametrize('strategy', ['ddp', 'colossalai']) @rerun_if_address_is_in_use() def test_data(world_size, strategy): - run_func = partial(run_dist, world_size=world_size, port=free_port(), strategy=strategy) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, world_size, strategy=strategy) if __name__ == '__main__': diff --git a/colossalai/cli/benchmark/benchmark.py b/colossalai/cli/benchmark/benchmark.py index f40f8f2f995e..97a9f45722dd 100644 --- a/colossalai/cli/benchmark/benchmark.py +++ b/colossalai/cli/benchmark/benchmark.py @@ -10,7 +10,8 @@ from colossalai.context.random import reset_seeds from colossalai.core import global_context as gpc from colossalai.logging import disable_existing_loggers, get_dist_logger -from colossalai.utils import MultiTimer, free_port +from colossalai.testing import free_port +from colossalai.utils import MultiTimer from .models import MLP diff --git a/colossalai/testing/__init__.py b/colossalai/testing/__init__.py index e3dd500dea8e..c53e0f44c7e0 100644 --- a/colossalai/testing/__init__.py +++ b/colossalai/testing/__init__.py @@ -1,7 +1,17 @@ -from .comparison import assert_equal, assert_not_equal, assert_close, assert_close_loose, assert_equal_in_group -from .utils import parameterize, rerun_on_exception, rerun_if_address_is_in_use, skip_if_not_enough_gpus +from .comparison import assert_close, assert_close_loose, assert_equal, assert_equal_in_group, assert_not_equal +from .pytest_wrapper import run_on_environment_flag +from .utils import ( + clear_cache_before_run, + free_port, + parameterize, + rerun_if_address_is_in_use, + rerun_on_exception, + skip_if_not_enough_gpus, + spawn, +) __all__ = [ 'assert_equal', 'assert_not_equal', 'assert_close', 'assert_close_loose', 'assert_equal_in_group', 'parameterize', - 'rerun_on_exception', 'rerun_if_address_is_in_use', 'skip_if_not_enough_gpus' + 'rerun_on_exception', 'rerun_if_address_is_in_use', 'skip_if_not_enough_gpus', 'free_port', 'spawn', + 'clear_cache_before_run', 'run_on_environment_flag' ] diff --git a/colossalai/testing/utils.py b/colossalai/testing/utils.py index 64c1d6e7bcd0..eac83e6d7bd5 100644 --- a/colossalai/testing/utils.py +++ b/colossalai/testing/utils.py @@ -1,8 +1,13 @@ +import gc +import random import re -import torch -from typing import Callable, List, Any +import socket from functools import partial from inspect import signature +from typing import Any, Callable, List + +import torch +import torch.multiprocessing as mp from packaging import version @@ -43,7 +48,7 @@ def say_something(person, msg): # > davis: hello # > davis: bye # > davis: stop - + Args: argument (str): the name of the argument to parameterize values (List[Any]): a list of values to iterate for this argument @@ -85,13 +90,13 @@ def test_method(): def test_method(): print('hey') raise RuntimeError('Address already in use') - + # rerun for infinite times if Runtime error occurs @rerun_on_exception(exception_type=RuntimeError, max_try=None) def test_method(): print('hey') raise RuntimeError('Address already in use') - + # rerun only the exception message is matched with pattern # for infinite times if Runtime error occurs @rerun_on_exception(exception_type=RuntimeError, pattern="^Address.*$") @@ -101,10 +106,10 @@ def test_method(): Args: exception_type (Exception, Optional): The type of exception to detect for rerun - pattern (str, Optional): The pattern to match the exception message. + pattern (str, Optional): The pattern to match the exception message. If the pattern is not None and matches the exception message, the exception will be detected for rerun - max_try (int, Optional): Maximum reruns for this function. The default value is 5. + max_try (int, Optional): Maximum reruns for this function. The default value is 5. If max_try is None, it will rerun foreven if exception keeps occurings """ @@ -202,3 +207,72 @@ def _execute_by_gpu_num(*args, **kwargs): return _execute_by_gpu_num return _wrap_func + + +def free_port() -> int: + """Get a free port on localhost. + + Returns: + int: A free port on localhost. + """ + while True: + port = random.randint(20000, 65000) + try: + with socket.socket() as sock: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind(("localhost", port)) + return port + except OSError: + continue + + +def spawn(func, nprocs=1, **kwargs): + """ + This function is used to spawn processes for testing. + + Usage: + # must contians arguments rank, world_size, port + def do_something(rank, world_size, port): + ... + + spawn(do_something, nprocs=8) + + # can also pass other arguments + def do_something(rank, world_size, port, arg1, arg2): + ... + + spawn(do_something, nprocs=8, arg1=1, arg2=2) + + Args: + func (Callable): The function to be spawned. + nprocs (int, optional): The number of processes to spawn. Defaults to 1. + """ + port = free_port() + wrapped_func = partial(func, world_size=nprocs, port=port, **kwargs) + mp.spawn(wrapped_func, nprocs=nprocs) + + +def clear_cache_before_run(): + """ + This function is a wrapper to clear CUDA and python cache before executing the function. + + Usage: + @clear_cache_before_run() + def test_something(): + ... + """ + + def _wrap_func(f): + + def _clear_cache(*args, **kwargs): + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + torch.cuda.reset_max_memory_allocated() + torch.cuda.reset_max_memory_cached() + torch.cuda.synchronize() + gc.collect() + f(*args, **kwargs) + + return _clear_cache + + return _wrap_func diff --git a/colossalai/utils/__init__.py b/colossalai/utils/__init__.py index 3f16bd91e5fe..7b2e8480c66c 100644 --- a/colossalai/utils/__init__.py +++ b/colossalai/utils/__init__.py @@ -7,7 +7,6 @@ count_zeros_fp32, disposable, ensure_path_exists, - free_port, is_ddp_ignored, is_dp_rank_0, is_model_parallel_parameter, @@ -37,7 +36,6 @@ __all__ = [ 'checkpoint', - 'free_port', 'print_rank_0', 'sync_model_param', 'is_ddp_ignored', diff --git a/colossalai/utils/common.py b/colossalai/utils/common.py index e15981140be1..95b3b8014af1 100644 --- a/colossalai/utils/common.py +++ b/colossalai/utils/common.py @@ -50,23 +50,6 @@ def ensure_path_exists(filename: str): Path(dirpath).mkdir(parents=True, exist_ok=True) -def free_port() -> int: - """Get a free port on localhost. - - Returns: - int: A free port on localhost. - """ - while True: - port = random.randint(20000, 65000) - try: - with socket.socket() as sock: - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - sock.bind(("localhost", port)) - return port - except OSError: - continue - - def sync_model_param(model, parallel_mode): r"""Make sure data parameters are consistent during Data Parallel Mode. diff --git a/docs/requirements-doc-test.txt b/docs/requirements-doc-test.txt index 6a6bb3bee9b0..79e04bd5615d 100644 --- a/docs/requirements-doc-test.txt +++ b/docs/requirements-doc-test.txt @@ -4,3 +4,4 @@ packaging tensornvme psutil transformers +pytest diff --git a/docs/source/en/basics/colotensor_concept.md b/docs/source/en/basics/colotensor_concept.md index 2d8acd88dfd4..1b855c03b919 100644 --- a/docs/source/en/basics/colotensor_concept.md +++ b/docs/source/en/basics/colotensor_concept.md @@ -56,12 +56,12 @@ Let's see an example. A ColoTensor is initialized and sharded on 8 GPUs using tp ```python import torch import torch.multiprocessing as mp -from colossalai.utils import free_port, print_rank_0 +from colossalai.utils import print_rank_0 from functools import partial import colossalai from colossalai.tensor import ProcessGroup, ColoTensor, ColoTensorSpec, ShardSpec, ComputeSpec, ComputePattern -from colossalai.utils import free_port +from colossalai.testing import spawn import torch @@ -83,8 +83,7 @@ def run_dist_tests(rank, world_size, port): print_rank_0(f"shape {t1.shape}, {t1.data}") def test_dist_cases(world_size): - run_func = partial(run_dist_tests, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist_tests, world_size) if __name__ == '__main__': test_dist_cases(4) diff --git a/docs/source/zh-Hans/basics/colotensor_concept.md b/docs/source/zh-Hans/basics/colotensor_concept.md index cac5b9a4b40d..d6a332df2e9c 100644 --- a/docs/source/zh-Hans/basics/colotensor_concept.md +++ b/docs/source/zh-Hans/basics/colotensor_concept.md @@ -57,12 +57,12 @@ ColoTensor 包含额外的属性[ColoTensorSpec](https://colossalai.readthedocs. ```python import torch import torch.multiprocessing as mp -from colossalai.utils import free_port, print_rank_0 +from colossalai.utils import print_rank_0 from functools import partial import colossalai from colossalai.tensor import ProcessGroup, ColoTensor, ColoTensorSpec, ShardSpec, ComputeSpec, ComputePattern -from colossalai.utils import free_port +from colossalai.testing import spawn import torch @@ -84,8 +84,7 @@ def run_dist_tests(rank, world_size, port): print_rank_0(f"shape {t1.shape}, {t1.data}") def test_dist_cases(world_size): - run_func = partial(run_dist_tests, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist_tests, world_size) if __name__ == '__main__': test_dist_cases(4) diff --git a/examples/images/vit/test_vit.py b/examples/images/vit/test_vit.py index 6a587e1df96a..c0ae35bca871 100644 --- a/examples/images/vit/test_vit.py +++ b/examples/images/vit/test_vit.py @@ -1,11 +1,9 @@ import os import random -from functools import partial import numpy as np import pytest import torch -import torch.multiprocessing as mp from torch.nn.parallel import DistributedDataParallel as DDP from vit import get_training_components @@ -15,8 +13,7 @@ from colossalai.core import global_context as gpc from colossalai.nn.parallel.data_parallel import ColoDDP from colossalai.tensor import ComputePattern, ComputeSpec, DistSpecManager, ProcessGroup, ShardSpec -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.utils.cuda import get_current_device from colossalai.zero import ColoInitContext @@ -156,8 +153,7 @@ def run_dist(rank, world_size, port, use_ddp): @pytest.mark.parametrize('use_ddp', [False, True]) @rerun_if_address_is_in_use() def test_vit(world_size, use_ddp): - run_func = partial(run_dist, world_size=world_size, port=free_port(), use_ddp=use_ddp) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, world_size, use_ddp=use_ddp) if __name__ == '__main__': diff --git a/examples/language/gpt/experiments/auto_offload/train_gpt_offload.py b/examples/language/gpt/experiments/auto_offload/train_gpt_offload.py index 729d1ce4456b..89415c23f93c 100644 --- a/examples/language/gpt/experiments/auto_offload/train_gpt_offload.py +++ b/examples/language/gpt/experiments/auto_offload/train_gpt_offload.py @@ -1,20 +1,20 @@ -import time -import pytest import argparse -from functools import partial +import time +import pytest import torch +from model_zoo import GPTLMLoss, get_gpt2_components from torch.utils._pytree import tree_map -import torch.multiprocessing as mp import colossalai -from colossalai.nn.optimizer import HybridAdam -from colossalai.fx.profiler import parameter_size -from colossalai.utils import free_port, get_current_device from colossalai.auto_parallel.offload.amp_optimizer import AMPOptimizer from colossalai.auto_parallel.offload.mem_optimize import memory_optimize from colossalai.auto_parallel.offload.solver import NOT_NVML -from model_zoo import get_gpt2_components, GPTLMLoss +from colossalai.fx.profiler import parameter_size +from colossalai.nn.optimizer import HybridAdam +from colossalai.testing import spawn +from colossalai.utils import get_current_device + def parse_args(): parser = argparse.ArgumentParser() @@ -24,6 +24,7 @@ def parse_args(): parser.add_argument('--memory_budget', type=float, default=16) return parser.parse_args() + @pytest.mark.skipif(NOT_NVML, reason='pynvml is not installed') def train_gpt(args): memory_budget = args.memory_budget * 1024 * 1024 * 1024 @@ -33,13 +34,16 @@ def train_gpt(args): # build model model_builder, data_gen = get_gpt2_components(model_type=model_type, batch_size=batch_size) - label = torch.randint(low=0, high=128, size=(64, 8,), device=get_current_device()) + label = torch.randint(low=0, high=128, size=( + 64, + 8, + ), device=get_current_device()) criterion = GPTLMLoss() start_time = time.time() model = model_builder() model.train() - param_size = parameter_size(model) / 1024 ** 2 / 2 + param_size = parameter_size(model) / 1024**2 / 2 init_time = time.time() - start_time print(f"init_param_size={param_size:.3f} MB | init_model_time={init_time:.3f} s") @@ -74,21 +78,20 @@ def train_gpt(args): torch.cuda.synchronize() exec_time = sum(sorted(time_list)[:5]) / 5 - runtime_peak_mem_alc = torch.cuda.max_memory_allocated() / 1024 ** 2 - runtime_peak_mem_res = torch.cuda.max_memory_reserved() / 1024 ** 2 + runtime_peak_mem_alc = torch.cuda.max_memory_allocated() / 1024**2 + runtime_peak_mem_res = torch.cuda.max_memory_reserved() / 1024**2 print(f'solver_type: {solver_type} | model_type: {model_type}') - print( - f'| exec_time={exec_time:.3f} s | param_size={param_size:.3f} MB ' - f'| runtime_peak_mem_alc={runtime_peak_mem_alc:.3f} MB| runtime_peak_mem_res={runtime_peak_mem_res:.3f} MB|' - ) + print(f'| exec_time={exec_time:.3f} s | param_size={param_size:.3f} MB ' + f'| runtime_peak_mem_alc={runtime_peak_mem_alc:.3f} MB| runtime_peak_mem_res={runtime_peak_mem_res:.3f} MB|') print(time_list) + def run(rank, world_size, port, args): config = {} colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') train_gpt(args) + if __name__ == '__main__': args = parse_args() - run_func = partial(run, world_size=1, port=free_port(), args=args) - mp.spawn(run_func, nprocs=1) + spawn(run, 1, args=args) diff --git a/examples/language/gpt/experiments/auto_parallel/auto_parallel_with_gpt.py b/examples/language/gpt/experiments/auto_parallel/auto_parallel_with_gpt.py index 6ceb7fd87c0a..e331fc8fcf10 100644 --- a/examples/language/gpt/experiments/auto_parallel/auto_parallel_with_gpt.py +++ b/examples/language/gpt/experiments/auto_parallel/auto_parallel_with_gpt.py @@ -1,18 +1,13 @@ from functools import partial from time import time -from typing import Dict, Optional, Tuple, Union import psutil import torch -import torch.multiprocessing as mp -import torch.nn as nn import transformers from gpt_modules import GPT2LMHeadModel, GPTLMLoss -from torch.fx import GraphModule -from colossalai.auto_parallel.tensor_shard.initialize import autoparallelize, initialize_model +from colossalai.auto_parallel.tensor_shard.initialize import autoparallelize from colossalai.core import global_context as gpc -from colossalai.device.device_mesh import DeviceMesh from colossalai.initialize import launch_from_torch from colossalai.logging import disable_existing_loggers, get_dist_logger diff --git a/examples/tutorial/auto_parallel/auto_ckpt_batchsize_test.py b/examples/tutorial/auto_parallel/auto_ckpt_batchsize_test.py index 5decfc695f6f..5a68aae18041 100644 --- a/examples/tutorial/auto_parallel/auto_ckpt_batchsize_test.py +++ b/examples/tutorial/auto_parallel/auto_ckpt_batchsize_test.py @@ -1,19 +1,14 @@ -import time -from argparse import ArgumentParser from copy import deepcopy from functools import partial -import matplotlib.pyplot as plt -import numpy as np import torch -import torch.multiprocessing as mp import torchvision.models as tm from bench_utils import bench, data_gen_resnet import colossalai from colossalai.auto_parallel.checkpoint import CheckpointSolverRotor from colossalai.fx import metainfo_trace, symbolic_trace -from colossalai.utils import free_port +from colossalai.testing import spawn def _benchmark(rank, world_size, port): @@ -50,9 +45,7 @@ def _benchmark(rank, world_size, port): def auto_activation_checkpoint_batchsize_benchmark(): - world_size = 1 - run_func_module = partial(_benchmark, world_size=world_size, port=free_port()) - mp.spawn(run_func_module, nprocs=world_size) + spawn(_benchmark, 1) if __name__ == "__main__": diff --git a/examples/tutorial/auto_parallel/auto_ckpt_solver_test.py b/examples/tutorial/auto_parallel/auto_ckpt_solver_test.py index ab0f2ef661df..aa5c47294a82 100644 --- a/examples/tutorial/auto_parallel/auto_ckpt_solver_test.py +++ b/examples/tutorial/auto_parallel/auto_ckpt_solver_test.py @@ -4,14 +4,13 @@ import matplotlib.pyplot as plt import torch -import torch.multiprocessing as mp import torchvision.models as tm from bench_utils import GPTLMLoss, bench_rotor, data_gen_gpt2, data_gen_resnet, gpt2_medium import colossalai from colossalai.auto_parallel.checkpoint import CheckpointSolverRotor from colossalai.fx import metainfo_trace, symbolic_trace -from colossalai.utils import free_port +from colossalai.testing import spawn def _benchmark(rank, world_size, port, args): @@ -77,8 +76,7 @@ def _benchmark(rank, world_size, port, args): def auto_activation_checkpoint_benchmark(args): world_size = 1 - run_func_module = partial(_benchmark, world_size=world_size, port=free_port(), args=args) - mp.spawn(run_func_module, nprocs=world_size) + spawn(_benchmark, world_size, args=args) if __name__ == "__main__": diff --git a/requirements/requirements-test.txt b/requirements/requirements-test.txt index 05c0e6ac5e5c..82b6173b3517 100644 --- a/requirements/requirements-test.txt +++ b/requirements/requirements-test.txt @@ -12,3 +12,4 @@ contexttimer einops triton==2.0.0.dev20221202 git+https://github.com/HazyResearch/flash-attention.git@c422fee3776eb3ea24e011ef641fd5fbeb212623#egg=flash_attn +requests==2.27.1 # downgrade to avoid huggingface error https://github.com/huggingface/transformers/issues/17611 diff --git a/tests/test_amp/test_naive_fp16.py b/tests/test_amp/test_naive_fp16.py index c01de469b8f1..6ce4c7f49725 100644 --- a/tests/test_amp/test_naive_fp16.py +++ b/tests/test_amp/test_naive_fp16.py @@ -1,14 +1,11 @@ import copy -from functools import partial import pytest import torch -import torch.multiprocessing as mp import colossalai from colossalai.amp import convert_to_apex_amp, convert_to_naive_amp -from colossalai.testing import assert_close_loose, rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import assert_close_loose, clear_cache_before_run, rerun_if_address_is_in_use, spawn from tests.components_to_test.registry import non_distributed_component_funcs @@ -87,10 +84,9 @@ def run_dist(rank, world_size, port): @pytest.mark.dist @rerun_if_address_is_in_use() +@clear_cache_before_run() def test_naive_amp(): - world_size = 1 - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, 1) if __name__ == '__main__': diff --git a/tests/test_amp/test_torch_fp16.py b/tests/test_amp/test_torch_fp16.py index e65dd8cded26..6451aa6264a3 100644 --- a/tests/test_amp/test_torch_fp16.py +++ b/tests/test_amp/test_torch_fp16.py @@ -1,14 +1,11 @@ import copy -from functools import partial import pytest import torch -import torch.multiprocessing as mp import colossalai from colossalai.amp import convert_to_apex_amp, convert_to_torch_amp -from colossalai.testing import assert_close_loose, rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import assert_close_loose, clear_cache_before_run, rerun_if_address_is_in_use, spawn from tests.components_to_test.registry import non_distributed_component_funcs @@ -87,10 +84,9 @@ def run_dist(rank, world_size, port): @pytest.mark.dist @rerun_if_address_is_in_use() +@clear_cache_before_run() def test_torch_amp(): - world_size = 1 - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, 1) if __name__ == '__main__': diff --git a/tests/test_analyzer/test_fx/test_bias_addition.py b/tests/test_analyzer/test_fx/test_bias_addition.py index 61951e9a5da9..f7b5eb140f24 100644 --- a/tests/test_analyzer/test_fx/test_bias_addition.py +++ b/tests/test_analyzer/test_fx/test_bias_addition.py @@ -3,7 +3,7 @@ from packaging import version from torch.utils.checkpoint import checkpoint -from colossalai.testing.utils import parameterize +from colossalai.testing.utils import clear_cache_before_run, parameterize try: from colossalai._analyzer.fx import symbolic_trace @@ -81,6 +81,7 @@ def forward(self, x): @pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') +@clear_cache_before_run() @parameterize("bias", [True, False]) @parameterize("bias_addition_split", [True, False]) @parameterize("shape", [(3, 3, 3), (3, 3, 3, 3)]) diff --git a/tests/test_analyzer/test_fx/test_mod_dir.py b/tests/test_analyzer/test_fx/test_mod_dir.py index 15e0c2ec21c7..f62147b297a2 100644 --- a/tests/test_analyzer/test_fx/test_mod_dir.py +++ b/tests/test_analyzer/test_fx/test_mod_dir.py @@ -1,6 +1,8 @@ import pytest import torch +from colossalai.testing import clear_cache_before_run, parameterize + try: from colossalai._analyzer.fx import symbolic_trace except: @@ -62,9 +64,10 @@ def forward(self, x): @pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12') -@pytest.mark.parametrize("bias", [True, False]) -@pytest.mark.parametrize("bias_addition_split", [True, False]) -@pytest.mark.parametrize("shape", [(3, 3, 3), (3, 3, 3, 3)]) +@clear_cache_before_run() +@parameterize("bias", [True, False]) +@parameterize("bias_addition_split", [True, False]) +@parameterize("shape", [(3, 3, 3), (3, 3, 3, 3)]) def test_mod_dir(bias, bias_addition_split, shape): model = AModel(bias=bias) x = torch.rand(shape) @@ -75,4 +78,4 @@ def test_mod_dir(bias, bias_addition_split, shape): if __name__ == '__main__': - test_mod_dir(True, True, (3, 3, 3)) + test_mod_dir(bias=True, bias_addition_split=True, shape=(3, 3, 3)) diff --git a/tests/test_analyzer/test_fx/test_nested_ckpt.py b/tests/test_analyzer/test_fx/test_nested_ckpt.py index c31aab6752f8..bd16f5a4f95d 100644 --- a/tests/test_analyzer/test_fx/test_nested_ckpt.py +++ b/tests/test_analyzer/test_fx/test_nested_ckpt.py @@ -1,7 +1,9 @@ +import pytest import torch import torch.nn as nn from torch.utils.checkpoint import checkpoint -import pytest + +from colossalai.testing import clear_cache_before_run try: from colossalai._analyzer.fx import symbolic_trace @@ -42,6 +44,7 @@ def forward(self, x): @pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12') +@clear_cache_before_run() def test_nested_ckpt(): model = MyModule() x = torch.rand(10, 10) diff --git a/tests/test_analyzer/test_fx/test_shape_prop.py b/tests/test_analyzer/test_fx/test_shape_prop.py index 08f4ff2cbd1f..a849feb795e5 100644 --- a/tests/test_analyzer/test_fx/test_shape_prop.py +++ b/tests/test_analyzer/test_fx/test_shape_prop.py @@ -3,7 +3,7 @@ import torchvision.models as tm from packaging import version -from colossalai.testing.utils import parameterize +from colossalai.testing.utils import clear_cache_before_run, parameterize from tests.test_analyzer.test_fx.zoo import tm_models, tmm_models try: @@ -32,6 +32,7 @@ def _check_gm_validity(gm: torch.fx.GraphModule): @pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') +@clear_cache_before_run() @parameterize('m', tm_models) def test_torchvision_shape_prop(m): with MetaTensorMode(): @@ -46,6 +47,7 @@ def test_torchvision_shape_prop(m): @pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') +@clear_cache_before_run() @parameterize('m', tmm_models) def test_timm_shape_prop(m): with MetaTensorMode(): diff --git a/tests/test_analyzer/test_fx/test_symbolic_profile.py b/tests/test_analyzer/test_fx/test_symbolic_profile.py index be781599f14b..17deee7a7118 100644 --- a/tests/test_analyzer/test_fx/test_symbolic_profile.py +++ b/tests/test_analyzer/test_fx/test_symbolic_profile.py @@ -3,7 +3,7 @@ import torchvision.models as tm from packaging import version -from colossalai.testing.utils import parameterize +from colossalai.testing.utils import clear_cache_before_run, parameterize from tests.test_analyzer.test_fx.zoo import tm_models, tmm_models try: @@ -19,6 +19,7 @@ def _check_gm_validity(gm: torch.fx.GraphModule): @pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') +@clear_cache_before_run() @parameterize('m', tm_models) def test_torchvision_profile(m, verbose=False, bias_addition_split=False): with MetaTensorMode(): @@ -33,6 +34,7 @@ def test_torchvision_profile(m, verbose=False, bias_addition_split=False): @pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') +@clear_cache_before_run() @parameterize('m', tmm_models) def test_timm_profile(m, verbose=False, bias_addition_split=False): with MetaTensorMode(): diff --git a/tests/test_analyzer/test_subclasses/test_aten.py b/tests/test_analyzer/test_subclasses/test_aten.py index 591a8d617580..b7858110ac09 100644 --- a/tests/test_analyzer/test_subclasses/test_aten.py +++ b/tests/test_analyzer/test_subclasses/test_aten.py @@ -1,9 +1,11 @@ from typing import Any, Callable, Union -import pytest +import pytest import torch import torch.nn as nn +from colossalai.testing import clear_cache_before_run + try: from colossalai._analyzer._subclasses import MetaTensor except: @@ -72,6 +74,7 @@ def run_and_compare(f: Union[nn.Module, Callable], x: torch.Tensor, requires_bac @pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12') +@clear_cache_before_run() def test_meta_aten(): for (aten_op, requires_backward), v in registered_meta.items(): for f, x in v: diff --git a/tests/test_analyzer/test_subclasses/test_flop_tensor.py b/tests/test_analyzer/test_subclasses/test_flop_tensor.py index 752836141fe7..da3829e40146 100644 --- a/tests/test_analyzer/test_subclasses/test_flop_tensor.py +++ b/tests/test_analyzer/test_subclasses/test_flop_tensor.py @@ -4,6 +4,7 @@ import torchvision.models as tm from packaging import version +from colossalai.testing import clear_cache_before_run, parameterize from tests.test_analyzer.test_fx.zoo import tm_models, tmm_models try: @@ -39,7 +40,8 @@ def test_flop_count_module(m): @pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') -@pytest.mark.parametrize('func, args, kwargs', odd_cases) +@clear_cache_before_run() +@parameterize('func, args, kwargs', odd_cases) def test_flop_count_function(func, args, kwargs): rs_fwd, rs_bwd = flop_count(func, *args, **kwargs, verbose=True) assert rs_fwd > 0, f'fwd flop count of {func.__name__} is {rs_fwd}' diff --git a/tests/test_analyzer/test_subclasses/test_meta_mode.py b/tests/test_analyzer/test_subclasses/test_meta_mode.py index 160d411f6c39..d2a0a1b9cfb5 100644 --- a/tests/test_analyzer/test_subclasses/test_meta_mode.py +++ b/tests/test_analyzer/test_subclasses/test_meta_mode.py @@ -3,6 +3,8 @@ import torchvision.models as tm from packaging import version +from colossalai.testing import clear_cache_before_run, parameterize + try: from colossalai._analyzer._subclasses import MetaTensor, MetaTensorMode except: @@ -30,7 +32,8 @@ def run_and_compare(model): @pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') -@pytest.mark.parametrize('m', tm_models + tmm_models) +@clear_cache_before_run() +@parameterize('m', tm_models + tmm_models) def test_meta_mode_shape(m): run_and_compare(m()) diff --git a/tests/test_auto_parallel/test_ckpt_solvers/test_C_solver_consistency.py b/tests/test_auto_parallel/test_ckpt_solvers/test_C_solver_consistency.py index f8dd0b16b7f6..f184f64b35d0 100644 --- a/tests/test_auto_parallel/test_ckpt_solvers/test_C_solver_consistency.py +++ b/tests/test_auto_parallel/test_ckpt_solvers/test_C_solver_consistency.py @@ -3,7 +3,6 @@ import pytest import torch import torch.fx -import torch.multiprocessing as mp import torchvision.models as tm import colossalai @@ -13,7 +12,7 @@ # from colossalai.fx.passes.algorithms import solver_rotor # from colossalai.fx.passes.algorithms.operation import Sequence from colossalai.fx.passes.meta_info_prop import MetaInfoProp -from colossalai.utils import free_port +from colossalai.testing import rerun_if_address_is_in_use, spawn if is_compatible_with_meta(): from colossalai.fx.profiler.tensor import MetaTensor @@ -26,8 +25,8 @@ withcodegen = False -def _run_C_solver_consistency_test(rank=0): - colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl') +def _run_C_solver_consistency_test(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') for M, mem_budget in [(tm.resnet50, 4000), (tm.densenet121, 8080)]: model = M() @@ -70,8 +69,9 @@ def _run_C_solver_consistency_test(rank=0): @pytest.mark.skip("TODO(lyl): refactor all tests.") @pytest.mark.skipif(not withcodegen, reason="torch version is less than 1.12.0") +@rerun_if_address_is_in_use() def test_C_solver_consistency(): - mp.spawn(_run_C_solver_consistency_test, nprocs=1) + spawn(_run_C_solver_consistency_test, 1) if __name__ == '__main__': diff --git a/tests/test_auto_parallel/test_ckpt_solvers/test_ckpt_torchvision.py b/tests/test_auto_parallel/test_ckpt_solvers/test_ckpt_torchvision.py index 89600ea098a9..db268b91d0a0 100644 --- a/tests/test_auto_parallel/test_ckpt_solvers/test_ckpt_torchvision.py +++ b/tests/test_auto_parallel/test_ckpt_solvers/test_ckpt_torchvision.py @@ -4,7 +4,6 @@ import pytest import torch -import torch.multiprocessing as mp import torchvision.models as tm from torch.fx import GraphModule @@ -15,7 +14,7 @@ from colossalai.fx.graph_module import ColoGraphModule # from colossalai.fx.passes.algorithms import chen_greedy, solver_rotor from colossalai.fx.passes.meta_info_prop import MetaInfoProp -from colossalai.utils import free_port +from colossalai.testing import rerun_if_address_is_in_use, spawn if is_compatible_with_meta(): from colossalai.fx.profiler.tensor import MetaTensor @@ -68,8 +67,8 @@ def check_backward_consistency(m: torch.nn.Module, gm: GraphModule, solver: Call assert _is_all_gradient_close(m, gm), f'Solver {solver} did not work correctly in backward pass on {model_cls}' -def _run_ckpt_solver(rank): - colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl') +def _run_ckpt_solver(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') MODEL_LIST = [tm.densenet121] torch.backends.cudnn.deterministic = True @@ -98,12 +97,13 @@ def _run_ckpt_solver(rank): @pytest.mark.skip("TODO(super-dainiu): refactor all tests.") @pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0') +@rerun_if_address_is_in_use() def test_ckpt_solver(): - mp.spawn(_run_ckpt_solver, nprocs=1) + spawn(_run_ckpt_solver, 1) -def _run_ckpt_solver_torch11(rank): - colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl') +def _run_ckpt_solver_torch11(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') MODEL_LIST = [tm.densenet121] torch.backends.cudnn.deterministic = True @@ -131,8 +131,9 @@ def _run_ckpt_solver_torch11(rank): @pytest.mark.skipif(with_codegen, reason='torch version is equal to or higher than 1.12.0') @pytest.mark.skip(reason="currently torch11 ColoGraphModule is not done") +@rerun_if_address_is_in_use() def test_ckpt_solver_torch11(): - mp.spawn(_run_ckpt_solver_torch11, nprocs=1) + spawn(_run_ckpt_solver_torch11, 1) if __name__ == '__main__': diff --git a/tests/test_auto_parallel/test_ckpt_solvers/test_linearize.py b/tests/test_auto_parallel/test_ckpt_solvers/test_linearize.py index 0f90ba0b0989..59880815dc5e 100644 --- a/tests/test_auto_parallel/test_ckpt_solvers/test_linearize.py +++ b/tests/test_auto_parallel/test_ckpt_solvers/test_linearize.py @@ -8,6 +8,7 @@ # from colossalai.fx.passes.algorithms import linearize, solver_rotor # from colossalai.fx.passes.algorithms.operation import (ForwardCheck, ForwardEnable, ForwardNograd, Loss) from colossalai.fx.passes.meta_info_prop import MetaInfoProp +from colossalai.testing import clear_cache_before_run if is_compatible_with_meta(): from colossalai.fx.profiler.tensor import MetaTensor @@ -24,6 +25,7 @@ @pytest.mark.skip(reason='TODO: modify the logger') @pytest.mark.skip("TODO(lyl): refactor all tests.") @pytest.mark.skipif(not with_codegen, reason="torch version is lower than 1.12.0") +@clear_cache_before_run() def test_linearize(): MODEL_DICT = {tm.resnet18: [2100, 3000], tm.densenet121: [8100, 17000]} tracer = ColoTracer() @@ -84,6 +86,7 @@ def test_linearize(): @pytest.mark.skip("TODO(lyl): refactor all tests.") @pytest.mark.skip(reason="torch11 meta tensor not implemented") @pytest.mark.skipif(with_codegen, reason="torch version is equal to or higher than 1.12.0") +@clear_cache_before_run() def test_linearize_torch11(): MODEL_DICT = {tm.resnet18: [2100, 3000], tm.densenet121: [8100, 17000]} tracer = ColoTracer() diff --git a/tests/test_auto_parallel/test_offload/test_perf.py b/tests/test_auto_parallel/test_offload/test_perf.py index c925843fb2b6..80f134fd85d0 100644 --- a/tests/test_auto_parallel/test_offload/test_perf.py +++ b/tests/test_auto_parallel/test_offload/test_perf.py @@ -1,9 +1,7 @@ import time -from functools import partial import pytest import torch -import torch.multiprocessing as mp from torch.utils._pytree import tree_map import colossalai @@ -12,8 +10,8 @@ from colossalai.auto_parallel.offload.solver import NOT_NVML from colossalai.fx.profiler import parameter_size from colossalai.nn.optimizer import HybridAdam -from colossalai.testing import parameterize -from colossalai.utils import free_port, get_current_device +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.utils import get_current_device from colossalai.zero import ColoInitContext, zero_model_wrapper, zero_optim_wrapper from tests.test_auto_parallel.test_offload.model_utils import * from tests.test_tensor.common_utils import set_seed @@ -140,9 +138,9 @@ def run_dist(rank, world_size, port): @pytest.mark.skip("this test failed") @pytest.mark.skipif(NOT_NVML, reason='pynvml is not installed') +@rerun_if_address_is_in_use() def test_perf(): - run_func = partial(run_dist, world_size=1, port=free_port()) - mp.spawn(run_func, nprocs=1) + spawn(run_dist, 1) if __name__ == '__main__': diff --git a/tests/test_auto_parallel/test_offload/test_solver.py b/tests/test_auto_parallel/test_offload/test_solver.py index 2efbb750f80d..aa2c9a36849f 100644 --- a/tests/test_auto_parallel/test_offload/test_solver.py +++ b/tests/test_auto_parallel/test_offload/test_solver.py @@ -3,20 +3,20 @@ from torch.fx import GraphModule from torch.utils._pytree import tree_map +from colossalai.auto_parallel.offload.region_manager import RegionManager +from colossalai.auto_parallel.offload.solver import NOT_NVML, SolverFactory from colossalai.fx import ColoTracer, is_compatible_with_meta from colossalai.fx.passes.meta_info_prop import MetaInfoProp -from colossalai.auto_parallel.offload.region_manager import RegionManager -from colossalai.auto_parallel.offload.solver import SolverFactory, NOT_NVML -from colossalai.testing import parameterize +from colossalai.testing import clear_cache_before_run, parameterize from tests.test_auto_parallel.test_offload.model_utils import * + @pytest.mark.skipif(NOT_NVML, reason='pynvml is not installed') +@clear_cache_before_run() @parameterize('model_name', ['gpt2_', 'bert_']) @parameterize('memory_budget', [4000]) @parameterize('solver_name', ['syn', 'asyn']) -def solver_test(model_name: str, - memory_budget: float, - solver_name: str): +def solver_test(model_name: str, memory_budget: float, solver_name: str): get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, data_gen = get_components_func() @@ -52,11 +52,16 @@ def solver_test(model_name: str, for region in region_list: need_offload = region.need_offload to_prefetch = region.fwd_prefetch_region.r_id if region.fwd_prefetch_region is not None else None - print(f'| {model_name} forward | region id: {region.r_id} | need_offload: {need_offload} | to_prefetch: {to_prefetch}') + print( + f'| {model_name} forward | region id: {region.r_id} | need_offload: {need_offload} | to_prefetch: {to_prefetch}' + ) for region in region_list.__reversed__(): need_offload = region.need_offload to_prefetch = region.bwd_prefetch_region.r_id if region.bwd_prefetch_region is not None else None - print(f'| {model_name} backward | region id: {region.r_id} | need_offload: {need_offload} | to_prefetch: {to_prefetch}') + print( + f'| {model_name} backward | region id: {region.r_id} | need_offload: {need_offload} | to_prefetch: {to_prefetch}' + ) + if __name__ == '__main__': - solver_test() \ No newline at end of file + solver_test() diff --git a/tests/test_auto_parallel/test_pass/test_node_converting_pass.py b/tests/test_auto_parallel/test_pass/test_node_converting_pass.py index d0d107610f7a..429e89aae5d3 100644 --- a/tests/test_auto_parallel/test_pass/test_node_converting_pass.py +++ b/tests/test_auto_parallel/test_pass/test_node_converting_pass.py @@ -6,6 +6,7 @@ from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.tracer import ColoTracer from colossalai.tensor.sharding_spec import ShardingSpec +from colossalai.testing import clear_cache_before_run class TestModule(torch.nn.Module): @@ -26,6 +27,7 @@ def insert_narrow(gm, x_node): return gm +@clear_cache_before_run() def test_node_args_converting_pass(): model = TestModule() physical_mesh_id = torch.arange(0, 4) diff --git a/tests/test_auto_parallel/test_pass/test_size_value_converting_pass.py b/tests/test_auto_parallel/test_pass/test_size_value_converting_pass.py index 7d4fd844ab26..bca81201c6ef 100644 --- a/tests/test_auto_parallel/test_pass/test_size_value_converting_pass.py +++ b/tests/test_auto_parallel/test_pass/test_size_value_converting_pass.py @@ -8,6 +8,7 @@ from colossalai.auto_parallel.passes.runtime_preparation_pass import size_value_converting_pass from colossalai.device.device_mesh import DeviceMesh from colossalai.tensor.sharding_spec import ShardingSpec +from colossalai.testing import clear_cache_before_run class TestModule(torch.nn.Module): @@ -36,6 +37,7 @@ def recover_narrow(gm, narrow_node): @pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0') +@clear_cache_before_run() def test_size_value_converting_pass(): model = TestModule() physical_mesh_id = torch.arange(0, 4) diff --git a/tests/test_auto_parallel/test_tensor_shard/test_bias_addition_forward.py b/tests/test_auto_parallel/test_tensor_shard/test_bias_addition_forward.py index 6d1b28912c9b..9fbe674ef4f4 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_bias_addition_forward.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_bias_addition_forward.py @@ -2,7 +2,6 @@ import pytest import torch -import torch.multiprocessing as mp try: from colossalai.auto_parallel.tensor_shard.initialize import initialize_model @@ -13,9 +12,7 @@ from colossalai.device.device_mesh import DeviceMesh from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers -from colossalai.testing import assert_close, rerun_if_address_is_in_use -from colossalai.testing.pytest_wrapper import run_on_environment_flag -from colossalai.utils import free_port +from colossalai.testing import assert_close, rerun_if_address_is_in_use, run_on_environment_flag, spawn class LinearModel(torch.nn.Module): @@ -86,11 +83,8 @@ def check_conv_module(rank, world_size, port): @pytest.mark.dist @rerun_if_address_is_in_use() def test_bias_addition_module(): - world_size = 4 - run_func_linear = partial(check_linear_module, world_size=world_size, port=free_port()) - mp.spawn(run_func_linear, nprocs=world_size) - run_func_conv = partial(check_conv_module, world_size=world_size, port=free_port()) - mp.spawn(run_func_conv, nprocs=world_size) + spawn(check_linear_module, 4) + spawn(check_conv_module, 4) if __name__ == '__main__': diff --git a/tests/test_auto_parallel/test_tensor_shard/test_checkpoint.py b/tests/test_auto_parallel/test_tensor_shard/test_checkpoint.py index 7a4c8d32ed80..398458306e3d 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_checkpoint.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_checkpoint.py @@ -1,9 +1,7 @@ -from functools import partial -from typing import Optional, Tuple, Union +from typing import Optional, Tuple import pytest import torch -import torch.multiprocessing as mp import torch.nn as nn from torch.utils.checkpoint import checkpoint from transformers.pytorch_utils import Conv1D @@ -17,9 +15,7 @@ from colossalai.device.device_mesh import DeviceMesh from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.testing.pytest_wrapper import run_on_environment_flag -from colossalai.utils import free_port +from colossalai.testing import rerun_if_address_is_in_use, run_on_environment_flag, spawn HIDDEN_SIZE = 16 @@ -65,9 +61,7 @@ def check_act_ckpt(rank, world_size, port): @pytest.mark.dist @rerun_if_address_is_in_use() def test_mlp_layer(): - world_size = 4 - run_func = partial(check_act_ckpt, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(check_act_ckpt, 4) if __name__ == '__main__': diff --git a/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_ddp.py b/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_ddp.py index 7c3277c69970..6908a1781869 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_ddp.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_ddp.py @@ -1,9 +1,7 @@ import copy -from functools import partial import pytest import torch -import torch.multiprocessing as mp from torch.nn.parallel import DistributedDataParallel as DDP try: @@ -15,9 +13,7 @@ from colossalai.device.device_mesh import DeviceMesh from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers -from colossalai.testing import assert_close, rerun_if_address_is_in_use -from colossalai.testing.pytest_wrapper import run_on_environment_flag -from colossalai.utils import free_port +from colossalai.testing import assert_close, rerun_if_address_is_in_use, run_on_environment_flag, spawn class MLP(torch.nn.Module): @@ -102,9 +98,7 @@ def check_compatibility_with_ddp(rank, world_size, port): @pytest.mark.dist @rerun_if_address_is_in_use() def test_compatibility_with_ddp(): - world_size = 4 - run_func = partial(check_compatibility_with_ddp, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(check_compatibility_with_ddp, 4) if __name__ == '__main__': diff --git a/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py b/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py index e4435a049f62..05704acbf7fd 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py @@ -1,10 +1,7 @@ import copy -from functools import partial import pytest import torch -import torch.multiprocessing as mp -from torch.nn.parallel import DistributedDataParallel as DDP try: from colossalai.auto_parallel.tensor_shard.initialize import initialize_model @@ -17,10 +14,9 @@ from colossalai.logging import disable_existing_loggers from colossalai.nn.optimizer import HybridAdam from colossalai.tensor.process_group import ProcessGroup -from colossalai.testing import assert_close, rerun_if_address_is_in_use -from colossalai.testing.pytest_wrapper import run_on_environment_flag -from colossalai.utils import free_port, get_current_device -from colossalai.zero import ColoInitContext, post_process_colo_init_ctx, zero_model_wrapper, zero_optim_wrapper +from colossalai.testing import assert_close, rerun_if_address_is_in_use, run_on_environment_flag, spawn +from colossalai.utils import get_current_device +from colossalai.zero import post_process_colo_init_ctx, zero_model_wrapper, zero_optim_wrapper class MLP(torch.nn.Module): @@ -110,9 +106,7 @@ def check_auto_parallel_with_gemini(rank, world_size, port): @pytest.mark.dist @rerun_if_address_is_in_use() def test_auto_parallel_with_gemini(): - world_size = 4 - run_func = partial(check_auto_parallel_with_gemini, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(check_auto_parallel_with_gemini, 4) if __name__ == '__main__': diff --git a/tests/test_auto_parallel/test_tensor_shard/test_find_repeat_block.py b/tests/test_auto_parallel/test_tensor_shard/test_find_repeat_block.py index e7fccad36caf..a0b407b240e1 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_find_repeat_block.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_find_repeat_block.py @@ -10,8 +10,7 @@ # from colossalai.fx.tracer.tracer import ColoTracer from colossalai._analyzer.fx.tracer.tracer import ColoTracer from colossalai.auto_parallel.tensor_shard.utils.factory import find_repeat_blocks -from colossalai.testing import parameterize -from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.testing import clear_cache_before_run, parameterize, run_on_environment_flag NUM_REPEAT_BLOCKS = 4 BATCH_SIZE = 1 @@ -81,6 +80,7 @@ def forward(self, x): @run_on_environment_flag(name='AUTO_PARALLEL') +@clear_cache_before_run() @parameterize('model_cls', [RepeatModel, NonRepeatModel]) def test_repeat_blocks(model_cls): diff --git a/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_runtime_with_gpt_modules.py b/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_runtime_with_gpt_modules.py index 8688890efc93..48d2672c6571 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_runtime_with_gpt_modules.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_runtime_with_gpt_modules.py @@ -1,12 +1,10 @@ import copy import random -from functools import partial from typing import Dict import numpy as np import pytest import torch -import torch.multiprocessing as mp import transformers from torch.fx import GraphModule @@ -30,9 +28,8 @@ from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers from colossalai.tensor.shape_consistency import to_global -from colossalai.testing import assert_close, assert_close_loose, parameterize, rerun_if_address_is_in_use +from colossalai.testing import assert_close, assert_close_loose, parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing.pytest_wrapper import run_on_environment_flag -from colossalai.utils import free_port from tests.test_auto_parallel.test_tensor_shard.test_gpt.gpt_modules import GPT2MLP, GPT2Attention, GPT2Block, GPT2Model BATCH_SIZE = 1 @@ -190,9 +187,7 @@ def check_attention_layer(rank, model_cls, world_size, port): @parameterize('model_cls', [GPT2MLP, GPT2Block, GPT2Attention, GPT2Model]) @rerun_if_address_is_in_use() def test_mlp_layer(model_cls): - world_size = 4 - run_func = partial(check_attention_layer, model_cls=model_cls, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(check_attention_layer, 4, model_cls=model_cls) if __name__ == '__main__': diff --git a/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_solver_with_gpt_module.py b/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_solver_with_gpt_module.py index 5f0688d5f5f2..5a8c3c4bf5a0 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_solver_with_gpt_module.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_solver_with_gpt_module.py @@ -1,5 +1,4 @@ import torch -import torch.nn as nn import transformers from torch.fx import GraphModule @@ -7,10 +6,10 @@ from colossalai._analyzer.fx.tracer.tracer import ColoTracer from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP from colossalai.auto_parallel.tensor_shard.options import SolverOptions -from colossalai.auto_parallel.tensor_shard.solver import CostGraph, GraphAnalyser, Solver, StrategiesConstructor +from colossalai.auto_parallel.tensor_shard.solver import CostGraph, Solver, StrategiesConstructor from colossalai.device.device_mesh import DeviceMesh from colossalai.tensor.shape_consistency import ShapeConsistencyManager -from colossalai.testing import parameterize +from colossalai.testing import clear_cache_before_run, parameterize from colossalai.testing.pytest_wrapper import run_on_environment_flag from tests.test_auto_parallel.test_tensor_shard.test_gpt.gpt_modules import GPT2MLP, GPT2Attention, GPT2Block, GPT2Model @@ -20,6 +19,7 @@ @run_on_environment_flag(name='AUTO_PARALLEL') +@clear_cache_before_run() @parameterize('model_cls', [GPT2Block, GPT2Attention, GPT2MLP, GPT2Model]) def test_self_attention_block(model_cls): config = transformers.GPT2Config(n_position=64, n_layer=2, n_head=16, n_embd=HIDDEN_DIM) diff --git a/tests/test_auto_parallel/test_tensor_shard/test_liveness_analysis.py b/tests/test_auto_parallel/test_tensor_shard/test_liveness_analysis.py index 8d421243827e..d10b222c060d 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_liveness_analysis.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_liveness_analysis.py @@ -6,6 +6,8 @@ from colossalai._analyzer.fx.passes import shape_prop_pass from colossalai._analyzer.fx.tracer.tracer import ColoTracer from colossalai.auto_parallel.tensor_shard.solver import GraphAnalyser +from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.testing import clear_cache_before_run class LinearModel(nn.Module): @@ -26,6 +28,7 @@ def forward(self, x1, x2): @pytest.mark.skip('meta tensor has some bugs in 1.11') +@clear_cache_before_run() def test_liveness_analysis(): model = LinearModel() tracer = ColoTracer(bias_addition_split=True) diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_activation_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_activation_metainfo.py index e41ac4fa690b..e0a2133e654e 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_activation_metainfo.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_activation_metainfo.py @@ -1,23 +1,14 @@ -from functools import partial - import pytest import torch -import torch.multiprocessing as mp -import torch.nn as nn from colossalai.auto_parallel.meta_profiler import meta_register from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType -from colossalai.device.device_mesh import DeviceMesh -from colossalai.fx import ColoGraphModule, ColoTracer -from colossalai.initialize import launch -from colossalai.logging import disable_existing_loggers -from colossalai.testing.pytest_wrapper import run_on_environment_flag -from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use -from colossalai.utils import free_port -from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy, print_results +from colossalai.testing.utils import clear_cache_before_run, parameterize +from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results @pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations") +@clear_cache_before_run() @parameterize('func', [ torch.nn.functional.softmax, torch.nn.functional.relu, diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_binary_elementwise_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_binary_elementwise_metainfo.py index 1b745d8906b0..68ccc7835bc3 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_binary_elementwise_metainfo.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_binary_elementwise_metainfo.py @@ -1,8 +1,5 @@ -from functools import partial - import pytest import torch -import torch.multiprocessing as mp import torch.nn as nn from colossalai.device.device_mesh import DeviceMesh @@ -10,8 +7,7 @@ from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers from colossalai.testing.pytest_wrapper import run_on_environment_flag -from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing.utils import rerun_if_address_is_in_use, spawn from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy @@ -62,9 +58,7 @@ def _binary_elementwise_mem_test(rank, world_size, port): @pytest.mark.dist @rerun_if_address_is_in_use() def test_binary_elementwise_meta_concrete_info_match(): - world_size = 4 - run_func_module = partial(_binary_elementwise_mem_test, world_size=world_size, port=free_port()) - mp.spawn(run_func_module, nprocs=world_size) + spawn(_binary_elementwise_mem_test, 4) if __name__ == '__main__': diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_conv_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_conv_metainfo.py index a973a8182cf3..c6f7b88f44a5 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_conv_metainfo.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_conv_metainfo.py @@ -1,17 +1,12 @@ -from functools import partial - import pytest import torch -import torch.multiprocessing as mp import torch.nn as nn from colossalai.device.device_mesh import DeviceMesh -from colossalai.fx import ColoGraphModule, ColoTracer from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers from colossalai.testing.pytest_wrapper import run_on_environment_flag -from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing.utils import rerun_if_address_is_in_use, spawn from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy @@ -25,7 +20,7 @@ def forward(self, input): return nn.functional.conv2d(input, self.conv_weight) -def _conv_module_mem_test(rank, bias, world_size, port): +def _conv_module_mem_test(rank, world_size, port, bias): """This function is for conv memory test Test and print real memory cost and estimated, this test will not be executed except with the tag AUTO_PARALLEL @@ -62,9 +57,7 @@ def _conv_module_mem_test(rank, bias, world_size, port): @pytest.mark.dist @rerun_if_address_is_in_use() def test_conv_meta_concrete_info_match(bias=False): - world_size = 4 - run_func_module = partial(_conv_module_mem_test, bias=bias, world_size=world_size, port=free_port()) - mp.spawn(run_func_module, nprocs=world_size) + spawn(_conv_module_mem_test, 4, bias=bias) def _conv_function_mem_test(rank, world_size, port): @@ -103,9 +96,7 @@ def _conv_function_mem_test(rank, world_size, port): @pytest.mark.dist @rerun_if_address_is_in_use() def test_conv_function_concrete_info_match(): - world_size = 4 - run_func_module = partial(_conv_function_mem_test, world_size=world_size, port=free_port()) - mp.spawn(run_func_module, nprocs=world_size) + spawn(_conv_function_mem_test, 4) if __name__ == '__main__': diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_embedding_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_embedding_metainfo.py index 5f3d2df503f6..e3f76a95c4a5 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_embedding_metainfo.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_embedding_metainfo.py @@ -1,33 +1,16 @@ -from functools import partial - import pytest import torch -import torch.multiprocessing as mp -import torch.nn as nn - -from colossalai.auto_parallel.tensor_shard.node_handler import LinearModuleHandler -from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( - MemoryCost, - OperationData, - OperationDataType, - ShardingStrategy, - StrategiesVector, - TrainCycleItem, -) -from colossalai.device.device_mesh import DeviceMesh -from colossalai.fx import ColoGraphModule, ColoTracer -from colossalai.initialize import launch -from colossalai.logging import disable_existing_loggers -from colossalai.testing.pytest_wrapper import run_on_environment_flag -from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use -from colossalai.utils import free_port + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType +from colossalai.testing.utils import clear_cache_before_run from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results if torch.__version__ >= '1.12.0': - from colossalai.auto_parallel.meta_profiler import ShardMetaInfo, meta_register + from colossalai.auto_parallel.meta_profiler import meta_register @pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations") +@clear_cache_before_run() def test_embedding_meta_info(): meta_func = meta_register.get(torch.nn.Embedding) diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_linear_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_linear_metainfo.py index ddc8e3c6ac02..fb3ded339ddf 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_linear_metainfo.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_linear_metainfo.py @@ -1,24 +1,14 @@ -from functools import partial - import pytest import torch -import torch.multiprocessing as mp import torch.nn as nn -from colossalai.auto_parallel.tensor_shard.node_handler import LinearModuleHandler -from colossalai.auto_parallel.tensor_shard.sharding_strategy import ShardingStrategy, StrategiesVector from colossalai.device.device_mesh import DeviceMesh -from colossalai.fx import ColoGraphModule, ColoTracer from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers from colossalai.testing.pytest_wrapper import run_on_environment_flag -from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing.utils import rerun_if_address_is_in_use, spawn from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy -if torch.__version__ >= '1.12.0': - from colossalai.auto_parallel.meta_profiler import ShardMetaInfo, meta_register - class MyModule(nn.Module): @@ -63,9 +53,7 @@ def _linear_module_mem_test(rank, world_size, port): @pytest.mark.dist @rerun_if_address_is_in_use() def test_linear_module_meta_concrete_info_match(): - world_size = 4 - run_func_module = partial(_linear_module_mem_test, world_size=world_size, port=free_port()) - mp.spawn(run_func_module, nprocs=world_size) + spawn(_linear_module_mem_test, 4) def _linear_function_mem_test(rank, world_size, port): @@ -101,9 +89,7 @@ def _linear_function_mem_test(rank, world_size, port): @pytest.mark.dist @rerun_if_address_is_in_use() def test_linear_function_meta_concrete_info_match(): - world_size = 4 - run_func_module = partial(_linear_function_mem_test, world_size=world_size, port=free_port()) - mp.spawn(run_func_module, nprocs=world_size) + spawn(_linear_function_mem_test, 4) if __name__ == '__main__': diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_matmul_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_matmul_metainfo.py index 1242b9db0750..2d2d77f0c637 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_matmul_metainfo.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_matmul_metainfo.py @@ -1,26 +1,8 @@ -from functools import partial - import pytest import torch -import torch.multiprocessing as mp -import torch.nn as nn - -from colossalai.auto_parallel.tensor_shard.node_handler import LinearModuleHandler -from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( - MemoryCost, - OperationData, - OperationDataType, - ShardingStrategy, - StrategiesVector, - TrainCycleItem, -) -from colossalai.device.device_mesh import DeviceMesh -from colossalai.fx import ColoGraphModule, ColoTracer -from colossalai.initialize import launch -from colossalai.logging import disable_existing_loggers -from colossalai.testing.pytest_wrapper import run_on_environment_flag -from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use -from colossalai.utils import free_port + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, TrainCycleItem +from colossalai.testing.utils import clear_cache_before_run, parameterize from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results if torch.__version__ >= '1.12.0': @@ -28,6 +10,7 @@ @pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations") +@clear_cache_before_run() @parameterize( 'tensor_shapes', [ diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_norm_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_norm_metainfo.py index d3342d310157..808172977b60 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_norm_metainfo.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_norm_metainfo.py @@ -1,29 +1,17 @@ -from functools import partial - import pytest import torch -import torch.multiprocessing as mp import torch.nn as nn -from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( - MemoryCost, - OperationData, - OperationDataType, - ShardingStrategy, - StrategiesVector, - TrainCycleItem, -) +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, TrainCycleItem from colossalai.device.device_mesh import DeviceMesh -from colossalai.fx import ColoGraphModule, ColoTracer from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers from colossalai.testing.pytest_wrapper import run_on_environment_flag -from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use, spawn from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy, print_results if torch.__version__ >= '1.12.0': - from colossalai.auto_parallel.meta_profiler import ShardMetaInfo, meta_register + from colossalai.auto_parallel.meta_profiler import meta_register def _batchnorm_module_mem_test(rank, world_size, port): @@ -62,9 +50,7 @@ def _batchnorm_module_mem_test(rank, world_size, port): @pytest.mark.dist @rerun_if_address_is_in_use() def test_batchnorm_meta_concrete_info_match(): - world_size = 4 - run_func_module = partial(_batchnorm_module_mem_test, world_size=world_size, port=free_port()) - mp.spawn(run_func_module, nprocs=world_size) + spawn(_batchnorm_module_mem_test, 4) @pytest.mark.skipif(torch.__version__ < '1.12.0', reason='need pytorch 1.12.0 or higher for aten level operations') diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_pooling_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_pooling_metainfo.py index 529686d27d19..4cddf4e19fca 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_pooling_metainfo.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_pooling_metainfo.py @@ -1,17 +1,12 @@ -from functools import partial - import pytest import torch -import torch.multiprocessing as mp import torch.nn as nn from colossalai.device.device_mesh import DeviceMesh -from colossalai.fx import ColoGraphModule, ColoTracer from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers from colossalai.testing.pytest_wrapper import run_on_environment_flag -from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing.utils import rerun_if_address_is_in_use, spawn from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy @@ -51,9 +46,7 @@ def _adaptiveavgpool_module_mem_test(rank, world_size, port): @pytest.mark.dist @rerun_if_address_is_in_use() def test_adaptiveavgpool_meta_concrete_info_match(): - world_size = 4 - run_func_module = partial(_adaptiveavgpool_module_mem_test, world_size=world_size, port=free_port()) - mp.spawn(run_func_module, nprocs=world_size) + spawn(_adaptiveavgpool_module_mem_test, 4) def _maxpool_module_mem_test(rank, world_size, port): @@ -92,9 +85,7 @@ def _maxpool_module_mem_test(rank, world_size, port): @pytest.mark.dist @rerun_if_address_is_in_use() def test_maxpool_meta_concrete_info_match(): - world_size = 4 - run_func_module = partial(_maxpool_module_mem_test, world_size=world_size, port=free_port()) - mp.spawn(run_func_module, nprocs=world_size) + spawn(_maxpool_module_mem_test, 4) if __name__ == '__main__': diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_tensor_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_tensor_metainfo.py index a544e9a3cb2f..6e8145885d67 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_tensor_metainfo.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_tensor_metainfo.py @@ -1,26 +1,9 @@ -from functools import partial - import pytest import torch -import torch.multiprocessing as mp import torch.nn as nn -from colossalai.auto_parallel.tensor_shard.node_handler import LinearModuleHandler -from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( - MemoryCost, - OperationData, - OperationDataType, - ShardingStrategy, - StrategiesVector, - TrainCycleItem, -) -from colossalai.device.device_mesh import DeviceMesh -from colossalai.fx import ColoGraphModule, ColoTracer -from colossalai.initialize import launch -from colossalai.logging import disable_existing_loggers -from colossalai.testing.pytest_wrapper import run_on_environment_flag -from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType +from colossalai.testing.utils import clear_cache_before_run from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results if torch.__version__ >= '1.12.0': @@ -37,6 +20,7 @@ def forward(self, x): @pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations") +@clear_cache_before_run() def test_tensor_meta_info(): """test tensor related meta information We will just use torch.Tensor.split for the test diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_where_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_where_metainfo.py index 2ae13ea2b1f8..b4564312eeb4 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_where_metainfo.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_where_metainfo.py @@ -1,24 +1,8 @@ import pytest import torch -import torch.multiprocessing as mp -import torch.nn as nn - -from colossalai.auto_parallel.tensor_shard.node_handler import LinearModuleHandler -from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( - MemoryCost, - OperationData, - OperationDataType, - ShardingStrategy, - StrategiesVector, - TrainCycleItem, -) -from colossalai.device.device_mesh import DeviceMesh -from colossalai.fx import ColoGraphModule, ColoTracer -from colossalai.initialize import launch -from colossalai.logging import disable_existing_loggers -from colossalai.testing.pytest_wrapper import run_on_environment_flag -from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use -from colossalai.utils import free_port + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, TrainCycleItem +from colossalai.testing.utils import clear_cache_before_run from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results if torch.__version__ >= '1.12.0': @@ -26,6 +10,7 @@ @pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations") +@clear_cache_before_run() def test_where_meta_info(): meta_func = meta_register.get(torch.where) diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addbmm_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addbmm_handler.py index ffc15e403f35..80e6a6c1460c 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addbmm_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addbmm_handler.py @@ -1,8 +1,5 @@ -from functools import partial - import pytest import torch -import torch.multiprocessing as mp import torch.nn as nn from colossalai.auto_parallel.tensor_shard.node_handler import BMMFunctionHandler @@ -11,9 +8,7 @@ from colossalai.fx import ColoGraphModule, ColoTracer from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers -from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use -from colossalai.testing.pytest_wrapper import run_on_environment_flag -from colossalai.utils import free_port +from colossalai.testing import parameterize, rerun_if_address_is_in_use, run_on_environment_flag, spawn from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy @@ -45,7 +40,7 @@ def forward(self, bias, x1, x2): return output -def check_2d_device_mesh(rank, module, bias_shape, using_kwargs, world_size, port): +def check_2d_device_mesh(rank, world_size, port, module, bias_shape, using_kwargs): disable_existing_loggers() launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') model = module(using_kwargs).cuda() @@ -249,14 +244,13 @@ def check_1d_device_mesh(rank, module, bias_shape, using_kwargs, world_size, por @parameterize('using_kwargs', [True, False]) @rerun_if_address_is_in_use() def test_2d_device_mesh(module, bias_shape, using_kwargs): - world_size = 4 - run_func = partial(check_2d_device_mesh, - module=module, - bias_shape=bias_shape, - world_size=world_size, - using_kwargs=using_kwargs, - port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn( + check_2d_device_mesh, + 4, + module=module, + bias_shape=bias_shape, + using_kwargs=using_kwargs, + ) @pytest.mark.skip("skip due to bias cases not ready") @@ -267,14 +261,13 @@ def test_2d_device_mesh(module, bias_shape, using_kwargs): @parameterize('using_kwargs', [True, False]) @rerun_if_address_is_in_use() def test_1d_device_mesh(module, bias_shape, using_kwargs): - world_size = 4 - run_func = partial(check_1d_device_mesh, - module=module, - bias_shape=bias_shape, - using_kwargs=using_kwargs, - world_size=world_size, - port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn( + check_1d_device_mesh, + 4, + module=module, + bias_shape=bias_shape, + using_kwargs=using_kwargs, + ) if __name__ == '__main__': diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addmm_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addmm_handler.py index 35f12ce83af2..fe6554cd81ee 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addmm_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addmm_handler.py @@ -1,8 +1,5 @@ -from functools import partial - import pytest import torch -import torch.multiprocessing as mp import torch.nn as nn from colossalai._analyzer.fx.graph_module import ColoGraphModule @@ -17,9 +14,7 @@ from colossalai.device.device_mesh import DeviceMesh from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers -from colossalai.testing import parameterize, rerun_if_address_is_in_use -from colossalai.testing.pytest_wrapper import run_on_environment_flag -from colossalai.utils import free_port +from colossalai.testing import parameterize, rerun_if_address_is_in_use, run_on_environment_flag, spawn from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy @@ -45,7 +40,7 @@ def forward(self, m1): return x -def check_addmm_function_handler(rank, input_shape, model_cls, world_size, port): +def check_addmm_function_handler(rank, world_size, port, input_shape, model_cls): disable_existing_loggers() launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') if model_cls == AddmmModel: @@ -189,13 +184,7 @@ def check_addmm_function_handler(rank, input_shape, model_cls, world_size, port) @parameterize('model_cls', [AddmmModel, AddmmModel_with_param]) @rerun_if_address_is_in_use() def test_addmm_handler(input_shape, model_cls): - world_size = 4 - run_func_function = partial(check_addmm_function_handler, - input_shape=input_shape, - model_cls=model_cls, - world_size=world_size, - port=free_port()) - mp.spawn(run_func_function, nprocs=world_size) + spawn(check_addmm_function_handler, 4, input_shape=input_shape, model_cls=model_cls) if __name__ == '__main__': diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_batch_norm_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_batch_norm_handler.py index 2069b5e8a4de..b47b3508ad1b 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_batch_norm_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_batch_norm_handler.py @@ -1,8 +1,5 @@ -from functools import partial - import pytest import torch -import torch.multiprocessing as mp import torch.nn as nn from colossalai._analyzer.fx.graph_module import ColoGraphModule @@ -13,9 +10,7 @@ from colossalai.device.device_mesh import DeviceMesh from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers -from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use -from colossalai.testing.pytest_wrapper import run_on_environment_flag -from colossalai.utils import free_port +from colossalai.testing import rerun_if_address_is_in_use, run_on_environment_flag, spawn from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy @@ -114,9 +109,7 @@ def check_bn_module_handler(rank, world_size, port): @pytest.mark.dist @rerun_if_address_is_in_use() def test_bn_module_handler(): - world_size = 4 - run_func = partial(check_bn_module_handler, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(check_bn_module_handler, 4) if __name__ == '__main__': diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_function_node.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_function_node.py index dca5f6e227fa..800bc11a50e4 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_function_node.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_function_node.py @@ -1,9 +1,5 @@ -from functools import partial - import pytest import torch -import torch.multiprocessing as mp -import torch.nn as nn import torch.nn.functional as F from colossalai._analyzer.fx.graph_module import ColoGraphModule @@ -19,9 +15,7 @@ from colossalai.device.device_mesh import DeviceMesh from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers -from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use -from colossalai.testing.pytest_wrapper import run_on_environment_flag -from colossalai.utils import free_port +from colossalai.testing import rerun_if_address_is_in_use, run_on_environment_flag, spawn from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy WEIGHT_SHAPE = (32, 16) @@ -168,9 +162,7 @@ def check_linear_module_handler(rank, world_size, port): @pytest.mark.dist @rerun_if_address_is_in_use() def test_linear_handler(): - world_size = 4 - run_func_module = partial(check_linear_module_handler, world_size=world_size, port=free_port()) - mp.spawn(run_func_module, nprocs=world_size) + spawn(check_linear_module_handler) if __name__ == '__main__': diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_module_node.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_module_node.py index 14d4a73fb4f8..c29a065d10ba 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_module_node.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bias_linear_module_node.py @@ -1,14 +1,10 @@ -from functools import partial - import pytest import torch -import torch.multiprocessing as mp -import torch.nn as nn from colossalai._analyzer.fx.graph_module import ColoGraphModule from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass from colossalai._analyzer.fx.tracer.tracer import ColoTracer -from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler, LinearModuleHandler +from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( OperationData, OperationDataType, @@ -18,9 +14,7 @@ from colossalai.device.device_mesh import DeviceMesh from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers -from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use -from colossalai.testing.pytest_wrapper import run_on_environment_flag -from colossalai.utils import free_port +from colossalai.testing import rerun_if_address_is_in_use, run_on_environment_flag, spawn from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy @@ -35,7 +29,7 @@ def forward(self, x): return x -def check_linear_module_handler(rank, bias, world_size, port): +def check_linear_module_handler(rank, world_size, port, bias): disable_existing_loggers() launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') model = LinearModule(16, 32, bias=bias).cuda() @@ -157,9 +151,7 @@ def check_linear_module_handler(rank, bias, world_size, port): @pytest.mark.dist @rerun_if_address_is_in_use() def test_linear_handler(bias=True): - world_size = 4 - run_func_module = partial(check_linear_module_handler, bias=bias, world_size=world_size, port=free_port()) - mp.spawn(run_func_module, nprocs=world_size) + spawn(check_linear_module_handler, bias=bias) if __name__ == '__main__': diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_binary_elementwise_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_binary_elementwise_handler.py index 2414749f60a4..83f3aafe220e 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_binary_elementwise_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_binary_elementwise_handler.py @@ -1,8 +1,5 @@ -from functools import partial - import pytest import torch -import torch.multiprocessing as mp import torch.nn as nn from colossalai._analyzer.fx.graph_module import ColoGraphModule @@ -13,13 +10,11 @@ from colossalai.device.device_mesh import DeviceMesh from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers -from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use -from colossalai.testing.pytest_wrapper import run_on_environment_flag -from colossalai.utils import free_port +from colossalai.testing import parameterize, rerun_if_address_is_in_use, run_on_environment_flag, spawn from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy -def check_binary_elementwise_handler_with_tensor(rank, op, other_dim, world_size, port): +def check_binary_elementwise_handler_with_tensor(rank, world_size, port, op, other_dim): disable_existing_loggers() launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') @@ -149,7 +144,7 @@ def forward(self, x1): return out -def check_binary_elementwise_handler_with_int(rank, op, other_dim, model_cls, world_size, port): +def check_binary_elementwise_handler_with_int(rank, world_size, port, op, other_dim, model_cls): disable_existing_loggers() launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') @@ -236,13 +231,12 @@ def check_binary_elementwise_handler_with_int(rank, op, other_dim, model_cls, wo @pytest.mark.dist @rerun_if_address_is_in_use() def test_binary_elementwise_handler_with_tensor(op, other_dim): - world_size = 4 - run_func_tensor = partial(check_binary_elementwise_handler_with_tensor, - op=op, - other_dim=other_dim, - world_size=world_size, - port=free_port()) - mp.spawn(run_func_tensor, nprocs=world_size) + spawn( + check_binary_elementwise_handler_with_tensor, + 4, + op=op, + other_dim=other_dim, + ) @run_on_environment_flag(name='AUTO_PARALLEL') @@ -252,14 +246,13 @@ def test_binary_elementwise_handler_with_tensor(op, other_dim): @pytest.mark.dist @rerun_if_address_is_in_use() def test_binary_elementwise_handler_with_int(op, model_cls, other_dim): - world_size = 4 - run_func_int = partial(check_binary_elementwise_handler_with_int, - op=op, - model_cls=model_cls, - other_dim=other_dim, - world_size=world_size, - port=free_port()) - mp.spawn(run_func_int, nprocs=world_size) + spawn( + check_binary_elementwise_handler_with_int, + 4, + op=op, + model_cls=model_cls, + other_dim=other_dim, + ) if __name__ == '__main__': diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bmm_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bmm_handler.py index 34c20c1ac0fe..f4fdc458f80e 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bmm_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bmm_handler.py @@ -1,8 +1,5 @@ -from functools import partial - import pytest import torch -import torch.multiprocessing as mp import torch.nn as nn from colossalai._analyzer.fx.graph_module import ColoGraphModule @@ -13,9 +10,7 @@ from colossalai.device.device_mesh import DeviceMesh from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers -from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use -from colossalai.testing.pytest_wrapper import run_on_environment_flag -from colossalai.utils import free_port +from colossalai.testing import parameterize, rerun_if_address_is_in_use, run_on_environment_flag, spawn from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy @@ -207,11 +202,8 @@ def check_1d_device_mesh(rank, module, world_size, port): @pytest.mark.dist @rerun_if_address_is_in_use() def test_bmm_handler(module): - world_size = 4 - run_func_2d = partial(check_2d_device_mesh, module=module, world_size=world_size, port=free_port()) - mp.spawn(run_func_2d, nprocs=world_size) - run_func_1d = partial(check_1d_device_mesh, module=module, world_size=world_size, port=free_port()) - mp.spawn(run_func_1d, nprocs=world_size) + spawn(check_2d_device_mesh, 4, module=module) + spawn(check_1d_device_mesh, 4, module=module) if __name__ == '__main__': diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_conv_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_conv_handler.py index fe1a0d726db0..f9632b1cd8f9 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_conv_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_conv_handler.py @@ -1,8 +1,5 @@ -from functools import partial - import pytest import torch -import torch.multiprocessing as mp import torch.nn as nn from colossalai._analyzer.fx.graph_module import ColoGraphModule @@ -13,13 +10,11 @@ from colossalai.device.device_mesh import DeviceMesh from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers -from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use -from colossalai.testing.pytest_wrapper import run_on_environment_flag -from colossalai.utils import free_port +from colossalai.testing import rerun_if_address_is_in_use, run_on_environment_flag, spawn from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy -def check_conv_module_handler(rank, bias, world_size, port): +def check_conv_module_handler(rank, world_size, port, bias): disable_existing_loggers() launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') model = nn.Sequential(nn.Conv2d(4, 16, 3, padding=1, bias=bias)).cuda() @@ -155,7 +150,7 @@ def forward(self, input, others, bias=None): return x -def check_conv_function_handler(rank, bias, world_size, port): +def check_conv_function_handler(rank, world_size, port, bias): disable_existing_loggers() launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') model = ConvModel().cuda() @@ -302,9 +297,7 @@ def check_conv_function_handler(rank, bias, world_size, port): # @parameterize('bias', [True, False]) @rerun_if_address_is_in_use() def test_conv_module_handler(bias=False): - world_size = 4 - run_func = partial(check_conv_module_handler, bias=bias, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(check_conv_module_handler, 4, bias=bias) @run_on_environment_flag(name='AUTO_PARALLEL') @@ -314,9 +307,7 @@ def test_conv_module_handler(bias=False): # @parameterize('bias', [True, False]) @rerun_if_address_is_in_use() def test_conv_function_handler(bias=False): - world_size = 4 - run_func = partial(check_conv_function_handler, bias=bias, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(check_conv_function_handler, 4, bias=bias) if __name__ == '__main__': diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_default_reshape_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_default_reshape_handler.py index 8e5b7512ca0e..64f56ba98e2b 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_default_reshape_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_default_reshape_handler.py @@ -8,7 +8,7 @@ from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.device.device_mesh import DeviceMesh -from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.testing import clear_cache_before_run, run_on_environment_flag class ReshapeModel(nn.Module): @@ -23,6 +23,7 @@ def forward(self, input, other): @run_on_environment_flag(name='AUTO_PARALLEL') +@clear_cache_before_run() def test_reshape_handler(): model = ReshapeModel() tracer = ColoTracer(bias_addition_split=True) diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_embedding_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_embedding_handler.py index a61d2ed5c108..4fa0313b1cb5 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_embedding_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_embedding_handler.py @@ -1,8 +1,5 @@ -from functools import partial - import pytest import torch -import torch.multiprocessing as mp import torch.nn as nn from colossalai._analyzer.fx.graph_module import ColoGraphModule @@ -16,9 +13,8 @@ from colossalai.device.device_mesh import DeviceMesh from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers -from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use +from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.testing.pytest_wrapper import run_on_environment_flag -from colossalai.utils import free_port from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy NUM_EMBEDDINGS = 16 @@ -272,18 +268,14 @@ def check_embedding_function_handler(rank, world_size, port): @pytest.mark.dist @rerun_if_address_is_in_use() def test_embedding_module_handler(): - world_size = 4 - run_func = partial(check_embedding_module_handler, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(check_embedding_module_handler, 4) @run_on_environment_flag(name='AUTO_PARALLEL') @pytest.mark.dist @rerun_if_address_is_in_use() def test_embedding_function_handler(): - world_size = 4 - run_func = partial(check_embedding_function_handler, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(check_embedding_function_handler, 4) if __name__ == '__main__': diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getattr_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getattr_handler.py index fb611330946a..a089df743ec0 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getattr_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getattr_handler.py @@ -8,6 +8,7 @@ from colossalai.auto_parallel.tensor_shard.node_handler.getattr_handler import GetattrHandler from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.device.device_mesh import DeviceMesh +from colossalai.testing import clear_cache_before_run class GetattrModel(nn.Module): @@ -22,6 +23,7 @@ def forward(self, input): @pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0') +@clear_cache_before_run() def test_getattr_handler(): model = GetattrModel() tracer = ColoTracer(bias_addition_split=True) diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getitem_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getitem_handler.py index 9a29808ebb31..a2e0968b18bb 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getitem_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getitem_handler.py @@ -2,7 +2,6 @@ import pytest import torch -import torch.multiprocessing as mp import torch.nn as nn from colossalai._analyzer.fx.graph_module import ColoGraphModule @@ -14,12 +13,10 @@ from colossalai.auto_parallel.tensor_shard.node_handler.placeholder_handler import PlaceholderHandler from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.device.device_mesh import DeviceMesh -from colossalai.fx.tracer.meta_patch.patched_module import linear from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers -from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing.pytest_wrapper import run_on_environment_flag -from colossalai.utils import free_port from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy @@ -103,12 +100,7 @@ def check_getitem_from_tensor_handler(rank, getitem_index, world_size, port): # @parameterize('getitem_index', [slice(0, 2), (slice(None), slice(None))]) @parameterize('getitem_index', [1, (1, 4), slice(0, 2), (slice(None), slice(None))]) def test_getitem_from_tensor_handler(getitem_index): - world_size = 4 - run_func = partial(check_getitem_from_tensor_handler, - getitem_index=getitem_index, - world_size=world_size, - port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(check_getitem_from_tensor_handler, 4) class GetItemFromTupleModel(nn.Module): @@ -123,6 +115,7 @@ def forward(self, input): @run_on_environment_flag(name='AUTO_PARALLEL') +@clear_cache_before_run() def test_getitem_from_tuple_handler(): model = GetItemFromTupleModel() tracer = ColoTracer() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_layer_norm_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_layer_norm_handler.py index edd7bae6c979..ad72c2026b9a 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_layer_norm_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_layer_norm_handler.py @@ -1,8 +1,5 @@ -from functools import partial - import pytest import torch -import torch.multiprocessing as mp import torch.nn as nn from colossalai._analyzer.fx.graph_module import ColoGraphModule @@ -11,12 +8,10 @@ from colossalai.auto_parallel.tensor_shard.node_handler.layer_norm_handler import LayerNormModuleHandler from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.device.device_mesh import DeviceMesh -from colossalai.fx.tracer.meta_patch.patched_module import linear from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers -from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use +from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.testing.pytest_wrapper import run_on_environment_flag -from colossalai.utils import free_port from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy @@ -104,9 +99,7 @@ def check_ln_module_handler(rank, world_size, port): @pytest.mark.dist @rerun_if_address_is_in_use() def test_ln_module_handler(): - world_size = 4 - run_func = partial(check_ln_module_handler, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(check_ln_module_handler, 4) if __name__ == '__main__': diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py index bec5c3dc5e28..ec695cd8f7b9 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py @@ -1,8 +1,5 @@ -from functools import partial - import pytest import torch -import torch.multiprocessing as mp import torch.nn as nn from colossalai._analyzer.fx.graph_module import ColoGraphModule @@ -18,14 +15,13 @@ from colossalai.device.device_mesh import DeviceMesh from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers -from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing.pytest_wrapper import run_on_environment_flag from colossalai.testing.utils import parameterize -from colossalai.utils import free_port from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy -def check_linear_module_handler(rank, bias, input_shape, world_size, port): +def check_linear_module_handler(rank, world_size, port, bias, input_shape): disable_existing_loggers() launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') model = nn.Sequential(nn.Linear(16, 32, bias=bias)).cuda() @@ -172,7 +168,7 @@ def forward(self, input, others, bias=None): return x -def check_linear_function_handler(rank, bias, input_shape, world_size, port): +def check_linear_function_handler(rank, world_size, port, bias, input_shape): disable_existing_loggers() launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') model = LinearModel().cuda() @@ -313,19 +309,18 @@ def check_linear_function_handler(rank, bias, input_shape, world_size, port): @pytest.mark.dist @rerun_if_address_is_in_use() def test_linear_handler(input_shape, bias=False): - world_size = 4 - run_func_module = partial(check_linear_module_handler, - bias=bias, - input_shape=input_shape, - world_size=world_size, - port=free_port()) - mp.spawn(run_func_module, nprocs=world_size) - run_func_function = partial(check_linear_function_handler, - bias=bias, - input_shape=input_shape, - world_size=world_size, - port=free_port()) - mp.spawn(run_func_function, nprocs=world_size) + spawn( + check_linear_module_handler, + 4, + bias=bias, + input_shape=input_shape, + ) + spawn( + check_linear_function_handler, + 4, + bias=bias, + input_shape=input_shape, + ) if __name__ == '__main__': diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_matmul_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_matmul_handler.py index 46c3ff4434d7..938acd3d1eea 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_matmul_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_matmul_handler.py @@ -18,7 +18,7 @@ StrategiesVector, ) from colossalai.device.device_mesh import DeviceMesh -from colossalai.testing.utils import parameterize +from colossalai.testing.utils import clear_cache_before_run, parameterize class MatMulModule(nn.Module): @@ -28,6 +28,7 @@ def forward(self, x1, x2): @pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations") +@clear_cache_before_run() @parameterize( 'tensor_shapes', [ diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_norm_pooling_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_norm_pooling_handler.py index aacc7d9aeb64..6bff9f9648e2 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_norm_pooling_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_norm_pooling_handler.py @@ -1,4 +1,3 @@ -import pytest import torch import torch.nn as nn @@ -8,11 +7,11 @@ from colossalai.auto_parallel.tensor_shard.node_handler.normal_pooling_handler import NormPoolingHandler from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.device.device_mesh import DeviceMesh -from colossalai.fx.tracer.meta_patch.patched_module import linear -from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.testing import clear_cache_before_run, run_on_environment_flag @run_on_environment_flag(name='AUTO_PARALLEL') +@clear_cache_before_run() def test_norm_pool_handler(): model = nn.Sequential(nn.MaxPool2d(4, padding=1).to('meta')) tracer = ColoTracer(bias_addition_split=True) diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_output_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_output_handler.py index 5efbb4f5f6a4..5259455d2179 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_output_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_output_handler.py @@ -8,7 +8,7 @@ from colossalai.auto_parallel.tensor_shard.node_handler.output_handler import OutputHandler from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.device.device_mesh import DeviceMesh -from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use +from colossalai.testing import clear_cache_before_run, parameterize class OutputModel(nn.Module): @@ -23,7 +23,7 @@ def forward(self, x): @pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0') @parameterize('output_option', ['distributed', 'replicated']) -@rerun_if_address_is_in_use() +@clear_cache_before_run() def test_output_handler(output_option): model = OutputModel() tracer = ColoTracer(bias_addition_split=True) diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_permute_and_transpose_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_permute_and_transpose_handler.py index 0a5ad3e3523d..f071cd120fb7 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_permute_and_transpose_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_permute_and_transpose_handler.py @@ -2,7 +2,6 @@ import pytest import torch -import torch.multiprocessing as mp import torch.nn as nn from colossalai._analyzer.fx.graph_module import ColoGraphModule @@ -15,9 +14,8 @@ from colossalai.device.device_mesh import DeviceMesh from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers -from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing.pytest_wrapper import run_on_environment_flag -from colossalai.utils import free_port from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy @@ -55,7 +53,7 @@ def forward(self, input, other): return permute_node -def check_view_handler(rank, call_function, reshape_dims, model_cls, world_size, port): +def check_view_handler(rank, world_size, port, call_function, reshape_dims, model_cls): disable_existing_loggers() launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') if call_function == torch.permute: @@ -328,14 +326,13 @@ def check_view_handler(rank, call_function, reshape_dims, model_cls, world_size, @parameterize('reshape_dims', [((0, 2, 1, 3), (1, 2)), ((2, 0, 1, 3), (1, 3))]) @parameterize('model_cls', [ConvReshapeModel, LinearReshapeModel]) def test_view_handler(call_function, reshape_dims, model_cls): - world_size = 4 - run_func = partial(check_view_handler, - call_function=call_function, - reshape_dims=reshape_dims, - model_cls=model_cls, - world_size=world_size, - port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn( + check_view_handler, + 4, + call_function=call_function, + reshape_dims=reshape_dims, + model_cls=model_cls, + ) if __name__ == '__main__': diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_placeholder_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_placeholder_handler.py index 5e8fb51edbff..6d02b0e0ba74 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_placeholder_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_placeholder_handler.py @@ -8,7 +8,7 @@ from colossalai.auto_parallel.tensor_shard.node_handler.placeholder_handler import PlaceholderHandler from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.device.device_mesh import DeviceMesh -from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use +from colossalai.testing import clear_cache_before_run, parameterize class PlaceholderModel(nn.Module): @@ -22,7 +22,7 @@ def forward(self, input): @pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0') @parameterize('placeholder_option', ['distributed', 'replicated']) -@rerun_if_address_is_in_use() +@clear_cache_before_run() def test_placeholder_handler(placeholder_option): model = PlaceholderModel() tracer = ColoTracer(bias_addition_split=True) diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_shard_option.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_shard_option.py index e589fff996c6..14c364c45fc4 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_shard_option.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_shard_option.py @@ -1,5 +1,4 @@ import torch -import torch.multiprocessing as mp import torch.nn as nn from colossalai._analyzer.fx.graph_module import ColoGraphModule @@ -9,7 +8,7 @@ from colossalai.auto_parallel.tensor_shard.options import ShardOption from colossalai.auto_parallel.tensor_shard.sharding_strategy import StrategiesVector from colossalai.device.device_mesh import DeviceMesh -from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.testing import clear_cache_before_run, run_on_environment_flag class LinearModel(nn.Module): @@ -108,6 +107,7 @@ def check_shard_option(shard_option): @run_on_environment_flag(name='AUTO_PARALLEL') +@clear_cache_before_run() def test_shard_option(): # for shard_option in [ShardOption.STANDARD, ShardOption.SHARD, ShardOption.FULL_SHARD, ShardOption.SHARD_LAST_AXIS]: for shard_option in [ShardOption.SHARD_LAST_AXIS]: diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_softmax_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_softmax_handler.py index db463a4e9d6a..75ae0416ef98 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_softmax_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_softmax_handler.py @@ -1,8 +1,5 @@ -from functools import partial - import pytest import torch -import torch.multiprocessing as mp import torch.nn as nn import torch.nn.functional as F @@ -15,9 +12,7 @@ from colossalai.device.device_mesh import DeviceMesh from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers -from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use -from colossalai.testing.pytest_wrapper import run_on_environment_flag -from colossalai.utils import free_port +from colossalai.testing import parameterize, rerun_if_address_is_in_use, run_on_environment_flag, spawn from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy @@ -33,7 +28,7 @@ def forward(self, input, other): return softmax_node -def check_split_handler(rank, softmax_dim, model_cls, world_size, port): +def check_split_handler(rank, world_size, port, softmax_dim, model_cls): disable_existing_loggers() launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') model = model_cls(softmax_dim=softmax_dim).cuda() @@ -176,13 +171,7 @@ def check_split_handler(rank, softmax_dim, model_cls, world_size, port): @parameterize('softmax_dim', [0, 1, 2, 3]) @parameterize('model_cls', [LinearSplitModel]) def test_split_handler(softmax_dim, model_cls): - world_size = 4 - run_func = partial(check_split_handler, - softmax_dim=softmax_dim, - model_cls=model_cls, - world_size=world_size, - port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(check_split_handler, 4, softmax_dim=softmax_dim, model_cls=model_cls) if __name__ == '__main__': diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_split_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_split_handler.py index db59ea60ef4b..f860c629b0a0 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_split_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_split_handler.py @@ -1,8 +1,5 @@ -from functools import partial - import pytest import torch -import torch.multiprocessing as mp import torch.nn as nn from colossalai._analyzer.fx.graph_module import ColoGraphModule @@ -15,9 +12,7 @@ from colossalai.device.device_mesh import DeviceMesh from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers -from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use -from colossalai.testing.pytest_wrapper import run_on_environment_flag -from colossalai.utils import free_port +from colossalai.testing import parameterize, rerun_if_address_is_in_use, run_on_environment_flag, spawn from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy @@ -47,7 +42,7 @@ def forward(self, input, other): return split_node -def check_split_handler(rank, split_size, split_dim, model_cls, world_size, port): +def check_split_handler(rank, world_size, port, split_size, split_dim, model_cls): disable_existing_loggers() launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') model = model_cls(split_size=split_size, split_dim=split_dim).cuda() @@ -258,14 +253,7 @@ def check_split_handler(rank, split_size, split_dim, model_cls, world_size, port @parameterize('split_dim', [0, 1, 2]) @parameterize('model_cls', [ConvSplitModel, LinearSplitModel]) def test_split_handler(split_size, split_dim, model_cls): - world_size = 4 - run_func = partial(check_split_handler, - split_size=split_size, - split_dim=split_dim, - model_cls=model_cls, - world_size=world_size, - port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(check_split_handler, 4, split_size=split_size, split_dim=split_dim, model_cls=model_cls) if __name__ == '__main__': diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_sum_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_sum_handler.py index add51d73f2a4..c11291ecac96 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_sum_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_sum_handler.py @@ -1,8 +1,5 @@ -from functools import partial - import pytest import torch -import torch.multiprocessing as mp import torch.nn as nn from colossalai._analyzer.fx.graph_module import ColoGraphModule @@ -14,9 +11,7 @@ from colossalai.device.device_mesh import DeviceMesh from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers -from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use -from colossalai.testing.pytest_wrapper import run_on_environment_flag -from colossalai.utils import free_port +from colossalai.testing import parameterize, rerun_if_address_is_in_use, run_on_environment_flag, spawn from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy @@ -36,7 +31,7 @@ def forward(self, input, other): return sum_node -def check_sum_handler(rank, sum_dims, keepdim, world_size, port): +def check_sum_handler(rank, world_size, port, sum_dims, keepdim): disable_existing_loggers() launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') model = LinearSumModel(sum_dims=sum_dims, keepdim=keepdim).cuda() @@ -228,9 +223,7 @@ def check_sum_handler(rank, sum_dims, keepdim, world_size, port): @parameterize('sum_dims', [(0, 2), 1]) @parameterize('keepdim', [False, True]) def test_sum_handler(sum_dims, keepdim): - world_size = 4 - run_func = partial(check_sum_handler, sum_dims=sum_dims, keepdim=keepdim, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(check_sum_handler, 4, sum_dims=sum_dims, keepdim=keepdim) if __name__ == '__main__': diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_tensor_constructor.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_tensor_constructor.py index f54b208c3380..5b6ac051a8ef 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_tensor_constructor.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_tensor_constructor.py @@ -7,7 +7,7 @@ from colossalai.auto_parallel.tensor_shard.node_handler.tensor_constructor_handler import TensorConstructorHandler from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.device.device_mesh import DeviceMesh -from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.testing import clear_cache_before_run, run_on_environment_flag class TensorConstructorModel(nn.Module): @@ -22,6 +22,7 @@ def forward(self, x): @run_on_environment_flag(name='AUTO_PARALLEL') +@clear_cache_before_run() def test_where_handler(): model = TensorConstructorModel() tracer = ColoTracer(bias_addition_split=True) diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_unary_element_wise_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_unary_element_wise_handler.py index bd88089734a7..f4e6dafdfd69 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_unary_element_wise_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_unary_element_wise_handler.py @@ -8,7 +8,7 @@ from colossalai.auto_parallel.tensor_shard.node_handler.unary_elementwise_handler import UnaryElementwiseHandler from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.device.device_mesh import DeviceMesh -from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.testing import clear_cache_before_run, run_on_environment_flag class ReLuModel(nn.Module): @@ -24,6 +24,7 @@ def forward(self, input, other): @run_on_environment_flag(name='AUTO_PARALLEL') +@clear_cache_before_run() def test_elementwise_handler(): model = ReLuModel() tracer = ColoTracer(bias_addition_split=True) diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_view_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_view_handler.py index 300e8f94e7fe..fbb194d8e0b8 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_view_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_view_handler.py @@ -1,8 +1,5 @@ -from functools import partial - import pytest import torch -import torch.multiprocessing as mp import torch.nn as nn from colossalai._analyzer.fx.graph_module import ColoGraphModule @@ -15,9 +12,8 @@ from colossalai.device.device_mesh import DeviceMesh from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers -from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing.pytest_wrapper import run_on_environment_flag -from colossalai.utils import free_port from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy @@ -255,13 +251,7 @@ def check_view_handler(rank, tgt_shape, model_cls, world_size, port): @parameterize('tgt_shape', [(32, 4, 64, 16, 4), (8, 4, 4, 64, 16, 4)]) @parameterize('model_cls', [ConvViewModel, LinearViewModel]) def test_view_handler(tgt_shape, model_cls): - world_size = 4 - run_func = partial(check_view_handler, - tgt_shape=tgt_shape, - model_cls=model_cls, - world_size=world_size, - port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(check_view_handler, 4, tgt_shape=tgt_shape, model_cls=model_cls) if __name__ == '__main__': diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_where_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_where_handler.py index c150ebd90053..bd7635ac1737 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_where_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_where_handler.py @@ -8,6 +8,7 @@ from colossalai.auto_parallel.tensor_shard.node_handler.where_handler import WhereHandler from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.device.device_mesh import DeviceMesh +from colossalai.testing import clear_cache_before_run class ConvModel(nn.Module): @@ -21,6 +22,7 @@ def forward(self, condition, x, y): @pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0') +@clear_cache_before_run() def test_where_handler(): model = ConvModel() tracer = ColoTracer(bias_addition_split=True) diff --git a/tests/test_auto_parallel/test_tensor_shard/test_solver_with_resnet_v2.py b/tests/test_auto_parallel/test_tensor_shard/test_solver_with_resnet_v2.py index fb47baab9476..0d93e4e40527 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_solver_with_resnet_v2.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_solver_with_resnet_v2.py @@ -10,10 +10,11 @@ from colossalai.auto_parallel.tensor_shard.solver import CostGraph, GraphAnalyser, Solver, StrategiesConstructor from colossalai.device.device_mesh import DeviceMesh from colossalai.tensor.shape_consistency import ShapeConsistencyManager -from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.testing import clear_cache_before_run, run_on_environment_flag @run_on_environment_flag(name='AUTO_PARALLEL') +@clear_cache_before_run() def test_cost_graph(): physical_mesh_id = torch.arange(0, 8) mesh_shape = (2, 4) diff --git a/tests/test_autochunk/test_autochunk_alphafold/benchmark_autochunk_alphafold.py b/tests/test_autochunk/test_autochunk_alphafold/benchmark_autochunk_alphafold.py index 9a2240d62de4..d07145e48e1f 100644 --- a/tests/test_autochunk/test_autochunk_alphafold/benchmark_autochunk_alphafold.py +++ b/tests/test_autochunk/test_autochunk_alphafold/benchmark_autochunk_alphafold.py @@ -8,7 +8,7 @@ from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.passes.meta_info_prop import MetaInfoProp -from colossalai.utils import free_port +from colossalai.testing import free_port if AUTOCHUNK_AVAILABLE: from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen diff --git a/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_alphafold_utils.py b/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_alphafold_utils.py index cb250d6402e2..15610e2b50dc 100644 --- a/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_alphafold_utils.py +++ b/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_alphafold_utils.py @@ -9,7 +9,7 @@ from colossalai.core import global_context as gpc from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.passes.meta_info_prop import MetaInfoProp -from colossalai.utils import free_port +from colossalai.testing import free_port if AUTOCHUNK_AVAILABLE: from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen diff --git a/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_evoformer_block.py b/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_evoformer_block.py index 17a5abf4cab8..9e4cb7ee9f95 100644 --- a/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_evoformer_block.py +++ b/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_evoformer_block.py @@ -1,10 +1,8 @@ -from functools import partial from typing import Dict, List, Tuple import pytest import torch import torch.fx -import torch.multiprocessing as mp try: from fastfold.model.nn.evoformer import EvoformerBlock @@ -15,6 +13,7 @@ from test_autochunk_alphafold_utils import run_test from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE +from colossalai.testing import clear_cache_before_run, parameterize, spawn def get_model(): @@ -66,18 +65,19 @@ def get_chunk_target() -> Dict: not (AUTOCHUNK_AVAILABLE and HAS_REPO), reason="torch version is lower than 1.12.0", ) -@pytest.mark.parametrize("max_memory", [None, 20, 24]) -@pytest.mark.parametrize("data_args", [(32, 64)]) # (msa_len, pair_len) +@clear_cache_before_run() +@parameterize("max_memory", [None, 20, 24]) +@parameterize("data_args", [(32, 64)]) def test_evoformer_block(data_args, max_memory): - run_func = partial( + spawn( run_test, + 1, data_args=data_args, max_memory=max_memory, get_model=get_model, get_data=get_data, get_chunk_target=get_chunk_target, ) - mp.spawn(run_func, nprocs=1) if __name__ == "__main__": diff --git a/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_evoformer_stack.py b/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_evoformer_stack.py index 5210c1c8d48e..6b47033e199f 100644 --- a/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_evoformer_stack.py +++ b/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_evoformer_stack.py @@ -1,10 +1,8 @@ -from functools import partial from typing import List, Tuple import pytest import torch import torch.fx -import torch.multiprocessing as mp try: from fastfold.model.nn.evoformer import EvoformerStack @@ -15,6 +13,7 @@ from test_autochunk_alphafold_utils import run_test from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE +from colossalai.testing import clear_cache_before_run, parameterize, spawn def get_model(): @@ -61,17 +60,18 @@ def get_data(msa_len: int, pair_len: int) -> Tuple[List, List]: not (AUTOCHUNK_AVAILABLE and HAS_REPO), reason="torch version is lower than 1.12.0", ) -@pytest.mark.parametrize("max_memory", [None, 20, 24]) -@pytest.mark.parametrize("data_args", [(32, 64)]) # (msa_len, pair_len) +@clear_cache_before_run() +@parameterize("max_memory", [None, 20, 24]) +@parameterize("data_args", [(32, 64)]) # (msa_len, pair_len) def test_evoformer_stack(data_args, max_memory): - run_func = partial( + spawn( run_test, + 1, data_args=data_args, max_memory=max_memory, get_model=get_model, get_data=get_data, ) - mp.spawn(run_func, nprocs=1) if __name__ == "__main__": diff --git a/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_extramsa_block.py b/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_extramsa_block.py index ad955479e617..b4c577c18ee6 100644 --- a/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_extramsa_block.py +++ b/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_extramsa_block.py @@ -1,10 +1,8 @@ -from functools import partial from typing import Dict, List, Tuple import pytest import torch import torch.fx -import torch.multiprocessing as mp try: from fastfold.model.nn.evoformer import ExtraMSABlock @@ -14,6 +12,7 @@ from test_autochunk_alphafold_utils import run_test from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE +from colossalai.testing import clear_cache_before_run, parameterize, spawn def get_model(): @@ -57,17 +56,18 @@ def get_data(msa_len: int, pair_len: int) -> Tuple[List, List]: not (AUTOCHUNK_AVAILABLE and HAS_REPO), reason="torch version is lower than 1.12.0", ) -@pytest.mark.parametrize("max_memory", [None, 20, 24]) -@pytest.mark.parametrize("data_args", [(32, 64)]) # (msa_len, pair_len) +@clear_cache_before_run() +@parameterize("max_memory", [None, 20, 24]) +@parameterize("data_args", [(32, 64)]) # (msa_len, pair_len) def test_extramsa_block(data_args, max_memory): - run_func = partial( + spawn( run_test, + 1, data_args=data_args, max_memory=max_memory, get_model=get_model, get_data=get_data, ) - mp.spawn(run_func, nprocs=1) if __name__ == "__main__": diff --git a/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_diffuser_utils.py b/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_diffuser_utils.py index 529250fe8f51..e245f10d4576 100644 --- a/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_diffuser_utils.py +++ b/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_diffuser_utils.py @@ -8,7 +8,7 @@ from colossalai.core import global_context as gpc from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.passes.meta_info_prop import MetaInfoProp -from colossalai.utils import free_port +from colossalai.testing import free_port if AUTOCHUNK_AVAILABLE: from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen diff --git a/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_unet.py b/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_unet.py index 16c5b10ff4ae..ff0d4a1b53f5 100644 --- a/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_unet.py +++ b/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_unet.py @@ -1,9 +1,7 @@ -from functools import partial from typing import List, Tuple import pytest import torch -import torch.multiprocessing as mp try: from diffusers import UNet2DModel @@ -16,6 +14,7 @@ from test_autochunk_diffuser_utils import run_test from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE +from colossalai.testing import clear_cache_before_run, parameterize, spawn BATCH_SIZE = 1 HEIGHT = 448 @@ -37,17 +36,18 @@ def get_data(shape: tuple) -> Tuple[List, List]: not (AUTOCHUNK_AVAILABLE and HAS_REPO), reason="torch version is lower than 1.12.0", ) -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("shape", [LATENTS_SHAPE]) -@pytest.mark.parametrize("max_memory", [None, 150, 300]) +@clear_cache_before_run() +@parameterize("model", MODELS) +@parameterize("shape", [LATENTS_SHAPE]) +@parameterize("max_memory", [None, 150, 300]) def test_evoformer_block(model, shape, max_memory): - run_func = partial( + spawn( run_test, + 1, max_memory=max_memory, model=model, data=get_data(shape), ) - mp.spawn(run_func, nprocs=1) if __name__ == "__main__": diff --git a/tests/test_autochunk/test_autochunk_transformer/test_autochunk_gpt.py b/tests/test_autochunk/test_autochunk_transformer/test_autochunk_gpt.py index 018a2557a974..384706639e10 100644 --- a/tests/test_autochunk/test_autochunk_transformer/test_autochunk_gpt.py +++ b/tests/test_autochunk/test_autochunk_transformer/test_autochunk_gpt.py @@ -1,9 +1,7 @@ -from functools import partial from typing import List, Tuple import pytest import torch -import torch.multiprocessing as mp try: from transformers import GPT2Config, GPT2Model @@ -16,6 +14,7 @@ from test_autochunk_transformer_utils import run_test from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE +from colossalai.testing import clear_cache_before_run, parameterize, spawn BATCH_SIZE = 1 SEQ_LENGTH = 512 @@ -35,18 +34,19 @@ def get_data(shape: tuple) -> Tuple[List, List]: not (AUTOCHUNK_AVAILABLE and HAS_REPO), reason="torch version is lower than 1.12.0", ) -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("shape", [(BATCH_SIZE, SEQ_LENGTH)]) -@pytest.mark.parametrize("max_memory", [None, 6, 8]) +@clear_cache_before_run() +@parameterize("model", MODELS) +@parameterize("shape", [(BATCH_SIZE, SEQ_LENGTH)]) +@parameterize("max_memory", [None, 6, 8]) def test_autochunk_gpt(model, shape, max_memory): - run_func = partial( + spawn( run_test, + 1, data=get_data(shape), max_memory=max_memory, model=model, config=GPT2Config(n_embd=96, n_positions=shape[1], n_layer=2, n_head=4), ) - mp.spawn(run_func, nprocs=1) if __name__ == "__main__": diff --git a/tests/test_autochunk/test_autochunk_transformer/test_autochunk_transformer_utils.py b/tests/test_autochunk/test_autochunk_transformer/test_autochunk_transformer_utils.py index bc5eda7edf91..faba138cd42c 100644 --- a/tests/test_autochunk/test_autochunk_transformer/test_autochunk_transformer_utils.py +++ b/tests/test_autochunk/test_autochunk_transformer/test_autochunk_transformer_utils.py @@ -8,7 +8,7 @@ from colossalai.core import global_context as gpc from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.passes.meta_info_prop import MetaInfoProp -from colossalai.utils import free_port +from colossalai.testing import free_port if AUTOCHUNK_AVAILABLE: from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen diff --git a/tests/test_autochunk/test_autochunk_vit/test_autochunk_vit.py b/tests/test_autochunk/test_autochunk_vit/test_autochunk_vit.py index 2b7cbf1390d2..a98aa0e03954 100644 --- a/tests/test_autochunk/test_autochunk_vit/test_autochunk_vit.py +++ b/tests/test_autochunk/test_autochunk_vit/test_autochunk_vit.py @@ -1,9 +1,7 @@ -from functools import partial from typing import List, Tuple import pytest import torch -import torch.multiprocessing as mp try: from timm.models.vision_transformer import vit_large_patch16_384 as vit @@ -16,6 +14,7 @@ from test_autochunk_vit_utils import run_test from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE +from colossalai.testing import clear_cache_before_run, parameterize, spawn def get_data() -> Tuple[List, List]: @@ -28,16 +27,17 @@ def get_data() -> Tuple[List, List]: not (AUTOCHUNK_AVAILABLE and HAS_REPO), reason="torch version is lower than 1.12.0", ) -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("max_memory", [None, 32, 40]) +@clear_cache_before_run() +@parameterize("model", MODELS) +@parameterize("max_memory", [None, 32, 40]) def test_evoformer_block(model, max_memory): - run_func = partial( + spawn( run_test, + 1, max_memory=max_memory, model=model, data=get_data(), ) - mp.spawn(run_func, nprocs=1) if __name__ == "__main__": diff --git a/tests/test_autochunk/test_autochunk_vit/test_autochunk_vit_utils.py b/tests/test_autochunk/test_autochunk_vit/test_autochunk_vit_utils.py index 035dd59799b4..317606fc4781 100644 --- a/tests/test_autochunk/test_autochunk_vit/test_autochunk_vit_utils.py +++ b/tests/test_autochunk/test_autochunk_vit/test_autochunk_vit_utils.py @@ -8,7 +8,7 @@ from colossalai.core import global_context as gpc from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.passes.meta_info_prop import MetaInfoProp -from colossalai.utils import free_port +from colossalai.testing import free_port if AUTOCHUNK_AVAILABLE: from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen diff --git a/tests/test_booster/test_accelerator.py b/tests/test_booster/test_accelerator.py index 6958a87e2a08..895c494d0c17 100644 --- a/tests/test_booster/test_accelerator.py +++ b/tests/test_booster/test_accelerator.py @@ -1,27 +1,14 @@ -from functools import partial - -import torch.multiprocessing as mp import torch.nn as nn from colossalai.booster.accelerator import Accelerator -from colossalai.testing import parameterize, rerun_if_address_is_in_use +from colossalai.testing import clear_cache_before_run, parameterize +@clear_cache_before_run() @parameterize('device', ['cpu', 'cuda']) -def run_accelerator(device): +def test_accelerator(device): acceleartor = Accelerator(device) model = nn.Linear(8, 8) model = acceleartor.configure_model(model) assert next(model.parameters()).device.type == device del model, acceleartor - - -def run_dist(rank): - run_accelerator() - - -@rerun_if_address_is_in_use() -def test_accelerator(): - world_size = 1 - run_func = partial(run_dist) - mp.spawn(run_func, nprocs=world_size) diff --git a/tests/test_booster/test_mixed_precision/test_fp16_torch.py b/tests/test_booster/test_mixed_precision/test_fp16_torch.py index bacf29014193..963387da262b 100644 --- a/tests/test_booster/test_mixed_precision/test_fp16_torch.py +++ b/tests/test_booster/test_mixed_precision/test_fp16_torch.py @@ -1,13 +1,9 @@ -from functools import partial - import torch -import torch.multiprocessing as mp from torch.optim import Adam import colossalai from colossalai.booster.mixed_precision import FP16TorchMixedPrecision -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import rerun_if_address_is_in_use, spawn from tests.kit.model_zoo import model_zoo @@ -41,6 +37,4 @@ def run_torch_amp(rank, world_size, port): @rerun_if_address_is_in_use() def test_torch_ddp_plugin(): - world_size = 1 - run_func = partial(run_torch_amp, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_torch_amp, 1) diff --git a/tests/test_booster/test_plugin/test_gemini_plugin.py b/tests/test_booster/test_plugin/test_gemini_plugin.py index 169983a76110..a3c63fd09d26 100644 --- a/tests/test_booster/test_plugin/test_gemini_plugin.py +++ b/tests/test_booster/test_plugin/test_gemini_plugin.py @@ -1,17 +1,12 @@ -from functools import partial - -import pytest import torch import torch.distributed as dist -import torch.multiprocessing as mp import colossalai from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin from colossalai.nn.optimizer import HybridAdam from colossalai.tensor.colo_parameter import ColoParameter -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import rerun_if_address_is_in_use, spawn from tests.kit.model_zoo import model_zoo @@ -119,9 +114,7 @@ def run_dist(rank, world_size, port, early_stop: bool = True): @rerun_if_address_is_in_use() def test_gemini_plugin(early_stop: bool = True): - world_size = 2 - run_func = partial(run_dist, world_size=world_size, port=free_port(), early_stop=early_stop) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, 2, early_stop=early_stop) if __name__ == '__main__': diff --git a/tests/test_booster/test_plugin/test_torch_ddp_plugin.py b/tests/test_booster/test_plugin/test_torch_ddp_plugin.py index 71e8582cc364..c225a1a069dd 100644 --- a/tests/test_booster/test_plugin/test_torch_ddp_plugin.py +++ b/tests/test_booster/test_plugin/test_torch_ddp_plugin.py @@ -1,8 +1,5 @@ -from functools import partial - import torch import torch.distributed as dist -import torch.multiprocessing as mp from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import SGD @@ -10,8 +7,7 @@ from colossalai.booster import Booster from colossalai.booster.plugin import TorchDDPPlugin from colossalai.interface import OptimizerWrapper -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import rerun_if_address_is_in_use, spawn from tests.kit.model_zoo import model_zoo @@ -103,6 +99,4 @@ def run_dist(rank, world_size, port): @rerun_if_address_is_in_use() def test_torch_ddp_plugin(): - world_size = 2 - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, 2) diff --git a/tests/test_checkpoint_io/test_general_checkpoint_io.py b/tests/test_checkpoint_io/test_general_checkpoint_io.py index dfbb16af4ec6..0f78184f70e1 100644 --- a/tests/test_checkpoint_io/test_general_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_general_checkpoint_io.py @@ -6,6 +6,7 @@ from torchvision.models import resnet18 from colossalai.checkpoint_io import GeneralCheckpointIO +from colossalai.testing import clear_cache_before_run, parameterize # ======== # Note: @@ -15,7 +16,8 @@ # ======== -@pytest.mark.parametrize('use_safetensors', [True, False]) +@clear_cache_before_run() +@parameterize('use_safetensors', [True, False]) def test_unsharded_checkpoint(use_safetensors: bool): # create a model and optimizer model = resnet18() diff --git a/tests/test_cluster/test_device_mesh_manager.py b/tests/test_cluster/test_device_mesh_manager.py index b79814735325..b42ef1fe0062 100644 --- a/tests/test_cluster/test_device_mesh_manager.py +++ b/tests/test_cluster/test_device_mesh_manager.py @@ -1,14 +1,9 @@ -from functools import partial - import torch -import torch.multiprocessing as mp from colossalai.cluster.device_mesh_manager import DeviceMeshInfo, DeviceMeshManager -from colossalai.device.device_mesh import DeviceMesh -from colossalai.fx.tracer import ColoTracer from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers -from colossalai.utils import free_port +from colossalai.testing import spawn def check_device_mesh_manager(rank, world_size, port): @@ -31,9 +26,7 @@ def check_device_mesh_manager(rank, world_size, port): def test_device_mesh_manager(): - world_size = 4 - run_func = partial(check_device_mesh_manager, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(check_device_mesh_manager, 4) if __name__ == '__main__': diff --git a/tests/test_comm/test_boardcast_send_recv_v2.py b/tests/test_comm/test_boardcast_send_recv_v2.py index 1520d6054043..253f6f21cd80 100644 --- a/tests/test_comm/test_boardcast_send_recv_v2.py +++ b/tests/test_comm/test_boardcast_send_recv_v2.py @@ -1,17 +1,12 @@ -from functools import partial -from typing import List - import pytest import torch -import torch.distributed as dist -import torch.multiprocessing as mp -from colossalai.communication.p2p_v2 import _send_object, _recv_object, init_process_group + +from colossalai.communication.p2p_v2 import _recv_object, _send_object from colossalai.context import ParallelMode from colossalai.core import global_context as gpc from colossalai.initialize import launch -from colossalai.utils import free_port, get_current_device -from colossalai.testing import rerun_if_address_is_in_use from colossalai.logging import disable_existing_loggers +from colossalai.testing import rerun_if_address_is_in_use, spawn disable_existing_loggers() world_size = 4 @@ -45,9 +40,7 @@ def check_layer(rank, world_size, port): @pytest.mark.dist @rerun_if_address_is_in_use() def test_object_list_p2p(): - disable_existing_loggers() - run_func = partial(check_layer, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(check_layer, world_size) if __name__ == '__main__': diff --git a/tests/test_comm/test_comm.py b/tests/test_comm/test_comm.py index 07cb67730d24..747596bd2ded 100644 --- a/tests/test_comm/test_comm.py +++ b/tests/test_comm/test_comm.py @@ -1,15 +1,13 @@ -from functools import partial - import pytest import torch import torch.distributed as dist -import torch.multiprocessing as mp + from colossalai.communication import all_gather, all_reduce, reduce_scatter from colossalai.context import ParallelMode from colossalai.core import global_context as gpc from colossalai.initialize import launch -from colossalai.utils import free_port, get_current_device -from colossalai.testing import rerun_if_address_is_in_use +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.utils import get_current_device CONFIG = dict(parallel=dict(data=8, pipeline=1, tensor=dict(mode=None, size=1))) @@ -66,9 +64,7 @@ def check_layer(rank, world_size, port): @pytest.mark.dist @rerun_if_address_is_in_use() def test_comm(): - world_size = 4 - run_func = partial(check_layer, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(check_layer, 4) if __name__ == '__main__': diff --git a/tests/test_comm/test_object_list_p2p.py b/tests/test_comm/test_object_list_p2p.py index 701e3e8ade79..e9d7630c1543 100644 --- a/tests/test_comm/test_object_list_p2p.py +++ b/tests/test_comm/test_object_list_p2p.py @@ -1,15 +1,18 @@ -from functools import partial - import pytest import torch -import torch.distributed as dist -import torch.multiprocessing as mp -from colossalai.communication.p2p import send_forward, recv_forward, send_backward, recv_backward, send_forward_recv_backward, send_backward_recv_forward + +from colossalai.communication.p2p import ( + recv_backward, + recv_forward, + send_backward, + send_backward_recv_forward, + send_forward, + send_forward_recv_backward, +) from colossalai.context import ParallelMode from colossalai.core import global_context as gpc from colossalai.initialize import launch -from colossalai.utils import free_port, get_current_device -from colossalai.testing import rerun_if_address_is_in_use +from colossalai.testing import rerun_if_address_is_in_use, spawn CONFIG = dict(parallel=dict(pipeline=2)) torch.manual_seed(123) @@ -96,9 +99,7 @@ def check_layer(rank, world_size, port): @pytest.mark.dist @rerun_if_address_is_in_use() def test_object_list_p2p(): - world_size = 2 - run_func = partial(check_layer, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(check_layer, 2) if __name__ == '__main__': diff --git a/tests/test_comm/test_object_list_p2p_v2.py b/tests/test_comm/test_object_list_p2p_v2.py index c639ac9f8ef3..cae38385b6e1 100644 --- a/tests/test_comm/test_object_list_p2p_v2.py +++ b/tests/test_comm/test_object_list_p2p_v2.py @@ -1,16 +1,12 @@ -from functools import partial - import pytest import torch -import torch.distributed as dist -import torch.multiprocessing as mp -from colossalai.communication.p2p_v2 import send_forward, recv_forward, send_backward, recv_backward, init_process_group -from colossalai.context import ParallelMode, Initializer_Pipeline + +from colossalai.communication.p2p_v2 import recv_backward, recv_forward, send_backward, send_forward +from colossalai.context import ParallelMode from colossalai.core import global_context as gpc from colossalai.initialize import launch -from colossalai.utils import free_port, get_current_device -from colossalai.testing import rerun_if_address_is_in_use from colossalai.logging import disable_existing_loggers +from colossalai.testing import rerun_if_address_is_in_use, spawn disable_existing_loggers() @@ -121,10 +117,7 @@ def check_layer(rank, world_size, port): @pytest.mark.dist @rerun_if_address_is_in_use() def test_object_list_p2p(): - disable_existing_loggers() - run_func = partial(check_layer, world_size=world_size, port=free_port()) - disable_existing_loggers() - mp.spawn(run_func, nprocs=world_size) + spawn(check_layer, world_size) if __name__ == '__main__': diff --git a/tests/test_context/test_hybrid_parallel.py b/tests/test_context/test_hybrid_parallel.py index f311b1d2e736..9f26a5af53ce 100644 --- a/tests/test_context/test_hybrid_parallel.py +++ b/tests/test_context/test_hybrid_parallel.py @@ -1,19 +1,17 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -from functools import partial from pathlib import Path + import pytest import torch -import torch.multiprocessing as mp from colossalai import launch +from colossalai.context import reset_seeds from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc -from colossalai.utils import free_port -from colossalai.context import reset_seeds from colossalai.global_variables import tensor_parallel_env as tp_env -from colossalai.testing import rerun_if_address_is_in_use +from colossalai.testing import free_port, rerun_if_address_is_in_use, spawn CONFIG_PATH_LIST = list(Path(__file__).parent.glob('configs/*.py')) @@ -134,9 +132,14 @@ def init_context(config_path, rank, world_size, backend, port, host): torch.cuda.empty_cache() -def run_dist(rank, world_size, backend, port_list, host): - for config_path, port in zip(CONFIG_PATH_LIST, port_list): - init_context(config_path=config_path, rank=rank, world_size=world_size, backend=backend, port=port, host=host) +def run_dist(rank, world_size, port, backend, port_list, host): + for config_path, current_port in zip(CONFIG_PATH_LIST, port_list): + init_context(config_path=config_path, + rank=rank, + world_size=world_size, + backend=backend, + port=current_port, + host=host) reset_seeds() @@ -156,8 +159,7 @@ def test_context(): port_list.append(port) break - test_fn = partial(run_dist, world_size=world_size, backend='gloo', port_list=port_list, host='localhost') - mp.spawn(test_fn, nprocs=world_size) + spawn(run_dist, world_size, backend='gloo', port_list=port_list, host='localhost') if __name__ == '__main__': diff --git a/tests/test_data/test_data_parallel_sampler.py b/tests/test_data/test_data_parallel_sampler.py index 54fa44bdc0c2..2ad3fd696c39 100644 --- a/tests/test_data/test_data_parallel_sampler.py +++ b/tests/test_data/test_data_parallel_sampler.py @@ -2,20 +2,18 @@ # -*- encoding: utf-8 -*- import os -from functools import partial from pathlib import Path import pytest import torch import torch.distributed as dist -import torch.multiprocessing as mp +from torchvision import datasets, transforms import colossalai -from torchvision import transforms, datasets -from colossalai.context import ParallelMode, Config +from colossalai.context import Config, ParallelMode from colossalai.core import global_context as gpc -from colossalai.utils import get_dataloader, free_port -from colossalai.testing import rerun_if_address_is_in_use +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.utils import get_dataloader CONFIG = Config(dict( parallel=dict( @@ -58,9 +56,7 @@ def run_data_sampler(rank, world_size, port): @pytest.mark.cpu @rerun_if_address_is_in_use() def test_data_sampler(): - world_size = 4 - test_func = partial(run_data_sampler, world_size=world_size, port=free_port()) - mp.spawn(test_func, nprocs=world_size) + spawn(run_data_sampler, 4) if __name__ == '__main__': diff --git a/tests/test_data/test_deterministic_dataloader.py b/tests/test_data/test_deterministic_dataloader.py index 4d76e7f137f1..239e79dff7d8 100644 --- a/tests/test_data/test_deterministic_dataloader.py +++ b/tests/test_data/test_deterministic_dataloader.py @@ -2,21 +2,18 @@ # -*- encoding: utf-8 -*- import os -from functools import partial from pathlib import Path import pytest import torch import torch.distributed as dist -import torch.multiprocessing as mp -from torchvision import transforms, datasets +from torchvision import datasets, transforms import colossalai -from colossalai.context import ParallelMode, Config +from colossalai.context import Config, ParallelMode from colossalai.core import global_context as gpc -from colossalai.utils import get_dataloader, free_port -from colossalai.testing import rerun_if_address_is_in_use -from torchvision import transforms +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.utils import get_dataloader CONFIG = Config( dict( @@ -70,9 +67,7 @@ def run_data_sampler(rank, world_size, port): @pytest.mark.cpu @rerun_if_address_is_in_use() def test_data_sampler(): - world_size = 4 - test_func = partial(run_data_sampler, world_size=world_size, port=free_port()) - mp.spawn(test_func, nprocs=world_size) + spawn(run_data_sampler, 4) if __name__ == '__main__': diff --git a/tests/test_data_pipeline_tensor_parallel/test_cifar_with_data_pipeline_tensor.py b/tests/test_data_pipeline_tensor_parallel/test_cifar_with_data_pipeline_tensor.py index 3c2390c92837..4d63592f12b0 100644 --- a/tests/test_data_pipeline_tensor_parallel/test_cifar_with_data_pipeline_tensor.py +++ b/tests/test_data_pipeline_tensor_parallel/test_cifar_with_data_pipeline_tensor.py @@ -1,25 +1,22 @@ import os - -from functools import partial from pathlib import Path -import colossalai import pytest import torch -import torch.multiprocessing as mp +from torchvision import transforms +from torchvision.datasets import CIFAR10 + +import colossalai from colossalai.amp import AMP_TYPE -from colossalai.trainer import Trainer, hooks from colossalai.context import ParallelMode -from colossalai.testing import rerun_if_address_is_in_use, skip_if_not_enough_gpus -from colossalai.utils import free_port from colossalai.core import global_context as gpc from colossalai.logging import get_dist_logger from colossalai.nn import CrossEntropyLoss from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR -from colossalai.utils import get_dataloader from colossalai.pipeline.pipelinable import PipelinableContext -from torchvision.datasets import CIFAR10 -from torchvision import transforms +from colossalai.testing import rerun_if_address_is_in_use, skip_if_not_enough_gpus, spawn +from colossalai.trainer import Trainer, hooks +from colossalai.utils import get_dataloader BATCH_SIZE = 4 NUM_EPOCHS = 60 @@ -96,9 +93,7 @@ def run_trainer(rank, world_size, port): @skip_if_not_enough_gpus(min_gpus=8) @rerun_if_address_is_in_use() def test_hybrid_parallel(): - world_size = 8 - run_func = partial(run_trainer, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_trainer, 8) if __name__ == '__main__': diff --git a/tests/test_data_pipeline_tensor_parallel/test_cifar_with_data_pipeline_tensor_v2.py b/tests/test_data_pipeline_tensor_parallel/test_cifar_with_data_pipeline_tensor_v2.py index 2bafe0f7e374..67d2ba5f5d98 100644 --- a/tests/test_data_pipeline_tensor_parallel/test_cifar_with_data_pipeline_tensor_v2.py +++ b/tests/test_data_pipeline_tensor_parallel/test_cifar_with_data_pipeline_tensor_v2.py @@ -1,111 +1,104 @@ -import os - -from functools import partial -from pathlib import Path - -import colossalai -import pytest -import torch -import torch.multiprocessing as mp -from colossalai.amp import AMP_TYPE -from colossalai.trainer import Trainer, hooks -from colossalai.context import ParallelMode -from colossalai.testing import rerun_if_address_is_in_use, skip_if_not_enough_gpus -from colossalai.utils import free_port -from colossalai.core import global_context as gpc -from colossalai.logging import get_dist_logger -from colossalai.nn import CrossEntropyLoss -from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR -from colossalai.utils import get_dataloader -from colossalai.pipeline.pipelinable import PipelinableContext -from colossalai.logging import disable_existing_loggers -from torchvision.datasets import CIFAR10 -from torchvision import transforms - -from colossalai.engine.schedule._pipeline_schedule_v2 import PipelineScheduleV2 - -disable_existing_loggers() -BATCH_SIZE = 4 -NUM_EPOCHS = 10 -WARMUP_EPOCHS = 5 -CONFIG = dict(NUM_MICRO_BATCHES=2, - parallel=dict(pipeline=2, tensor=dict(size=1, mode='1d')), - fp16=dict(mode=AMP_TYPE.NAIVE), - gradient_accumulation=2) - - -def run_trainer(rank, world_size, port): - disable_existing_loggers() - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - - disable_existing_loggers() - # get logger - logger = get_dist_logger() - - pipelinable = PipelinableContext() - try: - from titans.model.vit import vit_tiny_patch4_32 - except ImportError: - logger.warning('skip the test_cifar_with_data_pipeline_tensor test because titan is not installed') - logger.warning('please install titan from https://github.com/hpcaitech/Titans') - return - with pipelinable: - model = vit_tiny_patch4_32() - pipelinable.to_layer_list() - pipelinable.policy = "uniform" - model = pipelinable.partition(1, gpc.pipeline_parallel_size, gpc.get_local_rank(ParallelMode.PIPELINE)) - - # craete dataloaders - root = Path(os.environ['DATA']) - transform_train = transforms.Compose([ - transforms.RandomCrop(32, padding=4, pad_if_needed=True), - transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10), - transforms.ToTensor(), - transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), - ]) - train_dataset = CIFAR10(root=root, train=True, download=True, transform=transform_train) - train_dataloader = get_dataloader(dataset=train_dataset, shuffle=True, batch_size=BATCH_SIZE, pin_memory=True) - - # create loss function - criterion = CrossEntropyLoss(label_smoothing=0.1) - - # create optimizer - optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=0) - - # create lr scheduler - lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer, total_steps=NUM_EPOCHS, warmup_steps=WARMUP_EPOCHS) - - # intiailize - engine, train_dataloader, *_ = colossalai.initialize(model=model, - optimizer=optimizer, - criterion=criterion, - train_dataloader=train_dataloader) - - engine._schedule = PipelineScheduleV2(num_microbatches=gpc.config.NUM_MICRO_BATCHES) - - logger = get_dist_logger() - - trainer = Trainer(engine=engine, logger=logger) - - hook_list = [ - hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=False), - ] - - trainer.fit(train_dataloader=train_dataloader, - max_steps=2, - epochs=NUM_EPOCHS, - hooks=hook_list, - display_progress=True) - - -@pytest.mark.dist -@rerun_if_address_is_in_use() -def test_hybrid_parallel(): - world_size = 2 - run_func = partial(run_trainer, world_size=world_size, port=free_port()) - disable_existing_loggers() - mp.spawn(run_func, nprocs=world_size) - - -if __name__ == '__main__': - test_hybrid_parallel() +import os +from pathlib import Path + +import pytest +import torch +from torchvision import transforms +from torchvision.datasets import CIFAR10 + +import colossalai +from colossalai.amp import AMP_TYPE +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.engine.schedule._pipeline_schedule_v2 import PipelineScheduleV2 +from colossalai.logging import disable_existing_loggers, get_dist_logger +from colossalai.nn import CrossEntropyLoss +from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR +from colossalai.pipeline.pipelinable import PipelinableContext +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.trainer import Trainer, hooks +from colossalai.utils import get_dataloader + +disable_existing_loggers() +BATCH_SIZE = 4 +NUM_EPOCHS = 10 +WARMUP_EPOCHS = 5 +CONFIG = dict(NUM_MICRO_BATCHES=2, + parallel=dict(pipeline=2, tensor=dict(size=1, mode='1d')), + fp16=dict(mode=AMP_TYPE.NAIVE), + gradient_accumulation=2) + + +def run_trainer(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + disable_existing_loggers() + # get logger + logger = get_dist_logger() + + pipelinable = PipelinableContext() + try: + from titans.model.vit import vit_tiny_patch4_32 + except ImportError: + logger.warning('skip the test_cifar_with_data_pipeline_tensor test because titan is not installed') + logger.warning('please install titan from https://github.com/hpcaitech/Titans') + return + with pipelinable: + model = vit_tiny_patch4_32() + pipelinable.to_layer_list() + pipelinable.policy = "uniform" + model = pipelinable.partition(1, gpc.pipeline_parallel_size, gpc.get_local_rank(ParallelMode.PIPELINE)) + + # craete dataloaders + root = Path(os.environ['DATA']) + transform_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4, pad_if_needed=True), + transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + ]) + train_dataset = CIFAR10(root=root, train=True, download=True, transform=transform_train) + train_dataloader = get_dataloader(dataset=train_dataset, shuffle=True, batch_size=BATCH_SIZE, pin_memory=True) + + # create loss function + criterion = CrossEntropyLoss(label_smoothing=0.1) + + # create optimizer + optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=0) + + # create lr scheduler + lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer, total_steps=NUM_EPOCHS, warmup_steps=WARMUP_EPOCHS) + + # intiailize + engine, train_dataloader, *_ = colossalai.initialize(model=model, + optimizer=optimizer, + criterion=criterion, + train_dataloader=train_dataloader) + + engine._schedule = PipelineScheduleV2(num_microbatches=gpc.config.NUM_MICRO_BATCHES) + + logger = get_dist_logger() + + trainer = Trainer(engine=engine, logger=logger) + + hook_list = [ + hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=False), + ] + + trainer.fit(train_dataloader=train_dataloader, + max_steps=2, + epochs=NUM_EPOCHS, + hooks=hook_list, + display_progress=True) + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_hybrid_parallel(): + spawn(run_trainer, 2) + disable_existing_loggers() + + +if __name__ == '__main__': + test_hybrid_parallel() diff --git a/tests/test_ddp/test_ddp_ignore_params.py b/tests/test_ddp/test_ddp_ignore_params.py index 2ad20f6bec72..39efcd41a1d4 100644 --- a/tests/test_ddp/test_ddp_ignore_params.py +++ b/tests/test_ddp/test_ddp_ignore_params.py @@ -1,19 +1,16 @@ import os import random -from functools import partial from typing import Callable, Type import numpy as np import pytest import torch import torch.distributed as dist -import torch.multiprocessing as mp import colossalai from colossalai.nn.parallel import ColoDDP from colossalai.tensor import ProcessGroup -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.utils.cuda import get_current_device from colossalai.zero import ColoInitContext, ZeroDDP from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration @@ -88,8 +85,7 @@ def run_dist(rank, world_size, port): @pytest.mark.parametrize('world_size', [2]) @rerun_if_address_is_in_use() def test_ddp_ignore_params(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, world_size) if __name__ == '__main__': diff --git a/tests/test_ddp/test_ddp_state_dict.py b/tests/test_ddp/test_ddp_state_dict.py index bd4742ff2cc9..54f89f972765 100644 --- a/tests/test_ddp/test_ddp_state_dict.py +++ b/tests/test_ddp/test_ddp_state_dict.py @@ -1,16 +1,12 @@ -import copy from collections import OrderedDict -from functools import partial import pytest import torch -import torch.multiprocessing as mp import colossalai from colossalai.nn.parallel import ColoDDP from colossalai.tensor import ColoParameter, ProcessGroup -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.utils.cuda import get_current_device from colossalai.zero import ColoInitContext from tests.components_to_test.registry import non_distributed_component_funcs @@ -64,8 +60,7 @@ def run_dist(rank, world_size, port): @pytest.mark.parametrize('world_size', [1, 2]) @rerun_if_address_is_in_use() def test_state_dict(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, world_size) if __name__ == '__main__': diff --git a/tests/test_ddp/test_reducer.py b/tests/test_ddp/test_reducer.py index 5b302d99ffb1..e8d3a112c938 100644 --- a/tests/test_ddp/test_reducer.py +++ b/tests/test_ddp/test_reducer.py @@ -1,15 +1,15 @@ +from functools import partial + import pytest -import colossalai import torch -import torch.multiprocessing as mp -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils.cuda import get_current_device -from colossalai.utils import free_port -from functools import partial -from colossalai.nn.parallel.reducer import Reducer import torch.distributed as dist from torch.distributed.distributed_c10d import _get_default_group +import colossalai +from colossalai.nn.parallel.reducer import Reducer +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.utils.cuda import get_current_device + REDUCE_CNT = 0 @@ -40,8 +40,7 @@ def run_dist(rank, world_size, port): @pytest.mark.parametrize('world_size', [1, 2]) @rerun_if_address_is_in_use() def test_reducer(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, world_size) if __name__ == '__main__': diff --git a/tests/test_device/test_alpha_beta.py b/tests/test_device/test_alpha_beta.py index 99abacd1342b..ab933ed57d0d 100644 --- a/tests/test_device/test_alpha_beta.py +++ b/tests/test_device/test_alpha_beta.py @@ -1,16 +1,12 @@ -from functools import partial - import pytest -import torch.multiprocessing as mp from colossalai.device import AlphaBetaProfiler from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers -from colossalai.testing import parameterize, rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -def check_alpha_beta(rank, physical_devices, world_size, port): +def check_alpha_beta(rank, world_size, port, physical_devices): disable_existing_loggers() launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') profiler = AlphaBetaProfiler(physical_devices) @@ -24,9 +20,7 @@ def check_alpha_beta(rank, physical_devices, world_size, port): @parameterize('physical_devices', [[0, 1, 2, 3], [0, 3]]) @rerun_if_address_is_in_use() def test_profile_alpha_beta(physical_devices): - world_size = 4 - run_func = partial(check_alpha_beta, physical_devices=physical_devices, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(check_alpha_beta, 4, physical_devices=physical_devices) if __name__ == '__main__': diff --git a/tests/test_device/test_extract_alpha_beta.py b/tests/test_device/test_extract_alpha_beta.py index e32bebdd908e..52604b9c6a49 100644 --- a/tests/test_device/test_extract_alpha_beta.py +++ b/tests/test_device/test_extract_alpha_beta.py @@ -1,16 +1,12 @@ -from functools import partial - import pytest -import torch.multiprocessing as mp from colossalai.device import AlphaBetaProfiler from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers -from colossalai.testing import parameterize, rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -def check_extract_alpha_beta(rank, physical_devices, world_size, port): +def check_extract_alpha_beta(rank, world_size, port, physical_devices): disable_existing_loggers() launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') profiler = AlphaBetaProfiler(physical_devices) @@ -27,12 +23,7 @@ def check_extract_alpha_beta(rank, physical_devices, world_size, port): @parameterize('physical_devices', [[0, 1, 2, 3], [0, 3]]) @rerun_if_address_is_in_use() def test_profile_alpha_beta(physical_devices): - world_size = 4 - run_func = partial(check_extract_alpha_beta, - physical_devices=physical_devices, - world_size=world_size, - port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(check_extract_alpha_beta, 4, physical_devices=physical_devices) if __name__ == '__main__': diff --git a/tests/test_device/test_init_logical_pg.py b/tests/test_device/test_init_logical_pg.py index 3172897fb5cd..2b7060c4846a 100644 --- a/tests/test_device/test_init_logical_pg.py +++ b/tests/test_device/test_init_logical_pg.py @@ -1,15 +1,12 @@ -import torch -from functools import partial import pytest +import torch import torch.distributed as dist -import torch.multiprocessing as mp from torch.distributed import ReduceOp from colossalai.core import global_context as gpc -from colossalai.initialize import launch -from colossalai.utils import free_port -from colossalai.testing import rerun_if_address_is_in_use from colossalai.device.device_mesh import DeviceMesh +from colossalai.initialize import launch +from colossalai.testing import rerun_if_address_is_in_use, spawn def check_layer(rank, world_size, port): @@ -40,9 +37,7 @@ def check_layer(rank, world_size, port): @pytest.mark.dist @rerun_if_address_is_in_use() def test_logical_pg(): - world_size = 4 - run_func = partial(check_layer, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(check_layer, 4) if __name__ == '__main__': diff --git a/tests/test_device/test_search_logical_device_mesh.py b/tests/test_device/test_search_logical_device_mesh.py index 591eafb2a50d..b22a76eabc2f 100644 --- a/tests/test_device/test_search_logical_device_mesh.py +++ b/tests/test_device/test_search_logical_device_mesh.py @@ -1,16 +1,12 @@ -from functools import partial - import pytest -import torch.multiprocessing as mp from colossalai.device import AlphaBetaProfiler from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers -from colossalai.testing import parameterize, rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -def check_alpha_beta(rank, physical_devices, world_size, port): +def check_alpha_beta(rank, world_size, port, physical_devices): disable_existing_loggers() launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') profiler = AlphaBetaProfiler(physical_devices) @@ -27,9 +23,7 @@ def check_alpha_beta(rank, physical_devices, world_size, port): @parameterize('physical_devices', [[0, 1, 2, 3], [0, 3]]) @rerun_if_address_is_in_use() def test_profile_alpha_beta(physical_devices): - world_size = 4 - run_func = partial(check_alpha_beta, physical_devices=physical_devices, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(check_alpha_beta, 4, physical_devices=physical_devices) if __name__ == '__main__': diff --git a/tests/test_engine/test_engine.py b/tests/test_engine/test_engine.py index fb5bd1e1602e..62493cf3712d 100644 --- a/tests/test_engine/test_engine.py +++ b/tests/test_engine/test_engine.py @@ -1,13 +1,10 @@ -from functools import partial +import pytest import colossalai -import pytest -import torch.multiprocessing as mp from colossalai.amp import AMP_TYPE from colossalai.core import global_context as gpc -from colossalai.utils import free_port +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from tests.components_to_test.registry import non_distributed_component_funcs -from colossalai.testing import parameterize, rerun_if_address_is_in_use CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None)), fp16=dict(mode=None), @@ -58,9 +55,7 @@ def run_engine(rank, world_size, port): @pytest.mark.dist @rerun_if_address_is_in_use() def test_engine(): - world_size = 2 - run_func = partial(run_engine, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_engine, 2) if __name__ == '__main__': diff --git a/tests/test_engine/test_gradient_accumluation.py b/tests/test_engine/test_gradient_accumluation.py index 7f5ee47be8e6..7783827c7c44 100644 --- a/tests/test_engine/test_gradient_accumluation.py +++ b/tests/test_engine/test_gradient_accumluation.py @@ -1,22 +1,20 @@ import os -from functools import partial from pathlib import Path -import colossalai -from colossalai.testing.utils import rerun_if_address_is_in_use import pytest import torch -import torch.multiprocessing as mp import torch.nn as nn -from colossalai.core import global_context as gpc -from colossalai.logging import get_dist_logger -from colossalai.utils import free_port, get_dataloader -from colossalai.testing import rerun_if_address_is_in_use from torch.optim import Adam from torchvision import transforms from torchvision.datasets import CIFAR10 from torchvision.models import resnet18 +import colossalai +from colossalai.core import global_context as gpc +from colossalai.logging import get_dist_logger +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.utils import get_dataloader + # Config BATCH_SIZE = 2 NUM_CLASSES = 10 @@ -90,9 +88,7 @@ def run_no_pipeline(rank, world_size, port): @pytest.mark.dist @rerun_if_address_is_in_use() def test_engine(): - world_size = 4 - func = partial(run_no_pipeline, world_size=world_size, port=free_port()) - mp.spawn(func, nprocs=world_size) + spawn(run_no_pipeline, 4) if __name__ == '__main__': diff --git a/tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py b/tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py index 83df1bb5e69c..ab483f7e47a3 100644 --- a/tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py +++ b/tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py @@ -1,15 +1,13 @@ import pytest import torch -import torch.multiprocessing as mp import torch.nn.functional as F -from torch.fx import GraphModule from torch.utils.checkpoint import checkpoint import colossalai from colossalai.core import global_context as gpc from colossalai.fx import ColoTracer from colossalai.fx.graph_module import ColoGraphModule -from colossalai.utils import free_port +from colossalai.testing import rerun_if_address_is_in_use, spawn try: from colossalai.fx.codegen import ActivationCheckpointCodeGen @@ -65,9 +63,9 @@ def forward(self, x, y): return y1 + y2 + y3 + y4 + y5 + y6 -def _run_act_ckpt_codegen(rank): +def _run_act_ckpt_codegen(rank, world_size, port): # launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly - colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') # build model and run forward model = MyModule() @@ -118,13 +116,14 @@ def _run_act_ckpt_codegen(rank): @pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0') +@rerun_if_address_is_in_use() def test_act_ckpt_codegen(): - mp.spawn(_run_act_ckpt_codegen, nprocs=1) + spawn(_run_act_ckpt_codegen, 1) -def _run_act_ckpt_python_code_torch11(rank): +def _run_act_ckpt_python_code_torch11(rank, world_size, port): # launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly - colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') # build model and run forward model = MyModule() @@ -174,8 +173,9 @@ def _run_act_ckpt_python_code_torch11(rank): @pytest.mark.skipif(with_codegen, reason='torch version is equal to or higher than 1.12.0') @pytest.mark.skip(reason="currently torch11 ColoGraphModule is not done") +@rerun_if_address_is_in_use() def test_act_ckpt_python_code_torch11(): - mp.spawn(_run_act_ckpt_python_code_torch11, nprocs=1) + spawn(_run_act_ckpt_python_code_torch11, 1) if __name__ == '__main__': diff --git a/tests/test_fx/test_codegen/test_nested_activation_checkpoint_codegen.py b/tests/test_fx/test_codegen/test_nested_activation_checkpoint_codegen.py index 6b3a49d181e1..9064023d4f68 100644 --- a/tests/test_fx/test_codegen/test_nested_activation_checkpoint_codegen.py +++ b/tests/test_fx/test_codegen/test_nested_activation_checkpoint_codegen.py @@ -1,15 +1,11 @@ import pytest import torch -import torch.multiprocessing as mp -import torch.nn.functional as F -from torch.fx import GraphModule -from torch.utils.checkpoint import checkpoint import colossalai from colossalai.core import global_context as gpc from colossalai.fx import ColoTracer from colossalai.fx.graph_module import ColoGraphModule -from colossalai.utils import free_port +from colossalai.testing import rerun_if_address_is_in_use, spawn try: from colossalai.fx.codegen import ActivationCheckpointCodeGen @@ -35,9 +31,9 @@ def forward(self, x): return self.linear6(self.linear5(self.linear4(self.linear3(self.linear2(self.linear1(x)))))) -def _run_act_ckpt_codegen(rank): +def _run_act_ckpt_codegen(rank, world_size, port): # launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly - colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') # build model and run forward model = MyModule() @@ -89,12 +85,12 @@ def _run_act_ckpt_codegen(rank): @pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0') def test_act_ckpt_codegen(): - mp.spawn(_run_act_ckpt_codegen, nprocs=1) + spawn(_run_act_ckpt_codegen, 1) -def _run_act_ckpt_python_code_torch11(rank): +def _run_act_ckpt_python_code_torch11(rank, world_size, port): # launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly - colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') # build model and run forward model = MyModule() @@ -146,8 +142,9 @@ def _run_act_ckpt_python_code_torch11(rank): @pytest.mark.skipif(with_codegen, reason='torch version is equal to or higher than 1.12.0') @pytest.mark.skip(reason="currently torch11 ColoGraphModule is not done") +@rerun_if_address_is_in_use() def test_act_ckpt_python_code_torch11(): - mp.spawn(_run_act_ckpt_python_code_torch11, nprocs=1) + spawn(_run_act_ckpt_python_code_torch11, 1) if __name__ == '__main__': diff --git a/tests/test_fx/test_codegen/test_offload_codegen.py b/tests/test_fx/test_codegen/test_offload_codegen.py index 5d090066c763..96e88eb92b33 100644 --- a/tests/test_fx/test_codegen/test_offload_codegen.py +++ b/tests/test_fx/test_codegen/test_offload_codegen.py @@ -2,15 +2,13 @@ import pytest import torch -import torch.multiprocessing as mp -import torch.nn.functional as F from torch.fx import GraphModule import colossalai from colossalai.core import global_context as gpc from colossalai.fx import ColoTracer from colossalai.fx.graph_module import ColoGraphModule -from colossalai.utils import free_port +from colossalai.testing import rerun_if_address_is_in_use, spawn try: from colossalai.fx.codegen import ActivationCheckpointCodeGen @@ -66,9 +64,9 @@ def _test_fwd_and_bwd(model: torch.nn.Module, gm: ColoGraphModule, data: torch.T assert _is_all_gradient_close(model, gm), "gm doesn't have the same gradient as original one" -def _run_offload_codegen(rank): +def _run_offload_codegen(rank, world_size, port): # launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly - colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') # build model and input model = MyNet().cuda() @@ -116,13 +114,14 @@ def _run_offload_codegen(rank): @pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0') +@rerun_if_address_is_in_use() def test_act_ckpt_codegen(): - mp.spawn(_run_offload_codegen, nprocs=1) + spawn(_run_offload_codegen, 1) -def _run_offload_codegen_torch11(rank): +def _run_offload_codegen_torch11(rank, world_size, port): # launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly - colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') # build model and input model = MyNet().cuda() @@ -171,8 +170,9 @@ def _run_offload_codegen_torch11(rank): @pytest.mark.skip(reason="currently torch11 ColoGraphModule is not implemented") +@rerun_if_address_is_in_use() def test_act_ckpt_python_code_torch11(): - mp.spawn(_run_offload_codegen_torch11, nprocs=1) + spawn(_run_offload_codegen_torch11, 1) if __name__ == "__main__": diff --git a/tests/test_fx/test_coloproxy.py b/tests/test_fx/test_coloproxy.py index 2bb6cf86466c..96cf5198da10 100644 --- a/tests/test_fx/test_coloproxy.py +++ b/tests/test_fx/test_coloproxy.py @@ -1,9 +1,11 @@ +import pytest import torch import torch.nn as nn +from torch.fx import GraphModule + from colossalai.fx.proxy import ColoProxy from colossalai.fx.tracer.tracer import ColoTracer -from torch.fx import GraphModule -import pytest +from colossalai.testing import clear_cache_before_run class Conv1D(nn.Module): @@ -23,6 +25,7 @@ def forward(self, x): return x +@clear_cache_before_run() def test_coloproxy(): tracer = ColoTracer() diff --git a/tests/test_fx/test_comm_size_compute.py b/tests/test_fx/test_comm_size_compute.py index 8825bbb461d6..d3daadd71406 100644 --- a/tests/test_fx/test_comm_size_compute.py +++ b/tests/test_fx/test_comm_size_compute.py @@ -1,13 +1,11 @@ -import colossalai -import colossalai.nn as col_nn -import pytest import torch -import torch.nn as nn +from torch.fx import symbolic_trace + from colossalai.fx._compatibility import is_compatible_with_meta -from colossalai.fx.passes.adding_split_node_pass import (split_with_split_nodes_pass, uniform_split_pass) +from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, uniform_split_pass from colossalai.fx.passes.meta_info_prop import MetaInfoProp from colossalai.fx.passes.utils import get_comm_size -from torch.fx import symbolic_trace +from colossalai.testing import clear_cache_before_run is_compatible = is_compatible_with_meta() if is_compatible: @@ -35,6 +33,7 @@ def forward(self, x): return x +@clear_cache_before_run() def test_comm_size_compute(): model = MLP(MODEL_DIM) input_sample = torch.rand(BATCH_SIZE, MODEL_DIM, device='meta') diff --git a/tests/test_fx/test_complete_workflow.py b/tests/test_fx/test_complete_workflow.py deleted file mode 100644 index a21a351f8d77..000000000000 --- a/tests/test_fx/test_complete_workflow.py +++ /dev/null @@ -1,87 +0,0 @@ -from functools import partial - -import pytest -import torch -import torch.distributed as dist -import torch.multiprocessing as mp -import torch.nn as nn - -import colossalai -from colossalai.fx import ColoTracer -from colossalai.fx.passes.shard_1d_pass import transformer_mlp_pass -from colossalai.tensor import ProcessGroup -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils import free_port -from colossalai.utils.model.lazy_init_context import LazyInitContext - - -class MLP(torch.nn.Module): - - def __init__(self, dim: int): - super().__init__() - self.linear1 = torch.nn.Linear(dim, dim) - self.linear2 = torch.nn.Linear(dim, dim) - self.dropout = torch.nn.Dropout(0) - self.relu = torch.nn.ReLU() - - def forward(self, x): - x = self.linear1(x) - x = self.dropout(x) - x = self.relu(x) - x = self.linear2(x) - return x - - -def run_workflow(world_size, dev): - # initailization - with LazyInitContext() as ctx: - model = MLP(16) - - for param in model.parameters(): - assert param.is_meta - - # tracing - tracer = ColoTracer() - graph = tracer.trace(model) - gm = torch.fx.GraphModule(model, graph, model.__class__.__name__) - - # annotate - annotated_gm = transformer_mlp_pass(gm, process_group=ProcessGroup(tp_degree=world_size)) - annotated_gm.recompile() - - # materialization and sharding - ctx.lazy_init_parameters(annotated_gm, device=dev) - for param in model.parameters(): - assert not param.is_meta - - # # check sharding - assert list(model.linear1.weight.shape) == [16 // world_size, 16] - assert list(model.linear1.bias.shape) == [16 // world_size] - assert list(model.linear2.weight.shape) == [16, 16 // world_size] - - # test forward to make sure that IR transform will produce the same results - # like how ColoTensor would do it normally - data = torch.rand(4, 16, device=dev) - non_fx_out = model(data) - fx_out = annotated_gm(data) - assert torch.equal(non_fx_out, fx_out), f'{non_fx_out} vs {fx_out}' - - -def run_dist(rank, world_size, dev, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - run_workflow(world_size, dev) - - -@pytest.mark.dist -@pytest.mark.parametrize('world_size', [1, 2]) -@pytest.mark.parametrize('dev', ['cuda', 'cpu']) -@rerun_if_address_is_in_use() -def test_complete_workflow(world_size, dev): - if dev == 'cpu' and world_size > 1: - return - run_func = partial(run_dist, world_size=world_size, dev=dev, port=free_port()) - mp.spawn(run_func, nprocs=world_size) - - -if __name__ == '__main__': - test_complete_workflow(1, 'cuda') diff --git a/tests/test_fx/test_graph_manipulation.py b/tests/test_fx/test_graph_manipulation.py index fb33e58a778c..175b69dd96fe 100644 --- a/tests/test_fx/test_graph_manipulation.py +++ b/tests/test_fx/test_graph_manipulation.py @@ -1,9 +1,11 @@ -import colossalai import torch -from colossalai.fx.passes.utils import get_leaf, get_top, assign_bfs_level_to_nodes -from colossalai.fx import ColoTracer from torch.fx import GraphModule + +import colossalai +from colossalai.fx import ColoTracer from colossalai.fx.passes.meta_info_prop import MetaInfoProp, TensorMetadata +from colossalai.fx.passes.utils import assign_bfs_level_to_nodes, get_leaf, get_top +from colossalai.testing import clear_cache_before_run class MLP(torch.nn.Module): @@ -25,6 +27,7 @@ def forward(self, x): return l4, l5 +@clear_cache_before_run() def test_graph_manipulation(): model = MLP(4) tracer = ColoTracer() diff --git a/tests/test_fx/test_meta/test_aten.py b/tests/test_fx/test_meta/test_aten.py index 209ded89cfb9..e490522dbf15 100644 --- a/tests/test_fx/test_meta/test_aten.py +++ b/tests/test_fx/test_meta/test_aten.py @@ -3,7 +3,9 @@ import pytest import torch import torch.nn as nn + from colossalai.fx._compatibility import is_compatible_with_meta +from colossalai.testing import clear_cache_before_run if is_compatible_with_meta(): from colossalai.fx.profiler import MetaTensor @@ -71,6 +73,7 @@ def run_and_compare(f: Union[nn.Module, Callable], x: torch.Tensor, requires_bac @pytest.mark.skipif(not is_compatible_with_meta(), reason='torch version is lower than 1.12.0') +@clear_cache_before_run() def test_meta_aten(): for (aten_op, requires_backward), v in registered_meta.items(): for f, x in v: diff --git a/tests/test_fx/test_meta/test_backward.py b/tests/test_fx/test_meta/test_backward.py index 351c02c5744a..7aed6fd4597b 100644 --- a/tests/test_fx/test_meta/test_backward.py +++ b/tests/test_fx/test_meta/test_backward.py @@ -2,11 +2,14 @@ import timm.models as tmm import torch import torchvision.models as tm + from colossalai.fx._compatibility import is_compatible_with_meta if is_compatible_with_meta(): from colossalai.fx.profiler import MetaTensor +from colossalai.testing import clear_cache_before_run + tm_models = [ tm.vgg11, tm.resnet18, @@ -28,6 +31,7 @@ @pytest.mark.skipif(not is_compatible_with_meta(), reason='torch version is lower than 1.12.0') +@clear_cache_before_run() def test_torchvision_models(): for m in tm_models: model = m() @@ -36,6 +40,7 @@ def test_torchvision_models(): @pytest.mark.skipif(not is_compatible_with_meta(), reason='torch version is lower than 1.12.0') +@clear_cache_before_run() def test_timm_models(): for m in tmm_models: model = m() diff --git a/tests/test_fx/test_meta/test_meta_trace.py b/tests/test_fx/test_meta/test_meta_trace.py index 404b6d27d2d4..61614f8a6623 100644 --- a/tests/test_fx/test_meta/test_meta_trace.py +++ b/tests/test_fx/test_meta/test_meta_trace.py @@ -2,11 +2,14 @@ import timm.models as tmm import torch import torchvision.models as tm + from colossalai.fx._compatibility import is_compatible_with_meta if is_compatible_with_meta(): from colossalai.fx import meta_trace +from colossalai.testing import clear_cache_before_run + tm_models = [ tm.vgg11, tm.resnet18, @@ -28,6 +31,7 @@ @pytest.mark.skipif(not is_compatible_with_meta(), reason='torch version is lower than 1.12.0') +@clear_cache_before_run() def test_torchvision_models_trace(): for m in tm_models: model = m() @@ -36,6 +40,7 @@ def test_torchvision_models_trace(): @pytest.mark.skipif(not is_compatible_with_meta(), reason='torch version is lower than 1.12.0') +@clear_cache_before_run() def test_timm_models_trace(): for m in tmm_models: model = m() diff --git a/tests/test_fx/test_meta_info_prop.py b/tests/test_fx/test_meta_info_prop.py index 6fac180d8ba2..a12512696a73 100644 --- a/tests/test_fx/test_meta_info_prop.py +++ b/tests/test_fx/test_meta_info_prop.py @@ -1,7 +1,9 @@ import torch +from torch.fx import symbolic_trace + from colossalai.fx._compatibility import is_compatible_with_meta from colossalai.fx.passes.meta_info_prop import MetaInfoProp, TensorMetadata -from torch.fx import symbolic_trace +from colossalai.testing import clear_cache_before_run if is_compatible_with_meta(): from colossalai.fx.profiler import MetaTensor @@ -18,6 +20,7 @@ def meta_check(meta_info_spec: TensorMetadata, orig_tensor: torch.Tensor): assert meta_info_spec.numel == orig_tensor.numel() +@clear_cache_before_run() def test_meta_info_prop(): model = torch.nn.Linear(DIM_IN, DIM_OUT) input_sample = torch.rand(BATCH_SIZE, DIM_IN, device='meta') diff --git a/tests/test_fx/test_parallel_1d.py b/tests/test_fx/test_parallel_1d.py index 8963ba29cb03..1044be7db1f4 100644 --- a/tests/test_fx/test_parallel_1d.py +++ b/tests/test_fx/test_parallel_1d.py @@ -1,18 +1,15 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -from functools import partial - import pytest import torch -import torch.multiprocessing as mp -from colossalai.core import global_context as gpc -from colossalai.logging import disable_existing_loggers -from colossalai.initialize import launch -from colossalai.utils import free_port -from colossalai.testing import rerun_if_address_is_in_use from torch.fx import symbolic_trace + +from colossalai.core import global_context as gpc from colossalai.fx.passes import column_shard_linear_pass +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn class MLP(torch.nn.Module): @@ -52,11 +49,10 @@ def check_layer(rank, world_size, port): @pytest.mark.dist +@clear_cache_before_run() @rerun_if_address_is_in_use() def test_1d(): - world_size = 2 - run_func = partial(check_layer, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(check_layer, 2) if __name__ == '__main__': diff --git a/tests/test_fx/test_pipeline_passes.py b/tests/test_fx/test_pipeline_passes.py index de8a9402ba56..1078dac9db7c 100644 --- a/tests/test_fx/test_pipeline_passes.py +++ b/tests/test_fx/test_pipeline_passes.py @@ -1,12 +1,17 @@ +import pytest import torch import torch.nn as nn -import colossalai -import colossalai.nn as col_nn from torch.fx import symbolic_trace -from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass, \ - uniform_split_pass, balanced_split_pass_v2 -import pytest +import colossalai +import colossalai.nn as col_nn +from colossalai.fx.passes.adding_split_node_pass import ( + balanced_split_pass, + balanced_split_pass_v2, + split_with_split_nodes_pass, + uniform_split_pass, +) +from colossalai.testing import clear_cache_before_run MODEL_DIM = 16 BATCH_SIZE = 8 @@ -39,6 +44,7 @@ def pipeline_pass_test_helper(model, data, pass_func): assert output.equal(origin_output) +@clear_cache_before_run() def test_pipeline_passes(): model = MLP(MODEL_DIM) data = torch.rand(BATCH_SIZE, MODEL_DIM) diff --git a/tests/test_fx/test_profiler/test_profiler_meta_info_prop.py b/tests/test_fx/test_profiler/test_profiler_meta_info_prop.py index c717960181ad..b5a6bbe8bf18 100644 --- a/tests/test_fx/test_profiler/test_profiler_meta_info_prop.py +++ b/tests/test_fx/test_profiler/test_profiler_meta_info_prop.py @@ -9,7 +9,7 @@ from colossalai.fx.passes.meta_info_prop import MetaInfoProp from colossalai.fx.profiler import calculate_fwd_out, calculate_fwd_tmp, is_compatible_with_meta, parameter_size from colossalai.fx.tracer.tracer import ColoTracer -from colossalai.testing.pytest_wrapper import run_on_environment_flag +from colossalai.testing import clear_cache_before_run, run_on_environment_flag if is_compatible_with_meta(): from colossalai.fx.profiler import MetaTensor @@ -126,6 +126,7 @@ def run_gpt_forward(gm: torch.fx.GraphModule): @run_on_environment_flag(name='FX_PROFILER') +@clear_cache_before_run() def test_meta_info_prop(): for m in [ tm.alexnet, tm.resnet18, tm.resnet34, tm.resnet50, tm.resnet101, tm.resnet152, tm.densenet121, @@ -155,6 +156,7 @@ def test_meta_info_prop(): @run_on_environment_flag(name='FX_PROFILER') +@clear_cache_before_run() def test_gpt_meta_info_prop(): for m in [gpt2_medium]: model = m().cuda() diff --git a/tests/test_fx/test_tracer/test_activation_checkpoint_annotation.py b/tests/test_fx/test_tracer/test_activation_checkpoint_annotation.py index a834951bb695..632ab8c09750 100644 --- a/tests/test_fx/test_tracer/test_activation_checkpoint_annotation.py +++ b/tests/test_fx/test_tracer/test_activation_checkpoint_annotation.py @@ -4,6 +4,7 @@ from torch.utils.checkpoint import checkpoint from colossalai.fx import ColoTracer +from colossalai.testing import clear_cache_before_run class MLP(torch.nn.Module): @@ -35,6 +36,7 @@ def forward(self, x): return x +@clear_cache_before_run() def test_activation_checkpoint_annotation(): module = MyModule() diff --git a/tests/test_fx/test_tracer/test_bias_addition_module.py b/tests/test_fx/test_tracer/test_bias_addition_module.py index afa30a217604..2f88d8c784e8 100644 --- a/tests/test_fx/test_tracer/test_bias_addition_module.py +++ b/tests/test_fx/test_tracer/test_bias_addition_module.py @@ -1,6 +1,7 @@ import torch from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.testing import clear_cache_before_run class LinearModel(torch.nn.Module): @@ -32,6 +33,7 @@ def forward(self, x): return x +@clear_cache_before_run() def test_linear_module(): model = LinearModel(3, 6) tracer = ColoTracer() @@ -68,6 +70,7 @@ def test_linear_module(): assert add_node._meta_data.shape == (3, 6) +@clear_cache_before_run() def test_conv_module(): model = ConvModel(3, 6, 2) tracer = ColoTracer() diff --git a/tests/test_fx/test_tracer/test_control_flow.py b/tests/test_fx/test_tracer/test_control_flow.py index ed842cff2776..820729dadb3e 100644 --- a/tests/test_fx/test_tracer/test_control_flow.py +++ b/tests/test_fx/test_tracer/test_control_flow.py @@ -1,7 +1,9 @@ import torch import torch.nn as nn from torch.fx import GraphModule + from colossalai.fx import ColoTracer as Tracer +from colossalai.testing import clear_cache_before_run class ControlFlowModel(nn.Module): @@ -21,6 +23,7 @@ def forward(self, x, y): return x1 - y1 +@clear_cache_before_run() def test_control_flow(): model = ControlFlowModel() tracer = Tracer() diff --git a/tests/test_fx/test_tracer/test_functional_conv.py b/tests/test_fx/test_tracer/test_functional_conv.py index 95670b85f335..a552e905223d 100644 --- a/tests/test_fx/test_tracer/test_functional_conv.py +++ b/tests/test_fx/test_tracer/test_functional_conv.py @@ -1,8 +1,11 @@ import torch from torch.nn import functional as F + from colossalai.fx.tracer.meta_patch import patched_function +from colossalai.testing import clear_cache_before_run +@clear_cache_before_run() def test_conv(): # test F.conv_1d data_1d = torch.rand(3, 16, 10) diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_albert.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_albert.py index 31ba2290ed99..f4d681221191 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_albert.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_albert.py @@ -3,6 +3,7 @@ from hf_tracer_utils import trace_model_and_compare_output from packaging import version +from colossalai.testing import clear_cache_before_run from tests.kit.model_zoo import model_zoo BATCH_SIZE = 2 @@ -10,6 +11,7 @@ @pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') +@clear_cache_before_run() def test_albert(): sub_registry = model_zoo.get_sub_registry('transformers_albert') diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py index 8db6817c66dc..a833bb30c056 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py @@ -3,10 +3,12 @@ from hf_tracer_utils import trace_model_and_compare_output from packaging import version +from colossalai.testing import clear_cache_before_run from tests.kit.model_zoo import model_zoo @pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') +@clear_cache_before_run() def test_bert(): sub_registry = model_zoo.get_sub_registry('transformers_bert') diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_diffuser.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_diffuser.py index 92ece357bfed..0cbea82e083a 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_diffuser.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_diffuser.py @@ -2,6 +2,7 @@ import torch from colossalai.fx import symbolic_trace +from colossalai.testing import clear_cache_before_run from colossalai.testing.random import seed_all from tests.kit.model_zoo import model_zoo @@ -40,6 +41,7 @@ def assert_fn(ta, tb): @pytest.mark.skip(reason='cannot pass this test yet') +@clear_cache_before_run() def test_diffusers(): seed_all(9091, cuda_deterministic=True) @@ -52,6 +54,7 @@ def test_diffusers(): print(f"{name:40s} √") +@clear_cache_before_run() def test_torch_diffusers(): seed_all(65535, cuda_deterministic=True) diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py index 796c17e398d5..67107469d8bb 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py @@ -3,10 +3,12 @@ from hf_tracer_utils import trace_model_and_compare_output from packaging import version +from colossalai.testing import clear_cache_before_run from tests.kit.model_zoo import model_zoo @pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') +@clear_cache_before_run() def test_gpt(): sub_registry = model_zoo.get_sub_registry('transformers_gpt') diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py index e7bfa607082e..369545b03de1 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py @@ -3,10 +3,12 @@ from hf_tracer_utils import trace_model_and_compare_output from packaging import version +from colossalai.testing import clear_cache_before_run from tests.kit.model_zoo import model_zoo @pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') +@clear_cache_before_run() def test_opt(): sub_registry = model_zoo.get_sub_registry('transformers_opt') diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py index 5f7e4f81c44e..811cf3b21430 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py @@ -3,10 +3,12 @@ from hf_tracer_utils import trace_model_and_compare_output from packaging import version +from colossalai.testing import clear_cache_before_run from tests.kit.model_zoo import model_zoo @pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') +@clear_cache_before_run() def test_t5(): sub_registry = model_zoo.get_sub_registry('transformers_t5') diff --git a/tests/test_fx/test_tracer/test_patched_module.py b/tests/test_fx/test_tracer/test_patched_module.py index 94a93e16f3c7..ef778e21801a 100644 --- a/tests/test_fx/test_tracer/test_patched_module.py +++ b/tests/test_fx/test_tracer/test_patched_module.py @@ -1,5 +1,7 @@ import torch + from colossalai.fx.tracer.meta_patch import patched_module +from colossalai.testing import clear_cache_before_run def _run(data, module, patch_fn): @@ -31,6 +33,7 @@ def _assert_output_shape(data, module, patch_fn, expect_exception, output_shape) assert output.shape == output_shape +@clear_cache_before_run() def test_linear(): # test linear patch can produce the meta output with correct shape data = torch.rand(2, 4, device='meta') @@ -42,6 +45,7 @@ def test_linear(): _assert_output_shape(data, module, patched_module.torch_nn_linear, True, None) +@clear_cache_before_run() def test_rnn(): # test rnn patch can produce the meta output with correct shape data = (torch.randn(5, 3, 10), torch.randn(2, 3, 20)) @@ -58,6 +62,7 @@ def test_rnn(): _assert_output_shape(meta_data, module, patched_module.torch_nn_rnn, True, None) +@clear_cache_before_run() def test_embedding(): data = torch.rand(2, 4, device='meta') @@ -134,6 +139,7 @@ def test_embedding(): output_shape=None) +@clear_cache_before_run() def test_conv1d(): # test conv 1d data = torch.rand(2, 3, 4) @@ -212,6 +218,7 @@ def test_conv2d(): output_shape=materialized_output.shape) +@clear_cache_before_run() def test_conv3d(): # test conv 3d data = torch.rand(2, 3, 4, 4, 4) @@ -253,6 +260,7 @@ def test_conv3d(): output_shape=materialized_output.shape) +@clear_cache_before_run() def test_conv_transpose1d(): # test conv transpose1d data = torch.rand(2, 3, 4) @@ -276,6 +284,7 @@ def test_conv_transpose1d(): output_shape=materialized_output.shape) +@clear_cache_before_run() def test_conv_transpose2d(): # test conv transpose2d data = torch.rand(2, 3, 4, 4) @@ -299,6 +308,7 @@ def test_conv_transpose2d(): output_shape=materialized_output.shape) +@clear_cache_before_run() def test_conv_transpose3d(): # test conv transpose2d data = torch.rand(2, 3, 4, 4, 4) @@ -322,6 +332,7 @@ def test_conv_transpose3d(): output_shape=materialized_output.shape) +@clear_cache_before_run() def test_pool1d(): combinations = [[torch.nn.MaxPool1d, patched_module.torch_nn_maxpool1d], [torch.nn.AvgPool1d, patched_module.torch_nn_avgpool1d]] @@ -349,6 +360,7 @@ def test_pool1d(): _assert_output_shape(data=data, module=pooler, patch_fn=patch_func, expect_exception=True, output_shape=None) +@clear_cache_before_run() def test_pool2d(): combinations = [[torch.nn.MaxPool2d, patched_module.torch_nn_maxpool2d], [torch.nn.AvgPool2d, patched_module.torch_nn_avgpool2d]] @@ -379,6 +391,7 @@ def test_pool2d(): _assert_output_shape(data=data, module=pooler, patch_fn=patch_func, expect_exception=True, output_shape=None) +@clear_cache_before_run() def test_pool3d(): combinations = [[torch.nn.MaxPool3d, patched_module.torch_nn_maxpool3d], [torch.nn.AvgPool3d, patched_module.torch_nn_avgpool3d]] @@ -410,6 +423,7 @@ def test_pool3d(): # adapative pooling is different from other pooling, so test it individually +@clear_cache_before_run() def test_adaptive_pooling_1d(): pooler = torch.nn.AdaptiveAvgPool1d(output_size=3) patch_func = patched_module.torch_nn_adapative_pooling_1d @@ -434,6 +448,7 @@ def test_adaptive_pooling_1d(): _assert_output_shape(data=data, module=pooler, patch_fn=patch_func, expect_exception=True, output_shape=None) +@clear_cache_before_run() def test_adaptive_pooling_2d(): pooler = torch.nn.AdaptiveAvgPool2d(output_size=3) patch_func = patched_module.torch_nn_adapative_pooling_2d @@ -458,6 +473,7 @@ def test_adaptive_pooling_2d(): output_shape=output.shape) +@clear_cache_before_run() def test_adaptive_pooling_3d(): pooler = torch.nn.AdaptiveAvgPool3d(output_size=3) patch_func = patched_module.torch_nn_adapative_pooling_3d diff --git a/tests/test_fx/test_tracer/test_patched_op.py b/tests/test_fx/test_tracer/test_patched_op.py index 4406f02db24b..e0c5f560c49e 100644 --- a/tests/test_fx/test_tracer/test_patched_op.py +++ b/tests/test_fx/test_tracer/test_patched_op.py @@ -1,6 +1,9 @@ +from functools import partial + import torch + from colossalai.fx.tracer.meta_patch import patched_function -from functools import partial +from colossalai.testing import clear_cache_before_run def _run(data, patch_fn): @@ -22,6 +25,7 @@ def _assert_output_shape(data, patch_fn, expect_exception, output_shape): assert output.shape == output_shape +@clear_cache_before_run() def test_repeat_interleave(): patch_fn = patched_function.torch_repeat_interleave @@ -63,6 +67,7 @@ def test_repeat_interleave(): output_shape=materialized_output.shape) +@clear_cache_before_run() def test_torch_max(): data = torch.rand(4, 3) out = torch.max(data) diff --git a/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py b/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py index b175d8b10c67..aa14f514c7d6 100644 --- a/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py +++ b/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py @@ -3,6 +3,7 @@ from packaging import version from colossalai._analyzer.fx import symbolic_trace +from colossalai.testing import clear_cache_before_run from tests.kit.model_zoo import model_zoo @@ -43,6 +44,7 @@ def trace_and_compare(model_cls, data, output_transform_fn, meta_args=None): @pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12') +@clear_cache_before_run() def test_timm_models(): torch.backends.cudnn.deterministic = True diff --git a/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_model.py b/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_model.py index 66f4be5a6f7f..eafcaca10b1d 100644 --- a/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_model.py +++ b/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_model.py @@ -3,12 +3,14 @@ from packaging import version from torchaudio_utils import trace_and_compare +from colossalai.testing import clear_cache_before_run from tests.kit.model_zoo import model_zoo # We cannot handle the tensors constructed with constant during forward, such as ``torch.empty(0).to(device=Proxy.device)`` # TODO: We could handle this case by hijacking torch.Tensor.to function. @pytest.mark.skip +@clear_cache_before_run() def test_torchaudio_models(): torch.backends.cudnn.deterministic = True diff --git a/tests/test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py b/tests/test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py index 40f83d47a7cc..df02568c0049 100644 --- a/tests/test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py +++ b/tests/test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py @@ -2,6 +2,7 @@ import torch from colossalai._analyzer.fx import symbolic_trace +from colossalai.testing import clear_cache_before_run from tests.kit.model_zoo import model_zoo BATCH = 2 @@ -47,6 +48,7 @@ def trace_and_compare(model_cls, data, output_transform_fn, meta_args=None): ), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}' +@clear_cache_before_run() def test_torchrec_deepfm_models(): deepfm_models = model_zoo.get_sub_registry('deepfm') torch.backends.cudnn.deterministic = True diff --git a/tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py b/tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py index 6d4b6ab81b12..9776452be9c8 100644 --- a/tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py +++ b/tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py @@ -2,6 +2,7 @@ import torch from colossalai._analyzer.fx import symbolic_trace +from colossalai.testing import clear_cache_before_run from tests.kit.model_zoo import model_zoo BATCH = 2 @@ -47,6 +48,7 @@ def trace_and_compare(model_cls, data, output_transform_fn, meta_args=None): ), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}' +@clear_cache_before_run() def test_torchrec_dlrm_models(): torch.backends.cudnn.deterministic = True dlrm_models = model_zoo.get_sub_registry('dlrm') diff --git a/tests/test_fx/test_tracer/test_torchvision_model/test_torchvision_model.py b/tests/test_fx/test_tracer/test_torchvision_model/test_torchvision_model.py index 8dbbf9f5aab7..bd259475ae5a 100644 --- a/tests/test_fx/test_tracer/test_torchvision_model/test_torchvision_model.py +++ b/tests/test_fx/test_tracer/test_torchvision_model/test_torchvision_model.py @@ -1,9 +1,11 @@ import torch from colossalai._analyzer.fx import symbolic_trace +from colossalai.testing import clear_cache_before_run from tests.kit.model_zoo import model_zoo +@clear_cache_before_run() def test_torchvision_models(): torch.backends.cudnn.deterministic = True tv_sub_registry = model_zoo.get_sub_registry('torchvision') diff --git a/tests/test_layers/test_1d/test_1d.py b/tests/test_layers/test_1d/test_1d.py index 897590f0d9c8..891512542475 100644 --- a/tests/test_layers/test_1d/test_1d.py +++ b/tests/test_layers/test_1d/test_1d.py @@ -1,18 +1,14 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -from functools import partial - import pytest import torch -import torch.multiprocessing as mp from checks_1d.check_layer_1d import * from colossalai.core import global_context as gpc from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import rerun_if_address_is_in_use, spawn CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=4, mode='1d')),) @@ -40,9 +36,7 @@ def check_layer(rank, world_size, port): @pytest.mark.dist @rerun_if_address_is_in_use() def test_1d(): - world_size = 4 - run_func = partial(check_layer, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(check_layer, 4) if __name__ == '__main__': diff --git a/tests/test_layers/test_2d/test_2d.py b/tests/test_layers/test_2d/test_2d.py index da235d0cf168..bcea5ce7b25d 100644 --- a/tests/test_layers/test_2d/test_2d.py +++ b/tests/test_layers/test_2d/test_2d.py @@ -1,22 +1,27 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -from functools import partial - import pytest import torch -import torch.multiprocessing as mp +from checks_2d.check_layer_2d import ( + check_classifier_given_embed_weight, + check_classifier_no_given_weight, + check_embed, + check_layernorm, + check_linear, + check_loss, + check_patch_embed, + check_vocab_parallel_classifier_given_embed_weight, + check_vocab_parallel_classifier_no_given_weight, + check_vocab_parallel_embed, + check_vocab_parallel_loss, +) +from checks_2d.check_operation_2d import check_AB, check_ABT, check_ATB + from colossalai.core import global_context as gpc from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers -from colossalai.utils import free_port -from colossalai.testing import rerun_if_address_is_in_use -from checks_2d.check_layer_2d import (check_classifier_given_embed_weight, check_classifier_no_given_weight, - check_embed, check_layernorm, check_linear, check_loss, check_patch_embed, - check_vocab_parallel_classifier_given_embed_weight, - check_vocab_parallel_classifier_no_given_weight, check_vocab_parallel_embed, - check_vocab_parallel_loss) -from checks_2d.check_operation_2d import check_AB, check_ABT, check_ATB +from colossalai.testing import rerun_if_address_is_in_use, spawn CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=4, mode='2d')),) @@ -57,9 +62,7 @@ def check_layer_and_operation(rank, world_size, port): @pytest.mark.dist @rerun_if_address_is_in_use() def test_2d(): - world_size = 4 - run_func = partial(check_layer_and_operation, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(check_layer_and_operation, 4) if __name__ == '__main__': diff --git a/tests/test_layers/test_2p5d/test_2p5d.py b/tests/test_layers/test_2p5d/test_2p5d.py index 365e2d934df8..373d834d0032 100644 --- a/tests/test_layers/test_2p5d/test_2p5d.py +++ b/tests/test_layers/test_2p5d/test_2p5d.py @@ -1,15 +1,12 @@ -from functools import partial - import pytest import torch -import torch.multiprocessing as mp +from checks_2p5d.check_layer_2p5d import * +from checks_2p5d.check_operation_2p5d import check_AB, check_ABT, check_ATB + from colossalai.core import global_context as gpc from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers -from colossalai.utils import free_port -from colossalai.testing import rerun_if_address_is_in_use -from checks_2p5d.check_layer_2p5d import * -from checks_2p5d.check_operation_2p5d import check_AB, check_ABT, check_ATB +from colossalai.testing import rerun_if_address_is_in_use, spawn CONFIG = dict(parallel=dict( pipeline=dict(size=1), @@ -53,9 +50,7 @@ def check_layer_and_operation(rank, world_size, port): @pytest.mark.dist @rerun_if_address_is_in_use() def test_2p5d(): - world_size = 4 - run_func = partial(check_layer_and_operation, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(check_layer_and_operation, 4) if __name__ == '__main__': diff --git a/tests/test_layers/test_3d/test_3d.py b/tests/test_layers/test_3d/test_3d.py index 29a8b3aea239..fde71a4a0d26 100644 --- a/tests/test_layers/test_3d/test_3d.py +++ b/tests/test_layers/test_3d/test_3d.py @@ -1,19 +1,24 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -from functools import partial - import pytest import torch -import torch.multiprocessing as mp +from checks_3d.check_layer_3d import ( + check_classifier_no_given_weight, + check_embed, + check_layernorm, + check_linear, + check_loss, + check_patch_embed, + check_vocab_parallel_classifier_given_embed_weight, + check_vocab_parallel_classifier_no_given_weight, + check_vocab_parallel_embed, + check_vocab_parallel_loss, +) + from colossalai.core import global_context as gpc from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers -from colossalai.utils import free_port -from colossalai.testing import rerun_if_address_is_in_use, skip_if_not_enough_gpus -from checks_3d.check_layer_3d import (check_classifier_no_given_weight, check_embed, check_layernorm, check_linear, - check_loss, check_patch_embed, check_vocab_parallel_classifier_given_embed_weight, - check_vocab_parallel_classifier_no_given_weight, check_vocab_parallel_embed, - check_vocab_parallel_loss) +from colossalai.testing import rerun_if_address_is_in_use, skip_if_not_enough_gpus, spawn CONFIG = dict( parallel=dict( @@ -52,9 +57,7 @@ def check_layer_and_operation(rank, world_size, port): @skip_if_not_enough_gpus(min_gpus=8) @rerun_if_address_is_in_use() def test_3d(): - world_size = 8 - run_func = partial(check_layer_and_operation, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(check_layer_and_operation, 8) if __name__ == '__main__': diff --git a/tests/test_layers/test_cache_embedding.py b/tests/test_layers/test_cache_embedding.py index cff9072c7a06..22d4f02a48d7 100644 --- a/tests/test_layers/test_cache_embedding.py +++ b/tests/test_layers/test_cache_embedding.py @@ -1,20 +1,21 @@ -import pytest -from functools import partial - -import numpy as np import random +from typing import List +import numpy as np +import pytest import torch -import torch.multiprocessing as mp import colossalai -from colossalai.utils import free_port -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.tensor import ColoParameter, ProcessGroup, ShardSpec, ComputePattern, ComputeSpec, \ - ColoTensor, ColoTensorSpec -from colossalai.nn.parallel.layers import CachedParamMgr, CachedEmbeddingBag, ParallelCachedEmbeddingBag, EvictionStrategy, \ - ParallelCachedEmbeddingBagTablewise, TablewiseEmbeddingBagConfig -from typing import List +from colossalai.nn.parallel.layers import ( + CachedEmbeddingBag, + CachedParamMgr, + EvictionStrategy, + ParallelCachedEmbeddingBag, + ParallelCachedEmbeddingBagTablewise, + TablewiseEmbeddingBagConfig, +) +from colossalai.tensor import ColoTensor, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn NUM_EMBED, EMBED_DIM = 10, 8 BATCH_SIZE = 8 @@ -44,6 +45,7 @@ def synthesize_1d_sparse_feature( @pytest.mark.skip +@clear_cache_before_run() def test_cachemgr(): model = torch.nn.EmbeddingBag(10000, 128) # 10 chunks, 5 in cuda @@ -72,6 +74,7 @@ def test_cachemgr(): assert mgr.cuda_available_chunk_num == 5 +@clear_cache_before_run() def test_reorder_with_freq(): num_embed = 100 chunk_size = 1 @@ -102,7 +105,8 @@ def test_reorder_with_freq(): f"offset in chunk: {offset_in_chunk}, mgr: {mgr_offsets}" -@pytest.mark.parametrize('use_LFU', [True, False]) +@clear_cache_before_run() +@parameterize('use_LFU', [True, False]) def test_freq_aware_embed(use_LFU: bool): device = torch.device('cuda', 0) evict_strategy = EvictionStrategy.LFU if use_LFU else EvictionStrategy.DATASET @@ -148,7 +152,8 @@ def test_freq_aware_embed(use_LFU: bool): f"model weight: {model_weight[10:18, :8]}, reference: {ref_weight[10:18, :8]}" -@pytest.mark.parametrize('init_freq', [True, False]) +@clear_cache_before_run() +@parameterize('init_freq', [True, False]) def test_lfu_strategy(init_freq: bool): # minimal test to check behavior Bag = CachedEmbeddingBag(5, @@ -248,7 +253,7 @@ def run_parallel_freq_aware_embed_tablewise(rank, world_size): input0 [1,2,3] [6,7] [] input1 [] [9] [13,15] input2 [1,5] [6,8] [11] - ↑ ↑ ↑ + ↑ ↑ ↑ rank 0 rank 0 rank 1 in KJT format ''' @@ -363,8 +368,7 @@ def run_dist(rank, world_size, port): @pytest.mark.parametrize('world_size', [1, 4]) @rerun_if_address_is_in_use() def test_parallel_freq_aware_embed(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, world_size) if __name__ == '__main__': diff --git a/tests/test_layers/test_sequence/test_sequence.py b/tests/test_layers/test_sequence/test_sequence.py index 3862c4ccd439..aac192d7eff0 100644 --- a/tests/test_layers/test_sequence/test_sequence.py +++ b/tests/test_layers/test_sequence/test_sequence.py @@ -1,14 +1,11 @@ -import colossalai -import colossalai.nn as col_nn +import pytest import torch import torch.distributed as dist -import torch.multiprocessing as mp -import pytest -from colossalai.core import global_context as gpc +import colossalai from colossalai.context import ParallelMode -from colossalai.testing import rerun_if_address_is_in_use -from functools import partial +from colossalai.core import global_context as gpc +from colossalai.testing import rerun_if_address_is_in_use, spawn CONFIG = dict(parallel=dict(tensor=dict(size=4, mode='sequence'))) @@ -121,8 +118,8 @@ def check_ring_av(rank, world_size): 'attention output cannot match' -def run_test(rank, world_size): - colossalai.launch(rank=rank, world_size=world_size, config=CONFIG, host='localhost', port=29500) +def run_test(rank, world_size, port): + colossalai.launch(rank=rank, world_size=world_size, config=CONFIG, host='localhost', port=port) # check_ring_qk(rank, world_size) check_ring_av(rank, world_size) @@ -134,9 +131,7 @@ def run_test(rank, world_size): @pytest.mark.dist @rerun_if_address_is_in_use() def test_sequence(): - world_size = 4 - run_func = partial(run_test, world_size=world_size) - mp.spawn(run_func, nprocs=world_size) + spawn(run_test, 4) if __name__ == '__main__': diff --git a/tests/test_moe/test_grad_handler.py b/tests/test_moe/test_grad_handler.py index e7b9a55277c6..e7002a75f3f7 100644 --- a/tests/test_moe/test_grad_handler.py +++ b/tests/test_moe/test_grad_handler.py @@ -1,16 +1,15 @@ -from functools import partial import pytest import torch -import torch.nn as nn -import torch.multiprocessing as mp import torch.distributed as dist +import torch.nn as nn + import colossalai -from colossalai.utils import free_port, get_current_device -from colossalai.nn.layer.moe import Top1Router, UniformNoiseGenerator, MoeLayer, Experts from colossalai.context.moe_context import MOE_CONTEXT -from colossalai.utils.moe import sync_moe_model_param from colossalai.engine.gradient_handler import MoeGradientHandler -from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use +from colossalai.nn.layer.moe import Experts, MoeLayer, Top1Router, UniformNoiseGenerator +from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn +from colossalai.utils import get_current_device +from colossalai.utils.moe import sync_moe_model_param BATCH_SIZE = 4 DIM = 16 @@ -65,9 +64,7 @@ def run_test(rank, world_size, port): @pytest.mark.dist @rerun_if_address_is_in_use() def test_grad_handler(): - world_size = 4 - run_func = partial(run_test, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_test, 4) if __name__ == '__main__': diff --git a/tests/test_moe/test_kernel.py b/tests/test_moe/test_kernel.py index 62f9241642b9..ad9a172b72aa 100644 --- a/tests/test_moe/test_kernel.py +++ b/tests/test_moe/test_kernel.py @@ -1,15 +1,14 @@ -from functools import partial import pytest import torch import torch.nn as nn -import torch.multiprocessing as mp + import colossalai from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.utils import free_port, get_current_device -from colossalai.nn.layer.moe import Top1Router, Top2Router, MoeLayer, Experts from colossalai.context.moe_context import MOE_CONTEXT -from colossalai.testing import rerun_if_address_is_in_use +from colossalai.core import global_context as gpc +from colossalai.nn.layer.moe import Experts, MoeLayer, Top1Router, Top2Router +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.utils import get_current_device BATCH_SIZE = 16 NUM_EXPERTS = 4 @@ -90,15 +89,7 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f @pytest.mark.parametrize("router", [Top1Router, Top2Router]) @rerun_if_address_is_in_use() def test_moe_kernel(rs, hidden_size, data_type, router): - world_size = 4 - run_func = partial(run_routing, - world_size=world_size, - port=free_port(), - rs=rs, - hidden_size=hidden_size, - data_type=data_type, - router=router) - mp.spawn(run_func, nprocs=world_size) + spawn(run_routing, 4, rs=rs, hidden_size=hidden_size, data_type=data_type, router=router) if __name__ == '__main__': diff --git a/tests/test_moe/test_moe_checkpoint.py b/tests/test_moe/test_moe_checkpoint.py index d2cff44ad9bd..8a0283ba71fc 100644 --- a/tests/test_moe/test_moe_checkpoint.py +++ b/tests/test_moe/test_moe_checkpoint.py @@ -1,19 +1,16 @@ import os -from functools import partial import pytest import torch import torch.distributed as dist -import torch.multiprocessing as mp import colossalai from colossalai.context import MOE_CONTEXT from colossalai.nn.layer.moe import load_moe_model, save_moe_model -from colossalai.testing import parameterize, rerun_if_address_is_in_use -from colossalai.utils import free_port, get_current_device +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.utils import get_current_device from colossalai.zero import ColoInitContext from tests.test_moe.test_moe_zero_init import MoeModel -from tests.test_tensor.common_utils import debug_print from tests.test_zero.test_legacy.common import CONFIG @@ -46,8 +43,7 @@ def _run_dist(rank, world_size, port): @pytest.mark.parametrize("world_size", [2, 4]) @rerun_if_address_is_in_use() def test_moe_checkpoint(world_size): - run_func = partial(_run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(_run_dist) if __name__ == '__main__': diff --git a/tests/test_moe/test_moe_colo_init.py b/tests/test_moe/test_moe_colo_init.py index 4826d87ac044..555338fcf9fc 100644 --- a/tests/test_moe/test_moe_colo_init.py +++ b/tests/test_moe/test_moe_colo_init.py @@ -1,15 +1,12 @@ -from functools import partial - import pytest import torch import torch.distributed as dist -import torch.multiprocessing as mp import colossalai from colossalai.context import MOE_CONTEXT from colossalai.tensor import ColoParameter -from colossalai.testing import parameterize, rerun_if_address_is_in_use -from colossalai.utils import free_port, get_current_device +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.utils import get_current_device from colossalai.zero import ColoInitContext from tests.test_moe.test_moe_zero_init import MoeModel from tests.test_tensor.common_utils import debug_print @@ -52,8 +49,7 @@ def _run_dist(rank, world_size, port): @pytest.mark.parametrize("world_size", [4]) @rerun_if_address_is_in_use() def test_moe_colo_init(world_size): - run_func = partial(_run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(_run_dist, world_size) if __name__ == '__main__': diff --git a/tests/test_moe/test_moe_group.py b/tests/test_moe/test_moe_group.py index 3126f59e246e..6dc3f5f18b6d 100644 --- a/tests/test_moe/test_moe_group.py +++ b/tests/test_moe/test_moe_group.py @@ -1,21 +1,20 @@ -from functools import partial import pytest -import torch.nn as nn -import torch.multiprocessing as mp import torch.distributed as dist +import torch.nn as nn + import colossalai -from colossalai.utils import free_port, get_current_device -from colossalai.nn.layer.moe import Experts from colossalai.context.moe_context import MOE_CONTEXT +from colossalai.nn.layer.moe import Experts +from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn +from colossalai.utils import get_current_device from colossalai.utils.moe import sync_moe_model_param -from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use D_MODEL = 4 D_FF = 8 CONFIG = dict() -def run_test(rank, port): +def run_test(rank, world_size, port): world_size = 4 colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') expert_module = nn.Linear @@ -62,9 +61,7 @@ def run_test(rank, port): @pytest.mark.dist @rerun_if_address_is_in_use() def test_moe_initialization(): - world_size = 4 - run_func = partial(run_test, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_test, 4) if __name__ == '__main__': diff --git a/tests/test_moe/test_moe_zero_init.py b/tests/test_moe/test_moe_zero_init.py index 18b50eb5c482..79722f9f4056 100644 --- a/tests/test_moe/test_moe_zero_init.py +++ b/tests/test_moe/test_moe_zero_init.py @@ -1,8 +1,5 @@ -from functools import partial - import pytest import torch -import torch.multiprocessing as mp import torch.nn as nn import colossalai @@ -10,8 +7,8 @@ from colossalai.logging import get_dist_logger from colossalai.nn import CheckpointModule from colossalai.nn.layer import MoeModule -from colossalai.testing import parameterize, rerun_if_address_is_in_use -from colossalai.utils import free_port, get_current_device +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.utils import get_current_device from colossalai.zero.legacy.init_ctx import ZeroInitContext from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy, TensorShardStrategy from tests.test_zero.test_legacy.common import CONFIG @@ -104,8 +101,7 @@ def _run_dist(rank, world_size, port): @pytest.mark.parametrize("world_size", [2, 4]) @rerun_if_address_is_in_use() def test_moe_zero_init(world_size): - run_func = partial(_run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(_run_dist, world_size) if __name__ == '__main__': diff --git a/tests/test_moe/test_moe_zero_model.py b/tests/test_moe/test_moe_zero_model.py index 49c452938e25..ec37967f18c5 100644 --- a/tests/test_moe/test_moe_zero_model.py +++ b/tests/test_moe/test_moe_zero_model.py @@ -1,15 +1,11 @@ -from functools import partial - import pytest import torch -import torch.multiprocessing as mp import colossalai from colossalai.context import MOE_CONTEXT from colossalai.engine.gradient_handler import MoeGradientHandler from colossalai.nn import MoeLoss -from colossalai.testing import assert_equal_in_group, parameterize, rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import assert_equal_in_group, parameterize, rerun_if_address_is_in_use, spawn from colossalai.zero.legacy.init_ctx import ZeroInitContext from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy, TensorShardStrategy from colossalai.zero.legacy.sharded_model import ShardedModelV2 @@ -67,8 +63,7 @@ def run_dist(rank, world_size, port): @pytest.mark.parametrize("world_size", [2]) @rerun_if_address_is_in_use() def test_moe_zero_model(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, world_size) if __name__ == '__main__': diff --git a/tests/test_moe/test_moe_zero_optim.py b/tests/test_moe/test_moe_zero_optim.py index b43e52bb4c6a..efc6e9ddae27 100644 --- a/tests/test_moe/test_moe_zero_optim.py +++ b/tests/test_moe/test_moe_zero_optim.py @@ -1,8 +1,5 @@ -from functools import partial - import pytest import torch -import torch.multiprocessing as mp import colossalai from colossalai.amp import convert_to_apex_amp @@ -10,8 +7,8 @@ from colossalai.engine.gradient_handler import MoeGradientHandler from colossalai.nn import MoeLoss from colossalai.nn.optimizer import CPUAdam -from colossalai.testing import assert_equal_in_group, parameterize, rerun_if_address_is_in_use -from colossalai.utils import free_port, get_current_device +from colossalai.testing import assert_equal_in_group, parameterize, rerun_if_address_is_in_use, spawn +from colossalai.utils import get_current_device from colossalai.zero.legacy.init_ctx import ZeroInitContext from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy, TensorShardStrategy from colossalai.zero.legacy.sharded_model import ShardedModelV2 @@ -116,8 +113,7 @@ def _run_dist(rank, world_size, port): @pytest.mark.parametrize("world_size", [2]) @rerun_if_address_is_in_use() def test_moe_zero_optim(world_size): - run_func = partial(_run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(_run_dist, world_size) if __name__ == '__main__': diff --git a/tests/test_ops/test_addmm_tp.py b/tests/test_ops/test_addmm_tp.py index 5182868b5bbd..ecd3721b902e 100644 --- a/tests/test_ops/test_addmm_tp.py +++ b/tests/test_ops/test_addmm_tp.py @@ -1,14 +1,11 @@ -import colossalai -import torch import pytest +import torch import torch.nn as nn -import torch.multiprocessing as mp -from colossalai.tensor import ColoTensor, ProcessGroup -from colossalai.tensor import ColoTensorSpec -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils import free_port -from functools import partial -from tests.test_tensor.common_utils import tensor_shard_equal, tensor_equal, split_param_row_tp1d, split_param_col_tp1d + +import colossalai +from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup +from colossalai.testing import rerun_if_address_is_in_use, spawn +from tests.test_tensor.common_utils import split_param_col_tp1d, split_param_row_tp1d, tensor_equal, tensor_shard_equal class Conv1D(nn.Module): @@ -69,8 +66,7 @@ def run_dist(rank, world_size, port): @pytest.mark.parametrize('world_size', [1, 4]) @rerun_if_address_is_in_use() def test_addmm_1d(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, world_size) if __name__ == '__main__': diff --git a/tests/test_ops/test_embedding_bag_tp.py b/tests/test_ops/test_embedding_bag_tp.py index c7a1604e5455..d3d3dcf7e2c9 100644 --- a/tests/test_ops/test_embedding_bag_tp.py +++ b/tests/test_ops/test_embedding_bag_tp.py @@ -1,14 +1,11 @@ +import pytest +import torch from torch.nn import functional as F -from functools import partial import colossalai -import pytest -import torch -import torch.multiprocessing as mp -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils import free_port from colossalai.tensor import ColoParameter, ColoTensorSpec, ProcessGroup -from tests.test_tensor.common_utils import tensor_equal, tensor_shard_equal, split_param_col_tp1d +from colossalai.testing import rerun_if_address_is_in_use, spawn +from tests.test_tensor.common_utils import split_param_col_tp1d, tensor_equal, tensor_shard_equal def run_with_spec(spec_init_func): @@ -39,8 +36,7 @@ def run_dist(rank, world_size, port): @pytest.mark.parametrize('world_size', [1, 4]) @rerun_if_address_is_in_use() def test_embedding_bag_1d(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, world_size) if __name__ == '__main__': diff --git a/tests/test_ops/test_embedding_tp.py b/tests/test_ops/test_embedding_tp.py index 541dc5c09324..c0b376e2c92a 100644 --- a/tests/test_ops/test_embedding_tp.py +++ b/tests/test_ops/test_embedding_tp.py @@ -1,14 +1,11 @@ +import pytest +import torch from torch.nn import functional as F -from functools import partial import colossalai -import pytest -import torch -import torch.multiprocessing as mp -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils import free_port -from colossalai.tensor import ColoTensorSpec, ProcessGroup, ColoTensor -from tests.test_tensor.common_utils import tensor_equal, tensor_shard_equal, split_param_col_tp1d, split_param_row_tp1d +from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup +from colossalai.testing import rerun_if_address_is_in_use, spawn +from tests.test_tensor.common_utils import split_param_col_tp1d, split_param_row_tp1d, tensor_equal, tensor_shard_equal def run_with_spec(spec_init_func, pg: ProcessGroup): @@ -40,8 +37,7 @@ def run_dist(rank, world_size, port): @pytest.mark.parametrize('world_size', [1, 4]) @rerun_if_address_is_in_use() def test_embedding_1d(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, world_size) if __name__ == '__main__': diff --git a/tests/test_ops/test_linear_tp.py b/tests/test_ops/test_linear_tp.py index 603e98564de8..c88adfdd9a77 100644 --- a/tests/test_ops/test_linear_tp.py +++ b/tests/test_ops/test_linear_tp.py @@ -1,14 +1,11 @@ -from functools import partial - -import colossalai import pytest import torch -import torch.multiprocessing as mp import torch.nn.functional as F -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils import free_port -from colossalai.tensor import ColoTensorSpec, ProcessGroup, ColoTensor -from tests.test_tensor.common_utils import tensor_equal, tensor_shard_equal, split_param_col_tp1d, split_param_row_tp1d + +import colossalai +from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup +from colossalai.testing import rerun_if_address_is_in_use, spawn +from tests.test_tensor.common_utils import split_param_col_tp1d, split_param_row_tp1d, tensor_equal, tensor_shard_equal def run_with_spec(spec_init_func, split_bias): @@ -44,8 +41,7 @@ def run_dist(rank, world_size, port): @pytest.mark.parametrize('world_size', [1, 4]) @rerun_if_address_is_in_use() def test_linear_1d(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, world_size) if __name__ == '__main__': diff --git a/tests/test_ops/test_loss_func.py b/tests/test_ops/test_loss_func.py index 9210242a0a9f..fc55c7f77254 100644 --- a/tests/test_ops/test_loss_func.py +++ b/tests/test_ops/test_loss_func.py @@ -1,52 +1,48 @@ -import torch -import pytest -import colossalai -import torch.nn.functional as F -import torch.multiprocessing as mp -from functools import partial -from colossalai.tensor import ColoTensor, ProcessGroup, ColoTensorSpec -from colossalai.utils import get_current_device -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils import free_port -from colossalai.tensor import ShardSpec, ComputeSpec, ComputePattern - - -def check_cross_entropy(): - input_t = torch.randn(4, 4, device=get_current_device(), requires_grad=True) - input_ct = torch.randn(4, 4, device=get_current_device(), requires_grad=True) - with torch.no_grad(): - input_ct.copy_(input_t) - - target = torch.randint(4, (4,), dtype=torch.int64, device=get_current_device()) - - world_size = torch.distributed.get_world_size() - pg = ProcessGroup(tp_degree=world_size) - input_t_colo = ColoTensor.from_torch_tensor(tensor=input_ct, spec=ColoTensorSpec(pg)) - input_shard = input_t_colo.redistribute(ShardSpec([-1], [pg.tp_world_size()])) - input_shard.set_tensor_spec(dist_spec=None, compute_spec=ComputeSpec(ComputePattern.TP1D)) - - output = F.cross_entropy(input_t, target) - output_colo = F.cross_entropy(input_shard, target) - assert torch.allclose(output_colo, output) - - output.backward() - output_colo.backward() - - assert torch.allclose(input_t.grad, input_ct.grad) - - -def run_dist(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - check_cross_entropy() - - -@pytest.mark.dist -@pytest.mark.parametrize('world_size', [1, 2]) -@rerun_if_address_is_in_use() -def test_loss_func(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) - - -if __name__ == '__main__': - test_loss_func(1) +import pytest +import torch +import torch.nn.functional as F + +import colossalai +from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.utils import get_current_device + + +def check_cross_entropy(): + input_t = torch.randn(4, 4, device=get_current_device(), requires_grad=True) + input_ct = torch.randn(4, 4, device=get_current_device(), requires_grad=True) + with torch.no_grad(): + input_ct.copy_(input_t) + + target = torch.randint(4, (4,), dtype=torch.int64, device=get_current_device()) + + world_size = torch.distributed.get_world_size() + pg = ProcessGroup(tp_degree=world_size) + input_t_colo = ColoTensor.from_torch_tensor(tensor=input_ct, spec=ColoTensorSpec(pg)) + input_shard = input_t_colo.redistribute(ShardSpec([-1], [pg.tp_world_size()])) + input_shard.set_tensor_spec(dist_spec=None, compute_spec=ComputeSpec(ComputePattern.TP1D)) + + output = F.cross_entropy(input_t, target) + output_colo = F.cross_entropy(input_shard, target) + assert torch.allclose(output_colo, output) + + output.backward() + output_colo.backward() + + assert torch.allclose(input_t.grad, input_ct.grad) + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + check_cross_entropy() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 2]) +@rerun_if_address_is_in_use() +def test_loss_func(world_size): + spawn(run_dist, world_size) + + +if __name__ == '__main__': + test_loss_func(1) diff --git a/tests/test_ops/test_op.py b/tests/test_ops/test_op.py index 8d3cf50ff2aa..4176d3b64d90 100644 --- a/tests/test_ops/test_op.py +++ b/tests/test_ops/test_op.py @@ -1,14 +1,12 @@ -import torch import pytest -import colossalai +import torch import torch.nn.functional as F -import torch.multiprocessing as mp -from functools import partial -from colossalai.tensor import ColoTensor, ProcessGroup, ColoTensorSpec, ShardSpec -from colossalai.utils import get_current_device from torch.nn import Parameter -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils import free_port + +import colossalai +from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup, ShardSpec +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.utils import get_current_device def _run_layer_norm(): @@ -66,8 +64,7 @@ def run_dist(rank, world_size, port): @pytest.mark.parametrize('world_size', [2]) @rerun_if_address_is_in_use() def test_element_wise_ops(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, world_size) def run_dist2(rank, world_size, port): @@ -79,8 +76,7 @@ def run_dist2(rank, world_size, port): @pytest.mark.parametrize('world_size', [1]) @rerun_if_address_is_in_use() def test_ln(world_size): - run_func = partial(run_dist2, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist2, world_size) def check_all(): diff --git a/tests/test_ops/test_view.py b/tests/test_ops/test_view.py index fc6fc2d3c291..a9f2033201c7 100644 --- a/tests/test_ops/test_view.py +++ b/tests/test_ops/test_view.py @@ -1,100 +1,97 @@ -from functools import partial - -import colossalai -import pytest -import torch -import torch.multiprocessing as mp -import torch.distributed as dist -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils import free_port, get_current_device -from colossalai.tensor import ColoTensorSpec, ProcessGroup, ColoTensor, ShardSpec -from colossalai.tensor.distspec import DistPlacementPattern -from tests.test_tensor.common_utils import split_param_row_tp1d, split_param_col_tp1d, debug_print - - -def exam_view_core(pg): - # the case of replicated ColoTensors - x = torch.randn(4, 4).cuda() - x_colo = ColoTensor(x, ColoTensorSpec(pg)) - - y = x.view(2, -1, 2) - y_colo = x_colo.view(2, -1, 2) - - assert torch.all(y == y_colo) - assert y_colo.dist_spec.placement == DistPlacementPattern.REPLICATE - # the perfect case of col-sliced ColoTensors - split_param_col_tp1d(x_colo, pg) - - z = x.view(torch.Size((2, 1, 2, -1))) - z_colo = x_colo.view(torch.Size((2, 1, 2, -1))) - if dist.get_rank() == 0: - z = z[:, :, :, 0:2] - else: - z = z[:, :, :, 2:] - assert torch.all(z == z_colo) - assert z_colo.dist_spec == x_colo.dist_spec - # the perfect case of row-sliced ColoTensors - split_param_row_tp1d(x_colo, pg) - - z = x.view(torch.Size((-1, 2, 2))) - z_colo = x_colo.view(torch.Size((-1, 2, 2))) - if dist.get_rank() == 0: - z = z[0:2, :, :] - else: - z = z[2:, :, :] - assert torch.all(z == z_colo) - assert z_colo.dist_spec == x_colo.dist_spec - # the normal case of row-sliced ColoTensors - z = x.view(-1, 2, 2, 2) - z_colo = x_colo.view(-1, 2, 2, 2) - assert torch.all(z == z_colo) - assert y_colo.dist_spec.placement == DistPlacementPattern.REPLICATE - - -def exam_view_autograd(pg): - x = torch.randn(8, 2, device=get_current_device(), requires_grad=True) - y = torch.randn(8, 2, device=get_current_device(), requires_grad=True) - with torch.no_grad(): - y.copy_(x) - y = ColoTensor(y, ColoTensorSpec(pg)) - y_slice = y.redistribute(ShardSpec([-1], [pg.tp_world_size()])) - - xx = x.view(2, 2, -1) - yy_slice = y_slice.view(2, 2, -1) - yy = yy_slice.to_replicate() - grad = torch.randn(2, 2, 4, device=get_current_device()) - - xx.backward(grad) - yy.backward(grad) - assert torch.all(x.grad == y.grad) - - -def exam_view_errors(pg): - x = torch.randn(8, 2, device=get_current_device()) - x = ColoTensor(x, ColoTensorSpec(pg)) - split_param_row_tp1d(x, pg) - - x.view('a', 'b', 'c') - x.view(8, -1) - x.view([-2, -2, -2]) - x.view((-1, -1, -1)) - - -def run_dist(rank, world_size, port): - colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - pg = ProcessGroup(tp_degree=torch.distributed.get_world_size()) - exam_view_core(pg) - exam_view_autograd(pg) - # exam_view_errors(pg) - - -@pytest.mark.dist -@pytest.mark.parametrize('world_size', [2]) -@rerun_if_address_is_in_use() -def test_view(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) - - -if __name__ == '__main__': - test_view(2) +import pytest +import torch +import torch.distributed as dist + +import colossalai +from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup, ShardSpec +from colossalai.tensor.distspec import DistPlacementPattern +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.utils import get_current_device +from tests.test_tensor.common_utils import debug_print, split_param_col_tp1d, split_param_row_tp1d + + +def exam_view_core(pg): + # the case of replicated ColoTensors + x = torch.randn(4, 4).cuda() + x_colo = ColoTensor(x, ColoTensorSpec(pg)) + + y = x.view(2, -1, 2) + y_colo = x_colo.view(2, -1, 2) + + assert torch.all(y == y_colo) + assert y_colo.dist_spec.placement == DistPlacementPattern.REPLICATE + # the perfect case of col-sliced ColoTensors + split_param_col_tp1d(x_colo, pg) + + z = x.view(torch.Size((2, 1, 2, -1))) + z_colo = x_colo.view(torch.Size((2, 1, 2, -1))) + if dist.get_rank() == 0: + z = z[:, :, :, 0:2] + else: + z = z[:, :, :, 2:] + assert torch.all(z == z_colo) + assert z_colo.dist_spec == x_colo.dist_spec + # the perfect case of row-sliced ColoTensors + split_param_row_tp1d(x_colo, pg) + + z = x.view(torch.Size((-1, 2, 2))) + z_colo = x_colo.view(torch.Size((-1, 2, 2))) + if dist.get_rank() == 0: + z = z[0:2, :, :] + else: + z = z[2:, :, :] + assert torch.all(z == z_colo) + assert z_colo.dist_spec == x_colo.dist_spec + # the normal case of row-sliced ColoTensors + z = x.view(-1, 2, 2, 2) + z_colo = x_colo.view(-1, 2, 2, 2) + assert torch.all(z == z_colo) + assert y_colo.dist_spec.placement == DistPlacementPattern.REPLICATE + + +def exam_view_autograd(pg): + x = torch.randn(8, 2, device=get_current_device(), requires_grad=True) + y = torch.randn(8, 2, device=get_current_device(), requires_grad=True) + with torch.no_grad(): + y.copy_(x) + y = ColoTensor(y, ColoTensorSpec(pg)) + y_slice = y.redistribute(ShardSpec([-1], [pg.tp_world_size()])) + + xx = x.view(2, 2, -1) + yy_slice = y_slice.view(2, 2, -1) + yy = yy_slice.to_replicate() + grad = torch.randn(2, 2, 4, device=get_current_device()) + + xx.backward(grad) + yy.backward(grad) + assert torch.all(x.grad == y.grad) + + +def exam_view_errors(pg): + x = torch.randn(8, 2, device=get_current_device()) + x = ColoTensor(x, ColoTensorSpec(pg)) + split_param_row_tp1d(x, pg) + + x.view('a', 'b', 'c') + x.view(8, -1) + x.view([-2, -2, -2]) + x.view((-1, -1, -1)) + + +def run_dist(rank, world_size, port): + colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + pg = ProcessGroup(tp_degree=torch.distributed.get_world_size()) + exam_view_core(pg) + exam_view_autograd(pg) + # exam_view_errors(pg) + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [2]) +@rerun_if_address_is_in_use() +def test_view(world_size): + spawn(run_dist, world_size) + + +if __name__ == '__main__': + test_view(2) diff --git a/tests/test_optimizer/test_cpu_adam.py b/tests/test_optimizer/test_cpu_adam.py index ea1c044f5820..8b3ecf8517f7 100644 --- a/tests/test_optimizer/test_cpu_adam.py +++ b/tests/test_optimizer/test_cpu_adam.py @@ -2,7 +2,7 @@ import torch -from colossalai.testing import parameterize +from colossalai.testing import clear_cache_before_run, parameterize def torch_adam_update( @@ -46,6 +46,7 @@ def assertTrue(condition, msg): assert condition, msg +@clear_cache_before_run() @parameterize('adamw', [True, False]) @parameterize('step', [1, 2]) @parameterize('p_dtype', [torch.float, torch.half]) diff --git a/tests/test_optimizer/test_fused_adam.py b/tests/test_optimizer/test_fused_adam.py index f7227c2d57c0..114d5293dad9 100644 --- a/tests/test_optimizer/test_fused_adam.py +++ b/tests/test_optimizer/test_fused_adam.py @@ -1,10 +1,10 @@ import torch import torch.nn as nn -from torch.optim.adam import Adam from torch.optim import AdamW +from torch.optim.adam import Adam from colossalai.nn.optimizer.fused_adam import FusedAdam -from colossalai.testing import parameterize +from colossalai.testing import clear_cache_before_run, parameterize class FC(nn.Module): @@ -17,6 +17,7 @@ def forward(self, x): return self.fc(x) +@clear_cache_before_run() @parameterize('adamw', [False, True]) @parameterize('p_dtype', [torch.float, torch.half]) @parameterize('g_dtype', [torch.float, torch.half]) diff --git a/tests/test_optimizer/test_fused_adam_kernel.py b/tests/test_optimizer/test_fused_adam_kernel.py index 8ff6618aee2e..4afa13349c1b 100644 --- a/tests/test_optimizer/test_fused_adam_kernel.py +++ b/tests/test_optimizer/test_fused_adam_kernel.py @@ -4,7 +4,7 @@ import torch.nn as nn from numpy import dtype -from colossalai.testing import parameterize +from colossalai.testing import clear_cache_before_run, parameterize from colossalai.utils import multi_tensor_applier @@ -41,6 +41,7 @@ def torch_adam_update( param.addcdiv_(exp_avg, denom, value=-step_size) +@clear_cache_before_run() @parameterize('adamw', [False, True]) @parameterize('step', [1, 2]) @parameterize('p_dtype', [torch.float, torch.half]) diff --git a/tests/test_optimizer/test_hybrid_adam.py b/tests/test_optimizer/test_hybrid_adam.py index 2576d8ffee43..d075149dfcb1 100644 --- a/tests/test_optimizer/test_hybrid_adam.py +++ b/tests/test_optimizer/test_hybrid_adam.py @@ -4,11 +4,12 @@ from torch.optim.adam import Adam from colossalai.nn.optimizer.hybrid_adam import HybridAdam -from colossalai.testing import parameterize +from colossalai.testing import clear_cache_before_run, parameterize RE = 3 +@clear_cache_before_run() @parameterize('adamw', [False, True]) @parameterize('device', ['cpu', 'cuda:0']) @parameterize('p_dtype', [torch.float]) diff --git a/tests/test_optimizer/test_nvme.py b/tests/test_optimizer/test_nvme.py index 243f785adaf9..5d794ac2dd1a 100644 --- a/tests/test_optimizer/test_nvme.py +++ b/tests/test_optimizer/test_nvme.py @@ -1,7 +1,9 @@ import pytest import torch -from tests.components_to_test.registry import non_distributed_component_funcs + from colossalai.nn.optimizer import CPUAdam, HybridAdam +from colossalai.testing import clear_cache_before_run, parameterize +from tests.components_to_test.registry import non_distributed_component_funcs def move_some_params_to_cuda(model, torch_model): @@ -16,9 +18,10 @@ def check_params_equal(model, torch_model): assert torch.allclose(p, torch_p, atol=1e-3), f'diff: {torch.abs(p - torch_p)}' -@pytest.mark.parametrize('nvme_offload_fraction', [0.0, 0.5, 1.0]) -@pytest.mark.parametrize('nvme_offload_dir', ['./offload', None]) -@pytest.mark.parametrize('adam_cls', [CPUAdam, HybridAdam]) +@clear_cache_before_run() +@parameterize('nvme_offload_fraction', [0.0, 0.5, 1.0]) +@parameterize('nvme_offload_dir', ['./offload', None]) +@parameterize('adam_cls', [CPUAdam, HybridAdam]) def test_nvme_adam(nvme_offload_fraction, nvme_offload_dir, adam_cls): get_components_func = non_distributed_component_funcs.get_callable('simple_net') model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() diff --git a/tests/test_pipeline/rpc_test_utils.py b/tests/test_pipeline/rpc_test_utils.py index 7ce2cd433b12..dab474a4ee21 100644 --- a/tests/test_pipeline/rpc_test_utils.py +++ b/tests/test_pipeline/rpc_test_utils.py @@ -6,13 +6,14 @@ import torch.distributed as dist import torch.distributed.rpc as rpc import torch.multiprocessing as mp -from colossalai import launch -from colossalai.logging import disable_existing_loggers -from colossalai.pipeline.pipeline_process_group import ppg from torch import nn from torch._C._distributed_rpc import _is_current_rpc_agent_set from torch.optim import SGD, Adam, Optimizer, RMSprop +from colossalai import launch +from colossalai.logging import disable_existing_loggers +from colossalai.pipeline.pipeline_process_group import ppg + rpc_is_initialized = _is_current_rpc_agent_set @@ -20,7 +21,9 @@ def color_debug(text, prefix=' ', color='blue'): color = color.upper() print(getattr(Back, color), prefix, Style.RESET_ALL, text) + class MLP(nn.Module): + def __init__(self, dim: int, layers: int): super().__init__() self.layers = torch.nn.ModuleList() @@ -32,8 +35,10 @@ def forward(self, x): for layer in self.layers: x = layer(x) return x.sum() - + + class DAG_MLP(nn.Module): + def __init__(self, dim: int, layers: int): super().__init__() self.layers = torch.nn.ModuleList() @@ -48,6 +53,7 @@ def forward(self, x, y): y = self.dag_layer(y) return x.sum(), y.sum() + class RpcTestModel(nn.Module): def __init__(self, stage_id, actual_stage_num, feat_num, h) -> None: diff --git a/tests/test_pipeline/test_middleware_1f1b.py b/tests/test_pipeline/test_middleware_1f1b.py index c4dc617b1683..5b3aad703275 100644 --- a/tests/test_pipeline/test_middleware_1f1b.py +++ b/tests/test_pipeline/test_middleware_1f1b.py @@ -1,27 +1,27 @@ -import torch -import pytest import os -import torch.multiprocessing as mp -import torch.distributed.rpc as rpc +from functools import partial -from torch import nn +import pytest +import torch +import torch.distributed.rpc as rpc +from rpc_test_utils import DAG_MLP, MLP from torch._C._distributed_rpc import _is_current_rpc_agent_set + from colossalai import launch +from colossalai.fx import ColoTracer +from colossalai.fx.passes.adding_split_node_pass import balanced_split_pass, split_with_split_nodes_pass from colossalai.logging import disable_existing_loggers +from colossalai.pipeline.middleware.adaptor import get_fx_topology from colossalai.pipeline.pipeline_process_group import ppg from colossalai.pipeline.rpc._pipeline_schedule import OneFOneBPipelineEngine -from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass -from colossalai.fx import ColoTracer -from colossalai.pipeline.middleware.adaptor import get_fx_topology -from rpc_test_utils import MLP, DAG_MLP -from functools import partial -from colossalai.testing import parameterize, rerun_if_address_is_in_use +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn # global variable for model created batch_size = 16 dim = 10 rpc_is_initialized = _is_current_rpc_agent_set + def create_partition_module(pp_rank: int, stage_num: int, model, data_kwargs): model.eval() tracer = ColoTracer() @@ -34,13 +34,15 @@ def create_partition_module(pp_rank: int, stage_num: int, model, data_kwargs): for submodule in split_submodules: if isinstance(submodule, torch.fx.GraphModule): setattr(submodule, '_topo', topo) - return split_submodules[pp_rank+1] + return split_submodules[pp_rank + 1] + def partition(model, data_kwargs: dict, pp_rank: int, chunk: int, stage_num: int): torch.manual_seed(1024) partition = create_partition_module(pp_rank, stage_num, model, data_kwargs) return partition + def run_master(model_cls, world_size, forward_only): torch.manual_seed(100) @@ -50,23 +52,27 @@ def run_master(model_cls, world_size, forward_only): chunk = 1 num_microbatches = 8 use_checkpoint = 'store_true' - + if model_cls == MLP: + def data_gen(): x = torch.zeros((batch_size, dim)) kwargs = dict(x=x) return kwargs + model = model_cls(dim, stage_num * 3) if forward_only: labels = None else: labels = 1 elif model_cls == DAG_MLP: + def data_gen(): x = torch.zeros((batch_size, dim)) y = torch.zeros((batch_size, dim)) kwargs = dict(x=x, y=y) return kwargs + model = model_cls(dim, stage_num * 3) if forward_only: labels = None @@ -74,15 +80,17 @@ def data_gen(): labels = 1 else: pass - + data_kwargs = data_gen() - - engine = OneFOneBPipelineEngine(partition_fn=partial(partition, model, data_kwargs), - stage_num=stage_num, - num_microbatches=num_microbatches, - device=device, - chunk=chunk, - checkpoint=use_checkpoint,) + + engine = OneFOneBPipelineEngine( + partition_fn=partial(partition, model, data_kwargs), + stage_num=stage_num, + num_microbatches=num_microbatches, + device=device, + chunk=chunk, + checkpoint=use_checkpoint, + ) if not forward_only: engine.initialize_optimizer(getattr(torch.optim, 'SGD'), lr=1e-3) @@ -90,13 +98,14 @@ def data_gen(): input_x = torch.randn((batch_size, dim), device=device) input_y = torch.randn((batch_size, dim), device=device) logits = engine.forward_backward({'x': input_x, 'y': input_y}, labels=labels, forward_only=forward_only) - -def run_worker(rank, model_cls, world_size, forward_only, master_func): + + +def run_worker(rank, world_size, port, model_cls, forward_only, master_func): master_addr = 'localhost' master_port = 29020 os.environ['MASTER_ADDR'] = master_addr os.environ['MASTER_PORT'] = str(master_port) - + disable_existing_loggers() launch(dict(), rank, world_size, master_addr, master_port, 'nccl', verbose=False) @@ -113,7 +122,8 @@ def run_worker(rank, model_cls, world_size, forward_only, master_func): # barrier here if rpc_is_initialized(): rpc.shutdown() - + + @pytest.mark.skip("skip due to CI torch version 1.11") @parameterize('model_cls', [MLP, DAG_MLP]) @parameterize('forward_only', [True, False]) @@ -122,7 +132,14 @@ def run_worker(rank, model_cls, world_size, forward_only, master_func): def test_pp_middleware_fwd(model_cls, forward_only): world_size = 4 master_func = run_master - mp.spawn(run_worker, args=(model_cls, world_size, forward_only, master_func), nprocs=world_size) + spawn( + run_worker, + world_size, + model_cls=model_cls, + forward_only=forward_only, + master_func=master_func, + ) + if __name__ == "__main__": - test_pp_middleware_fwd() \ No newline at end of file + test_pp_middleware_fwd() diff --git a/tests/test_pipeline/test_pipelinable.py b/tests/test_pipeline/test_pipelinable.py index c99a88550b71..627cb5ac6f51 100644 --- a/tests/test_pipeline/test_pipelinable.py +++ b/tests/test_pipeline/test_pipelinable.py @@ -1,9 +1,7 @@ import torch -import torch.multiprocessing as mp from colossalai.pipeline.pipelinable import PipelinableContext - -from colossalai.testing import rerun_on_exception +from colossalai.testing import rerun_if_address_is_in_use, rerun_on_exception, spawn NUM_CHUNKS = 1 PIPELINE_SIZE = 2 @@ -27,7 +25,7 @@ def forward(self, x): return x -def run_pipelinable(rank): +def run_pipelinable(rank, world_size, port): pipelinable = PipelinableContext() with pipelinable: model = MLP() @@ -50,9 +48,9 @@ def run_pipelinable(rank): assert layers_count_in_part_0 + layers_count_in_part_1 == pipelinable.layers_count -@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") +@rerun_if_address_is_in_use() def test_pipelinable(): - mp.spawn(run_pipelinable, nprocs=1) + spawn(run_pipelinable, 1) if __name__ == '__main__': diff --git a/tests/test_pipeline/test_pipeline_process_group.py b/tests/test_pipeline/test_pipeline_process_group.py index c67e4175df92..2a00e3ac55b1 100644 --- a/tests/test_pipeline/test_pipeline_process_group.py +++ b/tests/test_pipeline/test_pipeline_process_group.py @@ -1,13 +1,12 @@ import os import torch.distributed.rpc as rpc -import torch.multiprocessing as mp -import pytest +from rpc_test_utils import pg_parse_args, rpc_is_initialized -from colossalai.pipeline.pipeline_process_group import ppg from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers -from rpc_test_utils import pg_parse_args, rpc_is_initialized +from colossalai.pipeline.pipeline_process_group import ppg +from colossalai.testing import spawn def run_worker(rank, args): @@ -40,4 +39,4 @@ def run_worker(rank, args): if __name__ == "__main__": args = pg_parse_args() world_size = args.world_size - mp.spawn(run_worker, args=(args,), nprocs=world_size) \ No newline at end of file + spawn(run_worker, world_size, args=args) diff --git a/tests/test_tensor/core/test_dist_spec_mgr.py b/tests/test_tensor/core/test_dist_spec_mgr.py index e02f4e7977f6..89476a35b63a 100644 --- a/tests/test_tensor/core/test_dist_spec_mgr.py +++ b/tests/test_tensor/core/test_dist_spec_mgr.py @@ -1,13 +1,12 @@ import math + +import pytest import torch import torch.distributed as dist -import pytest + import colossalai -import torch.multiprocessing as mp -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils import free_port -from colossalai.tensor import DistSpecManager, ProcessGroup, ShardSpec, ReplicaSpec -from functools import partial +from colossalai.tensor import DistSpecManager, ProcessGroup, ReplicaSpec, ShardSpec +from colossalai.testing import rerun_if_address_is_in_use, spawn def run(): @@ -58,8 +57,7 @@ def run_dist(rank, world_size, port): @pytest.mark.parametrize('world_size', [1, 4]) @rerun_if_address_is_in_use() def test_dist_spec_mgr(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, world_size) if __name__ == '__main__': diff --git a/tests/test_tensor/core/test_tensor.py b/tests/test_tensor/core/test_tensor.py index b48d9e9a2dfa..64d198b350a8 100644 --- a/tests/test_tensor/core/test_tensor.py +++ b/tests/test_tensor/core/test_tensor.py @@ -1,17 +1,11 @@ -import torch import pytest -from colossalai.tensor import ColoTensor +import torch from numpy import allclose import colossalai -from colossalai.utils import free_port -from colossalai.tensor import ColoTensorSpec from colossalai.core import global_context as gpc -import torch.multiprocessing as mp -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils import free_port -from colossalai.tensor import distspec, ColoTensor, ProcessGroup, ShardSpec, ReplicaSpec -from functools import partial +from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup, ReplicaSpec, ShardSpec, distspec +from colossalai.testing import rerun_if_address_is_in_use, spawn def _run_tensor_indexing(): @@ -152,8 +146,7 @@ def run_dist_tests(rank, world_size, port): @pytest.mark.parametrize('world_size', [1, 2]) @rerun_if_address_is_in_use() def test_dist_cases(world_size): - run_func = partial(run_dist_tests, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist_tests, world_size) if __name__ == '__main__': diff --git a/tests/test_tensor/model/test_gpt2.py b/tests/test_tensor/model/test_gpt2.py index 0d6a3fe26c2d..337bfa840d5d 100644 --- a/tests/test_tensor/model/test_gpt2.py +++ b/tests/test_tensor/model/test_gpt2.py @@ -1,15 +1,11 @@ -from functools import partial - import pytest import torch -import torch.multiprocessing as mp from torch.nn.parallel import DistributedDataParallel as DDP import colossalai from colossalai.nn.parallel.data_parallel import ColoDDP from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.utils.cuda import get_current_device from colossalai.zero import ColoInitContext from tests.components_to_test.registry import non_distributed_component_funcs @@ -145,8 +141,7 @@ def run_dist(rank, world_size, port, use_ddp): @pytest.mark.parametrize('use_ddp', [False, True]) @rerun_if_address_is_in_use() def test_gpt(world_size, use_ddp): - run_func = partial(run_dist, world_size=world_size, port=free_port(), use_ddp=use_ddp) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, world_size, use_ddp=use_ddp) if __name__ == '__main__': diff --git a/tests/test_tensor/model/test_model.py b/tests/test_tensor/model/test_model.py index 83abc641cbd4..79d70e53c5cb 100644 --- a/tests/test_tensor/model/test_model.py +++ b/tests/test_tensor/model/test_model.py @@ -1,15 +1,11 @@ -from functools import partial - import pytest import torch -import torch.multiprocessing as mp import colossalai from colossalai.nn.optimizer import ColossalaiOptimizer from colossalai.tensor import ColoTensor, ProcessGroup from colossalai.tensor.colo_parameter import ColoParameter -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import free_port, rerun_if_address_is_in_use, spawn from colossalai.utils.cuda import get_current_device from colossalai.zero import ColoInitContext from tests.components_to_test.registry import non_distributed_component_funcs @@ -313,8 +309,7 @@ def run_model_dist(rank, world_size, port): @pytest.mark.parametrize('world_size', [1, 4]) @rerun_if_address_is_in_use() def test_model(world_size): - run_func = partial(run_model_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_model_dist, world_size) def run_pretrain_load_dist(rank, world_size, port): @@ -329,8 +324,7 @@ def run_pretrain_load_dist(rank, world_size, port): @pytest.mark.parametrize('world_size', [1, 4]) @rerun_if_address_is_in_use() def test_pretrain_load(world_size): - run_func = partial(run_pretrain_load_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_pretrain_load_dist, world_size) if __name__ == '__main__': diff --git a/tests/test_tensor/model/test_module_spec.py b/tests/test_tensor/model/test_module_spec.py index 739bf2b0a641..b50851e5eaf2 100644 --- a/tests/test_tensor/model/test_module_spec.py +++ b/tests/test_tensor/model/test_module_spec.py @@ -1,9 +1,7 @@ from copy import deepcopy -from functools import partial import pytest import torch -import torch.multiprocessing as mp import colossalai from colossalai.nn.parallel.layers import check_colo_module, init_colo_module @@ -17,8 +15,7 @@ ShardSpec, distspec, ) -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.utils.cuda import get_current_device from colossalai.zero import ColoInitContext from tests.components_to_test.registry import non_distributed_component_funcs @@ -207,8 +204,7 @@ def run_dist_check(rank, world_size, port): @pytest.mark.skip("for higher testing speed") @rerun_if_address_is_in_use() def test_module_linear_1d(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, world_size) @pytest.mark.dist @@ -216,8 +212,7 @@ def test_module_linear_1d(world_size): @pytest.mark.skip("for higher testing speed") @rerun_if_address_is_in_use() def test_module_model(world_size): - run_func = partial(run_dist_model, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist_model, world_size) @pytest.mark.dist @@ -225,8 +220,7 @@ def test_module_model(world_size): @pytest.mark.skip("for higher testing speed") @rerun_if_address_is_in_use() def test_module_check(world_size): - run_func = partial(run_dist_check, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist_check, world_size) if __name__ == '__main__': diff --git a/tests/test_tensor/test_colo_checkpoint_tools.py b/tests/test_tensor/test_colo_checkpoint_tools.py index aa333d55276c..a53a3f37a664 100644 --- a/tests/test_tensor/test_colo_checkpoint_tools.py +++ b/tests/test_tensor/test_colo_checkpoint_tools.py @@ -1,47 +1,41 @@ -import torch -import pytest -from functools import partial - -import torch.multiprocessing as mp -import torch.distributed as dist - -import colossalai -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils.cuda import get_current_device -from colossalai.utils import free_port -from colossalai.tensor import ComputePattern, ComputeSpec, ColoTensor, ShardSpec, ProcessGroup, ColoTensorSpec -from colossalai.utils.checkpoint.utils import gather_tensor, scatter_tensor -from tests.test_tensor.common_utils import tensor_shard_equal - - -def run_dist(rank, world_size, port, dp_degree, tp_degree): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - pg = ProcessGroup(dp_degree=dp_degree, tp_degree=tp_degree) - x = torch.randn(4, 4) - param = ColoTensor(torch.nn.Parameter(x), spec=ColoTensorSpec(pg)) - spec = ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D) - param.set_tensor_spec(*spec) - - gather_tensor(param) - if dist.get_rank() == 0: - assert torch.all(x == param) - else: - assert tensor_shard_equal(x, param.data, pg.tp_local_rank(), pg.tp_world_size()) - dist.barrier() - - scatter_tensor(param, spec[0]) - assert tensor_shard_equal(x, param.data, pg.tp_local_rank(), pg.tp_world_size()) - assert param.requires_grad is True - dist.barrier() - - -@pytest.mark.dist -@pytest.mark.parametrize('world_size', [4]) -@rerun_if_address_is_in_use() -def test_checkpoint(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port(), dp_degree=2, tp_degree=world_size // 2) - mp.spawn(run_func, nprocs=world_size) - - -if __name__ == '__main__': - test_checkpoint(world_size=4) +import pytest +import torch +import torch.distributed as dist + +import colossalai +from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.utils.checkpoint.utils import gather_tensor, scatter_tensor +from tests.test_tensor.common_utils import tensor_shard_equal + + +def run_dist(rank, world_size, port, dp_degree, tp_degree): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + pg = ProcessGroup(dp_degree=dp_degree, tp_degree=tp_degree) + x = torch.randn(4, 4) + param = ColoTensor(torch.nn.Parameter(x), spec=ColoTensorSpec(pg)) + spec = ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D) + param.set_tensor_spec(*spec) + + gather_tensor(param) + if dist.get_rank() == 0: + assert torch.all(x == param) + else: + assert tensor_shard_equal(x, param.data, pg.tp_local_rank(), pg.tp_world_size()) + dist.barrier() + + scatter_tensor(param, spec[0]) + assert tensor_shard_equal(x, param.data, pg.tp_local_rank(), pg.tp_world_size()) + assert param.requires_grad is True + dist.barrier() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [4]) +@rerun_if_address_is_in_use() +def test_checkpoint(world_size): + spawn(run_dist, world_size, dp_degree=2, tp_degree=world_size // 2) + + +if __name__ == '__main__': + test_checkpoint(world_size=4) diff --git a/tests/test_tensor/test_comm_spec_apply.py b/tests/test_tensor/test_comm_spec_apply.py index 46eee61f1ecf..2c68633aabc8 100644 --- a/tests/test_tensor/test_comm_spec_apply.py +++ b/tests/test_tensor/test_comm_spec_apply.py @@ -1,10 +1,5 @@ -from functools import partial - import pytest import torch -import torch.distributed as dist -import torch.multiprocessing as mp -from torch.distributed import ReduceOp from colossalai.core import global_context as gpc from colossalai.device.device_mesh import DeviceMesh @@ -12,8 +7,7 @@ from colossalai.logging import disable_existing_loggers from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec from colossalai.tensor.sharding_spec import ShardingSpec -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import rerun_if_address_is_in_use, spawn def check_all_gather(device_mesh, rank): @@ -218,8 +212,7 @@ def check_comm(rank, world_size, port): @rerun_if_address_is_in_use() def test_comm_spec(): world_size = 4 - run_func = partial(check_comm, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(check_comm, world_size) if __name__ == '__main__': diff --git a/tests/test_tensor/test_context.py b/tests/test_tensor/test_context.py index 047371f45bda..45def034ba8e 100644 --- a/tests/test_tensor/test_context.py +++ b/tests/test_tensor/test_context.py @@ -1,8 +1,5 @@ -from functools import partial - import pytest import torch -import torch.multiprocessing as mp import colossalai from colossalai.tensor import ( @@ -14,8 +11,7 @@ ReplicaSpec, ShardSpec, ) -from colossalai.testing import parameterize, rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.utils.cuda import get_current_device from colossalai.zero import ColoInitContext from tests.components_to_test.registry import non_distributed_component_funcs @@ -61,8 +57,7 @@ def run_colo_init_context(rank: int, world_size: int, port: int): @pytest.mark.parametrize('world_size', [1, 4]) @rerun_if_address_is_in_use() def test_colo_init_context(world_size): - run_func = partial(run_colo_init_context, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_colo_init_context, world_size) if __name__ == '__main__': diff --git a/tests/test_tensor/test_dtensor/test_comm_spec.py b/tests/test_tensor/test_dtensor/test_comm_spec.py index 547a96b264dc..d1f5b9299397 100644 --- a/tests/test_tensor/test_dtensor/test_comm_spec.py +++ b/tests/test_tensor/test_dtensor/test_comm_spec.py @@ -1,9 +1,6 @@ -from functools import partial - import pytest import torch import torch.distributed as dist -import torch.multiprocessing as mp from torch.distributed import ReduceOp from colossalai.core import global_context as gpc @@ -12,8 +9,7 @@ from colossalai.logging import disable_existing_loggers from colossalai.tensor.d_tensor.comm_spec import CollectiveCommPattern, CommSpec from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import rerun_if_address_is_in_use, spawn def check_all_gather(process_groups_dict, rank): @@ -182,8 +178,7 @@ def check_comm(rank, world_size, port): @rerun_if_address_is_in_use() def test_comm_spec(): world_size = 4 - run_func = partial(check_comm, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(check_comm, world_size) if __name__ == '__main__': diff --git a/tests/test_tensor/test_dtensor/test_dtensor.py b/tests/test_tensor/test_dtensor/test_dtensor.py index a99ac6e41c5e..3ca369acbf87 100644 --- a/tests/test_tensor/test_dtensor/test_dtensor.py +++ b/tests/test_tensor/test_dtensor/test_dtensor.py @@ -1,7 +1,4 @@ -from functools import partial - import torch -import torch.multiprocessing as mp from colossalai.device.device_mesh import DeviceMesh from colossalai.initialize import launch @@ -9,7 +6,7 @@ from colossalai.tensor.d_tensor.d_tensor import DTensor, distribute_tensor from colossalai.tensor.d_tensor.layout import Layout from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec -from colossalai.utils import free_port +from colossalai.testing import rerun_if_address_is_in_use, spawn class TestModel(torch.nn.Module): @@ -92,10 +89,10 @@ def check_dtensor(rank, world_size, port): raise ValueError(f'rank {rank} is not in the device mesh') +@rerun_if_address_is_in_use() def test_dtensor(): world_size = 4 - run_func = partial(check_dtensor, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(check_dtensor, world_size) if __name__ == '__main__': diff --git a/tests/test_tensor/test_dtensor/test_layout_converter.py b/tests/test_tensor/test_dtensor/test_layout_converter.py index 70cf8726dbd0..5f56decb5e5d 100644 --- a/tests/test_tensor/test_dtensor/test_layout_converter.py +++ b/tests/test_tensor/test_dtensor/test_layout_converter.py @@ -1,9 +1,7 @@ import math -from functools import partial import pytest import torch -import torch.multiprocessing as mp from colossalai.device.device_mesh import DeviceMesh from colossalai.initialize import launch @@ -12,8 +10,7 @@ from colossalai.tensor.d_tensor.layout import Layout from colossalai.tensor.d_tensor.layout_converter import LayoutConverter from colossalai.tensor.d_tensor.sharding_spec import DimSpec, ShardingSpec -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import rerun_if_address_is_in_use, spawn entire_shape = torch.Size((64, 32, 16)) layout_converter = LayoutConverter() @@ -192,14 +189,9 @@ def check_layout_converting_apply(rank, world_size, port): @rerun_if_address_is_in_use() def test_layout_converter(): world_size = 4 - run_func = partial(check_one_step_transform, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) - - run_func = partial(check_layout_converting, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) - - run_func = partial(check_layout_converting_apply, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(check_one_step_transform, world_size) + spawn(check_layout_converting, world_size) + spawn(check_layout_converting_apply, world_size) if __name__ == '__main__': diff --git a/tests/test_tensor/test_mix_gather.py b/tests/test_tensor/test_mix_gather.py index c1ab30601501..9122808eb5a3 100644 --- a/tests/test_tensor/test_mix_gather.py +++ b/tests/test_tensor/test_mix_gather.py @@ -1,8 +1,5 @@ -from functools import partial - import pytest import torch -import torch.multiprocessing as mp from colossalai.core import global_context as gpc from colossalai.device.device_mesh import DeviceMesh @@ -11,7 +8,7 @@ from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec from colossalai.tensor.sharding_spec import ShardingSpec from colossalai.tensor.utils import mix_gather_simulator -from colossalai.utils import free_port +from colossalai.testing import rerun_if_address_is_in_use, spawn def check_mix_gather_S0S1(device_mesh, rank): @@ -323,10 +320,10 @@ def check_comm(rank, world_size, port): @pytest.mark.skip(reason="Skip because the check functions assume 8 GPUS but CI only have 4 GPUs") +@rerun_if_address_is_in_use() def test_mix_gather(): world_size = 8 - run_func = partial(check_comm, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(check_comm, world_size) if __name__ == '__main__': diff --git a/tests/test_tensor/test_parameter.py b/tests/test_tensor/test_parameter.py index 7c3c4b2132e4..9c3f05da1ffa 100644 --- a/tests/test_tensor/test_parameter.py +++ b/tests/test_tensor/test_parameter.py @@ -1,9 +1,10 @@ -from colossalai.tensor import ColoParameter, ColoTensor, ColoTensorSpec, ProcessGroup -import torch import pytest +import torch from common_utils import tensor_equal + import colossalai -from colossalai.utils import free_port +from colossalai.tensor import ColoParameter, ColoTensor, ColoTensorSpec, ProcessGroup +from colossalai.testing import free_port @pytest.mark.skip diff --git a/tests/test_tensor/test_shape_consistency_apply.py b/tests/test_tensor/test_shape_consistency_apply.py index 4c838bc83fad..b57952df401f 100644 --- a/tests/test_tensor/test_shape_consistency_apply.py +++ b/tests/test_tensor/test_shape_consistency_apply.py @@ -1,16 +1,12 @@ -from functools import partial - import pytest import torch -import torch.multiprocessing as mp from colossalai.device.device_mesh import DeviceMesh from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers from colossalai.tensor.shape_consistency import CollectiveCommPattern, ShapeConsistencyManager from colossalai.tensor.sharding_spec import ShardingSpec -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import rerun_if_address_is_in_use, spawn def check_apply(rank, world_size, port): @@ -73,8 +69,7 @@ def check_apply(rank, world_size, port): @rerun_if_address_is_in_use() def test_apply(): world_size = 4 - run_func = partial(check_apply, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(check_apply, world_size) if __name__ == '__main__': diff --git a/tests/test_tensor/test_sharded_linear.py b/tests/test_tensor/test_sharded_linear.py index 85008c67a9c2..d66d4fec14d1 100644 --- a/tests/test_tensor/test_sharded_linear.py +++ b/tests/test_tensor/test_sharded_linear.py @@ -1,8 +1,5 @@ -from functools import partial - import pytest import torch -import torch.multiprocessing as mp import torch.nn.functional as F import colossalai @@ -10,8 +7,7 @@ from colossalai.nn._ops._utils import gather_forward_split_backward from colossalai.tensor import ColoParameter, ColoTensor, ProcessGroup from colossalai.tensor.sharding_spec import ShardingSpec -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import rerun_if_address_is_in_use, spawn def run_dist(rank, world_size, port): @@ -229,8 +225,7 @@ def run_dist(rank, world_size, port): @pytest.mark.parametrize('world_size', [4]) @rerun_if_address_is_in_use() def test_sharded_mlp(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, world_size) if __name__ == '__main__': diff --git a/tests/test_tensor/test_tp_with_zero.py b/tests/test_tensor/test_tp_with_zero.py index 94e39e5d1546..c636d9442902 100644 --- a/tests/test_tensor/test_tp_with_zero.py +++ b/tests/test_tensor/test_tp_with_zero.py @@ -1,15 +1,11 @@ -from functools import partial - import pytest import torch -import torch.multiprocessing as mp from torch.nn.parallel import DistributedDataParallel as DDP import colossalai from colossalai.amp import convert_to_apex_amp from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec -from colossalai.testing import parameterize, rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.utils.cuda import get_current_device from colossalai.zero import ColoInitContext, GeminiAdamOptimizer, GeminiDDP, ZeroDDP from colossalai.zero.gemini import search_chunk_configuration @@ -140,8 +136,7 @@ def run_dist(rank, world_size, port): @pytest.mark.parametrize('world_size', [1, 4]) @rerun_if_address_is_in_use() def test_gpt(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, world_size) if __name__ == '__main__': diff --git a/tests/test_trainer/test_pipeline/test_p2p.py b/tests/test_trainer/test_pipeline/test_p2p.py index 72820c6a1f0d..cb7a193d2bfa 100644 --- a/tests/test_trainer/test_pipeline/test_p2p.py +++ b/tests/test_trainer/test_pipeline/test_p2p.py @@ -1,21 +1,26 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -from functools import partial - import pytest import torch import torch.distributed as dist -import torch.multiprocessing as mp -from colossalai.communication import (recv_backward, recv_forward, recv_obj_meta, send_backward, - send_backward_recv_forward, send_forward, send_forward_recv_backward, - send_obj_meta) + +from colossalai.communication import ( + recv_backward, + recv_forward, + recv_obj_meta, + send_backward, + send_backward_recv_forward, + send_forward, + send_forward_recv_backward, + send_obj_meta, +) from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc from colossalai.initialize import launch from colossalai.logging import get_dist_logger -from colossalai.utils import free_port, get_current_device -from colossalai.testing import rerun_on_exception +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.utils import get_current_device BATCH_SIZE = 4 SEQ_LENGTH = 2 @@ -93,11 +98,10 @@ def run_check(rank, world_size, port): @pytest.mark.dist -@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") +@rerun_if_address_is_in_use() def test_p2p(): world_size = 4 - run_func = partial(run_check, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_check, world_size) if __name__ == '__main__': diff --git a/tests/test_trainer/test_pipeline/test_pipeline_schedule.py b/tests/test_trainer/test_pipeline/test_pipeline_schedule.py index 48f729658134..6d7bf6b3d89f 100644 --- a/tests/test_trainer/test_pipeline/test_pipeline_schedule.py +++ b/tests/test_trainer/test_pipeline/test_pipeline_schedule.py @@ -1,34 +1,26 @@ # referenced from Megatron and used to testify communication import os -import os.path as osp -from functools import partial from pathlib import Path -import colossalai import pytest import torch import torch.nn as nn -import torch.multiprocessing as mp -from colossalai.core import global_context as gpc -from colossalai.context import ParallelMode -from colossalai.initialize import launch -from colossalai.utils import free_port, get_dataloader, print_rank_0 -from colossalai.testing import rerun_on_exception from torchvision import transforms from torchvision.datasets import CIFAR10 from torchvision.models import resnet18 +import colossalai +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.initialize import launch +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.utils import get_dataloader, print_rank_0 BATCH_SIZE = 8 -CONFIG=dict( - NUM_MICRO_BATCHES=2, - parallel = dict( - pipeline=dict(size=2), - tensor=dict(size=1, mode=None) - ) -) +CONFIG = dict(NUM_MICRO_BATCHES=2, parallel=dict(pipeline=dict(size=2), tensor=dict(size=1, mode=None))) + def run_schedule(rank, world_size, port): launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') @@ -85,11 +77,10 @@ def forward(self, x): @pytest.mark.dist -@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") +@rerun_if_address_is_in_use() def test_pipeline_schedule(): world_size = 2 - run_func = partial(run_schedule, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_schedule, world_size) if __name__ == '__main__': diff --git a/tests/test_trainer/test_trainer_with_non_pipe_schedule.py b/tests/test_trainer/test_trainer_with_non_pipe_schedule.py index b013433293cd..753f82222f9d 100644 --- a/tests/test_trainer/test_trainer_with_non_pipe_schedule.py +++ b/tests/test_trainer/test_trainer_with_non_pipe_schedule.py @@ -1,15 +1,13 @@ -from functools import partial - -import colossalai import pytest import torch -import torch.multiprocessing as mp + +import colossalai from colossalai.amp.amp_type import AMP_TYPE from colossalai.logging import get_dist_logger +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.trainer import Trainer -from colossalai.utils import MultiTimer, free_port +from colossalai.utils import MultiTimer from tests.components_to_test.registry import non_distributed_component_funcs -from colossalai.testing import parameterize, rerun_if_address_is_in_use BATCH_SIZE = 4 IMG_SIZE = 32 @@ -54,8 +52,7 @@ def run_dist(rank, world_size, port): @rerun_if_address_is_in_use() def test_trainer_no_pipeline(): world_size = 4 - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, world_size) if __name__ == '__main__': diff --git a/tests/test_trainer/test_trainer_with_pipe_schedule.py b/tests/test_trainer/test_trainer_with_pipe_schedule.py index 3698526a8e6c..bb63d51a0b65 100644 --- a/tests/test_trainer/test_trainer_with_pipe_schedule.py +++ b/tests/test_trainer/test_trainer_with_pipe_schedule.py @@ -1,23 +1,21 @@ import os -from functools import partial from pathlib import Path -import colossalai import pytest import torch -import torch.multiprocessing as mp import torch.nn as nn -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.engine.schedule import PipelineSchedule -from colossalai.logging import get_dist_logger -from colossalai.trainer import Trainer -from colossalai.utils import MultiTimer, free_port, get_dataloader from torch.optim import Adam from torchvision import transforms from torchvision.datasets import CIFAR10 from torchvision.models import resnet18 -from colossalai.testing import rerun_if_address_is_in_use + +import colossalai +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.logging import get_dist_logger +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.trainer import Trainer +from colossalai.utils import MultiTimer, get_dataloader BATCH_SIZE = 4 IMG_SIZE = 32 @@ -91,8 +89,7 @@ def forward(self, x): @rerun_if_address_is_in_use() def test_trainer_with_pipeline(): world_size = 4 - run_func = partial(run_trainer_with_pipeline, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_trainer_with_pipeline, world_size) if __name__ == '__main__': diff --git a/tests/test_utils/test_activation_checkpointing.py b/tests/test_utils/test_activation_checkpointing.py index 3ac75fb00c86..59a8acd4b210 100644 --- a/tests/test_utils/test_activation_checkpointing.py +++ b/tests/test_utils/test_activation_checkpointing.py @@ -4,8 +4,10 @@ import pytest import torch import torch.nn.functional as F + from colossalai.context.parallel_mode import ParallelMode -from colossalai.context.random import add_seed, seed, set_mode, reset_seeds +from colossalai.context.random import add_seed, reset_seeds, seed, set_mode +from colossalai.testing import clear_cache_before_run, parameterize from colossalai.utils.activation_checkpoint import checkpoint @@ -39,8 +41,9 @@ def forward_inplace(x, weight): @pytest.mark.gpu -@pytest.mark.parametrize("use_reentrant", [True, False]) -@pytest.mark.parametrize("cpu_offload", [True, False]) +@clear_cache_before_run() +@parameterize("use_reentrant", [True, False]) +@parameterize("cpu_offload", [True, False]) def test_activation_checkpointing(cpu_offload, use_reentrant): # as seed manager is singleton diff --git a/tests/test_utils/test_checkpoint/test_checkpoint_1d.py b/tests/test_utils/test_checkpoint/test_checkpoint_1d.py index 8a0fea9ae47a..335be61359ed 100644 --- a/tests/test_utils/test_checkpoint/test_checkpoint_1d.py +++ b/tests/test_utils/test_checkpoint/test_checkpoint_1d.py @@ -1,80 +1,77 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -import pprint -from functools import partial - -import colossalai.nn as col_nn -import pytest -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.initialize import launch -from colossalai.logging import disable_existing_loggers -from colossalai.utils import free_port, is_using_pp -from colossalai.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint -from colossalai.testing import rerun_on_exception, skip_if_not_enough_gpus - - -def build_pipeline(model): - from colossalai.pipeline.utils import partition_uniform - - pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) - pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) - depth = len(model) - start, end = partition_uniform(depth, pipeline_size, 1)[pipeline_rank][0] - layers = [] - for i in range(depth): - if start <= i < end: - layers.append(model[i]) - else: - layers.append(nn.Identity()) - return nn.Sequential(*tuple(layers)) - - -def check_equal(A, B): - assert torch.allclose(A, B, rtol=1e-3, atol=1e-2) - - -def check_checkpoint_1d(rank, world_size, port): - config = dict(parallel=dict(pipeline=dict(size=2), tensor=dict(size=4, mode="1d")),) - - disable_existing_loggers() - launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - - m1 = nn.Sequential(nn.Linear(4, 8), nn.Linear(8, 4)) - sd1 = m1.state_dict() - if gpc.get_global_rank() == 0: - print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd1)}\n") - save_checkpoint("test.pt", 0, m1) - - m2 = nn.Sequential(col_nn.Linear(4, 8), col_nn.Linear(8, 4)) - if is_using_pp(): - m2 = build_pipeline(m2) - - load_checkpoint("test.pt", m2) - sd2 = m2.state_dict() - if is_using_pp() and gpc.get_local_rank(ParallelMode.TENSOR) == 0: - sd2 = gather_pipeline_parallel_state_dict(sd2) - print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd2)}\n") - - if gpc.get_global_rank() == 0: - for k, v in sd1.items(): - assert k in sd2 - check_equal(v, sd2[k].to(torch.device("cpu"))) - - -@pytest.mark.dist -@pytest.mark.skip("takes too long") -@skip_if_not_enough_gpus(min_gpus=8) -@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") -def test_checkpoint_1d(): - world_size = 8 - run_func = partial(check_checkpoint_1d, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) - - -if __name__ == "__main__": - test_checkpoint_1d() +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import pprint + +import pytest +import torch +import torch.nn as nn + +import colossalai.nn as col_nn +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import rerun_if_address_is_in_use, skip_if_not_enough_gpus, spawn +from colossalai.utils import is_using_pp +from colossalai.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint + + +def build_pipeline(model): + from colossalai.pipeline.utils import partition_uniform + + pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) + pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) + depth = len(model) + start, end = partition_uniform(depth, pipeline_size, 1)[pipeline_rank][0] + layers = [] + for i in range(depth): + if start <= i < end: + layers.append(model[i]) + else: + layers.append(nn.Identity()) + return nn.Sequential(*tuple(layers)) + + +def check_equal(A, B): + assert torch.allclose(A, B, rtol=1e-3, atol=1e-2) + + +def check_checkpoint_1d(rank, world_size, port): + config = dict(parallel=dict(pipeline=dict(size=2), tensor=dict(size=4, mode="1d")),) + + disable_existing_loggers() + launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + + m1 = nn.Sequential(nn.Linear(4, 8), nn.Linear(8, 4)) + sd1 = m1.state_dict() + if gpc.get_global_rank() == 0: + print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd1)}\n") + save_checkpoint("test.pt", 0, m1) + + m2 = nn.Sequential(col_nn.Linear(4, 8), col_nn.Linear(8, 4)) + if is_using_pp(): + m2 = build_pipeline(m2) + + load_checkpoint("test.pt", m2) + sd2 = m2.state_dict() + if is_using_pp() and gpc.get_local_rank(ParallelMode.TENSOR) == 0: + sd2 = gather_pipeline_parallel_state_dict(sd2) + print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd2)}\n") + + if gpc.get_global_rank() == 0: + for k, v in sd1.items(): + assert k in sd2 + check_equal(v, sd2[k].to(torch.device("cpu"))) + + +@pytest.mark.dist +@pytest.mark.skip("takes too long") +@skip_if_not_enough_gpus(min_gpus=8) +@rerun_if_address_is_in_use() +def test_checkpoint_1d(): + spawn(check_checkpoint_1d, 8) + + +if __name__ == "__main__": + test_checkpoint_1d() diff --git a/tests/test_utils/test_checkpoint/test_checkpoint_2d.py b/tests/test_utils/test_checkpoint/test_checkpoint_2d.py index 26314290d4de..175d9ef6ceb9 100644 --- a/tests/test_utils/test_checkpoint/test_checkpoint_2d.py +++ b/tests/test_utils/test_checkpoint/test_checkpoint_2d.py @@ -1,80 +1,77 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -import pprint -from functools import partial - -import colossalai.nn as col_nn -import pytest -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.initialize import launch -from colossalai.logging import disable_existing_loggers -from colossalai.utils import free_port, get_current_device, is_using_pp -from colossalai.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint -from colossalai.testing import rerun_on_exception, skip_if_not_enough_gpus - - -def build_pipeline(model): - from colossalai.pipeline.utils import partition_uniform - - pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) - pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) - depth = len(model) - start, end = partition_uniform(depth, pipeline_size, 1)[pipeline_rank][0] - layers = [] - for i in range(depth): - if start <= i < end: - layers.append(model[i]) - else: - layers.append(nn.Identity()) - return nn.Sequential(*tuple(layers)) - - -def check_equal(A, B): - assert torch.allclose(A, B, rtol=1e-3, atol=1e-2) - - -def check_checkpoint_2d(rank, world_size, port): - config = dict(parallel=dict(pipeline=dict(size=2), tensor=dict(size=4, mode="2d")),) - - disable_existing_loggers() - launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - - m1 = nn.Sequential(nn.Linear(4, 8), nn.Linear(8, 4)) - sd1 = m1.state_dict() - if gpc.get_global_rank() == 0: - print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd1)}\n") - save_checkpoint("test.pt", 0, m1) - - m2 = nn.Sequential(col_nn.Linear(4, 8), col_nn.Linear(8, 4)) - if is_using_pp(): - m2 = build_pipeline(m2) - - load_checkpoint("test.pt", m2) - sd2 = m2.state_dict() - if is_using_pp() and gpc.get_local_rank(ParallelMode.TENSOR) == 0: - sd2 = gather_pipeline_parallel_state_dict(sd2) - print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd2)}\n") - - if gpc.get_global_rank() == 0: - for k, v in sd1.items(): - assert k in sd2 - check_equal(v, sd2[k].to(torch.device("cpu"))) - - -@pytest.mark.dist -@pytest.mark.skip("takes too long") -@skip_if_not_enough_gpus(min_gpus=8) -@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") -def test_checkpoint_2d(): - world_size = 8 - run_func = partial(check_checkpoint_2d, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) - - -if __name__ == "__main__": - test_checkpoint_2d() +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import pprint + +import pytest +import torch +import torch.nn as nn + +import colossalai.nn as col_nn +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import rerun_if_address_is_in_use, skip_if_not_enough_gpus, spawn +from colossalai.utils import is_using_pp +from colossalai.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint + + +def build_pipeline(model): + from colossalai.pipeline.utils import partition_uniform + + pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) + pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) + depth = len(model) + start, end = partition_uniform(depth, pipeline_size, 1)[pipeline_rank][0] + layers = [] + for i in range(depth): + if start <= i < end: + layers.append(model[i]) + else: + layers.append(nn.Identity()) + return nn.Sequential(*tuple(layers)) + + +def check_equal(A, B): + assert torch.allclose(A, B, rtol=1e-3, atol=1e-2) + + +def check_checkpoint_2d(rank, world_size, port): + config = dict(parallel=dict(pipeline=dict(size=2), tensor=dict(size=4, mode="2d")),) + + disable_existing_loggers() + launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + + m1 = nn.Sequential(nn.Linear(4, 8), nn.Linear(8, 4)) + sd1 = m1.state_dict() + if gpc.get_global_rank() == 0: + print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd1)}\n") + save_checkpoint("test.pt", 0, m1) + + m2 = nn.Sequential(col_nn.Linear(4, 8), col_nn.Linear(8, 4)) + if is_using_pp(): + m2 = build_pipeline(m2) + + load_checkpoint("test.pt", m2) + sd2 = m2.state_dict() + if is_using_pp() and gpc.get_local_rank(ParallelMode.TENSOR) == 0: + sd2 = gather_pipeline_parallel_state_dict(sd2) + print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd2)}\n") + + if gpc.get_global_rank() == 0: + for k, v in sd1.items(): + assert k in sd2 + check_equal(v, sd2[k].to(torch.device("cpu"))) + + +@pytest.mark.dist +@pytest.mark.skip("takes too long") +@skip_if_not_enough_gpus(min_gpus=8) +@rerun_if_address_is_in_use() +def test_checkpoint_2d(): + spawn(check_checkpoint_2d, 8) + + +if __name__ == "__main__": + test_checkpoint_2d() diff --git a/tests/test_utils/test_checkpoint/test_checkpoint_2p5d.py b/tests/test_utils/test_checkpoint/test_checkpoint_2p5d.py index 3dbd340fd42d..33cb3a65d184 100644 --- a/tests/test_utils/test_checkpoint/test_checkpoint_2p5d.py +++ b/tests/test_utils/test_checkpoint/test_checkpoint_2p5d.py @@ -1,80 +1,77 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -import pprint -from functools import partial - -import colossalai.nn as col_nn -import pytest -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.initialize import launch -from colossalai.logging import disable_existing_loggers -from colossalai.utils import free_port, get_current_device, is_using_pp -from colossalai.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint -from colossalai.testing import rerun_on_exception, skip_if_not_enough_gpus - - -def build_pipeline(model): - from colossalai.pipeline.utils import partition_uniform - - pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) - pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) - depth = len(model) - start, end = partition_uniform(depth, pipeline_size, 1)[pipeline_rank][0] - layers = [] - for i in range(depth): - if start <= i < end: - layers.append(model[i]) - else: - layers.append(nn.Identity()) - return nn.Sequential(*tuple(layers)) - - -def check_equal(A, B): - assert torch.allclose(A, B, rtol=1e-3, atol=1e-2) - - -def check_checkpoint_2p5d(rank, world_size, port): - config = dict(parallel=dict(pipeline=dict(size=2), tensor=dict(size=4, depth=1, mode="2.5d")),) - - disable_existing_loggers() - launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - - m1 = nn.Sequential(nn.Linear(4, 8), nn.Linear(8, 4)) - sd1 = m1.state_dict() - if gpc.get_global_rank() == 0: - print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd1)}\n") - save_checkpoint("test.pt", 0, m1) - - m2 = nn.Sequential(col_nn.Linear(4, 8), col_nn.Linear(8, 4)) - if is_using_pp(): - m2 = build_pipeline(m2) - - load_checkpoint("test.pt", m2) - sd2 = m2.state_dict() - if is_using_pp() and gpc.get_local_rank(ParallelMode.TENSOR) == 0: - sd2 = gather_pipeline_parallel_state_dict(sd2) - print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd2)}\n") - - if gpc.get_global_rank() == 0: - for k, v in sd1.items(): - assert k in sd2 - check_equal(v, sd2[k].to(torch.device("cpu"))) - - -@pytest.mark.dist -@pytest.mark.skip("takes too long") -@skip_if_not_enough_gpus(min_gpus=8) -@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") -def test_checkpoint_2p5d(): - world_size = 8 - run_func = partial(check_checkpoint_2p5d, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) - - -if __name__ == "__main__": - test_checkpoint_2p5d() +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import pprint + +import pytest +import torch +import torch.nn as nn + +import colossalai.nn as col_nn +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import rerun_if_address_is_in_use, skip_if_not_enough_gpus, spawn +from colossalai.utils import is_using_pp +from colossalai.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint + + +def build_pipeline(model): + from colossalai.pipeline.utils import partition_uniform + + pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) + pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) + depth = len(model) + start, end = partition_uniform(depth, pipeline_size, 1)[pipeline_rank][0] + layers = [] + for i in range(depth): + if start <= i < end: + layers.append(model[i]) + else: + layers.append(nn.Identity()) + return nn.Sequential(*tuple(layers)) + + +def check_equal(A, B): + assert torch.allclose(A, B, rtol=1e-3, atol=1e-2) + + +def check_checkpoint_2p5d(rank, world_size, port): + config = dict(parallel=dict(pipeline=dict(size=2), tensor=dict(size=4, depth=1, mode="2.5d")),) + + disable_existing_loggers() + launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + + m1 = nn.Sequential(nn.Linear(4, 8), nn.Linear(8, 4)) + sd1 = m1.state_dict() + if gpc.get_global_rank() == 0: + print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd1)}\n") + save_checkpoint("test.pt", 0, m1) + + m2 = nn.Sequential(col_nn.Linear(4, 8), col_nn.Linear(8, 4)) + if is_using_pp(): + m2 = build_pipeline(m2) + + load_checkpoint("test.pt", m2) + sd2 = m2.state_dict() + if is_using_pp() and gpc.get_local_rank(ParallelMode.TENSOR) == 0: + sd2 = gather_pipeline_parallel_state_dict(sd2) + print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd2)}\n") + + if gpc.get_global_rank() == 0: + for k, v in sd1.items(): + assert k in sd2 + check_equal(v, sd2[k].to(torch.device("cpu"))) + + +@pytest.mark.dist +@pytest.mark.skip("takes too long") +@skip_if_not_enough_gpus(min_gpus=8) +@rerun_if_address_is_in_use() +def test_checkpoint_2p5d(): + spawn(check_checkpoint_2p5d, 8) + + +if __name__ == "__main__": + test_checkpoint_2p5d() diff --git a/tests/test_utils/test_checkpoint/test_checkpoint_3d.py b/tests/test_utils/test_checkpoint/test_checkpoint_3d.py index 38f650547585..73ac2dd5fe18 100644 --- a/tests/test_utils/test_checkpoint/test_checkpoint_3d.py +++ b/tests/test_utils/test_checkpoint/test_checkpoint_3d.py @@ -1,80 +1,77 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -import pprint -from functools import partial - -import colossalai.nn as col_nn -import pytest -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.initialize import launch -from colossalai.logging import disable_existing_loggers -from colossalai.utils import free_port, get_current_device, is_using_pp -from colossalai.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint -from colossalai.testing import rerun_on_exception, skip_if_not_enough_gpus - - -def build_pipeline(model): - from colossalai.pipeline.utils import partition_uniform - - pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) - pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) - depth = len(model) - start, end = partition_uniform(depth, pipeline_size, 1)[pipeline_rank][0] - layers = [] - for i in range(depth): - if start <= i < end: - layers.append(model[i]) - else: - layers.append(nn.Identity()) - return nn.Sequential(*tuple(layers)) - - -def check_equal(A, B): - assert torch.allclose(A, B, rtol=1e-3, atol=1e-2) - - -def check_checkpoint_3d(rank, world_size, port): - config = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=8, mode="3d")),) - - disable_existing_loggers() - launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - - m1 = nn.Sequential(nn.Linear(4, 8), nn.Linear(8, 4)) - sd1 = m1.state_dict() - if gpc.get_global_rank() == 0: - print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd1)}\n") - save_checkpoint("test.pt", 0, m1) - - m2 = nn.Sequential(col_nn.Linear(4, 8), col_nn.Linear(8, 4)) - if is_using_pp(): - m2 = build_pipeline(m2) - - load_checkpoint("test.pt", m2) - sd2 = m2.state_dict() - if is_using_pp() and gpc.get_local_rank(ParallelMode.TENSOR) == 0: - sd2 = gather_pipeline_parallel_state_dict(sd2) - print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd2)}\n") - - if gpc.get_global_rank() == 0: - for k, v in sd1.items(): - assert k in sd2 - check_equal(v, sd2[k].to(torch.device("cpu"))) - - -@pytest.mark.dist -@pytest.mark.skip("takes too long") -@skip_if_not_enough_gpus(min_gpus=8) -@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") -def test_checkpoint_3d(): - world_size = 8 - run_func = partial(check_checkpoint_3d, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) - - -if __name__ == "__main__": - test_checkpoint_3d() +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import pprint + +import pytest +import torch +import torch.nn as nn + +import colossalai.nn as col_nn +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import rerun_if_address_is_in_use, skip_if_not_enough_gpus, spawn +from colossalai.utils import is_using_pp +from colossalai.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint + + +def build_pipeline(model): + from colossalai.pipeline.utils import partition_uniform + + pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) + pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) + depth = len(model) + start, end = partition_uniform(depth, pipeline_size, 1)[pipeline_rank][0] + layers = [] + for i in range(depth): + if start <= i < end: + layers.append(model[i]) + else: + layers.append(nn.Identity()) + return nn.Sequential(*tuple(layers)) + + +def check_equal(A, B): + assert torch.allclose(A, B, rtol=1e-3, atol=1e-2) + + +def check_checkpoint_3d(rank, world_size, port): + config = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=8, mode="3d")),) + + disable_existing_loggers() + launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + + m1 = nn.Sequential(nn.Linear(4, 8), nn.Linear(8, 4)) + sd1 = m1.state_dict() + if gpc.get_global_rank() == 0: + print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd1)}\n") + save_checkpoint("test.pt", 0, m1) + + m2 = nn.Sequential(col_nn.Linear(4, 8), col_nn.Linear(8, 4)) + if is_using_pp(): + m2 = build_pipeline(m2) + + load_checkpoint("test.pt", m2) + sd2 = m2.state_dict() + if is_using_pp() and gpc.get_local_rank(ParallelMode.TENSOR) == 0: + sd2 = gather_pipeline_parallel_state_dict(sd2) + print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd2)}\n") + + if gpc.get_global_rank() == 0: + for k, v in sd1.items(): + assert k in sd2 + check_equal(v, sd2[k].to(torch.device("cpu"))) + + +@pytest.mark.dist +@pytest.mark.skip("takes too long") +@skip_if_not_enough_gpus(min_gpus=8) +@rerun_if_address_is_in_use() +def test_checkpoint_3d(): + spawn(check_checkpoint_3d, 8) + + +if __name__ == "__main__": + test_checkpoint_3d() diff --git a/tests/test_utils/test_checkpoint_io/test_load.py b/tests/test_utils/test_checkpoint_io/test_load.py index 780c13dc534a..b1a741515728 100644 --- a/tests/test_utils/test_checkpoint_io/test_load.py +++ b/tests/test_utils/test_checkpoint_io/test_load.py @@ -3,20 +3,19 @@ from tempfile import TemporaryDirectory from typing import Dict -import colossalai import pytest import torch import torch.distributed as dist -import torch.multiprocessing as mp import torch.nn as nn -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils import free_port -from colossalai.utils.checkpoint_io.io import load, save -from colossalai.utils.checkpoint_io.meta import (ParamDistMeta, ParamRedistMeta, RankRedistMeta, RedistMeta) from torch import Tensor from torch.nn import Module from torch.optim import Adam, Optimizer +import colossalai +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.utils.checkpoint_io.io import load, save +from colossalai.utils.checkpoint_io.meta import ParamDistMeta, ParamRedistMeta, RankRedistMeta, RedistMeta + def check_model_state_dict(a: Dict[str, Tensor], b: Dict[str, Tensor]) -> None: assert set(a.keys()) == set(b.keys()) @@ -120,14 +119,13 @@ def test_save_global_load_global(max_shard_size_gb: float): check_optim_state_dict(optimizer.state_dict(), new_optimizer.state_dict()) -def run_dist(rank, world_size, port, func): +def run_dist(rank, world_size, port, test_fn): colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - func() + test_fn() def launch_dist(fn, world_size: int): - proc_fn = partial(run_dist, world_size=world_size, port=free_port(), func=fn) - mp.spawn(proc_fn, nprocs=world_size) + spawn(run_dist, world_size, test_fn=fn) def save_dist(dir_name: str, zero: bool): diff --git a/tests/test_utils/test_checkpoint_io/test_merge.py b/tests/test_utils/test_checkpoint_io/test_merge.py index 04e454dcb713..255c74adf0a2 100644 --- a/tests/test_utils/test_checkpoint_io/test_merge.py +++ b/tests/test_utils/test_checkpoint_io/test_merge.py @@ -1,18 +1,18 @@ -from colossalai.utils.checkpoint_io.meta import ParamDistMeta -from colossalai.utils.checkpoint_io.constant import GLOBAL_META_FILE_NAME -from colossalai.utils.checkpoint_io.io import save, merge -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils import free_port -from tempfile import TemporaryDirectory -from torch.optim import Adam -from functools import partial -import torch import os +from functools import partial +from tempfile import TemporaryDirectory + import pytest -import colossalai -import torch.nn as nn +import torch import torch.distributed as dist -import torch.multiprocessing as mp +import torch.nn as nn +from torch.optim import Adam + +import colossalai +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.utils.checkpoint_io.constant import GLOBAL_META_FILE_NAME +from colossalai.utils.checkpoint_io.io import merge, save +from colossalai.utils.checkpoint_io.meta import ParamDistMeta class DummyModel(nn.Module): @@ -52,7 +52,7 @@ def test_merge_global(): assert len(os.listdir(output_dir)) == 0 -def run_dist(rank, world_size, port, func): +def run_dist(rank, world_size, port, test_fn): colossalai.launch(config={'parallel': { 'tensor': { 'mode': '1d', @@ -64,7 +64,7 @@ def run_dist(rank, world_size, port, func): host='localhost', port=port, backend='nccl') - func() + test_fn() def run_save_dist(dir_name: str, zero: bool): @@ -100,8 +100,7 @@ def test_merge_tp_dp(zero: bool): with TemporaryDirectory() as dir_name: fn = partial(run_save_dist, dir_name, zero) world_size = 4 - proc_fn = partial(run_dist, world_size=world_size, port=free_port(), func=fn) - mp.spawn(proc_fn, nprocs=world_size) + spawn(run_dist, world_size, test_fn=fn) with TemporaryDirectory() as output_dir: merge(dir_name, output_dir) assert len(os.listdir(output_dir)) == 5 diff --git a/tests/test_utils/test_checkpoint_io/test_redist.py b/tests/test_utils/test_checkpoint_io/test_redist.py index 6e76f3167e31..144715bdfcca 100644 --- a/tests/test_utils/test_checkpoint_io/test_redist.py +++ b/tests/test_utils/test_checkpoint_io/test_redist.py @@ -2,19 +2,23 @@ from functools import partial from tempfile import TemporaryDirectory -import colossalai import pytest import torch import torch.distributed as dist -import torch.multiprocessing as mp import torch.nn as nn -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils import free_port +from torch.optim import Adam + +import colossalai +from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.utils.checkpoint_io.constant import GLOBAL_META_FILE_NAME from colossalai.utils.checkpoint_io.io import redist, save -from colossalai.utils.checkpoint_io.meta import (ParamDistMeta, ParamRedistMeta, PipelineRedistMeta, RankRedistMeta, - RedistMeta) -from torch.optim import Adam +from colossalai.utils.checkpoint_io.meta import ( + ParamDistMeta, + ParamRedistMeta, + PipelineRedistMeta, + RankRedistMeta, + RedistMeta, +) class DummyModel(nn.Module): @@ -105,7 +109,7 @@ def test_global_to_dist(): check_checkpoint_shape(output_dir) -def run_dist(rank, world_size, port, func): +def run_dist(rank, world_size, port, test_fn): colossalai.launch(config={'parallel': { 'tensor': { 'mode': '1d', @@ -117,7 +121,7 @@ def run_dist(rank, world_size, port, func): host='localhost', port=port, backend='nccl') - func() + test_fn() def run_save_dist(dir_name: str, zero: bool): @@ -133,8 +137,7 @@ def test_dist_to_dist(zero: bool): with TemporaryDirectory() as dir_name: fn = partial(run_save_dist, dir_name, zero) world_size = 4 - proc_fn = partial(run_dist, world_size=world_size, port=free_port(), func=fn) - mp.spawn(proc_fn, nprocs=world_size) + spawn(run_dist, world_size, test_fn=fn) with TemporaryDirectory() as output_dir: redist(dir_name, output_dir, get_redist_meta(4), get_dist_metas(4)) if not zero: diff --git a/tests/test_utils/test_checkpoint_io/test_save.py b/tests/test_utils/test_checkpoint_io/test_save.py index 5ff9d0aa2217..e35e566f6ff8 100644 --- a/tests/test_utils/test_checkpoint_io/test_save.py +++ b/tests/test_utils/test_checkpoint_io/test_save.py @@ -3,21 +3,24 @@ from tempfile import TemporaryDirectory from typing import Dict -import colossalai import pytest import torch import torch.distributed as dist -import torch.multiprocessing as mp import torch.nn as nn -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils import free_port -from colossalai.utils.checkpoint_io.constant import (GLOBAL_META_FILE_NAME, META_CKPT_FILE_NAME, MODEL_CKPT_FILE_NAME, - OTHER_CKPT_FILE_NAME) -from colossalai.utils.checkpoint_io.io import save -from colossalai.utils.checkpoint_io.meta import ParamDistMeta from torch import Tensor from torch.optim import Adam +import colossalai +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.utils.checkpoint_io.constant import ( + GLOBAL_META_FILE_NAME, + META_CKPT_FILE_NAME, + MODEL_CKPT_FILE_NAME, + OTHER_CKPT_FILE_NAME, +) +from colossalai.utils.checkpoint_io.io import save +from colossalai.utils.checkpoint_io.meta import ParamDistMeta + def check_model_state_dict(a: Dict[str, Tensor], b: Dict[str, Tensor]) -> None: assert set(a.keys()) == set(b.keys()) @@ -104,9 +107,9 @@ def test_save_global_shard(): }) -def run_dist(rank, world_size, port, func): +def run_dist(rank, world_size, port, test_fn): colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - func() + test_fn() def run_save_dist(dir_name): @@ -124,8 +127,7 @@ def test_save_dist(): with TemporaryDirectory() as dir_name: fn = partial(run_save_dist, dir_name) world_size = 2 - proc_fn = partial(run_dist, world_size=world_size, port=free_port(), func=fn) - mp.spawn(proc_fn, nprocs=world_size) + spawn(run_dist, world_size, test_fn=fn) assert len(os.listdir(dir_name)) == 8 global_meta = torch.load(os.path.join(dir_name, GLOBAL_META_FILE_NAME)) assert len(global_meta['meta']) == 2 diff --git a/tests/test_utils/test_colo_checkpoint.py b/tests/test_utils/test_colo_checkpoint.py index 7c2ad9078f98..89760a5456e7 100644 --- a/tests/test_utils/test_colo_checkpoint.py +++ b/tests/test_utils/test_colo_checkpoint.py @@ -1,20 +1,17 @@ import os import shutil from copy import deepcopy -from functools import partial import pytest import torch import torch.distributed as dist -import torch.multiprocessing as mp from torch.optim.lr_scheduler import CosineAnnealingLR, MultiplicativeLR import colossalai from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR from colossalai.nn.optimizer import ColossalaiOptimizer from colossalai.tensor import ColoTensor, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.utils.checkpoint import load_checkpoint, save_checkpoint from colossalai.utils.cuda import get_current_device from colossalai.zero import ColoInitContext @@ -202,13 +199,7 @@ def run_dist(rank, world_size, port, use_ddp, use_mp_reload, test_scheduler): # @pytest.mark.parametrize('test_scheduler', ['colossalai_cosine_warmup', 'torch_cosine', 'torch_lambda']) @rerun_if_address_is_in_use() def test_checkpoint(world_size, use_ddp, use_mp_reload, test_scheduler=None): - run_func = partial(run_dist, - world_size=world_size, - port=free_port(), - use_ddp=use_ddp, - use_mp_reload=use_mp_reload, - test_scheduler=test_scheduler) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, world_size, use_ddp=use_ddp, use_mp_reload=use_mp_reload, test_scheduler=test_scheduler) if __name__ == '__main__': diff --git a/tests/test_utils/test_commons.py b/tests/test_utils/test_commons.py index 6bfa6f33c812..2633d7da21aa 100644 --- a/tests/test_utils/test_commons.py +++ b/tests/test_utils/test_commons.py @@ -1,15 +1,13 @@ import torch -import torch.multiprocessing as mp import colossalai -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.zero.legacy.gemini.tensor_utils import colo_model_data_tensor_move, colo_model_data_tensor_move_inline from colossalai.zero.legacy.sharded_param import ShardedTensor -def run_tensor_move(rank): - colossalai.launch(config={}, rank=0, world_size=1, host='localhost', port=free_port(), backend='nccl') +def run_tensor_move(rank, world_size, port): + colossalai.launch(config={}, rank=0, world_size=world_size, host='localhost', port=port, backend='nccl') src_t = torch.ones(2, 3).cuda() tgt_t = torch.zeros(2, 3) @@ -36,7 +34,7 @@ def run_tensor_move(rank): @rerun_if_address_is_in_use() def test_tensor_move(): - mp.spawn(run_tensor_move, nprocs=1) + spawn(run_tensor_move, 1) if __name__ == '__main__': diff --git a/tests/test_utils/test_flash_attention.py b/tests/test_utils/test_flash_attention.py index 441cbbb22ce7..7a28b0157384 100644 --- a/tests/test_utils/test_flash_attention.py +++ b/tests/test_utils/test_flash_attention.py @@ -5,6 +5,7 @@ from einops import rearrange from colossalai.kernel.cuda_native.flash_attention import HAS_MEM_EFF_ATTN +from colossalai.testing import clear_cache_before_run, parameterize if HAS_MEM_EFF_ATTN: from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention @@ -22,7 +23,8 @@ def baseline_attention(Z, N_CTX, H, q, k, v, sm_scale): @pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available") -@pytest.mark.parametrize('B, S, H, D_HEAD', [(6, 8, 4, 16)]) +@clear_cache_before_run() +@parameterize('B, S, H, D_HEAD', [(6, 8, 4, 16)]) def test_attention_gpt(B, S, H, D_HEAD, dtype=torch.float16): D = H * D_HEAD @@ -42,7 +44,8 @@ def test_attention_gpt(B, S, H, D_HEAD, dtype=torch.float16): @pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available") -@pytest.mark.parametrize('B, S, H, D_HEAD', [(6, 8, 4, 16)]) +@clear_cache_before_run() +@parameterize('B, S, H, D_HEAD', [(6, 8, 4, 16)]) def test_attention_bert(B, S, H, D_HEAD, dtype=torch.float16): D = H * D_HEAD @@ -65,7 +68,8 @@ def test_attention_bert(B, S, H, D_HEAD, dtype=torch.float16): @pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available") -@pytest.mark.parametrize('B, S, H, D_HEAD', [(6, 8, 4, 16)]) +@clear_cache_before_run() +@parameterize('B, S, H, D_HEAD', [(6, 8, 4, 16)]) def test_attention_no_mask(B, S, H, D_HEAD, dtype=torch.float16): D = H * D_HEAD @@ -84,7 +88,8 @@ def test_attention_no_mask(B, S, H, D_HEAD, dtype=torch.float16): @pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available") -@pytest.mark.parametrize('B, S, T, H, D_HEAD', [(6, 24, 8, 4, 16)]) +@clear_cache_before_run() +@parameterize('B, S, T, H, D_HEAD', [(6, 24, 8, 4, 16)]) def test_cross_attention(B, S, T, H, D_HEAD, dtype=torch.float16): D = H * D_HEAD diff --git a/tests/test_utils/test_lazy_init/test_distribute.py b/tests/test_utils/test_lazy_init/test_distribute.py index 1e32814ab147..2c15ca84efaa 100644 --- a/tests/test_utils/test_lazy_init/test_distribute.py +++ b/tests/test_utils/test_lazy_init/test_distribute.py @@ -1,17 +1,14 @@ -from functools import partial from typing import Optional import pytest import torch -import torch.multiprocessing as mp import torch.nn as nn import colossalai from colossalai.device.device_mesh import DeviceMesh from colossalai.tensor.d_tensor.layout import Layout from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec -from colossalai.testing import parameterize, rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.utils.common import print_rank_0 try: @@ -105,9 +102,7 @@ def run_dist(rank, world_size, port) -> None: @pytest.mark.dist @rerun_if_address_is_in_use() def test_dist_lazy_init(): - world_size = 4 - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, 4) if __name__ == '__main__': diff --git a/tests/test_utils/test_memory.py b/tests/test_utils/test_memory.py index 46a5aeba505b..c88c2f8ec3c5 100644 --- a/tests/test_utils/test_memory.py +++ b/tests/test_utils/test_memory.py @@ -1,12 +1,9 @@ import pytest import colossalai +from colossalai.testing import spawn from colossalai.utils.cuda import get_current_device -from colossalai.utils.memory import colo_set_process_memory_fraction, colo_device_memory_capacity -from colossalai.utils import free_port - -from functools import partial -import torch.multiprocessing as mp +from colossalai.utils.memory import colo_device_memory_capacity, colo_set_process_memory_fraction def _run_colo_set_process_memory_fraction_and_colo_device_memory_capacity(): @@ -24,8 +21,7 @@ def run_dist(rank, world_size, port): @pytest.mark.dist @pytest.mark.parametrize("world_size", [3, 4]) def test_memory_utils(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, world_size) if __name__ == '__main__': diff --git a/tests/test_utils/test_norm_gradient_clipping.py b/tests/test_utils/test_norm_gradient_clipping.py index 259286663033..c0d678026c5f 100644 --- a/tests/test_utils/test_norm_gradient_clipping.py +++ b/tests/test_utils/test_norm_gradient_clipping.py @@ -1,16 +1,15 @@ -from colossalai.tensor import distspec, ColoTensorSpec, ProcessGroup -from colossalai.tensor.colo_parameter import ColoParameter -import colossalai import pytest import torch -import torch.multiprocessing as mp -from colossalai.logging import disable_existing_loggers -from colossalai.utils import free_port, get_current_device +from torch.nn.parameter import Parameter from torch.nn.utils import clip_grad_norm_ -from functools import partial -from colossalai.testing import parameterize, rerun_if_address_is_in_use + +import colossalai +from colossalai.logging import disable_existing_loggers +from colossalai.tensor import ColoTensorSpec, ProcessGroup, distspec +from colossalai.tensor.colo_parameter import ColoParameter +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.utils import get_current_device from colossalai.utils.common import clip_grad_norm -from torch.nn.parameter import Parameter def close(num: float, other: float, rtol: float = 1e-5, atol: float = 1e-8): @@ -71,8 +70,7 @@ def run_dist(rank, world_size, port): @pytest.mark.parametrize('world_size', [1, 2]) @rerun_if_address_is_in_use() def test_zero_clip_grad(world_size: int): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, world_size) if __name__ == '__main__': diff --git a/tests/test_utils/test_zero_gradient_clippling.py b/tests/test_utils/test_zero_gradient_clippling.py index 920656726d63..e99cf388e929 100644 --- a/tests/test_utils/test_zero_gradient_clippling.py +++ b/tests/test_utils/test_zero_gradient_clippling.py @@ -1,21 +1,19 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -import copy from functools import partial import pytest import torch import torch.distributed as dist -import torch.multiprocessing as mp import torch.nn as nn from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.utils import clip_grad_norm_ import colossalai from colossalai.logging import disable_existing_loggers -from colossalai.testing import parameterize, rerun_if_address_is_in_use -from colossalai.utils import checkpoint, clip_grad_norm_fp32, free_port +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.utils import checkpoint, clip_grad_norm_fp32 from colossalai.zero.legacy.shard_utils.tensor_shard_strategy import TensorShardStrategy from colossalai.zero.legacy.sharded_model.sharded_model_v2 import ShardedModelV2 @@ -106,8 +104,7 @@ def run_dist(rank, world_size, port): @rerun_if_address_is_in_use() def test_zero_clip_grad(): world_size = 4 - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, world_size) if __name__ == '__main__': diff --git a/tests/test_zero/test_gemini/test_chunk_mgrv2.py b/tests/test_zero/test_gemini/test_chunk_mgrv2.py index ba0945551b18..7ea063877b5c 100644 --- a/tests/test_zero/test_gemini/test_chunk_mgrv2.py +++ b/tests/test_zero/test_gemini/test_chunk_mgrv2.py @@ -1,13 +1,9 @@ -from functools import partial - import pytest import torch -import torch.multiprocessing as mp import colossalai from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup -from colossalai.testing import parameterize, rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.zero.gemini.chunk import ChunkManager from tests.test_tensor.common_utils import debug_print @@ -64,8 +60,7 @@ def run_dist(rank, world_size, port): @pytest.mark.parametrize('world_size', [2]) @rerun_if_address_is_in_use() def test_chunk_manager(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, world_size) if __name__ == '__main__': diff --git a/tests/test_zero/test_gemini/test_chunkv2.py b/tests/test_zero/test_gemini/test_chunkv2.py index 5f9ba5d3a827..16764aa6b0b1 100644 --- a/tests/test_zero/test_gemini/test_chunkv2.py +++ b/tests/test_zero/test_gemini/test_chunkv2.py @@ -1,15 +1,12 @@ -from functools import partial - import pytest import torch import torch.distributed as dist -import torch.multiprocessing as mp import colossalai from colossalai.tensor import ColoParameter from colossalai.tensor import ProcessGroup as ColoProcessGroup -from colossalai.testing import parameterize, rerun_if_address_is_in_use -from colossalai.utils import free_port, get_current_device +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.utils import get_current_device from colossalai.zero.gemini import TensorState from colossalai.zero.gemini.chunk import Chunk @@ -117,8 +114,7 @@ def run_dist(rank, world_size, port): @pytest.mark.parametrize('world_size', [1, 2, 4]) @rerun_if_address_is_in_use() def test_chunk_function(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, world_size) if __name__ == '__main__': diff --git a/tests/test_zero/test_gemini/test_fwd_bwd.py b/tests/test_zero/test_gemini/test_fwd_bwd.py index 8cfacd018324..697595bc3352 100644 --- a/tests/test_zero/test_gemini/test_fwd_bwd.py +++ b/tests/test_zero/test_gemini/test_fwd_bwd.py @@ -1,8 +1,5 @@ -from functools import partial - import pytest import torch -import torch.multiprocessing as mp from torch.nn.parallel import DistributedDataParallel as DDP from torch.testing import assert_close @@ -10,8 +7,7 @@ from colossalai.amp import convert_to_apex_amp from colossalai.nn.optimizer import HybridAdam from colossalai.tensor import ProcessGroup -from colossalai.testing import parameterize, rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.utils.cuda import get_current_device from colossalai.zero import ColoInitContext, ZeroDDP, ZeroOptimizer from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration @@ -103,8 +99,7 @@ def run_dist(rank, world_size, port): @pytest.mark.parametrize('world_size', [1, 4]) @rerun_if_address_is_in_use() def test_gpt(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, world_size) if __name__ == '__main__': diff --git a/tests/test_zero/test_gemini/test_gemini_use_rmt.py b/tests/test_zero/test_gemini/test_gemini_use_rmt.py index 9d5419e9452d..dd580976d8ea 100644 --- a/tests/test_zero/test_gemini/test_gemini_use_rmt.py +++ b/tests/test_zero/test_gemini/test_gemini_use_rmt.py @@ -1,14 +1,10 @@ -from functools import partial - import pytest import torch -import torch.multiprocessing as mp import colossalai from colossalai.tensor import ProcessGroup -from colossalai.testing import parameterize, rerun_if_address_is_in_use -from colossalai.utils import free_port -from colossalai.zero import ColoInitContext, GeminiAdamOptimizer, GeminiDDP, ZeroDDP +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.zero import ColoInitContext, ZeroDDP from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration from colossalai.zero.gemini.gemini_mgr import GeminiManager from colossalai.zero.gemini.memory_tracer.runtime_mem_tracer import RuntimeMemTracer @@ -98,8 +94,7 @@ def run_dist(rank, world_size, port): @pytest.mark.parametrize('world_size', [1, 4]) @rerun_if_address_is_in_use() def test_gemini_use_rmt(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, world_size) if __name__ == '__main__': diff --git a/tests/test_zero/test_gemini/test_get_torch_model.py b/tests/test_zero/test_gemini/test_get_torch_model.py index c014ced975ce..b3e3b2b22fc3 100644 --- a/tests/test_zero/test_gemini/test_get_torch_model.py +++ b/tests/test_zero/test_gemini/test_get_torch_model.py @@ -1,14 +1,9 @@ -import os -from functools import partial - import pytest import torch -import torch.multiprocessing as mp import colossalai from colossalai.tensor import ColoParameter -from colossalai.testing import parameterize, rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.utils.cuda import get_current_device from colossalai.zero import ColoInitContext, GeminiDDP from colossalai.zero.gemini.utils import get_static_torch_model @@ -50,8 +45,7 @@ def run_dist(rank, world_size, port): @pytest.mark.parametrize('world_size', [1, 4]) @rerun_if_address_is_in_use() def test_convert_torch_module(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, world_size) if __name__ == '__main__': diff --git a/tests/test_zero/test_gemini/test_grad_clip.py b/tests/test_zero/test_gemini/test_grad_clip.py index 65f252c558d7..38b6e474ea98 100644 --- a/tests/test_zero/test_gemini/test_grad_clip.py +++ b/tests/test_zero/test_gemini/test_grad_clip.py @@ -1,25 +1,20 @@ -from functools import partial -from time import time - import pytest import torch import torch.distributed as dist -import torch.multiprocessing as mp from torch.nn.parallel import DistributedDataParallel as DDP from torch.testing import assert_close import colossalai from colossalai.amp import convert_to_apex_amp from colossalai.nn.optimizer import HybridAdam -from colossalai.testing import parameterize, rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.utils.cuda import get_current_device from colossalai.zero import ColoInitContext, ZeroDDP, ZeroOptimizer from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration from colossalai.zero.gemini.gemini_mgr import GeminiManager from tests.components_to_test import run_fwd_bwd from tests.components_to_test.registry import non_distributed_component_funcs -from tests.test_tensor.common_utils import debug_print, set_seed +from tests.test_tensor.common_utils import set_seed def check_param(model: ZeroDDP, torch_model: torch.nn.Module): @@ -105,8 +100,7 @@ def run_dist(rank, world_size, port): @pytest.mark.parametrize('world_size', [1, 2]) @rerun_if_address_is_in_use() def test_grad_clip(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, world_size) if __name__ == '__main__': diff --git a/tests/test_zero/test_gemini/test_inference.py b/tests/test_zero/test_gemini/test_inference.py index 12392d6e57ad..790a0611c9dd 100644 --- a/tests/test_zero/test_gemini/test_inference.py +++ b/tests/test_zero/test_gemini/test_inference.py @@ -1,18 +1,15 @@ -from functools import partial from typing import Callable import pytest import torch import torch.distributed as dist -import torch.multiprocessing as mp from torch.nn.parallel import DistributedDataParallel as DDP from torch.testing import assert_close import colossalai from colossalai.amp import convert_to_apex_amp from colossalai.nn.optimizer import HybridAdam -from colossalai.testing import parameterize, rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.utils.cuda import get_current_device from colossalai.zero import ColoInitContext, ZeroDDP, ZeroOptimizer, post_process_colo_init_ctx, zero_model_wrapper from colossalai.zero.gemini.chunk import ChunkManager, init_chunk_manager, search_chunk_configuration @@ -128,8 +125,7 @@ def run_dist(rank, world_size, port): @pytest.mark.parametrize('world_size', [1, 4]) @rerun_if_address_is_in_use() def test_inference(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, world_size) if __name__ == '__main__': diff --git a/tests/test_zero/test_gemini/test_optim.py b/tests/test_zero/test_gemini/test_optim.py index 7364e59d10b4..8ce20c16e8f9 100644 --- a/tests/test_zero/test_gemini/test_optim.py +++ b/tests/test_zero/test_gemini/test_optim.py @@ -1,18 +1,13 @@ -from functools import partial - import pytest import torch import torch.distributed as dist -import torch.multiprocessing as mp from torch.nn.parallel import DistributedDataParallel as DDP from torch.testing import assert_close import colossalai from colossalai.amp import convert_to_apex_amp from colossalai.nn.optimizer import HybridAdam -from colossalai.tensor import ColoParameter, ColoTensor -from colossalai.testing import parameterize, rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.utils.cuda import get_current_device from colossalai.zero import ColoInitContext, ZeroDDP, ZeroOptimizer, post_process_colo_init_ctx from colossalai.zero.gemini.chunk import ChunkManager, init_chunk_manager, search_chunk_configuration @@ -157,8 +152,7 @@ def run_dist(rank, world_size, port): @pytest.mark.parametrize('world_size', [1, 4]) @rerun_if_address_is_in_use() def test_optim(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, world_size) if __name__ == '__main__': diff --git a/tests/test_zero/test_gemini/test_runtime_mem_tracer.py b/tests/test_zero/test_gemini/test_runtime_mem_tracer.py index 9a3e93493ec3..0e6f283aa5d2 100644 --- a/tests/test_zero/test_gemini/test_runtime_mem_tracer.py +++ b/tests/test_zero/test_gemini/test_runtime_mem_tracer.py @@ -3,12 +3,14 @@ import numpy as np import torch +from colossalai.testing import clear_cache_before_run from colossalai.zero import ColoInitContext from colossalai.zero.gemini.memory_tracer.runtime_mem_tracer import RuntimeMemTracer from tests.components_to_test import run_fwd_bwd from tests.components_to_test.registry import non_distributed_component_funcs +@clear_cache_before_run() def test_runtime_mem_tracer(): test_models = ['gpt2', 'bert', 'simple_net', 'repeated_computed_layers', 'nested_model', 'albert'] diff --git a/tests/test_zero/test_gemini/test_search.py b/tests/test_zero/test_gemini/test_search.py index 71cdf9a18840..35b3b93ade0c 100644 --- a/tests/test_zero/test_gemini/test_search.py +++ b/tests/test_zero/test_gemini/test_search.py @@ -1,14 +1,10 @@ -from functools import partial - import pytest import torch -import torch.distributed as dist -import torch.multiprocessing as mp import colossalai from colossalai.tensor import ComputePattern, ComputeSpec, ProcessGroup, ShardSpec -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils import free_port, get_current_device +from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.utils import get_current_device from colossalai.zero import ColoInitContext from colossalai.zero.gemini.chunk import init_chunk_manager, search_chunk_configuration from tests.components_to_test.registry import non_distributed_component_funcs @@ -115,8 +111,7 @@ def run_dist(rank, world_size, port): @pytest.mark.parametrize('world_size', [1, 4]) @rerun_if_address_is_in_use() def test_search(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, world_size) if __name__ == '__main__': diff --git a/tests/test_zero/test_gemini/test_zeroddp_state_dict.py b/tests/test_zero/test_gemini/test_zeroddp_state_dict.py index 7e759808d1cb..66e05f3ed1ec 100644 --- a/tests/test_zero/test_gemini/test_zeroddp_state_dict.py +++ b/tests/test_zero/test_gemini/test_zeroddp_state_dict.py @@ -1,14 +1,9 @@ -from functools import partial - import pytest import torch -import torch.distributed as dist -import torch.multiprocessing as mp from torch.testing import assert_close import colossalai -from colossalai.testing import parameterize, rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.utils.cuda import get_current_device from colossalai.zero import ColoInitContext, ZeroDDP from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration @@ -105,8 +100,7 @@ def run_dist(rank, world_size, port): @pytest.mark.parametrize('world_size', [1, 4]) @rerun_if_address_is_in_use() def test_zero_ddp(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, world_size) if __name__ == '__main__': diff --git a/tests/test_zero/test_gemini/test_zerooptim_state_dict.py b/tests/test_zero/test_gemini/test_zerooptim_state_dict.py index 996dc4eb8b56..a8af176c5b3d 100644 --- a/tests/test_zero/test_gemini/test_zerooptim_state_dict.py +++ b/tests/test_zero/test_gemini/test_zerooptim_state_dict.py @@ -1,14 +1,10 @@ -from functools import partial - import pytest import torch import torch.distributed as dist -import torch.multiprocessing as mp import colossalai from colossalai.nn.optimizer import HybridAdam -from colossalai.testing import parameterize, rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.utils.cuda import get_current_device from colossalai.zero import ColoInitContext, ZeroDDP, ZeroOptimizer from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration @@ -83,8 +79,7 @@ def run_dist(rank, world_size, port): @pytest.mark.parametrize('world_size', [1, 4]) @rerun_if_address_is_in_use() def test_zero_optim(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, world_size) if __name__ == '__main__': diff --git a/tests/test_zero/test_legacy/test_found_inf.py b/tests/test_zero/test_legacy/test_found_inf.py index 03a1a609b672..e90158e0a43b 100644 --- a/tests/test_zero/test_legacy/test_found_inf.py +++ b/tests/test_zero/test_legacy/test_found_inf.py @@ -1,15 +1,11 @@ -from functools import partial - import pytest import torch -import torch.multiprocessing as mp from common import CONFIG from test_sharded_optim_v2 import _run_step import colossalai from colossalai.nn.optimizer import HybridAdam -from colossalai.testing import parameterize, rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.utils.cuda import get_current_device from colossalai.zero.legacy.init_ctx import ZeroInitContext from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy @@ -64,8 +60,7 @@ def _run_dist(rank, world_size, port): @pytest.mark.parametrize("world_size", [1, 2]) @rerun_if_address_is_in_use() def test_found_inf(world_size): - run_func = partial(_run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(_run_dist, world_size) if __name__ == '__main__': diff --git a/tests/test_zero/test_legacy/test_gemini_manager.py b/tests/test_zero/test_legacy/test_gemini_manager.py index aee9432532c6..0e956f7cc617 100644 --- a/tests/test_zero/test_legacy/test_gemini_manager.py +++ b/tests/test_zero/test_legacy/test_gemini_manager.py @@ -1,10 +1,12 @@ import pytest import torch +from colossalai.testing import clear_cache_before_run from colossalai.zero.legacy.gemini.stateful_tensor import StatefulTensor, TensorState @pytest.mark.dist +@clear_cache_before_run() def test_gemini_manager(): # reset the manager, in case that there exists memory information left manager = StatefulTensor.GST_MGR diff --git a/tests/test_zero/test_legacy/test_init_context.py b/tests/test_zero/test_legacy/test_init_context.py index 0eb8842de3e3..84493827193e 100644 --- a/tests/test_zero/test_legacy/test_init_context.py +++ b/tests/test_zero/test_legacy/test_init_context.py @@ -1,17 +1,13 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -from functools import partial - import pytest import torch -import torch.multiprocessing as mp from common import CONFIG import colossalai from colossalai.logging import get_dist_logger -from colossalai.testing import parameterize, rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.utils.cuda import get_current_device from colossalai.utils.memory import colo_device_memory_used from colossalai.zero.gemini.memory_tracer.utils import colo_model_mem_usage @@ -70,8 +66,7 @@ def run_dist(rank, world_size, port): @pytest.mark.parametrize("world_size", [1, 4]) @rerun_if_address_is_in_use() def test_zero_init_context(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, world_size) if __name__ == '__main__': diff --git a/tests/test_zero/test_legacy/test_param_op.py b/tests/test_zero/test_legacy/test_param_op.py index 9ebacdb70f1a..b91371b98922 100644 --- a/tests/test_zero/test_legacy/test_param_op.py +++ b/tests/test_zero/test_legacy/test_param_op.py @@ -2,6 +2,7 @@ import torch +from colossalai.testing import clear_cache_before_run from colossalai.zero.legacy.gemini.paramhooks import BaseParamHookMgr from tests.components_to_test.registry import non_distributed_component_funcs @@ -49,6 +50,7 @@ def hook(param, grad) -> torch.Tensor or None: return hookwrapper.hook_triggered_times +@clear_cache_before_run() def test_base_param_hook(): test_models = ['repeated_computed_layers', 'resnet18', 'hanging_param_model', 'inline_op_model'] # test_models = ['bert'] diff --git a/tests/test_zero/test_legacy/test_shard_model_v2.py b/tests/test_zero/test_legacy/test_shard_model_v2.py index 884444adf167..93d624aa2bbd 100644 --- a/tests/test_zero/test_legacy/test_shard_model_v2.py +++ b/tests/test_zero/test_legacy/test_shard_model_v2.py @@ -1,17 +1,13 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -from functools import partial - import pytest import torch -import torch.multiprocessing as mp from common import CONFIG, check_grads_padding, run_fwd_bwd from torch.nn.parallel import DistributedDataParallel as DDP import colossalai -from colossalai.testing import parameterize, rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.zero.legacy.init_ctx import ZeroInitContext from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy from colossalai.zero.legacy.sharded_model import ShardedModelV2 @@ -61,8 +57,7 @@ def run_dist(rank, world_size, port): @pytest.mark.parametrize("world_size", [1, 2]) @rerun_if_address_is_in_use() def test_shard_model_v2(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, world_size) if __name__ == '__main__': diff --git a/tests/test_zero/test_legacy/test_shard_param.py b/tests/test_zero/test_legacy/test_shard_param.py index b76648321451..4ba43edceb5d 100644 --- a/tests/test_zero/test_legacy/test_shard_param.py +++ b/tests/test_zero/test_legacy/test_shard_param.py @@ -1,14 +1,11 @@ from copy import deepcopy -from functools import partial import pytest import torch -import torch.multiprocessing as mp from common import CONFIG, allclose import colossalai -from colossalai.testing import parameterize, rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.zero.legacy.gemini.stateful_tensor import StatefulTensor from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy, TensorShardStrategy from colossalai.zero.legacy.sharded_param import ShardedTensor @@ -39,8 +36,7 @@ def _run_shard_tensor(rank, world_size, port): @pytest.mark.parametrize("world_size", [1, 2]) @rerun_if_address_is_in_use() def test_shard_tensor(world_size): - run_func = partial(_run_shard_tensor, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(_run_shard_tensor, world_size) def _run_shard_param_v2(rank, world_size, port): @@ -87,8 +83,7 @@ def _run_shard_param_v2(rank, world_size, port): @pytest.mark.parametrize("world_size", [1, 2]) @rerun_if_address_is_in_use() def test_shard_param_v2(world_size): - run_func = partial(_run_shard_param_v2, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(_run_shard_param_v2, world_size) if __name__ == '__main__': diff --git a/tests/test_zero/test_legacy/test_sharded_optim_state_dict.py b/tests/test_zero/test_legacy/test_sharded_optim_state_dict.py index d257a02854c6..1ca144662722 100644 --- a/tests/test_zero/test_legacy/test_sharded_optim_state_dict.py +++ b/tests/test_zero/test_legacy/test_sharded_optim_state_dict.py @@ -1,14 +1,10 @@ -from functools import partial - import pytest import torch -import torch.multiprocessing as mp import colossalai from colossalai.nn.optimizer import HybridAdam from colossalai.tensor import ProcessGroup -from colossalai.testing import parameterize, rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.utils.cuda import get_current_device from colossalai.zero.legacy.init_ctx import ZeroInitContext from colossalai.zero.legacy.shard_utils import TensorShardStrategy @@ -86,8 +82,7 @@ def run_dist(rank, world_size, port): @pytest.mark.parametrize('world_size', [1, 2]) @rerun_if_address_is_in_use() def test_sharded_optim_state_dist(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, world_size) if __name__ == '__main__': diff --git a/tests/test_zero/test_legacy/test_sharded_optim_v2.py b/tests/test_zero/test_legacy/test_sharded_optim_v2.py index 3eea13d5d5c2..c6f77995ebcd 100644 --- a/tests/test_zero/test_legacy/test_sharded_optim_v2.py +++ b/tests/test_zero/test_legacy/test_sharded_optim_v2.py @@ -1,17 +1,13 @@ -from functools import partial - import pytest import torch import torch.distributed as dist -import torch.multiprocessing as mp from common import CONFIG, check_sharded_model_params from torch.nn.parallel import DistributedDataParallel as DDP import colossalai from colossalai.amp import convert_to_apex_amp from colossalai.nn.optimizer import CPUAdam -from colossalai.testing import parameterize, rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.utils.cuda import get_current_device from colossalai.zero.legacy.init_ctx import ZeroInitContext from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy, TensorShardStrategy @@ -107,8 +103,7 @@ def _run_dist(rank, world_size, port): @pytest.mark.parametrize("world_size", [1, 2]) @rerun_if_address_is_in_use() def test_sharded_optim_v2(world_size): - run_func = partial(_run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(_run_dist, world_size) if __name__ == '__main__': diff --git a/tests/test_zero/test_legacy/test_sharded_optim_with_sync_bn.py b/tests/test_zero/test_legacy/test_sharded_optim_with_sync_bn.py index 05512f59a36a..61d850d06080 100644 --- a/tests/test_zero/test_legacy/test_sharded_optim_with_sync_bn.py +++ b/tests/test_zero/test_legacy/test_sharded_optim_with_sync_bn.py @@ -1,19 +1,15 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -from functools import partial - import pytest import torch import torch.distributed as dist -import torch.multiprocessing as mp from torchvision.models import resnet50 import colossalai from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.zero.legacy.init_ctx import ZeroInitContext from colossalai.zero.legacy.shard_utils import TensorShardStrategy @@ -84,9 +80,7 @@ def test_sharded_optim_with_sync_bn(): wanted if we are doing predictions. """ - world_size = 2 - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, 2) if __name__ == '__main__': diff --git a/tests/test_zero/test_legacy/test_state_dict.py b/tests/test_zero/test_legacy/test_state_dict.py index 40d2820d800a..5f76fff3e5c3 100644 --- a/tests/test_zero/test_legacy/test_state_dict.py +++ b/tests/test_zero/test_legacy/test_state_dict.py @@ -5,12 +5,10 @@ import pytest import torch -import torch.multiprocessing as mp from common import CONFIG import colossalai -from colossalai.testing import parameterize, rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.zero.legacy.init_ctx import ZeroInitContext from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy, TensorShardStrategy from colossalai.zero.legacy.sharded_model import ShardedModelV2 @@ -50,8 +48,7 @@ def run_dist(rank, world_size, port): @pytest.mark.parametrize("world_size", [1, 2]) @rerun_if_address_is_in_use() def test_zero_state_dict(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, world_size) if __name__ == '__main__': diff --git a/tests/test_zero/test_legacy/test_tensor_utils.py b/tests/test_zero/test_legacy/test_tensor_utils.py index 3114481707c2..238bc3fe1a98 100644 --- a/tests/test_zero/test_legacy/test_tensor_utils.py +++ b/tests/test_zero/test_legacy/test_tensor_utils.py @@ -1,12 +1,8 @@ -from functools import partial - import pytest import torch -import torch.multiprocessing as mp import colossalai -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.utils.cuda import get_current_device from colossalai.zero.legacy.gemini.stateful_tensor import StatefulTensor from colossalai.zero.legacy.gemini.tensor_utils import ( @@ -91,8 +87,7 @@ def run_dist(rank, world_size, port): @pytest.mark.parametrize("world_size", [2, 4]) @rerun_if_address_is_in_use() def test_zero_tensor_utils(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, world_size) if __name__ == '__main__': diff --git a/tests/test_zero/test_legacy/test_zero_engine.py b/tests/test_zero/test_legacy/test_zero_engine.py index 1e7f53358526..dc8847ce56ab 100644 --- a/tests/test_zero/test_legacy/test_zero_engine.py +++ b/tests/test_zero/test_legacy/test_zero_engine.py @@ -1,19 +1,15 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -from functools import partial - import pytest import torch import torch.distributed as dist -import torch.multiprocessing as mp from common import MP_PARALLEL_CONFIG, ZERO_PARALLEL_CONFIG, check_params, check_sharded_model_params from torch.nn.parallel import DistributedDataParallel as DDP import colossalai from colossalai.core import global_context as gpc -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.zero.legacy.init_ctx import ZeroInitContext from colossalai.zero.legacy.sharded_model.utils import col_model_deepcopy from colossalai.zero.low_level._utils import has_inf_or_nan @@ -96,16 +92,14 @@ def run_dist(rank, world_size, port, parallel_config): @pytest.mark.parametrize("world_size", [2, 4]) @rerun_if_address_is_in_use() def test_mp_engine(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port(), parallel_config=MP_PARALLEL_CONFIG) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, world_size, parallel_config=MP_PARALLEL_CONFIG) @pytest.mark.dist @pytest.mark.parametrize("world_size", [1, 2]) @rerun_if_address_is_in_use() def test_zero_engine(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port(), parallel_config=ZERO_PARALLEL_CONFIG) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, world_size, parallel_config=ZERO_PARALLEL_CONFIG) if __name__ == '__main__': diff --git a/tests/test_zero/test_low_level/test_grad_acc.py b/tests/test_zero/test_low_level/test_grad_acc.py index 504df202e168..2ae1f3a99d79 100644 --- a/tests/test_zero/test_low_level/test_grad_acc.py +++ b/tests/test_zero/test_low_level/test_grad_acc.py @@ -1,16 +1,14 @@ import copy -from functools import partial import pytest import torch -import torch.multiprocessing as mp import torch.nn as nn from torch.nn.parallel import DistributedDataParallel as DDP from torch.testing import assert_close import colossalai +from colossalai.testing import spawn from colossalai.testing.random import seed_all -from colossalai.utils import free_port from colossalai.zero import LowLevelZeroOptimizer @@ -158,9 +156,7 @@ def run_dist(rank, world_size, port): @pytest.mark.dist def test_grad_accumulation(): - world_size = 2 - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, 2) if __name__ == '__main__': diff --git a/tests/test_zero/test_low_level/test_zero1_2.py b/tests/test_zero/test_low_level/test_zero1_2.py index ed76e0171fb4..4086af9d896e 100644 --- a/tests/test_zero/test_low_level/test_zero1_2.py +++ b/tests/test_zero/test_low_level/test_zero1_2.py @@ -1,17 +1,14 @@ import copy -from functools import partial import pytest import torch -import torch.multiprocessing as mp import torch.nn as nn from torch.nn.parallel import DistributedDataParallel as DDP from torch.testing import assert_close import colossalai -from colossalai.testing import rerun_if_address_is_in_use +from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.testing.random import seed_all -from colossalai.utils import free_port from colossalai.zero import LowLevelZeroOptimizer @@ -179,9 +176,7 @@ def run_dist(rank, world_size, port): @pytest.mark.dist @rerun_if_address_is_in_use() def test_zero_1_2(): - world_size = 2 - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, 2) if __name__ == '__main__': diff --git a/tests/test_zero/test_low_level/test_zero_init.py b/tests/test_zero/test_low_level/test_zero_init.py index 803d0021df96..aeeaff5b5cb9 100644 --- a/tests/test_zero/test_low_level/test_zero_init.py +++ b/tests/test_zero/test_low_level/test_zero_init.py @@ -1,14 +1,12 @@ -from functools import partial - import pytest import torch import torch.distributed as dist -import torch.multiprocessing as mp import torch.nn as nn import colossalai from colossalai.tensor import ProcessGroup -from colossalai.utils import free_port, get_current_device +from colossalai.testing import spawn +from colossalai.utils import get_current_device from colossalai.zero import ColoInitContext, LowLevelZeroOptimizer @@ -51,9 +49,7 @@ def run_dist(rank, world_size, port): @pytest.mark.dist def test_zero_init(): - world_size = 4 - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, 4) if __name__ == '__main__': diff --git a/tests/test_zero/test_low_level/test_zero_tp.py b/tests/test_zero/test_low_level/test_zero_tp.py index bb7495583dc9..f0804f4bb5ba 100644 --- a/tests/test_zero/test_low_level/test_zero_tp.py +++ b/tests/test_zero/test_low_level/test_zero_tp.py @@ -1,16 +1,13 @@ -from functools import partial - import pytest import torch -import torch.multiprocessing as mp import torch.nn as nn from torch.nn.parallel import DistributedDataParallel as DDP from torch.testing import assert_close import colossalai from colossalai.tensor import ProcessGroup -from colossalai.testing import parameterize, rerun_if_address_is_in_use -from colossalai.utils import free_port, get_current_device +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.utils import get_current_device from colossalai.zero import ColoInitContext, LowLevelZeroOptimizer from tests.test_tensor.common_utils import set_seed, split_param_col_tp1d, split_param_row_tp1d, tensor_shard_equal @@ -89,9 +86,7 @@ def run_dist(rank, world_size, port): @pytest.mark.dist @rerun_if_address_is_in_use() def test_zero_with_tp(): - world_size = 4 - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) + spawn(run_dist, 4) if __name__ == '__main__': From 6afeb1202aeedc68d9ef1f77b41f3e5f55e0e121 Mon Sep 17 00:00:00 2001 From: Fazzie-Maqianli <55798671+Fazziekey@users.noreply.github.com> Date: Thu, 6 Apr 2023 15:04:48 +0800 Subject: [PATCH 19/27] add community example dictionary (#3465) --- .../Chat/examples/community/README.md | 1 + .../{EasyPeftModel.md => peft/README.md} | 4 +- .../community/{ => peft}/easy_dataset.py | 126 +++++++++--------- .../community/{ => peft}/easy_models.py | 13 +- .../{ => peft}/train_peft_prompts.py | 17 +-- .../community/{ => peft}/train_peft_sft.py | 25 ++-- 6 files changed, 94 insertions(+), 92 deletions(-) create mode 100644 applications/Chat/examples/community/README.md rename applications/Chat/examples/community/{EasyPeftModel.md => peft/README.md} (98%) rename applications/Chat/examples/community/{ => peft}/easy_dataset.py (74%) rename applications/Chat/examples/community/{ => peft}/easy_models.py (94%) rename applications/Chat/examples/community/{ => peft}/train_peft_prompts.py (95%) rename applications/Chat/examples/community/{ => peft}/train_peft_sft.py (90%) diff --git a/applications/Chat/examples/community/README.md b/applications/Chat/examples/community/README.md new file mode 100644 index 000000000000..5e3f47db37b3 --- /dev/null +++ b/applications/Chat/examples/community/README.md @@ -0,0 +1 @@ +# Community Examples diff --git a/applications/Chat/examples/community/EasyPeftModel.md b/applications/Chat/examples/community/peft/README.md similarity index 98% rename from applications/Chat/examples/community/EasyPeftModel.md rename to applications/Chat/examples/community/peft/README.md index 16c4af76b91f..a82f02a87317 100644 --- a/applications/Chat/examples/community/EasyPeftModel.md +++ b/applications/Chat/examples/community/peft/README.md @@ -10,7 +10,7 @@ Since the current pypi peft package(0.2) has some bugs, please install the peft git clone https://github.com/huggingface/peft cd peft pip install . -``` +``` # Usage For SFT training, just call train_peft_sft.py @@ -21,4 +21,4 @@ For stage-3 rlhf training, call train_peft_prompts.py. Its arguments are almost idential to train_prompts.py. The only difference is that I use text files to indicate the prompt and pretrained data file. The models are included in easy_models.py. Currently only bloom models are tested, but technically gpt2/opt/llama should be supported. # Dataformat -Please refer the formats in test_sft.txt, test_prompts.txt, test_pretrained.txt. \ No newline at end of file +Please refer the formats in test_sft.txt, test_prompts.txt, test_pretrained.txt. diff --git a/applications/Chat/examples/community/easy_dataset.py b/applications/Chat/examples/community/peft/easy_dataset.py similarity index 74% rename from applications/Chat/examples/community/easy_dataset.py rename to applications/Chat/examples/community/peft/easy_dataset.py index 15dd9a3ccd1d..13dceef79145 100644 --- a/applications/Chat/examples/community/easy_dataset.py +++ b/applications/Chat/examples/community/peft/easy_dataset.py @@ -1,19 +1,17 @@ import copy +import json from typing import Dict, Sequence + +import torch from datasets import load_dataset from torch.utils.data import Dataset -from transformers import AutoTokenizer -import torch from tqdm import tqdm -import json - -from tqdm import tqdm -import json +from transformers import AutoTokenizer IGNORE_INDEX = -100 -def _tokenize_fn(strings: Sequence[str], tokenizer: AutoTokenizer,max_length :int = 512) -> Dict: +def _tokenize_fn(strings: Sequence[str], tokenizer: AutoTokenizer, max_length: int = 512) -> Dict: """Tokenize a list of strings.""" tokenized_list = [ tokenizer( @@ -36,15 +34,12 @@ def _tokenize_fn(strings: Sequence[str], tokenizer: AutoTokenizer,max_length :in ) -def preprocess( - sources: Sequence[str], - targets: Sequence[str], - tokenizer: AutoTokenizer, - max_length :int = 512 -) -> Dict: +def preprocess(sources: Sequence[str], targets: Sequence[str], tokenizer: AutoTokenizer, max_length: int = 512) -> Dict: """Preprocess the data by tokenizing.""" examples = [s + t for s, t in zip(sources, targets)] - examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer,max_length) for strings in (examples, sources)] + examples_tokenized, sources_tokenized = [ + _tokenize_fn(strings, tokenizer, max_length) for strings in (examples, sources) + ] input_ids = examples_tokenized["input_ids"] labels = copy.deepcopy(input_ids) for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]): @@ -53,59 +48,60 @@ def preprocess( class EasySupervisedDataset(Dataset): - def __init__(self, data_file :str, tokenizer :AutoTokenizer,max_length :int = 512) -> None: - super(EasySupervisedDataset,self).__init__() - with open(data_file,"r",encoding="UTF-8") as f: + + def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length: int = 512) -> None: + super(EasySupervisedDataset, self).__init__() + with open(data_file, "r", encoding="UTF-8") as f: all_lines = f.readlines() #split to source and target ,source the characters before "回答:" including "回答:", target the characters after "回答:" - sources,targets = [],[] + sources, targets = [], [] for line in all_lines: if "回答:" in line: sep_index = line.index("回答:") - sources.append(line[:sep_index+3]) - targets.append(line[sep_index+3:]+tokenizer.eos_token) + sources.append(line[:sep_index + 3]) + targets.append(line[sep_index + 3:] + tokenizer.eos_token) else: sources.append(line) - targets.append(""+tokenizer.eos_token) - data_dict = preprocess(sources, targets, tokenizer,max_length) + targets.append("" + tokenizer.eos_token) + data_dict = preprocess(sources, targets, tokenizer, max_length) self.input_ids = data_dict["input_ids"] self.labels = data_dict["labels"] self.data_file = data_file - + def __len__(self): return len(self.input_ids) def __getitem__(self, i) -> Dict[str, torch.Tensor]: return dict(input_ids=self.input_ids[i], labels=self.labels[i]) - + def __repr__(self): return f"LawSupervisedDataset(data_file={self.data_file}, input_ids_len={len(self.input_ids)}, labels_len={len(self.labels)})" - + def __str__(self): return f"LawSupervisedDataset(data_file={self.data_file}, input_ids_len={len(self.input_ids)}, labels_len={len(self.labels)})" + class EasyPromptsDataset(Dataset): - def __init__(self,data_file :str, tokenizer :AutoTokenizer, max_length :int = 96) -> None: - super(EasyPromptsDataset,self).__init__() - with open(data_file,"r",encoding="UTF-8") as f: + + def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length: int = 96) -> None: + super(EasyPromptsDataset, self).__init__() + with open(data_file, "r", encoding="UTF-8") as f: all_lines = f.readlines() - all_lines = [line if "回答:" not in line else line[:line.index("回答:")+3] for line in all_lines] + all_lines = [line if "回答:" not in line else line[:line.index("回答:") + 3] for line in all_lines] self.prompts = [ - tokenizer(line, - return_tensors='pt', - max_length=max_length, - padding='max_length', - truncation=True)['input_ids'].to(torch.cuda.current_device()).squeeze(0) + tokenizer(line, return_tensors='pt', max_length=max_length, padding='max_length', + truncation=True)['input_ids'].to(torch.cuda.current_device()).squeeze(0) for line in tqdm(all_lines) ] self.data_file = data_file + def __len__(self): return len(self.prompts) - + def __getitem__(self, idx): return self.prompts[idx] - + def __repr__(self): return f"LawPromptsDataset(data_file={self.data_file}, prompts_len={len(self.prompts)})" @@ -114,8 +110,9 @@ def __str__(self): class EasyRewardDataset(Dataset): - def __init__(self,train_file :str,tokenizer :AutoTokenizer, special_token = None,max_length = 512) -> None: - super(EasyRewardDataset,self).__init__() + + def __init__(self, train_file: str, tokenizer: AutoTokenizer, special_token=None, max_length=512) -> None: + super(EasyRewardDataset, self).__init__() self.chosen = [] self.reject = [] if special_token is None: @@ -124,11 +121,11 @@ def __init__(self,train_file :str,tokenizer :AutoTokenizer, special_token = None self.end_token = special_token print(self.end_token) #read all lines in the train_file to a list - with open(train_file,"r",encoding="UTF-8") as f: + with open(train_file, "r", encoding="UTF-8") as f: all_lines = f.readlines() for line in tqdm(all_lines): data = json.loads(line) - prompt = "提问:"+data['prompt']+" 回答:" + prompt = "提问:" + data['prompt'] + " 回答:" chosen = prompt + data['chosen'] + self.end_token chosen_token = tokenizer(chosen, @@ -159,7 +156,7 @@ def __len__(self): def __getitem__(self, idx): return self.chosen[idx]["input_ids"], self.chosen[idx]["attention_mask"], self.reject[idx][ "input_ids"], self.reject[idx]["attention_mask"] - + #python representation of the object and the string representation of the object def __repr__(self): return f"LawRewardDataset(chosen_len={len(self.chosen)}, reject_len={len(self.reject)})" @@ -167,27 +164,30 @@ def __repr__(self): def __str__(self): return f"LawRewardDataset(chosen_len={len(self.chosen)}, reject_len={len(self.reject)})" + ''' Easy SFT just accept a text file which can be read line by line. However the datasest will group texts together to max_length so LLM will learn the texts meaning better. If individual lines are not related, just set is_group_texts to False. ''' + + class EasySFTDataset(Dataset): - - def __init__(self,data_file :str,tokenizer :AutoTokenizer,max_length = 512,is_group_texts = True) -> None: + + def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length=512, is_group_texts=True) -> None: super().__init__() #read the data_file line by line - with open(data_file,"r",encoding="UTF-8") as f: + with open(data_file, "r", encoding="UTF-8") as f: #encode the text data line by line and put raw python list input_ids only to raw_input_ids list raw_input_ids = [] for line in f: encoded_ids = tokenizer.encode(line) #if the encoded_ids is longer than max_length, then split it into several parts if len(encoded_ids) > max_length: - for i in range(0,len(encoded_ids),max_length): - raw_input_ids.append(encoded_ids[i:i+max_length]) + for i in range(0, len(encoded_ids), max_length): + raw_input_ids.append(encoded_ids[i:i + max_length]) else: raw_input_ids.append(encoded_ids) - + grouped_inpup_ids = [] current_input_ids = [] attention_mask = [] @@ -199,44 +199,42 @@ def __init__(self,data_file :str,tokenizer :AutoTokenizer,max_length = 512,is_gr #pad the current_input_ids to max_length with tokenizer.pad_token_id padded_length = max_length - len(current_input_ids) current_input_ids.extend([tokenizer.pad_token_id] * padded_length) - grouped_inpup_ids.append(torch.tensor(current_input_ids,dtype=torch.long)) - attention_mask.append(torch.tensor([1] * (max_length - padded_length) + [0] * padded_length,dtype=torch.long)) + grouped_inpup_ids.append(torch.tensor(current_input_ids, dtype=torch.long)) + attention_mask.append( + torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long)) current_input_ids = [] else: current_input_ids.extend(input_ids) if len(current_input_ids) > 0: padded_length = max_length - len(current_input_ids) current_input_ids.extend([tokenizer.pad_token_id] * padded_length) - grouped_inpup_ids.append(torch.tensor(current_input_ids,dtype=torch.long)) - attention_mask.append(torch.tensor([1] * (max_length - padded_length) + [0] * padded_length,dtype=torch.long)) + grouped_inpup_ids.append(torch.tensor(current_input_ids, dtype=torch.long)) + attention_mask.append( + torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long)) else: #just append the raw_input_ids to max_length for input_ids in raw_input_ids: padded_length = max_length - len(input_ids) input_ids.extend([tokenizer.pad_token_id] * padded_length) - attention_mask.append(torch.tensor([1] * (max_length - padded_length) + [0] * padded_length,dtype=torch.long)) - grouped_inpup_ids.append(torch.tensor(input_ids,dtype=torch.long)) + attention_mask.append( + torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long)) + grouped_inpup_ids.append(torch.tensor(input_ids, dtype=torch.long)) self.input_ids = grouped_inpup_ids self.labels = copy.deepcopy(self.input_ids) self.file_name = data_file self.attention_mask = attention_mask - + def __len__(self): return len(self.input_ids) - + #get item from dataset - def __getitem__(self,idx): - return dict(input_ids=self.input_ids[idx],labels=self.labels[idx],attention_mask=self.attention_mask[idx]) - + def __getitem__(self, idx): + return dict(input_ids=self.input_ids[idx], labels=self.labels[idx], attention_mask=self.attention_mask[idx]) + #generate the dataset description to be printed by print in python def __repr__(self): return f"EasySFTDataset(len={len(self)},\nfile_name is {self.file_name})" - + #generate the dataset description to be printed by print in python def __str__(self): return f"EasySFTDataset(len={len(self)},\nfile_name is {self.file_name})" - - - - - \ No newline at end of file diff --git a/applications/Chat/examples/community/easy_models.py b/applications/Chat/examples/community/peft/easy_models.py similarity index 94% rename from applications/Chat/examples/community/easy_models.py rename to applications/Chat/examples/community/peft/easy_models.py index 080fc1802a02..fe294868159d 100644 --- a/applications/Chat/examples/community/easy_models.py +++ b/applications/Chat/examples/community/peft/easy_models.py @@ -3,12 +3,12 @@ import torch import torch.nn as nn import torch.nn.functional as F -from torch.nn.modules import Module - from coati.models.generation import generate -from coati.models.utils import log_probs_from_logits,masked_mean -from transformers import BloomConfig,BloomForCausalLM +from coati.models.utils import log_probs_from_logits, masked_mean from peft import PeftModel +from torch.nn.modules import Module +from transformers import BloomConfig, BloomForCausalLM + class Actor(Module): """ @@ -87,11 +87,10 @@ def __init__(self, else: model = BloomForCausalLM(BloomConfig()) if lora_path is not None: - model = PeftModel.from_pretrained(model,lora_path) + model = PeftModel.from_pretrained(model, lora_path) if checkpoint: model.gradient_checkpointing_enable() super().__init__(model) - + def print_trainable_parameters(self): self.get_base_model().print_trainable_parameters() - diff --git a/applications/Chat/examples/community/train_peft_prompts.py b/applications/Chat/examples/community/peft/train_peft_prompts.py similarity index 95% rename from applications/Chat/examples/community/train_peft_prompts.py rename to applications/Chat/examples/community/peft/train_peft_prompts.py index b9394c9e4190..0e277021e917 100644 --- a/applications/Chat/examples/community/train_peft_prompts.py +++ b/applications/Chat/examples/community/peft/train_peft_prompts.py @@ -5,21 +5,22 @@ import torch.distributed as dist from coati.dataset import DataCollatorForSupervisedDataset, PromptDataset, SupervisedDataset from coati.models.bloom import BLOOMRM, BLOOMCritic -from easy_models import BLOOMActor from coati.models.gpt import GPTRM, GPTActor, GPTCritic from coati.models.llama import LlamaActor, LlamaCritic, LlamaRM from coati.models.opt import OPTRM, OPTActor, OPTCritic from coati.trainer import PPOTrainer from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy from coati.utils import prepare_llama_tokenizer_and_embedding +from easy_dataset import EasyPromptsDataset, EasySupervisedDataset +from easy_models import BLOOMActor +from peft import PeftModel from torch.optim import Adam from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler from transformers import AutoTokenizer, BloomTokenizerFast, GPT2Tokenizer, LlamaTokenizer from colossalai.nn.optimizer import HybridAdam -from peft import PeftModel -from easy_dataset import EasyPromptsDataset,EasySupervisedDataset + def main(args): # configure strategy @@ -41,7 +42,7 @@ def main(args): if args.model == 'bloom': # initial_model = BLOOMActor(pretrained=args.pretrain) print('Using peft lora to load Bloom model as inital_model') - initial_model = BLOOMActor(pretrained=args.pretrain,lora_path=args.sft_lora_path) + initial_model = BLOOMActor(pretrained=args.pretrain, lora_path=args.sft_lora_path) print('Using peft lora to load Bloom model as initial_model (Done)') else: raise ValueError(f'Unsupported actor model "{args.model}"') @@ -54,7 +55,7 @@ def main(args): if rm_model_name == 'gpt2': reward_model = GPTRM(pretrained=args.rm_pretrain) elif rm_model_name == 'bloom': - print("load bloom reward model ",args.rm_pretrain) + print("load bloom reward model ", args.rm_pretrain) reward_model = BLOOMRM(pretrained=args.rm_pretrain) elif rm_model_name == 'opt': reward_model = OPTRM(pretrained=args.rm_pretrain) @@ -75,7 +76,7 @@ def main(args): if args.model == 'bloom': # actor = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank) print('Using peft lora to load Bloom model as Actor') - actor = BLOOMActor(pretrained=args.pretrain,lora_path=args.sft_lora_path) + actor = BLOOMActor(pretrained=args.pretrain, lora_path=args.sft_lora_path) print('Using peft lora to load Bloom model as Actor (Done)') else: raise ValueError(f'Unsupported actor model "{args.model}"') @@ -83,7 +84,7 @@ def main(args): if rm_model_name == 'gpt2': critic = GPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True) elif rm_model_name == 'bloom': - print("load bloom critic ",args.rm_pretrain," lora_rank ",args.lora_rank," use_action_mask ",True) + print("load bloom critic ", args.rm_pretrain, " lora_rank ", args.lora_rank, " use_action_mask ", True) critic = BLOOMCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True) print("load bloom critic (Done) ") elif rm_model_name == 'opt': @@ -130,7 +131,7 @@ def main(args): data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) - prompt_dataset = EasyPromptsDataset(args.prompt_path,tokenizer) + prompt_dataset = EasyPromptsDataset(args.prompt_path, tokenizer) if dist.is_initialized() and dist.get_world_size() > 1: prompt_sampler = DistributedSampler(prompt_dataset, shuffle=True, seed=42, drop_last=True) else: diff --git a/applications/Chat/examples/community/train_peft_sft.py b/applications/Chat/examples/community/peft/train_peft_sft.py similarity index 90% rename from applications/Chat/examples/community/train_peft_sft.py rename to applications/Chat/examples/community/peft/train_peft_sft.py index 65d901261bfc..fcc65e24478a 100644 --- a/applications/Chat/examples/community/train_peft_sft.py +++ b/applications/Chat/examples/community/peft/train_peft_sft.py @@ -14,19 +14,19 @@ from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy from coati.utils import prepare_llama_tokenizer_and_embedding from datasets import load_dataset +from easy_dataset import EasyDataset +from peft import LoraConfig, PeftModel, TaskType, get_peft_model from torch.optim import Adam from torch.utils.data import DataLoader +from torch.utils.data.dataloader import default_collate from torch.utils.data.distributed import DistributedSampler -from transformers import AutoTokenizer, BloomTokenizerFast,AutoModelForCausalLM +from transformers import AutoModelForCausalLM, AutoTokenizer, BloomTokenizerFast from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer from colossalai.logging import get_dist_logger from colossalai.nn.optimizer import HybridAdam from colossalai.tensor import ColoParameter -from torch.utils.data.dataloader import default_collate -from peft import LoraConfig, TaskType,get_peft_model,PeftModel -from easy_dataset import EasyDataset def train(args): # configure strategy @@ -48,17 +48,20 @@ def train(args): #if the args.save_path exists and args.save_path+'/adapter_config.json' exists, we'll load the adapter_config.json if os.path.exists(args.save_path) and os.path.exists(args.save_path+'/adapter_config.json') \ and os.path.exists(args.save_path+'/adapter_model.bin'): - print("loading from saved peft model ",args.save_path) + print("loading from saved peft model ", args.save_path) model = PeftModel.from_pretrained(model, args.save_path) else: #we'll use peft lora library to do the lora lora_rank = args.lora_rank if args.lora_rank > 0 else 32 #config lora with rank of lora_rank - lora_config = LoraConfig(task_type=TaskType.CAUSAL_LM, inference_mode=False, r=lora_rank, lora_alpha=32, lora_dropout=0.1) + lora_config = LoraConfig(task_type=TaskType.CAUSAL_LM, + inference_mode=False, + r=lora_rank, + lora_alpha=32, + lora_dropout=0.1) model = get_peft_model(model, lora_config) model.print_trainable_parameters() - # configure tokenizer if args.model == 'gpt2': tokenizer = GPT2Tokenizer.from_pretrained('gpt2') @@ -103,12 +106,12 @@ def train(args): logger.set_level('WARNING') # configure dataset - law_dataset = EasyDataset(args.dataset,tokenizer=tokenizer,is_group_texts=not args.is_short_text) + law_dataset = EasyDataset(args.dataset, tokenizer=tokenizer, is_group_texts=not args.is_short_text) train_dataset = law_dataset print(train_dataset) eval_dataset = None if args.eval_dataset is not None: - eval_dataset = EasyDataset(args.eval_dataset,tokenizer=tokenizer,is_group_texts=not args.is_short_text) + eval_dataset = EasyDataset(args.eval_dataset, tokenizer=tokenizer, is_group_texts=not args.is_short_text) data_collator = default_collate if dist.is_initialized() and dist.get_world_size() > 1: train_sampler = DistributedSampler(train_dataset, @@ -181,7 +184,7 @@ def train(args): parser.add_argument('--log_interval', type=int, default=100, help="how many steps to log") parser.add_argument('--lr', type=float, default=5e-6) parser.add_argument('--accimulation_steps', type=int, default=8) - parser.add_argument('--enable_peft_lora',action='store_true', default=False) - parser.add_argument("--is_short_text",action='store_true', default=False) + parser.add_argument('--enable_peft_lora', action='store_true', default=False) + parser.add_argument("--is_short_text", action='store_true', default=False) args = parser.parse_args() train(args) From 52a933e17509c71811e919b165de38cb3d5d6d41 Mon Sep 17 00:00:00 2001 From: jiangmingyan <37931082+jiangmingyan@users.noreply.github.com> Date: Thu, 6 Apr 2023 16:23:39 +0800 Subject: [PATCH 20/27] [checkpoint] support huggingface style sharded checkpoint (#3461) * [checkpoint] support huggingface style sharded checkpoint * [checkpoint] support huggingface style sharded checkpoint * [checkpoint] support huggingface style sharded checkpoint * [checkpoint] support huggingface style sharded checkpoint * [checkpoint] support huggingface style sharded checkpoint --------- Co-authored-by: luchen --- .../checkpoint_io/general_checkpoint_io.py | 101 ++++++++++--- colossalai/checkpoint_io/index_file.py | 6 + colossalai/checkpoint_io/utils.py | 140 +++++++++++++++++- .../test_general_checkpoint_io.py | 89 ++++++++--- 4 files changed, 291 insertions(+), 45 deletions(-) diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py index c779f4c17355..2a76f1718469 100644 --- a/colossalai/checkpoint_io/general_checkpoint_io.py +++ b/colossalai/checkpoint_io/general_checkpoint_io.py @@ -2,37 +2,35 @@ import torch.nn as nn from torch.optim import Optimizer +import logging +import os +import json +import gc from .checkpoint_io_base import CheckpointIO from .index_file import CheckpointIndexFile -from .utils import has_index_file, load_state_dict, save_state_dict +from .utils import ( + has_index_file, + load_state_dict, + save_state_dict, + is_safetensors_available, + shard_checkpoint, + load_shard_state_dict, + load_state_dict_into_model + ) +from .utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME __all__ = ['GeneralCheckpointIO'] class GeneralCheckpointIO(CheckpointIO): - - def load_sharded_model(self, model: nn.Module, index_file_path: Path, strict: bool): - # load the index file - index_file = CheckpointIndexFile.from_file(index_file_path) - - # iterate over the shard checkpoint files - # and load each - index_file.assert_no_dtensor_checkpoint() - checkpoint_file_list, _ = index_file.get_checkpoint_fileanames() - for shard_file in checkpoint_file_list: - shard_checkpoint = load_state_dict(shard_file) - model.load_state_dict(shard_checkpoint, strict=strict) - + """ + Checkpoint IO + """ def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool): checkpoint = load_state_dict(checkpoint) model.load_state_dict(checkpoint, strict=strict) - def save_sharded_model(self, model: nn.Module, checkpoint: Path, gather_dtensor: bool, prefix: str, - size_per_shard: int, use_safetensors: bool): - # TODO(FrankLeeeee): implement this method as it can be supported by Huggingface model - raise NotImplementedError("Sharded model checkpoint is not supported yet.") - def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): state_dict = model.state_dict() @@ -68,3 +66,68 @@ def save_unsharded_optimizer( ): # TODO(FrankLeeeee): handle distributed tensors save_state_dict(optimizer.state_dict(), checkpoint, use_safetensors=False) + + + def save_sharded_model(self, model: nn.Module, checkpoint_path: str, gather_dtensor:bool = False, + prefix: str = "", max_shard_size: int = 1024, use_safetensors: bool = False): + """ + implement this method as it can be supported by Huggingface model, + save shard model, save model to multiple files + """ + if os.path.isfile(checkpoint_path): + logging.error(f"Provided path ({checkpoint_path}) should be a directory, not a file") + return + + Path(checkpoint_path).mkdir(parents=True, exist_ok=True) + + # shard checkpoint + state_dict = model.state_dict() + weights_name = SAFE_WEIGHTS_NAME if use_safetensors else WEIGHTS_NAME + shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size, weights_name=weights_name) + + # Save the model + for shard_file, shard in shards.items(): + checkpoint_file_path = os.path.join(checkpoint_path, shard_file) + save_state_dict(shard, checkpoint_file_path, use_safetensors) + + # save index file + save_index_file = SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME + save_index_file = os.path.join(checkpoint_path, save_index_file) + with open(save_index_file, "w", encoding="utf-8") as f: + content = json.dumps(index, indent=2, sort_keys=True) + "\n" + f.write(content) + logging.info( + f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be " + f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the " + f"index located at {save_index_file}." + ) + + + def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False, use_safetensors: bool = False): + """ + load shard model, load model from multiple files + """ + use_safetensors = False + if "safetensors" in checkpoint_index_file.name: + use_safetensors = True + + if use_safetensors and not is_safetensors_available(): + raise ImportError("`safe_serialization` requires the `safetensors` library: `pip install safetensors`.") + + # read checkpoint index file + ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file) + checkpoint_files, _ = ckpt_index_file.get_checkpoint_fileanames() + missing_keys = ckpt_index_file.get_all_param_names() + + for shard_file in checkpoint_files: + state_dict = load_shard_state_dict(Path(shard_file), use_safetensors) + load_state_dict_into_model(model, state_dict, missing_keys, strict) + del state_dict + gc.collect() + + if strict and len(missing_keys) > 0: + error_msgs = 'Missing key(s) in state_dict: {}. '.format( + ', '.join('"{}"'.format(k) for k in missing_keys)) + raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( + self.__class__.__name__, "\n\t".join(error_msgs))) + diff --git a/colossalai/checkpoint_io/index_file.py b/colossalai/checkpoint_io/index_file.py index 32ff1b762e88..89224787a91b 100644 --- a/colossalai/checkpoint_io/index_file.py +++ b/colossalai/checkpoint_io/index_file.py @@ -148,3 +148,9 @@ def get_checkpoint_file(self, param_name: str) -> str: """ ckpt_path = self.weight_map[param_name] return ckpt_path + + def get_all_param_names(self): + """ + Get all the weight keys. + """ + return list(self.weight_map.keys()) diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index 76c9db0afaff..81b666da5c78 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -1,13 +1,19 @@ +# coding=utf-8 from pathlib import Path -from typing import List, Optional, Tuple - import torch +import torch.nn as nn +from typing import List, Dict, Mapping, OrderedDict, Optional, Tuple +from colossalai.tensor.d_tensor.d_tensor import DTensor + +SAFE_WEIGHTS_NAME = "model.safetensors" +WEIGHTS_NAME = "model.bin" +SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json" +WEIGHTS_INDEX_NAME = "model.bin.index.json" # ====================================== # General helper functions # ====================================== - def calculate_tensor_size(tensor: torch.Tensor) -> float: """ Calculate the size of a parameter in MB. Used to compute whether a group of params exceed the shard size. @@ -68,6 +74,130 @@ def is_safetensor_checkpoint(checkpoint_file_path: str) -> bool: return False +# ====================================== +# Helper functions for saving shard file +# ====================================== +def shard_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024, weights_name: str = WEIGHTS_NAME): + + """ + Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a + given size. + """ + sharded_state_dicts = [] + current_block = {} + current_block_size = 0 + total_size = 0 + + for key, weight in state_dict.items(): + if type(weight) != DTensor: + weight_size = calculate_tensor_size(weight) + + # If this weight is going to tip up over the maximal size, we split. + if current_block_size + weight_size > max_shard_size: + sharded_state_dicts.append(current_block) + current_block = {} + current_block_size = 0 + + current_block[key] = weight + current_block_size += weight_size + total_size += weight_size + + # Add the last block + sharded_state_dicts.append(current_block) + + # If we only have one shard, we return it + if len(sharded_state_dicts) == 1: + return {weights_name: sharded_state_dicts[0]}, None + + # Otherwise, let's build the index + weight_map = {} + shards = {} + + for idx, shard in enumerate(sharded_state_dicts): + shard_file = weights_name.replace(".bin", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.bin") + shard_file = shard_file.replace( + ".safetensors", f"-{idx + 1:05d}-of-{len(sharded_state_dicts):05d}.safetensors" + ) + shards[shard_file] = shard + for key in shard.keys(): + weight_map[key] = shard_file + + # Add the metadata + metadata = {"total_size": total_size} + index = {"metadata": metadata, "weight_map": weight_map} + return shards, index + +def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool =False): + """ + load shard state dict into model + """ + if use_safetensors and not checkpoint_file.suffix == ".safetensors": + raise Exception("load the model using `safetensors`, but no file endwith .safetensors") + if use_safetensors: + from safetensors.torch import safe_open + from safetensors.torch import load_file as safe_load_file + with safe_open(checkpoint_file, framework="pt") as f: + metadata = f.metadata() + if metadata["format"] != "pt": + raise NotImplementedError( + f"Conversion from a {metadata['format']} safetensors archive to PyTorch is not implemented yet." + ) + return safe_load_file(checkpoint_file) + else: + return torch.load(checkpoint_file) + +def load_state_dict_into_model(model: nn.Module, state_dict: torch.Tensor, missing_keys: List, strict: bool = False): + r"""Copies parameters and buffers from :attr:`state_dict` into + this module and its descendants. + + Args: + state_dict (dict): a dict containing parameters and + persistent buffers. + """ + if not isinstance(state_dict, Mapping): + raise TypeError("Expected state_dict to be dict-like, got {}.".format(type(state_dict))) + + unexpected_keys: List[str] = [] + sub_missing_keys: List[str] = [] + error_msgs: List[str] = [] + + # copy state_dict so _load_from_state_dict can modify it + metadata = getattr(state_dict, '_metadata', None) + state_dict = OrderedDict(state_dict) + if metadata is not None: + state_dict._metadata = metadata + + def load(module: nn.Module, state_dict, prefix=""): + local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) + args = (state_dict, prefix, local_metadata, True, [], [], error_msgs) + # Parameters of module and children will start with prefix. We can exit early if there are none in this + # state_dict + if len([key for key in state_dict if key.startswith(prefix)]) > 0: + module._load_from_state_dict(*args) + + for name, child in module._modules.items(): + if child is not None: + load(child, state_dict, prefix + name + ".") + + load(model, state_dict, "") + del load + + # deal with missing key + if len(missing_keys) > 0: + deleted_keys = [] + for key in missing_keys: + if key not in sub_missing_keys: + deleted_keys.append(key) + for key in deleted_keys: + missing_keys.remove(key) + + if strict: + if len(unexpected_keys) > 0: + error_msgs = 'Unexpected key(s) in state_dict: {}. '.format( + ', '.join('"{}"'.format(k) for k in unexpected_keys)) + raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( + model.__class__.__name__, "\n\t".join(error_msgs))) + # ====================================== # Helper functions for saving state dict # ====================================== @@ -86,8 +216,8 @@ def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors assert is_safetensors_available(), "safetensors is not available." assert checkpoint_file_path.endswith('.safetensors'), \ "safetensors only supports .safetensors suffix for checkpoint file." - from safetensors.torch import save_file - save_file(state_dict, checkpoint_file_path) + from safetensors.torch import save_file as safe_save_file + safe_save_file(state_dict, checkpoint_file_path, metadata={"format": "pt"}) else: torch.save(state_dict, checkpoint_file_path) diff --git a/tests/test_checkpoint_io/test_general_checkpoint_io.py b/tests/test_checkpoint_io/test_general_checkpoint_io.py index 0f78184f70e1..ca5ce10054f7 100644 --- a/tests/test_checkpoint_io/test_general_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_general_checkpoint_io.py @@ -1,9 +1,12 @@ import tempfile - import pytest import torch +import logging from torch.optim import Adam from torchvision.models import resnet18 +from pathlib import Path +import os +import subprocess from colossalai.checkpoint_io import GeneralCheckpointIO from colossalai.testing import clear_cache_before_run, parameterize @@ -12,7 +15,7 @@ # Note: # 1. due to checkpoint IO can be quite slow if tested with all models, we will only test on resnet for now # 2. we will test on both sharded and unsharded checkpoints -# 3. TODO(FrankLeeeee): implement sharded checkpoint and test it +# 3. implement sharded checkpoint and test it # ======== @@ -53,27 +56,71 @@ def test_unsharded_checkpoint(use_safetensors: bool): ckpt_io.load_model(new_model, model_ckpt_tempfile.name) ckpt_io.load_optimizer(new_optimizer, optimizer_ckpt_tempfile.name) - # do recursive check for the optimizer state dict - # if the value is a dict, compare its values - # if the value is a list, comapre all elements one-by-one - # if the value is a torch.Tensor, use torch.equal - # otherwise use assertEqual - def recursive_check(d1, d2): - for k, v in d1.items(): - if isinstance(v, dict): - recursive_check(v, d2[k]) - elif isinstance(v, list): - for i in range(len(v)): - if isinstance(v[i], torch.Tensor): - assert torch.equal(v[i], d2[k][i]) - else: - assert v[i] == d2[k][i] - elif isinstance(v, torch.Tensor): - assert torch.equal(v, d2[k]) - else: - assert v == d2[k] # check for model and optimizer state dict recursively recursive_check(model.state_dict(), new_model.state_dict()) recursive_check(optimizer.state_dict(), new_optimizer.state_dict()) + +@pytest.mark.parametrize('use_safetensors', [True, False]) +def test_sharded_checkpoint(use_safetensors: bool): + # create a model and optimizer + model = resnet18() + optimizer = Adam(model.parameters(), lr=0.001) + # create test data sample + x = torch.randn(1, 3, 224, 224) + + # run fwd and bwd + y = model(x) + loss = y.sum() + loss.backward() + optimizer.step() + + # create a temp file for checkpoint + if use_safetensors: + suffix = ".safetensors" + SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json" + else: + suffix = ".bin" + WEIGHTS_INDEX_NAME = "model.bin.index.json" + + # model_ckpt_dir = tempfile.TemporaryDirectory(suffix=suffix) + model_ckpt_dir = tempfile.TemporaryDirectory() + optimizer_ckpt_tempfile = tempfile.NamedTemporaryFile() + + # save the model and optimizer + ckpt_io = GeneralCheckpointIO() + + ckpt_io.save_model(model, model_ckpt_dir.name, True, True, "", 10, use_safetensors=use_safetensors) + ckpt_io.save_optimizer(optimizer, optimizer_ckpt_tempfile.name, shard=False) + + # create new model + new_model = resnet18() + new_optimizer = Adam(new_model.parameters(), lr=0.001) + + ckpt_io.load_model(new_model, str(model_ckpt_dir.name), strict=True) + ckpt_io.load_optimizer(new_optimizer, optimizer_ckpt_tempfile.name) + + # check for model and optimizer state dict recursively + recursive_check(model.state_dict(), new_model.state_dict()) recursive_check(optimizer.state_dict(), new_optimizer.state_dict()) + + +# do recursive check for the optimizer state dict +# if the value is a dict, compare its values +# if the value is a list, comapre all elements one-by-one +# if the value is a torch.Tensor, use torch.equal +# otherwise use assertEqual +def recursive_check(d1, d2): + for k, v in d1.items(): + if isinstance(v, dict): + recursive_check(v, d2[k]) + elif isinstance(v, list): + for i in range(len(v)): + if isinstance(v[i], torch.Tensor): + assert torch.equal(v[i], d2[k][i]) + else: + assert v[i] == d2[k][i] + elif isinstance(v, torch.Tensor): + assert torch.equal(v, d2[k]) + else: + assert v == d2[k] From 4e9989344d20e3f8af44767f0eadeaab5fff8c00 Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Thu, 6 Apr 2023 17:47:59 +0800 Subject: [PATCH 21/27] [doc] updated contributor list (#3474) --- README.md | 5 +++-- docs/README-zh-Hans.md | 6 +++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 65c8ae166608..8342a9fa0c9e 100644 --- a/README.md +++ b/README.md @@ -396,9 +396,10 @@ You may contact us or participate in the following ways: Thanks so much to all of our amazing contributors! - + + + -*The order of contributor avatars is randomly shuffled.*

(back to top)

diff --git a/docs/README-zh-Hans.md b/docs/README-zh-Hans.md index 3630e8539a8b..f43a5953022d 100644 --- a/docs/README-zh-Hans.md +++ b/docs/README-zh-Hans.md @@ -396,9 +396,9 @@ docker run -ti --gpus all --rm --ipc=host colossalai bash 真诚感谢所有贡献者! - - -*贡献者头像的展示顺序是随机的。* + + +

(返回顶端)

From c701b77b1131a9095f3dca454da4ec667bcbf182 Mon Sep 17 00:00:00 2001 From: NatalieC323 <127177614+NatalieC323@users.noreply.github.com> Date: Thu, 6 Apr 2023 17:50:52 +0800 Subject: [PATCH 22/27] [dreambooth] fixing the incompatibity in requirements.txt (#3190) (#3378) * Update requirements.txt * Update environment.yaml * Update README.md * Update environment.yaml * Update README.md * Update README.md * Delete requirements_colossalai.txt * Update requirements.txt * Update README.md --- .../Teyvat/train_colossalai_teyvat.yaml | 16 ++-- .../diffusion/configs/train_colossalai.yaml | 22 ++--- .../configs/train_colossalai_cifar10.yaml | 16 ++-- .../images/diffusion/configs/train_ddp.yaml | 14 +-- .../diffusion/ldm/models/autoencoder.py | 5 +- .../ldm/models/diffusion/classifier.py | 9 +- .../diffusion/ldm/models/diffusion/ddpm.py | 22 +++-- examples/images/diffusion/main.py | 94 ++++++++++++------- examples/images/diffusion/scripts/img2img.py | 4 +- examples/images/diffusion/scripts/inpaint.py | 4 +- examples/images/diffusion/scripts/knn2img.py | 5 +- .../diffusion/scripts/sample_diffusion.py | 3 +- .../scripts/tests/test_checkpoint.py | 4 +- examples/images/diffusion/scripts/txt2img.py | 4 +- 14 files changed, 124 insertions(+), 98 deletions(-) diff --git a/examples/images/diffusion/configs/Teyvat/train_colossalai_teyvat.yaml b/examples/images/diffusion/configs/Teyvat/train_colossalai_teyvat.yaml index ff0f4c5a0463..fe883cdfd7f8 100644 --- a/examples/images/diffusion/configs/Teyvat/train_colossalai_teyvat.yaml +++ b/examples/images/diffusion/configs/Teyvat/train_colossalai_teyvat.yaml @@ -1,6 +1,6 @@ model: base_learning_rate: 1.0e-4 - target: ldm.models.diffusion.ddpm.LatentDiffusion + #target: ldm.models.diffusion.ddpm.LatentDiffusion params: parameterization: "v" linear_start: 0.00085 @@ -20,7 +20,7 @@ model: use_ema: False scheduler_config: # 10000 warmup steps - target: ldm.lr_scheduler.LambdaLinearScheduler + #target: ldm.lr_scheduler.LambdaLinearScheduler params: warm_up_steps: [ 1 ] # NOTE for resuming. use 10000 if starting from scratch cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases @@ -30,7 +30,7 @@ model: unet_config: - target: ldm.modules.diffusionmodules.openaimodel.UNetModel + #target: ldm.modules.diffusionmodules.openaimodel.UNetModel params: use_checkpoint: True use_fp16: True @@ -49,7 +49,7 @@ model: legacy: False first_stage_config: - target: ldm.models.autoencoder.AutoencoderKL + #target: ldm.models.autoencoder.AutoencoderKL params: embed_dim: 4 monitor: val/rec_loss @@ -73,13 +73,13 @@ model: target: torch.nn.Identity cond_stage_config: - target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder + #target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder params: freeze: True layer: "penultimate" data: - target: main.DataModuleFromConfig + #target: main.DataModuleFromConfig params: batch_size: 16 num_workers: 4 @@ -105,7 +105,7 @@ lightning: precision: 16 auto_select_gpus: False strategy: - target: strategies.ColossalAIStrategy + #target: strategies.ColossalAIStrategy params: use_chunk: True enable_distributed_storage: True @@ -120,7 +120,7 @@ lightning: logger_config: wandb: - target: loggers.WandbLogger + #target: loggers.WandbLogger params: name: nowname save_dir: "/tmp/diff_log/" diff --git a/examples/images/diffusion/configs/train_colossalai.yaml b/examples/images/diffusion/configs/train_colossalai.yaml index 88432e978a0f..388ab2e8ff94 100644 --- a/examples/images/diffusion/configs/train_colossalai.yaml +++ b/examples/images/diffusion/configs/train_colossalai.yaml @@ -1,6 +1,6 @@ model: base_learning_rate: 1.0e-4 - target: ldm.models.diffusion.ddpm.LatentDiffusion + #target: ldm.models.diffusion.ddpm.LatentDiffusion params: parameterization: "v" linear_start: 0.00085 @@ -19,7 +19,7 @@ model: use_ema: False # we set this to false because this is an inference only config scheduler_config: # 10000 warmup steps - target: ldm.lr_scheduler.LambdaLinearScheduler + #target: ldm.lr_scheduler.LambdaLinearScheduler params: warm_up_steps: [ 1 ] # NOTE for resuming. use 10000 if starting from scratch cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases @@ -29,7 +29,7 @@ model: unet_config: - target: ldm.modules.diffusionmodules.openaimodel.UNetModel + #target: ldm.modules.diffusionmodules.openaimodel.UNetModel params: use_checkpoint: True use_fp16: True @@ -48,7 +48,7 @@ model: legacy: False first_stage_config: - target: ldm.models.autoencoder.AutoencoderKL + #target: ldm.models.autoencoder.AutoencoderKL params: embed_dim: 4 monitor: val/rec_loss @@ -69,16 +69,16 @@ model: attn_resolutions: [] dropout: 0.0 lossconfig: - target: torch.nn.Identity + #target: torch.nn.Identity cond_stage_config: - target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder + #target: #ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder params: freeze: True layer: "penultimate" data: - target: main.DataModuleFromConfig + #target: #main.DataModuleFromConfig params: batch_size: 128 wrap: False @@ -88,20 +88,20 @@ data: train: target: ldm.data.base.Txt2ImgIterableBaseDataset params: - file_path: # YOUR DATASET_PATH + file_path: /data/scratch/diffuser/laion_part0/ world_size: 1 rank: 0 lightning: trainer: accelerator: 'gpu' - devices: 8 + devices: 2 log_gpu_memory: all max_epochs: 2 precision: 16 auto_select_gpus: False strategy: - target: strategies.ColossalAIStrategy + #target: #strategies.ColossalAIStrategy params: use_chunk: True enable_distributed_storage: True @@ -116,7 +116,7 @@ lightning: logger_config: wandb: - target: loggers.WandbLogger + #target: #loggers.WandbLogger params: name: nowname save_dir: "/tmp/diff_log/" diff --git a/examples/images/diffusion/configs/train_colossalai_cifar10.yaml b/examples/images/diffusion/configs/train_colossalai_cifar10.yaml index 0ba06f832178..1331f96e34d6 100644 --- a/examples/images/diffusion/configs/train_colossalai_cifar10.yaml +++ b/examples/images/diffusion/configs/train_colossalai_cifar10.yaml @@ -1,6 +1,6 @@ model: base_learning_rate: 1.0e-4 - target: ldm.models.diffusion.ddpm.LatentDiffusion + #target: ldm.models.diffusion.ddpm.LatentDiffusion params: parameterization: "v" linear_start: 0.00085 @@ -19,7 +19,7 @@ model: use_ema: False # we set this to false because this is an inference only config scheduler_config: # 10000 warmup steps - target: ldm.lr_scheduler.LambdaLinearScheduler + #target: ldm.lr_scheduler.LambdaLinearScheduler params: warm_up_steps: [ 1 ] # NOTE for resuming. use 10000 if starting from scratch cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases @@ -29,7 +29,7 @@ model: unet_config: - target: ldm.modules.diffusionmodules.openaimodel.UNetModel + #target: ldm.modules.diffusionmodules.openaimodel.UNetModel params: use_checkpoint: True use_fp16: True @@ -48,7 +48,7 @@ model: legacy: False first_stage_config: - target: ldm.models.autoencoder.AutoencoderKL + #target: ldm.models.autoencoder.AutoencoderKL params: embed_dim: 4 monitor: val/rec_loss @@ -69,16 +69,16 @@ model: attn_resolutions: [] dropout: 0.0 lossconfig: - target: torch.nn.Identity + #target: torch.nn.Identity cond_stage_config: - target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder + #target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder params: freeze: True layer: "penultimate" data: - target: main.DataModuleFromConfig + #target: main.DataModuleFromConfig params: batch_size: 4 num_workers: 4 @@ -105,7 +105,7 @@ lightning: precision: 16 auto_select_gpus: False strategy: - target: strategies.ColossalAIStrategy + #target: strategies.ColossalAIStrategy params: use_chunk: True enable_distributed_storage: True diff --git a/examples/images/diffusion/configs/train_ddp.yaml b/examples/images/diffusion/configs/train_ddp.yaml index a63df887e719..df591f33d5fd 100644 --- a/examples/images/diffusion/configs/train_ddp.yaml +++ b/examples/images/diffusion/configs/train_ddp.yaml @@ -1,6 +1,6 @@ model: base_learning_rate: 1.0e-4 - target: ldm.models.diffusion.ddpm.LatentDiffusion + #target: ldm.models.diffusion.ddpm.LatentDiffusion params: parameterization: "v" linear_start: 0.00085 @@ -29,7 +29,7 @@ model: unet_config: - target: ldm.modules.diffusionmodules.openaimodel.UNetModel + #target: ldm.modules.diffusionmodules.openaimodel.UNetModel params: use_checkpoint: True use_fp16: True @@ -48,7 +48,7 @@ model: legacy: False first_stage_config: - target: ldm.models.autoencoder.AutoencoderKL + #target: ldm.models.autoencoder.AutoencoderKL params: embed_dim: 4 monitor: val/rec_loss @@ -72,13 +72,13 @@ model: target: torch.nn.Identity cond_stage_config: - target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder + #target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder params: freeze: True layer: "penultimate" data: - target: main.DataModuleFromConfig + #target: main.DataModuleFromConfig params: batch_size: 128 # num_workwers should be 2 * batch_size, and the total num less than 1024 @@ -100,7 +100,7 @@ lightning: precision: 16 auto_select_gpus: False strategy: - target: strategies.DDPStrategy + #target: strategies.DDPStrategy params: find_unused_parameters: False log_every_n_steps: 2 @@ -111,7 +111,7 @@ lightning: logger_config: wandb: - target: loggers.WandbLogger + #target: loggers.WandbLogger params: name: nowname save_dir: "/data2/tmp/diff_log/" diff --git a/examples/images/diffusion/ldm/models/autoencoder.py b/examples/images/diffusion/ldm/models/autoencoder.py index b1bd8377835b..145ccf6fb271 100644 --- a/examples/images/diffusion/ldm/models/autoencoder.py +++ b/examples/images/diffusion/ldm/models/autoencoder.py @@ -6,11 +6,10 @@ import torch.nn.functional as F from contextlib import contextmanager +from torch.nn import Identity from ldm.modules.diffusionmodules.model import Encoder, Decoder from ldm.modules.distributions.distributions import DiagonalGaussianDistribution - -from ldm.util import instantiate_from_config from ldm.modules.ema import LitEma @@ -32,7 +31,7 @@ def __init__(self, self.image_key = image_key self.encoder = Encoder(**ddconfig) self.decoder = Decoder(**ddconfig) - self.loss = instantiate_from_config(lossconfig) + self.loss = Identity(**lossconfig.get("params", dict())) assert ddconfig["double_z"] self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) diff --git a/examples/images/diffusion/ldm/models/diffusion/classifier.py b/examples/images/diffusion/ldm/models/diffusion/classifier.py index 612a8371bf20..3cf12f093bea 100644 --- a/examples/images/diffusion/ldm/models/diffusion/classifier.py +++ b/examples/images/diffusion/ldm/models/diffusion/classifier.py @@ -9,9 +9,10 @@ from einops import rearrange from glob import glob from natsort import natsorted - +from ldm.models.diffusion.ddpm import LatentDiffusion +from ldm.lr_scheduler import LambdaLinearScheduler from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel -from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config +from ldm.util import log_txt_as_img, default, ismap __models__ = { 'class_label': EncoderUNetModel, @@ -86,7 +87,7 @@ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): print(f"Unexpected Keys: {unexpected}") def load_diffusion(self): - model = instantiate_from_config(self.diffusion_config) + model = LatentDiffusion(**self.diffusion_config.get('params',dict())) self.diffusion_model = model.eval() self.diffusion_model.train = disabled_train for param in self.diffusion_model.parameters(): @@ -221,7 +222,7 @@ def configure_optimizers(self): optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay) if self.use_scheduler: - scheduler = instantiate_from_config(self.scheduler_config) + scheduler = LambdaLinearScheduler(**self.scheduler_config.get('params',dict())) print("Setting up LambdaLR scheduler...") scheduler = [ diff --git a/examples/images/diffusion/ldm/models/diffusion/ddpm.py b/examples/images/diffusion/ldm/models/diffusion/ddpm.py index b7315b048c66..11de828732ea 100644 --- a/examples/images/diffusion/ldm/models/diffusion/ddpm.py +++ b/examples/images/diffusion/ldm/models/diffusion/ddpm.py @@ -22,6 +22,7 @@ from functools import partial from einops import rearrange, repeat +from ldm.lr_scheduler import LambdaLinearScheduler from ldm.models.autoencoder import * from ldm.models.autoencoder import AutoencoderKL, IdentityFirstStage from ldm.models.diffusion.ddim import * @@ -29,9 +30,10 @@ from ldm.modules.diffusionmodules.model import * from ldm.modules.diffusionmodules.model import Decoder, Encoder, Model from ldm.modules.diffusionmodules.openaimodel import * -from ldm.modules.diffusionmodules.openaimodel import AttentionPool2d +from ldm.modules.diffusionmodules.openaimodel import AttentionPool2d, UNetModel from ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule, noise_like from ldm.modules.distributions.distributions import DiagonalGaussianDistribution, normal_kl +from ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation from ldm.modules.ema import LitEma from ldm.modules.encoders.modules import * from ldm.util import count_params, default, exists, instantiate_from_config, isimage, ismap, log_txt_as_img, mean_flat @@ -39,6 +41,7 @@ from torch.optim.lr_scheduler import LambdaLR from torchvision.utils import make_grid from tqdm import tqdm +from ldm.modules.midas.api import MiDaSInference __conditioning_keys__ = {'concat': 'c_concat', 'crossattn': 'c_crossattn', 'adm': 'y'} @@ -690,7 +693,7 @@ def register_schedule(self, self.make_cond_schedule() def instantiate_first_stage(self, config): - model = instantiate_from_config(config) + model = AutoencoderKL(**config.get("params", dict())) self.first_stage_model = model.eval() self.first_stage_model.train = disabled_train for param in self.first_stage_model.parameters(): @@ -706,7 +709,7 @@ def instantiate_cond_stage(self, config): self.cond_stage_model = None # self.be_unconditional = True else: - model = instantiate_from_config(config) + model = FrozenOpenCLIPEmbedder(**config.get("params", dict())) self.cond_stage_model = model.eval() self.cond_stage_model.train = disabled_train for param in self.cond_stage_model.parameters(): @@ -714,7 +717,7 @@ def instantiate_cond_stage(self, config): else: assert config != '__is_first_stage__' assert config != '__is_unconditional__' - model = instantiate_from_config(config) + model = FrozenOpenCLIPEmbedder(**config.get("params", dict())) self.cond_stage_model = model def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False): @@ -1479,8 +1482,7 @@ def configure_optimizers(self): # opt = torch.optim.AdamW(params, lr=lr) if self.use_scheduler: - assert 'target' in self.scheduler_config - scheduler = instantiate_from_config(self.scheduler_config) + scheduler = LambdaLinearScheduler(**self.scheduler_config.get("params", dict())) rank_zero_info("Setting up LambdaLR scheduler...") scheduler = [{'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule), 'interval': 'step', 'frequency': 1}] @@ -1502,7 +1504,7 @@ class DiffusionWrapper(pl.LightningModule): def __init__(self, diff_model_config, conditioning_key): super().__init__() self.sequential_cross_attn = diff_model_config.pop("sequential_crossattn", False) - self.diffusion_model = instantiate_from_config(diff_model_config) + self.diffusion_model = UNetModel(**diff_model_config.get("params", dict())) self.conditioning_key = conditioning_key assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm', 'crossattn-adm'] @@ -1551,7 +1553,7 @@ def __init__(self, *args, low_scale_config, low_scale_key="LR", noise_level_key= self.noise_level_key = noise_level_key def instantiate_low_stage(self, config): - model = instantiate_from_config(config) + model = ImageConcatWithNoiseAugmentation(**config.get("params", dict())) self.low_scale_model = model.eval() self.low_scale_model.train = disabled_train for param in self.low_scale_model.parameters(): @@ -1933,7 +1935,7 @@ class LatentDepth2ImageDiffusion(LatentFinetuneDiffusion): def __init__(self, depth_stage_config, concat_keys=("midas_in",), *args, **kwargs): super().__init__(concat_keys=concat_keys, *args, **kwargs) - self.depth_model = instantiate_from_config(depth_stage_config) + self.depth_model = MiDaSInference(**depth_stage_config.get("params", dict())) self.depth_stage_key = concat_keys[0] @torch.no_grad() @@ -2006,7 +2008,7 @@ def __init__(self, self.low_scale_key = low_scale_key def instantiate_low_stage(self, config): - model = instantiate_from_config(config) + model = ImageConcatWithNoiseAugmentation(**config.get("params", dict())) self.low_scale_model = model.eval() self.low_scale_model.train = disabled_train for param in self.low_scale_model.parameters(): diff --git a/examples/images/diffusion/main.py b/examples/images/diffusion/main.py index 91b809d5a65c..aeed6d5566f5 100644 --- a/examples/images/diffusion/main.py +++ b/examples/images/diffusion/main.py @@ -23,19 +23,21 @@ from PIL import Image from prefetch_generator import BackgroundGenerator from torch.utils.data import DataLoader, Dataset, Subset, random_split - -try: - from lightning.pytorch import seed_everything - from lightning.pytorch.callbacks import Callback, LearningRateMonitor, ModelCheckpoint - from lightning.pytorch.trainer import Trainer - from lightning.pytorch.utilities import rank_zero_info, rank_zero_only - LIGHTNING_PACK_NAME = "lightning.pytorch." -except: - from pytorch_lightning import seed_everything - from pytorch_lightning.callbacks import Callback, LearningRateMonitor, ModelCheckpoint - from pytorch_lightning.trainer import Trainer - from pytorch_lightning.utilities import rank_zero_info, rank_zero_only - LIGHTNING_PACK_NAME = "pytorch_lightning." +from ldm.models.diffusion.ddpm import LatentDiffusion +#try: +from lightning.pytorch import seed_everything +from lightning.pytorch.callbacks import Callback, LearningRateMonitor, ModelCheckpoint +from lightning.pytorch.trainer import Trainer +from lightning.pytorch.utilities import rank_zero_info, rank_zero_only +from lightning.pytorch.loggers import WandbLogger, TensorBoardLogger +from lightning.pytorch.strategies import ColossalAIStrategy,DDPStrategy +LIGHTNING_PACK_NAME = "lightning.pytorch." +# #except: +# from pytorch_lightning import seed_everything +# from pytorch_lightning.callbacks import Callback, LearningRateMonitor, ModelCheckpoint +# from pytorch_lightning.trainer import Trainer +# from pytorch_lightning.utilities import rank_zero_info, rank_zero_only +# LIGHTNING_PACK_NAME = "pytorch_lightning." from ldm.data.base import Txt2ImgIterableBaseDataset from ldm.util import instantiate_from_config @@ -575,7 +577,7 @@ def on_train_epoch_end(self, trainer, pl_module): # target: path to test dataset # params: # key: value - # lightning: (optional, has sane defaults and can be specified on cmdline) + # lightning: (optional, has same defaults and can be specified on cmdline) # trainer: # additional arguments to trainer # logger: @@ -653,7 +655,7 @@ def on_train_epoch_end(self, trainer, pl_module): # Sets the seed for the random number generator to ensure reproducibility seed_everything(opt.seed) - # Intinalize and save configuratioon using teh OmegaConf library. + # Intinalize and save configuration using the OmegaConf library. try: # init and save configs configs = [OmegaConf.load(cfg) for cfg in opt.base] @@ -687,7 +689,7 @@ def on_train_epoch_end(self, trainer, pl_module): config.model["params"].update({"ckpt": ckpt}) rank_zero_info("Using ckpt_path = {}".format(config.model["params"]["ckpt"])) - model = instantiate_from_config(config.model) + model = LatentDiffusion(**config.model.get("params", dict())) # trainer and callbacks trainer_kwargs = dict() @@ -696,7 +698,7 @@ def on_train_epoch_end(self, trainer, pl_module): # These loggers are specified as targets in the dictionary, along with the configuration settings specific to each logger. default_logger_cfgs = { "wandb": { - "target": LIGHTNING_PACK_NAME + "loggers.WandbLogger", + #"target": LIGHTNING_PACK_NAME + "loggers.WandbLogger", "params": { "name": nowname, "save_dir": logdir, @@ -705,7 +707,7 @@ def on_train_epoch_end(self, trainer, pl_module): } }, "tensorboard": { - "target": LIGHTNING_PACK_NAME + "loggers.TensorBoardLogger", + #"target": LIGHTNING_PACK_NAME + "loggers.TensorBoardLogger", "params": { "save_dir": logdir, "name": "diff_tb", @@ -718,30 +720,32 @@ def on_train_epoch_end(self, trainer, pl_module): default_logger_cfg = default_logger_cfgs["tensorboard"] if "logger" in lightning_config: logger_cfg = lightning_config.logger + logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg) + trainer_kwargs["logger"] = WandbLogger(**logger_cfg.get("params", dict())) else: logger_cfg = default_logger_cfg - logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg) - trainer_kwargs["logger"] = instantiate_from_config(logger_cfg) + logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg) + trainer_kwargs["logger"] = TensorBoardLogger(**logger_cfg.get("params", dict())) + # config the strategy, defualt is ddp if "strategy" in trainer_config: strategy_cfg = trainer_config["strategy"] - strategy_cfg["target"] = LIGHTNING_PACK_NAME + strategy_cfg["target"] + trainer_kwargs["strategy"] = ColossalAIStrategy(**strategy_cfg.get("params", dict())) else: strategy_cfg = { - "target": LIGHTNING_PACK_NAME + "strategies.DDPStrategy", + #"target": LIGHTNING_PACK_NAME + "strategies.DDPStrategy", "params": { "find_unused_parameters": False } } - - trainer_kwargs["strategy"] = instantiate_from_config(strategy_cfg) + trainer_kwargs["strategy"] = DDPStrategy(**strategy_cfg.get("params", dict())) # Set up ModelCheckpoint callback to save best models # modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to # specify which metric is used to determine best models default_modelckpt_cfg = { - "target": LIGHTNING_PACK_NAME + "callbacks.ModelCheckpoint", + #"target": LIGHTNING_PACK_NAME + "callbacks.ModelCheckpoint", "params": { "dirpath": ckptdir, "filename": "{epoch:06}", @@ -759,13 +763,13 @@ def on_train_epoch_end(self, trainer, pl_module): modelckpt_cfg = OmegaConf.create() modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg) if version.parse(pl.__version__) < version.parse('1.4.0'): - trainer_kwargs["checkpoint_callback"] = instantiate_from_config(modelckpt_cfg) + trainer_kwargs["checkpoint_callback"] = ModelCheckpoint(**modelckpt_cfg.get("params", dict())) # Set up various callbacks, including logging, learning rate monitoring, and CUDA management # add callback which sets up log directory default_callbacks_cfg = { "setup_callback": { # callback to set up the training - "target": "main.SetupCallback", + #"target": "main.SetupCallback", "params": { "resume": opt.resume, # resume training if applicable "now": now, @@ -777,7 +781,7 @@ def on_train_epoch_end(self, trainer, pl_module): } }, "image_logger": { # callback to log image data - "target": "main.ImageLogger", + #"target": "main.ImageLogger", "params": { "batch_frequency": 750, # how frequently to log images "max_images": 4, # maximum number of images to log @@ -785,14 +789,14 @@ def on_train_epoch_end(self, trainer, pl_module): } }, "learning_rate_logger": { # callback to log learning rate - "target": "main.LearningRateMonitor", + #"target": "main.LearningRateMonitor", "params": { "logging_interval": "step", # logging frequency (either 'step' or 'epoch') # "log_momentum": True # whether to log momentum (currently commented out) } }, "cuda_callback": { # callback to handle CUDA-related operations - "target": "main.CUDACallback" + #"target": "main.CUDACallback" }, } @@ -810,7 +814,7 @@ def on_train_epoch_end(self, trainer, pl_module): 'Caution: Saving checkpoints every n train steps without deleting. This might require some free space.') default_metrics_over_trainsteps_ckpt_dict = { 'metrics_over_trainsteps_checkpoint': { - "target": LIGHTNING_PACK_NAME + 'callbacks.ModelCheckpoint', + #"target": LIGHTNING_PACK_NAME + 'callbacks.ModelCheckpoint', 'params': { "dirpath": os.path.join(ckptdir, 'trainstep_checkpoints'), "filename": "{epoch:06}-{step:09}", @@ -825,15 +829,35 @@ def on_train_epoch_end(self, trainer, pl_module): # Merge the default callbacks configuration with the specified callbacks configuration, and instantiate the callbacks callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg) + + #Instantiate items according to the configs + trainer_kwargs.setdefault("callbacks", []) - trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg] + if "setup_callback" in callbacks_cfg: + setup_callback_config = callbacks_cfg["setup_callback"] + trainer_kwargs["callbacks"].append(SetupCallback(**setup_callback_config.get("params", dict()))) - # Create a Trainer object with the specified command-line arguments and keyword arguments, and set the log directory + if "image_logger" in callbacks_cfg: + image_logger_config = callbacks_cfg["image_logger"] + trainer_kwargs["callbacks"].append(ImageLogger(**image_logger_config.get("params", dict()))) + + if "learning_rate_logger" in callbacks_cfg: + learning_rate_logger_config = callbacks_cfg["learning_rate_logger"] + trainer_kwargs["callbacks"].append(LearningRateMonitor(**learning_rate_logger_config.get("params", dict()))) + + if "cuda_callback" in callbacks_cfg: + cuda_callback_config = callbacks_cfg["cuda_callback"] + trainer_kwargs["callbacks"].append(CUDACallback(**cuda_callback_config.get("params", dict()))) + + if "metrics_over_trainsteps_checkpoint" in callbacks_cfg: + metrics_over_config = callbacks_cfg['metrics_over_trainsteps_checkpoint'] + trainer_kwargs["callbacks"].append(ModelCheckpoint(**metrics_over_config.get("params", dict()))) + #trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg] trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs) trainer.logdir = logdir - + # Create a data module based on the configuration file - data = instantiate_from_config(config.data) + data = DataModuleFromConfig(**config.data.get("params", dict())) # NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html # calling these ourselves should not be necessary but it is. # lightning still takes care of proper multiprocessing though diff --git a/examples/images/diffusion/scripts/img2img.py b/examples/images/diffusion/scripts/img2img.py index 877538d4733d..a3011005c16a 100644 --- a/examples/images/diffusion/scripts/img2img.py +++ b/examples/images/diffusion/scripts/img2img.py @@ -20,8 +20,8 @@ from scripts.txt2img import put_watermark -from ldm.util import instantiate_from_config from ldm.models.diffusion.ddim import DDIMSampler +from ldm.models.diffusion.ddpm import LatentDiffusion from utils import replace_module, getModelSize @@ -36,7 +36,7 @@ def load_model_from_config(config, ckpt, verbose=False): if "global_step" in pl_sd: print(f"Global Step: {pl_sd['global_step']}") sd = pl_sd["state_dict"] - model = instantiate_from_config(config.model) + model = LatentDiffusion(**config.model.get("params", dict())) m, u = model.load_state_dict(sd, strict=False) if len(m) > 0 and verbose: print("missing keys:") diff --git a/examples/images/diffusion/scripts/inpaint.py b/examples/images/diffusion/scripts/inpaint.py index d6e6387a9a3b..993c67b0e6f7 100644 --- a/examples/images/diffusion/scripts/inpaint.py +++ b/examples/images/diffusion/scripts/inpaint.py @@ -4,7 +4,7 @@ from tqdm import tqdm import numpy as np import torch -from main import instantiate_from_config +from ldm.models.diffusion.ddpm import LatentgDiffusion from ldm.models.diffusion.ddim import DDIMSampler @@ -57,7 +57,7 @@ def make_batch(image, mask, device): print(f"Found {len(masks)} inputs.") config = OmegaConf.load("models/ldm/inpainting_big/config.yaml") - model = instantiate_from_config(config.model) + model = LatentDiffusion(**config.model.get("params", dict())) model.load_state_dict(torch.load("models/ldm/inpainting_big/last.ckpt")["state_dict"], strict=False) diff --git a/examples/images/diffusion/scripts/knn2img.py b/examples/images/diffusion/scripts/knn2img.py index e6eaaecab53e..66d9aa57de66 100644 --- a/examples/images/diffusion/scripts/knn2img.py +++ b/examples/images/diffusion/scripts/knn2img.py @@ -13,9 +13,10 @@ import time from multiprocessing import cpu_count -from ldm.util import instantiate_from_config, parallel_data_prefetch +from ldm.util import parallel_data_prefetch from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.plms import PLMSSampler +from ldm.models.diffusion.ddpm import LatentDiffusion from ldm.modules.encoders.modules import FrozenClipImageEmbedder, FrozenCLIPTextEmbedder DATABASES = [ @@ -44,7 +45,7 @@ def load_model_from_config(config, ckpt, verbose=False): if "global_step" in pl_sd: print(f"Global Step: {pl_sd['global_step']}") sd = pl_sd["state_dict"] - model = instantiate_from_config(config.model) + model = LatentDiffusion(**config.model.get("params", dict())) m, u = model.load_state_dict(sd, strict=False) if len(m) > 0 and verbose: print("missing keys:") diff --git a/examples/images/diffusion/scripts/sample_diffusion.py b/examples/images/diffusion/scripts/sample_diffusion.py index 876fe3c3642f..a25965ef74de 100644 --- a/examples/images/diffusion/scripts/sample_diffusion.py +++ b/examples/images/diffusion/scripts/sample_diffusion.py @@ -8,7 +8,6 @@ from PIL import Image from ldm.models.diffusion.ddim import DDIMSampler -from ldm.util import instantiate_from_config rescale = lambda x: (x + 1.) / 2. @@ -218,7 +217,7 @@ def get_parser(): def load_model_from_config(config, sd): - model = instantiate_from_config(config) + model = LatentDiffusion(**config.get("params", dict())) model.load_state_dict(sd,strict=False) model.cuda() model.eval() diff --git a/examples/images/diffusion/scripts/tests/test_checkpoint.py b/examples/images/diffusion/scripts/tests/test_checkpoint.py index a32e66d44cf2..a157d186d6e7 100644 --- a/examples/images/diffusion/scripts/tests/test_checkpoint.py +++ b/examples/images/diffusion/scripts/tests/test_checkpoint.py @@ -9,7 +9,7 @@ import torch from ldm.util import instantiate_from_config from main import get_parser - +from ldm.modules.diffusionmodules.openaimodel import UNetModel if __name__ == "__main__": with torch.no_grad(): yaml_path = "../../train_colossalai.yaml" @@ -17,7 +17,7 @@ config = f.read() base_config = yaml.load(config, Loader=yaml.FullLoader) unet_config = base_config['model']['params']['unet_config'] - diffusion_model = instantiate_from_config(unet_config).to("cuda:0") + diffusion_model = UNetModel(**unet_config.get("params", dict())).to("cuda:0") pipe = StableDiffusionPipeline.from_pretrained( "/data/scratch/diffuser/stable-diffusion-v1-4" diff --git a/examples/images/diffusion/scripts/txt2img.py b/examples/images/diffusion/scripts/txt2img.py index 364ebac6c67b..b198430f6d1c 100644 --- a/examples/images/diffusion/scripts/txt2img.py +++ b/examples/images/diffusion/scripts/txt2img.py @@ -16,9 +16,9 @@ from contextlib import nullcontext from imwatermark import WatermarkEncoder -from ldm.util import instantiate_from_config from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.plms import PLMSSampler +from ldm.models.diffusion.ddpm import LatentDiffusion from ldm.models.diffusion.dpm_solver import DPMSolverSampler from utils import replace_module, getModelSize @@ -35,7 +35,7 @@ def load_model_from_config(config, ckpt, verbose=False): if "global_step" in pl_sd: print(f"Global Step: {pl_sd['global_step']}") sd = pl_sd["state_dict"] - model = instantiate_from_config(config.model) + model = LatentDiffusion(**config.model.get("params", dict())) m, u = model.load_state_dict(sd, strict=False) if len(m) > 0 and verbose: print("missing keys:") From 891b8e7fac993a7fee2dccb5a54ec18f6110d5e0 Mon Sep 17 00:00:00 2001 From: binmakeswell Date: Thu, 6 Apr 2023 18:08:16 +0800 Subject: [PATCH 23/27] [chat] fix stage3 PPO sample sh command (#3477) --- applications/Chat/examples/train_prompts.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/applications/Chat/examples/train_prompts.sh b/applications/Chat/examples/train_prompts.sh index db73ac8e8e85..b750cf3581a6 100755 --- a/applications/Chat/examples/train_prompts.sh +++ b/applications/Chat/examples/train_prompts.sh @@ -15,4 +15,4 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() { set_n_least_used_CUDA_VISIBLE_DEVICES 2 -torchrun --standalone --nproc_per_node=2 train_prompts.py prompts.csv --strategy colossalai_zero2 +torchrun --standalone --nproc_per_node=2 train_prompts.py --prompt_path /path/to/data.json --strategy colossalai_zero2 From fb8fae6f2905173b0136031751d0c520c11e7b10 Mon Sep 17 00:00:00 2001 From: NatalieC323 <127177614+NatalieC323@users.noreply.github.com> Date: Thu, 6 Apr 2023 20:22:52 +0800 Subject: [PATCH 24/27] Revert "[dreambooth] fixing the incompatibity in requirements.txt (#3190) (#3378)" (#3481) --- .../Teyvat/train_colossalai_teyvat.yaml | 16 ++-- .../diffusion/configs/train_colossalai.yaml | 22 ++--- .../configs/train_colossalai_cifar10.yaml | 16 ++-- .../images/diffusion/configs/train_ddp.yaml | 14 +-- .../diffusion/ldm/models/autoencoder.py | 5 +- .../ldm/models/diffusion/classifier.py | 9 +- .../diffusion/ldm/models/diffusion/ddpm.py | 22 ++--- examples/images/diffusion/main.py | 94 +++++++------------ examples/images/diffusion/scripts/img2img.py | 4 +- examples/images/diffusion/scripts/inpaint.py | 4 +- examples/images/diffusion/scripts/knn2img.py | 5 +- .../diffusion/scripts/sample_diffusion.py | 3 +- .../scripts/tests/test_checkpoint.py | 4 +- examples/images/diffusion/scripts/txt2img.py | 4 +- 14 files changed, 98 insertions(+), 124 deletions(-) diff --git a/examples/images/diffusion/configs/Teyvat/train_colossalai_teyvat.yaml b/examples/images/diffusion/configs/Teyvat/train_colossalai_teyvat.yaml index fe883cdfd7f8..ff0f4c5a0463 100644 --- a/examples/images/diffusion/configs/Teyvat/train_colossalai_teyvat.yaml +++ b/examples/images/diffusion/configs/Teyvat/train_colossalai_teyvat.yaml @@ -1,6 +1,6 @@ model: base_learning_rate: 1.0e-4 - #target: ldm.models.diffusion.ddpm.LatentDiffusion + target: ldm.models.diffusion.ddpm.LatentDiffusion params: parameterization: "v" linear_start: 0.00085 @@ -20,7 +20,7 @@ model: use_ema: False scheduler_config: # 10000 warmup steps - #target: ldm.lr_scheduler.LambdaLinearScheduler + target: ldm.lr_scheduler.LambdaLinearScheduler params: warm_up_steps: [ 1 ] # NOTE for resuming. use 10000 if starting from scratch cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases @@ -30,7 +30,7 @@ model: unet_config: - #target: ldm.modules.diffusionmodules.openaimodel.UNetModel + target: ldm.modules.diffusionmodules.openaimodel.UNetModel params: use_checkpoint: True use_fp16: True @@ -49,7 +49,7 @@ model: legacy: False first_stage_config: - #target: ldm.models.autoencoder.AutoencoderKL + target: ldm.models.autoencoder.AutoencoderKL params: embed_dim: 4 monitor: val/rec_loss @@ -73,13 +73,13 @@ model: target: torch.nn.Identity cond_stage_config: - #target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder + target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder params: freeze: True layer: "penultimate" data: - #target: main.DataModuleFromConfig + target: main.DataModuleFromConfig params: batch_size: 16 num_workers: 4 @@ -105,7 +105,7 @@ lightning: precision: 16 auto_select_gpus: False strategy: - #target: strategies.ColossalAIStrategy + target: strategies.ColossalAIStrategy params: use_chunk: True enable_distributed_storage: True @@ -120,7 +120,7 @@ lightning: logger_config: wandb: - #target: loggers.WandbLogger + target: loggers.WandbLogger params: name: nowname save_dir: "/tmp/diff_log/" diff --git a/examples/images/diffusion/configs/train_colossalai.yaml b/examples/images/diffusion/configs/train_colossalai.yaml index 388ab2e8ff94..88432e978a0f 100644 --- a/examples/images/diffusion/configs/train_colossalai.yaml +++ b/examples/images/diffusion/configs/train_colossalai.yaml @@ -1,6 +1,6 @@ model: base_learning_rate: 1.0e-4 - #target: ldm.models.diffusion.ddpm.LatentDiffusion + target: ldm.models.diffusion.ddpm.LatentDiffusion params: parameterization: "v" linear_start: 0.00085 @@ -19,7 +19,7 @@ model: use_ema: False # we set this to false because this is an inference only config scheduler_config: # 10000 warmup steps - #target: ldm.lr_scheduler.LambdaLinearScheduler + target: ldm.lr_scheduler.LambdaLinearScheduler params: warm_up_steps: [ 1 ] # NOTE for resuming. use 10000 if starting from scratch cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases @@ -29,7 +29,7 @@ model: unet_config: - #target: ldm.modules.diffusionmodules.openaimodel.UNetModel + target: ldm.modules.diffusionmodules.openaimodel.UNetModel params: use_checkpoint: True use_fp16: True @@ -48,7 +48,7 @@ model: legacy: False first_stage_config: - #target: ldm.models.autoencoder.AutoencoderKL + target: ldm.models.autoencoder.AutoencoderKL params: embed_dim: 4 monitor: val/rec_loss @@ -69,16 +69,16 @@ model: attn_resolutions: [] dropout: 0.0 lossconfig: - #target: torch.nn.Identity + target: torch.nn.Identity cond_stage_config: - #target: #ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder + target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder params: freeze: True layer: "penultimate" data: - #target: #main.DataModuleFromConfig + target: main.DataModuleFromConfig params: batch_size: 128 wrap: False @@ -88,20 +88,20 @@ data: train: target: ldm.data.base.Txt2ImgIterableBaseDataset params: - file_path: /data/scratch/diffuser/laion_part0/ + file_path: # YOUR DATASET_PATH world_size: 1 rank: 0 lightning: trainer: accelerator: 'gpu' - devices: 2 + devices: 8 log_gpu_memory: all max_epochs: 2 precision: 16 auto_select_gpus: False strategy: - #target: #strategies.ColossalAIStrategy + target: strategies.ColossalAIStrategy params: use_chunk: True enable_distributed_storage: True @@ -116,7 +116,7 @@ lightning: logger_config: wandb: - #target: #loggers.WandbLogger + target: loggers.WandbLogger params: name: nowname save_dir: "/tmp/diff_log/" diff --git a/examples/images/diffusion/configs/train_colossalai_cifar10.yaml b/examples/images/diffusion/configs/train_colossalai_cifar10.yaml index 1331f96e34d6..0ba06f832178 100644 --- a/examples/images/diffusion/configs/train_colossalai_cifar10.yaml +++ b/examples/images/diffusion/configs/train_colossalai_cifar10.yaml @@ -1,6 +1,6 @@ model: base_learning_rate: 1.0e-4 - #target: ldm.models.diffusion.ddpm.LatentDiffusion + target: ldm.models.diffusion.ddpm.LatentDiffusion params: parameterization: "v" linear_start: 0.00085 @@ -19,7 +19,7 @@ model: use_ema: False # we set this to false because this is an inference only config scheduler_config: # 10000 warmup steps - #target: ldm.lr_scheduler.LambdaLinearScheduler + target: ldm.lr_scheduler.LambdaLinearScheduler params: warm_up_steps: [ 1 ] # NOTE for resuming. use 10000 if starting from scratch cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases @@ -29,7 +29,7 @@ model: unet_config: - #target: ldm.modules.diffusionmodules.openaimodel.UNetModel + target: ldm.modules.diffusionmodules.openaimodel.UNetModel params: use_checkpoint: True use_fp16: True @@ -48,7 +48,7 @@ model: legacy: False first_stage_config: - #target: ldm.models.autoencoder.AutoencoderKL + target: ldm.models.autoencoder.AutoencoderKL params: embed_dim: 4 monitor: val/rec_loss @@ -69,16 +69,16 @@ model: attn_resolutions: [] dropout: 0.0 lossconfig: - #target: torch.nn.Identity + target: torch.nn.Identity cond_stage_config: - #target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder + target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder params: freeze: True layer: "penultimate" data: - #target: main.DataModuleFromConfig + target: main.DataModuleFromConfig params: batch_size: 4 num_workers: 4 @@ -105,7 +105,7 @@ lightning: precision: 16 auto_select_gpus: False strategy: - #target: strategies.ColossalAIStrategy + target: strategies.ColossalAIStrategy params: use_chunk: True enable_distributed_storage: True diff --git a/examples/images/diffusion/configs/train_ddp.yaml b/examples/images/diffusion/configs/train_ddp.yaml index df591f33d5fd..a63df887e719 100644 --- a/examples/images/diffusion/configs/train_ddp.yaml +++ b/examples/images/diffusion/configs/train_ddp.yaml @@ -1,6 +1,6 @@ model: base_learning_rate: 1.0e-4 - #target: ldm.models.diffusion.ddpm.LatentDiffusion + target: ldm.models.diffusion.ddpm.LatentDiffusion params: parameterization: "v" linear_start: 0.00085 @@ -29,7 +29,7 @@ model: unet_config: - #target: ldm.modules.diffusionmodules.openaimodel.UNetModel + target: ldm.modules.diffusionmodules.openaimodel.UNetModel params: use_checkpoint: True use_fp16: True @@ -48,7 +48,7 @@ model: legacy: False first_stage_config: - #target: ldm.models.autoencoder.AutoencoderKL + target: ldm.models.autoencoder.AutoencoderKL params: embed_dim: 4 monitor: val/rec_loss @@ -72,13 +72,13 @@ model: target: torch.nn.Identity cond_stage_config: - #target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder + target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder params: freeze: True layer: "penultimate" data: - #target: main.DataModuleFromConfig + target: main.DataModuleFromConfig params: batch_size: 128 # num_workwers should be 2 * batch_size, and the total num less than 1024 @@ -100,7 +100,7 @@ lightning: precision: 16 auto_select_gpus: False strategy: - #target: strategies.DDPStrategy + target: strategies.DDPStrategy params: find_unused_parameters: False log_every_n_steps: 2 @@ -111,7 +111,7 @@ lightning: logger_config: wandb: - #target: loggers.WandbLogger + target: loggers.WandbLogger params: name: nowname save_dir: "/data2/tmp/diff_log/" diff --git a/examples/images/diffusion/ldm/models/autoencoder.py b/examples/images/diffusion/ldm/models/autoencoder.py index 145ccf6fb271..b1bd8377835b 100644 --- a/examples/images/diffusion/ldm/models/autoencoder.py +++ b/examples/images/diffusion/ldm/models/autoencoder.py @@ -6,10 +6,11 @@ import torch.nn.functional as F from contextlib import contextmanager -from torch.nn import Identity from ldm.modules.diffusionmodules.model import Encoder, Decoder from ldm.modules.distributions.distributions import DiagonalGaussianDistribution + +from ldm.util import instantiate_from_config from ldm.modules.ema import LitEma @@ -31,7 +32,7 @@ def __init__(self, self.image_key = image_key self.encoder = Encoder(**ddconfig) self.decoder = Decoder(**ddconfig) - self.loss = Identity(**lossconfig.get("params", dict())) + self.loss = instantiate_from_config(lossconfig) assert ddconfig["double_z"] self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) diff --git a/examples/images/diffusion/ldm/models/diffusion/classifier.py b/examples/images/diffusion/ldm/models/diffusion/classifier.py index 3cf12f093bea..612a8371bf20 100644 --- a/examples/images/diffusion/ldm/models/diffusion/classifier.py +++ b/examples/images/diffusion/ldm/models/diffusion/classifier.py @@ -9,10 +9,9 @@ from einops import rearrange from glob import glob from natsort import natsorted -from ldm.models.diffusion.ddpm import LatentDiffusion -from ldm.lr_scheduler import LambdaLinearScheduler + from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel -from ldm.util import log_txt_as_img, default, ismap +from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config __models__ = { 'class_label': EncoderUNetModel, @@ -87,7 +86,7 @@ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): print(f"Unexpected Keys: {unexpected}") def load_diffusion(self): - model = LatentDiffusion(**self.diffusion_config.get('params',dict())) + model = instantiate_from_config(self.diffusion_config) self.diffusion_model = model.eval() self.diffusion_model.train = disabled_train for param in self.diffusion_model.parameters(): @@ -222,7 +221,7 @@ def configure_optimizers(self): optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay) if self.use_scheduler: - scheduler = LambdaLinearScheduler(**self.scheduler_config.get('params',dict())) + scheduler = instantiate_from_config(self.scheduler_config) print("Setting up LambdaLR scheduler...") scheduler = [ diff --git a/examples/images/diffusion/ldm/models/diffusion/ddpm.py b/examples/images/diffusion/ldm/models/diffusion/ddpm.py index 11de828732ea..b7315b048c66 100644 --- a/examples/images/diffusion/ldm/models/diffusion/ddpm.py +++ b/examples/images/diffusion/ldm/models/diffusion/ddpm.py @@ -22,7 +22,6 @@ from functools import partial from einops import rearrange, repeat -from ldm.lr_scheduler import LambdaLinearScheduler from ldm.models.autoencoder import * from ldm.models.autoencoder import AutoencoderKL, IdentityFirstStage from ldm.models.diffusion.ddim import * @@ -30,10 +29,9 @@ from ldm.modules.diffusionmodules.model import * from ldm.modules.diffusionmodules.model import Decoder, Encoder, Model from ldm.modules.diffusionmodules.openaimodel import * -from ldm.modules.diffusionmodules.openaimodel import AttentionPool2d, UNetModel +from ldm.modules.diffusionmodules.openaimodel import AttentionPool2d from ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule, noise_like from ldm.modules.distributions.distributions import DiagonalGaussianDistribution, normal_kl -from ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation from ldm.modules.ema import LitEma from ldm.modules.encoders.modules import * from ldm.util import count_params, default, exists, instantiate_from_config, isimage, ismap, log_txt_as_img, mean_flat @@ -41,7 +39,6 @@ from torch.optim.lr_scheduler import LambdaLR from torchvision.utils import make_grid from tqdm import tqdm -from ldm.modules.midas.api import MiDaSInference __conditioning_keys__ = {'concat': 'c_concat', 'crossattn': 'c_crossattn', 'adm': 'y'} @@ -693,7 +690,7 @@ def register_schedule(self, self.make_cond_schedule() def instantiate_first_stage(self, config): - model = AutoencoderKL(**config.get("params", dict())) + model = instantiate_from_config(config) self.first_stage_model = model.eval() self.first_stage_model.train = disabled_train for param in self.first_stage_model.parameters(): @@ -709,7 +706,7 @@ def instantiate_cond_stage(self, config): self.cond_stage_model = None # self.be_unconditional = True else: - model = FrozenOpenCLIPEmbedder(**config.get("params", dict())) + model = instantiate_from_config(config) self.cond_stage_model = model.eval() self.cond_stage_model.train = disabled_train for param in self.cond_stage_model.parameters(): @@ -717,7 +714,7 @@ def instantiate_cond_stage(self, config): else: assert config != '__is_first_stage__' assert config != '__is_unconditional__' - model = FrozenOpenCLIPEmbedder(**config.get("params", dict())) + model = instantiate_from_config(config) self.cond_stage_model = model def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False): @@ -1482,7 +1479,8 @@ def configure_optimizers(self): # opt = torch.optim.AdamW(params, lr=lr) if self.use_scheduler: - scheduler = LambdaLinearScheduler(**self.scheduler_config.get("params", dict())) + assert 'target' in self.scheduler_config + scheduler = instantiate_from_config(self.scheduler_config) rank_zero_info("Setting up LambdaLR scheduler...") scheduler = [{'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule), 'interval': 'step', 'frequency': 1}] @@ -1504,7 +1502,7 @@ class DiffusionWrapper(pl.LightningModule): def __init__(self, diff_model_config, conditioning_key): super().__init__() self.sequential_cross_attn = diff_model_config.pop("sequential_crossattn", False) - self.diffusion_model = UNetModel(**diff_model_config.get("params", dict())) + self.diffusion_model = instantiate_from_config(diff_model_config) self.conditioning_key = conditioning_key assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm', 'crossattn-adm'] @@ -1553,7 +1551,7 @@ def __init__(self, *args, low_scale_config, low_scale_key="LR", noise_level_key= self.noise_level_key = noise_level_key def instantiate_low_stage(self, config): - model = ImageConcatWithNoiseAugmentation(**config.get("params", dict())) + model = instantiate_from_config(config) self.low_scale_model = model.eval() self.low_scale_model.train = disabled_train for param in self.low_scale_model.parameters(): @@ -1935,7 +1933,7 @@ class LatentDepth2ImageDiffusion(LatentFinetuneDiffusion): def __init__(self, depth_stage_config, concat_keys=("midas_in",), *args, **kwargs): super().__init__(concat_keys=concat_keys, *args, **kwargs) - self.depth_model = MiDaSInference(**depth_stage_config.get("params", dict())) + self.depth_model = instantiate_from_config(depth_stage_config) self.depth_stage_key = concat_keys[0] @torch.no_grad() @@ -2008,7 +2006,7 @@ def __init__(self, self.low_scale_key = low_scale_key def instantiate_low_stage(self, config): - model = ImageConcatWithNoiseAugmentation(**config.get("params", dict())) + model = instantiate_from_config(config) self.low_scale_model = model.eval() self.low_scale_model.train = disabled_train for param in self.low_scale_model.parameters(): diff --git a/examples/images/diffusion/main.py b/examples/images/diffusion/main.py index aeed6d5566f5..91b809d5a65c 100644 --- a/examples/images/diffusion/main.py +++ b/examples/images/diffusion/main.py @@ -23,21 +23,19 @@ from PIL import Image from prefetch_generator import BackgroundGenerator from torch.utils.data import DataLoader, Dataset, Subset, random_split -from ldm.models.diffusion.ddpm import LatentDiffusion -#try: -from lightning.pytorch import seed_everything -from lightning.pytorch.callbacks import Callback, LearningRateMonitor, ModelCheckpoint -from lightning.pytorch.trainer import Trainer -from lightning.pytorch.utilities import rank_zero_info, rank_zero_only -from lightning.pytorch.loggers import WandbLogger, TensorBoardLogger -from lightning.pytorch.strategies import ColossalAIStrategy,DDPStrategy -LIGHTNING_PACK_NAME = "lightning.pytorch." -# #except: -# from pytorch_lightning import seed_everything -# from pytorch_lightning.callbacks import Callback, LearningRateMonitor, ModelCheckpoint -# from pytorch_lightning.trainer import Trainer -# from pytorch_lightning.utilities import rank_zero_info, rank_zero_only -# LIGHTNING_PACK_NAME = "pytorch_lightning." + +try: + from lightning.pytorch import seed_everything + from lightning.pytorch.callbacks import Callback, LearningRateMonitor, ModelCheckpoint + from lightning.pytorch.trainer import Trainer + from lightning.pytorch.utilities import rank_zero_info, rank_zero_only + LIGHTNING_PACK_NAME = "lightning.pytorch." +except: + from pytorch_lightning import seed_everything + from pytorch_lightning.callbacks import Callback, LearningRateMonitor, ModelCheckpoint + from pytorch_lightning.trainer import Trainer + from pytorch_lightning.utilities import rank_zero_info, rank_zero_only + LIGHTNING_PACK_NAME = "pytorch_lightning." from ldm.data.base import Txt2ImgIterableBaseDataset from ldm.util import instantiate_from_config @@ -577,7 +575,7 @@ def on_train_epoch_end(self, trainer, pl_module): # target: path to test dataset # params: # key: value - # lightning: (optional, has same defaults and can be specified on cmdline) + # lightning: (optional, has sane defaults and can be specified on cmdline) # trainer: # additional arguments to trainer # logger: @@ -655,7 +653,7 @@ def on_train_epoch_end(self, trainer, pl_module): # Sets the seed for the random number generator to ensure reproducibility seed_everything(opt.seed) - # Intinalize and save configuration using the OmegaConf library. + # Intinalize and save configuratioon using teh OmegaConf library. try: # init and save configs configs = [OmegaConf.load(cfg) for cfg in opt.base] @@ -689,7 +687,7 @@ def on_train_epoch_end(self, trainer, pl_module): config.model["params"].update({"ckpt": ckpt}) rank_zero_info("Using ckpt_path = {}".format(config.model["params"]["ckpt"])) - model = LatentDiffusion(**config.model.get("params", dict())) + model = instantiate_from_config(config.model) # trainer and callbacks trainer_kwargs = dict() @@ -698,7 +696,7 @@ def on_train_epoch_end(self, trainer, pl_module): # These loggers are specified as targets in the dictionary, along with the configuration settings specific to each logger. default_logger_cfgs = { "wandb": { - #"target": LIGHTNING_PACK_NAME + "loggers.WandbLogger", + "target": LIGHTNING_PACK_NAME + "loggers.WandbLogger", "params": { "name": nowname, "save_dir": logdir, @@ -707,7 +705,7 @@ def on_train_epoch_end(self, trainer, pl_module): } }, "tensorboard": { - #"target": LIGHTNING_PACK_NAME + "loggers.TensorBoardLogger", + "target": LIGHTNING_PACK_NAME + "loggers.TensorBoardLogger", "params": { "save_dir": logdir, "name": "diff_tb", @@ -720,32 +718,30 @@ def on_train_epoch_end(self, trainer, pl_module): default_logger_cfg = default_logger_cfgs["tensorboard"] if "logger" in lightning_config: logger_cfg = lightning_config.logger - logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg) - trainer_kwargs["logger"] = WandbLogger(**logger_cfg.get("params", dict())) else: logger_cfg = default_logger_cfg - logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg) - trainer_kwargs["logger"] = TensorBoardLogger(**logger_cfg.get("params", dict())) - + logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg) + trainer_kwargs["logger"] = instantiate_from_config(logger_cfg) # config the strategy, defualt is ddp if "strategy" in trainer_config: strategy_cfg = trainer_config["strategy"] - trainer_kwargs["strategy"] = ColossalAIStrategy(**strategy_cfg.get("params", dict())) + strategy_cfg["target"] = LIGHTNING_PACK_NAME + strategy_cfg["target"] else: strategy_cfg = { - #"target": LIGHTNING_PACK_NAME + "strategies.DDPStrategy", + "target": LIGHTNING_PACK_NAME + "strategies.DDPStrategy", "params": { "find_unused_parameters": False } } - trainer_kwargs["strategy"] = DDPStrategy(**strategy_cfg.get("params", dict())) + + trainer_kwargs["strategy"] = instantiate_from_config(strategy_cfg) # Set up ModelCheckpoint callback to save best models # modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to # specify which metric is used to determine best models default_modelckpt_cfg = { - #"target": LIGHTNING_PACK_NAME + "callbacks.ModelCheckpoint", + "target": LIGHTNING_PACK_NAME + "callbacks.ModelCheckpoint", "params": { "dirpath": ckptdir, "filename": "{epoch:06}", @@ -763,13 +759,13 @@ def on_train_epoch_end(self, trainer, pl_module): modelckpt_cfg = OmegaConf.create() modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg) if version.parse(pl.__version__) < version.parse('1.4.0'): - trainer_kwargs["checkpoint_callback"] = ModelCheckpoint(**modelckpt_cfg.get("params", dict())) + trainer_kwargs["checkpoint_callback"] = instantiate_from_config(modelckpt_cfg) # Set up various callbacks, including logging, learning rate monitoring, and CUDA management # add callback which sets up log directory default_callbacks_cfg = { "setup_callback": { # callback to set up the training - #"target": "main.SetupCallback", + "target": "main.SetupCallback", "params": { "resume": opt.resume, # resume training if applicable "now": now, @@ -781,7 +777,7 @@ def on_train_epoch_end(self, trainer, pl_module): } }, "image_logger": { # callback to log image data - #"target": "main.ImageLogger", + "target": "main.ImageLogger", "params": { "batch_frequency": 750, # how frequently to log images "max_images": 4, # maximum number of images to log @@ -789,14 +785,14 @@ def on_train_epoch_end(self, trainer, pl_module): } }, "learning_rate_logger": { # callback to log learning rate - #"target": "main.LearningRateMonitor", + "target": "main.LearningRateMonitor", "params": { "logging_interval": "step", # logging frequency (either 'step' or 'epoch') # "log_momentum": True # whether to log momentum (currently commented out) } }, "cuda_callback": { # callback to handle CUDA-related operations - #"target": "main.CUDACallback" + "target": "main.CUDACallback" }, } @@ -814,7 +810,7 @@ def on_train_epoch_end(self, trainer, pl_module): 'Caution: Saving checkpoints every n train steps without deleting. This might require some free space.') default_metrics_over_trainsteps_ckpt_dict = { 'metrics_over_trainsteps_checkpoint': { - #"target": LIGHTNING_PACK_NAME + 'callbacks.ModelCheckpoint', + "target": LIGHTNING_PACK_NAME + 'callbacks.ModelCheckpoint', 'params': { "dirpath": os.path.join(ckptdir, 'trainstep_checkpoints'), "filename": "{epoch:06}-{step:09}", @@ -829,35 +825,15 @@ def on_train_epoch_end(self, trainer, pl_module): # Merge the default callbacks configuration with the specified callbacks configuration, and instantiate the callbacks callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg) - - #Instantiate items according to the configs - trainer_kwargs.setdefault("callbacks", []) - if "setup_callback" in callbacks_cfg: - setup_callback_config = callbacks_cfg["setup_callback"] - trainer_kwargs["callbacks"].append(SetupCallback(**setup_callback_config.get("params", dict()))) + trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg] - if "image_logger" in callbacks_cfg: - image_logger_config = callbacks_cfg["image_logger"] - trainer_kwargs["callbacks"].append(ImageLogger(**image_logger_config.get("params", dict()))) - - if "learning_rate_logger" in callbacks_cfg: - learning_rate_logger_config = callbacks_cfg["learning_rate_logger"] - trainer_kwargs["callbacks"].append(LearningRateMonitor(**learning_rate_logger_config.get("params", dict()))) - - if "cuda_callback" in callbacks_cfg: - cuda_callback_config = callbacks_cfg["cuda_callback"] - trainer_kwargs["callbacks"].append(CUDACallback(**cuda_callback_config.get("params", dict()))) - - if "metrics_over_trainsteps_checkpoint" in callbacks_cfg: - metrics_over_config = callbacks_cfg['metrics_over_trainsteps_checkpoint'] - trainer_kwargs["callbacks"].append(ModelCheckpoint(**metrics_over_config.get("params", dict()))) - #trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg] + # Create a Trainer object with the specified command-line arguments and keyword arguments, and set the log directory trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs) trainer.logdir = logdir - + # Create a data module based on the configuration file - data = DataModuleFromConfig(**config.data.get("params", dict())) + data = instantiate_from_config(config.data) # NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html # calling these ourselves should not be necessary but it is. # lightning still takes care of proper multiprocessing though diff --git a/examples/images/diffusion/scripts/img2img.py b/examples/images/diffusion/scripts/img2img.py index a3011005c16a..877538d4733d 100644 --- a/examples/images/diffusion/scripts/img2img.py +++ b/examples/images/diffusion/scripts/img2img.py @@ -20,8 +20,8 @@ from scripts.txt2img import put_watermark +from ldm.util import instantiate_from_config from ldm.models.diffusion.ddim import DDIMSampler -from ldm.models.diffusion.ddpm import LatentDiffusion from utils import replace_module, getModelSize @@ -36,7 +36,7 @@ def load_model_from_config(config, ckpt, verbose=False): if "global_step" in pl_sd: print(f"Global Step: {pl_sd['global_step']}") sd = pl_sd["state_dict"] - model = LatentDiffusion(**config.model.get("params", dict())) + model = instantiate_from_config(config.model) m, u = model.load_state_dict(sd, strict=False) if len(m) > 0 and verbose: print("missing keys:") diff --git a/examples/images/diffusion/scripts/inpaint.py b/examples/images/diffusion/scripts/inpaint.py index 993c67b0e6f7..d6e6387a9a3b 100644 --- a/examples/images/diffusion/scripts/inpaint.py +++ b/examples/images/diffusion/scripts/inpaint.py @@ -4,7 +4,7 @@ from tqdm import tqdm import numpy as np import torch -from ldm.models.diffusion.ddpm import LatentgDiffusion +from main import instantiate_from_config from ldm.models.diffusion.ddim import DDIMSampler @@ -57,7 +57,7 @@ def make_batch(image, mask, device): print(f"Found {len(masks)} inputs.") config = OmegaConf.load("models/ldm/inpainting_big/config.yaml") - model = LatentDiffusion(**config.model.get("params", dict())) + model = instantiate_from_config(config.model) model.load_state_dict(torch.load("models/ldm/inpainting_big/last.ckpt")["state_dict"], strict=False) diff --git a/examples/images/diffusion/scripts/knn2img.py b/examples/images/diffusion/scripts/knn2img.py index 66d9aa57de66..e6eaaecab53e 100644 --- a/examples/images/diffusion/scripts/knn2img.py +++ b/examples/images/diffusion/scripts/knn2img.py @@ -13,10 +13,9 @@ import time from multiprocessing import cpu_count -from ldm.util import parallel_data_prefetch +from ldm.util import instantiate_from_config, parallel_data_prefetch from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.plms import PLMSSampler -from ldm.models.diffusion.ddpm import LatentDiffusion from ldm.modules.encoders.modules import FrozenClipImageEmbedder, FrozenCLIPTextEmbedder DATABASES = [ @@ -45,7 +44,7 @@ def load_model_from_config(config, ckpt, verbose=False): if "global_step" in pl_sd: print(f"Global Step: {pl_sd['global_step']}") sd = pl_sd["state_dict"] - model = LatentDiffusion(**config.model.get("params", dict())) + model = instantiate_from_config(config.model) m, u = model.load_state_dict(sd, strict=False) if len(m) > 0 and verbose: print("missing keys:") diff --git a/examples/images/diffusion/scripts/sample_diffusion.py b/examples/images/diffusion/scripts/sample_diffusion.py index a25965ef74de..876fe3c3642f 100644 --- a/examples/images/diffusion/scripts/sample_diffusion.py +++ b/examples/images/diffusion/scripts/sample_diffusion.py @@ -8,6 +8,7 @@ from PIL import Image from ldm.models.diffusion.ddim import DDIMSampler +from ldm.util import instantiate_from_config rescale = lambda x: (x + 1.) / 2. @@ -217,7 +218,7 @@ def get_parser(): def load_model_from_config(config, sd): - model = LatentDiffusion(**config.get("params", dict())) + model = instantiate_from_config(config) model.load_state_dict(sd,strict=False) model.cuda() model.eval() diff --git a/examples/images/diffusion/scripts/tests/test_checkpoint.py b/examples/images/diffusion/scripts/tests/test_checkpoint.py index a157d186d6e7..a32e66d44cf2 100644 --- a/examples/images/diffusion/scripts/tests/test_checkpoint.py +++ b/examples/images/diffusion/scripts/tests/test_checkpoint.py @@ -9,7 +9,7 @@ import torch from ldm.util import instantiate_from_config from main import get_parser -from ldm.modules.diffusionmodules.openaimodel import UNetModel + if __name__ == "__main__": with torch.no_grad(): yaml_path = "../../train_colossalai.yaml" @@ -17,7 +17,7 @@ config = f.read() base_config = yaml.load(config, Loader=yaml.FullLoader) unet_config = base_config['model']['params']['unet_config'] - diffusion_model = UNetModel(**unet_config.get("params", dict())).to("cuda:0") + diffusion_model = instantiate_from_config(unet_config).to("cuda:0") pipe = StableDiffusionPipeline.from_pretrained( "/data/scratch/diffuser/stable-diffusion-v1-4" diff --git a/examples/images/diffusion/scripts/txt2img.py b/examples/images/diffusion/scripts/txt2img.py index b198430f6d1c..364ebac6c67b 100644 --- a/examples/images/diffusion/scripts/txt2img.py +++ b/examples/images/diffusion/scripts/txt2img.py @@ -16,9 +16,9 @@ from contextlib import nullcontext from imwatermark import WatermarkEncoder +from ldm.util import instantiate_from_config from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.plms import PLMSSampler -from ldm.models.diffusion.ddpm import LatentDiffusion from ldm.models.diffusion.dpm_solver import DPMSolverSampler from utils import replace_module, getModelSize @@ -35,7 +35,7 @@ def load_model_from_config(config, ckpt, verbose=False): if "global_step" in pl_sd: print(f"Global Step: {pl_sd['global_step']}") sd = pl_sd["state_dict"] - model = LatentDiffusion(**config.model.get("params", dict())) + model = instantiate_from_config(config.model) m, u = model.load_state_dict(sd, strict=False) if len(m) > 0 and verbose: print("missing keys:") From ab5fd127e393f7298c6497fd638d0f89c53a9452 Mon Sep 17 00:00:00 2001 From: mandoxzhang <111039218+mandoxzhang@users.noreply.github.com> Date: Fri, 7 Apr 2023 10:34:51 +0800 Subject: [PATCH 25/27] [example] update roberta with newer ColossalAI (#3472) * update roberta example * update roberta example --- examples/language/roberta/README.md | 16 +- .../roberta/configs/colossalai_ddp.py | 9 -- .../roberta/configs/colossalai_zero.py | 37 ----- .../roberta/preprocessing/get_mask.py | 7 +- .../roberta/preprocessing/sentence_split.py | 19 +-- .../roberta/preprocessing/tokenize_mask.py | 12 +- .../language/roberta/pretraining/arguments.py | 24 +++ .../roberta/pretraining/evaluation.py | 12 +- .../roberta/pretraining/pretrain_utils.py | 4 +- .../roberta/pretraining/run_pretrain.sh | 2 - .../pretraining/run_pretrain_resume.sh | 2 - .../roberta/pretraining/run_pretraining.py | 142 ++++++++++++++++-- examples/language/roberta/requirements.txt | 5 + 13 files changed, 188 insertions(+), 103 deletions(-) delete mode 100644 examples/language/roberta/configs/colossalai_ddp.py delete mode 100644 examples/language/roberta/configs/colossalai_zero.py diff --git a/examples/language/roberta/README.md b/examples/language/roberta/README.md index a42b1935dd85..0e080d00981a 100644 --- a/examples/language/roberta/README.md +++ b/examples/language/roberta/README.md @@ -1,9 +1,9 @@ # Introduction -This repo introduce how to pretrain a chinese roberta-large from scratch, including preprocessing, pretraining, finetune. The repo can help you quickly train a high-quality bert. +This example introduce how to pretrain roberta from scratch, including preprocessing, pretraining, finetune. The example can help you quickly train a high-quality roberta. ## 0. Prerequisite - Install Colossal-AI -- Editing the port from /etc/ssh/sshd_config and /etc/ssh/ssh_config, every host expose the same ssh port of server and client. If you are a root user, you also set the **PermitRootLogin** from /etc/ssh/sshd_config to "yes" +- Editing the port from `/etc/ssh/sshd_config` and `/etc/ssh/ssh_config`, every host expose the same ssh port of server and client. If you are a root user, you also set the **PermitRootLogin** from `/etc/ssh/sshd_config` to "yes" - Ensure that each host can log in to each other without password. If you have n hosts, need to execute n2 times ``` @@ -33,7 +33,7 @@ service ssh restart ```bash cd preprocessing ``` -following the `README.md`, preprocess original corpus to h5py+numpy +following the `README.md`, preprocess original corpus to h5py plus numpy ## 2. Pretrain @@ -47,12 +47,4 @@ following the `README.md`, load the h5py generated by preprocess of step 1 to pr The checkpoint produced by this repo can replace `pytorch_model.bin` from [hfl/chinese-roberta-wwm-ext-large](https://huggingface.co/hfl/chinese-roberta-wwm-ext-large/tree/main) directly. Then use transfomers from Hugging Face to finetune downstream application. ## Contributors -The repo is contributed by AI team from [Moore Threads](https://www.mthreads.com/). If you find any problems for pretraining, please file an issue or send an email to yehua.zhang@mthreads.com. At last, welcome any form of contribution! - -``` -@misc{ - title={A simple Chinese RoBERTa Example for Whole Word Masked}, - author={Yehua Zhang, Chen Zhang}, - year={2022} -} -``` +The example is contributed by AI team from [Moore Threads](https://www.mthreads.com/). If you find any problems for pretraining, please file an issue or send an email to yehua.zhang@mthreads.com. At last, welcome any form of contribution! diff --git a/examples/language/roberta/configs/colossalai_ddp.py b/examples/language/roberta/configs/colossalai_ddp.py deleted file mode 100644 index 3146ffc45eef..000000000000 --- a/examples/language/roberta/configs/colossalai_ddp.py +++ /dev/null @@ -1,9 +0,0 @@ -from colossalai.nn.optimizer import FusedAdam - -try: - from colossalai.zero.shard_utils import TensorShardStrategy -except ImportError: - # colossalai > 0.2.8 - from colossalai.zero.legacy import TensorShardStrategy - -clip_grad_norm = 1.0 diff --git a/examples/language/roberta/configs/colossalai_zero.py b/examples/language/roberta/configs/colossalai_zero.py deleted file mode 100644 index bae4c723ccc8..000000000000 --- a/examples/language/roberta/configs/colossalai_zero.py +++ /dev/null @@ -1,37 +0,0 @@ -from colossalai.nn.optimizer import FusedAdam - -try: - from colossalai.zero.shard_utils import TensorShardStrategy -except ImportError: - # colossalai > 0.2.8 - from colossalai.zero.legacy import TensorShardStrategy - -# fp16 = dict( -# mode=AMP_TYPE.TORCH, -# ) - -# seed = 2 -zero = dict(model_config=dict(shard_strategy=TensorShardStrategy(), - reduce_scatter_bucket_size_mb=25, - fp32_reduce_scatter=False, - tensor_placement_policy="cuda", - gradient_predivide_factor=1.0, - reuse_fp16_shard=False), - optimizer_config=dict(gpu_margin_mem_ratio=0.8, - initial_scale=2**5, - min_scale=1, - growth_factor=2, - backoff_factor=0.5, - growth_interval=1000, - hysteresis=2, - max_scale=2**32)) - -# gradient_accumulation = 4 -clip_grad_norm = 1.0 -optimizer = dict( - type=FusedAdam, - lr=0.00015, - weight_decay=1e-2, -) - -# 64433 diff --git a/examples/language/roberta/preprocessing/get_mask.py b/examples/language/roberta/preprocessing/get_mask.py index da297f98e6c9..869ef2cb377c 100644 --- a/examples/language/roberta/preprocessing/get_mask.py +++ b/examples/language/roberta/preprocessing/get_mask.py @@ -163,16 +163,15 @@ def create_masked_lm_predictions(self, tokens): def get_new_segment(self, segment): """ - 输入一句话,返回一句经过处理的话: 为了支持中文全称mask,将被分开的词,将上特殊标记("#"),使得后续处理模块,能够知道哪些字是属于同一个词的。 - :param segment: 一句话 - :return: 一句处理过的话 + Input a sentence, return a processed sentence: In order to support the Chinese whole word mask, the words that are separated will be marked with a special mark ("#"), so that the subsequent processing module can know which words belong to the same word. + :param segment: a sentence """ seq_cws = jieba.lcut(''.join(segment)) seq_cws_dict = {x: 1 for x in seq_cws} new_segment = [] i = 0 while i < len(segment): - if len(self.rec.findall(segment[i])) == 0: # 不是中文的,原文加进去。 + if len(self.rec.findall(segment[i])) == 0: new_segment.append(segment[i]) i += 1 continue diff --git a/examples/language/roberta/preprocessing/sentence_split.py b/examples/language/roberta/preprocessing/sentence_split.py index 231be152b067..f0ed83f90114 100644 --- a/examples/language/roberta/preprocessing/sentence_split.py +++ b/examples/language/roberta/preprocessing/sentence_split.py @@ -10,26 +10,19 @@ import functools def split_sentence(document: str, flag: str = "all", limit: int = 510) -> List[str]: - """ - Args: - document: - flag: Type:str, "all" 中英文标点分句,"zh" 中文标点分句,"en" 英文标点分句 - limit: 默认单句最大长度为510个字符 - Returns: Type:list - """ sent_list = [] try: if flag == "zh": - document = re.sub('(?P([。?!…](?![”’"\'])))', r'\g\n', document) # 单字符断句符 - document = re.sub('(?P([。?!]|…{1,2})[”’"\'])', r'\g\n', document) # 特殊引号 + document = re.sub('(?P([。?!…](?![”’"\'])))', r'\g\n', document) + document = re.sub('(?P([。?!]|…{1,2})[”’"\'])', r'\g\n', document) elif flag == "en": - document = re.sub('(?P([.?!](?![”’"\'])))', r'\g\n', document) # 英文单字符断句符 - document = re.sub('(?P([?!.]["\']))', r'\g\n', document) # 特殊引号 + document = re.sub('(?P([.?!](?![”’"\'])))', r'\g\n', document) + document = re.sub('(?P([?!.]["\']))', r'\g\n', document) # Special quotation marks else: - document = re.sub('(?P([。?!….?!](?![”’"\'])))', r'\g\n', document) # 单字符断句符 + document = re.sub('(?P([。?!….?!](?![”’"\'])))', r'\g\n', document) document = re.sub('(?P(([。?!.!?]|…{1,2})[”’"\']))', r'\g\n', - document) # 特殊引号 + document) # Special quotation marks sent_list_ori = document.splitlines() for sent in sent_list_ori: diff --git a/examples/language/roberta/preprocessing/tokenize_mask.py b/examples/language/roberta/preprocessing/tokenize_mask.py index b33871d5d037..76c74868e1fc 100644 --- a/examples/language/roberta/preprocessing/tokenize_mask.py +++ b/examples/language/roberta/preprocessing/tokenize_mask.py @@ -15,8 +15,8 @@ def get_raw_instance(document, max_sequence_length=512): """ - 获取初步的训练实例,将整段按照max_sequence_length切分成多个部分,并以多个处理好的实例的形式返回。 - :param document: 一整段 + Get the initial training instances, split the whole segment into multiple parts according to the max_sequence_length, and return as multiple processed instances. + :param document: document :param max_sequence_length: :return: a list. each element is a sequence of text """ @@ -26,10 +26,9 @@ def get_raw_instance(document, max_sequence_length=512): sizes = [len(seq) for seq in document] result_list = [] - curr_seq = [] # 当前处理的序列 + curr_seq = [] sz_idx = 0 while sz_idx < len(sizes): - # 当前句子加上新的句子,如果长度小于最大限制,则合并当前句子和新句子;否则即超过了最大限制,那么做为一个新的序列加到目标列表中 if len(curr_seq) + sizes[sz_idx] <= max_sequence_length_allowed: # or len(curr_seq)==0: curr_seq += document[sz_idx] @@ -43,14 +42,13 @@ def get_raw_instance(document, max_sequence_length=512): else: result_list.append(curr_seq) curr_seq = [] - # 对最后一个序列进行处理,如果太短的话,丢弃掉。 + if len(curr_seq) > max_sequence_length_allowed / 2: # /2 result_list.append(curr_seq) - # # 计算总共可以得到多少份 # num_instance=int(len(big_list)/max_sequence_length_allowed)+1 # print("num_instance:",num_instance) - # # 切分成多份,添加到列表中 + # result_list=[] # for j in range(num_instance): # index=j*max_sequence_length_allowed diff --git a/examples/language/roberta/pretraining/arguments.py b/examples/language/roberta/pretraining/arguments.py index 3a9370e00b0c..87fa8dd8a8ae 100644 --- a/examples/language/roberta/pretraining/arguments.py +++ b/examples/language/roberta/pretraining/arguments.py @@ -6,6 +6,30 @@ def parse_args(): parser = colossalai.get_default_parser() + + parser.add_argument( + "--distplan", + type=str, + default='CAI_Gemini', + help="The distributed plan [colossalai, zero1, zero2, torch_ddp, torch_zero].", + ) + parser.add_argument( + "--tp_degree", + type=int, + default=1, + help="Tensor Parallelism Degree. Valid when using colossalai as dist plan.", + ) + parser.add_argument( + "--placement", + type=str, + default='cpu', + help="Placement Policy for Gemini. Valid when using colossalai as dist plan.", + ) + parser.add_argument( + "--shardinit", + action='store_true', + help="Shard the tensors when init the model to shrink peak memory size on the assigned device. Valid when using colossalai as dist plan.", + ) parser.add_argument( '--lr', diff --git a/examples/language/roberta/pretraining/evaluation.py b/examples/language/roberta/pretraining/evaluation.py index 83f94082f6c0..8fc019c121ac 100644 --- a/examples/language/roberta/pretraining/evaluation.py +++ b/examples/language/roberta/pretraining/evaluation.py @@ -5,11 +5,11 @@ from utils.global_vars import get_timers, get_tensorboard_writer from nvidia_bert_dataset_provider import NvidiaBertDatasetProvider -def evaluate(engine, args, logger, global_step): +def evaluate(model, args, logger, global_step, criterion): evaluate_dataset_provider = NvidiaBertDatasetProvider(args, evaluate=True) start_shard = 0 - engine.eval() + model.eval() timers = get_timers() eval_step = 0 eval_loss = 0 @@ -39,9 +39,9 @@ def evaluate(engine, args, logger, global_step): mlm_label = batch_data[3].cuda() # nsp_label = batch_data[5].cuda() - output = engine(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) + output = model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) - loss = engine.criterion(output.logits, mlm_label)#prediction_scores + loss = criterion(output.logits, mlm_label)#prediction_scores evaluate_dataset_provider.prefetch_batch() eval_loss += loss.float().item() @@ -67,5 +67,5 @@ def evaluate(engine, args, logger, global_step): logger.info('') evaluate_dataset_provider.release_shard() - engine.train() - return cur_loss + model.train() + return cur_loss \ No newline at end of file diff --git a/examples/language/roberta/pretraining/pretrain_utils.py b/examples/language/roberta/pretraining/pretrain_utils.py index ba17b0f5ee09..54fc2affe632 100644 --- a/examples/language/roberta/pretraining/pretrain_utils.py +++ b/examples/language/roberta/pretraining/pretrain_utils.py @@ -5,7 +5,7 @@ from transformers import BertForPreTraining, RobertaForMaskedLM, RobertaConfig from transformers import GPT2Config, GPT2LMHeadModel from transformers import AutoTokenizer, AutoModelForMaskedLM -from colossalai.nn.optimizer import FusedAdam +from colossalai.nn.optimizer import FusedAdam, HybridAdam from torch.optim import AdamW from colossalai.core import global_context as gpc import torch @@ -83,7 +83,7 @@ def get_optimizer(model, lr): 'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0 }] - optimizer = FusedAdam(optimizer_grouped_parameters, lr=lr, betas=[0.9, 0.95]) + optimizer = HybridAdam(optimizer_grouped_parameters, lr=lr, betas=[0.9, 0.95]) return optimizer diff --git a/examples/language/roberta/pretraining/run_pretrain.sh b/examples/language/roberta/pretraining/run_pretrain.sh index 144cd0ab96fd..38fdefe0af8a 100644 --- a/examples/language/roberta/pretraining/run_pretrain.sh +++ b/examples/language/roberta/pretraining/run_pretrain.sh @@ -7,7 +7,6 @@ tensorboard_path="$root_path/tensorboard" log_path="$root_path/exp_log" ckpt_path="$root_path/ckpt" -colossal_config="$root_path/../configs/colossalai_ddp.py" mkdir -p $tensorboard_path mkdir -p $log_path @@ -32,7 +31,6 @@ env OMP_NUM_THREADS=40 colossalai run --hostfile ./hostfile \ --tensorboard_path $tensorboard_path \ --log_path $log_path \ --ckpt_path $ckpt_path \ - --colossal_config $colossal_config \ --log_interval 50 \ --mlm bert \ --wandb \ diff --git a/examples/language/roberta/pretraining/run_pretrain_resume.sh b/examples/language/roberta/pretraining/run_pretrain_resume.sh index a0704cf7c517..351c98d3e9cb 100644 --- a/examples/language/roberta/pretraining/run_pretrain_resume.sh +++ b/examples/language/roberta/pretraining/run_pretrain_resume.sh @@ -7,7 +7,6 @@ tensorboard_path="$root_path/tensorboard" log_path="$root_path/exp_log" ckpt_path="$root_path/ckpt" -colossal_config="$root_path/../configs/colossalai_ddp.py" mkdir -p $tensorboard_path mkdir -p $log_path @@ -32,7 +31,6 @@ env OMP_NUM_THREADS=40 colossalai run --hostfile ./hostfile \ --tensorboard_path $tensorboard_path \ --log_path $log_path \ --ckpt_path $ckpt_path \ - --colossal_config $colossal_config \ --log_interval 50 \ --mlm bert \ --wandb \ diff --git a/examples/language/roberta/pretraining/run_pretraining.py b/examples/language/roberta/pretraining/run_pretraining.py index eef7bb6ad5cd..6e3cae6ec042 100644 --- a/examples/language/roberta/pretraining/run_pretraining.py +++ b/examples/language/roberta/pretraining/run_pretraining.py @@ -4,9 +4,31 @@ from functools import partial import torch +<<<<<<< HEAD +from tqdm import tqdm +import os +import time +from functools import partial +from transformers import AutoTokenizer + +import colossalai +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.nn.parallel import GeminiDDP, zero_model_wrapper, zero_optim_wrapper +from colossalai.utils import get_current_device +from colossalai.utils.model.colo_init_context import ColoInitContext +from colossalai.zero import ZeroOptimizer +from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec + +======= +>>>>>>> 52a933e17509c71811e919b165de38cb3d5d6d41 from arguments import parse_args from evaluation import evaluate from loss import LossForPretraining +<<<<<<< HEAD + +from nvidia_bert_dataset_provider import NvidiaBertDatasetProvider +======= from nvidia_bert_dataset_provider import NvidiaBertDatasetProvider from pretrain_utils import get_lr_scheduler, get_model, get_optimizer, save_ckpt from tqdm import tqdm @@ -27,6 +49,7 @@ from colossalai.zero.gemini import ChunkManager, ColoInitContext, GeminiManager from colossalai.zero.legacy import ShardedModelV2, ShardedOptimizerV2, ZeroInitContext from colossalai.zero.legacy.shard_utils import TensorShardStrategy +>>>>>>> 52a933e17509c71811e919b165de38cb3d5d6d41 def main(): @@ -36,8 +59,13 @@ def main(): tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path) +<<<<<<< HEAD + # os.environ['CUDA_LAUNCH_BLOCKING'] = '1' + +======= os.environ['CUDA_LAUNCH_BLOCKING'] = '1' +>>>>>>> 52a933e17509c71811e919b165de38cb3d5d6d41 logger = Logger(os.path.join(args.log_path, launch_time), cuda=torch.cuda.is_available(), debug=args.vscode_debug) if args.vscode_debug: @@ -50,7 +78,11 @@ def main(): args.local_rank = -1 args.log_interval = 1 else: +<<<<<<< HEAD + colossalai.launch_from_torch(config={}) #args.colossal_config +======= colossalai.launch_from_torch(args.colossal_config) # args.colossal_config +>>>>>>> 52a933e17509c71811e919b165de38cb3d5d6d41 args.local_rank = int(os.environ["LOCAL_RANK"]) logger.info( f'launch_from_torch, world size: {torch.distributed.get_world_size()} | ' + @@ -61,32 +93,94 @@ def main(): args.tokenizer = tokenizer args.logger = logger set_global_variables(launch_time, args.tensorboard_path) +<<<<<<< HEAD + +======= use_zero = hasattr(gpc.config, 'zero') +>>>>>>> 52a933e17509c71811e919b165de38cb3d5d6d41 world_size = torch.distributed.get_world_size() + init_dev = get_current_device() # build model, optimizer and criterion +<<<<<<< HEAD + if args.distplan.startswith("CAI"): + # all param must use the same process group. + world_size = torch.distributed.get_world_size() + shard_pg = ProcessGroup(tp_degree=world_size) if args.shardinit else None + default_dist_spec = ShardSpec([-1], [world_size]) if args.shardinit else None + + if args.shardinit and args.distplan != "CAI_Gemini": + raise RuntimeError("You can only use shardinit with CAI_Gemini") + + # build GPT model + with ColoInitContext(device=get_current_device(), + dtype=torch.half, + default_dist_spec=default_dist_spec, + default_pg=shard_pg): +======= if use_zero: shard_strategy = TensorShardStrategy() with ZeroInitContext(target_device=torch.cuda.current_device(), shard_strategy=shard_strategy, shard_param=True): +>>>>>>> 52a933e17509c71811e919b165de38cb3d5d6d41 config, model, numel = get_model(args, logger) - # model = ShardedModelV2(model, shard_strategy, tensor_placement_policy='cpu', reuse_fp16_shard=True) + + # asign running configurations + gemini_config = None + if args.distplan.startswith("CAI_ZeRO"): + optim_config = dict(reduce_bucket_size=12 * 1024 * 1024, overlap_communication=True, verbose=True) + elif args.distplan == "CAI_Gemini": + gemini_config = dict(strict_ddp_mode=args.tp_degree == 1, + device=get_current_device(), + placement_policy=args.placement, + pin_memory=True, + hidden_dim=model.config.hidden_size, + search_range_mb=128) + optim_config = dict(gpu_margin_mem_ratio=0.) + else: + raise RuntimeError + + # build a highly optimized gpu/cpu optimizer + optimizer = get_optimizer(model, lr=args.lr) + + if args.distplan == "CAI_ZeRO1": + zero_stage = 1 + elif args.distplan == "CAI_ZeRO2": + zero_stage = 2 + elif args.distplan == "CAI_Gemini": + zero_stage = 3 + else: + raise RuntimeError + + # wrap your model and optimizer + model = zero_model_wrapper(model, zero_stage, gemini_config) + optimizer = zero_optim_wrapper(model, optimizer, optim_config=optim_config) + + logger.info(get_mem_info(prefix='After init optim, ')) + else: config, model, numel = get_model(args, logger) logger.info("no_zero") + if torch.distributed.get_rank() == 0: os.mkdir(os.path.join(args.ckpt_path, launch_time)) logger.info(f'Model numel: {numel}') get_tflops_func = partial(get_tflops, numel, args.train_micro_batch_size_per_gpu, args.max_seq_length) +<<<<<<< HEAD + + # 144003367 is is the length of the entire dataset + steps_per_epoch = 144003367 // world_size // args.train_micro_batch_size_per_gpu // args.gradient_accumulation_steps // args.refresh_bucket_size #len(dataloader) +======= # len(dataloader) steps_per_epoch = 144003367 // world_size // args.train_micro_batch_size_per_gpu // args.gradient_accumulation_steps // args.refresh_bucket_size +>>>>>>> 52a933e17509c71811e919b165de38cb3d5d6d41 total_steps = steps_per_epoch * args.epoch - # build optimizer and lr_scheduler + lr_scheduler = get_lr_scheduler(optimizer, total_steps=total_steps, last_epoch=-1) start_epoch = 0 start_shard = 0 @@ -95,7 +189,6 @@ def main(): assert os.path.exists(args.load_optimizer_lr) o_l_state_dict = torch.load(args.load_optimizer_lr, map_location='cpu') o_l_state_dict['lr_scheduler']['last_epoch'] = o_l_state_dict['lr_scheduler']['last_epoch'] - 1 - optimizer = get_optimizer(model, lr=args.lr) optimizer.load_state_dict(o_l_state_dict['optimizer']) # o_l_state_dict['lr_scheduler']['last_epoch'] lr_scheduler = get_lr_scheduler(optimizer, @@ -105,33 +198,38 @@ def main(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.cuda(f"cuda:{torch.cuda.current_device()}") - # if you want delete the above three code, have to move the model to gpu, because in optimizer.step() + # if you want delete the above three code, must move the model to gpu. Because in optimizer.step() lr_scheduler.load_state_dict(o_l_state_dict['lr_scheduler']) start_epoch = o_l_state_dict['epoch'] start_shard = o_l_state_dict['shard'] + 1 # global_step = o_l_state_dict['global_step'] + 1 +<<<<<<< HEAD + logger.info(f'resume from epoch {start_epoch} shard {start_shard} step {lr_scheduler.last_epoch} lr {lr_scheduler.get_last_lr()[0]}') +======= logger.info( f'resume from epoch {start_epoch} shard {start_shard} step {lr_scheduler.last_epoch} lr {lr_scheduler.get_last_lr()[0]}' ) else: optimizer = get_optimizer(model, lr=args.lr) lr_scheduler = get_lr_scheduler(optimizer, total_steps=total_steps, last_epoch=-1) +>>>>>>> 52a933e17509c71811e919b165de38cb3d5d6d41 - # optimizer = gpc.config.optimizer.pop('type')( - # model.parameters(), **gpc.config.optimizer) - # optimizer = ShardedOptimizerV2(model, optimizer, initial_scale=2**5) criterion = LossForPretraining(config.vocab_size) # build dataloader pretrain_dataset_provider = NvidiaBertDatasetProvider(args) +<<<<<<< HEAD + +======= # initialize with colossalai engine, _, _, lr_scheduelr = colossalai.initialize(model=model, optimizer=optimizer, criterion=criterion, lr_scheduler=lr_scheduler) +>>>>>>> 52a933e17509c71811e919b165de38cb3d5d6d41 logger.info(get_mem_info(prefix='After init model, ')) best_loss = None @@ -156,9 +254,15 @@ def main(): else: iterator_data = enumerate(dataset_iterator) +<<<<<<< HEAD + model.train() + + for step, batch_data in iterator_data: +======= engine.train() for step, batch_data in iterator_data: +>>>>>>> 52a933e17509c71811e919b165de38cb3d5d6d41 # batch_data = pretrain_dataset_provider.get_batch(batch_index) input_ids = batch_data[0].cuda(f"cuda:{torch.cuda.current_device()}") @@ -167,18 +271,31 @@ def main(): mlm_label = batch_data[3].cuda(f"cuda:{torch.cuda.current_device()}") # nsp_label = batch_data[5].cuda() +<<<<<<< HEAD + output = model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) + + loss = criterion(output.logits, mlm_label) +======= output = engine(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) loss = engine.criterion(output.logits, mlm_label) +>>>>>>> 52a933e17509c71811e919b165de38cb3d5d6d41 pretrain_dataset_provider.prefetch_batch() - engine.backward(loss) + optimizer.backward(loss) train_loss += loss.float().item() # if (step + 1) % args.accumulation_step == 0: +<<<<<<< HEAD + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + +======= engine.step() lr_scheduelr.step() engine.zero_grad() +>>>>>>> 52a933e17509c71811e919b165de38cb3d5d6d41 global_step += 1 if global_step % args.log_interval == 0 and global_step != 0 \ @@ -189,7 +306,7 @@ def main(): numel, args, config, elapsed_time, global_step, world_size) cur_loss = train_loss / args.log_interval - current_lr = lr_scheduelr.get_last_lr()[0] + current_lr = lr_scheduler.get_last_lr()[0] log_str = f'| epoch: {epoch} | shard: {shard} | step: {global_step} | lr {current_lr:.7f} | elapsed_time: {elapsed_time / 60 :.3f} minutes ' + \ f'| mins/batch: {elapsed_time_per_iteration :.3f} seconds | loss: {cur_loss:.7f} | ppl: {math.exp(cur_loss):.3f} | TFLOPS: {get_tflops_func(elapsed_time_per_iteration):.3f} or {tflops:.3f}' logger.info(log_str, print_=False) @@ -209,11 +326,18 @@ def main(): logger.info(f'epoch {epoch} shard {shard} has cost {timers("shard_time").elapsed() / 60 :.3f} mins') logger.info('*' * 100) +<<<<<<< HEAD + eval_loss += evaluate(model, args, logger, global_step, criterion) + save_ckpt(model, optimizer, lr_scheduler, os.path.join(args.ckpt_path, launch_time, f'epoch-{epoch}_shard-{shard}_' + launch_time), epoch, shard, global_step) + + +======= eval_loss += evaluate(engine, args, logger, global_step) save_ckpt(engine.model, optimizer, lr_scheduelr, os.path.join(args.ckpt_path, launch_time, f'epoch-{epoch}_shard-{shard}_' + launch_time), epoch, shard, global_step) +>>>>>>> 52a933e17509c71811e919b165de38cb3d5d6d41 eval_loss /= len(os.listdir(args.data_path_prefix)) logger.info( f'epoch {epoch} | shard_length {len(os.listdir(args.data_path_prefix))} | elapsed_time: {timers("epoch_time").elapsed() / 60 :.3f} mins' diff --git a/examples/language/roberta/requirements.txt b/examples/language/roberta/requirements.txt index 137a69e80498..d351f362f3f7 100644 --- a/examples/language/roberta/requirements.txt +++ b/examples/language/roberta/requirements.txt @@ -1,2 +1,7 @@ colossalai >= 0.1.12 torch >= 1.8.1 +tqdm +tensorboard +numpy +h5py +wandb \ No newline at end of file From 8f2c55f9c99012e4cbfefa422ab2e91dfece447e Mon Sep 17 00:00:00 2001 From: mandoxzhang <111039218+mandoxzhang@users.noreply.github.com> Date: Fri, 7 Apr 2023 11:33:32 +0800 Subject: [PATCH 26/27] [example] remove redundant texts & update roberta (#3493) * update roberta example * update roberta example * modify conflict & update roberta --- .../roberta/pretraining/run_pretraining.py | 93 ------------------- 1 file changed, 93 deletions(-) diff --git a/examples/language/roberta/pretraining/run_pretraining.py b/examples/language/roberta/pretraining/run_pretraining.py index 6e3cae6ec042..a283c44cadbf 100644 --- a/examples/language/roberta/pretraining/run_pretraining.py +++ b/examples/language/roberta/pretraining/run_pretraining.py @@ -4,7 +4,6 @@ from functools import partial import torch -<<<<<<< HEAD from tqdm import tqdm import os import time @@ -20,15 +19,9 @@ from colossalai.zero import ZeroOptimizer from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec -======= ->>>>>>> 52a933e17509c71811e919b165de38cb3d5d6d41 from arguments import parse_args from evaluation import evaluate from loss import LossForPretraining -<<<<<<< HEAD - -from nvidia_bert_dataset_provider import NvidiaBertDatasetProvider -======= from nvidia_bert_dataset_provider import NvidiaBertDatasetProvider from pretrain_utils import get_lr_scheduler, get_model, get_optimizer, save_ckpt from tqdm import tqdm @@ -37,20 +30,6 @@ from utils.global_vars import get_tensorboard_writer, get_timers, set_global_variables from utils.logger import Logger -import colossalai -import colossalai.nn as col_nn -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.nn.optimizer import HybridAdam -from colossalai.nn.parallel import ZeroDDP -from colossalai.tensor import ProcessGroup -from colossalai.utils import get_current_device -from colossalai.zero import ZeroOptimizer -from colossalai.zero.gemini import ChunkManager, ColoInitContext, GeminiManager -from colossalai.zero.legacy import ShardedModelV2, ShardedOptimizerV2, ZeroInitContext -from colossalai.zero.legacy.shard_utils import TensorShardStrategy ->>>>>>> 52a933e17509c71811e919b165de38cb3d5d6d41 - def main(): @@ -59,13 +38,8 @@ def main(): tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path) -<<<<<<< HEAD # os.environ['CUDA_LAUNCH_BLOCKING'] = '1' -======= - os.environ['CUDA_LAUNCH_BLOCKING'] = '1' - ->>>>>>> 52a933e17509c71811e919b165de38cb3d5d6d41 logger = Logger(os.path.join(args.log_path, launch_time), cuda=torch.cuda.is_available(), debug=args.vscode_debug) if args.vscode_debug: @@ -78,11 +52,7 @@ def main(): args.local_rank = -1 args.log_interval = 1 else: -<<<<<<< HEAD colossalai.launch_from_torch(config={}) #args.colossal_config -======= - colossalai.launch_from_torch(args.colossal_config) # args.colossal_config ->>>>>>> 52a933e17509c71811e919b165de38cb3d5d6d41 args.local_rank = int(os.environ["LOCAL_RANK"]) logger.info( f'launch_from_torch, world size: {torch.distributed.get_world_size()} | ' + @@ -93,17 +63,11 @@ def main(): args.tokenizer = tokenizer args.logger = logger set_global_variables(launch_time, args.tensorboard_path) -<<<<<<< HEAD -======= - - use_zero = hasattr(gpc.config, 'zero') ->>>>>>> 52a933e17509c71811e919b165de38cb3d5d6d41 world_size = torch.distributed.get_world_size() init_dev = get_current_device() # build model, optimizer and criterion -<<<<<<< HEAD if args.distplan.startswith("CAI"): # all param must use the same process group. world_size = torch.distributed.get_world_size() @@ -118,13 +82,6 @@ def main(): dtype=torch.half, default_dist_spec=default_dist_spec, default_pg=shard_pg): -======= - if use_zero: - shard_strategy = TensorShardStrategy() - with ZeroInitContext(target_device=torch.cuda.current_device(), shard_strategy=shard_strategy, - shard_param=True): - ->>>>>>> 52a933e17509c71811e919b165de38cb3d5d6d41 config, model, numel = get_model(args, logger) # asign running configurations @@ -170,14 +127,9 @@ def main(): logger.info(f'Model numel: {numel}') get_tflops_func = partial(get_tflops, numel, args.train_micro_batch_size_per_gpu, args.max_seq_length) -<<<<<<< HEAD # 144003367 is is the length of the entire dataset steps_per_epoch = 144003367 // world_size // args.train_micro_batch_size_per_gpu // args.gradient_accumulation_steps // args.refresh_bucket_size #len(dataloader) -======= - # len(dataloader) - steps_per_epoch = 144003367 // world_size // args.train_micro_batch_size_per_gpu // args.gradient_accumulation_steps // args.refresh_bucket_size ->>>>>>> 52a933e17509c71811e919b165de38cb3d5d6d41 total_steps = steps_per_epoch * args.epoch lr_scheduler = get_lr_scheduler(optimizer, total_steps=total_steps, last_epoch=-1) @@ -204,32 +156,14 @@ def main(): start_epoch = o_l_state_dict['epoch'] start_shard = o_l_state_dict['shard'] + 1 # global_step = o_l_state_dict['global_step'] + 1 -<<<<<<< HEAD logger.info(f'resume from epoch {start_epoch} shard {start_shard} step {lr_scheduler.last_epoch} lr {lr_scheduler.get_last_lr()[0]}') -======= - logger.info( - f'resume from epoch {start_epoch} shard {start_shard} step {lr_scheduler.last_epoch} lr {lr_scheduler.get_last_lr()[0]}' - ) - else: - optimizer = get_optimizer(model, lr=args.lr) - lr_scheduler = get_lr_scheduler(optimizer, total_steps=total_steps, last_epoch=-1) ->>>>>>> 52a933e17509c71811e919b165de38cb3d5d6d41 criterion = LossForPretraining(config.vocab_size) # build dataloader pretrain_dataset_provider = NvidiaBertDatasetProvider(args) -<<<<<<< HEAD -======= - # initialize with colossalai - engine, _, _, lr_scheduelr = colossalai.initialize(model=model, - optimizer=optimizer, - criterion=criterion, - lr_scheduler=lr_scheduler) - ->>>>>>> 52a933e17509c71811e919b165de38cb3d5d6d41 logger.info(get_mem_info(prefix='After init model, ')) best_loss = None @@ -254,15 +188,9 @@ def main(): else: iterator_data = enumerate(dataset_iterator) -<<<<<<< HEAD model.train() for step, batch_data in iterator_data: -======= - engine.train() - - for step, batch_data in iterator_data: ->>>>>>> 52a933e17509c71811e919b165de38cb3d5d6d41 # batch_data = pretrain_dataset_provider.get_batch(batch_index) input_ids = batch_data[0].cuda(f"cuda:{torch.cuda.current_device()}") @@ -271,31 +199,18 @@ def main(): mlm_label = batch_data[3].cuda(f"cuda:{torch.cuda.current_device()}") # nsp_label = batch_data[5].cuda() -<<<<<<< HEAD output = model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) loss = criterion(output.logits, mlm_label) -======= - output = engine(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) - - loss = engine.criterion(output.logits, mlm_label) ->>>>>>> 52a933e17509c71811e919b165de38cb3d5d6d41 pretrain_dataset_provider.prefetch_batch() optimizer.backward(loss) train_loss += loss.float().item() # if (step + 1) % args.accumulation_step == 0: -<<<<<<< HEAD optimizer.step() lr_scheduler.step() optimizer.zero_grad() -======= - engine.step() - lr_scheduelr.step() - engine.zero_grad() - ->>>>>>> 52a933e17509c71811e919b165de38cb3d5d6d41 global_step += 1 if global_step % args.log_interval == 0 and global_step != 0 \ @@ -326,18 +241,10 @@ def main(): logger.info(f'epoch {epoch} shard {shard} has cost {timers("shard_time").elapsed() / 60 :.3f} mins') logger.info('*' * 100) -<<<<<<< HEAD eval_loss += evaluate(model, args, logger, global_step, criterion) save_ckpt(model, optimizer, lr_scheduler, os.path.join(args.ckpt_path, launch_time, f'epoch-{epoch}_shard-{shard}_' + launch_time), epoch, shard, global_step) -======= - eval_loss += evaluate(engine, args, logger, global_step) - save_ckpt(engine.model, optimizer, lr_scheduelr, - os.path.join(args.ckpt_path, launch_time, f'epoch-{epoch}_shard-{shard}_' + launch_time), epoch, - shard, global_step) - ->>>>>>> 52a933e17509c71811e919b165de38cb3d5d6d41 eval_loss /= len(os.listdir(args.data_path_prefix)) logger.info( f'epoch {epoch} | shard_length {len(os.listdir(args.data_path_prefix))} | elapsed_time: {timers("epoch_time").elapsed() / 60 :.3f} mins' From a7ca2972810ac784754f0f31e21324687c03b324 Mon Sep 17 00:00:00 2001 From: gongenlei Date: Fri, 7 Apr 2023 11:39:09 +0800 Subject: [PATCH 27/27] [coati] Fix LlamaCritic (#3475) * mv LlamaForCausalLM to LlamaModel * rm unused imports --------- Co-authored-by: gongenlei --- applications/Chat/coati/models/llama/llama_critic.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/applications/Chat/coati/models/llama/llama_critic.py b/applications/Chat/coati/models/llama/llama_critic.py index cd565031e112..dd9e5e7bfa1a 100644 --- a/applications/Chat/coati/models/llama/llama_critic.py +++ b/applications/Chat/coati/models/llama/llama_critic.py @@ -1,8 +1,7 @@ from typing import Optional -import torch import torch.nn as nn -from transformers import AutoModelForCausalLM, LlamaConfig, LlamaForCausalLM +from transformers import LlamaConfig, LlamaModel from ..base import Critic @@ -28,11 +27,11 @@ def __init__(self, **kwargs) -> None: if pretrained is not None: - model = LlamaForCausalLM.from_pretrained(pretrained) + model = LlamaModel.from_pretrained(pretrained) elif config is not None: - model = LlamaForCausalLM(config) + model = LlamaModel(config) else: - model = LlamaForCausalLM(LlamaConfig()) + model = LlamaModel(LlamaConfig()) if checkpoint: model.gradient_checkpointing_enable()