Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions colossalai/inference/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .engine import CaiInferEngine
from .engine import InferenceEngine
from .engine.policies import BloomModelInferPolicy, ChatGLM2InferPolicy, LlamaModelInferPolicy

__all__ = ["CaiInferEngine", "LlamaModelInferPolicy", "BloomModelInferPolicy", "ChatGLM2InferPolicy"]
__all__ = ["InferenceEngine", "LlamaModelInferPolicy", "BloomModelInferPolicy", "ChatGLM2InferPolicy"]
4 changes: 2 additions & 2 deletions colossalai/inference/engine/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .engine import CaiInferEngine
from .engine import InferenceEngine

__all__ = ["CaiInferEngine"]
__all__ = ["InferenceEngine"]
34 changes: 4 additions & 30 deletions colossalai/inference/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import torch
import torch.distributed as dist
import torch.nn as nn
from transformers.tokenization_utils_base import BatchEncoding
from transformers.utils import logging

from colossalai.cluster import ProcessGroupMesh
Expand All @@ -27,9 +26,9 @@
]


class CaiInferEngine:
class InferenceEngine:
"""
CaiInferEngine is a class that handles the pipeline parallel inference.
InferenceEngine is a class that handles the pipeline parallel inference.

Args:
tp_size (int): the size of tensor parallelism.
Expand All @@ -42,27 +41,6 @@ class CaiInferEngine:
max_input_len (int): the maximum input length.
max_output_len (int): the maximum output length.

Example:

```python
from colossalai.inference import InferEngine
from colossalai.inference.pipeline.policies import LlamaModelInferPolicy
import colossalai
from transformers import LlamaForCausalLM, LlamaTokenizer

colossalai.launch_from_torch(config={})

model = LlamaForCausalLM.from_pretrained("your_path_to_model")
tokenizer = LlamaTokenizer.from_pretrained("/home/lczyh/share/models/llama-7b-hf")
# assume the model is infered with 2 pipeline stages
inferengine = CaiInferEngine(pp_size=2, model=model, model_policy=LlamaModelInferPolicy())

input = ["Introduce a landmark in China ","Introduce a landmark in China "]
data = tokenizer(input, return_tensors='pt')
output = inferengine.inference([data.to('cuda').data])

```

"""

def __init__(
Expand Down Expand Up @@ -146,7 +124,7 @@ def __init__(
if quant == "gptq":
self.gptq_manager.post_init_gptq_buffer(self.model)

def generate(self, input_list: Union[BatchEncoding, dict]):
def generate(self, input_list: Union[list, dict]):
"""
Args:
input_list (list): a list of input data, each element is a `BatchEncoding` or `dict`.
Expand All @@ -155,11 +133,7 @@ def generate(self, input_list: Union[BatchEncoding, dict]):
out (list): a list of output data, each element is a list of token.
timestamp (float): the time cost of the inference, only return when verbose is `True`.
"""
assert isinstance(
input_list, (BatchEncoding, dict)
), f"Only accept BatchEncoding or dict as input, but get {input_list.__class__.__name__}."
if isinstance(input_list, BatchEncoding):
input_list = input_list.data

out, timestamp = self.schedule.generate_step(self.model, iter([input_list]))
if self.verbose:
return out, timestamp
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def parse_args():
type=str,
help="location of the calibration dataset",
)
parser.add_argument("--num-samples", type=int, default=512)
parser.add_argument("--num-samples", type=int, default=10)
parser.add_argument("--seq-len", type=int, default=512)
args = parser.parse_args()
return args
Expand All @@ -41,13 +41,12 @@ def main():
model_path = args.model_name
dataset_path = args.dataset_path
output_path = args.output_path
num_samples = 10
seq_len = 512
num_samples = args.num_samples
seq_len = args.seq_len

model, tokenizer = build_model_and_tokenizer(model_path)
if not os.path.exists(dataset_path):
print(f"Cannot find the dataset at {args.dataset_path}")
raise FileNotFoundError
raise FileNotFoundError(f"Cannot find the dataset at {args.dataset_path}")
dataset = load_dataset("json", data_files=dataset_path, split="train")

model.quantized(tokenizer, dataset, num_samples=num_samples, seq_len=seq_len)
Expand Down
72 changes: 0 additions & 72 deletions examples/inference/hybrid_gptq_llama.py

This file was deleted.

86 changes: 0 additions & 86 deletions examples/inference/hybrid_llama.py

This file was deleted.

69 changes: 0 additions & 69 deletions examples/inference/hybrid_smoothquant_llama.py

This file was deleted.

Loading