diff --git a/src/check_jsonschema/cachedownloader.py b/src/check_jsonschema/cachedownloader.py index a31913097..935294688 100644 --- a/src/check_jsonschema/cachedownloader.py +++ b/src/check_jsonschema/cachedownloader.py @@ -24,6 +24,9 @@ class CacheDownloader: # this will let us do any other caching we might need in the future in the same # cache dir (adjacent to "downloads") _CACHEDIR_NAME = os.path.join("check_jsonschema", "downloads") + # Keep list of newly loaded/revalidated schemas in memory to avoid network requests + # Especially useful for schemas making extensive use of refs to remote URLs + _DOWNLOADED_URIS: set[str] = set() def __init__( self, @@ -113,10 +116,14 @@ def _write(self, dest: str, response: requests.Response) -> None: shutil.copy(fp.name, dest) os.remove(fp.name) - def _download(self) -> str: + def _cachefile_path(self) -> str: assert self._cache_dir os.makedirs(self._cache_dir, exist_ok=True) - dest = os.path.join(self._cache_dir, self._filename) + return os.path.join(self._cache_dir, self._filename) + + def _download(self) -> str: + dest = self._cachefile_path() + CacheDownloader._DOWNLOADED_URIS.add(self._file_url) response = self._get_request() # check to see if we have a file which matches the connection @@ -130,6 +137,14 @@ def _download(self) -> str: def open(self) -> t.Iterator[t.IO[bytes]]: if (not self._cache_dir) or self._disable_cache: yield io.BytesIO(self._get_request().content) + else: - with open(self._download(), "rb") as fp: - yield fp + cachefile = self._cachefile_path() + if self._file_url in CacheDownloader._DOWNLOADED_URIS and os.path.exists( + cachefile + ): + with open(cachefile, "rb") as fp: + yield fp + else: + with open(self._download(), "rb") as fp: + yield fp diff --git a/src/check_jsonschema/schema_loader/resolver.py b/src/check_jsonschema/schema_loader/resolver.py index 1ad1248df..8cda9a730 100644 --- a/src/check_jsonschema/schema_loader/resolver.py +++ b/src/check_jsonschema/schema_loader/resolver.py @@ -4,9 +4,10 @@ import urllib.parse import referencing -import requests from referencing.jsonschema import DRAFT202012, Schema +from check_jsonschema.cachedownloader import CacheDownloader + from ..parsers import ParserSet from ..utils import filename2path @@ -62,10 +63,9 @@ def retrieve_reference(uri: str) -> referencing.Resource[Schema]: full_uri_scheme = urllib.parse.urlsplit(full_uri).scheme if full_uri_scheme in ("http", "https"): - data = requests.get(full_uri, stream=True) - parsed_object = parser_set.parse_data_with_path( - data.content, full_uri, "json" - ) + dwl = CacheDownloader(full_uri) + with dwl.open() as file: + parsed_object = parser_set.parse_data_with_path(file, full_uri, "json") else: parsed_object = get_local_file(full_uri)