-
Notifications
You must be signed in to change notification settings - Fork 3.8k
Fixed issue #3069 by checking op tag #3070
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
d008622
Fixed issue #3069 by adding in_channels
kumasento c1d47dc
Registerd group_conv2d_nchw as topi compute
kumasento e524d62
Improved by checking tag value
kumasento a56fb77
Removed group_conv2d_nchw topi registration
kumasento a25dfcd
Added test for relay group_conv2d_nchw
kumasento fd24e4b
Added assertions to forbid small group size
kumasento 4774f28
Removed hard-coded oc_block_factor
kumasento 626a5e6
Added explanatory comments to group_conv2d_nchw_cuda
kumasento 8de44b2
Updated group_conv2d_nchw_cuda schedule
kumasento 4866a17
Reverted an accidental change in a conv2d test
kumasento c42b3bf
Fixed indentation problems
kumasento 2c3f7c5
Fixed a mis-commented line
kumasento db8b2a9
Reverted change in group_conv2d_nchw tag
kumasento b7ed247
Removed commented int8 group_conv2d test
kumasento e901c7c
Fixed group size assertions in group_conv2d_nchw_cuda
kumasento File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,7 +14,7 @@ | |
| # KIND, either express or implied. See the License for the | ||
| # specific language governing permissions and limitations | ||
| # under the License. | ||
| #pylint: disable=invalid-name, unused-argument | ||
| # pylint: disable=invalid-name, unused-argument | ||
| """Backend compiler related feature registration""" | ||
| from __future__ import absolute_import | ||
|
|
||
|
|
@@ -34,16 +34,19 @@ def schedule_softmax(_, outputs, target): | |
| with target: | ||
| return topi.generic.schedule_softmax(outputs) | ||
|
|
||
|
|
||
| reg.register_pattern("nn.softmax", OpPattern.OPAQUE) | ||
|
|
||
| schedule_broadcast = schedule_injective | ||
|
|
||
|
|
||
| @reg.register_schedule("nn.log_softmax") | ||
| def schedule_log_softmax(_, outputs, target): | ||
| """Schedule definition of log_softmax""" | ||
| with target: | ||
| return topi.generic.schedule_softmax(outputs) | ||
|
|
||
|
|
||
| reg.register_pattern("nn.log_softmax", OpPattern.OPAQUE) | ||
|
|
||
|
|
||
|
|
@@ -53,12 +56,14 @@ def compute_dense(attrs, inputs, out_type, target): | |
| """Compute definition of dense""" | ||
| return [topi.nn.dense(inputs[0], inputs[1])] | ||
|
|
||
|
|
||
| @reg.register_schedule("nn.dense") | ||
| def schedule_dense(attrs, outputs, target): | ||
| """Schedule definition of dense""" | ||
| with target: | ||
| return topi.generic.schedule_dense(outputs) | ||
|
|
||
|
|
||
| reg.register_pattern("nn.dense", reg.OpPattern.OUT_ELEMWISE_FUSABLE) | ||
|
|
||
|
|
||
|
|
@@ -68,16 +73,29 @@ def compute_batch_matmul(attrs, inputs, out_type, target): | |
| """Compute definition of batch_matmul""" | ||
| return [topi.nn.batch_matmul(inputs[0], inputs[1])] | ||
|
|
||
|
|
||
| @reg.register_schedule("nn.batch_matmul") | ||
| def schedule_batch_matmul(attrs, outputs, target): | ||
| """Schedule definition of batch_matmul""" | ||
| with target: | ||
| return topi.generic.schedule_batch_matmul(outputs) | ||
|
|
||
|
|
||
| reg.register_pattern("nn.batch_matmul", reg.OpPattern.OUT_ELEMWISE_FUSABLE) | ||
|
|
||
|
|
||
| # conv2d | ||
| def _find_conv2d_op(op): | ||
| """Find the op with conv2d in its tag by traversing.""" | ||
| if 'conv2d' in op.tag: | ||
| return op | ||
| for tensor in op.input_tensors: | ||
| op_ = _find_conv2d_op(tensor.op) | ||
| if op_ is not None: | ||
| return op_ | ||
| return None | ||
|
|
||
|
|
||
| @reg.register_compute("nn.conv2d") | ||
| def compute_conv2d(attrs, inputs, out_type, target): | ||
| """Compute definition of conv2d""" | ||
|
|
@@ -101,14 +119,14 @@ def compute_conv2d(attrs, inputs, out_type, target): | |
| inputs[0], inputs[1], strides, padding, | ||
| dilation, layout, out_dtype=out_dtype) | ||
| elif layout == "NCHW" and \ | ||
| get_const_int(inputs[1].shape[0]) == groups and \ | ||
| get_const_int(inputs[1].shape[1]) == 1: | ||
| get_const_int(inputs[1].shape[0]) == groups and \ | ||
| get_const_int(inputs[1].shape[1]) == 1: | ||
| out = topi.nn.depthwise_conv2d_nchw( | ||
| inputs[0], inputs[1], strides, padding, dilation, out_dtype=out_dtype) | ||
| elif layout == "NHWC" and \ | ||
| kernel_layout == "HWOI" and\ | ||
| get_const_int(inputs[1].shape[2]) == groups and \ | ||
| get_const_int(inputs[1].shape[3]) == 1: | ||
| kernel_layout == "HWOI" and\ | ||
| get_const_int(inputs[1].shape[2]) == groups and \ | ||
| get_const_int(inputs[1].shape[3]) == 1: | ||
| out = topi.nn.depthwise_conv2d_nhwc( | ||
| inputs[0], inputs[1], strides, padding, dilation, out_dtype=out_dtype) | ||
| elif layout in ['NCHW', 'NCHW4c']: | ||
|
|
@@ -125,6 +143,7 @@ def schedule_conv2d(attrs, outs, target): | |
| groups = attrs.groups | ||
| layout = attrs.data_layout | ||
| kernel_layout = attrs.kernel_layout | ||
|
|
||
| with target: | ||
| if groups == 1 and layout == "NCHW": | ||
| return topi.generic.schedule_conv2d_nchw(outs) | ||
|
|
@@ -133,13 +152,20 @@ def schedule_conv2d(attrs, outs, target): | |
| if groups == 1 and layout == "NHWC": | ||
| return topi.generic.schedule_conv2d_nhwc(outs) | ||
| if groups != 1: | ||
| if layout == "NCHW": | ||
| # TODO(leyuan, merrymercy, Huyuwei): fold depthwise topi into conv2d. | ||
| return topi.generic.schedule_depthwise_conv2d_nchw(outs) | ||
| if layout == "NHWC" and kernel_layout == "HWOI": | ||
| return topi.generic.schedule_depthwise_conv2d_nhwc(outs) | ||
| if layout == "NCHW4c": | ||
| return topi.generic.schedule_group_conv2d_nchw(outs) | ||
| # collect in_channels to distinguish depthwise and group conv2d | ||
| op = _find_conv2d_op(outs[0].op) | ||
| assert op is not None | ||
|
|
||
| is_depthwise = 'depthwise' in op.tag | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure whether it is a good way to go but it seems that checking whether |
||
| if is_depthwise: | ||
| if layout == "NCHW": | ||
| # TODO(leyuan, merrymercy, Huyuwei): fold depthwise topi into conv2d. | ||
| return topi.generic.schedule_depthwise_conv2d_nchw(outs) | ||
| if layout == "NHWC" and kernel_layout == "HWOI": | ||
| return topi.generic.schedule_depthwise_conv2d_nhwc(outs) | ||
| else: | ||
| if layout in ["NCHW", "NCHW4c"]: | ||
| return topi.generic.schedule_group_conv2d_nchw(outs) | ||
| raise ValueError("No compatible schedule") | ||
|
|
||
|
|
||
|
|
@@ -149,6 +175,7 @@ def alter_op_layout_conv2d(attrs, inputs, tinfos): | |
| from ... import op | ||
| return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos, op) | ||
|
|
||
|
|
||
| reg.register_pattern("nn.conv2d", OpPattern.OUT_ELEMWISE_FUSABLE) | ||
|
|
||
|
|
||
|
|
@@ -167,18 +194,21 @@ def compute_conv2d_transpose(attrs, inputs, out_dtype, target): | |
| assert layout == "NCHW", "only support nchw for now" | ||
| assert dilation == (1, 1), "not support dilate now" | ||
| assert groups == 1, "only support groups == 1 for now" | ||
| out = topi.nn.conv2d_transpose_nchw(inputs[0], inputs[1], strides, padding, out_dtype) | ||
| out = topi.nn.conv2d_transpose_nchw( | ||
| inputs[0], inputs[1], strides, padding, out_dtype) | ||
| output_padding = get_const_tuple(attrs.output_padding) | ||
| out = topi.nn.pad(out, | ||
| [0, 0, 0, 0], [0, 0, output_padding[0], output_padding[1]]) | ||
| return [out] | ||
|
|
||
|
|
||
| @reg.register_schedule("nn.conv2d_transpose") | ||
| def schedule_conv2d_transpose(attrs, outs, target): | ||
| """Schedule definition of conv2d_transpose""" | ||
| with target: | ||
| return topi.generic.schedule_conv2d_transpose_nchw(outs) | ||
|
|
||
|
|
||
| reg.register_pattern("nn.conv2d_transpose", OpPattern.OUT_ELEMWISE_FUSABLE) | ||
|
|
||
| # bias_add | ||
|
|
@@ -194,6 +224,7 @@ def schedule_max_pool2d(attrs, outs, target): | |
| with target: | ||
| return topi.generic.schedule_pool(outs, layout) | ||
|
|
||
|
|
||
| reg.register_pattern("nn.max_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE) | ||
|
|
||
|
|
||
|
|
@@ -205,6 +236,7 @@ def schedule_avg_pool2d(attrs, outs, target): | |
| with target: | ||
| return topi.generic.schedule_pool(outs, layout) | ||
|
|
||
|
|
||
| reg.register_pattern("nn.avg_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE) | ||
|
|
||
|
|
||
|
|
@@ -215,6 +247,7 @@ def schedule_global_max_pool2d(_, outs, target): | |
| with target: | ||
| return topi.generic.schedule_global_pool(outs) | ||
|
|
||
|
|
||
| reg.register_pattern("nn.global_max_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE) | ||
|
|
||
|
|
||
|
|
@@ -225,6 +258,7 @@ def schedule_global_avg_pool2d(_, outs, target): | |
| with target: | ||
| return topi.generic.schedule_global_pool(outs) | ||
|
|
||
|
|
||
| reg.register_pattern("nn.global_avg_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE) | ||
|
|
||
| # leaky_relu | ||
|
|
@@ -248,12 +282,14 @@ def compute_lrn(attrs, inputs, out_dtype, target): | |
| return [topi.nn.lrn(inputs[0], attrs.size, attrs.axis, | ||
| attrs.alpha, attrs.beta, attrs.bias)] | ||
|
|
||
|
|
||
| @reg.register_schedule("nn.lrn") | ||
| def schedule_lrn(attrs, outs, target): | ||
| """Schedule definition of lrn""" | ||
| with target: | ||
| return topi.generic.schedule_lrn(outs) | ||
|
|
||
|
|
||
| reg.register_pattern("nn.lrn", OpPattern.OPAQUE) | ||
|
|
||
|
|
||
|
|
@@ -263,20 +299,26 @@ def compute_l2_normalize(attrs, inputs, out_dtype, target): | |
| """Compute definition of l2 normalize""" | ||
| return [topi.nn.l2_normalize(inputs[0], attrs.eps, attrs.axis)] | ||
|
|
||
|
|
||
| @reg.register_schedule("nn.l2_normalize") | ||
| def schedule_l2_normalize(attrs, outs, target): | ||
| """Schedule definition of l2 normalize""" | ||
| with target: | ||
| return topi.generic.schedule_l2_normalize(outs) | ||
|
|
||
|
|
||
| reg.register_pattern("nn.l2_normalize", OpPattern.OUT_ELEMWISE_FUSABLE) | ||
|
|
||
| # upsampling | ||
| reg.register_schedule("nn.upsampling", reg.schedule_injective) | ||
|
|
||
|
|
||
| def schedule_upsampling(_, outs, target): | ||
| """Schedule definition of upsampling""" | ||
| with target: | ||
| return topi.generic.schedule_injective(outs) | ||
|
|
||
|
|
||
| # pad | ||
| reg.register_schedule("nn.pad", schedule_broadcast) | ||
|
|
||
|
|
@@ -302,28 +344,33 @@ def compute_contrib_conv2d_winograd_without_weight_transform(attrs, inputs, out_ | |
|
|
||
| return [out] | ||
|
|
||
|
|
||
| @reg.register_schedule("nn.contrib_conv2d_winograd_without_weight_transform") | ||
| def schedule_contrib_conv2d_winograd_without_weight_transform(attrs, outs, target): | ||
| """Schedule definition of conv2d_winograd_without_weight_transform""" | ||
| with target: | ||
| return topi.generic.schedule_conv2d_winograd_without_weight_transform(outs) | ||
|
|
||
|
|
||
| reg.register_pattern("nn.contrib_conv2d_winograd_without_weight_transform", | ||
| OpPattern.OUT_ELEMWISE_FUSABLE) | ||
|
|
||
|
|
||
| @reg.register_compute("nn.contrib_conv2d_winograd_weight_transform") | ||
| def compute_contrib_conv2d_winograd_weight_transform(attrs, inputs, out_dtype, target): | ||
| """Compute definition of contrib_conv2d_winograd_weight_transform""" | ||
| out = topi.nn.conv2d_winograd_weight_transform(inputs[0], attrs.get_int('tile_size')) | ||
| out = topi.nn.conv2d_winograd_weight_transform( | ||
| inputs[0], attrs.get_int('tile_size')) | ||
| return [out] | ||
|
|
||
|
|
||
| @reg.register_schedule("nn.contrib_conv2d_winograd_weight_transform") | ||
| def schedule_contrib_conv2d_winograd_weight_transform(attrs, outs, target): | ||
| """Schedule definition of contrib_conv2d_winograd_weight_transform""" | ||
| with target: | ||
| return topi.generic.schedule_conv2d_winograd_weight_transform(outs) | ||
|
|
||
|
|
||
| reg.register_pattern("nn.contrib_conv2d_winograd_weight_transform", | ||
| OpPattern.OUT_ELEMWISE_FUSABLE) | ||
|
|
||
|
|
@@ -351,12 +398,14 @@ def compute_contrib_conv2d_winograd_nnpack_without_weight_transform( | |
|
|
||
| return [out] | ||
|
|
||
|
|
||
| @reg.register_schedule("nn.contrib_conv2d_winograd_nnpack_without_weight_transform") | ||
| def schedule_contrib_conv2d_winograd_nnpack_without_weight_transform(attrs, outs, target): | ||
| """Schedule definition of conv2d_winograd_nnpack_without_weight_transform""" | ||
| with target: | ||
| return topi.generic.schedule_conv2d_winograd_nnpack_without_weight_transform(outs) | ||
|
|
||
|
|
||
| reg.register_pattern("nn.contrib_conv2d_winograd_nnpack_without_weight_transform", | ||
| OpPattern.OPAQUE) | ||
|
|
||
|
|
@@ -369,12 +418,14 @@ def compute_contrib_conv2d_winograd_nnpack_weight_transform(attrs, inputs, out_d | |
| inputs[0], convolution_algorithm, out_dtype) | ||
| return [out] | ||
|
|
||
|
|
||
| @reg.register_schedule("nn.contrib_conv2d_winograd_nnpack_weight_transform") | ||
| def schedule_contrib_conv2d_winograd_nnpack_weight_transform(attrs, outs, target): | ||
| """Schedule definition of contrib_conv2d_winograd_nnpack_weight_transform""" | ||
| with target: | ||
| return topi.generic.schedule_conv2d_winograd_nnpack_weight_transform(outs) | ||
|
|
||
|
|
||
| reg.register_pattern("nn.contrib_conv2d_winograd_nnpack_weight_transform", | ||
| OpPattern.OPAQUE) | ||
|
|
||
|
|
@@ -395,15 +446,18 @@ def compute_contrib_conv2d_NCHWc(attrs, inputs, out_dtype, target): | |
| data_layout, out_layout, out_dtype) | ||
| return [out] | ||
|
|
||
|
|
||
| @reg.register_schedule("nn.contrib_conv2d_NCHWc") | ||
| def schedule_contrib_conv2d_NCHWc(attrs, outs, target): | ||
| """Schedule definition of contrib_conv2d_NCHWc""" | ||
| with target: | ||
| return topi.generic.schedule_conv2d_NCHWc(outs) | ||
|
|
||
|
|
||
| reg.register_pattern("nn.contrib_conv2d_NCHWc", | ||
| OpPattern.OUT_ELEMWISE_FUSABLE) | ||
|
|
||
|
|
||
| @reg.register_compute("nn.contrib_depthwise_conv2d_NCHWc") | ||
| def compute_contrib_depthwise_conv2d_NCHWc(attrs, inputs, out_dtype, target): | ||
| """Compute definition of depthwise conv2d NCHWc""" | ||
|
|
@@ -420,15 +474,18 @@ def compute_contrib_depthwise_conv2d_NCHWc(attrs, inputs, out_dtype, target): | |
| data_layout, out_layout, out_dtype) | ||
| return [out] | ||
|
|
||
|
|
||
| @reg.register_schedule("nn.contrib_depthwise_conv2d_NCHWc") | ||
| def schedule_contrib_depthwise_conv2d_NCHWc(attrs, outs, target): | ||
| """Schedule definition of contrib_conv2d_NCHWc""" | ||
| with target: | ||
| return topi.generic.schedule_depthwise_conv2d_NCHWc(outs) | ||
|
|
||
|
|
||
| reg.register_pattern("nn.contrib_depthwise_conv2d_NCHWc", | ||
| OpPattern.OUT_ELEMWISE_FUSABLE) | ||
|
|
||
|
|
||
| @reg.register_compute("nn.deformable_conv2d") | ||
| def compute_deformable_conv2d(attrs, inputs, out_dtype, target): | ||
| """Compute definition of deformable_conv2d""" | ||
|
|
@@ -444,10 +501,12 @@ def compute_deformable_conv2d(attrs, inputs, out_dtype, target): | |
| dilation, deformable_groups, groups, out_dtype) | ||
| return [out] | ||
|
|
||
|
|
||
| @reg.register_schedule("nn.deformable_conv2d") | ||
| def schedule_deformable_conv2d(attrs, outs, target): | ||
| """Schedule definition of deformable_conv2d""" | ||
| with target: | ||
| return topi.generic.schedule_deformable_conv2d_nchw(outs) | ||
|
|
||
|
|
||
| reg.register_pattern("nn.deformable_conv2d", OpPattern.OUT_ELEMWISE_FUSABLE) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here I updated the logic to distinguish depthwise and group conv2d