Torch Optimization
Torch Optimization
Pankaj Chouhan
2024
Float16:
Advantages: Memory efficiency.
Disadvantages: Limited precision, not ideal for training complex
models.
Float32:
Advantages: Higher precision, widely supported.
Disadvantages: Higher memory usage compared to float16 and
bfloat16.
bfloat16:
Advantages: Same dynamic range as float32 with reduced memory use.
Disadvantages: Lower precision than float32 but often sufficient for
deep learning tasks.
Disadvantages: Only available on specific machine, e.g A100, RTX 30
series.
Conclusion: Float16 for inference, bfloat16 for training LLMs or on
GPUs/TPUs, float32 for general use.
if IMPROVEMENT_3:
model = torch.compile(model)
# Perform attention
if self.use_flash_attention:
# Use flash attention (optimized attention)
y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
else:
# Traditional scaled dot-product attention with causal masking
att = ([email protected](-2,-1))*(1.0/math.sqrt(k.size(-1)))
att = att.masked_fill(self.bias[:,:,:T,:T]==0,float("-inf"))
att = F.softmax(att, dim=-1)
y = att @ v # Apply attention weights to the values
Figure: By using Tilling and re-computation, they minimize the call to HBM.
Source:arxiv.
They keep track of two variable m(x), l(x) to compute global softmax
using only local-block.
Pankaj Chouhan (Florida State University) Torch Optimization 2024 13 / 17
Improvement 5: Pad to the power of 2 for efficient GPU
computation
ChatGPT’s answer:
Memory Access Efficiency: Aligning data to powers of 2 allows for coalesced memory
access, reducing memory transaction overhead.
Full Warp Utilization: Padding ensures that all threads in a warp (group of 32 threads)
are fully utilized, avoiding idle threads.
Cache Optimization: Data fits neatly into GPU cache lines (often power-of-2 sized),
minimizing cache misses and improving access speed.
Reduced Shared Memory Bank Conflicts: Ensures data is evenly distributed across shared
memory banks, improving parallel access.
SIMD Efficiency: Enables efficient execution of operations in GPU’s SIMD (Single
Instruction, Multiple Data) units.
if use_fused_adam:
optimizer = torch.optim.AdamW(..., fused=use_fused)
else:
optimizer = torch.optim.AdamW(...)