From 110314bfec28023b68b11f83584509053d33f351 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 8 Jul 2024 07:58:14 -0500 Subject: [PATCH 1/2] [Bugfix] Allow import of TVM when current directory is read-only Prior to this commit, TVM could only be imported if the current directory had write privileges. This was due to the use of `tvm.contrib.pickle_memoize` to cache the winograd transformation matrices. This commit makes multiple related fixes, to ensure that (1) TVM can be imported regardless of directory permissions, (2) the working directory is not left in a cluttered state, and (3) cache files are generated in an expected location to be reused later. * The cache directory is only generated when required, just prior to saving a cache. * The cache directory defaults to `$HOME/.cache/tvm/pkl_memoize`, rather than `.pkl_memorize_py3` in the working directory. * The cache directory respects `XDG_CACHE_HOME`, using `$XDG_CACHE_HOME/tvm/pkl_memoize` if set. --- python/tvm/contrib/pickle_memoize.py | 56 +++++--- tests/python/contrib/pickle_memoize_script.py | 48 +++++++ tests/python/contrib/test_memoize.py | 126 ++++++++++++++++++ 3 files changed, 212 insertions(+), 18 deletions(-) create mode 100755 tests/python/contrib/pickle_memoize_script.py create mode 100644 tests/python/contrib/test_memoize.py diff --git a/python/tvm/contrib/pickle_memoize.py b/python/tvm/contrib/pickle_memoize.py index 6d2ffbac0673..10f3f7c6df04 100644 --- a/python/tvm/contrib/pickle_memoize.py +++ b/python/tvm/contrib/pickle_memoize.py @@ -15,10 +15,13 @@ # specific language governing permissions and limitations # under the License. """Memoize result of function via pickle, used for cache testcases.""" + # pylint: disable=broad-except,superfluous-parens +import atexit import os +import pathlib import sys -import atexit + from decorator import decorate from .._ffi.base import string_types @@ -28,6 +31,17 @@ import pickle +def _get_global_cache_dir() -> pathlib.Path: + if "XDG_CACHE_HOME" in os.environ: + cache_home = pathlib.Path(os.environ.get("XDG_CACHE_HOME")) + else: + cache_home = pathlib.Path.home().joinpath(".cache") + return cache_home.joinpath("tvm", f"pkl_memoize_py{sys.version_info[0]}") + + +GLOBAL_CACHE_DIR = _get_global_cache_dir() + + class Cache(object): """A cache object for result cache. @@ -42,28 +56,34 @@ class Cache(object): cache_by_key = {} def __init__(self, key, save_at_exit): - cache_dir = f".pkl_memoize_py{sys.version_info[0]}" - try: - os.mkdir(cache_dir) - except FileExistsError: - pass - else: - self.cache = {} - self.path = os.path.join(cache_dir, key) - if os.path.exists(self.path): - try: - self.cache = pickle.load(open(self.path, "rb")) - except Exception: - self.cache = {} - else: - self.cache = {} + self._cache = None + + self.path = GLOBAL_CACHE_DIR.joinpath(key) self.dirty = False self.save_at_exit = save_at_exit + @property + def cache(self): + if self._cache is not None: + return self._cache + + if self.path.exists(): + with self.path.open("rb") as cache_file: + try: + cache = pickle.load(cache_file) + except pickle.UnpicklingError: + cache = {} + else: + cache = {} + + self._cache = cache + return self._cache + def save(self): if self.dirty: - print(f"Save memoize result to {self.path}") - with open(self.path, "wb") as out_file: + self.path.parent.mkdir(parents=True, exist_ok=True) + + with self.path.open("wb") as out_file: pickle.dump(self.cache, out_file, pickle.HIGHEST_PROTOCOL) diff --git a/tests/python/contrib/pickle_memoize_script.py b/tests/python/contrib/pickle_memoize_script.py new file mode 100755 index 000000000000..f0d73e391066 --- /dev/null +++ b/tests/python/contrib/pickle_memoize_script.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import sys + +import tvm + + +@tvm.contrib.pickle_memoize.memoize("test_memoize_save_data", save_at_exit=True) +def get_data_saved(): + return 42 + + +@tvm.contrib.pickle_memoize.memoize("test_memoize_transient_data", save_at_exit=False) +def get_data_transient(): + return 42 + + +def main(): + assert len(sys.argv) == 3, "Expect arguments SCRIPT NUM_SAVED NUM_TRANSIENT" + + num_iter_saved = int(sys.argv[1]) + num_iter_transient = int(sys.argv[2]) + + for _ in range(num_iter_saved): + get_data_saved() + for _ in range(num_iter_transient): + get_data_transient() + + +if __name__ == "__main__": + main() diff --git a/tests/python/contrib/test_memoize.py b/tests/python/contrib/test_memoize.py new file mode 100644 index 000000000000..6881940e5062 --- /dev/null +++ b/tests/python/contrib/test_memoize.py @@ -0,0 +1,126 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Tests for tvm.contrib.pickle_memoize""" + +import os +import pathlib +import tempfile +import subprocess +import sys + +import tvm.testing + +TEST_SCRIPT_FILE = pathlib.Path(__file__).with_name("pickle_memoize_script.py").resolve() + + +def test_cache_dir_not_in_current_working_dir(): + with tempfile.TemporaryDirectory(prefix="tvm_") as temp_dir: + temp_dir = pathlib.Path(temp_dir) + subprocess.check_call([TEST_SCRIPT_FILE, "1", "1"], cwd=temp_dir) + + new_files = list(temp_dir.iterdir()) + assert ( + not new_files + ), "Use of tvm.contrib.pickle_memorize may not write to current directory." + + +def test_current_directory_is_not_required_to_be_writable(): + """TVM may be imported without directory permissions + + This is a regression test. In previous implementations, the + `tvm.contrib.pickle_memoize.memoize` function would write to the + current directory when importing TVM. Import of a Python module + should not write to any directory. + + """ + + with tempfile.TemporaryDirectory(prefix="tvm_") as temp_dir: + temp_dir = pathlib.Path(temp_dir) + + # User may read/cd into the temp dir, nobody may write to temp + # dir. + temp_dir.chmod(0o500) + subprocess.check_call([sys.executable, "-c", "import tvm"], cwd=temp_dir) + + +def test_cache_dir_defaults_to_home_config_cache(): + with tempfile.TemporaryDirectory(prefix="tvm_") as temp_dir: + temp_dir = pathlib.Path(temp_dir) + + subprocess.check_call([TEST_SCRIPT_FILE, "1", "0"], cwd=temp_dir) + + new_files = list(temp_dir.iterdir()) + assert ( + not new_files + ), "Use of tvm.contrib.pickle_memorize may not write to current directory." + + cache_dir = pathlib.Path.home().joinpath(".cache", "tvm", "pkl_memoize_py3") + assert cache_dir.exists() + cache_files = list(cache_dir.iterdir()) + assert len(cache_files) >= 1 + + +def test_cache_dir_respects_xdg_cache_home(): + with tempfile.TemporaryDirectory( + prefix="tvm_" + ) as temp_working_dir, tempfile.TemporaryDirectory(prefix="tvm_") as temp_cache_dir: + temp_cache_dir = pathlib.Path(temp_cache_dir) + temp_working_dir = pathlib.Path(temp_working_dir) + + subprocess.check_call( + [TEST_SCRIPT_FILE, "1", "0"], + cwd=temp_working_dir, + env={ + **os.environ, + "XDG_CACHE_HOME": temp_cache_dir.as_posix(), + }, + ) + + new_files = list(temp_working_dir.iterdir()) + assert ( + not new_files + ), "Use of tvm.contrib.pickle_memorize may not write to current directory." + + cache_dir = temp_cache_dir.joinpath("tvm", "pkl_memoize_py3") + assert cache_dir.exists() + cache_files = list(cache_dir.iterdir()) + assert len(cache_files) == 1 + + +def test_cache_dir_only_created_when_used(): + with tempfile.TemporaryDirectory( + prefix="tvm_" + ) as temp_working_dir, tempfile.TemporaryDirectory(prefix="tvm_") as temp_cache_dir: + temp_cache_dir = pathlib.Path(temp_cache_dir) + temp_working_dir = pathlib.Path(temp_working_dir) + + subprocess.check_call( + [TEST_SCRIPT_FILE, "0", "1"], + cwd=temp_working_dir, + env={ + **os.environ, + "XDG_CACHE_HOME": temp_cache_dir.as_posix(), + }, + ) + + cache_dir = temp_cache_dir.joinpath("tvm", "pkl_memoize_py3") + assert not cache_dir.exists() + + +if __name__ == "__main__": + tvm.testing.main() From 4bacbf6a7d35ad7284084a59646dc9333af7d91a Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 8 Jul 2024 14:23:38 -0500 Subject: [PATCH 2/2] lint fix --- python/tvm/contrib/pickle_memoize.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/tvm/contrib/pickle_memoize.py b/python/tvm/contrib/pickle_memoize.py index 10f3f7c6df04..4f3aff8fb5b0 100644 --- a/python/tvm/contrib/pickle_memoize.py +++ b/python/tvm/contrib/pickle_memoize.py @@ -64,6 +64,8 @@ def __init__(self, key, save_at_exit): @property def cache(self): + """Return the cache, initializing on first use.""" + if self._cache is not None: return self._cache