From acbaa1f4862265f13b1b4763805e91b91e7034af Mon Sep 17 00:00:00 2001 From: sung Date: Sun, 12 Mar 2023 13:19:53 -0700 Subject: [PATCH 01/12] FEAT: Support data-dependent operation of reshape --- python/tvm/relax/op/base.py | 14 +++ .../transform/legalize_ops/manipulate.py | 6 +- python/tvm/script/ir_builder/relax/ir.py | 2 + src/relax/backend/vm/vm_builtin_lower.cc | 45 ++++++++- src/relax/op/op.cc | 26 +++++ src/relax/op/tensor/manipulate.cc | 10 ++ src/runtime/relax_vm/builtin.cc | 34 +++++++ tests/python/relax/test_relax_operators.py | 18 +++- .../test_transform_legalize_ops_manipulate.py | 95 +++++++++++++++++++ 9 files changed, 245 insertions(+), 5 deletions(-) diff --git a/python/tvm/relax/op/base.py b/python/tvm/relax/op/base.py index aef0e731db51..19cc1bad8afb 100644 --- a/python/tvm/relax/op/base.py +++ b/python/tvm/relax/op/base.py @@ -402,3 +402,17 @@ def shape_of(expr: Expr) -> Expr: A relax Call, which gets the shape of the input """ return _ffi_api.shape_of(expr) # type: ignore # pylint: disable=no-member + + +def tensor_to_shape(expr: Expr) -> Expr: + """Convert tensor to shape expr. + Parameters + ---------- + expr : Expr + The input Expr + Returns + ------- + result : ShapeExpr + ShapeExpr for the tensor values + """ + return _ffi_api.tensor_to_shape(expr) # type: ignore # pylint: disable=no-member diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py b/python/tvm/relax/transform/legalize_ops/manipulate.py index e7cae1af3481..144ef04748c5 100644 --- a/python/tvm/relax/transform/legalize_ops/manipulate.py +++ b/python/tvm/relax/transform/legalize_ops/manipulate.py @@ -23,7 +23,7 @@ from tvm import topi, tir, relax, te from tvm.tir.expr import IntImm from ...block_builder import BlockBuilder -from ...expr import Call, Expr, Var, Tuple, TupleGetItem +from ...expr import Call, Expr, Var, Tuple, TupleGetItem, ShapeExpr from .common import TEFunc, LegalizeFunc, register_legalize @@ -32,6 +32,10 @@ def _reshape( ) -> LegalizeFunc: def reshape_call_te(bb: BlockBuilder, call: Call): tgt_shape = call.args[1].struct_info.shape if is_collapse_sum_like else call.args[1] + # If target shape is Var, pass its bound expr only when it is ShapeExpr + if isinstance(tgt_shape, Var): + tgt_shape = bb.lookup_binding(tgt_shape) + assert isinstance(tgt_shape, ShapeExpr) return bb.call_te(te_func, call.args[0], tgt_shape, primfunc_name_hint=primfunc_name) return reshape_call_te diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index 32d6083e8aee..ae0918a0820e 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -97,6 +97,7 @@ prod, repeat, reshape, + tensor_to_shape, round, shape_of, std, @@ -612,6 +613,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "prod", "repeat", "reshape", + "tensor_to_shape", "round", "shape", "shape_of", diff --git a/src/relax/backend/vm/vm_builtin_lower.cc b/src/relax/backend/vm/vm_builtin_lower.cc index 00d8512dc6af..cc0d654f307b 100644 --- a/src/relax/backend/vm/vm_builtin_lower.cc +++ b/src/relax/backend/vm/vm_builtin_lower.cc @@ -51,6 +51,8 @@ class VMBuiltinLowerMutator : public ExprMutator { if (call->op == call_tir_dyn_op_) { return CallTIRDyn(call); + } else if (call->op == tensor_to_shape_op_) { + return TensorToShape(call); } else if (call->op == reshape_op_) { return Reshape(call); } else if (call->op == shape_of_op_) { @@ -129,9 +131,45 @@ class VMBuiltinLowerMutator : public ExprMutator { Expr Reshape(const Call& call_node) { ICHECK(call_node->args.size() == 2); ICHECK(call_node->struct_info_.defined()); - CHECK(call_node->args[1]->IsInstance()) - << "VMBuiltinLower expects the shape arg of reshape op to be a ShapeExpr"; - return Call(builtin_reshape_, call_node->args, Attrs(), {GetStructInfo(call_node)}); + auto arg = call_node->args[1]; + CHECK(arg->IsInstance() || arg->IsInstance()) + << "VMBuiltinLower expects the shape arg of reshape op to be a ShapeExpr or VarNode bound " + "to a ShapeExpr"; + + if (arg->IsInstance()) { + return Call(builtin_reshape_, call_node->args, Attrs(), {GetStructInfo(call_node)}); + } + + ICHECK(arg->IsInstance()); + Optional _bound_val = LookupBinding(Downcast(call_node->args[1])); + ICHECK(_bound_val.defined()); + Expr bound_val = _bound_val.value(); + CHECK(bound_val->IsInstance()) + << "VMBuiltinLower expects bound value to be a ShapeExpr"; + return Call(builtin_reshape_, {call_node->args[0], bound_val}, Attrs(), + {GetStructInfo(call_node)}); + } + + ShapeExpr TensorToShape(const Call& call_node) { + ICHECK(call_node->args.size() == 1); + ICHECK(call_node->struct_info_.defined()); + Expr expr = call_node->args[0]; + const ShapeStructInfoNode* sinfo = GetStructInfoAs(call_node); + ICHECK(sinfo); + // call builtin function that converts tensor to shape tuple + Var call = builder_->Emit(Call(ExternFunc("vm.builtin.tensor_to_shape"), {expr}, {}, + {GetRef(sinfo)})); + + // define symbolic variables + Array shape_var; + for (int i = 0; i < sinfo->ndim; i++) { + shape_var.push_back(tir::Var("x", DataType::Int(64))); + } + + // bind symbolic variables to the shape tuple + relax::Var var("y", ShapeStructInfo(shape_var)); + builder_->EmitNormalized(MatchCast(var, call, ShapeStructInfo(shape_var))); + return ShapeExpr(shape_var); } Expr ShapeOf(const Call& call_node) { @@ -180,6 +218,7 @@ class VMBuiltinLowerMutator : public ExprMutator { const StructInfo void_sinfo_ = TupleStructInfo(Array({})); // object to pattern match. const Op& call_tir_dyn_op_ = Op::Get("relax.vm.call_tir_dyn"); + const Op& tensor_to_shape_op_ = Op::Get("relax.tensor_to_shape"); const Op& reshape_op_ = Op::Get("relax.reshape"); const Op& shape_of_op_ = Op::Get("relax.shape_of"); const Op& make_closure_op_ = Op::Get("relax.make_closure"); diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc index a60304039475..49df881dcb8b 100644 --- a/src/relax/op/op.cc +++ b/src/relax/op/op.cc @@ -315,6 +315,32 @@ Expr MakeShapeOf(Expr expr) { TVM_REGISTER_GLOBAL("relax.op.shape_of").set_body_typed(MakeShapeOf); +// tensor_to_shape + +StructInfo ReturnTensorToShapeStructInfo(const Call& call, const BlockBuilder& ctx) { + ICHECK(call->args.size() == 1); + ICHECK(call->args[0]->struct_info_.defined()); + const auto* tsinfo = GetStructInfoAs(call->args[0]); + ICHECK(tsinfo && tsinfo->shape.defined()); + ShapeExpr shape_expr = Downcast(tsinfo->shape.value()); + ICHECK(shape_expr->values.size() == 1); + const IntImmNode* ndim = shape_expr->values[0].as(); + ICHECK(ndim); + return ShapeStructInfo(ndim->value); +} + +RELAY_REGISTER_OP("relax.tensor_to_shape") + .set_num_inputs(1) + .add_argument("input", "Expr", "The input expression") + .set_attr("FInferStructInfo", ReturnTensorToShapeStructInfo); + +Expr MakeTensorToShape(Expr expr) { + static const Op& op = Op::Get("relax.tensor_to_shape"); + return Call(op, {expr}, {}, {}); +} + +TVM_REGISTER_GLOBAL("relax.op.tensor_to_shape").set_body_typed(MakeTensorToShape); + // alloc_tensor StructInfo InferStructInfoAllocateTensor(const Call& call, const BlockBuilder& ctx) { diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index dbeb6f8d5bc7..7ea596dbf0c1 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -751,6 +751,16 @@ StructInfo InferStructInfoReshape(const Call& call, const BlockBuilder& ctx) { << new_shape_prod); } } + + if (call->args[1]->IsInstance()) { + auto ssinfo = GetStructInfoAs(call->args[1]); + ICHECK(ssinfo); + if (ssinfo->values.defined()) { + return TensorStructInfo(ShapeExpr(ssinfo->values.value()), data_sinfo->dtype); + } else { + return TensorStructInfo(data_sinfo->dtype, ssinfo->ndim); + } + } return TensorStructInfo(call->args[1], data_sinfo->dtype); } diff --git a/src/runtime/relax_vm/builtin.cc b/src/runtime/relax_vm/builtin.cc index 15a4f8702b03..5a7c1d662055 100644 --- a/src/runtime/relax_vm/builtin.cc +++ b/src/runtime/relax_vm/builtin.cc @@ -380,6 +380,40 @@ TVM_REGISTER_GLOBAL("vm.builtin.make_tuple").set_body([](TVMArgs args, TVMRetVal *rv = arr; }); +TVM_REGISTER_GLOBAL("vm.builtin.tensor_to_shape").set_body_typed([](NDArray data) { + NDArray arr = data; + if (data->device.device_type != kDLCPU) { + arr = data.CopyTo(DLDevice{kDLCPU, 0}); + } + + ICHECK_EQ(arr->ndim, 1); + ICHECK_EQ(arr->dtype.code, kDLInt); + + std::vector out_shape; + for (int i = 0; i < arr.Shape()[0]; ++i) { + int64_t result; + switch (arr->dtype.bits) { + case 16: { + result = reinterpret_cast(arr->data)[i]; + break; + } + case 32: { + result = reinterpret_cast(arr->data)[i]; + break; + } + case 64: { + result = reinterpret_cast(arr->data)[i]; + break; + } + default: + LOG(FATAL) << "Unknown scalar int type: " << DLDataType2String(arr->dtype); + throw; + } + out_shape.push_back(result); + } + return ShapeTuple(out_shape); +}); + } // namespace relax_vm } // namespace runtime } // namespace tvm diff --git a/tests/python/relax/test_relax_operators.py b/tests/python/relax/test_relax_operators.py index f197eaa9ab7c..69c10853ccd6 100644 --- a/tests/python/relax/test_relax_operators.py +++ b/tests/python/relax/test_relax_operators.py @@ -193,5 +193,21 @@ def test_op_shape_of(): assert constrained_shape == tvm.runtime.ShapeTuple([1]) +@tvm.script.ir_module +class TensorToShapeTest: + @R.function + def run_tensor_to_shape(t: R.Tensor(ndim=1, dtype="int64")) -> R.Shape((1, 2, 3)): + gv: R.Shape(ndim=3) = R.tensor_to_shape(t) + return gv + + +def test_op_tensor_to_shape(): + out_shape = run_cpu( + TensorToShapeTest, "run_tensor_to_shape", tvm.nd.array(np.array([1, 2, 3]).astype("int64")) + ) + assert out_shape == tvm.runtime.ShapeTuple([1, 2, 3]) + + if __name__ == "__main__": - tvm.testing.main() + test_op_tensor_to_shape() + # tvm.testing.main() diff --git a/tests/python/relax/test_transform_legalize_ops_manipulate.py b/tests/python/relax/test_transform_legalize_ops_manipulate.py index b50ba91089a3..c9c979424ea1 100644 --- a/tests/python/relax/test_transform_legalize_ops_manipulate.py +++ b/tests/python/relax/test_transform_legalize_ops_manipulate.py @@ -498,6 +498,53 @@ def reshape(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3), T.int64 mod = LegalizeOps()(Reshape) tvm.ir.assert_structural_equal(mod, Expected) + # ShapeExpr might be produced by shape computation + @tvm.script.ir_module + class Reshape2: + @R.function + def main(x: R.Tensor((1, 2, 3, 4), "float32")) -> R.Tensor((8, 3), "float32"): + lv: R.Shape((8, 3)) = R.shape((8, 3)) + gv: R.Tensor((8, 3), "float32") = R.reshape(x, lv) + return gv + + # After lowering, redundant var might be removed by later dead code elimination + @tvm.script.ir_module + class Expected2: + @T.prim_func + def reshape( + rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3), T.int64(4)), "float32"), + T_reshape: T.Buffer((T.int64(8), T.int64(3)), "float32"), + ): + T.func_attr({"tir.noalias": True}) + # with T.block("root"): + for ax0, ax1 in T.grid(T.int64(8), T.int64(3)): + with T.block("T_reshape"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads( + rxplaceholder[ + T.int64(0), + (v_ax0 * T.int64(3) + v_ax1) % T.int64(24) // T.int64(12), + (v_ax0 * T.int64(3) + v_ax1) % T.int64(12) // T.int64(4), + (v_ax0 * T.int64(3) + v_ax1) % T.int64(4), + ] + ) + T.writes(T_reshape[v_ax0, v_ax1]) + T_reshape[v_ax0, v_ax1] = rxplaceholder[ + T.int64(0), + (v_ax0 * T.int64(3) + v_ax1) % T.int64(24) // T.int64(12), + (v_ax0 * T.int64(3) + v_ax1) % T.int64(12) // T.int64(4), + (v_ax0 * T.int64(3) + v_ax1) % T.int64(4), + ] + + @R.function + def main(x: R.Tensor((1, 2, 3, 4), dtype="float32")) -> R.Tensor((8, 3), dtype="float32"): + lv: R.Shape((8, 3)) = R.shape((8, 3)) + gv = R.call_tir(Expected2.reshape, (x,), out_sinfo=R.Tensor((8, 3), dtype="float32")) + return gv + + mod2 = LegalizeOps()(Reshape2) + tvm.ir.assert_structural_equal(mod2, Expected2) + def test_reshape_symbolic(): # fmt: off @@ -537,6 +584,54 @@ def reshape(var_rxplaceholder: T.handle, var_T_reshape: T.handle): mod = LegalizeOps()(Reshape) tvm.ir.assert_structural_equal(mod, Expected) + # ShapeExpr might be produced by shape computation + @tvm.script.ir_module + class Reshape2: + @R.function + def main(x: R.Tensor(("a", "b"), "float32")) -> R.Tensor(("a // 2", "b * 2"), "float32"): + a = T.int64() + b = T.int64() + lv: R.Shape((a // 2, b * 2)) = R.shape((a // 2, b * 2)) + gv: R.Tensor((a // 2, b * 2), "float32") = R.reshape(x, lv) + return gv + + # After lowering, redundant var might be removed by later dead code elimination + @tvm.script.ir_module + class Expected2: + @R.function + def main(x: R.Tensor(("a", "b"), "float32")) -> R.Tensor(("a // 2", "b * 2"), "float32"): + a = T.int64() + b = T.int64() + lv: R.Shape((a // 2, b * 2)) = R.shape((a // 2, b * 2)) + gv = R.call_tir(Expected2.reshape, (x,), R.Tensor(((a // 2), (b * 2)), dtype="float32")) + return gv + + @T.prim_func + def reshape(var_rxplaceholder: T.handle, var_T_reshape: T.handle): + T.func_attr({"tir.noalias": True}) + a = T.int64() + b = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, [a, b], dtype="float32") + T_reshape = T.match_buffer( + var_T_reshape, [a // T.int64(2), b * T.int64(2)], dtype="float32" + ) + for i0, i1 in T.grid(a // T.int64(2), b * T.int64(2)): + with T.block("T_reshape"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads( + rxplaceholder[ + (ax0 * (b * T.int64(2)) + ax1) // b % a, + (ax0 * (b * T.int64(2)) + ax1) % b, + ] + ) + T.writes(T_reshape[ax0, ax1]) + T_reshape[ax0, ax1] = rxplaceholder[ + (ax0 * (b * T.int64(2)) + ax1) // b % a, (ax0 * (b * T.int64(2)) + ax1) % b + ] + + mod2 = LegalizeOps()(Reshape2) + tvm.ir.assert_structural_equal(mod2, Expected2) + def test_split_by_indices(): # fmt: off From 48090808ded7fdadffd17ca9e7bfbec2d8ec4acc Mon Sep 17 00:00:00 2001 From: sung Date: Sun, 12 Mar 2023 13:20:33 -0700 Subject: [PATCH 02/12] FEAT: Support constant folding with data-dependent reshape --- src/relax/transform/fold_constant.cc | 61 ++++++++++++++++--- .../relax/test_transform_fold_constant.py | 37 +++++++++++ 2 files changed, 90 insertions(+), 8 deletions(-) diff --git a/src/relax/transform/fold_constant.cc b/src/relax/transform/fold_constant.cc index 622dd9ad09b7..92125d7f2b8c 100644 --- a/src/relax/transform/fold_constant.cc +++ b/src/relax/transform/fold_constant.cc @@ -217,14 +217,59 @@ class ConstantFolder : public ExprMutator { return VisitCallTIR(post_call).value_or(post_call); } - // If we are in a dataflow block, we can fold ops by lowering them to call_tir. - if (builder_->CurrentBlockIsDataFlow() && legalize_map.count(op)) { - // Get the legalized expression - Expr legalized_expr = builder_->Normalize(legalize_map[op](builder_, post_call)); - // If the legalized expression is call_tir, try to fold it. - const CallNode* call = legalized_expr.as(); - if (call && call->op.same_as(call_tir_op)) { - return VisitCallTIR(GetRef(call)).value_or(post_call); + // Special logic to fold ShapeExpr between operators + // e.g., + // + // lv: R.Shape([16, 16]) = R.shape([16, 16]) + // gv: R.Tensor(lv2, dtype="float32") = R.reshape(data, lv) + // + // gv: R.Tensor(lv2, dtype="float32") = R.reshape(data, R.shape([16, 16])) + // + Array new_args; + for (auto arg : post_call->args) { + if (arg->IsInstance()) { + Optional val = LookupBinding(Downcast(arg)); + if (val.defined() && val.value()->IsInstance()) { + new_args.push_back(val.value()); + continue; + } + } + new_args.push_back(arg); + } + post_call = + Call(post_call->op, new_args, post_call->attrs, post_call->sinfo_args, post_call->span); + + // If we are in a dataflow block, we can fold ops. + if (builder_->CurrentBlockIsDataFlow()) { + // Check if we can them to call_tir + if (legalize_map.count(op)) { + // Get the legalized expression + Expr legalized_expr = builder_->Normalize(legalize_map[op](builder_, post_call)); + // If the legalized expression is call_tir, try to fold it. + const CallNode* call = legalized_expr.as(); + if (call && call->op.same_as(call_tir_op)) { + return VisitCallTIR(GetRef(call)).value_or(post_call); + } + } else if (op->name == "relax.tensor_to_shape") { + // Special handling for builtin op "relax.tensor_to_shape" + // If its input is constant, we can access its value and create ShapeExpr + ICHECK_EQ(post_call->args.size(), 1); + Expr arg = post_call->args[0]; + if (arg->IsInstance()) { + Constant constant = Downcast(arg); + runtime::NDArray ndarray = constant->data; + ICHECK_EQ(ndarray->device.device_type, kDLCPU); + ICHECK(ndarray->strides == nullptr); + ICHECK_EQ(ndarray->byte_offset, 0); + ICHECK_EQ(ndarray->ndim, 1); + const int64_t* data = static_cast(ndarray->data); + int64_t num_elems = ndarray->shape[0]; + Array shape_values; + for (int64_t i = 0; i < num_elems; i++) { + shape_values.push_back(IntImm(DataType::Int(64), data[i])); + } + return ShapeExpr(shape_values); + } } } diff --git a/tests/python/relax/test_transform_fold_constant.py b/tests/python/relax/test_transform_fold_constant.py index 5bf2d3d9ab25..ebd4348b6488 100644 --- a/tests/python/relax/test_transform_fold_constant.py +++ b/tests/python/relax/test_transform_fold_constant.py @@ -349,6 +349,43 @@ def before(c0: R.Tensor((16, 16), "float32")): tvm.ir.assert_structural_equal(after, before) +def test_fold_multiple_relax_ops_with_data_dependent_reshape(): + @tvm.script.ir_module + class Module: + @R.function + def before( + data: R.Tensor((256,), "float32"), + c0: R.Tensor((2,), "int64"), + c1: R.Tensor((2,), "int64"), + ): + with R.dataflow(): + lv0 = R.add(c0, c0) + target_shape = R.multiply(lv0, c1) + lv2: R.Shape(ndim=2) = R.tensor_to_shape(target_shape) + gv: R.Tensor(ndim=2, dtype="float32") = R.reshape(data, lv2) + R.output(gv) + return gv + + @R.function + def expected(data: R.Tensor((256,), "float32")) -> R.Tensor((16, 16), dtype="float32"): + R.func_attr({"global_symbol": "main"}) + with R.dataflow(): + gv: R.Tensor((16, 16), dtype="float32") = R.reshape(data, R.shape([16, 16])) + R.output(gv) + return gv + + c0_np = [8, 8] + c1_np = [1, 1] + before = gen_mod(Module, "before", {"c0": c0_np, "c1": c1_np}) + assert relax.analysis.well_formed(before) + + c2_np = np.multiply(np.add(c0_np, c0_np), c1_np) + expected = gen_mod(Module, "expected", {"c2": c2_np}) + + after = relax.transform.FoldConstant()(before) + tvm.ir.assert_structural_equal(after, expected) + + def test_unsupported_fold_ops_legalized_to_multiple_calls(): @tvm.script.ir_module class Module: From 1a7ae44b9755471c8e70b8fed5acee6b6860d859 Mon Sep 17 00:00:00 2001 From: sung Date: Sun, 12 Mar 2023 19:57:53 -0700 Subject: [PATCH 03/12] fix --- src/relax/op/tensor/manipulate.cc | 9 --------- tests/python/relax/test_relax_operators.py | 3 +-- 2 files changed, 1 insertion(+), 11 deletions(-) diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index 7ea596dbf0c1..f6358f0b0975 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -752,15 +752,6 @@ StructInfo InferStructInfoReshape(const Call& call, const BlockBuilder& ctx) { } } - if (call->args[1]->IsInstance()) { - auto ssinfo = GetStructInfoAs(call->args[1]); - ICHECK(ssinfo); - if (ssinfo->values.defined()) { - return TensorStructInfo(ShapeExpr(ssinfo->values.value()), data_sinfo->dtype); - } else { - return TensorStructInfo(data_sinfo->dtype, ssinfo->ndim); - } - } return TensorStructInfo(call->args[1], data_sinfo->dtype); } diff --git a/tests/python/relax/test_relax_operators.py b/tests/python/relax/test_relax_operators.py index 69c10853ccd6..a0bc664d674d 100644 --- a/tests/python/relax/test_relax_operators.py +++ b/tests/python/relax/test_relax_operators.py @@ -209,5 +209,4 @@ def test_op_tensor_to_shape(): if __name__ == "__main__": - test_op_tensor_to_shape() - # tvm.testing.main() + tvm.testing.main() From 302e5259cf8a8c2f7ed5982a36702a7da776f0b0 Mon Sep 17 00:00:00 2001 From: sung Date: Sun, 12 Mar 2023 21:31:06 -0700 Subject: [PATCH 04/12] remove empty line --- src/relax/op/tensor/manipulate.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index f6358f0b0975..dbeb6f8d5bc7 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -751,7 +751,6 @@ StructInfo InferStructInfoReshape(const Call& call, const BlockBuilder& ctx) { << new_shape_prod); } } - return TensorStructInfo(call->args[1], data_sinfo->dtype); } From 36cc64afca3bd2645eb17b3c15971e2a490258fc Mon Sep 17 00:00:00 2001 From: sung Date: Tue, 14 Mar 2023 17:50:17 -0700 Subject: [PATCH 05/12] reflect feedback --- python/tvm/relax/op/base.py | 4 +- .../test_transform_legalize_ops_manipulate.py | 50 +++++++++++++++++++ 2 files changed, 52 insertions(+), 2 deletions(-) diff --git a/python/tvm/relax/op/base.py b/python/tvm/relax/op/base.py index 19cc1bad8afb..becd3f2a0f57 100644 --- a/python/tvm/relax/op/base.py +++ b/python/tvm/relax/op/base.py @@ -412,7 +412,7 @@ def tensor_to_shape(expr: Expr) -> Expr: The input Expr Returns ------- - result : ShapeExpr - ShapeExpr for the tensor values + result : Expr + A relax Call, which transforms the tensor values to the shape """ return _ffi_api.tensor_to_shape(expr) # type: ignore # pylint: disable=no-member diff --git a/tests/python/relax/test_transform_legalize_ops_manipulate.py b/tests/python/relax/test_transform_legalize_ops_manipulate.py index c9c979424ea1..a45c9b49bc41 100644 --- a/tests/python/relax/test_transform_legalize_ops_manipulate.py +++ b/tests/python/relax/test_transform_legalize_ops_manipulate.py @@ -632,6 +632,56 @@ def reshape(var_rxplaceholder: T.handle, var_T_reshape: T.handle): mod2 = LegalizeOps()(Reshape2) tvm.ir.assert_structural_equal(mod2, Expected2) + # ShapeExpr might be produced by shape computation + @I.ir_module + class Reshape3: + @R.function + def main(x: R.Tensor((10, "b"), "float32")) -> R.Tensor((5, "b * 2"), "float32"): + a = T.int64() + b = T.int64() + lv: R.Shape((5, b * 2)) = R.shape((5, b * 2)) + gv: R.Tensor((5, b * 2), "float32") = R.reshape(x, lv) + return gv + + # After lowering, redundant var might be removed by later dead code elimination + @I.ir_module + class Expected3: + @T.prim_func + def reshape(var_rxplaceholder: T.handle, var_T_reshape: T.handle): + T.func_attr({"tir.noalias": True}) + b = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, (T.int64(10), b)) + T_reshape = T.match_buffer(var_T_reshape, (T.int64(5), b * T.int64(2))) + # with T.block("root"): + for ax0, ax1 in T.grid(T.int64(5), b * T.int64(2)): + with T.block("T_reshape"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads( + rxplaceholder[ + (v_ax0 * (b * T.int64(2)) + v_ax1) // b % T.int64(10), + (v_ax0 * (b * T.int64(2)) + v_ax1) % b, + ] + ) + T.writes(T_reshape[v_ax0, v_ax1]) + T_reshape[v_ax0, v_ax1] = rxplaceholder[ + (v_ax0 * (b * T.int64(2)) + v_ax1) // b % T.int64(10), + (v_ax0 * (b * T.int64(2)) + v_ax1) % b, + ] + + @R.function + def main( + x: R.Tensor((10, "b"), dtype="float32") + ) -> R.Tensor((5, "b * 2"), dtype="float32"): + b = T.int64() + lv: R.Shape([5, b * 2]) = R.shape([5, b * 2]) + gv = R.call_tir( + Expected3.reshape, (x,), out_sinfo=R.Tensor((5, b * 2), dtype="float32") + ) + return gv + + mod3 = LegalizeOps()(Reshape3) + tvm.ir.assert_structural_equal(mod3, Expected3) + def test_split_by_indices(): # fmt: off From bdfaa6b780381d8abc9b47835a005ae6969a89c8 Mon Sep 17 00:00:00 2001 From: sung Date: Sat, 18 Mar 2023 18:39:52 -0700 Subject: [PATCH 06/12] Lift the lowering of tensor_to_shape from builtin to DecomposeCompositeOps pass --- include/tvm/relax/transform.h | 5 +- python/tvm/relax/transform/transform.py | 11 ++-- src/relax/backend/vm/vm_builtin_lower.cc | 25 --------- src/relax/op/tensor/manipulate.cc | 8 ++- ...nference.cc => decompose_composite_ops.cc} | 53 ++++++++++++++++--- ...test_transform_decompose_composite_ops.py} | 2 +- .../test_transform_legalize_ops_manipulate.py | 43 +++++++++++++++ 7 files changed, 106 insertions(+), 41 deletions(-) rename src/relax/transform/{simplify_norm_inference.cc => decompose_composite_ops.cc} (71%) rename tests/python/relax/{test_transform_simpilify_norm_inference.py => test_transform_decompose_composite_ops.py} (98%) diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h index 5a21f76b0b4e..e4fae591f3aa 100644 --- a/include/tvm/relax/transform.h +++ b/include/tvm/relax/transform.h @@ -356,12 +356,12 @@ TVM_DLL Pass RunCodegen(Optional>> target_opt Array entry_functions); /*! - * \brief Simplify normalization operators during inference. For example, the result + * \brief Decompose composite operators during inference. For example, the result * of a batch norm which is indexed at tuple index 0 will be unpacked into a * number of simplified operators. * \return The Pass. */ -TVM_DLL Pass SimplifyNormInference(); +TVM_DLL Pass DecomposeCompositeOperator(); /*! * \brief Returns a pass which replaces PrimFuncs which have matching kOperatorName attribute in \p @@ -404,7 +404,6 @@ TVM_DLL Pass DeadCodeElimination(Array entry_functions); * \return The Pass. */ TVM_DLL Pass ToMixedPrecision(const DataType& out_dtype); - } // namespace transform } // namespace relax } // namespace tvm diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index 95f81f7e6cde..59a2ebf37976 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -577,10 +577,11 @@ def MetaScheduleTuneIRMod( return _ffi_api.MetaScheduleTuneIRMod(params, work_dir, max_trials_global) # type: ignore -def SimplifyNormInference() -> tvm.ir.transform.Pass: - """Simplify normalization operators during inference. For example, the result - of a batch norm which is indexed at tuple index 0 will be unpacked into a - number of simplified operators. +def DecomposeCompositeOperator() -> tvm.ir.transform.Pass: + """Decompose composite operators that are composed by other operators during inference. + For example, the result of a batch norm which is indexed at tuple index 0 will be unpacked into a + number of simplified operators. Attention, tensor_to_shape, etc. can be also decomposed into a number + of simplified operators as well. Returns ------- @@ -588,7 +589,7 @@ def SimplifyNormInference() -> tvm.ir.transform.Pass: The registered pass """ - return _ffi_api.SimplifyNormInference() # type: ignore + return _ffi_api.DecomposeCompositeOperator() # type: ignore def AlterOpImpl( diff --git a/src/relax/backend/vm/vm_builtin_lower.cc b/src/relax/backend/vm/vm_builtin_lower.cc index cc0d654f307b..f558c6f435ea 100644 --- a/src/relax/backend/vm/vm_builtin_lower.cc +++ b/src/relax/backend/vm/vm_builtin_lower.cc @@ -51,8 +51,6 @@ class VMBuiltinLowerMutator : public ExprMutator { if (call->op == call_tir_dyn_op_) { return CallTIRDyn(call); - } else if (call->op == tensor_to_shape_op_) { - return TensorToShape(call); } else if (call->op == reshape_op_) { return Reshape(call); } else if (call->op == shape_of_op_) { @@ -150,28 +148,6 @@ class VMBuiltinLowerMutator : public ExprMutator { {GetStructInfo(call_node)}); } - ShapeExpr TensorToShape(const Call& call_node) { - ICHECK(call_node->args.size() == 1); - ICHECK(call_node->struct_info_.defined()); - Expr expr = call_node->args[0]; - const ShapeStructInfoNode* sinfo = GetStructInfoAs(call_node); - ICHECK(sinfo); - // call builtin function that converts tensor to shape tuple - Var call = builder_->Emit(Call(ExternFunc("vm.builtin.tensor_to_shape"), {expr}, {}, - {GetRef(sinfo)})); - - // define symbolic variables - Array shape_var; - for (int i = 0; i < sinfo->ndim; i++) { - shape_var.push_back(tir::Var("x", DataType::Int(64))); - } - - // bind symbolic variables to the shape tuple - relax::Var var("y", ShapeStructInfo(shape_var)); - builder_->EmitNormalized(MatchCast(var, call, ShapeStructInfo(shape_var))); - return ShapeExpr(shape_var); - } - Expr ShapeOf(const Call& call_node) { ICHECK(call_node->args.size() == 1); ICHECK(call_node->struct_info_.defined()); @@ -218,7 +194,6 @@ class VMBuiltinLowerMutator : public ExprMutator { const StructInfo void_sinfo_ = TupleStructInfo(Array({})); // object to pattern match. const Op& call_tir_dyn_op_ = Op::Get("relax.vm.call_tir_dyn"); - const Op& tensor_to_shape_op_ = Op::Get("relax.tensor_to_shape"); const Op& reshape_op_ = Op::Get("relax.reshape"); const Op& shape_of_op_ = Op::Get("relax.shape_of"); const Op& make_closure_op_ = Op::Get("relax.make_closure"); diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index dbeb6f8d5bc7..f4e54589cf11 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -751,7 +751,13 @@ StructInfo InferStructInfoReshape(const Call& call, const BlockBuilder& ctx) { << new_shape_prod); } } - return TensorStructInfo(call->args[1], data_sinfo->dtype); + + Expr target_shape = call->args[1]; + // If shape values are defined, use them + if (target_shape->IsInstance() && new_shape_sinfo->values.defined()) { + return TensorStructInfo(ShapeExpr(new_shape_sinfo->values.value()), data_sinfo->dtype); + } + return TensorStructInfo(target_shape, data_sinfo->dtype); } TVM_REGISTER_OP("relax.reshape") diff --git a/src/relax/transform/simplify_norm_inference.cc b/src/relax/transform/decompose_composite_ops.cc similarity index 71% rename from src/relax/transform/simplify_norm_inference.cc rename to src/relax/transform/decompose_composite_ops.cc index 545098db28bd..2ab84b48f4d0 100644 --- a/src/relax/transform/simplify_norm_inference.cc +++ b/src/relax/transform/decompose_composite_ops.cc @@ -21,6 +21,7 @@ #include #include +#include #include #include "utils.h" @@ -110,21 +111,61 @@ class NormInferenceSimplifier : public ExprMutator { Map batch_norm_map_; }; +class OpDecomposer : public ExprMutator { + public: + static Expr Decompose(Expr expr) { return OpDecomposer()(expr); } + + private: + using ExprMutator::VisitExpr_; + Expr TensorToShape(const Call& call_node) { + ICHECK(call_node->args.size() == 1); + ICHECK(call_node->struct_info_.defined()); + Expr expr = call_node->args[0]; + const ShapeStructInfoNode* sinfo = GetStructInfoAs(call_node); + ICHECK(sinfo); + // call builtin function that converts tensor to shape tuple + Var call = builder_->Emit(Call(ExternFunc("vm.builtin.tensor_to_shape"), {expr}, {}, + {GetRef(sinfo)})); + // define symbolic variables + Array shape_var; + for (int i = 0; i < sinfo->ndim; i++) { + shape_var.push_back(tir::Var("x", DataType::Int(64))); + } + // bind symbolic variables to the shape tuple + relax::Var var("y", ShapeStructInfo(shape_var)); + builder_->EmitNormalized(MatchCast(var, call, ShapeStructInfo(shape_var))); + return ShapeExpr(shape_var); + } + + Expr VisitExpr_(const CallNode* call_node) final { + Call call = Downcast(VisitExprPostOrder_(call_node)); + if (call->op == tensor_to_shape_op_) { + return TensorToShape(call); + } else { + return call; + } + } + + const Op& tensor_to_shape_op_ = Op::Get("relax.tensor_to_shape"); +}; + namespace transform { -Pass SimplifyNormInference() { +Pass DecomposeCompositeOperator() { runtime::TypedPackedFunc pass_func = [=](Function f, IRModule m, PassContext pc) { f = Downcast(NormInferenceSimplifier::Simplify(f)); - // Remove original batch_norm op if it's not used. + f = Downcast(OpDecomposer::Decompose(f)); + // Remove original ops if it's not used. return RemoveAllUnused(f); }; - return CreateFunctionPass(/*pass_function=*/pass_func, // - /*opt_level=*/0, // - /*pass_name=*/"SimplifyNormInference", // + return CreateFunctionPass(/*pass_function=*/pass_func, // + /*opt_level=*/0, // + /*pass_name=*/"DecomposeCompositeOperator", // /*required=*/{}); } -TVM_REGISTER_GLOBAL("relax.transform.SimplifyNormInference").set_body_typed(SimplifyNormInference); +TVM_REGISTER_GLOBAL("relax.transform.DecomposeCompositeOperator") + .set_body_typed(DecomposeCompositeOperator); } // namespace transform } // namespace relax diff --git a/tests/python/relax/test_transform_simpilify_norm_inference.py b/tests/python/relax/test_transform_decompose_composite_ops.py similarity index 98% rename from tests/python/relax/test_transform_simpilify_norm_inference.py rename to tests/python/relax/test_transform_decompose_composite_ops.py index 3c981ba0351e..26a8844a2937 100644 --- a/tests/python/relax/test_transform_simpilify_norm_inference.py +++ b/tests/python/relax/test_transform_decompose_composite_ops.py @@ -30,7 +30,7 @@ def _check(before: Union[Function, IRModule], expected: Union[Function, IRModule before = IRModule({"main": before}) if isinstance(expected, Function): expected = IRModule({"main": expected}) - after = relax.transform.SimplifyNormInference()(before) + after = relax.transform.DecomposeCompositeOperator()(before) tvm.ir.assert_structural_equal(expected, after) diff --git a/tests/python/relax/test_transform_legalize_ops_manipulate.py b/tests/python/relax/test_transform_legalize_ops_manipulate.py index a45c9b49bc41..5a5fc588fcf7 100644 --- a/tests/python/relax/test_transform_legalize_ops_manipulate.py +++ b/tests/python/relax/test_transform_legalize_ops_manipulate.py @@ -17,6 +17,7 @@ import pytest import tvm +from tvm import relax from tvm.relax.transform import LegalizeOps from tvm.script import relax as R, tir as T, ir as I import tvm.testing @@ -682,7 +683,49 @@ def main( mod3 = LegalizeOps()(Reshape3) tvm.ir.assert_structural_equal(mod3, Expected3) +def test_data_dependent_reshape(): + # fmt: off + @tvm.script.ir_module + class DDReshape: + @R.function + def main(x: R.Tensor((3, ), dtype="int64")): + lv: R.Shape([3,]) = R.tensor_to_shape(x) + gv = R.reshape(x, lv) + return gv + + assert relax.analysis.well_formed(DDReshape) + mod = relax.transform.DecomposeCompositeOperator()(DDReshape) + out_mod = relax.transform.LegalizeOps()(mod) + + @I.ir_module + class Expected: + @T.prim_func + def reshape( + rxplaceholder: T.Buffer((T.int64(3),), "int64"), var_T_reshape: T.handle + ): + T.func_attr({"tir.noalias": True}) + x = T.int64() + T_reshape = T.match_buffer(var_T_reshape, (x,), "int64") + # with T.block("root"): + for ax0 in range(x): + with T.block("T_reshape"): + v_ax0 = T.axis.spatial(x, ax0) + T.reads(rxplaceholder[v_ax0 % T.int64(3)]) + T.writes(T_reshape[v_ax0]) + T_reshape[v_ax0] = rxplaceholder[v_ax0 % T.int64(3)] + @R.function + def main(x: R.Tensor((3,), dtype="int64")) -> R.Tensor((3,), dtype="int64"): + x_1 = T.int64() + gv: R.Shape([3]) = R.call_packed( + "vm.builtin.tensor_to_shape", x, sinfo_args=(R.Shape([3]),) + ) + y: R.Shape([x_1]) = R.match_cast(gv, R.Shape([x_1])) + lv: R.Shape([x_1]) = R.shape([x_1]) + gv_1 = R.call_tir(Expected.reshape, (x,), out_sinfo=R.Tensor((x_1,), dtype="int64")) + return gv_1 + tvm.ir.assert_structural_equal(out_mod, Expected) + def test_split_by_indices(): # fmt: off @tvm.script.ir_module From 8ce3895e4dcc0fa02d819b6ced65dedd1899a037 Mon Sep 17 00:00:00 2001 From: sung Date: Sat, 18 Mar 2023 19:27:55 -0700 Subject: [PATCH 07/12] fix and comment --- include/tvm/relax/transform.h | 1 + python/tvm/relax/transform/transform.py | 10 ++++---- src/relax/op/tensor/manipulate.cc | 1 - .../transform/decompose_composite_ops.cc | 11 ++++---- src/relax/transform/fold_constant.cc | 11 +++++++- tests/python/relax/test_relax_operators.py | 15 ----------- .../test_transform_decompose_composite_ops.py | 25 +++++++++++++++++-- .../test_transform_legalize_ops_manipulate.py | 13 +++++++--- 8 files changed, 54 insertions(+), 33 deletions(-) diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h index e4fae591f3aa..5d99946aeff8 100644 --- a/include/tvm/relax/transform.h +++ b/include/tvm/relax/transform.h @@ -404,6 +404,7 @@ TVM_DLL Pass DeadCodeElimination(Array entry_functions); * \return The Pass. */ TVM_DLL Pass ToMixedPrecision(const DataType& out_dtype); + } // namespace transform } // namespace relax } // namespace tvm diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index 59a2ebf37976..1ec0cb0dde0e 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -577,11 +577,11 @@ def MetaScheduleTuneIRMod( return _ffi_api.MetaScheduleTuneIRMod(params, work_dir, max_trials_global) # type: ignore -def DecomposeCompositeOperator() -> tvm.ir.transform.Pass: +def DecomposeCompositeOps() -> tvm.ir.transform.Pass: """Decompose composite operators that are composed by other operators during inference. - For example, the result of a batch norm which is indexed at tuple index 0 will be unpacked into a - number of simplified operators. Attention, tensor_to_shape, etc. can be also decomposed into a number - of simplified operators as well. + For example, the result of a batch norm which is indexed at tuple index 0 will be unpacked + into a number of simplified operators. Attention, tensor_to_shape, etc. can be also + decomposed into a number of simplified operators as well. Returns ------- @@ -589,7 +589,7 @@ def DecomposeCompositeOperator() -> tvm.ir.transform.Pass: The registered pass """ - return _ffi_api.DecomposeCompositeOperator() # type: ignore + return _ffi_api.DecomposeCompositeOps() # type: ignore def AlterOpImpl( diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index f4e54589cf11..faa5ee3bc099 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -751,7 +751,6 @@ StructInfo InferStructInfoReshape(const Call& call, const BlockBuilder& ctx) { << new_shape_prod); } } - Expr target_shape = call->args[1]; // If shape values are defined, use them if (target_shape->IsInstance() && new_shape_sinfo->values.defined()) { diff --git a/src/relax/transform/decompose_composite_ops.cc b/src/relax/transform/decompose_composite_ops.cc index 2ab84b48f4d0..59dd9e0a5453 100644 --- a/src/relax/transform/decompose_composite_ops.cc +++ b/src/relax/transform/decompose_composite_ops.cc @@ -150,7 +150,7 @@ class OpDecomposer : public ExprMutator { }; namespace transform { -Pass DecomposeCompositeOperator() { +Pass DecomposeCompositeOps() { runtime::TypedPackedFunc pass_func = [=](Function f, IRModule m, PassContext pc) { f = Downcast(NormInferenceSimplifier::Simplify(f)); @@ -158,14 +158,13 @@ Pass DecomposeCompositeOperator() { // Remove original ops if it's not used. return RemoveAllUnused(f); }; - return CreateFunctionPass(/*pass_function=*/pass_func, // - /*opt_level=*/0, // - /*pass_name=*/"DecomposeCompositeOperator", // + return CreateFunctionPass(/*pass_function=*/pass_func, // + /*opt_level=*/0, // + /*pass_name=*/"DecomposeCompositeOps", // /*required=*/{}); } -TVM_REGISTER_GLOBAL("relax.transform.DecomposeCompositeOperator") - .set_body_typed(DecomposeCompositeOperator); +TVM_REGISTER_GLOBAL("relax.transform.DecomposeCompositeOps").set_body_typed(DecomposeCompositeOps); } // namespace transform } // namespace relax diff --git a/src/relax/transform/fold_constant.cc b/src/relax/transform/fold_constant.cc index 92125d7f2b8c..315b3dc1f2a2 100644 --- a/src/relax/transform/fold_constant.cc +++ b/src/relax/transform/fold_constant.cc @@ -196,6 +196,10 @@ class ConstantFolder : public ExprMutator { using ExprMutator::VisitExpr_; + // TODO(@sunggg): + // Next PR will support fold with PackedFunc and MatchCast + // Until then, DecomposeCompositeOps() should be applied after + // this pass to fold `tensor_to_shape` op. Expr VisitExpr_(const CallNode* call) final { // post-order mutation Call post_call = Downcast(VisitExprPostOrder_(call)); @@ -251,8 +255,13 @@ class ConstantFolder : public ExprMutator { return VisitCallTIR(GetRef(call)).value_or(post_call); } } else if (op->name == "relax.tensor_to_shape") { - // Special handling for builtin op "relax.tensor_to_shape" + // Special handling for composite op "relax.tensor_to_shape" // If its input is constant, we can access its value and create ShapeExpr + // TODO(@sunggg): + // currently, we do not have a info map about decomposition. + // Thus, this is a temporary solution until we have a consensus about + // how to deal with composite ops. One possibility is we register the + // decomposition map for each op in a similar way we do for legalization. ICHECK_EQ(post_call->args.size(), 1); Expr arg = post_call->args[0]; if (arg->IsInstance()) { diff --git a/tests/python/relax/test_relax_operators.py b/tests/python/relax/test_relax_operators.py index a0bc664d674d..f197eaa9ab7c 100644 --- a/tests/python/relax/test_relax_operators.py +++ b/tests/python/relax/test_relax_operators.py @@ -193,20 +193,5 @@ def test_op_shape_of(): assert constrained_shape == tvm.runtime.ShapeTuple([1]) -@tvm.script.ir_module -class TensorToShapeTest: - @R.function - def run_tensor_to_shape(t: R.Tensor(ndim=1, dtype="int64")) -> R.Shape((1, 2, 3)): - gv: R.Shape(ndim=3) = R.tensor_to_shape(t) - return gv - - -def test_op_tensor_to_shape(): - out_shape = run_cpu( - TensorToShapeTest, "run_tensor_to_shape", tvm.nd.array(np.array([1, 2, 3]).astype("int64")) - ) - assert out_shape == tvm.runtime.ShapeTuple([1, 2, 3]) - - if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_transform_decompose_composite_ops.py b/tests/python/relax/test_transform_decompose_composite_ops.py index 26a8844a2937..08483600a3ed 100644 --- a/tests/python/relax/test_transform_decompose_composite_ops.py +++ b/tests/python/relax/test_transform_decompose_composite_ops.py @@ -22,7 +22,7 @@ import tvm.testing from tvm import IRModule, relax from tvm.relax import Function -from tvm.script import relax as R +from tvm.script import relax as R, tir as T def _check(before: Union[Function, IRModule], expected: Union[Function, IRModule]): @@ -30,7 +30,7 @@ def _check(before: Union[Function, IRModule], expected: Union[Function, IRModule before = IRModule({"main": before}) if isinstance(expected, Function): expected = IRModule({"main": expected}) - after = relax.transform.DecomposeCompositeOperator()(before) + after = relax.transform.DecomposeCompositeOps()(before) tvm.ir.assert_structural_equal(expected, after) @@ -149,5 +149,26 @@ def expected( _check(before, expected) +def test_op_tensor_to_shape(): + @R.function + def before(t: R.Tensor(ndim=1, dtype="int64")): + gv: R.Shape(ndim=3) = R.tensor_to_shape(t) + return gv + + @R.function + def expected(t: R.Tensor(dtype="int64", ndim=1)) -> R.Shape(ndim=3): + x = T.int64() + x_1 = T.int64() + x_2 = T.int64() + gv: R.Shape(ndim=3) = R.call_packed( + "vm.builtin.tensor_to_shape", t, sinfo_args=(R.Shape(ndim=3),) + ) + y: R.Shape([x, x_1, x_2]) = R.match_cast(gv, R.Shape([x, x_1, x_2])) + gv_1: R.Shape([x, x_1, x_2]) = R.shape([x, x_1, x_2]) + return gv_1 + + _check(before, expected) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_transform_legalize_ops_manipulate.py b/tests/python/relax/test_transform_legalize_ops_manipulate.py index 5a5fc588fcf7..cce35a90263c 100644 --- a/tests/python/relax/test_transform_legalize_ops_manipulate.py +++ b/tests/python/relax/test_transform_legalize_ops_manipulate.py @@ -499,6 +499,7 @@ def reshape(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3), T.int64 mod = LegalizeOps()(Reshape) tvm.ir.assert_structural_equal(mod, Expected) + # fmt: off # ShapeExpr might be produced by shape computation @tvm.script.ir_module class Reshape2: @@ -542,6 +543,7 @@ def main(x: R.Tensor((1, 2, 3, 4), dtype="float32")) -> R.Tensor((8, 3), dtype=" lv: R.Shape((8, 3)) = R.shape((8, 3)) gv = R.call_tir(Expected2.reshape, (x,), out_sinfo=R.Tensor((8, 3), dtype="float32")) return gv + # fmt: on mod2 = LegalizeOps()(Reshape2) tvm.ir.assert_structural_equal(mod2, Expected2) @@ -683,6 +685,7 @@ def main( mod3 = LegalizeOps()(Reshape3) tvm.ir.assert_structural_equal(mod3, Expected3) + def test_data_dependent_reshape(): # fmt: off @tvm.script.ir_module @@ -692,11 +695,13 @@ def main(x: R.Tensor((3, ), dtype="int64")): lv: R.Shape([3,]) = R.tensor_to_shape(x) gv = R.reshape(x, lv) return gv - + # fmt: on + assert relax.analysis.well_formed(DDReshape) - mod = relax.transform.DecomposeCompositeOperator()(DDReshape) + mod = relax.transform.DecomposeCompositeOps()(DDReshape) out_mod = relax.transform.LegalizeOps()(mod) + # fmt: off @I.ir_module class Expected: @T.prim_func @@ -724,8 +729,10 @@ def main(x: R.Tensor((3,), dtype="int64")) -> R.Tensor((3,), dtype="int64"): lv: R.Shape([x_1]) = R.shape([x_1]) gv_1 = R.call_tir(Expected.reshape, (x,), out_sinfo=R.Tensor((x_1,), dtype="int64")) return gv_1 + # fmt: on tvm.ir.assert_structural_equal(out_mod, Expected) - + + def test_split_by_indices(): # fmt: off @tvm.script.ir_module From eb4012964ff3d1fa672ae2af99884d417209922c Mon Sep 17 00:00:00 2001 From: sung Date: Sun, 19 Mar 2023 23:19:46 -0700 Subject: [PATCH 08/12] fix --- tests/python/relax/test_op_manipulate.py | 25 ++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/tests/python/relax/test_op_manipulate.py b/tests/python/relax/test_op_manipulate.py index 16bbc04d269a..3edf63764a58 100644 --- a/tests/python/relax/test_op_manipulate.py +++ b/tests/python/relax/test_op_manipulate.py @@ -88,12 +88,13 @@ def test_reshape_infer_struct_info(): _check_inference( bb, relax.op.reshape(x5, (3, 8, 5)), relax.TensorStructInfo((3, 8, 5), dtype="") ) - _check_inference(bb, relax.op.reshape(x0, s0), relax.TensorStructInfo(s0, "float32")) - _check_inference(bb, relax.op.reshape(x1, s0), relax.TensorStructInfo(s0, "float32")) - _check_inference(bb, relax.op.reshape(x2, s0), relax.TensorStructInfo(s0, "float32")) - _check_inference(bb, relax.op.reshape(x3, s0), relax.TensorStructInfo(s0, dtype="")) - _check_inference(bb, relax.op.reshape(x4, s0), relax.TensorStructInfo(s0, dtype="")) - _check_inference(bb, relax.op.reshape(x5, s0), relax.TensorStructInfo(s0, dtype="")) + # Remove Var from StructInfo when we can + _check_inference(bb, relax.op.reshape(x0, s0), relax.TensorStructInfo((3, 8, 5), "float32")) + _check_inference(bb, relax.op.reshape(x1, s0), relax.TensorStructInfo((3, 8, 5), "float32")) + _check_inference(bb, relax.op.reshape(x2, s0), relax.TensorStructInfo((3, 8, 5), "float32")) + _check_inference(bb, relax.op.reshape(x3, s0), relax.TensorStructInfo((3, 8, 5), dtype="")) + _check_inference(bb, relax.op.reshape(x4, s0), relax.TensorStructInfo((3, 8, 5), dtype="")) + _check_inference(bb, relax.op.reshape(x5, s0), relax.TensorStructInfo((3, 8, 5), dtype="")) _check_inference(bb, relax.op.reshape(x0, s1), relax.TensorStructInfo(s1, "float32")) _check_inference(bb, relax.op.reshape(x1, s1), relax.TensorStructInfo(s1, "float32")) _check_inference(bb, relax.op.reshape(x2, s1), relax.TensorStructInfo(s1, "float32")) @@ -160,7 +161,8 @@ def test_reshape_infer_struct_info_shape_symbolic(): (c, a * b * d, tir.floordiv(a * b * c * d, c * (a * b * d))), "float32" ), ) - _check_inference(bb, relax.op.reshape(x, s0), relax.TensorStructInfo(s0, "float32")) + # Remove Var from StructInfo when we can + _check_inference(bb, relax.op.reshape(x, s0), relax.TensorStructInfo((c, a, d, b), "float32")) _check_inference(bb, relax.op.reshape(x, s1), relax.TensorStructInfo(s1, "float32")) _check_inference(bb, relax.op.reshape(x, s2), relax.TensorStructInfo(s2, "float32")) @@ -188,17 +190,20 @@ def test_reshape_infer_struct_info_shape_var(): _check_inference( bb, relax.op.reshape(x0, (3, -1, 5)), relax.TensorStructInfo((3, 8, 5), "float32") ) - _check_inference(bb, relax.op.reshape(x0, ns0), relax.TensorStructInfo(ns0, "float32")) + # Remove Var from StructInfo when we can + _check_inference(bb, relax.op.reshape(x0, ns0), relax.TensorStructInfo((3, 8, 5), "float32")) _check_inference(bb, relax.op.reshape(x0, ns1), relax.TensorStructInfo(ns1, "float32")) _check_inference( bb, relax.op.reshape(x1, (3, 8, 5)), relax.TensorStructInfo((3, 8, 5), "float32") ) - _check_inference(bb, relax.op.reshape(x1, ns0), relax.TensorStructInfo(ns0, "float32")) + # Remove Var from StructInfo when we can + _check_inference(bb, relax.op.reshape(x1, ns0), relax.TensorStructInfo((3, 8, 5), "float32")) _check_inference(bb, relax.op.reshape(x1, ns1), relax.TensorStructInfo(ns1, "float32")) _check_inference( bb, relax.op.reshape(x2, (3, 8, 5)), relax.TensorStructInfo((3, 8, 5), "float32") ) - _check_inference(bb, relax.op.reshape(x2, ns0), relax.TensorStructInfo(ns0, "float32")) + # Remove Var from StructInfo when we can + _check_inference(bb, relax.op.reshape(x2, ns0), relax.TensorStructInfo((3, 8, 5), "float32")) _check_inference(bb, relax.op.reshape(x2, ns1), relax.TensorStructInfo(ns1, "float32")) From f4095b62734c1fa9485574dfac91b2c2e2af9f2c Mon Sep 17 00:00:00 2001 From: sung Date: Mon, 20 Mar 2023 13:58:47 -0700 Subject: [PATCH 09/12] add comments --- include/tvm/relax/transform.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h index 5d99946aeff8..4434b802a189 100644 --- a/include/tvm/relax/transform.h +++ b/include/tvm/relax/transform.h @@ -358,7 +358,8 @@ TVM_DLL Pass RunCodegen(Optional>> target_opt /*! * \brief Decompose composite operators during inference. For example, the result * of a batch norm which is indexed at tuple index 0 will be unpacked into a - * number of simplified operators. + * number of simplified operators. Operators like Attention, Erf, etc. can be also + * simplified into several operators as well. * \return The Pass. */ TVM_DLL Pass DecomposeCompositeOperator(); From 600a94d996536876b740f8d5f1bf54968059fff7 Mon Sep 17 00:00:00 2001 From: sung Date: Mon, 20 Mar 2023 16:55:45 -0700 Subject: [PATCH 10/12] reflect feedback --- src/relax/backend/vm/vm_builtin_lower.cc | 18 +++++++++--------- src/relax/transform/decompose_composite_ops.cc | 6 +++--- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/relax/backend/vm/vm_builtin_lower.cc b/src/relax/backend/vm/vm_builtin_lower.cc index f558c6f435ea..5bf419499714 100644 --- a/src/relax/backend/vm/vm_builtin_lower.cc +++ b/src/relax/backend/vm/vm_builtin_lower.cc @@ -136,16 +136,16 @@ class VMBuiltinLowerMutator : public ExprMutator { if (arg->IsInstance()) { return Call(builtin_reshape_, call_node->args, Attrs(), {GetStructInfo(call_node)}); + } else { + // Handling the case when arg is VarNode + Optional _bound_val = LookupBinding(Downcast(arg)); + ICHECK(_bound_val.defined()); + Expr bound_val = _bound_val.value(); + CHECK(bound_val->IsInstance()) + << "VMBuiltinLower expects bound value to be a ShapeExpr"; + return Call(builtin_reshape_, {call_node->args[0], bound_val}, Attrs(), + {GetStructInfo(call_node)}); } - - ICHECK(arg->IsInstance()); - Optional _bound_val = LookupBinding(Downcast(call_node->args[1])); - ICHECK(_bound_val.defined()); - Expr bound_val = _bound_val.value(); - CHECK(bound_val->IsInstance()) - << "VMBuiltinLower expects bound value to be a ShapeExpr"; - return Call(builtin_reshape_, {call_node->args[0], bound_val}, Attrs(), - {GetStructInfo(call_node)}); } Expr ShapeOf(const Call& call_node) { diff --git a/src/relax/transform/decompose_composite_ops.cc b/src/relax/transform/decompose_composite_ops.cc index 59dd9e0a5453..aa422ec141b4 100644 --- a/src/relax/transform/decompose_composite_ops.cc +++ b/src/relax/transform/decompose_composite_ops.cc @@ -118,14 +118,14 @@ class OpDecomposer : public ExprMutator { private: using ExprMutator::VisitExpr_; Expr TensorToShape(const Call& call_node) { - ICHECK(call_node->args.size() == 1); ICHECK(call_node->struct_info_.defined()); Expr expr = call_node->args[0]; const ShapeStructInfoNode* sinfo = GetStructInfoAs(call_node); ICHECK(sinfo); // call builtin function that converts tensor to shape tuple - Var call = builder_->Emit(Call(ExternFunc("vm.builtin.tensor_to_shape"), {expr}, {}, - {GetRef(sinfo)})); + static const Op& tensor_to_shape_op = Op::Get("relax.builtin.tensor_to_shape"); + Var call = + builder_->Emit(Call(tensor_to_shape_op, {expr}, {}, {GetRef(sinfo)})); // define symbolic variables Array shape_var; for (int i = 0; i < sinfo->ndim; i++) { From 6022718cb723346913cf90e5e89e8de3f8783958 Mon Sep 17 00:00:00 2001 From: sung Date: Mon, 20 Mar 2023 21:25:23 -0700 Subject: [PATCH 11/12] add comment --- src/relax/transform/decompose_composite_ops.cc | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/relax/transform/decompose_composite_ops.cc b/src/relax/transform/decompose_composite_ops.cc index aa422ec141b4..b9b2d1b394cc 100644 --- a/src/relax/transform/decompose_composite_ops.cc +++ b/src/relax/transform/decompose_composite_ops.cc @@ -126,7 +126,10 @@ class OpDecomposer : public ExprMutator { static const Op& tensor_to_shape_op = Op::Get("relax.builtin.tensor_to_shape"); Var call = builder_->Emit(Call(tensor_to_shape_op, {expr}, {}, {GetRef(sinfo)})); - // define symbolic variables + + // Operators like reshape take the output of `TensorToShape` as their output shape. + // Because TOPI expects to have such output shape in symbolic shape at least (i.e., + // Array), we define symbolic variables and returns them as a ShapeExpr. Array shape_var; for (int i = 0; i < sinfo->ndim; i++) { shape_var.push_back(tir::Var("x", DataType::Int(64))); From 28c71b40abdb869e3f858ac85827c8ce2b9324fe Mon Sep 17 00:00:00 2001 From: sung Date: Tue, 21 Mar 2023 11:45:51 -0700 Subject: [PATCH 12/12] fix --- src/relax/transform/decompose_composite_ops.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/relax/transform/decompose_composite_ops.cc b/src/relax/transform/decompose_composite_ops.cc index b9b2d1b394cc..36814422216b 100644 --- a/src/relax/transform/decompose_composite_ops.cc +++ b/src/relax/transform/decompose_composite_ops.cc @@ -123,9 +123,9 @@ class OpDecomposer : public ExprMutator { const ShapeStructInfoNode* sinfo = GetStructInfoAs(call_node); ICHECK(sinfo); // call builtin function that converts tensor to shape tuple - static const Op& tensor_to_shape_op = Op::Get("relax.builtin.tensor_to_shape"); - Var call = - builder_->Emit(Call(tensor_to_shape_op, {expr}, {}, {GetRef(sinfo)})); + // TODO(@sunggg): Register operator for "vm.builtin.tensor_to_shape" + Var call = builder_->Emit(Call(ExternFunc("vm.builtin.tensor_to_shape"), {expr}, {}, + {GetRef(sinfo)})); // Operators like reshape take the output of `TensorToShape` as their output shape. // Because TOPI expects to have such output shape in symbolic shape at least (i.e.,