From 02953c72019e5068f05d883190a0a6045682f8cc Mon Sep 17 00:00:00 2001 From: Antigravity Agent Date: Mon, 25 May 2026 12:17:16 +0000 Subject: [PATCH] feat(kernel): implement PagedFieldprintAttention triton kernel and benchmark --- experiments/paged_fieldprint_kernel/README.md | 24 ++++ .../paged_fieldprint_kernel/benchmark.py | 65 +++++++++ .../fused_attention.py | 129 ++++++++++++++++++ 3 files changed, 218 insertions(+) create mode 100644 experiments/paged_fieldprint_kernel/README.md create mode 100644 experiments/paged_fieldprint_kernel/benchmark.py create mode 100644 experiments/paged_fieldprint_kernel/fused_attention.py diff --git a/experiments/paged_fieldprint_kernel/README.md b/experiments/paged_fieldprint_kernel/README.md new file mode 100644 index 0000000..737f80b --- /dev/null +++ b/experiments/paged_fieldprint_kernel/README.md @@ -0,0 +1,24 @@ +# PagedFieldprintAttention Kernel Benchmark + +This directory contains the Triton kernel implementation and benchmark suite for the `PagedFieldprintAttention` mechanism proposed in the Verifiable Dual-Path Architecture. + +## Architecture + +Modern autoregressive generation relies heavily on fused attention kernels (like FlashAttention) to prevent HBM memory thrashing. Unfused additions of cryptographic identity anchors force intermediate matrices to be written back to HBM, destroying inference throughput. + +Our custom Triton kernel, `paged_fieldprint_attention_kernel`, resolves this by computing the attention scores for the cryptographic anchor tokens in phase 1, and the standard context tokens in phase 2, scaling and accumulating the softmax reduction entirely in SRAM. + +## Files +- `fused_attention.py`: The Triton kernel implementation. +- `benchmark.py`: The `triton.testing.perf_report` harness comparing the naive PyTorch implementation against our fused Triton kernel. + +## Execution + +This benchmark requires a CUDA-enabled NVIDIA GPU. + +```bash +pip install torch triton +python benchmark.py +``` + +The script will sweep across context lengths from 1,024 to 32,768 and generate `attention-latency-benchmark.csv` and a PNG plot demonstrating the $O(N)$ vs $O(N^2)$ memory bandwidth costs. diff --git a/experiments/paged_fieldprint_kernel/benchmark.py b/experiments/paged_fieldprint_kernel/benchmark.py new file mode 100644 index 0000000..46e1ad7 --- /dev/null +++ b/experiments/paged_fieldprint_kernel/benchmark.py @@ -0,0 +1,65 @@ +""" +Benchmark Suite for PagedFieldprintAttention +============================================ + +This script empirically benchmarks the memory bandwidth and latency +savings of the Fused Triton kernel vs. a Naive Unfused PyTorch implementation. +""" + +import torch +import triton +import triton.testing +from fused_attention import paged_fieldprint_attention + +def naive_unfused_attention(q, k_ctx, v_ctx, k_anchor, v_anchor): + """ + Simulates the mathematically valid but hardware-inefficient + unfused dual-attention from the original markdown paper. + Materializes the full N x N matrix in HBM. + """ + # Concatenate along sequence dimension + k_full = torch.cat([k_anchor, k_ctx], dim=2) + v_full = torch.cat([v_anchor, v_ctx], dim=2) + + d_k = q.size(-1) + + # Materialize N x N attention matrix in HBM + scores = torch.matmul(q, k_full.transpose(-2, -1)) / (d_k ** 0.5) + attn = torch.softmax(scores, dim=-1) + + # Materialize final output in HBM + out = torch.matmul(attn, v_full) + return out + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=['N_CTX'], + x_vals=[2**i for i in range(10, 16)], # 1024 to 32768 + line_arg='provider', + line_vals=['naive', 'fused'], + line_names=['Naive Unfused (PyTorch)', 'PagedFieldprint (Triton)'], + styles=[('blue', '-'), ('green', '-')], + ylabel='Latency (ms)', + plot_name='attention-latency-benchmark', + args={'BATCH': 1, 'H': 32, 'D_HEAD': 128, 'N_ANCHOR': 128} + ) +) +def benchmark_attention(BATCH, H, N_CTX, D_HEAD, N_ANCHOR, provider): + q = torch.randn((BATCH, H, N_CTX, D_HEAD), device='cuda', dtype=torch.float16) + k_ctx = torch.randn((BATCH, H, N_CTX, D_HEAD), device='cuda', dtype=torch.float16) + v_ctx = torch.randn((BATCH, H, N_CTX, D_HEAD), device='cuda', dtype=torch.float16) + + k_anchor = torch.randn((BATCH, H, N_ANCHOR, D_HEAD), device='cuda', dtype=torch.float16) + v_anchor = torch.randn((BATCH, H, N_ANCHOR, D_HEAD), device='cuda', dtype=torch.float16) + + quantiles = [0.5, 0.2, 0.8] + if provider == 'naive': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: naive_unfused_attention(q, k_ctx, v_ctx, k_anchor, v_anchor), quantiles=quantiles) + if provider == 'fused': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: paged_fieldprint_attention(q, k_ctx, v_ctx, k_anchor, v_anchor), quantiles=quantiles) + + return ms, min_ms, max_ms + +if __name__ == '__main__': + print("Running PagedFieldprintAttention Benchmark Suite...") + benchmark_attention.run(save_path='.', print_data=True) diff --git a/experiments/paged_fieldprint_kernel/fused_attention.py b/experiments/paged_fieldprint_kernel/fused_attention.py new file mode 100644 index 0000000..996b9f6 --- /dev/null +++ b/experiments/paged_fieldprint_kernel/fused_attention.py @@ -0,0 +1,129 @@ +""" +PagedFieldprintAttention: Custom Triton Kernel Implementation +=========================================================== + +This module implements the fused CUDA/Triton kernel proposed in Paper 02. +It computes: +Output = FusedSoftmax( (Q [K, K_anchor]^T) / sqrt(d) ) [V, V_anchor] + +By computing the anchor contribution explicitly inside the SRAM before +writing back to HBM, it avoids the catastrophic memory thrashing of unfused dual-attention. +""" + +import torch +import triton +import triton.language as tl + +@triton.jit +def paged_fieldprint_attention_kernel( + Q, K_ctx, V_ctx, K_anchor, V_anchor, Out, + stride_qz, stride_qh, stride_qm, stride_qk, + stride_kz, stride_kh, stride_kn, stride_kk, + stride_vz, stride_vh, stride_vn, stride_vk, + stride_oz, stride_oh, stride_om, stride_ok, + Z, H, N_CTX, N_ANCHOR, + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, +): + start_m = tl.program_id(0) + off_hz = tl.program_id(1) + + # Initialize offsets + q_offset = off_hz * stride_qh + k_ctx_offset = off_hz * stride_kh + v_ctx_offset = off_hz * stride_vh + k_anchor_offset = off_hz * stride_kh # Assuming same layout + v_anchor_offset = off_hz * stride_vh + + # Block pointers + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + + q_ptrs = Q + q_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk + o_ptrs = Out + q_offset + offs_m[:, None] * stride_om + offs_d[None, :] * stride_ok + + # Load Q + q = tl.load(q_ptrs) + + # Initialize accumulators + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + sm_scale = 1.0 / (BLOCK_DMODEL ** 0.5) + + # ------------------------------------------------------------- + # PHASE 1: Process the Cryptographic Anchor Tokens + # ------------------------------------------------------------- + # We process the anchor tokens first. They compete in the softmax. + for start_n in range(0, N_ANCHOR, BLOCK_N): + k_ptrs = K_anchor + k_anchor_offset + (start_n + offs_n[None, :]) * stride_kn + offs_d[:, None] * stride_kk + v_ptrs = V_anchor + v_anchor_offset + (start_n + offs_n[:, None]) * stride_vn + offs_d[None, :] * stride_vk + + k = tl.load(k_ptrs) + qk = tl.dot(q, k) * sm_scale + + # Softmax logic + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + p = tl.math.exp(qk - m_ij[:, None]) + l_ij = tl.sum(p, 1) + + # Scaling previous acc + alpha = tl.math.exp(m_i - m_ij) + acc = acc * alpha[:, None] + + # Update + v = tl.load(v_ptrs) + acc += tl.dot(p.to(tl.float16), v) + m_i = m_ij + l_i = l_i * alpha + l_ij + + # ------------------------------------------------------------- + # PHASE 2: Process the Standard Context Tokens + # ------------------------------------------------------------- + for start_n in range(0, N_CTX, BLOCK_N): + k_ptrs = K_ctx + k_ctx_offset + (start_n + offs_n[None, :]) * stride_kn + offs_d[:, None] * stride_kk + v_ptrs = V_ctx + v_ctx_offset + (start_n + offs_n[:, None]) * stride_vn + offs_d[None, :] * stride_vk + + k = tl.load(k_ptrs) + qk = tl.dot(q, k) * sm_scale + + # Softmax logic + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + p = tl.math.exp(qk - m_ij[:, None]) + l_ij = tl.sum(p, 1) + + alpha = tl.math.exp(m_i - m_ij) + acc = acc * alpha[:, None] + + v = tl.load(v_ptrs) + acc += tl.dot(p.to(tl.float16), v) + m_i = m_ij + l_i = l_i * alpha + l_ij + + # Normalize and write back + acc = acc / l_i[:, None] + tl.store(o_ptrs, acc.to(tl.float16)) + +def paged_fieldprint_attention(q, k_ctx, v_ctx, k_anchor, v_anchor): + # Shape expectations: [batch, heads, seq_len, d_model] + Z, H, N_CTX, D_HEAD = q.shape + _, _, N_ANCHOR, _ = k_anchor.shape + + out = torch.empty_like(q) + + BLOCK_M = 128 + BLOCK_N = 64 + + grid = (triton.cdiv(N_CTX, BLOCK_M), Z * H) + + paged_fieldprint_attention_kernel[grid]( + q, k_ctx, v_ctx, k_anchor, v_anchor, out, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + k_ctx.stride(0), k_ctx.stride(1), k_ctx.stride(2), k_ctx.stride(3), + v_ctx.stride(0), v_ctx.stride(1), v_ctx.stride(2), v_ctx.stride(3), + out.stride(0), out.stride(1), out.stride(2), out.stride(3), + Z, H, N_CTX, N_ANCHOR, + BLOCK_M=BLOCK_M, BLOCK_DMODEL=D_HEAD, BLOCK_N=BLOCK_N, + ) + return out