feat: add empirical Triton benchmarks to Paper 02
Mirror to GitLab / mirror (push) Waiting to run

This commit is contained in:
Antigravity Agent
2026-05-25 12:31:06 +00:00
parent 02953c7201
commit ca9f764ea3
9 changed files with 69 additions and 47 deletions
@@ -0,0 +1,4 @@
N_CTX,Naive Unfused (PyTorch) (Latency (ms)),PagedFieldprint (Triton) (Latency (ms))
1024.000000,10.757120,194.005890
2048.000000,36.357632,721.277954
4096.000000,152.161880,2787.063721
1 N_CTX Naive Unfused (PyTorch) (Latency (ms)) PagedFieldprint (Triton) (Latency (ms))
2 1024.000000 10.757120 194.005890
3 2048.000000 36.357632 721.277954
4 4096.000000 152.161880 2787.063721
Binary file not shown.

After

Width:  |  Height:  |  Size: 31 KiB

@@ -34,7 +34,7 @@ def naive_unfused_attention(q, k_ctx, v_ctx, k_anchor, v_anchor):
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['N_CTX'],
x_vals=[2**i for i in range(10, 16)], # 1024 to 32768
x_vals=[2**i for i in range(10, 13)], # 1024 to 4096
line_arg='provider',
line_vals=['naive', 'fused'],
line_names=['Naive Unfused (PyTorch)', 'PagedFieldprint (Triton)'],
@@ -20,6 +20,8 @@ def paged_fieldprint_attention_kernel(
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vn, stride_vk,
stride_ka_z, stride_ka_h, stride_ka_n, stride_ka_k,
stride_va_z, stride_va_h, stride_va_n, stride_va_k,
stride_oz, stride_oh, stride_om, stride_ok,
Z, H, N_CTX, N_ANCHOR,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr,
@@ -31,8 +33,8 @@ def paged_fieldprint_attention_kernel(
q_offset = off_hz * stride_qh
k_ctx_offset = off_hz * stride_kh
v_ctx_offset = off_hz * stride_vh
k_anchor_offset = off_hz * stride_kh # Assuming same layout
v_anchor_offset = off_hz * stride_vh
k_anchor_offset = off_hz * stride_ka_h
v_anchor_offset = off_hz * stride_va_h
# Block pointers
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
@@ -43,7 +45,7 @@ def paged_fieldprint_attention_kernel(
o_ptrs = Out + q_offset + offs_m[:, None] * stride_om + offs_d[None, :] * stride_ok
# Load Q
q = tl.load(q_ptrs)
q = tl.load(q_ptrs, mask=offs_m[:, None] < N_CTX, other=0.0)
# Initialize accumulators
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
@@ -57,10 +59,10 @@ def paged_fieldprint_attention_kernel(
# -------------------------------------------------------------
# We process the anchor tokens first. They compete in the softmax.
for start_n in range(0, N_ANCHOR, BLOCK_N):
k_ptrs = K_anchor + k_anchor_offset + (start_n + offs_n[None, :]) * stride_kn + offs_d[:, None] * stride_kk
v_ptrs = V_anchor + v_anchor_offset + (start_n + offs_n[:, None]) * stride_vn + offs_d[None, :] * stride_vk
k_ptrs = K_anchor + k_anchor_offset + (start_n + offs_n[None, :]) * stride_ka_n + offs_d[:, None] * stride_ka_k
v_ptrs = V_anchor + v_anchor_offset + (start_n + offs_n[:, None]) * stride_va_n + offs_d[None, :] * stride_va_k
k = tl.load(k_ptrs)
k = tl.load(k_ptrs, mask=(start_n + offs_n[None, :]) < N_ANCHOR, other=0.0)
qk = tl.dot(q, k) * sm_scale
# Softmax logic
@@ -73,7 +75,7 @@ def paged_fieldprint_attention_kernel(
acc = acc * alpha[:, None]
# Update
v = tl.load(v_ptrs)
v = tl.load(v_ptrs, mask=(start_n + offs_n[:, None]) < N_ANCHOR, other=0.0)
acc += tl.dot(p.to(tl.float16), v)
m_i = m_ij
l_i = l_i * alpha + l_ij
@@ -85,7 +87,7 @@ def paged_fieldprint_attention_kernel(
k_ptrs = K_ctx + k_ctx_offset + (start_n + offs_n[None, :]) * stride_kn + offs_d[:, None] * stride_kk
v_ptrs = V_ctx + v_ctx_offset + (start_n + offs_n[:, None]) * stride_vn + offs_d[None, :] * stride_vk
k = tl.load(k_ptrs)
k = tl.load(k_ptrs, mask=(start_n + offs_n[None, :]) < N_CTX, other=0.0)
qk = tl.dot(q, k) * sm_scale
# Softmax logic
@@ -96,14 +98,14 @@ def paged_fieldprint_attention_kernel(
alpha = tl.math.exp(m_i - m_ij)
acc = acc * alpha[:, None]
v = tl.load(v_ptrs)
v = tl.load(v_ptrs, mask=(start_n + offs_n[:, None]) < N_CTX, other=0.0)
acc += tl.dot(p.to(tl.float16), v)
m_i = m_ij
l_i = l_i * alpha + l_ij
# Normalize and write back
acc = acc / l_i[:, None]
tl.store(o_ptrs, acc.to(tl.float16))
tl.store(o_ptrs, acc.to(tl.float16), mask=offs_m[:, None] < N_CTX)
def paged_fieldprint_attention(q, k_ctx, v_ctx, k_anchor, v_anchor):
# Shape expectations: [batch, heads, seq_len, d_model]
@@ -112,8 +114,8 @@ def paged_fieldprint_attention(q, k_ctx, v_ctx, k_anchor, v_anchor):
out = torch.empty_like(q)
BLOCK_M = 128
BLOCK_N = 64
BLOCK_M = 64
BLOCK_N = 32
grid = (triton.cdiv(N_CTX, BLOCK_M), Z * H)
@@ -122,6 +124,8 @@ def paged_fieldprint_attention(q, k_ctx, v_ctx, k_anchor, v_anchor):
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k_ctx.stride(0), k_ctx.stride(1), k_ctx.stride(2), k_ctx.stride(3),
v_ctx.stride(0), v_ctx.stride(1), v_ctx.stride(2), v_ctx.stride(3),
k_anchor.stride(0), k_anchor.stride(1), k_anchor.stride(2), k_anchor.stride(3),
v_anchor.stride(0), v_anchor.stride(1), v_anchor.stride(2), v_anchor.stride(3),
out.stride(0), out.stride(1), out.stride(2), out.stride(3),
Z, H, N_CTX, N_ANCHOR,
BLOCK_M=BLOCK_M, BLOCK_DMODEL=D_HEAD, BLOCK_N=BLOCK_N,