diff --git a/README.md b/README.md index 582f1154a..dbe36b38d 100644 --- a/README.md +++ b/README.md @@ -82,7 +82,8 @@ from flash_attn import flash_attn_qkvpacked_func, flash_attn_func ``` ```python -flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False, window_size=(-1, -1)): +flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False, + window_size=(-1, -1), alibi_slopes=None): """dropout_p should be set to 0.0 during evaluation If Q, K, V are already stacked into 1 tensor, this function will be faster than calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation @@ -96,13 +97,16 @@ Arguments: Default to 1 / sqrt(headdim). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). window_size: (left, right). If not (-1, -1), implements sliding window local attention. + alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to + the attention score of query i and key j. Return: out: (batch_size, seqlen, nheads, headdim). """ ``` ```python -flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False, window_size=(-1, -1)): +flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False, + window_size=(-1, -1), alibi_slopes=None): """dropout_p should be set to 0.0 during evaluation Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. @@ -121,6 +125,9 @@ Arguments: Default to 1 / sqrt(headdim). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). window_size: (left, right). If not (-1, -1), implements sliding window local attention. + alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of + (-alibi_slope * |i + seqlen_k - seqlen_q - j|) + is added to the attention score of query i and key j. Return: out: (batch_size, seqlen, nheads, headdim). """ @@ -141,6 +148,7 @@ def flash_attn_with_kvcache( causal=False, window_size=(-1, -1), # -1 means infinite context window rotary_interleaved=True, + alibi_slopes=None, ): """ If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from @@ -183,10 +191,9 @@ def flash_attn_with_kvcache( If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False, rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1 (i.e. GPT-NeoX style). - num_splits: int. If > 1, split the key/value into this many chunks along the sequence. - If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic - to automatically determine the number of splits. - Don't change this unless you know what you are doing. + alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of + (-alibi_slope * |i + seqlen_k - seqlen_q - j|) + is added to the attention score of query i and key j. Return: out: (batch_size, seqlen, nheads, headdim). @@ -262,6 +269,10 @@ Implement sliding window attention (i.e., local attention). Thanks to [Mistral AI](https://mistral.ai/) and in particular Timothée Lacroix for this contribution. Sliding window was used in the [Mistral 7B](https://mistral.ai/news/announcing-mistral-7b/) model. +### 2.4: ALiBi (attention with linear bias) + +Implement ALiBi (Press et el., 2021). Thanks to Sanghun Cho from Kakao Brain for this contribution. + ## Performance We present expected speedup (combined forward + backward pass) and memory savings from using FlashAttention against PyTorch standard attention, depending on sequence length, on different GPUs (speedup depends on memory bandwidth - we see more speedup on slower GPU memory).