Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions python/tvm/contrib/tflite_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def __init__(self, module):
self._set_input = module["set_input"]
self._invoke = module["invoke"]
self._get_output = module["get_output"]
self._set_num_threads = module["set_num_threads"]

def set_input(self, index, value):
"""Set inputs to the module via kwargs
Expand Down Expand Up @@ -109,3 +110,12 @@ def get_output(self, index):
The output index
"""
return self._get_output(index)

def set_num_threads(self, num_threads):
"""Set the number of threads via kwargs
Parameters
----------
num_threads : int
The number of threads
"""
self._set_num_threads(num_threads)
8 changes: 8 additions & 0 deletions src/runtime/contrib/tflite/tflite_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ void TFLiteRuntime::SetInput(int index, DLTensor* data_in) {
});
}

void TFLiteRuntime::SetNumThreads(int num_threads) { interpreter_->SetNumThreads(num_threads); }

NDArray TFLiteRuntime::GetOutput(int index) const {
TfLiteTensor* output = interpreter_->tensor(interpreter_->outputs()[index]);
DataType dtype = TfLiteDType2TVMDType(output->type);
Expand Down Expand Up @@ -163,6 +165,12 @@ PackedFunc TFLiteRuntime::GetFunction(const std::string& name,
[sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetOutput(args[0]); });
} else if (name == "invoke") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { this->Invoke(); });
} else if (name == "set_num_threads") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
int num_threads = args[0];
CHECK_GE(num_threads, 1);
this->SetNumThreads(num_threads);
});
} else {
return PackedFunc();
}
Expand Down
5 changes: 5 additions & 0 deletions src/runtime/contrib/tflite/tflite_runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,11 @@ class TFLiteRuntime : public ModuleNode {
* \return NDArray corresponding to given output node index.
*/
NDArray GetOutput(int index) const;
/*!
* \brief Set the number of threads available to the interpreter.
* \param num_threads The number of threads to be set.
*/
void SetNumThreads(int num_threads);

// Buffer backing the interpreter's model
std::unique_ptr<char[]> flatBuffersBuffer_;
Expand Down