Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit fd890eb

Browse files
author
Bartlomiej Gawrych
committed
Review
1 parent 7f491b8 commit fd890eb

File tree

4 files changed

+18
-17
lines changed

4 files changed

+18
-17
lines changed

src/operator/nn/dnnl/dnnl_reduce-inl.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,9 @@ class DNNLReduceFwd {
6767
};
6868

6969
template <class T>
70-
NumpyReduceAxesParam ConvertParamsToNumpy(const T& original_param,
71-
const NDArray& in_data,
72-
const NDArray& out_data);
70+
NumpyReduceAxesParam ConvertReduceParamsToNumpy(const T& original_param,
71+
const NDArray& in_data,
72+
const NDArray& out_data);
7373

7474
void DNNLReduceForwardImpl(const NumpyReduceAxesParam& param,
7575
const OpContext& ctx,
@@ -85,7 +85,7 @@ void DNNLReduceForward(const nnvm::NodeAttrs& attrs,
8585
const OpReqType& req,
8686
const NDArray& out_data) {
8787
const ParamType& org_param = nnvm::get<ParamType>(attrs.parsed);
88-
auto param = ConvertParamsToNumpy<ParamType>(org_param, in_data, out_data);
88+
auto param = ConvertReduceParamsToNumpy<ParamType>(org_param, in_data, out_data);
8989
DNNLReduceForwardImpl(param, ctx, in_data, req, out_data, reduction_alg);
9090
}
9191

@@ -98,7 +98,7 @@ bool SupportDNNLReduce(const nnvm::NodeAttrs& attrs,
9898
const NDArray& in_data,
9999
const NDArray& out_data) {
100100
const T& org_param = nnvm::get<T>(attrs.parsed);
101-
auto param = ConvertParamsToNumpy<T>(org_param, in_data, out_data);
101+
auto param = ConvertReduceParamsToNumpy<T>(org_param, in_data, out_data);
102102
return SupportDNNLReduceImpl(param, in_data, out_data);
103103
}
104104

src/operator/nn/dnnl/dnnl_reduce.cc

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,10 @@ namespace mxnet {
3131
namespace op {
3232

3333
template <>
34-
NumpyReduceAxesParam ConvertParamsToNumpy<ReduceAxesParam>(const ReduceAxesParam& original_param,
35-
const NDArray& input,
36-
const NDArray& output) {
34+
NumpyReduceAxesParam ConvertReduceParamsToNumpy<ReduceAxesParam>(
35+
const ReduceAxesParam& original_param,
36+
const NDArray& input,
37+
const NDArray& output) {
3738
NumpyReduceAxesParam numpy_param;
3839
numpy_param.axis = dmlc::optional<mxnet::Tuple<int>>();
3940
if (original_param.axis.has_value()) {
@@ -61,7 +62,7 @@ NumpyReduceAxesParam ConvertParamsToNumpy<ReduceAxesParam>(const ReduceAxesParam
6162
}
6263

6364
template <>
64-
NumpyReduceAxesParam ConvertParamsToNumpy<NumpyReduceAxesParam>(
65+
NumpyReduceAxesParam ConvertReduceParamsToNumpy<NumpyReduceAxesParam>(
6566
const NumpyReduceAxesParam& original_param,
6667
const NDArray& input,
6768
const NDArray& output) {
@@ -73,8 +74,7 @@ mxnet::Tuple<int> CanonicalizeAndSortAxes(const NDArray& input,
7374
mxnet::Tuple<int> original_axes) {
7475
int in_ndim = input.shape().ndim();
7576
mxnet::Tuple<int> axes(param.axis.value());
76-
// canonicalize
77-
for (index_t i = 0; i < axes.ndim(); i++) {
77+
for (int i = 0; i < axes.ndim(); i++) {
7878
if (axes[i] < 0) {
7979
axes[i] += in_ndim;
8080
}
@@ -94,7 +94,7 @@ bool SupportDNNLReduceImpl(const NumpyReduceAxesParam& param,
9494
auto axes = CanonicalizeAndSortAxes(input, param, param.axis.value());
9595
int last_dim = *(axes.end() - 1);
9696
if (last_dim != input.shape().ndim() - 1) {
97-
// oneDNN not optimized case
97+
// oneDNN (v2.3.2) not optimized case
9898
return false;
9999
} else {
100100
for (int i = 0; i < axes.ndim(); i++) {
@@ -106,7 +106,7 @@ bool SupportDNNLReduceImpl(const NumpyReduceAxesParam& param,
106106
}
107107
}
108108

109-
// if axis = () it's identity op and it is not supported by oneDNN
109+
// if `axis = ()` it is identity op and it is not supported by oneDNN
110110
param_supported = param.axis.value().ndim() > 0;
111111
}
112112
// initial value not supported by oneDNN

src/operator/nn/dnnl/dnnl_transpose-inl.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ class DNNLTransposeFwd {
5151
DNNLTransposeFwd& GetTransposeForward(const NumpyTransposeParam& param, const NDArray& data);
5252

5353
template <class ParamType>
54-
NumpyTransposeParam ConvertParamsToNumpy(const ParamType& param);
54+
NumpyTransposeParam ConvertTransposeParamsToNumpy(const ParamType& param);
5555

5656
template <class ParamType>
5757
void DNNLTransposeForward(const nnvm::NodeAttrs& attrs,
@@ -60,7 +60,7 @@ void DNNLTransposeForward(const nnvm::NodeAttrs& attrs,
6060
const OpReqType& req,
6161
const NDArray& output) {
6262
const ParamType& org_param = nnvm::get<ParamType>(attrs.parsed);
63-
auto param = ConvertParamsToNumpy<ParamType>(org_param);
63+
auto param = ConvertTransposeParamsToNumpy<ParamType>(org_param);
6464
auto fwd = GetTransposeForward(param, data);
6565
fwd.SetNewMem(data, output);
6666
fwd.Execute();

src/operator/nn/dnnl/dnnl_transpose.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,14 +123,15 @@ DNNLTransposeFwd& GetTransposeForward(const NumpyTransposeParam& param, const ND
123123
}
124124

125125
template <>
126-
NumpyTransposeParam ConvertParamsToNumpy<NumpyTransposeParam>(const NumpyTransposeParam& param) {
126+
NumpyTransposeParam ConvertTransposeParamsToNumpy<NumpyTransposeParam>(
127+
const NumpyTransposeParam& param) {
127128
NumpyTransposeParam numpy_param;
128129
numpy_param.axes = common::CanonicalizeAxes(param.axes);
129130
return numpy_param;
130131
}
131132

132133
template <>
133-
NumpyTransposeParam ConvertParamsToNumpy<TransposeParam>(const TransposeParam& param) {
134+
NumpyTransposeParam ConvertTransposeParamsToNumpy<TransposeParam>(const TransposeParam& param) {
134135
NumpyTransposeParam numpy_param;
135136
if (param.axes.ndim() == 0) {
136137
numpy_param.axes = mxnet::TShape(-1, 0);

0 commit comments

Comments
 (0)