Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 0 additions & 14 deletions colossalai/inference/tensor_parallel/kvcache_manager.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,6 @@
# Adapted from lightllm/common/mem_manager.py
# of the ModelTC/lightllm GitHub repository
# https://github.com/ModelTC/lightllm/blob/050af3ce65edca617e2f30ec2479397d5bb248c9/lightllm/common/mem_manager.py
#
# Copyright 2023 ModelTC Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch

Expand Down
27 changes: 6 additions & 21 deletions colossalai/inference/tensor_parallel/modeling/bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import torch
import torch.distributed as dist
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from torch.nn import CrossEntropyLoss
from torch.nn import functional as F
from transformers.models.bloom.modeling_bloom import (
BaseModelOutputWithPastAndCrossAttentions,
Expand All @@ -30,21 +30,6 @@ def generate_alibi(n_head, dtype=torch.float16):
This method is originally the `build_alibi_tensor` function
in `transformers/models/bloom/modeling_bloom.py`
of the huggingface/transformers GitHub repository.

Copyright 2023 ModelTC Team
Copyright 2022 HuggingFace Inc. team and BigScience workshop

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

def get_slopes_power_of_2(n):
Expand All @@ -67,7 +52,11 @@ def get_slopes(n):

class BloomInferenceForwards:
"""
This class serves a micro library for bloom inference forwards
This class serves a micro library for bloom inference forwards.
We intend to replace the forward methods for BloomForCausalLM, BloomModel, BloomBlock, and BloomAttention,
as well as prepare_inputs_for_generation method for BloomForCausalLM.
For future improvement, we might want to skip replacing methods for BloomForCausalLM,
and call BloomModel.forward iteratively in TpInferEngine
"""

@staticmethod
Expand Down Expand Up @@ -372,8 +361,6 @@ def bloom_for_causal_lm_prepare_inputs_for_generation(
})
return model_inputs

# replace decoder layer forward:
# used to replace BloomBlock.forward
@staticmethod
def bloom_block_forward(
self: BloomBlock,
Expand Down Expand Up @@ -432,8 +419,6 @@ def bloom_block_forward(

return outputs # hidden_states, present, attentions

# replace attention forward:
# used to replace BloomAttention.forward
@staticmethod
def bloom_attention_forward(
self: BloomAttention,
Expand Down