diff --git a/src/cashet/store.py b/src/cashet/store.py index a80c60a..730b0ef 100644 --- a/src/cashet/store.py +++ b/src/cashet/store.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import contextlib import hashlib import json import logging @@ -708,12 +709,12 @@ def delete_by_tags(self, tags: dict[str, str | None]) -> int: return deleted def _row_object_refs(self, row: sqlite3.Row) -> list[str]: - refs: list[str] = [] + refs: set[str] = set() if row["output_hash"]: - refs.append(row["output_hash"]) + refs.add(row["output_hash"]) if row["input_refs"]: - refs.extend(set(json.loads(row["input_refs"]))) - return refs + refs.update(json.loads(row["input_refs"])) + return list(refs) def _object_ref_counts(self, conn: sqlite3.Connection) -> dict[str, int]: counts: dict[str, int] = {} @@ -844,16 +845,119 @@ async def __aexit__(self, *args: Any) -> None: self._state.thread_lock.release() +class _SyncSQLiteFingerprintLock: + def __init__(self, lock_path: str) -> None: + self._state = _sqlite_lock_state(lock_path) + + def __enter__(self) -> _SyncSQLiteFingerprintLock: + self._state.thread_lock.acquire() + self._state.file_lock.acquire() + return self + + def __exit__(self, *args: Any) -> None: + self._state.file_lock.release() + self._state.thread_lock.release() + + +class _DirectFingerprintLock: + def __init__(self, lock_path: str) -> None: + self._state = _sqlite_lock_state(lock_path) + + async def __aenter__(self) -> _DirectFingerprintLock: + self._state.thread_lock.acquire() + self._state.file_lock.acquire() + return self + + async def __aexit__(self, *args: Any) -> None: + self._state.file_lock.release() + self._state.thread_lock.release() + + +class _DirectAsyncSQLiteStore: + def __init__(self, core: _SQLiteStoreCore) -> None: + self._core = core + self._lock_paths: set[str] = set() + + def _fingerprint_lock(self, fingerprint: str) -> _DirectFingerprintLock: + import hashlib + + fp_hash = hashlib.sha256(fingerprint.encode()).hexdigest()[:16] + lock_path = str(self._core.root / f".lock-{fp_hash}") + self._lock_paths.add(lock_path) + return _DirectFingerprintLock(lock_path) + + @property + def root(self) -> Path: + return self._core.root + + async def put_blob(self, data: bytes) -> ObjectRef: + return self._core.put_blob(data) + + async def get_blob(self, ref: ObjectRef) -> bytes: + return self._core.get_blob(ref) + + async def put_commit(self, commit: Commit) -> None: + self._core.put_commit(commit) + + async def get_commit(self, hash: str) -> Commit | None: + return self._core.get_commit(hash) + + async def find_by_fingerprint(self, fingerprint: str) -> Commit | None: + return self._core.find_by_fingerprint(fingerprint) + + async def find_running_by_fingerprint(self, fingerprint: str) -> Commit | None: + return self._core.find_running_by_fingerprint(fingerprint) + + async def list_commits( + self, + func_name: str | None = None, + limit: int = 50, + status: TaskStatus | None = None, + tags: dict[str, str | None] | None = None, + ) -> list[Commit]: + return self._core.list_commits( + func_name=func_name, + limit=limit, + status=status, + tags=tags, + ) + + async def get_history(self, hash: str) -> list[Commit]: + return self._core.get_history(hash) + + async def stats(self) -> dict[str, int]: + return self._core.stats() + + async def evict(self, older_than: datetime, max_size_bytes: int | None = None) -> int: + return self._core.evict(older_than, max_size_bytes=max_size_bytes) + + async def delete_commit(self, hash: str) -> bool: + return self._core.delete_commit(hash) + + async def delete_by_tags(self, tags: dict[str, str | None]) -> int: + return self._core.delete_by_tags(tags) + + async def close(self) -> None: + self._core.close() + for lock_path in self._lock_paths: + with contextlib.suppress(OSError): + Path(lock_path).unlink(missing_ok=True) + self._lock_paths.clear() + + class AsyncSQLiteStore: def __init__(self, root: Path) -> None: self._core = _SQLiteStoreCore(root) self._write_lock = asyncio.Lock() + self._lock_paths: set[str] = set() def _fingerprint_lock(self, fingerprint: str) -> _SQLiteFingerprintLock: import hashlib fp_hash = hashlib.sha256(fingerprint.encode()).hexdigest()[:16] - return _SQLiteFingerprintLock(str(self._core.root / f".lock-{fp_hash}")) + lock_path = str(self._core.root / f".lock-{fp_hash}") + self._lock_paths.add(lock_path) + return _SQLiteFingerprintLock(lock_path) @property def root(self) -> Path: @@ -922,6 +1026,10 @@ async def delete_by_tags(self, tags: dict[str, str | None]) -> int: async def close(self) -> None: await asyncio.to_thread(self._core.close) + for lock_path in self._lock_paths: + with contextlib.suppress(OSError): + Path(lock_path).unlink(missing_ok=True) + self._lock_paths.clear() class SQLiteStore: @@ -950,13 +1058,28 @@ def objects_dir(self) -> Path: def db_path(self) -> Path: return self._async_store.db_path + @property + def _core(self) -> _SQLiteStoreCore: + return self._async_store._core # pyright: ignore[reportPrivateUsage] + + @property + def _direct_async_store(self) -> _DirectAsyncSQLiteStore: + return _DirectAsyncSQLiteStore(self._core) + + def _fingerprint_lock_sync(self, fingerprint: str) -> _SyncSQLiteFingerprintLock: + import hashlib + + fp_hash = hashlib.sha256(fingerprint.encode()).hexdigest()[:16] + lock_path = str(self.root / f".lock-{fp_hash}") + self._async_store._lock_paths.add(lock_path) # pyright: ignore[reportPrivateUsage] + return _SyncSQLiteFingerprintLock(lock_path) + def _connect(self, *, immediate: bool = False) -> sqlite3.Connection: - core: Any = self._async_store._core # pyright: ignore[reportPrivateUsage] - return core._connect(immediate=immediate) + core: _SQLiteStoreCore = self._core + return core._connect(immediate=immediate) # pyright: ignore[reportPrivateUsage] def blob_exists(self, hash: str) -> bool: - core: Any = self._async_store._core # pyright: ignore[reportPrivateUsage] - return core.blob_exists(hash) + return self._core.blob_exists(hash) def put_blob(self, data: bytes) -> ObjectRef: return self._runner.call(self._async_store.put_blob(data))