diff --git a/python/tvm/relax/op/nn/__init__.py b/python/tvm/relax/op/nn/__init__.py index 61212f33d882..5ccc38d61612 100644 --- a/python/tvm/relax/op/nn/__init__.py +++ b/python/tvm/relax/op/nn/__init__.py @@ -20,6 +20,7 @@ adaptive_avg_pool2d, adaptive_avg_pool3d, attention, + attention_bias, attention_var_len, avg_pool1d, avg_pool2d, diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py index 62d8b84321ce..aef77b7c63af 100644 --- a/python/tvm/relax/op/nn/nn.py +++ b/python/tvm/relax/op/nn/nn.py @@ -1837,6 +1837,103 @@ def attention( ) # type: ignore +def attention_bias( + query: Expr, + key: Expr, + value: Expr, + bias: Optional[Expr] = None, + scale: Optional[FloatImm] = None, + causal_mask: Optional[str] = None, + window_size: Optional[int] = None, +) -> Expr: + r"""Computes fused multi head attention. + + IRModule.script() transforms attention op to attention_bias which is incompatible + with TVMScript Parser. + The function makes TVMScript's print compatible with TVMScript's parser. + + All input tensors are of 4-D tensors with BSNH layout. + + .. math:: + FMA(Q, K, V) = \text{Softmax}(Q @ K^T) @ V + + .. note:: + The input tensor is required to have float16 dtype + + Parameters + ---------- + query: relax.Expr + The input query to the operator. The layout of the input query should be + (batch_size, seq_len, num_head, head_dim). + + key: relax.Expr + The input key to the operator. The layout of the input key should be + (batch_size, seq_len_kv, num_head, head_dim). + + value: relax.Expr + The input value to the operator. The layout of the input value should be + (batch_size, seq_len_kv, num_head, head_dim_v). + + bias: Optional[Expr] + The optional attention bias to the operator. The layout of the attention bias should be + a 4-D tensor ending with seq_len_kv, and broadcastable to + (batch_size, num_head, seq_len, seq_len_kv). + + scale: Optional[float] + The scale value to be applied to the attention score, by default 1 / sqrt(head_dim). + + causal_mask: Optional[str] + The optional causal mask, i.e. 'TopLeft' and 'BottomRight'. + For 'TopLeft', the mask matrix is as `np.tril(*, k=0)`, + while for 'BottomRight', the mask matrix is as `np.tril(*, k=abs(seq_len - seq_len_kv))` + For example, with seq_len = 4, seq_len_kv = 2, + mask for 'TopLeft': + + .. code:: python + + [[1, 0], + [1, 1], + [1, 1], + [1, 1]] + + mask for 'BottomRight': + + .. code:: python + + [[1, 1], + [1, 1], + [1, 1], + [1, 1]] + + with seq_len = 2, seq_len_kv = 4, + mask for 'TopLeft': + + .. code:: python + + [[1, 0, 0, 0], + [1, 1, 0, 0]] + + mask for 'BottomRight': + + .. code:: python + + [[1, 1, 1, 0], + [1, 1, 1, 1]] + + window_size: Optional[int] + The size of the window for sliding-window attention. + + Returns + ------- + result : relax.Expr + The computed result. The layout of the output should be + (batch_size, seq_len, num_head, head_dim_v). + """ + return _ffi_api.attention( + query, key, value, bias, scale, causal_mask, window_size + ) # type: ignore + + def attention_var_len( queries: Expr, keys: Expr,