diff --git a/tools/server/public/index.html.gz b/tools/server/public/index.html.gz index d7363e13eb2..03d36ba084a 100644 Binary files a/tools/server/public/index.html.gz and b/tools/server/public/index.html.gz differ diff --git a/tools/server/webui/src/components/ChatMessage.tsx b/tools/server/webui/src/components/ChatMessage.tsx index 08eb423526b..c868c226edf 100644 --- a/tools/server/webui/src/components/ChatMessage.tsx +++ b/tools/server/webui/src/components/ChatMessage.tsx @@ -27,6 +27,7 @@ export default function ChatMessage({ onEditMessage, onChangeSibling, isPending, + onContinueMessage, }: { msg: Message | PendingMessage; siblingLeafNodeIds: Message['id'][]; @@ -34,6 +35,7 @@ export default function ChatMessage({ id?: string; onRegenerateMessage(msg: Message): void; onEditMessage(msg: Message, content: string): void; + onContinueMessage(msg: Message, content: string): void; onChangeSibling(sibling: Message['id']): void; isPending?: boolean; }) { @@ -123,7 +125,11 @@ export default function ChatMessage({ onClick={() => { if (msg.content !== null) { setEditingContent(null); - onEditMessage(msg as Message, editingContent); + if (msg.role === 'user') { + onEditMessage(msg as Message, editingContent); + } else { + onContinueMessage(msg as Message, editingContent); + } } }} > @@ -248,12 +254,22 @@ export default function ChatMessage({ )} + {!isPending && ( + setEditingContent(msg.content)} + disabled={msg.content === null} + tooltipsContent="Edit message" + > + + + )} + )} - )} diff --git a/tools/server/webui/src/components/ChatScreen.tsx b/tools/server/webui/src/components/ChatScreen.tsx index b645a494d68..94fb71296ba 100644 --- a/tools/server/webui/src/components/ChatScreen.tsx +++ b/tools/server/webui/src/components/ChatScreen.tsx @@ -94,6 +94,7 @@ export default function ChatScreen() { pendingMessages, canvasData, replaceMessageAndGenerate, + continueMessageAndGenerate, } = useAppContext(); const textarea: ChatTextareaApi = useChatTextarea(prefilledMsg.content()); @@ -188,6 +189,20 @@ export default function ChatScreen() { scrollToBottom(false); }; + const handleContinueMessage = async (msg: Message, content: string) => { + if (!viewingChat || !continueMessageAndGenerate) return; + setCurrNodeId(msg.id); + scrollToBottom(false); + await continueMessageAndGenerate( + viewingChat.conv.id, + msg.id, + content, + onChunk + ); + setCurrNodeId(-1); + scrollToBottom(false); + }; + const hasCanvas = !!canvasData; useEffect(() => { @@ -205,7 +220,7 @@ export default function ChatScreen() { // due to some timing issues of StorageUtils.appendMsg(), we need to make sure the pendingMsg is not duplicated upon rendering (i.e. appears once in the saved conversation and once in the pendingMsg) const pendingMsgDisplay: MessageDisplay[] = - pendingMsg && messages.at(-1)?.msg.id !== pendingMsg.id + pendingMsg && !messages.some((m) => m.msg.id === pendingMsg.id) // Only show if pendingMsg is not an existing message being continued ? [ { msg: pendingMsg, @@ -244,18 +259,35 @@ export default function ChatScreen() { )} - {[...messages, ...pendingMsgDisplay].map((msg) => ( - - ))} + {[...messages, ...pendingMsgDisplay].map((msgDisplay) => { + const actualMsgObject = msgDisplay.msg; + // Check if the current message from the list is the one actively being generated/continued + const isThisMessageTheActivePendingOne = + pendingMsg?.id === actualMsgObject.id; + + return ( + + ); + })} {/* chat input */} diff --git a/tools/server/webui/src/utils/app.context.tsx b/tools/server/webui/src/utils/app.context.tsx index 96cffd95aba..86f3ec8eee7 100644 --- a/tools/server/webui/src/utils/app.context.tsx +++ b/tools/server/webui/src/utils/app.context.tsx @@ -39,6 +39,12 @@ interface AppContextValue { extra: Message['extra'], onChunk: CallbackGeneratedChunk ) => Promise; + continueMessageAndGenerate: ( + convId: string, + messageIdToContinue: Message['id'], + newContent: string, + onChunk: CallbackGeneratedChunk + ) => Promise; // canvas canvasData: CanvasData | null; @@ -156,7 +162,8 @@ export const AppContextProvider = ({ const generateMessage = async ( convId: string, leafNodeId: Message['id'], - onChunk: CallbackGeneratedChunk + onChunk: CallbackGeneratedChunk, + isContinuation: boolean = false ) => { if (isGenerating(convId)) return; @@ -179,17 +186,36 @@ export const AppContextProvider = ({ } const pendingId = Date.now() + 1; - let pendingMsg: PendingMessage = { - id: pendingId, - convId, - type: 'text', - timestamp: pendingId, - role: 'assistant', - content: null, - parent: leafNodeId, - children: [], - }; - setPending(convId, pendingMsg); + let pendingMsg: Message | PendingMessage; + + if (isContinuation) { + const existingAsstMsg = await StorageUtils.getMessage(convId, leafNodeId); + if (!existingAsstMsg || existingAsstMsg.role !== 'assistant') { + toast.error( + 'Cannot continue: target message not found or not an assistant message.' + ); + throw new Error( + 'Cannot continue: target message not found or not an assistant message.' + ); + } + pendingMsg = { + ...existingAsstMsg, + content: existingAsstMsg.content || '', + }; + setPending(convId, pendingMsg as PendingMessage); + } else { + pendingMsg = { + id: pendingId, + convId, + type: 'text', + timestamp: pendingId, + role: 'assistant', + content: null, + parent: leafNodeId, + children: [], + }; + setPending(convId, pendingMsg as PendingMessage); + } try { // prepare messages for API @@ -272,7 +298,7 @@ export const AppContextProvider = ({ predicted_ms: timings.predicted_ms, }; } - setPending(convId, pendingMsg); + setPending(convId, pendingMsg as PendingMessage); onChunk(); // don't need to switch node for pending message } } catch (err) { @@ -289,10 +315,15 @@ export const AppContextProvider = ({ } if (pendingMsg.content !== null) { - await StorageUtils.appendMsg(pendingMsg as Message, leafNodeId); + if (isContinuation) { + await StorageUtils.updateMessage(pendingMsg as Message); + } else if (pendingMsg.content.trim().length > 0) { + await StorageUtils.appendMsg(pendingMsg as Message, leafNodeId); + } } setPending(convId, null); - onChunk(pendingId); // trigger scroll to bottom and switch to the last node + const finalNodeId = (pendingMsg as Message).id; + onChunk(finalNodeId); }; const sendMessage = async ( @@ -333,7 +364,7 @@ export const AppContextProvider = ({ onChunk(currMsgId); try { - await generateMessage(convId, currMsgId, onChunk); + await generateMessage(convId, currMsgId, onChunk, false); return true; } catch (_) { // TODO: rollback @@ -380,6 +411,47 @@ export const AppContextProvider = ({ await generateMessage(convId, parentNodeId, onChunk); }; + const continueMessageAndGenerate = async ( + convId: string, + messageIdToContinue: Message['id'], + newContent: string, + onChunk: CallbackGeneratedChunk + ) => { + if (isGenerating(convId)) return; + + const existingMessage = await StorageUtils.getMessage( + convId, + messageIdToContinue + ); + if (!existingMessage || existingMessage.role !== 'assistant') { + console.error( + 'Cannot continue non-assistant message or message not found' + ); + toast.error( + 'Failed to continue message: Not an assistant message or not found.' + ); + return; + } + + const updatedAssistantMessage: Message = { + ...existingMessage, + content: newContent, + timestamp: Date.now(), + children: [], // Clear existing children to start a new branch of generation + extra: existingMessage.extra, // Preserve existing extra data + }; + + await StorageUtils.updateMessage(updatedAssistantMessage); + + onChunk(messageIdToContinue); + + try { + await generateMessage(convId, messageIdToContinue, onChunk, true); + } catch (err) { + console.error('Error continuing message'); + } + }; + const saveConfig = (config: typeof CONFIG_DEFAULT) => { StorageUtils.setConfig(config); setConfig(config); @@ -394,6 +466,7 @@ export const AppContextProvider = ({ sendMessage, stopGenerating, replaceMessageAndGenerate, + continueMessageAndGenerate, canvasData, setCanvasData, config, diff --git a/tools/server/webui/src/utils/storage.ts b/tools/server/webui/src/utils/storage.ts index 505693e9272..bf7116b841e 100644 --- a/tools/server/webui/src/utils/storage.ts +++ b/tools/server/webui/src/utils/storage.ts @@ -213,6 +213,22 @@ const StorageUtils = { localStorage.setItem('theme', theme); } }, + async getMessage( + convId: string, + messageId: Message['id'] + ): Promise { + return await db.messages.where({ convId, id: messageId }).first(); + }, + async updateMessage(updatedMessage: Message): Promise { + await db.transaction('rw', db.conversations, db.messages, async () => { + await db.messages.put(updatedMessage); + await db.conversations.update(updatedMessage.convId, { + lastModified: Date.now(), + currNode: updatedMessage.id, + }); + }); + dispatchConversationChange(updatedMessage.convId); + }, }; export default StorageUtils;