diff --git a/src/operator/subgraph/mkldnn/mkldnn_subgraph_property.cc b/src/operator/subgraph/mkldnn/mkldnn_subgraph_property.cc index 18cd3031ef18..9190ba41afcb 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_subgraph_property.cc +++ b/src/operator/subgraph/mkldnn/mkldnn_subgraph_property.cc @@ -25,6 +25,8 @@ #include "mkldnn_fc_post_quantize_property.h" #include "mkldnn_elemwisemul_post_quantize_property.h" #include "mkldnn_post_quantize_align_scale_property.h" +#include "mkldnn_transformer_property.h" +#include "mkldnn_transformer_post_quantize_property.h" namespace mxnet { namespace op { @@ -35,34 +37,29 @@ MXNET_REGISTER_SUBGRAPH_BACKEND(MKLDNN) MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN, SgMKLDNNConvProperty); -#endif // MXNET_USE_MKLDNN == 1 -#if MXNET_USE_MKLDNN == 1 MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN, SgMKLDNNFCProperty); -#endif // MXNET_USE_MKLDNN == 1 -#if MXNET_USE_MKLDNN == 1 + +MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN, SgMKLDNNTransformerProperty); + MXNET_REGISTER_SUBGRAPH_BACKEND(MKLDNN_QUANTIZE) .set_attr("context", Context::CPU()); MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNConvProperty) .set_attr("quantize", true); -#endif // MXNET_USE_MKLDNN == 1 -#if MXNET_USE_MKLDNN == 1 - MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNFCProperty) .set_attr("quantize", true); -#endif // MXNET_USE_MKLDNN == 1 -#if MXNET_USE_MKLDNN == 1 + +MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNTransformerProperty); + +MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNTransformerPostQuantizeProperty); + MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNPostQuantizeProperty); -#endif // MXNET_USE_MKLDNN == 1 -#if MXNET_USE_MKLDNN == 1 MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNFCPostQuantizeProperty); MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, ElemwiseMulPostQuantizeProperty); MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNPostQuantizeAlignScaleProperty); -#endif // MXNET_USE_MKLDNN == 1 -#if MXNET_USE_MKLDNN == 1 } // namespace op } // namespace mxnet #endif // MXNET_USE_MKLDNN == 1 diff --git a/src/operator/subgraph/mkldnn/mkldnn_transformer-inl.h b/src/operator/subgraph/mkldnn/mkldnn_transformer-inl.h new file mode 100644 index 000000000000..d4004351649e --- /dev/null +++ b/src/operator/subgraph/mkldnn/mkldnn_transformer-inl.h @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_TRANSFORMER_INL_H_ +#define MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_TRANSFORMER_INL_H_ + +#include "../../mxnet_op.h" +#include "../../mshadow_op.h" + + +namespace mxnet { +namespace op { + +struct MKLDNNSelfAttParam : public dmlc::Parameter { + int heads; + bool quantized; + bool enable_float_output; + dmlc::optional min_calib_range; // min float value calculated from calibration dataset + dmlc::optional max_calib_range; // max float value calculated from calibration dataset + DMLC_DECLARE_PARAMETER(MKLDNNSelfAttParam) { + DMLC_DECLARE_FIELD(heads) + .describe("Set number of heads"); + DMLC_DECLARE_FIELD(quantized).set_default(false) + .describe("Whether it's a quantized InterleavedMatMul operator"); + DMLC_DECLARE_FIELD(enable_float_output).set_default(false) + .describe("Whether to enable float32 output"); + DMLC_DECLARE_FIELD(min_calib_range) + .set_default(dmlc::optional()) + .describe("The minimum scalar value in the form of float32 obtained " + "through calibration. If present, it will be used to by " + "quantized InterleavedMatMul op to calculate primitive scale"); + DMLC_DECLARE_FIELD(max_calib_range) + .set_default(dmlc::optional()) + .describe("The maximum scalar value in the form of float32 obtained " + "through calibration. If present, it will be used to by " + "quantized InterleavedMatMul op to calculate primitive scale"); + } +}; + +} // namespace op +} // namespace mxnet +#endif // MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_TRANSFORMER_INL_H_ diff --git a/src/operator/subgraph/mkldnn/mkldnn_transformer.cc b/src/operator/subgraph/mkldnn/mkldnn_transformer.cc new file mode 100644 index 000000000000..8757da0b3c1f --- /dev/null +++ b/src/operator/subgraph/mkldnn/mkldnn_transformer.cc @@ -0,0 +1,670 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one +* or more contributor license agreements. See the NOTICE file +* distributed with this work for additional information +* regarding copyright ownership. The ASF licenses this file +* to you under the Apache License, Version 2.0 (the +* "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an +* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +* KIND, either express or implied. See the License for the +* specific language governing permissions and limitations +* under the License. +*/ + +#if MXNET_USE_MKLDNN == 1 + +#include +#include +#include +#include "../common.h" +#include "./mkldnn_transformer-inl.h" +#include "../../contrib/transformer-inl.h" +#include "../../tensor/elemwise_unary_op.h" + +#include "../../quantization/quantization_utils.h" + +namespace mxnet { +namespace op { + +DMLC_REGISTER_PARAMETER(MKLDNNSelfAttParam); + +template +static bool SgMKLDNNSelfAttShape(const NodeAttrs& attrs, + mxnet::ShapeVector* in_shapes, + mxnet::ShapeVector* out_shapes) { + const auto& param = nnvm::get(attrs.parsed); + if (param.quantized) { + mxnet::ShapeVector base_in_shapes; + mxnet::ShapeVector base_out_shapes = {out_shapes->at(0)}; + + for (int i = 0; i < base_num_inputs; i++) { + base_in_shapes.emplace_back(in_shapes->at(i)); + } + bool ret = DefaultSubgraphOpShape(attrs, &base_in_shapes, &base_out_shapes); + + for (size_t i = 0; i < in_shapes->size(); ++i) { + if (i < base_in_shapes.size()) + in_shapes->at(i) = base_in_shapes[i]; + else + SHAPE_ASSIGN_CHECK(*in_shapes, i, mxnet::TShape({1})); + } + out_shapes->resize(3); + out_shapes->at(0) = base_out_shapes[0]; + if (!param.enable_float_output) { + SHAPE_ASSIGN_CHECK(*out_shapes, 1, mxnet::TShape({1})); // min output + SHAPE_ASSIGN_CHECK(*out_shapes, 2, mxnet::TShape({1})); // max output + } + + return ret; + } else { + return DefaultSubgraphOpShape(attrs, in_shapes, out_shapes); + } +} + +static bool SgMKLDNNSelfAttQKInferType(const nnvm::NodeAttrs &attrs, + std::vector *in_types, + std::vector *out_types) { + const auto& param = nnvm::get(attrs.parsed); + if (param.quantized) { + CHECK(in_types->at(0) == mshadow::kInt8) + << "QuantizedInterleavedMatMulSelfAttQK only supports int8 input, while " + << in_types->at(0) << " is given."; + + TYPE_ASSIGN_CHECK(*in_types, 1, mshadow::kFloat32); // min value + TYPE_ASSIGN_CHECK(*in_types, 2, mshadow::kFloat32); // max value + + if (param.enable_float_output) { + TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kFloat32); // output + } else { + if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) { + TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kInt8); // output + } else { + TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kInt32); // output + } + TYPE_ASSIGN_CHECK(*out_types, 1, mshadow::kFloat32); // min output + TYPE_ASSIGN_CHECK(*out_types, 2, mshadow::kFloat32); // max output + } + return true; + } else { + return DefaultSubgraphOpType(attrs, in_types, out_types); + } +} + +template +static bool SgMKLDNNSelfAttStorageType(const nnvm::NodeAttrs &attrs, + const int dev_mask, + DispatchMode *dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs) { + auto const ¶m = nnvm::get(attrs.parsed); + if (param.quantized) { + std::vector base_in_attrs; + std::vector base_out_attrs{out_attrs->at(0)}; + + for (int i = 0; i < base_num_inputs; i++) { + base_in_attrs.emplace_back(in_attrs->at(i)); + } + bool ret = DefaultSubgraphOpStorageType(attrs, dev_mask, dispatch_mode, + &base_in_attrs, &base_out_attrs); + + for (size_t i = 0; i < in_attrs->size(); ++i) { + if (i < base_in_attrs.size()) + in_attrs->at(i) = base_in_attrs[i]; + else + type_assign(&in_attrs->at(i), mxnet::kDefaultStorage); + } + + out_attrs->at(0) = base_out_attrs[0]; + if (!param.enable_float_output) { + type_assign(&out_attrs->at(1), mxnet::kDefaultStorage); + type_assign(&out_attrs->at(2), mxnet::kDefaultStorage); + } + return ret; + } else { + return DefaultSubgraphOpStorageType(attrs, dev_mask, dispatch_mode, + in_attrs, out_attrs); + } +} + +class SgMKLDNNSelfAttQKOp { + public: + explicit SgMKLDNNSelfAttQKOp(const nnvm::NodeAttrs &attrs) : + param_(nnvm::get(attrs.parsed)) {} + + void Forward(const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs); + + void Backward(const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + LOG(FATAL) << "Not implemented: subgraph mkldnn fully connected only supports " + "inference computation."; + } + + void Initialize(const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs); + + bool IsInitialized() { + return initialized_; + } + + private: + bool initialized_{false}; + MKLDNNSelfAttParam param_; + mkldnn_args_map_t args_; + std::shared_ptr fwd_; + std::shared_ptr cached_query_mem_; + std::shared_ptr cached_key_mem_; + std::shared_ptr cached_out_mem_; + float min_data_; + float max_data_; + float min_output_; + float max_output_; + float data_scale_{0.0f}; +}; + +static OpStatePtr CreateSgMKLDNNSelfAttQKState(const nnvm::NodeAttrs &attrs, + Context ctx, + const mxnet::ShapeVector &in_shapes, + const std::vector &in_types) { + return OpStatePtr::Create(attrs); +} + +static void SgMKLDNNSelfAttQKForward(const OpStatePtr &state_pointer, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + SgMKLDNNSelfAttQKOp &op = state_pointer.get_state(); + if (!op.IsInitialized()) { + op.Initialize(ctx, inputs, req, outputs); + } + op.Forward(ctx, inputs, req, outputs); +} + +void SgMKLDNNSelfAttQKOp::Initialize(const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + using namespace mkldnn; + const auto qkv_tensor = inputs[0]; + const auto out_tensor = outputs[0]; + const auto qkv_dtype = get_mkldnn_type(qkv_tensor.dtype()); + + const memory::dim heads = param_.heads; + const memory::dim sequences = inputs[0].shape()[1]; + const memory::dim qkv_seq_len = inputs[0].shape()[0]; + const memory::dim output_lin_dim = inputs[0].shape()[2]; + const memory::dim embed_dim = output_lin_dim / 3; + const memory::dim head_dim = embed_dim / heads; + const memory::dim attn_batches = heads * sequences; + const memory::dim lead_dim = attn_batches * 3 * head_dim; + const memory::dim batch_stride = 3 * head_dim; + + float min_data = 0.0f; + float max_data = 0.0f; + + if (param_.quantized) { + min_data_ = inputs[1].data().dptr()[0]; + max_data_ = inputs[2].data().dptr()[0]; + } + + const auto engine = CpuEngine::Get()->get_engine(); + + memory::dims query_dims = {attn_batches, qkv_seq_len, head_dim}; + memory::dims key_dims = {attn_batches, head_dim, qkv_seq_len}; + memory::dims out_dims = {attn_batches, qkv_seq_len, qkv_seq_len}; + + memory::dims query_strides = {batch_stride, lead_dim, 1}; + memory::dims key_strides = {batch_stride, 1, lead_dim}; + + auto query_md = memory::desc(query_dims, qkv_dtype, query_strides); + auto key_md = memory::desc(key_dims, qkv_dtype, key_strides); + + memory::desc out_md; + + float oscale = 1.0f; + if (param_.quantized) { + data_scale_ = GetQuantizeScale(qkv_tensor.dtype(), min_data_, max_data_); + + if (param_.min_calib_range.has_value() && + param_.max_calib_range.has_value()) { + min_output_ = param_.min_calib_range.value(); + max_output_ = param_.max_calib_range.value(); + oscale = + GetQuantizeScale(out_tensor.dtype(), min_output_, max_output_) / + (data_scale_ * data_scale_); + out_md = memory::desc(out_dims, memory::data_type::s8, memory::format_tag::abc); + } else if (param_.enable_float_output) { + oscale = 1.0f / (data_scale_ * data_scale_); + out_md = dnnl::memory::desc(out_dims, memory::data_type::f32, memory::format_tag::abc); + } else { + mshadow::Stream *s = ctx.get_stream(); + mxnet_op::Kernel::Launch( + s, 1, &min_output_, &max_output_, &min_data, &max_data, &min_data, + &max_data); + out_md = dnnl::memory::desc(out_dims, memory::data_type::s32, memory::format_tag::abc); + } + } else { + out_md = dnnl::memory::desc(out_dims, memory::data_type::f32, memory::format_tag::abc); + } + oscale /= sqrt(static_cast(head_dim)); // combine quantized scale and sqrt(head_dim) + + dnnl::primitive_attr attr; + attr.set_output_scales(0, {oscale}); + auto matmul_d = matmul::desc(query_md, key_md, out_md); + auto matmul_pd = matmul::primitive_desc(matmul_d, attr, engine); + + fwd_ = std::make_shared(matmul_pd); + + MSHADOW_TYPE_SWITCH(inputs[0].dtype(), DType, { + DType* query_mem_ptr = inputs[0].data().dptr(); + DType* key_mem_ptr = query_mem_ptr + head_dim; + cached_query_mem_ = std::make_shared(query_md, engine, query_mem_ptr); + cached_key_mem_ = std::make_shared(key_md, engine, key_mem_ptr); + }); + MSHADOW_TYPE_SWITCH(outputs[0].dtype(), DType, { + cached_out_mem_ = std::make_shared(out_md, engine, outputs[0].data().dptr()); + }); + + args_[DNNL_ARG_SRC] = *cached_query_mem_; + args_[DNNL_ARG_WEIGHTS] = *cached_key_mem_; + args_[DNNL_ARG_DST] = *cached_out_mem_; + initialized_ = true; +} + + +void SgMKLDNNSelfAttQKOp::Forward(const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + const size_t head_dim = inputs[0].shape()[2] / 3 / param_.heads; + + MSHADOW_TYPE_SWITCH(inputs[0].dtype(), DType, { + DType* query_mem_ptr = inputs[0].data().dptr(); + DType* key_mem_ptr = query_mem_ptr + head_dim; + cached_query_mem_->set_data_handle(query_mem_ptr); + cached_key_mem_->set_data_handle(key_mem_ptr); + }); + + MSHADOW_TYPE_SWITCH(outputs[0].dtype(), DType, { + cached_out_mem_->set_data_handle(outputs[0].data().dptr()); + }); + + MKLDNNStream::Get()->RegisterPrimArgs(*fwd_, args_); + MKLDNNStream::Get()->Submit(); + + if (param_.quantized && !param_.enable_float_output) { + float* output_min = outputs[1].data().dptr(); + float* output_max = outputs[2].data().dptr(); + + *output_min = min_output_; + *output_max = max_output_; + } +} + +nnvm::ObjectPtr SgMKLDNNSelfAttQKQuantizedOp(const NodeAttrs& attrs) { + nnvm::ObjectPtr node = nnvm::Node::Create(); + auto const ¶m = nnvm::get(attrs.parsed); + node->attrs.op = Op::Get("_sg_mkldnn_selfatt_qk"); + node->attrs.name = "quantized_" + attrs.name; + node->attrs.dict = attrs.dict; + node->attrs.dict["heads"] = std::to_string(param.heads); + node->attrs.dict["quantized"] = "True"; + node->attrs.subgraphs.reserve(attrs.subgraphs.size()); + for (auto sub : attrs.subgraphs) { + node->attrs.subgraphs.push_back(sub); + } + node->op()->attr_parser(&(node->attrs)); + return node; +} + +NNVM_REGISTER_OP(_sg_mkldnn_selfatt_qk) +.describe(R"code(_sg_mkldnn_selfatt_qk)code" ADD_FILELINE) +.set_num_inputs([](const NodeAttrs& attrs) { + auto const& param = nnvm::get(attrs.parsed); + if (param.quantized) { + return 3; + } else { + return 1; + } +}) +.set_num_outputs([](const NodeAttrs& attrs) { + auto const& param = nnvm::get(attrs.parsed); + if (param.quantized && !param.enable_float_output) { + return 3; + } else { + return 1; + } +}) +.set_attr_parser(ParamParser) +.set_attr("FListInputNames", [](const NodeAttrs& attrs) { + auto const& param = nnvm::get(attrs.parsed); + std::vector input_names {"queries_keys_values"}; + if (param.quantized) { + input_names.emplace_back("min_qkv"); + input_names.emplace_back("max_qkv"); + } + return input_names; +}) +.set_attr("FListOutputNames", [](const NodeAttrs& attrs) { + auto const& param = nnvm::get(attrs.parsed); + std::vector output_names {"output"}; + if (param.quantized && !param.enable_float_output) { + output_names.emplace_back("min_output"); + output_names.emplace_back("max_output"); + } + return output_names; +}) +.set_attr("FInferShape", SgMKLDNNSelfAttShape<1>) +.set_attr("FInferType", SgMKLDNNSelfAttQKInferType) +.set_attr("FInferStorageType", SgMKLDNNSelfAttStorageType<1>) +.set_attr("FCreateOpState", CreateSgMKLDNNSelfAttQKState) +.set_attr("FStatefulComputeEx", SgMKLDNNSelfAttQKForward) +.set_attr("TIsMKLDNN", true) +.set_attr("FGradient", MakeZeroGradNodes) +.set_attr("FQuantizable", [](const NodeAttrs& attrs) { + return QuantizeType::kMust; +}) +.set_attr("FQuantizedOp", SgMKLDNNSelfAttQKQuantizedOp) +.set_attr("FNeedRequantize", [](const NodeAttrs& attrs) { return true; }) +.add_argument("queries_keys_values", "NDArray-or-Symbol", "Interleaved queries, keys and values") +.add_arguments(MKLDNNSelfAttParam::__FIELDS__()); + +/**********************************_sg_mkldnn_selfatt_valatt**********************************/ + +static bool SgMKLDNNSelfAttValAttInferType(const nnvm::NodeAttrs &attrs, + std::vector *in_types, + std::vector *out_types) { + const auto& param = nnvm::get(attrs.parsed); + if (param.quantized) { + TYPE_ASSIGN_CHECK(*in_types, 0, mshadow::kInt8); // qkv input + TYPE_ASSIGN_CHECK(*in_types, 1, mshadow::kUint8); // att input + + // min qkv, max qkv, min att, max att + for (size_t i = 2; i < in_types->size(); ++i) { + TYPE_ASSIGN_CHECK(*in_types, i, mshadow::kFloat32); + } + + if (param.enable_float_output) { + TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kFloat32); // output + } else { + if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) { + TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kInt8); // output + } else { + TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kInt32); // output + } + TYPE_ASSIGN_CHECK(*out_types, 1, mshadow::kFloat32); // min output + TYPE_ASSIGN_CHECK(*out_types, 2, mshadow::kFloat32); // max output + } + return true; + } else { + return DefaultSubgraphOpType(attrs, in_types, out_types); + } +} + +nnvm::ObjectPtr SgMKLDNNSelfAttValAttQuantizedOp(const NodeAttrs& attrs) { + nnvm::ObjectPtr node = nnvm::Node::Create(); + auto const ¶m = nnvm::get(attrs.parsed); + node->attrs.op = Op::Get("_sg_mkldnn_selfatt_valatt"); + node->attrs.name = "quantized_" + attrs.name; + node->attrs.dict = attrs.dict; + node->attrs.dict["heads"] = std::to_string(param.heads); + node->attrs.dict["quantized"] = "True"; + node->attrs.subgraphs.reserve(attrs.subgraphs.size()); + for (auto sub : attrs.subgraphs) { + node->attrs.subgraphs.push_back(sub); + } + node->op()->attr_parser(&(node->attrs)); + return node; +} + +class MKLDNNSelfAttValAttOp { + public: + explicit MKLDNNSelfAttValAttOp(const nnvm::NodeAttrs &attrs) : + param_(nnvm::get(attrs.parsed)) {} + + void Forward(const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs); + + void Backward(const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + LOG(FATAL) << "Not implemented: subgraph mkldnn fully connected only supports " + "inference computation."; + } + + void Initialize(const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs); + + bool IsInitialized() { + return initialized_; + } + + private: + bool initialized_{false}; + MKLDNNSelfAttParam param_; + mkldnn_args_map_t args_; + std::shared_ptr fwd_; + std::shared_ptr cached_att_mem_; + std::shared_ptr cached_qkv_mem_; + std::shared_ptr cached_out_mem_; + float min_qkv_; + float max_qkv_; + float min_att_; + float max_att_; + float min_output_; + float max_output_; + float qkv_scale_{0.0f}; + float att_scale_{0.0f}; +}; + +static OpStatePtr CreateMKLDNNSelfAttValAttState(const nnvm::NodeAttrs &attrs, + Context ctx, + const mxnet::ShapeVector &in_shapes, + const std::vector &in_types) { + return OpStatePtr::Create(attrs); +} + +static void MKLDNNSelfAttValAttForward(const OpStatePtr &state_pointer, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + MKLDNNSelfAttValAttOp &op = state_pointer.get_state(); + if (!op.IsInitialized()) { + op.Initialize(ctx, inputs, req, outputs); + } + op.Forward(ctx, inputs, req, outputs); +} + +void MKLDNNSelfAttValAttOp::Initialize(const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + const dnnl::memory::dim qkv_seq_len = inputs[0].shape()[0]; + const dnnl::memory::dim sequences = inputs[0].shape()[1]; + const dnnl::memory::dim output_lin_dim = inputs[0].shape()[2]; + const dnnl::memory::dim embed_dim = output_lin_dim / 3; + const dnnl::memory::dim head_dim = embed_dim / param_.heads; + const dnnl::memory::dim attn_batches = param_.heads * sequences; + const dnnl::memory::dim lead_dim = attn_batches * 3 * head_dim; + const dnnl::memory::dim batch_stride = 3 * head_dim; + + + dnnl::memory::dims att_dims = {attn_batches, qkv_seq_len, qkv_seq_len}; + dnnl::memory::dims qkv_dims = {attn_batches, qkv_seq_len, head_dim}; + dnnl::memory::dims dst_dims = {attn_batches, qkv_seq_len, head_dim}; + + dnnl::memory::dims att_strides = {qkv_seq_len * qkv_seq_len, qkv_seq_len, 1}; + dnnl::memory::dims qkv_strides = {batch_stride, lead_dim, 1}; + + auto att_dtype = inputs[1].dtype(); + auto qkv_dtype = inputs[0].dtype(); + auto out_dtype = outputs[0].dtype(); + auto att_md = dnnl::memory::desc(att_dims, get_mkldnn_type(att_dtype), att_strides); + auto qkv_md = dnnl::memory::desc(qkv_dims, get_mkldnn_type(qkv_dtype), qkv_strides); + + dnnl::memory::desc out_md; + dnnl::primitive_attr attr; + + float oscale = 1.0f; + if (param_.quantized) { + min_qkv_ = inputs[2].data().dptr()[0]; + max_qkv_ = inputs[3].data().dptr()[0]; + min_att_ = inputs[4].data().dptr()[0]; + max_att_ = inputs[5].data().dptr()[0]; + qkv_scale_ = GetQuantizeScale(qkv_dtype, min_qkv_, max_qkv_); + att_scale_ = GetQuantizeScale(att_dtype, min_att_, max_att_); + + if (param_.min_calib_range.has_value() && + param_.max_calib_range.has_value()) { + min_output_ = param_.min_calib_range.value(); + max_output_ = param_.max_calib_range.value(); + + oscale = GetQuantizeScale(out_dtype, min_output_, max_output_) / (qkv_scale_ * att_scale_); + attr.set_output_scales(0, {oscale}); + } else if (param_.enable_float_output) { + oscale = 1.0f / (qkv_scale_ * att_scale_); + attr.set_output_scales(0, {oscale}); + } else { + mshadow::Stream *s = ctx.get_stream(); + mxnet_op::Kernel::Launch( + s, 1, &min_output_, &max_output_, &min_qkv_, &max_qkv_, &min_att_, + &max_att_); + } + } + out_md = dnnl::memory::desc(dst_dims, get_mkldnn_type(out_dtype), dnnl::memory::format_tag::bac); + + const auto engine = CpuEngine::Get()->get_engine(); + auto matmul_d = dnnl::matmul::desc(att_md, qkv_md, out_md); + auto matmul_pd = dnnl::matmul::primitive_desc(matmul_d, attr, engine); + + fwd_ = std::make_shared(matmul_pd); + + MSHADOW_TYPE_SWITCH(att_dtype, DType, { + DType* att_ptr = inputs[1].data().dptr(); + cached_att_mem_ = std::make_shared(att_md, engine, att_ptr); + }); + MSHADOW_TYPE_SWITCH(qkv_dtype, DType, { + DType* value_ptr = inputs[0].data().dptr() + 2*head_dim; + cached_qkv_mem_ = std::make_shared(qkv_md, engine, value_ptr); + }); + MSHADOW_TYPE_SWITCH(out_dtype, DType, { + DType* out_ptr = outputs[0].data().dptr(); + cached_out_mem_ = std::make_shared(out_md, engine, out_ptr); + }); + + args_[DNNL_ARG_SRC] = *cached_att_mem_; + args_[DNNL_ARG_WEIGHTS] = *cached_qkv_mem_; + args_[DNNL_ARG_DST] = *cached_out_mem_; + initialized_ = true; +} + +void MKLDNNSelfAttValAttOp::Forward(const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + const auto engine = CpuEngine::Get()->get_engine(); + const size_t head_dim = inputs[0].shape()[2] / param_.heads / 3; + MSHADOW_TYPE_SWITCH(inputs[1].dtype(), DType, { + DType* att_ptr = inputs[1].data().dptr(); + cached_att_mem_->set_data_handle(att_ptr); + }); + MSHADOW_TYPE_SWITCH(inputs[0].dtype(), DType, { + DType* value_ptr = inputs[0].data().dptr() + 2*head_dim; + cached_qkv_mem_->set_data_handle(value_ptr); + }); + MSHADOW_TYPE_SWITCH(outputs[0].dtype(), DType, { + DType* out_ptr = outputs[0].data().dptr(); + cached_out_mem_->set_data_handle(out_ptr); + }); + + MKLDNNStream::Get()->RegisterPrimArgs(*fwd_, args_); + MKLDNNStream::Get()->Submit(); + + if (param_.quantized && !param_.enable_float_output) { + float* output_min = outputs[1].data().dptr(); + float* output_max = outputs[2].data().dptr(); + + *output_min = min_output_; + *output_max = max_output_; + } +} + +NNVM_REGISTER_OP(_sg_mkldnn_selfatt_valatt) +.describe(R"code(_sg_mkldnn_selfatt_valatt)code" ADD_FILELINE) +.set_num_inputs([](const NodeAttrs& attrs) { + auto const& param = nnvm::get(attrs.parsed); + if (param.quantized) { + return 6; + } else { + return 2; + } +}) +.set_num_outputs([](const NodeAttrs& attrs) { + auto const& param = nnvm::get(attrs.parsed); + if (param.quantized && !param.enable_float_output) { + return 3; + } else { + return 1; + } +}) +.set_attr_parser(ParamParser) +.set_attr("FListInputNames", [](const NodeAttrs& attrs) { + auto const& param = nnvm::get(attrs.parsed); + std::vector input_names {"queries_keys_values", "attention"}; + if (param.quantized) { + input_names.emplace_back("min_qkv"); + input_names.emplace_back("max_qkv"); + + input_names.emplace_back("min_attention"); + input_names.emplace_back("max_attention"); + } + return input_names; +}) +.set_attr("FListOutputNames", [](const NodeAttrs& attrs) { + auto const& param = nnvm::get(attrs.parsed); + std::vector output_names {"output"}; + if (param.quantized && !param.enable_float_output) { + output_names.emplace_back("min_output"); + output_names.emplace_back("max_output"); + } + return output_names; +}) +.set_attr("FInferShape", SgMKLDNNSelfAttShape<2>) +.set_attr("FInferType", SgMKLDNNSelfAttValAttInferType) +.set_attr("FInferStorageType", SgMKLDNNSelfAttStorageType<2>) +.set_attr("FCreateOpState", CreateMKLDNNSelfAttValAttState) +.set_attr("FStatefulComputeEx", MKLDNNSelfAttValAttForward) +.set_attr("TIsMKLDNN", true) +.set_attr("FGradient", MakeZeroGradNodes) +.set_attr("FQuantizable", [](const NodeAttrs& attrs) { + return QuantizeType::kMust; +}) +.set_attr("FQuantizedOp", SgMKLDNNSelfAttValAttQuantizedOp) +.set_attr("FNeedRequantize", [](const NodeAttrs& attrs) { return true; }) +.add_argument("queries_keys_values", "NDArray-or-Symbol", "Queries, keys and values interleaved") +.add_argument("attention", "NDArray-or-Symbol", "Attention maps") +.add_arguments(MKLDNNSelfAttParam::__FIELDS__()); + +} // namespace op +} // namespace mxnet + +#endif diff --git a/src/operator/subgraph/mkldnn/mkldnn_transformer_post_quantize_property.h b/src/operator/subgraph/mkldnn/mkldnn_transformer_post_quantize_property.h new file mode 100644 index 000000000000..adf623084807 --- /dev/null +++ b/src/operator/subgraph/mkldnn/mkldnn_transformer_post_quantize_property.h @@ -0,0 +1,207 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_TRANSFORMER_POST_QUANTIZE_PROPERTY_H_ +#define MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_TRANSFORMER_POST_QUANTIZE_PROPERTY_H_ +#if MXNET_USE_MKLDNN == 1 + +#include +#include +#include "../../quantization/requantize-inl.h" +#include "../common.h" +#include "mkldnn_subgraph_base-inl.h" + +namespace mxnet { +namespace op { + +class SgMKLDNNTransformerPostQuantizeSelector : public SubgraphSelector { + public: + /*! \brief pattern match status */ + enum SelectStatus { + kFail = 0, + kStart, + kRequantize, + kSuccess, + }; + + private: + bool disable_all; + bool disable_float_output; + SelectStatus status; + std::vector matched_list; + + public: + explicit SgMKLDNNTransformerPostQuantizeSelector(const bool dis_all, + const bool dis_float_output) + : disable_all(dis_all), + disable_float_output(dis_float_output) {} + + bool Select(const nnvm::Node &n) override { + if ((!disable_all) && + (n.op() == Op::Get("_sg_mkldnn_selfatt_qk") || + n.op() == Op::Get("_sg_mkldnn_selfatt_valatt"))) { + status = disable_all ? kSuccess : kStart; + matched_list.clear(); + matched_list.push_back(&n); + return true; + } + return false; + } + + bool SelectInput(const nnvm::Node &n, const nnvm::Node &new_node) override { + return false; + } + + bool SelectOutput(const nnvm::Node &n, const nnvm::Node &new_node) override { + if (status == kFail || status == kSuccess || new_node.is_variable()) + return false; + // If n isn't the last matched node, then we encoutered a internal + // branch, we should pop out the node behind n and stop fusion. + if (matched_list.back() != &n) { + if (std::find(matched_list.begin(), matched_list.end(), &n) != + matched_list.end()) { + while (matched_list.back() != &n) { + matched_list.pop_back(); + } + } + + status = kSuccess; + return false; + } + + switch (status) { + case kStart: + if (new_node.op() == Op::Get("_contrib_requantize")) { + auto const ¶m = nnvm::get(new_node.attrs.parsed); + if (param.min_calib_range.has_value() && + param.max_calib_range.has_value()) { + matched_list.push_back(&new_node); + status = kRequantize; + return true; + } + } + case kRequantize: + if ((!disable_float_output) && (new_node.op() == Op::Get("_contrib_dequantize"))) { + matched_list.push_back(&new_node); + status = kSuccess; + return true; + } + default: + status = kSuccess; + return false; + } + } + + std::vector Filter( + const std::vector &candidates) override { + if ((status != kSuccess) || (matched_list.size() <= 1)) { + return std::vector(0); + } else { + std::vector ret; + for (auto i : matched_list) { + auto non_const_i = const_cast(i); + if (std::find(candidates.begin(), candidates.end(), non_const_i) != + candidates.end()) { + ret.push_back(non_const_i); + } + } + return ret; + } + } + + void Reset() override { + CHECK_GE(matched_list.size(), 1); + auto new_selector = SgMKLDNNTransformerPostQuantizeSelector(disable_all, disable_float_output); + new_selector.Select(*matched_list[0]); + *this = new_selector; + } +}; + +class SgMKLDNNTransformerPostQuantizeProperty : public SubgraphProperty { + public: + SgMKLDNNTransformerPostQuantizeProperty() { + disable_fuse_all = dmlc::GetEnv("MXNET_DISABLE_MKLDNN_QTRANSFORMER_FUSE_ALL", false); + disable_float_output = dmlc::GetEnv("MXNET_DISABLE_MKLDNN_QTRANSFORMER_FLOAT_OUTPUT", false); + } + + static SubgraphPropertyPtr Create() { + static const std::string &name = "MKLDNN Transformer post-quantization optimization pass"; + auto property = std::make_shared(); + property->SetAttr("property_name", name); + property->SetAttr("inference_only", true); + return property; + } + + nnvm::ObjectPtr CreateSubgraphNode(const nnvm::Symbol &sym, + const int subgraph_id = 0) const override { + nnvm::ObjectPtr interleaved_node = nullptr; + nnvm::ObjectPtr requantize_node = nullptr; + nnvm::ObjectPtr dequantize_node = nullptr; + + DFSVisit(sym.outputs, [&](const nnvm::ObjectPtr &node) { + if (node->is_variable()) return; + if (node->op() == Op::Get("_sg_mkldnn_selfatt_qk") || + node->op() == Op::Get("_sg_mkldnn_selfatt_valatt")) { + interleaved_node = node; + } else if (node->op() == Op::Get("_contrib_requantize")) { + requantize_node = node; + } else if (node->op() == Op::Get("_contrib_dequantize")) { + dequantize_node = node; + } + }); + + CHECK_NOTNULL(interleaved_node); + CHECK_NOTNULL(requantize_node); + auto const &requantize_param = + nnvm::get(requantize_node->attrs.parsed); + CHECK(requantize_param.min_calib_range.has_value()); + CHECK(requantize_param.max_calib_range.has_value()); + + // When only fusing quantized_interleaved_matmul and requantize, set min/max_cablib_range, + // When fusing quantized_interleaved_matmul + requantize + dequantize, + // set dequantize flag to true. + if (dequantize_node != nullptr) { + interleaved_node->attrs.dict["enable_float_output"] = "True"; + } else { + interleaved_node->attrs.dict["min_calib_range"] = + std::to_string(requantize_param.min_calib_range.value()); + interleaved_node->attrs.dict["max_calib_range"] = + std::to_string(requantize_param.max_calib_range.value()); + } + interleaved_node->op()->attr_parser(&(interleaved_node->attrs)); + return interleaved_node; + } + + SubgraphSelectorPtr CreateSubgraphSelector() const override { + auto selector = + std::make_shared(disable_fuse_all, + disable_float_output); + return selector; + } + + private: + bool disable_fuse_all; + bool disable_float_output; +}; + +} // namespace op +} // namespace mxnet + +#endif // if MXNET_USE_MKLDNN == 1 +#endif // MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_TRANSFORMER_POST_QUANTIZE_PROPERTY_H_ diff --git a/src/operator/subgraph/mkldnn/mkldnn_transformer_property.h b/src/operator/subgraph/mkldnn/mkldnn_transformer_property.h new file mode 100644 index 000000000000..f022bccc24ac --- /dev/null +++ b/src/operator/subgraph/mkldnn/mkldnn_transformer_property.h @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + +#ifndef MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_TRANSFORMER_PROPERTY_H_ +#define MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_TRANSFORMER_PROPERTY_H_ +#if MXNET_USE_MKLDNN == 1 + +#include +#include +#include +#include "../common.h" +#include "../../tensor/matrix_op-inl.h" +#include "../../contrib/transformer-inl.h" +#include "mkldnn_transformer-inl.h" +#include "mkldnn_subgraph_base-inl.h" + +namespace mxnet { +namespace op { + +#define SELFATT_QK "_contrib_interleaved_matmul_selfatt_qk" +#define SELFATT_VALATT "_contrib_interleaved_matmul_selfatt_valatt" + +const std::map OpMapping = { + {SELFATT_QK, "_sg_mkldnn_selfatt_qk"}, + {SELFATT_VALATT, "_sg_mkldnn_selfatt_valatt"} +}; + +const std::map NameMapping = { + {SELFATT_QK, "sg_mkldnn_selfatt_qk"}, + {SELFATT_VALATT, "sg_mkldnn_selfatt_valatt"} +}; + +class SgMKLDNNTransformerSelector : public SubgraphSelector { + public: + bool Select(const nnvm::Node &n, const std::shared_ptr& node_attr) override { + if (n.op() == Op::Get(SELFATT_QK) || + n.op() == Op::Get(SELFATT_VALATT)) { + return true; + } + return false; + } + + bool SelectInput(const nnvm::Node &n, const nnvm::Node &new_node) override { + return false; + } + + bool SelectOutput(const nnvm::Node &n, const nnvm::Node &new_node) override { + return false; + } +}; + +class SgMKLDNNTransformerProperty : public SubgraphProperty { + public: + SgMKLDNNTransformerProperty() {} + + static SubgraphPropertyPtr Create() { + static const std::string &name = "MKLDNN Transformer optimization pass"; + auto property = std::make_shared(); + property->SetAttr("property_name", name); + property->SetAttr("inference_only", true); + if (dmlc::GetEnv("MXNET_DISABLE_MKLDNN_TRANSFORMER_OPT", 0)) { + property->SetAttr("disable", true); + } + return property; + } + + nnvm::ObjectPtr CreateSubgraphNode(const nnvm::Symbol &sym, + const int subgraph_id = 0) const override { + nnvm::ObjectPtr n = nnvm::Node::Create(); + // This op has single output, remove duplicated. + auto last_node = sym.outputs[0].node; + nnvm::Symbol new_sym; + new_sym.outputs.emplace_back(last_node); + std::ostringstream node_name; + std::string op_name; + MKLDNNSelfAttParam new_param; + DFSVisit(new_sym.outputs, [&](const nnvm::ObjectPtr &node) { + if (node->op() && + (node->op()->name == SELFATT_QK || + node->op()->name == SELFATT_VALATT)) { + op_name = node->op()->name; + auto param = nnvm::get(node->attrs.parsed); + new_param.heads = param.heads; + new_param.quantized = false; + new_param.enable_float_output = false; + } + }); + node_name << NameMapping.at(op_name) << "_" << std::to_string(subgraph_id); + + + n->attrs.name = node_name.str(); + n->attrs.op = Op::Get(OpMapping.at(op_name)); + CHECK(n->attrs.op); + n->attrs.subgraphs.emplace_back(std::make_shared(new_sym)); + n->attrs.parsed = new_param; + return n; + } + + SubgraphSelectorPtr CreateSubgraphSelector() const override { + auto selector = std::make_shared(); + return selector; + } + + void ConnectSubgraphOutputs( + const nnvm::ObjectPtr n, + std::vector *output_entries) const override { + // Connect all extern output entries to output[0] + for (size_t i = 0; i < output_entries->size(); ++i) { + auto entry_ptr = output_entries->at(i); + *entry_ptr = nnvm::NodeEntry{n, entry_ptr->index, 0}; + } + } +}; + +} // namespace op +} // namespace mxnet + +#endif // if MXNET_USE_MKLDNN == 1 +#endif // MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_TRANSFORMER_PROPERTY_H_ diff --git a/tests/python/mkl/test_subgraph.py b/tests/python/mkl/test_subgraph.py index 65b73e438ea6..79494a046e2b 100644 --- a/tests/python/mkl/test_subgraph.py +++ b/tests/python/mkl/test_subgraph.py @@ -45,6 +45,14 @@ 'fc': { OP_NAME: 'sg_mkldnn_fully_connected', QUANTIZED_OP_NAME: 'quantized_sg_mkldnn_fully_connected' + }, + 'selfatt_qk': { + OP_NAME: 'sg_mkldnn_selfatt_qk', + QUANTIZED_OP_NAME: 'quantized_sg_mkldnn_selfatt_qk' + }, + 'selfatt_valatt': { + OP_NAME: 'sg_mkldnn_selfatt_valatt', + QUANTIZED_OP_NAME: 'quantized_sg_mkldnn_selfatt_valatt' } } @@ -52,6 +60,10 @@ fc_post_ops_list=['relu', 'sigmoid', 'tanh', 'softrelu', 'square', 'square_root', 'abs', 'exp', 'bounded_relu'] +quant_op_fp32_output_support = ("quantized_sg_mkldnn_fully_connected", + "quantized_sg_mkldnn_selfatt_qk", + "quantized_sg_mkldnn_selfatt_valatt") + def check_qsym_calibrated(qsym, out_type, name='conv'): quantized_op_name = 'quantized_' + name assert ''.join(qsym.attr_dict().keys()).find(quantized_op_name) != -1 @@ -59,7 +71,8 @@ def check_qsym_calibrated(qsym, out_type, name='conv'): if k.find('_quantize') != -1: assert v['out_type'] == out_type if k.find(quantized_op_name) != -1: - if quantized_op_name.startswith("quantized_sg_mkldnn_fully_connected") and 'enable_float_output' in v: + if ('enable_float_output' in v + and quantized_op_name.startswith(quant_op_fp32_output_support)): continue assert 'min_calib_range' in v assert 'max_calib_range' in v @@ -119,9 +132,11 @@ def check_qsym_gluon_forward(qsym, qarg_params, qaux_params, data_shape): class CalibIter(mx.io.DataIter): def __init__(self, batch, data_shape, batch_size): super(CalibIter, self).__init__(batch_size) - self.data_shape = data_shape self.label_shape = (batch_size,) - self.provide_data = [('data', self.data_shape)] + if isinstance(data_shape, tuple): + self.provide_data = [('data', data_shape)] + else: + self.provide_data = data_shape self.provide_label = [] self.batch = batch @@ -249,7 +264,6 @@ def check_fusion(sym, data_shape, attrs_dict, check_fp32_fusion=True, check_quan if ''.join(sym.get_internals().list_outputs()).find('sqrt') != -1: check_quantization = False data_min = 0 - sym_sg = sym.get_backend_symbol(SG_PASS_NAME) for name, attrs in attrs_dict.items(): if name in config: @@ -677,6 +691,13 @@ def fc_eltwise(no_bias, data_shape, flatten=True, alg='relu'): return sym, attr +def single_selfatt_qk(data_shape, nheads=16): + attr = {'selfatt_qk': {}} + data = mx.symbol.Variable('data', shape=data_shape, dtype='float32') + qk = mx.symbol.contrib.interleaved_matmul_selfatt_qk(queries_keys_values=data, + heads=nheads) + return qk, attr + # fc + relu can't be fusion case # eg.1 # fc -----------> relu @@ -865,6 +886,87 @@ def test_fc_eltwise(): else: check_fusion(syms, dshape, attrs, check_quantization=False) +@with_seed() +def test_selfatt_qk(): + batchsizes = [1, 8] + seq_lengths = [180, 384] + num_hidden = [1024, 3072] + num_heads = [8, 16] + for bs, seqlen, nhidden, nheads in itertools.product(batchsizes, seq_lengths, num_hidden, num_heads): + dshape = (seqlen, bs, nhidden) + syms, attrs = single_selfatt_qk(dshape, nheads) + check_fusion(syms, dshape, attrs, out_types=['int8', 'auto'], check_quantization=True) + +@with_seed() +def test_selfatt_valatt(): + batchsizes = [1, 8] + seq_lengths = [18, 255, 384] + num_hidden = [1024, 3072] + num_heads = [1, 16] + + def get_valatt_symbol(qkv_shape, attention_shape, nheads): + qkv = mx.symbol.Variable('qkv', shape=qkv_shape, dtype='float32') + attention = mx.symbol.Variable('attention', shape=attention_shape, dtype='float32') + # CalibIter assumes that batch_size is always first dimension + # following operators changes shapes to the proper one + qkv_swap = mx.symbol.swapaxes(data=qkv, dim1=0, dim2=1) + attention_reshape = mx.symbol.reshape(data=attention, shape=(-1, 0, 0), reverse=True) + sym = mx.symbol.contrib.interleaved_matmul_selfatt_valatt(queries_keys_values=qkv_swap, + attention=attention_reshape, + heads=nheads) + return sym + + def check_valatt_quantize(sym, qkv_shape, att_shape): + qkv_nd = mx.nd.random.uniform(low=-1, high=1, shape=qkv_shape) + weight_nd = mx.nd.random.uniform(low=0, high=1, shape=att_shape) + arg_params = { + 'qkv': qkv_nd, + 'attention': weight_nd + } + + ex = sym.bind(mx.cpu(), arg_params, args_grad=None) + ex.forward() + ref_out = ex.outputs + + sym_sg = sym.get_backend_symbol(QUANTIZE_SG_PASS_NAME) + + batch = mx.io.DataBatch([qkv_nd, weight_nd], []) + calib_data = CalibIter(batch, [('qkv', qkv_shape), ('attention', att_shape)], bs) + qsym, qarg_params, qaux_params = mx.contrib.quant.quantize_model(sym=sym_sg, + arg_params=arg_params, + aux_params={}, + ctx=mx.cpu(), + excluded_sym_names=None, + excluded_op_names=None, + quantize_granularity='tensor-wise', + quantized_dtype='auto', + calib_mode='naive', + calib_data=calib_data, + data_names=('qkv', 'attention'), + label_names=None, + num_calib_examples=1, + quantize_mode='full') + qsym = qsym.get_backend_symbol(QUANTIZE_SG_PASS_NAME) + + qex = qsym.bind(mx.cpu(), arg_params, args_grad=None) + qex.forward() + quantized_out = qex.outputs + + for i in range(len(ref_out)): + min_range = mx.nd.min(ref_out[i]).asscalar() + max_range = mx.nd.max(ref_out[i]).asscalar() + atol = 0.1 * max(abs(min_range), abs(max_range)) + assert_almost_equal_with_err(quantized_out[i].asnumpy(), ref_out[i].asnumpy(), rtol=0.1, atol=atol, etol=0.2) + + for bs, seqlen, nhidden, nheads in itertools.product(batchsizes, seq_lengths, num_hidden, num_heads): + qkv_shape = (bs, seqlen, 3*nhidden) + att_shape = (bs, nheads, seqlen, seqlen) + + sym = get_valatt_symbol(qkv_shape, att_shape, nheads) + check_fusion(sym, None, {'selfatt_valatt': {}}, check_quantization=False) + check_valatt_quantize(sym, qkv_shape, att_shape) + + @with_seed() def test_neg_fc_relu(): for dshape, no_bias, flatten in itertools.product(DATA_SHAPE, [True, False], [True, False]):