From ce2e0ce8ceb1ac9558183eea8c9833f57d1d3145 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Fri, 13 Nov 2020 02:31:40 +0000 Subject: [PATCH 1/2] [ShapeFunc] Handle weights in shape func --- src/relay/backend/compile_engine.cc | 73 +++++++++++++++++++---------- tests/python/relay/test_vm.py | 25 ++++++++++ 2 files changed, 72 insertions(+), 26 deletions(-) diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index 767cb6f644de..63477989f324 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -420,38 +420,59 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator> Array VisitExpr_(const ConstantNode* op) final { using tir::make_const; ICHECK(data_dependants_.size()); - ICHECK(op->is_scalar()); bool data_dependant = data_dependants_.back(); - if (data_dependant) { - void* data = op->data->data; - DataType dtype = DataType(op->data->dtype); - auto value = tvm::te::compute( - {}, - [&](const Array&) { - if (dtype == DataType::Int(32)) { - return make_const(dtype, static_cast(data)[0]); - } else if (dtype == DataType::Int(64)) { - return make_const(dtype, static_cast(data)[0]); - } else if (dtype == DataType::Float(32)) { - return make_const(dtype, static_cast(data)[0]); - } else if (dtype == DataType::Float(64)) { - return make_const(dtype, static_cast(data)[0]); - } else if (dtype == DataType::Bool()) { - return make_const(dtype, static_cast(data)[0]); - } else { - LOG(FATAL) << "not handled"; - return tvm::PrimExpr(); + if (!op->is_scalar()) { + // This is a constant weight, extract the shape of the weight tensor. + // This can not be data dependent. + CHECK(!data_dependant); + auto ttype = op->checked_type().as(); + int ndim = static_cast(ttype->shape.size()); + Array out_shape{ndim}; + te::Tensor value = tvm::te::compute( + out_shape, + [&](const Array& indices) { + auto idx = indices[0]; + PrimExpr ret = make_const(DataType::Int(64), 0); + for (int i = 0; i < ndim; i++) { + ret = tvm::if_then_else(idx == i, ttype->shape[i], ret); } + return ret; }, - "data_const", topi::kBroadcast); - scalars_.push_back(value); - return {value}; - } else { - auto value = tvm::te::compute( - {}, [&](const Array&) { return tir::make_const(DataType::Int(64), 0); }, "shape_const", topi::kBroadcast); scalars_.push_back(value); return {value}; + } else { + if (data_dependant) { + void* data = op->data->data; + DataType dtype = DataType(op->data->dtype); + auto value = tvm::te::compute( + {}, + [&](const Array&) { + if (dtype == DataType::Int(32)) { + return make_const(dtype, static_cast(data)[0]); + } else if (dtype == DataType::Int(64)) { + return make_const(dtype, static_cast(data)[0]); + } else if (dtype == DataType::Float(32)) { + return make_const(dtype, static_cast(data)[0]); + } else if (dtype == DataType::Float(64)) { + return make_const(dtype, static_cast(data)[0]); + } else if (dtype == DataType::Bool()) { + return make_const(dtype, static_cast(data)[0]); + } else { + LOG(FATAL) << "not handled"; + return tvm::PrimExpr(); + } + }, + "data_const", topi::kBroadcast); + scalars_.push_back(value); + return {value}; + } else { + auto value = tvm::te::compute( + {}, [&](const Array&) { return tir::make_const(DataType::Int(64), 0); }, + "shape_const", topi::kBroadcast); + scalars_.push_back(value); + return {value}; + } } } diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py index 92d6e8e55db4..6958010176e3 100644 --- a/tests/python/relay/test_vm.py +++ b/tests/python/relay/test_vm.py @@ -770,5 +770,30 @@ def test_vm_reshape_tuple(x_shape=(1, 4, 2), y_shape=(1, 2, 10)): tvm.testing.assert_allclose(res.asnumpy(), np.reshape(x_data, (1, -1))) +def test_constant_shape_with_external_codegen(): + mod = tvm.IRModule() + shape = (relay.Any(), 25) + dtype = "float32" + + # external function + x = relay.var("x", shape=shape, dtype=dtype) + weight = relay.const(np.random.rand(5, 25).astype("float32"), dtype="float32") + out = relay.nn.dense(x, weight) + f1 = relay.Function([x], out) + f1 = f1.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) + f1 = f1.with_attr("Inline", tvm.tir.IntImm("int32", 1)) + f1 = f1.with_attr("Compiler", "a") + glb_f1 = relay.GlobalVar("f1") + mod[glb_f1] = f1 + mod = relay.transform.InferType()(mod) + + # Main function + x = relay.var("x", shape=shape, dtype=dtype) + mod["main"] = relay.Function([x], glb_f1(x)) + comp = relay.vm.VMCompiler() + opt_mod, _ = comp.optimize(mod, target="llvm") + assert "shape_func" in opt_mod.astext(False) + + if __name__ == "__main__": pytest.main([__file__]) From 20b844ba48f69f3643b2eaf4e0d4a536b7581f94 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Sat, 14 Nov 2020 01:08:18 +0000 Subject: [PATCH 2/2] Comments --- src/relay/backend/compile_engine.cc | 61 ++++++++++++++--------------- 1 file changed, 30 insertions(+), 31 deletions(-) diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index 63477989f324..c8327de94232 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -441,38 +441,37 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator> "shape_const", topi::kBroadcast); scalars_.push_back(value); return {value}; + } + if (data_dependant) { + void* data = op->data->data; + DataType dtype = DataType(op->data->dtype); + auto value = tvm::te::compute( + {}, + [&](const Array&) { + if (dtype == DataType::Int(32)) { + return make_const(dtype, static_cast(data)[0]); + } else if (dtype == DataType::Int(64)) { + return make_const(dtype, static_cast(data)[0]); + } else if (dtype == DataType::Float(32)) { + return make_const(dtype, static_cast(data)[0]); + } else if (dtype == DataType::Float(64)) { + return make_const(dtype, static_cast(data)[0]); + } else if (dtype == DataType::Bool()) { + return make_const(dtype, static_cast(data)[0]); + } else { + LOG(FATAL) << "not handled"; + return tvm::PrimExpr(); + } + }, + "data_const", topi::kBroadcast); + scalars_.push_back(value); + return {value}; } else { - if (data_dependant) { - void* data = op->data->data; - DataType dtype = DataType(op->data->dtype); - auto value = tvm::te::compute( - {}, - [&](const Array&) { - if (dtype == DataType::Int(32)) { - return make_const(dtype, static_cast(data)[0]); - } else if (dtype == DataType::Int(64)) { - return make_const(dtype, static_cast(data)[0]); - } else if (dtype == DataType::Float(32)) { - return make_const(dtype, static_cast(data)[0]); - } else if (dtype == DataType::Float(64)) { - return make_const(dtype, static_cast(data)[0]); - } else if (dtype == DataType::Bool()) { - return make_const(dtype, static_cast(data)[0]); - } else { - LOG(FATAL) << "not handled"; - return tvm::PrimExpr(); - } - }, - "data_const", topi::kBroadcast); - scalars_.push_back(value); - return {value}; - } else { - auto value = tvm::te::compute( - {}, [&](const Array&) { return tir::make_const(DataType::Int(64), 0); }, - "shape_const", topi::kBroadcast); - scalars_.push_back(value); - return {value}; - } + auto value = tvm::te::compute( + {}, [&](const Array&) { return tir::make_const(DataType::Int(64), 0); }, + "shape_const", topi::kBroadcast); + scalars_.push_back(value); + return {value}; } }