Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 6f807a6

Browse files
authored
Fix backport of SoftmaxOutput implementation using onednn kernels (#20459)
1 parent 1155c9e commit 6f807a6

File tree

4 files changed

+136
-3
lines changed

4 files changed

+136
-3
lines changed

src/operator/nn/mkldnn/mkldnn_base-inl.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,7 @@ struct LeakyReLUParam;
200200
struct ConvolutionParam;
201201
struct DeconvolutionParam;
202202
struct SoftmaxParam;
203+
struct SoftmaxOutputParam;
203204
struct TransposeParam;
204205
struct ReshapeParam;
205206
bool SupportMKLDNNAct(const ActivationParam& param);
@@ -212,6 +213,7 @@ bool SupportMKLDNNDeconv(const DeconvolutionParam& params, const NDArray &input)
212213
bool SupportMKLDNNSoftmax(const SoftmaxParam& param, const NDArray &input, const NDArray &output);
213214
bool SupportMKLDNNLogSoftmax(const SoftmaxParam& param, const NDArray &input,
214215
const NDArray &output);
216+
bool SupportMKLDNNSoftmaxOutput(const SoftmaxOutputParam &param);
215217
bool SupportMKLDNNTranspose(const TransposeParam& param, const NDArray &data);
216218
bool SupportMKLDNNBatchDot(const std::vector<NDArray> &inputs, const NDArray &output);
217219
} // namespace op

src/operator/nn/mkldnn/mkldnn_ops-inl.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,11 @@ void MKLDNNLogSoftmaxBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx
107107
const std::vector<OpReqType> &req,
108108
const std::vector<NDArray> &out_data);
109109

110+
/* For softmax_output */
111+
void MKLDNNSoftmaxOutputForward(const nnvm::NodeAttrs& attrs, const OpContext& ctx,
112+
const std::vector<NDArray>& in_data,
113+
const std::vector<OpReqType>& req,
114+
const std::vector<NDArray>& out_data);
110115

111116
/* For sum */
112117
void MKLDNNSumForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file mkldnn_softmax_output.cc
22+
* \brief integrate mkldnn softmax to softmax_output forward
23+
* \author Zhang Rong A
24+
*/
25+
26+
#if MXNET_USE_ONEDNN == 1
27+
#include "../../softmax_output-inl.h"
28+
#include "./mkldnn_base-inl.h"
29+
#include "./mkldnn_ops-inl.h"
30+
namespace mxnet {
31+
namespace op {
32+
33+
static mkldnn::softmax_forward::primitive_desc GetSoftmaxOutputFwdDescImpl(
34+
const SoftmaxOutputParam& param,
35+
bool is_train,
36+
const int axis,
37+
const mkldnn::memory& input_mem) {
38+
mkldnn::memory::desc data_md = input_mem.get_desc();
39+
auto cpu_engine = CpuEngine::Get()->get_engine();
40+
auto prop = is_train ? mkldnn::prop_kind::forward_training : mkldnn::prop_kind::forward_scoring;
41+
auto desc = mkldnn::softmax_forward::desc(prop, data_md, axis);
42+
return mkldnn::softmax_forward::primitive_desc(desc, cpu_engine);
43+
}
44+
45+
typedef ParamOpSign<SoftmaxOutputParam> MKLDNNSoftmaxOuputSignature;
46+
47+
class MKLDNNSoftmaxOutputFwd {
48+
std::shared_ptr<mkldnn::softmax_forward> fwd_;
49+
50+
public:
51+
const mkldnn::softmax_forward::primitive_desc fwd_pd;
52+
53+
MKLDNNSoftmaxOutputFwd(const SoftmaxOutputParam& param,
54+
bool is_train,
55+
const int axis,
56+
const mkldnn::memory& mem)
57+
: fwd_pd(GetSoftmaxOutputFwdDescImpl(param, is_train, axis, mem)) {
58+
fwd_ = std::make_shared<mkldnn::softmax_forward>(fwd_pd);
59+
}
60+
61+
const inline mkldnn::softmax_forward& GetFwd() const {
62+
return *fwd_;
63+
}
64+
};
65+
66+
static MKLDNNSoftmaxOutputFwd& GetSoftmaxOutputForward(const SoftmaxOutputParam& param,
67+
const OpContext& ctx,
68+
const NDArray& in_data) {
69+
#if DMLC_CXX11_THREAD_LOCAL
70+
static thread_local std::
71+
unordered_map<MKLDNNSoftmaxOuputSignature, MKLDNNSoftmaxOutputFwd, OpHash>
72+
fwds;
73+
#else
74+
static MX_THREAD_LOCAL
75+
std::unordered_map<MKLDNNSoftmaxOuputSignature, MKLDNNSoftmaxOutputFwd, OpHash>
76+
fwds;
77+
#endif
78+
MKLDNNSoftmaxOuputSignature key(param);
79+
key.AddSign(ctx.is_train);
80+
key.AddSign(in_data);
81+
82+
// softmax_output has no axis parameter, so use it as it original implement.
83+
int axis = in_data.shape().ndim() - 1;
84+
85+
auto it = fwds.find(key);
86+
if (it == fwds.end()) {
87+
auto in_mem = *(in_data.GetMKLDNNData());
88+
MKLDNNSoftmaxOutputFwd fwd(param, ctx.is_train, axis, in_mem);
89+
it = AddToCache(&fwds, key, fwd);
90+
}
91+
return it->second;
92+
}
93+
94+
// This is only used for forward. For backward ,need double check compatibility
95+
bool SupportMKLDNNSoftmaxOutput(const SoftmaxOutputParam& param) {
96+
return param.multi_output ? false : true;
97+
}
98+
99+
void MKLDNNSoftmaxOutputForward(const nnvm::NodeAttrs& attrs,
100+
const OpContext& ctx,
101+
const std::vector<NDArray>& in_data,
102+
const std::vector<OpReqType>& req,
103+
const std::vector<NDArray>& out_data) {
104+
const SoftmaxOutputParam& param = nnvm::get<SoftmaxOutputParam>(attrs.parsed);
105+
106+
NDArray idata = in_data[softmaxout_enum::kData];
107+
NDArray odata = out_data[softmaxout_enum::kOut];
108+
if (in_data[softmaxout_enum::kData].IsView() && in_data[softmaxout_enum::kData].IsMKLDNNData()) {
109+
idata = in_data[softmaxout_enum::kData].Reorder2Default();
110+
}
111+
112+
auto input_mem = idata.GetMKLDNNData();
113+
auto out_mem = CreateMKLDNNMem(
114+
out_data[softmaxout_enum::kOut], input_mem->get_desc(), req[softmaxout_enum::kOut]);
115+
116+
MKLDNNSoftmaxOutputFwd& fwd = GetSoftmaxOutputForward(param, ctx, idata);
117+
118+
MKLDNNStream* stream = MKLDNNStream::Get();
119+
stream->RegisterPrimArgs(fwd.GetFwd(),
120+
{{MKLDNN_ARG_SRC, *input_mem}, {MKLDNN_ARG_DST, *out_mem.second}});
121+
CommitOutput(out_data[softmaxout_enum::kOut], out_mem);
122+
stream->Submit();
123+
}
124+
} // namespace op
125+
} // namespace mxnet
126+
#endif

src/operator/softmax_output.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
* \author Bing Xu, Zhang Rong A
2525
*/
2626
#include "./softmax_output-inl.h"
27-
#if MXNET_USE_MKLDNN == 1
27+
#if MXNET_USE_ONEDNN == 1
2828
#include "./nn/mkldnn/mkldnn_ops-inl.h"
2929
#include "./nn/mkldnn/mkldnn_base-inl.h"
3030
#endif
@@ -134,7 +134,7 @@ static bool SoftmaxOutputShape(const nnvm::NodeAttrs& attrs,
134134
return true;
135135
}
136136

137-
#if MXNET_USE_MKLDNN == 1
137+
#if MXNET_USE_ONEDNN == 1
138138
inline static bool SoftmaxOutputStorageType(const nnvm::NodeAttrs& attrs,
139139
const int dev_mask,
140140
DispatchMode* dispatch_mode,
@@ -244,7 +244,7 @@ NNVM_REGISTER_OP(SoftmaxOutput)
244244
.set_num_inputs(2)
245245
.set_num_outputs(1)
246246
.set_attr_parser(ParamParser<SoftmaxOutputParam>)
247-
#if MXNET_USE_MKLDNN == 1
247+
#if MXNET_USE_ONEDNN == 1
248248
.set_attr<FInferStorageType>("FInferStorageType", SoftmaxOutputStorageType)
249249
.set_attr<bool>("TIsMKLDNN", true)
250250
.set_attr<FComputeEx>("FComputeEx<cpu>", SoftmaxOutputComputeExCPU)

0 commit comments

Comments
 (0)