VLM多模态核心模块解析(一):Projector设计与分阶段训练策略

一、VLM模型基础:视觉-语言联合建模

VLM(Vision-Language Model) 是一种能够同时处理视觉(图像/视频)和语言(文本)信息的多模态模型,其核心目标是建立视觉与语言模态之间的语义对齐。典型的VLM架构包含三大组件:

  1. 视觉编码器(Vision Backbone)

    • 常用CLIP-ViT、ResNet等预训练模型

    • 输出视觉特征向量(如embed_dim=768

  2. 语言模型(LLM Backbone)

    • 采用LLaMA、GPT等大语言模型

    • 负责文本理解与生成

  3. 投影层(Projector)

    • 关键创新点:将视觉特征空间映射到语言模型输入空间

    • 设计质量直接影响多模态交互效果

🔍 为什么需要Projector?
视觉特征(如ViT输出)与文本特征(如LLM输入)通常存在于不同向量空间,直接拼接会导致语义割裂。Projector通过可学习的参数矩阵实现跨模态特征对齐。

二、Projector模块架构解析

Projector作为连接视觉编码器(Vision Backbone)与语言模型(LLM)的关键组件,其实现方式直接影响特征空间对齐效果:

if arch_specifier == "linear":
    self.projector = LinearProjector(vision_backbone.embed_dim, llm_backbone.embed_dim)
elif arch_specifier.endswith("fused-gelu-mlp"):
    self.projector = FusedMLPProjector(vision_backbone.embed_dim, llm_backbone.embed_dim)
elif arch_specifier.endswith("gelu-mlp"):
    self.projector = MLPProjector(vision_backbone.embed_dim, llm_backbone.embed_dim)
else:
    raise ValueError(f"VLM with `{arch_specifier = }` is not supported!")

三种Projector实现方案:

  1. LinearProjector

    • 最简单的线性映射层

    • 适用于特征空间差异较小的场景

  2. MLPProjector

    • 含GELU激活的多层感知机

    • 更强的非线性表征能力

  3. FusedMLPProjector

    • 优化后的高效MLP实现

    • 平衡性能与计算效率

注:选择取决于arch_specifier参数,未匹配时抛出异常确保配置正确性

三、分阶段训练策略详解

采用渐进式解冻策略,各阶段模块训练状态如下:

阶段Vision EncoderLLMProjector典型应用场景
align❄️ 冻结❄️ 冻结🔥 训练初始特征对齐
finetune❄️ 冻结🔥 训练🔥 训练语言模型适配
full-finetune🔥 训练(float32)🔥 训练🔥 训练全模型端到端微调
last-layer-finetune❄️ 冻结末层🔥❄️ 冻结轻量级微调
vla-sandwich-train🔥 训练末层🔥🔥 训练视觉-语言联合优化

关键实现逻辑

1. 训练状态控制

Align 阶段:只训练视觉特征到语言模型输入空间的「对齐层(projector)」。

self.vision_backbone.requires_grad_(False)
self.llm_backbone.requires_grad_(False)
self.projector.requires_grad_(True)

# Add to `self.trainable_module_keys`
self.trainable_module_keys = ["projector"]

# Update Trackers
self.vision_backbone_requires_grad = False

Finetune 阶段:训练Projector + LLM

self.vision_backbone.requires_grad_(False)
self.llm_backbone.requires_grad_(True)
self.projector.requires_grad_(True)

# Add to `self.trainable_module_keys`
self.trainable_module_keys = ["projector", "llm_backbone"]

# Update Trackers
self.vision_backbone_requires_grad = False

Full Finetune阶段:训练 Vision + Projector + LLM

self.vision_backbone.dtype = torch.float32
self.vision_backbone.requires_grad_(True)
self.llm_backbone.requires_grad_(True)
self.projector.requires_grad_(True)

# Add to `self.trainable_module_keys`
self.trainable_module_keys = ["vision_backbone", "projector", "llm_backbone"]

# Update Trackers
self.vision_backbone_requires_grad = True
2. 阶段依赖管理

✅ 如果选择align,跳过加载,因为需要从头训练。

✅ 如果选择finetune,必须先训练projector,加载训练好的参数,再训练projector + LLM。

  • align 已训练好的 projector 基础上,解冻语言模型等模块继续训练。因此它必须加载 align 训练后的权重

def load_from_checkpoint(self, stage: str, run_dir: Path, pretrained_checkpoint: Optional[Path] = None) -> None:
    # ...

    if stage == "align":
        log.info("Stage `align` does not require pretrained weights =>> Starting Training", ctx_level=1)
        return

    # otherwise:load align checkpoints
    align_dirs = [
            d
            for d in run_dir.parent.iterdir()
            if (d.name.startswith(f"{model}+{scale}") and d.name.endswith(f"+stage-align+{seed}"))
         ]
     assert len(align_dirs) == 1, "Multiple or No Valid Pretrained Directories Exist -- Double Check `runs`!"
     if (pretrained_checkpoint := (align_dirs[0] / "checkpoints" / "latest-checkpoint.pt")).exists():
         log.info(f"Loading from Discovered Checkpoint `{pretrained_checkpoint}`", ctx_level=1)
         model_state_dict = torch.load(pretrained_checkpoint)["model"]
         self.projector.load_state_dict(model_state_dict["projector"])
     else:
         raise ValueError(f"Could not find valid `align` checkpoint at {pretrained_checkpoint}!")

四、最佳实践建议

  1. 训练流程

    • 必须遵循align -> finetune的递进顺序

    • 对齐阶段建议使用MLPProjector获得更好初始化

  2. 精度控制

    • full-finetune阶段强制float32避免精度损失:vision_backbone.dtype = torch.float32

    • 可使用torch.autocast进行混合精度训练

  3. 调试技巧

    • 监控实际训练参数: 记录到 self.trainable_module_keys 方便跟踪。

    • 使用grad_check验证梯度传播是否符合预期

五、完整实现参考

本文代码以prismatic为例,详见官方GitHub仓库:
https://github.com/TRI-ML/prismatic-vlms/blob/main/prismatic/models/vlms/prismatic.py

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值