Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 36 additions & 22 deletions src/operator/nn/dnnl/dnnl_deconvolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,18 +75,23 @@ DNNLDeconvFwd& DNNLDeconvFwd::GetCached(const DeconvolutionParam& param, const T
std::shared_ptr<deconv_fwd_pd_t> DNNLDeconvFwd::CreatePrimitiveDesc(const DeconvolutionParam& param,
const Tensors& tensors) {
DeconvDescCreator ddc(param, tensors.data, tensors.weights, tensors.bias, tensors.out);
auto fwd_desc = ddc.CreateFwdDesc(); // `fwd_desc` lifetime must be longer than `pd`
// when using next_impl
const auto& engine = CpuEngine::Get()->get_engine();
const auto pd = std::make_shared<deconv_fwd_pd_t>(ddc.CreateFwdDesc(), engine);
const auto pd = std::make_shared<deconv_fwd_pd_t>(fwd_desc, engine);
const auto get_data_size = [&pd]() { return pd->src_desc().get_size(); };
const auto get_weights_size = [&pd]() { return pd->weights_desc().get_size(); };
const auto get_out_size = [&pd]() { return pd->dst_desc().get_size(); };

while (!ddc.CheckImplSizeReq(get_data_size(), get_weights_size(), get_out_size())) {
// ImposePlainWherePadding fails when all memory descriptors already have plain formats
// imposed, meaning there is no implementation with plain formats
CHECK(ddc.ImposePlainWherePadding(get_data_size(), get_weights_size(), get_out_size()))
<< "No implementation of deconvolution forward propagation";
*pd = deconv_fwd_pd_t(ddc.CreateFwdDesc(), engine);
if (!pd->next_impl()) {
// ImposePlainWherePadding fails when all memory descriptors already have plain formats
// imposed, meaning there is no implementation with plain formats
CHECK(ddc.ImposePlainWherePadding(get_data_size(), get_weights_size(), get_out_size()))
<< "No implementation of deconvolution forward propagation";
fwd_desc = ddc.CreateFwdDesc();
*pd = deconv_fwd_pd_t(fwd_desc, engine);
}
}
return pd;
}
Expand Down Expand Up @@ -204,18 +209,23 @@ std::shared_ptr<deconv_bwd_data_pd_t> DNNLDeconvBwd::CreateDataPrimitiveDesc(
const deconv_fwd_pd_t& fwd_pd) {
DeconvDescCreator ddc(
param, read_tensors.data, read_tensors.weights, nullptr, read_tensors.out_grad);
const auto& engine = CpuEngine::Get()->get_engine();
const auto pd = std::make_shared<deconv_bwd_data_pd_t>(ddc.CreateBwdDataDesc(), engine, fwd_pd);
auto bwd_d_desc = ddc.CreateBwdDataDesc(); // `bwd_d_desc` lifetime must be longer than `pd`
// when using next_impl
const auto& engine = CpuEngine::Get()->get_engine();
const auto pd = std::make_shared<deconv_bwd_data_pd_t>(bwd_d_desc, engine, fwd_pd);
const auto get_data_size = [&pd]() { return pd->diff_src_desc().get_size(); };
const auto get_weights_size = [&pd]() { return pd->weights_desc().get_size(); };
const auto get_out_size = [&pd]() { return pd->diff_dst_desc().get_size(); };

while (!ddc.CheckImplSizeReq(get_data_size(), get_weights_size(), get_out_size())) {
// ImposePlainWherePadding fails when all memory descriptors already have plain formats
// imposed, meaning there is no implementation with plain formats
CHECK(ddc.ImposePlainWherePadding(get_data_size(), get_weights_size(), get_out_size()))
<< "No implementation of deconvolution backward propagation";
*pd = deconv_bwd_data_pd_t(ddc.CreateBwdDataDesc(), engine, fwd_pd);
if (!pd->next_impl()) {
// ImposePlainWherePadding fails when all memory descriptors already have plain formats
// imposed, meaning there is no implementation with plain formats
CHECK(ddc.ImposePlainWherePadding(get_data_size(), get_weights_size(), get_out_size()))
<< "No implementation of deconvolution backward propagation";
bwd_d_desc = ddc.CreateBwdDataDesc();
*pd = deconv_bwd_data_pd_t(bwd_d_desc, engine, fwd_pd);
}
}
return pd;
}
Expand All @@ -226,19 +236,23 @@ std::shared_ptr<deconv_bwd_weights_pd_t> DNNLDeconvBwd::CreateWeightsPrimitiveDe
const deconv_fwd_pd_t& fwd_pd) {
DeconvDescCreator ddc(
param, read_tensors.data, read_tensors.weights, read_tensors.bias, read_tensors.out_grad);
const auto& engine = CpuEngine::Get()->get_engine();
const auto pd =
std::make_shared<deconv_bwd_weights_pd_t>(ddc.CreateBwdWeightsDesc(), engine, fwd_pd);
const auto get_data_size = [&pd]() { return pd->src_desc().get_size(); };
auto bwd_w_desc = ddc.CreateBwdWeightsDesc(); // `bwd_w_desc` lifetime must be longer than `pd`
// when using next_impl
const auto& engine = CpuEngine::Get()->get_engine();
const auto pd = std::make_shared<deconv_bwd_weights_pd_t>(bwd_w_desc, engine, fwd_pd);
const auto get_data_size = [&pd]() { return pd->src_desc().get_size(); };
const auto get_weights_size = [&pd]() { return pd->diff_weights_desc().get_size(); };
const auto get_out_size = [&pd]() { return pd->diff_dst_desc().get_size(); };

while (!ddc.CheckImplSizeReq(get_data_size(), get_weights_size(), get_out_size())) {
// ImposePlainWherePadding fails when all memory descriptors already have plain formats
// imposed, meaning there is no implementation with plain formats
CHECK(ddc.ImposePlainWherePadding(get_data_size(), get_weights_size(), get_out_size()))
<< "No implementation of calculating deconvolution weights gradient";
*pd = deconv_bwd_weights_pd_t(ddc.CreateBwdWeightsDesc(), engine, fwd_pd);
if (!pd->next_impl()) {
// ImposePlainWherePadding fails when all memory descriptors already have plain formats
// imposed, meaning there is no implementation with plain formats
CHECK(ddc.ImposePlainWherePadding(get_data_size(), get_weights_size(), get_out_size()))
<< "No implementation of calculating deconvolution weights gradient";
bwd_w_desc = ddc.CreateBwdWeightsDesc();
*pd = deconv_bwd_weights_pd_t(bwd_w_desc, engine, fwd_pd);
}
}
return pd;
}
Expand Down