diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc index d4da99ea9e85..1e7f8e0de1b3 100644 --- a/src/imperative/cached_op.cc +++ b/src/imperative/cached_op.cc @@ -821,12 +821,11 @@ OpStatePtr CachedOp::DynamicForward( const auto& dispatch_modes = g.GetAttr("dispatch_mode"); - if (recording && !inlining_) Imperative::Get()->set_is_recording(false); - + // If we are already recording, we don't need RunGraph to record all + // computation again. RunGraph(false, idx, arrays, 0, idx.num_nodes(), std::move(array_reqs), - std::move(ref_count), &states, dispatch_modes); - - Imperative::Get()->set_is_recording(recording); + std::move(ref_count), &states, dispatch_modes, + !recording || inlining_); return op_state; } @@ -947,7 +946,8 @@ void CachedOp::DynamicBackward( const auto& dispatch_modes = g.GetAttr("dispatch_mode"); RunGraph(retain_graph, idx, arrays, num_forward_nodes, idx.num_nodes(), - std::move(array_reqs), std::move(ref_count), &states, dispatch_modes); + std::move(array_reqs), std::move(ref_count), &states, dispatch_modes, + Imperative::Get()->is_recording()); if (retain_graph) { buff.resize(num_forward_entries); diff --git a/src/imperative/imperative.cc b/src/imperative/imperative.cc index e1654259a2fb..0c5ff8417754 100644 --- a/src/imperative/imperative.cc +++ b/src/imperative/imperative.cc @@ -495,7 +495,8 @@ std::vector Imperative::Backward( int prev_bulk_size = Engine::Get()->set_bulk_size(backward_bulk_size_); RunGraph(retain_graph, idx, arrays, num_forward_nodes, idx.num_nodes(), - std::move(array_reqs), std::move(ref_count), &states, dispatch_modes); + std::move(array_reqs), std::move(ref_count), &states, dispatch_modes, + is_recording()); Engine::Get()->set_bulk_size(prev_bulk_size); set_is_recording(prev_recording); diff --git a/src/imperative/imperative_utils.cc b/src/imperative/imperative_utils.cc index 464aefc220de..c84a3b9be502 100644 --- a/src/imperative/imperative_utils.cc +++ b/src/imperative/imperative_utils.cc @@ -30,7 +30,8 @@ void RunGraph( std::vector&& array_reqs, std::vector&& ref_count, std::vector *p_states, - const DispatchModeVector &dispatch_modes) { + const DispatchModeVector &dispatch_modes, + bool recording) { using namespace nnvm; using namespace imperative; static auto& createop = nnvm::Op::GetAttr("FCreateOpState"); @@ -40,7 +41,6 @@ void RunGraph( const auto imp = Imperative::Get(); std::vector& states = *p_states; - bool recording = imp->is_recording(); std::vector ndinputs, ndoutputs; ShapeVector arg_shapes; diff --git a/src/imperative/imperative_utils.h b/src/imperative/imperative_utils.h index 6daf96e60d0b..9c86843ca7af 100644 --- a/src/imperative/imperative_utils.h +++ b/src/imperative/imperative_utils.h @@ -994,7 +994,8 @@ void RunGraph(const bool retain_graph, std::vector&& array_reqs, std::vector&& ref_count, std::vector *p_states, - const DispatchModeVector &dispatch_modes); + const DispatchModeVector &dispatch_modes, + bool recording); } // namespace imperative } // namespace mxnet diff --git a/tests/python/unittest/test_contrib_control_flow.py b/tests/python/unittest/test_contrib_control_flow.py index 67ed78ee0308..f1188b53d814 100644 --- a/tests/python/unittest/test_contrib_control_flow.py +++ b/tests/python/unittest/test_contrib_control_flow.py @@ -1159,6 +1159,7 @@ def check_contrib_rnn(cell_type, num_states): configs = [ {}, + {'inline_limit': 0}, {'static_alloc': True}, {'static_alloc': True, 'static_shape': True} ] for config in configs: