diff --git a/src/operator/rnn-inl.h b/src/operator/rnn-inl.h index 9785be209b7d..1046f01cf6e2 100644 --- a/src/operator/rnn-inl.h +++ b/src/operator/rnn-inl.h @@ -549,7 +549,6 @@ class RNNOp { CUDNN_CALL(cudnnDestroyTensorDescriptor(dy_desc_vec_[i])); } init_cudnn_ = false; - Storage::Get()->Free(temp_space_); Storage::Get()->Free(reserve_space_); } #if MXNET_USE_CUDNN_GE_7200 @@ -677,6 +676,12 @@ class RNNOp { Init(ctx, s, in_data, out_data); } + // Get temp space + int temp_size = workspace_size_; + Tensor temp_space = + ctx.requested[rnn_enum::kTempSpace].get_space_typed( + mshadow::Shape1(temp_size), s); + #if MXNET_USE_CUDNN_GE_7200 cudnnRNNDataLayout_t layout_t; @@ -770,7 +775,7 @@ class RNNOp { nullptr, nullptr, nullptr, - temp_space_.dptr, + temp_space.dptr_, workspace_byte_, reserve_space_.dptr, reserve_space_byte_)); @@ -792,7 +797,7 @@ class RNNOp { hy_ptr, cy_desc_, cy_ptr, - temp_space_.dptr, + temp_space.dptr_, workspace_byte_, reserve_space_.dptr, reserve_space_byte_)); @@ -823,7 +828,7 @@ class RNNOp { nullptr, nullptr, nullptr, - temp_space_.dptr, + temp_space.dptr_, workspace_byte_)); #else CUDNN_CALL(cudnnRNNForwardInference(s->dnn_handle_, @@ -843,7 +848,7 @@ class RNNOp { hy_ptr, cy_desc_, cy_ptr, - temp_space_.dptr, + temp_space.dptr_, workspace_byte_)); #endif } @@ -1061,6 +1066,12 @@ class RNNOp { Init(ctx, s, in_data, out_data); } + // Get temp space + int temp_size = workspace_size_; + Tensor temp_space = + ctx.requested[rnn_enum::kTempSpace].get_space_typed( + mshadow::Shape1(temp_size), s); + #if MXNET_USE_CUDNN_GE_7200 CUDNN_CALL(cudnnRNNBackwardDataEx(s->dnn_handle_, rnn_desc_, @@ -1088,7 +1099,7 @@ class RNNOp { dcx_ptr, nullptr, nullptr, - temp_space_.dptr, + temp_space.dptr_, workspace_byte_, reserve_space_.dptr, reserve_space_byte_)); @@ -1100,7 +1111,7 @@ class RNNOp { hx.dptr_, y_data_desc_, y.dptr_, - temp_space_.dptr, + temp_space.dptr_, workspace_byte_, dw_desc_, dw.dptr_, @@ -1130,7 +1141,7 @@ class RNNOp { dhx.dptr_, dcx_desc_, dcx_ptr, - temp_space_.dptr, + temp_space.dptr_, workspace_byte_, reserve_space_.dptr, reserve_space_byte_)); @@ -1143,7 +1154,7 @@ class RNNOp { hx.dptr_, y_desc_vec_.data(), y.dptr_, - temp_space_.dptr, + temp_space.dptr_, workspace_byte_, dw_desc_, dw.dptr_, @@ -1378,17 +1389,16 @@ class RNNOp { strideA)); // Create Dropout descriptors - DType* dropout_states_ = NULL; if (param_.p > 0) { ctx.requested[rnn_enum::kCuDNNDropoutDescSpace].get_cudnn_dropout_desc (&dropout_desc_, s, 1.0f - param_.p, seed_); - } else { - dropout_byte_ = 0; } - + // Only update the probability by passing in a null dropout_states ptr + DType* dropout_states = NULL; + size_t dropout_bytes = 0; CUDNN_CALL(cudnnSetDropoutDescriptor(dropout_desc_, s->dnn_handle_, param_.p, // discard probability - dropout_states_, dropout_byte_, + dropout_states, dropout_bytes, seed_)); // RNN descriptors @@ -1469,8 +1479,6 @@ class RNNOp { workspace_size_ = workspace_byte_ / sizeof(DType); // Allocate the reserve space reserve_space_ = Storage::Get()->Alloc(reserve_space_byte_, Context::GPU(s->dev_id)); - // Allocate the temp space - temp_space_ = Storage::Get()->Alloc(workspace_byte_, Context::GPU(s->dev_id)); // Check that number of params are correct size_t cudnn_param_size; CUDNN_CALL(cudnnGetRNNParamsSize(s->dnn_handle_, @@ -1539,9 +1547,9 @@ class RNNOp { cudnnDirectionMode_t direction_; cudnnRNNInputMode_t input_mode_; cudnnDropoutDescriptor_t dropout_desc_; - Storage::Handle reserve_space_, temp_space_; + Storage::Handle reserve_space_; uint64_t seed_ = 17 + rand() % 4096; // NOLINT(runtime/threadsafe_fn) - size_t workspace_byte_, reserve_space_byte_, dropout_byte_; + size_t workspace_byte_, reserve_space_byte_; int workspace_size_; std::vector x_desc_vec_, y_desc_vec_, dx_desc_vec_, dy_desc_vec_; #if MXNET_USE_CUDNN_GE_7200 diff --git a/src/operator/rnn.cc b/src/operator/rnn.cc index 32184943cac0..6a0dbd7a4e23 100644 --- a/src/operator/rnn.cc +++ b/src/operator/rnn.cc @@ -167,6 +167,22 @@ static bool RNNType(const nnvm::NodeAttrs& attrs, return true; } +static std::vector RNNResourceEx(const NodeAttrs& attrs, const int dev_mask, + const DispatchMode dispatch_mode) { + std::vector request; + if (dev_mask == kGPU) { +#if MXNET_USE_CUDNN_RNN + request.emplace_back(ResourceRequest::kTempSpace); + + const RNNParam& param = nnvm::get(attrs.parsed); + if (param.p != 0 && 1.0f - param.p > 0) { + request.emplace_back(ResourceRequest::kCuDNNDropoutDesc); + } +#endif + } + return request; +} + inline static bool RNNStorageType(const nnvm::NodeAttrs& attrs, const int dev_mask, DispatchMode* dispatch_mode, @@ -703,21 +719,7 @@ The definition of GRU here is slightly different from paper but compatible with .set_attr("FStatefulComputeEx", RNNStatefulComputeCPU) #endif .set_attr("FGradient", RNNGrad{"_backward_RNN"}) -.set_attr("FResourceRequestEx", - [](const NodeAttrs& attrs, const int dev_mask, const DispatchMode dispatch_mode) { - std::vector request; - if (dev_mask == kGPU) { -#if MXNET_USE_CUDNN_RNN - request.emplace_back(ResourceRequest::kTempSpace); - - const RNNParam& param = nnvm::get(attrs.parsed); - if (param.p != 0 && 1.0f - param.p > 0) { - request.emplace_back(ResourceRequest::kCuDNNDropoutDesc); - } -#endif - } - return request; -}) +.set_attr("FResourceRequestEx", RNNResourceEx) .add_argument("data", "NDArray-or-Symbol", "Input data to RNN") .add_argument("parameters", "NDArray-or-Symbol", "Vector of all RNN trainable parameters concatenated") @@ -737,6 +739,7 @@ NNVM_REGISTER_OP(_backward_RNN) .set_attr_parser(ParamParser) .set_attr("TIsLayerOpBackward", true) .set_attr("TIsBackward", true) -.set_attr("FStatefulCompute", RNNStatefulGradCompute); +.set_attr("FStatefulCompute", RNNStatefulGradCompute) +.set_attr("FResourceRequestEx", RNNResourceEx); } // namespace op } // namespace mxnet