From e819a984bafc76b3681f3d4b8eb3bad32c448f9a Mon Sep 17 00:00:00 2001 From: Eric Junyuan Xie Date: Fri, 15 Jun 2018 13:26:35 -0700 Subject: [PATCH 1/2] Update cached_op.cc --- src/imperative/cached_op.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc index b40605bd25e2..b17fae4b3cf3 100644 --- a/src/imperative/cached_op.cc +++ b/src/imperative/cached_op.cc @@ -674,7 +674,7 @@ OpStatePtr CachedOp::StaticForward( std::lock_guard lock(state.mutex); bool match = SetForwardGraph(&state.info, recording, inputs); - match = match && state.recording != recording; + match = match && state.recording == recording; nnvm::Graph& g = state.info.fwd_graph; const auto& idx = g.indexed_graph(); From 29faa9f08a2d35747480d3d95efc5206642c196d Mon Sep 17 00:00:00 2001 From: Junyuan Xie Date: Fri, 15 Jun 2018 14:49:51 -0700 Subject: [PATCH 2/2] fix --- tests/python/unittest/test_gluon.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index bb61af127240..5701a5df5a08 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -1285,6 +1285,17 @@ def test_legacy_save_params(): model.load_params('test.params', ctx=mx.cpu()) +def test_hybrid_static_memory_recording(): + net = gluon.model_zoo.vision.get_resnet( + 1, 18, pretrained=True, ctx=mx.context.current_context()) + net.hybridize(static_alloc=True) + + x = mx.nd.random.uniform(shape=(1, 3, 32, 32)) + with mx.autograd.record(True): + net(x) + net(x) + + if __name__ == '__main__': import nose nose.runmodule()