diff --git a/python/tvm/relax/backend/__init__.py b/python/tvm/relax/backend/__init__.py index 6d0ca302018c..6b411d356d7e 100644 --- a/python/tvm/relax/backend/__init__.py +++ b/python/tvm/relax/backend/__init__.py @@ -16,7 +16,26 @@ # under the License. """Relax backends""" +from tvm.target import Target + from . import contrib from .dispatch_sampling import DispatchSampling from .dispatch_sort_scan import DispatchSortScan from .pattern_registry import get_pattern, get_patterns_with_prefix + + +def get_default_pipeline(target: Target): + """Get the default Relax compilation pipeline for the given target.""" + if target.kind.name == "cuda": + from . import cuda # pylint: disable=import-outside-toplevel + + return cuda.get_default_pipeline(target) + if target.kind.name == "llvm": + from . import cpu_generic # pylint: disable=import-outside-toplevel + + return cpu_generic.get_default_pipeline(target) + # Todo(tvm-team): support gpu-generic + raise ValueError( + f"Target {target} is not yet supported by default pipeline. " + "Please lower and build the IRModule manually." + ) diff --git a/python/tvm/relax/backend/cpu_generic/__init__.py b/python/tvm/relax/backend/cpu_generic/__init__.py new file mode 100644 index 000000000000..e1cd26686cd7 --- /dev/null +++ b/python/tvm/relax/backend/cpu_generic/__init__.py @@ -0,0 +1,23 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""The Relax CPU backend compilation pipeline and other passes.""" +from .pipeline import ( + finalize_passes, + get_default_pipeline, + legalize_passes, + library_dispatch_passes, +) diff --git a/python/tvm/relax/backend/cpu_generic/pipeline.py b/python/tvm/relax/backend/cpu_generic/pipeline.py new file mode 100644 index 000000000000..74d951b817b1 --- /dev/null +++ b/python/tvm/relax/backend/cpu_generic/pipeline.py @@ -0,0 +1,76 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""The Relax CPU backend compilation pipeline and other passes.""" +import tvm +from tvm import relax + + +def library_dispatch_passes(target: tvm.target.Target): # pylint: disable=unused-argument + """The default library dispatch passes for CPU backend.""" + return [] + + +def legalize_passes(target: tvm.target.Target): # pylint: disable=unused-argument + """The default legalization passes for CPU backend.""" + return [ + tvm.relax.transform.LegalizeOps(), + tvm.relax.transform.AnnotateTIROpPattern(), + tvm.relax.transform.FoldConstant(), + tvm.relax.transform.FuseOps(), + tvm.relax.transform.FuseTIR(), + ] + + +def dataflow_lower_passes(target: tvm.target.Target): # pylint: disable=unused-argument + """The default dataflow lowering passes for CPU backend.""" + return [ + relax.transform.RewriteDataflowReshape(), + relax.transform.ToNonDataflow(), + relax.transform.RemovePurityChecking(), + relax.transform.CallTIRRewrite(), + ] + + +def finalize_passes(target: tvm.target.Target): # pylint: disable=unused-argument + """The default finalization passes for CPU backend.""" + return [ + relax.transform.StaticPlanBlockMemory(), + relax.transform.LowerAllocTensor(), + relax.transform.KillAfterLastUse(), + relax.transform.LowerRuntimeBuiltin(), + relax.transform.ComputePrimValue(), + relax.transform.VMShapeLower(), + relax.transform.AttachGlobalSymbol(), + ] + + +def get_default_pipeline(target: tvm.target.Target): + """Return the default compilation pipeline for CPU.""" + + @tvm.transform.module_pass(opt_level=0) + def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext): + with target: + seq = tvm.transform.Sequential( + library_dispatch_passes(target) + + legalize_passes(target) + + dataflow_lower_passes(target) + + finalize_passes(target) + ) + mod = seq(mod) + return mod + + return _pipeline diff --git a/python/tvm/relax/backend/cuda/__init__.py b/python/tvm/relax/backend/cuda/__init__.py new file mode 100644 index 000000000000..f4458f4b55d1 --- /dev/null +++ b/python/tvm/relax/backend/cuda/__init__.py @@ -0,0 +1,23 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""The Relax CUDA backend compilation pipeline and other passes.""" +from .pipeline import ( + finalize_passes, + get_default_pipeline, + legalize_passes, + library_dispatch_passes, +) diff --git a/python/tvm/relax/backend/cuda/pipeline.py b/python/tvm/relax/backend/cuda/pipeline.py new file mode 100644 index 000000000000..d5c4c0856165 --- /dev/null +++ b/python/tvm/relax/backend/cuda/pipeline.py @@ -0,0 +1,88 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""The Relax CUDA backend compilation pipeline and other passes.""" +import tvm +from tvm import dlight as dl +from tvm import relax + + +def library_dispatch_passes(target: tvm.target.Target): # pylint: disable=unused-argument + """The default library dispatch passes for CUDA backend.""" + return [ + relax.backend.DispatchSampling(), + relax.backend.DispatchSortScan(), + ] + + +def legalize_passes(target: tvm.target.Target): # pylint: disable=unused-argument + """The default legalization passes for CUDA backend.""" + return [ + tvm.relax.transform.LegalizeOps(), + tvm.relax.transform.AnnotateTIROpPattern(), + tvm.relax.transform.FoldConstant(), + tvm.relax.transform.FuseOps(), + tvm.relax.transform.FuseTIR(), + dl.ApplyDefaultSchedule( + dl.gpu.Matmul(), + dl.gpu.GEMV(), + dl.gpu.Reduction(), + dl.gpu.GeneralReduction(), + dl.gpu.Fallback(), + ), + ] + + +def dataflow_lower_passes(target: tvm.target.Target): # pylint: disable=unused-argument + """The default dataflow lowering passes for CUDA backend.""" + return [ + relax.transform.RewriteDataflowReshape(), + relax.transform.ToNonDataflow(), + relax.transform.RemovePurityChecking(), + relax.transform.CallTIRRewrite(), + ] + + +def finalize_passes(target: tvm.target.Target): # pylint: disable=unused-argument + """The default finalization passes for CUDA backend.""" + return [ + relax.transform.StaticPlanBlockMemory(), + relax.transform.RewriteCUDAGraph(), + relax.transform.LowerAllocTensor(), + relax.transform.KillAfterLastUse(), + relax.transform.LowerRuntimeBuiltin(), + relax.transform.ComputePrimValue(), + relax.transform.VMShapeLower(), + relax.transform.AttachGlobalSymbol(), + ] + + +def get_default_pipeline(target: tvm.target.Target): + """Return the default compilation pipeline for CUDA.""" + + @tvm.transform.module_pass(opt_level=0) + def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext): + with target: + seq = tvm.transform.Sequential( + library_dispatch_passes(target) + + legalize_passes(target) + + dataflow_lower_passes(target) + + finalize_passes(target) + ) + mod = seq(mod) + return mod + + return _pipeline diff --git a/tests/python/relax/test_pipeline.py b/tests/python/relax/test_pipeline.py index 8aa41490c66b..86711477b274 100644 --- a/tests/python/relax/test_pipeline.py +++ b/tests/python/relax/test_pipeline.py @@ -15,14 +15,17 @@ # specific language governing permissions and limitations # under the License. import numpy as np + import tvm import tvm.testing from tvm import relax -from tvm.script import relax as R, tir as T +from tvm.script import relax as R +from tvm.script import tir as T def test_pipeline_compile(): - pipeline = relax.get_pipeline() + target = tvm.target.Target("llvm", host="llvm") + pipeline = relax.backend.get_default_pipeline(target) @tvm.script.ir_module class Mod: @@ -33,7 +36,6 @@ def main(x: R.Tensor((3, 4), "float32"), y: R.Tensor((3, 4), "float32")): mod = Mod mod = pipeline(mod) - target = tvm.target.Target("llvm", host="llvm") ex = relax.build(mod, target) x_np = np.random.rand(3, 4).astype(np.float32) @@ -48,7 +50,8 @@ def main(x: R.Tensor((3, 4), "float32"), y: R.Tensor((3, 4), "float32")): def test_pipeline_with_kv_cache(): """A dummy pipline that simulates KV update.""" - pipeline = relax.get_pipeline() + target = tvm.target.Target("llvm", host="llvm") + pipeline = relax.backend.get_default_pipeline(target) @tvm.script.ir_module class Mod: @@ -92,8 +95,6 @@ def main( mod = Mod mod = pipeline(mod) - target = tvm.target.Target("llvm", host="llvm") - ex = relax.build(mod, target) num_steps = 8