From d78b93249bbf2872a13fb31bf37d643879178507 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 3 Oct 2023 11:24:19 -0500 Subject: [PATCH] [Unity] Include LegalizeOps in the default relax.build lowering flow Prior to this commit, `relax.transform.LegalizeOps` needed to be called prior to `relax.build`. This commit adds `LegalizeOps` to the lowering flow, to simplify the calling steps for an end-user. If the `IRModule` contains no legalizable functions, a second legalization pass has no effect. Some test cases relied on this behavior as an implicit assertion that operator fusion patterns applied. That is, by omitting `LegalizeOps`, a successful compilation `relax.build` would only occur if all legalizable operators have already been removed, and so an incorrect fusion pattern would result in a failure to build the module. While these tests would be better expressed by comparing against an expected fused pattern, updating the tests is outside the scope of this PR. To allow these tests to keep their implicit assertions, a `"relax.transform.apply_legalize_ops"` config can be used to disable the `LegalizeOps` pass. --- python/tvm/relax/vm_build.py | 1 + src/relax/transform/legalize_ops.cc | 9 ++++++++- tests/python/relax/test_codegen_cublas.py | 10 ++++++---- tests/python/relax/test_codegen_cudnn.py | 10 ++++++---- tests/python/relax/test_codegen_cutlass.py | 10 ++++++---- tests/python/relax/test_codegen_dnnl.py | 6 ++---- tests/python/relax/test_codegen_tensorrt.py | 6 ++---- tests/python/relax/test_codegen_tir_cutlass.py | 1 - tests/python/relax/test_dataflow_pattern.py | 2 -- tests/python/relax/test_e2e_op_dynamic.py | 6 ++---- tests/python/relax/test_frontend_stablehlo.py | 4 ---- tests/python/relax/test_op_gradient_numeric.py | 6 ++---- tests/python/relax/test_training_optimizer_numeric.py | 4 +--- tests/python/relax/test_transform_gradient_numeric.py | 4 +--- tests/python/relax/test_vm_execbuilder.py | 3 +-- 15 files changed, 38 insertions(+), 44 deletions(-) diff --git a/python/tvm/relax/vm_build.py b/python/tvm/relax/vm_build.py index 8b3337995785..a54c0154fc32 100644 --- a/python/tvm/relax/vm_build.py +++ b/python/tvm/relax/vm_build.py @@ -307,6 +307,7 @@ def foo(x: Tensor((3, 4), "float32"), y: Tensor((3, 4), "float32")): lowering_passes = tvm.transform.Sequential( [ + relax.transform.LegalizeOps(), relax.transform.RewriteDataflowReshape(), relax.transform.ToNonDataflow(), relax.transform.RemovePurityChecking(), diff --git a/src/relax/transform/legalize_ops.cc b/src/relax/transform/legalize_ops.cc index 170967d28281..a557a41f8eb7 100644 --- a/src/relax/transform/legalize_ops.cc +++ b/src/relax/transform/legalize_ops.cc @@ -31,6 +31,8 @@ namespace tvm { namespace relax { +TVM_REGISTER_PASS_CONFIG_OPTION("relax.transform.apply_legalize_ops", Bool); + /*! * \brief Check if a given Tensor/Shape/TupleStructInfo contains shapes whose * values are all known. @@ -206,7 +208,12 @@ namespace transform { Pass LegalizeOps(Optional> cmap, bool enable_warning) { runtime::TypedPackedFunc pass_func = [=](IRModule mod, PassContext pc) { - return LegalizeMutator(mod, cmap, enable_warning).Transform(); + bool apply_legalize_ops = + pc->GetConfig("relax.transform.apply_legalize_ops").value_or(Bool(true))->value; + if (apply_legalize_ops) { + mod = LegalizeMutator(mod, cmap, enable_warning).Transform(); + } + return mod; }; return CreateModulePass(/*pass_function=*/pass_func, /*opt_level=*/0, diff --git a/tests/python/relax/test_codegen_cublas.py b/tests/python/relax/test_codegen_cublas.py index fc2256531eed..6c8f6bc33501 100644 --- a/tests/python/relax/test_codegen_cublas.py +++ b/tests/python/relax/test_codegen_cublas.py @@ -42,11 +42,13 @@ def reset_seed(): def build_and_run(mod, inputs_np, target, legalize=False, cuda_graph=False): - if legalize: - mod = relax.transform.LegalizeOps()(mod) - dev = tvm.device(target, 0) - with tvm.transform.PassContext(config={"relax.backend.use_cuda_graph": cuda_graph}): + with tvm.transform.PassContext( + config={ + "relax.backend.use_cuda_graph": cuda_graph, + "relax.transform.apply_legalize_ops": legalize, + } + ): ex = relax.build(mod, target) vm = relax.VirtualMachine(ex, dev) f = vm["main"] diff --git a/tests/python/relax/test_codegen_cudnn.py b/tests/python/relax/test_codegen_cudnn.py index 5ba638c11c7a..c91355923298 100644 --- a/tests/python/relax/test_codegen_cudnn.py +++ b/tests/python/relax/test_codegen_cudnn.py @@ -110,11 +110,13 @@ def get_result_with_relax_cudnn_offload(mod, np_inputs, cuda_graph=False): def build_and_run(mod, inputs_np, target, legalize=False, cuda_graph=False): - if legalize: - mod = relax.transform.LegalizeOps()(mod) - dev = tvm.device(target, 0) - with tvm.transform.PassContext(config={"relax.backend.use_cuda_graph": cuda_graph}): + with tvm.transform.PassContext( + config={ + "relax.backend.use_cuda_graph": cuda_graph, + "relax.transform.apply_legalize_ops": legalize, + } + ): ex = relax.build(mod, target) vm = relax.VirtualMachine(ex, dev) f = vm["main"] diff --git a/tests/python/relax/test_codegen_cutlass.py b/tests/python/relax/test_codegen_cutlass.py index 151e05e9b6fa..9c7bb1dbbcd7 100644 --- a/tests/python/relax/test_codegen_cutlass.py +++ b/tests/python/relax/test_codegen_cutlass.py @@ -86,10 +86,12 @@ def main( def build_and_run(mod, inputs_np, target, legalize=True, cuda_graph=False): - if legalize: - mod = relax.transform.LegalizeOps()(mod) # For cpu reference, nop for cutlass. - - with tvm.transform.PassContext(config={"relax.backend.use_cuda_graph": cuda_graph}): + with tvm.transform.PassContext( + config={ + "relax.backend.use_cuda_graph": cuda_graph, + "relax.transform.apply_legalize_ops": legalize, + } + ): ex = relax.build(mod, target) dev = tvm.device(target, 0) diff --git a/tests/python/relax/test_codegen_dnnl.py b/tests/python/relax/test_codegen_dnnl.py index 66f442f16519..fe4590f85a12 100644 --- a/tests/python/relax/test_codegen_dnnl.py +++ b/tests/python/relax/test_codegen_dnnl.py @@ -52,14 +52,12 @@ def main( def build_and_run(mod, inputs, legalize=False): - if legalize: - mod = relax.transform.LegalizeOps()(mod) - target = tvm.target.Target("llvm") dev = tvm.cpu() inputs = [tvm.nd.array(inp, dev) for inp in inputs] - ex = relax.build(mod, target) + with tvm.transform.PassContext(config={"relax.transform.apply_legalize_ops": legalize}): + ex = relax.build(mod, target) vm = relax.VirtualMachine(ex, dev) f = vm["main"] return f(*inputs).numpy() diff --git a/tests/python/relax/test_codegen_tensorrt.py b/tests/python/relax/test_codegen_tensorrt.py index 595103bc5fb7..23dc7d887f4c 100644 --- a/tests/python/relax/test_codegen_tensorrt.py +++ b/tests/python/relax/test_codegen_tensorrt.py @@ -53,11 +53,9 @@ def main( def build_and_run(mod, inputs_np, target, legalize=False): - if legalize: - mod = relax.transform.LegalizeOps()(mod) - dev = tvm.device(target, 0) - ex = relax.build(mod, target) + with tvm.transform.PassContext(config={"relax.transform.apply_legalize_ops": legalize}): + ex = relax.build(mod, target) vm = relax.VirtualMachine(ex, dev) f = vm["main"] inputs = [tvm.nd.array(inp, dev) for inp in inputs_np] diff --git a/tests/python/relax/test_codegen_tir_cutlass.py b/tests/python/relax/test_codegen_tir_cutlass.py index 9c960ed355d3..a14ca7ac36d7 100644 --- a/tests/python/relax/test_codegen_tir_cutlass.py +++ b/tests/python/relax/test_codegen_tir_cutlass.py @@ -65,7 +65,6 @@ def build(mod): def build_and_run_reference(mod, inputs_np): - mod = relax.transform.LegalizeOps()(mod) dev = tvm.device("llvm", 0) ex = relax.build(mod, "llvm") vm = relax.VirtualMachine(ex, dev) diff --git a/tests/python/relax/test_dataflow_pattern.py b/tests/python/relax/test_dataflow_pattern.py index a8b71aa5ebfe..520fb873221c 100644 --- a/tests/python/relax/test_dataflow_pattern.py +++ b/tests/python/relax/test_dataflow_pattern.py @@ -1181,7 +1181,6 @@ def expected( # make sure it builds mod = tvm.IRModule() mod["main"] = rewritten - mod = rx.transform.LegalizeOps()(mod) rx.build(mod, target="llvm") @@ -1279,7 +1278,6 @@ def rewriter(matchings, _): # make sure it builds mod = tvm.IRModule() mod["main"] = rewritten - mod = rx.transform.LegalizeOps()(mod) rx.build(mod, target="llvm") diff --git a/tests/python/relax/test_e2e_op_dynamic.py b/tests/python/relax/test_e2e_op_dynamic.py index 63c71b797915..641469172f97 100644 --- a/tests/python/relax/test_e2e_op_dynamic.py +++ b/tests/python/relax/test_e2e_op_dynamic.py @@ -49,8 +49,7 @@ def main(x: R.Tensor((8, 9, 10, 10), "float32"), begin: R.Tensor((4,),"int64"), gv: R.Tensor("float32", ndim=4) = R.dynamic_strided_slice(x, begin, end, strides) return gv # fmt: on - mod = LegalizeOps()(DynamicStridedSlice) - vm = build(mod) + vm = build(DynamicStridedSlice) x_np = np.random.rand(8, 9, 10, 10).astype(np.float32) data_nd = tvm.nd.array(x_np, dev) @@ -83,8 +82,7 @@ def main(x: R.Tensor(("m", "n", 10, 10), "float32"), begin: R.Tensor((4,),"int64 gv: R.Tensor("float32", ndim=4) = R.dynamic_strided_slice(x, begin, end, strides) return gv # fmt: on - mod = LegalizeOps()(DynamicStridedSlice) - vm = build(mod) + vm = build(DynamicStridedSlice) x_np = np.random.rand(8, 9, 10, 10).astype(np.float32) data_nd = tvm.nd.array(x_np, dev) diff --git a/tests/python/relax/test_frontend_stablehlo.py b/tests/python/relax/test_frontend_stablehlo.py index 4152d50b8da1..d3068f29c73d 100644 --- a/tests/python/relax/test_frontend_stablehlo.py +++ b/tests/python/relax/test_frontend_stablehlo.py @@ -115,9 +115,6 @@ def check_correctness( # Run the jax jitted model with the input jax numpy data jax_output = jax_jit_mod(*inputs_jnp) - # Legalize the Relax Operators into TensorIR - # TODO (relax-team): add LegalizeOps in default seq in vm_build - ir_mod = relax.transform.LegalizeOps()(ir_mod) # TODO (yongwww): support multiple targets, # "llvm" should be good for this check target = tvm.target.Target("llvm", host="llvm") @@ -157,7 +154,6 @@ def get_vm_res( out: Union[tvm.nd.NDArray, List[tvm.nd.NDArray]] inference result """ - ir_mod = relax.transform.LegalizeOps()(ir_mod) target = tvm.target.Target("llvm", host="llvm") # Compile and run ex = relax.build(ir_mod, target) diff --git a/tests/python/relax/test_op_gradient_numeric.py b/tests/python/relax/test_op_gradient_numeric.py index 4b4c5cabc416..bc5cb0f5bec7 100644 --- a/tests/python/relax/test_op_gradient_numeric.py +++ b/tests/python/relax/test_op_gradient_numeric.py @@ -122,8 +122,7 @@ def _is_call_no_grad(expr): out = forward_bb.emit_output(call) forward_bb.emit_func_output(out) forward_mod = forward_bb.get() - forward_lower_mod = LegalizeOps()(forward_mod) - forward_ex = relax.build(forward_lower_mod, target) + forward_ex = relax.build(forward_mod, target) forward_vm = relax.VirtualMachine(forward_ex, dev) # Generate weights @@ -187,8 +186,7 @@ def forward(*inputs): grad_bb.emit_func_output(out) grad_mod = grad_bb.get() - grad_lower_mod = LegalizeOps()(grad_mod) - grad_ex = relax.build(grad_lower_mod, target) + grad_ex = relax.build(grad_mod, target) grad_vm = relax.VirtualMachine(grad_ex, dev) # tvm.runtime.NDArray inputs diff --git a/tests/python/relax/test_training_optimizer_numeric.py b/tests/python/relax/test_training_optimizer_numeric.py index 23db8987f12d..8acf7ad66b2a 100644 --- a/tests/python/relax/test_training_optimizer_numeric.py +++ b/tests/python/relax/test_training_optimizer_numeric.py @@ -23,15 +23,13 @@ from tvm import relax from tvm import IRModule from tvm.relax.training.optimizer import Adam, SGD, MomentumSGD -from tvm.relax.transform import LegalizeOps from tvm.script.parser import relax as R from tvm.runtime.relax_vm import VirtualMachine from tvm.testing import assert_allclose def _legalize_and_build(mod: IRModule, target, dev): - lowered_mod = LegalizeOps()(mod) - ex = relax.build(lowered_mod, target) + ex = relax.build(mod, target) vm = VirtualMachine(ex, dev) return vm diff --git a/tests/python/relax/test_transform_gradient_numeric.py b/tests/python/relax/test_transform_gradient_numeric.py index 7585ecf1f6b7..38a63406e88c 100644 --- a/tests/python/relax/test_transform_gradient_numeric.py +++ b/tests/python/relax/test_transform_gradient_numeric.py @@ -22,12 +22,10 @@ from tvm.testing import assert_allclose from tvm.testing.utils import check_numerical_grads from tvm.script.parser import ir as I, relax as R -from tvm.relax.transform import LegalizeOps def _legalize_and_build(mod, target, dev): - lowered_mod = LegalizeOps()(mod) - ex = relax.build(lowered_mod, target) + ex = relax.build(mod, target) vm = relax.VirtualMachine(ex, dev) return vm diff --git a/tests/python/relax/test_vm_execbuilder.py b/tests/python/relax/test_vm_execbuilder.py index 4c15d8013bf3..b2d9edd34661 100644 --- a/tests/python/relax/test_vm_execbuilder.py +++ b/tests/python/relax/test_vm_execbuilder.py @@ -277,8 +277,7 @@ def main(inp: R.Tensor((10, 10), dtype="float32")) -> R.Tensor((10, 10), dtype=" R.output(gv) return gv - mod = relax.transform.LegalizeOps()(Module) - ex = relax.build(mod, "llvm") + ex = relax.build(Module, "llvm") vm = relax.VirtualMachine(ex, tvm.cpu()) correct_input = tvm.nd.array(np.random.normal(size=(10, 10)).astype("float32"))