diff --git a/src/operator/nn/mkldnn/mkldnn_rnn_impl.h b/src/operator/nn/mkldnn/mkldnn_rnn_impl.h new file mode 100644 index 000000000000..ea8e07ea617c --- /dev/null +++ b/src/operator/nn/mkldnn/mkldnn_rnn_impl.h @@ -0,0 +1,740 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_RNN_IMPL_H_ +#define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_RNN_IMPL_H_ +#if MXNET_USE_MKLDNN == 1 +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "../../math_functions-inl.h" +#include "../../operator_common.h" +#include "../../rnn_impl.h" +#include "../../rnn-inl.h" +#include "mkldnn.hpp" +#include "./mkldnn_base-inl.h" + +namespace mxnet { +namespace op { + +static algorithm GetMKLDNNRNNAlgo(int mode, + int* ngates, + int* nstates) { + algorithm algo = algorithm::vanilla_rnn; + switch (mode) { + case rnn_enum::kLstm: + *ngates = 4; + *nstates = 2; + algo = algorithm::vanilla_lstm; + break; + case rnn_enum::kGru: + *ngates = 3; + *nstates = 1; + algo = algorithm::vanilla_gru; + break; + case rnn_enum::kRnnRelu: + case rnn_enum::kRnnTanh: + *ngates = 1; + *nstates = 1; + algo = algorithm::vanilla_rnn; + break; + default: + LOG(FATAL) << "unsupported RNN mode:" << mode; + break; + } + return algo; +} + +static void ConcatData(mkldnn::memory::format src_format, + mkldnn::memory::format dst_format, + std::vector srcs_cds, + mkldnn::memory::dims dst_cds, + mkldnn::memory::data_type mkldnn_dtype, + int concat_dimension, + std::vector srcs_data, + const mkldnn::memory &dst) { + auto cpu_engine = CpuEngine::Get()->get_engine(); + std::vector srcs_pd; + std::vector srcs; + for (size_t i = 0; i < srcs_cds.size(); i++) { + auto desc = mkldnn::memory::desc(srcs_cds[i], mkldnn_dtype, src_format); + auto mpd = mkldnn::memory::primitive_desc(desc, cpu_engine); + auto src_memory = mkldnn::memory(mpd, srcs_data[i]); + srcs_pd.push_back(mpd); + srcs.push_back(src_memory); + } + std::vector inputs; + for (size_t i = 0; i < srcs_cds.size(); i++) { + inputs.push_back(srcs[i]); + } + auto dst_desc = mkldnn::memory::desc(dst_cds, mkldnn_dtype, dst_format); + auto concat_pd = concat::primitive_desc(dst_desc, concat_dimension, srcs_pd); + MKLDNNStream::Get()->RegisterPrim(concat(concat_pd, inputs, dst)); + MKLDNNStream::Get()->Submit(); +} + +// cached mkldnn memory +// first layer wx, wh with next L - 1 layers wx and wh +// with L layers hx and cx, src and dst data/iter etc. +// it will prepare memory on before and after reorder and concat. +// for unidirectional, it will fused as dim like 1 + (L - 1) when I != H. +// for bidirectional, it will fused as data + back_data (weight, bias, iter etc), +// also need to identify first layer and next layers +static size_t GetMKLDNNRNNCacheMemorySize(int L, + int D, + int T, + int N, + int I, + int H, + int mode) { + size_t size = 0; + switch (mode) { + case rnn_enum::kLstm: + size = 2 * (D * (I + H) * 4 * H + (L - 1) * D * (D * H + H) * 4 * H + + L * D * 2 * N * H) + T * N * D * H + L * 2 * D * 4 * H + (L + 2) * D * 2 * N * H + + 6 * D * (I + H + 2) * 4 * H + T * N * I * 2; + break; + case rnn_enum::kGru: + size = 2 * (D * (I + H) * 3 * H + (L - 1) * D * (D * H + H) * 3 * H + + L * D * 2 * N * H) + T * N * D * H + L * 2 * D * 3 * H + (L + 2) * D * 2 * N * H + + 6 * D * (I + H + 2) * 3 * H + T * N * I * 2; + break; + case rnn_enum::kRnnRelu: + case rnn_enum::kRnnTanh: + size = 2 * (D * (I + H) * 1 * H + (L - 1) * D * (D * H + H) * 1 * H + + L * D * 2 * N * H) + T * N * D * H + L * 2 * D * 1 * H + (L + 2) * D * 2 * N * H + + 6 * D * (I + H + 2) * 1 * H + T * N * I * 2; + break; + default: + LOG(FATAL) << "unknown RNN mode " << mode; + break; + } + return size; +} + +template +static void AdjustGruWeightGateOrder(DType* weight, + const int I, + const int H) { + // mxnet gru gate order is reset, update and new gates + // mkldnn gru gate order is update, reset and new gates + const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); + DType* weight_reset = weight; + DType* weight_update = weight + I * H; + #pragma omp parallel for num_threads(omp_threads) + for (int i = 0; i < I * H; i++) { + DType tmp = weight_update[i]; + weight_update[i] = weight_reset[i]; + weight_reset[i] = tmp; + } +} + +template +static void AdjustGruBiasGateOrder(DType* bias, + const int H) { + // mxnet gru gate order is reset, update and new gates + // mkldnn gru gate order is update, reset and new gates + const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); + DType* bias_reset = bias; + DType* bias_update = bias + H; + #pragma omp parallel for num_threads(omp_threads) + for (int i = 0; i < H; i++) { + DType tmp = bias_update[i]; + bias_update[i] = bias_reset[i]; + bias_reset[i] = tmp; + } +} +// since there is different sematics of MKLDNN's Fused RNN and MXNet FusedRNN, +// bidirectional will be fused layer by layer, +// unidirectional will be done by fused 1 + fused (L - 1) layers or fused L layers(when I = H) + +template +static void MKLDNNRNNForwardSingleLayerBi(bool state_outputs, + const int T, + const int N, + const int I, + const int H, + DType* x_ptr, + mkldnn::memory *user_src_layer_memory, + DType* hx_ptr, + DType* cx_ptr, + DType* w_ptr, + DType* b_ptr, + DType* y_ptr, + DType* hy_ptr, + DType* cy_ptr, + std::vector *concat_weight_memory, + std::vector *concat_iter_memory, + std::vector *x_memory, + std::vector *hcx_memory, + std::vector *wx_memory, + std::vector *wh_memory, + std::vector *bias_memory, + std::vector *y_memory, + std::vector *hcy_memory, + std::vector *rnn_forward_prim, + int layer_index, + bool *has_cache, + int lvalue, + int dtype, + bool is_train, + int mode) { + int ngates = 0, nstates = 0; + algorithm nalgorithm = GetMKLDNNRNNAlgo(mode, &ngates, &nstates); + mkldnn::memory::data_type mkldnn_dtype = get_mkldnn_type(dtype); + const int single_cell_size = N * H; + const int single_b_size = ngates * H; + DType* wx = w_ptr; // ngates * H, I + DType* wh = w_ptr + I * H * ngates; // ngates * H, H + DType* back_wx = w_ptr + ngates * H * (I + H); + DType* back_wh = back_wx + I * H * ngates; + DType* bx = b_ptr; + DType* bh = b_ptr + H * ngates; + DType* back_bx = b_ptr + single_b_size * 2; + DType* back_bh = back_bx + H * ngates; + const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); + auto cpu_engine = CpuEngine::Get()->get_engine(); + auto null_memory_ = null_memory(cpu_engine); + int offset1 = 0, offset2 = 0; + bool initialized = *has_cache; + mkldnn::memory::dims src_layer_tz = {T, N, I}; + mkldnn::memory::dims dst_layer_tz = {T, N, 2 * H}; + mkldnn::memory::dims weights_layer_tz = {1, 2, I, ngates, H}; // ldigo + mkldnn::memory::dims weights_layer_r_tz = {1, 1, I, ngates, H}; // ldigo for reorder + mkldnn::memory::dims weights_iter_tz = {1, 2, H, ngates, H}; // ldigo + mkldnn::memory::dims weights_iter_r_tz = {1, 1, H, ngates, H}; // ldigo for reorder + mkldnn::memory::dims bias_tz = {1, 2, ngates, H}; + mkldnn::memory::dims src_iter_tz = {1, 2, nstates, N, H}; // ldsnc + mkldnn::memory::dims dst_iter_tz = {1, 2, nstates, N, H}; // ldsnc + + if (!initialized) { + if (mode == rnn_enum::kGru) { + AdjustGruWeightGateOrder(wx, I, H); + AdjustGruWeightGateOrder(back_wx, I, H); + AdjustGruWeightGateOrder(wh, H, H); + AdjustGruWeightGateOrder(back_wh, H, H); + AdjustGruBiasGateOrder(bx, H); + AdjustGruBiasGateOrder(back_bx, H); + AdjustGruBiasGateOrder(bh, H); + AdjustGruBiasGateOrder(back_bh, H); + } + auto src_wx = (*concat_weight_memory)[2 * layer_index]; + auto src_wh = (*concat_weight_memory)[2 * layer_index + 1]; + std::vector srcs_data1; + srcs_data1.push_back(wx); + srcs_data1.push_back(back_wx); + ConcatData(mkldnn::memory::format::ldgoi, mkldnn::memory::format::ldgoi, + {weights_layer_r_tz, weights_layer_r_tz}, weights_layer_tz, + mkldnn_dtype, 1, srcs_data1, src_wx); + srcs_data1.clear(); + srcs_data1.push_back(wh); + srcs_data1.push_back(back_wh); + ConcatData(mkldnn::memory::format::ldgoi, mkldnn::memory::format::ldgoi, + {weights_iter_r_tz, weights_iter_r_tz}, weights_iter_tz, + mkldnn_dtype, 1, srcs_data1, src_wh); + int tmpvalue = 0; + if (lvalue > 0) { + tmpvalue = lvalue + 1; + } + MKLDNNStream::Get()->RegisterPrim(reorder(src_wx, (*wx_memory)[tmpvalue])); + MKLDNNStream::Get()->RegisterPrim(reorder(src_wh, (*wh_memory)[tmpvalue])); + + DType* user_bias = reinterpret_cast + ((*bias_memory)[tmpvalue].get_data_handle()); + #pragma omp parallel for num_threads(omp_threads) + for (int j = 0; j < single_b_size; j++) { + user_bias[j] = bx[j] + bh[j]; + user_bias[single_b_size + j] = back_bx[j] + back_bh[j]; + } + } + if (lvalue > 0) { + (*wx_memory)[layer_index].set_data_handle((*wx_memory)[lvalue + 1].get_data_handle()); + (*wh_memory)[layer_index].set_data_handle((*wh_memory)[lvalue + 1].get_data_handle()); + (*bias_memory)[layer_index].set_data_handle((*bias_memory)[lvalue + 1].get_data_handle()); + } + + auto src_layer_md = mkldnn::memory::desc( + { src_layer_tz }, mkldnn_dtype, mkldnn::memory::format::tnc); + auto weight_layer_md = mkldnn::memory::desc( + { weights_layer_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo); + auto weight_iter_md = mkldnn::memory::desc( + { weights_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo); + auto dst_layer_md = mkldnn::memory::desc( + { dst_layer_tz }, mkldnn_dtype, mkldnn::memory::format::tnc); + auto dst_iter_md = mkldnn::memory::desc( + { dst_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldsnc); + auto src_iter_md = mkldnn::memory::desc( + {src_iter_tz}, mkldnn_dtype, mkldnn::memory::format::ldsnc); + auto bias_md = mkldnn::memory::desc({bias_tz}, + mkldnn_dtype, mkldnn::memory::format::ldgo); + + auto user_src_iter_memory = (*concat_iter_memory)[2]; + if (mode == rnn_enum::kLstm) { + std::vector srcs_data1; + srcs_data1.push_back(hx_ptr); + srcs_data1.push_back(cx_ptr); + auto tmp1_src_iter_memory = (*concat_iter_memory)[0]; + ConcatData(mkldnn::memory::format::ldsnc, mkldnn::memory::format::ldsnc, + {{1, 1, 1, N, H}, {1, 1, 1, N, H}}, {1, 1, nstates, N, H}, mkldnn_dtype, 2, + srcs_data1, tmp1_src_iter_memory); + std::vector srcs_data2; + srcs_data2.push_back(hx_ptr + single_cell_size); + srcs_data2.push_back(cx_ptr + single_cell_size); + auto tmp2_src_iter_memory = (*concat_iter_memory)[1]; + ConcatData(mkldnn::memory::format::ldsnc, mkldnn::memory::format::ldsnc, + {{1, 1, 1, N, H}, {1, 1, 1, N, H}}, {1, 1, nstates, N, H}, mkldnn_dtype, 2, + srcs_data2, tmp2_src_iter_memory); + std::vector srcs_data3; + srcs_data3.push_back(reinterpret_cast(tmp1_src_iter_memory.get_data_handle())); + srcs_data3.push_back(reinterpret_cast(tmp2_src_iter_memory.get_data_handle())); + ConcatData(mkldnn::memory::format::ldsnc, mkldnn::memory::format::ldsnc, + {{1, 1, nstates, N, H}, {1, 1, nstates, N, H}}, {1, 2, nstates, N, H}, + mkldnn_dtype, 1, srcs_data3, user_src_iter_memory); + } else { + user_src_iter_memory.set_data_handle(hx_ptr); + } + (*hcx_memory)[layer_index].set_data_handle(user_src_iter_memory.get_data_handle()); + + rnn_cell::desc rnn_cell(nalgorithm, + mode == rnn_enum::kRnnRelu ? algorithm::eltwise_relu : algorithm::eltwise_tanh); + + rnn_forward::desc layer_desc(prop_kind::forward_inference, rnn_cell, + rnn_direction::bidirectional_concat, src_layer_md, + src_iter_md, weight_layer_md, weight_iter_md, + bias_md, dst_layer_md, dst_iter_md); + + auto prim_desc + = rnn_forward::primitive_desc(layer_desc, cpu_engine); + + if (x_ptr && layer_index == 0) { + (*x_memory)[layer_index].set_data_handle(x_ptr); + } else { + (*x_memory)[layer_index].set_data_handle((*user_src_layer_memory).get_data_handle()); + } + (*y_memory)[layer_index].set_data_handle(y_ptr); + + if (rnn_forward_prim->size() <= (size_t)layer_index) { + primitive rnn_prim = rnn_forward(prim_desc, (*x_memory)[layer_index], + (*hcx_memory)[layer_index], (*wx_memory)[layer_index], + (*wh_memory)[layer_index], (*bias_memory)[layer_index], + (*y_memory)[layer_index], + (*hcy_memory)[layer_index], null_memory_); + rnn_forward_prim->push_back(rnn_prim); + } + MKLDNNStream::Get()->RegisterPrim((*rnn_forward_prim)[layer_index]); + MKLDNNStream::Get()->Submit(); + + if (state_outputs) { + DType* dst_hcy = reinterpret_cast ((*hcy_memory)[layer_index].get_data_handle()); + if (mode == rnn_enum::kLstm) { + offset1 = nstates * single_cell_size; + offset2 = (nstates + 1) * single_cell_size; + #pragma omp parallel for num_threads(omp_threads) + for (int n = 0; n < single_cell_size; n++) { + hy_ptr[n] = dst_hcy[n]; + hy_ptr[n + single_cell_size] = dst_hcy[n + offset1]; + cy_ptr[n] = dst_hcy[n + single_cell_size]; + cy_ptr[n + single_cell_size] = dst_hcy[n + offset2]; + } + } else { + #pragma omp parallel for num_threads(omp_threads) + for (int n = 0; n < 2 * single_cell_size; n++) { + hy_ptr[n] = dst_hcy[n]; + } + } + } +} + + +template +static void MKLDNNRNNForwardUnidi(bool state_outputs, + const int L, + const int T, + const int N, + const int I, + const int H, + DType* x_ptr, + mkldnn::memory *user_src_layer_memory, + DType* hx_ptr, + DType* cx_ptr, + DType* w_ptr, + DType* b_ptr, + DType* y_ptr, + DType* hy_ptr, + DType* cy_ptr, + std::vector *concat_weight_memory, + std::vector *concat_iter_memory, + std::vector *x_memory, + std::vector *hcx_memory, + std::vector *wx_memory, + std::vector *wh_memory, + std::vector *bias_memory, + std::vector *y_memory, + std::vector *hcy_memory, + std::vector *rnn_forward_prim, + int layer_index, + bool *has_cache, + int dtype, + bool is_train, + int mode) { + int ngates = 0, nstates = 0; + algorithm nalgorithm = GetMKLDNNRNNAlgo(mode, &ngates, &nstates); + mkldnn::memory::data_type mkldnn_dtype = get_mkldnn_type(dtype); + const int cell_size = N * H; + const int single_cell_size = N * H; + const int single_b_size = ngates * H; + int w_size = (I + H) * H * ngates; + const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); + auto cpu_engine = CpuEngine::Get()->get_engine(); + auto null_memory_ = null_memory(cpu_engine); + int offset1 = 0, offset2 = 0; + bool initialized = *has_cache; + + mkldnn::memory::dims src_layer_tz = {T, N, I}; + mkldnn::memory::dims dst_layer_tz = {T, N, H}; + mkldnn::memory::dims weights_layer_tz = {L, 1, I, ngates, H}; // ldigo + mkldnn::memory::dims weights_iter_tz = {L, 1, H, ngates, H}; // ldigo + mkldnn::memory::dims bias_tz = {L, 1, ngates, H}; + mkldnn::memory::dims src_iter_tz = {L, 1, nstates, N, H}; // ldsnc + mkldnn::memory::dims dst_iter_tz = {L, 1, nstates, N, H}; // ldsnc + mkldnn::memory::dims weights_layer_r_tz = {1, 1, I, ngates, H}; // ldigo for reorder + mkldnn::memory::dims weights_iter_r_tz = {1, 1, H, ngates, H}; // ldigo for reorder + + auto weight_layer_md = mkldnn::memory::desc( + { weights_layer_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo); + auto weight_iter_md = mkldnn::memory::desc( + { weights_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo); + auto src_layer_md = mkldnn::memory::desc( + { src_layer_tz }, mkldnn_dtype, mkldnn::memory::format::tnc); + auto dst_layer_md = mkldnn::memory::desc( + {dst_layer_tz}, mkldnn_dtype, mkldnn::memory::format::tnc); + auto src_iter_md = mkldnn::memory::desc( + {src_iter_tz}, mkldnn_dtype, mkldnn::memory::format::ldsnc); + auto bias_md = mkldnn::memory::desc({bias_tz}, + mkldnn_dtype, mkldnn::memory::format::ldgo); + auto dst_iter_md = mkldnn::memory::desc( + {dst_iter_tz}, mkldnn_dtype, mkldnn::memory::format::ldsnc); + + for (int l = 0; l < L; l++) { + if (mode == rnn_enum::kLstm) { + std::vector srcs_data; + srcs_data.push_back(hx_ptr); + srcs_data.push_back(cx_ptr); + auto tmp_src_iter_memory = (*concat_iter_memory)[l + layer_index]; + ConcatData(mkldnn::memory::format::ldsnc, mkldnn::memory::format::ldsnc, + {{1, 1, 1, N, H}, {1, 1, 1, N, H}}, {1, 1, nstates, N, H}, mkldnn_dtype, + 2, srcs_data, tmp_src_iter_memory); + } else { + (*concat_iter_memory)[l + layer_index].set_data_handle(hx_ptr); + } + hx_ptr += cell_size; + if (mode == rnn_enum::kLstm) { + cx_ptr += cell_size; + } + } + + auto user_src_iter_memory = null_memory_; + if (L == 1) { + user_src_iter_memory = (*concat_iter_memory)[layer_index]; + } else { + user_src_iter_memory = (*concat_iter_memory)[L + layer_index]; + std::vector src_l_data; + std::vector src_l_dim; + for (int l = 0; l < L; l++) { + src_l_data.push_back(reinterpret_cast + ((*concat_iter_memory)[l + layer_index].get_data_handle())); + src_l_dim.push_back({1, 1, nstates, N, H}); + } + ConcatData(mkldnn::memory::format::ldsnc, mkldnn::memory::format::ldsnc, src_l_dim, + {L, 1, nstates, N, H}, mkldnn_dtype, 0, src_l_data, user_src_iter_memory); + } + (*hcx_memory)[layer_index].set_data_handle(user_src_iter_memory.get_data_handle()); + + auto src_wx_f = (*concat_weight_memory)[2 * layer_index]; + auto src_wh_f = (*concat_weight_memory)[2 * layer_index + 1]; + + std::vector srcs_data_x; + std::vector srcs_data_h; + std::vector src_l_dim_x; + std::vector src_l_dim_h; + if (!initialized) { + if (L == 1) { + DType* wx = w_ptr; + DType* wh = w_ptr + I * H * ngates; + if (mode == rnn_enum::kGru) { + AdjustGruWeightGateOrder(wx, I, H); + AdjustGruWeightGateOrder(wh, H, H); + AdjustGruBiasGateOrder(b_ptr, H); + AdjustGruBiasGateOrder(b_ptr + H * ngates, H); + } + src_wx_f.set_data_handle(wx); + src_wh_f.set_data_handle(wh); + } else { + for (int l = 0; l < L; l++) { + DType* wx = w_ptr; + DType* wh = w_ptr + I * H * ngates; + DType* bx = b_ptr + l * ngates * H * 2; + DType* bh = b_ptr + l * ngates * H * 2 + H * ngates; + if (mode == rnn_enum::kGru) { + AdjustGruWeightGateOrder(wx, I, H); + AdjustGruWeightGateOrder(wh, H, H); + AdjustGruBiasGateOrder(bx, H); + AdjustGruBiasGateOrder(bh, H); + } + srcs_data_x.push_back(wx); + srcs_data_h.push_back(wh); + src_l_dim_x.push_back(weights_layer_r_tz); + src_l_dim_h.push_back(weights_iter_r_tz); + w_ptr = w_ptr + w_size; + } + ConcatData(mkldnn::memory::format::ldgoi, mkldnn::memory::format::ldgoi, + src_l_dim_x, weights_layer_tz, mkldnn_dtype, 0, srcs_data_x, src_wx_f); + ConcatData(mkldnn::memory::format::ldgoi, mkldnn::memory::format::ldgoi, + src_l_dim_h, weights_iter_tz, mkldnn_dtype, 0, srcs_data_h, src_wh_f); + } + MKLDNNStream::Get()->RegisterPrim(reorder(src_wx_f, (*wx_memory)[layer_index])); + MKLDNNStream::Get()->RegisterPrim(reorder(src_wh_f, (*wh_memory)[layer_index])); + + DType* user_bias_f = reinterpret_cast ((*bias_memory)[layer_index].get_data_handle()); + #pragma omp parallel for num_threads(omp_threads) + for (int j = 0; j < L * single_b_size; j++) { + int k = j / single_b_size; + user_bias_f[j] = b_ptr[j + k * single_b_size] + b_ptr[j + k * single_b_size + single_b_size]; + } + } + + rnn_cell::desc rnn_cell(nalgorithm, + mode == rnn_enum::kRnnRelu ? algorithm::eltwise_relu : algorithm::eltwise_tanh); + + rnn_forward::desc layer_desc(prop_kind::forward_inference, rnn_cell, + rnn_direction::unidirectional, src_layer_md, + src_iter_md, weight_layer_md, weight_iter_md, + bias_md, dst_layer_md, dst_iter_md); + + auto prim_desc + = rnn_forward::primitive_desc(layer_desc, cpu_engine); + + if (x_ptr && layer_index == 0) { + (*x_memory)[layer_index].set_data_handle(x_ptr); + } else { + (*x_memory)[layer_index].set_data_handle((*user_src_layer_memory).get_data_handle()); + } + (*y_memory)[layer_index].set_data_handle(y_ptr); + + if (rnn_forward_prim->size() <= (size_t)layer_index) { + primitive rnn_prim = rnn_forward(prim_desc, (*x_memory)[layer_index], + (*hcx_memory)[layer_index], (*wx_memory)[layer_index], + (*wh_memory)[layer_index], (*bias_memory)[layer_index], + (*y_memory)[layer_index], + (*hcy_memory)[layer_index], null_memory_); + rnn_forward_prim->push_back(rnn_prim); + } + MKLDNNStream::Get()->RegisterPrim((*rnn_forward_prim)[layer_index]); + MKLDNNStream::Get()->Submit(); + + if (state_outputs) { + DType* dst_hcy = reinterpret_cast ((*hcy_memory)[layer_index].get_data_handle()); + if (mode == rnn_enum::kLstm) { + for (int l = 0; l < L; l++) { + offset1 = l * single_cell_size; + offset2 = l * nstates * single_cell_size; + #pragma omp parallel for num_threads(omp_threads) + for (int n = 0; n < single_cell_size; n++) { + hy_ptr[offset1 + n] = dst_hcy[offset2 + n]; + cy_ptr[offset1 + n] = dst_hcy[offset2 + n + single_cell_size]; + } + } + } else { + #pragma omp parallel for num_threads(omp_threads) + for (int n = 0; n < L * single_cell_size; n++) { + hy_ptr[n] = dst_hcy[n]; + } + } + } +} + +template +static void MKLDNNRNNForward(bool state_outputs, + const int L, + const int D, + const int T, + const int N, + const int I, + const int H, + DType* x_ptr, + DType* hx_ptr, + DType* cx_ptr, + DType* w_ptr, + DType* b_ptr, + DType* y_ptr, + DType* hy_ptr, + DType* cy_ptr, + std::vector *concat_weight_memory, + std::vector *concat_iter_memory, + std::vector *x_memory, + std::vector *hcx_memory, + std::vector *wx_memory, + std::vector *wh_memory, + std::vector *bias_memory, + std::vector *y_memory, + std::vector *hcy_memory, + std::vector *rnn_forward_prim, + bool *has_cache, + int dtype, + bool is_train, + int mode) { + int ngates = 0, nstates = 0; + GetMKLDNNRNNAlgo(mode, &ngates, &nstates); + const int b_size = 2 * H * ngates * D; + const int cell_size = N * H * D; + // First layer + int w_size = (I + H) * H * ngates * D; + auto cpu_engine = CpuEngine::Get()->get_engine(); + auto null_memory_ = null_memory(cpu_engine); + DType* tmpNull = NULL; + // when D = 1 and I == H, L layers can be fused together + if (D == 1 && I == H) { + MKLDNNRNNForwardUnidi(state_outputs, L, T, N, I, H, x_ptr, &null_memory_, + hx_ptr, cx_ptr, w_ptr, b_ptr, y_ptr, hy_ptr, cy_ptr, concat_weight_memory, + concat_iter_memory, x_memory, hcx_memory, wx_memory, wh_memory, + bias_memory, y_memory, hcy_memory, rnn_forward_prim, + 0, has_cache, dtype, is_train, mode); + } else { + auto user_src_layer_memory_l = null_memory_; + if (D == 2) { + MKLDNNRNNForwardSingleLayerBi(state_outputs, T, N, I, H, x_ptr, &user_src_layer_memory_l, + hx_ptr, cx_ptr, w_ptr, b_ptr, y_ptr, hy_ptr, cy_ptr, concat_weight_memory, + concat_iter_memory, x_memory, hcx_memory, wx_memory, wh_memory, + bias_memory, y_memory, hcy_memory, rnn_forward_prim, + 0, has_cache, 0, dtype, is_train, mode); + } else { + MKLDNNRNNForwardUnidi(state_outputs, 1, T, N, I, H, x_ptr, &user_src_layer_memory_l, + hx_ptr, cx_ptr, w_ptr, b_ptr, y_ptr, hy_ptr, cy_ptr, concat_weight_memory, + concat_iter_memory, x_memory, hcx_memory, wx_memory, wh_memory, + bias_memory, y_memory, hcy_memory, rnn_forward_prim, + 0, has_cache, dtype, is_train, mode); + } + if (L > 1) { + user_src_layer_memory_l = (*y_memory)[0]; + // go to next L - 1 layers. + // If D = 2, do it layer by layer. If D = 1, fused L - 1 layers + w_ptr += w_size; + b_ptr += b_size; + if (D == 2) { + w_size = (H * D + H) * H * ngates * D; + for (int l = 0; l < L - 1; l++) { + if (state_outputs) { + hy_ptr += cell_size; + if (mode == rnn_enum::kLstm) { + cy_ptr += cell_size; + } + } + hx_ptr += cell_size; + if (mode == rnn_enum::kLstm) { + cx_ptr += cell_size; + } + MKLDNNRNNForwardSingleLayerBi(state_outputs, T, N, D * H, H, tmpNull, + &user_src_layer_memory_l, hx_ptr, cx_ptr, w_ptr, b_ptr, y_ptr, hy_ptr, + cy_ptr, concat_weight_memory, concat_iter_memory, x_memory, + hcx_memory, wx_memory, wh_memory, bias_memory, + y_memory, hcy_memory, rnn_forward_prim, + 1, has_cache, l + 1, dtype, is_train, mode); + user_src_layer_memory_l = (*y_memory)[1]; + w_ptr += w_size; + b_ptr += b_size; + } + } + if (D == 1) { + if (state_outputs) { + hy_ptr += cell_size; + if (mode == rnn_enum::kLstm) { + cy_ptr += cell_size; + } + } + w_size = (H + H) * H * ngates; + MKLDNNRNNForwardUnidi(state_outputs, L - 1, T, N, H, H, tmpNull, &user_src_layer_memory_l, + hx_ptr, cx_ptr, w_ptr, b_ptr, y_ptr, hy_ptr, cy_ptr, concat_weight_memory, + concat_iter_memory, x_memory, hcx_memory, wx_memory, + wh_memory, bias_memory, y_memory, hcy_memory, + rnn_forward_prim, 1, has_cache, dtype, is_train, mode); + } + } + } + *has_cache = true; +} + +template +static void MKLDNNRNNForwardInference(bool state_outputs, + const int num_layers, + const int direction, + const int seq_length, + const int batch_size, + const int input_size, + const int state_size, + DType* x_ptr, + DType* hx_ptr, + DType* cx_ptr, + DType* w_ptr, + DType* b_ptr, + DType* y_ptr, + DType* hy_ptr, + DType* cy_ptr, + std::vector* concat_weight_memory, + std::vector* concat_iter_memory, + std::vector *x_memory, + std::vector *hcx_memory, + std::vector *wx_memory, + std::vector *wh_memory, + std::vector *bias_memory, + std::vector *y_memory, + std::vector *hcy_memory, + std::vector *rnn_forward_prim, + bool *has_cache, + int dtype, + bool is_train, + int mode) { + switch (mode) { + case rnn_enum::kLstm: + case rnn_enum::kGru: + case rnn_enum::kRnnTanh: + case rnn_enum::kRnnRelu: + MKLDNNRNNForward(state_outputs, num_layers, direction, seq_length, + batch_size, input_size, state_size, x_ptr, hx_ptr, + cx_ptr, w_ptr, b_ptr, y_ptr, hy_ptr, cy_ptr, + concat_weight_memory, concat_iter_memory, x_memory, + hcx_memory, wx_memory, wh_memory, + bias_memory, y_memory, hcy_memory, rnn_forward_prim, + has_cache, dtype, is_train, mode); + break; + default: + LOG(FATAL) << "unknown RNN mode" << mode; + break; + } +} + +} // namespace op +} // namespace mxnet +#endif // MXNET_USE_MKLDNN == 1 +#endif // MXNET_OPERATOR_NN_MKLDNN_MKLDNN_RNN_IMPL_H_ diff --git a/src/operator/rnn-inl.h b/src/operator/rnn-inl.h index e43b3c9b5131..9785be209b7d 100644 --- a/src/operator/rnn-inl.h +++ b/src/operator/rnn-inl.h @@ -44,17 +44,13 @@ #include "./math_functions-inl.h" #include "./operator_common.h" #include "./rnn_impl.h" +#if MXNET_USE_MKLDNN == 1 +#include "./nn/mkldnn/mkldnn_rnn_impl.h" +#endif namespace mxnet { namespace op { -namespace rnn_enum { - enum RNNOpInputs {kData, kParams, kState, kStateCell, kSequenceLength}; - enum RNNOpOutputs {kOut, kStateOut, kStateCellOut}; - enum RNNModeType {kRnnRelu, kRnnTanh, kLstm, kGru}; - enum RNNOpResource {kTempSpace, kCuDNNDropoutDescSpace}; -} - inline int GetRnnParamSize(int num_layer, int input_size, int state_size, @@ -400,9 +396,29 @@ class RNNOp { public: RNNParam param_; Context ctx_; + #if MXNET_USE_MKLDNN == 1 + std::vector concat_weight_memory; + std::vector concat_iter_memory; + std::vector rnn_forward_prim; + std::vector x_memory; + std::vector hcx_memory; + std::vector wx_memory; + std::vector wh_memory; + std::vector bias_memory; + std::vector y_memory; + std::vector hcy_memory; + bool has_cache; + bool init_mem_; + size_t reserve_mem_size_; + Storage::Handle mem_space_; + #endif explicit RNNOp(RNNParam param, Context ctx) { this->param_ = param; this->ctx_ = ctx; + #if MXNET_USE_MKLDNN == 1 + init_mem_ = false; + reserve_mem_size_ = 0; + #endif #if MXNET_USE_CUDNN_RNN init_cudnn_ = false; dtype_ = mshadow::DataType::kCudnnFlag; @@ -410,8 +426,8 @@ class RNNOp { // No tests in place for fp16 RNNs, so leave TensorCore disabled for now. cudnn_tensor_core_ = false; // When fp16 RNN tests are introduced, we can enable TensorCore as follows: -// cudnn_tensor_core = -// mshadow::DataType::kFlag == mshadow::kFloat16 && GetEnvAllowTensorCore(); + // cudnn_tensor_core = + // mshadow::DataType::kFlag == mshadow::kFloat16 && GetEnvAllowTensorCore(); // Defaults input_mode_ = CUDNN_LINEAR_INPUT; // Don't support this yet // RNN Mode @@ -492,7 +508,6 @@ class RNNOp { this->temp_init_space_ = false; this->reserve_cpu_space_size_ = 0; this->temp_cpu_space_size_ = 0; - if (param_.projection_size.has_value()) { LOG(FATAL) << "hidden layer projection is only supported for GPU with CuDNN later than 7.1.1"; @@ -505,6 +520,12 @@ class RNNOp { } ~RNNOp() { + #if MXNET_USE_MKLDNN == 1 + if (init_mem_) { + Storage::Get()->Free(mem_space_); + init_mem_ = false; + } + #endif #if MXNET_USE_CUDNN_RNN CUDNN_CALL(cudnnDestroyTensorDescriptor(hx_desc_)); CUDNN_CALL(cudnnDestroyTensorDescriptor(cx_desc_)); @@ -829,22 +850,23 @@ class RNNOp { #endif if (ctx_.dev_type == kCPU) { - // allocate temp space - const size_t work_cpu_space_size = - GetRNNWorkspaceSize(param_.seq_length_, param_.batch_size_, - param_.state_size, direction, param_.mode); - if (temp_init_space_ && temp_cpu_space_size_ < work_cpu_space_size) { - Storage::Get()->Free(temp_cpu_space_); - temp_init_space_ = false; - } - if (!temp_init_space_) { - temp_cpu_space_ = Storage::Get()->Alloc - (work_cpu_space_size * sizeof(DType), Context::CPU()); - temp_cpu_space_size_ = work_cpu_space_size; - temp_init_space_ = true; - } - DType* work_cpu_space = static_cast(temp_cpu_space_.dptr); if (ctx.is_train) { + // allocate temp space + const size_t work_cpu_space_size = + GetRNNWorkspaceSize(param_.seq_length_, param_.batch_size_, + param_.state_size, direction, param_.mode); + if (temp_init_space_ && temp_cpu_space_size_ < work_cpu_space_size) { + Storage::Get()->Free(temp_cpu_space_); + temp_init_space_ = false; + } + if (!temp_init_space_) { + temp_cpu_space_ = Storage::Get()->Alloc + (work_cpu_space_size * sizeof(DType), Context::CPU()); + temp_cpu_space_size_ = work_cpu_space_size; + temp_init_space_ = true; + } + DType* work_cpu_space = static_cast(temp_cpu_space_.dptr); + const size_t r_size = GetRNNReserveSpaceSize(param_.num_layers, direction, param_.seq_length_, param_.batch_size_, param_.state_size, param_.mode); @@ -880,23 +902,78 @@ class RNNOp { param_.p, param_.mode); } else { - RNNForwardInference(work_cpu_space, - param_.state_outputs, - param_.num_layers, - direction, - param_.seq_length_, - param_.batch_size_, - param_.input_size_, - param_.state_size, - x.dptr_, - hx.dptr_, - cx_ptr, - w.dptr_, - b_ptr, - y.dptr_, - hy_ptr, - cy_ptr, - param_.mode); + #if MXNET_USE_MKLDNN == 1 + if (dmlc::GetEnv("MXNET_USE_MKLDNN_RNN", 1) && param_.mode != rnn_enum::kGru) { + // TODO(zixuanweeei): MKLDNN GRU has precision issue. A stable one + // will be added to MXNet when we figure out the issue. + int dtype = in_data[rnn_enum::kData].type_flag_; + MKLDNNRNNForwardInference(param_.state_outputs, + param_.num_layers, + direction, + param_.seq_length_, + param_.batch_size_, + param_.input_size_, + param_.state_size, + x.dptr_, + hx.dptr_, + cx_ptr, + w.dptr_, + b_ptr, + y.dptr_, + hy_ptr, + cy_ptr, + &concat_weight_memory, + &concat_iter_memory, + &x_memory, + &hcx_memory, + &wx_memory, + &wh_memory, + &bias_memory, + &y_memory, + &hcy_memory, + &rnn_forward_prim, + &has_cache, + dtype, + ctx.is_train, + param_.mode); + } else { + #endif + // Before integrating MKLDNN GRU fp32 inference + // using below code for keep func being OK + const size_t work_cpu_space_size = + GetRNNWorkspaceSize(param_.seq_length_, param_.batch_size_, + param_.state_size, direction, param_.mode); + if (temp_init_space_ && temp_cpu_space_size_ < work_cpu_space_size) { + Storage::Get()->Free(temp_cpu_space_); + temp_init_space_ = false; + } + if (!temp_init_space_) { + temp_cpu_space_ = Storage::Get()->Alloc + (work_cpu_space_size * sizeof(DType), Context::CPU()); + temp_cpu_space_size_ = work_cpu_space_size; + temp_init_space_ = true; + } + DType* work_cpu_space = static_cast(temp_cpu_space_.dptr); + RNNForwardInference(work_cpu_space, + param_.state_outputs, + param_.num_layers, + direction, + param_.seq_length_, + param_.batch_size_, + param_.input_size_, + param_.state_size, + x.dptr_, + hx.dptr_, + cx_ptr, + w.dptr_, + b_ptr, + y.dptr_, + hy_ptr, + cy_ptr, + param_.mode); + #if MXNET_USE_MKLDNN == 1 + } + #endif } } } diff --git a/src/operator/rnn.cc b/src/operator/rnn.cc index 9b412a2575a1..32184943cac0 100644 --- a/src/operator/rnn.cc +++ b/src/operator/rnn.cc @@ -167,6 +167,21 @@ static bool RNNType(const nnvm::NodeAttrs& attrs, return true; } +inline static bool RNNStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs) { + DispatchMode wanted_mode = DispatchMode::kFCompute; + + #if MXNET_USE_MKLDNN == 1 + wanted_mode = DispatchMode::kFComputeEx; + #endif + + return storage_type_assign(out_attrs, mxnet::kDefaultStorage, + dispatch_mode, wanted_mode); +} + struct RNNGrad { const char *op_name; std::vector operator()(const nnvm::NodePtr &n, @@ -191,6 +206,417 @@ struct RNNGrad { } }; +#if MXNET_USE_MKLDNN == 1 +static void RNNStatefulComputeCPU(const OpStatePtr& state_ptr, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + std::vector in_blobs; + std::vector out_blobs; + std::vector temp_ndarrays_i; + std::vector temp_ndarrays_o; + for (const NDArray& in : inputs) { + if (in.storage_type() == kDefaultStorage) { + temp_ndarrays_i.push_back(in.Reorder2Default()); + in_blobs.emplace_back(temp_ndarrays_i.back().data()); + } else { + in_blobs.emplace_back(in.data()); + } + } + + for (const NDArray& out : outputs) { + if (out.storage_type() == kDefaultStorage) { + temp_ndarrays_o.push_back(out.Reorder2Default()); + out_blobs.emplace_back(temp_ndarrays_o.back().data()); + } else { + out_blobs.emplace_back(out.data()); + } + } + int dtype = in_blobs[rnn_enum::kData].type_flag_; + int itype = in_blobs[inputs.size()-1].type_flag_; + mkldnn::memory::data_type mkldnn_dtype = get_mkldnn_type(dtype); + Stream *s = ctx.get_stream(); + auto cpu_engine = CpuEngine::Get()->get_engine(); + MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { + MSHADOW_TYPE_SWITCH(itype, IType, { + RNNOp& op = state_ptr.get_state>(); + const RNNParam& param = op.param_; + int ngates = 0, nstates = 0; + GetMKLDNNRNNAlgo(param.mode, &ngates, &nstates); + int D = param.bidirectional ? 2 : 1; + Tensor x = in_blobs[rnn_enum::kData].get(s); + int T = x.shape_[0]; + int N = x.shape_[1]; + int I = x.shape_[2]; + int H = param.state_size; + int L = param.num_layers; + + const size_t r_size = GetMKLDNNRNNCacheMemorySize(L, D, T, N, I, H, param.mode); + if (op.init_mem_ && op.reserve_mem_size_ < r_size) { + Storage::Get()->Free(op.mem_space_); + op.init_mem_ = false; + } + if (!op.init_mem_) { + op.mem_space_ = Storage::Get()->Alloc( + r_size * sizeof(DType), + Context::CPU()); + op.reserve_mem_size_ = r_size; + op.init_mem_ = true; + op.has_cache = false; + } + if (op.has_cache && op.x_memory.size() == 0) { + op.has_cache = false; + } + + DType* workptr = static_cast(op.mem_space_.dptr); + mkldnn::memory::dims src_layer_tz_0 = {T, N, I}; + mkldnn::memory::dims src_layer_tz = {T, N, D * H}; + mkldnn::memory::dims dst_layer_tz = {T, N, D * H}; + auto dst_layer_md = mkldnn::memory::desc( + { dst_layer_tz }, mkldnn_dtype, mkldnn::memory::format::tnc); + if (op.x_memory.size() == 0) { + if (D == 1 && I == H) { + auto user_src_layer_md = mkldnn::memory::desc( + { src_layer_tz }, mkldnn_dtype, mkldnn::memory::format::tnc); + auto user_src_layer_memory_n = mkldnn::memory({ user_src_layer_md, cpu_engine }); + op.x_memory.push_back(user_src_layer_memory_n); + + mkldnn::memory::dims weights_layer_tz = {L, 1, I, ngates, H}; // ldigo + mkldnn::memory::dims weights_iter_tz = {L, 1, H, ngates, H}; // ldigo + mkldnn::memory::dims bias_tz = {L, 1, ngates, H}; + auto user_weight_layer_md = mkldnn::memory::desc( + { weights_layer_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo); + auto user_weight_iter_md = mkldnn::memory::desc( + { weights_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo); + auto user_bias_md = mkldnn::memory::desc({ bias_tz }, + mkldnn_dtype, mkldnn::memory::format::ldgo); + DType* weight_layer_n = workptr; // L * I * ngates * H + auto user_weight_layer_memory_n + = mkldnn::memory({ user_weight_layer_md, cpu_engine }, weight_layer_n); + op.wx_memory.push_back(user_weight_layer_memory_n); + + DType* weight_iter_n = weight_layer_n + L * I * ngates * H; // L * H * ngates * H + auto user_weight_iter_memory_n + = mkldnn::memory({ user_weight_iter_md, cpu_engine }, weight_iter_n); + op.wh_memory.push_back(user_weight_iter_memory_n); + + DType* bias_n = weight_iter_n + L * H * ngates * H; // L * ngates * H + auto user_bias_memory_n = + mkldnn::memory({ user_bias_md, cpu_engine }, bias_n); + op.bias_memory.push_back(user_bias_memory_n); + + auto wx_md_n = mkldnn::memory::desc( + { weights_layer_tz }, mkldnn_dtype, mkldnn::memory::format::ldgoi); + DType* wx_n = bias_n + L * ngates * H; // L * ngates * I * H + auto wx_memory_n = + mkldnn::memory({ wx_md_n, cpu_engine }, wx_n); + DType* wh_n = wx_n + L * ngates * I * H; // L * ngates * H * H + auto wh_md_n = mkldnn::memory::desc( + { weights_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldgoi); + auto wh_memory_n = + mkldnn::memory({ wh_md_n, cpu_engine }, wh_n); + + op.concat_weight_memory.push_back(wx_memory_n); + op.concat_weight_memory.push_back(wh_memory_n); + workptr = wh_n + L * ngates * H * H; + + mkldnn::memory::dims src_iter_tz_n1 = {1, 1, nstates, N, H}; // ldsnc + auto src_iter_md_n1 = mkldnn::memory::desc( + { src_iter_tz_n1 }, mkldnn_dtype, mkldnn::memory::format::ldsnc); + for (int l = 0; l < L; l++) { + DType* src_iter_n1 = workptr; // nstates * N * H + auto src_iter_memory_n1 = + mkldnn::memory({ src_iter_md_n1, cpu_engine }, src_iter_n1); + op.concat_iter_memory.push_back(src_iter_memory_n1); + workptr = src_iter_n1 + nstates * N * H; + } + mkldnn::memory::dims src_iter_tz_n = {L, 1, nstates, N, H}; // ldsnc + auto src_iter_md_n = mkldnn::memory::desc( + { src_iter_tz_n }, mkldnn_dtype, mkldnn::memory::format::ldsnc); + DType* src_iter_n = workptr; // L * nstates * N * H + auto src_iter_memory_n = + mkldnn::memory({ src_iter_md_n, cpu_engine }, src_iter_n); + op.concat_iter_memory.push_back(src_iter_memory_n); + op.hcx_memory.push_back(src_iter_memory_n); + DType* dst_layer_n = src_iter_n + L * nstates * N * H; // T * N * D * H + auto dst_layer_memory_n + = mkldnn::memory({ dst_layer_md, cpu_engine }, dst_layer_n); + op.y_memory.push_back(dst_layer_memory_n); + + mkldnn::memory::dims dst_iter_tz_n = {L, 1, nstates, N, H}; // ldsnc + auto dst_iter_md_n = mkldnn::memory::desc( + { dst_iter_tz_n }, mkldnn_dtype, mkldnn::memory::format::ldsnc); + DType* dst_iter_n = dst_layer_n + T * N * D * H; // L * nstates * N * H + auto dst_iter_memory_n = + mkldnn::memory({ dst_iter_md_n, cpu_engine }, dst_iter_n); + op.hcy_memory.push_back(dst_iter_memory_n); + workptr = dst_iter_n + L * nstates * N * H; + + } else { + auto user_src_layer_md_0 = mkldnn::memory::desc( + { src_layer_tz_0 }, mkldnn_dtype, mkldnn::memory::format::tnc); + auto user_src_layer_memory_0 = mkldnn::memory({ user_src_layer_md_0, cpu_engine }); + op.x_memory.push_back(user_src_layer_memory_0); + + mkldnn::memory::dims weights_layer_tz_0 = {1, D, I, ngates, H}; // ldigo + mkldnn::memory::dims weights_iter_tz_0 = {1, D, H, ngates, H}; // ldigo + mkldnn::memory::dims bias_tz_0 = {1, D, ngates, H}; + auto user_weight_layer_md_0 = mkldnn::memory::desc( + { weights_layer_tz_0 }, mkldnn_dtype, mkldnn::memory::format::ldigo); + auto user_weight_iter_md_0 = mkldnn::memory::desc( + { weights_iter_tz_0 }, mkldnn_dtype, mkldnn::memory::format::ldigo); + auto user_bias_md_0 = mkldnn::memory::desc({ bias_tz_0 }, + mkldnn_dtype, mkldnn::memory::format::ldgo); + + DType* weight_layer_0 = workptr; // D * I * ngates * H + auto user_weight_layer_memory_0 + = mkldnn::memory({ user_weight_layer_md_0, cpu_engine }, weight_layer_0); + op.wx_memory.push_back(user_weight_layer_memory_0); + + DType* weight_iter_0 = weight_layer_0 + D * I * ngates * H; // D * H * ngates * H + auto user_weight_iter_memory_0 + = mkldnn::memory({ user_weight_iter_md_0, cpu_engine }, weight_iter_0); + op.wh_memory.push_back(user_weight_iter_memory_0); + + DType* bias_0 = weight_iter_0 + D * H * ngates * H; // D * ngates * H + auto user_bias_memory_0 = + mkldnn::memory({ user_bias_md_0, cpu_engine }, bias_0); + op.bias_memory.push_back(user_bias_memory_0); + workptr = bias_0 + D * ngates * H; + + auto wx_md_0 = mkldnn::memory::desc( + { weights_layer_tz_0 }, mkldnn_dtype, mkldnn::memory::format::ldgoi); + auto wx_memory_0 = + mkldnn::memory({ wx_md_0, cpu_engine }); + auto wh_md_0 = mkldnn::memory::desc( + { weights_iter_tz_0 }, mkldnn_dtype, mkldnn::memory::format::ldgoi); + auto wh_memory_0 = + mkldnn::memory({ wh_md_0, cpu_engine }); + if (D == 2) { + DType* wx_0 = workptr; // D * ngates * I * H + wx_memory_0.set_data_handle(wx_0); + DType* wh_0 = wx_0 + D * ngates * I * H; // D * ngates * H * H + wh_memory_0.set_data_handle(wh_0); + workptr = wh_0 + D * ngates * H * H; + } + op.concat_weight_memory.push_back(wx_memory_0); + op.concat_weight_memory.push_back(wh_memory_0); + + mkldnn::memory::dims src_iter_undi_tz_0 = {1, 1, nstates, N, H}; // ldsnc + auto src_iter_undi_md_0 = mkldnn::memory::desc( + { src_iter_undi_tz_0 }, mkldnn_dtype, mkldnn::memory::format::ldsnc); + DType* src_iter_undi_0 = workptr; // nstates * N * H + auto src_iter_undi_memory_0 = + mkldnn::memory({ src_iter_undi_md_0, cpu_engine }, src_iter_undi_0); + op.concat_iter_memory.push_back(src_iter_undi_memory_0); + workptr = src_iter_undi_0 + nstates * N * H; + if (D == 1) { + op.hcx_memory.push_back(src_iter_undi_memory_0); + } else { + DType* src_iter_undi2_0 = workptr; // nstates * N * H + auto src_iter_undi2_memory_0 = + mkldnn::memory({ src_iter_undi_md_0, cpu_engine }, src_iter_undi2_0); + op.concat_iter_memory.push_back(src_iter_undi2_memory_0); + + mkldnn::memory::dims src_iter_tz_0 = {1, D, nstates, N, H}; // ldsnc + auto src_iter_md_0 = mkldnn::memory::desc( + { src_iter_tz_0 }, mkldnn_dtype, mkldnn::memory::format::ldsnc); + DType* src_iter_0 = src_iter_undi2_0 + nstates * N * H; // D * nstates * N * H + auto src_iter_memory_0 = + mkldnn::memory({ src_iter_md_0, cpu_engine }, src_iter_0); + op.concat_iter_memory.push_back(src_iter_memory_0); + op.hcx_memory.push_back(src_iter_memory_0); + workptr = src_iter_0 + D * nstates * N * H; + } + + DType* dst_layer_0 = workptr; // T * N * D * H + auto dst_layer_memory_0 + = mkldnn::memory({ dst_layer_md, cpu_engine }, dst_layer_0); + op.y_memory.push_back(dst_layer_memory_0); + + mkldnn::memory::dims dst_iter_tz_0 = {1, D, nstates, N, H}; // ldsnc + auto dst_iter_md_0 = mkldnn::memory::desc( + { dst_iter_tz_0 }, mkldnn_dtype, mkldnn::memory::format::ldsnc); + DType* dst_iter_0 = dst_layer_0 + T * N * D * H; // D * nstates * N * H + auto dst_iter_memory_0 = + mkldnn::memory({ dst_iter_md_0, cpu_engine }, dst_iter_0); + op.hcy_memory.push_back(dst_iter_memory_0); + workptr = dst_iter_0 + D * nstates * N * H; + + // next L - 1 layers + if (L > 1 && D == 1) { + auto user_src_layer_md = mkldnn::memory::desc( + { src_layer_tz }, mkldnn_dtype, mkldnn::memory::format::tnc); + auto user_src_layer_memory = mkldnn::memory({ user_src_layer_md, cpu_engine }); + op.x_memory.push_back(user_src_layer_memory); + + mkldnn::memory::dims weights_layer_tz = {L - 1, 1, H, ngates, H}; // ldigo + mkldnn::memory::dims weights_iter_tz = {L - 1, 1, H, ngates, H}; // ldigo + mkldnn::memory::dims bias_tz = {L - 1, 1, ngates, H}; + auto user_weight_layer_md = mkldnn::memory::desc( + { weights_layer_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo); + auto user_weight_iter_md = mkldnn::memory::desc( + { weights_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo); + auto user_bias_md = mkldnn::memory::desc({ bias_tz }, + mkldnn_dtype, mkldnn::memory::format::ldgo); + + DType* weight_layer_n = workptr; // (L - 1) * H * ngates * H + auto user_weight_layer_memory_n + = mkldnn::memory({ user_weight_layer_md, cpu_engine }, weight_layer_n); + op.wx_memory.push_back(user_weight_layer_memory_n); + + DType* weight_iter_n = weight_layer_n + + (L - 1) * H * ngates * H; // (L - 1) * H * ngates * H + auto user_weight_iter_memory_n + = mkldnn::memory({ user_weight_iter_md, cpu_engine }, weight_iter_n); + op.wh_memory.push_back(user_weight_iter_memory_n); + + DType* bias_n = weight_iter_n + (L - 1) * H * ngates * H; // (L - 1) * ngates * H + auto user_bias_memory_n = + mkldnn::memory({ user_bias_md, cpu_engine }, bias_n); + op.bias_memory.push_back(user_bias_memory_n); + + auto wx_md_n = mkldnn::memory::desc( + { weights_layer_tz }, mkldnn_dtype, mkldnn::memory::format::ldgoi); + DType* wx_n = bias_n + (L - 1) * ngates * H; // (L - 1) * ngates * H * H + auto wx_memory_n = + mkldnn::memory({ wx_md_n, cpu_engine }, wx_n); + DType* wh_n = wx_n + (L - 1) * ngates * H * H; // (L - 1) * ngates * H * H + auto wh_md_n = mkldnn::memory::desc( + { weights_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldgoi); + auto wh_memory_n = + mkldnn::memory({ wh_md_n, cpu_engine }, wh_n); + + op.concat_weight_memory.push_back(wx_memory_n); + op.concat_weight_memory.push_back(wh_memory_n); + workptr = wh_n + (L - 1) * ngates * H * H; + + mkldnn::memory::dims src_iter_tz_n1 = {1, 1, nstates, N, H}; // ldsnc + auto src_iter_md_n1 = mkldnn::memory::desc( + { src_iter_tz_n1 }, mkldnn_dtype, mkldnn::memory::format::ldsnc); + for (int l = 0; l < L - 1; l++) { + DType* src_iter_n1 = workptr; // nstates * N * H + auto src_iter_memory_n1 = + mkldnn::memory({ src_iter_md_n1, cpu_engine }, src_iter_n1); + op.concat_iter_memory.push_back(src_iter_memory_n1); + workptr = src_iter_n1 + nstates * N * H; + } + mkldnn::memory::dims src_iter_tz_n = {L - 1, 1, nstates, N, H}; // ldsnc + auto src_iter_md_n = mkldnn::memory::desc( + { src_iter_tz_n }, mkldnn_dtype, mkldnn::memory::format::ldsnc); + DType* src_iter_n = workptr; // (L - 1) * nstates * N * H + auto src_iter_memory_n = + mkldnn::memory({ src_iter_md_n, cpu_engine }, src_iter_n); + op.concat_iter_memory.push_back(src_iter_memory_n); + op.hcx_memory.push_back(src_iter_memory_n); + + DType* dst_layer_n = src_iter_n + (L - 1) * nstates * N * H; // T * N * D * H + auto dst_layer_memory_n + = mkldnn::memory({ dst_layer_md, cpu_engine }, dst_layer_n); + op.y_memory.push_back(dst_layer_memory_n); + + mkldnn::memory::dims dst_iter_tz_n = {L - 1, 1, nstates, N, H}; // ldsnc + auto dst_iter_md_n = mkldnn::memory::desc( + { dst_iter_tz_n }, mkldnn_dtype, mkldnn::memory::format::ldsnc); + DType* dst_iter_n = dst_layer_n + T * N * D * H; // (L - 1) * nstates * N * H + auto dst_iter_memory_n = + mkldnn::memory({ dst_iter_md_n, cpu_engine }, dst_iter_n); + op.hcy_memory.push_back(dst_iter_memory_n); + } + + if (L > 1 && D == 2) { + mkldnn::memory::dims weights_layer_tz = {1, D, H * D, ngates, H}; // ldigo + mkldnn::memory::dims weights_iter_tz = {1, D, H, ngates, H}; // ldigo + mkldnn::memory::dims bias_tz = {1, D, ngates, H}; + auto user_weight_layer_md = mkldnn::memory::desc( + { weights_layer_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo); + auto user_weight_iter_md = mkldnn::memory::desc( + { weights_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo); + auto user_bias_md = mkldnn::memory::desc({ bias_tz }, + mkldnn_dtype, mkldnn::memory::format::ldgo); + + auto user_src_layer_md = mkldnn::memory::desc( + { src_layer_tz }, mkldnn_dtype, mkldnn::memory::format::tnc); + auto user_src_layer_memory = mkldnn::memory({ user_src_layer_md, cpu_engine }); + op.x_memory.push_back(user_src_layer_memory); + + auto wx_md_n = mkldnn::memory::desc( + { weights_layer_tz }, mkldnn_dtype, mkldnn::memory::format::ldgoi); + auto wh_md_n = mkldnn::memory::desc( + { weights_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldgoi); + + for (int l = 0; l < L; l++) { + DType* weight_layer_n = workptr; // D * (H * D) * ngates * H + auto user_weight_layer_memory_n + = mkldnn::memory({ user_weight_layer_md, cpu_engine }, weight_layer_n); + op.wx_memory.push_back(user_weight_layer_memory_n); + + DType* weight_iter_n = weight_layer_n + + D * (H * D) * ngates * H; // D * H * ngates * H + auto user_weight_iter_memory_n + = mkldnn::memory({ user_weight_iter_md, cpu_engine }, weight_iter_n); + op.wh_memory.push_back(user_weight_iter_memory_n); + + DType* bias_n = weight_iter_n + D * H * ngates * H; // D * ngates * H + auto user_bias_memory_n = + mkldnn::memory({ user_bias_md, cpu_engine }, bias_n); + op.bias_memory.push_back(user_bias_memory_n); + workptr = bias_n + D * ngates * H; + } + + DType* wx_n = workptr; // D * ngates * (D * H) * H + DType* wh_n = wx_n + D * ngates * (D * H) * H; // D * ngates * H * H + auto wx_memory_n = + mkldnn::memory({ wx_md_n, cpu_engine }, wx_n); + auto wh_memory_n = + mkldnn::memory({ wh_md_n, cpu_engine }, wh_n); + op.concat_weight_memory.push_back(wx_memory_n); + op.concat_weight_memory.push_back(wh_memory_n); + + mkldnn::memory::dims src_iter_undi_tz = {1, 1, nstates, N, H}; // ldsnc + auto src_iter_undi_md = mkldnn::memory::desc( + { src_iter_undi_tz }, mkldnn_dtype, mkldnn::memory::format::ldsnc); + DType* src_iter_undi = wh_n + D * ngates * H * H; // nstates * N * H + auto src_iter_undi_memory = + mkldnn::memory({ src_iter_undi_md, cpu_engine }, src_iter_undi); + op.concat_iter_memory.push_back(src_iter_undi_memory_0); + + DType* src_iter_undi2 = src_iter_undi + nstates * N * H; // nstates * N * H + auto src_iter_undi2_memory = + mkldnn::memory({ src_iter_undi_md, cpu_engine }, src_iter_undi2); + op.concat_iter_memory.push_back(src_iter_undi2_memory); + + mkldnn::memory::dims src_iter_tz = {1, D, nstates, N, H}; // ldsnc + auto src_iter_md = mkldnn::memory::desc( + { src_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldsnc); + DType* src_iter = src_iter_undi2 + nstates * N * H; // D * nstates * N * H + auto src_iter_memory = + mkldnn::memory({ src_iter_md, cpu_engine }, src_iter); + op.concat_iter_memory.push_back(src_iter_memory); + op.hcx_memory.push_back(src_iter_memory); + + DType* dst_layer_n = src_iter + D * nstates * N * H; // T * N * D * H + auto dst_layer_memory_n + = mkldnn::memory({ dst_layer_md, cpu_engine }, dst_layer_n); + op.y_memory.push_back(dst_layer_memory_n); + + mkldnn::memory::dims dst_iter_tz_n = {1, D, nstates, N, H}; // ldsnc + auto dst_iter_md_n = mkldnn::memory::desc( + { dst_iter_tz_n }, mkldnn_dtype, mkldnn::memory::format::ldsnc); + DType* dst_iter_n = dst_layer_n + T * N * D * H; // D * nstates * N * H + auto dst_iter_memory_n = + mkldnn::memory({ dst_iter_md_n, cpu_engine }, dst_iter_n); + op.hcy_memory.push_back(dst_iter_memory_n); + } + } + } + op.Forward(ctx, in_blobs, req, out_blobs); + }); + }); +} +#endif + NNVM_REGISTER_OP(RNN) .describe(R"code(Applies recurrent layers to input data. Currently, vanilla RNN, LSTM and GRU are implemented, with both multi-layer and bidirectional support. @@ -269,8 +695,13 @@ The definition of GRU here is slightly different from paper but compatible with }) .set_attr("FInferShape", RNNShape) .set_attr("FInferType", RNNType) +.set_attr("FInferStorageType", RNNStorageType) .set_attr("FCreateOpState", CreateRNNState) .set_attr("FStatefulCompute", RNNStatefulCompute) +#if MXNET_USE_MKLDNN == 1 +.set_attr("TIsMKLDNN", true) +.set_attr("FStatefulComputeEx", RNNStatefulComputeCPU) +#endif .set_attr("FGradient", RNNGrad{"_backward_RNN"}) .set_attr("FResourceRequestEx", [](const NodeAttrs& attrs, const int dev_mask, const DispatchMode dispatch_mode) { diff --git a/src/operator/rnn_impl.h b/src/operator/rnn_impl.h index e1b4a2b79c0a..425ea4a3c6ab 100644 --- a/src/operator/rnn_impl.h +++ b/src/operator/rnn_impl.h @@ -44,6 +44,13 @@ namespace mxnet { namespace op { +namespace rnn_enum { + enum RNNOpInputs {kData, kParams, kState, kStateCell, kSequenceLength}; + enum RNNOpOutputs {kOut, kStateOut, kStateCellOut}; + enum RNNModeType {kRnnRelu, kRnnTanh, kLstm, kGru}; + enum RNNOpResource {kTempSpace, kCuDNNDropoutDescSpace}; +} + template inline DType sigmoid(DType x) { return 1.0f / (1.0f + exp(-x));