Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,6 @@ class Conv(OnnxOpConverter):

@classmethod
def _impl_v1(cls, inputs, attr, params):
# get number of channels
out = AttrCvt(op_name=dimension_picker('conv'),
transforms={
'kernel_shape': 'kernel_size',
Expand Down
89 changes: 74 additions & 15 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#pylint: disable=invalid-name, unused-argument
# pylint: disable=invalid-name, unused-argument
"""Backend compiler related feature registration"""
from __future__ import absolute_import

Expand All @@ -34,16 +34,19 @@ def schedule_softmax(_, outputs, target):
with target:
return topi.generic.schedule_softmax(outputs)


reg.register_pattern("nn.softmax", OpPattern.OPAQUE)

schedule_broadcast = schedule_injective


@reg.register_schedule("nn.log_softmax")
def schedule_log_softmax(_, outputs, target):
"""Schedule definition of log_softmax"""
with target:
return topi.generic.schedule_softmax(outputs)


reg.register_pattern("nn.log_softmax", OpPattern.OPAQUE)


Expand All @@ -53,12 +56,14 @@ def compute_dense(attrs, inputs, out_type, target):
"""Compute definition of dense"""
return [topi.nn.dense(inputs[0], inputs[1])]


@reg.register_schedule("nn.dense")
def schedule_dense(attrs, outputs, target):
"""Schedule definition of dense"""
with target:
return topi.generic.schedule_dense(outputs)


reg.register_pattern("nn.dense", reg.OpPattern.OUT_ELEMWISE_FUSABLE)


Expand All @@ -68,16 +73,29 @@ def compute_batch_matmul(attrs, inputs, out_type, target):
"""Compute definition of batch_matmul"""
return [topi.nn.batch_matmul(inputs[0], inputs[1])]


@reg.register_schedule("nn.batch_matmul")
def schedule_batch_matmul(attrs, outputs, target):
"""Schedule definition of batch_matmul"""
with target:
return topi.generic.schedule_batch_matmul(outputs)


reg.register_pattern("nn.batch_matmul", reg.OpPattern.OUT_ELEMWISE_FUSABLE)


# conv2d
def _find_conv2d_op(op):
"""Find the op with conv2d in its tag by traversing."""
if 'conv2d' in op.tag:
return op
for tensor in op.input_tensors:
op_ = _find_conv2d_op(tensor.op)
if op_ is not None:
return op_
return None


@reg.register_compute("nn.conv2d")
def compute_conv2d(attrs, inputs, out_type, target):
"""Compute definition of conv2d"""
Expand All @@ -101,14 +119,14 @@ def compute_conv2d(attrs, inputs, out_type, target):
inputs[0], inputs[1], strides, padding,
dilation, layout, out_dtype=out_dtype)
elif layout == "NCHW" and \
get_const_int(inputs[1].shape[0]) == groups and \
get_const_int(inputs[1].shape[1]) == 1:
get_const_int(inputs[1].shape[0]) == groups and \
get_const_int(inputs[1].shape[1]) == 1:
out = topi.nn.depthwise_conv2d_nchw(
inputs[0], inputs[1], strides, padding, dilation, out_dtype=out_dtype)
elif layout == "NHWC" and \
kernel_layout == "HWOI" and\
get_const_int(inputs[1].shape[2]) == groups and \
get_const_int(inputs[1].shape[3]) == 1:
kernel_layout == "HWOI" and\
get_const_int(inputs[1].shape[2]) == groups and \
get_const_int(inputs[1].shape[3]) == 1:
out = topi.nn.depthwise_conv2d_nhwc(
inputs[0], inputs[1], strides, padding, dilation, out_dtype=out_dtype)
elif layout in ['NCHW', 'NCHW4c']:
Expand All @@ -125,6 +143,7 @@ def schedule_conv2d(attrs, outs, target):
groups = attrs.groups
layout = attrs.data_layout
kernel_layout = attrs.kernel_layout

with target:
if groups == 1 and layout == "NCHW":
return topi.generic.schedule_conv2d_nchw(outs)
Expand All @@ -133,13 +152,20 @@ def schedule_conv2d(attrs, outs, target):
if groups == 1 and layout == "NHWC":
return topi.generic.schedule_conv2d_nhwc(outs)
if groups != 1:
if layout == "NCHW":
# TODO(leyuan, merrymercy, Huyuwei): fold depthwise topi into conv2d.
return topi.generic.schedule_depthwise_conv2d_nchw(outs)
if layout == "NHWC" and kernel_layout == "HWOI":
return topi.generic.schedule_depthwise_conv2d_nhwc(outs)
if layout == "NCHW4c":
return topi.generic.schedule_group_conv2d_nchw(outs)
# collect in_channels to distinguish depthwise and group conv2d
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here I updated the logic to distinguish depthwise and group conv2d

op = _find_conv2d_op(outs[0].op)
assert op is not None

is_depthwise = 'depthwise' in op.tag
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure whether it is a good way to go but it seems that checking whether depthwise is within the tag value is much easier and intuitive.

if is_depthwise:
if layout == "NCHW":
# TODO(leyuan, merrymercy, Huyuwei): fold depthwise topi into conv2d.
return topi.generic.schedule_depthwise_conv2d_nchw(outs)
if layout == "NHWC" and kernel_layout == "HWOI":
return topi.generic.schedule_depthwise_conv2d_nhwc(outs)
else:
if layout in ["NCHW", "NCHW4c"]:
return topi.generic.schedule_group_conv2d_nchw(outs)
raise ValueError("No compatible schedule")


Expand All @@ -149,6 +175,7 @@ def alter_op_layout_conv2d(attrs, inputs, tinfos):
from ... import op
return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos, op)


reg.register_pattern("nn.conv2d", OpPattern.OUT_ELEMWISE_FUSABLE)


Expand All @@ -167,18 +194,21 @@ def compute_conv2d_transpose(attrs, inputs, out_dtype, target):
assert layout == "NCHW", "only support nchw for now"
assert dilation == (1, 1), "not support dilate now"
assert groups == 1, "only support groups == 1 for now"
out = topi.nn.conv2d_transpose_nchw(inputs[0], inputs[1], strides, padding, out_dtype)
out = topi.nn.conv2d_transpose_nchw(
inputs[0], inputs[1], strides, padding, out_dtype)
output_padding = get_const_tuple(attrs.output_padding)
out = topi.nn.pad(out,
[0, 0, 0, 0], [0, 0, output_padding[0], output_padding[1]])
return [out]


@reg.register_schedule("nn.conv2d_transpose")
def schedule_conv2d_transpose(attrs, outs, target):
"""Schedule definition of conv2d_transpose"""
with target:
return topi.generic.schedule_conv2d_transpose_nchw(outs)


reg.register_pattern("nn.conv2d_transpose", OpPattern.OUT_ELEMWISE_FUSABLE)

# bias_add
Expand All @@ -194,6 +224,7 @@ def schedule_max_pool2d(attrs, outs, target):
with target:
return topi.generic.schedule_pool(outs, layout)


reg.register_pattern("nn.max_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)


Expand All @@ -205,6 +236,7 @@ def schedule_avg_pool2d(attrs, outs, target):
with target:
return topi.generic.schedule_pool(outs, layout)


reg.register_pattern("nn.avg_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)


Expand All @@ -215,6 +247,7 @@ def schedule_global_max_pool2d(_, outs, target):
with target:
return topi.generic.schedule_global_pool(outs)


reg.register_pattern("nn.global_max_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)


Expand All @@ -225,6 +258,7 @@ def schedule_global_avg_pool2d(_, outs, target):
with target:
return topi.generic.schedule_global_pool(outs)


reg.register_pattern("nn.global_avg_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)

# leaky_relu
Expand All @@ -248,12 +282,14 @@ def compute_lrn(attrs, inputs, out_dtype, target):
return [topi.nn.lrn(inputs[0], attrs.size, attrs.axis,
attrs.alpha, attrs.beta, attrs.bias)]


@reg.register_schedule("nn.lrn")
def schedule_lrn(attrs, outs, target):
"""Schedule definition of lrn"""
with target:
return topi.generic.schedule_lrn(outs)


reg.register_pattern("nn.lrn", OpPattern.OPAQUE)


Expand All @@ -263,20 +299,26 @@ def compute_l2_normalize(attrs, inputs, out_dtype, target):
"""Compute definition of l2 normalize"""
return [topi.nn.l2_normalize(inputs[0], attrs.eps, attrs.axis)]


@reg.register_schedule("nn.l2_normalize")
def schedule_l2_normalize(attrs, outs, target):
"""Schedule definition of l2 normalize"""
with target:
return topi.generic.schedule_l2_normalize(outs)


reg.register_pattern("nn.l2_normalize", OpPattern.OUT_ELEMWISE_FUSABLE)

# upsampling
reg.register_schedule("nn.upsampling", reg.schedule_injective)


def schedule_upsampling(_, outs, target):
"""Schedule definition of upsampling"""
with target:
return topi.generic.schedule_injective(outs)


# pad
reg.register_schedule("nn.pad", schedule_broadcast)

Expand All @@ -302,28 +344,33 @@ def compute_contrib_conv2d_winograd_without_weight_transform(attrs, inputs, out_

return [out]


@reg.register_schedule("nn.contrib_conv2d_winograd_without_weight_transform")
def schedule_contrib_conv2d_winograd_without_weight_transform(attrs, outs, target):
"""Schedule definition of conv2d_winograd_without_weight_transform"""
with target:
return topi.generic.schedule_conv2d_winograd_without_weight_transform(outs)


reg.register_pattern("nn.contrib_conv2d_winograd_without_weight_transform",
OpPattern.OUT_ELEMWISE_FUSABLE)


@reg.register_compute("nn.contrib_conv2d_winograd_weight_transform")
def compute_contrib_conv2d_winograd_weight_transform(attrs, inputs, out_dtype, target):
"""Compute definition of contrib_conv2d_winograd_weight_transform"""
out = topi.nn.conv2d_winograd_weight_transform(inputs[0], attrs.get_int('tile_size'))
out = topi.nn.conv2d_winograd_weight_transform(
inputs[0], attrs.get_int('tile_size'))
return [out]


@reg.register_schedule("nn.contrib_conv2d_winograd_weight_transform")
def schedule_contrib_conv2d_winograd_weight_transform(attrs, outs, target):
"""Schedule definition of contrib_conv2d_winograd_weight_transform"""
with target:
return topi.generic.schedule_conv2d_winograd_weight_transform(outs)


reg.register_pattern("nn.contrib_conv2d_winograd_weight_transform",
OpPattern.OUT_ELEMWISE_FUSABLE)

Expand Down Expand Up @@ -351,12 +398,14 @@ def compute_contrib_conv2d_winograd_nnpack_without_weight_transform(

return [out]


@reg.register_schedule("nn.contrib_conv2d_winograd_nnpack_without_weight_transform")
def schedule_contrib_conv2d_winograd_nnpack_without_weight_transform(attrs, outs, target):
"""Schedule definition of conv2d_winograd_nnpack_without_weight_transform"""
with target:
return topi.generic.schedule_conv2d_winograd_nnpack_without_weight_transform(outs)


reg.register_pattern("nn.contrib_conv2d_winograd_nnpack_without_weight_transform",
OpPattern.OPAQUE)

Expand All @@ -369,12 +418,14 @@ def compute_contrib_conv2d_winograd_nnpack_weight_transform(attrs, inputs, out_d
inputs[0], convolution_algorithm, out_dtype)
return [out]


@reg.register_schedule("nn.contrib_conv2d_winograd_nnpack_weight_transform")
def schedule_contrib_conv2d_winograd_nnpack_weight_transform(attrs, outs, target):
"""Schedule definition of contrib_conv2d_winograd_nnpack_weight_transform"""
with target:
return topi.generic.schedule_conv2d_winograd_nnpack_weight_transform(outs)


reg.register_pattern("nn.contrib_conv2d_winograd_nnpack_weight_transform",
OpPattern.OPAQUE)

Expand All @@ -395,15 +446,18 @@ def compute_contrib_conv2d_NCHWc(attrs, inputs, out_dtype, target):
data_layout, out_layout, out_dtype)
return [out]


@reg.register_schedule("nn.contrib_conv2d_NCHWc")
def schedule_contrib_conv2d_NCHWc(attrs, outs, target):
"""Schedule definition of contrib_conv2d_NCHWc"""
with target:
return topi.generic.schedule_conv2d_NCHWc(outs)


reg.register_pattern("nn.contrib_conv2d_NCHWc",
OpPattern.OUT_ELEMWISE_FUSABLE)


@reg.register_compute("nn.contrib_depthwise_conv2d_NCHWc")
def compute_contrib_depthwise_conv2d_NCHWc(attrs, inputs, out_dtype, target):
"""Compute definition of depthwise conv2d NCHWc"""
Expand All @@ -420,15 +474,18 @@ def compute_contrib_depthwise_conv2d_NCHWc(attrs, inputs, out_dtype, target):
data_layout, out_layout, out_dtype)
return [out]


@reg.register_schedule("nn.contrib_depthwise_conv2d_NCHWc")
def schedule_contrib_depthwise_conv2d_NCHWc(attrs, outs, target):
"""Schedule definition of contrib_conv2d_NCHWc"""
with target:
return topi.generic.schedule_depthwise_conv2d_NCHWc(outs)


reg.register_pattern("nn.contrib_depthwise_conv2d_NCHWc",
OpPattern.OUT_ELEMWISE_FUSABLE)


@reg.register_compute("nn.deformable_conv2d")
def compute_deformable_conv2d(attrs, inputs, out_dtype, target):
"""Compute definition of deformable_conv2d"""
Expand All @@ -444,10 +501,12 @@ def compute_deformable_conv2d(attrs, inputs, out_dtype, target):
dilation, deformable_groups, groups, out_dtype)
return [out]


@reg.register_schedule("nn.deformable_conv2d")
def schedule_deformable_conv2d(attrs, outs, target):
"""Schedule definition of deformable_conv2d"""
with target:
return topi.generic.schedule_deformable_conv2d_nchw(outs)


reg.register_pattern("nn.deformable_conv2d", OpPattern.OUT_ELEMWISE_FUSABLE)
2 changes: 1 addition & 1 deletion python/tvm/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def dispatch_func(func, *args, **kwargs):
def generic_func(fdefault):
"""Wrap a target generic function.

Generic function allows registeration of further functions
Generic function allows registration of further functions
that can be dispatched on current target context.
If no registered dispatch is matched, the fdefault will be called.

Expand Down
Loading