From 555f55b771eb95fc2399b05cf48bb095b1c4333f Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 23 May 2024 10:38:55 -0500 Subject: [PATCH] [Relax][Bugfix] Bind symbolic variables in R.match_cast Prior to this commit, variable replacement by `BindSymbolicVars` would fail to replace variables that occur within a `relax::MatchCast` node. This pattern is rare, because the `bind_symbolic_vars` method can only replace variables that are exposed as part of the function signature, and most uses of `relax::MatchCast` act as a definition for symbolic variables that are not exposed through the function signature. This pattern is well-formed, though, since the `relax::MatchCast` node can also act as a user of previously-defined symbolic variables. The root cause for this bug was in the `ExprMutator` visitor for `relax::MatchCast`, which did not visit the struct info field. As a result, the virtual `ExprMutator::VisitPrimExpr` function was not called for expressions that occur within the `StructInfo` of a `relax::MatchCast`. This commit updates `ExprMutator` to resolve this bug, and applies an analogous fix for `ExprVisitor`. Co-authored-by: Chris Sullivan --- src/relax/ir/expr_functor.cc | 22 ++++++++++++++----- tests/python/relax/test_bind_symbolic_vars.py | 22 +++++++++++++++++++ 2 files changed, 38 insertions(+), 6 deletions(-) diff --git a/src/relax/ir/expr_functor.cc b/src/relax/ir/expr_functor.cc index dbfaf60fecfc..63c74db7e33e 100644 --- a/src/relax/ir/expr_functor.cc +++ b/src/relax/ir/expr_functor.cc @@ -257,6 +257,7 @@ RELAX_EXPR_VISITOR_VISIT_BINDING_IMPL(DataTypeImmNode); void ExprVisitor::VisitBinding_(const MatchCastNode* binding) { this->VisitExpr(binding->value); + this->VisitExprDepStructInfoField(binding->struct_info); this->VisitVarDef(binding->var); } @@ -690,16 +691,25 @@ void ExprMutator::ReEmitBinding(const VarBindingNode* binding, Expr new_value) { } void ExprMutator::VisitBinding_(const MatchCastNode* binding) { - Var new_var = this->VisitVarDef(binding->var); Expr new_value = this->VisitExpr(binding->value); + StructInfo new_struct_info = this->VisitExprDepStructInfoField(binding->struct_info); - // re-emit old binding if nothing changes - if (new_var.same_as(binding->var) && new_value.same_as(binding->value)) { + Var new_var = this->VisitVarDef(binding->var); + + if (new_var.same_as(binding->var) && new_value.same_as(binding->value) && + new_struct_info.same_as(binding->struct_info)) { + // re-emit old binding if nothing changes builder_->EmitNormalized(GetRef(binding)); - } else { - new_value = builder_->NormalizeArgument(new_value); - builder_->EmitNormalized(MatchCast(new_var, new_value, binding->struct_info, binding->span)); + return; } + + new_value = builder_->NormalizeArgument(new_value); + new_var = WithStructInfo(new_var, new_struct_info); + + var_remap_[binding->var->vid] = new_var; + var_remap_[new_var->vid] = new_var; + + builder_->EmitNormalized(MatchCast(new_var, new_value, new_struct_info, binding->span)); } BindingBlock ExprMutator::VisitBindingBlock_(const BindingBlockNode* block) { diff --git a/tests/python/relax/test_bind_symbolic_vars.py b/tests/python/relax/test_bind_symbolic_vars.py index 82798c56dfff..18246d224b65 100644 --- a/tests/python/relax/test_bind_symbolic_vars.py +++ b/tests/python/relax/test_bind_symbolic_vars.py @@ -286,5 +286,27 @@ def expected(A: R.Tensor(["M", 32])): tvm.ir.assert_structural_equal(expected, after) +def test_bind_inside_match_cast(): + """Symbolic variables may occur within R.match_cast""" + + @R.function(private=True) + def before(A: R.Tensor(["M", "N"]), B: R.Tensor(ndim=2)): + M = T.int64() + N = T.int64() + C = R.match_cast(B, R.Tensor([M, N])) + D = R.add(A, C) + return D + + @R.function(private=True) + def expected(A: R.Tensor(["M", 32]), B: R.Tensor(ndim=2)): + M = T.int64() + C = R.match_cast(B, R.Tensor([M, 32])) + D = R.add(A, C) + return D + + after = before.bind_symbolic_vars({"N": 32}) + tvm.ir.assert_structural_equal(expected, after) + + if __name__ == "__main__": tvm.testing.main()