原文:Transformer推理性能优化技术很重要的一个就是K V cache,能否通俗分析,可以结合代码? - 知乎
为什么要研究KV cache?
设输入序列的长度为 s ,输出序列的长度为 n ,模型深度为l,维度为h,以 FP16 来保存KV cache,那么KV cache的峰值显存占用大小为 b(s+n)h∗l∗2∗2=4blh(s+n) 。这里第一个2表示K/V cache,第二个2表示 FP16 占2个bytes。
以 GPT3 (175B) 为例,对比 KV cache 与模型参数占用显存的大小。GPT3 模型weight占用显存大小为350GB (FP16),层数 l为96,维度h为12888。
batch size | s+n | KV cache(GB) | KV cache/weight |
4 | 4096 | 75.5 | 0.22 |
16 | 4096 | 302 | 0.86 |
64 | 4096 | 1208 | 3.45 |
参考上图,随着 batch size 和 长度的增大,KV cache 占用的显存开销快速增大,甚至会超过模型本身。从LLM的趋势上而讲,主要有三个方面来说明kv cache优化的必要性:
1、总体趋势上LLM 的窗口长度在不断增大,因此就出现一组主要矛盾,即:对不断增长的 LLM 的窗口长度的需要与有限的 GPU 显存之间的矛盾。因此优化 KV cache 非常必要。
OpenAI API场景,API最烧钱的是输入而非输出,输入包括prefill prompt 和conversation,长度动辄数十K token。虽说每输入token比每输出token便宜,但能够降低kv重新计算的开销,无论是硬件资源门槛,还是模型推理降本,都有着极为积极的作用。
2、对于消费级显卡这种性价比较高的显卡而言,显存容量相对较小,KV cache从一定程度上降低了模型的batch size,因而KV cache优化在工程落地中更显重要。
框架 | 模型 | input/output | 机器配置 | tokens/s | 最佳batch_size | gpu/cpu负载 |
TGI | llama2 70B | 128/512 | 4090*8,32c | 291 | 13 | cpu25%,gpu95% |
TGI | llama2 70B | 128/512 | A800,32c | 1122 | 43 | cpu25%,gpu95% |
从上表能够看出,类似4090等消费级显卡,其主要的瓶颈时batch_size,而影响batch_size的主要因素就是显存容量,而KV cache在推理过程中占用大量的显存。如果能通过KV cache降低显存占用,从一定程度上就能提升消费级显卡的性价比,带来非常高的商业收益。
3、sora/sd3等文生视频或者文生图的模型,纷纷放弃u-net架构,转而支持DIF(diffusion transformer)架构。对此类AIGC模型而言, KV cache同样能起到类似LLM上的加速效果。
根据资料,Sora类训练任务的特点是模型本体不大(10B以下),但是由于视频复杂性带来的序列长度特别长(接近1000kpatches的长度),可以对模型推理进行简易测算:
- 按照batch size = 1 进行测算,kv cache和模型权重对显卡占比能达到10:1(例如4090的24G显存,2G分给模型,22G分给kv cache)左右,这个场景的显存分配占比与LLM差异性还是非常的大。
- 按照batch size = 4 进行测算,kv cache和模型权重对显卡占比能达到40:1(batch size越大,kv cache的显存越大)
由此可见,KV cache会成为视频生成领域的一个重要瓶颈,但不排除有替代kv cache的加速方案。
KV cache的作用
解释kv cache之前,先看一组对话:
paki:What is the apples?
llama:Apples are a boring fruit.
上述对话中,paki是自然人,llama是模型。如果对上述对话进行分析,实际上需要将llama的推理步骤分成两个阶段,即prefill和decode。
prefill阶段:输入为Q,即‘What is the apples?’,返回了第一个token,即‘Apples’,同时初始化了kv cache。
decode阶段:输入为单个词或者说q,通过自回归的方式,生成‘Apples are a boring fruit’这个句子。需要注意的是,decode计算的过程中,q的长度为1,即当前词,返回下一个词,例如通过‘Apples’生成‘are’,同时更新kv cache。
暂时无法在飞书文档外展示此内容
如上图所示,kv cache是attention计算中的全量kv 缓存,主要作用在decode阶段,目的是将输入Q优化成输入q。
我们举例说明,假设通过decode阶段通过自回归生成‘Apples are a boring fruit’这句话,当生成到‘fruit’这个词的时候,如果没有KV cache,输入为Q(‘Apples are a boring’),进行attention计算。反之,如果有了kv cache之后,输入只需要q(‘boring’这个词),即可完成attention计算。
为什么会这样,主要跟下一个token的生成给当前token的q和全量KV有关,具体attention的计算公式不再粘贴。
从这里也能看出,为什么 KV cache那么吃显存,其实主要因为随着seq长度变长和batch size增大,KV cache需要存储历史全量KV,从而跟着增大。
那就有这么一个思路,KV cacha本质上是attention计算中的一部分,如果对其进行压缩或者优化,是不是能起到推理的加速效果?
答案是肯定的,下面介绍一下主要的优化方法。
基于KV cache的加速策略
从整体上来讲,KV cache主要分成5个方向的优化,即Sparse、Quantization、Allocator、Window、share,我们逐个对5个方向的最新技术,做一些探讨。
Window--窗口
多轮对话场景的 LLMs 有两个难点:1. 解码阶段缓存 KV 需要耗费大量的内存;2. 流行的 LLMs 不能拓展到训练长度之外。
基于window方向的技术,解决上述问题主要有StreamingLLM[8]和LM-Infinite[14]两种方案,我们基于StreamingLLM进行介绍。
首先,自回归LLMs的一个有趣现象:无论它们与语言建模任务的相关性如何,初始tokens都被分配了惊人的大量注意力得分,并将tokens称为“attention sinks”,参考下图:
Visualization of the average attention logits in Llama-2-7B over 256 sentences, each with a length of 16.
根据attention sinks特性,我们参考之前多轮对话场景的解决方法,给出StreamingLLM的解决方案,即只保留attention sink tokens的KV(只需4个初始tokens)以及滑动窗口的KV,以锚定注意力计算并稳定模型性能的方法。
Illustration of StreamingLLM vs. existing methods. The language model, pre-trained on texts of length L, predicts the Tth token (T ≫ L).
上图的四个模块,分别对应下面的描述:
(a) 密集注意力(Dense Attention