diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc index cdbb764bc535..356a3bd1e06b 100644 --- a/src/ndarray/ndarray.cc +++ b/src/ndarray/ndarray.cc @@ -540,13 +540,20 @@ void NDArray::Chunk::Reorder2Default() { dnnl_format_tag_t format = dnnl_mem_->GetDefaultFormat(); dnnl::memory::desc def_desc = dnnl_mem_->GetDesc(format); - dnnl_mem_ptr def_mem(new dnnl::memory(def_desc, CpuEngine::Get()->get_engine())); - dnnl_mem_->ReorderTo(def_mem.get()); CHECK(shandle.size >= def_desc.get_size()); CheckAndAlloc(def_desc.get_size()); - // TODO(zhengda) We need to avoid memory copy here. - memcpy(shandle.dptr, def_mem->get_data_handle(), def_desc.get_size()); + + // oneDNN reorder can't be performed in-place + if (shandle.dptr == dnnl_mem_->GetDataHandle()) { + dnnl_mem_ptr def_mem(new dnnl::memory(def_desc, CpuEngine::Get()->get_engine())); + dnnl_mem_->ReorderTo(def_mem.get()); + memcpy(shandle.dptr, def_mem->get_data_handle(), def_desc.get_size()); + } else { + dnnl_mem_ptr def_mem(new dnnl::memory(def_desc, CpuEngine::Get()->get_engine(), shandle.dptr)); + dnnl_mem_->ReorderTo(def_mem.get()); + } + dnnl_mem_ = nullptr; }