From 3abb9d547dedbf3eea3c57aa3d7ee22b3db3cdab Mon Sep 17 00:00:00 2001 From: Da Zheng Date: Tue, 31 Jul 2018 16:26:01 +0000 Subject: [PATCH 1/2] fix nested call on cachedop. --- src/imperative/cached_op.cc | 12 ++++++------ src/imperative/imperative.cc | 3 ++- src/imperative/imperative_utils.cc | 4 ++-- src/imperative/imperative_utils.h | 3 ++- tests/python/unittest/test_contrib_control_flow.py | 1 + 5 files changed, 13 insertions(+), 10 deletions(-) diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc index d4da99ea9e85..ea133535f00d 100644 --- a/src/imperative/cached_op.cc +++ b/src/imperative/cached_op.cc @@ -821,12 +821,11 @@ OpStatePtr CachedOp::DynamicForward( const auto& dispatch_modes = g.GetAttr("dispatch_mode"); - if (recording && !inlining_) Imperative::Get()->set_is_recording(false); - + // We don't need to record when running the graph. The computation is recorded + // in forward. RunGraph(false, idx, arrays, 0, idx.num_nodes(), std::move(array_reqs), - std::move(ref_count), &states, dispatch_modes); - - Imperative::Get()->set_is_recording(recording); + std::move(ref_count), &states, dispatch_modes, + recording && !inlining_ ? false : true); return op_state; } @@ -947,7 +946,8 @@ void CachedOp::DynamicBackward( const auto& dispatch_modes = g.GetAttr("dispatch_mode"); RunGraph(retain_graph, idx, arrays, num_forward_nodes, idx.num_nodes(), - std::move(array_reqs), std::move(ref_count), &states, dispatch_modes); + std::move(array_reqs), std::move(ref_count), &states, dispatch_modes, + Imperative::Get()->is_recording()); if (retain_graph) { buff.resize(num_forward_entries); diff --git a/src/imperative/imperative.cc b/src/imperative/imperative.cc index e1654259a2fb..0c5ff8417754 100644 --- a/src/imperative/imperative.cc +++ b/src/imperative/imperative.cc @@ -495,7 +495,8 @@ std::vector Imperative::Backward( int prev_bulk_size = Engine::Get()->set_bulk_size(backward_bulk_size_); RunGraph(retain_graph, idx, arrays, num_forward_nodes, idx.num_nodes(), - std::move(array_reqs), std::move(ref_count), &states, dispatch_modes); + std::move(array_reqs), std::move(ref_count), &states, dispatch_modes, + is_recording()); Engine::Get()->set_bulk_size(prev_bulk_size); set_is_recording(prev_recording); diff --git a/src/imperative/imperative_utils.cc b/src/imperative/imperative_utils.cc index 464aefc220de..c84a3b9be502 100644 --- a/src/imperative/imperative_utils.cc +++ b/src/imperative/imperative_utils.cc @@ -30,7 +30,8 @@ void RunGraph( std::vector&& array_reqs, std::vector&& ref_count, std::vector *p_states, - const DispatchModeVector &dispatch_modes) { + const DispatchModeVector &dispatch_modes, + bool recording) { using namespace nnvm; using namespace imperative; static auto& createop = nnvm::Op::GetAttr("FCreateOpState"); @@ -40,7 +41,6 @@ void RunGraph( const auto imp = Imperative::Get(); std::vector& states = *p_states; - bool recording = imp->is_recording(); std::vector ndinputs, ndoutputs; ShapeVector arg_shapes; diff --git a/src/imperative/imperative_utils.h b/src/imperative/imperative_utils.h index 6daf96e60d0b..9c86843ca7af 100644 --- a/src/imperative/imperative_utils.h +++ b/src/imperative/imperative_utils.h @@ -994,7 +994,8 @@ void RunGraph(const bool retain_graph, std::vector&& array_reqs, std::vector&& ref_count, std::vector *p_states, - const DispatchModeVector &dispatch_modes); + const DispatchModeVector &dispatch_modes, + bool recording); } // namespace imperative } // namespace mxnet diff --git a/tests/python/unittest/test_contrib_control_flow.py b/tests/python/unittest/test_contrib_control_flow.py index 67ed78ee0308..f1188b53d814 100644 --- a/tests/python/unittest/test_contrib_control_flow.py +++ b/tests/python/unittest/test_contrib_control_flow.py @@ -1159,6 +1159,7 @@ def check_contrib_rnn(cell_type, num_states): configs = [ {}, + {'inline_limit': 0}, {'static_alloc': True}, {'static_alloc': True, 'static_shape': True} ] for config in configs: From b2542288713097a7803310ae338783afd9716737 Mon Sep 17 00:00:00 2001 From: Da Zheng Date: Tue, 31 Jul 2018 17:40:35 +0000 Subject: [PATCH 2/2] fix. --- src/imperative/cached_op.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc index ea133535f00d..1e7f8e0de1b3 100644 --- a/src/imperative/cached_op.cc +++ b/src/imperative/cached_op.cc @@ -821,11 +821,11 @@ OpStatePtr CachedOp::DynamicForward( const auto& dispatch_modes = g.GetAttr("dispatch_mode"); - // We don't need to record when running the graph. The computation is recorded - // in forward. + // If we are already recording, we don't need RunGraph to record all + // computation again. RunGraph(false, idx, arrays, 0, idx.num_nodes(), std::move(array_reqs), std::move(ref_count), &states, dispatch_modes, - recording && !inlining_ ? false : true); + !recording || inlining_); return op_state; }