Transformers Inference Optimization Toolset - AstraBlog
Transformers Inference Optimization Toolset - AstraBlog
Post
Large Language Models are pushing the boundaries of artificial intelligence, but their immense size poses
significant computational challenges. As these models grow, so does the need for smart optimization
techniques to keep them running efficiently on modern hardware.
In this post, we’ll explore key optimization strategies that are making LLMs faster and more memory-
efficient. We’ll start with a brief look at GPU memory hierarchy, which forms the foundation for many of
these techniques. Then, we’ll explore algorithms that allow LLMs to process information more quickly and
handle longer contexts. Understanding these techniques offers valuable insights helping to unlock the full
potential of Large Language Models.
The idea of this post is not just to discuss transformer-specific optimizations, since there are plenty of
resources, where one can examine every inch of transformer to make it faster (my favourite one is the “Let’s
reproduce GPT-2” by Andrej Karpathy). The main goal is to lower the entry barrier for those curious
researchers who are currently unable to piece together the huge number of articles and papers into one
picture.
A lot of optimization techniques will be left out, like for example quantization methods, which are relatively
diverse and deserve a separate post. Also we’ll mostly discuss transformer inference and won’t mention
some training tricks, such as mixed-precision training, gradient checkpointing or sequence packing. But even
so a lot of optimizations from this post could be applied to training as well.
Graphic processor unit performs all of the computations by multiple streaming multiprocessors (SM) (these
are similar to the cores in the CPU). SM is basic GPU building block: it has its own instruction schedulers and
various instruction execution pipelines. Modern GPUs are also equipped with special off-chip memory called
high bandwidth memory (HBM), where data is initially stored and ultimately written back. Unlike to the
system dynamic random access memory (DRAM), which is controlled by CPU and typically optimized for
https://astralord.github.io/posts/transformer-inference-optimization-toolset/#pagedattention--vllm 1/29
02/10/2024, 20:41 Transformers Inference Optimization Toolset | AstraBlog
low latency access, HBM is physically bonded to the GPUs in stacked layers with thousands of pins and
provides massively parallel data throughput by design.
Streaming multiprocessors access data and code from HBM via the L2 cache. It acts as an intermediate level
between off-chip and on-chip memory and caches data that be shared among multiple SMs. It also situated
in the path of data moving between devices. And finally, each SM has its own L1 cache and shared memory
(SRAM), a low-latency on-chip memory caches: they are order of magnitude faster than HBM but many
orders of magnitude smaller in size. L1 cache is managed by the GPU hardware, while SRAM can be explicitly
managed by the programmer through NVIDIA tools.
The GPUs can communicate to each other with a high bandwidth interconnect called NVLink, and they can
talk to the outside world with a PCIe bus (a high-speed bus standard, common on motherboards to transfer
data) or a special ethernet alternative called Infiniband. Usually, 8 GPUs are packed into a single node. Feel
free to check out my post on parallelization strategies to learn more on multi-device training.
GPU 0
Device L1/SRAM SM
PCIe
DRAM
L1/SRAM SM
System
NVLink
DRAM L2 L1/SRAM SM
PCIe Device
DRAM
L1/SRAM SM
GPU 1
Compute performance measured by the number of trillion float operations per second (TFLOPS).
GPU memory required to store model parameters, hidden activations and cache values, measured in
GBs. For instance, GPT-3 has 175 billion parameters, so we need 350 GBs of memory just to keep them on
device in fp16.
Memory bandwidth measured in GB/s - the speed of bytes movement from GPU to processing units.
GPU capabilities grow exponentially fast. According to NVIDIA documentation, T4 graphics card released in
2018 had 65 TFLOPs, 40 SMs with 64KB L1 cache each, 4MB L2 cache with 1.3TB/s bandwidth and 16GB HBM
with 300 GB/s bandwidth. After just 2 years A100 was released with 312 TFLOPs and 108 SMs with 192KB of
L1, 40MB of L2 cache and 80 GB of HBM with 1.55 TB/s bandwidth. Compare those numbers to the latest B100
card, which can perform 1.8 PFLOPs and which HBM has a capacity of 192 GB and a throughput of 8 TB/s.
https://astralord.github.io/posts/transformer-inference-optimization-toolset/#pagedattention--vllm 2/29
02/10/2024, 20:41 Transformers Inference Optimization Toolset | AstraBlog
Bandwidth
B100
GB/s
4,096 H100
A100
1,024 V100
T4
256 Performance
64 256 1,024 TFLOPs
Memory bandwidth vs compute performance rapid growth. Note that both axes are log-scale. 1
Time required for memory accesses can vary depending on the devices, their modifications and
infrastructure setups. But the main point to remember is that if we compare throughput numbers, we will see
that some of them differ by orders of magnitude:
These numbers show us that while the number of operations per second matters, the operand placement can
be even more important when we optimizing for inference speed. Keep in mind that the slower memory
always dominates performance bottlenecks.
Depending on the balance of computation and memory accesses, operations can be classified as follows:
1. Compute-bound: the time spent on arithmetic operations exceeds time spent for other operations such
as memory accesses. Typical examples are linear layer with large inner dimension or convolutional layer
with large number of channels.
2. Memory-bound: the time taken by memory accesses exceeds computation time. Most operations are
memory bound, e.g. elementwise operations (activation functions, dropouts) or reductions (sum,
softmax, normalization).
3. Overhead-bound: everything else, such as communication-bound, interpreter-bound, etc. We won’t
discuss it in this post, however I strongly advice to take a look Making Deep Learning Go Brrrr From First
Principles blogpost to understand GPU mechanisms and why in most of the cases our bottleneck may
not be related to them at all.
The balance between first two is commonly measured by the arithmetic intensity, which is the number of
arithmetic operations per byte of memory access required to run the computation. For example, if we apply
https://astralord.github.io/posts/transformer-inference-optimization-toolset/#pagedattention--vllm 3/29
02/10/2024, 20:41 Transformers Inference Optimization Toolset | AstraBlog
read 2 bytes
make 1 comparison
write 2 bytes
for each element of the tensor. Regardless of the size of x, arithmetic intensity for ReLU is equal to
. Again, this means that for each operation we need to make 4 memory accesses.
# flops 1
# bytes
= 4
Arithmetic intensity is commonly compared to a hardware specific ops:byte ratio to find if we are in
compute- or memory-bound scenario. To explain how it works, let’s take for example a linear layer forward
pass on A100 GPU. Given an input batch x ∈ R B×d and weight matrix W ∈ R d×d (here B is a batch size and
d is an embedding dimension) linear layer basically represents a matrix multiplication xW. We can calculate
that linear layer computation requires 2Bd 2 flops. 2 Hence the compute-time for A100 will be
# flops 2Bd 2
T compute = = s.
compute performance 312 ⋅ 10 12
At the same time we need to read 2d 2 bytes from memory to load weight matrix W (again under the
condition that we work with fp16/bf16). Also, just for simplicity let’s say that B ≪ d and we can neglect the
loading time of x compared to weight matrix W. Model parameters are usually stored at HBM, therefore
# bytes 2d 2
T memory = = s.
memory bandwidth 1.55 ⋅ 10 12
Recall that arithmetic intensity is equal to , while ops:byte is given by . To find the
# flops compute performance
# bytes memory bandwidth
bottleneck for our model we look at the ratio of these two terms, which is
T compute B
≈ .
T memory 200
This means that until our batch size is smaller than 200 our system performance is memory-bound. Enlarging
input batch to a value greater than 200 increases the computation time, while keeping the memory transfer
time constant, which brings us to the compute-bound scenario.
The ops:byte ratio analysis is useful, but keep in mind, that it assumes that a GPU workload is sufficiently
large to saturate compute and memory pipelines. If the workload is not large enough, or does not have
sufficient parallelism, the processor will be under-utilized and performance will be limited by latency.
https://astralord.github.io/posts/transformer-inference-optimization-toolset/#pagedattention--vllm 4/29
02/10/2024, 20:41 Transformers Inference Optimization Toolset | AstraBlog
QK T
Attention(Q, K, V) = softmax ( ) ⋅ V,
√d
where d is a hidden dimensionality for queries and keys. When we work with GPT-based models, we use
masked attention where softmax input is modified with mask tensor, setting masked attention values to
−∞ if we don’t want to attend to corresponding tokens. Input tensors are Q ∈ R L×d and K, V ∈ R M×d ,
where L and M are sequence lengths 3 .
where
head i = Attention(QW Q K V
i , KW i , VW i ), i = 1, … , h.
we call it multi-head self-attention, otherwise it is called multi-head cross-attention. We’ll focus on self-
attention mechanism as it’s widely used in generative LLMs.
We’ll focus on the core attention mechanism. Let’s introduce new tensor names to simplify the notation: let’s
call dot product S := QK T ∈ R L×L , normalized attention weights P := softmax(S ⊗ mask) ∈ R L×L (
mask is broadcastable to S) and output O := PV ∈ R L×d .
KV Cache
When we work with models like GPT, text generation occurs in two stages:
1. Prefill - the model ingests large chunk of our prompt tokens in parallel, computing all hidden states and
outputs in one pass.
2. When prefill is finished auto-regressive decoding is launched. Decoding is in general more time-
consuming than prefill due to its sequential nature: response tokens are always generated one after
another.
T
K
Q mask V O
softmax ( ⊗ )⋅ =
S
L
1 2 3 4 5 6 7 8 9 10
https://astralord.github.io/posts/transformer-inference-optimization-toolset/#pagedattention--vllm 5/29
02/10/2024, 20:41 Transformers Inference Optimization Toolset | AstraBlog
Representation of causal self-attention during text generation for different sequence lengths L. Scaling coefficient √d is omitted.
Attention values can be computed only once for the whole bunch of prompt tokens, but then sequential one-by-one computations are
required to generate response tokens.
Python
1 @jit
2 def dot_product_attention(query, key, value, mask=None):
3 d = query.shape[-1]
4 # attn_logits shape is [batch..., num_heads, q_seq_len, kv_seq_len]
5 attn_logits = jnp.einsum('...lhd,...mhd->...hlm', query, key)
6 attn_logits = attn_logits / jnp.sqrt(d) # normalize logits
7 if mask is not None:
8 big_neg = jnp.finfo(attn_logits.dtype).min
9 attn_logits = jnp.where(mask, big_neg, attn_logits)
10 # logits -> weights
11 attention = nn.softmax(attn_logits, axis=-1)
12 # return weighted sum over values for each query position
13 output = jnp.einsum('...hlm,...mhv->...lhv', attention, value)
14 return output, attention
Computational complexity:
To compute query Q we multiply input matrix x ∈ R L×d with matrices W 1…h ∈ R d× h across h
Q d
heads which takes O(Ld 2 ) operations. The same amount of compute is needed for K and V.
Attention computation requires O(L 2 d) for both S = QK T and O = PV.
Memory accesses:
The entire memory size to be accessed is equal to the sum of the sizes of all the tensors involved, which
is O(Ld + L 2 h + d 2 ) bytes. Therefore, we have arithmetic intensity proportional to
L 2 d + Ld 2 d
−2 2
→
L h + Ld + d L≫d h
https://astralord.github.io/posts/transformer-inference-optimization-toolset/#pagedattention--vllm 6/29
02/10/2024, 20:41 Transformers Inference Optimization Toolset | AstraBlog
arithmetic
intensity
d/h -
L
0 10,000 20,000 30,000
Arithmetric intensity vs sequence length L for multi-head self-attention with embedding dimension d = 12, 288 and number of heads
h = 96 (GPT-3 scale).
We’ve noticed already that modern GPU hardware has the computational capacity orders of magnitude
higher than the memory bandwidth. As the graph shows, for sufficiently large sequence length the arithmetic
intensity is always larger than embedding dimension per attention head d
h , which is usually one or few
hundreds. 4 Hence, the arithmetic intensity is equal if not greater than ops:byte ratio.
Generally, this would imply high algorithm efficiency, but the situation is different for the second phase, text
generation. First thing to notice here is that in generation scenario there is no need to compute the attention
outputs for each token in the input sequence x, only for the last one to decode the next token x L+1 . Thus
there is no need to send the whole query Q vector into attention mechanism.
The second important thing is that we can reuse previously computed activations, namely we can cache K
and V values during generation process, hence the naming. We store the KV cache to improve efficiency and
reduce redundant computational requirements, especially for long sequences.
T
K KV cache
Q V O
softmax ( )⋅ =
L
1 2 3 4 5 6 7 8 9 10
Representation of causal self-attention with KV cache during decoding phase. Each timestep K and V computed for the last token are
added to the cache and re-used in the future steps.
Let’s get flops and memory accesses count for text generation at each generation step (we can simply
multiply these counts by L steps to get values for the whole sequence).
Compute:
https://astralord.github.io/posts/transformer-inference-optimization-toolset/#pagedattention--vllm 7/29
02/10/2024, 20:41 Transformers Inference Optimization Toolset | AstraBlog
To compute query Q we multiply input vector x ∈ R d with matrices W 1…h ∈ R d× h across h heads
Q d
which takes O(d 2 ). The same is for K and V when we store the KV cache.
Attention computation requires at most O(Ld) for both S = QK T and O = PV.
In total we need to perform O(d 2 + Ld) operations for each step. The number of operations stays the
same for L steps as in prefill stage.
Memory:
Input x and intermediate tensors Q, O occupy O(d) bytes, but K and V in cache require O(Ld)
space.
Attention logits S and weights P take at most O(Lh) bytes across all heads.
Projection weights W 1…h take again O(d 2 ) bytes.
Q,K,V
Ld + d 2
arithmetic intensity ∝ < 1,
Ld + Lh + d 2
which is definitely smaller than ops:byte ratio and we end up memory-bound. Although we reduced the
amount of operations by a factor of L by removing L − 1 queries, the number of memory accesses did not
decrease as much. The reason for that is that at each step we retrieve all the values from the KV cache, the
size of which increases in proportion to the length of the sequence.
And that brings us to another drawback - the need to store KV cache requires a lot of HBM capacity, e.g. when
we launch a decoding on a transformer with n layers, we need to store O(Lnd) bytes of KV cache. Thus we
either need to make sure that we have enough of memory to accommodate it or to load it from CPU DRAM,
which is one or two orders of magnitude slower compared to reading from HBM.
A real world example: take A100 GPU with 80 GB of HBM. Say, we work with GPT-3 model (I believe it can be
considered large yet these days) with n = 96 and d = 12, 288 and try to fit in the context of length
L = 4096. Then the space we need additionally is
So 18 GB or 22.5% of A100 memory space is required for KV cache of just one sequence sample. Keeping in
mind, that most of the GPU space would be taken by model parameters, we can conclude that even without
enlarging our batch size we may quickly run out of memory.
In the standard attention mechanism, the KV pairs are computed for each vector Q independently. This
means that for each token in the input sequence, a separate key-value pair is computed and cached.
However, in many cases, different query vectors may share similar attention patterns, and thus, the
https://astralord.github.io/posts/transformer-inference-optimization-toolset/#pagedattention--vllm 8/29
02/10/2024, 20:41 Transformers Inference Optimization Toolset | AstraBlog
corresponding keys and values could be reused across multiple queries. Multi-query attention (MQA)
(Shazeer, 2019) shares the cached key-value pairs across multiple queries, thus substantially reducing the
memory requirements associated with the KV cache during text generation.
MQA not only lowers memory consumption, but it also leads to higher inference throughput. Our
computational complexity doesn’t change, because from algorithmic point of view the number of matrix
multiplications stays the same and we only reuse K and V for different heads. But in terms of memory, KV
cache requires only O(L hd ) space now and arithmetic intensity is proportional to
Ld + d 2 dh
− → ≈ h,
L h + Lh + d 2 L≫d d + h 2
d
meaning it is steadily growing with increasing L until it reaches the plateau of few orders of magnitude. We
might be still memory-bound in most cases, but with multi-query attention technique we can
arithmetic
intensity
MQA
MHA
0 L
0 10,000 20,000 30,000
Arithmetric intensity vs sequence length L for multi-query and multi-head self-attention mechanisms during auto-regressive generation
with embedding dimension d = 12, 288 and number of heads h = 96.
In our example above with GPT-3 model, using MQA would make KV cache h = 96 times smaller, thus the
required space would take around 200 MB which is just 0.25% of one A100 GPU memory.
Of course, such acceleration and memory reduction come with a price - we cut model parameters and
therefore its potential capacity. The possible way to avoid quality degradation is to use technique, which
interpolates between MHA and MQA - grouped query attention (GQA) (Ainslie et al. (2023)). With GQA we
split h query heads into g groups, each with its own keys and values. Note that for g = 1 GQA is equal to
multi-query and for g = h GQA is the same as multi-head attention. The choice of g is a trade-off between
memory savings and potential accuracy loss. A larger group size will result in more memory savings but may
also lead to a larger approximation error in the attention computations. In practice, the optimal group size
may need to be determined empirically based on the specific model architecture and the trade-off between
memory efficiency and model performance.
https://astralord.github.io/posts/transformer-inference-optimization-toolset/#pagedattention--vllm 9/29
02/10/2024, 20:41 Transformers Inference Optimization Toolset | AstraBlog
Multi-head Grouped-query Multi-query
Value
heads
Key
heads
Query
heads
Python
1 @jit
2 def gqa_dot_product_attention(query, key, value, mask=None):
3 num_heads, num_kv_heads = query.shape[-2], key.shape[-2]
4 # broadcast K/V heads to match number of Q heads
5 num_heads_per_kv = num_heads // num_kv_heads
6 key = jnp.repeat(key, num_heads_per_kv, axis=-2)
7 value = jnp.repeat(value, num_heads_per_kv, axis=-2)
8 return dot_product_attention(query, key, value, mask)
Below is a comparison table with batched decoding/inference algorithms complexities for input
x ∈ R B×d×L and large context size L. Note, that the computation complexity is the same for all algorithms
in the table! However, the real effectiveness can vary greatly depending on the setting.
A large KV cache is not the only source of memory problems as the sequence length increases. During prefill
phase we compute all the outputs and KV pairs in one pass. This requires us to compute the attention
matrix S ∈ R L×L , which depends quadratically on the context length. What if the prompt size is so large that
we can’t fit in all attention weights? We can compute them by passing single tokens one-by-one as we do it in
decoding stage, though this procedure is much slower since it’s memory-bound. But since we know future
tokens in advance, we can feed them to model in chunks:
https://astralord.github.io/posts/transformer-inference-optimization-toolset/#pagedattention--vllm 10/29
02/10/2024, 20:41 Transformers Inference Optimization Toolset | AstraBlog
Take first C tokens (C < L) from the prompt, run them through the prefill stage and store their KV cache
values. Attention weights will be S ∈ R C×C .
Then apply the same procedure for the next C tokens, but now use cached KV pairs to attend to the
tokens in a previous chunk. Attention weights then will be S ∈ R C×2C .
Repeat until the whole prompt is prefilled. The maximum size of S at the end will be C × L.
T
K KV cache
Q mask V O
softmax ( ⊗ )⋅ =
Chunk
1 2 3
With chunking the maximum size of S depends linearly on L multiplied by controllable constant coefficient
C.
We can see that even with multi-query attention and prefill-chunking our KV cache and attention weights are
still increasing with growing context during both prefill phase and decoding phase. If we truncate the amount
of tokens each other token can attend to by some constant L w , our memory requirements will not depend on
the input sequence length. This is exactly what a technique called sliding window attention does: it changes
attention mask from lower-diagonal matrix to a band matrix.
mask
This makes attention layer focus only on local context. But notice that tokens can implicitly attend to
previous n ⋅ L w tokens, where n is a number of layers in our transformer model. This is very similar to how
receptive field works in convolutional networks. Choosing the optimal window size involves a trade-off
https://astralord.github.io/posts/transformer-inference-optimization-toolset/#pagedattention--vllm 11/29
02/10/2024, 20:41 Transformers Inference Optimization Toolset | AstraBlog
between memory efficiency and maintaining context. Larger window size preserve more context but require
more memory, while smaller window size is more memory-efficient but it may lose some context.
When we use sliding window attention, we might use rolling KV cache as well. As KV cache is now restricted
by a given constant 2 ⋅ L w and only one KV pair is changed at each step, we can remove pair related the
oldest token that we won’t attend to anymore and replace it with a newest one. In practice we keep the write
pointer at the oldest pair and move it by one after its replacement. When we reach the end of the buffer, we
move it back to the first position.
T
K KV cache
Q V O
softmax ( )⋅ =
L
1 2 3 4 5 6 7 8 9 10
Another advantage of sliding window attention is that combining it with chunking during prefill phase does
not only keep maximum size of attention matrix constant (S ∈ R C×Lw ), but also reduces the number of dot-
products to compute it.
The drawback of sliding window attention is that it may lead to degradation as not all interactions between
tokens are captured. An interesting phenomenon was found by Xiao et al. (2024), which they called attention
sink: keeping the KV of a small number of tokens in the beginning of the sequence will largely recover the
performance of window attention. They observe that LLMs outputs strong attention scores towards initial
tokens as a “sink” even if they are not semantically important.
Linear attention
Linear attention mechanism (Katharopoulos et al. (2020)) is an alternative family of methods to avoid
O(L 2 ) scaling for long sequences. Linear attention approximates the standard attention mechanism while
achieving linear time and space complexity. The key idea is in reformulating the attention operation using the
associative property of matrix multiplication and kernel functions.
Kernel function K(q, k) can be thought of as similarity measure between pair of inputs q and k, exactly like in
any attention mechanism. To simplify computation kernel function is oftenly chosen so that it can be
represented in the form of a feature map ϕ:
https://astralord.github.io/posts/transformer-inference-optimization-toolset/#pagedattention--vllm 12/29
02/10/2024, 20:41 Transformers Inference Optimization Toolset | AstraBlog
If we find such feature map, it would allow us to implicitly compute similarities between queries and keys
without explicitly computing the full attention matrix QK T .
In attention mechanism unnormalized similarity between query embedding q and key embedding k is
measured as K(q, k) = exp ( q√ k ). Each element of softmax masked attention matrix
T
d
P = softmax(mask ⊗S), the normalized similarity between query row Q i and key row K j (j ≤ i), can be
represented as
K(Q i , K j )
P ij = .
∑ j≤i K(Q i , K j )
Using feature maps we can rewrite each row i of dot-product attention output O = PV as
O i = ∑ P ij V j
j≤i
∑ j≤i K(Q i , K j ) ⋅ V j
=
∑ j≤i K(Q i , K j )
∑ j≤i ϕ(Q i ) T ϕ(K j ) ⋅ V j
=
∑ j≤i ϕ(Q i ) T ϕ(K j )
ϕ(Q i ) T ⋅ ∑ j≤i ϕ(K j )V Tj
=
ϕ(Q i ) T ⋅ ∑ j≤i ϕ(K j )
ϕ(Q i ) T ⋅ U i
= .
ϕ(Q i ) T ⋅ Z i
The above equation is simpler to follow when the numerator is written in vectorized form as follows,
Regardless of the value L we no longer need to store the quadratically growing attention matrix, we only
need O(d 2 ) space for U L = ϕ(K) T V ∈ R d×d :
φ(Q) O
T
φ(K)
⋅( )=
U
L
1 2 3 4 5 6 7 8 9 10
https://astralord.github.io/posts/transformer-inference-optimization-toolset/#pagedattention--vllm 13/29
02/10/2024, 20:41 Transformers Inference Optimization Toolset | AstraBlog
Neither prefill nor decoding phase with linear attention require O(L 2 ) space anymore. Scalar denominator ϕ(Q L ) T ⋅ Z L is omitted
here.
Another interesting property emerges with introduction of feature maps: linear attention computation can be
expressed recurrently. Note that we have
U i = U i−1 + ϕ(K i )V Ti ,
Z i = Z i−1 + ϕ(K i ),
This allows us to keep only constant-sized hidden states U and Z to compute the attention during auto-
regressive decoding and we don’t need to feed linearly increasing inputs to the model.
ϕ(q)ϕ(k) T ≈ exp(qk T )
and is a non-trivial task. Although linear attention reduces computational complexity, it may lead to a
decrease in model performance if kernel approximation doesn’t capture key properties of full attention.
Authors of the original linear attention used ϕ(x) = ELU(x) + 1 as a feature map in their experiments.
Another option is to use standard ReLU function (though it’ll set the gradients to 0 for negative inputs). But
while such choices lead to simple and effective computations, Zhang et al. (2024) showed that, unlike
softmax attention, these feature maps lead to the loss of two key features associated with higher
performance:
Low-entropy “spikyness”: intuitively, attentions are expected to attend only to relevant tokens and
ignore irrelevant ones.
Dot-product monotonicity: attention weights increase as the dot products of their corresponding
queries and keys increase. Otherwise, the lack of this monotonicity can produce unstable gradients
during training.
To solve this problem they propose Hedgehog 5 , learnable linear layer with exponential activation function,
trained to capture these properties and mimic softmax attention weights:
To learn a softmax approximation, they train ϕ mlp (x) to minimize the cross-entropy loss between the
computed linear attention weights and those that would have been computed via softmax masked attention
P:
ϕ mlp (Q i ) T ϕ mlp (K j )
L i = − ∑ P ij ⋅ log .
j≤i
∑ j≤i ϕ mlp (Q i ) T ϕ mlp (K j )
We’ve observed before that bandwidth cost (moving data from one place in memory to another) is usually
the most substantial factor when we talk about performance. Let’s step back a bit and take another detailed
look at the GPU hardware to understand why this is the case.
GPU is designed to perform thousands of simple operations in parallel, called threads. Each thread has its
own (the fastest on GPU) memory, called register. Threads which run on the same SM can be grouped into
thread blocks to be executed at the same time (the number of maximum threads in a block is limited by the
architecture, usually it’s 1024). Threads within a thread block can load data from global memory (HBM) into
shared memory (SRAM), which is used for communication between threads, perform computations, and
write results back to global memory.
When SM is given one or more thread blocks to execute, it partitions them into warps. A warp is a set of 32
threads, such that all the threads in a warp execute the same instruction. Finally, multiple thread blocks are
combined to form a grid. All the blocks in the same grid contain the same number of threads. Since the
number of threads in a block is limited, grids can be used for computations that require a large number of
thread blocks to operate in parallel.
The computations are defined in kernels, small C++ functions, which supposed to be executed multiple times
in parallel by different threads (as opposed to only once like regular C++ functions). A kernel is launched as a
grid and different kernels can have different grid and block configurations.
Let’s take A100 as an example again. It has 108 SMs, each can run up to 2048 threads. Shared memory
capacity for each SM is up to 192KB. This means that we can take large matrices (up to few MBs), split them
up in smaller chunks that fit into A100 registers and SRAM and then do matrix multiplication at the speed of
18 TB/s, SRAM bandwidth. This in total makes GPU much more efficient than CPU for matrix multiplications.
But we have a lot of non-matmul operations in deep learning, such as normalization layers, activation
functions or dropouts. Even though they only account for a small fraction of the total FLOPs, GPUs run them
much slower. Fistly, because they have specialized units for matrix multiply, called Tensor Cores. That’s why
matmul throughput can be up to 16× higher than non-matmul throughput, e.g. 312 TFLOPs vs 19.5 TFLOPs on
A100. 6 And secondly, they can be extremely memory-bound.
https://astralord.github.io/posts/transformer-inference-optimization-toolset/#pagedattention--vllm 15/29
02/10/2024, 20:41 Transformers Inference Optimization Toolset | AstraBlog
Let’s take P = softmax(S) ∈ R L×L from attention. For a large sequence length L, tensor S has to reside
on HBM, since it would be too large to fit on SRAM. The simplest implementation of softmax requires us to
In total we need to move back and forth 5L 2 + L floats. It would be much faster if we could just read L 2
floats of S and write L 2 floats of P, making this operation more than 2.5× faster.
That’s where kernel fusion comes into play: instead of writing computation output y = f(x) to low-
bandwidth global memory only to read it again to get z = g(y), we can implement kernel which performs
multiple computations at once z = (g ∘ f)(x) without extra memory accesses. XLA compiler in Jax can
perform simple fusions, but a programmer can also write custom CUDA kernels with Triton or Pallas.
Memory-efficient attention
Online softmax
e xi
yi = .
∑ j e xj
Python
1 @jit
2 def naive_softmax(logits):
3 exp_logits = jnp.exp(logits)
4 return exp_logits / exp_logits.sum()
The naive implementation of softmax scans x 2 times - one to calculate normalization term and another to
compute output vector y. Unfortunately, on real hardware, such implementation has a serios flaw: for
x i ≥ 89 exponentiation results in infinity for bf16 and fp32. And here’s a trick to avoid overflow: notice that
for any constant m:
e xi
softmax(x) i =
∑ j e xj
e xi e −m
= ⋅
∑ j e xj e −m
e xi −m
=
∑ j e xj −m
= softmax(x − m) i .
https://astralord.github.io/posts/transformer-inference-optimization-toolset/#pagedattention--vllm 16/29
02/10/2024, 20:41 Transformers Inference Optimization Toolset | AstraBlog
e xi −m(x)
y i = softmax(x − m(x)) i =
ℓ(x)
we implement a numerically stable version of softmax, which is sometimes called safe softmax.
Python
1 @jit
2 def safe_softmax(logits):
3 exp_logits = jnp.exp(logits - logits.max())
4 return exp_logits / exp_logits.sum()
But stability comes with a price in a efficiency since we do one more pass over x now to calculate m(x). This
results in 4 memory access per vector element overall (3 loads and 1 store) and we want to improve on that.
Based on this property Milakov and Gimelshein (2018) presented online softmax, which calculates both
m(x) and ℓ(x) in one pass: initialize m 0 = −∞ and ℓ 0 = 0, then for each iteration i = 1, … , L:
m i ← max(m i−1 , x i ),
ℓ i ← ℓ i−1 e mi−1 −mi + e xi −mi .
m i = m([x 1 , … , x i ])
ℓ i = ℓ([x 1 , … , x i ])
as it iterates over elements of the input array. At each iteration it needs to adjust the normalizer to the new
maximum m i and only then add new value to ℓ i .
m(x) x1 x2 xL
[ ] = [ ] ⊕ [ ] ⊕ ⋯ ⊕ [ ],
ℓ(x) 1 1 1
https://astralord.github.io/posts/transformer-inference-optimization-toolset/#pagedattention--vllm 17/29
02/10/2024, 20:41 Transformers Inference Optimization Toolset | AstraBlog
The operation ⊕ is associative and commutative, which enables parallel and efficient evaluation.
Python
1 @jit
2 def online_softmax(logits):
3
4 def reducer(x, y):
5 m_i, l_i = x
6 m_j, l_j = y
7 m = jnp.maximum(m_i, m_j)
8 l = l_i * jnp.exp(m_i - m) + l_j * jnp.exp(m_j - m)
9 return (m, l)
10
11 m, l = jax.lax.reduce(
12 (logits, jnp.ones_like(logits)),
13 (-jnp.inf, 0.),
14 reducer,
15 (0,)
16 )
17 exp_logits = jnp.exp(logits - m)
18 return exp_logits / l
One can run this little test script to evaluate the efficiency of each implementation:
Python
1 # create large random vector
2 logits = jax.random.uniform(random.PRNGKey(42), shape=(1_000_000,))
3
4 # one warmup run for each function to compile
5 naive_softmax(logits)
6 safe_softmax(logits)
7 online_softmax(logits)
8
9 print('Naive:')
10 %timeit naive_softmax(logits).block_until_ready()
11 print('\nSafe:')
12 %timeit safe_softmax(logits).block_until_ready()
13 print('\nOnline:')
14 %timeit online_softmax(logits).block_until_ready()
Plaintext
Naive:
194 μs ± 15.4 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
Safe:
254 μs ± 17.8 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
https://astralord.github.io/posts/transformer-inference-optimization-toolset/#pagedattention--vllm 18/29
02/10/2024, 20:41 Transformers Inference Optimization Toolset | AstraBlog
Online:
199 μs ± 22.3 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
Check out the original paper for more details, including algorithm with softmax + top-k fusion.
Lazy softmax
Now let’s get back to calculation of attention operation. Given query, key and value tensors
Q, K, V ∈ R L×d our aim is to compute (we omit masking and 1
√d
normalization here and after for
simplification)
S = QK T , P = softmax(S), O = PV.
L
e Sij
S ij = Q Ti K j , P ij = , O i = ∑ P il V l ∀i, j = 1, … , L
∑ Ll=1 e Sil l=1
The problem with implementation above is that it requires us to first compute and remember S ij for all j,
leading to linear time and memory complexity for each query, leading to the overall time and space
complexity O(L 2 ). Rabe and Staats (2022) suggested to move the division by normalization term to the very
end of the attention operation using the distributive law:
∑ Ll=1 V l e Sil
Oi = ∀i = 1, … , L.
∑ Ll=1 e Sil
This implementation, called lazy softmax, can be computed with constant memory for each query: we start
from vector v 0 ∈ R d and scalar ℓ 0 , both initialized with 0, and when we process key/value pairs sequentially
for j = 1, … , L, we only update
v j ← v j−1 + V j e Sij ,
ℓ j ← ℓ j−1 + e Sij .
One can notice that such algorithm has the same numerical problem as the naive implementation of softmax:
incremental computation of the sum of exponentiated scores (and values). The standard safe-softmax trick
cannot be applied here as the maximum may depend on the last score in the sequence. The subtraction
cannot be delayed either, since the scores must be exponentiated before they can be added to the
cumulative sum.
To resolve this problem, authors introduce an additional scalar m as in online softmax, which keeps track
of
the maximum score that the incremental algorithm has seen so far, and they renormalize the sums of
exponentiated values as needed:
https://astralord.github.io/posts/transformer-inference-optimization-toolset/#pagedattention--vllm 19/29
02/10/2024, 20:41 Transformers Inference Optimization Toolset | AstraBlog
m j ← max(m, S ij ),
v j ← v j−1 e mj−1 −mj + V j e Sij −mj ,
ℓ j ← ℓ j−1 + e mj−1 −mj .
Authors also exploited massive parallelism and provided code in Jax for memory-efficient parallel algorithm.
Notice that they use jax.checkpoint decorator in summarize_chunk function. The reason is that during
forward pass this algorihtm saves memory by summarizing parts of the attention matrix sequentially,
allowing it to forget the parts of the attention matrix it has summarized already. A naive application of
differentiation would have to store all those intermediate results and algorithm would loose its memory
advantage entirely. So authors propose to apply gradient checkpointing to the function that summarizes the
individual chunks. The intermediate results can thus be forgotten during the forward pass and recomputed
during backpropagation.
Applying checkpointing to the standard attention algorithm would not achieve these results. The standard
attention algorithm with checkpointing would forget the attention matrix after it is formed; query chunk
attention algorithm never forms the full attention matrix at all.
FlashAttention
FlashAttention might be the most popular implementation of attention mechanism nowadays. While it
actually does more FLOPs than standard attention, it runs up to 3× faster just by making attention algorithm
IO-aware — accounting for reads and writes between levels of GPU memory. Remember, how we discussed
GPU architecture in the first section of this post and that moving tensors from SRAM can be 10× faster on
modern GPU than moving them from HBM.
Dao et al. (2022) modified it with two techniques to reduce the amount of HBM accesses to sub-quadratic in
L:
1. Tiling: inputs Q, K, V are divided into blocks to fit to SRAM, then online softmax is computed to get
attention scores for each block and the results are combined all together. With tiling the whole attention
mechanism can be implemented in one CUDA kernel:
https://astralord.github.io/posts/transformer-inference-optimization-toolset/#pagedattention--vllm 20/29
02/10/2024, 20:41 Transformers Inference Optimization Toolset | AstraBlog
Training requires to store intermediate outputs, such as S, P ∈ R L×L , to compute gradients with
respect to Q, K, V during backward pass. The standard mechanism to reduce memory footprint during
training is to use gradient checkpointing: forgetting these outputs in forward pass and recalculating
them in backward pass. However, this implementation has to trade speed for memory.
Authors propose to use selective gradient checkpointing and to store only the output O and the softmax
normalization statistics (m, ℓ) to recompute the attention matrices S and P easily in the backward pass
from blocks of Q, K, V in SRAM.
S ij = Q i K Tj ∈ R BQ ×BKV .
Compute
~ ij = rowmax(S ij ) ∈ R BQ ,
m
~ ~ ij ) ∈ R BQ ×BKV ,
P ij = exp(S ij − m
~ ~
ℓ ij = rowsum(P ij ) ∈ R BQ
Renew statistics
m new
i
~ ) ∈ R BQ ,
= max(m i , m ij
new ~ new ~
ℓ new
i = e mi −mi ℓ i + e mij −mi ℓ ij ∈ R BQ
Write ℓ i ←ℓ new
i , m i ←m i
new
to HBM.
Return O.
https://astralord.github.io/posts/transformer-inference-optimization-toolset/#pagedattention--vllm 21/29
02/10/2024, 20:41 Transformers Inference Optimization Toolset | AstraBlog
In the end FlashAttention alogithm returns O = softmax(QK T )V with O(L 2 d) FLOPs and requires O(L)
additional memory beyond inputs and output. In terms of memory accesses, it requires O(L 2 d 2 M −1 ) HBM
accesses, where d ≤ M ≤ Ld, compared to O(Ld + L 2 ) for standard attention.
SRAM HBM
T T
K1 K2
V1
Q1 O1
/ℓ1
ℓ1
→ ⋅ = ⋅_ +
ℓ2
/ℓ2
S 11 S 12 P̃ 11 P̃ 12
V2
i
1 2
Schematic diagram of how FlashAttention forward pass is performed, when Q is partitioned into T Q = 1 block of size B Q × d with
B Q = 3 and K/V are partitioned into T KV = 2 blocks of size B KV × d with B KV = 2 each. Here ℓ 1 = ∑ e S11 , ℓ 2 = ℓ 1 + ∑ e S12 .
The step with subtracting m in softmax is omitted for simplification.
Authors of FlashAttention also compared it to query chunk attention algorithm, stating three major
differences:
1. Memory-efficient attention focuses on memory footprint, while FlashAttention focuses on reducing HBM
accesses.
2. Memory-efficient attention summarizes each block with its temporary output along with the softmax
normalization statistics. FlashAttention instead incrementally updates the output after processing each
block, so only one copy of the output is needed.
3. Memory-efficient attention uses gradient checkpointing to recompute the attention matrix and the
temporary output of each block. FlashAttention only recomputes the attention matrix and does not
recompute the temporary output of each block.
“FlashAttention” is an amazingly written paper and it’s definitely worth reading to anyone who’s training
large language models. There are more details in the paper about FlashAttention backward pass, theoretical
proofs and comparisons to other optimization algorithms.
FlashAttention + Parallelism
FlashAttention significantly speeds up attention computation also reduces memory usage from quadratic to
linear in sequence length. While it works for most cases, it’s not optimized for the case of long sequences with
small batch size and/or small number of attention heads, due to insufficient parallelism.
The first version of FlashAttention kernel uses one thread block per one attention head leading to overall Bh
thread blocks (B is batch size, h is number of attention heads). Each thread block is scheduled to run on a
SM, and such scheduling is only efficient when Bh is as large as number of SMs (e.g. 108 SMs for A100 GPU)
https://astralord.github.io/posts/transformer-inference-optimization-toolset/#pagedattention--vllm 22/29
02/10/2024, 20:41 Transformers Inference Optimization Toolset | AstraBlog
for all of the compute resources to be used effectively. If one trains LLM with modern parallelism techniques
batch size is reduced by a factor of DP and number of heads is reduced by a factor of TP.
Tri Dao, author of the original FlashAttention, applied additional parallelization along sequence axis to make
better use of the multiprocessors on the GPU. In such regime in forward pass one attention head is now
processed by multiple thread blocks and each block takes care of its own segment of rows of the attention
matrix. As the rows of the attention matrix don’t depend on each other, we don’t need to communicate
between the blocks.
In the backward pass each thread block now takes care of a segment of columns of the attention matrix.
Parallelizing by columns is faster than parallelizing by rows due to the reduced communication between the
workers (parallelizing by columns requires aggregating the gradient of the query, while parallelizing by rows
requires aggregating the gradient of the key and value).
FlashAttention-2
A new version of FlashAttention (Tri Dao, 2023) included parallelization across different thread blocks to
increase occupancy for long-sequences. Besides that two numbers were reduced:
The first tweak is to rewrite the online softmax trick to reduce the number of rescaling operations as well as
bound-checking and causal masking, without changing the output (remember that matmul throughput can
be few orders higher than non-matmul throughput).
The second tweak is an optimal work partitioning between warps, a group of threads working together. In the
first FlashAttention K and V were divided across 4 or 8 warps per thread block. Each warp multiplies to get a
slice of QK T , then they need to multiply with a slice of V and communicate to add up the result. However,
this is inefficient since all warps need to write their intermediate results out to shared memory, synchronize,
then add up the intermediate results.
In FlashAttention-2, Q is divided instead across 4 warps while keeping K and V accessible by all warps. After
each warp performs matrix multiply to get a slice of QK T , they just need to multiply with their shared slice
of V to get their corresponding slice of the output. There is no need for communication between warps. The
reduction in SRAM reads/writes yields speedup
The second version also introduced support for larger head dimensions (d = 128 → 256) and support for
MQA and GQA.
FlashAttention-3
https://astralord.github.io/posts/transformer-inference-optimization-toolset/#pagedattention--vllm 23/29
02/10/2024, 20:41 Transformers Inference Optimization Toolset | AstraBlog
The most recent version, FlashAttention-3 (2024) focused on optimizations specific to H100 GPU, as previous
versions only achieved up to 35% utilization with Hopper architecture. Authors exploited a new feature:
tensor memory accelerator (TMA) - a new chunk of hardware that can do asynchronous address generation
and fetch memory. The main techniques which were used for speed up:
Ring Attention
Even with Flash Attention, the memory complexity is linear in L so scaling the sequence length is limited by
the memory capacity. We could scale context with number of devices N , split inputs into N parts, perform
computations in parallel, then gather the results. However, attention requires for Q to access all elements of
K, V matrices. 7 Sending large matrices between devices can add a huge communication overhead (e.g.
A100 throughput is 600GB/s with NVLink and only 64GB/s with PCIe).
Ring Attention (Lie et al. (2023)) addresses this problem and explores the idea of hiding communication
overhead behind computations in an extremely large context scenario. The algorithm is the following:
Split input sequence into N blocks along sequence axis, such that each device stores one input block of
size C = ⌈ N
L
⌉. Compute Q i , K i , V i for its input block on each i-th device.
For iter = 0, … , N − 1:
https://astralord.github.io/posts/transformer-inference-optimization-toolset/#pagedattention--vllm 24/29
02/10/2024, 20:41 Transformers Inference Optimization Toolset | AstraBlog
K1
Q1 ⋅ V1
GPU 1
K4 K2
Q4 ⋅ V4 Q2 ⋅ V2
GPU 4 GPU 2
K3
Q3 ⋅ V3
GPU 3
iter
0 1 2 3
GPUs are arranged in a ring, each holding a portion of Q. During the Ring Attention process, the GPUs pass along blocks of K, V to each
other.
Local attention compute at each iteration requires O(C 2 d) operations, while each device needs to send
K j , V j ∈ R C×d tensors or O( Cd) bytes. Thus the lower bound of chunk to effectively hide communication
overhead will be
L compute performance
C=⌈ ⌉≥ .
N communication bandwidth
Stripe Attention
Brandon et al. 2023 studied Ring Attention performance for causal transformer and found out that it leads to
highly imbalanced workload due to triangular structure of causal attention computations. They propose a
simple extension called Stripe Attention to fix this and achieved near 1.5× speedup.
The problem of Ring Attention is that on all but the first iteration, the workload of some devices is entirely
necessary (unmasked), while the workload of others is entirely unnecessary (masked) for the final output
(and thus there is no need for them to be computed). The total latency is determined by the maximum
latency of any participating device per iteration. As a result, regardless of per device optimizations, the
latency per iteration would be the same as the time taken to compute a fully unmasked workload. As a result,
Ring Attention will run as fast as a workload with no attention masking, despite in principle needing to
compute only half the operations.
https://astralord.github.io/posts/transformer-inference-optimization-toolset/#pagedattention--vllm 25/29
02/10/2024, 20:41 Transformers Inference Optimization Toolset | AstraBlog
Ring Attention scores Stripe Attention scores
iter
0 1 2 3
Workload distribution of Ring Attention (left) and Stripe Attention. The square matrices represent the set of all possible pairwise
query/key interactions; row indices correspond to queries, and column indices correspond to keys. All cells above the diagonal are
masked out by the causal mask and can be skipped. The colors indicate which devices are responsible for which parts of the
computation. We can see that some devices in Ring Attention are responsible for workloads which are entirely masked out, whereas
Striped Attention maintains a balanced workload across all devices
Rather than partitioning the tokens into contiguous blocks like in Ring Attention, Stripe Attention partitions
them into sets of evenly spaced stripes based on their residues modulo N , so that i-th token would reside on
i mod N device. In practice, we can achieve this partitioning scheme by permuting the input tokens before
the model’s first embedding layer, and then partitioning the permuted sequence into contiguous blocks as in
Ring Attention. After partitioning, Stripe and Ring Attention algorithms proceeds almost identically.
Another recent improvement, called Tree Attention (Shyam et al. 2024), leveraged tree-reduction topology
to reduce communication costs for decoding across multiple GPUs and achieved asymptotically 8× speedup
compared to Ring Attention.
PagedAttention / vLLM
KV cache takes up a significant amount of memory. In a straightforward implementation we can store the KV
cache of a request in contiguous memory space like all the other tensors. However, unlike the tensors in the
traditional deep learning workloads, KV cache dynamically grows and shrinks over time as the model
generates new tokens, and its lifetime and length are not known a priori.
To store the KV cache of a request in contiguous space, we have to pre-allocate a contiguous chunk of
memory with the request’s maximum length, while the request’s actual length can be much shorter. Another
memory inefficiency can be observed when we use advanced decoding techniques such as beam search or
parallel sampling to generate multiple outputs per request. In these scenarios, the request consists of
multiple sequences that can partially share their KV cache. However, memory sharing is not possible when KV
cache is stored in separate contiguous spaces.
https://astralord.github.io/posts/transformer-inference-optimization-toolset/#pagedattention--vllm 26/29
02/10/2024, 20:41 Transformers Inference Optimization Toolset | AstraBlog
To address the above limitations, Kwon et al., 2023 propose PagedAttention, an algorithm inspired by the OS
solution to memory fragmentation and sharing: virtual memory with paging. PagedAttention divides the
request’s KV cache into blocks, each of which can contain the attention keys and values of a fixed number of
tokens. In PagedAttention, the blocks for the KV cache are not necessarily stored in contiguous space and we
can manage the KV cache in a more flexible way.
The key idea is to represent KV cache as a series of logical KV blocks, filled from left to right as new tokens are
generated, that can be divided into non-contiguous physical KV blocks in GPU memory, allocated by a block
engine. The KV block manager maintains block tables — the mapping between logical and physical KV blocks
of each request. Separating logical and physical KV blocks allows vLLM to dynamically grow the KV cache
memory without reserving it for all positions in advance.
Block 1 Block 1
Block 2 Block 2
Request A Block 3
Block 4
Block 0 Block 5
Block 1 Block 6
Request B Block 8
Tokens
1 2 3 4 5 6 7 8 9
Storing the KV cache of two requests with PagedAttention. The logical blocks of the two sequences are mapped to different physical
blocks within the space reserved by the block engine in GPU workers. Block manager maps the first three logical blocks to 7, 1 and 3
physical blocks for request A and first two blocks to 5 and 2 for request B.
In parallel sampling multiple responses share the same input prompt, allowing the KV cache of the prompt to
be shared as well. Thus the logical blocks for the prompts of all sequences are mapped to the same physical
block, which stores a reference count. At the generation phase the first physical block which contains the
response is replicated with reduced reference count and sampling continues with separate blocks. By sharing
physical KV blocks across multiple samples, memory usage can be greatly reduced, especially for long input
prompts.
Beam search decoding with PagedAttention is more advanced, since also blocks across diferent candidates
can be shared and the sharing pattern changes dynamically as the decoding process progresses. As
candidates no longer among top, their logical blocks are freed and the reference counts of corresponding
physical blocks are reduced.
Additionaly to the algorithm authors released vLLM, a high-throughput distributed LLM serving engine on
top of PagedAttention. Their evaluations on various models and workloads show that vLLM improves the LLM
https://astralord.github.io/posts/transformer-inference-optimization-toolset/#pagedattention--vllm 27/29
02/10/2024, 20:41 Transformers Inference Optimization Toolset | AstraBlog
serving throughput by 2-4× compared to the other existing serving systems, without affecting the model
accuracy.
Conclusion
This essay has covered a wide range of transformer inference optimization techniques, from high-level
algorithmic improvements like KV caching and MQA/GQA, to low-level hardware optimizations such as CUDA
kernel fusion and vLLM. The key takeaway is clear: LLMs are resource-intensive beasts, and taming them
requires a diverse toolkit.
We’ve seen that successful ML engineering isn’t just about understanding algorithms or hardware - it’s about
grasping how they work together. An algorithmic breakthrough might fall flat without the right hardware
implementation, while a deep dive into GPU architecture could unlock unexpected performance gains.
For ML engineers, the message is simple: stay curious and be ready to learn across the entire stack. The most
effective optimizations often come from connecting the dots between different areas of expertise. As LLMs
continue to grow and evolve, so too will the tricks we use to run them efficiently.
1. Full names given to GPU architectures are: Turing, Volta, Ampere, Hopper, Blackwell. ↩
2. In general, the number of flops for a matrix-matrix multiplication AB with A ∈ R n×m and B ∈ R m×k
is near 2mnk: we perform m matrix-vector multiplications, each of which can be represented as n inner
products, and each inner product requires k − 1 additions and k multiplications. ↩
3. Dimension size of values V can be different from d, but usually it’s not the case. ↩
4. A reasonable question might be: “What is the best way to utilize GPU to generate small sequences, e.g.
L ≪ d?” A possible solution is to enlarge batch processing, since one can compute that for x ∈ R B×d
the arithmetic intensity is BL → BL ↩
2 2
BL d+BLd
− 2 h+BLd+d 2
d≫L
5. While it is clear that hedgehog comes from attention “spikyness” modelling, I still wonder what
porcupine in the title refers to. ↩
6. This isn’t unique to GPUs - in fact, TPUs are even less general than GPUs. ↩
7. If we don’t use SWA. ↩
jax gpu kv-cache linear attention cuda kernels online softmax flash attention ring attention
https://astralord.github.io/posts/transformer-inference-optimization-toolset/#pagedattention--vllm 28/29
02/10/2024, 20:41 Transformers Inference Optimization Toolset | AstraBlog
Further Reading
Jul 3, 2023
OLDER NEWER
https://astralord.github.io/posts/transformer-inference-optimization-toolset/#pagedattention--vllm 29/29