diff --git a/ci/build_windows.py b/ci/build_windows.py index 7ec24395e22e..8e7ed3b5376d 100755 --- a/ci/build_windows.py +++ b/ci/build_windows.py @@ -36,8 +36,8 @@ from util import * KNOWN_VCVARS = { - 'VS 2015': r'C:\Program Files (x86)\Microsoft Visual Studio 14.0\VC\bin\x86_amd64\vcvarsx86_amd64.bat', - 'VS 2017': r'C:\Program Files (x86)\Microsoft Visual Studio\2017\Community\VC\Auxiliary\Build\vcvarsx86_amd64.bat' + 'VS 2015': r'C:\Program Files (x86)\Microsoft Visual Studio 14.0\VC\bin\amd64\vcvarsamd64.bat', + 'VS 2017': r'C:\Program Files (x86)\Microsoft Visual Studio\2017\Community\VC\Auxiliary\Build\vcvars64.bat' } diff --git a/ci/docker/Dockerfile.build.ubuntu_gpu_cu101 b/ci/docker/Dockerfile.build.ubuntu_gpu_cu101 index 32f0a0a8d862..3f806482402c 100644 --- a/ci/docker/Dockerfile.build.ubuntu_gpu_cu101 +++ b/ci/docker/Dockerfile.build.ubuntu_gpu_cu101 @@ -67,7 +67,7 @@ RUN /work/ubuntu_docs.sh COPY install/ubuntu_tutorials.sh /work/ RUN /work/ubuntu_tutorials.sh -ENV CUDNN_VERSION=7.5.1.10 +ENV CUDNN_VERSION=7.6.0.64 COPY install/ubuntu_cudnn.sh /work/ RUN /work/ubuntu_cudnn.sh diff --git a/src/operator/nn/mkldnn/mkldnn_rnn_impl.h b/src/operator/nn/mkldnn/mkldnn_rnn_impl.h index ea8e07ea617c..bd72038125ae 100644 --- a/src/operator/nn/mkldnn/mkldnn_rnn_impl.h +++ b/src/operator/nn/mkldnn/mkldnn_rnn_impl.h @@ -39,6 +39,26 @@ namespace mxnet { namespace op { +struct MKLDNNRNNMemory { + std::vector concat_weight_memory; + std::vector concat_iter_memory; + std::vector x_memory; + std::vector hcx_memory; + std::vector wx_memory; + std::vector wh_memory; + std::vector bias_memory; + std::vector y_memory; + std::vector hcy_memory; + std::vector uni_states_memory; + std::vector concat_states_memory; + std::vector weight_layer_mems; + std::vector weight_iter_mems; + mkldnn::memory user_src_layer_memory_l; + + MKLDNNRNNMemory() : user_src_layer_memory_l( + null_memory(CpuEngine::Get()->get_engine())) {} +}; + static algorithm GetMKLDNNRNNAlgo(int mode, int* ngates, int* nstates) { @@ -52,7 +72,7 @@ static algorithm GetMKLDNNRNNAlgo(int mode, case rnn_enum::kGru: *ngates = 3; *nstates = 1; - algo = algorithm::vanilla_gru; + algo = algorithm::gru_linear_before_reset; break; case rnn_enum::kRnnRelu: case rnn_enum::kRnnTanh: @@ -72,112 +92,102 @@ static void ConcatData(mkldnn::memory::format src_format, std::vector srcs_cds, mkldnn::memory::dims dst_cds, mkldnn::memory::data_type mkldnn_dtype, - int concat_dimension, - std::vector srcs_data, - const mkldnn::memory &dst) { + const int concat_dimension, + const std::vector &srcs_data, + const mkldnn::memory &dst, + std::vector *tmp_src_mems) { auto cpu_engine = CpuEngine::Get()->get_engine(); std::vector srcs_pd; - std::vector srcs; + const bool initialized = tmp_src_mems->size() > 0; for (size_t i = 0; i < srcs_cds.size(); i++) { auto desc = mkldnn::memory::desc(srcs_cds[i], mkldnn_dtype, src_format); auto mpd = mkldnn::memory::primitive_desc(desc, cpu_engine); - auto src_memory = mkldnn::memory(mpd, srcs_data[i]); srcs_pd.push_back(mpd); - srcs.push_back(src_memory); - } - std::vector inputs; - for (size_t i = 0; i < srcs_cds.size(); i++) { - inputs.push_back(srcs[i]); + if (initialized) { + tmp_src_mems->at(i).set_data_handle(srcs_data[i]); + } else { + auto src_memory = mkldnn::memory(mpd, srcs_data[i]); + tmp_src_mems->push_back(src_memory); + } } + std::vector inputs(tmp_src_mems->begin(), tmp_src_mems->end()); auto dst_desc = mkldnn::memory::desc(dst_cds, mkldnn_dtype, dst_format); auto concat_pd = concat::primitive_desc(dst_desc, concat_dimension, srcs_pd); MKLDNNStream::Get()->RegisterPrim(concat(concat_pd, inputs, dst)); - MKLDNNStream::Get()->Submit(); } -// cached mkldnn memory -// first layer wx, wh with next L - 1 layers wx and wh -// with L layers hx and cx, src and dst data/iter etc. -// it will prepare memory on before and after reorder and concat. -// for unidirectional, it will fused as dim like 1 + (L - 1) when I != H. -// for bidirectional, it will fused as data + back_data (weight, bias, iter etc), -// also need to identify first layer and next layers -static size_t GetMKLDNNRNNCacheMemorySize(int L, - int D, - int T, - int N, - int I, - int H, - int mode) { - size_t size = 0; - switch (mode) { - case rnn_enum::kLstm: - size = 2 * (D * (I + H) * 4 * H + (L - 1) * D * (D * H + H) * 4 * H + - L * D * 2 * N * H) + T * N * D * H + L * 2 * D * 4 * H + (L + 2) * D * 2 * N * H + - 6 * D * (I + H + 2) * 4 * H + T * N * I * 2; - break; - case rnn_enum::kGru: - size = 2 * (D * (I + H) * 3 * H + (L - 1) * D * (D * H + H) * 3 * H + - L * D * 2 * N * H) + T * N * D * H + L * 2 * D * 3 * H + (L + 2) * D * 2 * N * H + - 6 * D * (I + H + 2) * 3 * H + T * N * I * 2; - break; - case rnn_enum::kRnnRelu: - case rnn_enum::kRnnTanh: - size = 2 * (D * (I + H) * 1 * H + (L - 1) * D * (D * H + H) * 1 * H + - L * D * 2 * N * H) + T * N * D * H + L * 2 * D * 1 * H + (L + 2) * D * 2 * N * H + - 6 * D * (I + H + 2) * 1 * H + T * N * I * 2; - break; - default: - LOG(FATAL) << "unknown RNN mode " << mode; - break; - } +/** + * Size of cached memory + * + * Cache memory of wx, wh from the first layer and next num_layer - 1 layers + * seperately, as well as the layer and iter memory for src and dst. + * Output states memory hx, hc and bias memory are also cached. It + * will prepare memory on before and after reorder and concat. For + * unidirectional, it will fused as dim like 1 + (num_layer - 1) when + * input_size != hidden_size. For bidirectional, it will fused as data + + * back_data (weight, bias, iter etc) + * + * @param num_layer Number of Layers + * @param direction Direction of the RNN implement. It should be 1 or 2. + * @param seq_len The maximum sequence length. + * @param batch_size Batch size. + * @param input_size Input channel. Also the dimension of the input feature. + * @param hidden_size Hidden state size. + * @return The required cache size. + */ +static size_t GetMKLDNNRNNCacheMemorySize(const size_t num_layer, + const size_t direction, + const size_t seq_len, + const size_t batch_size, + const size_t input_size, + const size_t hidden_size, + const size_t mode) { + int n_gates = 0, n_states = 0; + GetMKLDNNRNNAlgo(mode, &n_gates, &n_states); + const size_t n_bias = mode == rnn_enum::kGru ? n_gates + 1 : n_gates; + // sizes of single gates from a single cell + const size_t weights_size_0 = direction * (input_size + hidden_size) * hidden_size; + const size_t weights_size_n = direction * (direction * hidden_size + hidden_size) * hidden_size; + const size_t bias_size = direction * hidden_size; + const size_t src_iter_size = direction * batch_size * hidden_size; + const size_t dst_iter_size = direction * batch_size * hidden_size; + const size_t dst_layer_size = seq_len * batch_size * direction * hidden_size; + + size_t size = (weights_size_0 + weights_size_n * (num_layer - 1)) * n_gates * 2 + + bias_size * num_layer * n_bias + src_iter_size * num_layer * n_states * 2 + + dst_iter_size * num_layer * n_states + dst_layer_size * 2; return size; } template static void AdjustGruWeightGateOrder(DType* weight, - const int I, - const int H) { + const int input_size, + const int hidden_size) { // mxnet gru gate order is reset, update and new gates // mkldnn gru gate order is update, reset and new gates const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); + size_t single_weight_size = input_size * hidden_size; DType* weight_reset = weight; - DType* weight_update = weight + I * H; + DType* weight_update = weight + single_weight_size; #pragma omp parallel for num_threads(omp_threads) - for (int i = 0; i < I * H; i++) { + for (int i = 0; i < static_cast(single_weight_size); i++) { DType tmp = weight_update[i]; weight_update[i] = weight_reset[i]; weight_reset[i] = tmp; } } -template -static void AdjustGruBiasGateOrder(DType* bias, - const int H) { - // mxnet gru gate order is reset, update and new gates - // mkldnn gru gate order is update, reset and new gates - const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); - DType* bias_reset = bias; - DType* bias_update = bias + H; - #pragma omp parallel for num_threads(omp_threads) - for (int i = 0; i < H; i++) { - DType tmp = bias_update[i]; - bias_update[i] = bias_reset[i]; - bias_reset[i] = tmp; - } -} // since there is different sematics of MKLDNN's Fused RNN and MXNet FusedRNN, // bidirectional will be fused layer by layer, -// unidirectional will be done by fused 1 + fused (L - 1) layers or fused L layers(when I = H) - +// unidirectional will be done by fused 1 + fused (num_layer - 1) layers or fused num_layer +// layers(when input_size = hidden_size) template static void MKLDNNRNNForwardSingleLayerBi(bool state_outputs, - const int T, - const int N, - const int I, - const int H, + const int seq_len, + const int batch_size, + const int input_size, + const int hidden_size, DType* x_ptr, - mkldnn::memory *user_src_layer_memory, DType* hx_ptr, DType* cx_ptr, DType* w_ptr, @@ -185,95 +195,96 @@ static void MKLDNNRNNForwardSingleLayerBi(bool state_outputs, DType* y_ptr, DType* hy_ptr, DType* cy_ptr, - std::vector *concat_weight_memory, - std::vector *concat_iter_memory, - std::vector *x_memory, - std::vector *hcx_memory, - std::vector *wx_memory, - std::vector *wh_memory, - std::vector *bias_memory, - std::vector *y_memory, - std::vector *hcy_memory, + MKLDNNRNNMemory *mkldnn_mems, std::vector *rnn_forward_prim, int layer_index, bool *has_cache, - int lvalue, int dtype, bool is_train, int mode) { int ngates = 0, nstates = 0; algorithm nalgorithm = GetMKLDNNRNNAlgo(mode, &ngates, &nstates); + const int nbias = mode == rnn_enum::kGru ? ngates + 1 : ngates; mkldnn::memory::data_type mkldnn_dtype = get_mkldnn_type(dtype); - const int single_cell_size = N * H; - const int single_b_size = ngates * H; - DType* wx = w_ptr; // ngates * H, I - DType* wh = w_ptr + I * H * ngates; // ngates * H, H - DType* back_wx = w_ptr + ngates * H * (I + H); - DType* back_wh = back_wx + I * H * ngates; + const size_t single_cell_size = batch_size * hidden_size; + const size_t mx_single_b_sz = ngates * hidden_size; + DType* wx = w_ptr; // ngates * hidden_size, input_size + DType* wh = w_ptr + input_size * hidden_size * ngates; // ngates * hidden_size, hidden_size + DType* back_wx = w_ptr + ngates * hidden_size * (input_size + hidden_size); + DType* back_wh = back_wx + input_size * hidden_size * ngates; DType* bx = b_ptr; - DType* bh = b_ptr + H * ngates; - DType* back_bx = b_ptr + single_b_size * 2; - DType* back_bh = back_bx + H * ngates; + DType* bh = b_ptr + hidden_size * ngates; + DType* back_bx = b_ptr + mx_single_b_sz * 2; + DType* back_bh = back_bx + hidden_size * ngates; const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); auto cpu_engine = CpuEngine::Get()->get_engine(); auto null_memory_ = null_memory(cpu_engine); - int offset1 = 0, offset2 = 0; bool initialized = *has_cache; - mkldnn::memory::dims src_layer_tz = {T, N, I}; - mkldnn::memory::dims dst_layer_tz = {T, N, 2 * H}; - mkldnn::memory::dims weights_layer_tz = {1, 2, I, ngates, H}; // ldigo - mkldnn::memory::dims weights_layer_r_tz = {1, 1, I, ngates, H}; // ldigo for reorder - mkldnn::memory::dims weights_iter_tz = {1, 2, H, ngates, H}; // ldigo - mkldnn::memory::dims weights_iter_r_tz = {1, 1, H, ngates, H}; // ldigo for reorder - mkldnn::memory::dims bias_tz = {1, 2, ngates, H}; - mkldnn::memory::dims src_iter_tz = {1, 2, nstates, N, H}; // ldsnc - mkldnn::memory::dims dst_iter_tz = {1, 2, nstates, N, H}; // ldsnc - - if (!initialized) { + mkldnn::memory::dims src_layer_tz = {seq_len, batch_size, input_size}; + mkldnn::memory::dims dst_layer_tz = {seq_len, batch_size, 2 * hidden_size}; + mkldnn::memory::dims weights_layer_tz = {1, 2, input_size, ngates, hidden_size}; // ldigo + mkldnn::memory::dims weights_iter_tz = {1, 2, hidden_size, ngates, hidden_size}; // ldigo + mkldnn::memory::dims bias_tz = {1, 2, nbias, hidden_size}; // ldgo + mkldnn::memory::dims src_iter_tz = {1, 2, nstates, batch_size, hidden_size}; // ldsnc + mkldnn::memory::dims dst_iter_tz = {1, 2, nstates, batch_size, hidden_size}; // ldsnc + mkldnn::memory::dims weights_layer_r_tz = {1, 1, input_size, ngates, hidden_size}; + mkldnn::memory::dims weights_iter_r_tz = {1, 1, hidden_size, ngates, hidden_size}; + + bool has_adjusted = false; + if (!initialized || is_train) { if (mode == rnn_enum::kGru) { - AdjustGruWeightGateOrder(wx, I, H); - AdjustGruWeightGateOrder(back_wx, I, H); - AdjustGruWeightGateOrder(wh, H, H); - AdjustGruWeightGateOrder(back_wh, H, H); - AdjustGruBiasGateOrder(bx, H); - AdjustGruBiasGateOrder(back_bx, H); - AdjustGruBiasGateOrder(bh, H); - AdjustGruBiasGateOrder(back_bh, H); + AdjustGruWeightGateOrder(wx, input_size, hidden_size); + AdjustGruWeightGateOrder(back_wx, input_size, hidden_size); + AdjustGruWeightGateOrder(wh, hidden_size, hidden_size); + AdjustGruWeightGateOrder(back_wh, hidden_size, hidden_size); + has_adjusted = true; } - auto src_wx = (*concat_weight_memory)[2 * layer_index]; - auto src_wh = (*concat_weight_memory)[2 * layer_index + 1]; + mkldnn::memory& src_wx = mkldnn_mems->concat_weight_memory[2 * layer_index]; + mkldnn::memory& src_wh = mkldnn_mems->concat_weight_memory[2 * layer_index + 1]; std::vector srcs_data1; srcs_data1.push_back(wx); srcs_data1.push_back(back_wx); ConcatData(mkldnn::memory::format::ldgoi, mkldnn::memory::format::ldgoi, {weights_layer_r_tz, weights_layer_r_tz}, weights_layer_tz, - mkldnn_dtype, 1, srcs_data1, src_wx); + mkldnn_dtype, 1, srcs_data1, src_wx, &(mkldnn_mems->weight_layer_mems)); srcs_data1.clear(); srcs_data1.push_back(wh); srcs_data1.push_back(back_wh); ConcatData(mkldnn::memory::format::ldgoi, mkldnn::memory::format::ldgoi, {weights_iter_r_tz, weights_iter_r_tz}, weights_iter_tz, - mkldnn_dtype, 1, srcs_data1, src_wh); - int tmpvalue = 0; - if (lvalue > 0) { - tmpvalue = lvalue + 1; - } - MKLDNNStream::Get()->RegisterPrim(reorder(src_wx, (*wx_memory)[tmpvalue])); - MKLDNNStream::Get()->RegisterPrim(reorder(src_wh, (*wh_memory)[tmpvalue])); + mkldnn_dtype, 1, srcs_data1, src_wh, &(mkldnn_mems->weight_iter_mems)); + + MKLDNNStream::Get()->RegisterPrim(reorder(src_wx, mkldnn_mems->wx_memory[layer_index])); + MKLDNNStream::Get()->RegisterPrim(reorder(src_wh, mkldnn_mems->wh_memory[layer_index])); DType* user_bias = reinterpret_cast - ((*bias_memory)[tmpvalue].get_data_handle()); - #pragma omp parallel for num_threads(omp_threads) - for (int j = 0; j < single_b_size; j++) { - user_bias[j] = bx[j] + bh[j]; - user_bias[single_b_size + j] = back_bx[j] + back_bh[j]; + (mkldnn_mems->bias_memory[layer_index].get_data_handle()); + if (mode == rnn_enum::kGru) { + // While mxnet gru gate order is reset, update and new gates, + // mkldnn gru gate order is update, reset and new gates. So + // we need to swap the order of reset and update from mxnet. + const size_t single_b_sz = nbias * hidden_size; + #pragma omp parallel for num_threads(omp_threads) + for (int j = 0; j < hidden_size; j++) { + user_bias[j + hidden_size] = bx[j] + bh[j]; + user_bias[single_b_sz + j + hidden_size] = back_bx[j] + back_bh[j]; + + user_bias[j] = bx[j + hidden_size] + bh[j + hidden_size]; + user_bias[single_b_sz + j] = back_bx[j + hidden_size] + back_bh[j + hidden_size]; + + user_bias[j + 2 * hidden_size] = bx[j + 2 * hidden_size]; + user_bias[j + 3 * hidden_size] = bh[j + 2 * hidden_size]; + user_bias[single_b_sz + j + 2 * hidden_size] = back_bx[j + 2 * hidden_size]; + user_bias[single_b_sz + j + 3 * hidden_size] = back_bh[j + 2 * hidden_size]; + } + } else { + #pragma omp parallel for num_threads(omp_threads) + for (int j = 0; j < static_cast(mx_single_b_sz); j++) { + user_bias[j] = bx[j] + bh[j]; + user_bias[mx_single_b_sz + j] = back_bx[j] + back_bh[j]; + } } } - if (lvalue > 0) { - (*wx_memory)[layer_index].set_data_handle((*wx_memory)[lvalue + 1].get_data_handle()); - (*wh_memory)[layer_index].set_data_handle((*wh_memory)[lvalue + 1].get_data_handle()); - (*bias_memory)[layer_index].set_data_handle((*bias_memory)[lvalue + 1].get_data_handle()); - } auto src_layer_md = mkldnn::memory::desc( { src_layer_tz }, mkldnn_dtype, mkldnn::memory::format::tnc); @@ -286,36 +297,39 @@ static void MKLDNNRNNForwardSingleLayerBi(bool state_outputs, auto dst_iter_md = mkldnn::memory::desc( { dst_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldsnc); auto src_iter_md = mkldnn::memory::desc( - {src_iter_tz}, mkldnn_dtype, mkldnn::memory::format::ldsnc); - auto bias_md = mkldnn::memory::desc({bias_tz}, - mkldnn_dtype, mkldnn::memory::format::ldgo); + { src_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldsnc); + auto bias_md = mkldnn::memory::desc( + { bias_tz }, mkldnn_dtype, mkldnn::memory::format::ldgo); - auto user_src_iter_memory = (*concat_iter_memory)[2]; + mkldnn::memory& user_src_iter_memory = mkldnn_mems->concat_iter_memory[2]; if (mode == rnn_enum::kLstm) { std::vector srcs_data1; srcs_data1.push_back(hx_ptr); srcs_data1.push_back(cx_ptr); - auto tmp1_src_iter_memory = (*concat_iter_memory)[0]; + mkldnn::memory& tmp1_src_iter_memory = mkldnn_mems->concat_iter_memory[0]; ConcatData(mkldnn::memory::format::ldsnc, mkldnn::memory::format::ldsnc, - {{1, 1, 1, N, H}, {1, 1, 1, N, H}}, {1, 1, nstates, N, H}, mkldnn_dtype, 2, - srcs_data1, tmp1_src_iter_memory); + {{1, 1, 1, batch_size, hidden_size}, {1, 1, 1, batch_size, hidden_size}}, + {1, 1, nstates, batch_size, hidden_size}, mkldnn_dtype, 2, srcs_data1, + tmp1_src_iter_memory, &(mkldnn_mems->uni_states_memory)); std::vector srcs_data2; srcs_data2.push_back(hx_ptr + single_cell_size); srcs_data2.push_back(cx_ptr + single_cell_size); - auto tmp2_src_iter_memory = (*concat_iter_memory)[1]; + mkldnn::memory& tmp2_src_iter_memory = mkldnn_mems->concat_iter_memory[1]; ConcatData(mkldnn::memory::format::ldsnc, mkldnn::memory::format::ldsnc, - {{1, 1, 1, N, H}, {1, 1, 1, N, H}}, {1, 1, nstates, N, H}, mkldnn_dtype, 2, - srcs_data2, tmp2_src_iter_memory); + {{1, 1, 1, batch_size, hidden_size}, {1, 1, 1, batch_size, hidden_size}}, + {1, 1, nstates, batch_size, hidden_size}, mkldnn_dtype, 2, srcs_data2, + tmp2_src_iter_memory, &(mkldnn_mems->uni_states_memory)); std::vector srcs_data3; srcs_data3.push_back(reinterpret_cast(tmp1_src_iter_memory.get_data_handle())); srcs_data3.push_back(reinterpret_cast(tmp2_src_iter_memory.get_data_handle())); ConcatData(mkldnn::memory::format::ldsnc, mkldnn::memory::format::ldsnc, - {{1, 1, nstates, N, H}, {1, 1, nstates, N, H}}, {1, 2, nstates, N, H}, - mkldnn_dtype, 1, srcs_data3, user_src_iter_memory); + {{1, 1, nstates, batch_size, hidden_size}, {1, 1, nstates, batch_size, hidden_size}}, + {1, 2, nstates, batch_size, hidden_size}, mkldnn_dtype, 1, srcs_data3, + user_src_iter_memory, &(mkldnn_mems->concat_states_memory)); } else { user_src_iter_memory.set_data_handle(hx_ptr); } - (*hcx_memory)[layer_index].set_data_handle(user_src_iter_memory.get_data_handle()); + mkldnn_mems->hcx_memory[layer_index].set_data_handle(user_src_iter_memory.get_data_handle()); rnn_cell::desc rnn_cell(nalgorithm, mode == rnn_enum::kRnnRelu ? algorithm::eltwise_relu : algorithm::eltwise_tanh); @@ -329,54 +343,60 @@ static void MKLDNNRNNForwardSingleLayerBi(bool state_outputs, = rnn_forward::primitive_desc(layer_desc, cpu_engine); if (x_ptr && layer_index == 0) { - (*x_memory)[layer_index].set_data_handle(x_ptr); + mkldnn_mems->x_memory[layer_index].set_data_handle(x_ptr); } else { - (*x_memory)[layer_index].set_data_handle((*user_src_layer_memory).get_data_handle()); + mkldnn_mems->x_memory[layer_index].set_data_handle( + mkldnn_mems->user_src_layer_memory_l.get_data_handle()); } - (*y_memory)[layer_index].set_data_handle(y_ptr); - + mkldnn_mems->y_memory[layer_index].set_data_handle(y_ptr); if (rnn_forward_prim->size() <= (size_t)layer_index) { - primitive rnn_prim = rnn_forward(prim_desc, (*x_memory)[layer_index], - (*hcx_memory)[layer_index], (*wx_memory)[layer_index], - (*wh_memory)[layer_index], (*bias_memory)[layer_index], - (*y_memory)[layer_index], - (*hcy_memory)[layer_index], null_memory_); + primitive rnn_prim = rnn_forward(prim_desc, mkldnn_mems->x_memory[layer_index], + mkldnn_mems->hcx_memory[layer_index], mkldnn_mems->wx_memory[layer_index], + mkldnn_mems->wh_memory[layer_index], mkldnn_mems->bias_memory[layer_index], + mkldnn_mems->y_memory[layer_index], + mkldnn_mems->hcy_memory[layer_index], null_memory_); rnn_forward_prim->push_back(rnn_prim); } MKLDNNStream::Get()->RegisterPrim((*rnn_forward_prim)[layer_index]); MKLDNNStream::Get()->Submit(); if (state_outputs) { - DType* dst_hcy = reinterpret_cast ((*hcy_memory)[layer_index].get_data_handle()); + DType* dst_hcy = reinterpret_cast( + mkldnn_mems->hcy_memory[layer_index].get_data_handle()); if (mode == rnn_enum::kLstm) { - offset1 = nstates * single_cell_size; - offset2 = (nstates + 1) * single_cell_size; + size_t back_hy_offset = nstates * single_cell_size; + size_t back_cy_offset = (nstates + 1) * single_cell_size; #pragma omp parallel for num_threads(omp_threads) - for (int n = 0; n < single_cell_size; n++) { + for (int n = 0; n < static_cast(single_cell_size); n++) { hy_ptr[n] = dst_hcy[n]; - hy_ptr[n + single_cell_size] = dst_hcy[n + offset1]; + hy_ptr[n + single_cell_size] = dst_hcy[n + back_hy_offset]; cy_ptr[n] = dst_hcy[n + single_cell_size]; - cy_ptr[n + single_cell_size] = dst_hcy[n + offset2]; + cy_ptr[n + single_cell_size] = dst_hcy[n + back_cy_offset]; } } else { #pragma omp parallel for num_threads(omp_threads) - for (int n = 0; n < 2 * single_cell_size; n++) { + for (int n = 0; n < static_cast(2 * single_cell_size); n++) { hy_ptr[n] = dst_hcy[n]; } } } + if (has_adjusted) { + AdjustGruWeightGateOrder(wx, input_size, hidden_size); + AdjustGruWeightGateOrder(back_wx, input_size, hidden_size); + AdjustGruWeightGateOrder(wh, hidden_size, hidden_size); + AdjustGruWeightGateOrder(back_wh, hidden_size, hidden_size); + } } template -static void MKLDNNRNNForwardUnidi(bool state_outputs, - const int L, - const int T, - const int N, - const int I, - const int H, +static void MKLDNNRNNForwardUnidi(const bool state_outputs, + const int num_layer, + const int seq_len, + const int batch_size, + const int input_size, + const int hidden_size, DType* x_ptr, - mkldnn::memory *user_src_layer_memory, DType* hx_ptr, DType* cx_ptr, DType* w_ptr, @@ -384,15 +404,7 @@ static void MKLDNNRNNForwardUnidi(bool state_outputs, DType* y_ptr, DType* hy_ptr, DType* cy_ptr, - std::vector *concat_weight_memory, - std::vector *concat_iter_memory, - std::vector *x_memory, - std::vector *hcx_memory, - std::vector *wx_memory, - std::vector *wh_memory, - std::vector *bias_memory, - std::vector *y_memory, - std::vector *hcy_memory, + MKLDNNRNNMemory *mkldnn_mems, std::vector *rnn_forward_prim, int layer_index, bool *has_cache, @@ -401,26 +413,26 @@ static void MKLDNNRNNForwardUnidi(bool state_outputs, int mode) { int ngates = 0, nstates = 0; algorithm nalgorithm = GetMKLDNNRNNAlgo(mode, &ngates, &nstates); + const int nbias = (mode == rnn_enum::kGru ? ngates + 1 : ngates); mkldnn::memory::data_type mkldnn_dtype = get_mkldnn_type(dtype); - const int cell_size = N * H; - const int single_cell_size = N * H; - const int single_b_size = ngates * H; - int w_size = (I + H) * H * ngates; + const size_t cell_size = batch_size * hidden_size; + const size_t single_cell_size = batch_size * hidden_size; + const size_t single_b_size = nbias * hidden_size; + const size_t w_size = (input_size + hidden_size) * hidden_size * ngates; const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); auto cpu_engine = CpuEngine::Get()->get_engine(); auto null_memory_ = null_memory(cpu_engine); - int offset1 = 0, offset2 = 0; bool initialized = *has_cache; - mkldnn::memory::dims src_layer_tz = {T, N, I}; - mkldnn::memory::dims dst_layer_tz = {T, N, H}; - mkldnn::memory::dims weights_layer_tz = {L, 1, I, ngates, H}; // ldigo - mkldnn::memory::dims weights_iter_tz = {L, 1, H, ngates, H}; // ldigo - mkldnn::memory::dims bias_tz = {L, 1, ngates, H}; - mkldnn::memory::dims src_iter_tz = {L, 1, nstates, N, H}; // ldsnc - mkldnn::memory::dims dst_iter_tz = {L, 1, nstates, N, H}; // ldsnc - mkldnn::memory::dims weights_layer_r_tz = {1, 1, I, ngates, H}; // ldigo for reorder - mkldnn::memory::dims weights_iter_r_tz = {1, 1, H, ngates, H}; // ldigo for reorder + mkldnn::memory::dims src_layer_tz = {seq_len, batch_size, input_size}; + mkldnn::memory::dims dst_layer_tz = {seq_len, batch_size, hidden_size}; + mkldnn::memory::dims weights_layer_tz = {num_layer, 1, input_size, ngates, hidden_size}; // ldigo + mkldnn::memory::dims weights_iter_tz = {num_layer, 1, hidden_size, ngates, hidden_size}; // ldigo + mkldnn::memory::dims bias_tz = {num_layer, 1, nbias, hidden_size}; // ldgo + mkldnn::memory::dims src_iter_tz = {num_layer, 1, nstates, batch_size, hidden_size}; // ldsnc + mkldnn::memory::dims dst_iter_tz = {num_layer, 1, nstates, batch_size, hidden_size}; // ldsnc + mkldnn::memory::dims weights_layer_r_tz = {1, 1, input_size, ngates, hidden_size}; + mkldnn::memory::dims weights_iter_r_tz = {1, 1, hidden_size, ngates, hidden_size}; auto weight_layer_md = mkldnn::memory::desc( { weights_layer_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo); @@ -437,17 +449,18 @@ static void MKLDNNRNNForwardUnidi(bool state_outputs, auto dst_iter_md = mkldnn::memory::desc( {dst_iter_tz}, mkldnn_dtype, mkldnn::memory::format::ldsnc); - for (int l = 0; l < L; l++) { + for (int l = 0; l < num_layer; l++) { if (mode == rnn_enum::kLstm) { std::vector srcs_data; srcs_data.push_back(hx_ptr); srcs_data.push_back(cx_ptr); - auto tmp_src_iter_memory = (*concat_iter_memory)[l + layer_index]; + mkldnn::memory& tmp_src_iter_memory = mkldnn_mems->concat_iter_memory[l + layer_index]; ConcatData(mkldnn::memory::format::ldsnc, mkldnn::memory::format::ldsnc, - {{1, 1, 1, N, H}, {1, 1, 1, N, H}}, {1, 1, nstates, N, H}, mkldnn_dtype, - 2, srcs_data, tmp_src_iter_memory); + {{1, 1, 1, batch_size, hidden_size}, {1, 1, 1, batch_size, hidden_size}}, + {1, 1, nstates, batch_size, hidden_size}, mkldnn_dtype, 2, srcs_data, + tmp_src_iter_memory, &(mkldnn_mems->uni_states_memory)); } else { - (*concat_iter_memory)[l + layer_index].set_data_handle(hx_ptr); + mkldnn_mems->concat_iter_memory[l + layer_index].set_data_handle(hx_ptr); } hx_ptr += cell_size; if (mode == rnn_enum::kLstm) { @@ -455,73 +468,107 @@ static void MKLDNNRNNForwardUnidi(bool state_outputs, } } - auto user_src_iter_memory = null_memory_; - if (L == 1) { - user_src_iter_memory = (*concat_iter_memory)[layer_index]; + mkldnn::memory* user_src_iter_memory; + if (num_layer == 1) { + user_src_iter_memory = &(mkldnn_mems->concat_iter_memory[layer_index]); } else { - user_src_iter_memory = (*concat_iter_memory)[L + layer_index]; + user_src_iter_memory = &(mkldnn_mems->concat_iter_memory[num_layer + layer_index]); std::vector src_l_data; std::vector src_l_dim; - for (int l = 0; l < L; l++) { + for (int l = 0; l < num_layer; l++) { src_l_data.push_back(reinterpret_cast - ((*concat_iter_memory)[l + layer_index].get_data_handle())); - src_l_dim.push_back({1, 1, nstates, N, H}); + (mkldnn_mems->concat_iter_memory[l + layer_index].get_data_handle())); + src_l_dim.push_back({1, 1, nstates, batch_size, hidden_size}); } ConcatData(mkldnn::memory::format::ldsnc, mkldnn::memory::format::ldsnc, src_l_dim, - {L, 1, nstates, N, H}, mkldnn_dtype, 0, src_l_data, user_src_iter_memory); + {num_layer, 1, nstates, batch_size, hidden_size}, mkldnn_dtype, 0, src_l_data, + *user_src_iter_memory, &(mkldnn_mems->concat_states_memory)); } - (*hcx_memory)[layer_index].set_data_handle(user_src_iter_memory.get_data_handle()); + mkldnn_mems->hcx_memory[layer_index].set_data_handle(user_src_iter_memory->get_data_handle()); - auto src_wx_f = (*concat_weight_memory)[2 * layer_index]; - auto src_wh_f = (*concat_weight_memory)[2 * layer_index + 1]; + mkldnn::memory& src_wx_f = mkldnn_mems->concat_weight_memory[2 * layer_index]; + mkldnn::memory& src_wh_f = mkldnn_mems->concat_weight_memory[2 * layer_index + 1]; std::vector srcs_data_x; std::vector srcs_data_h; std::vector src_l_dim_x; std::vector src_l_dim_h; - if (!initialized) { - if (L == 1) { + + bool has_adjusted = false; + if (!initialized || is_train) { + if (num_layer == 1) { DType* wx = w_ptr; - DType* wh = w_ptr + I * H * ngates; + DType* wh = wx + input_size * hidden_size * ngates; if (mode == rnn_enum::kGru) { - AdjustGruWeightGateOrder(wx, I, H); - AdjustGruWeightGateOrder(wh, H, H); - AdjustGruBiasGateOrder(b_ptr, H); - AdjustGruBiasGateOrder(b_ptr + H * ngates, H); + AdjustGruWeightGateOrder(wx, input_size, hidden_size); + AdjustGruWeightGateOrder(wh, hidden_size, hidden_size); + has_adjusted = true; } src_wx_f.set_data_handle(wx); src_wh_f.set_data_handle(wh); } else { - for (int l = 0; l < L; l++) { - DType* wx = w_ptr; - DType* wh = w_ptr + I * H * ngates; - DType* bx = b_ptr + l * ngates * H * 2; - DType* bh = b_ptr + l * ngates * H * 2 + H * ngates; + for (int l = 0; l < num_layer; l++) { + DType* wx = w_ptr + l * w_size; + DType* wh = wx + input_size * hidden_size * ngates; if (mode == rnn_enum::kGru) { - AdjustGruWeightGateOrder(wx, I, H); - AdjustGruWeightGateOrder(wh, H, H); - AdjustGruBiasGateOrder(bx, H); - AdjustGruBiasGateOrder(bh, H); + AdjustGruWeightGateOrder(wx, input_size, hidden_size); + AdjustGruWeightGateOrder(wh, hidden_size, hidden_size); + has_adjusted = true; } srcs_data_x.push_back(wx); srcs_data_h.push_back(wh); src_l_dim_x.push_back(weights_layer_r_tz); src_l_dim_h.push_back(weights_iter_r_tz); - w_ptr = w_ptr + w_size; } ConcatData(mkldnn::memory::format::ldgoi, mkldnn::memory::format::ldgoi, - src_l_dim_x, weights_layer_tz, mkldnn_dtype, 0, srcs_data_x, src_wx_f); + src_l_dim_x, weights_layer_tz, mkldnn_dtype, 0, srcs_data_x, src_wx_f, + &(mkldnn_mems->weight_layer_mems)); ConcatData(mkldnn::memory::format::ldgoi, mkldnn::memory::format::ldgoi, - src_l_dim_h, weights_iter_tz, mkldnn_dtype, 0, srcs_data_h, src_wh_f); + src_l_dim_h, weights_iter_tz, mkldnn_dtype, 0, srcs_data_h, src_wh_f, + &(mkldnn_mems->weight_iter_mems)); } - MKLDNNStream::Get()->RegisterPrim(reorder(src_wx_f, (*wx_memory)[layer_index])); - MKLDNNStream::Get()->RegisterPrim(reorder(src_wh_f, (*wh_memory)[layer_index])); - - DType* user_bias_f = reinterpret_cast ((*bias_memory)[layer_index].get_data_handle()); - #pragma omp parallel for num_threads(omp_threads) - for (int j = 0; j < L * single_b_size; j++) { - int k = j / single_b_size; - user_bias_f[j] = b_ptr[j + k * single_b_size] + b_ptr[j + k * single_b_size + single_b_size]; + MKLDNNStream::Get()->RegisterPrim(reorder(src_wx_f, mkldnn_mems->wx_memory[layer_index])); + MKLDNNStream::Get()->RegisterPrim(reorder(src_wh_f, mkldnn_mems->wh_memory[layer_index])); + + DType* user_bias_f = reinterpret_cast( + mkldnn_mems->bias_memory[layer_index].get_data_handle()); + if (mode == rnn_enum::kGru) { + const size_t mx_single_b_sz = ngates * hidden_size; + //* NOTES: According to the instructions from https://bit.ly/2yMp8Cd, the collapse() + // directive is only supported in OpenMP 3.0 and higher. OpenMP 3.0 was released in + // May 2008 (hence the version number). + #if _OPENMP >= 200805 + # pragma omp parallel for num_threads(omp_threads) collapse(2) + #else + # pragma omp parallel for num_threads(omp_threads) + #endif + for (int l = 0; l < num_layer; l++) { + for (int g = 0; g < hidden_size; g++) { + // While mxnet gru gate order is reset, update and new gates, + // mkldnn gru gate order is update, reset and new gates. So + // we need to swap the order of reset and update from mxnet. + user_bias_f[g + hidden_size + l * single_b_size] = + b_ptr[g + l * mx_single_b_sz * 2] + + b_ptr[g + l * mx_single_b_sz * 2 + mx_single_b_sz]; + + user_bias_f[g + l * single_b_size] = + b_ptr[g + hidden_size + l * mx_single_b_sz * 2] + + b_ptr[g + hidden_size + l * mx_single_b_sz * 2 + mx_single_b_sz]; + + user_bias_f[g + l * single_b_size + 2 * hidden_size] = + b_ptr[g + l * mx_single_b_sz * 2 + 2 * hidden_size]; + user_bias_f[g + l * single_b_size + 3 * hidden_size] = + b_ptr[g + 2 * hidden_size + l * mx_single_b_sz * 2 + mx_single_b_sz]; + } + } + } else { + const size_t b_size = num_layer * single_b_size; + #pragma omp parallel for num_threads(omp_threads) + for (int j = 0; j < static_cast(b_size); j++) { + int k = j / single_b_size; + user_bias_f[j] = b_ptr[j + k * single_b_size] + + b_ptr[j + k * single_b_size + single_b_size]; + } } } @@ -537,52 +584,69 @@ static void MKLDNNRNNForwardUnidi(bool state_outputs, = rnn_forward::primitive_desc(layer_desc, cpu_engine); if (x_ptr && layer_index == 0) { - (*x_memory)[layer_index].set_data_handle(x_ptr); + mkldnn_mems->x_memory[layer_index].set_data_handle(x_ptr); } else { - (*x_memory)[layer_index].set_data_handle((*user_src_layer_memory).get_data_handle()); + mkldnn_mems->x_memory[layer_index].set_data_handle( + mkldnn_mems->user_src_layer_memory_l.get_data_handle()); } - (*y_memory)[layer_index].set_data_handle(y_ptr); - + mkldnn_mems->y_memory[layer_index].set_data_handle(y_ptr); if (rnn_forward_prim->size() <= (size_t)layer_index) { - primitive rnn_prim = rnn_forward(prim_desc, (*x_memory)[layer_index], - (*hcx_memory)[layer_index], (*wx_memory)[layer_index], - (*wh_memory)[layer_index], (*bias_memory)[layer_index], - (*y_memory)[layer_index], - (*hcy_memory)[layer_index], null_memory_); + primitive rnn_prim = rnn_forward(prim_desc, mkldnn_mems->x_memory[layer_index], + mkldnn_mems->hcx_memory[layer_index], mkldnn_mems->wx_memory[layer_index], + mkldnn_mems->wh_memory[layer_index], mkldnn_mems->bias_memory[layer_index], + mkldnn_mems->y_memory[layer_index], + mkldnn_mems->hcy_memory[layer_index], null_memory_); rnn_forward_prim->push_back(rnn_prim); } MKLDNNStream::Get()->RegisterPrim((*rnn_forward_prim)[layer_index]); MKLDNNStream::Get()->Submit(); if (state_outputs) { - DType* dst_hcy = reinterpret_cast ((*hcy_memory)[layer_index].get_data_handle()); + DType* dst_hcy = reinterpret_cast( + mkldnn_mems->hcy_memory[layer_index].get_data_handle()); if (mode == rnn_enum::kLstm) { - for (int l = 0; l < L; l++) { - offset1 = l * single_cell_size; - offset2 = l * nstates * single_cell_size; - #pragma omp parallel for num_threads(omp_threads) - for (int n = 0; n < single_cell_size; n++) { - hy_ptr[offset1 + n] = dst_hcy[offset2 + n]; - cy_ptr[offset1 + n] = dst_hcy[offset2 + n + single_cell_size]; + //* NOTES: According to the instructions from https://bit.ly/2yMp8Cd, the collapse() + // directive is only supported in OpenMP 3.0 and higher. OpenMP 3.0 was released in + // May 2008 (hence the version number). + #if _OPENMP >= 200805 + # pragma omp parallel for num_threads(omp_threads) collapse(2) + #else + # pragma omp parallel for num_threads(omp_threads) + #endif + for (int l = 0; l < num_layer; l++) { + for (int n = 0; n < static_cast(single_cell_size); n++) { + const size_t single_state_offset = l * single_cell_size; + const size_t concat_state_offset = l * nstates * single_cell_size; + hy_ptr[single_state_offset + n] = dst_hcy[concat_state_offset + n]; + cy_ptr[single_state_offset + n] = dst_hcy[concat_state_offset + n + single_cell_size]; } } } else { + const size_t cell_size = num_layer * single_cell_size; #pragma omp parallel for num_threads(omp_threads) - for (int n = 0; n < L * single_cell_size; n++) { + for (int n = 0; n < static_cast(cell_size); n++) { hy_ptr[n] = dst_hcy[n]; } } } + if (has_adjusted) { + for (int l = 0; l < num_layer; l++) { + DType* wx = w_ptr + l * w_size; + DType* wh = wx + input_size * hidden_size * ngates; + AdjustGruWeightGateOrder(wx, input_size, hidden_size); + AdjustGruWeightGateOrder(wh, hidden_size, hidden_size); + } + } } template -static void MKLDNNRNNForward(bool state_outputs, - const int L, - const int D, - const int T, - const int N, - const int I, - const int H, +static void MKLDNNRNNForward(const bool state_outputs, + const int num_layer, + const int direction, + const int seq_len, + const int batch_size, + const int input_size, + const int hidden_size, DType* x_ptr, DType* hx_ptr, DType* cx_ptr, @@ -591,15 +655,7 @@ static void MKLDNNRNNForward(bool state_outputs, DType* y_ptr, DType* hy_ptr, DType* cy_ptr, - std::vector *concat_weight_memory, - std::vector *concat_iter_memory, - std::vector *x_memory, - std::vector *hcx_memory, - std::vector *wx_memory, - std::vector *wh_memory, - std::vector *bias_memory, - std::vector *y_memory, - std::vector *hcy_memory, + MKLDNNRNNMemory *mkldnn_mems, std::vector *rnn_forward_prim, bool *has_cache, int dtype, @@ -607,44 +663,35 @@ static void MKLDNNRNNForward(bool state_outputs, int mode) { int ngates = 0, nstates = 0; GetMKLDNNRNNAlgo(mode, &ngates, &nstates); - const int b_size = 2 * H * ngates * D; - const int cell_size = N * H * D; + const int b_size = 2 * hidden_size * ngates * direction; + const int cell_size = batch_size * hidden_size * direction; // First layer - int w_size = (I + H) * H * ngates * D; - auto cpu_engine = CpuEngine::Get()->get_engine(); - auto null_memory_ = null_memory(cpu_engine); + int w_size = (input_size + hidden_size) * hidden_size * ngates * direction; DType* tmpNull = NULL; - // when D = 1 and I == H, L layers can be fused together - if (D == 1 && I == H) { - MKLDNNRNNForwardUnidi(state_outputs, L, T, N, I, H, x_ptr, &null_memory_, - hx_ptr, cx_ptr, w_ptr, b_ptr, y_ptr, hy_ptr, cy_ptr, concat_weight_memory, - concat_iter_memory, x_memory, hcx_memory, wx_memory, wh_memory, - bias_memory, y_memory, hcy_memory, rnn_forward_prim, - 0, has_cache, dtype, is_train, mode); + // when direction = 1 and input_size == hidden_size, num_layer layers can be fused together + if (direction == 1 && input_size == hidden_size) { + MKLDNNRNNForwardUnidi(state_outputs, num_layer, seq_len, batch_size, input_size, + hidden_size, x_ptr, hx_ptr, cx_ptr, w_ptr, b_ptr, y_ptr, hy_ptr, cy_ptr, + mkldnn_mems, rnn_forward_prim, 0, has_cache, dtype, is_train, mode); } else { - auto user_src_layer_memory_l = null_memory_; - if (D == 2) { - MKLDNNRNNForwardSingleLayerBi(state_outputs, T, N, I, H, x_ptr, &user_src_layer_memory_l, - hx_ptr, cx_ptr, w_ptr, b_ptr, y_ptr, hy_ptr, cy_ptr, concat_weight_memory, - concat_iter_memory, x_memory, hcx_memory, wx_memory, wh_memory, - bias_memory, y_memory, hcy_memory, rnn_forward_prim, - 0, has_cache, 0, dtype, is_train, mode); + if (direction == 2) { + MKLDNNRNNForwardSingleLayerBi(state_outputs, seq_len, batch_size, input_size, + hidden_size, x_ptr, hx_ptr, cx_ptr, w_ptr, b_ptr, y_ptr, hy_ptr, cy_ptr, + mkldnn_mems, rnn_forward_prim, 0, has_cache, dtype, is_train, mode); } else { - MKLDNNRNNForwardUnidi(state_outputs, 1, T, N, I, H, x_ptr, &user_src_layer_memory_l, - hx_ptr, cx_ptr, w_ptr, b_ptr, y_ptr, hy_ptr, cy_ptr, concat_weight_memory, - concat_iter_memory, x_memory, hcx_memory, wx_memory, wh_memory, - bias_memory, y_memory, hcy_memory, rnn_forward_prim, + MKLDNNRNNForwardUnidi(state_outputs, 1, seq_len, batch_size, input_size, hidden_size, x_ptr, + hx_ptr, cx_ptr, w_ptr, b_ptr, y_ptr, hy_ptr, cy_ptr, mkldnn_mems, rnn_forward_prim, 0, has_cache, dtype, is_train, mode); } - if (L > 1) { - user_src_layer_memory_l = (*y_memory)[0]; - // go to next L - 1 layers. - // If D = 2, do it layer by layer. If D = 1, fused L - 1 layers + if (num_layer > 1) { + mkldnn_mems->user_src_layer_memory_l = mkldnn_mems->y_memory[0]; + // go to next num_layer - 1 layers. + // If direction = 2, do it layer by layer. If direction = 1, fused num_layer - 1 layers w_ptr += w_size; b_ptr += b_size; - if (D == 2) { - w_size = (H * D + H) * H * ngates * D; - for (int l = 0; l < L - 1; l++) { + if (direction == 2) { + w_size = (hidden_size * direction + hidden_size) * hidden_size * ngates * direction; + for (int l = 0; l < num_layer - 1; l++) { if (state_outputs) { hy_ptr += cell_size; if (mode == rnn_enum::kLstm) { @@ -655,30 +702,27 @@ static void MKLDNNRNNForward(bool state_outputs, if (mode == rnn_enum::kLstm) { cx_ptr += cell_size; } - MKLDNNRNNForwardSingleLayerBi(state_outputs, T, N, D * H, H, tmpNull, - &user_src_layer_memory_l, hx_ptr, cx_ptr, w_ptr, b_ptr, y_ptr, hy_ptr, - cy_ptr, concat_weight_memory, concat_iter_memory, x_memory, - hcx_memory, wx_memory, wh_memory, bias_memory, - y_memory, hcy_memory, rnn_forward_prim, - 1, has_cache, l + 1, dtype, is_train, mode); - user_src_layer_memory_l = (*y_memory)[1]; + MKLDNNRNNForwardSingleLayerBi(state_outputs, seq_len, batch_size, + direction * hidden_size, hidden_size, tmpNull, hx_ptr, cx_ptr, w_ptr, b_ptr, + y_ptr, hy_ptr, cy_ptr, mkldnn_mems, rnn_forward_prim, 1, has_cache, dtype, + is_train, mode); + mkldnn_mems->user_src_layer_memory_l = mkldnn_mems->y_memory[1]; w_ptr += w_size; b_ptr += b_size; } } - if (D == 1) { + if (direction == 1) { if (state_outputs) { hy_ptr += cell_size; if (mode == rnn_enum::kLstm) { cy_ptr += cell_size; } } - w_size = (H + H) * H * ngates; - MKLDNNRNNForwardUnidi(state_outputs, L - 1, T, N, H, H, tmpNull, &user_src_layer_memory_l, - hx_ptr, cx_ptr, w_ptr, b_ptr, y_ptr, hy_ptr, cy_ptr, concat_weight_memory, - concat_iter_memory, x_memory, hcx_memory, wx_memory, - wh_memory, bias_memory, y_memory, hcy_memory, - rnn_forward_prim, 1, has_cache, dtype, is_train, mode); + w_size = (hidden_size + hidden_size) * hidden_size * ngates; + MKLDNNRNNForwardUnidi(state_outputs, num_layer - 1, seq_len, batch_size, + hidden_size, hidden_size, tmpNull, hx_ptr, cx_ptr, w_ptr, b_ptr, y_ptr, + hy_ptr, cy_ptr, mkldnn_mems, rnn_forward_prim, 1, has_cache, dtype, + is_train, mode); } } } @@ -686,7 +730,7 @@ static void MKLDNNRNNForward(bool state_outputs, } template -static void MKLDNNRNNForwardInference(bool state_outputs, +static void MKLDNNRNNForwardInference(const bool state_outputs, const int num_layers, const int direction, const int seq_length, @@ -701,15 +745,7 @@ static void MKLDNNRNNForwardInference(bool state_outputs, DType* y_ptr, DType* hy_ptr, DType* cy_ptr, - std::vector* concat_weight_memory, - std::vector* concat_iter_memory, - std::vector *x_memory, - std::vector *hcx_memory, - std::vector *wx_memory, - std::vector *wh_memory, - std::vector *bias_memory, - std::vector *y_memory, - std::vector *hcy_memory, + MKLDNNRNNMemory *mkldnn_mems, std::vector *rnn_forward_prim, bool *has_cache, int dtype, @@ -723,9 +759,7 @@ static void MKLDNNRNNForwardInference(bool state_outputs, MKLDNNRNNForward(state_outputs, num_layers, direction, seq_length, batch_size, input_size, state_size, x_ptr, hx_ptr, cx_ptr, w_ptr, b_ptr, y_ptr, hy_ptr, cy_ptr, - concat_weight_memory, concat_iter_memory, x_memory, - hcx_memory, wx_memory, wh_memory, - bias_memory, y_memory, hcy_memory, rnn_forward_prim, + mkldnn_mems, rnn_forward_prim, has_cache, dtype, is_train, mode); break; default: diff --git a/src/operator/rnn-inl.h b/src/operator/rnn-inl.h index 328e28de8537..c3c22ef463bd 100644 --- a/src/operator/rnn-inl.h +++ b/src/operator/rnn-inl.h @@ -39,8 +39,8 @@ #include #include #include +#include -#include "./math.h" #include "./math_functions-inl.h" #include "./operator_common.h" #include "./rnn_impl.h" @@ -396,22 +396,16 @@ class RNNOp { public: RNNParam param_; Context ctx_; + #if MXNET_USE_MKLDNN == 1 - std::vector concat_weight_memory; - std::vector concat_iter_memory; - std::vector rnn_forward_prim; - std::vector x_memory; - std::vector hcx_memory; - std::vector wx_memory; - std::vector wh_memory; - std::vector bias_memory; - std::vector y_memory; - std::vector hcy_memory; bool has_cache; bool init_mem_; size_t reserve_mem_size_; - Storage::Handle mem_space_; + std::shared_ptr > mem_space_; + MKLDNNRNNMemory mkldnn_mems; + std::vector rnn_forward_prim; #endif + explicit RNNOp(RNNParam param, Context ctx) { this->param_ = param; this->ctx_ = ctx; @@ -522,7 +516,6 @@ class RNNOp { ~RNNOp() { #if MXNET_USE_MKLDNN == 1 if (init_mem_) { - Storage::Get()->Free(mem_space_); init_mem_ = false; } #endif @@ -908,9 +901,7 @@ class RNNOp { param_.mode); } else { #if MXNET_USE_MKLDNN == 1 - if (dmlc::GetEnv("MXNET_USE_MKLDNN_RNN", 1) && param_.mode != rnn_enum::kGru) { - // TODO(zixuanweeei): MKLDNN GRU has precision issue. A stable one - // will be added to MXNet when we figure out the issue. + if (dmlc::GetEnv("MXNET_USE_MKLDNN_RNN", 1)) { int dtype = in_data[rnn_enum::kData].type_flag_; MKLDNNRNNForwardInference(param_.state_outputs, param_.num_layers, @@ -927,15 +918,7 @@ class RNNOp { y.dptr_, hy_ptr, cy_ptr, - &concat_weight_memory, - &concat_iter_memory, - &x_memory, - &hcx_memory, - &wx_memory, - &wh_memory, - &bias_memory, - &y_memory, - &hcy_memory, + &mkldnn_mems, &rnn_forward_prim, &has_cache, dtype, @@ -943,8 +926,6 @@ class RNNOp { param_.mode); } else { #endif - // Before integrating MKLDNN GRU fp32 inference - // using below code for keep func being OK const size_t work_cpu_space_size = GetRNNWorkspaceSize(param_.seq_length_, param_.batch_size_, param_.state_size, direction, param_.mode); @@ -1213,6 +1194,41 @@ class RNNOp { } private: + bool init_space_, temp_init_space_; + size_t reserve_cpu_space_size_, temp_cpu_space_size_; + Storage::Handle reserve_cpu_space_, temp_cpu_space_; + + #if MXNET_USE_CUDNN_RNN + cudnnDataType_t dtype_; + bool init_cudnn_; + cudnnRNNDescriptor_t rnn_desc_; + cudnnRNNMode_t mode_; + cudnnDirectionMode_t direction_; + cudnnRNNInputMode_t input_mode_; + cudnnDropoutDescriptor_t dropout_desc_; + Storage::Handle reserve_space_; + uint64_t seed_ = 17 + rand() % 4096; // NOLINT(runtime/threadsafe_fn) + size_t workspace_byte_, reserve_space_byte_; + int workspace_size_; + std::vector x_desc_vec_, y_desc_vec_, dx_desc_vec_, dy_desc_vec_; + #if MXNET_USE_CUDNN_GE_7200 + cudnnRNNDataDescriptor_t x_data_desc_, y_data_desc_, dx_data_desc_, dy_data_desc_; + DType padding_fill_ = 0; + #endif + cudnnTensorDescriptor_t hx_desc_, cx_desc_; + cudnnTensorDescriptor_t hy_desc_, cy_desc_; + cudnnTensorDescriptor_t dhx_desc_, dcx_desc_; + cudnnTensorDescriptor_t dhy_desc_, dcy_desc_; + + cudnnFilterDescriptor_t w_desc_, dw_desc_; + // Allow TensorCore algo policy + bool cudnn_tensor_core_; + + #if CUDNN_MAJOR >= 5 + cudnnTensorFormat_t format_; + #endif + #endif + inline void Init(const OpContext &ctx, mshadow::Stream *s, const std::vector &in_data, @@ -1539,39 +1555,6 @@ class RNNOp { } #endif } - #if MXNET_USE_CUDNN_RNN - cudnnDataType_t dtype_; - bool init_cudnn_; - cudnnRNNDescriptor_t rnn_desc_; - cudnnRNNMode_t mode_; - cudnnDirectionMode_t direction_; - cudnnRNNInputMode_t input_mode_; - cudnnDropoutDescriptor_t dropout_desc_; - Storage::Handle reserve_space_; - uint64_t seed_ = 17 + rand() % 4096; // NOLINT(runtime/threadsafe_fn) - size_t workspace_byte_, reserve_space_byte_; - int workspace_size_; - std::vector x_desc_vec_, y_desc_vec_, dx_desc_vec_, dy_desc_vec_; - #if MXNET_USE_CUDNN_GE_7200 - cudnnRNNDataDescriptor_t x_data_desc_, y_data_desc_, dx_data_desc_, dy_data_desc_; - DType padding_fill_ = 0; - #endif - cudnnTensorDescriptor_t hx_desc_, cx_desc_; - cudnnTensorDescriptor_t hy_desc_, cy_desc_; - cudnnTensorDescriptor_t dhx_desc_, dcx_desc_; - cudnnTensorDescriptor_t dhy_desc_, dcy_desc_; - - cudnnFilterDescriptor_t w_desc_, dw_desc_; - // Allow TensorCore algo policy - bool cudnn_tensor_core_; - - #if CUDNN_MAJOR >= 5 - cudnnTensorFormat_t format_; - #endif - #endif - bool init_space_, temp_init_space_; - size_t reserve_cpu_space_size_, temp_cpu_space_size_; - Storage::Handle reserve_cpu_space_, temp_cpu_space_; }; // class RNNOp static OpStatePtr CreateRNNState(const nnvm::NodeAttrs &attrs, diff --git a/src/operator/rnn.cc b/src/operator/rnn.cc index 244e39335a91..1d05598cbd59 100644 --- a/src/operator/rnn.cc +++ b/src/operator/rnn.cc @@ -170,10 +170,9 @@ static bool RNNType(const nnvm::NodeAttrs& attrs, static std::vector RNNResourceEx(const NodeAttrs& attrs, const int dev_mask, const DispatchMode dispatch_mode) { std::vector request; + request.emplace_back(ResourceRequest::kTempSpace); if (dev_mask == kGPU) { #if MXNET_USE_CUDNN_RNN - request.emplace_back(ResourceRequest::kTempSpace); - const RNNParam& param = nnvm::get(attrs.parsed); if (param.p != 0 && 1.0f - param.p > 0) { request.emplace_back(ResourceRequest::kCuDNNDropoutDesc); @@ -260,47 +259,44 @@ static void RNNStatefulComputeCPU(const OpStatePtr& state_ptr, const RNNParam& param = op.param_; int ngates = 0, nstates = 0; GetMKLDNNRNNAlgo(param.mode, &ngates, &nstates); - int D = param.bidirectional ? 2 : 1; + const int D = param.bidirectional ? 2 : 1; Tensor x = in_blobs[rnn_enum::kData].get(s); - int T = x.shape_[0]; - int N = x.shape_[1]; - int I = x.shape_[2]; - int H = param.state_size; - int L = param.num_layers; + const int T = x.shape_[0]; + const int N = x.shape_[1]; + const int I = x.shape_[2]; + const int H = param.state_size; + const int L = param.num_layers; + const int nbias = param.mode == rnn_enum::kGru ? ngates + 1 : ngates; const size_t r_size = GetMKLDNNRNNCacheMemorySize(L, D, T, N, I, H, param.mode); - if (op.init_mem_ && op.reserve_mem_size_ < r_size) { - Storage::Get()->Free(op.mem_space_); - op.init_mem_ = false; - } - if (!op.init_mem_) { - op.mem_space_ = Storage::Get()->Alloc( - r_size * sizeof(DType), - Context::CPU()); + if (!op.init_mem_ || op.reserve_mem_size_ < r_size) { + op.mem_space_ = std::make_shared >( + ctx.requested[rnn_enum::kTempSpace].get_space_typed( + Shape1(r_size), s)); op.reserve_mem_size_ = r_size; op.init_mem_ = true; op.has_cache = false; } - if (op.has_cache && op.x_memory.size() == 0) { + if (op.has_cache && op.mkldnn_mems.x_memory.size() == 0) { op.has_cache = false; } - DType* workptr = static_cast(op.mem_space_.dptr); + DType* workptr = static_cast(op.mem_space_->dptr_); mkldnn::memory::dims src_layer_tz_0 = {T, N, I}; mkldnn::memory::dims src_layer_tz = {T, N, D * H}; mkldnn::memory::dims dst_layer_tz = {T, N, D * H}; auto dst_layer_md = mkldnn::memory::desc( { dst_layer_tz }, mkldnn_dtype, mkldnn::memory::format::tnc); - if (op.x_memory.size() == 0) { + if (op.mkldnn_mems.x_memory.size() == 0) { if (D == 1 && I == H) { auto user_src_layer_md = mkldnn::memory::desc( { src_layer_tz }, mkldnn_dtype, mkldnn::memory::format::tnc); auto user_src_layer_memory_n = mkldnn::memory({ user_src_layer_md, cpu_engine }); - op.x_memory.push_back(user_src_layer_memory_n); + op.mkldnn_mems.x_memory.push_back(user_src_layer_memory_n); mkldnn::memory::dims weights_layer_tz = {L, 1, I, ngates, H}; // ldigo mkldnn::memory::dims weights_iter_tz = {L, 1, H, ngates, H}; // ldigo - mkldnn::memory::dims bias_tz = {L, 1, ngates, H}; + mkldnn::memory::dims bias_tz = {L, 1, nbias, H}; auto user_weight_layer_md = mkldnn::memory::desc( { weights_layer_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo); auto user_weight_iter_md = mkldnn::memory::desc( @@ -310,21 +306,22 @@ static void RNNStatefulComputeCPU(const OpStatePtr& state_ptr, DType* weight_layer_n = workptr; // L * I * ngates * H auto user_weight_layer_memory_n = mkldnn::memory({ user_weight_layer_md, cpu_engine }, weight_layer_n); - op.wx_memory.push_back(user_weight_layer_memory_n); + op.mkldnn_mems.wx_memory.push_back(user_weight_layer_memory_n); DType* weight_iter_n = weight_layer_n + L * I * ngates * H; // L * H * ngates * H auto user_weight_iter_memory_n = mkldnn::memory({ user_weight_iter_md, cpu_engine }, weight_iter_n); - op.wh_memory.push_back(user_weight_iter_memory_n); + op.mkldnn_mems.wh_memory.push_back(user_weight_iter_memory_n); - DType* bias_n = weight_iter_n + L * H * ngates * H; // L * ngates * H + DType* bias_n = weight_iter_n + L * H * ngates * H; // Generally, L * ngates * H + // LBR-Gru, L * (ngates + 1) * H auto user_bias_memory_n = mkldnn::memory({ user_bias_md, cpu_engine }, bias_n); - op.bias_memory.push_back(user_bias_memory_n); + op.mkldnn_mems.bias_memory.push_back(user_bias_memory_n); auto wx_md_n = mkldnn::memory::desc( { weights_layer_tz }, mkldnn_dtype, mkldnn::memory::format::ldgoi); - DType* wx_n = bias_n + L * ngates * H; // L * ngates * I * H + DType* wx_n = bias_n + L * nbias * H; // L * ngates * I * H auto wx_memory_n = mkldnn::memory({ wx_md_n, cpu_engine }, wx_n); DType* wh_n = wx_n + L * ngates * I * H; // L * ngates * H * H @@ -333,8 +330,8 @@ static void RNNStatefulComputeCPU(const OpStatePtr& state_ptr, auto wh_memory_n = mkldnn::memory({ wh_md_n, cpu_engine }, wh_n); - op.concat_weight_memory.push_back(wx_memory_n); - op.concat_weight_memory.push_back(wh_memory_n); + op.mkldnn_mems.concat_weight_memory.push_back(wx_memory_n); + op.mkldnn_mems.concat_weight_memory.push_back(wh_memory_n); workptr = wh_n + L * ngates * H * H; mkldnn::memory::dims src_iter_tz_n1 = {1, 1, nstates, N, H}; // ldsnc @@ -344,7 +341,7 @@ static void RNNStatefulComputeCPU(const OpStatePtr& state_ptr, DType* src_iter_n1 = workptr; // nstates * N * H auto src_iter_memory_n1 = mkldnn::memory({ src_iter_md_n1, cpu_engine }, src_iter_n1); - op.concat_iter_memory.push_back(src_iter_memory_n1); + op.mkldnn_mems.concat_iter_memory.push_back(src_iter_memory_n1); workptr = src_iter_n1 + nstates * N * H; } mkldnn::memory::dims src_iter_tz_n = {L, 1, nstates, N, H}; // ldsnc @@ -353,12 +350,12 @@ static void RNNStatefulComputeCPU(const OpStatePtr& state_ptr, DType* src_iter_n = workptr; // L * nstates * N * H auto src_iter_memory_n = mkldnn::memory({ src_iter_md_n, cpu_engine }, src_iter_n); - op.concat_iter_memory.push_back(src_iter_memory_n); - op.hcx_memory.push_back(src_iter_memory_n); + op.mkldnn_mems.concat_iter_memory.push_back(src_iter_memory_n); + op.mkldnn_mems.hcx_memory.push_back(src_iter_memory_n); DType* dst_layer_n = src_iter_n + L * nstates * N * H; // T * N * D * H auto dst_layer_memory_n = mkldnn::memory({ dst_layer_md, cpu_engine }, dst_layer_n); - op.y_memory.push_back(dst_layer_memory_n); + op.mkldnn_mems.y_memory.push_back(dst_layer_memory_n); mkldnn::memory::dims dst_iter_tz_n = {L, 1, nstates, N, H}; // ldsnc auto dst_iter_md_n = mkldnn::memory::desc( @@ -366,18 +363,18 @@ static void RNNStatefulComputeCPU(const OpStatePtr& state_ptr, DType* dst_iter_n = dst_layer_n + T * N * D * H; // L * nstates * N * H auto dst_iter_memory_n = mkldnn::memory({ dst_iter_md_n, cpu_engine }, dst_iter_n); - op.hcy_memory.push_back(dst_iter_memory_n); + op.mkldnn_mems.hcy_memory.push_back(dst_iter_memory_n); workptr = dst_iter_n + L * nstates * N * H; } else { auto user_src_layer_md_0 = mkldnn::memory::desc( { src_layer_tz_0 }, mkldnn_dtype, mkldnn::memory::format::tnc); auto user_src_layer_memory_0 = mkldnn::memory({ user_src_layer_md_0, cpu_engine }); - op.x_memory.push_back(user_src_layer_memory_0); + op.mkldnn_mems.x_memory.push_back(user_src_layer_memory_0); mkldnn::memory::dims weights_layer_tz_0 = {1, D, I, ngates, H}; // ldigo mkldnn::memory::dims weights_iter_tz_0 = {1, D, H, ngates, H}; // ldigo - mkldnn::memory::dims bias_tz_0 = {1, D, ngates, H}; + mkldnn::memory::dims bias_tz_0 = {1, D, nbias, H}; auto user_weight_layer_md_0 = mkldnn::memory::desc( { weights_layer_tz_0 }, mkldnn_dtype, mkldnn::memory::format::ldigo); auto user_weight_iter_md_0 = mkldnn::memory::desc( @@ -388,18 +385,19 @@ static void RNNStatefulComputeCPU(const OpStatePtr& state_ptr, DType* weight_layer_0 = workptr; // D * I * ngates * H auto user_weight_layer_memory_0 = mkldnn::memory({ user_weight_layer_md_0, cpu_engine }, weight_layer_0); - op.wx_memory.push_back(user_weight_layer_memory_0); + op.mkldnn_mems.wx_memory.push_back(user_weight_layer_memory_0); DType* weight_iter_0 = weight_layer_0 + D * I * ngates * H; // D * H * ngates * H auto user_weight_iter_memory_0 = mkldnn::memory({ user_weight_iter_md_0, cpu_engine }, weight_iter_0); - op.wh_memory.push_back(user_weight_iter_memory_0); + op.mkldnn_mems.wh_memory.push_back(user_weight_iter_memory_0); - DType* bias_0 = weight_iter_0 + D * H * ngates * H; // D * ngates * H + DType* bias_0 = weight_iter_0 + D * H * ngates * H; // Generally, D * ngates * H + // LBR-Gru, D * (ngates + 1) * H auto user_bias_memory_0 = mkldnn::memory({ user_bias_md_0, cpu_engine }, bias_0); - op.bias_memory.push_back(user_bias_memory_0); - workptr = bias_0 + D * ngates * H; + op.mkldnn_mems.bias_memory.push_back(user_bias_memory_0); + workptr = bias_0 + D * nbias * H; auto wx_md_0 = mkldnn::memory::desc( { weights_layer_tz_0 }, mkldnn_dtype, mkldnn::memory::format::ldgoi); @@ -416,8 +414,8 @@ static void RNNStatefulComputeCPU(const OpStatePtr& state_ptr, wh_memory_0.set_data_handle(wh_0); workptr = wh_0 + D * ngates * H * H; } - op.concat_weight_memory.push_back(wx_memory_0); - op.concat_weight_memory.push_back(wh_memory_0); + op.mkldnn_mems.concat_weight_memory.push_back(wx_memory_0); + op.mkldnn_mems.concat_weight_memory.push_back(wh_memory_0); mkldnn::memory::dims src_iter_undi_tz_0 = {1, 1, nstates, N, H}; // ldsnc auto src_iter_undi_md_0 = mkldnn::memory::desc( @@ -425,15 +423,15 @@ static void RNNStatefulComputeCPU(const OpStatePtr& state_ptr, DType* src_iter_undi_0 = workptr; // nstates * N * H auto src_iter_undi_memory_0 = mkldnn::memory({ src_iter_undi_md_0, cpu_engine }, src_iter_undi_0); - op.concat_iter_memory.push_back(src_iter_undi_memory_0); + op.mkldnn_mems.concat_iter_memory.push_back(src_iter_undi_memory_0); workptr = src_iter_undi_0 + nstates * N * H; if (D == 1) { - op.hcx_memory.push_back(src_iter_undi_memory_0); + op.mkldnn_mems.hcx_memory.push_back(src_iter_undi_memory_0); } else { DType* src_iter_undi2_0 = workptr; // nstates * N * H auto src_iter_undi2_memory_0 = mkldnn::memory({ src_iter_undi_md_0, cpu_engine }, src_iter_undi2_0); - op.concat_iter_memory.push_back(src_iter_undi2_memory_0); + op.mkldnn_mems.concat_iter_memory.push_back(src_iter_undi2_memory_0); mkldnn::memory::dims src_iter_tz_0 = {1, D, nstates, N, H}; // ldsnc auto src_iter_md_0 = mkldnn::memory::desc( @@ -441,15 +439,15 @@ static void RNNStatefulComputeCPU(const OpStatePtr& state_ptr, DType* src_iter_0 = src_iter_undi2_0 + nstates * N * H; // D * nstates * N * H auto src_iter_memory_0 = mkldnn::memory({ src_iter_md_0, cpu_engine }, src_iter_0); - op.concat_iter_memory.push_back(src_iter_memory_0); - op.hcx_memory.push_back(src_iter_memory_0); + op.mkldnn_mems.concat_iter_memory.push_back(src_iter_memory_0); + op.mkldnn_mems.hcx_memory.push_back(src_iter_memory_0); workptr = src_iter_0 + D * nstates * N * H; } DType* dst_layer_0 = workptr; // T * N * D * H auto dst_layer_memory_0 = mkldnn::memory({ dst_layer_md, cpu_engine }, dst_layer_0); - op.y_memory.push_back(dst_layer_memory_0); + op.mkldnn_mems.y_memory.push_back(dst_layer_memory_0); mkldnn::memory::dims dst_iter_tz_0 = {1, D, nstates, N, H}; // ldsnc auto dst_iter_md_0 = mkldnn::memory::desc( @@ -457,7 +455,7 @@ static void RNNStatefulComputeCPU(const OpStatePtr& state_ptr, DType* dst_iter_0 = dst_layer_0 + T * N * D * H; // D * nstates * N * H auto dst_iter_memory_0 = mkldnn::memory({ dst_iter_md_0, cpu_engine }, dst_iter_0); - op.hcy_memory.push_back(dst_iter_memory_0); + op.mkldnn_mems.hcy_memory.push_back(dst_iter_memory_0); workptr = dst_iter_0 + D * nstates * N * H; // next L - 1 layers @@ -465,11 +463,11 @@ static void RNNStatefulComputeCPU(const OpStatePtr& state_ptr, auto user_src_layer_md = mkldnn::memory::desc( { src_layer_tz }, mkldnn_dtype, mkldnn::memory::format::tnc); auto user_src_layer_memory = mkldnn::memory({ user_src_layer_md, cpu_engine }); - op.x_memory.push_back(user_src_layer_memory); + op.mkldnn_mems.x_memory.push_back(user_src_layer_memory); mkldnn::memory::dims weights_layer_tz = {L - 1, 1, H, ngates, H}; // ldigo mkldnn::memory::dims weights_iter_tz = {L - 1, 1, H, ngates, H}; // ldigo - mkldnn::memory::dims bias_tz = {L - 1, 1, ngates, H}; + mkldnn::memory::dims bias_tz = {L - 1, 1, nbias, H}; auto user_weight_layer_md = mkldnn::memory::desc( { weights_layer_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo); auto user_weight_iter_md = mkldnn::memory::desc( @@ -480,22 +478,24 @@ static void RNNStatefulComputeCPU(const OpStatePtr& state_ptr, DType* weight_layer_n = workptr; // (L - 1) * H * ngates * H auto user_weight_layer_memory_n = mkldnn::memory({ user_weight_layer_md, cpu_engine }, weight_layer_n); - op.wx_memory.push_back(user_weight_layer_memory_n); + op.mkldnn_mems.wx_memory.push_back(user_weight_layer_memory_n); DType* weight_iter_n = weight_layer_n + (L - 1) * H * ngates * H; // (L - 1) * H * ngates * H auto user_weight_iter_memory_n = mkldnn::memory({ user_weight_iter_md, cpu_engine }, weight_iter_n); - op.wh_memory.push_back(user_weight_iter_memory_n); + op.mkldnn_mems.wh_memory.push_back(user_weight_iter_memory_n); - DType* bias_n = weight_iter_n + (L - 1) * H * ngates * H; // (L - 1) * ngates * H + DType* bias_n = weight_iter_n + (L - 1) * H * ngates * H; // Generally, (L - 1) * + // ngates * H. LBR-Gru, + // (L -1) * (ngates + 1) * H auto user_bias_memory_n = mkldnn::memory({ user_bias_md, cpu_engine }, bias_n); - op.bias_memory.push_back(user_bias_memory_n); + op.mkldnn_mems.bias_memory.push_back(user_bias_memory_n); auto wx_md_n = mkldnn::memory::desc( { weights_layer_tz }, mkldnn_dtype, mkldnn::memory::format::ldgoi); - DType* wx_n = bias_n + (L - 1) * ngates * H; // (L - 1) * ngates * H * H + DType* wx_n = bias_n + (L - 1) * nbias * H; // (L - 1) * ngates * H * H auto wx_memory_n = mkldnn::memory({ wx_md_n, cpu_engine }, wx_n); DType* wh_n = wx_n + (L - 1) * ngates * H * H; // (L - 1) * ngates * H * H @@ -504,8 +504,8 @@ static void RNNStatefulComputeCPU(const OpStatePtr& state_ptr, auto wh_memory_n = mkldnn::memory({ wh_md_n, cpu_engine }, wh_n); - op.concat_weight_memory.push_back(wx_memory_n); - op.concat_weight_memory.push_back(wh_memory_n); + op.mkldnn_mems.concat_weight_memory.push_back(wx_memory_n); + op.mkldnn_mems.concat_weight_memory.push_back(wh_memory_n); workptr = wh_n + (L - 1) * ngates * H * H; mkldnn::memory::dims src_iter_tz_n1 = {1, 1, nstates, N, H}; // ldsnc @@ -515,7 +515,7 @@ static void RNNStatefulComputeCPU(const OpStatePtr& state_ptr, DType* src_iter_n1 = workptr; // nstates * N * H auto src_iter_memory_n1 = mkldnn::memory({ src_iter_md_n1, cpu_engine }, src_iter_n1); - op.concat_iter_memory.push_back(src_iter_memory_n1); + op.mkldnn_mems.concat_iter_memory.push_back(src_iter_memory_n1); workptr = src_iter_n1 + nstates * N * H; } mkldnn::memory::dims src_iter_tz_n = {L - 1, 1, nstates, N, H}; // ldsnc @@ -524,13 +524,13 @@ static void RNNStatefulComputeCPU(const OpStatePtr& state_ptr, DType* src_iter_n = workptr; // (L - 1) * nstates * N * H auto src_iter_memory_n = mkldnn::memory({ src_iter_md_n, cpu_engine }, src_iter_n); - op.concat_iter_memory.push_back(src_iter_memory_n); - op.hcx_memory.push_back(src_iter_memory_n); + op.mkldnn_mems.concat_iter_memory.push_back(src_iter_memory_n); + op.mkldnn_mems.hcx_memory.push_back(src_iter_memory_n); DType* dst_layer_n = src_iter_n + (L - 1) * nstates * N * H; // T * N * D * H auto dst_layer_memory_n = mkldnn::memory({ dst_layer_md, cpu_engine }, dst_layer_n); - op.y_memory.push_back(dst_layer_memory_n); + op.mkldnn_mems.y_memory.push_back(dst_layer_memory_n); mkldnn::memory::dims dst_iter_tz_n = {L - 1, 1, nstates, N, H}; // ldsnc auto dst_iter_md_n = mkldnn::memory::desc( @@ -538,13 +538,14 @@ static void RNNStatefulComputeCPU(const OpStatePtr& state_ptr, DType* dst_iter_n = dst_layer_n + T * N * D * H; // (L - 1) * nstates * N * H auto dst_iter_memory_n = mkldnn::memory({ dst_iter_md_n, cpu_engine }, dst_iter_n); - op.hcy_memory.push_back(dst_iter_memory_n); + op.mkldnn_mems.hcy_memory.push_back(dst_iter_memory_n); + workptr = dst_iter_n + (L - 1) * nstates * N * H; } if (L > 1 && D == 2) { mkldnn::memory::dims weights_layer_tz = {1, D, H * D, ngates, H}; // ldigo mkldnn::memory::dims weights_iter_tz = {1, D, H, ngates, H}; // ldigo - mkldnn::memory::dims bias_tz = {1, D, ngates, H}; + mkldnn::memory::dims bias_tz = {1, D, nbias, H}; auto user_weight_layer_md = mkldnn::memory::desc( { weights_layer_tz }, mkldnn_dtype, mkldnn::memory::format::ldigo); auto user_weight_iter_md = mkldnn::memory::desc( @@ -555,31 +556,30 @@ static void RNNStatefulComputeCPU(const OpStatePtr& state_ptr, auto user_src_layer_md = mkldnn::memory::desc( { src_layer_tz }, mkldnn_dtype, mkldnn::memory::format::tnc); auto user_src_layer_memory = mkldnn::memory({ user_src_layer_md, cpu_engine }); - op.x_memory.push_back(user_src_layer_memory); + op.mkldnn_mems.x_memory.push_back(user_src_layer_memory); auto wx_md_n = mkldnn::memory::desc( { weights_layer_tz }, mkldnn_dtype, mkldnn::memory::format::ldgoi); auto wh_md_n = mkldnn::memory::desc( { weights_iter_tz }, mkldnn_dtype, mkldnn::memory::format::ldgoi); - for (int l = 0; l < L; l++) { - DType* weight_layer_n = workptr; // D * (H * D) * ngates * H - auto user_weight_layer_memory_n - = mkldnn::memory({ user_weight_layer_md, cpu_engine }, weight_layer_n); - op.wx_memory.push_back(user_weight_layer_memory_n); - - DType* weight_iter_n = weight_layer_n + - D * (H * D) * ngates * H; // D * H * ngates * H - auto user_weight_iter_memory_n - = mkldnn::memory({ user_weight_iter_md, cpu_engine }, weight_iter_n); - op.wh_memory.push_back(user_weight_iter_memory_n); - - DType* bias_n = weight_iter_n + D * H * ngates * H; // D * ngates * H - auto user_bias_memory_n = - mkldnn::memory({ user_bias_md, cpu_engine }, bias_n); - op.bias_memory.push_back(user_bias_memory_n); - workptr = bias_n + D * ngates * H; - } + DType* weight_layer_n = workptr; // D * (H * D) * ngates * H + auto user_weight_layer_memory_n + = mkldnn::memory({ user_weight_layer_md, cpu_engine }, weight_layer_n); + op.mkldnn_mems.wx_memory.push_back(user_weight_layer_memory_n); + + DType* weight_iter_n = weight_layer_n + + D * (H * D) * ngates * H; // D * H * ngates * H + auto user_weight_iter_memory_n + = mkldnn::memory({ user_weight_iter_md, cpu_engine }, weight_iter_n); + op.mkldnn_mems.wh_memory.push_back(user_weight_iter_memory_n); + + DType* bias_n = weight_iter_n + D * H * ngates * H; // Generally, D * ngates * H + // LBR-Gru, D * (ngates + 1) * H + auto user_bias_memory_n = + mkldnn::memory({ user_bias_md, cpu_engine }, bias_n); + op.mkldnn_mems.bias_memory.push_back(user_bias_memory_n); + workptr = bias_n + D * nbias * H; DType* wx_n = workptr; // D * ngates * (D * H) * H DType* wh_n = wx_n + D * ngates * (D * H) * H; // D * ngates * H * H @@ -587,8 +587,8 @@ static void RNNStatefulComputeCPU(const OpStatePtr& state_ptr, mkldnn::memory({ wx_md_n, cpu_engine }, wx_n); auto wh_memory_n = mkldnn::memory({ wh_md_n, cpu_engine }, wh_n); - op.concat_weight_memory.push_back(wx_memory_n); - op.concat_weight_memory.push_back(wh_memory_n); + op.mkldnn_mems.concat_weight_memory.push_back(wx_memory_n); + op.mkldnn_mems.concat_weight_memory.push_back(wh_memory_n); mkldnn::memory::dims src_iter_undi_tz = {1, 1, nstates, N, H}; // ldsnc auto src_iter_undi_md = mkldnn::memory::desc( @@ -596,12 +596,12 @@ static void RNNStatefulComputeCPU(const OpStatePtr& state_ptr, DType* src_iter_undi = wh_n + D * ngates * H * H; // nstates * N * H auto src_iter_undi_memory = mkldnn::memory({ src_iter_undi_md, cpu_engine }, src_iter_undi); - op.concat_iter_memory.push_back(src_iter_undi_memory_0); + op.mkldnn_mems.concat_iter_memory.push_back(src_iter_undi_memory_0); DType* src_iter_undi2 = src_iter_undi + nstates * N * H; // nstates * N * H auto src_iter_undi2_memory = mkldnn::memory({ src_iter_undi_md, cpu_engine }, src_iter_undi2); - op.concat_iter_memory.push_back(src_iter_undi2_memory); + op.mkldnn_mems.concat_iter_memory.push_back(src_iter_undi2_memory); mkldnn::memory::dims src_iter_tz = {1, D, nstates, N, H}; // ldsnc auto src_iter_md = mkldnn::memory::desc( @@ -609,13 +609,13 @@ static void RNNStatefulComputeCPU(const OpStatePtr& state_ptr, DType* src_iter = src_iter_undi2 + nstates * N * H; // D * nstates * N * H auto src_iter_memory = mkldnn::memory({ src_iter_md, cpu_engine }, src_iter); - op.concat_iter_memory.push_back(src_iter_memory); - op.hcx_memory.push_back(src_iter_memory); + op.mkldnn_mems.concat_iter_memory.push_back(src_iter_memory); + op.mkldnn_mems.hcx_memory.push_back(src_iter_memory); DType* dst_layer_n = src_iter + D * nstates * N * H; // T * N * D * H auto dst_layer_memory_n = mkldnn::memory({ dst_layer_md, cpu_engine }, dst_layer_n); - op.y_memory.push_back(dst_layer_memory_n); + op.mkldnn_mems.y_memory.push_back(dst_layer_memory_n); mkldnn::memory::dims dst_iter_tz_n = {1, D, nstates, N, H}; // ldsnc auto dst_iter_md_n = mkldnn::memory::desc( @@ -623,7 +623,8 @@ static void RNNStatefulComputeCPU(const OpStatePtr& state_ptr, DType* dst_iter_n = dst_layer_n + T * N * D * H; // D * nstates * N * H auto dst_iter_memory_n = mkldnn::memory({ dst_iter_md_n, cpu_engine }, dst_iter_n); - op.hcy_memory.push_back(dst_iter_memory_n); + op.mkldnn_mems.hcy_memory.push_back(dst_iter_memory_n); + workptr = dst_iter_n + D * nstates * N * H; } } } diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 1991b16be317..c8677a5b0c9a 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -79,148 +79,175 @@ def check_rnn_consistency(cell1, cell2, T, N, I, H, grad_req, rtol=1e-2, atol=1e @with_seed() @assert_raises_cudnn_not_satisfied(min_version='5.1.10') def test_lstm_sym(): - T, N, I, H = 5, 32, 800, 800 - fused = mx.rnn.FusedRNNCell(H, num_layers=3, mode='lstm', get_next_state=True, prefix='') - stack = mx.rnn.SequentialRNNCell() - stack.add(mx.rnn.LSTMCell(H, prefix='l0_')) - stack.add(mx.rnn.LSTMCell(H, prefix='l1_')) - stack.add(mx.rnn.LSTMCell(H, prefix='l2_')) - - check_rnn_consistency(fused, stack, T, N, I, H, 'write') - check_rnn_consistency(fused, stack, T, N, I, H, 'add') - check_rnn_consistency(fused, stack, T, N, I, H, 'null') + Ts = [1, 5] + Ns = [1, 32] + Is = [32, 128, 512] + Hs = [32, 128, 512] + for T, N, I, H in itertools.product(Ts, Ns, Is, Hs): + fused = mx.rnn.FusedRNNCell(H, num_layers=3, mode='lstm', get_next_state=True, prefix='') + stack = mx.rnn.SequentialRNNCell() + stack.add(mx.rnn.LSTMCell(H, prefix='l0_')) + stack.add(mx.rnn.LSTMCell(H, prefix='l1_')) + stack.add(mx.rnn.LSTMCell(H, prefix='l2_')) + + check_rnn_consistency(fused, stack, T, N, I, H, 'write') + check_rnn_consistency(fused, stack, T, N, I, H, 'add') + check_rnn_consistency(fused, stack, T, N, I, H, 'null') @with_seed() @assert_raises_cudnn_not_satisfied(min_version='5.1.10') def test_lstm_bidirectional(): - T, N, I, H = 5, 20, 800, 800 - fused = mx.rnn.FusedRNNCell(H, num_layers=2, mode='lstm', - bidirectional=True, get_next_state=True, prefix='') - - stack = mx.rnn.SequentialRNNCell() - stack.add(mx.rnn.BidirectionalCell( - mx.rnn.LSTMCell(H, prefix='l0_'), - mx.rnn.LSTMCell(H, prefix='r0_'), - output_prefix='bi_lstm_0_')) - stack.add(mx.rnn.BidirectionalCell( - mx.rnn.LSTMCell(H, prefix='l1_'), - mx.rnn.LSTMCell(H, prefix='r1_'), - output_prefix='bi_lstm_1_')) - - check_rnn_consistency(fused, stack, T, N, I, H, 'write') - check_rnn_consistency(fused, stack, T, N, I, H, 'add') - check_rnn_consistency(fused, stack, T, N, I, H, 'null') + Ts = [1, 5] + Ns = [1, 32] + Is = [32, 128, 512] + Hs = [32, 128, 512] + for T, N, I, H in itertools.product(Ts, Ns, Is, Hs): + fused = mx.rnn.FusedRNNCell(H, num_layers=2, mode='lstm', + bidirectional=True, get_next_state=True, prefix='') + + stack = mx.rnn.SequentialRNNCell() + stack.add(mx.rnn.BidirectionalCell( + mx.rnn.LSTMCell(H, prefix='l0_'), + mx.rnn.LSTMCell(H, prefix='r0_'), + output_prefix='bi_lstm_0_')) + stack.add(mx.rnn.BidirectionalCell( + mx.rnn.LSTMCell(H, prefix='l1_'), + mx.rnn.LSTMCell(H, prefix='r1_'), + output_prefix='bi_lstm_1_')) + + check_rnn_consistency(fused, stack, T, N, I, H, 'write') + check_rnn_consistency(fused, stack, T, N, I, H, 'add') + check_rnn_consistency(fused, stack, T, N, I, H, 'null') @with_seed() @assert_raises_cudnn_not_satisfied(min_version='5.1.10') def test_gru_sym(): - T, N, I, H = 5, 32, 800, 800 - fused = mx.rnn.FusedRNNCell(H, num_layers=3, mode='gru', get_next_state=True, prefix='') - stack = mx.rnn.SequentialRNNCell() - stack.add(mx.rnn.GRUCell(H, prefix='l0_')) - stack.add(mx.rnn.GRUCell(H, prefix='l1_')) - stack.add(mx.rnn.GRUCell(H, prefix='l2_')) - - check_rnn_consistency(fused, stack, T, N, I, H, 'write') - check_rnn_consistency(fused, stack, T, N, I, H, 'add') - check_rnn_consistency(fused, stack, T, N, I, H, 'null') + Ts = [1, 5] + Ns = [1, 32] + Is = [32, 128, 512] + Hs = [32, 128, 512] + for T, N, I, H in itertools.product(Ts, Ns, Is, Hs): + fused = mx.rnn.FusedRNNCell(H, num_layers=3, mode='gru', get_next_state=True, prefix='') + stack = mx.rnn.SequentialRNNCell() + stack.add(mx.rnn.GRUCell(H, prefix='l0_')) + stack.add(mx.rnn.GRUCell(H, prefix='l1_')) + stack.add(mx.rnn.GRUCell(H, prefix='l2_')) + + check_rnn_consistency(fused, stack, T, N, I, H, 'write') + check_rnn_consistency(fused, stack, T, N, I, H, 'add') + check_rnn_consistency(fused, stack, T, N, I, H, 'null') @with_seed() @assert_raises_cudnn_not_satisfied(min_version='5.1.10') def test_gru_bidirectional(): - T, N, I, H = 5, 20, 800, 800 - - fused = mx.rnn.FusedRNNCell(H, num_layers=2, mode='gru', - bidirectional=True, get_next_state=True, prefix='') - - stack = mx.rnn.SequentialRNNCell() - stack.add(mx.rnn.BidirectionalCell( - mx.rnn.GRUCell(H, prefix='l0_'), - mx.rnn.GRUCell(H, prefix='r0_'), - output_prefix='bi_gru_0_')) - - stack.add(mx.rnn.BidirectionalCell( - mx.rnn.GRUCell(H, prefix='l1_'), - mx.rnn.GRUCell(H, prefix='r1_'), - output_prefix='bi_gru_1_')) - - check_rnn_consistency(fused, stack, T, N, I, H, 'write') - check_rnn_consistency(fused, stack, T, N, I, H, 'add') - check_rnn_consistency(fused, stack, T, N, I, H, 'null') + Ts = [1, 5] + Ns = [1, 32] + Is = [32, 128, 512] + Hs = [32, 128, 512] + for T, N, I, H in itertools.product(Ts, Ns, Is, Hs): + fused = mx.rnn.FusedRNNCell(H, num_layers=2, mode='gru', + bidirectional=True, get_next_state=True, prefix='') + + stack = mx.rnn.SequentialRNNCell() + stack.add(mx.rnn.BidirectionalCell( + mx.rnn.GRUCell(H, prefix='l0_'), + mx.rnn.GRUCell(H, prefix='r0_'), + output_prefix='bi_gru_0_')) + + stack.add(mx.rnn.BidirectionalCell( + mx.rnn.GRUCell(H, prefix='l1_'), + mx.rnn.GRUCell(H, prefix='r1_'), + output_prefix='bi_gru_1_')) + + check_rnn_consistency(fused, stack, T, N, I, H, 'write') + check_rnn_consistency(fused, stack, T, N, I, H, 'add') + check_rnn_consistency(fused, stack, T, N, I, H, 'null') @with_seed() @assert_raises_cudnn_not_satisfied(min_version='5.1.10') def test_rnntanh_sym(): - T, N, I, H = 5, 32, 800, 800 - - fused = mx.rnn.FusedRNNCell(H, num_layers=3, mode='rnn_tanh', get_next_state=True, prefix='') - stack = mx.rnn.SequentialRNNCell() - stack.add(mx.rnn.RNNCell(H, activation='tanh', prefix='l0_')) - stack.add(mx.rnn.RNNCell(H, activation='tanh', prefix='l1_')) - stack.add(mx.rnn.RNNCell(H, activation='tanh', prefix='l2_')) - - check_rnn_consistency(fused, stack, T, N, I, H, 'write') - check_rnn_consistency(fused, stack, T, N, I, H, 'add') - check_rnn_consistency(fused, stack, T, N, I, H, 'null') + Ts = [1, 5] + Ns = [1, 32] + Is = [32, 128, 512] + Hs = [32, 128, 512] + for T, N, I, H in itertools.product(Ts, Ns, Is, Hs): + fused = mx.rnn.FusedRNNCell(H, num_layers=3, mode='rnn_tanh', get_next_state=True, prefix='') + stack = mx.rnn.SequentialRNNCell() + stack.add(mx.rnn.RNNCell(H, activation='tanh', prefix='l0_')) + stack.add(mx.rnn.RNNCell(H, activation='tanh', prefix='l1_')) + stack.add(mx.rnn.RNNCell(H, activation='tanh', prefix='l2_')) + + check_rnn_consistency(fused, stack, T, N, I, H, 'write') + check_rnn_consistency(fused, stack, T, N, I, H, 'add') + check_rnn_consistency(fused, stack, T, N, I, H, 'null') @with_seed() @assert_raises_cudnn_not_satisfied(min_version='5.1.10') def test_rnntanh_bidirectional(): - T, N, I, H = 5, 20, 800, 800 - - fused = mx.rnn.FusedRNNCell(H, num_layers=2, mode='rnn_tanh', - bidirectional=True, get_next_state=True, prefix='') - - stack = mx.rnn.SequentialRNNCell() - stack.add(mx.rnn.BidirectionalCell( - mx.rnn.RNNCell(H, activation='tanh', prefix='l0_'), - mx.rnn.RNNCell(H, activation='tanh', prefix='r0_'), - output_prefix='bi_rnntanh_0_')) - stack.add(mx.rnn.BidirectionalCell( - mx.rnn.RNNCell(H, activation='tanh', prefix='l1_'), - mx.rnn.RNNCell(H, activation='tanh', prefix='r1_'), - output_prefix='bi_rnntanh_1_')) - - check_rnn_consistency(fused, stack, T, N, I, H, 'write') - check_rnn_consistency(fused, stack, T, N, I, H, 'add') - check_rnn_consistency(fused, stack, T, N, I, H, 'null') + Ts = [1, 5] + Ns = [1, 32] + Is = [32, 128, 512] + Hs = [32, 128, 512] + for T, N, I, H in itertools.product(Ts, Ns, Is, Hs): + fused = mx.rnn.FusedRNNCell(H, num_layers=2, mode='rnn_tanh', + bidirectional=True, get_next_state=True, prefix='') + + stack = mx.rnn.SequentialRNNCell() + stack.add(mx.rnn.BidirectionalCell( + mx.rnn.RNNCell(H, activation='tanh', prefix='l0_'), + mx.rnn.RNNCell(H, activation='tanh', prefix='r0_'), + output_prefix='bi_rnntanh_0_')) + stack.add(mx.rnn.BidirectionalCell( + mx.rnn.RNNCell(H, activation='tanh', prefix='l1_'), + mx.rnn.RNNCell(H, activation='tanh', prefix='r1_'), + output_prefix='bi_rnntanh_1_')) + + check_rnn_consistency(fused, stack, T, N, I, H, 'write') + check_rnn_consistency(fused, stack, T, N, I, H, 'add') + check_rnn_consistency(fused, stack, T, N, I, H, 'null') @with_seed() @assert_raises_cudnn_not_satisfied(min_version='5.1.10') def test_rnnrelu_sym(): - T, N, I, H = 5, 32, 200, 200 - - fused = mx.rnn.FusedRNNCell(H, num_layers=3, mode='rnn_relu', get_next_state=True, prefix='') - stack = mx.rnn.SequentialRNNCell() - stack.add(mx.rnn.RNNCell(H, activation='relu', prefix='l0_')) - stack.add(mx.rnn.RNNCell(H, activation='relu', prefix='l1_')) - stack.add(mx.rnn.RNNCell(H, activation='relu', prefix='l2_')) - - check_rnn_consistency(fused, stack, T, N, I, H, 'write') - check_rnn_consistency(fused, stack, T, N, I, H, 'add') - check_rnn_consistency(fused, stack, T, N, I, H, 'null') + Ts = [1, 5] + Ns = [1, 32] + Is = [32, 128, 512] + Hs = [32, 128, 512] + for T, N, I, H in itertools.product(Ts, Ns, Is, Hs): + fused = mx.rnn.FusedRNNCell(H, num_layers=3, mode='rnn_relu', get_next_state=True, prefix='') + stack = mx.rnn.SequentialRNNCell() + stack.add(mx.rnn.RNNCell(H, activation='relu', prefix='l0_')) + stack.add(mx.rnn.RNNCell(H, activation='relu', prefix='l1_')) + stack.add(mx.rnn.RNNCell(H, activation='relu', prefix='l2_')) + + check_rnn_consistency(fused, stack, T, N, I, H, 'write') + check_rnn_consistency(fused, stack, T, N, I, H, 'add') + check_rnn_consistency(fused, stack, T, N, I, H, 'null') @with_seed() @assert_raises_cudnn_not_satisfied(min_version='5.1.10') def test_rnnrelu_bidirectional(): - T, N, I, H = 5, 20, 200, 200 - - fused = mx.rnn.FusedRNNCell(H, num_layers=2, mode='rnn_relu', - bidirectional=True, get_next_state=True, prefix='') - - stack = mx.rnn.SequentialRNNCell() - stack.add(mx.rnn.BidirectionalCell( - mx.rnn.RNNCell(H, activation='relu', prefix='l0_'), - mx.rnn.RNNCell(H, activation='relu', prefix='r0_'), - output_prefix='bi_rnnrelu_0_')) - stack.add(mx.rnn.BidirectionalCell( - mx.rnn.RNNCell(H, activation='relu', prefix='l1_'), - mx.rnn.RNNCell(H, activation='relu', prefix='r1_'), - output_prefix='bi_rnnrelu_1_')) - - check_rnn_consistency(fused, stack, T, N, I, H, 'write', rtol=1e-2, atol=1e-2) - check_rnn_consistency(fused, stack, T, N, I, H, 'add', rtol=1e-2, atol=1e-2) - check_rnn_consistency(fused, stack, T, N, I, H, 'null', rtol=1e-2, atol=1e-2) + Ts = [1, 5] + Ns = [1, 32] + Is = [32, 128, 512] + Hs = [32, 128, 512] + for T, N, I, H in itertools.product(Ts, Ns, Is, Hs): + fused = mx.rnn.FusedRNNCell(H, num_layers=2, mode='rnn_relu', + bidirectional=True, get_next_state=True, prefix='') + + stack = mx.rnn.SequentialRNNCell() + stack.add(mx.rnn.BidirectionalCell( + mx.rnn.RNNCell(H, activation='relu', prefix='l0_'), + mx.rnn.RNNCell(H, activation='relu', prefix='r0_'), + output_prefix='bi_rnnrelu_0_')) + stack.add(mx.rnn.BidirectionalCell( + mx.rnn.RNNCell(H, activation='relu', prefix='l1_'), + mx.rnn.RNNCell(H, activation='relu', prefix='r1_'), + output_prefix='bi_rnnrelu_1_')) + + check_rnn_consistency(fused, stack, T, N, I, H, 'write', rtol=1e-2, atol=1e-2) + check_rnn_consistency(fused, stack, T, N, I, H, 'add', rtol=1e-2, atol=1e-2) + check_rnn_consistency(fused, stack, T, N, I, H, 'null', rtol=1e-2, atol=1e-2) @with_seed() def test_lstm_dropout():