Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions src/runtime/graph_executor/graph_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,16 @@ void GraphExecutor::SetOutputZeroCopy(int index, DLTensor* data_ref) {
// check the consistency of output
CheckExternalDLTensor(data_ref, output_node_eid);

if (nodes_[output_node.node_id].op_type == "tvm_op" &&
nodes_[output_node.node_id].param.func_name == "__nop") {
const NodeEntry& input_node = nodes_[output_node.node_id].inputs[0];
output_node_eid = this->entry_id(input_node);
ICHECK_NE(node_output_dltensors_[output_node_eid].size(), 0);
for (DLTensor* t : node_output_dltensors_[output_node_eid]) {
t->data = static_cast<char*>(data_ref->data) + data_ref->byte_offset;
}
}

// Update the data pointer for output op
for (DLTensor* t : output_dltensors_[output_node_eid]) {
t->data = static_cast<char*>(data_ref->data) + data_ref->byte_offset;
Expand Down Expand Up @@ -540,6 +550,13 @@ void GraphExecutor::SetupOpExecs() {
input_dltensors_[input_eid].push_back(
const_cast<DLTensor*>(data_entry_[eid].operator->()));
}
} else {
const auto& arg_node = nodes_[inode.inputs[i].node_id];
if (arg_node.op_type == "tvm_op" && arg_node.param.func_name == "__nop") {
uint32_t arg_input_eid = this->entry_id(arg_node.inputs[0]);
input_dltensors_[arg_input_eid].push_back(
static_cast<DLTensor*>(op_args->arg_values[i].v_handle));
}
}
// check if any model output is the input of the op
if (output_node_eids.count(input_eid) > 0) {
Expand All @@ -554,6 +571,11 @@ void GraphExecutor::SetupOpExecs() {
if (output_node_eids.count(output_eid) > 0) {
output_dltensors_[output_eid].push_back(
static_cast<DLTensor*>(op_args->arg_values[i].v_handle));
} else {
// If the node is not an output, keep its output for record and support set_output_zero_copy
// of reshape __nop nodes.
node_output_dltensors_[output_eid].push_back(
static_cast<DLTensor*>(op_args->arg_values[i].v_handle));
}
}
}
Expand Down
2 changes: 2 additions & 0 deletions src/runtime/graph_executor/graph_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,8 @@ class TVM_DLL GraphExecutor : public ModuleNode {
std::vector<std::vector<DLTensor*>> output_dltensors_;
/*! \brief Used for quick node(both model output and op input) DLTensor* lookup given an eid. */
std::vector<std::vector<DLTensor*>> both_output_opinput_dltensors_;
/*! \brief Used for quick node output DLTensor* lookup given a nop's input eid. */
std::unordered_map<int, std::vector<DLTensor*>> node_output_dltensors_;
/*! \brief Used for quick entry_id lookup given an storage_id. */
std::vector<std::vector<uint32_t>> sid_to_eid_;
/*! \brief Used for quick entry indexing. */
Expand Down
49 changes: 49 additions & 0 deletions tests/python/runtime/test_runtime_module_based_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,6 +735,54 @@ def test_graph_module_zero_copy():
tvm.testing.assert_allclose(gm.get_output(0).numpy(), z_torch.numpy())


@tvm.testing.requires_llvm
def test_reshape_zero_copy():
shape0 = (56, 224)
shape1 = (112, 112)
in_name0 = "infeats0"
in_name1 = "infeats1"
x0 = relay.var(in_name0, shape=shape0, dtype="float32")
x0 = relay.reshape(x0, shape1)

x1 = relay.var(in_name1, shape=shape1, dtype="float32")
mat = relay.nn.matmul(x0, x1)
_y = relay.reshape(mat, (-1))
func = relay.Function(relay.analysis.free_vars(_y), _y)
mod = tvm.IRModule.from_expr(func)

with tvm.transform.PassContext(opt_level=3):
lib = relay.build(mod, target="llvm")
m = graph_executor.GraphModule(lib["default"](tvm.cpu(0)))

data_ndarray0 = tvm.nd.array(
np.random.random(shape0).astype(np.float32), device=tvm.device("llvm", 0)
)
data_ndarray1 = tvm.nd.array(
np.random.random(shape1).astype(np.float32), device=tvm.device("llvm", 0)
)

def expected():
m.set_input(in_name0, data_ndarray0)
m.set_input(in_name1, data_ndarray1)
m.run()
return m.get_output(0).numpy()

def zero_copy():
from tvm.relay.frontend.common import infer_shape

outshape = infer_shape(_y)
output_view = tvm.nd.empty(outshape, device=tvm.device("llvm", 0))
m.set_input_zero_copy(in_name0, data_ndarray0)
m.set_input_zero_copy(in_name1, data_ndarray1)
m.set_output_zero_copy(0, output_view)
m.run()
return output_view.numpy()

golden_out = expected()
out = zero_copy()
tvm.testing.assert_allclose(golden_out, out)


if __name__ == "__main__":
test_legacy_compatibility()
test_cpu()
Expand All @@ -747,3 +795,4 @@ def test_graph_module_zero_copy():
test_cpu_get_graph_params_run()
test_cpu_get_graph_params_compare()
test_graph_module_zero_copy()
test_reshape_zero_copy()