diff --git a/docs/faq.rst b/docs/faq.rst index 5054c2ced..1d78f78c4 100644 --- a/docs/faq.rst +++ b/docs/faq.rst @@ -115,3 +115,46 @@ To resolve, quote the boolean: steps: - bash: echo "{{ parameters.myBoolean}}" + +Caching +------- + +What data gets cached? +~~~~~~~~~~~~~~~~~~~~~~ + +``check-jsonschema`` will cache all downloaded schemas by default. +The schemas are stored in the ``downloads/`` directory in your cache dir, and any +downloaded refs are stored in the ``refs/`` directory. + +Where is the cache dir? +~~~~~~~~~~~~~~~~~~~~~~~ + +``check-jsonschema`` detects an appropriate cache directory based on your +platform and environment variables. + +On Windows, the cache dir is ``%LOCALAPPDATA%/check_jsonschema/`` and falls back +to ``%APPDATA%/check_jsonschema/`` if ``LOCALAPPDATA`` is unset. + +On macOS, the cache dir is ``~/Library/Caches/check_jsonschema/``. + +On Linux, the cache dir is ``$XDG_CACHE_HOME/check_jsonschema/`` and falls back +to ``~/.cache/check_jsonschema/`` if ``XDG_CACHE_HOME`` is unset. + +How does check-jsonschema decide what is a cache hit vs miss? +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +``check-jsonschema`` checks for cache hits by comparing local file modification +times to the ``Last-Modified`` header present in the headers on an HTTP GET +request. If the local last modified time is older than the header, the rest of +the request will be streamed and written to replace the file. + +How do I clear the cache? +~~~~~~~~~~~~~~~~~~~~~~~~~ + +There is no special command for clearing the cache. Simply find the cache +directory based on the information above and remove it or any of its contents. + +Can I disable caching? +~~~~~~~~~~~~~~~~~~~~~~ + +Yes! Just use the ``--no-cache`` CLI option. diff --git a/src/check_jsonschema/cachedownloader.py b/src/check_jsonschema/cachedownloader.py index 268db55ed..9dd4c0491 100644 --- a/src/check_jsonschema/cachedownloader.py +++ b/src/check_jsonschema/cachedownloader.py @@ -11,139 +11,186 @@ import requests +_LASTMOD_FMT = "%a, %d %b %Y %H:%M:%S %Z" + + +def _base_cache_dir() -> str | None: + sysname = platform.system() + + # on windows, try to get the appdata env var + # this *could* result in cache_dir=None, which is fine, just skip caching in + # that case + if sysname == "Windows": + cache_dir = os.getenv("LOCALAPPDATA", os.getenv("APPDATA")) + # macOS -> app support dir + elif sysname == "Darwin": + cache_dir = os.path.expanduser("~/Library/Caches") + # default for unknown platforms, namely linux behavior + # use XDG env var and default to ~/.cache/ + else: + cache_dir = os.getenv("XDG_CACHE_HOME", os.path.expanduser("~/.cache")) + + return cache_dir + + +def _resolve_cache_dir(dirname: str = "downloads") -> str | None: + cache_dir = _base_cache_dir() + if cache_dir: + cache_dir = os.path.join(cache_dir, "check_jsonschema", dirname) + return cache_dir + + +def _lastmod_from_response(response: requests.Response) -> float: + try: + return time.mktime( + time.strptime(response.headers["last-modified"], _LASTMOD_FMT) + ) + # OverflowError: time outside of platform-specific bounds + # ValueError: malformed/unparseable + # LookupError: no such header + except (OverflowError, ValueError, LookupError): + return 0.0 + + +def _get_request( + file_url: str, *, response_ok: t.Callable[[requests.Response], bool] +) -> requests.Response: + num_retries = 2 + r: requests.Response | None = None + for _attempt in range(num_retries + 1): + try: + r = requests.get(file_url, stream=True) + except requests.RequestException as e: + if _attempt == num_retries: + raise FailedDownloadError("encountered error during download") from e + continue + if r.ok and response_ok(r): + return r + assert r is not None + raise FailedDownloadError( + f"got response with status={r.status_code}, retries exhausted" + ) + + +def _atomic_write(dest: str, content: bytes) -> None: + # download to a temp file and then move to the dest + # this makes the download safe if run in parallel (parallel runs + # won't create a new empty file for writing and cause failures) + fp = tempfile.NamedTemporaryFile(mode="wb", delete=False) + fp.write(content) + fp.close() + shutil.copy(fp.name, dest) + os.remove(fp.name) + + +def _cache_hit(cachefile: str, response: requests.Response) -> bool: + # no file? miss + if not os.path.exists(cachefile): + return False + + # compare mtime on any cached file against the remote last-modified time + # it is considered a hit if the local file is at least as new as the remote file + local_mtime = os.path.getmtime(cachefile) + remote_mtime = _lastmod_from_response(response) + return local_mtime >= remote_mtime + class FailedDownloadError(Exception): pass class CacheDownloader: - _LASTMOD_FMT = "%a, %d %b %Y %H:%M:%S %Z" - - # changed in v0.5.0 - # original cache dir was "jsonschema_validate" - # this will let us do any other caching we might need in the future in the same - # cache dir (adjacent to "downloads") - _CACHEDIR_NAME = os.path.join("check_jsonschema", "downloads") + def __init__(self, cache_dir: str | None = None, disable_cache: bool = False): + if cache_dir is None: + self._cache_dir = _resolve_cache_dir() + else: + self._cache_dir = _resolve_cache_dir(cache_dir) + self._disable_cache = disable_cache - def __init__( + def _download( self, file_url: str, - filename: str | None = None, - cache_dir: str | None = None, - disable_cache: bool = False, - validation_callback: t.Callable[[bytes], t.Any] | None = None, - ): - self._file_url = file_url - self._filename = filename or file_url.split("/")[-1] - self._cache_dir = cache_dir or self._compute_default_cache_dir() - self._disable_cache = disable_cache - self._validation_callback = validation_callback - - def _compute_default_cache_dir(self) -> str | None: - sysname = platform.system() - - # on windows, try to get the appdata env var - # this *could* result in cache_dir=None, which is fine, just skip caching in - # that case - if sysname == "Windows": - cache_dir = os.getenv("LOCALAPPDATA", os.getenv("APPDATA")) - # macOS -> app support dir - elif sysname == "Darwin": - cache_dir = os.path.expanduser("~/Library/Caches") - # default for unknown platforms, namely linux behavior - # use XDG env var and default to ~/.cache/ - else: - cache_dir = os.getenv("XDG_CACHE_HOME", os.path.expanduser("~/.cache")) - - if cache_dir: - cache_dir = os.path.join(cache_dir, self._CACHEDIR_NAME) - - return cache_dir - - def _get_request( - self, *, response_ok: t.Callable[[requests.Response], bool] - ) -> requests.Response: - try: - r: requests.Response | None = None - for _attempt in range(3): - r = requests.get(self._file_url, stream=True) - if r.ok and response_ok(r): - return r - assert r is not None - raise FailedDownloadError( - f"got response with status={r.status_code}, retries exhausted" - ) - except requests.RequestException as e: - raise FailedDownloadError("encountered error during download") from e - - def _lastmod_from_response(self, response: requests.Response) -> float: - try: - return time.mktime( - time.strptime(response.headers["last-modified"], self._LASTMOD_FMT) - ) - # OverflowError: time outside of platform-specific bounds - # ValueError: malformed/unparseable - # LookupError: no such header - except (OverflowError, ValueError, LookupError): - return 0.0 - - def _cache_hit(self, cachefile: str, response: requests.Response) -> bool: - # no file? miss - if not os.path.exists(cachefile): - return False - - # compare mtime on any cached file against the remote last-modified time - # it is considered a hit if the local file is at least as new as the remote file - local_mtime = os.path.getmtime(cachefile) - remote_mtime = self._lastmod_from_response(response) - return local_mtime >= remote_mtime - - def _write(self, dest: str, response: requests.Response) -> None: - # download to a temp file and then move to the dest - # this makes the download safe if run in parallel (parallel runs - # won't create a new empty file for writing and cause failures) - fp = tempfile.NamedTemporaryFile(mode="wb", delete=False) - fp.write(response.content) - fp.close() - shutil.copy(fp.name, dest) - os.remove(fp.name) - - def _validate(self, response: requests.Response) -> bool: - if not self._validation_callback: - return True - - try: - self._validation_callback(response.content) - return True - except ValueError: - return False - - def _download(self) -> str: - assert self._cache_dir + filename: str, + response_ok: t.Callable[[requests.Response], bool], + ) -> str: + assert self._cache_dir is not None os.makedirs(self._cache_dir, exist_ok=True) - dest = os.path.join(self._cache_dir, self._filename) + dest = os.path.join(self._cache_dir, filename) def check_response_for_download(r: requests.Response) -> bool: # if the response indicates a cache hit, treat it as valid # this ensures that we short-circuit any further evaluation immediately on # a hit - if self._cache_hit(dest, r): + if _cache_hit(dest, r): return True # we now know it's not a hit, so validate the content (forces download) - return self._validate(r) + return response_ok(r) - response = self._get_request(response_ok=check_response_for_download) + response = _get_request(file_url, response_ok=check_response_for_download) # check to see if we have a file which matches the connection # only download if we do not (cache miss, vs hit) - if not self._cache_hit(dest, response): - self._write(dest, response) + if not _cache_hit(dest, response): + _atomic_write(dest, response.content) return dest @contextlib.contextmanager - def open(self) -> t.Iterator[t.IO[bytes]]: + def open( + self, + file_url: str, + filename: str, + validate_response: t.Callable[[requests.Response], bool], + ) -> t.Iterator[t.IO[bytes]]: if (not self._cache_dir) or self._disable_cache: - yield io.BytesIO(self._get_request(response_ok=self._validate).content) + yield io.BytesIO( + _get_request(file_url, response_ok=validate_response).content + ) else: - with open(self._download(), "rb") as fp: + with open( + self._download(file_url, filename, response_ok=validate_response), "rb" + ) as fp: yield fp + + def bind( + self, + file_url: str, + filename: str | None = None, + validation_callback: t.Callable[[bytes], t.Any] | None = None, + ) -> BoundCacheDownloader: + return BoundCacheDownloader( + file_url, filename, self, validation_callback=validation_callback + ) + + +class BoundCacheDownloader: + def __init__( + self, + file_url: str, + filename: str | None, + downloader: CacheDownloader, + *, + validation_callback: t.Callable[[bytes], t.Any] | None = None, + ): + self._file_url = file_url + self._filename = filename or file_url.split("/")[-1] + self._downloader = downloader + self._validation_callback = validation_callback + + @contextlib.contextmanager + def open(self) -> t.Iterator[t.IO[bytes]]: + with self._downloader.open( + self._file_url, + self._filename, + validate_response=self._validate_response, + ) as fp: + yield fp + + def _validate_response(self, response: requests.Response) -> bool: + if not self._validation_callback: + return True + + try: + self._validation_callback(response.content) + return True + except ValueError: + return False diff --git a/src/check_jsonschema/cli/main_command.py b/src/check_jsonschema/cli/main_command.py index 4d9fc6ec0..3145019e8 100644 --- a/src/check_jsonschema/cli/main_command.py +++ b/src/check_jsonschema/cli/main_command.py @@ -300,8 +300,8 @@ def build_schema_loader(args: ParseResult) -> SchemaLoaderBase: assert args.schema_path is not None return SchemaLoader( args.schema_path, - args.cache_filename, - args.disable_cache, + cache_filename=args.cache_filename, + disable_cache=args.disable_cache, base_uri=args.base_uri, validator_class=args.validator_class, ) diff --git a/src/check_jsonschema/schema_loader/main.py b/src/check_jsonschema/schema_loader/main.py index 070c1f6f7..4ce95c9e5 100644 --- a/src/check_jsonschema/schema_loader/main.py +++ b/src/check_jsonschema/schema_loader/main.py @@ -57,14 +57,16 @@ def get_validator( class SchemaLoader(SchemaLoaderBase): validator_class: type[jsonschema.protocols.Validator] | None = None + disable_cache: bool = True def __init__( self, schemafile: str, + *, cache_filename: str | None = None, - disable_cache: bool = False, base_uri: str | None = None, validator_class: type[jsonschema.protocols.Validator] | None = None, + disable_cache: bool = True, ) -> None: # record input parameters (these are not to be modified) self.schemafile = schemafile @@ -140,7 +142,7 @@ def get_validator( # reference resolution # with support for YAML, TOML, and other formats from the parsers reference_registry = make_reference_registry( - self._parsers, retrieval_uri, schema + self._parsers, retrieval_uri, schema, self.disable_cache ) if self.validator_class is None: @@ -171,7 +173,7 @@ def get_validator( class BuiltinSchemaLoader(SchemaLoader): - def __init__(self, schema_name: str, base_uri: str | None = None) -> None: + def __init__(self, schema_name: str, *, base_uri: str | None = None) -> None: self.schema_name = schema_name self.base_uri = base_uri self._parsers = ParserSet() @@ -187,7 +189,7 @@ def get_schema(self) -> dict[str, t.Any]: class MetaSchemaLoader(SchemaLoaderBase): - def __init__(self, base_uri: str | None = None) -> None: + def __init__(self, *, base_uri: str | None = None) -> None: if base_uri is not None: raise NotImplementedError( "'--base-uri' was used with '--metaschema'. " diff --git a/src/check_jsonschema/schema_loader/readers.py b/src/check_jsonschema/schema_loader/readers.py index 907ce6936..65280808b 100644 --- a/src/check_jsonschema/schema_loader/readers.py +++ b/src/check_jsonschema/schema_loader/readers.py @@ -79,11 +79,8 @@ def __init__( self.url = url self.parsers = ParserSet() self.downloader = CacheDownloader( - url, - cache_filename, disable_cache=disable_cache, - validation_callback=self._parse, - ) + ).bind(url, cache_filename, validation_callback=self._parse) self._parsed_schema: dict | _UnsetType = _UNSET def _parse(self, schema_bytes: bytes) -> t.Any: diff --git a/src/check_jsonschema/schema_loader/resolver.py b/src/check_jsonschema/schema_loader/resolver.py index 1ad1248df..c63b7bb4d 100644 --- a/src/check_jsonschema/schema_loader/resolver.py +++ b/src/check_jsonschema/schema_loader/resolver.py @@ -1,18 +1,34 @@ from __future__ import annotations +import hashlib import typing as t import urllib.parse import referencing -import requests from referencing.jsonschema import DRAFT202012, Schema +from ..cachedownloader import CacheDownloader from ..parsers import ParserSet from ..utils import filename2path +def ref_url_to_cache_filename(ref_url: str) -> str: + """ + Given a $ref URL, convert it to the filename in the refs/ cache dir. + Rules are as follows: + - the base filename is an md5 hash of the URL + - if the filename ends in an extension (.json, .yaml, etc) that extension + is appended to the hash + """ + filename = hashlib.md5(ref_url.encode()).hexdigest() + if "." in (last_part := ref_url.rpartition("/")[-1]): + _, _, extension = last_part.rpartition(".") + filename = f"{filename}.{extension}" + return filename + + def make_reference_registry( - parsers: ParserSet, retrieval_uri: str | None, schema: dict + parsers: ParserSet, retrieval_uri: str | None, schema: dict, disable_cache: bool ) -> referencing.Registry: id_attribute_: t.Any = schema.get("$id") if isinstance(id_attribute_, str): @@ -26,7 +42,9 @@ def make_reference_registry( # mypy does not recognize that Registry is an `attrs` class and has `retrieve` as an # argument to its implicit initializer registry: referencing.Registry = referencing.Registry( # type: ignore[call-arg] - retrieve=create_retrieve_callable(parsers, retrieval_uri, id_attribute) + retrieve=create_retrieve_callable( + parsers, retrieval_uri, id_attribute, disable_cache + ) ) if retrieval_uri is not None: @@ -38,13 +56,17 @@ def make_reference_registry( def create_retrieve_callable( - parser_set: ParserSet, retrieval_uri: str | None, id_attribute: str | None + parser_set: ParserSet, + retrieval_uri: str | None, + id_attribute: str | None, + disable_cache: bool, ) -> t.Callable[[str], referencing.Resource[Schema]]: base_uri = id_attribute if base_uri is None: base_uri = retrieval_uri cache = ResourceCache() + downloader = CacheDownloader("refs", disable_cache) def get_local_file(uri: str) -> t.Any: path = filename2path(uri) @@ -62,10 +84,19 @@ def retrieve_reference(uri: str) -> referencing.Resource[Schema]: full_uri_scheme = urllib.parse.urlsplit(full_uri).scheme if full_uri_scheme in ("http", "https"): - data = requests.get(full_uri, stream=True) - parsed_object = parser_set.parse_data_with_path( - data.content, full_uri, "json" + + def validation_callback(content: bytes) -> None: + parser_set.parse_data_with_path(content, full_uri, "json") + + bound_downloader = downloader.bind( + full_uri, + ref_url_to_cache_filename(full_uri), + validation_callback, ) + with bound_downloader.open() as fp: + data = fp.read() + + parsed_object = parser_set.parse_data_with_path(data, full_uri, "json") else: parsed_object = get_local_file(full_uri) diff --git a/tests/acceptance/test_nonjson_schema_handling.py b/tests/acceptance/test_nonjson_schema_handling.py index d64ac42e7..4e56d25e2 100644 --- a/tests/acceptance/test_nonjson_schema_handling.py +++ b/tests/acceptance/test_nonjson_schema_handling.py @@ -1,6 +1,7 @@ import json import pytest +import responses from check_jsonschema.parsers.json5 import ENABLED as JSON5_ENABLED @@ -87,3 +88,78 @@ def test_can_load_json5_schema(run_line, tmp_path, passing_data): ["check-jsonschema", "--schemafile", str(main_schemafile), str(doc)] ) assert result.exit_code == (0 if passing_data else 1) + + +@pytest.mark.parametrize("passing_data", [True, False]) +def test_can_load_remote_yaml_schema(run_line, tmp_path, passing_data): + retrieval_uri = "https://example.org/retrieval/schemas/main.yaml" + responses.add( + "GET", + retrieval_uri, + body="""\ +"$schema": "http://json-schema.org/draft-07/schema" +properties: + title: {"type": "string"} +additionalProperties: false +""", + ) + + doc = tmp_path / "doc.json" + doc.write_text(json.dumps(PASSING_DOCUMENT if passing_data else FAILING_DOCUMENT)) + + result = run_line(["check-jsonschema", "--schemafile", retrieval_uri, str(doc)]) + assert result.exit_code == (0 if passing_data else 1) + + +@pytest.mark.parametrize("passing_data", [True, False]) +def test_can_load_remote_yaml_schema_ref(run_line, tmp_path, passing_data): + retrieval_uri = "https://example.org/retrieval/schemas/main.yaml" + responses.add( + "GET", + retrieval_uri, + body="""\ +"$schema": "http://json-schema.org/draft-07/schema" +properties: + "title": {"$ref": "./title_schema.yaml"} +additionalProperties: false +""", + ) + responses.add( + "GET", + "https://example.org/retrieval/schemas/title_schema.yaml", + body="type: string", + ) + + doc = tmp_path / "doc.json" + doc.write_text(json.dumps(PASSING_DOCUMENT if passing_data else FAILING_DOCUMENT)) + + result = run_line(["check-jsonschema", "--schemafile", retrieval_uri, str(doc)]) + assert result.exit_code == (0 if passing_data else 1) + + +def test_can_load_remote_yaml_schema_ref_from_cache( + run_line, inject_cached_ref, tmp_path +): + retrieval_uri = "https://example.org/retrieval/schemas/main.yaml" + responses.add( + "GET", + retrieval_uri, + body="""\ +"$schema": "http://json-schema.org/draft-07/schema" +properties: + "title": {"$ref": "./title_schema.yaml"} +additionalProperties: false +""", + ) + + ref_loc = "https://example.org/retrieval/schemas/title_schema.yaml" + # populate a bad schema, but then "override" that with a good cache value + # this can only pass (in the success case) if the cache loading really works + responses.add("GET", ref_loc, body="false") + inject_cached_ref(ref_loc, "type: string") + + doc = tmp_path / "doc.json" + doc.write_text(json.dumps(PASSING_DOCUMENT)) + + result = run_line(["check-jsonschema", "--schemafile", retrieval_uri, str(doc)]) + assert result.exit_code == 0 diff --git a/tests/acceptance/test_remote_ref_resolution.py b/tests/acceptance/test_remote_ref_resolution.py index e73593d1e..d95fba555 100644 --- a/tests/acceptance/test_remote_ref_resolution.py +++ b/tests/acceptance/test_remote_ref_resolution.py @@ -3,8 +3,6 @@ import pytest import responses -from check_jsonschema import cachedownloader - CASES = { "case1": { "main_schema": { @@ -37,26 +35,90 @@ } -@pytest.fixture(autouse=True) -def _mock_schema_cache_dir(monkeypatch, tmp_path): - def _fake_compute_default_cache_dir(self): - return str(tmp_path) +@pytest.mark.parametrize("check_passes", (True, False)) +@pytest.mark.parametrize("casename", ("case1", "case2")) +def test_remote_ref_resolution_simple_case(run_line, check_passes, casename, tmp_path): + main_schema_loc = "https://example.com/main.json" + responses.add("GET", main_schema_loc, json=CASES[casename]["main_schema"]) + for name, subschema in CASES[casename]["other_schemas"].items(): + other_schema_loc = f"https://example.com/{name}.json" + responses.add("GET", other_schema_loc, json=subschema) + + instance_path = tmp_path / "instance.json" + instance_path.write_text( + json.dumps( + CASES[casename]["passing_document"] + if check_passes + else CASES[casename]["failing_document"] + ) + ) - monkeypatch.setattr( - cachedownloader.CacheDownloader, - "_compute_default_cache_dir", - _fake_compute_default_cache_dir, + result = run_line( + ["check-jsonschema", "--schemafile", main_schema_loc, str(instance_path)] ) + output = f"\nstdout:\n{result.stdout}\n\nstderr:\n{result.stderr}" + if check_passes: + assert result.exit_code == 0, output + else: + assert result.exit_code == 1, output -@pytest.mark.parametrize("check_passes", (True, False)) @pytest.mark.parametrize("casename", ("case1", "case2")) -def test_remote_ref_resolution_simple_case(run_line, check_passes, casename, tmp_path): +@pytest.mark.parametrize("disable_cache", (True, False)) +def test_remote_ref_resolution_cache_control( + run_line, tmp_path, get_ref_cache_loc, casename, disable_cache +): main_schema_loc = "https://example.com/main.json" responses.add("GET", main_schema_loc, json=CASES[casename]["main_schema"]) + + ref_locs = [] for name, subschema in CASES[casename]["other_schemas"].items(): other_schema_loc = f"https://example.com/{name}.json" responses.add("GET", other_schema_loc, json=subschema) + ref_locs.append(other_schema_loc) + + instance_path = tmp_path / "instance.json" + instance_path.write_text(json.dumps(CASES[casename]["passing_document"])) + + # run the command + result = run_line( + ["check-jsonschema", "--schemafile", main_schema_loc, str(instance_path)] + + (["--no-cache"] if disable_cache else []) + ) + output = f"\nstdout:\n{result.stdout}\n\nstderr:\n{result.stderr}" + assert result.exit_code == 0, output + + cache_locs = [] + for ref_loc in ref_locs: + cache_locs.append(get_ref_cache_loc(ref_loc)) + assert cache_locs # sanity check + if disable_cache: + for loc in cache_locs: + assert not loc.exists() + else: + for loc in cache_locs: + assert loc.exists() + + +@pytest.mark.parametrize("casename", ("case1", "case2")) +@pytest.mark.parametrize("check_passes", (True, False)) +def test_remote_ref_resolution_loads_from_cache( + run_line, tmp_path, get_ref_cache_loc, inject_cached_ref, casename, check_passes +): + main_schema_loc = "https://example.com/main.json" + responses.add("GET", main_schema_loc, json=CASES[casename]["main_schema"]) + + ref_locs = [] + cache_locs = [] + for name, subschema in CASES[casename]["other_schemas"].items(): + other_schema_loc = f"https://example.com/{name}.json" + # intentionally populate the HTTP location with "bad data" + responses.add("GET", other_schema_loc, json="{}") + ref_locs.append(other_schema_loc) + + # but populate the cache with "good data" + inject_cached_ref(other_schema_loc, json.dumps(subschema)) + cache_locs.append(get_ref_cache_loc(other_schema_loc)) instance_path = tmp_path / "instance.json" instance_path.write_text( @@ -67,6 +129,7 @@ def test_remote_ref_resolution_simple_case(run_line, check_passes, casename, tmp ) ) + # run the command result = run_line( ["check-jsonschema", "--schemafile", main_schema_loc, str(instance_path)] ) diff --git a/tests/acceptance/test_special_filetypes.py b/tests/acceptance/test_special_filetypes.py index 70ca1cdcf..9036e51ae 100644 --- a/tests/acceptance/test_special_filetypes.py +++ b/tests/acceptance/test_special_filetypes.py @@ -6,8 +6,6 @@ import pytest import responses -from check_jsonschema import cachedownloader - @pytest.mark.skipif( platform.system() != "Linux", reason="test requires /proc/self/ mechanism" @@ -81,21 +79,11 @@ def test_schema_and_instance_in_fifos(tmp_path, run_line, check_succeeds): @pytest.mark.parametrize("check_passes", (True, False)) -def test_remote_schema_requiring_retry(run_line, check_passes, tmp_path, monkeypatch): +def test_remote_schema_requiring_retry(run_line, check_passes, tmp_path): """ a "remote schema" (meaning HTTPS) with bad data, therefore requiring that a retry fires in order to parse """ - - def _fake_compute_default_cache_dir(self): - return str(tmp_path) - - monkeypatch.setattr( - cachedownloader.CacheDownloader, - "_compute_default_cache_dir", - _fake_compute_default_cache_dir, - ) - schema_loc = "https://example.com/schema1.json" responses.add("GET", schema_loc, body="", match_querystring=None) responses.add( diff --git a/tests/conftest.py b/tests/conftest.py index b2cd6969c..9179f6eed 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -46,3 +46,67 @@ def in_tmp_dir(request, tmp_path): os.chdir(str(tmp_path)) yield os.chdir(request.config.invocation_dir) + + +@pytest.fixture +def cache_dir(tmp_path): + return tmp_path / ".cache" + + +@pytest.fixture(autouse=True) +def patch_cache_dir(monkeypatch, cache_dir): + with monkeypatch.context() as m: + m.setattr( + "check_jsonschema.cachedownloader._base_cache_dir", lambda: str(cache_dir) + ) + yield m + + +@pytest.fixture +def downloads_cache_dir(tmp_path): + return tmp_path / ".cache" / "check_jsonschema" / "downloads" + + +@pytest.fixture +def get_download_cache_loc(downloads_cache_dir): + def _get(uri): + return downloads_cache_dir / uri.split("/")[-1] + + return _get + + +@pytest.fixture +def inject_cached_download(downloads_cache_dir, get_download_cache_loc): + def _write(uri, content): + downloads_cache_dir.mkdir(parents=True) + path = get_download_cache_loc(uri) + if isinstance(content, str): + path.write_text(content) + else: + path.write_bytes(content) + + return _write + + +@pytest.fixture +def refs_cache_dir(tmp_path): + return tmp_path / ".cache" / "check_jsonschema" / "refs" + + +@pytest.fixture +def get_ref_cache_loc(refs_cache_dir): + from check_jsonschema.schema_loader.resolver import ref_url_to_cache_filename + + def _get(uri): + return refs_cache_dir / ref_url_to_cache_filename(uri) + + return _get + + +@pytest.fixture +def inject_cached_ref(refs_cache_dir, get_ref_cache_loc): + def _write(uri, content): + refs_cache_dir.mkdir(parents=True) + get_ref_cache_loc(uri).write_text(content) + + return _write diff --git a/tests/unit/test_cachedownloader.py b/tests/unit/test_cachedownloader.py index 32098db91..30d47052b 100644 --- a/tests/unit/test_cachedownloader.py +++ b/tests/unit/test_cachedownloader.py @@ -7,7 +7,11 @@ import requests import responses -from check_jsonschema.cachedownloader import CacheDownloader, FailedDownloadError +from check_jsonschema.cachedownloader import ( + CacheDownloader, + FailedDownloadError, + _cache_hit, +) def add_default_response(): @@ -26,7 +30,7 @@ def default_response(): def test_default_filename_from_uri(default_response): - cd = CacheDownloader("https://example.com/schema1.json") + cd = CacheDownloader().bind("https://example.com/schema1.json") assert cd._filename == "schema1.json" @@ -47,8 +51,11 @@ def test_default_filename_from_uri(default_response): ], ) def test_default_cache_dir( - monkeypatch, default_response, sysname, fakeenv, expect_value + patch_cache_dir, monkeypatch, default_response, sysname, fakeenv, expect_value ): + # undo the patch which typically overrides resolution of the cache dir + patch_cache_dir.undo() + for var in ["LOCALAPPDATA", "APPDATA", "XDG_CACHE_HOME"]: monkeypatch.delenv(var, raising=False) for k, v in fakeenv.items(): @@ -69,7 +76,7 @@ def fake_expanduser(path): monkeypatch.setattr(platform, "system", fakesystem) monkeypatch.setattr(os.path, "expanduser", fake_expanduser) - cd = CacheDownloader("https://example.com/schema1.json") + cd = CacheDownloader() assert cd._cache_dir == expect_value if sysname == "Darwin": @@ -85,17 +92,15 @@ def test_cache_hit_by_mtime(monkeypatch, default_response): # local mtime = NOW, cache hit monkeypatch.setattr(os.path, "getmtime", lambda x: time.time()) - cd = CacheDownloader("https://example.com/schema1.json") - assert cd._cache_hit( + assert _cache_hit( "/tmp/schema1.json", requests.get("https://example.com/schema1.json", stream=True), ) # local mtime = 0, cache miss monkeypatch.setattr(os.path, "getmtime", lambda x: 0) - cd = CacheDownloader("https://example.com/schema1.json") assert ( - cd._cache_hit( + _cache_hit( "/tmp/schema1.json", requests.get("https://example.com/schema1.json", stream=True), ) @@ -109,19 +114,46 @@ def test_cachedownloader_cached_file(tmp_path, monkeypatch, default_response): f.write_text("{}") # set the cache_dir to the tmp dir (so that cache_dir will always be set) - cd = CacheDownloader(str(f), cache_dir=tmp_path) + cd = CacheDownloader(cache_dir=tmp_path).bind(str(f)) # patch the downloader to skip any download "work" - monkeypatch.setattr(cd, "_download", lambda: str(f)) + monkeypatch.setattr( + cd._downloader, "_download", lambda file_uri, filename, response_ok: str(f) + ) with cd.open() as fp: assert fp.read() == b"{}" -@pytest.mark.parametrize( - "mode", ["filename", "filename_otherdir", "cache_dir", "disable_cache"] -) -@pytest.mark.parametrize("failures", (0, 1, 10, requests.ConnectionError)) -def test_cachedownloader_e2e(tmp_path, mode, failures): +@pytest.mark.parametrize("disable_cache", (True, False)) +def test_cachedownloader_on_success(get_download_cache_loc, disable_cache): + add_default_response() + f = get_download_cache_loc("schema1.json") + cd = CacheDownloader(disable_cache=disable_cache).bind( + "https://example.com/schema1.json" + ) + + with cd.open() as fp: + assert fp.read() == b"{}" + if disable_cache: + assert not f.exists() + else: + assert f.exists() + + +def test_cachedownloader_using_alternate_target_dir(cache_dir): + add_default_response() + f = cache_dir / "check_jsonschema" / "otherdir" / "schema1.json" + cd = CacheDownloader("otherdir").bind("https://example.com/schema1.json") + with cd.open() as fp: + assert fp.read() == b"{}" + assert f.exists() + + +@pytest.mark.parametrize("disable_cache", (True, False)) +@pytest.mark.parametrize("failures", (1, 2, requests.ConnectionError)) +def test_cachedownloader_succeeds_after_few_errors( + get_download_cache_loc, disable_cache, failures +): if isinstance(failures, int): for _i in range(failures): responses.add( @@ -138,49 +170,52 @@ def test_cachedownloader_e2e(tmp_path, mode, failures): match_querystring=None, ) add_default_response() - f = tmp_path / "schema1.json" - if mode == "filename": - cd = CacheDownloader( - "https://example.com/schema1.json", filename=str(f), cache_dir=str(tmp_path) - ) - elif mode == "filename_otherdir": - otherdir = tmp_path / "otherdir" - cd = CacheDownloader( - "https://example.com/schema1.json", filename=str(f), cache_dir=str(otherdir) - ) - elif mode == "cache_dir": - cd = CacheDownloader( - "https://example.com/schema1.json", cache_dir=str(tmp_path) - ) - elif mode == "disable_cache": - cd = CacheDownloader("https://example.com/schema1.json", disable_cache=True) + f = get_download_cache_loc("schema1.json") + cd = CacheDownloader(disable_cache=disable_cache).bind( + "https://example.com/schema1.json" + ) + + with cd.open() as fp: + assert fp.read() == b"{}" + if disable_cache: + assert not f.exists() else: - raise NotImplementedError + assert f.exists() - if isinstance(failures, int) and failures < 3: - with cd.open() as fp: - assert fp.read() == b"{}" - if mode == "filename": - assert f.exists() - elif mode == "filename_otherdir": - otherdir = f.exists() - elif mode == "cache_dir": - assert (tmp_path / "schema1.json").exists() - elif mode == "disable_cache": - assert not (tmp_path / "schema1.json").exists() - assert not f.exists() + +@pytest.mark.parametrize("disable_cache", (True, False)) +@pytest.mark.parametrize("connection_error", (True, False)) +def test_cachedownloader_fails_after_many_errors( + get_download_cache_loc, disable_cache, connection_error +): + for _i in range(10): + if connection_error: + responses.add( + "GET", + "https://example.com/schema1.json", + body=requests.ConnectionError(), + match_querystring=None, + ) else: - raise NotImplementedError - else: - with pytest.raises(FailedDownloadError): - with cd.open() as fp: - pass - assert not (tmp_path / "schema1.json").exists() - assert not f.exists() + responses.add( + "GET", + "https://example.com/schema1.json", + status=500, + match_querystring=None, + ) + add_default_response() # never reached, the 11th response + f = get_download_cache_loc("schema1.json") + cd = CacheDownloader(disable_cache=disable_cache).bind( + "https://example.com/schema1.json" + ) + with pytest.raises(FailedDownloadError): + with cd.open(): + pass + assert not f.exists() @pytest.mark.parametrize("disable_cache", (True, False)) -def test_cachedownloader_retries_on_bad_data(tmp_path, disable_cache): +def test_cachedownloader_retries_on_bad_data(get_download_cache_loc, disable_cache): responses.add( "GET", "https://example.com/schema1.json", @@ -189,12 +224,12 @@ def test_cachedownloader_retries_on_bad_data(tmp_path, disable_cache): match_querystring=None, ) add_default_response() - f = tmp_path / "schema1.json" + f = get_download_cache_loc("schema1.json") cd = CacheDownloader( + disable_cache=disable_cache, + ).bind( "https://example.com/schema1.json", filename=str(f), - cache_dir=str(tmp_path), - disable_cache=disable_cache, validation_callback=json.loads, ) @@ -212,20 +247,19 @@ def test_cachedownloader_retries_on_bad_data(tmp_path, disable_cache): "failure_mode", ("header_missing", "header_malformed", "time_overflow") ) def test_cachedownloader_handles_bad_lastmod_header( - monkeypatch, tmp_path, file_exists, failure_mode + monkeypatch, + get_download_cache_loc, + inject_cached_download, + file_exists, + failure_mode, ): + uri = "https://example.com/schema1.json" if failure_mode == "header_missing": - responses.add( - "GET", - "https://example.com/schema1.json", - headers={}, - json={}, - match_querystring=None, - ) + responses.add("GET", uri, headers={}, json={}, match_querystring=None) elif failure_mode == "header_malformed": responses.add( "GET", - "https://example.com/schema1.json", + uri, headers={"Last-Modified": "Jan 2000 00:00:01"}, json={}, match_querystring=None, @@ -241,16 +275,13 @@ def fake_mktime(*args): raise NotImplementedError original_file_contents = b'{"foo": "bar"}' - f = tmp_path / "schema1.json" + file_path = get_download_cache_loc(uri) + assert not file_path.exists() if file_exists: - f.write_bytes(original_file_contents) - else: - assert not f.exists() + inject_cached_download(uri, original_file_contents) - cd = CacheDownloader( - "https://example.com/schema1.json", filename=str(f), cache_dir=str(tmp_path) - ) + cd = CacheDownloader().bind(uri) # if the file already existed, it will not be overwritten by the cachedownloader # so the returned value for both the downloader and a direct file read should be the @@ -258,30 +289,33 @@ def fake_mktime(*args): if file_exists: with cd.open() as fp: assert fp.read() == original_file_contents - assert f.read_bytes() == original_file_contents + assert file_path.read_bytes() == original_file_contents # otherwise, the file will have been created with new content # both reads will show that new content else: with cd.open() as fp: assert fp.read() == b"{}" - assert f.read_bytes() == b"{}" + assert file_path.read_bytes() == b"{}" # at the end, the file always exists on disk - assert f.exists() + assert file_path.exists() -def test_cachedownloader_validation_is_not_invoked_on_hit(monkeypatch, tmp_path): +def test_cachedownloader_validation_is_not_invoked_on_hit( + monkeypatch, inject_cached_download +): """ Regression test for https://github.com/python-jsonschema/check-jsonschema/issues/453 This was a bug in which the validation callback was invoked eagerly, even on a cache hit. As a result, cache hits did not demonstrate their expected performance gain. """ + uri = "https://example.com/schema1.json" + # 1: construct some perfectly good data (it doesn't really matter what it is) add_default_response() # 2: put equivalent data on disk - f = tmp_path / "schema1.json" - f.write_text("{}") + inject_cached_download(uri, "{}") # 3: construct a validator which marks that it ran in a variable validator_ran = False @@ -292,10 +326,8 @@ def dummy_validate_bytes(data): # construct a downloader pointed at the schema and file, expecting a cache hit # and use the above validation method - cd = CacheDownloader( + cd = CacheDownloader().bind( "https://example.com/schema1.json", - filename=str(f), - cache_dir=str(tmp_path), validation_callback=dummy_validate_bytes, )