Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 16 additions & 20 deletions colossalai/inference/hybridengine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,6 @@ def __init__(
assert max_batch_size <= 64, "Max batch size exceeds the constraint"
assert max_input_len + max_output_len <= 4096, "Max length exceeds the constraint"

# TODO: support only tensor parallel inference
assert pp_size > 1, "Not support only tensor parallel inference."
self.pp_size = pp_size
self.tp_size = tp_size

Expand All @@ -102,23 +100,21 @@ def __init__(
# Init pg mesh
pg_mesh = ProcessGroupMesh(pp_size, tp_size)

stage_manager = None
if pp_size > 1:
stage_manager = PipelineStageManager(pg_mesh, PP_AXIS, True)
self.cache_manager_list = [
self._init_manager(model, max_batch_size, max_input_len, max_output_len)
for _ in range(micro_batch_buffer_size or pp_size)
]
self.mb_manager = MicroBatchManager(
stage_manager.stage,
micro_batch_size,
micro_batch_buffer_size or pp_size,
max_input_len,
max_output_len,
self.cache_manager_list,
)
self.verbose = verbose
self.schedule = GenerateSchedule(stage_manager, self.mb_manager, verbose)
stage_manager = PipelineStageManager(pg_mesh, PP_AXIS, True)
self.cache_manager_list = [
self._init_manager(model, max_batch_size, max_input_len, max_output_len)
for _ in range(micro_batch_buffer_size or pp_size)
]
self.mb_manager = MicroBatchManager(
stage_manager.stage,
micro_batch_size,
micro_batch_buffer_size or pp_size,
max_input_len,
max_output_len,
self.cache_manager_list,
)
self.verbose = verbose
self.schedule = GenerateSchedule(stage_manager, self.mb_manager, verbose)

self.model = self._shardformer(model, model_policy, stage_manager, pg_mesh.get_group_along_axis(TP_AXIS))

Expand Down Expand Up @@ -146,7 +142,7 @@ def _shardformer(self, model, model_policy, stage_manager, tp_group):
shardconfig = ShardConfig(
tensor_parallel_process_group=tp_group,
pipeline_stage_manager=stage_manager,
enable_tensor_parallelism=False,
enable_tensor_parallelism=True if self.tp_size > 1 else False,
enable_fused_normalization=False,
enable_all_optimization=False,
enable_flash_attention=False,
Expand Down
2 changes: 1 addition & 1 deletion colossalai/inference/tensor_parallel/kvcache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,4 +103,4 @@ def free_all(self):
self.available_size = len(self.mem_state)
self.mem_state[:] = 1
self.max_len_in_batch = 0
self.logger.info("freed all space of memory manager")
# self.logger.info("freed all space of memory manager")
107 changes: 87 additions & 20 deletions colossalai/pipeline/schedule/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,7 @@ def _prepare_inputs_for_interval_stage(self):
Returns:
dict: inputs for interval stage, `{'past_key_values': torch.Tensor}` or `None`
"""
model_inputs = {
'infer_state': self.mb_manager.cur_descrption.infer_state
}
model_inputs = {"infer_state": self.mb_manager.cur_descrption.infer_state}
return model_inputs

def _prepare_inputs_for_new_token(self, new_token: torch.Tensor):
Expand Down Expand Up @@ -129,8 +127,8 @@ def _recv_pre_stage(self) -> Any:

def _init_infer_state_action(self) -> None:
"""
This action is only for no first stage, to load batch and init infer_state.
1.Load micro_batch 2.Use the current micro_batch to init the current infer_state
This action is only for no first stage, to load batch and init infer_state.
1.Load micro_batch 2.Use the current micro_batch to init the current infer_state
"""
inputs_dict = self.load_micro_batch()
self.mb_manager.add_descrption(inputs_dict)
Expand All @@ -145,19 +143,19 @@ def _load_stage_action(self, model: Module) -> None:
if self.verbose and self.stage_manager.is_first_stage():
torch.cuda.synchronize()
self.timestamps[self.mb_manager.idx].append(time.time())
interval_inputs = {'infer_state': self.mb_manager.cur_infer_state}
interval_inputs = {"infer_state": self.mb_manager.cur_infer_state}
output_dict = model_forward(model, inputs_dict, interval_inputs)

self.action_interval_buffer.hidden_states = output_dict['hidden_states']
self.action_interval_buffer.hidden_states = output_dict["hidden_states"]

def _gen_token_action(self, model: Module):
"""
This action is only for first stage
This action is only for first stage
1.do the forward with hidden_states to generate new tokens 2.step to update
"""
hidden_states = self.action_interval_buffer.hidden_states
assert hidden_states is not None, "When first stage in GENERATE phase, the hidden states should not be None"
interval_inputs = {'hidden_states': hidden_states, 'infer_state': self.mb_manager.cur_infer_state}
interval_inputs = {"hidden_states": hidden_states, "infer_state": self.mb_manager.cur_infer_state}
logits = model_forward(model, None, interval_inputs)
if self.verbose and self.stage_manager.is_first_stage():
torch.cuda.synchronize()
Expand All @@ -178,18 +176,18 @@ def _head_encoding_action(self, model: Module):
new_token = self.action_interval_buffer.new_token
assert new_token is not None, "When first stage in GENERATE phase, the new token should not be None"
inputs_dict = self._prepare_inputs_for_new_token(new_token)
interval_inputs = {'infer_state': self.mb_manager.cur_infer_state}
interval_inputs = {"infer_state": self.mb_manager.cur_infer_state}
output_dict = model_forward(model, inputs_dict, interval_inputs)

self.action_interval_buffer.hidden_states = output_dict['hidden_states']
self.action_interval_buffer.hidden_states = output_dict["hidden_states"]

def _body_encoding_action(self, model: Module):
hidden_states = self.action_interval_buffer.hidden_states
assert hidden_states is not None, "When not first stage, the hidden states should not be None"
interval_inputs = {'hidden_states': hidden_states, 'infer_state': self.mb_manager.cur_infer_state}
interval_inputs = {"hidden_states": hidden_states, "infer_state": self.mb_manager.cur_infer_state}
output_dict = model_forward(model, None, interval_inputs)

self.action_interval_buffer.hidden_states = output_dict['hidden_states']
self.action_interval_buffer.hidden_states = output_dict["hidden_states"]

def _comm_action(self, recv_pre: bool) -> torch.Tensor:
"""
Expand Down Expand Up @@ -233,12 +231,73 @@ def _gen_action(self, model: Module):

return actions

def _gen_one_stage_action(self, model: Module):
"""
In this function, it will generate a sequence action for current state, and do the action one by one.

Args:
model (Module): Model to be run.

Returns:
List[Callable]: A list of action, each action is a callable function, and it will be called in order.
"""
actions = []

if self.mb_manager.cur_state is Status.PREFILL:
actions.append(partial(self._load_stage_action, model))
elif self.mb_manager.cur_state is Status.GENERATE:
actions.append(partial(self._gen_token_action, model))
actions.append(partial(self._head_encoding_action, model))
elif self.mb_manager.cur_state is Status.COOLDOWN:
actions.append(partial(self._gen_token_action, model))

return actions

def generate_step(self, model: Module, data_iter: Iterable) -> Union[torch.Tensor, dict]:
if self.stage_manager.num_stages == 2:
if self.stage_manager.num_stages == 1:
return self.generate_step_one_stage(model, data_iter)
elif self.stage_manager.num_stages == 2:
return self.generate_step_p2p(model, data_iter)
else:
return self.generate_step_broadcast(model, data_iter)

@torch.no_grad()
def generate_step_one_stage(self, model: Module, data_iter: Iterable) -> Union[torch.Tensor, dict]:
"""
Forward one step of the pipeline, when pipeline size is 1.

Args:
model (Module): Model to be run.
data_iter (Iterable): Data iterator.

Returns:
Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor).
"""
output_sequence = []
self.load_batch(data_iter)
model.eval()
self.comm_dtype = model.dtype

whole_timestamp = []

# run by round
for _ in range(self.round):
self.timestamps = [[] for _ in range(self.stage_manager.num_stages)] if self.verbose else None
self.action_interval_buffer.clear()
while self.mb_manager.is_micro_batch_done() is False:
actions = self._gen_one_stage_action(model)
for action in actions:
action()
self.mb_manager.next()
# All microbatch in current round is DONE
output_sequence.extend(self.mb_manager.export_new_tokens())

self.mb_manager.clear()
if self.verbose:
whole_timestamp.extend(self.timestamps)

return output_sequence, whole_timestamp

@torch.no_grad()
def generate_step_p2p(self, model: Module, data_iter: Iterable) -> Union[torch.Tensor, dict]:
"""
Expand Down Expand Up @@ -319,7 +378,7 @@ def generate_step_broadcast(self, model: Module, data_iter: Iterable) -> Union[t
torch.cuda.synchronize()
self.timestamps[self.mb_manager.idx].append(time.time())
self.mb_manager.add_descrption(inputs_dict)
interval_inputs = {'infer_state': self.mb_manager.cur_infer_state}
interval_inputs = {"infer_state": self.mb_manager.cur_infer_state}
output_dict = model_forward(model, inputs_dict, interval_inputs)
# In GENERATE phase
else:
Expand All @@ -330,18 +389,23 @@ def generate_step_broadcast(self, model: Module, data_iter: Iterable) -> Union[t
assert (
hidden_states is not None
), "When first stage in GENERATE phase, the hidden states should not be None"
interval_inputs = {'hidden_states': hidden_states['hidden_states'], 'infer_state': self.mb_manager.cur_infer_state}
interval_inputs = {
"hidden_states": hidden_states["hidden_states"],
"infer_state": self.mb_manager.cur_infer_state,
}
logits = model_forward(model, None, interval_inputs)
if self.verbose and self.stage_manager.is_first_stage():
torch.cuda.synchronize()
self.timestamps[self.mb_manager.idx].append(time.time())
assert 'logits' in logits, f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}"
new_token = self._get_token_id(logits['logits'])
assert (
"logits" in logits
), f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}"
new_token = self._get_token_id(logits["logits"])
self.mb_manager.step(new_token)
# If the current micro batch is not DONE, go through blocks
if self.mb_manager.cur_state in (Status.GENERATE, Status.COOLDOWN):
inputs_dict = self._prepare_inputs_for_new_token(new_token)
interval_inputs = {'infer_state': self.mb_manager.cur_infer_state}
interval_inputs = {"infer_state": self.mb_manager.cur_infer_state}
output_dict = model_forward(model, inputs_dict, interval_inputs)
else:
assert hidden_states is not None, "When not first stage, the hidden states should not be None"
Expand All @@ -350,7 +414,10 @@ def generate_step_broadcast(self, model: Module, data_iter: Iterable) -> Union[t
if self.mb_manager.cur_state is Status.PREFILL:
inputs_dict = self.load_micro_batch()
self.mb_manager.add_descrption(inputs_dict)
interval_inputs = {'hidden_states': hidden_states['hidden_states'], 'infer_state': self.mb_manager.cur_infer_state}
interval_inputs = {
"hidden_states": hidden_states["hidden_states"],
"infer_state": self.mb_manager.cur_infer_state,
}
output_dict = model_forward(model, inputs_dict, interval_inputs)

# Current microbatch is not DONE, send hidden_state to next stage
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,16 @@ def run_tp_pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch
torch.cuda.empty_cache()


@parameterize("tp_size", [2])
@parameterize("pp_size", [1])
@parameterize("max_output_len", [2])
@parameterize("micro_batch_size", [1])
@clear_cache_before_run()
def run_tp_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size)
torch.cuda.empty_cache()


def check_pipeline_inference(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_pipeline_inference_test()
Expand All @@ -75,13 +85,19 @@ def check_tp_pipeline_inference(rank, world_size, port):
run_tp_pipeline_inference_test()


def check_tp_inference(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_tp_inference_test()


@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_pipeline_inference():
spawn(check_pipeline_inference, nprocs=2)
spawn(check_tp_pipeline_inference, nprocs=4)
spawn(check_tp_inference, nprocs=2)


if __name__ == "__main__":
Expand Down