From ffbf43abcdcd609b2fb702838716ea078972e4e8 Mon Sep 17 00:00:00 2001 From: grygielski Date: Tue, 23 Mar 2021 14:00:37 +0100 Subject: [PATCH 01/13] Add oneDNN code to interleved kernels --- .../subgraph/mkldnn/mkldnn_transformer.cc | 1164 +++++++++++++++++ 1 file changed, 1164 insertions(+) create mode 100644 src/operator/subgraph/mkldnn/mkldnn_transformer.cc diff --git a/src/operator/subgraph/mkldnn/mkldnn_transformer.cc b/src/operator/subgraph/mkldnn/mkldnn_transformer.cc new file mode 100644 index 000000000000..c34d0dde6996 --- /dev/null +++ b/src/operator/subgraph/mkldnn/mkldnn_transformer.cc @@ -0,0 +1,1164 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one +* or more contributor license agreements. See the NOTICE file +* distributed with this work for additional information +* regarding copyright ownership. The ASF licenses this file +* to you under the Apache License, Version 2.0 (the +* "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an +* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +* KIND, either express or implied. See the License for the +* specific language governing permissions and limitations +* under the License. +*/ + +#if MXNET_USE_MKLDNN == 1 + +#include +#include +#include +#include +#include "../common.h" +#include "./mkldnn_transformer-inl.h" +#include "../../contrib/transformer-inl.h" +#include "../../tensor/elemwise_unary_op.h" + +#include "../../quantization/quantization_utils.h" + +namespace mxnet { +namespace op { + +DMLC_REGISTER_PARAMETER(MKLDNNInterleavedMatMulParam); + +static bool MKLDNNInterleavedMatMulSelfAttQKShape(const NodeAttrs& attrs, + mxnet::ShapeVector* in_shape, + mxnet::ShapeVector* out_shape) { + const auto& param = nnvm::get(attrs.parsed); + if (param.quantized) { + auto qkv_shape = in_shape->at(0); + out_shape->resize(3); + SHAPE_ASSIGN_CHECK(*out_shape, 0, + mxnet::TShape({param.heads * qkv_shape[1], qkv_shape[0], qkv_shape[0]})); // output + + if (!param.enable_float_output) { + SHAPE_ASSIGN_CHECK(*out_shape, 1, mxnet::TShape({1})); // min output + SHAPE_ASSIGN_CHECK(*out_shape, 2, mxnet::TShape({1})); // max output + } + return true; + } else { + CHECK_EQ(in_shape->size(), 1U) << "Input:[queries_keys_values] currently have, " + << in_shape->size() << " inputs"; + auto qkv_shape = in_shape->at(0); + CHECK_EQ(qkv_shape.ndim(), 3U) + << "Input queries_keys_values should be 3D in seq_length-batch-proj_dim, " + << "currently is: " << qkv_shape.ndim() << "D"; + out_shape->resize(1); + SHAPE_ASSIGN_CHECK(*out_shape, 0, + mxnet::TShape({param.heads * qkv_shape[1], qkv_shape[0], qkv_shape[0]})); + return true; + } +} + +static bool MKLDNNInterleavedMatMulSelfAttQKInferType(const nnvm::NodeAttrs &attrs, + std::vector *in_types, + std::vector *out_types) { + const auto& param = nnvm::get(attrs.parsed); + if (param.quantized) { + TYPE_ASSIGN_CHECK(*in_types, 0, mshadow::kInt8); // qkv input + TYPE_ASSIGN_CHECK(*in_types, 1, mshadow::kFloat32); // min value + TYPE_ASSIGN_CHECK(*in_types, 2, mshadow::kFloat32); // max value + + if (param.enable_float_output) { + TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kFloat32); // output + } else { + if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) { + TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kInt8); // output + } else { + TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kInt32); // output + } + TYPE_ASSIGN_CHECK(*out_types, 1, mshadow::kFloat32); // min output + TYPE_ASSIGN_CHECK(*out_types, 2, mshadow::kFloat32); // max output + } + return true; + } else { + return DefaultSubgraphOpType(attrs, in_types, out_types); + } +} + +static bool MKLDNNInterleavedMatMulSelfAttQKStorageType(const nnvm::NodeAttrs &attrs, + const int dev_mask, + DispatchMode *dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs) { + auto const ¶m = nnvm::get(attrs.parsed); + if (param.quantized) { + type_assign(&in_attrs->at(0), mxnet::kDefaultStorage); + type_assign(&in_attrs->at(1), mxnet::kDefaultStorage); + type_assign(&in_attrs->at(2), mxnet::kDefaultStorage); + + type_assign(&out_attrs->at(0), mxnet::kDefaultStorage); + if (!param.enable_float_output) { + type_assign(&out_attrs->at(1), mxnet::kDefaultStorage); + type_assign(&out_attrs->at(2), mxnet::kDefaultStorage); + } + std::vector base_in_attrs{in_attrs->at(0)}; + std::vector base_out_attrs{out_attrs->at(0)}; + return DefaultSubgraphOpStorageType(attrs, dev_mask, dispatch_mode, + &base_in_attrs, &base_out_attrs);; + } else { + return DefaultSubgraphOpStorageType(attrs, dev_mask, dispatch_mode, + in_attrs, out_attrs); + } +} + +class MKLDNNInterleavedMatMulSelfAttQKOp { + public: + explicit MKLDNNInterleavedMatMulSelfAttQKOp(const nnvm::NodeAttrs &attrs) : + param_(nnvm::get(attrs.parsed)) {} + + void Forward(const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs); + + void Backward(const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + LOG(FATAL) << "Not implemented: subgraph mkldnn fully connected only supports " + "inference computation."; + } + + private: + bool initialized_{false}; + MKLDNNInterleavedMatMulParam param_; + mkldnn_args_map_t args_; + std::shared_ptr fwd_; + std::shared_ptr cached_data1_mem_; + std::shared_ptr cached_data2_mem_; + std::shared_ptr cached_out_mem_; + float cached_min_data_; + float cached_max_data_; + float cached_min_output_; + float cached_max_output_; + float data_scale_{0.0f}; +}; + +static OpStatePtr CreateMKLDNNInterleavedMatMulSelfAttQKState(const nnvm::NodeAttrs &attrs, + Context ctx, + const mxnet::ShapeVector &in_shapes, + const std::vector &in_types) { + return OpStatePtr::Create(attrs); +} + +static void MKLDNNInterleavedMatMulSelfAttQKForward(const OpStatePtr &state_pointer, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + MKLDNNInterleavedMatMulSelfAttQKOp &op = state_pointer.get_state(); + op.Forward(ctx, inputs, req, outputs); +} + +void MKLDNNInterleavedMatMulSelfAttQKOp::Forward( + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + + float min_data = 0.0f; + float max_data = 0.0f; + + if (param_.quantized) { + min_data = inputs[1].data().dptr()[0]; + max_data = inputs[2].data().dptr()[0]; + } + + const dnnl::memory::dim HEADS = param_.heads; + const dnnl::memory::dim BS = inputs[0].shape()[1]; + const dnnl::memory::dim SEQ_LEN = inputs[0].shape()[0]; + const dnnl::memory::dim EMBED = inputs[0].shape()[2]; + + if (!initialized_) { + const auto engine = CpuEngine::Get()->get_engine(); + + cached_min_data_ = min_data; + cached_max_data_ = max_data; + + dnnl::memory::dims src1_dims = {HEADS*BS, SEQ_LEN, EMBED/HEADS/3}; + dnnl::memory::dims src2_dims = {HEADS*BS, EMBED/HEADS/3, SEQ_LEN}; + dnnl::memory::dims dst_dims = {HEADS*BS, SEQ_LEN, SEQ_LEN}; + + dnnl::memory::dims src1_strides = {3*(EMBED/HEADS/3), EMBED*BS, 1}; + dnnl::memory::dims src2_strides = {3*(EMBED/HEADS/3), 1, EMBED*BS}; + + auto src1_md = param_.quantized ? dnnl::memory::desc(src1_dims, dnnl::memory::data_type::s8, src1_strides) : + dnnl::memory::desc(src1_dims, dnnl::memory::data_type::f32, src1_strides); + auto src2_md = param_.quantized ? dnnl::memory::desc(src2_dims, dnnl::memory::data_type::s8, src2_strides) : + dnnl::memory::desc(src2_dims, dnnl::memory::data_type::f32, src2_strides); + + dnnl::memory::desc dst_md; + + float tmp_scale = 1.0f; + if (param_.quantized) { + data_scale_ = GetQuantizeScale(mshadow::kInt8, cached_min_data_, cached_max_data_); + + if (param_.min_calib_range.has_value() && + param_.max_calib_range.has_value()) { + cached_min_output_ = param_.min_calib_range.value(); + cached_max_output_ = param_.max_calib_range.value(); + + tmp_scale = GetQuantizeScale(mshadow::kInt8, cached_min_output_, cached_max_output_) / (data_scale_ * data_scale_); + dst_md = dnnl::memory::desc(dst_dims, dnnl::memory::data_type::s8, dnnl::memory::format_tag::abc); + } else if (param_.enable_float_output) { + tmp_scale = 1.0f / (data_scale_ * data_scale_); + dst_md = dnnl::memory::desc(dst_dims, dnnl::memory::data_type::f32, dnnl::memory::format_tag::abc); + } else { + mshadow::Stream *s = ctx.get_stream(); + mxnet_op::Kernel::Launch( + s, 1, &cached_min_output_, &cached_max_output_, &min_data, &max_data, &min_data, + &max_data); + dst_md = dnnl::memory::desc(dst_dims, dnnl::memory::data_type::s32, dnnl::memory::format_tag::abc); + } + } else { + dst_md = dnnl::memory::desc(dst_dims, dnnl::memory::data_type::f32, dnnl::memory::format_tag::abc); + } + tmp_scale /= sqrt(static_cast(EMBED/HEADS/3)); + + dnnl::primitive_attr attr; + attr.set_output_scales(0, {tmp_scale}); + auto matmul_d = dnnl::matmul::desc(src1_md, src2_md, dst_md); + auto matmul_pd = dnnl::matmul::primitive_desc(matmul_d, attr, engine); + + fwd_ = std::make_shared(matmul_pd); + + MSHADOW_TYPE_SWITCH(inputs[0].dtype(), DType, { + cached_data1_mem_ = std::make_shared(src1_md, engine, inputs[0].data().dptr()); + cached_data2_mem_ = std::make_shared(src2_md, engine, inputs[0].data().dptr() + (EMBED/HEADS/3)); + }); + MSHADOW_TYPE_SWITCH(outputs[0].dtype(), DType, { + cached_out_mem_ = std::make_shared(dst_md, engine, outputs[0].data().dptr()); + }); + + args_[DNNL_ARG_SRC] = *cached_data1_mem_; + args_[DNNL_ARG_WEIGHTS] = *cached_data2_mem_; + args_[DNNL_ARG_DST] = *cached_out_mem_; + initialized_ = true; + } else { + MSHADOW_TYPE_SWITCH(inputs[0].dtype(), DType, { + cached_data1_mem_->set_data_handle(reinterpret_cast(inputs[0].data().dptr())); + cached_data2_mem_->set_data_handle(reinterpret_cast(inputs[0].data().dptr() + (EMBED/HEADS/3))); + }); + MSHADOW_TYPE_SWITCH(outputs[0].dtype(), DType, { + cached_out_mem_->set_data_handle(reinterpret_cast(outputs[0].data().dptr())); + }); + } + MKLDNNStream::Get()->RegisterPrimArgs(*fwd_, args_); + MKLDNNStream::Get()->Submit(); + + if (param_.quantized && !param_.enable_float_output) { + float* min_output = outputs[1].data().dptr(); + float* max_output = outputs[2].data().dptr(); + + *min_output = cached_min_output_; + *max_output = cached_max_output_; + } + + + + + + // if (param_.quantized) { + // if (param_.enable_float_output) { + + + // dnnl::engine engine(dnnl::engine::kind::cpu, 0); + // dnnl::stream engine_stream(engine); + + // dnnl::memory::dims src1_dims = {HEADS*BS, SEQ_LEN, EMBED/HEADS/3}; + // dnnl::memory::dims src2_dims = {HEADS*BS, EMBED/HEADS/3, SEQ_LEN}; + // dnnl::memory::dims dst_dims = {HEADS*BS, SEQ_LEN, SEQ_LEN}; + + // dnnl::memory::dims src1_strides = {3*(EMBED/HEADS/3), EMBED*BS, 1}; + // dnnl::memory::dims src2_strides = {3*(EMBED/HEADS/3), 1, EMBED*BS}; + + // auto src1_md = dnnl::memory::desc(src1_dims, dnnl::memory::data_type::s8, src1_strides); + // auto src2_md = dnnl::memory::desc(src2_dims, dnnl::memory::data_type::s8, src2_strides); + // auto dst_md = dnnl::memory::desc(dst_dims, dnnl::memory::data_type::f32, dnnl::memory::format_tag::abc); + + // // const float scale = 1.0f / sqrt(static_cast(EMBED/HEADS/3)); + // float min_data = inputs[1].dptr()[0]; + // float max_data = inputs[2].dptr()[0]; + + // float data_scale = GetQuantizeScale(mshadow::kInt8, min_data, max_data); + // const float scale = 1.0f / (data_scale * data_scale) / sqrt(static_cast(EMBED/HEADS/3)); + + // dnnl::primitive_attr attr; + // attr.set_output_scales(0, {scale}); + + // // CODE FOR HANLDING MASKING + // // float* mask = inputs[1].FlatTo2D(s).dptr_; + // // memcpy(output, mask, sizeof(float)*HEADS*BS*SEQ_LEN*SEQ_LEN); + // // dnnl::post_ops post_op; + // // post_op.append_sum(1); + // // attr.set_post_ops(post_op); + + // auto matmul_d = dnnl::matmul::desc(src1_md, src2_md, dst_md); + // auto matmul_pd = dnnl::matmul::primitive_desc(matmul_d, attr, engine); + + // auto matmul_prim = dnnl::matmul(matmul_pd); + + // mshadow::Stream* s = ctx.get_stream(); + // int8_t* queries_keys_values = inputs[0].FlatTo2D(s).dptr_; + + // float* output = outputs[0].FlatTo2D(s).dptr_; + + // auto src1_mem = dnnl::memory(src1_md, engine, queries_keys_values); + // auto src2_mem = dnnl::memory(src2_md, engine, queries_keys_values+(EMBED/HEADS/3)); + // auto dst_mem = dnnl::memory(dst_md, engine, output); + + // std::unordered_map matmul_args; + // matmul_args.insert({DNNL_ARG_SRC, src1_mem}); + // matmul_args.insert({DNNL_ARG_WEIGHTS, src2_mem}); + // matmul_args.insert({DNNL_ARG_DST, dst_mem}); + + // matmul_prim.execute(engine_stream, matmul_args); + // engine_stream.wait(); + // } else if (param_.min_calib_range.has_value() && param_.max_calib_range.has_value()) { + // const dnnl::memory::dim HEADS = param.heads; + // const dnnl::memory::dim BS = inputs[0].shape_[1]; + // const dnnl::memory::dim SEQ_LEN = inputs[0].shape_[0]; + // const dnnl::memory::dim EMBED = inputs[0].shape_[2]; + + // dnnl::engine engine(dnnl::engine::kind::cpu, 0); + // dnnl::stream engine_stream(engine); + + // dnnl::memory::dims src1_dims = {HEADS*BS, SEQ_LEN, EMBED/HEADS/3}; + // dnnl::memory::dims src2_dims = {HEADS*BS, EMBED/HEADS/3, SEQ_LEN}; + // dnnl::memory::dims dst_dims = {HEADS*BS, SEQ_LEN, SEQ_LEN}; + + // dnnl::memory::dims src1_strides = {3*(EMBED/HEADS/3), EMBED*BS, 1}; + // dnnl::memory::dims src2_strides = {3*(EMBED/HEADS/3), 1, EMBED*BS}; + + // auto src1_md = dnnl::memory::desc(src1_dims, dnnl::memory::data_type::s8, src1_strides); + // auto src2_md = dnnl::memory::desc(src2_dims, dnnl::memory::data_type::s8, src2_strides); + // auto dst_md = dnnl::memory::desc(dst_dims, dnnl::memory::data_type::s8, dnnl::memory::format_tag::abc); + + // // const float scale = 1.0f / sqrt(static_cast(EMBED/HEADS/3)); + // float min_data = inputs[1].dptr()[0]; + // float max_data = inputs[2].dptr()[0]; + + + + // float data_scale = GetQuantizeScale(mshadow::kInt8, min_data, max_data); + // const float scale = GetQuantizeScale(mshadow::kInt8, param.min_calib_range.value(), param.max_calib_range.value()) + // / (data_scale * data_scale) / sqrt(static_cast(EMBED/HEADS/3)); + + // dnnl::primitive_attr attr; + // attr.set_output_scales(0, {scale}); + + // // CODE FOR HANLDING MASKING + // // float* mask = inputs[1].FlatTo2D(s).dptr_; + // // memcpy(output, mask, sizeof(float)*HEADS*BS*SEQ_LEN*SEQ_LEN); + // // dnnl::post_ops post_op; + // // post_op.append_sum(1); + // // attr.set_post_ops(post_op); + + // auto matmul_d = dnnl::matmul::desc(src1_md, src2_md, dst_md); + // auto matmul_pd = dnnl::matmul::primitive_desc(matmul_d, attr, engine); + + // auto matmul_prim = dnnl::matmul(matmul_pd); + + // mshadow::Stream* s = ctx.get_stream(); + // int8_t* queries_keys_values = inputs[0].FlatTo2D(s).dptr_; + + // int8_t* output = outputs[0].FlatTo2D(s).dptr_; + // float* min_output = outputs[1].dptr(); + // float* max_output = outputs[2].dptr(); + // min_output[0] = param.min_calib_range.value(); + // max_output[0] = param.max_calib_range.value(); + + // auto src1_mem = dnnl::memory(src1_md, engine, queries_keys_values); + // auto src2_mem = dnnl::memory(src2_md, engine, queries_keys_values+(EMBED/HEADS/3)); + // auto dst_mem = dnnl::memory(dst_md, engine, output); + + // std::unordered_map matmul_args; + // matmul_args.insert({DNNL_ARG_SRC, src1_mem}); + // matmul_args.insert({DNNL_ARG_WEIGHTS, src2_mem}); + // matmul_args.insert({DNNL_ARG_DST, dst_mem}); + + // matmul_prim.execute(engine_stream, matmul_args); + // engine_stream.wait(); + // } else { + // const dnnl::memory::dim HEADS = param.heads; + // const dnnl::memory::dim BS = inputs[0].shape_[1]; + // const dnnl::memory::dim SEQ_LEN = inputs[0].shape_[0]; + // const dnnl::memory::dim EMBED = inputs[0].shape_[2]; + + // dnnl::engine engine(dnnl::engine::kind::cpu, 0); + // dnnl::stream engine_stream(engine); + + // dnnl::memory::dims src1_dims = {HEADS*BS, SEQ_LEN, EMBED/HEADS/3}; + // dnnl::memory::dims src2_dims = {HEADS*BS, EMBED/HEADS/3, SEQ_LEN}; + // dnnl::memory::dims dst_dims = {HEADS*BS, SEQ_LEN, SEQ_LEN}; + + // dnnl::memory::dims src1_strides = {3*(EMBED/HEADS/3), EMBED*BS, 1}; + // dnnl::memory::dims src2_strides = {3*(EMBED/HEADS/3), 1, EMBED*BS}; + + // auto src1_md = dnnl::memory::desc(src1_dims, dnnl::memory::data_type::s8, src1_strides); + // auto src2_md = dnnl::memory::desc(src2_dims, dnnl::memory::data_type::s8, src2_strides); + // auto dst_md = dnnl::memory::desc(dst_dims, dnnl::memory::data_type::s32, dnnl::memory::format_tag::abc); + + // // const float scale = 1.0f / sqrt(static_cast(EMBED/HEADS/3)); + // float min_data = inputs[1].dptr()[0]; + // float max_data = inputs[2].dptr()[0]; + + // // float data_scale = GetQuantizeScale(mshadow::kInt8, min_data, max_data); + // const float scale = 1.0f / sqrt(static_cast(EMBED/HEADS/3)); + + // dnnl::primitive_attr attr; + // attr.set_output_scales(0, {scale}); + + // // CODE FOR HANLDING MASKING + // // float* mask = inputs[1].FlatTo2D(s).dptr_; + // // memcpy(output, mask, sizeof(float)*HEADS*BS*SEQ_LEN*SEQ_LEN); + // // dnnl::post_ops post_op; + // // post_op.append_sum(1); + // // attr.set_post_ops(post_op); + + // auto matmul_d = dnnl::matmul::desc(src1_md, src2_md, dst_md); + // auto matmul_pd = dnnl::matmul::primitive_desc(matmul_d, attr, engine); + + // auto matmul_prim = dnnl::matmul(matmul_pd); + + // mshadow::Stream* s = ctx.get_stream(); + // int8_t* queries_keys_values = inputs[0].FlatTo2D(s).dptr_; + + // int32_t* output = outputs[0].FlatTo2D(s).dptr_; + // float* min_output = outputs[1].dptr(); + // float* max_output = outputs[2].dptr(); + + // mxnet_op::Kernel::Launch( + // s, 1, min_output, max_output, &min_data, &max_data, &min_data, + // &max_data); + + // // min_output[0] = min_data; + // // max_output[0] = max_data; + + // auto src1_mem = dnnl::memory(src1_md, engine, queries_keys_values); + // auto src2_mem = dnnl::memory(src2_md, engine, queries_keys_values+(EMBED/HEADS/3)); + // auto dst_mem = dnnl::memory(dst_md, engine, output); + + // std::unordered_map matmul_args; + // matmul_args.insert({DNNL_ARG_SRC, src1_mem}); + // matmul_args.insert({DNNL_ARG_WEIGHTS, src2_mem}); + // matmul_args.insert({DNNL_ARG_DST, dst_mem}); + + // matmul_prim.execute(engine_stream, matmul_args); + // engine_stream.wait(); + // } + // } else { + // const dnnl::memory::dim HEADS = param.heads; + // const dnnl::memory::dim BS = inputs[0].shape_[1]; + // const dnnl::memory::dim SEQ_LEN = inputs[0].shape_[0]; + // const dnnl::memory::dim EMBED = inputs[0].shape_[2]; + + // dnnl::engine engine(dnnl::engine::kind::cpu, 0); + // dnnl::stream engine_stream(engine); + + // dnnl::memory::dims src1_dims = {HEADS*BS, SEQ_LEN, EMBED/HEADS/3}; + // dnnl::memory::dims src2_dims = {HEADS*BS, EMBED/HEADS/3, SEQ_LEN}; + // dnnl::memory::dims dst_dims = {HEADS*BS, SEQ_LEN, SEQ_LEN}; + + // dnnl::memory::dims src1_strides = {3*(EMBED/HEADS/3), EMBED*BS, 1}; + // dnnl::memory::dims src2_strides = {3*(EMBED/HEADS/3), 1, EMBED*BS}; + + // auto src1_md = dnnl::memory::desc(src1_dims, dnnl::memory::data_type::f32, src1_strides); + // auto src2_md = dnnl::memory::desc(src2_dims, dnnl::memory::data_type::f32, src2_strides); + // auto dst_md = dnnl::memory::desc(dst_dims, dnnl::memory::data_type::f32, dnnl::memory::format_tag::abc); + + // const float scale = 1.0f / sqrt(static_cast(EMBED/HEADS/3)); + + // dnnl::primitive_attr attr; + // attr.set_output_scales(0, {scale}); + + // // CODE FOR HANLDING MASKING + // // float* mask = inputs[1].FlatTo2D(s).dptr_; + // // memcpy(output, mask, sizeof(float)*HEADS*BS*SEQ_LEN*SEQ_LEN); + // // dnnl::post_ops post_op; + // // post_op.append_sum(1); + // // attr.set_post_ops(post_op); + + // auto matmul_d = dnnl::matmul::desc(src1_md, src2_md, dst_md); + // auto matmul_pd = dnnl::matmul::primitive_desc(matmul_d, attr, engine); + + // auto matmul_prim = dnnl::matmul(matmul_pd); + + // mshadow::Stream* s = ctx.get_stream(); + // float* queries_keys_values = inputs[0].FlatTo2D(s).dptr_; + + // float* output = outputs[0].FlatTo2D(s).dptr_; + + // auto src1_mem = dnnl::memory(src1_md, engine, queries_keys_values); + // auto src2_mem = dnnl::memory(src2_md, engine, queries_keys_values+(EMBED/HEADS/3)); + // auto dst_mem = dnnl::memory(dst_md, engine, output); + + // std::unordered_map matmul_args; + // matmul_args.insert({DNNL_ARG_SRC, src1_mem}); + // matmul_args.insert({DNNL_ARG_WEIGHTS, src2_mem}); + // matmul_args.insert({DNNL_ARG_DST, dst_mem}); + + // matmul_prim.execute(engine_stream, matmul_args); + // engine_stream.wait(); + // } +} + +nnvm::ObjectPtr SgMKLDNNInterleavedMatMulSelfAttQKQuantizedOp(const NodeAttrs& attrs) { + nnvm::ObjectPtr node = nnvm::Node::Create(); + auto const ¶m = nnvm::get(attrs.parsed); + node->attrs.op = Op::Get("_sg_mkldnn_interleaved_matmul_selfatt_qk"); + node->attrs.name = "quantized_" + attrs.name; + node->attrs.dict = attrs.dict; + node->attrs.dict["heads"] = std::to_string(param.heads); + node->attrs.dict["quantized"] = "True"; + node->attrs.subgraphs.reserve(attrs.subgraphs.size()); + for (auto sub : attrs.subgraphs) { + node->attrs.subgraphs.push_back(sub); + } + node->op()->attr_parser(&(node->attrs)); + return node; +} + +NNVM_REGISTER_OP(_sg_mkldnn_interleaved_matmul_selfatt_qk) +.describe(R"code(_sg_mkldnn_interleaved_matmul_selfatt_qk)code" ADD_FILELINE) +.set_num_inputs([](const NodeAttrs& attrs) { + auto const& param = nnvm::get(attrs.parsed); + if (param.quantized) { + return 3; + } else { + return 1; + } +}) +.set_num_outputs([](const NodeAttrs& attrs) { + auto const& param = nnvm::get(attrs.parsed); + if (param.quantized && !param.enable_float_output) { + return 3; + } else { + return 1; + } +}) +.set_attr_parser(ParamParser) +.set_attr("FListInputNames", [](const NodeAttrs& attrs) { + auto const& param = nnvm::get(attrs.parsed); + std::vector input_names {"queries_keys_values"}; + if (param.quantized) { + input_names.emplace_back("min_qkv"); + input_names.emplace_back("max_qkv"); + } + return input_names; +}) +.set_attr("FListOutputNames", [](const NodeAttrs& attrs) { + auto const& param = nnvm::get(attrs.parsed); + std::vector output_names {"output"}; + if (param.quantized && !param.enable_float_output) { + output_names.emplace_back("min_output"); + output_names.emplace_back("max_output"); + } + return output_names; +}) +.set_attr("FInferShape", MKLDNNInterleavedMatMulSelfAttQKShape) +.set_attr("FInferType", MKLDNNInterleavedMatMulSelfAttQKInferType) +.set_attr("FInferStorageType", MKLDNNInterleavedMatMulSelfAttQKStorageType) +.set_attr("FCreateOpState", CreateMKLDNNInterleavedMatMulSelfAttQKState) +.set_attr("FStatefulComputeEx", MKLDNNInterleavedMatMulSelfAttQKForward) +.set_attr("TIsMKLDNN", true) +.set_attr("FGradient", MakeZeroGradNodes) +.set_attr("FQuantizable", [](const NodeAttrs& attrs) { + return QuantizeType::kMust; +}) +.set_attr("FQuantizedOp", SgMKLDNNInterleavedMatMulSelfAttQKQuantizedOp) +.set_attr("FNeedRequantize", [](const NodeAttrs& attrs) { return true; }) +.add_argument("queries_keys_values", "NDArray-or-Symbol", "Interleaved queries, keys and values") +.add_arguments(MKLDNNInterleavedMatMulParam::__FIELDS__()); + +/********************************************************************************************************/ + +static bool MKLDNNInterleavedMatMulSelfAttValAttShape(const NodeAttrs& attrs, + mxnet::ShapeVector* in_shape, + mxnet::ShapeVector* out_shape) { + const auto& param = nnvm::get(attrs.parsed); + if (param.quantized) { + auto qkv_shape = in_shape->at(0); + + out_shape->resize(3); + SHAPE_ASSIGN_CHECK(*out_shape, 0, + mxnet::TShape({qkv_shape[0], qkv_shape[1], qkv_shape[2] / 3})); + if (!param.enable_float_output) { + SHAPE_ASSIGN_CHECK(*out_shape, 1, mxnet::TShape({1})); // min output + SHAPE_ASSIGN_CHECK(*out_shape, 2, mxnet::TShape({1})); // max output + } + + return true; + } else { + CHECK_EQ(in_shape->size(), 2U) << "Input:[queries_keys_values, attention] currently have, " + << in_shape->size() << " inputs"; + auto qkv_shape = in_shape->at(0); + auto att_shape = in_shape->at(1); + CHECK_EQ(qkv_shape.ndim(), 3U) + << "Input queries_keys_values should be 3D in seq_length-batch-3*proj_dim, " + << "currently is: " << qkv_shape.ndim() << "D"; + CHECK_EQ(att_shape.ndim(), 3U) + << "Input attention should be 3D in batch-seq_length-seq_length, " + << "currently is: " << att_shape.ndim() << "D"; + CHECK_EQ(qkv_shape[0], att_shape[1]) + << "queries_keys_values.shape[0] and attention.shape[1] should be the same, " + << "currently are " << qkv_shape[0] << " and " << att_shape[1]; + CHECK_EQ(qkv_shape[0], att_shape[2]) + << "queries_keys_values.shape[0] and attention.shape[2] should be the same, " + << "currently are " << qkv_shape[0] << " and " << att_shape[2]; + CHECK_EQ(qkv_shape[2] % 3, 0) + << "queries_keys_values.shape[2] should be a multiple of 3, " + << "currently is " << qkv_shape[2]; + SHAPE_ASSIGN_CHECK(*out_shape, 0, + mxnet::TShape({qkv_shape[0], qkv_shape[1], qkv_shape[2] / 3})); + return true; + } +} + +static bool MKLDNNInterleavedMatMulSelfAttValAttInferType(const nnvm::NodeAttrs &attrs, + std::vector *in_types, + std::vector *out_types) { + const auto& param = nnvm::get(attrs.parsed); + if (param.quantized) { + TYPE_ASSIGN_CHECK(*in_types, 0, mshadow::kInt8); // qkv input + TYPE_ASSIGN_CHECK(*in_types, 1, mshadow::kUint8); // att input + + TYPE_ASSIGN_CHECK(*in_types, 2, mshadow::kFloat32); // min qkv + TYPE_ASSIGN_CHECK(*in_types, 3, mshadow::kFloat32); // max qkv + + TYPE_ASSIGN_CHECK(*in_types, 4, mshadow::kFloat32); // min att + TYPE_ASSIGN_CHECK(*in_types, 5, mshadow::kFloat32); // max att + + if (param.enable_float_output) { + TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kFloat32); // output + } else { + if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) { + TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kInt8); // output + } else { + TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kInt32); // output + } + TYPE_ASSIGN_CHECK(*out_types, 1, mshadow::kFloat32); // min output + TYPE_ASSIGN_CHECK(*out_types, 2, mshadow::kFloat32); // max output + } + return true; + } else { + return DefaultSubgraphOpType(attrs, in_types, out_types); + } +} + +static bool MKLDNNInterleavedMatMulSelfAttValAttStorageType(const nnvm::NodeAttrs &attrs, + const int dev_mask, + DispatchMode *dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs) { + auto const ¶m = nnvm::get(attrs.parsed); + if (param.quantized) { + type_assign(&in_attrs->at(0), mxnet::kDefaultStorage); + type_assign(&in_attrs->at(1), mxnet::kDefaultStorage); + type_assign(&in_attrs->at(2), mxnet::kDefaultStorage); + type_assign(&in_attrs->at(3), mxnet::kDefaultStorage); + type_assign(&in_attrs->at(4), mxnet::kDefaultStorage); + type_assign(&in_attrs->at(5), mxnet::kDefaultStorage); + + type_assign(&out_attrs->at(0), mxnet::kDefaultStorage); + if (!param.enable_float_output) { + type_assign(&out_attrs->at(1), mxnet::kDefaultStorage); + type_assign(&out_attrs->at(2), mxnet::kDefaultStorage); + } + std::vector base_in_attrs{in_attrs->at(0), in_attrs->at(1)}; + std::vector base_out_attrs{out_attrs->at(0)}; + return DefaultSubgraphOpStorageType(attrs, dev_mask, dispatch_mode, + &base_in_attrs, &base_out_attrs);; + } else { + return DefaultSubgraphOpStorageType(attrs, dev_mask, dispatch_mode, + in_attrs, out_attrs); + } +} + +nnvm::ObjectPtr SgMKLDNNInterleavedMatMulSelfAttValAttQuantizedOp(const NodeAttrs& attrs) { + nnvm::ObjectPtr node = nnvm::Node::Create(); + auto const ¶m = nnvm::get(attrs.parsed); + node->attrs.op = Op::Get("_sg_mkldnn_interleaved_matmul_selfatt_valatt"); + node->attrs.name = "quantized_" + attrs.name; + node->attrs.dict = attrs.dict; + node->attrs.dict["heads"] = std::to_string(param.heads); + node->attrs.dict["quantized"] = "True"; + node->attrs.subgraphs.reserve(attrs.subgraphs.size()); + for (auto sub : attrs.subgraphs) { + node->attrs.subgraphs.push_back(sub); + } + node->op()->attr_parser(&(node->attrs)); + return node; +} + +class MKLDNNInterleavedMatMulSelfAttValAttOp { + public: + explicit MKLDNNInterleavedMatMulSelfAttValAttOp(const nnvm::NodeAttrs &attrs) : + param_(nnvm::get(attrs.parsed)) {} + + void Forward(const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs); + + void Backward(const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + LOG(FATAL) << "Not implemented: subgraph mkldnn fully connected only supports " + "inference computation."; + } + + private: + bool initialized_{false}; + MKLDNNInterleavedMatMulParam param_; + mkldnn_args_map_t args_; + std::shared_ptr fwd_; + std::shared_ptr cached_data1_mem_; + std::shared_ptr cached_data2_mem_; + std::shared_ptr cached_out_mem_; + float cached_min_qkv_; + float cached_max_qkv_; + float cached_min_att_; + float cached_max_att_; + float cached_min_output_; + float cached_max_output_; + float qkv_scale_{0.0f}; + float att_scale_{0.0f}; +}; + +static OpStatePtr CreateMKLDNNInterleavedMatMulSelfAttValAttState(const nnvm::NodeAttrs &attrs, + Context ctx, + const mxnet::ShapeVector &in_shapes, + const std::vector &in_types) { + return OpStatePtr::Create(attrs); +} + +static void MKLDNNInterleavedMatMulSelfAttValAttForward(const OpStatePtr &state_pointer, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + MKLDNNInterleavedMatMulSelfAttValAttOp &op = state_pointer.get_state(); + op.Forward(ctx, inputs, req, outputs); +} + +void MKLDNNInterleavedMatMulSelfAttValAttOp::Forward( + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + + float min_qkv = 0.0f; + float max_qkv = 0.0f; + float min_att = 0.0f; + float max_att = 0.0f; + + if (param_.quantized) { + min_qkv = inputs[2].data().dptr()[0]; + max_qkv = inputs[3].data().dptr()[0]; + min_att = inputs[4].data().dptr()[0]; + max_att = inputs[5].data().dptr()[0]; + } + + const dnnl::memory::dim HEADS = param_.heads; + const dnnl::memory::dim BS = inputs[0].shape()[1]; + const dnnl::memory::dim SEQ_LEN = inputs[0].shape()[0]; + const dnnl::memory::dim EMBED = inputs[0].shape()[2]; + + if (!initialized_) { + const auto engine = CpuEngine::Get()->get_engine(); + + cached_min_qkv_ = min_qkv; + cached_max_qkv_ = max_qkv; + cached_min_att_ = min_att; + cached_max_att_ = max_att; + + dnnl::memory::dims src1_dims = {BS*HEADS, SEQ_LEN, SEQ_LEN}; + dnnl::memory::dims src2_dims = {BS*HEADS, SEQ_LEN, EMBED/HEADS/3}; + dnnl::memory::dims dst_dims = {BS*HEADS, SEQ_LEN, EMBED/HEADS/3}; + + dnnl::memory::dims src1_strides = {SEQ_LEN*SEQ_LEN, SEQ_LEN, 1}; + dnnl::memory::dims src2_strides = {3*(EMBED/HEADS/3), EMBED*BS, 1}; + + auto src1_md = param_.quantized ? dnnl::memory::desc(src1_dims, dnnl::memory::data_type::u8, src1_strides) : + dnnl::memory::desc(src1_dims, dnnl::memory::data_type::f32, src1_strides); + auto src2_md = param_.quantized ? dnnl::memory::desc(src2_dims, dnnl::memory::data_type::s8, src2_strides) : + dnnl::memory::desc(src2_dims, dnnl::memory::data_type::f32, src2_strides); + + dnnl::memory::desc dst_md; + + float tmp_scale = 1.0f; + if (param_.quantized) { + qkv_scale_ = GetQuantizeScale(mshadow::kInt8, min_qkv, max_qkv); + att_scale_ = GetQuantizeScale(mshadow::kUint8, min_att, max_att); + + if (param_.min_calib_range.has_value() && + param_.max_calib_range.has_value()) { + cached_min_output_ = param_.min_calib_range.value(); + cached_max_output_ = param_.max_calib_range.value(); + + tmp_scale = GetQuantizeScale(mshadow::kInt8, cached_min_output_, cached_max_output_) / (qkv_scale_ * att_scale_); + dst_md = dnnl::memory::desc(dst_dims, dnnl::memory::data_type::s8, dnnl::memory::format_tag::bac); + } else if (param_.enable_float_output) { + tmp_scale = 1.0f / (qkv_scale_ * att_scale_); + dst_md = dnnl::memory::desc(dst_dims, dnnl::memory::data_type::f32, dnnl::memory::format_tag::bac); + } else { + mshadow::Stream *s = ctx.get_stream(); + mxnet_op::Kernel::Launch( + s, 1, &cached_min_output_, &cached_max_output_, &min_qkv, &max_qkv, &min_att, + &max_att); + dst_md = dnnl::memory::desc(dst_dims, dnnl::memory::data_type::s32, dnnl::memory::format_tag::bac); + } + } else { + dst_md = dnnl::memory::desc(dst_dims, dnnl::memory::data_type::f32, dnnl::memory::format_tag::bac); + } + + dnnl::primitive_attr attr; + attr.set_output_scales(0, {tmp_scale}); + auto matmul_d = dnnl::matmul::desc(src1_md, src2_md, dst_md); + auto matmul_pd = dnnl::matmul::primitive_desc(matmul_d, attr, engine); + + fwd_ = std::make_shared(matmul_pd); + + MSHADOW_TYPE_SWITCH(inputs[1].dtype(), DType, { + cached_data1_mem_ = std::make_shared(src1_md, engine, inputs[1].data().dptr()); + }); + MSHADOW_TYPE_SWITCH(inputs[0].dtype(), DType, { + cached_data2_mem_ = std::make_shared(src2_md, engine, inputs[0].data().dptr() + 2*(EMBED/HEADS/3)); + }); + MSHADOW_TYPE_SWITCH(outputs[0].dtype(), DType, { + cached_out_mem_ = std::make_shared(dst_md, engine, outputs[0].data().dptr()); + }); + + args_[DNNL_ARG_SRC] = *cached_data1_mem_; + args_[DNNL_ARG_WEIGHTS] = *cached_data2_mem_; + args_[DNNL_ARG_DST] = *cached_out_mem_; + initialized_ = true; + } else { + MSHADOW_TYPE_SWITCH(inputs[1].dtype(), DType, { + cached_data1_mem_->set_data_handle(reinterpret_cast(inputs[1].data().dptr())); + }); + MSHADOW_TYPE_SWITCH(inputs[0].dtype(), DType, { + cached_data2_mem_->set_data_handle(reinterpret_cast(inputs[0].data().dptr() + 2*(EMBED/HEADS/3))); + }); + MSHADOW_TYPE_SWITCH(outputs[0].dtype(), DType, { + cached_out_mem_->set_data_handle(reinterpret_cast(outputs[0].data().dptr())); + }); + } + MKLDNNStream::Get()->RegisterPrimArgs(*fwd_, args_); + MKLDNNStream::Get()->Submit(); + + if (param_.quantized && !param_.enable_float_output) { + float* min_output = outputs[1].data().dptr(); + float* max_output = outputs[2].data().dptr(); + + *min_output = cached_min_output_; + *max_output = cached_max_output_; + } +} + +// void MKLDNNInterleavedMatMulSelfAttValAttCPU(const nnvm::NodeAttrs& attrs, +// const OpContext &ctx, +// const std::vector &inputs, +// const std::vector &req, +// const std::vector &outputs) { +// const auto& param = nnvm::get(attrs.parsed); + +// if (param.quantized) { +// if (param.enable_float_output) { +// const dnnl::memory::dim HEADS = param.heads; +// const dnnl::memory::dim BS = inputs[0].shape_[1]; +// const dnnl::memory::dim SEQ_LEN = inputs[0].shape_[0]; +// const dnnl::memory::dim EMBED = inputs[0].shape_[2]; + +// dnnl::engine engine(dnnl::engine::kind::cpu, 0); +// dnnl::stream engine_stream(engine); + +// dnnl::memory::dims src1_dims = {BS*HEADS, SEQ_LEN, SEQ_LEN}; +// dnnl::memory::dims src2_dims = {BS*HEADS, SEQ_LEN, EMBED/HEADS/3}; +// dnnl::memory::dims dst_dims = {BS*HEADS, SEQ_LEN, EMBED/HEADS/3}; + +// // dnnl::memory::dims src1_strides = {SEQ_LEN*SEQ_LEN, SEQ_LEN, 1}; +// dnnl::memory::dims src2_strides = {3*(EMBED/HEADS/3), EMBED*BS, 1}; + +// auto src1_md = dnnl::memory::desc(src1_dims, dnnl::memory::data_type::u8, dnnl::memory::format_tag::abc); // CHECK IF IT IS U8 FOR SURE +// auto src2_md = dnnl::memory::desc(src2_dims, dnnl::memory::data_type::s8, src2_strides); +// auto dst_md = dnnl::memory::desc(dst_dims, dnnl::memory::data_type::f32, dnnl::memory::format_tag::bac); + +// float min_qkv = inputs[2].dptr()[0]; +// float max_qkv = inputs[3].dptr()[0]; +// float min_att = inputs[4].dptr()[0]; +// float max_att = inputs[5].dptr()[0]; + +// float qkv_scale = GetQuantizeScale(mshadow::kInt8, min_qkv, max_qkv); +// float att_scale = GetQuantizeScale(mshadow::kUint8, min_att, max_att); + +// const float scale = 1.0f / (qkv_scale * att_scale); + +// dnnl::primitive_attr attr; +// attr.set_output_scales(0, {scale}); + +// auto matmul_d = dnnl::matmul::desc(src1_md, src2_md, dst_md); +// auto matmul_pd = dnnl::matmul::primitive_desc(matmul_d, attr, engine); + +// auto matmul_prim = dnnl::matmul(matmul_pd); + +// mshadow::Stream* s = ctx.get_stream(); +// int8_t* queries_keys_values = inputs[0].FlatTo2D(s).dptr_; +// uint8_t* attention_maps = inputs[1].FlatTo2D(s).dptr_; +// float* output = outputs[0].FlatTo2D(s).dptr_; + +// auto src1_mem = dnnl::memory(src1_md, engine, attention_maps); +// auto src2_mem = dnnl::memory(src2_md, engine, queries_keys_values+2*(EMBED/HEADS/3)); +// auto dst_mem = dnnl::memory(dst_md, engine, output); + +// std::unordered_map matmul_args; +// matmul_args.insert({DNNL_ARG_SRC, src1_mem}); +// matmul_args.insert({DNNL_ARG_WEIGHTS, src2_mem}); +// matmul_args.insert({DNNL_ARG_DST, dst_mem}); + +// matmul_prim.execute(engine_stream, matmul_args); +// engine_stream.wait(); +// } else if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) { +// const dnnl::memory::dim HEADS = param.heads; +// const dnnl::memory::dim BS = inputs[0].shape_[1]; +// const dnnl::memory::dim SEQ_LEN = inputs[0].shape_[0]; +// const dnnl::memory::dim EMBED = inputs[0].shape_[2]; + +// dnnl::engine engine(dnnl::engine::kind::cpu, 0); +// dnnl::stream engine_stream(engine); + +// dnnl::memory::dims src1_dims = {BS*HEADS, SEQ_LEN, SEQ_LEN}; +// dnnl::memory::dims src2_dims = {BS*HEADS, SEQ_LEN, EMBED/HEADS/3}; +// dnnl::memory::dims dst_dims = {BS*HEADS, SEQ_LEN, EMBED/HEADS/3}; + +// // dnnl::memory::dims src1_strides = {SEQ_LEN*SEQ_LEN, SEQ_LEN, 1}; +// dnnl::memory::dims src2_strides = {3*(EMBED/HEADS/3), EMBED*BS, 1}; + +// auto src1_md = dnnl::memory::desc(src1_dims, dnnl::memory::data_type::u8, dnnl::memory::format_tag::abc); // CHECK IF IT IS U8 FOR SURE +// auto src2_md = dnnl::memory::desc(src2_dims, dnnl::memory::data_type::s8, src2_strides); +// auto dst_md = dnnl::memory::desc(dst_dims, dnnl::memory::data_type::s8, dnnl::memory::format_tag::bac); + +// float min_qkv = inputs[2].dptr()[0]; +// float max_qkv = inputs[3].dptr()[0]; +// float min_att = inputs[4].dptr()[0]; +// float max_att = inputs[5].dptr()[0]; + +// float qkv_scale = GetQuantizeScale(mshadow::kInt8, min_qkv, max_qkv); +// float att_scale = GetQuantizeScale(mshadow::kUint8, min_att, max_att); + +// const float scale = GetQuantizeScale(mshadow::kInt8, param.min_calib_range.value(), param.max_calib_range.value()) +// / (qkv_scale * att_scale); + +// dnnl::primitive_attr attr; +// attr.set_output_scales(0, {scale}); + +// auto matmul_d = dnnl::matmul::desc(src1_md, src2_md, dst_md); +// auto matmul_pd = dnnl::matmul::primitive_desc(matmul_d, attr, engine); + +// auto matmul_prim = dnnl::matmul(matmul_pd); + +// mshadow::Stream* s = ctx.get_stream(); +// int8_t* queries_keys_values = inputs[0].FlatTo2D(s).dptr_; +// uint8_t* attention_maps = inputs[1].FlatTo2D(s).dptr_; +// int8_t* output = outputs[0].FlatTo2D(s).dptr_; + +// float* min_output = outputs[1].dptr(); +// float* max_output = outputs[2].dptr(); +// min_output[0] = param.min_calib_range.value(); +// max_output[0] = param.max_calib_range.value(); + +// auto src1_mem = dnnl::memory(src1_md, engine, attention_maps); +// auto src2_mem = dnnl::memory(src2_md, engine, queries_keys_values+2*(EMBED/HEADS/3)); +// auto dst_mem = dnnl::memory(dst_md, engine, output); + +// std::unordered_map matmul_args; +// matmul_args.insert({DNNL_ARG_SRC, src1_mem}); +// matmul_args.insert({DNNL_ARG_WEIGHTS, src2_mem}); +// matmul_args.insert({DNNL_ARG_DST, dst_mem}); + +// matmul_prim.execute(engine_stream, matmul_args); +// engine_stream.wait(); +// } else { +// const dnnl::memory::dim HEADS = param.heads; +// const dnnl::memory::dim BS = inputs[0].shape_[1]; +// const dnnl::memory::dim SEQ_LEN = inputs[0].shape_[0]; +// const dnnl::memory::dim EMBED = inputs[0].shape_[2]; + +// dnnl::engine engine(dnnl::engine::kind::cpu, 0); +// dnnl::stream engine_stream(engine); + +// dnnl::memory::dims src1_dims = {BS*HEADS, SEQ_LEN, SEQ_LEN}; +// dnnl::memory::dims src2_dims = {BS*HEADS, SEQ_LEN, EMBED/HEADS/3}; +// dnnl::memory::dims dst_dims = {BS*HEADS, SEQ_LEN, EMBED/HEADS/3}; + +// // dnnl::memory::dims src1_strides = {SEQ_LEN*SEQ_LEN, SEQ_LEN, 1}; +// dnnl::memory::dims src2_strides = {3*(EMBED/HEADS/3), EMBED*BS, 1}; + +// auto src1_md = dnnl::memory::desc(src1_dims, dnnl::memory::data_type::u8, dnnl::memory::format_tag::abc); // CHECK IF IT IS U8 FOR SURE +// auto src2_md = dnnl::memory::desc(src2_dims, dnnl::memory::data_type::s8, src2_strides); +// auto dst_md = dnnl::memory::desc(dst_dims, dnnl::memory::data_type::s32, dnnl::memory::format_tag::bac); + +// float min_qkv = inputs[2].dptr()[0]; +// float max_qkv = inputs[3].dptr()[0]; +// float min_att = inputs[4].dptr()[0]; +// float max_att = inputs[5].dptr()[0]; + +// // float qkv_scale = GetQuantizeScale(mshadow::kInt8, min_qkv, max_qkv); +// // float att_scale = GetQuantizeScale(mshadow::kUint8, min_att, max_att); + +// // const float scale = 1.0f / (qkv_scale * att_scale); + +// // dnnl::primitive_attr attr; +// // attr.set_output_scales(0, {scale}); + +// auto matmul_d = dnnl::matmul::desc(src1_md, src2_md, dst_md); +// auto matmul_pd = dnnl::matmul::primitive_desc(matmul_d, engine); + +// auto matmul_prim = dnnl::matmul(matmul_pd); + +// mshadow::Stream* s = ctx.get_stream(); +// int8_t* queries_keys_values = inputs[0].FlatTo2D(s).dptr_; +// uint8_t* attention_maps = inputs[1].FlatTo2D(s).dptr_; +// int32_t* output = outputs[0].FlatTo2D(s).dptr_; + +// float* min_output = outputs[1].dptr(); +// float* max_output = outputs[2].dptr(); + +// mxnet_op::Kernel::Launch( +// s, 1, min_output, max_output, &min_qkv, &max_qkv, &min_att, +// &max_att); + +// auto src1_mem = dnnl::memory(src1_md, engine, attention_maps); +// auto src2_mem = dnnl::memory(src2_md, engine, queries_keys_values+2*(EMBED/HEADS/3)); +// auto dst_mem = dnnl::memory(dst_md, engine, output); + +// std::unordered_map matmul_args; +// matmul_args.insert({DNNL_ARG_SRC, src1_mem}); +// matmul_args.insert({DNNL_ARG_WEIGHTS, src2_mem}); +// matmul_args.insert({DNNL_ARG_DST, dst_mem}); + +// matmul_prim.execute(engine_stream, matmul_args); +// engine_stream.wait(); +// } +// } else { +// const dnnl::memory::dim HEADS = param.heads; +// const dnnl::memory::dim BS = inputs[0].shape_[1]; +// const dnnl::memory::dim SEQ_LEN = inputs[0].shape_[0]; +// const dnnl::memory::dim EMBED = inputs[0].shape_[2]; + +// dnnl::engine engine(dnnl::engine::kind::cpu, 0); +// dnnl::stream engine_stream(engine); + +// dnnl::memory::dims src1_dims = {BS*HEADS, SEQ_LEN, SEQ_LEN}; +// dnnl::memory::dims src2_dims = {BS*HEADS, SEQ_LEN, EMBED/HEADS/3}; +// dnnl::memory::dims dst_dims = {BS*HEADS, SEQ_LEN, EMBED/HEADS/3}; + +// // dnnl::memory::dims src1_strides = {SEQ_LEN*SEQ_LEN, SEQ_LEN, 1}; +// dnnl::memory::dims src2_strides = {3*(EMBED/HEADS/3), EMBED*BS, 1}; + +// auto src1_md = dnnl::memory::desc(src1_dims, dnnl::memory::data_type::f32, dnnl::memory::format_tag::abc); +// auto src2_md = dnnl::memory::desc(src2_dims, dnnl::memory::data_type::f32, src2_strides); +// auto dst_md = dnnl::memory::desc(dst_dims, dnnl::memory::data_type::f32, dnnl::memory::format_tag::bac); + +// auto matmul_d = dnnl::matmul::desc(src1_md, src2_md, dst_md); +// auto matmul_pd = dnnl::matmul::primitive_desc(matmul_d, engine); + +// auto matmul_prim = dnnl::matmul(matmul_pd); + +// mshadow::Stream* s = ctx.get_stream(); +// float* queries_keys_values = inputs[0].FlatTo2D(s).dptr_; +// float* attention_maps = inputs[1].FlatTo2D(s).dptr_; +// float* output = outputs[0].FlatTo2D(s).dptr_; + +// auto src1_mem = dnnl::memory(src1_md, engine, attention_maps); +// auto src2_mem = dnnl::memory(src2_md, engine, queries_keys_values+2*(EMBED/HEADS/3)); +// auto dst_mem = dnnl::memory(dst_md, engine, output); + +// std::unordered_map matmul_args; +// matmul_args.insert({DNNL_ARG_SRC, src1_mem}); +// matmul_args.insert({DNNL_ARG_WEIGHTS, src2_mem}); +// matmul_args.insert({DNNL_ARG_DST, dst_mem}); + +// matmul_prim.execute(engine_stream, matmul_args); +// engine_stream.wait(); +// } +// } + +NNVM_REGISTER_OP(_sg_mkldnn_interleaved_matmul_selfatt_valatt) +.describe(R"code(_sg_mkldnn_interleaved_matmul_selfatt_valatt)code" ADD_FILELINE) +.set_num_inputs([](const NodeAttrs& attrs) { + auto const& param = nnvm::get(attrs.parsed); + if (param.quantized) { + return 6; + } else { + return 2; + } +}) +.set_num_outputs([](const NodeAttrs& attrs) { + auto const& param = nnvm::get(attrs.parsed); + if (param.quantized && !param.enable_float_output) { + return 3; + } else { + return 1; + } +}) +.set_attr_parser(ParamParser) +.set_attr("FListInputNames", [](const NodeAttrs& attrs) { + auto const& param = nnvm::get(attrs.parsed); + std::vector input_names {"queries_keys_values", "attention"}; + if (param.quantized) { + input_names.emplace_back("min_qkv"); + input_names.emplace_back("max_qkv"); + + input_names.emplace_back("min_attention"); + input_names.emplace_back("max_attention"); + } + return input_names; +}) +.set_attr("FListOutputNames", [](const NodeAttrs& attrs) { + auto const& param = nnvm::get(attrs.parsed); + std::vector output_names {"output"}; + if (param.quantized && !param.enable_float_output) { + output_names.emplace_back("min_output"); + output_names.emplace_back("max_output"); + } + return output_names; +}) +.set_attr("FInferShape", MKLDNNInterleavedMatMulSelfAttValAttShape) +.set_attr("FInferType", MKLDNNInterleavedMatMulSelfAttValAttInferType) +// .set_attr("FCompute", MKLDNNInterleavedMatMulSelfAttValAttCPU) +.set_attr("FInferStorageType", MKLDNNInterleavedMatMulSelfAttValAttStorageType) +.set_attr("FCreateOpState", CreateMKLDNNInterleavedMatMulSelfAttValAttState) +.set_attr("FStatefulComputeEx", MKLDNNInterleavedMatMulSelfAttValAttForward) +.set_attr("TIsMKLDNN", true) +.set_attr("FGradient", MakeZeroGradNodes) +.set_attr("FQuantizable", [](const NodeAttrs& attrs) { + return QuantizeType::kMust; +}) +.set_attr("FQuantizedOp", SgMKLDNNInterleavedMatMulSelfAttValAttQuantizedOp) +.set_attr("FNeedRequantize", [](const NodeAttrs& attrs) { return true; }) +.add_argument("queries_keys_values", "NDArray-or-Symbol", "Queries, keys and values interleaved") +.add_argument("attention", "NDArray-or-Symbol", "Attention maps") +.add_arguments(MKLDNNInterleavedMatMulParam::__FIELDS__()); + +} // namespace op +} // namespace mxnet + +#endif \ No newline at end of file From b46a7a587845be1d3fb4d90261ec1abe7ddef167 Mon Sep 17 00:00:00 2001 From: "B. Gawrych" Date: Tue, 23 Mar 2021 14:00:51 +0100 Subject: [PATCH 02/13] check --- .../mkldnn/mkldnn_subgraph_property.cc | 21 +- .../subgraph/mkldnn/mkldnn_transformer-inl.h | 59 ++ .../subgraph/mkldnn/mkldnn_transformer.cc | 560 ++++++------------ ...kldnn_transformer_post_quantize_property.h | 207 +++++++ .../mkldnn/mkldnn_transformer_property.h | 126 ++++ 5 files changed, 565 insertions(+), 408 deletions(-) create mode 100644 src/operator/subgraph/mkldnn/mkldnn_transformer-inl.h create mode 100644 src/operator/subgraph/mkldnn/mkldnn_transformer_post_quantize_property.h create mode 100644 src/operator/subgraph/mkldnn/mkldnn_transformer_property.h diff --git a/src/operator/subgraph/mkldnn/mkldnn_subgraph_property.cc b/src/operator/subgraph/mkldnn/mkldnn_subgraph_property.cc index 18cd3031ef18..cf50c125c719 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_subgraph_property.cc +++ b/src/operator/subgraph/mkldnn/mkldnn_subgraph_property.cc @@ -25,6 +25,8 @@ #include "mkldnn_fc_post_quantize_property.h" #include "mkldnn_elemwisemul_post_quantize_property.h" #include "mkldnn_post_quantize_align_scale_property.h" +#include "mkldnn_transformer_property.h" +#include "mkldnn_transformer_post_quantize_property.h" namespace mxnet { namespace op { @@ -35,34 +37,27 @@ MXNET_REGISTER_SUBGRAPH_BACKEND(MKLDNN) MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN, SgMKLDNNConvProperty); -#endif // MXNET_USE_MKLDNN == 1 -#if MXNET_USE_MKLDNN == 1 MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN, SgMKLDNNFCProperty); -#endif // MXNET_USE_MKLDNN == 1 -#if MXNET_USE_MKLDNN == 1 + MXNET_REGISTER_SUBGRAPH_BACKEND(MKLDNN_QUANTIZE) .set_attr("context", Context::CPU()); MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNConvProperty) .set_attr("quantize", true); -#endif // MXNET_USE_MKLDNN == 1 -#if MXNET_USE_MKLDNN == 1 - MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNFCProperty) .set_attr("quantize", true); -#endif // MXNET_USE_MKLDNN == 1 -#if MXNET_USE_MKLDNN == 1 + +MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNTransformerProperty); + +MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNTransformerPostQuantizeProperty); + MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNPostQuantizeProperty); -#endif // MXNET_USE_MKLDNN == 1 -#if MXNET_USE_MKLDNN == 1 MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNFCPostQuantizeProperty); MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, ElemwiseMulPostQuantizeProperty); MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNPostQuantizeAlignScaleProperty); -#endif // MXNET_USE_MKLDNN == 1 -#if MXNET_USE_MKLDNN == 1 } // namespace op } // namespace mxnet #endif // MXNET_USE_MKLDNN == 1 diff --git a/src/operator/subgraph/mkldnn/mkldnn_transformer-inl.h b/src/operator/subgraph/mkldnn/mkldnn_transformer-inl.h new file mode 100644 index 000000000000..9318db9d3e50 --- /dev/null +++ b/src/operator/subgraph/mkldnn/mkldnn_transformer-inl.h @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef MXNET_OPERATOR_MKLDNN_TRANSFORMER_INL_H_ +#define MXNET_OPERATOR_MKLDNN_TRANSFORMER_INL_H_ + +// #include +#include "../../mxnet_op.h" +#include "../../mshadow_op.h" + + +namespace mxnet { +namespace op { + +struct MKLDNNInterleavedMatMulParam : public dmlc::Parameter { + int heads; + bool quantized; + bool enable_float_output; + dmlc::optional min_calib_range; // min float value calculated from calibration dataset + dmlc::optional max_calib_range; // max float value calculated from calibration dataset + DMLC_DECLARE_PARAMETER(MKLDNNInterleavedMatMulParam) { + DMLC_DECLARE_FIELD(heads) + .describe("Set number of heads"); + DMLC_DECLARE_FIELD(quantized).set_default(false) + .describe("Whether it's a quantized InterleavedMatMul operator"); + DMLC_DECLARE_FIELD(enable_float_output).set_default(false) + .describe("Whether to enable float32 output"); + DMLC_DECLARE_FIELD(min_calib_range) + .set_default(dmlc::optional()) + .describe("The minimum scalar value in the form of float32 obtained " + "through calibration. If present, it will be used to by " + "quantized InterleavedMatMul op to calculate primitive scale"); + DMLC_DECLARE_FIELD(max_calib_range) + .set_default(dmlc::optional()) + .describe("The maximum scalar value in the form of float32 obtained " + "through calibration. If present, it will be used to by " + "quantized InterleavedMatMul op to calculate primitive scale"); + } +}; + +} // namespace op +} // namespace mxnet +#endif // MXNET_OPERATOR_MKLDNN_TRANSFORMER_INL_H_ diff --git a/src/operator/subgraph/mkldnn/mkldnn_transformer.cc b/src/operator/subgraph/mkldnn/mkldnn_transformer.cc index c34d0dde6996..5143e28e55e7 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_transformer.cc +++ b/src/operator/subgraph/mkldnn/mkldnn_transformer.cc @@ -35,41 +35,36 @@ namespace op { DMLC_REGISTER_PARAMETER(MKLDNNInterleavedMatMulParam); -static bool MKLDNNInterleavedMatMulSelfAttQKShape(const NodeAttrs& attrs, - mxnet::ShapeVector* in_shape, - mxnet::ShapeVector* out_shape) { +static bool SgMKLDNNSelfAttQKShape(const NodeAttrs& attrs, + mxnet::ShapeVector* in_shapes, + mxnet::ShapeVector* out_shapes) { const auto& param = nnvm::get(attrs.parsed); if (param.quantized) { - auto qkv_shape = in_shape->at(0); - out_shape->resize(3); - SHAPE_ASSIGN_CHECK(*out_shape, 0, - mxnet::TShape({param.heads * qkv_shape[1], qkv_shape[0], qkv_shape[0]})); // output + mxnet::ShapeVector base_in_shapes = {in_shapes[0]}; + mxnet::ShapeVector base_out_shapes = {out_shapes[0]}; + bool ret = DefaultSubgraphOpShape(attrs, &base_in_shapes, &base_out_shapes); if (!param.enable_float_output) { - SHAPE_ASSIGN_CHECK(*out_shape, 1, mxnet::TShape({1})); // min output - SHAPE_ASSIGN_CHECK(*out_shape, 2, mxnet::TShape({1})); // max output + SHAPE_ASSIGN_CHECK(*out_shapes, 1, mxnet::TShape({1})); // min output + SHAPE_ASSIGN_CHECK(*out_shapes, 2, mxnet::TShape({1})); // max output } - return true; + + return ret; } else { - CHECK_EQ(in_shape->size(), 1U) << "Input:[queries_keys_values] currently have, " - << in_shape->size() << " inputs"; - auto qkv_shape = in_shape->at(0); - CHECK_EQ(qkv_shape.ndim(), 3U) - << "Input queries_keys_values should be 3D in seq_length-batch-proj_dim, " - << "currently is: " << qkv_shape.ndim() << "D"; - out_shape->resize(1); - SHAPE_ASSIGN_CHECK(*out_shape, 0, - mxnet::TShape({param.heads * qkv_shape[1], qkv_shape[0], qkv_shape[0]})); - return true; + return DefaultSubgraphOpShape(attrs, in_shapes, out_shapes); } } -static bool MKLDNNInterleavedMatMulSelfAttQKInferType(const nnvm::NodeAttrs &attrs, +static bool SgMKLDNNSelfAttQKInferType(const nnvm::NodeAttrs &attrs, std::vector *in_types, std::vector *out_types) { const auto& param = nnvm::get(attrs.parsed); if (param.quantized) { - TYPE_ASSIGN_CHECK(*in_types, 0, mshadow::kInt8); // qkv input + CHECK(in_types->at(0) == mshadow::kInt8 || + in_types->at(0) == mshadow::kUint8) + << "QuantizedInterleavedMatMulSelfAttQK only supports int8/uint8 input, while " + << in_types->at(0) << " is given."; + TYPE_ASSIGN_CHECK(*in_types, 1, mshadow::kFloat32); // min value TYPE_ASSIGN_CHECK(*in_types, 2, mshadow::kFloat32); // max value @@ -90,35 +85,40 @@ static bool MKLDNNInterleavedMatMulSelfAttQKInferType(const nnvm::NodeAttrs &att } } -static bool MKLDNNInterleavedMatMulSelfAttQKStorageType(const nnvm::NodeAttrs &attrs, +static bool SgMKLDNNSelfAttQKStorageType(const nnvm::NodeAttrs &attrs, const int dev_mask, DispatchMode *dispatch_mode, std::vector *in_attrs, std::vector *out_attrs) { auto const ¶m = nnvm::get(attrs.parsed); if (param.quantized) { - type_assign(&in_attrs->at(0), mxnet::kDefaultStorage); - type_assign(&in_attrs->at(1), mxnet::kDefaultStorage); - type_assign(&in_attrs->at(2), mxnet::kDefaultStorage); + std::vector base_in_attrs{in_attrs->at(0)}; + std::vector base_out_attrs{out_attrs->at(0)}; + bool ret = DefaultSubgraphOpStorageType(attrs, dev_mask, dispatch_mode, + &base_in_attrs, &base_out_attrs); + + for (size_t i = 0; i < in_attrs->size(); ++i) { + if (i < base_in_attrs.size()) + in_attrs->at(i) = base_in_attrs[i]; + else + type_assign(&in_attrs->at(i), mxnet::kDefaultStorage); + } - type_assign(&out_attrs->at(0), mxnet::kDefaultStorage); + out_attrs->at(0) = base_out_attrs[0]; if (!param.enable_float_output) { type_assign(&out_attrs->at(1), mxnet::kDefaultStorage); type_assign(&out_attrs->at(2), mxnet::kDefaultStorage); } - std::vector base_in_attrs{in_attrs->at(0)}; - std::vector base_out_attrs{out_attrs->at(0)}; - return DefaultSubgraphOpStorageType(attrs, dev_mask, dispatch_mode, - &base_in_attrs, &base_out_attrs);; + return ret; } else { return DefaultSubgraphOpStorageType(attrs, dev_mask, dispatch_mode, in_attrs, out_attrs); } } -class MKLDNNInterleavedMatMulSelfAttQKOp { +class SgMKLDNNSelfAttQKOp { public: - explicit MKLDNNInterleavedMatMulSelfAttQKOp(const nnvm::NodeAttrs &attrs) : + explicit SgMKLDNNSelfAttQKOp(const nnvm::NodeAttrs &attrs) : param_(nnvm::get(attrs.parsed)) {} void Forward(const OpContext &ctx, @@ -139,8 +139,8 @@ class MKLDNNInterleavedMatMulSelfAttQKOp { MKLDNNInterleavedMatMulParam param_; mkldnn_args_map_t args_; std::shared_ptr fwd_; - std::shared_ptr cached_data1_mem_; - std::shared_ptr cached_data2_mem_; + std::shared_ptr cached_query_mem_; + std::shared_ptr cached_key_mem_; std::shared_ptr cached_out_mem_; float cached_min_data_; float cached_max_data_; @@ -149,27 +149,27 @@ class MKLDNNInterleavedMatMulSelfAttQKOp { float data_scale_{0.0f}; }; -static OpStatePtr CreateMKLDNNInterleavedMatMulSelfAttQKState(const nnvm::NodeAttrs &attrs, - Context ctx, - const mxnet::ShapeVector &in_shapes, - const std::vector &in_types) { - return OpStatePtr::Create(attrs); +static OpStatePtr CreateSgMKLDNNSelfAttQKState(const nnvm::NodeAttrs &attrs, + Context ctx, + const mxnet::ShapeVector &in_shapes, + const std::vector &in_types) { + return OpStatePtr::Create(attrs); } -static void MKLDNNInterleavedMatMulSelfAttQKForward(const OpStatePtr &state_pointer, - const OpContext &ctx, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs) { - MKLDNNInterleavedMatMulSelfAttQKOp &op = state_pointer.get_state(); +static void SgMKLDNNSelfAttQKForward(const OpStatePtr &state_pointer, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + SgMKLDNNSelfAttQKOp &op = state_pointer.get_state(); op.Forward(ctx, inputs, req, outputs); } -void MKLDNNInterleavedMatMulSelfAttQKOp::Forward( - const OpContext &ctx, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs) { +void SgMKLDNNSelfAttQKOp::Forward(const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + using namespace mkldnn; float min_data = 0.0f; float max_data = 0.0f; @@ -179,10 +179,21 @@ void MKLDNNInterleavedMatMulSelfAttQKOp::Forward( max_data = inputs[2].data().dptr()[0]; } - const dnnl::memory::dim HEADS = param_.heads; - const dnnl::memory::dim BS = inputs[0].shape()[1]; - const dnnl::memory::dim SEQ_LEN = inputs[0].shape()[0]; - const dnnl::memory::dim EMBED = inputs[0].shape()[2]; + const auto qkv_tensor = inputs[0]; + const auto out_tensor = outputs[0]; + const auto qkv_dtype = get_mkldnn_type(qkv_tensor.dtype()); + const auto out_dtype = get_mkldnn_type(out_tensor.dtype()); + + const memory::dim heads = param_.heads; + const memory::dim sequences = inputs[0].shape()[1]; + const memory::dim qkv_seq_len = inputs[0].shape()[0]; + const memory::dim output_lin_dim = inputs[0].shape()[2]; + const memory::dim embed_dim = output_lin_dim / 3; + const memory::dim head_dim = embed_dim / heads; + const memory::dim attn_batches = heads * sequences; + const memory::dim lead_dim = attn_batches * 3 * head_dim; + const memory::dim batch_stride = 3 * head_dim; + const float scale = 1.0 / sqrt(static_cast(head_dim)); if (!initialized_) { const auto engine = CpuEngine::Get()->get_engine(); @@ -190,74 +201,80 @@ void MKLDNNInterleavedMatMulSelfAttQKOp::Forward( cached_min_data_ = min_data; cached_max_data_ = max_data; - dnnl::memory::dims src1_dims = {HEADS*BS, SEQ_LEN, EMBED/HEADS/3}; - dnnl::memory::dims src2_dims = {HEADS*BS, EMBED/HEADS/3, SEQ_LEN}; - dnnl::memory::dims dst_dims = {HEADS*BS, SEQ_LEN, SEQ_LEN}; - - dnnl::memory::dims src1_strides = {3*(EMBED/HEADS/3), EMBED*BS, 1}; - dnnl::memory::dims src2_strides = {3*(EMBED/HEADS/3), 1, EMBED*BS}; + memory::dims query_dims = {attn_batches, qkv_seq_len, head_dim}; + memory::dims key_dims = {attn_batches, head_dim, qkv_seq_len}; + memory::dims out_dims = {attn_batches, qkv_seq_len, qkv_seq_len}; - auto src1_md = param_.quantized ? dnnl::memory::desc(src1_dims, dnnl::memory::data_type::s8, src1_strides) : - dnnl::memory::desc(src1_dims, dnnl::memory::data_type::f32, src1_strides); - auto src2_md = param_.quantized ? dnnl::memory::desc(src2_dims, dnnl::memory::data_type::s8, src2_strides) : - dnnl::memory::desc(src2_dims, dnnl::memory::data_type::f32, src2_strides); + memory::dims query_strides = {batch_stride, lead_dim, 1}; + memory::dims key_strides = {batch_stride, 1, lead_dim}; + + auto query_md = memory::desc(query_dims, qkv_dtype, query_strides); + auto key_md = memory::desc(key_dims, qkv_dtype, key_dims); - dnnl::memory::desc dst_md; + memory::desc out_md; float tmp_scale = 1.0f; if (param_.quantized) { - data_scale_ = GetQuantizeScale(mshadow::kInt8, cached_min_data_, cached_max_data_); + data_scale_ = GetQuantizeScale(qkv_tensor.dtype(), cached_min_data_, cached_max_data_); if (param_.min_calib_range.has_value() && param_.max_calib_range.has_value()) { cached_min_output_ = param_.min_calib_range.value(); cached_max_output_ = param_.max_calib_range.value(); - tmp_scale = GetQuantizeScale(mshadow::kInt8, cached_min_output_, cached_max_output_) / (data_scale_ * data_scale_); - dst_md = dnnl::memory::desc(dst_dims, dnnl::memory::data_type::s8, dnnl::memory::format_tag::abc); + tmp_scale = + GetQuantizeScale(out_tensor.dtype(), cached_min_output_, cached_max_output_) / + (data_scale_ * data_scale_); + out_md = memory::desc(out_dims, memory::data_type::s8, memory::format_tag::abc); } else if (param_.enable_float_output) { tmp_scale = 1.0f / (data_scale_ * data_scale_); - dst_md = dnnl::memory::desc(dst_dims, dnnl::memory::data_type::f32, dnnl::memory::format_tag::abc); + out_md = dnnl::memory::desc(out_dims, memory::data_type::f32, memory::format_tag::abc); } else { mshadow::Stream *s = ctx.get_stream(); mxnet_op::Kernel::Launch( s, 1, &cached_min_output_, &cached_max_output_, &min_data, &max_data, &min_data, &max_data); - dst_md = dnnl::memory::desc(dst_dims, dnnl::memory::data_type::s32, dnnl::memory::format_tag::abc); + out_md = dnnl::memory::desc(out_dims, memory::data_type::s32, memory::format_tag::abc); } } else { - dst_md = dnnl::memory::desc(dst_dims, dnnl::memory::data_type::f32, dnnl::memory::format_tag::abc); + out_md = dnnl::memory::desc(out_dims, memory::data_type::f32, memory::format_tag::abc); } - tmp_scale /= sqrt(static_cast(EMBED/HEADS/3)); + tmp_scale /= sqrt(static_cast(head_dim)); dnnl::primitive_attr attr; attr.set_output_scales(0, {tmp_scale}); - auto matmul_d = dnnl::matmul::desc(src1_md, src2_md, dst_md); - auto matmul_pd = dnnl::matmul::primitive_desc(matmul_d, attr, engine); + auto matmul_d = matmul::desc(query_md, key_md, out_md); + auto matmul_pd = matmul::primitive_desc(matmul_d, attr, engine); - fwd_ = std::make_shared(matmul_pd); + fwd_ = std::make_shared(matmul_pd); MSHADOW_TYPE_SWITCH(inputs[0].dtype(), DType, { - cached_data1_mem_ = std::make_shared(src1_md, engine, inputs[0].data().dptr()); - cached_data2_mem_ = std::make_shared(src2_md, engine, inputs[0].data().dptr() + (EMBED/HEADS/3)); + cached_query_mem_ = std::make_shared(query_md, engine, inputs[0].data().dptr()); + cached_key_mem_ = std::make_shared(key_md, engine, inputs[0].data().dptr() + (head_dim)); }); MSHADOW_TYPE_SWITCH(outputs[0].dtype(), DType, { - cached_out_mem_ = std::make_shared(dst_md, engine, outputs[0].data().dptr()); + cached_out_mem_ = std::make_shared(out_md, engine, outputs[0].data().dptr()); }); - args_[DNNL_ARG_SRC] = *cached_data1_mem_; - args_[DNNL_ARG_WEIGHTS] = *cached_data2_mem_; - args_[DNNL_ARG_DST] = *cached_out_mem_; + args_[DNNL_ARG_SRC] = *cached_query_mem_; + args_[DNNL_ARG_WEIGHTS] = *cached_key_mem_; + args_[DNNL_ARG_DST] = *cached_out_mem_; + initialized_ = true; } else { - MSHADOW_TYPE_SWITCH(inputs[0].dtype(), DType, { - cached_data1_mem_->set_data_handle(reinterpret_cast(inputs[0].data().dptr())); - cached_data2_mem_->set_data_handle(reinterpret_cast(inputs[0].data().dptr() + (EMBED/HEADS/3))); + + MSHADOW_TYPE_SWITCH(qkv_tensor.dtype(), DType, { + void* query_mem_ptr = reinterpret_cast(inputs[0].data().dptr()); + void* key_mem_ptr = query_mem_ptr + head_dim; + cached_query_mem_->set_data_handle(query_mem_ptr); + cached_key_mem_->set_data_handle(key_mem_ptr); }); - MSHADOW_TYPE_SWITCH(outputs[0].dtype(), DType, { + + MSHADOW_TYPE_SWITCH(out_tensor.dtype(), DType, { cached_out_mem_->set_data_handle(reinterpret_cast(outputs[0].data().dptr())); }); } + MKLDNNStream::Get()->RegisterPrimArgs(*fwd_, args_); MKLDNNStream::Get()->Submit(); @@ -269,256 +286,9 @@ void MKLDNNInterleavedMatMulSelfAttQKOp::Forward( *max_output = cached_max_output_; } - - - - - // if (param_.quantized) { - // if (param_.enable_float_output) { - - - // dnnl::engine engine(dnnl::engine::kind::cpu, 0); - // dnnl::stream engine_stream(engine); - - // dnnl::memory::dims src1_dims = {HEADS*BS, SEQ_LEN, EMBED/HEADS/3}; - // dnnl::memory::dims src2_dims = {HEADS*BS, EMBED/HEADS/3, SEQ_LEN}; - // dnnl::memory::dims dst_dims = {HEADS*BS, SEQ_LEN, SEQ_LEN}; - - // dnnl::memory::dims src1_strides = {3*(EMBED/HEADS/3), EMBED*BS, 1}; - // dnnl::memory::dims src2_strides = {3*(EMBED/HEADS/3), 1, EMBED*BS}; - - // auto src1_md = dnnl::memory::desc(src1_dims, dnnl::memory::data_type::s8, src1_strides); - // auto src2_md = dnnl::memory::desc(src2_dims, dnnl::memory::data_type::s8, src2_strides); - // auto dst_md = dnnl::memory::desc(dst_dims, dnnl::memory::data_type::f32, dnnl::memory::format_tag::abc); - - // // const float scale = 1.0f / sqrt(static_cast(EMBED/HEADS/3)); - // float min_data = inputs[1].dptr()[0]; - // float max_data = inputs[2].dptr()[0]; - - // float data_scale = GetQuantizeScale(mshadow::kInt8, min_data, max_data); - // const float scale = 1.0f / (data_scale * data_scale) / sqrt(static_cast(EMBED/HEADS/3)); - - // dnnl::primitive_attr attr; - // attr.set_output_scales(0, {scale}); - - // // CODE FOR HANLDING MASKING - // // float* mask = inputs[1].FlatTo2D(s).dptr_; - // // memcpy(output, mask, sizeof(float)*HEADS*BS*SEQ_LEN*SEQ_LEN); - // // dnnl::post_ops post_op; - // // post_op.append_sum(1); - // // attr.set_post_ops(post_op); - - // auto matmul_d = dnnl::matmul::desc(src1_md, src2_md, dst_md); - // auto matmul_pd = dnnl::matmul::primitive_desc(matmul_d, attr, engine); - - // auto matmul_prim = dnnl::matmul(matmul_pd); - - // mshadow::Stream* s = ctx.get_stream(); - // int8_t* queries_keys_values = inputs[0].FlatTo2D(s).dptr_; - - // float* output = outputs[0].FlatTo2D(s).dptr_; - - // auto src1_mem = dnnl::memory(src1_md, engine, queries_keys_values); - // auto src2_mem = dnnl::memory(src2_md, engine, queries_keys_values+(EMBED/HEADS/3)); - // auto dst_mem = dnnl::memory(dst_md, engine, output); - - // std::unordered_map matmul_args; - // matmul_args.insert({DNNL_ARG_SRC, src1_mem}); - // matmul_args.insert({DNNL_ARG_WEIGHTS, src2_mem}); - // matmul_args.insert({DNNL_ARG_DST, dst_mem}); - - // matmul_prim.execute(engine_stream, matmul_args); - // engine_stream.wait(); - // } else if (param_.min_calib_range.has_value() && param_.max_calib_range.has_value()) { - // const dnnl::memory::dim HEADS = param.heads; - // const dnnl::memory::dim BS = inputs[0].shape_[1]; - // const dnnl::memory::dim SEQ_LEN = inputs[0].shape_[0]; - // const dnnl::memory::dim EMBED = inputs[0].shape_[2]; - - // dnnl::engine engine(dnnl::engine::kind::cpu, 0); - // dnnl::stream engine_stream(engine); - - // dnnl::memory::dims src1_dims = {HEADS*BS, SEQ_LEN, EMBED/HEADS/3}; - // dnnl::memory::dims src2_dims = {HEADS*BS, EMBED/HEADS/3, SEQ_LEN}; - // dnnl::memory::dims dst_dims = {HEADS*BS, SEQ_LEN, SEQ_LEN}; - - // dnnl::memory::dims src1_strides = {3*(EMBED/HEADS/3), EMBED*BS, 1}; - // dnnl::memory::dims src2_strides = {3*(EMBED/HEADS/3), 1, EMBED*BS}; - - // auto src1_md = dnnl::memory::desc(src1_dims, dnnl::memory::data_type::s8, src1_strides); - // auto src2_md = dnnl::memory::desc(src2_dims, dnnl::memory::data_type::s8, src2_strides); - // auto dst_md = dnnl::memory::desc(dst_dims, dnnl::memory::data_type::s8, dnnl::memory::format_tag::abc); - - // // const float scale = 1.0f / sqrt(static_cast(EMBED/HEADS/3)); - // float min_data = inputs[1].dptr()[0]; - // float max_data = inputs[2].dptr()[0]; - - - - // float data_scale = GetQuantizeScale(mshadow::kInt8, min_data, max_data); - // const float scale = GetQuantizeScale(mshadow::kInt8, param.min_calib_range.value(), param.max_calib_range.value()) - // / (data_scale * data_scale) / sqrt(static_cast(EMBED/HEADS/3)); - - // dnnl::primitive_attr attr; - // attr.set_output_scales(0, {scale}); - - // // CODE FOR HANLDING MASKING - // // float* mask = inputs[1].FlatTo2D(s).dptr_; - // // memcpy(output, mask, sizeof(float)*HEADS*BS*SEQ_LEN*SEQ_LEN); - // // dnnl::post_ops post_op; - // // post_op.append_sum(1); - // // attr.set_post_ops(post_op); - - // auto matmul_d = dnnl::matmul::desc(src1_md, src2_md, dst_md); - // auto matmul_pd = dnnl::matmul::primitive_desc(matmul_d, attr, engine); - - // auto matmul_prim = dnnl::matmul(matmul_pd); - - // mshadow::Stream* s = ctx.get_stream(); - // int8_t* queries_keys_values = inputs[0].FlatTo2D(s).dptr_; - - // int8_t* output = outputs[0].FlatTo2D(s).dptr_; - // float* min_output = outputs[1].dptr(); - // float* max_output = outputs[2].dptr(); - // min_output[0] = param.min_calib_range.value(); - // max_output[0] = param.max_calib_range.value(); - - // auto src1_mem = dnnl::memory(src1_md, engine, queries_keys_values); - // auto src2_mem = dnnl::memory(src2_md, engine, queries_keys_values+(EMBED/HEADS/3)); - // auto dst_mem = dnnl::memory(dst_md, engine, output); - - // std::unordered_map matmul_args; - // matmul_args.insert({DNNL_ARG_SRC, src1_mem}); - // matmul_args.insert({DNNL_ARG_WEIGHTS, src2_mem}); - // matmul_args.insert({DNNL_ARG_DST, dst_mem}); - - // matmul_prim.execute(engine_stream, matmul_args); - // engine_stream.wait(); - // } else { - // const dnnl::memory::dim HEADS = param.heads; - // const dnnl::memory::dim BS = inputs[0].shape_[1]; - // const dnnl::memory::dim SEQ_LEN = inputs[0].shape_[0]; - // const dnnl::memory::dim EMBED = inputs[0].shape_[2]; - - // dnnl::engine engine(dnnl::engine::kind::cpu, 0); - // dnnl::stream engine_stream(engine); - - // dnnl::memory::dims src1_dims = {HEADS*BS, SEQ_LEN, EMBED/HEADS/3}; - // dnnl::memory::dims src2_dims = {HEADS*BS, EMBED/HEADS/3, SEQ_LEN}; - // dnnl::memory::dims dst_dims = {HEADS*BS, SEQ_LEN, SEQ_LEN}; - - // dnnl::memory::dims src1_strides = {3*(EMBED/HEADS/3), EMBED*BS, 1}; - // dnnl::memory::dims src2_strides = {3*(EMBED/HEADS/3), 1, EMBED*BS}; - - // auto src1_md = dnnl::memory::desc(src1_dims, dnnl::memory::data_type::s8, src1_strides); - // auto src2_md = dnnl::memory::desc(src2_dims, dnnl::memory::data_type::s8, src2_strides); - // auto dst_md = dnnl::memory::desc(dst_dims, dnnl::memory::data_type::s32, dnnl::memory::format_tag::abc); - - // // const float scale = 1.0f / sqrt(static_cast(EMBED/HEADS/3)); - // float min_data = inputs[1].dptr()[0]; - // float max_data = inputs[2].dptr()[0]; - - // // float data_scale = GetQuantizeScale(mshadow::kInt8, min_data, max_data); - // const float scale = 1.0f / sqrt(static_cast(EMBED/HEADS/3)); - - // dnnl::primitive_attr attr; - // attr.set_output_scales(0, {scale}); - - // // CODE FOR HANLDING MASKING - // // float* mask = inputs[1].FlatTo2D(s).dptr_; - // // memcpy(output, mask, sizeof(float)*HEADS*BS*SEQ_LEN*SEQ_LEN); - // // dnnl::post_ops post_op; - // // post_op.append_sum(1); - // // attr.set_post_ops(post_op); - - // auto matmul_d = dnnl::matmul::desc(src1_md, src2_md, dst_md); - // auto matmul_pd = dnnl::matmul::primitive_desc(matmul_d, attr, engine); - - // auto matmul_prim = dnnl::matmul(matmul_pd); - - // mshadow::Stream* s = ctx.get_stream(); - // int8_t* queries_keys_values = inputs[0].FlatTo2D(s).dptr_; - - // int32_t* output = outputs[0].FlatTo2D(s).dptr_; - // float* min_output = outputs[1].dptr(); - // float* max_output = outputs[2].dptr(); - - // mxnet_op::Kernel::Launch( - // s, 1, min_output, max_output, &min_data, &max_data, &min_data, - // &max_data); - - // // min_output[0] = min_data; - // // max_output[0] = max_data; - - // auto src1_mem = dnnl::memory(src1_md, engine, queries_keys_values); - // auto src2_mem = dnnl::memory(src2_md, engine, queries_keys_values+(EMBED/HEADS/3)); - // auto dst_mem = dnnl::memory(dst_md, engine, output); - - // std::unordered_map matmul_args; - // matmul_args.insert({DNNL_ARG_SRC, src1_mem}); - // matmul_args.insert({DNNL_ARG_WEIGHTS, src2_mem}); - // matmul_args.insert({DNNL_ARG_DST, dst_mem}); - - // matmul_prim.execute(engine_stream, matmul_args); - // engine_stream.wait(); - // } - // } else { - // const dnnl::memory::dim HEADS = param.heads; - // const dnnl::memory::dim BS = inputs[0].shape_[1]; - // const dnnl::memory::dim SEQ_LEN = inputs[0].shape_[0]; - // const dnnl::memory::dim EMBED = inputs[0].shape_[2]; - - // dnnl::engine engine(dnnl::engine::kind::cpu, 0); - // dnnl::stream engine_stream(engine); - - // dnnl::memory::dims src1_dims = {HEADS*BS, SEQ_LEN, EMBED/HEADS/3}; - // dnnl::memory::dims src2_dims = {HEADS*BS, EMBED/HEADS/3, SEQ_LEN}; - // dnnl::memory::dims dst_dims = {HEADS*BS, SEQ_LEN, SEQ_LEN}; - - // dnnl::memory::dims src1_strides = {3*(EMBED/HEADS/3), EMBED*BS, 1}; - // dnnl::memory::dims src2_strides = {3*(EMBED/HEADS/3), 1, EMBED*BS}; - - // auto src1_md = dnnl::memory::desc(src1_dims, dnnl::memory::data_type::f32, src1_strides); - // auto src2_md = dnnl::memory::desc(src2_dims, dnnl::memory::data_type::f32, src2_strides); - // auto dst_md = dnnl::memory::desc(dst_dims, dnnl::memory::data_type::f32, dnnl::memory::format_tag::abc); - - // const float scale = 1.0f / sqrt(static_cast(EMBED/HEADS/3)); - - // dnnl::primitive_attr attr; - // attr.set_output_scales(0, {scale}); - - // // CODE FOR HANLDING MASKING - // // float* mask = inputs[1].FlatTo2D(s).dptr_; - // // memcpy(output, mask, sizeof(float)*HEADS*BS*SEQ_LEN*SEQ_LEN); - // // dnnl::post_ops post_op; - // // post_op.append_sum(1); - // // attr.set_post_ops(post_op); - - // auto matmul_d = dnnl::matmul::desc(src1_md, src2_md, dst_md); - // auto matmul_pd = dnnl::matmul::primitive_desc(matmul_d, attr, engine); - - // auto matmul_prim = dnnl::matmul(matmul_pd); - - // mshadow::Stream* s = ctx.get_stream(); - // float* queries_keys_values = inputs[0].FlatTo2D(s).dptr_; - - // float* output = outputs[0].FlatTo2D(s).dptr_; - - // auto src1_mem = dnnl::memory(src1_md, engine, queries_keys_values); - // auto src2_mem = dnnl::memory(src2_md, engine, queries_keys_values+(EMBED/HEADS/3)); - // auto dst_mem = dnnl::memory(dst_md, engine, output); - - // std::unordered_map matmul_args; - // matmul_args.insert({DNNL_ARG_SRC, src1_mem}); - // matmul_args.insert({DNNL_ARG_WEIGHTS, src2_mem}); - // matmul_args.insert({DNNL_ARG_DST, dst_mem}); - - // matmul_prim.execute(engine_stream, matmul_args); - // engine_stream.wait(); - // } } -nnvm::ObjectPtr SgMKLDNNInterleavedMatMulSelfAttQKQuantizedOp(const NodeAttrs& attrs) { +nnvm::ObjectPtr SgMKLDNNSelfAttQKQuantizedOp(const NodeAttrs& attrs) { nnvm::ObjectPtr node = nnvm::Node::Create(); auto const ¶m = nnvm::get(attrs.parsed); node->attrs.op = Op::Get("_sg_mkldnn_interleaved_matmul_selfatt_qk"); @@ -571,17 +341,17 @@ NNVM_REGISTER_OP(_sg_mkldnn_interleaved_matmul_selfatt_qk) } return output_names; }) -.set_attr("FInferShape", MKLDNNInterleavedMatMulSelfAttQKShape) -.set_attr("FInferType", MKLDNNInterleavedMatMulSelfAttQKInferType) -.set_attr("FInferStorageType", MKLDNNInterleavedMatMulSelfAttQKStorageType) -.set_attr("FCreateOpState", CreateMKLDNNInterleavedMatMulSelfAttQKState) -.set_attr("FStatefulComputeEx", MKLDNNInterleavedMatMulSelfAttQKForward) +.set_attr("FInferShape", SgMKLDNNSelfAttQKShape) +.set_attr("FInferType", SgMKLDNNSelfAttQKInferType) +.set_attr("FInferStorageType", SgMKLDNNSelfAttQKStorageType) +.set_attr("FCreateOpState", CreateSgMKLDNNSelfAttQKState) +.set_attr("FStatefulComputeEx", SgMKLDNNSelfAttQKForward) .set_attr("TIsMKLDNN", true) .set_attr("FGradient", MakeZeroGradNodes) .set_attr("FQuantizable", [](const NodeAttrs& attrs) { return QuantizeType::kMust; }) -.set_attr("FQuantizedOp", SgMKLDNNInterleavedMatMulSelfAttQKQuantizedOp) +.set_attr("FQuantizedOp", SgMKLDNNSelfAttQKQuantizedOp) .set_attr("FNeedRequantize", [](const NodeAttrs& attrs) { return true; }) .add_argument("queries_keys_values", "NDArray-or-Symbol", "Interleaved queries, keys and values") .add_arguments(MKLDNNInterleavedMatMulParam::__FIELDS__()); @@ -1101,62 +871,62 @@ void MKLDNNInterleavedMatMulSelfAttValAttOp::Forward( // } // } -NNVM_REGISTER_OP(_sg_mkldnn_interleaved_matmul_selfatt_valatt) -.describe(R"code(_sg_mkldnn_interleaved_matmul_selfatt_valatt)code" ADD_FILELINE) -.set_num_inputs([](const NodeAttrs& attrs) { - auto const& param = nnvm::get(attrs.parsed); - if (param.quantized) { - return 6; - } else { - return 2; - } -}) -.set_num_outputs([](const NodeAttrs& attrs) { - auto const& param = nnvm::get(attrs.parsed); - if (param.quantized && !param.enable_float_output) { - return 3; - } else { - return 1; - } -}) -.set_attr_parser(ParamParser) -.set_attr("FListInputNames", [](const NodeAttrs& attrs) { - auto const& param = nnvm::get(attrs.parsed); - std::vector input_names {"queries_keys_values", "attention"}; - if (param.quantized) { - input_names.emplace_back("min_qkv"); - input_names.emplace_back("max_qkv"); +// NNVM_REGISTER_OP(_sg_mkldnn_interleaved_matmul_selfatt_valatt) +// .describe(R"code(_sg_mkldnn_interleaved_matmul_selfatt_valatt)code" ADD_FILELINE) +// .set_num_inputs([](const NodeAttrs& attrs) { +// auto const& param = nnvm::get(attrs.parsed); +// if (param.quantized) { +// return 6; +// } else { +// return 2; +// } +// }) +// .set_num_outputs([](const NodeAttrs& attrs) { +// auto const& param = nnvm::get(attrs.parsed); +// if (param.quantized && !param.enable_float_output) { +// return 3; +// } else { +// return 1; +// } +// }) +// .set_attr_parser(ParamParser) +// .set_attr("FListInputNames", [](const NodeAttrs& attrs) { +// auto const& param = nnvm::get(attrs.parsed); +// std::vector input_names {"queries_keys_values", "attention"}; +// if (param.quantized) { +// input_names.emplace_back("min_qkv"); +// input_names.emplace_back("max_qkv"); - input_names.emplace_back("min_attention"); - input_names.emplace_back("max_attention"); - } - return input_names; -}) -.set_attr("FListOutputNames", [](const NodeAttrs& attrs) { - auto const& param = nnvm::get(attrs.parsed); - std::vector output_names {"output"}; - if (param.quantized && !param.enable_float_output) { - output_names.emplace_back("min_output"); - output_names.emplace_back("max_output"); - } - return output_names; -}) -.set_attr("FInferShape", MKLDNNInterleavedMatMulSelfAttValAttShape) -.set_attr("FInferType", MKLDNNInterleavedMatMulSelfAttValAttInferType) -// .set_attr("FCompute", MKLDNNInterleavedMatMulSelfAttValAttCPU) -.set_attr("FInferStorageType", MKLDNNInterleavedMatMulSelfAttValAttStorageType) -.set_attr("FCreateOpState", CreateMKLDNNInterleavedMatMulSelfAttValAttState) -.set_attr("FStatefulComputeEx", MKLDNNInterleavedMatMulSelfAttValAttForward) -.set_attr("TIsMKLDNN", true) -.set_attr("FGradient", MakeZeroGradNodes) -.set_attr("FQuantizable", [](const NodeAttrs& attrs) { - return QuantizeType::kMust; -}) -.set_attr("FQuantizedOp", SgMKLDNNInterleavedMatMulSelfAttValAttQuantizedOp) -.set_attr("FNeedRequantize", [](const NodeAttrs& attrs) { return true; }) -.add_argument("queries_keys_values", "NDArray-or-Symbol", "Queries, keys and values interleaved") -.add_argument("attention", "NDArray-or-Symbol", "Attention maps") -.add_arguments(MKLDNNInterleavedMatMulParam::__FIELDS__()); +// input_names.emplace_back("min_attention"); +// input_names.emplace_back("max_attention"); +// } +// return input_names; +// }) +// .set_attr("FListOutputNames", [](const NodeAttrs& attrs) { +// auto const& param = nnvm::get(attrs.parsed); +// std::vector output_names {"output"}; +// if (param.quantized && !param.enable_float_output) { +// output_names.emplace_back("min_output"); +// output_names.emplace_back("max_output"); +// } +// return output_names; +// }) +// .set_attr("FInferShape", MKLDNNInterleavedMatMulSelfAttValAttShape) +// .set_attr("FInferType", MKLDNNInterleavedMatMulSelfAttValAttInferType) +// // .set_attr("FCompute", MKLDNNInterleavedMatMulSelfAttValAttCPU) +// .set_attr("FInferStorageType", MKLDNNInterleavedMatMulSelfAttValAttStorageType) +// .set_attr("FCreateOpState", CreateMKLDNNInterleavedMatMulSelfAttValAttState) +// .set_attr("FStatefulComputeEx", MKLDNNInterleavedMatMulSelfAttValAttForward) +// .set_attr("TIsMKLDNN", true) +// .set_attr("FGradient", MakeZeroGradNodes) +// .set_attr("FQuantizable", [](const NodeAttrs& attrs) { +// return QuantizeType::kMust; +// }) +// .set_attr("FQuantizedOp", SgMKLDNNInterleavedMatMulSelfAttValAttQuantizedOp) +// .set_attr("FNeedRequantize", [](const NodeAttrs& attrs) { return true; }) +// .add_argument("queries_keys_values", "NDArray-or-Symbol", "Queries, keys and values interleaved") +// .add_argument("attention", "NDArray-or-Symbol", "Attention maps") +// .add_arguments(MKLDNNInterleavedMatMulParam::__FIELDS__()); } // namespace op } // namespace mxnet diff --git a/src/operator/subgraph/mkldnn/mkldnn_transformer_post_quantize_property.h b/src/operator/subgraph/mkldnn/mkldnn_transformer_post_quantize_property.h new file mode 100644 index 000000000000..abb08b39b796 --- /dev/null +++ b/src/operator/subgraph/mkldnn/mkldnn_transformer_post_quantize_property.h @@ -0,0 +1,207 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_TRANSFORMER_POST_QUANTIZE_PROPERTY_H_ +#define MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_TRANSFORMER_POST_QUANTIZE_PROPERTY_H_ +#if MXNET_USE_MKLDNN == 1 + +#include +#include +#include "../../quantization/requantize-inl.h" +#include "../common.h" +#include "mkldnn_subgraph_base-inl.h" + +namespace mxnet { +namespace op { + +class SgMKLDNNTransformerPostQuantizeSelector : public SubgraphSelector { + public: + /*! \brief pattern match status */ + enum SelectStatus { + kFail = 0, + kStart, + kRequantize, + kSuccess, + }; + + private: + bool disable_all; + bool disable_float_output; + SelectStatus status; + std::vector matched_list; + + public: + explicit SgMKLDNNTransformerPostQuantizeSelector(const bool dis_all, + const bool dis_float_output) + : disable_all(dis_all), + disable_float_output(dis_float_output) {} + + bool Select(const nnvm::Node &n) override { + if ((!disable_all) && + (n.op() == Op::Get("_sg_mkldnn_contrib_interleaved_matmul_selfatt_qk") || + n.op() == Op::Get("_sg_mkldnn_contrib_interleaved_matmul_selfatt_valatt"))) { + status = disable_all ? kSuccess : kStart; + matched_list.clear(); + matched_list.push_back(&n); + return true; + } + return false; + } + + bool SelectInput(const nnvm::Node &n, const nnvm::Node &new_node) override { + return false; + } + + bool SelectOutput(const nnvm::Node &n, const nnvm::Node &new_node) override { + if (status == kFail || status == kSuccess || new_node.is_variable()) + return false; + // If n isn't the last matched node, then we encoutered a internal + // branch, we should pop out the node behind n and stop fusion. + if (matched_list.back() != &n) { + if (std::find(matched_list.begin(), matched_list.end(), &n) != + matched_list.end()) { + while (matched_list.back() != &n) { + matched_list.pop_back(); + } + } + + status = kSuccess; + return false; + } + + switch (status) { + case kStart: + if (new_node.op() == Op::Get("_contrib_requantize")) { + auto const ¶m = nnvm::get(new_node.attrs.parsed); + if (param.min_calib_range.has_value() && + param.max_calib_range.has_value()) { + matched_list.push_back(&new_node); + status = kRequantize; + return true; + } + } + case kRequantize: + if ((!disable_float_output) && (new_node.op() == Op::Get("_contrib_dequantize"))) { + matched_list.push_back(&new_node); + status = kSuccess; + return true; + } + default: + status = kSuccess; + return false; + } + } + + std::vector Filter( + const std::vector &candidates) override { + if ((status != kSuccess) || (matched_list.size() <= 1)) { + return std::vector(0); + } else { + std::vector ret; + for (auto i : matched_list) { + auto non_const_i = const_cast(i); + if (std::find(candidates.begin(), candidates.end(), non_const_i) != + candidates.end()) { + ret.push_back(non_const_i); + } + } + return ret; + } + } + + void Reset() override { + CHECK_GE(matched_list.size(), 1); + auto new_selector = SgMKLDNNTransformerPostQuantizeSelector(disable_all, disable_float_output); + new_selector.Select(*matched_list[0]); + *this = new_selector; + } +}; + +class SgMKLDNNTransformerPostQuantizeProperty : public SubgraphProperty { + public: + SgMKLDNNTransformerPostQuantizeProperty() { + disable_fuse_all = dmlc::GetEnv("MXNET_DISABLE_MKLDNN_QTRANSFORMER_FUSE_ALL", false); + disable_float_output = dmlc::GetEnv("MXNET_DISABLE_MKLDNN_QTRANSFORMER_FLOAT_OUTPUT", false); + } + + static SubgraphPropertyPtr Create() { + static const std::string &name = "MKLDNN Transformer post-quantization optimization pass"; + auto property = std::make_shared(); + property->SetAttr("property_name", name); + property->SetAttr("inference_only", true); + return property; + } + + nnvm::ObjectPtr CreateSubgraphNode(const nnvm::Symbol &sym, + const int subgraph_id = 0) const override { + nnvm::ObjectPtr interleaved_node = nullptr; + nnvm::ObjectPtr requantize_node = nullptr; + nnvm::ObjectPtr dequantize_node = nullptr; + + DFSVisit(sym.outputs, [&](const nnvm::ObjectPtr &node) { + if (node->is_variable()) return; + if (node->op() == Op::Get("_sg_mkldnn_contrib_interleaved_matmul_selfatt_qk") || + node->op() == Op::Get("_sg_mkldnn_contrib_interleaved_matmul_selfatt_valatt")) { + interleaved_node = node; + } else if (node->op() == Op::Get("_contrib_requantize")) { + requantize_node = node; + } else if (node->op() == Op::Get("_contrib_dequantize")) { + dequantize_node = node; + } + }); + + CHECK_NOTNULL(interleaved_node); + CHECK_NOTNULL(requantize_node); + auto const &requantize_param = + nnvm::get(requantize_node->attrs.parsed); + CHECK(requantize_param.min_calib_range.has_value()); + CHECK(requantize_param.max_calib_range.has_value()); + + // When only fusing quantized_interleaved_matmul and requantize, set min/max_cablib_range, + // When fusing quantized_interleaved_matmul + requantize + dequantize, set dequantize flag to true. + // auto& param = nnvm::get(interleaved_node->attrs.parsed); + if (dequantize_node != nullptr) { + interleaved_node->attrs.dict["enable_float_output"] = "True"; + } else { + interleaved_node->attrs.dict["min_calib_range"] = + std::to_string(requantize_param.min_calib_range.value()); + interleaved_node->attrs.dict["max_calib_range"] = + std::to_string(requantize_param.max_calib_range.value()); + } + interleaved_node->op()->attr_parser(&(interleaved_node->attrs)); + return interleaved_node; + } + + SubgraphSelectorPtr CreateSubgraphSelector() const override { + auto selector = + std::make_shared(disable_fuse_all, + disable_float_output); + return selector; + } + + private: + bool disable_fuse_all; + bool disable_float_output; +}; + +} // namespace op +} // namespace mxnet + +#endif // if MXNET_USE_MKLDNN == 1 +#endif // MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_TRANSFORMER_POST_QUANTIZE_PROPERTY_H_ diff --git a/src/operator/subgraph/mkldnn/mkldnn_transformer_property.h b/src/operator/subgraph/mkldnn/mkldnn_transformer_property.h new file mode 100644 index 000000000000..610d81536fd6 --- /dev/null +++ b/src/operator/subgraph/mkldnn/mkldnn_transformer_property.h @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + +#ifndef MXNET_OPERATOR_SUBGRAPH_MKLDNN_TRANSFORMER_PROPERTY_H_ +#define MXNET_OPERATOR_SUBGRAPH_MKLDNN_TRANSFORMER_PROPERTY_H_ +#if MXNET_USE_MKLDNN == 1 + +#include +#include +#include "../common.h" +#include "../../tensor/matrix_op-inl.h" +#include "../../contrib/transformer-inl.h" +#include "mkldnn_transformer-inl.h" +#include "mkldnn_subgraph_base-inl.h" + +namespace mxnet { +namespace op { + +class SgMKLDNNTransformerSelector : public SubgraphSelector { + public: + explicit SgMKLDNNTransformerSelector() {} + + bool Select(const nnvm::Node &n, const std::shared_ptr& node_attr) override { + if (n.op() == Op::Get("_contrib_interleaved_matmul_selfatt_qk") || + n.op() == Op::Get("_contrib_interleaved_matmul_selfatt_valatt")) { + return true; + } + return false; + } + + bool SelectInput(const nnvm::Node &n, const nnvm::Node &new_node) override { + return false; + } + + bool SelectOutput(const nnvm::Node &n, const nnvm::Node &new_node) override { + return false; + } +}; + +class SgMKLDNNTransformerProperty : public SubgraphProperty { + public: + SgMKLDNNTransformerProperty() {} + + static SubgraphPropertyPtr Create() { + static const std::string &name = "MKLDNN Transformer optimization pass"; + auto property = std::make_shared(); + property->SetAttr("property_name", name); + property->SetAttr("inference_only", true); + if (dmlc::GetEnv("MXNET_DISABLE_MKLDNN_TRANSFORMER_OPT", 0)) { + property->SetAttr("disable", true); + } + return property; + } + + nnvm::ObjectPtr CreateSubgraphNode(const nnvm::Symbol &sym, + const int subgraph_id = 0) const override { + nnvm::ObjectPtr n = nnvm::Node::Create(); + // This op has single output, remove duplicated. + auto last_node = sym.outputs[0].node; + nnvm::Symbol new_sym; + new_sym.outputs.emplace_back(last_node); + std::ostringstream node_name; + node_name << "_sg_mkldnn"; + + std::string op_name; + MKLDNNInterleavedMatMulParam new_param; + DFSVisit(new_sym.outputs, [&](const nnvm::ObjectPtr &node) { + if (node->op() && + (node->op()->name == "_contrib_interleaved_matmul_selfatt_qk" || + node->op()->name == "_contrib_interleaved_matmul_selfatt_valatt")) { + op_name = node->op()->name; + auto param = nnvm::get(node->attrs.parsed); + new_param.heads = param.heads; + new_param.quantized = false; + new_param.enable_float_output = false; + } + }); + node_name << op_name << "_" << std::to_string(subgraph_id); + + + n->attrs.name = node_name.str(); + n->attrs.op = Op::Get("_sg_mkldnn" + op_name); + CHECK(n->attrs.op); + n->attrs.subgraphs.emplace_back(std::make_shared(new_sym)); + n->attrs.parsed = new_param; + return n; + } + + SubgraphSelectorPtr CreateSubgraphSelector() const override { + auto selector = std::make_shared(); + return selector; + } + + void ConnectSubgraphOutputs( + const nnvm::ObjectPtr n, + std::vector *output_entries) const override { + // Connect all extern output entries to output[0] + for (size_t i = 0; i < output_entries->size(); ++i) { + auto entry_ptr = output_entries->at(i); + *entry_ptr = nnvm::NodeEntry{n, entry_ptr->index, 0}; + } + } +}; + +} // namespace op +} // namespace mxnet + +#endif // if MXNET_USE_MKLDNN == 1 +#endif // MXNET_OPERATOR_SUBGRAPH_MKLDNN_TRANSFORMER_PROPERTY_H_ From f53e08436ddf8a9ce66790c1bcb6c63a8b151cc7 Mon Sep 17 00:00:00 2001 From: "B. Gawrych" Date: Wed, 24 Mar 2021 08:34:32 +0100 Subject: [PATCH 03/13] Fix selfattQK subgraph --- .../subgraph/mkldnn/mkldnn_transformer.cc | 4 ++-- ...kldnn_transformer_post_quantize_property.h | 8 +++---- .../mkldnn/mkldnn_transformer_property.h | 22 ++++++++++++------- 3 files changed, 20 insertions(+), 14 deletions(-) diff --git a/src/operator/subgraph/mkldnn/mkldnn_transformer.cc b/src/operator/subgraph/mkldnn/mkldnn_transformer.cc index 5143e28e55e7..3d2702e28b9a 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_transformer.cc +++ b/src/operator/subgraph/mkldnn/mkldnn_transformer.cc @@ -304,8 +304,8 @@ nnvm::ObjectPtr SgMKLDNNSelfAttQKQuantizedOp(const NodeAttrs& attrs) { return node; } -NNVM_REGISTER_OP(_sg_mkldnn_interleaved_matmul_selfatt_qk) -.describe(R"code(_sg_mkldnn_interleaved_matmul_selfatt_qk)code" ADD_FILELINE) +NNVM_REGISTER_OP(_sg_mkldnn_selfatt_qk) +.describe(R"code(_sg_mkldnn_selfatt_qk)code" ADD_FILELINE) .set_num_inputs([](const NodeAttrs& attrs) { auto const& param = nnvm::get(attrs.parsed); if (param.quantized) { diff --git a/src/operator/subgraph/mkldnn/mkldnn_transformer_post_quantize_property.h b/src/operator/subgraph/mkldnn/mkldnn_transformer_post_quantize_property.h index abb08b39b796..9291e427f0ed 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_transformer_post_quantize_property.h +++ b/src/operator/subgraph/mkldnn/mkldnn_transformer_post_quantize_property.h @@ -54,8 +54,8 @@ class SgMKLDNNTransformerPostQuantizeSelector : public SubgraphSelector { bool Select(const nnvm::Node &n) override { if ((!disable_all) && - (n.op() == Op::Get("_sg_mkldnn_contrib_interleaved_matmul_selfatt_qk") || - n.op() == Op::Get("_sg_mkldnn_contrib_interleaved_matmul_selfatt_valatt"))) { + (n.op() == Op::Get("_sg_mkldnn_selfatt_qk") || + n.op() == Op::Get("_sg_mkldnn_selfatt_valatt"))) { status = disable_all ? kSuccess : kStart; matched_list.clear(); matched_list.push_back(&n); @@ -156,8 +156,8 @@ class SgMKLDNNTransformerPostQuantizeProperty : public SubgraphProperty { DFSVisit(sym.outputs, [&](const nnvm::ObjectPtr &node) { if (node->is_variable()) return; - if (node->op() == Op::Get("_sg_mkldnn_contrib_interleaved_matmul_selfatt_qk") || - node->op() == Op::Get("_sg_mkldnn_contrib_interleaved_matmul_selfatt_valatt")) { + if (node->op() == Op::Get("_sg_mkldnn_selfatt_qk") || + node->op() == Op::Get("_sg_mkldnn_selfatt_valatt")) { interleaved_node = node; } else if (node->op() == Op::Get("_contrib_requantize")) { requantize_node = node; diff --git a/src/operator/subgraph/mkldnn/mkldnn_transformer_property.h b/src/operator/subgraph/mkldnn/mkldnn_transformer_property.h index 610d81536fd6..bd566f2c7a22 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_transformer_property.h +++ b/src/operator/subgraph/mkldnn/mkldnn_transformer_property.h @@ -33,13 +33,21 @@ namespace mxnet { namespace op { +#define SELFATT_QK "_contrib_interleaved_matmul_selfatt_qk" +#define SELFATT_VALATT "_contrib_interleaved_matmul_selfatt_valatt" + +const std::map OpMapping = { + {SELFATT_QK, "_sg_mkldnn_selfatt_qk"}, + {SELFATT_VALATT, "_sg_mkldnn_selfatt_valatt"} +}; + class SgMKLDNNTransformerSelector : public SubgraphSelector { public: explicit SgMKLDNNTransformerSelector() {} bool Select(const nnvm::Node &n, const std::shared_ptr& node_attr) override { - if (n.op() == Op::Get("_contrib_interleaved_matmul_selfatt_qk") || - n.op() == Op::Get("_contrib_interleaved_matmul_selfatt_valatt")) { + if (n.op() == Op::Get(SELFATT_QK)) { //|| + //n.op() == Op::Get(SELFATT_VALATT)) { // Enable when refactored return true; } return false; @@ -77,14 +85,12 @@ class SgMKLDNNTransformerProperty : public SubgraphProperty { nnvm::Symbol new_sym; new_sym.outputs.emplace_back(last_node); std::ostringstream node_name; - node_name << "_sg_mkldnn"; - std::string op_name; MKLDNNInterleavedMatMulParam new_param; DFSVisit(new_sym.outputs, [&](const nnvm::ObjectPtr &node) { if (node->op() && - (node->op()->name == "_contrib_interleaved_matmul_selfatt_qk" || - node->op()->name == "_contrib_interleaved_matmul_selfatt_valatt")) { + (node->op()->name == SELFATT_QK || + node->op()->name == SELFATT_VALATT)) { op_name = node->op()->name; auto param = nnvm::get(node->attrs.parsed); new_param.heads = param.heads; @@ -92,11 +98,11 @@ class SgMKLDNNTransformerProperty : public SubgraphProperty { new_param.enable_float_output = false; } }); - node_name << op_name << "_" << std::to_string(subgraph_id); + node_name << OpMapping.at(op_name) << "_" << std::to_string(subgraph_id); n->attrs.name = node_name.str(); - n->attrs.op = Op::Get("_sg_mkldnn" + op_name); + n->attrs.op = Op::Get(OpMapping.at(op_name)); CHECK(n->attrs.op); n->attrs.subgraphs.emplace_back(std::make_shared(new_sym)); n->attrs.parsed = new_param; From b5854942e7711d11fce820d9ad3259449c625a55 Mon Sep 17 00:00:00 2001 From: Bartlomiej Gawrych Date: Wed, 24 Mar 2021 21:26:29 +0800 Subject: [PATCH 04/13] fix qk --- .../subgraph/mkldnn/mkldnn_transformer.cc | 20 +++++++++++++------ ...kldnn_transformer_post_quantize_property.h | 8 ++++---- .../mkldnn/mkldnn_transformer_property.h | 6 +++--- 3 files changed, 21 insertions(+), 13 deletions(-) diff --git a/src/operator/subgraph/mkldnn/mkldnn_transformer.cc b/src/operator/subgraph/mkldnn/mkldnn_transformer.cc index 3d2702e28b9a..e5fa41226645 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_transformer.cc +++ b/src/operator/subgraph/mkldnn/mkldnn_transformer.cc @@ -40,10 +40,18 @@ static bool SgMKLDNNSelfAttQKShape(const NodeAttrs& attrs, mxnet::ShapeVector* out_shapes) { const auto& param = nnvm::get(attrs.parsed); if (param.quantized) { - mxnet::ShapeVector base_in_shapes = {in_shapes[0]}; - mxnet::ShapeVector base_out_shapes = {out_shapes[0]}; + mxnet::ShapeVector base_in_shapes = {in_shapes->at(0)}; + mxnet::ShapeVector base_out_shapes = {out_shapes->at(0)}; bool ret = DefaultSubgraphOpShape(attrs, &base_in_shapes, &base_out_shapes); + for (size_t i = 0; i < in_shapes->size(); ++i) { + if (i < base_in_shapes.size()) + in_shapes->at(i) = base_in_shapes[i]; + else + SHAPE_ASSIGN_CHECK(*in_shapes, i, mxnet::TShape({1})); + } + out_shapes->resize(3); + out_shapes->at(0) = base_out_shapes[0]; if (!param.enable_float_output) { SHAPE_ASSIGN_CHECK(*out_shapes, 1, mxnet::TShape({1})); // min output SHAPE_ASSIGN_CHECK(*out_shapes, 2, mxnet::TShape({1})); // max output @@ -191,8 +199,8 @@ void SgMKLDNNSelfAttQKOp::Forward(const OpContext &ctx, const memory::dim embed_dim = output_lin_dim / 3; const memory::dim head_dim = embed_dim / heads; const memory::dim attn_batches = heads * sequences; - const memory::dim lead_dim = attn_batches * 3 * head_dim; - const memory::dim batch_stride = 3 * head_dim; + const memory::dim lead_dim = attn_batches * 3 * head_dim; // sequences * output_lin_dim + const memory::dim batch_stride = 3 * head_dim; // output_lin_dim / heads; const float scale = 1.0 / sqrt(static_cast(head_dim)); if (!initialized_) { @@ -209,7 +217,7 @@ void SgMKLDNNSelfAttQKOp::Forward(const OpContext &ctx, memory::dims key_strides = {batch_stride, 1, lead_dim}; auto query_md = memory::desc(query_dims, qkv_dtype, query_strides); - auto key_md = memory::desc(key_dims, qkv_dtype, key_dims); + auto key_md = memory::desc(key_dims, qkv_dtype, key_strides); memory::desc out_md; @@ -291,7 +299,7 @@ void SgMKLDNNSelfAttQKOp::Forward(const OpContext &ctx, nnvm::ObjectPtr SgMKLDNNSelfAttQKQuantizedOp(const NodeAttrs& attrs) { nnvm::ObjectPtr node = nnvm::Node::Create(); auto const ¶m = nnvm::get(attrs.parsed); - node->attrs.op = Op::Get("_sg_mkldnn_interleaved_matmul_selfatt_qk"); + node->attrs.op = Op::Get("_sg_mkldnn_selfatt_qk"); node->attrs.name = "quantized_" + attrs.name; node->attrs.dict = attrs.dict; node->attrs.dict["heads"] = std::to_string(param.heads); diff --git a/src/operator/subgraph/mkldnn/mkldnn_transformer_post_quantize_property.h b/src/operator/subgraph/mkldnn/mkldnn_transformer_post_quantize_property.h index 9291e427f0ed..ba37c314f5b1 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_transformer_post_quantize_property.h +++ b/src/operator/subgraph/mkldnn/mkldnn_transformer_post_quantize_property.h @@ -54,8 +54,8 @@ class SgMKLDNNTransformerPostQuantizeSelector : public SubgraphSelector { bool Select(const nnvm::Node &n) override { if ((!disable_all) && - (n.op() == Op::Get("_sg_mkldnn_selfatt_qk") || - n.op() == Op::Get("_sg_mkldnn_selfatt_valatt"))) { + (n.op() == Op::Get("_sg_mkldnn_selfatt_qk"))) { // || + //n.op() == Op::Get("_sg_mkldnn_selfatt_valatt"))) { status = disable_all ? kSuccess : kStart; matched_list.clear(); matched_list.push_back(&n); @@ -156,8 +156,8 @@ class SgMKLDNNTransformerPostQuantizeProperty : public SubgraphProperty { DFSVisit(sym.outputs, [&](const nnvm::ObjectPtr &node) { if (node->is_variable()) return; - if (node->op() == Op::Get("_sg_mkldnn_selfatt_qk") || - node->op() == Op::Get("_sg_mkldnn_selfatt_valatt")) { + if (node->op() == Op::Get("_sg_mkldnn_selfatt_qk")) {// || + //node->op() == Op::Get("_sg_mkldnn_selfatt_valatt")) { interleaved_node = node; } else if (node->op() == Op::Get("_contrib_requantize")) { requantize_node = node; diff --git a/src/operator/subgraph/mkldnn/mkldnn_transformer_property.h b/src/operator/subgraph/mkldnn/mkldnn_transformer_property.h index bd566f2c7a22..1935e40536ff 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_transformer_property.h +++ b/src/operator/subgraph/mkldnn/mkldnn_transformer_property.h @@ -38,7 +38,7 @@ namespace op { const std::map OpMapping = { {SELFATT_QK, "_sg_mkldnn_selfatt_qk"}, - {SELFATT_VALATT, "_sg_mkldnn_selfatt_valatt"} + // {SELFATT_VALATT, "_sg_mkldnn_selfatt_valatt"} }; class SgMKLDNNTransformerSelector : public SubgraphSelector { @@ -89,8 +89,8 @@ class SgMKLDNNTransformerProperty : public SubgraphProperty { MKLDNNInterleavedMatMulParam new_param; DFSVisit(new_sym.outputs, [&](const nnvm::ObjectPtr &node) { if (node->op() && - (node->op()->name == SELFATT_QK || - node->op()->name == SELFATT_VALATT)) { + (node->op()->name == SELFATT_QK)) {// || + //node->op()->name == SELFATT_VALATT)) { op_name = node->op()->name; auto param = nnvm::get(node->attrs.parsed); new_param.heads = param.heads; From 2f5a7e7b9539137d08f831285eabe10b737c2059 Mon Sep 17 00:00:00 2001 From: "B. Gawrych" Date: Fri, 26 Mar 2021 09:31:41 +0100 Subject: [PATCH 05/13] Fixes QK --- .../mkldnn/mkldnn_subgraph_property.cc | 2 ++ .../subgraph/mkldnn/mkldnn_transformer.cc | 21 +++++++++---------- .../mkldnn/mkldnn_transformer_property.h | 9 ++++++-- 3 files changed, 19 insertions(+), 13 deletions(-) diff --git a/src/operator/subgraph/mkldnn/mkldnn_subgraph_property.cc b/src/operator/subgraph/mkldnn/mkldnn_subgraph_property.cc index cf50c125c719..9190ba41afcb 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_subgraph_property.cc +++ b/src/operator/subgraph/mkldnn/mkldnn_subgraph_property.cc @@ -39,6 +39,8 @@ MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN, SgMKLDNNConvProperty); MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN, SgMKLDNNFCProperty); +MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN, SgMKLDNNTransformerProperty); + MXNET_REGISTER_SUBGRAPH_BACKEND(MKLDNN_QUANTIZE) .set_attr("context", Context::CPU()); diff --git a/src/operator/subgraph/mkldnn/mkldnn_transformer.cc b/src/operator/subgraph/mkldnn/mkldnn_transformer.cc index e5fa41226645..490e74826ab6 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_transformer.cc +++ b/src/operator/subgraph/mkldnn/mkldnn_transformer.cc @@ -36,8 +36,8 @@ namespace op { DMLC_REGISTER_PARAMETER(MKLDNNInterleavedMatMulParam); static bool SgMKLDNNSelfAttQKShape(const NodeAttrs& attrs, - mxnet::ShapeVector* in_shapes, - mxnet::ShapeVector* out_shapes) { + mxnet::ShapeVector* in_shapes, + mxnet::ShapeVector* out_shapes) { const auto& param = nnvm::get(attrs.parsed); if (param.quantized) { mxnet::ShapeVector base_in_shapes = {in_shapes->at(0)}; @@ -64,13 +64,12 @@ static bool SgMKLDNNSelfAttQKShape(const NodeAttrs& attrs, } static bool SgMKLDNNSelfAttQKInferType(const nnvm::NodeAttrs &attrs, - std::vector *in_types, - std::vector *out_types) { + std::vector *in_types, + std::vector *out_types) { const auto& param = nnvm::get(attrs.parsed); if (param.quantized) { - CHECK(in_types->at(0) == mshadow::kInt8 || - in_types->at(0) == mshadow::kUint8) - << "QuantizedInterleavedMatMulSelfAttQK only supports int8/uint8 input, while " + CHECK(in_types->at(0) == mshadow::kInt8) + << "QuantizedInterleavedMatMulSelfAttQK only supports int8 input, while " << in_types->at(0) << " is given."; TYPE_ASSIGN_CHECK(*in_types, 1, mshadow::kFloat32); // min value @@ -94,10 +93,10 @@ static bool SgMKLDNNSelfAttQKInferType(const nnvm::NodeAttrs &attrs, } static bool SgMKLDNNSelfAttQKStorageType(const nnvm::NodeAttrs &attrs, - const int dev_mask, - DispatchMode *dispatch_mode, - std::vector *in_attrs, - std::vector *out_attrs) { + const int dev_mask, + DispatchMode *dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs) { auto const ¶m = nnvm::get(attrs.parsed); if (param.quantized) { std::vector base_in_attrs{in_attrs->at(0)}; diff --git a/src/operator/subgraph/mkldnn/mkldnn_transformer_property.h b/src/operator/subgraph/mkldnn/mkldnn_transformer_property.h index 1935e40536ff..be188c0ba987 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_transformer_property.h +++ b/src/operator/subgraph/mkldnn/mkldnn_transformer_property.h @@ -38,7 +38,12 @@ namespace op { const std::map OpMapping = { {SELFATT_QK, "_sg_mkldnn_selfatt_qk"}, - // {SELFATT_VALATT, "_sg_mkldnn_selfatt_valatt"} + {SELFATT_VALATT, "_sg_mkldnn_selfatt_valatt"} +}; + +const std::map NameMapping = { + {SELFATT_QK, "sg_mkldnn_selfatt_qk"}, + {SELFATT_VALATT, "sg_mkldnn_selfatt_valatt"} }; class SgMKLDNNTransformerSelector : public SubgraphSelector { @@ -98,7 +103,7 @@ class SgMKLDNNTransformerProperty : public SubgraphProperty { new_param.enable_float_output = false; } }); - node_name << OpMapping.at(op_name) << "_" << std::to_string(subgraph_id); + node_name << NameMapping.at(op_name) << "_" << std::to_string(subgraph_id); n->attrs.name = node_name.str(); From e8c98822cf4bc6d5ca0e728ff72949c4c293af35 Mon Sep 17 00:00:00 2001 From: "B. Gawrych" Date: Mon, 29 Mar 2021 08:30:48 +0200 Subject: [PATCH 06/13] add test for oneDNN self_att qk --- .../subgraph/mkldnn/mkldnn_transformer.cc | 142 ++++++++++-------- tests/python/mkl/test_subgraph.py | 26 +++- 2 files changed, 102 insertions(+), 66 deletions(-) diff --git a/src/operator/subgraph/mkldnn/mkldnn_transformer.cc b/src/operator/subgraph/mkldnn/mkldnn_transformer.cc index 490e74826ab6..2d7c63b57bb3 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_transformer.cc +++ b/src/operator/subgraph/mkldnn/mkldnn_transformer.cc @@ -141,6 +141,15 @@ class SgMKLDNNSelfAttQKOp { "inference computation."; } + void Initialize(const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs); + + bool IsInitialized() { + return initialized_; + } + private: bool initialized_{false}; MKLDNNInterleavedMatMulParam param_; @@ -149,10 +158,10 @@ class SgMKLDNNSelfAttQKOp { std::shared_ptr cached_query_mem_; std::shared_ptr cached_key_mem_; std::shared_ptr cached_out_mem_; - float cached_min_data_; - float cached_max_data_; - float cached_min_output_; - float cached_max_output_; + float min_data_; + float max_data_; + float min_output_; + float max_output_; float data_scale_{0.0f}; }; @@ -169,45 +178,41 @@ static void SgMKLDNNSelfAttQKForward(const OpStatePtr &state_pointer, const std::vector &req, const std::vector &outputs) { SgMKLDNNSelfAttQKOp &op = state_pointer.get_state(); + if(!op.IsInitialized()) { + op.Initialize(ctx, inputs, req, outputs); + } op.Forward(ctx, inputs, req, outputs); } -void SgMKLDNNSelfAttQKOp::Forward(const OpContext &ctx, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs) { - using namespace mkldnn; - - float min_data = 0.0f; - float max_data = 0.0f; - - if (param_.quantized) { - min_data = inputs[1].data().dptr()[0]; - max_data = inputs[2].data().dptr()[0]; - } +void SgMKLDNNSelfAttQKOp::Initialize(const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + using namespace mkldnn; + const auto qkv_tensor = inputs[0]; + const auto out_tensor = outputs[0]; + const auto qkv_dtype = get_mkldnn_type(qkv_tensor.dtype()); + + const memory::dim heads = param_.heads; + const memory::dim sequences = inputs[0].shape()[1]; + const memory::dim qkv_seq_len = inputs[0].shape()[0]; + const memory::dim output_lin_dim = inputs[0].shape()[2]; + const memory::dim embed_dim = output_lin_dim / 3; + const memory::dim head_dim = embed_dim / heads; + const memory::dim attn_batches = heads * sequences; + const memory::dim lead_dim = attn_batches * 3 * head_dim; + const memory::dim batch_stride = 3 * head_dim; + + float min_data = 0.0f; + float max_data = 0.0f; - const auto qkv_tensor = inputs[0]; - const auto out_tensor = outputs[0]; - const auto qkv_dtype = get_mkldnn_type(qkv_tensor.dtype()); - const auto out_dtype = get_mkldnn_type(out_tensor.dtype()); - - const memory::dim heads = param_.heads; - const memory::dim sequences = inputs[0].shape()[1]; - const memory::dim qkv_seq_len = inputs[0].shape()[0]; - const memory::dim output_lin_dim = inputs[0].shape()[2]; - const memory::dim embed_dim = output_lin_dim / 3; - const memory::dim head_dim = embed_dim / heads; - const memory::dim attn_batches = heads * sequences; - const memory::dim lead_dim = attn_batches * 3 * head_dim; // sequences * output_lin_dim - const memory::dim batch_stride = 3 * head_dim; // output_lin_dim / heads; - const float scale = 1.0 / sqrt(static_cast(head_dim)); + if (param_.quantized) { + min_data_ = inputs[1].data().dptr()[0]; + max_data_ = inputs[2].data().dptr()[0]; + } - if (!initialized_) { const auto engine = CpuEngine::Get()->get_engine(); - cached_min_data_ = min_data; - cached_max_data_ = max_data; - memory::dims query_dims = {attn_batches, qkv_seq_len, head_dim}; memory::dims key_dims = {attn_batches, head_dim, qkv_seq_len}; memory::dims out_dims = {attn_batches, qkv_seq_len, qkv_seq_len}; @@ -219,45 +224,46 @@ void SgMKLDNNSelfAttQKOp::Forward(const OpContext &ctx, auto key_md = memory::desc(key_dims, qkv_dtype, key_strides); memory::desc out_md; - - float tmp_scale = 1.0f; + + float oscale = 1.0f; if (param_.quantized) { - data_scale_ = GetQuantizeScale(qkv_tensor.dtype(), cached_min_data_, cached_max_data_); + data_scale_ = GetQuantizeScale(qkv_tensor.dtype(), min_data_, max_data_); if (param_.min_calib_range.has_value() && param_.max_calib_range.has_value()) { - cached_min_output_ = param_.min_calib_range.value(); - cached_max_output_ = param_.max_calib_range.value(); - - tmp_scale = - GetQuantizeScale(out_tensor.dtype(), cached_min_output_, cached_max_output_) / + min_output_ = param_.min_calib_range.value(); + max_output_ = param_.max_calib_range.value(); + oscale = + GetQuantizeScale(out_tensor.dtype(), min_output_, max_output_) / (data_scale_ * data_scale_); out_md = memory::desc(out_dims, memory::data_type::s8, memory::format_tag::abc); } else if (param_.enable_float_output) { - tmp_scale = 1.0f / (data_scale_ * data_scale_); + oscale = 1.0f / (data_scale_ * data_scale_); out_md = dnnl::memory::desc(out_dims, memory::data_type::f32, memory::format_tag::abc); } else { mshadow::Stream *s = ctx.get_stream(); mxnet_op::Kernel::Launch( - s, 1, &cached_min_output_, &cached_max_output_, &min_data, &max_data, &min_data, + s, 1, &min_output_, &max_output_, &min_data, &max_data, &min_data, &max_data); out_md = dnnl::memory::desc(out_dims, memory::data_type::s32, memory::format_tag::abc); } } else { out_md = dnnl::memory::desc(out_dims, memory::data_type::f32, memory::format_tag::abc); } - tmp_scale /= sqrt(static_cast(head_dim)); + oscale /= sqrt(static_cast(head_dim)); // combine quantized scale and sqrt(head_dim) dnnl::primitive_attr attr; - attr.set_output_scales(0, {tmp_scale}); + attr.set_output_scales(0, {oscale}); auto matmul_d = matmul::desc(query_md, key_md, out_md); auto matmul_pd = matmul::primitive_desc(matmul_d, attr, engine); fwd_ = std::make_shared(matmul_pd); MSHADOW_TYPE_SWITCH(inputs[0].dtype(), DType, { - cached_query_mem_ = std::make_shared(query_md, engine, inputs[0].data().dptr()); - cached_key_mem_ = std::make_shared(key_md, engine, inputs[0].data().dptr() + (head_dim)); + DType* query_mem_ptr = inputs[0].data().dptr(); + DType* key_mem_ptr = query_mem_ptr + head_dim; + cached_query_mem_ = std::make_shared(query_md, engine, query_mem_ptr); + cached_key_mem_ = std::make_shared(key_md, engine, key_mem_ptr); }); MSHADOW_TYPE_SWITCH(outputs[0].dtype(), DType, { cached_out_mem_ = std::make_shared(out_md, engine, outputs[0].data().dptr()); @@ -266,32 +272,38 @@ void SgMKLDNNSelfAttQKOp::Forward(const OpContext &ctx, args_[DNNL_ARG_SRC] = *cached_query_mem_; args_[DNNL_ARG_WEIGHTS] = *cached_key_mem_; args_[DNNL_ARG_DST] = *cached_out_mem_; - initialized_ = true; - } else { +} + + +void SgMKLDNNSelfAttQKOp::Forward(const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + + const size_t head_dim = inputs[0].shape()[2] / 3 / param_.heads; - MSHADOW_TYPE_SWITCH(qkv_tensor.dtype(), DType, { - void* query_mem_ptr = reinterpret_cast(inputs[0].data().dptr()); - void* key_mem_ptr = query_mem_ptr + head_dim; + MSHADOW_TYPE_SWITCH(inputs[0].dtype(), DType, { + DType* query_mem_ptr = inputs[0].data().dptr(); + DType* key_mem_ptr = query_mem_ptr + head_dim; cached_query_mem_->set_data_handle(query_mem_ptr); cached_key_mem_->set_data_handle(key_mem_ptr); }); - MSHADOW_TYPE_SWITCH(out_tensor.dtype(), DType, { + MSHADOW_TYPE_SWITCH(outputs[0].dtype(), DType, { cached_out_mem_->set_data_handle(reinterpret_cast(outputs[0].data().dptr())); }); - } - MKLDNNStream::Get()->RegisterPrimArgs(*fwd_, args_); - MKLDNNStream::Get()->Submit(); + MKLDNNStream::Get()->RegisterPrimArgs(*fwd_, args_); + MKLDNNStream::Get()->Submit(); - if (param_.quantized && !param_.enable_float_output) { - float* min_output = outputs[1].data().dptr(); - float* max_output = outputs[2].data().dptr(); + if (param_.quantized && !param_.enable_float_output) { + float* output_min = outputs[1].data().dptr(); + float* output_max = outputs[2].data().dptr(); - *min_output = cached_min_output_; - *max_output = cached_max_output_; - } + *output_min = min_output_; + *output_max = max_output_; + } } diff --git a/tests/python/mkl/test_subgraph.py b/tests/python/mkl/test_subgraph.py index 65b73e438ea6..b96e5dea625f 100644 --- a/tests/python/mkl/test_subgraph.py +++ b/tests/python/mkl/test_subgraph.py @@ -45,6 +45,10 @@ 'fc': { OP_NAME: 'sg_mkldnn_fully_connected', QUANTIZED_OP_NAME: 'quantized_sg_mkldnn_fully_connected' + }, + 'selfatt_qk': { + OP_NAME: 'sg_mkldnn_selfatt_qk', + QUANTIZED_OP_NAME: 'quantized_sg_mkldnn_selfatt_qk' } } @@ -59,7 +63,9 @@ def check_qsym_calibrated(qsym, out_type, name='conv'): if k.find('_quantize') != -1: assert v['out_type'] == out_type if k.find(quantized_op_name) != -1: - if quantized_op_name.startswith("quantized_sg_mkldnn_fully_connected") and 'enable_float_output' in v: + if ('enable_float_output' in v + and quantized_op_name.startswith(("quantized_sg_mkldnn_fully_connected", + "quantized_sg_mkldnn_selfatt_qk"))): continue assert 'min_calib_range' in v assert 'max_calib_range' in v @@ -677,6 +683,13 @@ def fc_eltwise(no_bias, data_shape, flatten=True, alg='relu'): return sym, attr +def single_selfatt_qk(data_shape, nheads=16): + attr = {'selfatt_qk': {}} + data = mx.symbol.Variable('data', shape=data_shape, dtype='float32') + qk = mx.symbol.contrib.interleaved_matmul_selfatt_qk(queries_keys_values=data, + heads=nheads) + return qk, attr + # fc + relu can't be fusion case # eg.1 # fc -----------> relu @@ -865,6 +878,17 @@ def test_fc_eltwise(): else: check_fusion(syms, dshape, attrs, check_quantization=False) +@with_seed() +def test_selfatt_qk(): + batchsizes = [1, 8, 16] + seq_lengths = [180, 255, 384]#,64,128,256,512,768,1024] + num_hidden = [1024, 2048, 3072] + num_heads = [8, 16] + for bs, seqlen, nhidden, nheads in itertools.product(batchsizes, seq_lengths, num_hidden, num_heads): + dshape = (seqlen, bs, nhidden) + syms, attrs = single_selfatt_qk(dshape, nheads) + check_fusion(syms, dshape, attrs, out_types=['int8', 'auto'], check_quantization=True) + @with_seed() def test_neg_fc_relu(): for dshape, no_bias, flatten in itertools.product(DATA_SHAPE, [True, False], [True, False]): From a0764eb0f93b10383d07c08f240fc50566d65197 Mon Sep 17 00:00:00 2001 From: "B. Gawrych" Date: Tue, 30 Mar 2021 10:20:08 +0200 Subject: [PATCH 07/13] basic valatt --- .../subgraph/mkldnn/mkldnn_transformer.cc | 444 ++++-------------- ...kldnn_transformer_post_quantize_property.h | 9 +- .../mkldnn/mkldnn_transformer_property.h | 8 +- 3 files changed, 112 insertions(+), 349 deletions(-) diff --git a/src/operator/subgraph/mkldnn/mkldnn_transformer.cc b/src/operator/subgraph/mkldnn/mkldnn_transformer.cc index 2d7c63b57bb3..e6fb2713ed9f 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_transformer.cc +++ b/src/operator/subgraph/mkldnn/mkldnn_transformer.cc @@ -35,12 +35,17 @@ namespace op { DMLC_REGISTER_PARAMETER(MKLDNNInterleavedMatMulParam); -static bool SgMKLDNNSelfAttQKShape(const NodeAttrs& attrs, +template +static bool SgMKLDNNSelfAttShape(const NodeAttrs& attrs, mxnet::ShapeVector* in_shapes, mxnet::ShapeVector* out_shapes) { const auto& param = nnvm::get(attrs.parsed); if (param.quantized) { - mxnet::ShapeVector base_in_shapes = {in_shapes->at(0)}; + mxnet::ShapeVector base_in_shapes; + for(int i=0; i < base_num_inputs; i++) { + base_in_shapes.emplace_back(in_shapes->at(i)); + } + mxnet::ShapeVector base_out_shapes = {out_shapes->at(0)}; bool ret = DefaultSubgraphOpShape(attrs, &base_in_shapes, &base_out_shapes); @@ -360,7 +365,7 @@ NNVM_REGISTER_OP(_sg_mkldnn_selfatt_qk) } return output_names; }) -.set_attr("FInferShape", SgMKLDNNSelfAttQKShape) +.set_attr("FInferShape", SgMKLDNNSelfAttShape<1>) .set_attr("FInferType", SgMKLDNNSelfAttQKInferType) .set_attr("FInferStorageType", SgMKLDNNSelfAttQKStorageType) .set_attr("FCreateOpState", CreateSgMKLDNNSelfAttQKState) @@ -377,53 +382,25 @@ NNVM_REGISTER_OP(_sg_mkldnn_selfatt_qk) /********************************************************************************************************/ -static bool MKLDNNInterleavedMatMulSelfAttValAttShape(const NodeAttrs& attrs, - mxnet::ShapeVector* in_shape, - mxnet::ShapeVector* out_shape) { - const auto& param = nnvm::get(attrs.parsed); - if (param.quantized) { - auto qkv_shape = in_shape->at(0); - - out_shape->resize(3); - SHAPE_ASSIGN_CHECK(*out_shape, 0, - mxnet::TShape({qkv_shape[0], qkv_shape[1], qkv_shape[2] / 3})); - if (!param.enable_float_output) { - SHAPE_ASSIGN_CHECK(*out_shape, 1, mxnet::TShape({1})); // min output - SHAPE_ASSIGN_CHECK(*out_shape, 2, mxnet::TShape({1})); // max output - } - - return true; - } else { - CHECK_EQ(in_shape->size(), 2U) << "Input:[queries_keys_values, attention] currently have, " - << in_shape->size() << " inputs"; - auto qkv_shape = in_shape->at(0); - auto att_shape = in_shape->at(1); - CHECK_EQ(qkv_shape.ndim(), 3U) - << "Input queries_keys_values should be 3D in seq_length-batch-3*proj_dim, " - << "currently is: " << qkv_shape.ndim() << "D"; - CHECK_EQ(att_shape.ndim(), 3U) - << "Input attention should be 3D in batch-seq_length-seq_length, " - << "currently is: " << att_shape.ndim() << "D"; - CHECK_EQ(qkv_shape[0], att_shape[1]) - << "queries_keys_values.shape[0] and attention.shape[1] should be the same, " - << "currently are " << qkv_shape[0] << " and " << att_shape[1]; - CHECK_EQ(qkv_shape[0], att_shape[2]) - << "queries_keys_values.shape[0] and attention.shape[2] should be the same, " - << "currently are " << qkv_shape[0] << " and " << att_shape[2]; - CHECK_EQ(qkv_shape[2] % 3, 0) - << "queries_keys_values.shape[2] should be a multiple of 3, " - << "currently is " << qkv_shape[2]; - SHAPE_ASSIGN_CHECK(*out_shape, 0, - mxnet::TShape({qkv_shape[0], qkv_shape[1], qkv_shape[2] / 3})); - return true; - } -} - static bool MKLDNNInterleavedMatMulSelfAttValAttInferType(const nnvm::NodeAttrs &attrs, std::vector *in_types, std::vector *out_types) { const auto& param = nnvm::get(attrs.parsed); if (param.quantized) { + // CHECK(in_types->at(0) == mshadow::kInt8 || + // in_types->at(0) == mshadow::kUint8) + // << "QuantizedInterleavedMatMulSelfAttValAtt only supports int8/uint8 input, while " + // << in_types->at(0) << " is given."; // attention weights + // // TYPE_ASSIGN_CHECK(*in_types, 0, mshadow::kInt8); // qkv input + // CHECK(in_types->at(1) == mshadow::kInt8 || + // in_types->at(1) == mshadow::kUint8) + // << "QuantizedInterleavedMatMulSelfAttValAtt only supports int8/uint8 input, while " + // << in_types->at(1) << " is given."; // attention weights + + // //min qkv, max qkv, min att, max att + // for (size_t i = 2; i < in_types->size(); ++i) { + // TYPE_ASSIGN_CHECK(*in_types, i, mshadow::kFloat32); + // } TYPE_ASSIGN_CHECK(*in_types, 0, mshadow::kInt8); // qkv input TYPE_ASSIGN_CHECK(*in_types, 1, mshadow::kUint8); // att input @@ -457,22 +434,25 @@ static bool MKLDNNInterleavedMatMulSelfAttValAttStorageType(const nnvm::NodeAttr std::vector *out_attrs) { auto const ¶m = nnvm::get(attrs.parsed); if (param.quantized) { - type_assign(&in_attrs->at(0), mxnet::kDefaultStorage); - type_assign(&in_attrs->at(1), mxnet::kDefaultStorage); - type_assign(&in_attrs->at(2), mxnet::kDefaultStorage); - type_assign(&in_attrs->at(3), mxnet::kDefaultStorage); - type_assign(&in_attrs->at(4), mxnet::kDefaultStorage); - type_assign(&in_attrs->at(5), mxnet::kDefaultStorage); - - type_assign(&out_attrs->at(0), mxnet::kDefaultStorage); + std::vector base_in_attrs = {in_attrs->at(0), in_attrs->at(1)}; + std::vector base_out_attrs = {out_attrs->at(0)}; + + bool ret = DefaultSubgraphOpStorageType(attrs, dev_mask, dispatch_mode, + &base_in_attrs, &base_out_attrs); + + for (size_t i = 0; i < in_attrs->size(); ++i) { + if (i < base_in_attrs.size()) + in_attrs->at(i) = base_in_attrs[i]; + else + type_assign(&in_attrs->at(i), mxnet::kDefaultStorage); + } + + out_attrs->at(0) = base_out_attrs[0]; if (!param.enable_float_output) { type_assign(&out_attrs->at(1), mxnet::kDefaultStorage); type_assign(&out_attrs->at(2), mxnet::kDefaultStorage); } - std::vector base_in_attrs{in_attrs->at(0), in_attrs->at(1)}; - std::vector base_out_attrs{out_attrs->at(0)}; - return DefaultSubgraphOpStorageType(attrs, dev_mask, dispatch_mode, - &base_in_attrs, &base_out_attrs);; + return ret; } else { return DefaultSubgraphOpStorageType(attrs, dev_mask, dispatch_mode, in_attrs, out_attrs); @@ -482,7 +462,7 @@ static bool MKLDNNInterleavedMatMulSelfAttValAttStorageType(const nnvm::NodeAttr nnvm::ObjectPtr SgMKLDNNInterleavedMatMulSelfAttValAttQuantizedOp(const NodeAttrs& attrs) { nnvm::ObjectPtr node = nnvm::Node::Create(); auto const ¶m = nnvm::get(attrs.parsed); - node->attrs.op = Op::Get("_sg_mkldnn_interleaved_matmul_selfatt_valatt"); + node->attrs.op = Op::Get("_sg_mkldnn_selfatt_valatt"); node->attrs.name = "quantized_" + attrs.name; node->attrs.dict = attrs.dict; node->attrs.dict["heads"] = std::to_string(param.heads); @@ -590,13 +570,23 @@ void MKLDNNInterleavedMatMulSelfAttValAttOp::Forward( auto src2_md = param_.quantized ? dnnl::memory::desc(src2_dims, dnnl::memory::data_type::s8, src2_strides) : dnnl::memory::desc(src2_dims, dnnl::memory::data_type::f32, src2_strides); + // const auto src1_dtype = get_mkldnn_type(inputs[0].dtype()); + // const auto src2_dtype = get_mkldnn_type(inputs[1].dtype()); + + // auto src1_md = dnnl::memory::desc(src1_dims, src1_dtype, src1_strides); + // auto src2_md = dnnl::memory::desc(src2_dims, src2_dtype, src2_strides); + dnnl::memory::desc dst_md; - + float tmp_scale = 1.0f; if (param_.quantized) { + + // qkv_scale_ = GetQuantizeScale(inputs[0].dtype(), min_qkv, max_qkv); + // att_scale_ = GetQuantizeScale(inputs[1].dtype(), min_att, max_att); qkv_scale_ = GetQuantizeScale(mshadow::kInt8, min_qkv, max_qkv); att_scale_ = GetQuantizeScale(mshadow::kUint8, min_att, max_att); + if (param_.min_calib_range.has_value() && param_.max_calib_range.has_value()) { cached_min_output_ = param_.min_calib_range.value(); @@ -662,290 +652,62 @@ void MKLDNNInterleavedMatMulSelfAttValAttOp::Forward( } } -// void MKLDNNInterleavedMatMulSelfAttValAttCPU(const nnvm::NodeAttrs& attrs, -// const OpContext &ctx, -// const std::vector &inputs, -// const std::vector &req, -// const std::vector &outputs) { -// const auto& param = nnvm::get(attrs.parsed); - -// if (param.quantized) { -// if (param.enable_float_output) { -// const dnnl::memory::dim HEADS = param.heads; -// const dnnl::memory::dim BS = inputs[0].shape_[1]; -// const dnnl::memory::dim SEQ_LEN = inputs[0].shape_[0]; -// const dnnl::memory::dim EMBED = inputs[0].shape_[2]; - -// dnnl::engine engine(dnnl::engine::kind::cpu, 0); -// dnnl::stream engine_stream(engine); - -// dnnl::memory::dims src1_dims = {BS*HEADS, SEQ_LEN, SEQ_LEN}; -// dnnl::memory::dims src2_dims = {BS*HEADS, SEQ_LEN, EMBED/HEADS/3}; -// dnnl::memory::dims dst_dims = {BS*HEADS, SEQ_LEN, EMBED/HEADS/3}; - -// // dnnl::memory::dims src1_strides = {SEQ_LEN*SEQ_LEN, SEQ_LEN, 1}; -// dnnl::memory::dims src2_strides = {3*(EMBED/HEADS/3), EMBED*BS, 1}; - -// auto src1_md = dnnl::memory::desc(src1_dims, dnnl::memory::data_type::u8, dnnl::memory::format_tag::abc); // CHECK IF IT IS U8 FOR SURE -// auto src2_md = dnnl::memory::desc(src2_dims, dnnl::memory::data_type::s8, src2_strides); -// auto dst_md = dnnl::memory::desc(dst_dims, dnnl::memory::data_type::f32, dnnl::memory::format_tag::bac); - -// float min_qkv = inputs[2].dptr()[0]; -// float max_qkv = inputs[3].dptr()[0]; -// float min_att = inputs[4].dptr()[0]; -// float max_att = inputs[5].dptr()[0]; - -// float qkv_scale = GetQuantizeScale(mshadow::kInt8, min_qkv, max_qkv); -// float att_scale = GetQuantizeScale(mshadow::kUint8, min_att, max_att); - -// const float scale = 1.0f / (qkv_scale * att_scale); - -// dnnl::primitive_attr attr; -// attr.set_output_scales(0, {scale}); - -// auto matmul_d = dnnl::matmul::desc(src1_md, src2_md, dst_md); -// auto matmul_pd = dnnl::matmul::primitive_desc(matmul_d, attr, engine); - -// auto matmul_prim = dnnl::matmul(matmul_pd); - -// mshadow::Stream* s = ctx.get_stream(); -// int8_t* queries_keys_values = inputs[0].FlatTo2D(s).dptr_; -// uint8_t* attention_maps = inputs[1].FlatTo2D(s).dptr_; -// float* output = outputs[0].FlatTo2D(s).dptr_; - -// auto src1_mem = dnnl::memory(src1_md, engine, attention_maps); -// auto src2_mem = dnnl::memory(src2_md, engine, queries_keys_values+2*(EMBED/HEADS/3)); -// auto dst_mem = dnnl::memory(dst_md, engine, output); - -// std::unordered_map matmul_args; -// matmul_args.insert({DNNL_ARG_SRC, src1_mem}); -// matmul_args.insert({DNNL_ARG_WEIGHTS, src2_mem}); -// matmul_args.insert({DNNL_ARG_DST, dst_mem}); - -// matmul_prim.execute(engine_stream, matmul_args); -// engine_stream.wait(); -// } else if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) { -// const dnnl::memory::dim HEADS = param.heads; -// const dnnl::memory::dim BS = inputs[0].shape_[1]; -// const dnnl::memory::dim SEQ_LEN = inputs[0].shape_[0]; -// const dnnl::memory::dim EMBED = inputs[0].shape_[2]; - -// dnnl::engine engine(dnnl::engine::kind::cpu, 0); -// dnnl::stream engine_stream(engine); - -// dnnl::memory::dims src1_dims = {BS*HEADS, SEQ_LEN, SEQ_LEN}; -// dnnl::memory::dims src2_dims = {BS*HEADS, SEQ_LEN, EMBED/HEADS/3}; -// dnnl::memory::dims dst_dims = {BS*HEADS, SEQ_LEN, EMBED/HEADS/3}; - -// // dnnl::memory::dims src1_strides = {SEQ_LEN*SEQ_LEN, SEQ_LEN, 1}; -// dnnl::memory::dims src2_strides = {3*(EMBED/HEADS/3), EMBED*BS, 1}; -// auto src1_md = dnnl::memory::desc(src1_dims, dnnl::memory::data_type::u8, dnnl::memory::format_tag::abc); // CHECK IF IT IS U8 FOR SURE -// auto src2_md = dnnl::memory::desc(src2_dims, dnnl::memory::data_type::s8, src2_strides); -// auto dst_md = dnnl::memory::desc(dst_dims, dnnl::memory::data_type::s8, dnnl::memory::format_tag::bac); - -// float min_qkv = inputs[2].dptr()[0]; -// float max_qkv = inputs[3].dptr()[0]; -// float min_att = inputs[4].dptr()[0]; -// float max_att = inputs[5].dptr()[0]; - -// float qkv_scale = GetQuantizeScale(mshadow::kInt8, min_qkv, max_qkv); -// float att_scale = GetQuantizeScale(mshadow::kUint8, min_att, max_att); - -// const float scale = GetQuantizeScale(mshadow::kInt8, param.min_calib_range.value(), param.max_calib_range.value()) -// / (qkv_scale * att_scale); - -// dnnl::primitive_attr attr; -// attr.set_output_scales(0, {scale}); - -// auto matmul_d = dnnl::matmul::desc(src1_md, src2_md, dst_md); -// auto matmul_pd = dnnl::matmul::primitive_desc(matmul_d, attr, engine); - -// auto matmul_prim = dnnl::matmul(matmul_pd); - -// mshadow::Stream* s = ctx.get_stream(); -// int8_t* queries_keys_values = inputs[0].FlatTo2D(s).dptr_; -// uint8_t* attention_maps = inputs[1].FlatTo2D(s).dptr_; -// int8_t* output = outputs[0].FlatTo2D(s).dptr_; - -// float* min_output = outputs[1].dptr(); -// float* max_output = outputs[2].dptr(); -// min_output[0] = param.min_calib_range.value(); -// max_output[0] = param.max_calib_range.value(); - -// auto src1_mem = dnnl::memory(src1_md, engine, attention_maps); -// auto src2_mem = dnnl::memory(src2_md, engine, queries_keys_values+2*(EMBED/HEADS/3)); -// auto dst_mem = dnnl::memory(dst_md, engine, output); - -// std::unordered_map matmul_args; -// matmul_args.insert({DNNL_ARG_SRC, src1_mem}); -// matmul_args.insert({DNNL_ARG_WEIGHTS, src2_mem}); -// matmul_args.insert({DNNL_ARG_DST, dst_mem}); - -// matmul_prim.execute(engine_stream, matmul_args); -// engine_stream.wait(); -// } else { -// const dnnl::memory::dim HEADS = param.heads; -// const dnnl::memory::dim BS = inputs[0].shape_[1]; -// const dnnl::memory::dim SEQ_LEN = inputs[0].shape_[0]; -// const dnnl::memory::dim EMBED = inputs[0].shape_[2]; - -// dnnl::engine engine(dnnl::engine::kind::cpu, 0); -// dnnl::stream engine_stream(engine); - -// dnnl::memory::dims src1_dims = {BS*HEADS, SEQ_LEN, SEQ_LEN}; -// dnnl::memory::dims src2_dims = {BS*HEADS, SEQ_LEN, EMBED/HEADS/3}; -// dnnl::memory::dims dst_dims = {BS*HEADS, SEQ_LEN, EMBED/HEADS/3}; +NNVM_REGISTER_OP(_sg_mkldnn_selfatt_valatt) +.describe(R"code(_sg_mkldnn_selfatt_valatt)code" ADD_FILELINE) +.set_num_inputs([](const NodeAttrs& attrs) { + auto const& param = nnvm::get(attrs.parsed); + if (param.quantized) { + return 6; + } else { + return 2; + } +}) +.set_num_outputs([](const NodeAttrs& attrs) { + auto const& param = nnvm::get(attrs.parsed); + if (param.quantized && !param.enable_float_output) { + return 3; + } else { + return 1; + } +}) +.set_attr_parser(ParamParser) +.set_attr("FListInputNames", [](const NodeAttrs& attrs) { + auto const& param = nnvm::get(attrs.parsed); + std::vector input_names {"queries_keys_values", "attention"}; + if (param.quantized) { + input_names.emplace_back("min_qkv"); + input_names.emplace_back("max_qkv"); -// // dnnl::memory::dims src1_strides = {SEQ_LEN*SEQ_LEN, SEQ_LEN, 1}; -// dnnl::memory::dims src2_strides = {3*(EMBED/HEADS/3), EMBED*BS, 1}; - -// auto src1_md = dnnl::memory::desc(src1_dims, dnnl::memory::data_type::u8, dnnl::memory::format_tag::abc); // CHECK IF IT IS U8 FOR SURE -// auto src2_md = dnnl::memory::desc(src2_dims, dnnl::memory::data_type::s8, src2_strides); -// auto dst_md = dnnl::memory::desc(dst_dims, dnnl::memory::data_type::s32, dnnl::memory::format_tag::bac); - -// float min_qkv = inputs[2].dptr()[0]; -// float max_qkv = inputs[3].dptr()[0]; -// float min_att = inputs[4].dptr()[0]; -// float max_att = inputs[5].dptr()[0]; - -// // float qkv_scale = GetQuantizeScale(mshadow::kInt8, min_qkv, max_qkv); -// // float att_scale = GetQuantizeScale(mshadow::kUint8, min_att, max_att); - -// // const float scale = 1.0f / (qkv_scale * att_scale); - -// // dnnl::primitive_attr attr; -// // attr.set_output_scales(0, {scale}); - -// auto matmul_d = dnnl::matmul::desc(src1_md, src2_md, dst_md); -// auto matmul_pd = dnnl::matmul::primitive_desc(matmul_d, engine); - -// auto matmul_prim = dnnl::matmul(matmul_pd); - -// mshadow::Stream* s = ctx.get_stream(); -// int8_t* queries_keys_values = inputs[0].FlatTo2D(s).dptr_; -// uint8_t* attention_maps = inputs[1].FlatTo2D(s).dptr_; -// int32_t* output = outputs[0].FlatTo2D(s).dptr_; - -// float* min_output = outputs[1].dptr(); -// float* max_output = outputs[2].dptr(); - -// mxnet_op::Kernel::Launch( -// s, 1, min_output, max_output, &min_qkv, &max_qkv, &min_att, -// &max_att); - -// auto src1_mem = dnnl::memory(src1_md, engine, attention_maps); -// auto src2_mem = dnnl::memory(src2_md, engine, queries_keys_values+2*(EMBED/HEADS/3)); -// auto dst_mem = dnnl::memory(dst_md, engine, output); - -// std::unordered_map matmul_args; -// matmul_args.insert({DNNL_ARG_SRC, src1_mem}); -// matmul_args.insert({DNNL_ARG_WEIGHTS, src2_mem}); -// matmul_args.insert({DNNL_ARG_DST, dst_mem}); - -// matmul_prim.execute(engine_stream, matmul_args); -// engine_stream.wait(); -// } -// } else { -// const dnnl::memory::dim HEADS = param.heads; -// const dnnl::memory::dim BS = inputs[0].shape_[1]; -// const dnnl::memory::dim SEQ_LEN = inputs[0].shape_[0]; -// const dnnl::memory::dim EMBED = inputs[0].shape_[2]; - -// dnnl::engine engine(dnnl::engine::kind::cpu, 0); -// dnnl::stream engine_stream(engine); - -// dnnl::memory::dims src1_dims = {BS*HEADS, SEQ_LEN, SEQ_LEN}; -// dnnl::memory::dims src2_dims = {BS*HEADS, SEQ_LEN, EMBED/HEADS/3}; -// dnnl::memory::dims dst_dims = {BS*HEADS, SEQ_LEN, EMBED/HEADS/3}; - -// // dnnl::memory::dims src1_strides = {SEQ_LEN*SEQ_LEN, SEQ_LEN, 1}; -// dnnl::memory::dims src2_strides = {3*(EMBED/HEADS/3), EMBED*BS, 1}; - -// auto src1_md = dnnl::memory::desc(src1_dims, dnnl::memory::data_type::f32, dnnl::memory::format_tag::abc); -// auto src2_md = dnnl::memory::desc(src2_dims, dnnl::memory::data_type::f32, src2_strides); -// auto dst_md = dnnl::memory::desc(dst_dims, dnnl::memory::data_type::f32, dnnl::memory::format_tag::bac); - -// auto matmul_d = dnnl::matmul::desc(src1_md, src2_md, dst_md); -// auto matmul_pd = dnnl::matmul::primitive_desc(matmul_d, engine); - -// auto matmul_prim = dnnl::matmul(matmul_pd); - -// mshadow::Stream* s = ctx.get_stream(); -// float* queries_keys_values = inputs[0].FlatTo2D(s).dptr_; -// float* attention_maps = inputs[1].FlatTo2D(s).dptr_; -// float* output = outputs[0].FlatTo2D(s).dptr_; - -// auto src1_mem = dnnl::memory(src1_md, engine, attention_maps); -// auto src2_mem = dnnl::memory(src2_md, engine, queries_keys_values+2*(EMBED/HEADS/3)); -// auto dst_mem = dnnl::memory(dst_md, engine, output); - -// std::unordered_map matmul_args; -// matmul_args.insert({DNNL_ARG_SRC, src1_mem}); -// matmul_args.insert({DNNL_ARG_WEIGHTS, src2_mem}); -// matmul_args.insert({DNNL_ARG_DST, dst_mem}); - -// matmul_prim.execute(engine_stream, matmul_args); -// engine_stream.wait(); -// } -// } - -// NNVM_REGISTER_OP(_sg_mkldnn_interleaved_matmul_selfatt_valatt) -// .describe(R"code(_sg_mkldnn_interleaved_matmul_selfatt_valatt)code" ADD_FILELINE) -// .set_num_inputs([](const NodeAttrs& attrs) { -// auto const& param = nnvm::get(attrs.parsed); -// if (param.quantized) { -// return 6; -// } else { -// return 2; -// } -// }) -// .set_num_outputs([](const NodeAttrs& attrs) { -// auto const& param = nnvm::get(attrs.parsed); -// if (param.quantized && !param.enable_float_output) { -// return 3; -// } else { -// return 1; -// } -// }) -// .set_attr_parser(ParamParser) -// .set_attr("FListInputNames", [](const NodeAttrs& attrs) { -// auto const& param = nnvm::get(attrs.parsed); -// std::vector input_names {"queries_keys_values", "attention"}; -// if (param.quantized) { -// input_names.emplace_back("min_qkv"); -// input_names.emplace_back("max_qkv"); - -// input_names.emplace_back("min_attention"); -// input_names.emplace_back("max_attention"); -// } -// return input_names; -// }) -// .set_attr("FListOutputNames", [](const NodeAttrs& attrs) { -// auto const& param = nnvm::get(attrs.parsed); -// std::vector output_names {"output"}; -// if (param.quantized && !param.enable_float_output) { -// output_names.emplace_back("min_output"); -// output_names.emplace_back("max_output"); -// } -// return output_names; -// }) -// .set_attr("FInferShape", MKLDNNInterleavedMatMulSelfAttValAttShape) -// .set_attr("FInferType", MKLDNNInterleavedMatMulSelfAttValAttInferType) -// // .set_attr("FCompute", MKLDNNInterleavedMatMulSelfAttValAttCPU) -// .set_attr("FInferStorageType", MKLDNNInterleavedMatMulSelfAttValAttStorageType) -// .set_attr("FCreateOpState", CreateMKLDNNInterleavedMatMulSelfAttValAttState) -// .set_attr("FStatefulComputeEx", MKLDNNInterleavedMatMulSelfAttValAttForward) -// .set_attr("TIsMKLDNN", true) -// .set_attr("FGradient", MakeZeroGradNodes) -// .set_attr("FQuantizable", [](const NodeAttrs& attrs) { -// return QuantizeType::kMust; -// }) -// .set_attr("FQuantizedOp", SgMKLDNNInterleavedMatMulSelfAttValAttQuantizedOp) -// .set_attr("FNeedRequantize", [](const NodeAttrs& attrs) { return true; }) -// .add_argument("queries_keys_values", "NDArray-or-Symbol", "Queries, keys and values interleaved") -// .add_argument("attention", "NDArray-or-Symbol", "Attention maps") -// .add_arguments(MKLDNNInterleavedMatMulParam::__FIELDS__()); + input_names.emplace_back("min_attention"); + input_names.emplace_back("max_attention"); + } + return input_names; +}) +.set_attr("FListOutputNames", [](const NodeAttrs& attrs) { + auto const& param = nnvm::get(attrs.parsed); + std::vector output_names {"output"}; + if (param.quantized && !param.enable_float_output) { + output_names.emplace_back("min_output"); + output_names.emplace_back("max_output"); + } + return output_names; +}) +.set_attr("FInferShape", SgMKLDNNSelfAttShape<2>) +.set_attr("FInferType", MKLDNNInterleavedMatMulSelfAttValAttInferType) +.set_attr("FInferStorageType", MKLDNNInterleavedMatMulSelfAttValAttStorageType) +.set_attr("FCreateOpState", CreateMKLDNNInterleavedMatMulSelfAttValAttState) +.set_attr("FStatefulComputeEx", MKLDNNInterleavedMatMulSelfAttValAttForward) +.set_attr("TIsMKLDNN", true) +.set_attr("FGradient", MakeZeroGradNodes) +.set_attr("FQuantizable", [](const NodeAttrs& attrs) { + return QuantizeType::kMust; +}) +.set_attr("FQuantizedOp", SgMKLDNNInterleavedMatMulSelfAttValAttQuantizedOp) +.set_attr("FNeedRequantize", [](const NodeAttrs& attrs) { return true; }) +.add_argument("queries_keys_values", "NDArray-or-Symbol", "Queries, keys and values interleaved") +.add_argument("attention", "NDArray-or-Symbol", "Attention maps") +.add_arguments(MKLDNNInterleavedMatMulParam::__FIELDS__()); } // namespace op } // namespace mxnet diff --git a/src/operator/subgraph/mkldnn/mkldnn_transformer_post_quantize_property.h b/src/operator/subgraph/mkldnn/mkldnn_transformer_post_quantize_property.h index ba37c314f5b1..28edc4608ab8 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_transformer_post_quantize_property.h +++ b/src/operator/subgraph/mkldnn/mkldnn_transformer_post_quantize_property.h @@ -54,8 +54,8 @@ class SgMKLDNNTransformerPostQuantizeSelector : public SubgraphSelector { bool Select(const nnvm::Node &n) override { if ((!disable_all) && - (n.op() == Op::Get("_sg_mkldnn_selfatt_qk"))) { // || - //n.op() == Op::Get("_sg_mkldnn_selfatt_valatt"))) { + (n.op() == Op::Get("_sg_mkldnn_selfatt_qk") || + n.op() == Op::Get("_sg_mkldnn_selfatt_valatt"))) { status = disable_all ? kSuccess : kStart; matched_list.clear(); matched_list.push_back(&n); @@ -156,8 +156,9 @@ class SgMKLDNNTransformerPostQuantizeProperty : public SubgraphProperty { DFSVisit(sym.outputs, [&](const nnvm::ObjectPtr &node) { if (node->is_variable()) return; - if (node->op() == Op::Get("_sg_mkldnn_selfatt_qk")) {// || - //node->op() == Op::Get("_sg_mkldnn_selfatt_valatt")) { + if (node->op() == Op::Get("_sg_mkldnn_selfatt_qk") || + node->op() == Op::Get("_sg_mkldnn_selfatt_valatt")) { + LOG(INFO) << " NOO COO?"; interleaved_node = node; } else if (node->op() == Op::Get("_contrib_requantize")) { requantize_node = node; diff --git a/src/operator/subgraph/mkldnn/mkldnn_transformer_property.h b/src/operator/subgraph/mkldnn/mkldnn_transformer_property.h index be188c0ba987..aa4a9e8a679f 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_transformer_property.h +++ b/src/operator/subgraph/mkldnn/mkldnn_transformer_property.h @@ -51,8 +51,8 @@ class SgMKLDNNTransformerSelector : public SubgraphSelector { explicit SgMKLDNNTransformerSelector() {} bool Select(const nnvm::Node &n, const std::shared_ptr& node_attr) override { - if (n.op() == Op::Get(SELFATT_QK)) { //|| - //n.op() == Op::Get(SELFATT_VALATT)) { // Enable when refactored + if (n.op() == Op::Get(SELFATT_QK) || + n.op() == Op::Get(SELFATT_VALATT)) { return true; } return false; @@ -94,8 +94,8 @@ class SgMKLDNNTransformerProperty : public SubgraphProperty { MKLDNNInterleavedMatMulParam new_param; DFSVisit(new_sym.outputs, [&](const nnvm::ObjectPtr &node) { if (node->op() && - (node->op()->name == SELFATT_QK)) {// || - //node->op()->name == SELFATT_VALATT)) { + (node->op()->name == SELFATT_QK || + node->op()->name == SELFATT_VALATT)) { op_name = node->op()->name; auto param = nnvm::get(node->attrs.parsed); new_param.heads = param.heads; From 94ad11f5f43373c1fb65b0ec529532c356ec135d Mon Sep 17 00:00:00 2001 From: "B. Gawrych" Date: Wed, 31 Mar 2021 11:59:24 +0200 Subject: [PATCH 08/13] add valatt test --- tests/python/mkl/test_subgraph.py | 90 ++++++++++++++++++++++++++++--- 1 file changed, 84 insertions(+), 6 deletions(-) diff --git a/tests/python/mkl/test_subgraph.py b/tests/python/mkl/test_subgraph.py index b96e5dea625f..cfb939f8146a 100644 --- a/tests/python/mkl/test_subgraph.py +++ b/tests/python/mkl/test_subgraph.py @@ -49,6 +49,10 @@ 'selfatt_qk': { OP_NAME: 'sg_mkldnn_selfatt_qk', QUANTIZED_OP_NAME: 'quantized_sg_mkldnn_selfatt_qk' + }, + 'selfatt_valatt': { + OP_NAME: 'sg_mkldnn_selfatt_valatt', + QUANTIZED_OP_NAME: 'quantized_sg_mkldnn_selfatt_valatt' } } @@ -56,6 +60,10 @@ fc_post_ops_list=['relu', 'sigmoid', 'tanh', 'softrelu', 'square', 'square_root', 'abs', 'exp', 'bounded_relu'] +quant_op_fp32_output_support = ("quantized_sg_mkldnn_fully_connected", + "quantized_sg_mkldnn_selfatt_qk", + "quantized_sg_mkldnn_selfatt_valatt") + def check_qsym_calibrated(qsym, out_type, name='conv'): quantized_op_name = 'quantized_' + name assert ''.join(qsym.attr_dict().keys()).find(quantized_op_name) != -1 @@ -64,8 +72,7 @@ def check_qsym_calibrated(qsym, out_type, name='conv'): assert v['out_type'] == out_type if k.find(quantized_op_name) != -1: if ('enable_float_output' in v - and quantized_op_name.startswith(("quantized_sg_mkldnn_fully_connected", - "quantized_sg_mkldnn_selfatt_qk"))): + and quantized_op_name.startswith(quant_op_fp32_output_support)): continue assert 'min_calib_range' in v assert 'max_calib_range' in v @@ -125,9 +132,11 @@ def check_qsym_gluon_forward(qsym, qarg_params, qaux_params, data_shape): class CalibIter(mx.io.DataIter): def __init__(self, batch, data_shape, batch_size): super(CalibIter, self).__init__(batch_size) - self.data_shape = data_shape self.label_shape = (batch_size,) - self.provide_data = [('data', self.data_shape)] + if isinstance(data_shape, tuple): + self.provide_data = [('data', data_shape)] + else: + self.provide_data = data_shape self.provide_label = [] self.batch = batch @@ -255,7 +264,6 @@ def check_fusion(sym, data_shape, attrs_dict, check_fp32_fusion=True, check_quan if ''.join(sym.get_internals().list_outputs()).find('sqrt') != -1: check_quantization = False data_min = 0 - sym_sg = sym.get_backend_symbol(SG_PASS_NAME) for name, attrs in attrs_dict.items(): if name in config: @@ -881,7 +889,7 @@ def test_fc_eltwise(): @with_seed() def test_selfatt_qk(): batchsizes = [1, 8, 16] - seq_lengths = [180, 255, 384]#,64,128,256,512,768,1024] + seq_lengths = [180, 255, 384] num_hidden = [1024, 2048, 3072] num_heads = [8, 16] for bs, seqlen, nhidden, nheads in itertools.product(batchsizes, seq_lengths, num_hidden, num_heads): @@ -889,6 +897,76 @@ def test_selfatt_qk(): syms, attrs = single_selfatt_qk(dshape, nheads) check_fusion(syms, dshape, attrs, out_types=['int8', 'auto'], check_quantization=True) +@with_seed() +def test_selfatt_valatt(): + batchsizes = [1, 8] + seq_lengths = [18, 255, 384] + num_hidden = [1024, 3072] + num_heads = [1, 16] + + def get_valatt_symbol(qkv_shape, attention_shape, nheads): + qkv = mx.symbol.Variable('qkv', shape=qkv_shape, dtype='float32') + attention = mx.symbol.Variable('attention', shape=attention_shape, dtype='float32') + # CalibIter assumes that batch_size is always first dimension + # following operators changes shapes to the proper one + qkv_swap = mx.symbol.swapaxes(data=qkv, dim1=0, dim2=1) + attention_reshape = mx.symbol.reshape(data=attention, shape=(-1, 0, 0), reverse=True) + sym = mx.symbol.contrib.interleaved_matmul_selfatt_valatt(queries_keys_values=qkv_swap, + attention=attention_reshape, + heads=nheads) + return sym + + def check_valatt_quantize(sym, qkv_shape, att_shape): + qkv_nd = mx.nd.random.uniform(low=-1, high=1, shape=qkv_shape) + weight_nd = mx.nd.random.uniform(low=0, high=1, shape=att_shape) + arg_params = { + 'qkv': qkv_nd, + 'attention': weight_nd + } + + ex = sym.bind(mx.cpu(), arg_params, args_grad=None) + ex.forward() + ref_out = ex.outputs + + sym_sg = sym.get_backend_symbol(QUANTIZE_SG_PASS_NAME) + + batch = mx.io.DataBatch([qkv_nd, weight_nd], []) + calib_data = CalibIter(batch, [('qkv', qkv_shape), ('attention', att_shape)], bs) + qsym, qarg_params, qaux_params = mx.contrib.quant.quantize_model(sym=sym_sg, + arg_params=arg_params, + aux_params={}, + ctx=mx.cpu(), + excluded_sym_names=None, + excluded_op_names=None, + quantize_granularity='tensor-wise', + quantized_dtype='auto', + calib_mode='naive', + calib_data=calib_data, + data_names=('qkv', 'attention'), + label_names=None, + num_calib_examples=1, + quantize_mode='full') + qsym = qsym.get_backend_symbol(QUANTIZE_SG_PASS_NAME) + + qex = qsym.bind(mx.cpu(), arg_params, args_grad=None) + qex.forward() + quantized_out = qex.outputs + + for i in range(len(ref_out)): + min_range = mx.nd.min(ref_out[i]).asscalar() + max_range = mx.nd.max(ref_out[i]).asscalar() + atol = 0.1 * max(abs(min_range), abs(max_range)) + assert_almost_equal_with_err(quantized_out[i].asnumpy(), ref_out[i].asnumpy(), rtol=0.1, atol=atol, etol=0.2) + + for bs, seqlen, nhidden, nheads in itertools.product(batchsizes, seq_lengths, num_hidden, num_heads): + qkv_shape = (bs, seqlen, 3*nhidden) + att_shape = (bs, nheads, seqlen, seqlen) + + sym = get_valatt_symbol(qkv_shape, att_shape, nheads) + check_fusion(sym, None, {'selfatt_valatt': {}}, check_quantization=False) + check_valatt_quantize(sym, qkv_shape, att_shape) + + @with_seed() def test_neg_fc_relu(): for dshape, no_bias, flatten in itertools.product(DATA_SHAPE, [True, False], [True, False]): From 145461697301c338e88eca1bd2b8b53dc8ff5e23 Mon Sep 17 00:00:00 2001 From: "B. Gawrych" Date: Wed, 31 Mar 2021 11:59:39 +0200 Subject: [PATCH 09/13] refactor valatt --- .../subgraph/mkldnn/mkldnn_transformer.cc | 368 ++++++++---------- ...kldnn_transformer_post_quantize_property.h | 1 - 2 files changed, 163 insertions(+), 206 deletions(-) diff --git a/src/operator/subgraph/mkldnn/mkldnn_transformer.cc b/src/operator/subgraph/mkldnn/mkldnn_transformer.cc index e6fb2713ed9f..9917e015b0fc 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_transformer.cc +++ b/src/operator/subgraph/mkldnn/mkldnn_transformer.cc @@ -37,16 +37,16 @@ DMLC_REGISTER_PARAMETER(MKLDNNInterleavedMatMulParam); template static bool SgMKLDNNSelfAttShape(const NodeAttrs& attrs, - mxnet::ShapeVector* in_shapes, - mxnet::ShapeVector* out_shapes) { + mxnet::ShapeVector* in_shapes, + mxnet::ShapeVector* out_shapes) { const auto& param = nnvm::get(attrs.parsed); if (param.quantized) { mxnet::ShapeVector base_in_shapes; + mxnet::ShapeVector base_out_shapes = {out_shapes->at(0)}; + for(int i=0; i < base_num_inputs; i++) { base_in_shapes.emplace_back(in_shapes->at(i)); } - - mxnet::ShapeVector base_out_shapes = {out_shapes->at(0)}; bool ret = DefaultSubgraphOpShape(attrs, &base_in_shapes, &base_out_shapes); for (size_t i = 0; i < in_shapes->size(); ++i) { @@ -97,15 +97,20 @@ static bool SgMKLDNNSelfAttQKInferType(const nnvm::NodeAttrs &attrs, } } -static bool SgMKLDNNSelfAttQKStorageType(const nnvm::NodeAttrs &attrs, - const int dev_mask, - DispatchMode *dispatch_mode, - std::vector *in_attrs, - std::vector *out_attrs) { +template +static bool SgMKLDNNSelfAttStorageType(const nnvm::NodeAttrs &attrs, + const int dev_mask, + DispatchMode *dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs) { auto const ¶m = nnvm::get(attrs.parsed); if (param.quantized) { - std::vector base_in_attrs{in_attrs->at(0)}; + std::vector base_in_attrs; std::vector base_out_attrs{out_attrs->at(0)}; + + for(int i=0; i < base_num_inputs; i++) { + base_in_attrs.emplace_back(in_attrs->at(i)); + } bool ret = DefaultSubgraphOpStorageType(attrs, dev_mask, dispatch_mode, &base_in_attrs, &base_out_attrs); @@ -296,7 +301,7 @@ void SgMKLDNNSelfAttQKOp::Forward(const OpContext &ctx, }); MSHADOW_TYPE_SWITCH(outputs[0].dtype(), DType, { - cached_out_mem_->set_data_handle(reinterpret_cast(outputs[0].data().dptr())); + cached_out_mem_->set_data_handle(outputs[0].data().dptr()); }); MKLDNNStream::Get()->RegisterPrimArgs(*fwd_, args_); @@ -309,7 +314,6 @@ void SgMKLDNNSelfAttQKOp::Forward(const OpContext &ctx, *output_min = min_output_; *output_max = max_output_; } - } nnvm::ObjectPtr SgMKLDNNSelfAttQKQuantizedOp(const NodeAttrs& attrs) { @@ -367,7 +371,7 @@ NNVM_REGISTER_OP(_sg_mkldnn_selfatt_qk) }) .set_attr("FInferShape", SgMKLDNNSelfAttShape<1>) .set_attr("FInferType", SgMKLDNNSelfAttQKInferType) -.set_attr("FInferStorageType", SgMKLDNNSelfAttQKStorageType) +.set_attr("FInferStorageType", SgMKLDNNSelfAttStorageType<1>) .set_attr("FCreateOpState", CreateSgMKLDNNSelfAttQKState) .set_attr("FStatefulComputeEx", SgMKLDNNSelfAttQKForward) .set_attr("TIsMKLDNN", true) @@ -380,35 +384,20 @@ NNVM_REGISTER_OP(_sg_mkldnn_selfatt_qk) .add_argument("queries_keys_values", "NDArray-or-Symbol", "Interleaved queries, keys and values") .add_arguments(MKLDNNInterleavedMatMulParam::__FIELDS__()); -/********************************************************************************************************/ +/**********************************_sg_mkldnn_selfatt_valatt**********************************/ -static bool MKLDNNInterleavedMatMulSelfAttValAttInferType(const nnvm::NodeAttrs &attrs, - std::vector *in_types, - std::vector *out_types) { +static bool SgMKLDNNSelfAttValAttInferType(const nnvm::NodeAttrs &attrs, + std::vector *in_types, + std::vector *out_types) { const auto& param = nnvm::get(attrs.parsed); if (param.quantized) { - // CHECK(in_types->at(0) == mshadow::kInt8 || - // in_types->at(0) == mshadow::kUint8) - // << "QuantizedInterleavedMatMulSelfAttValAtt only supports int8/uint8 input, while " - // << in_types->at(0) << " is given."; // attention weights - // // TYPE_ASSIGN_CHECK(*in_types, 0, mshadow::kInt8); // qkv input - // CHECK(in_types->at(1) == mshadow::kInt8 || - // in_types->at(1) == mshadow::kUint8) - // << "QuantizedInterleavedMatMulSelfAttValAtt only supports int8/uint8 input, while " - // << in_types->at(1) << " is given."; // attention weights - - // //min qkv, max qkv, min att, max att - // for (size_t i = 2; i < in_types->size(); ++i) { - // TYPE_ASSIGN_CHECK(*in_types, i, mshadow::kFloat32); - // } - TYPE_ASSIGN_CHECK(*in_types, 0, mshadow::kInt8); // qkv input - TYPE_ASSIGN_CHECK(*in_types, 1, mshadow::kUint8); // att input - - TYPE_ASSIGN_CHECK(*in_types, 2, mshadow::kFloat32); // min qkv - TYPE_ASSIGN_CHECK(*in_types, 3, mshadow::kFloat32); // max qkv - - TYPE_ASSIGN_CHECK(*in_types, 4, mshadow::kFloat32); // min att - TYPE_ASSIGN_CHECK(*in_types, 5, mshadow::kFloat32); // max att + TYPE_ASSIGN_CHECK(*in_types, 0, mshadow::kInt8); // qkv input + TYPE_ASSIGN_CHECK(*in_types, 1, mshadow::kUint8); // att input + + //min qkv, max qkv, min att, max att + for (size_t i = 2; i < in_types->size(); ++i) { + TYPE_ASSIGN_CHECK(*in_types, i, mshadow::kFloat32); + } if (param.enable_float_output) { TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kFloat32); // output @@ -427,39 +416,7 @@ static bool MKLDNNInterleavedMatMulSelfAttValAttInferType(const nnvm::NodeAttrs } } -static bool MKLDNNInterleavedMatMulSelfAttValAttStorageType(const nnvm::NodeAttrs &attrs, - const int dev_mask, - DispatchMode *dispatch_mode, - std::vector *in_attrs, - std::vector *out_attrs) { - auto const ¶m = nnvm::get(attrs.parsed); - if (param.quantized) { - std::vector base_in_attrs = {in_attrs->at(0), in_attrs->at(1)}; - std::vector base_out_attrs = {out_attrs->at(0)}; - - bool ret = DefaultSubgraphOpStorageType(attrs, dev_mask, dispatch_mode, - &base_in_attrs, &base_out_attrs); - - for (size_t i = 0; i < in_attrs->size(); ++i) { - if (i < base_in_attrs.size()) - in_attrs->at(i) = base_in_attrs[i]; - else - type_assign(&in_attrs->at(i), mxnet::kDefaultStorage); - } - - out_attrs->at(0) = base_out_attrs[0]; - if (!param.enable_float_output) { - type_assign(&out_attrs->at(1), mxnet::kDefaultStorage); - type_assign(&out_attrs->at(2), mxnet::kDefaultStorage); - } - return ret; - } else { - return DefaultSubgraphOpStorageType(attrs, dev_mask, dispatch_mode, - in_attrs, out_attrs); - } -} - -nnvm::ObjectPtr SgMKLDNNInterleavedMatMulSelfAttValAttQuantizedOp(const NodeAttrs& attrs) { +nnvm::ObjectPtr SgMKLDNNSelfAttValAttQuantizedOp(const NodeAttrs& attrs) { nnvm::ObjectPtr node = nnvm::Node::Create(); auto const ¶m = nnvm::get(attrs.parsed); node->attrs.op = Op::Get("_sg_mkldnn_selfatt_valatt"); @@ -475,9 +432,9 @@ nnvm::ObjectPtr SgMKLDNNInterleavedMatMulSelfAttValAttQuantizedOp(const NodeAttr return node; } -class MKLDNNInterleavedMatMulSelfAttValAttOp { +class MKLDNNSelfAttValAttOp { public: - explicit MKLDNNInterleavedMatMulSelfAttValAttOp(const nnvm::NodeAttrs &attrs) : + explicit MKLDNNSelfAttValAttOp(const nnvm::NodeAttrs &attrs) : param_(nnvm::get(attrs.parsed)) {} void Forward(const OpContext &ctx, @@ -493,166 +450,167 @@ class MKLDNNInterleavedMatMulSelfAttValAttOp { "inference computation."; } + void Initialize(const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs); + + bool IsInitialized() { + return initialized_; + } + private: bool initialized_{false}; MKLDNNInterleavedMatMulParam param_; mkldnn_args_map_t args_; std::shared_ptr fwd_; - std::shared_ptr cached_data1_mem_; - std::shared_ptr cached_data2_mem_; + std::shared_ptr cached_att_mem_; + std::shared_ptr cached_qkv_mem_; std::shared_ptr cached_out_mem_; - float cached_min_qkv_; - float cached_max_qkv_; - float cached_min_att_; - float cached_max_att_; - float cached_min_output_; - float cached_max_output_; + float min_qkv_; + float max_qkv_; + float min_att_; + float max_att_; + float min_output_; + float max_output_; float qkv_scale_{0.0f}; float att_scale_{0.0f}; }; -static OpStatePtr CreateMKLDNNInterleavedMatMulSelfAttValAttState(const nnvm::NodeAttrs &attrs, - Context ctx, - const mxnet::ShapeVector &in_shapes, - const std::vector &in_types) { - return OpStatePtr::Create(attrs); +static OpStatePtr CreateMKLDNNSelfAttValAttState(const nnvm::NodeAttrs &attrs, + Context ctx, + const mxnet::ShapeVector &in_shapes, + const std::vector &in_types) { + return OpStatePtr::Create(attrs); } -static void MKLDNNInterleavedMatMulSelfAttValAttForward(const OpStatePtr &state_pointer, - const OpContext &ctx, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs) { - MKLDNNInterleavedMatMulSelfAttValAttOp &op = state_pointer.get_state(); +static void MKLDNNSelfAttValAttForward(const OpStatePtr &state_pointer, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + MKLDNNSelfAttValAttOp &op = state_pointer.get_state(); + if(!op.IsInitialized()) { + op.Initialize(ctx, inputs, req, outputs); + } op.Forward(ctx, inputs, req, outputs); } -void MKLDNNInterleavedMatMulSelfAttValAttOp::Forward( - const OpContext &ctx, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs) { - - float min_qkv = 0.0f; - float max_qkv = 0.0f; - float min_att = 0.0f; - float max_att = 0.0f; - - if (param_.quantized) { - min_qkv = inputs[2].data().dptr()[0]; - max_qkv = inputs[3].data().dptr()[0]; - min_att = inputs[4].data().dptr()[0]; - max_att = inputs[5].data().dptr()[0]; - } - - const dnnl::memory::dim HEADS = param_.heads; - const dnnl::memory::dim BS = inputs[0].shape()[1]; - const dnnl::memory::dim SEQ_LEN = inputs[0].shape()[0]; - const dnnl::memory::dim EMBED = inputs[0].shape()[2]; - - if (!initialized_) { - const auto engine = CpuEngine::Get()->get_engine(); - - cached_min_qkv_ = min_qkv; - cached_max_qkv_ = max_qkv; - cached_min_att_ = min_att; - cached_max_att_ = max_att; - - dnnl::memory::dims src1_dims = {BS*HEADS, SEQ_LEN, SEQ_LEN}; - dnnl::memory::dims src2_dims = {BS*HEADS, SEQ_LEN, EMBED/HEADS/3}; - dnnl::memory::dims dst_dims = {BS*HEADS, SEQ_LEN, EMBED/HEADS/3}; +void MKLDNNSelfAttValAttOp::Initialize(const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { - dnnl::memory::dims src1_strides = {SEQ_LEN*SEQ_LEN, SEQ_LEN, 1}; - dnnl::memory::dims src2_strides = {3*(EMBED/HEADS/3), EMBED*BS, 1}; + const dnnl::memory::dim qkv_seq_len = inputs[0].shape()[0]; + const dnnl::memory::dim sequences = inputs[0].shape()[1]; + const dnnl::memory::dim output_lin_dim = inputs[0].shape()[2]; + const dnnl::memory::dim embed_dim = output_lin_dim / 3; + const dnnl::memory::dim head_dim = embed_dim / param_.heads; + const dnnl::memory::dim attn_batches = param_.heads * sequences; + const dnnl::memory::dim lead_dim = attn_batches * 3 * head_dim; + const dnnl::memory::dim batch_stride = 3 * head_dim; - auto src1_md = param_.quantized ? dnnl::memory::desc(src1_dims, dnnl::memory::data_type::u8, src1_strides) : - dnnl::memory::desc(src1_dims, dnnl::memory::data_type::f32, src1_strides); - auto src2_md = param_.quantized ? dnnl::memory::desc(src2_dims, dnnl::memory::data_type::s8, src2_strides) : - dnnl::memory::desc(src2_dims, dnnl::memory::data_type::f32, src2_strides); - // const auto src1_dtype = get_mkldnn_type(inputs[0].dtype()); - // const auto src2_dtype = get_mkldnn_type(inputs[1].dtype()); + dnnl::memory::dims att_dims = {attn_batches, qkv_seq_len, qkv_seq_len}; + dnnl::memory::dims qkv_dims = {attn_batches, qkv_seq_len, head_dim}; + dnnl::memory::dims dst_dims = {attn_batches, qkv_seq_len, head_dim}; - // auto src1_md = dnnl::memory::desc(src1_dims, src1_dtype, src1_strides); - // auto src2_md = dnnl::memory::desc(src2_dims, src2_dtype, src2_strides); - - dnnl::memory::desc dst_md; - - float tmp_scale = 1.0f; - if (param_.quantized) { + dnnl::memory::dims att_strides = {qkv_seq_len * qkv_seq_len, qkv_seq_len, 1}; + dnnl::memory::dims qkv_strides = {batch_stride, lead_dim, 1}; - // qkv_scale_ = GetQuantizeScale(inputs[0].dtype(), min_qkv, max_qkv); - // att_scale_ = GetQuantizeScale(inputs[1].dtype(), min_att, max_att); - qkv_scale_ = GetQuantizeScale(mshadow::kInt8, min_qkv, max_qkv); - att_scale_ = GetQuantizeScale(mshadow::kUint8, min_att, max_att); + auto att_dtype = inputs[1].dtype(); + auto qkv_dtype = inputs[0].dtype(); + auto out_dtype = outputs[0].dtype(); + auto att_md = dnnl::memory::desc(att_dims, get_mkldnn_type(att_dtype), att_strides); + auto qkv_md = dnnl::memory::desc(qkv_dims, get_mkldnn_type(qkv_dtype), qkv_strides); + dnnl::memory::desc dst_md; - if (param_.min_calib_range.has_value() && - param_.max_calib_range.has_value()) { - cached_min_output_ = param_.min_calib_range.value(); - cached_max_output_ = param_.max_calib_range.value(); - - tmp_scale = GetQuantizeScale(mshadow::kInt8, cached_min_output_, cached_max_output_) / (qkv_scale_ * att_scale_); - dst_md = dnnl::memory::desc(dst_dims, dnnl::memory::data_type::s8, dnnl::memory::format_tag::bac); - } else if (param_.enable_float_output) { - tmp_scale = 1.0f / (qkv_scale_ * att_scale_); - dst_md = dnnl::memory::desc(dst_dims, dnnl::memory::data_type::f32, dnnl::memory::format_tag::bac); - } else { - mshadow::Stream *s = ctx.get_stream(); - mxnet_op::Kernel::Launch( - s, 1, &cached_min_output_, &cached_max_output_, &min_qkv, &max_qkv, &min_att, - &max_att); - dst_md = dnnl::memory::desc(dst_dims, dnnl::memory::data_type::s32, dnnl::memory::format_tag::bac); - } + float oscale = 1.0f; + if (param_.quantized) { + min_qkv_ = inputs[2].data().dptr()[0]; + max_qkv_ = inputs[3].data().dptr()[0]; + min_att_ = inputs[4].data().dptr()[0]; + max_att_ = inputs[5].data().dptr()[0]; + qkv_scale_ = GetQuantizeScale(qkv_dtype, min_qkv_, max_qkv_); + att_scale_ = GetQuantizeScale(att_dtype, min_att_, max_att_); + + if (param_.min_calib_range.has_value() && + param_.max_calib_range.has_value()) { + min_output_ = param_.min_calib_range.value(); + max_output_ = param_.max_calib_range.value(); + + oscale = GetQuantizeScale(out_dtype, min_output_, max_output_) / (qkv_scale_ * att_scale_); + } else if (param_.enable_float_output) { + oscale = 1.0f / (qkv_scale_ * att_scale_); } else { - dst_md = dnnl::memory::desc(dst_dims, dnnl::memory::data_type::f32, dnnl::memory::format_tag::bac); + mshadow::Stream *s = ctx.get_stream(); + mxnet_op::Kernel::Launch( + s, 1, &min_output_, &max_output_, &min_qkv_, &max_qkv_, &min_att_, + &max_att_); } + } + dst_md = dnnl::memory::desc(dst_dims, get_mkldnn_type(out_dtype), dnnl::memory::format_tag::bac); + + const auto engine = CpuEngine::Get()->get_engine(); + dnnl::primitive_attr attr; + attr.set_output_scales(0, {oscale}); + auto matmul_d = dnnl::matmul::desc(att_md, qkv_md, dst_md); + auto matmul_pd = dnnl::matmul::primitive_desc(matmul_d, attr, engine); + + fwd_ = std::make_shared(matmul_pd); + + MSHADOW_TYPE_SWITCH(att_dtype, DType, { + DType* att_ptr = inputs[1].data().dptr(); + cached_att_mem_ = std::make_shared(att_md, engine, att_ptr); + }); + MSHADOW_TYPE_SWITCH(qkv_dtype, DType, { + DType* value_ptr = inputs[0].data().dptr() + 2*head_dim; + cached_qkv_mem_ = std::make_shared(qkv_md, engine, value_ptr); + }); + MSHADOW_TYPE_SWITCH(out_dtype, DType, { + DType* out_ptr = outputs[0].data().dptr(); + cached_out_mem_ = std::make_shared(dst_md, engine, out_ptr); + }); + + args_[DNNL_ARG_SRC] = *cached_att_mem_; + args_[DNNL_ARG_WEIGHTS] = *cached_qkv_mem_; + args_[DNNL_ARG_DST] = *cached_out_mem_; + initialized_ = true; +} - dnnl::primitive_attr attr; - attr.set_output_scales(0, {tmp_scale}); - auto matmul_d = dnnl::matmul::desc(src1_md, src2_md, dst_md); - auto matmul_pd = dnnl::matmul::primitive_desc(matmul_d, attr, engine); - - fwd_ = std::make_shared(matmul_pd); - - MSHADOW_TYPE_SWITCH(inputs[1].dtype(), DType, { - cached_data1_mem_ = std::make_shared(src1_md, engine, inputs[1].data().dptr()); - }); - MSHADOW_TYPE_SWITCH(inputs[0].dtype(), DType, { - cached_data2_mem_ = std::make_shared(src2_md, engine, inputs[0].data().dptr() + 2*(EMBED/HEADS/3)); - }); - MSHADOW_TYPE_SWITCH(outputs[0].dtype(), DType, { - cached_out_mem_ = std::make_shared(dst_md, engine, outputs[0].data().dptr()); - }); +void MKLDNNSelfAttValAttOp::Forward(const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + + const auto engine = CpuEngine::Get()->get_engine(); + const size_t head_dim = inputs[0].shape()[2] / param_.heads / 3; + MSHADOW_TYPE_SWITCH(inputs[1].dtype(), DType, { + DType* att_ptr = inputs[1].data().dptr(); + cached_att_mem_->set_data_handle(att_ptr); + }); + MSHADOW_TYPE_SWITCH(inputs[0].dtype(), DType, { + DType* value_ptr = inputs[0].data().dptr() + 2*head_dim; + cached_qkv_mem_->set_data_handle(value_ptr); + }); + MSHADOW_TYPE_SWITCH(outputs[0].dtype(), DType, { + DType* out_ptr = outputs[0].data().dptr(); + cached_out_mem_->set_data_handle(out_ptr); + }); - args_[DNNL_ARG_SRC] = *cached_data1_mem_; - args_[DNNL_ARG_WEIGHTS] = *cached_data2_mem_; - args_[DNNL_ARG_DST] = *cached_out_mem_; - initialized_ = true; - } else { - MSHADOW_TYPE_SWITCH(inputs[1].dtype(), DType, { - cached_data1_mem_->set_data_handle(reinterpret_cast(inputs[1].data().dptr())); - }); - MSHADOW_TYPE_SWITCH(inputs[0].dtype(), DType, { - cached_data2_mem_->set_data_handle(reinterpret_cast(inputs[0].data().dptr() + 2*(EMBED/HEADS/3))); - }); - MSHADOW_TYPE_SWITCH(outputs[0].dtype(), DType, { - cached_out_mem_->set_data_handle(reinterpret_cast(outputs[0].data().dptr())); - }); - } MKLDNNStream::Get()->RegisterPrimArgs(*fwd_, args_); MKLDNNStream::Get()->Submit(); if (param_.quantized && !param_.enable_float_output) { - float* min_output = outputs[1].data().dptr(); - float* max_output = outputs[2].data().dptr(); + float* output_min = outputs[1].data().dptr(); + float* output_max = outputs[2].data().dptr(); - *min_output = cached_min_output_; - *max_output = cached_max_output_; + *output_min = min_output_; + *output_max = max_output_; } } - NNVM_REGISTER_OP(_sg_mkldnn_selfatt_valatt) .describe(R"code(_sg_mkldnn_selfatt_valatt)code" ADD_FILELINE) .set_num_inputs([](const NodeAttrs& attrs) { @@ -694,16 +652,16 @@ NNVM_REGISTER_OP(_sg_mkldnn_selfatt_valatt) return output_names; }) .set_attr("FInferShape", SgMKLDNNSelfAttShape<2>) -.set_attr("FInferType", MKLDNNInterleavedMatMulSelfAttValAttInferType) -.set_attr("FInferStorageType", MKLDNNInterleavedMatMulSelfAttValAttStorageType) -.set_attr("FCreateOpState", CreateMKLDNNInterleavedMatMulSelfAttValAttState) -.set_attr("FStatefulComputeEx", MKLDNNInterleavedMatMulSelfAttValAttForward) +.set_attr("FInferType", SgMKLDNNSelfAttValAttInferType) +.set_attr("FInferStorageType", SgMKLDNNSelfAttStorageType<2>) +.set_attr("FCreateOpState", CreateMKLDNNSelfAttValAttState) +.set_attr("FStatefulComputeEx", MKLDNNSelfAttValAttForward) .set_attr("TIsMKLDNN", true) .set_attr("FGradient", MakeZeroGradNodes) .set_attr("FQuantizable", [](const NodeAttrs& attrs) { return QuantizeType::kMust; }) -.set_attr("FQuantizedOp", SgMKLDNNInterleavedMatMulSelfAttValAttQuantizedOp) +.set_attr("FQuantizedOp", SgMKLDNNSelfAttValAttQuantizedOp) .set_attr("FNeedRequantize", [](const NodeAttrs& attrs) { return true; }) .add_argument("queries_keys_values", "NDArray-or-Symbol", "Queries, keys and values interleaved") .add_argument("attention", "NDArray-or-Symbol", "Attention maps") diff --git a/src/operator/subgraph/mkldnn/mkldnn_transformer_post_quantize_property.h b/src/operator/subgraph/mkldnn/mkldnn_transformer_post_quantize_property.h index 28edc4608ab8..9291e427f0ed 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_transformer_post_quantize_property.h +++ b/src/operator/subgraph/mkldnn/mkldnn_transformer_post_quantize_property.h @@ -158,7 +158,6 @@ class SgMKLDNNTransformerPostQuantizeProperty : public SubgraphProperty { if (node->is_variable()) return; if (node->op() == Op::Get("_sg_mkldnn_selfatt_qk") || node->op() == Op::Get("_sg_mkldnn_selfatt_valatt")) { - LOG(INFO) << " NOO COO?"; interleaved_node = node; } else if (node->op() == Op::Get("_contrib_requantize")) { requantize_node = node; From d1ba73aa3ca88493a02680e0cb7299be141aec85 Mon Sep 17 00:00:00 2001 From: "B. Gawrych" Date: Wed, 7 Apr 2021 11:55:01 +0200 Subject: [PATCH 10/13] fix review --- .../subgraph/mkldnn/mkldnn_transformer-inl.h | 7 +++---- .../subgraph/mkldnn/mkldnn_transformer.cc | 17 +++++++++-------- .../mkldnn/mkldnn_transformer_property.h | 6 +++--- tests/python/mkl/test_subgraph.py | 6 +++--- 4 files changed, 18 insertions(+), 18 deletions(-) diff --git a/src/operator/subgraph/mkldnn/mkldnn_transformer-inl.h b/src/operator/subgraph/mkldnn/mkldnn_transformer-inl.h index 9318db9d3e50..22212812982a 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_transformer-inl.h +++ b/src/operator/subgraph/mkldnn/mkldnn_transformer-inl.h @@ -17,10 +17,9 @@ * under the License. */ -#ifndef MXNET_OPERATOR_MKLDNN_TRANSFORMER_INL_H_ -#define MXNET_OPERATOR_MKLDNN_TRANSFORMER_INL_H_ +#ifndef MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_TRANSFORMER_INL_H_ +#define MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_TRANSFORMER_INL_H_ -// #include #include "../../mxnet_op.h" #include "../../mshadow_op.h" @@ -56,4 +55,4 @@ struct MKLDNNInterleavedMatMulParam : public dmlc::Parameter +template static bool SgMKLDNNSelfAttShape(const NodeAttrs& attrs, mxnet::ShapeVector* in_shapes, mxnet::ShapeVector* out_shapes) { @@ -97,7 +97,7 @@ static bool SgMKLDNNSelfAttQKInferType(const nnvm::NodeAttrs &attrs, } } -template +template static bool SgMKLDNNSelfAttStorageType(const nnvm::NodeAttrs &attrs, const int dev_mask, DispatchMode *dispatch_mode, @@ -524,7 +524,8 @@ void MKLDNNSelfAttValAttOp::Initialize(const OpContext &ctx, auto att_md = dnnl::memory::desc(att_dims, get_mkldnn_type(att_dtype), att_strides); auto qkv_md = dnnl::memory::desc(qkv_dims, get_mkldnn_type(qkv_dtype), qkv_strides); - dnnl::memory::desc dst_md; + dnnl::memory::desc out_md; + dnnl::primitive_attr attr; float oscale = 1.0f; if (param_.quantized) { @@ -541,8 +542,10 @@ void MKLDNNSelfAttValAttOp::Initialize(const OpContext &ctx, max_output_ = param_.max_calib_range.value(); oscale = GetQuantizeScale(out_dtype, min_output_, max_output_) / (qkv_scale_ * att_scale_); + attr.set_output_scales(0, {oscale}); } else if (param_.enable_float_output) { oscale = 1.0f / (qkv_scale_ * att_scale_); + attr.set_output_scales(0, {oscale}); } else { mshadow::Stream *s = ctx.get_stream(); mxnet_op::Kernel::Launch( @@ -550,12 +553,10 @@ void MKLDNNSelfAttValAttOp::Initialize(const OpContext &ctx, &max_att_); } } - dst_md = dnnl::memory::desc(dst_dims, get_mkldnn_type(out_dtype), dnnl::memory::format_tag::bac); + out_md = dnnl::memory::desc(dst_dims, get_mkldnn_type(out_dtype), dnnl::memory::format_tag::bac); const auto engine = CpuEngine::Get()->get_engine(); - dnnl::primitive_attr attr; - attr.set_output_scales(0, {oscale}); - auto matmul_d = dnnl::matmul::desc(att_md, qkv_md, dst_md); + auto matmul_d = dnnl::matmul::desc(att_md, qkv_md, out_md); auto matmul_pd = dnnl::matmul::primitive_desc(matmul_d, attr, engine); fwd_ = std::make_shared(matmul_pd); @@ -570,7 +571,7 @@ void MKLDNNSelfAttValAttOp::Initialize(const OpContext &ctx, }); MSHADOW_TYPE_SWITCH(out_dtype, DType, { DType* out_ptr = outputs[0].data().dptr(); - cached_out_mem_ = std::make_shared(dst_md, engine, out_ptr); + cached_out_mem_ = std::make_shared(out_md, engine, out_ptr); }); args_[DNNL_ARG_SRC] = *cached_att_mem_; diff --git a/src/operator/subgraph/mkldnn/mkldnn_transformer_property.h b/src/operator/subgraph/mkldnn/mkldnn_transformer_property.h index aa4a9e8a679f..5413cc201f65 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_transformer_property.h +++ b/src/operator/subgraph/mkldnn/mkldnn_transformer_property.h @@ -18,8 +18,8 @@ */ -#ifndef MXNET_OPERATOR_SUBGRAPH_MKLDNN_TRANSFORMER_PROPERTY_H_ -#define MXNET_OPERATOR_SUBGRAPH_MKLDNN_TRANSFORMER_PROPERTY_H_ +#ifndef MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_TRANSFORMER_PROPERTY_H_ +#define MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_TRANSFORMER_PROPERTY_H_ #if MXNET_USE_MKLDNN == 1 #include @@ -134,4 +134,4 @@ class SgMKLDNNTransformerProperty : public SubgraphProperty { } // namespace mxnet #endif // if MXNET_USE_MKLDNN == 1 -#endif // MXNET_OPERATOR_SUBGRAPH_MKLDNN_TRANSFORMER_PROPERTY_H_ +#endif // MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_TRANSFORMER_PROPERTY_H_ diff --git a/tests/python/mkl/test_subgraph.py b/tests/python/mkl/test_subgraph.py index cfb939f8146a..79494a046e2b 100644 --- a/tests/python/mkl/test_subgraph.py +++ b/tests/python/mkl/test_subgraph.py @@ -888,9 +888,9 @@ def test_fc_eltwise(): @with_seed() def test_selfatt_qk(): - batchsizes = [1, 8, 16] - seq_lengths = [180, 255, 384] - num_hidden = [1024, 2048, 3072] + batchsizes = [1, 8] + seq_lengths = [180, 384] + num_hidden = [1024, 3072] num_heads = [8, 16] for bs, seqlen, nhidden, nheads in itertools.product(batchsizes, seq_lengths, num_hidden, num_heads): dshape = (seqlen, bs, nhidden) From 1468c86c83ea5dfdb374db58ab6a962c8defc37d Mon Sep 17 00:00:00 2001 From: "B. Gawrych" Date: Wed, 7 Apr 2021 12:01:31 +0200 Subject: [PATCH 11/13] Change param struct name --- .../subgraph/mkldnn/mkldnn_transformer-inl.h | 4 +- .../subgraph/mkldnn/mkldnn_transformer.cc | 46 +++++++++---------- ...kldnn_transformer_post_quantize_property.h | 2 +- .../mkldnn/mkldnn_transformer_property.h | 2 +- 4 files changed, 27 insertions(+), 27 deletions(-) diff --git a/src/operator/subgraph/mkldnn/mkldnn_transformer-inl.h b/src/operator/subgraph/mkldnn/mkldnn_transformer-inl.h index 22212812982a..d4004351649e 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_transformer-inl.h +++ b/src/operator/subgraph/mkldnn/mkldnn_transformer-inl.h @@ -27,13 +27,13 @@ namespace mxnet { namespace op { -struct MKLDNNInterleavedMatMulParam : public dmlc::Parameter { +struct MKLDNNSelfAttParam : public dmlc::Parameter { int heads; bool quantized; bool enable_float_output; dmlc::optional min_calib_range; // min float value calculated from calibration dataset dmlc::optional max_calib_range; // max float value calculated from calibration dataset - DMLC_DECLARE_PARAMETER(MKLDNNInterleavedMatMulParam) { + DMLC_DECLARE_PARAMETER(MKLDNNSelfAttParam) { DMLC_DECLARE_FIELD(heads) .describe("Set number of heads"); DMLC_DECLARE_FIELD(quantized).set_default(false) diff --git a/src/operator/subgraph/mkldnn/mkldnn_transformer.cc b/src/operator/subgraph/mkldnn/mkldnn_transformer.cc index 28c7409c1b08..585d1929b9ef 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_transformer.cc +++ b/src/operator/subgraph/mkldnn/mkldnn_transformer.cc @@ -33,13 +33,13 @@ namespace mxnet { namespace op { -DMLC_REGISTER_PARAMETER(MKLDNNInterleavedMatMulParam); +DMLC_REGISTER_PARAMETER(MKLDNNSelfAttParam); template static bool SgMKLDNNSelfAttShape(const NodeAttrs& attrs, mxnet::ShapeVector* in_shapes, mxnet::ShapeVector* out_shapes) { - const auto& param = nnvm::get(attrs.parsed); + const auto& param = nnvm::get(attrs.parsed); if (param.quantized) { mxnet::ShapeVector base_in_shapes; mxnet::ShapeVector base_out_shapes = {out_shapes->at(0)}; @@ -71,7 +71,7 @@ static bool SgMKLDNNSelfAttShape(const NodeAttrs& attrs, static bool SgMKLDNNSelfAttQKInferType(const nnvm::NodeAttrs &attrs, std::vector *in_types, std::vector *out_types) { - const auto& param = nnvm::get(attrs.parsed); + const auto& param = nnvm::get(attrs.parsed); if (param.quantized) { CHECK(in_types->at(0) == mshadow::kInt8) << "QuantizedInterleavedMatMulSelfAttQK only supports int8 input, while " @@ -103,7 +103,7 @@ static bool SgMKLDNNSelfAttStorageType(const nnvm::NodeAttrs &attrs, DispatchMode *dispatch_mode, std::vector *in_attrs, std::vector *out_attrs) { - auto const ¶m = nnvm::get(attrs.parsed); + auto const ¶m = nnvm::get(attrs.parsed); if (param.quantized) { std::vector base_in_attrs; std::vector base_out_attrs{out_attrs->at(0)}; @@ -136,7 +136,7 @@ static bool SgMKLDNNSelfAttStorageType(const nnvm::NodeAttrs &attrs, class SgMKLDNNSelfAttQKOp { public: explicit SgMKLDNNSelfAttQKOp(const nnvm::NodeAttrs &attrs) : - param_(nnvm::get(attrs.parsed)) {} + param_(nnvm::get(attrs.parsed)) {} void Forward(const OpContext &ctx, const std::vector &inputs, @@ -162,7 +162,7 @@ class SgMKLDNNSelfAttQKOp { private: bool initialized_{false}; - MKLDNNInterleavedMatMulParam param_; + MKLDNNSelfAttParam param_; mkldnn_args_map_t args_; std::shared_ptr fwd_; std::shared_ptr cached_query_mem_; @@ -318,7 +318,7 @@ void SgMKLDNNSelfAttQKOp::Forward(const OpContext &ctx, nnvm::ObjectPtr SgMKLDNNSelfAttQKQuantizedOp(const NodeAttrs& attrs) { nnvm::ObjectPtr node = nnvm::Node::Create(); - auto const ¶m = nnvm::get(attrs.parsed); + auto const ¶m = nnvm::get(attrs.parsed); node->attrs.op = Op::Get("_sg_mkldnn_selfatt_qk"); node->attrs.name = "quantized_" + attrs.name; node->attrs.dict = attrs.dict; @@ -335,7 +335,7 @@ nnvm::ObjectPtr SgMKLDNNSelfAttQKQuantizedOp(const NodeAttrs& attrs) { NNVM_REGISTER_OP(_sg_mkldnn_selfatt_qk) .describe(R"code(_sg_mkldnn_selfatt_qk)code" ADD_FILELINE) .set_num_inputs([](const NodeAttrs& attrs) { - auto const& param = nnvm::get(attrs.parsed); + auto const& param = nnvm::get(attrs.parsed); if (param.quantized) { return 3; } else { @@ -343,16 +343,16 @@ NNVM_REGISTER_OP(_sg_mkldnn_selfatt_qk) } }) .set_num_outputs([](const NodeAttrs& attrs) { - auto const& param = nnvm::get(attrs.parsed); + auto const& param = nnvm::get(attrs.parsed); if (param.quantized && !param.enable_float_output) { return 3; } else { return 1; } }) -.set_attr_parser(ParamParser) +.set_attr_parser(ParamParser) .set_attr("FListInputNames", [](const NodeAttrs& attrs) { - auto const& param = nnvm::get(attrs.parsed); + auto const& param = nnvm::get(attrs.parsed); std::vector input_names {"queries_keys_values"}; if (param.quantized) { input_names.emplace_back("min_qkv"); @@ -361,7 +361,7 @@ NNVM_REGISTER_OP(_sg_mkldnn_selfatt_qk) return input_names; }) .set_attr("FListOutputNames", [](const NodeAttrs& attrs) { - auto const& param = nnvm::get(attrs.parsed); + auto const& param = nnvm::get(attrs.parsed); std::vector output_names {"output"}; if (param.quantized && !param.enable_float_output) { output_names.emplace_back("min_output"); @@ -382,14 +382,14 @@ NNVM_REGISTER_OP(_sg_mkldnn_selfatt_qk) .set_attr("FQuantizedOp", SgMKLDNNSelfAttQKQuantizedOp) .set_attr("FNeedRequantize", [](const NodeAttrs& attrs) { return true; }) .add_argument("queries_keys_values", "NDArray-or-Symbol", "Interleaved queries, keys and values") -.add_arguments(MKLDNNInterleavedMatMulParam::__FIELDS__()); +.add_arguments(MKLDNNSelfAttParam::__FIELDS__()); /**********************************_sg_mkldnn_selfatt_valatt**********************************/ static bool SgMKLDNNSelfAttValAttInferType(const nnvm::NodeAttrs &attrs, std::vector *in_types, std::vector *out_types) { - const auto& param = nnvm::get(attrs.parsed); + const auto& param = nnvm::get(attrs.parsed); if (param.quantized) { TYPE_ASSIGN_CHECK(*in_types, 0, mshadow::kInt8); // qkv input TYPE_ASSIGN_CHECK(*in_types, 1, mshadow::kUint8); // att input @@ -418,7 +418,7 @@ static bool SgMKLDNNSelfAttValAttInferType(const nnvm::NodeAttrs &attrs, nnvm::ObjectPtr SgMKLDNNSelfAttValAttQuantizedOp(const NodeAttrs& attrs) { nnvm::ObjectPtr node = nnvm::Node::Create(); - auto const ¶m = nnvm::get(attrs.parsed); + auto const ¶m = nnvm::get(attrs.parsed); node->attrs.op = Op::Get("_sg_mkldnn_selfatt_valatt"); node->attrs.name = "quantized_" + attrs.name; node->attrs.dict = attrs.dict; @@ -435,7 +435,7 @@ nnvm::ObjectPtr SgMKLDNNSelfAttValAttQuantizedOp(const NodeAttrs& attrs) { class MKLDNNSelfAttValAttOp { public: explicit MKLDNNSelfAttValAttOp(const nnvm::NodeAttrs &attrs) : - param_(nnvm::get(attrs.parsed)) {} + param_(nnvm::get(attrs.parsed)) {} void Forward(const OpContext &ctx, const std::vector &inputs, @@ -461,7 +461,7 @@ class MKLDNNSelfAttValAttOp { private: bool initialized_{false}; - MKLDNNInterleavedMatMulParam param_; + MKLDNNSelfAttParam param_; mkldnn_args_map_t args_; std::shared_ptr fwd_; std::shared_ptr cached_att_mem_; @@ -615,7 +615,7 @@ void MKLDNNSelfAttValAttOp::Forward(const OpContext &ctx, NNVM_REGISTER_OP(_sg_mkldnn_selfatt_valatt) .describe(R"code(_sg_mkldnn_selfatt_valatt)code" ADD_FILELINE) .set_num_inputs([](const NodeAttrs& attrs) { - auto const& param = nnvm::get(attrs.parsed); + auto const& param = nnvm::get(attrs.parsed); if (param.quantized) { return 6; } else { @@ -623,16 +623,16 @@ NNVM_REGISTER_OP(_sg_mkldnn_selfatt_valatt) } }) .set_num_outputs([](const NodeAttrs& attrs) { - auto const& param = nnvm::get(attrs.parsed); + auto const& param = nnvm::get(attrs.parsed); if (param.quantized && !param.enable_float_output) { return 3; } else { return 1; } }) -.set_attr_parser(ParamParser) +.set_attr_parser(ParamParser) .set_attr("FListInputNames", [](const NodeAttrs& attrs) { - auto const& param = nnvm::get(attrs.parsed); + auto const& param = nnvm::get(attrs.parsed); std::vector input_names {"queries_keys_values", "attention"}; if (param.quantized) { input_names.emplace_back("min_qkv"); @@ -644,7 +644,7 @@ NNVM_REGISTER_OP(_sg_mkldnn_selfatt_valatt) return input_names; }) .set_attr("FListOutputNames", [](const NodeAttrs& attrs) { - auto const& param = nnvm::get(attrs.parsed); + auto const& param = nnvm::get(attrs.parsed); std::vector output_names {"output"}; if (param.quantized && !param.enable_float_output) { output_names.emplace_back("min_output"); @@ -666,7 +666,7 @@ NNVM_REGISTER_OP(_sg_mkldnn_selfatt_valatt) .set_attr("FNeedRequantize", [](const NodeAttrs& attrs) { return true; }) .add_argument("queries_keys_values", "NDArray-or-Symbol", "Queries, keys and values interleaved") .add_argument("attention", "NDArray-or-Symbol", "Attention maps") -.add_arguments(MKLDNNInterleavedMatMulParam::__FIELDS__()); +.add_arguments(MKLDNNSelfAttParam::__FIELDS__()); } // namespace op } // namespace mxnet diff --git a/src/operator/subgraph/mkldnn/mkldnn_transformer_post_quantize_property.h b/src/operator/subgraph/mkldnn/mkldnn_transformer_post_quantize_property.h index 9291e427f0ed..c3be3a157bf2 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_transformer_post_quantize_property.h +++ b/src/operator/subgraph/mkldnn/mkldnn_transformer_post_quantize_property.h @@ -175,7 +175,7 @@ class SgMKLDNNTransformerPostQuantizeProperty : public SubgraphProperty { // When only fusing quantized_interleaved_matmul and requantize, set min/max_cablib_range, // When fusing quantized_interleaved_matmul + requantize + dequantize, set dequantize flag to true. - // auto& param = nnvm::get(interleaved_node->attrs.parsed); + // auto& param = nnvm::get(interleaved_node->attrs.parsed); if (dequantize_node != nullptr) { interleaved_node->attrs.dict["enable_float_output"] = "True"; } else { diff --git a/src/operator/subgraph/mkldnn/mkldnn_transformer_property.h b/src/operator/subgraph/mkldnn/mkldnn_transformer_property.h index 5413cc201f65..86d1f58da530 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_transformer_property.h +++ b/src/operator/subgraph/mkldnn/mkldnn_transformer_property.h @@ -91,7 +91,7 @@ class SgMKLDNNTransformerProperty : public SubgraphProperty { new_sym.outputs.emplace_back(last_node); std::ostringstream node_name; std::string op_name; - MKLDNNInterleavedMatMulParam new_param; + MKLDNNSelfAttParam new_param; DFSVisit(new_sym.outputs, [&](const nnvm::ObjectPtr &node) { if (node->op() && (node->op()->name == SELFATT_QK || From ba04ec809ca30316a9508644f91ad3d13ca4fff3 Mon Sep 17 00:00:00 2001 From: "B. Gawrych" Date: Thu, 15 Apr 2021 08:38:53 +0200 Subject: [PATCH 12/13] Fix sanity --- .../subgraph/mkldnn/mkldnn_transformer.cc | 30 ++++++++----------- ...kldnn_transformer_post_quantize_property.h | 6 ++-- .../mkldnn/mkldnn_transformer_property.h | 4 +-- 3 files changed, 17 insertions(+), 23 deletions(-) diff --git a/src/operator/subgraph/mkldnn/mkldnn_transformer.cc b/src/operator/subgraph/mkldnn/mkldnn_transformer.cc index 585d1929b9ef..8757da0b3c1f 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_transformer.cc +++ b/src/operator/subgraph/mkldnn/mkldnn_transformer.cc @@ -22,7 +22,6 @@ #include #include #include -#include #include "../common.h" #include "./mkldnn_transformer-inl.h" #include "../../contrib/transformer-inl.h" @@ -44,7 +43,7 @@ static bool SgMKLDNNSelfAttShape(const NodeAttrs& attrs, mxnet::ShapeVector base_in_shapes; mxnet::ShapeVector base_out_shapes = {out_shapes->at(0)}; - for(int i=0; i < base_num_inputs; i++) { + for (int i = 0; i < base_num_inputs; i++) { base_in_shapes.emplace_back(in_shapes->at(i)); } bool ret = DefaultSubgraphOpShape(attrs, &base_in_shapes, &base_out_shapes); @@ -77,8 +76,8 @@ static bool SgMKLDNNSelfAttQKInferType(const nnvm::NodeAttrs &attrs, << "QuantizedInterleavedMatMulSelfAttQK only supports int8 input, while " << in_types->at(0) << " is given."; - TYPE_ASSIGN_CHECK(*in_types, 1, mshadow::kFloat32); // min value - TYPE_ASSIGN_CHECK(*in_types, 2, mshadow::kFloat32); // max value + TYPE_ASSIGN_CHECK(*in_types, 1, mshadow::kFloat32); // min value + TYPE_ASSIGN_CHECK(*in_types, 2, mshadow::kFloat32); // max value if (param.enable_float_output) { TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kFloat32); // output @@ -108,7 +107,7 @@ static bool SgMKLDNNSelfAttStorageType(const nnvm::NodeAttrs &attrs, std::vector base_in_attrs; std::vector base_out_attrs{out_attrs->at(0)}; - for(int i=0; i < base_num_inputs; i++) { + for (int i = 0; i < base_num_inputs; i++) { base_in_attrs.emplace_back(in_attrs->at(i)); } bool ret = DefaultSubgraphOpStorageType(attrs, dev_mask, dispatch_mode, @@ -188,7 +187,7 @@ static void SgMKLDNNSelfAttQKForward(const OpStatePtr &state_pointer, const std::vector &req, const std::vector &outputs) { SgMKLDNNSelfAttQKOp &op = state_pointer.get_state(); - if(!op.IsInitialized()) { + if (!op.IsInitialized()) { op.Initialize(ctx, inputs, req, outputs); } op.Forward(ctx, inputs, req, outputs); @@ -229,7 +228,7 @@ void SgMKLDNNSelfAttQKOp::Initialize(const OpContext &ctx, memory::dims query_strides = {batch_stride, lead_dim, 1}; memory::dims key_strides = {batch_stride, 1, lead_dim}; - + auto query_md = memory::desc(query_dims, qkv_dtype, query_strides); auto key_md = memory::desc(key_dims, qkv_dtype, key_strides); @@ -243,7 +242,7 @@ void SgMKLDNNSelfAttQKOp::Initialize(const OpContext &ctx, param_.max_calib_range.has_value()) { min_output_ = param_.min_calib_range.value(); max_output_ = param_.max_calib_range.value(); - oscale = + oscale = GetQuantizeScale(out_tensor.dtype(), min_output_, max_output_) / (data_scale_ * data_scale_); out_md = memory::desc(out_dims, memory::data_type::s8, memory::format_tag::abc); @@ -260,7 +259,7 @@ void SgMKLDNNSelfAttQKOp::Initialize(const OpContext &ctx, } else { out_md = dnnl::memory::desc(out_dims, memory::data_type::f32, memory::format_tag::abc); } - oscale /= sqrt(static_cast(head_dim)); // combine quantized scale and sqrt(head_dim) + oscale /= sqrt(static_cast(head_dim)); // combine quantized scale and sqrt(head_dim) dnnl::primitive_attr attr; attr.set_output_scales(0, {oscale}); @@ -290,7 +289,6 @@ void SgMKLDNNSelfAttQKOp::Forward(const OpContext &ctx, const std::vector &inputs, const std::vector &req, const std::vector &outputs) { - const size_t head_dim = inputs[0].shape()[2] / 3 / param_.heads; MSHADOW_TYPE_SWITCH(inputs[0].dtype(), DType, { @@ -391,10 +389,10 @@ static bool SgMKLDNNSelfAttValAttInferType(const nnvm::NodeAttrs &attrs, std::vector *out_types) { const auto& param = nnvm::get(attrs.parsed); if (param.quantized) { - TYPE_ASSIGN_CHECK(*in_types, 0, mshadow::kInt8); // qkv input - TYPE_ASSIGN_CHECK(*in_types, 1, mshadow::kUint8); // att input + TYPE_ASSIGN_CHECK(*in_types, 0, mshadow::kInt8); // qkv input + TYPE_ASSIGN_CHECK(*in_types, 1, mshadow::kUint8); // att input - //min qkv, max qkv, min att, max att + // min qkv, max qkv, min att, max att for (size_t i = 2; i < in_types->size(); ++i) { TYPE_ASSIGN_CHECK(*in_types, i, mshadow::kFloat32); } @@ -490,7 +488,7 @@ static void MKLDNNSelfAttValAttForward(const OpStatePtr &state_pointer, const std::vector &req, const std::vector &outputs) { MKLDNNSelfAttValAttOp &op = state_pointer.get_state(); - if(!op.IsInitialized()) { + if (!op.IsInitialized()) { op.Initialize(ctx, inputs, req, outputs); } op.Forward(ctx, inputs, req, outputs); @@ -500,7 +498,6 @@ void MKLDNNSelfAttValAttOp::Initialize(const OpContext &ctx, const std::vector &inputs, const std::vector &req, const std::vector &outputs) { - const dnnl::memory::dim qkv_seq_len = inputs[0].shape()[0]; const dnnl::memory::dim sequences = inputs[0].shape()[1]; const dnnl::memory::dim output_lin_dim = inputs[0].shape()[2]; @@ -584,7 +581,6 @@ void MKLDNNSelfAttValAttOp::Forward(const OpContext &ctx, const std::vector &inputs, const std::vector &req, const std::vector &outputs) { - const auto engine = CpuEngine::Get()->get_engine(); const size_t head_dim = inputs[0].shape()[2] / param_.heads / 3; MSHADOW_TYPE_SWITCH(inputs[1].dtype(), DType, { @@ -671,4 +667,4 @@ NNVM_REGISTER_OP(_sg_mkldnn_selfatt_valatt) } // namespace op } // namespace mxnet -#endif \ No newline at end of file +#endif diff --git a/src/operator/subgraph/mkldnn/mkldnn_transformer_post_quantize_property.h b/src/operator/subgraph/mkldnn/mkldnn_transformer_post_quantize_property.h index c3be3a157bf2..adf623084807 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_transformer_post_quantize_property.h +++ b/src/operator/subgraph/mkldnn/mkldnn_transformer_post_quantize_property.h @@ -53,7 +53,7 @@ class SgMKLDNNTransformerPostQuantizeSelector : public SubgraphSelector { disable_float_output(dis_float_output) {} bool Select(const nnvm::Node &n) override { - if ((!disable_all) && + if ((!disable_all) && (n.op() == Op::Get("_sg_mkldnn_selfatt_qk") || n.op() == Op::Get("_sg_mkldnn_selfatt_valatt"))) { status = disable_all ? kSuccess : kStart; @@ -174,8 +174,8 @@ class SgMKLDNNTransformerPostQuantizeProperty : public SubgraphProperty { CHECK(requantize_param.max_calib_range.has_value()); // When only fusing quantized_interleaved_matmul and requantize, set min/max_cablib_range, - // When fusing quantized_interleaved_matmul + requantize + dequantize, set dequantize flag to true. - // auto& param = nnvm::get(interleaved_node->attrs.parsed); + // When fusing quantized_interleaved_matmul + requantize + dequantize, + // set dequantize flag to true. if (dequantize_node != nullptr) { interleaved_node->attrs.dict["enable_float_output"] = "True"; } else { diff --git a/src/operator/subgraph/mkldnn/mkldnn_transformer_property.h b/src/operator/subgraph/mkldnn/mkldnn_transformer_property.h index 86d1f58da530..a2cafd7a63d8 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_transformer_property.h +++ b/src/operator/subgraph/mkldnn/mkldnn_transformer_property.h @@ -48,8 +48,6 @@ const std::map NameMapping = { class SgMKLDNNTransformerSelector : public SubgraphSelector { public: - explicit SgMKLDNNTransformerSelector() {} - bool Select(const nnvm::Node &n, const std::shared_ptr& node_attr) override { if (n.op() == Op::Get(SELFATT_QK) || n.op() == Op::Get(SELFATT_VALATT)) { @@ -93,7 +91,7 @@ class SgMKLDNNTransformerProperty : public SubgraphProperty { std::string op_name; MKLDNNSelfAttParam new_param; DFSVisit(new_sym.outputs, [&](const nnvm::ObjectPtr &node) { - if (node->op() && + if (node->op() && (node->op()->name == SELFATT_QK || node->op()->name == SELFATT_VALATT)) { op_name = node->op()->name; From aecb8459b4b63193b3415778d9635a4b927371a7 Mon Sep 17 00:00:00 2001 From: "B. Gawrych" Date: Thu, 15 Apr 2021 09:11:39 +0200 Subject: [PATCH 13/13] Fix sanity --- src/operator/subgraph/mkldnn/mkldnn_transformer_property.h | 1 + 1 file changed, 1 insertion(+) diff --git a/src/operator/subgraph/mkldnn/mkldnn_transformer_property.h b/src/operator/subgraph/mkldnn/mkldnn_transformer_property.h index a2cafd7a63d8..f022bccc24ac 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_transformer_property.h +++ b/src/operator/subgraph/mkldnn/mkldnn_transformer_property.h @@ -22,6 +22,7 @@ #define MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_TRANSFORMER_PROPERTY_H_ #if MXNET_USE_MKLDNN == 1 +#include #include #include #include "../common.h"