Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/tvm/autotvm/task/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def _encode(x):
return ("TENSOR", get_const_tuple(x.shape), x.dtype)
if isinstance(x, (tuple, list, container.Array)):
return tuple([_encode(a) for a in x])
if isinstance(x, (str, int, float, np.int, np.float, expr.Var)):
if isinstance(x, (str, int, float, np.int, np.float, expr.Var, expr.Any)):
return x
if isinstance(x, (expr.StringImm, expr.IntImm, expr.FloatImm)):
return x.value
Expand Down
21 changes: 18 additions & 3 deletions python/tvm/relay/op/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from tvm.te.hybrid import script
from tvm import topi
from tvm.runtime import convert

from .op import register_compute, register_shape_func
from .op import register_broadcast_schedule, register_injective_schedule
Expand Down Expand Up @@ -156,11 +157,22 @@ def _full_shape_func(shape):
return out


@script
def _convert_shape(shape):
out = output_tensor((len(shape),), "int64")
for i in const_range(len(shape)):
out[i] = int64(shape[i])
return out


def full_shape_func(attrs, inputs, out_ndims):
"""
Shape func for full.
"""
return [_full_shape_func(inputs[1])]
if len(inputs) > 1:
return [_full_shape_func(inputs[1])]

return [_convert_shape(convert(attrs.shape))]


def no_data_full_shape_func(attrs, inputs, out_ndims):
Expand Down Expand Up @@ -216,9 +228,9 @@ def elemwise_shape_func(attrs, inputs, _):


register_shape_func("cast", False, elemwise_shape_func)
register_shape_func("zeros", False, full_shape_func)
register_shape_func("zeros", False, no_data_full_shape_func)
register_shape_func("zeros_like", False, elemwise_shape_func)
register_shape_func("ones", False, full_shape_func)
register_shape_func("ones", False, no_data_full_shape_func)
register_shape_func("ones_like", False, elemwise_shape_func)
register_shape_func("full", False, full_shape_func)
register_shape_func("full_like", False, elemwise_shape_func)
Expand Down Expand Up @@ -257,3 +269,6 @@ def elemwise_shape_func(attrs, inputs, _):
register_shape_func("floor", False, elemwise_shape_func)
register_shape_func("log", False, elemwise_shape_func)
register_shape_func("device_copy", False, elemwise_shape_func)
register_shape_func("clip", False, elemwise_shape_func)
register_shape_func("log2", False, elemwise_shape_func)
register_shape_func("sigmoid", False, elemwise_shape_func)
50 changes: 50 additions & 0 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -745,3 +745,53 @@ def adv_index_shape_func(attrs, inputs, _):
Only allow single index tensor.
"""
return [_adv_index_shape_func(inputs)]


@script
def _repeat_shape_func(data_shape, repeats, axis):
out = output_tensor((data_shape.shape[0],), "int64")

for i in const_range(data_shape.shape[0]):
if i == axis:
out[i] = int64(data_shape[i] * repeats)
else:
out[i] = data_shape[i]

return out

@_reg.register_shape_func("repeat", False)
def repeat_shape_func(attrs, inputs, _):
"""
Shape func for repeat.
"""
axis = get_const_int(attrs.axis)
if axis < 0:
axis = inputs[0].shape[0] + axis
return [_repeat_shape_func(inputs[0], attrs.repeats, convert(axis))]


@_reg.register_shape_func("broadcast_to_like", False)
def broadcast_to_like_shape_func(attrs, inputs, _):
return [topi.math.identity(inputs[1])]


@script
def _stack_shape_func(data_shape, axis, num_inputs):
out = output_tensor((data_shape.shape[0] + 1,), "int64")

for i in const_range(data_shape.shape[0] + 1):
if i == axis:
out[i] = int64(num_inputs)
elif i < axis:
out[i] = data_shape[i]
else:
out[i] = data_shape[i - 1]

return out

@_reg.register_shape_func("stack", False)
def stack_shape_func(attrs, inputs, _):
axis = get_const_int(attrs.axis)
if axis < 0:
axis += inputs[0].shape[0] + 1
return [_stack_shape_func(inputs[0], convert(axis), convert(len(inputs)))]
43 changes: 43 additions & 0 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,6 +773,49 @@ def conv2d_NCHWc_shape_func(attrs, inputs, _):
]


@script
def _conv2d_transpose_nchw_shape_func(dshape, kshape, strides,
padding, dilation, output_padding):
out = output_tensor((dshape.shape[0],), "int64")
kheight = kshape[2]
kwidth = kshape[3]
dilated_kh = (kheight - 1) * dilation[0] + 1
dilated_kw = (kwidth - 1) * dilation[1] + 1

out_height = strides[0] * (dshape[2] - 1) + dilated_kh - \
2 * padding[0] + output_padding[0]
out_width = strides[1] * (dshape[3] - 1) + dilated_kw - \
2 * padding[1] + output_padding[1]

out[0] = dshape[0]
out[1] = kshape[1]
out[2] = out_height
out[3] = out_width
return out


@reg.register_shape_func("nn.conv2d_transpose", False)
def conv2d_transpose_nchw_shape_func(attrs, inputs, _):
"""
Shape function for conv2d_transpose op.
"""
strides = get_const_tuple(attrs.strides)
padding = get_const_tuple(attrs.padding)
dilation = get_const_tuple(attrs.dilation)
output_padding = get_const_tuple(attrs.output_padding)

return [
_conv2d_transpose_nchw_shape_func(
inputs[0],
inputs[1],
convert(strides),
convert(padding),
convert(dilation),
convert(output_padding)
)
]


@script
def _pool2d_shape_func(data_shape, pool_size, strides, padding, height_axis, width_axis):
out = output_tensor((data_shape.shape[0],), "int64")
Expand Down
17 changes: 17 additions & 0 deletions python/tvm/relay/op/vision/_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

from tvm import topi
from tvm.te.hybrid import script
from tvm.runtime import convert

from .. import op as reg
from .. import strategy
from ..op import OpPattern
Expand Down Expand Up @@ -81,3 +83,18 @@ def nms_shape_func(attrs, inputs, _):
if attrs.return_indices:
return _nms_shape_func(inputs[0])
return [topi.math.identity(inputs[0])]


@script
def _roi_align_shape_func(data_shape, rois_shape, pooled_size):
out = output_tensor((4,), "int64")
out[0] = rois_shape[0]
out[1] = data_shape[1]
out[2] = int64(pooled_size[0])
out[3] = int64(pooled_size[1])
return out

@reg.register_shape_func("vision.roi_align", False)
def roi_align_shape_func(attrs, inputs, _):
return [_roi_align_shape_func(inputs[0], inputs[1],
convert(attrs.pooled_size))]
40 changes: 20 additions & 20 deletions python/tvm/topi/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ def _scatter_1d(data, indices, updates):
@hybrid.script
def _scatter_2d(data, indices, updates, axis):
out = output_tensor(data.shape, data.dtype)
for i in const_range(data.shape[0]):
for j in const_range(data.shape[1]):
for i in range(data.shape[0]):
for j in range(data.shape[1]):
out[i, j] = data[i, j]
if axis == 0:
for i in range(indices.shape[0]):
Expand All @@ -54,14 +54,14 @@ def _scatter_2d(data, indices, updates, axis):
@hybrid.script
def _scatter_3d(data, indices, updates, axis):
out = output_tensor(data.shape, data.dtype)
for i in const_range(data.shape[0]):
for j in const_range(data.shape[1]):
for k in const_range(data.shape[2]):
for i in range(data.shape[0]):
for j in range(data.shape[1]):
for k in range(data.shape[2]):
out[i, j, k] = data[i, j, k]
if axis == 0:
for i in range(indices.shape[0]):
for j in range(indices.shape[1]):
for k in const_range(indices.shape[2]):
for k in range(indices.shape[2]):
out[
indices[i, j, k]
if indices[i, j, k] >= 0
Expand All @@ -72,7 +72,7 @@ def _scatter_3d(data, indices, updates, axis):
elif axis == 1:
for i in range(indices.shape[0]):
for j in range(indices.shape[1]):
for k in const_range(indices.shape[2]):
for k in range(indices.shape[2]):
out[
i,
indices[i, j, k]
Expand All @@ -83,7 +83,7 @@ def _scatter_3d(data, indices, updates, axis):
else:
for i in range(indices.shape[0]):
for j in range(indices.shape[1]):
for k in const_range(indices.shape[2]):
for k in range(indices.shape[2]):
out[
i,
j,
Expand All @@ -98,17 +98,17 @@ def _scatter_3d(data, indices, updates, axis):
@hybrid.script
def _scatter_4d(data, indices, updates, axis):
out = output_tensor(data.shape, data.dtype)
for i in const_range(data.shape[0]):
for j in const_range(data.shape[1]):
for k in const_range(data.shape[2]):
for l in const_range(data.shape[3]):
for i in range(data.shape[0]):
for j in range(data.shape[1]):
for k in range(data.shape[2]):
for l in range(data.shape[3]):
out[i, j, k, l] = data[i, j, k, l]

if axis == 0:
for i in range(indices.shape[0]):
for j in range(indices.shape[1]):
for k in const_range(indices.shape[2]):
for l in const_range(indices.shape[3]):
for k in range(indices.shape[2]):
for l in range(indices.shape[3]):
out[
indices[i, j, k, l]
if indices[i, j, k, l] >= 0
Expand All @@ -120,8 +120,8 @@ def _scatter_4d(data, indices, updates, axis):
elif axis == 1:
for i in range(indices.shape[0]):
for j in range(indices.shape[1]):
for k in const_range(indices.shape[2]):
for l in const_range(indices.shape[3]):
for k in range(indices.shape[2]):
for l in range(indices.shape[3]):
out[
i,
indices[i, j, k, l]
Expand All @@ -133,8 +133,8 @@ def _scatter_4d(data, indices, updates, axis):
elif axis == 2:
for i in range(indices.shape[0]):
for j in range(indices.shape[1]):
for k in const_range(indices.shape[2]):
for l in const_range(indices.shape[3]):
for k in range(indices.shape[2]):
for l in range(indices.shape[3]):
out[
i,
j,
Expand All @@ -146,8 +146,8 @@ def _scatter_4d(data, indices, updates, axis):
else:
for i in range(indices.shape[0]):
for j in range(indices.shape[1]):
for k in const_range(indices.shape[2]):
for l in const_range(indices.shape[3]):
for k in range(indices.shape[2]):
for l in range(indices.shape[3]):
out[
i,
j,
Expand Down
11 changes: 11 additions & 0 deletions python/tvm/topi/x86/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,17 @@ def _pack_data(cfg, data, kernel):
ic_chunk = ic // ic_bn
oc_chunk = oc // oc_bn

# Handle dynamic shape to pass tuning dispatch.
if isinstance(n, tvm.tir.Any):
n = tvm.te.size_var("n")
if isinstance(ih, tvm.tir.Any):
ih = tvm.te.size_var("ih")
if isinstance(iw, tvm.tir.Any):
iw = tvm.te.size_var("iw")
if isinstance(ic, tvm.tir.Any):
raise RuntimeError("Dynamic input channel is not supported for conv2d.")


data = te.compute(
(n, ic_chunk, ih, iw, ic_bn),
lambda bs, c, h, w, vc: data[bs, c * ic_bn + vc, h, w],
Expand Down
3 changes: 2 additions & 1 deletion python/tvm/topi/x86/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,8 @@ def dense_blas_common(cfg, data, weight, bias, out_dtype, lib):
"""Compute dense using a BLAS library"""
M, K = get_const_tuple(data.shape)
N, _ = get_const_tuple(weight.shape)
cfg.add_flop(M * K * N * 2)
if isinstance(M, int) and isinstance(K, int) and isinstance(N, int):
cfg.add_flop(M * K * N * 2)
if data.dtype == "uint8" and weight.dtype == "int8" and out_dtype == "int32":
if not hasattr(lib, "matmul_u8s8s32"):
raise NotImplementedError(
Expand Down
10 changes: 7 additions & 3 deletions python/tvm/topi/x86/roi_align.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@


@hybrid.script
def roi_align_nchw_ir(data, rois, w_pc, pos_pc, pooled_size, spatial_scale, sample_ratio):
def roi_align_nchw_ir(data, rois, num_rois, w_pc, pos_pc, pooled_size, spatial_scale, sample_ratio):
"""Hybrid routing fo ROI align operator in NCHW layout.

Parameters
Expand All @@ -37,6 +37,10 @@ def roi_align_nchw_ir(data, rois, w_pc, pos_pc, pooled_size, spatial_scale, samp
2-D with shape [num_roi, 5]. The last dimension should be in format of
[batch_index, w_start, h_start, w_end, h_end]

num_rois : tvm.tir.IntImm or tvm.tir.Var
Number of roi. We need to pass it in since hybrid script doesn't support
binding variable to symbolic dim.

w_pc : tvm.te.Tensor or numpy NDArray
3-D weight pre-calculation buffer

Expand All @@ -61,7 +65,6 @@ def roi_align_nchw_ir(data, rois, w_pc, pos_pc, pooled_size, spatial_scale, samp
channels = data.shape[1]
height = data.shape[2]
width = data.shape[3]
num_rois = rois.shape[0]
pooled_size_h = pooled_size[0]
pooled_size_w = pooled_size[1]
output = output_tensor((num_rois, channels, pooled_size_h, pooled_size_w), data.dtype)
Expand Down Expand Up @@ -235,6 +238,7 @@ def roi_align_nchw(data, rois, pooled_size, spatial_scale, sample_ratio=-1):
_, _, height, width = get_const_tuple(data.shape)
max_roi_bin_grid_h = math.ceil(height / pooled_size[0])
max_roi_bin_grid_w = math.ceil(width / pooled_size[1])
num_rois = rois.shape[0]
max_pc_shape = (
rois.shape[0],
max_roi_bin_grid_h * max_roi_bin_grid_w * pooled_size[0] * pooled_size[1],
Expand All @@ -247,5 +251,5 @@ def roi_align_nchw(data, rois, pooled_size, spatial_scale, sample_ratio=-1):
spatial_scale = tvm.tir.const(spatial_scale, "float32")
sample_ratio = tvm.tir.const(sample_ratio, "int32")
return roi_align_nchw_ir(
data, rois, w_pc_buffer, pos_pc_buffer, pooled_size, spatial_scale, sample_ratio
data, rois, num_rois, w_pc_buffer, pos_pc_buffer, pooled_size, spatial_scale, sample_ratio
)
11 changes: 8 additions & 3 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,9 @@ bool StackRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
CHECK_EQ(e_dtype, dtype) << "relay.stack requires all tensors have the same dtype";
for (size_t j = 0; j < first->shape.size(); ++j) {
if (j == static_cast<size_t>(axis)) continue;
if (reporter->AssertEQ(first->shape[j], e->shape[j])) continue;
if (first->shape[j].as<AnyNode>() || e->shape[j].as<AnyNode>() ||
reporter->AssertEQ(first->shape[j], e->shape[j]))
continue;
throw Error(
"relay.stack requires all tensors have the same shape "
"on non-stacking axes");
Expand Down Expand Up @@ -1292,7 +1294,11 @@ bool RepeatRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
for (int i = 0; i < pivot; ++i) {
oshape.emplace_back(data->shape[i]);
}
oshape.emplace_back(data->shape[pivot] * repeats);
if (data->shape[pivot].as<AnyNode>()) {
oshape.emplace_back(Any());
} else {
oshape.emplace_back(data->shape[pivot] * repeats);
}
for (int i = pivot + 1; i < ndim; ++i) {
oshape.emplace_back(data->shape[i]);
}
Expand Down Expand Up @@ -3243,7 +3249,6 @@ RELAY_REGISTER_OP("adv_index")
.add_type_rel("AdvIndex", AdvIndexRel)
.set_attr<TOpIsStateful>("TOpIsStateful", false)
.set_attr<TOpPattern>("TOpPattern", kInjective)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
.set_attr<FTVMCompute>("FTVMCompute", AdvIndexCompute);

} // namespace relay
Expand Down
Loading