Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 33 additions & 5 deletions notebook_intelligence/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -1876,14 +1876,34 @@ class WebsocketCopilotHandler(websocket.WebSocketHandler):

def __init__(self, application, request, context_factory=None, **kwargs):
super().__init__(application, request, **kwargs)
# TODO: cleanup
# Keyed by request messageId; entries are populated when a chat /
# inline-completion / generate-code request kicks off, and removed
# by `_run_request_thread` once the worker thread returns. The
# entry holds the response emitter (for ChatUserInput and
# RunUICommandResponse routing) and the cancel token. Without the
# removal step the dict grew unbounded for the lifetime of the
# websocket — every long chat session leaked one emitter +
# cancel token per turn.
self._messageCallbackHandlers: dict[str, MessageCallbackHandlers] = {}
self.chat_history = ChatHistory()
self._context_factory = context_factory or RuleContextFactory()
ws_connector = ThreadSafeWebSocketConnector(self)
ai_service_manager.websocket_connector = ws_connector
github_copilot.websocket_connector = ws_connector

def _run_request_thread(self, coro, message_id):
"""Worker-thread entrypoint that pops the messageId from
`_messageCallbackHandlers` on completion (success or failure).
The dict entry is only needed while the request is in flight —
ChatUserInput / RunUICommandResponse / Cancel messages from the
client are routed to the emitter by messageId, and the client
stops sending those once the response stream ends.
"""
try:
asyncio.run(coro)
finally:
self._messageCallbackHandlers.pop(message_id, None)

def open(self):
pass

Expand Down Expand Up @@ -2057,7 +2077,8 @@ def on_message(self, message):

# last prompt is added later
request_chat_history = chat_history[chat_history_initial_size:-1] if is_claude_code_mode else chat_history[:-1]
thread = threading.Thread(target=asyncio.run, args=(ai_service_manager.handle_chat_request(ChatRequest(chat_mode=chat_mode, tool_selection=tool_selection, prompt=prompt, chat_history=request_chat_history, cancel_token=cancel_token, rule_context=rule_context), response_emitter),))
coro = ai_service_manager.handle_chat_request(ChatRequest(chat_mode=chat_mode, tool_selection=tool_selection, prompt=prompt, chat_history=request_chat_history, cancel_token=cancel_token, rule_context=rule_context), response_emitter)
thread = threading.Thread(target=self._run_request_thread, args=(coro, messageId))
thread.start()
elif messageType == RequestDataType.GenerateCode:
data = msg['data']
Expand Down Expand Up @@ -2091,7 +2112,8 @@ def on_message(self, message):
root_dir=NotebookIntelligence.root_dir
)

thread = threading.Thread(target=asyncio.run, args=(ai_service_manager.handle_chat_request(ChatRequest(chat_mode=chat_mode, prompt=prompt, chat_history=self.chat_history.get_history(chatId), cancel_token=cancel_token, rule_context=rule_context), response_emitter, options={"system_prompt": f"You are an assistant that generates code for '{language}' language. You generate code between existing leading and trailing code sections.{existing_code_message} Be concise and return only code as a response. Don't include leading content or trailing content in your response, they are provided only for context. You can reuse methods and symbols defined in leading and trailing content."}),))
coro = ai_service_manager.handle_chat_request(ChatRequest(chat_mode=chat_mode, prompt=prompt, chat_history=self.chat_history.get_history(chatId), cancel_token=cancel_token, rule_context=rule_context), response_emitter, options={"system_prompt": f"You are an assistant that generates code for '{language}' language. You generate code between existing leading and trailing code sections.{existing_code_message} Be concise and return only code as a response. Don't include leading content or trailing content in your response, they are provided only for context. You can reuse methods and symbols defined in leading and trailing content."})
thread = threading.Thread(target=self._run_request_thread, args=(coro, messageId))
thread.start()
elif messageType == RequestDataType.InlineCompletionRequest:
data = msg['data']
Expand All @@ -2106,7 +2128,8 @@ def on_message(self, message):
cancel_token = CancelTokenImpl()
self._messageCallbackHandlers[messageId] = MessageCallbackHandlers(response_emitter, cancel_token)

thread = threading.Thread(target=asyncio.run, args=(WebsocketCopilotHandler.handle_inline_completions(prefix, suffix, language, filename, response_emitter, cancel_token),))
coro = WebsocketCopilotHandler.handle_inline_completions(prefix, suffix, language, filename, response_emitter, cancel_token)
thread = threading.Thread(target=self._run_request_thread, args=(coro, messageId))
thread.start()
elif messageType == RequestDataType.ChatUserInput:
handlers = self._messageCallbackHandlers.get(messageId)
Expand All @@ -2132,7 +2155,12 @@ def on_message(self, message):
handlers.cancel_token.cancel_request()

def on_close(self):
pass
# Drop any handler entries whose worker threads outlive the
# websocket connection. The thread wrapper would clean these up
# on its own once the coro returns, but a long-running request
# left in-flight at disconnect would otherwise pin its emitter
# and cancel token for the lifetime of the worker.
self._messageCallbackHandlers.clear()

async def handle_inline_completions(prefix, suffix, language, filename, response_emitter, cancel_token):
if ai_service_manager.inline_completion_model is None:
Expand Down
126 changes: 126 additions & 0 deletions tests/test_ws_callback_handler_leak.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# Copyright (c) Mehmet Bektas <mbektasgh@outlook.com>

"""Pin the contract that `_messageCallbackHandlers` does not grow
unbounded for the lifetime of the websocket connection.

Pre-fix: every chat / generate-code / inline-completion request added an
entry keyed by messageId and nothing ever removed them. Across a long
chat session this leaked one response emitter + cancel token per turn.

Fix: the worker thread is wrapped in `_run_request_thread`, which pops
the entry once the request coroutine returns. `on_close` clears the
whole dict so requests still in flight at disconnect don't pin their
state for the worker's lifetime.
"""

import asyncio
import threading
from unittest.mock import MagicMock

from notebook_intelligence.extension import (
MessageCallbackHandlers,
WebsocketCopilotHandler,
)


def _make_handler():
"""Build a WebsocketCopilotHandler without booting the Tornado
application. Only the dict-management surface is under test, so
bypassing __init__ is the cleanest approach.
"""
h = WebsocketCopilotHandler.__new__(WebsocketCopilotHandler)
h._messageCallbackHandlers = {}
return h


class TestRunRequestThreadPopsHandler:
def test_pops_entry_after_coro_returns_normally(self):
h = _make_handler()
emitter = MagicMock()
token = MagicMock()
h._messageCallbackHandlers["m1"] = MessageCallbackHandlers(emitter, token)

async def coro():
return "ok"

h._run_request_thread(coro(), "m1")

assert "m1" not in h._messageCallbackHandlers

def test_pops_entry_when_coro_raises(self):
# A worker exception must not leak the entry. The user may keep
# the chat session open after a failed turn and start another;
# repeated failures must not grow the dict. The wrapper deliberately
# re-raises (asyncio.run propagates) so the upstream error surfaces
# in thread-level logging; pytest.raises pins that contract.
h = _make_handler()
h._messageCallbackHandlers["m1"] = MessageCallbackHandlers(MagicMock(), MagicMock())

import pytest as _pytest

async def boom():
raise RuntimeError("upstream failure")

with _pytest.raises(RuntimeError, match="upstream failure"):
h._run_request_thread(boom(), "m1")

assert "m1" not in h._messageCallbackHandlers

def test_unknown_message_id_is_safe(self):
# Pop with default short-circuits cleanly so a race where the
# cleanup runs twice does not crash.
h = _make_handler()

async def coro():
return None

h._run_request_thread(coro(), "never-registered")
# Second call is a no-op.
h._run_request_thread(coro(), "never-registered")

def test_multiple_concurrent_requests_each_clean_up(self):
# The realistic concurrency pattern: several inline-completion
# requests in flight at once. Each thread's cleanup pops only
# its own entry; no global clear.
h = _make_handler()
for i in range(5):
h._messageCallbackHandlers[f"m{i}"] = MessageCallbackHandlers(
MagicMock(), MagicMock()
)

threads = []
for i in range(5):
async def coro():
return None

t = threading.Thread(
target=h._run_request_thread, args=(coro(), f"m{i}")
)
threads.append(t)
t.start()
for t in threads:
t.join()

assert h._messageCallbackHandlers == {}


class TestOnCloseClearsHandlers:
def test_on_close_drops_all_in_flight_entries(self):
# Long-running requests left in flight at disconnect would
# otherwise pin their emitter + cancel token until the worker
# finished. on_close drops everything so the GC can reclaim.
h = _make_handler()
for i in range(3):
h._messageCallbackHandlers[f"m{i}"] = MessageCallbackHandlers(
MagicMock(), MagicMock()
)

h.on_close()

assert h._messageCallbackHandlers == {}

def test_on_close_is_idempotent(self):
h = _make_handler()
h.on_close()
h.on_close()
assert h._messageCallbackHandlers == {}
Loading