|
| 1 | +/* |
| 2 | + * Licensed to the Apache Software Foundation (ASF) under one |
| 3 | + * or more contributor license agreements. See the NOTICE file |
| 4 | + * distributed with this work for additional information |
| 5 | + * regarding copyright ownership. The ASF licenses this file |
| 6 | + * to you under the Apache License, Version 2.0 (the |
| 7 | + * "License"); you may not use this file except in compliance |
| 8 | + * with the License. You may obtain a copy of the License at |
| 9 | + * |
| 10 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 11 | + * |
| 12 | + * Unless required by applicable law or agreed to in writing, |
| 13 | + * software distributed under the License is distributed on an |
| 14 | + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 15 | + * KIND, either express or implied. See the License for the |
| 16 | + * specific language governing permissions and limitations |
| 17 | + * under the License. |
| 18 | + */ |
| 19 | + |
| 20 | +/*! |
| 21 | + * \file mkldnn_softmax_output.cc |
| 22 | + * \brief integrate mkldnn softmax to softmax_output forward |
| 23 | + * \author Zhang Rong A |
| 24 | + */ |
| 25 | + |
| 26 | +#if MXNET_USE_ONEDNN == 1 |
| 27 | +#include "../../softmax_output-inl.h" |
| 28 | +#include "./mkldnn_base-inl.h" |
| 29 | +#include "./mkldnn_ops-inl.h" |
| 30 | +namespace mxnet { |
| 31 | +namespace op { |
| 32 | + |
| 33 | +static mkldnn::softmax_forward::primitive_desc GetSoftmaxOutputFwdDescImpl( |
| 34 | + const SoftmaxOutputParam& param, |
| 35 | + bool is_train, |
| 36 | + const int axis, |
| 37 | + const mkldnn::memory& input_mem) { |
| 38 | + mkldnn::memory::desc data_md = input_mem.get_desc(); |
| 39 | + auto cpu_engine = CpuEngine::Get()->get_engine(); |
| 40 | + auto prop = is_train ? mkldnn::prop_kind::forward_training : mkldnn::prop_kind::forward_scoring; |
| 41 | + auto desc = mkldnn::softmax_forward::desc(prop, data_md, axis); |
| 42 | + return mkldnn::softmax_forward::primitive_desc(desc, cpu_engine); |
| 43 | +} |
| 44 | + |
| 45 | +typedef ParamOpSign<SoftmaxOutputParam> MKLDNNSoftmaxOuputSignature; |
| 46 | + |
| 47 | +class MKLDNNSoftmaxOutputFwd { |
| 48 | + std::shared_ptr<mkldnn::softmax_forward> fwd_; |
| 49 | + |
| 50 | + public: |
| 51 | + const mkldnn::softmax_forward::primitive_desc fwd_pd; |
| 52 | + |
| 53 | + MKLDNNSoftmaxOutputFwd(const SoftmaxOutputParam& param, |
| 54 | + bool is_train, |
| 55 | + const int axis, |
| 56 | + const mkldnn::memory& mem) |
| 57 | + : fwd_pd(GetSoftmaxOutputFwdDescImpl(param, is_train, axis, mem)) { |
| 58 | + fwd_ = std::make_shared<mkldnn::softmax_forward>(fwd_pd); |
| 59 | + } |
| 60 | + |
| 61 | + const inline mkldnn::softmax_forward& GetFwd() const { |
| 62 | + return *fwd_; |
| 63 | + } |
| 64 | +}; |
| 65 | + |
| 66 | +static MKLDNNSoftmaxOutputFwd& GetSoftmaxOutputForward(const SoftmaxOutputParam& param, |
| 67 | + const OpContext& ctx, |
| 68 | + const NDArray& in_data) { |
| 69 | +#if DMLC_CXX11_THREAD_LOCAL |
| 70 | + static thread_local std:: |
| 71 | + unordered_map<MKLDNNSoftmaxOuputSignature, MKLDNNSoftmaxOutputFwd, OpHash> |
| 72 | + fwds; |
| 73 | +#else |
| 74 | + static MX_THREAD_LOCAL |
| 75 | + std::unordered_map<MKLDNNSoftmaxOuputSignature, MKLDNNSoftmaxOutputFwd, OpHash> |
| 76 | + fwds; |
| 77 | +#endif |
| 78 | + MKLDNNSoftmaxOuputSignature key(param); |
| 79 | + key.AddSign(ctx.is_train); |
| 80 | + key.AddSign(in_data); |
| 81 | + |
| 82 | + // softmax_output has no axis parameter, so use it as it original implement. |
| 83 | + int axis = in_data.shape().ndim() - 1; |
| 84 | + |
| 85 | + auto it = fwds.find(key); |
| 86 | + if (it == fwds.end()) { |
| 87 | + auto in_mem = *(in_data.GetMKLDNNData()); |
| 88 | + MKLDNNSoftmaxOutputFwd fwd(param, ctx.is_train, axis, in_mem); |
| 89 | + it = AddToCache(&fwds, key, fwd); |
| 90 | + } |
| 91 | + return it->second; |
| 92 | +} |
| 93 | + |
| 94 | +// This is only used for forward. For backward ,need double check compatibility |
| 95 | +bool SupportMKLDNNSoftmaxOutput(const SoftmaxOutputParam& param) { |
| 96 | + return param.multi_output ? false : true; |
| 97 | +} |
| 98 | + |
| 99 | +void MKLDNNSoftmaxOutputForward(const nnvm::NodeAttrs& attrs, |
| 100 | + const OpContext& ctx, |
| 101 | + const std::vector<NDArray>& in_data, |
| 102 | + const std::vector<OpReqType>& req, |
| 103 | + const std::vector<NDArray>& out_data) { |
| 104 | + const SoftmaxOutputParam& param = nnvm::get<SoftmaxOutputParam>(attrs.parsed); |
| 105 | + |
| 106 | + NDArray idata = in_data[softmaxout_enum::kData]; |
| 107 | + NDArray odata = out_data[softmaxout_enum::kOut]; |
| 108 | + if (in_data[softmaxout_enum::kData].IsView() && in_data[softmaxout_enum::kData].IsMKLDNNData()) { |
| 109 | + idata = in_data[softmaxout_enum::kData].Reorder2Default(); |
| 110 | + } |
| 111 | + |
| 112 | + auto input_mem = idata.GetMKLDNNData(); |
| 113 | + auto out_mem = CreateMKLDNNMem( |
| 114 | + out_data[softmaxout_enum::kOut], input_mem->get_desc(), req[softmaxout_enum::kOut]); |
| 115 | + |
| 116 | + MKLDNNSoftmaxOutputFwd& fwd = GetSoftmaxOutputForward(param, ctx, idata); |
| 117 | + |
| 118 | + MKLDNNStream* stream = MKLDNNStream::Get(); |
| 119 | + stream->RegisterPrimArgs(fwd.GetFwd(), |
| 120 | + {{MKLDNN_ARG_SRC, *input_mem}, {MKLDNN_ARG_DST, *out_mem.second}}); |
| 121 | + CommitOutput(out_data[softmaxout_enum::kOut], out_mem); |
| 122 | + stream->Submit(); |
| 123 | +} |
| 124 | +} // namespace op |
| 125 | +} // namespace mxnet |
| 126 | +#endif |
0 commit comments