diff --git a/backend_service/state.py b/backend_service/state.py index 4c09aa9..63582f9 100644 --- a/backend_service/state.py +++ b/backend_service/state.py @@ -196,6 +196,7 @@ def __init__( self._library_scan_started: bool = False self._library_scan_done: threading.Event = threading.Event() self._library_scan_generation: int = 0 + self._library_scan_threads: list[threading.Thread] = [] self._library_fingerprint: dict[str, float] = {} if library_provider is None: cached_payload = _load_library_cache(self._library_cache_path) @@ -312,6 +313,18 @@ def _kick_library_scan(self, *, force: bool = False) -> None: daemon=True, ) thread.start() + with self._lock: + self._library_scan_threads = [ + t for t in self._library_scan_threads if t.is_alive() + ] + self._library_scan_threads.append(thread) + + def shutdown(self, timeout: float = 5.0) -> None: + with self._lock: + threads = list(self._library_scan_threads) + for t in threads: + if t.is_alive(): + t.join(timeout=timeout) def _scan_library_into_cache( self, diff --git a/tests/test_state_async_library.py b/tests/test_state_async_library.py index 0aa7e22..eab8e41 100644 --- a/tests/test_state_async_library.py +++ b/tests/test_state_async_library.py @@ -46,8 +46,11 @@ def setUp(self) -> None: chat_sessions_path=tmpdir / "chat-sessions.json", library_cache_path=tmpdir / "library_cache.json", ) + self.state: ChaosEngineState | None = None def tearDown(self) -> None: + if self.state is not None: + self.state.shutdown() self.tmp.cleanup() def test_library_provider_short_circuits_async_scan(self): @@ -57,10 +60,10 @@ def provider(): provider_calls.append(True) return [{"name": "fake/model", "path": "/tmp/fake"}] - state = ChaosEngineState(library_provider=provider, **self.kwargs) - self.assertTrue(state._library_scan_done.is_set()) - self.assertFalse(state._library_scan_started) - result = state._library() + self.state = ChaosEngineState(library_provider=provider, **self.kwargs) + self.assertTrue(self.state._library_scan_done.is_set()) + self.assertFalse(self.state._library_scan_started) + result = self.state._library() self.assertEqual(len(result), 1) self.assertGreaterEqual(len(provider_calls), 1) @@ -76,14 +79,14 @@ def slow_scan(directories): "backend_service.state._discover_local_models", side_effect=slow_scan, ): - state = ChaosEngineState(**self.kwargs) + self.state = ChaosEngineState(**self.kwargs) try: - self.assertTrue(state._library_scan_done.wait(2.0)) - workspace = state.workspace() + self.assertTrue(self.state._library_scan_done.wait(2.0)) + workspace = self.state.workspace() self.assertEqual(workspace["libraryStatus"], "ready") self.assertEqual(len(workspace["library"]), 1) finally: - state._library_scan_done.set() + self.state._library_scan_done.set() def test_kick_scan_is_idempotent(self): scan_calls = [] @@ -96,12 +99,12 @@ def fake_scan(directories): "backend_service.state._discover_local_models", side_effect=fake_scan, ): - state = ChaosEngineState(**self.kwargs) - self.assertTrue(state._library_scan_done.wait(2.0)) + self.state = ChaosEngineState(**self.kwargs) + self.assertTrue(self.state._library_scan_done.wait(2.0)) initial_calls = len(scan_calls) - state._kick_library_scan() - state._kick_library_scan() - state._library_scan_done.wait(2.0) + self.state._kick_library_scan() + self.state._kick_library_scan() + self.state._library_scan_done.wait(2.0) self.assertGreaterEqual(len(scan_calls), initial_calls) @@ -117,22 +120,25 @@ def setUp(self) -> None: chat_sessions_path=tmpdir / "chat-sessions.json", library_cache_path=tmpdir / "library_cache.json", ) + self.state: ChaosEngineState | None = None def tearDown(self) -> None: + if self.state is not None: + self.state.shutdown() self.tmp.cleanup() def test_runtimes_unconstructed_after_init(self): - state = ChaosEngineState(**self.kwargs) - self.assertIsNone(state._image_runtime) - self.assertIsNone(state._video_runtime) + self.state = ChaosEngineState(**self.kwargs) + self.assertIsNone(self.state._image_runtime) + self.assertIsNone(self.state._video_runtime) def test_runtime_setter_keeps_test_compat(self): - state = ChaosEngineState(**self.kwargs) + self.state = ChaosEngineState(**self.kwargs) marker = object() - state.image_runtime = marker - self.assertIs(state.image_runtime, marker) - state.video_runtime = marker - self.assertIs(state.video_runtime, marker) + self.state.image_runtime = marker + self.assertIs(self.state.image_runtime, marker) + self.state.video_runtime = marker + self.assertIs(self.state.video_runtime, marker) if __name__ == "__main__":