Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion apps/android_camera/models/prepare_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def main(model_str, output_path):
f.write(graph)
print("dumping params...")
with open(output_path_str + "/" + "deploy_param.params", "wb") as f:
f.write(relay.save_param_dict(params))
f.write(runtime.save_param_dict(params))
print("dumping labels...")
synset_url = "".join(
[
Expand Down
6 changes: 3 additions & 3 deletions apps/bundle_deploy/build_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import os
from tvm import relay
import tvm
from tvm import te
from tvm import te, runtime
import logging
import json
from tvm.contrib import cc as _cc
Expand Down Expand Up @@ -70,7 +70,7 @@ def build_module(opts):
with open(
os.path.join(build_dir, file_format_str.format(name="params", ext="bin")), "wb"
) as f_params:
f_params.write(relay.save_param_dict(params))
f_params.write(runtime.save_param_dict(params))


def build_test_module(opts):
Expand Down Expand Up @@ -113,7 +113,7 @@ def build_test_module(opts):
with open(
os.path.join(build_dir, file_format_str.format(name="test_params", ext="bin")), "wb"
) as f_params:
f_params.write(relay.save_param_dict(lowered_params))
f_params.write(runtime.save_param_dict(lowered_params))
with open(
os.path.join(build_dir, file_format_str.format(name="test_data", ext="bin")), "wb"
) as fp:
Expand Down
1 change: 1 addition & 0 deletions apps/bundle_deploy/runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <tvm/runtime/registry.h>

#include "../../src/runtime/c_runtime_api.cc"
#include "../../src/runtime/container.cc"
#include "../../src/runtime/cpu_device_api.cc"
#include "../../src/runtime/file_utils.cc"
#include "../../src/runtime/graph/graph_runtime.cc"
Expand Down
4 changes: 2 additions & 2 deletions apps/sgx/src/build_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from os import path as osp
import sys

from tvm import relay
from tvm import relay, runtime
from tvm.relay import testing
import tvm
from tvm import te
Expand All @@ -49,7 +49,7 @@ def main():
with open(osp.join(build_dir, "graph.json"), "w") as f_graph_json:
f_graph_json.write(graph)
with open(osp.join(build_dir, "params.bin"), "wb") as f_params:
f_params.write(relay.save_param_dict(params))
f_params.write(runtime.save_param_dict(params))


if __name__ == "__main__":
Expand Down
4 changes: 2 additions & 2 deletions apps/wasm-standalone/wasm-graph/tools/build_graph_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

import onnx
import tvm
from tvm import relay
from tvm import relay, runtime


def _get_mod_and_params(model_file):
Expand Down Expand Up @@ -60,7 +60,7 @@ def build_graph_lib(model_file, opt_level):
f_graph.write(graph_json)

with open(os.path.join(out_dir, "graph.params"), "wb") as f_params:
f_params.write(relay.save_param_dict(params))
f_params.write(runtime.save_param_dict(params))


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion docs/deploy/android.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ The code below will save the compilation output which is required on android tar
with open("deploy_graph.json", "w") as fo:
fo.write(graph.json())
with open("deploy_param.params", "wb") as fo:
fo.write(relay.save_param_dict(params))
fo.write(runtime.save_param_dict(params))

deploy_lib.so, deploy_graph.json, deploy_param.params will go to android target.

Expand Down
4 changes: 2 additions & 2 deletions golang/sample/gen_mobilenet_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.

import os
from tvm import relay, transform
from tvm import relay, transform, runtime
from tvm.contrib.download import download_testdata


Expand Down Expand Up @@ -94,4 +94,4 @@ def extract(path):
fo.write(graph)

with open("./mobilenet.params", "wb") as fo:
fo.write(relay.save_param_dict(params))
fo.write(runtime.save_param_dict(params))
6 changes: 1 addition & 5 deletions python/tvm/contrib/debugger/debug_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,8 +264,4 @@ def save_tensors(params):
"""
_save_tensors = tvm.get_global_func("tvm.relay._save_param_dict")

args = []
for k, v in params.items():
args.append(k)
args.append(tvm.nd.array(v))
return _save_tensors(*args)
return _save_tensors(params)
4 changes: 2 additions & 2 deletions python/tvm/driver/tvmc/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

import tvm
from tvm import autotvm, auto_scheduler
from tvm import relay
from tvm import relay, runtime
from tvm.contrib import cc
from tvm.contrib import utils

Expand Down Expand Up @@ -282,7 +282,7 @@ def save_module(module_path, graph, lib, params, cross=None):

with open(temp.relpath(param_name), "wb") as params_file:
logger.debug("writing params to file to %s", params_file.name)
params_file.write(relay.save_param_dict(params))
params_file.write(runtime.save_param_dict(params))

logger.debug("saving module as tar file to %s", module_path)
with tarfile.open(module_path, "w") as tar:
Expand Down
7 changes: 3 additions & 4 deletions python/tvm/driver/tvmc/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@
import tempfile

import numpy as np
import tvm
from tvm import rpc
from tvm.autotvm.measure import request_remote
from tvm.contrib import graph_runtime as runtime
from tvm.contrib.debugger import debug_runtime
from tvm.relay import load_param_dict

from . import common
from .common import TVMCException
Expand Down Expand Up @@ -163,9 +163,8 @@ def get_input_info(graph_str, params):

shape_dict = {}
dtype_dict = {}
# Use a special function to load the binary params back into a dict
load_arr = tvm.get_global_func("tvm.relay._load_param_dict")(params)
param_names = [v.name for v in load_arr]
params_dict = load_param_dict(params)
param_names = [k for (k, v) in params_dict.items()]
graph = json.loads(graph_str)
for node_id in graph["arg_nodes"]:
node = graph["nodes"][node_id]
Expand Down
28 changes: 11 additions & 17 deletions python/tvm/relay/param_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,7 @@
# under the License.
# pylint: disable=invalid-name
"""Helper utility to save parameter dicts."""
import tvm
import tvm._ffi


_save_param_dict = tvm._ffi.get_global_func("tvm.relay._save_param_dict")
_load_param_dict = tvm._ffi.get_global_func("tvm.relay._load_param_dict")
import tvm.runtime


def save_param_dict(params):
Expand All @@ -30,6 +25,9 @@ def save_param_dict(params):
The result binary bytes can be loaded by the
GraphModule with API "load_params".

.. deprecated:: 0.9.0
Use :py:func:`tvm.runtime.save_param_dict` instead.

Parameters
----------
params : dict of str to NDArray
Expand All @@ -47,21 +45,20 @@ def save_param_dict(params):
# set up the parameter dict
params = {"param0": arr0, "param1": arr1}
# save the parameters as byte array
param_bytes = tvm.relay.save_param_dict(params)
param_bytes = tvm.runtime.save_param_dict(params)
# We can serialize the param_bytes and load it back later.
# Pass in byte array to module to directly set parameters
graph_runtime_mod.load_params(param_bytes)
tvm.runtime.load_param_dict(param_bytes)
"""
args = []
for k, v in params.items():
args.append(k)
args.append(tvm.nd.array(v))
return _save_param_dict(*args)
return tvm.runtime.save_param_dict(params)


def load_param_dict(param_bytes):
"""Load parameter dictionary to binary bytes.

.. deprecated:: 0.9.0
Use :py:func:`tvm.runtime.load_param_dict` instead.

Parameters
----------
param_bytes: bytearray
Expand All @@ -72,7 +69,4 @@ def load_param_dict(param_bytes):
params : dict of str to NDArray
The parameter dictionary.
"""
if isinstance(param_bytes, (bytes, str)):
param_bytes = bytearray(param_bytes)
load_arr = _load_param_dict(param_bytes)
return {v.name: v.array for v in load_arr}
return tvm.runtime.load_param_dict(param_bytes)
1 change: 1 addition & 0 deletions python/tvm/runtime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,4 @@
from .ndarray import vpi, rocm, ext_dev, micro_dev
from .module import load_module, enabled, system_lib
from .container import String
from .params import save_param_dict, load_param_dict
69 changes: 69 additions & 0 deletions python/tvm/runtime/params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""Helper utility to save and load parameter dicts."""
from . import _ffi_api, ndarray


def save_param_dict(params):
"""Save parameter dictionary to binary bytes.

The result binary bytes can be loaded by the
GraphModule with API "load_params".

Parameters
----------
params : dict of str to NDArray
The parameter dictionary.

Returns
-------
param_bytes: bytearray
Serialized parameters.

Examples
--------
.. code-block:: python

# set up the parameter dict
params = {"param0": arr0, "param1": arr1}
# save the parameters as byte array
param_bytes = tvm.runtime.save_param_dict(params)
# We can serialize the param_bytes and load it back later.
# Pass in byte array to module to directly set parameters
tvm.runtime.load_param_dict(param_bytes)
"""
transformed = {k: ndarray.array(v) for (k, v) in params.items()}
return _ffi_api.SaveParams(transformed)


def load_param_dict(param_bytes):
"""Load parameter dictionary to binary bytes.

Parameters
----------
param_bytes: bytearray
Serialized parameters.

Returns
-------
params : dict of str to NDArray
The parameter dictionary.
"""
if isinstance(param_bytes, (bytes, str)):
param_bytes = bytearray(param_bytes)
return _ffi_api.LoadParams(param_bytes)
2 changes: 1 addition & 1 deletion rust/tvm-graph-rt/src/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,7 @@ named! {
)
}

/// Loads a param dict saved using `relay.save_param_dict`.
/// Loads a param dict saved using `runtime.save_param_dict`.
pub fn load_param_dict(bytes: &[u8]) -> Result<HashMap<String, Tensor>, GraphFormatError> {
match parse_param_dict(bytes) {
Ok((remaining_bytes, param_dict)) => {
Expand Down
4 changes: 2 additions & 2 deletions rust/tvm-graph-rt/tests/build_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import numpy as np
import tvm
from tvm import te
from tvm import relay
from tvm import relay, runtime
from tvm.relay import testing

CWD = osp.dirname(osp.abspath(osp.expanduser(__file__)))
Expand All @@ -47,7 +47,7 @@ def main():
with open(osp.join(CWD, "graph.json"), "w") as f_resnet:
f_resnet.write(graph)
with open(osp.join(CWD, "graph.params"), "wb") as f_params:
f_params.write(relay.save_param_dict(params))
f_params.write(runtime.save_param_dict(params))


if __name__ == "__main__":
Expand Down
4 changes: 2 additions & 2 deletions rust/tvm-graph-rt/tests/test_nn/src/build_test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

import numpy as np
import tvm
from tvm import te
from tvm import te, runtime
from tvm import relay
from tvm.relay import testing

Expand All @@ -49,7 +49,7 @@ def main():
f_resnet.write(graph)

with open(osp.join(out_dir, "graph.params"), "wb") as f_params:
f_params.write(relay.save_param_dict(params))
f_params.write(runtime.save_param_dict(params))


if __name__ == "__main__":
Expand Down
4 changes: 2 additions & 2 deletions rust/tvm/examples/resnet/src/build_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

import tvm
from tvm import te
from tvm import relay
from tvm import relay, runtime
from tvm.relay import testing
from tvm.contrib import graph_runtime, cc
from PIL import Image
Expand Down Expand Up @@ -88,7 +88,7 @@ def build(target_dir):
fo.write(graph)

with open(osp.join(target_dir, "deploy_param.params"), "wb") as fo:
fo.write(relay.save_param_dict(params))
fo.write(runtime.save_param_dict(params))


def download_img_labels():
Expand Down
Loading