Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 22 additions & 74 deletions src/relax/transform/fuse_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -411,21 +411,11 @@ class FunctionCreator : public ExprMutator {

for (const Expr& arg : call->args) {
CheckDefAndUpdateParam(arg);
if (GetStructInfoAs<TupleStructInfoNode>(arg) != nullptr) {
// The argument is fully referenced. Thus we remove it from the mapping.
partially_used_tuple_params_.erase(arg.get());
}
}
}
} else if (var_binding->value.as<TupleGetItemNode>()) {
const auto* tuple_item = var_binding->value.as<TupleGetItemNode>();
CheckDefAndUpdateParam(tuple_item->tuple);

if (partially_used_tuple_params_.find(tuple_item->tuple.get()) !=
partially_used_tuple_params_.end()) {
// Appending get-item index to the mapping.
partially_used_tuple_params_[tuple_item->tuple.get()].push_back(tuple_item->index);
}
}

// Mark the binding variable as defined.
Expand Down Expand Up @@ -461,51 +451,9 @@ class FunctionCreator : public ExprMutator {
// Step 1. Start constructing a new dataflow block.
builder_->BeginDataflowBlock();

// Step 2. Handing partially used tuple parameters: replacing entire tuple
// parameters with the parameters of its fields that are accessed in the
// function.
std::unordered_map<const ExprNode*, std::unordered_map<int, Var>> tuple_get_item_remap;
for (auto& [tuple_arg, item_indices] : partially_used_tuple_params_) {
ICHECK(!item_indices.empty());
int param_idx = tuple_param_idx_[tuple_arg];
Var param = params_[param_idx];
String param_name = params_[param_idx]->name_hint();
TupleStructInfo param_sinfo = Downcast<TupleStructInfo>(tuple_arg->struct_info_);

Array<Expr> item_args;
Array<Var> item_params;
item_args.reserve(item_indices.size());
item_params.reserve(item_indices.size());
for (int item_idx : item_indices) {
Var item_param(param_name + "_" + std::to_string(item_idx), param_sinfo->fields[item_idx]);
item_args.push_back(TupleGetItem(GetRef<Expr>(tuple_arg), item_idx));
item_params.push_back(item_param);
tuple_get_item_remap[tuple_arg][item_idx] = item_param;
}
arguments_.erase(arguments_.begin() + param_idx);
arguments_.insert(arguments_.begin() + param_idx, item_args.begin(), item_args.end());
params_.erase(params_.begin() + param_idx);
params_.insert(params_.begin() + param_idx, item_params.begin(), item_params.end());
}

// Step 3. Visit each binding and collect outputs one by one.
Array<Expr> outputs(output_vars_.size(), Expr());
for (const Binding& binding : bindings_) {
// Special handing for TupleGetItem.
if (const auto* var_binding = binding.as<VarBindingNode>()) {
if (const auto* tuple_get_item = var_binding->value.as<TupleGetItemNode>()) {
auto it = tuple_get_item_remap.find(tuple_get_item->tuple.get());
if (it != tuple_get_item_remap.end()) {
ICHECK(it->second.find(tuple_get_item->index) != it->second.end());
var_remap_[var_binding->var->vid] = it->second[tuple_get_item->index];
if (auto output_idx = GetOutputIndex(binding->var)) {
outputs.Set(*output_idx, it->second[tuple_get_item->index]);
}
continue;
}
}
}

if (auto output_idx = GetOutputIndex(binding->var)) {
// Case 1. It is an output binding
// We only allow VarBinding as output.
Expand Down Expand Up @@ -602,13 +550,6 @@ class FunctionCreator : public ExprMutator {
arguments_.push_back(expr);
params_.push_back(param);
}

// Mark the tuple parameter is partially referenced in the beginning.
// We will remove it from the mapping once we find it is fully referenced.
if (param_sinfo->IsInstance<TupleStructInfoNode>()) {
partially_used_tuple_params_[expr.get()] = {};
tuple_param_idx_[expr.get()] = static_cast<int>(arguments_.size()) - 1;
}
}
}

Expand All @@ -631,13 +572,6 @@ class FunctionCreator : public ExprMutator {
std::vector<const VarNode*> output_vars_;
/*! \brief Whether or not to lift bound constants to parameters */
bool lift_constant_;
/*! \brief Mapping from tuple parameter of the function to its position index */
std::unordered_map<const ExprNode*, int> tuple_param_idx_;
/*!
* \brief Mapping from partially referenced tuple parameter to the list of
* indices that the parameter is referred by TupleGetItem
*/
std::unordered_map<const ExprNode*, std::vector<int>> partially_used_tuple_params_;
};

/*!
Expand Down Expand Up @@ -1323,10 +1257,17 @@ Pass FuseOps(int fuse_opt_level) {
auto max_fuse_depth = pc->GetConfig("relax.FuseOps.max_depth", Integer(kMaxFusedOps));
return relax::FuseOps(m, opt_level, max_fuse_depth.value().IntValue());
};
return CreateModulePass(/*pass_function=*/pass_func, //
/*opt_level=*/0, //
/*pass_name=*/"FuseOps", //
/*required=*/{});
auto inner_pass = CreateModulePass(/*pass_function=*/pass_func, //
/*opt_level=*/0, //
/*pass_name=*/"FuseOpsInner", //
/*required=*/{});
return tvm::transform::Sequential(
{
inner_pass,
ExpandTupleArguments(),
RemoveUnusedParameters(),
},
"FuseOpsInner");
}

TVM_REGISTER_GLOBAL("relax.transform.FuseOps").set_body_typed(FuseOps);
Expand All @@ -1337,10 +1278,17 @@ Pass FuseOpsByPattern(const tvm::Array<FusionPattern>& patterns, bool bind_const
[=](IRModule m, PassContext pc) {
return relax::FuseOpsByPattern(patterns, m, bind_constants, annotate_codegen);
};
return CreateModulePass(/*pass_function=*/pass_func, //
/*opt_level=*/0, //
/*pass_name=*/"FuseOpsByPattern", //
/*required=*/{});
auto inner_pass = CreateModulePass(/*pass_function=*/pass_func, //
/*opt_level=*/0, //
/*pass_name=*/"FuseOpsByPatternInner", //
/*required=*/{});
return tvm::transform::Sequential(
{
inner_pass,
ExpandTupleArguments(),
RemoveUnusedParameters(),
},
"FuseOpsByPattern");
}

TVM_REGISTER_GLOBAL("relax.transform.FuseOpsByPattern").set_body_typed(FuseOpsByPattern);
Expand Down
108 changes: 52 additions & 56 deletions tests/python/relax/test_transform_fuse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1122,8 +1122,7 @@ def main(inp_0: R.Tensor((2, 320, 64, 64), dtype="float32"), inp_1: R.Tensor((2,
lv31 = R.call_tir(cls.transpose, (w2,), out_sinfo=R.Tensor((1280, 320), dtype="float32"))
lv: R.Tensor((2, 320), dtype="float32") = cls.fused_matmul_add1(inp_1, lv31, b2)
lv35 = R.call_tir(cls.reshape1, (lv,), out_sinfo=R.Tensor((2, 320, 1, 1), dtype="float32"))
lv1: R.Tensor((2, 320, 64, 64), dtype="float32") = cls.fused_conv2d_add_add2(inp_0, w1, lv28, lv35)
gv: R.Tensor((2, 320, 64, 64), dtype="float32") = lv1
gv: R.Tensor((2, 320, 64, 64), dtype="float32") = cls.fused_conv2d_add_add2(inp_0, w1, lv28, lv35)
R.output(gv)
return gv
# fmt: on
Expand Down Expand Up @@ -1156,16 +1155,16 @@ def main(inp_0: R.Tensor((1, 784), dtype="float32"), inp_1: R.Tensor((1, 128), d

@I.ir_module
class Expected:
@T.prim_func(private=True)
def add(rxplaceholder: T.Buffer((T.int64(1), T.int64(128)), "float32"), rxplaceholder_1: T.Buffer((T.int64(128),), "float32"), T_add: T.Buffer((T.int64(1), T.int64(128)), "float32")):
T.func_attr({"op_pattern": 0, "tir.noalias": True})
# with T.block("root"):
for ax0, ax1 in T.grid(T.int64(1), T.int64(128)):
with T.block("T_add"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(rxplaceholder[v_ax0, v_ax1], rxplaceholder_1[v_ax1])
T.writes(T_add[v_ax0, v_ax1])
T_add[v_ax0, v_ax1] = rxplaceholder[v_ax0, v_ax1] + rxplaceholder_1[v_ax1]
# @T.prim_func(private=True)
# def add(rxplaceholder: T.Buffer((T.int64(1), T.int64(128)), "float32"), rxplaceholder_1: T.Buffer((T.int64(128),), "float32"), T_add: T.Buffer((T.int64(1), T.int64(128)), "float32")):
# T.func_attr({"op_pattern": 0, "tir.noalias": True})
# # with T.block("root"):
# for ax0, ax1 in T.grid(T.int64(1), T.int64(128)):
# with T.block("T_add"):
# v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
# T.reads(rxplaceholder[v_ax0, v_ax1], rxplaceholder_1[v_ax1])
# T.writes(T_add[v_ax0, v_ax1])
# T_add[v_ax0, v_ax1] = rxplaceholder[v_ax0, v_ax1] + rxplaceholder_1[v_ax1]

@T.prim_func(private=True)
def add1(rxplaceholder: T.Buffer((T.int64(1), T.int64(10)), "float32"), rxplaceholder_1: T.Buffer((T.int64(10),), "float32"), T_add: T.Buffer((T.int64(1), T.int64(10)), "float32")):
Expand All @@ -1178,18 +1177,18 @@ def add1(rxplaceholder: T.Buffer((T.int64(1), T.int64(10)), "float32"), rxplaceh
T.writes(T_add[v_ax0, v_ax1])
T_add[v_ax0, v_ax1] = rxplaceholder[v_ax0, v_ax1] + rxplaceholder_1[v_ax1]

@T.prim_func(private=True)
def matmul(rxplaceholder: T.Buffer((T.int64(1), T.int64(784)), "float32"), rxplaceholder_1: T.Buffer((T.int64(784), T.int64(128)), "float32"), matmul_1: T.Buffer((T.int64(1), T.int64(128)), "float32")):
T.func_attr({"op_pattern": 4, "tir.noalias": True})
# with T.block("root"):
for i0, i1, k in T.grid(T.int64(1), T.int64(128), T.int64(784)):
with T.block("matmul"):
v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
T.reads(rxplaceholder[v_i0, v_k], rxplaceholder_1[v_k, v_i1])
T.writes(matmul_1[v_i0, v_i1])
with T.init():
matmul_1[v_i0, v_i1] = T.float32(0)
matmul_1[v_i0, v_i1] = matmul_1[v_i0, v_i1] + rxplaceholder[v_i0, v_k] * rxplaceholder_1[v_k, v_i1]
# @T.prim_func(private=True)
# def matmul(rxplaceholder: T.Buffer((T.int64(1), T.int64(784)), "float32"), rxplaceholder_1: T.Buffer((T.int64(784), T.int64(128)), "float32"), matmul_1: T.Buffer((T.int64(1), T.int64(128)), "float32")):
# T.func_attr({"op_pattern": 4, "tir.noalias": True})
# # with T.block("root"):
# for i0, i1, k in T.grid(T.int64(1), T.int64(128), T.int64(784)):
# with T.block("matmul"):
# v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
# T.reads(rxplaceholder[v_i0, v_k], rxplaceholder_1[v_k, v_i1])
# T.writes(matmul_1[v_i0, v_i1])
# with T.init():
# matmul_1[v_i0, v_i1] = T.float32(0)
# matmul_1[v_i0, v_i1] = matmul_1[v_i0, v_i1] + rxplaceholder[v_i0, v_k] * rxplaceholder_1[v_k, v_i1]

@T.prim_func(private=True)
def matmul1(rxplaceholder: T.Buffer((T.int64(1), T.int64(128)), "float32"), rxplaceholder_1: T.Buffer((T.int64(128), T.int64(10)), "float32"), matmul: T.Buffer((T.int64(1), T.int64(10)), "float32")):
Expand All @@ -1204,27 +1203,27 @@ def matmul1(rxplaceholder: T.Buffer((T.int64(1), T.int64(128)), "float32"), rxpl
matmul[v_i0, v_i1] = T.float32(0)
matmul[v_i0, v_i1] = matmul[v_i0, v_i1] + rxplaceholder[v_i0, v_k] * rxplaceholder_1[v_k, v_i1]

@T.prim_func(private=True)
def relu(rxplaceholder: T.Buffer((T.int64(1), T.int64(128)), "float32"), compute: T.Buffer((T.int64(1), T.int64(128)), "float32")):
T.func_attr({"op_pattern": 0, "tir.noalias": True})
# with T.block("root"):
for i0, i1 in T.grid(T.int64(1), T.int64(128)):
with T.block("compute"):
v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
T.reads(rxplaceholder[v_i0, v_i1])
T.writes(compute[v_i0, v_i1])
compute[v_i0, v_i1] = T.max(rxplaceholder[v_i0, v_i1], T.float32(0))

@T.prim_func(private=True)
def transpose(rxplaceholder: T.Buffer((T.int64(128), T.int64(784)), "float32"), T_transpose: T.Buffer((T.int64(784), T.int64(128)), "float32")):
T.func_attr({"op_pattern": 2, "tir.noalias": True})
# with T.block("root"):
for ax0, ax1 in T.grid(T.int64(784), T.int64(128)):
with T.block("T_transpose"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(rxplaceholder[v_ax1, v_ax0])
T.writes(T_transpose[v_ax0, v_ax1])
T_transpose[v_ax0, v_ax1] = rxplaceholder[v_ax1, v_ax0]
# @T.prim_func(private=True)
# def relu(rxplaceholder: T.Buffer((T.int64(1), T.int64(128)), "float32"), compute: T.Buffer((T.int64(1), T.int64(128)), "float32")):
# T.func_attr({"op_pattern": 0, "tir.noalias": True})
# # with T.block("root"):
# for i0, i1 in T.grid(T.int64(1), T.int64(128)):
# with T.block("compute"):
# v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
# T.reads(rxplaceholder[v_i0, v_i1])
# T.writes(compute[v_i0, v_i1])
# compute[v_i0, v_i1] = T.max(rxplaceholder[v_i0, v_i1], T.float32(0))

# @T.prim_func(private=True)
# def transpose(rxplaceholder: T.Buffer((T.int64(128), T.int64(784)), "float32"), T_transpose: T.Buffer((T.int64(784), T.int64(128)), "float32")):
# T.func_attr({"op_pattern": 2, "tir.noalias": True})
# # with T.block("root"):
# for ax0, ax1 in T.grid(T.int64(784), T.int64(128)):
# with T.block("T_transpose"):
# v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
# T.reads(rxplaceholder[v_ax1, v_ax0])
# T.writes(T_transpose[v_ax0, v_ax1])
# T_transpose[v_ax0, v_ax1] = rxplaceholder[v_ax1, v_ax0]

@T.prim_func(private=True)
def transpose1(rxplaceholder: T.Buffer((T.int64(10), T.int64(128)), "float32"), T_transpose: T.Buffer((T.int64(128), T.int64(10)), "float32")):
Expand Down Expand Up @@ -1252,10 +1251,9 @@ def main(inp_0: R.Tensor((1, 784), dtype="float32"), inp_1: R.Tensor((1, 128), d
R.func_attr({"num_input": 1})
cls = Expected
with R.dataflow():
lv = R.call_tir(cls.transpose, (linear1_weight,), out_sinfo=R.Tensor((784, 128), dtype="float32"))
# lv = R.call_tir(cls.transpose, (linear1_weight,), out_sinfo=R.Tensor((784, 128), dtype="float32"))
lv4 = R.call_tir(cls.transpose1, (linear2_weight,), out_sinfo=R.Tensor((128, 10), dtype="float32"))
lv_1: R.Tensor((1, 10), dtype="float32") = cls.fused_matmul1_add1(inp_1, lv4, linear2_bias)
gv: R.Tensor((1, 10), dtype="float32") = lv_1
gv: R.Tensor((1, 10), dtype="float32") = cls.fused_matmul1_add1(inp_1, lv4, linear2_bias)
R.output(gv)
return gv

Expand Down Expand Up @@ -1319,7 +1317,7 @@ def main(s: R.Shape(["n"])):
class Expected:
@R.function(private=True)
def fused_full_trilu_broadcast_to(
s: R.Shape(["n"]),
s: R.Prim(value="n"),
) -> R.Tensor([1, 1, "n", "n"], "float32"):
R.func_attr({"Primitive": 1})
n = T.int64()
Expand All @@ -1336,7 +1334,7 @@ def main(s: R.Shape(["n"])) -> R.Tensor((1, 1, "n", "n"), dtype="float32"):
n = T.int64()
with R.dataflow():
gv: R.Tensor([1, 1, n, n], "float32") = cls.fused_full_trilu_broadcast_to(
R.shape([n])
R.prim_value(n)
)
R.output(gv)
return gv
Expand Down Expand Up @@ -1367,7 +1365,7 @@ def main(s: R.Shape(["n"]), kv_cache: R.Object):
class Expected:
@R.function(private=True)
def fused_full_trilu_broadcast_to(
s: R.Shape(["n"]),
s: R.Prim(value="n"),
) -> R.Tensor([1, 1, "n", "n"], "float32"):
R.func_attr({"Primitive": 1})
n = T.int64()
Expand All @@ -1384,7 +1382,7 @@ def main(s: R.Shape(["n"]), kv_cache: R.Object):
n = T.int64()
with R.dataflow():
lv: R.Tensor([1, 1, n, n], "float32") = cls.fused_full_trilu_broadcast_to(
R.shape([n])
R.prim_value(n)
)
gv = R.call_pure_packed(
"vm.builtin.attention_kv_cache_view",
Expand All @@ -1406,10 +1404,9 @@ def main(A: R.Tensor((10, 20), dtype="float32")) -> R.Tensor(dtype="float32", nd
m = T.int64()
n = T.int64()
with R.dataflow():
lv: R.Tensor((m, n), dtype="float32") = R.match_cast(
gv: R.Tensor((m, n), dtype="float32") = R.match_cast(
A, R.Tensor((m, n), dtype="float32")
)
gv: R.Tensor((m, n), dtype="float32") = lv
R.output(gv)
return gv

Expand Down Expand Up @@ -1491,10 +1488,9 @@ def main(
cls = Expected
with R.dataflow():
lv: R.Tensor((2,), dtype="float32") = x[0]
lv1: R.Tensor((2,), dtype="float32") = cls.fused_add_divide(
gv: R.Tensor((2,), dtype="float32") = cls.fused_add_divide(
lv, R.const(1, "float32"), R.const(1, "float32")
)
gv: R.Tensor((2,), dtype="float32") = lv1
R.output(gv)
return gv

Expand Down
1 change: 0 additions & 1 deletion tests/python/relax/test_transform_fuse_ops_by_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,7 +758,6 @@ def main(
gv: R.Tensor((1, 64, 56, 56), dtype="float32") = cls.fused_relax_nn_conv2d(
data, weight1
)
relu: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.relu(data)
R.output(gv)
return gv

Expand Down