In this paper, author Yao Fu discusses the challenges faced in deploying LLMs for use with long context sizes. Author highlights that these challenges arise from a ballooning KV cache, as context size grows. This makes the case for research into KV cache compression techniques. Author lists four challenges arising from a large KV cache:
- Prefilling (feeding an entire prompt through a decoder only transformer) takes longer time and consumes more memory, as context size grows.
- A large KV cache occupying GPU HBM (high bandwidth memory) limits the number of concurrent requests that can be served on that GPU.
- During decoding, for each token we load the KV cache from HBM to SM (shared memory). A larger cache takes longer to load, and so increases latency.
- When KV cache size exceeds HBM, we offload it to DDR, increasing latency (this is called "context switching latency" in the paper).
Sizing KV Cache
Authors show that for the Yi-34B model, the KV Cache size grows from 0.91GB for 4k context length to 22.8GB for 100k context length. Let me compute these numbers for Meta's Llama-3 70B model which uses Grouped Query Attention (64 attentions heads but 8 KV attention heads).
\[ \text{100K Context:} \qquad \underbrace{100000}_{\text{seqlen}} \times \underbrace{80}_{\text{hidden layers}} \times \underbrace{8}_{\text{# KV heads}} \times \underbrace{2}_{\text{KV}} \times \underbrace{2}_{\text{bf16}} \times \underbrace{8192}_{\text{hidden size}} / \underbrace{64}_{\text{# Q heads}} = 32.8 \text{ GB} \\ \text{4K Context:} \qquad \underbrace{4000}_{\text{seqlen}} \times \underbrace{80}_{\text{hidden layers}} \times \underbrace{8}_{\text{# KV heads}} \times \underbrace{2}_{\text{KV}} \times \underbrace{2}_{\text{bf16}} \times \underbrace{8192}_{\text{hidden size}} / \underbrace{64}_{\text{# Q heads}} = 1.3 \text{ GB} \\ \]If we use 4 x 80GB A100 with tensor parallelism to serve this Llama model in bf16, we'll be left with \(4 \times 80 - 2 \times 70 = 180 \text{ GB}\) space for storing the KV cache. This permits fitting in ~138 concurrent requests with 4K context, but only ~5 (=180/32.8) for 100K.
Compute Bound vs. Memory Bound Operations
Author computes "critical arithmetic intensity" of an A100 GPU as the ratio of (1) its bf16 FLOPs/second (312 T), to (2) its HBM bandwidth (2 TB/second), to equal 156. This means for each byte of data loaded from HBM, the GPU can do 156 floating point operations. Equivalently, if we have an operation which can be performed independently on each byte, we could load 156 bytes from the HBM and perform 1 FLOP on each byte, in parallel. If we can perform a FLOP for an operator in parallel for more than 156 bytes, we call that operator "compute-bound". If that numbers is less than 156 we call it "memory-bound". Author notes that the degree of parallelization for transformers is the number of tokens it can process in parallel.
Impact of KV Cache Size on Latency and Concurrency
Author shows with simple example how a larger KV cache impacts latency (for each of prefilling, decoding, and context switching between GPU/CPU) as well as concurrency (number of requests that can be serviced simultaneously).
Prefilling Latency
The self-attention operation during prefilling can be parallelized easily as for each attention head, we can use a triangular mask to compute self-attention for each token position in parallel. For contexts longer than 156, prefilling can be considered compute bound. We can thus compute theoretical prefilling latency as the ratio of (1) FLOPs for prefillng, and (2) FLOPs/second of an A100 GPU. To compute prefilling FLOPs, author uses the approximate expression for FLOPs per token from Table 1 in Kaplan et. al. (2020) (\(N\) is number of non-embedding parameters, \(n_{layer}\) is number of layers, \(n_{ctx}\) is number of tokens in the context, and \(d_{attn}\) is the dimension of attention head output):
\[ \text{FLOPs per token} = 2N + 2n_{layer}n_{ctx}d_{attn} \] For a context length of 50K: \[ \text{Total FLOPs} = \underbrace{50k}_{\text{seqlen}} \times \left(2 \times \underbrace{70B}_{\text{# parameters}} + \underbrace{80}_{\text{hidden layers}} \times \underbrace{50k}_{\text{seqlen}} \times \underbrace{128}_{\text{attn output dim}} \right) = 7.03 \text{ PFLOPs} \]Observe that in this calculation, the total number of FLOPs is quadratically dependent on sequence length. The prefilling latency is 7.03 PFLOPs / 312 TFLOPs/second = 22.5 seconds. For a 4K context length, the latency is 1.8 seconds (= ~560 TFLOPs / 312 TFLOPs/second).
Decoding Latency
We cannot decode more than one tokens in parallel - to decode token at time \(t\), we should've decoded the token at \(t-1\) already. This causes decoding to be memory bound, and the latency is the ratio of (1) bytes of memory access (sum of model weights and KV cache size), to (2) A100 HBM bandwidth (2TB/second). Author assumes we generate 250 tokens in one go on average. For 100k context length using Llama-3 70B, this implies: \[ 250 \times \left(\underbrace{140 GB}_{\text{model weights}} + \underbrace{16.4 GB}_{\text{KV cache}} \right) = 39.1 TB \] Given A100 HBM bandwidth of 2TB/seconds, this implies latency of 19.5 seconds. For a 4K long sequence, this latency reduces to 17.7 seconds.
Concurrency
See the section on Sizing KV Cache above for how context length impacts concurrency.
Context Switching
Say we have two users interacting with our LLM. Assume we want to offload KV Cache 1 from HBM to CPU and load KV Cache 2 from CPU to HBM. The latency is determined by how fast we can do transfers between HBM and CPU DDR, i.e. the PCIE bandwidth. Using Gen 4 PCIE bandwidth of 20 GB/second, for a 50k context length, the latency is \(2 \times 16.4 \text{ GB} / 20 \text{ GB-per-second}\) or 1.6 seconds.