diff --git a/CHANGES.rst b/CHANGES.rst index 6e3266ef..7e0225fe 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -10,6 +10,10 @@ v5.10.0 files was renamed from 'package' to 'anchor', with a compatibility shim for those passing by keyword. +* #259: ``files`` no longer requires the anchor to be + specified and can infer the anchor from the caller's scope + (defaults to the caller's module). + v5.9.0 ====== diff --git a/importlib_resources/_common.py b/importlib_resources/_common.py index 52af4a13..9f19784d 100644 --- a/importlib_resources/_common.py +++ b/importlib_resources/_common.py @@ -5,7 +5,9 @@ import contextlib import types import importlib +import inspect import warnings +import itertools from typing import Union, Optional, cast from .abc import ResourceReader, Traversable @@ -22,12 +24,9 @@ def package_to_anchor(func): Other errors should fall through. - >>> files() - Traceback (most recent call last): - TypeError: files() missing 1 required positional argument: 'anchor' >>> files('a', 'b') Traceback (most recent call last): - TypeError: files() takes 1 positional argument but 2 were given + TypeError: files() takes from 0 to 1 positional arguments but 2 were given """ undefined = object() @@ -50,7 +49,7 @@ def wrapper(anchor=undefined, package=undefined): @package_to_anchor -def files(anchor: Anchor) -> Traversable: +def files(anchor: Optional[Anchor] = None) -> Traversable: """ Get a Traversable resource for an anchor. """ @@ -74,7 +73,7 @@ def get_resource_reader(package: types.ModuleType) -> Optional[ResourceReader]: @functools.singledispatch -def resolve(cand: Anchor) -> types.ModuleType: +def resolve(cand: Optional[Anchor]) -> types.ModuleType: return cast(types.ModuleType, cand) @@ -83,6 +82,28 @@ def _(cand: str) -> types.ModuleType: return importlib.import_module(cand) +@resolve.register +def _(cand: None) -> types.ModuleType: + return resolve(_infer_caller().f_globals['__name__']) + + +def _infer_caller(): + """ + Walk the stack and find the frame of the first caller not in this module. + """ + + def is_this_file(frame_info): + return frame_info.filename == __file__ + + def is_wrapper(frame_info): + return frame_info.function == 'wrapper' + + not_this_file = itertools.filterfalse(is_this_file, inspect.stack()) + # also exclude 'wrapper' due to singledispatch in the call stack + callers = itertools.filterfalse(is_wrapper, not_this_file) + return next(callers).frame + + def from_package(package: types.ModuleType): """ Return a Traversable object for the given package. diff --git a/importlib_resources/tests/test_files.py b/importlib_resources/tests/test_files.py index dac08024..d258fb5f 100644 --- a/importlib_resources/tests/test_files.py +++ b/importlib_resources/tests/test_files.py @@ -1,6 +1,8 @@ import typing +import textwrap import unittest import warnings +import importlib import contextlib import importlib_resources as resources @@ -61,7 +63,7 @@ def setUp(self): self.data = namespacedata01 -class ModulesFilesTests(unittest.TestCase): +class SiteDir: def setUp(self): self.fixtures = contextlib.ExitStack() self.addCleanup(self.fixtures.close) @@ -69,6 +71,8 @@ def setUp(self): self.fixtures.enter_context(import_helper.DirsOnSysPath(self.site_dir)) self.fixtures.enter_context(import_helper.CleanImport()) + +class ModulesFilesTests(SiteDir, unittest.TestCase): def test_module_resources(self): """ A module can have resources found adjacent to the module. @@ -84,5 +88,25 @@ def test_module_resources(self): assert actual == spec['res.txt'] +class ImplicitContextFilesTests(SiteDir, unittest.TestCase): + def test_implicit_files(self): + """ + Without any parameter, files() will infer the location as the caller. + """ + spec = { + 'somepkg': { + '__init__.py': textwrap.dedent( + """ + import importlib_resources as res + val = res.files().joinpath('res.txt').read_text() + """ + ), + 'res.txt': 'resources are the best', + }, + } + _path.build(spec, self.site_dir) + assert importlib.import_module('somepkg').val == 'resources are the best' + + if __name__ == '__main__': unittest.main()