diff --git a/src/operator/nn/dnnl/dnnl_deconvolution.cc b/src/operator/nn/dnnl/dnnl_deconvolution.cc index f4766a12c7f3..b853d1a1e52e 100644 --- a/src/operator/nn/dnnl/dnnl_deconvolution.cc +++ b/src/operator/nn/dnnl/dnnl_deconvolution.cc @@ -75,18 +75,23 @@ DNNLDeconvFwd& DNNLDeconvFwd::GetCached(const DeconvolutionParam& param, const T std::shared_ptr DNNLDeconvFwd::CreatePrimitiveDesc(const DeconvolutionParam& param, const Tensors& tensors) { DeconvDescCreator ddc(param, tensors.data, tensors.weights, tensors.bias, tensors.out); + auto fwd_desc = ddc.CreateFwdDesc(); // `fwd_desc` lifetime must be longer than `pd` + // when using next_impl const auto& engine = CpuEngine::Get()->get_engine(); - const auto pd = std::make_shared(ddc.CreateFwdDesc(), engine); + const auto pd = std::make_shared(fwd_desc, engine); const auto get_data_size = [&pd]() { return pd->src_desc().get_size(); }; const auto get_weights_size = [&pd]() { return pd->weights_desc().get_size(); }; const auto get_out_size = [&pd]() { return pd->dst_desc().get_size(); }; while (!ddc.CheckImplSizeReq(get_data_size(), get_weights_size(), get_out_size())) { - // ImposePlainWherePadding fails when all memory descriptors already have plain formats - // imposed, meaning there is no implementation with plain formats - CHECK(ddc.ImposePlainWherePadding(get_data_size(), get_weights_size(), get_out_size())) - << "No implementation of deconvolution forward propagation"; - *pd = deconv_fwd_pd_t(ddc.CreateFwdDesc(), engine); + if (!pd->next_impl()) { + // ImposePlainWherePadding fails when all memory descriptors already have plain formats + // imposed, meaning there is no implementation with plain formats + CHECK(ddc.ImposePlainWherePadding(get_data_size(), get_weights_size(), get_out_size())) + << "No implementation of deconvolution forward propagation"; + fwd_desc = ddc.CreateFwdDesc(); + *pd = deconv_fwd_pd_t(fwd_desc, engine); + } } return pd; } @@ -204,18 +209,23 @@ std::shared_ptr DNNLDeconvBwd::CreateDataPrimitiveDesc( const deconv_fwd_pd_t& fwd_pd) { DeconvDescCreator ddc( param, read_tensors.data, read_tensors.weights, nullptr, read_tensors.out_grad); - const auto& engine = CpuEngine::Get()->get_engine(); - const auto pd = std::make_shared(ddc.CreateBwdDataDesc(), engine, fwd_pd); + auto bwd_d_desc = ddc.CreateBwdDataDesc(); // `bwd_d_desc` lifetime must be longer than `pd` + // when using next_impl + const auto& engine = CpuEngine::Get()->get_engine(); + const auto pd = std::make_shared(bwd_d_desc, engine, fwd_pd); const auto get_data_size = [&pd]() { return pd->diff_src_desc().get_size(); }; const auto get_weights_size = [&pd]() { return pd->weights_desc().get_size(); }; const auto get_out_size = [&pd]() { return pd->diff_dst_desc().get_size(); }; while (!ddc.CheckImplSizeReq(get_data_size(), get_weights_size(), get_out_size())) { - // ImposePlainWherePadding fails when all memory descriptors already have plain formats - // imposed, meaning there is no implementation with plain formats - CHECK(ddc.ImposePlainWherePadding(get_data_size(), get_weights_size(), get_out_size())) - << "No implementation of deconvolution backward propagation"; - *pd = deconv_bwd_data_pd_t(ddc.CreateBwdDataDesc(), engine, fwd_pd); + if (!pd->next_impl()) { + // ImposePlainWherePadding fails when all memory descriptors already have plain formats + // imposed, meaning there is no implementation with plain formats + CHECK(ddc.ImposePlainWherePadding(get_data_size(), get_weights_size(), get_out_size())) + << "No implementation of deconvolution backward propagation"; + bwd_d_desc = ddc.CreateBwdDataDesc(); + *pd = deconv_bwd_data_pd_t(bwd_d_desc, engine, fwd_pd); + } } return pd; } @@ -226,19 +236,23 @@ std::shared_ptr DNNLDeconvBwd::CreateWeightsPrimitiveDe const deconv_fwd_pd_t& fwd_pd) { DeconvDescCreator ddc( param, read_tensors.data, read_tensors.weights, read_tensors.bias, read_tensors.out_grad); - const auto& engine = CpuEngine::Get()->get_engine(); - const auto pd = - std::make_shared(ddc.CreateBwdWeightsDesc(), engine, fwd_pd); - const auto get_data_size = [&pd]() { return pd->src_desc().get_size(); }; + auto bwd_w_desc = ddc.CreateBwdWeightsDesc(); // `bwd_w_desc` lifetime must be longer than `pd` + // when using next_impl + const auto& engine = CpuEngine::Get()->get_engine(); + const auto pd = std::make_shared(bwd_w_desc, engine, fwd_pd); + const auto get_data_size = [&pd]() { return pd->src_desc().get_size(); }; const auto get_weights_size = [&pd]() { return pd->diff_weights_desc().get_size(); }; const auto get_out_size = [&pd]() { return pd->diff_dst_desc().get_size(); }; while (!ddc.CheckImplSizeReq(get_data_size(), get_weights_size(), get_out_size())) { - // ImposePlainWherePadding fails when all memory descriptors already have plain formats - // imposed, meaning there is no implementation with plain formats - CHECK(ddc.ImposePlainWherePadding(get_data_size(), get_weights_size(), get_out_size())) - << "No implementation of calculating deconvolution weights gradient"; - *pd = deconv_bwd_weights_pd_t(ddc.CreateBwdWeightsDesc(), engine, fwd_pd); + if (!pd->next_impl()) { + // ImposePlainWherePadding fails when all memory descriptors already have plain formats + // imposed, meaning there is no implementation with plain formats + CHECK(ddc.ImposePlainWherePadding(get_data_size(), get_weights_size(), get_out_size())) + << "No implementation of calculating deconvolution weights gradient"; + bwd_w_desc = ddc.CreateBwdWeightsDesc(); + *pd = deconv_bwd_weights_pd_t(bwd_w_desc, engine, fwd_pd); + } } return pd; }