Skip to content

Commit 4646456

Browse files
author
Siyuan Feng
committed
[Unity] Support simple dynamic-shape-aware fusion
This PR adds support for simple dynamic-shape-aware fusion, which is the first step towards supporting dynamic shapes. The main changes are as follows: - Fix FuncStructInfo in well-formed checks - Renew symbolic var defs in fuse_ops to prevent malformed functions
1 parent 559fac7 commit 4646456

File tree

6 files changed

+105
-5
lines changed

6 files changed

+105
-5
lines changed

3rdparty/cutlass

src/relax/analysis/well_formed.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,18 @@ class WellFormedChecker : public relax::ExprVisitor,
402402
}
403403
}
404404

405+
void VisitStructInfo_(const FuncStructInfoNode* op) final {
406+
if (op->params.defined()) {
407+
WithMode(VisitMode::kMatchVarDef, [&]() {
408+
ICHECK(mode_ == VisitMode::kMatchVarDef);
409+
for (StructInfo param : op->params.value()) {
410+
this->VisitStructInfo(param);
411+
}
412+
});
413+
}
414+
this->VisitStructInfo(op->ret);
415+
}
416+
405417
void VisitStructInfoExprField(const Expr& expr) final {
406418
if (mode_ == VisitMode::kMatchVarDef) {
407419
// populate symbolic var in first occurrence

src/relax/ir/expr_functor.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -752,6 +752,7 @@ Var ExprMutator::VisitVarDef(const Var& var) {
752752
} else {
753753
LOG(FATAL) << "TypeError: Invalid type: " << var->GetTypeKey();
754754
}
755+
this->var_remap_[var->vid] = ret;
755756
return ret;
756757
}
757758

src/relax/transform/fuse_ops.cc

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
#include <tvm/relax/expr_functor.h>
3434
#include <tvm/relax/struct_info.h>
3535
#include <tvm/relax/transform.h>
36+
#include <tvm/tir/expr_functor.h>
3637
#include <tvm/tir/function.h>
3738

3839
#include <optional>
@@ -344,6 +345,42 @@ class GraphCreator : public ExprVisitor {
344345
std::unordered_set<IndexedForwardGraph::Node*> initialized_nodes_;
345346
};
346347

348+
/*!
349+
* \brief Renew the definition of symbolic vars in Relax.
350+
* \details This mutator is used to prevent the same symbolic var from being used in different
351+
* functions, which is malformed.
352+
*/
353+
class SymbolicVarRenewMutator : public ExprMutator, tir::ExprMutator {
354+
public:
355+
static Function Renew(const Function& function) {
356+
SymbolicVarRenewMutator mutator;
357+
return Downcast<Function>(mutator.VisitExpr(function));
358+
}
359+
360+
private:
361+
SymbolicVarRenewMutator() = default;
362+
using relax::ExprMutator::VisitExpr;
363+
using relax::ExprMutator::VisitExpr_;
364+
using tir::ExprMutator::VisitExpr_;
365+
366+
PrimExpr VisitPrimExpr(const PrimExpr& expr) final { return tir::ExprMutator::VisitExpr(expr); }
367+
368+
PrimExpr VisitExpr_(const tir::VarNode* op) final {
369+
auto it = var_map_.find(GetRef<tir::Var>(op));
370+
if (it != var_map_.end()) {
371+
return (*it).second;
372+
} else {
373+
auto n = make_object<tir::VarNode>(*op);
374+
tir::Var v(n);
375+
var_map_.Set(GetRef<tir::Var>(op), v);
376+
return v;
377+
}
378+
}
379+
380+
private:
381+
Map<tir::Var, tir::Var> var_map_;
382+
};
383+
347384
/*!
348385
* \brief The ExprMutator used to create a new grouped function
349386
* \details The workflow of this ExprMutator is:
@@ -466,10 +503,10 @@ class FunctionCreator : public ExprMutator {
466503
body = builder_->Normalize(body);
467504
body = builder_->Normalize(SeqExpr({new_block}, body));
468505
group_attrs.Set(tvm::relax::attr::kPrimitive, Integer(1));
469-
function_ = Function(/*params=*/params_, //
470-
/*body=*/body, //
471-
/*ret_struct_info=*/NullOpt, //
472-
/*attrs=*/DictAttrs(group_attrs));
506+
function_ = SymbolicVarRenewMutator::Renew(Function(/*params=*/params_, //
507+
/*body=*/body, //
508+
/*ret_struct_info=*/NullOpt, //
509+
/*attrs=*/DictAttrs(group_attrs)));
473510
}
474511
}
475512

tests/python/relax/test_analysis_well_formed.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -520,5 +520,18 @@ def foo(x: R.Tensor(("m", "n"), dtype="float32")) -> R.Tensor(("m1", "n1"), dtyp
520520
assert not rx.analysis.well_formed(mod)
521521

522522

523+
def test_func_sinfo_well_formed():
524+
@R.function
525+
def foo():
526+
@R.function
527+
def local(x: R.Tensor(["m", "n"], "float32")):
528+
return x
529+
530+
return local
531+
532+
mod = rx.transform.Normalize()(tvm.IRModule.from_expr(foo))
533+
assert rx.analysis.well_formed(mod)
534+
535+
523536
if __name__ == "__main__":
524537
tvm.testing.main()

tests/python/relax/test_transform_fuse_ops.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1254,5 +1254,42 @@ def main(inp_0: R.Tensor((1, 784), dtype="float32"), inp_1: R.Tensor((1, 128), d
12541254
_check(mod, Expected)
12551255

12561256

1257+
def test_symbolic_shape_aware_fuse():
1258+
@I.ir_module
1259+
class Before:
1260+
@R.function
1261+
def main(x: R.Tensor(["n", "m"], "float32")):
1262+
with R.dataflow():
1263+
lv0 = R.emit_te(topi.add, x, R.const(1, "float32"))
1264+
lv1 = R.emit_te(topi.exp, lv0)
1265+
gv = R.emit_te(topi.squeeze, lv1)
1266+
R.output(gv)
1267+
return gv
1268+
1269+
@I.ir_module
1270+
class Expected:
1271+
@R.function
1272+
def fused_add_exp_squeeze(
1273+
x: R.Tensor(["n", "m"], "float32"), p0: R.Tensor([], "float32")
1274+
) -> R.Tensor(["n", "m"], dtype="float32"):
1275+
R.func_attr({"Primitive": 1})
1276+
with R.dataflow():
1277+
lv0 = R.emit_te(topi.add, x, p0)
1278+
lv1 = R.emit_te(topi.exp, lv0)
1279+
gv = R.emit_te(topi.squeeze, lv1)
1280+
R.output(gv)
1281+
return gv
1282+
1283+
@R.function
1284+
def main(x: R.Tensor(["n", "m"], "float32")) -> R.Tensor(["n", "m"], dtype="float32"):
1285+
cls = Expected
1286+
with R.dataflow():
1287+
gv = cls.fused_add_exp_squeeze(x, R.const(1, "float32"))
1288+
R.output(gv)
1289+
return gv
1290+
1291+
_check(Before, Expected)
1292+
1293+
12571294
if __name__ == "__main__":
12581295
tvm.testing.main()

0 commit comments

Comments
 (0)