【UNET】

UNET学习

引言

UNET是一种专为图像分割设计的卷积神经网络(CNN),由Olaf Ronneberger等人在2015年提出,尤其在医学图像分割中表现卓越。论文链接是UNET原文,UNET提出到现在过去了将近10年,但是因为其出色的表现,在今天的一些深度学习网络中仍在借鉴其提出的U形网络思想。

本文主要是对原文进行解读,包括代码复现,以及对UNET改进方向的思考。这部分内容很多,前置知识需要准备


全卷积网络(FCN)

UNET网络实际上是对全卷积网络的一种扩展,这里简单的介绍一下全卷积网络

(1) 基本定义

  • 全卷积化:将传统CNN中的全连接层替换为卷积层,使网络能够接受任意尺寸的输入图像,并输出与输入尺寸相同的分割结果(像素级预测)。
  • 核心思想:通过卷积和下采样(编码器)提取特征,再通过上采样(解码器)恢复分辨率,最终实现端到端的分割。
    在这里插入图片描述
    原论文中的图

(2) 关键特点

  • 去全连接层:传统CNN(如VGG、AlexNet)的全连接层会固定输入尺寸,而FCN通过全卷积化实现输入尺寸的灵活性。
  • 上采样(反卷积):使用转置卷积(Transposed Convolution)或插值法(如双线性插值)逐步恢复分辨率。
  • 跳跃连接(Skip Connections):将浅层特征(高分辨率、低语义)与深层特征(低分辨率、高语义)融合,提升细节定位能力。例如:
    • FCN-8s:融合第3、4、5层的特征,输出精细的分割结果。

这里列举一下与传统CNN网络的区别:

特性传统CNN全卷积网络(FCN)
任务类型图像分类(全局标签预测)图像分割(像素级预测)
输出形式一维向量(类别概率)二维/三维特征图(与输入尺寸相同)
全连接层包含全连接层(固定输入尺寸)无全连接层,全卷积化(任意输入尺寸)
空间信息保留最终丢失空间信息(通过池化压缩)保留空间信息(通过上采样恢复分辨率)
典型应用分类(如ResNet、VGG)分割(如FCN-8s、UNET)

(3) 局限性

  • 细节丢失:多次下采样导致空间信息丢失,即使有跳跃连接,边缘仍可能模糊。
  • 参数量大:深层网络计算成本高,对小数据集(如医学图像)容易过拟合。
  • 跳跃连接较弱:仅通过逐元素相加(Element-wise Addition)融合特征,信息传递效率有限。

具体细节可以看原论文中的Related work部分,写的很详细。


UNET

上面提到的FCN能够满足图像分割的部分要求了,UNET在FCN基础上进一步优化跳跃连接(通道拼接而非相加),并采用对称结构,更适合小数据和高精度边界需求。

首先,让我们来看一下UNET的网络结构

原文的图
上图解释:

  • 蓝色矩形是通道特征图,上面的是通道数量,右边的白色矩形框是左边复制过来的特征图,特征图的大小在每个矩形框的左边。
  • 可以看到,这是一个U型网络,红色箭头是下采样步骤,绿色箭头是上采样步骤,灰色箭头是跳跃连接。
  • 拓展路径和收缩路径基本对称,即左边和右边的网络结构是对称的,这样的好处是定位信息。
  • 网络没有全连接层。

听下来有点抽象,下面是详细解释

1.整体结构

UNET 是一种 对称的编码器-解码器(Encoder-Decoder)结构,形似字母“U”,故得名。其核心模块包括:

  • 编码器(下采样路径):提取多尺度特征,逐步压缩空间维度。(左半部分)
  • 解码器(上采样路径):逐步恢复空间分辨率,定位目标细节。(右半部分)
  • 跳跃连接(Skip Connections):连接编码器和解码器的对应层,融合局部细节与全局语义。(四条灰色箭头)

2.编码器(Contracting Path)

层级结构:由 下采样模块,每个模块由以下操作组成:

  • 双卷积层:两次 3×3 卷积 + 批归一化(BatchNorm) + ReLU 激活。(原论文中没有提到BatchNorm)
  • 最大池化:2×2 池化(步长 2),将分辨率减半,通道数加倍。

示例参数变化(输入:256×256×3):

  • 模块 1:3 通道 → 64 通道 → 最大池化 → 输出 128×128×64。
  • 模块 2:64 通道 → 128 通道 → 输出 64×64×128。
  • 依此类推,直至最低分辨率(如 16×16×1024)。

3. 解码器(Expansive Path)

  • 层级结构:与编码器对称的 上采样模块,每个模块包含:

    1. 上采样:转置卷积(Transposed Convolution)或双线性插值,分辨率翻倍。
    2. 跳跃连接:与编码器对应层的特征图进行 通道拼接(Concatenation)
    3. 双卷积层:与编码器相同的双卷积操作,细化特征。

    示例参数变化(最低层输入:16×16×1024):

    • 上采样至 32×32×512 → 拼接编码器对应层(32×32×512)→ 输出 32×32×1024。
    • 双卷积后输出 32×32×512,依此类推恢复至原始分辨率。

4. 跳跃连接(Skip Connections)

  • 作用:解决下采样导致的空间信息丢失问题,将编码器的 高分辨率低语义特征 与解码器的 低分辨率高语义特征 融合。
  • 实现方式
    • 编码器每层的输出特征图被复制并拼接到解码器对应层的输入。
    • 通道维度叠加:例如,编码器层输出 64×64×256,解码器上采样后为 64×64×128 → 拼接后为 64×64×(256+128)=384 通道。

5. 输出层

  • 1×1 卷积:将最终特征图的通道数映射到类别数(如二分类输出 1 通道)。
  • 激活函数:Sigmoid(二分类)或 Softmax(多分类)。

6.损失函数

  • 交叉熵损失(Cross-Entropy Loss)
  • Dice Loss(解决类别不平衡):
  • 联合损失:结合交叉熵与 Dice Loss:

    这部分准备单独列一个文章,比较多

UNET代码解析

此处代码需要结合上图

1.下采样模块(Encoder部分)


下采样模块的代码实现

class Down(nn.Module):
 
    def __init__(self, in_channels, out_channels):
        super(Down, self).__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), # stride默认是 1
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.maxpool_conv(x)
  • 功能:先通过最大池化降低分辨率,再进行双卷积提取特征。

  • 执行流程
    1.MaxPool2d(2): 将输入特征图尺寸缩小。
    2.两次卷积: 扩展通道数(in_channelsout_channels),保持分辨率不变。

  • 这里的下采样模块为 池化 → 卷积 → 卷积,而不是选用卷积 → 卷积 → 池化,这是因为后面这个模块可能导致浅层特征提取不足,思考一下池化的特性。

  • 这里使用了一个批归一层(BatchNorm2d),原论文没有提及,一些对UNET代码的复现在这个地方也有不同。其实不必纠结于这些,因为批归一化的操作是为了加速网络的收敛过程。它通过使每层的输入分布在训练过程中保持相对稳定,从而减少每一层所需的学习率调整。同时也可以提高泛化能力超参数寻优的速度,在这里加上会有利于网络的性能。


将多个下采样模块组合起来,便构成了左边的Encoder网络(或者叫压缩路径)
左边网络完整代码(contracting path)

class UNetDownsampling(nn.Module):
    def __init__(self, n_channels, n_classes=1, bilinear=True):
        super(UNetDownsampling, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 1024)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        return [x1, x2, x3, x4, x5]
class DoubleConv(nn.Module):

    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

2.上采样模块(Decoder部分)

在这里插入图片描述
上采样模块的代码实现

class Up(nn.Module):

    def __init__(self, in_channels, out_channels, bilinear=True):
        super(Up, self).__init__()

		# 上采样步骤,其实就是扩大分辨率的步骤
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        else:
            self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2)

        self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
    
        x = torch.cat([x2, x1], dim=1)  # 尺寸拼接
        return self.conv(x)

总体流程可以总结为上采样 →卷积 →卷积

  • 输入
    x1: 来自解码器上一层的低分辨率特征图(需上采样)。
    x2: 来自编码器的对应层的高分辨率特征图(跳跃连接)。
  • 这里的上采样方式有两种:
    • 双线性插值:非参数化方法,计算速度快但不可学习。
    • 转置卷积:可学习的上采样,可能生成更精细的特征。
      根据实际情况选用不同的上采样方式,实际上是扩大分辨率和通道数
  • 尺寸对齐:通过填充(F.pad)解决因奇偶分辨率导致的尺寸不匹配问题(这一步也可以理解为跳跃连接的部分)。

右边网络完整的代码(expansive path)

class UNetUpsampling(nn.Module):
    def __init__(self, n_classes=1, bilinear=True):
        super(UNetUpsampling, self).__init__()
        self.up1 = Up(1024, 512, bilinear)
        self.up2 = Up(512, 256, bilinear)
        self.up3 = Up(256, 128, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)

    def forward(self, down_outputs):
        x = self.up1(down_outputs[4], down_outputs[3])
        x = self.up2(x, down_outputs[2])
        x = self.up3(x, down_outputs[1])
        x = self.up4(x, down_outputs[0])
        logits = self.outc(x)
        return logits

# 最后一层的输出
class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)  # 1x1的卷积,减少通道数,生成预测结果

    def forward(self, x):
        return self.conv(x)

UNET网络结构代码实现

import torch
import torch.nn as nn
from parts_down import UNetDownsampling
from parts_up import UNetUpsampling

class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=True):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        # 下采样部分
        self.downsampling = UNetDownsampling(n_channels, n_classes, bilinear)

        # 上采样部分
        self.upsampling = UNetUpsampling(n_classes, bilinear)

    def forward(self, x):
        # 下采样部分
        down_outputs = self.downsampling(x)

        # 上采样部分
        logits = self.upsampling(down_outputs)

        return logits

将上面的下采样模块和上采样模块放到独立的文件,这样可以让代码更清晰(个人习惯)。


用UNET实现简单的医学图像分割

…待定


代码部分参考自Github开源代码Pytorch-UNet

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值