Skip to content

Commit fa0d630

Browse files
committed
Handle exceptions from extension APIs + expose chat_completion APIs
1 parent f8e925d commit fa0d630

File tree

1 file changed

+81
-35
lines changed

1 file changed

+81
-35
lines changed

llms/main.py

Lines changed: 81 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,7 @@ def __init__(self, status, reason, body, headers=None):
491491
self.headers = headers
492492
super().__init__(f"HTTP {status} {reason}")
493493

494+
494495
def save_bytes_to_cache(base64_data, filename, file_info):
495496
ext = filename.split(".")[-1]
496497
mimetype = get_file_mime_type(filename)
@@ -519,6 +520,7 @@ def save_bytes_to_cache(base64_data, filename, file_info):
519520
info.update(file_info)
520521
return url, info
521522

523+
522524
def save_image_to_cache(base64_data, filename, image_info):
523525
ext = filename.split(".")[-1]
524526
mimetype = get_file_mime_type(filename)
@@ -1078,7 +1080,41 @@ def api_providers():
10781080
return ret
10791081

10801082

1081-
async def chat_completion(chat):
1083+
def to_error_response(e, stacktrace=False):
1084+
status = {"errorCode": "Exception", "message": str(e)}
1085+
if stacktrace:
1086+
status["stackTrace"] = traceback.format_exc()
1087+
return {"responseStatus": status}
1088+
1089+
1090+
def g_chat_request(template=None, text=None, model=None, system_prompt=None):
1091+
chat_template = g_config["defaults"].get(template or "text")
1092+
if not chat_template:
1093+
raise Exception(f"Chat template '{template}' not found")
1094+
1095+
chat = chat_template.copy()
1096+
if model:
1097+
chat["model"] = model
1098+
if system_prompt is not None:
1099+
chat["messages"].insert(0, {"role": "system", "content": system_prompt})
1100+
if text is not None:
1101+
if not chat["messages"] or len(chat["messages"]) == 0:
1102+
chat["messages"] = [{"role": "user", "content": [{"type": "text", "text": ""}]}]
1103+
1104+
# replace content of last message if exists, else add
1105+
last_msg = chat["messages"][-1] if "messages" in chat else None
1106+
if last_msg and last_msg["role"] == "user":
1107+
if isinstance(last_msg["content"], list):
1108+
last_msg["content"][-1]["text"] = text
1109+
else:
1110+
last_msg["content"] = text
1111+
else:
1112+
chat["messages"].append({"role": "user", "content": text})
1113+
1114+
return chat
1115+
1116+
1117+
async def g_chat_completion(chat):
10821118
model = chat["model"]
10831119
# get first provider that has the model
10841120
candidate_providers = [name for name, provider in g_handlers.items() if provider.provider_model(model)]
@@ -1243,16 +1279,8 @@ async def cli_chat(chat, image=None, audio=None, file=None, args=None, raw=False
12431279
printdump(chat)
12441280

12451281
try:
1246-
# Apply pre-chat filters
1247-
context = {"chat": chat}
1248-
for filter_func in g_app.chat_request_filters:
1249-
chat = await filter_func(chat, context)
1282+
response = await g_app.chat_completion(chat)
12501283

1251-
response = await chat_completion(chat)
1252-
1253-
# Apply post-chat filters
1254-
for filter_func in g_app.chat_response_filters:
1255-
response = await filter_func(response, context)
12561284
if raw:
12571285
print(json.dumps(response, indent=2))
12581286
exit(0)
@@ -2010,6 +2038,24 @@ def __init__(self, cli_args, extra_args):
20102038
"21:9": "1536×672",
20112039
}
20122040

2041+
def chat_request(self, template=None, text=None, model=None, system_prompt=None):
2042+
return g_chat_request(template=template, text=text, model=model, system_prompt=system_prompt)
2043+
2044+
async def chat_completion(self, chat, context=None):
2045+
# Apply pre-chat filters
2046+
if context is None:
2047+
context = {"chat": chat}
2048+
for filter_func in self.chat_request_filters:
2049+
chat = await filter_func(chat, context)
2050+
2051+
response = await g_chat_completion(chat)
2052+
2053+
# Apply post-chat filters
2054+
for filter_func in self.chat_response_filters:
2055+
response = await filter_func(response, context)
2056+
2057+
return response
2058+
20132059

20142060
class ExtensionContext:
20152061
def __init__(self, app, path):
@@ -2098,8 +2144,11 @@ def add_post(self, path, handler, **kwargs):
20982144
def get_config(self):
20992145
return g_config
21002146

2101-
def chat_completion(self, chat):
2102-
return chat_completion(chat)
2147+
def chat_request(self, template=None, text=None, model=None, system_prompt=None):
2148+
return self.app.chat_request(template=template, text=text, model=model, system_prompt=system_prompt)
2149+
2150+
def chat_completion(self, chat, context=None):
2151+
return self.app.chat_completion(chat, context=context)
21032152

21042153
def get_providers(self):
21052154
return g_handlers
@@ -2735,19 +2784,8 @@ async def chat_handler(request):
27352784

27362785
try:
27372786
chat = await request.json()
2738-
2739-
# Apply pre-chat filters
27402787
context = {"request": request, "chat": chat}
2741-
for filter_func in g_app.chat_request_filters:
2742-
chat = await filter_func(chat, context)
2743-
2744-
response = await chat_completion(chat)
2745-
2746-
# Apply post-chat filters
2747-
# Apply post-chat filters
2748-
for filter_func in g_app.chat_response_filters:
2749-
response = await filter_func(response, context)
2750-
2788+
response = await g_app.chat_completion(chat, context)
27512789
return web.json_response(response)
27522790
except Exception as e:
27532791
return web.json_response({"error": str(e)}, status=500)
@@ -3222,9 +3260,25 @@ async def not_found_handler(request):
32223260

32233261
# go through and register all g_app extensions
32243262
for handler in g_app.server_add_get:
3225-
app.router.add_get(handler[0], handler[1], **handler[2])
3263+
handler_fn = handler[1]
3264+
3265+
async def managed_handler(request, handler_fn=handler_fn):
3266+
try:
3267+
return await handler_fn(request)
3268+
except Exception as e:
3269+
return web.json_response(to_error_response(e, stacktrace=True), status=500)
3270+
3271+
app.router.add_get(handler[0], managed_handler, **handler[2])
32263272
for handler in g_app.server_add_post:
3227-
app.router.add_post(handler[0], handler[1], **handler[2])
3273+
handler_fn = handler[1]
3274+
3275+
async def managed_handler(request, handler_fn=handler_fn):
3276+
try:
3277+
return await handler_fn(request)
3278+
except Exception as e:
3279+
return web.json_response(to_error_response(e, stacktrace=True), status=500)
3280+
3281+
app.router.add_post(handler[0], managed_handler, **handler[2])
32283282

32293283
# Serve index.html from root
32303284
async def index_handler(request):
@@ -3356,15 +3410,7 @@ async def start_background_tasks(app):
33563410
if len(extra_args) > 0:
33573411
prompt = " ".join(extra_args)
33583412
if not chat["messages"] or len(chat["messages"]) == 0:
3359-
chat["messages"] = [{
3360-
"role": "user",
3361-
"content": [
3362-
{
3363-
"type": "text",
3364-
"text": ""
3365-
}
3366-
]
3367-
}]
3413+
chat["messages"] = [{"role": "user", "content": [{"type": "text", "text": ""}]}]
33683414

33693415
# replace content of last message if exists, else add
33703416
last_msg = chat["messages"][-1] if "messages" in chat else None

0 commit comments

Comments
 (0)