Initial commit
Co-authored-by: Zhuohan Li <zhuohan@openai.com> Co-authored-by: Maratyszcza <marat@openai.com> Co-authored-by: Volodymyr Kyrylov <vol@wilab.org.ua>
This commit is contained in:
302
gpt_oss/triton/attention.py
Normal file
302
gpt_oss/triton/attention.py
Normal file
@@ -0,0 +1,302 @@
|
||||
"""FlashAttention w/support for learned sinks and banded attention.
|
||||
|
||||
This is an expanded version of the Flash Attention v2 implementation (see https://tridao.me/publications/flash2/flash2.pdf)
|
||||
which can be found at https://triton-lang.org/main/getting-started/tutorials/06-fused-attention.html.
|
||||
|
||||
This version has been extended to support banded attention and learned attention sinks.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _attn_fwd(
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
Sinks,
|
||||
sm_scale,
|
||||
M,
|
||||
Out, #
|
||||
Start_q,
|
||||
stride_qz,
|
||||
stride_qh,
|
||||
stride_qm,
|
||||
stride_qk, #
|
||||
stride_kz,
|
||||
stride_kh,
|
||||
stride_kn,
|
||||
stride_kk, #
|
||||
stride_vz,
|
||||
stride_vh,
|
||||
stride_vn,
|
||||
stride_vk, #
|
||||
stride_oz,
|
||||
stride_oh,
|
||||
stride_om,
|
||||
stride_ok, #
|
||||
Z,
|
||||
H,
|
||||
N_Q_CTX,
|
||||
N_KV_CTX,
|
||||
HEAD_DIM: tl.constexpr, #
|
||||
BLOCK_M: tl.constexpr, #
|
||||
BLOCK_N: tl.constexpr, #
|
||||
BANDWIDTH: tl.constexpr,
|
||||
):
|
||||
tl.static_assert(BLOCK_N <= HEAD_DIM)
|
||||
start_q = tl.load(Start_q).to(tl.int32)
|
||||
start_m = tl.program_id(0)
|
||||
off_hz = tl.program_id(1)
|
||||
off_z = off_hz // H
|
||||
off_h = off_hz % H
|
||||
q_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh
|
||||
k_offset = off_z.to(tl.int64) * stride_kz + off_h.to(tl.int64) * stride_kh
|
||||
v_offset = off_z.to(tl.int64) * stride_vz + off_h.to(tl.int64) * stride_vh
|
||||
o_offset = off_z.to(tl.int64) * stride_oz + off_h.to(tl.int64) * stride_oh
|
||||
|
||||
# block pointers
|
||||
Q_block_ptr = tl.make_block_ptr(
|
||||
base=Q + q_offset,
|
||||
shape=(N_Q_CTX, HEAD_DIM),
|
||||
strides=(stride_qm, stride_qk),
|
||||
offsets=(start_m * BLOCK_M, 0),
|
||||
block_shape=(BLOCK_M, HEAD_DIM),
|
||||
order=(1, 0),
|
||||
)
|
||||
V_block_ptr = tl.make_block_ptr(
|
||||
base=V + v_offset,
|
||||
shape=(N_KV_CTX, HEAD_DIM),
|
||||
strides=(stride_vn, stride_vk),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_N, HEAD_DIM),
|
||||
order=(1, 0),
|
||||
)
|
||||
K_block_ptr = tl.make_block_ptr(
|
||||
base=K + k_offset,
|
||||
shape=(HEAD_DIM, N_KV_CTX),
|
||||
strides=(stride_kk, stride_kn),
|
||||
offsets=(0, 0),
|
||||
block_shape=(HEAD_DIM, BLOCK_N),
|
||||
order=(0, 1),
|
||||
)
|
||||
O_block_ptr = tl.make_block_ptr(
|
||||
base=Out + o_offset,
|
||||
shape=(N_Q_CTX, HEAD_DIM),
|
||||
strides=(stride_om, stride_ok),
|
||||
offsets=(start_m * BLOCK_M, 0),
|
||||
block_shape=(BLOCK_M, HEAD_DIM),
|
||||
order=(1, 0),
|
||||
)
|
||||
|
||||
# load attention sinks
|
||||
if Sinks is not None:
|
||||
sink = tl.load(Sinks + off_h).to(tl.float32)
|
||||
else:
|
||||
sink = 0
|
||||
|
||||
# initialize offsets
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
# initialize pointer to m and l
|
||||
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) + sink
|
||||
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
|
||||
# load scales
|
||||
qk_scale = sm_scale
|
||||
q = tl.load(Q_block_ptr)
|
||||
|
||||
if BANDWIDTH:
|
||||
lo, hi = tl.maximum(start_q, start_q + start_m * BLOCK_M - BANDWIDTH), (start_q + start_m + 1) * BLOCK_M
|
||||
else:
|
||||
lo, hi = start_q, (start_q + start_m + 1) * BLOCK_M
|
||||
|
||||
# advance the KV block-pointers so they point at `lo`
|
||||
K_block_ptr = tl.advance(K_block_ptr, (0, lo))
|
||||
V_block_ptr = tl.advance(V_block_ptr, (lo, 0))
|
||||
|
||||
for start_n in range(lo, hi, BLOCK_N):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
|
||||
mask = (start_n + offs_n)[None, :] > (start_q + offs_m)[:, None]
|
||||
|
||||
if BANDWIDTH:
|
||||
too_old = (start_n + offs_n[None, :]) < (start_q + offs_m[:, None] - BANDWIDTH + 1)
|
||||
mask = mask | too_old
|
||||
|
||||
k = tl.load(K_block_ptr)
|
||||
qk = tl.dot(q, k, allow_tf32=False)
|
||||
|
||||
qk = qk * qk_scale + tl.where(mask, -1.0e6, 0.0)
|
||||
m_ij = tl.maximum(m_i, tl.max(qk, 1))
|
||||
qk -= m_ij[:, None]
|
||||
|
||||
p = tl.math.exp(qk)
|
||||
alpha = tl.math.exp(m_i - m_ij)
|
||||
l_ij = tl.sum(p, 1)
|
||||
acc = acc * alpha[:, None]
|
||||
|
||||
v = tl.load(V_block_ptr).to(tl.float32)
|
||||
acc = tl.dot(p, v, acc, allow_tf32=False)
|
||||
|
||||
l_i = l_i * alpha + l_ij
|
||||
m_i = m_ij
|
||||
|
||||
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
|
||||
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
|
||||
|
||||
sink = tl.math.exp(sink - m_i)
|
||||
z = l_i + sink
|
||||
acc = acc / z[:, None]
|
||||
m_i += tl.math.log(l_i)
|
||||
m_ptrs = M + off_hz * N_Q_CTX + offs_m
|
||||
tl.store(m_ptrs, m_i)
|
||||
tl.store(O_block_ptr, acc.to(Out.type.element_ty))
|
||||
|
||||
|
||||
class _attention(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, q, k, v, sinks, sm_scale, bandwidth, start_q):
|
||||
assert len(start_q) == 1
|
||||
bs, n_ctx, n_kv_heads, repeat_kv, HEAD_DIM_Q = q.shape
|
||||
bs, n_kv_ctx, n_kv_heads, HEAD_DIM_K = k.shape
|
||||
bs, n_kv_ctx, n_kv_heads, HEAD_DIM_V = v.shape
|
||||
n_heads = n_kv_heads * repeat_kv
|
||||
q = q.view(bs, n_ctx, n_heads, HEAD_DIM_Q)
|
||||
k = k.view(bs, n_kv_ctx, n_kv_heads, HEAD_DIM_K)
|
||||
assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V
|
||||
assert HEAD_DIM_K in {16, 32, 64, 128, 256}
|
||||
|
||||
q = q.transpose(1, 2).contiguous()
|
||||
k = k.repeat_interleave(repeat_kv, dim=2).transpose(1, 2).contiguous()
|
||||
v = v.repeat_interleave(repeat_kv, dim=2).transpose(1, 2).contiguous()
|
||||
|
||||
BLOCK_M = 64
|
||||
BLOCK_N = 64
|
||||
m_pad_size = BLOCK_M - n_ctx % BLOCK_M if n_ctx % BLOCK_M != 0 else 0
|
||||
# pad q to multiple of its block size in the n_ctx dimension (-2)
|
||||
q = torch.nn.functional.pad(q, (0, 0, 0, m_pad_size))
|
||||
n_pad_size = BLOCK_N - n_kv_ctx % BLOCK_N if n_kv_ctx % BLOCK_N != 0 else 0
|
||||
# pad k and v to multiple of their block size in the n_kv_ctx dimension
|
||||
k = torch.nn.functional.pad(k, (0, 0, 0, n_pad_size))
|
||||
v = torch.nn.functional.pad(v, (0, 0, 0, n_pad_size))
|
||||
|
||||
o = torch.empty_like(q)
|
||||
M = torch.empty((bs, n_heads, n_ctx + m_pad_size), device=q.device, dtype=torch.float32)
|
||||
grid = (triton.cdiv(n_ctx, BLOCK_M), bs * n_heads, 1)
|
||||
_attn_fwd[grid](
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
sinks,
|
||||
sm_scale,
|
||||
M,
|
||||
o, #
|
||||
start_q,
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
q.stride(2),
|
||||
q.stride(3), #
|
||||
k.stride(0),
|
||||
k.stride(1),
|
||||
k.stride(2),
|
||||
k.stride(3), #
|
||||
v.stride(0),
|
||||
v.stride(1),
|
||||
v.stride(2),
|
||||
v.stride(3), #
|
||||
o.stride(0),
|
||||
o.stride(1),
|
||||
o.stride(2),
|
||||
o.stride(3), #
|
||||
q.shape[0],
|
||||
q.shape[1], #
|
||||
N_Q_CTX=n_ctx + m_pad_size, #
|
||||
N_KV_CTX=n_kv_ctx, #
|
||||
HEAD_DIM=HEAD_DIM_K, #
|
||||
BANDWIDTH=bandwidth,
|
||||
BLOCK_M=BLOCK_M,
|
||||
BLOCK_N=BLOCK_N,
|
||||
)
|
||||
|
||||
ctx.save_for_backward(q, k, v, sinks, o, M, start_q)
|
||||
ctx.sm_scale = sm_scale
|
||||
ctx.bandwidth = bandwidth
|
||||
|
||||
o = o[:, :, :n_ctx, :].transpose(1, 2).contiguous()
|
||||
o = o.view(bs, n_ctx, n_heads * HEAD_DIM_V)
|
||||
return o
|
||||
|
||||
|
||||
attention = _attention.apply
|
||||
|
||||
|
||||
def attention_ref(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
sinks: torch.Tensor,
|
||||
sm_scale: float = 0.125,
|
||||
sliding_window: int | None = None,
|
||||
start_q: torch.LongTensor = 0,
|
||||
):
|
||||
batch_size, num_queries, num_key_value_heads, num_key_value_groups, head_dim = query.shape
|
||||
batch_size, num_keys, num_key_value_heads, head_dim = key.shape
|
||||
|
||||
sinks = sinks.view(1, num_key_value_heads, num_key_value_groups, 1, 1).float()
|
||||
key = key.unsqueeze(3)
|
||||
value = value.unsqueeze(3)
|
||||
|
||||
pos_keys = torch.arange(num_keys, device=query.device)
|
||||
pos_queries = torch.arange(num_queries, device=query.device) + start_q
|
||||
mask = pos_keys[None, :] > pos_queries[:, None]
|
||||
mask = mask.float().masked_fill(mask, float("-inf"))
|
||||
|
||||
if sliding_window:
|
||||
too_old = pos_keys[None, :] < (pos_queries[:, None] - sliding_window + 1)
|
||||
mask.masked_fill_(too_old, float("-inf"))
|
||||
|
||||
logits = torch.einsum("bqhmd,bkhmd->bhmqk", query.float(), key.float()) * sm_scale
|
||||
logits = logits + mask[None, None, None, :, :]
|
||||
|
||||
logits_max = torch.max(logits, dim=-1, keepdim=True).values
|
||||
logits_or_sinks_max = torch.maximum(sinks, logits_max)
|
||||
sinks = torch.exp(sinks - logits_or_sinks_max)
|
||||
unnormalized_scores = torch.exp(logits - logits_or_sinks_max)
|
||||
normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks
|
||||
scores = unnormalized_scores / normalizer
|
||||
|
||||
output = torch.einsum("bhmqk,bkhmd->bqhmd", scores, value.float())
|
||||
|
||||
output = output.reshape(batch_size, num_queries, num_key_value_heads * num_key_value_groups * head_dim).bfloat16()
|
||||
return output
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [1, 2])
|
||||
@pytest.mark.parametrize("num_queries", [1, 128])
|
||||
@pytest.mark.parametrize("num_keys", [128, 32])
|
||||
@pytest.mark.parametrize("num_key_value_heads", [8])
|
||||
@pytest.mark.parametrize("num_key_value_groups", [8])
|
||||
@pytest.mark.parametrize("head_dim", [64])
|
||||
@pytest.mark.parametrize("sm_scale", [0.125])
|
||||
@pytest.mark.parametrize("sliding_window", [None, 128])
|
||||
@pytest.mark.parametrize("start_q", [0, 5])
|
||||
def test_eq(batch_size, num_queries, num_keys, num_key_value_heads, num_key_value_groups, head_dim, sm_scale, sliding_window, start_q):
|
||||
if num_queries > num_keys:
|
||||
pytest.skip("too many queries")
|
||||
|
||||
q = torch.randn(batch_size, num_queries, num_key_value_heads, num_key_value_groups, head_dim).bfloat16().cuda()
|
||||
k = torch.randn(batch_size, num_keys, num_key_value_heads, head_dim).bfloat16().cuda()
|
||||
v = torch.randn(batch_size, num_keys, num_key_value_heads, head_dim).bfloat16().cuda()
|
||||
sinks = torch.randn(num_key_value_heads * num_key_value_groups).bfloat16().cuda()
|
||||
|
||||
start_q = torch.tensor([start_q], dtype=torch.int32).cuda()
|
||||
|
||||
o1 = attention(q, k, v, sinks, sm_scale, sliding_window, start_q)
|
||||
o2 = attention_ref(q, k, v, sinks, sm_scale, sliding_window, start_q)
|
||||
|
||||
torch.testing.assert_close(o1, o2)
|
||||
Reference in New Issue
Block a user