Data movement between HBM and on-chip SRAM during Flash Attention (Dao et al. 2022). Score and probability
matrices never leave SRAM — every float transferred is tallied.
Every square represents one float. Unlike standard attention, the N×N
score (S) and probability (P̃) matrices live and die entirely inside SRAM — they never touch HBM.
Phase 1: Initial State
Ready
Transfer Comparison
Metric
Standard Attention
Flash Attention
Peak HBM for intermediates
S (128×128) + P (128×128) = 32,768 floats
No S or P in HBM
Total floats transferred
—
—
Savings
—
—
Extra FLOPs
—
Recomputes S, P̃ per tile (trades compute for less I/O)
X (Input)
Q / Wq
K / Wk
V / Wv
S tile (SRAM only)
P̃ tile (SRAM
only)
O (output)
m (row-max stats)
ℓ (row-sum
stats)
S and P̃ tiles exist only inside SRAM (scratch slot) — they are never written to HBM. This eliminates
the O(N²) memory traffic that dominates standard attention.
Backward pass (not shown): In standard attention, S and P are stored in HBM during the
forward pass so the backward pass can read them. Flash Attention does not store S and P —
instead, it recomputes them from Q, K, V blocks during the backward pass. This trades
O(N²) extra FLOPs for O(N²) less HBM storage and I/O. Since attention is memory-bound (not
compute-bound), this trade is net faster.