diff --git a/applications/Chat/README.md b/applications/Chat/README.md index 2a9c916d45c9..9ba831973b6c 100644 --- a/applications/Chat/README.md +++ b/applications/Chat/README.md @@ -59,7 +59,7 @@ The Coati package provides a unified large language model framework that has imp Image source: https://openai.com/blog/chatgpt -**As Colossa-AI is undergoing some major updates, this project will be actively maintained to stay in line with the Colossal-AI project.** +**As Colossal-AI is undergoing some major updates, this project will be actively maintained to stay in line with the Colossal-AI project.** More details can be found in the latest news. diff --git a/applications/Chat/benchmarks/benchmark_opt_lora_dummy.py b/applications/Chat/benchmarks/benchmark_opt_lora_dummy.py index a991e8558aee..7a47624f74d8 100644 --- a/applications/Chat/benchmarks/benchmark_opt_lora_dummy.py +++ b/applications/Chat/benchmarks/benchmark_opt_lora_dummy.py @@ -101,6 +101,11 @@ def main(args): initial_model = deepcopy(actor).cuda().half() reward_model = RewardModel(deepcopy(critic.model), deepcopy(critic.value_head)).cuda().half() + if args.use_kernels: + from coati.kernels import convert_to_xformer_model + actor, critic, initial_model, reward_model = map(convert_to_xformer_model, + (actor, critic, initial_model, reward_model)) + actor_numel = get_model_numel(actor, strategy) critic_numel = get_model_numel(critic, strategy) initial_model_numel = get_model_numel(initial_model, strategy) @@ -184,5 +189,6 @@ def main(args): parser.add_argument('--lora_rank', type=int, default=0) parser.add_argument('--cuda_mem_frac', type=float, default=1.0) parser.add_argument('--offload_inference_models', action='store_true', default=False) + parser.add_argument('--use_kernels', action='store_true', default=False) args = parser.parse_args() main(args) diff --git a/applications/Chat/coati/kernels/__init__.py b/applications/Chat/coati/kernels/__init__.py new file mode 100644 index 000000000000..230eedf7ecba --- /dev/null +++ b/applications/Chat/coati/kernels/__init__.py @@ -0,0 +1,6 @@ +from .wrapper import convert_to_xformer_model, recover_from_xformer_model + +__all__ = [ + 'convert_to_xformer_model', + 'recover_from_xformer_model', +] diff --git a/applications/Chat/coati/kernels/opt_attn.py b/applications/Chat/coati/kernels/opt_attn.py new file mode 100644 index 000000000000..c10f341e94a3 --- /dev/null +++ b/applications/Chat/coati/kernels/opt_attn.py @@ -0,0 +1,87 @@ +from typing import Optional, Tuple + +import torch +import xformers.ops as xops +from torch import Tensor +from transformers.models.opt.modeling_opt import OPTAttention + + +# This is modified from https://github.com/huggingface/transformers/blob/main/src/transformers/models/opt/modeling_opt.py +class XOPTAttention(OPTAttention): + # def _shape(self, tensor: Tensor, seq_len: int, bsz: int): + # return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).contiguous() + + def forward( + self, + hidden_states: Tensor, + key_value_states: Optional[Tensor] = None, + past_key_value: Optional[Tensor] = None, + attention_mask: Optional[Tensor] = None, + layer_head_mask: Optional[Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[Tensor, Optional[Tensor], Optional[Tuple[Tensor]]]: + if not self.training: + return super().forward(hidden_states, key_value_states, past_key_value, attention_mask, layer_head_mask, + output_attentions) + """Input shape: Batch x Time x Channel""" + assert layer_head_mask is None, 'Xformers attention does not support layer_head_mask' + assert not output_attentions, 'Xformers attention does not support output_attentions' + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) + # get key, value proj + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + query_states = self._shape(query_states, tgt_len, bsz).transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + attn_output = xops.memory_efficient_attention(query_states, + key_states, + value_states, + attn_bias=xops.LowerTriangularMask(), + p=self.dropout if self.training else 0.0, + scale=self.scaling) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned aross GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + attn_weights_reshaped = None + + return attn_output, attn_weights_reshaped, past_key_value diff --git a/applications/Chat/coati/kernels/wrapper.py b/applications/Chat/coati/kernels/wrapper.py new file mode 100644 index 000000000000..c55bda600230 --- /dev/null +++ b/applications/Chat/coati/kernels/wrapper.py @@ -0,0 +1,18 @@ +import torch.nn as nn +from transformers.models.opt.modeling_opt import OPTAttention + +from .opt_attn import XOPTAttention + + +def convert_to_xformer_model(model: nn.Module) -> nn.Module: + for module in model.modules(): + if isinstance(module, OPTAttention): + module.__class__ = XOPTAttention + return model + + +def recover_from_xformer_model(model: nn.Module) -> nn.Module: + for module in model.modules(): + if isinstance(module, XOPTAttention): + module.__class__ = OPTAttention + return model diff --git a/applications/Chat/evaluate/evaluate.py b/applications/Chat/evaluate/evaluate.py index 9f17704426e2..2f9c9ce8e10d 100644 --- a/applications/Chat/evaluate/evaluate.py +++ b/applications/Chat/evaluate/evaluate.py @@ -130,7 +130,7 @@ def evaluate(args): assert answer1_jsons[i]['id'] == answer2_jsons[i]['id'] answer_id = answer1_jsons[i]['id'] - ques = answer1_jsons[i]['instruction'] if answer1_jsons[i]['input'] == "" else answer1_jsons[i]['instuction'] + \ + ques = answer1_jsons[i]['instruction'] if answer1_jsons[i]['input'] == "" else answer1_jsons[i]['instruction'] + \ " " + answer1_jsons[i]['input'] cat = answer1_jsons[i]['category'] ans1 = answer1_jsons[i]['output'] diff --git a/applications/Chat/evaluate/generate_gpt35_answers.py b/applications/Chat/evaluate/generate_gpt35_answers.py index 852a7cb19dfa..db95cd2febf4 100644 --- a/applications/Chat/evaluate/generate_gpt35_answers.py +++ b/applications/Chat/evaluate/generate_gpt35_answers.py @@ -35,7 +35,7 @@ def get_answer(question: str, max_tokens: int): answer = question - prompt = question['instruction'] if question['input'] == "" else question['instuction'] + \ + prompt = question['instruction'] if question['input'] == "" else question['instruction'] + \ " " + question['input'] for _ in range(MAX_API_RETRY): try: diff --git a/applications/Chat/examples/README.md b/applications/Chat/examples/README.md index 3e85bfe2d170..2a2128e25a62 100644 --- a/applications/Chat/examples/README.md +++ b/applications/Chat/examples/README.md @@ -24,7 +24,6 @@ - [LLaMA](#llama) - [Add your own models](#add-your-own-models) - [Actor model](#actor-model) - - [LM model](#lm-model) - [Reward model](#reward-model) - [Critic model](#critic-model) @@ -150,11 +149,11 @@ torchrun --standalone --nproc_per_node=4 train_prompts.py \ --strategy colossalai_zero2 \ --prompt_dataset /path/to/your/prompt_dataset \ --pretrain_dataset /path/to/your/pretrain_dataset \ - --rm_pretrain /your/pretrain/rm/defination \ + --rm_pretrain /your/pretrain/rm/definition \ --rm_path /your/rm/model/path ``` -Prompt dataset: the instruction dataset mentioned in the above figure which includes the instructions, e.g. you can use [seed_prompts_ch.jsonl](https://github.com/XueFuzhao/InstructionWild/blob/main/data/seed_prompts_ch.jsonl) or [seed_prompts_en.jsonl](https://github.com/XueFuzhao/InstructionWild/blob/main/data/seed_prompts_en.jsonl) in InstructionWild. +Prompt dataset: the instruction dataset mentioned in the above figure which includes the instructions, e.g. you can use the [script](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat/examples/example_data_reformat.py) to reformat [seed_prompts_ch.jsonl](https://github.com/XueFuzhao/InstructionWild/blob/main/data/seed_prompts_ch.jsonl) or [seed_prompts_en.jsonl](https://github.com/XueFuzhao/InstructionWild/blob/main/data/seed_prompts_en.jsonl) in InstructionWild. Pretrain dataset: the pretrain dataset including the instruction and corresponding response, e.g. you can use the [InstructWild Data](https://github.com/XueFuzhao/InstructionWild/tree/main/data) in stage 1 supervised instructs tuning. ### Arg List @@ -233,7 +232,7 @@ If you want to support your own model in Coati, please refer the pull request fo You should complete the implementation of four model classes, including Reward model, Critic model, LM model, Actor model here are some example code for a NewModel named `Coati`. -if it is supported in huggingaface [transformers](https://github.com/huggingface/transformers), you can load it by `from_pretrained`, o +if it is supported in huggingface [transformers](https://github.com/huggingface/transformers), you can load it by `from_pretrained`, o r you can build your own model by yourself. ### Actor model diff --git a/applications/Chat/examples/community/peft/easy_dataset.py b/applications/Chat/examples/community/peft/easy_dataset.py index 24ea4f0a8618..2fe293957079 100644 --- a/applications/Chat/examples/community/peft/easy_dataset.py +++ b/applications/Chat/examples/community/peft/easy_dataset.py @@ -188,7 +188,7 @@ def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length=512, is_ else: raw_input_ids.append(encoded_ids) - grouped_inpup_ids = [] + grouped_input_ids = [] current_input_ids = [] attention_mask = [] if tokenizer.pad_token_id is None: @@ -199,7 +199,7 @@ def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length=512, is_ #pad the current_input_ids to max_length with tokenizer.pad_token_id padded_length = max_length - len(current_input_ids) current_input_ids.extend([tokenizer.pad_token_id] * padded_length) - grouped_inpup_ids.append(torch.tensor(current_input_ids, dtype=torch.long)) + grouped_input_ids.append(torch.tensor(current_input_ids, dtype=torch.long)) attention_mask.append( torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long)) current_input_ids = [] @@ -208,7 +208,7 @@ def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length=512, is_ if len(current_input_ids) > 0: padded_length = max_length - len(current_input_ids) current_input_ids.extend([tokenizer.pad_token_id] * padded_length) - grouped_inpup_ids.append(torch.tensor(current_input_ids, dtype=torch.long)) + grouped_input_ids.append(torch.tensor(current_input_ids, dtype=torch.long)) attention_mask.append( torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long)) else: @@ -218,8 +218,8 @@ def __init__(self, data_file: str, tokenizer: AutoTokenizer, max_length=512, is_ input_ids.extend([tokenizer.pad_token_id] * padded_length) attention_mask.append( torch.tensor([1] * (max_length - padded_length) + [0] * padded_length, dtype=torch.long)) - grouped_inpup_ids.append(torch.tensor(input_ids, dtype=torch.long)) - self.input_ids = grouped_inpup_ids + grouped_input_ids.append(torch.tensor(input_ids, dtype=torch.long)) + self.input_ids = grouped_input_ids self.labels = copy.deepcopy(self.input_ids) self.file_name = data_file self.attention_mask = attention_mask diff --git a/applications/Chat/examples/community/peft/train_peft_prompts.py b/applications/Chat/examples/community/peft/train_peft_prompts.py index 0e277021e917..ba8470f38fad 100644 --- a/applications/Chat/examples/community/peft/train_peft_prompts.py +++ b/applications/Chat/examples/community/peft/train_peft_prompts.py @@ -41,7 +41,7 @@ def main(args): # configure model if args.model == 'bloom': # initial_model = BLOOMActor(pretrained=args.pretrain) - print('Using peft lora to load Bloom model as inital_model') + print('Using peft lora to load Bloom model as initial_model') initial_model = BLOOMActor(pretrained=args.pretrain, lora_path=args.sft_lora_path) print('Using peft lora to load Bloom model as initial_model (Done)') else: diff --git a/applications/Chat/examples/community/peft/train_peft_sft.py b/applications/Chat/examples/community/peft/train_peft_sft.py index 9bd0ebc12a83..d2b08b72ca95 100644 --- a/applications/Chat/examples/community/peft/train_peft_sft.py +++ b/applications/Chat/examples/community/peft/train_peft_sft.py @@ -86,7 +86,7 @@ def train(args): if args.strategy == 'colossalai_gemini': # this is a hack to deal with the resized embedding - # to make sure all parameters are ColoParameter for Colossal-AI Gemini Compatiblity + # to make sure all parameters are ColoParameter for Colossal-AI Gemini Compatibility for name, param in model.named_parameters(): if not isinstance(param, ColoParameter): sub_module_name = '.'.join(name.split('.')[:-1]) diff --git a/applications/Chat/examples/example_data_reformat.py b/applications/Chat/examples/example_data_reformat.py new file mode 100644 index 000000000000..dc83b29b525b --- /dev/null +++ b/applications/Chat/examples/example_data_reformat.py @@ -0,0 +1,12 @@ +jsonl_file = 'seed_prompts_xx.jsonl' # seed_prompts_en.jsonl or seed_prompts_ch.json from InstructionWild +reformat_file = 'prompts_xx.jsonl' # reformat jsonl file used as Prompt dataset in Stage3 + +data = '' +with open(jsonl_file, 'r', encoding="utf-8") as f1: + for jsonstr in f1.readlines(): + jsonstr = '\t' + jsonstr.strip('\n') + ',\n' + data = data + jsonstr + data = '[\n' + data + ']' + +with open(reformat_file, 'w') as f2: + f2.write(data) \ No newline at end of file diff --git a/applications/Chat/examples/train_prompts.py b/applications/Chat/examples/train_prompts.py index a584991cd34e..134f21f80ef1 100644 --- a/applications/Chat/examples/train_prompts.py +++ b/applications/Chat/examples/train_prompts.py @@ -36,45 +36,45 @@ def main(args): if args.rm_path is not None: state_dict = torch.load(args.rm_path, map_location='cpu') - # configure model - if args.model == 'gpt2': - initial_model = GPTActor(pretrained=args.pretrain) - elif args.model == 'bloom': - initial_model = BLOOMActor(pretrained=args.pretrain) - elif args.model == 'opt': - initial_model = OPTActor(pretrained=args.pretrain) - elif args.model == 'llama': - initial_model = LlamaActor(pretrained=args.pretrain) - elif args.model == 'roberta': - initial_model = RoBERTaActor(pretrained=args.pretrain) - else: - raise ValueError(f'Unsupported actor model "{args.model}"') + with strategy.model_init_context(): + # configure model + if args.model == 'gpt2': + initial_model = GPTActor(pretrained=args.pretrain) + elif args.model == 'bloom': + initial_model = BLOOMActor(pretrained=args.pretrain) + elif args.model == 'opt': + initial_model = OPTActor(pretrained=args.pretrain) + elif args.model == 'llama': + initial_model = LlamaActor(pretrained=args.pretrain) + elif args.model == 'roberta': + initial_model = RoBERTaActor(pretrained=args.pretrain) + else: + raise ValueError(f'Unsupported actor model "{args.model}"') - if args.rm_model == None: - rm_model_name = args.model - else: - rm_model_name = args.rm_model - - if rm_model_name == 'gpt2': - reward_model = GPTRM(pretrained=args.rm_pretrain) - elif rm_model_name == 'bloom': - reward_model = BLOOMRM(pretrained=args.rm_pretrain) - elif rm_model_name == 'opt': - reward_model = OPTRM(pretrained=args.rm_pretrain) - elif rm_model_name == 'llama': - reward_model = LlamaRM(pretrained=args.rm_pretrain) - elif rm_model_name == 'roberta': - reward_model = RoBERTaRM(pretrained=args.rm_pretrain) - else: - raise ValueError(f'Unsupported reward model "{rm_model_name}"') + if args.rm_model == None: + rm_model_name = args.model + else: + rm_model_name = args.rm_model - if args.rm_path is not None: - reward_model.load_state_dict(state_dict) + if rm_model_name == 'gpt2': + reward_model = GPTRM(pretrained=args.rm_pretrain) + elif rm_model_name == 'bloom': + reward_model = BLOOMRM(pretrained=args.rm_pretrain) + elif rm_model_name == 'opt': + reward_model = OPTRM(pretrained=args.rm_pretrain) + elif rm_model_name == 'llama': + reward_model = LlamaRM(pretrained=args.rm_pretrain) + elif rm_model_name == 'roberta': + reward_model = RoBERTaRM(pretrained=args.rm_pretrain) + else: + raise ValueError(f'Unsupported reward model "{rm_model_name}"') - initial_model.to(torch.float16).to(torch.cuda.current_device()) - reward_model.to(torch.float16).to(torch.cuda.current_device()) + if args.rm_path is not None: + reward_model.load_state_dict(state_dict) + + initial_model.to(torch.float16).to(torch.cuda.current_device()) + reward_model.to(torch.float16).to(torch.cuda.current_device()) - with strategy.model_init_context(): if args.model == 'gpt2': actor = GPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank) elif args.model == 'bloom': diff --git a/applications/Chat/examples/train_sft.py b/applications/Chat/examples/train_sft.py index da499f068b17..7fcd026fb538 100644 --- a/applications/Chat/examples/train_sft.py +++ b/applications/Chat/examples/train_sft.py @@ -84,7 +84,7 @@ def train(args): if args.strategy == 'colossalai_gemini': # this is a hack to deal with the resized embedding - # to make sure all parameters are ColoParameter for Colossal-AI Gemini Compatiblity + # to make sure all parameters are ColoParameter for Colossal-AI Gemini Compatibility for name, param in model.named_parameters(): if not isinstance(param, ColoParameter): sub_module_name = '.'.join(name.split('.')[:-1]) diff --git a/colossalai/_analyzer/__init__.py b/colossalai/_analyzer/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/colossalai/booster/plugin/dp_plugin_base.py b/colossalai/booster/plugin/dp_plugin_base.py new file mode 100644 index 000000000000..d5da5938bfd9 --- /dev/null +++ b/colossalai/booster/plugin/dp_plugin_base.py @@ -0,0 +1,70 @@ +import random + +import numpy as np +import torch +import torch.distributed as dist +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler + +from .plugin_base import Plugin + + +class DPPluginBase(Plugin): + """This is a base class for all DP plugins. It sets up world size and rank, and provides data loader creation. + """ + + def __init__(self) -> None: + super().__init__() + assert dist.is_initialized( + ), 'torch.distributed is not initialized, please use colossalai.launch to create the distributed environment' + self.rank = dist.get_rank() + self.world_size = dist.get_world_size() + + def prepare_dataloader(self, + dataset, + batch_size, + shuffle=False, + seed=1024, + drop_last=False, + pin_memory=False, + num_workers=0, + **kwargs): + r""" + Prepare a dataloader for distributed training. The dataloader will be wrapped by + `torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`. + + + Args: + dataset (`torch.utils.data.Dataset`): The dataset to be loaded. + shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False. + seed (int, optional): Random worker seed for sampling, defaults to 1024. + add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True. + drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size + is not divisible by the batch size. If False and the size of dataset is not divisible by + the batch size, then the last batch will be smaller, defaults to False. + pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False. + num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0. + kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in + `DataLoader `_. + + Returns: + :class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing. + """ + _kwargs = kwargs.copy() + sampler = DistributedSampler(dataset, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle) + + # Deterministic dataloader + def seed_worker(worker_id): + worker_seed = seed + np.random.seed(worker_seed) + torch.manual_seed(worker_seed) + random.seed(worker_seed) + + return DataLoader(dataset, + batch_size=batch_size, + sampler=sampler, + worker_init_fn=seed_worker, + drop_last=drop_last, + pin_memory=pin_memory, + num_workers=num_workers, + **_kwargs) diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index deda00d8a7b3..4850b52defaf 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -1,26 +1,25 @@ -import random +import logging +import os import warnings +from pathlib import Path from typing import Callable, List, Optional, Tuple, Union -import numpy as np import torch -import torch.distributed as dist import torch.nn as nn from torch import Tensor from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.utils.data import DataLoader -from torch.utils.data.distributed import DistributedSampler -from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO -from colossalai.checkpoint_io.utils import save_state_dict +from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO +from colossalai.checkpoint_io.utils import get_base_filenames, get_shard_filename, save_state_dict from colossalai.cluster import DistCoordinator from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.utils import get_current_device from colossalai.zero import GeminiDDP, zero_model_wrapper, zero_optim_wrapper from colossalai.zero.gemini.memory_tracer import MemStats -from .plugin_base import Plugin +from .dp_plugin_base import DPPluginBase __all__ = ['GeminiPlugin'] @@ -62,6 +61,48 @@ def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): if self.coordinator.is_master(): super().save_lr_scheduler(lr_scheduler, checkpoint) + def save_sharded_model(self, + model: GeminiDDP, + checkpoint_path: str, + gather_dtensor: bool = False, + variant: Optional[str] = None, + max_shard_size: int = 1024, + use_safetensors: bool = False): + """ + Save sharded model + """ + state_dict_shard = model.state_dict_shard(max_shard_size=max_shard_size, only_rank_0=True, dtype=torch.float32) + weights_name, save_index_file = get_base_filenames(variant, use_safetensors) + total_size = 0 + index_file = CheckpointIndexFile(checkpoint_path) + for idx, shard_pair in enumerate(state_dict_shard): + if not self.coordinator.is_master(): + continue + shard = shard_pair[0] + shard_file = get_shard_filename(weights_name, idx) + total_size = total_size + shard_pair[1] + for key in shard.keys(): + index_file.append_weight_map(key, shard_file) + + checkpoint_file_path = os.path.join(checkpoint_path, shard_file) + save_state_dict(shard, checkpoint_file_path, use_safetensors) + + index_file.append_meta_data("total_size", total_size) + index_file.write_index_file(save_index_file) + logging.info(f"The model is going to be split to checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {save_index_file}.") + + def load_sharded_model(self, + model: GeminiDDP, + checkpoint_index_file: Path, + strict: bool = False, + use_safetensors: bool = False): + """ + load shard model, load model from multiple files + """ + return super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module=False) + class GeminiModel(ModelWrapper): @@ -104,7 +145,7 @@ def clip_grad_by_value(self, clip_value: float, *args, **kwargs) -> None: raise NotImplementedError('Gemini does not support clip_grad_by_value') -class GeminiPlugin(Plugin): +class GeminiPlugin(DPPluginBase): """ Plugin for Gemini. @@ -115,7 +156,7 @@ class GeminiPlugin(Plugin): >>> model, train_dataset, optimizer, criterion = ... >>> plugin = GeminiPlugin() - >>> train_dataloader = plugin.prepare_train_dataloader(train_dataset, batch_size=8) + >>> train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8) >>> booster = Booster(plugin=plugin) >>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion) @@ -173,11 +214,7 @@ def __init__( norm_type: float = 2.0, verbose: bool = False, ) -> None: - - assert dist.is_initialized( - ), 'torch.distributed is not initialized, please use colossalai.launch to create the distributed environment' - self.rank = dist.get_rank() - self.world_size = dist.get_world_size() + super().__init__() self.gemini_config = dict( device=(device or get_current_device()), placement_policy=placement_policy, @@ -216,57 +253,6 @@ def control_device(self) -> bool: def supported_devices(self) -> List[str]: return ['cuda'] - def prepare_train_dataloader(self, - dataset, - batch_size, - shuffle=False, - seed=1024, - drop_last=False, - pin_memory=False, - num_workers=0, - **kwargs): - r""" - Prepare a dataloader for distributed training. The dataloader will be wrapped by - `torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`. - - Note: - 1. Evaluation datasets should not be passed to this function. - - Args: - dataset (`torch.utils.data.Dataset`): The dataset to be loaded. - shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False. - seed (int, optional): Random worker seed for sampling, defaults to 1024. - add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True. - drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size - is not divisible by the batch size. If False and the size of dataset is not divisible by - the batch size, then the last batch will be smaller, defaults to False. - pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False. - num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0. - kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in - `DataLoader `_. - - Returns: - :class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing. - """ - _kwargs = kwargs.copy() - sampler = DistributedSampler(dataset, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle) - - # Deterministic dataloader - def seed_worker(worker_id): - worker_seed = seed - np.random.seed(worker_seed) - torch.manual_seed(worker_seed) - random.seed(worker_seed) - - return DataLoader(dataset, - batch_size=batch_size, - sampler=sampler, - worker_init_fn=seed_worker, - drop_last=drop_last, - pin_memory=pin_memory, - num_workers=num_workers, - **_kwargs) - def configure( self, model: nn.Module, diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index 969c430bd317..f0f5768560a7 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -1,24 +1,20 @@ -import random import warnings from typing import Callable, List, Optional, Tuple, Union -import numpy as np import torch -import torch.distributed as dist import torch.nn as nn from torch import Tensor from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.utils._pytree import tree_map from torch.utils.data import DataLoader -from torch.utils.data.distributed import DistributedSampler from colossalai.checkpoint_io import CheckpointIO from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.utils import get_current_device from colossalai.zero import zero_model_wrapper, zero_optim_wrapper -from .plugin_base import Plugin +from .dp_plugin_base import DPPluginBase from .torch_ddp_plugin import TorchDDPCheckpointIO __all__ = ['LowLevelZeroPlugin'] @@ -88,7 +84,7 @@ def clip_grad_by_value(self, clip_value: float, *args, **kwargs) -> None: raise NotImplementedError('LowLevelZero does not support clip_grad_by_value') -class LowLevelZeroPlugin(Plugin): +class LowLevelZeroPlugin(DPPluginBase): """ Plugin for low level zero. @@ -99,7 +95,7 @@ class LowLevelZeroPlugin(Plugin): >>> model, train_dataset, optimizer, criterion = ... >>> plugin = LowLevelZeroPlugin() - >>> train_dataloader = plugin.prepare_train_dataloader(train_dataset, batch_size=8) + >>> train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8) >>> booster = Booster(plugin=plugin) >>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion) @@ -142,15 +138,10 @@ def __init__( cpu_offload: bool = False, verbose: bool = False, ) -> None: - - assert dist.is_initialized( - ), 'torch.distributed is not initialized, please use colossalai.launch to create the distributed environment' + super().__init__() assert stage in (1, 2), f'LowLevelZeroPlugin only supports stage 1/2 training' assert precision in ('fp16', 'fp32'), f'LowLevelZeroPlugin only supports fp16/fp32 training' - self.rank = dist.get_rank() - self.world_size = dist.get_world_size() - self.stage = stage self.precision = precision self.zero_optim_config = dict(reduce_bucket_size=reduce_bucket_size_in_m * 1024 * 1024, @@ -183,57 +174,6 @@ def control_device(self) -> bool: def supported_devices(self) -> List[str]: return ['cuda'] - def prepare_train_dataloader(self, - dataset, - batch_size, - shuffle=False, - seed=1024, - drop_last=False, - pin_memory=False, - num_workers=0, - **kwargs): - r""" - Prepare a dataloader for distributed training. The dataloader will be wrapped by - `torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`. - - Note: - 1. Evaluation datasets should not be passed to this function. - - Args: - dataset (`torch.utils.data.Dataset`): The dataset to be loaded. - shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False. - seed (int, optional): Random worker seed for sampling, defaults to 1024. - add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True. - drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size - is not divisible by the batch size. If False and the size of dataset is not divisible by - the batch size, then the last batch will be smaller, defaults to False. - pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False. - num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0. - kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in - `DataLoader `_. - - Returns: - :class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing. - """ - _kwargs = kwargs.copy() - sampler = DistributedSampler(dataset, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle) - - # Deterministic dataloader - def seed_worker(worker_id): - worker_seed = seed - np.random.seed(worker_seed) - torch.manual_seed(worker_seed) - random.seed(worker_seed) - - return DataLoader(dataset, - batch_size=batch_size, - sampler=sampler, - worker_init_fn=seed_worker, - drop_last=drop_last, - pin_memory=pin_memory, - num_workers=num_workers, - **_kwargs) - def configure( self, model: nn.Module, diff --git a/colossalai/booster/plugin/plugin_base.py b/colossalai/booster/plugin/plugin_base.py index 7a222022c1b2..eb5478595542 100644 --- a/colossalai/booster/plugin/plugin_base.py +++ b/colossalai/booster/plugin/plugin_base.py @@ -4,7 +4,7 @@ import torch.nn as nn from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler as LRScheduler -from torch.utils.data import DataLoader +from torch.utils.data import DataLoader, Dataset from colossalai.checkpoint_io import CheckpointIO from colossalai.interface import OptimizerWrapper @@ -59,3 +59,18 @@ def get_checkpoint_io(self) -> CheckpointIO: Get checkpoint io object for this plugin, only invoked when control_checkpoint_io is True. """ pass + + @abstractmethod + def prepare_dataloader(self, + dataset: Dataset, + batch_size: int, + shuffle: bool = False, + seed: int = 1024, + drop_last: bool = False, + pin_memory: bool = False, + num_workers: int = 0, + **kwargs): + """Prepare a dataloader for distributed training. The dataloader will be wrapped by + `torch.utils.data.DataLoader` + """ + pass diff --git a/colossalai/booster/plugin/torch_ddp_plugin.py b/colossalai/booster/plugin/torch_ddp_plugin.py index c5e310c7e769..76906d844ef1 100644 --- a/colossalai/booster/plugin/torch_ddp_plugin.py +++ b/colossalai/booster/plugin/torch_ddp_plugin.py @@ -1,21 +1,16 @@ -import random from typing import Callable, List, Tuple, Union -import numpy as np -import torch -import torch.distributed as dist import torch.nn as nn from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.utils.data import DataLoader -from torch.utils.data.distributed import DistributedSampler from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO from colossalai.cluster import DistCoordinator from colossalai.interface import ModelWrapper, OptimizerWrapper -from .plugin_base import Plugin +from .dp_plugin_base import DPPluginBase __all__ = ['TorchDDPPlugin'] @@ -66,7 +61,7 @@ def unwrap(self): return self.module.module -class TorchDDPPlugin(Plugin): +class TorchDDPPlugin(DPPluginBase): """ Plugin for PyTorch DDP. @@ -77,7 +72,7 @@ class TorchDDPPlugin(Plugin): >>> model, train_dataset, optimizer, criterion = ... >>> plugin = TorchDDPPlugin() - >>> train_dataloader = plugin.prepare_train_dataloader(train_dataset, batch_size=8) + >>> train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8) >>> booster = Booster(plugin=plugin) >>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion) @@ -97,11 +92,7 @@ def __init__(self, check_reduction: bool = False, gradient_as_bucket_view: bool = False, static_graph: bool = False) -> None: - - assert dist.is_initialized( - ), 'torch.distributed is not initialized, please use colossalai.launch to create the distributed environment' - self.rank = dist.get_rank() - self.world_size = dist.get_world_size() + super().__init__() self.ddp_kwargs = dict(broadcast_buffers=broadcast_buffers, bucket_cap_mb=bucket_cap_mb, find_unused_parameters=find_unused_parameters, @@ -124,57 +115,6 @@ def control_device(self) -> bool: def supported_devices(self) -> List[str]: return ['cuda'] - def prepare_train_dataloader(self, - dataset, - batch_size, - shuffle=False, - seed=1024, - drop_last=False, - pin_memory=False, - num_workers=0, - **kwargs): - r""" - Prepare a dataloader for distributed training. The dataloader will be wrapped by - `torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`. - - Note: - 1. Evaluation datasets should not be passed to this function. - - Args: - dataset (`torch.utils.data.Dataset`): The dataset to be loaded. - shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False. - seed (int, optional): Random worker seed for sampling, defaults to 1024. - add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True. - drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size - is not divisible by the batch size. If False and the size of dataset is not divisible by - the batch size, then the last batch will be smaller, defaults to False. - pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False. - num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0. - kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in - `DataLoader `_. - - Returns: - :class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing. - """ - _kwargs = kwargs.copy() - sampler = DistributedSampler(dataset, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle) - - # Deterministic dataloader - def seed_worker(worker_id): - worker_seed = seed - np.random.seed(worker_seed) - torch.manual_seed(worker_seed) - random.seed(worker_seed) - - return DataLoader(dataset, - batch_size=batch_size, - sampler=sampler, - worker_init_fn=seed_worker, - drop_last=drop_last, - pin_memory=pin_memory, - num_workers=num_workers, - **_kwargs) - def configure( self, model: nn.Module, diff --git a/colossalai/checkpoint_io/checkpoint_io_base.py b/colossalai/checkpoint_io/checkpoint_io_base.py index cb853559c48c..9cf344ecc41b 100644 --- a/colossalai/checkpoint_io/checkpoint_io_base.py +++ b/colossalai/checkpoint_io/checkpoint_io_base.py @@ -86,7 +86,7 @@ def load_model(self, # the existence of index file means it is a sharded checkpoint ckpt_path = Path(checkpoint) index_file_exists, index_file_path = has_index_file(checkpoint) - + # return the origin model instead of the unwrapped model origin_model = model diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py index bf584f45d045..96a883fdb42a 100644 --- a/colossalai/checkpoint_io/general_checkpoint_io.py +++ b/colossalai/checkpoint_io/general_checkpoint_io.py @@ -1,12 +1,12 @@ from pathlib import Path +from functools import reduce import torch.nn as nn from torch.optim import Optimizer import logging import os -import json import gc -from typing import Optional +from typing import Optional, Iterator, OrderedDict, Tuple from .checkpoint_io_base import CheckpointIO from .index_file import CheckpointIndexFile @@ -18,10 +18,9 @@ shard_checkpoint, load_shard_state_dict, load_state_dict_into_model, - add_variant + get_shard_filename, + get_base_filenames ) -from .utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME - __all__ = ['GeneralCheckpointIO'] @@ -85,30 +84,32 @@ def save_sharded_model(self, model: nn.Module, checkpoint_path: str, gather_dten # shard checkpoint state_dict = model.state_dict() - weights_name = SAFE_WEIGHTS_NAME if use_safetensors else WEIGHTS_NAME - weights_name = add_variant(weights_name, variant) - shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size, weights_name=weights_name) - - # Save the model - for shard_file, shard in shards.items(): + state_dict_shard = shard_checkpoint(state_dict, max_shard_size=max_shard_size) + + weights_name, save_index_file = get_base_filenames(variant, use_safetensors) + total_size = 0 + index_file = CheckpointIndexFile(checkpoint_path) + for idx, shard_pair in enumerate(state_dict_shard): + shard = shard_pair[0] + shard_file = get_shard_filename(weights_name, idx) + total_size = total_size + shard_pair[1] + for key in shard.keys(): + index_file.append_weight_map(key, shard_file) + checkpoint_file_path = os.path.join(checkpoint_path, shard_file) save_state_dict(shard, checkpoint_file_path, use_safetensors) - - # save index file - save_index_file = SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME - - save_index_file = os.path.join(checkpoint_path, add_variant(save_index_file, variant)) - with open(save_index_file, "w", encoding="utf-8") as f: - content = json.dumps(index, indent=2, sort_keys=True) + "\n" - f.write(content) + + index_file.append_meta_data("total_size", total_size) + index_file.write_index_file(save_index_file) logging.info( - f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be " - f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the " + f"The model is going to be split to checkpoint shards. " + f"You can find where each parameters has been saved in the " f"index located at {save_index_file}." ) - def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False, use_safetensors: bool = False): + def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False, + use_safetensors: bool = False, load_sub_module: bool = True): """ load shard model, load model from multiple files """ @@ -122,17 +123,21 @@ def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, stri # read checkpoint index file ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file) checkpoint_files, _ = ckpt_index_file.get_checkpoint_fileanames() - missing_keys = ckpt_index_file.get_all_param_names() + missing_keys = [] for shard_file in checkpoint_files: state_dict = load_shard_state_dict(Path(shard_file), use_safetensors) - load_state_dict_into_model(model, state_dict, missing_keys, strict) + load_state_dict_into_model(model, state_dict, missing_keys, strict, load_sub_module) del state_dict gc.collect() - if strict and len(missing_keys) > 0: - error_msgs = 'Missing key(s) in state_dict: {}. '.format( - ', '.join('"{}"'.format(k) for k in missing_keys)) - raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( - self.__class__.__name__, "\n\t".join(error_msgs))) + if strict: + remain_keys = reduce(lambda a, b: a & b, map(set, missing_keys)) + if len(remain_keys) > 0: + error_msgs = 'Missing key(s) in state_dict: {}. '.format( + ', '.join('"{}"'.format(k) for k in missing_keys)) + raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( + self.__class__.__name__, "\n\t".join(error_msgs))) + + diff --git a/colossalai/checkpoint_io/index_file.py b/colossalai/checkpoint_io/index_file.py index 89224787a91b..15a6d09f3b5e 100644 --- a/colossalai/checkpoint_io/index_file.py +++ b/colossalai/checkpoint_io/index_file.py @@ -1,6 +1,8 @@ import json from pathlib import Path from typing import Any, List, Union +import os +import json from .utils import is_dtensor_checkpoint @@ -18,8 +20,8 @@ class CheckpointIndexFile: >>> index.export('new_index.json') """ - def __init__(self) -> None: - self.root_path = None + def __init__(self, root_path=None) -> None: + self.root_path = root_path self.metadata: dict = dict() self.weight_map: dict = dict() @@ -154,3 +156,13 @@ def get_all_param_names(self): Get all the weight keys. """ return list(self.weight_map.keys()) + + def write_index_file(self, save_index_file): + """ + Wriete index file. + """ + save_index_file = os.path.join(self.root_path, save_index_file) + index = {"metadata": self.metadata, "weight_map": self.weight_map} + with open(save_index_file, "w", encoding="utf-8") as f: + content = json.dumps(index, indent=2, sort_keys=True) + "\n" + f.write(content) diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index 37d22d08df40..16e41631f0d5 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -2,7 +2,7 @@ from pathlib import Path import torch import torch.nn as nn -from typing import List, Dict, Mapping, OrderedDict, Optional, Tuple +from typing import List, Mapping, OrderedDict, Optional, Tuple, Iterator from colossalai.tensor.d_tensor.d_tensor import DTensor import re @@ -77,55 +77,35 @@ def is_safetensor_checkpoint(checkpoint_file_path: str) -> bool: # ====================================== # Helper functions for saving shard file # ====================================== -def shard_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024, weights_name: str = WEIGHTS_NAME): +def shard_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]: """ Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a given size. """ - sharded_state_dicts = [] current_block = {} current_block_size = 0 - total_size = 0 for key, weight in state_dict.items(): + ret_block = None + ret_block_size = 0 if type(weight) != DTensor: weight_size = calculate_tensor_size(weight) # If this weight is going to tip up over the maximal size, we split. if current_block_size + weight_size > max_shard_size: - sharded_state_dicts.append(current_block) + ret_block = current_block + ret_block_size = current_block_size current_block = {} current_block_size = 0 - current_block[key] = weight current_block_size += weight_size - total_size += weight_size + + if ret_block != None: + yield ret_block, ret_block_size - # Add the last block - sharded_state_dicts.append(current_block) + yield current_block, current_block_size - # If we only have one shard, we return it - if len(sharded_state_dicts) == 1: - return {weights_name: sharded_state_dicts[0]}, None - - # Otherwise, let's build the index - weight_map = {} - shards = {} - - for idx, shard in enumerate(sharded_state_dicts): - shard_file = weights_name.replace(".bin", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.bin") - shard_file = shard_file.replace( - ".safetensors", f"-{idx + 1:05d}-of-{len(sharded_state_dicts):05d}.safetensors" - ) - shards[shard_file] = shard - for key in shard.keys(): - weight_map[key] = shard_file - - # Add the metadata - metadata = {"total_size": total_size} - index = {"metadata": metadata, "weight_map": weight_map} - return shards, index def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool =False): """ @@ -146,7 +126,7 @@ def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool =False): else: return torch.load(checkpoint_file) -def load_state_dict_into_model(model: nn.Module, state_dict: torch.Tensor, missing_keys: List, strict: bool = False): +def load_state_dict_into_model(model: nn.Module, state_dict: torch.Tensor, missing_keys: List, strict: bool = False, load_sub_module: bool = True): r"""Copies parameters and buffers from :attr:`state_dict` into this module and its descendants. @@ -167,29 +147,22 @@ def load_state_dict_into_model(model: nn.Module, state_dict: torch.Tensor, missi if metadata is not None: state_dict._metadata = metadata - def load(module: nn.Module, state_dict, prefix=""): + def load(module: nn.Module, state_dict, prefix="", load_sub_module: bool = True): local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) - args = (state_dict, prefix, local_metadata, True, [], [], error_msgs) + args = (state_dict, prefix, local_metadata, True, sub_missing_keys, [], error_msgs) # Parameters of module and children will start with prefix. We can exit early if there are none in this # state_dict if len([key for key in state_dict if key.startswith(prefix)]) > 0: module._load_from_state_dict(*args) + if load_sub_module: + for name, child in module._modules.items(): + if child is not None: + load(child, state_dict, prefix + name + ".") - for name, child in module._modules.items(): - if child is not None: - load(child, state_dict, prefix + name + ".") - - load(model, state_dict, "") + load(model, state_dict, "", load_sub_module) del load - # deal with missing key - if len(missing_keys) > 0: - deleted_keys = [] - for key in missing_keys: - if key not in sub_missing_keys: - deleted_keys.append(key) - for key in deleted_keys: - missing_keys.remove(key) + missing_keys = missing_keys.append(sub_missing_keys) if strict: if len(unexpected_keys) > 0: @@ -417,3 +390,24 @@ def add_variant(weights_name: str, variant: Optional[str] = None) -> str: weights_name = ".".join(splits) return weights_name + + +def get_base_filenames(variant: str=None, use_safetensors: bool=False): + """ + generate base weight filenames + """ + weights_name = SAFE_WEIGHTS_NAME if use_safetensors else WEIGHTS_NAME + weights_name = add_variant(weights_name, variant) + + save_index_file = SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME + save_index_file = add_variant(save_index_file, variant) + + return weights_name, save_index_file + +def get_shard_filename(weights_name: str, idx: int): + """ + get shard file name + """ + shard_file = weights_name.replace(".bin", f"-{idx+1:05d}.bin") + shard_file = shard_file.replace(".safetensors", f"-{idx + 1:05d}.safetensors") + return shard_file \ No newline at end of file diff --git a/colossalai/tensor/dist_spec_mgr.py b/colossalai/tensor/dist_spec_mgr.py index 8657989235db..c968050de49d 100644 --- a/colossalai/tensor/dist_spec_mgr.py +++ b/colossalai/tensor/dist_spec_mgr.py @@ -4,10 +4,8 @@ import torch.distributed as dist # from colossalai.nn.layer.utils import divide from numpy import prod -from packaging import version -from colossalai.logging import get_dist_logger -from colossalai.tensor.distspec import _DistSpec +from colossalai.tensor.distspec import DistPlacementPattern, _DistSpec from colossalai.tensor.process_group import ProcessGroup @@ -171,11 +169,21 @@ def handle_trans_spec(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: pg: ProcessGroup) -> torch.Tensor: assert isinstance(old_dist_spec, _DistSpec), f"{type(old_dist_spec)} should be _DistSpec" assert isinstance(dist_spec, _DistSpec), f"{type(dist_spec)} should be _DistSpec" - forward_trans_handle = getattr(DistSpecManager, f'_{old_dist_spec.placement.value}2{dist_spec.placement.value}') + + trans_func_key = (old_dist_spec.placement, dist_spec.placement) + trans_funcs = { + (DistPlacementPattern.REPLICATE, DistPlacementPattern.REPLICATE): DistSpecManager._r2r, + (DistPlacementPattern.REPLICATE, DistPlacementPattern.SHARD): DistSpecManager._r2s, + (DistPlacementPattern.SHARD, DistPlacementPattern.REPLICATE): DistSpecManager._s2r, + (DistPlacementPattern.SHARD, DistPlacementPattern.SHARD): DistSpecManager._s2s + } + + forward_trans_handle = trans_funcs[trans_func_key] if not DistSpecManager._use_autograd_function: return forward_trans_handle(tensor, old_dist_spec, dist_spec, pg) - backward_trans_handle = getattr(DistSpecManager, - f'_{dist_spec.placement.value}2{old_dist_spec.placement.value}') + + backward_trans_handle = trans_funcs[(dist_spec.placement, old_dist_spec.placement)] + return TransformDistSpec.apply(tensor, old_dist_spec, dist_spec, pg, forward_trans_handle, backward_trans_handle) diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index 8a001b114e9a..878c25be7094 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -2,7 +2,7 @@ from collections import OrderedDict from contextlib import nullcontext from functools import partial -from typing import Dict, Iterator, List, Optional, Union +from typing import Dict, Iterator, List, Optional, Union, Tuple, Set import torch import torch.distributed as dist @@ -96,8 +96,35 @@ def __init__(self, param_name = m_name + '.' + p_name if m_name else p_name self.name2param[param_name] = p_var super().__init__(module, process_group=ColoProcessGroup()) + self._non_persistent_buffers_set=self._get_non_persistent_buffers_set(module) self._cast_buffers() + def _get_non_persistent_buffers_set(self, module, memo: Optional[Set[nn.Module]] = None, prefix: str = '', remove_duplicate: bool = True): + + r""" + Args: + memo: a memo to store the set of modules already added to the result + prefix: a prefix that will be added to the name of the module + remove_duplicate: whether to remove the duplicated module instances in the result + or not + """ + + if memo is None: + memo = set() + self_non_persistent_set = set() + if module not in memo: + if remove_duplicate: + memo.add(module) + self_non_persistent_set = set(map(lambda key: prefix + ('.' if prefix else '') + key, module._non_persistent_buffers_set)) + for name, sub_module in module._modules.items(): + if sub_module is None: + continue + submodule_prefix = prefix + ('.' if prefix else '') + name + child_non_persistent_set = self._get_non_persistent_buffers_set(sub_module, memo, submodule_prefix, remove_duplicate) + self_non_persistent_set = set.union(self_non_persistent_set, child_non_persistent_set) + return self_non_persistent_set + + def _post_forward(self): """This function is only triggered for inference. """ @@ -604,7 +631,7 @@ def state_dict_shard(self, keep_vars: bool = False, max_shard_size: int = 1024, only_rank_0: bool = True, - dtype: torch.dtype = torch.float16) -> Iterator[OrderedDict]: + dtype: torch.dtype = torch.float16) -> Iterator[Tuple[OrderedDict, int]]: """Returns dictionaries containing a whole state of the module one by one. The max size of dictionary shard is specified by ``max_shard_size``. Both parameters and persistent buffers (e.g. running averages) are included. @@ -644,9 +671,9 @@ def state_dict_shard(self, gathered_param_buffer.update(self._get_chunk_to_save_data(chunk, only_rank_0, dtype)) gathered_param = gathered_param_buffer.pop(fp32_param) - block = sharder.append(prefix + name, gathered_param) + block, block_size = sharder.append(prefix + name, gathered_param) if block is not None: - yield block + yield block, block_size del fp16_to_fp32 del gathered_param_buffer @@ -655,19 +682,19 @@ def state_dict_shard(self, for name, buf in self.named_buffers(): if buf is not None and name not in self._non_persistent_buffers_set: buffer = buf if keep_vars else buf.detach() - block = sharder.append(prefix + name, buffer) + block, block_size = sharder.append(prefix + name, buffer) if block is not None: - yield block + yield block, block_size # save extra states extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX if getattr(self.__class__, "get_extra_state", torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state: extra_state = self.get_extra_state() - block = sharder.append(extra_state_key, extra_state) + block, block_size = sharder.append(extra_state_key, extra_state) if block is not None: - yield block + yield block, block_size - yield sharder.current_block + yield sharder.current_block, sharder.current_block_size class _StateDictSharder: @@ -677,16 +704,18 @@ def __init__(self, max_shard_size: int) -> None: self.current_block = OrderedDict() self.current_block_size = 0 - def append(self, name: str, tensor: torch.Tensor) -> Optional[OrderedDict]: + def append(self, name: str, tensor: torch.Tensor) -> Tuple[Optional[OrderedDict], int]: tensor_size = calculate_tensor_size(tensor) ret_block = None + ret_block_size = 0 if self.current_block_size + tensor_size > self.max_shard_size: ret_block = self.current_block + ret_block_size = self.current_block_size self.current_block = OrderedDict() self.current_block_size = 0 self.current_block[name] = tensor self.current_block_size += tensor_size - return ret_block + return ret_block, ret_block_size class GeminiDDP(ZeroDDP): diff --git a/examples/tutorial/new_api/torch_ddp/.gitignore b/examples/tutorial/new_api/cifar_resnet/.gitignore similarity index 100% rename from examples/tutorial/new_api/torch_ddp/.gitignore rename to examples/tutorial/new_api/cifar_resnet/.gitignore diff --git a/examples/tutorial/new_api/cifar_resnet/README.md b/examples/tutorial/new_api/cifar_resnet/README.md new file mode 100644 index 000000000000..4ed86aa7a0ad --- /dev/null +++ b/examples/tutorial/new_api/cifar_resnet/README.md @@ -0,0 +1,56 @@ +# Train ResNet on CIFAR-10 from scratch + +## 🚀 Quick Start + +This example provides a training script and an evaluation script. The training script provides an example of training ResNet on CIFAR10 dataset from scratch. + +- Training Arguments + - `-p`, `--plugin`: Plugin to use. Choices: `torch_ddp`, `torch_ddp_fp16`, `low_level_zero`. Defaults to `torch_ddp`. + - `-r`, `--resume`: Resume from checkpoint file path. Defaults to `-1`, which means not resuming. + - `-c`, `--checkpoint`: The folder to save checkpoints. Defaults to `./checkpoint`. + - `-i`, `--interval`: Epoch interval to save checkpoints. Defaults to `5`. If set to `0`, no checkpoint will be saved. + - `--target_acc`: Target accuracy. Raise exception if not reached. Defaults to `None`. + +- Eval Arguments + - `-e`, `--epoch`: select the epoch to evaluate + - `-c`, `--checkpoint`: the folder where checkpoints are found + +### Install requirements + +```bash +pip install -r requirements.txt +``` + +### Train + +```bash +# train with torch DDP with fp32 +colossalai run --nproc_per_node 2 train.py -c ./ckpt-fp32 + +# train with torch DDP with mixed precision training +colossalai run --nproc_per_node 2 train.py -c ./ckpt-fp16 -p torch_ddp_fp16 + +# train with low level zero +colossalai run --nproc_per_node 2 train.py -c ./ckpt-low_level_zero -p low_level_zero +``` + +### Eval + +```bash +# evaluate fp32 training +python eval.py -c ./ckpt-fp32 -e 80 + +# evaluate fp16 mixed precision training +python eval.py -c ./ckpt-fp16 -e 80 + +# evaluate low level zero training +python eval.py -c ./ckpt-low_level_zero -e 80 +``` + +Expected accuracy performance will be: + +| Model | Single-GPU Baseline FP32 | Booster DDP with FP32 | Booster DDP with FP16 | Booster Low Level Zero | +| --------- | ------------------------ | --------------------- | --------------------- | ---------------------- | +| ResNet-18 | 85.85% | 84.91% | 85.46% | 84.50% | + +**Note: the baseline is adapted from the [script](https://pytorch-tutorial.readthedocs.io/en/latest/tutorial/chapter03_intermediate/3_2_2_cnn_resnet_cifar10/) to use `torchvision.models.resnet18`** diff --git a/examples/tutorial/new_api/torch_ddp/eval.py b/examples/tutorial/new_api/cifar_resnet/eval.py similarity index 100% rename from examples/tutorial/new_api/torch_ddp/eval.py rename to examples/tutorial/new_api/cifar_resnet/eval.py diff --git a/examples/tutorial/new_api/cifar_resnet/requirements.txt b/examples/tutorial/new_api/cifar_resnet/requirements.txt new file mode 100644 index 000000000000..85522f4129c4 --- /dev/null +++ b/examples/tutorial/new_api/cifar_resnet/requirements.txt @@ -0,0 +1,4 @@ +colossalai +torch +torchvision +tqdm diff --git a/examples/tutorial/new_api/cifar_resnet/test_ci.sh b/examples/tutorial/new_api/cifar_resnet/test_ci.sh new file mode 100755 index 000000000000..3954b84ff1ba --- /dev/null +++ b/examples/tutorial/new_api/cifar_resnet/test_ci.sh @@ -0,0 +1,10 @@ +#!/bin/bash +set -xe + +export DATA=/data/scratch/cifar-10 + +pip install -r requirements.txt + +for plugin in "torch_ddp" "torch_ddp_fp16" "low_level_zero"; do + colossalai run --nproc_per_node 4 train.py --interval 0 --target_acc 0.84 --plugin $plugin +done diff --git a/examples/tutorial/new_api/cifar_resnet/train.py b/examples/tutorial/new_api/cifar_resnet/train.py new file mode 100644 index 000000000000..a96a4b640a22 --- /dev/null +++ b/examples/tutorial/new_api/cifar_resnet/train.py @@ -0,0 +1,204 @@ +import argparse +import os +from pathlib import Path + +import torch +import torch.distributed as dist +import torch.nn as nn +import torchvision +import torchvision.transforms as transforms +from torch.optim import Optimizer +from torch.optim.lr_scheduler import MultiStepLR +from torch.utils.data import DataLoader +from tqdm import tqdm + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin +from colossalai.booster.plugin.dp_plugin_base import DPPluginBase +from colossalai.cluster import DistCoordinator +from colossalai.nn.optimizer import HybridAdam +from colossalai.utils import get_current_device + +# ============================== +# Prepare Hyperparameters +# ============================== +NUM_EPOCHS = 80 +LEARNING_RATE = 1e-3 + + +def build_dataloader(batch_size: int, coordinator: DistCoordinator, plugin: DPPluginBase): + # trainsform + transform_train = transforms.Compose( + [transforms.Pad(4), + transforms.RandomHorizontalFlip(), + transforms.RandomCrop(32), + transforms.ToTensor()]) + transform_test = transforms.ToTensor() + + # CIFAR-10 dataset + data_path = os.environ.get('DATA', './data') + with coordinator.priority_execution(): + train_dataset = torchvision.datasets.CIFAR10(root=data_path, + train=True, + transform=transform_train, + download=True) + test_dataset = torchvision.datasets.CIFAR10(root=data_path, + train=False, + transform=transform_test, + download=True) + + # Data loader + train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True) + test_dataloader = plugin.prepare_dataloader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=False) + return train_dataloader, test_dataloader + + +@torch.no_grad() +def evaluate(model: nn.Module, test_dataloader: DataLoader, coordinator: DistCoordinator) -> float: + model.eval() + correct = torch.zeros(1, dtype=torch.int64, device=get_current_device()) + total = torch.zeros(1, dtype=torch.int64, device=get_current_device()) + for images, labels in test_dataloader: + images = images.cuda() + labels = labels.cuda() + outputs = model(images) + _, predicted = torch.max(outputs.data, 1) + total += labels.size(0) + correct += (predicted == labels).sum().item() + dist.all_reduce(correct) + dist.all_reduce(total) + accuracy = correct.item() / total.item() + if coordinator.is_master(): + print(f'Accuracy of the model on the test images: {accuracy * 100:.2f} %') + return accuracy + + +def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, criterion: nn.Module, train_dataloader: DataLoader, + booster: Booster, coordinator: DistCoordinator): + model.train() + with tqdm(train_dataloader, desc=f'Epoch [{epoch + 1}/{NUM_EPOCHS}]', disable=not coordinator.is_master()) as pbar: + for images, labels in pbar: + images = images.cuda() + labels = labels.cuda() + # Forward pass + outputs = model(images) + loss = criterion(outputs, labels) + + # Backward and optimize + booster.backward(loss, optimizer) + optimizer.step() + optimizer.zero_grad() + + # Print log info + pbar.set_postfix({'loss': loss.item()}) + + +def main(): + # ============================== + # Parse Arguments + # ============================== + parser = argparse.ArgumentParser() + # FIXME(ver217): gemini is not supported resnet now + parser.add_argument('-p', + '--plugin', + type=str, + default='torch_ddp', + choices=['torch_ddp', 'torch_ddp_fp16', 'low_level_zero'], + help="plugin to use") + parser.add_argument('-r', '--resume', type=int, default=-1, help="resume from the epoch's checkpoint") + parser.add_argument('-c', '--checkpoint', type=str, default='./checkpoint', help="checkpoint directory") + parser.add_argument('-i', '--interval', type=int, default=5, help="interval of saving checkpoint") + parser.add_argument('--target_acc', + type=float, + default=None, + help="target accuracy. Raise exception if not reached") + args = parser.parse_args() + + # ============================== + # Prepare Checkpoint Directory + # ============================== + if args.interval > 0: + Path(args.checkpoint).mkdir(parents=True, exist_ok=True) + + # ============================== + # Launch Distributed Environment + # ============================== + colossalai.launch_from_torch(config={}) + coordinator = DistCoordinator() + + # update the learning rate with linear scaling + # old_gpu_num / old_lr = new_gpu_num / new_lr + global LEARNING_RATE + LEARNING_RATE *= coordinator.world_size + + # ============================== + # Instantiate Plugin and Booster + # ============================== + booster_kwargs = {} + if args.plugin == 'torch_ddp_fp16': + booster_kwargs['mixed_precision'] = 'fp16' + if args.plugin.startswith('torch_ddp'): + plugin = TorchDDPPlugin() + elif args.plugin == 'gemini': + plugin = GeminiPlugin(placement_policy='cuda', strict_ddp_mode=True, initial_scale=2**5) + elif args.plugin == 'low_level_zero': + plugin = LowLevelZeroPlugin(initial_scale=2**5) + + booster = Booster(plugin=plugin, **booster_kwargs) + + # ============================== + # Prepare Dataloader + # ============================== + train_dataloader, test_dataloader = build_dataloader(100, coordinator, plugin) + + # ==================================== + # Prepare model, optimizer, criterion + # ==================================== + # resent50 + model = torchvision.models.resnet18(num_classes=10) + + # Loss and optimizer + criterion = nn.CrossEntropyLoss() + optimizer = HybridAdam(model.parameters(), lr=LEARNING_RATE) + + # lr scheduler + lr_scheduler = MultiStepLR(optimizer, milestones=[20, 40, 60, 80], gamma=1 / 3) + + # ============================== + # Boost with ColossalAI + # ============================== + model, optimizer, criterion, _, lr_scheduler = booster.boost(model, + optimizer, + criterion=criterion, + lr_scheduler=lr_scheduler) + + # ============================== + # Resume from checkpoint + # ============================== + if args.resume >= 0: + booster.load_model(model, f'{args.checkpoint}/model_{args.resume}.pth') + booster.load_optimizer(optimizer, f'{args.checkpoint}/optimizer_{args.resume}.pth') + booster.load_lr_scheduler(lr_scheduler, f'{args.checkpoint}/lr_scheduler_{args.resume}.pth') + + # ============================== + # Train model + # ============================== + start_epoch = args.resume if args.resume >= 0 else 0 + for epoch in range(start_epoch, NUM_EPOCHS): + train_epoch(epoch, model, optimizer, criterion, train_dataloader, booster, coordinator) + lr_scheduler.step() + + # save checkpoint + if args.interval > 0 and (epoch + 1) % args.interval == 0: + booster.save_model(model, f'{args.checkpoint}/model_{epoch + 1}.pth') + booster.save_optimizer(optimizer, f'{args.checkpoint}/optimizer_{epoch + 1}.pth') + booster.save_lr_scheduler(lr_scheduler, f'{args.checkpoint}/lr_scheduler_{epoch + 1}.pth') + + accuracy = evaluate(model, test_dataloader, coordinator) + if args.target_acc is not None: + assert accuracy >= args.target_acc, f'Accuracy {accuracy} is lower than target accuracy {args.target_acc}' + + +if __name__ == '__main__': + main() diff --git a/examples/tutorial/new_api/cifar_vit/README.md b/examples/tutorial/new_api/cifar_vit/README.md new file mode 100644 index 000000000000..fa76447c508f --- /dev/null +++ b/examples/tutorial/new_api/cifar_vit/README.md @@ -0,0 +1,37 @@ +# Train ViT on CIFAR-10 from scratch + +## 🚀 Quick Start + +This example provides a training script, which provides an example of training ViT on CIFAR10 dataset from scratch. + +- Training Arguments + - `-p`, `--plugin`: Plugin to use. Choices: `torch_ddp`, `torch_ddp_fp16`, `low_level_zero`. Defaults to `torch_ddp`. + - `-r`, `--resume`: Resume from checkpoint file path. Defaults to `-1`, which means not resuming. + - `-c`, `--checkpoint`: The folder to save checkpoints. Defaults to `./checkpoint`. + - `-i`, `--interval`: Epoch interval to save checkpoints. Defaults to `5`. If set to `0`, no checkpoint will be saved. + - `--target_acc`: Target accuracy. Raise exception if not reached. Defaults to `None`. + +### Install requirements + +```bash +pip install -r requirements.txt +``` + +### Train + +```bash +# train with torch DDP with fp32 +colossalai run --nproc_per_node 4 train.py -c ./ckpt-fp32 + +# train with torch DDP with mixed precision training +colossalai run --nproc_per_node 4 train.py -c ./ckpt-fp16 -p torch_ddp_fp16 + +# train with low level zero +colossalai run --nproc_per_node 4 train.py -c ./ckpt-low_level_zero -p low_level_zero +``` + +Expected accuracy performance will be: + +| Model | Single-GPU Baseline FP32 | Booster DDP with FP32 | Booster DDP with FP16 | Booster Low Level Zero | +| --------- | ------------------------ | --------------------- | --------------------- | ---------------------- | +| ViT | 83.00% | 84.03% | 84.00% | 84.43% | diff --git a/examples/tutorial/new_api/cifar_vit/requirements.txt b/examples/tutorial/new_api/cifar_vit/requirements.txt new file mode 100644 index 000000000000..6d53ce7b5a7d --- /dev/null +++ b/examples/tutorial/new_api/cifar_vit/requirements.txt @@ -0,0 +1,5 @@ +colossalai +timm +torch +torchvision +tqdm diff --git a/examples/tutorial/new_api/cifar_vit/test_ci.sh b/examples/tutorial/new_api/cifar_vit/test_ci.sh new file mode 100755 index 000000000000..43239d400586 --- /dev/null +++ b/examples/tutorial/new_api/cifar_vit/test_ci.sh @@ -0,0 +1,10 @@ +#!/bin/bash +set -xe + +export DATA=/data/scratch/cifar-10 + +pip install -r requirements.txt + +for plugin in "torch_ddp" "torch_ddp_fp16" "low_level_zero"; do + colossalai run --nproc_per_node 4 train.py --interval 0 --target_acc 0.83 --plugin $plugin +done diff --git a/examples/tutorial/new_api/cifar_vit/train.py b/examples/tutorial/new_api/cifar_vit/train.py new file mode 100644 index 000000000000..2405fdfc60d5 --- /dev/null +++ b/examples/tutorial/new_api/cifar_vit/train.py @@ -0,0 +1,219 @@ +import argparse +import os +from pathlib import Path + +import torch +import torch.distributed as dist +import torch.nn as nn +import torchvision +import torchvision.transforms as transforms +from timm.models.vision_transformer import _cfg, _create_vision_transformer +from torch.optim import Optimizer +from torch.utils.data import DataLoader +from tqdm import tqdm + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin +from colossalai.booster.plugin.dp_plugin_base import DPPluginBase +from colossalai.cluster import DistCoordinator +from colossalai.nn.lr_scheduler import LinearWarmupLR +from colossalai.nn.optimizer import HybridAdam +from colossalai.utils import get_current_device + +# ============================== +# Prepare Hyperparameters +# ============================== +NUM_EPOCHS = 60 +WARMUP_EPOCSH = 5 +LEARNING_RATE = 1e-3 + + +def vit_cifar(**kwargs): + pretrained_cfg = _cfg(num_classes=10, input_size=(3, 32, 32), crop_pct=1.0) + model_kwargs = dict(patch_size=4, embed_dim=512, depth=6, num_heads=8, drop_rate=0.1, mlp_ratio=1.0, **kwargs) + model = _create_vision_transformer('vit_cifar', pretrained_cfg=pretrained_cfg, **model_kwargs) + return model + + +def build_dataloader(batch_size: int, coordinator: DistCoordinator, plugin: DPPluginBase): + # trainsform + transform_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.49139968, 0.48215827, 0.44653124), (0.24703233, 0.24348505, 0.26158768)), + ]) + transform_test = transforms.Compose([ + transforms.Resize(32), + transforms.ToTensor(), + transforms.Normalize((0.49139968, 0.48215827, 0.44653124), (0.24703233, 0.24348505, 0.26158768)), + ]) + + # CIFAR-10 dataset + data_path = os.environ.get('DATA', './data') + with coordinator.priority_execution(): + train_dataset = torchvision.datasets.CIFAR10(root=data_path, + train=True, + transform=transform_train, + download=True) + test_dataset = torchvision.datasets.CIFAR10(root=data_path, + train=False, + transform=transform_test, + download=True) + + # Data loader + train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True) + test_dataloader = plugin.prepare_dataloader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=False) + return train_dataloader, test_dataloader + + +@torch.no_grad() +def evaluate(model: nn.Module, test_dataloader: DataLoader, coordinator: DistCoordinator) -> float: + model.eval() + correct = torch.zeros(1, dtype=torch.int64, device=get_current_device()) + total = torch.zeros(1, dtype=torch.int64, device=get_current_device()) + for images, labels in test_dataloader: + images = images.cuda() + labels = labels.cuda() + outputs = model(images) + _, predicted = torch.max(outputs.data, 1) + total += labels.size(0) + correct += (predicted == labels).sum().item() + dist.all_reduce(correct) + dist.all_reduce(total) + accuracy = correct.item() / total.item() + if coordinator.is_master(): + print(f'Accuracy of the model on the test images: {accuracy * 100:.2f} %') + return accuracy + + +def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, criterion: nn.Module, train_dataloader: DataLoader, + booster: Booster, coordinator: DistCoordinator): + model.train() + with tqdm(train_dataloader, desc=f'Epoch [{epoch + 1}/{NUM_EPOCHS}]', disable=not coordinator.is_master()) as pbar: + for images, labels in pbar: + images = images.cuda() + labels = labels.cuda() + # Forward pass + outputs = model(images) + loss = criterion(outputs, labels) + + # Backward and optimize + booster.backward(loss, optimizer) + optimizer.step() + optimizer.zero_grad() + + # Print log info + pbar.set_postfix({'loss': loss.item()}) + + +def main(): + # ============================== + # Parse Arguments + # ============================== + parser = argparse.ArgumentParser() + # FIXME(ver217): gemini is not supported resnet now + parser.add_argument('-p', + '--plugin', + type=str, + default='torch_ddp', + choices=['torch_ddp', 'torch_ddp_fp16', 'low_level_zero'], + help="plugin to use") + parser.add_argument('-r', '--resume', type=int, default=-1, help="resume from the epoch's checkpoint") + parser.add_argument('-c', '--checkpoint', type=str, default='./checkpoint', help="checkpoint directory") + parser.add_argument('-i', '--interval', type=int, default=5, help="interval of saving checkpoint") + parser.add_argument('--target_acc', + type=float, + default=None, + help="target accuracy. Raise exception if not reached") + args = parser.parse_args() + + # ============================== + # Prepare Checkpoint Directory + # ============================== + if args.interval > 0: + Path(args.checkpoint).mkdir(parents=True, exist_ok=True) + + # ============================== + # Launch Distributed Environment + # ============================== + colossalai.launch_from_torch(config={}) + coordinator = DistCoordinator() + + # update the learning rate with linear scaling + # old_gpu_num / old_lr = new_gpu_num / new_lr + global LEARNING_RATE + LEARNING_RATE *= coordinator.world_size + + # ============================== + # Instantiate Plugin and Booster + # ============================== + booster_kwargs = {} + if args.plugin == 'torch_ddp_fp16': + booster_kwargs['mixed_precision'] = 'fp16' + if args.plugin.startswith('torch_ddp'): + plugin = TorchDDPPlugin() + elif args.plugin == 'gemini': + plugin = GeminiPlugin(placement_policy='cuda', strict_ddp_mode=True, initial_scale=2**5) + elif args.plugin == 'low_level_zero': + plugin = LowLevelZeroPlugin(initial_scale=2**5) + + booster = Booster(plugin=plugin, **booster_kwargs) + + # ============================== + # Prepare Dataloader + # ============================== + train_dataloader, test_dataloader = build_dataloader(512, coordinator, plugin) + + # ==================================== + # Prepare model, optimizer, criterion + # ==================================== + # resent50 + model = torchvision.models.resnet18(num_classes=10) + + # Loss and optimizer + criterion = nn.CrossEntropyLoss() + optimizer = HybridAdam(model.parameters(), lr=LEARNING_RATE) + + # lr scheduler + lr_scheduler = LinearWarmupLR(optimizer, NUM_EPOCHS, WARMUP_EPOCSH) + + # ============================== + # Boost with ColossalAI + # ============================== + model, optimizer, criterion, train_dataloader, lr_scheduler = booster.boost(model, + optimizer, + criterion=criterion, + dataloader=train_dataloader, + lr_scheduler=lr_scheduler) + + # ============================== + # Resume from checkpoint + # ============================== + if args.resume >= 0: + booster.load_model(model, f'{args.checkpoint}/model_{args.resume}.pth') + booster.load_optimizer(optimizer, f'{args.checkpoint}/optimizer_{args.resume}.pth') + booster.load_lr_scheduler(lr_scheduler, f'{args.checkpoint}/lr_scheduler_{args.resume}.pth') + + # ============================== + # Train model + # ============================== + start_epoch = args.resume if args.resume >= 0 else 0 + for epoch in range(start_epoch, NUM_EPOCHS): + train_epoch(epoch, model, optimizer, criterion, train_dataloader, booster, coordinator) + lr_scheduler.step() + + # save checkpoint + if args.interval > 0 and (epoch + 1) % args.interval == 0: + booster.save_model(model, f'{args.checkpoint}/model_{epoch + 1}.pth') + booster.save_optimizer(optimizer, f'{args.checkpoint}/optimizer_{epoch + 1}.pth') + booster.save_lr_scheduler(lr_scheduler, f'{args.checkpoint}/lr_scheduler_{epoch + 1}.pth') + + accuracy = evaluate(model, test_dataloader, coordinator) + if args.target_acc is not None: + assert accuracy >= args.target_acc, f'Accuracy {accuracy} is lower than target accuracy {args.target_acc}' + + +if __name__ == '__main__': + main() diff --git a/examples/tutorial/new_api/glue_bert/README.md b/examples/tutorial/new_api/glue_bert/README.md new file mode 100644 index 000000000000..0030eead9f5b --- /dev/null +++ b/examples/tutorial/new_api/glue_bert/README.md @@ -0,0 +1,39 @@ +# Finetune BERT on GLUE + +## 🚀 Quick Start + +This example provides a training script, which provides an example of finetuning BERT on GLUE dataset. + +- Training Arguments + - `-t`, `--task`: GLUE task to run. Defaults to `mrpc`. + - `-p`, `--plugin`: Plugin to use. Choices: `torch_ddp`, `torch_ddp_fp16`, `gemini`, `low_level_zero`. Defaults to `torch_ddp`. + - `--target_f1`: Target f1 score. Raise exception if not reached. Defaults to `None`. + + +### Install requirements + +```bash +pip install -r requirements.txt +``` + +### Train + +```bash +# train with torch DDP with fp32 +colossalai run --nproc_per_node 4 finetune.py + +# train with torch DDP with mixed precision training +colossalai run --nproc_per_node 4 finetune.py -p torch_ddp_fp16 + +# train with gemini +colossalai run --nproc_per_node 4 finetune.py -p gemini + +# train with low level zero +colossalai run --nproc_per_node 4 finetune.py -p low_level_zero +``` + +Expected F1-score will be: + +| Model | Single-GPU Baseline FP32 | Booster DDP with FP32 | Booster DDP with FP16 | Booster Gemini | Booster Low Level Zero | +| ----------------- | ------------------------ | --------------------- | --------------------- |--------------- | ---------------------- | +| bert-base-uncased | 0.86 | 0.88 | 0.87 | 0.88 | 0.89 | diff --git a/examples/tutorial/new_api/glue_bert/data.py b/examples/tutorial/new_api/glue_bert/data.py new file mode 100644 index 000000000000..981cedcca8c2 --- /dev/null +++ b/examples/tutorial/new_api/glue_bert/data.py @@ -0,0 +1,127 @@ +import datasets +from transformers import AutoTokenizer, PreTrainedTokenizer + +from colossalai.booster.plugin.dp_plugin_base import DPPluginBase + + +class GLUEDataBuilder: + + task_text_field_map = { + "cola": ["sentence"], + "sst2": ["sentence"], + "mrpc": ["sentence1", "sentence2"], + "qqp": ["question1", "question2"], + "stsb": ["sentence1", "sentence2"], + "mnli": ["premise", "hypothesis"], + "qnli": ["question", "sentence"], + "rte": ["sentence1", "sentence2"], + "wnli": ["sentence1", "sentence2"], + "ax": ["premise", "hypothesis"], + } + + glue_task_num_labels = { + "cola": 2, + "sst2": 2, + "mrpc": 2, + "qqp": 2, + "stsb": 1, + "mnli": 3, + "qnli": 2, + "rte": 2, + "wnli": 2, + "ax": 3, + } + + loader_columns = [ + "datasets_idx", + "input_ids", + "token_type_ids", + "attention_mask", + "start_positions", + "end_positions", + "labels", + ] + + def __init__( + self, + model_name_or_path: str, + plugin: DPPluginBase, + task_name: str = "mrpc", + max_seq_length: int = 128, + train_batch_size: int = 32, + eval_batch_size: int = 32, + **kwargs, + ): + super().__init__() + self.model_name_or_path = model_name_or_path + self.task_name = task_name + self.max_seq_length = max_seq_length + self.train_batch_size = train_batch_size + self.eval_batch_size = eval_batch_size + self.plugin = plugin + + self.text_fields = self.task_text_field_map[task_name] + self.num_labels = self.glue_task_num_labels[task_name] + self.tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True) + self.setup() + + def setup(self): + self.dataset = datasets.load_dataset("glue", self.task_name) + + for split in self.dataset.keys(): + self.dataset[split] = self.dataset[split].map( + self.convert_to_features, + batched=True, + remove_columns=["label"], + ) + self.columns = [c for c in self.dataset[split].column_names if c in self.loader_columns] + self.dataset[split].set_format(type="torch", columns=self.columns) + + self.eval_splits = [x for x in self.dataset.keys() if "validation" in x] + + def prepare_data(self): + datasets.load_dataset("glue", self.task_name) + AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True) + + def train_dataloader(self): + return self.plugin.prepare_dataloader(self.dataset["train"], + batch_size=self.train_batch_size, + shuffle=True, + drop_last=True) + + def val_dataloader(self): + if len(self.eval_splits) == 1: + return self.plugin.prepare_dataloader(self.dataset["validation"], batch_size=self.eval_batch_size) + elif len(self.eval_splits) > 1: + return [ + self.plugin.prepare_dataloader(self.dataset[x], batch_size=self.eval_batch_size) + for x in self.eval_splits + ] + + def test_dataloader(self): + if len(self.eval_splits) == 1: + return self.plugin.prepare_dataloader(self.dataset["test"], batch_size=self.eval_batch_size) + elif len(self.eval_splits) > 1: + return [ + self.plugin.prepare_dataloader(self.dataset[x], batch_size=self.eval_batch_size) + for x in self.eval_splits + ] + + def convert_to_features(self, example_batch): + + # Either encode single sentence or sentence pairs + if len(self.text_fields) > 1: + texts_or_text_pairs = list(zip(example_batch[self.text_fields[0]], example_batch[self.text_fields[1]])) + else: + texts_or_text_pairs = example_batch[self.text_fields[0]] + + # Tokenize the text/text pairs + features = self.tokenizer.batch_encode_plus(texts_or_text_pairs, + max_length=self.max_seq_length, + padding='max_length', + truncation=True) + + # Rename label to labels to make it easier to pass to model forward + features["labels"] = example_batch["label"] + + return features diff --git a/examples/tutorial/new_api/glue_bert/finetune.py b/examples/tutorial/new_api/glue_bert/finetune.py new file mode 100644 index 000000000000..63bdfc5d02cf --- /dev/null +++ b/examples/tutorial/new_api/glue_bert/finetune.py @@ -0,0 +1,198 @@ +import argparse +from typing import List, Union + +import datasets +import torch +import torch.distributed as dist +import torch.nn as nn +from data import GLUEDataBuilder +from torch.optim import Optimizer +from torch.utils.data import DataLoader +from tqdm import tqdm +from transformers import AutoConfig, BertForSequenceClassification, get_linear_schedule_with_warmup + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin +from colossalai.cluster import DistCoordinator +from colossalai.nn.optimizer import HybridAdam +from colossalai.utils import get_current_device + +# ============================== +# Prepare Hyperparameters +# ============================== +NUM_EPOCHS = 3 +BATCH_SIZE = 32 +LEARNING_RATE = 2.4e-5 +WEIGHT_DECAY = 0.01 +WARMUP_FRACTION = 0.1 + + +def move_to_cuda(batch): + return {k: v.cuda() for k, v in batch.items()} + + +@torch.no_grad() +def evaluate(model: nn.Module, test_dataloader: Union[DataLoader, List[DataLoader]], num_labels: int, task_name: str, + eval_splits: List[str], coordinator: DistCoordinator): + metric = datasets.load_metric("glue", task_name, process_id=coordinator.rank, num_process=coordinator.world_size) + model.eval() + + def evaluate_subset(dataloader: DataLoader): + accum_loss = torch.zeros(1, device=get_current_device()) + for batch in dataloader: + batch = move_to_cuda(batch) + outputs = model(**batch) + val_loss, logits = outputs[:2] + accum_loss.add_(val_loss) + + if num_labels > 1: + preds = torch.argmax(logits, axis=1) + elif num_labels == 1: + preds = logits.squeeze() + + labels = batch["labels"] + + metric.add_batch(predictions=preds, references=labels) + + results = metric.compute() + dist.all_reduce(accum_loss.div_(len(dataloader))) + if coordinator.is_master(): + results['loss'] = accum_loss.item() / coordinator.world_size + return results + + if isinstance(test_dataloader, DataLoader): + return evaluate_subset(test_dataloader) + else: + assert len(test_dataloader) == len(eval_splits) + final_results = {} + for split, sub_loader in zip(eval_splits, test_dataloader): + results = evaluate_subset(sub_loader) + final_results.update({f'{k}_{split}': v for k, v in results.items()}) + return final_results + + +def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, lr_scheduler, train_dataloader: DataLoader, + booster: Booster, coordinator: DistCoordinator): + model.train() + with tqdm(train_dataloader, desc=f'Epoch [{epoch + 1}/{NUM_EPOCHS}]', disable=not coordinator.is_master()) as pbar: + for batch in pbar: + # Forward pass + batch = move_to_cuda(batch) + outputs = model(**batch) + loss = outputs[0] + + # Backward and optimize + booster.backward(loss, optimizer) + optimizer.step() + optimizer.zero_grad() + lr_scheduler.step() + + # Print log info + pbar.set_postfix({'loss': loss.item()}) + + +def main(): + # ============================== + # Parse Arguments + # ============================== + parser = argparse.ArgumentParser() + parser.add_argument('-t', '--task', default='mrpc', help="GLUE task to run") + parser.add_argument('-p', + '--plugin', + type=str, + default='torch_ddp', + choices=['torch_ddp', 'torch_ddp_fp16', 'gemini', 'low_level_zero'], + help="plugin to use") + parser.add_argument('--target_f1', type=float, default=None, help="target f1 score. Raise exception if not reached") + args = parser.parse_args() + + # ============================== + # Launch Distributed Environment + # ============================== + colossalai.launch_from_torch(config={}, seed=42) + coordinator = DistCoordinator() + + # local_batch_size = BATCH_SIZE // coordinator.world_size + lr = LEARNING_RATE * coordinator.world_size + model_name = 'bert-base-uncased' + + # ============================== + # Instantiate Plugin and Booster + # ============================== + booster_kwargs = {} + if args.plugin == 'torch_ddp_fp16': + booster_kwargs['mixed_precision'] = 'fp16' + if args.plugin.startswith('torch_ddp'): + plugin = TorchDDPPlugin() + elif args.plugin == 'gemini': + plugin = GeminiPlugin(placement_policy='cuda', strict_ddp_mode=True, initial_scale=2**5) + elif args.plugin == 'low_level_zero': + plugin = LowLevelZeroPlugin(initial_scale=2**5) + + booster = Booster(plugin=plugin, **booster_kwargs) + + # ============================== + # Prepare Dataloader + # ============================== + data_builder = GLUEDataBuilder(model_name, + plugin, + args.task, + train_batch_size=BATCH_SIZE, + eval_batch_size=BATCH_SIZE) + train_dataloader = data_builder.train_dataloader() + test_dataloader = data_builder.test_dataloader() + + # ==================================== + # Prepare model, optimizer + # ==================================== + # bert pretrained model + config = AutoConfig.from_pretrained(model_name, num_labels=data_builder.num_labels) + model = BertForSequenceClassification.from_pretrained(model_name, config=config) + + # optimizer + no_decay = ["bias", "LayerNorm.weight"] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], + "weight_decay": WEIGHT_DECAY, + }, + { + "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], + "weight_decay": 0.0, + }, + ] + + optimizer = HybridAdam(optimizer_grouped_parameters, lr=lr, eps=1e-8) + + # lr scheduler + total_steps = len(train_dataloader) * NUM_EPOCHS + num_warmup_steps = int(WARMUP_FRACTION * total_steps) + lr_scheduler = get_linear_schedule_with_warmup( + optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=total_steps, + ) + + # ============================== + # Boost with ColossalAI + # ============================== + model, optimizer, _, _, lr_scheduler = booster.boost(model, optimizer, lr_scheduler=lr_scheduler) + + # ============================== + # Train model + # ============================== + for epoch in range(NUM_EPOCHS): + train_epoch(epoch, model, optimizer, lr_scheduler, train_dataloader, booster, coordinator) + + results = evaluate(model, test_dataloader, data_builder.num_labels, args.task, data_builder.eval_splits, + coordinator) + + if coordinator.is_master(): + print(results) + if args.target_f1 is not None and 'f1' in results: + assert results['f1'] >= args.target_f1, f'f1 score {results["f1"]} is lower than target {args.target_f1}' + + +if __name__ == '__main__': + main() diff --git a/examples/tutorial/new_api/glue_bert/requirements.txt b/examples/tutorial/new_api/glue_bert/requirements.txt new file mode 100644 index 000000000000..950c2d378f08 --- /dev/null +++ b/examples/tutorial/new_api/glue_bert/requirements.txt @@ -0,0 +1,7 @@ +colossalai +datasets +torch +tqdm +transformers +scipy +scikit-learn diff --git a/examples/tutorial/new_api/glue_bert/test_ci.sh b/examples/tutorial/new_api/glue_bert/test_ci.sh new file mode 100755 index 000000000000..c2c097f8d026 --- /dev/null +++ b/examples/tutorial/new_api/glue_bert/test_ci.sh @@ -0,0 +1,8 @@ +#!/bin/bash +set -xe + +pip install -r requirements.txt + +for plugin in "torch_ddp" "torch_ddp_fp16" "gemini" "low_level_zero"; do + torchrun --standalone --nproc_per_node 4 finetune.py --target_f1 0.86 --plugin $plugin +done diff --git a/examples/tutorial/new_api/test_ci.sh b/examples/tutorial/new_api/test_ci.sh index 8b4475e9f147..a08844dbe5fa 100644 --- a/examples/tutorial/new_api/test_ci.sh +++ b/examples/tutorial/new_api/test_ci.sh @@ -1,2 +1,6 @@ -#!/usr/bin/env -echo "The CI integration will be completed when the API is stable" +#!/bin/bash +set -xe + +# FIXME(ver217): only run bert finetune to save time + +cd glue_bert && bash ./test_ci.sh && cd .. diff --git a/examples/tutorial/new_api/torch_ddp/README.md b/examples/tutorial/new_api/torch_ddp/README.md deleted file mode 100644 index e120bacb0c84..000000000000 --- a/examples/tutorial/new_api/torch_ddp/README.md +++ /dev/null @@ -1,44 +0,0 @@ -# Distributed Data Parallel - -## 🚀 Quick Start - -This example provides a training script and an evaluation script. The training script provides an example of training ResNet on CIFAR10 dataset from scratch. - -- Training Arguments - - `-r`, `--resume`: resume from checkpoint file path - - `-c`, `--checkpoint`: the folder to save checkpoints - - `-i`, `--interval`: epoch interval to save checkpoints - - `-f`, `--fp16`: use fp16 - -- Eval Arguments - - `-e`, `--epoch`: select the epoch to evaluate - - `-c`, `--checkpoint`: the folder where checkpoints are found - - -### Train - -```bash -# train with torch DDP with fp32 -colossalai run --nproc_per_node 2 train.py -c ./ckpt-fp32 - -# train with torch DDP with mixed precision training -colossalai run --nproc_per_node 2 train.py -c ./ckpt-fp16 --fp16 -``` - -### Eval - -```bash -# evaluate fp32 training -python eval.py -c ./ckpt-fp32 -e 80 - -# evaluate fp16 mixed precision training -python eval.py -c ./ckpt-fp16 -e 80 -``` - -Expected accuracy performance will be: - -| Model | Single-GPU Baseline FP32 | Booster DDP with FP32 | Booster DDP with FP16 | -| --------- | ------------------------ | --------------------- | --------------------- | -| ResNet-18 | 85.85% | 85.03% | 85.12% | - -**Note: the baseline is adapted from the [script](https://pytorch-tutorial.readthedocs.io/en/latest/tutorial/chapter03_intermediate/3_2_2_cnn_resnet_cifar10/) to use `torchvision.models.resnet18`** diff --git a/examples/tutorial/new_api/torch_ddp/train.py b/examples/tutorial/new_api/torch_ddp/train.py deleted file mode 100644 index 4741c3151cbb..000000000000 --- a/examples/tutorial/new_api/torch_ddp/train.py +++ /dev/null @@ -1,128 +0,0 @@ -import argparse -from pathlib import Path - -import torch -import torch.nn as nn -import torchvision -import torchvision.transforms as transforms -from torch.optim.lr_scheduler import MultiStepLR - -import colossalai -from colossalai.booster import Booster -from colossalai.booster.plugin import TorchDDPPlugin -from colossalai.cluster import DistCoordinator - -# ============================== -# Parse Arguments -# ============================== -parser = argparse.ArgumentParser() -parser.add_argument('-r', '--resume', type=int, default=-1, help="resume from the epoch's checkpoint") -parser.add_argument('-c', '--checkpoint', type=str, default='./checkpoint', help="checkpoint directory") -parser.add_argument('-i', '--interval', type=int, default=5, help="interval of saving checkpoint") -parser.add_argument('-f', '--fp16', action='store_true', help="use fp16") -args = parser.parse_args() - -# ============================== -# Prepare Checkpoint Directory -# ============================== -Path(args.checkpoint).mkdir(parents=True, exist_ok=True) - -# ============================== -# Prepare Hyperparameters -# ============================== -NUM_EPOCHS = 80 -LEARNING_RATE = 1e-3 -START_EPOCH = args.resume if args.resume >= 0 else 0 - -# ============================== -# Launch Distributed Environment -# ============================== -colossalai.launch_from_torch(config={}) -coordinator = DistCoordinator() - -# update the learning rate with linear scaling -# old_gpu_num / old_lr = new_gpu_num / new_lr -LEARNING_RATE *= coordinator.world_size - -# ============================== -# Prepare Booster -# ============================== -plugin = TorchDDPPlugin() -if args.fp16: - booster = Booster(mixed_precision='fp16', plugin=plugin) -else: - booster = Booster(plugin=plugin) - -# ============================== -# Prepare Train Dataset -# ============================== -transform = transforms.Compose( - [transforms.Pad(4), - transforms.RandomHorizontalFlip(), - transforms.RandomCrop(32), - transforms.ToTensor()]) - -# CIFAR-10 dataset -with coordinator.priority_execution(): - train_dataset = torchvision.datasets.CIFAR10(root='./data/', train=True, transform=transform, download=True) - -# ==================================== -# Prepare model, optimizer, criterion -# ==================================== -# resent50 -model = torchvision.models.resnet18(num_classes=10).cuda() - -# Loss and optimizer -criterion = nn.CrossEntropyLoss() -optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE) - -# lr scheduler -lr_scheduler = MultiStepLR(optimizer, milestones=[20, 40, 60, 80], gamma=1 / 3) - -# prepare dataloader with torch ddp plugin -train_dataloader = plugin.prepare_train_dataloader(train_dataset, batch_size=100, shuffle=True) - -# ============================== -# Resume from checkpoint -# ============================== -if args.resume >= 0: - booster.load_model(model, f'{args.checkpoint}/model_{args.resume}.pth') - booster.load_optimizer(optimizer, f'{args.checkpoint}/optimizer_{args.resume}.pth') - booster.load_lr_scheduler(lr_scheduler, f'{args.checkpoint}/lr_scheduler_{args.resume}.pth') - -# ============================== -# Boost with ColossalAI -# ============================== -model, optimizer, criterion, train_dataloader, lr_scheduler = booster.boost(model, optimizer, criterion, - train_dataloader, lr_scheduler) - -# ============================== -# Train model -# ============================== -total_step = len(train_dataloader) - -for epoch in range(START_EPOCH, NUM_EPOCHS): - for i, (images, labels) in enumerate(train_dataloader): - images = images.cuda() - labels = labels.cuda() - - # Forward pass - outputs = model(images) - loss = criterion(outputs, labels) - - # Backward and optimize - optimizer.zero_grad() - booster.backward(loss, optimizer) - optimizer.step() - - if (i + 1) % 100 == 0: - print("Epoch [{}/{}], Step [{}/{}] Loss: {:.4f}".format(epoch + 1, NUM_EPOCHS, i + 1, total_step, - loss.item())) - - lr_scheduler.step() - - # save checkpoint every 5 epoch - if (epoch + 1) % args.interval == 0: - booster.save_model(model, f'{args.checkpoint}/model_{epoch + 1}.pth') - booster.save_optimizer(optimizer, f'{args.checkpoint}/optimizer_{epoch + 1}.pth') - booster.save_lr_scheduler(lr_scheduler, f'{args.checkpoint}/lr_scheduler_{epoch + 1}.pth') diff --git a/pytest.ini b/pytest.ini index ac31ace4bfae..01e5cd217c5d 100644 --- a/pytest.ini +++ b/pytest.ini @@ -3,4 +3,4 @@ markers = cpu: tests which can run on CPU gpu: tests which requires a single GPU dist: tests which are run in a multi-GPU or multi-machine environment - experiment: tests for experimental features \ No newline at end of file + experiment: tests for experimental features diff --git a/tests/test_booster/test_plugin/test_dp_plugin_base.py b/tests/test_booster/test_plugin/test_dp_plugin_base.py new file mode 100644 index 000000000000..eab949828db9 --- /dev/null +++ b/tests/test_booster/test_plugin/test_dp_plugin_base.py @@ -0,0 +1,85 @@ +from typing import Callable, List, Tuple, Union + +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler as LRScheduler +from torch.utils.data import DataLoader, TensorDataset + +import colossalai +from colossalai.booster.plugin.dp_plugin_base import DPPluginBase +from colossalai.checkpoint_io import CheckpointIO +from colossalai.interface import OptimizerWrapper +from colossalai.testing import rerun_if_address_is_in_use, spawn + + +class DPPluginWrapper(DPPluginBase): + """This is a wrapper class for testing DP plugin initialization and dataloader creation. + """ + + def configure( + self, + model: nn.Module, + optimizer: Optimizer, + criterion: Callable = None, + dataloader: DataLoader = None, + lr_scheduler: LRScheduler = None, + ) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]: + pass + + def control_checkpoint_io(self) -> bool: + pass + + def control_device(self) -> bool: + pass + + def control_precision(self) -> bool: + pass + + def get_checkpoint_io(self) -> CheckpointIO: + pass + + def support_no_sync(self) -> bool: + pass + + def supported_devices(self) -> List[str]: + pass + + def supported_precisions(self) -> List[str]: + pass + + +def check_dataloader_sharding(): + plugin = DPPluginWrapper() + + # create a custom dasetset with 0 to 10 + dataset = TensorDataset(torch.arange(0, 10)) + train_dataloader = plugin.prepare_dataloader(dataset, batch_size=2) + + # get the first batch of data + batch = next(iter(train_dataloader))[0].cuda() + is_rank_0 = dist.get_rank() == 0 + + if is_rank_0: + batch_to_compare = batch.clone() + else: + batch_to_compare = batch + # pass to the rank 1 value to rank 0 + dist.broadcast(batch_to_compare, src=1) + + # compare on rank 0 + if is_rank_0: + assert not torch.equal(batch, + batch_to_compare), 'Same number was found across ranks but expected it to be different' + + +def run_dist(rank, world_size, port): + # init dist env + colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') + check_dataloader_sharding() + + +@rerun_if_address_is_in_use() +def test_dp_plugin_dataloader(): + spawn(run_dist, 2) diff --git a/tests/test_booster/test_plugin/test_gemini_plugin.py b/tests/test_booster/test_plugin/test_gemini_plugin.py index 985d7989fc9d..c7b3676fb478 100644 --- a/tests/test_booster/test_plugin/test_gemini_plugin.py +++ b/tests/test_booster/test_plugin/test_gemini_plugin.py @@ -117,34 +117,9 @@ def check_gemini_plugin(init_method: str = 'none', early_stop: bool = True): assert len(failed_info) == 0, '\n'.join([f'{k}: {v}' for k, v in failed_info.items()]) -def check_dataloader_sharding(): - plugin = GeminiPlugin() - - # create a custom dasetset with 0 to 10 - dataset = torch.utils.data.TensorDataset(torch.arange(0, 10)) - train_dataloader = plugin.prepare_train_dataloader(dataset, batch_size=2) - - # get the first batch of data - batch = next(iter(train_dataloader))[0].cuda() - is_rank_0 = dist.get_rank() == 0 - - if is_rank_0: - batch_to_compare = batch.clone() - else: - batch_to_compare = batch - # pass to the rank 1 value to rank 0 - dist.broadcast(batch_to_compare, src=1) - - # compare on rank 0 - if is_rank_0: - assert not torch.equal(batch, - batch_to_compare), 'Same number was found across ranks but expected it to be different' - - def run_dist(rank, world_size, port, early_stop: bool = True): # init dist env colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') - check_dataloader_sharding() check_gemini_plugin(early_stop=early_stop) diff --git a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py index e24196a14917..d84b96f77a75 100644 --- a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py +++ b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py @@ -83,30 +83,6 @@ def check_low_level_zero_plugin(stage: int, early_stop: bool = True): assert len(failed_info) == 0, '\n'.join([f'{k}: {v}' for k, v in failed_info.items()]) -def check_dataloader_sharding(): - plugin = LowLevelZeroPlugin() - - # create a custom dasetset with 0 to 10 - dataset = torch.utils.data.TensorDataset(torch.arange(0, 10)) - train_dataloader = plugin.prepare_train_dataloader(dataset, batch_size=2) - - # get the first batch of data - batch = next(iter(train_dataloader))[0].cuda() - is_rank_0 = dist.get_rank() == 0 - - if is_rank_0: - batch_to_compare = batch.clone() - else: - batch_to_compare = batch - # pass to the rank 1 value to rank 0 - dist.broadcast(batch_to_compare, src=1) - - # compare on rank 0 - if is_rank_0: - assert not torch.equal(batch, - batch_to_compare), 'Same number was found across ranks but expected it to be different' - - def run_dist(rank, world_size, port, early_stop: bool = True): # init dist env colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') diff --git a/tests/test_booster/test_plugin/test_torch_ddp_plugin.py b/tests/test_booster/test_plugin/test_torch_ddp_plugin.py index 5354eae01d40..30c4db12309f 100644 --- a/tests/test_booster/test_plugin/test_torch_ddp_plugin.py +++ b/tests/test_booster/test_plugin/test_torch_ddp_plugin.py @@ -44,57 +44,9 @@ def check_torch_ddp_plugin(): torch.cuda.empty_cache() -def check_dataloader_sharding(): - plugin = TorchDDPPlugin() - - # create a custom dasetset with 0 to 10 - dataset = torch.utils.data.TensorDataset(torch.arange(0, 10)) - train_dataloader = plugin.prepare_train_dataloader(dataset, batch_size=2) - - # get the first batch of data - batch = next(iter(train_dataloader))[0].cuda() - is_rank_0 = dist.get_rank() == 0 - - if is_rank_0: - batch_to_compare = batch.clone() - else: - batch_to_compare = batch - # pass to the rank 1 value to rank 0 - dist.broadcast(batch_to_compare, src=1) - - # compare on rank 0 - if is_rank_0: - assert not torch.equal(batch, - batch_to_compare), 'Same number was found across ranks but expected it to be different' - - -def check_checkpoint_save_and_load(): - model_fn, data_gen_fn, output_transform_fn, _ = model_zoo['timm_resnet'] - - plugin = TorchDDPPlugin() - booster = Booster(plugin=plugin) - - model = model_fn() - optimizer = SGD(model.parameters(), lr=1e-3) - criterion = lambda x: x.mean() - data = data_gen_fn() - - data = {k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items()} - - model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) - - output = model(**data) - output = output_transform_fn(output) - output_key = list(output.keys())[0] - loss = criterion(output[output_key]) - - booster.backward(loss, optimizer) - - def run_dist(rank, world_size, port): # init dist env colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') - check_dataloader_sharding() check_torch_ddp_plugin() diff --git a/tests/test_checkpoint_io/test_general_checkpoint_io.py b/tests/test_checkpoint_io/test_general_checkpoint_io.py index ca5ce10054f7..752ca706bfd4 100644 --- a/tests/test_checkpoint_io/test_general_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_general_checkpoint_io.py @@ -1,16 +1,21 @@ import tempfile import pytest import torch -import logging from torch.optim import Adam from torchvision.models import resnet18 -from pathlib import Path -import os -import subprocess from colossalai.checkpoint_io import GeneralCheckpointIO +from colossalai.booster.plugin.gemini_plugin import GeminiCheckpointIO from colossalai.testing import clear_cache_before_run, parameterize +import colossalai +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.utils.cuda import get_current_device +from colossalai.zero import ColoInitContext, ZeroDDP +from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration +from colossalai.zero.gemini.gemini_mgr import GeminiManager +from tests.components_to_test.registry import non_distributed_component_funcs + # ======== # Note: # 1. due to checkpoint IO can be quite slow if tested with all models, we will only test on resnet for now @@ -83,7 +88,6 @@ def test_sharded_checkpoint(use_safetensors: bool): suffix = ".bin" WEIGHTS_INDEX_NAME = "model.bin.index.json" - # model_ckpt_dir = tempfile.TemporaryDirectory(suffix=suffix) model_ckpt_dir = tempfile.TemporaryDirectory() optimizer_ckpt_tempfile = tempfile.NamedTemporaryFile() @@ -104,6 +108,87 @@ def test_sharded_checkpoint(use_safetensors: bool): recursive_check(model.state_dict(), new_model.state_dict()) recursive_check(optimizer.state_dict(), new_optimizer.state_dict()) +@parameterize('placement_policy', ['cuda', 'cpu']) +@parameterize('model_name', ['bert']) +@parameterize('use_safetensors', [True, False]) +def hf_load_colossalai_checkpoint(placement_policy, model_name, use_safetensors: bool): + from transformers import BertTokenizer, BertModel, BertForMaskedLM, BertConfig, BertForSequenceClassification + + model_ckpt_dir = tempfile.TemporaryDirectory() + get_components_func = non_distributed_component_funcs.get_callable(model_name) + model_builder, *_ = get_components_func() + + with ColoInitContext(device=get_current_device()): + bert_model = model_builder() + bert_model.config.save_pretrained(save_directory=model_ckpt_dir.name) + config_dict, *_ = search_chunk_configuration(bert_model, search_range_mb=1, search_interval_byte=100) + chunk_manager = ChunkManager(config_dict) + gemini_manager = GeminiManager(placement_policy, chunk_manager) + bert_model = ZeroDDP(bert_model, gemini_manager) + bert_model.train() + + ckpt_io = GeminiCheckpointIO() + if ckpt_io.coordinator.is_master(): + model_size = sum(p.numel() * p.element_size() for p in bert_model.parameters()) / 1024**2 + ckpt_io.save_model(bert_model, model_ckpt_dir.name, True, True, "", (model_size / 3), use_safetensors=use_safetensors) + new_bert_model = BertForSequenceClassification.from_pretrained(model_ckpt_dir.name) + recursive_check(bert_model.state_dict(only_rank_0=True, dtype=torch.float32), new_bert_model.state_dict()) + + model_ckpt_dir.cleanup() + + + +@parameterize('placement_policy', ['cuda', 'cpu']) +@parameterize('model_name', ['gpt2', 'bert']) +@parameterize('use_safetensors', [True, False]) +def exam_state_dict(placement_policy, model_name: str, use_safetensors: bool): + get_components_func = non_distributed_component_funcs.get_callable(model_name) + model_builder, *_ = get_components_func() + + with ColoInitContext(device=get_current_device()): + model = model_builder() + new_model = model_builder() + + config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100) + chunk_manager = ChunkManager(config_dict) + gemini_manager = GeminiManager(placement_policy, chunk_manager) + model = ZeroDDP(model, gemini_manager) + model.train() + + new_config_dict, *_ = search_chunk_configuration(new_model, search_range_mb=1, search_interval_byte=100) + new_chunk_manager = ChunkManager(new_config_dict) + new_gemini_manager = GeminiManager(placement_policy, new_chunk_manager) + new_model = ZeroDDP(new_model, new_gemini_manager) + + model_ckpt_dir = tempfile.TemporaryDirectory() + + ckpt_io = GeminiCheckpointIO() + model_size = sum(p.numel() * p.element_size() for p in model.parameters()) / 1024**2 + ckpt_io.save_model(model, model_ckpt_dir.name, True, True, "epoch", (model_size / 3), use_safetensors=use_safetensors) + + # load model + if ckpt_io.coordinator.is_master(): + ckpt_io.load_model(new_model, model_ckpt_dir.name, strict=True) + model_dict = model.state_dict(only_rank_0=True) + new_model_dict = new_model.state_dict(only_rank_0=True) + recursive_check(model_dict, new_model_dict) + + model_ckpt_dir.cleanup() + + +def run_dist(rank, world_size, port): + config = {} + colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + exam_state_dict() + hf_load_colossalai_checkpoint() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [4, 4]) +@rerun_if_address_is_in_use() +def test_gemini_ckpIO(world_size): + spawn(run_dist, world_size) + # do recursive check for the optimizer state dict # if the value is a dict, compare its values @@ -117,10 +202,14 @@ def recursive_check(d1, d2): elif isinstance(v, list): for i in range(len(v)): if isinstance(v[i], torch.Tensor): + v[i] = v[i].to("cpu") + d2[k][i] = d2[k][i].to("cpu") assert torch.equal(v[i], d2[k][i]) else: assert v[i] == d2[k][i] elif isinstance(v, torch.Tensor): + v = v.to("cpu") + d2[k] = d2[k].to("cpu") assert torch.equal(v, d2[k]) else: assert v == d2[k] diff --git a/tests/test_zero/test_gemini/test_zeroddp_state_dict_shard.py b/tests/test_zero/test_gemini/test_zeroddp_state_dict_shard.py index 96c26a1de4df..ad7d3a5a4859 100644 --- a/tests/test_zero/test_gemini/test_zeroddp_state_dict_shard.py +++ b/tests/test_zero/test_gemini/test_zeroddp_state_dict_shard.py @@ -31,14 +31,13 @@ def exam_state_dict(placement_policy, model_name: str): zero_dict = model.state_dict(only_rank_0=False) accumulated_keys = set() # ensure number of shards > 1 - for shard in model.state_dict_shard(max_shard_size=(model_size / 3), only_rank_0=False): + for shard, _ in model.state_dict_shard(max_shard_size=(model_size / 3), only_rank_0=False): for key, value in shard.items(): assert key not in accumulated_keys, f"key `{key}` is duplicated." accumulated_keys.add(key) assert key in zero_dict, f"{key} not in ZeRO dictionary." assert torch.equal(value, zero_dict[key]), f"{key} not equal." - def run_dist(rank, world_size, port): config = {} colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') diff --git a/tests/test_zero/test_legacy/test_sharded_optim_with_sync_bn.py b/tests/test_zero/test_legacy/test_sharded_optim_with_sync_bn.py index 61d850d06080..0223f18c29d6 100644 --- a/tests/test_zero/test_legacy/test_sharded_optim_with_sync_bn.py +++ b/tests/test_zero/test_legacy/test_sharded_optim_with_sync_bn.py @@ -20,7 +20,7 @@ def run_dist(rank, world_size, port): # need to configure cudnn deterministic so that # randomness of convolution layers will be disabled zero_config = dict(model_config=dict(shard_strategy=TensorShardStrategy())) - colossalai.launch(config=dict(zero=zero_config, cudnn_determinstic=True, cudnn_benchmark=False), + colossalai.launch(config=dict(zero=zero_config, cudnn_deterministic=True, cudnn_benchmark=False), rank=rank, world_size=world_size, host='localhost',