Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions src/relax/analysis/well_formed.cc
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,18 @@ class WellFormedChecker : public relax::ExprVisitor,
}
}

void VisitStructInfo_(const FuncStructInfoNode* op) final {
if (op->params.defined()) {
WithMode(VisitMode::kMatchVarDef, [&]() {
ICHECK(mode_ == VisitMode::kMatchVarDef);
for (StructInfo param : op->params.value()) {
this->VisitStructInfo(param);
}
});
}
this->VisitStructInfo(op->ret);
}

void VisitStructInfoExprField(const Expr& expr) final {
if (mode_ == VisitMode::kMatchVarDef) {
// populate symbolic var in first occurrence
Expand Down
5 changes: 4 additions & 1 deletion src/relax/ir/expr_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,10 @@ Expr ExprMutator::VisitExpr_(const FunctionNode* op) {
for (Var param : op->params) {
Var new_param = this->VisitVarDef(param);
params.push_back(new_param);
all_params_unchanged &= param.same_as(new_param);
if (!param.same_as(new_param)) {
var_remap_[param->vid] = new_param;
all_params_unchanged = false;
}
}

Expr body = this->VisitWithNewScope(op->body, params);
Expand Down
48 changes: 44 additions & 4 deletions src/relax/transform/fuse_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include <tvm/relax/expr_functor.h>
#include <tvm/relax/struct_info.h>
#include <tvm/relax/transform.h>
#include <tvm/tir/expr_functor.h>
#include <tvm/tir/function.h>

#include <optional>
Expand Down Expand Up @@ -344,6 +345,45 @@ class GraphCreator : public ExprVisitor {
std::unordered_set<IndexedForwardGraph::Node*> initialized_nodes_;
};

/*!
* \brief Renew the definition of symbolic vars in Relax.
* \details This mutator is used to prevent the same symbolic var from being used in different
* functions, which is malformed.
*/
class SymbolicVarRenewMutator : public ExprMutator, tir::ExprMutator {
public:
static Function Renew(const Function& function) {
SymbolicVarRenewMutator mutator;
return Downcast<Function>(mutator.VisitExpr(function));
}

private:
SymbolicVarRenewMutator() = default;
using relax::ExprMutator::VisitExpr;
using relax::ExprMutator::VisitExpr_;
using tir::ExprMutator::VisitExpr_;

PrimExpr VisitPrimExpr(const PrimExpr& expr) final { return tir::ExprMutator::VisitExpr(expr); }

// TODO(Siyuan): enhance the method to the following steps:
// 1. Visit and replace all tir::Vars at the definition point
// 2. Revisit the function again and update the use side.
PrimExpr VisitExpr_(const tir::VarNode* op) final {
auto it = var_map_.find(GetRef<tir::Var>(op));
if (it != var_map_.end()) {
return (*it).second;
} else {
auto n = make_object<tir::VarNode>(*op);
tir::Var v(n);
var_map_.Set(GetRef<tir::Var>(op), v);
return v;
}
}

private:
Map<tir::Var, tir::Var> var_map_;
};

/*!
* \brief The ExprMutator used to create a new grouped function
* \details The workflow of this ExprMutator is:
Expand Down Expand Up @@ -466,10 +506,10 @@ class FunctionCreator : public ExprMutator {
body = builder_->Normalize(body);
body = builder_->Normalize(SeqExpr({new_block}, body));
group_attrs.Set(tvm::relax::attr::kPrimitive, Integer(1));
function_ = Function(/*params=*/params_, //
/*body=*/body, //
/*ret_struct_info=*/NullOpt, //
/*attrs=*/DictAttrs(group_attrs));
function_ = SymbolicVarRenewMutator::Renew(Function(/*params=*/params_, //
/*body=*/body, //
/*ret_struct_info=*/NullOpt, //
/*attrs=*/DictAttrs(group_attrs)));
}
}

Expand Down
4 changes: 2 additions & 2 deletions src/runtime/relax_vm/cuda/cuda_graph_builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,11 @@ class CUDAGraphCache : public Object {

/*!
* \brief Launch the cuda graph if it has been cached, otherwise execute it in capture mode.
* \param vm The virutal machine.
* \param vm The virtual machine.
* \param capture_func The function of type (args...) -> Tuple[ObjectRef], where 'args' are the
* static arguments that are the same for all invocations of the capture function, the returned
* tuple contains the intermediate tensors that will be used outside the capture function.
* \params args The static arguments of the capture function
* \param args The static arguments of the capture function
* \param entry_index The unique index of the capture function used for lookup.
* \return The return value of the capture function.
*/
Expand Down
13 changes: 13 additions & 0 deletions tests/python/relax/test_analysis_well_formed.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,5 +520,18 @@ def foo(x: R.Tensor(("m", "n"), dtype="float32")) -> R.Tensor(("m1", "n1"), dtyp
assert not rx.analysis.well_formed(mod)


def test_func_sinfo_well_formed():
@R.function
def foo():
@R.function
def local(x: R.Tensor(["m", "n"], "float32")):
return x

return local

mod = rx.transform.Normalize()(tvm.IRModule.from_expr(foo))
assert rx.analysis.well_formed(mod)


if __name__ == "__main__":
tvm.testing.main()
37 changes: 37 additions & 0 deletions tests/python/relax/test_transform_fuse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1254,5 +1254,42 @@ def main(inp_0: R.Tensor((1, 784), dtype="float32"), inp_1: R.Tensor((1, 128), d
_check(mod, Expected)


def test_symbolic_shape_aware_fuse():
@I.ir_module
class Before:
@R.function
def main(x: R.Tensor(["n", "m"], "float32")):
with R.dataflow():
lv0 = R.emit_te(topi.add, x, R.const(1, "float32"))
lv1 = R.emit_te(topi.exp, lv0)
gv = R.emit_te(topi.squeeze, lv1)
R.output(gv)
return gv

@I.ir_module
class Expected:
@R.function
def fused_add_exp_squeeze(
x: R.Tensor(["n", "m"], "float32"), p0: R.Tensor([], "float32")
) -> R.Tensor(["n", "m"], dtype="float32"):
R.func_attr({"Primitive": 1})
with R.dataflow():
lv0 = R.emit_te(topi.add, x, p0)
lv1 = R.emit_te(topi.exp, lv0)
gv = R.emit_te(topi.squeeze, lv1)
R.output(gv)
return gv

@R.function
def main(x: R.Tensor(["n", "m"], "float32")) -> R.Tensor(["n", "m"], dtype="float32"):
cls = Expected
with R.dataflow():
gv = cls.fused_add_exp_squeeze(x, R.const(1, "float32"))
R.output(gv)
return gv

_check(Before, Expected)


if __name__ == "__main__":
tvm.testing.main()