Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 149 additions & 20 deletions python/tvm/topi/cuda/conv2d_alter_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,60 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
return None


def _pad_conv2d_HWNC(db, di, do, data, kernel, out_channel, new_attrs, output_tensor):
# Pad batch size
if db != 0:
data = relay.nn.pad(data, pad_width=((0, 0), (0, 0), (0, db), (0, 0)))

# Pad input channel
if di != 0:
data = relay.nn.pad(data, pad_width=((0, 0), (0, 0), (0, 0), (0, di)))
kernel = relay.nn.pad(kernel, pad_width=((0, 0), (0, 0), (0, 0), (0, di)))

# Pad output channel
if do != 0:
kernel = relay.nn.pad(kernel, pad_width=((0, 0), (0, 0), (0, do), (0, 0)))

if do != 0:
new_out_channel = out_channel + do
new_attrs["channels"] = new_out_channel

out = relay.nn.conv2d(data, kernel, **new_attrs)

if db != 0 or do != 0:
original_out_shape = [x.value for x in output_tensor.shape]
out = relay.strided_slice(out, begin=[0, 0, 0, 0], end=original_out_shape)

return out


def _pad_conv2d_NHWC(db, di, do, data, kernel, out_channel, new_attrs, output_tensor):
# Pad batch size
if db != 0:
data = relay.nn.pad(data, pad_width=((0, db), (0, 0), (0, 0), (0, 0)))

# Pad input channel
if di != 0:
data = relay.nn.pad(data, pad_width=((0, 0), (0, 0), (0, 0), (0, di)))
kernel = relay.nn.pad(kernel, pad_width=((0, 0), (0, 0), (0, di), (0, 0)))

# Pad output channel
if do != 0:
kernel = relay.nn.pad(kernel, pad_width=((0, 0), (0, 0), (0, 0), (0, do)))

if do != 0:
new_out_channel = out_channel + do
new_attrs["channels"] = new_out_channel

out = relay.nn.conv2d(data, kernel, **new_attrs)

if db != 0 or do != 0:
original_out_shape = [x.value for x in output_tensor.shape]
out = relay.strided_slice(out, begin=[0, 0, 0, 0], end=original_out_shape)

return out


@conv2d_legalize.register("cuda")
def _conv2d_legalize(attrs, inputs, arg_types):
"""Legalizes Conv2D op.
Expand Down Expand Up @@ -347,7 +401,7 @@ def _conv2d_legalize(attrs, inputs, arg_types):
else:
out = relay.nn.conv2d(data, kernel, **new_attrs)
return out
elif data_dtype in ["float16"]: # todo: support int8/int4

if data_layout == "NHWC" and kernel_layout == "HWIO":
batch = data_tensor.shape[0].value
in_channel = data_tensor.shape[3].value
Expand All @@ -361,36 +415,111 @@ def _conv2d_legalize(attrs, inputs, arg_types):
# no need to pad
return None

(db, di, do), extra_flops = pad_to_tensorcore(batch, in_channel, out_channel)
candidates = [(16, 16, 16), (32, 16, 8), (8, 16, 32)]
(db, di, do), extra_flops = pad_to_tensorcore(
batch, in_channel, out_channel, candidates
)

if extra_flops > 2:
logger.info("conv2d pad_to_tensorcore skipped, extra_flops %s", extra_flops)
return None

logger.info("conv2d pad_to_tensorcore, extra_flops %s", extra_flops)

# Pad batch size
if db != 0:
data = relay.nn.pad(data, pad_width=((0, db), (0, 0), (0, 0), (0, 0)))
return _pad_conv2d_NHWC(db, di, do, data, kernel, out_channel, new_attrs, output_tensor)

# Pad input channel
if di != 0:
data = relay.nn.pad(data, pad_width=((0, 0), (0, 0), (0, 0), (0, di)))
kernel = relay.nn.pad(kernel, pad_width=((0, 0), (0, 0), (0, di), (0, 0)))
if data_layout == "HWNC" and kernel_layout == "HWOI":
batch = data_tensor.shape[2].value
in_channel = data_tensor.shape[3].value
out_channel = kernel_tensor.shape[2].value

# Pad output channel
if do != 0:
kernel = relay.nn.pad(kernel, pad_width=((0, 0), (0, 0), (0, 0), (0, do)))
if batch % 8 == 0 and in_channel % 16 == 0 and out_channel % 32 == 0:
return None

if do != 0:
new_out_channel = out_channel + do
new_attrs["channels"] = new_out_channel
candidates = [(8, 16, 32)]
(db, di, do), extra_flops = pad_to_tensorcore(
batch, in_channel, out_channel, candidates
)

out = relay.nn.conv2d(data, kernel, **new_attrs)
if extra_flops > 2:
logger.info("conv2d pad_to_tensorcore skipped, extra_flops %s", extra_flops)
return None
logger.info("conv2d pad_to_tensorcore, extra_flops %s", extra_flops)

if db != 0 or do != 0:
original_out_shape = [x.value for x in output_tensor.shape]
out = relay.strided_slice(out, begin=[0, 0, 0, 0], end=original_out_shape)
return _pad_conv2d_HWNC(db, di, do, data, kernel, out_channel, new_attrs, output_tensor)

elif data_dtype in ["float16"]:
if data_layout == "NHWC" and kernel_layout == "HWIO":
batch = data_tensor.shape[0].value
in_channel = data_tensor.shape[3].value
out_channel = kernel_tensor.shape[3].value

if (
(batch % 8 == 0 and in_channel % 16 == 0 and out_channel % 32 == 0)
or (batch % 16 == 0 and in_channel % 16 == 0 and out_channel % 16 == 0)
or (batch % 32 == 0 and in_channel % 16 == 0 and out_channel % 8 == 0)
):
# no need to pad
return None

candidates = [(16, 16, 16), (32, 16, 8), (8, 16, 32)]
(db, di, do), extra_flops = pad_to_tensorcore(
batch, in_channel, out_channel, candidates
)

if extra_flops > 2:
logger.info("conv2d pad_to_tensorcore skipped, extra_flops %s", extra_flops)
return None

logger.info("conv2d pad_to_tensorcore, extra_flops %s", extra_flops)

return _pad_conv2d_NHWC(db, di, do, data, kernel, out_channel, new_attrs, output_tensor)

elif data_dtype in ["int4", "uint4"]:
if data_layout == "NHWC" and kernel_layout == "HWIO":
batch = data_tensor.shape[0].value
in_channel = data_tensor.shape[3].value
out_channel = kernel_tensor.shape[3].value

if (
(batch % 8 == 0 and in_channel % 16 == 0 and out_channel % 32 == 0)
or (batch % 16 == 0 and in_channel % 16 == 0 and out_channel % 16 == 0)
or (batch % 32 == 0 and in_channel % 16 == 0 and out_channel % 8 == 0)
):
# no need to pad
return None

candidates = [(16, 16, 16), (32, 16, 8), (8, 16, 32)]
(db, di, do), extra_flops = pad_to_tensorcore(
batch, in_channel, out_channel, candidates
)

if extra_flops > 2:
logger.info("conv2d pad_to_tensorcore skipped, extra_flops %s", extra_flops)
return None

logger.info("conv2d pad_to_tensorcore, extra_flops %s", extra_flops)

return _pad_conv2d_NHWC(db, di, do, data, kernel, out_channel, new_attrs, output_tensor)

if data_layout == "HWNC" and kernel_layout == "HWOI":
batch = data_tensor.shape[2].value
in_channel = data_tensor.shape[3].value
out_channel = kernel_tensor.shape[2].value

if batch % 8 == 0 and in_channel % 32 == 0 and out_channel % 8 == 0:
return None

candidates = [(8, 32, 8)]
(db, di, do), extra_flops = pad_to_tensorcore(
batch, in_channel, out_channel, candidates
)

if extra_flops > 2:
logger.info("conv2d pad_to_tensorcore skipped, extra_flops %s", extra_flops)
return None
logger.info("conv2d pad_to_tensorcore, extra_flops %s", extra_flops)

return _pad_conv2d_HWNC(db, di, do, data, kernel, out_channel, new_attrs, output_tensor)

return out
return None
10 changes: 5 additions & 5 deletions python/tvm/topi/cuda/tensorcore_alter_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ def _batch_matmul_legalize(attrs, inputs, arg_types):
# no need to pad
return None

(dm, dk, dn), extra_flops = pad_to_tensorcore(M, K, N)
candidates = [(16, 16, 16), (32, 16, 8), (8, 16, 32)]
(dm, dk, dn), extra_flops = pad_to_tensorcore(M, K, N, candidates)

if extra_flops > 2:
logger.info("batch_matmul pad_to_tensorcore skipped, extra_flops %s", extra_flops)
Expand Down Expand Up @@ -145,7 +146,8 @@ def _dense_legalize(attrs, inputs, arg_types):
# no need to pad
return None

(dm, dk, dn), extra_flops_ratio = pad_to_tensorcore(M, K, N)
candidates = [(16, 16, 16), (32, 16, 8), (8, 16, 32)]
(dm, dk, dn), extra_flops_ratio = pad_to_tensorcore(M, K, N, candidates)

if extra_flops_ratio > 2:
logger.info("dense pad_to_tensorcore skipped, extra_flops_ratio %s", extra_flops_ratio)
Expand All @@ -171,10 +173,8 @@ def _dense_legalize(attrs, inputs, arg_types):
return None


def pad_to_tensorcore(M, K, N):
def pad_to_tensorcore(M, K, N, candidates):
"""pad shape to enable tensorcore"""
candidates = [(16, 16, 16), (32, 16, 8), (8, 16, 32)]

flops = M * K * N
extra_flops = math.inf
best_pad = (0, 0, 0)
Expand Down
Loading