From 2f3747cc05aaceaa4306f4dace8bd68a71143f7a Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 1 Mar 2023 04:46:50 +0900 Subject: [PATCH 1/6] add load_params_from_file --- python/tvm/runtime/__init__.py | 2 +- python/tvm/runtime/params.py | 16 ++++++++++++++++ src/runtime/file_utils.cc | 9 +++++++++ 3 files changed, 26 insertions(+), 1 deletion(-) diff --git a/python/tvm/runtime/__init__.py b/python/tvm/runtime/__init__.py index 71f71e6c8427..5b1499614e9d 100644 --- a/python/tvm/runtime/__init__.py +++ b/python/tvm/runtime/__init__.py @@ -32,6 +32,6 @@ from .ndarray import vpi, rocm, ext_dev from .module import load_module, enabled, system_lib, load_static_library from .container import String, ShapeTuple -from .params import save_param_dict, load_param_dict +from .params import save_param_dict, load_param_dict, load_param_dict_from_file from . import executor diff --git a/python/tvm/runtime/params.py b/python/tvm/runtime/params.py index 78e745686c95..d853f787810d 100644 --- a/python/tvm/runtime/params.py +++ b/python/tvm/runtime/params.py @@ -67,3 +67,19 @@ def load_param_dict(param_bytes): if isinstance(param_bytes, (bytes, str)): param_bytes = bytearray(param_bytes) return _ffi_api.LoadParams(param_bytes) + + +def load_param_dict_from_file(file_name): + """Load parameter dictionary to binary bytes. + + Parameters + ---------- + param_bytes: bytearray + Serialized parameters. + + Returns + ------- + params : dict of str to NDArray + The parameter dictionary. + """ + return _ffi_api.LoadParamsFromFile(file_name) diff --git a/src/runtime/file_utils.cc b/src/runtime/file_utils.cc index 1e7cc6ad44e7..793c817de2bd 100644 --- a/src/runtime/file_utils.cc +++ b/src/runtime/file_utils.cc @@ -207,6 +207,11 @@ Map LoadParams(dmlc::Stream* strm) { return params; } +Map LoadParamsFromFile(const std::string& file_name) { + tvm::runtime::SimpleBinaryFileStream strm(file_name, "rb"); + return LoadParams(&strm); +} + void SaveParams(dmlc::Stream* strm, const Map& params) { std::vector names; std::vector arrays; @@ -247,5 +252,9 @@ TVM_REGISTER_GLOBAL("runtime.LoadParams").set_body_typed([](const String& s) { return ::tvm::runtime::LoadParams(s); }); +TVM_REGISTER_GLOBAL("runtime.LoadParamsFromFile").set_body_typed([](const String& file_name) { + return ::tvm::runtime::LoadParamsFromFile(file_name); +}); + } // namespace runtime } // namespace tvm From 10dcd0dac8ad6c7a21b4118fca5e244eda42c723 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 1 Mar 2023 04:58:10 +0900 Subject: [PATCH 2/6] add save_params_to_file --- python/tvm/runtime/__init__.py | 2 +- python/tvm/runtime/params.py | 36 ++++++++++++++++++++++++++++++++-- src/runtime/file_utils.cc | 17 +++++++++------- 3 files changed, 45 insertions(+), 10 deletions(-) diff --git a/python/tvm/runtime/__init__.py b/python/tvm/runtime/__init__.py index 5b1499614e9d..88d80fe9a387 100644 --- a/python/tvm/runtime/__init__.py +++ b/python/tvm/runtime/__init__.py @@ -32,6 +32,6 @@ from .ndarray import vpi, rocm, ext_dev from .module import load_module, enabled, system_lib, load_static_library from .container import String, ShapeTuple -from .params import save_param_dict, load_param_dict, load_param_dict_from_file +from .params import save_param_dict, load_param_dict, save_param_dict_to_file, load_param_dict_from_file from . import executor diff --git a/python/tvm/runtime/params.py b/python/tvm/runtime/params.py index d853f787810d..7a6171d8a038 100644 --- a/python/tvm/runtime/params.py +++ b/python/tvm/runtime/params.py @@ -51,6 +51,38 @@ def save_param_dict(params): return _ffi_api.SaveParams(transformed) +def save_param_dict_to_file(params, path): + """Save parameter dictionary to binary bytes. + + The result binary bytes can be loaded by the + GraphModule with API "load_params". + + Parameters + ---------- + params : dict of str to NDArray + The parameter dictionary. + + Returns + ------- + param_bytes: bytearray + Serialized parameters. + + Examples + -------- + .. code-block:: python + + # set up the parameter dict + params = {"param0": arr0, "param1": arr1} + # save the parameters as byte array + param_bytes = tvm.runtime.save_param_dict(params) + # We can serialize the param_bytes and load it back later. + # Pass in byte array to module to directly set parameters + tvm.runtime.load_param_dict(param_bytes) + """ + transformed = {k: ndarray.array(v) for (k, v) in params.items()} + return _ffi_api.SaveParamsToFile(transformed, path) + + def load_param_dict(param_bytes): """Load parameter dictionary to binary bytes. @@ -69,7 +101,7 @@ def load_param_dict(param_bytes): return _ffi_api.LoadParams(param_bytes) -def load_param_dict_from_file(file_name): +def load_param_dict_from_file(path): """Load parameter dictionary to binary bytes. Parameters @@ -82,4 +114,4 @@ def load_param_dict_from_file(file_name): params : dict of str to NDArray The parameter dictionary. """ - return _ffi_api.LoadParamsFromFile(file_name) + return _ffi_api.LoadParamsFromFile(path) diff --git a/src/runtime/file_utils.cc b/src/runtime/file_utils.cc index 793c817de2bd..1c0e16dbe1a8 100644 --- a/src/runtime/file_utils.cc +++ b/src/runtime/file_utils.cc @@ -207,11 +207,6 @@ Map LoadParams(dmlc::Stream* strm) { return params; } -Map LoadParamsFromFile(const std::string& file_name) { - tvm::runtime::SimpleBinaryFileStream strm(file_name, "rb"); - return LoadParams(&strm); -} - void SaveParams(dmlc::Stream* strm, const Map& params) { std::vector names; std::vector arrays; @@ -248,12 +243,20 @@ TVM_REGISTER_GLOBAL("runtime.SaveParams").set_body_typed([](const Map& params, const String& path) { + tvm::runtime::SimpleBinaryFileStream strm(path, "wb"); + SaveParams(&strm, params); + }); + TVM_REGISTER_GLOBAL("runtime.LoadParams").set_body_typed([](const String& s) { return ::tvm::runtime::LoadParams(s); }); -TVM_REGISTER_GLOBAL("runtime.LoadParamsFromFile").set_body_typed([](const String& file_name) { - return ::tvm::runtime::LoadParamsFromFile(file_name); +TVM_REGISTER_GLOBAL("runtime.LoadParamsFromFile").set_body_typed([](const String& path) { + tvm::runtime::SimpleBinaryFileStream strm(path, "rb"); + return LoadParams(&strm); }); } // namespace runtime From 65147ae8c430f434e5cb390adf253a20b89a786f Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 1 Mar 2023 05:12:45 +0900 Subject: [PATCH 3/6] avoid making another copy in save_params --- python/tvm/runtime/params.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/python/tvm/runtime/params.py b/python/tvm/runtime/params.py index 7a6171d8a038..212a161e2066 100644 --- a/python/tvm/runtime/params.py +++ b/python/tvm/runtime/params.py @@ -16,7 +16,19 @@ # under the License. # pylint: disable=invalid-name """Helper utility to save and load parameter dicts.""" -from . import _ffi_api, ndarray +from . import _ffi_api, ndarray, NDArray + + +def _to_ndarray(params): + transformed = {} + + for (k, v) in params.items(): + if not isinstance(v, NDArray): + transformed[k] = ndarray.array(v) + else: + transformed[k] = v + + return transformed def save_param_dict(params): @@ -47,8 +59,7 @@ def save_param_dict(params): # Pass in byte array to module to directly set parameters tvm.runtime.load_param_dict(param_bytes) """ - transformed = {k: ndarray.array(v) for (k, v) in params.items()} - return _ffi_api.SaveParams(transformed) + return _ffi_api.SaveParams(_to_ndarray(params)) def save_param_dict_to_file(params, path): @@ -79,8 +90,7 @@ def save_param_dict_to_file(params, path): # Pass in byte array to module to directly set parameters tvm.runtime.load_param_dict(param_bytes) """ - transformed = {k: ndarray.array(v) for (k, v) in params.items()} - return _ffi_api.SaveParamsToFile(transformed, path) + return _ffi_api.SaveParamsToFile(_to_ndarray(params), path) def load_param_dict(param_bytes): From bd551f7e57451e961053205ea7156466c1e4808f Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 1 Mar 2023 05:16:39 +0900 Subject: [PATCH 4/6] black --- python/tvm/runtime/__init__.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/python/tvm/runtime/__init__.py b/python/tvm/runtime/__init__.py index 88d80fe9a387..eccdcbad9520 100644 --- a/python/tvm/runtime/__init__.py +++ b/python/tvm/runtime/__init__.py @@ -32,6 +32,11 @@ from .ndarray import vpi, rocm, ext_dev from .module import load_module, enabled, system_lib, load_static_library from .container import String, ShapeTuple -from .params import save_param_dict, load_param_dict, save_param_dict_to_file, load_param_dict_from_file +from .params import ( + save_param_dict, + load_param_dict, + save_param_dict_to_file, + load_param_dict_from_file, +) from . import executor From 1be1ca06c4c8a120da6dc8fb8d2620ab6fe92346 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 1 Mar 2023 05:32:47 +0900 Subject: [PATCH 5/6] add test --- tests/python/unittest/test_runtime_graph.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/tests/python/unittest/test_runtime_graph.py b/tests/python/unittest/test_runtime_graph.py index 458952fb5641..108784de7eb1 100644 --- a/tests/python/unittest/test_runtime_graph.py +++ b/tests/python/unittest/test_runtime_graph.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import tempfile import tvm import tvm.testing from tvm import te, runtime @@ -138,6 +139,17 @@ def test_load_unexpected_params(): rt_mod.load_params(runtime.save_param_dict(new_params)) +def test_save_load_file(): + p = np.random.randn(10) + params = {"x": p} + + with tempfile.NamedTemporaryFile() as fp: + tvm.runtime.save_param_dict_to_file(params, fp.name) + params_loaded = tvm.runtime.load_param_dict_from_file(fp.name) + + assert "x" in params_loaded + np.testing.assert_equal(p, params_loaded["x"].numpy()) + + if __name__ == "__main__": - test_graph_simple() - test_load_unexpected_params() + tvm.testing.main() From 28ea63135e559746bcd550cb4a91f01d3a32d5ac Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 1 Mar 2023 05:35:28 +0900 Subject: [PATCH 6/6] update doc --- python/tvm/runtime/params.py | 31 +++++++------------------------ 1 file changed, 7 insertions(+), 24 deletions(-) diff --git a/python/tvm/runtime/params.py b/python/tvm/runtime/params.py index 212a161e2066..4362a4b6a841 100644 --- a/python/tvm/runtime/params.py +++ b/python/tvm/runtime/params.py @@ -63,38 +63,21 @@ def save_param_dict(params): def save_param_dict_to_file(params, path): - """Save parameter dictionary to binary bytes. - - The result binary bytes can be loaded by the - GraphModule with API "load_params". + """Save parameter dictionary to file. Parameters ---------- params : dict of str to NDArray The parameter dictionary. - Returns - ------- - param_bytes: bytearray - Serialized parameters. - - Examples - -------- - .. code-block:: python - - # set up the parameter dict - params = {"param0": arr0, "param1": arr1} - # save the parameters as byte array - param_bytes = tvm.runtime.save_param_dict(params) - # We can serialize the param_bytes and load it back later. - # Pass in byte array to module to directly set parameters - tvm.runtime.load_param_dict(param_bytes) + path: str + The path to the parameter file. """ return _ffi_api.SaveParamsToFile(_to_ndarray(params), path) def load_param_dict(param_bytes): - """Load parameter dictionary to binary bytes. + """Load parameter dictionary from binary bytes. Parameters ---------- @@ -112,12 +95,12 @@ def load_param_dict(param_bytes): def load_param_dict_from_file(path): - """Load parameter dictionary to binary bytes. + """Load parameter dictionary from file. Parameters ---------- - param_bytes: bytearray - Serialized parameters. + path: str + The path to the parameter file to load from. Returns -------