From df85d05e1a629f80bf2f164e708d7925122af3e0 Mon Sep 17 00:00:00 2001 From: Prakalp Srivastava Date: Tue, 28 Mar 2023 14:56:46 -0700 Subject: [PATCH] [Unity][Fix] Copy over module attrs in FuseTIR --- src/relax/transform/fuse_tir.cc | 6 +++++- tests/python/relax/test_transform_fuse_tir.py | 4 ++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc index e90d6e4bc1d1..f4a31853e39c 100644 --- a/src/relax/transform/fuse_tir.cc +++ b/src/relax/transform/fuse_tir.cc @@ -604,7 +604,11 @@ class TIRFuseMutator : public ExprMutator { mutator.builder_->AddFunction(update_func, gv->name_hint); } } - return mutator.builder_->GetContextIRModule(); + + // Step 3. Copy over module attributes and return. + auto modified_mod = mutator.builder_->GetContextIRModule(); + if (mod->attrs.defined()) modified_mod = WithAttrs(modified_mod, mod->attrs->dict); + return modified_mod; } private: diff --git a/tests/python/relax/test_transform_fuse_tir.py b/tests/python/relax/test_transform_fuse_tir.py index 7a8aa4d39f67..356e28d6e910 100644 --- a/tests/python/relax/test_transform_fuse_tir.py +++ b/tests/python/relax/test_transform_fuse_tir.py @@ -47,7 +47,7 @@ def before(): gv = bb.emit_output(relax.Call(fused_add_exp_squeeze, [x, p0])) bb.emit_func_output(gv) - return bb.get() + return bb.get().with_attrs({"foo": "bar"}) def expected(): def fused_add_exp_squeeze(x, p0): @@ -63,7 +63,7 @@ def fused_add_exp_squeeze(x, p0): with bb.dataflow(): gv = bb.emit_output(bb.call_te(fused_add_exp_squeeze, x, p0)) bb.emit_func_output(gv) - return bb.get() + return bb.get().with_attrs({"foo": "bar"}) _check(before(), expected())