From 0f20125af8b26194b20fdfc3ad67a6da47e2038f Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Thu, 23 Apr 2020 15:05:18 +0200 Subject: [PATCH] fix miopen pad --- python/tvm/relay/op/strategy/rocm.py | 4 +++- topi/python/topi/rocm/conv2d.py | 6 +++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/python/tvm/relay/op/strategy/rocm.py b/python/tvm/relay/op/strategy/rocm.py index 6cda346e5068..b1213f1acbf1 100644 --- a/python/tvm/relay/op/strategy/rocm.py +++ b/python/tvm/relay/op/strategy/rocm.py @@ -36,6 +36,7 @@ def conv2d_strategy_rocm(attrs, inputs, out_type, target): layout = attrs.data_layout stride_h, stride_w = attrs.get_int_tuple("strides") kernel_layout = attrs.kernel_layout + padding = attrs.get_int_tuple("padding") if dilation_h < 1 or dilation_w < 1: raise ValueError("dilation should be positive value") @@ -77,7 +78,8 @@ def conv2d_strategy_rocm(attrs, inputs, out_type, target): else: raise RuntimeError("Unsupported conv2d layout {} for CUDA".format(layout)) # add miopen implementation - if "miopen" in target.libs and layout == "NCHW": + if "miopen" in target.libs and layout == "NCHW" and padding[0] == padding[2] and \ + padding[1] == padding[3]: strategy.add_implementation( wrap_compute_conv2d(topi.rocm.conv2d_nchw_miopen, True), wrap_topi_schedule(topi.rocm.schedule_conv2d_nchw_miopen), diff --git a/topi/python/topi/rocm/conv2d.py b/topi/python/topi/rocm/conv2d.py index 4ee18775b938..bc5d5c3c0688 100644 --- a/topi/python/topi/rocm/conv2d.py +++ b/topi/python/topi/rocm/conv2d.py @@ -66,7 +66,7 @@ def conv2d_nchw_miopen(cfg, data, kernel, strides, padding, dilation, pt, pl, pb, pr = get_pad_tuple(padding, (KH, KW)) pad_h, pad_w = pt + pb, pl + pr dilation_h, dilation_w = (dilation, dilation) if isinstance(dilation, int) else dilation - + assert (pt == pb) and (pl == pr) OH = (H + 2 * pad_h - KH) // stride_h + 1 OW = (W + 2 * pad_w - KW) // stride_w + 1 cfg.add_flop(2 * N * OH * OW * CO * CI * ((KH - 1) * dilation_h + 1) *\ @@ -76,8 +76,8 @@ def conv2d_nchw_miopen(cfg, data, kernel, strides, padding, dilation, kernel, stride_h, stride_w, - pad_h, - pad_w, + pt, + pl, dilation_h, dilation_w, conv_mode=0,