0% found this document useful (0 votes)
10 views

Torch Optimization

Torch optimization in deep learning
Copyright
© © All Rights Reserved
Available Formats
Download as PDF, TXT or read online on Scribd
0% found this document useful (0 votes)
10 views

Torch Optimization

Torch optimization in deep learning
Copyright
© © All Rights Reserved
Available Formats
Download as PDF, TXT or read online on Scribd
You are on page 1/ 17

Optimize your Torch training code for LLMs

Pankaj Chouhan

Florida State University

2024

Pankaj Chouhan (Florida State University) Torch Optimization 2024 1 / 17


Common Floating Point Data Types

Figure: Comparison between float16, float32 and bfloat16 (source)

Pankaj Chouhan (Florida State University) Torch Optimization 2024 2 / 17


Comparison and Best Use Cases

Figure: Histogram of activation gradient magnitudes throughout FP32 training.


From NVIDIA blog.

Pankaj Chouhan (Florida State University) Torch Optimization 2024 3 / 17


Comparison and Best Use Cases

Figure: Loss in information using FP16 format. From NVIDIA blog

Pankaj Chouhan (Florida State University) Torch Optimization 2024 4 / 17


Comparison and Best Use Cases

Float16:
Advantages: Memory efficiency.
Disadvantages: Limited precision, not ideal for training complex
models.
Float32:
Advantages: Higher precision, widely supported.
Disadvantages: Higher memory usage compared to float16 and
bfloat16.
bfloat16:
Advantages: Same dynamic range as float32 with reduced memory use.
Disadvantages: Lower precision than float32 but often sufficient for
deep learning tasks.
Disadvantages: Only available on specific machine, e.g A100, RTX 30
series.
Conclusion: Float16 for inference, bfloat16 for training LLMs or on
GPUs/TPUs, float32 for general use.

Pankaj Chouhan (Florida State University) Torch Optimization 2024 5 / 17


Improvement 1: Setup the matrix multiplication precision

PyTorch allows you to set the matrix multiplication precision for


float32 operations.
Example code:
torch.set_float32_matmul_precision("high")

The precision can be one of the following:


highest – Maximum precision (Default), use FP32.
high – Mixed precision. Either use TensorFloat32 or two bfloat16 to
represent float32.
medium – Reduced precision for speed. Use bfloat16 if allowed.

Pankaj Chouhan (Florida State University) Torch Optimization 2024 6 / 17


Improvement 2: Use AMP

# Enable AMP for forward pass if IMPROVEMENT_2 is toggled


if IMPROVEMENT_2:
with torch.autocast(device_type=device, dtype=torch.float16):
logits, loss = model(x, y)
else:
logits, loss = model(x, y)

Mixed Precision: Combines float16 and float32 to reduce memory


usage while preserving accuracy.
Safe Operations: Like matrix multiplications, activation function
evaluation, element-wise operations (add, subtract) will run in float16.
Risky Operations: Like reductions (e.g., summing or averaging across
tensors), loss computation will run in float32.

Pankaj Chouhan (Florida State University) Torch Optimization 2024 7 / 17


Improvement 3: Use Torch compile

if IMPROVEMENT_3:
model = torch.compile(model)

Consider this equation:

f (x) = sin2 (x) + cos 2 (x) (1)

In eager mode, each term of the computational graph will be sent to


GPU/CPU one-by-one.
Torch compile looks at the full computational graph at once and
optimizes it by fusing consecutive operations.
This is a good blog.
Torch compile: Graph-Based Optimization, Operation Fusion,
Memory and Python Overhead Reduction.

Pankaj Chouhan (Florida State University) Torch Optimization 2024 8 / 17


Improvement 4: Use flash attention

# Perform attention
if self.use_flash_attention:
# Use flash attention (optimized attention)
y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
else:
# Traditional scaled dot-product attention with causal masking
att = ([email protected](-2,-1))*(1.0/math.sqrt(k.size(-1)))
att = att.masked_fill(self.bias[:,:,:T,:T]==0,float("-inf"))
att = F.softmax(att, dim=-1)
y = att @ v # Apply attention weights to the values

The Main idea is to move the bottleneck of computing self-attention


from slow HBM (High Bandwidth Memory) to ultra-fast GPU on-chip
SRAM.
Self-attention is quadratic in memory and compute time.

Pankaj Chouhan (Florida State University) Torch Optimization 2024 9 / 17


Improvement 4: What is HBM and SRAM?

Figure: GPU memory. From NVIDIA blog.

Pankaj Chouhan (Florida State University) Torch Optimization 2024 10 / 17


Improvement 4: HBM vs. SRAM

Table: Comparison of HBM and SRAM in GPUs


Characteristic HBM (High Bandwidth Memory) SRAM (Static Random Access Memory)
Speed High Extremely fast, used for cache
Capacity High (up to several GB per stack) Low (typically in KB to MB)
Usage Main memory for GPU computation Cache memory on GPU die
Power Efficiency More power-efficient than GDDR Consumes more power per bit
Cost Expensive compared to GDDR Very expensive, takes up a lot of space

Pankaj Chouhan (Florida State University) Torch Optimization 2024 11 / 17


Improvement 4: Flash attention

Figure: Minimize call to HBM using self-attention. From HuggingFace blog

Pankaj Chouhan (Florida State University) Torch Optimization 2024 12 / 17


Improvement 4: Flash attention

Introduced in 2022 by Tri Dao

Figure: By using Tilling and re-computation, they minimize the call to HBM.
Source:arxiv.

They keep track of two variable m(x), l(x) to compute global softmax
using only local-block.
Pankaj Chouhan (Florida State University) Torch Optimization 2024 13 / 17
Improvement 5: Pad to the power of 2 for efficient GPU
computation

The Main reason is how hardware is designed.


Supported by empirical results, this approach has become a widely
accepted standard in the field.

ChatGPT’s answer:
Memory Access Efficiency: Aligning data to powers of 2 allows for coalesced memory
access, reducing memory transaction overhead.
Full Warp Utilization: Padding ensures that all threads in a warp (group of 32 threads)
are fully utilized, avoiding idle threads.
Cache Optimization: Data fits neatly into GPU cache lines (often power-of-2 sized),
minimizing cache misses and improving access speed.
Reduced Shared Memory Bank Conflicts: Ensures data is evenly distributed across shared
memory banks, improving parallel access.
SIMD Efficiency: Enables efficient execution of operations in GPU’s SIMD (Single
Instruction, Multiple Data) units.

Pankaj Chouhan (Florida State University) Torch Optimization 2024 14 / 17


Improvement 6: Use fused Adam optimizer

if use_fused_adam:
optimizer = torch.optim.AdamW(..., fused=use_fused)
else:
optimizer = torch.optim.AdamW(...)

GPU spins up different kernels for gradient computation, momentum


updates, parameter updates etc.
Fused Adam: combines multiple operations in Adam (e.g., gradient
computation, updates) into a single GPU kernel.
Faster Training: Reduces kernel launches and memory access,
speeding up model training.

Pankaj Chouhan (Florida State University) Torch Optimization 2024 15 / 17


Other improvement

Distributed Data-Parallel (DDP).


Model parallelism.
Tensor parallelism.
Zero optimization.
Quantization, KV-caching (for inference).

Pankaj Chouhan (Florida State University) Torch Optimization 2024 16 / 17


Thank You
Questions?

Pankaj Chouhan (Florida State University) Torch Optimization 2024 17 / 17

You might also like