diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index 918c36c20079..da7cbd5cec10 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -183,9 +183,9 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target): elif layout == "NHWC": assert kernel_layout == "HWIO" strategy.add_implementation( - wrap_compute_conv2d(topi.cuda.conv2d_nhwc), - wrap_topi_schedule(topi.cuda.schedule_conv2d_nhwc), - name="conv2d_nhwc.cuda", + wrap_compute_conv2d(topi.gpu.conv2d_nhwc), + wrap_topi_schedule(topi.gpu.schedule_conv2d_nhwc), + name="conv2d_nhwc.gpu", ) N, H, W, _ = get_const_tuple(data.shape) diff --git a/python/tvm/relay/op/strategy/rocm.py b/python/tvm/relay/op/strategy/rocm.py index 8d9c28ba714b..1453128eeb67 100644 --- a/python/tvm/relay/op/strategy/rocm.py +++ b/python/tvm/relay/op/strategy/rocm.py @@ -69,9 +69,9 @@ def conv2d_strategy_rocm(attrs, inputs, out_type, target): elif layout == "NHWC": assert kernel_layout == "HWIO" strategy.add_implementation( - wrap_compute_conv2d(topi.cuda.conv2d_nhwc), - wrap_topi_schedule(topi.cuda.schedule_conv2d_nhwc), - name="conv2d_nhwc.cuda", + wrap_compute_conv2d(topi.gpu.conv2d_nhwc), + wrap_topi_schedule(topi.gpu.schedule_conv2d_nhwc), + name="conv2d_nhwc.gpu", ) N, H, W, _ = get_const_tuple(data.shape) KH, KW, CI, CO = get_const_tuple(kernel.shape) diff --git a/python/tvm/topi/cuda/conv2d.py b/python/tvm/topi/cuda/conv2d.py index 8338208dd968..bd8d7ec19bb3 100644 --- a/python/tvm/topi/cuda/conv2d.py +++ b/python/tvm/topi/cuda/conv2d.py @@ -25,7 +25,6 @@ from ..nn.utils import get_pad_tuple from ..utils import get_const_tuple, traverse_inline from .conv2d_direct import schedule_direct_cuda -from .conv2d_nhwc import schedule_conv2d_nhwc_direct @autotvm.register_topi_compute("conv2d_nchw.cuda") @@ -48,26 +47,6 @@ def _callback(op): return s -@autotvm.register_topi_compute("conv2d_nhwc.cuda") -def conv2d_nhwc(cfg, data, kernel, strides, padding, dilation, out_dtype="float32"): - """Compute conv2d with NHWC layout""" - return nn.conv2d_nhwc(data, kernel, strides, padding, dilation, out_dtype) - - -@autotvm.register_topi_schedule("conv2d_nhwc.cuda") -def schedule_conv2d_nhwc(cfg, outs): - """Create the schedule for conv2d_nhwc""" - outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs - s = te.create_schedule([x.op for x in outs]) - - def _callback(op): - if op.tag == "conv2d_nhwc": - schedule_conv2d_nhwc_direct(cfg, s, op.output(0)) - - traverse_inline(s, outs[0].op, _callback) - return s - - @autotvm.register_topi_compute("conv2d_cudnn.cuda") def conv2d_cudnn( cfg, data, kernel, strides, padding, dilation, groups=1, layout="NCHW", out_dtype="float32" diff --git a/python/tvm/topi/gpu/__init__.py b/python/tvm/topi/gpu/__init__.py index 6d9fd39e16b8..8ed9362a3cf2 100644 --- a/python/tvm/topi/gpu/__init__.py +++ b/python/tvm/topi/gpu/__init__.py @@ -18,3 +18,4 @@ # pylint: disable=redefined-builtin, wildcard-import """GPU specific declaration and schedules.""" from .dense import * +from .conv2d import * diff --git a/python/tvm/topi/gpu/conv2d.py b/python/tvm/topi/gpu/conv2d.py new file mode 100644 index 000000000000..87c900e1d4d7 --- /dev/null +++ b/python/tvm/topi/gpu/conv2d.py @@ -0,0 +1,43 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-argument +"""Schedule for conv2d operator""" +from tvm import te, autotvm + +from .. import nn +from ..utils import traverse_inline +from .conv2d_nhwc import schedule_conv2d_nhwc_direct + + +@autotvm.register_topi_compute("conv2d_nhwc.gpu") +def conv2d_nhwc(cfg, data, kernel, strides, padding, dilation, out_dtype="float32"): + """Compute conv2d with NHWC layout""" + return nn.conv2d_nhwc(data, kernel, strides, padding, dilation, out_dtype) + + +@autotvm.register_topi_schedule("conv2d_nhwc.gpu") +def schedule_conv2d_nhwc(cfg, outs): + """Create the schedule for conv2d_nhwc""" + outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs + s = te.create_schedule([x.op for x in outs]) + + def _callback(op): + if op.tag == "conv2d_nhwc": + schedule_conv2d_nhwc_direct(cfg, s, op.output(0)) + + traverse_inline(s, outs[0].op, _callback) + return s diff --git a/python/tvm/topi/cuda/conv2d_nhwc.py b/python/tvm/topi/gpu/conv2d_nhwc.py similarity index 91% rename from python/tvm/topi/cuda/conv2d_nhwc.py rename to python/tvm/topi/gpu/conv2d_nhwc.py index f8115830ce50..ff0610394eac 100644 --- a/python/tvm/topi/cuda/conv2d_nhwc.py +++ b/python/tvm/topi/gpu/conv2d_nhwc.py @@ -54,12 +54,13 @@ def schedule_conv2d_nhwc_direct(cfg, s, Conv): cfg.define_knob("vthread_n", [1] if dynamic_batch else [1, 2]) cfg.define_knob("vthread_c", [1, 2]) cfg.define_knob("step", [16, 3, 32, 64]) + cfg.define_knob("vectorize", [1, 2, 4, 8]) # fallback support target = tvm.target.Target.current() if cfg.is_fallback: ref_log = autotvm.tophub.load_reference_log( - target.kind.name, target.model, "conv2d_nhwc.cuda" + target.kind.name, target.model, "conv2d_nhwc.gpu" ) cfg.fallback_with_reference_log(ref_log) @@ -70,6 +71,7 @@ def schedule_conv2d_nhwc_direct(cfg, s, Conv): vthread_n = cfg["vthread_n"].val vthread_c = cfg["vthread_c"].val step = cfg["step"].val + vec_factor = cfg["vectorize"].val block_factor_c = tile_c * num_thread_c * vthread_c offset = 8 @@ -85,15 +87,17 @@ def schedule_conv2d_nhwc_direct(cfg, s, Conv): thread_yz = te.thread_axis((0, vthread_n), "vthread", name="vy") # Schedule for output - ni, hi, wi, fi = s[output].op.axis - bx = s[output].fuse(hi, wi) + ni, _, wi, fi = s[output].op.axis + bx = wi + fi, vec = s[output].split(fi, factor=vec_factor) + s[output].vectorize(vec) tx, fi = s[output].split(fi, factor=tile_c) txz, tx = s[output].split(tx, factor=num_thread_c) bz, txz = s[output].split(txz, factor=vthread_c) ty, ni = s[output].split(ni, factor=tile_n) tyz, ty = s[output].split(ty, factor=num_thread_n) by, tyz = s[output].split(tyz, factor=vthread_n) - s[output].reorder(bx, by, bz, tyz, txz, ty, tx, ni, fi) + s[output].reorder(bx, by, bz, tyz, txz, ty, tx, ni, fi, vec) s[output].bind(bz, block_z) s[output].bind(by, block_y) s[output].bind(bx, block_x) @@ -106,6 +110,7 @@ def schedule_conv2d_nhwc_direct(cfg, s, Conv): ni, yi, xi, fi = s[OL].op.axis ry, rx, rc = s[OL].op.reduce_axis rco, rci = s[OL].split(rc, factor=step) + s[OL].vectorize(fi) s[OL].reorder(rco, ry, rx, rci, ni, fi) s[AA].compute_at(s[OL], rx) @@ -125,6 +130,8 @@ def schedule_conv2d_nhwc_direct(cfg, s, Conv): _, _, ic, o = s[WW].op.axis t = s[WW].fuse(ic, o) s[WW].storage_align(ic, W_align - 1, W_align) + t, vec = s[WW].split(t, factor=vec_factor) + s[WW].vectorize(vec) ty, tx = s[WW].split(t, factor=num_thread_c) _, ty = s[WW].split(ty, factor=num_thread_n) s[WW].bind(tx, thread_x) diff --git a/tests/python/topi/python/test_topi_conv2d_nhwc.py b/tests/python/topi/python/test_topi_conv2d_nhwc.py index 96359860f569..8c125af72163 100644 --- a/tests/python/topi/python/test_topi_conv2d_nhwc.py +++ b/tests/python/topi/python/test_topi_conv2d_nhwc.py @@ -28,7 +28,7 @@ _conv2d_nhwc_implement = { "generic": (topi.nn.conv2d_nhwc, topi.generic.schedule_conv2d_nhwc), - "gpu": (topi.cuda.conv2d_nhwc, topi.cuda.schedule_conv2d_nhwc), + "gpu": (topi.gpu.conv2d_nhwc, topi.gpu.schedule_conv2d_nhwc), "cpu": (topi.nn.conv2d_nhwc, topi.x86.schedule_conv2d_nhwc), "arm_cpu": ( topi.arm_cpu.conv2d_nhwc_spatial_pack,