"""FlashAttention w/support for learned sinks and banded attention. This is an expanded version of the Flash Attention v2 implementation (see https://tridao.me/publications/flash2/flash2.pdf) which can be found at https://triton-lang.org/main/getting-started/tutorials/06-fused-attention.html. This version has been extended to support banded attention and learned attention sinks. """ import pytest import torch import triton import triton.language as tl @triton.jit def _attn_fwd( Q, K, V, Sinks, sm_scale, M, Out, # Start_q, stride_qz, stride_qh, stride_qm, stride_qk, # stride_kz, stride_kh, stride_kn, stride_kk, # stride_vz, stride_vh, stride_vn, stride_vk, # stride_oz, stride_oh, stride_om, stride_ok, # Z, H, N_Q_CTX, N_KV_CTX, HEAD_DIM: tl.constexpr, # BLOCK_M: tl.constexpr, # BLOCK_N: tl.constexpr, # BANDWIDTH: tl.constexpr, ): tl.static_assert(BLOCK_N <= HEAD_DIM) start_q = tl.load(Start_q).to(tl.int32) start_m = tl.program_id(0) off_hz = tl.program_id(1) off_z = off_hz // H off_h = off_hz % H q_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh k_offset = off_z.to(tl.int64) * stride_kz + off_h.to(tl.int64) * stride_kh v_offset = off_z.to(tl.int64) * stride_vz + off_h.to(tl.int64) * stride_vh o_offset = off_z.to(tl.int64) * stride_oz + off_h.to(tl.int64) * stride_oh # block pointers Q_block_ptr = tl.make_block_ptr( base=Q + q_offset, shape=(N_Q_CTX, HEAD_DIM), strides=(stride_qm, stride_qk), offsets=(start_m * BLOCK_M, 0), block_shape=(BLOCK_M, HEAD_DIM), order=(1, 0), ) V_block_ptr = tl.make_block_ptr( base=V + v_offset, shape=(N_KV_CTX, HEAD_DIM), strides=(stride_vn, stride_vk), offsets=(0, 0), block_shape=(BLOCK_N, HEAD_DIM), order=(1, 0), ) K_block_ptr = tl.make_block_ptr( base=K + k_offset, shape=(HEAD_DIM, N_KV_CTX), strides=(stride_kk, stride_kn), offsets=(0, 0), block_shape=(HEAD_DIM, BLOCK_N), order=(0, 1), ) O_block_ptr = tl.make_block_ptr( base=Out + o_offset, shape=(N_Q_CTX, HEAD_DIM), strides=(stride_om, stride_ok), offsets=(start_m * BLOCK_M, 0), block_shape=(BLOCK_M, HEAD_DIM), order=(1, 0), ) # load attention sinks if Sinks is not None: sink = tl.load(Sinks + off_h).to(tl.float32) else: sink = 0 # initialize offsets offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_n = tl.arange(0, BLOCK_N) # initialize pointer to m and l m_i = tl.zeros([BLOCK_M], dtype=tl.float32) + sink l_i = tl.zeros([BLOCK_M], dtype=tl.float32) acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) # load scales qk_scale = sm_scale q = tl.load(Q_block_ptr) if BANDWIDTH: lo, hi = tl.maximum(start_q, start_q + start_m * BLOCK_M - BANDWIDTH), (start_q + start_m + 1) * BLOCK_M else: lo, hi = start_q, (start_q + start_m + 1) * BLOCK_M # advance the KV block-pointers so they point at `lo` K_block_ptr = tl.advance(K_block_ptr, (0, lo)) V_block_ptr = tl.advance(V_block_ptr, (lo, 0)) for start_n in range(lo, hi, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) mask = (start_n + offs_n)[None, :] > (start_q + offs_m)[:, None] if BANDWIDTH: too_old = (start_n + offs_n[None, :]) < (start_q + offs_m[:, None] - BANDWIDTH + 1) mask = mask | too_old k = tl.load(K_block_ptr) qk = tl.dot(q, k, allow_tf32=False) qk = qk * qk_scale + tl.where(mask, -1.0e6, 0.0) m_ij = tl.maximum(m_i, tl.max(qk, 1)) qk -= m_ij[:, None] p = tl.math.exp(qk) alpha = tl.math.exp(m_i - m_ij) l_ij = tl.sum(p, 1) acc = acc * alpha[:, None] v = tl.load(V_block_ptr).to(tl.float32) acc = tl.dot(p, v, acc, allow_tf32=False) l_i = l_i * alpha + l_ij m_i = m_ij V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) sink = tl.math.exp(sink - m_i) z = l_i + sink acc = acc / z[:, None] m_i += tl.math.log(l_i) m_ptrs = M + off_hz * N_Q_CTX + offs_m tl.store(m_ptrs, m_i) tl.store(O_block_ptr, acc.to(Out.type.element_ty)) class _attention(torch.autograd.Function): @staticmethod def forward(ctx, q, k, v, sinks, sm_scale, bandwidth, start_q): assert len(start_q) == 1 bs, n_ctx, n_kv_heads, repeat_kv, HEAD_DIM_Q = q.shape bs, n_kv_ctx, n_kv_heads, HEAD_DIM_K = k.shape bs, n_kv_ctx, n_kv_heads, HEAD_DIM_V = v.shape n_heads = n_kv_heads * repeat_kv q = q.view(bs, n_ctx, n_heads, HEAD_DIM_Q) k = k.view(bs, n_kv_ctx, n_kv_heads, HEAD_DIM_K) assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V assert HEAD_DIM_K in {16, 32, 64, 128, 256} q = q.transpose(1, 2).contiguous() k = k.repeat_interleave(repeat_kv, dim=2).transpose(1, 2).contiguous() v = v.repeat_interleave(repeat_kv, dim=2).transpose(1, 2).contiguous() BLOCK_M = 64 BLOCK_N = 64 m_pad_size = BLOCK_M - n_ctx % BLOCK_M if n_ctx % BLOCK_M != 0 else 0 # pad q to multiple of its block size in the n_ctx dimension (-2) q = torch.nn.functional.pad(q, (0, 0, 0, m_pad_size)) n_pad_size = BLOCK_N - n_kv_ctx % BLOCK_N if n_kv_ctx % BLOCK_N != 0 else 0 # pad k and v to multiple of their block size in the n_kv_ctx dimension k = torch.nn.functional.pad(k, (0, 0, 0, n_pad_size)) v = torch.nn.functional.pad(v, (0, 0, 0, n_pad_size)) o = torch.empty_like(q) M = torch.empty((bs, n_heads, n_ctx + m_pad_size), device=q.device, dtype=torch.float32) grid = (triton.cdiv(n_ctx, BLOCK_M), bs * n_heads, 1) _attn_fwd[grid]( q, k, v, sinks, sm_scale, M, o, # start_q, q.stride(0), q.stride(1), q.stride(2), q.stride(3), # k.stride(0), k.stride(1), k.stride(2), k.stride(3), # v.stride(0), v.stride(1), v.stride(2), v.stride(3), # o.stride(0), o.stride(1), o.stride(2), o.stride(3), # q.shape[0], q.shape[1], # N_Q_CTX=n_ctx + m_pad_size, # N_KV_CTX=n_kv_ctx, # HEAD_DIM=HEAD_DIM_K, # BANDWIDTH=bandwidth, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, ) ctx.save_for_backward(q, k, v, sinks, o, M, start_q) ctx.sm_scale = sm_scale ctx.bandwidth = bandwidth o = o[:, :, :n_ctx, :].transpose(1, 2).contiguous() o = o.view(bs, n_ctx, n_heads * HEAD_DIM_V) return o attention = _attention.apply def attention_ref( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, sinks: torch.Tensor, sm_scale: float = 0.125, sliding_window: int | None = None, start_q: torch.LongTensor = 0, ): batch_size, num_queries, num_key_value_heads, num_key_value_groups, head_dim = query.shape batch_size, num_keys, num_key_value_heads, head_dim = key.shape sinks = sinks.view(1, num_key_value_heads, num_key_value_groups, 1, 1).float() key = key.unsqueeze(3) value = value.unsqueeze(3) pos_keys = torch.arange(num_keys, device=query.device) pos_queries = torch.arange(num_queries, device=query.device) + start_q mask = pos_keys[None, :] > pos_queries[:, None] mask = mask.float().masked_fill(mask, float("-inf")) if sliding_window: too_old = pos_keys[None, :] < (pos_queries[:, None] - sliding_window + 1) mask.masked_fill_(too_old, float("-inf")) logits = torch.einsum("bqhmd,bkhmd->bhmqk", query.float(), key.float()) * sm_scale logits = logits + mask[None, None, None, :, :] logits_max = torch.max(logits, dim=-1, keepdim=True).values logits_or_sinks_max = torch.maximum(sinks, logits_max) sinks = torch.exp(sinks - logits_or_sinks_max) unnormalized_scores = torch.exp(logits - logits_or_sinks_max) normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks scores = unnormalized_scores / normalizer output = torch.einsum("bhmqk,bkhmd->bqhmd", scores, value.float()) output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups * head_dim).bfloat16() return output @pytest.mark.parametrize("batch_size", [1, 2]) @pytest.mark.parametrize("num_queries", [1, 128]) @pytest.mark.parametrize("num_keys", [128, 32]) @pytest.mark.parametrize("num_key_value_heads", [8]) @pytest.mark.parametrize("num_key_value_groups", [8]) @pytest.mark.parametrize("head_dim", [64]) @pytest.mark.parametrize("sm_scale", [0.125]) @pytest.mark.parametrize("sliding_window", [None, 128]) @pytest.mark.parametrize("start_q", [0, 5]) def test_eq(batch_size, num_queries, num_keys, num_key_value_heads, num_key_value_groups, head_dim, sm_scale, sliding_window, start_q): if num_queries > num_keys: pytest.skip("too many queries") q = torch.randn(batch_size, num_queries, num_key_value_heads, num_key_value_groups, head_dim).bfloat16().cuda() k = torch.randn(batch_size, num_keys, num_key_value_heads, head_dim).bfloat16().cuda() v = torch.randn(batch_size, num_keys, num_key_value_heads, head_dim).bfloat16().cuda() sinks = torch.randn(num_key_value_heads * num_key_value_groups).bfloat16().cuda() start_q = torch.tensor([start_q], dtype=torch.int32).cuda() o1 = attention(q, k, v, sinks, sm_scale, sliding_window, start_q) o2 = attention_ref(q, k, v, sinks, sm_scale, sliding_window, start_q) torch.testing.assert_close(o1, o2)