UNET
UNET学习
引言
UNET是一种专为图像分割设计的卷积神经网络(CNN),由Olaf Ronneberger等人在2015年提出,尤其在医学图像分割中表现卓越。论文链接是UNET原文,UNET提出到现在过去了将近10年,但是因为其出色的表现,在今天的一些深度学习网络中仍在借鉴其提出的U形网络思想。
本文主要是对原文进行解读,包括代码复现,以及对UNET改进方向的思考。这部分内容很多,前置知识需要准备
全卷积网络(FCN)
UNET网络实际上是对全卷积网络的一种扩展,这里简单的介绍一下全卷积网络
(1) 基本定义
- 全卷积化:将传统CNN中的全连接层替换为卷积层,使网络能够接受任意尺寸的输入图像,并输出与输入尺寸相同的分割结果(像素级预测)。
- 核心思想:通过卷积和下采样(编码器)提取特征,再通过上采样(解码器)恢复分辨率,最终实现端到端的分割。
原论文中的图
(2) 关键特点
- 去全连接层:传统CNN(如VGG、AlexNet)的全连接层会固定输入尺寸,而FCN通过全卷积化实现输入尺寸的灵活性。
- 上采样(反卷积):使用转置卷积(Transposed Convolution)或插值法(如双线性插值)逐步恢复分辨率。
- 跳跃连接(Skip Connections):将浅层特征(高分辨率、低语义)与深层特征(低分辨率、高语义)融合,提升细节定位能力。例如:
- FCN-8s:融合第3、4、5层的特征,输出精细的分割结果。
这里列举一下与传统CNN网络的区别:
特性 | 传统CNN | 全卷积网络(FCN) |
---|---|---|
任务类型 | 图像分类(全局标签预测) | 图像分割(像素级预测) |
输出形式 | 一维向量(类别概率) | 二维/三维特征图(与输入尺寸相同) |
全连接层 | 包含全连接层(固定输入尺寸) | 无全连接层,全卷积化(任意输入尺寸) |
空间信息保留 | 最终丢失空间信息(通过池化压缩) | 保留空间信息(通过上采样恢复分辨率) |
典型应用 | 分类(如ResNet、VGG) | 分割(如FCN-8s、UNET) |
(3) 局限性
- 细节丢失:多次下采样导致空间信息丢失,即使有跳跃连接,边缘仍可能模糊。
- 参数量大:深层网络计算成本高,对小数据集(如医学图像)容易过拟合。
- 跳跃连接较弱:仅通过逐元素相加(Element-wise Addition)融合特征,信息传递效率有限。
具体细节可以看原论文中的Related work部分,写的很详细。
UNET
上面提到的FCN能够满足图像分割的部分要求了,UNET在FCN基础上进一步优化跳跃连接(通道拼接而非相加),并采用对称结构,更适合小数据和高精度边界需求。
首先,让我们来看一下UNET的网络结构
上图解释:
- 蓝色矩形是通道特征图,上面的是通道数量,右边的白色矩形框是左边复制过来的特征图,特征图的大小在每个矩形框的左边。
- 可以看到,这是一个U型网络,
红色箭头
是下采样步骤,绿色箭头是上采样步骤,灰色箭头是跳跃连接。 - 拓展路径和收缩路径基本对称,即左边和右边的网络结构是对称的,这样的好处是定位信息。
- 网络没有全连接层。
听下来有点抽象,下面是详细解释
1.整体结构
UNET 是一种 对称的编码器-解码器(Encoder-Decoder)结构,形似字母“U”,故得名。其核心模块包括:
- 编码器(下采样路径):提取多尺度特征,逐步压缩空间维度。(左半部分)
- 解码器(上采样路径):逐步恢复空间分辨率,定位目标细节。(右半部分)
- 跳跃连接(Skip Connections):连接编码器和解码器的对应层,融合局部细节与全局语义。(四条灰色箭头)
2.编码器(Contracting Path)
层级结构:由 下采样模块,每个模块由以下操作组成:
- 双卷积层:两次 3×3 卷积 + 批归一化(BatchNorm) + ReLU 激活。(原论文中没有提到BatchNorm)
- 最大池化:2×2 池化(步长 2),将分辨率减半,通道数加倍。
示例参数变化(输入:256×256×3):
- 模块 1:3 通道 → 64 通道 → 最大池化 → 输出 128×128×64。
- 模块 2:64 通道 → 128 通道 → 输出 64×64×128。
- 依此类推,直至最低分辨率(如 16×16×1024)。
3. 解码器(Expansive Path)
-
层级结构:与编码器对称的 上采样模块,每个模块包含:
- 上采样:转置卷积(Transposed Convolution)或双线性插值,分辨率翻倍。
- 跳跃连接:与编码器对应层的特征图进行 通道拼接(Concatenation)。
- 双卷积层:与编码器相同的双卷积操作,细化特征。
示例参数变化(最低层输入:16×16×1024):
- 上采样至 32×32×512 → 拼接编码器对应层(32×32×512)→ 输出 32×32×1024。
- 双卷积后输出 32×32×512,依此类推恢复至原始分辨率。
4. 跳跃连接(Skip Connections)
- 作用:解决下采样导致的空间信息丢失问题,将编码器的 高分辨率低语义特征 与解码器的 低分辨率高语义特征 融合。
- 实现方式:
- 编码器每层的输出特征图被复制并拼接到解码器对应层的输入。
- 通道维度叠加:例如,编码器层输出 64×64×256,解码器上采样后为 64×64×128 → 拼接后为 64×64×(256+128)=384 通道。
5. 输出层
- 1×1 卷积:将最终特征图的通道数映射到类别数(如二分类输出 1 通道)。
- 激活函数:Sigmoid(二分类)或 Softmax(多分类)。
6.损失函数
- 交叉熵损失(Cross-Entropy Loss):
… - Dice Loss(解决类别不平衡):
… - 联合损失:结合交叉熵与 Dice Loss:
…
这部分准备单独列一个文章,比较多
UNET代码解析
此处代码需要结合上图
1.下采样模块(Encoder部分)
下采样模块的代码实现
class Down(nn.Module):
def __init__(self, in_channels, out_channels):
super(Down, self).__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2),
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), # stride默认是 1
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.maxpool_conv(x)
-
功能:先通过最大池化降低分辨率,再进行双卷积提取特征。
-
执行流程:
1.MaxPool2d(2): 将输入特征图尺寸缩小。
2.两次卷积: 扩展通道数(in_channels
→out_channels
),保持分辨率不变。 -
这里的下采样模块为 池化 → 卷积 → 卷积,而不是选用卷积 → 卷积 → 池化,这是因为后面这个模块可能导致浅层特征提取不足,思考一下池化的特性。
-
这里使用了一个批归一层(BatchNorm2d),原论文没有提及,一些对UNET代码的复现在这个地方也有不同。其实不必纠结于这些,因为批归一化的操作是为了加速网络的收敛过程。它通过使每层的输入分布在训练过程中保持相对稳定,从而减少每一层所需的学习率调整。同时也可以提高泛化能力和超参数寻优的速度,在这里加上会有利于网络的性能。
将多个下采样模块组合起来,便构成了左边的Encoder网络(或者叫压缩路径)
左边网络完整代码(contracting path)
class UNetDownsampling(nn.Module):
def __init__(self, n_channels, n_classes=1, bilinear=True):
super(UNetDownsampling, self).__init__()
self.n_channels = n_channels
self.n_classes = n_classes
self.bilinear = bilinear
self.inc = DoubleConv(n_channels, 64)
self.down1 = Down(64, 128)
self.down2 = Down(128, 256)
self.down3 = Down(256, 512)
self.down4 = Down(512, 1024)
def forward(self, x):
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
return [x1, x2, x3, x4, x5]
class DoubleConv(nn.Module):
def __init__(self, in_channels, out_channels):
super(DoubleConv, self).__init__()
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
2.上采样模块(Decoder部分)
上采样模块的代码实现
class Up(nn.Module):
def __init__(self, in_channels, out_channels, bilinear=True):
super(Up, self).__init__()
# 上采样步骤,其实就是扩大分辨率的步骤
if bilinear:
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
else:
self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2)
self.conv = DoubleConv(in_channels, out_channels)
def forward(self, x1, x2):
x1 = self.up(x1)
# input is CHW
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
diffY // 2, diffY - diffY // 2])
x = torch.cat([x2, x1], dim=1) # 尺寸拼接
return self.conv(x)
总体流程可以总结为上采样 →卷积 →卷积。
- 输入:
x1
: 来自解码器上一层的低分辨率特征图(需上采样)。
x2
: 来自编码器的对应层的高分辨率特征图(跳跃连接)。 - 这里的上采样方式有两种:
- 双线性插值:非参数化方法,计算速度快但不可学习。
- 转置卷积:可学习的上采样,可能生成更精细的特征。
根据实际情况选用不同的上采样方式,实际上是扩大分辨率和通道数
- 尺寸对齐:通过填充(
F.pad
)解决因奇偶分辨率导致的尺寸不匹配问题(这一步也可以理解为跳跃连接的部分)。
右边网络完整的代码(expansive path)
class UNetUpsampling(nn.Module):
def __init__(self, n_classes=1, bilinear=True):
super(UNetUpsampling, self).__init__()
self.up1 = Up(1024, 512, bilinear)
self.up2 = Up(512, 256, bilinear)
self.up3 = Up(256, 128, bilinear)
self.up4 = Up(128, 64, bilinear)
self.outc = OutConv(64, n_classes)
def forward(self, down_outputs):
x = self.up1(down_outputs[4], down_outputs[3])
x = self.up2(x, down_outputs[2])
x = self.up3(x, down_outputs[1])
x = self.up4(x, down_outputs[0])
logits = self.outc(x)
return logits
# 最后一层的输出
class OutConv(nn.Module):
def __init__(self, in_channels, out_channels):
super(OutConv, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) # 1x1的卷积,减少通道数,生成预测结果
def forward(self, x):
return self.conv(x)
UNET网络结构代码实现
import torch
import torch.nn as nn
from parts_down import UNetDownsampling
from parts_up import UNetUpsampling
class UNet(nn.Module):
def __init__(self, n_channels, n_classes, bilinear=True):
super(UNet, self).__init__()
self.n_channels = n_channels
self.n_classes = n_classes
self.bilinear = bilinear
# 下采样部分
self.downsampling = UNetDownsampling(n_channels, n_classes, bilinear)
# 上采样部分
self.upsampling = UNetUpsampling(n_classes, bilinear)
def forward(self, x):
# 下采样部分
down_outputs = self.downsampling(x)
# 上采样部分
logits = self.upsampling(down_outputs)
return logits
将上面的下采样模块和上采样模块放到独立的文件,这样可以让代码更清晰(个人习惯)。
用UNET实现简单的医学图像分割
…待定
代码部分参考自Github开源代码Pytorch-UNet