diff --git a/python/tvm/runtime/__init__.py b/python/tvm/runtime/__init__.py index 71f71e6c8427..eccdcbad9520 100644 --- a/python/tvm/runtime/__init__.py +++ b/python/tvm/runtime/__init__.py @@ -32,6 +32,11 @@ from .ndarray import vpi, rocm, ext_dev from .module import load_module, enabled, system_lib, load_static_library from .container import String, ShapeTuple -from .params import save_param_dict, load_param_dict +from .params import ( + save_param_dict, + load_param_dict, + save_param_dict_to_file, + load_param_dict_from_file, +) from . import executor diff --git a/python/tvm/runtime/params.py b/python/tvm/runtime/params.py index 78e745686c95..4362a4b6a841 100644 --- a/python/tvm/runtime/params.py +++ b/python/tvm/runtime/params.py @@ -16,7 +16,19 @@ # under the License. # pylint: disable=invalid-name """Helper utility to save and load parameter dicts.""" -from . import _ffi_api, ndarray +from . import _ffi_api, ndarray, NDArray + + +def _to_ndarray(params): + transformed = {} + + for (k, v) in params.items(): + if not isinstance(v, NDArray): + transformed[k] = ndarray.array(v) + else: + transformed[k] = v + + return transformed def save_param_dict(params): @@ -47,12 +59,25 @@ def save_param_dict(params): # Pass in byte array to module to directly set parameters tvm.runtime.load_param_dict(param_bytes) """ - transformed = {k: ndarray.array(v) for (k, v) in params.items()} - return _ffi_api.SaveParams(transformed) + return _ffi_api.SaveParams(_to_ndarray(params)) + + +def save_param_dict_to_file(params, path): + """Save parameter dictionary to file. + + Parameters + ---------- + params : dict of str to NDArray + The parameter dictionary. + + path: str + The path to the parameter file. + """ + return _ffi_api.SaveParamsToFile(_to_ndarray(params), path) def load_param_dict(param_bytes): - """Load parameter dictionary to binary bytes. + """Load parameter dictionary from binary bytes. Parameters ---------- @@ -67,3 +92,19 @@ def load_param_dict(param_bytes): if isinstance(param_bytes, (bytes, str)): param_bytes = bytearray(param_bytes) return _ffi_api.LoadParams(param_bytes) + + +def load_param_dict_from_file(path): + """Load parameter dictionary from file. + + Parameters + ---------- + path: str + The path to the parameter file to load from. + + Returns + ------- + params : dict of str to NDArray + The parameter dictionary. + """ + return _ffi_api.LoadParamsFromFile(path) diff --git a/src/runtime/file_utils.cc b/src/runtime/file_utils.cc index 1e7cc6ad44e7..1c0e16dbe1a8 100644 --- a/src/runtime/file_utils.cc +++ b/src/runtime/file_utils.cc @@ -243,9 +243,21 @@ TVM_REGISTER_GLOBAL("runtime.SaveParams").set_body_typed([](const Map& params, const String& path) { + tvm::runtime::SimpleBinaryFileStream strm(path, "wb"); + SaveParams(&strm, params); + }); + TVM_REGISTER_GLOBAL("runtime.LoadParams").set_body_typed([](const String& s) { return ::tvm::runtime::LoadParams(s); }); +TVM_REGISTER_GLOBAL("runtime.LoadParamsFromFile").set_body_typed([](const String& path) { + tvm::runtime::SimpleBinaryFileStream strm(path, "rb"); + return LoadParams(&strm); +}); + } // namespace runtime } // namespace tvm diff --git a/tests/python/unittest/test_runtime_graph.py b/tests/python/unittest/test_runtime_graph.py index 458952fb5641..108784de7eb1 100644 --- a/tests/python/unittest/test_runtime_graph.py +++ b/tests/python/unittest/test_runtime_graph.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import tempfile import tvm import tvm.testing from tvm import te, runtime @@ -138,6 +139,17 @@ def test_load_unexpected_params(): rt_mod.load_params(runtime.save_param_dict(new_params)) +def test_save_load_file(): + p = np.random.randn(10) + params = {"x": p} + + with tempfile.NamedTemporaryFile() as fp: + tvm.runtime.save_param_dict_to_file(params, fp.name) + params_loaded = tvm.runtime.load_param_dict_from_file(fp.name) + + assert "x" in params_loaded + np.testing.assert_equal(p, params_loaded["x"].numpy()) + + if __name__ == "__main__": - test_graph_simple() - test_load_unexpected_params() + tvm.testing.main()