diff --git a/async_substrate_interface/utils/cache.py b/async_substrate_interface/utils/cache.py index ab4f457..9d16411 100644 --- a/async_substrate_interface/utils/cache.py +++ b/async_substrate_interface/utils/cache.py @@ -2,6 +2,7 @@ import os import pickle import sqlite3 +from pathlib import Path import asyncstdlib as a USE_CACHE = True if os.getenv("NO_CACHE") != "1" else False @@ -14,6 +15,12 @@ ) +def _ensure_dir(): + path = Path(CACHE_LOCATION).parent + if not path.exists(): + path.mkdir(parents=True, exist_ok=True) + + def _get_table_name(func): """Convert "ClassName.method_name" to "ClassName_method_name""" return func.__qualname__.replace(".", "_") @@ -74,22 +81,32 @@ def _insert_into_cache(c, conn, table_name, key, result, chain): pass -def sql_lru_cache(maxsize=None): - def decorator(func): +def _shared_inner_fn_logic(func, self, args, kwargs): + chain = self.url + if not (local_chain := _check_if_local(chain)) or not USE_CACHE: + _ensure_dir() conn = sqlite3.connect(CACHE_LOCATION) c = conn.cursor() table_name = _get_table_name(func) _create_table(c, conn, table_name) + key = pickle.dumps((args, kwargs)) + result = _retrieve_from_cache(c, table_name, key, chain) + else: + result = None + c = None + conn = None + table_name = None + key = None + return c, conn, table_name, key, result, chain, local_chain + +def sql_lru_cache(maxsize=None): + def decorator(func): @functools.lru_cache(maxsize=maxsize) def inner(self, *args, **kwargs): - c = conn.cursor() - key = pickle.dumps((args, kwargs)) - chain = self.url - if not (local_chain := _check_if_local(chain)) or not USE_CACHE: - result = _retrieve_from_cache(c, table_name, key, chain) - if result is not None: - return result + c, conn, table_name, key, result, chain, local_chain = ( + _shared_inner_fn_logic(func, self, args, kwargs) + ) # If not in DB, call func and store in DB result = func(self, *args, **kwargs) @@ -106,21 +123,11 @@ def inner(self, *args, **kwargs): def async_sql_lru_cache(maxsize=None): def decorator(func): - conn = sqlite3.connect(CACHE_LOCATION) - c = conn.cursor() - table_name = _get_table_name(func) - _create_table(c, conn, table_name) - @a.lru_cache(maxsize=maxsize) async def inner(self, *args, **kwargs): - c = conn.cursor() - key = pickle.dumps((args, kwargs)) - chain = self.url - - if not (local_chain := _check_if_local(chain)) or not USE_CACHE: - result = _retrieve_from_cache(c, table_name, key, chain) - if result is not None: - return result + c, conn, table_name, key, result, chain, local_chain = ( + _shared_inner_fn_logic(func, self, args, kwargs) + ) # If not in DB, call func and store in DB result = await func(self, *args, **kwargs) diff --git a/tests/integration_tests/test_disk_cache.py b/tests/integration_tests/test_disk_cache.py new file mode 100644 index 0000000..3e379ab --- /dev/null +++ b/tests/integration_tests/test_disk_cache.py @@ -0,0 +1,74 @@ +import pytest + +from async_substrate_interface.async_substrate import ( + DiskCachedAsyncSubstrateInterface, + AsyncSubstrateInterface, +) +from async_substrate_interface.sync_substrate import SubstrateInterface + + +@pytest.mark.asyncio +async def test_disk_cache(): + entrypoint = "wss://entrypoint-finney.opentensor.ai:443" + async with DiskCachedAsyncSubstrateInterface(entrypoint) as disk_cached_substrate: + current_block = await disk_cached_substrate.get_block_number(None) + block_hash = await disk_cached_substrate.get_block_hash(current_block) + parent_block_hash = await disk_cached_substrate.get_parent_block_hash( + block_hash + ) + block_runtime_info = await disk_cached_substrate.get_block_runtime_info( + block_hash + ) + block_runtime_version_for = ( + await disk_cached_substrate.get_block_runtime_version_for(block_hash) + ) + block_hash_from_cache = await disk_cached_substrate.get_block_hash( + current_block + ) + parent_block_hash_from_cache = ( + await disk_cached_substrate.get_parent_block_hash(block_hash_from_cache) + ) + block_runtime_info_from_cache = ( + await disk_cached_substrate.get_block_runtime_info(block_hash_from_cache) + ) + block_runtime_version_from_cache = ( + await disk_cached_substrate.get_block_runtime_version_for( + block_hash_from_cache + ) + ) + assert block_hash == block_hash_from_cache + assert parent_block_hash == parent_block_hash_from_cache + assert block_runtime_info == block_runtime_info_from_cache + assert block_runtime_version_for == block_runtime_version_from_cache + async with AsyncSubstrateInterface(entrypoint) as non_cache_substrate: + block_hash_non_cache = await non_cache_substrate.get_block_hash(current_block) + parent_block_hash_non_cache = await non_cache_substrate.get_parent_block_hash( + block_hash_non_cache + ) + block_runtime_info_non_cache = await non_cache_substrate.get_block_runtime_info( + block_hash_non_cache + ) + block_runtime_version_for_non_cache = ( + await non_cache_substrate.get_block_runtime_version_for( + block_hash_non_cache + ) + ) + assert block_hash == block_hash_non_cache + assert parent_block_hash == parent_block_hash_non_cache + assert block_runtime_info == block_runtime_info_non_cache + assert block_runtime_version_for == block_runtime_version_for_non_cache + with SubstrateInterface(entrypoint) as sync_substrate: + block_hash_sync = sync_substrate.get_block_hash(current_block) + parent_block_hash_sync = sync_substrate.get_parent_block_hash( + block_hash_non_cache + ) + block_runtime_info_sync = sync_substrate.get_block_runtime_info( + block_hash_non_cache + ) + block_runtime_version_for_sync = sync_substrate.get_block_runtime_version_for( + block_hash_non_cache + ) + assert block_hash == block_hash_sync + assert parent_block_hash == parent_block_hash_sync + assert block_runtime_info == block_runtime_info_sync + assert block_runtime_version_for == block_runtime_version_for_sync