← Back to Blog

Writing FlashAttention in Triton (Part 1): The Memory Wall and the Online Softmax Trick

May 26, 2026

Before we dive into this, it is important that you have understood the attention mechanism first, you can read my blog about it.

Flash Attention

First lets clear up on some jargon:

  1. SRAM (Static RAM) - is the fastest memory in the hierarchy, it is built directly onto the SM die.
  2. VRAM (Video RAM) or HBM (High Bandwidth Memory) - this is the value that nvidia-smi shows. It’s stacked DRAM dies right next to the GPU die, giving very short, wide interconnects.

Now, the main speedup that comes when using flash attention is when we avoid all the intermediary memory transfers between the HBM and the SM when processing attention. So lets see the naive attention mechanism and the memory transfers it requires:

X=QKTdY=softmax(X)O=YV X = { Q K^T \over \sqrt{d} } \\[1em] Y = softmax(X) \\[1em] O = Y V
  1. First, QQ and KK are loaded from the HBM and XX is computed. After the computation, XX is written back to HBM.
  2. Now, XX is loaded back again from the HBM and softmax is computed after which YY is written back to the HBM.
  3. Again, YY is loaded back from HBM as well as VV and OO is computed and written back to HBM.

As it can be seen, there are too many unnecessary reads and writes which slows down the process. And you can imagine it for huge matrices, multiple heads, multiple blocks these reads and writes affects the overall speed. The GPU’s compute units can do arithmetic far faster than HBM can supply data. Naive attention is bottlenecked by bandwidth, not by the matmuls.

For a sequence of length n, the S matrix alone requires O(n2)O(n^2) memory. At n = 8192 with fp16, that’s ~134MB just for attention scores, per head, per layer. Flash Attention reduces memory to O(n)O(n) and more importantly, Q, K, and V are each read from HBM exactly once.

The way Flash Attention achieves this is by fusing the softmax and the output multiplication into a single pass. To understand how, we first need to look at streaming softmax.

Streaming Softmax

The main idea is to do softmax in tiles. The way this works is:

There are 2 variables that are initialized: mold=infm_{old} = -inf and dold=0d_{old} = 0, where mm is the running maximum and dd the running denominator.

More specifically

  1. Find the Local Max: Find the maximum value within just this tile (mlocalm_{local}).
  2. Update the Global Max: Figure out the new overall maximum:
mnew=max(mold,mlocal)m_{new} = \max(m_{old}, m_{local})
  1. Compute the Local Denominator: Calculate the sum of exponentials for just this tile, using the new global max to keep numbers stable:
dlocal=xtileexmnewd_{local} = \sum_{x \in \text{tile}} e^{x - m_{new}}
  1. Update the Global Denominator: Scale the old global denominator using the correction factor, then add the local denominator:
dnew=doldemoldmnew+dlocald_{new} = d_{old} \cdot e^{m_{old} - m_{new}} + d_{local}

The trick here is the correction factor emoldmnewe^{m_{old} - m_{new}}. Whenever we hit a new maximum value, this factor scales down the previously accumulated denominator. It mathematically adjusts the old sum so it acts as if we had known the new global maximum from the very beginning.

Lets take 1 row of a matrix [1,2,3,4][1, 2, 3, 4] with tile size 2.

Lets see how its done normally

Pass 1: Find the max of the full row

Pass 2: Compute Exponentials & Sum

Pass 3: Final Division

Now lets see how it is done in streaming version

Pass 1: Streaming the Tiles

Processing Tile 1: [1, 2].

Processing Tile 2: [3, 4].

dnew=1.367e24+1.367dnew=1.367(0.135)+1.367dnew=0.185+1.367=1.552d_{new} = 1.367 \cdot e^{2 - 4} + 1.367 \\[1em] d_{new} = 1.367 \cdot (0.135) + 1.367 \\[1em] d_{new} = 0.185 + 1.367 = \mathbf{1.552}

Final Output State: m=4,d=1.552m = 4, d = 1.552. 1.5521.552 is the exact same global denominator we got in the element-by-element example. The math perfectly guarantees that chunking the data doesn’t change the final answer.

Pass 2: Computing the Probabilities

Now that we have our true global max (44) and global denominator (1.5521.552), we do our second pass over the tiles to compute and write the final probabilities.

So far we have discussed streaming softmax, now lets see how flash attention incorporates it. Along with the two running variables, flash attention maintains one more, which is the output OO (this is the last step of the attention process, see above).

Steps 1, 2 and 3 are the same. Now, there is an additional step

The values in OoldO_{old} accumulator were multiplied by exponentials using the old maximum. We can fix the entire matrix block of OO using the exact same scalar correction factor! We scale the old OO matrix down, and then add the new block’s contribution:

Onew=Ooldemoldmnew+jtileeSjmnewVjO_{new} = O_{old} \cdot e^{m_{old} - m_{new}} + \sum_{j \in \text{tile}} e^{S_j - m_{new}} \cdot V_j

where Slocal=Q×KlocalTS_{local} = Q \times K_{local}^T.

Lets see how it works with the same example, but this time we need the value matrix and SS.

SS: [2, 3, 5, 4] \[1em] VV: [10, 20, 30, 40]

Tile 1

Load the first block from VRAM into fast memory: S=[2,3]S = [2, 3] and V=[10,20]V = [10, 20].

  1. Find New Max:
mlocal=3mnew=max(,3)=3m_{local} = 3 \\[1em] m_{new} = \max(-\infty, 3) = \mathbf{3}
  1. Update Denominator (dd):
dnew=dolde3+(e23+e33)dnew=0+(e1+1)dnew0.368+1=1.368d_{new} = d_{old} \cdot e^{-\infty - 3} + (e^{2-3} + e^{3-3}) \\[1em] d_{new} = 0 + (e^{-1} + 1) \\[1em] d_{new} \approx 0.368 + 1 = \mathbf{1.368}
  1. Update Output Accumulator (OO): Instead of dividing by dd, we just multiply the exponentials directly against the Values.
Onew=Oolde3+(e2310+e3320)Onew=0+(0.36810+120)Onew=3.68+20=23.68O_{new} = O_{old} \cdot e^{-\infty - 3} + (e^{2-3} \cdot 10 + e^{3-3} \cdot 20) \\[1em] O_{new} = 0 + (0.368 \cdot 10 + 1 \cdot 20) \\[1em] O_{new} = 3.68 + 20 = \mathbf{23.68}

Tile 1 is done. It gets dumped from SRAM. Current hardware state: m=3,d=1.368,O=23.68m = 3, d = 1.368, O = 23.68.

Tile 2

Load the next block: S=[5,4]S = [5, 4] and V=[30,40]V = [30, 40].

  1. Find New Max:
mlocal=5mnew=max(3,5)=5m_{local} = 5 \\[1em] m_{new} = \max(3, 5) = \mathbf{5}

Because the global max just changed from 33 to 55, we need to calculate our correction factor: e35=e20.135e^{3 - 5} = e^{-2} \approx \mathbf{0.135}.

  1. Update Denominator (dd):

We apply the correction factor to fix the old denominator, then add the new tile’s sum.

dnew=(dold0.135)+(e55+e45)dnew=(1.3680.135)+(1+0.368)dnew=0.185+1.368=1.553d_{new} = (d_{old} \cdot 0.135) + (e^{5-5} + e^{4-5}) \\[1em] d_{new} = (1.368 \cdot 0.135) + (1 + 0.368) \\[1em] d_{new} = 0.185 + 1.368 = \mathbf{1.553}
  1. Update Output Accumulator (OO):

We apply the exact same correction factor to fix our running OO matrix, then add the new tile’s unnormalized matrix multiplication.

Onew=(Oold0.135)+(e5530+e4540)Onew=(23.680.135)+(130+0.36840)Onew=3.20+(30+14.72)Onew=3.20+44.72=47.92O_{new} = (O_{old} \cdot 0.135) + (e^{5-5} \cdot 30 + e^{4-5} \cdot 40) \\[1em] O_{new} = (23.68 \cdot 0.135) + (1 \cdot 30 + 0.368 \cdot 40) \\[1em] O_{new} = 3.20 + (30 + 14.72) \\[1em] O_{new} = 3.20 + 44.72 = \mathbf{47.92}

Tile 2 is done. We have reached the end of the sequence. Final hardware state: m=5,d=1.553,O=47.92m = 5, d = 1.553, O = 47.92.

The Final Division

We streamed through the entire sequence without ever writing a single intermediate probability or exponential to main memory. We just maintained three variables in our fast registers. To get the final, true attention output, we do one division right before we write to VRAM:

Final Output=Od=47.921.553=30.86\text{Final Output} = \frac{O}{d} = \frac{47.92}{1.553} = \mathbf{30.86}

For a sequence of length n, the SS matrix alone requires O(n2)O(n^2) memory At n = 8192 and fp16, that’s ~134MB just for attention scores, per head, per layer. Flash Attention reduces memory to O(n). We streamed through all tiles in a single pass: Q, K, and V are each read from HBM exactly once.