diff --git a/python/tvm/relax/transform/lazy_transform_params.py b/python/tvm/relax/transform/lazy_transform_params.py index d7b594503124..01deee8197f9 100644 --- a/python/tvm/relax/transform/lazy_transform_params.py +++ b/python/tvm/relax/transform/lazy_transform_params.py @@ -85,11 +85,11 @@ def __init__(self, out_tuple_var: relax.Var, input_params: set) -> None: self.input_params = input_params self.var_liveness_end = {} - def visit_dataflow_block_(self, block: relax.DataflowBlock) -> None: + def visit_binding_block_(self, block: relax.BindingBlock) -> None: for binding in reversed(block.bindings): self.visit_binding(binding) - def visit_dataflow_var_(self, op: relax.DataflowVar) -> None: + def visit_var_(self, op: relax.Var) -> None: if op in self.input_params: self.last_appear_in_var_binding.append(op) self.input_params.remove(op) diff --git a/tests/python/relax/test_transform_lazy_transform_params.py b/tests/python/relax/test_transform_lazy_transform_params.py index bfc1d282ab18..0fc08d5ef487 100644 --- a/tests/python/relax/test_transform_lazy_transform_params.py +++ b/tests/python/relax/test_transform_lazy_transform_params.py @@ -45,19 +45,17 @@ def main_transform_params( R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor((16, 3, 3, 3), dtype="float32") ): cls = Before - with R.dataflow(): - lv: R.Tensor((16, 16, 3, 3), dtype="float32") = params[1] - lv1: R.Tensor((3, 16, 3, 3), dtype="float32") = params[0] - lv2 = R.call_tir( - cls.transform_layout_IOHW_to_OIHW, - (lv1,), - out_sinfo=R.Tensor((16, 3, 3, 3), dtype="float32"), - ) - gv: R.Tuple( - R.Tensor((16, 16, 3, 3), dtype="float32"), - R.Tensor((16, 3, 3, 3), dtype="float32"), - ) = (lv, lv2) - R.output(gv) + lv: R.Tensor((16, 16, 3, 3), dtype="float32") = params[1] + lv1: R.Tensor((3, 16, 3, 3), dtype="float32") = params[0] + lv2 = R.call_tir( + cls.transform_layout_IOHW_to_OIHW, + (lv1,), + out_sinfo=R.Tensor((16, 3, 3, 3), dtype="float32"), + ) + gv: R.Tuple( + R.Tensor((16, 16, 3, 3), dtype="float32"), + R.Tensor((16, 3, 3, 3), dtype="float32"), + ) = (lv, lv2) return gv @I.ir_module @@ -77,24 +75,18 @@ def transform_layout_IOHW_to_OIHW( @R.function def main_transform_params() -> R.Tuple(R.Object, R.Object): cls = Expected - with R.dataflow(): - lv: R.Object = R.call_packed("get_item", R.prim_value(1), sinfo_args=(R.Object,)) - lv1: R.Object = R.call_packed( - "set_item", R.prim_value(0), lv, sinfo_args=(R.Object,) - ) - lv2: R.Tuple = R.vm.kill_object(lv) - lv1_1: R.Object = R.call_packed("get_item", R.prim_value(0), sinfo_args=(R.Object,)) - lv3 = R.call_tir( - cls.transform_layout_IOHW_to_OIHW, - (lv1_1,), - out_sinfo=R.Tensor((16, 3, 3, 3), dtype="float32"), - ) - lv4: R.Object = R.call_packed( - "set_item", R.prim_value(1), lv3, sinfo_args=(R.Object,) - ) - lv5: R.Tuple = R.vm.kill_object(lv1_1) - gv: R.Tuple(R.Object, R.Object) = (lv1, lv4) - R.output(gv) + lv: R.Object = R.call_packed("get_item", R.prim_value(1), sinfo_args=(R.Object,)) + lv1: R.Object = R.call_packed("set_item", R.prim_value(0), lv, sinfo_args=(R.Object,)) + lv2: R.Tuple = R.vm.kill_object(lv) + lv1_1: R.Object = R.call_packed("get_item", R.prim_value(0), sinfo_args=(R.Object,)) + lv3 = R.call_tir( + cls.transform_layout_IOHW_to_OIHW, + (lv1_1,), + out_sinfo=R.Tensor((16, 3, 3, 3), dtype="float32"), + ) + lv4: R.Object = R.call_packed("set_item", R.prim_value(1), lv3, sinfo_args=(R.Object,)) + lv5: R.Tuple = R.vm.kill_object(lv1_1) + gv: R.Tuple(R.Object, R.Object) = (lv1, lv4) return gv after = LazyTransformParams()(Before)