From f3ff0bb0f34fa647c8c3a3b224677005dec9b75a Mon Sep 17 00:00:00 2001 From: Charlie Ruan Date: Sat, 2 Dec 2023 18:10:58 -0500 Subject: [PATCH] Get params from cache by name --- src/runtime/relax_vm/ndarray_cache_support.cc | 8 ++++++++ web/src/runtime.ts | 17 +++++++++++++++++ 2 files changed, 25 insertions(+) diff --git a/src/runtime/relax_vm/ndarray_cache_support.cc b/src/runtime/relax_vm/ndarray_cache_support.cc index ea90255fbaf2..25f1fd282e41 100644 --- a/src/runtime/relax_vm/ndarray_cache_support.cc +++ b/src/runtime/relax_vm/ndarray_cache_support.cc @@ -327,11 +327,19 @@ class ParamModuleNode : public runtime::ModuleNode { return Module(n); } + static Module CreateByName(const Array& names) { + auto n = make_object(); + n->params_ = GetParamByName(names); + return Module(n); + } + private: Array params_; }; TVM_REGISTER_GLOBAL("vm.builtin.param_module_from_cache").set_body_typed(ParamModuleNode::Create); +TVM_REGISTER_GLOBAL("vm.builtin.param_module_from_cache_by_name") + .set_body_typed(ParamModuleNode::CreateByName); TVM_REGISTER_GLOBAL("vm.builtin.param_array_from_cache").set_body_typed(ParamModuleNode::GetParams); TVM_REGISTER_GLOBAL("vm.builtin.param_array_from_cache_by_name") .set_body_typed(ParamModuleNode::GetParamByName); diff --git a/web/src/runtime.ts b/web/src/runtime.ts index 453d6240f3f4..f842b2723f81 100644 --- a/web/src/runtime.ts +++ b/web/src/runtime.ts @@ -152,6 +152,7 @@ class RuntimeContext implements Disposable { arrayCacheClear: PackedFunc; arrayDecodeStorage: PackedFunc; paramModuleFromCache: PackedFunc; + paramModuleFromCacheByName: PackedFunc; makeShapeTuple: PackedFunc; ndarrayCreateView: PackedFunc; sampleTopPFromLogits: PackedFunc; @@ -173,6 +174,7 @@ class RuntimeContext implements Disposable { this.arrayCacheClear = getGlobalFunc("vm.builtin.ndarray_cache.clear"); this.arrayDecodeStorage = getGlobalFunc("tvmjs.array.decode_storage"); this.paramModuleFromCache = getGlobalFunc("vm.builtin.param_module_from_cache"); + this.paramModuleFromCacheByName = getGlobalFunc("vm.builtin.param_module_from_cache_by_name"); this.makeShapeTuple = getGlobalFunc("runtime.ShapeTuple"); this.ndarrayCreateView = getGlobalFunc("runtime.TVMArrayCreateView"); this.sampleTopPFromLogits = getGlobalFunc("vm.builtin.sample_top_p_from_logits"); @@ -194,6 +196,7 @@ class RuntimeContext implements Disposable { this.arrayCacheClear.dispose(); this.arrayDecodeStorage.dispose(); this.paramModuleFromCache.dispose(); + this.paramModuleFromCacheByName.dispose(); this.makeShapeTuple.dispose(); this.ndarrayCreateView.dispose(); this.sampleTopPFromLogits.dispose(); @@ -1396,6 +1399,20 @@ export class Instance implements Disposable { prefix, new Scalar(numParams, "int32")) as Module).getFunction("get_params")(); } + /** + * Get parameters based on parameter names provided + * + * @param paramNames Names of the parameters. + * @returns Parameters read. + */ + getParamsFromCacheByName(paramNames: Array): TVMObject { + // Convert Array to Array + const paramNamesTVM: TVMString[] = []; + paramNames.forEach(paramName => { paramNamesTVM.push(this.makeString(paramName)) }); + return (this.ctx.paramModuleFromCacheByName( + this.makeTVMArray(paramNamesTVM)) as Module).getFunction("get_params")(); + } + /** * Get NDArray from cache. * @param name The name of array.