Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/build_on_pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ jobs:
container:
image: hpcaitech/pytorch-cuda:1.12.0-11.3.0
options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10 -v /data/scratch/llama-tiny:/data/scratch/llama-tiny
timeout-minutes: 60
timeout-minutes: 100
defaults:
run:
shell: bash
Expand Down
4 changes: 2 additions & 2 deletions colossalai/inference/hybridengine/polices/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def module_policy(self):
"self_attn.num_key_value_heads": self.model.config.num_key_value_heads
// self.shard_config.tensor_parallel_size,
}
if self.shard_config.quant == "gptq":
if self.shard_config.extra_kwargs.get("quant", None) == "gptq":
from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear

policy[LlamaDecoderLayer] = ModulePolicyDescription(
Expand Down Expand Up @@ -95,7 +95,7 @@ def module_policy(self):
],
)

elif self.shard_config.quant == "smoothquant":
elif self.shard_config.extra_kwargs.get("quant", None) == "smoothquant":
from colossalai.inference.quant.smoothquant.models.llama import LlamaSmoothquantDecoderLayer
from colossalai.inference.quant.smoothquant.models.parallel_linear import (
ColW8A8BFP32OFP32Linear,
Expand Down
4 changes: 2 additions & 2 deletions colossalai/shardformer/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ Following are the description `ShardConfig`'s arguments:

- `enable_all_optimization`: Whether to turn on all optimization tools including `fused normalizaion`, `flash attention`, `JIT fused operators`, `sequence parallelism` and `sequence overlap`. Defaults to False.

- `inference_only`: Whether only doing forward passing. Defaults to False.
- `extra_kwargs`: A dict to store extra kwargs for ShardFomer.

### Write your own policy

Expand Down Expand Up @@ -185,8 +185,8 @@ class ShardConfig:

# Some possible future config fields
tensor_parallel_mode: Choice['1d', '2d', '2.5d', '3d'] # support different tensor parallel mode
inference_only: bool # only inject inference-suitable sharding policy
use_flash_attention: bool # whether to use flash attention to speed up attention
extra_kwargs: Dict[str, Any] # extra kwargs for the shardformer
```

### Policy
Expand Down
5 changes: 3 additions & 2 deletions colossalai/shardformer/policies/auto_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,8 @@ def get_autopolicy(model: nn.Module, shard_config: ShardConfig = None) -> Policy
:class:`Policy`: The auto policy for the model
"""
full_name = _fullname(model)
if shard_config.inference_only:
inference_only = shard_config.extra_kwargs.get("inference_only", None)
if inference_only:
policy_location = _INFER_POLICY_LIST.get(full_name, None)
else:
policy_location = _POLICY_LIST.get(full_name, None)
Expand All @@ -219,5 +220,5 @@ def get_autopolicy(model: nn.Module, shard_config: ShardConfig = None) -> Policy
f"Auto policy for {model.__class__.__qualname__} is not implemented\n. Supported models are {list(_POLICY_LIST.keys())} and {list(_INFER_POLICY_LIST.keys())}"
)
else:
policy = import_policy(policy_location, shard_config.inference_only)
policy = import_policy(policy_location, inference_only)
return policy()
8 changes: 3 additions & 5 deletions colossalai/shardformer/shard/shard_config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Optional
from dataclasses import dataclass, field
from typing import Any, Dict, Optional

import torch.distributed as dist
from torch.distributed import ProcessGroup
Expand Down Expand Up @@ -33,11 +33,9 @@ class ShardConfig:
enable_flash_attention: bool = False
enable_jit_fused: bool = False
enable_all_optimization: bool = False
inference_only: bool = False
inference_gptq: bool = False
enable_sequence_parallelism: bool = False
enable_sequence_overlap: bool = False
quant: str = None
extra_kwargs: Dict[str, Any] = field(default_factory=dict)
# pipeline_parallel_size: int
# data_parallel_size: int
# tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d']
Expand Down
4 changes: 3 additions & 1 deletion examples/inference/bench_bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ def bench_bloom(args):

# init TPInferEngine and shard the original model
# To benchmark torch original, comment out the line of optimizing model
shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True)
shard_config = ShardConfig(
enable_tensor_parallelism=True if args.tp_size > 1 else False, extra_kwargs={"inference_only": True}
)
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)

# prepare data for generation
Expand Down
4 changes: 3 additions & 1 deletion examples/inference/bench_chatglm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ def run_chatglm2_test(args):
model = model.half()
model.config

shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True)
shard_config = ShardConfig(
enable_tensor_parallelism=True if args.tp_size > 1 else False, extra_kwargs={"inference_only": True}
)
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)

generate_kwargs = dict(max_new_tokens=1, do_sample=False)
Expand Down
4 changes: 3 additions & 1 deletion examples/inference/bench_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ def run_llama_test(args):
model = model.half()
model.config

shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True)
shard_config = ShardConfig(
enable_tensor_parallelism=True if args.tp_size > 1 else False, extra_kwargs={"inference_only": True}
)
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)

generate_kwargs = dict(max_new_tokens=1, do_sample=False)
Expand Down
7 changes: 5 additions & 2 deletions examples/inference/gptq_bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ def bench_bloom(args):
model = model.half()

model_config = model.config
shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True)
shard_config = ShardConfig(
enable_tensor_parallelism=True if args.tp_size > 1 else False, extra_kwargs={"inference_only": True}
)
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)
generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False)

Expand All @@ -46,7 +48,8 @@ def bench_bloom(args):
# init TPInferEngine and shard the original model
# To benchmark torch original, comment out the line of optimizing model
shard_config = ShardConfig(
enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True, inference_gptq=True
enable_tensor_parallelism=True if args.tp_size > 1 else False,
extra_kwargs={"inference_only": True, "quant": "gptq"},
)
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)

Expand Down
3 changes: 2 additions & 1 deletion examples/inference/gptq_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ def run_llama_test(args):

model_config = model.config
shard_config = ShardConfig(
enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True, inference_gptq=True
enable_tensor_parallelism=True if args.tp_size > 1 else False,
extra_kwargs={"inference_only": True, "quant": "gptq"},
)
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)

Expand Down