From 734a966e64d625ba6e8b0ac01064f906d97ad9ed Mon Sep 17 00:00:00 2001 From: Brett Date: Mon, 13 Oct 2025 18:00:28 -0400 Subject: [PATCH 1/2] resolve all refs for load_schema (can produce a recursive schema) --- asdf/_tests/test_schema.py | 9 +++++-- asdf/schema.py | 7 +++--- asdf/treeutil.py | 51 +++++++++++++++++++++++++++----------- 3 files changed, 46 insertions(+), 21 deletions(-) diff --git a/asdf/_tests/test_schema.py b/asdf/_tests/test_schema.py index cceb4394e..35862e3d4 100644 --- a/asdf/_tests/test_schema.py +++ b/asdf/_tests/test_schema.py @@ -132,7 +132,7 @@ def test_load_schema(tmp_path): schema_path = tmp_path / "nugatory.yaml" schema_path.write_bytes(schema_def.encode()) - schema_tree = schema.load_schema(str(schema_path), resolve_references=True) + schema_tree = schema.load_schema(str(schema_path)) schema.check_schema(schema_tree) @@ -156,7 +156,7 @@ def test_load_schema_with_file_url(tmp_path): schema_path = tmp_path / "nugatory.yaml" schema_path.write_bytes(schema_def.encode()) - schema_tree = schema.load_schema(str(schema_path), resolve_references=True) + schema_tree = schema.load_schema(str(schema_path)) schema.check_schema(schema_tree) @@ -814,6 +814,11 @@ def test_self_reference_resolution(test_data_path): assert s["anyOf"][1] == s["anyOf"][0] +def test_resolve_references(): + s = schema.load_schema("http://stsci.edu/schemas/asdf/core/ndarray-1.0.0", resolve_references=True) + assert "$ref" not in repr(s) + + def test_schema_resolved_via_entry_points(): """Test that entry points mappings to core schema works""" tag = "tag:stsci.edu:asdf/fits/fits-1.0.0" diff --git a/asdf/schema.py b/asdf/schema.py index 5fd6d5278..92bbee123 100644 --- a/asdf/schema.py +++ b/asdf/schema.py @@ -442,15 +442,14 @@ def resolve_refs(node, json_id): if suburl_base == url or suburl_base == schema.get("id"): # This is a local ref, which we'll resolve in both cases. - subschema = schema + return reference.resolve_fragment(schema, suburl_fragment) else: subschema = load_schema(suburl_base, True) - - return reference.resolve_fragment(subschema, suburl_fragment) + return reference.resolve_fragment(subschema, suburl_fragment) return node - schema = treeutil.walk_and_modify(schema, resolve_refs) + schema = treeutil.walk_and_modify(schema, resolve_refs, in_place=True) return schema diff --git a/asdf/treeutil.py b/asdf/treeutil.py index 96ae4cc41..bca29fa2a 100644 --- a/asdf/treeutil.py +++ b/asdf/treeutil.py @@ -219,7 +219,7 @@ def __repr__(self): RemoveNode = _RemoveNode() -def walk_and_modify(top, callback, postorder=True, _context=None): +def walk_and_modify(top, callback, postorder=True, _context=None, in_place=False): """Modify a tree by walking it with a callback function. It also has the effect of doing a deep copy. @@ -279,14 +279,17 @@ def _handle_callback(node, json_id): return _handle_generator(result) def _handle_mapping(node, json_id): - if isinstance(node, lazy_nodes.AsdfOrderedDictNode): - result = collections.OrderedDict() - elif isinstance(node, lazy_nodes.AsdfDictNode): - result = {} + if in_place: + result = node else: - result = node.__class__() - if isinstance(node, tagged.Tagged): - result._tag = node._tag + if isinstance(node, lazy_nodes.AsdfOrderedDictNode): + result = collections.OrderedDict() + elif isinstance(node, lazy_nodes.AsdfDictNode): + result = {} + else: + result = node.__class__() + if isinstance(node, tagged.Tagged): + result._tag = node._tag pending_items = {} for key, value in node.items(): @@ -300,6 +303,7 @@ def _handle_mapping(node, json_id): elif (val := _recurse(value, json_id)) is not RemoveNode: result[key] = val + # TODO handle RemoveNode yield result @@ -307,6 +311,7 @@ def _handle_mapping(node, json_id): # Now that we've yielded, the pending children should # be available. for key, value in pending_items.items(): + # TODO handle RemoveNode if (val := _recurse(value, json_id)) is not RemoveNode: result[key] = val else: @@ -315,12 +320,23 @@ def _handle_mapping(node, json_id): del result[key] def _handle_mutable_sequence(node, json_id): - if isinstance(node, lazy_nodes.AsdfListNode): - result = [] + if in_place: + result = node + + def setter(i, v): + result[i] = v + else: - result = node.__class__() - if isinstance(node, tagged.Tagged): - result._tag = node._tag + + def setter(i, v): + result.append(v) + + if isinstance(node, lazy_nodes.AsdfListNode): + result = [] + else: + result = node.__class__() + if isinstance(node, tagged.Tagged): + result._tag = node._tag pending_items = {} for i, value in enumerate(node): @@ -330,9 +346,11 @@ def _handle_mutable_sequence(node, json_id): # PendingValue instance for now, and note that we'll # need to fill in the real value later. pending_items[i] = value - result.append(PendingValue) + setter(i, PendingValue) + # result.append(PendingValue) else: - result.append(_recurse(value, json_id)) + setter(i, _recurse(value, json_id)) + # result.append(_recurse(value, json_id)) yield result @@ -346,6 +364,9 @@ def _handle_immutable_sequence(node, json_id): # to construct (well, maybe possible in a C extension, but # we're not going to worry about that), so we don't need # to yield here. + if in_place: + # TODO better error + raise Exception("fail") contents = [_recurse(value, json_id) for value in node] result = node.__class__(contents) From a13377fc5cfb5d35e84f739a3216e343fdce7d92 Mon Sep 17 00:00:00 2001 From: Brett Date: Thu, 16 Oct 2025 09:11:52 -0400 Subject: [PATCH 2/2] add allow_recursion to load_schema --- asdf/_tests/test_schema.py | 8 +++++--- asdf/schema.py | 15 ++++++++++----- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/asdf/_tests/test_schema.py b/asdf/_tests/test_schema.py index 35862e3d4..c3a1cb66b 100644 --- a/asdf/_tests/test_schema.py +++ b/asdf/_tests/test_schema.py @@ -132,7 +132,7 @@ def test_load_schema(tmp_path): schema_path = tmp_path / "nugatory.yaml" schema_path.write_bytes(schema_def.encode()) - schema_tree = schema.load_schema(str(schema_path)) + schema_tree = schema.load_schema(str(schema_path), resolve_references=True) schema.check_schema(schema_tree) @@ -156,7 +156,7 @@ def test_load_schema_with_file_url(tmp_path): schema_path = tmp_path / "nugatory.yaml" schema_path.write_bytes(schema_def.encode()) - schema_tree = schema.load_schema(str(schema_path)) + schema_tree = schema.load_schema(str(schema_path), resolve_references=True) schema.check_schema(schema_tree) @@ -815,7 +815,9 @@ def test_self_reference_resolution(test_data_path): def test_resolve_references(): - s = schema.load_schema("http://stsci.edu/schemas/asdf/core/ndarray-1.0.0", resolve_references=True) + s = schema.load_schema( + "http://stsci.edu/schemas/asdf/core/ndarray-1.0.0", resolve_references=True, allow_recursion=True + ) assert "$ref" not in repr(s) diff --git a/asdf/schema.py b/asdf/schema.py index 92bbee123..1071e4a4f 100644 --- a/asdf/schema.py +++ b/asdf/schema.py @@ -380,7 +380,7 @@ def get_schema(url): ) -def load_schema(url, resolve_references=False): +def load_schema(url, resolve_references=False, allow_recursion=False): """ Load a schema from the given URL. @@ -390,13 +390,18 @@ def load_schema(url, resolve_references=False): The path to the schema resolve_references : bool, optional - If ``True``, resolve all ``$ref`` references. + If ``True``, resolve ``$ref`` references. Note that local + references will only be resolved 1 level deep to prevent + returning recursive schemas (see allow_recursion). + allow_recursion : bool, options + If ``True``, resolve all local references possibly producing + a recursive schema. """ # We want to cache the work that went into constructing the schema, but returning # the same object is treacherous, because users who mutate the result will not # expect that they're changing the schema everywhere. - return copy.deepcopy(_load_schema_cached(url, resolve_references)) + return copy.deepcopy(_load_schema_cached(url, resolve_references, allow_recursion)) def _safe_resolve(json_id, uri): @@ -427,7 +432,7 @@ def _safe_resolve(json_id, uri): @lru_cache -def _load_schema_cached(url, resolve_references): +def _load_schema_cached(url, resolve_references, allow_recursion=False): loader = _make_schema_loader() schema, url = loader(url) @@ -449,7 +454,7 @@ def resolve_refs(node, json_id): return node - schema = treeutil.walk_and_modify(schema, resolve_refs, in_place=True) + schema = treeutil.walk_and_modify(schema, resolve_refs, in_place=allow_recursion) return schema