【AI模型学习】上/下采样


基于Transformer架构的图像分割模型(如 SegFormer、Swin-Unet)中,上采样和下采样结构几乎是标准配置。

分割中的上/下采样

为什么需要下采样?

  1. 提取高层语义特征
    Transformer擅长全局建模,结合下采样可以:降低分辨率;聚焦于更宽范围的上下文。

  2. 减少计算成本
    原始输入图像太大,直接送入多层Transformer(特别是多头注意力)会导致计算量和显存爆炸。

为什么要上采样?

  1. 恢复空间分辨率
    Segmentation任务最终要输出与输入图像同样大小的分割mask;

  2. 细粒度定位
    但如果没有上采样、跳跃连接或融合,容易失去细节;所以上采样常结合UNet-like结构来补偿细节损失。

模型下采样方式上采样方式
SETRViT backbone patchify多层反卷积上采样
SegFormerMLP Mixer + 4阶段卷积下采样多层插值 + FFN
Swin-UnetSwin Transformer 下采样Patch expanding + skip连接

下采样

SegFormer和PVT(使用卷积)

# 输入:img.shape = [B, 3, 512, 512]

# Stage 1
x1 = Conv2d(3, 32, kernel_size=7, stride=4, padding=3)(img)    # → [B, 32, 128, 128]
x1 = x1.flatten(2).transpose(1, 2)                             # → [B, 16384, 32]
x1 = TransformerBlock(x1)

# Stage 2
x2 = Conv2d(32, 64, kernel_size=3, stride=2, padding=1)(x1_reshaped)  # → [B, 64, 64, 64]
x2 = x2.flatten(2).transpose(1, 2)                                    # → [B, 4096, 64]
x2 = TransformerBlock(x2)

# 后面还有 Stage3、Stage4 类似

Shape 演化:

Stage1: [B, 128×128=16384, 32]
Stage2: [B, 64×64=4096, 64]
Stage3: [B, 32×32=1024, 160]
Stage4: [B, 16×16=256, 256]

Swin-Unet(使用 Patch Merging)

# 初始 patch embedding(patch_size=4)
x = Conv2d(3, 96, kernel_size=4, stride=4)(img)      # [B, 96, 128, 128]
x = x.flatten(2).transpose(1, 2)                     # → [B, 16384, 96]

# Stage 1
x = SwinBlock(x)                                     # [B, 16384, 96]
x = PatchMerging(x)                                  # → [B, 4096, 192]

# Stage 2
x = SwinBlock(x)                                     # [B, 4096, 192]
x = PatchMerging(x)                                  # → [B, 1024, 384]

# Stage 3
x = SwinBlock(x)
x = PatchMerging(x)                                  # → [B, 256, 768]

Shape 演化:

Stage0: [B, 128×128=16384, 96]
Stage1: [B, 64×64=4096, 192]
Stage2: [B, 32×32=1024, 384]
Stage3: [B, 16×16=256, 768]

过程不难,只是不好描述,可以看相关教程,这里就把代码贴出来

class PatchMerging(nn.Module):
    def __init__(self, in_dim):
        super().__init__()
        self.reduction = nn.Linear(in_dim * 4, in_dim * 2)

    def forward(self, x, H, W):
        # x: [B, H*W, C] → [B, H, W, C]
        x = x.view(B, H, W, C)

        # 拆分四个方向的 token
        x0 = x[:, 0::2, 0::2, :]  # top-left
        x1 = x[:, 1::2, 0::2, :]  # bottom-left
        x2 = x[:, 0::2, 1::2, :]  # top-right
        x3 = x[:, 1::2, 1::2, :]  # bottom-right

        x = torch.cat([x0, x1, x2, x3], dim=-1)  # → [B, H/2, W/2, 4C]
        x = x.view(B, -1, 4 * C)                 # → [B, H/2*W/2, 4C]
        x = self.reduction(x)                   # → [B, H/2*W/2, 2C]
        return x


上采样

SegFormer(interpolate)

在这里插入图片描述

def forward(self, x1, x2, x3, x4):  
    # 输入来自4个Stage:
    # x1: [B, 128*128, 32]
    # x2: [B, 64*64,   64]
    # x3: [B, 32*32,  160]
    # x4: [B, 16*16,  256]

    B = x1.shape[0]

    # === 1. Linear Projection:通道都投影为 256 ===
    _x1 = self.linear1(x1).permute(0, 2, 1).reshape(B, 256, 128, 128)  # [B, 256, 128, 128]
    _x2 = self.linear2(x2).permute(0, 2, 1).reshape(B, 256,  64,  64)  # [B, 256,  64,  64]
    _x3 = self.linear3(x3).permute(0, 2, 1).reshape(B, 256,  32,  32)  # [B, 256,  32,  32]
    _x4 = self.linear4(x4).permute(0, 2, 1).reshape(B, 256,  16,  16)  # [B, 256,  16,  16]

    # === 2. 上采样到统一大小 ===
    _x2 = F.interpolate(_x2, size=(128, 128), mode='bilinear', align_corners=False)  # [B, 256, 128, 128]
    _x3 = F.interpolate(_x3, size=(128, 128), mode='bilinear', align_corners=False)  # [B, 256, 128, 128]
    _x4 = F.interpolate(_x4, size=(128, 128), mode='bilinear', align_corners=False)  # [B, 256, 128, 128]

    # === 3. 拼接所有层 ===
    fused = torch.cat([_x1, _x2, _x3, _x4], dim=1)  # [B, 4*256=1024, 128, 128]

    # === 4. 1x1卷积融合通道数 ===
    out = self.fuse_conv(fused)  # [B, 256, 128, 128]

    return out

Swin-Unet(Patch Expanding)

看图也能看出来,十分经典的U-Net结构。

在这里插入图片描述

在上采样阶段
输入:

一个高语义 token,维度为 [4C],是上一步 Patch Merging 得到的。

  1. Linear 映射
    [4C] 投影为 [C] × 4,也就是还原为 2×2 patch 每格的 C 维向量。

  2. reshape → [H, W, 2, 2, C] → [2H, 2W, C]
    把这 4 个 token 安排到一个新的空间位置(上采样 ×2)。

  3. 最终输出为:

    Token 数量 × 4 , 通道数 ÷ 2 \text{Token 数量} \times 4,\quad \text{通道数} \div 2 Token 数量×4通道数÷2

class PatchExpanding(nn.Module):
    def __init__(self, in_dim, expand_ratio=2):
        super().__init__()
        # Linear: [B, H*W, in_dim] → [B, H*W, out_dim = (expand_ratio^2) * out_channels]
        # 例如:in_dim = 512,expand_ratio = 2 → 输出 4×C = 1024
        self.linear = nn.Linear(in_dim, in_dim // 2 * expand_ratio**2)
        self.expand_ratio = expand_ratio

    def forward(self, x, H, W):
        # x: [B, H*W, C]
        B, N, C = x.shape
        R = self.expand_ratio  # 通常为 2

        # 线性投影:C → 4 * (C/2),也就是 [B, H*W, 4*C'],每个 token 展开为 2×2 的 patch
        x = self.linear(x)                      # [B, H*W, 4*C'] = [B, H*W, R*R*(C//2)]

        #  reshape 成图像形式,带有 2×2 子结构 → [B, H, W, R, R, C']
        x = x.view(B, H, W, R, R, C // 2)       # [B, H, W, 2, 2, C//2]

        #  调整顺序,将 2×2 子结构移入空间维度 → [B, H*2, W*2, C//2]
        x = x.permute(0, 1, 3, 2, 4, 5)         # [B, H, 2, W, 2, C//2]
        x = x.reshape(B, H * R, W * R, C // 2)  # [B, 2H, 2W, C//2]

        #  flatten 成 token 序列形式(可再送入 Transformer)→ [B, 4*H*W, C//2]
        x = x.view(B, -1, C // 2)               # [B, 4*H*W, C//2]
        return x


逐级interpolate的方式

  1. 输入来自编码器 4 个 stage

    • x4:[16×16, 512] ← 最深层
    • x3:[32×32, 320]
    • x2:[64×64, 128]
    • x1:[128×128, 64] ← 最浅层
  2. 通道统一:
    每个特征图先通过 1×1 卷积或 Linear 映射,统一成相同维度(如全部 → 256 或 512)

  3. 上采样与融合(逐级):

    f4 = Conv(x4)                          # [16×16]
    f3 = F.interpolate(f4, scale=2) + Conv(x3)  # → [32×32]
    f2 = F.interpolate(f3, scale=2) + Conv(x2)  # → [64×64]
    f1 = F.interpolate(f2, scale=2) + Conv(x1)  # → [128×128]
    

反卷的方式

很经典的设计,不必过多介绍。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值