Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions backend_service/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ def __init__(
self._library_scan_started: bool = False
self._library_scan_done: threading.Event = threading.Event()
self._library_scan_generation: int = 0
self._library_scan_threads: list[threading.Thread] = []
self._library_fingerprint: dict[str, float] = {}
if library_provider is None:
cached_payload = _load_library_cache(self._library_cache_path)
Expand Down Expand Up @@ -312,6 +313,18 @@ def _kick_library_scan(self, *, force: bool = False) -> None:
daemon=True,
)
thread.start()
with self._lock:
self._library_scan_threads = [
t for t in self._library_scan_threads if t.is_alive()
]
self._library_scan_threads.append(thread)

def shutdown(self, timeout: float = 5.0) -> None:
with self._lock:
threads = list(self._library_scan_threads)
for t in threads:
if t.is_alive():
t.join(timeout=timeout)

def _scan_library_into_cache(
self,
Expand Down
48 changes: 27 additions & 21 deletions tests/test_state_async_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,11 @@ def setUp(self) -> None:
chat_sessions_path=tmpdir / "chat-sessions.json",
library_cache_path=tmpdir / "library_cache.json",
)
self.state: ChaosEngineState | None = None

def tearDown(self) -> None:
if self.state is not None:
self.state.shutdown()
self.tmp.cleanup()

def test_library_provider_short_circuits_async_scan(self):
Expand All @@ -57,10 +60,10 @@ def provider():
provider_calls.append(True)
return [{"name": "fake/model", "path": "/tmp/fake"}]

state = ChaosEngineState(library_provider=provider, **self.kwargs)
self.assertTrue(state._library_scan_done.is_set())
self.assertFalse(state._library_scan_started)
result = state._library()
self.state = ChaosEngineState(library_provider=provider, **self.kwargs)
self.assertTrue(self.state._library_scan_done.is_set())
self.assertFalse(self.state._library_scan_started)
result = self.state._library()
self.assertEqual(len(result), 1)
self.assertGreaterEqual(len(provider_calls), 1)

Expand All @@ -76,14 +79,14 @@ def slow_scan(directories):
"backend_service.state._discover_local_models",
side_effect=slow_scan,
):
state = ChaosEngineState(**self.kwargs)
self.state = ChaosEngineState(**self.kwargs)
try:
self.assertTrue(state._library_scan_done.wait(2.0))
workspace = state.workspace()
self.assertTrue(self.state._library_scan_done.wait(2.0))
workspace = self.state.workspace()
self.assertEqual(workspace["libraryStatus"], "ready")
self.assertEqual(len(workspace["library"]), 1)
finally:
state._library_scan_done.set()
self.state._library_scan_done.set()

def test_kick_scan_is_idempotent(self):
scan_calls = []
Expand All @@ -96,12 +99,12 @@ def fake_scan(directories):
"backend_service.state._discover_local_models",
side_effect=fake_scan,
):
state = ChaosEngineState(**self.kwargs)
self.assertTrue(state._library_scan_done.wait(2.0))
self.state = ChaosEngineState(**self.kwargs)
self.assertTrue(self.state._library_scan_done.wait(2.0))
initial_calls = len(scan_calls)
state._kick_library_scan()
state._kick_library_scan()
state._library_scan_done.wait(2.0)
self.state._kick_library_scan()
self.state._kick_library_scan()
self.state._library_scan_done.wait(2.0)
self.assertGreaterEqual(len(scan_calls), initial_calls)


Expand All @@ -117,22 +120,25 @@ def setUp(self) -> None:
chat_sessions_path=tmpdir / "chat-sessions.json",
library_cache_path=tmpdir / "library_cache.json",
)
self.state: ChaosEngineState | None = None

def tearDown(self) -> None:
if self.state is not None:
self.state.shutdown()
self.tmp.cleanup()

def test_runtimes_unconstructed_after_init(self):
state = ChaosEngineState(**self.kwargs)
self.assertIsNone(state._image_runtime)
self.assertIsNone(state._video_runtime)
self.state = ChaosEngineState(**self.kwargs)
self.assertIsNone(self.state._image_runtime)
self.assertIsNone(self.state._video_runtime)

def test_runtime_setter_keeps_test_compat(self):
state = ChaosEngineState(**self.kwargs)
self.state = ChaosEngineState(**self.kwargs)
marker = object()
state.image_runtime = marker
self.assertIs(state.image_runtime, marker)
state.video_runtime = marker
self.assertIs(state.video_runtime, marker)
self.state.image_runtime = marker
self.assertIs(self.state.image_runtime, marker)
self.state.video_runtime = marker
self.assertIs(self.state.video_runtime, marker)


if __name__ == "__main__":
Expand Down