From 18dd367b12e2e399605183d3fc8c4bbe7fe579c7 Mon Sep 17 00:00:00 2001 From: Alyssa Coghlan Date: Tue, 29 Jul 2025 03:58:03 +1000 Subject: [PATCH] Refactor sync API demultiplexing logic Demultiplexing in the sync API had become quite tangled, with the foreground thread responsible for allocating multiplexing IDs and receive queues, while the background thread handling the actual demultiplexing process. Now the sync API is using async queues for its demultiplexing, move multiplexing ID and queue management entirely to the background thread. --- src/lmstudio/_ws_impl.py | 276 +++++++++++-------------------------- src/lmstudio/_ws_thread.py | 213 ++++++++++++++++++++++++++++ src/lmstudio/async_api.py | 22 +-- src/lmstudio/json_api.py | 121 +++++++++++----- src/lmstudio/sync_api.py | 26 ++-- tests/test_sessions.py | 2 +- 6 files changed, 402 insertions(+), 258 deletions(-) create mode 100644 src/lmstudio/_ws_thread.py diff --git a/src/lmstudio/_ws_impl.py b/src/lmstudio/_ws_impl.py index a9d917c..34217f0 100644 --- a/src/lmstudio/_ws_impl.py +++ b/src/lmstudio/_ws_impl.py @@ -1,7 +1,7 @@ """Shared core async websocket implementation for the LM Studio remote access API.""" -# Sync API: runs in background thread with sync queues -# Async convenience API: runs in background thread with async queues +# Sync API: runs in dedicated background thread +# Async convenience API (once implemented): runs in dedicated background thread # Async structured API: runs in foreground event loop # Callback handling rules: @@ -10,36 +10,37 @@ # * All callbacks must be invoked from the *foreground* thread/event loop import asyncio -import threading -import weakref from concurrent.futures import Future as SyncFuture -from contextlib import AsyncExitStack +from contextlib import AsyncExitStack, contextmanager from functools import partial from typing import ( Any, Awaitable, Coroutine, Callable, - Iterable, + Generator, + TypeAlias, TypeVar, ) -# Synchronous API still uses an async websocket (just in a background thread) from anyio import create_task_group, move_on_after from httpx_ws import aconnect_ws, AsyncWebSocketSession, HTTPXWSException from .schemas import DictObject -from .json_api import LMStudioWebsocket, LMStudioWebsocketError - -from ._logging import new_logger, LogEventContext +from .json_api import ( + LMStudioWebsocket, + LMStudioWebsocketError, + MultiplexingManager, + RxQueue, +) +from ._logging import LogEventContext, new_logger # Allow the core client websocket management to be shared across all SDK interaction APIs # See https://discuss.python.org/t/daemon-threads-and-background-task-termination/77604 # (Note: this implementation has the elements needed to run on *current* Python versions # and omits the generalised features that the SDK doesn't need) -# Already used by the sync API, async client is still to be migrated T = TypeVar("T") @@ -194,6 +195,17 @@ def run_coroutine_threadsafe(self, coro: Coroutine[Any, Any, T]) -> SyncFuture[T raise RuntimeError(f"{self!r} is currently inactive.") return asyncio.run_coroutine_threadsafe(coro, loop) + def call_threadsafe(self, func: Callable[[], T]) -> SyncFuture[T]: + """Call non-blocking function in the background event loop and make the result available. + + Important: function must NOT access any scoped resources from the calling scope. + """ + + async def coro() -> T: + return func() + + return self.run_coroutine_threadsafe(coro()) + def call_soon_threadsafe(self, func: Callable[[], Any]) -> asyncio.Handle: """Call given non-blocking function in the background event loop.""" loop = self._event_loop @@ -202,126 +214,13 @@ def call_soon_threadsafe(self, func: Callable[[], Any]) -> asyncio.Handle: return loop.call_soon_threadsafe(func) -class BackgroundThread(threading.Thread): - """Background async event loop thread.""" - - def __init__( - self, - task_target: Callable[[], Coroutine[Any, Any, Any]] | None = None, - name: str | None = None, - ) -> None: - # Accepts the same args as `threading.Thread`, *except*: - # * a `task_target` coroutine replaces the `target` function - # * No `daemon` option (always runs as a daemon) - # Variant: accept `debug` and `loop_factory` options to forward to `asyncio.run` - # Alternative: accept a `task_runner` callback, defaulting to `asyncio.run` - self._task_target = task_target - self._loop_started = loop_started = threading.Event() - self._task_manager = AsyncTaskManager(on_activation=loop_started.set) - # Annoyingly, we have to mark the background thread as a daemon thread to - # prevent hanging at shutdown. Even checking `sys.is_finalizing()` is inadequate - # https://discuss.python.org/t/should-sys-is-finalizing-report-interpreter-finalization-instead-of-runtime-finalization/76695 - # TODO: skip thread daemonization when running in a subinterpreter - # (and also disable the convenience API in subinterpreters to avoid hanging on shutdown) - super().__init__(name=name, daemon=True) - weakref.finalize(self, self.terminate) - - @property - def task_manager(self) -> AsyncTaskManager: - return self._task_manager - - def start(self, wait_for_loop: bool = True) -> None: - """Start background thread and (optionally) wait for the event loop to be ready.""" - super().start() - if wait_for_loop: - self.wait_for_loop() - - def run(self) -> None: - """Run an async event loop in the background thread.""" - # Only public to override threading.Thread.run - asyncio.run(self._task_manager.run_until_terminated(self._task_target)) - - def wait_for_loop(self) -> asyncio.AbstractEventLoop | None: - """Wait for the event loop to start from a synchronous foreground thread.""" - if self._task_manager._event_loop is None and not self._task_manager.activated: - self._loop_started.wait() - return self._task_manager._event_loop - - async def wait_for_loop_async(self) -> asyncio.AbstractEventLoop | None: - """Wait for the event loop to start from an asynchronous foreground thread.""" - return await asyncio.to_thread(self.wait_for_loop) - - def terminate(self) -> bool: - """Request termination of the event loop from a synchronous foreground thread.""" - return self._task_manager.request_termination_threadsafe().result() - - async def terminate_async(self) -> bool: - """Request termination of the event loop from an asynchronous foreground thread.""" - return await asyncio.to_thread(self.terminate) - - def schedule_background_task(self, func: Callable[[], Any]) -> None: - """Schedule given task in the event loop from a synchronous foreground thread.""" - self._task_manager.schedule_task_threadsafe(func) - - async def schedule_background_task_async(self, func: Callable[[], Any]) -> None: - """Schedule given task in the event loop from an asynchronous foreground thread.""" - return await asyncio.to_thread(self.schedule_background_task, func) - - def run_background_coroutine(self, coro: Coroutine[Any, Any, T]) -> T: - """Run given coroutine in the event loop and wait for the result.""" - return self._task_manager.run_coroutine_threadsafe(coro).result() - - async def run_background_coroutine_async(self, coro: Coroutine[Any, Any, T]) -> T: - """Run given coroutine in the event loop and await the result.""" - return await asyncio.to_thread(self.run_background_coroutine, coro) - - def call_in_background(self, func: Callable[[], Any]) -> None: - """Call given non-blocking function in the background event loop.""" - self._task_manager.call_soon_threadsafe(func) - - -# By default, the weakref finalization atexit hook is registered lazily. -# This can lead to shutdown sequencing issues if SDK users attempt to access -# client instances (such as the default sync client) from atexit hooks -# registered at import time (so they may end up running after the weakref -# finalization hook has already terminated background threads) -# Creating this finalizer here ensures the weakref finalization hook is -# registered at import time, and hence runs *after* any such hooks -# (assuming the lmstudio SDK is imported before the hooks are registered) -def _register_weakref_atexit_hook() -> None: - class C: - pass - - weakref.finalize(C(), int) - - -_register_weakref_atexit_hook() -del _register_weakref_atexit_hook - - -class AsyncWebsocketThread(BackgroundThread): - def __init__(self, log_context: LogEventContext | None = None) -> None: - super().__init__(task_target=self._log_thread_execution) - self._logger = logger = new_logger(type(self).__name__) - logger.update_context(log_context, thread_id=self.name) - - async def _log_thread_execution(self) -> None: - self._logger.info("Websocket handling thread started") - never_set = asyncio.Event() - try: - # Run the event loop until termination is requested - await never_set.wait() - except (asyncio.CancelledError, GeneratorExit): - raise - except BaseException: - err_msg = "Terminating websocket thread due to exception" - self._logger.debug(err_msg, exc_info=True) - finally: - self._logger.info("Websocket thread terminated") +AsyncChannelInfo: TypeAlias = tuple[int, Callable[[], Awaitable[Any]]] +AsyncRemoteCallInfo: TypeAlias = tuple[int, Callable[[], Awaitable[Any]]] # TODO: Improve code sharing between AsyncWebsocketHandler and # the async-native AsyncLMStudioWebsocket implementation +# (likely by migrating the websocket over to using the handler) class AsyncWebsocketHandler: """Async task handler for a single websocket connection.""" @@ -332,7 +231,6 @@ def __init__( task_manager: AsyncTaskManager, ws_url: str, auth_details: DictObject, - enqueue_message: Callable[[DictObject | None], Awaitable[bool]], log_context: LogEventContext | None = None, ) -> None: self._auth_details = auth_details @@ -344,9 +242,9 @@ def __init__( self._ws: AsyncWebSocketSession | None = None self._ws_disconnected = asyncio.Event() self._rx_task: asyncio.Task[None] | None = None - self._enqueue_message = enqueue_message self._logger = logger = new_logger(type(self).__name__) logger.update_context(log_context) + self._mux = MultiplexingManager(logger) async def connect(self) -> bool: """Connect websocket from the task manager's event loop.""" @@ -386,7 +284,7 @@ async def _logged_ws_handler(self) -> None: err_msg = "Terminating websocket task due to exception" self._logger.debug(err_msg, exc_info=True) finally: - # Ensure the foreground thread is unblocked even if the + # Ensure connections attempt are unblocked even if the # background async task errors out completely self._connection_attempted.set() self._logger.info("Websocket task terminated") @@ -403,9 +301,7 @@ async def _handle_ws(self) -> None: raise def _clear_task_state() -> None: - # Break the reference cycle with the foreground thread - del self._enqueue_message - # Websocket is about to be disconnected + # Websocket is about to be disconnected (if it isn't already) self._ws = None resources.callback(_clear_task_state) @@ -423,7 +319,7 @@ def _clear_task_state() -> None: self._logger.info("Websocket demultiplexing task terminated.") # Notify foreground thread of background thread termination # (this covers termination due to link failure) - await self._enqueue_message(None) + await self.notify_client_termination() dc_timeout = self.WS_DISCONNECT_TIMEOUT with move_on_after(dc_timeout, shield=True) as cancel_scope: # Workaround an anyio/httpx-ws issue with task cancellation: @@ -447,6 +343,9 @@ async def send_json(self, message: DictObject) -> None: ws = self._ws if ws is None: # Assume app is shutting down and the owning task has already been cancelled + rx_queue = self._mux.map_tx_message(message) + if rx_queue is not None: + await rx_queue.put(None) return try: await ws.send_json(message) @@ -464,9 +363,47 @@ def run_background_coroutine(self, coro: Coroutine[Any, Any, T]) -> T: """Run given coroutine in the event loop and wait for the result.""" return self._task_manager.run_coroutine_threadsafe(coro).result() - def rx_queue_get_threadsafe( - self, rx_queue: asyncio.Queue[Any], timeout: float | None - ) -> Any: + @contextmanager + def open_channel(self) -> Generator[AsyncChannelInfo, None, None]: + assert self._task_manager.check_running_in_task_loop() + rx_queue: RxQueue = asyncio.Queue() + with self._mux.assign_channel_id(rx_queue) as call_id: + yield call_id, rx_queue.get + + @contextmanager + def start_call(self) -> Generator[AsyncRemoteCallInfo, None, None]: + assert self._task_manager.check_running_in_task_loop() + rx_queue: RxQueue = asyncio.Queue() + with self._mux.assign_call_id(rx_queue) as call_id: + yield call_id, rx_queue.get + + def new_threadsafe_rx_queue(self) -> tuple[RxQueue, Callable[[float | None], Any]]: + rx_queue: RxQueue = asyncio.Queue() + return rx_queue, partial(self._rx_queue_get_threadsafe, rx_queue) + + def acquire_channel_id_threadsafe(self, rx_queue: RxQueue) -> int: + future = self._task_manager.call_threadsafe( + partial(self._mux.acquire_channel_id, rx_queue) + ) + return future.result() # Wait for background thread to assign the ID + + def release_channel_id_threadsafe(self, channel_id: int, rx_queue: RxQueue) -> None: + self._task_manager.call_soon_threadsafe( + partial(self._mux.release_channel_id, channel_id, rx_queue) + ) + + def acquire_call_id_threadsafe(self, rx_queue: RxQueue) -> int: + future = self._task_manager.call_threadsafe( + partial(self._mux.acquire_call_id, rx_queue) + ) + return future.result() # Wait for background thread to assign the ID + + def release_call_id_threadsafe(self, call_id: int, rx_queue: RxQueue) -> None: + self._task_manager.call_soon_threadsafe( + partial(self._mux.release_call_id, call_id, rx_queue) + ) + + def _rx_queue_get_threadsafe(self, rx_queue: RxQueue, timeout: float | None) -> Any: future = self._task_manager.run_coroutine_threadsafe(rx_queue.get()) try: return future.result(timeout) @@ -474,12 +411,6 @@ def rx_queue_get_threadsafe( future.cancel() raise - def rx_queue_put_threadsafe( - self, rx_queue: asyncio.Queue[Any], message: Any - ) -> None: - future = self._task_manager.run_coroutine_threadsafe(rx_queue.put(message)) - return future.result() - async def _receive_json(self) -> Any: # This is only called if the websocket has been created assert self._task_manager.check_running_in_task_loop() @@ -536,46 +467,12 @@ async def _receive_messages(self) -> None: self._logger.error("Websocket failed, terminating session.") break - -class SyncToAsyncWebsocketBridge: - def __init__( - self, - ws_thread: AsyncWebsocketThread, - ws_url: str, - auth_details: DictObject, - get_queue: Callable[[DictObject | None], asyncio.Queue[Any] | None], - iter_queues: Callable[[], Iterable[asyncio.Queue[Any]]], - log_context: LogEventContext, - ) -> None: - self._get_queue = get_queue - self._iter_queues = iter_queues - self._ws_handler = AsyncWebsocketHandler( - ws_thread.task_manager, - ws_url, - auth_details, - self._enqueue_message, - log_context, - ) - self._logger = logger = new_logger(type(self).__name__) - logger.update_context(log_context) - - def connect(self) -> bool: - return self._ws_handler.connect_threadsafe() - - def disconnect(self) -> None: - self._ws_handler.disconnect_threadsafe() - - def send_json(self, message: DictObject) -> None: - self._ws_handler.send_json_threadsafe(message) - - def new_rx_queue(self) -> tuple[asyncio.Queue[Any], Callable[[float | None], Any]]: - rx_queue: asyncio.Queue[Any] = asyncio.Queue() - return rx_queue, partial(self._ws_handler.rx_queue_get_threadsafe, rx_queue) - async def _enqueue_message(self, message: Any) -> bool: - rx_queue = self._get_queue(message) if message is None: + self._logger.info(f"Websocket session failed ({self._ws_url})") + self._ws = None return await self.notify_client_termination() > 0 + rx_queue = self._mux.map_rx_message(message) if rx_queue is None: return False await rx_queue.put(message) @@ -584,7 +481,7 @@ async def _enqueue_message(self, message: Any) -> bool: async def notify_client_termination(self) -> int: """Send None to all clients with open receive queues (from background thread).""" num_clients = 0 - for rx_queue in self._iter_queues(): + for rx_queue in self._mux.all_queues(): await rx_queue.put(None) num_clients += 1 self._logger.debug( @@ -595,19 +492,4 @@ async def notify_client_termination(self) -> int: def notify_client_termination_threadsafe(self) -> int: """Send None to all clients with open receive queues (from foreground thread).""" - return self._ws_handler.run_background_coroutine( - self.notify_client_termination() - ) - - # These attributes are currently accessed directly... - @property - def _ws(self) -> AsyncWebSocketSession | None: - return self._ws_handler._ws - - @property - def _connection_failure(self) -> Exception | None: - return self._ws_handler._connection_failure - - @property - def _auth_failure(self) -> Any | None: - return self._ws_handler._auth_failure + return self.run_background_coroutine(self.notify_client_termination()) diff --git a/src/lmstudio/_ws_thread.py b/src/lmstudio/_ws_thread.py new file mode 100644 index 0000000..acfed06 --- /dev/null +++ b/src/lmstudio/_ws_thread.py @@ -0,0 +1,213 @@ +"""Background thread async websocket implementation for the LM Studio remote access API.""" + +# Sync API +# Async convenience API (once implemented) + +import asyncio +import threading +import weakref + +from contextlib import contextmanager +from typing import ( + Any, + Coroutine, + Callable, + Generator, + TypeAlias, + TypeVar, +) + +from httpx_ws import AsyncWebSocketSession + +from .schemas import DictObject + +from ._logging import new_logger, LogEventContext +from ._ws_impl import AsyncTaskManager, AsyncWebsocketHandler + +T = TypeVar("T") + + +class BackgroundThread(threading.Thread): + """Background async event loop thread.""" + + def __init__( + self, + task_target: Callable[[], Coroutine[Any, Any, Any]] | None = None, + name: str | None = None, + ) -> None: + # Accepts the same args as `threading.Thread`, *except*: + # * a `task_target` coroutine replaces the `target` function + # * No `daemon` option (always runs as a daemon) + # Variant: accept `debug` and `loop_factory` options to forward to `asyncio.run` + # Alternative: accept a `task_runner` callback, defaulting to `asyncio.run` + self._task_target = task_target + self._loop_started = loop_started = threading.Event() + self._task_manager = AsyncTaskManager(on_activation=loop_started.set) + # Annoyingly, we have to mark the background thread as a daemon thread to + # prevent hanging at shutdown. Even checking `sys.is_finalizing()` is inadequate + # https://discuss.python.org/t/should-sys-is-finalizing-report-interpreter-finalization-instead-of-runtime-finalization/76695 + # TODO: skip thread daemonization when running in a subinterpreter + # (and also disable the convenience API in subinterpreters to avoid hanging on shutdown) + super().__init__(name=name, daemon=True) + weakref.finalize(self, self.terminate) + + @property + def task_manager(self) -> AsyncTaskManager: + return self._task_manager + + def start(self, wait_for_loop: bool = True) -> None: + """Start background thread and (optionally) wait for the event loop to be ready.""" + super().start() + if wait_for_loop: + self.wait_for_loop() + + def run(self) -> None: + """Run an async event loop in the background thread.""" + # Only public to override threading.Thread.run + asyncio.run(self._task_manager.run_until_terminated(self._task_target)) + + def wait_for_loop(self) -> asyncio.AbstractEventLoop | None: + """Wait for the event loop to start from a synchronous foreground thread.""" + if self._task_manager._event_loop is None and not self._task_manager.activated: + self._loop_started.wait() + return self._task_manager._event_loop + + async def wait_for_loop_async(self) -> asyncio.AbstractEventLoop | None: + """Wait for the event loop to start from an asynchronous foreground thread.""" + return await asyncio.to_thread(self.wait_for_loop) + + def terminate(self) -> bool: + """Request termination of the event loop from a synchronous foreground thread.""" + return self._task_manager.request_termination_threadsafe().result() + + async def terminate_async(self) -> bool: + """Request termination of the event loop from an asynchronous foreground thread.""" + return await asyncio.to_thread(self.terminate) + + def schedule_background_task(self, func: Callable[[], Any]) -> None: + """Schedule given task in the event loop from a synchronous foreground thread.""" + self._task_manager.schedule_task_threadsafe(func) + + async def schedule_background_task_async(self, func: Callable[[], Any]) -> None: + """Schedule given task in the event loop from an asynchronous foreground thread.""" + return await asyncio.to_thread(self.schedule_background_task, func) + + def run_background_coroutine(self, coro: Coroutine[Any, Any, T]) -> T: + """Run given coroutine in the event loop and wait for the result.""" + return self._task_manager.run_coroutine_threadsafe(coro).result() + + async def run_background_coroutine_async(self, coro: Coroutine[Any, Any, T]) -> T: + """Run given coroutine in the event loop and await the result.""" + return await asyncio.to_thread(self.run_background_coroutine, coro) + + def call_in_background(self, func: Callable[[], Any]) -> None: + """Call given non-blocking function in the background event loop.""" + self._task_manager.call_soon_threadsafe(func) + + +# By default, the weakref finalization atexit hook is registered lazily. +# This can lead to shutdown sequencing issues if SDK users attempt to access +# client instances (such as the default sync client) from atexit hooks +# registered at import time (so they may end up running after the weakref +# finalization hook has already terminated background threads) +# Creating this finalizer here ensures the weakref finalization hook is +# registered at import time, and hence runs *after* any such hooks +# (assuming the lmstudio SDK is imported before the hooks are registered) +def _register_weakref_atexit_hook() -> None: + class C: + pass + + weakref.finalize(C(), int) + + +_register_weakref_atexit_hook() +del _register_weakref_atexit_hook + + +class AsyncWebsocketThread(BackgroundThread): + def __init__(self, log_context: LogEventContext | None = None) -> None: + super().__init__(task_target=self._log_thread_execution) + self._logger = logger = new_logger(type(self).__name__) + logger.update_context(log_context, thread_id=self.name) + + async def _log_thread_execution(self) -> None: + self._logger.info("Websocket handling thread started") + never_set = asyncio.Event() + try: + # Run the event loop until termination is requested + await never_set.wait() + except (asyncio.CancelledError, GeneratorExit): + raise + except BaseException: + err_msg = "Terminating websocket thread due to exception" + self._logger.debug(err_msg, exc_info=True) + finally: + self._logger.info("Websocket thread terminated") + + +SyncChannelInfo: TypeAlias = tuple[int, Callable[[float | None], Any]] +SyncRemoteCallInfo: TypeAlias = tuple[int, Callable[[float | None], Any]] + + +class SyncToAsyncWebsocketBridge: + def __init__( + self, + ws_thread: AsyncWebsocketThread, + ws_url: str, + auth_details: DictObject, + log_context: LogEventContext, + ) -> None: + self._ws_handler = AsyncWebsocketHandler( + ws_thread.task_manager, + ws_url, + auth_details, + log_context, + ) + self._logger = logger = new_logger(type(self).__name__) + logger.update_context(log_context) + + def connect(self) -> bool: + return self._ws_handler.connect_threadsafe() + + def disconnect(self) -> None: + self._ws_handler.disconnect_threadsafe() + + def send_json(self, message: DictObject) -> None: + self._ws_handler.send_json_threadsafe(message) + + @contextmanager + def open_channel(self) -> Generator[SyncChannelInfo, None, None]: + ws_handler = self._ws_handler + rx_queue, getter = ws_handler.new_threadsafe_rx_queue() + channel_id = ws_handler.acquire_channel_id_threadsafe(rx_queue) + try: + yield channel_id, getter + finally: + ws_handler.release_channel_id_threadsafe(channel_id, rx_queue) + + @contextmanager + def start_call(self) -> Generator[SyncRemoteCallInfo, None, None]: + ws_handler = self._ws_handler + rx_queue, getter = ws_handler.new_threadsafe_rx_queue() + call_id = ws_handler.acquire_call_id_threadsafe(rx_queue) + try: + yield call_id, getter + finally: + ws_handler.release_call_id_threadsafe(call_id, rx_queue) + + def notify_client_termination_threadsafe(self) -> int: + """Send None to all clients with open receive queues (from foreground thread).""" + return self._ws_handler.notify_client_termination_threadsafe() + + # These attributes are currently accessed directly... + @property + def _ws(self) -> AsyncWebSocketSession | None: + return self._ws_handler._ws + + @property + def _connection_failure(self) -> Exception | None: + return self._ws_handler._connection_failure + + @property + def _auth_failure(self) -> Any | None: + return self._ws_handler._auth_failure diff --git a/src/lmstudio/async_api.py b/src/lmstudio/async_api.py index 775e9ee..09c7359 100644 --- a/src/lmstudio/async_api.py +++ b/src/lmstudio/async_api.py @@ -77,6 +77,7 @@ ModelSessionTypes, ModelTypesEmbedding, ModelTypesLlm, + MultiplexingManager, # Temporary until migration to AsyncWebsocketHandler PredictionStreamBase, PredictionEndpoint, PredictionFirstTokenCallback, @@ -87,6 +88,7 @@ PromptProcessingCallback, RemoteCallHandler, ResponseSchema, + RxQueue, TModelInfo, check_model_namespace, load_struct, @@ -133,7 +135,7 @@ class AsyncChannel(Generic[T]): def __init__( self, channel_id: int, - rx_queue: asyncio.Queue[Any], + rx_queue: RxQueue, endpoint: ChannelEndpoint[T, Any, Any], send_json: Callable[[DictObject], Awaitable[None]], log_context: LogEventContext, @@ -170,9 +172,8 @@ async def rx_stream( # (we can't easily suppress the SDK's own frames for iterators) message = await self._rx_queue.get() if message is None: - contents = None - else: - contents = self._api_channel.handle_rx_message(message) + raise LMStudioRuntimeError("Client unexpectedly disconnected.") + contents = self._api_channel.handle_rx_message(message) if contents is None: self._is_finished = True break @@ -194,7 +195,7 @@ class AsyncRemoteCall: def __init__( self, call_id: int, - rx_queue: asyncio.Queue[Any], + rx_queue: RxQueue, log_context: LogEventContext, notice_prefix: str = "RPC", ) -> None: @@ -214,13 +215,11 @@ async def receive_result(self) -> Any: """Receive call response on the receive queue.""" message = await self._rx_queue.get() if message is None: - return None + raise LMStudioRuntimeError("Client unexpectedly disconnected.") return self._rpc.handle_rx_message(message) -class AsyncLMStudioWebsocket( - LMStudioWebsocket[AsyncWebSocketSession, asyncio.Queue[Any]] -): +class AsyncLMStudioWebsocket(LMStudioWebsocket[AsyncWebSocketSession]): """Asynchronous websocket client that handles demultiplexing of reply messages.""" def __init__( @@ -235,6 +234,7 @@ def __init__( rm.push_async_callback(self._notify_client_termination) self._rx_task: asyncio.Task[None] | None = None self._terminate = asyncio.Event() + self._mux = MultiplexingManager(self._logger) @property def _httpx_ws(self) -> AsyncWebSocketSession | None: @@ -386,7 +386,7 @@ async def open_channel( endpoint: ChannelEndpoint[T, Any, Any], ) -> AsyncGenerator[AsyncChannel[T], None]: """Open a streaming channel over the websocket.""" - rx_queue: asyncio.Queue[Any] = asyncio.Queue() + rx_queue: RxQueue = asyncio.Queue() with self._mux.assign_channel_id(rx_queue) as channel_id: channel = AsyncChannel( channel_id, @@ -427,7 +427,7 @@ async def remote_call( notice_prefix: str = "RPC", ) -> Any: """Make a remote procedure call over the websocket.""" - rx_queue: asyncio.Queue[Any] = asyncio.Queue() + rx_queue: RxQueue = asyncio.Queue() with self._mux.assign_call_id(rx_queue) as call_id: rpc = AsyncRemoteCall( call_id, rx_queue, self._logger.event_context, notice_prefix diff --git a/src/lmstudio/json_api.py b/src/lmstudio/json_api.py index 5eb90a2..39f9d02 100644 --- a/src/lmstudio/json_api.py +++ b/src/lmstudio/json_api.py @@ -1,6 +1,15 @@ -"""Sans I/O protocol implementation for the LM Studio remote access API.""" +"""Common protocol implementation for the LM Studio remote access API.""" -# TODO: Migrate additional protocol details from the [a]sync APIs to the sans I/O API +# In order to simplify the websocket demultiplexing logic, this is NOT +# a full sans I/O protocol implementation. Instead, it is an async +# protocol implementation that supports both async interaction +# (from the same event loop or from one running in another thread) +# *and* sync interaction (by blocking on threadsafe futures) +# +# The I/O *transport* layer is still abstracted out, but the internal +# use of asynchronous queues for message demultiplexing is assumed. + +import asyncio import copy import json import uuid @@ -123,7 +132,7 @@ # From here, we publish everything that might be needed # for API type hints, error handling, defining custom # structured responses, and other expected activities. -# The "sans I/O" API itself is *not* automatically exported. +# The shared API itself is *not* automatically exported. # If API consumers want to use that, they need to access it # explicitly via `lmstudio.json_api`, it isn't exported # implicitly as part of the top-level `lmstudio` API. @@ -532,25 +541,23 @@ def _redact_json(data: DictObject | None) -> DictObject | None: return redacted -# TODO: Now that even the sync API uses asyncio.Queue, -# change the multiplexing manager to no longer be generic -TQueue = TypeVar("TQueue") +RxQueue: TypeAlias = asyncio.Queue[Any] -class MultiplexingManager(Generic[TQueue]): +class MultiplexingManager: """Helper class to allocate distinct protocol multiplexing IDs.""" def __init__(self, logger: StructuredLogger) -> None: """Initialize ID multiplexer.""" - self._open_channels: dict[int, TQueue] = {} + self._open_channels: dict[int, RxQueue] = {} self._last_channel_id = 0 - self._pending_calls: dict[int, TQueue] = {} + self._pending_calls: dict[int, RxQueue] = {} self._last_call_id = 0 # `_active_subscriptions` (if we add signal support) # `_last_subscriber_id` (if we add signal support) self._logger = logger - def all_queues(self) -> Iterator[TQueue]: + def all_queues(self) -> Iterator[asyncio.Queue[Any]]: """Iterate over all queues (for example, to send a shutdown message).""" yield from self._open_channels.values() yield from self._pending_calls.values() @@ -562,18 +569,30 @@ def _get_next_channel_id(self) -> int: self._last_channel_id = next_id return next_id - @contextmanager - def assign_channel_id(self, rx_queue: TQueue) -> Generator[int, None, None]: - """Assign distinct streaming channel ID to given queue.""" + def acquire_channel_id(self, rx_queue: RxQueue) -> int: + """Acquire a distinct streaming channel ID for the given queue.""" channel_id = self._get_next_channel_id() self._open_channels[channel_id] = rx_queue + return channel_id + + def release_channel_id(self, channel_id: int, rx_queue: RxQueue) -> None: + """Release a previously acquired streaming channel ID.""" + open_channels = self._open_channels + assigned_queue = open_channels.get(channel_id) + if rx_queue is not assigned_queue: + raise LMStudioRuntimeError( + f"Unexpected change to reply queue for channel ({channel_id} in {self!r})" + ) + del open_channels[channel_id] + + @contextmanager + def assign_channel_id(self, rx_queue: RxQueue) -> Generator[int, None, None]: + """Assign distinct streaming channel ID to given queue.""" + channel_id = self.acquire_channel_id(rx_queue) try: yield channel_id finally: - dropped_queue = self._open_channels.pop(channel_id, None) - assert dropped_queue is rx_queue, ( - f"Unexpected change to reply queue for channel ({channel_id} in {self!r})" - ) + self.release_channel_id(channel_id, rx_queue) def _get_next_call_id(self) -> int: """Get next distinct RPC ID.""" @@ -581,24 +600,36 @@ def _get_next_call_id(self) -> int: self._last_call_id = next_id return next_id - @contextmanager - def assign_call_id(self, rx_queue: TQueue) -> Generator[int, None, None]: - """Assign distinct remote call ID to given queue.""" + def acquire_call_id(self, rx_queue: RxQueue) -> int: + """Acquire a distinct remote call ID for the given queue.""" call_id = self._get_next_call_id() self._pending_calls[call_id] = rx_queue + return call_id + + def release_call_id(self, call_id: int, rx_queue: RxQueue) -> None: + """Release a previously acquired remote call ID.""" + pending_calls = self._pending_calls + assigned_queue = pending_calls.get(call_id) + if rx_queue is not assigned_queue: + raise LMStudioRuntimeError( + f"Unexpected change to reply queue for remote call ({call_id} in {self!r})" + ) + del pending_calls[call_id] + + @contextmanager + def assign_call_id(self, rx_queue: RxQueue) -> Generator[int, None, None]: + """Assign distinct remote call ID to given queue.""" + call_id = self.acquire_call_id(rx_queue) try: yield call_id finally: - dropped_queue = self._pending_calls.pop(call_id, None) - assert dropped_queue is rx_queue, ( - f"Unexpected change to reply queue for remote call ({call_id} in {self!r})" - ) + self.release_call_id(call_id, rx_queue) - def map_rx_message(self, message: DictObject) -> TQueue | None: + def map_rx_message(self, message: DictObject) -> RxQueue | None: """Map received message to the relevant demultiplexing queue.""" # TODO: Define an even-spammier-than-debug trace logging level for this # self._logger.trace("Incoming websocket message", json=message) - rx_queue: TQueue | None = None + rx_queue: RxQueue | None = None match message: case {"channelId": channel_id}: rx_queue = self._open_channels.get(channel_id, None) @@ -631,6 +662,37 @@ def map_rx_message(self, message: DictObject) -> TQueue | None: raise LMStudioClientError(f"Unexpected message: {unmatched}") return rx_queue + def map_tx_message(self, message: DictObject) -> RxQueue | None: + """Map failed message transmission to the relevant demultiplexing queue.""" + # TODO: Define an even-spammier-than-debug trace logging level for this + # self._logger.trace("Failed to send websocket message", json=message) + rx_queue: RxQueue | None = None + match message: + case {"channelId": channel_id}: + rx_queue = self._open_channels.get(channel_id, None) + if rx_queue is None: + if channel_id <= self._last_channel_id: + self._logger.warn( + "Attempted to send message on already closed channel", + channel_id=channel_id, + ) + else: + self._logger.warn( + "Attempted to send message on not yet used channel", + channel_id=channel_id, + ) + case {"callId": call_id}: + rx_queue = self._pending_calls.get(call_id, None) + if rx_queue is None: + self._logger.warn( + "Attempted to send remote call with unknown ID", call_id=call_id + ) + case _: + self._logger.warn( + "Attempted to send top level message on closed session" + ) + return rx_queue + # Channel events are processed via structural pattern matching, so it would be nice # to define them as tuples to make them as lightweight as possible at runtime. @@ -1772,7 +1834,7 @@ def _format_exc(exc: Exception) -> str: return exc_name -class LMStudioWebsocket(Generic[TWebsocket, TQueue]): +class LMStudioWebsocket(Generic[TWebsocket]): """Common base class for LM Studio websocket clients.""" # The common websocket API is narrow due to the sync/async split, @@ -1780,8 +1842,6 @@ class LMStudioWebsocket(Generic[TWebsocket, TQueue]): # Subclasses will declare a specific underlying websocket type _ws: TWebsocket | None - # Subclasses will declare a specific receive queue type - _mux: MultiplexingManager[TQueue] def __init__( self, @@ -1794,7 +1854,6 @@ def __init__( self._auth_details = auth_details self._logger = logger = new_logger(type(self).__name__) logger.update_context(log_context, ws_url=ws_url) - self._mux = MultiplexingManager(logger) # Subclasses handle actually creating a websocket instance self._ws = None @@ -1842,7 +1901,7 @@ def _ensure_connected(self, usage: str) -> None | NoReturn: return None -TLMStudioWebsocket = TypeVar("TLMStudioWebsocket", bound=LMStudioWebsocket[Any, Any]) +TLMStudioWebsocket = TypeVar("TLMStudioWebsocket", bound=LMStudioWebsocket[Any]) class ClientBase: diff --git a/src/lmstudio/sync_api.py b/src/lmstudio/sync_api.py index ee11ef7..cfe6c40 100644 --- a/src/lmstudio/sync_api.py +++ b/src/lmstudio/sync_api.py @@ -1,6 +1,5 @@ """Sync I/O protocol implementation for the LM Studio remote access API.""" -import asyncio import itertools import time import weakref @@ -110,7 +109,7 @@ _model_spec_to_api_dict, _redact_json, ) -from ._ws_impl import AsyncWebsocketThread, SyncToAsyncWebsocketBridge +from ._ws_thread import AsyncWebsocketThread, SyncToAsyncWebsocketBridge from ._kv_config import TLoadConfig, TLoadConfigDict, parse_server_config from ._sdk_models import ( EmbeddingRpcCountTokensParameter, @@ -230,6 +229,8 @@ def rx_stream( message = self._get_message(self.timeout) except TimeoutError: raise LMStudioTimeoutError from None + if message is None: + raise LMStudioRuntimeError("Client unexpectedly disconnected.") contents = self._api_channel.handle_rx_message(message) if contents is None: self._is_finished = True @@ -284,12 +285,12 @@ def receive_result(self) -> Any: message = self._get_message(self.timeout) except TimeoutError: raise LMStudioTimeoutError from None + if message is None: + raise LMStudioRuntimeError("Client unexpectedly disconnected.") return self._rpc.handle_rx_message(message) -class SyncLMStudioWebsocket( - LMStudioWebsocket[SyncToAsyncWebsocketBridge, asyncio.Queue[Any]] -): +class SyncLMStudioWebsocket(LMStudioWebsocket[SyncToAsyncWebsocketBridge]): """Synchronous websocket client that handles demultiplexing of reply messages.""" def __init__( @@ -328,8 +329,6 @@ def connect(self) -> Self: self._ws_thread, self._ws_url, self._auth_details, - self._get_rx_queue, - self._mux.all_queues, self._logger.event_context, ) if not ws.connect(): @@ -361,13 +360,6 @@ def _send_json(self, message: DictObject) -> None: # Background thread handles the exception conversion ws.send_json(message) - def _get_rx_queue(self, message: Any) -> asyncio.Queue[Any] | None: - if message is None: - self._logger.info(f"Websocket session failed ({self._ws_url})") - self._ws = None - return None - return self._mux.map_rx_message(message) - def _connect_to_endpoint(self, channel: SyncChannel[Any]) -> None: """Connect channel to specified endpoint.""" self._ensure_connected("open channel endpoints") @@ -383,8 +375,7 @@ def open_channel( """Open a streaming channel over the websocket.""" ws = self._ws assert ws is not None - rx_queue, getter = ws.new_rx_queue() - with self._mux.assign_channel_id(rx_queue) as channel_id: + with ws.open_channel() as (channel_id, getter): channel = SyncChannel( channel_id, getter, @@ -423,8 +414,7 @@ def remote_call( """Make a remote procedure call over the websocket.""" ws = self._ws assert ws is not None - rx_queue, getter = ws.new_rx_queue() - with self._mux.assign_call_id(rx_queue) as call_id: + with ws.start_call() as (call_id, getter): rpc = SyncRemoteCall( call_id, getter, self._logger.event_context, notice_prefix ) diff --git a/tests/test_sessions.py b/tests/test_sessions.py index a592437..6e046ea 100644 --- a/tests/test_sessions.py +++ b/tests/test_sessions.py @@ -21,7 +21,7 @@ SyncSession, SyncSessionSystem, ) -from lmstudio._ws_impl import AsyncWebsocketThread +from lmstudio._ws_thread import AsyncWebsocketThread from .support import LOCAL_API_HOST