From fd4e0a02b199f97cc0ceebbc58a7af0018e65087 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Fri, 1 Sep 2023 10:59:20 +0800 Subject: [PATCH 1/2] code factors --- .../tensor_parallel/modeling/bloom.py | 5 +- .../tensor_parallel/modeling/llama.py | 5 +- .../tensor_parallel/policies/llama.py | 8 +- .../kernel/triton/self_attention_nofusion.py | 69 ++-- examples/language/bert/bert_finetune.py | 311 ++++++++++++++++++ tests/test_infer/test_llama_infer.py | 31 +- 6 files changed, 378 insertions(+), 51 deletions(-) create mode 100644 examples/language/bert/bert_finetune.py diff --git a/colossalai/inference/tensor_parallel/modeling/bloom.py b/colossalai/inference/tensor_parallel/modeling/bloom.py index 0fd08d3721e6..a6ee58f1e00d 100644 --- a/colossalai/inference/tensor_parallel/modeling/bloom.py +++ b/colossalai/inference/tensor_parallel/modeling/bloom.py @@ -197,7 +197,7 @@ def bloom_model_forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - # FIXME: currently our KV cache manager does not handle this condition + # NOTE: currently our KV cache manager does not handle this condition def create_custom_forward(module): def custom_forward(*inputs): @@ -240,7 +240,8 @@ def custom_forward(*inputs): all_hidden_states = all_hidden_states + (hidden_states,) # update indices of kv cache block - # TODO: might want to remove this part, instead, better to pass the BatchInferState from model forward, + # NOT READY FOR PRIME TIME + # might want to remove this part, instead, better to pass the BatchInferState from model forward, # and update these information in engine.generate after model foward called infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") infer_state.seq_len += 1 diff --git a/colossalai/inference/tensor_parallel/modeling/llama.py b/colossalai/inference/tensor_parallel/modeling/llama.py index 1d9e366f99f3..94a13b968d0d 100644 --- a/colossalai/inference/tensor_parallel/modeling/llama.py +++ b/colossalai/inference/tensor_parallel/modeling/llama.py @@ -77,12 +77,13 @@ def llama_model_forward( past_key_values_length = 0 if past_key_values is not None: - # TODO dummy but work, revise it + # NOT READY FOR PRIME TIME + # dummy but work, revise it past_key_values_length = infer_state.cache_manager.past_key_values_length # past_key_values_length = past_key_values[0][0].shape[2] seq_length_with_past = seq_length_with_past + past_key_values_length - # FIXME: differentiate with prefill stage + # NOTE: differentiate with prefill stage # block_loc require different value-assigning method for two different stage if use_cache and seq_length != 1: # NOTE assuem prefill stage diff --git a/colossalai/inference/tensor_parallel/policies/llama.py b/colossalai/inference/tensor_parallel/policies/llama.py index c569a0e3163a..bbd2156b8523 100644 --- a/colossalai/inference/tensor_parallel/policies/llama.py +++ b/colossalai/inference/tensor_parallel/policies/llama.py @@ -1,10 +1,10 @@ from functools import partial + from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy -from ..modeling.llama import LlamaInferenceForwards -from ..modeling.llama import get_llama_vllm_rmsnorm_forward +from ..modeling.llama import LlamaInferenceForwards, get_llama_vllm_rmsnorm_forward class LlamaModelInferPolicy(LlamaForCausalLMPolicy): @@ -37,8 +37,8 @@ def module_policy(self): self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaAttention) - - # TODO: adding rms_norm caused precision issue, fix @tiandiao123 + + # NOTE: adding rms_norm caused precision issue, fix @tiandiao123 # infer_forward = get_llama_vllm_rmsnorm_forward() # if infer_forward is not None: # method_replacement = {'forward': partial(infer_forward)} diff --git a/colossalai/kernel/triton/self_attention_nofusion.py b/colossalai/kernel/triton/self_attention_nofusion.py index a6c9bdfbdff6..6ae54dcb0b38 100644 --- a/colossalai/kernel/triton/self_attention_nofusion.py +++ b/colossalai/kernel/triton/self_attention_nofusion.py @@ -13,8 +13,9 @@ from .qkv_matmul_kernel import qkv_gemm_4d_kernel from .softmax import softmax_kernel - def self_attention_forward_without_fusion(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, input_mask: torch.Tensor, scale: float): - r""" A function to do QKV Attention calculation by calling GEMM and softmax triton kernels + def self_attention_forward_without_fusion(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + input_mask: torch.Tensor, scale: float): + r""" A function to do QKV Attention calculation by calling GEMM and softmax triton kernels Args: q (torch.Tensor): Q embedding in attention layer, shape should be (batch, seq_len, num_heads, head_size) k (torch.Tensor): K embedding in attention layer, shape should be (batch, seq_len, num_heads, head_size) @@ -36,39 +37,49 @@ def self_attention_forward_without_fusion(q: torch.Tensor, k: torch.Tensor, v: t # head_size * num_of_head d_model = q.shape[-1] * q.shape[-2] - score_output = torch.empty( - (batches, H, M, N), device=q.device, dtype=q.dtype) + score_output = torch.empty((batches, H, M, N), device=q.device, dtype=q.dtype) grid = lambda meta: ( batches, H, - triton.cdiv(M, meta["BLOCK_SIZE_M"]) * - triton.cdiv(N, meta["BLOCK_SIZE_N"]), + triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(N, meta["BLOCK_SIZE_N"]), ) qkv_gemm_4d_kernel[grid]( - q, k, score_output, - M, N, K, - q.stride(0), q.stride(2), q.stride(1), q.stride(3), - k.stride(0), k.stride(2), k.stride(3), k.stride(1), - score_output.stride(0), score_output.stride(1), score_output.stride(2), score_output.stride(3), + q, + k, + score_output, + M, + N, + K, + q.stride(0), + q.stride(2), + q.stride(1), + q.stride(3), + k.stride(0), + k.stride(2), + k.stride(3), + k.stride(1), + score_output.stride(0), + score_output.stride(1), + score_output.stride(2), + score_output.stride(3), scale=scale, - # currently manually setting, later on we can use auto-tune config to match best setting + # currently manually setting, later on we can use auto-tune config to match best setting BLOCK_SIZE_M=64, BLOCK_SIZE_N=32, BLOCK_SIZE_K=32, GROUP_SIZE_M=8, ) - - softmax_output = torch.empty( - score_output.shape, device=score_output.device, dtype=score_output.dtype) + + softmax_output = torch.empty(score_output.shape, device=score_output.device, dtype=score_output.dtype) score_output_shape = score_output.shape score_output = score_output.view(-1, score_output.shape[-1]) n_rows, n_cols = score_output.shape if n_rows <= 350000: - + block_size = max(triton.next_power_of_2(n_cols), 2) num_warps = 4 if block_size >= 4096: @@ -78,37 +89,39 @@ def self_attention_forward_without_fusion(q: torch.Tensor, k: torch.Tensor, v: t else: num_warps = 4 - softmax_kernel[(n_rows, )]( + softmax_kernel[(n_rows,)]( softmax_output, score_output, score_output.stride(0), n_cols, - mask_ptr = input_mask, + mask_ptr=input_mask, num_warps=num_warps, BLOCK_SIZE=block_size, ) else: - #TODO: change softmax kernel functions to make it suitable for large size dimension + # NOTE: change softmax kernel functions to make it suitable for large size dimension softmax_output = torch.nn.functional.softmax(score_output, dim=-1) softmax_output = softmax_output.view(*score_output_shape) batches, H, M, K = softmax_output.shape N = v.shape[-1] - output = torch.empty( - (batches, M, H, N), device=softmax_output.device, dtype=softmax_output.dtype) + output = torch.empty((batches, M, H, N), device=softmax_output.device, dtype=softmax_output.dtype) grid = lambda meta: ( batches, H, - triton.cdiv(M, meta["BLOCK_SIZE_M"]) * - triton.cdiv(N, meta["BLOCK_SIZE_N"]), + triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(N, meta["BLOCK_SIZE_N"]), ) qkv_gemm_4d_kernel[grid]( - softmax_output, v, output, - M, N, K, + softmax_output, + v, + output, + M, + N, + K, softmax_output.stride(0), softmax_output.stride(1), softmax_output.stride(2), @@ -129,7 +142,6 @@ def self_attention_forward_without_fusion(q: torch.Tensor, k: torch.Tensor, v: t ) return output.view(batches, -1, d_model) - def self_attention_compute_using_triton(qkv, input_mask, layer_past, @@ -152,7 +164,6 @@ def self_attention_compute_using_triton(qkv, k = k.view(batches, -1, num_of_heads, head_size) v = v.view(batches, -1, num_of_heads, head_size) - data_output_triton = self_attention_forward_without_fusion( - q, k, v, input_mask, scale) + data_output_triton = self_attention_forward_without_fusion(q, k, v, input_mask, scale) - return data_output_triton \ No newline at end of file + return data_output_triton diff --git a/examples/language/bert/bert_finetune.py b/examples/language/bert/bert_finetune.py new file mode 100644 index 000000000000..06e2cee40599 --- /dev/null +++ b/examples/language/bert/bert_finetune.py @@ -0,0 +1,311 @@ +import argparse +import copy +from contextlib import nullcontext +from typing import List, Union + +import evaluate +import torch +import torch.distributed as dist +import torch.nn as nn +import transformers +from data import GLUEDataBuilder +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from torch.optim import Adam, Optimizer +from torch.utils.data import DataLoader +from tqdm import tqdm +from transformers import ( + AlbertForSequenceClassification, + AutoConfig, + BertForSequenceClassification, + get_linear_schedule_with_warmup, +) + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin +from colossalai.cluster import DistCoordinator +from colossalai.lazy import LazyInitContext +from colossalai.nn.optimizer import HybridAdam +from colossalai.utils import get_current_device + +# ============================== +# Prepare Hyperparameters +# ============================== +NUM_EPOCHS = 1 +BATCH_SIZE = 32 +LEARNING_RATE = 2.4e-5 +WEIGHT_DECAY = 0.01 +WARMUP_FRACTION = 0.1 + + +def move_to_cuda(batch): + return {k: v.cuda() for k, v in batch.items()} + + +@torch.no_grad() +def evaluate_model(model: nn.Module, optimizer, criterion, test_dataloader: Union[DataLoader, List[DataLoader]], + num_labels: int, task_name: str, eval_splits: List[str], coordinator: DistCoordinator, booster): + metric = evaluate.load("glue", task_name, process_id=coordinator.rank, num_process=coordinator.world_size) + model.eval() + + def evaluate_subset(dataloader: DataLoader): + accum_loss = torch.zeros(1, device=get_current_device()) + for batch in dataloader: + batch = move_to_cuda(batch) + labels = batch["labels"] + + if booster.plugin.stage_manager is not None: + batch = iter([batch]) + outputs = booster.execute_pipeline(batch, + model, + criterion, + optimizer, + return_loss=True, + return_outputs=True) + # outputs = model(**batch) + if booster.plugin.stage_manager.is_last_stage(): + print(outputs) + val_loss = outputs["loss"] + # val_loss, logits = outputs[:2] + logits = outputs["outputs"][0].logits + accum_loss.add_(val_loss) + + if num_labels > 1: + preds = torch.argmax(logits, axis=1) + elif num_labels == 1: + preds = logits.squeeze() + print(dist.get_rank()) + + metric.add_batch(predictions=preds, references=labels) + else: + batch = move_to_cuda(batch) + outputs = model(**batch) + val_loss, logits = outputs[:2] + accum_loss.add_(val_loss) + + if num_labels > 1: + preds = torch.argmax(logits, axis=1) + elif num_labels == 1: + preds = logits.squeeze() + + # labels = batch["labels"] + + metric.add_batch(predictions=preds, references=labels) + + # results = metric.compute() + results = None + if booster.plugin.stage_manager is not None: + if booster.plugin.stage_manager.is_last_stage(): + results = metric.compute() + else: + results = metric.compute() + # if booster.plugin.stage_manager is not None: + # group = booster.plugin. + # dist.all_reduce(accum_loss.div_(len(dataloader))) + if coordinator.is_master() and results is not None: + results['loss'] = accum_loss.item() / coordinator.world_size + return results + + if isinstance(test_dataloader, DataLoader): + return evaluate_subset(test_dataloader) + else: + assert len(test_dataloader) == len(eval_splits) + final_results = {} + for split, sub_loader in zip(eval_splits, test_dataloader): + results = evaluate_subset(sub_loader) + final_results.update({f'{k}_{split}': v for k, v in results.items()}) + return final_results + + +def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, _criterion, lr_scheduler, + train_dataloader: DataLoader, booster: Booster, coordinator: DistCoordinator): + + model.train() + with tqdm(train_dataloader, desc=f'Epoch [{epoch + 1}/{NUM_EPOCHS}]', disable=not coordinator.is_master()) as pbar: + for batch in pbar: + # Forward pass + batch = move_to_cuda(batch) + # print("batch:" + str(batch)) + # # outputs = model(**batch) + if booster.plugin.stage_manager is not None: + batch = iter([batch]) + outputs = booster.execute_pipeline(batch, + model, + _criterion, + optimizer, + return_loss=True, + return_outputs=True) + # Backward and optimize + loss = outputs['loss'] + if booster.plugin.stage_manager.is_last_stage(): + print("aaaaaaaaaaaaaaa" + str(loss)) + pbar.set_postfix({'loss': loss}) + else: + outputs = model(**batch) + loss = _criterion(outputs, None) + # Backward + booster.backward(loss, optimizer) + pbar.set_postfix({'loss': loss}) + + optimizer.step() + optimizer.zero_grad() + lr_scheduler.step() + + +def main(): + # ============================== + # Parse Arguments + # ============================== + parser = argparse.ArgumentParser() + parser.add_argument('-t', '--task', default='mrpc', help="GLUE task to run") + parser.add_argument('-p', + '--plugin', + type=str, + default='torch_ddp', + choices=['torch_ddp', 'torch_ddp_fp16', 'gemini', 'low_level_zero', 'hybrid_parallel'], + help="plugin to use") + parser.add_argument( + "--model_type", + type=str, + default="bert", + help="bert or albert", + ) + parser.add_argument('--target_f1', type=float, default=None, help="target f1 score. Raise exception if not reached") + parser.add_argument('--use_lazy_init', type=bool, default=False, help="for initiating lazy init context") + args = parser.parse_args() + + if args.model_type == 'bert': + model_name = "bert-base-uncased" + elif args.model_type == 'albert': + model_name = "albert-xxlarge-v2" + else: + raise RuntimeError + # ============================== + # Launch Distributed Environment + # ============================== + colossalai.launch_from_torch(config={}, seed=42) + coordinator = DistCoordinator() + + # local_batch_size = BATCH_SIZE // coordinator.world_size + lr = LEARNING_RATE * coordinator.world_size + + # ============================== + # Instantiate Plugin and Booster + # ============================== + booster_kwargs = {} + if args.plugin == 'torch_ddp_fp16': + booster_kwargs['mixed_precision'] = 'fp16' + if args.plugin.startswith('torch_ddp'): + plugin = TorchDDPPlugin() + elif args.plugin == 'gemini': + plugin = GeminiPlugin(placement_policy='cuda', strict_ddp_mode=True, initial_scale=2**5) + elif args.plugin == 'low_level_zero': + plugin = LowLevelZeroPlugin(initial_scale=2**5) + elif args.plugin == 'hybrid_parallel': + + # modify the param accordingly for finetuning test cases + plugin = HybridParallelPlugin( + tp_size=1, + pp_size=2, + num_microbatches=2, + enable_all_optimization=True, + #enable_sequence_parallelism=True, + zero_stage=1, + precision='fp16', + initial_scale=1) + + booster = Booster(plugin=plugin, **booster_kwargs) + + # ============================== + # Prepare Dataloader + # ============================== + data_builder = GLUEDataBuilder(model_name, + plugin, + args.task, + train_batch_size=BATCH_SIZE, + eval_batch_size=BATCH_SIZE) + train_dataloader = data_builder.train_dataloader() + test_dataloader = data_builder.test_dataloader() + + # ==================================== + # Prepare model, optimizer + # ==================================== + # bert pretrained model + + cfg = AutoConfig.from_pretrained(model_name, num_labels=data_builder.num_labels) + if model_name == "bert-base-uncased": + model = BertForSequenceClassification.from_pretrained(model_name, config=cfg).cuda() + elif model_name == "albert-xxlarge-v2": + model = AlbertForSequenceClassification.from_pretrained(model_name, config=cfg) + else: + raise RuntimeError + + # optimizer + no_decay = ["bias", "LayerNorm.weight"] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], + "weight_decay": WEIGHT_DECAY, + }, + { + "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], + "weight_decay": 0.0, + }, + ] + + optimizer = Adam(model.parameters(), lr=lr, eps=1e-8) + + # lr scheduler + total_steps = len(train_dataloader) * NUM_EPOCHS + num_warmup_steps = int(WARMUP_FRACTION * total_steps) + lr_scheduler = get_linear_schedule_with_warmup( + optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=total_steps, + ) + + # lazy_init + use_lazy_init = args.use_lazy_init + ctx = LazyInitContext() if use_lazy_init else nullcontext() + # with ctx: + # org_model = model + # sharded_model = copy.deepcopy(org_model) + # if use_lazy_init: + # ctx.materialize(org_model) + + output_transform_fn = lambda x: x + criterion = lambda x: x.loss + + def _criterion(outputs, inputs): + outputs = output_transform_fn(outputs) + loss = criterion(outputs) + return loss + + # ============================== + # Boost with ColossalAI + # ============================== + print("===before boost===") + # print(optimizer.param_groups[0]['params'][0].dtype) + model, optimizer, _criterion, _, lr_scheduler = booster.boost(model, + optimizer, + criterion=_criterion, + lr_scheduler=lr_scheduler) + + # print(model) + # ============================== + # Train model + # ============================== + for epoch in range(NUM_EPOCHS): + train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, train_dataloader, booster, coordinator) + + results = evaluate_model(model, optimizer, _criterion, test_dataloader, data_builder.num_labels, args.task, + data_builder.eval_splits, coordinator, booster) + + if coordinator.is_master(): + print(results) + if args.target_f1 is not None and 'f1' in results: + assert results['f1'] >= args.target_f1, f'f1 score {results["f1"]} is lower than target {args.target_f1}' + + +if __name__ == '__main__': + main() diff --git a/tests/test_infer/test_llama_infer.py b/tests/test_infer/test_llama_infer.py index 55576e55fd2d..c8f852aef420 100644 --- a/tests/test_infer/test_llama_infer.py +++ b/tests/test_infer/test_llama_infer.py @@ -1,17 +1,17 @@ import os +import numpy as np import pytest import torch -import numpy as np +import torch.distributed as dist +from transformers import LlamaForCausalLM, LlamaTokenizer import colossalai -from colossalai.logging import disable_existing_loggers -from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn -from transformers import LlamaForCausalLM, LlamaTokenizer from colossalai.cluster import ProcessGroupMesh -from colossalai.shardformer import ShardConfig, ShardFormer from colossalai.inference.tensor_parallel.engine import TPInferEngine -import torch.distributed as dist +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer import ShardConfig, ShardFormer +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' TPSIZE = 2 @@ -19,20 +19,22 @@ MAX_INPUT_LEN = 12 MAX_OUTPUT_LEN = 100 + def init_to_get_rotary(self, base=10000): self.config.head_dim_ = self.config.hidden_size // self.config.num_attention_heads if not hasattr(self.config, "rope_scaling"): rope_scaling_factor = 1.0 else: rope_scaling_factor = self.config.rope_scaling.factor if self.config.rope_scaling is not None else 1.0 - if hasattr(self.config,"max_sequence_length"): + if hasattr(self.config, "max_sequence_length"): max_seq_len = self.config.max_sequence_length - elif hasattr(self.config,"max_position_embeddings"): + elif hasattr(self.config, "max_position_embeddings"): max_seq_len = self.config.max_position_embeddings * rope_scaling_factor else: - max_seq_len = 2048 * rope_scaling_factor + max_seq_len = 2048 * rope_scaling_factor base = float(base) - inv_freq = 1.0 / (base ** (torch.arange(0, self.config.head_dim_, 2, device="cpu", dtype=torch.float32) / self.config.head_dim_)) + inv_freq = 1.0 / (base**(torch.arange(0, self.config.head_dim_, 2, device="cpu", dtype=torch.float32) / + self.config.head_dim_)) t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor freqs = torch.outer(t, inv_freq) @@ -40,21 +42,22 @@ def init_to_get_rotary(self, base=10000): self._sin_cached = torch.sin(freqs).to(torch.float16).cuda() return + @parameterize('test_config', [{ 'tp_size': TPSIZE, }]) def run_llama_test(test_config): - + llama_model_path = "/data/scratch/llama-7b-hf" tokenizer = LlamaTokenizer.from_pretrained(llama_model_path) tokenizer.pad_token_id = tokenizer.unk_token_id model = LlamaForCausalLM.from_pretrained(llama_model_path, pad_token_id=tokenizer.eos_token_id) init_to_get_rotary(model.model, base=10000) model = model.half() - + text = "how is weather today?" input_ids = tokenizer.encode(text, return_tensors='pt', device='cuda') - + infer_engine = TPInferEngine(model.half(), BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) shard_config = ShardConfig(enable_tensor_parallelism=True, inference_only=True) shardformer = ShardFormer(shard_config=shard_config) @@ -65,7 +68,7 @@ def run_llama_test(test_config): generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False) outputs = infer_engine.generate(input_ids, generate_kwargs) print("outputs.shape: ", outputs.shape) - + print("outputs: ", outputs) output_text = tokenizer.decode(outputs[0]) From ccbdf323145fc263b3434fb38664dee6df800944 Mon Sep 17 00:00:00 2001 From: CjhHa1 Date: Fri, 1 Sep 2023 11:01:45 +0800 Subject: [PATCH 2/2] remove --- examples/language/bert/bert_finetune.py | 311 ------------------------ 1 file changed, 311 deletions(-) delete mode 100644 examples/language/bert/bert_finetune.py diff --git a/examples/language/bert/bert_finetune.py b/examples/language/bert/bert_finetune.py deleted file mode 100644 index 06e2cee40599..000000000000 --- a/examples/language/bert/bert_finetune.py +++ /dev/null @@ -1,311 +0,0 @@ -import argparse -import copy -from contextlib import nullcontext -from typing import List, Union - -import evaluate -import torch -import torch.distributed as dist -import torch.nn as nn -import transformers -from data import GLUEDataBuilder -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -from torch.optim import Adam, Optimizer -from torch.utils.data import DataLoader -from tqdm import tqdm -from transformers import ( - AlbertForSequenceClassification, - AutoConfig, - BertForSequenceClassification, - get_linear_schedule_with_warmup, -) - -import colossalai -from colossalai.booster import Booster -from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin -from colossalai.cluster import DistCoordinator -from colossalai.lazy import LazyInitContext -from colossalai.nn.optimizer import HybridAdam -from colossalai.utils import get_current_device - -# ============================== -# Prepare Hyperparameters -# ============================== -NUM_EPOCHS = 1 -BATCH_SIZE = 32 -LEARNING_RATE = 2.4e-5 -WEIGHT_DECAY = 0.01 -WARMUP_FRACTION = 0.1 - - -def move_to_cuda(batch): - return {k: v.cuda() for k, v in batch.items()} - - -@torch.no_grad() -def evaluate_model(model: nn.Module, optimizer, criterion, test_dataloader: Union[DataLoader, List[DataLoader]], - num_labels: int, task_name: str, eval_splits: List[str], coordinator: DistCoordinator, booster): - metric = evaluate.load("glue", task_name, process_id=coordinator.rank, num_process=coordinator.world_size) - model.eval() - - def evaluate_subset(dataloader: DataLoader): - accum_loss = torch.zeros(1, device=get_current_device()) - for batch in dataloader: - batch = move_to_cuda(batch) - labels = batch["labels"] - - if booster.plugin.stage_manager is not None: - batch = iter([batch]) - outputs = booster.execute_pipeline(batch, - model, - criterion, - optimizer, - return_loss=True, - return_outputs=True) - # outputs = model(**batch) - if booster.plugin.stage_manager.is_last_stage(): - print(outputs) - val_loss = outputs["loss"] - # val_loss, logits = outputs[:2] - logits = outputs["outputs"][0].logits - accum_loss.add_(val_loss) - - if num_labels > 1: - preds = torch.argmax(logits, axis=1) - elif num_labels == 1: - preds = logits.squeeze() - print(dist.get_rank()) - - metric.add_batch(predictions=preds, references=labels) - else: - batch = move_to_cuda(batch) - outputs = model(**batch) - val_loss, logits = outputs[:2] - accum_loss.add_(val_loss) - - if num_labels > 1: - preds = torch.argmax(logits, axis=1) - elif num_labels == 1: - preds = logits.squeeze() - - # labels = batch["labels"] - - metric.add_batch(predictions=preds, references=labels) - - # results = metric.compute() - results = None - if booster.plugin.stage_manager is not None: - if booster.plugin.stage_manager.is_last_stage(): - results = metric.compute() - else: - results = metric.compute() - # if booster.plugin.stage_manager is not None: - # group = booster.plugin. - # dist.all_reduce(accum_loss.div_(len(dataloader))) - if coordinator.is_master() and results is not None: - results['loss'] = accum_loss.item() / coordinator.world_size - return results - - if isinstance(test_dataloader, DataLoader): - return evaluate_subset(test_dataloader) - else: - assert len(test_dataloader) == len(eval_splits) - final_results = {} - for split, sub_loader in zip(eval_splits, test_dataloader): - results = evaluate_subset(sub_loader) - final_results.update({f'{k}_{split}': v for k, v in results.items()}) - return final_results - - -def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, _criterion, lr_scheduler, - train_dataloader: DataLoader, booster: Booster, coordinator: DistCoordinator): - - model.train() - with tqdm(train_dataloader, desc=f'Epoch [{epoch + 1}/{NUM_EPOCHS}]', disable=not coordinator.is_master()) as pbar: - for batch in pbar: - # Forward pass - batch = move_to_cuda(batch) - # print("batch:" + str(batch)) - # # outputs = model(**batch) - if booster.plugin.stage_manager is not None: - batch = iter([batch]) - outputs = booster.execute_pipeline(batch, - model, - _criterion, - optimizer, - return_loss=True, - return_outputs=True) - # Backward and optimize - loss = outputs['loss'] - if booster.plugin.stage_manager.is_last_stage(): - print("aaaaaaaaaaaaaaa" + str(loss)) - pbar.set_postfix({'loss': loss}) - else: - outputs = model(**batch) - loss = _criterion(outputs, None) - # Backward - booster.backward(loss, optimizer) - pbar.set_postfix({'loss': loss}) - - optimizer.step() - optimizer.zero_grad() - lr_scheduler.step() - - -def main(): - # ============================== - # Parse Arguments - # ============================== - parser = argparse.ArgumentParser() - parser.add_argument('-t', '--task', default='mrpc', help="GLUE task to run") - parser.add_argument('-p', - '--plugin', - type=str, - default='torch_ddp', - choices=['torch_ddp', 'torch_ddp_fp16', 'gemini', 'low_level_zero', 'hybrid_parallel'], - help="plugin to use") - parser.add_argument( - "--model_type", - type=str, - default="bert", - help="bert or albert", - ) - parser.add_argument('--target_f1', type=float, default=None, help="target f1 score. Raise exception if not reached") - parser.add_argument('--use_lazy_init', type=bool, default=False, help="for initiating lazy init context") - args = parser.parse_args() - - if args.model_type == 'bert': - model_name = "bert-base-uncased" - elif args.model_type == 'albert': - model_name = "albert-xxlarge-v2" - else: - raise RuntimeError - # ============================== - # Launch Distributed Environment - # ============================== - colossalai.launch_from_torch(config={}, seed=42) - coordinator = DistCoordinator() - - # local_batch_size = BATCH_SIZE // coordinator.world_size - lr = LEARNING_RATE * coordinator.world_size - - # ============================== - # Instantiate Plugin and Booster - # ============================== - booster_kwargs = {} - if args.plugin == 'torch_ddp_fp16': - booster_kwargs['mixed_precision'] = 'fp16' - if args.plugin.startswith('torch_ddp'): - plugin = TorchDDPPlugin() - elif args.plugin == 'gemini': - plugin = GeminiPlugin(placement_policy='cuda', strict_ddp_mode=True, initial_scale=2**5) - elif args.plugin == 'low_level_zero': - plugin = LowLevelZeroPlugin(initial_scale=2**5) - elif args.plugin == 'hybrid_parallel': - - # modify the param accordingly for finetuning test cases - plugin = HybridParallelPlugin( - tp_size=1, - pp_size=2, - num_microbatches=2, - enable_all_optimization=True, - #enable_sequence_parallelism=True, - zero_stage=1, - precision='fp16', - initial_scale=1) - - booster = Booster(plugin=plugin, **booster_kwargs) - - # ============================== - # Prepare Dataloader - # ============================== - data_builder = GLUEDataBuilder(model_name, - plugin, - args.task, - train_batch_size=BATCH_SIZE, - eval_batch_size=BATCH_SIZE) - train_dataloader = data_builder.train_dataloader() - test_dataloader = data_builder.test_dataloader() - - # ==================================== - # Prepare model, optimizer - # ==================================== - # bert pretrained model - - cfg = AutoConfig.from_pretrained(model_name, num_labels=data_builder.num_labels) - if model_name == "bert-base-uncased": - model = BertForSequenceClassification.from_pretrained(model_name, config=cfg).cuda() - elif model_name == "albert-xxlarge-v2": - model = AlbertForSequenceClassification.from_pretrained(model_name, config=cfg) - else: - raise RuntimeError - - # optimizer - no_decay = ["bias", "LayerNorm.weight"] - optimizer_grouped_parameters = [ - { - "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], - "weight_decay": WEIGHT_DECAY, - }, - { - "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], - "weight_decay": 0.0, - }, - ] - - optimizer = Adam(model.parameters(), lr=lr, eps=1e-8) - - # lr scheduler - total_steps = len(train_dataloader) * NUM_EPOCHS - num_warmup_steps = int(WARMUP_FRACTION * total_steps) - lr_scheduler = get_linear_schedule_with_warmup( - optimizer, - num_warmup_steps=num_warmup_steps, - num_training_steps=total_steps, - ) - - # lazy_init - use_lazy_init = args.use_lazy_init - ctx = LazyInitContext() if use_lazy_init else nullcontext() - # with ctx: - # org_model = model - # sharded_model = copy.deepcopy(org_model) - # if use_lazy_init: - # ctx.materialize(org_model) - - output_transform_fn = lambda x: x - criterion = lambda x: x.loss - - def _criterion(outputs, inputs): - outputs = output_transform_fn(outputs) - loss = criterion(outputs) - return loss - - # ============================== - # Boost with ColossalAI - # ============================== - print("===before boost===") - # print(optimizer.param_groups[0]['params'][0].dtype) - model, optimizer, _criterion, _, lr_scheduler = booster.boost(model, - optimizer, - criterion=_criterion, - lr_scheduler=lr_scheduler) - - # print(model) - # ============================== - # Train model - # ============================== - for epoch in range(NUM_EPOCHS): - train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, train_dataloader, booster, coordinator) - - results = evaluate_model(model, optimizer, _criterion, test_dataloader, data_builder.num_labels, args.task, - data_builder.eval_splits, coordinator, booster) - - if coordinator.is_master(): - print(results) - if args.target_f1 is not None and 'f1' in results: - assert results['f1'] >= args.target_f1, f'f1 score {results["f1"]} is lower than target {args.target_f1}' - - -if __name__ == '__main__': - main()