Before we dive into this, it is important that you have understood the attention mechanism first, you can read my blog about it.
Flash Attention
First lets clear up on some jargon:
- SRAM (Static RAM) - is the fastest memory in the hierarchy, it is built directly onto the SM die.
- VRAM (Video RAM) or HBM (High Bandwidth Memory) - this is the value that
nvidia-smishows. It’s stacked DRAM dies right next to the GPU die, giving very short, wide interconnects.
Now, the main speedup that comes when using flash attention is when we avoid all the intermediary memory transfers between the HBM and the SM when processing attention. So lets see the naive attention mechanism and the memory transfers it requires:
- First, and are loaded from the HBM and is computed. After the computation, is written back to HBM.
- Now, is loaded back again from the HBM and softmax is computed after which is written back to the HBM.
- Again, is loaded back from HBM as well as and is computed and written back to HBM.
As it can be seen, there are too many unnecessary reads and writes which slows down the process. And you can imagine it for huge matrices, multiple heads, multiple blocks these reads and writes affects the overall speed. The GPU’s compute units can do arithmetic far faster than HBM can supply data. Naive attention is bottlenecked by bandwidth, not by the matmuls.
For a sequence of length n, the S matrix alone requires memory. At n = 8192 with fp16, that’s ~134MB just for attention scores, per head, per layer. Flash Attention reduces memory to and more importantly, Q, K, and V are each read from HBM exactly once.
The way Flash Attention achieves this is by fusing the softmax and the output multiplication into a single pass. To understand how, we first need to look at streaming softmax.
Streaming Softmax
The main idea is to do softmax in tiles. The way this works is:
There are 2 variables that are initialized: and , where is the running maximum and the running denominator.
More specifically
- Find the Local Max: Find the maximum value within just this tile ().
- Update the Global Max: Figure out the new overall maximum:
- Compute the Local Denominator: Calculate the sum of exponentials for just this tile, using the new global max to keep numbers stable:
- Update the Global Denominator: Scale the old global denominator using the correction factor, then add the local denominator:
The trick here is the correction factor . Whenever we hit a new maximum value, this factor scales down the previously accumulated denominator. It mathematically adjusts the old sum so it acts as if we had known the new global maximum from the very beginning.
Lets take 1 row of a matrix with tile size 2.
Lets see how its done normally
Pass 1: Find the max of the full row
- Load all the values from VRAM and calculate the max.
Pass 2: Compute Exponentials & Sum
- Load the values again from VRAM and calculate the exponentials and the sum of it.
- After calculating the exponentials we need to store them back to VRAM.
Pass 3: Final Division
- Now we need to load those exponentials from VRAM and divide by the row sum.
Now lets see how it is done in streaming version
Pass 1: Streaming the Tiles
Processing Tile 1: [1, 2].
- Load [1, 2] from VRAM into registers/SRAM.
- Local Max:
- New Global Max:
- Local Denominator:
- New Global Denominator:
- Current State:
Processing Tile 2: [3, 4].
- Load [3, 4] into registers.
- Local Max:
- New Global Max:
- Local Denominator:
- New Global Denominator: Here is where the magic happens. We scale the old denominator () by the difference between the old max () and the new max ().
Final Output State: . is the exact same global denominator we got in the element-by-element example. The math perfectly guarantees that chunking the data doesn’t change the final answer.
Pass 2: Computing the Probabilities
Now that we have our true global max () and global denominator (), we do our second pass over the tiles to compute and write the final probabilities.
- Load Tile 1 [1, 2]: Compute and . Write [0.03, 0.09] to VRAM.
- Load Tile 2 [3, 4]: Compute and . Write [0.24, 0.64] to VRAM.
So far we have discussed streaming softmax, now lets see how flash attention incorporates it. Along with the two running variables, flash attention maintains one more, which is the output (this is the last step of the attention process, see above).
Steps 1, 2 and 3 are the same. Now, there is an additional step
The values in accumulator were multiplied by exponentials using the old maximum. We can fix the entire matrix block of using the exact same scalar correction factor! We scale the old matrix down, and then add the new block’s contribution:
where .
Lets see how it works with the same example, but this time we need the value matrix and .
: [2, 3, 5, 4] \[1em] : [10, 20, 30, 40]
Tile 1
Load the first block from VRAM into fast memory: and .
- Find New Max:
- Update Denominator ():
- Update Output Accumulator (): Instead of dividing by , we just multiply the exponentials directly against the Values.
Tile 1 is done. It gets dumped from SRAM. Current hardware state: .
Tile 2
Load the next block: and .
- Find New Max:
Because the global max just changed from to , we need to calculate our correction factor: .
- Update Denominator ():
We apply the correction factor to fix the old denominator, then add the new tile’s sum.
- Update Output Accumulator ():
We apply the exact same correction factor to fix our running matrix, then add the new tile’s unnormalized matrix multiplication.
Tile 2 is done. We have reached the end of the sequence. Final hardware state: .
The Final Division
We streamed through the entire sequence without ever writing a single intermediate probability or exponential to main memory. We just maintained three variables in our fast registers. To get the final, true attention output, we do one division right before we write to VRAM:
For a sequence of length n, the matrix alone requires memory At n = 8192 and fp16, that’s ~134MB just for attention scores, per head, per layer. Flash Attention reduces memory to O(n). We streamed through all tiles in a single pass: Q, K, and V are each read from HBM exactly once.