Skip to content

Commit d91fe45

Browse files
authored
[Transform][Bugfix] Handle non-composite lambda functions in FuseOps (#16598)
Prior to this commit, calling `FuseOpsByPattern` with `annotate_codegen=True` would cause an error when encountering a lambda function. This was caused by the `CompositeFunctionAnnotator` asserting that all `relax::Function` encountered must have the `kComposite` attribute. While this is true for all lambda functions produced by `FuseOpsByPattern`, the user may have defined other lambda functions as well. This commit updates `CompositeFunctionAnnotator` to ignore lambda functions that do not have a `kComposite` attribute.
1 parent e5bfb02 commit d91fe45

File tree

2 files changed

+60
-2
lines changed

2 files changed

+60
-2
lines changed

src/relax/transform/fuse_ops.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1238,10 +1238,14 @@ class CompositeFunctionAnnotator : public ExprMutator {
12381238

12391239
Expr VisitExpr_(const FunctionNode* func_node) final {
12401240
Function f_inner = Downcast<Function>(ExprMutator::VisitExpr_(func_node));
1241-
auto composite_name = func_node->GetAttr<String>(attr::kComposite);
1241+
1242+
if (!func_node->GetAttr<String>(attr::kComposite)) {
1243+
// This lambda function doesn't have `attr::kComposite`, so it
1244+
// was not produced by FuseOps.
1245+
return std::move(f_inner);
1246+
}
12421247

12431248
f_inner = WithoutAttr(std::move(f_inner), tvm::relax::attr::kPrimitive);
1244-
ICHECK(composite_name);
12451249

12461250
Array<Var> param_vars;
12471251
Array<Expr> params;

tests/python/relax/test_transform_fuse_ops_by_pattern.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -530,6 +530,60 @@ def test_annotate_codegen():
530530
)
531531

532532

533+
@pytest.mark.parametrize("annotate_codegen", [True, False])
534+
def test_no_op_if_no_patterns_match(annotate_codegen):
535+
"""If no matches occur, FuseOpsByPattern is a no-op"""
536+
check(
537+
Conv2dReLU,
538+
[],
539+
Conv2dReLU,
540+
annotate_codegen=annotate_codegen,
541+
)
542+
543+
544+
@pytest.mark.parametrize("annotate_codegen", [True, False])
545+
def test_unmatched_calls_may_include_lambda_functions(annotate_codegen):
546+
"""If no matches occur, FuseOpsByPattern is a no-op
547+
548+
This is a regression test. Previous implementations of
549+
CompositeFunctionAnnotator assumed that all lambda functions
550+
resulted from FuseOps, and would contain the `kComposite`
551+
attribute.
552+
"""
553+
554+
@tvm.script.ir_module
555+
class Module:
556+
@R.function
557+
def main(
558+
data: R.Tensor((1, 64, 56, 56), "float32"),
559+
weight1: R.Tensor((64, 64, 3, 3), "float32"),
560+
):
561+
with R.dataflow():
562+
conv1 = R.nn.relu(R.nn.conv2d(data, weight1, padding=(1, 1)))
563+
R.output(conv1)
564+
565+
return conv1
566+
567+
@R.function
568+
def unrelated_function(A: R.Tensor([16, 16], dtype="float16")):
569+
@R.function
570+
def inner_func(B: R.Tensor([16, 16], dtype="float16")):
571+
with R.dataflow():
572+
C = R.multiply(B, R.const(2, "float16"))
573+
R.output(C)
574+
return C
575+
576+
D = inner_func(A)
577+
return D
578+
579+
check(
580+
Module,
581+
[],
582+
Module,
583+
annotate_codegen=annotate_codegen,
584+
)
585+
586+
533587
def test_compare_with_merge_composite_path():
534588
x = relax.Var("x", relax.TensorStructInfo([10, 10], "float32"))
535589
y = relax.Var("y", relax.TensorStructInfo([10, 10], "float32"))

0 commit comments

Comments
 (0)