From 5324c9306461f41384d3d76df885f677d2858b34 Mon Sep 17 00:00:00 2001 From: "Li, Hao H" Date: Fri, 19 Apr 2019 10:02:47 +0800 Subject: [PATCH 01/21] trigger the ci --- src/operator/rnn-inl.h | 1 + 1 file changed, 1 insertion(+) diff --git a/src/operator/rnn-inl.h b/src/operator/rnn-inl.h index 37f21ce6d126..afda78db6cd7 100644 --- a/src/operator/rnn-inl.h +++ b/src/operator/rnn-inl.h @@ -44,6 +44,7 @@ #include "./operator_common.h" #include "./rnn_impl.h" + namespace mxnet { namespace op { From b1c3d546b6d44937274a79c7718ab4d9ba70ee28 Mon Sep 17 00:00:00 2001 From: "Li, Hao H" Date: Fri, 19 Apr 2019 13:53:01 +0800 Subject: [PATCH 02/21] integrate mkldnn rnn fp32 inference(LSTM and vRNN with tanh and relu) --- 3rdparty/mkldnn | 2 +- src/operator/nn/mkldnn/mkldnn_rnn_impl.h | 704 +++++++++++++++++++++++ src/operator/rnn-inl.h | 561 +++++++++++++++++- src/operator/rnn.cc | 20 + src/operator/rnn_impl.h | 7 + 5 files changed, 1269 insertions(+), 25 deletions(-) create mode 100644 src/operator/nn/mkldnn/mkldnn_rnn_impl.h diff --git a/3rdparty/mkldnn b/3rdparty/mkldnn index 7de7e5d02bf6..57e1203092f6 160000 --- a/3rdparty/mkldnn +++ b/3rdparty/mkldnn @@ -1 +1 @@ -Subproject commit 7de7e5d02bf687f971e7668963649728356e0c20 +Subproject commit 57e1203092f63941475ec4088ccd3cf609ed9d7a diff --git a/src/operator/nn/mkldnn/mkldnn_rnn_impl.h b/src/operator/nn/mkldnn/mkldnn_rnn_impl.h new file mode 100644 index 000000000000..3f664a9d50ce --- /dev/null +++ b/src/operator/nn/mkldnn/mkldnn_rnn_impl.h @@ -0,0 +1,704 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_RNN_IMPL_H_ +#define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_RNN_IMPL_H_ +#if MXNET_USE_MKLDNN == 1 +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "../../math.h" +#include "../../math_functions-inl.h" +#include "../../operator_common.h" +#include "../../rnn_impl.h" +#include "../../rnn-inl.h" +#include "mkldnn.hpp" +#include "./mkldnn_base-inl.h" + +namespace mxnet { +namespace op { + +algorithm GetMKLDNNRNNAlgo(int mode, + int* ngates, + int* nstates) { + algorithm algo = algorithm::vanilla_rnn; + switch (mode) { + case rnn_enum::kLstm: + *ngates = 4; + *nstates = 2; + algo = algorithm::vanilla_lstm; + break; + case rnn_enum::kGru: + *ngates = 3; + *nstates = 1; + algo = algorithm::vanilla_gru; + break; + case rnn_enum::kRnnRelu: + case rnn_enum::kRnnTanh: + *ngates = 1; + *nstates = 1; + algo = algorithm::vanilla_rnn; + break; + default: + LOG(FATAL) << "unsupported RNN mode:" << mode; + break; + } + return algo; +} + +void ConcatData(mkldnn::memory::format src_format, + mkldnn::memory::format dst_format, + std::vector srcs_cds, + mkldnn::memory::dims dst_cds, + mkldnn::memory::data_type mkldnn_dtype, + int concat_dimension, + std::vector srcs_data, + const mkldnn::memory &dst) { + auto cpu_engine = CpuEngine::Get()->get_engine(); + std::vector srcs_pd; + std::vector srcs; + for (size_t i = 0; i < srcs_cds.size(); i++) { + auto desc = mkldnn::memory::desc(srcs_cds[i], mkldnn_dtype, src_format); + auto mpd = mkldnn::memory::primitive_desc(desc, cpu_engine); + auto src_memory = mkldnn::memory(mpd, srcs_data[i]); + srcs_pd.push_back(mpd); + srcs.push_back(src_memory); + } + std::vector inputs; + for (size_t i = 0; i < srcs_cds.size(); i++) { + inputs.push_back(srcs[i]); + } + auto dst_desc = mkldnn::memory::desc(dst_cds, mkldnn_dtype, dst_format); + auto concat_pd = concat::primitive_desc(dst_desc, concat_dimension, srcs_pd); + MKLDNNStream::Get()->RegisterPrim(concat(concat_pd, inputs, dst)); + MKLDNNStream::Get()->Submit(); +} + +inline size_t GetMKLDNNRNNCacheMemorySize(int L, + int D, + int T, + int N, + int I, + int H, + int mode) { + size_t size = 0; + switch (mode) { + case rnn_enum::kLstm: + size = 2 * (D * (I + H) * 4 * H + (L - 1) * D * (D * H + H) * 4 * H + + L * D * 2 * N * H) + T * N * D * H + L * 2 * D * 4 * H + (L + 2) * D * 2 * N * H + + 6 * D * (I + H + 2) * 4 * H + T * N * I * 2; + break; + case rnn_enum::kGru: + size = 2 * (D * (I + H) * 3 * H + (L - 1) * D * (D * H + H) * 3 * H + + L * D * 2 * N * H) + T * N * D * H + L * 2 * D * 3 * H + (L + 2) * D * 2 * N * H + + 6 * D * (I + H + 2) * 3 * H + T * N * I * 2; + break; + case rnn_enum::kRnnRelu: + case rnn_enum::kRnnTanh: + size = 2 * (D * (I + H) * 1 * H + (L - 1) * D * (D * H + H) * 1 * H + + L * D * 2 * N * H) + T * N * D * H + L * 2 * D * 1 * H + (L + 2) * D * 2 * N * H + + 6 * D * (I + H + 2) * 1 * H + T * N * I * 2; + break; + default: + LOG(FATAL) << "unknown RNN mode " << mode; + break; + } + return size; +} + +template +void AdjustGruGateOrder(DType* weight, + const int I, + const int H) { + // mxnet gru gate order is reset, update and new gates + // mkldnn gru gate order is update, reset and new gates + const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); + DType* weight_reset = weight; + DType* weight_update = weight + I * H; + #pragma omp parallel for num_threads(omp_threads) + for (int i = 0; i < I * H; i++) { + DType tmp = weight_update[i]; + weight_update[i] = weight_reset[i]; + weight_reset[i] = tmp; + } +} +// since there is different sematics of MKLDNN's Fused RNN and Mxnet FusedRNN, +// bidirectional will be fused layer by layer, +// unidirectional will be done by fused 1 + fused (L - 1) layers or fused L layers(when I = H) + +template +void MKLDNNRNNForwardSingleLayerBi(bool state_outputs, + const int T, + const int N, + const int I, + const int H, + DType* x_ptr, + mkldnn::memory user_src_layer_memory, + DType* hx_ptr, + DType* cx_ptr, + DType* w_ptr, + DType* b_ptr, + DType* y_ptr, + DType* hy_ptr, + DType* cy_ptr, + std::vector *concat_weight_memory, + std::vector *concat_iter_memory, + std::vector *x_memory, + std::vector *hcx_memory, + std::vector *wx_memory, + std::vector *wh_memory, + std::vector *bias_memory, + std::vector *y_memory, + std::vector *hcy_memory, + std::vector *rnn_forward_prim, + int layer_index, + bool *has_cache, + int lvalue, + int dtype, + bool is_train, + int mode) { + int ngates = 0, nstates = 0; + algorithm nalgorithm = GetMKLDNNRNNAlgo(mode, &ngates, &nstates); + mkldnn::memory::data_type mkldnn_dtype = get_mkldnn_type(dtype); + const int single_cell_size = N * H; + const int single_b_size = ngates * H; + DType* wx = w_ptr; // ngates * H, I + DType* wh = w_ptr + I * H * ngates; // ngates * H, H + DType* back_wx = w_ptr + ngates * H * (I + H); + DType* back_wh = back_wx + I * H * ngates; + DType* bx = b_ptr; + DType* bh = b_ptr + H * ngates; + DType* back_bx = b_ptr + single_b_size * 2; + DType* back_bh = back_bx + H * ngates; + const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); + auto cpu_engine = CpuEngine::Get()->get_engine(); + auto null_memory_ = null_memory(cpu_engine); + int offset1 = 0, offset2 = 0; + bool cached = *has_cache; + mkldnn::memory::dims src_layer_tz = {T, N, I}; + mkldnn::memory::dims dst_layer_tz = {T, N, 2 * H}; + mkldnn::memory::dims weights_layer_tz = {1, 2, I, ngates, H}; // ldigo + mkldnn::memory::dims weights_layer_r_tz = {1, 1, I, ngates, H}; // ldigo for reorder + mkldnn::memory::dims weights_iter_tz = {1, 2, H, ngates, H}; // ldigo + mkldnn::memory::dims weights_iter_r_tz = {1, 1, H, ngates, H}; // ldigo for reorder + mkldnn::memory::dims bias_tz = {1, 2, ngates, H}; + mkldnn::memory::dims src_iter_tz = {1, 2, nstates, N, H}; // ldsnc + mkldnn::memory::dims dst_iter_tz = {1, 2, nstates, N, H}; // ldsnc + + std::vector weights_scales(ngates * H); + if (!cached) { + if (mode == rnn_enum::kGru) { + AdjustGruGateOrder(wx, I, H); + AdjustGruGateOrder(back_wx, I, H); + AdjustGruGateOrder(wh, H, H); + AdjustGruGateOrder(back_wh, H, H); + } + auto src_wx = (*concat_weight_memory)[2 * layer_index]; + auto src_wh = (*concat_weight_memory)[2 * layer_index + 1]; + std::vector srcs_data1; + srcs_data1.push_back(wx); + srcs_data1.push_back(back_wx); + ConcatData(mkldnn::memory::format::ldgoi, mkldnn::memory::format::ldgoi, + {weights_layer_r_tz, weights_layer_r_tz}, weights_layer_tz, + mkldnn_dtype, 1, srcs_data1, src_wx); + srcs_data1.clear(); + srcs_data1.push_back(wh); + srcs_data1.push_back(back_wh); + ConcatData(mkldnn::memory::format::ldgoi, mkldnn::memory::format::ldgoi, + {weights_iter_r_tz, weights_iter_r_tz}, weights_iter_tz, + mkldnn_dtype, 1, srcs_data1, src_wh); + int tmpvalue = 0; + if (lvalue > 0) { + tmpvalue = lvalue + 1; + } + MKLDNNStream::Get()->RegisterPrim(reorder(src_wx, (*wx_memory)[tmpvalue])); + MKLDNNStream::Get()->RegisterPrim(reorder(src_wh, (*wh_memory)[tmpvalue])); + + DType* user_bias = reinterpret_cast + ((*bias_memory)[tmpvalue].get_data_handle()); + #pragma omp parallel for num_threads(omp_threads) + for (int j = 0; j < single_b_size; j++) { + user_bias[j] = bx[j] + bh[j]; + user_bias[single_b_size + j] = back_bx[j] + back_bh[j]; + } + } + if (lvalue > 0) { + (*wx_memory)[layer_index].set_data_handle((*wx_memory)[lvalue + 1].get_data_handle()); + (*wh_memory)[layer_index].set_data_handle((*wh_memory)[lvalue + 1].get_data_handle()); + (*bias_memory)[layer_index].set_data_handle((*bias_memory)[lvalue + 1].get_data_handle()); + } + + auto src_layer_md = mkldnn::memory::desc( + { src_layer_tz }, mkldnn_dtype, mkldnn::memory::format::tnc); + auto weight_layer_md = mkldnn::memory::desc( + { weights_layer_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo); + auto weight_iter_md = mkldnn::memory::desc( + { weights_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo); + auto dst_layer_md = mkldnn::memory::desc( + { dst_layer_tz }, mkldnn_dtype, mkldnn::memory::format::tnc); + auto dst_iter_md = mkldnn::memory::desc( + { dst_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldsnc); + auto src_iter_md = mkldnn::memory::desc( + {src_iter_tz}, mkldnn_dtype, mkldnn::memory::format::ldsnc); + auto bias_md = mkldnn::memory::desc({bias_tz}, + mkldnn_dtype, mkldnn::memory::format::ldgo); + + auto user_src_iter_memory = (*concat_iter_memory)[2]; + if (mode == rnn_enum::kLstm) { + std::vector srcs_data1; + srcs_data1.push_back(hx_ptr); + srcs_data1.push_back(cx_ptr); + auto tmp1_src_iter_memory = (*concat_iter_memory)[0]; + ConcatData(mkldnn::memory::format::ldsnc, mkldnn::memory::format::ldsnc, + {{1, 1, 1, N, H}, {1, 1, 1, N, H}}, {1, 1, nstates, N, H}, mkldnn_dtype, 2, + srcs_data1, tmp1_src_iter_memory); + std::vector srcs_data2; + srcs_data2.push_back(hx_ptr + single_cell_size); + srcs_data2.push_back(cx_ptr + single_cell_size); + auto tmp2_src_iter_memory = (*concat_iter_memory)[1]; + ConcatData(mkldnn::memory::format::ldsnc, mkldnn::memory::format::ldsnc, + {{1, 1, 1, N, H}, {1, 1, 1, N, H}}, {1, 1, nstates, N, H}, mkldnn_dtype, 2, + srcs_data2, tmp2_src_iter_memory); + std::vector srcs_data3; + srcs_data3.push_back(reinterpret_cast(tmp1_src_iter_memory.get_data_handle())); + srcs_data3.push_back(reinterpret_cast(tmp2_src_iter_memory.get_data_handle())); + ConcatData(mkldnn::memory::format::ldsnc, mkldnn::memory::format::ldsnc, + {{1, 1, nstates, N, H}, {1, 1, nstates, N, H}}, {1, 2, nstates, N, H}, + mkldnn_dtype, 1, srcs_data3, user_src_iter_memory); + } else { + user_src_iter_memory.set_data_handle(hx_ptr); + } + (*hcx_memory)[layer_index].set_data_handle(user_src_iter_memory.get_data_handle()); + + rnn_cell::desc rnn_cell(nalgorithm, + mode == rnn_enum::kRnnRelu ? algorithm::eltwise_relu : algorithm::eltwise_tanh); + + rnn_forward::desc layer_desc(prop_kind::forward_inference, rnn_cell, + rnn_direction::bidirectional_concat, src_layer_md, + src_iter_md, weight_layer_md, weight_iter_md, + bias_md, dst_layer_md, dst_iter_md); + + auto prim_desc + = rnn_forward::primitive_desc(layer_desc, cpu_engine); + + if (x_ptr && layer_index == 0) { + (*x_memory)[layer_index].set_data_handle(x_ptr); + } else { + (*x_memory)[layer_index].set_data_handle(user_src_layer_memory.get_data_handle()); + } + (*y_memory)[layer_index].set_data_handle(y_ptr); + + if (rnn_forward_prim->size() <= layer_index) { + primitive rnn_prim = rnn_forward(prim_desc, (*x_memory)[layer_index], + (*hcx_memory)[layer_index], (*wx_memory)[layer_index], + (*wh_memory)[layer_index], (*bias_memory)[layer_index], + (*y_memory)[layer_index], + (*hcy_memory)[layer_index], null_memory_); + rnn_forward_prim->push_back(rnn_prim); + } + MKLDNNStream::Get()->RegisterPrim((*rnn_forward_prim)[layer_index]); + MKLDNNStream::Get()->Submit(); + + if (state_outputs) { + DType* dst_hcy = reinterpret_cast ((*hcy_memory)[layer_index].get_data_handle()); + if (mode == rnn_enum::kLstm) { + offset1 = nstates * single_cell_size; + offset2 = (nstates + 1) * single_cell_size; + #pragma omp parallel for num_threads(omp_threads) + for (int n = 0; n < single_cell_size; n++) { + hy_ptr[n] = dst_hcy[n]; + hy_ptr[n + single_cell_size] = dst_hcy[n + offset1]; + cy_ptr[n] = dst_hcy[n + single_cell_size]; + cy_ptr[n + single_cell_size] = dst_hcy[n + offset2]; + } + } else { + #pragma omp parallel for num_threads(omp_threads) + for (int n = 0; n < 2 * single_cell_size; n++) { + hy_ptr[n] = dst_hcy[n]; + } + } + } +} + + +template +void MKLDNNRNNForwardUnidi(bool state_outputs, + const int L, + const int T, + const int N, + const int I, + const int H, + DType* x_ptr, + mkldnn::memory user_src_layer_memory, + DType* hx_ptr, + DType* cx_ptr, + DType* w_ptr, + DType* b_ptr, + DType* y_ptr, + DType* hy_ptr, + DType* cy_ptr, + std::vector *concat_weight_memory, + std::vector *concat_iter_memory, + std::vector *x_memory, + std::vector *hcx_memory, + std::vector *wx_memory, + std::vector *wh_memory, + std::vector *bias_memory, + std::vector *y_memory, + std::vector *hcy_memory, + std::vector *rnn_forward_prim, + int layer_index, + bool *has_cache, + int dtype, + bool is_train, + int mode) { + int ngates = 0, nstates = 0; + algorithm nalgorithm = GetMKLDNNRNNAlgo(mode, &ngates, &nstates); + mkldnn::memory::data_type mkldnn_dtype = get_mkldnn_type(dtype); + const int cell_size = N * H; + const int single_cell_size = N * H; + const int single_b_size = ngates * H; + int w_size = (I + H) * H * ngates; + const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); + auto cpu_engine = CpuEngine::Get()->get_engine(); + auto null_memory_ = null_memory(cpu_engine); + int offset1 = 0, offset2 = 0; + bool cached = *has_cache; + + mkldnn::memory::dims src_layer_tz = {T, N, I}; + mkldnn::memory::dims dst_layer_tz = {T, N, H}; + mkldnn::memory::dims weights_layer_tz = {L, 1, I, ngates, H}; // ldigo + mkldnn::memory::dims weights_iter_tz = {L, 1, H, ngates, H}; // ldigo + mkldnn::memory::dims bias_tz = {L, 1, ngates, H}; + mkldnn::memory::dims src_iter_tz = {L, 1, nstates, N, H}; // ldsnc + mkldnn::memory::dims dst_iter_tz = {L, 1, nstates, N, H}; // ldsnc + mkldnn::memory::dims weights_layer_r_tz = {1, 1, I, ngates, H}; // ldigo for reorder + mkldnn::memory::dims weights_iter_r_tz = {1, 1, H, ngates, H}; // ldigo for reorder + + auto weight_layer_md = mkldnn::memory::desc( + { weights_layer_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo); + auto weight_iter_md = mkldnn::memory::desc( + { weights_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo); + auto src_layer_md = mkldnn::memory::desc( + { src_layer_tz }, mkldnn_dtype, mkldnn::memory::format::tnc); + auto dst_layer_md = mkldnn::memory::desc( + {dst_layer_tz}, mkldnn_dtype, mkldnn::memory::format::tnc); + auto src_iter_md = mkldnn::memory::desc( + {src_iter_tz}, mkldnn_dtype, mkldnn::memory::format::ldsnc); + auto bias_md = mkldnn::memory::desc({bias_tz}, + mkldnn_dtype, mkldnn::memory::format::ldgo); + auto dst_iter_md = mkldnn::memory::desc( + {dst_iter_tz}, mkldnn_dtype, mkldnn::memory::format::ldsnc); + + for (int l = 0; l < L; l++) { + if (mode == rnn_enum::kLstm) { + std::vector srcs_data; + srcs_data.push_back(hx_ptr); + srcs_data.push_back(cx_ptr); + auto tmp_src_iter_memory = (*concat_iter_memory)[l + layer_index]; + ConcatData(mkldnn::memory::format::ldsnc, mkldnn::memory::format::ldsnc, + {{1, 1, 1, N, H}, {1, 1, 1, N, H}}, {1, 1, nstates, N, H}, mkldnn_dtype, + 2, srcs_data, tmp_src_iter_memory); + } else { + (*concat_iter_memory)[l + layer_index].set_data_handle(hx_ptr); + } + hx_ptr += cell_size; + if (mode == rnn_enum::kLstm) { + cx_ptr += cell_size; + } + } + + auto user_src_iter_memory = null_memory_; + if (L == 1) { + user_src_iter_memory = (*concat_iter_memory)[layer_index]; + } else { + user_src_iter_memory = (*concat_iter_memory)[L + layer_index]; + std::vector src_l_data; + std::vector src_l_dim; + for (int l = 0; l < L; l++) { + src_l_data.push_back(reinterpret_cast + ((*concat_iter_memory)[l + layer_index].get_data_handle())); + src_l_dim.push_back({1, 1, nstates, N, H}); + } + ConcatData(mkldnn::memory::format::ldsnc, mkldnn::memory::format::ldsnc, src_l_dim, + {L, 1, nstates, N, H}, mkldnn_dtype, 0, src_l_data, user_src_iter_memory); + } + (*hcx_memory)[layer_index].set_data_handle(user_src_iter_memory.get_data_handle()); + + auto src_wx_f = (*concat_weight_memory)[2 * layer_index]; + auto src_wh_f = (*concat_weight_memory)[2 * layer_index + 1]; + + std::vector srcs_data_x; + std::vector srcs_data_h; + std::vector src_l_dim_x; + std::vector src_l_dim_h; + std::vector weights_scales(ngates * H); + if (!cached) { + if (L == 1) { + DType* wx = w_ptr; + DType* wh = w_ptr + I * H * ngates; + if (mode == rnn_enum::kGru) { + AdjustGruGateOrder(wx, I, H); + AdjustGruGateOrder(wh, H, H); + } + src_wx_f.set_data_handle(wx); + src_wh_f.set_data_handle(wh); + } else { + for (int l = 0; l < L; l++) { + DType* wx = w_ptr; + DType* wh = w_ptr + I * H * ngates; + if (mode == rnn_enum::kGru) { + AdjustGruGateOrder(wx, I, H); + AdjustGruGateOrder(wh, H, H); + } + srcs_data_x.push_back(wx); + srcs_data_h.push_back(wh); + src_l_dim_x.push_back(weights_layer_r_tz); + src_l_dim_h.push_back(weights_iter_r_tz); + w_ptr = w_ptr + w_size; + } + ConcatData(mkldnn::memory::format::ldgoi, mkldnn::memory::format::ldgoi, + src_l_dim_x, weights_layer_tz, mkldnn_dtype, 0, srcs_data_x, src_wx_f); + ConcatData(mkldnn::memory::format::ldgoi, mkldnn::memory::format::ldgoi, + src_l_dim_h, weights_iter_tz, mkldnn_dtype, 0, srcs_data_h, src_wh_f); + } + MKLDNNStream::Get()->RegisterPrim(reorder(src_wx_f, (*wx_memory)[layer_index])); + MKLDNNStream::Get()->RegisterPrim(reorder(src_wh_f, (*wh_memory)[layer_index])); + + DType* user_bias_f = reinterpret_cast ((*bias_memory)[layer_index].get_data_handle()); + #pragma omp parallel for num_threads(omp_threads) + for (int j = 0; j < L * single_b_size; j++) { + int k = j / single_b_size; + user_bias_f[j] = b_ptr[j + k * single_b_size] + b_ptr[j + k * single_b_size + single_b_size]; + } + } + + rnn_cell::desc rnn_cell(nalgorithm, + mode == rnn_enum::kRnnRelu ? algorithm::eltwise_relu : algorithm::eltwise_tanh); + + rnn_forward::desc layer_desc(prop_kind::forward_inference, rnn_cell, + rnn_direction::unidirectional, src_layer_md, + src_iter_md, weight_layer_md, weight_iter_md, + bias_md, dst_layer_md, dst_iter_md); + + auto prim_desc + = rnn_forward::primitive_desc(layer_desc, cpu_engine); + + if (x_ptr && layer_index == 0) { + (*x_memory)[layer_index].set_data_handle(x_ptr); + } else { + (*x_memory)[layer_index].set_data_handle(user_src_layer_memory.get_data_handle()); + } + (*y_memory)[layer_index].set_data_handle(y_ptr); + + if (rnn_forward_prim->size() <= layer_index) { + primitive rnn_prim = rnn_forward(prim_desc, (*x_memory)[layer_index], + (*hcx_memory)[layer_index], (*wx_memory)[layer_index], + (*wh_memory)[layer_index], (*bias_memory)[layer_index], + (*y_memory)[layer_index], + (*hcy_memory)[layer_index], null_memory_); + rnn_forward_prim->push_back(rnn_prim); + } + MKLDNNStream::Get()->RegisterPrim((*rnn_forward_prim)[layer_index]); + MKLDNNStream::Get()->Submit(); + + if (state_outputs) { + DType* dst_hcy = reinterpret_cast ((*hcy_memory)[layer_index].get_data_handle()); + if (mode == rnn_enum::kLstm) { + for (int l = 0; l < L; l++) { + offset1 = l * single_cell_size; + offset2 = l * nstates * single_cell_size; + #pragma omp parallel for num_threads(omp_threads) + for (int n = 0; n < single_cell_size; n++) { + hy_ptr[offset1 + n] = dst_hcy[offset2 + n]; + cy_ptr[offset1 + n] = dst_hcy[offset2 + n + single_cell_size]; + } + } + } else { + #pragma omp parallel for num_threads(omp_threads) + for (int n = 0; n < L * single_cell_size; n++) { + hy_ptr[n] = dst_hcy[n]; + } + } + } +} + +template +void MKLDNNRNNForward(bool state_outputs, + const int L, + const int D, + const int T, + const int N, + const int I, + const int H, + DType* x_ptr, + DType* hx_ptr, + DType* cx_ptr, + DType* w_ptr, + DType* b_ptr, + DType* y_ptr, + DType* hy_ptr, + DType* cy_ptr, + std::vector *concat_weight_memory, + std::vector *concat_iter_memory, + std::vector *x_memory, + std::vector *hcx_memory, + std::vector *wx_memory, + std::vector *wh_memory, + std::vector *bias_memory, + std::vector *y_memory, + std::vector *hcy_memory, + std::vector *rnn_forward_prim, + bool *has_cache, + int dtype, + bool is_train, + int mode) { + int ngates = 0, nstates = 0; + GetMKLDNNRNNAlgo(mode, &ngates, &nstates); + const int b_size = 2 * H * ngates * D; + const int cell_size = N * H * D; + // First layer + int w_size = (I + H) * H * ngates * D; + auto cpu_engine = CpuEngine::Get()->get_engine(); + auto null_memory_ = null_memory(cpu_engine); + DType* tmpNull = NULL; + // when D = 1 and I == H, L layers can be fused together + if (D == 1 && I == H) { + MKLDNNRNNForwardUnidi(state_outputs, L, T, N, I, H, x_ptr, null_memory_, + hx_ptr, cx_ptr, w_ptr, b_ptr, y_ptr, hy_ptr, cy_ptr, concat_weight_memory, + concat_iter_memory, x_memory, hcx_memory, wx_memory, wh_memory, + bias_memory, y_memory, hcy_memory, rnn_forward_prim, + 0, has_cache, dtype, is_train, mode); + } else { + auto user_src_layer_memory_l = null_memory_; + if (D == 2) { + MKLDNNRNNForwardSingleLayerBi(state_outputs, T, N, I, H, x_ptr, user_src_layer_memory_l, + hx_ptr, cx_ptr, w_ptr, b_ptr, y_ptr, hy_ptr, cy_ptr, concat_weight_memory, + concat_iter_memory, x_memory, hcx_memory, wx_memory, wh_memory, + bias_memory, y_memory, hcy_memory, rnn_forward_prim, + 0, has_cache, 0, dtype, is_train, mode); + } else { + MKLDNNRNNForwardUnidi(state_outputs, 1, T, N, I, H, x_ptr, user_src_layer_memory_l, + hx_ptr, cx_ptr, w_ptr, b_ptr, y_ptr, hy_ptr, cy_ptr, concat_weight_memory, + concat_iter_memory, x_memory, hcx_memory, wx_memory, wh_memory, + bias_memory, y_memory, hcy_memory, rnn_forward_prim, + 0, has_cache, dtype, is_train, mode); + } + if (L > 1) { + user_src_layer_memory_l = (*y_memory)[0]; + // go to next L - 1 layers. + // If D = 2, do it layer by layer. If D = 1, fused L - 1 layers + w_ptr += w_size; + b_ptr += b_size; + if (D == 2) { + w_size = (H * D + H) * H * ngates * D; + for (int l = 0; l < L - 1; l++) { + if (state_outputs) { + hy_ptr += cell_size; + if (mode == rnn_enum::kLstm) { + cy_ptr += cell_size; + } + } + hx_ptr += cell_size; + if (mode == rnn_enum::kLstm) { + cx_ptr += cell_size; + } + MKLDNNRNNForwardSingleLayerBi(state_outputs, T, N, D * H, H, tmpNull, + user_src_layer_memory_l, hx_ptr, cx_ptr, w_ptr, b_ptr, y_ptr, hy_ptr, + cy_ptr, concat_weight_memory, concat_iter_memory, x_memory, + hcx_memory, wx_memory, wh_memory, bias_memory, + y_memory, hcy_memory, rnn_forward_prim, + 1, has_cache, l + 1, dtype, is_train, mode); + user_src_layer_memory_l = (*y_memory)[1]; + w_ptr += w_size; + b_ptr += b_size; + } + } + if (D == 1) { + w_size = (H + H) * H * ngates; + MKLDNNRNNForwardUnidi(state_outputs, L - 1, T, N, H, H, tmpNull, user_src_layer_memory_l, + hx_ptr, cx_ptr, w_ptr, b_ptr, y_ptr, hy_ptr, cy_ptr, concat_weight_memory, + concat_iter_memory, x_memory, hcx_memory, wx_memory, + wh_memory, bias_memory, y_memory, hcy_memory, + rnn_forward_prim, 1, has_cache, dtype, is_train, mode); + } + } + } + *has_cache = true; +} + +template +void MKLDNNRNNForwardInference(bool state_outputs, + const int num_layers, + const int direction, + const int seq_length, + const int batch_size, + const int input_size, + const int state_size, + DType* x_ptr, + DType* hx_ptr, + DType* cx_ptr, + DType* w_ptr, + DType* b_ptr, + DType* y_ptr, + DType* hy_ptr, + DType* cy_ptr, + std::vector* concat_weight_memory, + std::vector* concat_iter_memory, + std::vector *x_memory, + std::vector *hcx_memory, + std::vector *wx_memory, + std::vector *wh_memory, + std::vector *bias_memory, + std::vector *y_memory, + std::vector *hcy_memory, + std::vector *rnn_forward_prim, + bool *has_cache, + int dtype, + bool is_train, + int mode) { + switch (mode) { + case rnn_enum::kLstm: + case rnn_enum::kGru: + case rnn_enum::kRnnTanh: + case rnn_enum::kRnnRelu: + MKLDNNRNNForward(state_outputs, num_layers, direction, seq_length, + batch_size, input_size, state_size, x_ptr, hx_ptr, + cx_ptr, w_ptr, b_ptr, y_ptr, hy_ptr, cy_ptr, + concat_weight_memory, concat_iter_memory, x_memory, + hcx_memory, wx_memory, wh_memory, + bias_memory, y_memory, hcy_memory, rnn_forward_prim, + has_cache, dtype, is_train, mode); + break; + default: + LOG(FATAL) << "unknown RNN mode" << mode; + break; + } +} + +} // namespace op +} // namespace mxnet +#endif // MXNET_USE_MKLDNN == 1 +#endif // MXNET_OPERATOR_NN_MKLDNN_MKLDNN_RNN_IMPL_H_ diff --git a/src/operator/rnn-inl.h b/src/operator/rnn-inl.h index afda78db6cd7..cf080fef188a 100644 --- a/src/operator/rnn-inl.h +++ b/src/operator/rnn-inl.h @@ -43,18 +43,13 @@ #include "./math_functions-inl.h" #include "./operator_common.h" #include "./rnn_impl.h" - +#if MXNET_USE_MKLDNN == 1 +#include "./nn/mkldnn/mkldnn_rnn_impl.h" +#endif namespace mxnet { namespace op { -namespace rnn_enum { - enum RNNOpInputs {kData, kParams, kState, kStateCell}; - enum RNNOpOutputs {kOut, kStateOut, kStateCellOut}; - enum RNNModeType {kRnnRelu, kRnnTanh, kLstm, kGru}; - enum RNNOpResource {kCuDNNDropoutDescSpace}; -} - inline int GetRnnParamSize(int num_layer, int input_size, int state_size, @@ -385,9 +380,29 @@ class RNNOp { public: RNNParam param_; Context ctx_; + #if MXNET_USE_MKLDNN == 1 + std::vector concat_weight_memory; + std::vector concat_iter_memory; + std::vector rnn_forward_prim; + std::vector x_memory; + std::vector hcx_memory; + std::vector wx_memory; + std::vector wh_memory; + std::vector bias_memory; + std::vector y_memory; + std::vector hcy_memory; + bool has_cache; + bool init_mem_; + size_t reserve_mem_size_; + Storage::Handle mem_space_; + #endif explicit RNNOp(RNNParam param, Context ctx) { this->param_ = param; this->ctx_ = ctx; + #if MXNET_USE_MKLDNN == 1 + init_mem_ = false; + reserve_mem_size_ = 0; + #endif #if MXNET_USE_CUDNN_RNN init_cudnn_ = false; dtype_ = mshadow::DataType::kCudnnFlag; @@ -477,7 +492,6 @@ class RNNOp { this->temp_init_space_ = false; this->reserve_cpu_space_size_ = 0; this->temp_cpu_space_size_ = 0; - if (param_.projection_size.has_value()) { LOG(FATAL) << "hidden layer projection is only supported for GPU with CuDNN later than 7.1.1"; @@ -490,6 +504,12 @@ class RNNOp { } ~RNNOp() { + #if MXNET_USE_MKLDNN == 1 + if (init_mem_) { + Storage::Get()->Free(mem_space_); + init_mem_ = false; + } + #endif #if MXNET_USE_CUDNN_RNN CUDNN_CALL(cudnnDestroyTensorDescriptor(hx_desc_)); CUDNN_CALL(cudnnDestroyTensorDescriptor(cx_desc_)); @@ -746,22 +766,23 @@ class RNNOp { #endif if (ctx_.dev_type == kCPU) { - // allocate temp space - const size_t work_cpu_space_size = - GetRNNWorkspaceSize(param_.seq_length_, param_.batch_size_, - param_.state_size, direction, param_.mode); - if (temp_init_space_ && temp_cpu_space_size_ < work_cpu_space_size) { - Storage::Get()->Free(temp_cpu_space_); - temp_init_space_ = false; - } - if (!temp_init_space_) { - temp_cpu_space_ = Storage::Get()->Alloc - (work_cpu_space_size * sizeof(DType), Context::CPU()); - temp_cpu_space_size_ = work_cpu_space_size; - temp_init_space_ = true; - } - DType* work_cpu_space = static_cast(temp_cpu_space_.dptr); if (ctx.is_train) { + // allocate temp space + const size_t work_cpu_space_size = + GetRNNWorkspaceSize(param_.seq_length_, param_.batch_size_, + param_.state_size, direction, param_.mode); + if (temp_init_space_ && temp_cpu_space_size_ < work_cpu_space_size) { + Storage::Get()->Free(temp_cpu_space_); + temp_init_space_ = false; + } + if (!temp_init_space_) { + temp_cpu_space_ = Storage::Get()->Alloc + (work_cpu_space_size * sizeof(DType), Context::CPU()); + temp_cpu_space_size_ = work_cpu_space_size; + temp_init_space_ = true; + } + DType* work_cpu_space = static_cast(temp_cpu_space_.dptr); + const size_t r_size = GetRNNReserveSpaceSize(param_.num_layers, direction, param_.seq_length_, param_.batch_size_, param_.state_size, param_.mode); @@ -797,6 +818,89 @@ class RNNOp { param_.p, param_.mode); } else { + #if MXNET_USE_MKLDNN == 1 + if (param_.mode != rnn_enum::kGru) { + // mkldnn Gru has precision issue + int dtype = in_data[rnn_enum::kData].type_flag_; + MKLDNNRNNForwardInference(param_.state_outputs, + param_.num_layers, + direction, + param_.seq_length_, + param_.batch_size_, + param_.input_size_, + param_.state_size, + x.dptr_, + hx.dptr_, + cx_ptr, + w.dptr_, + b_ptr, + y.dptr_, + hy_ptr, + cy_ptr, + &concat_weight_memory, + &concat_iter_memory, + &x_memory, + &hcx_memory, + &wx_memory, + &wh_memory, + &bias_memory, + &y_memory, + &hcy_memory, + &rnn_forward_prim, + &has_cache, + dtype, + ctx.is_train, + param_.mode); + } else { + // Before integrating MKLDNN GRU fp32 inference + // using below code for keep func being OK" + const size_t work_cpu_space_size = + GetRNNWorkspaceSize(param_.seq_length_, param_.batch_size_, + param_.state_size, direction, param_.mode); + if (temp_init_space_ && temp_cpu_space_size_ < work_cpu_space_size) { + Storage::Get()->Free(temp_cpu_space_); + temp_init_space_ = false; + } + if (!temp_init_space_) { + temp_cpu_space_ = Storage::Get()->Alloc + (work_cpu_space_size * sizeof(DType), Context::CPU()); + temp_cpu_space_size_ = work_cpu_space_size; + temp_init_space_ = true; + } + DType* work_cpu_space = static_cast(temp_cpu_space_.dptr); + RNNForwardInference(work_cpu_space, + param_.state_outputs, + param_.num_layers, + direction, + param_.seq_length_, + param_.batch_size_, + param_.input_size_, + param_.state_size, + x.dptr_, + hx.dptr_, + cx_ptr, + w.dptr_, + b_ptr, + y.dptr_, + hy_ptr, + cy_ptr, + param_.mode); + } + #else + const size_t work_cpu_space_size = + GetRNNWorkspaceSize(param_.seq_length_, param_.batch_size_, + param_.state_size, direction, param_.mode); + if (temp_init_space_ && temp_cpu_space_size_ < work_cpu_space_size) { + Storage::Get()->Free(temp_cpu_space_); + temp_init_space_ = false; + } + if (!temp_init_space_) { + temp_cpu_space_ = Storage::Get()->Alloc + (work_cpu_space_size * sizeof(DType), Context::CPU()); + temp_cpu_space_size_ = work_cpu_space_size; + temp_init_space_ = true; + } + DType* work_cpu_space = static_cast(temp_cpu_space_.dptr); RNNForwardInference(work_cpu_space, param_.state_outputs, param_.num_layers, @@ -814,6 +918,7 @@ class RNNOp { hy_ptr, cy_ptr, param_.mode); + #endif } } } @@ -1411,6 +1516,414 @@ void RNNStatefulCompute(const OpStatePtr& state, }); } +#if MXNET_USE_MKLDNN == 1 +static void RNNStatefulComputeCPU(const OpStatePtr& state_ptr, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + std::vector in_blobs; + std::vector out_blobs; + std::vector temp_ndarrays_i; + std::vector temp_ndarrays_o; + for (const NDArray& in : inputs) { + if (in.storage_type() == kDefaultStorage) { + temp_ndarrays_i.push_back(in.Reorder2Default()); + in_blobs.emplace_back(temp_ndarrays_i.back().data()); + } else { + in_blobs.emplace_back(in.data()); + } + } + + for (const NDArray& out : outputs) { + if (out.storage_type() == kDefaultStorage) { + temp_ndarrays_o.push_back(out.Reorder2Default()); + out_blobs.emplace_back(temp_ndarrays_o.back().data()); + } else { + out_blobs.emplace_back(out.data()); + } + } + int dtype = in_blobs[rnn_enum::kData].type_flag_; + mkldnn::memory::data_type mkldnn_dtype = get_mkldnn_type(dtype); + Stream *s = ctx.get_stream(); + auto cpu_engine = CpuEngine::Get()->get_engine(); + MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { + RNNOp& op = state_ptr.get_state>(); + const RNNParam& param = op.param_; + int ngates = 0, nstates = 0; + GetMKLDNNRNNAlgo(param.mode, &ngates, &nstates); + int D = param.bidirectional ? 2 : 1; + Tensor x = in_blobs[rnn_enum::kData].get(s); + int T = x.shape_[0]; + int N = x.shape_[1]; + int I = x.shape_[2]; + int H = param.state_size; + int L = param.num_layers; + + const size_t r_size = GetMKLDNNRNNCacheMemorySize(L, D, T, N, I, H, param.mode); + if (op.init_mem_ && op.reserve_mem_size_ < r_size) { + Storage::Get()->Free(op.mem_space_); + op.init_mem_ = false; + } + if (!op.init_mem_) { + op.mem_space_ = Storage::Get()->Alloc( + r_size * sizeof(DType), + Context::CPU()); + op.reserve_mem_size_ = r_size; + op.init_mem_ = true; + op.has_cache = false; + } + if (op.has_cache && op.x_memory.size() == 0) { + op.has_cache = false; + } + + DType* workptr = static_cast(op.mem_space_.dptr); + mkldnn::memory::dims src_layer_tz_0 = {T, N, I}; + mkldnn::memory::dims src_layer_tz = {T, N, D * H}; + mkldnn::memory::dims dst_layer_tz = {T, N, D * H}; + auto dst_layer_md = mkldnn::memory::desc( + { dst_layer_tz }, mkldnn_dtype, mkldnn::memory::format::tnc); + if (op.x_memory.size() == 0) { + if (D == 1 && I == H) { + auto user_src_layer_md = mkldnn::memory::desc( + { src_layer_tz }, mkldnn_dtype, mkldnn::memory::format::tnc); + auto user_src_layer_memory_n = mkldnn::memory({ user_src_layer_md, cpu_engine }); + op.x_memory.push_back(user_src_layer_memory_n); + + mkldnn::memory::dims weights_layer_tz = {L, 1, I, ngates, H}; // ldigo + mkldnn::memory::dims weights_iter_tz = {L, 1, H, ngates, H}; // ldigo + mkldnn::memory::dims bias_tz = {L, 1, ngates, H}; + auto user_weight_layer_md = mkldnn::memory::desc( + { weights_layer_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo); + auto user_weight_iter_md = mkldnn::memory::desc( + { weights_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo); + auto user_bias_md = mkldnn::memory::desc({ bias_tz }, + mkldnn_dtype, mkldnn::memory::format::ldgo); + DType* weight_layer_n = workptr; // L * I * ngates * H + auto user_weight_layer_memory_n + = mkldnn::memory({ user_weight_layer_md, cpu_engine }, weight_layer_n); + op.wx_memory.push_back(user_weight_layer_memory_n); + + DType* weight_iter_n = weight_layer_n + L * I * ngates * H; // L * H * ngates * H + auto user_weight_iter_memory_n + = mkldnn::memory({ user_weight_iter_md, cpu_engine }, weight_iter_n); + op.wh_memory.push_back(user_weight_iter_memory_n); + + DType* bias_n = weight_iter_n + L * H * ngates * H; // L * ngates * H + auto user_bias_memory_n = + mkldnn::memory({ user_bias_md, cpu_engine }, bias_n); + op.bias_memory.push_back(user_bias_memory_n); + + auto wx_md_n = mkldnn::memory::desc( + { weights_layer_tz }, mkldnn_dtype, mkldnn::memory::format::ldgoi); + DType* wx_n = bias_n + L * ngates * H; // L * ngates * I * H + auto wx_memory_n = + mkldnn::memory({ wx_md_n, cpu_engine }, wx_n); + DType* wh_n = wx_n + L * ngates * I * H; // L * ngates * H * H + auto wh_md_n = mkldnn::memory::desc( + { weights_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldgoi); + auto wh_memory_n = + mkldnn::memory({ wh_md_n, cpu_engine }, wh_n); + + op.concat_weight_memory.push_back(wx_memory_n); + op.concat_weight_memory.push_back(wh_memory_n); + workptr = wh_n + L * ngates * H * H; + + mkldnn::memory::dims src_iter_tz_n1 = {1, 1, nstates, N, H}; // ldsnc + auto src_iter_md_n1 = mkldnn::memory::desc( + { src_iter_tz_n1 }, mkldnn_dtype, mkldnn::memory::format::ldsnc); + for (int l = 0; l < L; l++) { + DType* src_iter_n1 = workptr; // nstates * N * H + auto src_iter_memory_n1 = + mkldnn::memory({ src_iter_md_n1, cpu_engine }, src_iter_n1); + op.concat_iter_memory.push_back(src_iter_memory_n1); + workptr = src_iter_n1 + nstates * N * H; + } + mkldnn::memory::dims src_iter_tz_n = {L, 1, nstates, N, H}; // ldsnc + auto src_iter_md_n = mkldnn::memory::desc( + { src_iter_tz_n }, mkldnn_dtype, mkldnn::memory::format::ldsnc); + DType* src_iter_n = workptr; // L * nstates * N * H + auto src_iter_memory_n = + mkldnn::memory({ src_iter_md_n, cpu_engine }, src_iter_n); + op.concat_iter_memory.push_back(src_iter_memory_n); + op.hcx_memory.push_back(src_iter_memory_n); + DType* dst_layer_n = src_iter_n + L * nstates * N * H; // T * N * D * H + auto dst_layer_memory_n + = mkldnn::memory({ dst_layer_md, cpu_engine }, dst_layer_n); + op.y_memory.push_back(dst_layer_memory_n); + + mkldnn::memory::dims dst_iter_tz_n = {L, 1, nstates, N, H}; // ldsnc + auto dst_iter_md_n = mkldnn::memory::desc( + { dst_iter_tz_n }, mkldnn_dtype, mkldnn::memory::format::ldsnc); + DType* dst_iter_n = dst_layer_n + T * N * D * H; // L * nstates * N * H + auto dst_iter_memory_n = + mkldnn::memory({ dst_iter_md_n, cpu_engine }, dst_iter_n); + op.hcy_memory.push_back(dst_iter_memory_n); + workptr = dst_iter_n + L * nstates * N * H; + + } else { + auto user_src_layer_md_0 = mkldnn::memory::desc( + { src_layer_tz_0 }, mkldnn_dtype, mkldnn::memory::format::tnc); + auto user_src_layer_memory_0 = mkldnn::memory({ user_src_layer_md_0, cpu_engine }); + op.x_memory.push_back(user_src_layer_memory_0); + + mkldnn::memory::dims weights_layer_tz_0 = {1, D, I, ngates, H}; // ldigo + mkldnn::memory::dims weights_iter_tz_0 = {1, D, H, ngates, H}; // ldigo + mkldnn::memory::dims bias_tz_0 = {1, D, ngates, H}; + auto user_weight_layer_md_0 = mkldnn::memory::desc( + { weights_layer_tz_0 }, mkldnn_dtype, mkldnn::memory::format::ldigo); + auto user_weight_iter_md_0 = mkldnn::memory::desc( + { weights_iter_tz_0 }, mkldnn_dtype, mkldnn::memory::format::ldigo); + auto user_bias_md_0 = mkldnn::memory::desc({ bias_tz_0 }, + mkldnn_dtype, mkldnn::memory::format::ldgo); + + DType* weight_layer_0 = workptr; // D * I * ngates * H + auto user_weight_layer_memory_0 + = mkldnn::memory({ user_weight_layer_md_0, cpu_engine }, weight_layer_0); + op.wx_memory.push_back(user_weight_layer_memory_0); + + DType* weight_iter_0 = weight_layer_0 + D * I * ngates * H; // D * H * ngates * H + auto user_weight_iter_memory_0 + = mkldnn::memory({ user_weight_iter_md_0, cpu_engine }, weight_iter_0); + op.wh_memory.push_back(user_weight_iter_memory_0); + + DType* bias_0 = weight_iter_0 + D * H * ngates * H; // D * ngates * H + auto user_bias_memory_0 = + mkldnn::memory({ user_bias_md_0, cpu_engine }, bias_0); + op.bias_memory.push_back(user_bias_memory_0); + workptr = bias_0 + D * ngates * H; + + auto wx_md_0 = mkldnn::memory::desc( + { weights_layer_tz_0 }, mkldnn_dtype, mkldnn::memory::format::ldgoi); + auto wx_memory_0 = + mkldnn::memory({ wx_md_0, cpu_engine }); + auto wh_md_0 = mkldnn::memory::desc( + { weights_iter_tz_0 }, mkldnn_dtype, mkldnn::memory::format::ldgoi); + auto wh_memory_0 = + mkldnn::memory({ wh_md_0, cpu_engine }); + if (D == 2) { + DType* wx_0 = workptr; // D * ngates * I * H + wx_memory_0.set_data_handle(wx_0); + DType* wh_0 = wx_0 + D * ngates * I * H; // D * ngates * H * H + wh_memory_0.set_data_handle(wh_0); + workptr = wh_0 + D * ngates * H * H; + } + op.concat_weight_memory.push_back(wx_memory_0); + op.concat_weight_memory.push_back(wh_memory_0); + + mkldnn::memory::dims src_iter_undi_tz_0 = {1, 1, nstates, N, H}; // ldsnc + auto src_iter_undi_md_0 = mkldnn::memory::desc( + { src_iter_undi_tz_0 }, mkldnn_dtype, mkldnn::memory::format::ldsnc); + DType* src_iter_undi_0 = workptr; // nstates * N * H + auto src_iter_undi_memory_0 = + mkldnn::memory({ src_iter_undi_md_0, cpu_engine }, src_iter_undi_0); + op.concat_iter_memory.push_back(src_iter_undi_memory_0); + workptr = src_iter_undi_0 + nstates * N * H; + if (D == 1) { + op.hcx_memory.push_back(src_iter_undi_memory_0); + } else { + DType* src_iter_undi2_0 = workptr; // nstates * N * H + auto src_iter_undi2_memory_0 = + mkldnn::memory({ src_iter_undi_md_0, cpu_engine }, src_iter_undi2_0); + op.concat_iter_memory.push_back(src_iter_undi2_memory_0); + + mkldnn::memory::dims src_iter_tz_0 = {1, D, nstates, N, H}; // ldsnc + auto src_iter_md_0 = mkldnn::memory::desc( + { src_iter_tz_0 }, mkldnn_dtype, mkldnn::memory::format::ldsnc); + DType* src_iter_0 = src_iter_undi2_0 + nstates * N * H; // D * nstates * N * H + auto src_iter_memory_0 = + mkldnn::memory({ src_iter_md_0, cpu_engine }, src_iter_0); + op.concat_iter_memory.push_back(src_iter_memory_0); + op.hcx_memory.push_back(src_iter_memory_0); + workptr = src_iter_0 + D * nstates * N * H; + } + + DType* dst_layer_0 = workptr; // T * N * D * H + auto dst_layer_memory_0 + = mkldnn::memory({ dst_layer_md, cpu_engine }, dst_layer_0); + op.y_memory.push_back(dst_layer_memory_0); + + mkldnn::memory::dims dst_iter_tz_0 = {1, D, nstates, N, H}; // ldsnc + auto dst_iter_md_0 = mkldnn::memory::desc( + { dst_iter_tz_0 }, mkldnn_dtype, mkldnn::memory::format::ldsnc); + DType* dst_iter_0 = dst_layer_0 + T * N * D * H; // D * nstates * N * H + auto dst_iter_memory_0 = + mkldnn::memory({ dst_iter_md_0, cpu_engine }, dst_iter_0); + op.hcy_memory.push_back(dst_iter_memory_0); + workptr = dst_iter_0 + D * nstates * N * H; + + // next L - 1 layers + if (L > 1 && D == 1) { + auto user_src_layer_md = mkldnn::memory::desc( + { src_layer_tz }, mkldnn_dtype, mkldnn::memory::format::tnc); + auto user_src_layer_memory = mkldnn::memory({ user_src_layer_md, cpu_engine }); + op.x_memory.push_back(user_src_layer_memory); + + mkldnn::memory::dims weights_layer_tz = {L - 1, 1, H, ngates, H}; // ldigo + mkldnn::memory::dims weights_iter_tz = {L - 1, 1, H, ngates, H}; // ldigo + mkldnn::memory::dims bias_tz = {L - 1, 1, ngates, H}; + auto user_weight_layer_md = mkldnn::memory::desc( + { weights_layer_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo); + auto user_weight_iter_md = mkldnn::memory::desc( + { weights_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo); + auto user_bias_md = mkldnn::memory::desc({ bias_tz }, + mkldnn_dtype, mkldnn::memory::format::ldgo); + + DType* weight_layer_n = workptr; // (L - 1) * H * ngates * H + auto user_weight_layer_memory_n + = mkldnn::memory({ user_weight_layer_md, cpu_engine }, weight_layer_n); + op.wx_memory.push_back(user_weight_layer_memory_n); + + DType* weight_iter_n = weight_layer_n + + (L - 1) * H * ngates * H; // (L - 1) * H * ngates * H + auto user_weight_iter_memory_n + = mkldnn::memory({ user_weight_iter_md, cpu_engine }, weight_iter_n); + op.wh_memory.push_back(user_weight_iter_memory_n); + + DType* bias_n = weight_iter_n + (L - 1) * H * ngates * H; // (L - 1) * ngates * H + auto user_bias_memory_n = + mkldnn::memory({ user_bias_md, cpu_engine }, bias_n); + op.bias_memory.push_back(user_bias_memory_n); + + auto wx_md_n = mkldnn::memory::desc( + { weights_layer_tz }, mkldnn_dtype, mkldnn::memory::format::ldgoi); + DType* wx_n = bias_n + (L - 1) * ngates * H; // (L - 1) * ngates * H * H + auto wx_memory_n = + mkldnn::memory({ wx_md_n, cpu_engine }, wx_n); + DType* wh_n = wx_n + (L - 1) * ngates * H * H; // (L - 1) * ngates * H * H + auto wh_md_n = mkldnn::memory::desc( + { weights_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldgoi); + auto wh_memory_n = + mkldnn::memory({ wh_md_n, cpu_engine }, wh_n); + + op.concat_weight_memory.push_back(wx_memory_n); + op.concat_weight_memory.push_back(wh_memory_n); + workptr = wh_n + (L - 1) * ngates * H * H; + + mkldnn::memory::dims src_iter_tz_n1 = {1, 1, nstates, N, H}; // ldsnc + auto src_iter_md_n1 = mkldnn::memory::desc( + { src_iter_tz_n1 }, mkldnn_dtype, mkldnn::memory::format::ldsnc); + for (int l = 0; l < L - 1; l++) { + DType* src_iter_n1 = workptr; // nstates * N * H + auto src_iter_memory_n1 = + mkldnn::memory({ src_iter_md_n1, cpu_engine }, src_iter_n1); + op.concat_iter_memory.push_back(src_iter_memory_n1); + workptr = src_iter_n1 + nstates * N * H; + } + mkldnn::memory::dims src_iter_tz_n = {L - 1, 1, nstates, N, H}; // ldsnc + auto src_iter_md_n = mkldnn::memory::desc( + { src_iter_tz_n }, mkldnn_dtype, mkldnn::memory::format::ldsnc); + DType* src_iter_n = workptr; // (L - 1) * nstates * N * H + auto src_iter_memory_n = + mkldnn::memory({ src_iter_md_n, cpu_engine }, src_iter_n); + op.concat_iter_memory.push_back(src_iter_memory_n); + op.hcx_memory.push_back(src_iter_memory_n); + + DType* dst_layer_n = src_iter_n + (L - 1) * nstates * N * H; // T * N * D * H + auto dst_layer_memory_n + = mkldnn::memory({ dst_layer_md, cpu_engine }, dst_layer_n); + op.y_memory.push_back(dst_layer_memory_n); + + mkldnn::memory::dims dst_iter_tz_n = {L - 1, 1, nstates, N, H}; // ldsnc + auto dst_iter_md_n = mkldnn::memory::desc( + { dst_iter_tz_n }, mkldnn_dtype, mkldnn::memory::format::ldsnc); + DType* dst_iter_n = dst_layer_n + T * N * D * H; // (L - 1) * nstates * N * H + auto dst_iter_memory_n = + mkldnn::memory({ dst_iter_md_n, cpu_engine }, dst_iter_n); + op.hcy_memory.push_back(dst_iter_memory_n); + } + + if (L > 1 && D == 2) { + mkldnn::memory::dims weights_layer_tz = {1, D, H * D, ngates, H}; // ldigo + mkldnn::memory::dims weights_iter_tz = {1, D, H, ngates, H}; // ldigo + mkldnn::memory::dims bias_tz = {1, D, ngates, H}; + auto user_weight_layer_md = mkldnn::memory::desc( + { weights_layer_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo); + auto user_weight_iter_md = mkldnn::memory::desc( + { weights_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo); + auto user_bias_md = mkldnn::memory::desc({ bias_tz }, + mkldnn_dtype, mkldnn::memory::format::ldgo); + + auto user_src_layer_md = mkldnn::memory::desc( + { src_layer_tz }, mkldnn_dtype, mkldnn::memory::format::tnc); + auto user_src_layer_memory = mkldnn::memory({ user_src_layer_md, cpu_engine }); + op.x_memory.push_back(user_src_layer_memory); + + auto wx_md_n = mkldnn::memory::desc( + { weights_layer_tz }, mkldnn_dtype, mkldnn::memory::format::ldgoi); + auto wh_md_n = mkldnn::memory::desc( + { weights_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldgoi); + + for (int l = 0; l < L; l++) { + DType* weight_layer_n = workptr; // D * (H * D) * ngates * H + auto user_weight_layer_memory_n + = mkldnn::memory({ user_weight_layer_md, cpu_engine }, weight_layer_n); + op.wx_memory.push_back(user_weight_layer_memory_n); + + DType* weight_iter_n = weight_layer_n + + D * (H * D) * ngates * H; // D * H * ngates * H + auto user_weight_iter_memory_n + = mkldnn::memory({ user_weight_iter_md, cpu_engine }, weight_iter_n); + op.wh_memory.push_back(user_weight_iter_memory_n); + + DType* bias_n = weight_iter_n + D * H * ngates * H; // D * ngates * H + auto user_bias_memory_n = + mkldnn::memory({ user_bias_md, cpu_engine }, bias_n); + op.bias_memory.push_back(user_bias_memory_n); + workptr = bias_n + D * ngates * H; + } + + DType* wx_n = workptr; // D * ngates * (D * H) * H + DType* wh_n = wx_n + D * ngates * (D * H) * H; // D * ngates * H * H + auto wx_memory_n = + mkldnn::memory({ wx_md_n, cpu_engine }, wx_n); + auto wh_memory_n = + mkldnn::memory({ wh_md_n, cpu_engine }, wh_n); + op.concat_weight_memory.push_back(wx_memory_n); + op.concat_weight_memory.push_back(wh_memory_n); + + mkldnn::memory::dims src_iter_undi_tz = {1, 1, nstates, N, H}; // ldsnc + auto src_iter_undi_md = mkldnn::memory::desc( + { src_iter_undi_tz }, mkldnn_dtype, mkldnn::memory::format::ldsnc); + DType* src_iter_undi = wh_n + D * ngates * H * H; // nstates * N * H + auto src_iter_undi_memory = + mkldnn::memory({ src_iter_undi_md, cpu_engine }, src_iter_undi); + op.concat_iter_memory.push_back(src_iter_undi_memory_0); + + DType* src_iter_undi2 = src_iter_undi + nstates * N * H; // nstates * N * H + auto src_iter_undi2_memory = + mkldnn::memory({ src_iter_undi_md, cpu_engine }, src_iter_undi2); + op.concat_iter_memory.push_back(src_iter_undi2_memory); + + mkldnn::memory::dims src_iter_tz = {1, D, nstates, N, H}; // ldsnc + auto src_iter_md = mkldnn::memory::desc( + { src_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldsnc); + DType* src_iter = src_iter_undi2 + nstates * N * H; // D * nstates * N * H + auto src_iter_memory = + mkldnn::memory({ src_iter_md, cpu_engine }, src_iter); + op.concat_iter_memory.push_back(src_iter_memory); + op.hcx_memory.push_back(src_iter_memory); + + DType* dst_layer_n = src_iter + D * nstates * N * H; // T * N * D * H + auto dst_layer_memory_n + = mkldnn::memory({ dst_layer_md, cpu_engine }, dst_layer_n); + op.y_memory.push_back(dst_layer_memory_n); + + mkldnn::memory::dims dst_iter_tz_n = {1, D, nstates, N, H}; // ldsnc + auto dst_iter_md_n = mkldnn::memory::desc( + { dst_iter_tz_n }, mkldnn_dtype, mkldnn::memory::format::ldsnc); + DType* dst_iter_n = dst_layer_n + T * N * D * H; // D * nstates * N * H + auto dst_iter_memory_n = + mkldnn::memory({ dst_iter_md_n, cpu_engine }, dst_iter_n); + op.hcy_memory.push_back(dst_iter_memory_n); + } + } + } + op.Forward(ctx, in_blobs, req, out_blobs); + }); +} + +#endif /* index description 0: x diff --git a/src/operator/rnn.cc b/src/operator/rnn.cc index 7012a3c22f50..6f670b240237 100644 --- a/src/operator/rnn.cc +++ b/src/operator/rnn.cc @@ -138,6 +138,21 @@ static bool RNNType(const nnvm::NodeAttrs& attrs, return true; } +inline static bool RNNStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs) { + DispatchMode wanted_mode = DispatchMode::kFCompute; + + #if MXNET_USE_MKLDNN == 1 + wanted_mode = DispatchMode::kFComputeEx; + #endif + + return storage_type_assign(out_attrs, mxnet::kDefaultStorage, + dispatch_mode, wanted_mode); +} + struct RNNGrad { const char *op_name; std::vector operator()(const nnvm::NodePtr &n, @@ -240,8 +255,13 @@ The definition of GRU here is slightly different from paper but compatible with }) .set_attr("FInferShape", RNNShape) .set_attr("FInferType", RNNType) +.set_attr("FInferStorageType", RNNStorageType) .set_attr("FCreateOpState", CreateRNNState) .set_attr("FStatefulCompute", RNNStatefulCompute) +#if MXNET_USE_MKLDNN == 1 +.set_attr("TIsMKLDNN", true) +.set_attr("FStatefulComputeEx", RNNStatefulComputeCPU) +#endif .set_attr("FGradient", RNNGrad{"_backward_RNN"}) .set_attr("FResourceRequestEx", [](const NodeAttrs& attrs, const int dev_mask, const DispatchMode dispatch_mode) { diff --git a/src/operator/rnn_impl.h b/src/operator/rnn_impl.h index e1b4a2b79c0a..3c8112c51661 100644 --- a/src/operator/rnn_impl.h +++ b/src/operator/rnn_impl.h @@ -44,6 +44,13 @@ namespace mxnet { namespace op { +namespace rnn_enum { + enum RNNOpInputs {kData, kParams, kState, kStateCell}; + enum RNNOpOutputs {kOut, kStateOut, kStateCellOut}; + enum RNNModeType {kRnnRelu, kRnnTanh, kLstm, kGru}; + enum RNNOpResource {kCuDNNDropoutDescSpace}; +} + template inline DType sigmoid(DType x) { return 1.0f / (1.0f + exp(-x)); From d398a1c7aabb1014d85fd5676d75a1700af88242 Mon Sep 17 00:00:00 2001 From: "Li, Hao H" Date: Fri, 19 Apr 2019 14:15:07 +0800 Subject: [PATCH 03/21] fix bug about comparison between signed and unsigned integer expressions --- src/operator/nn/mkldnn/mkldnn_rnn_impl.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/operator/nn/mkldnn/mkldnn_rnn_impl.h b/src/operator/nn/mkldnn/mkldnn_rnn_impl.h index 3f664a9d50ce..e971122f703f 100644 --- a/src/operator/nn/mkldnn/mkldnn_rnn_impl.h +++ b/src/operator/nn/mkldnn/mkldnn_rnn_impl.h @@ -310,7 +310,7 @@ void MKLDNNRNNForwardSingleLayerBi(bool state_outputs, } (*y_memory)[layer_index].set_data_handle(y_ptr); - if (rnn_forward_prim->size() <= layer_index) { + if (rnn_forward_prim->size() <= (size_t)layer_index) { primitive rnn_prim = rnn_forward(prim_desc, (*x_memory)[layer_index], (*hcx_memory)[layer_index], (*wx_memory)[layer_index], (*wh_memory)[layer_index], (*bias_memory)[layer_index], @@ -513,7 +513,7 @@ void MKLDNNRNNForwardUnidi(bool state_outputs, } (*y_memory)[layer_index].set_data_handle(y_ptr); - if (rnn_forward_prim->size() <= layer_index) { + if (rnn_forward_prim->size() <= (size_t)layer_index) { primitive rnn_prim = rnn_forward(prim_desc, (*x_memory)[layer_index], (*hcx_memory)[layer_index], (*wx_memory)[layer_index], (*wh_memory)[layer_index], (*bias_memory)[layer_index], From 23a6e10f90384f84e652e47938bf22bdd9030dc5 Mon Sep 17 00:00:00 2001 From: "Li, Hao H" Date: Fri, 19 Apr 2019 15:20:29 +0800 Subject: [PATCH 04/21] fix unix-gpu issue --- src/operator/rnn-inl.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/operator/rnn-inl.h b/src/operator/rnn-inl.h index cf080fef188a..0be22ca97afd 100644 --- a/src/operator/rnn-inl.h +++ b/src/operator/rnn-inl.h @@ -26,7 +26,7 @@ #ifndef MXNET_OPERATOR_RNN_INL_H_ #define MXNET_OPERATOR_RNN_INL_H_ -#define MXNET_USE_CUDNN_RNN MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 5 +#define MXNET_USE_CUDNN_RNN MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 5 && MXNET_USE_MKLDNN == 0 #define USE_CUDNN_LSTM_PROJ MXNET_USE_CUDNN == 1 && CUDNN_VERSION >= 7200 #include From 18f51ac9ae51f99606620188e18acc0ddd5e99dc Mon Sep 17 00:00:00 2001 From: "Li, Hao H" Date: Fri, 19 Apr 2019 16:32:12 +0800 Subject: [PATCH 05/21] fix unix gpu bug --- src/operator/rnn-inl.h | 9 +++++---- src/operator/rnn.cc | 4 ++-- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/operator/rnn-inl.h b/src/operator/rnn-inl.h index 0be22ca97afd..267d0037e092 100644 --- a/src/operator/rnn-inl.h +++ b/src/operator/rnn-inl.h @@ -26,8 +26,9 @@ #ifndef MXNET_OPERATOR_RNN_INL_H_ #define MXNET_OPERATOR_RNN_INL_H_ -#define MXNET_USE_CUDNN_RNN MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 5 && MXNET_USE_MKLDNN == 0 +#define MXNET_USE_CUDNN_RNN MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 5 #define USE_CUDNN_LSTM_PROJ MXNET_USE_CUDNN == 1 && CUDNN_VERSION >= 7200 +#define MXNET_USE_MKLDNN_RNN MXNET_USE_MKLDNN == 1 && !defined(__CUDACC__) #include #include @@ -43,7 +44,7 @@ #include "./math_functions-inl.h" #include "./operator_common.h" #include "./rnn_impl.h" -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN_RNN #include "./nn/mkldnn/mkldnn_rnn_impl.h" #endif @@ -818,7 +819,7 @@ class RNNOp { param_.p, param_.mode); } else { - #if MXNET_USE_MKLDNN == 1 + #if MXNET_USE_MKLDNN_RNN if (param_.mode != rnn_enum::kGru) { // mkldnn Gru has precision issue int dtype = in_data[rnn_enum::kData].type_flag_; @@ -1516,7 +1517,7 @@ void RNNStatefulCompute(const OpStatePtr& state, }); } -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN_RNN static void RNNStatefulComputeCPU(const OpStatePtr& state_ptr, const OpContext& ctx, const std::vector& inputs, diff --git a/src/operator/rnn.cc b/src/operator/rnn.cc index 6f670b240237..78cd4791c9d3 100644 --- a/src/operator/rnn.cc +++ b/src/operator/rnn.cc @@ -145,7 +145,7 @@ inline static bool RNNStorageType(const nnvm::NodeAttrs& attrs, std::vector *out_attrs) { DispatchMode wanted_mode = DispatchMode::kFCompute; - #if MXNET_USE_MKLDNN == 1 + #if MXNET_USE_MKLDNN_RNN wanted_mode = DispatchMode::kFComputeEx; #endif @@ -258,7 +258,7 @@ The definition of GRU here is slightly different from paper but compatible with .set_attr("FInferStorageType", RNNStorageType) .set_attr("FCreateOpState", CreateRNNState) .set_attr("FStatefulCompute", RNNStatefulCompute) -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN_RNN .set_attr("TIsMKLDNN", true) .set_attr("FStatefulComputeEx", RNNStatefulComputeCPU) #endif From 6621343f18c191324bacbca88f51955fbda86a7b Mon Sep 17 00:00:00 2001 From: "Li, Hao H" Date: Fri, 19 Apr 2019 23:27:42 +0800 Subject: [PATCH 06/21] fix unix-gpu issues --- 3rdparty/mkldnn | 2 +- src/operator/nn/mkldnn/mkldnn_rnn_impl.h | 6 ++++++ src/operator/rnn-inl.h | 2 +- 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/3rdparty/mkldnn b/3rdparty/mkldnn index 57e1203092f6..7de7e5d02bf6 160000 --- a/3rdparty/mkldnn +++ b/3rdparty/mkldnn @@ -1 +1 @@ -Subproject commit 57e1203092f63941475ec4088ccd3cf609ed9d7a +Subproject commit 7de7e5d02bf687f971e7668963649728356e0c20 diff --git a/src/operator/nn/mkldnn/mkldnn_rnn_impl.h b/src/operator/nn/mkldnn/mkldnn_rnn_impl.h index e971122f703f..2f36da9adc33 100644 --- a/src/operator/nn/mkldnn/mkldnn_rnn_impl.h +++ b/src/operator/nn/mkldnn/mkldnn_rnn_impl.h @@ -637,6 +637,12 @@ void MKLDNNRNNForward(bool state_outputs, } } if (D == 1) { + if (state_outputs) { + hy_ptr += cell_size; + if (mode == rnn_enum::kLstm) { + cy_ptr += cell_size; + } + } w_size = (H + H) * H * ngates; MKLDNNRNNForwardUnidi(state_outputs, L - 1, T, N, H, H, tmpNull, user_src_layer_memory_l, hx_ptr, cx_ptr, w_ptr, b_ptr, y_ptr, hy_ptr, cy_ptr, concat_weight_memory, diff --git a/src/operator/rnn-inl.h b/src/operator/rnn-inl.h index 267d0037e092..3541cded1997 100644 --- a/src/operator/rnn-inl.h +++ b/src/operator/rnn-inl.h @@ -854,7 +854,7 @@ class RNNOp { param_.mode); } else { // Before integrating MKLDNN GRU fp32 inference - // using below code for keep func being OK" + // using below code for keep func being OK const size_t work_cpu_space_size = GetRNNWorkspaceSize(param_.seq_length_, param_.batch_size_, param_.state_size, direction, param_.mode); From 4b45093af508d23dab820409f5f3a18a8288e2d3 Mon Sep 17 00:00:00 2001 From: "Li, Hao H" Date: Tue, 23 Apr 2019 13:49:09 +0800 Subject: [PATCH 07/21] fix some comments --- src/operator/nn/mkldnn/mkldnn_rnn_impl.h | 27 ++++++++++++++---------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/src/operator/nn/mkldnn/mkldnn_rnn_impl.h b/src/operator/nn/mkldnn/mkldnn_rnn_impl.h index 2f36da9adc33..9600c5afcc87 100644 --- a/src/operator/nn/mkldnn/mkldnn_rnn_impl.h +++ b/src/operator/nn/mkldnn/mkldnn_rnn_impl.h @@ -96,6 +96,13 @@ void ConcatData(mkldnn::memory::format src_format, MKLDNNStream::Get()->Submit(); } +// cached mkldnn memory +// first layer wx, wh with next L - 1 layers wx and wh +// with L layers hx and cx, src and dst data/iter etc. +// it will prepare memory on before and after reorder and concat. +// for unidirectional, it will fused as dim like 1 + (L - 1) when I != H. +// for bidirectional, it will fused as data + back_data (weight, bias, iter etc), +// also need to identify first layer and next layers inline size_t GetMKLDNNRNNCacheMemorySize(int L, int D, int T, @@ -155,7 +162,7 @@ void MKLDNNRNNForwardSingleLayerBi(bool state_outputs, const int I, const int H, DType* x_ptr, - mkldnn::memory user_src_layer_memory, + mkldnn::memory *user_src_layer_memory, DType* hx_ptr, DType* cx_ptr, DType* w_ptr, @@ -207,7 +214,6 @@ void MKLDNNRNNForwardSingleLayerBi(bool state_outputs, mkldnn::memory::dims src_iter_tz = {1, 2, nstates, N, H}; // ldsnc mkldnn::memory::dims dst_iter_tz = {1, 2, nstates, N, H}; // ldsnc - std::vector weights_scales(ngates * H); if (!cached) { if (mode == rnn_enum::kGru) { AdjustGruGateOrder(wx, I, H); @@ -306,7 +312,7 @@ void MKLDNNRNNForwardSingleLayerBi(bool state_outputs, if (x_ptr && layer_index == 0) { (*x_memory)[layer_index].set_data_handle(x_ptr); } else { - (*x_memory)[layer_index].set_data_handle(user_src_layer_memory.get_data_handle()); + (*x_memory)[layer_index].set_data_handle((*user_src_layer_memory).get_data_handle()); } (*y_memory)[layer_index].set_data_handle(y_ptr); @@ -351,7 +357,7 @@ void MKLDNNRNNForwardUnidi(bool state_outputs, const int I, const int H, DType* x_ptr, - mkldnn::memory user_src_layer_memory, + mkldnn::memory *user_src_layer_memory, DType* hx_ptr, DType* cx_ptr, DType* w_ptr, @@ -454,7 +460,6 @@ void MKLDNNRNNForwardUnidi(bool state_outputs, std::vector srcs_data_h; std::vector src_l_dim_x; std::vector src_l_dim_h; - std::vector weights_scales(ngates * H); if (!cached) { if (L == 1) { DType* wx = w_ptr; @@ -509,7 +514,7 @@ void MKLDNNRNNForwardUnidi(bool state_outputs, if (x_ptr && layer_index == 0) { (*x_memory)[layer_index].set_data_handle(x_ptr); } else { - (*x_memory)[layer_index].set_data_handle(user_src_layer_memory.get_data_handle()); + (*x_memory)[layer_index].set_data_handle((*user_src_layer_memory).get_data_handle()); } (*y_memory)[layer_index].set_data_handle(y_ptr); @@ -586,7 +591,7 @@ void MKLDNNRNNForward(bool state_outputs, DType* tmpNull = NULL; // when D = 1 and I == H, L layers can be fused together if (D == 1 && I == H) { - MKLDNNRNNForwardUnidi(state_outputs, L, T, N, I, H, x_ptr, null_memory_, + MKLDNNRNNForwardUnidi(state_outputs, L, T, N, I, H, x_ptr, &null_memory_, hx_ptr, cx_ptr, w_ptr, b_ptr, y_ptr, hy_ptr, cy_ptr, concat_weight_memory, concat_iter_memory, x_memory, hcx_memory, wx_memory, wh_memory, bias_memory, y_memory, hcy_memory, rnn_forward_prim, @@ -594,13 +599,13 @@ void MKLDNNRNNForward(bool state_outputs, } else { auto user_src_layer_memory_l = null_memory_; if (D == 2) { - MKLDNNRNNForwardSingleLayerBi(state_outputs, T, N, I, H, x_ptr, user_src_layer_memory_l, + MKLDNNRNNForwardSingleLayerBi(state_outputs, T, N, I, H, x_ptr, &user_src_layer_memory_l, hx_ptr, cx_ptr, w_ptr, b_ptr, y_ptr, hy_ptr, cy_ptr, concat_weight_memory, concat_iter_memory, x_memory, hcx_memory, wx_memory, wh_memory, bias_memory, y_memory, hcy_memory, rnn_forward_prim, 0, has_cache, 0, dtype, is_train, mode); } else { - MKLDNNRNNForwardUnidi(state_outputs, 1, T, N, I, H, x_ptr, user_src_layer_memory_l, + MKLDNNRNNForwardUnidi(state_outputs, 1, T, N, I, H, x_ptr, &user_src_layer_memory_l, hx_ptr, cx_ptr, w_ptr, b_ptr, y_ptr, hy_ptr, cy_ptr, concat_weight_memory, concat_iter_memory, x_memory, hcx_memory, wx_memory, wh_memory, bias_memory, y_memory, hcy_memory, rnn_forward_prim, @@ -626,7 +631,7 @@ void MKLDNNRNNForward(bool state_outputs, cx_ptr += cell_size; } MKLDNNRNNForwardSingleLayerBi(state_outputs, T, N, D * H, H, tmpNull, - user_src_layer_memory_l, hx_ptr, cx_ptr, w_ptr, b_ptr, y_ptr, hy_ptr, + &user_src_layer_memory_l, hx_ptr, cx_ptr, w_ptr, b_ptr, y_ptr, hy_ptr, cy_ptr, concat_weight_memory, concat_iter_memory, x_memory, hcx_memory, wx_memory, wh_memory, bias_memory, y_memory, hcy_memory, rnn_forward_prim, @@ -644,7 +649,7 @@ void MKLDNNRNNForward(bool state_outputs, } } w_size = (H + H) * H * ngates; - MKLDNNRNNForwardUnidi(state_outputs, L - 1, T, N, H, H, tmpNull, user_src_layer_memory_l, + MKLDNNRNNForwardUnidi(state_outputs, L - 1, T, N, H, H, tmpNull, &user_src_layer_memory_l, hx_ptr, cx_ptr, w_ptr, b_ptr, y_ptr, hy_ptr, cy_ptr, concat_weight_memory, concat_iter_memory, x_memory, hcx_memory, wx_memory, wh_memory, bias_memory, y_memory, hcy_memory, From 48c58085d2ab8e596c3281409a7df703f068263a Mon Sep 17 00:00:00 2001 From: "Li, Hao H" Date: Wed, 24 Apr 2019 12:14:27 +0800 Subject: [PATCH 08/21] fix issue --- src/operator/nn/mkldnn/mkldnn_rnn_impl.h | 46 ++++++++++++++++++------ 1 file changed, 35 insertions(+), 11 deletions(-) diff --git a/src/operator/nn/mkldnn/mkldnn_rnn_impl.h b/src/operator/nn/mkldnn/mkldnn_rnn_impl.h index 9600c5afcc87..e8876c1a75c4 100644 --- a/src/operator/nn/mkldnn/mkldnn_rnn_impl.h +++ b/src/operator/nn/mkldnn/mkldnn_rnn_impl.h @@ -136,9 +136,9 @@ inline size_t GetMKLDNNRNNCacheMemorySize(int L, } template -void AdjustGruGateOrder(DType* weight, - const int I, - const int H) { +void AdjustGruWeightGateOrder(DType* weight, + const int I, + const int H) { // mxnet gru gate order is reset, update and new gates // mkldnn gru gate order is update, reset and new gates const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); @@ -151,6 +151,22 @@ void AdjustGruGateOrder(DType* weight, weight_reset[i] = tmp; } } + +template +void AdjustGruBiasGateOrder(DType* bias, + const int H) { + // mxnet gru gate order is reset, update and new gates + // mkldnn gru gate order is update, reset and new gates + const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); + DType* bias_reset = bias; + DType* bias_update = bias + H; + #pragma omp parallel for num_threads(omp_threads) + for (int i = 0; i < H; i++) { + DType tmp = bias_update[i]; + bias_update[i] = bias_reset[i]; + bias_reset[i] = tmp; + } +} // since there is different sematics of MKLDNN's Fused RNN and Mxnet FusedRNN, // bidirectional will be fused layer by layer, // unidirectional will be done by fused 1 + fused (L - 1) layers or fused L layers(when I = H) @@ -216,10 +232,14 @@ void MKLDNNRNNForwardSingleLayerBi(bool state_outputs, if (!cached) { if (mode == rnn_enum::kGru) { - AdjustGruGateOrder(wx, I, H); - AdjustGruGateOrder(back_wx, I, H); - AdjustGruGateOrder(wh, H, H); - AdjustGruGateOrder(back_wh, H, H); + AdjustGruWeightGateOrder(wx, I, H); + AdjustGruWeightGateOrder(back_wx, I, H); + AdjustGruWeightGateOrder(wh, H, H); + AdjustGruWeightGateOrder(back_wh, H, H); + AdjustGruBiasGateOrder(bx, H); + AdjustGruBiasGateOrder(back_bx, H); + AdjustGruBiasGateOrder(bh, H); + AdjustGruBiasGateOrder(back_bh, H); } auto src_wx = (*concat_weight_memory)[2 * layer_index]; auto src_wh = (*concat_weight_memory)[2 * layer_index + 1]; @@ -465,8 +485,8 @@ void MKLDNNRNNForwardUnidi(bool state_outputs, DType* wx = w_ptr; DType* wh = w_ptr + I * H * ngates; if (mode == rnn_enum::kGru) { - AdjustGruGateOrder(wx, I, H); - AdjustGruGateOrder(wh, H, H); + AdjustGruWeightGateOrder(wx, I, H); + AdjustGruWeightGateOrder(wh, H, H); } src_wx_f.set_data_handle(wx); src_wh_f.set_data_handle(wh); @@ -474,9 +494,13 @@ void MKLDNNRNNForwardUnidi(bool state_outputs, for (int l = 0; l < L; l++) { DType* wx = w_ptr; DType* wh = w_ptr + I * H * ngates; + DType* bx = b_ptr + l * ngates * H * 2; + DType* bh = b_ptr + l * ngates * H * 2 + H * ngates; if (mode == rnn_enum::kGru) { - AdjustGruGateOrder(wx, I, H); - AdjustGruGateOrder(wh, H, H); + AdjustGruWeightGateOrder(wx, I, H); + AdjustGruWeightGateOrder(wh, H, H); + AdjustGruBiasGateOrder(bx, H); + AdjustGruBiasGateOrder(bh, H); } srcs_data_x.push_back(wx); srcs_data_h.push_back(wh); From 82445ee81db46112c7d36e469273dc689db81e1c Mon Sep 17 00:00:00 2001 From: "Li, Hao H" Date: Thu, 25 Apr 2019 09:53:10 +0800 Subject: [PATCH 09/21] fix comment --- src/operator/rnn-inl.h | 7 +++---- src/operator/rnn.cc | 4 ++-- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/src/operator/rnn-inl.h b/src/operator/rnn-inl.h index 3541cded1997..8f4f32df1773 100644 --- a/src/operator/rnn-inl.h +++ b/src/operator/rnn-inl.h @@ -28,7 +28,6 @@ #define MXNET_USE_CUDNN_RNN MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 5 #define USE_CUDNN_LSTM_PROJ MXNET_USE_CUDNN == 1 && CUDNN_VERSION >= 7200 -#define MXNET_USE_MKLDNN_RNN MXNET_USE_MKLDNN == 1 && !defined(__CUDACC__) #include #include @@ -44,7 +43,7 @@ #include "./math_functions-inl.h" #include "./operator_common.h" #include "./rnn_impl.h" -#if MXNET_USE_MKLDNN_RNN +#if MXNET_USE_MKLDNN == 1 && !defined(__CUDACC__) #include "./nn/mkldnn/mkldnn_rnn_impl.h" #endif @@ -819,7 +818,7 @@ class RNNOp { param_.p, param_.mode); } else { - #if MXNET_USE_MKLDNN_RNN + #if MXNET_USE_MKLDNN == 1 && !defined(__CUDACC__) if (param_.mode != rnn_enum::kGru) { // mkldnn Gru has precision issue int dtype = in_data[rnn_enum::kData].type_flag_; @@ -1517,7 +1516,7 @@ void RNNStatefulCompute(const OpStatePtr& state, }); } -#if MXNET_USE_MKLDNN_RNN +#if MXNET_USE_MKLDNN == 1 && !defined(__CUDACC__) static void RNNStatefulComputeCPU(const OpStatePtr& state_ptr, const OpContext& ctx, const std::vector& inputs, diff --git a/src/operator/rnn.cc b/src/operator/rnn.cc index 78cd4791c9d3..bb0acfb0277e 100644 --- a/src/operator/rnn.cc +++ b/src/operator/rnn.cc @@ -145,7 +145,7 @@ inline static bool RNNStorageType(const nnvm::NodeAttrs& attrs, std::vector *out_attrs) { DispatchMode wanted_mode = DispatchMode::kFCompute; - #if MXNET_USE_MKLDNN_RNN + #if MXNET_USE_MKLDNN == 1 && !defined(__CUDACC__) wanted_mode = DispatchMode::kFComputeEx; #endif @@ -258,7 +258,7 @@ The definition of GRU here is slightly different from paper but compatible with .set_attr("FInferStorageType", RNNStorageType) .set_attr("FCreateOpState", CreateRNNState) .set_attr("FStatefulCompute", RNNStatefulCompute) -#if MXNET_USE_MKLDNN_RNN +#if MXNET_USE_MKLDNN == 1 && !defined(__CUDACC__) .set_attr("TIsMKLDNN", true) .set_attr("FStatefulComputeEx", RNNStatefulComputeCPU) #endif From 0135b1fd608fd094c89c8991da299cc85369b706 Mon Sep 17 00:00:00 2001 From: Wei Date: Mon, 20 May 2019 10:22:51 +0800 Subject: [PATCH 10/21] rename `cached` to `initialized` --- src/operator/nn/mkldnn/mkldnn_rnn_impl.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/operator/nn/mkldnn/mkldnn_rnn_impl.h b/src/operator/nn/mkldnn/mkldnn_rnn_impl.h index e8876c1a75c4..7975b392afa0 100644 --- a/src/operator/nn/mkldnn/mkldnn_rnn_impl.h +++ b/src/operator/nn/mkldnn/mkldnn_rnn_impl.h @@ -219,7 +219,7 @@ void MKLDNNRNNForwardSingleLayerBi(bool state_outputs, auto cpu_engine = CpuEngine::Get()->get_engine(); auto null_memory_ = null_memory(cpu_engine); int offset1 = 0, offset2 = 0; - bool cached = *has_cache; + bool initialized = *has_cache; mkldnn::memory::dims src_layer_tz = {T, N, I}; mkldnn::memory::dims dst_layer_tz = {T, N, 2 * H}; mkldnn::memory::dims weights_layer_tz = {1, 2, I, ngates, H}; // ldigo @@ -230,7 +230,7 @@ void MKLDNNRNNForwardSingleLayerBi(bool state_outputs, mkldnn::memory::dims src_iter_tz = {1, 2, nstates, N, H}; // ldsnc mkldnn::memory::dims dst_iter_tz = {1, 2, nstates, N, H}; // ldsnc - if (!cached) { + if (!initialized) { if (mode == rnn_enum::kGru) { AdjustGruWeightGateOrder(wx, I, H); AdjustGruWeightGateOrder(back_wx, I, H); @@ -411,7 +411,7 @@ void MKLDNNRNNForwardUnidi(bool state_outputs, auto cpu_engine = CpuEngine::Get()->get_engine(); auto null_memory_ = null_memory(cpu_engine); int offset1 = 0, offset2 = 0; - bool cached = *has_cache; + bool initialized = *has_cache; mkldnn::memory::dims src_layer_tz = {T, N, I}; mkldnn::memory::dims dst_layer_tz = {T, N, H}; @@ -480,7 +480,7 @@ void MKLDNNRNNForwardUnidi(bool state_outputs, std::vector srcs_data_h; std::vector src_l_dim_x; std::vector src_l_dim_h; - if (!cached) { + if (!initialized) { if (L == 1) { DType* wx = w_ptr; DType* wh = w_ptr + I * H * ngates; From 75e4803a90f527e458e31944338ee55af3531955 Mon Sep 17 00:00:00 2001 From: xinyu Date: Mon, 20 May 2019 10:24:09 +0800 Subject: [PATCH 11/21] support IType --- src/operator/rnn-inl.h | 589 +++++++++++++++++++++-------------------- 1 file changed, 296 insertions(+), 293 deletions(-) diff --git a/src/operator/rnn-inl.h b/src/operator/rnn-inl.h index 8b88604d146c..9034c5746b8d 100644 --- a/src/operator/rnn-inl.h +++ b/src/operator/rnn-inl.h @@ -1669,253 +1669,84 @@ static void RNNStatefulComputeCPU(const OpStatePtr& state_ptr, } } int dtype = in_blobs[rnn_enum::kData].type_flag_; + int itype = in_blobs[inputs.size()-1].type_flag_; mkldnn::memory::data_type mkldnn_dtype = get_mkldnn_type(dtype); Stream *s = ctx.get_stream(); auto cpu_engine = CpuEngine::Get()->get_engine(); MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { - RNNOp& op = state_ptr.get_state>(); - const RNNParam& param = op.param_; - int ngates = 0, nstates = 0; - GetMKLDNNRNNAlgo(param.mode, &ngates, &nstates); - int D = param.bidirectional ? 2 : 1; - Tensor x = in_blobs[rnn_enum::kData].get(s); - int T = x.shape_[0]; - int N = x.shape_[1]; - int I = x.shape_[2]; - int H = param.state_size; - int L = param.num_layers; - - const size_t r_size = GetMKLDNNRNNCacheMemorySize(L, D, T, N, I, H, param.mode); - if (op.init_mem_ && op.reserve_mem_size_ < r_size) { - Storage::Get()->Free(op.mem_space_); - op.init_mem_ = false; - } - if (!op.init_mem_) { - op.mem_space_ = Storage::Get()->Alloc( - r_size * sizeof(DType), - Context::CPU()); - op.reserve_mem_size_ = r_size; - op.init_mem_ = true; - op.has_cache = false; - } - if (op.has_cache && op.x_memory.size() == 0) { - op.has_cache = false; - } - - DType* workptr = static_cast(op.mem_space_.dptr); - mkldnn::memory::dims src_layer_tz_0 = {T, N, I}; - mkldnn::memory::dims src_layer_tz = {T, N, D * H}; - mkldnn::memory::dims dst_layer_tz = {T, N, D * H}; - auto dst_layer_md = mkldnn::memory::desc( - { dst_layer_tz }, mkldnn_dtype, mkldnn::memory::format::tnc); - if (op.x_memory.size() == 0) { - if (D == 1 && I == H) { - auto user_src_layer_md = mkldnn::memory::desc( - { src_layer_tz }, mkldnn_dtype, mkldnn::memory::format::tnc); - auto user_src_layer_memory_n = mkldnn::memory({ user_src_layer_md, cpu_engine }); - op.x_memory.push_back(user_src_layer_memory_n); - - mkldnn::memory::dims weights_layer_tz = {L, 1, I, ngates, H}; // ldigo - mkldnn::memory::dims weights_iter_tz = {L, 1, H, ngates, H}; // ldigo - mkldnn::memory::dims bias_tz = {L, 1, ngates, H}; - auto user_weight_layer_md = mkldnn::memory::desc( - { weights_layer_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo); - auto user_weight_iter_md = mkldnn::memory::desc( - { weights_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo); - auto user_bias_md = mkldnn::memory::desc({ bias_tz }, - mkldnn_dtype, mkldnn::memory::format::ldgo); - DType* weight_layer_n = workptr; // L * I * ngates * H - auto user_weight_layer_memory_n - = mkldnn::memory({ user_weight_layer_md, cpu_engine }, weight_layer_n); - op.wx_memory.push_back(user_weight_layer_memory_n); - - DType* weight_iter_n = weight_layer_n + L * I * ngates * H; // L * H * ngates * H - auto user_weight_iter_memory_n - = mkldnn::memory({ user_weight_iter_md, cpu_engine }, weight_iter_n); - op.wh_memory.push_back(user_weight_iter_memory_n); - - DType* bias_n = weight_iter_n + L * H * ngates * H; // L * ngates * H - auto user_bias_memory_n = - mkldnn::memory({ user_bias_md, cpu_engine }, bias_n); - op.bias_memory.push_back(user_bias_memory_n); - - auto wx_md_n = mkldnn::memory::desc( - { weights_layer_tz }, mkldnn_dtype, mkldnn::memory::format::ldgoi); - DType* wx_n = bias_n + L * ngates * H; // L * ngates * I * H - auto wx_memory_n = - mkldnn::memory({ wx_md_n, cpu_engine }, wx_n); - DType* wh_n = wx_n + L * ngates * I * H; // L * ngates * H * H - auto wh_md_n = mkldnn::memory::desc( - { weights_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldgoi); - auto wh_memory_n = - mkldnn::memory({ wh_md_n, cpu_engine }, wh_n); - - op.concat_weight_memory.push_back(wx_memory_n); - op.concat_weight_memory.push_back(wh_memory_n); - workptr = wh_n + L * ngates * H * H; - - mkldnn::memory::dims src_iter_tz_n1 = {1, 1, nstates, N, H}; // ldsnc - auto src_iter_md_n1 = mkldnn::memory::desc( - { src_iter_tz_n1 }, mkldnn_dtype, mkldnn::memory::format::ldsnc); - for (int l = 0; l < L; l++) { - DType* src_iter_n1 = workptr; // nstates * N * H - auto src_iter_memory_n1 = - mkldnn::memory({ src_iter_md_n1, cpu_engine }, src_iter_n1); - op.concat_iter_memory.push_back(src_iter_memory_n1); - workptr = src_iter_n1 + nstates * N * H; - } - mkldnn::memory::dims src_iter_tz_n = {L, 1, nstates, N, H}; // ldsnc - auto src_iter_md_n = mkldnn::memory::desc( - { src_iter_tz_n }, mkldnn_dtype, mkldnn::memory::format::ldsnc); - DType* src_iter_n = workptr; // L * nstates * N * H - auto src_iter_memory_n = - mkldnn::memory({ src_iter_md_n, cpu_engine }, src_iter_n); - op.concat_iter_memory.push_back(src_iter_memory_n); - op.hcx_memory.push_back(src_iter_memory_n); - DType* dst_layer_n = src_iter_n + L * nstates * N * H; // T * N * D * H - auto dst_layer_memory_n - = mkldnn::memory({ dst_layer_md, cpu_engine }, dst_layer_n); - op.y_memory.push_back(dst_layer_memory_n); - - mkldnn::memory::dims dst_iter_tz_n = {L, 1, nstates, N, H}; // ldsnc - auto dst_iter_md_n = mkldnn::memory::desc( - { dst_iter_tz_n }, mkldnn_dtype, mkldnn::memory::format::ldsnc); - DType* dst_iter_n = dst_layer_n + T * N * D * H; // L * nstates * N * H - auto dst_iter_memory_n = - mkldnn::memory({ dst_iter_md_n, cpu_engine }, dst_iter_n); - op.hcy_memory.push_back(dst_iter_memory_n); - workptr = dst_iter_n + L * nstates * N * H; - - } else { - auto user_src_layer_md_0 = mkldnn::memory::desc( - { src_layer_tz_0 }, mkldnn_dtype, mkldnn::memory::format::tnc); - auto user_src_layer_memory_0 = mkldnn::memory({ user_src_layer_md_0, cpu_engine }); - op.x_memory.push_back(user_src_layer_memory_0); - - mkldnn::memory::dims weights_layer_tz_0 = {1, D, I, ngates, H}; // ldigo - mkldnn::memory::dims weights_iter_tz_0 = {1, D, H, ngates, H}; // ldigo - mkldnn::memory::dims bias_tz_0 = {1, D, ngates, H}; - auto user_weight_layer_md_0 = mkldnn::memory::desc( - { weights_layer_tz_0 }, mkldnn_dtype, mkldnn::memory::format::ldigo); - auto user_weight_iter_md_0 = mkldnn::memory::desc( - { weights_iter_tz_0 }, mkldnn_dtype, mkldnn::memory::format::ldigo); - auto user_bias_md_0 = mkldnn::memory::desc({ bias_tz_0 }, - mkldnn_dtype, mkldnn::memory::format::ldgo); - - DType* weight_layer_0 = workptr; // D * I * ngates * H - auto user_weight_layer_memory_0 - = mkldnn::memory({ user_weight_layer_md_0, cpu_engine }, weight_layer_0); - op.wx_memory.push_back(user_weight_layer_memory_0); - - DType* weight_iter_0 = weight_layer_0 + D * I * ngates * H; // D * H * ngates * H - auto user_weight_iter_memory_0 - = mkldnn::memory({ user_weight_iter_md_0, cpu_engine }, weight_iter_0); - op.wh_memory.push_back(user_weight_iter_memory_0); - - DType* bias_0 = weight_iter_0 + D * H * ngates * H; // D * ngates * H - auto user_bias_memory_0 = - mkldnn::memory({ user_bias_md_0, cpu_engine }, bias_0); - op.bias_memory.push_back(user_bias_memory_0); - workptr = bias_0 + D * ngates * H; - - auto wx_md_0 = mkldnn::memory::desc( - { weights_layer_tz_0 }, mkldnn_dtype, mkldnn::memory::format::ldgoi); - auto wx_memory_0 = - mkldnn::memory({ wx_md_0, cpu_engine }); - auto wh_md_0 = mkldnn::memory::desc( - { weights_iter_tz_0 }, mkldnn_dtype, mkldnn::memory::format::ldgoi); - auto wh_memory_0 = - mkldnn::memory({ wh_md_0, cpu_engine }); - if (D == 2) { - DType* wx_0 = workptr; // D * ngates * I * H - wx_memory_0.set_data_handle(wx_0); - DType* wh_0 = wx_0 + D * ngates * I * H; // D * ngates * H * H - wh_memory_0.set_data_handle(wh_0); - workptr = wh_0 + D * ngates * H * H; - } - op.concat_weight_memory.push_back(wx_memory_0); - op.concat_weight_memory.push_back(wh_memory_0); - - mkldnn::memory::dims src_iter_undi_tz_0 = {1, 1, nstates, N, H}; // ldsnc - auto src_iter_undi_md_0 = mkldnn::memory::desc( - { src_iter_undi_tz_0 }, mkldnn_dtype, mkldnn::memory::format::ldsnc); - DType* src_iter_undi_0 = workptr; // nstates * N * H - auto src_iter_undi_memory_0 = - mkldnn::memory({ src_iter_undi_md_0, cpu_engine }, src_iter_undi_0); - op.concat_iter_memory.push_back(src_iter_undi_memory_0); - workptr = src_iter_undi_0 + nstates * N * H; - if (D == 1) { - op.hcx_memory.push_back(src_iter_undi_memory_0); - } else { - DType* src_iter_undi2_0 = workptr; // nstates * N * H - auto src_iter_undi2_memory_0 = - mkldnn::memory({ src_iter_undi_md_0, cpu_engine }, src_iter_undi2_0); - op.concat_iter_memory.push_back(src_iter_undi2_memory_0); - - mkldnn::memory::dims src_iter_tz_0 = {1, D, nstates, N, H}; // ldsnc - auto src_iter_md_0 = mkldnn::memory::desc( - { src_iter_tz_0 }, mkldnn_dtype, mkldnn::memory::format::ldsnc); - DType* src_iter_0 = src_iter_undi2_0 + nstates * N * H; // D * nstates * N * H - auto src_iter_memory_0 = - mkldnn::memory({ src_iter_md_0, cpu_engine }, src_iter_0); - op.concat_iter_memory.push_back(src_iter_memory_0); - op.hcx_memory.push_back(src_iter_memory_0); - workptr = src_iter_0 + D * nstates * N * H; - } + MSHADOW_TYPE_SWITCH(itype, IType, { + RNNOp& op = state_ptr.get_state>(); + const RNNParam& param = op.param_; + int ngates = 0, nstates = 0; + GetMKLDNNRNNAlgo(param.mode, &ngates, &nstates); + int D = param.bidirectional ? 2 : 1; + Tensor x = in_blobs[rnn_enum::kData].get(s); + int T = x.shape_[0]; + int N = x.shape_[1]; + int I = x.shape_[2]; + int H = param.state_size; + int L = param.num_layers; + + const size_t r_size = GetMKLDNNRNNCacheMemorySize(L, D, T, N, I, H, param.mode); + if (op.init_mem_ && op.reserve_mem_size_ < r_size) { + Storage::Get()->Free(op.mem_space_); + op.init_mem_ = false; + } + if (!op.init_mem_) { + op.mem_space_ = Storage::Get()->Alloc( + r_size * sizeof(DType), + Context::CPU()); + op.reserve_mem_size_ = r_size; + op.init_mem_ = true; + op.has_cache = false; + } + if (op.has_cache && op.x_memory.size() == 0) { + op.has_cache = false; + } - DType* dst_layer_0 = workptr; // T * N * D * H - auto dst_layer_memory_0 - = mkldnn::memory({ dst_layer_md, cpu_engine }, dst_layer_0); - op.y_memory.push_back(dst_layer_memory_0); - - mkldnn::memory::dims dst_iter_tz_0 = {1, D, nstates, N, H}; // ldsnc - auto dst_iter_md_0 = mkldnn::memory::desc( - { dst_iter_tz_0 }, mkldnn_dtype, mkldnn::memory::format::ldsnc); - DType* dst_iter_0 = dst_layer_0 + T * N * D * H; // D * nstates * N * H - auto dst_iter_memory_0 = - mkldnn::memory({ dst_iter_md_0, cpu_engine }, dst_iter_0); - op.hcy_memory.push_back(dst_iter_memory_0); - workptr = dst_iter_0 + D * nstates * N * H; - - // next L - 1 layers - if (L > 1 && D == 1) { + DType* workptr = static_cast(op.mem_space_.dptr); + mkldnn::memory::dims src_layer_tz_0 = {T, N, I}; + mkldnn::memory::dims src_layer_tz = {T, N, D * H}; + mkldnn::memory::dims dst_layer_tz = {T, N, D * H}; + auto dst_layer_md = mkldnn::memory::desc( + { dst_layer_tz }, mkldnn_dtype, mkldnn::memory::format::tnc); + if (op.x_memory.size() == 0) { + if (D == 1 && I == H) { auto user_src_layer_md = mkldnn::memory::desc( { src_layer_tz }, mkldnn_dtype, mkldnn::memory::format::tnc); - auto user_src_layer_memory = mkldnn::memory({ user_src_layer_md, cpu_engine }); - op.x_memory.push_back(user_src_layer_memory); + auto user_src_layer_memory_n = mkldnn::memory({ user_src_layer_md, cpu_engine }); + op.x_memory.push_back(user_src_layer_memory_n); - mkldnn::memory::dims weights_layer_tz = {L - 1, 1, H, ngates, H}; // ldigo - mkldnn::memory::dims weights_iter_tz = {L - 1, 1, H, ngates, H}; // ldigo - mkldnn::memory::dims bias_tz = {L - 1, 1, ngates, H}; + mkldnn::memory::dims weights_layer_tz = {L, 1, I, ngates, H}; // ldigo + mkldnn::memory::dims weights_iter_tz = {L, 1, H, ngates, H}; // ldigo + mkldnn::memory::dims bias_tz = {L, 1, ngates, H}; auto user_weight_layer_md = mkldnn::memory::desc( { weights_layer_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo); auto user_weight_iter_md = mkldnn::memory::desc( { weights_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo); auto user_bias_md = mkldnn::memory::desc({ bias_tz }, mkldnn_dtype, mkldnn::memory::format::ldgo); - - DType* weight_layer_n = workptr; // (L - 1) * H * ngates * H + DType* weight_layer_n = workptr; // L * I * ngates * H auto user_weight_layer_memory_n = mkldnn::memory({ user_weight_layer_md, cpu_engine }, weight_layer_n); op.wx_memory.push_back(user_weight_layer_memory_n); - DType* weight_iter_n = weight_layer_n + - (L - 1) * H * ngates * H; // (L - 1) * H * ngates * H + DType* weight_iter_n = weight_layer_n + L * I * ngates * H; // L * H * ngates * H auto user_weight_iter_memory_n = mkldnn::memory({ user_weight_iter_md, cpu_engine }, weight_iter_n); op.wh_memory.push_back(user_weight_iter_memory_n); - DType* bias_n = weight_iter_n + (L - 1) * H * ngates * H; // (L - 1) * ngates * H + DType* bias_n = weight_iter_n + L * H * ngates * H; // L * ngates * H auto user_bias_memory_n = mkldnn::memory({ user_bias_md, cpu_engine }, bias_n); op.bias_memory.push_back(user_bias_memory_n); auto wx_md_n = mkldnn::memory::desc( { weights_layer_tz }, mkldnn_dtype, mkldnn::memory::format::ldgoi); - DType* wx_n = bias_n + (L - 1) * ngates * H; // (L - 1) * ngates * H * H + DType* wx_n = bias_n + L * ngates * H; // L * ngates * I * H auto wx_memory_n = mkldnn::memory({ wx_md_n, cpu_engine }, wx_n); - DType* wh_n = wx_n + (L - 1) * ngates * H * H; // (L - 1) * ngates * H * H + DType* wh_n = wx_n + L * ngates * I * H; // L * ngates * H * H auto wh_md_n = mkldnn::memory::desc( { weights_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldgoi); auto wh_memory_n = @@ -1923,128 +1754,300 @@ static void RNNStatefulComputeCPU(const OpStatePtr& state_ptr, op.concat_weight_memory.push_back(wx_memory_n); op.concat_weight_memory.push_back(wh_memory_n); - workptr = wh_n + (L - 1) * ngates * H * H; + workptr = wh_n + L * ngates * H * H; mkldnn::memory::dims src_iter_tz_n1 = {1, 1, nstates, N, H}; // ldsnc auto src_iter_md_n1 = mkldnn::memory::desc( { src_iter_tz_n1 }, mkldnn_dtype, mkldnn::memory::format::ldsnc); - for (int l = 0; l < L - 1; l++) { + for (int l = 0; l < L; l++) { DType* src_iter_n1 = workptr; // nstates * N * H auto src_iter_memory_n1 = mkldnn::memory({ src_iter_md_n1, cpu_engine }, src_iter_n1); op.concat_iter_memory.push_back(src_iter_memory_n1); workptr = src_iter_n1 + nstates * N * H; } - mkldnn::memory::dims src_iter_tz_n = {L - 1, 1, nstates, N, H}; // ldsnc + mkldnn::memory::dims src_iter_tz_n = {L, 1, nstates, N, H}; // ldsnc auto src_iter_md_n = mkldnn::memory::desc( { src_iter_tz_n }, mkldnn_dtype, mkldnn::memory::format::ldsnc); - DType* src_iter_n = workptr; // (L - 1) * nstates * N * H + DType* src_iter_n = workptr; // L * nstates * N * H auto src_iter_memory_n = mkldnn::memory({ src_iter_md_n, cpu_engine }, src_iter_n); op.concat_iter_memory.push_back(src_iter_memory_n); op.hcx_memory.push_back(src_iter_memory_n); - - DType* dst_layer_n = src_iter_n + (L - 1) * nstates * N * H; // T * N * D * H + DType* dst_layer_n = src_iter_n + L * nstates * N * H; // T * N * D * H auto dst_layer_memory_n = mkldnn::memory({ dst_layer_md, cpu_engine }, dst_layer_n); op.y_memory.push_back(dst_layer_memory_n); - mkldnn::memory::dims dst_iter_tz_n = {L - 1, 1, nstates, N, H}; // ldsnc + mkldnn::memory::dims dst_iter_tz_n = {L, 1, nstates, N, H}; // ldsnc auto dst_iter_md_n = mkldnn::memory::desc( { dst_iter_tz_n }, mkldnn_dtype, mkldnn::memory::format::ldsnc); - DType* dst_iter_n = dst_layer_n + T * N * D * H; // (L - 1) * nstates * N * H + DType* dst_iter_n = dst_layer_n + T * N * D * H; // L * nstates * N * H auto dst_iter_memory_n = mkldnn::memory({ dst_iter_md_n, cpu_engine }, dst_iter_n); op.hcy_memory.push_back(dst_iter_memory_n); - } + workptr = dst_iter_n + L * nstates * N * H; - if (L > 1 && D == 2) { - mkldnn::memory::dims weights_layer_tz = {1, D, H * D, ngates, H}; // ldigo - mkldnn::memory::dims weights_iter_tz = {1, D, H, ngates, H}; // ldigo - mkldnn::memory::dims bias_tz = {1, D, ngates, H}; - auto user_weight_layer_md = mkldnn::memory::desc( - { weights_layer_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo); - auto user_weight_iter_md = mkldnn::memory::desc( - { weights_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo); - auto user_bias_md = mkldnn::memory::desc({ bias_tz }, + } else { + auto user_src_layer_md_0 = mkldnn::memory::desc( + { src_layer_tz_0 }, mkldnn_dtype, mkldnn::memory::format::tnc); + auto user_src_layer_memory_0 = mkldnn::memory({ user_src_layer_md_0, cpu_engine }); + op.x_memory.push_back(user_src_layer_memory_0); + + mkldnn::memory::dims weights_layer_tz_0 = {1, D, I, ngates, H}; // ldigo + mkldnn::memory::dims weights_iter_tz_0 = {1, D, H, ngates, H}; // ldigo + mkldnn::memory::dims bias_tz_0 = {1, D, ngates, H}; + auto user_weight_layer_md_0 = mkldnn::memory::desc( + { weights_layer_tz_0 }, mkldnn_dtype, mkldnn::memory::format::ldigo); + auto user_weight_iter_md_0 = mkldnn::memory::desc( + { weights_iter_tz_0 }, mkldnn_dtype, mkldnn::memory::format::ldigo); + auto user_bias_md_0 = mkldnn::memory::desc({ bias_tz_0 }, mkldnn_dtype, mkldnn::memory::format::ldgo); - auto user_src_layer_md = mkldnn::memory::desc( - { src_layer_tz }, mkldnn_dtype, mkldnn::memory::format::tnc); - auto user_src_layer_memory = mkldnn::memory({ user_src_layer_md, cpu_engine }); - op.x_memory.push_back(user_src_layer_memory); - - auto wx_md_n = mkldnn::memory::desc( - { weights_layer_tz }, mkldnn_dtype, mkldnn::memory::format::ldgoi); - auto wh_md_n = mkldnn::memory::desc( - { weights_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldgoi); + DType* weight_layer_0 = workptr; // D * I * ngates * H + auto user_weight_layer_memory_0 + = mkldnn::memory({ user_weight_layer_md_0, cpu_engine }, weight_layer_0); + op.wx_memory.push_back(user_weight_layer_memory_0); + + DType* weight_iter_0 = weight_layer_0 + D * I * ngates * H; // D * H * ngates * H + auto user_weight_iter_memory_0 + = mkldnn::memory({ user_weight_iter_md_0, cpu_engine }, weight_iter_0); + op.wh_memory.push_back(user_weight_iter_memory_0); + + DType* bias_0 = weight_iter_0 + D * H * ngates * H; // D * ngates * H + auto user_bias_memory_0 = + mkldnn::memory({ user_bias_md_0, cpu_engine }, bias_0); + op.bias_memory.push_back(user_bias_memory_0); + workptr = bias_0 + D * ngates * H; + + auto wx_md_0 = mkldnn::memory::desc( + { weights_layer_tz_0 }, mkldnn_dtype, mkldnn::memory::format::ldgoi); + auto wx_memory_0 = + mkldnn::memory({ wx_md_0, cpu_engine }); + auto wh_md_0 = mkldnn::memory::desc( + { weights_iter_tz_0 }, mkldnn_dtype, mkldnn::memory::format::ldgoi); + auto wh_memory_0 = + mkldnn::memory({ wh_md_0, cpu_engine }); + if (D == 2) { + DType* wx_0 = workptr; // D * ngates * I * H + wx_memory_0.set_data_handle(wx_0); + DType* wh_0 = wx_0 + D * ngates * I * H; // D * ngates * H * H + wh_memory_0.set_data_handle(wh_0); + workptr = wh_0 + D * ngates * H * H; + } + op.concat_weight_memory.push_back(wx_memory_0); + op.concat_weight_memory.push_back(wh_memory_0); + + mkldnn::memory::dims src_iter_undi_tz_0 = {1, 1, nstates, N, H}; // ldsnc + auto src_iter_undi_md_0 = mkldnn::memory::desc( + { src_iter_undi_tz_0 }, mkldnn_dtype, mkldnn::memory::format::ldsnc); + DType* src_iter_undi_0 = workptr; // nstates * N * H + auto src_iter_undi_memory_0 = + mkldnn::memory({ src_iter_undi_md_0, cpu_engine }, src_iter_undi_0); + op.concat_iter_memory.push_back(src_iter_undi_memory_0); + workptr = src_iter_undi_0 + nstates * N * H; + if (D == 1) { + op.hcx_memory.push_back(src_iter_undi_memory_0); + } else { + DType* src_iter_undi2_0 = workptr; // nstates * N * H + auto src_iter_undi2_memory_0 = + mkldnn::memory({ src_iter_undi_md_0, cpu_engine }, src_iter_undi2_0); + op.concat_iter_memory.push_back(src_iter_undi2_memory_0); + + mkldnn::memory::dims src_iter_tz_0 = {1, D, nstates, N, H}; // ldsnc + auto src_iter_md_0 = mkldnn::memory::desc( + { src_iter_tz_0 }, mkldnn_dtype, mkldnn::memory::format::ldsnc); + DType* src_iter_0 = src_iter_undi2_0 + nstates * N * H; // D * nstates * N * H + auto src_iter_memory_0 = + mkldnn::memory({ src_iter_md_0, cpu_engine }, src_iter_0); + op.concat_iter_memory.push_back(src_iter_memory_0); + op.hcx_memory.push_back(src_iter_memory_0); + workptr = src_iter_0 + D * nstates * N * H; + } - for (int l = 0; l < L; l++) { - DType* weight_layer_n = workptr; // D * (H * D) * ngates * H + DType* dst_layer_0 = workptr; // T * N * D * H + auto dst_layer_memory_0 + = mkldnn::memory({ dst_layer_md, cpu_engine }, dst_layer_0); + op.y_memory.push_back(dst_layer_memory_0); + + mkldnn::memory::dims dst_iter_tz_0 = {1, D, nstates, N, H}; // ldsnc + auto dst_iter_md_0 = mkldnn::memory::desc( + { dst_iter_tz_0 }, mkldnn_dtype, mkldnn::memory::format::ldsnc); + DType* dst_iter_0 = dst_layer_0 + T * N * D * H; // D * nstates * N * H + auto dst_iter_memory_0 = + mkldnn::memory({ dst_iter_md_0, cpu_engine }, dst_iter_0); + op.hcy_memory.push_back(dst_iter_memory_0); + workptr = dst_iter_0 + D * nstates * N * H; + + // next L - 1 layers + if (L > 1 && D == 1) { + auto user_src_layer_md = mkldnn::memory::desc( + { src_layer_tz }, mkldnn_dtype, mkldnn::memory::format::tnc); + auto user_src_layer_memory = mkldnn::memory({ user_src_layer_md, cpu_engine }); + op.x_memory.push_back(user_src_layer_memory); + + mkldnn::memory::dims weights_layer_tz = {L - 1, 1, H, ngates, H}; // ldigo + mkldnn::memory::dims weights_iter_tz = {L - 1, 1, H, ngates, H}; // ldigo + mkldnn::memory::dims bias_tz = {L - 1, 1, ngates, H}; + auto user_weight_layer_md = mkldnn::memory::desc( + { weights_layer_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo); + auto user_weight_iter_md = mkldnn::memory::desc( + { weights_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo); + auto user_bias_md = mkldnn::memory::desc({ bias_tz }, + mkldnn_dtype, mkldnn::memory::format::ldgo); + + DType* weight_layer_n = workptr; // (L - 1) * H * ngates * H auto user_weight_layer_memory_n = mkldnn::memory({ user_weight_layer_md, cpu_engine }, weight_layer_n); op.wx_memory.push_back(user_weight_layer_memory_n); DType* weight_iter_n = weight_layer_n + - D * (H * D) * ngates * H; // D * H * ngates * H + (L - 1) * H * ngates * H; // (L - 1) * H * ngates * H auto user_weight_iter_memory_n = mkldnn::memory({ user_weight_iter_md, cpu_engine }, weight_iter_n); op.wh_memory.push_back(user_weight_iter_memory_n); - DType* bias_n = weight_iter_n + D * H * ngates * H; // D * ngates * H + DType* bias_n = weight_iter_n + (L - 1) * H * ngates * H; // (L - 1) * ngates * H auto user_bias_memory_n = mkldnn::memory({ user_bias_md, cpu_engine }, bias_n); op.bias_memory.push_back(user_bias_memory_n); - workptr = bias_n + D * ngates * H; - } - - DType* wx_n = workptr; // D * ngates * (D * H) * H - DType* wh_n = wx_n + D * ngates * (D * H) * H; // D * ngates * H * H - auto wx_memory_n = - mkldnn::memory({ wx_md_n, cpu_engine }, wx_n); - auto wh_memory_n = - mkldnn::memory({ wh_md_n, cpu_engine }, wh_n); - op.concat_weight_memory.push_back(wx_memory_n); - op.concat_weight_memory.push_back(wh_memory_n); - mkldnn::memory::dims src_iter_undi_tz = {1, 1, nstates, N, H}; // ldsnc - auto src_iter_undi_md = mkldnn::memory::desc( - { src_iter_undi_tz }, mkldnn_dtype, mkldnn::memory::format::ldsnc); - DType* src_iter_undi = wh_n + D * ngates * H * H; // nstates * N * H - auto src_iter_undi_memory = - mkldnn::memory({ src_iter_undi_md, cpu_engine }, src_iter_undi); - op.concat_iter_memory.push_back(src_iter_undi_memory_0); + auto wx_md_n = mkldnn::memory::desc( + { weights_layer_tz }, mkldnn_dtype, mkldnn::memory::format::ldgoi); + DType* wx_n = bias_n + (L - 1) * ngates * H; // (L - 1) * ngates * H * H + auto wx_memory_n = + mkldnn::memory({ wx_md_n, cpu_engine }, wx_n); + DType* wh_n = wx_n + (L - 1) * ngates * H * H; // (L - 1) * ngates * H * H + auto wh_md_n = mkldnn::memory::desc( + { weights_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldgoi); + auto wh_memory_n = + mkldnn::memory({ wh_md_n, cpu_engine }, wh_n); + + op.concat_weight_memory.push_back(wx_memory_n); + op.concat_weight_memory.push_back(wh_memory_n); + workptr = wh_n + (L - 1) * ngates * H * H; + + mkldnn::memory::dims src_iter_tz_n1 = {1, 1, nstates, N, H}; // ldsnc + auto src_iter_md_n1 = mkldnn::memory::desc( + { src_iter_tz_n1 }, mkldnn_dtype, mkldnn::memory::format::ldsnc); + for (int l = 0; l < L - 1; l++) { + DType* src_iter_n1 = workptr; // nstates * N * H + auto src_iter_memory_n1 = + mkldnn::memory({ src_iter_md_n1, cpu_engine }, src_iter_n1); + op.concat_iter_memory.push_back(src_iter_memory_n1); + workptr = src_iter_n1 + nstates * N * H; + } + mkldnn::memory::dims src_iter_tz_n = {L - 1, 1, nstates, N, H}; // ldsnc + auto src_iter_md_n = mkldnn::memory::desc( + { src_iter_tz_n }, mkldnn_dtype, mkldnn::memory::format::ldsnc); + DType* src_iter_n = workptr; // (L - 1) * nstates * N * H + auto src_iter_memory_n = + mkldnn::memory({ src_iter_md_n, cpu_engine }, src_iter_n); + op.concat_iter_memory.push_back(src_iter_memory_n); + op.hcx_memory.push_back(src_iter_memory_n); + + DType* dst_layer_n = src_iter_n + (L - 1) * nstates * N * H; // T * N * D * H + auto dst_layer_memory_n + = mkldnn::memory({ dst_layer_md, cpu_engine }, dst_layer_n); + op.y_memory.push_back(dst_layer_memory_n); + + mkldnn::memory::dims dst_iter_tz_n = {L - 1, 1, nstates, N, H}; // ldsnc + auto dst_iter_md_n = mkldnn::memory::desc( + { dst_iter_tz_n }, mkldnn_dtype, mkldnn::memory::format::ldsnc); + DType* dst_iter_n = dst_layer_n + T * N * D * H; // (L - 1) * nstates * N * H + auto dst_iter_memory_n = + mkldnn::memory({ dst_iter_md_n, cpu_engine }, dst_iter_n); + op.hcy_memory.push_back(dst_iter_memory_n); + } - DType* src_iter_undi2 = src_iter_undi + nstates * N * H; // nstates * N * H - auto src_iter_undi2_memory = - mkldnn::memory({ src_iter_undi_md, cpu_engine }, src_iter_undi2); - op.concat_iter_memory.push_back(src_iter_undi2_memory); - - mkldnn::memory::dims src_iter_tz = {1, D, nstates, N, H}; // ldsnc - auto src_iter_md = mkldnn::memory::desc( - { src_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldsnc); - DType* src_iter = src_iter_undi2 + nstates * N * H; // D * nstates * N * H - auto src_iter_memory = - mkldnn::memory({ src_iter_md, cpu_engine }, src_iter); - op.concat_iter_memory.push_back(src_iter_memory); - op.hcx_memory.push_back(src_iter_memory); - - DType* dst_layer_n = src_iter + D * nstates * N * H; // T * N * D * H - auto dst_layer_memory_n - = mkldnn::memory({ dst_layer_md, cpu_engine }, dst_layer_n); - op.y_memory.push_back(dst_layer_memory_n); + if (L > 1 && D == 2) { + mkldnn::memory::dims weights_layer_tz = {1, D, H * D, ngates, H}; // ldigo + mkldnn::memory::dims weights_iter_tz = {1, D, H, ngates, H}; // ldigo + mkldnn::memory::dims bias_tz = {1, D, ngates, H}; + auto user_weight_layer_md = mkldnn::memory::desc( + { weights_layer_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo); + auto user_weight_iter_md = mkldnn::memory::desc( + { weights_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo); + auto user_bias_md = mkldnn::memory::desc({ bias_tz }, + mkldnn_dtype, mkldnn::memory::format::ldgo); + + auto user_src_layer_md = mkldnn::memory::desc( + { src_layer_tz }, mkldnn_dtype, mkldnn::memory::format::tnc); + auto user_src_layer_memory = mkldnn::memory({ user_src_layer_md, cpu_engine }); + op.x_memory.push_back(user_src_layer_memory); + + auto wx_md_n = mkldnn::memory::desc( + { weights_layer_tz }, mkldnn_dtype, mkldnn::memory::format::ldgoi); + auto wh_md_n = mkldnn::memory::desc( + { weights_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldgoi); + + for (int l = 0; l < L; l++) { + DType* weight_layer_n = workptr; // D * (H * D) * ngates * H + auto user_weight_layer_memory_n + = mkldnn::memory({ user_weight_layer_md, cpu_engine }, weight_layer_n); + op.wx_memory.push_back(user_weight_layer_memory_n); + + DType* weight_iter_n = weight_layer_n + + D * (H * D) * ngates * H; // D * H * ngates * H + auto user_weight_iter_memory_n + = mkldnn::memory({ user_weight_iter_md, cpu_engine }, weight_iter_n); + op.wh_memory.push_back(user_weight_iter_memory_n); + + DType* bias_n = weight_iter_n + D * H * ngates * H; // D * ngates * H + auto user_bias_memory_n = + mkldnn::memory({ user_bias_md, cpu_engine }, bias_n); + op.bias_memory.push_back(user_bias_memory_n); + workptr = bias_n + D * ngates * H; + } - mkldnn::memory::dims dst_iter_tz_n = {1, D, nstates, N, H}; // ldsnc - auto dst_iter_md_n = mkldnn::memory::desc( - { dst_iter_tz_n }, mkldnn_dtype, mkldnn::memory::format::ldsnc); - DType* dst_iter_n = dst_layer_n + T * N * D * H; // D * nstates * N * H - auto dst_iter_memory_n = - mkldnn::memory({ dst_iter_md_n, cpu_engine }, dst_iter_n); - op.hcy_memory.push_back(dst_iter_memory_n); + DType* wx_n = workptr; // D * ngates * (D * H) * H + DType* wh_n = wx_n + D * ngates * (D * H) * H; // D * ngates * H * H + auto wx_memory_n = + mkldnn::memory({ wx_md_n, cpu_engine }, wx_n); + auto wh_memory_n = + mkldnn::memory({ wh_md_n, cpu_engine }, wh_n); + op.concat_weight_memory.push_back(wx_memory_n); + op.concat_weight_memory.push_back(wh_memory_n); + + mkldnn::memory::dims src_iter_undi_tz = {1, 1, nstates, N, H}; // ldsnc + auto src_iter_undi_md = mkldnn::memory::desc( + { src_iter_undi_tz }, mkldnn_dtype, mkldnn::memory::format::ldsnc); + DType* src_iter_undi = wh_n + D * ngates * H * H; // nstates * N * H + auto src_iter_undi_memory = + mkldnn::memory({ src_iter_undi_md, cpu_engine }, src_iter_undi); + op.concat_iter_memory.push_back(src_iter_undi_memory_0); + + DType* src_iter_undi2 = src_iter_undi + nstates * N * H; // nstates * N * H + auto src_iter_undi2_memory = + mkldnn::memory({ src_iter_undi_md, cpu_engine }, src_iter_undi2); + op.concat_iter_memory.push_back(src_iter_undi2_memory); + + mkldnn::memory::dims src_iter_tz = {1, D, nstates, N, H}; // ldsnc + auto src_iter_md = mkldnn::memory::desc( + { src_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldsnc); + DType* src_iter = src_iter_undi2 + nstates * N * H; // D * nstates * N * H + auto src_iter_memory = + mkldnn::memory({ src_iter_md, cpu_engine }, src_iter); + op.concat_iter_memory.push_back(src_iter_memory); + op.hcx_memory.push_back(src_iter_memory); + + DType* dst_layer_n = src_iter + D * nstates * N * H; // T * N * D * H + auto dst_layer_memory_n + = mkldnn::memory({ dst_layer_md, cpu_engine }, dst_layer_n); + op.y_memory.push_back(dst_layer_memory_n); + + mkldnn::memory::dims dst_iter_tz_n = {1, D, nstates, N, H}; // ldsnc + auto dst_iter_md_n = mkldnn::memory::desc( + { dst_iter_tz_n }, mkldnn_dtype, mkldnn::memory::format::ldsnc); + DType* dst_iter_n = dst_layer_n + T * N * D * H; // D * nstates * N * H + auto dst_iter_memory_n = + mkldnn::memory({ dst_iter_md_n, cpu_engine }, dst_iter_n); + op.hcy_memory.push_back(dst_iter_memory_n); + } } } - } - op.Forward(ctx, in_blobs, req, out_blobs); + op.Forward(ctx, in_blobs, req, out_blobs); + }); }); } From 78f252ee19993c12d7b772888684ec062023f9ed Mon Sep 17 00:00:00 2001 From: Wei Date: Mon, 20 May 2019 13:00:33 +0800 Subject: [PATCH 12/21] TODO for MKLDNN GRU --- src/operator/rnn-inl.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/operator/rnn-inl.h b/src/operator/rnn-inl.h index 9034c5746b8d..0caf42147bd3 100644 --- a/src/operator/rnn-inl.h +++ b/src/operator/rnn-inl.h @@ -904,7 +904,8 @@ class RNNOp { } else { #if MXNET_USE_MKLDNN == 1 && !defined(__CUDACC__) if (param_.mode != rnn_enum::kGru) { - // mkldnn Gru has precision issue + // TODO(zixuanweeei): MKLDNN GRU has precision issue. A stable one + // will be added to MXNet when it passes the unit test. int dtype = in_data[rnn_enum::kData].type_flag_; MKLDNNRNNForwardInference(param_.state_outputs, param_.num_layers, From d926794355d1cd6d1fa38d49b69d24ff9ef6c04f Mon Sep 17 00:00:00 2001 From: Wei Date: Mon, 20 May 2019 13:01:46 +0800 Subject: [PATCH 13/21] fix bugs in memory adjustment --- src/operator/nn/mkldnn/mkldnn_rnn_impl.h | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/operator/nn/mkldnn/mkldnn_rnn_impl.h b/src/operator/nn/mkldnn/mkldnn_rnn_impl.h index 7975b392afa0..8532c379ffc8 100644 --- a/src/operator/nn/mkldnn/mkldnn_rnn_impl.h +++ b/src/operator/nn/mkldnn/mkldnn_rnn_impl.h @@ -29,7 +29,6 @@ #include #include #include -#include "../../math.h" #include "../../math_functions-inl.h" #include "../../operator_common.h" #include "../../rnn_impl.h" @@ -167,7 +166,7 @@ void AdjustGruBiasGateOrder(DType* bias, bias_reset[i] = tmp; } } -// since there is different sematics of MKLDNN's Fused RNN and Mxnet FusedRNN, +// since there is different sematics of MKLDNN's Fused RNN and MXNet FusedRNN, // bidirectional will be fused layer by layer, // unidirectional will be done by fused 1 + fused (L - 1) layers or fused L layers(when I = H) @@ -487,6 +486,8 @@ void MKLDNNRNNForwardUnidi(bool state_outputs, if (mode == rnn_enum::kGru) { AdjustGruWeightGateOrder(wx, I, H); AdjustGruWeightGateOrder(wh, H, H); + AdjustGruBiasGateOrder(b_ptr, H); + AdjustGruBiasGateOrder(b_ptr + H * ngates, H); } src_wx_f.set_data_handle(wx); src_wh_f.set_data_handle(wh); From 980ce85b3c072107bc0fcb2406ea9aa5ba8c18d5 Mon Sep 17 00:00:00 2001 From: Wei Date: Tue, 21 May 2019 12:38:16 +0800 Subject: [PATCH 14/21] Reformat TODO for MKLDNN GRU --- src/operator/rnn-inl.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/operator/rnn-inl.h b/src/operator/rnn-inl.h index 0caf42147bd3..9b9911ff007f 100644 --- a/src/operator/rnn-inl.h +++ b/src/operator/rnn-inl.h @@ -904,8 +904,8 @@ class RNNOp { } else { #if MXNET_USE_MKLDNN == 1 && !defined(__CUDACC__) if (param_.mode != rnn_enum::kGru) { - // TODO(zixuanweeei): MKLDNN GRU has precision issue. A stable one - // will be added to MXNet when it passes the unit test. + // TODO(zixuanweeei): MKLDNN GRU has precision issue. A stable one + // will be added to MXNet when we figure out the issue. int dtype = in_data[rnn_enum::kData].type_flag_; MKLDNNRNNForwardInference(param_.state_outputs, param_.num_layers, From 620bad3b15272a7395c7d31f47dcc678c4181af6 Mon Sep 17 00:00:00 2001 From: Wei Date: Wed, 22 May 2019 16:14:48 +0800 Subject: [PATCH 15/21] Reserve original RNN path --- src/operator/rnn-inl.h | 73 ++++++++++++------------------------------ 1 file changed, 20 insertions(+), 53 deletions(-) diff --git a/src/operator/rnn-inl.h b/src/operator/rnn-inl.h index 9b9911ff007f..7eda8b612896 100644 --- a/src/operator/rnn-inl.h +++ b/src/operator/rnn-inl.h @@ -903,7 +903,7 @@ class RNNOp { param_.mode); } else { #if MXNET_USE_MKLDNN == 1 && !defined(__CUDACC__) - if (param_.mode != rnn_enum::kGru) { + if (dmlc::GetEnv("MXNET_USE_MKLDNN_RNN", 1)) { // TODO(zixuanweeei): MKLDNN GRU has precision issue. A stable one // will be added to MXNet when we figure out the issue. int dtype = in_data[rnn_enum::kData].type_flag_; @@ -937,8 +937,7 @@ class RNNOp { ctx.is_train, param_.mode); } else { - // Before integrating MKLDNN GRU fp32 inference - // using below code for keep func being OK + #endif const size_t work_cpu_space_size = GetRNNWorkspaceSize(param_.seq_length_, param_.batch_size_, param_.state_size, direction, param_.mode); @@ -954,57 +953,25 @@ class RNNOp { } DType* work_cpu_space = static_cast(temp_cpu_space_.dptr); RNNForwardInference(work_cpu_space, - param_.state_outputs, - param_.num_layers, - direction, - param_.seq_length_, - param_.batch_size_, - param_.input_size_, - param_.state_size, - x.dptr_, - hx.dptr_, - cx_ptr, - w.dptr_, - b_ptr, - y.dptr_, - hy_ptr, - cy_ptr, - param_.mode); - } - #else - const size_t work_cpu_space_size = - GetRNNWorkspaceSize(param_.seq_length_, param_.batch_size_, - param_.state_size, direction, param_.mode); - if (temp_init_space_ && temp_cpu_space_size_ < work_cpu_space_size) { - Storage::Get()->Free(temp_cpu_space_); - temp_init_space_ = false; - } - if (!temp_init_space_) { - temp_cpu_space_ = Storage::Get()->Alloc - (work_cpu_space_size * sizeof(DType), Context::CPU()); - temp_cpu_space_size_ = work_cpu_space_size; - temp_init_space_ = true; - } - DType* work_cpu_space = static_cast(temp_cpu_space_.dptr); - RNNForwardInference(work_cpu_space, - param_.state_outputs, - param_.num_layers, - direction, - param_.seq_length_, - param_.batch_size_, - param_.input_size_, - param_.state_size, - x.dptr_, - hx.dptr_, - cx_ptr, - w.dptr_, - b_ptr, - y.dptr_, - hy_ptr, - cy_ptr, - param_.mode); - #endif + param_.state_outputs, + param_.num_layers, + direction, + param_.seq_length_, + param_.batch_size_, + param_.input_size_, + param_.state_size, + x.dptr_, + hx.dptr_, + cx_ptr, + w.dptr_, + b_ptr, + y.dptr_, + hy_ptr, + cy_ptr, + param_.mode); + #if MXNET_USE_MKLDNN == 1 && !defined(__CUDACC__) } + #endif } } From bc66cdf2f86d1514764f57e6858834496c7c3f77 Mon Sep 17 00:00:00 2001 From: Wei Date: Wed, 22 May 2019 16:31:53 +0800 Subject: [PATCH 16/21] Remove MKLDNN GRU --- src/operator/rnn-inl.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/operator/rnn-inl.h b/src/operator/rnn-inl.h index 7eda8b612896..9a50baf74e85 100644 --- a/src/operator/rnn-inl.h +++ b/src/operator/rnn-inl.h @@ -903,7 +903,7 @@ class RNNOp { param_.mode); } else { #if MXNET_USE_MKLDNN == 1 && !defined(__CUDACC__) - if (dmlc::GetEnv("MXNET_USE_MKLDNN_RNN", 1)) { + if (dmlc::GetEnv("MXNET_USE_MKLDNN_RNN", 1) && param_.mode != rnn_enum::kGru) { // TODO(zixuanweeei): MKLDNN GRU has precision issue. A stable one // will be added to MXNet when we figure out the issue. int dtype = in_data[rnn_enum::kData].type_flag_; From 02ee00574693965b2d3d860ce6698a80f580e9ad Mon Sep 17 00:00:00 2001 From: Wei Date: Wed, 22 May 2019 17:45:56 +0800 Subject: [PATCH 17/21] Fix bug for rnn forward --- src/operator/rnn-inl.h | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/operator/rnn-inl.h b/src/operator/rnn-inl.h index 9a50baf74e85..6ce17e4cbdfd 100644 --- a/src/operator/rnn-inl.h +++ b/src/operator/rnn-inl.h @@ -969,9 +969,10 @@ class RNNOp { hy_ptr, cy_ptr, param_.mode); - #if MXNET_USE_MKLDNN == 1 && !defined(__CUDACC__) + #if MXNET_USE_MKLDNN == 1 && !defined(__CUDACC__) + } + #endif } - #endif } } From 41998f9fdcac6f30e7194152488984f406df1136 Mon Sep 17 00:00:00 2001 From: Wei Date: Thu, 23 May 2019 15:37:22 +0800 Subject: [PATCH 18/21] Remove `__CUDAACC__` --- src/operator/nn/mkldnn/mkldnn_rnn_impl.h | 270 +++++++++++------------ src/operator/rnn-inl.h | 50 +++-- src/operator/rnn.cc | 4 +- 3 files changed, 163 insertions(+), 161 deletions(-) diff --git a/src/operator/nn/mkldnn/mkldnn_rnn_impl.h b/src/operator/nn/mkldnn/mkldnn_rnn_impl.h index 8532c379ffc8..ea8e07ea617c 100644 --- a/src/operator/nn/mkldnn/mkldnn_rnn_impl.h +++ b/src/operator/nn/mkldnn/mkldnn_rnn_impl.h @@ -39,9 +39,9 @@ namespace mxnet { namespace op { -algorithm GetMKLDNNRNNAlgo(int mode, - int* ngates, - int* nstates) { +static algorithm GetMKLDNNRNNAlgo(int mode, + int* ngates, + int* nstates) { algorithm algo = algorithm::vanilla_rnn; switch (mode) { case rnn_enum::kLstm: @@ -67,14 +67,14 @@ algorithm GetMKLDNNRNNAlgo(int mode, return algo; } -void ConcatData(mkldnn::memory::format src_format, - mkldnn::memory::format dst_format, - std::vector srcs_cds, - mkldnn::memory::dims dst_cds, - mkldnn::memory::data_type mkldnn_dtype, - int concat_dimension, - std::vector srcs_data, - const mkldnn::memory &dst) { +static void ConcatData(mkldnn::memory::format src_format, + mkldnn::memory::format dst_format, + std::vector srcs_cds, + mkldnn::memory::dims dst_cds, + mkldnn::memory::data_type mkldnn_dtype, + int concat_dimension, + std::vector srcs_data, + const mkldnn::memory &dst) { auto cpu_engine = CpuEngine::Get()->get_engine(); std::vector srcs_pd; std::vector srcs; @@ -102,7 +102,7 @@ void ConcatData(mkldnn::memory::format src_format, // for unidirectional, it will fused as dim like 1 + (L - 1) when I != H. // for bidirectional, it will fused as data + back_data (weight, bias, iter etc), // also need to identify first layer and next layers -inline size_t GetMKLDNNRNNCacheMemorySize(int L, +static size_t GetMKLDNNRNNCacheMemorySize(int L, int D, int T, int N, @@ -135,9 +135,9 @@ inline size_t GetMKLDNNRNNCacheMemorySize(int L, } template -void AdjustGruWeightGateOrder(DType* weight, - const int I, - const int H) { +static void AdjustGruWeightGateOrder(DType* weight, + const int I, + const int H) { // mxnet gru gate order is reset, update and new gates // mkldnn gru gate order is update, reset and new gates const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); @@ -152,8 +152,8 @@ void AdjustGruWeightGateOrder(DType* weight, } template -void AdjustGruBiasGateOrder(DType* bias, - const int H) { +static void AdjustGruBiasGateOrder(DType* bias, + const int H) { // mxnet gru gate order is reset, update and new gates // mkldnn gru gate order is update, reset and new gates const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); @@ -171,36 +171,36 @@ void AdjustGruBiasGateOrder(DType* bias, // unidirectional will be done by fused 1 + fused (L - 1) layers or fused L layers(when I = H) template -void MKLDNNRNNForwardSingleLayerBi(bool state_outputs, - const int T, - const int N, - const int I, - const int H, - DType* x_ptr, - mkldnn::memory *user_src_layer_memory, - DType* hx_ptr, - DType* cx_ptr, - DType* w_ptr, - DType* b_ptr, - DType* y_ptr, - DType* hy_ptr, - DType* cy_ptr, - std::vector *concat_weight_memory, - std::vector *concat_iter_memory, - std::vector *x_memory, - std::vector *hcx_memory, - std::vector *wx_memory, - std::vector *wh_memory, - std::vector *bias_memory, - std::vector *y_memory, - std::vector *hcy_memory, - std::vector *rnn_forward_prim, - int layer_index, - bool *has_cache, - int lvalue, - int dtype, - bool is_train, - int mode) { +static void MKLDNNRNNForwardSingleLayerBi(bool state_outputs, + const int T, + const int N, + const int I, + const int H, + DType* x_ptr, + mkldnn::memory *user_src_layer_memory, + DType* hx_ptr, + DType* cx_ptr, + DType* w_ptr, + DType* b_ptr, + DType* y_ptr, + DType* hy_ptr, + DType* cy_ptr, + std::vector *concat_weight_memory, + std::vector *concat_iter_memory, + std::vector *x_memory, + std::vector *hcx_memory, + std::vector *wx_memory, + std::vector *wh_memory, + std::vector *bias_memory, + std::vector *y_memory, + std::vector *hcy_memory, + std::vector *rnn_forward_prim, + int layer_index, + bool *has_cache, + int lvalue, + int dtype, + bool is_train, + int mode) { int ngates = 0, nstates = 0; algorithm nalgorithm = GetMKLDNNRNNAlgo(mode, &ngates, &nstates); mkldnn::memory::data_type mkldnn_dtype = get_mkldnn_type(dtype); @@ -369,36 +369,36 @@ void MKLDNNRNNForwardSingleLayerBi(bool state_outputs, template -void MKLDNNRNNForwardUnidi(bool state_outputs, - const int L, - const int T, - const int N, - const int I, - const int H, - DType* x_ptr, - mkldnn::memory *user_src_layer_memory, - DType* hx_ptr, - DType* cx_ptr, - DType* w_ptr, - DType* b_ptr, - DType* y_ptr, - DType* hy_ptr, - DType* cy_ptr, - std::vector *concat_weight_memory, - std::vector *concat_iter_memory, - std::vector *x_memory, - std::vector *hcx_memory, - std::vector *wx_memory, - std::vector *wh_memory, - std::vector *bias_memory, - std::vector *y_memory, - std::vector *hcy_memory, - std::vector *rnn_forward_prim, - int layer_index, - bool *has_cache, - int dtype, - bool is_train, - int mode) { +static void MKLDNNRNNForwardUnidi(bool state_outputs, + const int L, + const int T, + const int N, + const int I, + const int H, + DType* x_ptr, + mkldnn::memory *user_src_layer_memory, + DType* hx_ptr, + DType* cx_ptr, + DType* w_ptr, + DType* b_ptr, + DType* y_ptr, + DType* hy_ptr, + DType* cy_ptr, + std::vector *concat_weight_memory, + std::vector *concat_iter_memory, + std::vector *x_memory, + std::vector *hcx_memory, + std::vector *wx_memory, + std::vector *wh_memory, + std::vector *bias_memory, + std::vector *y_memory, + std::vector *hcy_memory, + std::vector *rnn_forward_prim, + int layer_index, + bool *has_cache, + int dtype, + bool is_train, + int mode) { int ngates = 0, nstates = 0; algorithm nalgorithm = GetMKLDNNRNNAlgo(mode, &ngates, &nstates); mkldnn::memory::data_type mkldnn_dtype = get_mkldnn_type(dtype); @@ -576,35 +576,35 @@ void MKLDNNRNNForwardUnidi(bool state_outputs, } template -void MKLDNNRNNForward(bool state_outputs, - const int L, - const int D, - const int T, - const int N, - const int I, - const int H, - DType* x_ptr, - DType* hx_ptr, - DType* cx_ptr, - DType* w_ptr, - DType* b_ptr, - DType* y_ptr, - DType* hy_ptr, - DType* cy_ptr, - std::vector *concat_weight_memory, - std::vector *concat_iter_memory, - std::vector *x_memory, - std::vector *hcx_memory, - std::vector *wx_memory, - std::vector *wh_memory, - std::vector *bias_memory, - std::vector *y_memory, - std::vector *hcy_memory, - std::vector *rnn_forward_prim, - bool *has_cache, - int dtype, - bool is_train, - int mode) { +static void MKLDNNRNNForward(bool state_outputs, + const int L, + const int D, + const int T, + const int N, + const int I, + const int H, + DType* x_ptr, + DType* hx_ptr, + DType* cx_ptr, + DType* w_ptr, + DType* b_ptr, + DType* y_ptr, + DType* hy_ptr, + DType* cy_ptr, + std::vector *concat_weight_memory, + std::vector *concat_iter_memory, + std::vector *x_memory, + std::vector *hcx_memory, + std::vector *wx_memory, + std::vector *wh_memory, + std::vector *bias_memory, + std::vector *y_memory, + std::vector *hcy_memory, + std::vector *rnn_forward_prim, + bool *has_cache, + int dtype, + bool is_train, + int mode) { int ngates = 0, nstates = 0; GetMKLDNNRNNAlgo(mode, &ngates, &nstates); const int b_size = 2 * H * ngates * D; @@ -686,35 +686,35 @@ void MKLDNNRNNForward(bool state_outputs, } template -void MKLDNNRNNForwardInference(bool state_outputs, - const int num_layers, - const int direction, - const int seq_length, - const int batch_size, - const int input_size, - const int state_size, - DType* x_ptr, - DType* hx_ptr, - DType* cx_ptr, - DType* w_ptr, - DType* b_ptr, - DType* y_ptr, - DType* hy_ptr, - DType* cy_ptr, - std::vector* concat_weight_memory, - std::vector* concat_iter_memory, - std::vector *x_memory, - std::vector *hcx_memory, - std::vector *wx_memory, - std::vector *wh_memory, - std::vector *bias_memory, - std::vector *y_memory, - std::vector *hcy_memory, - std::vector *rnn_forward_prim, - bool *has_cache, - int dtype, - bool is_train, - int mode) { +static void MKLDNNRNNForwardInference(bool state_outputs, + const int num_layers, + const int direction, + const int seq_length, + const int batch_size, + const int input_size, + const int state_size, + DType* x_ptr, + DType* hx_ptr, + DType* cx_ptr, + DType* w_ptr, + DType* b_ptr, + DType* y_ptr, + DType* hy_ptr, + DType* cy_ptr, + std::vector* concat_weight_memory, + std::vector* concat_iter_memory, + std::vector *x_memory, + std::vector *hcx_memory, + std::vector *wx_memory, + std::vector *wh_memory, + std::vector *bias_memory, + std::vector *y_memory, + std::vector *hcy_memory, + std::vector *rnn_forward_prim, + bool *has_cache, + int dtype, + bool is_train, + int mode) { switch (mode) { case rnn_enum::kLstm: case rnn_enum::kGru: diff --git a/src/operator/rnn-inl.h b/src/operator/rnn-inl.h index 6ce17e4cbdfd..9681a47c82d6 100644 --- a/src/operator/rnn-inl.h +++ b/src/operator/rnn-inl.h @@ -44,7 +44,7 @@ #include "./math_functions-inl.h" #include "./operator_common.h" #include "./rnn_impl.h" -#if MXNET_USE_MKLDNN == 1 && !defined(__CUDACC__) +#if MXNET_USE_MKLDNN == 1 #include "./nn/mkldnn/mkldnn_rnn_impl.h" #endif @@ -426,8 +426,8 @@ class RNNOp { // No tests in place for fp16 RNNs, so leave TensorCore disabled for now. cudnn_tensor_core_ = false; // When fp16 RNN tests are introduced, we can enable TensorCore as follows: -// cudnn_tensor_core = -// mshadow::DataType::kFlag == mshadow::kFloat16 && GetEnvAllowTensorCore(); + // cudnn_tensor_core = + // mshadow::DataType::kFlag == mshadow::kFloat16 && GetEnvAllowTensorCore(); // Defaults input_mode_ = CUDNN_LINEAR_INPUT; // Don't support this yet // RNN Mode @@ -938,12 +938,14 @@ class RNNOp { param_.mode); } else { #endif + // Before integrating MKLDNN GRU fp32 inference + // using below code for keep func being OK const size_t work_cpu_space_size = GetRNNWorkspaceSize(param_.seq_length_, param_.batch_size_, - param_.state_size, direction, param_.mode); + param_.state_size, direction, param_.mode); if (temp_init_space_ && temp_cpu_space_size_ < work_cpu_space_size) { - Storage::Get()->Free(temp_cpu_space_); - temp_init_space_ = false; + Storage::Get()->Free(temp_cpu_space_); + temp_init_space_ = false; } if (!temp_init_space_) { temp_cpu_space_ = Storage::Get()->Alloc @@ -953,22 +955,22 @@ class RNNOp { } DType* work_cpu_space = static_cast(temp_cpu_space_.dptr); RNNForwardInference(work_cpu_space, - param_.state_outputs, - param_.num_layers, - direction, - param_.seq_length_, - param_.batch_size_, - param_.input_size_, - param_.state_size, - x.dptr_, - hx.dptr_, - cx_ptr, - w.dptr_, - b_ptr, - y.dptr_, - hy_ptr, - cy_ptr, - param_.mode); + param_.state_outputs, + param_.num_layers, + direction, + param_.seq_length_, + param_.batch_size_, + param_.input_size_, + param_.state_size, + x.dptr_, + hx.dptr_, + cx_ptr, + w.dptr_, + b_ptr, + y.dptr_, + hy_ptr, + cy_ptr, + param_.mode); #if MXNET_USE_MKLDNN == 1 && !defined(__CUDACC__) } #endif @@ -1610,7 +1612,7 @@ void RNNStatefulCompute(const OpStatePtr& state, }); } -#if MXNET_USE_MKLDNN == 1 && !defined(__CUDACC__) +#if MXNET_USE_MKLDNN == 1 static void RNNStatefulComputeCPU(const OpStatePtr& state_ptr, const OpContext& ctx, const std::vector& inputs, @@ -2019,8 +2021,8 @@ static void RNNStatefulComputeCPU(const OpStatePtr& state_ptr, }); }); } - #endif + /* index description 0: x diff --git a/src/operator/rnn.cc b/src/operator/rnn.cc index 142c77404b41..90d8da6f84b0 100644 --- a/src/operator/rnn.cc +++ b/src/operator/rnn.cc @@ -174,7 +174,7 @@ inline static bool RNNStorageType(const nnvm::NodeAttrs& attrs, std::vector *out_attrs) { DispatchMode wanted_mode = DispatchMode::kFCompute; - #if MXNET_USE_MKLDNN == 1 && !defined(__CUDACC__) + #if MXNET_USE_MKLDNN == 1 wanted_mode = DispatchMode::kFComputeEx; #endif @@ -287,7 +287,7 @@ The definition of GRU here is slightly different from paper but compatible with .set_attr("FInferStorageType", RNNStorageType) .set_attr("FCreateOpState", CreateRNNState) .set_attr("FStatefulCompute", RNNStatefulCompute) -#if MXNET_USE_MKLDNN == 1 && !defined(__CUDACC__) +#if MXNET_USE_MKLDNN == 1 .set_attr("TIsMKLDNN", true) .set_attr("FStatefulComputeEx", RNNStatefulComputeCPU) #endif From f165f400d8f43608c2b84ed33d497441118e7c0c Mon Sep 17 00:00:00 2001 From: Wei Date: Thu, 23 May 2019 16:19:10 +0800 Subject: [PATCH 19/21] Move `RNNStatefulComputeCPU` to rnn.cc --- src/operator/rnn-inl.h | 411 ----------------------------------------- src/operator/rnn.cc | 411 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 411 insertions(+), 411 deletions(-) diff --git a/src/operator/rnn-inl.h b/src/operator/rnn-inl.h index 9681a47c82d6..64397e9aab31 100644 --- a/src/operator/rnn-inl.h +++ b/src/operator/rnn-inl.h @@ -1612,417 +1612,6 @@ void RNNStatefulCompute(const OpStatePtr& state, }); } -#if MXNET_USE_MKLDNN == 1 -static void RNNStatefulComputeCPU(const OpStatePtr& state_ptr, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { - std::vector in_blobs; - std::vector out_blobs; - std::vector temp_ndarrays_i; - std::vector temp_ndarrays_o; - for (const NDArray& in : inputs) { - if (in.storage_type() == kDefaultStorage) { - temp_ndarrays_i.push_back(in.Reorder2Default()); - in_blobs.emplace_back(temp_ndarrays_i.back().data()); - } else { - in_blobs.emplace_back(in.data()); - } - } - - for (const NDArray& out : outputs) { - if (out.storage_type() == kDefaultStorage) { - temp_ndarrays_o.push_back(out.Reorder2Default()); - out_blobs.emplace_back(temp_ndarrays_o.back().data()); - } else { - out_blobs.emplace_back(out.data()); - } - } - int dtype = in_blobs[rnn_enum::kData].type_flag_; - int itype = in_blobs[inputs.size()-1].type_flag_; - mkldnn::memory::data_type mkldnn_dtype = get_mkldnn_type(dtype); - Stream *s = ctx.get_stream(); - auto cpu_engine = CpuEngine::Get()->get_engine(); - MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { - MSHADOW_TYPE_SWITCH(itype, IType, { - RNNOp& op = state_ptr.get_state>(); - const RNNParam& param = op.param_; - int ngates = 0, nstates = 0; - GetMKLDNNRNNAlgo(param.mode, &ngates, &nstates); - int D = param.bidirectional ? 2 : 1; - Tensor x = in_blobs[rnn_enum::kData].get(s); - int T = x.shape_[0]; - int N = x.shape_[1]; - int I = x.shape_[2]; - int H = param.state_size; - int L = param.num_layers; - - const size_t r_size = GetMKLDNNRNNCacheMemorySize(L, D, T, N, I, H, param.mode); - if (op.init_mem_ && op.reserve_mem_size_ < r_size) { - Storage::Get()->Free(op.mem_space_); - op.init_mem_ = false; - } - if (!op.init_mem_) { - op.mem_space_ = Storage::Get()->Alloc( - r_size * sizeof(DType), - Context::CPU()); - op.reserve_mem_size_ = r_size; - op.init_mem_ = true; - op.has_cache = false; - } - if (op.has_cache && op.x_memory.size() == 0) { - op.has_cache = false; - } - - DType* workptr = static_cast(op.mem_space_.dptr); - mkldnn::memory::dims src_layer_tz_0 = {T, N, I}; - mkldnn::memory::dims src_layer_tz = {T, N, D * H}; - mkldnn::memory::dims dst_layer_tz = {T, N, D * H}; - auto dst_layer_md = mkldnn::memory::desc( - { dst_layer_tz }, mkldnn_dtype, mkldnn::memory::format::tnc); - if (op.x_memory.size() == 0) { - if (D == 1 && I == H) { - auto user_src_layer_md = mkldnn::memory::desc( - { src_layer_tz }, mkldnn_dtype, mkldnn::memory::format::tnc); - auto user_src_layer_memory_n = mkldnn::memory({ user_src_layer_md, cpu_engine }); - op.x_memory.push_back(user_src_layer_memory_n); - - mkldnn::memory::dims weights_layer_tz = {L, 1, I, ngates, H}; // ldigo - mkldnn::memory::dims weights_iter_tz = {L, 1, H, ngates, H}; // ldigo - mkldnn::memory::dims bias_tz = {L, 1, ngates, H}; - auto user_weight_layer_md = mkldnn::memory::desc( - { weights_layer_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo); - auto user_weight_iter_md = mkldnn::memory::desc( - { weights_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo); - auto user_bias_md = mkldnn::memory::desc({ bias_tz }, - mkldnn_dtype, mkldnn::memory::format::ldgo); - DType* weight_layer_n = workptr; // L * I * ngates * H - auto user_weight_layer_memory_n - = mkldnn::memory({ user_weight_layer_md, cpu_engine }, weight_layer_n); - op.wx_memory.push_back(user_weight_layer_memory_n); - - DType* weight_iter_n = weight_layer_n + L * I * ngates * H; // L * H * ngates * H - auto user_weight_iter_memory_n - = mkldnn::memory({ user_weight_iter_md, cpu_engine }, weight_iter_n); - op.wh_memory.push_back(user_weight_iter_memory_n); - - DType* bias_n = weight_iter_n + L * H * ngates * H; // L * ngates * H - auto user_bias_memory_n = - mkldnn::memory({ user_bias_md, cpu_engine }, bias_n); - op.bias_memory.push_back(user_bias_memory_n); - - auto wx_md_n = mkldnn::memory::desc( - { weights_layer_tz }, mkldnn_dtype, mkldnn::memory::format::ldgoi); - DType* wx_n = bias_n + L * ngates * H; // L * ngates * I * H - auto wx_memory_n = - mkldnn::memory({ wx_md_n, cpu_engine }, wx_n); - DType* wh_n = wx_n + L * ngates * I * H; // L * ngates * H * H - auto wh_md_n = mkldnn::memory::desc( - { weights_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldgoi); - auto wh_memory_n = - mkldnn::memory({ wh_md_n, cpu_engine }, wh_n); - - op.concat_weight_memory.push_back(wx_memory_n); - op.concat_weight_memory.push_back(wh_memory_n); - workptr = wh_n + L * ngates * H * H; - - mkldnn::memory::dims src_iter_tz_n1 = {1, 1, nstates, N, H}; // ldsnc - auto src_iter_md_n1 = mkldnn::memory::desc( - { src_iter_tz_n1 }, mkldnn_dtype, mkldnn::memory::format::ldsnc); - for (int l = 0; l < L; l++) { - DType* src_iter_n1 = workptr; // nstates * N * H - auto src_iter_memory_n1 = - mkldnn::memory({ src_iter_md_n1, cpu_engine }, src_iter_n1); - op.concat_iter_memory.push_back(src_iter_memory_n1); - workptr = src_iter_n1 + nstates * N * H; - } - mkldnn::memory::dims src_iter_tz_n = {L, 1, nstates, N, H}; // ldsnc - auto src_iter_md_n = mkldnn::memory::desc( - { src_iter_tz_n }, mkldnn_dtype, mkldnn::memory::format::ldsnc); - DType* src_iter_n = workptr; // L * nstates * N * H - auto src_iter_memory_n = - mkldnn::memory({ src_iter_md_n, cpu_engine }, src_iter_n); - op.concat_iter_memory.push_back(src_iter_memory_n); - op.hcx_memory.push_back(src_iter_memory_n); - DType* dst_layer_n = src_iter_n + L * nstates * N * H; // T * N * D * H - auto dst_layer_memory_n - = mkldnn::memory({ dst_layer_md, cpu_engine }, dst_layer_n); - op.y_memory.push_back(dst_layer_memory_n); - - mkldnn::memory::dims dst_iter_tz_n = {L, 1, nstates, N, H}; // ldsnc - auto dst_iter_md_n = mkldnn::memory::desc( - { dst_iter_tz_n }, mkldnn_dtype, mkldnn::memory::format::ldsnc); - DType* dst_iter_n = dst_layer_n + T * N * D * H; // L * nstates * N * H - auto dst_iter_memory_n = - mkldnn::memory({ dst_iter_md_n, cpu_engine }, dst_iter_n); - op.hcy_memory.push_back(dst_iter_memory_n); - workptr = dst_iter_n + L * nstates * N * H; - - } else { - auto user_src_layer_md_0 = mkldnn::memory::desc( - { src_layer_tz_0 }, mkldnn_dtype, mkldnn::memory::format::tnc); - auto user_src_layer_memory_0 = mkldnn::memory({ user_src_layer_md_0, cpu_engine }); - op.x_memory.push_back(user_src_layer_memory_0); - - mkldnn::memory::dims weights_layer_tz_0 = {1, D, I, ngates, H}; // ldigo - mkldnn::memory::dims weights_iter_tz_0 = {1, D, H, ngates, H}; // ldigo - mkldnn::memory::dims bias_tz_0 = {1, D, ngates, H}; - auto user_weight_layer_md_0 = mkldnn::memory::desc( - { weights_layer_tz_0 }, mkldnn_dtype, mkldnn::memory::format::ldigo); - auto user_weight_iter_md_0 = mkldnn::memory::desc( - { weights_iter_tz_0 }, mkldnn_dtype, mkldnn::memory::format::ldigo); - auto user_bias_md_0 = mkldnn::memory::desc({ bias_tz_0 }, - mkldnn_dtype, mkldnn::memory::format::ldgo); - - DType* weight_layer_0 = workptr; // D * I * ngates * H - auto user_weight_layer_memory_0 - = mkldnn::memory({ user_weight_layer_md_0, cpu_engine }, weight_layer_0); - op.wx_memory.push_back(user_weight_layer_memory_0); - - DType* weight_iter_0 = weight_layer_0 + D * I * ngates * H; // D * H * ngates * H - auto user_weight_iter_memory_0 - = mkldnn::memory({ user_weight_iter_md_0, cpu_engine }, weight_iter_0); - op.wh_memory.push_back(user_weight_iter_memory_0); - - DType* bias_0 = weight_iter_0 + D * H * ngates * H; // D * ngates * H - auto user_bias_memory_0 = - mkldnn::memory({ user_bias_md_0, cpu_engine }, bias_0); - op.bias_memory.push_back(user_bias_memory_0); - workptr = bias_0 + D * ngates * H; - - auto wx_md_0 = mkldnn::memory::desc( - { weights_layer_tz_0 }, mkldnn_dtype, mkldnn::memory::format::ldgoi); - auto wx_memory_0 = - mkldnn::memory({ wx_md_0, cpu_engine }); - auto wh_md_0 = mkldnn::memory::desc( - { weights_iter_tz_0 }, mkldnn_dtype, mkldnn::memory::format::ldgoi); - auto wh_memory_0 = - mkldnn::memory({ wh_md_0, cpu_engine }); - if (D == 2) { - DType* wx_0 = workptr; // D * ngates * I * H - wx_memory_0.set_data_handle(wx_0); - DType* wh_0 = wx_0 + D * ngates * I * H; // D * ngates * H * H - wh_memory_0.set_data_handle(wh_0); - workptr = wh_0 + D * ngates * H * H; - } - op.concat_weight_memory.push_back(wx_memory_0); - op.concat_weight_memory.push_back(wh_memory_0); - - mkldnn::memory::dims src_iter_undi_tz_0 = {1, 1, nstates, N, H}; // ldsnc - auto src_iter_undi_md_0 = mkldnn::memory::desc( - { src_iter_undi_tz_0 }, mkldnn_dtype, mkldnn::memory::format::ldsnc); - DType* src_iter_undi_0 = workptr; // nstates * N * H - auto src_iter_undi_memory_0 = - mkldnn::memory({ src_iter_undi_md_0, cpu_engine }, src_iter_undi_0); - op.concat_iter_memory.push_back(src_iter_undi_memory_0); - workptr = src_iter_undi_0 + nstates * N * H; - if (D == 1) { - op.hcx_memory.push_back(src_iter_undi_memory_0); - } else { - DType* src_iter_undi2_0 = workptr; // nstates * N * H - auto src_iter_undi2_memory_0 = - mkldnn::memory({ src_iter_undi_md_0, cpu_engine }, src_iter_undi2_0); - op.concat_iter_memory.push_back(src_iter_undi2_memory_0); - - mkldnn::memory::dims src_iter_tz_0 = {1, D, nstates, N, H}; // ldsnc - auto src_iter_md_0 = mkldnn::memory::desc( - { src_iter_tz_0 }, mkldnn_dtype, mkldnn::memory::format::ldsnc); - DType* src_iter_0 = src_iter_undi2_0 + nstates * N * H; // D * nstates * N * H - auto src_iter_memory_0 = - mkldnn::memory({ src_iter_md_0, cpu_engine }, src_iter_0); - op.concat_iter_memory.push_back(src_iter_memory_0); - op.hcx_memory.push_back(src_iter_memory_0); - workptr = src_iter_0 + D * nstates * N * H; - } - - DType* dst_layer_0 = workptr; // T * N * D * H - auto dst_layer_memory_0 - = mkldnn::memory({ dst_layer_md, cpu_engine }, dst_layer_0); - op.y_memory.push_back(dst_layer_memory_0); - - mkldnn::memory::dims dst_iter_tz_0 = {1, D, nstates, N, H}; // ldsnc - auto dst_iter_md_0 = mkldnn::memory::desc( - { dst_iter_tz_0 }, mkldnn_dtype, mkldnn::memory::format::ldsnc); - DType* dst_iter_0 = dst_layer_0 + T * N * D * H; // D * nstates * N * H - auto dst_iter_memory_0 = - mkldnn::memory({ dst_iter_md_0, cpu_engine }, dst_iter_0); - op.hcy_memory.push_back(dst_iter_memory_0); - workptr = dst_iter_0 + D * nstates * N * H; - - // next L - 1 layers - if (L > 1 && D == 1) { - auto user_src_layer_md = mkldnn::memory::desc( - { src_layer_tz }, mkldnn_dtype, mkldnn::memory::format::tnc); - auto user_src_layer_memory = mkldnn::memory({ user_src_layer_md, cpu_engine }); - op.x_memory.push_back(user_src_layer_memory); - - mkldnn::memory::dims weights_layer_tz = {L - 1, 1, H, ngates, H}; // ldigo - mkldnn::memory::dims weights_iter_tz = {L - 1, 1, H, ngates, H}; // ldigo - mkldnn::memory::dims bias_tz = {L - 1, 1, ngates, H}; - auto user_weight_layer_md = mkldnn::memory::desc( - { weights_layer_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo); - auto user_weight_iter_md = mkldnn::memory::desc( - { weights_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo); - auto user_bias_md = mkldnn::memory::desc({ bias_tz }, - mkldnn_dtype, mkldnn::memory::format::ldgo); - - DType* weight_layer_n = workptr; // (L - 1) * H * ngates * H - auto user_weight_layer_memory_n - = mkldnn::memory({ user_weight_layer_md, cpu_engine }, weight_layer_n); - op.wx_memory.push_back(user_weight_layer_memory_n); - - DType* weight_iter_n = weight_layer_n + - (L - 1) * H * ngates * H; // (L - 1) * H * ngates * H - auto user_weight_iter_memory_n - = mkldnn::memory({ user_weight_iter_md, cpu_engine }, weight_iter_n); - op.wh_memory.push_back(user_weight_iter_memory_n); - - DType* bias_n = weight_iter_n + (L - 1) * H * ngates * H; // (L - 1) * ngates * H - auto user_bias_memory_n = - mkldnn::memory({ user_bias_md, cpu_engine }, bias_n); - op.bias_memory.push_back(user_bias_memory_n); - - auto wx_md_n = mkldnn::memory::desc( - { weights_layer_tz }, mkldnn_dtype, mkldnn::memory::format::ldgoi); - DType* wx_n = bias_n + (L - 1) * ngates * H; // (L - 1) * ngates * H * H - auto wx_memory_n = - mkldnn::memory({ wx_md_n, cpu_engine }, wx_n); - DType* wh_n = wx_n + (L - 1) * ngates * H * H; // (L - 1) * ngates * H * H - auto wh_md_n = mkldnn::memory::desc( - { weights_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldgoi); - auto wh_memory_n = - mkldnn::memory({ wh_md_n, cpu_engine }, wh_n); - - op.concat_weight_memory.push_back(wx_memory_n); - op.concat_weight_memory.push_back(wh_memory_n); - workptr = wh_n + (L - 1) * ngates * H * H; - - mkldnn::memory::dims src_iter_tz_n1 = {1, 1, nstates, N, H}; // ldsnc - auto src_iter_md_n1 = mkldnn::memory::desc( - { src_iter_tz_n1 }, mkldnn_dtype, mkldnn::memory::format::ldsnc); - for (int l = 0; l < L - 1; l++) { - DType* src_iter_n1 = workptr; // nstates * N * H - auto src_iter_memory_n1 = - mkldnn::memory({ src_iter_md_n1, cpu_engine }, src_iter_n1); - op.concat_iter_memory.push_back(src_iter_memory_n1); - workptr = src_iter_n1 + nstates * N * H; - } - mkldnn::memory::dims src_iter_tz_n = {L - 1, 1, nstates, N, H}; // ldsnc - auto src_iter_md_n = mkldnn::memory::desc( - { src_iter_tz_n }, mkldnn_dtype, mkldnn::memory::format::ldsnc); - DType* src_iter_n = workptr; // (L - 1) * nstates * N * H - auto src_iter_memory_n = - mkldnn::memory({ src_iter_md_n, cpu_engine }, src_iter_n); - op.concat_iter_memory.push_back(src_iter_memory_n); - op.hcx_memory.push_back(src_iter_memory_n); - - DType* dst_layer_n = src_iter_n + (L - 1) * nstates * N * H; // T * N * D * H - auto dst_layer_memory_n - = mkldnn::memory({ dst_layer_md, cpu_engine }, dst_layer_n); - op.y_memory.push_back(dst_layer_memory_n); - - mkldnn::memory::dims dst_iter_tz_n = {L - 1, 1, nstates, N, H}; // ldsnc - auto dst_iter_md_n = mkldnn::memory::desc( - { dst_iter_tz_n }, mkldnn_dtype, mkldnn::memory::format::ldsnc); - DType* dst_iter_n = dst_layer_n + T * N * D * H; // (L - 1) * nstates * N * H - auto dst_iter_memory_n = - mkldnn::memory({ dst_iter_md_n, cpu_engine }, dst_iter_n); - op.hcy_memory.push_back(dst_iter_memory_n); - } - - if (L > 1 && D == 2) { - mkldnn::memory::dims weights_layer_tz = {1, D, H * D, ngates, H}; // ldigo - mkldnn::memory::dims weights_iter_tz = {1, D, H, ngates, H}; // ldigo - mkldnn::memory::dims bias_tz = {1, D, ngates, H}; - auto user_weight_layer_md = mkldnn::memory::desc( - { weights_layer_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo); - auto user_weight_iter_md = mkldnn::memory::desc( - { weights_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo); - auto user_bias_md = mkldnn::memory::desc({ bias_tz }, - mkldnn_dtype, mkldnn::memory::format::ldgo); - - auto user_src_layer_md = mkldnn::memory::desc( - { src_layer_tz }, mkldnn_dtype, mkldnn::memory::format::tnc); - auto user_src_layer_memory = mkldnn::memory({ user_src_layer_md, cpu_engine }); - op.x_memory.push_back(user_src_layer_memory); - - auto wx_md_n = mkldnn::memory::desc( - { weights_layer_tz }, mkldnn_dtype, mkldnn::memory::format::ldgoi); - auto wh_md_n = mkldnn::memory::desc( - { weights_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldgoi); - - for (int l = 0; l < L; l++) { - DType* weight_layer_n = workptr; // D * (H * D) * ngates * H - auto user_weight_layer_memory_n - = mkldnn::memory({ user_weight_layer_md, cpu_engine }, weight_layer_n); - op.wx_memory.push_back(user_weight_layer_memory_n); - - DType* weight_iter_n = weight_layer_n + - D * (H * D) * ngates * H; // D * H * ngates * H - auto user_weight_iter_memory_n - = mkldnn::memory({ user_weight_iter_md, cpu_engine }, weight_iter_n); - op.wh_memory.push_back(user_weight_iter_memory_n); - - DType* bias_n = weight_iter_n + D * H * ngates * H; // D * ngates * H - auto user_bias_memory_n = - mkldnn::memory({ user_bias_md, cpu_engine }, bias_n); - op.bias_memory.push_back(user_bias_memory_n); - workptr = bias_n + D * ngates * H; - } - - DType* wx_n = workptr; // D * ngates * (D * H) * H - DType* wh_n = wx_n + D * ngates * (D * H) * H; // D * ngates * H * H - auto wx_memory_n = - mkldnn::memory({ wx_md_n, cpu_engine }, wx_n); - auto wh_memory_n = - mkldnn::memory({ wh_md_n, cpu_engine }, wh_n); - op.concat_weight_memory.push_back(wx_memory_n); - op.concat_weight_memory.push_back(wh_memory_n); - - mkldnn::memory::dims src_iter_undi_tz = {1, 1, nstates, N, H}; // ldsnc - auto src_iter_undi_md = mkldnn::memory::desc( - { src_iter_undi_tz }, mkldnn_dtype, mkldnn::memory::format::ldsnc); - DType* src_iter_undi = wh_n + D * ngates * H * H; // nstates * N * H - auto src_iter_undi_memory = - mkldnn::memory({ src_iter_undi_md, cpu_engine }, src_iter_undi); - op.concat_iter_memory.push_back(src_iter_undi_memory_0); - - DType* src_iter_undi2 = src_iter_undi + nstates * N * H; // nstates * N * H - auto src_iter_undi2_memory = - mkldnn::memory({ src_iter_undi_md, cpu_engine }, src_iter_undi2); - op.concat_iter_memory.push_back(src_iter_undi2_memory); - - mkldnn::memory::dims src_iter_tz = {1, D, nstates, N, H}; // ldsnc - auto src_iter_md = mkldnn::memory::desc( - { src_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldsnc); - DType* src_iter = src_iter_undi2 + nstates * N * H; // D * nstates * N * H - auto src_iter_memory = - mkldnn::memory({ src_iter_md, cpu_engine }, src_iter); - op.concat_iter_memory.push_back(src_iter_memory); - op.hcx_memory.push_back(src_iter_memory); - - DType* dst_layer_n = src_iter + D * nstates * N * H; // T * N * D * H - auto dst_layer_memory_n - = mkldnn::memory({ dst_layer_md, cpu_engine }, dst_layer_n); - op.y_memory.push_back(dst_layer_memory_n); - - mkldnn::memory::dims dst_iter_tz_n = {1, D, nstates, N, H}; // ldsnc - auto dst_iter_md_n = mkldnn::memory::desc( - { dst_iter_tz_n }, mkldnn_dtype, mkldnn::memory::format::ldsnc); - DType* dst_iter_n = dst_layer_n + T * N * D * H; // D * nstates * N * H - auto dst_iter_memory_n = - mkldnn::memory({ dst_iter_md_n, cpu_engine }, dst_iter_n); - op.hcy_memory.push_back(dst_iter_memory_n); - } - } - } - op.Forward(ctx, in_blobs, req, out_blobs); - }); - }); -} -#endif - /* index description 0: x diff --git a/src/operator/rnn.cc b/src/operator/rnn.cc index 90d8da6f84b0..e9eddcd27a41 100644 --- a/src/operator/rnn.cc +++ b/src/operator/rnn.cc @@ -206,6 +206,417 @@ struct RNNGrad { } }; +#if MXNET_USE_MKLDNN == 1 +static void RNNStatefulComputeCPU(const OpStatePtr& state_ptr, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + std::vector in_blobs; + std::vector out_blobs; + std::vector temp_ndarrays_i; + std::vector temp_ndarrays_o; + for (const NDArray& in : inputs) { + if (in.storage_type() == kDefaultStorage) { + temp_ndarrays_i.push_back(in.Reorder2Default()); + in_blobs.emplace_back(temp_ndarrays_i.back().data()); + } else { + in_blobs.emplace_back(in.data()); + } + } + + for (const NDArray& out : outputs) { + if (out.storage_type() == kDefaultStorage) { + temp_ndarrays_o.push_back(out.Reorder2Default()); + out_blobs.emplace_back(temp_ndarrays_o.back().data()); + } else { + out_blobs.emplace_back(out.data()); + } + } + int dtype = in_blobs[rnn_enum::kData].type_flag_; + int itype = in_blobs[inputs.size()-1].type_flag_; + mkldnn::memory::data_type mkldnn_dtype = get_mkldnn_type(dtype); + Stream *s = ctx.get_stream(); + auto cpu_engine = CpuEngine::Get()->get_engine(); + MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { + MSHADOW_TYPE_SWITCH(itype, IType, { + RNNOp& op = state_ptr.get_state>(); + const RNNParam& param = op.param_; + int ngates = 0, nstates = 0; + GetMKLDNNRNNAlgo(param.mode, &ngates, &nstates); + int D = param.bidirectional ? 2 : 1; + Tensor x = in_blobs[rnn_enum::kData].get(s); + int T = x.shape_[0]; + int N = x.shape_[1]; + int I = x.shape_[2]; + int H = param.state_size; + int L = param.num_layers; + + const size_t r_size = GetMKLDNNRNNCacheMemorySize(L, D, T, N, I, H, param.mode); + if (op.init_mem_ && op.reserve_mem_size_ < r_size) { + Storage::Get()->Free(op.mem_space_); + op.init_mem_ = false; + } + if (!op.init_mem_) { + op.mem_space_ = Storage::Get()->Alloc( + r_size * sizeof(DType), + Context::CPU()); + op.reserve_mem_size_ = r_size; + op.init_mem_ = true; + op.has_cache = false; + } + if (op.has_cache && op.x_memory.size() == 0) { + op.has_cache = false; + } + + DType* workptr = static_cast(op.mem_space_.dptr); + mkldnn::memory::dims src_layer_tz_0 = {T, N, I}; + mkldnn::memory::dims src_layer_tz = {T, N, D * H}; + mkldnn::memory::dims dst_layer_tz = {T, N, D * H}; + auto dst_layer_md = mkldnn::memory::desc( + { dst_layer_tz }, mkldnn_dtype, mkldnn::memory::format::tnc); + if (op.x_memory.size() == 0) { + if (D == 1 && I == H) { + auto user_src_layer_md = mkldnn::memory::desc( + { src_layer_tz }, mkldnn_dtype, mkldnn::memory::format::tnc); + auto user_src_layer_memory_n = mkldnn::memory({ user_src_layer_md, cpu_engine }); + op.x_memory.push_back(user_src_layer_memory_n); + + mkldnn::memory::dims weights_layer_tz = {L, 1, I, ngates, H}; // ldigo + mkldnn::memory::dims weights_iter_tz = {L, 1, H, ngates, H}; // ldigo + mkldnn::memory::dims bias_tz = {L, 1, ngates, H}; + auto user_weight_layer_md = mkldnn::memory::desc( + { weights_layer_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo); + auto user_weight_iter_md = mkldnn::memory::desc( + { weights_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo); + auto user_bias_md = mkldnn::memory::desc({ bias_tz }, + mkldnn_dtype, mkldnn::memory::format::ldgo); + DType* weight_layer_n = workptr; // L * I * ngates * H + auto user_weight_layer_memory_n + = mkldnn::memory({ user_weight_layer_md, cpu_engine }, weight_layer_n); + op.wx_memory.push_back(user_weight_layer_memory_n); + + DType* weight_iter_n = weight_layer_n + L * I * ngates * H; // L * H * ngates * H + auto user_weight_iter_memory_n + = mkldnn::memory({ user_weight_iter_md, cpu_engine }, weight_iter_n); + op.wh_memory.push_back(user_weight_iter_memory_n); + + DType* bias_n = weight_iter_n + L * H * ngates * H; // L * ngates * H + auto user_bias_memory_n = + mkldnn::memory({ user_bias_md, cpu_engine }, bias_n); + op.bias_memory.push_back(user_bias_memory_n); + + auto wx_md_n = mkldnn::memory::desc( + { weights_layer_tz }, mkldnn_dtype, mkldnn::memory::format::ldgoi); + DType* wx_n = bias_n + L * ngates * H; // L * ngates * I * H + auto wx_memory_n = + mkldnn::memory({ wx_md_n, cpu_engine }, wx_n); + DType* wh_n = wx_n + L * ngates * I * H; // L * ngates * H * H + auto wh_md_n = mkldnn::memory::desc( + { weights_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldgoi); + auto wh_memory_n = + mkldnn::memory({ wh_md_n, cpu_engine }, wh_n); + + op.concat_weight_memory.push_back(wx_memory_n); + op.concat_weight_memory.push_back(wh_memory_n); + workptr = wh_n + L * ngates * H * H; + + mkldnn::memory::dims src_iter_tz_n1 = {1, 1, nstates, N, H}; // ldsnc + auto src_iter_md_n1 = mkldnn::memory::desc( + { src_iter_tz_n1 }, mkldnn_dtype, mkldnn::memory::format::ldsnc); + for (int l = 0; l < L; l++) { + DType* src_iter_n1 = workptr; // nstates * N * H + auto src_iter_memory_n1 = + mkldnn::memory({ src_iter_md_n1, cpu_engine }, src_iter_n1); + op.concat_iter_memory.push_back(src_iter_memory_n1); + workptr = src_iter_n1 + nstates * N * H; + } + mkldnn::memory::dims src_iter_tz_n = {L, 1, nstates, N, H}; // ldsnc + auto src_iter_md_n = mkldnn::memory::desc( + { src_iter_tz_n }, mkldnn_dtype, mkldnn::memory::format::ldsnc); + DType* src_iter_n = workptr; // L * nstates * N * H + auto src_iter_memory_n = + mkldnn::memory({ src_iter_md_n, cpu_engine }, src_iter_n); + op.concat_iter_memory.push_back(src_iter_memory_n); + op.hcx_memory.push_back(src_iter_memory_n); + DType* dst_layer_n = src_iter_n + L * nstates * N * H; // T * N * D * H + auto dst_layer_memory_n + = mkldnn::memory({ dst_layer_md, cpu_engine }, dst_layer_n); + op.y_memory.push_back(dst_layer_memory_n); + + mkldnn::memory::dims dst_iter_tz_n = {L, 1, nstates, N, H}; // ldsnc + auto dst_iter_md_n = mkldnn::memory::desc( + { dst_iter_tz_n }, mkldnn_dtype, mkldnn::memory::format::ldsnc); + DType* dst_iter_n = dst_layer_n + T * N * D * H; // L * nstates * N * H + auto dst_iter_memory_n = + mkldnn::memory({ dst_iter_md_n, cpu_engine }, dst_iter_n); + op.hcy_memory.push_back(dst_iter_memory_n); + workptr = dst_iter_n + L * nstates * N * H; + + } else { + auto user_src_layer_md_0 = mkldnn::memory::desc( + { src_layer_tz_0 }, mkldnn_dtype, mkldnn::memory::format::tnc); + auto user_src_layer_memory_0 = mkldnn::memory({ user_src_layer_md_0, cpu_engine }); + op.x_memory.push_back(user_src_layer_memory_0); + + mkldnn::memory::dims weights_layer_tz_0 = {1, D, I, ngates, H}; // ldigo + mkldnn::memory::dims weights_iter_tz_0 = {1, D, H, ngates, H}; // ldigo + mkldnn::memory::dims bias_tz_0 = {1, D, ngates, H}; + auto user_weight_layer_md_0 = mkldnn::memory::desc( + { weights_layer_tz_0 }, mkldnn_dtype, mkldnn::memory::format::ldigo); + auto user_weight_iter_md_0 = mkldnn::memory::desc( + { weights_iter_tz_0 }, mkldnn_dtype, mkldnn::memory::format::ldigo); + auto user_bias_md_0 = mkldnn::memory::desc({ bias_tz_0 }, + mkldnn_dtype, mkldnn::memory::format::ldgo); + + DType* weight_layer_0 = workptr; // D * I * ngates * H + auto user_weight_layer_memory_0 + = mkldnn::memory({ user_weight_layer_md_0, cpu_engine }, weight_layer_0); + op.wx_memory.push_back(user_weight_layer_memory_0); + + DType* weight_iter_0 = weight_layer_0 + D * I * ngates * H; // D * H * ngates * H + auto user_weight_iter_memory_0 + = mkldnn::memory({ user_weight_iter_md_0, cpu_engine }, weight_iter_0); + op.wh_memory.push_back(user_weight_iter_memory_0); + + DType* bias_0 = weight_iter_0 + D * H * ngates * H; // D * ngates * H + auto user_bias_memory_0 = + mkldnn::memory({ user_bias_md_0, cpu_engine }, bias_0); + op.bias_memory.push_back(user_bias_memory_0); + workptr = bias_0 + D * ngates * H; + + auto wx_md_0 = mkldnn::memory::desc( + { weights_layer_tz_0 }, mkldnn_dtype, mkldnn::memory::format::ldgoi); + auto wx_memory_0 = + mkldnn::memory({ wx_md_0, cpu_engine }); + auto wh_md_0 = mkldnn::memory::desc( + { weights_iter_tz_0 }, mkldnn_dtype, mkldnn::memory::format::ldgoi); + auto wh_memory_0 = + mkldnn::memory({ wh_md_0, cpu_engine }); + if (D == 2) { + DType* wx_0 = workptr; // D * ngates * I * H + wx_memory_0.set_data_handle(wx_0); + DType* wh_0 = wx_0 + D * ngates * I * H; // D * ngates * H * H + wh_memory_0.set_data_handle(wh_0); + workptr = wh_0 + D * ngates * H * H; + } + op.concat_weight_memory.push_back(wx_memory_0); + op.concat_weight_memory.push_back(wh_memory_0); + + mkldnn::memory::dims src_iter_undi_tz_0 = {1, 1, nstates, N, H}; // ldsnc + auto src_iter_undi_md_0 = mkldnn::memory::desc( + { src_iter_undi_tz_0 }, mkldnn_dtype, mkldnn::memory::format::ldsnc); + DType* src_iter_undi_0 = workptr; // nstates * N * H + auto src_iter_undi_memory_0 = + mkldnn::memory({ src_iter_undi_md_0, cpu_engine }, src_iter_undi_0); + op.concat_iter_memory.push_back(src_iter_undi_memory_0); + workptr = src_iter_undi_0 + nstates * N * H; + if (D == 1) { + op.hcx_memory.push_back(src_iter_undi_memory_0); + } else { + DType* src_iter_undi2_0 = workptr; // nstates * N * H + auto src_iter_undi2_memory_0 = + mkldnn::memory({ src_iter_undi_md_0, cpu_engine }, src_iter_undi2_0); + op.concat_iter_memory.push_back(src_iter_undi2_memory_0); + + mkldnn::memory::dims src_iter_tz_0 = {1, D, nstates, N, H}; // ldsnc + auto src_iter_md_0 = mkldnn::memory::desc( + { src_iter_tz_0 }, mkldnn_dtype, mkldnn::memory::format::ldsnc); + DType* src_iter_0 = src_iter_undi2_0 + nstates * N * H; // D * nstates * N * H + auto src_iter_memory_0 = + mkldnn::memory({ src_iter_md_0, cpu_engine }, src_iter_0); + op.concat_iter_memory.push_back(src_iter_memory_0); + op.hcx_memory.push_back(src_iter_memory_0); + workptr = src_iter_0 + D * nstates * N * H; + } + + DType* dst_layer_0 = workptr; // T * N * D * H + auto dst_layer_memory_0 + = mkldnn::memory({ dst_layer_md, cpu_engine }, dst_layer_0); + op.y_memory.push_back(dst_layer_memory_0); + + mkldnn::memory::dims dst_iter_tz_0 = {1, D, nstates, N, H}; // ldsnc + auto dst_iter_md_0 = mkldnn::memory::desc( + { dst_iter_tz_0 }, mkldnn_dtype, mkldnn::memory::format::ldsnc); + DType* dst_iter_0 = dst_layer_0 + T * N * D * H; // D * nstates * N * H + auto dst_iter_memory_0 = + mkldnn::memory({ dst_iter_md_0, cpu_engine }, dst_iter_0); + op.hcy_memory.push_back(dst_iter_memory_0); + workptr = dst_iter_0 + D * nstates * N * H; + + // next L - 1 layers + if (L > 1 && D == 1) { + auto user_src_layer_md = mkldnn::memory::desc( + { src_layer_tz }, mkldnn_dtype, mkldnn::memory::format::tnc); + auto user_src_layer_memory = mkldnn::memory({ user_src_layer_md, cpu_engine }); + op.x_memory.push_back(user_src_layer_memory); + + mkldnn::memory::dims weights_layer_tz = {L - 1, 1, H, ngates, H}; // ldigo + mkldnn::memory::dims weights_iter_tz = {L - 1, 1, H, ngates, H}; // ldigo + mkldnn::memory::dims bias_tz = {L - 1, 1, ngates, H}; + auto user_weight_layer_md = mkldnn::memory::desc( + { weights_layer_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo); + auto user_weight_iter_md = mkldnn::memory::desc( + { weights_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo); + auto user_bias_md = mkldnn::memory::desc({ bias_tz }, + mkldnn_dtype, mkldnn::memory::format::ldgo); + + DType* weight_layer_n = workptr; // (L - 1) * H * ngates * H + auto user_weight_layer_memory_n + = mkldnn::memory({ user_weight_layer_md, cpu_engine }, weight_layer_n); + op.wx_memory.push_back(user_weight_layer_memory_n); + + DType* weight_iter_n = weight_layer_n + + (L - 1) * H * ngates * H; // (L - 1) * H * ngates * H + auto user_weight_iter_memory_n + = mkldnn::memory({ user_weight_iter_md, cpu_engine }, weight_iter_n); + op.wh_memory.push_back(user_weight_iter_memory_n); + + DType* bias_n = weight_iter_n + (L - 1) * H * ngates * H; // (L - 1) * ngates * H + auto user_bias_memory_n = + mkldnn::memory({ user_bias_md, cpu_engine }, bias_n); + op.bias_memory.push_back(user_bias_memory_n); + + auto wx_md_n = mkldnn::memory::desc( + { weights_layer_tz }, mkldnn_dtype, mkldnn::memory::format::ldgoi); + DType* wx_n = bias_n + (L - 1) * ngates * H; // (L - 1) * ngates * H * H + auto wx_memory_n = + mkldnn::memory({ wx_md_n, cpu_engine }, wx_n); + DType* wh_n = wx_n + (L - 1) * ngates * H * H; // (L - 1) * ngates * H * H + auto wh_md_n = mkldnn::memory::desc( + { weights_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldgoi); + auto wh_memory_n = + mkldnn::memory({ wh_md_n, cpu_engine }, wh_n); + + op.concat_weight_memory.push_back(wx_memory_n); + op.concat_weight_memory.push_back(wh_memory_n); + workptr = wh_n + (L - 1) * ngates * H * H; + + mkldnn::memory::dims src_iter_tz_n1 = {1, 1, nstates, N, H}; // ldsnc + auto src_iter_md_n1 = mkldnn::memory::desc( + { src_iter_tz_n1 }, mkldnn_dtype, mkldnn::memory::format::ldsnc); + for (int l = 0; l < L - 1; l++) { + DType* src_iter_n1 = workptr; // nstates * N * H + auto src_iter_memory_n1 = + mkldnn::memory({ src_iter_md_n1, cpu_engine }, src_iter_n1); + op.concat_iter_memory.push_back(src_iter_memory_n1); + workptr = src_iter_n1 + nstates * N * H; + } + mkldnn::memory::dims src_iter_tz_n = {L - 1, 1, nstates, N, H}; // ldsnc + auto src_iter_md_n = mkldnn::memory::desc( + { src_iter_tz_n }, mkldnn_dtype, mkldnn::memory::format::ldsnc); + DType* src_iter_n = workptr; // (L - 1) * nstates * N * H + auto src_iter_memory_n = + mkldnn::memory({ src_iter_md_n, cpu_engine }, src_iter_n); + op.concat_iter_memory.push_back(src_iter_memory_n); + op.hcx_memory.push_back(src_iter_memory_n); + + DType* dst_layer_n = src_iter_n + (L - 1) * nstates * N * H; // T * N * D * H + auto dst_layer_memory_n + = mkldnn::memory({ dst_layer_md, cpu_engine }, dst_layer_n); + op.y_memory.push_back(dst_layer_memory_n); + + mkldnn::memory::dims dst_iter_tz_n = {L - 1, 1, nstates, N, H}; // ldsnc + auto dst_iter_md_n = mkldnn::memory::desc( + { dst_iter_tz_n }, mkldnn_dtype, mkldnn::memory::format::ldsnc); + DType* dst_iter_n = dst_layer_n + T * N * D * H; // (L - 1) * nstates * N * H + auto dst_iter_memory_n = + mkldnn::memory({ dst_iter_md_n, cpu_engine }, dst_iter_n); + op.hcy_memory.push_back(dst_iter_memory_n); + } + + if (L > 1 && D == 2) { + mkldnn::memory::dims weights_layer_tz = {1, D, H * D, ngates, H}; // ldigo + mkldnn::memory::dims weights_iter_tz = {1, D, H, ngates, H}; // ldigo + mkldnn::memory::dims bias_tz = {1, D, ngates, H}; + auto user_weight_layer_md = mkldnn::memory::desc( + { weights_layer_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo); + auto user_weight_iter_md = mkldnn::memory::desc( + { weights_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo); + auto user_bias_md = mkldnn::memory::desc({ bias_tz }, + mkldnn_dtype, mkldnn::memory::format::ldgo); + + auto user_src_layer_md = mkldnn::memory::desc( + { src_layer_tz }, mkldnn_dtype, mkldnn::memory::format::tnc); + auto user_src_layer_memory = mkldnn::memory({ user_src_layer_md, cpu_engine }); + op.x_memory.push_back(user_src_layer_memory); + + auto wx_md_n = mkldnn::memory::desc( + { weights_layer_tz }, mkldnn_dtype, mkldnn::memory::format::ldgoi); + auto wh_md_n = mkldnn::memory::desc( + { weights_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldgoi); + + for (int l = 0; l < L; l++) { + DType* weight_layer_n = workptr; // D * (H * D) * ngates * H + auto user_weight_layer_memory_n + = mkldnn::memory({ user_weight_layer_md, cpu_engine }, weight_layer_n); + op.wx_memory.push_back(user_weight_layer_memory_n); + + DType* weight_iter_n = weight_layer_n + + D * (H * D) * ngates * H; // D * H * ngates * H + auto user_weight_iter_memory_n + = mkldnn::memory({ user_weight_iter_md, cpu_engine }, weight_iter_n); + op.wh_memory.push_back(user_weight_iter_memory_n); + + DType* bias_n = weight_iter_n + D * H * ngates * H; // D * ngates * H + auto user_bias_memory_n = + mkldnn::memory({ user_bias_md, cpu_engine }, bias_n); + op.bias_memory.push_back(user_bias_memory_n); + workptr = bias_n + D * ngates * H; + } + + DType* wx_n = workptr; // D * ngates * (D * H) * H + DType* wh_n = wx_n + D * ngates * (D * H) * H; // D * ngates * H * H + auto wx_memory_n = + mkldnn::memory({ wx_md_n, cpu_engine }, wx_n); + auto wh_memory_n = + mkldnn::memory({ wh_md_n, cpu_engine }, wh_n); + op.concat_weight_memory.push_back(wx_memory_n); + op.concat_weight_memory.push_back(wh_memory_n); + + mkldnn::memory::dims src_iter_undi_tz = {1, 1, nstates, N, H}; // ldsnc + auto src_iter_undi_md = mkldnn::memory::desc( + { src_iter_undi_tz }, mkldnn_dtype, mkldnn::memory::format::ldsnc); + DType* src_iter_undi = wh_n + D * ngates * H * H; // nstates * N * H + auto src_iter_undi_memory = + mkldnn::memory({ src_iter_undi_md, cpu_engine }, src_iter_undi); + op.concat_iter_memory.push_back(src_iter_undi_memory_0); + + DType* src_iter_undi2 = src_iter_undi + nstates * N * H; // nstates * N * H + auto src_iter_undi2_memory = + mkldnn::memory({ src_iter_undi_md, cpu_engine }, src_iter_undi2); + op.concat_iter_memory.push_back(src_iter_undi2_memory); + + mkldnn::memory::dims src_iter_tz = {1, D, nstates, N, H}; // ldsnc + auto src_iter_md = mkldnn::memory::desc( + { src_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldsnc); + DType* src_iter = src_iter_undi2 + nstates * N * H; // D * nstates * N * H + auto src_iter_memory = + mkldnn::memory({ src_iter_md, cpu_engine }, src_iter); + op.concat_iter_memory.push_back(src_iter_memory); + op.hcx_memory.push_back(src_iter_memory); + + DType* dst_layer_n = src_iter + D * nstates * N * H; // T * N * D * H + auto dst_layer_memory_n + = mkldnn::memory({ dst_layer_md, cpu_engine }, dst_layer_n); + op.y_memory.push_back(dst_layer_memory_n); + + mkldnn::memory::dims dst_iter_tz_n = {1, D, nstates, N, H}; // ldsnc + auto dst_iter_md_n = mkldnn::memory::desc( + { dst_iter_tz_n }, mkldnn_dtype, mkldnn::memory::format::ldsnc); + DType* dst_iter_n = dst_layer_n + T * N * D * H; // D * nstates * N * H + auto dst_iter_memory_n = + mkldnn::memory({ dst_iter_md_n, cpu_engine }, dst_iter_n); + op.hcy_memory.push_back(dst_iter_memory_n); + } + } + } + op.Forward(ctx, in_blobs, req, out_blobs); + }); + }); +} +#endif + NNVM_REGISTER_OP(RNN) .describe(R"code(Applies recurrent layers to input data. Currently, vanilla RNN, LSTM and GRU are implemented, with both multi-layer and bidirectional support. From 5e8c20e996dd08493ed3894e0becf2de24a89183 Mon Sep 17 00:00:00 2001 From: Wei Date: Thu, 23 May 2019 22:44:50 +0800 Subject: [PATCH 20/21] Remove redundent macro of `__CUDACC__` --- src/operator/rnn-inl.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/operator/rnn-inl.h b/src/operator/rnn-inl.h index 64397e9aab31..56fe7a3b3829 100644 --- a/src/operator/rnn-inl.h +++ b/src/operator/rnn-inl.h @@ -902,7 +902,7 @@ class RNNOp { param_.p, param_.mode); } else { - #if MXNET_USE_MKLDNN == 1 && !defined(__CUDACC__) + #if MXNET_USE_MKLDNN == 1 if (dmlc::GetEnv("MXNET_USE_MKLDNN_RNN", 1) && param_.mode != rnn_enum::kGru) { // TODO(zixuanweeei): MKLDNN GRU has precision issue. A stable one // will be added to MXNet when we figure out the issue. From 41772c97375a10e89a0faca7154347e3a3b7cd10 Mon Sep 17 00:00:00 2001 From: Wei Date: Thu, 23 May 2019 23:12:06 +0800 Subject: [PATCH 21/21] Remove the last macro `__CUDACC__` from rnn* --- src/operator/rnn-inl.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/operator/rnn-inl.h b/src/operator/rnn-inl.h index 56fe7a3b3829..9785be209b7d 100644 --- a/src/operator/rnn-inl.h +++ b/src/operator/rnn-inl.h @@ -971,7 +971,7 @@ class RNNOp { hy_ptr, cy_ptr, param_.mode); - #if MXNET_USE_MKLDNN == 1 && !defined(__CUDACC__) + #if MXNET_USE_MKLDNN == 1 } #endif }