代码示例:
38, PyTorch TensorBoard 的使用与可视化技巧
TensorBoard 不仅是“画折线图”的工具,更是一套可嵌入 PyTorch 训练/推理生命周期的可视化操作系统。本节从“启动 → 记录 → 布局 → 交互 → 高阶技巧”五个维度,给出可直接落地的最佳实践。所有示例基于 PyTorch 2.1+,CPU 与单卡/多卡 GPU 通用。
38.1 启动:一行命令,零配置
# 方式1:系统级
tensorboard --logdir runs --port 6006 --host 0.0.0.0 --load_fast=false
# 方式2:Python 内嵌(推荐在 Notebook)
%load_ext tensorboard
%tensorboard --logdir runs
--load_fast=false
关闭 Rust 后端,旧版浏览器更稳定。- 多实验对比:
--logdir_spec base:runs/base,frozen:runs/frozen
浏览器左侧会自动出现下拉框切换。
38.2 记录:把训练循环变成“数据流”
38.2.1 单 writer 模式(单实验)
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter('runs/exp1')
for epoch in range(E):
train_loss = train_one_epoch()
val_loss, val_acc = validate()
writer.add_scalars('Loss', {'train': train_loss, 'val': val_loss}, epoch)
writer.add_scalar('Acc/val', val_acc, epoch)
38.2.2 多 writer 模式(推荐)
train_writer = SummaryWriter('runs/exp1/train')
val_writer = SummaryWriter('runs/exp1/val')
train_writer.add_scalar('loss', loss, global_step)
val_writer.add_scalar('loss', val_loss, epoch)
优点:
- 子目录天然隔离,TensorBoard 会自动折叠。
- 后续删除/保留部分日志时只需
rm -rf
。
38.2.3 记录任意 Python 对象
数据类型 | API | 示例 |
---|---|---|
直方图 | add_histogram | 权重分布、梯度分布 |
图像 | add_image/add_images | 单张/批量 3×224×224 |
视频 | add_video | 5D Tensor (N,T,C,H,W) |
音频 | add_audio | 1D Tensor 16 kHz |
PR 曲线 | add_pr_curve | 多分类 |
超参数 | add_hparams | 见 38.5 |
38.3 布局:让 10 个实验也能一眼定位
38.3.1 命名空间技巧
使用“/”分隔符,TensorBoard 会自动分组。
writer.add_scalar('network/encoder/conv1/weight_norm', w.norm(), step)
writer.add_scalar('network/decoder/linear/bias_max', b.max(), step)
效果:network
→ encoder
→ conv1
→ weight_norm
。
38.3.2 run 命名规范
{model}_{lr}_{bs}_{date}
例如:resnet50_1e-3_64_0730_2145
。
用 writer = SummaryWriter(comment=f'_{lr}_{bs}')
自动生成。
38.4 交互:在浏览器里做实验
- 滑杆过滤:
在“Scalars”面板输入正则.*loss.*
,其余曲线自动隐藏。 - Tooltip 显示:
鼠标悬停 → 显示(step, value)
,按m
可锁定。 - 图像/音频播放:
点击“Show actual image size”查看 1:1 像素;音频可直接播放。 - 并排对比:
打开两个浏览器 Tab,分别指向不同端口,拖拽成左右布局即可。
38.5 超参数搜索:HParams 面板
from torch.utils.tensorboard import SummaryWriter
from tensorboard.plugins.hparams import api as hp
HP_LR = hp.HParam('lr', hp.RealInterval(1e-5, 1e-1))
HP_BS = hp.HParam('batch_size', hp.Discrete([16, 32, 64]))
HP_OPT = hp.HParam('optimizer', hp.Discrete(['adam', 'sgd']))
with SummaryWriter('runs/hparam') as w:
w.add_hparams(
{HP_LR: 3e-4, HP_BS: 32, HP_OPT: 'adam'},
{'val_acc': best_acc, 'val_loss': best_loss}
)
浏览器出现 “HPARAMS” 页签,可直接排序、筛选,生成平行坐标图。
38.6 高阶技巧
38.6.1 直方图 + 权重分布漂移检测
for name, p in model.named_parameters():
if p.grad is None:
continue
writer.add_histogram(f'grad/{name}', p.grad, global_step)
writer.add_histogram(f'weight/{name}', p, global_step)
当分布突然向 0 或 ±∞ 漂移,可第一时间发现梯度爆炸/消失。
38.6.2 图像网格:可视化 64 张预测
img_grid = torchvision.utils.make_grid(imgs[:64], nrow=8, normalize=True)
writer.add_image('batch_predict', img_grid, global_step)
38.6.3 嵌入向量:2D/3D 投影
features = model.encode(batch) # (N, 128)
labels = batch['label'] # (N,)
writer.add_embedding(
features, metadata=labels, label_img=img_grid,
global_step=epoch, tag='feature_space')
浏览器“PROJECTOR”页签支持 PCA/T-SNE/U-MAP 实时降维、旋转、搜索最近邻。
38.6.4 自定义图结构:Graph 面板
dummy = torch.rand(1, 3, 224, 224).cuda()
writer.add_graph(model, dummy, verbose=True)
verbose=True
会把每个torch.autograd.Function
节点打印出来,便于检查梯度断点。- 支持双击节点查看 shape/dtype,右键“Export as PNG”可直接贴论文。
38.7 生产环境:多进程与远程
# 在 slurm 集群
tensorboard --logdir runs \
--host $(hostname -I | awk '{print $1}') \
--port 0 > tb.log 2>&1 &
echo "TensorBoard URL: http://$(hostname -I | awk '{print $1}'):$(cat tb.log | grep -oP 'http://.*?(\d{4,5})')"
浏览器打开返回的 URL 即可,无需 SSH 隧道。
38.8 清理与版本管理
# 删除旧日志,仅保留最近 3 天
find runs -type d -name '20*' -mtime +3 | xargs rm -rf
Git 仓库里添加 .gitignore
:
runs/
*.tfevents.*
38.9 小结
目标 | 关键 API | 快捷命令 |
---|---|---|
训练曲线 | add_scalar(s) | Scalars |
权重监控 | add_histogram | Distributions / Histograms |
超参搜索 | add_hparams | HPARAMS |
图像/音频 | add_image/audio | Images / Audio |
特征降维 | add_embedding | PROJECTOR |
网络结构 | add_graph | Graphs |
把 TensorBoard 当成“实验日志实时数据库”,而不是“训练结束后画图工具”,你会获得更快的迭代速度、更少的 Bug、以及更漂亮的论文配图。
更多技术文章见公众号: 大城市小农民