Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions applications/Chat/inference/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from slowapi.util import get_remote_address
from sse_starlette.sse import EventSourceResponse
from transformers import AutoTokenizer, GenerationConfig, LlamaForCausalLM
from utils import ChatPromptProcessor, Dialogue, LockedIterator, sample_streamingly, update_model_kwargs_fn
from utils import ChatPromptProcessor, Dialogue, LockedIterator, sample_streamingly, update_model_kwargs_fn, load_json

CONTEXT = 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.'
MAX_LEN = 512
Expand Down Expand Up @@ -111,14 +111,19 @@ def generate(data: GenerationTaskReq, request: Request):
@limiter.limit('1/second')
def generate_no_stream(data: GenerationTaskReq, request: Request):
prompt = prompt_processor.preprocess_prompt(data.history, data.max_new_tokens)
if prompt_processor.has_censored_words(prompt):
return prompt_processor.SAFE_RESPONSE
inputs = {k: v.cuda() for k, v in tokenizer(prompt, return_tensors="pt").items()}
with running_lock:
output = model.generate(**inputs, **data.dict(exclude={'history'}))
output = output.cpu()
prompt_len = inputs['input_ids'].size(1)
response = output[0, prompt_len:]
out_string = tokenizer.decode(response, skip_special_tokens=True)
return prompt_processor.postprocess_output(out_string)
out_string = prompt_processor.postprocess_output(out_string)
if prompt_processor.has_censored_words(out_string):
return prompt_processor.SAFE_RESPONSE
return out_string


if __name__ == '__main__':
Expand All @@ -140,13 +145,19 @@ def generate_no_stream(data: GenerationTaskReq, request: Request):
help='Group size for GPTQ. This is only useful when quantization mode is 4bit. Default: 128.')
parser.add_argument('--http_host', default='0.0.0.0')
parser.add_argument('--http_port', type=int, default=7070)
parser.add_argument('--profanity_file', default=None, help='Path to profanity words list. It should be a JSON file containing a list of words.')
args = parser.parse_args()

if args.quant == '4bit':
assert args.gptq_checkpoint is not None, 'Please specify a GPTQ checkpoint.'

tokenizer = AutoTokenizer.from_pretrained(args.pretrained)
prompt_processor = ChatPromptProcessor(tokenizer, CONTEXT, MAX_LEN)

if args.profanity_file is not None:
censored_words = load_json(args.profanity_file)
else:
censored_words = []
prompt_processor = ChatPromptProcessor(tokenizer, CONTEXT, MAX_LEN, censored_words=censored_words)

if args.quant == '4bit':
model = load_quant(args.pretrained, args.gptq_checkpoint, 4, args.gptq_group_size)
Expand Down
16 changes: 15 additions & 1 deletion applications/Chat/inference/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import re
from threading import Lock
from typing import Any, Callable, Generator, List, Optional
import json

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -123,11 +124,16 @@ def _format_dialogue(instruction: str, response: str = ''):


class ChatPromptProcessor:
SAFE_RESPONSE = 'The input/response contains inappropriate content, please rephrase your prompt.'

def __init__(self, tokenizer, context: str, max_len: int = 2048):
def __init__(self, tokenizer, context: str, max_len: int = 2048, censored_words: List[str]=[]):
self.tokenizer = tokenizer
self.context = context
self.max_len = max_len
if len(censored_words) > 0:
self.censored_pat = re.compile(f'({"|".join(map(re.escape, censored_words))})', flags=re.I)
else:
self.censored_pat = None
# These will be initialized after the first call of preprocess_prompt()
self.context_len: Optional[int] = None
self.dialogue_placeholder_len: Optional[int] = None
Expand Down Expand Up @@ -172,6 +178,10 @@ def postprocess_output(self, output: str) -> str:
output = STOP_PAT.sub('', output)
return output.strip()

def has_censored_words(self, text: str) -> bool:
if self.censored_pat is None:
return False
return self.censored_pat.search(text) is not None

class LockedIterator:

Expand All @@ -185,3 +195,7 @@ def __iter__(self):
def __next__(self):
with self.lock:
return next(self.it)

def load_json(path: str):
with open(path) as f:
return json.load(f)