diff --git a/colossalai/inference/tensor_parallel/engine.py b/colossalai/inference/tensor_parallel/engine.py index 01763c850381..986177555b06 100644 --- a/colossalai/inference/tensor_parallel/engine.py +++ b/colossalai/inference/tensor_parallel/engine.py @@ -97,6 +97,10 @@ def _supported_models() -> List[str]: def generate(self, input_tokens, generate_kwargs) -> torch.Tensor: if isinstance(input_tokens, torch.Tensor): input_tokens = dict(input_ids=input_tokens, attention_mask=torch.ones_like(input_tokens, dtype=torch.bool)) + for t in input_tokens: + if torch.is_tensor(input_tokens[t]): + input_tokens[t] = input_tokens[t].cuda() + if self.sharded_model is not None: return self.generate_by_set_infer_state(input_tokens, generate_kwargs) @@ -132,13 +136,6 @@ def generate_by_set_infer_state(self, input_tokens, generate_kwargs) -> torch.Te setattr(model, 'infer_state', batch_infer_state) generate_kwargs.update(max_new_tokens=self.max_output_len) - - if isinstance(input_tokens, torch.Tensor): - input_tokens = dict(input_ids=input_tokens) - for t in input_tokens: - if torch.is_tensor(input_tokens[t]): - input_tokens[t] = input_tokens[t].cuda() - outputs = self.sharded_model.generate(**input_tokens, **generate_kwargs, early_stopping=False) return outputs