diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc index e90d6e4bc1d1..f4a31853e39c 100644 --- a/src/relax/transform/fuse_tir.cc +++ b/src/relax/transform/fuse_tir.cc @@ -604,7 +604,11 @@ class TIRFuseMutator : public ExprMutator { mutator.builder_->AddFunction(update_func, gv->name_hint); } } - return mutator.builder_->GetContextIRModule(); + + // Step 3. Copy over module attributes and return. + auto modified_mod = mutator.builder_->GetContextIRModule(); + if (mod->attrs.defined()) modified_mod = WithAttrs(modified_mod, mod->attrs->dict); + return modified_mod; } private: diff --git a/tests/python/relax/test_transform_fuse_tir.py b/tests/python/relax/test_transform_fuse_tir.py index 7a8aa4d39f67..356e28d6e910 100644 --- a/tests/python/relax/test_transform_fuse_tir.py +++ b/tests/python/relax/test_transform_fuse_tir.py @@ -47,7 +47,7 @@ def before(): gv = bb.emit_output(relax.Call(fused_add_exp_squeeze, [x, p0])) bb.emit_func_output(gv) - return bb.get() + return bb.get().with_attrs({"foo": "bar"}) def expected(): def fused_add_exp_squeeze(x, p0): @@ -63,7 +63,7 @@ def fused_add_exp_squeeze(x, p0): with bb.dataflow(): gv = bb.emit_output(bb.call_te(fused_add_exp_squeeze, x, p0)) bb.emit_func_output(gv) - return bb.get() + return bb.get().with_attrs({"foo": "bar"}) _check(before(), expected())