Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion src/relay/backend/compile_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -420,8 +420,28 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator<Array<te::Tensor>>
Array<te::Tensor> VisitExpr_(const ConstantNode* op) final {
using tir::make_const;
ICHECK(data_dependants_.size());
ICHECK(op->is_scalar());
bool data_dependant = data_dependants_.back();
if (!op->is_scalar()) {
// This is a constant weight, extract the shape of the weight tensor.
// This can not be data dependent.
CHECK(!data_dependant);
auto ttype = op->checked_type().as<TensorTypeNode>();
int ndim = static_cast<int>(ttype->shape.size());
Array<PrimExpr> out_shape{ndim};
te::Tensor value = tvm::te::compute(
out_shape,
[&](const Array<tvm::tir::Var>& indices) {
auto idx = indices[0];
PrimExpr ret = make_const(DataType::Int(64), 0);
for (int i = 0; i < ndim; i++) {
ret = tvm::if_then_else(idx == i, ttype->shape[i], ret);
}
return ret;
},
"shape_const", topi::kBroadcast);
scalars_.push_back(value);
return {value};
}
if (data_dependant) {
void* data = op->data->data;
DataType dtype = DataType(op->data->dtype);
Expand Down
25 changes: 25 additions & 0 deletions tests/python/relay/test_vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -770,5 +770,30 @@ def test_vm_reshape_tuple(x_shape=(1, 4, 2), y_shape=(1, 2, 10)):
tvm.testing.assert_allclose(res.asnumpy(), np.reshape(x_data, (1, -1)))


def test_constant_shape_with_external_codegen():
mod = tvm.IRModule()
shape = (relay.Any(), 25)
dtype = "float32"

# external function
x = relay.var("x", shape=shape, dtype=dtype)
weight = relay.const(np.random.rand(5, 25).astype("float32"), dtype="float32")
out = relay.nn.dense(x, weight)
f1 = relay.Function([x], out)
f1 = f1.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
f1 = f1.with_attr("Inline", tvm.tir.IntImm("int32", 1))
f1 = f1.with_attr("Compiler", "a")
glb_f1 = relay.GlobalVar("f1")
mod[glb_f1] = f1
mod = relay.transform.InferType()(mod)

# Main function
x = relay.var("x", shape=shape, dtype=dtype)
mod["main"] = relay.Function([x], glb_f1(x))
comp = relay.vm.VMCompiler()
opt_mod, _ = comp.optimize(mod, target="llvm")
assert "shape_func" in opt_mod.astext(False)


if __name__ == "__main__":
pytest.main([__file__])