66 lines
2.6 KiB
Python
66 lines
2.6 KiB
Python
"""
|
|
Benchmark Suite for PagedFieldprintAttention
|
|
============================================
|
|
|
|
This script empirically benchmarks the memory bandwidth and latency
|
|
savings of the Fused Triton kernel vs. a Naive Unfused PyTorch implementation.
|
|
"""
|
|
|
|
import torch
|
|
import triton
|
|
import triton.testing
|
|
from fused_attention import paged_fieldprint_attention
|
|
|
|
def naive_unfused_attention(q, k_ctx, v_ctx, k_anchor, v_anchor):
|
|
"""
|
|
Simulates the mathematically valid but hardware-inefficient
|
|
unfused dual-attention from the original markdown paper.
|
|
Materializes the full N x N matrix in HBM.
|
|
"""
|
|
# Concatenate along sequence dimension
|
|
k_full = torch.cat([k_anchor, k_ctx], dim=2)
|
|
v_full = torch.cat([v_anchor, v_ctx], dim=2)
|
|
|
|
d_k = q.size(-1)
|
|
|
|
# Materialize N x N attention matrix in HBM
|
|
scores = torch.matmul(q, k_full.transpose(-2, -1)) / (d_k ** 0.5)
|
|
attn = torch.softmax(scores, dim=-1)
|
|
|
|
# Materialize final output in HBM
|
|
out = torch.matmul(attn, v_full)
|
|
return out
|
|
|
|
@triton.testing.perf_report(
|
|
triton.testing.Benchmark(
|
|
x_names=['N_CTX'],
|
|
x_vals=[2**i for i in range(10, 13)], # 1024 to 4096
|
|
line_arg='provider',
|
|
line_vals=['naive', 'fused'],
|
|
line_names=['Naive Unfused (PyTorch)', 'PagedFieldprint (Triton)'],
|
|
styles=[('blue', '-'), ('green', '-')],
|
|
ylabel='Latency (ms)',
|
|
plot_name='attention-latency-benchmark',
|
|
args={'BATCH': 1, 'H': 32, 'D_HEAD': 128, 'N_ANCHOR': 128}
|
|
)
|
|
)
|
|
def benchmark_attention(BATCH, H, N_CTX, D_HEAD, N_ANCHOR, provider):
|
|
q = torch.randn((BATCH, H, N_CTX, D_HEAD), device='cuda', dtype=torch.float16)
|
|
k_ctx = torch.randn((BATCH, H, N_CTX, D_HEAD), device='cuda', dtype=torch.float16)
|
|
v_ctx = torch.randn((BATCH, H, N_CTX, D_HEAD), device='cuda', dtype=torch.float16)
|
|
|
|
k_anchor = torch.randn((BATCH, H, N_ANCHOR, D_HEAD), device='cuda', dtype=torch.float16)
|
|
v_anchor = torch.randn((BATCH, H, N_ANCHOR, D_HEAD), device='cuda', dtype=torch.float16)
|
|
|
|
quantiles = [0.5, 0.2, 0.8]
|
|
if provider == 'naive':
|
|
ms, min_ms, max_ms = triton.testing.do_bench(lambda: naive_unfused_attention(q, k_ctx, v_ctx, k_anchor, v_anchor), quantiles=quantiles)
|
|
if provider == 'fused':
|
|
ms, min_ms, max_ms = triton.testing.do_bench(lambda: paged_fieldprint_attention(q, k_ctx, v_ctx, k_anchor, v_anchor), quantiles=quantiles)
|
|
|
|
return ms, min_ms, max_ms
|
|
|
|
if __name__ == '__main__':
|
|
print("Running PagedFieldprintAttention Benchmark Suite...")
|
|
benchmark_attention.run(save_path='.', print_data=True)
|