From 7ea5d51f7e6452cc1d43fc529290402be8396e5a Mon Sep 17 00:00:00 2001 From: yuanheng-zhao Date: Thu, 24 Aug 2023 19:12:04 +0800 Subject: [PATCH 01/16] add bloom inference methods and policy --- colossalai/shardformer/modeling/bloom.py | 433 ++++++++++++++++++ .../shardformer/policies/auto_policy.py | 4 +- colossalai/shardformer/policies/bloom.py | 23 + 3 files changed, 458 insertions(+), 2 deletions(-) diff --git a/colossalai/shardformer/modeling/bloom.py b/colossalai/shardformer/modeling/bloom.py index 12276635ecfa..d4dafaaa66d7 100644 --- a/colossalai/shardformer/modeling/bloom.py +++ b/colossalai/shardformer/modeling/bloom.py @@ -1,3 +1,4 @@ +import math import warnings from typing import List, Optional, Tuple, Union @@ -14,6 +15,8 @@ TokenClassifierOutput, ) from transformers.models.bloom.modeling_bloom import ( + BloomAttention, + BloomBlock, BloomForCausalLM, BloomForQuestionAnswering, BloomForSequenceClassification, @@ -22,7 +25,11 @@ ) from transformers.utils import logging +from colossalai.kernel.triton.context_attention import bloom_context_attn_fwd +from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest +from colossalai.kernel.triton.token_attention_kernel import token_attention_fwd from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.shardformer.inference import BatchInferState def build_bloom_alibi_tensor_fn(process_group: ProcessGroup) -> torch.Tensor: @@ -91,6 +98,67 @@ def build_bloom_alibi_tensor(self, attention_mask: torch.Tensor, num_heads: int, return build_bloom_alibi_tensor +def generate_alibi(n_head, dtype=torch.float16): + """ + This method is originally the `build_alibi_tensor` function + in `transformers/models/bloom/modeling_bloom.py` + of the huggingface/transformers GitHub repository. + + Copyright 2023 ModelTC Team + Copyright 2022 HuggingFace Inc. team and BigScience workshop + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + """ + + def get_slopes(n): + + def get_slopes_power_of_2(n): + start = 2**(-(2**-(math.log2(n) - 3))) + ratio = start + return [start * ratio**i for i in range(n)] + + if math.log2(n).is_integer(): + return get_slopes_power_of_2(n) + else: + closest_power_of_2 = 2**math.floor(math.log2(n)) + return (get_slopes_power_of_2(closest_power_of_2) + + get_slopes(2 * closest_power_of_2)[0::2][:n - closest_power_of_2]) + + slopes = torch.Tensor(get_slopes(n_head)) + head_alibi = slopes.to(dtype) + return head_alibi # 1 * num_heads + + +def generate_alibi_2(n_head, dtype=torch.float16): + + def get_slopes_power_of_2(n): + start = 2**(-(2**-(math.log2(n) - 3))) + return [start * start**i for i in range(n)] + + def get_slopes(n): + if math.log2(n).is_integer(): + return get_slopes_power_of_2(n) + else: + closest_power_of_2 = 2**math.floor(math.log2(n)) + slopes_power_of_2 = get_slopes_power_of_2(closest_power_of_2) + slopes_double = get_slopes(2 * closest_power_of_2) + slopes_combined = slopes_power_of_2 + slopes_double[0::2][:n - closest_power_of_2] + return slopes_combined + + slopes = torch.tensor(get_slopes(n_head), dtype=dtype) + return slopes + + class BloomPipelineForwards: ''' This class serves as a micro library for bloom pipeline forwards. @@ -678,6 +746,371 @@ def bloom_for_question_answering_forward( return {'hidden_states': hidden_states} +class BloomInferenceForwards: + """ + This class serves a micro library for bloom inference forwards + """ + + @staticmethod + def bloom_model_forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **deprecated_arguments, + ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]: + + logger = logging.get_logger(__name__) + + if deprecated_arguments.pop("position_ids", False) is not False: + # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None` + warnings.warn( + "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore" + " passing `position_ids`.", + FutureWarning, + ) + if len(deprecated_arguments) > 0: + raise ValueError(f"Got unexpected arguments: {deprecated_arguments}") + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + # initialize BatchInferState to track necessary states during current model forward + infer_state = BatchInferState() + infer_state.batch_size = batch_size + # TODO: dummy implementation here for testing, assume all inputs same length + infer_state.total_token_num = batch_size * seq_length + infer_state.block_loc = self.block_loc + infer_state.start_loc = self.b_start_loc + infer_state.seq_len = self.b_seq_len + + # still need to keep past_key_values to fit original forward flow + if past_key_values is None: + past_key_values = tuple([None] * len(self.h)) + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape batch_size x num_heads x N x N + # head_mask has shape n_layer x batch x num_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + hidden_states = self.word_embeddings_layernorm(inputs_embeds) + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + use_cache = False + + # Compute alibi tensor: check build_alibi_tensor documentation + seq_length_with_past = seq_length + past_key_values_length = 0 + if self.cache_manager.past_key_values_length > 0: + # TODO dummy but work, revise it + past_key_values_length = self.cache_manager.past_key_values_length + # past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + infer_state.cache_manager = self.cache_manager + + if use_cache and seq_length != 1: + # NOTE assuem prefill stage + # allocate memory block + infer_state.is_context_stage = True # set prefill stage, notify attention layer + infer_state.context_mem_index = infer_state.cache_manager.alloc(infer_state.total_token_num) + BatchInferState.init_block_loc(infer_state.block_loc, infer_state.seq_len, seq_length, + infer_state.prefill_mem_index) + else: + # TODO handle the condition that no contiguous memory presents + alloc_mem = infer_state.cache_manager.alloc_contiguous(batch_size) + if alloc_mem is not None: + infer_state.decode_is_contiguous = True + infer_state.decode_mem_index = alloc_mem[0] + infer_state.decode_mem_start = alloc_mem[1] + infer_state.decode_mem_end = alloc_mem[2] + infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index + else: + print(f" *** Encountered allocation non-contiguous") + print( + f" infer_state.cache_manager.past_key_values_length: {infer_state.cache_manager.past_key_values_length}" + ) + infer_state.decode_is_contiguous = False + alloc_mem = infer_state.cache_manager.alloc(batch_size) + infer_state.decode_mem_index = alloc_mem + # infer_state.decode_key_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") + # infer_state.decode_value_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") + infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index + + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device) + else: + attention_mask = attention_mask.to(hidden_states.device) + + # NOTE we might want to store a single 1D alibi(length is #heads) in model + alibi = generate_alibi(self.num_heads).contiguous().cuda() + # alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype) + + causal_mask = self._prepare_attn_mask( + attention_mask, + input_shape=(batch_size, seq_length), + past_key_values_length=past_key_values_length, + ) + + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + # FIXME: currently our KV cache manager does not handle this condition + def create_custom_forward(module): + + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, use_cache=use_cache, output_attentions=output_attentions) + + return custom_forward + + outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + alibi, + causal_mask, + layer_past, + head_mask[i], + ) + else: + outputs = block( + hidden_states, + layer_past=layer_past, + attention_mask=causal_mask, + head_mask=head_mask[i], + use_cache=use_cache, + output_attentions=output_attentions, + alibi=alibi, + infer_state=infer_state, + ) + + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + + # Add last hidden state + hidden_states = self.ln_f(hidden_states) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + # NOTE: remember to update indices of kv cache block + self.b_start_loc = self.b_start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") + self.b_seq_len += 1 + + if not return_dict: + return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=presents, # should always be (None, None, ..., None) + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + # replace decoder layer forward: + # used to replace BloomBlock.forward + def bloom_block_forward( + self, + hidden_states: torch.Tensor, + alibi: torch.Tensor, + attention_mask: torch.Tensor, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + head_mask: Optional[torch.Tensor] = None, + use_cache: bool = False, + output_attentions: bool = False, + infer_state: Optional[BatchInferState] = None, + ): + # hidden_states: [batch_size, seq_length, hidden_size] + + # Layer norm at the beginning of the transformer layer. + layernorm_output = self.input_layernorm(hidden_states) + + # Layer norm post the self attention. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = hidden_states + + # Self attention. + attn_outputs = self.self_attention( + layernorm_output, + residual, + layer_past=layer_past, + attention_mask=attention_mask, + alibi=alibi, + head_mask=head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + infer_state=infer_state, + ) + + attention_output = attn_outputs[0] + + outputs = attn_outputs[1:] + + layernorm_output = self.post_attention_layernorm(attention_output) + + # Get residual + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = attention_output + + # MLP. + output = self.mlp(layernorm_output, residual) + + if use_cache: + outputs = (output,) + outputs + else: + outputs = (output,) + outputs[1:] + + return outputs # hidden_states, present, attentions + + # replace attention forward: + # used to replace BloomAttention.forward + def bloom_attention_forward( + self, + hidden_states: torch.Tensor, + residual: torch.Tensor, + alibi: torch.Tensor, + attention_mask: torch.Tensor, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + head_mask: Optional[torch.Tensor] = None, + use_cache: bool = False, + output_attentions: bool = False, + infer_state: Optional[BatchInferState] = None, + ): + + fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] + + # 3 x [batch_size, seq_length, num_heads, head_dim] + (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) + batch_size, q_length, H, D_HEAD = query_layer.shape + k = key_layer.reshape(-1, H, D_HEAD) # batch_size * q_length, H, D_HEAD, q_lenth == 1 + v = value_layer.reshape(-1, H, D_HEAD) # batch_size * q_length, H, D_HEAD, q_lenth == 1 + + mem_manager = infer_state.cache_manager + + if infer_state.is_context_stage: + # context process + max_input_len = q_length + b_start_loc = infer_state.start_loc + b_seq_len = infer_state.seq_len[:batch_size] + q = query_layer.reshape(-1, H, D_HEAD) + + copy_kv_cache_to_dest(k, infer_state.context_mem_index, mem_manager.key_buffer[self.layer_id]) + copy_kv_cache_to_dest(v, infer_state.context_mem_index, mem_manager.value_buffer[self.layer_id]) + + # output = self.output[:batch_size*q_length, :, :] + output = torch.empty_like(q) + + context_attention_fwd(q, k, v, output, alibi, b_start_loc, b_seq_len, max_input_len) + + context_layer = output.view(batch_size, q_length, H * D_HEAD) + # FIXME might want to revise + # need some way to record the length of past key values cache + # since we won't return past_key_value_cache right now + if self.layer_id == 0: # once per model.forward + infer_state.cache_manager.past_key_values_length = q_length # seq_len + else: + # query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim) + # need shape: batch_size, H, D_HEAD (q_length == 1), input q shape : (batch_size, q_length(1), H, D_HEAD) + assert q_length == 1, "for non-context process, we only support q_length == 1" + q = query_layer.reshape(-1, H, D_HEAD) + + if infer_state.decode_is_contiguous: + # if decode is contiguous, then we copy to key cache and value cache in cache manager directly + cache_k = infer_state.cache_manager.key_buffer[self.layer_id][ + infer_state.decode_mem_start:infer_state.decode_mem_end, :, :] + cache_v = infer_state.cache_manager.value_buffer[self.layer_id][ + infer_state.decode_mem_start:infer_state.decode_mem_end, :, :] + cache_k.copy_(k) + cache_v.copy_(v) + else: + # if decode is not contiguous, use triton kernel to copy key and value cache + # k, v shape: [batch_size, num_heads, head_dim/embed_size_per_head] + # TODO clean comments + # destindex_copy_kv(k, infer_state.decode_mem_index, mem_manager.key_buffer[self.layer_id]) + # destindex_copy_kv(v, infer_state.decode_mem_index, mem_manager.value_buffer[self.layer_id]) + copy_kv_cache_to_dest(k, infer_state.decode_mem_index, mem_manager.key_buffer[self.layer_id]) + copy_kv_cache_to_dest(v, infer_state.decode_mem_index, mem_manager.value_buffer[self.layer_id]) + + b_start_loc = infer_state.start_loc[:batch_size] + b_loc = infer_state.block_loc[:batch_size, :] + b_seq_len = infer_state.seq_len[:batch_size] + max_len_in_batch = mem_manager.past_key_values_length + q_length + output = torch.empty_like(q) + token_attention_fwd(q, mem_manager.key_buffer[self.layer_id], mem_manager.value_buffer[self.layer_id], + output, alibi, b_loc, b_start_loc, b_seq_len, max_len_in_batch) + + context_layer = output.view(batch_size, q_length, H * D_HEAD) + # FIXME might want to revise (same as above one) + # need some way to record the length of past key values cache + # since we won't return past_key_value_cache right now + if self.layer_id == 0: # once per model.forward + assert infer_state.cache_manager.past_key_values_length != 0 + infer_state.cache_manager.past_key_values_length += q_length # += 1 + + # NOTE: always set present as none for now, instead of returning past key value to the next decoding, + # we create the past key value pair from the cache manager + present = None + + # aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232 + if self.pretraining_tp > 1 and self.slow_but_exact: + slices = self.hidden_size / self.pretraining_tp + output_tensor = torch.zeros_like(context_layer) + for i in range(self.pretraining_tp): + output_tensor = output_tensor + F.linear( + context_layer[:, :, int(i * slices):int((i + 1) * slices)], + self.dense.weight[:, int(i * slices):int((i + 1) * slices)], + ) + else: + output_tensor = self.dense(context_layer) + + # dropout is not required here during inference + output_tensor = residual + output_tensor + + outputs = (output_tensor, present) + assert output_attentions is False, "we do not support output_attentions at this time" + + return outputs + + def get_bloom_flash_attention_forward(enabel_jit_fused=False): try: diff --git a/colossalai/shardformer/policies/auto_policy.py b/colossalai/shardformer/policies/auto_policy.py index 0ffa7fbeeab1..71d58e3e210a 100644 --- a/colossalai/shardformer/policies/auto_policy.py +++ b/colossalai/shardformer/policies/auto_policy.py @@ -135,8 +135,8 @@ class PolicyLocation: # LlaMa "transformers.models.llama.modeling_llama.LlamaModel": PolicyLocation(file_name="llama", class_name="LlamaModelInferPolicy"), - "transformers.models.llama.modeling_llama.LlamaForCausalLM": - PolicyLocation(file_name="llama", class_name="LlamaModelInferPolicy"), + "transformers.models.bloom.modeling_bloom.BloomModel": + PolicyLocation(file_name="bloom", class_name="BloomModelInferPolicy") } diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py index b35764db3870..147cbea35a43 100644 --- a/colossalai/shardformer/policies/bloom.py +++ b/colossalai/shardformer/policies/bloom.py @@ -9,6 +9,7 @@ from .._utils import getattr_, setattr_ from ..modeling.bloom import ( + BloomInferenceForwards, BloomPipelineForwards, build_bloom_alibi_tensor_fn, get_bloom_flash_attention_forward, @@ -209,6 +210,28 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: return [] +class BloomModelInferPolicy(BloomPolicy): + + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + from transformers.models.bloom.modeling_bloom import BloomAttention, BloomBlock, BloomModel + policy = super().module_policy() + # TODO might want to set inference config to shard config + + # NOTE ignore tp, pp at this moment? + if self.shard_config.enable_tensor_parallelism: + policy[BloomModel] = ModulePolicyDescription( + method_replacement={"forward": BloomInferenceForwards.bloom_model_forward}) + policy[BloomBlock] = ModulePolicyDescription( + method_replacement={"forward": BloomInferenceForwards.bloom_block_forward}) + policy[BloomAttention] = ModulePolicyDescription( + method_replacement={"forward": BloomInferenceForwards.bloom_attention_forward}) + + return policy + + class BloomForCausalLMPolicy(BloomPolicy): def module_policy(self): From 5f2e841fc6f46855d434b6f585c6cbb89a92bb54 Mon Sep 17 00:00:00 2001 From: yuanheng-zhao Date: Fri, 25 Aug 2023 14:14:02 +0800 Subject: [PATCH 02/16] enable pass BatchInferState from model forward --- colossalai/shardformer/modeling/bloom.py | 134 ++++++++++++++++++++--- colossalai/shardformer/policies/bloom.py | 12 +- 2 files changed, 130 insertions(+), 16 deletions(-) diff --git a/colossalai/shardformer/modeling/bloom.py b/colossalai/shardformer/modeling/bloom.py index d4dafaaa66d7..5ba307be75a1 100644 --- a/colossalai/shardformer/modeling/bloom.py +++ b/colossalai/shardformer/modeling/bloom.py @@ -753,7 +753,7 @@ class BloomInferenceForwards: @staticmethod def bloom_model_forward( - self, + self: BloomModel, input_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, attention_mask: Optional[torch.Tensor] = None, @@ -763,6 +763,7 @@ def bloom_model_forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + infer_state: Optional[BatchInferState] = None, **deprecated_arguments, ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]: @@ -793,16 +794,16 @@ def bloom_model_forward( else: raise ValueError("You have to specify either input_ids or inputs_embeds") - # initialize BatchInferState to track necessary states during current model forward - infer_state = BatchInferState() - infer_state.batch_size = batch_size - # TODO: dummy implementation here for testing, assume all inputs same length - infer_state.total_token_num = batch_size * seq_length - infer_state.block_loc = self.block_loc - infer_state.start_loc = self.b_start_loc - infer_state.seq_len = self.b_seq_len + # # initialize BatchInferState to track necessary states during current model forward + # infer_state = BatchInferState() + # infer_state.batch_size = batch_size + # # TODO: dummy implementation here for testing, assume all inputs same length + # infer_state.total_token_num = batch_size * seq_length + # infer_state.block_loc = self.block_loc + # infer_state.start_loc = self.b_start_loc + # infer_state.seq_len = self.b_seq_len - # still need to keep past_key_values to fit original forward flow + # still need to keep past_key_values to fit original forward flow· if past_key_values is None: past_key_values = tuple([None] * len(self.h)) @@ -836,7 +837,7 @@ def bloom_model_forward( # past_key_values_length = past_key_values[0][0].shape[2] seq_length_with_past = seq_length_with_past + past_key_values_length - infer_state.cache_manager = self.cache_manager + # infer_state.cache_manager = self.cache_manager if use_cache and seq_length != 1: # NOTE assuem prefill stage @@ -942,10 +943,116 @@ def custom_forward(*inputs): attentions=all_self_attentions, ) + @staticmethod + def bloom_for_causal_lm_forward(self: BloomForCausalLM, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + infer_state: Optional[BatchInferState] = None, + **deprecated_arguments): + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + """ + logger = logging.get_logger(__name__) + + if deprecated_arguments.pop("position_ids", False) is not False: + # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None` + warnings.warn( + "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore" + " passing `position_ids`.", + FutureWarning, + ) + if len(deprecated_arguments) > 0: + raise ValueError(f"Got unexpected arguments: {deprecated_arguments}") + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = BloomInferenceForwards.bloom_model_forward(self.transformer, + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + infer_state=infer_state) + hidden_states = transformer_outputs[0] + + lm_logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(lm_logits.device) + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + batch_size, seq_length, vocab_size = shift_logits.shape + # Flatten the tokens + loss_fct = CrossEntropyLoss() + loss = loss_fct(shift_logits.view(batch_size * seq_length, vocab_size), + shift_labels.view(batch_size * seq_length)) + + if not return_dict: + output = (lm_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + @staticmethod + def bloom_for_causal_lm_prepare_inputs_for_generation( + self: BloomForCausalLM, + input_ids: torch.LongTensor, + past_key_values: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ) -> dict: + # only last token for input_ids if past is not None + if past_key_values: + input_ids = input_ids[:, -1].unsqueeze(-1) + + # NOTE we won't use past key values here + # the cache may be in the stardard format (e.g. in contrastive search), convert to bloom's format if needed + # if past_key_values[0][0].shape[0] == input_ids.shape[0]: + # past_key_values = self._convert_to_bloom_cache(past_key_values) + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update({ + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + }) + return model_inputs + # replace decoder layer forward: # used to replace BloomBlock.forward + @staticmethod def bloom_block_forward( - self, + self: BloomBlock, hidden_states: torch.Tensor, alibi: torch.Tensor, attention_mask: torch.Tensor, @@ -1003,8 +1110,9 @@ def bloom_block_forward( # replace attention forward: # used to replace BloomAttention.forward + @staticmethod def bloom_attention_forward( - self, + self: BloomAttention, hidden_states: torch.Tensor, residual: torch.Tensor, alibi: torch.Tensor, diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py index 147cbea35a43..7260ab263f8c 100644 --- a/colossalai/shardformer/policies/bloom.py +++ b/colossalai/shardformer/policies/bloom.py @@ -216,7 +216,7 @@ def __init__(self) -> None: super().__init__() def module_policy(self): - from transformers.models.bloom.modeling_bloom import BloomAttention, BloomBlock, BloomModel + from transformers.models.bloom.modeling_bloom import BloomAttention, BloomBlock, BloomForCausalLM, BloomModel policy = super().module_policy() # TODO might want to set inference config to shard config @@ -224,11 +224,17 @@ def module_policy(self): if self.shard_config.enable_tensor_parallelism: policy[BloomModel] = ModulePolicyDescription( method_replacement={"forward": BloomInferenceForwards.bloom_model_forward}) + policy[BloomForCausalLM] = ModulePolicyDescription( + method_replacement={"forward": BloomInferenceForwards.bloom_for_causal_lm_forward}) policy[BloomBlock] = ModulePolicyDescription( - method_replacement={"forward": BloomInferenceForwards.bloom_block_forward}) + method_replacement={ + "forward": + BloomInferenceForwards.bloom_block_forward, + "prepare_inputs_for_generation": + BloomInferenceForwards.bloom_for_causal_lm_prepare_inputs_for_generation + }) policy[BloomAttention] = ModulePolicyDescription( method_replacement={"forward": BloomInferenceForwards.bloom_attention_forward}) - return policy From 686e56e7ccada4e0b9c8910b81a32ed1514f722b Mon Sep 17 00:00:00 2001 From: yuanheng-zhao Date: Fri, 25 Aug 2023 20:10:53 +0800 Subject: [PATCH 03/16] revise bloom infer layers/policies --- colossalai/shardformer/modeling/bloom.py | 56 ++++++++++++------- .../shardformer/policies/auto_policy.py | 4 +- colossalai/shardformer/policies/bloom.py | 5 ++ 3 files changed, 45 insertions(+), 20 deletions(-) diff --git a/colossalai/shardformer/modeling/bloom.py b/colossalai/shardformer/modeling/bloom.py index 5ba307be75a1..7eb9663ae68f 100644 --- a/colossalai/shardformer/modeling/bloom.py +++ b/colossalai/shardformer/modeling/bloom.py @@ -828,12 +828,19 @@ def bloom_model_forward( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") use_cache = False + # NOTE determine if BatchInferState is passed in via arg + # if not, get the attr binded to the model + # We might wantto remove setattr later + if infer_state is None: + infer_state = self.infer_state + # Compute alibi tensor: check build_alibi_tensor documentation seq_length_with_past = seq_length past_key_values_length = 0 - if self.cache_manager.past_key_values_length > 0: + # if self.cache_manager.past_key_values_length > 0: + if infer_state.cache_manager.past_key_values_length > 0: # TODO dummy but work, revise it - past_key_values_length = self.cache_manager.past_key_values_length + past_key_values_length = infer_state.cache_manager.past_key_values_length # past_key_values_length = past_key_values[0][0].shape[2] seq_length_with_past = seq_length_with_past + past_key_values_length @@ -845,7 +852,7 @@ def bloom_model_forward( infer_state.is_context_stage = True # set prefill stage, notify attention layer infer_state.context_mem_index = infer_state.cache_manager.alloc(infer_state.total_token_num) BatchInferState.init_block_loc(infer_state.block_loc, infer_state.seq_len, seq_length, - infer_state.prefill_mem_index) + infer_state.context_mem_index) else: # TODO handle the condition that no contiguous memory presents alloc_mem = infer_state.cache_manager.alloc_contiguous(batch_size) @@ -874,6 +881,7 @@ def bloom_model_forward( # NOTE we might want to store a single 1D alibi(length is #heads) in model alibi = generate_alibi(self.num_heads).contiguous().cuda() + print(f" self.num_heads = {self.num_heads}") # alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype) causal_mask = self._prepare_attn_mask( @@ -929,9 +937,12 @@ def custom_forward(*inputs): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - # NOTE: remember to update indices of kv cache block - self.b_start_loc = self.b_start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") - self.b_seq_len += 1 + # NOTE: here we still to update indices of kv cache block + # TODO: remove this part, instead, better to pass the BatchInferState from model forward, + # and update these information in engine.generate after model foward called + infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") + infer_state.seq_len += 1 + infer_state.decode_layer_id = 0 if not return_dict: return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) @@ -1133,6 +1144,7 @@ def bloom_attention_forward( v = value_layer.reshape(-1, H, D_HEAD) # batch_size * q_length, H, D_HEAD, q_lenth == 1 mem_manager = infer_state.cache_manager + layer_id = infer_state.decode_layer_id if infer_state.is_context_stage: # context process @@ -1141,19 +1153,22 @@ def bloom_attention_forward( b_seq_len = infer_state.seq_len[:batch_size] q = query_layer.reshape(-1, H, D_HEAD) - copy_kv_cache_to_dest(k, infer_state.context_mem_index, mem_manager.key_buffer[self.layer_id]) - copy_kv_cache_to_dest(v, infer_state.context_mem_index, mem_manager.value_buffer[self.layer_id]) + print(f" k.shape: {k.shape}") + print(f" infer_state.context_mem_index: {infer_state.context_mem_index}") + print(f" mem_manager.key_buffer[layer_id].shape: {mem_manager.key_buffer[layer_id].shape}") + copy_kv_cache_to_dest(k, infer_state.context_mem_index, mem_manager.key_buffer[layer_id]) + copy_kv_cache_to_dest(v, infer_state.context_mem_index, mem_manager.value_buffer[layer_id]) # output = self.output[:batch_size*q_length, :, :] output = torch.empty_like(q) - context_attention_fwd(q, k, v, output, alibi, b_start_loc, b_seq_len, max_input_len) + bloom_context_attn_fwd(q, k, v, output, b_start_loc, b_seq_len, max_input_len, alibi) context_layer = output.view(batch_size, q_length, H * D_HEAD) # FIXME might want to revise # need some way to record the length of past key values cache # since we won't return past_key_value_cache right now - if self.layer_id == 0: # once per model.forward + if layer_id == 0: # once per model.forward infer_state.cache_manager.past_key_values_length = q_length # seq_len else: # query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim) @@ -1163,9 +1178,9 @@ def bloom_attention_forward( if infer_state.decode_is_contiguous: # if decode is contiguous, then we copy to key cache and value cache in cache manager directly - cache_k = infer_state.cache_manager.key_buffer[self.layer_id][ + cache_k = infer_state.cache_manager.key_buffer[layer_id][ infer_state.decode_mem_start:infer_state.decode_mem_end, :, :] - cache_v = infer_state.cache_manager.value_buffer[self.layer_id][ + cache_v = infer_state.cache_manager.value_buffer[layer_id][ infer_state.decode_mem_start:infer_state.decode_mem_end, :, :] cache_k.copy_(k) cache_v.copy_(v) @@ -1173,27 +1188,30 @@ def bloom_attention_forward( # if decode is not contiguous, use triton kernel to copy key and value cache # k, v shape: [batch_size, num_heads, head_dim/embed_size_per_head] # TODO clean comments - # destindex_copy_kv(k, infer_state.decode_mem_index, mem_manager.key_buffer[self.layer_id]) - # destindex_copy_kv(v, infer_state.decode_mem_index, mem_manager.value_buffer[self.layer_id]) - copy_kv_cache_to_dest(k, infer_state.decode_mem_index, mem_manager.key_buffer[self.layer_id]) - copy_kv_cache_to_dest(v, infer_state.decode_mem_index, mem_manager.value_buffer[self.layer_id]) + # destindex_copy_kv(k, infer_state.decode_mem_index, mem_manager.key_buffer[layer_id]) + # destindex_copy_kv(v, infer_state.decode_mem_index, mem_manager.value_buffer[layer_id]) + copy_kv_cache_to_dest(k, infer_state.decode_mem_index, mem_manager.key_buffer[layer_id]) + copy_kv_cache_to_dest(v, infer_state.decode_mem_index, mem_manager.value_buffer[layer_id]) b_start_loc = infer_state.start_loc[:batch_size] b_loc = infer_state.block_loc[:batch_size, :] b_seq_len = infer_state.seq_len[:batch_size] max_len_in_batch = mem_manager.past_key_values_length + q_length output = torch.empty_like(q) - token_attention_fwd(q, mem_manager.key_buffer[self.layer_id], mem_manager.value_buffer[self.layer_id], - output, alibi, b_loc, b_start_loc, b_seq_len, max_len_in_batch) + token_attention_fwd(q, mem_manager.key_buffer[layer_id], mem_manager.value_buffer[layer_id], output, b_loc, + b_start_loc, b_seq_len, max_len_in_batch, alibi) context_layer = output.view(batch_size, q_length, H * D_HEAD) # FIXME might want to revise (same as above one) # need some way to record the length of past key values cache # since we won't return past_key_value_cache right now - if self.layer_id == 0: # once per model.forward + if layer_id == 0: # once per model.forward assert infer_state.cache_manager.past_key_values_length != 0 infer_state.cache_manager.past_key_values_length += q_length # += 1 + # update layer id + infer_state.decode_layer_id += 1 + # NOTE: always set present as none for now, instead of returning past key value to the next decoding, # we create the past key value pair from the cache manager present = None diff --git a/colossalai/shardformer/policies/auto_policy.py b/colossalai/shardformer/policies/auto_policy.py index 71d58e3e210a..9a1be4d146d1 100644 --- a/colossalai/shardformer/policies/auto_policy.py +++ b/colossalai/shardformer/policies/auto_policy.py @@ -136,7 +136,9 @@ class PolicyLocation: "transformers.models.llama.modeling_llama.LlamaModel": PolicyLocation(file_name="llama", class_name="LlamaModelInferPolicy"), "transformers.models.bloom.modeling_bloom.BloomModel": - PolicyLocation(file_name="bloom", class_name="BloomModelInferPolicy") + PolicyLocation(file_name="bloom", class_name="BloomModelInferPolicy"), + "transformers.models.bloom.modeling_bloom.BloomForCausalLM": + PolicyLocation(file_name="bloom", class_name="BloomModelInferPolicy"), } diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py index 7260ab263f8c..f45ab7002b27 100644 --- a/colossalai/shardformer/policies/bloom.py +++ b/colossalai/shardformer/policies/bloom.py @@ -44,6 +44,11 @@ def module_policy(self): policy = {} + print(f" BloomPolicyt.module_policy:") + print(f" self.shard_config.enable_tensor_parallelism: {self.shard_config.enable_tensor_parallelism}") + print(f" self.shard_config.tensor_parallel_size: {self.shard_config.tensor_parallel_size}") + print(f" self.model.config.n_head: {self.model.config.n_head}") + if self.shard_config.enable_tensor_parallelism: policy[BloomBlock] = ModulePolicyDescription(attribute_replacement={ "self_attention.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, From 0cbb25127648c4e2ea5de56e9eb8973afcc89742 Mon Sep 17 00:00:00 2001 From: yuanheng-zhao Date: Fri, 25 Aug 2023 20:12:32 +0800 Subject: [PATCH 04/16] add engine for inference (draft) --- colossalai/shardformer/inference/engine.py | 245 +++++++++++++++++++++ 1 file changed, 245 insertions(+) create mode 100644 colossalai/shardformer/inference/engine.py diff --git a/colossalai/shardformer/inference/engine.py b/colossalai/shardformer/inference/engine.py new file mode 100644 index 000000000000..5b257b1bf5b9 --- /dev/null +++ b/colossalai/shardformer/inference/engine.py @@ -0,0 +1,245 @@ +from functools import partial +from types import MethodType +from typing import Any, Callable, List, Optional, Set, Union + +import torch +import torch.distributed as dist +import torch.nn as nn +from transformers.generation import GenerationConfig +from transformers.generation.stopping_criteria import StoppingCriteria, StoppingCriteriaList +from transformers.tokenization_utils_base import BatchEncoding + +from colossalai.cluster import ProcessGroupMesh +from colossalai.shardformer import ShardConfig, ShardFormer +from colossalai.shardformer.inference import BatchInferState, MemoryManager +# from colossalai.shardformer.policies.bloom import BloomModelInferPolicy +from colossalai.shardformer.policies.auto_policy import get_autopolicy + +DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2 + + +class InferenceEngine: + + def __init__(self, model: nn.Module, max_batch_size, max_input_len, max_output_len, tp_size=1) -> None: + self.model = model + self.sharded_model = None + + self.max_batch_size = max_batch_size + self.max_input_len = max_input_len + self.max_output_len = max_output_len + self.max_total_token_num = self.max_batch_size * (self.max_input_len + self.max_output_len) + assert self.max_batch_size <= 64 + assert self.max_input_len + self.max_output_len <= 2048 + + self.tp_size = tp_size + self.pp_size = 1 # only consider tp for now + self.dp_size = 1 # only consider tp for now + + self.head_dim = self.model.config.hidden_size // self.model.config.num_attention_heads + self.head_num = self.model.config.num_attention_heads // self.tp_size + self.layer_num = self.model.config.num_hidden_layers + self.cache_manager = MemoryManager(self.max_total_token_num, torch.float16, self.head_num, self.head_dim, + self.layer_num) + + # self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size) + # self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) + + def shard_model_by(self, shardformer: ShardFormer) -> None: + # TODO Might want to use infer policy only when bs >= 4 + assert self.tp_size == shardformer.shard_config.tensor_parallel_size, "Engine tp size != shardformer tp size" + # shardformer.shard_config.tensor_parallel_process_group = self.tp_group + model_name = self.model.__class__.__name__ + policy = get_autopolicy(self.model, inference_only=True) + if model_name == 'LlamaForCausalLM': + self.sharded_model, _ = shardformer.optimize(self.model, policy) + elif model_name == 'BloomForCausalLM': + self.sharded_model, _ = shardformer.optimize(self.model, policy) + else: + raise ValueError(f'Unsupported model "{model_name}" for inference') + self.sharded_model = self.sharded_model.cuda() + + # NOTE input_tokens is expected to be BatchEncoding, + # instead of only input token ids + @torch.no_grad() + def generate_by_pass_infer_state(self, + input_tokens, + max_out_length: int, + generation_config: Optional[GenerationConfig] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None, + **model_kwargs) -> torch.Tensor: + + input_ids = input_tokens['input_ids'] + unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) + + if batch_size >= 4: + assert self.sharded_model is not None, "sharded model does not exist" + + batch_infer_state = self.prepare_batch_state(input_tokens) + batch_size = batch_infer_state.batch_size + assert batch_infer_state.max_len_in_batch <= self.max_input_len + + # record sequences finish status, add early stopping, etc, + + for _ in range(min(max_out_length, self.max_output_len)): + # ... + self.sharded_model.forward(..., **model_kwargs) + else: + # Use original model + orig_model = self.model + + for _ in range(min(max_out_length, self.max_output_len)): + + if prepare_inputs_fn is None and hasattr(orig_model, 'prepare_inputs_for_generation'): + prepare_inputs_fn = orig_model.prepare_inputs_for_generation + + model_inputs = prepare_inputs_fn(input_ids, ** + model_kwargs) if prepare_inputs_fn is not None else input_tokens + outputs = orig_model(**model_inputs) + + # next_token_logits = outputs['logits'][:, -1, :] + next_token_logits = outputs.logits[:, -1, :] + # pre-process distribution + # next_token_logits = logits_processor(input_ids, next_token_logits) + + # sample + # probs = torch.softmax(next_token_logits, dim=-1, dtype=torch.float) + # next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) + + # consider greedy only for now + next_tokens = torch.argmax(next_token_logits, dim=-1) + + # finished sentences should have their next token be a padding token + + # if eos_token_id is not None: + # if pad_token_id is None: + # raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") + # next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) + + # # update generated ids, model inputs for next step + # input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + # if update_model_kwargs_fn is not None: + # model_kwargs = update_model_kwargs_fn(outputs, model_kwargs) + + # # if eos_token was found in one sentence, set sentence to finished + # if eos_token_id is not None: + # unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long()) + + # # stop when each sentence is finished if early_stopping=True + # if early_stopping and _is_sequence_finished(unfinished_sequences): + # break + + return input_ids + + @torch.no_grad() + def generate_by_set_infer_state(self, input_tokens, generate_kwargs, early_stopping=False): + + # for testing, always use sharded model + assert self.sharded_model is not None, "sharded model does not exist" + + batch_infer_state = self.prepare_batch_state(input_tokens) + assert batch_infer_state.max_len_in_batch <= self.max_input_len, "max length in batch exceeds limit" + + # set BatchInferState for the current batch as attr to model + # NOTE this is not an expectable way to pass BatchInferState during inference + # we might want to rewrite generate function (e.g. generate_by_pass_infer_state) + # and pass BatchInferState via model forward + if hasattr(self.sharded_model, 'model'): + model = self.sharded_model.model + elif hasattr(self.sharded_model, 'transformer'): + model = self.sharded_model.transformer + setattr(model, 'infer_state', batch_infer_state) + + # add logging + generate_kwargs.update(max_new_tokens=self.max_output_len) + + # convert to dict + if isinstance(input_tokens, torch.Tensor): + input_tokens = dict(input_ids=input_tokens) + for t in input_tokens: + if torch.is_tensor(input_tokens[t]): + input_tokens[t] = input_tokens[t].to(torch.cuda.current_device()) + print(f" input_tokens[{t}].shape: {input_tokens[t].shape}") + + outputs = self.sharded_model.generate(**input_tokens, **generate_kwargs, early_stopping=early_stopping) + + print(f"outputs.shape {outputs.shape}") + return outputs + + # inputs should be one of the following types + # 1. BatchEncoding (e.g. tokenizer batch_encode) + # 2. list of input token ids (e.g. appended result of tokenizer encode) + # 3. torch.Tensor (e.g. tokenizer encode with return_tensors='pt') + # NOTE For torch.Tensor inputs representing a batch of inputs, we are unable to retrieve + # the actual length (e.g. number of tokens) of each input without attention mask + # Hence, for torch.Tensor with shape [bs, l] where bs > 1, we will assume + # all the inputs in the batch has the maximum length l + def prepare_batch_state(self, inputs: [BatchEncoding, torch.Tensor]) -> BatchInferState: + # records length based on attention mask + # Any better method? + if not isinstance(inputs, (BatchEncoding, list, torch.Tensor)): + raise TypeError(f"inputs type {type(inputs)} is not supported in prepare_batch_state") + + if isinstance(inputs, BatchEncoding): + attn_masks = inputs['attention_mask'] + batch_size = attn_masks.shape[0] + max_len_in_batch = attn_masks.shape[1] + elif isinstance(inputs, list): + batch_size = len(inputs) + else: + batch_size = inputs.shape[0] + + block_loc = torch.empty(batch_size, self.max_input_len + self.max_output_len, dtype=torch.long, device="cuda") + seq_start_indexes = torch.zeros(batch_size, dtype=torch.int32, device="cuda") + seq_lengths = torch.zeros(batch_size, dtype=torch.int32, device="cuda") + start_index = 0 + if isinstance(inputs, BatchEncoding): + for i, attn_mask in enumerate(attn_masks): + curr_seq_len = torch.sum(attn_mask) + seq_lengths[i] = curr_seq_len + seq_start_indexes[i] = start_index + start_index += curr_seq_len + else: + max_len_in_batch = -1 + for i, input_ids in enumerate(inputs): + curr_seq_len = len(input_ids) + seq_lengths[i] = curr_seq_len + seq_start_indexes[i] = start_index + start_index += curr_seq_len + max_len_in_batch = curr_seq_len if curr_seq_len > max_len_in_batch else max_len_in_batch + + batch_infer_state = BatchInferState(batch_size, max_len_in_batch) + batch_infer_state.seq_len = seq_lengths.to('cuda') # might want to assign specific device + batch_infer_state.start_loc = seq_start_indexes.to('cuda') + batch_infer_state.block_loc = block_loc + # NOTE BatchInferState.total_token_num revised (not pushed yet) + # Now we want actual total token num based on seq_len, instead of dummy ones in test + # (Could still use the dummy one for testing usage) + batch_infer_state.set_cache_manager(self.cache_manager) + batch_infer_state.decode_layer_id = 0 + batch_infer_state.past_key_values_len = 0 + batch_infer_state.is_context_stage = True + return batch_infer_state + + # BatchInferState is created and kept during generation + # after each iter of model forward, we should update BatchInferState + # NOTE use in rewritten generate method: use after model.forward + def update_batch_state(self, infer_state: Optional[BatchInferState]) -> None: + # self.b_start_loc = self.b_start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") + # self.b_seq_len += 1 + batch_size = infer_state.batch_size + device = infer_state.start_loc.device + infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device=device) + infer_state.seq_len += 1 + + # TODO might want to create a sequence pool + # add a single request/sequence/input text at a time and record its length + # In other words, store the actual length of input tokens representing a single input text + # E.g. "Introduce landmarks in Beijing" + # => add request + # => record token length and other necessary information to be used + # => engine hold all these necessary information until `generate` (or other name) is called, + # => put information already recorded in batchinferstate and pass it to model forward + # => clear records in engine + def add_request(): + pass From 0ec07ca6c8fb482880c6603f98de93fc782fd0d4 Mon Sep 17 00:00:00 2001 From: yuanheng-zhao Date: Fri, 25 Aug 2023 20:13:14 +0800 Subject: [PATCH 05/16] add test for bloom infer --- tests/test_infer/test_bloom_infer.py | 59 ++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) create mode 100644 tests/test_infer/test_bloom_infer.py diff --git a/tests/test_infer/test_bloom_infer.py b/tests/test_infer/test_bloom_infer.py new file mode 100644 index 000000000000..bfe234198c34 --- /dev/null +++ b/tests/test_infer/test_bloom_infer.py @@ -0,0 +1,59 @@ +import pytest +import torch +from transformers import AutoTokenizer, BloomForCausalLM + +import colossalai +from colossalai.cluster import ProcessGroupMesh +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer import ShardConfig, ShardFormer +from colossalai.shardformer.inference import InferenceEngine +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn + +TP_SIZE = 2 + + +# @parameterize +def run(): + # dummly set the model path, will revise later + # bloom model + model_path = "/data3/data/model_eval_for_commerical_use/phoenix-inst-chat-7b" + tokenizer = AutoTokenizer.from_pretrained(model_path) + tokenizer.pad_token = tokenizer.eos_token + model = BloomForCausalLM.from_pretrained(model_path, pad_token_id=tokenizer.eos_token_id) + + text = "Introduce some landmarks in Beijing" + input_ids = tokenizer.encode(text, return_tensors='pt') + + pg_mesh = ProcessGroupMesh(1, 1, TP_SIZE) + shardconfig = ShardConfig( + tensor_parallel_process_group=pg_mesh.get_group_along_axis(2), + enable_tensor_parallelism=True, + inference_only=True, + ) + shardformer = ShardFormer(shard_config=shardconfig) + + infer_engine = InferenceEngine(model.half(), 4, 12, 8, tp_size=TP_SIZE) + infer_engine.shard_model_by(shardformer) + + generate_kwargs = dict(do_sample=False) + outputs = infer_engine.generate_by_set_infer_state(input_ids, generate_kwargs) + + output_text = tokenizer.decode(outputs) + print(output_text) + + +def check_bloom(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_bloom_infer(): + spawn(check_bloom, TP_SIZE) + + +if __name__ == '__main__': + test_bloom_infer() From 960eea381ce744c638649b61c26776b627e6fa2b Mon Sep 17 00:00:00 2001 From: yuanheng-zhao Date: Mon, 28 Aug 2023 10:52:05 +0800 Subject: [PATCH 06/16] fix bloom infer policy and flow --- colossalai/shardformer/inference/engine.py | 11 +++- colossalai/shardformer/modeling/bloom.py | 16 +++-- colossalai/shardformer/policies/bloom.py | 69 +++++++++++++--------- 3 files changed, 61 insertions(+), 35 deletions(-) diff --git a/colossalai/shardformer/inference/engine.py b/colossalai/shardformer/inference/engine.py index 5b257b1bf5b9..0189f6d61b7c 100644 --- a/colossalai/shardformer/inference/engine.py +++ b/colossalai/shardformer/inference/engine.py @@ -11,10 +11,12 @@ from colossalai.cluster import ProcessGroupMesh from colossalai.shardformer import ShardConfig, ShardFormer -from colossalai.shardformer.inference import BatchInferState, MemoryManager # from colossalai.shardformer.policies.bloom import BloomModelInferPolicy from colossalai.shardformer.policies.auto_policy import get_autopolicy +from .batch_infer_state import BatchInferState +from .kvcache_manager import MemoryManager + DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2 @@ -189,18 +191,20 @@ def prepare_batch_state(self, inputs: [BatchEncoding, torch.Tensor]) -> BatchInf else: batch_size = inputs.shape[0] - block_loc = torch.empty(batch_size, self.max_input_len + self.max_output_len, dtype=torch.long, device="cuda") + # block_loc = torch.empty(batch_size, self.max_input_len + self.max_output_len, dtype=torch.long, device="cuda") seq_start_indexes = torch.zeros(batch_size, dtype=torch.int32, device="cuda") seq_lengths = torch.zeros(batch_size, dtype=torch.int32, device="cuda") start_index = 0 + + max_len_in_batch = -1 if isinstance(inputs, BatchEncoding): for i, attn_mask in enumerate(attn_masks): curr_seq_len = torch.sum(attn_mask) seq_lengths[i] = curr_seq_len seq_start_indexes[i] = start_index start_index += curr_seq_len + max_len_in_batch = curr_seq_len if curr_seq_len > max_len_in_batch else max_len_in_batch else: - max_len_in_batch = -1 for i, input_ids in enumerate(inputs): curr_seq_len = len(input_ids) seq_lengths[i] = curr_seq_len @@ -208,6 +212,7 @@ def prepare_batch_state(self, inputs: [BatchEncoding, torch.Tensor]) -> BatchInf start_index += curr_seq_len max_len_in_batch = curr_seq_len if curr_seq_len > max_len_in_batch else max_len_in_batch + block_loc = torch.empty((batch_size, self.max_input_len + self.max_output_len), dtype=torch.long, device="cuda") batch_infer_state = BatchInferState(batch_size, max_len_in_batch) batch_infer_state.seq_len = seq_lengths.to('cuda') # might want to assign specific device batch_infer_state.start_loc = seq_start_indexes.to('cuda') diff --git a/colossalai/shardformer/modeling/bloom.py b/colossalai/shardformer/modeling/bloom.py index 7eb9663ae68f..493323438f1e 100644 --- a/colossalai/shardformer/modeling/bloom.py +++ b/colossalai/shardformer/modeling/bloom.py @@ -881,7 +881,6 @@ def bloom_model_forward( # NOTE we might want to store a single 1D alibi(length is #heads) in model alibi = generate_alibi(self.num_heads).contiguous().cuda() - print(f" self.num_heads = {self.num_heads}") # alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype) causal_mask = self._prepare_attn_mask( @@ -1153,15 +1152,24 @@ def bloom_attention_forward( b_seq_len = infer_state.seq_len[:batch_size] q = query_layer.reshape(-1, H, D_HEAD) - print(f" k.shape: {k.shape}") - print(f" infer_state.context_mem_index: {infer_state.context_mem_index}") - print(f" mem_manager.key_buffer[layer_id].shape: {mem_manager.key_buffer[layer_id].shape}") copy_kv_cache_to_dest(k, infer_state.context_mem_index, mem_manager.key_buffer[layer_id]) copy_kv_cache_to_dest(v, infer_state.context_mem_index, mem_manager.value_buffer[layer_id]) # output = self.output[:batch_size*q_length, :, :] output = torch.empty_like(q) + # temp for testing + if not dist.is_initialized() or dist.get_rank() == 0: + print(f" cp3") + print(f" q.shape: {q.shape}") + print(f" k.shape: {k.shape}") + print(f" v.shape: {v.shape}") + print(f" output.shape: {output.shape}") + print(f" b_start_loc: {b_start_loc}") + print(f" b_seq_len: {b_seq_len}") + print(f" max_input_len: {max_input_len}") + print(f" alibi: {alibi}") + bloom_context_attn_fwd(q, k, v, output, b_start_loc, b_seq_len, max_input_len, alibi) context_layer = output.view(batch_size, q_length, H * D_HEAD) diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py index f45ab7002b27..d8bbaeeb7fd9 100644 --- a/colossalai/shardformer/policies/bloom.py +++ b/colossalai/shardformer/policies/bloom.py @@ -215,34 +215,6 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: return [] -class BloomModelInferPolicy(BloomPolicy): - - def __init__(self) -> None: - super().__init__() - - def module_policy(self): - from transformers.models.bloom.modeling_bloom import BloomAttention, BloomBlock, BloomForCausalLM, BloomModel - policy = super().module_policy() - # TODO might want to set inference config to shard config - - # NOTE ignore tp, pp at this moment? - if self.shard_config.enable_tensor_parallelism: - policy[BloomModel] = ModulePolicyDescription( - method_replacement={"forward": BloomInferenceForwards.bloom_model_forward}) - policy[BloomForCausalLM] = ModulePolicyDescription( - method_replacement={"forward": BloomInferenceForwards.bloom_for_causal_lm_forward}) - policy[BloomBlock] = ModulePolicyDescription( - method_replacement={ - "forward": - BloomInferenceForwards.bloom_block_forward, - "prepare_inputs_for_generation": - BloomInferenceForwards.bloom_for_causal_lm_prepare_inputs_for_generation - }) - policy[BloomAttention] = ModulePolicyDescription( - method_replacement={"forward": BloomInferenceForwards.bloom_attention_forward}) - return policy - - class BloomForCausalLMPolicy(BloomPolicy): def module_policy(self): @@ -281,6 +253,47 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: return [] +class BloomModelInferPolicy(BloomForCausalLMPolicy): + + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + from transformers.models.bloom.modeling_bloom import BloomAttention, BloomBlock, BloomForCausalLM, BloomModel + policy = super().module_policy() + # NOTE set inference mode to shard config + self.shard_config._infer() + + if self.shard_config.enable_tensor_parallelism: + + method_replacement = { + 'forward': + BloomInferenceForwards.bloom_for_causal_lm_forward, + 'prepare_inputs_for_generation': + BloomInferenceForwards.bloom_for_causal_lm_prepare_inputs_for_generation + } + self.append_or_create_method_replacement(description=method_replacement, + policy=policy, + target_key=BloomForCausalLM) + + method_replacement = {'forward': BloomInferenceForwards.bloom_model_forward} + self.append_or_create_method_replacement(description=method_replacement, + policy=policy, + target_key=BloomModel) + + method_replacement = {'forward': BloomInferenceForwards.bloom_block_forward} + self.append_or_create_method_replacement(description=method_replacement, + policy=policy, + target_key=BloomBlock) + + method_replacement = {'forward': BloomInferenceForwards.bloom_attention_forward} + self.append_or_create_method_replacement(description=method_replacement, + policy=policy, + target_key=BloomAttention) + + return policy + + class BloomForSequenceClassificationPolicy(BloomPolicy): def module_policy(self): From f511cb1ce4d494a416c542a16efa3ecf68fd745a Mon Sep 17 00:00:00 2001 From: yuanheng-zhao Date: Wed, 30 Aug 2023 14:10:34 +0800 Subject: [PATCH 07/16] revise bloom test --- tests/test_infer/test_bloom_infer.py | 45 +++++++++++++++------------- 1 file changed, 24 insertions(+), 21 deletions(-) diff --git a/tests/test_infer/test_bloom_infer.py b/tests/test_infer/test_bloom_infer.py index bfe234198c34..6381323555b6 100644 --- a/tests/test_infer/test_bloom_infer.py +++ b/tests/test_infer/test_bloom_infer.py @@ -1,48 +1,51 @@ import pytest import torch -from transformers import AutoTokenizer, BloomForCausalLM +from transformers import AutoModelForCausalLM, AutoTokenizer, BloomForCausalLM import colossalai -from colossalai.cluster import ProcessGroupMesh +from colossalai.inference.tensor_parallel import TPInferEngine from colossalai.logging import disable_existing_loggers from colossalai.shardformer import ShardConfig, ShardFormer -from colossalai.shardformer.inference import InferenceEngine from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn TP_SIZE = 2 +MAX_BATCH_SIZE = 4 +MAX_INPUT_LEN = 16 +MAX_OUTPUT_LEN = 8 -# @parameterize def run(): - # dummly set the model path, will revise later - # bloom model + model_path = "/data3/data/model_eval_for_commerical_use/phoenix-inst-chat-7b" tokenizer = AutoTokenizer.from_pretrained(model_path) tokenizer.pad_token = tokenizer.eos_token - model = BloomForCausalLM.from_pretrained(model_path, pad_token_id=tokenizer.eos_token_id) text = "Introduce some landmarks in Beijing" input_ids = tokenizer.encode(text, return_tensors='pt') - pg_mesh = ProcessGroupMesh(1, 1, TP_SIZE) - shardconfig = ShardConfig( - tensor_parallel_process_group=pg_mesh.get_group_along_axis(2), - enable_tensor_parallelism=True, - inference_only=True, - ) - shardformer = ShardFormer(shard_config=shardconfig) + # model_config = BloomConfig() + # model = BloomForCausalLM(model_config) + # model = AutoModelForCausalLM.from_pretrained(model_path, pad_token_id=tokenizer.eos_token_id) + model = BloomForCausalLM.from_pretrained(model_path, pad_token_id=tokenizer.eos_token_id) + model = model.half() + model.to(torch.cuda.current_device()) + + shard_config = ShardConfig(enable_tensor_parallelism=True, inference_only=True) + shardformer = ShardFormer(shard_config=shard_config) - infer_engine = InferenceEngine(model.half(), 4, 12, 8, tp_size=TP_SIZE) + infer_engine = TPInferEngine(model, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) + infer_engine.prepare_with_shard_config(shard_config=shard_config) infer_engine.shard_model_by(shardformer) generate_kwargs = dict(do_sample=False) - outputs = infer_engine.generate_by_set_infer_state(input_ids, generate_kwargs) + outputs = infer_engine.generate(input_ids, generate_kwargs) - output_text = tokenizer.decode(outputs) + print(outputs) + output_text = tokenizer.decode(outputs[0]) print(output_text) -def check_bloom(rank, world_size, port): +def check_engine(rank, world_size, port): disable_existing_loggers() colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') run() @@ -51,9 +54,9 @@ def check_bloom(rank, world_size, port): @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() -def test_bloom_infer(): - spawn(check_bloom, TP_SIZE) +def test_engine_infer(): + spawn(check_engine, TP_SIZE) if __name__ == '__main__': - test_bloom_infer() + test_engine_infer() From 41a3bf52c57dda70dc8d77323dda1238496d9139 Mon Sep 17 00:00:00 2001 From: yuanheng-zhao Date: Wed, 30 Aug 2023 14:29:55 +0800 Subject: [PATCH 08/16] fix bloom file path --- .../tensor_parallel/modeling/__init__.py | 3 + .../tensor_parallel/modeling/bloom.py | 582 ++++++++++++++++++ .../tensor_parallel/policies/__init__.py | 3 + .../tensor_parallel/policies/bloom.py | 44 ++ colossalai/shardformer/modeling/bloom.py | 567 ----------------- colossalai/shardformer/policies/bloom.py | 42 -- 6 files changed, 632 insertions(+), 609 deletions(-) create mode 100644 colossalai/inference/tensor_parallel/modeling/__init__.py create mode 100644 colossalai/inference/tensor_parallel/modeling/bloom.py create mode 100644 colossalai/inference/tensor_parallel/policies/__init__.py create mode 100644 colossalai/inference/tensor_parallel/policies/bloom.py diff --git a/colossalai/inference/tensor_parallel/modeling/__init__.py b/colossalai/inference/tensor_parallel/modeling/__init__.py new file mode 100644 index 000000000000..5b9ed338335f --- /dev/null +++ b/colossalai/inference/tensor_parallel/modeling/__init__.py @@ -0,0 +1,3 @@ +from .bloom import BloomInferenceForwards + +__all__ = ['BloomInferenceForwards'] diff --git a/colossalai/inference/tensor_parallel/modeling/bloom.py b/colossalai/inference/tensor_parallel/modeling/bloom.py new file mode 100644 index 000000000000..01dfc3dee7ef --- /dev/null +++ b/colossalai/inference/tensor_parallel/modeling/bloom.py @@ -0,0 +1,582 @@ +import math +import warnings +from typing import List, Optional, Tuple, Union + +import torch +import torch.distributed as dist +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from torch.nn import functional as F +from transformers.models.bloom.modeling_bloom import ( + BaseModelOutputWithPastAndCrossAttentions, + BloomAttention, + BloomBlock, + BloomForCausalLM, + BloomModel, + CausalLMOutputWithCrossAttentions, +) +from transformers.utils import logging + +from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState +from colossalai.kernel.triton.context_attention import bloom_context_attn_fwd +from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest +from colossalai.kernel.triton.token_attention_kernel import token_attention_fwd + + +def generate_alibi(n_head, dtype=torch.float16): + """ + This method is originally the `build_alibi_tensor` function + in `transformers/models/bloom/modeling_bloom.py` + of the huggingface/transformers GitHub repository. + + Copyright 2023 ModelTC Team + Copyright 2022 HuggingFace Inc. team and BigScience workshop + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + """ + + def get_slopes(n): + + def get_slopes_power_of_2(n): + start = 2**(-(2**-(math.log2(n) - 3))) + ratio = start + return [start * ratio**i for i in range(n)] + + if math.log2(n).is_integer(): + return get_slopes_power_of_2(n) + else: + closest_power_of_2 = 2**math.floor(math.log2(n)) + return (get_slopes_power_of_2(closest_power_of_2) + + get_slopes(2 * closest_power_of_2)[0::2][:n - closest_power_of_2]) + + slopes = torch.Tensor(get_slopes(n_head)) + head_alibi = slopes.to(dtype) + return head_alibi # 1 * num_heads + + +def generate_alibi_2(n_head, dtype=torch.float16): + + def get_slopes_power_of_2(n): + start = 2**(-(2**-(math.log2(n) - 3))) + return [start * start**i for i in range(n)] + + def get_slopes(n): + if math.log2(n).is_integer(): + return get_slopes_power_of_2(n) + else: + closest_power_of_2 = 2**math.floor(math.log2(n)) + slopes_power_of_2 = get_slopes_power_of_2(closest_power_of_2) + slopes_double = get_slopes(2 * closest_power_of_2) + slopes_combined = slopes_power_of_2 + slopes_double[0::2][:n - closest_power_of_2] + return slopes_combined + + slopes = torch.tensor(get_slopes(n_head), dtype=dtype) + return slopes + + +class BloomInferenceForwards: + """ + This class serves a micro library for bloom inference forwards + """ + + @staticmethod + def bloom_model_forward( + self: BloomModel, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + infer_state: Optional[BatchInferState] = None, + **deprecated_arguments, + ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]: + + logger = logging.get_logger(__name__) + + if deprecated_arguments.pop("position_ids", False) is not False: + # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None` + warnings.warn( + "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore" + " passing `position_ids`.", + FutureWarning, + ) + if len(deprecated_arguments) > 0: + raise ValueError(f"Got unexpected arguments: {deprecated_arguments}") + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + # # initialize BatchInferState to track necessary states during current model forward + # infer_state = BatchInferState() + # infer_state.batch_size = batch_size + # # TODO: dummy implementation here for testing, assume all inputs same length + # infer_state.total_token_num = batch_size * seq_length + # infer_state.block_loc = self.block_loc + # infer_state.start_loc = self.b_start_loc + # infer_state.seq_len = self.b_seq_len + + # still need to keep past_key_values to fit original forward flow· + if past_key_values is None: + past_key_values = tuple([None] * len(self.h)) + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape batch_size x num_heads x N x N + # head_mask has shape n_layer x batch x num_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + hidden_states = self.word_embeddings_layernorm(inputs_embeds) + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + use_cache = False + + # NOTE determine if BatchInferState is passed in via arg + # if not, get the attr binded to the model + # We might wantto remove setattr later + if infer_state is None: + infer_state = self.infer_state + + # Compute alibi tensor: check build_alibi_tensor documentation + seq_length_with_past = seq_length + past_key_values_length = 0 + # if self.cache_manager.past_key_values_length > 0: + if infer_state.cache_manager.past_key_values_length > 0: + # TODO dummy but work, revise it + past_key_values_length = infer_state.cache_manager.past_key_values_length + # past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + # infer_state.cache_manager = self.cache_manager + + if use_cache and seq_length != 1: + # NOTE assuem prefill stage + # allocate memory block + infer_state.is_context_stage = True # set prefill stage, notify attention layer + infer_state.context_mem_index = infer_state.cache_manager.alloc(infer_state.total_token_num) + BatchInferState.init_block_loc(infer_state.block_loc, infer_state.seq_len, seq_length, + infer_state.context_mem_index) + else: + # TODO handle the condition that no contiguous memory presents + alloc_mem = infer_state.cache_manager.alloc_contiguous(batch_size) + if alloc_mem is not None: + infer_state.decode_is_contiguous = True + infer_state.decode_mem_index = alloc_mem[0] + infer_state.decode_mem_start = alloc_mem[1] + infer_state.decode_mem_end = alloc_mem[2] + infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index + else: + print(f" *** Encountered allocation non-contiguous") + print( + f" infer_state.cache_manager.past_key_values_length: {infer_state.cache_manager.past_key_values_length}" + ) + infer_state.decode_is_contiguous = False + alloc_mem = infer_state.cache_manager.alloc(batch_size) + infer_state.decode_mem_index = alloc_mem + # infer_state.decode_key_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") + # infer_state.decode_value_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") + infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index + + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device) + else: + attention_mask = attention_mask.to(hidden_states.device) + + # NOTE we might want to store a single 1D alibi(length is #heads) in model + alibi = generate_alibi(self.num_heads).contiguous().cuda() + # alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype) + + causal_mask = self._prepare_attn_mask( + attention_mask, + input_shape=(batch_size, seq_length), + past_key_values_length=past_key_values_length, + ) + + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + # FIXME: currently our KV cache manager does not handle this condition + def create_custom_forward(module): + + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, use_cache=use_cache, output_attentions=output_attentions) + + return custom_forward + + outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + alibi, + causal_mask, + layer_past, + head_mask[i], + ) + else: + outputs = block( + hidden_states, + layer_past=layer_past, + attention_mask=causal_mask, + head_mask=head_mask[i], + use_cache=use_cache, + output_attentions=output_attentions, + alibi=alibi, + infer_state=infer_state, + ) + + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + + # Add last hidden state + hidden_states = self.ln_f(hidden_states) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + # NOTE: here we still to update indices of kv cache block + # TODO: remove this part, instead, better to pass the BatchInferState from model forward, + # and update these information in engine.generate after model foward called + infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") + infer_state.seq_len += 1 + infer_state.decode_layer_id = 0 + + if not return_dict: + return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=presents, # should always be (None, None, ..., None) + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + @staticmethod + def bloom_for_causal_lm_forward(self: BloomForCausalLM, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + infer_state: Optional[BatchInferState] = None, + **deprecated_arguments): + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + """ + logger = logging.get_logger(__name__) + + if deprecated_arguments.pop("position_ids", False) is not False: + # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None` + warnings.warn( + "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore" + " passing `position_ids`.", + FutureWarning, + ) + if len(deprecated_arguments) > 0: + raise ValueError(f"Got unexpected arguments: {deprecated_arguments}") + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = BloomInferenceForwards.bloom_model_forward(self.transformer, + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + infer_state=infer_state) + hidden_states = transformer_outputs[0] + + lm_logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(lm_logits.device) + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + batch_size, seq_length, vocab_size = shift_logits.shape + # Flatten the tokens + loss_fct = CrossEntropyLoss() + loss = loss_fct(shift_logits.view(batch_size * seq_length, vocab_size), + shift_labels.view(batch_size * seq_length)) + + if not return_dict: + output = (lm_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + @staticmethod + def bloom_for_causal_lm_prepare_inputs_for_generation( + self: BloomForCausalLM, + input_ids: torch.LongTensor, + past_key_values: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ) -> dict: + # only last token for input_ids if past is not None + if past_key_values: + input_ids = input_ids[:, -1].unsqueeze(-1) + + # NOTE we won't use past key values here + # the cache may be in the stardard format (e.g. in contrastive search), convert to bloom's format if needed + # if past_key_values[0][0].shape[0] == input_ids.shape[0]: + # past_key_values = self._convert_to_bloom_cache(past_key_values) + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update({ + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + }) + return model_inputs + + # replace decoder layer forward: + # used to replace BloomBlock.forward + @staticmethod + def bloom_block_forward( + self: BloomBlock, + hidden_states: torch.Tensor, + alibi: torch.Tensor, + attention_mask: torch.Tensor, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + head_mask: Optional[torch.Tensor] = None, + use_cache: bool = False, + output_attentions: bool = False, + infer_state: Optional[BatchInferState] = None, + ): + # hidden_states: [batch_size, seq_length, hidden_size] + + # Layer norm at the beginning of the transformer layer. + layernorm_output = self.input_layernorm(hidden_states) + + # Layer norm post the self attention. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = hidden_states + + # Self attention. + attn_outputs = self.self_attention( + layernorm_output, + residual, + layer_past=layer_past, + attention_mask=attention_mask, + alibi=alibi, + head_mask=head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + infer_state=infer_state, + ) + + attention_output = attn_outputs[0] + + outputs = attn_outputs[1:] + + layernorm_output = self.post_attention_layernorm(attention_output) + + # Get residual + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = attention_output + + # MLP. + output = self.mlp(layernorm_output, residual) + + if use_cache: + outputs = (output,) + outputs + else: + outputs = (output,) + outputs[1:] + + return outputs # hidden_states, present, attentions + + # replace attention forward: + # used to replace BloomAttention.forward + @staticmethod + def bloom_attention_forward( + self: BloomAttention, + hidden_states: torch.Tensor, + residual: torch.Tensor, + alibi: torch.Tensor, + attention_mask: torch.Tensor, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + head_mask: Optional[torch.Tensor] = None, + use_cache: bool = False, + output_attentions: bool = False, + infer_state: Optional[BatchInferState] = None, + ): + + fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] + + # 3 x [batch_size, seq_length, num_heads, head_dim] + (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) + batch_size, q_length, H, D_HEAD = query_layer.shape + k = key_layer.reshape(-1, H, D_HEAD) # batch_size * q_length, H, D_HEAD, q_lenth == 1 + v = value_layer.reshape(-1, H, D_HEAD) # batch_size * q_length, H, D_HEAD, q_lenth == 1 + + mem_manager = infer_state.cache_manager + layer_id = infer_state.decode_layer_id + + if infer_state.is_context_stage: + # context process + max_input_len = q_length + b_start_loc = infer_state.start_loc + b_seq_len = infer_state.seq_len[:batch_size] + q = query_layer.reshape(-1, H, D_HEAD) + + copy_kv_cache_to_dest(k, infer_state.context_mem_index, mem_manager.key_buffer[layer_id]) + copy_kv_cache_to_dest(v, infer_state.context_mem_index, mem_manager.value_buffer[layer_id]) + + # output = self.output[:batch_size*q_length, :, :] + output = torch.empty_like(q) + + # temp for testing + if not dist.is_initialized() or dist.get_rank() == 0: + print(f" cp3") + print(f" q.shape: {q.shape}") + print(f" k.shape: {k.shape}") + print(f" v.shape: {v.shape}") + print(f" output.shape: {output.shape}") + print(f" b_start_loc: {b_start_loc}") + print(f" b_seq_len: {b_seq_len}") + print(f" max_input_len: {max_input_len}") + print(f" alibi: {alibi}") + + bloom_context_attn_fwd(q, k, v, output, b_start_loc, b_seq_len, max_input_len, alibi) + + context_layer = output.view(batch_size, q_length, H * D_HEAD) + # FIXME might want to revise + # need some way to record the length of past key values cache + # since we won't return past_key_value_cache right now + if layer_id == 0: # once per model.forward + infer_state.cache_manager.past_key_values_length = q_length # seq_len + else: + # query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim) + # need shape: batch_size, H, D_HEAD (q_length == 1), input q shape : (batch_size, q_length(1), H, D_HEAD) + assert q_length == 1, "for non-context process, we only support q_length == 1" + q = query_layer.reshape(-1, H, D_HEAD) + + if infer_state.decode_is_contiguous: + # if decode is contiguous, then we copy to key cache and value cache in cache manager directly + cache_k = infer_state.cache_manager.key_buffer[layer_id][ + infer_state.decode_mem_start:infer_state.decode_mem_end, :, :] + cache_v = infer_state.cache_manager.value_buffer[layer_id][ + infer_state.decode_mem_start:infer_state.decode_mem_end, :, :] + cache_k.copy_(k) + cache_v.copy_(v) + else: + # if decode is not contiguous, use triton kernel to copy key and value cache + # k, v shape: [batch_size, num_heads, head_dim/embed_size_per_head] + # TODO clean comments + # destindex_copy_kv(k, infer_state.decode_mem_index, mem_manager.key_buffer[layer_id]) + # destindex_copy_kv(v, infer_state.decode_mem_index, mem_manager.value_buffer[layer_id]) + copy_kv_cache_to_dest(k, infer_state.decode_mem_index, mem_manager.key_buffer[layer_id]) + copy_kv_cache_to_dest(v, infer_state.decode_mem_index, mem_manager.value_buffer[layer_id]) + + b_start_loc = infer_state.start_loc[:batch_size] + b_loc = infer_state.block_loc[:batch_size, :] + b_seq_len = infer_state.seq_len[:batch_size] + max_len_in_batch = mem_manager.past_key_values_length + q_length + output = torch.empty_like(q) + token_attention_fwd(q, mem_manager.key_buffer[layer_id], mem_manager.value_buffer[layer_id], output, b_loc, + b_start_loc, b_seq_len, max_len_in_batch, alibi) + + context_layer = output.view(batch_size, q_length, H * D_HEAD) + # FIXME might want to revise (same as above one) + # need some way to record the length of past key values cache + # since we won't return past_key_value_cache right now + if layer_id == 0: # once per model.forward + assert infer_state.cache_manager.past_key_values_length != 0 + infer_state.cache_manager.past_key_values_length += q_length # += 1 + + # update layer id + infer_state.decode_layer_id += 1 + + # NOTE: always set present as none for now, instead of returning past key value to the next decoding, + # we create the past key value pair from the cache manager + present = None + + # aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232 + if self.pretraining_tp > 1 and self.slow_but_exact: + slices = self.hidden_size / self.pretraining_tp + output_tensor = torch.zeros_like(context_layer) + for i in range(self.pretraining_tp): + output_tensor = output_tensor + F.linear( + context_layer[:, :, int(i * slices):int((i + 1) * slices)], + self.dense.weight[:, int(i * slices):int((i + 1) * slices)], + ) + else: + output_tensor = self.dense(context_layer) + + # dropout is not required here during inference + output_tensor = residual + output_tensor + + outputs = (output_tensor, present) + assert output_attentions is False, "we do not support output_attentions at this time" + + return outputs diff --git a/colossalai/inference/tensor_parallel/policies/__init__.py b/colossalai/inference/tensor_parallel/policies/__init__.py new file mode 100644 index 000000000000..dfce99240775 --- /dev/null +++ b/colossalai/inference/tensor_parallel/policies/__init__.py @@ -0,0 +1,3 @@ +from .bloom import BloomModelInferPolicy + +__all__ = ['BloomModelInferPolicy'] diff --git a/colossalai/inference/tensor_parallel/policies/bloom.py b/colossalai/inference/tensor_parallel/policies/bloom.py new file mode 100644 index 000000000000..d9dc2982d040 --- /dev/null +++ b/colossalai/inference/tensor_parallel/policies/bloom.py @@ -0,0 +1,44 @@ +from colossalai.shardformer.policies.bloom import BloomForCausalLMPolicy + +from ..modeling.bloom import BloomInferenceForwards + + +class BloomModelInferPolicy(BloomForCausalLMPolicy): + + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + from transformers.models.bloom.modeling_bloom import BloomAttention, BloomBlock, BloomForCausalLM, BloomModel + policy = super().module_policy() + # NOTE set inference mode to shard config + self.shard_config._infer() + + if self.shard_config.enable_tensor_parallelism: + + method_replacement = { + 'forward': + BloomInferenceForwards.bloom_for_causal_lm_forward, + 'prepare_inputs_for_generation': + BloomInferenceForwards.bloom_for_causal_lm_prepare_inputs_for_generation + } + self.append_or_create_method_replacement(description=method_replacement, + policy=policy, + target_key=BloomForCausalLM) + + method_replacement = {'forward': BloomInferenceForwards.bloom_model_forward} + self.append_or_create_method_replacement(description=method_replacement, + policy=policy, + target_key=BloomModel) + + method_replacement = {'forward': BloomInferenceForwards.bloom_block_forward} + self.append_or_create_method_replacement(description=method_replacement, + policy=policy, + target_key=BloomBlock) + + method_replacement = {'forward': BloomInferenceForwards.bloom_attention_forward} + self.append_or_create_method_replacement(description=method_replacement, + policy=policy, + target_key=BloomAttention) + + return policy diff --git a/colossalai/shardformer/modeling/bloom.py b/colossalai/shardformer/modeling/bloom.py index 493323438f1e..12276635ecfa 100644 --- a/colossalai/shardformer/modeling/bloom.py +++ b/colossalai/shardformer/modeling/bloom.py @@ -1,4 +1,3 @@ -import math import warnings from typing import List, Optional, Tuple, Union @@ -15,8 +14,6 @@ TokenClassifierOutput, ) from transformers.models.bloom.modeling_bloom import ( - BloomAttention, - BloomBlock, BloomForCausalLM, BloomForQuestionAnswering, BloomForSequenceClassification, @@ -25,11 +22,7 @@ ) from transformers.utils import logging -from colossalai.kernel.triton.context_attention import bloom_context_attn_fwd -from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest -from colossalai.kernel.triton.token_attention_kernel import token_attention_fwd from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.shardformer.inference import BatchInferState def build_bloom_alibi_tensor_fn(process_group: ProcessGroup) -> torch.Tensor: @@ -98,67 +91,6 @@ def build_bloom_alibi_tensor(self, attention_mask: torch.Tensor, num_heads: int, return build_bloom_alibi_tensor -def generate_alibi(n_head, dtype=torch.float16): - """ - This method is originally the `build_alibi_tensor` function - in `transformers/models/bloom/modeling_bloom.py` - of the huggingface/transformers GitHub repository. - - Copyright 2023 ModelTC Team - Copyright 2022 HuggingFace Inc. team and BigScience workshop - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ - - def get_slopes(n): - - def get_slopes_power_of_2(n): - start = 2**(-(2**-(math.log2(n) - 3))) - ratio = start - return [start * ratio**i for i in range(n)] - - if math.log2(n).is_integer(): - return get_slopes_power_of_2(n) - else: - closest_power_of_2 = 2**math.floor(math.log2(n)) - return (get_slopes_power_of_2(closest_power_of_2) + - get_slopes(2 * closest_power_of_2)[0::2][:n - closest_power_of_2]) - - slopes = torch.Tensor(get_slopes(n_head)) - head_alibi = slopes.to(dtype) - return head_alibi # 1 * num_heads - - -def generate_alibi_2(n_head, dtype=torch.float16): - - def get_slopes_power_of_2(n): - start = 2**(-(2**-(math.log2(n) - 3))) - return [start * start**i for i in range(n)] - - def get_slopes(n): - if math.log2(n).is_integer(): - return get_slopes_power_of_2(n) - else: - closest_power_of_2 = 2**math.floor(math.log2(n)) - slopes_power_of_2 = get_slopes_power_of_2(closest_power_of_2) - slopes_double = get_slopes(2 * closest_power_of_2) - slopes_combined = slopes_power_of_2 + slopes_double[0::2][:n - closest_power_of_2] - return slopes_combined - - slopes = torch.tensor(get_slopes(n_head), dtype=dtype) - return slopes - - class BloomPipelineForwards: ''' This class serves as a micro library for bloom pipeline forwards. @@ -746,505 +678,6 @@ def bloom_for_question_answering_forward( return {'hidden_states': hidden_states} -class BloomInferenceForwards: - """ - This class serves a micro library for bloom inference forwards - """ - - @staticmethod - def bloom_model_forward( - self: BloomModel, - input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, - attention_mask: Optional[torch.Tensor] = None, - head_mask: Optional[torch.LongTensor] = None, - inputs_embeds: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - infer_state: Optional[BatchInferState] = None, - **deprecated_arguments, - ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]: - - logger = logging.get_logger(__name__) - - if deprecated_arguments.pop("position_ids", False) is not False: - # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None` - warnings.warn( - "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore" - " passing `position_ids`.", - FutureWarning, - ) - if len(deprecated_arguments) > 0: - raise ValueError(f"Got unexpected arguments: {deprecated_arguments}") - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = (output_hidden_states - if output_hidden_states is not None else self.config.output_hidden_states) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - # # initialize BatchInferState to track necessary states during current model forward - # infer_state = BatchInferState() - # infer_state.batch_size = batch_size - # # TODO: dummy implementation here for testing, assume all inputs same length - # infer_state.total_token_num = batch_size * seq_length - # infer_state.block_loc = self.block_loc - # infer_state.start_loc = self.b_start_loc - # infer_state.seq_len = self.b_seq_len - - # still need to keep past_key_values to fit original forward flow· - if past_key_values is None: - past_key_values = tuple([None] * len(self.h)) - - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape batch_size x num_heads x N x N - # head_mask has shape n_layer x batch x num_heads x N x N - head_mask = self.get_head_mask(head_mask, self.config.n_layer) - - if inputs_embeds is None: - inputs_embeds = self.word_embeddings(input_ids) - - hidden_states = self.word_embeddings_layernorm(inputs_embeds) - - presents = () if use_cache else None - all_self_attentions = () if output_attentions else None - all_hidden_states = () if output_hidden_states else None - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") - use_cache = False - - # NOTE determine if BatchInferState is passed in via arg - # if not, get the attr binded to the model - # We might wantto remove setattr later - if infer_state is None: - infer_state = self.infer_state - - # Compute alibi tensor: check build_alibi_tensor documentation - seq_length_with_past = seq_length - past_key_values_length = 0 - # if self.cache_manager.past_key_values_length > 0: - if infer_state.cache_manager.past_key_values_length > 0: - # TODO dummy but work, revise it - past_key_values_length = infer_state.cache_manager.past_key_values_length - # past_key_values_length = past_key_values[0][0].shape[2] - seq_length_with_past = seq_length_with_past + past_key_values_length - - # infer_state.cache_manager = self.cache_manager - - if use_cache and seq_length != 1: - # NOTE assuem prefill stage - # allocate memory block - infer_state.is_context_stage = True # set prefill stage, notify attention layer - infer_state.context_mem_index = infer_state.cache_manager.alloc(infer_state.total_token_num) - BatchInferState.init_block_loc(infer_state.block_loc, infer_state.seq_len, seq_length, - infer_state.context_mem_index) - else: - # TODO handle the condition that no contiguous memory presents - alloc_mem = infer_state.cache_manager.alloc_contiguous(batch_size) - if alloc_mem is not None: - infer_state.decode_is_contiguous = True - infer_state.decode_mem_index = alloc_mem[0] - infer_state.decode_mem_start = alloc_mem[1] - infer_state.decode_mem_end = alloc_mem[2] - infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index - else: - print(f" *** Encountered allocation non-contiguous") - print( - f" infer_state.cache_manager.past_key_values_length: {infer_state.cache_manager.past_key_values_length}" - ) - infer_state.decode_is_contiguous = False - alloc_mem = infer_state.cache_manager.alloc(batch_size) - infer_state.decode_mem_index = alloc_mem - # infer_state.decode_key_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") - # infer_state.decode_value_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") - infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index - - if attention_mask is None: - attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device) - else: - attention_mask = attention_mask.to(hidden_states.device) - - # NOTE we might want to store a single 1D alibi(length is #heads) in model - alibi = generate_alibi(self.num_heads).contiguous().cuda() - # alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype) - - causal_mask = self._prepare_attn_mask( - attention_mask, - input_shape=(batch_size, seq_length), - past_key_values_length=past_key_values_length, - ) - - for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if self.gradient_checkpointing and self.training: - # FIXME: currently our KV cache manager does not handle this condition - def create_custom_forward(module): - - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache=use_cache, output_attentions=output_attentions) - - return custom_forward - - outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - hidden_states, - alibi, - causal_mask, - layer_past, - head_mask[i], - ) - else: - outputs = block( - hidden_states, - layer_past=layer_past, - attention_mask=causal_mask, - head_mask=head_mask[i], - use_cache=use_cache, - output_attentions=output_attentions, - alibi=alibi, - infer_state=infer_state, - ) - - hidden_states = outputs[0] - if use_cache is True: - presents = presents + (outputs[1],) - - if output_attentions: - all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) - - # Add last hidden state - hidden_states = self.ln_f(hidden_states) - - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - # NOTE: here we still to update indices of kv cache block - # TODO: remove this part, instead, better to pass the BatchInferState from model forward, - # and update these information in engine.generate after model foward called - infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") - infer_state.seq_len += 1 - infer_state.decode_layer_id = 0 - - if not return_dict: - return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) - - return BaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=presents, # should always be (None, None, ..., None) - hidden_states=all_hidden_states, - attentions=all_self_attentions, - ) - - @staticmethod - def bloom_for_causal_lm_forward(self: BloomForCausalLM, - input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, - attention_mask: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - infer_state: Optional[BatchInferState] = None, - **deprecated_arguments): - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set - `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` - are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` - """ - logger = logging.get_logger(__name__) - - if deprecated_arguments.pop("position_ids", False) is not False: - # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None` - warnings.warn( - "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore" - " passing `position_ids`.", - FutureWarning, - ) - if len(deprecated_arguments) > 0: - raise ValueError(f"Got unexpected arguments: {deprecated_arguments}") - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - transformer_outputs = BloomInferenceForwards.bloom_model_forward(self.transformer, - input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - infer_state=infer_state) - hidden_states = transformer_outputs[0] - - lm_logits = self.lm_head(hidden_states) - - loss = None - if labels is not None: - # move labels to correct device to enable model parallelism - labels = labels.to(lm_logits.device) - # Shift so that tokens < n predict n - shift_logits = lm_logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - batch_size, seq_length, vocab_size = shift_logits.shape - # Flatten the tokens - loss_fct = CrossEntropyLoss() - loss = loss_fct(shift_logits.view(batch_size * seq_length, vocab_size), - shift_labels.view(batch_size * seq_length)) - - if not return_dict: - output = (lm_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return CausalLMOutputWithCrossAttentions( - loss=loss, - logits=lm_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) - - @staticmethod - def bloom_for_causal_lm_prepare_inputs_for_generation( - self: BloomForCausalLM, - input_ids: torch.LongTensor, - past_key_values: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - **kwargs, - ) -> dict: - # only last token for input_ids if past is not None - if past_key_values: - input_ids = input_ids[:, -1].unsqueeze(-1) - - # NOTE we won't use past key values here - # the cache may be in the stardard format (e.g. in contrastive search), convert to bloom's format if needed - # if past_key_values[0][0].shape[0] == input_ids.shape[0]: - # past_key_values = self._convert_to_bloom_cache(past_key_values) - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: - model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids} - - model_inputs.update({ - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "attention_mask": attention_mask, - }) - return model_inputs - - # replace decoder layer forward: - # used to replace BloomBlock.forward - @staticmethod - def bloom_block_forward( - self: BloomBlock, - hidden_states: torch.Tensor, - alibi: torch.Tensor, - attention_mask: torch.Tensor, - layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - head_mask: Optional[torch.Tensor] = None, - use_cache: bool = False, - output_attentions: bool = False, - infer_state: Optional[BatchInferState] = None, - ): - # hidden_states: [batch_size, seq_length, hidden_size] - - # Layer norm at the beginning of the transformer layer. - layernorm_output = self.input_layernorm(hidden_states) - - # Layer norm post the self attention. - if self.apply_residual_connection_post_layernorm: - residual = layernorm_output - else: - residual = hidden_states - - # Self attention. - attn_outputs = self.self_attention( - layernorm_output, - residual, - layer_past=layer_past, - attention_mask=attention_mask, - alibi=alibi, - head_mask=head_mask, - use_cache=use_cache, - output_attentions=output_attentions, - infer_state=infer_state, - ) - - attention_output = attn_outputs[0] - - outputs = attn_outputs[1:] - - layernorm_output = self.post_attention_layernorm(attention_output) - - # Get residual - if self.apply_residual_connection_post_layernorm: - residual = layernorm_output - else: - residual = attention_output - - # MLP. - output = self.mlp(layernorm_output, residual) - - if use_cache: - outputs = (output,) + outputs - else: - outputs = (output,) + outputs[1:] - - return outputs # hidden_states, present, attentions - - # replace attention forward: - # used to replace BloomAttention.forward - @staticmethod - def bloom_attention_forward( - self: BloomAttention, - hidden_states: torch.Tensor, - residual: torch.Tensor, - alibi: torch.Tensor, - attention_mask: torch.Tensor, - layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - head_mask: Optional[torch.Tensor] = None, - use_cache: bool = False, - output_attentions: bool = False, - infer_state: Optional[BatchInferState] = None, - ): - - fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] - - # 3 x [batch_size, seq_length, num_heads, head_dim] - (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) - batch_size, q_length, H, D_HEAD = query_layer.shape - k = key_layer.reshape(-1, H, D_HEAD) # batch_size * q_length, H, D_HEAD, q_lenth == 1 - v = value_layer.reshape(-1, H, D_HEAD) # batch_size * q_length, H, D_HEAD, q_lenth == 1 - - mem_manager = infer_state.cache_manager - layer_id = infer_state.decode_layer_id - - if infer_state.is_context_stage: - # context process - max_input_len = q_length - b_start_loc = infer_state.start_loc - b_seq_len = infer_state.seq_len[:batch_size] - q = query_layer.reshape(-1, H, D_HEAD) - - copy_kv_cache_to_dest(k, infer_state.context_mem_index, mem_manager.key_buffer[layer_id]) - copy_kv_cache_to_dest(v, infer_state.context_mem_index, mem_manager.value_buffer[layer_id]) - - # output = self.output[:batch_size*q_length, :, :] - output = torch.empty_like(q) - - # temp for testing - if not dist.is_initialized() or dist.get_rank() == 0: - print(f" cp3") - print(f" q.shape: {q.shape}") - print(f" k.shape: {k.shape}") - print(f" v.shape: {v.shape}") - print(f" output.shape: {output.shape}") - print(f" b_start_loc: {b_start_loc}") - print(f" b_seq_len: {b_seq_len}") - print(f" max_input_len: {max_input_len}") - print(f" alibi: {alibi}") - - bloom_context_attn_fwd(q, k, v, output, b_start_loc, b_seq_len, max_input_len, alibi) - - context_layer = output.view(batch_size, q_length, H * D_HEAD) - # FIXME might want to revise - # need some way to record the length of past key values cache - # since we won't return past_key_value_cache right now - if layer_id == 0: # once per model.forward - infer_state.cache_manager.past_key_values_length = q_length # seq_len - else: - # query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim) - # need shape: batch_size, H, D_HEAD (q_length == 1), input q shape : (batch_size, q_length(1), H, D_HEAD) - assert q_length == 1, "for non-context process, we only support q_length == 1" - q = query_layer.reshape(-1, H, D_HEAD) - - if infer_state.decode_is_contiguous: - # if decode is contiguous, then we copy to key cache and value cache in cache manager directly - cache_k = infer_state.cache_manager.key_buffer[layer_id][ - infer_state.decode_mem_start:infer_state.decode_mem_end, :, :] - cache_v = infer_state.cache_manager.value_buffer[layer_id][ - infer_state.decode_mem_start:infer_state.decode_mem_end, :, :] - cache_k.copy_(k) - cache_v.copy_(v) - else: - # if decode is not contiguous, use triton kernel to copy key and value cache - # k, v shape: [batch_size, num_heads, head_dim/embed_size_per_head] - # TODO clean comments - # destindex_copy_kv(k, infer_state.decode_mem_index, mem_manager.key_buffer[layer_id]) - # destindex_copy_kv(v, infer_state.decode_mem_index, mem_manager.value_buffer[layer_id]) - copy_kv_cache_to_dest(k, infer_state.decode_mem_index, mem_manager.key_buffer[layer_id]) - copy_kv_cache_to_dest(v, infer_state.decode_mem_index, mem_manager.value_buffer[layer_id]) - - b_start_loc = infer_state.start_loc[:batch_size] - b_loc = infer_state.block_loc[:batch_size, :] - b_seq_len = infer_state.seq_len[:batch_size] - max_len_in_batch = mem_manager.past_key_values_length + q_length - output = torch.empty_like(q) - token_attention_fwd(q, mem_manager.key_buffer[layer_id], mem_manager.value_buffer[layer_id], output, b_loc, - b_start_loc, b_seq_len, max_len_in_batch, alibi) - - context_layer = output.view(batch_size, q_length, H * D_HEAD) - # FIXME might want to revise (same as above one) - # need some way to record the length of past key values cache - # since we won't return past_key_value_cache right now - if layer_id == 0: # once per model.forward - assert infer_state.cache_manager.past_key_values_length != 0 - infer_state.cache_manager.past_key_values_length += q_length # += 1 - - # update layer id - infer_state.decode_layer_id += 1 - - # NOTE: always set present as none for now, instead of returning past key value to the next decoding, - # we create the past key value pair from the cache manager - present = None - - # aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232 - if self.pretraining_tp > 1 and self.slow_but_exact: - slices = self.hidden_size / self.pretraining_tp - output_tensor = torch.zeros_like(context_layer) - for i in range(self.pretraining_tp): - output_tensor = output_tensor + F.linear( - context_layer[:, :, int(i * slices):int((i + 1) * slices)], - self.dense.weight[:, int(i * slices):int((i + 1) * slices)], - ) - else: - output_tensor = self.dense(context_layer) - - # dropout is not required here during inference - output_tensor = residual + output_tensor - - outputs = (output_tensor, present) - assert output_attentions is False, "we do not support output_attentions at this time" - - return outputs - - def get_bloom_flash_attention_forward(enabel_jit_fused=False): try: diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py index d8bbaeeb7fd9..d5c4ba62d563 100644 --- a/colossalai/shardformer/policies/bloom.py +++ b/colossalai/shardformer/policies/bloom.py @@ -9,7 +9,6 @@ from .._utils import getattr_, setattr_ from ..modeling.bloom import ( - BloomInferenceForwards, BloomPipelineForwards, build_bloom_alibi_tensor_fn, get_bloom_flash_attention_forward, @@ -253,47 +252,6 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: return [] -class BloomModelInferPolicy(BloomForCausalLMPolicy): - - def __init__(self) -> None: - super().__init__() - - def module_policy(self): - from transformers.models.bloom.modeling_bloom import BloomAttention, BloomBlock, BloomForCausalLM, BloomModel - policy = super().module_policy() - # NOTE set inference mode to shard config - self.shard_config._infer() - - if self.shard_config.enable_tensor_parallelism: - - method_replacement = { - 'forward': - BloomInferenceForwards.bloom_for_causal_lm_forward, - 'prepare_inputs_for_generation': - BloomInferenceForwards.bloom_for_causal_lm_prepare_inputs_for_generation - } - self.append_or_create_method_replacement(description=method_replacement, - policy=policy, - target_key=BloomForCausalLM) - - method_replacement = {'forward': BloomInferenceForwards.bloom_model_forward} - self.append_or_create_method_replacement(description=method_replacement, - policy=policy, - target_key=BloomModel) - - method_replacement = {'forward': BloomInferenceForwards.bloom_block_forward} - self.append_or_create_method_replacement(description=method_replacement, - policy=policy, - target_key=BloomBlock) - - method_replacement = {'forward': BloomInferenceForwards.bloom_attention_forward} - self.append_or_create_method_replacement(description=method_replacement, - policy=policy, - target_key=BloomAttention) - - return policy - - class BloomForSequenceClassificationPolicy(BloomPolicy): def module_policy(self): From 77674d19de0633f79199eb9e048b75db34917c9e Mon Sep 17 00:00:00 2001 From: yuanheng-zhao Date: Wed, 30 Aug 2023 14:32:50 +0800 Subject: [PATCH 09/16] remove unused codes --- colossalai/shardformer/inference/engine.py | 250 --------------------- colossalai/shardformer/policies/bloom.py | 5 - 2 files changed, 255 deletions(-) delete mode 100644 colossalai/shardformer/inference/engine.py diff --git a/colossalai/shardformer/inference/engine.py b/colossalai/shardformer/inference/engine.py deleted file mode 100644 index 0189f6d61b7c..000000000000 --- a/colossalai/shardformer/inference/engine.py +++ /dev/null @@ -1,250 +0,0 @@ -from functools import partial -from types import MethodType -from typing import Any, Callable, List, Optional, Set, Union - -import torch -import torch.distributed as dist -import torch.nn as nn -from transformers.generation import GenerationConfig -from transformers.generation.stopping_criteria import StoppingCriteria, StoppingCriteriaList -from transformers.tokenization_utils_base import BatchEncoding - -from colossalai.cluster import ProcessGroupMesh -from colossalai.shardformer import ShardConfig, ShardFormer -# from colossalai.shardformer.policies.bloom import BloomModelInferPolicy -from colossalai.shardformer.policies.auto_policy import get_autopolicy - -from .batch_infer_state import BatchInferState -from .kvcache_manager import MemoryManager - -DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2 - - -class InferenceEngine: - - def __init__(self, model: nn.Module, max_batch_size, max_input_len, max_output_len, tp_size=1) -> None: - self.model = model - self.sharded_model = None - - self.max_batch_size = max_batch_size - self.max_input_len = max_input_len - self.max_output_len = max_output_len - self.max_total_token_num = self.max_batch_size * (self.max_input_len + self.max_output_len) - assert self.max_batch_size <= 64 - assert self.max_input_len + self.max_output_len <= 2048 - - self.tp_size = tp_size - self.pp_size = 1 # only consider tp for now - self.dp_size = 1 # only consider tp for now - - self.head_dim = self.model.config.hidden_size // self.model.config.num_attention_heads - self.head_num = self.model.config.num_attention_heads // self.tp_size - self.layer_num = self.model.config.num_hidden_layers - self.cache_manager = MemoryManager(self.max_total_token_num, torch.float16, self.head_num, self.head_dim, - self.layer_num) - - # self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size) - # self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) - - def shard_model_by(self, shardformer: ShardFormer) -> None: - # TODO Might want to use infer policy only when bs >= 4 - assert self.tp_size == shardformer.shard_config.tensor_parallel_size, "Engine tp size != shardformer tp size" - # shardformer.shard_config.tensor_parallel_process_group = self.tp_group - model_name = self.model.__class__.__name__ - policy = get_autopolicy(self.model, inference_only=True) - if model_name == 'LlamaForCausalLM': - self.sharded_model, _ = shardformer.optimize(self.model, policy) - elif model_name == 'BloomForCausalLM': - self.sharded_model, _ = shardformer.optimize(self.model, policy) - else: - raise ValueError(f'Unsupported model "{model_name}" for inference') - self.sharded_model = self.sharded_model.cuda() - - # NOTE input_tokens is expected to be BatchEncoding, - # instead of only input token ids - @torch.no_grad() - def generate_by_pass_infer_state(self, - input_tokens, - max_out_length: int, - generation_config: Optional[GenerationConfig] = None, - stopping_criteria: Optional[StoppingCriteriaList] = None, - prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None, - **model_kwargs) -> torch.Tensor: - - input_ids = input_tokens['input_ids'] - unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) - - if batch_size >= 4: - assert self.sharded_model is not None, "sharded model does not exist" - - batch_infer_state = self.prepare_batch_state(input_tokens) - batch_size = batch_infer_state.batch_size - assert batch_infer_state.max_len_in_batch <= self.max_input_len - - # record sequences finish status, add early stopping, etc, - - for _ in range(min(max_out_length, self.max_output_len)): - # ... - self.sharded_model.forward(..., **model_kwargs) - else: - # Use original model - orig_model = self.model - - for _ in range(min(max_out_length, self.max_output_len)): - - if prepare_inputs_fn is None and hasattr(orig_model, 'prepare_inputs_for_generation'): - prepare_inputs_fn = orig_model.prepare_inputs_for_generation - - model_inputs = prepare_inputs_fn(input_ids, ** - model_kwargs) if prepare_inputs_fn is not None else input_tokens - outputs = orig_model(**model_inputs) - - # next_token_logits = outputs['logits'][:, -1, :] - next_token_logits = outputs.logits[:, -1, :] - # pre-process distribution - # next_token_logits = logits_processor(input_ids, next_token_logits) - - # sample - # probs = torch.softmax(next_token_logits, dim=-1, dtype=torch.float) - # next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) - - # consider greedy only for now - next_tokens = torch.argmax(next_token_logits, dim=-1) - - # finished sentences should have their next token be a padding token - - # if eos_token_id is not None: - # if pad_token_id is None: - # raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") - # next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) - - # # update generated ids, model inputs for next step - # input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) - # if update_model_kwargs_fn is not None: - # model_kwargs = update_model_kwargs_fn(outputs, model_kwargs) - - # # if eos_token was found in one sentence, set sentence to finished - # if eos_token_id is not None: - # unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long()) - - # # stop when each sentence is finished if early_stopping=True - # if early_stopping and _is_sequence_finished(unfinished_sequences): - # break - - return input_ids - - @torch.no_grad() - def generate_by_set_infer_state(self, input_tokens, generate_kwargs, early_stopping=False): - - # for testing, always use sharded model - assert self.sharded_model is not None, "sharded model does not exist" - - batch_infer_state = self.prepare_batch_state(input_tokens) - assert batch_infer_state.max_len_in_batch <= self.max_input_len, "max length in batch exceeds limit" - - # set BatchInferState for the current batch as attr to model - # NOTE this is not an expectable way to pass BatchInferState during inference - # we might want to rewrite generate function (e.g. generate_by_pass_infer_state) - # and pass BatchInferState via model forward - if hasattr(self.sharded_model, 'model'): - model = self.sharded_model.model - elif hasattr(self.sharded_model, 'transformer'): - model = self.sharded_model.transformer - setattr(model, 'infer_state', batch_infer_state) - - # add logging - generate_kwargs.update(max_new_tokens=self.max_output_len) - - # convert to dict - if isinstance(input_tokens, torch.Tensor): - input_tokens = dict(input_ids=input_tokens) - for t in input_tokens: - if torch.is_tensor(input_tokens[t]): - input_tokens[t] = input_tokens[t].to(torch.cuda.current_device()) - print(f" input_tokens[{t}].shape: {input_tokens[t].shape}") - - outputs = self.sharded_model.generate(**input_tokens, **generate_kwargs, early_stopping=early_stopping) - - print(f"outputs.shape {outputs.shape}") - return outputs - - # inputs should be one of the following types - # 1. BatchEncoding (e.g. tokenizer batch_encode) - # 2. list of input token ids (e.g. appended result of tokenizer encode) - # 3. torch.Tensor (e.g. tokenizer encode with return_tensors='pt') - # NOTE For torch.Tensor inputs representing a batch of inputs, we are unable to retrieve - # the actual length (e.g. number of tokens) of each input without attention mask - # Hence, for torch.Tensor with shape [bs, l] where bs > 1, we will assume - # all the inputs in the batch has the maximum length l - def prepare_batch_state(self, inputs: [BatchEncoding, torch.Tensor]) -> BatchInferState: - # records length based on attention mask - # Any better method? - if not isinstance(inputs, (BatchEncoding, list, torch.Tensor)): - raise TypeError(f"inputs type {type(inputs)} is not supported in prepare_batch_state") - - if isinstance(inputs, BatchEncoding): - attn_masks = inputs['attention_mask'] - batch_size = attn_masks.shape[0] - max_len_in_batch = attn_masks.shape[1] - elif isinstance(inputs, list): - batch_size = len(inputs) - else: - batch_size = inputs.shape[0] - - # block_loc = torch.empty(batch_size, self.max_input_len + self.max_output_len, dtype=torch.long, device="cuda") - seq_start_indexes = torch.zeros(batch_size, dtype=torch.int32, device="cuda") - seq_lengths = torch.zeros(batch_size, dtype=torch.int32, device="cuda") - start_index = 0 - - max_len_in_batch = -1 - if isinstance(inputs, BatchEncoding): - for i, attn_mask in enumerate(attn_masks): - curr_seq_len = torch.sum(attn_mask) - seq_lengths[i] = curr_seq_len - seq_start_indexes[i] = start_index - start_index += curr_seq_len - max_len_in_batch = curr_seq_len if curr_seq_len > max_len_in_batch else max_len_in_batch - else: - for i, input_ids in enumerate(inputs): - curr_seq_len = len(input_ids) - seq_lengths[i] = curr_seq_len - seq_start_indexes[i] = start_index - start_index += curr_seq_len - max_len_in_batch = curr_seq_len if curr_seq_len > max_len_in_batch else max_len_in_batch - - block_loc = torch.empty((batch_size, self.max_input_len + self.max_output_len), dtype=torch.long, device="cuda") - batch_infer_state = BatchInferState(batch_size, max_len_in_batch) - batch_infer_state.seq_len = seq_lengths.to('cuda') # might want to assign specific device - batch_infer_state.start_loc = seq_start_indexes.to('cuda') - batch_infer_state.block_loc = block_loc - # NOTE BatchInferState.total_token_num revised (not pushed yet) - # Now we want actual total token num based on seq_len, instead of dummy ones in test - # (Could still use the dummy one for testing usage) - batch_infer_state.set_cache_manager(self.cache_manager) - batch_infer_state.decode_layer_id = 0 - batch_infer_state.past_key_values_len = 0 - batch_infer_state.is_context_stage = True - return batch_infer_state - - # BatchInferState is created and kept during generation - # after each iter of model forward, we should update BatchInferState - # NOTE use in rewritten generate method: use after model.forward - def update_batch_state(self, infer_state: Optional[BatchInferState]) -> None: - # self.b_start_loc = self.b_start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") - # self.b_seq_len += 1 - batch_size = infer_state.batch_size - device = infer_state.start_loc.device - infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device=device) - infer_state.seq_len += 1 - - # TODO might want to create a sequence pool - # add a single request/sequence/input text at a time and record its length - # In other words, store the actual length of input tokens representing a single input text - # E.g. "Introduce landmarks in Beijing" - # => add request - # => record token length and other necessary information to be used - # => engine hold all these necessary information until `generate` (or other name) is called, - # => put information already recorded in batchinferstate and pass it to model forward - # => clear records in engine - def add_request(): - pass diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py index d5c4ba62d563..b35764db3870 100644 --- a/colossalai/shardformer/policies/bloom.py +++ b/colossalai/shardformer/policies/bloom.py @@ -43,11 +43,6 @@ def module_policy(self): policy = {} - print(f" BloomPolicyt.module_policy:") - print(f" self.shard_config.enable_tensor_parallelism: {self.shard_config.enable_tensor_parallelism}") - print(f" self.shard_config.tensor_parallel_size: {self.shard_config.tensor_parallel_size}") - print(f" self.model.config.n_head: {self.model.config.n_head}") - if self.shard_config.enable_tensor_parallelism: policy[BloomBlock] = ModulePolicyDescription(attribute_replacement={ "self_attention.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, From cb88328e6fc607cdfb604e207bc97cb7bf6ff76c Mon Sep 17 00:00:00 2001 From: yuanheng-zhao Date: Wed, 30 Aug 2023 14:58:39 +0800 Subject: [PATCH 10/16] fix bloom modeling --- .../tensor_parallel/modeling/bloom.py | 47 +++++++------------ .../shardformer/policies/auto_policy.py | 2 +- 2 files changed, 19 insertions(+), 30 deletions(-) diff --git a/colossalai/inference/tensor_parallel/modeling/bloom.py b/colossalai/inference/tensor_parallel/modeling/bloom.py index 01dfc3dee7ef..18d6887f2768 100644 --- a/colossalai/inference/tensor_parallel/modeling/bloom.py +++ b/colossalai/inference/tensor_parallel/modeling/bloom.py @@ -131,16 +131,7 @@ def bloom_model_forward( else: raise ValueError("You have to specify either input_ids or inputs_embeds") - # # initialize BatchInferState to track necessary states during current model forward - # infer_state = BatchInferState() - # infer_state.batch_size = batch_size - # # TODO: dummy implementation here for testing, assume all inputs same length - # infer_state.total_token_num = batch_size * seq_length - # infer_state.block_loc = self.block_loc - # infer_state.start_loc = self.b_start_loc - # infer_state.seq_len = self.b_seq_len - - # still need to keep past_key_values to fit original forward flow· + # still need to keep past_key_values to fit original forward flow if past_key_values is None: past_key_values = tuple([None] * len(self.h)) @@ -169,6 +160,7 @@ def bloom_model_forward( # if not, get the attr binded to the model # We might wantto remove setattr later if infer_state is None: + assert hasattr(self, 'infer_state') infer_state = self.infer_state # Compute alibi tensor: check build_alibi_tensor documentation @@ -176,22 +168,20 @@ def bloom_model_forward( past_key_values_length = 0 # if self.cache_manager.past_key_values_length > 0: if infer_state.cache_manager.past_key_values_length > 0: - # TODO dummy but work, revise it + # update the past key values length in cache manager, + # TODO use BatchInferState.past_key_values_length instead the one in cache manager past_key_values_length = infer_state.cache_manager.past_key_values_length - # past_key_values_length = past_key_values[0][0].shape[2] seq_length_with_past = seq_length_with_past + past_key_values_length # infer_state.cache_manager = self.cache_manager if use_cache and seq_length != 1: - # NOTE assuem prefill stage - # allocate memory block + # prefill stage infer_state.is_context_stage = True # set prefill stage, notify attention layer infer_state.context_mem_index = infer_state.cache_manager.alloc(infer_state.total_token_num) BatchInferState.init_block_loc(infer_state.block_loc, infer_state.seq_len, seq_length, infer_state.context_mem_index) else: - # TODO handle the condition that no contiguous memory presents alloc_mem = infer_state.cache_manager.alloc_contiguous(batch_size) if alloc_mem is not None: infer_state.decode_is_contiguous = True @@ -216,9 +206,14 @@ def bloom_model_forward( else: attention_mask = attention_mask.to(hidden_states.device) - # NOTE we might want to store a single 1D alibi(length is #heads) in model - alibi = generate_alibi(self.num_heads).contiguous().cuda() - # alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype) + # TODO revise: we might want to store a single 1D alibi(length is #heads) in model, + # or store to BatchInferState to prevent re-calculating + # When we have multiple process group (e.g. dp together with tp), we need to pass the pg to here + # alibi = generate_alibi(self.num_heads).contiguous().cuda() + tp_size = dist.get_world_size() + curr_tp_rank = dist.get_rank() + alibi = generate_alibi(self.num_heads * tp_size).contiguous()[curr_tp_rank * self.num_heads:(curr_tp_rank + 1) * + self.num_heads].cuda() causal_mask = self._prepare_attn_mask( attention_mask, @@ -273,8 +268,8 @@ def custom_forward(*inputs): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - # NOTE: here we still to update indices of kv cache block - # TODO: remove this part, instead, better to pass the BatchInferState from model forward, + # update indices of kv cache block + # TODO: might want to remove this part, instead, better to pass the BatchInferState from model forward, # and update these information in engine.generate after model foward called infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") infer_state.seq_len += 1 @@ -510,9 +505,8 @@ def bloom_attention_forward( bloom_context_attn_fwd(q, k, v, output, b_start_loc, b_seq_len, max_input_len, alibi) context_layer = output.view(batch_size, q_length, H * D_HEAD) - # FIXME might want to revise - # need some way to record the length of past key values cache - # since we won't return past_key_value_cache right now + # record the length of past key values cache when entering the first attention layer in bloom block, + # since we won't return past_key_value_cache right now if layer_id == 0: # once per model.forward infer_state.cache_manager.past_key_values_length = q_length # seq_len else: @@ -532,9 +526,6 @@ def bloom_attention_forward( else: # if decode is not contiguous, use triton kernel to copy key and value cache # k, v shape: [batch_size, num_heads, head_dim/embed_size_per_head] - # TODO clean comments - # destindex_copy_kv(k, infer_state.decode_mem_index, mem_manager.key_buffer[layer_id]) - # destindex_copy_kv(v, infer_state.decode_mem_index, mem_manager.value_buffer[layer_id]) copy_kv_cache_to_dest(k, infer_state.decode_mem_index, mem_manager.key_buffer[layer_id]) copy_kv_cache_to_dest(v, infer_state.decode_mem_index, mem_manager.value_buffer[layer_id]) @@ -547,9 +538,7 @@ def bloom_attention_forward( b_start_loc, b_seq_len, max_len_in_batch, alibi) context_layer = output.view(batch_size, q_length, H * D_HEAD) - # FIXME might want to revise (same as above one) - # need some way to record the length of past key values cache - # since we won't return past_key_value_cache right now + if layer_id == 0: # once per model.forward assert infer_state.cache_manager.past_key_values_length != 0 infer_state.cache_manager.past_key_values_length += q_length # += 1 diff --git a/colossalai/shardformer/policies/auto_policy.py b/colossalai/shardformer/policies/auto_policy.py index ef5145acde4c..05474f46a8a7 100644 --- a/colossalai/shardformer/policies/auto_policy.py +++ b/colossalai/shardformer/policies/auto_policy.py @@ -135,6 +135,7 @@ class PolicyLocation: # LlaMa "transformers.models.llama.modeling_llama.LlamaModel": PolicyLocation(file_name="llama", class_name="LlamaModelInferPolicy"), + # Bloom "transformers.models.bloom.modeling_bloom.BloomModel": PolicyLocation(file_name="bloom", class_name="BloomModelInferPolicy"), "transformers.models.bloom.modeling_bloom.BloomForCausalLM": @@ -146,7 +147,6 @@ def import_policy(policy_location: PolicyLocation, inference_only: Optional[bool """ Dynamically import a Policy class based on the policy location. """ - if inference_only: module_name = f"colossalai.inference.tensor_parallel.pollcies.{policy_location.file_name}" else: From df406714c26f50488f4ec8a2e79637864b39157d Mon Sep 17 00:00:00 2001 From: yuanheng-zhao Date: Wed, 30 Aug 2023 15:10:24 +0800 Subject: [PATCH 11/16] fix dir typo --- .../inference/tensor_parallel/__init__.py | 8 +++++--- .../tensor_parallel/policies/__init__.py | 3 ++- .../{pollcies => policies}/llama.py | 17 +++++++++++------ .../tensor_parallel/pollcies/__init__.py | 3 --- colossalai/shardformer/policies/auto_policy.py | 2 +- 5 files changed, 19 insertions(+), 14 deletions(-) rename colossalai/inference/tensor_parallel/{pollcies => policies}/llama.py (77%) delete mode 100644 colossalai/inference/tensor_parallel/pollcies/__init__.py diff --git a/colossalai/inference/tensor_parallel/__init__.py b/colossalai/inference/tensor_parallel/__init__.py index 1535db4c1ff9..8c8fa2960429 100644 --- a/colossalai/inference/tensor_parallel/__init__.py +++ b/colossalai/inference/tensor_parallel/__init__.py @@ -1,6 +1,8 @@ -from .modeling.llama import LlamaInferenceForwards -from .pollcies.llama import LlamaModelInferPolicy from .engine import TPInferEngine from .kvcache_manager import MemoryManager - +from .modeling.bloom import BloomInferenceForwards +from .modeling.llama import LlamaInferenceForwards +from .policies.bloom import BloomModelInferPolicy +from .policies.llama import LlamaModelInferPolicy + __all__ = ['LlamaInferenceForwards', 'LlamaModelInferPolicy', 'MemoryManager', 'TPInferEngine'] diff --git a/colossalai/inference/tensor_parallel/policies/__init__.py b/colossalai/inference/tensor_parallel/policies/__init__.py index dfce99240775..48f8db62c32a 100644 --- a/colossalai/inference/tensor_parallel/policies/__init__.py +++ b/colossalai/inference/tensor_parallel/policies/__init__.py @@ -1,3 +1,4 @@ from .bloom import BloomModelInferPolicy +from .llama import LlamaModelInferPolicy -__all__ = ['BloomModelInferPolicy'] +__all__ = ['BloomModelInferPolicy', 'LlamaModelInferPolicy'] diff --git a/colossalai/inference/tensor_parallel/pollcies/llama.py b/colossalai/inference/tensor_parallel/policies/llama.py similarity index 77% rename from colossalai/inference/tensor_parallel/pollcies/llama.py rename to colossalai/inference/tensor_parallel/policies/llama.py index 570e10ba3010..997f5fe48a54 100644 --- a/colossalai/inference/tensor_parallel/pollcies/llama.py +++ b/colossalai/inference/tensor_parallel/policies/llama.py @@ -2,7 +2,8 @@ from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy -from ..modeling.llama import LlamaInferenceForwards +from ..modeling.llama import LlamaInferenceForwards + class LlamaModelInferPolicy(LlamaForCausalLMPolicy): @@ -23,13 +24,17 @@ def module_policy(self): infer_forward = LlamaInferenceForwards.llama_model_forward method_replacement = {'forward': partial(infer_forward)} self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaModel) - + infer_forward = LlamaInferenceForwards.llama_decoder_layer_forward method_replacement = {'forward': partial(infer_forward)} - self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaDecoderLayer) - + self.append_or_create_method_replacement(description=method_replacement, + policy=policy, + target_key=LlamaDecoderLayer) + infer_forward = LlamaInferenceForwards.llama_flash_attn_kvcache_forward method_replacement = {'forward': partial(infer_forward)} - self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaAttention) + self.append_or_create_method_replacement(description=method_replacement, + policy=policy, + target_key=LlamaAttention) - return policy \ No newline at end of file + return policy diff --git a/colossalai/inference/tensor_parallel/pollcies/__init__.py b/colossalai/inference/tensor_parallel/pollcies/__init__.py deleted file mode 100644 index d92a3e84d097..000000000000 --- a/colossalai/inference/tensor_parallel/pollcies/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .llama import LlamaModelInferPolicy - -__all__ = ['LlamaModelInferPolicy'] \ No newline at end of file diff --git a/colossalai/shardformer/policies/auto_policy.py b/colossalai/shardformer/policies/auto_policy.py index 05474f46a8a7..064dcf2cd47e 100644 --- a/colossalai/shardformer/policies/auto_policy.py +++ b/colossalai/shardformer/policies/auto_policy.py @@ -148,7 +148,7 @@ def import_policy(policy_location: PolicyLocation, inference_only: Optional[bool Dynamically import a Policy class based on the policy location. """ if inference_only: - module_name = f"colossalai.inference.tensor_parallel.pollcies.{policy_location.file_name}" + module_name = f"colossalai.inference.tensor_parallel.policies.{policy_location.file_name}" else: module_name = f"colossalai.shardformer.policies.{policy_location.file_name}" module = importlib.import_module(module_name) From 1e4ecbb7deab730f42a1638a84ac2a4c86ac773e Mon Sep 17 00:00:00 2001 From: yuanheng-zhao Date: Wed, 30 Aug 2023 15:22:16 +0800 Subject: [PATCH 12/16] fix trivial --- colossalai/inference/tensor_parallel/modeling/bloom.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/colossalai/inference/tensor_parallel/modeling/bloom.py b/colossalai/inference/tensor_parallel/modeling/bloom.py index 18d6887f2768..eefc68197e5e 100644 --- a/colossalai/inference/tensor_parallel/modeling/bloom.py +++ b/colossalai/inference/tensor_parallel/modeling/bloom.py @@ -182,6 +182,7 @@ def bloom_model_forward( BatchInferState.init_block_loc(infer_state.block_loc, infer_state.seq_len, seq_length, infer_state.context_mem_index) else: + infer_state.is_context_stage = False alloc_mem = infer_state.cache_manager.alloc_contiguous(batch_size) if alloc_mem is not None: infer_state.decode_is_contiguous = True @@ -214,6 +215,9 @@ def bloom_model_forward( curr_tp_rank = dist.get_rank() alibi = generate_alibi(self.num_heads * tp_size).contiguous()[curr_tp_rank * self.num_heads:(curr_tp_rank + 1) * self.num_heads].cuda() + print(f" bloom_model_forward:") + print(f" alibi.shape: {alibi.shape}") + print(alibi) causal_mask = self._prepare_attn_mask( attention_mask, From 63bcb944f0bd8bbf27fc2e2df1488ae013ab13b9 Mon Sep 17 00:00:00 2001 From: yuanheng-zhao Date: Wed, 30 Aug 2023 15:29:23 +0800 Subject: [PATCH 13/16] fix policy --- colossalai/shardformer/policies/auto_policy.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/colossalai/shardformer/policies/auto_policy.py b/colossalai/shardformer/policies/auto_policy.py index 064dcf2cd47e..d23261ce237c 100644 --- a/colossalai/shardformer/policies/auto_policy.py +++ b/colossalai/shardformer/policies/auto_policy.py @@ -135,6 +135,8 @@ class PolicyLocation: # LlaMa "transformers.models.llama.modeling_llama.LlamaModel": PolicyLocation(file_name="llama", class_name="LlamaModelInferPolicy"), + "transformers.models.llama.modeling_llama.LlamaForCausalLM": + PolicyLocation(file_name="llama", class_name="LlamaModelInferPolicy"), # Bloom "transformers.models.bloom.modeling_bloom.BloomModel": PolicyLocation(file_name="bloom", class_name="BloomModelInferPolicy"), From 274fbf5aa1d62fc03f3bae7882be864b10e599e5 Mon Sep 17 00:00:00 2001 From: yuanheng-zhao Date: Wed, 30 Aug 2023 17:38:55 +0800 Subject: [PATCH 14/16] clean pr --- colossalai/inference/tensor_parallel/engine.py | 9 ++------- .../inference/tensor_parallel/modeling/bloom.py | 16 ---------------- tests/test_infer/test_bloom_infer.py | 10 +++------- 3 files changed, 5 insertions(+), 30 deletions(-) diff --git a/colossalai/inference/tensor_parallel/engine.py b/colossalai/inference/tensor_parallel/engine.py index e833ef3bdb7e..52d2fc05ffbb 100644 --- a/colossalai/inference/tensor_parallel/engine.py +++ b/colossalai/inference/tensor_parallel/engine.py @@ -141,7 +141,6 @@ def generate_by_set_infer_state(self, input_tokens, generate_kwargs) -> torch.Te outputs = self.sharded_model.generate(**input_tokens, **generate_kwargs, early_stopping=False) - print(f"outputs.shape {outputs.shape}") return outputs def prepare_batch_state(self, inputs) -> BatchInferState: @@ -193,11 +192,7 @@ def prepare_batch_state(self, inputs) -> BatchInferState: start_index += curr_seq_len max_len_in_batch = curr_seq_len if curr_seq_len > max_len_in_batch else max_len_in_batch - print(" 666 ", max_len_in_batch) - - block_loc = torch.empty((batch_size, self.max_input_len + self.max_output_len), - dtype=torch.long, - device='cuda') + block_loc = torch.empty((batch_size, self.max_input_len + self.max_output_len), dtype=torch.long, device='cuda') batch_infer_state = BatchInferState(batch_size, max_len_in_batch) batch_infer_state.seq_len = seq_lengths.to('cuda') # might want to assign specific device batch_infer_state.start_loc = seq_start_indexes.to('cuda') @@ -251,4 +246,4 @@ def update_batch_state(self, infer_state: Optional[BatchInferState]) -> None: # => put information already recorded in batchinferstate and pass it to model forward # => clear records in engine def add_request(): - raise NotImplementedError() \ No newline at end of file + raise NotImplementedError() diff --git a/colossalai/inference/tensor_parallel/modeling/bloom.py b/colossalai/inference/tensor_parallel/modeling/bloom.py index eefc68197e5e..e5fafa703919 100644 --- a/colossalai/inference/tensor_parallel/modeling/bloom.py +++ b/colossalai/inference/tensor_parallel/modeling/bloom.py @@ -215,10 +215,6 @@ def bloom_model_forward( curr_tp_rank = dist.get_rank() alibi = generate_alibi(self.num_heads * tp_size).contiguous()[curr_tp_rank * self.num_heads:(curr_tp_rank + 1) * self.num_heads].cuda() - print(f" bloom_model_forward:") - print(f" alibi.shape: {alibi.shape}") - print(alibi) - causal_mask = self._prepare_attn_mask( attention_mask, input_shape=(batch_size, seq_length), @@ -494,18 +490,6 @@ def bloom_attention_forward( # output = self.output[:batch_size*q_length, :, :] output = torch.empty_like(q) - # temp for testing - if not dist.is_initialized() or dist.get_rank() == 0: - print(f" cp3") - print(f" q.shape: {q.shape}") - print(f" k.shape: {k.shape}") - print(f" v.shape: {v.shape}") - print(f" output.shape: {output.shape}") - print(f" b_start_loc: {b_start_loc}") - print(f" b_seq_len: {b_seq_len}") - print(f" max_input_len: {max_input_len}") - print(f" alibi: {alibi}") - bloom_context_attn_fwd(q, k, v, output, b_start_loc, b_seq_len, max_input_len, alibi) context_layer = output.view(batch_size, q_length, H * D_HEAD) diff --git a/tests/test_infer/test_bloom_infer.py b/tests/test_infer/test_bloom_infer.py index 6381323555b6..bb63d1a88d5b 100644 --- a/tests/test_infer/test_bloom_infer.py +++ b/tests/test_infer/test_bloom_infer.py @@ -6,12 +6,12 @@ from colossalai.inference.tensor_parallel import TPInferEngine from colossalai.logging import disable_existing_loggers from colossalai.shardformer import ShardConfig, ShardFormer -from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn +from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn TP_SIZE = 2 MAX_BATCH_SIZE = 4 MAX_INPUT_LEN = 16 -MAX_OUTPUT_LEN = 8 +MAX_OUTPUT_LEN = 32 def run(): @@ -21,11 +21,8 @@ def run(): tokenizer.pad_token = tokenizer.eos_token text = "Introduce some landmarks in Beijing" - input_ids = tokenizer.encode(text, return_tensors='pt') + input_ids = tokenizer.batch_encode_plus([text], return_tensors='pt') - # model_config = BloomConfig() - # model = BloomForCausalLM(model_config) - # model = AutoModelForCausalLM.from_pretrained(model_path, pad_token_id=tokenizer.eos_token_id) model = BloomForCausalLM.from_pretrained(model_path, pad_token_id=tokenizer.eos_token_id) model = model.half() model.to(torch.cuda.current_device()) @@ -40,7 +37,6 @@ def run(): generate_kwargs = dict(do_sample=False) outputs = infer_engine.generate(input_ids, generate_kwargs) - print(outputs) output_text = tokenizer.decode(outputs[0]) print(output_text) From 25a9a2165469e0cbaf5910a263ebc00de7ecad0d Mon Sep 17 00:00:00 2001 From: yuanheng-zhao Date: Wed, 30 Aug 2023 17:44:34 +0800 Subject: [PATCH 15/16] trivial fix --- colossalai/inference/tensor_parallel/__init__.py | 6 +----- tests/test_infer/test_bloom_infer.py | 6 ++++-- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/colossalai/inference/tensor_parallel/__init__.py b/colossalai/inference/tensor_parallel/__init__.py index 8c8fa2960429..e467b4c73e6b 100644 --- a/colossalai/inference/tensor_parallel/__init__.py +++ b/colossalai/inference/tensor_parallel/__init__.py @@ -1,8 +1,4 @@ from .engine import TPInferEngine from .kvcache_manager import MemoryManager -from .modeling.bloom import BloomInferenceForwards -from .modeling.llama import LlamaInferenceForwards -from .policies.bloom import BloomModelInferPolicy -from .policies.llama import LlamaModelInferPolicy -__all__ = ['LlamaInferenceForwards', 'LlamaModelInferPolicy', 'MemoryManager', 'TPInferEngine'] +__all__ = ['MemoryManager', 'TPInferEngine'] diff --git a/tests/test_infer/test_bloom_infer.py b/tests/test_infer/test_bloom_infer.py index bb63d1a88d5b..95ab7d5c451e 100644 --- a/tests/test_infer/test_bloom_infer.py +++ b/tests/test_infer/test_bloom_infer.py @@ -1,5 +1,6 @@ import pytest import torch +import torch.distributed as dist from transformers import AutoModelForCausalLM, AutoTokenizer, BloomForCausalLM import colossalai @@ -37,8 +38,9 @@ def run(): generate_kwargs = dict(do_sample=False) outputs = infer_engine.generate(input_ids, generate_kwargs) - output_text = tokenizer.decode(outputs[0]) - print(output_text) + if not dist.is_initialized() or dist.get_rank() == 0: + output_text = tokenizer.decode(outputs[0]) + print(output_text) def check_engine(rank, world_size, port): From 474c45d06d00194bc96b8b858f87d0cf77a85456 Mon Sep 17 00:00:00 2001 From: yuanheng-zhao Date: Wed, 30 Aug 2023 17:53:08 +0800 Subject: [PATCH 16/16] trivial --- .../tensor_parallel/modeling/bloom.py | 28 ++++--------------- 1 file changed, 5 insertions(+), 23 deletions(-) diff --git a/colossalai/inference/tensor_parallel/modeling/bloom.py b/colossalai/inference/tensor_parallel/modeling/bloom.py index e5fafa703919..1a5dbf4b5a1b 100644 --- a/colossalai/inference/tensor_parallel/modeling/bloom.py +++ b/colossalai/inference/tensor_parallel/modeling/bloom.py @@ -24,6 +24,9 @@ def generate_alibi(n_head, dtype=torch.float16): """ + This method is adapted from `_generate_alibi` function + in `lightllm/models/bloom/layer_weights/transformer_layer_weight.py` + of the ModelTC/lightllm GitHub repository. This method is originally the `build_alibi_tensor` function in `transformers/models/bloom/modeling_bloom.py` of the huggingface/transformers GitHub repository. @@ -44,27 +47,6 @@ def generate_alibi(n_head, dtype=torch.float16): limitations under the License. """ - def get_slopes(n): - - def get_slopes_power_of_2(n): - start = 2**(-(2**-(math.log2(n) - 3))) - ratio = start - return [start * ratio**i for i in range(n)] - - if math.log2(n).is_integer(): - return get_slopes_power_of_2(n) - else: - closest_power_of_2 = 2**math.floor(math.log2(n)) - return (get_slopes_power_of_2(closest_power_of_2) + - get_slopes(2 * closest_power_of_2)[0::2][:n - closest_power_of_2]) - - slopes = torch.Tensor(get_slopes(n_head)) - head_alibi = slopes.to(dtype) - return head_alibi # 1 * num_heads - - -def generate_alibi_2(n_head, dtype=torch.float16): - def get_slopes_power_of_2(n): start = 2**(-(2**-(math.log2(n) - 3))) return [start * start**i for i in range(n)] @@ -79,8 +61,8 @@ def get_slopes(n): slopes_combined = slopes_power_of_2 + slopes_double[0::2][:n - closest_power_of_2] return slopes_combined - slopes = torch.tensor(get_slopes(n_head), dtype=dtype) - return slopes + slopes = get_slopes(n_head) + return torch.tensor(slopes, dtype=dtype) class BloomInferenceForwards: