Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion colossalai/inference/dynamic_batching/ray_dist_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,9 @@ def setup(self, world_size, rank, port):
self.model = AutoModelForCausalLM.from_pretrained(
self.model_path, pad_token_id=self.tokenizer.pad_token_id, torch_dtype=torch.float16
)
shard_config = ShardConfig(enable_tensor_parallelism=True if world_size > 1 else False, inference_only=True)
shard_config = ShardConfig(
enable_tensor_parallelism=True if world_size > 1 else False, extra_kwargs={"inference_only": True}
)
self.infer_engine = TPInferEngine(
self.model, shard_config, self.max_batch_size, self.max_input_len, self.max_output_len
)
Expand Down
3 changes: 1 addition & 2 deletions colossalai/inference/hybridengine/polices/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,7 @@ def __init__(self) -> None:

def module_policy(self):
policy = super().module_policy()

if self.shard_config.inference_gptq:
if self.shard_config.extra_kwargs.get("inference_gptq", False):
from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear

decoder_attribute_replacement = {
Expand Down
16 changes: 7 additions & 9 deletions colossalai/inference/tensor_parallel/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class TPInferEngine:
>>> # define model and shard config for your inference
>>> model = ...
>>> generate_kwargs = ...
>>> shard_config = ShardConfig(enable_tensor_parallelism=True, inference_only=True)
>>> shard_config = ShardConfig(enable_tensor_parallelism=True, extra_kwargs={"inference_only": True})
>>> infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
>>> outputs = infer_engine.generate(input_ids, **generate_kwargs)
"""
Expand Down Expand Up @@ -181,7 +181,7 @@ def _optimize_model(self, model: nn.Module) -> None:
In further generation, use the sharded model instead of original model.
"""
# NOTE we will change to use an inference config later with additional attrs we want
assert self.shard_config.inference_only is True
assert self.shard_config.extra_kwargs["inference_only"] is True
shardformer = ShardFormer(shard_config=self.shard_config)
self._prepare_with_shard_config(shard_config=self.shard_config)
self._shard_model_by(shardformer, model)
Expand All @@ -203,10 +203,10 @@ def _prepare_with_shard_config(self, shard_config: Optional[ShardConfig] = None)
enable_all_optimization=False,
enable_flash_attention=False,
enable_jit_fused=False,
inference_only=True,
extra_kwargs={"inference_only": True},
)
else:
shard_config.inference_only = True
shard_config.extra_kwargs = {"inference_only": True}
shard_config.pipeline_stage_manager = None
if shard_config.enable_tensor_parallelism:
self.tp_size = shard_config.tensor_parallel_size
Expand All @@ -221,13 +221,11 @@ def _shard_model_by(self, shardformer: ShardFormer, model: nn.Module) -> None:
), "Discrepancy between the tp size of TPInferEngine and the tp size of shard config"
model_name = model.__class__.__name__
assert model_name in self.supported_models, f"Unsupported model cls {model_name} for TP inference."

model = model.model if self.shard_config.inference_gptq else model
if self.shard_config.extra_kwargs.get("inference_gptq", False):
model = model.model
policy = get_autopolicy(model, shard_config=self.shard_config)

self.model, _ = shardformer.optimize(model, policy)

if self.shard_config.inference_gptq:
if self.shard_config.extra_kwargs.get("inference_gptq", False):
self._post_init_gptq_buffer(self.model)

self.model = self.model.cuda()
Expand Down
61 changes: 32 additions & 29 deletions colossalai/inference/tensor_parallel/policies/bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from torch.nn import LayerNorm

import colossalai.shardformer.layer as col_nn
from colossalai.shardformer.modeling.bloom import build_bloom_alibi_tensor_fn
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription
from colossalai.shardformer.policies.bloom import BloomForCausalLMPolicy

Expand Down Expand Up @@ -38,35 +37,39 @@ def module_policy(self):
from transformers.models.bloom.modeling_bloom import BloomAttention, BloomBlock, BloomForCausalLM, BloomModel

policy = super().module_policy()
if self.shard_config.inference_gptq:

if self.shard_config.extra_kwargs.get("inference_gptq", False):
from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear
policy[BloomBlock] = ModulePolicyDescription(attribute_replacement={
"self_attention.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"self_attention.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"self_attention.num_heads": self.model.config.n_head // self.shard_config.tensor_parallel_size,
},
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attention.query_key_value",
target_module=ColCaiQuantLinear,
kwargs={'split_num': 3}),
SubModuleReplacementDescription(
suffix="self_attention.dense",
target_module=RowCaiQuantLinear,
kwargs={'split_num': 1}),
SubModuleReplacementDescription(
suffix="self_attention.attention_dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="mlp.dense_h_to_4h",
target_module=ColCaiQuantLinear,
kwargs={'split_num': 1}),
SubModuleReplacementDescription(
suffix="mlp.dense_4h_to_h",
target_module=RowCaiQuantLinear,
kwargs={'split_num': 1}),
])

policy[BloomBlock] = ModulePolicyDescription(
attribute_replacement={
"self_attention.hidden_size": self.model.config.hidden_size
// self.shard_config.tensor_parallel_size,
"self_attention.split_size": self.model.config.hidden_size
// self.shard_config.tensor_parallel_size,
"self_attention.num_heads": self.model.config.n_head // self.shard_config.tensor_parallel_size,
},
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attention.query_key_value",
target_module=ColCaiQuantLinear,
kwargs={"split_num": 3},
),
SubModuleReplacementDescription(
suffix="self_attention.dense", target_module=RowCaiQuantLinear, kwargs={"split_num": 1}
),
SubModuleReplacementDescription(
suffix="self_attention.attention_dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="mlp.dense_h_to_4h", target_module=ColCaiQuantLinear, kwargs={"split_num": 1}
),
SubModuleReplacementDescription(
suffix="mlp.dense_4h_to_h", target_module=RowCaiQuantLinear, kwargs={"split_num": 1}
),
],
)
# NOTE set inference mode to shard config
self.shard_config._infer()

Expand Down
4 changes: 3 additions & 1 deletion colossalai/inference/tensor_parallel/policies/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

try:
from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward as lightllm_rmsnorm_forward

HAS_TRITON_RMSNORM = True
except:
print("you should install triton from https://github.com/openai/triton")
Expand All @@ -21,6 +22,7 @@

def get_triton_rmsnorm_forward():
if HAS_TRITON_RMSNORM:

def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor):
return lightllm_rmsnorm_forward(hidden_states, self.weight.data, self.variance_epsilon)

Expand All @@ -36,7 +38,7 @@ def __init__(self) -> None:
def module_policy(self):
policy = super().module_policy()

if self.shard_config.inference_gptq:
if self.shard_config.extra_kwargs.get("inference_gptq", False):
from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear

decoder_attribute_replacement = {
Expand Down
3 changes: 0 additions & 3 deletions colossalai/shardformer/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,6 @@ Following are the description `ShardConfig`'s arguments:

- `enable_all_optimization`: Whether to turn on all optimization tools including `fused normalizaion`, `flash attention`, `JIT fused operators`, `sequence parallelism` and `sequence overlap`. Defaults to False.

- `inference_only`: Whether only doing forward passing. Defaults to False.

### Write your own policy

If you have a custom model, you can also use Shardformer to parallelize it by writing your own sharding policy. More information about the sharding policy can be found in [API Design](#-api-design).
Expand Down Expand Up @@ -185,7 +183,6 @@ class ShardConfig:

# Some possible future config fields
tensor_parallel_mode: Choice['1d', '2d', '2.5d', '3d'] # support different tensor parallel mode
inference_only: bool # only inject inference-suitable sharding policy
use_flash_attention: bool # whether to use flash attention to speed up attention
```

Expand Down
5 changes: 3 additions & 2 deletions colossalai/shardformer/policies/auto_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,8 @@ def get_autopolicy(model: nn.Module, shard_config: ShardConfig = None) -> Policy
:class:`Policy`: The auto policy for the model
"""
full_name = _fullname(model)
if shard_config.inference_only:
inference_only = shard_config.extra_kwargs.get("inference_only", False)
if inference_only:
policy_location = _INFER_POLICY_LIST.get(full_name, None)
else:
policy_location = _POLICY_LIST.get(full_name, None)
Expand All @@ -219,5 +220,5 @@ def get_autopolicy(model: nn.Module, shard_config: ShardConfig = None) -> Policy
f"Auto policy for {model.__class__.__qualname__} is not implemented\n. Supported models are {list(_POLICY_LIST.keys())} and {list(_INFER_POLICY_LIST.keys())}"
)
else:
policy = import_policy(policy_location, shard_config.inference_only)
policy = import_policy(policy_location, inference_only)
return policy()
9 changes: 3 additions & 6 deletions colossalai/shardformer/shard/shard_config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Optional
from dataclasses import dataclass, field
from typing import Dict, Optional

import torch.distributed as dist
from torch.distributed import ProcessGroup
Expand All @@ -24,7 +24,6 @@ class ShardConfig:
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism, which partitions non-tensor-parallel regions along the sequence dimension. Defaults to False.
enable_sequence_overlap (bool): Whether to turn on sequence overlap, wheich overlap the computation and communication in sequence parallelism. It can only be used when enable_sequence_parallelism is True. Defaults to False.
enable_all_optimization (bool): Whether to turn on all optimization tools including 'fused normalizaion', 'flash attention', 'JIT fused operators', 'sequence parallelism' and 'sequence overlap'. Defaults to False.
inference_only (bool): Whether only doing forward passing. Defaults to False.
"""
tensor_parallel_process_group: Optional[ProcessGroup] = None
pipeline_stage_manager: Optional[PipelineStageManager] = None
Expand All @@ -33,10 +32,9 @@ class ShardConfig:
enable_flash_attention: bool = False
enable_jit_fused: bool = False
enable_all_optimization: bool = False
inference_only: bool = False
inference_gptq: bool = False
enable_sequence_parallelism: bool = False
enable_sequence_overlap: bool = False
extra_kwargs: Dict[str, bool] = field(default_factory=dict)
# pipeline_parallel_size: int
# data_parallel_size: int
# tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d']
Expand Down Expand Up @@ -77,4 +75,3 @@ def _infer(self):
Set default params for inference.
"""
# assert self.pipeline_stage_manager is None, "pipeline parallelism is not supported in inference for now"
pass
4 changes: 3 additions & 1 deletion examples/inference/bench_bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ def bench_bloom(args):

# init TPInferEngine and shard the original model
# To benchmark torch original, comment out the line of optimizing model
shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True)
shard_config = ShardConfig(
enable_tensor_parallelism=True if args.tp_size > 1 else False, extra_kwargs={"inference_only": True}
)
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)

# prepare data for generation
Expand Down
4 changes: 3 additions & 1 deletion examples/inference/bench_chatglm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ def run_chatglm2_test(args):
model = model.half()
model.config

shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True)
shard_config = ShardConfig(
enable_tensor_parallelism=True if args.tp_size > 1 else False, extra_kwargs={"inference_only": True}
)
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)

generate_kwargs = dict(max_new_tokens=1, do_sample=False)
Expand Down
4 changes: 3 additions & 1 deletion examples/inference/bench_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ def run_llama_test(args):
model = model.half()
model.config

shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True)
shard_config = ShardConfig(
enable_tensor_parallelism=True if args.tp_size > 1 else False, extra_kwargs={"inference_only": True}
)
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)

generate_kwargs = dict(max_new_tokens=1, do_sample=False)
Expand Down
7 changes: 5 additions & 2 deletions examples/inference/gptq_bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ def bench_bloom(args):
model = model.half()

model_config = model.config
shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True)
shard_config = ShardConfig(
enable_tensor_parallelism=True if args.tp_size > 1 else False, extra_kwargs={"inference_only": True}
)
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)
generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False)

Expand All @@ -46,7 +48,8 @@ def bench_bloom(args):
# init TPInferEngine and shard the original model
# To benchmark torch original, comment out the line of optimizing model
shard_config = ShardConfig(
enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True, inference_gptq=True
enable_tensor_parallelism=True if args.tp_size > 1 else False,
extra_kwargs={"inference_only": True, "inference_gptq": True},
)
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)

Expand Down
3 changes: 2 additions & 1 deletion examples/inference/gptq_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ def run_llama_test(args):

model_config = model.config
shard_config = ShardConfig(
enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True, inference_gptq=True
enable_tensor_parallelism=True if args.tp_size > 1 else False,
extra_kwargs={"inference_only": True, "inference_gptq": True},
)
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@ def setup(self, world_size, rank, port):
self.model_path, pad_token_id=self.tokenizer.pad_token_id, torch_dtype=torch.float16
)

shard_config = ShardConfig(enable_tensor_parallelism=True if world_size > 1 else False, inference_only=True)
shard_config = ShardConfig(
enable_tensor_parallelism=True if world_size > 1 else False, extra_kwargs={"inference_only": True}
)
self.infer_engine = TPInferEngine(
self.model, shard_config, self.max_batch_size, self.max_input_len, self.max_output_len
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,9 @@ def initialize(self, ctx):

colossalai.launch(config={}, rank=rank, world_size=world_size, host=host, port=port, backend="nccl")
logger.info("Initializing TPInferEngine ...")
shard_config = ShardConfig(enable_tensor_parallelism=True if self.tp_size > 1 else False, inference_only=True)
shard_config = ShardConfig(
enable_tensor_parallelism=True if self.tp_size > 1 else False, extra_kwargs={"inference_only": True}
)
self.infer_engine = TPInferEngine(
self.model, shard_config, self.max_batch_size, self.max_input_len, self.max_output_len
)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_infer/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def build_model(
enable_tensor_parallelism=enable_tensor_parallelism,
enable_flash_attention=enable_flash_attention,
enable_jit_fused=enable_jit_fused,
inference_only=True,
extra_kwargs={"inference_only": True},
)
model_copy = copy.deepcopy(org_model)
shard_former = ShardFormer(shard_config=shard_config)
Expand Down
10 changes: 6 additions & 4 deletions tests/test_infer/test_bloom_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,10 @@
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn

try:
import lightllm
HAS_LIGHTLLM_KERNEL = True
except:
HAS_LIGHTLLM_KERNEL = False

TP_SIZE = 2
MAX_BATCH_SIZE = 4
MAX_INPUT_LEN = 16
Expand All @@ -38,7 +37,7 @@ def run(test_config):
model = model.half()

shard_config = ShardConfig(
enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, inference_only=True
enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, extra_kwargs={"inference_only": True}
)
infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False)
Expand All @@ -58,7 +57,10 @@ def check_bloom(rank, world_size, port):
run()


@pytest.mark.skipif(not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
@pytest.mark.skipif(
not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL,
reason="kv-cache manager engine requires cuda version to be higher than 11.5",
)
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
Expand Down
2 changes: 1 addition & 1 deletion tests/test_infer/test_chatglm2_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def run_chatglm2_test(test_config):
model = model.half()

shard_config = ShardConfig(
enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, inference_only=True
enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, extra_kwargs={"inference_only": True}
)
infer_engine = TPInferEngine(model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False)
Expand Down
Loading