diff --git a/docs/how_to/tutorials/customize_opt.py b/docs/how_to/tutorials/customize_opt.py index 5806d6ce5da1..aab1de1b904d 100644 --- a/docs/how_to/tutorials/customize_opt.py +++ b/docs/how_to/tutorials/customize_opt.py @@ -103,7 +103,7 @@ def forward(self, x): # Import cublas pattern -import tvm.relax.backend.contrib.cublas as _cublas +import tvm.relax.backend.cuda.cublas as _cublas # Define a new pass for CUBLAS dispatch diff --git a/python/tvm/relax/backend/__init__.py b/python/tvm/relax/backend/__init__.py index 6b411d356d7e..2a64ffe27b30 100644 --- a/python/tvm/relax/backend/__init__.py +++ b/python/tvm/relax/backend/__init__.py @@ -16,26 +16,7 @@ # under the License. """Relax backends""" -from tvm.target import Target - -from . import contrib +from . import contrib, cpu_generic, cuda, gpu_generic, metal, rocm from .dispatch_sampling import DispatchSampling from .dispatch_sort_scan import DispatchSortScan from .pattern_registry import get_pattern, get_patterns_with_prefix - - -def get_default_pipeline(target: Target): - """Get the default Relax compilation pipeline for the given target.""" - if target.kind.name == "cuda": - from . import cuda # pylint: disable=import-outside-toplevel - - return cuda.get_default_pipeline(target) - if target.kind.name == "llvm": - from . import cpu_generic # pylint: disable=import-outside-toplevel - - return cpu_generic.get_default_pipeline(target) - # Todo(tvm-team): support gpu-generic - raise ValueError( - f"Target {target} is not yet supported by default pipeline. " - "Please lower and build the IRModule manually." - ) diff --git a/python/tvm/relax/backend/contrib/cublas.py b/python/tvm/relax/backend/cuda/cublas.py similarity index 100% rename from python/tvm/relax/backend/contrib/cublas.py rename to python/tvm/relax/backend/cuda/cublas.py diff --git a/python/tvm/relax/backend/contrib/cudnn.py b/python/tvm/relax/backend/cuda/cudnn.py similarity index 100% rename from python/tvm/relax/backend/contrib/cudnn.py rename to python/tvm/relax/backend/cuda/cudnn.py diff --git a/python/tvm/relax/backend/contrib/cutlass.py b/python/tvm/relax/backend/cuda/cutlass.py similarity index 100% rename from python/tvm/relax/backend/contrib/cutlass.py rename to python/tvm/relax/backend/cuda/cutlass.py diff --git a/python/tvm/relax/backend/gpu_generic/__init__.py b/python/tvm/relax/backend/gpu_generic/__init__.py new file mode 100644 index 000000000000..9c5e65fb49b6 --- /dev/null +++ b/python/tvm/relax/backend/gpu_generic/__init__.py @@ -0,0 +1,23 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""The Relax Metal backend compilation pipeline and other passes.""" +from .pipeline import ( + finalize_passes, + get_default_pipeline, + legalize_passes, + library_dispatch_passes, +) diff --git a/python/tvm/relax/backend/gpu_generic/pipeline.py b/python/tvm/relax/backend/gpu_generic/pipeline.py new file mode 100644 index 000000000000..86c60114c699 --- /dev/null +++ b/python/tvm/relax/backend/gpu_generic/pipeline.py @@ -0,0 +1,87 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""The Relax generic GPU backend compilation pipeline and other passes.""" +import tvm +from tvm import dlight as dl +from tvm import relax + + +def library_dispatch_passes(target: tvm.target.Target): # pylint: disable=unused-argument + """The default library dispatch passes for generic GPU backend.""" + return [ + relax.backend.DispatchSampling(), + relax.backend.DispatchSortScan(), + ] + + +def legalize_passes(target: tvm.target.Target): # pylint: disable=unused-argument + """The default legalization passes for generic GPU backend.""" + return [ + tvm.relax.transform.LegalizeOps(), + tvm.relax.transform.AnnotateTIROpPattern(), + tvm.relax.transform.FoldConstant(), + tvm.relax.transform.FuseOps(), + tvm.relax.transform.FuseTIR(), + dl.ApplyDefaultSchedule( + dl.gpu.Matmul(), + dl.gpu.GEMV(), + dl.gpu.Reduction(), + dl.gpu.GeneralReduction(), + dl.gpu.Fallback(), + ), + ] + + +def dataflow_lower_passes(target: tvm.target.Target): # pylint: disable=unused-argument + """The default dataflow lowering passes for generic GPU backend.""" + return [ + relax.transform.RewriteDataflowReshape(), + relax.transform.ToNonDataflow(), + relax.transform.RemovePurityChecking(), + relax.transform.CallTIRRewrite(), + ] + + +def finalize_passes(target: tvm.target.Target): # pylint: disable=unused-argument + """The default finalization passes for generic GPU backend.""" + return [ + relax.transform.StaticPlanBlockMemory(), + relax.transform.LowerAllocTensor(), + relax.transform.KillAfterLastUse(), + relax.transform.LowerRuntimeBuiltin(), + relax.transform.ComputePrimValue(), + relax.transform.VMShapeLower(), + relax.transform.AttachGlobalSymbol(), + ] + + +def get_default_pipeline(target: tvm.target.Target): + """Return the default compilation pipeline for generic GPU.""" + + @tvm.transform.module_pass(opt_level=0) + def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext): + with target: + seq = tvm.transform.Sequential( + library_dispatch_passes(target) + + legalize_passes(target) + + dataflow_lower_passes(target) + + finalize_passes(target) + ) + mod = seq(mod) + return mod + + return _pipeline diff --git a/python/tvm/relax/backend/metal/__init__.py b/python/tvm/relax/backend/metal/__init__.py new file mode 100644 index 000000000000..ab432bb6efcd --- /dev/null +++ b/python/tvm/relax/backend/metal/__init__.py @@ -0,0 +1,17 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""The Relax Metal backend compilation pipeline and other passes.""" diff --git a/python/tvm/relax/backend/contrib/coreml.py b/python/tvm/relax/backend/metal/coreml.py similarity index 100% rename from python/tvm/relax/backend/contrib/coreml.py rename to python/tvm/relax/backend/metal/coreml.py diff --git a/python/tvm/relax/backend/rocm/__init__.py b/python/tvm/relax/backend/rocm/__init__.py new file mode 100644 index 000000000000..27852abd46e2 --- /dev/null +++ b/python/tvm/relax/backend/rocm/__init__.py @@ -0,0 +1,23 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""The Relax ROCm backend compilation pipeline and other passes.""" +from .pipeline import ( + finalize_passes, + get_default_pipeline, + legalize_passes, + library_dispatch_passes, +) diff --git a/python/tvm/relax/backend/contrib/hipblas.py b/python/tvm/relax/backend/rocm/hipblas.py similarity index 100% rename from python/tvm/relax/backend/contrib/hipblas.py rename to python/tvm/relax/backend/rocm/hipblas.py diff --git a/python/tvm/relax/backend/rocm/pipeline.py b/python/tvm/relax/backend/rocm/pipeline.py new file mode 100644 index 000000000000..e74039ca8634 --- /dev/null +++ b/python/tvm/relax/backend/rocm/pipeline.py @@ -0,0 +1,87 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""The Relax ROCm backend compilation pipeline and other passes.""" +import tvm +from tvm import dlight as dl +from tvm import relax + + +def library_dispatch_passes(target: tvm.target.Target): # pylint: disable=unused-argument + """The default library dispatch passes for ROCm backend.""" + return [ + relax.backend.DispatchSampling(), + relax.backend.DispatchSortScan(), + ] + + +def legalize_passes(target: tvm.target.Target): # pylint: disable=unused-argument + """The default legalization passes for ROCm backend.""" + return [ + tvm.relax.transform.LegalizeOps(), + tvm.relax.transform.AnnotateTIROpPattern(), + tvm.relax.transform.FoldConstant(), + tvm.relax.transform.FuseOps(), + tvm.relax.transform.FuseTIR(), + dl.ApplyDefaultSchedule( + dl.gpu.Matmul(), + dl.gpu.GEMV(), + dl.gpu.Reduction(), + dl.gpu.GeneralReduction(), + dl.gpu.Fallback(), + ), + ] + + +def dataflow_lower_passes(target: tvm.target.Target): # pylint: disable=unused-argument + """The default dataflow lowering passes for ROCm backend.""" + return [ + relax.transform.RewriteDataflowReshape(), + relax.transform.ToNonDataflow(), + relax.transform.RemovePurityChecking(), + relax.transform.CallTIRRewrite(), + ] + + +def finalize_passes(target: tvm.target.Target): # pylint: disable=unused-argument + """The default finalization passes for ROCm backend.""" + return [ + relax.transform.StaticPlanBlockMemory(), + relax.transform.LowerAllocTensor(), + relax.transform.KillAfterLastUse(), + relax.transform.LowerRuntimeBuiltin(), + relax.transform.ComputePrimValue(), + relax.transform.VMShapeLower(), + relax.transform.AttachGlobalSymbol(), + ] + + +def get_default_pipeline(target: tvm.target.Target): + """Return the default compilation pipeline for ROCm.""" + + @tvm.transform.module_pass(opt_level=0) + def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext): + with target: + seq = tvm.transform.Sequential( + library_dispatch_passes(target) + + legalize_passes(target) + + dataflow_lower_passes(target) + + finalize_passes(target) + ) + mod = seq(mod) + return mod + + return _pipeline diff --git a/python/tvm/relax/pipeline.py b/python/tvm/relax/pipeline.py index fe3dbc99fc15..ebb61ad3e609 100644 --- a/python/tvm/relax/pipeline.py +++ b/python/tvm/relax/pipeline.py @@ -22,10 +22,11 @@ """ # pylint: disable=unused-argument from typing import Union + import tvm from tvm import meta_schedule as ms -from . import transform, backend +from . import backend, transform def zero_pipeline(*, enable_warning: bool = False): @@ -237,3 +238,76 @@ def _register(func): return func return _register + + +def library_dispatch_passes(target: tvm.target.Target): + """Get the default library dispatch passes for the given target.""" + if target.kind.name == "cuda": + return backend.cuda.library_dispatch_passes(target) + if target.kind.name == "rocm": + return backend.rocm.library_dispatch_passes(target) + if target.kind.name == "metal": + return backend.gpu_generic.library_dispatch_passes(target) + if target.kind.name == "llvm": + return backend.cpu_generic.library_dispatch_passes(target) + # Todo(tvm-team): support gpu-generic + raise ValueError(f"Target {target} is not yet supported by library dispatch passes.") + + +def legalize_passes(target: tvm.target.Target): + """Get the default legalization passes for the given target.""" + if target.kind.name == "cuda": + return backend.cuda.legalize_passes(target) + if target.kind.name == "rocm": + return backend.rocm.legalize_passes(target) + if target.kind.name == "metal": + return backend.gpu_generic.legalize_passes(target) + if target.kind.name == "llvm": + return backend.cpu_generic.legalize_passes(target) + # Todo(tvm-team): support gpu-generic + raise ValueError(f"Target {target} is not yet supported by library dispatch passes.") + + +def dataflow_lower_passes(target: tvm.target.Target): + """Get the default legalization passes for the given target.""" + if target.kind.name == "cuda": + return backend.cuda.dataflow_lower_passes(target) + if target.kind.name == "rocm": + return backend.rocm.dataflow_lower_passes(target) + if target.kind.name == "metal": + return backend.gpu_generic.dataflow_lower_passes(target) + if target.kind.name == "llvm": + return backend.cpu_generic.dataflow_lower_passes(target) + # Todo(tvm-team): support gpu-generic + raise ValueError(f"Target {target} is not yet supported by dataflow lowering passes.") + + +def finalize_passes(target: tvm.target.Target): + """Get the default legalization passes for the given target.""" + if target.kind.name == "cuda": + return backend.cuda.finalize_passes(target) + if target.kind.name == "rocm": + return backend.rocm.finalize_passes(target) + if target.kind.name == "metal": + return backend.gpu_generic.finalize_passes(target) + if target.kind.name == "llvm": + return backend.cpu_generic.finalize_passes(target) + # Todo(tvm-team): support gpu-generic + raise ValueError(f"Target {target} is not yet supported by finalization passes.") + + +def get_default_pipeline(target: tvm.target.Target): + """Get the default Relax compilation pipeline for the given target.""" + if target.kind.name == "cuda": + return backend.cuda.get_default_pipeline(target) + if target.kind.name == "rocm": + return backend.rocm.get_default_pipeline(target) + if target.kind.name == "metal": + return backend.gpu_generic.get_default_pipeline(target) + if target.kind.name == "llvm": + return backend.cpu_generic.get_default_pipeline(target) + # Todo(tvm-team): support gpu-generic + raise ValueError( + f"Target {target} is not yet supported by default pipeline. " + "Please lower and build the IRModule manually." + ) diff --git a/tests/python/relax/test_codegen_coreml.py b/tests/python/relax/test_codegen_coreml.py index ba8304bca728..0be6f4731635 100644 --- a/tests/python/relax/test_codegen_coreml.py +++ b/tests/python/relax/test_codegen_coreml.py @@ -41,7 +41,7 @@ def _has_xcode(): def verify(mod, inputs): - from tvm.relax.backend.contrib.coreml import partition_for_coreml + from tvm.relax.backend.metal.coreml import partition_for_coreml mod1 = partition_for_coreml(mod) mod1 = relax.transform.RunCodegen()(mod1) diff --git a/tests/python/relax/test_codegen_cublas.py b/tests/python/relax/test_codegen_cublas.py index 8ab97e4f295a..2fbff8433bf7 100644 --- a/tests/python/relax/test_codegen_cublas.py +++ b/tests/python/relax/test_codegen_cublas.py @@ -21,7 +21,7 @@ import tvm.testing import tvm.topi.testing from tvm import relax -from tvm.relax.backend.contrib.cublas import partition_for_cublas +from tvm.relax.backend.cuda.cublas import partition_for_cublas from tvm.relax.testing import get_relax_matmul_module from tvm.script import relax as R from tvm.script.ir_builder import IRBuilder diff --git a/tests/python/relax/test_codegen_cudnn.py b/tests/python/relax/test_codegen_cudnn.py index 59f49bfde889..0f9a0bc262a6 100644 --- a/tests/python/relax/test_codegen_cudnn.py +++ b/tests/python/relax/test_codegen_cudnn.py @@ -21,11 +21,10 @@ import tvm.testing import tvm.topi.testing from tvm import relax -from tvm.relax.backend.contrib.cudnn import partition_for_cudnn -from tvm.relax.testing import get_relax_matmul_module, get_relax_stacked_attention_module from tvm.contrib.pickle_memoize import memoize +from tvm.relax.backend.cuda.cudnn import partition_for_cudnn +from tvm.relax.testing import get_relax_stacked_attention_module from tvm.script import relax as R - from tvm.script.ir_builder import IRBuilder from tvm.script.ir_builder import relax as relax_builder diff --git a/tests/python/relax/test_codegen_cutlass.py b/tests/python/relax/test_codegen_cutlass.py index 3fa3f2d914d7..9d31ced08c9d 100644 --- a/tests/python/relax/test_codegen_cutlass.py +++ b/tests/python/relax/test_codegen_cutlass.py @@ -23,10 +23,10 @@ from tvm import relax from tvm.contrib.cutlass.build import is_shape_valid_for_cutlass_matmul from tvm.contrib.pickle_memoize import memoize -from tvm.relax.backend.contrib.cutlass import partition_for_cutlass +from tvm.relax.backend.cuda.cutlass import partition_for_cutlass from tvm.relax.testing import ( - get_relax_matmul_module, get_relax_attention_module, + get_relax_matmul_module, get_relax_stacked_attention_module, ) from tvm.script import ir as I @@ -1909,9 +1909,7 @@ def main( ex = relax.build(mod_transform, target="llvm") vm = relax.vm.VirtualMachine(ex, tvm.cpu(0)) - (packed_weight, scales,) = vm[ - transform_func_name - ]((tvm.nd.array(y),)) + packed_weight, scales = vm[transform_func_name]((tvm.nd.array(y),)) dev = tvm.device("cuda", 0) ex = relax.build(mod_deploy, target="cuda") @@ -2066,9 +2064,7 @@ def main( ex = relax.build(mod_transform, target="llvm") vm = relax.vm.VirtualMachine(ex, tvm.cpu(0)) - (packed_weight, scales,) = vm[ - transform_func_name - ]((tvm.nd.array(y),)) + packed_weight, scales = vm[transform_func_name]((tvm.nd.array(y),)) dev = tvm.device("cuda", 0) ex = relax.build(mod_deploy, target="cuda") diff --git a/tests/python/relax/test_codegen_hipblas.py b/tests/python/relax/test_codegen_hipblas.py index f43b83802b81..7edbed61bc96 100644 --- a/tests/python/relax/test_codegen_hipblas.py +++ b/tests/python/relax/test_codegen_hipblas.py @@ -21,7 +21,7 @@ import tvm.testing import tvm.topi.testing from tvm import relax -from tvm.relax.backend.contrib.hipblas import partition_for_hipblas +from tvm.relax.backend.rocm.hipblas import partition_for_hipblas from tvm.relax.testing import get_relax_matmul_module from tvm.script import relax as R diff --git a/tests/python/relax/test_pipeline.py b/tests/python/relax/test_pipeline.py index 86711477b274..6d4d44c9e5a7 100644 --- a/tests/python/relax/test_pipeline.py +++ b/tests/python/relax/test_pipeline.py @@ -25,7 +25,7 @@ def test_pipeline_compile(): target = tvm.target.Target("llvm", host="llvm") - pipeline = relax.backend.get_default_pipeline(target) + pipeline = relax.pipeline.get_default_pipeline(target) @tvm.script.ir_module class Mod: @@ -51,7 +51,7 @@ def main(x: R.Tensor((3, 4), "float32"), y: R.Tensor((3, 4), "float32")): def test_pipeline_with_kv_cache(): """A dummy pipline that simulates KV update.""" target = tvm.target.Target("llvm", host="llvm") - pipeline = relax.backend.get_default_pipeline(target) + pipeline = relax.pipeline.get_default_pipeline(target) @tvm.script.ir_module class Mod: diff --git a/tests/python/relax/test_transform_codegen_pass.py b/tests/python/relax/test_transform_codegen_pass.py index 6e78a67fd085..cb827e9734e3 100644 --- a/tests/python/relax/test_transform_codegen_pass.py +++ b/tests/python/relax/test_transform_codegen_pass.py @@ -15,17 +15,21 @@ # specific language governing permissions and limitations # under the License. -import pytest import os +import tempfile + +import numpy as np +import pytest + import tvm import tvm.testing from tvm import relax, tir -import numpy as np -from tvm.script import relax as R, ir as I, tir as T +from tvm.relax.dpl import is_op, wildcard from tvm.relax.testing import transform -import tempfile from tvm.relax.transform.tuning_api import Trace -from tvm.relax.dpl import is_op, wildcard +from tvm.script import ir as I +from tvm.script import relax as R +from tvm.script import tir as T env_checker_codegen = tvm.get_global_func("relax.ext.tensorrt", True) env_checker_runtime = tvm.get_global_func("relax.is_tensorrt_runtime_enabled", True) @@ -280,7 +284,7 @@ def rename_main(mod): def test_dynamic_shape(): - import tvm.relax.backend.contrib.cublas + import tvm.relax.backend.cuda.cublas @I.ir_module class Before: diff --git a/tests/python/relax/test_transform_fuse_ops_by_pattern.py b/tests/python/relax/test_transform_fuse_ops_by_pattern.py index a07875fcdae6..67a1d76d9801 100644 --- a/tests/python/relax/test_transform_fuse_ops_by_pattern.py +++ b/tests/python/relax/test_transform_fuse_ops_by_pattern.py @@ -19,6 +19,8 @@ import tvm from tvm import relax +from tvm.relax.backend.cuda.cublas import partition_for_cublas +from tvm.relax.backend.cuda.cutlass import partition_for_cutlass from tvm.relax.dpl.pattern import ( is_op, is_tuple_get_item, @@ -26,8 +28,6 @@ wildcard, ) from tvm.relax.transform import PatternCheckContext -from tvm.relax.backend.contrib.cutlass import partition_for_cutlass -from tvm.relax.backend.contrib.cublas import partition_for_cublas from tvm.script import ir as I from tvm.script import relax as R from tvm.script import tir as T @@ -57,7 +57,8 @@ def main( cls = Conv2dReLU_composite_annotated with R.dataflow(): gv: R.Tensor( - (1, 64, 56, 56), dtype="float32" + (1, 64, 56, 56), + dtype="float32", ) = cls.fused_relax_nn_conv2d_relax_nn_relu_dnnl(data, weight1) R.output(gv) return gv @@ -121,10 +122,12 @@ def main( cls = Conv2dReLUx2Partitioned with R.dataflow(): lv: R.Tensor( - (1, 64, 56, 56), dtype="float32" + (1, 64, 56, 56), + dtype="float32", ) = cls.fused_relax_nn_conv2d_relax_nn_relu(data, weight1) gv: R.Tensor( - (1, 64, 54, 54), dtype="float32" + (1, 64, 54, 54), + dtype="float32", ) = cls.fused_relax_nn_conv2d_relax_nn_relu1(lv, weight2) R.output(gv) return gv @@ -236,7 +239,8 @@ def main( data, weight1 ) gv: R.Tensor( - (1, 64, 54, 54), dtype="float32" + (1, 64, 54, 54), + dtype="float32", ) = cls.fused_relax_nn_conv2d_relax_nn_relu(lv, weight2) R.output(gv) return gv