Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions asdf/_tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,6 +814,13 @@ def test_self_reference_resolution(test_data_path):
assert s["anyOf"][1] == s["anyOf"][0]


def test_resolve_references():
s = schema.load_schema(
"http://stsci.edu/schemas/asdf/core/ndarray-1.0.0", resolve_references=True, allow_recursion=True
)
assert "$ref" not in repr(s)


def test_schema_resolved_via_entry_points():
"""Test that entry points mappings to core schema works"""
tag = "tag:stsci.edu:asdf/fits/fits-1.0.0"
Expand Down
20 changes: 12 additions & 8 deletions asdf/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ def get_schema(url):
)


def load_schema(url, resolve_references=False):
def load_schema(url, resolve_references=False, allow_recursion=False):
"""
Load a schema from the given URL.

Expand All @@ -390,13 +390,18 @@ def load_schema(url, resolve_references=False):
The path to the schema

resolve_references : bool, optional
If ``True``, resolve all ``$ref`` references.
If ``True``, resolve ``$ref`` references. Note that local
references will only be resolved 1 level deep to prevent
returning recursive schemas (see allow_recursion).

allow_recursion : bool, options
If ``True``, resolve all local references possibly producing
a recursive schema.
"""
# We want to cache the work that went into constructing the schema, but returning
# the same object is treacherous, because users who mutate the result will not
# expect that they're changing the schema everywhere.
return copy.deepcopy(_load_schema_cached(url, resolve_references))
return copy.deepcopy(_load_schema_cached(url, resolve_references, allow_recursion))


def _safe_resolve(json_id, uri):
Expand Down Expand Up @@ -427,7 +432,7 @@ def _safe_resolve(json_id, uri):


@lru_cache
def _load_schema_cached(url, resolve_references):
def _load_schema_cached(url, resolve_references, allow_recursion=False):
loader = _make_schema_loader()
schema, url = loader(url)

Expand All @@ -442,15 +447,14 @@ def resolve_refs(node, json_id):

if suburl_base == url or suburl_base == schema.get("id"):
# This is a local ref, which we'll resolve in both cases.
subschema = schema
return reference.resolve_fragment(schema, suburl_fragment)
else:
subschema = load_schema(suburl_base, True)

return reference.resolve_fragment(subschema, suburl_fragment)
return reference.resolve_fragment(subschema, suburl_fragment)

return node

schema = treeutil.walk_and_modify(schema, resolve_refs)
schema = treeutil.walk_and_modify(schema, resolve_refs, in_place=allow_recursion)

return schema

Expand Down
51 changes: 36 additions & 15 deletions asdf/treeutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def __repr__(self):
RemoveNode = _RemoveNode()


def walk_and_modify(top, callback, postorder=True, _context=None):
def walk_and_modify(top, callback, postorder=True, _context=None, in_place=False):
"""Modify a tree by walking it with a callback function. It also has
the effect of doing a deep copy.

Expand Down Expand Up @@ -279,14 +279,17 @@ def _handle_callback(node, json_id):
return _handle_generator(result)

def _handle_mapping(node, json_id):
if isinstance(node, lazy_nodes.AsdfOrderedDictNode):
result = collections.OrderedDict()
elif isinstance(node, lazy_nodes.AsdfDictNode):
result = {}
if in_place:
result = node
else:
result = node.__class__()
if isinstance(node, tagged.Tagged):
result._tag = node._tag
if isinstance(node, lazy_nodes.AsdfOrderedDictNode):
result = collections.OrderedDict()
elif isinstance(node, lazy_nodes.AsdfDictNode):
result = {}
else:
result = node.__class__()
if isinstance(node, tagged.Tagged):
result._tag = node._tag

pending_items = {}
for key, value in node.items():
Expand All @@ -300,13 +303,15 @@ def _handle_mapping(node, json_id):

elif (val := _recurse(value, json_id)) is not RemoveNode:
result[key] = val
# TODO handle RemoveNode

yield result

if len(pending_items) > 0:
# Now that we've yielded, the pending children should
# be available.
for key, value in pending_items.items():
# TODO handle RemoveNode
if (val := _recurse(value, json_id)) is not RemoveNode:
result[key] = val
else:
Expand All @@ -315,12 +320,23 @@ def _handle_mapping(node, json_id):
del result[key]

def _handle_mutable_sequence(node, json_id):
if isinstance(node, lazy_nodes.AsdfListNode):
result = []
if in_place:
result = node

def setter(i, v):
result[i] = v

else:
result = node.__class__()
if isinstance(node, tagged.Tagged):
result._tag = node._tag

def setter(i, v):
result.append(v)

if isinstance(node, lazy_nodes.AsdfListNode):
result = []
else:
result = node.__class__()
if isinstance(node, tagged.Tagged):
result._tag = node._tag

pending_items = {}
for i, value in enumerate(node):
Expand All @@ -330,9 +346,11 @@ def _handle_mutable_sequence(node, json_id):
# PendingValue instance for now, and note that we'll
# need to fill in the real value later.
pending_items[i] = value
result.append(PendingValue)
setter(i, PendingValue)
# result.append(PendingValue)
else:
result.append(_recurse(value, json_id))
setter(i, _recurse(value, json_id))
# result.append(_recurse(value, json_id))

yield result

Expand All @@ -346,6 +364,9 @@ def _handle_immutable_sequence(node, json_id):
# to construct (well, maybe possible in a C extension, but
# we're not going to worry about that), so we don't need
# to yield here.
if in_place:
# TODO better error
raise Exception("fail")
contents = [_recurse(value, json_id) for value in node]

result = node.__class__(contents)
Expand Down
Loading