From b0bf0fc6eeff3ac3f34d2dfc2d9837d8e1dd8df3 Mon Sep 17 00:00:00 2001 From: taylor Date: Tue, 3 Jun 2025 17:26:24 +0800 Subject: [PATCH 1/3] Resolving inconsistency between attention/attention_bias --- python/tvm/relax/op/nn/__init__.py | 1 + python/tvm/relax/op/nn/nn.py | 95 ++++++++++++++++++++++++++++++ 2 files changed, 96 insertions(+) diff --git a/python/tvm/relax/op/nn/__init__.py b/python/tvm/relax/op/nn/__init__.py index 61212f33d882..5ccc38d61612 100644 --- a/python/tvm/relax/op/nn/__init__.py +++ b/python/tvm/relax/op/nn/__init__.py @@ -20,6 +20,7 @@ adaptive_avg_pool2d, adaptive_avg_pool3d, attention, + attention_bias, attention_var_len, avg_pool1d, avg_pool2d, diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py index 62d8b84321ce..9afb3bff6414 100644 --- a/python/tvm/relax/op/nn/nn.py +++ b/python/tvm/relax/op/nn/nn.py @@ -1836,6 +1836,101 @@ def attention( query, key, value, bias, scale, causal_mask, window_size ) # type: ignore +def attention_bias( + query: Expr, + key: Expr, + value: Expr, + bias: Optional[Expr] = None, + scale: Optional[FloatImm] = None, + causal_mask: Optional[str] = None, + window_size: Optional[int] = None, +) -> Expr: + r"""Computes fused multi head attention. + + IRModule.script() transforms attention op to attention_bias which is incompatible with TVMScript Parser + The function makes TVMScript's print compatible with TVMScript's parser. + + All input tensors are of 4-D tensors with BSNH layout. + + .. math:: + FMA(Q, K, V) = \text{Softmax}(Q @ K^T) @ V + + .. note:: + The input tensor is required to have float16 dtype + + Parameters + ---------- + query: relax.Expr + The input query to the operator. The layout of the input query should be + (batch_size, seq_len, num_head, head_dim). + + key: relax.Expr + The input key to the operator. The layout of the input key should be + (batch_size, seq_len_kv, num_head, head_dim). + + value: relax.Expr + The input value to the operator. The layout of the input value should be + (batch_size, seq_len_kv, num_head, head_dim_v). + + bias: Optional[Expr] + The optional attention bias to the operator. The layout of the attention bias should be + a 4-D tensor ending with seq_len_kv, and broadcastable to + (batch_size, num_head, seq_len, seq_len_kv). + + scale: Optional[float] + The scale value to be applied to the attention score, by default 1 / sqrt(head_dim). + + causal_mask: Optional[str] + The optional causal mask, i.e. 'TopLeft' and 'BottomRight'. + For 'TopLeft', the mask matrix is as `np.tril(*, k=0)`, + while for 'BottomRight', the mask matrix is as `np.tril(*, k=abs(seq_len - seq_len_kv))` + For example, with seq_len = 4, seq_len_kv = 2, + mask for 'TopLeft': + + .. code:: python + + [[1, 0], + [1, 1], + [1, 1], + [1, 1]] + + mask for 'BottomRight': + + .. code:: python + + [[1, 1], + [1, 1], + [1, 1], + [1, 1]] + + with seq_len = 2, seq_len_kv = 4, + mask for 'TopLeft': + + .. code:: python + + [[1, 0, 0, 0], + [1, 1, 0, 0]] + + mask for 'BottomRight': + + .. code:: python + + [[1, 1, 1, 0], + [1, 1, 1, 1]] + + window_size: Optional[int] + The size of the window for sliding-window attention. + + Returns + ------- + result : relax.Expr + The computed result. The layout of the output should be + (batch_size, seq_len, num_head, head_dim_v). + """ + return _ffi_api.attention( + query, key, value, bias, scale, causal_mask, window_size + ) # type: ignore + def attention_var_len( queries: Expr, From 4f7b69a56a990d7a9a6b3b7852f6356e77c9da88 Mon Sep 17 00:00:00 2001 From: taylor Date: Wed, 4 Jun 2025 10:05:17 +0800 Subject: [PATCH 2/3] reformat --- python/tvm/relax/op/nn/nn.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py index 9afb3bff6414..a305fd46df5a 100644 --- a/python/tvm/relax/op/nn/nn.py +++ b/python/tvm/relax/op/nn/nn.py @@ -1836,6 +1836,7 @@ def attention( query, key, value, bias, scale, causal_mask, window_size ) # type: ignore + def attention_bias( query: Expr, key: Expr, From c6120abd2d319aff404f75b947f6cf69ea5eaf5b Mon Sep 17 00:00:00 2001 From: taylor Date: Wed, 4 Jun 2025 14:52:24 +0800 Subject: [PATCH 3/3] reduce the length of line --- python/tvm/relax/op/nn/nn.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py index a305fd46df5a..aef77b7c63af 100644 --- a/python/tvm/relax/op/nn/nn.py +++ b/python/tvm/relax/op/nn/nn.py @@ -1848,7 +1848,8 @@ def attention_bias( ) -> Expr: r"""Computes fused multi head attention. - IRModule.script() transforms attention op to attention_bias which is incompatible with TVMScript Parser + IRModule.script() transforms attention op to attention_bias which is incompatible + with TVMScript Parser. The function makes TVMScript's print compatible with TVMScript's parser. All input tensors are of 4-D tensors with BSNH layout.