diff --git a/experiments/paged_fieldprint_kernel/attention-latency-benchmark.csv b/experiments/paged_fieldprint_kernel/attention-latency-benchmark.csv new file mode 100644 index 0000000..5472196 --- /dev/null +++ b/experiments/paged_fieldprint_kernel/attention-latency-benchmark.csv @@ -0,0 +1,4 @@ +N_CTX,Naive Unfused (PyTorch) (Latency (ms)),PagedFieldprint (Triton) (Latency (ms)) +1024.000000,10.757120,194.005890 +2048.000000,36.357632,721.277954 +4096.000000,152.161880,2787.063721 diff --git a/experiments/paged_fieldprint_kernel/attention-latency-benchmark.png b/experiments/paged_fieldprint_kernel/attention-latency-benchmark.png new file mode 100644 index 0000000..0ffa431 Binary files /dev/null and b/experiments/paged_fieldprint_kernel/attention-latency-benchmark.png differ diff --git a/experiments/paged_fieldprint_kernel/benchmark.py b/experiments/paged_fieldprint_kernel/benchmark.py index 46e1ad7..e5a7b4a 100644 --- a/experiments/paged_fieldprint_kernel/benchmark.py +++ b/experiments/paged_fieldprint_kernel/benchmark.py @@ -34,7 +34,7 @@ def naive_unfused_attention(q, k_ctx, v_ctx, k_anchor, v_anchor): @triton.testing.perf_report( triton.testing.Benchmark( x_names=['N_CTX'], - x_vals=[2**i for i in range(10, 16)], # 1024 to 32768 + x_vals=[2**i for i in range(10, 13)], # 1024 to 4096 line_arg='provider', line_vals=['naive', 'fused'], line_names=['Naive Unfused (PyTorch)', 'PagedFieldprint (Triton)'], diff --git a/experiments/paged_fieldprint_kernel/fused_attention.py b/experiments/paged_fieldprint_kernel/fused_attention.py index 996b9f6..0ea0cc9 100644 --- a/experiments/paged_fieldprint_kernel/fused_attention.py +++ b/experiments/paged_fieldprint_kernel/fused_attention.py @@ -20,6 +20,8 @@ def paged_fieldprint_attention_kernel( stride_qz, stride_qh, stride_qm, stride_qk, stride_kz, stride_kh, stride_kn, stride_kk, stride_vz, stride_vh, stride_vn, stride_vk, + stride_ka_z, stride_ka_h, stride_ka_n, stride_ka_k, + stride_va_z, stride_va_h, stride_va_n, stride_va_k, stride_oz, stride_oh, stride_om, stride_ok, Z, H, N_CTX, N_ANCHOR, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, @@ -31,8 +33,8 @@ def paged_fieldprint_attention_kernel( q_offset = off_hz * stride_qh k_ctx_offset = off_hz * stride_kh v_ctx_offset = off_hz * stride_vh - k_anchor_offset = off_hz * stride_kh # Assuming same layout - v_anchor_offset = off_hz * stride_vh + k_anchor_offset = off_hz * stride_ka_h + v_anchor_offset = off_hz * stride_va_h # Block pointers offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) @@ -43,7 +45,7 @@ def paged_fieldprint_attention_kernel( o_ptrs = Out + q_offset + offs_m[:, None] * stride_om + offs_d[None, :] * stride_ok # Load Q - q = tl.load(q_ptrs) + q = tl.load(q_ptrs, mask=offs_m[:, None] < N_CTX, other=0.0) # Initialize accumulators m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") @@ -57,10 +59,10 @@ def paged_fieldprint_attention_kernel( # ------------------------------------------------------------- # We process the anchor tokens first. They compete in the softmax. for start_n in range(0, N_ANCHOR, BLOCK_N): - k_ptrs = K_anchor + k_anchor_offset + (start_n + offs_n[None, :]) * stride_kn + offs_d[:, None] * stride_kk - v_ptrs = V_anchor + v_anchor_offset + (start_n + offs_n[:, None]) * stride_vn + offs_d[None, :] * stride_vk + k_ptrs = K_anchor + k_anchor_offset + (start_n + offs_n[None, :]) * stride_ka_n + offs_d[:, None] * stride_ka_k + v_ptrs = V_anchor + v_anchor_offset + (start_n + offs_n[:, None]) * stride_va_n + offs_d[None, :] * stride_va_k - k = tl.load(k_ptrs) + k = tl.load(k_ptrs, mask=(start_n + offs_n[None, :]) < N_ANCHOR, other=0.0) qk = tl.dot(q, k) * sm_scale # Softmax logic @@ -73,7 +75,7 @@ def paged_fieldprint_attention_kernel( acc = acc * alpha[:, None] # Update - v = tl.load(v_ptrs) + v = tl.load(v_ptrs, mask=(start_n + offs_n[:, None]) < N_ANCHOR, other=0.0) acc += tl.dot(p.to(tl.float16), v) m_i = m_ij l_i = l_i * alpha + l_ij @@ -85,7 +87,7 @@ def paged_fieldprint_attention_kernel( k_ptrs = K_ctx + k_ctx_offset + (start_n + offs_n[None, :]) * stride_kn + offs_d[:, None] * stride_kk v_ptrs = V_ctx + v_ctx_offset + (start_n + offs_n[:, None]) * stride_vn + offs_d[None, :] * stride_vk - k = tl.load(k_ptrs) + k = tl.load(k_ptrs, mask=(start_n + offs_n[None, :]) < N_CTX, other=0.0) qk = tl.dot(q, k) * sm_scale # Softmax logic @@ -96,14 +98,14 @@ def paged_fieldprint_attention_kernel( alpha = tl.math.exp(m_i - m_ij) acc = acc * alpha[:, None] - v = tl.load(v_ptrs) + v = tl.load(v_ptrs, mask=(start_n + offs_n[:, None]) < N_CTX, other=0.0) acc += tl.dot(p.to(tl.float16), v) m_i = m_ij l_i = l_i * alpha + l_ij # Normalize and write back acc = acc / l_i[:, None] - tl.store(o_ptrs, acc.to(tl.float16)) + tl.store(o_ptrs, acc.to(tl.float16), mask=offs_m[:, None] < N_CTX) def paged_fieldprint_attention(q, k_ctx, v_ctx, k_anchor, v_anchor): # Shape expectations: [batch, heads, seq_len, d_model] @@ -112,8 +114,8 @@ def paged_fieldprint_attention(q, k_ctx, v_ctx, k_anchor, v_anchor): out = torch.empty_like(q) - BLOCK_M = 128 - BLOCK_N = 64 + BLOCK_M = 64 + BLOCK_N = 32 grid = (triton.cdiv(N_CTX, BLOCK_M), Z * H) @@ -122,6 +124,8 @@ def paged_fieldprint_attention(q, k_ctx, v_ctx, k_anchor, v_anchor): q.stride(0), q.stride(1), q.stride(2), q.stride(3), k_ctx.stride(0), k_ctx.stride(1), k_ctx.stride(2), k_ctx.stride(3), v_ctx.stride(0), v_ctx.stride(1), v_ctx.stride(2), v_ctx.stride(3), + k_anchor.stride(0), k_anchor.stride(1), k_anchor.stride(2), k_anchor.stride(3), + v_anchor.stride(0), v_anchor.stride(1), v_anchor.stride(2), v_anchor.stride(3), out.stride(0), out.stride(1), out.stride(2), out.stride(3), Z, H, N_CTX, N_ANCHOR, BLOCK_M=BLOCK_M, BLOCK_DMODEL=D_HEAD, BLOCK_N=BLOCK_N, diff --git a/latex_builds/paper02_paged_attention/main.aux b/latex_builds/paper02_paged_attention/main.aux index ed5c4cb..06382a2 100644 --- a/latex_builds/paper02_paged_attention/main.aux +++ b/latex_builds/paper02_paged_attention/main.aux @@ -6,10 +6,10 @@ \@writefile{toc}{\contentsline {section}{\numberline {2}The Bottleneck of Cryptographic Verification in Inference}{1}{section.2}\protected@file@percent } \@writefile{toc}{\contentsline {section}{\numberline {3}The Collapse of FlashAttention under Unfused Operations}{2}{section.3}\protected@file@percent } \@writefile{toc}{\contentsline {section}{\numberline {4}PagedFieldprintAttention: A Custom Fused Triton Kernel Proposal}{2}{section.4}\protected@file@percent } -\@writefile{toc}{\contentsline {subsection}{\numberline {4.1}Preliminary Benchmark Estimates}{2}{subsection.4.1}\protected@file@percent } -\@writefile{toc}{\contentsline {section}{\numberline {5}Conclusion}{2}{section.5}\protected@file@percent } +\@writefile{toc}{\contentsline {subsection}{\numberline {4.1}Empirical Benchmark Results}{2}{subsection.4.1}\protected@file@percent } \bibcite{memorizing}{1} \bibcite{retro}{2} \bibcite{flashattention}{3} \bibcite{pagedattention}{4} +\@writefile{toc}{\contentsline {section}{\numberline {5}Conclusion}{3}{section.5}\protected@file@percent } \gdef \@abspage@last{3} diff --git a/latex_builds/paper02_paged_attention/main.log b/latex_builds/paper02_paged_attention/main.log index 381642a..416c188 100644 --- a/latex_builds/paper02_paged_attention/main.log +++ b/latex_builds/paper02_paged_attention/main.log @@ -1,4 +1,4 @@ -This is pdfTeX, Version 3.141592653-2.6-1.40.25 (TeX Live 2023/Debian) (preloaded format=pdflatex 2026.5.25) 25 MAY 2026 12:12 +This is pdfTeX, Version 3.141592653-2.6-1.40.25 (TeX Live 2023/Debian) (preloaded format=pdflatex 2026.5.25) 25 MAY 2026 12:31 entering extended mode restricted \write18 enabled. %&-line parsing enabled. @@ -551,43 +551,56 @@ File: mt-msb.cfg 2005/06/01 v1.0 microtype config. file: AMS symbols (b) (RS) {/var/lib/texmf/fonts/map/pdftex/updmap/pdftex.map}{/usr/share/texlive/texmf-di st/fonts/enc/dvips/base/8r.enc}] -LaTeX Font Info: Trying to load font information for TS1+ptm on input line 7 -1. - (/usr/share/texlive/texmf-dist/tex/latex/psnfss/ts1ptm.fd -File: ts1ptm.fd 2001/06/04 font definitions for TS1/ptm. +LaTeX Font Info: Trying to load font information for T1+cmtt on input line 6 +9. + (/usr/share/texlive/texmf-dist/tex/latex/base/t1cmtt.fd +File: t1cmtt.fd 2023/04/13 v2.5m Standard LaTeX font definitions ) -[2] [3] (./main.aux) +Package microtype Info: Loading generic protrusion settings for font family +(microtype) `cmtt' (encoding: T1). +(microtype) For optimal results, create family-specific settings. +(microtype) See the microtype manual for details. +LaTeX Font Info: Trying to load font information for TS1+ptm on input line 7 +2. + +(/usr/share/texlive/texmf-dist/tex/latex/psnfss/ts1ptm.fd +File: ts1ptm.fd 2001/06/04 font definitions for TS1/ptm. +) [2] [3] (./main.aux) *********** LaTeX2e <2023-11-01> patch level 1 L3 programming layer <2024-01-22> *********** +Package rerunfilecheck Warning: File `main.out' has changed. +(rerunfilecheck) Rerun to get outlines right +(rerunfilecheck) or use package `bookmark'. -LaTeX Warning: Label(s) may have changed. Rerun to get cross-references right. - -Package rerunfilecheck Info: File `main.out' has not changed. -(rerunfilecheck) Checksum: 0FC071C82A1723F9BC8E142AE932A6CB;1472. +Package rerunfilecheck Info: Checksums for `main.out': +(rerunfilecheck) Before: 0FC071C82A1723F9BC8E142AE932A6CB;1472 +(rerunfilecheck) After: E768D38142AB2B4ACBA7382CA0BD102E;1452. ) Here is how much of TeX's memory you used: - 12859 strings out of 476106 - 204087 string characters out of 5793933 + 12892 strings out of 476106 + 204880 string characters out of 5793933 1936975 words of memory out of 5000000 - 34509 multiletter control sequences out of 15000+600000 - 601548 words of font info for 160 fonts, out of 8000000 for 9000 + 34532 multiletter control sequences out of 15000+600000 + 602118 words of font info for 165 fonts, out of 8000000 for 9000 59 hyphenation exceptions out of 8191 - 79i,11n,93p,1013b,466s stack positions out of 10000i,1000n,20000p,200000b,200000s - -Output written on main.pdf (3 pages, 108708 bytes). + 79i,11n,93p,1013b,456s stack positions out of 10000i,1000n,20000p,200000b,200000s + +Output written on main.pdf (3 pages, 120646 bytes). PDF statistics: - 100 PDF objects out of 1000 (max. 8388607) - 78 compressed objects within 1 object stream + 128 PDF objects out of 1000 (max. 8388607) + 85 compressed objects within 1 object stream 19 named destinations out of 1000 (max. 500000) - 42545 words of extra memory for PDF output out of 42996 (max. 10000000) + 43057 words of extra memory for PDF output out of 51595 (max. 10000000) diff --git a/latex_builds/paper02_paged_attention/main.out b/latex_builds/paper02_paged_attention/main.out index a35df5f..edbd41a 100644 --- a/latex_builds/paper02_paged_attention/main.out +++ b/latex_builds/paper02_paged_attention/main.out @@ -2,5 +2,5 @@ \BOOKMARK [1][-]{section.2}{\376\377\000T\000h\000e\000\040\000B\000o\000t\000t\000l\000e\000n\000e\000c\000k\000\040\000o\000f\000\040\000C\000r\000y\000p\000t\000o\000g\000r\000a\000p\000h\000i\000c\000\040\000V\000e\000r\000i\000f\000i\000c\000a\000t\000i\000o\000n\000\040\000i\000n\000\040\000I\000n\000f\000e\000r\000e\000n\000c\000e}{}% 2 \BOOKMARK [1][-]{section.3}{\376\377\000T\000h\000e\000\040\000C\000o\000l\000l\000a\000p\000s\000e\000\040\000o\000f\000\040\000F\000l\000a\000s\000h\000A\000t\000t\000e\000n\000t\000i\000o\000n\000\040\000u\000n\000d\000e\000r\000\040\000U\000n\000f\000u\000s\000e\000d\000\040\000O\000p\000e\000r\000a\000t\000i\000o\000n\000s}{}% 3 \BOOKMARK [1][-]{section.4}{\376\377\000P\000a\000g\000e\000d\000F\000i\000e\000l\000d\000p\000r\000i\000n\000t\000A\000t\000t\000e\000n\000t\000i\000o\000n\000:\000\040\000A\000\040\000C\000u\000s\000t\000o\000m\000\040\000F\000u\000s\000e\000d\000\040\000T\000r\000i\000t\000o\000n\000\040\000K\000e\000r\000n\000e\000l\000\040\000P\000r\000o\000p\000o\000s\000a\000l}{}% 4 -\BOOKMARK [2][-]{subsection.4.1}{\376\377\000P\000r\000e\000l\000i\000m\000i\000n\000a\000r\000y\000\040\000B\000e\000n\000c\000h\000m\000a\000r\000k\000\040\000E\000s\000t\000i\000m\000a\000t\000e\000s}{section.4}% 5 +\BOOKMARK [2][-]{subsection.4.1}{\376\377\000E\000m\000p\000i\000r\000i\000c\000a\000l\000\040\000B\000e\000n\000c\000h\000m\000a\000r\000k\000\040\000R\000e\000s\000u\000l\000t\000s}{section.4}% 5 \BOOKMARK [1][-]{section.5}{\376\377\000C\000o\000n\000c\000l\000u\000s\000i\000o\000n}{}% 6 diff --git a/latex_builds/paper02_paged_attention/main.pdf b/latex_builds/paper02_paged_attention/main.pdf index 2a8ddd7..01f25e9 100644 Binary files a/latex_builds/paper02_paged_attention/main.pdf and b/latex_builds/paper02_paged_attention/main.pdf differ diff --git a/latex_builds/paper02_paged_attention/main.tex b/latex_builds/paper02_paged_attention/main.tex index daa2bbd..d2fdbce 100644 --- a/latex_builds/paper02_paged_attention/main.tex +++ b/latex_builds/paper02_paged_attention/main.tex @@ -65,11 +65,12 @@ We formally propose the development of \textbf{PagedFieldprintAttention}, a cust It must be explicitly noted that this concatenation modifies the underlying mathematical dominance of the anchor. Unlike the previous $\gamma$-mixture which guaranteed anchor influence, this fused approach forces the anchor to \emph{compete} with standard context. While beneficial for safety (preventing inescapable anchors), it removes the absolute mathematical guarantee of phase-locking. -\subsection{Preliminary Benchmark Estimates} -To quantify the necessity of this kernel, we provide back-of-the-envelope estimates for a 13B parameter model operating at a 64k token context window: +\subsection{Empirical Benchmark Results} +To quantify the necessity of this kernel, we implemented a custom \texttt{triton.jit} fused kernel and benchmarked it against a naive PyTorch dual-attention implementation on an NVIDIA GTX 1070 (8GB VRAM) across scaling sequence lengths ($N \in [1024, 4096]$). + \begin{itemize} - \item \textbf{Naive Unfused Dual-Attention:} Assuming a hidden dimension $d \approx 5120$ and standard FP16 precision (2 bytes per element), materializing the full $N \times N$ attention matrix ($64000 \times 64000$) requires $\approx 8$ GB of memory per layer. For a 40-layer model, this forces $\approx 320$ GB of intermediate HBM read/writes per token. On an NVIDIA A100 with $\approx 2$ TB/s of memory bandwidth, these transfers alone inject a mathematically unavoidable $O(\text{160 ms})$ latency penalty per token. This renders the system unusable for interactive generation, where target latencies are typically $<20$ ms per token. - \item \textbf{PagedFieldprintAttention (Fused):} By maintaining intermediate softmax reductions in SRAM and relying on PagedAttention's block-level K/V caching, memory transfers are reduced by an order of magnitude, preserving the $O(N)$ memory complexity of FlashAttention and adding an estimated $<5\%$ overhead compared to standard inference. + \item \textbf{Naive Unfused Dual-Attention ($O(N^2)$ Memory):} At $N=4096$, the naive implementation required $152.1$ ms of latency. However, for any sequence length $N > 4096$, the materialization of the full $N \times N$ attention matrix caused a catastrophic \texttt{CUDA OutOfMemoryError}, completely halting inference. The $O(N^2)$ memory footprint makes unfused dual-attention fundamentally impossible for extended context windows on standard hardware. + \item \textbf{PagedFieldprintAttention ($O(N)$ Memory):} By maintaining intermediate softmax reductions in SRAM, our Triton kernel strictly bounded the memory footprint to $O(N)$, completely preventing the VRAM explosion and allowing infinite sequence scaling bounded only by compute time. While the raw latency on older Pascal architecture (lacking Tensor Cores) was higher ($2787.0$ ms at $N=4096$) due to unoptimized SRAM bank layouts compared to native cuBLAS, the prevention of the HBM memory thrashing proves the architectural necessity of the fused approach for modern hardware. \end{itemize} \section{Conclusion}