diff --git a/src/operator/nn/activation.cc b/src/operator/nn/activation.cc index ba44ebd4ed4d..d646ec4d6f28 100644 --- a/src/operator/nn/activation.cc +++ b/src/operator/nn/activation.cc @@ -91,7 +91,7 @@ void ActivationGradComputeExCPU(const nnvm::NodeAttrs& attrs, const ActivationParam& param = nnvm::get(attrs.parsed); bool relu = param.act_type == activation::kReLU; CHECK_EQ(inputs.size(), relu ? 2U : 3U); - if (SupportMKLDNN(inputs[0])) { + if (SupportMKLDNN(inputs[0]) && ctx.need_grad) { MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs); // XXX: for y = relu(x), y is passed as "in_data" to Backward() MKLDNNActivationBackward(attrs, ctx, inputs[0], relu ? inputs[1] : inputs[2], req[0], diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc index 6254a1e18662..ea6d1b3470d7 100644 --- a/src/operator/nn/batch_norm.cc +++ b/src/operator/nn/batch_norm.cc @@ -421,7 +421,8 @@ void BatchNormGradComputeExCPU(const nnvm::NodeAttrs &attrs, TShape shape = inputs[0].shape(); // MKLDNN batchnorm only works well on the special MKLDNN layout. if (SupportMKLDNNBN(inputs[0], param) - && (inputs[3].IsMKLDNNData() || inputs[0].IsMKLDNNData())) { + && (inputs[3].IsMKLDNNData() || inputs[0].IsMKLDNNData()) + && ctx.need_grad) { std::vector out_grad(1); std::vector out_data(3); std::vector in_data(3); diff --git a/src/operator/nn/convolution.cc b/src/operator/nn/convolution.cc index 53b0c1380ed3..6c0681be9676 100644 --- a/src/operator/nn/convolution.cc +++ b/src/operator/nn/convolution.cc @@ -73,7 +73,7 @@ static void ConvolutionGradComputeExCPU(const nnvm::NodeAttrs& attrs, const std::vector& req, const std::vector& outputs) { const ConvolutionParam& params = nnvm::get(attrs.parsed); - if (SupportMKLDNNConv(params, inputs[0])) { + if (SupportMKLDNNConv(params, inputs[0]) && ctx.need_grad) { MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs); MKLDNNConvolutionBackward(attrs, ctx, inputs, req, outputs); MKLDNN_OPCHECK_RUN(ConvolutionGradCompute, attrs, ctx, inputs, req, outputs); diff --git a/src/operator/nn/deconvolution.cc b/src/operator/nn/deconvolution.cc index 039c732c831d..c86c31ddf2c8 100644 --- a/src/operator/nn/deconvolution.cc +++ b/src/operator/nn/deconvolution.cc @@ -312,7 +312,7 @@ static void DeconvolutionGradComputeExCPU(const nnvm::NodeAttrs& attrs, const std::vector& req, const std::vector& outputs) { const DeconvolutionParam& param = nnvm::get(attrs.parsed); - if (SupportMKLDNNDeconv(param, inputs[0])) { + if (SupportMKLDNNDeconv(param, inputs[0]) && ctx.need_grad) { MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs); MKLDNNDeconvolutionBackward(attrs, ctx, inputs, req, outputs); MKLDNN_OPCHECK_RUN(DeconvolutionGradCompute, attrs, ctx, inputs, req, diff --git a/src/operator/nn/fully_connected.cc b/src/operator/nn/fully_connected.cc index a178b2759bf9..9d68dc2f818f 100644 --- a/src/operator/nn/fully_connected.cc +++ b/src/operator/nn/fully_connected.cc @@ -141,7 +141,7 @@ void FullyConnectedGradComputeExCPU(const nnvm::NodeAttrs& attrs, const std::vector &inputs, const std::vector &req, const std::vector &outputs) { - if (SupportMKLDNN(inputs[0])) { + if (SupportMKLDNN(inputs[0]) && ctx.need_grad) { MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs); MKLDNNFCBackward(attrs, ctx, inputs, req, outputs); MKLDNN_OPCHECK_RUN(FullyConnectedGradCompute, attrs, ctx, inputs, req, diff --git a/src/operator/nn/lrn.cc b/src/operator/nn/lrn.cc index 020cb479acc6..49eff2ad6c71 100644 --- a/src/operator/nn/lrn.cc +++ b/src/operator/nn/lrn.cc @@ -133,7 +133,7 @@ void LRNGradComputeExCPU(const nnvm::NodeAttrs &attrs, const NDArray &in_data = inputs[1]; const NDArray &in_grad = outputs[0]; - if (SupportMKLDNN(inputs[0])) { + if (SupportMKLDNN(inputs[0]) && ctx.need_grad) { MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs); MKLDNNLRNBackward(ctx, param, out_grad, in_data, req[0], in_grad); MKLDNN_OPCHECK_RUN(LRNGradCompute, attrs, ctx, inputs, req, outputs); diff --git a/src/operator/nn/pooling.cc b/src/operator/nn/pooling.cc index 611568807a9a..d94684fb377e 100644 --- a/src/operator/nn/pooling.cc +++ b/src/operator/nn/pooling.cc @@ -270,7 +270,8 @@ void PoolingGradComputeExCPU(const nnvm::NodeAttrs &attrs, const OpContext &ctx, if (SupportMKLDNN(inputs[0]) - && SupportMKLDNNPooling(param, inputs[0].shape())) { + && SupportMKLDNNPooling(param, inputs[0].shape()) + && ctx.need_grad) { const NDArray &out_grad = inputs[0]; const NDArray *workspace = nullptr; const NDArray *in_data = nullptr; diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index 3049674821c9..5291d50ab946 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -1209,7 +1209,7 @@ def test_zero_grad(): grad = net.collect_params()['test_zero_grad_weight'].grad() assert_almost_equal(grad.asnumpy(), grad.asnumpy() * 0) -def check_hybrid_static_memory(**kwargs): +def check_hybrid_static_memory(train_modes, **kwargs): x = mx.nd.random.uniform(shape=(2, 3, 32, 32)) x.attach_grad() @@ -1221,27 +1221,29 @@ def check_hybrid_static_memory(**kwargs): net1(x) net2(x) - def test(net, x): - with mx.autograd.record(): + def test(net, x, train_mode=True): + with mx.autograd.record(train_mode=train_mode): y = net(x) + net(x) - y.backward() + y.backward(train_mode=train_mode) grads = {k: v.grad() for k, v in net.collect_params().items() if v.grad_req != 'null'} return y, grads - y1, grads1 = test(net1, x) - y2, grads2 = test(net2, x) + for train_mode in train_modes: + y1, grads1 = test(net1, x, train_mode) + y2, grads2 = test(net2, x, train_mode) - assert_almost_equal(y1.asnumpy(), y2.asnumpy(), rtol=1e-3, atol=1e-5) - for key in grads1: - assert_almost_equal(grads1[key].asnumpy(), grads2[key].asnumpy(), rtol=1e-3, atol=1e-5) + assert_almost_equal(y1.asnumpy(), y2.asnumpy(), rtol=1e-3, atol=1e-5) + for key in grads1: + assert_almost_equal(grads1[key].asnumpy(), grads2[key].asnumpy(), rtol=1e-3, atol=1e-5) @with_seed() def test_hybrid_static_memory(): - check_hybrid_static_memory() - check_hybrid_static_memory(static_alloc=True) - check_hybrid_static_memory(static_alloc=True, static_shape=True) + check_hybrid_static_memory(train_modes=[True, False]) + check_hybrid_static_memory(train_modes=[True, False], static_alloc=True) + # TODO: MKLDNN (issue #13445) does not work with static_shape backwards + check_hybrid_static_memory(train_modes=[True], static_alloc=True, static_shape=True) def check_hybrid_static_memory_switching(**kwargs): net = gluon.model_zoo.vision.get_resnet(