From f7442931d1d227d6be4db3ae688cc23f4e0ff3a7 Mon Sep 17 00:00:00 2001 From: "B. Gawrych" Date: Thu, 28 Oct 2021 15:14:35 +0200 Subject: [PATCH] Optimize preparation of selfattn operators --- .../subgraph/dnnl/dnnl_transformer.cc | 94 ++++++++++--------- 1 file changed, 52 insertions(+), 42 deletions(-) diff --git a/src/operator/subgraph/dnnl/dnnl_transformer.cc b/src/operator/subgraph/dnnl/dnnl_transformer.cc index 965aff4df301..f6861737c972 100644 --- a/src/operator/subgraph/dnnl/dnnl_transformer.cc +++ b/src/operator/subgraph/dnnl/dnnl_transformer.cc @@ -116,7 +116,8 @@ class SgDNNLSelfAttQKOp { void Forward(const OpContext& ctx, const std::vector& inputs, const std::vector& req, - const std::vector& outputs); + const std::vector& outputs, + bool already_prepared); void Backward(const OpContext& ctx, const std::vector& inputs, @@ -163,10 +164,12 @@ static void SgDNNLSelfAttQKForward(const OpStatePtr& state_pointer, const std::vector& req, const std::vector& outputs) { SgDNNLSelfAttQKOp& op = state_pointer.get_state(); + bool already_prepared = false; if (!op.IsInitialized()) { op.Initialize(ctx, inputs, req, outputs); + already_prepared = true; } - op.Forward(ctx, inputs, req, outputs); + op.Forward(ctx, inputs, req, outputs, already_prepared); } static bool SgDNNLSelfAttStorageType(const nnvm::NodeAttrs& attrs, @@ -264,21 +267,23 @@ void SgDNNLSelfAttQKOp::Initialize(const OpContext& ctx, void SgDNNLSelfAttQKOp::Forward(const OpContext& ctx, const std::vector& inputs, const std::vector& req, - const std::vector& outputs) { - const size_t output_lin_dim = inputs[0].shape()[2]; - const size_t embed_dim = output_lin_dim / QKV_NUM; - - MSHADOW_TYPE_SWITCH(inputs[0].dtype(), DType, { - DType* query_mem_ptr = inputs[0].data().dptr(); - DType* key_mem_ptr = query_mem_ptr + embed_dim; - cached_query_mem_->set_data_handle(query_mem_ptr); - cached_key_mem_->set_data_handle(key_mem_ptr); - }); - - MSHADOW_TYPE_SWITCH(outputs[0].dtype(), DType, { - cached_out_mem_->set_data_handle(outputs[0].data().dptr()); - }); - + const std::vector& outputs, + bool already_prepared) { + if (!already_prepared) { + const size_t output_lin_dim = inputs[0].shape()[2]; + const size_t embed_dim = output_lin_dim / QKV_NUM; + + MSHADOW_TYPE_SWITCH(inputs[0].dtype(), DType, { + DType* query_mem_ptr = inputs[0].data().dptr(); + DType* key_mem_ptr = query_mem_ptr + embed_dim; + cached_query_mem_->set_data_handle(query_mem_ptr); + cached_key_mem_->set_data_handle(key_mem_ptr); + }); + + MSHADOW_TYPE_SWITCH(outputs[0].dtype(), DType, { + cached_out_mem_->set_data_handle(outputs[0].data().dptr()); + }); + } DNNLStream::Get()->RegisterPrimArgs(*fwd_, args_); DNNLStream::Get()->Submit(); @@ -483,7 +488,8 @@ class DNNLSelfAttValAttOp { void Forward(const OpContext& ctx, const std::vector& inputs, const std::vector& req, - const std::vector& outputs); + const std::vector& outputs, + bool already_prepared); void Backward(const OpContext& ctx, const std::vector& inputs, @@ -537,10 +543,12 @@ static void DNNLSelfAttValAttForward(const OpStatePtr& state_pointer, const std::vector& req, const std::vector& outputs) { DNNLSelfAttValAttOp& op = state_pointer.get_state(); + bool already_prepared = false; if (!op.IsInitialized()) { op.Initialize(ctx, inputs, req, outputs); + already_prepared = true; } - op.Forward(ctx, inputs, req, outputs); + op.Forward(ctx, inputs, req, outputs, already_prepared); } void DNNLSelfAttValAttOp::Initialize(const OpContext& ctx, @@ -663,29 +671,31 @@ void DNNLSelfAttValAttOp::Initialize(const OpContext& ctx, void DNNLSelfAttValAttOp::Forward(const OpContext& ctx, const std::vector& inputs, const std::vector& req, - const std::vector& outputs) { - // multiply by 2 as we need to skip queries and keys - const size_t value_offset = inputs[1].shape()[2] / QKV_NUM * 2; - - auto att_buffer = inputs[0]; - if (att_buffer.IsDNNLData()) - att_buffer = att_buffer.Reorder2Default(); - - MSHADOW_TYPE_SWITCH(att_buffer.dtype(), DType, { - DType* attention_ptr = att_buffer.data().dptr(); - cached_att_mem_->set_data_handle(attention_ptr); - }); - - MSHADOW_TYPE_SWITCH(inputs[1].dtype(), DType, { - DType* qkv_ptr = inputs[1].data().dptr(); - DType* value_mem_ptr = qkv_ptr + value_offset; - cached_value_mem_->set_data_handle(value_mem_ptr); - }); - - MSHADOW_TYPE_SWITCH(outputs[0].dtype(), DType, { - cached_transposed_mem_->set_data_handle(outputs[0].data().dptr()); - }); - + const std::vector& outputs, + bool already_prepared) { + if (!already_prepared) { + // multiply by 2 as we need to skip queries and keys + const size_t value_offset = inputs[1].shape()[2] / QKV_NUM * 2; + + auto att_buffer = inputs[0]; + if (att_buffer.IsDNNLData()) + att_buffer = att_buffer.Reorder2Default(); + + MSHADOW_TYPE_SWITCH(att_buffer.dtype(), DType, { + DType* attention_ptr = att_buffer.data().dptr(); + cached_att_mem_->set_data_handle(attention_ptr); + }); + + MSHADOW_TYPE_SWITCH(inputs[1].dtype(), DType, { + DType* qkv_ptr = inputs[1].data().dptr(); + DType* value_mem_ptr = qkv_ptr + value_offset; + cached_value_mem_->set_data_handle(value_mem_ptr); + }); + + MSHADOW_TYPE_SWITCH(outputs[0].dtype(), DType, { + cached_transposed_mem_->set_data_handle(outputs[0].data().dptr()); + }); + } DNNLStream::Get()->RegisterPrimArgs(*fwd_, args_); DNNLStream::Get()->RegisterPrimArgs(*reorder_, reorder_args); DNNLStream::Get()->Submit();