From dc0dc0b91d215f000a2e98037000e5bee8cc974b Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Fri, 22 Sep 2023 16:32:09 +0800 Subject: [PATCH 01/46] [colossalai]fix typo --- .../quant/gptq/cai_gptq/cai_quant_linear.py | 202 +++++----- .../tensor_parallel/policies/bloom.py | 58 +-- .../cuda_native/csrc/gptq/linear_gptq.cpp | 362 ++++++++---------- .../kernel/cuda_native/csrc/gptq/q4_matrix.cu | 2 +- .../cuda_native/csrc/gptq/q4_matrix.cuh | 2 +- examples/inference/gptq_bloom.py | 43 +-- examples/inference/gptq_llama.py | 24 ++ tests/test_gptq/test_gptq_linear.py | 30 +- 8 files changed, 344 insertions(+), 379 deletions(-) diff --git a/colossalai/inference/quant/gptq/cai_gptq/cai_quant_linear.py b/colossalai/inference/quant/gptq/cai_gptq/cai_quant_linear.py index ca12c34ed958..36339ac88486 100644 --- a/colossalai/inference/quant/gptq/cai_gptq/cai_quant_linear.py +++ b/colossalai/inference/quant/gptq/cai_gptq/cai_quant_linear.py @@ -18,15 +18,15 @@ HAS_GPTQ_CUDA = False try: from colossalai.kernel.op_builder.gptq import GPTQBuilder + gptq_cuda = GPTQBuilder().load() HAS_GPTQ_CUDA = True except ImportError: - warnings.warn('CUDA gptq is not installed') + warnings.warn("CUDA gptq is not installed") HAS_GPTQ_CUDA = False class CaiQuantLinear(nn.Module): - def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False): super().__init__() if bits not in [2, 4, 8]: @@ -37,23 +37,28 @@ def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp self.maxq = 2**self.bits - 1 self.groupsize = groupsize if groupsize != -1 else infeatures - self.register_buffer('qweight', torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32)) + self.register_buffer("qweight", torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32)) + self.register_buffer( + "qzeros", + torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits), dtype=torch.int32), + ) self.register_buffer( - 'qzeros', - torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits), dtype=torch.int32)) - self.register_buffer('scales', - torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16)) + "scales", torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16) + ) if row_split: self.register_buffer( - 'g_idx', - torch.tensor([(i + (tp_rank * self.infeatures)) // self.groupsize for i in range(infeatures)], - dtype=torch.int32)) + "g_idx", + torch.tensor( + [(i + (tp_rank * self.infeatures)) // self.groupsize for i in range(infeatures)], dtype=torch.int32 + ), + ) else: - self.register_buffer('g_idx', - torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int32)) + self.register_buffer( + "g_idx", torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int32) + ) if bias: - self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16)) + self.register_buffer("bias", torch.zeros((outfeatures), dtype=torch.float16)) else: self.bias = None @@ -66,9 +71,11 @@ def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp self.row_split = row_split def pack(self, linear, scales, zeros, g_idx=None): - - g_idx = g_idx.clone() if g_idx is not None else torch.tensor( - [i // self.groupsize for i in range(self.infeatures)], dtype=torch.int32) + g_idx = ( + g_idx.clone() + if g_idx is not None + else torch.tensor([i // self.groupsize for i in range(self.infeatures)], dtype=torch.int32) + ) scales = scales.t().contiguous() zeros = zeros.t().contiguous() @@ -79,7 +86,6 @@ def pack(self, linear, scales, zeros, g_idx=None): if linear.bias is not None: self.bias = linear.bias.clone().half() - wn = 8 pbits = 32 ptype = torch.int32 unsign_type = np.uint32 @@ -88,9 +94,10 @@ def pack(self, linear, scales, zeros, g_idx=None): intweight = [] for idx in range(self.infeatures): intweight.append( - torch.round( - (linear.weight.data[:, idx] + scale_zeros[g_idx[idx]]) / half_scales[g_idx[idx]]).to(ptype)[:, - None]) + torch.round((linear.weight.data[:, idx] + scale_zeros[g_idx[idx]]) / half_scales[g_idx[idx]]).to(ptype)[ + :, None + ] + ) intweight = torch.cat(intweight, dim=1) intweight = intweight.t().contiguous() intweight = intweight.numpy().astype(unsign_type) @@ -109,7 +116,7 @@ def pack(self, linear, scales, zeros, g_idx=None): raise NotImplementedError("Only 2,4,8 bits are supported.") qweight = qweight.astype(sign_type) qweight1 = torch.from_numpy(qweight) - qweight1 = qweight1.contiguous() #.to("cuda") + qweight1 = qweight1.contiguous() # .to("cuda") self.qweight.data.copy_(qweight1) qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // pbits * self.bits), dtype=unsign_type) @@ -140,17 +147,20 @@ def init_q4(self): self.q4_width = self.qweight.shape[1] if self.g_idx is not None: if self.row_split and torch.equal( - self.g_idx, - torch.tensor( - [(i + (self.tp_rank * self.infeatures)) // self.groupsize for i in range(self.infeatures)], - dtype=torch.int32, - device=self.g_idx.device)): + self.g_idx, + torch.tensor( + [(i + (self.tp_rank * self.infeatures)) // self.groupsize for i in range(self.infeatures)], + dtype=torch.int32, + device=self.g_idx.device, + ), + ): self.g_idx = None elif torch.equal( - self.g_idx, - torch.tensor([i // self.groupsize for i in range(self.infeatures)], - dtype=torch.int32, - device=self.g_idx.device)): + self.g_idx, + torch.tensor( + [i // self.groupsize for i in range(self.infeatures)], dtype=torch.int32, device=self.g_idx.device + ), + ): self.g_idx = None if self.g_idx is not None: @@ -165,7 +175,6 @@ def forward(self, x): outshape = x.shape[:-1] + (self.outfeatures,) if HAS_GPTQ_CUDA and self.bits == 4: - if self.q4 is None: self.init_q4() @@ -191,7 +200,6 @@ def forward(self, x): def split_column_copy(gptq_linear, cai_linear, tp_size=1, tp_rank=0, split_num=1): - qweights = gptq_linear.qweight.split(gptq_linear.out_features // split_num, dim=-1) qzeros = gptq_linear.qzeros.split(gptq_linear.out_features // (32 // cai_linear.bits) // split_num, dim=-1) scales = gptq_linear.scales.split(gptq_linear.out_features // split_num, dim=-1) @@ -203,24 +211,24 @@ def split_column_copy(gptq_linear, cai_linear, tp_size=1, tp_rank=0, split_num=1 zero_split_block = cai_linear.outfeatures // (32 // cai_linear.bits) // split_num for i in range(split_num): - cai_linear.qweight[:, i * cai_split_out_features:(i + 1) * - cai_split_out_features] = qweights[i][:, tp_rank * cai_split_out_features:(tp_rank + 1) * - cai_split_out_features] - cai_linear.qzeros[:, i * zero_split_block:(i + 1) * - zero_split_block] = qzeros[i][:, tp_rank * zero_split_block:(tp_rank + 1) * zero_split_block] - cai_linear.scales[:, i * cai_split_out_features:(i + 1) * - cai_split_out_features] = scales[i][:, tp_rank * cai_split_out_features:(tp_rank + 1) * - cai_split_out_features] + cai_linear.qweight[:, i * cai_split_out_features : (i + 1) * cai_split_out_features] = qweights[i][ + :, tp_rank * cai_split_out_features : (tp_rank + 1) * cai_split_out_features + ] + cai_linear.qzeros[:, i * zero_split_block : (i + 1) * zero_split_block] = qzeros[i][ + :, tp_rank * zero_split_block : (tp_rank + 1) * zero_split_block + ] + cai_linear.scales[:, i * cai_split_out_features : (i + 1) * cai_split_out_features] = scales[i][ + :, tp_rank * cai_split_out_features : (tp_rank + 1) * cai_split_out_features + ] if cai_linear.bias is not None: - cai_linear.bias[i * cai_split_out_features:(i + 1) * - cai_split_out_features] = bias[i][tp_rank * cai_split_out_features:(tp_rank + 1) * - cai_split_out_features] + cai_linear.bias[i * cai_split_out_features : (i + 1) * cai_split_out_features] = bias[i][ + tp_rank * cai_split_out_features : (tp_rank + 1) * cai_split_out_features + ] cai_linear.g_idx.copy_(g_idx) def split_row_copy(gptq_linear, cai_linear, tp_rank=0, split_num=1): - qweights = gptq_linear.qweight.split(gptq_linear.in_features // split_num, dim=0) qzeros = gptq_linear.qzeros.split(gptq_linear.in_features // split_num, dim=0) scales = gptq_linear.scales.split(gptq_linear.in_features // split_num, dim=0) @@ -231,47 +239,40 @@ def split_row_copy(gptq_linear, cai_linear, tp_rank=0, split_num=1): idx_split_features = cai_linear.infeatures // split_num for i in range(split_num): - cai_linear.qweight[i * cai_split_in_features:(i + 1) * - cai_split_in_features, :] = qweights[i][tp_rank * cai_split_in_features:(tp_rank + 1) * - cai_split_in_features, :] - cai_linear.qzeros[i * zero_split_block:(i + 1) * - zero_split_block, :] = qzeros[i][tp_rank * zero_split_block:(tp_rank + 1) * - zero_split_block, :] - cai_linear.scales[i * zero_split_block:(i + 1) * - zero_split_block, :] = scales[i][tp_rank * zero_split_block:(tp_rank + 1) * - zero_split_block, :] - cai_linear.g_idx[i * idx_split_features:(i + 1) * - idx_split_features] = g_idxs[i][tp_rank * idx_split_features:(tp_rank + 1) * - idx_split_features] + cai_linear.qweight[i * cai_split_in_features : (i + 1) * cai_split_in_features, :] = qweights[i][ + tp_rank * cai_split_in_features : (tp_rank + 1) * cai_split_in_features, : + ] + cai_linear.qzeros[i * zero_split_block : (i + 1) * zero_split_block, :] = qzeros[i][ + tp_rank * zero_split_block : (tp_rank + 1) * zero_split_block, : + ] + cai_linear.scales[i * zero_split_block : (i + 1) * zero_split_block, :] = scales[i][ + tp_rank * zero_split_block : (tp_rank + 1) * zero_split_block, : + ] + cai_linear.g_idx[i * idx_split_features : (i + 1) * idx_split_features] = g_idxs[i][ + tp_rank * idx_split_features : (tp_rank + 1) * idx_split_features + ] if cai_linear.bias is not None: cai_linear.bias.copy_(gptq_linear.bias) class RowCaiQuantLinear(CaiQuantLinear, ParallelModule): - def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False): - - super().__init__(bits, - groupsize, - infeatures, - outfeatures, - bias, - tp_size=tp_size, - tp_rank=tp_rank, - row_split=row_split) + super().__init__( + bits, groupsize, infeatures, outfeatures, bias, tp_size=tp_size, tp_rank=tp_rank, row_split=row_split + ) self.process_group = None @staticmethod - def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, - **kwargs) -> ParallelModule: + def from_native_module( + module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs + ) -> ParallelModule: LazyInitContext.materialize(module) # get the attributes in_features = module.in_features # ensure only one process group is passed if isinstance(process_group, (list, tuple)): - assert len(process_group) == 1, \ - f'Expected only one process group, got {len(process_group)}.' + assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}." process_group = process_group[0] tp_size = dist.get_world_size(process_group) @@ -282,15 +283,18 @@ def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, Lis if in_features % tp_size != 0: raise ValueError( - f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!") - linear_1d = RowCaiQuantLinear(module.bits, - module.group_size, - module.in_features // tp_size, - module.out_features, - module.bias is not None, - tp_size=tp_size, - tp_rank=tp_rank, - row_split=True) + f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!" + ) + linear_1d = RowCaiQuantLinear( + module.bits, + module.group_size, + module.in_features // tp_size, + module.out_features, + module.bias is not None, + tp_size=tp_size, + tp_rank=tp_rank, + row_split=True, + ) linear_1d.process_group = process_group split_row_copy(module, linear_1d, tp_rank=tp_rank, **kwargs) @@ -306,30 +310,23 @@ def forward(self, x): class ColCaiQuantLinear(CaiQuantLinear, ParallelModule): - def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False): - - super().__init__(bits, - groupsize, - infeatures, - outfeatures, - bias, - tp_size=tp_size, - tp_rank=tp_rank, - row_split=row_split) + super().__init__( + bits, groupsize, infeatures, outfeatures, bias, tp_size=tp_size, tp_rank=tp_rank, row_split=row_split + ) self.process_group = None @staticmethod - def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, - **kwargs) -> ParallelModule: + def from_native_module( + module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs + ) -> ParallelModule: LazyInitContext.materialize(module) # get the attributes in_features = module.in_features # ensure only one process group is passed if isinstance(process_group, (list, tuple)): - assert len(process_group) == 1, \ - f'Expected only one process group, got {len(process_group)}.' + assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}." process_group = process_group[0] tp_size = dist.get_world_size(process_group) @@ -340,14 +337,17 @@ def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, Lis if in_features % tp_size != 0: raise ValueError( - f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!") - linear_1d = ColCaiQuantLinear(module.bits, - module.group_size, - module.in_features, - module.out_features // tp_size, - module.bias is not None, - tp_size=tp_size, - tp_rank=tp_rank) + f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!" + ) + linear_1d = ColCaiQuantLinear( + module.bits, + module.group_size, + module.in_features, + module.out_features // tp_size, + module.bias is not None, + tp_size=tp_size, + tp_rank=tp_rank, + ) linear_1d.process_group = process_group split_column_copy(module, linear_1d, tp_rank=tp_rank, **kwargs) diff --git a/colossalai/inference/tensor_parallel/policies/bloom.py b/colossalai/inference/tensor_parallel/policies/bloom.py index 3d6df2097000..fba83a08175d 100644 --- a/colossalai/inference/tensor_parallel/policies/bloom.py +++ b/colossalai/inference/tensor_parallel/policies/bloom.py @@ -4,7 +4,6 @@ from torch.nn import LayerNorm import colossalai.shardformer.layer as col_nn -from colossalai.shardformer.modeling.bloom import build_bloom_alibi_tensor_fn from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription from colossalai.shardformer.policies.bloom import BloomForCausalLMPolicy @@ -40,33 +39,36 @@ def module_policy(self): policy = super().module_policy() if self.shard_config.inference_gptq: from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear - policy[BloomBlock] = ModulePolicyDescription(attribute_replacement={ - "self_attention.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, - "self_attention.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, - "self_attention.num_heads": self.model.config.n_head // self.shard_config.tensor_parallel_size, - }, - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="self_attention.query_key_value", - target_module=ColCaiQuantLinear, - kwargs={'split_num': 3}), - SubModuleReplacementDescription( - suffix="self_attention.dense", - target_module=RowCaiQuantLinear, - kwargs={'split_num': 1}), - SubModuleReplacementDescription( - suffix="self_attention.attention_dropout", - target_module=col_nn.DropoutForParallelInput, - ), - SubModuleReplacementDescription( - suffix="mlp.dense_h_to_4h", - target_module=ColCaiQuantLinear, - kwargs={'split_num': 1}), - SubModuleReplacementDescription( - suffix="mlp.dense_4h_to_h", - target_module=RowCaiQuantLinear, - kwargs={'split_num': 1}), - ]) + + policy[BloomBlock] = ModulePolicyDescription( + attribute_replacement={ + "self_attention.hidden_size": self.model.config.hidden_size + // self.shard_config.tensor_parallel_size, + "self_attention.split_size": self.model.config.hidden_size + // self.shard_config.tensor_parallel_size, + "self_attention.num_heads": self.model.config.n_head // self.shard_config.tensor_parallel_size, + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attention.query_key_value", + target_module=ColCaiQuantLinear, + kwargs={"split_num": 3}, + ), + SubModuleReplacementDescription( + suffix="self_attention.dense", target_module=RowCaiQuantLinear, kwargs={"split_num": 1} + ), + SubModuleReplacementDescription( + suffix="self_attention.attention_dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="mlp.dense_h_to_4h", target_module=ColCaiQuantLinear, kwargs={"split_num": 1} + ), + SubModuleReplacementDescription( + suffix="mlp.dense_4h_to_h", target_module=RowCaiQuantLinear, kwargs={"split_num": 1} + ), + ], + ) # NOTE set inference mode to shard config self.shard_config._infer() diff --git a/colossalai/kernel/cuda_native/csrc/gptq/linear_gptq.cpp b/colossalai/kernel/cuda_native/csrc/gptq/linear_gptq.cpp index bcc0e43901de..8f17723cbd1b 100644 --- a/colossalai/kernel/cuda_native/csrc/gptq/linear_gptq.cpp +++ b/colossalai/kernel/cuda_native/csrc/gptq/linear_gptq.cpp @@ -1,254 +1,202 @@ // Adapted from turboderp exllama: https://github.com/turboderp/exllama -#include -#include #include -#include +#include #include +#include +#include + #include #include -#include "util.cuh" -#include "tuning.h" + +#include "column_remap.cuh" #include "cuda_buffers.cuh" -#include "q4_matrix.cuh" #include "q4_matmul.cuh" -#include "column_remap.cuh" +#include "q4_matrix.cuh" +#include "tuning.h" +#include "util.cuh" -// Check CUDA return code. We don't want to include Torch headers in the .cu files because parsing them adds almost a -// minute to the compile time on a 12900K. Also passing exceptions back to Python is super tricky, so in place of -// exceptions, CUDA functions return with a cudaError_t which we can parse and dump to the console. - -void check_cuda(cudaError_t ret) -{ - switch (ret) - { - case cudaSuccess: - break; - - case cudaUnspecified: - printf(" **** Unspecified error\n"); - TORCH_CHECK(false, "CUDA error"); - break; - - default: - printf(" **** CUDA error\n"); \ - printf(" **** %s\n", cudaGetErrorString(ret)); \ - TORCH_CHECK(false, "CUDA error"); \ - break; - } +// Check CUDA return code. We don't want to include Torch headers in the .cu +// files because parsing them adds almost a minute to the compile time on a +// 12900K. Also passing exceptions back to Python is super tricky, so in place +// of exceptions, CUDA functions return with a cudaError_t which we can parse +// and dump to the console. + +void check_cuda(cudaError_t ret) { + switch (ret) { + case cudaSuccess: + break; + + case cudaUnspecified: + printf(" **** Unspecified error\n"); + TORCH_CHECK(false, "CUDA error"); + break; + + default: + printf(" **** CUDA error\n"); + printf(" **** %s\n", cudaGetErrorString(ret)); + TORCH_CHECK(false, "CUDA error"); + break; + } } // Some decluttering macros #define STRINGIFY_(__x) #__x #define STRINGIFY(__x) STRINGIFY_(__x) -#define TORCH_CHECK_DTYPE(__x, __dtype) TORCH_CHECK((__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype) -#define TORCH_CHECK_DTYPE_OPT(__x, __dtype) TORCH_CHECK((__x).device().is_meta() || (__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype) -#define TORCH_CHECK_SHAPES(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes") -#define TORCH_CHECK_SHAPES_OPT(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).device().is_meta() || (__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes") -#define TORCH_CHECK_SHAPE_MOD(__x, __dim_x, __mod) TORCH_CHECK((__x).size(__dim_x) % __mod == 0, #__x ".shape[" STRINGIFY(__dim_x) "] must be a multiple of " STRINGIFY(__mod)) -#define TORCH_CHECK_BUFFER_SIZE(__buffer, __minimum_size) TORCH_CHECK((__buffer).numel() >= __minimum_size, #__buffer " is too small") - -#define TORCH_CHECK_DEVICE_INDEX(__index) \ -do { \ - TORCH_CHECK(__index >= 0, "no device index"); \ +#define TORCH_CHECK_DTYPE(__x, __dtype) \ + TORCH_CHECK((__x).dtype() == torch::__dtype, \ + #__x " is incorrect datatype, must be " #__dtype) +#define TORCH_CHECK_DTYPE_OPT(__x, __dtype) \ + TORCH_CHECK((__x).device().is_meta() || (__x).dtype() == torch::__dtype, \ + #__x " is incorrect datatype, must be " #__dtype) +#define TORCH_CHECK_SHAPES(__x, __dim_x, __y, __dim_y, __scale_y) \ + TORCH_CHECK((__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, \ + #__x " and " #__y " have incompatible shapes") +#define TORCH_CHECK_SHAPES_OPT(__x, __dim_x, __y, __dim_y, __scale_y) \ + TORCH_CHECK((__x).device().is_meta() || \ + (__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, \ + #__x " and " #__y " have incompatible shapes") +#define TORCH_CHECK_SHAPE_MOD(__x, __dim_x, __mod) \ + TORCH_CHECK((__x).size(__dim_x) % __mod == 0, \ + #__x ".shape[" STRINGIFY( \ + __dim_x) "] must be a multiple of " STRINGIFY(__mod)) +#define TORCH_CHECK_BUFFER_SIZE(__buffer, __minimum_size) \ + TORCH_CHECK((__buffer).numel() >= __minimum_size, #__buffer " is too small") + +#define TORCH_CHECK_DEVICE_INDEX(__index) \ + do { \ + TORCH_CHECK(__index >= 0, "no device index"); \ TORCH_CHECK(__index < CUDA_MAX_DEVICES, "invalid device index"); \ -} while(0) + } while (0) #define TORCH_CHECK_QUANT(__w, __w_scales, __w_zeros, __seq_g_idx, __x_map) \ -do { \ - TORCH_CHECK_DTYPE(__w, kInt); \ - TORCH_CHECK_DTYPE(__w_scales, kHalf); \ - TORCH_CHECK_DTYPE(__w_zeros, kInt); \ - TORCH_CHECK_DTYPE_OPT(__seq_g_idx, kShort); \ - TORCH_CHECK_DTYPE_OPT(__x_map, kInt); \ - TORCH_CHECK_SHAPES_OPT(__seq_g_idx, 0, __w, 0, 2 * 8); \ - TORCH_CHECK_SHAPES_OPT(__x_map, 0, __w, 0, 8); \ -} while(0) - -int get_groupsize(torch::Tensor w, torch::Tensor w_zeros) -{ - int groupsize = w.size(0) * 8 / w_zeros.size(0); - TORCH_CHECK(groupsize * w_zeros.size(0) == w.size(0) * 8, "w.shape[-2] must be a multiple of zeros.shape[-2]") - return groupsize; + do { \ + TORCH_CHECK_DTYPE(__w, kInt); \ + TORCH_CHECK_DTYPE(__w_scales, kHalf); \ + TORCH_CHECK_DTYPE(__w_zeros, kInt); \ + TORCH_CHECK_DTYPE_OPT(__seq_g_idx, kShort); \ + TORCH_CHECK_DTYPE_OPT(__x_map, kInt); \ + TORCH_CHECK_SHAPES_OPT(__seq_g_idx, 0, __w, 0, 2 * 8); \ + TORCH_CHECK_SHAPES_OPT(__x_map, 0, __w, 0, 8); \ + } while (0) + +int get_groupsize(torch::Tensor w, torch::Tensor w_zeros) { + int groupsize = w.size(0) * 8 / w_zeros.size(0); + TORCH_CHECK(groupsize * w_zeros.size(0) == w.size(0) * 8, + "w.shape[-2] must be a multiple of zeros.shape[-2]") + return groupsize; } - // Tuning parameters ExLlamaTuning tuningParams; -void set_tuning_params -( - int matmul_recons_thd, - bool matmul_fused_remap, - bool matmul_no_half2 -) -{ - tuningParams.matmul_recons_thd = matmul_recons_thd; - tuningParams.matmul_fused_remap = matmul_fused_remap; - tuningParams.matmul_no_half2 = matmul_no_half2; +void set_tuning_params(int matmul_recons_thd, bool matmul_fused_remap, + bool matmul_no_half2) { + tuningParams.matmul_recons_thd = matmul_recons_thd; + tuningParams.matmul_fused_remap = matmul_fused_remap; + tuningParams.matmul_no_half2 = matmul_no_half2; } - // Release all unmanaged objects allocated by the extension -void cleanup() -{ - cleanup_buffers_cuda(); - g_q4_free_matrices(); +void cleanup() { + cleanup_buffers_cuda(); + g_q4_free_matrices(); } - // Prepare buffers for forward pass -void prepare_buffers -( - torch::Device device, - torch::Tensor temp_state, - torch::Tensor temp_dq -) -{ - int device_index = device.index(); - TORCH_CHECK_DEVICE_INDEX(device_index); - const at::cuda::OptionalCUDAGuard device_guard(device); - - prepare_buffers_cuda - ( - device_index, - // buffer size used for sanity checks - temp_state.numel(), - (half*) temp_state.data_ptr(), - (half*) temp_dq.data_ptr() - ); -} +void prepare_buffers(torch::Device device, torch::Tensor temp_state, + torch::Tensor temp_dq) { + int device_index = device.index(); + TORCH_CHECK_DEVICE_INDEX(device_index); + const at::cuda::OptionalCUDAGuard device_guard(device); + prepare_buffers_cuda(device_index, + // buffer size used for sanity checks + temp_state.numel(), (half*)temp_state.data_ptr(), + (half*)temp_dq.data_ptr()); +} // Create Q4Matrix, return handle -uintptr_t make_q4 -( - torch::Tensor qweight, - torch::Tensor qzeros, - torch::Tensor scales, - torch::Tensor g_idx, - int device -) -{ - TORCH_CHECK_DTYPE(qweight, kInt); - TORCH_CHECK_DTYPE(qzeros, kInt); - TORCH_CHECK_DTYPE(scales, kHalf); - TORCH_CHECK_DTYPE_OPT(g_idx, kInt); - TORCH_CHECK_SHAPES(qweight, 1, qzeros, 1, 8); - TORCH_CHECK_SHAPES(scales, 1, qweight, 1, 1); - TORCH_CHECK_SHAPES(qzeros, 0, scales, 0, 1); - - int width = qweight.size(1); - int height = qweight.size(0) * 8; - int groups = qzeros.size(0); - - Q4Matrix* m = new Q4Matrix - ( - height, - width, - groups, - - (uint32_t*) qweight.data_ptr(), - (uint32_t*) qzeros.data_ptr(), - (half*) scales.data_ptr(), - g_idx.device().is_meta() ? NULL : (uint32_t*) g_idx.data_ptr(), - - device - ); - - g_q4_keep_matrix(m); - return reinterpret_cast (m); -} +uintptr_t make_q4(torch::Tensor qweight, torch::Tensor qzeros, + torch::Tensor scales, torch::Tensor g_idx, int device) { + TORCH_CHECK_DTYPE(qweight, kInt); + TORCH_CHECK_DTYPE(qzeros, kInt); + TORCH_CHECK_DTYPE(scales, kHalf); + TORCH_CHECK_DTYPE_OPT(g_idx, kInt); + TORCH_CHECK_SHAPES(qweight, 1, qzeros, 1, 8); + TORCH_CHECK_SHAPES(scales, 1, qweight, 1, 1); + TORCH_CHECK_SHAPES(qzeros, 0, scales, 0, 1); + int width = qweight.size(1); + int height = qweight.size(0) * 8; + int groups = qzeros.size(0); -// Matmul half @ quant -> half + Q4Matrix* m = new Q4Matrix( + height, width, groups, + + (uint32_t*)qweight.data_ptr(), (uint32_t*)qzeros.data_ptr(), + (half*)scales.data_ptr(), + g_idx.device().is_meta() ? NULL : (uint32_t*)g_idx.data_ptr(), -void q4_matmul -( - torch::Tensor x, - uintptr_t w, - torch::Tensor out -) -{ - Q4Matrix* wm = reinterpret_cast (w); - - TORCH_CHECK_DTYPE(x, kHalf); - TORCH_CHECK_DTYPE(out, kHalf); - TORCH_CHECK_SHAPES(x, 0, out, 0, 1); - TORCH_CHECK(wm->height == x.size(-1), "x and w have incompatible shapes") - - const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); - - int x_height = x.size(0); - - if (tuningParams.matmul_recons_thd == 0 || x_height < tuningParams.matmul_recons_thd) - { - q4_matmul_cuda - ( - &tuningParams, - (half*) x.data_ptr(), - x_height, - wm, - (half*) out.data_ptr() - ); - } - else - { - q4_matmul_recons_cuda - ( - &tuningParams, - (half*) x.data_ptr(), - x_height, - wm, - (half*) out.data_ptr(), - at::cuda::getCurrentCUDABlasHandle() - ); - } + device); + + g_q4_keep_matrix(m); + return reinterpret_cast(m); } +// Matmul half @ quant -> half + +void q4_matmul(torch::Tensor x, uintptr_t w, torch::Tensor out) { + Q4Matrix* wm = reinterpret_cast(w); -// Remap columns in half tensor + TORCH_CHECK_DTYPE(x, kHalf); + TORCH_CHECK_DTYPE(out, kHalf); + TORCH_CHECK_SHAPES(x, 0, out, 0, 1); + TORCH_CHECK(wm->height == x.size(-1), "x and w have incompatible shapes") -void column_remap -( - torch::Tensor x, - torch::Tensor x_new, - torch::Tensor x_map -) -{ - TORCH_CHECK_DTYPE(x, kHalf); - TORCH_CHECK_DTYPE(x_new, kHalf); - TORCH_CHECK_DTYPE(x_map, kInt); - TORCH_CHECK_SHAPES(x_map, 0, x, 1, 1); - - int height = x.size(0); - int width = x.size(1); - - TORCH_CHECK_BUFFER_SIZE(x_new, height * width); - - const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); - - column_remap_cuda - ( - (half*) x.data_ptr(), - (half*) x_new.data_ptr(), - height, - width, - (uint32_t*) x_map.data_ptr() - ); + const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); + + int x_height = x.size(0); + + if (tuningParams.matmul_recons_thd == 0 || + x_height < tuningParams.matmul_recons_thd) { + q4_matmul_cuda(&tuningParams, (half*)x.data_ptr(), x_height, wm, + (half*)out.data_ptr()); + } else { + q4_matmul_recons_cuda(&tuningParams, (half*)x.data_ptr(), x_height, wm, + (half*)out.data_ptr(), + at::cuda::getCurrentCUDABlasHandle()); + } } +// Remap columns in half tensor + +void column_remap(torch::Tensor x, torch::Tensor x_new, torch::Tensor x_map) { + TORCH_CHECK_DTYPE(x, kHalf); + TORCH_CHECK_DTYPE(x_new, kHalf); + TORCH_CHECK_DTYPE(x_map, kInt); + TORCH_CHECK_SHAPES(x_map, 0, x, 1, 1); + + int height = x.size(0); + int width = x.size(1); + + TORCH_CHECK_BUFFER_SIZE(x_new, height * width); + + const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); + + column_remap_cuda((half*)x.data_ptr(), (half*)x_new.data_ptr(), height, width, + (uint32_t*)x_map.data_ptr()); +} -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) -{ - m.def("set_tuning_params", &set_tuning_params, "set_tuning_params"); - m.def("prepare_buffers", &prepare_buffers, "prepare_buffers"); - m.def("cleanup", &cleanup, "cleanup"); - m.def("make_q4", &make_q4, "make_q4"); - m.def("q4_matmul", &q4_matmul, "q4_matmul"); +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("set_tuning_params", &set_tuning_params, "set_tuning_params"); + m.def("prepare_buffers", &prepare_buffers, "prepare_buffers"); + m.def("cleanup", &cleanup, "cleanup"); + m.def("make_q4", &make_q4, "make_q4"); + m.def("q4_matmul", &q4_matmul, "q4_matmul"); } diff --git a/colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cu b/colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cu index 9c61143f565e..bd595ee6f86c 100644 --- a/colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cu +++ b/colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cu @@ -184,7 +184,7 @@ __global__ void reconstruct_kernel int column = RECONS_THREADS_X * blockIdx.x + threadIdx.x; int row = (RECONS_THREADS_Y * blockIdx.y + threadIdx.y) * 8; if (column >= width) return; - + // Views MatrixView_q4_column w_(w, height, width); diff --git a/colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cuh b/colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cuh index 50cb72a41518..49431dc95876 100644 --- a/colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cuh +++ b/colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cuh @@ -50,4 +50,4 @@ private: void g_q4_keep_matrix(Q4Matrix* m); void g_q4_free_matrices(); -#endif \ No newline at end of file +#endif diff --git a/examples/inference/gptq_bloom.py b/examples/inference/gptq_bloom.py index 43e118cc0aa5..f5413e31682d 100644 --- a/examples/inference/gptq_bloom.py +++ b/examples/inference/gptq_bloom.py @@ -1,12 +1,10 @@ import argparse -import logging import os import time import torch -from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig -from auto_gptq.nn_modules.qlinear import GeneralQuantLinear -from transformers import AutoTokenizer, BloomForCausalLM, BloomTokenizerFast, LlamaForCausalLM, LlamaTokenizer +from auto_gptq import AutoGPTQForCausalLM +from transformers import BloomTokenizerFast import colossalai from colossalai.inference.tensor_parallel.engine import TPInferEngine @@ -14,7 +12,7 @@ from colossalai.shardformer import ShardConfig from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn -os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' +os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" def print_perf_stats(latency_set, config, bs, warmup=3): @@ -28,7 +26,7 @@ def print_perf_stats(latency_set, config, bs, warmup=3): avg = sum(latency_set) / count num_layers = getattr(config, "num_layers", config.num_hidden_layers) num_parameters = num_layers * config.hidden_size * config.hidden_size * 12 - num_bytes = 2 # float16 + num_bytes = 2 # float16 print("Avg Per Token Latency: {0:8.2f} ms".format(avg * 1000)) print("Avg BW: {0:8.2f} GB/s".format(1 / avg * num_parameters * num_bytes / 1e9)) @@ -37,7 +35,6 @@ def print_perf_stats(latency_set, config, bs, warmup=3): def bench_bloom(args): - pretrained_model_dir = args.path quantized_model_dir = args.quantized_path max_batch_size = args.batch_size @@ -48,9 +45,9 @@ def bench_bloom(args): tokenizer.pad_token = tokenizer.eos_token # load quantized model to the first GPU - model = AutoGPTQForCausalLM.from_quantized(quantized_model_dir, - device=torch.cuda.current_device(), - inject_fused_attention=False) + model = AutoGPTQForCausalLM.from_quantized( + quantized_model_dir, device=torch.cuda.current_device(), inject_fused_attention=False + ) model = model.half() @@ -60,22 +57,22 @@ def bench_bloom(args): generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False) input_tokens = { - "input_ids": torch.randint(1, 1000, (max_batch_size, max_input_len), device='cuda'), - "attention_mask": torch.ones((max_batch_size, max_input_len), device='cuda') + "input_ids": torch.randint(1, 1000, (max_batch_size, max_input_len), device="cuda"), + "attention_mask": torch.ones((max_batch_size, max_input_len), device="cuda"), } # init TPInferEngine and shard the original model # To benchmark torch original, comment out the line of optimizing model - shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, - inference_only=True, - inference_gptq=True) + shard_config = ShardConfig( + enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True, inference_gptq=True + ) infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len) # prepare data for generation generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False) input_tokens = { "input_ids": torch.randint(10, 1000, (max_batch_size, max_input_len)), - "attention_mask": torch.ones((max_batch_size, max_input_len)) + "attention_mask": torch.ones((max_batch_size, max_input_len)), } for t in input_tokens: if torch.is_tensor(input_tokens[t]): @@ -99,7 +96,7 @@ def bench_bloom(args): def check_bloom(rank, world_size, port, args): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") bench_bloom(args) @@ -111,12 +108,12 @@ def test_bloom(args): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('-p', '--path', type=str, help='Model path', required=True) - parser.add_argument('-q', '--quantized_path', type=str, help='Model path', required=True) - parser.add_argument('-tp', '--tp_size', type=int, default=1, help='Tensor parallel size') - parser.add_argument('-b', '--batch_size', type=int, default=16, help='Maximum batch size') - parser.add_argument('--input_len', type=int, default=1024, help='Maximum input length') - parser.add_argument('--output_len', type=int, default=128, help='Maximum output length') + parser.add_argument("-p", "--path", type=str, help="Model path", required=True) + parser.add_argument("-q", "--quantized_path", type=str, help="Model path", required=True) + parser.add_argument("-tp", "--tp_size", type=int, default=1, help="Tensor parallel size") + parser.add_argument("-b", "--batch_size", type=int, default=16, help="Maximum batch size") + parser.add_argument("--input_len", type=int, default=1024, help="Maximum input length") + parser.add_argument("--output_len", type=int, default=128, help="Maximum output length") args = parser.parse_args() diff --git a/examples/inference/gptq_llama.py b/examples/inference/gptq_llama.py index 1bdee448c742..bbfbf1bc8b43 100644 --- a/examples/inference/gptq_llama.py +++ b/examples/inference/gptq_llama.py @@ -16,6 +16,30 @@ os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" +def init_to_get_rotary(self, base=10000): + self.config.head_dim_ = self.config.hidden_size // self.config.num_attention_heads + if not hasattr(self.config, "rope_scaling"): + rope_scaling_factor = 1.0 + else: + rope_scaling_factor = self.config.rope_scaling.factor if self.config.rope_scaling is not None else 1.0 + if hasattr(self.config, "max_sequence_length"): + max_seq_len = self.config.max_sequence_length + elif hasattr(self.config, "max_position_embeddings"): + max_seq_len = self.config.max_position_embeddings * rope_scaling_factor + else: + max_seq_len = 2048 * rope_scaling_factor + base = float(base) + inv_freq = 1.0 / ( + base ** (torch.arange(0, self.config.head_dim_, 2, device="cpu", dtype=torch.float32) / self.config.head_dim_) + ) + t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor + freqs = torch.outer(t, inv_freq) + + self._cos_cached = torch.cos(freqs).to(torch.float16).cuda() + self._sin_cached = torch.sin(freqs).to(torch.float16).cuda() + return + + def print_perf_stats(latency_set, config, bs, warmup=3): # trim warmup queries latency_set = list(latency_set) diff --git a/tests/test_gptq/test_gptq_linear.py b/tests/test_gptq/test_gptq_linear.py index 9b650aa78112..ded70fa43c30 100644 --- a/tests/test_gptq/test_gptq_linear.py +++ b/tests/test_gptq/test_gptq_linear.py @@ -1,16 +1,8 @@ -import math -import time - -import numpy as np import pytest import torch -import torch.nn as nn -import transformers from packaging import version try: - import triton - import triton.language as tl HAS_TRITON = True except ImportError: HAS_TRITON = False @@ -22,6 +14,7 @@ from exllama_kernels import prepare_buffers, set_tuning_params from colossalai.inference.quant.gptq import CaiQuantLinear + HAS_AUTO_GPTQ = True except: HAS_AUTO_GPTQ = False @@ -32,13 +25,14 @@ HAS_GPTQ_CUDA = False try: from colossalai.kernel.op_builder.gptq import GPTQBuilder + gptq_cuda = GPTQBuilder().load() HAS_GPTQ_CUDA = True except ImportError: - warnings.warn('CUDA gptq is not installed') + warnings.warn("CUDA gptq is not installed") HAS_GPTQ_CUDA = False -TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4') +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") max_inner_outer_dim = 1 max_input_len = 1 @@ -64,9 +58,9 @@ def init_buffer(cai_linear, use_act_order=False): max_input_len = 4096 # The temp_state buffer is required to reorder X in the act-order case. # The temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill. - gptq_temp_state_buffer = torch.zeros((max_input_len, max_inner_outer_dim), - dtype=torch.float16, - device=torch.cuda.current_device()) + gptq_temp_state_buffer = torch.zeros( + (max_input_len, max_inner_outer_dim), dtype=torch.float16, device=torch.cuda.current_device() + ) gptq_temp_dq_buffer = torch.zeros((1, max_dq_buffer_size), dtype=torch.float16, device=torch.cuda.current_device()) gptq_cuda.prepare_buffers(torch.device(torch.cuda.current_device()), gptq_temp_state_buffer, gptq_temp_dq_buffer) @@ -77,10 +71,11 @@ def init_buffer(cai_linear, use_act_order=False): gptq_cuda.set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2) -@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON or not HAS_AUTO_GPTQ, - reason="triton requires cuda version to be higher than 11.4 or not install auto-gptq") +@pytest.mark.skipif( + not TRITON_CUDA_SUPPORT or not HAS_TRITON or not HAS_AUTO_GPTQ, + reason="triton requires cuda version to be higher than 11.4 or not install auto-gptq", +) def test_gptq_linear(): - infeature = 1024 outfeature = 1024 group_size = 128 @@ -120,7 +115,7 @@ def test_gptq_linear(): max_input_len = 2048 buffers = { "temp_state": torch.zeros((max_input_len, max_inner_outer_dim), dtype=torch.float16, device=device), - "temp_dq": torch.zeros((1, max_dq_buffer_size), dtype=torch.float16, device=device) + "temp_dq": torch.zeros((1, max_dq_buffer_size), dtype=torch.float16, device=device), } prepare_buffers(device, buffers["temp_state"], buffers["temp_dq"]) @@ -146,5 +141,4 @@ def test_gptq_linear(): if __name__ == "__main__": - test_gptq_linear() From dd59ca209f3fea8b6ad6d2d96c8c58b849c0b79b Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Mon, 16 Oct 2023 11:28:44 +0800 Subject: [PATCH 02/46] [inference] Add smmoothquant for llama (#4904) * [inference] add int8 rotary embedding kernel for smoothquant (#4843) * [inference] add smoothquant llama attention (#4850) * add smoothquant llama attention * remove uselss code * remove useless code * fix import error * rename file name * [inference] add silu linear fusion for smoothquant llama mlp (#4853) * add silu linear * update skip condition * catch smoothquant cuda lib exception * prcocess exception for tests * [inference] add llama mlp for smoothquant (#4854) * add llama mlp for smoothquant * fix down out scale * remove duplicate lines * add llama mlp check * delete useless code * [inference] add smoothquant llama (#4861) * add smoothquant llama * fix attention accuracy * fix accuracy * add kv cache and save pretrained * refactor example * delete smooth * refactor code * [inference] add smooth function and delete useless code for smoothquant (#4895) * add smooth function and delete useless code * update datasets * remove duplicate import * delete useless file * refactor codes (#4902) * rafactor code * add license * add torch-int and smoothquant license --- LICENSE | 50 ++ .../inference/quant/smoothquant/__init__.py | 0 .../quant/smoothquant/models/__init__.py | 12 + .../quant/smoothquant/models/base_model.py | 482 ++++++++++ .../quant/smoothquant/models/linear.py | 177 ++++ .../quant/smoothquant/models/llama.py | 846 ++++++++++++++++++ .../cuda_native/csrc/smoothquant/binding.cpp | 8 + .../cuda_native/csrc/smoothquant/linear.cu | 162 ++++ .../cuda_native/csrc/smoothquant/linear.h | 12 + colossalai/kernel/triton/__init__.py | 5 + .../triton/int8_rotary_embedding_kernel.py | 117 +++ colossalai/kernel/triton/smooth_attention.py | 652 ++++++++++++++ examples/inference/smoothquant_llama.py | 69 ++ op_builder/smoothquant.py | 52 ++ .../test_smoothquant/test_llama_attention.py | 136 +++ tests/test_smoothquant/test_llama_mlp.py | 84 ++ .../test_smoothquant_linear.py | 39 + .../test_sq_rotary_embedding.py | 59 ++ 18 files changed, 2962 insertions(+) create mode 100644 colossalai/inference/quant/smoothquant/__init__.py create mode 100644 colossalai/inference/quant/smoothquant/models/__init__.py create mode 100644 colossalai/inference/quant/smoothquant/models/base_model.py create mode 100644 colossalai/inference/quant/smoothquant/models/linear.py create mode 100644 colossalai/inference/quant/smoothquant/models/llama.py create mode 100644 colossalai/kernel/cuda_native/csrc/smoothquant/binding.cpp create mode 100644 colossalai/kernel/cuda_native/csrc/smoothquant/linear.cu create mode 100644 colossalai/kernel/cuda_native/csrc/smoothquant/linear.h create mode 100644 colossalai/kernel/triton/int8_rotary_embedding_kernel.py create mode 100644 colossalai/kernel/triton/smooth_attention.py create mode 100644 examples/inference/smoothquant_llama.py create mode 100644 op_builder/smoothquant.py create mode 100644 tests/test_smoothquant/test_llama_attention.py create mode 100644 tests/test_smoothquant/test_llama_mlp.py create mode 100644 tests/test_smoothquant/test_smoothquant_linear.py create mode 100644 tests/test_smoothquant/test_sq_rotary_embedding.py diff --git a/LICENSE b/LICENSE index 59d456c5b8a1..b3eb43520a6f 100644 --- a/LICENSE +++ b/LICENSE @@ -477,3 +477,53 @@ Copyright 2021- HPC-AI Technology Inc. All rights reserved. LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + + + ---------------- LICENSE FOR torch-int ---------------- + + MIT License + + Copyright (c) 2022 Guangxuan Xiao + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. + + + ---------------- LICENSE FOR smoothquant ---------------- + + MIT License + + Copyright (c) 2022 MIT HAN Lab + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. diff --git a/colossalai/inference/quant/smoothquant/__init__.py b/colossalai/inference/quant/smoothquant/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/colossalai/inference/quant/smoothquant/models/__init__.py b/colossalai/inference/quant/smoothquant/models/__init__.py new file mode 100644 index 000000000000..77541d8610c5 --- /dev/null +++ b/colossalai/inference/quant/smoothquant/models/__init__.py @@ -0,0 +1,12 @@ +try: + import torch_int + + HAS_TORCH_INT = True +except ImportError: + HAS_TORCH_INT = False + raise ImportError( + "Not install torch_int. Please install torch_int from https://github.com/Guangxuan-Xiao/torch-int" + ) + +if HAS_TORCH_INT: + from .llama import LLamaSmoothquantAttention, LlamaSmoothquantMLP diff --git a/colossalai/inference/quant/smoothquant/models/base_model.py b/colossalai/inference/quant/smoothquant/models/base_model.py new file mode 100644 index 000000000000..180e6c6e8fa6 --- /dev/null +++ b/colossalai/inference/quant/smoothquant/models/base_model.py @@ -0,0 +1,482 @@ +# Adapted from AutoGPTQ: https://github.com/PanQiWei/AutoGPTQ +# Adapted from smoothquant: https://github.com/mit-han-lab/smoothquant/blob/main/smoothquant/calibration.py +# Adapted from smoothquant: https://github.com/mit-han-lab/smoothquant/blob/main/smoothquant/smooth.py + +import os +import warnings +from abc import abstractmethod +from functools import partial +from os.path import isdir, isfile, join +from typing import Dict, List, Optional, Union + +import accelerate +import numpy as np +import torch +import torch.nn as nn +import transformers +from safetensors.torch import save_file as safe_save +from tqdm import tqdm +from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel +from transformers.modeling_utils import no_init_weights +from transformers.utils.generic import ContextManagers +from transformers.utils.hub import PushToHubMixin, cached_file + +from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState +from colossalai.inference.tensor_parallel.kvcache_manager import MemoryManager + +SUPPORTED_MODELS = ["llama"] + + +class BaseSmoothForCausalLM(nn.Module, PushToHubMixin): + layer_type: str = None + + def __init__(self, model: PreTrainedModel, quantized: bool = False): + super().__init__() + + self.model = model + self.model_type = self.model.config.model_type + self._quantized = quantized + self.config = self.model.config + self.cache_manager = None + self.max_total_token_num = 0 + + @property + def quantized(self): + return self._quantized + + def init_cache_manager(self, max_total_token_num=2048): + if self.config.model_type == "llama": + head_num = self.config.num_key_value_heads + layer_num = self.config.num_hidden_layers + head_dim = self.config.hidden_size // head_num + + self.cache_manager = MemoryManager(max_total_token_num, torch.int8, head_num, head_dim, layer_num) + self.max_total_token_num = max_total_token_num + + def init_batch_state(self, max_output_len=256, **kwargs): + input_ids = kwargs["input_ids"] + batch_size = len(input_ids) + + seq_start_indexes = torch.zeros(batch_size, dtype=torch.int32, device="cuda") + seq_lengths = torch.zeros(batch_size, dtype=torch.int32, device="cuda") + start_index = 0 + max_len_in_batch = -1 + + for i in range(batch_size): + seq_len = len(input_ids[i]) + seq_lengths[i] = seq_len + seq_start_indexes[i] = start_index + start_index += seq_len + max_len_in_batch = seq_len if seq_len > max_len_in_batch else max_len_in_batch + + if "max_total_token_num" in kwargs.keys(): + max_total_token_num = kwargs["max_total_token_num"] + self.init_cache_manager(max_total_token_num) + + if "max_new_tokens" in kwargs.keys(): + max_output_len = kwargs["max_new_tokens"] + + if batch_size * (max_len_in_batch + max_output_len) > self.max_total_token_num: + max_total_token_num = batch_size * (max_len_in_batch + max_output_len) + warnings.warn(f"reset max tokens to {max_total_token_num}") + self.init_cache_manager(max_total_token_num) + + block_loc = torch.empty((batch_size, max_len_in_batch + max_output_len), dtype=torch.long, device="cuda") + batch_infer_state = BatchInferState(batch_size, max_len_in_batch) + batch_infer_state.seq_len = seq_lengths.to("cuda") + batch_infer_state.start_loc = seq_start_indexes.to("cuda") + batch_infer_state.block_loc = block_loc + batch_infer_state.decode_layer_id = 0 + batch_infer_state.past_key_values_len = 0 + batch_infer_state.is_context_stage = True + batch_infer_state.set_cache_manager(self.cache_manager) + batch_infer_state.cache_manager.free_all() + return batch_infer_state + + @abstractmethod + @torch.inference_mode() + def quantize( + self, + examples: List[Dict[str, Union[List[int], torch.LongTensor]]], + ): + if self.quantized: + raise EnvironmentError("can't execute quantize because the model is quantized.") + + def forward(self, *args, **kwargs): + return self.model(*args, **kwargs) + + def generate(self, **kwargs): + """shortcut for model.generate""" + + batch_infer_state = self.init_batch_state(**kwargs) + if self.config.model_type == "llama": + setattr(self.model.model, "infer_state", batch_infer_state) + + with torch.inference_mode(): + return self.model.generate(**kwargs) + + def prepare_inputs_for_generation(self, *args, **kwargs): + """shortcut for model.prepare_inputs_for_generation""" + return self.model.prepare_inputs_for_generation(*args, **kwargs) + + def collect_act_scales(self, model, tokenizer, dataset, device, num_samples=512, seq_len=512): + for text in tqdm(dataset): + input_ids = tokenizer(text, return_tensors="pt", max_length=seq_len, truncation=True).input_ids.to(device) + model(input_ids) + + def collect_act_dict(self, model, tokenizer, dataset, act_dict, device, num_samples=512, seq_len=512): + pbar = tqdm(dataset) + for text in pbar: + input_ids = tokenizer(text, return_tensors="pt", max_length=seq_len, truncation=True).input_ids.to(device) + model(input_ids) + mean_scale = np.mean([v["input"] for v in act_dict.values()]) + pbar.set_description(f"Mean input scale: {mean_scale:.2f}") + + def get_act_scales(self, model, tokenizer, dataset, num_samples=512, seq_len=512): + model.eval() + device = next(model.parameters()).device + act_scales = {} + + def stat_tensor(name, tensor): + hidden_dim = tensor.shape[-1] + tensor = tensor.view(-1, hidden_dim).abs().detach() + comming_max = torch.max(tensor, dim=0)[0].float().cpu() + if name in act_scales: + act_scales[name] = torch.max(act_scales[name], comming_max) + else: + act_scales[name] = comming_max + + def stat_input_hook(m, x, y, name): + if isinstance(x, tuple): + x = x[0] + stat_tensor(name, x) + + hooks = [] + for name, m in model.named_modules(): + if isinstance(m, nn.Linear): + hooks.append(m.register_forward_hook(partial(stat_input_hook, name=name))) + + self.collect_act_scales(model, tokenizer, dataset, device, num_samples, seq_len) + + for h in hooks: + h.remove() + + return act_scales + + @torch.no_grad() + def smooth_ln_fcs(self, ln, fcs, act_scales, alpha=0.5): + if not isinstance(fcs, list): + fcs = [fcs] + for fc in fcs: + assert isinstance(fc, nn.Linear) + assert ln.weight.numel() == fc.in_features == act_scales.numel() + + device, dtype = fcs[0].weight.device, fcs[0].weight.dtype + act_scales = act_scales.to(device=device, dtype=dtype) + weight_scales = torch.cat([fc.weight.abs().max(dim=0, keepdim=True)[0] for fc in fcs], dim=0) + weight_scales = weight_scales.max(dim=0)[0].clamp(min=1e-5) + + scales = (act_scales.pow(alpha) / weight_scales.pow(1 - alpha)).clamp(min=1e-5).to(device).to(dtype) + + ln.weight.div_(scales) + if hasattr(ln, "bias"): + ln.bias.div_(scales) + + for fc in fcs: + fc.weight.mul_(scales.view(1, -1)) + + @classmethod + def create_quantized_model(model): + raise NotImplementedError("Not implement create_quantized_model method") + + def save_quantized( + self, + save_dir: str, + model_basename: str, + use_safetensors: bool = False, + safetensors_metadata: Optional[Dict[str, str]] = None, + ): + """save quantized model and configs to local disk""" + os.makedirs(save_dir, exist_ok=True) + + if not self.quantized: + raise EnvironmentError("can only save quantized model, please execute .quantize first.") + + self.model.to("cpu") + + model_base_name = model_basename # or f"smooth-" + if use_safetensors: + model_save_name = model_base_name + ".safetensors" + state_dict = self.model.state_dict() + state_dict = {k: v.clone().contiguous() for k, v in state_dict.items()} + if safetensors_metadata is None: + safetensors_metadata = {} + elif not isinstance(safetensors_metadata, dict): + raise TypeError("safetensors_metadata must be a dictionary.") + else: + print(f"Received safetensors_metadata: {safetensors_metadata}") + new_safetensors_metadata = {} + converted_keys = False + for key, value in safetensors_metadata.items(): + if not isinstance(key, str) or not isinstance(value, str): + converted_keys = True + try: + new_key = str(key) + new_value = str(value) + except Exception as e: + raise TypeError( + f"safetensors_metadata: both keys and values must be strings and an error occured when trying to convert them: {e}" + ) + if new_key in new_safetensors_metadata: + print( + f"After converting safetensors_metadata keys to strings, the key '{new_key}' is duplicated. Ensure that all your metadata keys are strings to avoid overwriting." + ) + new_safetensors_metadata[new_key] = new_value + safetensors_metadata = new_safetensors_metadata + if converted_keys: + print( + f"One or more safetensors_metadata keys or values had to be converted to str(). Final safetensors_metadata: {safetensors_metadata}" + ) + + # Format is required to enable Accelerate to load the metadata + # otherwise it raises an OSError + safetensors_metadata["format"] = "pt" + + safe_save(state_dict, join(save_dir, model_save_name), safetensors_metadata) + else: + model_save_name = model_base_name + ".bin" + torch.save(self.model.state_dict(), join(save_dir, model_save_name)) + + self.model.config.save_pretrained(save_dir) + + def save_pretrained( + self, + save_dir: str, + use_safetensors: bool = False, + safetensors_metadata: Optional[Dict[str, str]] = None, + **kwargs, + ): + """alias of save_quantized""" + warnings.warn("you are using save_pretrained, which will re-direct to save_quantized.") + self.save_quantized(save_dir, use_safetensors, safetensors_metadata) + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: str, + max_memory: Optional[dict] = None, + trust_remote_code: bool = False, + torch_dtype: torch.dtype = torch.float16, + **model_init_kwargs, + ): + if not torch.cuda.is_available(): + raise EnvironmentError("Load pretrained model to do quantization requires CUDA available.") + + def skip(*args, **kwargs): + pass + + torch.nn.init.kaiming_uniform_ = skip + torch.nn.init.uniform_ = skip + torch.nn.init.normal_ = skip + + # Parameters related to loading from Hugging Face Hub + cache_dir = model_init_kwargs.pop("cache_dir", None) + force_download = model_init_kwargs.pop("force_download", False) + resume_download = model_init_kwargs.pop("resume_download", False) + proxies = model_init_kwargs.pop("proxies", None) + local_files_only = model_init_kwargs.pop("local_files_only", False) + use_auth_token = model_init_kwargs.pop("use_auth_token", None) + revision = model_init_kwargs.pop("revision", None) + subfolder = model_init_kwargs.pop("subfolder", "") + model_init_kwargs.pop("_commit_hash", None) + + cached_file_kwargs = { + "cache_dir": cache_dir, + "force_download": force_download, + "proxies": proxies, + "resume_download": resume_download, + "local_files_only": local_files_only, + "use_auth_token": use_auth_token, + "revision": revision, + "subfolder": subfolder, + } + + config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True, **cached_file_kwargs) + if config.model_type not in SUPPORTED_MODELS: + raise TypeError(f"{config.model_type} isn't supported yet.") + + # enforce some values despite user specified + model_init_kwargs["torch_dtype"] = torch_dtype + model_init_kwargs["trust_remote_code"] = trust_remote_code + if max_memory: + if "disk" in max_memory: + raise NotImplementedError("disk offload not support yet.") + with accelerate.init_empty_weights(): + model = AutoModelForCausalLM.from_config(config, trust_remote_code=True) + model.tie_weights() + + max_memory = accelerate.utils.get_balanced_memory( + model, + max_memory=max_memory, + no_split_module_classes=[cls.layer_type], + dtype=model_init_kwargs["torch_dtype"], + low_zero=False, + ) + model_init_kwargs["device_map"] = accelerate.infer_auto_device_map( + model, + max_memory=max_memory, + no_split_module_classes=[cls.layer_type], + dtype=model_init_kwargs["torch_dtype"], + ) + model_init_kwargs["low_cpu_mem_usage"] = True + + del model + else: + model_init_kwargs["device_map"] = None + model_init_kwargs["low_cpu_mem_usage"] = False + + torch.cuda.empty_cache() + + merged_kwargs = {**model_init_kwargs, **cached_file_kwargs} + model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, **merged_kwargs) + + model_config = model.config.to_dict() + seq_len_keys = ["max_position_embeddings", "seq_length", "n_positions"] + if any([k in model_config for k in seq_len_keys]): + for key in seq_len_keys: + if key in model_config: + model.seqlen = model_config[key] + break + else: + warnings.warn("can't get model's sequence length from model config, will set to 4096.") + model.seqlen = 4096 + model.eval() + + return cls(model, False) + + @classmethod + def from_quantized( + cls, + model_name_or_path: Optional[str], + model_basename: Optional[str] = None, + device_map: Optional[Union[str, Dict[str, Union[int, str]]]] = None, + max_memory: Optional[dict] = None, + device: Optional[Union[str, int]] = None, + low_cpu_mem_usage: bool = False, + torch_dtype: Optional[torch.dtype] = None, + use_safetensors: bool = False, + trust_remote_code: bool = False, + **kwargs, + ): + """load quantized model from local disk""" + + # Parameters related to loading from Hugging Face Hub + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", False) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", "") + commit_hash = kwargs.pop("_commit_hash", None) + + cached_file_kwargs = { + "cache_dir": cache_dir, + "force_download": force_download, + "proxies": proxies, + "resume_download": resume_download, + "local_files_only": local_files_only, + "use_auth_token": use_auth_token, + "revision": revision, + "subfolder": subfolder, + "_raise_exceptions_for_missing_entries": False, + "_commit_hash": commit_hash, + } + + # == step1: prepare configs and file names == # + config = AutoConfig.from_pretrained( + model_name_or_path, trust_remote_code=trust_remote_code, **cached_file_kwargs + ) + + if config.model_type not in SUPPORTED_MODELS: + raise TypeError(f"{config.model_type} isn't supported yet.") + + extensions = [] + if use_safetensors: + extensions.append(".safetensors") + else: + extensions += [".bin", ".pt"] + + model_name_or_path = str(model_name_or_path) + is_local = isdir(model_name_or_path) + + resolved_archive_file = None + if is_local: + model_save_name = join(model_name_or_path, model_basename) + for ext in extensions: + if isfile(model_save_name + ext): + resolved_archive_file = model_save_name + ext + break + else: # remote + for ext in extensions: + resolved_archive_file = cached_file(model_name_or_path, model_basename + ext, **cached_file_kwargs) + if resolved_archive_file is not None: + break + + if resolved_archive_file is None: # Could not find a model file to use + raise FileNotFoundError(f"Could not find model in {model_name_or_path}") + + model_save_name = resolved_archive_file + + # == step2: convert model to quantized-model (replace Linear) == # + def skip(*args, **kwargs): + pass + + torch.nn.init.kaiming_uniform_ = skip + torch.nn.init.uniform_ = skip + torch.nn.init.normal_ = skip + + transformers.modeling_utils._init_weights = False + + init_contexts = [no_init_weights()] + if low_cpu_mem_usage: + init_contexts.append(accelerate.init_empty_weights(include_buffers=True)) + + with ContextManagers(init_contexts): + model = AutoModelForCausalLM.from_config( + config, trust_remote_code=trust_remote_code, torch_dtype=torch_dtype + ) + cls.create_quantized_model(model) + model.tie_weights() + + # == step3: load checkpoint to quantized-model == # + accelerate.utils.modeling.load_checkpoint_in_model( + model, checkpoint=model_save_name, offload_state_dict=True, offload_buffers=True + ) + + # == step4: set seqlen == # + model_config = model.config.to_dict() + seq_len_keys = ["max_position_embeddings", "seq_length", "n_positions"] + if any([k in model_config for k in seq_len_keys]): + for key in seq_len_keys: + if key in model_config: + model.seqlen = model_config[key] + break + else: + warnings.warn("can't get model's sequence length from model config, will set to 4096.") + model.seqlen = 4096 + + return cls( + model, + True, + ) + + def __getattr__(self, item): + try: + return super().__getattr__(item) + except: + return getattr(self.model, item) + + +__all__ = ["BaseSmoothForCausalLM"] diff --git a/colossalai/inference/quant/smoothquant/models/linear.py b/colossalai/inference/quant/smoothquant/models/linear.py new file mode 100644 index 000000000000..048565bfbf5e --- /dev/null +++ b/colossalai/inference/quant/smoothquant/models/linear.py @@ -0,0 +1,177 @@ +# modified from torch-int: https://github.com/Guangxuan-Xiao/torch-int/blob/main/torch_int/nn/linear.py + +import torch +from torch_int._CUDA import linear_a8_w8_b8_o8, linear_a8_w8_bfp32_ofp32 +from torch_int.functional.quantization import quantize_per_tensor_absmax + +try: + from colossalai.kernel.op_builder.smoothquant import SmoothquantBuilder + + smoothquant_cuda = SmoothquantBuilder().load() + HAS_SMOOTHQUANT_CUDA = True +except ImportError: + HAS_SMOOTHQUANT_CUDA = False + raise ImportError("CUDA smoothquant linear is not installed") + + +class W8A8BFP32O32LinearSiLU(torch.nn.Module): + def __init__(self, in_features, out_features, alpha=1.0, beta=1.0): + super().__init__() + self.in_features = in_features + self.out_features = out_features + + self.register_buffer( + "weight", + torch.randint( + -127, + 127, + (self.out_features, self.in_features), + dtype=torch.int8, + requires_grad=False, + ), + ) + self.register_buffer( + "bias", + torch.zeros((1, self.out_features), dtype=torch.float, requires_grad=False), + ) + self.register_buffer("a", torch.tensor(alpha)) + + def to(self, *args, **kwargs): + super().to(*args, **kwargs) + self.weight = self.weight.to(*args, **kwargs) + self.bias = self.bias.to(*args, **kwargs) + return self + + @torch.no_grad() + def forward(self, x): + x_shape = x.shape + x = x.view(-1, x_shape[-1]) + y = smoothquant_cuda.linear_silu_a8_w8_bfp32_ofp32(x, self.weight, self.bias, self.a.item(), 1.0) + y = y.view(*x_shape[:-1], -1) + return y + + @staticmethod + def from_float(module: torch.nn.Linear, input_scale): + int8_module = W8A8BFP32O32LinearSiLU(module.in_features, module.out_features) + int8_weight, weight_scale = quantize_per_tensor_absmax(module.weight) + alpha = input_scale * weight_scale + int8_module.weight = int8_weight + if module.bias is not None: + int8_module.bias.data.copy_(module.bias.to(torch.float)) + int8_module.a = alpha + return int8_module + + +class W8A8B8O8Linear(torch.nn.Module): + # For qkv_proj + def __init__(self, in_features, out_features, alpha=1.0, beta=1.0): + super().__init__() + self.in_features = in_features + self.out_features = out_features + + self.register_buffer( + "weight", + torch.randint( + -127, + 127, + (self.out_features, self.in_features), + dtype=torch.int8, + requires_grad=False, + ), + ) + self.register_buffer( + "bias", + torch.zeros((1, self.out_features), dtype=torch.int8, requires_grad=False), + ) + self.register_buffer("a", torch.tensor(alpha)) + self.register_buffer("b", torch.tensor(beta)) + + def to(self, *args, **kwargs): + super().to(*args, **kwargs) + self.weight = self.weight.to(*args, **kwargs) + self.bias = self.bias.to(*args, **kwargs) + return self + + @torch.no_grad() + def forward(self, x): + x_shape = x.shape + x = x.view(-1, x_shape[-1]) + y = linear_a8_w8_b8_o8(x, self.weight, self.bias, self.a.item(), self.b.item()) + y = y.view(*x_shape[:-1], -1) + return y + + @staticmethod + def from_float(module: torch.nn.Linear, input_scale, output_scale): + int8_module = W8A8B8O8Linear(module.in_features, module.out_features) + int8_weight, weight_scale = quantize_per_tensor_absmax(module.weight) + alpha = input_scale * weight_scale / output_scale + int8_module.weight = int8_weight + int8_module.a = alpha + + if module.bias is not None: + int8_bias, bias_scale = quantize_per_tensor_absmax(module.bias) + int8_module.bias = int8_bias + beta = bias_scale / output_scale + int8_module.b = beta + + return int8_module + + +class W8A8BFP32OFP32Linear(torch.nn.Module): + # For fc2 and out_proj + def __init__(self, in_features, out_features, alpha=1.0, beta=1.0): + super().__init__() + self.in_features = in_features + self.out_features = out_features + + self.register_buffer( + "weight", + torch.randint( + -127, + 127, + (self.out_features, self.in_features), + dtype=torch.int8, + requires_grad=False, + ), + ) + self.register_buffer( + "bias", + torch.zeros(self.out_features, dtype=torch.float32, requires_grad=False), + ) + self.register_buffer("a", torch.tensor(alpha)) + + def _apply(self, fn): + # prevent the bias from being converted to half + super()._apply(fn) + self.bias = self.bias.to(torch.float32) + return self + + def to(self, *args, **kwargs): + super().to(*args, **kwargs) + self.weight = self.weight.to(*args, **kwargs) + self.bias = self.bias.to(*args, **kwargs) + self.bias = self.bias.to(torch.float32) + return self + + @torch.no_grad() + def forward(self, x): + x_shape = x.shape + x = x.view(-1, x_shape[-1]) + y = linear_a8_w8_bfp32_ofp32(x, self.weight, self.bias, self.a.item(), 1) + y = y.view(*x_shape[:-1], -1) + return y + + @staticmethod + def from_float(module: torch.nn.Linear, input_scale): + int8_module = W8A8BFP32OFP32Linear(module.in_features, module.out_features) + int8_weight, weight_scale = quantize_per_tensor_absmax(module.weight) + alpha = input_scale * weight_scale + int8_module.weight = int8_weight + int8_module.a = alpha + int8_module.input_scale = input_scale + int8_module.weight_scale = weight_scale + + if module.bias is not None: + int8_module.bias = module.bias.to(torch.float32) + + return int8_module diff --git a/colossalai/inference/quant/smoothquant/models/llama.py b/colossalai/inference/quant/smoothquant/models/llama.py new file mode 100644 index 000000000000..9c77feeb346e --- /dev/null +++ b/colossalai/inference/quant/smoothquant/models/llama.py @@ -0,0 +1,846 @@ +import math +import os +import types +from collections import defaultdict +from functools import partial +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch_int.nn.bmm import BMM_S8T_S8N_F32T, BMM_S8T_S8N_S8T +from transformers import PreTrainedModel +from transformers.modeling_outputs import BaseModelOutputWithPast +from transformers.models.llama.configuration_llama import LlamaConfig +from transformers.models.llama.modeling_llama import ( + LLAMA_INPUTS_DOCSTRING, + LlamaAttention, + LlamaDecoderLayer, + LlamaMLP, + LlamaRotaryEmbedding, + repeat_kv, + rotate_half, +) +from transformers.utils import add_start_docstrings_to_model_forward + +from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState +from colossalai.kernel.triton import ( + copy_kv_cache_to_dest, + int8_rotary_embedding_fwd, + smooth_llama_context_attn_fwd, + smooth_token_attention_fwd, +) + +from .base_model import BaseSmoothForCausalLM +from .linear import W8A8B8O8Linear, W8A8BFP32O32LinearSiLU, W8A8BFP32OFP32Linear + + +class LLamaSmoothquantAttention(nn.Module): + def __init__( + self, + hidden_size: int, + num_heads: int, + ): + super().__init__() + self.hidden_size = hidden_size + self.num_heads = num_heads + self.head_dim = hidden_size // num_heads + + if (self.head_dim * num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {num_heads})." + ) + + self.qk_bmm = BMM_S8T_S8N_F32T(1.0) + self.pv_bmm = BMM_S8T_S8N_S8T(1.0) + + self.k_proj = W8A8B8O8Linear(hidden_size, hidden_size) + self.v_proj = W8A8B8O8Linear(hidden_size, hidden_size) + self.q_proj = W8A8B8O8Linear(hidden_size, hidden_size) + self.o_proj = W8A8BFP32OFP32Linear(hidden_size, hidden_size) + + self.register_buffer("q_output_scale", torch.tensor([1.0])) + self.register_buffer("k_output_scale", torch.tensor([1.0])) + self.register_buffer("v_output_scale", torch.tensor([1.0])) + self.register_buffer("q_rotary_output_scale", torch.tensor([1.0])) + self.register_buffer("k_rotary_output_scale", torch.tensor([1.0])) + self.register_buffer("out_input_scale", torch.tensor([1.0])) + self.register_buffer("attn_input_scale", torch.tensor([1.0])) + + self._init_rope() + self.num_key_value_heads = num_heads + + def _init_rope(self): + self.rotary_emb = LlamaRotaryEmbedding( + self.head_dim, + max_position_embeddings=2048, + base=10000.0, + ) + + @staticmethod + def pack( + module: LlamaAttention, + attn_input_scale: float, + q_output_scale: float, + k_output_scale: float, + v_output_scale: float, + q_rotary_output_scale: float, + k_rotary_output_scale: float, + out_input_scale: float, + ): + int8_module = LLamaSmoothquantAttention(module.hidden_size, module.num_heads) + + int8_module.attn_input_scale = torch.tensor([attn_input_scale]) + + int8_module.q_output_scale = torch.tensor([q_output_scale]) + int8_module.k_output_scale = torch.tensor([k_output_scale]) + int8_module.v_output_scale = torch.tensor([v_output_scale]) + + int8_module.q_rotary_output_scale = torch.tensor([q_rotary_output_scale]) + int8_module.k_rotary_output_scale = torch.tensor([k_rotary_output_scale]) + + int8_module.q_proj = W8A8B8O8Linear.from_float(module.q_proj, attn_input_scale, q_output_scale) + int8_module.k_proj = W8A8B8O8Linear.from_float(module.k_proj, attn_input_scale, k_output_scale) + int8_module.v_proj = W8A8B8O8Linear.from_float(module.v_proj, attn_input_scale, v_output_scale) + int8_module.o_proj = W8A8BFP32OFP32Linear.from_float(module.o_proj, out_input_scale) + + int8_module.out_input_scale = torch.tensor([out_input_scale]) + + return int8_module + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + @torch.no_grad() + def forward( + self, + hidden_states: torch.Tensor, + rotary_emb: Tuple[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + padding_mask: Optional[torch.LongTensor] = None, + infer_state: Optional[BatchInferState] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + cos = rotary_emb[0] + sin = rotary_emb[1] + + int8_rotary_embedding_fwd( + query_states.view(-1, self.num_heads, self.head_dim), + cos, + sin, + self.q_output_scale.item(), + self.q_rotary_output_scale.item(), + ) + int8_rotary_embedding_fwd( + key_states.view(-1, self.num_heads, self.head_dim), + cos, + sin, + self.k_output_scale.item(), + self.k_rotary_output_scale.item(), + ) + + # NOTE might want to revise + # need some way to record the length of past key values cache + # since we won't return past_key_value_cache right now + if infer_state.decode_layer_id == 0: # once per model.forward + infer_state.cache_manager.past_key_values_length += q_len # seq_len + + def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, mem_manager): + copy_kv_cache_to_dest(key_buffer, context_mem_index, mem_manager.key_buffer[layer_id]) + copy_kv_cache_to_dest(value_buffer, context_mem_index, mem_manager.value_buffer[layer_id]) + return + + query_states = query_states.view(-1, self.num_heads, self.head_dim) + key_states = key_states.view(-1, self.num_heads, self.head_dim) + value_states = value_states.view(-1, self.num_heads, self.head_dim) + + if infer_state.is_context_stage: + # first token generation + + # copy key and value calculated in current step to memory manager + _copy_kv_to_mem_cache( + infer_state.decode_layer_id, + key_states, + value_states, + infer_state.context_mem_index, + infer_state.cache_manager, + ) + + attn_output = torch.empty_like(query_states) + + smooth_llama_context_attn_fwd( + query_states, + key_states, + value_states, + attn_output, + self.q_rotary_output_scale.item(), + self.k_rotary_output_scale.item(), + self.v_output_scale.item(), + self.out_input_scale.item(), + infer_state.start_loc, + infer_state.seq_len, + q_len, + ) + + else: + if infer_state.decode_is_contiguous: + # if decode is contiguous, then we copy to key cache and value cache in cache manager directly + cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id][ + infer_state.decode_mem_start : infer_state.decode_mem_end, :, : + ] + cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id][ + infer_state.decode_mem_start : infer_state.decode_mem_end, :, : + ] + cache_k.copy_(key_states) + cache_v.copy_(value_states) + else: + # if decode is not contiguous, use triton kernel to copy key and value cache + # k, v shape: [batch_size, num_heads, head_dim/embed_size_per_head + _copy_kv_to_mem_cache( + infer_state.decode_layer_id, + key_states, + value_states, + infer_state.decode_mem_index, + infer_state.cache_manager, + ) + + # (batch_size, seqlen, nheads, headdim) + attn_output = torch.empty_like(query_states) + + smooth_token_attention_fwd( + query_states, + infer_state.cache_manager.key_buffer[infer_state.decode_layer_id], + infer_state.cache_manager.value_buffer[infer_state.decode_layer_id], + attn_output, + self.q_rotary_output_scale.item(), + self.k_rotary_output_scale.item(), + self.v_output_scale.item(), + self.out_input_scale.item(), + infer_state.block_loc, + infer_state.start_loc, + infer_state.seq_len, + infer_state.cache_manager.past_key_values_length, + ) + + attn_output = attn_output.view(bsz, q_len, self.num_heads * self.head_dim) + attn_output = self.o_proj(attn_output) + + return attn_output, None, None + + +class LlamaLayerNormQ(torch.nn.Module): + def __init__(self, dim, eps=1e-5): + super().__init__() + self.input_scale = 1.0 + self.variance_epsilon = eps + self.register_buffer("weight", torch.ones(dim, dtype=torch.float32)) + + def forward(self, x): + ln_output_fp = torch.nn.functional.layer_norm(x, x.shape[-1:], self.weight, None, self.variance_epsilon) + ln_output_int8 = ln_output_fp.round().clamp(-128, 127).to(torch.int8) + return ln_output_int8 + + @staticmethod + def from_float(module: torch.nn.LayerNorm, output_scale: float): + assert module.weight.shape[0] == module.weight.numel() + q_module = LlamaLayerNormQ(module.weight.shape[0], module.variance_epsilon) + q_module.weight = module.weight / output_scale + return q_module + + +class LlamaSmoothquantMLP(nn.Module): + def __init__(self, intermediate_size, hidden_size): + super().__init__() + self.gate_proj = W8A8BFP32O32LinearSiLU(hidden_size, intermediate_size) + self.up_proj = W8A8BFP32OFP32Linear(hidden_size, intermediate_size) + self.down_proj = W8A8BFP32OFP32Linear(intermediate_size, hidden_size) + self.register_buffer("down_proj_input_scale", torch.tensor([1.0])) + + @staticmethod + def pack( + mlp_module: LlamaMLP, + gate_proj_input_scale: float, + up_proj_input_scale: float, + down_proj_input_scale: float, + ): + int8_module = LlamaSmoothquantMLP( + mlp_module.intermediate_size, + mlp_module.hidden_size, + ) + + int8_module.gate_proj = W8A8BFP32O32LinearSiLU.from_float(mlp_module.gate_proj, gate_proj_input_scale) + int8_module.up_proj = W8A8BFP32OFP32Linear.from_float(mlp_module.up_proj, up_proj_input_scale) + int8_module.down_proj = W8A8BFP32OFP32Linear.from_float(mlp_module.down_proj, down_proj_input_scale) + int8_module.down_proj_input_scale = torch.tensor([down_proj_input_scale]) + return int8_module + + def forward( + self, + hidden_states: torch.Tensor, + ): + x_shape = hidden_states.shape + gate_out = self.gate_proj(hidden_states) + up_out = self.up_proj(hidden_states) + inter_out = gate_out * up_out + inter_out = inter_out.div_(self.down_proj_input_scale.item()).round().clamp(-128, 127).to(torch.int8) + down_out = self.down_proj(inter_out) + down_out = down_out.view(*x_shape[:-1], -1) + return down_out + + +class LlamaSmoothquantDecoderLayer(nn.Module): + def __init__(self, config: LlamaConfig): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = LLamaSmoothquantAttention(config.hidden_size, config.num_attention_heads) + + self.mlp = LlamaSmoothquantMLP(config.intermediate_size, config.hidden_size) + self.input_layernorm = LlamaLayerNormQ(config.hidden_size, eps=config.rms_norm_eps) + + self.post_attention_layernorm = LlamaLayerNormQ(config.hidden_size, eps=config.rms_norm_eps) + + @staticmethod + def pack( + module: LlamaDecoderLayer, + attn_input_scale: float, + q_output_scale: float, + k_output_scale: float, + v_output_scale: float, + q_rotary_output_scale: float, + k_rotary_output_scale: float, + out_input_scale: float, + gate_input_scale: float, + up_input_scale: float, + down_input_scale: float, + ): + config = module.self_attn.config + int8_decoder_layer = LlamaSmoothquantDecoderLayer(config) + + int8_decoder_layer.input_layernorm = LlamaLayerNormQ.from_float(module.input_layernorm, attn_input_scale) + int8_decoder_layer.self_attn = LLamaSmoothquantAttention.pack( + module.self_attn, + attn_input_scale, + q_output_scale, + k_output_scale, + v_output_scale, + q_rotary_output_scale, + k_rotary_output_scale, + out_input_scale, + ) + + int8_decoder_layer.post_attention_layernorm = LlamaLayerNormQ.from_float( + module.post_attention_layernorm, gate_input_scale + ) + + int8_decoder_layer.mlp = LlamaSmoothquantMLP.pack( + module.mlp, + gate_input_scale, + up_input_scale, + down_input_scale, + ) + + return int8_decoder_layer + + def forward( + self, + hidden_states: torch.Tensor, + rotary_emb: Tuple[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + padding_mask: Optional[torch.LongTensor] = None, + infer_state: Optional[BatchInferState] = None, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + rotary_emb=rotary_emb, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + padding_mask=padding_mask, + infer_state=infer_state, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states, None, None + + +class LlamaApplyRotary(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, cos, sin, position_ids): + # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. + cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] + sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] + cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + x_embed = (x * cos) + (rotate_half(x) * sin) + + return x_embed + + +def llama_decoder_layer_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + padding_mask: Optional[torch.LongTensor] = None, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + if self.config.pretraining_tp > 1: + key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp + query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0) + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] + query_states = torch.cat(query_states, dim=-1) + + key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] + key_states = torch.cat(key_states, dim=-1) + + value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] + value_states = torch.cat(value_states, dim=-1) + + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states = self.q_apply_rotary(query_states, cos, sin, position_ids) + key_states = self.k_apply_rotary(key_states, cos, sin, position_ids) + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + if self.config.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) + attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) + else: + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +def init_to_get_rotary(config, base=10000, use_elem=False): + """ + This function initializes the rotary positional embedding, it is compatible for all models and is called in ShardFormer + Args: + base : calculation arg + use_elem : activated when using chatglm-based models + """ + config.head_dim_ = config.hidden_size // config.num_attention_heads + if not hasattr(config, "rope_scaling"): + rope_scaling_factor = 1.0 + else: + rope_scaling_factor = config.rope_scaling.factor if config.rope_scaling is not None else 1.0 + + if hasattr(config, "max_sequence_length"): + max_seq_len = config.max_sequence_length + elif hasattr(config, "max_position_embeddings"): + max_seq_len = config.max_position_embeddings * rope_scaling_factor + else: + max_seq_len = 2048 * rope_scaling_factor + base = float(base) + + # NTK ref: https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ + try: + ntk_alpha = float(os.environ.get("INFER_NTK_ALPHA", 1)) + assert ntk_alpha >= 1 + if ntk_alpha > 1: + print(f"Note: NTK enabled, alpha set to {ntk_alpha}") + max_seq_len *= ntk_alpha + base = base * (ntk_alpha ** (config.head_dim_ / (config.head_dim_ - 2))) # Base change formula + except: + pass + + n_elem = config.head_dim_ + if use_elem: + n_elem //= 2 + + inv_freq = 1.0 / (base ** (torch.arange(0, n_elem, 2, device="cpu", dtype=torch.float32) / n_elem)) + t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor + freqs = torch.outer(t, inv_freq) + + _cos_cached = torch.cos(freqs).to(torch.float) + _sin_cached = torch.sin(freqs).to(torch.float) + return _cos_cached, _sin_cached + + +@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) +def llama_model_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, +) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + seq_length_with_past = seq_length + past_key_values_length = 0 + + infer_state = self.infer_state + + if past_key_values is not None: + # NOT READY FOR PRIME TIME + # dummy but work, revise it + past_key_values_length = infer_state.cache_manager.past_key_values_length + # past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + # NOTE: differentiate with prefill stage + # block_loc require different value-assigning method for two different stage + # NOTE: differentiate with prefill stage + # block_loc require different value-assigning method for two different stage + if infer_state.is_context_stage: + infer_state.context_mem_index = infer_state.cache_manager.alloc(infer_state.total_token_num) + infer_state.init_block_loc( + infer_state.block_loc, infer_state.seq_len, seq_length, infer_state.context_mem_index + ) + else: + alloc_mem = infer_state.cache_manager.alloc_contiguous(batch_size) + if alloc_mem is not None: + infer_state.decode_is_contiguous = True + infer_state.decode_mem_index = alloc_mem[0] + infer_state.decode_mem_start = alloc_mem[1] + infer_state.decode_mem_end = alloc_mem[2] + infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index + else: + print(f" *** Encountered allocation non-contiguous") + print( + f" infer_state.cache_manager.past_key_values_length: {infer_state.cache_manager.past_key_values_length}" + ) + infer_state.decode_is_contiguous = False + alloc_mem = infer_state.cache_manager.alloc(batch_size) + infer_state.decode_mem_index = alloc_mem + infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + # embed positions + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device) + padding_mask = None + else: + if 0 in attention_mask: + padding_mask = attention_mask + else: + padding_mask = None + + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + ) + + hidden_states = inputs_embeds + + if self.gradient_checkpointing and self.training: + raise NotImplementedError("not implement gradient_checkpointing and training options ") + + if past_key_values_length == 0: + position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view( + position_ids.view(-1).shape[0], -1 + ) + position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view( + position_ids.view(-1).shape[0], -1 + ) + else: + position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view(batch_size, -1) + position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view(batch_size, -1) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + infer_state.decode_layer_id = 0 + for idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + layer_outputs = decoder_layer( + hidden_states, + rotary_emb=(position_cos, position_sin), + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + padding_mask=padding_mask, + infer_state=infer_state, + ) + + hidden_states = layer_outputs[0] + infer_state.decode_layer_id += 1 + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + infer_state.is_context_stage = False + infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") + infer_state.seq_len += 1 + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class SmoothLlamaForCausalLM(BaseSmoothForCausalLM): + layer_type = "LlamaDecoderLayer" + + def __init__(self, model: PreTrainedModel, quantized: bool = False): + super().__init__(model, quantized) + + def get_act_dict( + self, + tokenizer, + dataset, + num_samples=512, + seq_len=512, + ): + llama_model = self.model + + llama_model.eval() + device = next(llama_model.parameters()).device + # print("model:", llama_model) + act_dict = defaultdict(dict) + + def stat_io_hook(m, x, y, name): + if isinstance(x, tuple): + x = x[0] + if name not in act_dict or "input" not in act_dict[name]: + act_dict[name]["input"] = x.detach().abs().max().item() + else: + act_dict[name]["input"] = max(act_dict[name]["input"], x.detach().abs().max().item()) + if isinstance(y, tuple): + y = y[0] + if name not in act_dict or "output" not in act_dict[name]: + act_dict[name]["output"] = y.detach().abs().max().item() + else: + act_dict[name]["output"] = max(act_dict[name]["output"], y.detach().abs().max().item()) + + for name, m in llama_model.named_modules(): + if isinstance(m, LlamaAttention): + setattr(m, "q_apply_rotary", LlamaApplyRotary()) + setattr(m, "k_apply_rotary", LlamaApplyRotary()) + m.forward = types.MethodType(llama_decoder_layer_forward, m) + + hooks = [] + for name, m in llama_model.named_modules(): + if isinstance(m, LlamaApplyRotary): + hooks.append(m.register_forward_hook(partial(stat_io_hook, name=name))) + if isinstance(m, torch.nn.Linear): + hooks.append(m.register_forward_hook(partial(stat_io_hook, name=name))) + + self.collect_act_dict(llama_model, tokenizer, dataset, act_dict, device, num_samples, seq_len) + + for hook in hooks: + hook.remove() + return act_dict + + def smooth_fn(self, scales, alpha=0.5): + model = self.model + for name, module in model.named_modules(): + if isinstance(module, LlamaDecoderLayer): + attn_ln = module.input_layernorm + qkv = [module.self_attn.q_proj, module.self_attn.k_proj, module.self_attn.v_proj] + qkv_input_scales = scales[name + ".self_attn.q_proj"] + self.smooth_ln_fcs(attn_ln, qkv, qkv_input_scales, alpha) + + def create_quantized_model(model): + llama_config = model.config + for i, layer in enumerate(model.model.layers): + model.model.layers[i] = LlamaSmoothquantDecoderLayer(llama_config) + + model.model.forward = types.MethodType(llama_model_forward, model.model) + cos, sin = init_to_get_rotary(llama_config) + model.model.register_buffer("_cos_cached", cos) + model.model.register_buffer("_sin_cached", sin) + + def quantized( + self, + tokenizer, + dataset, + num_samples=512, + seq_len=512, + alpha=0.5, + ): + llama_model = self.model + llama_config = llama_model.config + + act_scales = self.get_act_scales(llama_model, tokenizer, dataset, num_samples, seq_len) + + self.smooth_fn(act_scales, alpha) + + act_dict = self.get_act_dict(tokenizer, dataset, num_samples, seq_len) + decoder_layer_scales = [] + + for idx in range(llama_config.num_hidden_layers): + scale_dict = {} + scale_dict["attn_input_scale"] = act_dict[f"model.layers.{idx}.self_attn.q_proj"]["input"] / 127 + scale_dict["q_output_scale"] = act_dict[f"model.layers.{idx}.self_attn.q_proj"]["output"] / 127 + scale_dict["k_output_scale"] = act_dict[f"model.layers.{idx}.self_attn.k_proj"]["output"] / 127 + scale_dict["v_output_scale"] = act_dict[f"model.layers.{idx}.self_attn.v_proj"]["output"] / 127 + + scale_dict["q_rotary_output_scale"] = ( + act_dict[f"model.layers.{idx}.self_attn.q_apply_rotary"]["output"] / 127 + ) + scale_dict["k_rotary_output_scale"] = ( + act_dict[f"model.layers.{idx}.self_attn.k_apply_rotary"]["output"] / 127 + ) + + scale_dict["out_input_scale"] = act_dict[f"model.layers.{idx}.self_attn.o_proj"]["input"] / 127 + + scale_dict["gate_input_scale"] = act_dict[f"model.layers.{idx}.mlp.gate_proj"]["input"] / 127 + scale_dict["up_input_scale"] = act_dict[f"model.layers.{idx}.mlp.up_proj"]["input"] / 127 + scale_dict["down_input_scale"] = act_dict[f"model.layers.{idx}.mlp.down_proj"]["input"] / 127 + + decoder_layer_scales.append(scale_dict) + + for i, layer in enumerate(llama_model.model.layers): + orig_layer = layer + llama_model.model.layers[i] = LlamaSmoothquantDecoderLayer.pack(orig_layer, **decoder_layer_scales[i]) + + llama_model.model.forward = types.MethodType(llama_model_forward, llama_model.model) + + cos, sin = init_to_get_rotary(llama_config) + llama_model.model.register_buffer("_cos_cached", cos.to(self.model.device)) + llama_model.model.register_buffer("_sin_cached", sin.to(self.model.device)) diff --git a/colossalai/kernel/cuda_native/csrc/smoothquant/binding.cpp b/colossalai/kernel/cuda_native/csrc/smoothquant/binding.cpp new file mode 100644 index 000000000000..8444272940b4 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/smoothquant/binding.cpp @@ -0,0 +1,8 @@ +#include + +#include "linear.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("linear_silu_a8_w8_bfp32_ofp32", &linear_silu_a8_w8_bfp32_ofp32, + "Linear SiLU (INT8)"); +} diff --git a/colossalai/kernel/cuda_native/csrc/smoothquant/linear.cu b/colossalai/kernel/cuda_native/csrc/smoothquant/linear.cu new file mode 100644 index 000000000000..a30d02a4cf42 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/smoothquant/linear.cu @@ -0,0 +1,162 @@ +// modified from https://github.com/Guangxuan-Xiao/torch-int/blob/main/torch_int/kernels/linear.cu + +#include "linear.h" +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +torch::Tensor linear_silu_a8_w8_bfp32_ofp32(torch::Tensor input, // INT8 + torch::Tensor weight, // INT8 + torch::Tensor bias, // FP32 + float alpha, // FP32 + float beta // FP32 +) { + auto M = input.size(0); + auto N = weight.size(0); + auto K = input.size(1); + + using ElementOutput = float; + using ElementAccumulator = int32_t; + using ElementComputeEpilogue = float; + using ElementInputA = int8_t; // <- data type of elements in input matrix A + using ElementInputB = int8_t; // <- data type of elements in input matrix B + + // The code section below describes matrix layout of input and output + // matrices. Column Major for Matrix A, Row Major for Matrix B and Row Major + // for Matrix C + using LayoutInputA = cutlass::layout::RowMajor; + using LayoutInputB = cutlass::layout::ColumnMajor; + using LayoutOutput = cutlass::layout::RowMajor; + +#if CUDA_ARCH >= 800 + using EpilogueOp = cutlass::epilogue::thread::LinearCombinationSilu< + ElementOutput, // <- data type of output matrix + 128 / cutlass::sizeof_bits< + ElementOutput>::value, // <- this is the number of elements per + // vectorized memory access. For half + // precision, it's 8 elements. This + // becomes the vector width of math + // instructions in epilogue too + ElementAccumulator, // <- data type of accumulator + ElementComputeEpilogue // <- data type for alpha in linear combination + // function + >; + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, + ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + EpilogueOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; +#elif CUDA_ARCH >= 750 + using EpilogueOp = cutlass::epilogue::thread::LinearCombinationSilu< + ElementOutput, // <- data type of output matrix + 128 / cutlass::sizeof_bits< + ElementOutput>::value, // <- this is the number of elements per + // vectorized memory access. For half + // precision, it's 8 elements. This + // becomes the vector width of math + // instructions in epilogue too + ElementAccumulator, // <- data type of accumulator + ElementComputeEpilogue // <- data type for alpha in linear combination + // function + >; + + using DefaultGemmCfg = cutlass::gemm::device::DefaultGemmConfiguration< + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, + ElementInputA, ElementInputB, ElementOutput, ElementAccumulator>; + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, + ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, + DefaultGemmCfg::ThreadblockShape, DefaultGemmCfg::WarpShape, + DefaultGemmCfg::InstructionShape, + EpilogueOp>; +#elif CUDA_ARCH >= 700 + #define USE_TORCH_SILU + using DefaultGemmCfg = cutlass::gemm::device::DefaultGemmConfiguration< + cutlass::arch::OpClassSimt, cutlass::arch::Sm70, + ElementInputA, ElementInputB, ElementOutput, ElementAccumulator>; + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, + ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, + cutlass::arch::OpClassSimt, cutlass::arch::Sm70, + DefaultGemmCfg::ThreadblockShape, DefaultGemmCfg::WarpShape, + DefaultGemmCfg::InstructionShape, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 1, ElementAccumulator, ElementComputeEpilogue>>; +#else + #error "Unsupported cuda arch" +#endif + + auto input_size = cutlass::MatrixCoord(M, K); + auto weight_size = cutlass::MatrixCoord(K, N); + auto output_size = cutlass::MatrixCoord(M, N); + + auto device = input.device(); + // use the broadcasted bias as the output + auto out = bias.to(device).view({1, -1}).repeat({M, 1}); + + // constexpr int kSparse = Gemm::kSparse; + // How many elements of A are covered per ElementE + // constexpr int kElementsPerElementE = Gemm::kElementsPerElementE; + // The size of individual meta data + // constexpr int kMetaSizeInBits = Gemm::kMetaSizeInBits; + cutlass::gemm::GemmCoord problem_size(M, N, K); + + cutlass::TensorRef input_ref( + input.data_ptr(), LayoutInputA::packed(input_size)); + cutlass::TensorRef weight_ref( + weight.data_ptr(), LayoutInputB::packed(weight_size)); + cutlass::TensorRef out_ref( + out.data_ptr(), LayoutOutput::packed(output_size)); + + typename Gemm::Arguments arguments{ + problem_size, // <- problem size of matrix multiplication + input_ref, // <- reference to matrix A on device + weight_ref, // <- reference to matrix B on device + out_ref, // <- reference to matrix C on device + out_ref, // <- reference to matrix D on device + {alpha, beta}, 1}; + Gemm gemm_op; + + // Using the arguments, query for extra workspace required for matrix + // multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check the problem size is supported or not + cutlass::Status status = gemm_op.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("cutlass cannot implement"); + } + + // Initialize CUTLASS kernel with arguments and workspace pointer + status = gemm_op.initialize(arguments, workspace.get()); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("cutlass cannot initialize"); + } + + status = gemm_op(); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("cutlass cannot run"); + } +#ifdef USE_TORCH_SILU +#undef USE_TORCH_SILU + out = torch::silu(out); +#endif + return out; +} diff --git a/colossalai/kernel/cuda_native/csrc/smoothquant/linear.h b/colossalai/kernel/cuda_native/csrc/smoothquant/linear.h new file mode 100644 index 000000000000..b62a27f3f8f3 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/smoothquant/linear.h @@ -0,0 +1,12 @@ +#include +#include + +#include +#include + +torch::Tensor linear_silu_a8_w8_bfp32_ofp32(torch::Tensor input, // INT8 + torch::Tensor weight, // INT8 + torch::Tensor bias, // FP32 + float alpha, // FP32 + float beta // FP32 +); diff --git a/colossalai/kernel/triton/__init__.py b/colossalai/kernel/triton/__init__.py index f065b2100fa8..27351a686d2f 100644 --- a/colossalai/kernel/triton/__init__.py +++ b/colossalai/kernel/triton/__init__.py @@ -13,8 +13,10 @@ from .copy_kv_cache_dest import copy_kv_cache_to_dest from .fused_layernorm import layer_norm from .gptq_triton import gptq_fused_linear_triton + from .int8_rotary_embedding_kernel import int8_rotary_embedding_fwd from .rms_norm import rmsnorm_forward from .rotary_embedding_kernel import rotary_embedding_fwd + from .smooth_attention import smooth_llama_context_attn_fwd, smooth_token_attention_fwd from .softmax import softmax from .token_attention_kernel import token_attention_fwd @@ -29,4 +31,7 @@ "rotary_embedding_fwd", "token_attention_fwd", "gptq_fused_linear_triton", + "int8_rotary_embedding_fwd", + "smooth_llama_context_attn_fwd", + "smooth_token_attention_fwd", ] diff --git a/colossalai/kernel/triton/int8_rotary_embedding_kernel.py b/colossalai/kernel/triton/int8_rotary_embedding_kernel.py new file mode 100644 index 000000000000..537dd164d1ab --- /dev/null +++ b/colossalai/kernel/triton/int8_rotary_embedding_kernel.py @@ -0,0 +1,117 @@ +# Adapted from ModelTC https://github.com/ModelTC/lightllm +import torch +import triton +import triton.language as tl + + +@triton.jit +def _rotary_kernel( + q, + input_scale, + output_scale, + Cos, + Sin, + q_bs_stride, + q_h_stride, + q_d_stride, + cos_bs_stride, + cos_d_stride, + total_len, + HEAD_NUM: tl.constexpr, + BLOCK_HEAD: tl.constexpr, + BLOCK_SEQ: tl.constexpr, + HEAD_DIM: tl.constexpr, +): + current_head_index = tl.program_id(0) + current_seq_index = tl.program_id(1) + + dim_range0 = tl.arange(0, HEAD_DIM // 2) + dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM) + + current_head_range = current_head_index * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD) + current_seq_range = current_seq_index * BLOCK_SEQ + tl.arange(0, BLOCK_SEQ) + + off_q0 = ( + current_seq_range[:, None, None] * q_bs_stride + + current_head_range[None, :, None] * q_h_stride + + dim_range0[None, None, :] * q_d_stride + ) + off_q1 = ( + current_seq_range[:, None, None] * q_bs_stride + + current_head_range[None, :, None] * q_h_stride + + dim_range1[None, None, :] * q_d_stride + ) + + off_dimcos_sin = current_seq_range[:, None, None] * cos_bs_stride + dim_range0[None, None, :] * cos_d_stride + + q0 = tl.load( + q + off_q0, + mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM), + other=0.0, + ) + q1 = tl.load( + q + off_q1, + mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM), + other=0.0, + ) + + cos = tl.load(Cos + off_dimcos_sin, mask=current_seq_range[:, None, None] < total_len, other=0.0) + sin = tl.load(Sin + off_dimcos_sin, mask=current_seq_range[:, None, None] < total_len, other=0.0) + + q0 = q0.to(tl.float32) * input_scale + q1 = q1.to(tl.float32) * input_scale + + out0 = (q0 * cos - q1 * sin) / output_scale + out1 = (q0 * sin + q1 * cos) / output_scale + + out0 = out0.to(tl.int8) + out1 = out1.to(tl.int8) + + tl.store( + q + off_q0, + out0, + mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM), + ) + tl.store( + q + off_q1, + out1, + mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM), + ) + + return + + +@torch.no_grad() +def int8_rotary_embedding_fwd(q, cos, sin, input_scale, output_scale): + total_len = q.shape[0] + head_num = q.shape[1] + head_dim = q.shape[2] + assert q.shape[0] == cos.shape[0] and q.shape[0] == sin.shape[0], f"q shape {q.shape} cos shape {cos.shape}" + BLOCK_HEAD = 4 + BLOCK_SEQ = 32 + grid = (triton.cdiv(head_num, BLOCK_HEAD), triton.cdiv(total_len, BLOCK_SEQ)) + if head_dim >= 128: + num_warps = 8 + else: + num_warps = 4 + + _rotary_kernel[grid]( + q, + input_scale, + output_scale, + cos, + sin, + q.stride(0), + q.stride(1), + q.stride(2), + cos.stride(0), + cos.stride(1), + total_len, + HEAD_NUM=head_num, + BLOCK_HEAD=BLOCK_HEAD, + BLOCK_SEQ=BLOCK_SEQ, + HEAD_DIM=head_dim, + num_warps=num_warps, + num_stages=1, + ) + return diff --git a/colossalai/kernel/triton/smooth_attention.py b/colossalai/kernel/triton/smooth_attention.py new file mode 100644 index 000000000000..ee0df6a74eaa --- /dev/null +++ b/colossalai/kernel/triton/smooth_attention.py @@ -0,0 +1,652 @@ +import math + +import torch + +try: + import triton + import triton.language as tl + + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + +if HAS_TRITON: + """ + this function is modified from + https://github.com/ModelTC/lightllm/blob/f093edc20683ac3ea1bca3fb5d8320a0dd36cf7b/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L10 + """ + + @triton.jit + def _context_flash_attention_kernel( + Q, + K, + V, + q_input_scale, + k_input_scale, + v_input_scale, + pv_output_scale, + sm_scale, + B_Start_Loc, + B_Seqlen, + TMP, + alibi_ptr, + Out, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_tmp_b, + stride_tmp_h, + stride_tmp_s, + # suggtest set-up 64, 128, 256, 512 + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + batch_id = tl.program_id(0) + cur_head = tl.program_id(1) + start_m = tl.program_id(2) + + # initialize offsets + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + + # get batch info + cur_batch_seq_len = tl.load(B_Seqlen + batch_id) + cur_batch_start_index = tl.load(B_Start_Loc + batch_id) + block_start_loc = BLOCK_M * start_m + + load_p_ptrs = ( + Q + + (cur_batch_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + + offs_d[None, :] * stride_qd + ) + q = tl.load(load_p_ptrs, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0) + q = q.to(tl.float16) * q_input_scale.to(tl.float16) + + k_ptrs = K + offs_n[None, :] * stride_kbs + cur_head * stride_kh + offs_d[:, None] * stride_kd + v_ptrs = V + offs_n[:, None] * stride_vbs + cur_head * stride_vh + offs_d[None, :] * stride_vd + t_ptrs = TMP + batch_id * stride_tmp_b + cur_head * stride_tmp_h + offs_m * stride_tmp_s + + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + if alibi_ptr is not None: + alibi_m = tl.load(alibi_ptr + cur_head) + + block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) + + for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + k = tl.load( + k_ptrs + (cur_batch_start_index + start_n) * stride_kbs, + mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, + other=0.0, + ) + k = k.to(tl.float16) * k_input_scale.to(tl.float16) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk *= sm_scale + + if alibi_ptr is not None: + alibi_loc = offs_m[:, None] - (start_n + offs_n[None, :]) + qk -= alibi_loc * alibi_m + + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) + + m_ij = tl.max(qk, 1) + p = tl.exp(qk - m_ij[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + m_i_new = tl.maximum(m_i, m_ij) + alpha = tl.exp(m_i - m_i_new) + beta = tl.exp(m_ij - m_i_new) + l_i_new = alpha * l_i + beta * l_ij + # -- update output accumulator -- + # scale p + p_scale = beta / l_i_new + p = p * p_scale[:, None] + # scale acc + acc_scale = l_i / l_i_new * alpha + tl.store(t_ptrs, acc_scale) + acc_scale = tl.load(t_ptrs) + acc = acc * acc_scale[:, None] + # update acc + v = tl.load( + v_ptrs + (cur_batch_start_index + start_n) * stride_vbs, + mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, + other=0.0, + ) + + v = v.to(tl.float16) * v_input_scale.to(tl.float16) + p = p.to(v.dtype) + acc += tl.dot(p, v) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + acc = (acc / pv_output_scale.to(tl.float16)).to(tl.int8) + off_o = ( + (cur_batch_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od + ) + out_ptrs = Out + off_o + tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) + return + + + + @torch.no_grad() + def smooth_llama_context_attn_fwd( + q, k, v, o, q_input_scale, k_input_scale, v_input_scale, pv_output_scale, b_start_loc, b_seq_len, max_input_len + ): + + BLOCK = 128 + # shape constraints + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + assert Lq == Lk, "context process only supports equal query, key, value length" + assert Lk == Lv, "context process only supports equal query, key, value length" + assert Lk in {16, 32, 64, 128} + BLOCK_N = 128 + sm_scale = 1.0 / math.sqrt(Lk) + batch, head = b_seq_len.shape[0], q.shape[1] + grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) + + tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32) + num_warps = 4 if Lk <= 64 else 8 + + _context_flash_attention_kernel[grid]( + q, + k, + v, + q_input_scale, + k_input_scale, + v_input_scale, + pv_output_scale, + sm_scale, + b_start_loc, + b_seq_len, + tmp, + None, + o, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + tmp.stride(0), + tmp.stride(1), + tmp.stride(2), + BLOCK_M=BLOCK, + BLOCK_DMODEL=Lk, + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + return + + @triton.jit + def _token_attn_1_kernel( + Q, + K, + q_input_scale, + k_input_scale, + sm_scale, + kv_cache_loc, + kv_cache_start_loc, + kv_cache_seqlen, + max_kv_cache_len, + attn_out, + kv_cache_loc_b_stride, + kv_cache_loc_s_stride, + q_batch_stride, + q_head_stride, + q_head_dim_stride, + k_batch_stride, + k_head_stride, + k_head_dim_stride, + attn_head_stride, + attn_batch_stride, + HEAD_DIM: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + current_batch = tl.program_id(0) + current_head = tl.program_id(1) + start_n = tl.program_id(2) + + offs_d = tl.arange(0, HEAD_DIM) + current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch) + current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch) + + current_batch_start_index = max_kv_cache_len - current_batch_seq_len + current_batch_end_index = max_kv_cache_len + + off_q = current_batch * q_batch_stride + current_head * q_head_stride + offs_d * q_head_dim_stride + + offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) + + block_stard_index = start_n * BLOCK_N + block_mask = tl.where(block_stard_index < current_batch_seq_len, 1, 0) + + for start_mark in range(0, block_mask, 1): + q = tl.load(Q + off_q + start_mark) + q = q.to(tl.float16) * q_input_scale.to(tl.float16) + offs_n_new = current_batch_start_index + offs_n + k_loc = tl.load( + kv_cache_loc + kv_cache_loc_b_stride * current_batch + kv_cache_loc_s_stride * offs_n_new, + mask=offs_n_new < current_batch_end_index, + other=0, + ) + off_k = k_loc[:, None] * k_batch_stride + current_head * k_head_stride + offs_d[None, :] * k_head_dim_stride + k = tl.load(K + off_k, mask=offs_n_new[:, None] < current_batch_end_index, other=0.0) + k = k.to(tl.float16) * k_input_scale.to(tl.float16) + att_value = tl.sum(q[None, :] * k, 1) + att_value *= sm_scale + off_o = current_head * attn_head_stride + (current_batch_in_all_start_index + offs_n) * attn_batch_stride + tl.store(attn_out + off_o, att_value, mask=offs_n_new < current_batch_end_index) + return + + @triton.jit + def _token_attn_1_alibi_kernel( + Q, + K, + q_input_scale, + k_input_scale, + sm_scale, + alibi, + kv_cache_loc, + kv_cache_start_loc, + kv_cache_seqlen, + max_kv_cache_len, + attn_out, + kv_cache_loc_b_stride, + kv_cache_loc_s_stride, + q_batch_stride, + q_head_stride, + q_head_dim_stride, + k_batch_stride, + k_head_stride, + k_head_dim_stride, + attn_head_stride, + attn_batch_stride, + HEAD_DIM: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + current_batch = tl.program_id(0) + current_head = tl.program_id(1) + start_n = tl.program_id(2) + + offs_d = tl.arange(0, HEAD_DIM) + current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch) + current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch) + + current_batch_start_index = max_kv_cache_len - current_batch_seq_len + current_batch_end_index = max_kv_cache_len + + off_q = current_batch * q_batch_stride + current_head * q_head_stride + offs_d * q_head_dim_stride + + offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) + + block_stard_index = start_n * BLOCK_N + block_mask = tl.where(block_stard_index < current_batch_seq_len, 1, 0) + + for start_mark in range(0, block_mask, 1): + alibi_m = tl.load(alibi + current_head) + q = tl.load(Q + off_q + start_mark) + q = q.to(tl.float16) * q_input_scale.to(tl.float16) + + offs_n_new = current_batch_start_index + offs_n + k_loc = tl.load( + kv_cache_loc + kv_cache_loc_b_stride * current_batch + kv_cache_loc_s_stride * offs_n_new, + mask=offs_n_new < current_batch_end_index, + other=0, + ) + off_k = k_loc[:, None] * k_batch_stride + current_head * k_head_stride + offs_d[None, :] * k_head_dim_stride + k = tl.load(K + off_k, mask=offs_n_new[:, None] < current_batch_end_index, other=0.0) + k = k.to(tl.float16) * k_input_scale.to(tl.float16) + att_value = tl.sum(q[None, :] * k, 1) + att_value *= sm_scale + att_value -= alibi_m * (current_batch_seq_len - 1 - offs_n) + off_o = current_head * attn_head_stride + (current_batch_in_all_start_index + offs_n) * attn_batch_stride + tl.store(attn_out + off_o, att_value, mask=offs_n_new < current_batch_end_index) + return + + @torch.no_grad() + def token_attn_fwd_1( + q, + k, + attn_out, + q_input_scale, + k_input_scale, + kv_cache_loc, + kv_cache_start_loc, + kv_cache_seqlen, + max_kv_cache_len, + alibi=None, + ): + BLOCK = 32 + # shape constraints + q_head_dim, k_head_dim = q.shape[-1], k.shape[-1] + assert q_head_dim == k_head_dim + assert k_head_dim in {16, 32, 64, 128} + sm_scale = 1.0 / (k_head_dim**0.5) + + batch, head_num = kv_cache_loc.shape[0], q.shape[1] + + grid = (batch, head_num, triton.cdiv(max_kv_cache_len, BLOCK)) + + num_warps = 4 if k_head_dim <= 64 else 8 + num_warps = 2 + + if alibi is not None: + _token_attn_1_alibi_kernel[grid]( + q, + k, + q_input_scale, + k_input_scale, + sm_scale, + alibi, + kv_cache_loc, + kv_cache_start_loc, + kv_cache_seqlen, + max_kv_cache_len, + attn_out, + kv_cache_loc.stride(0), + kv_cache_loc.stride(1), + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + attn_out.stride(0), + attn_out.stride(1), + HEAD_DIM=k_head_dim, + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + else: + _token_attn_1_kernel[grid]( + q, + k, + q_input_scale, + k_input_scale, + sm_scale, + kv_cache_loc, + kv_cache_start_loc, + kv_cache_seqlen, + max_kv_cache_len, + attn_out, + kv_cache_loc.stride(0), + kv_cache_loc.stride(1), + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + attn_out.stride(0), + attn_out.stride(1), + HEAD_DIM=k_head_dim, + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + return + + @triton.jit + def _token_attn_softmax_fwd( + softmax_logics, + kv_cache_start_loc, + kv_cache_seqlen, + softmax_prob_out, + logics_head_dim_stride, + logics_batch_stride, + prob_head_dim_stride, + prob_batch_stride, + BLOCK_SIZE: tl.constexpr, + ): + current_batch = tl.program_id(0) + current_head = tl.program_id(1) + + col_offsets = tl.arange(0, BLOCK_SIZE) + current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch) + current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch) + + row = tl.load( + softmax_logics + + current_head * logics_head_dim_stride + + (current_batch_in_all_start_index + col_offsets) * logics_batch_stride, + mask=col_offsets < current_batch_seq_len, + other=-float("inf"), + ).to(tl.float32) + + row_minus_max = row - tl.max(row, axis=0) + numerator = tl.exp(row_minus_max) + denominator = tl.sum(numerator, axis=0) + softmax_output = numerator / denominator + + tl.store( + softmax_prob_out + + current_head * prob_head_dim_stride + + (current_batch_in_all_start_index + col_offsets) * prob_batch_stride, + softmax_output, + mask=col_offsets < current_batch_seq_len, + ) + return + + @torch.no_grad() + def token_attn_softmax_fwd(softmax_logics, kv_cache_start_loc, kv_cache_seqlen, softmax_prob_out, max_kv_cache_len): + BLOCK_SIZE = triton.next_power_of_2(max_kv_cache_len) + batch, head_num = kv_cache_start_loc.shape[0], softmax_logics.shape[0] + + num_warps = 4 + if BLOCK_SIZE >= 2048: + num_warps = 8 + if BLOCK_SIZE >= 4096: + num_warps = 16 + + _token_attn_softmax_fwd[(batch, head_num)]( + softmax_logics, + kv_cache_start_loc, + kv_cache_seqlen, + softmax_prob_out, + softmax_logics.stride(0), + softmax_logics.stride(1), + softmax_prob_out.stride(0), + softmax_prob_out.stride(1), + num_warps=num_warps, + BLOCK_SIZE=BLOCK_SIZE, + ) + return + + @triton.jit + def _token_attn_2_kernel( + Prob, + V, + attn_out, + v_input_scale, + pv_output_scale, + kv_cache_loc, + kv_cache_start_loc, + kv_cache_seqlen, + max_kv_cache_len, + kv_cache_loc_b_stride, + kv_cache_loc_s_stride, + prob_head_dim_stride, + prob_batch_stride, + v_batch_stride, + v_head_stride, + v_head_dim_stride, + attn_out_batch_stride, + attn_out_head_stride, + attn_out_head_dim_stride, + HEAD_DIM: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + current_batch = tl.program_id(0) + current_head = tl.program_id(1) + + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, HEAD_DIM) + current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch) + current_batch_start_index = max_kv_cache_len - current_batch_seq_len + current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch) + + v_loc_off = current_batch * kv_cache_loc_b_stride + (current_batch_start_index + offs_n) * kv_cache_loc_s_stride + p_offs = current_head * prob_head_dim_stride + (current_batch_in_all_start_index + offs_n) * prob_batch_stride + v_offs = current_head * v_head_stride + offs_d[None, :] * v_head_dim_stride + + acc = tl.zeros([HEAD_DIM], dtype=tl.float32) + for start_n in range(0, current_batch_seq_len, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + p_value = tl.load( + Prob + p_offs + start_n * kv_cache_loc_s_stride, + mask=(start_n + offs_n) < current_batch_seq_len, + other=0.0, + ) + v_loc = tl.load( + kv_cache_loc + v_loc_off + start_n * kv_cache_loc_s_stride, + mask=(start_n + offs_n) < current_batch_seq_len, + other=0.0, + ) + v_value = tl.load( + V + v_offs + v_loc[:, None] * v_batch_stride, + mask=(start_n + offs_n[:, None]) < current_batch_seq_len, + other=0.0, + ) + v_value = v_value.to(tl.float16) * v_input_scale.to(tl.float16) + acc += tl.sum(p_value[:, None] * v_value, 0) + + acc = (acc / pv_output_scale.to(tl.float16)).to(tl.int8) + off_o = ( + current_batch * attn_out_batch_stride + + current_head * attn_out_head_stride + + offs_d * attn_out_head_dim_stride + ) + out_ptrs = attn_out + off_o + tl.store(out_ptrs, acc) + return + + @torch.no_grad() + def token_attn_fwd_2( + prob, + v, + attn_out, + v_input_scale, + pv_output_scale, + kv_cache_loc, + kv_cache_start_loc, + kv_cache_seqlen, + max_kv_cache_len, + ): + if triton.__version__ >= "2.1.0": + BLOCK = 128 + else: + BLOCK = 64 + batch, head = kv_cache_loc.shape[0], v.shape[1] + grid = (batch, head) + num_warps = 4 + dim = v.shape[-1] + + _token_attn_2_kernel[grid]( + prob, + v, + attn_out, + v_input_scale, + pv_output_scale, + kv_cache_loc, + kv_cache_start_loc, + kv_cache_seqlen, + max_kv_cache_len, + kv_cache_loc.stride(0), + kv_cache_loc.stride(1), + prob.stride(0), + prob.stride(1), + v.stride(0), + v.stride(1), + v.stride(2), + attn_out.stride(0), + attn_out.stride(1), + attn_out.stride(2), + HEAD_DIM=dim, + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + return + + @torch.no_grad() + def smooth_token_attention_fwd( + q, + k, + v, + attn_out, + q_input_scale, + k_input_scale, + v_input_scale, + pv_output_scale, + kv_cache_loc, + kv_cache_start_loc, + kv_cache_seq_len, + max_len_in_batch, + alibi=None, + ): + head_num = k.shape[1] + batch_size = kv_cache_seq_len.shape[0] + calcu_shape1 = (batch_size, head_num, k.shape[2]) + total_token_num = k.shape[0] + + att_m_tensor = torch.empty((head_num, total_token_num), dtype=torch.float32, device="cuda") + + token_attn_fwd_1( + q.view(calcu_shape1), + k, + att_m_tensor, + q_input_scale, + k_input_scale, + kv_cache_loc, + kv_cache_start_loc, + kv_cache_seq_len, + max_len_in_batch, + alibi=alibi, + ) + + prob = torch.empty_like(att_m_tensor) + + token_attn_softmax_fwd(att_m_tensor, kv_cache_start_loc, kv_cache_seq_len, prob, max_len_in_batch) + att_m_tensor = None + token_attn_fwd_2( + prob, + v, + attn_out.view(calcu_shape1), + v_input_scale, + pv_output_scale, + kv_cache_loc, + kv_cache_start_loc, + kv_cache_seq_len, + max_len_in_batch, + ) + + prob = None + + return diff --git a/examples/inference/smoothquant_llama.py b/examples/inference/smoothquant_llama.py new file mode 100644 index 000000000000..ce7a00aa2739 --- /dev/null +++ b/examples/inference/smoothquant_llama.py @@ -0,0 +1,69 @@ +import argparse +import os + +import torch +from datasets import load_dataset +from transformers import LlamaTokenizer + +from colossalai.inference.quant.smoothquant.models.llama import SmoothLlamaForCausalLM + + +def build_model_and_tokenizer(model_name): + tokenizer = LlamaTokenizer.from_pretrained(model_name, model_max_length=512) + kwargs = {"torch_dtype": torch.float16, "device_map": "sequential"} + model = SmoothLlamaForCausalLM.from_pretrained(model_name, **kwargs) + model = model.to(torch.float32) + return model, tokenizer + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--model-name", type=str, help="model name") + parser.add_argument( + "--output-path", + type=str, + help="where to save the checkpoint", + ) + parser.add_argument( + "--dataset-path", + type=str, + help="location of the calibration dataset", + ) + parser.add_argument("--num-samples", type=int, default=512) + parser.add_argument("--seq-len", type=int, default=512) + args = parser.parse_args() + return args + + +@torch.no_grad() +def main(): + args = parse_args() + model_path = args.model_name + dataset_path = args.dataset_path + output_path = args.output_path + num_samples = 10 + seq_len = 512 + + model, tokenizer = build_model_and_tokenizer(model_path) + if not os.path.exists(dataset_path): + print(f"Cannot find the dataset at {args.dataset_path}") + raise FileNotFoundError + dataset = load_dataset("json", data_files=dataset_path, split="train") + + model.quantized(tokenizer, dataset, num_samples=num_samples, seq_len=seq_len) + model = model.cuda() + + model.save_quantized(output_path, model_basename="llama-7b") + + model = SmoothLlamaForCausalLM.from_quantized(output_path, model_basename="llama-7b") + model = model.cuda() + + generate_kwargs = dict(max_new_tokens=16, do_sample=False, use_cache=True) + input_tokens = tokenizer(["today is "], return_tensors="pt").to("cuda") + out = model.generate(**input_tokens, **generate_kwargs) + text = tokenizer.batch_decode(out) + print("out is:", text) + + +if __name__ == "__main__": + main() diff --git a/op_builder/smoothquant.py b/op_builder/smoothquant.py new file mode 100644 index 000000000000..d562a4c4f626 --- /dev/null +++ b/op_builder/smoothquant.py @@ -0,0 +1,52 @@ +import torch + +from .builder import Builder +from .utils import append_nvcc_threads + + +class SmoothquantBuilder(Builder): + NAME = "cu_smoothquant" + PREBUILT_IMPORT_PATH = "colossalai._C.cu_smoothquant" + + def __init__(self): + super().__init__(name=SmoothquantBuilder.NAME, prebuilt_import_path=SmoothquantBuilder.PREBUILT_IMPORT_PATH) + + def include_dirs(self): + ret = [self.csrc_abs_path("smoothquant"), self.get_cuda_home_include()] + return ret + + def sources_files(self): + ret = [ + self.csrc_abs_path(fname) + for fname in [ + "smoothquant/binding.cpp", + "smoothquant/linear.cu", + ] + ] + return ret + + def cxx_flags(self): + return ["-O3"] + self.version_dependent_macros + + def nvcc_flags(self): + compute_capability = torch.cuda.get_device_capability() + cuda_arch = compute_capability[0] * 100 + compute_capability[1] * 10 + + extra_cuda_flags = [ + "-v", + f"-DCUDA_ARCH={cuda_arch}", + "-std=c++17", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_HALF2_OPERATORS__", + "-DTHRUST_IGNORE_CUB_VERSION_CHECK", + ] + + ret = ["-O3", "--use_fast_math"] + self.version_dependent_macros + extra_cuda_flags + return append_nvcc_threads(ret) + + def builder(self): + try: + super().builder() + except: + warnings.warn("build smoothquant lib not successful") diff --git a/tests/test_smoothquant/test_llama_attention.py b/tests/test_smoothquant/test_llama_attention.py new file mode 100644 index 000000000000..f8c79145c952 --- /dev/null +++ b/tests/test_smoothquant/test_llama_attention.py @@ -0,0 +1,136 @@ +import pytest +import torch +from packaging import version + +try: + from colossalai.kernel.triton import int8_rotary_embedding_fwd + + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + +try: + from colossalai.inference.quant.smoothquant.models import LLamaSmoothquantAttention + + HAS_TORCH_INT = True +except ImportError: + HAS_TORCH_INT = False + print("Please install torch_int from https://github.com/Guangxuan-Xiao/torch-int") + + +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") + +import math + +import torch +from torch.nn import functional as F + + +def torch_context_attention(xq, xk, xv, bs, seqlen, num_head, head_dim): + """ + adapted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/bloom/triton_kernel/context_flashattention_nopad.py#L253 + """ + xq = xq.view(bs, seqlen, num_head, head_dim) + xk = xk.view(bs, seqlen, num_head, head_dim) + xv = xv.view(bs, seqlen, num_head, head_dim) + mask = torch.tril(torch.ones(seqlen, seqlen), diagonal=0).unsqueeze(0).unsqueeze(0).cuda() + mask[mask == 0.0] = -100000000.0 + mask = mask.repeat(bs, num_head, 1, 1) + keys = xk + values = xv + xq = xq.transpose(1, 2) + keys = keys.transpose(1, 2) + values = values.transpose(1, 2) + scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(head_dim) + scores = F.softmax(scores.float() + mask, dim=-1).type_as(xq) + output = torch.matmul(scores, values).transpose(1, 2).contiguous().reshape(-1, num_head, head_dim) + + return output + + +@pytest.mark.skipif( + not TRITON_CUDA_SUPPORT or not HAS_TRITON or not HAS_TORCH_INT, + reason="triton requires cuda version to be higher than 11.4 or not install torch_int", +) +def test_llama_context_attention(): + head_num = 2 + seq_len = 32 + head_dim = 64 + dtype = torch.float + hidden_size = head_num * head_dim + + smooth_attn = LLamaSmoothquantAttention(head_num * head_dim, head_num) + + smooth_attn.q_proj.weight = torch.ones(hidden_size, hidden_size, device="cuda").to(torch.int8) + smooth_attn.k_proj.weight = torch.ones(hidden_size, hidden_size, device="cuda").to(torch.int8) + smooth_attn.v_proj.weight = torch.ones(hidden_size, hidden_size, device="cuda").to(torch.int8) + smooth_attn.out_proj.weight = torch.ones(hidden_size, hidden_size, device="cuda").to(torch.int8) + smooth_attn.out_proj.weight[:, 1:hidden_size] = torch.zeros(hidden_size - 1, device="cuda").to(torch.int8) + + qkv_weight_scale = 1.0 + + ones = torch.ones(hidden_size, hidden_size, dtype=torch.float, device="cuda") + + smooth_attn = smooth_attn.to("cuda") + + input = torch.randint(-20, 20, (1, seq_len, head_num * head_dim), dtype=torch.int8, device="cuda") + input_scale = 1 / 20.0 + + output = torch.matmul(input.to(torch.float) * input_scale, ones) + qkv_max_out = torch.max(torch.abs(output)) / 127 + smooth_attn.q_proj.a = torch.tensor(input_scale * qkv_weight_scale / qkv_max_out) + smooth_attn.k_proj.a = torch.tensor(input_scale * qkv_weight_scale / qkv_max_out) + smooth_attn.v_proj.a = torch.tensor(input_scale * qkv_weight_scale / qkv_max_out) + + q = smooth_attn.q_proj(input) + k = smooth_attn.k_proj(input) + v = smooth_attn.v_proj(input) + + cos_shape = (seq_len, head_dim // 2) + cos = torch.ones(cos_shape, dtype=dtype, device="cuda") + sin = torch.zeros(cos_shape, dtype=dtype, device="cuda") + in_scale = torch.tensor([qkv_max_out], device="cuda") + out_scale = torch.tensor([qkv_max_out], device="cuda") + int8_rotary_embedding_fwd(q.view(-1, head_num, head_dim), cos, sin, in_scale.item(), out_scale.item()) + int8_rotary_embedding_fwd(k.view(-1, head_num, head_dim), cos, sin, in_scale.item(), out_scale.item()) + + q = q.to(torch.float) * out_scale + k = k.to(torch.float) * out_scale + v = v.to(torch.float) * out_scale + torch_out = torch_context_attention(q.clone(), k.clone(), v.clone(), 1, seq_len, head_num, head_dim) + attn_out_max = torch.max(torch.abs(torch_out)) / 127 + + output = torch.matmul(torch_out.view(-1, seq_len, head_num * head_dim), ones) + smooth_attn.q_output_scale = torch.tensor(qkv_max_out) + smooth_attn.k_output_scale = torch.tensor(qkv_max_out) + + smooth_attn.v_output_scale = torch.tensor(qkv_max_out) + smooth_attn.q_rotary_output_scale = torch.tensor(qkv_max_out) + smooth_attn.k_rotary_output_scale = torch.tensor(qkv_max_out) + + smooth_attn.attn_output_scale = torch.tensor(attn_out_max) + smooth_attn.out_proj.a = torch.tensor([attn_out_max]) + + torch_out = ( + (torch_out / smooth_attn.attn_output_scale) + .round() + .clamp(-128, 127) + .to(torch.int8) + .view(-1, seq_len, head_num * head_dim) + ) + + torch_out = smooth_attn.out_proj(torch_out) + torch_out = torch_out.to(torch.float) + + smooth_attn = smooth_attn.to("cuda") + smooth_out, _, _ = smooth_attn(input, (cos, sin)) + smooth_out = smooth_out.to(torch.float) + + assert torch.allclose( + torch_out.cpu(), smooth_out.cpu(), rtol=1e-1, atol=1e-1 + ), "outputs from triton and torch are not matched" + + +if __name__ == "__main__": + test_llama_context_attention() diff --git a/tests/test_smoothquant/test_llama_mlp.py b/tests/test_smoothquant/test_llama_mlp.py new file mode 100644 index 000000000000..236edb10cb7f --- /dev/null +++ b/tests/test_smoothquant/test_llama_mlp.py @@ -0,0 +1,84 @@ +import warnings + +import pytest +import torch +from packaging import version + +try: + from colossalai.kernel.op_builder.smoothquant import SmoothquantBuilder + + smoothquant_cuda = SmoothquantBuilder().load() + HAS_SMOOTHQUANT_CUDA = True +except: + warnings.warn("CUDA smoothquant linear is not installed") + HAS_SMOOTHQUANT_CUDA = False + + +try: + from colossalai.inference.quant.smoothquant.models import LlamaSmoothquantMLP + + HAS_TORCH_INT = True +except: + HAS_TORCH_INT = False + warnings.warn("Please install torch_int from https://github.com/Guangxuan-Xiao/torch-int") + + +CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") + + +def torch_llama_mlp(gate_proj, up_proj, down_proj, x): + gate_out = torch.mm(x, gate_proj) + silu = torch.nn.SiLU() + gate_out = silu(gate_out) + up_out = torch.mm(x, up_proj) + + o_out = gate_out * up_out + + max_up = torch.max(torch.abs(o_out)) + min_up = torch.min(torch.abs(o_out)) + + torch_out = torch.mm(o_out, down_proj) + + return (torch_out, max_up, min_up) + + +@pytest.mark.skipif( + not CUDA_SUPPORT or not HAS_SMOOTHQUANT_CUDA or not HAS_TORCH_INT, + reason="smoothquant linear not installed properly or not install torch_int", +) +def test_llama_mlp(): + hidden_size = 256 + intermediate_size = 512 + + smooth_mlp = LlamaSmoothquantMLP(intermediate_size, hidden_size) + + smooth_mlp.gate_proj.weight = torch.ones((intermediate_size, hidden_size), dtype=torch.int8, device="cuda") + + smooth_mlp.up_proj.weight = torch.randint( + -10, 10, (intermediate_size, hidden_size), dtype=torch.int8, device="cuda" + ) + smooth_mlp.down_proj.weight = torch.randint( + -10, 10, (hidden_size, intermediate_size), dtype=torch.int8, device="cuda" + ) + + x = torch.ones((1, 256), dtype=torch.int8, device="cuda") + + torch_out, max_inter, min_inter = torch_llama_mlp( + smooth_mlp.gate_proj.weight.transpose(0, 1).to(torch.float) / hidden_size, + smooth_mlp.up_proj.weight.transpose(0, 1).to(torch.float) / 127, + smooth_mlp.down_proj.weight.transpose(0, 1).to(torch.float) / 127, + x.to(torch.float), + ) + + smooth_mlp.down_proj_input_scale = torch.tensor(max_inter.item() / 127) + smooth_mlp.gate_proj.a = torch.tensor(1 / hidden_size) + smooth_mlp.up_proj.a = torch.tensor(1 / 127) + smooth_mlp.down_proj.a = torch.tensor(1 / 127 * (max_inter.item() / 127)) + + smooth_out = smooth_mlp(x) + + assert torch.allclose(torch_out, smooth_out, rtol=1e-02, atol=1e-01) + + +if __name__ == "__main__": + test_llama_mlp() diff --git a/tests/test_smoothquant/test_smoothquant_linear.py b/tests/test_smoothquant/test_smoothquant_linear.py new file mode 100644 index 000000000000..58a0b82f6759 --- /dev/null +++ b/tests/test_smoothquant/test_smoothquant_linear.py @@ -0,0 +1,39 @@ +import warnings + +import pytest +import torch + +try: + from colossalai.kernel.op_builder.smoothquant import SmoothquantBuilder + + smoothquant_cuda = SmoothquantBuilder().load() + HAS_SMOOTHQUANT_CUDA = True +except: + warnings.warn("CUDA smoothquant linear is not installed") + HAS_SMOOTHQUANT_CUDA = False + + +@pytest.mark.skipif( + not HAS_SMOOTHQUANT_CUDA, + reason="smoothquant linear not installed properly", +) +def test_linear(): + a = torch.randint(-127, 127, (128, 512), dtype=torch.int8, device="cuda") + b = torch.randint(-127, 127, (512, 256), dtype=torch.int8, device="cuda") + c = torch.rand(256, dtype=torch.float, device="cuda") + + alpha = 1 / 127 + beta = 1.0 + torch_out = torch.mm(a.to(torch.float) * alpha, b.to(torch.float)) + c + + silu = torch.nn.SiLU() + torch_out = silu(torch_out) + + b = b.transpose(0, 1).contiguous() + cuda_out = smoothquant_cuda.linear_silu_a8_w8_bfp32_ofp32(a, b, c, alpha, beta) + + assert torch.allclose(torch_out, cuda_out, rtol=1e-02, atol=1e-02) + + +if __name__ == "__main__": + test_linear() diff --git a/tests/test_smoothquant/test_sq_rotary_embedding.py b/tests/test_smoothquant/test_sq_rotary_embedding.py new file mode 100644 index 000000000000..4cc76f00474d --- /dev/null +++ b/tests/test_smoothquant/test_sq_rotary_embedding.py @@ -0,0 +1,59 @@ +# Adapted from ModelTC https://github.com/ModelTC/lightllm + + +import pytest +import torch +from packaging import version + +try: + from colossalai.kernel.triton import int8_rotary_embedding_fwd + + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") + + +def torch_rotary_emb(x, cos, sin): + seq_len, h, dim = x.shape + x0 = x[:, :, 0 : dim // 2] + x1 = x[:, :, dim // 2 : dim] + cos = cos.view((seq_len, 1, dim // 2)) + sin = sin.view((seq_len, 1, dim // 2)) + o0 = x0 * cos - x1 * sin + o1 = x0 * sin + x1 * cos + return torch.cat((o0, o1), dim=-1) + + +@pytest.mark.skipif( + not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4" +) +def test_rotary_emb(): + SEQ_LEN = 1 + HEAD_NUM = 32 + HEAD_DIM = 128 + dtype = torch.float + # create data + x_shape = (SEQ_LEN, HEAD_NUM, HEAD_DIM) + x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda") + cos_shape = (SEQ_LEN, HEAD_DIM // 2) + cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda") + sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda") + # forward pass + y_torch = torch_rotary_emb(x, cos, sin) + + input_scale = torch.max(torch.abs(x)) / 127 + output_scale = torch.max(torch.abs(y_torch)) / 127 + + x = x / input_scale + x = x.to(torch.int8) + + int8_rotary_embedding_fwd(x, cos, sin, input_scale.item(), output_scale.item()) + y_triton = x.to(torch.float) * output_scale + assert torch.allclose(y_triton, y_torch, atol=2e-1, rtol=1e-2, equal_nan=True) + + +if __name__ == "__main__": + test_rotary_emb() From 52707c6328b25961b27be9af443e74f8e5883d87 Mon Sep 17 00:00:00 2001 From: "Zian(Andy) Zheng" <62330719+Orion-Zheng@users.noreply.github.com> Date: Fri, 13 Oct 2023 16:46:33 +0800 Subject: [PATCH 03/46] Update flash_attention_patch.py To be compatible with the new change in the Transformers library, where a new argument 'padding_mask' was added to forward function of attention layer. https://github.com/huggingface/transformers/pull/25598 --- .../colossal_llama2/utils/flash_attention_patch.py | 1 + 1 file changed, 1 insertion(+) diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/utils/flash_attention_patch.py b/applications/Colossal-LLaMA-2/colossal_llama2/utils/flash_attention_patch.py index 6c58c59307a6..111659b2d928 100644 --- a/applications/Colossal-LLaMA-2/colossal_llama2/utils/flash_attention_patch.py +++ b/applications/Colossal-LLaMA-2/colossal_llama2/utils/flash_attention_patch.py @@ -65,6 +65,7 @@ def attention_forward( past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, + **kwargs ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """ Re-define LLaMA-2 `LlamaAttention` forward method using flash-attention. From 61ec9f72edf24083698def2aa11e9ddfac950881 Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Mon, 16 Oct 2023 21:56:53 +0800 Subject: [PATCH 04/46] [kernel] support pure fp16 for cpu adam and update gemini optim tests (#4921) * [kernel] support pure fp16 for cpu adam (#4896) * [kernel] fix cpu adam kernel for pure fp16 and update tests (#4919) * [kernel] fix cpu adam * [test] update gemini optim test --- .../kernel/cuda_native/csrc/cpu_adam.cpp | 201 ++++++++---------- colossalai/kernel/cuda_native/csrc/cpu_adam.h | 41 +++- colossalai/nn/optimizer/cpu_adam.py | 3 +- colossalai/nn/optimizer/hybrid_adam.py | 3 +- tests/test_optimizer/test_adam_kernel.py | 7 +- tests/test_optimizer/test_adam_optim.py | 2 - tests/test_zero/test_gemini/test_grad_clip.py | 12 +- tests/test_zero/test_gemini/test_optim.py | 15 +- 8 files changed, 148 insertions(+), 136 deletions(-) diff --git a/colossalai/kernel/cuda_native/csrc/cpu_adam.cpp b/colossalai/kernel/cuda_native/csrc/cpu_adam.cpp index 0ab250218da3..be9300c545c2 100644 --- a/colossalai/kernel/cuda_native/csrc/cpu_adam.cpp +++ b/colossalai/kernel/cuda_native/csrc/cpu_adam.cpp @@ -35,23 +35,19 @@ SOFTWARE void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg, float *_exp_avg_sq, size_t _param_size, bool param_half_precision, bool grad_half_precision, - float loss_scale) { - size_t rounded_size = 0; + bool momentum_half_precision, + bool variance_half_precision, float loss_scale) { + size_t rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH); float betta1_minus1 = 1 - _betta1; float betta2_minus1 = 1 - _betta2; float step_size = -1 * _alpha / _bias_correction1; float w_decay = -1 * _alpha * _weight_decay; - __half *params_cast_h = NULL; - __half *grads_cast_h = NULL; - - if (param_half_precision) { - params_cast_h = reinterpret_cast<__half *>(_params); - } - if (grad_half_precision) { - grads_cast_h = reinterpret_cast<__half *>(grads); - } + __half *params_cast_h = reinterpret_cast<__half *>(_params); + __half *grads_cast_h = reinterpret_cast<__half *>(grads); + __half *momentum_cast_h = reinterpret_cast<__half *>(_exp_avg); + __half *variance_cast_h = reinterpret_cast<__half *>(_exp_avg_sq); #if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__) AVX_Data betta1_4; @@ -77,7 +73,6 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg, if (_weight_decay > 0) weight_decay_4.data = (_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay)); - rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH); for (size_t t = 0; t < rounded_size; t += TILE) { size_t copy_size = TILE; @@ -87,28 +82,23 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg, #pragma omp parallel for for (size_t i = t; i < offset; i += SIMD_WIDTH) { AVX_Data grad_4; - if (grad_half_precision) { - grad_4.data = SIMD_LOAD_HALF(grads_cast_h + i); - } else { - grad_4.data = SIMD_LOAD(grads + i); - } + this->simd_load(grad_half_precision, grads + i, grads_cast_h + i, grad_4); if (loss_scale > 0) { AVX_Data loss_scale_vec; loss_scale_vec.data = SIMD_SET(loss_scale); grad_4.data = SIMD_DIV(grad_4.data, loss_scale_vec.data); } AVX_Data momentum_4; - momentum_4.data = SIMD_LOAD(_exp_avg + i); + this->simd_load(momentum_half_precision, _exp_avg + i, + momentum_cast_h + i, momentum_4); AVX_Data variance_4; - variance_4.data = SIMD_LOAD(_exp_avg_sq + i); + this->simd_load(variance_half_precision, _exp_avg_sq + i, + variance_cast_h + i, variance_4); AVX_Data param_4; - if (param_half_precision) { - param_4.data = SIMD_LOAD_HALF(params_cast_h + i); - } else { - param_4.data = SIMD_LOAD(_params + i); - } + this->simd_load(param_half_precision, _params + i, params_cast_h + i, + param_4); if (_weight_decay > 0 && !_adamw_mode) { grad_4.data = SIMD_FMA(param_4.data, weight_decay_4.data, grad_4.data); @@ -130,13 +120,12 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg, } param_4.data = SIMD_FMA(grad_4.data, step_size_4.data, param_4.data); - if (param_half_precision) { - SIMD_STORE_HALF((float *)(params_cast_h + i), param_4.data); - } else { - SIMD_STORE(_params + i, param_4.data); - } - SIMD_STORE(_exp_avg + i, momentum_4.data); - SIMD_STORE(_exp_avg_sq + i, variance_4.data); + this->simd_store(param_half_precision, _params + i, params_cast_h + i, + param_4); + this->simd_store(momentum_half_precision, _exp_avg + i, + momentum_cast_h + i, momentum_4); + this->simd_store(variance_half_precision, _exp_avg_sq + i, + variance_cast_h + i, variance_4); } } #endif @@ -154,8 +143,10 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg, } float param = param_half_precision ? (float)params_cast_h[k] : _params[k]; - float momentum = _exp_avg[k]; - float variance = _exp_avg_sq[k]; + float momentum = + momentum_half_precision ? (float)momentum_cast_h[k] : _exp_avg[k]; + float variance = variance_half_precision ? (float)variance_cast_h[k] + : _exp_avg_sq[k]; if (_weight_decay > 0 && !_adamw_mode) { grad = param * _weight_decay + grad; } @@ -178,8 +169,14 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg, params_cast_h[k] = (__half)param; else _params[k] = param; - _exp_avg[k] = momentum; - _exp_avg_sq[k] = variance; + if (momentum_half_precision) + momentum_cast_h[k] = (__half)(momentum); + else + _exp_avg[k] = momentum; + if (variance_half_precision) + variance_cast_h[k] = (__half)(variance); + else + _exp_avg_sq[k] = variance; } } } @@ -188,17 +185,14 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg, void Adam_Optimizer::Step_4(float *_params, float *grads, float *_exp_avg, float *_exp_avg_sq, size_t _param_size, bool param_half_precision, bool grad_half_precision, - float loss_scale) { - size_t rounded_size = 0; + bool momentum_half_precision, + bool variance_half_precision, float loss_scale) { + size_t rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * 4); - __half *params_cast_h = NULL; - __half *grads_cast_h = NULL; - if (param_half_precision) { - params_cast_h = reinterpret_cast<__half *>(_params); - } - if (grad_half_precision) { - grads_cast_h = reinterpret_cast<__half *>(grads); - } + __half *params_cast_h = reinterpret_cast<__half *>(_params); + __half *grads_cast_h = reinterpret_cast<__half *>(grads); + __half *momentum_cast_h = reinterpret_cast<__half *>(_exp_avg); + __half *variance_cast_h = reinterpret_cast<__half *>(_exp_avg_sq); #if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__) AVX_Data betta1_4; @@ -228,7 +222,6 @@ void Adam_Optimizer::Step_4(float *_params, float *grads, float *_exp_avg, if (_weight_decay > 0) weight_decay_4.data = (_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay)); - rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * 4); for (size_t t = 0; t < rounded_size; t += TILE) { size_t copy_size = TILE; @@ -243,26 +236,21 @@ void Adam_Optimizer::Step_4(float *_params, float *grads, float *_exp_avg, AVX_Data param_4[4]; #pragma unroll 4 for (int j = 0; j < 4; j++) { - if (grad_half_precision) { - grad_4[j].data = SIMD_LOAD_HALF(grads_cast_h + i + SIMD_WIDTH * j); - } else { - grad_4[j].data = SIMD_LOAD(grads + i + SIMD_WIDTH * j); - } + this->simd_load(grad_half_precision, grads + i + SIMD_WIDTH * j, + grads_cast_h + i + SIMD_WIDTH * j, grad_4[j]); if (loss_scale > 0) { AVX_Data loss_scale_vec; loss_scale_vec.data = SIMD_SET(loss_scale); grad_4[j].data = SIMD_DIV(grad_4[j].data, loss_scale_vec.data); } - - momentum_4[j].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH * j); - variance_4[j].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH * j); - - if (param_half_precision) { - param_4[j].data = SIMD_LOAD_HALF(params_cast_h + i + SIMD_WIDTH * j); - } else { - param_4[j].data = SIMD_LOAD(_params + i + SIMD_WIDTH * j); - } + this->simd_load(momentum_half_precision, _exp_avg + i + SIMD_WIDTH * j, + momentum_cast_h + i + SIMD_WIDTH * j, momentum_4[j]); + this->simd_load(variance_half_precision, + _exp_avg_sq + i + SIMD_WIDTH * j, + variance_cast_h + i + SIMD_WIDTH * j, variance_4[j]); + this->simd_load(param_half_precision, _params + i + SIMD_WIDTH * j, + params_cast_h + i + SIMD_WIDTH * j, param_4[j]); if (_weight_decay > 0 && !_adamw_mode) { grad_4[j].data = @@ -285,14 +273,13 @@ void Adam_Optimizer::Step_4(float *_params, float *grads, float *_exp_avg, } param_4[j].data = SIMD_FMA(grad_4[j].data, step_size_4.data, param_4[j].data); - if (param_half_precision) { - SIMD_STORE_HALF((float *)(params_cast_h + i + SIMD_WIDTH * j), - param_4[j].data); - } else { - SIMD_STORE(_params + i + SIMD_WIDTH * j, param_4[j].data); - } - SIMD_STORE(_exp_avg + i + SIMD_WIDTH * j, momentum_4[j].data); - SIMD_STORE(_exp_avg_sq + i + SIMD_WIDTH * j, variance_4[j].data); + this->simd_store(param_half_precision, _params + i + SIMD_WIDTH * j, + params_cast_h + i + SIMD_WIDTH * j, param_4[j]); + this->simd_store(momentum_half_precision, _exp_avg + i + SIMD_WIDTH * j, + momentum_cast_h + i + SIMD_WIDTH * j, momentum_4[j]); + this->simd_store(variance_half_precision, + _exp_avg_sq + i + SIMD_WIDTH * j, + variance_cast_h + i + SIMD_WIDTH * j, variance_4[j]); } } } @@ -302,24 +289,26 @@ void Adam_Optimizer::Step_4(float *_params, float *grads, float *_exp_avg, : _params + rounded_size), (grad_half_precision ? (float *)(grads_cast_h + rounded_size) : grads + rounded_size), - (_exp_avg + rounded_size), (_exp_avg_sq + rounded_size), + (momentum_half_precision ? (float *)(momentum_cast_h + rounded_size) + : _exp_avg + rounded_size), + (variance_half_precision ? (float *)(variance_cast_h + rounded_size) + : _exp_avg_sq + rounded_size), (_param_size - rounded_size), param_half_precision, - grad_half_precision, loss_scale); + grad_half_precision, momentum_half_precision, + variance_half_precision, loss_scale); } void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg, float *_exp_avg_sq, size_t _param_size, bool param_half_precision, bool grad_half_precision, - float loss_scale) { - size_t rounded_size = 0; - __half *params_cast_h = NULL; - __half *grads_cast_h = NULL; - if (param_half_precision) { - params_cast_h = reinterpret_cast<__half *>(_params); - } - if (grad_half_precision) { - grads_cast_h = reinterpret_cast<__half *>(grads); - } + bool momentum_half_precision, + bool variance_half_precision, float loss_scale) { + size_t rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * 8); + __half *params_cast_h = reinterpret_cast<__half *>(_params); + __half *grads_cast_h = reinterpret_cast<__half *>(grads); + __half *momentum_cast_h = reinterpret_cast<__half *>(_exp_avg); + __half *variance_cast_h = reinterpret_cast<__half *>(_exp_avg_sq); + #if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__) AVX_Data betta1_4; betta1_4.data = SIMD_SET(_betta1); @@ -348,7 +337,6 @@ void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg, if (_weight_decay > 0) weight_decay_4.data = (_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay)); - rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * 8); for (size_t t = 0; t < rounded_size; t += TILE) { size_t copy_size = TILE; @@ -363,26 +351,21 @@ void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg, AVX_Data param_4[8]; #pragma unroll 8 for (int j = 0; j < 8; j++) { - if (grad_half_precision) { - grad_4[j].data = SIMD_LOAD_HALF(grads_cast_h + i + SIMD_WIDTH * j); - } else { - grad_4[j].data = SIMD_LOAD(grads + i + SIMD_WIDTH * j); - } + this->simd_load(grad_half_precision, grads + i + SIMD_WIDTH * j, + grads_cast_h + i + SIMD_WIDTH * j, grad_4[j]); if (loss_scale > 0) { AVX_Data loss_scale_vec; loss_scale_vec.data = SIMD_SET(loss_scale); grad_4[j].data = SIMD_DIV(grad_4[j].data, loss_scale_vec.data); } - - momentum_4[j].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH * j); - variance_4[j].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH * j); - - if (param_half_precision) { - param_4[j].data = SIMD_LOAD_HALF(params_cast_h + i + SIMD_WIDTH * j); - } else { - param_4[j].data = SIMD_LOAD(_params + i + SIMD_WIDTH * j); - } + this->simd_load(momentum_half_precision, _exp_avg + i + SIMD_WIDTH * j, + momentum_cast_h + i + SIMD_WIDTH * j, momentum_4[j]); + this->simd_load(variance_half_precision, + _exp_avg_sq + i + SIMD_WIDTH * j, + variance_cast_h + i + SIMD_WIDTH * j, variance_4[j]); + this->simd_load(param_half_precision, _params + i + SIMD_WIDTH * j, + params_cast_h + i + SIMD_WIDTH * j, param_4[j]); if (_weight_decay > 0 && !_adamw_mode) { grad_4[j].data = @@ -405,15 +388,13 @@ void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg, param_4[j].data = SIMD_FMA(grad_4[j].data, step_size_4.data, param_4[j].data); - if (param_half_precision) { - SIMD_STORE_HALF((float *)(params_cast_h + i + SIMD_WIDTH * j), - param_4[j].data); - } else { - SIMD_STORE(_params + i + SIMD_WIDTH * j, param_4[j].data); - } - - SIMD_STORE(_exp_avg + i + (SIMD_WIDTH * j), momentum_4[j].data); - SIMD_STORE(_exp_avg_sq + i + (SIMD_WIDTH * j), variance_4[j].data); + this->simd_store(param_half_precision, _params + i + SIMD_WIDTH * j, + params_cast_h + i + SIMD_WIDTH * j, param_4[j]); + this->simd_store(momentum_half_precision, _exp_avg + i + SIMD_WIDTH * j, + momentum_cast_h + i + SIMD_WIDTH * j, momentum_4[j]); + this->simd_store(variance_half_precision, + _exp_avg_sq + i + SIMD_WIDTH * j, + variance_cast_h + i + SIMD_WIDTH * j, variance_4[j]); } } } @@ -423,9 +404,13 @@ void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg, : _params + rounded_size), (grad_half_precision ? (float *)(grads_cast_h + rounded_size) : grads + rounded_size), - (_exp_avg + rounded_size), (_exp_avg_sq + rounded_size), + (momentum_half_precision ? (float *)(momentum_cast_h + rounded_size) + : _exp_avg + rounded_size), + (variance_half_precision ? (float *)(variance_cast_h + rounded_size) + : _exp_avg_sq + rounded_size), (_param_size - rounded_size), param_half_precision, - grad_half_precision, loss_scale); + grad_half_precision, momentum_half_precision, + variance_half_precision, loss_scale); } void Adam_Optimizer::step(size_t step, float lr, float beta1, float beta2, @@ -447,7 +432,9 @@ void Adam_Optimizer::step(size_t step, float lr, float beta1, float beta2, this->update_state(lr, epsilon, weight_decay, bias_correction); this->Step_8(params_ptr, grads_ptr, exp_avg_ptr, exp_avg_sq_ptr, params_c.numel(), (params.options().dtype() == at::kHalf), - (grads.options().dtype() == at::kHalf), loss_scale); + (grads.options().dtype() == at::kHalf), + (exp_avg.options().dtype() == at::kHalf), + (exp_avg_sq.options().dtype() == at::kHalf), loss_scale); } namespace py = pybind11; diff --git a/colossalai/kernel/cuda_native/csrc/cpu_adam.h b/colossalai/kernel/cuda_native/csrc/cpu_adam.h index 4247da942775..bf9b85997c78 100644 --- a/colossalai/kernel/cuda_native/csrc/cpu_adam.h +++ b/colossalai/kernel/cuda_native/csrc/cpu_adam.h @@ -50,9 +50,9 @@ SOFTWARE #define SIMD_DIV(x, y) _mm512_div_ps(x, y) #define SIMD_LOAD_HALF(x) \ _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(x))) -#define SIMD_STORE_HALF(x, d) \ - _mm256_store_ps( \ - x, _mm256_castsi256_ps(_mm512_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT))) +#define SIMD_STORE_HALF(x, d) \ + _mm256_storeu_ps((float *)(x), _mm256_castsi256_ps(_mm512_cvtps_ph( \ + d, _MM_FROUND_TO_NEAREST_INT))) #elif defined(__AVX256__) or defined(__AVX2__) #define SIMD_WIDTH 8 @@ -66,9 +66,9 @@ SOFTWARE #define SIMD_SQRT(x) _mm256_sqrt_ps(x) #define SIMD_DIV(x, y) _mm256_div_ps(x, y) #define SIMD_LOAD_HALF(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x))) -#define SIMD_STORE_HALF(x, d) \ - _mm_store_ps( \ - x, _mm_castsi128_ps(_mm256_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT))) +#define SIMD_STORE_HALF(x, d) \ + _mm_storeu_ps((float *)(x), _mm_castsi128_ps(_mm256_cvtps_ph( \ + d, _MM_FROUND_TO_NEAREST_INT))) #endif @@ -83,11 +83,12 @@ union AVX_Data { #endif -#define STEP(SPAN) \ - void Step_##SPAN(float *_params, float *grads, float *_exp_avg, \ - float *_exp_avg_sq, size_t _param_size, \ - bool param_half_precision = false, \ - bool grad_half_precision = false, float loss_scale = -1); +#define STEP(SPAN) \ + void Step_##SPAN( \ + float *_params, float *grads, float *_exp_avg, float *_exp_avg_sq, \ + size_t _param_size, bool param_half_precision = false, \ + bool grad_half_precision = false, bool momentum_half_precision = false, \ + bool variance_half_precision = false, float loss_scale = -1); class Adam_Optimizer { public: @@ -141,6 +142,24 @@ class Adam_Optimizer { } } + inline void simd_load(bool is_half, float *ptr, __half *h_ptr, + AVX_Data &data) { + if (is_half) { + data.data = SIMD_LOAD_HALF(h_ptr); + } else { + data.data = SIMD_LOAD(ptr); + } + } + + inline void simd_store(bool is_half, float *ptr, __half *h_ptr, + AVX_Data &data) { + if (is_half) { + SIMD_STORE_HALF(h_ptr, data.data); + } else { + SIMD_STORE(ptr, data.data); + } + } + void step(size_t step, float lr, float beta1, float beta2, float epsilon, float weight_decay, bool bias_correction, torch::Tensor ¶ms, torch::Tensor &grads, torch::Tensor &exp_avg, diff --git a/colossalai/nn/optimizer/cpu_adam.py b/colossalai/nn/optimizer/cpu_adam.py index 1bdb81e2d6ec..238ba366da43 100644 --- a/colossalai/nn/optimizer/cpu_adam.py +++ b/colossalai/nn/optimizer/cpu_adam.py @@ -146,8 +146,7 @@ def step(self, closure=None, div_scale: float = -1): assert state["exp_avg"].device.type == "cpu", "exp_avg should stay on cpu" assert state["exp_avg_sq"].device.type == "cpu", "exp_avg should stay on cpu" self._pre_update(p, "exp_avg", "exp_avg_sq") - # FIXME(ver217): CPU adam kernel only supports fp32 states now - if p.grad.dtype is torch.bfloat16 or p.dtype is not torch.float: + if p.grad.dtype is torch.bfloat16: # cpu adam kernel does not support bf16 now bias_correction1 = 1 - beta1 ** state["step"] bias_correction2 = 1 - beta2 ** state["step"] diff --git a/colossalai/nn/optimizer/hybrid_adam.py b/colossalai/nn/optimizer/hybrid_adam.py index 7dc4590dc3f2..c7a309b872ce 100644 --- a/colossalai/nn/optimizer/hybrid_adam.py +++ b/colossalai/nn/optimizer/hybrid_adam.py @@ -122,8 +122,7 @@ def step(self, closure=None, div_scale: float = -1): assert state["exp_avg"].device.type == "cpu", "exp_avg should stay on cpu" assert state["exp_avg_sq"].device.type == "cpu", "exp_avg should stay on cpu" self._pre_update(p, "exp_avg", "exp_avg_sq") - # FIXME(ver217): CPU adam kernel only supports fp32 states now - if p.grad.dtype is torch.bfloat16 or p.dtype is not torch.float: + if p.grad.dtype is torch.bfloat16: # cpu adam kernel does not support bf16 now bias_correction1 = 1 - beta1 ** state["step"] bias_correction2 = 1 - beta2 ** state["step"] diff --git a/tests/test_optimizer/test_adam_kernel.py b/tests/test_optimizer/test_adam_kernel.py index 8131ea3234d8..6bbe3e4e8172 100644 --- a/tests/test_optimizer/test_adam_kernel.py +++ b/tests/test_optimizer/test_adam_kernel.py @@ -13,9 +13,7 @@ _FUSED_ALLOWED_P_G_TYPES = [ (torch.float, torch.half), (torch.float, torch.float), - (torch.half, torch.float), (torch.half, torch.half), - (torch.bfloat16, torch.float), (torch.float, torch.bfloat16), (torch.bfloat16, torch.bfloat16), ] @@ -23,7 +21,6 @@ _CPU_ALLOWED_P_G_TYPES = [ (torch.float, torch.half), (torch.float, torch.float), - (torch.half, torch.float), (torch.half, torch.half), ] @@ -138,8 +135,8 @@ def check_adam_kernel( master_exp_avg_sq = torch.zeros_like(master_p) p = master_p.clone().to(p_dtype) g = master_g.clone().to(g_dtype) - exp_avg = master_exp_avg.clone() - exp_avg_sq = master_exp_avg_sq.clone() + exp_avg = master_exp_avg.clone().to(p_dtype) + exp_avg_sq = master_exp_avg_sq.clone().to(p_dtype) for step in range(1, 1 + n_steps): torch_adam.update(step, master_p, master_g, master_exp_avg, master_exp_avg_sq) diff --git a/tests/test_optimizer/test_adam_optim.py b/tests/test_optimizer/test_adam_optim.py index 59b40a0afa3c..68d71e3c4194 100644 --- a/tests/test_optimizer/test_adam_optim.py +++ b/tests/test_optimizer/test_adam_optim.py @@ -21,8 +21,6 @@ (torch.float, torch.float), # pure fp32 (torch.float, torch.half), # fp16 amp (torch.float, torch.bfloat16), # bfloat16 amp - # (torch.half, torch.half), # FIXME(ver217): cpu adam kernel does not support pure fp16 - # (torch.bfloat16, torch.bfloat16), # FIXME(ver217): cpu adam kernel does not support pure bfloat16 ] N_STEPS = 3 diff --git a/tests/test_zero/test_gemini/test_grad_clip.py b/tests/test_zero/test_gemini/test_grad_clip.py index a3af81646a18..4c84e9e5a89a 100644 --- a/tests/test_zero/test_gemini/test_grad_clip.py +++ b/tests/test_zero/test_gemini/test_grad_clip.py @@ -52,7 +52,8 @@ def check_param(model: GeminiDDP, torch_model: torch.nn.Module): @parameterize("placement_config", PLACEMENT_CONFIGS) @parameterize("model_name", ["gpt2"]) -def exam_grad_clipping(placement_config, model_name: str): +@parameterize("master_weights", [True, False]) +def exam_grad_clipping(placement_config, model_name: str, master_weights: bool): set_seed(1912) get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() @@ -82,6 +83,7 @@ def exam_grad_clipping(placement_config, model_name: str): chunk_config_dict=config_dict, chunk_init_device=init_device, pin_memory=True, + master_weights=master_weights, **placement_config, ) @@ -103,7 +105,10 @@ def exam_grad_clipping(placement_config, model_name: str): torch_loss = run_fwd_bwd(torch_model, data, label, criterion, torch_optim) loss = run_fwd_bwd(model, data, label, criterion, zero_optim) - assert_close(torch_loss, loss) + + # as no master weights leads to error accumulation, we don't check the loss + if master_weights: + assert_close(torch_loss, loss) import apex.amp as apex_amp @@ -111,7 +116,8 @@ def exam_grad_clipping(placement_config, model_name: str): torch_optim.step() zero_optim.step() - check_param(model, torch_model) + if master_weights: + check_param(model, torch_model) def run_dist(rank, world_size, port): diff --git a/tests/test_zero/test_gemini/test_optim.py b/tests/test_zero/test_gemini/test_optim.py index 8e8e508ff483..9b84d68f3c7a 100644 --- a/tests/test_zero/test_gemini/test_optim.py +++ b/tests/test_zero/test_gemini/test_optim.py @@ -70,12 +70,14 @@ def check_param(model: GeminiDDP, torch_model: torch.nn.Module, dtype: torch.dty @parameterize("placement_config", PLACEMENT_CONFIGS) @parameterize("model_name", TEST_MODELS) @parameterize("mixed_precision", [torch.half, torch.bfloat16]) -def exam_model_step(placement_config, model_name: str, mixed_precision: torch.dtype): +@parameterize("master_weights", [True, False]) +def exam_model_step(placement_config, model_name: str, mixed_precision: torch.dtype, master_weights: bool): set_seed(42) get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() torch_model = model_builder().cuda() + # apex no master weights leads to nan, so we don't use it amp_config = dict(opt_level="O2", keep_batchnorm_fp32=False, loss_scale=128) torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3) torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config) @@ -90,7 +92,9 @@ def exam_model_step(placement_config, model_name: str, mixed_precision: torch.dt config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100) config_dict[world_size]["chunk_size"] = 5000 config_dict[world_size]["keep_gathered"] = False - model = GeminiDDP(model, config_dict, **placement_config, mixed_precision=mixed_precision) + model = GeminiDDP( + model, config_dict, **placement_config, mixed_precision=mixed_precision, master_weights=master_weights + ) optimizer = HybridAdam(model.parameters(), lr=1e-3) zero_optim = GeminiOptimizer(optimizer, model, initial_scale=128) @@ -109,12 +113,15 @@ def exam_model_step(placement_config, model_name: str, mixed_precision: torch.dt torch_loss = run_fwd_bwd(torch_model, input_ids, label, criterion, torch_optim) loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim) - assert_close(torch_loss, loss, rtol=rtol, atol=atol) + # as no master weights leads to error accumulation, we don't check the loss + if master_weights: + assert_close(torch_loss, loss, rtol=rtol, atol=atol) zero_optim.step() torch_optim.step() - check_param(model, torch_model, mixed_precision) + if master_weights: + check_param(model, torch_model, mixed_precision) @parameterize("placement_config", PLACEMENT_CONFIGS) From 561553b2e2424cb4690fba8dc6e96fd2bcd6b3bd Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Tue, 17 Oct 2023 10:48:24 +0800 Subject: [PATCH 05/46] [format] applied code formatting on changed files in pull request 4908 (#4918) Co-authored-by: github-actions --- .../utils/flash_attention_patch.py | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/utils/flash_attention_patch.py b/applications/Colossal-LLaMA-2/colossal_llama2/utils/flash_attention_patch.py index 111659b2d928..1926ec78aba8 100644 --- a/applications/Colossal-LLaMA-2/colossal_llama2/utils/flash_attention_patch.py +++ b/applications/Colossal-LLaMA-2/colossal_llama2/utils/flash_attention_patch.py @@ -6,25 +6,20 @@ import torch import torch.nn.functional as F +from einops import rearrange +from flash_attn.bert_padding import pad_input, unpad_input +from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_kvpacked_func +from flash_attn.ops.rms_norm import rms_norm from transformers.models.llama.modeling_llama import ( - LlamaRMSNorm, LlamaAttention, - LlamaModel, LlamaForCausalLM, + LlamaModel, + LlamaRMSNorm, apply_rotary_pos_emb, repeat_kv, ) from colossalai.logging import get_dist_logger -from einops import rearrange - -from flash_attn.bert_padding import pad_input, unpad_input -from flash_attn.flash_attn_interface import ( - flash_attn_func, - flash_attn_varlen_kvpacked_func, -) -from flash_attn.ops.rms_norm import rms_norm - logger = get_dist_logger() @@ -65,7 +60,7 @@ def attention_forward( past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, - **kwargs + **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """ Re-define LLaMA-2 `LlamaAttention` forward method using flash-attention. From 8d420025780ed90c7ba9871320a9bcd83592ec56 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Tue, 17 Oct 2023 14:07:21 +0800 Subject: [PATCH 06/46] [gemini] support gradient accumulation (#4869) * add test * fix no_sync bug in low level zero plugin * fix test * add argument for grad accum * add grad accum in backward hook for gemini * finish implementation, rewrite tests * fix test * skip stuck model in low level zero test * update doc * optimize communication & fix gradient checkpoint * modify doc * cleaning codes * update cpu adam fp16 case --- colossalai/booster/plugin/gemini_plugin.py | 5 +- .../booster/plugin/low_level_zero_plugin.py | 2 +- colossalai/zero/gemini/chunk/chunk.py | 15 ++ colossalai/zero/gemini/chunk/manager.py | 36 ++++- colossalai/zero/gemini/gemini_ddp.py | 25 ++- colossalai/zero/gemini/gemini_optimizer.py | 1 + .../gradient_accumulation_with_booster.md | 29 +++- .../gradient_accumulation_with_booster.md | 28 +++- tests/components_to_test/bert.py | 1 - .../test_plugin/test_low_level_zero_plugin.py | 4 +- .../test_zero/test_gemini/test_grad_accum.py | 147 ++++++++++++++++++ 11 files changed, 283 insertions(+), 10 deletions(-) create mode 100644 tests/test_zero/test_gemini/test_grad_accum.py diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index 6c165857506c..20a931b816ea 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -245,6 +245,7 @@ class GeminiPlugin(DPPluginBase): chunk_config_dict (dict, optional): chunk configuration dictionary. chunk_init_device (torch.device, optional): device to initialize the chunk. placement_policy (str, optional): "static" and "auto". Defaults to "static". + enable_gradient_accumulation (bool, optional): Whether to enable gradient accumulation. When set to True, gradient will be stored after doing backward pass. Defaults to False. shard_param_frac (float, optional): fraction of parameters to be sharded. Only for "static" placement. If `shard_param_frac` is 1.0, it's equal to zero-3. If `shard_param_frac` is 0.0, it's equal to zero-2. Defaults to 1.0. offload_optim_frac (float, optional): fraction of optimizer states to be offloaded. Only for "static" placement. @@ -257,7 +258,7 @@ class GeminiPlugin(DPPluginBase): warmup_non_model_data_ratio (float, optional): ratio of expected non-model data memory during warmup. Only for "auto" placement. Defaults to 0.8. steady_cuda_cap_ratio (float, optional): ratio of allowed cuda capacity for model data during steady state. Only for "auto" placement. Defaults to 0.9. precision (str, optional): precision. Support 'fp16' and 'bf16'. Defaults to 'fp16'. - master_weights (bool, optional): master weights. Defaults to True. + master_weights (bool, optional): Whether to keep fp32 master parameter weights in optimizer. Defaults to True. pin_memory (bool, optional): use pin memory on CPU. Defaults to False. force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False. strict_ddp_mode (bool, optional): use strict ddp mode (only use dp without other parallelism). Defaults to False. @@ -291,6 +292,7 @@ def __init__( chunk_config_dict: Optional[dict] = None, chunk_init_device: Optional[torch.device] = None, placement_policy: str = "static", + enable_gradient_accumulation: bool = False, shard_param_frac: float = 1.0, # only for static placement offload_optim_frac: float = 0.0, # only for static placement offload_param_frac: float = 0.0, # only for static placement @@ -323,6 +325,7 @@ def __init__( chunk_config_dict=chunk_config_dict, chunk_init_device=(chunk_init_device or get_current_device()), placement_policy=placement_policy, + enable_gradient_accumulation=enable_gradient_accumulation, shard_param_frac=shard_param_frac, offload_optim_frac=offload_optim_frac, offload_param_frac=offload_param_frac, diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index 088b67c8c533..dc78fe8c094c 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -335,4 +335,4 @@ def get_checkpoint_io(self) -> CheckpointIO: def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]: assert isinstance(optimizer, LowLevelZeroOptimizer) - return optimizer.optim.no_sync() + return optimizer.no_sync() diff --git a/colossalai/zero/gemini/chunk/chunk.py b/colossalai/zero/gemini/chunk/chunk.py index c8be773b2c4f..d3309fc5364f 100644 --- a/colossalai/zero/gemini/chunk/chunk.py +++ b/colossalai/zero/gemini/chunk/chunk.py @@ -434,6 +434,21 @@ def copy_tensor_to_chunk_slice( if update_ptr: tensor.data = self.cuda_global_chunk[tensor_info.offset : tensor_info.end].view(tensor.shape) + def add_tensor_to_chunk_slice(self, tensor: torch.Tensor, data_slice: torch.Tensor) -> None: + """ + Add data slice to the memory space indexed by the input tensor in the chunk. + Only used when accumulating gradient chunks. + + Args: + tensor (torch.Tensor): the tensor used to retrieve meta information + data_slice (torch.Tensor): the tensor to be added to the chunk + """ + # sanity check + assert self.is_gathered + + tensor_info = self.tensors_info[tensor] + self.cuda_global_chunk[tensor_info.offset : tensor_info.end].add_(data_slice.data.flatten()) + def get_valid_length(self) -> int: """Get the valid length of the chunk's payload.""" if self.keep_gathered: diff --git a/colossalai/zero/gemini/chunk/manager.py b/colossalai/zero/gemini/chunk/manager.py index 713c11742e15..d3c512fe978d 100644 --- a/colossalai/zero/gemini/chunk/manager.py +++ b/colossalai/zero/gemini/chunk/manager.py @@ -5,7 +5,7 @@ import torch.distributed as dist from torch.distributed import ProcessGroup -from colossalai.utils import get_current_device +from colossalai.utils import free_storage, get_current_device from .chunk import Chunk, ChunkFullError, TensorState @@ -255,3 +255,37 @@ def init_grad_chunk(self, chunk: Chunk) -> Chunk: self.accessed_chunks.add(grad_chunk) self.accessed_mem += grad_chunk.chunk_mem return grad_chunk + + def rearrange_accumulated_grad_chunk(self, chunk: Chunk) -> Chunk: + """Rearrange gradients accumulated in chunk.grad_chunk, and getP prepared for gradient reduction.""" + + assert chunk.grad_chunk is not None + + # Make a backup for gradient accumulated before. + # Here backup gradients should be multiplied, since it will be divided after gradient reduction. + if chunk.grad_chunk.is_gathered: + accumulated_grad = chunk.grad_chunk.cuda_global_chunk.clone().detach().mul_(chunk.pg_size) + accumulated_grad_gathered = True + else: + if chunk.grad_chunk.cuda_shard is not None: + accumulated_grad = chunk.grad_chunk.cuda_shard.clone().detach().mul_(chunk.pg_size) + else: + accumulated_grad = ( + chunk.grad_chunk.cpu_shard.to(get_current_device()).clone().detach().mul_(chunk.pg_size) + ) + accumulated_grad_gathered = False + + # Reset grad_chunk, and chunk.grad_chunk will be accessed. + grad_chunk = self.init_grad_chunk(chunk) + grad_chunk.cuda_global_chunk.zero_() + + # Add backup gradients to grad_chunk. + if accumulated_grad_gathered: + grad_chunk.cuda_global_chunk.add_(accumulated_grad) + else: + grad_chunk.cuda_global_chunk[grad_chunk.shard_begin : grad_chunk.shard_end].add_(accumulated_grad) + + # Release accumulated_grad + free_storage(accumulated_grad) + + return grad_chunk diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index a4871f7e4b40..df7e1163c3d9 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -59,6 +59,7 @@ def __init__( chunk_config_dict: Optional[dict] = None, chunk_init_device: torch.device = torch.device("cpu"), placement_policy: str = "static", + enable_gradient_accumulation: bool = False, shard_param_frac: float = 1.0, # only for static placement offload_optim_frac: float = 0.0, # only for static placement offload_param_frac: float = 0.0, # only for static placement @@ -119,6 +120,11 @@ def __init__( self.reuse_fp16_chunk = master_weights self.master_weights = master_weights + self.enable_gradient_accumulation = enable_gradient_accumulation + if self.enable_gradient_accumulation: + self.reuse_fp16_chunk = False + self.accumulating_grads = False # Whether model is accumulating gradients + self._logger = get_dist_logger() if self.gemini_manager._premade_memstats_: @@ -298,6 +304,8 @@ def _post_backward(self): f"{error_str}", ) self._setup_grads_ptr() + if self.enable_gradient_accumulation and not self.accumulating_grads: + self.accumulating_grads = True # Turn on the state of gradient accumulation. self._logger.debug( f"comp cuda demand time: {self.gemini_manager._comp_cuda_demand_time}, layout time: {self.gemini_manager._layout_time}, evict time: {self.gemini_manager._evict_time}, CPU->CUDA vol: {self.gemini_manager._h2d_volume}B, CUDA->CPU vol: {self.gemini_manager._d2h_volume}" ) @@ -327,7 +335,15 @@ def grad_handle(self, p, grad): ) grad_chunk = chunk if not self.reuse_fp16_chunk: - grad_chunk = self.chunk_manager.init_grad_chunk(chunk) + if not self.accumulating_grads: + grad_chunk = self.chunk_manager.init_grad_chunk(chunk) + else: + assert chunk.grad_chunk is not None + if chunk.grad_chunk not in self.chunk_manager.accessed_chunks: + grad_chunk = self.chunk_manager.rearrange_accumulated_grad_chunk(chunk) + else: + grad_chunk = chunk.grad_chunk + # hold -> compute -> hold after bwd grad_chunk.tensor_trans_state(p, TensorState.COMPUTE) grad_chunk.tensor_trans_state(p, TensorState.HOLD_AFTER_BWD) @@ -336,7 +352,10 @@ def grad_handle(self, p, grad): chunk.tensor_trans_state(p, TensorState.HOLD) grad_chunk.tensor_trans_state(p, TensorState.READY_FOR_REDUCE) - grad_chunk.copy_tensor_to_chunk_slice(p, grad, update_ptr=self.reuse_fp16_chunk) + if not self.accumulating_grads: + grad_chunk.copy_tensor_to_chunk_slice(p, grad, update_ptr=self.reuse_fp16_chunk) + else: + grad_chunk.add_tensor_to_chunk_slice(p, grad) reduced = self.chunk_manager.reduce_chunk(grad_chunk) if reduced: if not self.reuse_fp16_chunk: @@ -354,7 +373,7 @@ def grad_handle(self, p, grad): if chunk.l2_norm_flag: grad_chunk.set_l2_norm() self.chunk_manager.move_chunk(grad_chunk, self.grads_device[p], force_copy=True) - if not self.master_weights: + if not (self.master_weights) or (self.enable_gradient_accumulation): self.chunk_manager.move_chunk(chunk, self.grads_device[p], force_copy=True) return empty_grad diff --git a/colossalai/zero/gemini/gemini_optimizer.py b/colossalai/zero/gemini/gemini_optimizer.py index 3c42e96cb803..0d0298e067f3 100644 --- a/colossalai/zero/gemini/gemini_optimizer.py +++ b/colossalai/zero/gemini/gemini_optimizer.py @@ -263,6 +263,7 @@ def step(self, *args, **kwargs): self.zero_grad() if self.module.master_weights: self._update_fp16_params() + self.module.accumulating_grads = False return ret def clip_grad_norm(self, model: torch.nn.Module, max_norm: float, norm_type: float = 2.0): diff --git a/docs/source/en/features/gradient_accumulation_with_booster.md b/docs/source/en/features/gradient_accumulation_with_booster.md index 347cd6e519bb..ea97dd92e885 100644 --- a/docs/source/en/features/gradient_accumulation_with_booster.md +++ b/docs/source/en/features/gradient_accumulation_with_booster.md @@ -1,6 +1,6 @@ # Gradient Accumulation -Author: [Mingyan Jiang](https://github.com/jiangmingyan) +Author: [Mingyan Jiang](https://github.com/jiangmingyan), [Baizhou Zhang](https://github.com/Fridge003) **Prerequisite** - [Training Booster](../basics/booster_api.md) @@ -126,6 +126,7 @@ for idx, (img, label) in enumerate(train_dataloader): ``` + ### Step 6. Invoke Training Scripts To verify gradient accumulation, we can just check the change of parameter values. When gradient accumulation is set, parameters are only updated in the last step. You can run the script using this command: ```shell @@ -142,4 +143,30 @@ iteration 2, first 10 elements of param: tensor([-0.0208, 0.0189, 0.0234, 0.0 iteration 3, first 10 elements of param: tensor([-0.0141, 0.0464, 0.0507, 0.0321, 0.0356, -0.0150, 0.0172, -0.0118, 0.0222, 0.0473], device='cuda:0', grad_fn=) ``` + +## Gradient Accumulation on GeminiPlugin + +Currently the plugins supporting `no_sync()` method include `TorchDDPPlugin` and `LowLevelZeroPlugin` set to stage 1. `GeminiPlugin` doesn't support `no_sync()` method, but it can enable synchronized gradient accumulation in a torch-like way. + +To enable gradient accumulation feature, the argument `enable_gradient_accumulation` should be set to `True` when initializing `GeminiPlugin`. Following is the pseudocode snippet of enabling gradient accumulation for `GeminiPlugin`: + +```python +... +plugin = GeminiPlugin(..., enable_gradient_accumulation=True) +booster = Booster(plugin=plugin) +... + +... +for idx, (input, label) in enumerate(train_dataloader): + output = gemini_model(input.cuda()) + train_loss = criterion(output, label.cuda()) + train_loss = train_loss / GRADIENT_ACCUMULATION + booster.backward(train_loss, gemini_optimizer) + + if idx % (GRADIENT_ACCUMULATION - 1) == 0: + gemini_optimizer.step() # zero_grad is automatically done +... +``` + + diff --git a/docs/source/zh-Hans/features/gradient_accumulation_with_booster.md b/docs/source/zh-Hans/features/gradient_accumulation_with_booster.md index 3ad9b2e07a95..824308f94654 100644 --- a/docs/source/zh-Hans/features/gradient_accumulation_with_booster.md +++ b/docs/source/zh-Hans/features/gradient_accumulation_with_booster.md @@ -1,6 +1,6 @@ # 梯度累积 -作者: [Mingyan Jiang](https://github.com/jiangmingyan) +作者: [Mingyan Jiang](https://github.com/jiangmingyan), [Baizhou Zhang](https://github.com/Fridge003) **前置教程** - [训练中使用Booster](../basics/booster_api.md) @@ -93,6 +93,7 @@ model, optimizer, criterion, train_dataloader, _ = booster.boost(model=model, dataloader=train_dataloader) ``` + ### 步骤 5. 使用booster训练 使用booster构建一个普通的训练循环,验证梯度累积。 `param_by_iter` 记录分布训练的信息。 ```python @@ -144,4 +145,29 @@ iteration 2, first 10 elements of param: tensor([-0.0208, 0.0189, 0.0234, 0.0 iteration 3, first 10 elements of param: tensor([-0.0141, 0.0464, 0.0507, 0.0321, 0.0356, -0.0150, 0.0172, -0.0118, 0.0222, 0.0473], device='cuda:0', grad_fn=) ``` +## 在Gemini插件中使用梯度累积 + +目前支持`no_sync()`方法的插件包括 `TorchDDPPlugin` 和 `LowLevelZeroPlugin`(需要设置参数`stage`为1). `GeminiPlugin` 不支持 `no_sync()` 方法, 但是它可以通过和`pytorch`类似的方式来使用同步的梯度累积。 + +为了开启梯度累积功能,在初始化`GeminiPlugin`的时候需要将参数`enable_gradient_accumulation`设置为`True`。以下是 `GeminiPlugin` 进行梯度累积的伪代码片段: + +```python +... +plugin = GeminiPlugin(..., enable_gradient_accumulation=True) +booster = Booster(plugin=plugin) +... + +... +for idx, (input, label) in enumerate(train_dataloader): + output = gemini_model(input.cuda()) + train_loss = criterion(output, label.cuda()) + train_loss = train_loss / GRADIENT_ACCUMULATION + booster.backward(train_loss, gemini_optimizer) + + if idx % (GRADIENT_ACCUMULATION - 1) == 0: + gemini_optimizer.step() # zero_grad is automatically done +... +``` + + diff --git a/tests/components_to_test/bert.py b/tests/components_to_test/bert.py index f0061ad18c84..9f0eef75ae93 100644 --- a/tests/components_to_test/bert.py +++ b/tests/components_to_test/bert.py @@ -52,7 +52,6 @@ def bert_model_builder(checkpoint: bool = False): hidden_dropout_prob=0.0, attention_probs_dropout_prob=0.0, ) - print("building BertForSequenceClassification model") # adapting huggingface BertForSequenceClassification for single unittest calling interface class ModelAdaptor(BertForSequenceClassification): diff --git a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py index 9cc12f96bd4d..104ca254c572 100644 --- a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py +++ b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py @@ -14,6 +14,8 @@ _AMP_ERR_MODELS = ["timm_convit", "deepfm_interactionarch"] # These models have no parameters _LOW_LEVEL_ZERO_ERR_MODELS = ["dlrm_interactionarch"] +# These models will cause stuck, to be fixed +_STUCK_MODELS = ["transformers_albert_for_multiple_choice"] def run_fn(stage, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]: @@ -53,7 +55,7 @@ def check_low_level_zero_plugin(stage: int, early_stop: bool = True): """ passed_models = [] failed_info = {} # (model_name, error) pair - ignore_models = _AMP_ERR_MODELS + _LOW_LEVEL_ZERO_ERR_MODELS + ignore_models = _AMP_ERR_MODELS + _LOW_LEVEL_ZERO_ERR_MODELS + _STUCK_MODELS skipped_models = [] for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.items(): diff --git a/tests/test_zero/test_gemini/test_grad_accum.py b/tests/test_zero/test_gemini/test_grad_accum.py new file mode 100644 index 000000000000..334a57410817 --- /dev/null +++ b/tests/test_zero/test_gemini/test_grad_accum.py @@ -0,0 +1,147 @@ +import pytest +import torch +import torch.distributed as dist +from apex import amp +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.testing import assert_close + +import colossalai +from colossalai.nn.optimizer import HybridAdam +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.utils import set_seed +from colossalai.utils.cuda import get_current_device +from colossalai.zero import GeminiDDP, GeminiOptimizer +from colossalai.zero.gemini.chunk import search_chunk_configuration +from tests.components_to_test import run_fwd +from tests.components_to_test.registry import non_distributed_component_funcs + +PLACEMENT_CONFIGS = [ + {"placement_policy": "static", "shard_param_frac": 0.0}, # zero2 + {"placement_policy": "static", "shard_param_frac": 1.0}, # zero3 + {"placement_policy": "static", "shard_param_frac": 0.5}, # zero3-half + {"placement_policy": "auto"}, +] + + +def check_grad(model: GeminiDDP, torch_model: torch.nn.Module): + chunk_manager = model.chunk_manager + grad_chunk_list = [] + device_list = [] + + # Access gradient chunks. + for p in model.parameters(): + grad_chunk = chunk_manager.get_chunk(p).grad_chunk + if grad_chunk not in grad_chunk_list: + chunk_manager.access_chunk(grad_chunk) + grad_chunk_list.append(grad_chunk) + device_list.append(model.grads_device[p]) + + # Compare gradients. + for p0, p1 in zip(model.parameters(), torch_model.parameters()): + assert_close(p0, p1.grad, rtol=1e-3, atol=5e-5) + + # Release gradient chunks and move them to gradient device. + for grad_chunk, device in zip(grad_chunk_list, device_list): + chunk_manager.release_chunk(grad_chunk) + chunk_manager.move_chunk(grad_chunk, device, force_copy=True) + + +@parameterize("placement_config", PLACEMENT_CONFIGS) +@parameterize("keep_gathered", [False, True]) +@parameterize("model_name", ["gpt2", "bert"]) +@parameterize("use_grad_checkpoint", [False, True]) +@parameterize("master_weights", [False, True]) +def exam_gemini_grad_acc( + placement_config, keep_gathered: bool, model_name: str, use_grad_checkpoint: bool, master_weights: bool +): + init_device = get_current_device() + get_components_func = non_distributed_component_funcs.get_callable(model_name) + model_builder, train_dataloader, _, _, criterion = get_components_func() + + set_seed(42) + gemini_model = model_builder(use_grad_checkpoint) + + set_seed(42) + torch_model = model_builder(use_grad_checkpoint).cuda() + for torch_p, p in zip(torch_model.parameters(), gemini_model.parameters()): + torch_p.data.copy_(p.data) + + world_size = torch.distributed.get_world_size() + config_dict, *_ = search_chunk_configuration(gemini_model, search_range_m=1, search_interval=100) + config_dict[world_size]["chunk_size"] = 5000 + config_dict[world_size]["keep_gathered"] = keep_gathered + gemini_model = GeminiDDP( + gemini_model, + config_dict, + init_device, + pin_memory=True, + enable_gradient_accumulation=True, + master_weights=master_weights, + **placement_config, + ) + optimizer = HybridAdam(gemini_model.parameters(), lr=1e-3) + gemini_optim = GeminiOptimizer(optimizer, gemini_model, initial_scale=1) + + rank = dist.get_rank() + + # setting master_weights to False will cause overflow after optimizer.step() + amp_config = dict( + opt_level="O2", keep_batchnorm_fp32=False, loss_scale=1, min_loss_scale=1, max_loss_scale=1, master_weights=True + ) + torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3) + torch_model, torch_optim = amp.initialize(torch_model, torch_optim, **amp_config) + torch_model = DDP(torch_model, device_ids=[rank]) + + set_seed(rank) + accum_iter = 4 + for i, (input_ids, label) in enumerate(train_dataloader): + delay_unscale = False if (i + 1) % accum_iter == 0 else True + input_ids, label = input_ids.cuda(), label.cuda() + + set_seed(42 + rank) + torch_loss = run_fwd(torch_model, input_ids, label, criterion) + torch_loss = torch_loss / accum_iter + with amp.scale_loss(torch_loss, torch_optim, delay_unscale=delay_unscale) as scaled_loss: + scaled_loss.backward() + + set_seed(42 + rank) + gemini_loss = run_fwd(gemini_model, input_ids, label, criterion) + gemini_loss = gemini_loss / accum_iter + gemini_optim.backward(gemini_loss) + + assert torch.allclose(torch_loss, gemini_loss, rtol=1e-3, atol=1e-5) + + check_grad(gemini_model, torch_model) + + if (i + 1) % accum_iter == 0: + torch_optim.step() + gemini_optim.step() + torch_optim.zero_grad() + + # check updated param + torch_dict = torch_model.state_dict() + gemini_dict = gemini_model.state_dict(only_rank_0=False) + + for key, value in gemini_dict.items(): + torch_key = "module." + key + torch_value = torch_dict[torch_key].to(value.device).to(value.dtype) + assert_close(value, torch_value, rtol=1e-3, atol=2e-3) + + if i == accum_iter: + break + + +def run_dist(rank, world_size, port): + config = {} + colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + exam_gemini_grad_acc() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_grad_accumulation(): + spawn(run_dist, 2) + + +if __name__ == "__main__": + test_grad_accumulation() From da55732d2fc0660431116df8fa51ec5a84d9b1dc Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Wed, 18 Oct 2023 11:05:25 +0800 Subject: [PATCH 07/46] [hotfix] fix torch 2.0 compatibility (#4936) * [hotfix] fix launch * [test] fix test gemini optim * [shardformer] fix vit --- colossalai/legacy/context/parallel_context.py | 16 ++++++--- colossalai/legacy/tensor/process_group.py | 5 ++- colossalai/shardformer/modeling/vit.py | 33 +++++++------------ tests/test_shardformer/test_model/_utils.py | 23 ++++--------- .../test_model/test_shard_vit.py | 13 ++------ tests/test_zero/test_gemini/test_optim.py | 4 +++ 6 files changed, 39 insertions(+), 55 deletions(-) diff --git a/colossalai/legacy/context/parallel_context.py b/colossalai/legacy/context/parallel_context.py index 48bf8ab279e8..b95405a33092 100644 --- a/colossalai/legacy/context/parallel_context.py +++ b/colossalai/legacy/context/parallel_context.py @@ -54,7 +54,7 @@ def __init__(self): # logging self._verbose = False - self._logger = get_dist_logger() + self._logger = None @property def config(self): @@ -68,6 +68,12 @@ def verbose(self): def verbose(self, verbose_: bool): self._verbose = verbose_ + @property + def logger(self): + if self._logger is None: + self._logger = get_dist_logger() + return self._logger + def load_config(self, config: Union[dict, str]): """Loads the configuration from either a dict or a file. @@ -527,7 +533,7 @@ def set_device(self, device_ordinal: int = None): torch.cuda.set_device(device_ordinal) if self._verbose: - self._logger.info(f"process rank {global_rank} is bound to device {device_ordinal}") + self.logger.info(f"process rank {global_rank} is bound to device {device_ordinal}") def set_seed(self, seed: int): """Sets seeds for all random libraries. @@ -563,19 +569,19 @@ def set_seed(self, seed: int): seed_str = ", ".join([f"{k}: {v}" for k, v in seeds.items()]) if self._verbose: - self._logger.info( + self.logger.info( f"initialized seed on rank {global_rank}, " f"numpy: {seed}, python random: {seed}, {seed_str}," f"the default parallel seed is {ParallelMode.DATA}." ) else: if self._verbose: - self._logger.info( + self.logger.info( f"initialized seed on rank {global_rank}, " f"numpy: {seed}, python random: {seed}, pytorch: {seed}", ranks=[0], ) - self._logger.info( + self.logger.info( "WARNING: CUDA is not available, thus CUDA RNG cannot be used to track CUDA random number states", ranks=[0], ) diff --git a/colossalai/legacy/tensor/process_group.py b/colossalai/legacy/tensor/process_group.py index ec6043163336..230849f17576 100644 --- a/colossalai/legacy/tensor/process_group.py +++ b/colossalai/legacy/tensor/process_group.py @@ -31,7 +31,7 @@ def get(self, rank_list: List[int], backend: str = "nccl"): return self.dict[processgroup_key] -PYTORCHPGDICT_ = PyTorchProcessGroupDict() +PYTORCHPGDICT_ = None class ProcessGroup: @@ -59,6 +59,9 @@ def __init__( if not torch.distributed.is_initialized(): self.is_init = False return + global PYTORCHPGDICT_ + if PYTORCHPGDICT_ is None: + PYTORCHPGDICT_ = PyTorchProcessGroupDict() assert torch.distributed.is_initialized(), f"ProcessGroup must be used after distributed initialized" diff --git a/colossalai/shardformer/modeling/vit.py b/colossalai/shardformer/modeling/vit.py index 2db83b912112..5a50e7379cdc 100644 --- a/colossalai/shardformer/modeling/vit.py +++ b/colossalai/shardformer/modeling/vit.py @@ -100,35 +100,24 @@ def pp_forward( embedding_output = self.embeddings( pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding ) + hidden_states = embedding_output else: assert ( hidden_states is not None ), f"Current stage is {stage_manager.stage}, hidden_states should not be None" - # Go through encoder + encoder_outputs = _encoder_forward( + encoder=self.encoder, + start_idx=stage_index[0], + end_idx=stage_index[1], + hidden_states=hidden_states, + head_mask=head_mask, + return_dict=return_dict, + stage_manager=stage_manager, + ) if not stage_manager.is_last_stage(): - hidden_states = _encoder_forward( - encoder=self.encoder, - start_idx=stage_index[0], - end_idx=stage_index[1], - hidden_states=embedding_output, - head_mask=head_mask, - return_dict=return_dict, - stage_manager=stage_manager, - ) - return {"hidden_states": hidden_states} - else: - encoder_outputs = _encoder_forward( - encoder=self.encoder, - start_idx=stage_index[0], - end_idx=stage_index[1], - hidden_states=hidden_states, - head_mask=head_mask, - return_dict=return_dict, - stage_manager=stage_manager, - ) + return {"hidden_states": encoder_outputs} - # Go through rest layers sequence_output = encoder_outputs[0] sequence_output = self.layernorm(sequence_output) pooled_output = self.pooler(sequence_output) if self.pooler is not None else None diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index 66d77b48aa0c..6acbe4ff523d 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -10,6 +10,7 @@ from torch.distributed import ProcessGroup from torch.nn import Module from torch.optim import Adam, Optimizer +from torch.testing import assert_close from colossalai.booster import Booster from colossalai.booster.plugin import HybridParallelPlugin @@ -160,7 +161,7 @@ def _criterion(outputs, inputs): input_shape = data["input_ids"].shape for k, v in data.items(): if v.shape == input_shape: - data[k] = v.repeat((1, ) * (v.dim() - 1) + (times,)) + data[k] = v.repeat((1,) * (v.dim() - 1) + (times,)) sharded_model.train() if booster.plugin.stage_manager is not None: @@ -207,15 +208,11 @@ def check_output_hidden_state( else: sharded_hidden_state = sharded_output.last_hidden_state - assert torch.allclose( - org_hidden_state.float(), sharded_hidden_state.float(), atol=atol, rtol=rtol - ), f"shard model's output hidden state is not equal to origin model's last hidden state\n{org_hidden_state}\n{sharded_hidden_state}" + assert_close(org_hidden_state.float(), sharded_hidden_state.float(), atol=atol, rtol=rtol) def check_loss(org_loss: Tensor, sharded_loss: Tensor, atol: float = 1e-5, rtol: float = 1e-3): - assert torch.allclose( - org_loss.float(), sharded_loss.float(), atol=atol, rtol=rtol - ), f"shard model loss is not equal to origin model loss\n{org_loss}\n{sharded_loss}" + assert torch.allclose(org_loss.float(), sharded_loss.float(), atol=atol, rtol=rtol) def check_weight( @@ -242,9 +239,7 @@ def check_weight( if verbose and dist.get_rank() == 0: print(f"'{suffix}' weight: {org_weight}, {sharded_weight}") - assert torch.allclose( - org_weight.float(), sharded_weight.float(), atol=atol, rtol=rtol - ), f"shard model weight {suffix} is not equal to origin model weight\n{org_weight}\n{sharded_weight}" + assert_close(org_weight.float(), sharded_weight.float(), atol=atol, rtol=rtol) def get_grad_tensors_for_check( @@ -310,9 +305,7 @@ def check_grad( if verbose and dist.get_rank() == 0: print(f"'{suffix}' grad: {org_grad}, {shard_grad}") - assert torch.allclose( - org_grad.float(), shard_grad.float(), rtol=rtol, atol=atol - ), f"error attribute '{suffix}', orgin model grad is not equal to shard model grad\n{org_grad}\n{shard_grad}" + assert_close(org_grad.float(), shard_grad.float(), rtol=rtol, atol=atol) def unwrap_model( @@ -337,6 +330,4 @@ def check_all_grad_tensors(check_tensors): shard_grad = check_info["shard_grad"] rtol = check_info["rtol"] atol = check_info["atol"] - assert torch.allclose( - org_grad, shard_grad, atol=atol, rtol=rtol - ), f"error attribute '{suffix}', orgin model grad is not equal to shard model grad\n{org_grad}\n{shard_grad}" + assert_close(org_grad, shard_grad, atol=atol, rtol=rtol) diff --git a/tests/test_shardformer/test_model/test_shard_vit.py b/tests/test_shardformer/test_model/test_shard_vit.py index 1c934bd22340..3a8af2d6d481 100644 --- a/tests/test_shardformer/test_model/test_shard_vit.py +++ b/tests/test_shardformer/test_model/test_shard_vit.py @@ -43,7 +43,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, grads_to_check = {} if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: if test_config["precision"] == "fp32": - atol, rtol = 1e-5, 1e-3 + atol, rtol = 2e-5, 1e-3 else: atol, rtol = 5e-3, 5e-3 row_layer_grads = get_grad_tensors_for_check( @@ -62,7 +62,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # check last hidden state & loss if stage_manager is None or stage_manager.is_last_stage(): if test_config["precision"] == "fp32": - atol, rtol = 1e-5, 1e-3 + atol, rtol = 2e-3, 1e-3 else: atol, rtol = 5e-3, 5e-3 @@ -154,15 +154,6 @@ def run_vit_test(test_config): "precision": "fp32", "initial_scale": 1, }, - { - "tp_size": 2, - "pp_size": 2, - "num_microbatches": 2, - "enable_all_optimization": False, - "use_lazy_init": False, - "precision": "fp32", - "initial_scale": 1, - }, ], ) def run_vit_3d_test(test_config): diff --git a/tests/test_zero/test_gemini/test_optim.py b/tests/test_zero/test_gemini/test_optim.py index 9b84d68f3c7a..0cf9aa073f9f 100644 --- a/tests/test_zero/test_gemini/test_optim.py +++ b/tests/test_zero/test_gemini/test_optim.py @@ -1,6 +1,7 @@ import pytest import torch import torch.distributed as dist +from packaging.version import Version from torch.nn.parallel import DistributedDataParallel as DDP from torch.testing import assert_close @@ -161,6 +162,9 @@ def exam_tiny_example(placement_config, model_name: str, mixed_precision: torch. rtol, atol = 1.5e-6, 2e-5 if mixed_precision is torch.bfloat16: rtol, atol = 2e-3, 2e-3 + elif Version(torch.__version__) >= Version("2.0.0"): + rtol, atol = 4e-5, 3e-5 + for i, (input_ids, label) in enumerate(train_dataloader): if i > 2: break From 775ea1bc0b881b8f902bf7a1ca2566b5267ea762 Mon Sep 17 00:00:00 2001 From: Zhongkai Zhao Date: Wed, 18 Oct 2023 11:41:23 +0800 Subject: [PATCH 08/46] [test] add no master test for low level zero plugin (#4934) --- colossalai/nn/optimizer/cpu_adam.py | 3 ++- tests/test_zero/test_low_level/test_zero1_2.py | 9 +++++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/colossalai/nn/optimizer/cpu_adam.py b/colossalai/nn/optimizer/cpu_adam.py index 238ba366da43..c3c0180e8516 100644 --- a/colossalai/nn/optimizer/cpu_adam.py +++ b/colossalai/nn/optimizer/cpu_adam.py @@ -9,7 +9,8 @@ class CPUAdam(NVMeOptimizer): - """Implements Adam algorithm. + """ + Implements Adam algorithm. Supports parameters updating on both GPU and CPU, depending on the device of parameters. But the parameters and gradients should on the same device: diff --git a/tests/test_zero/test_low_level/test_zero1_2.py b/tests/test_zero/test_low_level/test_zero1_2.py index ebda9f6f25c5..e2196cfbf0f2 100644 --- a/tests/test_zero/test_low_level/test_zero1_2.py +++ b/tests/test_zero/test_low_level/test_zero1_2.py @@ -106,7 +106,8 @@ def exam_zero_1_2(): @parameterize("dtype", [torch.float16, torch.bfloat16]) -def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype): +@parameterize("master_weights", [True, False]) +def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype, master_weights: bool): """ In this test, two pairs of model and optimizers are created. 1. zero: use sharded optimizer and fp16 parameters @@ -131,7 +132,11 @@ def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype): # in `check_sharded_param_consistency.py`, we will test whether # level 1 and 2 will produce exactly the same results zero_optimizer = LowLevelZeroOptimizer( - zero_optimizer, overlap_communication=True, initial_scale=1, reduce_bucket_size=1024 * 1024 + zero_optimizer, + overlap_communication=True, + initial_scale=1, + reduce_bucket_size=1024 * 1024, + master_weights=master_weights, ) torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1) From 0074178f332ea26d80fe62151ff0d41dd2898a35 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Wed, 18 Oct 2023 11:46:37 +0800 Subject: [PATCH 09/46] [format] applied code formatting on changed files in pull request 4820 (#4886) Co-authored-by: github-actions --- colossalai/inference/__init__.py | 2 +- colossalai/inference/pipeline/__init__.py | 2 +- .../inference/pipeline/benchmark/benchmark.py | 97 +++++++----- colossalai/inference/pipeline/engine.py | 21 ++- .../inference/pipeline/microbatch_manager.py | 36 ++--- .../inference/pipeline/modeling/gpt2.py | 144 +++++++++--------- .../inference/pipeline/modeling/llama.py | 78 +++++----- .../inference/pipeline/policy/gpt2_ppinfer.py | 27 ++-- .../pipeline/policy/llama_ppinfer.py | 26 ++-- colossalai/inference/pipeline/utils.py | 6 +- colossalai/pipeline/p2p.py | 12 +- colossalai/pipeline/schedule/generate.py | 77 ++++++---- tests/test_infer/test_pipeline_infer.py | 29 ++-- 13 files changed, 298 insertions(+), 259 deletions(-) diff --git a/colossalai/inference/__init__.py b/colossalai/inference/__init__.py index db33ae6fe998..35891307e754 100644 --- a/colossalai/inference/__init__.py +++ b/colossalai/inference/__init__.py @@ -1,3 +1,3 @@ from .pipeline import PPInferEngine -__all__ = ['PPInferEngine'] +__all__ = ["PPInferEngine"] diff --git a/colossalai/inference/pipeline/__init__.py b/colossalai/inference/pipeline/__init__.py index aff4568f7d08..41af9f3ef948 100644 --- a/colossalai/inference/pipeline/__init__.py +++ b/colossalai/inference/pipeline/__init__.py @@ -1,3 +1,3 @@ from .engine import PPInferEngine -__all__ = ['PPInferEngine'] +__all__ = ["PPInferEngine"] diff --git a/colossalai/inference/pipeline/benchmark/benchmark.py b/colossalai/inference/pipeline/benchmark/benchmark.py index 97dfc6336bea..9c47909f70f0 100644 --- a/colossalai/inference/pipeline/benchmark/benchmark.py +++ b/colossalai/inference/pipeline/benchmark/benchmark.py @@ -1,28 +1,32 @@ +import argparse +import time + import torch import torch.distributed as dist import transformers import colossalai -import time from colossalai.inference import PPInferEngine from colossalai.inference.pipeline.policy.llama_ppinfer import LlamaForCausalLMPipelinePolicy -import argparse -GIGABYTE = 1024 ** 3 + +GIGABYTE = 1024**3 MEGABYTE = 1024 * 1024 colossalai.launch_from_torch(config={}) -def data_gen(batch_size: int=4, seq_len: int=512): + +def data_gen(batch_size: int = 4, seq_len: int = 512): input_ids = torch.randint(10, 30000, (1, seq_len), dtype=torch.int32) attention_mask = torch.ones((1, seq_len), dtype=torch.int32) data = dict(input_ids=input_ids, attention_mask=attention_mask) for k, v in data.items(): - if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__: + if torch.is_tensor(v) or "Tensor" in v.__class__.__name__: new_shape = [1] * v.dim() new_shape[0] = batch_size - data[k] = v.to('cuda').repeat(*new_shape) + data[k] = v.to("cuda").repeat(*new_shape) return data + def print_details_info(timestamps, model_config, args, whole_end2end): if dist.get_rank() == 0: prefill = [] @@ -31,32 +35,37 @@ def print_details_info(timestamps, model_config, args, whole_end2end): for timestamp in timestamps: prefill.append(timestamp[1] - timestamp[0]) encoder.append( - sum(timestamp[i + 1] - timestamp[i] for i in range(1,len(timestamp) - 1)) / (len(timestamp) - 2)) + sum(timestamp[i + 1] - timestamp[i] for i in range(1, len(timestamp) - 1)) / (len(timestamp) - 2) + ) end2end.append(timestamp[-1] - timestamp[0]) print(whole_end2end) - with open(f"{args.log_path}/llama-{args.model}{args.dtype}_pp{args.pp_size}_{args.seq_len}_{args.new_length}_bsz{args.batch_size}_mbsz{args.mb_size}.log","w+") as f: - mb_avg_end2end = sum(end2end)/len(end2end) - mb_avg_latency = mb_avg_end2end/(args.new_length * args.mb_size) - whole_avg_latency = whole_end2end/(args.new_length * args.batch_size) + with open( + f"{args.log_path}/llama-{args.model}{args.dtype}_pp{args.pp_size}_{args.seq_len}_{args.new_length}_bsz{args.batch_size}_mbsz{args.mb_size}.log", + "w+", + ) as f: + mb_avg_end2end = sum(end2end) / len(end2end) + mb_avg_latency = mb_avg_end2end / (args.new_length * args.mb_size) + whole_avg_latency = whole_end2end / (args.new_length * args.batch_size) num_layers = getattr(model_config, "num_layers", model_config.num_hidden_layers) num_parameters = num_layers * model_config.hidden_size * model_config.hidden_size * 12 / args.pp_size - if args.dtype in ['fp16','bf16']: + if args.dtype in ["fp16", "bf16"]: num_bytes = 2 else: num_bytes = 4 - f.write(f"llama-{args.model}{args.dtype}_pp{args.pp_size}, input_len:{args.seq_len}, output_len:{args.new_length}, bsz:{args.batch_size}, mbsz:{args.mb_size}\n") - f.write("Average prefill time: {0:8.2f} ms\n".format(sum(prefill)/len(prefill)*1000)) - f.write("Average encode time: {0:8.2f} ms\n".format(sum(encoder)/len(encoder)*1000)) - f.write("Average micro batch end2end time: {0:8.2f} ms\n".format(mb_avg_end2end*1000)) + f.write( + f"llama-{args.model}{args.dtype}_pp{args.pp_size}, input_len:{args.seq_len}, output_len:{args.new_length}, bsz:{args.batch_size}, mbsz:{args.mb_size}\n" + ) + f.write("Average prefill time: {0:8.2f} ms\n".format(sum(prefill) / len(prefill) * 1000)) + f.write("Average encode time: {0:8.2f} ms\n".format(sum(encoder) / len(encoder) * 1000)) + f.write("Average micro batch end2end time: {0:8.2f} ms\n".format(mb_avg_end2end * 1000)) f.write("Average micro batch Per Token Latency: {0:8.2f} ms\n".format(mb_avg_latency * 1000)) - f.write("Whole batch end2end time: {0:8.2f} ms\n".format(whole_end2end*1000)) + f.write("Whole batch end2end time: {0:8.2f} ms\n".format(whole_end2end * 1000)) f.write("Whole batch Per Token Latency: {0:8.2f} ms\n".format(whole_avg_latency * 1000)) - f.write("Throughput: {} tokens/s\n".format((1000/(whole_avg_latency * 1000)))) - f.write("flops: {0:8.2f} TFlops/s\n".format(1/whole_avg_latency * num_parameters * num_bytes / 1e12)) + f.write("Throughput: {} tokens/s\n".format((1000 / (whole_avg_latency * 1000)))) + f.write("flops: {0:8.2f} TFlops/s\n".format(1 / whole_avg_latency * num_parameters * num_bytes / 1e12)) f.write("----------------------------------------------------------\n") - if torch.cuda.is_available(): current_device = torch.cuda.current_device() @@ -66,7 +75,10 @@ def print_details_info(timestamps, model_config, args, whole_end2end): max_memory_allocated = torch.cuda.max_memory_allocated() memory_reserved = torch.cuda.memory_reserved() max_memory_reserved = torch.cuda.max_memory_reserved() - with open(f"{args.log_path}/llama-{args.model}{args.dtype}_pp{args.pp_size}_{args.seq_len}_{args.new_length}_bsz{args.batch_size}_mbsz{args.mb_size}.log","a") as f: + with open( + f"{args.log_path}/llama-{args.model}{args.dtype}_pp{args.pp_size}_{args.seq_len}_{args.new_length}_bsz{args.batch_size}_mbsz{args.mb_size}.log", + "a", + ) as f: f.write( f"\nCurrently using GPU: {current_device}\n" f"free memory : {global_free_memory / GIGABYTE:.4f} GB,\n" @@ -77,29 +89,37 @@ def print_details_info(timestamps, model_config, args, whole_end2end): f"Max CUDA memory reserved/cached: {max_memory_reserved / GIGABYTE:.4f} GB,\n" ) -if __name__ == '__main__': + +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--model', default='toy', help='the size of model') - parser.add_argument('-b', '--batch_size', type=int, default=8, help='batch size') - parser.add_argument('-s', '--seq_len', type=int, default=8, help='sequence length') - parser.add_argument('--new_length', type=int, default=4, help='new tokens length') - parser.add_argument('--mb_size', type=int, default=1, help='micro_batch_size') - parser.add_argument('--pp_size', type=int, default=2, help='pipeline size') - parser.add_argument('--log_path', type=str, default='./log' ,help='where to store the benchmark log') - parser.add_argument('--dtype', type=str, default='fp16', help='data type') + parser.add_argument("--model", default="toy", help="the size of model") + parser.add_argument("-b", "--batch_size", type=int, default=8, help="batch size") + parser.add_argument("-s", "--seq_len", type=int, default=8, help="sequence length") + parser.add_argument("--new_length", type=int, default=4, help="new tokens length") + parser.add_argument("--mb_size", type=int, default=1, help="micro_batch_size") + parser.add_argument("--pp_size", type=int, default=2, help="pipeline size") + parser.add_argument("--log_path", type=str, default="./log", help="where to store the benchmark log") + parser.add_argument("--dtype", type=str, default="fp16", help="data type") args = parser.parse_args() - if args.model == 'toy': + if args.model == "toy": model = transformers.LlamaForCausalLM(transformers.LlamaConfig(num_hidden_layers=8)) - elif args.model == '7b': - model = transformers.LlamaForCausalLM(transformers.AutoConfig.from_pretrained('decapoda-research/llama-7b-hf')) - elif args.model == '13b': - model = transformers.LlamaForCausalLM(transformers.AutoConfig.from_pretrained('decapoda-research/llama-13b-hf')) + elif args.model == "7b": + model = transformers.LlamaForCausalLM(transformers.AutoConfig.from_pretrained("decapoda-research/llama-7b-hf")) + elif args.model == "13b": + model = transformers.LlamaForCausalLM(transformers.AutoConfig.from_pretrained("decapoda-research/llama-13b-hf")) else: raise NotImplementedError - - - engine = PPInferEngine(pp_size=args.pp_size, dtype=args.dtype, micro_batch_size=args.mb_size, new_length=args.new_length, model=model, model_policy=LlamaForCausalLMPipelinePolicy(),verbose=True) + + engine = PPInferEngine( + pp_size=args.pp_size, + dtype=args.dtype, + micro_batch_size=args.mb_size, + new_length=args.new_length, + model=model, + model_policy=LlamaForCausalLMPipelinePolicy(), + verbose=True, + ) data = data_gen(args.batch_size, args.seq_len) torch.cuda.synchronize() @@ -109,4 +129,3 @@ def print_details_info(timestamps, model_config, args, whole_end2end): whole_end2end = time.time() - whole_end2end print_details_info(timestamps, model.config, args, whole_end2end) - diff --git a/colossalai/inference/pipeline/engine.py b/colossalai/inference/pipeline/engine.py index 048ead2bccda..4f42385caf8f 100644 --- a/colossalai/inference/pipeline/engine.py +++ b/colossalai/inference/pipeline/engine.py @@ -1,5 +1,3 @@ -from typing import Callable, List, Optional, Set, Union - import torch import torch.nn as nn @@ -13,7 +11,7 @@ class PPInferEngine: - ''' + """ PPInferEngine is a class that handles the pipeline parallel inference. Args: @@ -41,12 +39,12 @@ class PPInferEngine: output = engine.inference([tokenized_input]) ``` - ''' + """ def __init__( self, pp_size: int, - dtype: str = 'fp16', + dtype: str = "fp16", pp_model: nn.Module = None, model: nn.Module = None, model_policy: Policy = None, @@ -54,7 +52,7 @@ def __init__( micro_batch_size: int = 1, micro_batch_buffer_size: int = None, verbose: bool = False, - # TODO: implement early_stopping, and various gerneration options + # TODO: implement early_stopping, and various gerneration options early_stopping: bool = False, do_sample: bool = False, num_beams: int = 1, @@ -63,15 +61,16 @@ def __init__( self.pp_size = pp_size self.pg_mesh = ProcessGroupMesh(pp_size) self.stage_manager = PipelineStageManager(self.pg_mesh, 0, True) - self.mb_manager = MicroBatchManager(self.stage_manager.stage, new_length, micro_batch_size, - micro_batch_buffer_size or pp_size) + self.mb_manager = MicroBatchManager( + self.stage_manager.stage, new_length, micro_batch_size, micro_batch_buffer_size or pp_size + ) self.verbose = verbose self.schedule = GenerateSchedule(self.stage_manager, self.mb_manager, verbose) - assert dtype in ['fp16', 'fp32', 'bf16'], "dtype should be one of 'fp16', 'fp32', 'bf16'" - if dtype == 'fp16': + assert dtype in ["fp16", "fp32", "bf16"], "dtype should be one of 'fp16', 'fp32', 'bf16'" + if dtype == "fp16": model.half() - elif dtype == 'bf16': + elif dtype == "bf16": model.to(torch.bfloat16) self.model = pp_model or self._shardformer(model, model_policy) diff --git a/colossalai/inference/pipeline/microbatch_manager.py b/colossalai/inference/pipeline/microbatch_manager.py index b6b008442cfd..49d1bf3f42cb 100644 --- a/colossalai/inference/pipeline/microbatch_manager.py +++ b/colossalai/inference/pipeline/microbatch_manager.py @@ -3,7 +3,7 @@ import torch -__all__ = 'MicroBatchManager' +__all__ = "MicroBatchManager" class Status(Enum): @@ -13,7 +13,7 @@ class Status(Enum): COOLDOWN = 4 -class MicroBatchDescription(): +class MicroBatchDescription: """ This is the class to record the infomation of each microbatch, and also do some update operation. This clase is the base class of `HeadMicroBatchDescription` and `BodyMicroBatchDescription`, for more @@ -30,14 +30,14 @@ def __init__( output_dict: Dict[str, torch.Tensor], new_length: int, ) -> None: - assert output_dict.get('hidden_states') is not None - self.mb_length = output_dict['hidden_states'].shape[-2] + assert output_dict.get("hidden_states") is not None + self.mb_length = output_dict["hidden_states"].shape[-2] self.target_length = self.mb_length + new_length self.kv_cache = () def update(self, output_dict: Dict[str, torch.Tensor] = None, new_token: torch.Tensor = None): if output_dict is not None: - self._update_kvcache(output_dict['past_key_values']) + self._update_kvcache(output_dict["past_key_values"]) def _update_kvcache(self, kv_cache: Tuple): assert type(kv_cache) == tuple @@ -64,7 +64,6 @@ def cur_length(self): Return the current sequnence length of micro batch """ - pass class HeadMicroBatchDescription(MicroBatchDescription): @@ -80,13 +79,14 @@ class HeadMicroBatchDescription(MicroBatchDescription): """ - def __init__(self, inputs_dict: Dict[str, torch.Tensor], output_dict: Dict[str, torch.Tensor], - new_length: int) -> None: + def __init__( + self, inputs_dict: Dict[str, torch.Tensor], output_dict: Dict[str, torch.Tensor], new_length: int + ) -> None: super().__init__(inputs_dict, output_dict, new_length) assert inputs_dict is not None - assert inputs_dict.get('input_ids') is not None and inputs_dict.get('attention_mask') is not None - self.input_ids = inputs_dict['input_ids'] - self.attn_mask = inputs_dict['attention_mask'] + assert inputs_dict.get("input_ids") is not None and inputs_dict.get("attention_mask") is not None + self.input_ids = inputs_dict["input_ids"] + self.attn_mask = inputs_dict["attention_mask"] self.new_tokens = None def update(self, output_dict: Dict[str, torch.Tensor] = None, new_token: torch.Tensor = None): @@ -104,7 +104,8 @@ def _update_newtokens(self, new_token: torch.Tensor): def _update_attnmask(self): self.attn_mask = torch.cat( - (self.attn_mask, torch.ones((self.attn_mask.shape[0], 1), dtype=torch.int64, device='cuda')), dim=-1) + (self.attn_mask, torch.ones((self.attn_mask.shape[0], 1), dtype=torch.int64, device="cuda")), dim=-1 + ) @property def cur_length(self): @@ -127,8 +128,9 @@ class BodyMicroBatchDescription(MicroBatchDescription): output_dict (Dict[str, torch.Tensor]): the outputs of previous stage. The key should have `hidden_states` and `past_key_values`. """ - def __init__(self, inputs_dict: Dict[str, torch.Tensor], output_dict: Dict[str, torch.Tensor], - new_length: int) -> None: + def __init__( + self, inputs_dict: Dict[str, torch.Tensor], output_dict: Dict[str, torch.Tensor], new_length: int + ) -> None: super().__init__(inputs_dict, output_dict, new_length) def update(self, output_dict: Dict[str, torch.Tensor] = None, new_token: torch.Tensor = None): @@ -146,8 +148,8 @@ def cur_length(self): return self.kv_cache[0][0].shape[-2] + 1 -class MicroBatchManager(): - ''' +class MicroBatchManager: + """ MicroBatchManager is a class that manages the micro batch. Args: @@ -156,7 +158,7 @@ class MicroBatchManager(): micro_batch_size (int): the micro batch size. micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages. - ''' + """ def __init__(self, stage: int, new_length: int, micro_batch_size: int, micro_batch_buffer_size: int): self.stage = stage diff --git a/colossalai/inference/pipeline/modeling/gpt2.py b/colossalai/inference/pipeline/modeling/gpt2.py index f490710c1f7f..d2bfcb8b6842 100644 --- a/colossalai/inference/pipeline/modeling/gpt2.py +++ b/colossalai/inference/pipeline/modeling/gpt2.py @@ -1,7 +1,6 @@ from typing import Dict, List, Optional, Tuple, Union import torch -from torch.nn import CrossEntropyLoss from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel, GPT2Model from transformers.utils import logging @@ -10,41 +9,41 @@ class GPT2PipelineForwards: - ''' + """ This class serves as a micro library for forward function substitution of GPT2 models under pipeline setting. - ''' + """ @staticmethod def gpt2_model_forward( - self: GPT2Model, - input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, - attention_mask: Optional[torch.FloatTensor] = None, - token_type_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - stage_manager: Optional[PipelineStageManager] = None, - hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None) -> Union[Dict, Tuple, BaseModelOutputWithPastAndCrossAttentions]: - + self: GPT2Model, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + ) -> Union[Dict, Tuple, BaseModelOutputWithPastAndCrossAttentions]: # This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2Model.forward. # Please refer to original code of transformers for more details. logger = logging.get_logger(__name__) # Preprocess passed in arguments if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") output_attentions = False if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") output_hidden_states = False use_cache = use_cache if use_cache is not None else self.config.use_cache @@ -96,7 +95,7 @@ def gpt2_model_forward( # positions we want to attend and the dtype's smallest value for masked positions. # Since we are adding it to the raw scores before the softmax, this is # effectively the same as removing these entirely. - attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility + attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min # If a 2D or 3D attention mask is provided for the cross-attention @@ -137,7 +136,8 @@ def gpt2_model_forward( if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) use_cache = False presents = () if use_cache else None @@ -166,7 +166,6 @@ def gpt2_model_forward( if self.gradient_checkpointing and self.training: def create_custom_forward(module): - def custom_forward(*inputs): # None for past_key_value return module(*inputs, use_cache, output_attentions) @@ -218,61 +217,64 @@ def custom_forward(*inputs): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - return {'hidden_states': hidden_states, 'past_key_values': presents} + return {"hidden_states": hidden_states, "past_key_values": presents} @staticmethod def gpt2_lmhead_model_forward( - self: GPT2LMHeadModel, - input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, - attention_mask: Optional[torch.FloatTensor] = None, - token_type_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - stage_manager: Optional[PipelineStageManager] = None, - hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None) -> Union[Dict, Tuple, CausalLMOutputWithCrossAttentions]: + self: GPT2LMHeadModel, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + ) -> Union[Dict, Tuple, CausalLMOutputWithCrossAttentions]: r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set - `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` - are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` - - This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.forward. - Please refer to original code of transformers for more details. - """ + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + + This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.forward. + Please refer to original code of transformers for more details. + """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict # If is first stage and after warmup, go throught lm_head first if stage_manager.is_first_stage() and hidden_states is not None: lm_logits = self.lm_head(hidden_states) - return {'logits': lm_logits} + return {"logits": lm_logits} # Not first stage or before warmup, go through gpt2 model - outputs = GPT2PipelineForwards.gpt2_model_forward(self.transformer, - input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - stage_manager=stage_manager, - hidden_states=hidden_states, - stage_index=stage_index) + outputs = GPT2PipelineForwards.gpt2_model_forward( + self.transformer, + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + ) return outputs diff --git a/colossalai/inference/pipeline/modeling/llama.py b/colossalai/inference/pipeline/modeling/llama.py index eeda96df25fd..f46e1fbdd7b3 100644 --- a/colossalai/inference/pipeline/modeling/llama.py +++ b/colossalai/inference/pipeline/modeling/llama.py @@ -1,8 +1,6 @@ -from typing import List, Optional, Tuple +from typing import List, Optional import torch -from torch.nn import CrossEntropyLoss, MSELoss -from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaModel from transformers.utils import logging @@ -10,10 +8,10 @@ class LlamaPipelineForwards: - ''' + """ This class serves as a micro library for forward function substitution of Llama models under pipeline setting. - ''' + """ def llama_model_forward( self: LlamaModel, @@ -34,10 +32,10 @@ def llama_model_forward( # Preprocess passed in arguments if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") output_attentions = False if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") output_hidden_states = False use_cache = use_cache if use_cache is not None else self.config.use_cache @@ -70,10 +68,9 @@ def llama_model_forward( seq_length_with_past = seq_length_with_past + past_key_values_length if position_ids is None: - position_ids = torch.arange(past_key_values_length, - seq_length + past_key_values_length, - dtype=torch.long, - device=device) + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) position_ids = position_ids.unsqueeze(0).view(-1, seq_length) else: position_ids = position_ids.view(-1, seq_length).long() @@ -81,16 +78,18 @@ def llama_model_forward( # embed positions, for the first stage, hidden_states is the input embeddings, # for the other stages, hidden_states is the output of the previous stage if attention_mask is None: - attention_mask = torch.ones((batch_size, seq_length_with_past), - dtype=torch.bool, - device=hidden_states.device) - attention_mask = self._prepare_decoder_attention_mask(attention_mask, (batch_size, seq_length), hidden_states, - past_key_values_length) + attention_mask = torch.ones( + (batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device + ) + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length + ) if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) use_cache = False # decoder layers @@ -112,7 +111,6 @@ def llama_model_forward( if self.gradient_checkpointing and self.training: def create_custom_forward(module): - def custom_forward(*inputs): # None for past_key_value return module(*inputs, output_attentions, None) @@ -152,7 +150,7 @@ def custom_forward(*inputs): next_cache = next_decoder_cache if use_cache else None # always return dict for imediate stage - return {'hidden_states': hidden_states, 'past_key_values': next_cache} + return {"hidden_states": hidden_states, "past_key_values": next_cache} def llama_for_causal_lm_forward( self: LlamaForCausalLM, @@ -171,45 +169,45 @@ def llama_for_causal_lm_forward( stage_index: Optional[List[int]] = None, ): r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - Returns: + Returns: - Example: + Example: - ```python - >>> from transformers import AutoTokenizer, LlamaForCausalLM + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM - >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) - >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) - >>> prompt = "Hey, are you consciours? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") + >>> prompt = "Hey, are you consciours? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you." - ```""" + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you." + ```""" logger = logging.get_logger(__name__) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") output_attentions = False if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") output_hidden_states = False # If is first stage and after warmup, go throught lm_head first if stage_manager.is_first_stage() and hidden_states is not None: lm_logits = self.lm_head(hidden_states) - return {'logits': lm_logits} + return {"logits": lm_logits} # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = LlamaPipelineForwards.llama_model_forward( diff --git a/colossalai/inference/pipeline/policy/gpt2_ppinfer.py b/colossalai/inference/pipeline/policy/gpt2_ppinfer.py index e51090200f83..51e6425b113e 100644 --- a/colossalai/inference/pipeline/policy/gpt2_ppinfer.py +++ b/colossalai/inference/pipeline/policy/gpt2_ppinfer.py @@ -11,7 +11,6 @@ class GPT2LMHeadModelPipelinePolicy(GPT2Policy): - def __init__(self) -> None: super().__init__() @@ -22,18 +21,22 @@ def module_policy(self): if self.shard_config.enable_tensor_parallelism: addon_module = { - GPT2LMHeadModel: - ModulePolicyDescription(sub_module_replacement=[ + GPT2LMHeadModel: ModulePolicyDescription( + sub_module_replacement=[ SubModuleReplacementDescription( - suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}) - ]) + suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True} + ) + ] + ) } module_policy.update(addon_module) if self.pipeline_stage_manager is not None: - self.set_pipeline_forward(model_cls=GPT2LMHeadModel, - new_forward=GPT2PipelineForwards.gpt2_lmhead_model_forward, - policy=module_policy) + self.set_pipeline_forward( + model_cls=GPT2LMHeadModel, + new_forward=GPT2PipelineForwards.gpt2_lmhead_model_forward, + policy=module_policy, + ) return module_policy def get_held_layers(self) -> List[nn.Module]: @@ -45,7 +48,7 @@ def get_held_layers(self) -> List[nn.Module]: return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: - '''The weights of wte and lm_head are shared.''' + """The weights of wte and lm_head are shared.""" module = self.model stage_manager = self.pipeline_stage_manager if stage_manager is not None: @@ -56,16 +59,16 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: """If under pipeline parallel setting, replacing the original forward method of huggingface - to customized forward method, and add this changing to policy.""" + to customized forward method, and add this changing to policy.""" if not self.pipeline_stage_manager: raise ValueError("set_pipeline_forward method can only be called when pipeline parallel is enabled.") stage_manager = self.pipeline_stage_manager - if self.model.__class__.__name__ == 'GPT2Model': + if self.model.__class__.__name__ == "GPT2Model": module = self.model else: module = self.model.transformer layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages) stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) - method_replacement = {'forward': partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} + method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) diff --git a/colossalai/inference/pipeline/policy/llama_ppinfer.py b/colossalai/inference/pipeline/policy/llama_ppinfer.py index bb359de0bb6f..6e12ed61bf7b 100644 --- a/colossalai/inference/pipeline/policy/llama_ppinfer.py +++ b/colossalai/inference/pipeline/policy/llama_ppinfer.py @@ -1,19 +1,15 @@ -from functools import partial -from typing import Callable, Dict, List, Union +from typing import List -import torch.nn as nn -from torch import Tensor from torch.nn import Module -from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D -from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription +from colossalai.shardformer.layer import Linear1D_Col +from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription from colossalai.shardformer.policies.llama import LlamaPolicy from ..modeling.llama import LlamaPipelineForwards class LlamaForCausalLMPipelinePolicy(LlamaPolicy): - def __init__(self) -> None: super().__init__() @@ -25,19 +21,21 @@ def module_policy(self): if self.shard_config.enable_tensor_parallelism: # add a new item for casual lm new_item = { - LlamaForCausalLM: - ModulePolicyDescription(sub_module_replacement=[ + LlamaForCausalLM: ModulePolicyDescription( + sub_module_replacement=[ SubModuleReplacementDescription( - suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True)) - ]) + suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True) + ) + ] + ) } policy.update(new_item) if self.pipeline_stage_manager: # set None as default - self.set_pipeline_forward(model_cls=LlamaForCausalLM, - new_forward=LlamaPipelineForwards.llama_for_causal_lm_forward, - policy=policy) + self.set_pipeline_forward( + model_cls=LlamaForCausalLM, new_forward=LlamaPipelineForwards.llama_for_causal_lm_forward, policy=policy + ) return policy diff --git a/colossalai/inference/pipeline/utils.py b/colossalai/inference/pipeline/utils.py index 1a6e8a519397..c26aa4e40b71 100644 --- a/colossalai/inference/pipeline/utils.py +++ b/colossalai/inference/pipeline/utils.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Set +from typing import Set import torch.nn as nn @@ -30,6 +30,6 @@ def get_suffix_name(suffix: str, name: str): suffix (str): The suffix of the suffix module name (str): The name of the current module """ - point = '' if suffix is '' else '.' - suffix_name = suffix + f'[{name}]' if name.isdigit() else suffix + f'{point}{name}' + point = "" if suffix is "" else "." + suffix_name = suffix + f"[{name}]" if name.isdigit() else suffix + f"{point}{name}" return suffix_name diff --git a/colossalai/pipeline/p2p.py b/colossalai/pipeline/p2p.py index 67e198ca0347..f822c1819adc 100644 --- a/colossalai/pipeline/p2p.py +++ b/colossalai/pipeline/p2p.py @@ -167,7 +167,7 @@ def _p2p_comm( group: ProcessGroup, comm_dtype: torch.dtype = torch.float16, ): - """ + """ Send and recv tensor using P2P communication, used when pipeline size is 2 to solve the race communication. Agrs: @@ -176,7 +176,7 @@ def _p2p_comm( peer (int): rank of the peer group (ProcessGroup): process group comm_dtype (torch.dtype): dtype of the tensor to be sent - + Returns: torch.Tensor: tensor received from previous stage """ @@ -302,7 +302,9 @@ def send_backward(self, input_object: Any, prev_rank: int = None) -> None: cur_rank = self.stage_manager.get_rank() _send_object(input_object, cur_rank, prev_rank, self.stage_manager.get_p2p_process_group(cur_rank, prev_rank)) - def p2p_communicate(self, output_object: Any, recv_pre: bool, peer: int = None, comm_dtype: torch.dtype = torch.float16) -> None: + def p2p_communicate( + self, output_object: Any, recv_pre: bool, peer: int = None, comm_dtype: torch.dtype = torch.float16 + ) -> None: """ Sends the input tensor to the next stage in pipeline, using `P2Pop` in torch. @@ -313,5 +315,7 @@ def p2p_communicate(self, output_object: Any, recv_pre: bool, peer: int = None, if peer is None: peer = self.stage_manager.get_next_rank() cur_rank = self.stage_manager.get_rank() - recv_tensor = _p2p_comm(output_object, recv_pre, peer, self.stage_manager.get_p2p_process_group(cur_rank, peer), comm_dtype) + recv_tensor = _p2p_comm( + output_object, recv_pre, peer, self.stage_manager.get_p2p_process_group(cur_rank, peer), comm_dtype + ) return recv_tensor diff --git a/colossalai/pipeline/schedule/generate.py b/colossalai/pipeline/schedule/generate.py index 8f6acd5fcf4b..1f4bbe9f8dad 100644 --- a/colossalai/pipeline/schedule/generate.py +++ b/colossalai/pipeline/schedule/generate.py @@ -1,6 +1,6 @@ import time from functools import partial -from typing import Any, Iterable, List, Optional, Union +from typing import Any, Iterable, Optional, Union import torch import torch.cuda @@ -16,7 +16,7 @@ from .base import PipelineSchedule -class ActionIntervalBuffer(): +class ActionIntervalBuffer: """ The buffer to save the interval hidden states and new token for stage to use. @@ -70,8 +70,9 @@ def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) self.batch = batch self.batch_size = get_batch_size(batch) self.microbatch_offset = 0 - assert self.batch_size % self.microbatch_size == 0, \ - f"Batch size should divided by the number of microbatches, {self.batch_size}, {self.num_microbatches}" + assert ( + self.batch_size % self.microbatch_size == 0 + ), f"Batch size should divided by the number of microbatches, {self.batch_size}, {self.num_microbatches}" self.num_microbatches = self.batch_size // self.microbatch_size self.round = self.num_microbatches // self.stage_manager.num_stages @@ -86,26 +87,26 @@ def load_micro_batch(self) -> Any: return tree_map(partial(to_device, device=get_current_device()), micro_batch) def _prepare_inputs_for_interval_stage(self): - ''' + """ Prepare inputs for interval stage, for all the interval stage, the inputs is just the past_key_values Returns: dict: inputs for interval stage, `{'past_key_values': torch.Tensor}` or `None` - ''' - model_inputs = { - 'past_key_values': self.mb_manager.cur_kv_cache - } if self.mb_manager.cur_kv_cache is not None else None + """ + model_inputs = ( + {"past_key_values": self.mb_manager.cur_kv_cache} if self.mb_manager.cur_kv_cache is not None else None + ) return model_inputs def _prepare_inputs_for_new_token(self, new_token: torch.Tensor): - ''' + """ Prepare inputs for new token, the inputs is a dict with `input_ids`, `attention_mask` and `past_key_values` `input_ids` is the new token, `attention_mask` is the previous mask add `1` in the end, `past_key_values` is the past_key_values save in the micro batch manager Returns: dict: inputs for new token, `{'input_ids': torch.Tensor, 'attention_mask': torch.Tensor, 'past_key_values': torch.Tensor}` - ''' + """ new_mask = self.mb_manager.cur_descrption.attn_mask past_key_values = self.mb_manager.cur_descrption.kv_cache @@ -117,12 +118,12 @@ def _get_token_id(self, hidden_state: torch.Tensor) -> torch.Tensor: return input_ids def _recv_pre_stage(self) -> Any: - ''' + """ Receive the output from previous stage Returns: Any: The output from previous stage - ''' + """ if self.stage_manager.num_stages == 2: return self.comm.p2p_recv() return self.comm.recv_forward() @@ -138,7 +139,7 @@ def _load_stage_action(self, model: Module) -> None: output_dict = model_forward(model, inputs_dict, None) self.mb_manager.step(inputs_dict, output_dict, None) - self.action_interval_buffer.hidden_states = output_dict['hidden_states'] + self.action_interval_buffer.hidden_states = output_dict["hidden_states"] def _gen_token_action(self, model: Module): """ @@ -146,13 +147,15 @@ def _gen_token_action(self, model: Module): """ hidden_states = self.action_interval_buffer.hidden_states assert hidden_states is not None, "When first stage in GENERATE phase, the hidden states should not be None" - hidden_states = {'hidden_states': hidden_states} + hidden_states = {"hidden_states": hidden_states} logits = model_forward(model, None, hidden_states) if self.verbose and self.stage_manager.is_first_stage(): torch.cuda.synchronize() self.timestamps[self.mb_manager.idx].append(time.time()) - assert 'logits' in logits, f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}" - new_token = self._get_token_id(logits['logits']) + assert ( + "logits" in logits + ), f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}" + new_token = self._get_token_id(logits["logits"]) self.mb_manager.step(None, None, new_token) self.action_interval_buffer.new_token = new_token @@ -168,17 +171,17 @@ def _head_encoding_action(self, model: Module): output_dict = model_forward(model, inputs_dict, None) self.mb_manager.step(inputs_dict, output_dict, None) - self.action_interval_buffer.hidden_states = output_dict['hidden_states'] + self.action_interval_buffer.hidden_states = output_dict["hidden_states"] def _body_encoding_action(self, model: Module): hidden_states = self.action_interval_buffer.hidden_states assert hidden_states is not None, "When not first stage, the hidden states should not be None" inputs_dict = self._prepare_inputs_for_interval_stage() - hidden_states = {'hidden_states': hidden_states} + hidden_states = {"hidden_states": hidden_states} output_dict = model_forward(model, inputs_dict, hidden_states) self.mb_manager.step(inputs_dict, output_dict, None) - self.action_interval_buffer.hidden_states = output_dict['hidden_states'] + self.action_interval_buffer.hidden_states = output_dict["hidden_states"] def _comm_action(self, recv_pre: bool) -> torch.Tensor: """ @@ -246,10 +249,13 @@ def generate_step_p2p(self, model: Module, data_iter: Iterable) -> Union[torch.T whole_timestamp = [] - #run by round + # run by round for _ in range(self.round): - self.timestamps = [[] for _ in range(self.stage_manager.num_stages) - ] if self.verbose and self.stage_manager.is_first_stage() else None + self.timestamps = ( + [[] for _ in range(self.stage_manager.num_stages)] + if self.verbose and self.stage_manager.is_first_stage() + else None + ) self.action_interval_buffer.clear() while self.mb_manager.is_micro_batch_done() is False: actions = self._gen_action(model) @@ -286,8 +292,11 @@ def generate_step_broadcast(self, model: Module, data_iter: Iterable) -> Union[t whole_timestamp = [] # run by round for _ in range(self.round): - self.timestamps = [[] for _ in range(self.stage_manager.num_stages) - ] if self.verbose and self.stage_manager.is_first_stage() else None + self.timestamps = ( + [[] for _ in range(self.stage_manager.num_stages)] + if self.verbose and self.stage_manager.is_first_stage() + else None + ) while self.mb_manager.is_micro_batch_done() is False: inputs_dict = None new_token = None @@ -307,13 +316,17 @@ def generate_step_broadcast(self, model: Module, data_iter: Iterable) -> Union[t hidden_states = self.comm.recv_forward() if self.stage_manager.is_first_stage(): # First just generate a new token - assert hidden_states is not None, "When first stage in GENERATE phase, the hidden states should not be None" + assert ( + hidden_states is not None + ), "When first stage in GENERATE phase, the hidden states should not be None" logits = model_forward(model, None, hidden_states) if self.verbose and self.stage_manager.is_first_stage(): torch.cuda.synchronize() self.timestamps[self.mb_manager.idx].append(time.time()) - assert 'logits' in logits, f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}" - new_token = self._get_token_id(logits['logits']) + assert ( + "logits" in logits + ), f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}" + new_token = self._get_token_id(logits["logits"]) self.mb_manager.step(None, None, new_token) # If the current micro batch is not DONE, go through blocks if self.mb_manager.cur_state in (Status.GENERATE, Status.COOLDOWN): @@ -327,9 +340,11 @@ def generate_step_broadcast(self, model: Module, data_iter: Iterable) -> Union[t self.mb_manager.step(inputs_dict, output_dict, None) # Current microbatch is not DONE, send hidden_state to next stage - if not self.stage_manager.is_first_stage() or self.mb_manager.cur_state in (Status.GENERATE, - Status.COOLDOWN): - self.comm.send_forward({'hidden_states': output_dict['hidden_states']}) + if not self.stage_manager.is_first_stage() or self.mb_manager.cur_state in ( + Status.GENERATE, + Status.COOLDOWN, + ): + self.comm.send_forward({"hidden_states": output_dict["hidden_states"]}) self.mb_manager.next() diff --git a/tests/test_infer/test_pipeline_infer.py b/tests/test_infer/test_pipeline_infer.py index 47cf9e78d138..ad8e32b48bae 100644 --- a/tests/test_infer/test_pipeline_infer.py +++ b/tests/test_infer/test_pipeline_infer.py @@ -1,9 +1,6 @@ -from copy import deepcopy - import pytest import torch import torch.distributed as dist -import torch.nn as nn import transformers import colossalai @@ -20,27 +17,29 @@ def data_gen(): inputs = data_gen() for k, v in inputs.items(): - if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__: + if torch.is_tensor(v) or "Tensor" in v.__class__.__name__: new_shape = [1] * v.dim() new_shape[0] = 16 - inputs[k] = v.to('cuda').repeat(*new_shape) + inputs[k] = v.to("cuda").repeat(*new_shape) def pipeline_inference_test(pp_size, new_length, micro_batch_size): model = transformers.GPT2LMHeadModel(transformers.GPT2Config(n_layer=8)) - engine = PPInferEngine(pp_size=pp_size, - model=model, - model_policy=GPT2LMHeadModelPipelinePolicy(), - new_length=new_length, - micro_batch_size=micro_batch_size) + engine = PPInferEngine( + pp_size=pp_size, + model=model, + model_policy=GPT2LMHeadModelPipelinePolicy(), + new_length=new_length, + micro_batch_size=micro_batch_size, + ) output = engine.inference([inputs]) if dist.get_rank() == 0: assert len(output[0]) == new_length, f"{len(output)}, {new_length}" -@parameterize('pp_size', [4]) -@parameterize('new_length', [4, 8, 16]) -@parameterize('micro_batch_size', [1, 4]) +@parameterize("pp_size", [4]) +@parameterize("new_length", [4, 8, 16]) +@parameterize("micro_batch_size", [1, 4]) @clear_cache_before_run() def run_pipeline_inference_test(pp_size, new_length, micro_batch_size): pipeline_inference_test(pp_size, new_length, micro_batch_size) @@ -48,7 +47,7 @@ def run_pipeline_inference_test(pp_size, new_length, micro_batch_size): def check_pipeline_inference(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_pipeline_inference_test() @@ -59,5 +58,5 @@ def test_pipeline_inference(): spawn(check_pipeline_inference, nprocs=4) -if __name__ == '__main__': +if __name__ == "__main__": test_pipeline_inference() From 907aa98d3fbf5fae353872e812b6a2fd4298deef Mon Sep 17 00:00:00 2001 From: digger yu Date: Wed, 18 Oct 2023 15:44:04 +0800 Subject: [PATCH 10/46] [nfc] fix some typo with colossalai/ docs/ etc. (#4920) --- colossalai/inference/README.md | 2 +- colossalai/shardformer/README.md | 2 +- docs/source/en/basics/booster_plugins.md | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/colossalai/inference/README.md b/colossalai/inference/README.md index 9a965dc982a4..ba6c95ce8832 100644 --- a/colossalai/inference/README.md +++ b/colossalai/inference/README.md @@ -94,7 +94,7 @@ For various models, experiments were conducted using multiple batch sizes under ### Single GPU Performance: -Currently the stats below are calculated based on A100 (single GPU), and we calculate token latency based on average values of context-forward and decoding forward process, which means we combine both of processes to calculate token generation times. We are actively developing new features and methods to furthur optimize the performance of LLM models. Please stay tuned. +Currently the stats below are calculated based on A100 (single GPU), and we calculate token latency based on average values of context-forward and decoding forward process, which means we combine both of processes to calculate token generation times. We are actively developing new features and methods to further optimize the performance of LLM models. Please stay tuned. #### Llama diff --git a/colossalai/shardformer/README.md b/colossalai/shardformer/README.md index 4bd7d5208a64..63b28701e879 100644 --- a/colossalai/shardformer/README.md +++ b/colossalai/shardformer/README.md @@ -77,7 +77,7 @@ Following are the description `ShardConfig`'s arguments: - `enable_sequence_parallelism`: Whether to turn on sequence parallelism, which partitions non-tensor-parallel regions along the sequence dimension. Defaults to False. -- `enable_sequence_overlap`: Whether to turn on sequence overlap, wheich overlap the computation and communication in sequence parallelism. It can only be used when `enable_sequence_parallelism` is True. Defaults to False. +- `enable_sequence_overlap`: Whether to turn on sequence overlap, which overlap the computation and communication in sequence parallelism. It can only be used when `enable_sequence_parallelism` is True. Defaults to False. - `enable_all_optimization`: Whether to turn on all optimization tools including `fused normalizaion`, `flash attention`, `JIT fused operators`, `sequence parallelism` and `sequence overlap`. Defaults to False. diff --git a/docs/source/en/basics/booster_plugins.md b/docs/source/en/basics/booster_plugins.md index feb37fc15de2..fa360a4b9213 100644 --- a/docs/source/en/basics/booster_plugins.md +++ b/docs/source/en/basics/booster_plugins.md @@ -15,7 +15,7 @@ We currently provide the following plugins: - [Torch FSDP Plugin](#torch-fsdp-plugin): It is a wrapper of `torch.distributed.fsdp.FullyShardedDataParallel` and can be used to train models with zero-dp. - [Low Level Zero Plugin](#low-level-zero-plugin): It wraps the `colossalai.zero.low_level.LowLevelZeroOptimizer` and can be used to train models with zero-dp. It only supports zero stage-1 and stage-2. - [Gemini Plugin](#gemini-plugin): It wraps the [Gemini](../features/zero_with_chunk.md) which implements Zero-3 with chunk-based and heterogeneous memory management. -- [Hybrid Pararllel Plugin](#hybrid-parallel-plugin): It provides a tidy interface that integrates the power of Shardformer, pipeline manager, mixied precision training, TorchDDP and Zero stage 1/2 feature. With this plugin, transformer models can be easily trained with any combination of tensor parallel, pipeline parallel and data parallel (DDP/Zero) efficiently, along with various kinds of optimization tools for acceleration and memory saving. Detailed information about supported parallel strategies and optimization tools is explained in the section below. +- [Hybrid Parallel Plugin](#hybrid-parallel-plugin): It provides a tidy interface that integrates the power of Shardformer, pipeline manager, mixied precision training, TorchDDP and Zero stage 1/2 feature. With this plugin, transformer models can be easily trained with any combination of tensor parallel, pipeline parallel and data parallel (DDP/Zero) efficiently, along with various kinds of optimization tools for acceleration and memory saving. Detailed information about supported parallel strategies and optimization tools is explained in the section below. More plugins are coming soon. From 31fddbc554aeec4db99e70cd37e6e6e00e9b6275 Mon Sep 17 00:00:00 2001 From: Cuiqing Li Date: Thu, 19 Oct 2023 22:22:47 +0800 Subject: [PATCH 11/46] [Refactor] Integrated some lightllm kernels into token-attention (#4946) * add some req for inference * clean codes * add codes * add some lightllm deps * clean codes * hello * delete rms files * add some comments * add comments * add doc * add lightllm deps * add lightllm cahtglm2 kernels * add lightllm cahtglm2 kernels * replace rotary embedding with lightllm kernel * add some commnets * add some comments * add some comments * add * replace fwd kernel att1 * fix a arg * add * add * fix token attention * add some comments * clean codes * modify comments * fix readme * fix bug * fix bug --------- Co-authored-by: cuiqing.li Co-authored-by: CjhHa1 --- colossalai/inference/README.md | 19 +- .../tensor_parallel/batch_infer_state.py | 3 +- .../tensor_parallel/kvcache_manager.py | 10 +- .../tensor_parallel/modeling/chatglm2.py | 18 +- .../tensor_parallel/modeling/llama.py | 25 +- .../tensor_parallel/policies/llama.py | 6 +- colossalai/kernel/triton/__init__.py | 7 +- colossalai/kernel/triton/context_attention.py | 328 +-------- .../kernel/triton/copy_kv_cache_dest.py | 2 + colossalai/kernel/triton/rms_norm.py | 71 -- .../kernel/triton/rotary_embedding_kernel.py | 212 ------ .../kernel/triton/self_attention_nofusion.py | 2 + .../kernel/triton/token_attention_kernel.py | 689 ++---------------- examples/inference/bench_llama.py | 56 +- requirements/requirements.txt | 3 + .../triton/test_llama2_token_attn.py | 63 -- .../triton/test_rotary_embedding.py | 55 -- .../triton/test_token_attn_1.py | 74 -- .../triton/test_token_attn_2.py | 63 -- .../triton/test_token_attn_fwd.py | 5 +- 20 files changed, 158 insertions(+), 1553 deletions(-) delete mode 100644 colossalai/kernel/triton/rms_norm.py delete mode 100644 colossalai/kernel/triton/rotary_embedding_kernel.py delete mode 100644 tests/test_infer_ops/triton/test_llama2_token_attn.py delete mode 100644 tests/test_infer_ops/triton/test_rotary_embedding.py delete mode 100644 tests/test_infer_ops/triton/test_token_attn_1.py delete mode 100644 tests/test_infer_ops/triton/test_token_attn_2.py diff --git a/colossalai/inference/README.md b/colossalai/inference/README.md index ba6c95ce8832..d0c281e057b3 100644 --- a/colossalai/inference/README.md +++ b/colossalai/inference/README.md @@ -4,7 +4,7 @@ ## Introduction -`Colossal Inference` is a module that contains colossal-ai designed inference framework, featuring high performance, steady and easy usability. `Colossal Inference` incorporated the advantages of the latest open-source inference systems, including TGI, vLLM, FasterTransformer, LightLLM and flash attention. while combining the design of Colossal AI, especially Shardformer, to reduce the learning curve for users. +`Colossal Inference` is a module that contains colossal-ai designed inference framework, featuring high performance, steady and easy usability. `Colossal Inference` incorporated the advantages of the latest open-source inference systems, including LightLLM, TGI, vLLM, FasterTransformer and flash attention. while combining the design of Colossal AI, especially Shardformer, to reduce the learning curve for users. ## Design @@ -62,6 +62,12 @@ triton==2.0.0.dev20221202 vllm # for install flash-attention, please use commit hash: 67ae6fd74b4bc99c36b2ce524cf139c35663793c flash-attention + +# install lightllm since we depend on lightllm triton kernels +git clone https://github.com/ModelTC/lightllm +git checkout 28c1267cfca536b7b4f28e921e03de735b003039 +cd lightllm +pip3 install -e . ``` ### Docker @@ -73,6 +79,17 @@ You can use docker run to use docker container to set-up environment docker pull hpcaitech/colossalai-inference:v2 docker run -it --gpus all --name ANY_NAME -v $PWD:/workspace -w /workspace hpcaitech/colossalai-inference:v2 /bin/bash +# enter into docker container +cd /path/to/CollossalAI +pip install -e . + +# install lightllm +git clone https://github.com/ModelTC/lightllm +git checkout 28c1267cfca536b7b4f28e921e03de735b003039 +cd lightllm +pip3 install -e . + + ``` ### Dive into fast-inference! diff --git a/colossalai/inference/tensor_parallel/batch_infer_state.py b/colossalai/inference/tensor_parallel/batch_infer_state.py index ac185f1b6529..de150311cc08 100644 --- a/colossalai/inference/tensor_parallel/batch_infer_state.py +++ b/colossalai/inference/tensor_parallel/batch_infer_state.py @@ -5,7 +5,7 @@ from .kvcache_manager import MemoryManager - +# adapted from: lightllm/server/router/model_infer/infer_batch.py @dataclass class BatchInferState: r""" @@ -41,6 +41,7 @@ def total_token_num(self): def set_cache_manager(self, manager: MemoryManager): self.cache_manager = manager + # adapted from: https://github.com/ModelTC/lightllm/blob/28c1267cfca536b7b4f28e921e03de735b003039/lightllm/common/infer_utils.py#L1 @staticmethod def init_block_loc( b_loc: torch.Tensor, seq_len: torch.Tensor, max_len_in_batch: int, alloc_mem_index: torch.Tensor diff --git a/colossalai/inference/tensor_parallel/kvcache_manager.py b/colossalai/inference/tensor_parallel/kvcache_manager.py index e74a3a491a7b..c9e7aaae0844 100644 --- a/colossalai/inference/tensor_parallel/kvcache_manager.py +++ b/colossalai/inference/tensor_parallel/kvcache_manager.py @@ -1,7 +1,9 @@ -# Adapted from lightllm/common/mem_manager.py -# of the ModelTC/lightllm GitHub repository -# https://github.com/ModelTC/lightllm/blob/050af3ce65edca617e2f30ec2479397d5bb248c9/lightllm/common/mem_manager.py - +""" +Refered/Modified from lightllm/common/mem_manager.py +of the ModelTC/lightllm GitHub repository +https://github.com/ModelTC/lightllm/blob/050af3ce65edca617e2f30ec2479397d5bb248c9/lightllm/common/mem_manager.py +we slightly changed it to make it suitable for our colossal-ai shardformer TP-engine design. +""" import torch from transformers.utils import logging diff --git a/colossalai/inference/tensor_parallel/modeling/chatglm2.py b/colossalai/inference/tensor_parallel/modeling/chatglm2.py index 4b1bc601f436..b8274d3c660f 100644 --- a/colossalai/inference/tensor_parallel/modeling/chatglm2.py +++ b/colossalai/inference/tensor_parallel/modeling/chatglm2.py @@ -6,8 +6,6 @@ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState -from colossalai.kernel.triton.context_attention import llama2_context_attn_fwd -from colossalai.kernel.triton.rotary_embedding_kernel import Llama2Forwards from colossalai.kernel.triton.token_attention_kernel import Llama2TokenAttentionForwards from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ( ChatGLMForConditionalGeneration, @@ -20,6 +18,14 @@ from ._utils import copy_kv_to_mem_cache +try: + from lightllm.models.llama2.triton_kernel.context_flashattention_nopad import context_attention_fwd as lightllm_llama2_context_attention_fwd + from lightllm.models.chatglm2.triton_kernel.rotary_emb import rotary_emb_fwd as chatglm2_rotary_emb_fwd + HAS_LIGHTLLM_KERNEL = True +except: + print("please install lightllm from source to run inference: https://github.com/ModelTC/lightllm") + HAS_LIGHTLLM_KERNEL = False + # This func is same as Llama model init_to_get_rotary, we should move them into _utils.py def _init_to_get_rotary(self, base=10000): @@ -433,17 +439,17 @@ def chatglm_flash_attn_kvcache_forward( cos, sin = infer_state.position_cos, infer_state.position_sin - Llama2Forwards.rotary_emb_fwd( + chatglm2_rotary_emb_fwd( query_layer.view(-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head), cos, sin ) if self.multi_query_attention: - Llama2Forwards.rotary_emb_fwd( + chatglm2_rotary_emb_fwd( key_layer.view(-1, self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head), cos, sin, ) else: - Llama2Forwards.rotary_emb_fwd( + chatglm2_rotary_emb_fwd( key_layer.view(-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head), cos, sin, @@ -474,7 +480,7 @@ def chatglm_flash_attn_kvcache_forward( attn_output = torch.empty_like(query_layer.view(-1, self.projection_size)) # NOTE: no bug in context attn fwd (del it ) - llama2_context_attn_fwd( + lightllm_llama2_context_attention_fwd( query_layer, key_layer, value_layer, diff --git a/colossalai/inference/tensor_parallel/modeling/llama.py b/colossalai/inference/tensor_parallel/modeling/llama.py index ac4ae72f3d18..a3937f6f10ba 100644 --- a/colossalai/inference/tensor_parallel/modeling/llama.py +++ b/colossalai/inference/tensor_parallel/modeling/llama.py @@ -5,12 +5,7 @@ from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState -from colossalai.kernel.triton import ( - llama2_context_attn_fwd, - llama_context_attn_fwd, - rotary_embedding_fwd, - token_attention_fwd, -) +from colossalai.kernel.triton import llama_context_attn_fwd, token_attention_fwd from colossalai.kernel.triton.token_attention_kernel import Llama2TokenAttentionForwards from ._utils import copy_kv_to_mem_cache @@ -29,6 +24,17 @@ ) HAS_VLLM_KERNERL = False +try: + from lightllm.models.llama2.triton_kernel.context_flashattention_nopad import ( + context_attention_fwd as lightllm_llama2_context_attention_fwd, + ) + from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd as llama_rotary_embedding_fwd + + HAS_LIGHTLLM_KERNEL = True +except: + print("please install lightllm from source to run inference: https://github.com/ModelTC/lightllm") + HAS_LIGHTLLM_KERNEL = False + def rotate_half(x): """Rotates half the hidden dims of the input.""" @@ -280,8 +286,8 @@ def llama_flash_attn_kvcache_forward( cos, sin = infer_state.position_cos, infer_state.position_sin # print("shape ", cos.shape, query_states.view(-1, self.num_heads, self.head_dim).shape, ) - rotary_embedding_fwd(query_states.view(-1, self.num_heads, self.head_dim), cos, sin) - rotary_embedding_fwd(key_states.view(-1, self.num_key_value_heads, self.head_dim), cos, sin) + llama_rotary_embedding_fwd(query_states.view(-1, self.num_heads, self.head_dim), cos, sin) + llama_rotary_embedding_fwd(key_states.view(-1, self.num_key_value_heads, self.head_dim), cos, sin) query_states = query_states.reshape(-1, self.num_heads, self.head_dim) key_states = key_states.reshape(-1, self.num_key_value_heads, self.head_dim) @@ -312,7 +318,7 @@ def llama_flash_attn_kvcache_forward( infer_state.cache_manager.past_key_values_length, ) else: - llama2_context_attn_fwd( + lightllm_llama2_context_attention_fwd( query_states, key_states, value_states, @@ -371,6 +377,7 @@ def llama_flash_attn_kvcache_forward( infer_state.cache_manager.past_key_values_length, infer_state.other_kv_index, ) + attn_output = attn_output.view(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) diff --git a/colossalai/inference/tensor_parallel/policies/llama.py b/colossalai/inference/tensor_parallel/policies/llama.py index 507c1203dd6b..7e163efe0173 100644 --- a/colossalai/inference/tensor_parallel/policies/llama.py +++ b/colossalai/inference/tensor_parallel/policies/llama.py @@ -12,8 +12,7 @@ from ..modeling.llama import LlamaInferenceForwards, get_llama_vllm_rmsnorm_forward try: - from colossalai.kernel.triton import rmsnorm_forward - + from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward as lightllm_rmsnorm_forward HAS_TRITON_RMSNORM = True except: print("you should install triton from https://github.com/openai/triton") @@ -22,9 +21,8 @@ def get_triton_rmsnorm_forward(): if HAS_TRITON_RMSNORM: - def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor): - return rmsnorm_forward(hidden_states, self.weight.data, self.variance_epsilon) + return lightllm_rmsnorm_forward(hidden_states, self.weight.data, self.variance_epsilon) return _triton_rmsnorm_forward else: diff --git a/colossalai/kernel/triton/__init__.py b/colossalai/kernel/triton/__init__.py index 27351a686d2f..1fe292289f3d 100644 --- a/colossalai/kernel/triton/__init__.py +++ b/colossalai/kernel/triton/__init__.py @@ -9,26 +9,21 @@ # There may exist import error even if we have triton installed. if HAS_TRITON: - from .context_attention import bloom_context_attn_fwd, llama2_context_attn_fwd, llama_context_attn_fwd + from .context_attention import bloom_context_attn_fwd, llama_context_attn_fwd from .copy_kv_cache_dest import copy_kv_cache_to_dest from .fused_layernorm import layer_norm from .gptq_triton import gptq_fused_linear_triton from .int8_rotary_embedding_kernel import int8_rotary_embedding_fwd - from .rms_norm import rmsnorm_forward - from .rotary_embedding_kernel import rotary_embedding_fwd from .smooth_attention import smooth_llama_context_attn_fwd, smooth_token_attention_fwd from .softmax import softmax from .token_attention_kernel import token_attention_fwd __all__ = [ "llama_context_attn_fwd", - "llama2_context_attn_fwd", "bloom_context_attn_fwd", "softmax", "layer_norm", - "rmsnorm_forward", "copy_kv_cache_to_dest", - "rotary_embedding_fwd", "token_attention_fwd", "gptq_fused_linear_triton", "int8_rotary_embedding_fwd", diff --git a/colossalai/kernel/triton/context_attention.py b/colossalai/kernel/triton/context_attention.py index 01d54566483a..1b4f6e44b0f2 100644 --- a/colossalai/kernel/triton/context_attention.py +++ b/colossalai/kernel/triton/context_attention.py @@ -238,329 +238,5 @@ def llama_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len): num_warps=num_warps, num_stages=1, ) - return - - @triton.jit - def _fwd_kernel_latest( - Q, - K, - V, - sm_scale, - B_Start_Loc, - B_Seqlen, - Out, - stride_qbs, - stride_qh, - stride_qd, - stride_kbs, - stride_kh, - stride_kd, - stride_vbs, - stride_vh, - stride_vd, - stride_obs, - stride_oh, - stride_od, - kv_group_num, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - ): - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - start_m = tl.program_id(2) - - cur_kv_head = cur_head // kv_group_num - - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - - block_start_loc = BLOCK_M * start_m - - # initialize offsets - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - off_q = ( - (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs - + cur_head * stride_qh - + offs_d[None, :] * stride_qd - ) - off_k = offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd - off_v = offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] * stride_vd - - q = tl.load(Q + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0) - - k_ptrs = K + off_k - v_ptrs = V + off_v - - # initialize pointer to m and l - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - - block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) - - for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - k = tl.load( - k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs, - mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, - other=0.0, - ) - # mask = tl.load(mask_ptrs + start_n, mask=start_n + offs_n < cur_batch_end_loc, other=0.0) - - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k) - qk *= sm_scale - qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) - - # -- compute m_ij, p, l_ij - m_ij = tl.max(qk, 1) - p = tl.exp(qk - m_ij[:, None]) - l_ij = tl.sum(p, 1) - # -- update m_i and l_i - m_i_new = tl.maximum(m_i, m_ij) - alpha = tl.exp(m_i - m_i_new) - beta = tl.exp(m_ij - m_i_new) - l_i_new = alpha * l_i + beta * l_ij - # -- update output accumulator -- - # scale p - p_scale = beta / l_i_new - p = p * p_scale[:, None] - # scale acc - acc_scale = l_i / l_i_new * alpha - acc = acc * acc_scale[:, None] - # update acc - v = tl.load( - v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, - mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, - other=0.0, - ) - - p = p.to(v.dtype) - acc += tl.dot(p, v) - # update m_i and l_i - l_i = l_i_new - m_i = m_i_new - # initialize pointers to output - off_o = ( - (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs - + cur_head * stride_oh - + offs_d[None, :] * stride_od - ) - out_ptrs = Out + off_o - tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) - return - - @triton.jit - def _fwd_kernel_old( - Q, - K, - V, - sm_scale, - B_Start_Loc, - B_Seqlen, - TMP, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug - Out, - stride_qbs, - stride_qh, - stride_qd, - stride_kbs, - stride_kh, - stride_kd, - stride_vbs, - stride_vh, - stride_vd, - stride_obs, - stride_oh, - stride_od, - stride_tmp_b, - stride_tmp_h, - stride_tmp_s, - kv_group_num, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - ): - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - start_m = tl.program_id(2) - - cur_kv_head = cur_head // kv_group_num - - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - - block_start_loc = BLOCK_M * start_m - - # initialize offsets - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - off_q = ( - (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs - + cur_head * stride_qh - + offs_d[None, :] * stride_qd - ) - off_k = offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd - off_v = offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] * stride_vd - q = tl.load(Q + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0) - - k_ptrs = K + off_k - v_ptrs = V + off_v - - t_ptrs = TMP + cur_batch * stride_tmp_b + cur_head * stride_tmp_h + offs_m * stride_tmp_s - # t_ptrs = TMP + offs_m - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - - block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) - - for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - k = tl.load( - k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs, - mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, - other=0.0, - ) - - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k) - qk *= sm_scale - qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) - - m_ij = tl.max(qk, 1) - p = tl.exp(qk - m_ij[:, None]) - l_ij = tl.sum(p, 1) - # -- update m_i and l_i - m_i_new = tl.maximum(m_i, m_ij) - alpha = tl.exp(m_i - m_i_new) - beta = tl.exp(m_ij - m_i_new) - l_i_new = alpha * l_i + beta * l_ij - # -- update output accumulator -- - # scale p - p_scale = beta / l_i_new - p = p * p_scale[:, None] - # scale acc - acc_scale = l_i / l_i_new * alpha - tl.store(t_ptrs, acc_scale) - acc_scale = tl.load(t_ptrs) # BUG: have to store and immediately load - acc = acc * acc_scale[:, None] - # update acc - v = tl.load( - v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, - mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, - other=0.0, - ) - - p = p.to(v.dtype) - acc += tl.dot(p, v) - # update m_i and l_i - l_i = l_i_new - m_i = m_i_new - # initialize pointers to output - off_o = ( - (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs - + cur_head * stride_oh - + offs_d[None, :] * stride_od - ) - out_ptrs = Out + off_o - tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) - - return - - @torch.no_grad() - def llama2_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len): - if triton.__version__ >= "2.1.0": - BLOCK = 128 - # shape constraints - Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] - assert Lq == Lk and Lk == Lv - assert Lk in {16, 32, 64, 128} - sm_scale = 1.0 / (Lq**0.5) # 计算scale系数 - batch, head = b_seq_len.shape[0], q.shape[1] - kv_group_num = q.shape[1] // k.shape[1] - - grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head, - - num_warps = 4 if Lk <= 64 else 8 - _fwd_kernel_latest[grid]( - q, - k, - v, - sm_scale, - b_start_loc, - b_seq_len, - o, - q.stride(0), - q.stride(1), - q.stride(2), - k.stride(0), - k.stride(1), - k.stride(2), - v.stride(0), - v.stride(1), - v.stride(2), - o.stride(0), - o.stride(1), - o.stride(2), - kv_group_num=kv_group_num, - BLOCK_M=BLOCK, - BLOCK_DMODEL=Lk, - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=1, - ) - return - - elif triton.__version__ == "2.0.0": - BLOCK = 128 - # shape constraints - Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] - assert Lq == Lk and Lk == Lv - assert Lk in {16, 32, 64, 128} - - sm_scale = 1.0 / (Lq**0.5) - batch, head = b_seq_len.shape[0], q.shape[1] - kv_group_num = q.shape[1] // k.shape[1] - - grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) - tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32) - num_warps = 4 if Lk <= 64 else 8 - # num_warps = 4 - _fwd_kernel_old[grid]( - q, - k, - v, - sm_scale, - b_start_loc, - b_seq_len, - tmp, - o, - q.stride(0), - q.stride(1), - q.stride(2), - k.stride(0), - k.stride(1), - k.stride(2), - v.stride(0), - v.stride(1), - v.stride(2), - o.stride(0), - o.stride(1), - o.stride(2), - tmp.stride(0), - tmp.stride(1), - tmp.stride(2), - kv_group_num=kv_group_num, - BLOCK_M=BLOCK, - BLOCK_DMODEL=Lk, - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=1, - ) - return + + return \ No newline at end of file diff --git a/colossalai/kernel/triton/copy_kv_cache_dest.py b/colossalai/kernel/triton/copy_kv_cache_dest.py index 02edcc9a903a..0ce6b09e54dc 100644 --- a/colossalai/kernel/triton/copy_kv_cache_dest.py +++ b/colossalai/kernel/triton/copy_kv_cache_dest.py @@ -11,6 +11,7 @@ if HAS_TRITON: + # adapted from https://github.com/ModelTC/lightllm/blob/5c559dd7981ed67679a08a1e09a88fb4c1550b3a/lightllm/common/triton_kernel/destindex_copy_kv.py @triton.jit def _fwd_copy_kv_cache_dest( kv_cache_ptr, @@ -42,6 +43,7 @@ def _fwd_copy_kv_cache_dest( tl.store(o_ptrs, k, mask=offs_h[:, None] < head_num) return + # adepted from https://github.com/ModelTC/lightllm/blob/5c559dd7981ed67679a08a1e09a88fb4c1550b3a/lightllm/common/triton_kernel/destindex_copy_kv.py @torch.no_grad() def copy_kv_cache_to_dest(k_ptr, dest_index_ptr, out): seq_len = dest_index_ptr.shape[0] diff --git a/colossalai/kernel/triton/rms_norm.py b/colossalai/kernel/triton/rms_norm.py deleted file mode 100644 index d5d6f9d85df1..000000000000 --- a/colossalai/kernel/triton/rms_norm.py +++ /dev/null @@ -1,71 +0,0 @@ -import torch - -try: - import triton - import triton.language as tl - - HAS_TRITON = True -except ImportError: - HAS_TRITON = False - print("please install triton from https://github.com/openai/triton") - - -if HAS_TRITON: - """ - this kernel function is modified from - https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/rmsnorm.py - """ - - @triton.jit - def _rms_norm_fwd_fused( - X, # pointer to the input - Y, # pointer to the output - W, # pointer to the weights - stride, # how much to increase the pointer when moving by 1 row - N, # number of columns in X - eps, # epsilon to avoid division by zero - BLOCK_SIZE: tl.constexpr, - ): - # Map the program id to the row of X and Y it should compute. - row = tl.program_id(0) - Y += row * stride - X += row * stride - # Compute variance - _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32) - for off in range(0, N, BLOCK_SIZE): - cols = off + tl.arange(0, BLOCK_SIZE) - x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) - _var += x * x - var = tl.sum(_var, axis=0) / N - rstd = 1 / tl.sqrt(var + eps) - # Normalize and apply linear transformation - for off in range(0, N, BLOCK_SIZE): - cols = off + tl.arange(0, BLOCK_SIZE) - mask = cols < N - w = tl.load(W + cols, mask=mask).to(tl.float32) - x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32) - x_hat = x * rstd - y = x_hat * w - # Write output - tl.store(Y + cols, y.to(tl.float16), mask=mask) - - def rmsnorm_forward(x, weight, eps): - # allocate output - y = torch.empty_like(x) - # reshape input data into 2D tensor - x_arg = x.view(-1, x.shape[-1]) - M, N = x_arg.shape - # Less than 64KB per feature: enqueue fused kernel - MAX_FUSED_SIZE = 65536 // x.element_size() - BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) - # print("BLOCK_SIZE:", BLOCK_SIZE) - if N > BLOCK_SIZE: - raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") - # heuristics for number of warps - num_warps = min(max(BLOCK_SIZE // 256, 1), 8) - # print(BLOCK_SIZE, num_warps, "block_size, numwarps") - BLOCK_SIZE = 128 * 2 * 2 * 2 * 2 * 2 * 2 * 2 - num_warps = 8 - # enqueue kernel - _rms_norm_fwd_fused[(M,)](x_arg, y, weight, x_arg.stride(0), N, eps, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps) - return y diff --git a/colossalai/kernel/triton/rotary_embedding_kernel.py b/colossalai/kernel/triton/rotary_embedding_kernel.py deleted file mode 100644 index fd74ba817551..000000000000 --- a/colossalai/kernel/triton/rotary_embedding_kernel.py +++ /dev/null @@ -1,212 +0,0 @@ -# Adapted from ModelTC https://github.com/ModelTC/lightllm -import torch -import triton -import triton.language as tl - - -@triton.jit -def _rotary_kernel( - q, - Cos, - Sin, - q_bs_stride, - q_h_stride, - q_d_stride, - cos_bs_stride, - cos_d_stride, - total_len, - HEAD_NUM: tl.constexpr, - BLOCK_HEAD: tl.constexpr, - BLOCK_SEQ: tl.constexpr, - HEAD_DIM: tl.constexpr, -): - current_head_index = tl.program_id(0) - current_seq_index = tl.program_id(1) - - current_head_range = current_head_index * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD) - current_seq_range = current_seq_index * BLOCK_SEQ + tl.arange(0, BLOCK_SEQ) - - dim_range0 = tl.arange(0, HEAD_DIM // 2) - dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM) - - off_q0 = ( - current_seq_range[:, None, None] * q_bs_stride - + current_head_range[None, :, None] * q_h_stride - + dim_range0[None, None, :] * q_d_stride - ) - off_q1 = ( - current_seq_range[:, None, None] * q_bs_stride - + current_head_range[None, :, None] * q_h_stride - + dim_range1[None, None, :] * q_d_stride - ) - - off_dimcos_sin = current_seq_range[:, None, None] * cos_bs_stride + dim_range0[None, None, :] * cos_d_stride - - q0 = tl.load( - q + off_q0, - mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM), - other=0.0, - ) - q1 = tl.load( - q + off_q1, - mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM), - other=0.0, - ) - - cos = tl.load(Cos + off_dimcos_sin, mask=current_seq_range[:, None, None] < total_len, other=0.0) - sin = tl.load(Sin + off_dimcos_sin, mask=current_seq_range[:, None, None] < total_len, other=0.0) - - out0 = q0 * cos - q1 * sin - out1 = q0 * sin + q1 * cos - - tl.store( - q + off_q0, - out0, - mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM), - ) - tl.store( - q + off_q1, - out1, - mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM), - ) - - return - - -@torch.no_grad() -def rotary_embedding_fwd(q, cos, sin): - total_len = q.shape[0] - head_num = q.shape[1] - head_dim = q.shape[2] - assert q.shape[0] == cos.shape[0] and q.shape[0] == sin.shape[0], f"q shape {q.shape} cos shape {cos.shape}" - BLOCK_HEAD = 4 - BLOCK_SEQ = 32 - grid = (triton.cdiv(head_num, BLOCK_HEAD), triton.cdiv(total_len, BLOCK_SEQ)) - if head_dim >= 128: - num_warps = 8 - else: - num_warps = 4 - - _rotary_kernel[grid]( - q, - cos, - sin, - q.stride(0), - q.stride(1), - q.stride(2), - cos.stride(0), - cos.stride(1), - total_len, - HEAD_NUM=head_num, - BLOCK_HEAD=BLOCK_HEAD, - BLOCK_SEQ=BLOCK_SEQ, - HEAD_DIM=head_dim, - num_warps=num_warps, - num_stages=1, - ) - return - - -class Llama2Forwards: - @staticmethod - @triton.jit - def _rotary_kernel( - Q, - Cos, - Sin, - stride_qbs, - stride_qh, - stride_qd, - stride_cosbs, - stride_cosd, - stride_sinbs, - stride_sind, - max_total_len, - H, # N_CTX - BLOCK_HEAD: tl.constexpr, - BLOCK_SEQ: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - ): - cur_head_index = tl.program_id(0) - cur_seq_index = tl.program_id(1) - - cur_head_range = cur_head_index * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD) - cur_seq_range = cur_seq_index * BLOCK_SEQ + tl.arange(0, BLOCK_SEQ) - - dim_range0 = tl.arange(0, BLOCK_DMODEL // 2) * 2 - dim_range1 = dim_range0 + 1 - off_q0 = ( - cur_seq_range[:, None, None] * stride_qbs - + cur_head_range[None, :, None] * stride_qh - + dim_range0[None, None, :] * stride_qd - ) - off_q1 = ( - cur_seq_range[:, None, None] * stride_qbs - + cur_head_range[None, :, None] * stride_qh - + dim_range1[None, None, :] * stride_qd - ) - - cos_range = tl.arange(0, BLOCK_DMODEL // 2) - off_dimcos_sin = cur_seq_range[:, None, None] * stride_cosbs + cos_range[None, None, :] * stride_cosd - - q0 = tl.load( - Q + off_q0, - mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < H), - other=0.0, - ) - q1 = tl.load( - Q + off_q1, - mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < H), - other=0.0, - ) - - cos = tl.load(Cos + off_dimcos_sin, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0) - sin = tl.load(Sin + off_dimcos_sin, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0) - - out0 = q0 * cos - q1 * sin - out1 = q0 * sin + q1 * cos - - tl.store( - Q + off_q0, out0, mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < H) - ) - tl.store( - Q + off_q1, out1, mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < H) - ) - - return - - @staticmethod - @torch.no_grad() - def rotary_emb_fwd(q, cos, sin): - total_len = q.shape[0] - head_num = q.shape[1] - head_dim = q.shape[2] // 2 - assert q.shape[0] == cos.shape[0] and q.shape[0] == sin.shape[0], f"q shape {q.shape} cos shape {cos.shape}" - BLOCK_HEAD = 4 - BLOCK_SEQ = 32 - grid = (triton.cdiv(head_num, BLOCK_HEAD), triton.cdiv(total_len, BLOCK_SEQ)) - if head_dim >= 128: - num_warps = 8 - else: - num_warps = 4 - - Llama2Forwards._rotary_kernel[grid]( - q, - cos, - sin, - q.stride(0), - q.stride(1), - q.stride(2), - cos.stride(0), - cos.stride(1), - sin.stride(0), - sin.stride(1), - total_len, - head_num, - BLOCK_HEAD=BLOCK_HEAD, - BLOCK_SEQ=BLOCK_SEQ, - BLOCK_DMODEL=head_dim, - num_warps=num_warps, - num_stages=1, - ) - return diff --git a/colossalai/kernel/triton/self_attention_nofusion.py b/colossalai/kernel/triton/self_attention_nofusion.py index 4b56c8afd67f..50d6786bd940 100644 --- a/colossalai/kernel/triton/self_attention_nofusion.py +++ b/colossalai/kernel/triton/self_attention_nofusion.py @@ -12,6 +12,7 @@ from .qkv_matmul_kernel import qkv_gemm_4d_kernel from .softmax import softmax_kernel + # adpeted from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/ops/transformer/inference/triton/triton_matmul_kernel.py#L312 def self_attention_forward_without_fusion( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, input_mask: torch.Tensor, scale: float ): @@ -141,6 +142,7 @@ def self_attention_forward_without_fusion( ) return output.view(batches, -1, d_model) + # modified from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/ops/transformer/inference/triton/attention.py#L212 def self_attention_compute_using_triton( qkv, input_mask, layer_past, alibi, scale, head_size, triangular=False, use_flash=False ): diff --git a/colossalai/kernel/triton/token_attention_kernel.py b/colossalai/kernel/triton/token_attention_kernel.py index c27394f0f9cf..8dc919bad125 100644 --- a/colossalai/kernel/triton/token_attention_kernel.py +++ b/colossalai/kernel/triton/token_attention_kernel.py @@ -12,401 +12,78 @@ HAS_TRITON = False print("please install triton from https://github.com/openai/triton") -if HAS_TRITON: - - @triton.jit - def _token_attn_1_kernel( - Q, - K, - sm_scale, - kv_cache_loc, - kv_cache_start_loc, - kv_cache_seqlen, - max_kv_cache_len, - attn_out, - kv_cache_loc_b_stride, - kv_cache_loc_s_stride, - q_batch_stride, - q_head_stride, - q_head_dim_stride, - k_batch_stride, - k_head_stride, - k_head_dim_stride, - attn_head_stride, - attn_batch_stride, - HEAD_DIM: tl.constexpr, - BLOCK_N: tl.constexpr, - ): - current_batch = tl.program_id(0) - current_head = tl.program_id(1) - start_n = tl.program_id(2) - - offs_d = tl.arange(0, HEAD_DIM) - current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch) - current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch) - - current_batch_start_index = max_kv_cache_len - current_batch_seq_len - current_batch_end_index = max_kv_cache_len - - off_q = current_batch * q_batch_stride + current_head * q_head_stride + offs_d * q_head_dim_stride - - offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) - - block_stard_index = start_n * BLOCK_N - block_mask = tl.where(block_stard_index < current_batch_seq_len, 1, 0) - - for start_mark in range(0, block_mask, 1): - q = tl.load(Q + off_q + start_mark) - offs_n_new = current_batch_start_index + offs_n - k_loc = tl.load( - kv_cache_loc + kv_cache_loc_b_stride * current_batch + kv_cache_loc_s_stride * offs_n_new, - mask=offs_n_new < current_batch_end_index, - other=0, - ) - off_k = k_loc[:, None] * k_batch_stride + current_head * k_head_stride + offs_d[None, :] * k_head_dim_stride - k = tl.load(K + off_k, mask=offs_n_new[:, None] < current_batch_end_index, other=0.0) - att_value = tl.sum(q[None, :] * k, 1) - att_value *= sm_scale - off_o = current_head * attn_head_stride + (current_batch_in_all_start_index + offs_n) * attn_batch_stride - tl.store(attn_out + off_o, att_value, mask=offs_n_new < current_batch_end_index) - return - - @triton.jit - def _token_attn_1_alibi_kernel( - Q, - K, - sm_scale, - alibi, - kv_cache_loc, - kv_cache_start_loc, - kv_cache_seqlen, - max_kv_cache_len, - attn_out, - kv_cache_loc_b_stride, - kv_cache_loc_s_stride, - q_batch_stride, - q_head_stride, - q_head_dim_stride, - k_batch_stride, - k_head_stride, - k_head_dim_stride, - attn_head_stride, - attn_batch_stride, - HEAD_DIM: tl.constexpr, - BLOCK_N: tl.constexpr, - ): - current_batch = tl.program_id(0) - current_head = tl.program_id(1) - start_n = tl.program_id(2) - - offs_d = tl.arange(0, HEAD_DIM) - current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch) - current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch) - - current_batch_start_index = max_kv_cache_len - current_batch_seq_len - current_batch_end_index = max_kv_cache_len - - off_q = current_batch * q_batch_stride + current_head * q_head_stride + offs_d * q_head_dim_stride - - offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) - - block_stard_index = start_n * BLOCK_N - block_mask = tl.where(block_stard_index < current_batch_seq_len, 1, 0) +try: + from lightllm.models.llama2.triton_kernel.token_attention_nopad_att1 import ( + token_att_fwd as lightllm_llama2_token_att_fwd, + ) + from lightllm.models.llama2.triton_kernel.token_attention_nopad_reduceV import ( + token_att_fwd2 as lightllm_llama2_token_att_fwd2, + ) + from lightllm.models.llama2.triton_kernel.token_attention_nopad_softmax import ( + token_softmax_fwd as lightllm_llama2_token_softmax_fwd, + ) + + from lightllm.models.llama.triton_kernel.token_attention_nopad_reduceV import token_att_fwd2 as lightllm_llama_token_att_fw2 + from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd as lightllm_llama_token_att_fwd + from lightllm.models.llama.triton_kernel.token_attention_nopad_softmax import token_softmax_fwd as lightllm_llama_token_softmax_fwd + from lightllm.models.bloom.triton_kernel.token_attention_nopad_att1 import token_att_fwd as lightllm_bloom_token_att_fwd + + HAS_TRITON_TOKEN_ATTENTION = True +except ImportError: + print("unable to import lightllm kernels") + HAS_TRITON_TOKEN_ATTENTION = False - for start_mark in range(0, block_mask, 1): - alibi_m = tl.load(alibi + current_head) - q = tl.load(Q + off_q + start_mark) - offs_n_new = current_batch_start_index + offs_n - k_loc = tl.load( - kv_cache_loc + kv_cache_loc_b_stride * current_batch + kv_cache_loc_s_stride * offs_n_new, - mask=offs_n_new < current_batch_end_index, - other=0, - ) - off_k = k_loc[:, None] * k_batch_stride + current_head * k_head_stride + offs_d[None, :] * k_head_dim_stride - k = tl.load(K + off_k, mask=offs_n_new[:, None] < current_batch_end_index, other=0.0) - att_value = tl.sum(q[None, :] * k, 1) - att_value *= sm_scale - att_value -= alibi_m * (current_batch_seq_len - 1 - offs_n) - off_o = current_head * attn_head_stride + (current_batch_in_all_start_index + offs_n) * attn_batch_stride - tl.store(attn_out + off_o, att_value, mask=offs_n_new < current_batch_end_index) - return +if HAS_TRITON: @torch.no_grad() - def token_attn_fwd_1( - q, k, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seqlen, max_kv_cache_len, alibi=None + def token_attention_fwd( + q, k, v, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, max_len_in_batch, alibi=None ): - BLOCK = 32 - # shape constraints - q_head_dim, k_head_dim = q.shape[-1], k.shape[-1] - assert q_head_dim == k_head_dim - assert k_head_dim in {16, 32, 64, 128} - sm_scale = 1.0 / (k_head_dim**0.5) - - batch, head_num = kv_cache_loc.shape[0], q.shape[1] - - grid = (batch, head_num, triton.cdiv(max_kv_cache_len, BLOCK)) + head_num = k.shape[1] + batch_size = kv_cache_seq_len.shape[0] + calcu_shape1 = (batch_size, head_num, k.shape[2]) + total_token_num = k.shape[0] - num_warps = 4 if k_head_dim <= 64 else 8 - num_warps = 2 + att_m_tensor = torch.empty((head_num, total_token_num), dtype=q.dtype, device="cuda") - if alibi is not None: - _token_attn_1_alibi_kernel[grid]( - q, + if alibi is None: + lightllm_llama_token_att_fwd( + q.view(calcu_shape1), k, - sm_scale, - alibi, + att_m_tensor, kv_cache_loc, kv_cache_start_loc, - kv_cache_seqlen, - max_kv_cache_len, - attn_out, - kv_cache_loc.stride(0), - kv_cache_loc.stride(1), - q.stride(0), - q.stride(1), - q.stride(2), - k.stride(0), - k.stride(1), - k.stride(2), - attn_out.stride(0), - attn_out.stride(1), - HEAD_DIM=k_head_dim, - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=1, + kv_cache_seq_len, + max_len_in_batch, ) else: - _token_attn_1_kernel[grid]( - q, + lightllm_bloom_token_att_fwd( + q.view(calcu_shape1), k, - sm_scale, + att_m_tensor, + alibi, kv_cache_loc, kv_cache_start_loc, - kv_cache_seqlen, - max_kv_cache_len, - attn_out, - kv_cache_loc.stride(0), - kv_cache_loc.stride(1), - q.stride(0), - q.stride(1), - q.stride(2), - k.stride(0), - k.stride(1), - k.stride(2), - attn_out.stride(0), - attn_out.stride(1), - HEAD_DIM=k_head_dim, - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=1, - ) - return - - @triton.jit - def _token_attn_softmax_fwd( - softmax_logics, - kv_cache_start_loc, - kv_cache_seqlen, - softmax_prob_out, - logics_head_dim_stride, - logics_batch_stride, - prob_head_dim_stride, - prob_batch_stride, - BLOCK_SIZE: tl.constexpr, - ): - current_batch = tl.program_id(0) - current_head = tl.program_id(1) - - col_offsets = tl.arange(0, BLOCK_SIZE) - current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch) - current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch) - - row = tl.load( - softmax_logics - + current_head * logics_head_dim_stride - + (current_batch_in_all_start_index + col_offsets) * logics_batch_stride, - mask=col_offsets < current_batch_seq_len, - other=-float("inf"), - ).to(tl.float32) - - row_minus_max = row - tl.max(row, axis=0) - numerator = tl.exp(row_minus_max) - denominator = tl.sum(numerator, axis=0) - softmax_output = numerator / denominator - - tl.store( - softmax_prob_out - + current_head * prob_head_dim_stride - + (current_batch_in_all_start_index + col_offsets) * prob_batch_stride, - softmax_output, - mask=col_offsets < current_batch_seq_len, - ) - return - - @torch.no_grad() - def token_attn_softmax_fwd(softmax_logics, kv_cache_start_loc, kv_cache_seqlen, softmax_prob_out, max_kv_cache_len): - BLOCK_SIZE = triton.next_power_of_2(max_kv_cache_len) - batch, head_num = kv_cache_start_loc.shape[0], softmax_logics.shape[0] - - num_warps = 4 - if BLOCK_SIZE >= 2048: - num_warps = 8 - if BLOCK_SIZE >= 4096: - num_warps = 16 - - _token_attn_softmax_fwd[(batch, head_num)]( - softmax_logics, - kv_cache_start_loc, - kv_cache_seqlen, - softmax_prob_out, - softmax_logics.stride(0), - softmax_logics.stride(1), - softmax_prob_out.stride(0), - softmax_prob_out.stride(1), - num_warps=num_warps, - BLOCK_SIZE=BLOCK_SIZE, - ) - return - - @triton.jit - def _token_attn_2_kernel( - Prob, - V, - attn_out, - kv_cache_loc, - kv_cache_start_loc, - kv_cache_seqlen, - max_kv_cache_len, - kv_cache_loc_b_stride, - kv_cache_loc_s_stride, - prob_head_dim_stride, - prob_batch_stride, - v_batch_stride, - v_head_stride, - v_head_dim_stride, - attn_out_batch_stride, - attn_out_head_stride, - attn_out_head_dim_stride, - HEAD_DIM: tl.constexpr, - BLOCK_N: tl.constexpr, - ): - current_batch = tl.program_id(0) - current_head = tl.program_id(1) - - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, HEAD_DIM) - current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch) - current_batch_start_index = max_kv_cache_len - current_batch_seq_len - current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch) - - v_loc_off = current_batch * kv_cache_loc_b_stride + (current_batch_start_index + offs_n) * kv_cache_loc_s_stride - p_offs = current_head * prob_head_dim_stride + (current_batch_in_all_start_index + offs_n) * prob_batch_stride - v_offs = current_head * v_head_stride + offs_d[None, :] * v_head_dim_stride - - acc = tl.zeros([HEAD_DIM], dtype=tl.float32) - for start_n in range(0, current_batch_seq_len, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - p_value = tl.load( - Prob + p_offs + start_n * kv_cache_loc_s_stride, - mask=(start_n + offs_n) < current_batch_seq_len, - other=0.0, - ) - v_loc = tl.load( - kv_cache_loc + v_loc_off + start_n * kv_cache_loc_s_stride, - mask=(start_n + offs_n) < current_batch_seq_len, - other=0.0, - ) - v_value = tl.load( - V + v_offs + v_loc[:, None] * v_batch_stride, - mask=(start_n + offs_n[:, None]) < current_batch_seq_len, - other=0.0, + kv_cache_seq_len, + max_len_in_batch, ) - acc += tl.sum(p_value[:, None] * v_value, 0) - - acc = acc.to(tl.float16) - off_o = ( - current_batch * attn_out_batch_stride - + current_head * attn_out_head_stride - + offs_d * attn_out_head_dim_stride - ) - out_ptrs = attn_out + off_o - tl.store(out_ptrs, acc) - return - - @torch.no_grad() - def token_attn_fwd_2(prob, v, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seqlen, max_kv_cache_len): - if triton.__version__ >= "2.1.0": - BLOCK = 128 - else: - BLOCK = 64 - batch, head = kv_cache_loc.shape[0], v.shape[1] - grid = (batch, head) - num_warps = 4 - dim = v.shape[-1] - - _token_attn_2_kernel[grid]( - prob, - v, - attn_out, - kv_cache_loc, - kv_cache_start_loc, - kv_cache_seqlen, - max_kv_cache_len, - kv_cache_loc.stride(0), - kv_cache_loc.stride(1), - prob.stride(0), - prob.stride(1), - v.stride(0), - v.stride(1), - v.stride(2), - attn_out.stride(0), - attn_out.stride(1), - attn_out.stride(2), - HEAD_DIM=dim, - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=1, - ) - return - - @torch.no_grad() - def token_attention_fwd( - q, k, v, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, max_len_in_batch, alibi=None - ): - head_num = k.shape[1] - batch_size = kv_cache_seq_len.shape[0] - calcu_shape1 = (batch_size, head_num, k.shape[2]) - total_token_num = k.shape[0] - - att_m_tensor = torch.empty((head_num, total_token_num), dtype=q.dtype, device="cuda") - - token_attn_fwd_1( - q.view(calcu_shape1), - k, - att_m_tensor, - kv_cache_loc, - kv_cache_start_loc, - kv_cache_seq_len, - max_len_in_batch, - alibi=alibi, - ) prob = torch.empty_like(att_m_tensor) - token_attn_softmax_fwd(att_m_tensor, kv_cache_start_loc, kv_cache_seq_len, prob, max_len_in_batch) + lightllm_llama_token_softmax_fwd(att_m_tensor, kv_cache_start_loc, kv_cache_seq_len, prob, max_len_in_batch) att_m_tensor = None - token_attn_fwd_2( + lightllm_llama_token_att_fw2( prob, v, attn_out.view(calcu_shape1), kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, max_len_in_batch ) - prob = None - return class Llama2TokenAttentionForwards: @staticmethod @triton.jit + + # this function is adapted from https://github.com/ModelTC/lightllm/blob/5c559dd7981ed67679a08a1e09a88fb4c1550b3a/lightllm/models/llama2/triton_kernel/token_attention_nopad_softmax.py#L8 def _fwd_kernel( Logics, V, @@ -478,6 +155,7 @@ def _fwd_kernel( tl.store(out_ptrs, acc) return + # this function is adapted from https://github.com/ModelTC/lightllm/blob/5c559dd7981ed67679a08a1e09a88fb4c1550b3a/lightllm/models/llama2/triton_kernel/token_attention_nopad_softmax.py#L36 @staticmethod @torch.no_grad() def token_softmax_reducev_fwd(logics, v, o, b_loc, b_start_loc, b_seq_len, max_input_len, other_kv_index): @@ -514,277 +192,6 @@ def token_softmax_reducev_fwd(logics, v, o, b_loc, b_start_loc, b_seq_len, max_i ) return - @staticmethod - @triton.jit - def _fwd_kernel_token_softmax( - Logics, - B_Start_Loc, - B_Seqlen, - Prob_Out, - stride_logic_h, - stride_logic_bs, - stride_prob_h, - stride_prob_bs, - BLOCK_SIZE: tl.constexpr, - ): - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - - col_offsets = tl.arange(0, BLOCK_SIZE) - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - - row = tl.load( - Logics + cur_head * stride_logic_h + (cur_batch_in_all_start_index + col_offsets) * stride_logic_bs, - mask=col_offsets < cur_batch_seq_len, - other=-float("inf"), - ).to(tl.float32) - - row_minus_max = row - tl.max(row, axis=0) - numerator = tl.exp(row_minus_max) - denominator = tl.sum(numerator, axis=0) - softmax_output = numerator / denominator - - tl.store( - Prob_Out + cur_head * stride_prob_h + (cur_batch_in_all_start_index + col_offsets) * stride_prob_bs, - softmax_output, - mask=col_offsets < cur_batch_seq_len, - ) - return - - @staticmethod - @torch.no_grad() - def token_softmax_fwd(Logics, B_Start_Loc, B_Seqlen, Prob_Out, max_input_len): - BLOCK_SIZE = triton.next_power_of_2(max_input_len) - batch, head_num = B_Start_Loc.shape[0], Logics.shape[0] - - num_warps = 4 - if BLOCK_SIZE >= 2048: - num_warps = 8 - if BLOCK_SIZE >= 4096: - num_warps = 16 - - Llama2TokenAttentionForwards._fwd_kernel_token_softmax[(batch, head_num)]( - Logics, - B_Start_Loc, - B_Seqlen, - Prob_Out, - Logics.stride(0), - Logics.stride(1), - Prob_Out.stride(0), - Prob_Out.stride(1), - num_warps=num_warps, - BLOCK_SIZE=BLOCK_SIZE, - ) - return - - @staticmethod - @triton.jit - def _fwd_kernel_token_att1( - Q, - K, - sm_scale, - B_Loc, - B_Start_Loc, - B_Seqlen, - max_input_len, - Att_Out, - stride_b_loc_b, - stride_b_loc_s, - stride_qbs, - stride_qh, - stride_qd, - stride_kbs, - stride_kh, - stride_kd, - att_stride_h, - att_stride_bs, - kv_group_num, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - ): - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - start_n = tl.program_id(2) - - cur_kv_head = cur_head // kv_group_num - - offs_d = tl.arange(0, BLOCK_DMODEL) - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - - cur_batch_start_index = max_input_len - cur_batch_seq_len - cur_batch_end_index = max_input_len - - off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d * stride_qd - - offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) - - block_stard_index = start_n * BLOCK_N - block_mask = tl.where(block_stard_index < cur_batch_seq_len, 1, 0) - - for start_mark in range(0, block_mask, 1): - q = tl.load(Q + off_q + start_mark) - offs_n_new = cur_batch_start_index + offs_n - k_loc = tl.load( - B_Loc + stride_b_loc_b * cur_batch + stride_b_loc_s * offs_n_new, - mask=offs_n_new < cur_batch_end_index, - other=0, - ) - off_k = k_loc[:, None] * stride_kbs + cur_kv_head * stride_kh + offs_d[None, :] * stride_kd - k = tl.load(K + off_k, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0) - att_value = tl.sum(q[None, :] * k, 1) - att_value *= sm_scale - off_o = cur_head * att_stride_h + (cur_batch_in_all_start_index + offs_n) * att_stride_bs - tl.store(Att_Out + off_o, att_value, mask=offs_n_new < cur_batch_end_index) - return - - @staticmethod - @torch.no_grad() - def token_att_fwd(q, k, att_out, B_Loc, B_Start_Loc, B_Seqlen, max_input_len): - BLOCK = 32 - # shape constraints - Lq, Lk = q.shape[-1], k.shape[-1] - assert Lq == Lk - assert Lk in {16, 32, 64, 128} - sm_scale = 1.0 / (Lk**0.5) - - batch, head_num = B_Loc.shape[0], q.shape[1] - - grid = (batch, head_num, triton.cdiv(max_input_len, BLOCK)) - kv_group_num = q.shape[1] // k.shape[1] - - num_warps = 4 if Lk <= 64 else 8 - num_warps = 2 - - Llama2TokenAttentionForwards._fwd_kernel_token_att1[grid]( - q, - k, - sm_scale, - B_Loc, - B_Start_Loc, - B_Seqlen, - max_input_len, - att_out, - B_Loc.stride(0), - B_Loc.stride(1), - q.stride(0), - q.stride(1), - q.stride(2), - k.stride(0), - k.stride(1), - k.stride(2), - att_out.stride(0), - att_out.stride(1), - kv_group_num=kv_group_num, - BLOCK_DMODEL=Lk, - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=1, - ) - return - - @staticmethod - @triton.jit - def _fwd_kernel_token_att2( - Prob, - V, - Out, - B_Loc, - B_Start_Loc, - B_Seqlen, - max_input_len, # B_Start_Loc cumsum of input lens if continuous - stride_b_loc_b, - stride_b_loc_s, - stride_ph, - stride_pbs, - stride_vbs, - stride_vh, - stride_vd, - stride_obs, - stride_oh, - stride_od, - kv_group_num, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - ): - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - - cur_kv_head = cur_head // kv_group_num - - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_start_index = max_input_len - cur_batch_seq_len - cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - - v_loc_off = cur_batch * stride_b_loc_b + (cur_batch_start_index + offs_n) * stride_b_loc_s - p_offs = cur_head * stride_ph + (cur_batch_in_all_start_index + offs_n) * stride_pbs - v_offs = cur_kv_head * stride_vh + offs_d[None, :] * stride_vd - - acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) - for start_n in range(0, cur_batch_seq_len, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - p_value = tl.load( - Prob + p_offs + start_n * stride_b_loc_s, mask=(start_n + offs_n) < cur_batch_seq_len, other=0.0 - ) - v_loc = tl.load( - B_Loc + v_loc_off + start_n * stride_b_loc_s, mask=(start_n + offs_n) < cur_batch_seq_len, other=0.0 - ) - v_value = tl.load( - V + v_offs + v_loc[:, None] * stride_vbs, - mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, - other=0.0, - ) - acc += tl.sum(p_value[:, None] * v_value, 0) - - acc = acc.to(tl.float16) - off_o = cur_batch * stride_obs + cur_head * stride_oh + offs_d * stride_od - out_ptrs = Out + off_o - tl.store(out_ptrs, acc) - return - - @staticmethod - @torch.no_grad() - def token_att_fwd2(prob, v, out, B_Loc, B_Start_Loc, B_Seqlen, max_input_len): - if triton.__version__ >= "2.1.0": - BLOCK = 128 - else: - BLOCK = 64 - batch, head = B_Loc.shape[0], prob.shape[0] - grid = (batch, head) - num_warps = 4 - dim = v.shape[-1] - - kv_group_num = prob.shape[0] // v.shape[1] - - Llama2TokenAttentionForwards._fwd_kernel_token_att2[grid]( - prob, - v, - out, - B_Loc, - B_Start_Loc, - B_Seqlen, - max_input_len, - B_Loc.stride(0), - B_Loc.stride(1), - prob.stride(0), - prob.stride(1), - v.stride(0), - v.stride(1), - v.stride(2), - out.stride(0), - out.stride(1), - out.stride(2), - kv_group_num=kv_group_num, - BLOCK_DMODEL=dim, - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=1, - ) - return - # this is the interface of llama2 attn forward @staticmethod @torch.no_grad() @@ -796,7 +203,7 @@ def token_attn( calcu_shape1 = (batch_size, head_num, head_dim) att_m_tensor = torch.empty((head_num, total_token_num), dtype=q.dtype, device="cuda") - Llama2TokenAttentionForwards.token_att_fwd( + lightllm_llama2_token_att_fwd( q, k, att_m_tensor, @@ -808,12 +215,12 @@ def token_attn( if triton.__version__ == "2.0.0": prob = torch.empty_like(att_m_tensor) - Llama2TokenAttentionForwards.token_softmax_fwd( + lightllm_llama2_token_softmax_fwd( att_m_tensor, kv_cache_start_loc, kv_cache_seq_len, prob, max_len_in_batch ) att_m_tensor = None - Llama2TokenAttentionForwards.token_att_fwd2( + lightllm_llama2_token_att_fwd2( prob, v, attn_out.view(calcu_shape1), diff --git a/examples/inference/bench_llama.py b/examples/inference/bench_llama.py index 90d49f6a264a..0ca1953c6a41 100644 --- a/examples/inference/bench_llama.py +++ b/examples/inference/bench_llama.py @@ -3,7 +3,6 @@ import time import torch -from torch.profiler import ProfilerActivity, profile, record_function from transformers import LlamaForCausalLM, LlamaTokenizer import colossalai @@ -16,6 +15,7 @@ def print_perf_stats(latency_set, config, bs, warmup=3): + torch.cuda.empty_cache() # trim warmup queries latency_set = list(latency_set) latency_set = latency_set[warmup:] @@ -38,24 +38,29 @@ def run_llama_test(args): max_batch_size = args.batch_size max_input_len = args.input_len max_output_len = args.output_len + args.test_mode + + print("max_batch_size : " + str(max_batch_size)) tokenizer = LlamaTokenizer.from_pretrained(llama_model_path) tokenizer.pad_token_id = tokenizer.unk_token_id model = LlamaForCausalLM.from_pretrained(llama_model_path, pad_token_id=tokenizer.eos_token_id) model = model.half() - model_config = model.config + model.config shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True) infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len) - generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False) + generate_kwargs = dict(max_new_tokens=1, do_sample=False) input_tokens = { "input_ids": torch.randint(1, 1000, (max_batch_size, max_input_len), device="cuda"), "attention_mask": torch.ones((max_batch_size, max_input_len), device="cuda"), } iters = 10 - times = [] + prefill_times = [] + + warmup = 3 for i in range(iters): torch.cuda.synchronize() @@ -65,17 +70,39 @@ def run_llama_test(args): end = time.time() out_len = outputs.shape[1] print("generation time {} s".format(str(end - start))) + print(out_len - max_input_len) + prefill_times.append((end - start) / (out_len - max_input_len)) + + prefill_times = prefill_times[warmup:] + prefill_time_avg = sum(prefill_times) / len(prefill_times) + generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False) + + times = [] + decoder_times = [] + for i in range(iters): + torch.cuda.synchronize() + start = time.time() + outputs = infer_engine.generate(input_tokens, **generate_kwargs) + torch.cuda.synchronize() + end = time.time() + out_len = outputs.shape[1] + print("generation time {} s".format(str(end - start))) + print(out_len - max_input_len) times.append((end - start) / (out_len - max_input_len)) + if args.test_mode == "decoder_test": + decoder_times.append((end - start - prefill_time_avg) / (out_len - max_input_len - 1)) + + times = times[warmup:] + latency = sum(times) / len(times) + print("total process latency is : " + str(latency) + " s") + print("total throughput is : " + str(1 / latency * max_batch_size)) - print("outputs, ", len(outputs)) - print_perf_stats(times, model_config, max_batch_size) + if args.test_mode == "decoder_test": + decoder_times = decoder_times[warmup:] + latency = sum(decoder_times) / len(decoder_times) - with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof: - with record_function("model_inference"): - torch.cuda.synchronize() - outputs = infer_engine.generate(input_tokens, **generate_kwargs) - torch.cuda.synchronize() - print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) + print("decoder process latency is : " + str(latency) + " s") + print("decoder throughput is : " + str(1 / latency * max_batch_size)) def check_llama(rank, world_size, port, args): @@ -95,8 +122,11 @@ def test_llama(args): parser.add_argument("-p", "--path", type=str, help="Model path", required=True) parser.add_argument("-tp", "--tp_size", type=int, default=1, help="Tensor parallel size") parser.add_argument("-b", "--batch_size", type=int, default=16, help="Maximum batch size") - parser.add_argument("--input_len", type=int, default=1024, help="Maximum input length") + parser.add_argument("--input_len", type=int, default=256, help="Maximum input length") parser.add_argument("--output_len", type=int, default=128, help="Maximum output length") + parser.add_argument( + "--test_mode", type=str, help="Test mode", default="e2e_test", choices=["e2e_test", "decoder_test"] + ) args = parser.parse_args() diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 9aa5f2822e40..19cb7a154a01 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -11,3 +11,6 @@ ninja torch>=1.12 safetensors einops +sentencepiece +google +protobuf diff --git a/tests/test_infer_ops/triton/test_llama2_token_attn.py b/tests/test_infer_ops/triton/test_llama2_token_attn.py deleted file mode 100644 index 0537a3d76129..000000000000 --- a/tests/test_infer_ops/triton/test_llama2_token_attn.py +++ /dev/null @@ -1,63 +0,0 @@ -import pytest -import torch -from packaging import version - -try: - pass - - from colossalai.kernel.triton.token_attention_kernel import Llama2TokenAttentionForwards - - HAS_TRITON = True -except ImportError: - HAS_TRITON = False - print("please install triton from https://github.com/openai/triton") - -TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") - - -def torch_att(xq, xk, xv, bs, seqlen, num_head, head_dim): - xq = xq.view(bs, 1, num_head, head_dim) - xk = xk.view(bs, seqlen, num_head, head_dim) - xv = xv.view(bs, seqlen, num_head, head_dim) - - logics = torch.sum(xq * xk, dim=3, keepdim=False) * 1 / (head_dim**0.5) - prob = torch.softmax(logics, dim=1) - prob = prob.view(bs, seqlen, num_head, 1) - - return torch.sum(prob * xv, dim=1, keepdim=False) - - -@pytest.mark.skipif( - not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4" -) -def test(): - Z, head_num, seq_len, head_dim = 2, 32, 2048, 128 - dtype = torch.float16 - - # attn out: 2,4096 - q = torch.empty((Z, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2) - k = torch.empty((Z * seq_len, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2) - v = torch.empty((Z * seq_len, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2) - o = torch.empty((Z, head_num, head_dim), dtype=dtype, device="cuda") - max_kv_cache_len = seq_len - kv_cache_start_loc = torch.zeros((Z,), dtype=torch.int32, device="cuda") - kv_cache_loc = torch.zeros((Z, seq_len), dtype=torch.int32, device="cuda") - kv_cache_seq_len = torch.ones((Z,), dtype=torch.int32, device="cuda") - other_kv_index = 2048 - - kv_cache_seq_len[:] = seq_len - kv_cache_start_loc[0] = 0 - kv_cache_start_loc[1] = seq_len - - for i in range(Z): - kv_cache_loc[i, :] = torch.arange(i * seq_len, (i + 1) * seq_len, dtype=torch.int32, device="cuda") - - Llama2TokenAttentionForwards.token_attn( - q, k, v, o, kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, max_kv_cache_len, other_kv_index - ) - torch_out = torch_att(q, k, v, Z, seq_len, head_num, head_dim) - assert torch.allclose(torch_out, o, atol=1e-3, rtol=0) - - -if __name__ == "__main__": - test() diff --git a/tests/test_infer_ops/triton/test_rotary_embedding.py b/tests/test_infer_ops/triton/test_rotary_embedding.py deleted file mode 100644 index 7e05ccafbfc4..000000000000 --- a/tests/test_infer_ops/triton/test_rotary_embedding.py +++ /dev/null @@ -1,55 +0,0 @@ -# Adapted from ModelTC https://github.com/ModelTC/lightllm - - -import pytest -import torch -from packaging import version - -try: - pass - - from colossalai.kernel.triton.rotary_embedding_kernel import rotary_embedding_fwd - - HAS_TRITON = True -except ImportError: - HAS_TRITON = False - print("please install triton from https://github.com/openai/triton") - -TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") - - -def torch_rotary_emb(x, cos, sin): - seq_len, h, dim = x.shape - x0 = x[:, :, 0 : dim // 2] - x1 = x[:, :, dim // 2 : dim] - cos = cos.view((seq_len, 1, dim // 2)) - sin = sin.view((seq_len, 1, dim // 2)) - o0 = x0 * cos - x1 * sin - o1 = x0 * sin + x1 * cos - return torch.cat((o0, o1), dim=-1) - - -@pytest.mark.skipif( - not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4" -) -def test_rotary_emb(): - SEQ_LEN = 1 - HEAD_NUM = 32 - HEAD_DIM = 128 - dtype = torch.half - # create data - x_shape = (SEQ_LEN, HEAD_NUM, HEAD_DIM) - x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda") - cos_shape = (SEQ_LEN, HEAD_DIM // 2) - cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda") - sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda") - # forward pass - y_torch = torch_rotary_emb(x, cos, sin) - rotary_embedding_fwd(x, cos, sin) - y_triton = x - # compare - assert torch.allclose(y_torch, y_triton, atol=1e-2, rtol=0) - - -if __name__ == "__main__": - test_rotary_emb() diff --git a/tests/test_infer_ops/triton/test_token_attn_1.py b/tests/test_infer_ops/triton/test_token_attn_1.py deleted file mode 100644 index fc5f8cd6c9dc..000000000000 --- a/tests/test_infer_ops/triton/test_token_attn_1.py +++ /dev/null @@ -1,74 +0,0 @@ -import math - -import pytest -import torch -from packaging import version - -try: - pass - - from colossalai.kernel.triton.token_attention_kernel import token_attn_fwd_1 - - HAS_TRITON = True -except ImportError: - HAS_TRITON = False - print("please install triton from https://github.com/openai/triton") - -TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") - - -def torch_attn(xq, xk, bs, seqlen, num_head, head_dim): - xq = xq.view(bs, 1, num_head, head_dim) - xk = xk.view(bs, seqlen, num_head, head_dim) - keys = xk - xq = xq.transpose(1, 2) - keys = keys.transpose(1, 2) - scores = ( - (torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(head_dim)).squeeze().transpose(0, 1).reshape(num_head, -1) - ) - return scores - - -def torch_attn_1(xq, xk, seqlen, num_head, head_dim): - xq = xq.view(1, num_head, head_dim) - xk = xk.view(seqlen, num_head, head_dim) - logics = torch.sum(xq * xk, dim=-1, keepdim=False) - - logics = logics.transpose(0, 1) / math.sqrt(head_dim) - return logics - - -@pytest.mark.skipif( - not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4" -) -def test_attn_1(): - pass - - batch_size, seq_len, head_num, head_dim = 17, 1025, 12, 128 - - dtype = torch.float16 - - q = torch.empty((batch_size, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2) - k = torch.empty((batch_size * seq_len, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2) - attn_out = torch.empty((head_num, batch_size * seq_len), dtype=dtype, device="cuda") - - b_loc = torch.zeros((batch_size, seq_len), dtype=torch.int32, device="cuda") - kv_cache_start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda") - kv_cache_seq_len = torch.zeros((batch_size,), dtype=torch.int32, device="cuda") - - for i in range(batch_size): - kv_cache_start_loc[i] = i * seq_len - kv_cache_seq_len[i] = seq_len - b_loc[i] = i * seq_len + torch.arange(0, seq_len, dtype=torch.int32, device="cuda") - - token_attn_fwd_1(q, k, attn_out, b_loc, kv_cache_start_loc, kv_cache_seq_len, seq_len) - - torch_out = torch_attn(q, k, batch_size, seq_len, head_num, head_dim).squeeze() - o = attn_out.squeeze() - print("max ", torch.max(torch.abs(torch_out - o))) - print("mean ", torch.mean(torch.abs(torch_out - o))) - assert torch.allclose(torch_out, o, atol=1e-2, rtol=0) - - -if __name__ == "__main__": - test_attn_1() diff --git a/tests/test_infer_ops/triton/test_token_attn_2.py b/tests/test_infer_ops/triton/test_token_attn_2.py deleted file mode 100644 index 2dd756f2ba91..000000000000 --- a/tests/test_infer_ops/triton/test_token_attn_2.py +++ /dev/null @@ -1,63 +0,0 @@ -import pytest -import torch -from packaging import version - -try: - pass - - from colossalai.kernel.triton.token_attention_kernel import token_attn_fwd_2 - - HAS_TRITON = True -except ImportError: - HAS_TRITON = False - print("please install triton from https://github.com/openai/triton") - -TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") - - -def torch_attn(V, P, bs, seqlen, num_head, head_dim): - V = V.view(bs, seqlen, num_head, head_dim).transpose(1, 2) - P = P.reshape(num_head, bs, 1, seqlen).transpose(0, 1) - attn_out = torch.matmul(P, V) - - return attn_out - - -@pytest.mark.skipif( - not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4" -) -def test_token_attn_2(): - pass - - batch_size, seq_len, head_num, head_dim = 17, 1025, 12, 128 - dtype = torch.float16 - - V = torch.empty((batch_size * seq_len, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.1, std=10) - Prob = ( - torch.empty((head_num, batch_size * seq_len), dtype=dtype, device="cuda") - .normal_(mean=0.4, std=0.2) - .reshape(head_num, batch_size, seq_len) - .softmax(-1) - .reshape(head_num, batch_size * seq_len) - ) - attn_out = torch.empty((batch_size, head_num, head_dim), dtype=dtype, device="cuda") - - kv_cache_start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda") - kv_cache_seq_len = torch.zeros((batch_size,), dtype=torch.int32, device="cuda") - kv_cache_loc = torch.zeros((batch_size, seq_len), dtype=torch.int32, device="cuda") - for i in range(batch_size): - kv_cache_start_loc[i] = i * seq_len - kv_cache_seq_len[i] = seq_len - kv_cache_loc[i] = i * seq_len + torch.arange(0, seq_len, dtype=torch.int32, device="cuda") - - token_attn_fwd_2(Prob, V, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, seq_len) - - torch_out = torch_attn(V, Prob, batch_size, seq_len, head_num, head_dim).squeeze() - o = attn_out - print("max ", torch.max(torch.abs(torch_out - o))) - print("mean ", torch.mean(torch.abs(torch_out - o))) - assert torch.allclose(torch_out, o, atol=1e-2, rtol=0) - - -if __name__ == "__main__": - test_token_attn_2() diff --git a/tests/test_infer_ops/triton/test_token_attn_fwd.py b/tests/test_infer_ops/triton/test_token_attn_fwd.py index 9c7a53798317..a7fc3d29b77a 100644 --- a/tests/test_infer_ops/triton/test_token_attn_fwd.py +++ b/tests/test_infer_ops/triton/test_token_attn_fwd.py @@ -3,16 +3,13 @@ from packaging import version try: - pass - from colossalai.kernel.triton.token_attention_kernel import token_attention_fwd - HAS_TRITON = True except ImportError: HAS_TRITON = False print("please install triton from https://github.com/openai/triton") -TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) >= version.parse("11.6") def torch_att(xq, xk, xv, bs, seqlen, num_head, head_dim): From 8633a87a873f84ea2dedd268b5399c90c35a0fef Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Fri, 20 Oct 2023 10:35:08 +0800 Subject: [PATCH 12/46] [test] merge old components to test to model zoo (#4945) * [test] add custom models in model zoo * [test] update legacy test * [test] update model zoo * [test] update gemini test * [test] remove components to test --- colossalai/testing/__init__.py | 2 + colossalai/testing/utils.py | 21 +++++ tests/components_to_test/__init__.py | 29 ------ tests/components_to_test/albert.py | 62 ------------- tests/components_to_test/beit.py | 44 --------- tests/components_to_test/bert.py | 88 ------------------ tests/components_to_test/gpt2.py | 92 ------------------- .../components_to_test/hanging_param_model.py | 48 ---------- tests/components_to_test/inline_op_model.py | 49 ---------- tests/components_to_test/registry.py | 38 -------- .../repeated_computed_layers.py | 47 ---------- tests/components_to_test/resnet.py | 37 -------- tests/components_to_test/simple_net.py | 53 ----------- tests/components_to_test/utils/__init__.py | 2 - .../utils/dummy_data_generator.py | 24 ----- tests/kit/model_zoo/__init__.py | 5 +- tests/kit/model_zoo/custom/__init__.py | 4 + tests/kit/model_zoo/custom/base.py | 26 ++++++ .../model_zoo/custom/hanging_param_model.py | 48 ++++++++++ .../model_zoo/custom}/nested_model.py | 36 ++++---- .../custom/repeated_computed_layers.py | 48 ++++++++++ tests/kit/model_zoo/custom/simple_net.py | 53 +++++++++++ .../utils => kit/model_zoo}/executor.py | 30 ++++-- tests/kit/model_zoo/transformers/bert.py | 4 +- tests/kit/model_zoo/transformers/blip2.py | 2 +- tests/kit/model_zoo/transformers/bloom.py | 8 +- tests/kit/model_zoo/transformers/chatglm2.py | 4 +- tests/kit/model_zoo/transformers/gpt.py | 5 +- tests/kit/model_zoo/transformers/llama.py | 6 +- tests/kit/model_zoo/transformers/opt.py | 4 +- tests/kit/model_zoo/transformers/sam.py | 2 +- tests/kit/model_zoo/transformers/t5.py | 6 +- tests/kit/model_zoo/transformers/vit.py | 6 +- tests/kit/model_zoo/transformers/whisper.py | 4 +- tests/test_legacy/test_amp/test_naive_fp16.py | 18 ++-- tests/test_legacy/test_amp/test_torch_fp16.py | 18 ++-- tests/test_legacy/test_engine/test_engine.py | 25 ++--- .../test_trainer_with_non_pipe_schedule.py | 14 +-- tests/test_optimizer/test_nvme.py | 5 +- tests/test_zero/test_gemini/test_fwd_bwd.py | 45 ++++----- .../test_gemini/test_gemini_use_rmt.py | 33 +++---- .../test_zero/test_gemini/test_grad_accum.py | 34 ++++--- tests/test_zero/test_gemini/test_grad_clip.py | 26 +++--- tests/test_zero/test_gemini/test_inference.py | 34 +++---- tests/test_zero/test_gemini/test_optim.py | 63 ++++++------- .../test_gemini/test_runtime_mem_tracer.py | 20 ++-- tests/test_zero/test_gemini/test_search.py | 25 +++-- .../test_gemini/test_zeroddp_state_dict.py | 53 ++--------- .../test_gemini/test_zerooptim_state_dict.py | 25 ++--- 49 files changed, 461 insertions(+), 914 deletions(-) delete mode 100644 tests/components_to_test/__init__.py delete mode 100644 tests/components_to_test/albert.py delete mode 100644 tests/components_to_test/beit.py delete mode 100644 tests/components_to_test/bert.py delete mode 100644 tests/components_to_test/gpt2.py delete mode 100644 tests/components_to_test/hanging_param_model.py delete mode 100644 tests/components_to_test/inline_op_model.py delete mode 100644 tests/components_to_test/registry.py delete mode 100644 tests/components_to_test/repeated_computed_layers.py delete mode 100644 tests/components_to_test/resnet.py delete mode 100644 tests/components_to_test/simple_net.py delete mode 100644 tests/components_to_test/utils/__init__.py delete mode 100644 tests/components_to_test/utils/dummy_data_generator.py create mode 100644 tests/kit/model_zoo/custom/__init__.py create mode 100644 tests/kit/model_zoo/custom/base.py create mode 100644 tests/kit/model_zoo/custom/hanging_param_model.py rename tests/{components_to_test => kit/model_zoo/custom}/nested_model.py (50%) create mode 100644 tests/kit/model_zoo/custom/repeated_computed_layers.py create mode 100644 tests/kit/model_zoo/custom/simple_net.py rename tests/{components_to_test/utils => kit/model_zoo}/executor.py (51%) diff --git a/colossalai/testing/__init__.py b/colossalai/testing/__init__.py index c6956e81fbde..b84ba55a7a13 100644 --- a/colossalai/testing/__init__.py +++ b/colossalai/testing/__init__.py @@ -9,6 +9,7 @@ ) from .pytest_wrapper import run_on_environment_flag from .utils import ( + DummyDataloader, clear_cache_before_run, free_port, parameterize, @@ -34,4 +35,5 @@ "run_on_environment_flag", "check_state_dict_equal", "assert_hf_output_close", + "DummyDataloader", ] diff --git a/colossalai/testing/utils.py b/colossalai/testing/utils.py index fdbda9a598bf..839e7aab3567 100644 --- a/colossalai/testing/utils.py +++ b/colossalai/testing/utils.py @@ -273,3 +273,24 @@ def _clear_cache(*args, **kwargs): return _clear_cache return _wrap_func + + +class DummyDataloader: + def __init__(self, data_gen_fn: Callable, length: int = 10): + self.data_gen_fn = data_gen_fn + self.length = length + self.step = 0 + + def __iter__(self): + self.step = 0 + return self + + def __next__(self): + if self.step < self.length: + self.step += 1 + return self.data_gen_fn() + else: + raise StopIteration + + def __len__(self): + return self.length diff --git a/tests/components_to_test/__init__.py b/tests/components_to_test/__init__.py deleted file mode 100644 index 65eaa72d6e84..000000000000 --- a/tests/components_to_test/__init__.py +++ /dev/null @@ -1,29 +0,0 @@ -from . import ( - beit, - bert, - gpt2, - hanging_param_model, - inline_op_model, - nested_model, - repeated_computed_layers, - resnet, - simple_net, -) -from .utils import run_fwd, run_fwd_bwd - -from . import albert # isort:skip - -__all__ = [ - "bert", - "gpt2", - "hanging_param_model", - "inline_op_model", - "nested_model", - "repeated_computed_layers", - "resnet", - "simple_net", - "run_fwd_bwd", - "albert", - "beit", - "run_fwd", -] diff --git a/tests/components_to_test/albert.py b/tests/components_to_test/albert.py deleted file mode 100644 index 0ba4d19655cd..000000000000 --- a/tests/components_to_test/albert.py +++ /dev/null @@ -1,62 +0,0 @@ -import torch -from transformers import AlbertConfig, AlbertForSequenceClassification - -from .bert import get_bert_data_loader -from .registry import non_distributed_component_funcs - - -@non_distributed_component_funcs.register(name="albert") -def get_training_components(): - hidden_dim = 8 - num_head = 4 - sequence_length = 12 - num_layer = 2 - vocab_size = 32 - - def bert_model_builder(checkpoint: bool = False): - config = AlbertConfig( - vocab_size=vocab_size, - gradient_checkpointing=checkpoint, - hidden_size=hidden_dim, - intermediate_size=hidden_dim * 4, - num_attention_heads=num_head, - max_position_embeddings=sequence_length, - num_hidden_layers=num_layer, - hidden_dropout_prob=0.0, - attention_probs_dropout_prob=0.0, - ) - print("building AlbertForSequenceClassification model") - - # adapting huggingface BertForSequenceClassification for single unittest calling interface - class ModelAdaptor(AlbertForSequenceClassification): - def forward(self, input_ids, labels): - """ - inputs: data, label - outputs: loss - """ - return super().forward(input_ids=input_ids, labels=labels)[0] - - model = ModelAdaptor(config) - # if checkpoint and version.parse(transformers.__version__) >= version.parse("4.11.0"): - # model.gradient_checkpointing_enable() - - return model - - is_distributed = torch.distributed.is_initialized() - trainloader = get_bert_data_loader( - n_class=vocab_size, - batch_size=2, - total_samples=10000, - sequence_length=sequence_length, - is_distributed=is_distributed, - ) - testloader = get_bert_data_loader( - n_class=vocab_size, - batch_size=2, - total_samples=10000, - sequence_length=sequence_length, - is_distributed=is_distributed, - ) - - criterion = None - return bert_model_builder, trainloader, testloader, torch.optim.Adam, criterion diff --git a/tests/components_to_test/beit.py b/tests/components_to_test/beit.py deleted file mode 100644 index d33474ea9a6b..000000000000 --- a/tests/components_to_test/beit.py +++ /dev/null @@ -1,44 +0,0 @@ -import torch -from timm.models.beit import Beit - -from colossalai.utils.cuda import get_current_device - -from .registry import non_distributed_component_funcs -from .utils.dummy_data_generator import DummyDataGenerator - - -class DummyDataLoader(DummyDataGenerator): - img_size = 64 - num_channel = 3 - num_class = 10 - batch_size = 4 - - def generate(self): - data = torch.randn( - ( - DummyDataLoader.batch_size, - DummyDataLoader.num_channel, - DummyDataLoader.img_size, - DummyDataLoader.img_size, - ), - device=get_current_device(), - ) - label = torch.randint( - low=0, high=DummyDataLoader.num_class, size=(DummyDataLoader.batch_size,), device=get_current_device() - ) - return data, label - - -@non_distributed_component_funcs.register(name="beit") -def get_training_components(): - def model_builder(checkpoint=False): - model = Beit( - img_size=DummyDataLoader.img_size, num_classes=DummyDataLoader.num_class, embed_dim=32, depth=2, num_heads=4 - ) - return model - - trainloader = DummyDataLoader() - testloader = DummyDataLoader() - - criterion = torch.nn.CrossEntropyLoss() - return model_builder, trainloader, testloader, torch.optim.Adam, criterion diff --git a/tests/components_to_test/bert.py b/tests/components_to_test/bert.py deleted file mode 100644 index 9f0eef75ae93..000000000000 --- a/tests/components_to_test/bert.py +++ /dev/null @@ -1,88 +0,0 @@ -import torch -import transformers -from packaging import version -from torch.utils.data import SequentialSampler -from transformers import BertConfig, BertForSequenceClassification - -from .registry import non_distributed_component_funcs - - -def get_bert_data_loader( - n_class, - batch_size, - total_samples, - sequence_length, - device=torch.device("cpu:0"), - is_distributed=False, -): - train_data = torch.randint( - low=0, - high=n_class, - size=(total_samples, sequence_length), - device=device, - dtype=torch.long, - ) - train_label = torch.randint(low=0, high=2, size=(total_samples,), device=device, dtype=torch.long) - train_dataset = torch.utils.data.TensorDataset(train_data, train_label) - if is_distributed: - sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) - else: - sampler = SequentialSampler(train_dataset) - train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, sampler=sampler) - return train_loader - - -@non_distributed_component_funcs.register(name="bert") -def get_training_components(): - hidden_dim = 8 - num_head = 4 - sequence_length = 12 - num_layer = 2 - vocab_size = 32 - - def bert_model_builder(checkpoint: bool = False): - config = BertConfig( - vocab_size=vocab_size, - gradient_checkpointing=checkpoint, - hidden_size=hidden_dim, - intermediate_size=hidden_dim * 4, - num_attention_heads=num_head, - max_position_embeddings=sequence_length, - num_hidden_layers=num_layer, - hidden_dropout_prob=0.0, - attention_probs_dropout_prob=0.0, - ) - - # adapting huggingface BertForSequenceClassification for single unittest calling interface - class ModelAdaptor(BertForSequenceClassification): - def forward(self, input_ids, labels): - """ - inputs: data, label - outputs: loss - """ - return super().forward(input_ids=input_ids, labels=labels)[0] - - model = ModelAdaptor(config) - if checkpoint and version.parse(transformers.__version__) >= version.parse("4.11.0"): - model.gradient_checkpointing_enable() - - return model - - is_distributed = torch.distributed.is_initialized() - trainloader = get_bert_data_loader( - n_class=vocab_size, - batch_size=2, - total_samples=10000, - sequence_length=sequence_length, - is_distributed=is_distributed, - ) - testloader = get_bert_data_loader( - n_class=vocab_size, - batch_size=2, - total_samples=10000, - sequence_length=sequence_length, - is_distributed=is_distributed, - ) - - criterion = None - return bert_model_builder, trainloader, testloader, torch.optim.Adam, criterion diff --git a/tests/components_to_test/gpt2.py b/tests/components_to_test/gpt2.py deleted file mode 100644 index 7f826497d2ab..000000000000 --- a/tests/components_to_test/gpt2.py +++ /dev/null @@ -1,92 +0,0 @@ -import torch -import torch.nn as nn -from transformers import GPT2Config, GPT2LMHeadModel - -from colossalai.utils.cuda import get_current_device - -from .registry import non_distributed_component_funcs -from .utils.dummy_data_generator import DummyDataGenerator - - -class DummyDataLoader(DummyDataGenerator): - vocab_size = 128 - batch_size = 4 - seq_len = 64 - - def generate(self): - input_ids = torch.randint( - 0, - DummyDataLoader.vocab_size, - (DummyDataLoader.batch_size, DummyDataLoader.seq_len), - device=get_current_device(), - ) - return input_ids, input_ids - - -class GPTLMModel(nn.Module): - def __init__( - self, - hidden_size=768, - num_layers=12, - num_attention_heads=12, - max_seq_len=1024, - vocab_size=50304, - checkpoint=False, - ): - super().__init__() - self.checkpoint = checkpoint - self.model = GPT2LMHeadModel( - GPT2Config( - n_embd=hidden_size, - n_layer=num_layers, - n_head=num_attention_heads, - n_positions=max_seq_len, - n_ctx=max_seq_len, - vocab_size=vocab_size, - resid_pdrop=0.0, - embd_pdrop=0.0, - attn_pdrop=0.0, - ) - ) - if checkpoint: - self.model.gradient_checkpointing_enable() - - def forward(self, input_ids): - # Only return lm_logits - attention_mask = torch.ones_like(input_ids) - return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=not self.checkpoint)[0] - - -def gpt2_micro(checkpoint=True): - return GPTLMModel( - checkpoint=checkpoint, hidden_size=32, num_layers=2, num_attention_heads=4, max_seq_len=64, vocab_size=128 - ) - - -def gpt2_s(checkpoint=True): - return GPTLMModel(checkpoint=checkpoint) - - -def gpt2_m(checkpoint=True): - return GPTLMModel(hidden_size=1024, num_layers=24, num_attention_heads=16, checkpoint=checkpoint) - - -class GPTLMLoss(nn.Module): - def __init__(self): - super().__init__() - self.loss_fn = nn.CrossEntropyLoss() - - def forward(self, logits, labels): - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) - - -@non_distributed_component_funcs.register(name="gpt2") -def get_training_components(): - trainloader = DummyDataLoader() - testloader = DummyDataLoader() - - criterion = GPTLMLoss() - return gpt2_micro, trainloader, testloader, torch.optim.Adam, criterion diff --git a/tests/components_to_test/hanging_param_model.py b/tests/components_to_test/hanging_param_model.py deleted file mode 100644 index 5531c8d081a0..000000000000 --- a/tests/components_to_test/hanging_param_model.py +++ /dev/null @@ -1,48 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F - -from colossalai.legacy.nn import CheckpointModule - -from .registry import non_distributed_component_funcs -from .utils.dummy_data_generator import DummyDataGenerator - - -class HangingParamModule(CheckpointModule): - """ - Hanging Parameter: a parameter dose not belong to a leaf Module. - It has subordinate nn.modules and a nn.Parameter. - """ - - def __init__(self, checkpoint=False) -> None: - super().__init__(checkpoint=checkpoint) - self.proj1 = nn.Linear(4, 8) - self.weight = nn.Parameter(torch.randn(8, 8)) - self.proj2 = nn.Linear(8, 4) - - def forward(self, x): - x = self.proj1(x) - x = F.linear(x, self.weight) - x = self.proj2(x) - return x - - -class DummyDataLoader(DummyDataGenerator): - def generate(self): - data = torch.rand(16, 4) - label = torch.randint(low=0, high=2, size=(16,)) - return data, label - - -@non_distributed_component_funcs.register(name="hanging_param_model") -def get_training_components(): - def model_builder(checkpoint=False): - return HangingParamModule(checkpoint) - - trainloader = DummyDataLoader() - testloader = DummyDataLoader() - - criterion = torch.nn.CrossEntropyLoss() - from colossalai.nn.optimizer import HybridAdam - - return model_builder, trainloader, testloader, HybridAdam, criterion diff --git a/tests/components_to_test/inline_op_model.py b/tests/components_to_test/inline_op_model.py deleted file mode 100644 index 8bfa9cf34353..000000000000 --- a/tests/components_to_test/inline_op_model.py +++ /dev/null @@ -1,49 +0,0 @@ -import torch -import torch.nn as nn - -from colossalai.legacy.nn import CheckpointModule - -from .registry import non_distributed_component_funcs -from .utils.dummy_data_generator import DummyDataGenerator - - -class InlineOpModule(CheckpointModule): - """ - a module with inline Ops - """ - - def __init__(self, checkpoint=False) -> None: - super().__init__(checkpoint=checkpoint) - self.proj1 = nn.Linear(4, 8) - self.proj2 = nn.Linear(8, 8) - - def forward(self, x): - x = self.proj1(x) - # inline add_ - x.add_(10) - x = self.proj2(x) - # inline relu_ - x = torch.relu_(x) - x = self.proj2(x) - return x - - -class DummyDataLoader(DummyDataGenerator): - def generate(self): - data = torch.rand(16, 4) - label = torch.randint(low=0, high=2, size=(16,)) - return data, label - - -@non_distributed_component_funcs.register(name="inline_op_model") -def get_training_components(): - def model_builder(checkpoint=False): - return InlineOpModule(checkpoint) - - trainloader = DummyDataLoader() - testloader = DummyDataLoader() - - criterion = torch.nn.CrossEntropyLoss() - from colossalai.nn.optimizer import HybridAdam - - return model_builder, trainloader, testloader, HybridAdam, criterion diff --git a/tests/components_to_test/registry.py b/tests/components_to_test/registry.py deleted file mode 100644 index ec561b7831ad..000000000000 --- a/tests/components_to_test/registry.py +++ /dev/null @@ -1,38 +0,0 @@ -#!/usr/bin/env python - - -class Registry: - def __init__(self): - self._registry = dict() - - def register(self, name): - assert name not in self._registry - - def _register(callable_): - self._registry[name] = callable_ - - return _register - - def get_callable(self, name: str): - return self._registry[name] - - def __iter__(self): - self._idx = 0 - self._len = len(self._registry) - self._names = list(self._registry.keys()) - return self - - def __next__(self): - if self._idx < self._len: - key = self._names[self._idx] - callable_ = self._registry[key] - self._idx += 1 - return callable_ - else: - raise StopIteration - - -non_distributed_component_funcs = Registry() -model_parallel_component_funcs = Registry() - -__all__ = ["non_distributed_component_funcs", "model_parallel_component_funcs"] diff --git a/tests/components_to_test/repeated_computed_layers.py b/tests/components_to_test/repeated_computed_layers.py deleted file mode 100644 index 3da64de3fb64..000000000000 --- a/tests/components_to_test/repeated_computed_layers.py +++ /dev/null @@ -1,47 +0,0 @@ -#!/usr/bin/env python - -import torch -import torch.nn as nn - -from colossalai.legacy.nn import CheckpointModule - -from .registry import non_distributed_component_funcs -from .utils.dummy_data_generator import DummyDataGenerator - - -class NetWithRepeatedlyComputedLayers(CheckpointModule): - """ - This model is to test with layers which go through forward pass multiple times. - In this model, the fc1 and fc2 call forward twice - """ - - def __init__(self, checkpoint=False) -> None: - super().__init__(checkpoint=checkpoint) - self.fc1 = nn.Linear(5, 5) - self.fc2 = nn.Linear(5, 5) - self.fc3 = nn.Linear(5, 2) - self.layers = [self.fc1, self.fc2, self.fc1, self.fc2, self.fc3] - - def forward(self, x): - for layer in self.layers: - x = layer(x) - return x - - -class DummyDataLoader(DummyDataGenerator): - def generate(self): - data = torch.rand(16, 5) - label = torch.randint(low=0, high=2, size=(16,)) - return data, label - - -@non_distributed_component_funcs.register(name="repeated_computed_layers") -def get_training_components(): - def model_builder(checkpoint=False): - return NetWithRepeatedlyComputedLayers(checkpoint) - - trainloader = DummyDataLoader() - testloader = DummyDataLoader() - - criterion = torch.nn.CrossEntropyLoss() - return model_builder, trainloader, testloader, torch.optim.Adam, criterion diff --git a/tests/components_to_test/resnet.py b/tests/components_to_test/resnet.py deleted file mode 100644 index a43becc16233..000000000000 --- a/tests/components_to_test/resnet.py +++ /dev/null @@ -1,37 +0,0 @@ -import os -from pathlib import Path - -import torch -from torchvision.datasets import CIFAR10 -from torchvision.models import resnet18 -from torchvision.transforms import transforms - -from colossalai.legacy.utils import get_dataloader - -from .registry import non_distributed_component_funcs - - -def get_cifar10_dataloader(train): - # build dataloaders - dataset = CIFAR10( - root=Path(os.environ["DATA"]), - download=True, - train=train, - transform=transforms.Compose( - [transforms.ToTensor(), transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))] - ), - ) - dataloader = get_dataloader(dataset=dataset, shuffle=True, batch_size=16, drop_last=True) - return dataloader - - -@non_distributed_component_funcs.register(name="resnet18") -def get_resnet_training_components(): - def model_builder(checkpoint=False): - return resnet18(num_classes=10) - - trainloader = get_cifar10_dataloader(train=True) - testloader = get_cifar10_dataloader(train=False) - - criterion = torch.nn.CrossEntropyLoss() - return model_builder, trainloader, testloader, torch.optim.Adam, criterion diff --git a/tests/components_to_test/simple_net.py b/tests/components_to_test/simple_net.py deleted file mode 100644 index 0f0ac5cff49a..000000000000 --- a/tests/components_to_test/simple_net.py +++ /dev/null @@ -1,53 +0,0 @@ -import torch -import torch.nn as nn - -from colossalai.legacy.nn import CheckpointModule -from colossalai.utils.cuda import get_current_device - -from .registry import non_distributed_component_funcs -from .utils.dummy_data_generator import DummyDataGenerator - - -class SimpleNet(CheckpointModule): - """ - In this no-leaf module, it has subordinate nn.modules and a nn.Parameter. - """ - - def __init__(self, checkpoint=False) -> None: - super().__init__(checkpoint=checkpoint) - self.embed = nn.Embedding(20, 4) - self.proj1 = nn.Linear(4, 8) - self.ln1 = nn.LayerNorm(8) - self.proj2 = nn.Linear(8, 4) - self.ln2 = nn.LayerNorm(4) - self.classifier = nn.Linear(4, 4) - - def forward(self, x): - x = self.embed(x) - x = self.proj1(x) - x = self.ln1(x) - x = self.proj2(x) - x = self.ln2(x) - x = self.classifier(x) - return x - - -class DummyDataLoader(DummyDataGenerator): - def generate(self): - data = torch.randint(low=0, high=20, size=(16,), device=get_current_device()) - label = torch.randint(low=0, high=2, size=(16,), device=get_current_device()) - return data, label - - -@non_distributed_component_funcs.register(name="simple_net") -def get_training_components(): - def model_builder(checkpoint=False): - return SimpleNet(checkpoint) - - trainloader = DummyDataLoader() - testloader = DummyDataLoader() - - criterion = torch.nn.CrossEntropyLoss() - from colossalai.nn.optimizer import HybridAdam - - return model_builder, trainloader, testloader, HybridAdam, criterion diff --git a/tests/components_to_test/utils/__init__.py b/tests/components_to_test/utils/__init__.py deleted file mode 100644 index 150124b58800..000000000000 --- a/tests/components_to_test/utils/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .dummy_data_generator import DummyDataGenerator -from .executor import run_fwd, run_fwd_bwd diff --git a/tests/components_to_test/utils/dummy_data_generator.py b/tests/components_to_test/utils/dummy_data_generator.py deleted file mode 100644 index 7b3af46c8f35..000000000000 --- a/tests/components_to_test/utils/dummy_data_generator.py +++ /dev/null @@ -1,24 +0,0 @@ -from abc import ABC, abstractmethod - - -class DummyDataGenerator(ABC): - def __init__(self, length=10): - self.length = length - - @abstractmethod - def generate(self): - pass - - def __iter__(self): - self.step = 0 - return self - - def __next__(self): - if self.step < self.length: - self.step += 1 - return self.generate() - else: - raise StopIteration - - def __len__(self): - return self.length diff --git a/tests/kit/model_zoo/__init__.py b/tests/kit/model_zoo/__init__.py index c08fd365d871..62b9123b59b0 100644 --- a/tests/kit/model_zoo/__init__.py +++ b/tests/kit/model_zoo/__init__.py @@ -1,4 +1,5 @@ -from . import diffusers, timm, torchaudio, torchrec, torchvision, transformers +from . import custom, diffusers, timm, torchaudio, torchrec, torchvision, transformers +from .executor import run_fwd, run_fwd_bwd from .registry import model_zoo -__all__ = ["model_zoo"] +__all__ = ["model_zoo", "run_fwd", "run_fwd_bwd"] diff --git a/tests/kit/model_zoo/custom/__init__.py b/tests/kit/model_zoo/custom/__init__.py new file mode 100644 index 000000000000..1f8ac324d4d6 --- /dev/null +++ b/tests/kit/model_zoo/custom/__init__.py @@ -0,0 +1,4 @@ +from .hanging_param_model import * +from .nested_model import * +from .repeated_computed_layers import * +from .simple_net import * diff --git a/tests/kit/model_zoo/custom/base.py b/tests/kit/model_zoo/custom/base.py new file mode 100644 index 000000000000..4a0f505826f1 --- /dev/null +++ b/tests/kit/model_zoo/custom/base.py @@ -0,0 +1,26 @@ +import torch.nn as nn +from torch.utils.checkpoint import checkpoint + + +class CheckpointModule(nn.Module): + def __init__(self, checkpoint: bool = False): + super().__init__() + self.checkpoint = checkpoint + self._use_checkpoint = checkpoint + + def _forward(self, *args, **kwargs): + raise NotImplementedError("CheckpointModule should implement _forward method instead of origin forward") + + def forward(self, *args, **kwargs): + if self._use_checkpoint: + return checkpoint(self._forward, *args, **kwargs) + else: + return self._forward(*args, **kwargs) + + def train(self, mode: bool = True): + self._use_checkpoint = self.checkpoint + return super().train(mode=mode) + + def eval(self): + self._use_checkpoint = False + return super().eval() diff --git a/tests/kit/model_zoo/custom/hanging_param_model.py b/tests/kit/model_zoo/custom/hanging_param_model.py new file mode 100644 index 000000000000..a8ace5f35e6a --- /dev/null +++ b/tests/kit/model_zoo/custom/hanging_param_model.py @@ -0,0 +1,48 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ..registry import model_zoo +from .base import CheckpointModule + + +class HangingParamModule(CheckpointModule): + """ + Hanging Parameter: a parameter dose not belong to a leaf Module. + It has subordinate nn.modules and a nn.Parameter. + """ + + def __init__(self, checkpoint=False) -> None: + super().__init__(checkpoint=checkpoint) + self.proj1 = nn.Linear(4, 8) + self.weight = nn.Parameter(torch.randn(8, 8)) + self.proj2 = nn.Linear(8, 4) + + def forward(self, x): + x = self.proj1(x) + x = F.linear(x, self.weight) + x = self.proj2(x) + return x + + +def data_gen(): + return dict(x=torch.rand(16, 4)) + + +def loss_fn(x): + outputs = x["x"] + label = torch.randint(low=0, high=2, size=(16,), device=outputs.device) + return F.cross_entropy(x["x"], label) + + +def output_transform(x: torch.Tensor): + return dict(x=x) + + +model_zoo.register( + name="custom_hanging_param_model", + model_fn=HangingParamModule, + data_gen_fn=data_gen, + output_transform_fn=output_transform, + loss_fn=loss_fn, +) diff --git a/tests/components_to_test/nested_model.py b/tests/kit/model_zoo/custom/nested_model.py similarity index 50% rename from tests/components_to_test/nested_model.py rename to tests/kit/model_zoo/custom/nested_model.py index 44577456dec5..2eb1c8398a29 100644 --- a/tests/components_to_test/nested_model.py +++ b/tests/kit/model_zoo/custom/nested_model.py @@ -2,10 +2,8 @@ import torch.nn as nn import torch.nn.functional as F -from colossalai.legacy.nn import CheckpointModule - -from .registry import non_distributed_component_funcs -from .utils import DummyDataGenerator +from ..registry import model_zoo +from .base import CheckpointModule class SubNet(nn.Module): @@ -32,20 +30,24 @@ def forward(self, x): return x -class DummyDataLoader(DummyDataGenerator): - def generate(self): - data = torch.rand(16, 5) - label = torch.randint(low=0, high=2, size=(16,)) - return data, label +def data_gen(): + return dict(x=torch.rand(16, 5)) + + +def loss_fn(x): + outputs = x["x"] + label = torch.randint(low=0, high=2, size=(16,), device=outputs.device) + return F.cross_entropy(x["x"], label) -@non_distributed_component_funcs.register(name="nested_model") -def get_training_components(): - def model_builder(checkpoint=False): - return NestedNet(checkpoint) +def output_transform(x: torch.Tensor): + return dict(x=x) - trainloader = DummyDataLoader() - testloader = DummyDataLoader() - criterion = torch.nn.CrossEntropyLoss() - return model_builder, trainloader, testloader, torch.optim.Adam, criterion +model_zoo.register( + name="custom_nested_model", + model_fn=NestedNet, + data_gen_fn=data_gen, + output_transform_fn=output_transform, + loss_fn=loss_fn, +) diff --git a/tests/kit/model_zoo/custom/repeated_computed_layers.py b/tests/kit/model_zoo/custom/repeated_computed_layers.py new file mode 100644 index 000000000000..781fecc51427 --- /dev/null +++ b/tests/kit/model_zoo/custom/repeated_computed_layers.py @@ -0,0 +1,48 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ..registry import model_zoo +from .base import CheckpointModule + + +class NetWithRepeatedlyComputedLayers(CheckpointModule): + """ + This model is to test with layers which go through forward pass multiple times. + In this model, the fc1 and fc2 call forward twice + """ + + def __init__(self, checkpoint=False) -> None: + super().__init__(checkpoint=checkpoint) + self.fc1 = nn.Linear(5, 5) + self.fc2 = nn.Linear(5, 5) + self.fc3 = nn.Linear(5, 2) + self.layers = [self.fc1, self.fc2, self.fc1, self.fc2, self.fc3] + + def forward(self, x): + for layer in self.layers: + x = layer(x) + return x + + +def data_gen(): + return dict(x=torch.rand(16, 5)) + + +def loss_fn(x): + outputs = x["x"] + label = torch.randint(low=0, high=2, size=(16,), device=outputs.device) + return F.cross_entropy(x["x"], label) + + +def output_transform(x: torch.Tensor): + return dict(x=x) + + +model_zoo.register( + name="custom_repeated_computed_layers", + model_fn=NetWithRepeatedlyComputedLayers, + data_gen_fn=data_gen, + output_transform_fn=output_transform, + loss_fn=loss_fn, +) diff --git a/tests/kit/model_zoo/custom/simple_net.py b/tests/kit/model_zoo/custom/simple_net.py new file mode 100644 index 000000000000..ae68fccf9c61 --- /dev/null +++ b/tests/kit/model_zoo/custom/simple_net.py @@ -0,0 +1,53 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ..registry import model_zoo +from .base import CheckpointModule + + +class SimpleNet(CheckpointModule): + """ + In this no-leaf module, it has subordinate nn.modules and a nn.Parameter. + """ + + def __init__(self, checkpoint=False) -> None: + super().__init__(checkpoint=checkpoint) + self.embed = nn.Embedding(20, 4) + self.proj1 = nn.Linear(4, 8) + self.ln1 = nn.LayerNorm(8) + self.proj2 = nn.Linear(8, 4) + self.ln2 = nn.LayerNorm(4) + self.classifier = nn.Linear(4, 4) + + def forward(self, x): + x = self.embed(x) + x = self.proj1(x) + x = self.ln1(x) + x = self.proj2(x) + x = self.ln2(x) + x = self.classifier(x) + return x + + +def data_gen(): + return dict(x=torch.randint(low=0, high=20, size=(16,))) + + +def loss_fn(x): + outputs = x["x"] + label = torch.randint(low=0, high=2, size=(16,), device=outputs.device) + return F.cross_entropy(x["x"], label) + + +def output_transform(x: torch.Tensor): + return dict(x=x) + + +model_zoo.register( + name="custom_simple_net", + model_fn=SimpleNet, + data_gen_fn=data_gen, + output_transform_fn=output_transform, + loss_fn=loss_fn, +) diff --git a/tests/components_to_test/utils/executor.py b/tests/kit/model_zoo/executor.py similarity index 51% rename from tests/components_to_test/utils/executor.py rename to tests/kit/model_zoo/executor.py index 631401e022e6..033d6d12dd07 100644 --- a/tests/components_to_test/utils/executor.py +++ b/tests/kit/model_zoo/executor.py @@ -1,7 +1,15 @@ +from typing import Callable, Dict, Optional, Union + import torch +from torch.nn import Module +from torch.optim import Optimizer + +from colossalai.interface import OptimizerWrapper -def run_fwd(model, data, label, criterion) -> torch.Tensor: +def run_fwd( + model: Module, data: Dict, output_transform_fn: Callable, criterion: Optional[Callable] = None +) -> torch.Tensor: """run_fwd run fwd for the model @@ -14,18 +22,22 @@ def run_fwd(model, data, label, criterion) -> torch.Tensor: Returns: torch.Tensor: loss of fwd """ + outputs = model(**data) + outputs = output_transform_fn(outputs) if criterion: - y = model(data) - y = y.float() - loss = criterion(y, label) + loss = criterion(outputs) else: - loss = model(data, label) - - loss = loss.float() + loss = next(iter(outputs.values())).sum() return loss -def run_fwd_bwd(model, data, label, criterion, optimizer=None) -> torch.Tensor: +def run_fwd_bwd( + model: Module, + data: Dict, + output_transform_fn: Callable, + criterion: Optional[Callable] = None, + optimizer: Optional[Union[Optimizer, OptimizerWrapper]] = None, +) -> torch.Tensor: """run_fwd_bwd run fwd and bwd for the model @@ -38,7 +50,7 @@ def run_fwd_bwd(model, data, label, criterion, optimizer=None) -> torch.Tensor: Returns: torch.Tensor: loss of fwd """ - loss = run_fwd(model, data, label, criterion) + loss = run_fwd(model, data, output_transform_fn, criterion) if optimizer: optimizer.backward(loss) else: diff --git a/tests/kit/model_zoo/transformers/bert.py b/tests/kit/model_zoo/transformers/bert.py index 8b90a3c7372c..6dd3e102c20f 100644 --- a/tests/kit/model_zoo/transformers/bert.py +++ b/tests/kit/model_zoo/transformers/bert.py @@ -359,9 +359,9 @@ def data_gen_for_qa(): # define loss funciton loss_fn_for_bert_model = lambda x: torch.nn.functional.mse_loss( - x.last_hidden_state, torch.ones_like(x.last_hidden_state) + x["last_hidden_state"], torch.ones_like(x["last_hidden_state"]) ) -loss_fn = lambda x: x.loss +loss_fn = lambda x: x["loss"] config = transformers.BertConfig( hidden_size=128, diff --git a/tests/kit/model_zoo/transformers/blip2.py b/tests/kit/model_zoo/transformers/blip2.py index 887b11c7f54e..0be9268307ce 100644 --- a/tests/kit/model_zoo/transformers/blip2.py +++ b/tests/kit/model_zoo/transformers/blip2.py @@ -35,7 +35,7 @@ def data_gen(): output_transform_fn = lambda x: x # define loss funciton -loss_fn_blip2_model = lambda x: x.loss +loss_fn_blip2_model = lambda x: x["loss"] config = transformers.Blip2Config() config.vision_config.patch_size = 14 diff --git a/tests/kit/model_zoo/transformers/bloom.py b/tests/kit/model_zoo/transformers/bloom.py index 12dcd71d5d1b..07f1d497777d 100644 --- a/tests/kit/model_zoo/transformers/bloom.py +++ b/tests/kit/model_zoo/transformers/bloom.py @@ -69,11 +69,11 @@ def data_gen_for_question_answering(): # define loss function loss_fn_for_bloom_model = lambda x: torch.nn.functional.mse_loss( - x.last_hidden_state, torch.ones_like(x.last_hidden_state) + x["last_hidden_state"], torch.ones_like(x["last_hidden_state"]) ) -loss_fn_for_causal_lm = lambda x: x.loss -loss_fn_for_classification = lambda x: x.loss -loss_fn_for_question_answering = lambda x: x.loss +loss_fn_for_causal_lm = lambda x: x["loss"] +loss_fn_for_classification = lambda x: x["loss"] +loss_fn_for_question_answering = lambda x: x["loss"] config = transformers.BloomConfig( n_layer=2, n_head=4, vocab_size=250880, hidden_dropout=0, attention_dropout=0, hidden_size=64, pad_token_id=50256 diff --git a/tests/kit/model_zoo/transformers/chatglm2.py b/tests/kit/model_zoo/transformers/chatglm2.py index f4369cb7d171..0b178d58ce33 100644 --- a/tests/kit/model_zoo/transformers/chatglm2.py +++ b/tests/kit/model_zoo/transformers/chatglm2.py @@ -30,9 +30,9 @@ def data_gen_for_conditional_generation(): # define loss function loss_fn_for_chatglm_model = lambda x: torch.nn.functional.mse_loss( - x.last_hidden_state, torch.ones_like(x.last_hidden_state) + x["last_hidden_state"], torch.ones_like(x["last_hidden_state"]) ) -loss_fn = lambda x: x.loss +loss_fn = lambda x: x["loss"] config = ChatGLMConfig( num_layers=2, diff --git a/tests/kit/model_zoo/transformers/gpt.py b/tests/kit/model_zoo/transformers/gpt.py index 2af6176fbe4a..5e98c02fd4fc 100644 --- a/tests/kit/model_zoo/transformers/gpt.py +++ b/tests/kit/model_zoo/transformers/gpt.py @@ -87,13 +87,14 @@ def date_gen_for_double_heads(): # define loss function loss_fn_for_gpt2_model = lambda x: torch.nn.functional.mse_loss( - x.last_hidden_state, torch.ones_like(x.last_hidden_state) + x["last_hidden_state"], torch.ones_like(x["last_hidden_state"]) ) -loss_fn = lambda x: x.loss +loss_fn = lambda x: x["loss"] config = transformers.GPT2Config( n_layer=2, n_head=4, + n_embd=128, vocab_size=50258, attn_pdrop=0, embd_pdrop=0, diff --git a/tests/kit/model_zoo/transformers/llama.py b/tests/kit/model_zoo/transformers/llama.py index bc229b17e08c..041de6b90f8d 100644 --- a/tests/kit/model_zoo/transformers/llama.py +++ b/tests/kit/model_zoo/transformers/llama.py @@ -42,9 +42,9 @@ def data_gen_for_casual_lm(): output_transform_fn = lambda x: x # function to get the loss - loss_fn = lambda output: output.last_hidden_state.mean() - loss_fn_for_casual_lm = lambda output: output.loss - loss_fn_for_seq_classification = lambda output: output.logits.mean() + loss_fn = lambda output: output["last_hidden_state"].mean() + loss_fn_for_casual_lm = lambda output: output["loss"] + loss_fn_for_seq_classification = lambda output: output["logits"].mean() config = LlamaConfig( num_hidden_layers=4, diff --git a/tests/kit/model_zoo/transformers/opt.py b/tests/kit/model_zoo/transformers/opt.py index 07ca41ef21ae..2da94a4fcc0f 100644 --- a/tests/kit/model_zoo/transformers/opt.py +++ b/tests/kit/model_zoo/transformers/opt.py @@ -45,9 +45,9 @@ def data_gen_for_question_answering(): output_transform_fn = lambda x: x loss_fn_for_opt_model = lambda x: torch.nn.functional.mse_loss( - x.last_hidden_state, torch.ones_like(x.last_hidden_state) + x["last_hidden_state"], torch.ones_like(x["last_hidden_state"]) ) -loss_fn_for_lm = lambda x: x.loss +loss_fn_for_lm = lambda x: x["loss"] config = transformers.OPTConfig( hidden_size=128, num_hidden_layers=2, diff --git a/tests/kit/model_zoo/transformers/sam.py b/tests/kit/model_zoo/transformers/sam.py index b928a8f14e75..7e756abe91b8 100644 --- a/tests/kit/model_zoo/transformers/sam.py +++ b/tests/kit/model_zoo/transformers/sam.py @@ -40,7 +40,7 @@ def data_gen(): output_transform_fn = lambda x: x # define loss funciton -loss_fn = lambda x: x.iou_scores.mean() +loss_fn = lambda x: x["iou_scores"].mean() config = transformers.SamConfig() config.vision_config.num_hidden_layers = 2 diff --git a/tests/kit/model_zoo/transformers/t5.py b/tests/kit/model_zoo/transformers/t5.py index 1b63cccc42ee..2ccfb0356c2b 100644 --- a/tests/kit/model_zoo/transformers/t5.py +++ b/tests/kit/model_zoo/transformers/t5.py @@ -44,9 +44,9 @@ def data_gen_for_t5_model(): output_transform_fn = lambda x: x # define loss function -loss_fn_for_t5_model = lambda x: x.last_hidden_state.mean() -loss_fn_for_encoder_only = lambda x: x.last_hidden_state.mean() -loss_fn_for_conditional_generation = lambda x: x.loss +loss_fn_for_t5_model = lambda x: x["last_hidden_state"].mean() +loss_fn_for_encoder_only = lambda x: x["last_hidden_state"].mean() +loss_fn_for_conditional_generation = lambda x: x["loss"] # define model config config = transformers.T5Config(d_model=128, num_layers=2, dropout_rate=0, decoder_start_token_id=0) diff --git a/tests/kit/model_zoo/transformers/vit.py b/tests/kit/model_zoo/transformers/vit.py index f1990751b016..223559d73a55 100644 --- a/tests/kit/model_zoo/transformers/vit.py +++ b/tests/kit/model_zoo/transformers/vit.py @@ -34,9 +34,9 @@ def data_gen_for_masked_image_modeling(): output_transform_fn = lambda x: x # function to get the loss -loss_fn_for_vit_model = lambda x: x.pooler_output.mean() -loss_fn_for_image_classification = lambda x: x.logits.mean() -loss_fn_for_masked_image_modeling = lambda x: x.loss +loss_fn_for_vit_model = lambda x: x["pooler_output"].mean() +loss_fn_for_image_classification = lambda x: x["logits"].mean() +loss_fn_for_masked_image_modeling = lambda x: x["loss"] # register the following models # transformers.ViTModel, diff --git a/tests/kit/model_zoo/transformers/whisper.py b/tests/kit/model_zoo/transformers/whisper.py index 928be4468c01..d69bebe6cc04 100644 --- a/tests/kit/model_zoo/transformers/whisper.py +++ b/tests/kit/model_zoo/transformers/whisper.py @@ -53,8 +53,8 @@ def data_gen_for_audio_classification(): output_transform_fn = lambda x: x # define loss funciton -loss_fn = lambda x: torch.nn.functional.mse_loss(x.last_hidden_state, torch.ones_like(x.last_hidden_state)) -loss_fn_attr = lambda x: x.loss +loss_fn = lambda x: torch.nn.functional.mse_loss(x["last_hidden_state"], torch.ones_like(x["last_hidden_state"])) +loss_fn_attr = lambda x: x["loss"] config = transformers.WhisperConfig( classifier_proj_size=256, diff --git a/tests/test_legacy/test_amp/test_naive_fp16.py b/tests/test_legacy/test_amp/test_naive_fp16.py index 76f9ff07407f..fe16bc4d480a 100644 --- a/tests/test_legacy/test_amp/test_naive_fp16.py +++ b/tests/test_legacy/test_amp/test_naive_fp16.py @@ -6,7 +6,7 @@ import colossalai from colossalai.legacy.amp import convert_to_apex_amp, convert_to_naive_amp from colossalai.testing import assert_close_loose, clear_cache_before_run, rerun_if_address_is_in_use, spawn -from tests.components_to_test.registry import non_distributed_component_funcs +from tests.kit.model_zoo import model_zoo def check_equal(a, b): @@ -25,13 +25,12 @@ def run_naive_amp(): torch.backends.cudnn.deterministic = True # create layer - test_models = ["repeated_computed_layers", "nested_model", "resnet18"] + test_models = ["custom_repeated_computed_layers", "custom_nested_model", "torchvision_resnet18"] for test_name in test_models: - get_component_func = non_distributed_component_funcs.get_callable(test_name) - model_builder, train_dataloader, _, optim_class, _ = get_component_func() + model_builder, data_gen_fn, *_ = next(iter(model_zoo.get_sub_registry(test_name).values())) # create model - naive_amp_model = model_builder(checkpoint=True).cuda() + naive_amp_model = model_builder().cuda() apex_amp_model = copy.deepcopy(naive_amp_model) # create optimizer @@ -48,13 +47,12 @@ def run_naive_amp(): apex_amp_model, apex_amp_optimizer = convert_to_apex_amp(apex_amp_model, apex_amp_optimizer, apex_amp_config) # create data - data_iter = iter(train_dataloader) - data, label = next(data_iter) - data = data.cuda() + data = data_gen_fn() + data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()} # forward pass - naive_amp_output = naive_amp_model(data) - apex_amp_output = apex_amp_model(data) + naive_amp_output = naive_amp_model(**data) + apex_amp_output = apex_amp_model(**data) assert_close_loose(naive_amp_output, apex_amp_output) # backward diff --git a/tests/test_legacy/test_amp/test_torch_fp16.py b/tests/test_legacy/test_amp/test_torch_fp16.py index 47b303745e4e..5e2e1ede5725 100644 --- a/tests/test_legacy/test_amp/test_torch_fp16.py +++ b/tests/test_legacy/test_amp/test_torch_fp16.py @@ -6,7 +6,7 @@ import colossalai from colossalai.legacy.amp import convert_to_apex_amp, convert_to_torch_amp from colossalai.testing import assert_close_loose, clear_cache_before_run, rerun_if_address_is_in_use, spawn -from tests.components_to_test.registry import non_distributed_component_funcs +from tests.kit.model_zoo import model_zoo def run_torch_amp(): @@ -18,13 +18,12 @@ def run_torch_amp(): torch.backends.cudnn.deterministic = True # create layer - test_models = ["resnet18", "simple_net"] + test_models = ["torchvision_resnet18", "custom_simple_net"] for test_name in test_models: - get_component_func = non_distributed_component_funcs.get_callable(test_name) - model_builder, train_dataloader, _, optim_class, _ = get_component_func() + model_builder, data_gen_fn, *_ = next(iter(model_zoo.get_sub_registry(test_name).values())) # create model - torch_amp_model = model_builder(checkpoint=True).cuda() + torch_amp_model = model_builder().cuda() apex_amp_model = copy.deepcopy(torch_amp_model) # create optimizer @@ -41,13 +40,12 @@ def run_torch_amp(): apex_amp_model, apex_amp_optimizer = convert_to_apex_amp(apex_amp_model, apex_amp_optimizer, apex_amp_config) # create data - data_iter = iter(train_dataloader) - data, label = next(data_iter) - data = data.cuda() + data = data_gen_fn() + data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()} # forward pass - torch_amp_output = torch_amp_model(data) - apex_amp_output = apex_amp_model(data) + torch_amp_output = torch_amp_model(**data) + apex_amp_output = apex_amp_model(**data) assert_close_loose(torch_amp_output, apex_amp_output) for torch_amp_param, apex_amp_param in zip(torch_amp_model.parameters(), apex_amp_model.parameters()): diff --git a/tests/test_legacy/test_engine/test_engine.py b/tests/test_legacy/test_engine/test_engine.py index b07fe8abe86e..1bb0b49c5362 100644 --- a/tests/test_legacy/test_engine/test_engine.py +++ b/tests/test_legacy/test_engine/test_engine.py @@ -1,10 +1,11 @@ import pytest +import torch import colossalai from colossalai.legacy.amp import AMP_TYPE from colossalai.legacy.core import global_context as gpc -from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -from tests.components_to_test.registry import non_distributed_component_funcs +from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn +from tests.kit.model_zoo import model_zoo CONFIG = dict( parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None)), fp16=dict(mode=None), clip_grad_norm=1.0 @@ -15,29 +16,29 @@ @parameterize("amp_mode", [AMP_TYPE.APEX, AMP_TYPE.TORCH, AMP_TYPE.NAIVE, None]) def run_train(model_name, amp_mode): # FIXME: test bert - get_components_func = non_distributed_component_funcs.get_callable(model_name) + model_builder, data_gen_fn, *_ = next(iter(model_zoo.get_sub_registry(model_name).values())) + train_dataloader = DummyDataloader(data_gen_fn) + criterion = lambda x: x.sum() gpc.config.fp16["mode"] = amp_mode - model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func() - model = model_builder(checkpoint=False) + model = model_builder() engine, train_dataloader, *args = colossalai.legacy.initialize( model=model, - optimizer=optimizer_class(model.parameters(), lr=1e-3), + optimizer=torch.optim.Adam(model.parameters(), lr=1e-3), criterion=criterion, train_dataloader=train_dataloader, ) try: engine.train() - for data, label in train_dataloader: + for data in train_dataloader: engine.zero_grad() - data = data.cuda() - label = label.cuda() + data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()} if criterion: - output = engine(data) - loss = engine.criterion(output, label) + output = engine(**data) + loss = engine.criterion(output) else: - loss = engine(data, label) + loss = engine(**data) engine.backward(loss) engine.step() break diff --git a/tests/test_legacy/test_trainer/test_trainer_with_non_pipe_schedule.py b/tests/test_legacy/test_trainer/test_trainer_with_non_pipe_schedule.py index d19b12a5b044..d75ddbff7cf3 100644 --- a/tests/test_legacy/test_trainer/test_trainer_with_non_pipe_schedule.py +++ b/tests/test_legacy/test_trainer/test_trainer_with_non_pipe_schedule.py @@ -5,9 +5,9 @@ from colossalai.legacy.amp.amp_type import AMP_TYPE from colossalai.legacy.trainer import Trainer from colossalai.logging import get_dist_logger -from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn from colossalai.utils import MultiTimer -from tests.components_to_test.registry import non_distributed_component_funcs +from tests.kit.model_zoo import model_zoo BATCH_SIZE = 4 IMG_SIZE = 32 @@ -16,12 +16,14 @@ CONFIG = dict(fp16=dict(mode=AMP_TYPE.TORCH)) -@parameterize("model_name", ["repeated_computed_layers", "resnet18", "nested_model"]) +@parameterize("model_name", ["custom_repeated_computed_layers", "torchvision_resnet18", "custom_nested_model"]) def run_trainer(model_name): - get_components_func = non_distributed_component_funcs.get_callable(model_name) - model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + model_builder, data_gen_fn, *_ = next(iter(model_zoo.get_sub_registry(model_name).values())) model = model_builder() - optimizer = optimizer_class(model.parameters(), lr=1e-3) + optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) + train_dataloader = DummyDataloader(data_gen_fn) + test_dataloader = DummyDataloader(data_gen_fn) + criterion = lambda x: x.sum() engine, train_dataloader, *_ = colossalai.legacy.initialize( model=model, optimizer=optimizer, criterion=criterion, train_dataloader=train_dataloader ) diff --git a/tests/test_optimizer/test_nvme.py b/tests/test_optimizer/test_nvme.py index a68a9c51855f..4ff16bb9b7c9 100644 --- a/tests/test_optimizer/test_nvme.py +++ b/tests/test_optimizer/test_nvme.py @@ -2,7 +2,7 @@ from colossalai.nn.optimizer import CPUAdam, HybridAdam from colossalai.testing import clear_cache_before_run, parameterize -from tests.components_to_test.registry import non_distributed_component_funcs +from tests.kit.model_zoo import model_zoo def move_some_params_to_cuda(model, torch_model): @@ -22,8 +22,7 @@ def check_params_equal(model, torch_model): @parameterize("nvme_offload_dir", ["./offload", None]) @parameterize("adam_cls", [CPUAdam, HybridAdam]) def test_nvme_adam(nvme_offload_fraction, nvme_offload_dir, adam_cls): - get_components_func = non_distributed_component_funcs.get_callable("simple_net") - model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + model_builder, data_gen_fn, *_ = next(iter(model_zoo.get_sub_registry("custom_simple_net").values())) model = model_builder() torch_model = model_builder() move_some_params_to_cuda(model, torch_model) diff --git a/tests/test_zero/test_gemini/test_fwd_bwd.py b/tests/test_zero/test_gemini/test_fwd_bwd.py index 2fb2bcbc851a..b8d3f45e0f34 100644 --- a/tests/test_zero/test_gemini/test_fwd_bwd.py +++ b/tests/test_zero/test_gemini/test_fwd_bwd.py @@ -12,8 +12,7 @@ from colossalai.utils.cuda import get_current_device from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero.gemini.chunk import search_chunk_configuration -from tests.components_to_test import run_fwd_bwd -from tests.components_to_test.registry import non_distributed_component_funcs +from tests.kit.model_zoo import model_zoo, run_fwd_bwd PLACEMENT_CONFIGS = [ {"placement_policy": "static", "shard_param_frac": 0.0}, # zero2 @@ -38,7 +37,7 @@ def check_grad(model: GeminiDDP, torch_model: torch.nn.Module): @parameterize("placement_config", PLACEMENT_CONFIGS) @parameterize("keep_gather", [False, True]) -@parameterize("model_name", ["gpt2", "bert"]) +@parameterize("model_name", ["transformers_gpt_lm"]) @parameterize("use_grad_checkpoint", [False, True]) @parameterize("master_weights", [False, True]) def exam_gpt_fwd_bwd( @@ -49,17 +48,22 @@ def exam_gpt_fwd_bwd( master_weights: bool = True, ): init_device = get_current_device() - get_components_func = non_distributed_component_funcs.get_callable(model_name) - model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next( + iter(model_zoo.get_sub_registry(model_name).values()) + ) set_seed(42) - model = model_builder(use_grad_checkpoint) + model = model_builder() set_seed(42) - torch_model = model_builder(use_grad_checkpoint).cuda() + torch_model = model_builder().cuda() for torch_p, p in zip(torch_model.parameters(), model.parameters()): torch_p.data.copy_(p.data) + if use_grad_checkpoint: + model.gradient_checkpointing_enable() + torch_model.gradient_checkpointing_enable() + world_size = torch.distributed.get_world_size() config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100) config_dict[world_size]["chunk_size"] = 5000 @@ -77,25 +81,22 @@ def exam_gpt_fwd_bwd( torch_model = DDP(torch_model, device_ids=[rank]) set_seed(rank) - for i, (input_ids, label) in enumerate(train_dataloader): - # you can only test a single fwd + bwd. - # after bwd param is grad for Gemini, due to the chunk reuse optimization. - if i > 0: - break - input_ids, label = input_ids.cuda(), label.cuda() - torch_optim.zero_grad() - zero_optim.zero_grad() + data = data_gen_fn() + data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()} + + torch_optim.zero_grad() + zero_optim.zero_grad() - # set random seed is same as torch_model.eval() - set_seed(42) - torch_loss = run_fwd_bwd(torch_model, input_ids, label, criterion, torch_optim) - set_seed(42) - loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim) + # set random seed is same as torch_model.eval() + set_seed(42) + torch_loss = run_fwd_bwd(torch_model, data, output_transform_fn, loss_fn, optimizer=torch_optim) + set_seed(42) + loss = run_fwd_bwd(model, data, output_transform_fn, loss_fn, optimizer=zero_optim) - assert torch.equal(torch_loss, loss) + assert_close(torch_loss.float(), loss.float()) - check_grad(model, torch_model) + check_grad(model, torch_model) def run_dist(rank, world_size, port): diff --git a/tests/test_zero/test_gemini/test_gemini_use_rmt.py b/tests/test_zero/test_gemini/test_gemini_use_rmt.py index 2fa2d50a6caa..90ad62d1ac78 100644 --- a/tests/test_zero/test_gemini/test_gemini_use_rmt.py +++ b/tests/test_zero/test_gemini/test_gemini_use_rmt.py @@ -3,38 +3,34 @@ import torch.distributed as dist import colossalai -from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn from colossalai.utils import set_seed from colossalai.zero import GeminiDDP from colossalai.zero.gemini.chunk import search_chunk_configuration from colossalai.zero.gemini.memory_tracer.runtime_mem_tracer import RuntimeMemTracer -from tests.components_to_test import run_fwd_bwd -from tests.components_to_test.registry import non_distributed_component_funcs +from tests.kit.model_zoo import model_zoo, run_fwd_bwd # run gemini use the runtime memory tracer @parameterize("placement_policy", ["auto"]) @parameterize("keep_gather", [False]) -@parameterize("model_name", ["repeated_computed_layers", "bert", "albert", "gpt2"]) +@parameterize("model_name", ["transformers_bert_for_sequence_classification"]) @parameterize("use_grad_checkpoint", [False, True]) def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_checkpoint: bool = False): set_seed(42) - get_components_func = non_distributed_component_funcs.get_callable(model_name) - model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + model_builder, data_gen_fn, output_transform_fn, *_ = next(iter(model_zoo.get_sub_registry(model_name).values())) - model = model_builder(use_grad_checkpoint).cuda() + model = model_builder().cuda() + if use_grad_checkpoint: + model.gradient_checkpointing_enable() print(f"model_name {model_name}") - runtime_mem_tracer = RuntimeMemTracer(model) - for i, (input_ids, label) in enumerate(train_dataloader): - if i > 0: - break - input_ids, label = input_ids.cuda(), label.cuda() - # mem tracing - if i == 0: - run_fwd_bwd(runtime_mem_tracer, input_ids, label, criterion, runtime_mem_tracer) + runtime_mem_tracer = RuntimeMemTracer(model) + data = data_gen_fn() + data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()} + run_fwd_bwd(runtime_mem_tracer, data, output_transform_fn, optimizer=runtime_mem_tracer) memstats = runtime_mem_tracer.memstats() runtime_tracer_non_model_data = runtime_mem_tracer._memstats._non_model_data_cuda_list print("runtime tracer non model data points: ", len(runtime_tracer_non_model_data)) @@ -62,16 +58,17 @@ def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_ ) set_seed(dist.get_rank()) - for i, (input_ids, label) in enumerate(train_dataloader): + train_dataloader = DummyDataloader(data_gen_fn) + for i, data in enumerate(train_dataloader): # you can only test a single fwd + bwd. # after bwd param is grad for Gemini, due to the chunk reuse optimization. # print(f'iteration {i}') if i > 4: break - input_ids, label = input_ids.cuda(), label.cuda() + data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()} set_seed(42) - run_fwd_bwd(model, input_ids, label, criterion, model) + run_fwd_bwd(model, data, output_transform_fn, optimizer=model) gemini_non_model_data = model.gemini_manager._mem_stats_collector._memstats.non_model_data_list("cuda") diff --git a/tests/test_zero/test_gemini/test_grad_accum.py b/tests/test_zero/test_gemini/test_grad_accum.py index 334a57410817..5e36b18389b1 100644 --- a/tests/test_zero/test_gemini/test_grad_accum.py +++ b/tests/test_zero/test_gemini/test_grad_accum.py @@ -7,13 +7,12 @@ import colossalai from colossalai.nn.optimizer import HybridAdam -from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn from colossalai.utils import set_seed from colossalai.utils.cuda import get_current_device from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero.gemini.chunk import search_chunk_configuration -from tests.components_to_test import run_fwd -from tests.components_to_test.registry import non_distributed_component_funcs +from tests.kit.model_zoo import model_zoo, run_fwd PLACEMENT_CONFIGS = [ {"placement_policy": "static", "shard_param_frac": 0.0}, # zero2 @@ -38,7 +37,7 @@ def check_grad(model: GeminiDDP, torch_model: torch.nn.Module): # Compare gradients. for p0, p1 in zip(model.parameters(), torch_model.parameters()): - assert_close(p0, p1.grad, rtol=1e-3, atol=5e-5) + assert_close(p0, p1.grad, rtol=2e-3, atol=2e-2) # Release gradient chunks and move them to gradient device. for grad_chunk, device in zip(grad_chunk_list, device_list): @@ -48,21 +47,19 @@ def check_grad(model: GeminiDDP, torch_model: torch.nn.Module): @parameterize("placement_config", PLACEMENT_CONFIGS) @parameterize("keep_gathered", [False, True]) -@parameterize("model_name", ["gpt2", "bert"]) -@parameterize("use_grad_checkpoint", [False, True]) +@parameterize("model_name", ["transformers_gpt_lm"]) @parameterize("master_weights", [False, True]) -def exam_gemini_grad_acc( - placement_config, keep_gathered: bool, model_name: str, use_grad_checkpoint: bool, master_weights: bool -): +def exam_gemini_grad_acc(placement_config, keep_gathered: bool, model_name: str, master_weights: bool): init_device = get_current_device() - get_components_func = non_distributed_component_funcs.get_callable(model_name) - model_builder, train_dataloader, _, _, criterion = get_components_func() + model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next( + iter(model_zoo.get_sub_registry(model_name).values()) + ) set_seed(42) - gemini_model = model_builder(use_grad_checkpoint) + gemini_model = model_builder() set_seed(42) - torch_model = model_builder(use_grad_checkpoint).cuda() + torch_model = model_builder().cuda() for torch_p, p in zip(torch_model.parameters(), gemini_model.parameters()): torch_p.data.copy_(p.data) @@ -94,22 +91,23 @@ def exam_gemini_grad_acc( set_seed(rank) accum_iter = 4 - for i, (input_ids, label) in enumerate(train_dataloader): + train_dataloader = DummyDataloader(data_gen_fn) + for i, data in enumerate(train_dataloader): delay_unscale = False if (i + 1) % accum_iter == 0 else True - input_ids, label = input_ids.cuda(), label.cuda() + data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()} set_seed(42 + rank) - torch_loss = run_fwd(torch_model, input_ids, label, criterion) + torch_loss = run_fwd(torch_model, data, output_transform_fn, loss_fn) torch_loss = torch_loss / accum_iter with amp.scale_loss(torch_loss, torch_optim, delay_unscale=delay_unscale) as scaled_loss: scaled_loss.backward() set_seed(42 + rank) - gemini_loss = run_fwd(gemini_model, input_ids, label, criterion) + gemini_loss = run_fwd(gemini_model, data, output_transform_fn, loss_fn) gemini_loss = gemini_loss / accum_iter gemini_optim.backward(gemini_loss) - assert torch.allclose(torch_loss, gemini_loss, rtol=1e-3, atol=1e-5) + assert torch.allclose(torch_loss.float(), gemini_loss.float(), rtol=1e-3, atol=1e-5) check_grad(gemini_model, torch_model) diff --git a/tests/test_zero/test_gemini/test_grad_clip.py b/tests/test_zero/test_gemini/test_grad_clip.py index 4c84e9e5a89a..c3a36d3bafa1 100644 --- a/tests/test_zero/test_gemini/test_grad_clip.py +++ b/tests/test_zero/test_gemini/test_grad_clip.py @@ -7,12 +7,11 @@ import colossalai from colossalai.legacy.amp import convert_to_apex_amp from colossalai.nn.optimizer import HybridAdam -from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn from colossalai.utils import set_seed from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero.gemini.chunk import search_chunk_configuration -from tests.components_to_test import run_fwd_bwd -from tests.components_to_test.registry import non_distributed_component_funcs +from tests.kit.model_zoo import model_zoo, run_fwd_bwd PLACEMENT_CONFIGS = [ { @@ -51,12 +50,13 @@ def check_param(model: GeminiDDP, torch_model: torch.nn.Module): @parameterize("placement_config", PLACEMENT_CONFIGS) -@parameterize("model_name", ["gpt2"]) +@parameterize("model_name", ["transformers_gpt_lm"]) @parameterize("master_weights", [True, False]) def exam_grad_clipping(placement_config, model_name: str, master_weights: bool): set_seed(1912) - get_components_func = non_distributed_component_funcs.get_callable(model_name) - model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next( + iter(model_zoo.get_sub_registry(model_name).values()) + ) torch_model = model_builder().cuda() amp_config = dict(opt_level="O2", keep_batchnorm_fp32=False, loss_scale=32) @@ -94,21 +94,17 @@ def exam_grad_clipping(placement_config, model_name: str, master_weights: bool): torch_model.train() set_seed(dist.get_rank() * 3 + 128) - for i, (data, label) in enumerate(train_dataloader): + train_dataloader = DummyDataloader(data_gen_fn) + for i, data in enumerate(train_dataloader): if i > 2: break - data = data.cuda() - label = label.cuda() + data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()} zero_optim.zero_grad() torch_optim.zero_grad() - torch_loss = run_fwd_bwd(torch_model, data, label, criterion, torch_optim) - loss = run_fwd_bwd(model, data, label, criterion, zero_optim) - - # as no master weights leads to error accumulation, we don't check the loss - if master_weights: - assert_close(torch_loss, loss) + run_fwd_bwd(torch_model, data, output_transform_fn, loss_fn, optimizer=torch_optim) + run_fwd_bwd(model, data, output_transform_fn, loss_fn, optimizer=zero_optim) import apex.amp as apex_amp diff --git a/tests/test_zero/test_gemini/test_inference.py b/tests/test_zero/test_gemini/test_inference.py index 2b2b246a9f54..e20428b67b41 100644 --- a/tests/test_zero/test_gemini/test_inference.py +++ b/tests/test_zero/test_gemini/test_inference.py @@ -9,13 +9,12 @@ import colossalai from colossalai.legacy.amp import convert_to_apex_amp from colossalai.nn.optimizer import HybridAdam -from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn from colossalai.utils import set_seed from colossalai.utils.cuda import get_current_device from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero.gemini.chunk import search_chunk_configuration -from tests.components_to_test import run_fwd_bwd -from tests.components_to_test.registry import non_distributed_component_funcs +from tests.kit.model_zoo import model_zoo, run_fwd, run_fwd_bwd PLACEMENT_CONFIGS = [ {"placement_policy": "static", "shard_param_frac": 0.0}, # zero2 @@ -53,12 +52,11 @@ def single_chunk_init(model: torch.nn.Module, placement_config: dict): @parameterize("placement_config", PLACEMENT_CONFIGS) -@parameterize("model_name", ["gpt2"]) +@parameterize("model_name", ["transformers_gpt_lm"]) @parameterize("model_init_func", [single_chunk_init, multi_chunk_init]) def exam_inference(placement_config: dict, model_name: str, model_init_func: Callable): set_seed(19360226) - get_components_func = non_distributed_component_funcs.get_callable(model_name) - model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + model_builder, data_gen_fn, output_transform_fn, *_ = next(iter(model_zoo.get_sub_registry(model_name).values())) torch_model = model_builder().cuda() amp_config = dict(opt_level="O2", keep_batchnorm_fp32=False, loss_scale=128) @@ -79,29 +77,27 @@ def exam_inference(placement_config: dict, model_name: str, model_init_func: Cal torch_model.eval() set_seed(dist.get_rank() * 3 + 128) - train_dataloader = iter(train_dataloader) + train_dataloader = iter(DummyDataloader(data_gen_fn)) def train_iter(): - input_ids, label = next(train_dataloader) - input_ids, label = input_ids.cuda(), label.cuda() + data = next(train_dataloader) + data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()} zero_optim.zero_grad() torch_optim.zero_grad() - torch_loss = run_fwd_bwd(torch_model, input_ids, label, criterion, torch_optim) - loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim) - assert_close(torch_loss, loss, rtol=1e-5, atol=1e-5) + torch_loss = run_fwd_bwd(torch_model, data, output_transform_fn, optimizer=torch_optim) + loss = run_fwd_bwd(model, data, output_transform_fn, optimizer=zero_optim) + assert_close(torch_loss.float(), loss.float(), rtol=1e-5, atol=1e-5) zero_optim.step() torch_optim.step() check_param(model, torch_model) def inference_iter(): - input_ids, label = next(train_dataloader) - input_ids, label = input_ids.cuda(), label.cuda() + data = next(train_dataloader) + data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()} with torch.no_grad(): - torch_output = torch_model(input_ids) - torch_loss = criterion(torch_output.float(), label) - zero_output = model(input_ids) - zero_loss = criterion(zero_output.float(), label) - assert_close(torch_loss, zero_loss) + torch_loss = run_fwd(torch_model, data, output_transform_fn) + zero_loss = run_fwd(model, data, output_transform_fn) + assert_close(torch_loss.float(), zero_loss.float(), rtol=1e-5, atol=1e-5) train_iter() inference_iter() diff --git a/tests/test_zero/test_gemini/test_optim.py b/tests/test_zero/test_gemini/test_optim.py index 0cf9aa073f9f..887e495e6187 100644 --- a/tests/test_zero/test_gemini/test_optim.py +++ b/tests/test_zero/test_gemini/test_optim.py @@ -1,20 +1,18 @@ import pytest import torch import torch.distributed as dist -from packaging.version import Version from torch.nn.parallel import DistributedDataParallel as DDP from torch.testing import assert_close import colossalai from colossalai.legacy.amp import convert_to_apex_amp from colossalai.nn.optimizer import HybridAdam -from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn from colossalai.utils import set_seed from colossalai.utils.cuda import get_current_device from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero.gemini.chunk import search_chunk_configuration -from tests.components_to_test import run_fwd_bwd -from tests.components_to_test.registry import non_distributed_component_funcs +from tests.kit.model_zoo import model_zoo, run_fwd_bwd PLACEMENT_CONFIGS = [ {"placement_policy": "static", "shard_param_frac": 0.0, "offload_optim_frac": 0.0}, # zero2 @@ -32,14 +30,17 @@ ] # this model is large enough to slice to chunks -TEST_MODELS = ["gpt2"] +TEST_MODELS = ["transformers_gpt_lm"] # these models are too small, all parameters in these models are compacted into one chunk -EXAMPLE_MODELS = ["albert", "beit", "bert", "hanging_param_model", "nested_model", "repeated_computed_layers"] +EXAMPLE_MODELS = [ + "transformers_bert_for_sequence_classification", + "custom_hanging_param_model", + "custom_nested_model", + "custom_repeated_computed_layers", +] # bfloat16 cannot represent them exactly BF16_IGNORED_KEYS = [ - "albert.embeddings.word_embeddings.weight", - "albert.embeddings.position_embeddings.weight", "masked_bias", ] @@ -55,7 +56,7 @@ def check_param(model: GeminiDDP, torch_model: torch.nn.Module, dtype: torch.dty temp_zero_value = zero_dict[key].to(device=value.device) if dtype is torch.bfloat16 and any(k in key for k in BF16_IGNORED_KEYS): continue - rtol, atol = 1e-3, 4e-3 + rtol, atol = 2e-3, 6e-3 if dtype is torch.bfloat16: rtol, atol = 4e-3, 8e-3 # debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value))) @@ -74,8 +75,9 @@ def check_param(model: GeminiDDP, torch_model: torch.nn.Module, dtype: torch.dty @parameterize("master_weights", [True, False]) def exam_model_step(placement_config, model_name: str, mixed_precision: torch.dtype, master_weights: bool): set_seed(42) - get_components_func = non_distributed_component_funcs.get_callable(model_name) - model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next( + iter(model_zoo.get_sub_registry(model_name).values()) + ) torch_model = model_builder().cuda() # apex no master weights leads to nan, so we don't use it @@ -104,19 +106,20 @@ def exam_model_step(placement_config, model_name: str, mixed_precision: torch.dt torch_model.eval() set_seed(dist.get_rank() * 3 + 128) - rtol, atol = 1e-4, 1e-5 - for i, (input_ids, label) in enumerate(train_dataloader): + rtol, atol = 4e-2, 4e-2 + train_dataloader = iter(DummyDataloader(data_gen_fn)) + for i, data in enumerate(train_dataloader): if i > 2: break - input_ids, label = input_ids.cuda(), label.cuda() + data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()} zero_optim.zero_grad() torch_optim.zero_grad() - torch_loss = run_fwd_bwd(torch_model, input_ids, label, criterion, torch_optim) - loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim) + torch_loss = run_fwd_bwd(torch_model, data, output_transform_fn, loss_fn, optimizer=torch_optim) + loss = run_fwd_bwd(model, data, output_transform_fn, loss_fn, optimizer=zero_optim) # as no master weights leads to error accumulation, we don't check the loss if master_weights: - assert_close(torch_loss, loss, rtol=rtol, atol=atol) + assert_close(torch_loss.float(), loss.float(), rtol=rtol, atol=atol) zero_optim.step() torch_optim.step() @@ -125,13 +128,14 @@ def exam_model_step(placement_config, model_name: str, mixed_precision: torch.dt check_param(model, torch_model, mixed_precision) -@parameterize("placement_config", PLACEMENT_CONFIGS) +@parameterize("placement_config", [PLACEMENT_CONFIGS[3]]) @parameterize("model_name", EXAMPLE_MODELS) -@parameterize("mixed_precision", [torch.half, torch.bfloat16]) +@parameterize("mixed_precision", [torch.half]) def exam_tiny_example(placement_config, model_name: str, mixed_precision: torch.dtype): set_seed(2008) - get_components_func = non_distributed_component_funcs.get_callable(model_name) - model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next( + iter(model_zoo.get_sub_registry(model_name).values()) + ) torch_model = model_builder().cuda() amp_config = dict(opt_level="O2", keep_batchnorm_fp32=False, loss_scale=2) @@ -159,26 +163,19 @@ def exam_tiny_example(placement_config, model_name: str, mixed_precision: torch. torch_model.eval() set_seed(dist.get_rank() * 3 + 128) - rtol, atol = 1.5e-6, 2e-5 - if mixed_precision is torch.bfloat16: - rtol, atol = 2e-3, 2e-3 - elif Version(torch.__version__) >= Version("2.0.0"): - rtol, atol = 4e-5, 3e-5 - for i, (input_ids, label) in enumerate(train_dataloader): + train_dataloader = DummyDataloader(data_gen_fn) + for i, data in enumerate(train_dataloader): if i > 2: break - input_ids = input_ids.cuda() - label = label.cuda() + data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()} zero_optim.zero_grad() torch_optim.zero_grad() - torch_loss = run_fwd_bwd(torch_model, input_ids, label, criterion, torch_optim) - loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim) - assert_close(torch_loss, loss, rtol=rtol, atol=atol) # atol should be 2e-5 for torch lower than 1.12 - + run_fwd_bwd(torch_model, data, output_transform_fn, loss_fn, optimizer=torch_optim) + run_fwd_bwd(model, data, output_transform_fn, loss_fn, optimizer=zero_optim) zero_optim.step() torch_optim.step() diff --git a/tests/test_zero/test_gemini/test_runtime_mem_tracer.py b/tests/test_zero/test_gemini/test_runtime_mem_tracer.py index 8e0f6ae36c46..9d00521be694 100644 --- a/tests/test_zero/test_gemini/test_runtime_mem_tracer.py +++ b/tests/test_zero/test_gemini/test_runtime_mem_tracer.py @@ -4,10 +4,9 @@ import pytest import torch -from colossalai.testing import clear_cache_before_run +from colossalai.testing import DummyDataloader, clear_cache_before_run from colossalai.zero.gemini.memory_tracer.runtime_mem_tracer import RuntimeMemTracer -from tests.components_to_test import run_fwd_bwd -from tests.components_to_test.registry import non_distributed_component_funcs +from tests.kit.model_zoo import model_zoo, run_fwd_bwd @pytest.mark.skip("this is not used") @@ -16,21 +15,22 @@ def test_runtime_mem_tracer(): test_models = ["gpt2", "bert", "simple_net", "repeated_computed_layers", "nested_model", "albert"] for model_name in test_models: - get_components_func = non_distributed_component_funcs.get_callable(model_name) - model_builder, train_dataloader, _, _, criterion = get_components_func() + model_builder, data_gen_fn, output_transform_fn, *_ = next( + iter(model_zoo.get_sub_registry(model_name).values()) + ) - model = model_builder(checkpoint=False).cuda() + model = model_builder().cuda() model_bk = deepcopy(model) runtime_mem_tracer = RuntimeMemTracer(model) - for i, (data, label) in enumerate(train_dataloader): + train_dataloader = DummyDataloader(data_gen_fn) + for i, data in enumerate(train_dataloader): if i > 1: break - data = data.cuda() - label = label.cuda() + data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()} - run_fwd_bwd(runtime_mem_tracer, data, label, criterion, optimizer=runtime_mem_tracer) + run_fwd_bwd(runtime_mem_tracer, data, output_transform_fn, optimizer=runtime_mem_tracer) for p1, p2 in zip(model_bk.parameters(), model.parameters()): torch.allclose(p1.to(torch.half), p2) diff --git a/tests/test_zero/test_gemini/test_search.py b/tests/test_zero/test_gemini/test_search.py index e22e5ece42a5..e99f6d59ba8e 100644 --- a/tests/test_zero/test_gemini/test_search.py +++ b/tests/test_zero/test_gemini/test_search.py @@ -5,40 +5,37 @@ from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.utils import get_current_device from colossalai.zero.gemini.chunk import init_chunk_manager, search_chunk_configuration -from tests.components_to_test.registry import non_distributed_component_funcs +from tests.kit.model_zoo import model_zoo def exam_search_chunk_size(): - world_size = torch.distributed.get_world_size() - - get_components_func = non_distributed_component_funcs.get_callable("gpt2") - model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + model_builder, data_gen_fn, output_transform_fn, *_ = next( + iter(model_zoo.get_sub_registry("transformers_gpt_lm").values()) + ) # make sure torch_model and model has the same parameter values model = model_builder() config_dict, *_ = search_chunk_configuration( - model, search_range_m=1, search_interval=16, min_chunk_size_m=0, filter_exlarge_params=True + model, search_range_m=1, search_interval=128, min_chunk_size_m=0, filter_exlarge_params=True ) for key in config_dict: chunk_size = config_dict[key]["chunk_size"] - if world_size == 1 or True: - assert chunk_size == 31616 - else: - assert chunk_size == 1024 + assert chunk_size == 527872 def exam_chunk_manager(): world_size = torch.distributed.get_world_size() - get_components_func = non_distributed_component_funcs.get_callable("gpt2") - model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + model_builder, data_gen_fn, output_transform_fn, *_ = next( + iter(model_zoo.get_sub_registry("transformers_gpt_lm").values()) + ) sharded_ddp_model = model_builder() chunk_manager = init_chunk_manager( sharded_ddp_model, get_current_device(), - hidden_dim=16, + hidden_dim=128, search_range_m=1, min_chunk_size_m=0, filter_exlarge_params=True, @@ -46,7 +43,7 @@ def exam_chunk_manager(): ) config_dict = chunk_manager.dp_degree_chunk_size_dict assert len(config_dict) == 1 - assert config_dict[world_size] == 31616 + assert config_dict[world_size] == 527872 def run_dist(rank, world_size, port): diff --git a/tests/test_zero/test_gemini/test_zeroddp_state_dict.py b/tests/test_zero/test_gemini/test_zeroddp_state_dict.py index bf16a301cd8a..cbf5169fc621 100644 --- a/tests/test_zero/test_gemini/test_zeroddp_state_dict.py +++ b/tests/test_zero/test_gemini/test_zeroddp_state_dict.py @@ -7,7 +7,7 @@ from colossalai.utils import set_seed from colossalai.zero import GeminiDDP from colossalai.zero.gemini.chunk import search_chunk_configuration -from tests.components_to_test.registry import non_distributed_component_funcs +from tests.kit.model_zoo import model_zoo PLACEMENT_CONFIGS = [ {"placement_policy": "static", "shard_param_frac": 0.0}, # zero2 @@ -26,15 +26,16 @@ def ignore_the_first_parameter(model: torch.nn.Module): @parameterize("placement_config", PLACEMENT_CONFIGS) @parameterize("keep_gathered", [True, False]) -@parameterize("model_name", ["gpt2", "bert"]) +@parameterize("model_name", ["transformers_gpt_lm", "transformers_bert_for_sequence_classification"]) @parameterize("master_weights", [False, True]) def exam_state_dict(placement_config, keep_gathered, model_name: str, master_weights: bool): set_seed(431) - get_components_func = non_distributed_component_funcs.get_callable(model_name) - model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + model_builder, data_gen_fn, output_transform_fn, *_ = next(iter(model_zoo.get_sub_registry(model_name).values())) model = model_builder() + model_size = sum(p.numel() * p.element_size() for p in model.parameters()) / 1024**2 + torch_model = model_builder() for torch_p, p in zip(torch_model.parameters(), model.parameters()): torch_p.data.copy_(p.data) @@ -54,29 +55,7 @@ def exam_state_dict(placement_config, keep_gathered, model_name: str, master_wei temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype) assert_close(value, temp_zero_value, rtol=1e-3, atol=1e-5) - -@parameterize("placement_config", PLACEMENT_CONFIGS) -@parameterize("keep_gathered", [True, False]) -@parameterize("model_name", ["gpt2", "bert"]) -@parameterize("master_weights", [False, True]) -def exam_load_state_dict(placement_config, keep_gathered, model_name: str, master_weights: bool): - set_seed(431) - get_components_func = non_distributed_component_funcs.get_callable(model_name) - model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() - - model = model_builder() - - set_seed(451) - torch_model = model_builder() # get a different model - - world_size = torch.distributed.get_world_size() - config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100) - config_dict[world_size]["chunk_size"] = 5000 - config_dict[world_size]["keep_gathered"] = keep_gathered - - model = GeminiDDP(model, config_dict, **placement_config, pin_memory=True, master_weights=master_weights) - - torch_dict = torch_model.state_dict() + # check load state dict model.load_state_dict(torch_dict, strict=False) zero_dict = model.state_dict(only_rank_0=False) @@ -85,23 +64,7 @@ def exam_load_state_dict(placement_config, keep_gathered, model_name: str, maste temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype) assert_close(value, temp_zero_value, rtol=1e-3, atol=1e-5) - -@parameterize("placement_config", PLACEMENT_CONFIGS) -@parameterize("model_name", ["gpt2", "bert"]) -@parameterize("master_weights", [False, True]) -def exam_state_dict_shard(placement_config, model_name: str, master_weights: bool): - get_components_func = non_distributed_component_funcs.get_callable(model_name) - model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() - - model = model_builder() - - model_size = sum(p.numel() * p.element_size() for p in model.parameters()) / 1024**2 - - config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100) - model = GeminiDDP(model, config_dict, **placement_config, master_weights=master_weights) - model.train() - - zero_dict = model.state_dict(only_rank_0=False) + # check state dict shard accumulated_keys = set() # ensure number of shards > 1 for shard, _ in model.state_dict_shard(max_shard_size=(model_size / 3), only_rank_0=False): @@ -116,8 +79,6 @@ def run_dist(rank, world_size, port): config = {} colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") exam_state_dict() - exam_load_state_dict() - exam_state_dict_shard() @pytest.mark.dist diff --git a/tests/test_zero/test_gemini/test_zerooptim_state_dict.py b/tests/test_zero/test_gemini/test_zerooptim_state_dict.py index c65c6d292467..87cb1cdfe43f 100644 --- a/tests/test_zero/test_gemini/test_zerooptim_state_dict.py +++ b/tests/test_zero/test_gemini/test_zerooptim_state_dict.py @@ -8,7 +8,7 @@ from colossalai.utils import set_seed from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero.gemini.chunk import search_chunk_configuration -from tests.components_to_test.registry import non_distributed_component_funcs +from tests.kit.model_zoo import model_zoo PLACEMENT_CONFIGS = [ {"placement_policy": "static", "shard_param_frac": 0.0, "offload_optim_frac": 0.0}, # zero2 @@ -22,8 +22,9 @@ @parameterize("keep_gathered", [True, False]) def exam_zero_optim_state_dict(placement_config, keep_gathered): set_seed(431) - get_components_func = non_distributed_component_funcs.get_callable("gpt2") - model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + model_builder, data_gen_fn, output_transform_fn, *_ = next( + iter(model_zoo.get_sub_registry("transformers_gpt_lm").values()) + ) model = model_builder() @@ -41,15 +42,15 @@ def exam_zero_optim_state_dict(placement_config, keep_gathered): set_seed(dist.get_rank() * 3 + 128) model.train() - for i, (input_ids, label) in enumerate(train_dataloader): - if i > 0: - break - optim.zero_grad() - logits = model(input_ids) - logits = logits.float() - loss = criterion(logits, input_ids) - optim.backward(loss) - optim.step() + data = data_gen_fn() + data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()} + + optim.zero_grad() + outputs = model(**data) + outputs = output_transform_fn(outputs) + loss = next(iter(outputs.values())).sum() + optim.backward(loss) + optim.step() optim_state_dict = optim.state_dict() optim.load_state_dict(optim_state_dict) From 9d543af5d02576fe8eb8bb242d0eb5738edd508d Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Fri, 20 Oct 2023 13:39:34 +0800 Subject: [PATCH 13/46] [inference] add reference and fix some bugs (#4937) * add reference and fix some bugs * update gptq init --------- Co-authored-by: Xu Kai --- .../inference/quant/smoothquant/models/base_model.py | 6 ++++++ .../inference/quant/smoothquant/models/linear.py | 2 ++ .../inference/quant/smoothquant/models/llama.py | 3 +++ colossalai/inference/tensor_parallel/engine.py | 7 ++++++- colossalai/kernel/triton/gptq_triton.py | 1 + colossalai/kernel/triton/smooth_attention.py | 12 ++++++------ examples/inference/gptq_llama.py | 3 --- 7 files changed, 24 insertions(+), 10 deletions(-) diff --git a/colossalai/inference/quant/smoothquant/models/base_model.py b/colossalai/inference/quant/smoothquant/models/base_model.py index 180e6c6e8fa6..6a1d96ecec8f 100644 --- a/colossalai/inference/quant/smoothquant/models/base_model.py +++ b/colossalai/inference/quant/smoothquant/models/base_model.py @@ -132,6 +132,7 @@ def collect_act_dict(self, model, tokenizer, dataset, act_dict, device, num_samp mean_scale = np.mean([v["input"] for v in act_dict.values()]) pbar.set_description(f"Mean input scale: {mean_scale:.2f}") + # Adatped from https://github.com/mit-han-lab/smoothquant/blob/main/smoothquant/calibration.py def get_act_scales(self, model, tokenizer, dataset, num_samples=512, seq_len=512): model.eval() device = next(model.parameters()).device @@ -163,6 +164,7 @@ def stat_input_hook(m, x, y, name): return act_scales + # Adapted from https://github.com/mit-han-lab/smoothquant/blob/main/smoothquant/smooth.py @torch.no_grad() def smooth_ln_fcs(self, ln, fcs, act_scales, alpha=0.5): if not isinstance(fcs, list): @@ -189,6 +191,7 @@ def smooth_ln_fcs(self, ln, fcs, act_scales, alpha=0.5): def create_quantized_model(model): raise NotImplementedError("Not implement create_quantized_model method") + # Adapted from AutoGPTQ: https://github.com/PanQiWei/AutoGPTQ/blob/main/auto_gptq/modeling/_base.py def save_quantized( self, save_dir: str, @@ -249,6 +252,7 @@ def save_quantized( self.model.config.save_pretrained(save_dir) + # Adapted from AutoGPTQ: https://github.com/PanQiWei/AutoGPTQ/blob/main/auto_gptq/modeling/_base.py def save_pretrained( self, save_dir: str, @@ -260,6 +264,7 @@ def save_pretrained( warnings.warn("you are using save_pretrained, which will re-direct to save_quantized.") self.save_quantized(save_dir, use_safetensors, safetensors_metadata) + # Adapted from AutoGPTQ: https://github.com/PanQiWei/AutoGPTQ/blob/main/auto_gptq/modeling/_base.py @classmethod def from_pretrained( cls, @@ -354,6 +359,7 @@ def skip(*args, **kwargs): return cls(model, False) + # Adapted from AutoGPTQ: https://github.com/PanQiWei/AutoGPTQ/blob/main/auto_gptq/modeling/_base.py @classmethod def from_quantized( cls, diff --git a/colossalai/inference/quant/smoothquant/models/linear.py b/colossalai/inference/quant/smoothquant/models/linear.py index 048565bfbf5e..969c390a0849 100644 --- a/colossalai/inference/quant/smoothquant/models/linear.py +++ b/colossalai/inference/quant/smoothquant/models/linear.py @@ -62,6 +62,7 @@ def from_float(module: torch.nn.Linear, input_scale): return int8_module +# modified from torch-int: https://github.com/Guangxuan-Xiao/torch-int/blob/main/torch_int/nn/linear.py class W8A8B8O8Linear(torch.nn.Module): # For qkv_proj def __init__(self, in_features, out_features, alpha=1.0, beta=1.0): @@ -117,6 +118,7 @@ def from_float(module: torch.nn.Linear, input_scale, output_scale): return int8_module +# modified from torch-int: https://github.com/Guangxuan-Xiao/torch-int/blob/main/torch_int/nn/linear.py class W8A8BFP32OFP32Linear(torch.nn.Module): # For fc2 and out_proj def __init__(self, in_features, out_features, alpha=1.0, beta=1.0): diff --git a/colossalai/inference/quant/smoothquant/models/llama.py b/colossalai/inference/quant/smoothquant/models/llama.py index 9c77feeb346e..4c3d6dcc0b23 100644 --- a/colossalai/inference/quant/smoothquant/models/llama.py +++ b/colossalai/inference/quant/smoothquant/models/llama.py @@ -419,6 +419,7 @@ def forward(self, x, cos, sin, position_ids): return x_embed +# Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py def llama_decoder_layer_forward( self, hidden_states: torch.Tensor, @@ -559,6 +560,7 @@ def init_to_get_rotary(config, base=10000, use_elem=False): return _cos_cached, _sin_cached +# Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) def llama_model_forward( self, @@ -729,6 +731,7 @@ class SmoothLlamaForCausalLM(BaseSmoothForCausalLM): def __init__(self, model: PreTrainedModel, quantized: bool = False): super().__init__(model, quantized) + # Adatped from https://github.com/mit-han-lab/smoothquant/blob/main/smoothquant/calibration.py def get_act_dict( self, tokenizer, diff --git a/colossalai/inference/tensor_parallel/engine.py b/colossalai/inference/tensor_parallel/engine.py index e4c4a2d70cd7..216b134f5fab 100644 --- a/colossalai/inference/tensor_parallel/engine.py +++ b/colossalai/inference/tensor_parallel/engine.py @@ -21,6 +21,8 @@ "BloomForCausalLM", "ChatGLMModel", "ChatGLMForConditionalGeneration", + "LlamaGPTQForCausalLM", + "BloomGPTQForCausalLM", ] @@ -213,11 +215,14 @@ def _shard_model_by(self, shardformer: ShardFormer, model: nn.Module) -> None: ), "Discrepancy between the tp size of TPInferEngine and the tp size of shard config" model_name = model.__class__.__name__ assert model_name in self.supported_models, f"Unsupported model cls {model_name} for TP inference." + + model = model.model if self.shard_config.inference_gptq else model + policy = get_autopolicy(model, inference_only=True) self.model, _ = shardformer.optimize(model, policy) if self.shard_config.inference_gptq: - self._post_init_gptq_buffer(model) + self._post_init_gptq_buffer(self.model) self.model = self.model.cuda() diff --git a/colossalai/kernel/triton/gptq_triton.py b/colossalai/kernel/triton/gptq_triton.py index 8460103e261d..2dc1fe04438a 100644 --- a/colossalai/kernel/triton/gptq_triton.py +++ b/colossalai/kernel/triton/gptq_triton.py @@ -267,6 +267,7 @@ def cai_gptq_matmul_248_kernel( tl.store(c_ptrs, accumulator, mask=c_mask) +# Adapted from AutoGPTQ auto_gptq: https://github.com/PanQiWei/AutoGPTQ @autotune( configs=[ triton.Config( diff --git a/colossalai/kernel/triton/smooth_attention.py b/colossalai/kernel/triton/smooth_attention.py index ee0df6a74eaa..071de58e20c0 100644 --- a/colossalai/kernel/triton/smooth_attention.py +++ b/colossalai/kernel/triton/smooth_attention.py @@ -13,10 +13,10 @@ if HAS_TRITON: """ - this function is modified from - https://github.com/ModelTC/lightllm/blob/f093edc20683ac3ea1bca3fb5d8320a0dd36cf7b/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L10 + this functions are modified from https://github.com/ModelTC/lightllm """ + # Adapted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py @triton.jit def _context_flash_attention_kernel( Q, @@ -145,20 +145,16 @@ def _context_flash_attention_kernel( tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) return - - @torch.no_grad() def smooth_llama_context_attn_fwd( q, k, v, o, q_input_scale, k_input_scale, v_input_scale, pv_output_scale, b_start_loc, b_seq_len, max_input_len ): - BLOCK = 128 # shape constraints Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] assert Lq == Lk, "context process only supports equal query, key, value length" assert Lk == Lv, "context process only supports equal query, key, value length" assert Lk in {16, 32, 64, 128} - BLOCK_N = 128 sm_scale = 1.0 / math.sqrt(Lk) batch, head = b_seq_len.shape[0], q.shape[1] grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) @@ -203,6 +199,7 @@ def smooth_llama_context_attn_fwd( ) return + # Adapted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py @triton.jit def _token_attn_1_kernel( Q, @@ -264,6 +261,7 @@ def _token_attn_1_kernel( tl.store(attn_out + off_o, att_value, mask=offs_n_new < current_batch_end_index) return + # Adapted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py @triton.jit def _token_attn_1_alibi_kernel( Q, @@ -413,6 +411,7 @@ def token_attn_fwd_1( ) return + # Adapted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/token_attention_softmax_and_reducev.py @triton.jit def _token_attn_softmax_fwd( softmax_logics, @@ -479,6 +478,7 @@ def token_attn_softmax_fwd(softmax_logics, kv_cache_start_loc, kv_cache_seqlen, ) return + # Adapted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py @triton.jit def _token_attn_2_kernel( Prob, diff --git a/examples/inference/gptq_llama.py b/examples/inference/gptq_llama.py index bbfbf1bc8b43..4357dc6c026d 100644 --- a/examples/inference/gptq_llama.py +++ b/examples/inference/gptq_llama.py @@ -8,7 +8,6 @@ import colossalai from colossalai.inference.tensor_parallel.engine import TPInferEngine -from colossalai.inference.tensor_parallel.modeling._utils import init_to_get_rotary from colossalai.logging import disable_existing_loggers from colossalai.shardformer import ShardConfig from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn @@ -74,8 +73,6 @@ def run_llama_test(args): quantized_model_dir, device=torch.cuda.current_device(), inject_fused_attention=False ) - init_to_get_rotary(model.model.model, base=10000) - model_config = model.config shard_config = ShardConfig( enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True, inference_gptq=True From fe795600474f5d5eb41a4a73563d9153eb3cbbbe Mon Sep 17 00:00:00 2001 From: Jianghai <72591262+CjhHa1@users.noreply.github.com> Date: Tue, 24 Oct 2023 13:11:15 +0800 Subject: [PATCH 14/46] [Inference]ADD Bench Chatglm2 script (#4963) * add bench chatglm * fix bug and make utils --------- Co-authored-by: CjhHa1 --- examples/inference/_utils.py | 19 +++++ examples/inference/bench_bloom.py | 20 +---- examples/inference/bench_chatglm2.py | 116 +++++++++++++++++++++++++++ examples/inference/bench_llama.py | 22 +---- examples/inference/gptq_bloom.py | 1 + examples/inference/gptq_llama.py | 1 + 6 files changed, 141 insertions(+), 38 deletions(-) create mode 100644 examples/inference/_utils.py create mode 100644 examples/inference/bench_chatglm2.py diff --git a/examples/inference/_utils.py b/examples/inference/_utils.py new file mode 100644 index 000000000000..67d897836113 --- /dev/null +++ b/examples/inference/_utils.py @@ -0,0 +1,19 @@ +def print_perf_stats(latency_set, config, bs, warmup=3): + # trim warmup queries + latency_set = list(latency_set) + latency_set = latency_set[warmup:] + count = len(latency_set) + + if count > 0: + latency_set.sort() + avg = sum(latency_set) / count + num_layers = ( + getattr(config, "num_layers") if hasattr(config, "num_layers") else getattr(config, "num_hidden_layers") + ) + num_parameters = num_layers * config.hidden_size * config.hidden_size * 12 + num_bytes = 2 # float16 + + print("Avg Per Token Latency: {0:8.2f} ms".format(avg * 1000)) + print("Avg BW: {0:8.2f} GB/s".format(1 / avg * num_parameters * num_bytes / 1e9)) + print("Avg flops: {0:8.2f} TFlops/s".format(1 / avg * num_parameters * num_bytes * bs / 1e12)) + print("Avg Throughput: tokens/s: {}".format((1000 / (avg * 1000)) * bs)) diff --git a/examples/inference/bench_bloom.py b/examples/inference/bench_bloom.py index 738f43dc0619..054641f6eebf 100644 --- a/examples/inference/bench_bloom.py +++ b/examples/inference/bench_bloom.py @@ -3,6 +3,7 @@ import time import torch +from _utils import print_perf_stats from transformers import BloomForCausalLM, BloomTokenizerFast import colossalai @@ -14,25 +15,6 @@ os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" -def print_perf_stats(latency_set, config, bs, warmup=3): - # trim warmup queries - latency_set = list(latency_set) - latency_set = latency_set[warmup:] - count = len(latency_set) - - if count > 0: - latency_set.sort() - avg = sum(latency_set) / count - num_layers = getattr(config, "num_layers", config.num_hidden_layers) - num_parameters = num_layers * config.hidden_size * config.hidden_size * 12 - num_bytes = 2 # float16 - - print("Avg Per Token Latency: {0:8.2f} ms".format(avg * 1000)) - print("Avg BW: {0:8.2f} GB/s".format(1 / avg * num_parameters * num_bytes / 1e9)) - print("Avg flops: {0:8.2f} TFlops/s".format(1 / avg * num_parameters * num_bytes * bs / 1e12)) - print("Avg Throughput: tokens/s: {}".format((1000 / (avg * 1000)) * bs)) - - def bench_bloom(args): model_path = args.path max_batch_size = args.batch_size diff --git a/examples/inference/bench_chatglm2.py b/examples/inference/bench_chatglm2.py new file mode 100644 index 000000000000..f3678d29ff93 --- /dev/null +++ b/examples/inference/bench_chatglm2.py @@ -0,0 +1,116 @@ +import argparse +import os +import time + +import torch +from _utils import print_perf_stats +from transformers import AutoTokenizer + +import colossalai +from colossalai.inference.tensor_parallel.engine import TPInferEngine +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer import ShardConfig +from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration +from colossalai.testing import rerun_if_address_is_in_use, spawn + +os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" + + +def run_chatglm2_test(args): + chatglm2_model_path = args.path + max_batch_size = args.batch_size + max_input_len = args.input_len + max_output_len = args.output_len + args.test_mode + + print("max_batch_size : " + str(max_batch_size)) + + tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True) + model = ChatGLMForConditionalGeneration.from_pretrained(chatglm2_model_path, pad_token_id=tokenizer.eos_token_id) + model = model.half() + model.config + + shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True) + infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len) + + generate_kwargs = dict(max_new_tokens=1, do_sample=False) + input_tokens = { + "input_ids": torch.randint(1, 1000, (max_batch_size, max_input_len), device="cuda"), + "attention_mask": torch.ones((max_batch_size, max_input_len), device="cuda"), + } + + iters = 10 + prefill_times = [] + + warmup = 3 + + for i in range(iters): + torch.cuda.synchronize() + start = time.time() + outputs = infer_engine.generate(input_tokens, **generate_kwargs) + torch.cuda.synchronize() + end = time.time() + out_len = outputs.shape[1] + print("generation time {} s".format(str(end - start))) + print(out_len - max_input_len) + prefill_times.append((end - start) / (out_len - max_input_len)) + + prefill_times = prefill_times[warmup:] + prefill_time_avg = sum(prefill_times) / len(prefill_times) + generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False) + + times = [] + decoder_times = [] + for i in range(iters): + torch.cuda.synchronize() + start = time.time() + outputs = infer_engine.generate(input_tokens, **generate_kwargs) + torch.cuda.synchronize() + end = time.time() + out_len = outputs.shape[1] + print("generation time {} s".format(str(end - start))) + print(out_len - max_input_len) + times.append((end - start) / (out_len - max_input_len)) + if args.test_mode == "decoder_test": + decoder_times.append((end - start - prefill_time_avg) / (out_len - max_input_len - 1)) + + times = times[warmup:] + latency = sum(times) / len(times) + print("total process latency is : " + str(latency) + " s") + print("total throughput is : " + str(1 / latency * max_batch_size)) + + if args.test_mode == "decoder_test": + decoder_times = decoder_times[warmup:] + latency = sum(decoder_times) / len(decoder_times) + + print("decoder process latency is : " + str(latency) + " s") + print("decoder throughput is : " + str(1 / latency * max_batch_size)) + + print_perf_stats(times, model.config, max_batch_size) + + +def check_chatglm2(rank, world_size, port, args): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_chatglm2_test(args) + + +@rerun_if_address_is_in_use() +def test_chatglm2(args): + spawn(check_chatglm2, args.tp_size, args=args) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("-p", "--path", type=str, help="Model path", required=True) + parser.add_argument("-tp", "--tp_size", type=int, default=1, help="Tensor parallel size") + parser.add_argument("-b", "--batch_size", type=int, default=16, help="Maximum batch size") + parser.add_argument("--input_len", type=int, default=256, help="Maximum input length") + parser.add_argument("--output_len", type=int, default=128, help="Maximum output length") + parser.add_argument( + "--test_mode", type=str, help="Test mode", default="e2e_test", choices=["e2e_test", "decoder_test"] + ) + + args = parser.parse_args() + + test_chatglm2(args) diff --git a/examples/inference/bench_llama.py b/examples/inference/bench_llama.py index 0ca1953c6a41..f3e742dfbb59 100644 --- a/examples/inference/bench_llama.py +++ b/examples/inference/bench_llama.py @@ -3,6 +3,7 @@ import time import torch +from _utils import print_perf_stats from transformers import LlamaForCausalLM, LlamaTokenizer import colossalai @@ -14,25 +15,6 @@ os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" -def print_perf_stats(latency_set, config, bs, warmup=3): - torch.cuda.empty_cache() - # trim warmup queries - latency_set = list(latency_set) - latency_set = latency_set[warmup:] - count = len(latency_set) - - if count > 0: - latency_set.sort() - avg = sum(latency_set) / count - num_layers = getattr(config, "num_layers", config.num_hidden_layers) - num_parameters = num_layers * config.hidden_size * config.hidden_size * 12 - num_bytes = 2 - - print("Avg Per Token Latency: {0:8.2f} ms".format(avg * 1000)) - print("Avg BW: {0:8.2f} GB/s".format(1 / avg * num_parameters * num_bytes / 1e9)) - print("Avg flops: {0:8.2f} TFlops/s".format(1 / avg * num_parameters * num_bytes * bs / 1e12)) - - def run_llama_test(args): llama_model_path = args.path max_batch_size = args.batch_size @@ -104,6 +86,8 @@ def run_llama_test(args): print("decoder process latency is : " + str(latency) + " s") print("decoder throughput is : " + str(1 / latency * max_batch_size)) + print_perf_stats(times, model.config, max_batch_size) + def check_llama(rank, world_size, port, args): disable_existing_loggers() diff --git a/examples/inference/gptq_bloom.py b/examples/inference/gptq_bloom.py index f5413e31682d..bb92e5471d89 100644 --- a/examples/inference/gptq_bloom.py +++ b/examples/inference/gptq_bloom.py @@ -3,6 +3,7 @@ import time import torch +from _utils import print_perf_stats from auto_gptq import AutoGPTQForCausalLM from transformers import BloomTokenizerFast diff --git a/examples/inference/gptq_llama.py b/examples/inference/gptq_llama.py index 4357dc6c026d..e730522c0096 100644 --- a/examples/inference/gptq_llama.py +++ b/examples/inference/gptq_llama.py @@ -3,6 +3,7 @@ import time import torch +from _utils import print_perf_stats from auto_gptq import AutoGPTQForCausalLM from transformers import LlamaTokenizer From a6100461bbaf63ebe82c77945b0e3c9a44d1c966 Mon Sep 17 00:00:00 2001 From: Bin Jia <45593998+FoolPlayer@users.noreply.github.com> Date: Fri, 27 Oct 2023 16:19:54 +0800 Subject: [PATCH 15/46] [Pipeline inference] Combine kvcache with pipeline inference (#4938) * merge kvcache with pipeline inference and refactor the code structure * support ppsize > 2 * refactor pipeline code * do pre-commit * modify benchmark * fix bench mark * polish code * add docstring and update readme * refactor the code * fix some logic bug of ppinfer * polish readme * fix typo * skip infer test --- colossalai/inference/__init__.py | 3 +- colossalai/inference/pipeline/README.md | 75 ++- .../inference/pipeline/benchmark/benchmark.py | 7 +- .../inference/pipeline/benchmark/run.sh | 6 +- colossalai/inference/pipeline/engine.py | 91 +++- .../inference/pipeline/microbatch_manager.py | 100 ++-- .../inference/pipeline/modeling/__init__.py | 3 + .../inference/pipeline/modeling/_utils.py | 67 +++ .../inference/pipeline/modeling/gpt2.py | 280 ---------- .../inference/pipeline/modeling/llama.py | 483 +++++++++++++----- .../inference/pipeline/policies/__init__.py | 3 + .../inference/pipeline/policies/llama.py | 145 ++++++ .../inference/pipeline/policy/gpt2_ppinfer.py | 74 --- .../pipeline/policy/llama_ppinfer.py | 48 -- colossalai/inference/pipeline/utils.py | 35 -- .../tensor_parallel/batch_infer_state.py | 61 +++ colossalai/pipeline/schedule/generate.py | 82 +-- colossalai/shardformer/shard/shard_config.py | 3 +- tests/test_infer/test_pipeline_infer.py | 19 +- 19 files changed, 881 insertions(+), 704 deletions(-) create mode 100644 colossalai/inference/pipeline/modeling/_utils.py delete mode 100644 colossalai/inference/pipeline/modeling/gpt2.py create mode 100644 colossalai/inference/pipeline/policies/__init__.py create mode 100644 colossalai/inference/pipeline/policies/llama.py delete mode 100644 colossalai/inference/pipeline/policy/gpt2_ppinfer.py delete mode 100644 colossalai/inference/pipeline/policy/llama_ppinfer.py delete mode 100644 colossalai/inference/pipeline/utils.py diff --git a/colossalai/inference/__init__.py b/colossalai/inference/__init__.py index 35891307e754..761e48e5917a 100644 --- a/colossalai/inference/__init__.py +++ b/colossalai/inference/__init__.py @@ -1,3 +1,4 @@ from .pipeline import PPInferEngine -__all__ = ["PPInferEngine"] + +__all__ = ['PPInferEngine'] diff --git a/colossalai/inference/pipeline/README.md b/colossalai/inference/pipeline/README.md index a90d5d6da182..f9bb35cc4d4c 100644 --- a/colossalai/inference/pipeline/README.md +++ b/colossalai/inference/pipeline/README.md @@ -17,7 +17,7 @@ Pipeline Inference is composed of three parts: `PPInferEngine`, `MicroBatchManager` and `generate` [schedule](https://github.com/hpcaitech/ColossalAI/blob/feature/pipeline-infer/colossalai/pipeline/schedule/generate.py). 1. `PPInderEngine` is the High-Level API for users to use. It is responsible for the following tasks: - - Initialize the pipeline inference environment with `PipelineStageManager` and mdoel with `ShardFormer`. + - Initialize the pipeline inference environment with `PipelineStageManager` and model with `ShardFormer`. - Run the pipeline inference model. 2. `MicroBatchManager` is a structure to manage the micro-batch information. It is responsible for the following tasks: @@ -31,54 +31,53 @@ Pipeline Inference is composed of three parts: `PPInferEngine`, `MicroBatchManag ### Example ```python -from colossalai.pipeline import PPInferEngine -# Suppose the pipeline size is 2, and use fp16 to do infenrence. Use Llama as an example. -model = LlamaForCausalLM.from_pretrained('/path/to/model') -inputs = tokenizer("Hello, my dog is cute", "What a good day", return_tensors="pt") -engine = PPInferEngine( - pp_size=2, - dtype='fp16', - micro_batch_size=1, - new_length=10, - model=model, - model_policy=LlamaForCausalLMPipelinePolicy()) - -output = engine.inference([inputs]) +from colossalai.inference import PPInferEngine +from colossalai.inference.pipeline.policies import LlamaModelInferPolicy +import colossalai +from transformers import LlamaForCausalLM, LlamaTokenizer -``` +colossalai.launch_from_torch(config={}) + +model = LlamaForCausalLM.from_pretrained("/path/to/model") +tokenizer = LlamaTokenizer.from_pretrained("/path/to/model") -### Quick start -```shell -cd benchmark -sh run.sh +# assume the model is inferred with 2 pipeline stages +inferengine = PPInferEngine(pp_size=2, model=model, model_policy=LlamaModelInferPolicy(), new_length=32) + +input = ["Introduce a landmark in London","Introduce a landmark in Singapore"] +data = tokenizer(input, return_tensors='pt') +output = inferengine.inference(data.to('cuda')) +print(tokenizer.batch_decode(output)) ``` ## Performance -We conducted multiple benchmark tests to evaluate the performance. We compared the inference `latency` and `throughputs` between `Pipeline Inference` and `hugging face` pipeline. The test environment is 2*A10, 20G. +We conducted multiple benchmark tests to evaluate the performance. We compared the inference `latency` and `throughputs` between `Pipeline Inference` and `hugging face` pipeline. The test environment is 2 * A10, 20G / 2 * A800, 80G. -### Llama Throughput(tokens/s) +### Llama Throughput (tokens/s) | input length=1024, output length=128 -#### 7b, fp16 +#### A10 7b, fp16 | batch_size(micro_batch size)| 2(1) | 4(2) | 8(4) | 16(8) | 32(8) | 32(16)| | :---: | :---: | :---: | :---: | :---: | :---: | :---:| -| Pipeline Inference(1024, 128) | 33.31 | 59.98 | 98.92 | 143.47 | 152.61 | OOM | -| Hugging Face(1024, 128) | 41.43 | 65.30 | 91.93 | 114.62 | OOM| OOM | -| Pipeline Inference(512, 512) | 43.37 | 82.81 | 148.03 | 229.06 | 238.67 | 312.82 | -| Hugging Face(512, 512) | 49.13 | 84.91 | 132.87 | 178.30 | OOM| OOM | +| Pipeline Inference | 40.35 | 77.1 | 139.03 | 232.7 | 257.81 | OOM | +| Hugging Face | 41.43 | 65.30 | 91.93 | 114.62 | OOM| OOM | -#### 7b, fp32 +#### A10 13b, fp16 | batch_size(micro_batch size)| 2(1) | 4(2) | 8(4) | 16(4) | | :---: | :---: | :---: | :---: | :---: | -| Pipeline Inference(1024, 128) | 20.61 | 31.23 | 45.20 | 47.46 | -| Hugging Face(1024, 128) | 19.80 | 29.37| OOM | OOM | -| Pipeline Inference(512, 512) | 28.07 | 46.76 | 79.35 | 81.70 | -| Hugging Face(512, 512) | 25.67 | 43.97 | 60.67 | OOM | +| Pipeline Inference | 25.39 | 47.09 | 83.7 | 89.46 | +| Hugging Face | 23.48 | 37.59 | 53.44 | OOM | -#### 13b, fp16 -| batch_size(micro_batch size)| 2(1) | 4(2) | 8(4) | 16(4) | -| :---: | :---: | :---: | :---: | :---: | -| Pipeline Inference(1024, 128) | 21.73 | 38.06 | 61.02 | 64.30 | -| Hugging Face(1024, 128) | 23.48 | 37.59 | 53.44 | OOM | -| Pipeline Inference(512, 512) | 26.65 | 49.48 | 86.11 | 88.44 | -| Hugging Face(512, 512) | 27.45 | 47.74 | 74.46 | OOM | + +#### A800 7b, fp16 +| batch_size(micro_batch size) | 2(1) | 4(2) | 8(4) | 16(8) | 32(16) | +| :---: | :---: | :---: | :---: | :---: | :---: | +| Pipeline Inference| 57.97 | 110.13 | 213.33 | 389.86 | 670.12 | +| Hugging Face | 42.44 | 76.5 | 151.97 | 212.88 | 256.13 | + + +#### A800 13b, fp16 +| batch_size(micro_batch size) | 2(1) | 4(2) | 8(4) | 16(8) | 32(16) | +| :---: | :---: | :---: | :---: | :---: | :---: | +| Pipeline Inference | 41.78 | 94.18 | 172.67| 310.75| 470.15 | +| Hugging Face | 36.57 | 68.4 | 105.81 | 139.51 | 166.34 | diff --git a/colossalai/inference/pipeline/benchmark/benchmark.py b/colossalai/inference/pipeline/benchmark/benchmark.py index 9c47909f70f0..8392d0a1e579 100644 --- a/colossalai/inference/pipeline/benchmark/benchmark.py +++ b/colossalai/inference/pipeline/benchmark/benchmark.py @@ -7,7 +7,7 @@ import colossalai from colossalai.inference import PPInferEngine -from colossalai.inference.pipeline.policy.llama_ppinfer import LlamaForCausalLMPipelinePolicy +from colossalai.inference.pipeline.policies import LlamaModelInferPolicy GIGABYTE = 1024**3 MEGABYTE = 1024 * 1024 @@ -117,8 +117,11 @@ def print_details_info(timestamps, model_config, args, whole_end2end): micro_batch_size=args.mb_size, new_length=args.new_length, model=model, - model_policy=LlamaForCausalLMPipelinePolicy(), + model_policy=LlamaModelInferPolicy(), verbose=True, + max_batch_size=args.mb_size, + max_input_len=args.seq_len, + max_output_len=args.seq_len + args.new_length + 256, ) data = data_gen(args.batch_size, args.seq_len) diff --git a/colossalai/inference/pipeline/benchmark/run.sh b/colossalai/inference/pipeline/benchmark/run.sh index 7d8da858692f..e3c33bb88db1 100644 --- a/colossalai/inference/pipeline/benchmark/run.sh +++ b/colossalai/inference/pipeline/benchmark/run.sh @@ -1,7 +1,7 @@ script_dir=$(cd "$(dirname "$0")" && pwd) cd "${script_dir}" -# 7b, fp32, 2 gpu, 1024, 128 +# 7b, fp16, 2 gpu, 1024, 128 for BATCH_SIZE in 2 4 8 16; do CUDA_VISIBLE_DEVICES=0,1 colossalai run --nproc_per_node 2 --master_port 29800 ./benchmark.py \ --model="7b" \ @@ -13,7 +13,7 @@ for BATCH_SIZE in 2 4 8 16; do --pp_size=2 done -# 7b, fp32, 2 gpu, 512, 512 +# 7b, fp16, 2 gpu, 512, 512 for BATCH_SIZE in 2 4 8 16 32; do CUDA_VISIBLE_DEVICES=0,1 colossalai run --nproc_per_node 2 --master_port 29800 ./benchmark.py \ --model="7b" \ @@ -25,7 +25,7 @@ for BATCH_SIZE in 2 4 8 16 32; do --pp_size=2 done -# 7b, fp32, 2 gpu, 1024, 128 +# 7b, fp16, 2 gpu, 1024, 128 for BATCH_SIZE in 2 4 8; do CUDA_VISIBLE_DEVICES=0,1 colossalai run --nproc_per_node 2 --master_port 29800 ./benchmark.py \ --model="13b" \ diff --git a/colossalai/inference/pipeline/engine.py b/colossalai/inference/pipeline/engine.py index 4f42385caf8f..480ac5dc79fb 100644 --- a/colossalai/inference/pipeline/engine.py +++ b/colossalai/inference/pipeline/engine.py @@ -1,5 +1,6 @@ import torch import torch.nn as nn +from transformers.tokenization_utils_base import BatchEncoding from colossalai.cluster import ProcessGroupMesh from colossalai.pipeline.schedule.generate import GenerateSchedule @@ -7,6 +8,7 @@ from colossalai.shardformer import ShardConfig, ShardFormer from colossalai.shardformer.policies.base_policy import Policy +from ..tensor_parallel.kvcache_manager import MemoryManager from .microbatch_manager import MicroBatchManager @@ -23,20 +25,29 @@ class PPInferEngine: micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages. new_length (int): the new length of the input sequence. early_stopping (bool): whether to stop early. + max_batch_size (int): the maximum batch size. + max_input_len (int): the maximum input length. + max_output_len (int): the maximum output length. Example: ```python - from colossalai.ppinference import PPInferEngine - from transformers import GPT2LMHeadModel, GPT2Tokenizer + from colossalai.inference import PPInferEngine + from colossalai.inference.pipeline.policies import LlamaModelInferPolicy + import colossalai + from transformers import LlamaForCausalLM, LlamaTokenizer - model = transformers.GPT2LMHeadModel.from_pretrained('gpt2') - # assume the model is infered with 4 pipeline stages - inferengine = PPInferEngine(pp_size=4, model=model, model_policy={Your own policy for pipeline sharding}) + colossalai.launch_from_torch(config={}) + + model = LlamaForCausalLM.from_pretrained("your_path_to_model") + tokenizer = LlamaTokenizer.from_pretrained("/home/lczyh/share/models/llama-7b-hf") + # assume the model is infered with 2 pipeline stages + inferengine = PPInferEngine(pp_size=2, model=model, model_policy=LlamaModelInferPolicy(), new_length=8) + + input = ["Introduce a landmark in China ","Introduce a landmark in China "] + data = tokenizer(input, return_tensors='pt') + output = inferengine.inference([data.to('cuda').data]) - input = ["Hello, my dog is cute, and I like"] - tokenized_input = tokenizer(input, return_tensors='pt') - output = engine.inference([tokenized_input]) ``` """ @@ -51,6 +62,9 @@ def __init__( new_length: int = 32, micro_batch_size: int = 1, micro_batch_buffer_size: int = None, + max_batch_size: int = 4, + max_input_len: int = 32, + max_output_len: int = 32, verbose: bool = False, # TODO: implement early_stopping, and various gerneration options early_stopping: bool = False, @@ -58,24 +72,53 @@ def __init__( num_beams: int = 1, ) -> None: assert pp_model or (model and model_policy), "Either pp_model or model with model_policy should be provided." + assert dtype in ["fp16", "fp32", "bf16"], "dtype should be one of 'fp16', 'fp32', 'bf16'" + + max_output_len = max(max_output_len, max_input_len + new_length) + self.pp_size = pp_size + if dtype == "fp16": + self.dtype = torch.float16 + model.half() + elif dtype == "bf16": + self.dtype = torch.bfloat16 + model.to(torch.bfloat16) + else: + self.dtype = torch.float32 self.pg_mesh = ProcessGroupMesh(pp_size) self.stage_manager = PipelineStageManager(self.pg_mesh, 0, True) + self.model = pp_model or self._shardformer(model, model_policy) + self.cache_manager_list = [ + self._init_manager(max_batch_size, max_input_len, max_output_len) + for _ in range(micro_batch_buffer_size or pp_size) + ] self.mb_manager = MicroBatchManager( - self.stage_manager.stage, new_length, micro_batch_size, micro_batch_buffer_size or pp_size + self.stage_manager.stage, + new_length, + micro_batch_size, + micro_batch_buffer_size or pp_size, + max_input_len, + max_output_len, + self.cache_manager_list, ) self.verbose = verbose self.schedule = GenerateSchedule(self.stage_manager, self.mb_manager, verbose) - assert dtype in ["fp16", "fp32", "bf16"], "dtype should be one of 'fp16', 'fp32', 'bf16'" - if dtype == "fp16": - model.half() - elif dtype == "bf16": - model.to(torch.bfloat16) - self.model = pp_model or self._shardformer(model, model_policy) - def inference(self, input_list): - out, timestamp = self.schedule.generate_step(self.model, iter(input_list)) + """ + Args: + input_list (list): a list of input data, each element is a `BatchEncoding` or `dict`. + + Returns: + out (list): a list of output data, each element is a list of token. + timestamp (float): the time cost of the inference, only return when verbose is `True`. + """ + assert isinstance( + input_list, (BatchEncoding, dict) + ), f"Only accept BatchEncoding or dict as input, but get {input_list.__class__.__name__}." + if isinstance(input_list, BatchEncoding): + input_list = input_list.data + out, timestamp = self.schedule.generate_step(self.model, iter([input_list])) if self.verbose: return out, timestamp else: @@ -95,3 +138,17 @@ def _shardformer(self, model, model_policy): shardformer = ShardFormer(shard_config=shardconfig) shard_model, _ = shardformer.optimize(model, model_policy) return shard_model.cuda() + + def _init_manager(self, max_batch_size: int, max_input_len: int, max_output_len: int) -> None: + max_total_token_num = max_batch_size * (max_input_len + max_output_len) + head_dim = self.model.config.hidden_size // self.model.config.num_attention_heads + head_num = self.model.config.num_attention_heads + num_hidden_layers = ( + self.model.config.num_hidden_layers + if hasattr(self.model.config, "num_hidden_layers") + else self.model.config.num_layers + ) + layer_num = num_hidden_layers // self.pp_size + + cache_manager = MemoryManager(max_total_token_num, self.dtype, head_num, head_dim, layer_num) + return cache_manager diff --git a/colossalai/inference/pipeline/microbatch_manager.py b/colossalai/inference/pipeline/microbatch_manager.py index 49d1bf3f42cb..2bf52161d611 100644 --- a/colossalai/inference/pipeline/microbatch_manager.py +++ b/colossalai/inference/pipeline/microbatch_manager.py @@ -1,8 +1,11 @@ from enum import Enum -from typing import Dict, Tuple +from typing import Dict import torch +from ..tensor_parallel.batch_infer_state import BatchInferState +from ..tensor_parallel.kvcache_manager import MemoryManager + __all__ = "MicroBatchManager" @@ -27,21 +30,20 @@ class MicroBatchDescription: def __init__( self, inputs_dict: Dict[str, torch.Tensor], - output_dict: Dict[str, torch.Tensor], + max_input_len: int, + max_output_len: int, + cache_manager: MemoryManager, new_length: int, ) -> None: - assert output_dict.get("hidden_states") is not None - self.mb_length = output_dict["hidden_states"].shape[-2] + self.mb_length = inputs_dict["input_ids"].shape[-1] self.target_length = self.mb_length + new_length - self.kv_cache = () - - def update(self, output_dict: Dict[str, torch.Tensor] = None, new_token: torch.Tensor = None): - if output_dict is not None: - self._update_kvcache(output_dict["past_key_values"]) + self.infer_state = BatchInferState.init_from_batch( + batch=inputs_dict, max_input_len=max_input_len, max_output_len=max_output_len, cache_manager=cache_manager + ) + # print(f"[init] {inputs_dict}, {max_input_len}, {max_output_len}, {cache_manager}, {self.infer_state}") - def _update_kvcache(self, kv_cache: Tuple): - assert type(kv_cache) == tuple - self.kv_cache = kv_cache + def update(self, *args, **kwargs): + pass @property def state(self): @@ -80,17 +82,21 @@ class HeadMicroBatchDescription(MicroBatchDescription): """ def __init__( - self, inputs_dict: Dict[str, torch.Tensor], output_dict: Dict[str, torch.Tensor], new_length: int + self, + inputs_dict: Dict[str, torch.Tensor], + max_input_len: int, + max_output_len: int, + cache_manager: MemoryManager, + new_length: int, ) -> None: - super().__init__(inputs_dict, output_dict, new_length) + super().__init__(inputs_dict, max_input_len, max_output_len, cache_manager, new_length) assert inputs_dict is not None assert inputs_dict.get("input_ids") is not None and inputs_dict.get("attention_mask") is not None self.input_ids = inputs_dict["input_ids"] self.attn_mask = inputs_dict["attention_mask"] self.new_tokens = None - def update(self, output_dict: Dict[str, torch.Tensor] = None, new_token: torch.Tensor = None): - super().update(output_dict, new_token) + def update(self, new_token: torch.Tensor = None): if new_token is not None: self._update_newtokens(new_token) if self.state is not Status.DONE and new_token is not None: @@ -125,16 +131,17 @@ class BodyMicroBatchDescription(MicroBatchDescription): Args: inputs_dict (Dict[str, torch.Tensor]): will always be `None`. Other stages only receive hiddenstates from previous stage. - output_dict (Dict[str, torch.Tensor]): the outputs of previous stage. The key should have `hidden_states` and `past_key_values`. """ def __init__( - self, inputs_dict: Dict[str, torch.Tensor], output_dict: Dict[str, torch.Tensor], new_length: int + self, + inputs_dict: Dict[str, torch.Tensor], + max_input_len: int, + max_output_len: int, + cache_manager: MemoryManager, + new_length: int, ) -> None: - super().__init__(inputs_dict, output_dict, new_length) - - def update(self, output_dict: Dict[str, torch.Tensor] = None, new_token: torch.Tensor = None): - super().update(output_dict, new_token) + super().__init__(inputs_dict, max_input_len, max_output_len, cache_manager, new_length) @property def cur_length(self): @@ -142,10 +149,7 @@ def cur_length(self): When there is no kv_cache, the length is mb_length, otherwise the sequence length is `kv_cache[0][0].shape[-2]` plus 1 """ - if len(self.kv_cache) == 0: - return self.mb_length - else: - return self.kv_cache[0][0].shape[-2] + 1 + return self.infer_state.seq_len.max().item() class MicroBatchManager: @@ -160,16 +164,38 @@ class MicroBatchManager: """ - def __init__(self, stage: int, new_length: int, micro_batch_size: int, micro_batch_buffer_size: int): + def __init__( + self, + stage: int, + new_length: int, + micro_batch_size: int, + micro_batch_buffer_size: int, + max_input_len: int, + max_output_len: int, + cache_manager_list: MemoryManager, + ): self.stage = stage self.new_length = new_length self.micro_batch_size = micro_batch_size self.buffer_size = micro_batch_buffer_size + self.max_input_len = max_input_len + self.max_output_len = max_output_len + self.cache_manager_list = cache_manager_list self.mb_descrption_buffer = {} self.new_tokens_buffer = {} self.idx = 0 - def step(self, inputs_dict=None, output_dict: Dict[str, torch.Tensor] = None, new_token: torch.Tensor = None): + def add_descrption(self, inputs_dict: Dict[str, torch.Tensor]): + if self.stage == 0: + self.mb_descrption_buffer[self.idx] = HeadMicroBatchDescription( + inputs_dict, self.max_input_len, self.max_output_len, self.cache_manager_list[self.idx], self.new_length + ) + else: + self.mb_descrption_buffer[self.idx] = BodyMicroBatchDescription( + inputs_dict, self.max_input_len, self.max_output_len, self.cache_manager_list[self.idx], self.new_length + ) + + def step(self, new_token: torch.Tensor = None): """ Update the state if microbatch manager, 2 conditions. 1. For first stage in PREFILL, receive inputs and outputs, `_add_descrption` will save its inputs. @@ -181,11 +207,7 @@ def step(self, inputs_dict=None, output_dict: Dict[str, torch.Tensor] = None, ne new_token (torch.Tensor): the new token generated by current stage. """ # Add descrption first if the descrption is None - if inputs_dict is None and output_dict is None and new_token is None: - return Status.PREFILL - if self.mb_descrption_buffer.get(self.idx) is None: - self._add_descrption(inputs_dict, output_dict) - self.cur_descrption.update(output_dict, new_token) + self.cur_descrption.update(new_token) return self.cur_state def export_new_tokens(self): @@ -204,16 +226,12 @@ def is_micro_batch_done(self): def clear(self): self.mb_descrption_buffer.clear() + for cache in self.cache_manager_list: + cache.free_all() def next(self): self.idx = (self.idx + 1) % self.buffer_size - def _add_descrption(self, inputs_dict: Dict[str, torch.Tensor], output_dict: Dict[str, torch.Tensor]): - if self.stage == 0: - self.mb_descrption_buffer[self.idx] = HeadMicroBatchDescription(inputs_dict, output_dict, self.new_length) - else: - self.mb_descrption_buffer[self.idx] = BodyMicroBatchDescription(inputs_dict, output_dict, self.new_length) - def _remove_descrption(self): self.mb_descrption_buffer.pop(self.idx) @@ -222,10 +240,10 @@ def cur_descrption(self) -> MicroBatchDescription: return self.mb_descrption_buffer.get(self.idx) @property - def cur_kv_cache(self): + def cur_infer_state(self): if self.cur_descrption is None: return None - return self.cur_descrption.kv_cache + return self.cur_descrption.infer_state @property def cur_state(self): diff --git a/colossalai/inference/pipeline/modeling/__init__.py b/colossalai/inference/pipeline/modeling/__init__.py index e69de29bb2d1..239bdebd7efd 100644 --- a/colossalai/inference/pipeline/modeling/__init__.py +++ b/colossalai/inference/pipeline/modeling/__init__.py @@ -0,0 +1,3 @@ +from .llama import LlamaInferenceForwards + +__all__ = ["LlamaInferenceForwards"] diff --git a/colossalai/inference/pipeline/modeling/_utils.py b/colossalai/inference/pipeline/modeling/_utils.py new file mode 100644 index 000000000000..068b64b4f829 --- /dev/null +++ b/colossalai/inference/pipeline/modeling/_utils.py @@ -0,0 +1,67 @@ +""" +Utils for model inference +""" +import os + +import torch + +from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest + + +def copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, mem_manager): + """ + This function copies the key and value cache to the memory cache + Args: + layer_id : id of current layer + key_buffer : key cache + value_buffer : value cache + context_mem_index : index of memory cache in kv cache manager + mem_manager : cache manager + """ + copy_kv_cache_to_dest(key_buffer, context_mem_index, mem_manager.key_buffer[layer_id]) + copy_kv_cache_to_dest(value_buffer, context_mem_index, mem_manager.value_buffer[layer_id]) + + +def init_to_get_rotary(self, base=10000, use_elem=False): + """ + This function initializes the rotary positional embedding, it is compatible for all models and is called in ShardFormer + Args: + self : Model that holds the rotary positional embedding + base : calculation arg + use_elem : activated when using chatglm-based models + """ + self.config.head_dim_ = self.config.hidden_size // self.config.num_attention_heads + if not hasattr(self.config, "rope_scaling"): + rope_scaling_factor = 1.0 + else: + rope_scaling_factor = self.config.rope_scaling.factor if self.config.rope_scaling is not None else 1.0 + + if hasattr(self.config, "max_sequence_length"): + max_seq_len = self.config.max_sequence_length + elif hasattr(self.config, "max_position_embeddings"): + max_seq_len = self.config.max_position_embeddings * rope_scaling_factor + else: + max_seq_len = 2048 * rope_scaling_factor + base = float(base) + + # NTK ref: https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ + ntk_alpha = os.environ.get("INFER_NTK_ALPHA", None) + + if ntk_alpha is not None: + ntk_alpha = float(ntk_alpha) + assert ntk_alpha >= 1, "NTK alpha must be greater than or equal to 1" + if ntk_alpha > 1: + print(f"Note: NTK enabled, alpha set to {ntk_alpha}") + max_seq_len *= ntk_alpha + base = base * (ntk_alpha ** (self.head_dim_ / (self.head_dim_ - 2))) # Base change formula + + n_elem = self.config.head_dim_ + if use_elem: + n_elem //= 2 + + inv_freq = 1.0 / (base ** (torch.arange(0, n_elem, 2, device="cpu", dtype=torch.float32) / n_elem)) + t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor + freqs = torch.outer(t, inv_freq) + + self._cos_cached = torch.cos(freqs).to(torch.float16).cuda() + self._sin_cached = torch.sin(freqs).to(torch.float16).cuda() diff --git a/colossalai/inference/pipeline/modeling/gpt2.py b/colossalai/inference/pipeline/modeling/gpt2.py deleted file mode 100644 index d2bfcb8b6842..000000000000 --- a/colossalai/inference/pipeline/modeling/gpt2.py +++ /dev/null @@ -1,280 +0,0 @@ -from typing import Dict, List, Optional, Tuple, Union - -import torch -from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions -from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel, GPT2Model -from transformers.utils import logging - -from colossalai.pipeline.stage_manager import PipelineStageManager - - -class GPT2PipelineForwards: - """ - This class serves as a micro library for forward function substitution of GPT2 models - under pipeline setting. - """ - - @staticmethod - def gpt2_model_forward( - self: GPT2Model, - input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, - attention_mask: Optional[torch.FloatTensor] = None, - token_type_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - stage_manager: Optional[PipelineStageManager] = None, - hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None, - ) -> Union[Dict, Tuple, BaseModelOutputWithPastAndCrossAttentions]: - # This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2Model.forward. - # Please refer to original code of transformers for more details. - logger = logging.get_logger(__name__) - - # Preprocess passed in arguments - if output_attentions: - logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") - output_attentions = False - if output_hidden_states: - logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") - output_hidden_states = False - - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if past_key_values is None: - past_length = 0 - past_key_values = tuple([None] * len(self.h)) - else: - past_length = past_key_values[0][0].size(-2) - - if stage_manager.is_first_stage(): - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - input_shape = input_ids.size() - input_ids = input_ids.view(-1, input_shape[-1]) - batch_size = input_ids.shape[0] - elif inputs_embeds is not None: - input_shape = inputs_embeds.size()[:-1] - batch_size = inputs_embeds.shape[0] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - device = input_ids.device if input_ids is not None else inputs_embeds.device - if token_type_ids is not None: - token_type_ids = token_type_ids.view(-1, input_shape[-1]) - else: - if hidden_states is None: - raise ValueError("hidden_states shouldn't be None for stages other than the first stage.") - input_shape = hidden_states.size()[:-1] - batch_size, seq_length = input_shape[0], input_shape[1] - device = hidden_states.device - - # GPT2Attention mask. - if attention_mask is not None: - if batch_size <= 0: - raise ValueError("batch_size has to be defined and > 0") - attention_mask = attention_mask.view(batch_size, -1) - # We create a 3D attention mask from a 2D tensor mask. - # Sizes are [batch_size, 1, 1, to_seq_length] - # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] - # this attention mask is more simple than the triangular masking of causal attention - # used in OpenAI GPT, we just need to prepare the broadcast dimension here. - attention_mask = attention_mask[:, None, None, :] - - # Since attention_mask is 1.0 for positions we want to attend and 0.0 for - # masked positions, this operation will create a tensor which is 0.0 for - # positions we want to attend and the dtype's smallest value for masked positions. - # Since we are adding it to the raw scores before the softmax, this is - # effectively the same as removing these entirely. - attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility - attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min - - # If a 2D or 3D attention mask is provided for the cross-attention - # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] - if self.config.add_cross_attention and encoder_hidden_states is not None: - encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() - encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) - if encoder_attention_mask is None: - encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) - encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) - else: - encoder_attention_mask = None - - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape bsz x n_heads x N x N - # head_mask has shape n_layer x batch x n_heads x N x N - head_mask = self.get_head_mask(head_mask, self.config.n_layer) - - if stage_manager.is_first_stage(): - if position_ids is not None: - position_ids = position_ids.view(-1, input_shape[-1]) - else: - position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) - position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) - - if inputs_embeds is None: - inputs_embeds = self.wte(input_ids) - position_embeds = self.wpe(position_ids) - hidden_states = inputs_embeds + position_embeds - if token_type_ids is not None: - token_type_embeds = self.wte(token_type_ids) - hidden_states = hidden_states + token_type_embeds - hidden_states = self.drop(hidden_states) - - output_shape = input_shape + (hidden_states.size(-1),) - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - presents = () if use_cache else None - all_self_attentions = () if output_attentions else None - all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None - all_hidden_states = () if output_hidden_states else None - - # Going through held blocks. - start_idx, end_idx = stage_index[0], stage_index[1] - for i, layer_past in zip(range(start_idx, end_idx), past_key_values): - block = self.h[i] - # Model parallel - if self.model_parallel: - torch.cuda.set_device(hidden_states.device) - # Ensure layer_past is on same device as hidden_states (might not be correct) - if layer_past is not None: - layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past) - # Ensure that attention_mask is always on the same device as hidden_states - if attention_mask is not None: - attention_mask = attention_mask.to(hidden_states.device) - if isinstance(head_mask, torch.Tensor): - head_mask = head_mask.to(hidden_states.device) - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache, output_attentions) - - return custom_forward - - outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - hidden_states, - None, - attention_mask, - head_mask[i], - encoder_hidden_states, - encoder_attention_mask, - ) - else: - outputs = block( - hidden_states, - layer_past=layer_past, - attention_mask=attention_mask, - head_mask=head_mask[i], - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - use_cache=use_cache, - output_attentions=output_attentions, - ) - - hidden_states = outputs[0] - if use_cache is True: - presents = presents + (outputs[1],) - - if output_attentions: - all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) - if self.config.add_cross_attention: - all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) - - # Model Parallel: If it's the last layer for that device, put things on the next device - if self.model_parallel: - for k, v in self.device_map.items(): - if i == v[-1] and "cuda:" + str(k) != self.last_device: - hidden_states = hidden_states.to("cuda:" + str(k + 1)) - - if stage_manager.is_last_stage(): - hidden_states = self.ln_f(hidden_states) - - hidden_states = hidden_states.view(output_shape) - - # Add last hidden state - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - return {"hidden_states": hidden_states, "past_key_values": presents} - - @staticmethod - def gpt2_lmhead_model_forward( - self: GPT2LMHeadModel, - input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, - attention_mask: Optional[torch.FloatTensor] = None, - token_type_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - stage_manager: Optional[PipelineStageManager] = None, - hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None, - ) -> Union[Dict, Tuple, CausalLMOutputWithCrossAttentions]: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set - `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` - are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` - - This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.forward. - Please refer to original code of transformers for more details. - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # If is first stage and after warmup, go throught lm_head first - if stage_manager.is_first_stage() and hidden_states is not None: - lm_logits = self.lm_head(hidden_states) - return {"logits": lm_logits} - - # Not first stage or before warmup, go through gpt2 model - outputs = GPT2PipelineForwards.gpt2_model_forward( - self.transformer, - input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - stage_manager=stage_manager, - hidden_states=hidden_states, - stage_index=stage_index, - ) - - return outputs diff --git a/colossalai/inference/pipeline/modeling/llama.py b/colossalai/inference/pipeline/modeling/llama.py index f46e1fbdd7b3..9c72b02ccef8 100644 --- a/colossalai/inference/pipeline/modeling/llama.py +++ b/colossalai/inference/pipeline/modeling/llama.py @@ -1,36 +1,100 @@ -from typing import List, Optional +# This code is adapted from huggingface transformers: https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/llama/modeling_llama.py +from typing import List, Optional, Tuple import torch -from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaModel +from transformers.models.llama.modeling_llama import ( + LlamaAttention, + LlamaDecoderLayer, + LlamaForCausalLM, + LlamaModel, + LlamaRMSNorm, +) from transformers.utils import logging +from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState +from colossalai.kernel.triton import llama_context_attn_fwd, token_attention_fwd from colossalai.pipeline.stage_manager import PipelineStageManager +from ._utils import copy_kv_to_mem_cache -class LlamaPipelineForwards: +try: + from vllm import layernorm_ops, pos_encoding_ops + + rms_norm = layernorm_ops.rms_norm + rotary_embedding_neox = pos_encoding_ops.rotary_embedding_neox + HAS_VLLM_KERNERL = True +except: + print("fall back to original rotary_embedding_neox of huggingface") + print("install vllm from https://github.com/vllm-project/vllm to accelerate your inference") + print( + "if falied to install vllm, please use this branch to install: https://github.com/tiandiao123/vllm/tree/setup_branch" + ) + HAS_VLLM_KERNERL = False + +try: + from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd as llama_rotary_embedding_fwd + + HAS_LIGHTLLM_KERNEL = True +except: + print("please install lightllm from source to run inference: https://github.com/ModelTC/lightllm") + HAS_LIGHTLLM_KERNEL = False + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids): + # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. + cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] + sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] + cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class LlamaInferenceForwards: """ - This class serves as a micro library for forward function substitution of Llama models - under pipeline setting. + This class holds forwards for llama inference. + We intend to replace the forward methods for LlamaModel, LlamaDecoderLayer, and LlamaAttention for LlamaForCausalLM. """ - def llama_model_forward( - self: LlamaModel, + @staticmethod + def llama_causal_lm_forward( + self: LlamaForCausalLM, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + infer_state: BatchInferState = None, stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, ): + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + """ logger = logging.get_logger(__name__) - # Preprocess passed in arguments + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if output_attentions: logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") output_attentions = False @@ -38,11 +102,57 @@ def llama_model_forward( logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") output_hidden_states = False + # If is first stage and after warmup, go throught lm_head first + if stage_manager.is_first_stage() and hidden_states is not None: + lm_logits = self.lm_head(hidden_states) + return {"logits": lm_logits} + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = LlamaInferenceForwards.llama_model_forward( + self.model, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + infer_state=infer_state, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + ) + + return outputs + + @staticmethod + def llama_model_forward( + self: LlamaModel, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + infer_state: BatchInferState = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + ): + # batch_size = input_ids.shape[0] # input_ids.shape[0] + # print(f"[Before] rank:{torch.distributed.get_rank()}\n->{infer_state}") + + # infer_state = self.infer_state use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict # retrieve input_ids and inputs_embeds - if stage_manager.is_first_stage(): + if stage_manager is None or stage_manager.is_first_stage(): if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") elif input_ids is not None: @@ -56,6 +166,8 @@ def llama_model_forward( inputs_embeds = self.embed_tokens(input_ids) hidden_states = inputs_embeds else: + assert stage_manager is not None + assert hidden_states is not None, f"hidden_state should not be none in stage {stage_manager.stage}" input_shape = hidden_states.shape[:-1] batch_size, seq_length = input_shape device = hidden_states.device @@ -63,167 +175,292 @@ def llama_model_forward( seq_length_with_past = seq_length past_key_values_length = 0 - if past_key_values is not None: - past_key_values_length = past_key_values[0][0].shape[2] + if infer_state.is_context_stage is False: + past_key_values_length = infer_state.cache_manager.past_key_values_length seq_length_with_past = seq_length_with_past + past_key_values_length + # NOTE: differentiate with prefill stage + # block_loc require different value-assigning method for two different stage + if use_cache and seq_length != 1: + # NOTE assume prefill stage + # allocate memory block + infer_state.is_context_stage = True # set prefill stage, notify attention layer + infer_state.context_mem_index = infer_state.cache_manager.alloc(infer_state.total_token_num) + infer_state.init_block_loc( + infer_state.block_loc, infer_state.seq_len, seq_length, infer_state.context_mem_index + ) + else: + infer_state.is_context_stage = False + alloc_mem = infer_state.cache_manager.alloc_contiguous(batch_size) + if alloc_mem is not None: + infer_state.decode_is_contiguous = True + infer_state.decode_mem_index = alloc_mem[0] + infer_state.decode_mem_start = alloc_mem[1] + infer_state.decode_mem_end = alloc_mem[2] + infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index + else: + print(f" *** Encountered allocation non-contiguous") + print( + f" infer_state.cache_manager.past_key_values_length: {infer_state.cache_manager.past_key_values_length}" + ) + infer_state.decode_is_contiguous = False + alloc_mem = infer_state.cache_manager.alloc(batch_size) + infer_state.decode_mem_index = alloc_mem + # infer_state.decode_key_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") + # infer_state.decode_value_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") + infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index if position_ids is None: position_ids = torch.arange( past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + position_ids = position_ids.unsqueeze(0) + new_shape = [1] * position_ids.dim() + new_shape[0] = batch_size + position_ids = position_ids.repeat(*new_shape).view(-1, seq_length) else: position_ids = position_ids.view(-1, seq_length).long() - # embed positions, for the first stage, hidden_states is the input embeddings, - # for the other stages, hidden_states is the output of the previous stage + if infer_state.is_context_stage: + infer_state.position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view( + position_ids.view(-1).shape[0], -1 + ) + infer_state.position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view( + position_ids.view(-1).shape[0], -1 + ) + else: + seq_len = infer_state.seq_len + infer_state.position_cos = torch.index_select(self._cos_cached, 0, seq_len - 1).view(seq_len.shape[0], -1) + infer_state.position_sin = torch.index_select(self._sin_cached, 0, seq_len - 1).view(seq_len.shape[0], -1) + + # embed positions if attention_mask is None: attention_mask = torch.ones( (batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device ) + attention_mask = self._prepare_decoder_attention_mask( attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length ) - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None + () if output_hidden_states else None + () if output_attentions else None next_decoder_cache = () if use_cache else None + infer_state.decode_layer_id = 0 + start_idx, end_idx = stage_index[0], stage_index[1] if past_key_values is None: past_key_values = tuple([None] * (end_idx - start_idx + 1)) for idx, past_key_value in zip(range(start_idx, end_idx), past_key_values): decoder_layer = self.layers[idx] - if output_hidden_states: - all_hidden_states += (hidden_states,) - - # past_key_value = past_key_values[idx] if past_key_values is not None else None - - if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, None) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), - hidden_states, - attention_mask, - position_ids, - None, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - + # NOTE: modify here for passing args to decoder layer + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + infer_state=infer_state, + ) + infer_state.decode_layer_id += 1 hidden_states = layer_outputs[0] if use_cache: next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) - if output_attentions: - all_self_attns += (layer_outputs[1],) if stage_manager.is_last_stage(): hidden_states = self.norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) next_cache = next_decoder_cache if use_cache else None - # always return dict for imediate stage + # update indices + # infer_state.block_loc[:, infer_state.max_len_in_batch-1] = infer_state.total_token_num + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") + infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") + infer_state.seq_len += 1 + + # TODO: fix this to necessary return + # if not return_dict: + # return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + + # return BaseModelOutputWithPast( + # last_hidden_state=hidden_states, + # past_key_values=next_cache, + # hidden_states=all_hidden_states, + # attentions=all_self_attns, + # ) + # print(f"[After] rank:{torch.distributed.get_rank()}\n->{infer_state}") return {"hidden_states": hidden_states, "past_key_values": next_cache} - def llama_for_causal_lm_forward( - self: LlamaForCausalLM, - input_ids: torch.LongTensor = None, + @staticmethod + def llama_decoder_layer_forward( + self: LlamaDecoderLayer, + hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - stage_manager: Optional[PipelineStageManager] = None, - hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None, - ): - r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + infer_state: Optional[BatchInferState] = None, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states - Returns: + hidden_states = self.input_layernorm(hidden_states) - Example: + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + infer_state=infer_state, + ) - ```python - >>> from transformers import AutoTokenizer, LlamaForCausalLM + hidden_states = residual + hidden_states - >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) - >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states - >>> prompt = "Hey, are you consciours? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") + outputs = (hidden_states,) - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you." - ```""" - logger = logging.get_logger(__name__) + if output_attentions: + outputs += (self_attn_weights,) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if use_cache: + outputs += (present_key_value,) - if output_attentions: - logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") - output_attentions = False - if output_hidden_states: - logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") - output_hidden_states = False + return outputs - # If is first stage and after warmup, go throught lm_head first - if stage_manager.is_first_stage() and hidden_states is not None: - lm_logits = self.lm_head(hidden_states) - return {"logits": lm_logits} + @staticmethod + def llama_flash_attn_kvcache_forward( + self: LlamaAttention, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + infer_state: Optional[BatchInferState] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + assert use_cache is True, "use_cache should be set to True using this llama attention" + + bsz, q_len, _ = hidden_states.size() + + # NOTE might think about better way to handle transposed k and v + # key_states [bs, seq_len, num_heads, head_dim/embed_size_per_head] + # key_states_transposed [bs, num_heads, seq_len, head_dim/embed_size_per_head] + + query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) + key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) + value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) + + # NOTE might want to revise + # need some way to record the length of past key values cache + # since we won't return past_key_value_cache right now + if infer_state.decode_layer_id == 0: # once per model.forward + infer_state.cache_manager.past_key_values_length += q_len # seq_len + + cos, sin = infer_state.position_cos, infer_state.position_sin + + llama_rotary_embedding_fwd(query_states.view(-1, self.num_heads, self.head_dim), cos, sin) + llama_rotary_embedding_fwd(key_states.view(-1, self.num_heads, self.head_dim), cos, sin) + + query_states = query_states.reshape(-1, self.num_heads, self.head_dim) + key_states = key_states.reshape(-1, self.num_heads, self.head_dim) + value_states = value_states.reshape(-1, self.num_heads, self.head_dim) + + if infer_state.is_context_stage: + # print(f"rank:{torch.distributed.get_rank()}, {infer_state}") + # first token generation + + # copy key and value calculated in current step to memory manager + copy_kv_to_mem_cache( + infer_state.decode_layer_id, + key_states, + value_states, + infer_state.context_mem_index, + infer_state.cache_manager, + ) - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = LlamaPipelineForwards.llama_model_forward( - self.model, - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - stage_manager=stage_manager, - hidden_states=hidden_states, - stage_index=stage_index, - ) + attn_output = torch.empty_like(query_states) - return outputs + llama_context_attn_fwd( + query_states, + key_states, + value_states, + attn_output, + infer_state.start_loc, + infer_state.seq_len, + infer_state.cache_manager.past_key_values_length, + ) + + else: + if infer_state.decode_is_contiguous: + # if decode is contiguous, then we copy to key cache and value cache in cache manager directly + cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id][ + infer_state.decode_mem_start : infer_state.decode_mem_end, :, : + ] + cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id][ + infer_state.decode_mem_start : infer_state.decode_mem_end, :, : + ] + cache_k.copy_(key_states) + cache_v.copy_(value_states) + else: + # if decode is not contiguous, use triton kernel to copy key and value cache + # k, v shape: [batch_size, num_heads, head_dim/embed_size_per_head + copy_kv_to_mem_cache( + infer_state.decode_layer_id, + key_states, + value_states, + infer_state.decode_mem_index, + infer_state.cache_manager, + ) + + # second token and follows + # kv = torch.stack((key_states, value_states), dim=2) + # (batch_size, seqlen, nheads, headdim) + attn_output = torch.empty_like(query_states) + + token_attention_fwd( + query_states, + infer_state.cache_manager.key_buffer[infer_state.decode_layer_id], + infer_state.cache_manager.value_buffer[infer_state.decode_layer_id], + attn_output, + infer_state.block_loc, + infer_state.start_loc, + infer_state.seq_len, + infer_state.cache_manager.past_key_values_length, + ) + + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + # print(f"rank:{torch.distributed.get_rank()}, {attn_output}") + attn_output = self.o_proj(attn_output) + + # return past_key_value as None + return attn_output, None, None + + +def get_llama_vllm_rmsnorm_forward(): + if HAS_VLLM_KERNERL: + + def _vllm_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor): + x = hidden_states + out = torch.empty_like(x) + rms_norm( + out, + x, + self.weight.data, + self.variance_epsilon, + ) + + return out + + return _vllm_rmsnorm_forward + else: + return None diff --git a/colossalai/inference/pipeline/policies/__init__.py b/colossalai/inference/pipeline/policies/__init__.py new file mode 100644 index 000000000000..7271812c5366 --- /dev/null +++ b/colossalai/inference/pipeline/policies/__init__.py @@ -0,0 +1,3 @@ +from .llama import LlamaModelInferPolicy + +__all__ = ["LlamaModelInferPolicy"] diff --git a/colossalai/inference/pipeline/policies/llama.py b/colossalai/inference/pipeline/policies/llama.py new file mode 100644 index 000000000000..9f8c93c61234 --- /dev/null +++ b/colossalai/inference/pipeline/policies/llama.py @@ -0,0 +1,145 @@ +from functools import partial +from typing import List + +import torch +from torch.nn import Module +from transformers.models.llama.modeling_llama import ( + LlamaAttention, + LlamaDecoderLayer, + LlamaForCausalLM, + LlamaModel, + LlamaRMSNorm, +) + +from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription + +# import colossalai +from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy + +from ..modeling._utils import init_to_get_rotary +from ..modeling.llama import LlamaInferenceForwards, get_llama_vllm_rmsnorm_forward + +try: + from colossalai.kernel.triton import rmsnorm_forward + + HAS_TRITON_RMSNORM = True +except: + print("you should install triton from https://github.com/openai/triton") + HAS_TRITON_RMSNORM = False + + +def get_triton_rmsnorm_forward(): + if HAS_TRITON_RMSNORM: + + def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor): + return rmsnorm_forward(hidden_states, self.weight.data, self.variance_epsilon) + + return _triton_rmsnorm_forward + else: + return None + + +class LlamaModelInferPolicy(LlamaForCausalLMPolicy): + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + policy = super().module_policy() + + if self.shard_config.inference_gptq: + from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear + + decoder_attribute_replacement = { + "self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, + } + policy[LlamaDecoderLayer] = ModulePolicyDescription( + attribute_replacement=decoder_attribute_replacement, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attn.q_proj", + target_module=ColCaiQuantLinear, + kwargs={"split_num": 1}, + ), + SubModuleReplacementDescription( + suffix="self_attn.k_proj", + target_module=ColCaiQuantLinear, + kwargs={"split_num": 1}, + ), + SubModuleReplacementDescription( + suffix="self_attn.v_proj", + target_module=ColCaiQuantLinear, + kwargs={"split_num": 1}, + ), + SubModuleReplacementDescription( + suffix="self_attn.o_proj", + target_module=RowCaiQuantLinear, + kwargs={"split_num": 1}, + ), + SubModuleReplacementDescription( + suffix="mlp.gate_proj", + target_module=ColCaiQuantLinear, + kwargs={"split_num": 1}, + ), + SubModuleReplacementDescription( + suffix="mlp.up_proj", + target_module=ColCaiQuantLinear, + kwargs={"split_num": 1}, + ), + SubModuleReplacementDescription( + suffix="mlp.down_proj", + target_module=RowCaiQuantLinear, + kwargs={"split_num": 1}, + ), + ], + ) + + self.shard_config._infer() + + infer_forward = LlamaInferenceForwards.llama_model_forward + method_replacement = {"forward": partial(infer_forward)} + self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaModel) + + infer_forward = LlamaInferenceForwards.llama_decoder_layer_forward + method_replacement = {"forward": partial(infer_forward)} + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=LlamaDecoderLayer + ) + + infer_forward = LlamaInferenceForwards.llama_flash_attn_kvcache_forward + method_replacement = {"forward": partial(infer_forward)} + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=LlamaAttention + ) + + if self.pipeline_stage_manager: + # set None as default + self.set_pipeline_forward( + model_cls=LlamaForCausalLM, new_forward=LlamaInferenceForwards.llama_causal_lm_forward, policy=policy + ) + infer_forward = None + if HAS_TRITON_RMSNORM: + infer_forward = get_triton_rmsnorm_forward() + else: + # NOTE: adding rms_norm from cuda kernels caused precision issue, fix @tiandiao123 + infer_forward = get_llama_vllm_rmsnorm_forward() + + if infer_forward is not None: + method_replacement = {"forward": partial(infer_forward)} + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=LlamaRMSNorm + ) + + return policy + + def postprocess(self): + init_to_get_rotary(self.model.model) + return self.model + + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + stage_manager = self.pipeline_stage_manager + held_layers = super().get_held_layers() + if stage_manager.is_first_stage(): + held_layers.append(self.model.lm_head) + return held_layers diff --git a/colossalai/inference/pipeline/policy/gpt2_ppinfer.py b/colossalai/inference/pipeline/policy/gpt2_ppinfer.py deleted file mode 100644 index 51e6425b113e..000000000000 --- a/colossalai/inference/pipeline/policy/gpt2_ppinfer.py +++ /dev/null @@ -1,74 +0,0 @@ -from functools import partial -from typing import Callable, Dict, List - -from torch import Tensor, nn - -import colossalai.shardformer.layer as col_nn -from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription -from colossalai.shardformer.policies.gpt2 import GPT2Policy - -from ..modeling.gpt2 import GPT2PipelineForwards - - -class GPT2LMHeadModelPipelinePolicy(GPT2Policy): - def __init__(self) -> None: - super().__init__() - - def module_policy(self): - from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel - - module_policy = super().module_policy() - - if self.shard_config.enable_tensor_parallelism: - addon_module = { - GPT2LMHeadModel: ModulePolicyDescription( - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True} - ) - ] - ) - } - module_policy.update(addon_module) - - if self.pipeline_stage_manager is not None: - self.set_pipeline_forward( - model_cls=GPT2LMHeadModel, - new_forward=GPT2PipelineForwards.gpt2_lmhead_model_forward, - policy=module_policy, - ) - return module_policy - - def get_held_layers(self) -> List[nn.Module]: - held_layers = super().get_held_layers() - # make the tie weight lm_head and embedding in the same device to save memory - # if self.pipeline_stage_manager.is_first_stage(): - if self.pipeline_stage_manager.is_first_stage(): - held_layers.append(self.model.lm_head) - return held_layers - - def get_shared_params(self) -> List[Dict[int, Tensor]]: - """The weights of wte and lm_head are shared.""" - module = self.model - stage_manager = self.pipeline_stage_manager - if stage_manager is not None: - if stage_manager.num_stages > 1 and id(module.transformer.wte.weight) == id(module.lm_head.weight): - first_stage, last_stage = 0, stage_manager.num_stages - 1 - return [{first_stage: module.transformer.wte.weight, last_stage: module.lm_head.weight}] - return [] - - def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: - """If under pipeline parallel setting, replacing the original forward method of huggingface - to customized forward method, and add this changing to policy.""" - if not self.pipeline_stage_manager: - raise ValueError("set_pipeline_forward method can only be called when pipeline parallel is enabled.") - stage_manager = self.pipeline_stage_manager - if self.model.__class__.__name__ == "GPT2Model": - module = self.model - else: - module = self.model.transformer - - layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages) - stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) - method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} - self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) diff --git a/colossalai/inference/pipeline/policy/llama_ppinfer.py b/colossalai/inference/pipeline/policy/llama_ppinfer.py deleted file mode 100644 index 6e12ed61bf7b..000000000000 --- a/colossalai/inference/pipeline/policy/llama_ppinfer.py +++ /dev/null @@ -1,48 +0,0 @@ -from typing import List - -from torch.nn import Module - -from colossalai.shardformer.layer import Linear1D_Col -from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription -from colossalai.shardformer.policies.llama import LlamaPolicy - -from ..modeling.llama import LlamaPipelineForwards - - -class LlamaForCausalLMPipelinePolicy(LlamaPolicy): - def __init__(self) -> None: - super().__init__() - - def module_policy(self): - from transformers import LlamaForCausalLM - - policy = super().module_policy() - - if self.shard_config.enable_tensor_parallelism: - # add a new item for casual lm - new_item = { - LlamaForCausalLM: ModulePolicyDescription( - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True) - ) - ] - ) - } - policy.update(new_item) - - if self.pipeline_stage_manager: - # set None as default - self.set_pipeline_forward( - model_cls=LlamaForCausalLM, new_forward=LlamaPipelineForwards.llama_for_causal_lm_forward, policy=policy - ) - - return policy - - def get_held_layers(self) -> List[Module]: - """Get pipeline layers for current stage.""" - stage_manager = self.pipeline_stage_manager - held_layers = super().get_held_layers() - if stage_manager.is_first_stage(): - held_layers.append(self.model.lm_head) - return held_layers diff --git a/colossalai/inference/pipeline/utils.py b/colossalai/inference/pipeline/utils.py deleted file mode 100644 index c26aa4e40b71..000000000000 --- a/colossalai/inference/pipeline/utils.py +++ /dev/null @@ -1,35 +0,0 @@ -from typing import Set - -import torch.nn as nn - -from colossalai.shardformer._utils import getattr_, setattr_ - - -def set_tensors_to_none(model: nn.Module, include: Set[str] = set()) -> None: - """ - Set all parameters and buffers of model to None - - Args: - model (nn.Module): The model to set - """ - for module_suffix in include: - set_module = getattr_(model, module_suffix) - for n, p in set_module.named_parameters(): - setattr_(set_module, n, None) - for n, buf in set_module.named_buffers(): - setattr_(set_module, n, None) - setattr_(model, module_suffix, None) - - -def get_suffix_name(suffix: str, name: str): - """ - Get the suffix name of the module, as `suffix.name` when name is string or `suffix[name]` when name is a digit, - and 'name' when `suffix` is empty. - - Args: - suffix (str): The suffix of the suffix module - name (str): The name of the current module - """ - point = "" if suffix is "" else "." - suffix_name = suffix + f"[{name}]" if name.isdigit() else suffix + f"{point}{name}" - return suffix_name diff --git a/colossalai/inference/tensor_parallel/batch_infer_state.py b/colossalai/inference/tensor_parallel/batch_infer_state.py index de150311cc08..f707a86df37e 100644 --- a/colossalai/inference/tensor_parallel/batch_infer_state.py +++ b/colossalai/inference/tensor_parallel/batch_infer_state.py @@ -2,9 +2,11 @@ from dataclasses import dataclass import torch +from transformers.tokenization_utils_base import BatchEncoding from .kvcache_manager import MemoryManager + # adapted from: lightllm/server/router/model_infer/infer_batch.py @dataclass class BatchInferState: @@ -55,3 +57,62 @@ def init_block_loc( ] start_index += cur_seq_len return + + @classmethod + def init_from_batch( + cls, + batch: torch.Tensor, + max_input_len: int, + max_output_len: int, + cache_manager: MemoryManager, + ): + if not isinstance(batch, (BatchEncoding, dict, list, torch.Tensor)): + raise TypeError(f"batch type {type(batch)} is not supported in prepare_batch_state") + + input_ids_list = None + attention_mask = None + + if isinstance(batch, (BatchEncoding, dict)): + input_ids_list = batch["input_ids"] + attention_mask = batch["attention_mask"] + else: + input_ids_list = batch + if isinstance(input_ids_list[0], int): # for a single input + input_ids_list = [input_ids_list] + attention_mask = [attention_mask] if attention_mask is not None else attention_mask + + batch_size = len(input_ids_list) + + seq_start_indexes = torch.zeros(batch_size, dtype=torch.int32, device="cuda") + seq_lengths = torch.zeros(batch_size, dtype=torch.int32, device="cuda") + start_index = 0 + + max_len_in_batch = -1 + if isinstance(batch, (BatchEncoding, dict)): + for i, attn_mask in enumerate(attention_mask): + curr_seq_len = len(attn_mask) + seq_lengths[i] = curr_seq_len + seq_start_indexes[i] = start_index + start_index += curr_seq_len + max_len_in_batch = curr_seq_len if curr_seq_len > max_len_in_batch else max_len_in_batch + else: + length = max(len(input_id) for input_id in input_ids_list) + for i, input_ids in enumerate(input_ids_list): + curr_seq_len = length + seq_lengths[i] = curr_seq_len + seq_start_indexes[i] = start_index + start_index += curr_seq_len + max_len_in_batch = curr_seq_len if curr_seq_len > max_len_in_batch else max_len_in_batch + block_loc = torch.zeros((batch_size, max_input_len + max_output_len), dtype=torch.long, device="cuda") + + return cls( + batch_size=batch_size, + max_len_in_batch=max_len_in_batch, + seq_len=seq_lengths.to("cuda"), + start_loc=seq_start_indexes.to("cuda"), + block_loc=block_loc, + decode_layer_id=0, + past_key_values_len=0, + is_context_stage=True, + cache_manager=cache_manager, + ) diff --git a/colossalai/pipeline/schedule/generate.py b/colossalai/pipeline/schedule/generate.py index 1f4bbe9f8dad..db02dab59ca6 100644 --- a/colossalai/pipeline/schedule/generate.py +++ b/colossalai/pipeline/schedule/generate.py @@ -93,9 +93,9 @@ def _prepare_inputs_for_interval_stage(self): Returns: dict: inputs for interval stage, `{'past_key_values': torch.Tensor}` or `None` """ - model_inputs = ( - {"past_key_values": self.mb_manager.cur_kv_cache} if self.mb_manager.cur_kv_cache is not None else None - ) + model_inputs = { + 'infer_state': self.mb_manager.cur_descrption.infer_state + } return model_inputs def _prepare_inputs_for_new_token(self, new_token: torch.Tensor): @@ -108,9 +108,8 @@ def _prepare_inputs_for_new_token(self, new_token: torch.Tensor): dict: inputs for new token, `{'input_ids': torch.Tensor, 'attention_mask': torch.Tensor, 'past_key_values': torch.Tensor}` """ new_mask = self.mb_manager.cur_descrption.attn_mask - past_key_values = self.mb_manager.cur_descrption.kv_cache - return dict(input_ids=new_token, attention_mask=new_mask, past_key_values=past_key_values) + return dict(input_ids=new_token, attention_mask=new_mask) def _get_token_id(self, hidden_state: torch.Tensor) -> torch.Tensor: last_hidden_state = hidden_state[:, -1] @@ -128,27 +127,38 @@ def _recv_pre_stage(self) -> Any: return self.comm.p2p_recv() return self.comm.recv_forward() + def _init_infer_state_action(self) -> None: + """ + This action is only for no first stage, to load batch and init infer_state. + 1.Load micro_batch 2.Use the current micro_batch to init the current infer_state + """ + inputs_dict = self.load_micro_batch() + self.mb_manager.add_descrption(inputs_dict) + def _load_stage_action(self, model: Module) -> None: """ - In this action, 1.load micro_batch 2.do the forward 3.step to update + This action is only for first stage, load, init and do forward. + 1.load micro_batch 2.do the forward 3.step to update """ inputs_dict = self.load_micro_batch() + self.mb_manager.add_descrption(inputs_dict) if self.verbose and self.stage_manager.is_first_stage(): torch.cuda.synchronize() self.timestamps[self.mb_manager.idx].append(time.time()) - output_dict = model_forward(model, inputs_dict, None) + interval_inputs = {'infer_state': self.mb_manager.cur_infer_state} + output_dict = model_forward(model, inputs_dict, interval_inputs) - self.mb_manager.step(inputs_dict, output_dict, None) - self.action_interval_buffer.hidden_states = output_dict["hidden_states"] + self.action_interval_buffer.hidden_states = output_dict['hidden_states'] def _gen_token_action(self, model: Module): """ - In this action, 1.do the forward with hidden_states to generate new tokens 2.step to update + This action is only for first stage + 1.do the forward with hidden_states to generate new tokens 2.step to update """ hidden_states = self.action_interval_buffer.hidden_states assert hidden_states is not None, "When first stage in GENERATE phase, the hidden states should not be None" - hidden_states = {"hidden_states": hidden_states} - logits = model_forward(model, None, hidden_states) + interval_inputs = {'hidden_states': hidden_states, 'infer_state': self.mb_manager.cur_infer_state} + logits = model_forward(model, None, interval_inputs) if self.verbose and self.stage_manager.is_first_stage(): torch.cuda.synchronize() self.timestamps[self.mb_manager.idx].append(time.time()) @@ -157,7 +167,7 @@ def _gen_token_action(self, model: Module): ), f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}" new_token = self._get_token_id(logits["logits"]) - self.mb_manager.step(None, None, new_token) + self.mb_manager.step(new_token) self.action_interval_buffer.new_token = new_token self.action_interval_buffer.hidden_states = None @@ -168,20 +178,18 @@ def _head_encoding_action(self, model: Module): new_token = self.action_interval_buffer.new_token assert new_token is not None, "When first stage in GENERATE phase, the new token should not be None" inputs_dict = self._prepare_inputs_for_new_token(new_token) - output_dict = model_forward(model, inputs_dict, None) + interval_inputs = {'infer_state': self.mb_manager.cur_infer_state} + output_dict = model_forward(model, inputs_dict, interval_inputs) - self.mb_manager.step(inputs_dict, output_dict, None) - self.action_interval_buffer.hidden_states = output_dict["hidden_states"] + self.action_interval_buffer.hidden_states = output_dict['hidden_states'] def _body_encoding_action(self, model: Module): hidden_states = self.action_interval_buffer.hidden_states assert hidden_states is not None, "When not first stage, the hidden states should not be None" - inputs_dict = self._prepare_inputs_for_interval_stage() - hidden_states = {"hidden_states": hidden_states} - output_dict = model_forward(model, inputs_dict, hidden_states) + interval_inputs = {'hidden_states': hidden_states, 'infer_state': self.mb_manager.cur_infer_state} + output_dict = model_forward(model, None, interval_inputs) - self.mb_manager.step(inputs_dict, output_dict, None) - self.action_interval_buffer.hidden_states = output_dict["hidden_states"] + self.action_interval_buffer.hidden_states = output_dict['hidden_states'] def _comm_action(self, recv_pre: bool) -> torch.Tensor: """ @@ -218,6 +226,8 @@ def _gen_action(self, model: Module): actions.append(partial(self._gen_token_action, model)) # other stage else: + if self.mb_manager.cur_state is Status.PREFILL: + actions.append(partial(self._init_infer_state_action)) actions.append(partial(self._comm_action, True)) actions.append(partial(self._body_encoding_action, model)) @@ -308,8 +318,9 @@ def generate_step_broadcast(self, model: Module, data_iter: Iterable) -> Union[t if self.verbose and self.stage_manager.is_first_stage(): torch.cuda.synchronize() self.timestamps[self.mb_manager.idx].append(time.time()) - output_dict = model_forward(model, inputs_dict, None) - self.mb_manager.step(inputs_dict, output_dict, None) + self.mb_manager.add_descrption(inputs_dict) + interval_inputs = {'infer_state': self.mb_manager.cur_infer_state} + output_dict = model_forward(model, inputs_dict, interval_inputs) # In GENERATE phase else: # Get hidden_states from previous stage @@ -319,25 +330,28 @@ def generate_step_broadcast(self, model: Module, data_iter: Iterable) -> Union[t assert ( hidden_states is not None ), "When first stage in GENERATE phase, the hidden states should not be None" - logits = model_forward(model, None, hidden_states) + interval_inputs = {'hidden_states': hidden_states['hidden_states'], 'infer_state': self.mb_manager.cur_infer_state} + logits = model_forward(model, None, interval_inputs) if self.verbose and self.stage_manager.is_first_stage(): torch.cuda.synchronize() self.timestamps[self.mb_manager.idx].append(time.time()) - assert ( - "logits" in logits - ), f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}" - new_token = self._get_token_id(logits["logits"]) - self.mb_manager.step(None, None, new_token) + assert 'logits' in logits, f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}" + new_token = self._get_token_id(logits['logits']) + self.mb_manager.step(new_token) # If the current micro batch is not DONE, go through blocks if self.mb_manager.cur_state in (Status.GENERATE, Status.COOLDOWN): inputs_dict = self._prepare_inputs_for_new_token(new_token) - output_dict = model_forward(model, inputs_dict, None) - self.mb_manager.step(inputs_dict, output_dict, None) + interval_inputs = {'infer_state': self.mb_manager.cur_infer_state} + output_dict = model_forward(model, inputs_dict, interval_inputs) else: assert hidden_states is not None, "When not first stage, the hidden states should not be None" - inputs_dict = self._prepare_inputs_for_interval_stage() - output_dict = model_forward(model, inputs_dict, hidden_states) - self.mb_manager.step(inputs_dict, output_dict, None) + # inputs_dict = self._prepare_inputs_for_interval_stage() + inputs_dict = None + if self.mb_manager.cur_state is Status.PREFILL: + inputs_dict = self.load_micro_batch() + self.mb_manager.add_descrption(inputs_dict) + interval_inputs = {'hidden_states': hidden_states['hidden_states'], 'infer_state': self.mb_manager.cur_infer_state} + output_dict = model_forward(model, inputs_dict, interval_inputs) # Current microbatch is not DONE, send hidden_state to next stage if not self.stage_manager.is_first_stage() or self.mb_manager.cur_state in ( diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index a285874d218b..2aa6139836a5 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -76,4 +76,5 @@ def _infer(self): """ Set default params for inference. """ - assert self.pipeline_stage_manager is None, "pipeline parallelism is not supported in inference for now" + # assert self.pipeline_stage_manager is None, "pipeline parallelism is not supported in inference for now" + pass diff --git a/tests/test_infer/test_pipeline_infer.py b/tests/test_infer/test_pipeline_infer.py index ad8e32b48bae..6d02f2b326b4 100644 --- a/tests/test_infer/test_pipeline_infer.py +++ b/tests/test_infer/test_pipeline_infer.py @@ -2,12 +2,15 @@ import torch import torch.distributed as dist import transformers +from packaging import version import colossalai -from colossalai.inference.pipeline.engine import PPInferEngine -from colossalai.inference.pipeline.policy.gpt2_ppinfer import GPT2LMHeadModelPipelinePolicy +from colossalai.inference.pipeline import PPInferEngine +from colossalai.inference.pipeline.policies import LlamaModelInferPolicy from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn +CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5") + def data_gen(): input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779, 318, 13779]], dtype=torch.int64) @@ -24,20 +27,21 @@ def data_gen(): def pipeline_inference_test(pp_size, new_length, micro_batch_size): - model = transformers.GPT2LMHeadModel(transformers.GPT2Config(n_layer=8)) + model = transformers.LlamaForCausalLM(transformers.LlamaConfig(num_hidden_layers=4)) + engine = PPInferEngine( pp_size=pp_size, model=model, - model_policy=GPT2LMHeadModelPipelinePolicy(), + model_policy=LlamaModelInferPolicy(), new_length=new_length, micro_batch_size=micro_batch_size, ) - output = engine.inference([inputs]) + output = engine.inference(inputs) if dist.get_rank() == 0: assert len(output[0]) == new_length, f"{len(output)}, {new_length}" -@parameterize("pp_size", [4]) +@parameterize("pp_size", [2]) @parameterize("new_length", [4, 8, 16]) @parameterize("micro_batch_size", [1, 4]) @clear_cache_before_run() @@ -51,11 +55,12 @@ def check_pipeline_inference(rank, world_size, port): run_pipeline_inference_test() +@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5") @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() def test_pipeline_inference(): - spawn(check_pipeline_inference, nprocs=4) + spawn(check_pipeline_inference, nprocs=2) if __name__ == "__main__": From 3b8137d1837bb247a423abaa508cd400477dc6b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E3=82=A2=E3=83=9E=E3=83=87=E3=82=A6=E3=82=B9?= Date: Fri, 27 Oct 2023 18:19:56 +0800 Subject: [PATCH 16/46] updated c++17 compiler flags (#4983) --- examples/community/roberta/preprocessing/Makefile | 2 +- op_builder/cpu_adam.py | 12 +++++++++++- op_builder/gptq.py | 2 +- op_builder/multi_head_attn.py | 1 + op_builder/scaled_masked_softmax.py | 1 + 5 files changed, 15 insertions(+), 3 deletions(-) diff --git a/examples/community/roberta/preprocessing/Makefile b/examples/community/roberta/preprocessing/Makefile index 82ee4e1c5b31..81478dd49213 100644 --- a/examples/community/roberta/preprocessing/Makefile +++ b/examples/community/roberta/preprocessing/Makefile @@ -1,4 +1,4 @@ -CXXFLAGS += -O3 -Wall -shared -std=c++14 -fPIC -fdiagnostics-color +CXXFLAGS += -O3 -Wall -shared -std=c++14 -std=c++17 -fPIC -fdiagnostics-color CPPFLAGS += $(shell python3 -m pybind11 --includes) LIBNAME = mask LIBEXT = $(shell python3-config --extension-suffix) diff --git a/op_builder/cpu_adam.py b/op_builder/cpu_adam.py index 5a2a2e3e6a56..7988aae4be12 100644 --- a/op_builder/cpu_adam.py +++ b/op_builder/cpu_adam.py @@ -21,12 +21,22 @@ def include_dirs(self): return [self.csrc_abs_path("includes"), self.get_cuda_home_include()] def cxx_flags(self): - extra_cxx_flags = ["-std=c++14", "-lcudart", "-lcublas", "-g", "-Wno-reorder", "-fopenmp", "-march=native"] + extra_cxx_flags = [ + "-std=c++14", + "-std=c++17", + "-lcudart", + "-lcublas", + "-g", + "-Wno-reorder", + "-fopenmp", + "-march=native", + ] return ["-O3"] + self.version_dependent_macros + extra_cxx_flags def nvcc_flags(self): extra_cuda_flags = [ "-std=c++14", + "-std=c++17", "-U__CUDA_NO_HALF_OPERATORS__", "-U__CUDA_NO_HALF_CONVERSIONS__", "-U__CUDA_NO_HALF2_OPERATORS__", diff --git a/op_builder/gptq.py b/op_builder/gptq.py index bc4f445de067..a17801f8783c 100644 --- a/op_builder/gptq.py +++ b/op_builder/gptq.py @@ -37,12 +37,12 @@ def nvcc_flags(self): extra_cuda_flags = [ "-v", "-std=c++14", + "-std=c++17", "-U__CUDA_NO_HALF_OPERATORS__", "-U__CUDA_NO_HALF_CONVERSIONS__", "-U__CUDA_NO_HALF2_OPERATORS__", "-DTHRUST_IGNORE_CUB_VERSION_CHECK", "-lcublas", - "-std=c++17", ] for arch in torch.cuda.get_arch_list(): diff --git a/op_builder/multi_head_attn.py b/op_builder/multi_head_attn.py index b70f041db7d6..cb8fc489ced1 100644 --- a/op_builder/multi_head_attn.py +++ b/op_builder/multi_head_attn.py @@ -35,6 +35,7 @@ def cxx_flags(self): def nvcc_flags(self): extra_cuda_flags = [ "-std=c++14", + "-std=c++17", "-U__CUDA_NO_HALF_OPERATORS__", "-U__CUDA_NO_HALF_CONVERSIONS__", "-U__CUDA_NO_HALF2_OPERATORS__", diff --git a/op_builder/scaled_masked_softmax.py b/op_builder/scaled_masked_softmax.py index b2f1de7792c8..d9239a80eef6 100644 --- a/op_builder/scaled_masked_softmax.py +++ b/op_builder/scaled_masked_softmax.py @@ -25,6 +25,7 @@ def cxx_flags(self): def nvcc_flags(self): extra_cuda_flags = [ "-std=c++14", + "-std=c++17", "-U__CUDA_NO_HALF_OPERATORS__", "-U__CUDA_NO_HALF_CONVERSIONS__", "-U__CUDA_NO_HALF2_OPERATORS__", From 9fce43bb2da4198d1b617a5c386b9ea6e98c46d7 Mon Sep 17 00:00:00 2001 From: Jianghai <72591262+CjhHa1@users.noreply.github.com> Date: Mon, 30 Oct 2023 10:52:19 +0800 Subject: [PATCH 17/46] [Inference] Dynamic Batching Inference, online and offline (#4953) * [inference] Dynamic Batching for Single and Multiple GPUs (#4831) * finish batch manager * 1 * first * fix * fix dynamic batching * llama infer * finish test * support different lengths generating * del prints * del prints * fix * fix bug --------- Co-authored-by: CjhHa1 * [inference] Async dynamic batching (#4894) * finish input and output logic * add generate * test forward * 1 * [inference]Re push async dynamic batching (#4901) * adapt to ray server * finish async * finish test * del test --------- Co-authored-by: yuehuayingxueluo <867460659@qq.com> * Revert "[inference]Re push async dynamic batching (#4901)" (#4905) This reverts commit fbf3c09e673794ed18c91d4bab1a7dfea052e95a. * Revert "[inference] Async dynamic batching (#4894)" This reverts commit fced14025043e29ce816b315f440601188f7f79f. * Revert "[inference] Async dynamic batching (#4894)" (#4909) This reverts commit fced14025043e29ce816b315f440601188f7f79f. * Add Ray Distributed Environment Init Scripts * support DynamicBatchManager base function * revert _set_tokenizer version * add driver async generate * add async test * fix bugs in test_ray_dist.py * add get_tokenizer.py * fix code style * fix bugs about No module named 'pydantic' in ci test * fix bugs in ci test * fix bugs in ci test * fix bugs in ci test * [infer]Add Ray Distributed Environment Init Scripts (#4911) * Revert "[inference] Async dynamic batching (#4894)" This reverts commit fced14025043e29ce816b315f440601188f7f79f. * Add Ray Distributed Environment Init Scripts * support DynamicBatchManager base function * revert _set_tokenizer version * add driver async generate * add async test * fix bugs in test_ray_dist.py * add get_tokenizer.py * fix code style * fix bugs about No module named 'pydantic' in ci test * fix bugs in ci test * fix bugs in ci test * fix bugs in ci test * support dynamic batch for bloom model and is_running function * [Inference]Test for new Async engine (#4935) * infer engine * infer engine * test engine * test engine * new manager * change step * add * test * fix * fix * finish test * finish test * finish test * finish test * add license --------- Co-authored-by: yuehuayingxueluo <867460659@qq.com> * add assertion for config (#4947) * [Inference] Finish dynamic batching offline test (#4948) * test * fix test * fix quant * add default * fix * fix some bugs * fix some bugs * fix * fix bug * fix bugs * reset param --------- Co-authored-by: yuehuayingxueluo <867460659@qq.com> Co-authored-by: Cuiqing Li Co-authored-by: CjhHa1 --- colossalai/inference/async_engine.py | 133 +++++++ colossalai/inference/async_manager.py | 151 ++++++++ .../inference/dynamic_batching/__init__.py | 0 .../dynamic_batching/get_tokenizer.py | 40 ++ .../inference/dynamic_batching/infer_batch.py | 346 ++++++++++++++++++ .../inference/dynamic_batching/io_struct.py | 166 +++++++++ .../dynamic_batching/ray_dist_init.py | 152 ++++++++ .../dynamic_batching/ray_init_config.py | 58 +++ .../inference/dynamic_batching/req_queue.py | 73 ++++ .../dynamic_batching/sampling_params.py | 83 +++++ .../inference/dynamic_batching/stats.py | 45 +++ colossalai/inference/manager.py | 296 +++++++++++++++ .../quant/smoothquant/models/base_model.py | 1 - .../quant/smoothquant/models/llama.py | 27 +- .../inference/tensor_parallel/engine.py | 86 ++++- .../tensor_parallel/kvcache_manager.py | 4 +- .../tensor_parallel/modeling/bloom.py | 36 +- .../tensor_parallel/modeling/chatglm2.py | 17 +- .../tensor_parallel/modeling/llama.py | 49 +-- colossalai/kernel/triton/__init__.py | 1 - .../kernel/triton/copy_kv_cache_dest.py | 2 - requirements/requirements-test.txt | 2 + requirements/requirements.txt | 2 + tests/kit/model_zoo/transformers/llama.py | 6 +- tests/test_infer/test_chatglm2_infer.py | 1 - .../test_dynamic_batching/config.yaml | 14 + .../test_async_engine.py | 61 +++ .../test_dynamic_batching_manager.py | 95 +++++ .../test_offline_dynamic_batching.py | 84 +++++ .../test_dynamic_batching/test_ray_dist.py | 66 ++++ 30 files changed, 2005 insertions(+), 92 deletions(-) create mode 100644 colossalai/inference/async_engine.py create mode 100644 colossalai/inference/async_manager.py create mode 100644 colossalai/inference/dynamic_batching/__init__.py create mode 100644 colossalai/inference/dynamic_batching/get_tokenizer.py create mode 100644 colossalai/inference/dynamic_batching/infer_batch.py create mode 100644 colossalai/inference/dynamic_batching/io_struct.py create mode 100644 colossalai/inference/dynamic_batching/ray_dist_init.py create mode 100644 colossalai/inference/dynamic_batching/ray_init_config.py create mode 100644 colossalai/inference/dynamic_batching/req_queue.py create mode 100644 colossalai/inference/dynamic_batching/sampling_params.py create mode 100644 colossalai/inference/dynamic_batching/stats.py create mode 100644 colossalai/inference/manager.py create mode 100644 tests/test_infer/test_dynamic_batching/config.yaml create mode 100644 tests/test_infer/test_dynamic_batching/test_async_engine.py create mode 100644 tests/test_infer/test_dynamic_batching/test_dynamic_batching_manager.py create mode 100644 tests/test_infer/test_dynamic_batching/test_offline_dynamic_batching.py create mode 100644 tests/test_infer/test_dynamic_batching/test_ray_dist.py diff --git a/colossalai/inference/async_engine.py b/colossalai/inference/async_engine.py new file mode 100644 index 000000000000..d0890ba3e9fc --- /dev/null +++ b/colossalai/inference/async_engine.py @@ -0,0 +1,133 @@ +import asyncio + +from colossalai.inference.dynamic_batching.ray_dist_init import Driver + +from .dynamic_batching.io_struct import RequestOutput +from .dynamic_batching.sampling_params import SamplingParams + + +class RequestTracker: + """ + A class for trace down all the requests, abstraction for async + """ + + def __init__(self) -> None: + self._requests: asyncio.Queue[str] = asyncio.Queue() + self._finished_requests: asyncio.Queue[RequestOutput] = asyncio.Queue() + self.new_requests_event = None + + def __contains__(self, item): + return item in self._requests + + def init_event(self): + self.new_requests_event = asyncio.Event() + + def add_request(self, request_id: str): + """Add a request to be sent to the engine on the next background + loop iteration.""" + self._requests.put_nowait(request_id) + self.new_requests_event.set() # NOTE: we may find a better way to clear this event + + def add_stop(self): + """ + Add a StopIteration flag to stop async generator. + """ + self._finished_requests.put_nowait(StopIteration) + self.new_requests_event.clear() + + def process_request_output(self, request_output: RequestOutput) -> None: + """Process a request output from the engine.""" + self._finished_requests.put_nowait(request_output) + + async def wait_for_new_requests(self): + await self.new_requests_event.wait() + + def __aiter__(self): + return self + + async def __anext__(self) -> RequestOutput: + result = await self._finished_requests.get() + # print("result of ", result) + if result is StopIteration: + raise StopAsyncIteration + return result + + +class Async_Engine: + + """ + Use an engine to launch RAY Driver --> RAY Worker --> Async_Manager + Background loop: inference reqs in waiting list (Listen) + Request Tracker: manage incoming requests and restore finished ones + Generate: exposed func for add new input and return finished ones + """ + + def __init__( + self, + router_config, + engine_config, + start_engine_loop: bool = True, + ) -> None: + self.driver = Driver(router_config=router_config, engine_config=engine_config) + self.background_loop = None + self.start_engine_loop = start_engine_loop + self._request_tracker = RequestTracker() + + def _step(self): + """ + Logic for handling requests + """ + request_outputs = self.driver.step() + if request_outputs is not None: + for request_output in request_outputs: + self._request_tracker.process_request_output(request_output) + self._request_tracker.add_stop() + + def abort_request(self, request_id: str): + self.driver.abort(request_id) + + def _has_requests_in_progress(self): + return self.driver.is_running() + + async def run_loop_fwd(self): + has_requests_in_progress = self._has_requests_in_progress() + while True: + if not has_requests_in_progress: + await self._request_tracker.wait_for_new_requests() + self._step() + await asyncio.sleep(0) + + @property + def is_running(self): + return self.background_loop is not None and not self.background_loop.done() + + def start_background_loop(self): + if self.is_running: + raise RuntimeError("Background loop is already running.") + + self._request_tracker.init_event() + + self.background_loop_unshielded = asyncio.get_event_loop().create_task(self.run_loop_fwd()) + self.background_loop = asyncio.shield(self.background_loop_unshielded) + + async def add_request(self, request_id: str, prompt: str, sampling_params: SamplingParams): + self.driver.add_input(request_id, prompt, sampling_params) + self._request_tracker.add_request(request_id) + + async def generate(self, request_id: str, prompt: str, sampling_params: SamplingParams): + """ + The only exposed func, adding new request and return a async generator that yields the existing results. + """ + try: + if not self.is_running: + self.start_background_loop() + + await self.add_request(request_id, prompt, sampling_params) + + async for request_output in self._request_tracker: + yield request_output + + except (Exception, asyncio.CancelledError) as e: + # If there is an exception or coroutine is cancelled, abort the request. + self.abort_request(request_id) + raise e diff --git a/colossalai/inference/async_manager.py b/colossalai/inference/async_manager.py new file mode 100644 index 000000000000..60440a792f1c --- /dev/null +++ b/colossalai/inference/async_manager.py @@ -0,0 +1,151 @@ +from typing import List + +from .dynamic_batching.io_struct import Batch, Req, RequestOutput +from .manager import DynamicBatchManager +from .tensor_parallel import TPInferEngine + + +class Async_DynamicBatchManager(DynamicBatchManager): + def __init__( + self, + tp_engine: TPInferEngine, + max_total_token_num: int, + batch_max_tokens: int, + model: str, + tokenizer=None, + eos_id=None, + log_stats=True, + log_stats_interval=10, + running_batch: Batch = None, + waiting_req_list: List = [], + ): + """ + Args: tp_engine : The tp engine that dynamic batch manager hold, defined before dynamic batch manager + max_total_token_num : max_total_token_num for memory manager, default to: max batch size * (max input len + max output len) + batch_max_tokens : max tokens of one batch, default to (max input + output len) * num_requests + running_max_req_size : max request size of running batch, equals to MAX_BATCH_SIZE of tp engine + eos_id : The end token of a seq + model: the model weight dir path, the app will load config, weights and tokenizer from this dir + log_stats : whether to log stats + log_stats_interval : log stats interval + running_batch : running batch + waiting_req_list : list of waiting requests, initialized before dynamic batch manager + """ + super().__init__( + tp_engine, + max_total_token_num, + batch_max_tokens, + model, + tokenizer, + eos_id, + log_stats, + log_stats_interval, + running_batch, + waiting_req_list, + ) + + def _step(self): + """ + Logic for handling requests + """ + has_new_finished = False + if self.running_batch is None: + new_batch = self.req_queue.generate_new_batch(self.running_batch) + if new_batch is not None: + self.stats_tool.count_prompt_tokens(new_batch) + self.running_batch = new_batch + has_new_finished, outputs = self._prefill_batch(self.running_batch) + self._filter_runing_batch() + self.has_wait_tokens = 0 + + else: + if self.has_wait_tokens < self.max_wait_tokens: + self.stats_tool.count_output_tokens(self.running_batch) + has_new_finished, outputs = self._decode_batch(self.running_batch) + self._filter_runing_batch() + self.has_wait_tokens += 1 + + else: + new_mini_batch = self.req_queue.generate_new_batch(self.running_batch) + if new_mini_batch is not None: + self.stats_tool.count_prompt_tokens(new_mini_batch) + has_new_finished, outputs = self._prefill_batch(new_mini_batch) + if not new_mini_batch.is_clear(): + self._merge_batch(self.running_batch, new_mini_batch) + self.running_batch.merge(new_mini_batch) + self.has_wait_tokens = 0 + + else: + self.stats_tool.count_output_tokens(self.running_batch) + has_new_finished, outputs = self._decode_batch(self.running_batch) + self._filter_runing_batch() + self.has_wait_tokens += 1 + + if has_new_finished: + return outputs + return None + + def _prefill_batch(self, batch): + """ + For all batches, no matter it is a new batch or a mini batch, we need to do prefill first. + """ + self._init_batch(batch) + + # TODO: figure out if cache and batch id is needed + ans = self.engine._prefill_batch(batch.batch_id) + req_to_out_token_id = ans + self._add_token_id_to_req(batch, req_to_out_token_id) + has_new_finished_req = batch.mark_finished_req(self.eos_id, self.engine.max_output_len) + outputs = self._handle_finish_req(batch, has_new_finished_req) + return has_new_finished_req, outputs + # delete finished reqs + + def _decode_batch(self, batch: Batch): + """ + Decoding process + """ + ans = self.engine._decode_batch(batch.batch_id) + req_to_out_token_id = ans + self._add_token_id_to_req(batch, req_to_out_token_id) + has_new_finished_req = batch.mark_finished_req(self.eos_id, self.engine.max_output_len) + outputs = self._handle_finish_req(batch, has_new_finished_req) + return has_new_finished_req, outputs + + def _handle_finish_req(self, batch: Batch, has_new_finished_req): + if has_new_finished_req: + finished_reqs = batch.filter_finished() + if batch.is_clear(): + self._remove_batch(batch) + else: + self._filter_batch(batch) + return self._output_process(finished_reqs) + return None + + def _output_process(self, finished_reqs: List[Req]): + """ + Process the output of a batch. + """ + outputs = [] + for req in finished_reqs: + output = self.tokenizer.decode(req.output_ids) + outputs.append(RequestOutput(req.request_id, req.prompts, req.prompt_ids, output)) + return outputs + + +def start_dynamic_batching(args, tp_engine, waiting_req_list): + try: + batch_manager = Async_DynamicBatchManager( + tp_engine=tp_engine, + max_total_token_num=args.max_total_token_num, + batch_max_tokens=args.batch_max_tokens, + eos_id=args.eos_id, + model=args.model, + log_stats=not args.disable_log_stats, + log_stats_interval=args.log_stats_interval, + waiting_req_list=waiting_req_list, + ) + + except Exception: + raise Exception + + return batch_manager diff --git a/colossalai/inference/dynamic_batching/__init__.py b/colossalai/inference/dynamic_batching/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/colossalai/inference/dynamic_batching/get_tokenizer.py b/colossalai/inference/dynamic_batching/get_tokenizer.py new file mode 100644 index 000000000000..94aa3f24393f --- /dev/null +++ b/colossalai/inference/dynamic_batching/get_tokenizer.py @@ -0,0 +1,40 @@ +""" +Motivated by VllM (https://github.com/vllm-project/vllm), This module is trying to resolve the tokenizer issue. + +license: MIT, see LICENSE for more details. +""" + +from transformers import AutoTokenizer + +_FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer" + + +def get_tokenizer( + tokenizer=None, + tokenizer_name: str = "", + trust_remote_code: bool = False, + use_fast: bool = True, +): + if tokenizer is not None: + tokenizer = tokenizer + else: + if "llama" in tokenizer_name.lower() and use_fast == True: + print( + "For some LLaMA-based models, initializing the fast tokenizer may " + "take a long time. To eliminate the initialization time, consider " + f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original " + "tokenizer. This is done automatically in Colossalai." + ) + + tokenizer_name = _FAST_LLAMA_TOKENIZER + + try: + tokenizer = AutoTokenizer.from_pretrained( + tokenizer_name, use_fast=use_fast, trust_remote_code=trust_remote_code + ) + except TypeError: + use_fast = False + tokenizer = AutoTokenizer.from_pretrained( + tokenizer_name, use_fast=use_fast, trust_remote_code=trust_remote_code + ) + return tokenizer diff --git a/colossalai/inference/dynamic_batching/infer_batch.py b/colossalai/inference/dynamic_batching/infer_batch.py new file mode 100644 index 000000000000..112784c15f84 --- /dev/null +++ b/colossalai/inference/dynamic_batching/infer_batch.py @@ -0,0 +1,346 @@ +# Adapted from https://github.com/ModelTC/lightllm + +import collections +from dataclasses import dataclass +from typing import Dict, List, Tuple + +import numpy as np +import torch + +from colossalai.inference.tensor_parallel import MemoryManager + + +# make batch infer state an attr of InferBatch +class InferSamplingParams: + def __init__( + self, + do_sample: bool = False, + presence_penalty: float = 0.0, + frequency_penalty: float = 0.0, + temperature: float = 1.0, + top_p: float = 1.0, + top_k: int = -1, + vocab_size: int = -1, + ) -> None: + self.do_sample = do_sample + self.presence_penalty = presence_penalty + self.frequency_penalty = frequency_penalty + self.temperature = temperature + self.top_p = top_p + self.top_k = top_k + if self.top_k == -1: + self.top_k = vocab_size + return + + +@dataclass +class InferBatch: + batch_id: int + requests: List + requests_idx_mapping: Dict[int, int] + + input_ids: torch.Tensor + + all_input_ids: List[List[int]] + input_lengths: List[int] + + out_token_id_counts: List + sampling_param_list: List[InferSamplingParams] + + nopad_total_token_num: int + nopad_max_len_in_batch: int + nopad_b_loc: torch.Tensor + nopad_b_start_loc: torch.Tensor + nopad_b_seq_len: torch.Tensor + cache_manager: MemoryManager + max_total_len: int + + @classmethod + @torch.no_grad() + def init_batch( + cls, + batch_id, + requests, + dtype: torch.dtype, + device: torch.device, + cache_manager: MemoryManager, + vocab_size: int, + max_total_len: int, + ) -> "InferBatch": + input_lengths = [] + all_input_ids = [] + requests_idx_mapping = {} + + out_token_id_counts = [] + sampling_param_list = [] + + nopad_total_token_num = 0 + nopad_max_len_in_batch = 0 + nopad_b_loc = torch.empty((len(requests), max_total_len + 12), dtype=torch.long, device="cuda") + # to avoid memory leak , we pre-allocate 12 more space for each batch. + nopad_b_start_loc = torch.zeros(len(requests), dtype=torch.int32, device="cuda") + for i, r in enumerate(requests): + # request id -> idx in list mapping + requests_idx_mapping[r["request_id"]] = i + + tokenized_input = r["input_id"] + + input_length = len(tokenized_input) + input_lengths.append(input_length) + all_input_ids.append(tokenized_input) + out_token_id_counts.append(collections.defaultdict(int)) + + # postprocessor + sampling_param = r["sampling_param"] + sampling_param["vocab_size"] = vocab_size + sampling_param_list.append(InferSamplingParams(**sampling_param)) + + nopad_total_token_num += input_length + nopad_max_len_in_batch = max(nopad_max_len_in_batch, input_length) + + nopad_b_seq_len = torch.tensor(input_lengths, dtype=torch.int32, device="cuda") + nopad_b_start_loc[1:] = torch.cumsum(nopad_b_seq_len, dim=0, dtype=torch.int32)[0:-1] + + if len(requests) > 1: + input_ids = np.concatenate(all_input_ids, dtype=np.int64) + else: + input_ids = all_input_ids[0] + + # Create tensors on device + input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) + + return cls( + batch_id=batch_id, + requests=requests, + requests_idx_mapping=requests_idx_mapping, + input_ids=input_ids, + input_lengths=input_lengths, + all_input_ids=all_input_ids, + nopad_total_token_num=nopad_total_token_num, + nopad_max_len_in_batch=nopad_max_len_in_batch, + nopad_b_loc=nopad_b_loc, + nopad_b_start_loc=nopad_b_start_loc, + nopad_b_seq_len=nopad_b_seq_len, + out_token_id_counts=out_token_id_counts, + sampling_param_list=sampling_param_list, + cache_manager=cache_manager, + max_total_len=max_total_len, + ) + + @torch.no_grad() + def free_self(self) -> None: + """ + Free the memory of the InferBatch itself + """ + remove_index = [] + for idx in range(len(self)): + remove_index.append( + self.nopad_b_loc[ + idx, + (self.nopad_max_len_in_batch - 1) + - (self.nopad_b_seq_len[idx] - 1) : (self.nopad_max_len_in_batch - 1), + ] + ) + remove_index = torch.cat(remove_index, dim=-1) + self.cache_manager.free(remove_index) + + @torch.no_grad() + def filter(self, request_ids: List[int]) -> "InferBatch": + """ + Filter finished batch and return a new InferBatch with left ones. + """ + if len(request_ids) == 0: + raise ValueError("Batch must have at least one request") + if len(request_ids) == len(self): + return self + requests_idx_mapping = {} + indices = [] + requests = [] + all_input_ids = [] + input_lengths = [] + nopad_total_token_num = 0 + nopad_max_len_in_batch = 0 + nopad_b_loc = torch.empty((len(request_ids), self.max_total_len + 12), dtype=torch.long, device="cuda") + nopad_b_start_loc = torch.zeros(len(request_ids), dtype=torch.int32, device="cuda") + nopad_b_seq_len = torch.zeros(len(request_ids), dtype=torch.int32, device="cuda") + + left_idx = [] + for i, request_id in enumerate(request_ids): + idx = self.requests_idx_mapping[request_id] + left_idx.append(idx) + + left_idx_set = set(left_idx) + remove_index = [] + for idx in range(len(self)): + if idx not in left_idx_set: + remove_index.append( + self.nopad_b_loc[ + idx, + (self.nopad_max_len_in_batch - 1) + - (self.nopad_b_seq_len[idx] - 1) : (self.nopad_max_len_in_batch - 1), + ] + ) + remove_index = torch.cat(remove_index, dim=-1) + self.cache_manager.free(remove_index) + + nopad_max_len_in_batch = 0 + for i, request_id in enumerate(request_ids): + idx = self.requests_idx_mapping[request_id] + indices.append(idx) + + nopad_b_seq_len[:] = self.nopad_b_seq_len[indices] + nopad_max_len_in_batch = torch.max(nopad_b_seq_len).item() + nopad_b_start_loc[1:] = torch.cumsum(nopad_b_seq_len, dim=0, dtype=torch.int32)[0:-1] + nopad_total_token_num = torch.sum(nopad_b_seq_len).item() + + nopad_b_loc[:, 0 : (nopad_max_len_in_batch - 1)] = self.nopad_b_loc[ + indices, + (self.nopad_max_len_in_batch - 1) - (nopad_max_len_in_batch - 1) : (self.nopad_max_len_in_batch - 1), + ] + for i, request_id in enumerate(request_ids): + idx = self.requests_idx_mapping[request_id] + requests_idx_mapping[request_id] = i + requests.append(self.requests[idx]) + all_input_ids.append(self.all_input_ids[idx]) + input_lengths.append(self.input_lengths[idx]) + + input_ids = self.input_ids[indices] + + return InferBatch( + batch_id=self.batch_id, + requests=requests, + requests_idx_mapping=requests_idx_mapping, + input_ids=input_ids, + input_lengths=input_lengths, + all_input_ids=all_input_ids, + nopad_total_token_num=nopad_total_token_num, + nopad_max_len_in_batch=nopad_max_len_in_batch, + nopad_b_loc=nopad_b_loc, + nopad_b_start_loc=nopad_b_start_loc, + nopad_b_seq_len=nopad_b_seq_len, + out_token_id_counts=[self.out_token_id_counts[_i] for _i in indices], + sampling_param_list=[self.sampling_param_list[_i] for _i in indices], + cache_manager=self.cache_manager, + max_total_len=self.max_total_len, + ) + + @classmethod + @torch.no_grad() + def merge(cls, batch1, batch2) -> "InferBatch": + """ + Return megerd new InferBatch + """ + requests = batch1.requests + batch2.requests + requests_idx_mapping = {} + new_batch_size = len(batch1) + len(batch2) + + input_ids = batch1.input_ids.new_empty(new_batch_size) + all_input_ids = [] + input_lengths = [] + out_token_id_counts = [] + sampling_param_list = [] + + cumulative_batch_size = 0 + nopad_total_token_num = batch1.nopad_total_token_num + batch2.nopad_total_token_num + nopad_max_len_in_batch = max(batch1.nopad_max_len_in_batch, batch2.nopad_max_len_in_batch) + max_total_len = max(batch1.max_total_len, batch2.max_total_len) + nopad_b_loc = torch.empty((new_batch_size, batch1.max_total_len + 12), dtype=torch.long, device="cuda") + nopad_b_start_loc = torch.zeros(new_batch_size, dtype=torch.int32, device="cuda") + nopad_b_seq_len = torch.zeros(new_batch_size, dtype=torch.int32, device="cuda") + nopad_start_loc_len_temp = 0 + batches = [batch1, batch2] + for i, batch in enumerate(batches): + if i == 0: + requests_idx_mapping = batch.requests_idx_mapping + else: + for k, v in batch.requests_idx_mapping.items(): + requests_idx_mapping[k] = v + cumulative_batch_size + start_index = cumulative_batch_size + end_index = cumulative_batch_size + len(batch) + input_ids[start_index:end_index] = batch.input_ids + nopad_b_seq_len[start_index:end_index] = batch.nopad_b_seq_len + nopad_b_start_loc[start_index:end_index] = batch.nopad_b_start_loc + nopad_start_loc_len_temp + nopad_start_loc_len_temp = nopad_b_start_loc[end_index - 1] + nopad_b_seq_len[end_index - 1] + nopad_b_loc[ + start_index:end_index, + nopad_max_len_in_batch - batch.nopad_max_len_in_batch : nopad_max_len_in_batch - 1, + ] = batch.nopad_b_loc[:, : batch.nopad_max_len_in_batch - 1] + + all_input_ids.extend(batch.all_input_ids) + + input_lengths.extend(batch.input_lengths) + out_token_id_counts.extend(batch.out_token_id_counts) + sampling_param_list.extend(batch.sampling_param_list) + # Update + cumulative_batch_size += len(batch) + + nopad_b_loc[:, nopad_max_len_in_batch - 1] = ( + nopad_total_token_num - new_batch_size + torch.arange(0, new_batch_size, dtype=torch.int32, device="cuda") + ) + return InferBatch( + batch_id=batches[0].batch_id, + requests=requests, + requests_idx_mapping=requests_idx_mapping, + input_ids=input_ids, + input_lengths=input_lengths, + all_input_ids=all_input_ids, + nopad_total_token_num=nopad_total_token_num, + nopad_max_len_in_batch=nopad_max_len_in_batch, + nopad_b_loc=nopad_b_loc, + nopad_b_start_loc=nopad_b_start_loc, + nopad_b_seq_len=nopad_b_seq_len, + out_token_id_counts=out_token_id_counts, + sampling_param_list=sampling_param_list, + cache_manager=batches[0].cache_manager, + max_total_len=max_total_len, + ) + + def __len__(self): + return len(self.requests) + + def get_post_sample_tensors(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + presence_penalties: List[float] = [] + frequency_penalties: List[float] = [] + temperatures: List[float] = [] + top_ps: List[float] = [] + top_ks: List[int] = [] + p_token_ids: List[int] = [] + p_token_counts: List[int] = [] + p_seq_len: List[int] = [ + 0, + ] + p_max_len_in_batch: int = 0 + for i, id_to_count in enumerate(self.out_token_id_counts): + sample_param = self.sampling_param_list[i] + presence_penalties.append(sample_param.presence_penalty) + frequency_penalties.append(sample_param.frequency_penalty) + temperatures.append(sample_param.temperature) + top_ps.append(sample_param.top_p) + top_ks.append(sample_param.top_k) + + for token_id, count in id_to_count.items(): + p_token_ids.append(token_id) + p_token_counts.append(count) + p_seq_len.append(len(id_to_count)) + p_max_len_in_batch = max(p_max_len_in_batch, len(id_to_count)) + + presence_penalties = torch.tensor(presence_penalties, dtype=torch.float, device="cuda") + frequency_penalties = torch.tensor(frequency_penalties, dtype=torch.float, device="cuda") + temperatures = torch.tensor(temperatures, dtype=torch.float, device="cuda") + top_ps = torch.tensor(top_ps, dtype=torch.float, device="cuda") + top_ks = torch.tensor(top_ks, dtype=torch.int32, device="cuda") + p_token_ids = torch.tensor(p_token_ids, dtype=torch.int32, device="cuda") + p_token_counts = torch.tensor(p_token_counts, dtype=torch.int32, device="cuda") + p_seq_len = torch.tensor(p_seq_len, dtype=torch.int32, device="cuda") + p_cumsum_seq_len = torch.cumsum(p_seq_len, dim=0, dtype=torch.int32) + return ( + presence_penalties, + frequency_penalties, + temperatures, + top_ps, + top_ks, + p_token_ids, + p_token_counts, + p_cumsum_seq_len, + p_max_len_in_batch, + ) diff --git a/colossalai/inference/dynamic_batching/io_struct.py b/colossalai/inference/dynamic_batching/io_struct.py new file mode 100644 index 000000000000..fc5ecfe5796b --- /dev/null +++ b/colossalai/inference/dynamic_batching/io_struct.py @@ -0,0 +1,166 @@ +# Adapted from https://github.com/ModelTC/lightllm + +from typing import Dict, List, Tuple + +from .sampling_params import SamplingParams + + +class Req: + def __init__(self, request_id, prompt_ids, sample_params: SamplingParams, prompts: str = ""): + self.request_id = request_id + self.prompt_ids = prompt_ids + self.input_len = len(prompt_ids) + self.max_output_len = sample_params.max_new_tokens + self.sample_params = sample_params + self.output_ids = [] + self.output_metadata_list = [] + self.has_generate_finished = False + self.aborted = False + self.prompts = prompts + + def to_rpc_obj(self): + return { + "request_id": self.request_id, + "input_id": self.prompt_ids, + "output_len": self.max_output_len, + "sampling_param": self.sample_params.to_dict(), + } + + def stop_sequences_matched(self): + # should we add stpp sequences to the sample params? + if self.sample_params.stop_sequences is not None: + for stop_token_ids in self.sample_params.stop_sequences: + stop_len = len(stop_token_ids) + if ( + stop_len > 0 + and len(self.output_ids) >= stop_len + and all(self.output_ids[-(stop_len - i)] == stop_token_ids[i] for i in range(stop_len)) + ): + return True + return False + + def __repr__(self): + return f"request_id(n={self.request_id}, " f"prompt_ids={self.prompt_ids}, " + + +class Batch: + def __init__(self, batch_id, reqs: List[Req]): + self.batch_id = batch_id + self.reqs = reqs + self.id_to_reqs = {req.request_id: req for req in reqs} + + def input_tokens(self): + batch_input_tokens = 0 + for req in self.reqs: + batch_input_tokens += req.input_len + return batch_input_tokens + + def calcu_max_tokens(self): + tokens = 0 + for req in self.reqs: + tokens += req.input_len + req.max_output_len + return tokens + + def calcu_used_tokens(self): + tokens = 0 + for req in self.reqs: + tokens += req.input_len + len(req.output_ids) + return tokens + + def mark_finished_req(self, eos_id, engine_max_output_len): + has_new_finish = False + for req in self.reqs: + if req.stop_sequences_matched(): + req.has_generate_finished = True + has_new_finish = True + if len(req.output_ids) >= engine_max_output_len: + req.has_generate_finished = True + has_new_finish = True + if req.output_ids[-1] == eos_id and req.sample_params.ignore_eos == False: + req.has_generate_finished = True + has_new_finish = True + if len(req.output_ids) >= req.max_output_len or req.aborted: + req.has_generate_finished = True + has_new_finish = True + return has_new_finish + + def filter_finished(self) -> List[Req]: + """ + Filter finished requests from the batch, the finished ones will be removed from 'reqs'. + """ + # TODO: the logic of return should be defined here. + unfinished_req = [] + finished_req = [] + for req in self.reqs: + if not req.has_generate_finished: + unfinished_req.append(req) + else: + finished_req.append(req) + self.reqs = unfinished_req + self.id_to_reqs = {req.request_id: req for req in self.reqs} + return finished_req + + def is_clear(self): + return len(self.reqs) == 0 + + def merge(self, mini_batch): + for _req in mini_batch.reqs: + self.reqs.append(_req) + self.id_to_reqs = {req.request_id: req for req in self.reqs} + return + + def __repr__(self): + return f"batch_id={self.batch_id}, " f"reqs={self.reqs}, " + + def __len__(self): + return len(self.reqs) + + +class BatchTokenIdOut: + def __init__(self): + self.reqs_infs: List[ + Tuple[str, int, Dict, bool, bool] + ] = [] # [req_id, new_token_id, gen_metadata, finished_state, abort_state] + + +class BatchStrOut: + def __init__(self): + self.reqs_infs: List[ + Tuple[str, str, Dict, bool, bool] + ] = [] # [req_id, token_str, gen_metadata, finished_state, abort_state] + + +class AbortReq: + def __init__(self, req_id): + self.req_id = req_id + + +class RequestOutput: + """The output data of a request to the LLM. + + Args: + request_id: The unique ID of the request. + prompt: The prompt string of the request. + prompt_token_ids: The token IDs of the prompt. + outputs: The output sequences of the request. + """ + + def __init__( + self, + request_id: str, + prompt: str, + prompt_token_ids: List[int], + outputs, + ) -> None: + self.request_id = request_id + self.prompt = prompt + self.prompt_token_ids = prompt_token_ids + self.outputs = outputs + + def __repr__(self) -> str: + return ( + f"RequestOutput(request_id={self.request_id}, " + f"prompt={self.prompt!r}, " + f"prompt_token_ids={self.prompt_token_ids}, " + f"outputs={self.outputs}, " + ) diff --git a/colossalai/inference/dynamic_batching/ray_dist_init.py b/colossalai/inference/dynamic_batching/ray_dist_init.py new file mode 100644 index 000000000000..70ef489d3b70 --- /dev/null +++ b/colossalai/inference/dynamic_batching/ray_dist_init.py @@ -0,0 +1,152 @@ +import logging +import os +from typing import List + +import ray +import ray.util.collective as collective +import torch +from transformers import AutoModelForCausalLM + +import colossalai +from colossalai.inference.async_manager import start_dynamic_batching +from colossalai.inference.dynamic_batching.get_tokenizer import get_tokenizer +from colossalai.inference.dynamic_batching.io_struct import RequestOutput +from colossalai.inference.dynamic_batching.ray_init_config import EngineArgsClass, RooterArgsClass +from colossalai.inference.dynamic_batching.sampling_params import SamplingParams +from colossalai.inference.tensor_parallel.engine import TPInferEngine +from colossalai.shardformer import ShardConfig +from colossalai.testing import free_port + +ray_serve_logger = logging.getLogger("ray.serve") + + +def log_cuda_info(scope_name: str): + ray_serve_logger.info(f" {scope_name}: ray.get_gpu_ids(): {ray.get_gpu_ids()}") + ray_serve_logger.info( + f" {scope_name}: CUDA_VISIBLE_DEVICES: {os.getenv('CUDA_VISIBLE_DEVICES', 'NO DEVICES FOUND!')}" + ) + if torch.cuda.is_available(): + ray_serve_logger.info( + f" {scope_name}: cuda current_device: {torch.cuda.current_device()}, cuda device count: {torch.cuda.device_count()}" + ) + else: + ray_serve_logger.info(f" {scope_name}: cuda is not available!") + + +@ray.remote(num_gpus=1) +class Worker: + def __init__( + self, + model_path: str, + tensor_parallel_size: int, + max_batch_size: int, + max_input_len: int, + max_output_len: int, + router_config: RooterArgsClass, + ): + log_cuda_info("Worker.init") + self.tensor_parallel_size = tensor_parallel_size + self.model_path = model_path + self.max_batch_size = max_batch_size + self.max_input_len = max_input_len + self.max_output_len = max_output_len + self.router_config = router_config + + def setup(self, world_size, rank, port): + # initialize a ray collective group, otherwise colossalai distributed env won't be built successfully + collective.init_collective_group(world_size, rank, "nccl", "default") + # initialize and set distributed environment + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + ray_serve_logger.info(f"Worker with rank {rank} (world size {world_size}) setting up..") + log_cuda_info("Worker.setup") + + # Load model + self.tokenizer = get_tokenizer(tokenizer_name=self.model_path) + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + self.model = AutoModelForCausalLM.from_pretrained( + self.model_path, pad_token_id=self.tokenizer.pad_token_id, torch_dtype=torch.float16 + ) + shard_config = ShardConfig(enable_tensor_parallelism=True if world_size > 1 else False, inference_only=True) + self.infer_engine = TPInferEngine( + self.model, shard_config, self.max_batch_size, self.max_input_len, self.max_output_len + ) + self.start_dynamic_batching = start_dynamic_batching(self.router_config, self.infer_engine, []) + + return True + + # def generate(self, request_id: str, prompt: str, sampling_params: SamplingParams) -> List[str]: + # ray_serve_logger.info(f"text: {prompt}") + + # final_outputs = self.start_dynamic_batching.generate(prompt, sampling_params, request_id) + + # return final_outputs + + def add_input(self, request_id: str, prompt: str, sampling_params: SamplingParams): + self.start_dynamic_batching.add_input(request_id, prompt, sampling_params) + + def abort(self, request_id: str): + self.start_dynamic_batching.abort(request_id) + + def step(self) -> List[RequestOutput]: + return self.start_dynamic_batching._step() + + def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, request_id: str, prompt: str): + self.start_dynamic_batching.add_req(prompt_ids, sampling_params, request_id, prompt) + + def is_running(self): + return self.start_dynamic_batching.is_running() + + +class Driver: + def __init__(self, router_config: RooterArgsClass, engine_config: EngineArgsClass): + log_cuda_info("Driver:init") + model_path = engine_config.model + tensor_parallel_size = engine_config.tensor_parallel_size + + self.num_workers = tensor_parallel_size + self.workers = [] + init_rets = [] + + # Just grab a free port on localhost + # NOTE workers in this communication group listen to the same port + available_port = free_port() + + for i in range(self.num_workers): + worker_name = "worker_idx_{}".format(i) + w = Worker.options(name=worker_name).remote( + model_path, + self.num_workers, + engine_config.max_batch_size, + engine_config.max_input_len, + engine_config.max_output_len, + router_config, + ) + self.workers.append(w) + init_rets.append(w.setup.remote(self.num_workers, i, available_port)) + _options = { + "group_name": "default_driver", + "world_size": self.num_workers, + "ranks": [i for i in range(self.num_workers)], + "backend": "nccl", + } + collective.create_collective_group(self.workers, **_options) + _ = ray.get(init_rets) + + def add_input(self, request_id: str, prompt: str, sampling_params: SamplingParams): + ray.get([w.add_input.remote(request_id, prompt, sampling_params) for w in self.workers]) + + def abort(self, request_id: str): + ray.get([w.abort.remote(request_id) for w in self.workers]) + + def step(self): + results = ray.get([w.step.remote() for w in self.workers]) + outputs = results[0] # get any one of the copies + return outputs + + def add_req(self, request_id: str, prompt_ids: List[int], sampling_params: SamplingParams, prompt: str): + ray.get([w.add_req.remote(prompt_ids, sampling_params, request_id, prompt) for w in self.workers]) + + def is_running(self): + results = ray.get([w.is_running.remote() for w in self.workers]) + return any(results) diff --git a/colossalai/inference/dynamic_batching/ray_init_config.py b/colossalai/inference/dynamic_batching/ray_init_config.py new file mode 100644 index 000000000000..471f07330aec --- /dev/null +++ b/colossalai/inference/dynamic_batching/ray_init_config.py @@ -0,0 +1,58 @@ +import logging + +import yaml +from pydantic import BaseModel + +logger = logging.getLogger(__name__) + + +class EngineArgsClass(BaseModel): + """Config for Engine""" + + model: str + tensor_parallel_size: int = 2 + max_batch_size: int = 4 + max_input_len: int = 128 + max_output_len: int = 32 + + +class RooterArgsClass(BaseModel): + """Config for Rooter""" + + max_total_token_num: int = 42 + batch_max_tokens: int = 42 + eos_id: int = 0 + disable_log_stats: bool = False + log_stats_interval: int = 10 + model: str + + +class RayInitConfig(BaseModel): + """All-together configs without app router config""" + + engine_config_data: EngineArgsClass + router_config_data: RooterArgsClass + + @classmethod + def from_yaml_path(cls, path: str): + try: + with open(path, "r") as yaml_file: + try: + config = yaml.safe_load(yaml_file) + # serve deployment config + engine_config = config.get("engine_config", {}) + router_config = config.get("router_config", {}) + + return cls( + engine_config_data=engine_config, + router_config_data=router_config, + ) + except yaml.YAMLError as e: + logger.error(f"An Error occurred when parsing yaml: {e}") + raise + except FileNotFoundError: + logger.error(f"The file '{path}' does not exist!") + raise + except OSError as e: + logger.error(f"An Error occurred: {e}") + raise diff --git a/colossalai/inference/dynamic_batching/req_queue.py b/colossalai/inference/dynamic_batching/req_queue.py new file mode 100644 index 000000000000..0de43bd1a21f --- /dev/null +++ b/colossalai/inference/dynamic_batching/req_queue.py @@ -0,0 +1,73 @@ +# Adapted from https://github.com/ModelTC/lightllm + +import uuid +from typing import List + +import numpy as np + +from .io_struct import Batch, Req + + +class ReqQueue: + def __init__(self, max_total_tokens, batch_max_tokens, running_max_req_size, waiting_req_list=[]) -> None: + self.max_total_tokens = max_total_tokens + assert batch_max_tokens is not None + self.batch_max_tokens = batch_max_tokens + self.running_max_req_size = running_max_req_size + self.waiting_req_list: List[Req] = waiting_req_list + + def append(self, req): + self.waiting_req_list.append(req) + return + + def _init_cache_list(self, current_batch: Batch): + if current_batch is not None: + self.cache_len_list = [ + (req.input_len + len(req.output_ids), req.max_output_len - len(req.output_ids) - 1) + for req in current_batch.reqs + ] + else: + self.cache_len_list = [] + + # @calculate_time(show=True, min_cost_ms=0.1) + def _can_add_new_req(self, req): + self.cache_len_list.append((req.input_len + 1, req.max_output_len - 1)) # hard to analysis + self.cache_len_list.sort(key=lambda x: -x[1]) + + left_out_len_array = np.array([e[1] for e in self.cache_len_list]) + # assert left_out_len_array.min() >= 0 + has_run_len_array = np.array([e[0] for e in self.cache_len_list]) + cum_run_len_array = np.cumsum(has_run_len_array) + size_array = np.arange(1, len(self.cache_len_list) + 1, 1) + + need_max_token_num = (left_out_len_array * size_array + cum_run_len_array).max() + # NOTE: change here < to <= + return need_max_token_num <= self.max_total_tokens and len(self.cache_len_list) <= self.running_max_req_size + + def generate_new_batch(self, current_batch: Batch = None): + if current_batch is not None and len(current_batch.reqs) >= self.running_max_req_size: + return None + self._init_cache_list(current_batch) + can_run_list = [] + new_batch_total_tokens = 0 + aborted_count = 0 + for req in self.waiting_req_list: + flag = self._can_add_new_req(req) + if req.aborted: + aborted_count += 1 + continue + if flag and new_batch_total_tokens + req.input_len <= self.batch_max_tokens: + can_run_list.append(req) + new_batch_total_tokens += req.input_len + else: + break + + if len(can_run_list) != 0: + new_batch = Batch(uuid.uuid4().hex, can_run_list) + self.waiting_req_list = self.waiting_req_list[len(can_run_list) + aborted_count :] + return new_batch + else: + return None + + def __len__(self): + return self.waiting_req_list.__len__() diff --git a/colossalai/inference/dynamic_batching/sampling_params.py b/colossalai/inference/dynamic_batching/sampling_params.py new file mode 100644 index 000000000000..a37a83390021 --- /dev/null +++ b/colossalai/inference/dynamic_batching/sampling_params.py @@ -0,0 +1,83 @@ +# Adapted from https://github.com/ModelTC/lightllm + +"""Sampling parameters for text generation.""" +from typing import List, Optional, Union + +_SAMPLING_EPS = 1e-5 + + +class SamplingParams: + def __init__( + self, + do_sample: bool = False, + presence_penalty: float = 0.0, + frequency_penalty: float = 0.0, + temperature: float = 1.0, + top_p: float = 1.0, + top_k: int = -1, # -1 is for all + ignore_eos: bool = False, + max_new_tokens: int = 256, + stop_sequences: Optional[Union[str, List[str]]] = None, # conditions to stop generation + ) -> None: + self.do_sample = do_sample + self.presence_penalty = presence_penalty + self.frequency_penalty = frequency_penalty + self.temperature = temperature + self.top_p = top_p + self.top_k = top_k + self.ignore_eos = ignore_eos + self.max_new_tokens = max_new_tokens + self.stop_sequences = stop_sequences + if self.do_sample == False: + self.temperature = 1.0 + self.top_p = 1.0 + self.top_k = 1 + if ( + self.temperature >= 0.0 and self.temperature < _SAMPLING_EPS + ): # temperature is too slow, change to greedy search + self.temperature = 1.0 + self.top_k = 1 + return + + def verify(self): + if self.presence_penalty < 0.0: + raise ValueError(f"presence_penalty must >= 0.0, got {self.presence_penalty}") + if self.frequency_penalty < 0.0: + raise ValueError(f"frequency_penalty must >= 0.0, got {self.frequency_penalty}") + if self.temperature <= 0.0: + raise ValueError(f"temperature must > 0.0, got {self.temperature}") + if self.top_p <= 0.0 or self.top_p > 1.0: + raise ValueError(f"top_p must in (0.0, 1.0], got {self.top_p}") + if self.top_k < -1 or self.top_k == 0: + raise ValueError(f"top_k must be -1 (disable), or at least 1, got {self.top_k}.") + if self.max_new_tokens < 1: + raise ValueError(f"max_new_tokens must be at least 1 , got {self.max_new_tokens}.") + return + + def stop_sentences_to_token_ids(self, tokenizer): + if self.stop_sequences is None: + self.stop_sequences = [] + else: + if isinstance(self.stop_sequences, str): + self.stop_sequences = [self.stop_sequences] + new_stop_sequences = [] + for stop_str in self.stop_sequences: + stop_str_ids = tokenizer.encode(stop_str) + if stop_str_ids is not None and len(stop_str_ids) >= 1: # remove bos_token_id + stop_str_ids = stop_str_ids[1:] + if len(stop_str_ids) > 0: + new_stop_sequences.append(stop_str_ids) + self.stop_sequences = new_stop_sequences + return + + def to_dict(self): + ret = {} + ret["do_sample"] = self.do_sample + ret["presence_penalty"] = self.presence_penalty + ret["frequency_penalty"] = self.frequency_penalty + ret["temperature"] = self.temperature + ret["top_p"] = self.top_p + ret["top_k"] = self.top_k + # if self.ignore_eos is not None: + # ret["ignore_eos"] = self.ignore_eos + return ret diff --git a/colossalai/inference/dynamic_batching/stats.py b/colossalai/inference/dynamic_batching/stats.py new file mode 100644 index 000000000000..524072861a3f --- /dev/null +++ b/colossalai/inference/dynamic_batching/stats.py @@ -0,0 +1,45 @@ +# Adapted from https://github.com/ModelTC/lightllm + +import time + + +class Stats: + def __init__(self, log_status, log_stats_interval) -> None: + self.log_stats = log_status + self.log_stats_interval = log_stats_interval + self.last_log_time = time.time() + self.all_tokens = 0 + self.output_tokens = 0 + self.prompt_tokens = 0 + return + + def count_prompt_tokens(self, run_batch): + if self.log_stats: + tokens = run_batch.input_tokens() + self.prompt_tokens += tokens + self.all_tokens += tokens + return + + def count_output_tokens(self, run_batch): + if self.log_stats: + tokens = len(run_batch.reqs) + self.output_tokens += tokens + self.all_tokens += tokens + return + + def print_stats(self): + if not self.log_stats: + return + + now = time.time() + if now - self.last_log_time > self.log_stats_interval: + print( + f"Avg tokens(prompt+generate) throughput: {self.all_tokens/(now-self.last_log_time):8.3f} tokens/s\n" + f"Avg prompt tokens throughput: {self.prompt_tokens/(now-self.last_log_time):8.3f} tokens/s\n" + f"Avg generate tokens throughput: {self.output_tokens/(now-self.last_log_time):8.3f} tokens/s" + ) + self.all_tokens = 0 + self.output_tokens = 0 + self.prompt_tokens = 0 + self.last_log_time = now + return diff --git a/colossalai/inference/manager.py b/colossalai/inference/manager.py new file mode 100644 index 000000000000..9672a50141a0 --- /dev/null +++ b/colossalai/inference/manager.py @@ -0,0 +1,296 @@ +# Adapted from https://github.com/ModelTC/lightllm + +import time +from typing import List + +from .dynamic_batching.get_tokenizer import get_tokenizer +from .dynamic_batching.infer_batch import InferBatch +from .dynamic_batching.io_struct import Batch, Req +from .dynamic_batching.req_queue import ReqQueue +from .dynamic_batching.sampling_params import SamplingParams +from .dynamic_batching.stats import Stats +from .tensor_parallel import TPInferEngine + + +class DynamicBatchManager: + def __init__( + self, + tp_engine: TPInferEngine, + max_total_token_num, + batch_max_tokens, + model, + tokenizer=None, + eos_id=None, + log_stats=True, + log_stats_interval=10, + running_batch: Batch = None, + waiting_req_list: List = [], + ): + """ + Args: tp_engine : The tp engine that dynamic batch manager hold, defined before dynamic batch manager + max_total_token_num : max_total_token_num for memory manager, default to: max batch size * (max input len + max output len) + batch_max_tokens : max tokens of one batch, default to (max input + output len) * num_requests + running_max_req_size : max request size of running batch, equals to MAX_BATCH_SIZE of tp engine + eos_id : The end token of a seq + model: the model weight dir path, the app will load config, weights and tokenizer from this dir + log_stats : whether to log stats + log_stats_interval : log stats interval + running_batch : running batch + waiting_req_list : list of waiting requests, initialized before dynamic batch manager + """ + self.engine = tp_engine + self.max_total_token_num = max_total_token_num + running_max_req_size = self.engine.max_batch_size if self.engine is not None else 2 + self.req_queue = ReqQueue(max_total_token_num, batch_max_tokens, running_max_req_size, waiting_req_list) + # all the inputs should be put into req_queue: waiting req list + assert max_total_token_num >= self.engine.max_batch_size * ( + self.engine.max_input_len + self.engine.max_output_len + ), "max_total_token_num should be greater than max_batch_size * (max_input_len+max_output_len)" + assert ( + batch_max_tokens >= self.engine.max_input_len + self.engine.max_output_len + ), "batch_max_tokens should be greater than (max_input_len+max_output_len)" + self.running_batch: Batch = running_batch + self.eos_id = eos_id + self.has_wait_tokens = 0 + self.max_wait_tokens = 10 + self.model = model + + self.stats_tool = Stats(log_stats, log_stats_interval) + self.mem_usage_interval = log_stats_interval * 2 + self.tokenizer = get_tokenizer(tokenizer_name=self.model) if tokenizer is None else tokenizer + if self.eos_id == None: + self.eos_id = self.tokenizer.eos_token_id + + def add_req(self, request_id: str, prompt_ids: List[int], sampling_params: SamplingParams, prompts: str = ""): + """ + Add new request to req queue, during initialization all requests are held in waiting list. + """ + sampling_params.max_new_tokens = ( + self.engine.max_output_len + if sampling_params.max_new_tokens > self.engine.max_output_len + else sampling_params.max_new_tokens + ) + req = Req(request_id, prompt_ids, sampling_params, prompts) + self.req_queue.append(req) + return + + def add_input(self, request_id, prompts, sampling_params): + """ + Encode and Add new input to req queue. support one sequence input for now. + """ + prompt_ids = self.tokenizer.encode(prompts) + prompt_len = len(prompt_ids) + if prompt_len > self.engine.max_input_len: + raise ValueError(f"the input prompt token len {prompt_len} is too long > {self.engine.max_input_len}") + sampling_params.stop_sentences_to_token_ids(self.tokenizer) + self.add_req(request_id, prompt_ids, sampling_params, prompts) + return + + def abort(self, request_id): + if self.running_batch is not None: + for req in self.running_batch.reqs: + if req.request_id == request_id: + req.has_generate_finished = True + req.aborted = True + for req in self.req_queue.waiting_req_list: + if req.request_id == request_id: + req.has_generate_finished = True + req.aborted = True + return + + def loop_for_fwd(self): + """ + The main loop for a dynamic batching process. + """ + counter_count = 0 + # self.running_batch is not None or self.req_queue.waiting_req_list + while self.running_batch is not None or self.req_queue.waiting_req_list: + yield from self._step() + counter_count += 1 + if self.running_batch is not None: + if counter_count % self.mem_usage_interval == 0: + print( + "current batch size:", + len(self.running_batch.reqs), + "token used ratio:", + self.running_batch.calcu_used_tokens() / self.max_total_token_num, + ) + self.stats_tool.print_stats() + + if self.running_batch is None: + time.sleep(0.1) # 10ms + + def _step(self): + """ + Logic for handling requests + """ + + if self.running_batch is None: + new_batch = self.req_queue.generate_new_batch(self.running_batch) + if new_batch is not None: + self.stats_tool.count_prompt_tokens(new_batch) + self.running_batch = new_batch + yield from self._prefill_batch(self.running_batch) + self._filter_runing_batch() + self.has_wait_tokens = 0 + return + + if self.has_wait_tokens < self.max_wait_tokens: + self.stats_tool.count_output_tokens(self.running_batch) + yield from self._decode_batch(self.running_batch) + self._filter_runing_batch() + self.has_wait_tokens += 1 + return + else: + new_mini_batch = self.req_queue.generate_new_batch(self.running_batch) + if new_mini_batch is not None: + self.stats_tool.count_prompt_tokens(new_mini_batch) + yield from self._prefill_batch(new_mini_batch) + if not new_mini_batch.is_clear(): + self._merge_batch(self.running_batch, new_mini_batch) + self.running_batch.merge(new_mini_batch) + self.has_wait_tokens = 0 + + else: + self.stats_tool.count_output_tokens(self.running_batch) + yield from self._decode_batch(self.running_batch) + self._filter_runing_batch() + self.has_wait_tokens += 1 + + return + + def _init_batch(self, batch: Batch, dtype="fp16"): + reqs = [r.to_rpc_obj() for r in batch.reqs] + batch_id = batch.batch_id + + import torch + + if dtype == "fp16": + dtype = torch.float16 + else: + assert False, "error dtype" + + batch_data = InferBatch.init_batch( + batch_id, + reqs, + dtype, + torch.cuda.current_device(), + self.engine.cache_manager, + self.engine.model.config.vocab_size, + self.engine.max_input_len + self.engine.max_output_len, + ) + self.engine.cache[batch_id] = batch_data + + def _prefill_batch(self, batch): + """ + For all batches, no matter it is a new batch or a mini batch, we need to do prefill first. + """ + self._init_batch(batch) + + # TODO: figure out if cache and batch id is needed + ans = self.engine._prefill_batch(batch.batch_id) + req_to_out_token_id = ans + self._add_token_id_to_req(batch, req_to_out_token_id) + has_new_finished_req = batch.mark_finished_req(self.eos_id, self.engine.max_output_len) + yield from self._handle_finish_req(batch, has_new_finished_req) + + # delete finished reqs + + def _decode_batch(self, batch: Batch): + """ + Decoding process + """ + ans = self.engine._decode_batch(batch.batch_id) + req_to_out_token_id = ans + self._add_token_id_to_req(batch, req_to_out_token_id) + has_new_finished_req = batch.mark_finished_req(self.eos_id, self.engine.max_output_len) + yield from self._handle_finish_req(batch, has_new_finished_req) + + def _filter_batch(self, batch: Batch): + batch_id = batch.batch_id + req_id_list = [r.request_id for r in batch.reqs] + batch = self.engine.cache.pop(batch_id) + filter_batch = batch.filter(req_id_list) + del batch + self.engine.cache[batch_id] = filter_batch + + def _merge_batch(self, batch1, batch2): + """ + Merge new mini batch into running batch. + """ + batch1 = self.engine.cache.pop(batch1.batch_id) + batch2 = self.engine.cache.pop(batch2.batch_id) + + m_batch = InferBatch.merge(batch1, batch2) + self.engine.cache[batch1.batch_id] = m_batch + del batch1 + del batch2 + + def _remove_batch(self, batch): + """ + Remove finished batch. + """ + batch = self.engine.cache.pop(batch.batch_id) + batch.free_self() + del batch + + def _handle_finish_req(self, batch: Batch, has_new_finished_req): + if has_new_finished_req: + finished_reqs = batch.filter_finished() + if batch.is_clear(): + self._remove_batch(batch) + else: + self._filter_batch(batch) + yield from self._output_process(finished_reqs) + + def _filter_runing_batch(self): + if self.running_batch is not None and self.running_batch.is_clear(): + self.running_batch = None + + def _add_token_id_to_req(self, batch: Batch, req_ans): + for req_id, (new_token_id, new_gen_metadata) in req_ans.items(): + req = batch.id_to_reqs[req_id] + req.output_ids.append(new_token_id) + req.output_metadata_list.append(new_gen_metadata) + return + + def _output_process(self, finished_reqs: List[Req]): + """ + Process the output of a batch. + """ + for req in finished_reqs: + output = self.tokenizer.decode(req.output_ids) + yield req.prompts + output + + def clean_up(self): + # this logic should be implemented in the future. + pass + + def generate(self, request_id, prompts, sampling_params): + """ + Generate the output of a request. + """ + self.add_input(request_id, prompts, sampling_params) + return self.loop_for_fwd() + + def is_running(self): + return self.running_batch is not None or self.req_queue.waiting_req_list + + +def start_dynamic_batching(args, tp_engine, waiting_req_list): + try: + batch_manager = DynamicBatchManager( + tp_engine=tp_engine, + max_total_token_num=args.max_total_token_num, + batch_max_tokens=args.batch_max_tokens, + eos_id=args.eos_id, + model=args.model, + log_stats=not args.disable_log_stats, + log_stats_interval=args.log_stats_interval, + waiting_req_list=waiting_req_list, + ) + + except Exception: + raise Exception + + return batch_manager diff --git a/colossalai/inference/quant/smoothquant/models/base_model.py b/colossalai/inference/quant/smoothquant/models/base_model.py index 6a1d96ecec8f..9554be9ea96b 100644 --- a/colossalai/inference/quant/smoothquant/models/base_model.py +++ b/colossalai/inference/quant/smoothquant/models/base_model.py @@ -87,7 +87,6 @@ def init_batch_state(self, max_output_len=256, **kwargs): batch_infer_state.start_loc = seq_start_indexes.to("cuda") batch_infer_state.block_loc = block_loc batch_infer_state.decode_layer_id = 0 - batch_infer_state.past_key_values_len = 0 batch_infer_state.is_context_stage = True batch_infer_state.set_cache_manager(self.cache_manager) batch_infer_state.cache_manager.free_all() diff --git a/colossalai/inference/quant/smoothquant/models/llama.py b/colossalai/inference/quant/smoothquant/models/llama.py index 4c3d6dcc0b23..30063857ac30 100644 --- a/colossalai/inference/quant/smoothquant/models/llama.py +++ b/colossalai/inference/quant/smoothquant/models/llama.py @@ -149,12 +149,6 @@ def forward( self.k_rotary_output_scale.item(), ) - # NOTE might want to revise - # need some way to record the length of past key values cache - # since we won't return past_key_value_cache right now - if infer_state.decode_layer_id == 0: # once per model.forward - infer_state.cache_manager.past_key_values_length += q_len # seq_len - def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, mem_manager): copy_kv_cache_to_dest(key_buffer, context_mem_index, mem_manager.key_buffer[layer_id]) copy_kv_cache_to_dest(value_buffer, context_mem_index, mem_manager.value_buffer[layer_id]) @@ -229,7 +223,7 @@ def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, infer_state.block_loc, infer_state.start_loc, infer_state.seq_len, - infer_state.cache_manager.past_key_values_length, + infer_state.max_len_in_batch, ) attn_output = attn_output.view(bsz, q_len, self.num_heads * self.head_dim) @@ -592,17 +586,13 @@ def llama_model_forward( else: raise ValueError("You have to specify either input_ids or inputs_embeds") - seq_length_with_past = seq_length - past_key_values_length = 0 - infer_state = self.infer_state + if infer_state.is_context_stage: + past_key_values_length = 0 + else: + past_key_values_length = infer_state.max_len_in_batch - 1 - if past_key_values is not None: - # NOT READY FOR PRIME TIME - # dummy but work, revise it - past_key_values_length = infer_state.cache_manager.past_key_values_length - # past_key_values_length = past_key_values[0][0].shape[2] - seq_length_with_past = seq_length_with_past + past_key_values_length + seq_length_with_past = seq_length + past_key_values_length # NOTE: differentiate with prefill stage # block_loc require different value-assigning method for two different stage @@ -623,9 +613,7 @@ def llama_model_forward( infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index else: print(f" *** Encountered allocation non-contiguous") - print( - f" infer_state.cache_manager.past_key_values_length: {infer_state.cache_manager.past_key_values_length}" - ) + print(f" infer_state.cache_manager.max_len_in_batch: {infer_state.max_len_in_batch}") infer_state.decode_is_contiguous = False alloc_mem = infer_state.cache_manager.alloc(batch_size) infer_state.decode_mem_index = alloc_mem @@ -713,6 +701,7 @@ def llama_model_forward( infer_state.is_context_stage = False infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") infer_state.seq_len += 1 + infer_state.max_len_in_batch += 1 next_cache = next_decoder_cache if use_cache else None if not return_dict: diff --git a/colossalai/inference/tensor_parallel/engine.py b/colossalai/inference/tensor_parallel/engine.py index 216b134f5fab..e410532d83eb 100644 --- a/colossalai/inference/tensor_parallel/engine.py +++ b/colossalai/inference/tensor_parallel/engine.py @@ -13,6 +13,8 @@ from .batch_infer_state import BatchInferState from .kvcache_manager import MemoryManager +# from dynamic_batching.infer_batch import InferBatch + DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2 _supported_models = [ @@ -61,7 +63,6 @@ def __init__( self.max_input_len = max_input_len self.max_output_len = max_output_len self.max_total_token_num = self.max_batch_size * (self.max_input_len + self.max_output_len) - # Constraints relatable with specs of devices and model # This may change into an optional arg in the future assert self.max_batch_size <= 64, "Max batch size exceeds the constraint" @@ -96,6 +97,8 @@ def __init__( self.shard_config = shard_config self.model = None + self.cache = {} + # optimize the original model by sharding with ShardFormer self._optimize_model(model=model.to(device)) @@ -284,7 +287,6 @@ def prepare_batch_state(self, inputs) -> BatchInferState: attention_mask = [attention_mask] if attention_mask is not None else attention_mask batch_size = len(input_ids_list) - seq_start_indexes = torch.zeros(batch_size, dtype=torch.int32, device="cuda") seq_lengths = torch.zeros(batch_size, dtype=torch.int32, device="cuda") start_index = 0 @@ -318,6 +320,7 @@ def prepare_batch_state(self, inputs) -> BatchInferState: batch_infer_state.past_key_values_len = 0 batch_infer_state.is_context_stage = True batch_infer_state.set_cache_manager(self.cache_manager) + return batch_infer_state @torch.no_grad() @@ -381,6 +384,85 @@ def _update_batch_state(self, infer_state: Optional[BatchInferState]) -> None: infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device=device) infer_state.seq_len += 1 + @torch.no_grad() + def forward(self, batch_id, is_prefill): + """ + Forward is used in Dynamic Batching Manager + """ + batch = self.cache.pop(batch_id) + if is_prefill: + input_ = torch.tensor(batch.all_input_ids).cuda() + else: + input_ = batch.input_ids.reshape(len(batch), 1) + + batch_args = { + "batch_size": len(batch), + "max_len_in_batch": batch.nopad_max_len_in_batch, + "block_loc": batch.nopad_b_loc, + "start_loc": batch.nopad_b_start_loc, + "seq_len": batch.nopad_b_seq_len, + "cache_manager": batch.cache_manager, + "is_context_stage": is_prefill, + } + + infer_state = BatchInferState(**batch_args) + model = self.model + if isinstance(model, LlamaForCausalLM): + model = self.model.model + elif isinstance(model, BloomForCausalLM): + model = self.model.transformer + + setattr(model, "infer_state", infer_state) + output = self.model.forward(input_ids=input_) + logits = output.logits + # bsz, seq_len, vocab_size + prob_out = torch.softmax( + logits[ + :, + -1, + ], + dim=-1, + ).squeeze(1) + # prob_out: bsz, vocab_size + predict_ids = torch.argmax(prob_out, dim=-1, keepdim=True) + prob_out = torch.log(prob_out).detach().cpu().numpy() + predict_ids = predict_ids.detach().cpu().numpy() + # [ batch_size, 1 ] + + output_dict = {} + new_input_ids = [] + for i, (r, all_input_ids, next_token_id, next_token_logprob) in enumerate( + zip(batch.requests, batch.all_input_ids, predict_ids, prob_out) + ): + next_token_id = int(next_token_id) + next_token_logprob = next_token_logprob[next_token_id] + # all_input_ids_tensor = torch.tensor(all_input_ids, dtype=torch.long, device="cuda") + all_input_ids.append(next_token_id) + # all_input_ids_tensor = None + new_input_ids.append(next_token_id) + batch.all_input_ids[i] = all_input_ids + batch.input_lengths[i] += 1 + batch.out_token_id_counts[i][next_token_id] += 1 + metadata = { + "id": int(next_token_id), + "logprob": float(next_token_logprob), + } + output_dict[r["request_id"]] = (int(next_token_id), metadata) + + batch.input_ids = torch.tensor(new_input_ids, dtype=torch.long).cuda() + batch.nopad_total_token_num += len(batch) + batch.nopad_max_len_in_batch += 1 # NOTE: we may repalce this + self.cache[batch.batch_id] = batch + return output_dict + + @torch.no_grad() + def _prefill_batch(self, batch_id): + return self.forward(batch_id, is_prefill=True) + + @torch.no_grad() + def _decode_batch(self, batch_id): + return self.forward(batch_id, is_prefill=False) + # might want to create a sequence pool # add a single request/sequence/input text at a time and record its length # In other words, store the actual length of input tokens representing a single input text diff --git a/colossalai/inference/tensor_parallel/kvcache_manager.py b/colossalai/inference/tensor_parallel/kvcache_manager.py index c9e7aaae0844..91bb96a1f1f0 100644 --- a/colossalai/inference/tensor_parallel/kvcache_manager.py +++ b/colossalai/inference/tensor_parallel/kvcache_manager.py @@ -32,7 +32,7 @@ def __init__( ): self.logger = logging.get_logger(__name__) self.available_size = size - self.past_key_values_length = 0 + self.max_len_in_batch = 0 self._init_mem_states(size, device) self._init_kv_buffers(size, device, dtype, head_num, head_dim, layer_num) @@ -102,5 +102,5 @@ def free_all(self): """free all memory by updating memory states""" self.available_size = len(self.mem_state) self.mem_state[:] = 1 - self.past_key_values_length = 0 + self.max_len_in_batch = 0 self.logger.info("freed all space of memory manager") diff --git a/colossalai/inference/tensor_parallel/modeling/bloom.py b/colossalai/inference/tensor_parallel/modeling/bloom.py index 27a26caabefa..d84c567ead0c 100644 --- a/colossalai/inference/tensor_parallel/modeling/bloom.py +++ b/colossalai/inference/tensor_parallel/modeling/bloom.py @@ -133,17 +133,11 @@ def bloom_model_forward( assert hasattr(self, "infer_state") infer_state = self.infer_state - # Compute alibi tensor: check build_alibi_tensor documentation - seq_length_with_past = seq_length - past_key_values_length = 0 - # if self.cache_manager.past_key_values_length > 0: - if infer_state.cache_manager.past_key_values_length > 0: - # update the past key values length in cache manager, - # NOTE use BatchInferState.past_key_values_length instead the one in cache manager - past_key_values_length = infer_state.cache_manager.past_key_values_length - seq_length_with_past = seq_length_with_past + past_key_values_length - # infer_state.cache_manager = self.cache_manager + if infer_state.is_context_stage: + past_key_values_length = 0 + else: + past_key_values_length = infer_state.max_len_in_batch - 1 if use_cache and seq_length != 1: # prefill stage @@ -160,21 +154,19 @@ def bloom_model_forward( infer_state.decode_mem_index = alloc_mem[0] infer_state.decode_mem_start = alloc_mem[1] infer_state.decode_mem_end = alloc_mem[2] - infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index + infer_state.block_loc[:, infer_state.max_len_in_batch - 1] = infer_state.decode_mem_index else: print(f" *** Encountered allocation non-contiguous") - print( - f" infer_state.cache_manager.past_key_values_length: {infer_state.cache_manager.past_key_values_length}" - ) + print(f" infer_state.max_len_in_batch : {infer_state.max_len_in_batch}") infer_state.decode_is_contiguous = False alloc_mem = infer_state.cache_manager.alloc(batch_size) infer_state.decode_mem_index = alloc_mem # infer_state.decode_key_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") # infer_state.decode_value_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") - infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index + infer_state.block_loc[:, infer_state.max_len_in_batch - 1] = infer_state.decode_mem_index if attention_mask is None: - attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device) + attention_mask = torch.ones((batch_size, infer_state.max_len_in_batch), device=hidden_states.device) else: attention_mask = attention_mask.to(hidden_states.device) @@ -195,6 +187,7 @@ def bloom_model_forward( past_key_values_length=past_key_values_length, ) + infer_state.decode_layer_id = 0 for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -228,6 +221,7 @@ def custom_forward(*inputs): infer_state=infer_state, ) + infer_state.decode_layer_id += 1 hidden_states = outputs[0] if use_cache is True: presents = presents + (outputs[1],) @@ -247,7 +241,7 @@ def custom_forward(*inputs): # and update these information in engine.generate after model foward called infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") infer_state.seq_len += 1 - infer_state.decode_layer_id = 0 + infer_state.max_len_in_batch += 1 if not return_dict: return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) @@ -453,9 +447,6 @@ def bloom_attention_forward( mem_manager = infer_state.cache_manager layer_id = infer_state.decode_layer_id - if layer_id == 0: # once per model.forward - infer_state.cache_manager.past_key_values_length += q_length # += 1 - if infer_state.is_context_stage: # context process max_input_len = q_length @@ -506,15 +497,12 @@ def bloom_attention_forward( b_loc, b_start_loc, b_seq_len, - infer_state.cache_manager.past_key_values_length, + infer_state.max_len_in_batch, alibi, ) context_layer = output.view(batch_size, q_length, H * D_HEAD) - # update layer id - infer_state.decode_layer_id += 1 - # NOTE: always set present as none for now, instead of returning past key value to the next decoding, # we create the past key value pair from the cache manager present = None diff --git a/colossalai/inference/tensor_parallel/modeling/chatglm2.py b/colossalai/inference/tensor_parallel/modeling/chatglm2.py index b8274d3c660f..69a92c4fe746 100644 --- a/colossalai/inference/tensor_parallel/modeling/chatglm2.py +++ b/colossalai/inference/tensor_parallel/modeling/chatglm2.py @@ -19,8 +19,11 @@ from ._utils import copy_kv_to_mem_cache try: - from lightllm.models.llama2.triton_kernel.context_flashattention_nopad import context_attention_fwd as lightllm_llama2_context_attention_fwd from lightllm.models.chatglm2.triton_kernel.rotary_emb import rotary_emb_fwd as chatglm2_rotary_emb_fwd + from lightllm.models.llama2.triton_kernel.context_flashattention_nopad import ( + context_attention_fwd as lightllm_llama2_context_attention_fwd, + ) + HAS_LIGHTLLM_KERNEL = True except: print("please install lightllm from source to run inference: https://github.com/ModelTC/lightllm") @@ -118,13 +121,12 @@ def chatglm_for_conditional_generation_forward( else: raise ValueError("You have to specify either input_ids or inputs_embeds") - past_key_values_length = 0 + if infer_state.is_context_stage: + past_key_values_length = 0 + else: + past_key_values_length = infer_state.max_len_in_batch - 1 - # NOT READY FOR PRIME TIME - # dummy but work, revise it - past_key_values_length = infer_state.cache_manager.past_key_values_length seq_length_with_past = seq_length + past_key_values_length - infer_state.seq_length_with_past = seq_length_with_past # prefill stage at first if use_cache and seq_length != 1: @@ -272,7 +274,6 @@ def chatglm_model_forward( infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") infer_state.seq_len += 1 infer_state.max_len_in_batch += 1 - infer_state.cache_manager.past_key_values_length += seq_length if not return_dict: return tuple( @@ -487,7 +488,7 @@ def chatglm_flash_attn_kvcache_forward( attn_output.view(-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head), infer_state.start_loc, infer_state.seq_len, - infer_state.seq_length_with_past, + infer_state.max_len_in_batch, ) else: diff --git a/colossalai/inference/tensor_parallel/modeling/llama.py b/colossalai/inference/tensor_parallel/modeling/llama.py index a3937f6f10ba..a17b901dc7fd 100644 --- a/colossalai/inference/tensor_parallel/modeling/llama.py +++ b/colossalai/inference/tensor_parallel/modeling/llama.py @@ -74,12 +74,11 @@ def llama_model_forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ): - batch_size = input_ids.shape[0] # input_ids.shape[0] - infer_state = self.infer_state return_dict = return_dict if return_dict is not None else self.config.use_return_dict + use_cache = use_cache if use_cache is not None else self.config.use_cache # retrieve input_ids and inputs_embeds if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") @@ -90,15 +89,10 @@ def llama_model_forward( else: raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - seq_length_with_past = seq_length - past_key_values_length = 0 - - if past_key_values is not None: - # NOT READY FOR PRIME TIME - # dummy but work, revise it - past_key_values_length = infer_state.cache_manager.past_key_values_length - # past_key_values_length = past_key_values[0][0].shape[2] - seq_length_with_past = seq_length_with_past + past_key_values_length + if infer_state.is_context_stage: + past_key_values_length = 0 + else: + past_key_values_length = infer_state.max_len_in_batch - 1 # NOTE: differentiate with prefill stage # block_loc require different value-assigning method for two different stage @@ -118,23 +112,23 @@ def llama_model_forward( infer_state.decode_mem_index = alloc_mem[0] infer_state.decode_mem_start = alloc_mem[1] infer_state.decode_mem_end = alloc_mem[2] - infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index + infer_state.block_loc[:, infer_state.max_len_in_batch - 1] = infer_state.decode_mem_index else: print(f" *** Encountered allocation non-contiguous") - print( - f" infer_state.cache_manager.past_key_values_length: {infer_state.cache_manager.past_key_values_length}" - ) + print(f" infer_state.max_len_in_batch : {infer_state.max_len_in_batch}") infer_state.decode_is_contiguous = False alloc_mem = infer_state.cache_manager.alloc(batch_size) infer_state.decode_mem_index = alloc_mem # infer_state.decode_key_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") # infer_state.decode_value_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") - infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index + infer_state.block_loc[:, infer_state.max_len_in_batch - 1] = infer_state.decode_mem_index + if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = torch.arange( past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device ) + position_ids = position_ids.repeat(batch_size, 1) position_ids = position_ids.unsqueeze(0).view(-1, seq_length) else: position_ids = position_ids.view(-1, seq_length).long() @@ -146,11 +140,12 @@ def llama_model_forward( infer_state.position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view( position_ids.view(-1).shape[0], -1 ) + else: seq_len = infer_state.seq_len infer_state.position_cos = torch.index_select(self._cos_cached, 0, seq_len - 1).view(seq_len.shape[0], -1) infer_state.position_sin = torch.index_select(self._sin_cached, 0, seq_len - 1).view(seq_len.shape[0], -1) - infer_state.other_kv_index = infer_state.block_loc[0, seq_length_with_past - 1].item() + infer_state.other_kv_index = infer_state.block_loc[0, infer_state.max_len_in_batch - 1].item() if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) @@ -158,7 +153,7 @@ def llama_model_forward( # embed positions if attention_mask is None: attention_mask = torch.ones( - (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device + (batch_size, infer_state.max_len_in_batch), dtype=torch.bool, device=inputs_embeds.device ) attention_mask = self._prepare_decoder_attention_mask( @@ -173,7 +168,6 @@ def llama_model_forward( next_decoder_cache = () if use_cache else None infer_state.decode_layer_id = 0 - for idx, decoder_layer in enumerate(self.layers): past_key_value = past_key_values[idx] if past_key_values is not None else None # NOTE: modify here for passing args to decoder layer @@ -197,8 +191,9 @@ def llama_model_forward( # update indices # infer_state.block_loc[:, infer_state.max_len_in_batch-1] = infer_state.total_token_num + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") - infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") + infer_state.start_loc += torch.arange(0, batch_size, dtype=torch.int32, device="cuda") infer_state.seq_len += 1 + infer_state.max_len_in_batch += 1 if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) @@ -224,7 +219,6 @@ def llama_decoder_layer_forward( residual = hidden_states hidden_states = self.input_layernorm(hidden_states) - # Self Attention hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, @@ -280,11 +274,8 @@ def llama_flash_attn_kvcache_forward( # NOTE might want to revise # need some way to record the length of past key values cache # since we won't return past_key_value_cache right now - if infer_state.decode_layer_id == 0: # once per model.forward - infer_state.cache_manager.past_key_values_length += q_len # seq_len cos, sin = infer_state.position_cos, infer_state.position_sin - # print("shape ", cos.shape, query_states.view(-1, self.num_heads, self.head_dim).shape, ) llama_rotary_embedding_fwd(query_states.view(-1, self.num_heads, self.head_dim), cos, sin) llama_rotary_embedding_fwd(key_states.view(-1, self.num_key_value_heads, self.head_dim), cos, sin) @@ -295,7 +286,6 @@ def llama_flash_attn_kvcache_forward( if infer_state.is_context_stage: # first token generation - # copy key and value calculated in current step to memory manager copy_kv_to_mem_cache( infer_state.decode_layer_id, @@ -304,7 +294,6 @@ def llama_flash_attn_kvcache_forward( infer_state.context_mem_index, infer_state.cache_manager, ) - attn_output = torch.empty_like(query_states) if self.num_key_value_groups == 1: @@ -315,7 +304,7 @@ def llama_flash_attn_kvcache_forward( attn_output, infer_state.start_loc, infer_state.seq_len, - infer_state.cache_manager.past_key_values_length, + infer_state.max_len_in_batch, ) else: lightllm_llama2_context_attention_fwd( @@ -325,7 +314,7 @@ def llama_flash_attn_kvcache_forward( attn_output, infer_state.start_loc, infer_state.seq_len, - infer_state.cache_manager.past_key_values_length, + infer_state.max_len_in_batch, ) else: if infer_state.decode_is_contiguous: @@ -363,7 +352,7 @@ def llama_flash_attn_kvcache_forward( infer_state.block_loc, infer_state.start_loc, infer_state.seq_len, - infer_state.cache_manager.past_key_values_length, + infer_state.max_len_in_batch, ) else: Llama2TokenAttentionForwards.token_attn( @@ -374,7 +363,7 @@ def llama_flash_attn_kvcache_forward( infer_state.block_loc, infer_state.start_loc, infer_state.seq_len, - infer_state.cache_manager.past_key_values_length, + infer_state.max_len_in_batch, infer_state.other_kv_index, ) diff --git a/colossalai/kernel/triton/__init__.py b/colossalai/kernel/triton/__init__.py index 1fe292289f3d..20da71d394bd 100644 --- a/colossalai/kernel/triton/__init__.py +++ b/colossalai/kernel/triton/__init__.py @@ -2,7 +2,6 @@ import triton HAS_TRITON = True - except ImportError: HAS_TRITON = False print("Triton is not installed. Please install Triton to use Triton kernels.") diff --git a/colossalai/kernel/triton/copy_kv_cache_dest.py b/colossalai/kernel/triton/copy_kv_cache_dest.py index 0ce6b09e54dc..b8e6ab1d05ad 100644 --- a/colossalai/kernel/triton/copy_kv_cache_dest.py +++ b/colossalai/kernel/triton/copy_kv_cache_dest.py @@ -10,7 +10,6 @@ print("please install triton from https://github.com/openai/triton") if HAS_TRITON: - # adapted from https://github.com/ModelTC/lightllm/blob/5c559dd7981ed67679a08a1e09a88fb4c1550b3a/lightllm/common/triton_kernel/destindex_copy_kv.py @triton.jit def _fwd_copy_kv_cache_dest( @@ -53,7 +52,6 @@ def copy_kv_cache_to_dest(k_ptr, dest_index_ptr, out): assert head_dim == out.shape[2], "head_dim should be the same for k_ptr and out" num_warps = 2 - _fwd_copy_kv_cache_dest[(seq_len,)]( k_ptr, dest_index_ptr, diff --git a/requirements/requirements-test.txt b/requirements/requirements-test.txt index 467f83610eb0..f54b13c7e43c 100644 --- a/requirements/requirements-test.txt +++ b/requirements/requirements-test.txt @@ -18,4 +18,6 @@ SentencePiece ninja flash_attn==2.0.5 datasets +pydantic +ray #auto-gptq now not support torch1.12 diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 19cb7a154a01..095617d76355 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -11,6 +11,8 @@ ninja torch>=1.12 safetensors einops +pydantic +ray sentencepiece google protobuf diff --git a/tests/kit/model_zoo/transformers/llama.py b/tests/kit/model_zoo/transformers/llama.py index 041de6b90f8d..4730642705ff 100644 --- a/tests/kit/model_zoo/transformers/llama.py +++ b/tests/kit/model_zoo/transformers/llama.py @@ -27,8 +27,10 @@ def data_gen(): # tokenized_input = tokenizer(input, return_tensors='pt').to('cuda') # ----------------------------------- - input_ids = torch.Tensor([[1, 15043, 29892, 590, 11203, 338, 274, 1082]]).long() - attention_mask = torch.Tensor([[1, 1, 1, 1, 1, 1, 1, 1]]).long() + input_ids = torch.Tensor( + [[1, 15043, 29892, 590, 11203, 338, 274, 1082], [1, 15043, 29892, 590, 11203, 338, 274, 1082]] + ).long() + attention_mask = torch.Tensor([[1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1]]).long() return dict(input_ids=input_ids, attention_mask=attention_mask) # label is needed for casual lm diff --git a/tests/test_infer/test_chatglm2_infer.py b/tests/test_infer/test_chatglm2_infer.py index 399b70e1460e..f9f7670c4500 100644 --- a/tests/test_infer/test_chatglm2_infer.py +++ b/tests/test_infer/test_chatglm2_infer.py @@ -52,7 +52,6 @@ def run_chatglm2_test(test_config): "attention_mask": torch.ones((BATCH_SIZE, MAX_INPUT_LEN), device="cuda"), } outputs = infer_engine.generate(input_tokens, **generate_kwargs) - assert outputs is not None diff --git a/tests/test_infer/test_dynamic_batching/config.yaml b/tests/test_infer/test_dynamic_batching/config.yaml new file mode 100644 index 000000000000..0ac778a3c7b3 --- /dev/null +++ b/tests/test_infer/test_dynamic_batching/config.yaml @@ -0,0 +1,14 @@ +engine_config: + model: MODEL_PATH + tensor_parallel_size: 1 + max_batch_size: 2 + max_input_len: 1024 + max_output_len: 512 +# config for app router deployment +# Resources assigned to each model replica. This should correspond to Ray AIR ScalingConfig. +router_config: + max_total_token_num: 4096 + batch_max_tokens: 4096 + disable_log_stats: False + log_stats_interval: 10 + model: MODEL_PATH diff --git a/tests/test_infer/test_dynamic_batching/test_async_engine.py b/tests/test_infer/test_dynamic_batching/test_async_engine.py new file mode 100644 index 000000000000..512aa7430983 --- /dev/null +++ b/tests/test_infer/test_dynamic_batching/test_async_engine.py @@ -0,0 +1,61 @@ +import asyncio +import os +import uuid + +import pytest + +import colossalai +from colossalai.inference.async_engine import Async_Engine +from colossalai.inference.dynamic_batching.ray_init_config import RayInitConfig +from colossalai.inference.dynamic_batching.sampling_params import SamplingParams +from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn + +PATH = "config.yaml" + + +def run_async_engine(path: str): + if not os.path.exists(path): + return + + config = RayInitConfig.from_yaml_path(path) + engine_config = config.engine_config_data + model = engine_config.model + if model is None or not os.path.exists(model): + return + + prompt = "Introduce some landmarks in London.\n The Tower of London is a historic castle on the north bank of the River Thames in central London. It was founded towards the end of 10" + sampling_params = SamplingParams() + asyncio.run(asy_for_loop_test(config, prompt, sampling_params)) + + +async def get_result(engine, prompt, sampling_params): + request_id = str(uuid.uuid4().hex) + results = engine.generate(request_id, prompt, sampling_params) + async for result in results: + # print(result) + assert result is not None + + +async def asy_for_loop_test(config, prompt, sampling_params): + router_config = config.router_config_data + engine_config = config.engine_config_data + engine = Async_Engine(router_config=router_config, engine_config=engine_config) + for i in range(10): + print("in for loop", i) + await get_result(engine, prompt, sampling_params) + + +def check_async_engine(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_async_engine(PATH) + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_async_engine(): + spawn(check_async_engine, 1) + + +if __name__ == "__main__": + test_async_engine() diff --git a/tests/test_infer/test_dynamic_batching/test_dynamic_batching_manager.py b/tests/test_infer/test_dynamic_batching/test_dynamic_batching_manager.py new file mode 100644 index 000000000000..78df0d304096 --- /dev/null +++ b/tests/test_infer/test_dynamic_batching/test_dynamic_batching_manager.py @@ -0,0 +1,95 @@ +import pytest +from transformers import LlamaForCausalLM +from transformers.models.llama.configuration_llama import LlamaConfig + +import colossalai +from colossalai.inference.dynamic_batching.io_struct import Req +from colossalai.inference.dynamic_batching.sampling_params import SamplingParams +from colossalai.inference.manager import DynamicBatchManager +from colossalai.inference.tensor_parallel import TPInferEngine +from colossalai.shardformer import ShardConfig +from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn + +TP_SIZE = 1 +BATCH_SIZE = 2 +MAX_INPUT_LEN = 48 +MAX_OUTPUT_LEN = 256 + + +def run(): + sampling_params = SamplingParams() + + req1 = Req(0, [1], sampling_params) + req2 = Req(1, [2], sampling_params) + req3 = Req(2, [3], sampling_params) + # req 1-3 are initiliazed as token forward requests + req4 = Req(3, [10, 10, 10, 9, 1], sampling_params) + waiting_list = [] + waiting_list.append(req1) + waiting_list.append(req2) + waiting_list.append(req3) + + # init model and tp engine + llama_config = LlamaConfig(num_hidden_layers=2, bos_token_id=0, eos_token_id=1, vocab_size=1200, hidden_size=1024) + model = LlamaForCausalLM(llama_config) + model = model.half() + + shard_config = ShardConfig(enable_tensor_parallelism=False, inference_only=True) + infer_engine = TPInferEngine(model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) + + dynamic_batch_manager = DynamicBatchManager( + tp_engine=infer_engine, + max_total_token_num=640, + batch_max_tokens=608, + eos_id=0, + log_stats=False, + log_stats_interval=10, + waiting_req_list=waiting_list, + model="llama", + ) + before_add = len(dynamic_batch_manager.req_queue) + + # test add req function + dynamic_batch_manager.add_req(req4.request_id, req4.prompt_ids, req4.sample_params) + assert len(dynamic_batch_manager.req_queue.waiting_req_list) == before_add + 1 + + # test abort function + dynamic_batch_manager.abort(req4.request_id) + assert dynamic_batch_manager.req_queue.waiting_req_list[-1].aborted == True + + # test filter batch function, loop_for_fwd, _step, _init_batch and _prefill/_decode batch are tested + batch = dynamic_batch_manager.req_queue.generate_new_batch() + assert len(batch) == 2 + + dynamic_batch_manager._init_batch(batch) + assert dynamic_batch_manager.engine.cache[batch.batch_id] is not None + + batch.reqs[0].has_generate_finished = True + # filter one finished + batch.filter_finished() + dynamic_batch_manager._filter_batch(batch) + assert len(dynamic_batch_manager.engine.cache) == 1 + + # test merge batch + new_batch = dynamic_batch_manager.req_queue.generate_new_batch(batch) + assert len(new_batch) == 1 + dynamic_batch_manager._init_batch(new_batch) + dynamic_batch_manager._merge_batch(batch, new_batch) + + assert len(dynamic_batch_manager.engine.cache[batch.batch_id]) == 2 + + +def check_dynamic_batching_manager(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_dynamic_batching_manager(): + spawn(check_dynamic_batching_manager, 1) + + +if __name__ == "__main__": + test_dynamic_batching_manager() diff --git a/tests/test_infer/test_dynamic_batching/test_offline_dynamic_batching.py b/tests/test_infer/test_dynamic_batching/test_offline_dynamic_batching.py new file mode 100644 index 000000000000..9925a80b6e77 --- /dev/null +++ b/tests/test_infer/test_dynamic_batching/test_offline_dynamic_batching.py @@ -0,0 +1,84 @@ +from dataclasses import dataclass + +import pytest +import torch +from packaging import version +from transformers import LlamaForCausalLM +from transformers.models.llama.configuration_llama import LlamaConfig + +import colossalai +from colossalai.inference.dynamic_batching.io_struct import Req +from colossalai.inference.dynamic_batching.sampling_params import SamplingParams +from colossalai.inference.manager import start_dynamic_batching +from colossalai.inference.tensor_parallel import TPInferEngine +from colossalai.shardformer import ShardConfig +from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn + +TP_SIZE = 1 +MAX_BATCH_SIZE = 2 +MAX_INPUT_LEN = 5 +MAX_OUTPUT_LEN = 16 +CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5") + + +@dataclass +class args: + max_total_token_num: int + batch_max_tokens: int + model: str + eos_id: int + disable_log_stats: bool + log_stats_interval: int + + +def run(): + arg = args( + max_total_token_num=42, + model="llama", + batch_max_tokens=42, + eos_id=0, + disable_log_stats=False, + log_stats_interval=10, + ) + sampling_params = SamplingParams() + + req1 = Req(0, [0, 0, 10, 6, 8], sampling_params) + req2 = Req(1, [10, 10, 10, 10, 10], sampling_params) + req3 = Req(2, [0, 0, 10, 10, 10], sampling_params) + req4 = Req(3, [0, 0, 10, 10, 10], sampling_params) + + waiting_list = [] + waiting_list.append(req1) + waiting_list.append(req2) + waiting_list.append(req3) + waiting_list.append(req4) + + llama_config = LlamaConfig(num_hidden_layers=2, bos_token_id=0, eos_token_id=1, vocab_size=30000, hidden_size=1024) + model = LlamaForCausalLM(llama_config) + model = model.half() + + shard_config = ShardConfig(enable_tensor_parallelism=True if TP_SIZE > 1 else False, inference_only=True) + + infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) + batch_manager = start_dynamic_batching(arg, tp_engine=infer_engine, waiting_req_list=waiting_list) + + ans_gen = batch_manager.generate(request_id=5, prompts="hello", sampling_params=sampling_params) + for result in ans_gen: + assert result is not None + + +def check_dynamic_forward(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run() + + +@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5") +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_dynamic_batching(): + spawn(check_dynamic_forward, TP_SIZE) + + +if __name__ == "__main__": + test_dynamic_batching() diff --git a/tests/test_infer/test_dynamic_batching/test_ray_dist.py b/tests/test_infer/test_dynamic_batching/test_ray_dist.py new file mode 100644 index 000000000000..a840407d5867 --- /dev/null +++ b/tests/test_infer/test_dynamic_batching/test_ray_dist.py @@ -0,0 +1,66 @@ +import asyncio +import os +import uuid + +import pytest + +import colossalai +from colossalai.inference.dynamic_batching.ray_dist_init import Driver +from colossalai.inference.dynamic_batching.ray_init_config import RayInitConfig +from colossalai.inference.dynamic_batching.sampling_params import SamplingParams +from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn + +PATH = "config.yaml" + + +def run_ray_dist(path: str): + if not os.path.exists(path): + return + config = RayInitConfig.from_yaml_path(path) + router_config = config.router_config_data + engine_config = config.engine_config_data + model = engine_config.model + if model is None or not os.path.exists(model): + return + driver = Driver(router_config=router_config, engine_config=engine_config) + prompt = "Introduce some landmarks in Beijing" + + request_id = str(uuid.uuid4().hex) + sampling_params = SamplingParams() + print("sampling_params: ", sampling_params) + + async def get_result(request_id, prompt, sampling_params): + return await driver.async_generate(request_id, prompt, sampling_params) + + for test_async in [True, False]: + if test_async: + print("test_async: ", test_async) + result = asyncio.run(get_result(request_id, prompt, sampling_params)) + assert result is not None + print("result: ", result) + else: + print("test_async: ", test_async) + result = driver.generate(request_id, prompt, sampling_params) + assert result is not None + print("result: ", result) + + is_running = None + is_running = driver.is_running() + assert is_running is not None + print("is_running: ", is_running) + + +def check_ray_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_ray_dist(PATH) + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_ray_dist(): + spawn(check_ray_dist, 1) + + +if __name__ == "__main__": + test_ray_dist() From 62eb99f8f9d23f5a14cbdcf2cb0ec653f025cef5 Mon Sep 17 00:00:00 2001 From: Cuiqing Li Date: Mon, 30 Oct 2023 14:04:37 +0800 Subject: [PATCH 18/46] [Kernels]Updated Triton kernels into 2.1.0 and adding flash-decoding for llama token attention (#4965) * adding flash-decoding * clean * adding kernel * adding flash-decoding * add integration * add * adding kernel * adding kernel * adding triton 2.1.0 features for inference * update bloom triton kernel * remove useless vllm kernels * clean codes * fix * adding files * fix readme * update llama flash-decoding --------- Co-authored-by: cuiqing.li --- colossalai/inference/README.md | 15 +- .../inference/tensor_parallel/engine.py | 1 + .../tensor_parallel/modeling/bloom.py | 11 +- .../tensor_parallel/modeling/llama.py | 182 ++++++++++-------- .../tensor_parallel/policies/llama.py | 5 +- colossalai/kernel/triton/context_attention.py | 137 ++++++------- examples/inference/bench_llama.py | 4 +- tests/test_infer/test_bloom_infer.py | 8 +- tests/test_infer/test_chatglm2_infer.py | 8 +- tests/test_infer/test_llama2_infer.py | 8 +- tests/test_infer/test_llama_infer.py | 8 +- .../test_infer_ops/cuda/test_vllm_rmsnorm.py | 60 ------ .../cuda/test_vllm_rotary_embedding.py | 153 --------------- 13 files changed, 226 insertions(+), 374 deletions(-) delete mode 100644 tests/test_infer_ops/cuda/test_vllm_rmsnorm.py delete mode 100644 tests/test_infer_ops/cuda/test_vllm_rotary_embedding.py diff --git a/colossalai/inference/README.md b/colossalai/inference/README.md index d0c281e057b3..4aca7aeb0582 100644 --- a/colossalai/inference/README.md +++ b/colossalai/inference/README.md @@ -34,11 +34,13 @@ In this section we discuss how the colossal inference works and integrates with - [x] policy - [x] context forward - [x] token forward + - [x] support flash-decoding - [ ] Replace the kernels with `faster-transformer` in token-forward stage - [ ] Support all models - [x] Llama + - [x] Llama-2 - [x] Bloom - - [ ] Chatglm2 + - [x] Chatglm2 - [ ] Benchmarking for all models ## Get started @@ -68,6 +70,12 @@ git clone https://github.com/ModelTC/lightllm git checkout 28c1267cfca536b7b4f28e921e03de735b003039 cd lightllm pip3 install -e . + +# also, install xformers from source: +pip install ninja +# Set TORCH_CUDA_ARCH_LIST if running and building on different GPU types +pip install -v -U git+https://github.com/facebookresearch/xformers.git@main#egg=xformers + ``` ### Docker @@ -89,7 +97,10 @@ git checkout 28c1267cfca536b7b4f28e921e03de735b003039 cd lightllm pip3 install -e . - +# install xformers from source +pip install ninja +# Set TORCH_CUDA_ARCH_LIST if running and building on different GPU types +pip install -v -U git+https://github.com/facebookresearch/xformers.git@main#egg=xformers ``` ### Dive into fast-inference! diff --git a/colossalai/inference/tensor_parallel/engine.py b/colossalai/inference/tensor_parallel/engine.py index e410532d83eb..1c203140cc3d 100644 --- a/colossalai/inference/tensor_parallel/engine.py +++ b/colossalai/inference/tensor_parallel/engine.py @@ -311,6 +311,7 @@ def prepare_batch_state(self, inputs) -> BatchInferState: seq_start_indexes[i] = start_index start_index += curr_seq_len max_len_in_batch = curr_seq_len if curr_seq_len > max_len_in_batch else max_len_in_batch + block_loc = torch.empty((batch_size, self.max_input_len + self.max_output_len), dtype=torch.long, device="cuda") batch_infer_state = BatchInferState(batch_size, max_len_in_batch) batch_infer_state.seq_len = seq_lengths.to("cuda") diff --git a/colossalai/inference/tensor_parallel/modeling/bloom.py b/colossalai/inference/tensor_parallel/modeling/bloom.py index d84c567ead0c..0ad3994b0194 100644 --- a/colossalai/inference/tensor_parallel/modeling/bloom.py +++ b/colossalai/inference/tensor_parallel/modeling/bloom.py @@ -19,6 +19,12 @@ from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState from colossalai.kernel.triton import bloom_context_attn_fwd, copy_kv_cache_to_dest, token_attention_fwd +try: + from lightllm.models.bloom.triton_kernel.context_flashattention_nopad import context_attention_fwd as lightllm_bloom_context_attention_fwd + HAS_LIGHTLLM_KERNEL = True +except: + HAS_LIGHTLLM_KERNEL = False + def generate_alibi(n_head, dtype=torch.float16): """ @@ -460,7 +466,10 @@ def bloom_attention_forward( # output = self.output[:batch_size*q_length, :, :] output = torch.empty_like(q) - bloom_context_attn_fwd(q, k, v, output, b_start_loc, b_seq_len, max_input_len, alibi) + if HAS_LIGHTLLM_KERNEL: + lightllm_bloom_context_attention_fwd(q, k, v, output, alibi, b_start_loc, b_seq_len, max_input_len) + else: + bloom_context_attn_fwd(q, k, v, output, b_start_loc, b_seq_len, max_input_len, alibi) context_layer = output.view(batch_size, q_length, H * D_HEAD) else: diff --git a/colossalai/inference/tensor_parallel/modeling/llama.py b/colossalai/inference/tensor_parallel/modeling/llama.py index a17b901dc7fd..8573bb965ea6 100644 --- a/colossalai/inference/tensor_parallel/modeling/llama.py +++ b/colossalai/inference/tensor_parallel/modeling/llama.py @@ -1,4 +1,6 @@ from typing import List, Optional, Tuple +import math +import copy import torch from transformers.modeling_outputs import BaseModelOutputWithPast @@ -10,24 +12,11 @@ from ._utils import copy_kv_to_mem_cache -try: - from vllm import layernorm_ops, pos_encoding_ops - - rms_norm = layernorm_ops.rms_norm - rotary_embedding_neox = pos_encoding_ops.rotary_embedding_neox - HAS_VLLM_KERNERL = True -except: - print("fall back to original rotary_embedding_neox of huggingface") - print("install vllm from https://github.com/vllm-project/vllm to accelerate your inference") - print( - "if falied to install vllm, please use this branch to install: https://github.com/tiandiao123/vllm/tree/setup_branch" - ) - HAS_VLLM_KERNERL = False - try: from lightllm.models.llama2.triton_kernel.context_flashattention_nopad import ( context_attention_fwd as lightllm_llama2_context_attention_fwd, ) + from lightllm.models.llama.triton_kernel.context_flashattention_nopad import context_attention_fwd as lightllm_context_attention_fwd from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd as llama_rotary_embedding_fwd HAS_LIGHTLLM_KERNEL = True @@ -35,6 +24,13 @@ print("please install lightllm from source to run inference: https://github.com/ModelTC/lightllm") HAS_LIGHTLLM_KERNEL = False +try: + from flash_attn import flash_attn_with_kvcache + HAS_FLASH_KERNEL = True +except: + HAS_FLASH_KERNEL = False + print("please install flash attentiom from https://github.com/Dao-AILab/flash-attention") + def rotate_half(x): """Rotates half the hidden dims of the input.""" @@ -54,6 +50,71 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids): k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed +def llama_triton_context_attention(query_states, key_states, value_states, attn_output, infer_state, num_key_value_groups=1): + if num_key_value_groups == 1: + if HAS_LIGHTLLM_KERNEL is False: + llama_context_attn_fwd( + query_states, + key_states, + value_states, + attn_output, + infer_state.start_loc, + infer_state.seq_len, + # infer_state.cache_manager.past_key_values_length, + infer_state.max_len_in_batch, + ) + else: + lightllm_context_attention_fwd( + query_states, + key_states, + value_states, + attn_output, + infer_state.start_loc, + infer_state.seq_len, + # infer_state.cache_manager.past_key_values_length, + infer_state.max_len_in_batch, + ) + else: + assert HAS_LIGHTLLM_KERNEL is True, "You have to install lightllm kernels to run llama2 model" + lightllm_llama2_context_attention_fwd( + query_states, + key_states, + value_states, + attn_output, + infer_state.start_loc, + infer_state.seq_len, + # infer_state.cache_manager.past_key_values_length, + infer_state.max_len_in_batch, + ) + +def llama_triton_token_attention(query_states, attn_output, infer_state, num_key_value_groups=1): + assert HAS_LIGHTLLM_KERNEL is True, "You have to install lightllm kernel to run token attention for llama models" + if num_key_value_groups == 1: + token_attention_fwd( + query_states, + infer_state.cache_manager.key_buffer[infer_state.decode_layer_id], + infer_state.cache_manager.value_buffer[infer_state.decode_layer_id], + attn_output, + infer_state.block_loc, + infer_state.start_loc, + infer_state.seq_len, + # infer_state.cache_manager.past_key_values_length, + infer_state.max_len_in_batch, + ) + else: + Llama2TokenAttentionForwards.token_attn( + query_states, + infer_state.cache_manager.key_buffer[infer_state.decode_layer_id], + infer_state.cache_manager.value_buffer[infer_state.decode_layer_id], + attn_output, + infer_state.block_loc, + infer_state.start_loc, + infer_state.seq_len, + # infer_state.cache_manager.past_key_values_length, + infer_state.max_len_in_batch, + infer_state.other_kv_index, + ) + class LlamaInferenceForwards: """ @@ -204,7 +265,8 @@ def llama_model_forward( hidden_states=all_hidden_states, attentions=all_self_attns, ) - + + @staticmethod def llama_decoder_layer_forward( self: LlamaDecoderLayer, @@ -247,6 +309,7 @@ def llama_decoder_layer_forward( outputs += (present_key_value,) return outputs + @staticmethod def llama_flash_attn_kvcache_forward( @@ -295,27 +358,8 @@ def llama_flash_attn_kvcache_forward( infer_state.cache_manager, ) attn_output = torch.empty_like(query_states) - - if self.num_key_value_groups == 1: - llama_context_attn_fwd( - query_states, - key_states, - value_states, - attn_output, - infer_state.start_loc, - infer_state.seq_len, - infer_state.max_len_in_batch, - ) - else: - lightllm_llama2_context_attention_fwd( - query_states, - key_states, - value_states, - attn_output, - infer_state.start_loc, - infer_state.seq_len, - infer_state.max_len_in_batch, - ) + + llama_triton_context_attention(query_states, key_states, value_states, attn_output, infer_state, num_key_value_groups=self.num_key_value_groups) else: if infer_state.decode_is_contiguous: # if decode is contiguous, then we copy to key cache and value cache in cache manager directly @@ -337,35 +381,26 @@ def llama_flash_attn_kvcache_forward( infer_state.decode_mem_index, infer_state.cache_manager, ) - - # second token and follows - # kv = torch.stack((key_states, value_states), dim=2) - # (batch_size, seqlen, nheads, headdim) - attn_output = torch.empty_like(query_states) - - if self.num_key_value_groups == 1: - token_attention_fwd( - query_states, - infer_state.cache_manager.key_buffer[infer_state.decode_layer_id], - infer_state.cache_manager.value_buffer[infer_state.decode_layer_id], - attn_output, - infer_state.block_loc, - infer_state.start_loc, - infer_state.seq_len, - infer_state.max_len_in_batch, - ) + + HAS_LIGHTLLM_KERNEL = False + if HAS_LIGHTLLM_KERNEL: + attn_output = torch.empty_like(query_states) + llama_triton_token_attention(query_states, attn_output, infer_state, num_key_value_groups=self.num_key_value_groups) else: - Llama2TokenAttentionForwards.token_attn( - query_states, - infer_state.cache_manager.key_buffer[infer_state.decode_layer_id], - infer_state.cache_manager.value_buffer[infer_state.decode_layer_id], - attn_output, - infer_state.block_loc, - infer_state.start_loc, - infer_state.seq_len, - infer_state.max_len_in_batch, - infer_state.other_kv_index, - ) + heads_per_group = self.num_heads // self.num_key_value_heads + cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id] + cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id] + + query_states = query_states.view(bsz, -1, self.num_heads, self.head_dim) + copy_cache_k = cache_k.view(bsz, -1, self.num_key_value_heads, self.head_dim) + copy_cache_v = cache_v.view(bsz, -1, self.num_key_value_heads, self.head_dim) + + attn_output = flash_attn_with_kvcache(q = query_states, + k_cache = copy_cache_k, + v_cache = copy_cache_v, + softmax_scale = 1/ math.sqrt(self.head_dim), + causal = True) + attn_output = attn_output.view(bsz, q_len, self.hidden_size) @@ -374,22 +409,3 @@ def llama_flash_attn_kvcache_forward( # return past_key_value as None return attn_output, None, None - -def get_llama_vllm_rmsnorm_forward(): - if HAS_VLLM_KERNERL: - - def _vllm_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor): - x = hidden_states - out = torch.empty_like(x) - rms_norm( - out, - x, - self.weight.data, - self.variance_epsilon, - ) - - return out - - return _vllm_rmsnorm_forward - else: - return None diff --git a/colossalai/inference/tensor_parallel/policies/llama.py b/colossalai/inference/tensor_parallel/policies/llama.py index 7e163efe0173..d6c072c747b7 100644 --- a/colossalai/inference/tensor_parallel/policies/llama.py +++ b/colossalai/inference/tensor_parallel/policies/llama.py @@ -9,7 +9,7 @@ from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy from ..modeling._utils import init_to_get_rotary -from ..modeling.llama import LlamaInferenceForwards, get_llama_vllm_rmsnorm_forward +from ..modeling.llama import LlamaInferenceForwards try: from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward as lightllm_rmsnorm_forward @@ -105,9 +105,6 @@ def module_policy(self): infer_forward = None if HAS_TRITON_RMSNORM: infer_forward = get_triton_rmsnorm_forward() - else: - # NOTE: adding rms_norm from cuda kernels caused precision issue, fix @tiandiao123 - infer_forward = get_llama_vllm_rmsnorm_forward() if infer_forward is not None: method_replacement = {"forward": partial(infer_forward)} diff --git a/colossalai/kernel/triton/context_attention.py b/colossalai/kernel/triton/context_attention.py index 1b4f6e44b0f2..5ce6f2c21385 100644 --- a/colossalai/kernel/triton/context_attention.py +++ b/colossalai/kernel/triton/context_attention.py @@ -5,7 +5,6 @@ try: import triton import triton.language as tl - HAS_TRITON = True except ImportError: HAS_TRITON = False @@ -155,39 +154,43 @@ def bloom_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len, al num_warps = 4 if Lk <= 64 else 8 tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32) - - _context_flash_attention_kernel[grid]( - q, - k, - v, - sm_scale, - b_start_loc, - b_seq_len, - tmp, - alibi, - o, - q.stride(0), - q.stride(1), - q.stride(2), - k.stride(0), - k.stride(1), - k.stride(2), - v.stride(0), - v.stride(1), - v.stride(2), - o.stride(0), - o.stride(1), - o.stride(2), - tmp.stride(0), - tmp.stride(1), - tmp.stride(2), - # manually setting this blcok num, we can use tuning config to futher speed-up - BLOCK_M=BLOCK, - BLOCK_DMODEL=Lk, - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=1, - ) + + if triton.__version__ < "2.1.0": + _context_flash_attention_kernel[grid]( + q, + k, + v, + sm_scale, + b_start_loc, + b_seq_len, + tmp, + alibi, + o, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + tmp.stride(0), + tmp.stride(1), + tmp.stride(2), + # manually setting this blcok num, we can use tuning config to futher speed-up + BLOCK_M=BLOCK, + BLOCK_DMODEL=Lk, + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + else: + raise Exception("Please install lightllm kernels from https://github.com/ModelTC/lightllm since your triton version is larger than 2.0.0") + return @torch.no_grad() @@ -207,36 +210,40 @@ def llama_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len): tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32) num_warps = 4 if Lk <= 64 else 8 # num_warps = 4 - _context_flash_attention_kernel[grid]( - q, - k, - v, - sm_scale, - b_start_loc, - b_seq_len, - tmp, - None, - o, - q.stride(0), - q.stride(1), - q.stride(2), - k.stride(0), - k.stride(1), - k.stride(2), - v.stride(0), - v.stride(1), - v.stride(2), - o.stride(0), - o.stride(1), - o.stride(2), - tmp.stride(0), - tmp.stride(1), - tmp.stride(2), - BLOCK_M=BLOCK, - BLOCK_DMODEL=Lk, - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=1, - ) + if triton.__version__ < "2.1.0": + _context_flash_attention_kernel[grid]( + q, + k, + v, + sm_scale, + b_start_loc, + b_seq_len, + tmp, + None, + o, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + tmp.stride(0), + tmp.stride(1), + tmp.stride(2), + BLOCK_M=BLOCK, + BLOCK_DMODEL=Lk, + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + else: + raise Exception("Please install lightllm kernels from https://github.com/ModelTC/lightllm since your triton version is larger than 2.0.0") + return \ No newline at end of file diff --git a/examples/inference/bench_llama.py b/examples/inference/bench_llama.py index f3e742dfbb59..56bf062e2e68 100644 --- a/examples/inference/bench_llama.py +++ b/examples/inference/bench_llama.py @@ -105,8 +105,8 @@ def test_llama(args): parser = argparse.ArgumentParser() parser.add_argument("-p", "--path", type=str, help="Model path", required=True) parser.add_argument("-tp", "--tp_size", type=int, default=1, help="Tensor parallel size") - parser.add_argument("-b", "--batch_size", type=int, default=16, help="Maximum batch size") - parser.add_argument("--input_len", type=int, default=256, help="Maximum input length") + parser.add_argument("-b", "--batch_size", type=int, default=32, help="Maximum batch size") + parser.add_argument("--input_len", type=int, default=1024, help="Maximum input length") parser.add_argument("--output_len", type=int, default=128, help="Maximum output length") parser.add_argument( "--test_mode", type=str, help="Test mode", default="e2e_test", choices=["e2e_test", "decoder_test"] diff --git a/tests/test_infer/test_bloom_infer.py b/tests/test_infer/test_bloom_infer.py index ba978ad9bf0d..d4366758d6a3 100644 --- a/tests/test_infer/test_bloom_infer.py +++ b/tests/test_infer/test_bloom_infer.py @@ -10,6 +10,12 @@ from colossalai.shardformer import ShardConfig from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn +try: + import lightllm + HAS_LIGHTLLM_KERNEL = True +except: + HAS_LIGHTLLM_KERNEL = False + TP_SIZE = 2 MAX_BATCH_SIZE = 4 MAX_INPUT_LEN = 16 @@ -52,7 +58,7 @@ def check_bloom(rank, world_size, port): run() -@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5") +@pytest.mark.skipif(not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL, reason="kv-cache manager engine requires cuda version to be higher than 11.5") @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() diff --git a/tests/test_infer/test_chatglm2_infer.py b/tests/test_infer/test_chatglm2_infer.py index f9f7670c4500..09bb8a94994d 100644 --- a/tests/test_infer/test_chatglm2_infer.py +++ b/tests/test_infer/test_chatglm2_infer.py @@ -12,6 +12,12 @@ from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn +try: + import lightllm + HAS_LIGHTLLM_KERNEL = True +except: + HAS_LIGHTLLM_KERNEL = False + os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" TPSIZE = 1 BATCH_SIZE = 8 @@ -61,7 +67,7 @@ def check_chatglm2(rank, world_size, port): run_chatglm2_test() -@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5") +@pytest.mark.skipif(not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL, reason="kv-cache manager engine requires cuda version to be higher than 11.5") @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() diff --git a/tests/test_infer/test_llama2_infer.py b/tests/test_infer/test_llama2_infer.py index 0eebed8892ea..13e7a61826ab 100644 --- a/tests/test_infer/test_llama2_infer.py +++ b/tests/test_infer/test_llama2_infer.py @@ -12,6 +12,12 @@ from colossalai.shardformer import ShardConfig from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn +try: + import lightllm + HAS_LIGHTLLM_KERNEL = True +except: + HAS_LIGHTLLM_KERNEL = False + os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" TPSIZE = 2 BATCH_SIZE = 8 @@ -57,7 +63,7 @@ def check_llama(rank, world_size, port): run_llama_test() -@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5") +@pytest.mark.skipif(not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL, reason="kv-cache manager engine requires cuda version to be higher than 11.5") @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() diff --git a/tests/test_infer/test_llama_infer.py b/tests/test_infer/test_llama_infer.py index b424525a3719..a4f54d197065 100644 --- a/tests/test_infer/test_llama_infer.py +++ b/tests/test_infer/test_llama_infer.py @@ -12,6 +12,12 @@ from colossalai.shardformer import ShardConfig from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn +try: + import lightllm + HAS_LIGHTLLM_KERNEL = True +except: + HAS_LIGHTLLM_KERNEL = False + os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" TPSIZE = 2 BATCH_SIZE = 8 @@ -55,7 +61,7 @@ def check_llama(rank, world_size, port): run_llama_test() -@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5") +@pytest.mark.skipif(not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL, reason="kv-cache manager engine requires cuda version to be higher than 11.5") @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() diff --git a/tests/test_infer_ops/cuda/test_vllm_rmsnorm.py b/tests/test_infer_ops/cuda/test_vllm_rmsnorm.py deleted file mode 100644 index a4d893f8e830..000000000000 --- a/tests/test_infer_ops/cuda/test_vllm_rmsnorm.py +++ /dev/null @@ -1,60 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- -import pytest -import torch -from torch import nn - -try: - from vllm import layernorm_ops - - rms_norm = layernorm_ops.rms_norm - HAS_VLLM_KERNERL = True -except: - print("please install vllm kernels to install rmsnorm") - print("install vllm from https://github.com/vllm-project/vllm to accelerate your inference") - HAS_VLLM_KERNERL = False - - -class LlamaRMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - LlamaRMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - - -def cuda_rmsnorm_forward(hidden_states, weight, variance_epsilon): - x = hidden_states - out = torch.empty_like(x) - rms_norm( - out, - x, - weight, - variance_epsilon, - ) - return out - - -@pytest.mark.skipif(not HAS_VLLM_KERNERL, reason="You need to install llama supported cuda kernels to run this test") -def test_rmsnorm(): - data = torch.randn((1024, 64), dtype=torch.float16, device="cuda") - hg_rms = LlamaRMSNorm(64) - hg_rms = hg_rms.half().cuda() - out_torch = hg_rms(data) - out_cuda = cuda_rmsnorm_forward(data, hg_rms.weight.data, hg_rms.variance_epsilon) - - check = torch.allclose(out_torch.cpu(), out_cuda.cpu(), rtol=1e-3, atol=1e-5) - assert check is True, "cuda rmsnorm forward is not matched with torch rmsnorm forward" - - -if __name__ == "__main__": - test_rmsnorm() diff --git a/tests/test_infer_ops/cuda/test_vllm_rotary_embedding.py b/tests/test_infer_ops/cuda/test_vllm_rotary_embedding.py deleted file mode 100644 index 40451ef6636d..000000000000 --- a/tests/test_infer_ops/cuda/test_vllm_rotary_embedding.py +++ /dev/null @@ -1,153 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- -from typing import Tuple - -import pytest -import torch -import torch.nn as nn -import torch.nn.functional as F -from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, rotate_half - -try: - from vllm import pos_encoding_ops - - rotary_embedding_neox = pos_encoding_ops.rotary_embedding_neox - HAS_VLLM_KERNERL = True -except: - print("fall back to original rotary_embedding_neox of huggingface") - print("install vllm from https://github.com/vllm-project/vllm to accelerate your inference") - HAS_VLLM_KERNERL = False - - -def rotate_half(x: torch.Tensor) -> torch.Tensor: - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb( - q: torch.Tensor, - k: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -class RefRotaryEmbeddingNeox(nn.Module): - """Reference implementation of the GPT-NeoX style rotary embedding.""" - - def __init__( - self, - dim: int, - max_position_embeddings: int = 2048, - base: int = 10000, - ) -> None: - super().__init__() - self.rotary_dim = dim - self.max_position_embeddings = max_position_embeddings - - # Create cos and sin embeddings. - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2) / dim)) - t = torch.arange(max_position_embeddings).float() - freqs = torch.einsum("i,j->ij", t, inv_freq.float()) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos().to(dtype=inv_freq.dtype) - sin = emb.sin().to(dtype=inv_freq.dtype) - self.register_buffer("cos_cached", cos, persistent=False) - self.register_buffer("sin_cached", sin, persistent=False) - - def forward( - self, - positions: torch.Tensor, # [num_tokens] - query: torch.Tensor, # [num_tokens, num_heads, head_size] - key: torch.Tensor, # [num_tokens, num_heads, head_size] - ) -> Tuple[torch.Tensor, torch.Tensor]: - query_rot = query[..., : self.rotary_dim] - query_pass = query[..., self.rotary_dim :] - key_rot = key[..., : self.rotary_dim] - key_pass = key[..., self.rotary_dim :] - - query_rot = query_rot.transpose(0, 1) - key_rot = key_rot.transpose(0, 1) - cos = F.embedding(positions, self.cos_cached) - sin = F.embedding(positions, self.sin_cached) - query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin) - query_rot = query_rot.transpose(0, 1).contiguous() - key_rot = key_rot.transpose(0, 1).contiguous() - - query = torch.cat((query_rot, query_pass), dim=-1) - key = torch.cat((key_rot, key_pass), dim=-1) - - # Output query/key shape: [num_tokens, num_tokens, head_size] - return query, key - - -def run_rotary_embedding_neox( - num_tokens: int, - num_heads: int, - head_size: int, - max_position: int, - rotary_dim: int, - dtype: torch.dtype, - base: int = 10000, -) -> None: - positions = torch.randint(0, max_position, (num_tokens,), device="cuda") - query = torch.randn(num_tokens, num_heads * head_size, dtype=dtype, device="cuda") - key = torch.randn(num_tokens, num_heads * head_size, dtype=dtype, device="cuda") - - # Create the rotary embedding. - inv_freq = 1.0 / (base ** (torch.arange(0, rotary_dim, 2) / rotary_dim)) - t = torch.arange(max_position).float() - freqs = torch.einsum("i,j -> ij", t, inv_freq.float()) - cos = freqs.cos() - sin = freqs.sin() - cos_sin_cache = torch.cat((cos, sin), dim=-1) - cos_sin_cache = cos_sin_cache.to(dtype=dtype, device="cuda") - - # Run the kernel. The kernel is in-place, so we need to clone the inputs. - out_query = query.clone() - out_key = key.clone() - rotary_embedding_neox( - positions, - out_query, - out_key, - head_size, - cos_sin_cache, - ) - - # Run the reference implementation. - ref_rotary_embedding = RefRotaryEmbeddingNeox( - dim=rotary_dim, - max_position_embeddings=max_position, - base=base, - ).to(dtype=dtype, device="cuda") - ref_query, ref_key = ref_rotary_embedding( - positions, - query.view(num_tokens, num_heads, head_size), - key.view(num_tokens, num_heads, head_size), - ) - ref_query = ref_query.view(num_tokens, num_heads * head_size) - ref_key = ref_key.view(num_tokens, num_heads * head_size) - - # Compare the results. - assert torch.allclose(out_query, ref_query, atol=1e-3, rtol=1e-5) - assert torch.allclose(out_key, ref_key, atol=1e-3, rtol=1e-5) - - -@pytest.mark.skipif(not HAS_VLLM_KERNERL, reason="You need to install llama supported cuda kernels to run this test") -def test_rotary_embedding(): - run_rotary_embedding_neox( - num_tokens=1024, - num_heads=8, - head_size=64, - max_position=8192, - rotary_dim=64, - dtype=torch.float16, - ) - - -if __name__ == "__main__": - test_rotary_embedding() From fa1cbd3ffd1feed16c8b353519db51d3516502ee Mon Sep 17 00:00:00 2001 From: Yuanchen <70520919+chengeharrison@users.noreply.github.com> Date: Tue, 31 Oct 2023 10:30:03 +0800 Subject: [PATCH 19/46] fix ColossalEval (#4992) Co-authored-by: Xu Yuanchen --- .../dataset_evaluator/dataset_evaluator.py | 5 +++++ .../evaluate/dataset_evaluator/metrics.py | 14 ++++++++++++++ .../colossal_eval/models/huggingface.py | 2 +- 3 files changed, 20 insertions(+), 1 deletion(-) diff --git a/applications/ColossalEval/colossal_eval/evaluate/dataset_evaluator/dataset_evaluator.py b/applications/ColossalEval/colossal_eval/evaluate/dataset_evaluator/dataset_evaluator.py index c70988707a15..22de56b93c81 100644 --- a/applications/ColossalEval/colossal_eval/evaluate/dataset_evaluator/dataset_evaluator.py +++ b/applications/ColossalEval/colossal_eval/evaluate/dataset_evaluator/dataset_evaluator.py @@ -60,6 +60,11 @@ def _calculate_label_metrics(self, metric: str, category: str): sample["output"], ref, all_classes=self.data[category]["inference_kwargs"]["all_classes"] ), ) + + score = max( + score, + metric_helper.accuracy_by_options(sample["input"], sample["output"], ref), + ) softmaxs.append(references[i] if score == 1 else -1) else: softmaxs.append(np.argmax(np.array(list(sample["softmax_over_choices"].values())))) diff --git a/applications/ColossalEval/colossal_eval/evaluate/dataset_evaluator/metrics.py b/applications/ColossalEval/colossal_eval/evaluate/dataset_evaluator/metrics.py index 914465478dec..45a12756de69 100644 --- a/applications/ColossalEval/colossal_eval/evaluate/dataset_evaluator/metrics.py +++ b/applications/ColossalEval/colossal_eval/evaluate/dataset_evaluator/metrics.py @@ -443,6 +443,20 @@ def multi_choice_accuracy(prediction, reference, **kwargs): return score +def accuracy_by_options(question, prediction, reference): + pattern = r"[A-Z]\. [^\n]+" + options = re.findall(pattern, question) + answer = prediction.split("\n\n")[0] + + for option in options: + choice, content = option.split(". ", 1) + + if choice == reference and content == answer: + return 1 + + return 0 + + def combined_single_choice_accuracy(prediction, reference, **kwargs): return single_choice_accuracy(prediction, reference, **kwargs) diff --git a/applications/ColossalEval/colossal_eval/models/huggingface.py b/applications/ColossalEval/colossal_eval/models/huggingface.py index 9f785a6aa9d1..47259c1db758 100644 --- a/applications/ColossalEval/colossal_eval/models/huggingface.py +++ b/applications/ColossalEval/colossal_eval/models/huggingface.py @@ -96,7 +96,7 @@ def _load_tokenizer(self, path: str, tokenizer_path: Optional[str], tokenizer_kw self.logger.warning("pad_token_id is not set for the tokenizer. " "Using eos_token_id as pad_token_id.") if self.tokenizer.eos_token: self.tokenizer.pad_token = self.tokenizer.eos_token - elif self.tokenizer.eod_id: + elif hasattr(self.tokenizer, "eod_id"): # Qwen has an eod token "<|endoftext|>". self.tokenizer.pad_token_id = self.tokenizer.eod_id From 3209431029efffae33c6c54b021f0c3f8c9534a7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cuiqing=20Li=20=28=E6=9D=8E=E5=B4=94=E5=8D=BF=29?= Date: Tue, 31 Oct 2023 10:48:07 +0800 Subject: [PATCH 20/46] [doc]Update doc for colossal-inference (#4989) * update doc * Update README.md --------- Co-authored-by: cuiqing.li --- colossalai/inference/README.md | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/colossalai/inference/README.md b/colossalai/inference/README.md index 4aca7aeb0582..cf5dbf245205 100644 --- a/colossalai/inference/README.md +++ b/colossalai/inference/README.md @@ -59,16 +59,14 @@ dependencies pytorch= 1.13.1 (gpu) cuda>= 11.6 transformers= 4.30.2 -triton==2.0.0.dev20221202 -# for install vllm, please use this branch to install https://github.com/tiandiao123/vllm/tree/setup_branch -vllm -# for install flash-attention, please use commit hash: 67ae6fd74b4bc99c36b2ce524cf139c35663793c +triton +# for install flash-attention flash-attention # install lightllm since we depend on lightllm triton kernels git clone https://github.com/ModelTC/lightllm -git checkout 28c1267cfca536b7b4f28e921e03de735b003039 cd lightllm +git checkout 28c1267cfca536b7b4f28e921e03de735b003039 pip3 install -e . # also, install xformers from source: @@ -93,8 +91,8 @@ pip install -e . # install lightllm git clone https://github.com/ModelTC/lightllm -git checkout 28c1267cfca536b7b4f28e921e03de735b003039 cd lightllm +git checkout 28c1267cfca536b7b4f28e921e03de735b003039 pip3 install -e . # install xformers from source From f0482f4934c095ab5d50aca575c2ae90a26f6603 Mon Sep 17 00:00:00 2001 From: littsk <1214689160@qq.com> Date: Tue, 31 Oct 2023 14:47:30 +0800 Subject: [PATCH 21/46] [hotfix] Fix the bug where process groups were not being properly released. (#4940) * Fix the bug where process groups were not being properly released. * test * Revert "test" This reverts commit 479900c1398637310abf92eefa3cd168038ea02f. --- colossalai/cluster/process_group_mesh.py | 19 +++++++ .../tensor/d_tensor/layout_converter.py | 52 ++++++++++++++++++- 2 files changed, 69 insertions(+), 2 deletions(-) diff --git a/colossalai/cluster/process_group_mesh.py b/colossalai/cluster/process_group_mesh.py index 3885bc962561..eb4532194a26 100644 --- a/colossalai/cluster/process_group_mesh.py +++ b/colossalai/cluster/process_group_mesh.py @@ -1,3 +1,4 @@ +import gc import itertools from functools import reduce from operator import mul @@ -44,6 +45,24 @@ def __init__(self, *size: int) -> None: self._ranks_to_group: Dict[Tuple[int, ...], ProcessGroup] = {} self._group_to_ranks: Dict[ProcessGroup, Tuple[int, ...]] = {} + def __del__(self): + r""" + Destructor method for the ProcessGroupMesh class. + + When the ProcessGroupMesh object is deleted or goes out of scope, this method is called. It is responsible for + cleaning up any process groups that were created during the lifetime of the object. + + Note: + All process groups in PyTorch are represented as global variables, and they may not be automatically destroyed + when the ProcessGroupMesh's lifetime ends. This method manually destroys the process groups to release + system resources. + """ + for group in self._ranks_to_group.values(): + dist.destroy_process_group(group) + + # Manually clear all process groups to save memory + gc.collect() + @property def shape(self) -> Tuple[int, ...]: return self._shape diff --git a/colossalai/tensor/d_tensor/layout_converter.py b/colossalai/tensor/d_tensor/layout_converter.py index e031e0472b0b..abe4a86d8198 100644 --- a/colossalai/tensor/d_tensor/layout_converter.py +++ b/colossalai/tensor/d_tensor/layout_converter.py @@ -4,6 +4,7 @@ from typing import Dict, List, Tuple import torch +import torch.distributed as dist from colossalai.context.singleton_meta import SingletonMeta from colossalai.tensor.d_tensor.comm_spec import * @@ -438,11 +439,58 @@ def layout_converting( MAX_TRANSFORM_STEPS = 20 total_steps = 0 transform_path = [] - comm_action_sequence = [] + comm_action_sequence: List[CommSpec] = [] spec_pairs = (str(source_spec.sharding_sequence), str(target_spec.sharding_sequence)) if spec_pairs in self.cached_solution: - return self.cached_solution[spec_pairs] + # Solution Cache hit + + def _group_alive_check(cached_comm_action_sequence): + r""" + Check if the process groups required for sharding have been deleted by torch.distributed.destroy_process_group method. + If not deleted, return True; otherwise, return False. + + Args: + cached_comm_action_sequence (List[CommSpec]): A list of communication specifications representing actions. + + Returns: + bool: True if all process groups are still registered, False if at least one has been deleted. + + Raises: + RuntimeError: If there is an error while checking the status of a process group. + """ + + # Collect all process groups used in communication actions from the cached sequence + used_process_groups = [ + pg for comm_spec in cached_comm_action_sequence for pg in comm_spec.process_group_dict.values() + ] + + # Check if each process group is still alive + for process_group in used_process_groups: + try: + dist.get_rank(process_group) + except RuntimeError as e: + # If the group is not registered, it means it has been deleted + if str(e) == ( + f"Group {process_group} is not registered, please create group with torch.distributed.new_group API" + ): + return False + elif str(e) == "The given group does not exist": + return False + else: + # Re-raise the exception if it's not related to group deletion + raise e + # All process groups are alive + return True + + cached_transform_path, cached_comm_action_sequence = self.cached_solution[spec_pairs] + + if _group_alive_check(cached_comm_action_sequence): + # If all process groups have not been deleted, the cache is valid + return cached_transform_path, cached_comm_action_sequence + else: + # If at least one process group has been deleted, the cache is invalid, so delete it + del self.cached_solution[spec_pairs] # We do nothing if the sharding spec is all the same. if source_spec.spec_diff(target_spec) == 0: From cd8ad65f5a2ce53f88d321b82dfbb5b198beb009 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Tue, 31 Oct 2023 14:48:01 +0800 Subject: [PATCH 22/46] [hotfix] fix the bug of repeatedly storing param group (#4951) --- colossalai/booster/plugin/gemini_plugin.py | 12 ++++++------ colossalai/booster/plugin/low_level_zero_plugin.py | 7 ++++--- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index 20a931b816ea..d1a9bc2623a3 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -150,24 +150,24 @@ def save_sharded_optimizer( # Preparing file paths and index file. states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix) index_file = CheckpointIndexFile(checkpoint) + index_file.append_meta_data("param_groups", param_group_file) # Store the information of param groups to param_group_file. - index_file.append_meta_data("param_groups", param_group_file) - group_file_path = os.path.join(checkpoint, param_group_file) - param_groups = optimizer.get_param_groups_for_saving() - torch.save(param_groups, group_file_path) + if self.coordinator.is_master(): + group_file_path = os.path.join(checkpoint, param_group_file) + param_groups = optimizer.get_param_groups_for_saving() + torch.save(param_groups, group_file_path) # States are broken into shards within max_shard_size. state_dict_shard = optimizer.state_shard(prefix=prefix, max_shard_size=size_per_shard, only_rank_0=True) # Save shards of optimizer states. - is_master = self.coordinator.is_master() total_size = save_state_dict_shards( sharded_state_dict=state_dict_shard, checkpoint=checkpoint, index_file=index_file, base_filename=states_name, - is_master=is_master, + is_master=self.coordinator.is_master(), use_safetensors=False, ) diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index dc78fe8c094c..09343138f5ff 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -119,11 +119,12 @@ def save_sharded_optimizer( # Preparing file paths and index file. states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix) index_file = CheckpointIndexFile(checkpoint) + index_file.append_meta_data("param_groups", param_group_file) # Store the information of param groups to param_group_file. - index_file.append_meta_data("param_groups", param_group_file) - group_file_path = os.path.join(checkpoint, param_group_file) - save_param_groups(state_dict, group_file_path) + if self.coordinator.is_master(): + group_file_path = os.path.join(checkpoint, param_group_file) + save_param_groups(state_dict, group_file_path) # Save shards of optimizer states. total_size = 0 From 5266946c89e0f20e2a3cc99819ac516b1c46651d Mon Sep 17 00:00:00 2001 From: ppt0011 <143150326+ppt0011@users.noreply.github.com> Date: Tue, 31 Oct 2023 19:56:42 +0800 Subject: [PATCH 23/46] [doc] add supported feature diagram for hybrid parallel plugin (#4996) --- docs/source/en/basics/booster_plugins.md | 6 +++++- docs/source/zh-Hans/basics/booster_plugins.md | 6 +++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/docs/source/en/basics/booster_plugins.md b/docs/source/en/basics/booster_plugins.md index fa360a4b9213..55f1b4f53721 100644 --- a/docs/source/en/basics/booster_plugins.md +++ b/docs/source/en/basics/booster_plugins.md @@ -58,7 +58,11 @@ This plugin implements Zero-3 with chunk-based and heterogeneous memory manageme This plugin implements the combination of various parallel training strategies and optimization tools. The features of HybridParallelPlugin can be generally divided into four parts: -1. Shardformer: This plugin provides an entrance to Shardformer, which controls model sharding under tensor parallel and pipeline parallel setting. Shardformer also overloads the logic of model's forward/backward process to ensure the smooth working of tp/pp. Also, optimization tools including fused normalization, flash attention (xformers), JIT and sequence parallel are injected into the overloaded forward/backward method by Shardformer. More details can be found in chapter [Shardformer Doc](../features/shardformer.md). +1. Shardformer: This plugin provides an entrance to Shardformer, which controls model sharding under tensor parallel and pipeline parallel setting. Shardformer also overloads the logic of model's forward/backward process to ensure the smooth working of tp/pp. Also, optimization tools including fused normalization, flash attention (xformers), JIT and sequence parallel are injected into the overloaded forward/backward method by Shardformer. More details can be found in chapter [Shardformer Doc](../features/shardformer.md). The diagram below shows the features supported by shardformer together with hybrid parallel plugin. + +
+ +
2. Mixed Precision Training: Support for fp16/bf16 mixed precision training. More details about its arguments configuration can be found in [Mixed Precision Training Doc](../features/mixed_precision_training_with_booster.md). diff --git a/docs/source/zh-Hans/basics/booster_plugins.md b/docs/source/zh-Hans/basics/booster_plugins.md index 70352a7b9af3..c810d4ce40b0 100644 --- a/docs/source/zh-Hans/basics/booster_plugins.md +++ b/docs/source/zh-Hans/basics/booster_plugins.md @@ -55,7 +55,11 @@ Zero-2 不支持局部梯度累积。如果您坚持使用,虽然可以积累 这个插件实现了多种并行训练策略和优化工具的组合。Hybrid Parallel插件支持的功能大致可以被分为以下四个部分: -1. Shardformer: Shardformer负责在张量并行以及流水线并行下切分模型的逻辑,以及前向/后向方法的重载,这个插件为Shardformer功能提供了一个简单易用的接口。与此同时,Shardformer还负责将包括fused normalization, flash attention (xformers), JIT和序列并行在内的各类优化工具融入重载后的前向/后向方法。更多关于Shardformer的信息请参考 [Shardformer文档](../features/shardformer.md)。 +1. Shardformer: Shardformer负责在张量并行以及流水线并行下切分模型的逻辑,以及前向/后向方法的重载,这个插件为Shardformer功能提供了一个简单易用的接口。与此同时,Shardformer还负责将包括fused normalization, flash attention (xformers), JIT和序列并行在内的各类优化工具融入重载后的前向/后向方法。更多关于Shardformer的信息请参考 [Shardformer文档](../features/shardformer.md)。下图展示了Shardformer与Hybrid Parallel插件所支持的功能。 + +
+ +
2. 混合精度训练:插件支持fp16/bf16的混合精度训练。更多关于混合精度训练的参数配置的详细信息请参考 [混合精度训练文档](../features/mixed_precision_training_with_booster.md)。 From ab8468c0b115c869075d3f0e967e4945b4b6388d Mon Sep 17 00:00:00 2001 From: Bin Jia <45593998+FoolPlayer@users.noreply.github.com> Date: Wed, 1 Nov 2023 12:46:21 +0800 Subject: [PATCH 24/46] [Pipeline Inference] Merge pp with tp (#4993) * refactor pipeline into new CaiInferEngine * updata llama modeling forward * merge tp with pp * update docstring * optimize test workflow and example * fix typo * add assert and todo --- colossalai/inference/__init__.py | 6 +- colossalai/inference/hybridengine/__init__.py | 3 + .../{pipeline => hybridengine}/engine.py | 98 ++++--- .../modeling/__init__.py | 0 .../modeling/_utils.py | 0 .../modeling/llama.py | 241 ++++++++++-------- .../polices}/__init__.py | 0 .../polices}/llama.py | 5 +- colossalai/inference/pipeline/__init__.py | 4 +- .../inference/pipeline/microbatch_manager.py | 17 +- .../tensor_parallel/modeling/llama.py | 56 ++-- tests/test_infer/test_pipeline_infer.py | 43 +++- 12 files changed, 269 insertions(+), 204 deletions(-) create mode 100644 colossalai/inference/hybridengine/__init__.py rename colossalai/inference/{pipeline => hybridengine}/engine.py (60%) rename colossalai/inference/{pipeline => hybridengine}/modeling/__init__.py (100%) rename colossalai/inference/{pipeline => hybridengine}/modeling/_utils.py (100%) rename colossalai/inference/{pipeline => hybridengine}/modeling/llama.py (74%) rename colossalai/inference/{pipeline/policies => hybridengine/polices}/__init__.py (100%) rename colossalai/inference/{pipeline/policies => hybridengine/polices}/llama.py (95%) diff --git a/colossalai/inference/__init__.py b/colossalai/inference/__init__.py index 761e48e5917a..d5a988cfc6f0 100644 --- a/colossalai/inference/__init__.py +++ b/colossalai/inference/__init__.py @@ -1,4 +1,4 @@ -from .pipeline import PPInferEngine +from .hybridengine import CaiInferEngine +from .hybridengine.polices import LlamaModelInferPolicy - -__all__ = ['PPInferEngine'] +__all__ = ["CaiInferEngine", "LlamaModelInferPolicy"] diff --git a/colossalai/inference/hybridengine/__init__.py b/colossalai/inference/hybridengine/__init__.py new file mode 100644 index 000000000000..6377ef817301 --- /dev/null +++ b/colossalai/inference/hybridengine/__init__.py @@ -0,0 +1,3 @@ +from .engine import CaiInferEngine + +__all__ = ["CaiInferEngine"] diff --git a/colossalai/inference/pipeline/engine.py b/colossalai/inference/hybridengine/engine.py similarity index 60% rename from colossalai/inference/pipeline/engine.py rename to colossalai/inference/hybridengine/engine.py index 480ac5dc79fb..bb0b4c77a2a7 100644 --- a/colossalai/inference/pipeline/engine.py +++ b/colossalai/inference/hybridengine/engine.py @@ -1,4 +1,5 @@ import torch +import torch.distributed as dist import torch.nn as nn from transformers.tokenization_utils_base import BatchEncoding @@ -8,23 +9,27 @@ from colossalai.shardformer import ShardConfig, ShardFormer from colossalai.shardformer.policies.base_policy import Policy +from ..pipeline.microbatch_manager import MicroBatchManager from ..tensor_parallel.kvcache_manager import MemoryManager -from .microbatch_manager import MicroBatchManager +PP_AXIS, TP_AXIS = 0, 1 -class PPInferEngine: +_supported_models = [ + "LlamaForCausalLM", +] + + +class CaiInferEngine: """ - PPInferEngine is a class that handles the pipeline parallel inference. + CaiInferEngine is a class that handles the pipeline parallel inference. Args: - pp_size (int): the number of pipeline stages. - pp_model (`nn.Module`): the model already in pipeline parallelism style. + tp_size (int): the size of tensor parallelism. + pp_size (int): the size of pipeline parallelism. model (`nn.Module`): the model not in pipeline style, and will be modified with `ShardFormer`. model_policy (`colossalai.shardformer.policies.base_policy.Policy`): the policy to shardformer model. micro_batch_size (int): the micro batch size. micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages. - new_length (int): the new length of the input sequence. - early_stopping (bool): whether to stop early. max_batch_size (int): the maximum batch size. max_input_len (int): the maximum input length. max_output_len (int): the maximum output length. @@ -32,7 +37,7 @@ class PPInferEngine: Example: ```python - from colossalai.inference import PPInferEngine + from colossalai.inference import InferEngine from colossalai.inference.pipeline.policies import LlamaModelInferPolicy import colossalai from transformers import LlamaForCausalLM, LlamaTokenizer @@ -42,7 +47,7 @@ class PPInferEngine: model = LlamaForCausalLM.from_pretrained("your_path_to_model") tokenizer = LlamaTokenizer.from_pretrained("/home/lczyh/share/models/llama-7b-hf") # assume the model is infered with 2 pipeline stages - inferengine = PPInferEngine(pp_size=2, model=model, model_policy=LlamaModelInferPolicy(), new_length=8) + inferengine = CaiInferEngine(pp_size=2, model=model, model_policy=LlamaModelInferPolicy()) input = ["Introduce a landmark in China ","Introduce a landmark in China "] data = tokenizer(input, return_tensors='pt') @@ -54,12 +59,11 @@ class PPInferEngine: def __init__( self, - pp_size: int, + tp_size: int = 1, + pp_size: int = 1, dtype: str = "fp16", - pp_model: nn.Module = None, model: nn.Module = None, model_policy: Policy = None, - new_length: int = 32, micro_batch_size: int = 1, micro_batch_buffer_size: int = None, max_batch_size: int = 4, @@ -71,12 +75,21 @@ def __init__( do_sample: bool = False, num_beams: int = 1, ) -> None: - assert pp_model or (model and model_policy), "Either pp_model or model with model_policy should be provided." + assert model.__class__.__name__ in _supported_models, f"Model {model.__class__.__name__} is not supported." + assert ( + tp_size * pp_size == dist.get_world_size() + ), f"TP size({tp_size}) * PP size({pp_size}) should be equal to the global world size ({dist.get_world_size()})" + assert model and model_policy, "Model with model_policy should be provided." assert dtype in ["fp16", "fp32", "bf16"], "dtype should be one of 'fp16', 'fp32', 'bf16'" - max_output_len = max(max_output_len, max_input_len + new_length) + assert max_batch_size <= 64, "Max batch size exceeds the constraint" + assert max_input_len + max_output_len <= 4096, "Max length exceeds the constraint" + # TODO: support only tensor parallel inference + assert pp_size > 1, "Not support only tensor parallel inference." self.pp_size = pp_size + self.tp_size = tp_size + if dtype == "fp16": self.dtype = torch.float16 model.half() @@ -85,24 +98,29 @@ def __init__( model.to(torch.bfloat16) else: self.dtype = torch.float32 - self.pg_mesh = ProcessGroupMesh(pp_size) - self.stage_manager = PipelineStageManager(self.pg_mesh, 0, True) - self.model = pp_model or self._shardformer(model, model_policy) - self.cache_manager_list = [ - self._init_manager(max_batch_size, max_input_len, max_output_len) - for _ in range(micro_batch_buffer_size or pp_size) - ] - self.mb_manager = MicroBatchManager( - self.stage_manager.stage, - new_length, - micro_batch_size, - micro_batch_buffer_size or pp_size, - max_input_len, - max_output_len, - self.cache_manager_list, - ) - self.verbose = verbose - self.schedule = GenerateSchedule(self.stage_manager, self.mb_manager, verbose) + + # Init pg mesh + pg_mesh = ProcessGroupMesh(pp_size, tp_size) + + stage_manager = None + if pp_size > 1: + stage_manager = PipelineStageManager(pg_mesh, PP_AXIS, True) + self.cache_manager_list = [ + self._init_manager(model, max_batch_size, max_input_len, max_output_len) + for _ in range(micro_batch_buffer_size or pp_size) + ] + self.mb_manager = MicroBatchManager( + stage_manager.stage, + micro_batch_size, + micro_batch_buffer_size or pp_size, + max_input_len, + max_output_len, + self.cache_manager_list, + ) + self.verbose = verbose + self.schedule = GenerateSchedule(stage_manager, self.mb_manager, verbose) + + self.model = self._shardformer(model, model_policy, stage_manager, pg_mesh.get_group_along_axis(TP_AXIS)) def inference(self, input_list): """ @@ -124,10 +142,10 @@ def inference(self, input_list): else: return out - def _shardformer(self, model, model_policy): + def _shardformer(self, model, model_policy, stage_manager, tp_group): shardconfig = ShardConfig( - tensor_parallel_process_group=None, - pipeline_stage_manager=self.stage_manager, + tensor_parallel_process_group=tp_group, + pipeline_stage_manager=stage_manager, enable_tensor_parallelism=False, enable_fused_normalization=False, enable_all_optimization=False, @@ -139,14 +157,12 @@ def _shardformer(self, model, model_policy): shard_model, _ = shardformer.optimize(model, model_policy) return shard_model.cuda() - def _init_manager(self, max_batch_size: int, max_input_len: int, max_output_len: int) -> None: + def _init_manager(self, model, max_batch_size: int, max_input_len: int, max_output_len: int) -> None: max_total_token_num = max_batch_size * (max_input_len + max_output_len) - head_dim = self.model.config.hidden_size // self.model.config.num_attention_heads - head_num = self.model.config.num_attention_heads + head_dim = model.config.hidden_size // model.config.num_attention_heads + head_num = model.config.num_attention_heads num_hidden_layers = ( - self.model.config.num_hidden_layers - if hasattr(self.model.config, "num_hidden_layers") - else self.model.config.num_layers + model.config.num_hidden_layers if hasattr(model.config, "num_hidden_layers") else model.config.num_layers ) layer_num = num_hidden_layers // self.pp_size diff --git a/colossalai/inference/pipeline/modeling/__init__.py b/colossalai/inference/hybridengine/modeling/__init__.py similarity index 100% rename from colossalai/inference/pipeline/modeling/__init__.py rename to colossalai/inference/hybridengine/modeling/__init__.py diff --git a/colossalai/inference/pipeline/modeling/_utils.py b/colossalai/inference/hybridengine/modeling/_utils.py similarity index 100% rename from colossalai/inference/pipeline/modeling/_utils.py rename to colossalai/inference/hybridengine/modeling/_utils.py diff --git a/colossalai/inference/pipeline/modeling/llama.py b/colossalai/inference/hybridengine/modeling/llama.py similarity index 74% rename from colossalai/inference/pipeline/modeling/llama.py rename to colossalai/inference/hybridengine/modeling/llama.py index 9c72b02ccef8..34474d115c8f 100644 --- a/colossalai/inference/pipeline/modeling/llama.py +++ b/colossalai/inference/hybridengine/modeling/llama.py @@ -1,37 +1,25 @@ # This code is adapted from huggingface transformers: https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/llama/modeling_llama.py +import math from typing import List, Optional, Tuple import torch -from transformers.models.llama.modeling_llama import ( - LlamaAttention, - LlamaDecoderLayer, - LlamaForCausalLM, - LlamaModel, - LlamaRMSNorm, -) +from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaForCausalLM, LlamaModel from transformers.utils import logging from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState from colossalai.kernel.triton import llama_context_attn_fwd, token_attention_fwd +from colossalai.kernel.triton.token_attention_kernel import Llama2TokenAttentionForwards from colossalai.pipeline.stage_manager import PipelineStageManager from ._utils import copy_kv_to_mem_cache try: - from vllm import layernorm_ops, pos_encoding_ops - - rms_norm = layernorm_ops.rms_norm - rotary_embedding_neox = pos_encoding_ops.rotary_embedding_neox - HAS_VLLM_KERNERL = True -except: - print("fall back to original rotary_embedding_neox of huggingface") - print("install vllm from https://github.com/vllm-project/vllm to accelerate your inference") - print( - "if falied to install vllm, please use this branch to install: https://github.com/tiandiao123/vllm/tree/setup_branch" + from lightllm.models.llama2.triton_kernel.context_flashattention_nopad import ( + context_attention_fwd as lightllm_llama2_context_attention_fwd, + ) + from lightllm.models.llama.triton_kernel.context_flashattention_nopad import ( + context_attention_fwd as lightllm_context_attention_fwd, ) - HAS_VLLM_KERNERL = False - -try: from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd as llama_rotary_embedding_fwd HAS_LIGHTLLM_KERNEL = True @@ -39,6 +27,14 @@ print("please install lightllm from source to run inference: https://github.com/ModelTC/lightllm") HAS_LIGHTLLM_KERNEL = False +try: + from flash_attn import flash_attn_with_kvcache + + HAS_FLASH_KERNEL = True +except: + HAS_FLASH_KERNEL = False + print("please install flash attentiom from https://github.com/Dao-AILab/flash-attention") + def rotate_half(x): """Rotates half the hidden dims of the input.""" @@ -59,6 +55,75 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids): return q_embed, k_embed +def llama_triton_context_attention( + query_states, key_states, value_states, attn_output, infer_state, num_key_value_groups=1 +): + if num_key_value_groups == 1: + if HAS_LIGHTLLM_KERNEL is False: + llama_context_attn_fwd( + query_states, + key_states, + value_states, + attn_output, + infer_state.start_loc, + infer_state.seq_len, + # infer_state.cache_manager.past_key_values_length, + infer_state.max_len_in_batch, + ) + else: + lightllm_context_attention_fwd( + query_states, + key_states, + value_states, + attn_output, + infer_state.start_loc, + infer_state.seq_len, + # infer_state.cache_manager.past_key_values_length, + infer_state.max_len_in_batch, + ) + else: + assert HAS_LIGHTLLM_KERNEL is True, "You have to install lightllm kernels to run llama2 model" + lightllm_llama2_context_attention_fwd( + query_states, + key_states, + value_states, + attn_output, + infer_state.start_loc, + infer_state.seq_len, + # infer_state.cache_manager.past_key_values_length, + infer_state.max_len_in_batch, + ) + + +def llama_triton_token_attention(query_states, attn_output, infer_state, num_key_value_groups=1): + assert HAS_LIGHTLLM_KERNEL is True, "You have to install lightllm kernel to run token attention for llama models" + if num_key_value_groups == 1: + token_attention_fwd( + query_states, + infer_state.cache_manager.key_buffer[infer_state.decode_layer_id], + infer_state.cache_manager.value_buffer[infer_state.decode_layer_id], + attn_output, + infer_state.block_loc, + infer_state.start_loc, + infer_state.seq_len, + # infer_state.cache_manager.past_key_values_length, + infer_state.max_len_in_batch, + ) + else: + Llama2TokenAttentionForwards.token_attn( + query_states, + infer_state.cache_manager.key_buffer[infer_state.decode_layer_id], + infer_state.cache_manager.value_buffer[infer_state.decode_layer_id], + attn_output, + infer_state.block_loc, + infer_state.start_loc, + infer_state.seq_len, + # infer_state.cache_manager.past_key_values_length, + infer_state.max_len_in_batch, + infer_state.other_kv_index, + ) + + class LlamaInferenceForwards: """ This class holds forwards for llama inference. @@ -144,13 +209,9 @@ def llama_model_forward( hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, ): - # batch_size = input_ids.shape[0] # input_ids.shape[0] - # print(f"[Before] rank:{torch.distributed.get_rank()}\n->{infer_state}") - - # infer_state = self.infer_state - use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict + use_cache = use_cache if use_cache is not None else self.config.use_cache # retrieve input_ids and inputs_embeds if stage_manager is None or stage_manager.is_first_stage(): if input_ids is not None and inputs_embeds is not None: @@ -172,12 +233,10 @@ def llama_model_forward( batch_size, seq_length = input_shape device = hidden_states.device - seq_length_with_past = seq_length - past_key_values_length = 0 - - if infer_state.is_context_stage is False: - past_key_values_length = infer_state.cache_manager.past_key_values_length - seq_length_with_past = seq_length_with_past + past_key_values_length + if infer_state.is_context_stage: + past_key_values_length = 0 + else: + past_key_values_length = infer_state.max_len_in_batch - 1 # NOTE: differentiate with prefill stage # block_loc require different value-assigning method for two different stage @@ -197,26 +256,19 @@ def llama_model_forward( infer_state.decode_mem_index = alloc_mem[0] infer_state.decode_mem_start = alloc_mem[1] infer_state.decode_mem_end = alloc_mem[2] - infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index + infer_state.block_loc[:, infer_state.max_len_in_batch - 1] = infer_state.decode_mem_index else: - print(f" *** Encountered allocation non-contiguous") - print( - f" infer_state.cache_manager.past_key_values_length: {infer_state.cache_manager.past_key_values_length}" - ) infer_state.decode_is_contiguous = False alloc_mem = infer_state.cache_manager.alloc(batch_size) infer_state.decode_mem_index = alloc_mem - # infer_state.decode_key_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") - # infer_state.decode_value_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") - infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index + infer_state.block_loc[:, infer_state.max_len_in_batch - 1] = infer_state.decode_mem_index + if position_ids is None: position_ids = torch.arange( past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device ) - position_ids = position_ids.unsqueeze(0) - new_shape = [1] * position_ids.dim() - new_shape[0] = batch_size - position_ids = position_ids.repeat(*new_shape).view(-1, seq_length) + position_ids = position_ids.repeat(batch_size, 1) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) else: position_ids = position_ids.view(-1, seq_length).long() @@ -227,15 +279,17 @@ def llama_model_forward( infer_state.position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view( position_ids.view(-1).shape[0], -1 ) + else: seq_len = infer_state.seq_len infer_state.position_cos = torch.index_select(self._cos_cached, 0, seq_len - 1).view(seq_len.shape[0], -1) infer_state.position_sin = torch.index_select(self._sin_cached, 0, seq_len - 1).view(seq_len.shape[0], -1) + infer_state.other_kv_index = infer_state.block_loc[0, infer_state.max_len_in_batch - 1].item() # embed positions if attention_mask is None: attention_mask = torch.ones( - (batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device + (batch_size, infer_state.max_len_in_batch), dtype=torch.bool, device=hidden_states.device ) attention_mask = self._prepare_decoder_attention_mask( @@ -243,10 +297,6 @@ def llama_model_forward( ) # decoder layers - () if output_hidden_states else None - () if output_attentions else None - next_decoder_cache = () if use_cache else None - infer_state.decode_layer_id = 0 start_idx, end_idx = stage_index[0], stage_index[1] @@ -268,19 +318,15 @@ def llama_model_forward( infer_state.decode_layer_id += 1 hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) - - if stage_manager.is_last_stage(): + if stage_manager.is_last_stage() or stage_manager.num_stages == 1: hidden_states = self.norm(hidden_states) - next_cache = next_decoder_cache if use_cache else None # update indices # infer_state.block_loc[:, infer_state.max_len_in_batch-1] = infer_state.total_token_num + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") - infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") + infer_state.start_loc += torch.arange(0, batch_size, dtype=torch.int32, device="cuda") infer_state.seq_len += 1 + infer_state.max_len_in_batch += 1 - # TODO: fix this to necessary return # if not return_dict: # return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) @@ -290,8 +336,7 @@ def llama_model_forward( # hidden_states=all_hidden_states, # attentions=all_self_attns, # ) - # print(f"[After] rank:{torch.distributed.get_rank()}\n->{infer_state}") - return {"hidden_states": hidden_states, "past_key_values": next_cache} + return {"hidden_states": hidden_states} @staticmethod def llama_decoder_layer_forward( @@ -307,7 +352,6 @@ def llama_decoder_layer_forward( residual = hidden_states hidden_states = self.input_layernorm(hidden_states) - # Self Attention hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, @@ -357,28 +401,24 @@ def llama_flash_attn_kvcache_forward( # key_states_transposed [bs, num_heads, seq_len, head_dim/embed_size_per_head] query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) - key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) - value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) + key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim) + value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim) # NOTE might want to revise # need some way to record the length of past key values cache # since we won't return past_key_value_cache right now - if infer_state.decode_layer_id == 0: # once per model.forward - infer_state.cache_manager.past_key_values_length += q_len # seq_len cos, sin = infer_state.position_cos, infer_state.position_sin llama_rotary_embedding_fwd(query_states.view(-1, self.num_heads, self.head_dim), cos, sin) - llama_rotary_embedding_fwd(key_states.view(-1, self.num_heads, self.head_dim), cos, sin) + llama_rotary_embedding_fwd(key_states.view(-1, self.num_key_value_heads, self.head_dim), cos, sin) query_states = query_states.reshape(-1, self.num_heads, self.head_dim) - key_states = key_states.reshape(-1, self.num_heads, self.head_dim) - value_states = value_states.reshape(-1, self.num_heads, self.head_dim) + key_states = key_states.reshape(-1, self.num_key_value_heads, self.head_dim) + value_states = value_states.reshape(-1, self.num_key_value_heads, self.head_dim) if infer_state.is_context_stage: - # print(f"rank:{torch.distributed.get_rank()}, {infer_state}") # first token generation - # copy key and value calculated in current step to memory manager copy_kv_to_mem_cache( infer_state.decode_layer_id, @@ -387,19 +427,16 @@ def llama_flash_attn_kvcache_forward( infer_state.context_mem_index, infer_state.cache_manager, ) - attn_output = torch.empty_like(query_states) - llama_context_attn_fwd( + llama_triton_context_attention( query_states, key_states, value_states, attn_output, - infer_state.start_loc, - infer_state.seq_len, - infer_state.cache_manager.past_key_values_length, + infer_state, + num_key_value_groups=self.num_key_value_groups, ) - else: if infer_state.decode_is_contiguous: # if decode is contiguous, then we copy to key cache and value cache in cache manager directly @@ -422,45 +459,31 @@ def llama_flash_attn_kvcache_forward( infer_state.cache_manager, ) - # second token and follows - # kv = torch.stack((key_states, value_states), dim=2) - # (batch_size, seqlen, nheads, headdim) - attn_output = torch.empty_like(query_states) - - token_attention_fwd( - query_states, - infer_state.cache_manager.key_buffer[infer_state.decode_layer_id], - infer_state.cache_manager.value_buffer[infer_state.decode_layer_id], - attn_output, - infer_state.block_loc, - infer_state.start_loc, - infer_state.seq_len, - infer_state.cache_manager.past_key_values_length, - ) + if HAS_LIGHTLLM_KERNEL: + attn_output = torch.empty_like(query_states) + llama_triton_token_attention( + query_states, attn_output, infer_state, num_key_value_groups=self.num_key_value_groups + ) + else: + self.num_heads // self.num_key_value_heads + cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id] + cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id] + + query_states = query_states.view(bsz, -1, self.num_heads, self.head_dim) + copy_cache_k = cache_k.view(bsz, -1, self.num_key_value_heads, self.head_dim) + copy_cache_v = cache_v.view(bsz, -1, self.num_key_value_heads, self.head_dim) + + attn_output = flash_attn_with_kvcache( + q=query_states, + k_cache=copy_cache_k, + v_cache=copy_cache_v, + softmax_scale=1 / math.sqrt(self.head_dim), + causal=True, + ) attn_output = attn_output.view(bsz, q_len, self.hidden_size) - # print(f"rank:{torch.distributed.get_rank()}, {attn_output}") + attn_output = self.o_proj(attn_output) # return past_key_value as None return attn_output, None, None - - -def get_llama_vllm_rmsnorm_forward(): - if HAS_VLLM_KERNERL: - - def _vllm_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor): - x = hidden_states - out = torch.empty_like(x) - rms_norm( - out, - x, - self.weight.data, - self.variance_epsilon, - ) - - return out - - return _vllm_rmsnorm_forward - else: - return None diff --git a/colossalai/inference/pipeline/policies/__init__.py b/colossalai/inference/hybridengine/polices/__init__.py similarity index 100% rename from colossalai/inference/pipeline/policies/__init__.py rename to colossalai/inference/hybridengine/polices/__init__.py diff --git a/colossalai/inference/pipeline/policies/llama.py b/colossalai/inference/hybridengine/polices/llama.py similarity index 95% rename from colossalai/inference/pipeline/policies/llama.py rename to colossalai/inference/hybridengine/polices/llama.py index 9f8c93c61234..992299714bd1 100644 --- a/colossalai/inference/pipeline/policies/llama.py +++ b/colossalai/inference/hybridengine/polices/llama.py @@ -17,7 +17,7 @@ from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy from ..modeling._utils import init_to_get_rotary -from ..modeling.llama import LlamaInferenceForwards, get_llama_vllm_rmsnorm_forward +from ..modeling.llama import LlamaInferenceForwards try: from colossalai.kernel.triton import rmsnorm_forward @@ -120,9 +120,6 @@ def module_policy(self): infer_forward = None if HAS_TRITON_RMSNORM: infer_forward = get_triton_rmsnorm_forward() - else: - # NOTE: adding rms_norm from cuda kernels caused precision issue, fix @tiandiao123 - infer_forward = get_llama_vllm_rmsnorm_forward() if infer_forward is not None: method_replacement = {"forward": partial(infer_forward)} diff --git a/colossalai/inference/pipeline/__init__.py b/colossalai/inference/pipeline/__init__.py index 41af9f3ef948..f43e4a847448 100644 --- a/colossalai/inference/pipeline/__init__.py +++ b/colossalai/inference/pipeline/__init__.py @@ -1,3 +1,3 @@ -from .engine import PPInferEngine +from .microbatch_manager import MicroBatchManager -__all__ = ["PPInferEngine"] +__all__ = ["MicroBatchManager"] diff --git a/colossalai/inference/pipeline/microbatch_manager.py b/colossalai/inference/pipeline/microbatch_manager.py index 2bf52161d611..441cf603985c 100644 --- a/colossalai/inference/pipeline/microbatch_manager.py +++ b/colossalai/inference/pipeline/microbatch_manager.py @@ -33,10 +33,9 @@ def __init__( max_input_len: int, max_output_len: int, cache_manager: MemoryManager, - new_length: int, ) -> None: self.mb_length = inputs_dict["input_ids"].shape[-1] - self.target_length = self.mb_length + new_length + self.target_length = self.mb_length + max_output_len self.infer_state = BatchInferState.init_from_batch( batch=inputs_dict, max_input_len=max_input_len, max_output_len=max_output_len, cache_manager=cache_manager ) @@ -77,7 +76,6 @@ class HeadMicroBatchDescription(MicroBatchDescription): Args: inputs_dict (Dict[str, torch.Tensor]): the inputs of current stage. The key should have `input_ids` and `attention_mask`. output_dict (Dict[str, torch.Tensor]): the outputs of previous stage. The key should have `hidden_states` and `past_key_values`. - new_length (int): the new length of the input sequence. """ @@ -87,9 +85,8 @@ def __init__( max_input_len: int, max_output_len: int, cache_manager: MemoryManager, - new_length: int, ) -> None: - super().__init__(inputs_dict, max_input_len, max_output_len, cache_manager, new_length) + super().__init__(inputs_dict, max_input_len, max_output_len, cache_manager) assert inputs_dict is not None assert inputs_dict.get("input_ids") is not None and inputs_dict.get("attention_mask") is not None self.input_ids = inputs_dict["input_ids"] @@ -139,9 +136,8 @@ def __init__( max_input_len: int, max_output_len: int, cache_manager: MemoryManager, - new_length: int, ) -> None: - super().__init__(inputs_dict, max_input_len, max_output_len, cache_manager, new_length) + super().__init__(inputs_dict, max_input_len, max_output_len, cache_manager) @property def cur_length(self): @@ -158,7 +154,6 @@ class MicroBatchManager: Args: stage (int): stage id of current stage. - new_length (int): the new length of the input sequence. micro_batch_size (int): the micro batch size. micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages. @@ -167,7 +162,6 @@ class MicroBatchManager: def __init__( self, stage: int, - new_length: int, micro_batch_size: int, micro_batch_buffer_size: int, max_input_len: int, @@ -175,7 +169,6 @@ def __init__( cache_manager_list: MemoryManager, ): self.stage = stage - self.new_length = new_length self.micro_batch_size = micro_batch_size self.buffer_size = micro_batch_buffer_size self.max_input_len = max_input_len @@ -188,11 +181,11 @@ def __init__( def add_descrption(self, inputs_dict: Dict[str, torch.Tensor]): if self.stage == 0: self.mb_descrption_buffer[self.idx] = HeadMicroBatchDescription( - inputs_dict, self.max_input_len, self.max_output_len, self.cache_manager_list[self.idx], self.new_length + inputs_dict, self.max_input_len, self.max_output_len, self.cache_manager_list[self.idx] ) else: self.mb_descrption_buffer[self.idx] = BodyMicroBatchDescription( - inputs_dict, self.max_input_len, self.max_output_len, self.cache_manager_list[self.idx], self.new_length + inputs_dict, self.max_input_len, self.max_output_len, self.cache_manager_list[self.idx] ) def step(self, new_token: torch.Tensor = None): diff --git a/colossalai/inference/tensor_parallel/modeling/llama.py b/colossalai/inference/tensor_parallel/modeling/llama.py index 8573bb965ea6..62c2aad3c055 100644 --- a/colossalai/inference/tensor_parallel/modeling/llama.py +++ b/colossalai/inference/tensor_parallel/modeling/llama.py @@ -1,10 +1,9 @@ -from typing import List, Optional, Tuple import math -import copy +from typing import List, Optional, Tuple import torch from transformers.modeling_outputs import BaseModelOutputWithPast -from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm +from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState from colossalai.kernel.triton import llama_context_attn_fwd, token_attention_fwd @@ -16,7 +15,9 @@ from lightllm.models.llama2.triton_kernel.context_flashattention_nopad import ( context_attention_fwd as lightllm_llama2_context_attention_fwd, ) - from lightllm.models.llama.triton_kernel.context_flashattention_nopad import context_attention_fwd as lightllm_context_attention_fwd + from lightllm.models.llama.triton_kernel.context_flashattention_nopad import ( + context_attention_fwd as lightllm_context_attention_fwd, + ) from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd as llama_rotary_embedding_fwd HAS_LIGHTLLM_KERNEL = True @@ -26,6 +27,7 @@ try: from flash_attn import flash_attn_with_kvcache + HAS_FLASH_KERNEL = True except: HAS_FLASH_KERNEL = False @@ -50,7 +52,10 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids): k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed -def llama_triton_context_attention(query_states, key_states, value_states, attn_output, infer_state, num_key_value_groups=1): + +def llama_triton_context_attention( + query_states, key_states, value_states, attn_output, infer_state, num_key_value_groups=1 +): if num_key_value_groups == 1: if HAS_LIGHTLLM_KERNEL is False: llama_context_attn_fwd( @@ -87,6 +92,7 @@ def llama_triton_context_attention(query_states, key_states, value_states, attn_ infer_state.max_len_in_batch, ) + def llama_triton_token_attention(query_states, attn_output, infer_state, num_key_value_groups=1): assert HAS_LIGHTLLM_KERNEL is True, "You have to install lightllm kernel to run token attention for llama models" if num_key_value_groups == 1: @@ -265,8 +271,7 @@ def llama_model_forward( hidden_states=all_hidden_states, attentions=all_self_attns, ) - - + @staticmethod def llama_decoder_layer_forward( self: LlamaDecoderLayer, @@ -309,7 +314,6 @@ def llama_decoder_layer_forward( outputs += (present_key_value,) return outputs - @staticmethod def llama_flash_attn_kvcache_forward( @@ -358,8 +362,15 @@ def llama_flash_attn_kvcache_forward( infer_state.cache_manager, ) attn_output = torch.empty_like(query_states) - - llama_triton_context_attention(query_states, key_states, value_states, attn_output, infer_state, num_key_value_groups=self.num_key_value_groups) + + llama_triton_context_attention( + query_states, + key_states, + value_states, + attn_output, + infer_state, + num_key_value_groups=self.num_key_value_groups, + ) else: if infer_state.decode_is_contiguous: # if decode is contiguous, then we copy to key cache and value cache in cache manager directly @@ -381,26 +392,28 @@ def llama_flash_attn_kvcache_forward( infer_state.decode_mem_index, infer_state.cache_manager, ) - - HAS_LIGHTLLM_KERNEL = False + if HAS_LIGHTLLM_KERNEL: attn_output = torch.empty_like(query_states) - llama_triton_token_attention(query_states, attn_output, infer_state, num_key_value_groups=self.num_key_value_groups) + llama_triton_token_attention( + query_states, attn_output, infer_state, num_key_value_groups=self.num_key_value_groups + ) else: - heads_per_group = self.num_heads // self.num_key_value_heads + self.num_heads // self.num_key_value_heads cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id] cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id] - + query_states = query_states.view(bsz, -1, self.num_heads, self.head_dim) copy_cache_k = cache_k.view(bsz, -1, self.num_key_value_heads, self.head_dim) copy_cache_v = cache_v.view(bsz, -1, self.num_key_value_heads, self.head_dim) - - attn_output = flash_attn_with_kvcache(q = query_states, - k_cache = copy_cache_k, - v_cache = copy_cache_v, - softmax_scale = 1/ math.sqrt(self.head_dim), - causal = True) + attn_output = flash_attn_with_kvcache( + q=query_states, + k_cache=copy_cache_k, + v_cache=copy_cache_v, + softmax_scale=1 / math.sqrt(self.head_dim), + causal=True, + ) attn_output = attn_output.view(bsz, q_len, self.hidden_size) @@ -408,4 +421,3 @@ def llama_flash_attn_kvcache_forward( # return past_key_value as None return attn_output, None, None - diff --git a/tests/test_infer/test_pipeline_infer.py b/tests/test_infer/test_pipeline_infer.py index 6d02f2b326b4..3544153da857 100644 --- a/tests/test_infer/test_pipeline_infer.py +++ b/tests/test_infer/test_pipeline_infer.py @@ -5,8 +5,7 @@ from packaging import version import colossalai -from colossalai.inference.pipeline import PPInferEngine -from colossalai.inference.pipeline.policies import LlamaModelInferPolicy +from colossalai.inference import CaiInferEngine, LlamaModelInferPolicy from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5") @@ -26,27 +25,43 @@ def data_gen(): inputs[k] = v.to("cuda").repeat(*new_shape) -def pipeline_inference_test(pp_size, new_length, micro_batch_size): - model = transformers.LlamaForCausalLM(transformers.LlamaConfig(num_hidden_layers=4)) +def pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size): + model = transformers.LlamaForCausalLM( + transformers.LlamaConfig( + vocab_size=20000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=4 + ) + ) - engine = PPInferEngine( + engine = CaiInferEngine( + tp_size=tp_size, pp_size=pp_size, model=model, model_policy=LlamaModelInferPolicy(), - new_length=new_length, + max_output_len=max_output_len, micro_batch_size=micro_batch_size, ) output = engine.inference(inputs) if dist.get_rank() == 0: - assert len(output[0]) == new_length, f"{len(output)}, {new_length}" + assert len(output[0]) == max_output_len, f"{len(output)}, {max_output_len}" +@parameterize("tp_size", [1]) @parameterize("pp_size", [2]) -@parameterize("new_length", [4, 8, 16]) -@parameterize("micro_batch_size", [1, 4]) +@parameterize("max_output_len", [4]) +@parameterize("micro_batch_size", [1]) @clear_cache_before_run() -def run_pipeline_inference_test(pp_size, new_length, micro_batch_size): - pipeline_inference_test(pp_size, new_length, micro_batch_size) +def run_pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size): + pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size) + torch.cuda.empty_cache() + + +@parameterize("tp_size", [2]) +@parameterize("pp_size", [2]) +@parameterize("max_output_len", [4]) +@parameterize("micro_batch_size", [1]) +@clear_cache_before_run() +def run_tp_pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size): + pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size) torch.cuda.empty_cache() @@ -55,12 +70,18 @@ def check_pipeline_inference(rank, world_size, port): run_pipeline_inference_test() +def check_tp_pipeline_inference(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_tp_pipeline_inference_test() + + @pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5") @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() def test_pipeline_inference(): spawn(check_pipeline_inference, nprocs=2) + spawn(check_tp_pipeline_inference, nprocs=4) if __name__ == "__main__": From f9c192042d6b82eb093d5478798dab97c58ab22f Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Wed, 1 Nov 2023 13:41:22 +0800 Subject: [PATCH 25/46] [release] update version (#4995) * [release] update version * [hotfix] fix ci --- .github/workflows/release_test_pypi_before_merge.yml | 4 +++- version.txt | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/.github/workflows/release_test_pypi_before_merge.yml b/.github/workflows/release_test_pypi_before_merge.yml index 49c626265175..284ab4d1afb0 100644 --- a/.github/workflows/release_test_pypi_before_merge.yml +++ b/.github/workflows/release_test_pypi_before_merge.yml @@ -27,7 +27,9 @@ jobs: echo $new_version > ./version.txt echo "version=$new_version" >> $GITHUB_OUTPUT - - run: python setup.py sdist build + - run: | + pip install --upgrade pip + python setup.py sdist build # publish to PyPI if executed on the main branch - name: Publish package to PyPI diff --git a/version.txt b/version.txt index 1c09c74e221c..42045acae20f 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.3.3 +0.3.4 From 2043b9d5013c699c6c977cb79bf12dd900c041c4 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Wed, 18 Oct 2023 20:14:34 +0800 Subject: [PATCH 26/46] [gemini] gemini support tp [gemini] gemini support tp [gemini] gemini support tp [gemini] gemini support tp [gemini] gemini support tp --- colossalai/booster/plugin/gemini_plugin.py | 15 +- .../quant/gptq/cai_gptq/cai_quant_linear.py | 202 +++++----- .../tensor_parallel/policies/bloom.py | 58 ++- .../cuda_native/csrc/gptq/linear_gptq.cpp | 362 ++++++++++-------- .../kernel/cuda_native/csrc/gptq/q4_matrix.cu | 2 +- .../cuda_native/csrc/gptq/q4_matrix.cuh | 2 +- colossalai/shardformer/layer/_operation.py | 16 +- colossalai/shardformer/layer/embedding.py | 5 +- colossalai/zero/gemini/gemini_ddp.py | 1 + examples/inference/gptq_bloom.py | 36 +- .../test_plugin/test_gemini_plugin.py | 14 +- tests/test_gptq/test_gptq_linear.py | 30 +- 12 files changed, 415 insertions(+), 328 deletions(-) diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index d1a9bc2623a3..8fcb0f1a57d7 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -21,6 +21,7 @@ ) from colossalai.cluster import DistCoordinator from colossalai.interface import ModelWrapper, OptimizerWrapper +from colossalai.shardformer import ShardConfig, ShardFormer from colossalai.utils import get_current_device from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero.gemini.memory_tracer import MemStats @@ -317,7 +318,8 @@ def __init__( max_scale: float = 2**32, max_norm: float = 0.0, norm_type: float = 2.0, - verbose: bool = False, + use_tp_pipeline: bool = False, + verbose: bool = False ) -> None: super().__init__() assert precision in SUPPORTED_PRECISION, f"precision {precision} is not supported" @@ -355,6 +357,7 @@ def __init__( max_norm=max_norm, norm_type=norm_type, ) + self.use_tp_pipeline = use_tp_pipeline self.verbose = verbose def support_no_sync(self) -> bool: @@ -391,6 +394,16 @@ def configure( # model = nn.SyncBatchNorm.convert_sync_batchnorm(model, None) # wrap the model with Gemini + + if self.use_tp_pipeline: + try: + shard_config = ShardConfig(enable_tensor_parallelism=True, enable_fused_normalization=False) + shardformer = ShardFormer(shard_config) + model, _ = shardformer.optimize(model) + optimizer.param_groups[0]["params"] = model.parameters() + except NotImplementedError as e: + print(f"Auto policy for {model.__class__} is not implemented yet\n.") + model = GeminiDDP(model, **self.gemini_config, verbose=self.verbose) if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): diff --git a/colossalai/inference/quant/gptq/cai_gptq/cai_quant_linear.py b/colossalai/inference/quant/gptq/cai_gptq/cai_quant_linear.py index 36339ac88486..ca12c34ed958 100644 --- a/colossalai/inference/quant/gptq/cai_gptq/cai_quant_linear.py +++ b/colossalai/inference/quant/gptq/cai_gptq/cai_quant_linear.py @@ -18,15 +18,15 @@ HAS_GPTQ_CUDA = False try: from colossalai.kernel.op_builder.gptq import GPTQBuilder - gptq_cuda = GPTQBuilder().load() HAS_GPTQ_CUDA = True except ImportError: - warnings.warn("CUDA gptq is not installed") + warnings.warn('CUDA gptq is not installed') HAS_GPTQ_CUDA = False class CaiQuantLinear(nn.Module): + def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False): super().__init__() if bits not in [2, 4, 8]: @@ -37,28 +37,23 @@ def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp self.maxq = 2**self.bits - 1 self.groupsize = groupsize if groupsize != -1 else infeatures - self.register_buffer("qweight", torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32)) - self.register_buffer( - "qzeros", - torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits), dtype=torch.int32), - ) + self.register_buffer('qweight', torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32)) self.register_buffer( - "scales", torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16) - ) + 'qzeros', + torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits), dtype=torch.int32)) + self.register_buffer('scales', + torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16)) if row_split: self.register_buffer( - "g_idx", - torch.tensor( - [(i + (tp_rank * self.infeatures)) // self.groupsize for i in range(infeatures)], dtype=torch.int32 - ), - ) + 'g_idx', + torch.tensor([(i + (tp_rank * self.infeatures)) // self.groupsize for i in range(infeatures)], + dtype=torch.int32)) else: - self.register_buffer( - "g_idx", torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int32) - ) + self.register_buffer('g_idx', + torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int32)) if bias: - self.register_buffer("bias", torch.zeros((outfeatures), dtype=torch.float16)) + self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16)) else: self.bias = None @@ -71,11 +66,9 @@ def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp self.row_split = row_split def pack(self, linear, scales, zeros, g_idx=None): - g_idx = ( - g_idx.clone() - if g_idx is not None - else torch.tensor([i // self.groupsize for i in range(self.infeatures)], dtype=torch.int32) - ) + + g_idx = g_idx.clone() if g_idx is not None else torch.tensor( + [i // self.groupsize for i in range(self.infeatures)], dtype=torch.int32) scales = scales.t().contiguous() zeros = zeros.t().contiguous() @@ -86,6 +79,7 @@ def pack(self, linear, scales, zeros, g_idx=None): if linear.bias is not None: self.bias = linear.bias.clone().half() + wn = 8 pbits = 32 ptype = torch.int32 unsign_type = np.uint32 @@ -94,10 +88,9 @@ def pack(self, linear, scales, zeros, g_idx=None): intweight = [] for idx in range(self.infeatures): intweight.append( - torch.round((linear.weight.data[:, idx] + scale_zeros[g_idx[idx]]) / half_scales[g_idx[idx]]).to(ptype)[ - :, None - ] - ) + torch.round( + (linear.weight.data[:, idx] + scale_zeros[g_idx[idx]]) / half_scales[g_idx[idx]]).to(ptype)[:, + None]) intweight = torch.cat(intweight, dim=1) intweight = intweight.t().contiguous() intweight = intweight.numpy().astype(unsign_type) @@ -116,7 +109,7 @@ def pack(self, linear, scales, zeros, g_idx=None): raise NotImplementedError("Only 2,4,8 bits are supported.") qweight = qweight.astype(sign_type) qweight1 = torch.from_numpy(qweight) - qweight1 = qweight1.contiguous() # .to("cuda") + qweight1 = qweight1.contiguous() #.to("cuda") self.qweight.data.copy_(qweight1) qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // pbits * self.bits), dtype=unsign_type) @@ -147,20 +140,17 @@ def init_q4(self): self.q4_width = self.qweight.shape[1] if self.g_idx is not None: if self.row_split and torch.equal( - self.g_idx, - torch.tensor( - [(i + (self.tp_rank * self.infeatures)) // self.groupsize for i in range(self.infeatures)], - dtype=torch.int32, - device=self.g_idx.device, - ), - ): + self.g_idx, + torch.tensor( + [(i + (self.tp_rank * self.infeatures)) // self.groupsize for i in range(self.infeatures)], + dtype=torch.int32, + device=self.g_idx.device)): self.g_idx = None elif torch.equal( - self.g_idx, - torch.tensor( - [i // self.groupsize for i in range(self.infeatures)], dtype=torch.int32, device=self.g_idx.device - ), - ): + self.g_idx, + torch.tensor([i // self.groupsize for i in range(self.infeatures)], + dtype=torch.int32, + device=self.g_idx.device)): self.g_idx = None if self.g_idx is not None: @@ -175,6 +165,7 @@ def forward(self, x): outshape = x.shape[:-1] + (self.outfeatures,) if HAS_GPTQ_CUDA and self.bits == 4: + if self.q4 is None: self.init_q4() @@ -200,6 +191,7 @@ def forward(self, x): def split_column_copy(gptq_linear, cai_linear, tp_size=1, tp_rank=0, split_num=1): + qweights = gptq_linear.qweight.split(gptq_linear.out_features // split_num, dim=-1) qzeros = gptq_linear.qzeros.split(gptq_linear.out_features // (32 // cai_linear.bits) // split_num, dim=-1) scales = gptq_linear.scales.split(gptq_linear.out_features // split_num, dim=-1) @@ -211,24 +203,24 @@ def split_column_copy(gptq_linear, cai_linear, tp_size=1, tp_rank=0, split_num=1 zero_split_block = cai_linear.outfeatures // (32 // cai_linear.bits) // split_num for i in range(split_num): - cai_linear.qweight[:, i * cai_split_out_features : (i + 1) * cai_split_out_features] = qweights[i][ - :, tp_rank * cai_split_out_features : (tp_rank + 1) * cai_split_out_features - ] - cai_linear.qzeros[:, i * zero_split_block : (i + 1) * zero_split_block] = qzeros[i][ - :, tp_rank * zero_split_block : (tp_rank + 1) * zero_split_block - ] - cai_linear.scales[:, i * cai_split_out_features : (i + 1) * cai_split_out_features] = scales[i][ - :, tp_rank * cai_split_out_features : (tp_rank + 1) * cai_split_out_features - ] + cai_linear.qweight[:, i * cai_split_out_features:(i + 1) * + cai_split_out_features] = qweights[i][:, tp_rank * cai_split_out_features:(tp_rank + 1) * + cai_split_out_features] + cai_linear.qzeros[:, i * zero_split_block:(i + 1) * + zero_split_block] = qzeros[i][:, tp_rank * zero_split_block:(tp_rank + 1) * zero_split_block] + cai_linear.scales[:, i * cai_split_out_features:(i + 1) * + cai_split_out_features] = scales[i][:, tp_rank * cai_split_out_features:(tp_rank + 1) * + cai_split_out_features] if cai_linear.bias is not None: - cai_linear.bias[i * cai_split_out_features : (i + 1) * cai_split_out_features] = bias[i][ - tp_rank * cai_split_out_features : (tp_rank + 1) * cai_split_out_features - ] + cai_linear.bias[i * cai_split_out_features:(i + 1) * + cai_split_out_features] = bias[i][tp_rank * cai_split_out_features:(tp_rank + 1) * + cai_split_out_features] cai_linear.g_idx.copy_(g_idx) def split_row_copy(gptq_linear, cai_linear, tp_rank=0, split_num=1): + qweights = gptq_linear.qweight.split(gptq_linear.in_features // split_num, dim=0) qzeros = gptq_linear.qzeros.split(gptq_linear.in_features // split_num, dim=0) scales = gptq_linear.scales.split(gptq_linear.in_features // split_num, dim=0) @@ -239,40 +231,47 @@ def split_row_copy(gptq_linear, cai_linear, tp_rank=0, split_num=1): idx_split_features = cai_linear.infeatures // split_num for i in range(split_num): - cai_linear.qweight[i * cai_split_in_features : (i + 1) * cai_split_in_features, :] = qweights[i][ - tp_rank * cai_split_in_features : (tp_rank + 1) * cai_split_in_features, : - ] - cai_linear.qzeros[i * zero_split_block : (i + 1) * zero_split_block, :] = qzeros[i][ - tp_rank * zero_split_block : (tp_rank + 1) * zero_split_block, : - ] - cai_linear.scales[i * zero_split_block : (i + 1) * zero_split_block, :] = scales[i][ - tp_rank * zero_split_block : (tp_rank + 1) * zero_split_block, : - ] - cai_linear.g_idx[i * idx_split_features : (i + 1) * idx_split_features] = g_idxs[i][ - tp_rank * idx_split_features : (tp_rank + 1) * idx_split_features - ] + cai_linear.qweight[i * cai_split_in_features:(i + 1) * + cai_split_in_features, :] = qweights[i][tp_rank * cai_split_in_features:(tp_rank + 1) * + cai_split_in_features, :] + cai_linear.qzeros[i * zero_split_block:(i + 1) * + zero_split_block, :] = qzeros[i][tp_rank * zero_split_block:(tp_rank + 1) * + zero_split_block, :] + cai_linear.scales[i * zero_split_block:(i + 1) * + zero_split_block, :] = scales[i][tp_rank * zero_split_block:(tp_rank + 1) * + zero_split_block, :] + cai_linear.g_idx[i * idx_split_features:(i + 1) * + idx_split_features] = g_idxs[i][tp_rank * idx_split_features:(tp_rank + 1) * + idx_split_features] if cai_linear.bias is not None: cai_linear.bias.copy_(gptq_linear.bias) class RowCaiQuantLinear(CaiQuantLinear, ParallelModule): + def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False): - super().__init__( - bits, groupsize, infeatures, outfeatures, bias, tp_size=tp_size, tp_rank=tp_rank, row_split=row_split - ) + + super().__init__(bits, + groupsize, + infeatures, + outfeatures, + bias, + tp_size=tp_size, + tp_rank=tp_rank, + row_split=row_split) self.process_group = None @staticmethod - def from_native_module( - module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs - ) -> ParallelModule: + def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, + **kwargs) -> ParallelModule: LazyInitContext.materialize(module) # get the attributes in_features = module.in_features # ensure only one process group is passed if isinstance(process_group, (list, tuple)): - assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}." + assert len(process_group) == 1, \ + f'Expected only one process group, got {len(process_group)}.' process_group = process_group[0] tp_size = dist.get_world_size(process_group) @@ -283,18 +282,15 @@ def from_native_module( if in_features % tp_size != 0: raise ValueError( - f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!" - ) - linear_1d = RowCaiQuantLinear( - module.bits, - module.group_size, - module.in_features // tp_size, - module.out_features, - module.bias is not None, - tp_size=tp_size, - tp_rank=tp_rank, - row_split=True, - ) + f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!") + linear_1d = RowCaiQuantLinear(module.bits, + module.group_size, + module.in_features // tp_size, + module.out_features, + module.bias is not None, + tp_size=tp_size, + tp_rank=tp_rank, + row_split=True) linear_1d.process_group = process_group split_row_copy(module, linear_1d, tp_rank=tp_rank, **kwargs) @@ -310,23 +306,30 @@ def forward(self, x): class ColCaiQuantLinear(CaiQuantLinear, ParallelModule): + def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False): - super().__init__( - bits, groupsize, infeatures, outfeatures, bias, tp_size=tp_size, tp_rank=tp_rank, row_split=row_split - ) + + super().__init__(bits, + groupsize, + infeatures, + outfeatures, + bias, + tp_size=tp_size, + tp_rank=tp_rank, + row_split=row_split) self.process_group = None @staticmethod - def from_native_module( - module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs - ) -> ParallelModule: + def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, + **kwargs) -> ParallelModule: LazyInitContext.materialize(module) # get the attributes in_features = module.in_features # ensure only one process group is passed if isinstance(process_group, (list, tuple)): - assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}." + assert len(process_group) == 1, \ + f'Expected only one process group, got {len(process_group)}.' process_group = process_group[0] tp_size = dist.get_world_size(process_group) @@ -337,17 +340,14 @@ def from_native_module( if in_features % tp_size != 0: raise ValueError( - f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!" - ) - linear_1d = ColCaiQuantLinear( - module.bits, - module.group_size, - module.in_features, - module.out_features // tp_size, - module.bias is not None, - tp_size=tp_size, - tp_rank=tp_rank, - ) + f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!") + linear_1d = ColCaiQuantLinear(module.bits, + module.group_size, + module.in_features, + module.out_features // tp_size, + module.bias is not None, + tp_size=tp_size, + tp_rank=tp_rank) linear_1d.process_group = process_group split_column_copy(module, linear_1d, tp_rank=tp_rank, **kwargs) diff --git a/colossalai/inference/tensor_parallel/policies/bloom.py b/colossalai/inference/tensor_parallel/policies/bloom.py index fba83a08175d..3d6df2097000 100644 --- a/colossalai/inference/tensor_parallel/policies/bloom.py +++ b/colossalai/inference/tensor_parallel/policies/bloom.py @@ -4,6 +4,7 @@ from torch.nn import LayerNorm import colossalai.shardformer.layer as col_nn +from colossalai.shardformer.modeling.bloom import build_bloom_alibi_tensor_fn from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription from colossalai.shardformer.policies.bloom import BloomForCausalLMPolicy @@ -39,36 +40,33 @@ def module_policy(self): policy = super().module_policy() if self.shard_config.inference_gptq: from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear - - policy[BloomBlock] = ModulePolicyDescription( - attribute_replacement={ - "self_attention.hidden_size": self.model.config.hidden_size - // self.shard_config.tensor_parallel_size, - "self_attention.split_size": self.model.config.hidden_size - // self.shard_config.tensor_parallel_size, - "self_attention.num_heads": self.model.config.n_head // self.shard_config.tensor_parallel_size, - }, - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="self_attention.query_key_value", - target_module=ColCaiQuantLinear, - kwargs={"split_num": 3}, - ), - SubModuleReplacementDescription( - suffix="self_attention.dense", target_module=RowCaiQuantLinear, kwargs={"split_num": 1} - ), - SubModuleReplacementDescription( - suffix="self_attention.attention_dropout", - target_module=col_nn.DropoutForParallelInput, - ), - SubModuleReplacementDescription( - suffix="mlp.dense_h_to_4h", target_module=ColCaiQuantLinear, kwargs={"split_num": 1} - ), - SubModuleReplacementDescription( - suffix="mlp.dense_4h_to_h", target_module=RowCaiQuantLinear, kwargs={"split_num": 1} - ), - ], - ) + policy[BloomBlock] = ModulePolicyDescription(attribute_replacement={ + "self_attention.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "self_attention.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "self_attention.num_heads": self.model.config.n_head // self.shard_config.tensor_parallel_size, + }, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attention.query_key_value", + target_module=ColCaiQuantLinear, + kwargs={'split_num': 3}), + SubModuleReplacementDescription( + suffix="self_attention.dense", + target_module=RowCaiQuantLinear, + kwargs={'split_num': 1}), + SubModuleReplacementDescription( + suffix="self_attention.attention_dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="mlp.dense_h_to_4h", + target_module=ColCaiQuantLinear, + kwargs={'split_num': 1}), + SubModuleReplacementDescription( + suffix="mlp.dense_4h_to_h", + target_module=RowCaiQuantLinear, + kwargs={'split_num': 1}), + ]) # NOTE set inference mode to shard config self.shard_config._infer() diff --git a/colossalai/kernel/cuda_native/csrc/gptq/linear_gptq.cpp b/colossalai/kernel/cuda_native/csrc/gptq/linear_gptq.cpp index 8f17723cbd1b..bcc0e43901de 100644 --- a/colossalai/kernel/cuda_native/csrc/gptq/linear_gptq.cpp +++ b/colossalai/kernel/cuda_native/csrc/gptq/linear_gptq.cpp @@ -1,202 +1,254 @@ // Adapted from turboderp exllama: https://github.com/turboderp/exllama -#include +#include #include -#include +#include #include -#include - +#include #include #include - -#include "column_remap.cuh" +#include "util.cuh" +#include "tuning.h" #include "cuda_buffers.cuh" -#include "q4_matmul.cuh" #include "q4_matrix.cuh" -#include "tuning.h" -#include "util.cuh" +#include "q4_matmul.cuh" +#include "column_remap.cuh" -// Check CUDA return code. We don't want to include Torch headers in the .cu -// files because parsing them adds almost a minute to the compile time on a -// 12900K. Also passing exceptions back to Python is super tricky, so in place -// of exceptions, CUDA functions return with a cudaError_t which we can parse -// and dump to the console. - -void check_cuda(cudaError_t ret) { - switch (ret) { - case cudaSuccess: - break; - - case cudaUnspecified: - printf(" **** Unspecified error\n"); - TORCH_CHECK(false, "CUDA error"); - break; - - default: - printf(" **** CUDA error\n"); - printf(" **** %s\n", cudaGetErrorString(ret)); - TORCH_CHECK(false, "CUDA error"); - break; - } +// Check CUDA return code. We don't want to include Torch headers in the .cu files because parsing them adds almost a +// minute to the compile time on a 12900K. Also passing exceptions back to Python is super tricky, so in place of +// exceptions, CUDA functions return with a cudaError_t which we can parse and dump to the console. + +void check_cuda(cudaError_t ret) +{ + switch (ret) + { + case cudaSuccess: + break; + + case cudaUnspecified: + printf(" **** Unspecified error\n"); + TORCH_CHECK(false, "CUDA error"); + break; + + default: + printf(" **** CUDA error\n"); \ + printf(" **** %s\n", cudaGetErrorString(ret)); \ + TORCH_CHECK(false, "CUDA error"); \ + break; + } } // Some decluttering macros #define STRINGIFY_(__x) #__x #define STRINGIFY(__x) STRINGIFY_(__x) -#define TORCH_CHECK_DTYPE(__x, __dtype) \ - TORCH_CHECK((__x).dtype() == torch::__dtype, \ - #__x " is incorrect datatype, must be " #__dtype) -#define TORCH_CHECK_DTYPE_OPT(__x, __dtype) \ - TORCH_CHECK((__x).device().is_meta() || (__x).dtype() == torch::__dtype, \ - #__x " is incorrect datatype, must be " #__dtype) -#define TORCH_CHECK_SHAPES(__x, __dim_x, __y, __dim_y, __scale_y) \ - TORCH_CHECK((__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, \ - #__x " and " #__y " have incompatible shapes") -#define TORCH_CHECK_SHAPES_OPT(__x, __dim_x, __y, __dim_y, __scale_y) \ - TORCH_CHECK((__x).device().is_meta() || \ - (__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, \ - #__x " and " #__y " have incompatible shapes") -#define TORCH_CHECK_SHAPE_MOD(__x, __dim_x, __mod) \ - TORCH_CHECK((__x).size(__dim_x) % __mod == 0, \ - #__x ".shape[" STRINGIFY( \ - __dim_x) "] must be a multiple of " STRINGIFY(__mod)) -#define TORCH_CHECK_BUFFER_SIZE(__buffer, __minimum_size) \ - TORCH_CHECK((__buffer).numel() >= __minimum_size, #__buffer " is too small") - -#define TORCH_CHECK_DEVICE_INDEX(__index) \ - do { \ - TORCH_CHECK(__index >= 0, "no device index"); \ +#define TORCH_CHECK_DTYPE(__x, __dtype) TORCH_CHECK((__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype) +#define TORCH_CHECK_DTYPE_OPT(__x, __dtype) TORCH_CHECK((__x).device().is_meta() || (__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype) +#define TORCH_CHECK_SHAPES(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes") +#define TORCH_CHECK_SHAPES_OPT(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).device().is_meta() || (__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes") +#define TORCH_CHECK_SHAPE_MOD(__x, __dim_x, __mod) TORCH_CHECK((__x).size(__dim_x) % __mod == 0, #__x ".shape[" STRINGIFY(__dim_x) "] must be a multiple of " STRINGIFY(__mod)) +#define TORCH_CHECK_BUFFER_SIZE(__buffer, __minimum_size) TORCH_CHECK((__buffer).numel() >= __minimum_size, #__buffer " is too small") + +#define TORCH_CHECK_DEVICE_INDEX(__index) \ +do { \ + TORCH_CHECK(__index >= 0, "no device index"); \ TORCH_CHECK(__index < CUDA_MAX_DEVICES, "invalid device index"); \ - } while (0) +} while(0) #define TORCH_CHECK_QUANT(__w, __w_scales, __w_zeros, __seq_g_idx, __x_map) \ - do { \ - TORCH_CHECK_DTYPE(__w, kInt); \ - TORCH_CHECK_DTYPE(__w_scales, kHalf); \ - TORCH_CHECK_DTYPE(__w_zeros, kInt); \ - TORCH_CHECK_DTYPE_OPT(__seq_g_idx, kShort); \ - TORCH_CHECK_DTYPE_OPT(__x_map, kInt); \ - TORCH_CHECK_SHAPES_OPT(__seq_g_idx, 0, __w, 0, 2 * 8); \ - TORCH_CHECK_SHAPES_OPT(__x_map, 0, __w, 0, 8); \ - } while (0) - -int get_groupsize(torch::Tensor w, torch::Tensor w_zeros) { - int groupsize = w.size(0) * 8 / w_zeros.size(0); - TORCH_CHECK(groupsize * w_zeros.size(0) == w.size(0) * 8, - "w.shape[-2] must be a multiple of zeros.shape[-2]") - return groupsize; +do { \ + TORCH_CHECK_DTYPE(__w, kInt); \ + TORCH_CHECK_DTYPE(__w_scales, kHalf); \ + TORCH_CHECK_DTYPE(__w_zeros, kInt); \ + TORCH_CHECK_DTYPE_OPT(__seq_g_idx, kShort); \ + TORCH_CHECK_DTYPE_OPT(__x_map, kInt); \ + TORCH_CHECK_SHAPES_OPT(__seq_g_idx, 0, __w, 0, 2 * 8); \ + TORCH_CHECK_SHAPES_OPT(__x_map, 0, __w, 0, 8); \ +} while(0) + +int get_groupsize(torch::Tensor w, torch::Tensor w_zeros) +{ + int groupsize = w.size(0) * 8 / w_zeros.size(0); + TORCH_CHECK(groupsize * w_zeros.size(0) == w.size(0) * 8, "w.shape[-2] must be a multiple of zeros.shape[-2]") + return groupsize; } + // Tuning parameters ExLlamaTuning tuningParams; -void set_tuning_params(int matmul_recons_thd, bool matmul_fused_remap, - bool matmul_no_half2) { - tuningParams.matmul_recons_thd = matmul_recons_thd; - tuningParams.matmul_fused_remap = matmul_fused_remap; - tuningParams.matmul_no_half2 = matmul_no_half2; +void set_tuning_params +( + int matmul_recons_thd, + bool matmul_fused_remap, + bool matmul_no_half2 +) +{ + tuningParams.matmul_recons_thd = matmul_recons_thd; + tuningParams.matmul_fused_remap = matmul_fused_remap; + tuningParams.matmul_no_half2 = matmul_no_half2; } + // Release all unmanaged objects allocated by the extension -void cleanup() { - cleanup_buffers_cuda(); - g_q4_free_matrices(); +void cleanup() +{ + cleanup_buffers_cuda(); + g_q4_free_matrices(); } -// Prepare buffers for forward pass -void prepare_buffers(torch::Device device, torch::Tensor temp_state, - torch::Tensor temp_dq) { - int device_index = device.index(); - TORCH_CHECK_DEVICE_INDEX(device_index); - const at::cuda::OptionalCUDAGuard device_guard(device); +// Prepare buffers for forward pass - prepare_buffers_cuda(device_index, - // buffer size used for sanity checks - temp_state.numel(), (half*)temp_state.data_ptr(), - (half*)temp_dq.data_ptr()); +void prepare_buffers +( + torch::Device device, + torch::Tensor temp_state, + torch::Tensor temp_dq +) +{ + int device_index = device.index(); + TORCH_CHECK_DEVICE_INDEX(device_index); + const at::cuda::OptionalCUDAGuard device_guard(device); + + prepare_buffers_cuda + ( + device_index, + // buffer size used for sanity checks + temp_state.numel(), + (half*) temp_state.data_ptr(), + (half*) temp_dq.data_ptr() + ); } -// Create Q4Matrix, return handle - -uintptr_t make_q4(torch::Tensor qweight, torch::Tensor qzeros, - torch::Tensor scales, torch::Tensor g_idx, int device) { - TORCH_CHECK_DTYPE(qweight, kInt); - TORCH_CHECK_DTYPE(qzeros, kInt); - TORCH_CHECK_DTYPE(scales, kHalf); - TORCH_CHECK_DTYPE_OPT(g_idx, kInt); - TORCH_CHECK_SHAPES(qweight, 1, qzeros, 1, 8); - TORCH_CHECK_SHAPES(scales, 1, qweight, 1, 1); - TORCH_CHECK_SHAPES(qzeros, 0, scales, 0, 1); - - int width = qweight.size(1); - int height = qweight.size(0) * 8; - int groups = qzeros.size(0); - - Q4Matrix* m = new Q4Matrix( - height, width, groups, - (uint32_t*)qweight.data_ptr(), (uint32_t*)qzeros.data_ptr(), - (half*)scales.data_ptr(), - g_idx.device().is_meta() ? NULL : (uint32_t*)g_idx.data_ptr(), - - device); +// Create Q4Matrix, return handle - g_q4_keep_matrix(m); - return reinterpret_cast(m); +uintptr_t make_q4 +( + torch::Tensor qweight, + torch::Tensor qzeros, + torch::Tensor scales, + torch::Tensor g_idx, + int device +) +{ + TORCH_CHECK_DTYPE(qweight, kInt); + TORCH_CHECK_DTYPE(qzeros, kInt); + TORCH_CHECK_DTYPE(scales, kHalf); + TORCH_CHECK_DTYPE_OPT(g_idx, kInt); + TORCH_CHECK_SHAPES(qweight, 1, qzeros, 1, 8); + TORCH_CHECK_SHAPES(scales, 1, qweight, 1, 1); + TORCH_CHECK_SHAPES(qzeros, 0, scales, 0, 1); + + int width = qweight.size(1); + int height = qweight.size(0) * 8; + int groups = qzeros.size(0); + + Q4Matrix* m = new Q4Matrix + ( + height, + width, + groups, + + (uint32_t*) qweight.data_ptr(), + (uint32_t*) qzeros.data_ptr(), + (half*) scales.data_ptr(), + g_idx.device().is_meta() ? NULL : (uint32_t*) g_idx.data_ptr(), + + device + ); + + g_q4_keep_matrix(m); + return reinterpret_cast (m); } -// Matmul half @ quant -> half - -void q4_matmul(torch::Tensor x, uintptr_t w, torch::Tensor out) { - Q4Matrix* wm = reinterpret_cast(w); - TORCH_CHECK_DTYPE(x, kHalf); - TORCH_CHECK_DTYPE(out, kHalf); - TORCH_CHECK_SHAPES(x, 0, out, 0, 1); - TORCH_CHECK(wm->height == x.size(-1), "x and w have incompatible shapes") - - const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); - - int x_height = x.size(0); +// Matmul half @ quant -> half - if (tuningParams.matmul_recons_thd == 0 || - x_height < tuningParams.matmul_recons_thd) { - q4_matmul_cuda(&tuningParams, (half*)x.data_ptr(), x_height, wm, - (half*)out.data_ptr()); - } else { - q4_matmul_recons_cuda(&tuningParams, (half*)x.data_ptr(), x_height, wm, - (half*)out.data_ptr(), - at::cuda::getCurrentCUDABlasHandle()); - } +void q4_matmul +( + torch::Tensor x, + uintptr_t w, + torch::Tensor out +) +{ + Q4Matrix* wm = reinterpret_cast (w); + + TORCH_CHECK_DTYPE(x, kHalf); + TORCH_CHECK_DTYPE(out, kHalf); + TORCH_CHECK_SHAPES(x, 0, out, 0, 1); + TORCH_CHECK(wm->height == x.size(-1), "x and w have incompatible shapes") + + const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); + + int x_height = x.size(0); + + if (tuningParams.matmul_recons_thd == 0 || x_height < tuningParams.matmul_recons_thd) + { + q4_matmul_cuda + ( + &tuningParams, + (half*) x.data_ptr(), + x_height, + wm, + (half*) out.data_ptr() + ); + } + else + { + q4_matmul_recons_cuda + ( + &tuningParams, + (half*) x.data_ptr(), + x_height, + wm, + (half*) out.data_ptr(), + at::cuda::getCurrentCUDABlasHandle() + ); + } } -// Remap columns in half tensor - -void column_remap(torch::Tensor x, torch::Tensor x_new, torch::Tensor x_map) { - TORCH_CHECK_DTYPE(x, kHalf); - TORCH_CHECK_DTYPE(x_new, kHalf); - TORCH_CHECK_DTYPE(x_map, kInt); - TORCH_CHECK_SHAPES(x_map, 0, x, 1, 1); - int height = x.size(0); - int width = x.size(1); - - TORCH_CHECK_BUFFER_SIZE(x_new, height * width); - - const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); +// Remap columns in half tensor - column_remap_cuda((half*)x.data_ptr(), (half*)x_new.data_ptr(), height, width, - (uint32_t*)x_map.data_ptr()); +void column_remap +( + torch::Tensor x, + torch::Tensor x_new, + torch::Tensor x_map +) +{ + TORCH_CHECK_DTYPE(x, kHalf); + TORCH_CHECK_DTYPE(x_new, kHalf); + TORCH_CHECK_DTYPE(x_map, kInt); + TORCH_CHECK_SHAPES(x_map, 0, x, 1, 1); + + int height = x.size(0); + int width = x.size(1); + + TORCH_CHECK_BUFFER_SIZE(x_new, height * width); + + const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); + + column_remap_cuda + ( + (half*) x.data_ptr(), + (half*) x_new.data_ptr(), + height, + width, + (uint32_t*) x_map.data_ptr() + ); } -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("set_tuning_params", &set_tuning_params, "set_tuning_params"); - m.def("prepare_buffers", &prepare_buffers, "prepare_buffers"); - m.def("cleanup", &cleanup, "cleanup"); - m.def("make_q4", &make_q4, "make_q4"); - m.def("q4_matmul", &q4_matmul, "q4_matmul"); + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("set_tuning_params", &set_tuning_params, "set_tuning_params"); + m.def("prepare_buffers", &prepare_buffers, "prepare_buffers"); + m.def("cleanup", &cleanup, "cleanup"); + m.def("make_q4", &make_q4, "make_q4"); + m.def("q4_matmul", &q4_matmul, "q4_matmul"); } diff --git a/colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cu b/colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cu index bd595ee6f86c..9c61143f565e 100644 --- a/colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cu +++ b/colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cu @@ -184,7 +184,7 @@ __global__ void reconstruct_kernel int column = RECONS_THREADS_X * blockIdx.x + threadIdx.x; int row = (RECONS_THREADS_Y * blockIdx.y + threadIdx.y) * 8; if (column >= width) return; - + // Views MatrixView_q4_column w_(w, height, width); diff --git a/colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cuh b/colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cuh index 49431dc95876..50cb72a41518 100644 --- a/colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cuh +++ b/colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cuh @@ -50,4 +50,4 @@ private: void g_q4_keep_matrix(Q4Matrix* m); void g_q4_free_matrices(); -#endif +#endif \ No newline at end of file diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index 5ec48096183b..c55155da34e6 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -53,7 +53,7 @@ class MatmulWithAsyncCommunication(torch.autograd.Function): @staticmethod def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce): - ctx.save_for_backward(input_, weight) + ctx.save_for_backward(input_, weight, bias) ctx.use_bias = bias is not None ctx.process_group = process_group ctx.async_grad_allreduce = async_grad_allreduce @@ -66,9 +66,13 @@ def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce): @staticmethod def backward(ctx, grad_output): - input, weight = ctx.saved_tensors + input, weight, bias = ctx.saved_tensors use_bias = ctx.use_bias + # In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. + weight = weight.view(weight.shape) + bias = bias.view(bias.shape) + total_input = input grad_input = grad_output.matmul(weight.T) grad_output = grad_output.contiguous() @@ -100,7 +104,7 @@ class LinearWithAsyncCommunication(torch.autograd.Function): @staticmethod def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce): - ctx.save_for_backward(input_, weight) + ctx.save_for_backward(input_, weight, bias) ctx.use_bias = bias is not None ctx.process_group = process_group ctx.async_grad_allreduce = async_grad_allreduce @@ -113,9 +117,13 @@ def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce): @staticmethod def backward(ctx, grad_output): - input, weight = ctx.saved_tensors + input, weight, bias = ctx.saved_tensors use_bias = ctx.use_bias + # In order to be hooked into Gemini's '__torch_function__', adding a view operation to bias. + if use_bias: + bias.view(bias.shape) + total_input = input grad_input = grad_output.matmul(weight) grad_output = grad_output.contiguous() diff --git a/colossalai/shardformer/layer/embedding.py b/colossalai/shardformer/layer/embedding.py index 62163cb009aa..d081b204093b 100644 --- a/colossalai/shardformer/layer/embedding.py +++ b/colossalai/shardformer/layer/embedding.py @@ -309,7 +309,8 @@ def forward(self, input_: Tensor) -> Tensor: ) # Mask the output embedding. - output_parallel[input_mask, :] = 0.0 + embedding_output = output_parallel.clone() + embedding_output[input_mask, :] = 0.0 # Reduce across all the model parallel GPUs. - output = reduce_forward(output_parallel, self.process_group) + output = reduce_forward(embedding_output, self.process_group) return output diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index df7e1163c3d9..7d887b9766e6 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -318,6 +318,7 @@ def backward(self, loss: torch.Tensor): self._post_backward() def backward_by_grad(self, tensor, grad): + self._pre_backward() with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(self.param_op_hook): torch.autograd.backward(tensor, grad) self._post_backward() diff --git a/examples/inference/gptq_bloom.py b/examples/inference/gptq_bloom.py index bb92e5471d89..9afa438dc1a5 100644 --- a/examples/inference/gptq_bloom.py +++ b/examples/inference/gptq_bloom.py @@ -1,4 +1,5 @@ import argparse +import logging import os import time @@ -27,7 +28,7 @@ def print_perf_stats(latency_set, config, bs, warmup=3): avg = sum(latency_set) / count num_layers = getattr(config, "num_layers", config.num_hidden_layers) num_parameters = num_layers * config.hidden_size * config.hidden_size * 12 - num_bytes = 2 # float16 + num_bytes = 2 # float16 print("Avg Per Token Latency: {0:8.2f} ms".format(avg * 1000)) print("Avg BW: {0:8.2f} GB/s".format(1 / avg * num_parameters * num_bytes / 1e9)) @@ -36,6 +37,7 @@ def print_perf_stats(latency_set, config, bs, warmup=3): def bench_bloom(args): + pretrained_model_dir = args.path quantized_model_dir = args.quantized_path max_batch_size = args.batch_size @@ -46,9 +48,9 @@ def bench_bloom(args): tokenizer.pad_token = tokenizer.eos_token # load quantized model to the first GPU - model = AutoGPTQForCausalLM.from_quantized( - quantized_model_dir, device=torch.cuda.current_device(), inject_fused_attention=False - ) + model = AutoGPTQForCausalLM.from_quantized(quantized_model_dir, + device=torch.cuda.current_device(), + inject_fused_attention=False) model = model.half() @@ -58,22 +60,22 @@ def bench_bloom(args): generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False) input_tokens = { - "input_ids": torch.randint(1, 1000, (max_batch_size, max_input_len), device="cuda"), - "attention_mask": torch.ones((max_batch_size, max_input_len), device="cuda"), + "input_ids": torch.randint(1, 1000, (max_batch_size, max_input_len), device='cuda'), + "attention_mask": torch.ones((max_batch_size, max_input_len), device='cuda') } # init TPInferEngine and shard the original model # To benchmark torch original, comment out the line of optimizing model - shard_config = ShardConfig( - enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True, inference_gptq=True - ) + shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, + inference_only=True, + inference_gptq=True) infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len) # prepare data for generation generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False) input_tokens = { "input_ids": torch.randint(10, 1000, (max_batch_size, max_input_len)), - "attention_mask": torch.ones((max_batch_size, max_input_len)), + "attention_mask": torch.ones((max_batch_size, max_input_len)) } for t in input_tokens: if torch.is_tensor(input_tokens[t]): @@ -97,7 +99,7 @@ def bench_bloom(args): def check_bloom(rank, world_size, port, args): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') bench_bloom(args) @@ -109,12 +111,12 @@ def test_bloom(args): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("-p", "--path", type=str, help="Model path", required=True) - parser.add_argument("-q", "--quantized_path", type=str, help="Model path", required=True) - parser.add_argument("-tp", "--tp_size", type=int, default=1, help="Tensor parallel size") - parser.add_argument("-b", "--batch_size", type=int, default=16, help="Maximum batch size") - parser.add_argument("--input_len", type=int, default=1024, help="Maximum input length") - parser.add_argument("--output_len", type=int, default=128, help="Maximum output length") + parser.add_argument('-p', '--path', type=str, help='Model path', required=True) + parser.add_argument('-q', '--quantized_path', type=str, help='Model path', required=True) + parser.add_argument('-tp', '--tp_size', type=int, default=1, help='Tensor parallel size') + parser.add_argument('-b', '--batch_size', type=int, default=16, help='Maximum batch size') + parser.add_argument('--input_len', type=int, default=1024, help='Maximum input length') + parser.add_argument('--output_len', type=int, default=128, help='Maximum output length') args = parser.parse_args() diff --git a/tests/test_booster/test_plugin/test_gemini_plugin.py b/tests/test_booster/test_plugin/test_gemini_plugin.py index 00ff6cb37d2a..71bcea334273 100644 --- a/tests/test_booster/test_plugin/test_gemini_plugin.py +++ b/tests/test_booster/test_plugin/test_gemini_plugin.py @@ -15,13 +15,13 @@ from tests.kit.model_zoo import model_zoo -def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]: +def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, use_tp_pipeline) -> Optional[str]: try: if init_method == "lazy": ctx = LazyInitContext() else: ctx = nullcontext() - plugin = GeminiPlugin(max_norm=1.0, initial_scale=2**5) + plugin = GeminiPlugin(max_norm=1.0, initial_scale=2**5, use_tp_pipeline=use_tp_pipeline) booster = Booster(plugin=plugin) with ctx: model = model_fn() @@ -57,7 +57,8 @@ def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) -> Optional[ @parameterize("subset", ["torchvision", "transformers", "diffusers"]) @parameterize("init_method", ["none"]) -def check_gemini_plugin(subset: str, init_method: str = "none", early_stop: bool = True): +@parameterize("use_tp_pipeline", [True]) +def check_gemini_plugin(subset: str, init_method: str = "none", use_tp_pipeline: bool = True, early_stop: bool = True): """check gemini plugin over model zoo Args: @@ -116,7 +117,12 @@ def check_gemini_plugin(subset: str, init_method: str = "none", early_stop: bool "torchvision_efficientnet_v2_s", ]: continue - err = run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) + + # TODO debug blip2 when using tp, something wrong with shift_logits's shape + if "transformers_blip2" in name: + use_tp_pipeline = False + + err = run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, use_tp_pipeline) torch.cuda.empty_cache() if err is None: passed_models.append(name) diff --git a/tests/test_gptq/test_gptq_linear.py b/tests/test_gptq/test_gptq_linear.py index ded70fa43c30..9b650aa78112 100644 --- a/tests/test_gptq/test_gptq_linear.py +++ b/tests/test_gptq/test_gptq_linear.py @@ -1,8 +1,16 @@ +import math +import time + +import numpy as np import pytest import torch +import torch.nn as nn +import transformers from packaging import version try: + import triton + import triton.language as tl HAS_TRITON = True except ImportError: HAS_TRITON = False @@ -14,7 +22,6 @@ from exllama_kernels import prepare_buffers, set_tuning_params from colossalai.inference.quant.gptq import CaiQuantLinear - HAS_AUTO_GPTQ = True except: HAS_AUTO_GPTQ = False @@ -25,14 +32,13 @@ HAS_GPTQ_CUDA = False try: from colossalai.kernel.op_builder.gptq import GPTQBuilder - gptq_cuda = GPTQBuilder().load() HAS_GPTQ_CUDA = True except ImportError: - warnings.warn("CUDA gptq is not installed") + warnings.warn('CUDA gptq is not installed') HAS_GPTQ_CUDA = False -TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4') max_inner_outer_dim = 1 max_input_len = 1 @@ -58,9 +64,9 @@ def init_buffer(cai_linear, use_act_order=False): max_input_len = 4096 # The temp_state buffer is required to reorder X in the act-order case. # The temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill. - gptq_temp_state_buffer = torch.zeros( - (max_input_len, max_inner_outer_dim), dtype=torch.float16, device=torch.cuda.current_device() - ) + gptq_temp_state_buffer = torch.zeros((max_input_len, max_inner_outer_dim), + dtype=torch.float16, + device=torch.cuda.current_device()) gptq_temp_dq_buffer = torch.zeros((1, max_dq_buffer_size), dtype=torch.float16, device=torch.cuda.current_device()) gptq_cuda.prepare_buffers(torch.device(torch.cuda.current_device()), gptq_temp_state_buffer, gptq_temp_dq_buffer) @@ -71,11 +77,10 @@ def init_buffer(cai_linear, use_act_order=False): gptq_cuda.set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2) -@pytest.mark.skipif( - not TRITON_CUDA_SUPPORT or not HAS_TRITON or not HAS_AUTO_GPTQ, - reason="triton requires cuda version to be higher than 11.4 or not install auto-gptq", -) +@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON or not HAS_AUTO_GPTQ, + reason="triton requires cuda version to be higher than 11.4 or not install auto-gptq") def test_gptq_linear(): + infeature = 1024 outfeature = 1024 group_size = 128 @@ -115,7 +120,7 @@ def test_gptq_linear(): max_input_len = 2048 buffers = { "temp_state": torch.zeros((max_input_len, max_inner_outer_dim), dtype=torch.float16, device=device), - "temp_dq": torch.zeros((1, max_dq_buffer_size), dtype=torch.float16, device=device), + "temp_dq": torch.zeros((1, max_dq_buffer_size), dtype=torch.float16, device=device) } prepare_buffers(device, buffers["temp_state"], buffers["temp_dq"]) @@ -141,4 +146,5 @@ def test_gptq_linear(): if __name__ == "__main__": + test_gptq_linear() From da1915dcf816db4b68af3824b43f989b1f1d2382 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Thu, 19 Oct 2023 13:41:32 +0800 Subject: [PATCH 27/46] fix fix fix --- colossalai/booster/plugin/gemini_plugin.py | 24 +++++++++++++------ .../test_plugin/test_gemini_plugin.py | 14 +++++------ 2 files changed, 24 insertions(+), 14 deletions(-) diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index 8fcb0f1a57d7..84c4e8546a43 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -6,6 +6,7 @@ import torch import torch.nn as nn +import torch.distributed as dist from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.utils.data import DataLoader @@ -25,6 +26,7 @@ from colossalai.utils import get_current_device from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero.gemini.memory_tracer import MemStats +from colossalai.cluster import ProcessGroupMesh from .dp_plugin_base import DPPluginBase @@ -285,6 +287,8 @@ class GeminiPlugin(DPPluginBase): max_norm (float, optional): max_norm used for `clip_grad_norm`. You should notice that you shall not do clip_grad_norm by yourself when using ZeRO DDP. The ZeRO optimizer will take care of clip_grad_norm. norm_type (float, optional): norm_type used for `clip_grad_norm`. + use_tensor_parallel (bool, optional): Whether to use tensor parallelism strategy, which is implemented in Shardformer. Default to False. + tp_size (int, optional): If 'use_tensor_parallel' is set to true, please configure 'tp_size' which determines the size of the tensor parallel process group. Default to 1. verbose (bool, optional): verbose mode. Debug info including chunk search result will be printed. Defaults to False. """ @@ -318,7 +322,8 @@ def __init__( max_scale: float = 2**32, max_norm: float = 0.0, norm_type: float = 2.0, - use_tp_pipeline: bool = False, + use_tensor_parallel: bool = False, + tp_size: int = 1, verbose: bool = False ) -> None: super().__init__() @@ -357,7 +362,8 @@ def __init__( max_norm=max_norm, norm_type=norm_type, ) - self.use_tp_pipeline = use_tp_pipeline + self.use_tensor_parallel = use_tensor_parallel + self.tp_size = tp_size self.verbose = verbose def support_no_sync(self) -> bool: @@ -394,17 +400,21 @@ def configure( # model = nn.SyncBatchNorm.convert_sync_batchnorm(model, None) # wrap the model with Gemini - - if self.use_tp_pipeline: + self.dp_group = None + if self.use_tensor_parallel: try: - shard_config = ShardConfig(enable_tensor_parallelism=True, enable_fused_normalization=False) + dp_size = dist.get_world_size() // self.tp_size + self.pg_mesh = ProcessGroupMesh(dp_size, self.tp_size) + self.dp_group = self.pg_mesh.get_group_along_axis(0) + self.tp_group = self.pg_mesh.get_group_along_axis(1) + shard_config = ShardConfig(tensor_parallel_process_group = self.tp_group, enable_tensor_parallelism=True) shardformer = ShardFormer(shard_config) model, _ = shardformer.optimize(model) optimizer.param_groups[0]["params"] = model.parameters() except NotImplementedError as e: - print(f"Auto policy for {model.__class__} is not implemented yet\n.") + print(f"Tensor Parallelism policy for {model.__class__} is not implemented yet\n.") - model = GeminiDDP(model, **self.gemini_config, verbose=self.verbose) + model = GeminiDDP(model, **self.gemini_config, process_group=self.dp_group, verbose=self.verbose) if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): optimizer = GeminiOptimizer( diff --git a/tests/test_booster/test_plugin/test_gemini_plugin.py b/tests/test_booster/test_plugin/test_gemini_plugin.py index 71bcea334273..0a2ac52dc0b8 100644 --- a/tests/test_booster/test_plugin/test_gemini_plugin.py +++ b/tests/test_booster/test_plugin/test_gemini_plugin.py @@ -15,13 +15,13 @@ from tests.kit.model_zoo import model_zoo -def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, use_tp_pipeline) -> Optional[str]: +def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, use_tensor_parallel) -> Optional[str]: try: if init_method == "lazy": ctx = LazyInitContext() else: ctx = nullcontext() - plugin = GeminiPlugin(max_norm=1.0, initial_scale=2**5, use_tp_pipeline=use_tp_pipeline) + plugin = GeminiPlugin(max_norm=1.0, initial_scale=2**5, use_tensor_parallel=use_tensor_parallel) booster = Booster(plugin=plugin) with ctx: model = model_fn() @@ -47,7 +47,7 @@ def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, use_tp_pipel optimizer.step() except Exception as e: - # raise e + raise e return repr(e) @@ -57,8 +57,8 @@ def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, use_tp_pipel @parameterize("subset", ["torchvision", "transformers", "diffusers"]) @parameterize("init_method", ["none"]) -@parameterize("use_tp_pipeline", [True]) -def check_gemini_plugin(subset: str, init_method: str = "none", use_tp_pipeline: bool = True, early_stop: bool = True): +@parameterize("use_tensor_parallel", [True]) +def check_gemini_plugin(subset: str, init_method: str = "none", use_tensor_parallel: bool = True, early_stop: bool = True): """check gemini plugin over model zoo Args: @@ -120,9 +120,9 @@ def check_gemini_plugin(subset: str, init_method: str = "none", use_tp_pipeline: # TODO debug blip2 when using tp, something wrong with shift_logits's shape if "transformers_blip2" in name: - use_tp_pipeline = False + use_tensor_parallel = False - err = run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, use_tp_pipeline) + err = run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, use_tensor_parallel) torch.cuda.empty_cache() if err is None: passed_models.append(name) From 9fd9e690c097873d2e6aae177882ccdee01d3b94 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Fri, 20 Oct 2023 15:39:26 +0800 Subject: [PATCH 28/46] update checkpointIO update checkpointIO update checkpointIO update checkpointIO update checkpointIO update checkpointIO update checkpointIO update checkpointIO update checkpointIO --- colossalai/booster/plugin/gemini_plugin.py | 24 +++++- colossalai/tensor/d_tensor/__init__.py | 4 + colossalai/tensor/d_tensor/api.py | 59 ++++++++++++++ colossalai/zero/gemini/gemini_ddp.py | 49 +++++++++++- colossalai/zero/gemini/gemini_optimizer.py | 77 +++++++++++++++---- .../test_plugin/test_gemini_plugin.py | 4 +- .../test_gemini_checkpoint_io.py | 16 ++-- 7 files changed, 205 insertions(+), 28 deletions(-) diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index 84c4e8546a43..7e46bbc11172 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -36,6 +36,23 @@ PRECISION_STR_TO_DTYPE = {"fp16": torch.half, "bf16": torch.bfloat16} +def get_param_info(optim: Optimizer): + # Get a backup of necessary information of parameters for future use, which includes: + # 1. A mapping from integer param_id to param32 shape. + + if optim is None: + return {} + param_info = {"id2shape": {}} + start_index = 0 + for group in optim.param_groups: + for param_id, param in enumerate(group["params"], start_index): + original_shape = param.shape if isinstance(param, torch.Tensor) else None + param_info["id2shape"][param_id] = original_shape + + start_index += len(group["params"]) + + return param_info + class GeminiCheckpointIO(GeneralCheckpointIO): def __init__(self) -> None: super().__init__() @@ -363,7 +380,7 @@ def __init__( norm_type=norm_type, ) self.use_tensor_parallel = use_tensor_parallel - self.tp_size = tp_size + self.tp_size = tp_size if self.use_tensor_parallel else 1 self.verbose = verbose def support_no_sync(self) -> bool: @@ -389,6 +406,7 @@ def configure( dataloader: Optional[DataLoader] = None, lr_scheduler: Optional[LRScheduler] = None, ) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: + param_info = get_param_info(optimizer) if not isinstance(model, ModelWrapper): # convert model to sync bn # FIXME(ver217): gemini does not support sync bn @@ -401,6 +419,7 @@ def configure( # wrap the model with Gemini self.dp_group = None + self.tp_group = None if self.use_tensor_parallel: try: dp_size = dist.get_world_size() // self.tp_size @@ -410,7 +429,6 @@ def configure( shard_config = ShardConfig(tensor_parallel_process_group = self.tp_group, enable_tensor_parallelism=True) shardformer = ShardFormer(shard_config) model, _ = shardformer.optimize(model) - optimizer.param_groups[0]["params"] = model.parameters() except NotImplementedError as e: print(f"Tensor Parallelism policy for {model.__class__} is not implemented yet\n.") @@ -418,7 +436,7 @@ def configure( if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): optimizer = GeminiOptimizer( - optimizer, model, **self.zero_optim_config, **self.optim_kwargs, verbose=self.verbose + optimizer, model, **self.zero_optim_config, **self.optim_kwargs, param_info=param_info, tp_group=self.tp_group, verbose=self.verbose ) return model, optimizer, criterion, dataloader, lr_scheduler diff --git a/colossalai/tensor/d_tensor/__init__.py b/colossalai/tensor/d_tensor/__init__.py index fad5101d380c..6f8097735d57 100644 --- a/colossalai/tensor/d_tensor/__init__.py +++ b/colossalai/tensor/d_tensor/__init__.py @@ -2,7 +2,9 @@ compute_global_numel, customized_distributed_tensor_to_param, distribute_tensor, + init_as_dtensor, distribute_tensor_with_customization, + init_tensor_as_customization_distributed, get_device_mesh, get_global_shape, get_layout, @@ -23,6 +25,7 @@ __all__ = [ "is_distributed_tensor", "distribute_tensor", + "init_as_dtensor", "to_global", "is_sharded", "shard_rowwise", @@ -36,6 +39,7 @@ "get_layout", "is_customized_distributed_tensor", "distribute_tensor_with_customization", + "init_tensor_as_customization_distributed", "to_global_for_customized_distributed_tensor", "customized_distributed_tensor_to_param", "Layout", diff --git a/colossalai/tensor/d_tensor/api.py b/colossalai/tensor/d_tensor/api.py index 178bac428ea9..74a785f2dcd4 100644 --- a/colossalai/tensor/d_tensor/api.py +++ b/colossalai/tensor/d_tensor/api.py @@ -128,6 +128,17 @@ def distribute_tensor(tensor: torch.Tensor, device_mesh: DeviceMesh, sharding_sp return sharded_tensor +def init_as_dtensor(tensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec: ShardingSpec, global_shape: torch.Size) -> torch.Tensor: + assert not is_distributed_tensor(tensor), "The input tensor is already a distributed tensor." + dist_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec, global_shape=global_shape) + + # shard tensor + tensor.dist_layout = dist_layout + + # hack some tensor methods + _hijack_detach_and_clone(tensor) + + return tensor def redistribute(dtensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> None: """ @@ -420,6 +431,54 @@ def gather_fn(tensor): return sharded_tensor +def init_tensor_as_customization_distributed(tensor: torch.Tensor, shard_fn, gather_fn: callable): + """ + Distribute the given tensor with the given shard_fn and gather_fn. + + Example: + + ```python + # define shard and gather functions + def shard_fn(tensor): + rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() + return tensor.chunk(world_size, dim=0)[rank] + + def gather_fn(tensor): + rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() + shard_list = [torch.zeros_like(tensor) for _ in range(world_size)] + torch.distributed.all_gather(shard_list, tensor) + return torch.cat(shard_list, dim=0) + + # create a distributed tensor + tensor = torch.rand(4, 4) + dtensor = init_tensor_as_customization_distributed(tensor, shard_fn, gather_fn) + ``` + + Args: + tensor (torch.Tensor): The tensor to be distributed. + shard_fn (callable): The function to shard the tensor. + gather_fn (callable): The function to gather the tensor. + + Returns: + torch.Tensor: The distributed tensor. + """ + assert callable(shard_fn), "The shard_fn must be callable." + assert callable(gather_fn), "The gather_fn must be callable." + assert not is_distributed_tensor(tensor), "The input tensor is already a distributed tensor." + + + # set the shard_fn and gather_fn as attributes of the distributed tensor + tensor.shard_fn = shard_fn + tensor.gather_fn = gather_fn + + # set the shard_fn and gather_fn as attributes of the distributed tensor + _hijack_detach_and_clone_for_customized_distributed_tensor(tensor) + + return tensor + + def to_global_for_customized_distributed_tensor(dtensor: torch.Tensor) -> torch.Tensor: """ Gather the given tensor to the global tensor. diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index 7d887b9766e6..c6db50585dcf 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -17,6 +17,7 @@ from colossalai.tensor.colo_parameter import ColoParameter from colossalai.tensor.param_op_hook import ColoParamOpHookManager from colossalai.utils import _cast_float, free_storage, get_current_device, is_ddp_ignored +from colossalai.checkpoint_io.utils import gather_distributed_param from .chunk import Chunk, ChunkManager, TensorState, init_chunk_manager from .gemini_hook import GeminiZeROHook @@ -24,6 +25,18 @@ from .memory_tracer import MemStats, OrderedParamGenerator from .utils import get_temp_total_chunk_on_cuda +from colossalai.tensor.d_tensor import ( + distribute_tensor, + distribute_tensor_with_customization, + init_tensor_as_customization_distributed, + get_device_mesh, + get_sharding_spec, + is_customized_distributed_tensor, + is_distributed_tensor, + get_global_shape, + init_as_dtensor +) + try: from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys except ImportError: @@ -431,7 +444,18 @@ def _get_chunk_to_save_data(self, chunk: Chunk, only_rank_0: bool) -> Dict: record_tensor = torch.empty([0]) record_flag = (not only_rank_0) | (dist.get_rank(chunk.torch_pg) == 0) if record_flag: - record_tensor = temp_chunk[tensor_info.offset : tensor_info.end].view(tensor.shape).cpu() + record_tensor = temp_chunk[tensor_info.offset : tensor_info.end].view(tensor.shape).cuda() + if is_distributed_tensor(tensor): + global_shape = get_global_shape(tensor) + device_mesh = get_device_mesh(tensor) + shard_spec = get_sharding_spec(tensor) + record_tensor = init_as_dtensor(record_tensor, + device_mesh=device_mesh, + sharding_spec=shard_spec, + global_shape = global_shape) + elif is_customized_distributed_tensor(tensor): + init_tensor_as_customization_distributed(record_tensor, shard_fn=tensor.shard_fn, gather_fn=tensor.gather_fn) + record_tensor = gather_distributed_param(record_tensor, keep_vars=False) assert tensor not in chunk_to_save_data chunk_to_save_data[tensor] = record_tensor @@ -606,10 +630,16 @@ def _load_from_state_dict( local_name_params = itertools.chain(self.named_parameters(), persistent_buffers.items()) local_state = {k: v for k, v in local_name_params if v is not None} - def load(param_name, dest_tensor, copy_func): + def load(param_name, dest_tensor, copy_func, source_device_mesh=None, source_sharding_spec=None, shard_fn=None, gather_fn=None): state_key = prefix + param_name if state_key in state_dict: input_param = state_dict[state_key] + + if source_device_mesh is not None and source_sharding_spec is not None: + input_param = distribute_tensor(input_param, source_device_mesh, source_sharding_spec) + elif shard_fn is not None and gather_fn is not None: + input_param = distribute_tensor_with_customization(input_param, shard_fn=shard_fn, gather_fn=gather_fn) + # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+ if len(dest_tensor.shape) == 0 and len(input_param.shape) == 1: input_param = input_param[0] @@ -653,9 +683,19 @@ def load_parameter(chunk_slice, data): temp_chunk = get_temp_total_chunk_on_cuda(chunk, self.mixed_precision) for tensor, tensor_info in chunk.tensors_info.items(): + + source_device_mesh, source_sharding_spec, shard_fn, gather_fn = None, None, None, None + if is_distributed_tensor(tensor): + # shard the input param + source_device_mesh = get_device_mesh(tensor) + source_sharding_spec = get_sharding_spec(tensor) + elif is_customized_distributed_tensor(tensor): + shard_fn = tensor.shard_fn + gather_fn = tensor.gather_fn + parameter_name = fp32_to_name[tensor] if self.reuse_fp16_chunk else self.param2name[tensor] parameter_slice = temp_chunk[tensor_info.offset : tensor_info.end] - load(parameter_name, tensor, partial(load_parameter, parameter_slice)) + load(parameter_name, tensor, partial(load_parameter, parameter_slice), source_device_mesh, source_sharding_spec, shard_fn, gather_fn) if chunk.is_gathered: chunk.cuda_global_chunk.copy_(temp_chunk) @@ -724,7 +764,8 @@ def _init_chunks(self, param_order, strict_ddp_mode: bool, cpu_offload: bool, pi if self.master_weights: # create a fp32 parameter - fp32_p = p.data.float() + fp32_p = p.clone() + fp32_p.data = fp32_p.data.float() self.chunk_manager.register_tensor( tensor=fp32_p, group_type="fp32_param", diff --git a/colossalai/zero/gemini/gemini_optimizer.py b/colossalai/zero/gemini/gemini_optimizer.py index 0d0298e067f3..7f97c0a82ed9 100644 --- a/colossalai/zero/gemini/gemini_optimizer.py +++ b/colossalai/zero/gemini/gemini_optimizer.py @@ -9,6 +9,7 @@ from packaging.version import Version from torch.nn import Parameter from torch.optim import Optimizer +from torch.distributed import ProcessGroup from colossalai.amp.naive_amp.mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin from colossalai.checkpoint_io.utils import StateDictSharder @@ -19,6 +20,7 @@ from .chunk import Chunk, ChunkManager from .gemini_ddp import GeminiDDP +from colossalai.checkpoint_io.utils import search_tp_partition_dim __all__ = ["GeminiOptimizer", "GeminiAdamOptimizer"] @@ -93,6 +95,8 @@ def __init__( max_scale: float = 2**32, max_norm: float = 0.0, norm_type: float = 2.0, + param_info: OrderedDict = None, + tp_group: ProcessGroup = None, verbose: bool = False, **defaults: Any, ): @@ -109,6 +113,10 @@ def __init__( self.chunk16_set: Set[Chunk] = set() self.clipping_flag = max_norm > 0.0 self.max_norm = max_norm + self.param_info = param_info + self.tp_group = tp_group + self.tp_size = dist.get_world_size(tp_group) if tp_group is not None else 1 + self.tp_rank = dist.get_rank(tp_group) self.verbose = verbose self.param_groups_backup = list() @@ -406,8 +414,8 @@ def collect_states(self, param_id: int, only_rank_0: bool = True) -> dict: param = self.id_to_real_params[param_id] fake_param = self.id_to_fake_params.get(param_id, None) chunk = self.chunk_manager.get_chunk(param) - process_group = chunk.torch_pg - rank = dist.get_rank(process_group) + dp_group = chunk.torch_pg + rank = dist.get_rank(dp_group) master_rank = 0 collected_states = {} @@ -415,9 +423,9 @@ def collect_states(self, param_id: int, only_rank_0: bool = True) -> dict: local_state_names = None if fake_param is not None: local_state_names = list(self.optim.state[fake_param].keys()) - gathered_state_names = [None for _ in range(dist.get_world_size(process_group))] + gathered_state_names = [None for _ in range(dist.get_world_size(dp_group))] dist.barrier() - dist.all_gather_object(gathered_state_names, local_state_names) + dist.all_gather_object(gathered_state_names, local_state_names, dp_group) state_names = None for names in gathered_state_names: if names is not None: @@ -439,6 +447,13 @@ def collect_states(self, param_id: int, only_rank_0: bool = True) -> dict: # If the chunk is kept gathered, # the parameteres are treated the same as that of those in strict DDP during training. # So states can be directly fetched from current device. + tp_shard_info = {} + current_shape = param.shape + original_shape = self.param_info["id2shape"][param_id] + partition_dim = search_tp_partition_dim(current_shape, original_shape, self.tp_size) + tp_shard_info["shared"] = (current_shape, original_shape, partition_dim) + + if chunk.keep_gathered: assert param_id in self.id_to_fake_params if is_collector: @@ -450,8 +465,14 @@ def collect_states(self, param_id: int, only_rank_0: bool = True) -> dict: states["step"], dtype=torch.float32, requires_grad=False ).cpu() else: - state_tensor = states[state_name].detach().clone().to(torch.float32).cpu() - collected_states[state_name] = torch.reshape(state_tensor, param.shape) + state_tensor = states[state_name].detach().clone().to(torch.float32).cuda() + state_tensor = state_tensor.view(current_shape) + if partition_dim is not None: + gather_tensor = [torch.zeros_like(state_tensor) for _ in range(self.tp_size)] + dist.all_gather(gather_tensor, state_tensor, group=self.tp_group) + state_tensor = torch.cat(gather_tensor, dim=partition_dim) + + collected_states[state_name] = torch.reshape(state_tensor, original_shape).cpu() return collected_states # Check whether the param with given id is managed by current process. @@ -464,16 +485,21 @@ def collect_states(self, param_id: int, only_rank_0: bool = True) -> dict: # To keep aligned with pytorch, state 'step' is stored as a pytorch tensor with type float32. collected_states[state_name] = torch.tensor(0.0, dtype=torch.float32, requires_grad=False).cpu() else: + tensor_size = param.numel() + if partition_dim is not None: + tensor_size * self.tp_size collected_states[state_name] = torch.zeros( - param.numel(), dtype=torch.float32, requires_grad=False + tensor_size, dtype=torch.float32, requires_grad=False ).cpu() # Materials for gathering, including compacted state tensors, and the offset of shard inside each state. - compacted_states = self.pack_optimizer_states_to_tensor(param_id, state_names) if own_param else None + compacted_states = self.pack_optimizer_states_to_tensor(param_id, state_names, tp_shard_info) if own_param else None _, shard_offset, shard_size = self.get_offsets(param_id) + if partition_dim is not None: + shard_size = shard_size * self.tp_size # Collectors gather state shards through all_gathering. - gathered_state_shards = [None for _ in range(dist.get_world_size(process_group))] + gathered_state_shards = [None for _ in range(dist.get_world_size(dp_group))] dist.barrier() dist.all_gather_object(gathered_state_shards, [compacted_states, shard_offset, shard_size]) @@ -493,7 +519,7 @@ def collect_states(self, param_id: int, only_rank_0: bool = True) -> dict: if is_collector: for state_name, state_tensor in collected_states.items(): if state_tensor.numel() == param.numel(): - collected_states[state_name] = torch.reshape(state_tensor, param.shape) + collected_states[state_name] = torch.reshape(state_tensor, original_shape) return collected_states @@ -501,6 +527,7 @@ def pack_optimizer_states_to_tensor( self, param_id: int, state_names: list, + tp_shard_info: dict, device: torch.device = torch.device("cuda"), dtype: torch.dtype = torch.float32, ) -> torch.Tensor: @@ -515,6 +542,11 @@ def pack_optimizer_states_to_tensor( states = self.optim.state[fake_param] shard_size = param_range[1] - param_range[0] compacted_size = 0 + + (current_shape, original_shape, partition_dim) = tp_shard_info["shared"] + if partition_dim is not None: + shard_size = shard_size * self.tp_size + for name in state_names: if name == "step": compacted_size += 1 @@ -533,6 +565,11 @@ def pack_optimizer_states_to_tensor( compacted_states[next_state_offset] = state_tensor next_state_offset += 1 else: + if partition_dim is not None: + gather_tensor = [torch.zeros(current_shape, dtype=state_tensor.dtype, device=state_tensor.device) for _ in range(self.tp_size)] + dist.all_gather(gather_tensor, state_tensor, group=self.tp_group) + state_tensor = torch.cat(gather_tensor, dim=partition_dim).view(original_shape) + assert state_tensor.numel() == shard_size compacted_states[next_state_offset : next_state_offset + shard_size].copy_(state_tensor) next_state_offset += shard_size @@ -545,7 +582,7 @@ def load_from_compacted_states( collected_states: dict, state_names: list, shard_start: int, - shard_size: int, + shard_size: int ): """ Given a tensor carrying compacted optimizer states, @@ -644,7 +681,7 @@ def load_single_param_states(self, param_id: int, saved_states: dict): Load saved optimizer states into parameter with given id. """ - def cast(param, state_range, value, key=None): + def cast(param, state_range, value, key=None, tp_shard_info=None): """ Make a copy of the needed segment of value and cast it to device of param. """ @@ -658,6 +695,13 @@ def cast(param, state_range, value, key=None): ret_val = torch.zeros( state_end - state_start, dtype=torch.float32, device=param.device, requires_grad=False ) + + (current_shape, _, partition_dim) = tp_shard_info["shard"] + if partition_dim is not None: + slice_size = current_shape[partition_dim] + value = value.split(slice_size, dim=partition_dim)[self.tp_rank] + + ret_val.copy_(value.flatten()[state_start:state_end]) return ret_val @@ -668,8 +712,15 @@ def cast(param, state_range, value, key=None): # Copy states assigned to param (and cast tensors to appropriate types). updated_states = dict() + + tp_shard_info = {} + current_shape = self.id_to_real_params[param_id].shape + original_shape = self.param_info["id2shape"][param_id] + partition_dim = search_tp_partition_dim(current_shape, original_shape, self.tp_size) + tp_shard_info["shard"] = (current_shape, original_shape, partition_dim) + for k, v in saved_states.items(): - updated_states[k] = cast(fake_param, state_range, v, k) + updated_states[k] = cast(fake_param, state_range, v, k, tp_shard_info) del v # clean loaded states self.optim.state[fake_param].update(updated_states) diff --git a/tests/test_booster/test_plugin/test_gemini_plugin.py b/tests/test_booster/test_plugin/test_gemini_plugin.py index 0a2ac52dc0b8..66c624257b58 100644 --- a/tests/test_booster/test_plugin/test_gemini_plugin.py +++ b/tests/test_booster/test_plugin/test_gemini_plugin.py @@ -47,7 +47,7 @@ def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, use_tensor_p optimizer.step() except Exception as e: - raise e + # raise e return repr(e) @@ -57,7 +57,7 @@ def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, use_tensor_p @parameterize("subset", ["torchvision", "transformers", "diffusers"]) @parameterize("init_method", ["none"]) -@parameterize("use_tensor_parallel", [True]) +@parameterize("use_tensor_parallel", [True, False]) def check_gemini_plugin(subset: str, init_method: str = "none", use_tensor_parallel: bool = True, early_stop: bool = True): """check gemini plugin over model zoo diff --git a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py index f876040384b3..bc6582f57371 100644 --- a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py @@ -37,7 +37,9 @@ @parameterize("placement_config", MODEL_PLACEMENT_CONFIGS) @parameterize("model_name", ["transformers_bert_for_sequence_classification"]) @parameterize("use_safetensors", [False, True]) -def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: bool): +@parameterize("use_tensor_parallel", [True, False]) +@parameterize("tp_size", [2]) +def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: bool, use_tensor_parallel: bool, tp_size: int): from transformers import BertForSequenceClassification (model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values())) @@ -47,7 +49,7 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b pretrained_path = os.path.join(tempdir, "pretrained") bert_model.config.save_pretrained(save_directory=pretrained_path) - plugin = GeminiPlugin(**placement_config) + plugin = GeminiPlugin(**placement_config, use_tensor_parallel=use_tensor_parallel, tp_size=tp_size) booster = Booster(plugin=plugin) bert_model, _, _, _, _ = booster.boost(bert_model) model_size = sum(p.numel() * p.element_size() for p in bert_model.parameters()) / 1024**2 @@ -63,13 +65,15 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b @clear_cache_before_run() @parameterize("placement_config", OPTIM_PLACEMENT_CONFIGS) -@parameterize("shard", [False, True]) +@parameterize("shard", [True, False]) @parameterize("model_name", ["transformers_gpt"]) @parameterize("size_per_shard", [32]) -def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_shard: int): +@parameterize("use_tensor_parallel", [True, False]) +@parameterize("tp_size", [2]) +def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_shard: int, use_tensor_parallel: bool, tp_size: int): (model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values())) criterion = lambda x: x.mean() - plugin = GeminiPlugin(**placement_config, precision="fp16", initial_scale=(2**14)) + plugin = GeminiPlugin(**placement_config, precision="fp16", initial_scale=(2**14), use_tensor_parallel=use_tensor_parallel, tp_size=tp_size) booster = Booster(plugin=plugin) model = model_fn() @@ -148,7 +152,7 @@ def run_dist(rank, world_size, port): @pytest.mark.dist -@pytest.mark.parametrize("world_size", [2]) +@pytest.mark.parametrize("world_size", [4]) @rerun_if_address_is_in_use() def test_gemini_ckpIO(world_size): spawn(run_dist, world_size) From a89f2fdb83efae6c2114a97b20178f04378c6550 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Mon, 23 Oct 2023 16:39:17 +0800 Subject: [PATCH 29/46] support fused layernorm support fused layernorm support fused layernorm --- colossalai/booster/plugin/gemini_plugin.py | 11 +++++- colossalai/shardformer/layer/_operation.py | 25 ++++++++++++ colossalai/shardformer/layer/normalization.py | 39 ++++++++++++------- colossalai/zero/gemini/gemini_ddp.py | 1 + .../test_plugin/test_gemini_plugin.py | 2 +- 5 files changed, 62 insertions(+), 16 deletions(-) diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index 7e46bbc11172..f3bd987099df 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -306,6 +306,8 @@ class GeminiPlugin(DPPluginBase): norm_type (float, optional): norm_type used for `clip_grad_norm`. use_tensor_parallel (bool, optional): Whether to use tensor parallelism strategy, which is implemented in Shardformer. Default to False. tp_size (int, optional): If 'use_tensor_parallel' is set to true, please configure 'tp_size' which determines the size of the tensor parallel process group. Default to 1. + use_fused_layernorm (bool, optional): Whether to use fused layernorm operator, which is implemented in Shardformer. Used when 'use_tensor_parallel' is True. Default to False. + use_flash_attention (bool, optional): Whether to use flash attention, which is implemented in Shardformer. Used when 'use_tensor_parallel' is True. Default to False. verbose (bool, optional): verbose mode. Debug info including chunk search result will be printed. Defaults to False. """ @@ -341,6 +343,8 @@ def __init__( norm_type: float = 2.0, use_tensor_parallel: bool = False, tp_size: int = 1, + use_fused_layernorm: bool = False, + use_flash_attention: bool = False, verbose: bool = False ) -> None: super().__init__() @@ -381,6 +385,8 @@ def __init__( ) self.use_tensor_parallel = use_tensor_parallel self.tp_size = tp_size if self.use_tensor_parallel else 1 + self.use_fused_layernorm = use_fused_layernorm if self.use_tensor_parallel else False + self.use_flash_attention = use_flash_attention if self.use_tensor_parallel else False self.verbose = verbose def support_no_sync(self) -> bool: @@ -426,7 +432,10 @@ def configure( self.pg_mesh = ProcessGroupMesh(dp_size, self.tp_size) self.dp_group = self.pg_mesh.get_group_along_axis(0) self.tp_group = self.pg_mesh.get_group_along_axis(1) - shard_config = ShardConfig(tensor_parallel_process_group = self.tp_group, enable_tensor_parallelism=True) + shard_config = ShardConfig(tensor_parallel_process_group = self.tp_group, + enable_tensor_parallelism=True, + enable_fused_normalization=self.use_fused_layernorm, + enable_flash_attention=self.use_flash_attention) shardformer = ShardFormer(shard_config) model, _ = shardformer.optimize(model) except NotImplementedError as e: diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index c55155da34e6..92014064433d 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -62,6 +62,7 @@ def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce): if bias is not None: output = output + bias + return output @staticmethod @@ -113,6 +114,7 @@ def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce): output = F.linear(input_, weight, bias) else: output = F.linear(input_, weight) + return output @staticmethod @@ -462,6 +464,29 @@ def forward(ctx, input_, dim, process_group): @staticmethod def backward(ctx, grad_output): return _split(grad_output, ctx.dim, ctx.process_group), None, None + + +class HookParameter(torch.autograd.Function): + "In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm" + @staticmethod + def forward(ctx, input, weight, bias): + ctx.save_for_backward(weight, bias) + output = input + return output + + @staticmethod + def backward(ctx, grad_output): + weight, bias = ctx.saved_tensors + if weight is not None: + weight = weight.view(weight.shape) + if bias is not None: + bias = bias.view(bias.shape) + return grad_output, None, None + + +def hook_paramter_in_backward(input, weight=None, bias=None): + return HookParameter.apply(input, weight, bias) + def _reduce(input_, process_group): diff --git a/colossalai/shardformer/layer/normalization.py b/colossalai/shardformer/layer/normalization.py index 19b973be8679..8fb1154f8a3a 100644 --- a/colossalai/shardformer/layer/normalization.py +++ b/colossalai/shardformer/layer/normalization.py @@ -4,6 +4,7 @@ import torch.nn as nn from colossalai.lazy import LazyInitContext +from ._operation import hook_paramter_in_backward __all__ = ["FusedLayerNorm", "FusedRMSNorm"] @@ -35,16 +36,14 @@ ] -class FusedLayerNorm: +class FusedLayerNorm(nn.Module): r""" This is a wrapper around the apex fused layernorm implementation. It is meant to be used only with the from_native_module interface. """ - def __init__(self) -> None: - raise NotImplementedError( - "FusedLayerNorm is not implemented as a physical class. " - "It is meant to be used only with the from_native_module interface to wrap the fused layernorm implementation provided by apex." - ) + def __init__(self, layernorm=None) -> None: + super().__init__() + self.layernorm = layernorm @staticmethod def from_native_module(module: nn.LayerNorm, *args, **kwargs) -> nn.Module: @@ -79,25 +78,31 @@ def from_native_module(module: nn.LayerNorm, *args, **kwargs) -> nn.Module: else: from apex.normalization import FusedLayerNorm as ApexFusedLayerNorm + layernorm = ( ApexFusedLayerNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine).to(dtype).to(device) ) layernorm.weight = module.weight layernorm.bias = module.bias - return layernorm + return FusedLayerNorm(layernorm=layernorm) + + def forward(self, input): + weight = self.layernorm.weight + bias = self.layernorm.bias + layernorm_output = self.layernorm(input) + output = hook_paramter_in_backward(layernorm_output, weight, bias) + return output -class FusedRMSNorm: +class FusedRMSNorm(nn.Module): """ This is a wrapper around the apex fused rms norm implementation. It is meant to be used only with the from_native_module interface. """ - def __init__(self) -> None: - raise NotImplementedError( - "FusedRMSNorm is not implemented as a physical class. " - "It is meant to be used only with the from_native_module interface to wrap the fused rms norm implementation provided by apex." - ) + def __init__(self, rmsnorm=None) -> None: + super().__init__() + self.rmsnorm = rmsnorm @staticmethod def from_native_module(module: nn.Module, *args, **kwargs) -> nn.Module: @@ -124,4 +129,10 @@ def from_native_module(module: nn.Module, *args, **kwargs) -> nn.Module: rmsnorm.weight = module.weight - return rmsnorm + return FusedRMSNorm(rmsnorm=rmsnorm) + + def forward(self, input): + weight = self.rmsnorm.weight + rmsnorm_output = self.rmsnorm(input) + output = hook_paramter_in_backward(rmsnorm_output, weight) + return output diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index c6db50585dcf..a1aafcb49139 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -343,6 +343,7 @@ def grad_handle(self, p, grad): with torch._C.DisableTorchFunction(): chunk = self.chunk_manager.get_chunk(p) if chunk.tensors_info[p].state != TensorState.HOLD_AFTER_BWD: + print(chunk.tensors_info[p].state) raise RuntimeError( f"Parameter `{self.param2name[p]}` failed at the gradient reduction. " "Some unsupported torch function is operated upon this parameter." diff --git a/tests/test_booster/test_plugin/test_gemini_plugin.py b/tests/test_booster/test_plugin/test_gemini_plugin.py index 66c624257b58..f5f9aee7a8b6 100644 --- a/tests/test_booster/test_plugin/test_gemini_plugin.py +++ b/tests/test_booster/test_plugin/test_gemini_plugin.py @@ -21,7 +21,7 @@ def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, use_tensor_p ctx = LazyInitContext() else: ctx = nullcontext() - plugin = GeminiPlugin(max_norm=1.0, initial_scale=2**5, use_tensor_parallel=use_tensor_parallel) + plugin = GeminiPlugin(max_norm=1.0, initial_scale=2**5, use_tensor_parallel=use_tensor_parallel, use_fused_layernorm=True, use_flash_attention=True) booster = Booster(plugin=plugin) with ctx: model = model_fn() From 2406cb0b07118cde29bb4f1f1b76009fec1d968e Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Mon, 23 Oct 2023 18:48:18 +0800 Subject: [PATCH 30/46] update fusedlayernorm update fusedlayernorm update fusedlayernorm --- colossalai/shardformer/layer/normalization.py | 55 ++++++++++++++++++- .../test_gemini_checkpoint_io.py | 4 +- 2 files changed, 54 insertions(+), 5 deletions(-) diff --git a/colossalai/shardformer/layer/normalization.py b/colossalai/shardformer/layer/normalization.py index 8fb1154f8a3a..2b45ab9dd4ed 100644 --- a/colossalai/shardformer/layer/normalization.py +++ b/colossalai/shardformer/layer/normalization.py @@ -2,6 +2,9 @@ # -*- encoding: utf-8 -*- import torch.nn as nn +from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, Module +from typing import Tuple, Iterator, Set, Optional +from torch.nn import Parameter from colossalai.lazy import LazyInitContext from ._operation import hook_paramter_in_backward @@ -35,7 +38,6 @@ 65536, ] - class FusedLayerNorm(nn.Module): r""" This is a wrapper around the apex fused layernorm implementation. It is meant to be used only with the from_native_module interface. @@ -44,6 +46,8 @@ class FusedLayerNorm(nn.Module): def __init__(self, layernorm=None) -> None: super().__init__() self.layernorm = layernorm + self._parameters = layernorm._parameters + self._buffers = layernorm._buffers @staticmethod def from_native_module(module: nn.LayerNorm, *args, **kwargs) -> nn.Module: @@ -78,15 +82,15 @@ def from_native_module(module: nn.LayerNorm, *args, **kwargs) -> nn.Module: else: from apex.normalization import FusedLayerNorm as ApexFusedLayerNorm - layernorm = ( ApexFusedLayerNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine).to(dtype).to(device) ) layernorm.weight = module.weight layernorm.bias = module.bias + return FusedLayerNorm(layernorm=layernorm) - + def forward(self, input): weight = self.layernorm.weight bias = self.layernorm.bias @@ -94,6 +98,27 @@ def forward(self, input): output = hook_paramter_in_backward(layernorm_output, weight, bias) return output + def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, Parameter]]: + r"""Returns an iterator over module parameters, yielding both the + name of the parameter as well as the parameter itself. + """ + gen = self._named_members( + lambda module: module._parameters.items(), + prefix=prefix, recurse=False) + for elem in gen: + yield elem + + def named_modules(self, memo: Optional[Set['Module']] = None, prefix: str = '', remove_duplicate: bool = True): + r"""Returns an iterator over all modules in the network, yielding + both the name of the module as well as the module itself. + """ + if memo is None: + memo = set() + if self not in memo: + if remove_duplicate: + memo.add(self) + yield prefix, self + class FusedRMSNorm(nn.Module): """ @@ -103,6 +128,8 @@ class FusedRMSNorm(nn.Module): def __init__(self, rmsnorm=None) -> None: super().__init__() self.rmsnorm = rmsnorm + self._parameters = rmsnorm._parameters + self._buffers = rmsnorm._buffers @staticmethod def from_native_module(module: nn.Module, *args, **kwargs) -> nn.Module: @@ -136,3 +163,25 @@ def forward(self, input): rmsnorm_output = self.rmsnorm(input) output = hook_paramter_in_backward(rmsnorm_output, weight) return output + + + def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, Parameter]]: + r"""Returns an iterator over module parameters, yielding both the + name of the parameter as well as the parameter itself. + """ + gen = self._named_members( + lambda module: module._parameters.items(), + prefix=prefix, recurse=False) + for elem in gen: + yield elem + + def named_modules(self, memo: Optional[Set['Module']] = None, prefix: str = '', remove_duplicate: bool = True): + r"""Returns an iterator over all modules in the network, yielding + both the name of the module as well as the module itself. + """ + if memo is None: + memo = set() + if self not in memo: + if remove_duplicate: + memo.add(self) + yield prefix, self diff --git a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py index bc6582f57371..3f5d663431d4 100644 --- a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py @@ -49,7 +49,7 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b pretrained_path = os.path.join(tempdir, "pretrained") bert_model.config.save_pretrained(save_directory=pretrained_path) - plugin = GeminiPlugin(**placement_config, use_tensor_parallel=use_tensor_parallel, tp_size=tp_size) + plugin = GeminiPlugin(**placement_config, use_tensor_parallel=use_tensor_parallel, tp_size=tp_size, use_fused_layernorm=True) booster = Booster(plugin=plugin) bert_model, _, _, _, _ = booster.boost(bert_model) model_size = sum(p.numel() * p.element_size() for p in bert_model.parameters()) / 1024**2 @@ -73,7 +73,7 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_shard: int, use_tensor_parallel: bool, tp_size: int): (model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values())) criterion = lambda x: x.mean() - plugin = GeminiPlugin(**placement_config, precision="fp16", initial_scale=(2**14), use_tensor_parallel=use_tensor_parallel, tp_size=tp_size) + plugin = GeminiPlugin(**placement_config, precision="fp16", initial_scale=(2**14), use_tensor_parallel=use_tensor_parallel, tp_size=tp_size, use_fused_layernorm=True) booster = Booster(plugin=plugin) model = model_fn() From a0509a6e6f9f68d6a3b5def8c369b4750b29e741 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Tue, 24 Oct 2023 16:58:44 +0800 Subject: [PATCH 31/46] add sequence parallel to gemini add sequence parallel to gemini --- colossalai/booster/plugin/gemini_plugin.py | 68 ++++++++++++------- colossalai/shardformer/layer/_operation.py | 19 ++++-- colossalai/shardformer/modeling/bloom.py | 3 +- colossalai/zero/gemini/gemini_optimizer.py | 2 +- .../test_plugin/test_gemini_plugin.py | 13 ++-- .../test_gemini_checkpoint_io.py | 14 ++-- 6 files changed, 75 insertions(+), 44 deletions(-) diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index f3bd987099df..96118e1b6889 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -304,10 +304,16 @@ class GeminiPlugin(DPPluginBase): max_norm (float, optional): max_norm used for `clip_grad_norm`. You should notice that you shall not do clip_grad_norm by yourself when using ZeRO DDP. The ZeRO optimizer will take care of clip_grad_norm. norm_type (float, optional): norm_type used for `clip_grad_norm`. - use_tensor_parallel (bool, optional): Whether to use tensor parallelism strategy, which is implemented in Shardformer. Default to False. - tp_size (int, optional): If 'use_tensor_parallel' is set to true, please configure 'tp_size' which determines the size of the tensor parallel process group. Default to 1. - use_fused_layernorm (bool, optional): Whether to use fused layernorm operator, which is implemented in Shardformer. Used when 'use_tensor_parallel' is True. Default to False. - use_flash_attention (bool, optional): Whether to use flash attention, which is implemented in Shardformer. Used when 'use_tensor_parallel' is True. Default to False. + enable_tensor_parallelism (bool, optional): Whether to use tensor parallelism strategy, which is implemented in Shardformer. Default to False. + tp_size (int, optional): If 'enable_tensor_parallelism' is set to true, please configure 'tp_size' which determines the size of the tensor parallel process group. Default to 1. + enable_all_optimization (bool, optional): Whether to switch on all the optimizations supported by Shardformer. + Currently all the optimization methods include fused normalization, flash attention and JIT. + Defaults to False. + enable_fused_normalization (bool, optional): Whether to switch on fused normalization in Shardformer. Defaults to False. + enable_flash_attention (bool, optional): Whether to switch on flash attention in Shardformer. Defaults to False. + enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False. + enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False. + enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False. verbose (bool, optional): verbose mode. Debug info including chunk search result will be printed. Defaults to False. """ @@ -341,10 +347,14 @@ def __init__( max_scale: float = 2**32, max_norm: float = 0.0, norm_type: float = 2.0, - use_tensor_parallel: bool = False, + enable_tensor_parallelism: bool = False, tp_size: int = 1, - use_fused_layernorm: bool = False, - use_flash_attention: bool = False, + enable_all_optimization: bool = False, + enable_fused_normalization: bool = False, + enable_flash_attention: bool = False, + enable_sequence_parallelism: bool = False, + enable_jit_fused: bool = False, + enable_sequence_overlap: bool = False, verbose: bool = False ) -> None: super().__init__() @@ -383,10 +393,14 @@ def __init__( max_norm=max_norm, norm_type=norm_type, ) - self.use_tensor_parallel = use_tensor_parallel - self.tp_size = tp_size if self.use_tensor_parallel else 1 - self.use_fused_layernorm = use_fused_layernorm if self.use_tensor_parallel else False - self.use_flash_attention = use_flash_attention if self.use_tensor_parallel else False + self.enable_tensor_parallelism = enable_tensor_parallelism + self.tp_size = tp_size if self.enable_tensor_parallelism else 1 + self.enable_all_optimization = enable_all_optimization + self.enable_fused_normalization = enable_fused_normalization + self.enable_flash_attention = enable_flash_attention + self.enable_sequence_parallelism = enable_sequence_parallelism if self.enable_tensor_parallelism else False + self.enable_jit_fused = enable_jit_fused + self.enable_sequence_overlap = enable_sequence_overlap self.verbose = verbose def support_no_sync(self) -> bool: @@ -426,20 +440,24 @@ def configure( # wrap the model with Gemini self.dp_group = None self.tp_group = None - if self.use_tensor_parallel: - try: - dp_size = dist.get_world_size() // self.tp_size - self.pg_mesh = ProcessGroupMesh(dp_size, self.tp_size) - self.dp_group = self.pg_mesh.get_group_along_axis(0) - self.tp_group = self.pg_mesh.get_group_along_axis(1) - shard_config = ShardConfig(tensor_parallel_process_group = self.tp_group, - enable_tensor_parallelism=True, - enable_fused_normalization=self.use_fused_layernorm, - enable_flash_attention=self.use_flash_attention) - shardformer = ShardFormer(shard_config) - model, _ = shardformer.optimize(model) - except NotImplementedError as e: - print(f"Tensor Parallelism policy for {model.__class__} is not implemented yet\n.") + try: + dp_size = dist.get_world_size() // self.tp_size + assert dp_size > 1, f"the size of DP group should greater than 1. Please reduce the TP group size." + self.pg_mesh = ProcessGroupMesh(dp_size, self.tp_size) + self.dp_group = self.pg_mesh.get_group_along_axis(0) + self.tp_group = self.pg_mesh.get_group_along_axis(1) + shard_config = ShardConfig(tensor_parallel_process_group = self.tp_group, + enable_tensor_parallelism=self.enable_tensor_parallelism, + enable_all_optimization=self.enable_all_optimization, + enable_fused_normalization=self.enable_fused_normalization, + enable_flash_attention=self.enable_flash_attention, + enable_jit_fused=self.enable_jit_fused, + enable_sequence_parallelism=self.enable_sequence_parallelism, + enable_sequence_overlap=self.enable_sequence_overlap) + shardformer = ShardFormer(shard_config) + model, _ = shardformer.optimize(model) + except NotImplementedError as e: + print(f"Tensor Parallelism policy for {model.__class__} is not implemented yet\n.") model = GeminiDDP(model, **self.gemini_config, process_group=self.dp_group, verbose=self.verbose) diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index 92014064433d..0d8c3d453ce1 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -162,7 +162,7 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function): @staticmethod def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap=True): - ctx.save_for_backward(input_, weight) + ctx.save_for_backward(input_, weight, bias) ctx.use_bias = bias is not None ctx.process_group = process_group ctx.async_grad_reduce_scatter = async_grad_reduce_scatter @@ -180,12 +180,16 @@ def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, @staticmethod def backward(ctx, grad_output): - input_, weight = ctx.saved_tensors + input_, weight, bias = ctx.saved_tensors use_bias = ctx.use_bias dim = ctx.dim process_group = ctx.process_group overlap = ctx.overlap + # In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm + if use_bias: + bias = bias.view(bias.shape) + if not overlap: input_parallel = _gather(input_, dim, process_group) @@ -299,7 +303,7 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function): @staticmethod def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap): - ctx.save_for_backward(input_, weight) + ctx.save_for_backward(input_, weight, bias) ctx.use_bias = bias is not None ctx.process_group = process_group ctx.async_grad_reduce_scatter = async_grad_reduce_scatter @@ -316,12 +320,17 @@ def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, @staticmethod def backward(ctx, grad_output): - input_, weight = ctx.saved_tensors + input_, weight, bias = ctx.saved_tensors use_bias = ctx.use_bias dim = ctx.dim process_group = ctx.process_group overlap = ctx.overlap + # In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm + weight = weight.view(weight.shape) + if use_bias: + bias = bias.view(bias.shape) + if not overlap: input_parallel = _gather(input_, dim, process_group) @@ -467,7 +476,7 @@ def backward(ctx, grad_output): class HookParameter(torch.autograd.Function): - "In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm" + """In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm""" @staticmethod def forward(ctx, input, weight, bias): ctx.save_for_backward(weight, bias) diff --git a/colossalai/shardformer/modeling/bloom.py b/colossalai/shardformer/modeling/bloom.py index 1bf87e80a461..cd8a023306dc 100644 --- a/colossalai/shardformer/modeling/bloom.py +++ b/colossalai/shardformer/modeling/bloom.py @@ -719,7 +719,7 @@ def forward( ): fused_qkv = self.query_key_value(hidden_states) (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) - batch_size, tgt_len, _ = query_layer.size() + batch_size, tgt_len, _, _ = query_layer.size() _, kv_length, _, _ = key_layer.size() @@ -755,6 +755,7 @@ def forward( attention_numerical_mask = torch.masked_fill( attention_numerical_mask, attention_mask, torch.finfo(torch.float32).min ) + attention_numerical_mask = attention_numerical_mask.to(query_layer.dtype) context_layer = me_attention( query_layer, diff --git a/colossalai/zero/gemini/gemini_optimizer.py b/colossalai/zero/gemini/gemini_optimizer.py index 7f97c0a82ed9..ae13866eb6aa 100644 --- a/colossalai/zero/gemini/gemini_optimizer.py +++ b/colossalai/zero/gemini/gemini_optimizer.py @@ -116,7 +116,7 @@ def __init__( self.param_info = param_info self.tp_group = tp_group self.tp_size = dist.get_world_size(tp_group) if tp_group is not None else 1 - self.tp_rank = dist.get_rank(tp_group) + self.tp_rank = dist.get_rank(tp_group) if tp_group is not None else None self.verbose = verbose self.param_groups_backup = list() diff --git a/tests/test_booster/test_plugin/test_gemini_plugin.py b/tests/test_booster/test_plugin/test_gemini_plugin.py index f5f9aee7a8b6..ee1c8326b0a3 100644 --- a/tests/test_booster/test_plugin/test_gemini_plugin.py +++ b/tests/test_booster/test_plugin/test_gemini_plugin.py @@ -15,13 +15,14 @@ from tests.kit.model_zoo import model_zoo -def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, use_tensor_parallel) -> Optional[str]: +def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, enable_tensor_parallelism) -> Optional[str]: try: if init_method == "lazy": ctx = LazyInitContext() else: ctx = nullcontext() - plugin = GeminiPlugin(max_norm=1.0, initial_scale=2**5, use_tensor_parallel=use_tensor_parallel, use_fused_layernorm=True, use_flash_attention=True) + enable_all_optimization = True if enable_tensor_parallelism else False + plugin = GeminiPlugin(max_norm=1.0, initial_scale=2**5, enable_tensor_parallelism=enable_tensor_parallelism, enable_all_optimization=enable_all_optimization) booster = Booster(plugin=plugin) with ctx: model = model_fn() @@ -57,8 +58,8 @@ def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, use_tensor_p @parameterize("subset", ["torchvision", "transformers", "diffusers"]) @parameterize("init_method", ["none"]) -@parameterize("use_tensor_parallel", [True, False]) -def check_gemini_plugin(subset: str, init_method: str = "none", use_tensor_parallel: bool = True, early_stop: bool = True): +@parameterize("enable_tensor_parallelism", [True, False]) +def check_gemini_plugin(subset: str, init_method: str = "none", enable_tensor_parallelism: bool = True, early_stop: bool = True): """check gemini plugin over model zoo Args: @@ -120,9 +121,9 @@ def check_gemini_plugin(subset: str, init_method: str = "none", use_tensor_paral # TODO debug blip2 when using tp, something wrong with shift_logits's shape if "transformers_blip2" in name: - use_tensor_parallel = False + enable_tensor_parallelism = False - err = run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, use_tensor_parallel) + err = run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, enable_tensor_parallelism) torch.cuda.empty_cache() if err is None: passed_models.append(name) diff --git a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py index 3f5d663431d4..29d2bd9ce36e 100644 --- a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py @@ -37,9 +37,10 @@ @parameterize("placement_config", MODEL_PLACEMENT_CONFIGS) @parameterize("model_name", ["transformers_bert_for_sequence_classification"]) @parameterize("use_safetensors", [False, True]) -@parameterize("use_tensor_parallel", [True, False]) +@parameterize("enable_tensor_parallelism", [True, False]) @parameterize("tp_size", [2]) -def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: bool, use_tensor_parallel: bool, tp_size: int): +@parameterize("enable_all_optimization", [True, False]) +def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: bool, enable_tensor_parallelism: bool, tp_size: int, enable_all_optimization: bool): from transformers import BertForSequenceClassification (model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values())) @@ -49,7 +50,7 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b pretrained_path = os.path.join(tempdir, "pretrained") bert_model.config.save_pretrained(save_directory=pretrained_path) - plugin = GeminiPlugin(**placement_config, use_tensor_parallel=use_tensor_parallel, tp_size=tp_size, use_fused_layernorm=True) + plugin = GeminiPlugin(**placement_config, enable_tensor_parallelism=enable_tensor_parallelism, tp_size=tp_size, enable_all_optimization=enable_all_optimization) booster = Booster(plugin=plugin) bert_model, _, _, _, _ = booster.boost(bert_model) model_size = sum(p.numel() * p.element_size() for p in bert_model.parameters()) / 1024**2 @@ -68,12 +69,13 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b @parameterize("shard", [True, False]) @parameterize("model_name", ["transformers_gpt"]) @parameterize("size_per_shard", [32]) -@parameterize("use_tensor_parallel", [True, False]) +@parameterize("enable_tensor_parallelism", [True, False]) @parameterize("tp_size", [2]) -def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_shard: int, use_tensor_parallel: bool, tp_size: int): +@parameterize("enable_all_optimization", [True, False]) +def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_shard: int, enable_tensor_parallelism: bool, tp_size: int, enable_all_optimization: bool): (model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values())) criterion = lambda x: x.mean() - plugin = GeminiPlugin(**placement_config, precision="fp16", initial_scale=(2**14), use_tensor_parallel=use_tensor_parallel, tp_size=tp_size, use_fused_layernorm=True) + plugin = GeminiPlugin(**placement_config, precision="fp16", initial_scale=(2**14), enable_tensor_parallelism=enable_tensor_parallelism, tp_size=tp_size, enable_all_optimization=enable_all_optimization) booster = Booster(plugin=plugin) model = model_fn() From 12cd78018fd669e339dfc8a2f49edcfc1800d3f1 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Wed, 25 Oct 2023 10:15:53 +0800 Subject: [PATCH 32/46] fix --- colossalai/zero/gemini/gemini_ddp.py | 1 - 1 file changed, 1 deletion(-) diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index a1aafcb49139..c6db50585dcf 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -343,7 +343,6 @@ def grad_handle(self, p, grad): with torch._C.DisableTorchFunction(): chunk = self.chunk_manager.get_chunk(p) if chunk.tensors_info[p].state != TensorState.HOLD_AFTER_BWD: - print(chunk.tensors_info[p].state) raise RuntimeError( f"Parameter `{self.param2name[p]}` failed at the gradient reduction. " "Some unsupported torch function is operated upon this parameter." From 0110902082ddd65dd07c282a7fe7aef7fd8acd6c Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Wed, 25 Oct 2023 17:06:00 +0800 Subject: [PATCH 33/46] fix comments fix comments fix comments --- colossalai/booster/plugin/gemini_plugin.py | 8 +++++--- colossalai/zero/gemini/gemini_ddp.py | 2 +- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index 96118e1b6889..dbb943da80d6 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -35,6 +35,8 @@ SUPPORTED_PRECISION = ["fp16", "bf16"] PRECISION_STR_TO_DTYPE = {"fp16": torch.half, "bf16": torch.bfloat16} +DP_AXIS = 0 +TP_AXIS = 1 def get_param_info(optim: Optimizer): # Get a backup of necessary information of parameters for future use, which includes: @@ -442,10 +444,10 @@ def configure( self.tp_group = None try: dp_size = dist.get_world_size() // self.tp_size - assert dp_size > 1, f"the size of DP group should greater than 1. Please reduce the TP group size." + assert dp_size > 1, f"The size of the DP group should be greater than 1. Please reduce the TP group size." self.pg_mesh = ProcessGroupMesh(dp_size, self.tp_size) - self.dp_group = self.pg_mesh.get_group_along_axis(0) - self.tp_group = self.pg_mesh.get_group_along_axis(1) + self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS) + self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) shard_config = ShardConfig(tensor_parallel_process_group = self.tp_group, enable_tensor_parallelism=self.enable_tensor_parallelism, enable_all_optimization=self.enable_all_optimization, diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index c6db50585dcf..fdc5ccef0e43 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -455,7 +455,7 @@ def _get_chunk_to_save_data(self, chunk: Chunk, only_rank_0: bool) -> Dict: global_shape = global_shape) elif is_customized_distributed_tensor(tensor): init_tensor_as_customization_distributed(record_tensor, shard_fn=tensor.shard_fn, gather_fn=tensor.gather_fn) - record_tensor = gather_distributed_param(record_tensor, keep_vars=False) + record_tensor = gather_distributed_param(record_tensor, keep_vars=False).cpu() assert tensor not in chunk_to_save_data chunk_to_save_data[tensor] = record_tensor From 86a5eca26b7e3eaa30cc2228816cbc85ce0001cd Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Mon, 30 Oct 2023 14:19:40 +0800 Subject: [PATCH 34/46] fix --- colossalai/shardformer/layer/normalization.py | 195 +++++++++++++----- 1 file changed, 149 insertions(+), 46 deletions(-) diff --git a/colossalai/shardformer/layer/normalization.py b/colossalai/shardformer/layer/normalization.py index 2b45ab9dd4ed..4be6d0db829f 100644 --- a/colossalai/shardformer/layer/normalization.py +++ b/colossalai/shardformer/layer/normalization.py @@ -1,10 +1,13 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- +import warnings import torch.nn as nn from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, Module -from typing import Tuple, Iterator, Set, Optional +from typing import Tuple, Iterator, Set, Optional, Mapping, List, Any from torch.nn import Parameter +from collections import OrderedDict +from torch.nn.modules.module import _IncompatibleKeys from colossalai.lazy import LazyInitContext from ._operation import hook_paramter_in_backward @@ -38,7 +41,150 @@ 65536, ] -class FusedLayerNorm(nn.Module): +class LayerNormBase(nn.Module): + + def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, Parameter]]: + r"""Returns an iterator over module parameters, yielding both the + name of the parameter as well as the parameter itself. + """ + gen = self._named_members( + lambda module: module._parameters.items(), + prefix=prefix, recurse=False) + for elem in gen: + yield elem + + def named_modules(self, memo: Optional[Set['Module']] = None, prefix: str = '', remove_duplicate: bool = True): + r"""Returns an iterator over all modules in the network, yielding + both the name of the module as well as the module itself. + """ + if memo is None: + memo = set() + if self not in memo: + if remove_duplicate: + memo.add(self) + yield prefix, self + + def state_dict(self, *args, destination=None, prefix='', keep_vars=False): + r"""Returns a dictionary containing a whole state of the module. + + Both parameters and persistent buffers (e.g. running averages) are + included. Keys are corresponding parameter and buffer names. + Parameters and buffers set to ``None`` are not included. + + .. warning:: + Currently ``state_dict()`` also accepts positional arguments for + ``destination``, ``prefix`` and ``keep_vars`` in order. However, + this is being deprecated and keyword arguments will be enforced in + future releases. + + .. warning:: + Please avoid the use of argument ``destination`` as it is not + designed for end-users. + + Args: + destination (dict, optional): If provided, the state of module will + be updated into the dict and the same object is returned. + Otherwise, an ``OrderedDict`` will be created and returned. + Default: ``None``. + prefix (str, optional): a prefix added to parameter and buffer + names to compose the keys in state_dict. Default: ``''``. + keep_vars (bool, optional): by default the :class:`~torch.Tensor` s + returned in the state dict are detached from autograd. If it's + set to ``True``, detaching will not be performed. + Default: ``False``. + + Returns: + dict: + a dictionary containing a whole state of the module + + Example:: + + >>> module.state_dict().keys() + ['bias', 'weight'] + + """ + + # TODO: Remove `args` and the parsing logic when BC allows. + if len(args) > 0: + if destination is None: + destination = args[0] + if len(args) > 1 and prefix == '': + prefix = args[1] + if len(args) > 2 and keep_vars is False: + keep_vars = args[2] + # DeprecationWarning is ignored by default + warnings.warn( + "Positional args are being deprecated, use kwargs instead. Refer to " + "https://pytorch.org/docs/master/generated/torch.nn.Module.html#torch.nn.Module.state_dict" + " for details.") + + if destination is None: + destination = OrderedDict() + destination._metadata = OrderedDict() + + local_metadata = dict(version=self._version) + if hasattr(destination, "_metadata"): + destination._metadata[prefix[:-1]] = local_metadata + + self._save_to_state_dict(destination, prefix, keep_vars) + for hook in self._state_dict_hooks.values(): + hook_result = hook(self, destination, prefix, local_metadata) + if hook_result is not None: + destination = hook_result + return destination + + def load_state_dict(self, state_dict: Mapping[str, Any], + strict: bool = True): + r"""Don't + """ + if not isinstance(state_dict, Mapping): + raise TypeError("Expected state_dict to be dict-like, got {}.".format(type(state_dict))) + + missing_keys: List[str] = [] + unexpected_keys: List[str] = [] + error_msgs: List[str] = [] + + # copy state_dict so _load_from_state_dict can modify it + metadata = getattr(state_dict, '_metadata', None) + state_dict = OrderedDict(state_dict) + if metadata is not None: + # mypy isn't aware that "_metadata" exists in state_dict + state_dict._metadata = metadata # type: ignore[attr-defined] + + def load(module, prefix=''): + local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) + module._load_from_state_dict( + state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) + + # Note that the hook can modify missing_keys and unexpected_keys. + incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys) + for hook in module._load_state_dict_post_hooks.values(): + out = hook(module, incompatible_keys) + assert out is None, ( + "Hooks registered with ``register_load_state_dict_post_hook`` are not" + "expected to return new values, if incompatible_keys need to be modified," + "it should be done inplace." + ) + + load(self) + del load + + if strict: + if len(unexpected_keys) > 0: + error_msgs.insert( + 0, 'Unexpected key(s) in state_dict: {}. '.format( + ', '.join('"{}"'.format(k) for k in unexpected_keys))) + if len(missing_keys) > 0: + error_msgs.insert( + 0, 'Missing key(s) in state_dict: {}. '.format( + ', '.join('"{}"'.format(k) for k in missing_keys))) + + if len(error_msgs) > 0: + raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( + self.__class__.__name__, "\n\t".join(error_msgs))) + return _IncompatibleKeys(missing_keys, unexpected_keys) + +class FusedLayerNorm(LayerNormBase): r""" This is a wrapper around the apex fused layernorm implementation. It is meant to be used only with the from_native_module interface. """ @@ -98,29 +244,8 @@ def forward(self, input): output = hook_paramter_in_backward(layernorm_output, weight, bias) return output - def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, Parameter]]: - r"""Returns an iterator over module parameters, yielding both the - name of the parameter as well as the parameter itself. - """ - gen = self._named_members( - lambda module: module._parameters.items(), - prefix=prefix, recurse=False) - for elem in gen: - yield elem - - def named_modules(self, memo: Optional[Set['Module']] = None, prefix: str = '', remove_duplicate: bool = True): - r"""Returns an iterator over all modules in the network, yielding - both the name of the module as well as the module itself. - """ - if memo is None: - memo = set() - if self not in memo: - if remove_duplicate: - memo.add(self) - yield prefix, self - -class FusedRMSNorm(nn.Module): +class FusedRMSNorm(LayerNormBase): """ This is a wrapper around the apex fused rms norm implementation. It is meant to be used only with the from_native_module interface. """ @@ -163,25 +288,3 @@ def forward(self, input): rmsnorm_output = self.rmsnorm(input) output = hook_paramter_in_backward(rmsnorm_output, weight) return output - - - def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, Parameter]]: - r"""Returns an iterator over module parameters, yielding both the - name of the parameter as well as the parameter itself. - """ - gen = self._named_members( - lambda module: module._parameters.items(), - prefix=prefix, recurse=False) - for elem in gen: - yield elem - - def named_modules(self, memo: Optional[Set['Module']] = None, prefix: str = '', remove_duplicate: bool = True): - r"""Returns an iterator over all modules in the network, yielding - both the name of the module as well as the module itself. - """ - if memo is None: - memo = set() - if self not in memo: - if remove_duplicate: - memo.add(self) - yield prefix, self From 6f13876a0121c11eadc459caf86b78f314fd225d Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Mon, 30 Oct 2023 15:43:51 +0800 Subject: [PATCH 35/46] fix t5 --- colossalai/shardformer/policies/t5.py | 8 -------- tests/test_checkpoint_io/test_gemini_checkpoint_io.py | 8 ++++---- 2 files changed, 4 insertions(+), 12 deletions(-) diff --git a/colossalai/shardformer/policies/t5.py b/colossalai/shardformer/policies/t5.py index 74cc7337e9f1..9fe3ef260d75 100644 --- a/colossalai/shardformer/policies/t5.py +++ b/colossalai/shardformer/policies/t5.py @@ -170,14 +170,6 @@ def module_policy(self): # optimization configuration if self.shard_config.enable_fused_normalization: - self.append_or_create_submodule_replacement( - description=SubModuleReplacementDescription( - suffix="layer_norm", - target_module=FusedRMSNorm, - ), - policy=policy, - target_key=T5LayerFF, - ) self.append_or_create_submodule_replacement( description=SubModuleReplacementDescription( suffix="layer_norm", diff --git a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py index 29d2bd9ce36e..821ce9fbbbd9 100644 --- a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py @@ -39,12 +39,12 @@ @parameterize("use_safetensors", [False, True]) @parameterize("enable_tensor_parallelism", [True, False]) @parameterize("tp_size", [2]) -@parameterize("enable_all_optimization", [True, False]) -def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: bool, enable_tensor_parallelism: bool, tp_size: int, enable_all_optimization: bool): +def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: bool, enable_tensor_parallelism: bool, tp_size: int): from transformers import BertForSequenceClassification (model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values())) bert_model = model_fn() + enable_all_optimization = True if enable_tensor_parallelism else False with shared_tempdir() as tempdir: pretrained_path = os.path.join(tempdir, "pretrained") @@ -71,10 +71,10 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b @parameterize("size_per_shard", [32]) @parameterize("enable_tensor_parallelism", [True, False]) @parameterize("tp_size", [2]) -@parameterize("enable_all_optimization", [True, False]) -def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_shard: int, enable_tensor_parallelism: bool, tp_size: int, enable_all_optimization: bool): +def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_shard: int, enable_tensor_parallelism: bool, tp_size: int): (model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values())) criterion = lambda x: x.mean() + enable_all_optimization = True if enable_tensor_parallelism else False plugin = GeminiPlugin(**placement_config, precision="fp16", initial_scale=(2**14), enable_tensor_parallelism=enable_tensor_parallelism, tp_size=tp_size, enable_all_optimization=enable_all_optimization) booster = Booster(plugin=plugin) From 5f16e4fa4f4d47c5f4af3217bab03be49f8cc16d Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Mon, 30 Oct 2023 17:55:11 +0800 Subject: [PATCH 36/46] clear cache --- tests/test_booster/test_plugin/test_gemini_plugin.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/test_booster/test_plugin/test_gemini_plugin.py b/tests/test_booster/test_plugin/test_gemini_plugin.py index ee1c8326b0a3..c27a0fd7a2e3 100644 --- a/tests/test_booster/test_plugin/test_gemini_plugin.py +++ b/tests/test_booster/test_plugin/test_gemini_plugin.py @@ -10,6 +10,8 @@ from colossalai.fx import is_compatible_with_meta from colossalai.lazy.lazy_init import LazyInitContext from colossalai.nn.optimizer import HybridAdam +from colossalai.tensor.d_tensor.api import clear_layout_converter +from colossalai.shardformer.layer.utils import Randomizer from colossalai.tensor.colo_parameter import ColoParameter from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from tests.kit.model_zoo import model_zoo @@ -124,6 +126,8 @@ def check_gemini_plugin(subset: str, init_method: str = "none", enable_tensor_pa enable_tensor_parallelism = False err = run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, enable_tensor_parallelism) + clear_layout_converter() + Randomizer.reset_index() torch.cuda.empty_cache() if err is None: passed_models.append(name) From adead5051b283d010430135cd34409a26609fdf1 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Tue, 31 Oct 2023 17:35:47 +0800 Subject: [PATCH 37/46] fix --- colossalai/zero/gemini/gemini_ddp.py | 2 +- colossalai/zero/gemini/gemini_optimizer.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index fdc5ccef0e43..59a00bb9f8d5 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -331,7 +331,7 @@ def backward(self, loss: torch.Tensor): self._post_backward() def backward_by_grad(self, tensor, grad): - self._pre_backward() + raise RuntimeError("Gemini is not compatible with pipeline. backward_by_grad shoudn't be called in Gemini.") with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(self.param_op_hook): torch.autograd.backward(tensor, grad) self._post_backward() diff --git a/colossalai/zero/gemini/gemini_optimizer.py b/colossalai/zero/gemini/gemini_optimizer.py index ae13866eb6aa..0d993246a51a 100644 --- a/colossalai/zero/gemini/gemini_optimizer.py +++ b/colossalai/zero/gemini/gemini_optimizer.py @@ -487,7 +487,7 @@ def collect_states(self, param_id: int, only_rank_0: bool = True) -> dict: else: tensor_size = param.numel() if partition_dim is not None: - tensor_size * self.tp_size + tensor_size = tensor_size * self.tp_size collected_states[state_name] = torch.zeros( tensor_size, dtype=torch.float32, requires_grad=False ).cpu() From ed825dc78a3a30715ca9dd12c47bdb5d69d37279 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Tue, 31 Oct 2023 21:17:10 +0800 Subject: [PATCH 38/46] activate ci --- colossalai/shardformer/layer/normalization.py | 40 +------------------ 1 file changed, 2 insertions(+), 38 deletions(-) diff --git a/colossalai/shardformer/layer/normalization.py b/colossalai/shardformer/layer/normalization.py index 4be6d0db829f..86c869bf2f4e 100644 --- a/colossalai/shardformer/layer/normalization.py +++ b/colossalai/shardformer/layer/normalization.py @@ -65,43 +65,7 @@ def named_modules(self, memo: Optional[Set['Module']] = None, prefix: str = '', yield prefix, self def state_dict(self, *args, destination=None, prefix='', keep_vars=False): - r"""Returns a dictionary containing a whole state of the module. - - Both parameters and persistent buffers (e.g. running averages) are - included. Keys are corresponding parameter and buffer names. - Parameters and buffers set to ``None`` are not included. - - .. warning:: - Currently ``state_dict()`` also accepts positional arguments for - ``destination``, ``prefix`` and ``keep_vars`` in order. However, - this is being deprecated and keyword arguments will be enforced in - future releases. - - .. warning:: - Please avoid the use of argument ``destination`` as it is not - designed for end-users. - - Args: - destination (dict, optional): If provided, the state of module will - be updated into the dict and the same object is returned. - Otherwise, an ``OrderedDict`` will be created and returned. - Default: ``None``. - prefix (str, optional): a prefix added to parameter and buffer - names to compose the keys in state_dict. Default: ``''``. - keep_vars (bool, optional): by default the :class:`~torch.Tensor` s - returned in the state dict are detached from autograd. If it's - set to ``True``, detaching will not be performed. - Default: ``False``. - - Returns: - dict: - a dictionary containing a whole state of the module - - Example:: - - >>> module.state_dict().keys() - ['bias', 'weight'] - + r"""Don't recursive process self._module """ # TODO: Remove `args` and the parsing logic when BC allows. @@ -135,7 +99,7 @@ def state_dict(self, *args, destination=None, prefix='', keep_vars=False): def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): - r"""Don't + r"""Don't recursive process self._module """ if not isinstance(state_dict, Mapping): raise TypeError("Expected state_dict to be dict-like, got {}.".format(type(state_dict))) From 37494c372e83f3590110111b49e8693843de008e Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Wed, 1 Nov 2023 07:10:35 +0800 Subject: [PATCH 39/46] activate ci --- colossalai/shardformer/layer/normalization.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/colossalai/shardformer/layer/normalization.py b/colossalai/shardformer/layer/normalization.py index 86c869bf2f4e..6c8dd7c9cb9c 100644 --- a/colossalai/shardformer/layer/normalization.py +++ b/colossalai/shardformer/layer/normalization.py @@ -65,9 +65,7 @@ def named_modules(self, memo: Optional[Set['Module']] = None, prefix: str = '', yield prefix, self def state_dict(self, *args, destination=None, prefix='', keep_vars=False): - r"""Don't recursive process self._module - """ - + r"""Don't recursive process self._module""" # TODO: Remove `args` and the parsing logic when BC allows. if len(args) > 0: if destination is None: @@ -99,8 +97,7 @@ def state_dict(self, *args, destination=None, prefix='', keep_vars=False): def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): - r"""Don't recursive process self._module - """ + r"""Don't recursive process self._module""" if not isinstance(state_dict, Mapping): raise TypeError("Expected state_dict to be dict-like, got {}.".format(type(state_dict))) From 73da4ca91df6b2409ee6e1e9dbc1edc38ddc9fbc Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Wed, 1 Nov 2023 16:59:24 +0800 Subject: [PATCH 40/46] fix --- colossalai/booster/plugin/gemini_plugin.py | 56 ++++++++++++---------- 1 file changed, 32 insertions(+), 24 deletions(-) diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index dbb943da80d6..8d8cc6283770 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -5,8 +5,8 @@ from typing import Callable, Iterator, List, Optional, Tuple import torch -import torch.nn as nn import torch.distributed as dist +import torch.nn as nn from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.utils.data import DataLoader @@ -20,13 +20,12 @@ save_state_dict, save_state_dict_shards, ) -from colossalai.cluster import DistCoordinator +from colossalai.cluster import DistCoordinator, ProcessGroupMesh from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.shardformer import ShardConfig, ShardFormer from colossalai.utils import get_current_device from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero.gemini.memory_tracer import MemStats -from colossalai.cluster import ProcessGroupMesh from .dp_plugin_base import DPPluginBase @@ -38,6 +37,7 @@ DP_AXIS = 0 TP_AXIS = 1 + def get_param_info(optim: Optimizer): # Get a backup of necessary information of parameters for future use, which includes: # 1. A mapping from integer param_id to param32 shape. @@ -55,6 +55,7 @@ def get_param_info(optim: Optimizer): return param_info + class GeminiCheckpointIO(GeneralCheckpointIO): def __init__(self) -> None: super().__init__() @@ -357,7 +358,7 @@ def __init__( enable_sequence_parallelism: bool = False, enable_jit_fused: bool = False, enable_sequence_overlap: bool = False, - verbose: bool = False + verbose: bool = False, ) -> None: super().__init__() assert precision in SUPPORTED_PRECISION, f"precision {precision} is not supported" @@ -396,7 +397,6 @@ def __init__( norm_type=norm_type, ) self.enable_tensor_parallelism = enable_tensor_parallelism - self.tp_size = tp_size if self.enable_tensor_parallelism else 1 self.enable_all_optimization = enable_all_optimization self.enable_fused_normalization = enable_fused_normalization self.enable_flash_attention = enable_flash_attention @@ -405,6 +405,23 @@ def __init__( self.enable_sequence_overlap = enable_sequence_overlap self.verbose = verbose + self.tp_size = tp_size if self.enable_tensor_parallelism else 1 + self.dp_size = dist.get_world_size() // self.tp_size + assert self.dp_size > 1, f"The size of the DP group should be greater than 1. Please reduce the TP group size." + self.pg_mesh = ProcessGroupMesh(self.dp_size, self.tp_size) + self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS) + self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) + self.shard_config = ShardConfig( + tensor_parallel_process_group=self.tp_group, + enable_tensor_parallelism=self.enable_tensor_parallelism, + enable_all_optimization=self.enable_all_optimization, + enable_fused_normalization=self.enable_fused_normalization, + enable_flash_attention=self.enable_flash_attention, + enable_jit_fused=self.enable_jit_fused, + enable_sequence_parallelism=self.enable_sequence_parallelism, + enable_sequence_overlap=self.enable_sequence_overlap, + ) + def support_no_sync(self) -> bool: return False @@ -440,32 +457,23 @@ def configure( # model = nn.SyncBatchNorm.convert_sync_batchnorm(model, None) # wrap the model with Gemini - self.dp_group = None - self.tp_group = None try: - dp_size = dist.get_world_size() // self.tp_size - assert dp_size > 1, f"The size of the DP group should be greater than 1. Please reduce the TP group size." - self.pg_mesh = ProcessGroupMesh(dp_size, self.tp_size) - self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS) - self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) - shard_config = ShardConfig(tensor_parallel_process_group = self.tp_group, - enable_tensor_parallelism=self.enable_tensor_parallelism, - enable_all_optimization=self.enable_all_optimization, - enable_fused_normalization=self.enable_fused_normalization, - enable_flash_attention=self.enable_flash_attention, - enable_jit_fused=self.enable_jit_fused, - enable_sequence_parallelism=self.enable_sequence_parallelism, - enable_sequence_overlap=self.enable_sequence_overlap) - shardformer = ShardFormer(shard_config) + shardformer = ShardFormer(self.shard_config) model, _ = shardformer.optimize(model) - except NotImplementedError as e: + except NotImplementedError: print(f"Tensor Parallelism policy for {model.__class__} is not implemented yet\n.") model = GeminiDDP(model, **self.gemini_config, process_group=self.dp_group, verbose=self.verbose) if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): optimizer = GeminiOptimizer( - optimizer, model, **self.zero_optim_config, **self.optim_kwargs, param_info=param_info, tp_group=self.tp_group, verbose=self.verbose + optimizer, + model, + **self.zero_optim_config, + **self.optim_kwargs, + param_info=param_info, + tp_group=self.tp_group, + verbose=self.verbose, ) return model, optimizer, criterion, dataloader, lr_scheduler @@ -477,4 +485,4 @@ def get_checkpoint_io(self) -> CheckpointIO: return GeminiCheckpointIO() def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]: - raise NotImplementedError + raise NotImplementedError \ No newline at end of file From cf2bc63b7ff5bbbda008fb4c9750a93b7f7224da Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Wed, 1 Nov 2023 21:54:08 +0800 Subject: [PATCH 41/46] fix --- colossalai/booster/plugin/gemini_plugin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index 8d8cc6283770..0db98970fc13 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -110,7 +110,7 @@ def save_sharded_model( use_safetensors: bool = False, ): """ - Save sharded model. + Save sharded model. As there is communication when getting state dict, model.state_dict() must be called on all processes. """ assert isinstance(model, GeminiDDP), "Please boost the model before saving!" From 6c85a9e4b161d2a3440790e7c48e14167ed65251 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Wed, 1 Nov 2023 22:41:59 +0800 Subject: [PATCH 42/46] fix --- colossalai/booster/plugin/gemini_plugin.py | 56 ++++++++----------- colossalai/cluster/process_group_mesh.py | 2 +- .../tensor/d_tensor/layout_converter.py | 2 +- 3 files changed, 26 insertions(+), 34 deletions(-) diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index 0db98970fc13..ccac8a859de7 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -5,8 +5,8 @@ from typing import Callable, Iterator, List, Optional, Tuple import torch -import torch.distributed as dist import torch.nn as nn +import torch.distributed as dist from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.utils.data import DataLoader @@ -20,12 +20,13 @@ save_state_dict, save_state_dict_shards, ) -from colossalai.cluster import DistCoordinator, ProcessGroupMesh +from colossalai.cluster import DistCoordinator from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.shardformer import ShardConfig, ShardFormer from colossalai.utils import get_current_device from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero.gemini.memory_tracer import MemStats +from colossalai.cluster import ProcessGroupMesh from .dp_plugin_base import DPPluginBase @@ -37,7 +38,6 @@ DP_AXIS = 0 TP_AXIS = 1 - def get_param_info(optim: Optimizer): # Get a backup of necessary information of parameters for future use, which includes: # 1. A mapping from integer param_id to param32 shape. @@ -55,7 +55,6 @@ def get_param_info(optim: Optimizer): return param_info - class GeminiCheckpointIO(GeneralCheckpointIO): def __init__(self) -> None: super().__init__() @@ -110,7 +109,7 @@ def save_sharded_model( use_safetensors: bool = False, ): """ - Save sharded model. + Save sharded model. As there is communication when getting state dict, model.state_dict() must be called on all processes. """ assert isinstance(model, GeminiDDP), "Please boost the model before saving!" @@ -358,7 +357,7 @@ def __init__( enable_sequence_parallelism: bool = False, enable_jit_fused: bool = False, enable_sequence_overlap: bool = False, - verbose: bool = False, + verbose: bool = False ) -> None: super().__init__() assert precision in SUPPORTED_PRECISION, f"precision {precision} is not supported" @@ -397,6 +396,7 @@ def __init__( norm_type=norm_type, ) self.enable_tensor_parallelism = enable_tensor_parallelism + self.tp_size = tp_size if self.enable_tensor_parallelism else 1 self.enable_all_optimization = enable_all_optimization self.enable_fused_normalization = enable_fused_normalization self.enable_flash_attention = enable_flash_attention @@ -405,23 +405,6 @@ def __init__( self.enable_sequence_overlap = enable_sequence_overlap self.verbose = verbose - self.tp_size = tp_size if self.enable_tensor_parallelism else 1 - self.dp_size = dist.get_world_size() // self.tp_size - assert self.dp_size > 1, f"The size of the DP group should be greater than 1. Please reduce the TP group size." - self.pg_mesh = ProcessGroupMesh(self.dp_size, self.tp_size) - self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS) - self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) - self.shard_config = ShardConfig( - tensor_parallel_process_group=self.tp_group, - enable_tensor_parallelism=self.enable_tensor_parallelism, - enable_all_optimization=self.enable_all_optimization, - enable_fused_normalization=self.enable_fused_normalization, - enable_flash_attention=self.enable_flash_attention, - enable_jit_fused=self.enable_jit_fused, - enable_sequence_parallelism=self.enable_sequence_parallelism, - enable_sequence_overlap=self.enable_sequence_overlap, - ) - def support_no_sync(self) -> bool: return False @@ -457,23 +440,32 @@ def configure( # model = nn.SyncBatchNorm.convert_sync_batchnorm(model, None) # wrap the model with Gemini + self.dp_group = None + self.tp_group = None try: - shardformer = ShardFormer(self.shard_config) + dp_size = dist.get_world_size() // self.tp_size + assert dp_size > 1, f"The size of the DP group should be greater than 1. Please reduce the TP group size." + self.pg_mesh = ProcessGroupMesh(dp_size, self.tp_size) + self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS) + self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) + shard_config = ShardConfig(tensor_parallel_process_group = self.tp_group, + enable_tensor_parallelism=self.enable_tensor_parallelism, + enable_all_optimization=self.enable_all_optimization, + enable_fused_normalization=self.enable_fused_normalization, + enable_flash_attention=self.enable_flash_attention, + enable_jit_fused=self.enable_jit_fused, + enable_sequence_parallelism=self.enable_sequence_parallelism, + enable_sequence_overlap=self.enable_sequence_overlap) + shardformer = ShardFormer(shard_config) model, _ = shardformer.optimize(model) - except NotImplementedError: + except NotImplementedError as e: print(f"Tensor Parallelism policy for {model.__class__} is not implemented yet\n.") model = GeminiDDP(model, **self.gemini_config, process_group=self.dp_group, verbose=self.verbose) if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): optimizer = GeminiOptimizer( - optimizer, - model, - **self.zero_optim_config, - **self.optim_kwargs, - param_info=param_info, - tp_group=self.tp_group, - verbose=self.verbose, + optimizer, model, **self.zero_optim_config, **self.optim_kwargs, param_info=param_info, tp_group=self.tp_group, verbose=self.verbose ) return model, optimizer, criterion, dataloader, lr_scheduler diff --git a/colossalai/cluster/process_group_mesh.py b/colossalai/cluster/process_group_mesh.py index eb4532194a26..ca2b7c710b75 100644 --- a/colossalai/cluster/process_group_mesh.py +++ b/colossalai/cluster/process_group_mesh.py @@ -224,4 +224,4 @@ def get_group_along_axis( if ranks_in_group not in self._ranks_to_group: # no need to cache it explicitly, since it will be cached in `create_group_along_axis` return self.create_group_along_axis(axis, indices_at_axis, backend=backend) - return self._ranks_to_group[ranks_in_group] + return self._ranks_to_group[ranks_in_group] \ No newline at end of file diff --git a/colossalai/tensor/d_tensor/layout_converter.py b/colossalai/tensor/d_tensor/layout_converter.py index abe4a86d8198..fe43af788d4b 100644 --- a/colossalai/tensor/d_tensor/layout_converter.py +++ b/colossalai/tensor/d_tensor/layout_converter.py @@ -608,4 +608,4 @@ def apply(self, tensor: torch.Tensor, source_layout: Layout, target_layout: Layo for comm_spec in comm_action_sequence: tensor = comm_spec.covert_spec_to_action(tensor) tensor.dist_layout = target_layout - return tensor + return tensor \ No newline at end of file From 8dd4b415558fdd5a601ee98ddb6a3928d6b25c93 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Wed, 1 Nov 2023 22:53:42 +0800 Subject: [PATCH 43/46] fix --- .../tensor/d_tensor/layout_converter.py | 52 +------------------ 1 file changed, 2 insertions(+), 50 deletions(-) diff --git a/colossalai/tensor/d_tensor/layout_converter.py b/colossalai/tensor/d_tensor/layout_converter.py index fe43af788d4b..9526a0b1da9f 100644 --- a/colossalai/tensor/d_tensor/layout_converter.py +++ b/colossalai/tensor/d_tensor/layout_converter.py @@ -4,7 +4,6 @@ from typing import Dict, List, Tuple import torch -import torch.distributed as dist from colossalai.context.singleton_meta import SingletonMeta from colossalai.tensor.d_tensor.comm_spec import * @@ -439,58 +438,11 @@ def layout_converting( MAX_TRANSFORM_STEPS = 20 total_steps = 0 transform_path = [] - comm_action_sequence: List[CommSpec] = [] + comm_action_sequence = [] spec_pairs = (str(source_spec.sharding_sequence), str(target_spec.sharding_sequence)) if spec_pairs in self.cached_solution: - # Solution Cache hit - - def _group_alive_check(cached_comm_action_sequence): - r""" - Check if the process groups required for sharding have been deleted by torch.distributed.destroy_process_group method. - If not deleted, return True; otherwise, return False. - - Args: - cached_comm_action_sequence (List[CommSpec]): A list of communication specifications representing actions. - - Returns: - bool: True if all process groups are still registered, False if at least one has been deleted. - - Raises: - RuntimeError: If there is an error while checking the status of a process group. - """ - - # Collect all process groups used in communication actions from the cached sequence - used_process_groups = [ - pg for comm_spec in cached_comm_action_sequence for pg in comm_spec.process_group_dict.values() - ] - - # Check if each process group is still alive - for process_group in used_process_groups: - try: - dist.get_rank(process_group) - except RuntimeError as e: - # If the group is not registered, it means it has been deleted - if str(e) == ( - f"Group {process_group} is not registered, please create group with torch.distributed.new_group API" - ): - return False - elif str(e) == "The given group does not exist": - return False - else: - # Re-raise the exception if it's not related to group deletion - raise e - # All process groups are alive - return True - - cached_transform_path, cached_comm_action_sequence = self.cached_solution[spec_pairs] - - if _group_alive_check(cached_comm_action_sequence): - # If all process groups have not been deleted, the cache is valid - return cached_transform_path, cached_comm_action_sequence - else: - # If at least one process group has been deleted, the cache is invalid, so delete it - del self.cached_solution[spec_pairs] + return self.cached_solution[spec_pairs] # We do nothing if the sharding spec is all the same. if source_spec.spec_diff(target_spec) == 0: From 3d8319ea5af98ad384a83d4f027e597bc801ea22 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Wed, 1 Nov 2023 23:05:57 +0800 Subject: [PATCH 44/46] revert --- colossalai/cluster/process_group_mesh.py | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/colossalai/cluster/process_group_mesh.py b/colossalai/cluster/process_group_mesh.py index ca2b7c710b75..2d33a466e874 100644 --- a/colossalai/cluster/process_group_mesh.py +++ b/colossalai/cluster/process_group_mesh.py @@ -1,4 +1,3 @@ -import gc import itertools from functools import reduce from operator import mul @@ -45,24 +44,6 @@ def __init__(self, *size: int) -> None: self._ranks_to_group: Dict[Tuple[int, ...], ProcessGroup] = {} self._group_to_ranks: Dict[ProcessGroup, Tuple[int, ...]] = {} - def __del__(self): - r""" - Destructor method for the ProcessGroupMesh class. - - When the ProcessGroupMesh object is deleted or goes out of scope, this method is called. It is responsible for - cleaning up any process groups that were created during the lifetime of the object. - - Note: - All process groups in PyTorch are represented as global variables, and they may not be automatically destroyed - when the ProcessGroupMesh's lifetime ends. This method manually destroys the process groups to release - system resources. - """ - for group in self._ranks_to_group.values(): - dist.destroy_process_group(group) - - # Manually clear all process groups to save memory - gc.collect() - @property def shape(self) -> Tuple[int, ...]: return self._shape From 66ffed59eb7a270258e81308db61cabe398a6032 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Mon, 6 Nov 2023 19:10:47 +0800 Subject: [PATCH 45/46] modify tp gather method modify tp gather method modify tp gather method modify tp gather method --- colossalai/booster/plugin/gemini_plugin.py | 57 ++--- colossalai/cluster/process_group_mesh.py | 22 +- colossalai/shardformer/layer/normalization.py | 223 +++++------------- .../tensor/d_tensor/layout_converter.py | 54 ++++- colossalai/zero/gemini/gemini_ddp.py | 7 +- colossalai/zero/gemini/gemini_optimizer.py | 119 ++++++---- examples/inference/gptq_bloom.py | 34 ++- .../test_plugin/test_gemini_plugin.py | 4 +- 8 files changed, 254 insertions(+), 266 deletions(-) diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index ccac8a859de7..9c7dc6836c1e 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -5,8 +5,8 @@ from typing import Callable, Iterator, List, Optional, Tuple import torch -import torch.nn as nn import torch.distributed as dist +import torch.nn as nn from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.utils.data import DataLoader @@ -20,13 +20,12 @@ save_state_dict, save_state_dict_shards, ) -from colossalai.cluster import DistCoordinator +from colossalai.cluster import DistCoordinator, ProcessGroupMesh from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.shardformer import ShardConfig, ShardFormer from colossalai.utils import get_current_device from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero.gemini.memory_tracer import MemStats -from colossalai.cluster import ProcessGroupMesh from .dp_plugin_base import DPPluginBase @@ -54,7 +53,6 @@ def get_param_info(optim: Optimizer): start_index += len(group["params"]) return param_info - class GeminiCheckpointIO(GeneralCheckpointIO): def __init__(self) -> None: super().__init__() @@ -357,7 +355,7 @@ def __init__( enable_sequence_parallelism: bool = False, enable_jit_fused: bool = False, enable_sequence_overlap: bool = False, - verbose: bool = False + verbose: bool = False, ) -> None: super().__init__() assert precision in SUPPORTED_PRECISION, f"precision {precision} is not supported" @@ -396,7 +394,6 @@ def __init__( norm_type=norm_type, ) self.enable_tensor_parallelism = enable_tensor_parallelism - self.tp_size = tp_size if self.enable_tensor_parallelism else 1 self.enable_all_optimization = enable_all_optimization self.enable_fused_normalization = enable_fused_normalization self.enable_flash_attention = enable_flash_attention @@ -405,6 +402,23 @@ def __init__( self.enable_sequence_overlap = enable_sequence_overlap self.verbose = verbose + self.tp_size = tp_size if self.enable_tensor_parallelism else 1 + self.dp_size = dist.get_world_size() // self.tp_size + assert self.dp_size > 1, f"The size of the DP group should be greater than 1. Please reduce the TP group size." + self.pg_mesh = ProcessGroupMesh(self.dp_size, self.tp_size) + self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS) + self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) + self.shard_config = ShardConfig( + tensor_parallel_process_group=self.tp_group, + enable_tensor_parallelism=self.enable_tensor_parallelism, + enable_all_optimization=self.enable_all_optimization, + enable_fused_normalization=self.enable_fused_normalization, + enable_flash_attention=self.enable_flash_attention, + enable_jit_fused=self.enable_jit_fused, + enable_sequence_parallelism=self.enable_sequence_parallelism, + enable_sequence_overlap=self.enable_sequence_overlap, + ) + def support_no_sync(self) -> bool: return False @@ -428,7 +442,7 @@ def configure( dataloader: Optional[DataLoader] = None, lr_scheduler: Optional[LRScheduler] = None, ) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: - param_info = get_param_info(optimizer) + optimizer_params_info = get_param_info(optimizer) if not isinstance(model, ModelWrapper): # convert model to sync bn # FIXME(ver217): gemini does not support sync bn @@ -440,32 +454,21 @@ def configure( # model = nn.SyncBatchNorm.convert_sync_batchnorm(model, None) # wrap the model with Gemini - self.dp_group = None - self.tp_group = None - try: - dp_size = dist.get_world_size() // self.tp_size - assert dp_size > 1, f"The size of the DP group should be greater than 1. Please reduce the TP group size." - self.pg_mesh = ProcessGroupMesh(dp_size, self.tp_size) - self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS) - self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) - shard_config = ShardConfig(tensor_parallel_process_group = self.tp_group, - enable_tensor_parallelism=self.enable_tensor_parallelism, - enable_all_optimization=self.enable_all_optimization, - enable_fused_normalization=self.enable_fused_normalization, - enable_flash_attention=self.enable_flash_attention, - enable_jit_fused=self.enable_jit_fused, - enable_sequence_parallelism=self.enable_sequence_parallelism, - enable_sequence_overlap=self.enable_sequence_overlap) - shardformer = ShardFormer(shard_config) + if self.enable_tensor_parallelism: + shardformer = ShardFormer(self.shard_config) model, _ = shardformer.optimize(model) - except NotImplementedError as e: - print(f"Tensor Parallelism policy for {model.__class__} is not implemented yet\n.") model = GeminiDDP(model, **self.gemini_config, process_group=self.dp_group, verbose=self.verbose) if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): optimizer = GeminiOptimizer( - optimizer, model, **self.zero_optim_config, **self.optim_kwargs, param_info=param_info, tp_group=self.tp_group, verbose=self.verbose + optimizer, + model, + **self.zero_optim_config, + **self.optim_kwargs, + tp_group=self.tp_group, + optimizer_params_info=optimizer_params_info, + verbose=self.verbose, ) return model, optimizer, criterion, dataloader, lr_scheduler diff --git a/colossalai/cluster/process_group_mesh.py b/colossalai/cluster/process_group_mesh.py index 2d33a466e874..7a3bde44869c 100644 --- a/colossalai/cluster/process_group_mesh.py +++ b/colossalai/cluster/process_group_mesh.py @@ -1,3 +1,4 @@ +import gc import itertools from functools import reduce from operator import mul @@ -44,6 +45,24 @@ def __init__(self, *size: int) -> None: self._ranks_to_group: Dict[Tuple[int, ...], ProcessGroup] = {} self._group_to_ranks: Dict[ProcessGroup, Tuple[int, ...]] = {} + def __del__(self): + r""" + Destructor method for the ProcessGroupMesh class. + + When the ProcessGroupMesh object is deleted or goes out of scope, this method is called. It is responsible for + cleaning up any process groups that were created during the lifetime of the object. + + Note: + All process groups in PyTorch are represented as global variables, and they may not be automatically destroyed + when the ProcessGroupMesh's lifetime ends. This method manually destroys the process groups to release + system resources. + """ + for group in self._ranks_to_group.values(): + dist.destroy_process_group(group) + + # Manually clear all process groups to save memory + gc.collect() + @property def shape(self) -> Tuple[int, ...]: return self._shape @@ -205,4 +224,5 @@ def get_group_along_axis( if ranks_in_group not in self._ranks_to_group: # no need to cache it explicitly, since it will be cached in `create_group_along_axis` return self.create_group_along_axis(axis, indices_at_axis, backend=backend) - return self._ranks_to_group[ranks_in_group] \ No newline at end of file + return self._ranks_to_group[ranks_in_group] + \ No newline at end of file diff --git a/colossalai/shardformer/layer/normalization.py b/colossalai/shardformer/layer/normalization.py index 6c8dd7c9cb9c..3ced5e7d5d4a 100644 --- a/colossalai/shardformer/layer/normalization.py +++ b/colossalai/shardformer/layer/normalization.py @@ -1,19 +1,26 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- - import warnings import torch.nn as nn -from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, Module -from typing import Tuple, Iterator, Set, Optional, Mapping, List, Any -from torch.nn import Parameter -from collections import OrderedDict -from torch.nn.modules.module import _IncompatibleKeys - from colossalai.lazy import LazyInitContext from ._operation import hook_paramter_in_backward __all__ = ["FusedLayerNorm", "FusedRMSNorm"] +try: + from apex.contrib.layer_norm.layer_norm import FastLayerNorm + EnableFastLayerNorm = True +except ImportError: + EnableFastLayerNorm = False + +try: + from apex.normalization import FusedLayerNorm as ApexFusedLayerNorm + from apex.normalization import FusedRMSNorm as ApexFusedRMSNorm +except ImportError: + warnings.warn( + "Please install apex from source (https://github.com/NVIDIA/apex) to use the fused layernorm kernel" + ) + FAST_LAYERNORM_SUPPORTED_SIZE = [ 1024, 1536, @@ -41,133 +48,51 @@ 65536, ] -class LayerNormBase(nn.Module): +if EnableFastLayerNorm: + class FastLayerNormWithHook(FastLayerNorm): + def __init__(self, hidden_size, eps=0.00001): + super().__init__(hidden_size, eps) - def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, Parameter]]: - r"""Returns an iterator over module parameters, yielding both the - name of the parameter as well as the parameter itself. - """ - gen = self._named_members( - lambda module: module._parameters.items(), - prefix=prefix, recurse=False) - for elem in gen: - yield elem - - def named_modules(self, memo: Optional[Set['Module']] = None, prefix: str = '', remove_duplicate: bool = True): - r"""Returns an iterator over all modules in the network, yielding - both the name of the module as well as the module itself. - """ - if memo is None: - memo = set() - if self not in memo: - if remove_duplicate: - memo.add(self) - yield prefix, self - - def state_dict(self, *args, destination=None, prefix='', keep_vars=False): - r"""Don't recursive process self._module""" - # TODO: Remove `args` and the parsing logic when BC allows. - if len(args) > 0: - if destination is None: - destination = args[0] - if len(args) > 1 and prefix == '': - prefix = args[1] - if len(args) > 2 and keep_vars is False: - keep_vars = args[2] - # DeprecationWarning is ignored by default - warnings.warn( - "Positional args are being deprecated, use kwargs instead. Refer to " - "https://pytorch.org/docs/master/generated/torch.nn.Module.html#torch.nn.Module.state_dict" - " for details.") - - if destination is None: - destination = OrderedDict() - destination._metadata = OrderedDict() - - local_metadata = dict(version=self._version) - if hasattr(destination, "_metadata"): - destination._metadata[prefix[:-1]] = local_metadata - - self._save_to_state_dict(destination, prefix, keep_vars) - for hook in self._state_dict_hooks.values(): - hook_result = hook(self, destination, prefix, local_metadata) - if hook_result is not None: - destination = hook_result - return destination - - def load_state_dict(self, state_dict: Mapping[str, Any], - strict: bool = True): - r"""Don't recursive process self._module""" - if not isinstance(state_dict, Mapping): - raise TypeError("Expected state_dict to be dict-like, got {}.".format(type(state_dict))) - - missing_keys: List[str] = [] - unexpected_keys: List[str] = [] - error_msgs: List[str] = [] - - # copy state_dict so _load_from_state_dict can modify it - metadata = getattr(state_dict, '_metadata', None) - state_dict = OrderedDict(state_dict) - if metadata is not None: - # mypy isn't aware that "_metadata" exists in state_dict - state_dict._metadata = metadata # type: ignore[attr-defined] - - def load(module, prefix=''): - local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) - module._load_from_state_dict( - state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) - - # Note that the hook can modify missing_keys and unexpected_keys. - incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys) - for hook in module._load_state_dict_post_hooks.values(): - out = hook(module, incompatible_keys) - assert out is None, ( - "Hooks registered with ``register_load_state_dict_post_hook`` are not" - "expected to return new values, if incompatible_keys need to be modified," - "it should be done inplace." - ) - - load(self) - del load - - if strict: - if len(unexpected_keys) > 0: - error_msgs.insert( - 0, 'Unexpected key(s) in state_dict: {}. '.format( - ', '.join('"{}"'.format(k) for k in unexpected_keys))) - if len(missing_keys) > 0: - error_msgs.insert( - 0, 'Missing key(s) in state_dict: {}. '.format( - ', '.join('"{}"'.format(k) for k in missing_keys))) - - if len(error_msgs) > 0: - raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( - self.__class__.__name__, "\n\t".join(error_msgs))) - return _IncompatibleKeys(missing_keys, unexpected_keys) - -class FusedLayerNorm(LayerNormBase): + def forward(self, input): + output = super().forward(input) + output = hook_paramter_in_backward(output, self.weight, self.bias) + return output + +class FusedLayerNormWithHook(ApexFusedLayerNorm): + def __init__(self, normalized_shape, eps=0.00001, elementwise_affine=True): + super().__init__(normalized_shape, eps, elementwise_affine) + + def forward(self, input): + output = super().forward(input) + output = hook_paramter_in_backward(output, self.weight, self.bias) + return output + +class FusedRMSNormWithHook(ApexFusedRMSNorm): + def __init__(self, normalized_shape, eps=0.00001, elementwise_affine=True): + super().__init__(normalized_shape, eps, elementwise_affine) + + def forward(self, input): + output = super().forward(input) + output = hook_paramter_in_backward(output, self.weight) + return output + + +class FusedLayerNorm: r""" This is a wrapper around the apex fused layernorm implementation. It is meant to be used only with the from_native_module interface. """ - def __init__(self, layernorm=None) -> None: - super().__init__() - self.layernorm = layernorm - self._parameters = layernorm._parameters - self._buffers = layernorm._buffers + def __init__(self) -> None: + raise NotImplementedError( + "FusedLayerNorm is not implemented as a physical class. " + "It is meant to be used only with the from_native_module interface to wrap the fused layernorm implementation provided by apex." + ) @staticmethod def from_native_module(module: nn.LayerNorm, *args, **kwargs) -> nn.Module: r""" Convert a native pytorch layer norm module to colossalai layer norm module """ - # check if apex is installed - try: - pass - except ImportError: - raise ImportError( - "Please install apex from source (https://github.com/NVIDIA/apex) to use the fused layernorm kernel" - ) LazyInitContext.materialize(module) # get the attributes of the module @@ -181,51 +106,35 @@ def from_native_module(module: nn.LayerNorm, *args, **kwargs) -> nn.Module: use_fast_ln = normalized_shape in FAST_LAYERNORM_SUPPORTED_SIZE if use_fast_ln: - try: - from apex.contrib.layer_norm.layer_norm import FastLayerNorm as ApexFusedLayerNorm - except ImportError: + if EnableFastLayerNorm: + ApexFusedLayerNorm = FastLayerNormWithHook + else: # fall back to the normal fused layernorm is not built - from apex.normalization import FusedLayerNorm as ApexFusedLayerNorm + ApexFusedLayerNorm = FusedLayerNormWithHook else: - from apex.normalization import FusedLayerNorm as ApexFusedLayerNorm + ApexFusedLayerNorm = FusedLayerNormWithHook layernorm = ( ApexFusedLayerNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine).to(dtype).to(device) ) - layernorm.weight = module.weight layernorm.bias = module.bias - return FusedLayerNorm(layernorm=layernorm) - - def forward(self, input): - weight = self.layernorm.weight - bias = self.layernorm.bias - layernorm_output = self.layernorm(input) - output = hook_paramter_in_backward(layernorm_output, weight, bias) - return output + return layernorm -class FusedRMSNorm(LayerNormBase): +class FusedRMSNorm: """ This is a wrapper around the apex fused rms norm implementation. It is meant to be used only with the from_native_module interface. """ - - def __init__(self, rmsnorm=None) -> None: - super().__init__() - self.rmsnorm = rmsnorm - self._parameters = rmsnorm._parameters - self._buffers = rmsnorm._buffers - + def __init__(self) -> None: + raise NotImplementedError( + "FusedRMSNorm is not implemented as a physical class. " + "It is meant to be used only with the from_native_module interface to wrap the fused rms norm implementation provided by apex." + ) + @staticmethod def from_native_module(module: nn.Module, *args, **kwargs) -> nn.Module: - try: - from apex.normalization import FusedRMSNorm as ApexFusedRMSNorm - except ImportError: - raise ImportError( - "Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMS normalization kernel" - ) - LazyInitContext.materialize(module) # to check if it is huggingface LlamaRMSNorm if module.__class__.__name__ == "LlamaRMSNorm": @@ -238,14 +147,8 @@ def from_native_module(module: nn.Module, *args, **kwargs) -> nn.Module: eps = module.eps elementwise_affine = module.elementwise_affine - rmsnorm = ApexFusedRMSNorm(normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine) + rmsnorm = FusedRMSNormWithHook(normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine) rmsnorm.weight = module.weight - return FusedRMSNorm(rmsnorm=rmsnorm) - - def forward(self, input): - weight = self.rmsnorm.weight - rmsnorm_output = self.rmsnorm(input) - output = hook_paramter_in_backward(rmsnorm_output, weight) - return output + return rmsnorm diff --git a/colossalai/tensor/d_tensor/layout_converter.py b/colossalai/tensor/d_tensor/layout_converter.py index 9526a0b1da9f..abe4a86d8198 100644 --- a/colossalai/tensor/d_tensor/layout_converter.py +++ b/colossalai/tensor/d_tensor/layout_converter.py @@ -4,6 +4,7 @@ from typing import Dict, List, Tuple import torch +import torch.distributed as dist from colossalai.context.singleton_meta import SingletonMeta from colossalai.tensor.d_tensor.comm_spec import * @@ -438,11 +439,58 @@ def layout_converting( MAX_TRANSFORM_STEPS = 20 total_steps = 0 transform_path = [] - comm_action_sequence = [] + comm_action_sequence: List[CommSpec] = [] spec_pairs = (str(source_spec.sharding_sequence), str(target_spec.sharding_sequence)) if spec_pairs in self.cached_solution: - return self.cached_solution[spec_pairs] + # Solution Cache hit + + def _group_alive_check(cached_comm_action_sequence): + r""" + Check if the process groups required for sharding have been deleted by torch.distributed.destroy_process_group method. + If not deleted, return True; otherwise, return False. + + Args: + cached_comm_action_sequence (List[CommSpec]): A list of communication specifications representing actions. + + Returns: + bool: True if all process groups are still registered, False if at least one has been deleted. + + Raises: + RuntimeError: If there is an error while checking the status of a process group. + """ + + # Collect all process groups used in communication actions from the cached sequence + used_process_groups = [ + pg for comm_spec in cached_comm_action_sequence for pg in comm_spec.process_group_dict.values() + ] + + # Check if each process group is still alive + for process_group in used_process_groups: + try: + dist.get_rank(process_group) + except RuntimeError as e: + # If the group is not registered, it means it has been deleted + if str(e) == ( + f"Group {process_group} is not registered, please create group with torch.distributed.new_group API" + ): + return False + elif str(e) == "The given group does not exist": + return False + else: + # Re-raise the exception if it's not related to group deletion + raise e + # All process groups are alive + return True + + cached_transform_path, cached_comm_action_sequence = self.cached_solution[spec_pairs] + + if _group_alive_check(cached_comm_action_sequence): + # If all process groups have not been deleted, the cache is valid + return cached_transform_path, cached_comm_action_sequence + else: + # If at least one process group has been deleted, the cache is invalid, so delete it + del self.cached_solution[spec_pairs] # We do nothing if the sharding spec is all the same. if source_spec.spec_diff(target_spec) == 0: @@ -560,4 +608,4 @@ def apply(self, tensor: torch.Tensor, source_layout: Layout, target_layout: Layo for comm_spec in comm_action_sequence: tensor = comm_spec.covert_spec_to_action(tensor) tensor.dist_layout = target_layout - return tensor \ No newline at end of file + return tensor diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index 59a00bb9f8d5..4a6542bd6a64 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -332,9 +332,6 @@ def backward(self, loss: torch.Tensor): def backward_by_grad(self, tensor, grad): raise RuntimeError("Gemini is not compatible with pipeline. backward_by_grad shoudn't be called in Gemini.") - with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(self.param_op_hook): - torch.autograd.backward(tensor, grad) - self._post_backward() def grad_handle(self, p, grad): setattr(p, "_gemini_reduced", True) @@ -444,7 +441,7 @@ def _get_chunk_to_save_data(self, chunk: Chunk, only_rank_0: bool) -> Dict: record_tensor = torch.empty([0]) record_flag = (not only_rank_0) | (dist.get_rank(chunk.torch_pg) == 0) if record_flag: - record_tensor = temp_chunk[tensor_info.offset : tensor_info.end].view(tensor.shape).cuda() + record_tensor = temp_chunk[tensor_info.offset : tensor_info.end].view(tensor.shape).to(tensor.device) if is_distributed_tensor(tensor): global_shape = get_global_shape(tensor) device_mesh = get_device_mesh(tensor) @@ -455,7 +452,7 @@ def _get_chunk_to_save_data(self, chunk: Chunk, only_rank_0: bool) -> Dict: global_shape = global_shape) elif is_customized_distributed_tensor(tensor): init_tensor_as_customization_distributed(record_tensor, shard_fn=tensor.shard_fn, gather_fn=tensor.gather_fn) - record_tensor = gather_distributed_param(record_tensor, keep_vars=False).cpu() + record_tensor = gather_distributed_param(record_tensor, keep_vars=False).cpu() assert tensor not in chunk_to_save_data chunk_to_save_data[tensor] = record_tensor diff --git a/colossalai/zero/gemini/gemini_optimizer.py b/colossalai/zero/gemini/gemini_optimizer.py index 0d993246a51a..93f181ff4c73 100644 --- a/colossalai/zero/gemini/gemini_optimizer.py +++ b/colossalai/zero/gemini/gemini_optimizer.py @@ -20,7 +20,18 @@ from .chunk import Chunk, ChunkManager from .gemini_ddp import GeminiDDP -from colossalai.checkpoint_io.utils import search_tp_partition_dim +from colossalai.checkpoint_io.utils import gather_distributed_param +from colossalai.tensor.d_tensor import ( + distribute_tensor, + distribute_tensor_with_customization, + init_tensor_as_customization_distributed, + get_device_mesh, + get_sharding_spec, + is_customized_distributed_tensor, + is_distributed_tensor, + get_global_shape, + init_as_dtensor +) __all__ = ["GeminiOptimizer", "GeminiAdamOptimizer"] @@ -95,8 +106,8 @@ def __init__( max_scale: float = 2**32, max_norm: float = 0.0, norm_type: float = 2.0, - param_info: OrderedDict = None, tp_group: ProcessGroup = None, + optimizer_params_info=None, verbose: bool = False, **defaults: Any, ): @@ -113,10 +124,10 @@ def __init__( self.chunk16_set: Set[Chunk] = set() self.clipping_flag = max_norm > 0.0 self.max_norm = max_norm - self.param_info = param_info self.tp_group = tp_group + self.optimizer_params_info = optimizer_params_info self.tp_size = dist.get_world_size(tp_group) if tp_group is not None else 1 - self.tp_rank = dist.get_rank(tp_group) if tp_group is not None else None + self.tp_rank = dist.get_rank(tp_group) if tp_group is not None else 0 self.verbose = verbose self.param_groups_backup = list() @@ -444,16 +455,16 @@ def collect_states(self, param_id: int, only_rank_0: bool = True) -> dict: # Every rank is collector when only_rank_0 is False. is_collector = (rank == master_rank) or (not only_rank_0) + # get tensor parallelism information + is_dtensor = is_distributed_tensor(param) + is_customized_distributed = is_customized_distributed_tensor(param) + shard_spec = get_sharding_spec(param) if is_dtensor else None + device_mesh = get_device_mesh(param) if is_dtensor else None + global_shape = get_global_shape(param) if is_dtensor else None + # If the chunk is kept gathered, # the parameteres are treated the same as that of those in strict DDP during training. # So states can be directly fetched from current device. - tp_shard_info = {} - current_shape = param.shape - original_shape = self.param_info["id2shape"][param_id] - partition_dim = search_tp_partition_dim(current_shape, original_shape, self.tp_size) - tp_shard_info["shared"] = (current_shape, original_shape, partition_dim) - - if chunk.keep_gathered: assert param_id in self.id_to_fake_params if is_collector: @@ -465,14 +476,19 @@ def collect_states(self, param_id: int, only_rank_0: bool = True) -> dict: states["step"], dtype=torch.float32, requires_grad=False ).cpu() else: - state_tensor = states[state_name].detach().clone().to(torch.float32).cuda() - state_tensor = state_tensor.view(current_shape) - if partition_dim is not None: - gather_tensor = [torch.zeros_like(state_tensor) for _ in range(self.tp_size)] - dist.all_gather(gather_tensor, state_tensor, group=self.tp_group) - state_tensor = torch.cat(gather_tensor, dim=partition_dim) + state_tensor = states[state_name].detach().clone().to(torch.float32).cpu() + if is_dtensor: + state_tensor = torch.reshape(state_tensor, param.shape).to(param.device) + state_tensor = init_as_dtensor(state_tensor, + device_mesh=device_mesh, + sharding_spec=shard_spec, + global_shape = global_shape) + elif is_customized_distributed: + state_tensor = torch.reshape(state_tensor, param.shape).to(param.device) + init_tensor_as_customization_distributed(state_tensor, shard_fn=param.shard_fn, gather_fn=param.gather_fn) + state_tensor = gather_distributed_param(state_tensor, keep_vars=False).cpu() - collected_states[state_name] = torch.reshape(state_tensor, original_shape).cpu() + collected_states[state_name] = state_tensor return collected_states # Check whether the param with given id is managed by current process. @@ -485,18 +501,13 @@ def collect_states(self, param_id: int, only_rank_0: bool = True) -> dict: # To keep aligned with pytorch, state 'step' is stored as a pytorch tensor with type float32. collected_states[state_name] = torch.tensor(0.0, dtype=torch.float32, requires_grad=False).cpu() else: - tensor_size = param.numel() - if partition_dim is not None: - tensor_size = tensor_size * self.tp_size collected_states[state_name] = torch.zeros( - tensor_size, dtype=torch.float32, requires_grad=False + param.numel(), dtype=torch.float32, requires_grad=False ).cpu() # Materials for gathering, including compacted state tensors, and the offset of shard inside each state. - compacted_states = self.pack_optimizer_states_to_tensor(param_id, state_names, tp_shard_info) if own_param else None + compacted_states = self.pack_optimizer_states_to_tensor(param_id, state_names) if own_param else None _, shard_offset, shard_size = self.get_offsets(param_id) - if partition_dim is not None: - shard_size = shard_size * self.tp_size # Collectors gather state shards through all_gathering. gathered_state_shards = [None for _ in range(dist.get_world_size(dp_group))] @@ -519,7 +530,17 @@ def collect_states(self, param_id: int, only_rank_0: bool = True) -> dict: if is_collector: for state_name, state_tensor in collected_states.items(): if state_tensor.numel() == param.numel(): - collected_states[state_name] = torch.reshape(state_tensor, original_shape) + collected_states[state_name] = torch.reshape(state_tensor, param.shape) + if is_dtensor: + state_tensor = state_tensor.to(param.device) + state_tensor = init_as_dtensor(state_tensor, + sharding_spec=shard_spec, + device_mesh=device_mesh, + global_shape=global_shape) + elif is_customized_distributed: + state_tensor = state_tensor.to(param.device) + init_tensor_as_customization_distributed(state_tensor, shard_fn=param.shard_fn, gather_fn=param.gather_fn) + state_tensor = gather_distributed_param(state_tensor, keep_vars=False).cpu() return collected_states @@ -527,7 +548,6 @@ def pack_optimizer_states_to_tensor( self, param_id: int, state_names: list, - tp_shard_info: dict, device: torch.device = torch.device("cuda"), dtype: torch.dtype = torch.float32, ) -> torch.Tensor: @@ -542,11 +562,6 @@ def pack_optimizer_states_to_tensor( states = self.optim.state[fake_param] shard_size = param_range[1] - param_range[0] compacted_size = 0 - - (current_shape, original_shape, partition_dim) = tp_shard_info["shared"] - if partition_dim is not None: - shard_size = shard_size * self.tp_size - for name in state_names: if name == "step": compacted_size += 1 @@ -565,11 +580,6 @@ def pack_optimizer_states_to_tensor( compacted_states[next_state_offset] = state_tensor next_state_offset += 1 else: - if partition_dim is not None: - gather_tensor = [torch.zeros(current_shape, dtype=state_tensor.dtype, device=state_tensor.device) for _ in range(self.tp_size)] - dist.all_gather(gather_tensor, state_tensor, group=self.tp_group) - state_tensor = torch.cat(gather_tensor, dim=partition_dim).view(original_shape) - assert state_tensor.numel() == shard_size compacted_states[next_state_offset : next_state_offset + shard_size].copy_(state_tensor) next_state_offset += shard_size @@ -582,7 +592,7 @@ def load_from_compacted_states( collected_states: dict, state_names: list, shard_start: int, - shard_size: int + shard_size: int, ): """ Given a tensor carrying compacted optimizer states, @@ -681,7 +691,7 @@ def load_single_param_states(self, param_id: int, saved_states: dict): Load saved optimizer states into parameter with given id. """ - def cast(param, state_range, value, key=None, tp_shard_info=None): + def cast(param, state_range, value, key=None): """ Make a copy of the needed segment of value and cast it to device of param. """ @@ -695,12 +705,13 @@ def cast(param, state_range, value, key=None, tp_shard_info=None): ret_val = torch.zeros( state_end - state_start, dtype=torch.float32, device=param.device, requires_grad=False ) - - (current_shape, _, partition_dim) = tp_shard_info["shard"] - if partition_dim is not None: - slice_size = current_shape[partition_dim] - value = value.split(slice_size, dim=partition_dim)[self.tp_rank] + if is_dtensor: + value = torch.reshape(value, global_shape) + value = distribute_tensor(value, sharding_spec=shard_spec, device_mesh=device_mesh) + elif is_customized_distributed: + value = torch.reshape(value, global_shape) + value = distribute_tensor_with_customization(value, real_param.shard_fn, real_param.gather_fn) ret_val.copy_(value.flatten()[state_start:state_end]) return ret_val @@ -713,14 +724,22 @@ def cast(param, state_range, value, key=None, tp_shard_info=None): # Copy states assigned to param (and cast tensors to appropriate types). updated_states = dict() - tp_shard_info = {} - current_shape = self.id_to_real_params[param_id].shape - original_shape = self.param_info["id2shape"][param_id] - partition_dim = search_tp_partition_dim(current_shape, original_shape, self.tp_size) - tp_shard_info["shard"] = (current_shape, original_shape, partition_dim) + # get tensor parallelism information + real_param = self.id_to_real_params[param_id] + is_dtensor = is_distributed_tensor(real_param) + is_customized_distributed = is_customized_distributed_tensor(real_param) + shard_spec = get_sharding_spec(real_param) if is_dtensor else None + device_mesh = get_device_mesh(real_param) if is_dtensor else None + if is_dtensor: + global_shape = get_global_shape(real_param) + elif is_customized_distributed: + global_shape = self.optimizer_params_info["id2shape"][param_id] + else: + global_shape = None + for k, v in saved_states.items(): - updated_states[k] = cast(fake_param, state_range, v, k, tp_shard_info) + updated_states[k] = cast(fake_param, state_range, v, k) del v # clean loaded states self.optim.state[fake_param].update(updated_states) diff --git a/examples/inference/gptq_bloom.py b/examples/inference/gptq_bloom.py index 9afa438dc1a5..7f4f9973d300 100644 --- a/examples/inference/gptq_bloom.py +++ b/examples/inference/gptq_bloom.py @@ -1,5 +1,4 @@ import argparse -import logging import os import time @@ -37,7 +36,6 @@ def print_perf_stats(latency_set, config, bs, warmup=3): def bench_bloom(args): - pretrained_model_dir = args.path quantized_model_dir = args.quantized_path max_batch_size = args.batch_size @@ -48,9 +46,9 @@ def bench_bloom(args): tokenizer.pad_token = tokenizer.eos_token # load quantized model to the first GPU - model = AutoGPTQForCausalLM.from_quantized(quantized_model_dir, - device=torch.cuda.current_device(), - inject_fused_attention=False) + model = AutoGPTQForCausalLM.from_quantized( + quantized_model_dir, device=torch.cuda.current_device(), inject_fused_attention=False + ) model = model.half() @@ -60,22 +58,22 @@ def bench_bloom(args): generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False) input_tokens = { - "input_ids": torch.randint(1, 1000, (max_batch_size, max_input_len), device='cuda'), - "attention_mask": torch.ones((max_batch_size, max_input_len), device='cuda') + "input_ids": torch.randint(1, 1000, (max_batch_size, max_input_len), device="cuda"), + "attention_mask": torch.ones((max_batch_size, max_input_len), device="cuda"), } # init TPInferEngine and shard the original model # To benchmark torch original, comment out the line of optimizing model - shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, - inference_only=True, - inference_gptq=True) + shard_config = ShardConfig( + enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True, inference_gptq=True + ) infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len) # prepare data for generation generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False) input_tokens = { "input_ids": torch.randint(10, 1000, (max_batch_size, max_input_len)), - "attention_mask": torch.ones((max_batch_size, max_input_len)) + "attention_mask": torch.ones((max_batch_size, max_input_len)), } for t in input_tokens: if torch.is_tensor(input_tokens[t]): @@ -99,7 +97,7 @@ def bench_bloom(args): def check_bloom(rank, world_size, port, args): disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") bench_bloom(args) @@ -111,12 +109,12 @@ def test_bloom(args): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('-p', '--path', type=str, help='Model path', required=True) - parser.add_argument('-q', '--quantized_path', type=str, help='Model path', required=True) - parser.add_argument('-tp', '--tp_size', type=int, default=1, help='Tensor parallel size') - parser.add_argument('-b', '--batch_size', type=int, default=16, help='Maximum batch size') - parser.add_argument('--input_len', type=int, default=1024, help='Maximum input length') - parser.add_argument('--output_len', type=int, default=128, help='Maximum output length') + parser.add_argument("-p", "--path", type=str, help="Model path", required=True) + parser.add_argument("-q", "--quantized_path", type=str, help="Model path", required=True) + parser.add_argument("-tp", "--tp_size", type=int, default=1, help="Tensor parallel size") + parser.add_argument("-b", "--batch_size", type=int, default=16, help="Maximum batch size") + parser.add_argument("--input_len", type=int, default=1024, help="Maximum input length") + parser.add_argument("--output_len", type=int, default=128, help="Maximum output length") args = parser.parse_args() diff --git a/tests/test_booster/test_plugin/test_gemini_plugin.py b/tests/test_booster/test_plugin/test_gemini_plugin.py index c27a0fd7a2e3..97ec0233f766 100644 --- a/tests/test_booster/test_plugin/test_gemini_plugin.py +++ b/tests/test_booster/test_plugin/test_gemini_plugin.py @@ -49,6 +49,8 @@ def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, enable_tenso booster.backward(loss, optimizer) optimizer.step() + except NotImplementedError: + print(f"Tensor Parallelism policy for {model.__class__} is not implemented yet\n.") except Exception as e: # raise e return repr(e) @@ -126,8 +128,6 @@ def check_gemini_plugin(subset: str, init_method: str = "none", enable_tensor_pa enable_tensor_parallelism = False err = run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, enable_tensor_parallelism) - clear_layout_converter() - Randomizer.reset_index() torch.cuda.empty_cache() if err is None: passed_models.append(name) From c40c459cd391ae15b83efac6babf9d79c7f72605 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Wed, 8 Nov 2023 13:07:12 +0800 Subject: [PATCH 46/46] fix test --- colossalai/zero/gemini/gemini_optimizer.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/colossalai/zero/gemini/gemini_optimizer.py b/colossalai/zero/gemini/gemini_optimizer.py index 93f181ff4c73..e20d846f1071 100644 --- a/colossalai/zero/gemini/gemini_optimizer.py +++ b/colossalai/zero/gemini/gemini_optimizer.py @@ -460,7 +460,7 @@ def collect_states(self, param_id: int, only_rank_0: bool = True) -> dict: is_customized_distributed = is_customized_distributed_tensor(param) shard_spec = get_sharding_spec(param) if is_dtensor else None device_mesh = get_device_mesh(param) if is_dtensor else None - global_shape = get_global_shape(param) if is_dtensor else None + global_shape = self.optimizer_params_info["id2shape"][param_id] # If the chunk is kept gathered, # the parameteres are treated the same as that of those in strict DDP during training. @@ -488,7 +488,7 @@ def collect_states(self, param_id: int, only_rank_0: bool = True) -> dict: init_tensor_as_customization_distributed(state_tensor, shard_fn=param.shard_fn, gather_fn=param.gather_fn) state_tensor = gather_distributed_param(state_tensor, keep_vars=False).cpu() - collected_states[state_name] = state_tensor + collected_states[state_name] = state_tensor.reshape(global_shape) return collected_states # Check whether the param with given id is managed by current process. @@ -705,7 +705,7 @@ def cast(param, state_range, value, key=None): ret_val = torch.zeros( state_end - state_start, dtype=torch.float32, device=param.device, requires_grad=False ) - + if is_dtensor: value = torch.reshape(value, global_shape) value = distribute_tensor(value, sharding_spec=shard_spec, device_mesh=device_mesh) @@ -730,14 +730,8 @@ def cast(param, state_range, value, key=None): is_customized_distributed = is_customized_distributed_tensor(real_param) shard_spec = get_sharding_spec(real_param) if is_dtensor else None device_mesh = get_device_mesh(real_param) if is_dtensor else None - if is_dtensor: - global_shape = get_global_shape(real_param) - elif is_customized_distributed: - global_shape = self.optimizer_params_info["id2shape"][param_id] - else: - global_shape = None + global_shape = self.optimizer_params_info["id2shape"][param_id] - for k, v in saved_states.items(): updated_states[k] = cast(fake_param, state_range, v, k) del v # clean loaded states