diff --git a/python/tvm/relax/vm_build.py b/python/tvm/relax/vm_build.py index 8b3337995785..a54c0154fc32 100644 --- a/python/tvm/relax/vm_build.py +++ b/python/tvm/relax/vm_build.py @@ -307,6 +307,7 @@ def foo(x: Tensor((3, 4), "float32"), y: Tensor((3, 4), "float32")): lowering_passes = tvm.transform.Sequential( [ + relax.transform.LegalizeOps(), relax.transform.RewriteDataflowReshape(), relax.transform.ToNonDataflow(), relax.transform.RemovePurityChecking(), diff --git a/src/relax/transform/legalize_ops.cc b/src/relax/transform/legalize_ops.cc index 170967d28281..a557a41f8eb7 100644 --- a/src/relax/transform/legalize_ops.cc +++ b/src/relax/transform/legalize_ops.cc @@ -31,6 +31,8 @@ namespace tvm { namespace relax { +TVM_REGISTER_PASS_CONFIG_OPTION("relax.transform.apply_legalize_ops", Bool); + /*! * \brief Check if a given Tensor/Shape/TupleStructInfo contains shapes whose * values are all known. @@ -206,7 +208,12 @@ namespace transform { Pass LegalizeOps(Optional> cmap, bool enable_warning) { runtime::TypedPackedFunc pass_func = [=](IRModule mod, PassContext pc) { - return LegalizeMutator(mod, cmap, enable_warning).Transform(); + bool apply_legalize_ops = + pc->GetConfig("relax.transform.apply_legalize_ops").value_or(Bool(true))->value; + if (apply_legalize_ops) { + mod = LegalizeMutator(mod, cmap, enable_warning).Transform(); + } + return mod; }; return CreateModulePass(/*pass_function=*/pass_func, /*opt_level=*/0, diff --git a/tests/python/relax/test_codegen_cublas.py b/tests/python/relax/test_codegen_cublas.py index fc2256531eed..6c8f6bc33501 100644 --- a/tests/python/relax/test_codegen_cublas.py +++ b/tests/python/relax/test_codegen_cublas.py @@ -42,11 +42,13 @@ def reset_seed(): def build_and_run(mod, inputs_np, target, legalize=False, cuda_graph=False): - if legalize: - mod = relax.transform.LegalizeOps()(mod) - dev = tvm.device(target, 0) - with tvm.transform.PassContext(config={"relax.backend.use_cuda_graph": cuda_graph}): + with tvm.transform.PassContext( + config={ + "relax.backend.use_cuda_graph": cuda_graph, + "relax.transform.apply_legalize_ops": legalize, + } + ): ex = relax.build(mod, target) vm = relax.VirtualMachine(ex, dev) f = vm["main"] diff --git a/tests/python/relax/test_codegen_cudnn.py b/tests/python/relax/test_codegen_cudnn.py index 5ba638c11c7a..c91355923298 100644 --- a/tests/python/relax/test_codegen_cudnn.py +++ b/tests/python/relax/test_codegen_cudnn.py @@ -110,11 +110,13 @@ def get_result_with_relax_cudnn_offload(mod, np_inputs, cuda_graph=False): def build_and_run(mod, inputs_np, target, legalize=False, cuda_graph=False): - if legalize: - mod = relax.transform.LegalizeOps()(mod) - dev = tvm.device(target, 0) - with tvm.transform.PassContext(config={"relax.backend.use_cuda_graph": cuda_graph}): + with tvm.transform.PassContext( + config={ + "relax.backend.use_cuda_graph": cuda_graph, + "relax.transform.apply_legalize_ops": legalize, + } + ): ex = relax.build(mod, target) vm = relax.VirtualMachine(ex, dev) f = vm["main"] diff --git a/tests/python/relax/test_codegen_cutlass.py b/tests/python/relax/test_codegen_cutlass.py index 151e05e9b6fa..9c7bb1dbbcd7 100644 --- a/tests/python/relax/test_codegen_cutlass.py +++ b/tests/python/relax/test_codegen_cutlass.py @@ -86,10 +86,12 @@ def main( def build_and_run(mod, inputs_np, target, legalize=True, cuda_graph=False): - if legalize: - mod = relax.transform.LegalizeOps()(mod) # For cpu reference, nop for cutlass. - - with tvm.transform.PassContext(config={"relax.backend.use_cuda_graph": cuda_graph}): + with tvm.transform.PassContext( + config={ + "relax.backend.use_cuda_graph": cuda_graph, + "relax.transform.apply_legalize_ops": legalize, + } + ): ex = relax.build(mod, target) dev = tvm.device(target, 0) diff --git a/tests/python/relax/test_codegen_dnnl.py b/tests/python/relax/test_codegen_dnnl.py index 66f442f16519..fe4590f85a12 100644 --- a/tests/python/relax/test_codegen_dnnl.py +++ b/tests/python/relax/test_codegen_dnnl.py @@ -52,14 +52,12 @@ def main( def build_and_run(mod, inputs, legalize=False): - if legalize: - mod = relax.transform.LegalizeOps()(mod) - target = tvm.target.Target("llvm") dev = tvm.cpu() inputs = [tvm.nd.array(inp, dev) for inp in inputs] - ex = relax.build(mod, target) + with tvm.transform.PassContext(config={"relax.transform.apply_legalize_ops": legalize}): + ex = relax.build(mod, target) vm = relax.VirtualMachine(ex, dev) f = vm["main"] return f(*inputs).numpy() diff --git a/tests/python/relax/test_codegen_tensorrt.py b/tests/python/relax/test_codegen_tensorrt.py index 595103bc5fb7..23dc7d887f4c 100644 --- a/tests/python/relax/test_codegen_tensorrt.py +++ b/tests/python/relax/test_codegen_tensorrt.py @@ -53,11 +53,9 @@ def main( def build_and_run(mod, inputs_np, target, legalize=False): - if legalize: - mod = relax.transform.LegalizeOps()(mod) - dev = tvm.device(target, 0) - ex = relax.build(mod, target) + with tvm.transform.PassContext(config={"relax.transform.apply_legalize_ops": legalize}): + ex = relax.build(mod, target) vm = relax.VirtualMachine(ex, dev) f = vm["main"] inputs = [tvm.nd.array(inp, dev) for inp in inputs_np] diff --git a/tests/python/relax/test_codegen_tir_cutlass.py b/tests/python/relax/test_codegen_tir_cutlass.py index 9c960ed355d3..a14ca7ac36d7 100644 --- a/tests/python/relax/test_codegen_tir_cutlass.py +++ b/tests/python/relax/test_codegen_tir_cutlass.py @@ -65,7 +65,6 @@ def build(mod): def build_and_run_reference(mod, inputs_np): - mod = relax.transform.LegalizeOps()(mod) dev = tvm.device("llvm", 0) ex = relax.build(mod, "llvm") vm = relax.VirtualMachine(ex, dev) diff --git a/tests/python/relax/test_dataflow_pattern.py b/tests/python/relax/test_dataflow_pattern.py index a8b71aa5ebfe..520fb873221c 100644 --- a/tests/python/relax/test_dataflow_pattern.py +++ b/tests/python/relax/test_dataflow_pattern.py @@ -1181,7 +1181,6 @@ def expected( # make sure it builds mod = tvm.IRModule() mod["main"] = rewritten - mod = rx.transform.LegalizeOps()(mod) rx.build(mod, target="llvm") @@ -1279,7 +1278,6 @@ def rewriter(matchings, _): # make sure it builds mod = tvm.IRModule() mod["main"] = rewritten - mod = rx.transform.LegalizeOps()(mod) rx.build(mod, target="llvm") diff --git a/tests/python/relax/test_e2e_op_dynamic.py b/tests/python/relax/test_e2e_op_dynamic.py index 63c71b797915..641469172f97 100644 --- a/tests/python/relax/test_e2e_op_dynamic.py +++ b/tests/python/relax/test_e2e_op_dynamic.py @@ -49,8 +49,7 @@ def main(x: R.Tensor((8, 9, 10, 10), "float32"), begin: R.Tensor((4,),"int64"), gv: R.Tensor("float32", ndim=4) = R.dynamic_strided_slice(x, begin, end, strides) return gv # fmt: on - mod = LegalizeOps()(DynamicStridedSlice) - vm = build(mod) + vm = build(DynamicStridedSlice) x_np = np.random.rand(8, 9, 10, 10).astype(np.float32) data_nd = tvm.nd.array(x_np, dev) @@ -83,8 +82,7 @@ def main(x: R.Tensor(("m", "n", 10, 10), "float32"), begin: R.Tensor((4,),"int64 gv: R.Tensor("float32", ndim=4) = R.dynamic_strided_slice(x, begin, end, strides) return gv # fmt: on - mod = LegalizeOps()(DynamicStridedSlice) - vm = build(mod) + vm = build(DynamicStridedSlice) x_np = np.random.rand(8, 9, 10, 10).astype(np.float32) data_nd = tvm.nd.array(x_np, dev) diff --git a/tests/python/relax/test_frontend_stablehlo.py b/tests/python/relax/test_frontend_stablehlo.py index 4152d50b8da1..d3068f29c73d 100644 --- a/tests/python/relax/test_frontend_stablehlo.py +++ b/tests/python/relax/test_frontend_stablehlo.py @@ -115,9 +115,6 @@ def check_correctness( # Run the jax jitted model with the input jax numpy data jax_output = jax_jit_mod(*inputs_jnp) - # Legalize the Relax Operators into TensorIR - # TODO (relax-team): add LegalizeOps in default seq in vm_build - ir_mod = relax.transform.LegalizeOps()(ir_mod) # TODO (yongwww): support multiple targets, # "llvm" should be good for this check target = tvm.target.Target("llvm", host="llvm") @@ -157,7 +154,6 @@ def get_vm_res( out: Union[tvm.nd.NDArray, List[tvm.nd.NDArray]] inference result """ - ir_mod = relax.transform.LegalizeOps()(ir_mod) target = tvm.target.Target("llvm", host="llvm") # Compile and run ex = relax.build(ir_mod, target) diff --git a/tests/python/relax/test_op_gradient_numeric.py b/tests/python/relax/test_op_gradient_numeric.py index 4b4c5cabc416..bc5cb0f5bec7 100644 --- a/tests/python/relax/test_op_gradient_numeric.py +++ b/tests/python/relax/test_op_gradient_numeric.py @@ -122,8 +122,7 @@ def _is_call_no_grad(expr): out = forward_bb.emit_output(call) forward_bb.emit_func_output(out) forward_mod = forward_bb.get() - forward_lower_mod = LegalizeOps()(forward_mod) - forward_ex = relax.build(forward_lower_mod, target) + forward_ex = relax.build(forward_mod, target) forward_vm = relax.VirtualMachine(forward_ex, dev) # Generate weights @@ -187,8 +186,7 @@ def forward(*inputs): grad_bb.emit_func_output(out) grad_mod = grad_bb.get() - grad_lower_mod = LegalizeOps()(grad_mod) - grad_ex = relax.build(grad_lower_mod, target) + grad_ex = relax.build(grad_mod, target) grad_vm = relax.VirtualMachine(grad_ex, dev) # tvm.runtime.NDArray inputs diff --git a/tests/python/relax/test_training_optimizer_numeric.py b/tests/python/relax/test_training_optimizer_numeric.py index 23db8987f12d..8acf7ad66b2a 100644 --- a/tests/python/relax/test_training_optimizer_numeric.py +++ b/tests/python/relax/test_training_optimizer_numeric.py @@ -23,15 +23,13 @@ from tvm import relax from tvm import IRModule from tvm.relax.training.optimizer import Adam, SGD, MomentumSGD -from tvm.relax.transform import LegalizeOps from tvm.script.parser import relax as R from tvm.runtime.relax_vm import VirtualMachine from tvm.testing import assert_allclose def _legalize_and_build(mod: IRModule, target, dev): - lowered_mod = LegalizeOps()(mod) - ex = relax.build(lowered_mod, target) + ex = relax.build(mod, target) vm = VirtualMachine(ex, dev) return vm diff --git a/tests/python/relax/test_transform_gradient_numeric.py b/tests/python/relax/test_transform_gradient_numeric.py index 7585ecf1f6b7..38a63406e88c 100644 --- a/tests/python/relax/test_transform_gradient_numeric.py +++ b/tests/python/relax/test_transform_gradient_numeric.py @@ -22,12 +22,10 @@ from tvm.testing import assert_allclose from tvm.testing.utils import check_numerical_grads from tvm.script.parser import ir as I, relax as R -from tvm.relax.transform import LegalizeOps def _legalize_and_build(mod, target, dev): - lowered_mod = LegalizeOps()(mod) - ex = relax.build(lowered_mod, target) + ex = relax.build(mod, target) vm = relax.VirtualMachine(ex, dev) return vm diff --git a/tests/python/relax/test_vm_execbuilder.py b/tests/python/relax/test_vm_execbuilder.py index 4c15d8013bf3..b2d9edd34661 100644 --- a/tests/python/relax/test_vm_execbuilder.py +++ b/tests/python/relax/test_vm_execbuilder.py @@ -277,8 +277,7 @@ def main(inp: R.Tensor((10, 10), dtype="float32")) -> R.Tensor((10, 10), dtype=" R.output(gv) return gv - mod = relax.transform.LegalizeOps()(Module) - ex = relax.build(mod, "llvm") + ex = relax.build(Module, "llvm") vm = relax.VirtualMachine(ex, tvm.cpu()) correct_input = tvm.nd.array(np.random.normal(size=(10, 10)).astype("float32"))