From 60f934fe8fa56cc657ba6fcb12e4d97499ed10da Mon Sep 17 00:00:00 2001 From: Renat Idrisov Date: Sun, 26 Jan 2025 23:35:11 +0000 Subject: [PATCH 1/4] Resolving inconsistency between attention/attention_bias --- include/tvm/relax/attrs/nn.h | 2 ++ src/relax/op/nn/attention.cc | 18 +----------------- 2 files changed, 3 insertions(+), 17 deletions(-) diff --git a/include/tvm/relax/attrs/nn.h b/include/tvm/relax/attrs/nn.h index e26cee26584b..1f5f5017c1df 100644 --- a/include/tvm/relax/attrs/nn.h +++ b/include/tvm/relax/attrs/nn.h @@ -546,11 +546,13 @@ struct DropoutAttrs : public tvm::AttrsNode { /*! \brief Attributes used in Attention operator */ struct AttentionAttrs : public tvm::AttrsNode { + Optional bias; Optional scale; Optional causal_mask; Optional window_size; TVM_DECLARE_ATTRS(AttentionAttrs, "relax.attrs.AttentionAttrs") { + TVM_ATTR_FIELD(bias).describe("The input bias tensor."); TVM_ATTR_FIELD(scale).describe( "The custom scale applied before the softmax. The default value is 1 / sqrt(head_dim)."); TVM_ATTR_FIELD(causal_mask) diff --git a/src/relax/op/nn/attention.cc b/src/relax/op/nn/attention.cc index ca3746ddad4e..436b1a7d2582 100644 --- a/src/relax/op/nn/attention.cc +++ b/src/relax/op/nn/attention.cc @@ -34,12 +34,8 @@ Expr attention(Expr query, Expr key, Expr value, Optional bias, Optionalscale = scale; attrs->causal_mask = causal_mask; attrs->window_size = window_size; + attrs->bias = bias; - if (bias) { - return Call(Op::Get("relax.nn.attention_bias"), - {std::move(query), std::move(key), std::move(value), std::move(bias.value())}, - Attrs(attrs), {}); - } return Call(Op::Get("relax.nn.attention"), {std::move(query), std::move(key), std::move(value)}, Attrs(attrs), {}); } @@ -152,18 +148,6 @@ TVM_REGISTER_OP("relax.nn.attention") .set_attr("FInferStructInfo", InferStructInfoAttention) .set_attr("FPurity", Bool(true)); -TVM_REGISTER_OP("relax.nn.attention_bias") - .set_attrs_type() - .set_num_inputs(4) - .add_argument("query", "Tensor", "The input queries tensor.") - .add_argument("key", "Tensor", "The input keys tensor.") - .add_argument("value", "Tensor", "The input values tensor.") - .add_argument("bias", "Tensor", "The input bias tensor.") - .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kAlways) - .set_attr("FInferMixedPrecision", InferMixedPrecisionAttention) - .set_attr("FInferStructInfo", InferStructInfoAttention) - .set_attr("FPurity", Bool(true)); - TVM_REGISTER_OP("relax.nn.attention_var_len") .set_attrs_type() .set_num_inputs(7) From b2c173a06fdad288dd6f136c39e07628d2dcf05e Mon Sep 17 00:00:00 2001 From: Renat Idrisov Date: Mon, 27 Jan 2025 03:47:58 +0000 Subject: [PATCH 2/4] Fixing the build the other way around --- include/tvm/relax/attrs/nn.h | 2 -- python/tvm/relax/transform/legalize_ops/nn.py | 19 +------------------ src/relax/op/nn/attention.cc | 11 ++++++++--- 3 files changed, 9 insertions(+), 23 deletions(-) diff --git a/include/tvm/relax/attrs/nn.h b/include/tvm/relax/attrs/nn.h index 1f5f5017c1df..e26cee26584b 100644 --- a/include/tvm/relax/attrs/nn.h +++ b/include/tvm/relax/attrs/nn.h @@ -546,13 +546,11 @@ struct DropoutAttrs : public tvm::AttrsNode { /*! \brief Attributes used in Attention operator */ struct AttentionAttrs : public tvm::AttrsNode { - Optional bias; Optional scale; Optional causal_mask; Optional window_size; TVM_DECLARE_ATTRS(AttentionAttrs, "relax.attrs.AttentionAttrs") { - TVM_ATTR_FIELD(bias).describe("The input bias tensor."); TVM_ATTR_FIELD(scale).describe( "The custom scale applied before the softmax. The default value is 1 / sqrt(head_dim)."); TVM_ATTR_FIELD(causal_mask) diff --git a/python/tvm/relax/transform/legalize_ops/nn.py b/python/tvm/relax/transform/legalize_ops/nn.py index 8317d4504e1e..cba805d095af 100644 --- a/python/tvm/relax/transform/legalize_ops/nn.py +++ b/python/tvm/relax/transform/legalize_ops/nn.py @@ -648,23 +648,6 @@ def _te_attention( @register_legalize("relax.nn.attention") def _nn_attention(bb: BlockBuilder, call: Call) -> Expr: - assert ( - call.attrs.window_size is None - ), "Legalization for sliding-window attention is not supported yet." - return bb.call_te( - _te_attention, - call.args[0], - call.args[1], - call.args[2], - None, - call.attrs.scale, - call.attrs.causal_mask, - primfunc_name_hint="attention", - ) - - -@register_legalize("relax.nn.attention_bias") -def _nn_attention_bias(bb: BlockBuilder, call: Call) -> Expr: assert ( call.attrs.window_size is None ), "Legalization for sliding-window attention is not supported yet." @@ -676,7 +659,7 @@ def _nn_attention_bias(bb: BlockBuilder, call: Call) -> Expr: call.args[3], call.attrs.scale, call.attrs.causal_mask, - primfunc_name_hint="attention_bias", + primfunc_name_hint="attention", ) diff --git a/src/relax/op/nn/attention.cc b/src/relax/op/nn/attention.cc index 436b1a7d2582..8d0f597d73c8 100644 --- a/src/relax/op/nn/attention.cc +++ b/src/relax/op/nn/attention.cc @@ -34,10 +34,14 @@ Expr attention(Expr query, Expr key, Expr value, Optional bias, Optionalscale = scale; attrs->causal_mask = causal_mask; attrs->window_size = window_size; - attrs->bias = bias; + + if (bias) { + return Call(Op::Get("relax.nn.attention"), {std::move(query), std::move(key), std::move(value), std::move(bias.value())}, + Attrs(attrs), {}); + } return Call(Op::Get("relax.nn.attention"), {std::move(query), std::move(key), std::move(value)}, - Attrs(attrs), {}); + Attrs(attrs), {}); } Expr attention_var_len(Expr query, Expr key, Expr value, Expr seqstart_q, Expr seqstart_k, @@ -139,10 +143,11 @@ Call InferMixedPrecisionAttention(const Call& call, const DataType& out_dtype) { TVM_REGISTER_OP("relax.nn.attention") .set_attrs_type() - .set_num_inputs(3) + .set_num_inputs(4) .add_argument("query", "Tensor", "The input queries tensor.") .add_argument("key", "Tensor", "The input keys tensor.") .add_argument("value", "Tensor", "The input values tensor.") + .add_argument("bias", "Tensor", "The input bias tensor.") .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kAlways) .set_attr("FInferMixedPrecision", InferMixedPrecisionAttention) .set_attr("FInferStructInfo", InferStructInfoAttention) From af1a82c14fd5f1936fd0d5e23d9474aa81587c3e Mon Sep 17 00:00:00 2001 From: Renat Idrisov Date: Mon, 27 Jan 2025 04:43:01 +0000 Subject: [PATCH 3/4] Fixing the build the other way around --- include/tvm/relax/attrs/nn.h | 2 ++ python/tvm/relax/transform/legalize_ops/nn.py | 2 +- src/relax/op/nn/attention.cc | 13 ++++--------- src/relax/op/nn/attention.h | 2 +- 4 files changed, 8 insertions(+), 11 deletions(-) diff --git a/include/tvm/relax/attrs/nn.h b/include/tvm/relax/attrs/nn.h index e26cee26584b..bc13d686e763 100644 --- a/include/tvm/relax/attrs/nn.h +++ b/include/tvm/relax/attrs/nn.h @@ -546,11 +546,13 @@ struct DropoutAttrs : public tvm::AttrsNode { /*! \brief Attributes used in Attention operator */ struct AttentionAttrs : public tvm::AttrsNode { + Optional bias; Optional scale; Optional causal_mask; Optional window_size; TVM_DECLARE_ATTRS(AttentionAttrs, "relax.attrs.AttentionAttrs") { + TVM_ATTR_FIELD(bias).describe("The input bias tensor."); TVM_ATTR_FIELD(scale).describe( "The custom scale applied before the softmax. The default value is 1 / sqrt(head_dim)."); TVM_ATTR_FIELD(causal_mask) diff --git a/python/tvm/relax/transform/legalize_ops/nn.py b/python/tvm/relax/transform/legalize_ops/nn.py index cba805d095af..a4eb5c28f0ab 100644 --- a/python/tvm/relax/transform/legalize_ops/nn.py +++ b/python/tvm/relax/transform/legalize_ops/nn.py @@ -656,7 +656,7 @@ def _nn_attention(bb: BlockBuilder, call: Call) -> Expr: call.args[0], call.args[1], call.args[2], - call.args[3], + call.attrs.bias, call.attrs.scale, call.attrs.causal_mask, primfunc_name_hint="attention", diff --git a/src/relax/op/nn/attention.cc b/src/relax/op/nn/attention.cc index 8d0f597d73c8..8b21acdac391 100644 --- a/src/relax/op/nn/attention.cc +++ b/src/relax/op/nn/attention.cc @@ -28,20 +28,16 @@ namespace relax { /* relax.nn.attention */ TVM_REGISTER_NODE_TYPE(AttentionAttrs); -Expr attention(Expr query, Expr key, Expr value, Optional bias, Optional scale, +Expr attention(Expr query, Expr key, Expr value, Optional bias, Optional scale, Optional causal_mask, Optional window_size) { ObjectPtr attrs = make_object(); attrs->scale = scale; attrs->causal_mask = causal_mask; attrs->window_size = window_size; - - if (bias) { - return Call(Op::Get("relax.nn.attention"), {std::move(query), std::move(key), std::move(value), std::move(bias.value())}, - Attrs(attrs), {}); - } + attrs->bias = bias; return Call(Op::Get("relax.nn.attention"), {std::move(query), std::move(key), std::move(value)}, - Attrs(attrs), {}); + Attrs(attrs), {}); } Expr attention_var_len(Expr query, Expr key, Expr value, Expr seqstart_q, Expr seqstart_k, @@ -143,11 +139,10 @@ Call InferMixedPrecisionAttention(const Call& call, const DataType& out_dtype) { TVM_REGISTER_OP("relax.nn.attention") .set_attrs_type() - .set_num_inputs(4) + .set_num_inputs(3) .add_argument("query", "Tensor", "The input queries tensor.") .add_argument("key", "Tensor", "The input keys tensor.") .add_argument("value", "Tensor", "The input values tensor.") - .add_argument("bias", "Tensor", "The input bias tensor.") .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kAlways) .set_attr("FInferMixedPrecision", InferMixedPrecisionAttention) .set_attr("FInferStructInfo", InferStructInfoAttention) diff --git a/src/relax/op/nn/attention.h b/src/relax/op/nn/attention.h index 346907f8e938..16b266417591 100644 --- a/src/relax/op/nn/attention.h +++ b/src/relax/op/nn/attention.h @@ -33,7 +33,7 @@ namespace tvm { namespace relax { /*! \brief fused multi head attention */ -Expr attention(Expr query, Expr key, Expr value, Optional bias, Optional scale, +Expr attention(Expr query, Expr key, Expr value, Optional bias, Optional scale, Optional causal_mask, Optional window_size); } // namespace relax From 116cb1775d0f625572f7deb676452391a57d15d2 Mon Sep 17 00:00:00 2001 From: Renat Idrisov Date: Mon, 27 Jan 2025 04:47:32 +0000 Subject: [PATCH 4/4] Fixing the build --- include/tvm/relax/attrs/nn.h | 2 +- src/relax/op/nn/attention.cc | 2 +- src/relax/op/nn/attention.h | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/include/tvm/relax/attrs/nn.h b/include/tvm/relax/attrs/nn.h index bc13d686e763..1f5f5017c1df 100644 --- a/include/tvm/relax/attrs/nn.h +++ b/include/tvm/relax/attrs/nn.h @@ -546,7 +546,7 @@ struct DropoutAttrs : public tvm::AttrsNode { /*! \brief Attributes used in Attention operator */ struct AttentionAttrs : public tvm::AttrsNode { - Optional bias; + Optional bias; Optional scale; Optional causal_mask; Optional window_size; diff --git a/src/relax/op/nn/attention.cc b/src/relax/op/nn/attention.cc index 8b21acdac391..436b1a7d2582 100644 --- a/src/relax/op/nn/attention.cc +++ b/src/relax/op/nn/attention.cc @@ -28,7 +28,7 @@ namespace relax { /* relax.nn.attention */ TVM_REGISTER_NODE_TYPE(AttentionAttrs); -Expr attention(Expr query, Expr key, Expr value, Optional bias, Optional scale, +Expr attention(Expr query, Expr key, Expr value, Optional bias, Optional scale, Optional causal_mask, Optional window_size) { ObjectPtr attrs = make_object(); attrs->scale = scale; diff --git a/src/relax/op/nn/attention.h b/src/relax/op/nn/attention.h index 16b266417591..346907f8e938 100644 --- a/src/relax/op/nn/attention.h +++ b/src/relax/op/nn/attention.h @@ -33,7 +33,7 @@ namespace tvm { namespace relax { /*! \brief fused multi head attention */ -Expr attention(Expr query, Expr key, Expr value, Optional bias, Optional scale, +Expr attention(Expr query, Expr key, Expr value, Optional bias, Optional scale, Optional causal_mask, Optional window_size); } // namespace relax