Transformer推理性能优化技术很重要的一个就是K V cache,能否通俗分析,可以结合代码?

原文:Transformer推理性能优化技术很重要的一个就是K V cache,能否通俗分析,可以结合代码? - 知乎

为什么要研究KV cache?


设输入序列的长度为 s ,输出序列的长度为 n ,模型深度为l,维度为h,以 FP16 来保存KV cache,那么KV cache的峰值显存占用大小为 b(s+n)h∗l∗2∗2=4blh(s+n) 。这里第一个2表示K/V cache,第二个2表示 FP16 占2个bytes。
以 GPT3 (175B) 为例,对比 KV cache 与模型参数占用显存的大小。GPT3 模型weight占用显存大小为350GB (FP16),层数 l为96,维度h为12888。

batch size s+n KV cache(GB) KV cache/weight
4 4096 75.5 0.22
16 4096 302 0.86
64 4096 1208 3.45


参考上图,随着 batch size 和 长度的增大,KV cache 占用的显存开销快速增大,甚至会超过模型本身。从LLM的趋势上而讲,主要有三个方面来说明kv cache优化的必要性:
1、总体趋势上LLM 的窗口长度在不断增大,因此就出现一组主要矛盾,即:对不断增长的 LLM 的窗口长度的需要与有限的 GPU 显存之间的矛盾。因此优化 KV cache 非常必要。
OpenAI API场景,API最烧钱的是输入而非输出,输入包括prefill prompt 和conversation,长度动辄数十K token。虽说每输入token比每输出token便宜,但能够降低kv重新计算的开销,无论是硬件资源门槛,还是模型推理降本,都有着极为积极的作用。
2、对于消费级显卡这种性价比较高的显卡而言,显存容量相对较小,KV cache从一定程度上降低了模型的batch size,因而KV cache优化在工程落地中更显重要。

框架 模型 input/output 机器配置 tokens/s 最佳batch_size gpu/cpu负载
TGI llama2 70B 128/512 4090*8,32c 291 13 cpu25%,gpu95%
TGI llama2 70B 128/512 A800,32c 1122 43 cpu25%,gpu95%


从上表能够看出,类似4090等消费级显卡,其主要的瓶颈时batch_size,而影响batch_size的主要因素就是显存容量,而KV cache在推理过程中占用大量的显存。如果能通过KV cache降低显存占用,从一定程度上就能提升消费级显卡的性价比,带来非常高的商业收益。
3、sora/sd3等文生视频或者文生图的模型,纷纷放弃u-net架构,转而支持DIF(diffusion transformer)架构。对此类AIGC模型而言, KV cache同样能起到类似LLM上的加速效果。
根据资料,Sora类训练任务的特点是模型本体不大(10B以下),但是由于视频复杂性带来的序列长度特别长(接近1000kpatches的长度),可以对模型推理进行简易测算:

  • 按照batch size = 1 进行测算,kv cache和模型权重对显卡占比能达到10:1(例如4090的24G显存,2G分给模型,22G分给kv cache)左右,这个场景的显存分配占比与LLM差异性还是非常的大。
  • 按照batch size = 4 进行测算,kv cache和模型权重对显卡占比能达到40:1(batch size越大,kv cache的显存越大)

由此可见,KV cache会成为视频生成领域的一个重要瓶颈,但不排除有替代kv cache的加速方案。


KV cache的作用


解释kv cache之前,先看一组对话:
paki:What is the apples?
llama:Apples are a boring fruit.
上述对话中,paki是自然人,llama是模型。如果对上述对话进行分析,实际上需要将llama的推理步骤分成两个阶段,即prefill和decode。
prefill阶段:输入为Q,即‘What is the apples?’,返回了第一个token,即‘Apples’,同时初始化了kv cache。
decode阶段:输入为单个词或者说q,通过自回归的方式,生成‘Apples are a boring fruit’这个句子。需要注意的是,decode计算的过程中,q的长度为1,即当前词,返回下一个词,例如通过‘Apples’生成‘are’,同时更新kv cache。
暂时无法在飞书文档外展示此内容


如上图所示,kv cache是attention计算中的全量kv 缓存,主要作用在decode阶段,目的是将输入Q优化成输入q。
我们举例说明,假设通过decode阶段通过自回归生成‘Apples are a boring fruit’这句话,当生成到‘fruit’这个词的时候,如果没有KV cache,输入为Q(‘Apples are a boring’),进行attention计算。反之,如果有了kv cache之后,输入只需要q(‘boring’这个词),即可完成attention计算。
为什么会这样,主要跟下一个token的生成给当前token的q和全量KV有关,具体attention的计算公式不再粘贴。
从这里也能看出,为什么 KV cache那么吃显存,其实主要因为随着seq长度变长和batch size增大,KV cache需要存储历史全量KV,从而跟着增大。
那就有这么一个思路,KV cacha本质上是attention计算中的一部分,如果对其进行压缩或者优化,是不是能起到推理的加速效果?
答案是肯定的,下面介绍一下主要的优化方法。


基于KV cache的加速策略


从整体上来讲,KV cache主要分成5个方向的优化,即Sparse、Quantization、Allocator、Window、share,我们逐个对5个方向的最新技术,做一些探讨。


Window--窗口


多轮对话场景的 LLMs 有两个难点:1. 解码阶段缓存 KV 需要耗费大量的内存;2. 流行的 LLMs 不能拓展到训练长度之外。
基于window方向的技术,解决上述问题主要有StreamingLLM[8]和LM-Infinite[14]两种方案,我们基于StreamingLLM进行介绍。
首先,自回归LLMs的一个有趣现象:无论它们与语言建模任务的相关性如何,初始tokens都被分配了惊人的大量注意力得分,并将tokens称为“attention sinks”,参考下图:

Visualization of the average attention logits in Llama-2-7B over 256 sentences, each with a length of 16.


根据attention sinks特性,我们参考之前多轮对话场景的解决方法,给出StreamingLLM的解决方案,即只保留attention sink tokens的KV(只需4个初始tokens)以及滑动窗口的KV,以锚定注意力计算并稳定模型性能的方法。

Illustration of StreamingLLM vs. existing methods. The language model, pre-trained on texts of length L, predicts the Tth token (T ≫ L).

上图的四个模块,分别对应下面的描述:

(a) 密集注意力(Dense Attention

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值