Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit f363da3

Browse files
author
Sylwester Fraczek
authored
[BUGFIX] fix npi_concatenate quantization dim/axis (#20383)
* fix npi_concatenate quantization dim/axis
1 parent cb5bd4e commit f363da3

File tree

2 files changed

+23
-1
lines changed

2 files changed

+23
-1
lines changed

src/operator/quantization/quantized_concat.cc

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
*/
2525

2626
#include "../nn/concat-inl.h"
27+
#include "../numpy/np_matrix_op-inl.h"
2728

2829
namespace mxnet {
2930
namespace op {
@@ -157,5 +158,27 @@ NNVM_REGISTER_OP(Concat)
157158
return node;
158159
});
159160

161+
NNVM_REGISTER_OP(_npi_concatenate)
162+
.set_attr<FQuantizedOp>("FQuantizedOp", [](const NodeAttrs& attrs) {
163+
const NumpyConcatenateParam& param = nnvm::get<NumpyConcatenateParam>(attrs.parsed);
164+
nnvm::ObjectPtr node = nnvm::Node::Create();
165+
if (param.axis.has_value() && param.axis.value() > 0) {
166+
node->attrs.op = Op::Get("_contrib_quantized_concat");
167+
node->attrs.name = "quantized_" + attrs.name;
168+
} else {
169+
LOG(INFO) << "Currently, quantized numpy concatenate only supports axis>0, exclude "
170+
<< attrs.name << " which axis is " << param.axis;
171+
node->attrs.op = nullptr;
172+
node->attrs.name = attrs.name;
173+
}
174+
node->attrs.dict = attrs.dict;
175+
node->attrs.dict["dim"] = node->attrs.dict["axis"];
176+
node->attrs.dict.erase("axis");
177+
if (node->op() != nullptr && node->op()->attr_parser != nullptr) {
178+
node->op()->attr_parser(&(node->attrs));
179+
}
180+
return node;
181+
});
182+
160183
} // namespace op
161184
} // namespace mxnet

tests/python/mkl/subgraphs/test_conv_subgraph.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,6 @@ def forward(self, x):
289289
@mx.util.use_np
290290
@pytest.mark.parametrize('data_shape', DATA_SHAPE)
291291
@pytest.mark.parametrize('out_type', ['int8', 'auto'])
292-
@pytest.mark.skip("Scale doesn't align in numpy for numpy operators")
293292
def test_pos_concat_scale_align(data_shape, out_type):
294293
# concat scale alignment case
295294
class ConcatScaleAlign(nn.HybridBlock):

0 commit comments

Comments
 (0)