diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc index 1e7f8e0de1b3..0c4c1e60208f 100644 --- a/src/imperative/cached_op.cc +++ b/src/imperative/cached_op.cc @@ -821,11 +821,12 @@ OpStatePtr CachedOp::DynamicForward( const auto& dispatch_modes = g.GetAttr("dispatch_mode"); - // If we are already recording, we don't need RunGraph to record all - // computation again. + // If CachedOp is running in the inline mode, it uses RunGraph to record + // computation; otherwise, CachedOp records computation itself. + // So if it's not the inline mode, we disable recording. RunGraph(false, idx, arrays, 0, idx.num_nodes(), std::move(array_reqs), std::move(ref_count), &states, dispatch_modes, - !recording || inlining_); + recording && inlining_); return op_state; }