From d46bb766978de9c37a931b00fd3be54213c83585 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 12 Jul 2021 14:21:34 -0500 Subject: [PATCH 1/6] [Topi][Testing] Minor cleanup for python reference implementations - Use input dtype for dilate/conv2d accumulate in python impl. Previously, the python implementations of dilation and conv2d would use numpy default dtype in some cases, rather than the input data's dtype. - Added fallback for datatypes not supported by scipy.signal.convolve2d (e.g. float16). - Refactored to avoid duplication, use common get_pad_tuple functionality. --- python/tvm/topi/testing/common.py | 51 ++++++++ python/tvm/topi/testing/conv2d_nchw_python.py | 55 +++++++- .../topi/testing/depthwise_conv2d_python.py | 118 +++++------------- python/tvm/topi/testing/dilate_python.py | 34 +++-- 4 files changed, 159 insertions(+), 99 deletions(-) diff --git a/python/tvm/topi/testing/common.py b/python/tvm/topi/testing/common.py index 785a6d11d8a7..d040310ccc8f 100644 --- a/python/tvm/topi/testing/common.py +++ b/python/tvm/topi/testing/common.py @@ -18,6 +18,8 @@ """Common utility for topi test""" import numpy as np +import scipy.signal + import tvm from tvm import topi from tvm.testing import assert_allclose @@ -108,3 +110,52 @@ def compare_numpy_tvm(inputs, output, target, device, compute, schedule): arys = [tvm.nd.array(x, device=device) for x in inputs] func(*(arys + [te_out])) assert_allclose(te_out.numpy(), output, atol=1e-4, rtol=1e-4) + + +def _convolve2d(data, weights): + """2d convolution operator in HW layout. + + This is intended to be used as a replacement for + scipy.signals.convolve2d, with wider support for different dtypes. + scipy.signal.convolve2d does not support all TVM-supported + dtypes (e.g. float16). Where possible, this function uses + scipy.signal.convolve2d to take advantage of compiled scipy + routines, falling back to an explicit loop only where needed. + + Parameters + ---------- + data : numpy.ndarray + 2-D with shape [in_height, in_width] + + weights : numpy.ndarray + 2-D with shape [filter_height, filter_width]. + + Returns + ------- + b_np : np.ndarray + 2-D with shape [out_height, out_width] + + Return value and layout conventions are matched to + ``scipy.signal.convolve2d(data, weights, mode="valid")`` + """ + + try: + return scipy.signal.convolve2d(data, weights, mode="valid") + except ValueError: + pass + + weights = np.rot90(weights, k=2) + + assert len(data.shape) == len(weights.shape) == 2 + + dtype = data.dtype + kernel_h, kernel_w = weights.shape + + output_shape = [a_dim - w_dim + 1 for a_dim, w_dim in zip(data.shape, weights.shape)] + output = np.zeros(output_shape, dtype=dtype) + + for y in range(output_shape[0]): + for x in range(output_shape[1]): + output[y][x] = np.sum(data[y : y + kernel_h, x : x + kernel_w] * weights) + + return output diff --git a/python/tvm/topi/testing/conv2d_nchw_python.py b/python/tvm/topi/testing/conv2d_nchw_python.py index ce5d981cc651..4214ee4a2459 100644 --- a/python/tvm/topi/testing/conv2d_nchw_python.py +++ b/python/tvm/topi/testing/conv2d_nchw_python.py @@ -17,7 +17,8 @@ # pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals, too-many-branches """Convolution in python""" import numpy as np -import scipy.signal +import scipy + from tvm.topi.nn.utils import get_pad_tuple @@ -58,21 +59,67 @@ def _conv2d_nchw_python(a_np, w_np, stride, padding): out_channel = num_filter out_height = (in_height - kernel_h + pad_h) // stride_h + 1 out_width = (in_width - kernel_w + pad_w) // stride_w + 1 - b_np = np.zeros((batch, out_channel, out_height, out_width)) + b_np = np.zeros((batch, out_channel, out_height, out_width), dtype=a_np.dtype) # computation for n in range(batch): for f in range(out_channel): for c in range(in_channel): if pad_h > 0 or pad_w > 0: - apad = np.zeros((in_height + pad_h, in_width + pad_w)) + apad = np.zeros((in_height + pad_h, in_width + pad_w), dtype=a_np.dtype) apad[pad_top : pad_top + in_height, pad_left : pad_left + in_width] = a_np[n, c] else: apad = a_np[n, c] - out = scipy.signal.convolve2d(apad, np.rot90(np.rot90(w_np[f, c])), mode="valid") + + out = _conv2d_hw(apad, w_np[f, c]) b_np[n, f] += out[::stride_h, ::stride_w] return b_np +def _conv2d_hw(apad, w_np_fc): + """2d convolution operator in HW layout. + + This is intended to be used as a subroutine from + _conv2d_nchw_python. Using scipy.signal.convolve2d directly does + not work for all dtypes (e.g. float16). Where possible, this + function uses scipy.signal.convolve2d to take advantage of + compiled scipy routines, falling back to an explicit loop only + where needed + + Parameters + ---------- + a_np : numpy.ndarray + 2-D with shape [in_height, in_width] + + w_np : numpy.ndarray + 2-D with shape [filter_height, filter_width]. + + Returns + ------- + b_np : np.ndarray + 2-D with shape [out_height, out_width] + """ + + try: + return scipy.signal.convolve2d(apad, np.rot90(np.rot90(w_np_fc)), mode="valid") + except ValueError: + pass + + assert len(apad.shape) == len(w_np_fc.shape) == 2 + + dtype = apad.dtype + in_height, in_width = apad.shape + kernel_h, kernel_w = w_np_fc.shape + + output_shape = [a_dim - w_dim + 1 for a_dim, w_dim in zip(apad.shape, w_np_fc.shape)] + output = np.zeros(output_shape, dtype=apad.dtype) + + for y in range(output_shape[0]): + for x in range(output_shape[1]): + output[y][x] = np.sum(apad[y : y + kernel_h, x : x + kernel_w] * w_np_fc) + + return output + + def conv2d_nchw_python(a_np, w_np, stride, padding, groups=1): """Convolution operator in NCHW layout. diff --git a/python/tvm/topi/testing/depthwise_conv2d_python.py b/python/tvm/topi/testing/depthwise_conv2d_python.py index 1ec64b7e7b82..a6247e9f92cc 100644 --- a/python/tvm/topi/testing/depthwise_conv2d_python.py +++ b/python/tvm/topi/testing/depthwise_conv2d_python.py @@ -17,7 +17,9 @@ # pylint: disable=invalid-name, unused-variable, line-too-long """Depthwise convolution in python""" import numpy as np -from scipy import signal + +from tvm.topi.nn.utils import get_pad_tuple +from .common import _convolve2d def depthwise_conv2d_python_nchw(input_np, filter_np, stride, padding): @@ -49,42 +51,29 @@ def depthwise_conv2d_python_nchw(input_np, filter_np, stride, padding): else: stride_h, stride_w = stride - # calculate output shape - if padding == "VALID": - out_channel = in_channel * channel_multiplier - out_height = (in_height - filter_height) // stride_h + 1 - out_width = (in_width - filter_width) // stride_w + 1 - output_np = np.zeros((batch, out_channel, out_height, out_width)) - for i in range(batch): - for j in range(out_channel): - output_np[i, j, :, :] = signal.convolve2d( - input_np[i, j // channel_multiplier, :, :], - np.rot90(filter_np[j // channel_multiplier, j % channel_multiplier, :, :], 2), - mode="valid", - )[ - 0 : (in_height - filter_height + 1) : stride_h, - 0 : (in_width - filter_width + 1) : stride_w, - ] - elif padding == "SAME": - out_channel = in_channel * channel_multiplier - out_height = int(np.ceil(float(in_height) / float(stride_h))) - out_width = int(np.ceil(float(in_width) / float(stride_w))) - output_np = np.zeros((batch, out_channel, out_height, out_width)) - pad_along_height = int(np.max((out_height - 1) * stride_h + filter_height - in_height, 0)) - pad_along_width = int(np.max((out_width - 1) * stride_w + filter_width - in_width, 0)) - pad_top_tvm = int(np.ceil(float(pad_along_height) / 2)) - pad_left_tvm = int(np.ceil(float(pad_along_width) / 2)) - pad_top_scipy = int(np.ceil(float(filter_height - 1) / 2)) - pad_left_scipy = int(np.ceil(float(filter_width - 1) / 2)) - index_h = pad_top_scipy - pad_top_tvm - index_w = pad_left_scipy - pad_left_tvm - for i in range(batch): - for j in range(out_channel): - output_np[i, j, :, :] = signal.convolve2d( - input_np[i, j // channel_multiplier, :, :], - np.rot90(filter_np[j // channel_multiplier, j % channel_multiplier, :, :], 2), - mode="same", - )[index_h:in_height:stride_h, index_w:in_width:stride_w] + pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (filter_height, filter_width)) + pad_h = pad_top + pad_bottom + pad_w = pad_left + pad_right + + out_channel = in_channel * channel_multiplier + out_height = (in_height - filter_height + pad_h) // stride_h + 1 + out_width = (in_width - filter_width + pad_w) // stride_w + 1 + output_np = np.zeros((batch, out_channel, out_height, out_width)) + + for i in range(batch): + for j in range(out_channel): + apad = input_np[i, j // channel_multiplier, :, :] + if pad_h or pad_w: + apad = np.pad(apad, [(pad_top, pad_bottom), (pad_left, pad_right)]) + + conv = _convolve2d( + apad, + np.rot90(filter_np[j // channel_multiplier, j % channel_multiplier, :, :], k=2), + ) + output_np[i, j, :, :] = conv[ + ::stride_h, + ::stride_w, + ] return output_np @@ -139,7 +128,9 @@ def depthwise_conv2d_python_nchwc(input_np, filter_np, stride, padding): # Perform conv2d output_np = depthwise_conv2d_python_nchw(input_nchw, filter_nchw, stride, padding) - # Transform back + # Transform back to NCHWc + + # pylint: disable=unpacking-non-sequence batch_size, out_channel, out_height, out_width = output_np.shape return output_np.reshape( (batch_size, out_channel_chunk, out_channel_block, out_height, out_width) @@ -147,7 +138,7 @@ def depthwise_conv2d_python_nchwc(input_np, filter_np, stride, padding): def depthwise_conv2d_python_nhwc(input_np, filter_np, stride, padding): - """Depthwise convolution operator in nchw layout. + """Depthwise convolution operator in nhwc layout. Parameters ---------- @@ -168,48 +159,7 @@ def depthwise_conv2d_python_nhwc(input_np, filter_np, stride, padding): output_np : np.ndarray 4-D with shape [batch, out_height, out_width, out_channel] """ - batch, in_height, in_width, in_channel = input_np.shape - filter_height, filter_width, _, channel_multiplier = filter_np.shape - if isinstance(stride, int): - stride_h = stride_w = stride - else: - stride_h, stride_w = stride - - # calculate output shape - if padding == "VALID": - out_channel = in_channel * channel_multiplier - out_height = (in_height - filter_height) // stride_h + 1 - out_width = (in_width - filter_width) // stride_w + 1 - output_np = np.zeros((batch, out_height, out_width, out_channel)) - for i in range(batch): - for j in range(out_channel): - output_np[i, :, :, j] = signal.convolve2d( - input_np[i, :, :, j // channel_multiplier], - np.rot90(filter_np[:, :, j // channel_multiplier, j % channel_multiplier], 2), - mode="valid", - )[ - 0 : (in_height - filter_height + 1) : stride_h, - 0 : (in_width - filter_width + 1) : stride_w, - ] - if padding == "SAME": - out_channel = in_channel * channel_multiplier - out_height = int(np.ceil(float(in_height) / float(stride_h))) - out_width = int(np.ceil(float(in_width) / float(stride_w))) - output_np = np.zeros((batch, out_height, out_width, out_channel)) - pad_along_height = int(np.max((out_height - 1) * stride_h + filter_height - in_height, 0)) - pad_along_width = int(np.max((out_width - 1) * stride_w + filter_width - in_width, 0)) - pad_top_tvm = int(np.ceil(float(pad_along_height) / 2)) - pad_left_tvm = int(np.ceil(float(pad_along_width) / 2)) - pad_top_scipy = int(np.ceil(float(filter_height - 1) / 2)) - pad_left_scipy = int(np.ceil(float(filter_width - 1) / 2)) - index_h = pad_top_scipy - pad_top_tvm - index_w = pad_left_scipy - pad_left_tvm - for i in range(batch): - for j in range(out_channel): - output_np[i, :, :, j] = signal.convolve2d( - input_np[i, :, :, j // channel_multiplier], - np.rot90(filter_np[:, :, j // channel_multiplier, j % channel_multiplier], 2), - mode="same", - )[index_h:in_height:stride_h, index_w:in_width:stride_w] - - return output_np + input_nchw = input_np.transpose(0, 3, 1, 2) + filter_nchw = filter_np.transpose(2, 3, 0, 1) + output_nchw = depthwise_conv2d_python_nchw(input_nchw, filter_nchw, stride, padding) + return output_nchw.transpose(0, 2, 3, 1) diff --git a/python/tvm/topi/testing/dilate_python.py b/python/tvm/topi/testing/dilate_python.py index 0ae611559729..8cbd01af7f71 100644 --- a/python/tvm/topi/testing/dilate_python.py +++ b/python/tvm/topi/testing/dilate_python.py @@ -19,7 +19,7 @@ import numpy as np -def dilate_python(input_np, strides, dilation_value=0.0): +def dilate_python(input_np, strides, dilation_value=0.0, out_dtype=None): """Dilate operation. Parameters @@ -33,23 +33,35 @@ def dilate_python(input_np, strides, dilation_value=0.0): dilation_value : int/float, optional Value used to dilate the input. + out_dtype : Option[str] + The datatype of the dilated array. If unspecified, will use + the same dtype as the input array. + Returns ------- output_np : numpy.ndarray n-D, the same layout as Input. + """ - n = len(input_np.shape) - assert len(strides) == n, "Input dimension and strides size dismatch : %d vs %d" % ( - n, + assert len(input_np.shape) == len( + strides + ), "Input dimension and strides size dismatch : %d vs %d" % ( + len(input_np.shape), len(strides), ) - output_size = () - no_zero = () - for i in range(n): - output_size += ((input_np.shape[i] - 1) * strides[i] + 1,) - no_zero += ((range(0, output_size[i], strides[i])),) - output_np = np.ones(shape=output_size) + + if out_dtype is None: + out_dtype = input_np.dtype + + output_size = [ + (input_dim - 1) * stride + 1 for input_dim, stride in zip(input_np.shape, strides) + ] + non_zero_elements = np.ix_( + *[range(0, output_dim, stride) for output_dim, stride in zip(output_size, strides)] + ) + + output_np = np.ones(shape=output_size, dtype=out_dtype) output_np = dilation_value * output_np - output_np[np.ix_(*no_zero)] = input_np + output_np[non_zero_elements] = input_np return output_np From 05287445417ecc164fbde07c46d7d9b74505c27b Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 30 Jun 2021 07:53:34 -0700 Subject: [PATCH 2/6] [Topi][UnitTests] Added float16 tests to test_topi_dense.py --- tests/python/topi/python/test_topi_dense.py | 41 ++++++++++++++------- 1 file changed, 27 insertions(+), 14 deletions(-) diff --git a/tests/python/topi/python/test_topi_dense.py b/tests/python/topi/python/test_topi_dense.py index 235a09400387..195c938fb171 100644 --- a/tests/python/topi/python/test_topi_dense.py +++ b/tests/python/topi/python/test_topi_dense.py @@ -33,6 +33,7 @@ in_dim, out_dim = tvm.testing.parameters((1024, 1000)) in_dtype, out_dtype = tvm.testing.parameters( ("float32", "float32"), + ("float16", "float16"), ("int8", "int32"), ) @@ -90,19 +91,24 @@ def test_dense( ): target = tvm.target.Target(target) - if ( - in_dtype == "int8" - and target.kind.name == "cuda" - and not tvm.contrib.nvcc.have_int8(dev.compute_version) - ): - pytest.xfail("CUDA int8 intrinsics not available") - - if ( - in_dtype == "int8" - and target.kind.name == "vulkan" - and not target.attrs.get("supports_int8", False) - ): - pytest.xfail("Vulkan int8 driver support not available") + if target.kind.name == "cuda": + if in_dtype == "int8" and not tvm.contrib.nvcc.have_int8(dev.compute_version): + pytest.xfail("CUDA int8 intrinsics not available") + + if in_dtype == "float16" and not tvm.contrib.nvcc.have_fp16(dev.compute_version): + pytest.xfail("CUDA float16 intrinsics not available") + + if target.kind.name == "vulkan": + if in_dtype == "int8" and ( + not target.attrs.get("supports_int8", False) + or not target.attrs.get("supports_8bit_buffer", False) + ): + pytest.xfail("Vulkan int8 driver support not available") + if in_dtype == "float16" and ( + not target.attrs.get("supports_float16", False) + or not target.attrs.get("supports_16bit_buffer", False) + ): + pytest.xfail("Vulkan float16 driver support not available") if ( target.kind.name not in ["llvm", "c"] @@ -110,6 +116,13 @@ def test_dense( ): pytest.xfail("No implementation for tvm.topi.testing.dispatch to find") + if "int" in in_dtype: + tol = {"atol": 0, "rtol": 0} + elif in_dtype == "float32": + tol = {"rtol": 1e-5, "atol": 1e-5} + elif in_dtype == "float16": + tol = {"rtol": 5e-2, "atol": 1e-5} + A = te.placeholder((batch_size, in_dim), name="A", dtype=in_dtype) B = te.placeholder((out_dim, in_dim), name="B", dtype=in_dtype) C = te.placeholder((out_dim,), name="C", dtype=out_dtype) @@ -131,7 +144,7 @@ def test_dense( d = tvm.nd.array(np.zeros(get_const_tuple(D.shape), dtype=out_dtype), dev) f = tvm.build(s, [A, B, C, D], target, name="dense") f(a, b, c, d) - tvm.testing.assert_allclose(d.numpy(), d_np, rtol=1e-5) + tvm.testing.assert_allclose(d.numpy(), d_np, **tol) @pytest.mark.parametrize("target,in_dtype,out_dtype", [("cuda", "int8", "int32")]) From 944ad511fac0b4cb2df7cd32c881faafbaa572b3 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 12 Jul 2021 14:25:03 -0500 Subject: [PATCH 3/6] [Topi][UnitTests] Added float16 to test_topi_conv2d_nchw.py --- .../topi/python/test_topi_conv2d_nchw.py | 59 ++++++++++++++++--- 1 file changed, 51 insertions(+), 8 deletions(-) diff --git a/tests/python/topi/python/test_topi_conv2d_nchw.py b/tests/python/topi/python/test_topi_conv2d_nchw.py index 2a4865c6dd8d..83df22f4025d 100644 --- a/tests/python/topi/python/test_topi_conv2d_nchw.py +++ b/tests/python/topi/python/test_topi_conv2d_nchw.py @@ -32,7 +32,7 @@ import tvm.testing -dtype = tvm.testing.parameter("float32") +dtype = tvm.testing.parameter("float16", "float32") @tvm.testing.fixture @@ -62,11 +62,19 @@ def ref_data( add_bias, apply_relu, ): + # scipy.signal.convolve2d does not support float16 data types, and + # the python fallback is too slow for general use. Computing + # ref_data in float32 will have fewer rounding errors than the TVM + # float16 compute, but those vary based on schedule anyways. + conv_dtype = "float32" if dtype == "float16" else dtype + a_np = np.random.uniform(size=input_shape).astype(dtype) w_np = np.random.uniform(size=weight_shape).astype(dtype) b_np = np.random.uniform(size=bias_shape).astype(dtype) dw_np = tvm.topi.testing.dilate_python(w_np, (1, 1, dilation, dilation)) - c_np = tvm.topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding) + c_np = tvm.topi.testing.conv2d_nchw_python( + a_np.astype(conv_dtype), dw_np.astype(conv_dtype), stride, padding + ).astype(dtype) if add_bias: c_np = c_np + b_np @@ -101,6 +109,19 @@ def test_conv2d_nchw( target = tvm.target.Target(target) is_cudnn_target = target.kind.name == "cuda" and "cudnn" in target.attrs.get("libs", []) + if target.kind.name == "vulkan" and dtype == "float16": + if not target.attrs.get("supports_float16", False) or not target.attrs.get( + "supports_16bit_buffer", False + ): + pytest.xfail("Vulkan device does not support float16") + + if ( + target.kind.name == "cuda" + and dtype == "float16" + and not tvm.contrib.nvcc.have_fp16(dev.compute_version) + ): + pytest.xfail("CUDA float16 intrinsics not available") + pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel)) padding_sum = pad_top + pad_left + pad_bottom + pad_right @@ -108,7 +129,20 @@ def test_conv2d_nchw( A = te.placeholder(a_np.shape, name="A", dtype=dtype) W = te.placeholder(w_np.shape, name="W", dtype=dtype) - bias = te.placeholder(b_np.shape, name="bias") + bias = te.placeholder(b_np.shape, name="bias", dtype=dtype) + + if "int" in dtype: + tol = {"atol": 0, "rtol": 0} + elif dtype == "float32": + tol = {"rtol": 1e-4, "atol": 1e-5} + elif dtype == "float16": + # A summation in float16 with a single accumulator very + # quickly runs into large rounding errors. At some point, + # this tolerance should be schedule-dependent for to avoid + # false negatives. + num_values_summed = in_channel * kernel * kernel + gap_size = np.nextafter(c_np.max(), np.inf, dtype=c_np.dtype) - c_np.max() + tol = {"rtol": 1e-3, "atol": num_values_summed * gap_size / 2} with autotvm.tophub.context(target): # load tophub pre-tuned parameters if is_cudnn_target: @@ -138,11 +172,20 @@ def test_conv2d_nchw( s, [A, W, bias, C], target, - name="conv2d_%d_%d_%d_%d_%d_%d_%d_%d" - % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation), + name="conv2d_{}_{}_{}_{}_{}_{}_{}_{}_{}".format( + dtype, + batch, + in_channel, + in_size, + num_filter, + kernel, + stride, + padding_sum, + dilation, + ), ) func(a, w, b, c) - tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-4) + tvm.testing.assert_allclose(c.numpy(), c_np, **tol) @tvm.testing.parametrize_targets("llvm") def test_workload_padding( @@ -288,8 +331,8 @@ class TestBatchSize(BaseConv2DTests): class TestBiasRelu(BaseConv2DTests): - add_relu = tvm.testing.parameter(True, False) - add_bias = tvm.testing.parameter(True, False) + apply_relu = tvm.testing.parameter(True, False, ids=["relu", "no_relu"]) + add_bias = tvm.testing.parameter(True, False, ids=["bias", "no_bias"]) in_channel, in_size, num_filter, kernel, stride, padding = tvm.testing.parameters( (64, 56, 64, 3, 1, 1), (64, 8, 64, 3, 1, (1, 2, 2, 1)), From 8e5e006060e9b3fd9cf3d62ef7d8e5f376d8fd51 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 12 Jul 2021 14:12:59 -0500 Subject: [PATCH 4/6] [Topi][Float16] Added float16 tests for depthwise conv2d. --- .../topi/python/test_topi_depthwise_conv2d.py | 50 +++++++++++++++++-- 1 file changed, 47 insertions(+), 3 deletions(-) diff --git a/tests/python/topi/python/test_topi_depthwise_conv2d.py b/tests/python/topi/python/test_topi_depthwise_conv2d.py index 092ac9df5f9a..775536cc65b1 100644 --- a/tests/python/topi/python/test_topi_depthwise_conv2d.py +++ b/tests/python/topi/python/test_topi_depthwise_conv2d.py @@ -68,7 +68,10 @@ } -in_dtype, out_dtype = tvm.testing.parameters(("float32", "float32")) +in_dtype, out_dtype = tvm.testing.parameters( + ("float32", "float32"), + ("float16", "float16"), +) @tvm.testing.fixture @@ -133,6 +136,12 @@ def ref_data( use_scale_shift, apply_relu, ): + # scipy.signal.convolve2d does not support float16 data types, and + # the python fallback is too slow for general use. Computing + # ref_data in float32 will have fewer rounding errors than the TVM + # float16 compute, but those vary based on schedule anyways. + conv_dtype = "float32" if in_dtype == "float16" else in_dtype + input_np = np.random.uniform(size=input_shape).astype(in_dtype) filter_np = np.random.uniform(size=filter_shape).astype(in_dtype) scale_np = np.random.uniform(size=scale_shape).astype(out_dtype) @@ -151,7 +160,9 @@ def ref_data( reshape = (1, scale_shape[0], 1, 1, scale_shape[1]) dilated_filter_np = tvm.topi.testing.dilate_python(filter_np, dilation) - output_np = np_depthwise_conv2d(input_np, dilated_filter_np, stride, padding) + output_np = np_depthwise_conv2d( + input_np.astype(conv_dtype), dilated_filter_np.astype(conv_dtype), stride, padding + ).astype(out_dtype) if use_scale_shift: output_np = output_np * scale_np.reshape(reshape) + shift_np.reshape(reshape) @@ -211,6 +222,23 @@ def test_conv2d( padding, dilation, ): + if ( + target.kind.name == "cuda" + and in_dtype == "float16" + and not tvm.contrib.nvcc.have_fp16(dev.compute_version) + ): + pytest.xfail("CUDA float16 intrinsics not available") + + if ( + target.kind.name == "vulkan" + and in_dtype == "float16" + and ( + not target.attrs.get("supports_float16", False) + or not target.attrs.get("supports_16bit_buffer", False) + ) + ): + pytest.xfail("Vulkan float16 driver support not available") + # Transform the padding argument from 'str' to 'tuple' to # match the "workload" tuple in TopHub. Which padding_args to # use for each layout chosen to reproduce previous behavior. @@ -275,6 +303,22 @@ def test_conv2d( input_np, filter_np, scale_np, shift_np, output_np = request.getfixturevalue( "ref_data" ) + if "int" in out_dtype: + tol = {"atol": 0, "rtol": 0} + elif out_dtype == "float32": + tol = {"rtol": 1e-4, "atol": 1e-5} + elif out_dtype == "float16": + # A summation in float16 with a single accumulator very + # quickly runs into large rounding errors. At some point, + # this tolerance should be schedule-dependent for to avoid + # false negatives. + num_values_summed = kernel * kernel + gap_size = ( + np.nextafter(output_np.max(), np.inf, dtype=output_np.dtype) + - output_np.max() + ) + tol = {"rtol": 1e-3, "atol": num_values_summed * gap_size / 2} + input_tvm = tvm.nd.array(input_np, dev) filter_tvm = tvm.nd.array(filter_np, dev) scale_tvm = tvm.nd.array(scale_np, dev) @@ -285,7 +329,7 @@ def test_conv2d( ) f(input_tvm, filter_tvm, scale_tvm, shift_tvm, output_tvm) - tvm.testing.assert_allclose(output_np, output_tvm.numpy(), rtol=1e-5) + tvm.testing.assert_allclose(output_np, output_tvm.numpy(), **tol) class TestDepthwiseConv2D(BaseDepthwiseConv2D): From 211a7152161b373ae36774d28a40a31e9192a742 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 22 Jul 2021 13:03:09 -0500 Subject: [PATCH 5/6] [UnitTests] Explicitly set seed for float16 tests Intended to avoid flaky test failures later due to rounding errors. --- tests/python/topi/python/test_topi_conv2d_nchw.py | 4 ++++ tests/python/topi/python/test_topi_dense.py | 6 +++++- tests/python/topi/python/test_topi_depthwise_conv2d.py | 4 ++++ 3 files changed, 13 insertions(+), 1 deletion(-) diff --git a/tests/python/topi/python/test_topi_conv2d_nchw.py b/tests/python/topi/python/test_topi_conv2d_nchw.py index 83df22f4025d..cfebdee906f4 100644 --- a/tests/python/topi/python/test_topi_conv2d_nchw.py +++ b/tests/python/topi/python/test_topi_conv2d_nchw.py @@ -33,6 +33,7 @@ import tvm.testing dtype = tvm.testing.parameter("float16", "float32") +random_seed = tvm.testing.parameter(0) @tvm.testing.fixture @@ -52,6 +53,7 @@ def bias_shape(num_filter): @tvm.testing.fixture(cache_return_value=True) def ref_data( + random_seed, input_shape, weight_shape, bias_shape, @@ -62,6 +64,8 @@ def ref_data( add_bias, apply_relu, ): + np.random.seed(random_seed) + # scipy.signal.convolve2d does not support float16 data types, and # the python fallback is too slow for general use. Computing # ref_data in float32 will have fewer rounding errors than the TVM diff --git a/tests/python/topi/python/test_topi_dense.py b/tests/python/topi/python/test_topi_dense.py index 195c938fb171..8f58415da329 100644 --- a/tests/python/topi/python/test_topi_dense.py +++ b/tests/python/topi/python/test_topi_dense.py @@ -28,6 +28,8 @@ from common import Int8Fallback +random_seed = tvm.testing.parameter(0) + use_bias = tvm.testing.parameter(True, False) batch_size = tvm.testing.parameter(1, 2, 128) in_dim, out_dim = tvm.testing.parameters((1024, 1000)) @@ -56,7 +58,9 @@ @tvm.testing.fixture(cache_return_value=True) -def dense_ref_data(batch_size, in_dim, out_dim, use_bias, in_dtype, out_dtype): +def dense_ref_data(random_seed, batch_size, in_dim, out_dim, use_bias, in_dtype, out_dtype): + np.random.seed(random_seed) + if "float" in in_dtype: a_np = np.random.uniform(size=(batch_size, in_dim)).astype(in_dtype) b_np = np.random.uniform(size=(out_dim, in_dim)).astype(in_dtype) diff --git a/tests/python/topi/python/test_topi_depthwise_conv2d.py b/tests/python/topi/python/test_topi_depthwise_conv2d.py index 775536cc65b1..085f7c725813 100644 --- a/tests/python/topi/python/test_topi_depthwise_conv2d.py +++ b/tests/python/topi/python/test_topi_depthwise_conv2d.py @@ -67,6 +67,7 @@ }, } +random_seed = tvm.testing.parameter(0) in_dtype, out_dtype = tvm.testing.parameters( ("float32", "float32"), @@ -123,6 +124,7 @@ def shift_shape(scale_shape): @tvm.testing.fixture(cache_return_value=True) def ref_data( + random_seed, in_dtype, out_dtype, layout, @@ -136,6 +138,8 @@ def ref_data( use_scale_shift, apply_relu, ): + np.random.seed(random_seed) + # scipy.signal.convolve2d does not support float16 data types, and # the python fallback is too slow for general use. Computing # ref_data in float32 will have fewer rounding errors than the TVM From e061b0da4b7e725f649e74fa21ba78b61ded95c3 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 29 Jul 2021 09:16:55 -0500 Subject: [PATCH 6/6] [UnitTests] Fixed a few failing unit tests. - ref_data must be a test fixture, not acquired through request.getfixturevalue, in order to have the random_seed be known. - dilate_python's return value didn't follow `out_dtype`. - The test_topi_conv3d tests had the reference results computed in float64, due to dilate_python() not respecting the input data type. With the correct dtype, the tolerances needed to be slightly widened. --- python/tvm/topi/testing/dilate_python.py | 3 +-- tests/python/topi/python/test_topi_conv3d_ncdhw.py | 2 +- tests/python/topi/python/test_topi_conv3d_winograd.py | 2 +- tests/python/topi/python/test_topi_depthwise_conv2d.py | 7 +++---- 4 files changed, 6 insertions(+), 8 deletions(-) diff --git a/python/tvm/topi/testing/dilate_python.py b/python/tvm/topi/testing/dilate_python.py index 8cbd01af7f71..43559e3cee12 100644 --- a/python/tvm/topi/testing/dilate_python.py +++ b/python/tvm/topi/testing/dilate_python.py @@ -60,8 +60,7 @@ def dilate_python(input_np, strides, dilation_value=0.0, out_dtype=None): *[range(0, output_dim, stride) for output_dim, stride in zip(output_size, strides)] ) - output_np = np.ones(shape=output_size, dtype=out_dtype) - output_np = dilation_value * output_np + output_np = np.full(shape=output_size, fill_value=dilation_value, dtype=out_dtype) output_np[non_zero_elements] = input_np return output_np diff --git a/tests/python/topi/python/test_topi_conv3d_ncdhw.py b/tests/python/topi/python/test_topi_conv3d_ncdhw.py index 056ef7fc146a..c45aaa188834 100644 --- a/tests/python/topi/python/test_topi_conv3d_ncdhw.py +++ b/tests/python/topi/python/test_topi_conv3d_ncdhw.py @@ -116,7 +116,7 @@ def check_target(target, dev): % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation), ) func(a, w, c) - tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-4) + tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-4, atol=1e-6) for target, dev in tvm.testing.enabled_targets(): with autotvm.tophub.context(target): # load tophub pre-tuned parameters diff --git a/tests/python/topi/python/test_topi_conv3d_winograd.py b/tests/python/topi/python/test_topi_conv3d_winograd.py index 0b9d579287c3..54dd72a2f544 100644 --- a/tests/python/topi/python/test_topi_conv3d_winograd.py +++ b/tests/python/topi/python/test_topi_conv3d_winograd.py @@ -138,7 +138,7 @@ def check_device(device): ), ) func(a, w, c) - tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-4) + tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-4, atol=1e-6) for device in ["cuda"]: with autotvm.tophub.context(device): # load tophub pre-tuned parameters diff --git a/tests/python/topi/python/test_topi_depthwise_conv2d.py b/tests/python/topi/python/test_topi_depthwise_conv2d.py index 085f7c725813..5952e624b708 100644 --- a/tests/python/topi/python/test_topi_depthwise_conv2d.py +++ b/tests/python/topi/python/test_topi_depthwise_conv2d.py @@ -206,7 +206,6 @@ class BaseDepthwiseConv2D: def test_conv2d( self, - request, target, dev, in_dtype, @@ -225,7 +224,9 @@ def test_conv2d( stride, padding, dilation, + ref_data, ): + target = tvm.target.Target(target) if ( target.kind.name == "cuda" and in_dtype == "float16" @@ -304,9 +305,7 @@ def test_conv2d( f = tvm.build(s, [Input, Filter, Scale, Shift, C], target) if self.run_after_compile: - input_np, filter_np, scale_np, shift_np, output_np = request.getfixturevalue( - "ref_data" - ) + input_np, filter_np, scale_np, shift_np, output_np = ref_data if "int" in out_dtype: tol = {"atol": 0, "rtol": 0} elif out_dtype == "float32":