diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc index b40605bd25e2..b17fae4b3cf3 100644 --- a/src/imperative/cached_op.cc +++ b/src/imperative/cached_op.cc @@ -674,7 +674,7 @@ OpStatePtr CachedOp::StaticForward( std::lock_guard lock(state.mutex); bool match = SetForwardGraph(&state.info, recording, inputs); - match = match && state.recording != recording; + match = match && state.recording == recording; nnvm::Graph& g = state.info.fwd_graph; const auto& idx = g.indexed_graph(); diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index bb61af127240..5701a5df5a08 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -1285,6 +1285,17 @@ def test_legacy_save_params(): model.load_params('test.params', ctx=mx.cpu()) +def test_hybrid_static_memory_recording(): + net = gluon.model_zoo.vision.get_resnet( + 1, 18, pretrained=True, ctx=mx.context.current_context()) + net.hybridize(static_alloc=True) + + x = mx.nd.random.uniform(shape=(1, 3, 32, 32)) + with mx.autograd.record(True): + net(x) + net(x) + + if __name__ == '__main__': import nose nose.runmodule()