From 43f585798baa177bfed09154e450c0eb68993acc Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 29 Nov 2022 08:37:06 -0600 Subject: [PATCH 01/12] [TOPI] Use integer arithmetic for topi.image.resize Prior to this commit, floating point expressions were used to map between different-sized pixel arrays. These floating point expressions are less aggressively optimized by `RewriteSimplifier`, which can prevent some optimizations This was first noticed during investigation into issue #13508. Benchmarks of `topi.image.resize` showed 1000x and 50x performance improvements using the LLVM and CUDA backends, respectively, by using integer expressions instead of floating point. This performance improvement is partly driven by enabling `tir.transform.VectorizeLoops` to recognize vectorizable indices, where the round-trip through floating point previously prevented that optimization. --- python/tvm/topi/image/resize.py | 303 ++++++++--------- python/tvm/topi/utils.py | 365 ++++++++++++++++++++ src/arith/rewrite_simplify.cc | 11 +- src/tir/transforms/remove_no_op.cc | 1 + tests/python/topi/test_topi_image.py | 487 +++++++++++++-------------- 5 files changed, 745 insertions(+), 422 deletions(-) diff --git a/python/tvm/topi/image/resize.py b/python/tvm/topi/image/resize.py index 29ed03f62e74..d9b31ece4f98 100644 --- a/python/tvm/topi/image/resize.py +++ b/python/tvm/topi/image/resize.py @@ -16,27 +16,14 @@ # under the License. # pylint: disable=invalid-name """TVM operator input resize compute.""" + from __future__ import absolute_import + import tvm from tvm import te -from tvm.topi.utils import nchw_pack_layout, nchw_xc_layout -from .. import tag - +from tvm.topi.utils import nchw_pack_layout, nchw_xc_layout, Fraction -def can_convert_multiply_to_intdiv(origin_size, scaled_size): - """Check whether can convert multiplication to division""" - # Only support IntImm type - if not isinstance(scaled_size, tvm.tir.expr.IntImm): - return False - - div = scaled_size / origin_size.astype("float") - if div.value % 1 != 0: - return False - epsilon = 1e-5 - check = 1 / (epsilon * origin_size + epsilon) - if div > check: - return False - return True +from .. import tag def get_1d_indices(indices, layout="NCW"): @@ -136,71 +123,96 @@ def get_3d_pixel(data, layout, image_depth, image_height, image_width, n, c, z, def get_inx( - x, + target_x, image_width, target_width, coordinate_transformation_mode, start_x=0, end_x=-1, - use_int_div=False, ): """Infer input x from output x with various coordinate transformation methods""" - scale_x = te.div(image_width.astype("float"), target_width.astype("float")) + + non_trivial_target_width = target_width > 1 + + def _as_fraction_or_float(expr): + try: + return Fraction(expr) + except ValueError: + return expr.astype("float") + + image_width = _as_fraction_or_float(image_width) + target_width = _as_fraction_or_float(target_width) + target_x = _as_fraction_or_float(target_x) + + scale_x = image_width / target_width + if coordinate_transformation_mode == "half_pixel": - in_x = (x + 0.5) * scale_x - 0.5 + return (target_x + 0.5) * scale_x - 0.5 elif coordinate_transformation_mode == "align_corners": - in_x = (image_width - 1).astype("float") / (target_width - 1) * x + return (image_width - 1) / (target_width - 1) * target_x elif coordinate_transformation_mode == "asymmetric": - if use_int_div: - in_x = te.div(x, te.div(target_width, image_width)) - else: - in_x = scale_x * x + return scale_x * target_x elif coordinate_transformation_mode == "pytorch_half_pixel": - in_x = te.if_then_else(target_width > 1, (x + 0.5) * scale_x - 0.5, 0.0) + return te.if_then_else(non_trivial_target_width, (target_x + 0.5) * scale_x - 0.5, 0.0) elif coordinate_transformation_mode == "tf_half_pixel_for_nn": - in_x = (x + 0.5) * scale_x + return (target_x + 0.5) * scale_x elif coordinate_transformation_mode == "tf_crop_and_resize": - in_x = te.if_then_else( - target_width > 1, + start_x = _as_fraction_or_float(start_x) + end_x = _as_fraction_or_float(end_x) + return te.if_then_else( + non_trivial_target_width, start_x * (image_width - 1) - + x * (end_x - start_x) * (image_width - 1).astype("float") / (target_width - 1), + + target_x * (end_x - start_x) * (image_width - 1) / (target_width - 1), 0.5 * (start_x + end_x) * (image_width - 1), ) else: raise ValueError( f"Unsupported coordinate_transformation_mode: {coordinate_transformation_mode}" ) - return in_x -def get_closest_index(in_x, rounding_method, boxes, use_int_div=False): +def get_closest_index(in_x, rounding_method, boxes): """get the closest index to a value based on a certain rounding method""" - if use_int_div: - closest_x_index = in_x.astype("int32") - return closest_x_index - - if rounding_method == "round" or boxes is not None: - closest_x_index = te.round(in_x).astype("int32") - elif rounding_method == "round_prefer_floor": - closest_x_index = te.ceil(in_x - 0.5).astype("int32") - elif rounding_method == "round_prefer_ceil": - closest_x_index = te.floor(in_x + 0.5).astype("int32") - elif rounding_method == "floor": - # Add epsilon to floor to prevent gpu rounding errors. - epsilon = 1e-5 - closest_x_index = te.floor(in_x + epsilon).astype("int32") - elif rounding_method == "ceil": - # Subract epsilon from ceil to prevent gpu rounding errors. - epsilon = 1e-5 - closest_x_index = te.ceil(in_x - epsilon).astype("int32") + if isinstance(in_x, Fraction): + # Preferred path, if the initial sizes were an integer ratio. + + numerator = in_x.numerator + denominator = in_x.denominator + if rounding_method in ("round", "round_prefer_floor") or boxes is not None: + return (numerator + denominator // 2) // denominator + elif rounding_method == "round_prefer_ceil": + return (numerator + (denominator + 1) // 2) // denominator + elif rounding_method == "floor": + return numerator // denominator + elif rounding_method == "ceil": + return (numerator + denominator - 1) // denominator + else: + raise ValueError("Uknown rounding method: {}".format(rounding_method)) + else: - raise ValueError(f"Unknown rounding method: {rounding_method}") - return closest_x_index + # Preferred path, using floating-point values + + if rounding_method == "round" or boxes is not None: + return te.round(in_x).astype("int32") + elif rounding_method == "round_prefer_floor": + return te.ceil(in_x - 0.5).astype("int32") + elif rounding_method == "round_prefer_ceil": + return te.floor(in_x + 0.5).astype("int32") + elif rounding_method == "floor": + # Add epsilon to floor to prevent gpu rounding errors. + epsilon = 1e-5 + return te.floor(in_x + epsilon).astype("int32") + elif rounding_method == "ceil": + # Subract epsilon from ceil to prevent gpu rounding errors. + epsilon = 1e-5 + return te.ceil(in_x - epsilon).astype("int32") + else: + raise ValueError(f"Unknown rounding method: {rounding_method}") def _lerp(A, B, t): """Perform Linear interpolation in 1D""" - return A * (1.0 - t) + B * t + return (1.0 - t) * A + t * B def _cubic_spline_weights(t, alpha): @@ -214,9 +226,9 @@ def _cubic_spline_weights(t, alpha): return [w1, w2, w3, w4] -def _cubic_kernel(inputs, w): +def _sum_products(a, b): """perform cubic interpolation in 1D""" - return sum([a_i * w_i for a_i, w_i in zip(inputs, w)]) + return sum([a_i * b_i for a_i, b_i in zip(a, b)]) def _resize_1d( @@ -236,7 +248,6 @@ def _resize_1d( exclude_outside=0, out_dtype=None, ): - """Perform resize operation on the data with selected method and options. Parameters @@ -302,12 +313,8 @@ def _resize_1d( The computed result with type out_dtype """ - def _cast_output(value, data_dtype="float32", out_dtype=None): - if out_dtype: - dtype = out_dtype - else: - dtype = data_dtype - return value.astype(dtype) + if out_dtype is None: + out_dtype = data.dtype n, c, x, cc, inum, ic = get_1d_indices(indices, layout) box_idx = box_indices(n) if box_indices is not None else n @@ -327,9 +334,7 @@ def _cast_output(value, data_dtype="float32", out_dtype=None): value = get_1d_pixel(data, layout, image_width, box_idx, c, closest_x_index, cc, inum, ic) elif method == "linear": - x_int = te.floor(in_x).astype("int32") - - x_lerp = in_x - x_int + x_int, x_lerp = Fraction.split_whole_and_fractional_parts(in_x) p = [0 for i in range(2)] for i in range(2): @@ -338,8 +343,7 @@ def _cast_output(value, data_dtype="float32", out_dtype=None): value = _lerp(*p, x_lerp) elif method == "cubic": - xint = te.floor(in_x).astype("int32") - xfract = in_x - te.floor(in_x) + xint, xfract = Fraction.split_whole_and_fractional_parts(in_x) # Get the surrounding values p = [0 for i in range(4)] @@ -354,7 +358,7 @@ def _cast_output(value, data_dtype="float32", out_dtype=None): ) sum_wx = sum(wx) wx = [w / sum_wx for w in wx] - value = _cubic_kernel(p, wx) + value = _sum_products(wx, p) else: raise ValueError("Unknown resize method:", method) @@ -366,7 +370,8 @@ def _cast_output(value, data_dtype="float32", out_dtype=None): extrapolation_value, tvm.tir.if_then_else(in_x > image_width - 1, extrapolation_value, value), ) - return _cast_output(value, data.dtype, out_dtype=out_dtype) + + return value.astype(out_dtype) def resize1d( @@ -510,7 +515,6 @@ def _resize_2d( exclude_outside=0, out_dtype=None, ): - """Perform resize operation on the data with selected method and options. Parameters @@ -582,18 +586,8 @@ def _resize_2d( The computed result with type out_dtype """ - def _cast_output(value, data_dtype="float32", out_dtype=None): - if out_dtype: - dtype = out_dtype - else: - dtype = data_dtype - return value.astype(dtype) - - height_use_int_div = False - width_use_int_div = False - if method == "nearest_neighbor" and coordinate_transformation_mode == "asymmetric": - height_use_int_div = can_convert_multiply_to_intdiv(image_height, target_height) - width_use_int_div = can_convert_multiply_to_intdiv(image_width, target_width) + if out_dtype is None: + out_dtype = data.dtype n, c, y, x, cc, inum, ic = get_2d_indices(indices, layout) box_idx = box_indices(n) if box_indices is not None else n @@ -601,13 +595,15 @@ def _cast_output(value, data_dtype="float32", out_dtype=None): y1, x1 = boxes(n, 0), boxes(n, 1) y2, x2 = boxes(n, 2), boxes(n, 3) - in_h = (image_height - 1) * (y2 - y1) - in_w = (image_width - 1) * (x2 - x1) - h_scale = in_h.astype("float") / (target_height - 1) - w_scale = in_w.astype("float") / (target_width - 1) + in_h = Fraction.OrPrimExpr((image_height - 1) * (y2 - y1)) + in_w = Fraction.OrPrimExpr((image_width - 1) * (x2 - x1)) + + h_scale = in_h / (target_height - 1) + w_scale = in_w / (target_width - 1) + + in_y = h_scale * y + y1 * (image_height - 1) + in_x = w_scale * x + x1 * (image_width - 1) - in_y = y1 * (image_height - 1) + h_scale * y - in_x = x1 * (image_width - 1) + w_scale * x else: in_x = get_inx( x, @@ -616,7 +612,6 @@ def _cast_output(value, data_dtype="float32", out_dtype=None): coordinate_transformation_mode, roi[1], roi[3], - width_use_int_div, ) in_y = get_inx( y, @@ -625,7 +620,6 @@ def _cast_output(value, data_dtype="float32", out_dtype=None): coordinate_transformation_mode, roi[0], roi[2], - height_use_int_div, ) if method == "nearest_neighbor": @@ -635,8 +629,8 @@ def _cast_output(value, data_dtype="float32", out_dtype=None): else: rounding_method = "floor" - closest_x_index = get_closest_index(in_x, rounding_method, boxes, width_use_int_div) - closest_y_index = get_closest_index(in_y, rounding_method, boxes, height_use_int_div) + closest_x_index = get_closest_index(in_x, rounding_method, boxes) + closest_y_index = get_closest_index(in_y, rounding_method, boxes) value = get_2d_pixel( data, @@ -652,11 +646,8 @@ def _cast_output(value, data_dtype="float32", out_dtype=None): ic, ) elif method == "linear": - y_int = te.floor(in_y).astype("int32") - x_int = te.floor(in_x).astype("int32") - - y_lerp = in_y - y_int - x_lerp = in_x - x_int + x_int, x_lerp = Fraction.split_whole_and_fractional_parts(in_x) + y_int, y_lerp = Fraction.split_whole_and_fractional_parts(in_y) p = [[0 for i in range(2)] for j in range(2)] for j in range(2): @@ -680,11 +671,11 @@ def _cast_output(value, data_dtype="float32", out_dtype=None): value = _lerp(top, bottom, y_lerp) elif method == "cubic": - xint = te.floor(in_x).astype("int32") - xfract = in_x - te.floor(in_x) + xint, xfract = Fraction.split_whole_and_fractional_parts(in_x) + yint, yfract = Fraction.split_whole_and_fractional_parts(in_y) - yint = te.floor(in_y).astype("int32") - yfract = in_y - te.floor(in_y) + wx = _cubic_spline_weights(xfract, alpha) + wy = _cubic_spline_weights(yfract, alpha) # Get the surrounding values p = [[0 for i in range(4)] for j in range(4)] @@ -704,8 +695,6 @@ def _cast_output(value, data_dtype="float32", out_dtype=None): ic, ) - wx = _cubic_spline_weights(xfract, alpha) - wy = _cubic_spline_weights(yfract, alpha) if exclude_outside: for i in range(4): wx[i] = te.if_then_else( @@ -718,11 +707,12 @@ def _cast_output(value, data_dtype="float32", out_dtype=None): sum_wy = sum(wy) wx = [w / sum_wx for w in wx] wy = [w / sum_wy for w in wy] - col0 = _cubic_kernel(p[0], wx) - col1 = _cubic_kernel(p[1], wx) - col2 = _cubic_kernel(p[2], wx) - col3 = _cubic_kernel(p[3], wx) - value = _cubic_kernel([col0, col1, col2, col3], wy) + + col0 = _sum_products(wx, p[0]) + col1 = _sum_products(wx, p[1]) + col2 = _sum_products(wx, p[2]) + col3 = _sum_products(wx, p[3]) + value = _sum_products(wy, [col0, col1, col2, col3]) else: raise ValueError("Unknown resize method:", method) @@ -739,7 +729,8 @@ def _cast_output(value, data_dtype="float32", out_dtype=None): extrapolation_value, tvm.tir.if_then_else(in_x > image_width - 1, extrapolation_value, out), ) - return _cast_output(value, data.dtype, out_dtype=out_dtype) + + return value.astype(out_dtype) def resize2d( @@ -976,7 +967,6 @@ def _resize_3d( exclude_outside=0, out_dtype=None, ): - """Perform resize operation on the data with selected method and options. Parameters @@ -1054,12 +1044,8 @@ def _resize_3d( The computed result with type out_dtype """ - def _cast_output(value, data_dtype="float32", out_dtype=None): - if out_dtype: - dtype = out_dtype - else: - dtype = data_dtype - return value.astype(dtype) + if out_dtype is None: + out_dtype = data.dtype n, c, z, y, x, cc = get_3d_indices(indices, layout) box_idx = box_indices(n) if box_indices is not None else n @@ -1095,13 +1081,9 @@ def _cast_output(value, data_dtype="float32", out_dtype=None): cc, ) elif method == "linear": - z_int = te.floor(in_z).astype("int32") - y_int = te.floor(in_y).astype("int32") - x_int = te.floor(in_x).astype("int32") - - z_lerp = in_z - z_int - y_lerp = in_y - y_int - x_lerp = in_x - x_int + x_int, x_lerp = Fraction.split_whole_and_fractional_parts(in_x) + y_int, y_lerp = Fraction.split_whole_and_fractional_parts(in_y) + z_int, z_lerp = Fraction.split_whole_and_fractional_parts(in_z) p = [[[0 for i in range(2)] for j in range(2)] for k in range(2)] for k in range(2): @@ -1130,14 +1112,13 @@ def _cast_output(value, data_dtype="float32", out_dtype=None): value = _lerp(top, bottom, z_lerp) elif method == "cubic": - zint = te.floor(in_z).astype("int32") - zfract = in_z - te.floor(in_z) - - yint = te.floor(in_y).astype("int32") - yfract = in_y - te.floor(in_y) + xint, xfract = Fraction.split_whole_and_fractional_parts(in_x) + yint, yfract = Fraction.split_whole_and_fractional_parts(in_y) + zint, zfract = Fraction.split_whole_and_fractional_parts(in_z) - xint = te.floor(in_x).astype("int32") - xfract = in_x - te.floor(in_x) + wz = _cubic_spline_weights(zfract, alpha) + wy = _cubic_spline_weights(yfract, alpha) + wx = _cubic_spline_weights(xfract, alpha) # Get the surrounding values p = [[[0 for i in range(4)] for j in range(4)] for k in range(4)] @@ -1158,36 +1139,33 @@ def _cast_output(value, data_dtype="float32", out_dtype=None): cc, ) - wz = _cubic_spline_weights(zfract, alpha) - wy = _cubic_spline_weights(yfract, alpha) - wx = _cubic_spline_weights(xfract, alpha) - if exclude_outside: - for i in range(4): - wz[i] = te.if_then_else( - te.any(xint - 1 + i < 0, xint + i > image_height), 0.0, wx[i] - ) - wy[i] = te.if_then_else( - te.any(yint - 1 + i < 0, yint + i > image_height), 0.0, wy[i] - ) - wx[i] = te.if_then_else( - te.any(xint - 1 + i < 0, xint + i > image_width), 0.0, wx[i] - ) - sum_wz = sum(wz) - sum_wy = sum(wy) - sum_wx = sum(wx) - wz = [w / sum_wz for w in wz] - wy = [w / sum_wy for w in wy] - wx = [w / sum_wx for w in wx] - - l = [[0 for i in range(4)] for j in range(4)] - for j in range(4): - for i in range(4): - l[j][i] = _cubic_kernel(p[j][i], wx) - col0 = _cubic_kernel(l[0], wy) - col1 = _cubic_kernel(l[1], wy) - col2 = _cubic_kernel(l[2], wy) - col3 = _cubic_kernel(l[3], wy) - value = _cubic_kernel([col0, col1, col2, col3], wz) + if exclude_outside: + for i in range(4): + wz[i] = te.if_then_else( + te.any(xint - 1 + i < 0, xint + i > image_height), 0.0, wx[i] + ) + wy[i] = te.if_then_else( + te.any(yint - 1 + i < 0, yint + i > image_height), 0.0, wy[i] + ) + wx[i] = te.if_then_else( + te.any(xint - 1 + i < 0, xint + i > image_width), 0.0, wx[i] + ) + sum_wz = sum(wz) + sum_wy = sum(wy) + sum_wx = sum(wx) + wz = [w / sum_wz for w in wz] + wy = [w / sum_wy for w in wy] + wx = [w / sum_wx for w in wx] + + l = [[0 for i in range(4)] for j in range(4)] + for j in range(4): + for i in range(4): + l[j][i] = _sum_products(p[j][i], wx) + col0 = _sum_products(wy, l[0]) + col1 = _sum_products(wy, l[1]) + col2 = _sum_products(wy, l[2]) + col3 = _sum_products(wy, l[3]) + value = _sum_products(wz, [col0, col1, col2, col3]) else: raise ValueError("Unknown resize method:", method) @@ -1209,7 +1187,8 @@ def _cast_output(value, data_dtype="float32", out_dtype=None): extrapolation_value, tvm.tir.if_then_else(in_x > image_width - 1, extrapolation_value, out), ) - return _cast_output(value, data.dtype, out_dtype=out_dtype) + + return value.astype(out_dtype) def resize3d( diff --git a/python/tvm/topi/utils.py b/python/tvm/topi/utils.py index 71599ad74a62..39cc85652aab 100644 --- a/python/tvm/topi/utils.py +++ b/python/tvm/topi/utils.py @@ -18,7 +18,10 @@ """Common topi utilities""" from __future__ import absolute_import as _abs +import fractions +import math from numbers import Integral +from typing import Union, Optional, Tuple import numpy as np import tvm @@ -526,3 +529,365 @@ def is_target(names): def is_dynamic_shape(shape): """Checks if any part of a shape is dynamic""" return any([isinstance(x, (Any, SizeVar)) for x in shape]) + + +class Fraction: + """Utility class for representing integer ratios + + TVM's simplifier has specific handling for integer expressions, + especially as they appear in indexing. As the simplifier is used + to check if an optimization is permissible (e.g. vectorized + computations require linear buffer access), use of integer + expressions may provide significant performance benefits. + However, writing the simplified form + + The `Fraction` class is intended to allow for easier writing of + integer expressions. The operator overloads will attempt to + generate the resulting `Fraction` (e.g. `Fraction(Var('x')+2, 3) * + 0.75` evaluates to `Fraction((Var('x')+2)*3, 12)`). If the result + cannot be expressed as a fraction, the `Fraction` will be + converted to the appropriate `PrimExpr` type for us. + (e.g. `Fraction(3,4) * Var('pi')` evaluates to `tir.Mul(0.75, + Var('pi'))`). This allows integer arguments to be converted into + fractions where possible, and maintained as integer fractions + while generating a TIR expression. + + + Example + ------- + + When resizing an image from `original_width` to `output_width`, + determining the location in the original space for a given output + pixel. + + .. code-block:: + + resized_x = Fraction.OrPrimExpr(resized_x) + original_width = Fraction.OrPrimExpr(original_width) + resized_width = Fraction.OrPrimExpr(resized_width) + original_x = original_width / resized_width * (resized_x + 0.5) - 0.5 + + If `original_width`, `resized_width`, and `resized_x` are all integer + parameters, this will result in a `Fraction` equivalent to + `Fraction(numerator = original_width * (2*resized_x + 1) - resized_width, + denominator = 2*resized_width)`. If any of the parameters cannot be + represented as an integer fraction, the expression will instead use + floating-point arithmetic. + + To return a `PrimExpr` after using a `Fraction`, use the + `.astype(out_dtype)` method. This method is implemented for both + `PrimExpr` and `Fraction`, so the type coercion can be applied for both + usages. + + .. code-block:: + + output = original_x.astype('float32') + + To extract integer/fractional components of the expression, use the + utility method `Fraction.split_whole_and_fractional_parts`. + + .. code-block:: + + int_part, remainder = Fraction.split_whole_and_fractional_parts(original_x) + + """ + + def __init__( + self, + numerator: Union[int, float, "Fraction", tvm.tir.PrimExpr], + denominator: Optional[Union[int, float, "Fraction", tvm.tir.PrimExpr]] = None, + ): + """Initialize the Fraction + + Parameters + ---------- + numerator: Union[int, float, Fraction, PrimExpr] + + The numerator of the fraction. If a `float` or `tir.FloatImm` is + passed, will attempt to convert to a ratio of integers. If an exact + representation is not found, will raise a `ValueError`. Any other + `tir.PrimExpr` with a floating-point data-type will also result in a + `ValueError`. + + denominator: Optional[Union[int, float, Fraction, PrimExpr]] + + The denominator of the fraction. + + If a `float` or `tir.FloatImm` is passed, will attempt to convert to + a ratio of integers. If an exact representation is not found, will + raise a `ValueError`. Any other `tir.PrimExpr` with a floating-point + data-type will also result in a `ValueError`. + + If `None`, set equal to 1. + """ + + def _normalize(value): + if isinstance(value, Fraction): + return value + + elif isinstance(value, int): + return tvm.runtime.convert(value) + + elif isinstance(value, tvm.tir.PrimExpr) and "int" in value.dtype: + return value + + elif isinstance(value, (float, tvm.tir.FloatImm)): + # A floating-point number may result from previous division + # of integers. Use python's `fractions.Fraction` class to + # unpack into a rational number, so long it reproduces + # identically the same floating-point number. + as_float = float(value) + as_fraction = fractions.Fraction(as_float).limit_denominator(1024) + if as_fraction.numerator / as_fraction.denominator == as_float: + return Fraction(as_fraction.numerator, as_fraction.denominator) + else: + raise ValueError(f"Could not represent value {value} as a ratio of integers") + + elif isinstance(value, tvm.tir.PrimExpr) and "float" in value.dtype: + # Any other floating-point expressions are forbidden. + raise ValueError(f"Could not represent value {value} as a ratio of integers") + + else: + raise TypeError( + f"Could not represent type {type(value)} (value = {value}) " + "as a ratio of integers" + ) + + numerator = _normalize(numerator) + denominator = 1 if denominator is None else _normalize(denominator) + + if isinstance(numerator, Fraction) and isinstance(denominator, Fraction): + self.numerator, self.denominator = ( + numerator.numerator * denominator * denominator, + denominator.numerator * numerator.denominator, + ) + elif isinstance(numerator, Fraction): + self.numerator, self.denominator = ( + numerator.numerator, + denominator * numerator.denominator, + ) + elif isinstance(denominator, Fraction): + self.numerator, self.denominator = ( + numerator * denominator.denominator, + denominator.numerator, + ) + else: + self.numerator, self.denominator = (numerator, denominator) + + if not isinstance(self.denominator, tvm.tir.PrimExpr): + self.denominator = tvm.tir.IntImm(self.numerator.dtype, self.denominator) + + if isinstance(self.denominator, tvm.tir.IntImm): + assert self.denominator.value != 0 + + def __repr__(self): + return f"Fraction({self.numerator}, {self.denominator})" + + @classmethod + def OrPrimExpr( + cls, value: Union[int, float, "Fraction", tvm.tir.PrimExpr] + ) -> Union[tvm.tir.PrimExpr, "Fraction"]: + """Attempt to generate an integer fraction, with fallback to PrimExpr + + Parameters + ---------- + value: Union[int, float, Fraction, PrimExpr] + + The value to be expressed as a fraction, if possible. + + Returns + ------- + fraction_or_primexpr: Union[PrimExpr, Fraction] + + The resulting fraction if the value can be expressed as an + integer fraction, otherwise the original value. See + docstring of `Fraction` for the allowed conversions. + """ + + try: + return cls(value) + except ValueError: + return tvm.runtime.convert(value) + + @classmethod + def split_whole_and_fractional_parts( + cls, expr: Union["Fraction", tvm.tir.PrimExpr] + ) -> Tuple[tvm.tir.PrimExpr, Union["Fraction", tvm.tir.PrimExpr]]: + """Split the fraction into integer and fractional components + + Parameters + ---------- + + expr: Union[Fraction, PrimExpr] + + The expression to be split + + Returns + ------- + int_part: PrimExpr + + The integer part of the fraction. This is determined + either with integer `tir.floordiv` for a `Fraction`, or + with `tir.floor` for a `PrimExpr`. + + fractional_part: Union[PrimExpr, Fraction] + + The remaining fractional part of the initial fraction. + This is determined either with integer `tir.floormod` for + a `Fraction`, or by subtracting `int_part` for a + `PrimExpr`. + """ + if isinstance(expr, cls): + return (expr.int_part(), expr.fractional_part()) + else: + int_part = tvm.tir.floor(expr).astype("int32") + return int_part, expr - int_part + + def simplify(self, analyzer: Optional["tvm.arith.Analyzer"] = None) -> "Fraction": + """Simplify the fraction + + Parameters + ---------- + analyzer: Optional[arith.Analyzer] + + The analyzer to use for simplification. If None, + construct a temporary analyzer. + + Returns + ------- + simplified: Fraction + + The simplified fraction + """ + if analyzer is None: + analyzer = tvm.arith.Analyzer() + numerator = analyzer.simplify(self.numerator) + denominator = analyzer.simplify(self.denominator) + if numerator == 0: + return Fraction(0, 1) + + def _extract_coef(val): + if isinstance(val, (int, tvm.tir.IntImm)): + return int(val) + elif isinstance(val, tvm.tir.Mul) and isinstance(val.b, tvm.tir.IntImm): + return int(val.b) + else: + return 1 + + gcd = math.gcd(_extract_coef(numerator), _extract_coef(denominator)) + if gcd != 1: + numerator = analyzer.simplify(numerator // gcd) + denominator = analyzer.simplify(denominator // gcd) + + return Fraction(numerator, denominator) + + def astype(self, dtype: str) -> tvm.tir.PrimExpr: + """Convert to a tvm.tir.PrimExpr of the specified type + + The name is deliberately the same as `PrimExpr.astype`, to + allow `expr.astype(out_dtype)` to be valid for both + `tvm.tir.PrimExpr` and `Fraction` expressions. + + Parameters + ---------- + dtype: str + + The TVM datatype to return. + + Returns + ------- + value: PrimExpr + + The resulting PrimExpr + + """ + if "int" in dtype: + return self.int_part().astype(dtype) + else: + frac = self.simplify() + return frac.numerator.astype(dtype) / frac.denominator.astype(dtype) + + def int_part(self) -> tvm.tir.PrimExpr: + """The integer part of the fraction + + Returns + ------- + int_part: PrimExpr + + The integer part of the fraction + """ + return tvm.tir.floordiv(self.numerator, self.denominator) + + def fractional_part(self) -> "Fraction": + """The remainder of the fraction + + Returns + ------- + fractional_part: Fraction + + The remainder of the fraction + """ + return Fraction(tvm.tir.floormod(self.numerator, self.denominator), self.denominator) + + def __neg__(self): + # Disabling the pylint check, since pylint doesn't track the + # __init__ type annotations to determine that self.numerator + # may not be None. + + return Fraction( + -self.numerator, # pylint: disable=invalid-unary-operand-type + self.denominator, + ) + + def __mul__(self, other): + try: + other = Fraction(other) + except ValueError: + return self.astype(other.dtype) * other + else: + return Fraction(self.numerator * other.numerator, self.denominator * other.denominator) + + def __rmul__(self, other): + return self * other + + def __truediv__(self, other): + other = Fraction(other) + return Fraction(self.numerator * other.denominator, self.denominator * other.numerator) + + @staticmethod + def _with_common_denominator(lhs, rhs): + if not isinstance(lhs.denominator, (int, tvm.tir.IntImm)) or not isinstance( + rhs.denominator, (int, tvm.tir.IntImm) + ): + denom = lhs.denominator * rhs.denominator + return Fraction(lhs.numerator * rhs.denominator, denom), Fraction( + rhs.numerator * lhs.denominator, denom + ) + + gcd = math.gcd(int(lhs.denominator), int(rhs.denominator)) + lcm = (int(lhs.denominator) * int(rhs.denominator)) // gcd + return Fraction(lhs.numerator * (lcm // lhs.denominator), lcm), Fraction( + rhs.numerator * (lcm // rhs.denominator), lcm + ) + + def __add__(self, other): + other = Fraction(other) + self, other = Fraction._with_common_denominator(self, other) + return Fraction( + self.numerator + other.numerator, + self.denominator, + ) + + def __radd__(self, other): + return self + other + + def __sub__(self, other): + other = Fraction(other) + self, other = Fraction._with_common_denominator(self, other) + return Fraction( + self.numerator - other.numerator, + self.denominator, + ) + + def __rsub__(self, other): + return Fraction(other) - self diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index 3682054e8e4b..d2006ce4af09 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -754,7 +754,9 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MulNode* op) { c1.Eval()->value == -c2.Eval()->value); // canonicalization - TVM_TRY_RECURSIVE_REWRITE(x * (c1 * y), (x * y) * c1); + TVM_TRY_RECURSIVE_REWRITE(x * (y * c1), (x * y) * c1); + TVM_TRY_RECURSIVE_REWRITE((x * c1) * y, (x * y) * c1); + TVM_TRY_RECURSIVE_REWRITE(x * (y * z), (x * y) * z); TVM_TRY_RECURSIVE_REWRITE(c1 * x, x * c1); TVM_TRY_RECURSIVE_REWRITE_IF((x - y) * c1, (y - x) * (0 - c1), c1.Eval()->value < 0); } @@ -2315,6 +2317,13 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const CallNode* op) { } LOG(FATAL) << "Should not reach here"; } + } else if (op->op.same_as(Op::Get("tir.floor"))) { + PrimExpr floor_arg = op->args[0]; + if (auto arg_int = floor_arg.as()) { + return cast(op->dtype, IntImm(arg_int->dtype, arg_int->value)); + } else if (auto arg_float = floor_arg.as()) { + return cast(op->dtype, FloatImm(arg_float->dtype, std::floor(arg_float->value))); + } } if (op->op.same_as(tir::builtin::likely())) { diff --git a/src/tir/transforms/remove_no_op.cc b/src/tir/transforms/remove_no_op.cc index 3b418aac0cf5..d3c38e5cf3d1 100644 --- a/src/tir/transforms/remove_no_op.cc +++ b/src/tir/transforms/remove_no_op.cc @@ -29,6 +29,7 @@ #include #include +#include #include #include diff --git a/tests/python/topi/test_topi_image.py b/tests/python/topi/test_topi_image.py index 56f7a2026d33..d224bea640be 100644 --- a/tests/python/topi/test_topi_image.py +++ b/tests/python/topi/test_topi_image.py @@ -15,100 +15,106 @@ # specific language governing permissions and limitations # under the License. """Test code for bilinear scale """ + import numpy as np +import pytest + import tvm -from tvm import te -from tvm import topi import tvm.testing import tvm.topi.testing -from tvm.contrib.pickle_memoize import memoize +from tvm import te, topi +from tvm.contrib.pickle_memoize import memoize -def verify_resize2d( - batch, - in_channel, - in_height, - in_width, - out_height, - out_width, - layout="NCHW", - coord_trans="align_corners", - method="linear", +layout_2d = tvm.testing.parameter("NCHW", "NHWC") +layout_3d = tvm.testing.parameter("NCDHW", "NDHWC") +coordinate_transformation_mode = tvm.testing.parameter("asymmetric", "align_corners", "half_pixel") +interpolation_method = tvm.testing.parameter("nearest_neighbor", "linear") + +resize_2d_test_case = tvm.testing.parameter( + dict(sizes=(4, 16, 32, 32, 50, 50), coord_trans="align_corners", method="linear"), + dict(sizes=(6, 32, 64, 64, 20, 20), coord_trans="align_corners", method="linear"), + dict(sizes=(4, 16, 32, 32, 50, 50), coord_trans="asymmetric", method="nearest_neighbor"), + dict(sizes=(4, 16, 32, 32, 64, 50), coord_trans="asymmetric", method="nearest_neighbor"), + dict(sizes=(4, 16, 32, 32, 50, 96), coord_trans="asymmetric", method="nearest_neighbor"), + dict(sizes=(4, 16, 32, 32, 96, 96), coord_trans="asymmetric", method="nearest_neighbor"), + dict(sizes=(4, 16, 32, 32, 50, 50), coord_trans="align_corners", method="nearest_neighbor"), + dict(sizes=(4, 16, 32, 32, 50, 50), coord_trans="half_pixel", method="nearest_neighbor"), + dict(sizes=(4, 16, 32, 32, 50, 50), coord_trans="asymmetric", method="linear"), + dict(sizes=(4, 16, 32, 32, 50, 50), coord_trans="half_pixel", method="linear"), +) + + +def test_resize2d( + target, + dev, + resize_2d_test_case, + layout_2d, ): - if layout == "NCHW": + (batch, in_channel, in_height, in_width, out_height, out_width) = resize_2d_test_case["sizes"] + coordinate_transformation_mode = resize_2d_test_case["coord_trans"] + interpolation_method = resize_2d_test_case["method"] + + if layout_2d == "NCHW": A = te.placeholder((batch, in_channel, in_height, in_width), name="A", dtype="float32") dtype = A.dtype out_shape = (batch, in_channel, out_height, out_width) a_np = np.random.uniform(size=(batch, in_channel, in_height, in_width)).astype(dtype) - elif layout == "NHWC": + elif layout_2d == "NHWC": A = te.placeholder((batch, in_height, in_width, in_channel), name="A", dtype="float32") dtype = A.dtype out_shape = (batch, out_height, out_width, in_channel) a_np = np.random.uniform(size=(batch, in_height, in_width, in_channel)).astype(dtype) else: - raise NotImplementedError("Layout not supported {} ".format(layout)) + raise NotImplementedError("Layout not supported {} ".format(layout_2d)) + B = topi.image.resize2d( A, [0.0] * 4, (out_height, out_width), - layout=layout, - coordinate_transformation_mode=coord_trans, - method=method, + layout=layout_2d, + coordinate_transformation_mode=coordinate_transformation_mode, + method=interpolation_method, ) scale_h = out_height / in_height scale_w = out_width / in_width - b_np = tvm.topi.testing.resize2d_python(a_np, (scale_h, scale_w), layout, method, coord_trans) + b_np = tvm.topi.testing.resize2d_python( + a_np, (scale_h, scale_w), layout_2d, interpolation_method, coordinate_transformation_mode + ) - def check_target(target, dev): - print("Running on target: %s" % target) - with tvm.target.Target(target): - s = tvm.topi.testing.get_injective_schedule(target)(B) - a = tvm.nd.array(a_np, dev) - b = tvm.nd.array(np.zeros(out_shape, dtype=dtype), dev) - f = tvm.build(s, [A, B], target) - f(a, b) + with tvm.target.Target(target): + s = tvm.topi.testing.get_injective_schedule(target)(B) + a = tvm.nd.array(a_np, dev) + b = tvm.nd.array(np.zeros(out_shape, dtype=dtype), dev) + f = tvm.build(s, [A, B], target) + f(a, b) - tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-3, atol=1e-3) + tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-3, atol=1e-3) - for target, dev in tvm.testing.enabled_targets(): - check_target(target, dev) +resize_3d_test_case = tvm.testing.parameter((3, 16, 32, 32, 32, 10, 10, 10)) -@tvm.testing.uses_gpu -def test_resize2d(): - # Scale NCHW - verify_resize2d(4, 16, 32, 32, 50, 50, "NCHW") - # Scale NCHW + Align Corners - verify_resize2d(6, 32, 64, 64, 20, 20, "NCHW") - # Scale NHWC - verify_resize2d(4, 16, 32, 32, 50, 50, "NHWC") - # Scale NHWC + Align Corners - verify_resize2d(6, 32, 64, 64, 20, 20, "NHWC") - for layout in ["NCHW", "NHWC"]: - verify_resize2d(4, 16, 32, 32, 50, 50, layout, "asymmetric", method="nearest_neighbor") - verify_resize2d(4, 16, 32, 32, 64, 50, layout, "asymmetric", method="nearest_neighbor") - verify_resize2d(4, 16, 32, 32, 50, 96, layout, "asymmetric", method="nearest_neighbor") - verify_resize2d(4, 16, 32, 32, 96, 96, layout, "asymmetric", method="nearest_neighbor") - verify_resize2d(4, 16, 32, 32, 50, 50, layout, "align_corners", method="nearest_neighbor") - verify_resize2d(4, 16, 32, 32, 50, 50, layout, "half_pixel", method="nearest_neighbor") - verify_resize2d(4, 16, 32, 32, 50, 50, layout, "asymmetric", method="linear") - verify_resize2d(4, 16, 32, 32, 50, 50, layout, "half_pixel", method="linear") - - -def verify_resize3d( - batch, - in_channel, - in_depth, - in_height, - in_width, - out_depth, - out_height, - out_width, - layout="NCDHW", - coordinate_transformation_mode="asymmetric", - method="linear", + +def test_resize3d( + target, + dev, + resize_3d_test_case, + layout_3d, + coordinate_transformation_mode, + interpolation_method, ): - if layout == "NCDHW": + ( + batch, + in_channel, + in_depth, + in_height, + in_width, + out_depth, + out_height, + out_width, + ) = resize_3d_test_case + + if layout_3d == "NCDHW": A = te.placeholder( (batch, in_channel, in_depth, in_height, in_width), name="A", dtype="float32" ) @@ -117,7 +123,7 @@ def verify_resize3d( a_np = np.random.uniform(size=(batch, in_channel, in_depth, in_height, in_width)).astype( dtype ) - elif layout == "NDHWC": + elif layout_3d == "NDHWC": A = te.placeholder( (batch, in_depth, in_height, in_width, in_channel), name="A", dtype="float32" ) @@ -127,224 +133,187 @@ def verify_resize3d( dtype ) else: - raise NotImplementedError("Layout not supported {} ".format(layout)) + raise NotImplementedError("Layout not supported {} ".format(layout_3d)) B = topi.image.resize3d( A, [0.0] * 6, (out_depth, out_height, out_width), - layout=layout, + layout=layout_3d, coordinate_transformation_mode=coordinate_transformation_mode, - method=method, + method=interpolation_method, ) scale_d = out_depth / in_depth scale_h = out_height / in_height scale_w = out_width / in_width b_np = tvm.topi.testing.resize3d_python( - a_np, (scale_d, scale_h, scale_w), layout, method, coordinate_transformation_mode + a_np, + (scale_d, scale_h, scale_w), + layout_3d, + interpolation_method, + coordinate_transformation_mode, ) - def check_target(target, dev): - with tvm.target.Target(target): - s = tvm.topi.testing.get_injective_schedule(target)(B) - a = tvm.nd.array(a_np, dev) - b = tvm.nd.array(np.zeros(out_shape, dtype=dtype), dev) - f = tvm.build(s, [A, B], target) - f(a, b) - - tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-3, atol=1e-3) - - for target, dev in tvm.testing.enabled_targets(): - check_target(target, dev) + with tvm.target.Target(target): + s = tvm.topi.testing.get_injective_schedule(target)(B) + a = tvm.nd.array(a_np, dev) + b = tvm.nd.array(np.zeros(out_shape, dtype=dtype), dev) + f = tvm.build(s, [A, B], target) + f(a, b) + + tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-3, atol=1e-3) + + +box_set_1 = dict( + boxes=np.array([[0.2, 0.3, 0.7, 0.9]], dtype="float32"), + indices=np.array([0], dtype="int32"), + crop_size=(7, 11), +) +box_set_2 = dict( + boxes=np.array([[0.2, 0.3, 0.7, 0.9], [0, 0.1, 0.8, 1]], dtype="float32"), + indices=np.array([1, 0], dtype="int32"), + crop_size=(90, 60), +) +crop_and_resize_test_case = tvm.testing.parameter( + dict(image_shape=(1, 255, 255, 3), **box_set_1), + dict(image_shape=(1, 100, 100, 3), **box_set_1, method="nearest_neighbor"), + dict(image_shape=(1, 3, 224, 224), **box_set_1, layout="NCHW"), + dict(image_shape=(10, 224, 224, 5), **box_set_2, extrapolation_value=0.3), +) @tvm.testing.uses_gpu -def test_resize3d(): - # Trilinear - for method in ["nearest_neighbor", "linear"]: - for coord_trans in ["asymmetric", "align_corners", "half_pixel"]: - for layout in ["NCDHW", "NDHWC"]: - verify_resize3d(3, 16, 32, 32, 32, 10, 10, 10, layout, coord_trans, method) - +def test_crop_and_resize(target, dev, crop_and_resize_test_case): + image_shape = crop_and_resize_test_case["image_shape"] + np_boxes = crop_and_resize_test_case["boxes"] + np_box_indices = crop_and_resize_test_case["indices"] + np_crop_size = crop_and_resize_test_case["crop_size"] + method = crop_and_resize_test_case.get("method", "bilinear") + extrapolation_value = crop_and_resize_test_case.get("extrapolation_value", 0.0) + layout = crop_and_resize_test_case.get("layout", "NHWC") + + images = te.placeholder(image_shape, name="images", dtype="float32") + np_images = np.random.uniform(size=image_shape).astype("float32") + boxes = te.placeholder(np_boxes.shape, name="boxes", dtype="float32") + box_ind = te.placeholder(np_box_indices.shape, name="box_ind", dtype="int32") + + batch = len(np_box_indices) + target_height, target_width = np_crop_size[0], np_crop_size[1] + if layout == "NHWC": + channel = image_shape[3] + out_shape = (batch, target_height, target_width, channel) + elif layout == "NCHW": + channel = image_shape[1] + out_shape = (batch, channel, target_height, target_width) + else: + raise NotImplementedError("Layout {} is not supported.".format(layout)) -@tvm.testing.uses_gpu -def test_crop_and_resize(): - def verify_crop_and_resize( - image_shape, - np_boxes, - np_box_indices, + out = topi.image.crop_and_resize( + images, + boxes, + box_ind, np_crop_size, - layout="NHWC", - method="bilinear", - extrapolation_value=0.0, - ): - - images = te.placeholder(image_shape, name="images", dtype="float32") - np_images = np.random.uniform(size=image_shape).astype("float32") - boxes = te.placeholder(np_boxes.shape, name="boxes", dtype="float32") - box_ind = te.placeholder(np_box_indices.shape, name="box_ind", dtype="int32") - - batch = len(np_box_indices) - target_height, target_width = np_crop_size[0], np_crop_size[1] - if layout == "NHWC": - channel = image_shape[3] - out_shape = (batch, target_height, target_width, channel) - elif layout == "NCHW": - channel = image_shape[1] - out_shape = (batch, channel, target_height, target_width) - else: - raise NotImplementedError("Layout {} is not supported.".format(layout)) - - out = topi.image.crop_and_resize( - images, - boxes, - box_ind, - np_crop_size, - layout=layout, - method=method, - extrapolation_value=extrapolation_value, - ) - - baseline_np = tvm.topi.testing.crop_and_resize_python( - np_images, np_boxes, np_box_indices, np_crop_size, layout, method, extrapolation_value - ) - - def check_target(target, dev): - print("Running on target: %s" % target) - with tvm.target.Target(target): - s = tvm.topi.testing.get_injective_schedule(target)(out) - tvm_images = tvm.nd.array(np_images, dev) - tvm_boxes = tvm.nd.array(np_boxes, dev) - tvm_indices = tvm.nd.array(np_box_indices, dev) - tvm_out = tvm.nd.array(np.zeros(out_shape, dtype="float32"), dev) - f = tvm.build(s, [images, boxes, box_ind, out], target, name="crop_and_resize") - f(tvm_images, tvm_boxes, tvm_indices, tvm_out) - - tvm.testing.assert_allclose(tvm_out.numpy(), baseline_np, rtol=1e-3, atol=1e-3) - - for target, dev in tvm.testing.enabled_targets(): - check_target(target, dev) - - boxes_1 = np.array([[0.2, 0.3, 0.7, 0.9]], dtype="float32") - boxes_2 = np.array([[0.2, 0.3, 0.7, 0.9], [0, 0.1, 0.8, 1]], dtype="float32") - indices_1 = np.array([0], dtype="int32") - indices_2 = np.array([1, 0], dtype="int32") - size_1 = (7, 11) - size_2 = (90, 60) - - verify_crop_and_resize((1, 255, 255, 3), boxes_1, indices_1, size_1, layout="NHWC") - verify_crop_and_resize( - (10, 224, 224, 5), boxes_2, indices_2, size_2, extrapolation_value=0.3, layout="NHWC" + layout=layout, + method=method, + extrapolation_value=extrapolation_value, ) - verify_crop_and_resize((1, 100, 100, 3), boxes_1, indices_1, size_1, method="nearest_neighbor") - verify_crop_and_resize((1, 3, 224, 224), boxes_1, indices_1, size_1, layout="NCHW") + baseline_np = tvm.topi.testing.crop_and_resize_python( + np_images, np_boxes, np_box_indices, np_crop_size, layout, method, extrapolation_value + ) -@tvm.testing.uses_gpu -def test_affine_grid(): - def verify_affine_grid(num_batch, target_shape): - dtype = "float32" - data_shape = (num_batch, 2, 3) - data = te.placeholder(data_shape, dtype=dtype) - out = topi.image.affine_grid(data, target_shape) - - @memoize("topi.tests.test_affine_grid.verify_affine_grid") - def get_ref_data(): - data_np = np.random.uniform(size=data_shape).astype(dtype) - out_np = tvm.topi.testing.affine_grid_python(data_np, target_shape) - return data_np, out_np - - data_np, out_np = get_ref_data() - - def check_target(target, dev): - print("Running on target: %s" % target) - with tvm.target.Target(target): - s = tvm.topi.testing.get_injective_schedule(target)(out) - tvm_data = tvm.nd.array(data_np, dev) - tvm_out = tvm.nd.empty(out_np.shape, dtype, dev) - f = tvm.build(s, [data, out], target) - f(tvm_data, tvm_out) - - tvm.testing.assert_allclose(tvm_out.numpy(), out_np, rtol=1e-5, atol=1e-5) - - for target, dev in tvm.testing.enabled_targets(): - check_target(target, dev) + with tvm.target.Target(target): + s = tvm.topi.testing.get_injective_schedule(target)(out) + tvm_images = tvm.nd.array(np_images, dev) + tvm_boxes = tvm.nd.array(np_boxes, dev) + tvm_indices = tvm.nd.array(np_box_indices, dev) + tvm_out = tvm.nd.array(np.zeros(out_shape, dtype="float32"), dev) + f = tvm.build(s, [images, boxes, box_ind, out], target, name="crop_and_resize") + f(tvm_images, tvm_boxes, tvm_indices, tvm_out) + + tvm.testing.assert_allclose(tvm_out.numpy(), baseline_np, rtol=1e-3, atol=1e-3) + + +affine_grid_test_case = tvm.testing.parameter( + (1, (16, 32)), + (4, (16, 32)), +) + + +def test_affine_grid(target, dev, affine_grid_test_case): + num_batch, target_shape = affine_grid_test_case + + dtype = "float32" + data_shape = (num_batch, 2, 3) + data = te.placeholder(data_shape, dtype=dtype) + out = topi.image.affine_grid(data, target_shape) + + @memoize("topi.tests.test_affine_grid.verify_affine_grid") + def get_ref_data(): + data_np = np.random.uniform(size=data_shape).astype(dtype) + out_np = tvm.topi.testing.affine_grid_python(data_np, target_shape) + return data_np, out_np + + data_np, out_np = get_ref_data() + + with tvm.target.Target(target): + s = tvm.topi.testing.get_injective_schedule(target)(out) + tvm_data = tvm.nd.array(data_np, dev) + tvm_out = tvm.nd.empty(out_np.shape, dtype, dev) + f = tvm.build(s, [data, out], target) + f(tvm_data, tvm_out) + + tvm.testing.assert_allclose(tvm_out.numpy(), out_np, rtol=1e-5, atol=1e-5) + + +@pytest.mark.parametrize("method", ["nearest", "bilinear", "bicubic"]) +@pytest.mark.parametrize("padding_mode", ["zeros", "border", "reflection"]) +@pytest.mark.parametrize("align_corners", [True, False]) +@pytest.mark.parametrize("dimension", ["2d", "3d"]) +def test_grid_sample(target, dev, method, padding_mode, align_corners, dimension): + if dimension == "3d" and method == "bicubic": + pytest.skip('3D "bicubic"(tricubic) is not supported in pytorch') + + if dimension == "2d": + data_shape = (4, 4, 8, 8) + grid_shape = (4, 2, 16, 16) + layout = "NCHW" + elif dimension == "3d": + # choosing smaller sizes to be testable on weaker GPUs + data_shape = (4, 4, 4, 4, 4) + grid_shape = (4, 3, 8, 8, 8) + layout = "NCDHW" + else: + raise ValueError(f"Unknown dimension: {dimension}") + + dtype = "float32" + data = te.placeholder(data_shape, dtype=dtype) + grid = te.placeholder(grid_shape, dtype=dtype) + out = topi.image.grid_sample(data, grid, method, layout, padding_mode, align_corners) + + @memoize("topi.tests.test_grid_sample.verify_grid_sample") + def get_ref_data(): + data_np = np.random.uniform(size=data_shape).astype(dtype) + # allow grid values to be out-of-bound + grid_np = np.random.uniform(size=grid_shape, low=-1.5, high=1.5).astype(dtype) + out_np = tvm.topi.testing.grid_sample_python( + data_np, grid_np, method, layout, padding_mode, align_corners + ) + return data_np, grid_np, out_np - verify_affine_grid(1, (16, 32)) - verify_affine_grid(4, (16, 32)) + data_np, grid_np, out_np = get_ref_data() + with tvm.target.Target(target): + s = tvm.topi.testing.get_injective_schedule(target)(out) + tvm_data = tvm.nd.array(data_np, dev) + tvm_grid = tvm.nd.array(grid_np, dev) + tvm_out = tvm.nd.empty(out_np.shape, dtype, dev) + f = tvm.build(s, [data, grid, out], target) + f(tvm_data, tvm_grid, tvm_out) -@tvm.testing.uses_gpu -def test_grid_sample(): - def verify_grid_sample( - data_shape, - grid_shape, - method="bilinear", - layout="NCHW", - padding_mode="zeros", - align_corners=True, - ): - dtype = "float32" - data = te.placeholder(data_shape, dtype=dtype) - grid = te.placeholder(grid_shape, dtype=dtype) - out = topi.image.grid_sample(data, grid, method, layout, padding_mode, align_corners) - - @memoize("topi.tests.test_grid_sample.verify_grid_sample") - def get_ref_data(): - data_np = np.random.uniform(size=data_shape).astype(dtype) - # allow grid values to be out-of-bound - grid_np = np.random.uniform(size=grid_shape, low=-1.5, high=1.5).astype(dtype) - out_np = tvm.topi.testing.grid_sample_python( - data_np, grid_np, method, layout, padding_mode, align_corners - ) - return data_np, grid_np, out_np - - data_np, grid_np, out_np = get_ref_data() - - def check_target(target, dev): - print("Running on target: %s" % target) - with tvm.target.Target(target): - s = tvm.topi.testing.get_injective_schedule(target)(out) - tvm_data = tvm.nd.array(data_np, dev) - tvm_grid = tvm.nd.array(grid_np, dev) - tvm_out = tvm.nd.empty(out_np.shape, dtype, dev) - f = tvm.build(s, [data, grid, out], target) - f(tvm_data, tvm_grid, tvm_out) - - tvm.testing.assert_allclose(tvm_out.numpy(), out_np, rtol=1e-5, atol=1e-5) - - for target, dev in tvm.testing.enabled_targets(): - check_target(target, dev) - - methods = ["nearest", "bilinear", "bicubic"] - padding_modes = ["zeros", "border", "reflection"] - align_corners = [True, False] - data_2D_shape = (4, 4, 8, 8) - grid_2D_shape = (4, 2, 16, 16) - layout_2D = "NCHW" - # choosing smaller sizes to be testable on weaker GPUs - data_3D_shape = (4, 4, 4, 4, 4) - grid_3D_shape = (4, 3, 8, 8, 8) - layout_3D = "NCDHW" - - for _method in methods: - for _padding in padding_modes: - for _align in align_corners: - verify_grid_sample( - data_2D_shape, grid_2D_shape, _method, layout_2D, _padding, _align - ) - - # 3D "bicubic"(tricubic) is not supported in pytorch - if _method != "bicubic": - verify_grid_sample( - data_3D_shape, grid_3D_shape, _method, layout_3D, _padding, _align - ) + tvm.testing.assert_allclose(tvm_out.numpy(), out_np, rtol=1e-5, atol=1e-5) if __name__ == "__main__": - test_resize2d() - test_resize3d() - test_crop_and_resize() - test_affine_grid() - test_grid_sample() + tvm.testing.main() From 0901475019b314e8044303eb7ea3f6a6ca85b9e3 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 1 Dec 2022 15:35:03 -0600 Subject: [PATCH 02/12] Infer out_dtype for empty string Relay uses `DataType::Void()` to represent unspecified data types, the FFI converts `DataType` objects to strings, and `DataType::Void()` is represented as the empty string. --- python/tvm/topi/image/resize.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/topi/image/resize.py b/python/tvm/topi/image/resize.py index d9b31ece4f98..622842885015 100644 --- a/python/tvm/topi/image/resize.py +++ b/python/tvm/topi/image/resize.py @@ -586,7 +586,7 @@ def _resize_2d( The computed result with type out_dtype """ - if out_dtype is None: + if not out_dtype: out_dtype = data.dtype n, c, y, x, cc, inum, ic = get_2d_indices(indices, layout) From c9614e5fbe32b73e718e2ce9eef534ec2e43a1b3 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 2 Dec 2022 09:02:40 -0600 Subject: [PATCH 03/12] Removed unnecessary --- src/tir/transforms/remove_no_op.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/src/tir/transforms/remove_no_op.cc b/src/tir/transforms/remove_no_op.cc index d3c38e5cf3d1..3b418aac0cf5 100644 --- a/src/tir/transforms/remove_no_op.cc +++ b/src/tir/transforms/remove_no_op.cc @@ -29,7 +29,6 @@ #include #include -#include #include #include From 3783b225164ccfa823e4099afb4adba450e2116e Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 2 Dec 2022 09:09:19 -0600 Subject: [PATCH 04/12] Completed incomplete sentence --- python/tvm/topi/utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/tvm/topi/utils.py b/python/tvm/topi/utils.py index 39cc85652aab..040effe9fa3b 100644 --- a/python/tvm/topi/utils.py +++ b/python/tvm/topi/utils.py @@ -539,7 +539,9 @@ class Fraction: to check if an optimization is permissible (e.g. vectorized computations require linear buffer access), use of integer expressions may provide significant performance benefits. - However, writing the simplified form + However, directly writing the resulting integer expression would + be tedious in many cases, or may depend on a user-specified + fractional value. The `Fraction` class is intended to allow for easier writing of integer expressions. The operator overloads will attempt to From 9501daef2ec2331a217f3a1f4d991030faafd10e Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 2 Dec 2022 15:05:37 -0600 Subject: [PATCH 05/12] Allow use of integer te.Tensor, te.TensorSlice in Fraction These are passed into the topi resize library for dynamic relay shapes, and should be supported as possible integer types. --- python/tvm/topi/utils.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/python/tvm/topi/utils.py b/python/tvm/topi/utils.py index 040effe9fa3b..e5393c7a6b44 100644 --- a/python/tvm/topi/utils.py +++ b/python/tvm/topi/utils.py @@ -630,7 +630,10 @@ def _normalize(value): elif isinstance(value, int): return tvm.runtime.convert(value) - elif isinstance(value, tvm.tir.PrimExpr) and "int" in value.dtype: + elif ( + isinstance(value, (tvm.tir.PrimExpr, tvm.te.Tensor, tvm.te.TensorSlice)) + and "int" in value.dtype + ): return value elif isinstance(value, (float, tvm.tir.FloatImm)): @@ -645,7 +648,10 @@ def _normalize(value): else: raise ValueError(f"Could not represent value {value} as a ratio of integers") - elif isinstance(value, tvm.tir.PrimExpr) and "float" in value.dtype: + elif ( + isinstance(value, (tvm.tir.PrimExpr, tvm.te.Tensor, tvm.te.TensorSlice)) + and "float" in value.dtype + ): # Any other floating-point expressions are forbidden. raise ValueError(f"Could not represent value {value} as a ratio of integers") From fcaa709191f16c00a21a5f68a56ba192233bf3bd Mon Sep 17 00:00:00 2001 From: Eldritch Cheese Date: Thu, 5 Jan 2023 15:50:10 -0600 Subject: [PATCH 06/12] Infer out_dtype for empty string --- python/tvm/topi/image/resize.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/topi/image/resize.py b/python/tvm/topi/image/resize.py index 622842885015..de4ac8e284d6 100644 --- a/python/tvm/topi/image/resize.py +++ b/python/tvm/topi/image/resize.py @@ -313,7 +313,7 @@ def _resize_1d( The computed result with type out_dtype """ - if out_dtype is None: + if not out_dtype: out_dtype = data.dtype n, c, x, cc, inum, ic = get_1d_indices(indices, layout) @@ -1044,7 +1044,7 @@ def _resize_3d( The computed result with type out_dtype """ - if out_dtype is None: + if not out_dtype: out_dtype = data.dtype n, c, z, y, x, cc = get_3d_indices(indices, layout) From 0f93558236624f3e6398a19a9e4460fa30d5eb52 Mon Sep 17 00:00:00 2001 From: Eldritch Cheese Date: Tue, 10 Jan 2023 09:16:38 -0600 Subject: [PATCH 07/12] Corrected ordering of _sum_products arguments --- python/tvm/topi/image/resize.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/topi/image/resize.py b/python/tvm/topi/image/resize.py index de4ac8e284d6..1337c4248fcc 100644 --- a/python/tvm/topi/image/resize.py +++ b/python/tvm/topi/image/resize.py @@ -1160,7 +1160,7 @@ def _resize_3d( l = [[0 for i in range(4)] for j in range(4)] for j in range(4): for i in range(4): - l[j][i] = _sum_products(p[j][i], wx) + l[j][i] = _sum_products(wx, p[j][i]) col0 = _sum_products(wy, l[0]) col1 = _sum_products(wy, l[1]) col2 = _sum_products(wy, l[2]) From 4db65224896c011a40bc35dd98e70805dca8b14b Mon Sep 17 00:00:00 2001 From: Eldritch Cheese Date: Thu, 12 Jan 2023 09:06:36 -0600 Subject: [PATCH 08/12] Correction to "round_prefer_floor" --- python/tvm/topi/image/resize.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/tvm/topi/image/resize.py b/python/tvm/topi/image/resize.py index 1337c4248fcc..fb784c226c00 100644 --- a/python/tvm/topi/image/resize.py +++ b/python/tvm/topi/image/resize.py @@ -178,8 +178,10 @@ def get_closest_index(in_x, rounding_method, boxes): numerator = in_x.numerator denominator = in_x.denominator - if rounding_method in ("round", "round_prefer_floor") or boxes is not None: + if rounding_method == "round" or boxes is not None: return (numerator + denominator // 2) // denominator + if rounding_method == "round_prefer_floor": + return (numerator + (denominator - 1) // 2) // denominator elif rounding_method == "round_prefer_ceil": return (numerator + (denominator + 1) // 2) // denominator elif rounding_method == "floor": From a118d7e415e271ce29a75c4e43aaa65b2041f79b Mon Sep 17 00:00:00 2001 From: Eldritch Cheese Date: Thu, 12 Jan 2023 09:36:38 -0600 Subject: [PATCH 09/12] Avoid calling `te.if_then_else` with internal `Fraction` utility --- python/tvm/topi/image/resize.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/python/tvm/topi/image/resize.py b/python/tvm/topi/image/resize.py index fb784c226c00..a555be6dcbc4 100644 --- a/python/tvm/topi/image/resize.py +++ b/python/tvm/topi/image/resize.py @@ -356,7 +356,7 @@ def _resize_1d( if exclude_outside: for i in range(4): wx[i] = te.if_then_else( - te.any(xint - 1 + i < 0, xint + i > image_width), 0.0, wx[i] + te.any(xint - 1 + i < 0, xint + i > image_width), 0.0, wx[i].as_type(out_dtype) ) sum_wx = sum(wx) wx = [w / sum_wx for w in wx] @@ -700,10 +700,10 @@ def _resize_2d( if exclude_outside: for i in range(4): wx[i] = te.if_then_else( - te.any(xint - 1 + i < 0, xint + i > image_width), 0.0, wx[i] + te.any(xint - 1 + i < 0, xint + i > image_width), 0.0, wx[i].astype(out_dtype) ) wy[i] = te.if_then_else( - te.any(yint - 1 + i < 0, yint + i > image_height), 0.0, wy[i] + te.any(yint - 1 + i < 0, yint + i > image_height), 0.0, wy[i].astype(out_dtype) ) sum_wx = sum(wx) sum_wy = sum(wy) @@ -1144,13 +1144,13 @@ def _resize_3d( if exclude_outside: for i in range(4): wz[i] = te.if_then_else( - te.any(xint - 1 + i < 0, xint + i > image_height), 0.0, wx[i] + te.any(xint - 1 + i < 0, xint + i > image_height), 0.0, wx[i].as_type(out_dtype) ) wy[i] = te.if_then_else( - te.any(yint - 1 + i < 0, yint + i > image_height), 0.0, wy[i] + te.any(yint - 1 + i < 0, yint + i > image_height), 0.0, wy[i].as_type(out_dtype) ) wx[i] = te.if_then_else( - te.any(xint - 1 + i < 0, xint + i > image_width), 0.0, wx[i] + te.any(xint - 1 + i < 0, xint + i > image_width), 0.0, wx[i].as_type(out_dtype) ) sum_wz = sum(wz) sum_wy = sum(wy) From ff60a88729009036ec42861016efd90dc1320c64 Mon Sep 17 00:00:00 2001 From: Eldritch Cheese Date: Fri, 6 Jan 2023 15:05:12 -0600 Subject: [PATCH 10/12] [Debug] Print statements for RewriteSimplifier --- src/arith/rewrite_simplify.cc | 224 ++++++++++++++++++++++++++++++---- 1 file changed, 197 insertions(+), 27 deletions(-) diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index d2006ce4af09..8996eed1eda5 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -385,7 +385,12 @@ void RewriteSimplifier::Impl::Update(const Var& var, const PrimExpr& info, bool PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AddNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); - if (auto const_res = TryConstFold(op->a, op->b)) return const_res.value(); + if (auto const_res = TryConstFold(op->a, op->b)) { + std::cout << "\t\t" + << "Rewriting " << ret << " to " << const_res.value() << " using rule on line " + << __LINE__ << std::endl; + return const_res.value(); + } // Pattern var to match any expression PVar x, y, z, b1, b2, s1, s2; // Pattern var match IntImm @@ -536,7 +541,12 @@ RewriteSimplifier::Extension RewriteSimplifier::Impl::GetEnabledExtensions() con PrimExpr RewriteSimplifier::Impl::VisitExpr_(const SubNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); - if (auto const_res = TryConstFold(op->a, op->b)) return const_res.value(); + if (auto const_res = TryConstFold(op->a, op->b)) { + std::cout << "\t\t" + << "Rewriting " << ret << " to " << const_res.value() << " using rule on line " + << __LINE__ << std::endl; + return const_res.value(); + } // Pattern var to match any expression PVar x, y, z, b1, b2, s1, s2; // Pattern var match IntImm @@ -725,7 +735,12 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const SubNode* op) { PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MulNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); - if (auto const_res = TryConstFold(op->a, op->b)) return const_res.value(); + if (auto const_res = TryConstFold(op->a, op->b)) { + std::cout << "\t\t" + << "Rewriting " << ret << " to " << const_res.value() << " using rule on line " + << __LINE__ << std::endl; + return const_res.value(); + } // Pattern var to match any expression PVar x, y, z, b1, b2, s1, s2; // Pattern var match IntImm @@ -766,7 +781,12 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MulNode* op) { PrimExpr RewriteSimplifier::Impl::VisitExpr_(const DivNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); - if (auto const_res = TryConstFold
(op->a, op->b)) return const_res.value(); + if (auto const_res = TryConstFold
(op->a, op->b)) { + std::cout << "\t\t" + << "Rewriting " << ret << " to " << const_res.value() << " using rule on line " + << __LINE__ << std::endl; + return const_res.value(); + } // Pattern var to match any expression PVar x, y, z, b1; // Pattern var match IntImm @@ -926,7 +946,12 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const DivNode* op) { PrimExpr RewriteSimplifier::Impl::VisitExpr_(const ModNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); - if (auto const_res = TryConstFold(op->a, op->b)) return const_res.value(); + if (auto const_res = TryConstFold(op->a, op->b)) { + std::cout << "\t\t" + << "Rewriting " << ret << " to " << const_res.value() << " using rule on line " + << __LINE__ << std::endl; + return const_res.value(); + } // Pattern var to match any expression PVar x, y, z, b1; @@ -1016,7 +1041,12 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const ModNode* op) { PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); - if (auto const_res = TryConstFold(op->a, op->b)) return const_res.value(); + if (auto const_res = TryConstFold(op->a, op->b)) { + std::cout << "\t\t" + << "Rewriting " << ret << " to " << const_res.value() << " using rule on line " + << __LINE__ << std::endl; + return const_res.value(); + } // Pattern var to match any expression PVar x, y, z, b1; // Pattern var match IntImm @@ -1034,7 +1064,11 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { int64_t c2val = c2.Eval()->value; ICHECK(c2val != 0) << "division by zero"; if (c1val % c2val == 0) { - return ramp(floordiv(b1, c2), floordiv(c1, c2), lanes).Eval(); + PrimExpr out = ramp(floordiv(b1, c2), floordiv(c1, c2), lanes).Eval(); + std::cout << "\t\t" + << "Rewriting " << ret << " to " << out << " using rule on line " << __LINE__ + << std::endl; + return out; } // If all possible indices in ramp are the same. if (!arith::ExtractVscaleFactor(lanes.Eval())) { @@ -1069,7 +1103,11 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { int64_t c1val = c1.Eval()->value; int64_t c2val = c2.Eval()->value; PrimExpr yval = y.EvalOr(Integer(0)); - if (c2val == 0) return ret; + if (c2val == 0) { + std::cout << "\t\t" + << "Found denominator of zero in " << ret << ", bailing out" << std::endl; + return ret; + }; // try eliminate residue part PrimExpr residue = @@ -1077,7 +1115,11 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { PrimExpr y_div = CanProveEqual(floordiv(yval, c2val), 0) ? 0 : floordiv(yval, c2val); auto bound = analyzer_->const_int_bound(residue); if (bound.defined() && bound->max_value == bound->min_value) { - return x.Eval() * floordiv(c1val, c2.Eval()) + (y_div + Integer(bound->max_value)); + PrimExpr out = x.Eval() * floordiv(c1val, c2.Eval()) + (y_div + Integer(bound->max_value)); + std::cout << "\t\t" + << "Rewriting " << ret << " to " << out << " using rule on line " << __LINE__ + << std::endl; + return RecursiveRewrite(out); } // try simplify divisor @@ -1089,7 +1131,11 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { // ==> x' + d + (b * c1 + e) // c2 // ==> x' + d since 0 <= b * c1 <= (a-1) * c1, 0 <= e < c1 // ==> x // (c2 // c1) + (y // c2) - return floordiv(x.Eval(), floordiv(c2val, c1val)) + y_div; + PrimExpr out = floordiv(x.Eval(), floordiv(c2val, c1val)) + y_div; + std::cout << "\t\t" + << "Rewriting " << ret << " to " << out << " using rule on line " << __LINE__ + << std::endl; + return RecursiveRewrite(out); } } @@ -1161,7 +1207,12 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); - if (auto const_res = TryConstFold(op->a, op->b)) return const_res.value(); + if (auto const_res = TryConstFold(op->a, op->b)) { + std::cout << "\t\t" + << "Rewriting " << ret << " to " << const_res.value() << " using rule on line " + << __LINE__ << std::endl; + return const_res.value(); + } // Pattern var to match any expression PVar x, y, z, b1; @@ -1181,7 +1232,11 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) { int64_t c2val = c2.Eval()->value; ICHECK(c2val != 0) << "division by zero"; if (c1val % c2val == 0) { - return broadcast(floormod(b1, c2), lanes).Eval(); + PrimExpr out = broadcast(floormod(b1, c2), lanes).Eval(); + std::cout << "\t\t" + << "Rewriting " << ret << " to " << out << " using rule on line " << __LINE__ + << std::endl; + return out; } // If all possible indices in ramp are the same. ModularSet bmod = analyzer_->modular_set(b1.Eval()); @@ -1201,6 +1256,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) { } // If b1 can divide c2 if (bmod->coeff % c2val == 0) { +<<<<<<< HEAD return floormod(ramp(floormod(bmod->base, c2), c1, lanes), broadcast(c2, lanes)).Eval(); } } else { /* scalable vectors */ @@ -1208,6 +1264,31 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) { return floormod(ramp(floormod(bmod->base, c2), c1, lanes), broadcast(c2, lanes)).Eval(); } } +======= + PrimExpr out = ramp(floormod(bmod->base, c2), c1, lanes).Eval(); + std::cout << "\t\t" + << "Rewriting " << ret << " to " << out << " using rule on line " << __LINE__ + << std::endl; + return out; + } + // If all indices can be guaranteed to settle inside a coeff range + if (c2val % bmod->coeff == 0 && bmod->base + (lanes.Eval() - 1) * c1val < bmod->coeff) { + PrimExpr out = ramp(floormod(b1, c2), c1, lanes).Eval(); + std::cout << "\t\t" + << "Rewriting " << ret << " to " << out << " using rule on line " << __LINE__ + << std::endl; + return out; + } + } + if (bmod->coeff % c2val == 0) { + PrimExpr out = + floormod(ramp(floormod(bmod->base, c2), c1, lanes), broadcast(c2, lanes)).Eval(); + std::cout << "\t\t" + << "Rewriting " << ret << " to " << out << " using rule on line " << __LINE__ + << std::endl; + return out; + } +>>>>>>> 99950669ad ([Debug] Print statements for RewriteSimplifier) } } @@ -1260,7 +1341,11 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) { // try modular analysis ModularSet mod = analyzer_->modular_set(x.Eval()); if (mod->coeff % c1val == 0) { - return floormod(mod->base, c1).Eval(); + auto out = floormod(mod->base, c1).Eval(); + std::cout << "\t\t" + << "Rewriting " << ret << " to " << out << " using rule on line " << __LINE__ + << std::endl; + return out; } // floormod(x,c1) is a no-op when x is already in the @@ -1278,7 +1363,12 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) { PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MinNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); - if (auto const_res = TryConstFold(op->a, op->b)) return const_res.value(); + if (auto const_res = TryConstFold(op->a, op->b)) { + std::cout << "\t\t" + << "Rewriting " << ret << " to " << const_res.value() << " using rule on line " + << __LINE__ << std::endl; + return const_res.value(); + } // Pattern var to match any expression PVar x, y, z, s1, s2; @@ -1462,7 +1552,12 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MinNode* op) { PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MaxNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); - if (auto const_res = TryConstFold(op->a, op->b)) return const_res.value(); + if (auto const_res = TryConstFold(op->a, op->b)) { + std::cout << "\t\t" + << "Rewriting " << ret << " to " << const_res.value() << " using rule on line " + << __LINE__ << std::endl; + return const_res.value(); + } // Pattern var to match any expression PVar x, y, z, s1, s2; @@ -1672,9 +1767,15 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const EQNode* op) { op = ret.get(); if (auto const_res = TryConstFold(op->a, op->b)) { + std::cout << "\t\t" + << "Rewriting " << ret << " to " << const_res.value() << " using rule on line " + << __LINE__ << std::endl; return const_res.value(); } if (auto match = TryMatchLiteralConstraint(ret)) { + std::cout << "\t\t" + << "Rewriting " << ret << " to " << match.value() << " using rule on line " + << __LINE__ << std::endl; return match.value(); } @@ -1727,8 +1828,18 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const NENode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); - if (auto const_res = TryConstFold(op->a, op->b)) return const_res.value(); - if (auto match = TryMatchLiteralConstraint(ret)) return match.value(); + if (auto const_res = TryConstFold(op->a, op->b)) { + std::cout << "\t\t" + << "Rewriting " << ret << " to " << const_res.value() << " using rule on line " + << __LINE__ << std::endl; + return const_res.value(); + } + if (auto match = TryMatchLiteralConstraint(ret)) { + std::cout << "\t\t" + << "Rewriting " << ret << " to " << match.value() << " using rule on line " + << __LINE__ << std::endl; + return match.value(); + } if (IsIndexType(op->a.dtype())) { CompareResult result = TryCompare(op->a, op->b); @@ -1764,8 +1875,18 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const LENode* op) { op = ret.as(); ICHECK(op); - if (auto const_res = TryConstFold(op->a, op->b)) return const_res.value(); - if (auto match = TryMatchLiteralConstraint(ret)) return match.value(); + if (auto const_res = TryConstFold(op->a, op->b)) { + std::cout << "\t\t" + << "Rewriting " << ret << " to " << const_res.value() << " using rule on line " + << __LINE__ << std::endl; + return const_res.value(); + } + if (auto match = TryMatchLiteralConstraint(ret)) { + std::cout << "\t\t" + << "Rewriting " << ret << " to " << match.value() << " using rule on line " + << __LINE__ << std::endl; + return match.value(); + } // Check for applicable rewrites before attempting to prove/disprove // the inequality. This preserves earlier behavior, where (A<=B*x) @@ -1815,8 +1936,18 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const LTNode* op) { LT node = Downcast(IRMutatorWithAnalyzer::VisitExpr_(op)); op = node.get(); - if (auto const_res = TryConstFold(op->a, op->b)) return const_res.value(); - if (auto match = TryMatchLiteralConstraint(node)) return match.value(); + if (auto const_res = TryConstFold(op->a, op->b)) { + std::cout << "\t\t" + << "Rewriting " << node << " to " << const_res.value() << " using rule on line " + << __LINE__ << std::endl; + return const_res.value(); + } + if (auto match = TryMatchLiteralConstraint(node)) { + std::cout << "\t\t" + << "Rewriting " << node << " to " << match.value() << " using rule on line " + << __LINE__ << std::endl; + return match.value(); + } return ApplyRewriteRules(node); } @@ -1986,8 +2117,18 @@ PrimExpr RewriteSimplifier::Impl::ApplyRewriteRules(LT ret) { PrimExpr RewriteSimplifier::Impl::VisitExpr_(const NotNode* op) { Not ret = Downcast(IRMutatorWithAnalyzer::VisitExpr_(op)); - if (auto const_res = TryConstFold(ret->a)) return const_res.value(); - if (auto match = TryMatchLiteralConstraint(ret)) return match.value(); + if (auto const_res = TryConstFold(ret->a)) { + std::cout << "\t\t" + << "Rewriting " << ret << " to " << const_res.value() << " using rule on line " + << __LINE__ << std::endl; + return const_res.value(); + } + if (auto match = TryMatchLiteralConstraint(ret)) { + std::cout << "\t\t" + << "Rewriting " << ret << " to " << match.value() << " using rule on line " + << __LINE__ << std::endl; + return match.value(); + } return ApplyRewriteRules(ret); } @@ -2059,8 +2200,18 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AndNode* op) { op = ret.as(); - if (auto const_res = TryConstFold(op->a, op->b)) return const_res.value(); - if (auto match = TryMatchLiteralConstraint(ret)) return match.value(); + if (auto const_res = TryConstFold(op->a, op->b)) { + std::cout << "\t\t" + << "Rewriting " << ret << " to " << const_res.value() << " using rule on line " + << __LINE__ << std::endl; + return const_res.value(); + } + if (auto match = TryMatchLiteralConstraint(ret)) { + std::cout << "\t\t" + << "Rewriting " << ret << " to " << match.value() << " using rule on line " + << __LINE__ << std::endl; + return match.value(); + } if ((enabled_extensions_ & RewriteSimplifier::kConvertBooleanToAndOfOrs) && !recursively_visiting_boolean_) { return SimplifyAsAndOfOrs(ret, analyzer_); @@ -2207,8 +2358,18 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const OrNode* op) { }(); op = ret.as(); - if (auto const_res = TryConstFold(op->a, op->b)) return const_res.value(); - if (auto match = TryMatchLiteralConstraint(ret)) return match.value(); + if (auto const_res = TryConstFold(op->a, op->b)) { + std::cout << "\t\t" + << "Rewriting " << ret << " to " << const_res.value() << " using rule on line " + << __LINE__ << std::endl; + return const_res.value(); + } + if (auto match = TryMatchLiteralConstraint(ret)) { + std::cout << "\t\t" + << "Rewriting " << ret << " to " << match.value() << " using rule on line " + << __LINE__ << std::endl; + return match.value(); + } if ((enabled_extensions_ & RewriteSimplifier::kConvertBooleanToAndOfOrs) && !recursively_visiting_boolean_) { return SimplifyAsAndOfOrs(ret, analyzer_); @@ -2329,6 +2490,9 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const CallNode* op) { if (op->op.same_as(tir::builtin::likely())) { // Cases such as for (i, 0, bound) {if (likely(iter_var < bound)) { .. } } if (auto match = TryMatchLiteralConstraint(op->args[0])) { + std::cout << "\t\t" + << "Rewriting " << ret << " to " << match.value() << " using rule on line " + << __LINE__ << std::endl; return match.value(); } } @@ -2360,12 +2524,18 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const VarNode* op) { Var var = GetRef(op); if (op->dtype == DataType::Bool()) { if (auto match = TryMatchLiteralConstraint(var)) { + std::cout << "\t\t" + << "Rewriting " << var << " to " << match.value() << " using rule on line " + << __LINE__ << std::endl; return match.value(); } } auto it = var_map_.find(var); if (it != var_map_.end()) { + std::cout << "\t\t" + << "Rewriting " << var << " to " << it->second << " using rule on line " << __LINE__ + << std::endl; return it->second; } return GetRef(op); From 884229cdf37dfff04ed244622c0dfc3c11b5fead Mon Sep 17 00:00:00 2001 From: Eldritch Cheese Date: Fri, 6 Jan 2023 15:12:31 -0600 Subject: [PATCH 11/12] [Debug] Print statements in iter affine --- src/arith/iter_affine_map.cc | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index 77b20fcdf203..d5ef630b1118 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -651,7 +651,15 @@ class IterMapRewriter : public ExprMutator { return Array(); } } else { - ErrorLogger(this) << "The lower factor is not divisible by the full iter space extent"; + ErrorLogger(this) << "The lower factor " << mark->extent + << " is not divisible by the full iter space extent"; + ErrorLogger(this) << "This mark's extent " << mark->extent + << " is neither equal to, nor divisible by, the expected factor " + << expected_lower_factor << "."; + std::cout << "start" << std::endl; + ErrorLogger(this) << "Expected floormod(extent,divisor) == 0, but was " + << analyzer_->Simplify(floormod(mark->extent, expected_lower_factor)); + std::cout << "end" << std::endl; return {}; } } @@ -1758,6 +1766,12 @@ std::pair IterMapRewriter::PadDividendToDivisor(IterSpl // on the iter mark is compatible with it's own left padding. requires_padding_ = true; PrimExpr mark_left_pad = left_pad * split->lower_factor; + std::cout << "Defining info.left_pad based on max(" << info.left_pad << ", " << mark_left_pad + << ")" << std::endl; + std::cout << "\t" + << "base = " << base << std::endl; + std::cout << "\t" + << "divisor = " << divisor << std::endl; info.left_pad = max(info.left_pad, mark_left_pad); // Since we only care the extent in the first pass's result @@ -1779,6 +1793,11 @@ std::pair IterMapRewriter::PadDividendToDivisor(IterSpl return {split, left_pad}; } + std::cout << "Couldn't prove that extent " << mark->extent << " is divisible by " + << info.padding_factor << std::endl; + std::cout << "\t" + << "info.left_pad = " << info.left_pad << std::endl; + // check that padding factor is compatible with current split and divisor ICHECK(CanProveDivisible(info.padding_factor, split->lower_factor)) << "The padding factor " << info.padding_factor << " is not divisible by " @@ -1817,6 +1836,10 @@ std::pair IterMapRewriter::PadDividendToDivisor(IterSpl info.padded = IterMark(IterSumExpr({IterSplitExpr(mark)}, mark_left_pad), padded_extent); padded_origin_map_[info.padded] = mark; + std::cout << "Created padded version of " << mark << " with (left_pad,right_pad) = (" + << mark_left_pad << ", " << mark_right_pad << "), extent = " << padded_extent + << std::endl; + auto left_padding_introduced = (mark_left_pad != 0); // Equivalent to (0 <= split < left_pad), but easier to simplify in From dc07d5d166e8798eb3754818e79f6f5c53e0d885 Mon Sep 17 00:00:00 2001 From: Eldritch Cheese Date: Fri, 6 Jan 2023 15:25:00 -0600 Subject: [PATCH 12/12] [Debug] Print statements in Analyzer::Simplify --- src/arith/analyzer.cc | 24 +++++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc index 08d5e9379dc6..507bcbd65705 100644 --- a/src/arith/analyzer.cc +++ b/src/arith/analyzer.cc @@ -254,17 +254,35 @@ PrimExpr Analyzer::Simplify(const PrimExpr& expr, int steps) { // of an expression might be destroyed by rewrite simplification. res = this->canonical_simplify(res); + std::cout << "Simplifying " << res << std::endl; + for (int i = 0; i < steps; ++i) { if (tir::is_const_int(res)) { - return res; + break; } if (i % 2 == 0) { - res = this->rewrite_simplify(res); + std::cout << "\t" + << "Simplifying " << res << " using rewrite simplifier" << std::endl; + auto after = this->rewrite_simplify(res); + std::cout << "\t" + << "Simplified from " << res << " to " << after << " using rewrite simplifier" + << std::endl; + res = after; + // res = this->rewrite_simplify(res); } else { - res = this->canonical_simplify(res); + std::cout << "\t" + << "Simplifying " << res << " using canonical simplifier" << std::endl; + auto after = this->canonical_simplify(res); + std::cout << "\t" + << "Simplified from " << res << " to " << after << " using canonical simplifier" + << std::endl; + res = after; + // res = this->canonical_simplify(res); } } + std::cout << "Simplified into " << res << std::endl; + return res; }