diff --git a/src/operator/nn/dnnl/dnnl_convolution.cc b/src/operator/nn/dnnl/dnnl_convolution.cc index 314bc62175e3..22074ad997ec 100644 --- a/src/operator/nn/dnnl/dnnl_convolution.cc +++ b/src/operator/nn/dnnl/dnnl_convolution.cc @@ -116,10 +116,18 @@ std::shared_ptr GetConvFwdImpl( // suboptimal kernel for computation that has the expected memory size requirements auto conv_pd = std::make_shared(desc, attr, engine); - while (conv_pd->dst_desc().get_size() != GetArraySize(output) || - conv_pd->src_desc().get_size() != GetArraySize(data) || - (!param.dnnl_param.quantized && - conv_pd->weights_desc().get_size() != GetArraySize(weights))) { + while ( + conv_pd->dst_desc().get_size() != GetArraySize(output) || + conv_pd->src_desc().get_size() != GetArraySize(data) || + (!param.dnnl_param.quantized && + conv_pd->weights_desc().get_size() != GetArraySize(weights)) || + // With the upgrade of oneDNN to version 2.4+ + // tests/python/dnnl/subgraphs/test_conv_subgraph.py::test_pos_conv_add[True-data_shape1] + // started failing. Switching away from primitive with weight dnnl::format_tag + // ABcd4b16a4b in order to temporarily fix the issue until full fix arrives. + // Tracking issue: https://github.com/apache/incubator-mxnet/issues/20826. + (param.dnnl_param.quantized && conv_pd->weights_desc().dims()[1] < 4 && + conv_pd->weights_desc().data.padded_dims[1] == 16)) { // next_impl() will visit desc and engine, please make sure they are still alive here. CHECK(conv_pd->next_impl()) << "No convolution implementation for this request."; } diff --git a/tests/python/dnnl/subgraphs/subgraph_common.py b/tests/python/dnnl/subgraphs/subgraph_common.py index 37b14c830a7c..3ed526ca56d5 100644 --- a/tests/python/dnnl/subgraphs/subgraph_common.py +++ b/tests/python/dnnl/subgraphs/subgraph_common.py @@ -42,10 +42,7 @@ } } -DATA_SHAPE=[(64, 4, 10, 10), (4, 4, 24, 24), (1, 16, 32, 32)] -# Second shape has been temporairly changed from (4, 3, 24, 24) to (4, 4, 24, 24) due to -# a bug regarding conv+sum fuse with the amount of input channels < 4. It will be reverted -# as soon as the problem is fixed. Issue: https://github.com/apache/incubator-mxnet/issues/20826. +DATA_SHAPE=[(64, 4, 10, 10), (4, 3, 24, 24), (1, 16, 32, 32)] # Helpers class RELU6(nn.HybridBlock):