Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions colossalai/inference/tensor_parallel/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def _shard_model_by(self, shardformer: ShardFormer, model: nn.Module) -> None:
), "Discrepancy between the tp size of TPInferEngine and the tp size of shard config"
model_name = model.__class__.__name__
assert model_name in self.supported_models, f"Unsupported model cls {model_name} for TP inference."

model = model.model if self.shard_config.inference_gptq else model
policy = get_autopolicy(model, shard_config=self.shard_config)

Expand Down Expand Up @@ -311,7 +311,7 @@ def prepare_batch_state(self, inputs) -> BatchInferState:
seq_start_indexes[i] = start_index
start_index += curr_seq_len
max_len_in_batch = curr_seq_len if curr_seq_len > max_len_in_batch else max_len_in_batch

block_loc = torch.empty((batch_size, self.max_input_len + self.max_output_len), dtype=torch.long, device="cuda")
batch_infer_state = BatchInferState(batch_size, max_len_in_batch)
batch_infer_state.seq_len = seq_lengths.to("cuda")
Expand Down
3 changes: 1 addition & 2 deletions tests/test_infer/test_pipeline_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def data_gen():
new_shape[0] = 16
inputs[k] = v.to("cuda").repeat(*new_shape)


def pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
model = transformers.LlamaForCausalLM(
transformers.LlamaConfig(
Expand Down Expand Up @@ -58,7 +59,6 @@ def run_pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_si
@parameterize("pp_size", [2])
@parameterize("max_output_len", [4])
@parameterize("micro_batch_size", [1])

@clear_cache_before_run()
def run_tp_pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size)
Expand All @@ -76,7 +76,6 @@ def check_tp_pipeline_inference(rank, world_size, port):


@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")

@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
Expand Down