Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/runtime/relax_vm/ndarray_cache_support.cc
Original file line number Diff line number Diff line change
Expand Up @@ -327,11 +327,19 @@ class ParamModuleNode : public runtime::ModuleNode {
return Module(n);
}

static Module CreateByName(const Array<String>& names) {
auto n = make_object<ParamModuleNode>();
n->params_ = GetParamByName(names);
return Module(n);
}

private:
Array<NDArray> params_;
};

TVM_REGISTER_GLOBAL("vm.builtin.param_module_from_cache").set_body_typed(ParamModuleNode::Create);
TVM_REGISTER_GLOBAL("vm.builtin.param_module_from_cache_by_name")
.set_body_typed(ParamModuleNode::CreateByName);
TVM_REGISTER_GLOBAL("vm.builtin.param_array_from_cache").set_body_typed(ParamModuleNode::GetParams);
TVM_REGISTER_GLOBAL("vm.builtin.param_array_from_cache_by_name")
.set_body_typed(ParamModuleNode::GetParamByName);
Expand Down
17 changes: 17 additions & 0 deletions web/src/runtime.ts
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ class RuntimeContext implements Disposable {
arrayCacheClear: PackedFunc;
arrayDecodeStorage: PackedFunc;
paramModuleFromCache: PackedFunc;
paramModuleFromCacheByName: PackedFunc;
makeShapeTuple: PackedFunc;
ndarrayCreateView: PackedFunc;
sampleTopPFromLogits: PackedFunc;
Expand All @@ -173,6 +174,7 @@ class RuntimeContext implements Disposable {
this.arrayCacheClear = getGlobalFunc("vm.builtin.ndarray_cache.clear");
this.arrayDecodeStorage = getGlobalFunc("tvmjs.array.decode_storage");
this.paramModuleFromCache = getGlobalFunc("vm.builtin.param_module_from_cache");
this.paramModuleFromCacheByName = getGlobalFunc("vm.builtin.param_module_from_cache_by_name");
this.makeShapeTuple = getGlobalFunc("runtime.ShapeTuple");
this.ndarrayCreateView = getGlobalFunc("runtime.TVMArrayCreateView");
this.sampleTopPFromLogits = getGlobalFunc("vm.builtin.sample_top_p_from_logits");
Expand All @@ -194,6 +196,7 @@ class RuntimeContext implements Disposable {
this.arrayCacheClear.dispose();
this.arrayDecodeStorage.dispose();
this.paramModuleFromCache.dispose();
this.paramModuleFromCacheByName.dispose();
this.makeShapeTuple.dispose();
this.ndarrayCreateView.dispose();
this.sampleTopPFromLogits.dispose();
Expand Down Expand Up @@ -1396,6 +1399,20 @@ export class Instance implements Disposable {
prefix, new Scalar(numParams, "int32")) as Module).getFunction("get_params")();
}

/**
* Get parameters based on parameter names provided
*
* @param paramNames Names of the parameters.
* @returns Parameters read.
*/
getParamsFromCacheByName(paramNames: Array<string>): TVMObject {
// Convert Array<string> to Array<TVMString>
const paramNamesTVM: TVMString[] = [];
paramNames.forEach(paramName => { paramNamesTVM.push(this.makeString(paramName)) });
return (this.ctx.paramModuleFromCacheByName(
this.makeTVMArray(paramNamesTVM)) as Module).getFunction("get_params")();
}

/**
* Get NDArray from cache.
* @param name The name of array.
Expand Down