Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion python/tvm/runtime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@
from .ndarray import vpi, rocm, ext_dev
from .module import load_module, enabled, system_lib, load_static_library
from .container import String, ShapeTuple
from .params import save_param_dict, load_param_dict
from .params import (
save_param_dict,
load_param_dict,
save_param_dict_to_file,
load_param_dict_from_file,
)

from . import executor
49 changes: 45 additions & 4 deletions python/tvm/runtime/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,19 @@
# under the License.
# pylint: disable=invalid-name
"""Helper utility to save and load parameter dicts."""
from . import _ffi_api, ndarray
from . import _ffi_api, ndarray, NDArray


def _to_ndarray(params):
transformed = {}

for (k, v) in params.items():
if not isinstance(v, NDArray):
transformed[k] = ndarray.array(v)
else:
transformed[k] = v

return transformed


def save_param_dict(params):
Expand Down Expand Up @@ -47,12 +59,25 @@ def save_param_dict(params):
# Pass in byte array to module to directly set parameters
tvm.runtime.load_param_dict(param_bytes)
"""
transformed = {k: ndarray.array(v) for (k, v) in params.items()}
return _ffi_api.SaveParams(transformed)
return _ffi_api.SaveParams(_to_ndarray(params))


def save_param_dict_to_file(params, path):
"""Save parameter dictionary to file.

Parameters
----------
params : dict of str to NDArray
The parameter dictionary.

path: str
The path to the parameter file.
"""
return _ffi_api.SaveParamsToFile(_to_ndarray(params), path)


def load_param_dict(param_bytes):
"""Load parameter dictionary to binary bytes.
"""Load parameter dictionary from binary bytes.

Parameters
----------
Expand All @@ -67,3 +92,19 @@ def load_param_dict(param_bytes):
if isinstance(param_bytes, (bytes, str)):
param_bytes = bytearray(param_bytes)
return _ffi_api.LoadParams(param_bytes)


def load_param_dict_from_file(path):
"""Load parameter dictionary from file.

Parameters
----------
path: str
The path to the parameter file to load from.

Returns
-------
params : dict of str to NDArray
The parameter dictionary.
"""
return _ffi_api.LoadParamsFromFile(path)
12 changes: 12 additions & 0 deletions src/runtime/file_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -243,9 +243,21 @@ TVM_REGISTER_GLOBAL("runtime.SaveParams").set_body_typed([](const Map<String, ND
rv = TVMByteArray{s.data(), s.size()};
return rv;
});

TVM_REGISTER_GLOBAL("runtime.SaveParamsToFile")
.set_body_typed([](const Map<String, NDArray>& params, const String& path) {
tvm::runtime::SimpleBinaryFileStream strm(path, "wb");
SaveParams(&strm, params);
});

TVM_REGISTER_GLOBAL("runtime.LoadParams").set_body_typed([](const String& s) {
return ::tvm::runtime::LoadParams(s);
});

TVM_REGISTER_GLOBAL("runtime.LoadParamsFromFile").set_body_typed([](const String& path) {
tvm::runtime::SimpleBinaryFileStream strm(path, "rb");
return LoadParams(&strm);
});

} // namespace runtime
} // namespace tvm
16 changes: 14 additions & 2 deletions tests/python/unittest/test_runtime_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tempfile
import tvm
import tvm.testing
from tvm import te, runtime
Expand Down Expand Up @@ -138,6 +139,17 @@ def test_load_unexpected_params():
rt_mod.load_params(runtime.save_param_dict(new_params))


def test_save_load_file():
p = np.random.randn(10)
params = {"x": p}

with tempfile.NamedTemporaryFile() as fp:
tvm.runtime.save_param_dict_to_file(params, fp.name)
params_loaded = tvm.runtime.load_param_dict_from_file(fp.name)

assert "x" in params_loaded
np.testing.assert_equal(p, params_loaded["x"].numpy())


if __name__ == "__main__":
test_graph_simple()
test_load_unexpected_params()
tvm.testing.main()