Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions applications/Chat/inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ Tha data is from [LLaMA Int8 4bit ChatBot Guide v2](https://rentry.org/llama-tar
| LLaMA-30B | 15.8GB | 20GB | 64GB | RTX 3080 20GB, A4500, A5000, 3090, 4090, 6000, Tesla V100 |
| LLaMA-65B | 31.2GB | 40GB | 128GB | A100 40GB, 2x3090, 2x4090, A40, RTX A6000, 8000, Titan Ada |

## General setup

```shell
pip install -r requirements.txt
```

## 8-bit setup

8-bit quantization is originally supported by the latest [transformers](https://github.com/huggingface/transformers). Please install it from source.
Expand Down
4 changes: 3 additions & 1 deletion applications/Chat/inference/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
fastapi
locustio
locust
numpy
pydantic
safetensors
Expand All @@ -8,3 +8,5 @@ sse_starlette
torch
uvicorn
git+https://github.com/huggingface/transformers
accelerate
bitsandbytes
4 changes: 2 additions & 2 deletions applications/Chat/inference/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from utils import ChatPromptProcessor, Dialogue, LockedIterator, sample_streamingly, update_model_kwargs_fn

CONTEXT = 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.'
MAX_LEN = 2048
MAX_LEN = 512
running_lock = Lock()


Expand Down Expand Up @@ -116,7 +116,7 @@ def generate_no_stream(data: GenerationTaskReq, request: Request):
prompt_len = inputs['input_ids'].size(1)
response = output[0, prompt_len:]
out_string = tokenizer.decode(response, skip_special_tokens=True)
return out_string.lstrip()
return prompt_processor.postprocess_output(out_string)


if __name__ == '__main__':
Expand Down
8 changes: 8 additions & 0 deletions applications/Chat/inference/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
from threading import Lock
from typing import Any, Callable, Generator, List, Optional

Expand Down Expand Up @@ -118,6 +119,9 @@ def _format_dialogue(instruction: str, response: str = ''):
return f'\n\n### Instruction:\n{instruction}\n\n### Response:\n{response}'


STOP_PAT = re.compile(r'(###|instruction:).*', flags=(re.I | re.S))


class ChatPromptProcessor:

def __init__(self, tokenizer, context: str, max_len: int = 2048):
Expand Down Expand Up @@ -164,6 +168,10 @@ def preprocess_prompt(self, history: List[Dialogue], max_new_tokens: int) -> str
prompt += ''.join(rows) + _format_dialogue(last_dialogue.instruction)
return prompt

def postprocess_output(self, output: str) -> str:
output = STOP_PAT.sub('', output)
return output.strip()


class LockedIterator:

Expand Down