ASPP
Atrous Spatial Pyramid Pooling,可以理解为空洞空间卷积池化金字塔或者多孔空间金字塔池化。这其中,包括两个概念:SPP和Atrous。
SPP(Spatial Pyramid Pooling): 是一种神经网络模块,旨在通过在多个尺度上提取特征来捕获多尺度上下文信息。SPP最初在图像分类任务中被提出,后来被广泛应用于目标检测、语义分割等计算机视觉任务。
Atrous:空洞卷积,在卷积核元素之间加入一些空格(零)来扩大卷积核的过程。通过空洞卷积,可以扩大感受野、捕获多尺度上下文信息。在DeepLab系列中被提出,也广泛应用于语义分割与目标检测等任务中。语义分割由于需要获得较大的分辨率图,因此经常在网络的最后两个stage,取消降采样操作,之后采用空洞卷积弥补丢失的感受野。
完整代码如下:
# 空洞卷积
class ASPPConv(nn.Sequential):
def __init__(self, in_channels, out_channels, dilation):
modules = [
nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU()
]
super(ASPPConv, self).__init__(*modules)
# 池化 -> 1*1 卷积 -> 上采样
class ASPPPooling(nn.Sequential):
def __init__(self, in_channels, out_channels):
super(ASPPPooling, self).__init__(
nn.AdaptiveAvgPool2d(1), # 自适应均值池化
nn.Conv2d(in_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU())
def forward(self, x):
size = x.shape[-2:]
for mod in self:
x = mod(x)
# 上采样
return F.interpolate(x, size=size, mode='bilinear', align_corners=False)
# 整个 ASPP 架构
class ASPP(nn.Module):
def __init__(self, in_channels, atrous_rates, out_channels=256):
super(ASPP, self).__init__()
modules = []
# 1*1 卷积
modules.append(nn.Sequential(
nn.Conv2d(in_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU()))
# 多尺度空洞卷积
rates = tuple(atrous_rates)
for rate in rates:
modules.append(ASPPConv(in_channels, out_channels, rate))
# 池化
modules.append(ASPPPooling(in_channels, out_channels))
self.convs = nn.ModuleList(modules)
# 拼接后的卷积
self.project = nn.Sequential(
nn.Conv2d(len(self.convs) * out_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
nn.Dropout(0.5))
def forward(self, x):
res = []
for conv in self.convs:
res.append(conv(x))
res = torch.cat(res, dim=1)
return self.project(res)
PPM
PPM聚合不同区域的上下文信息,以提高网络获取全局信息的能力。具体做法为:在原始特征图上使用不同尺度的池化,得到多个不同尺寸的特征图,再在通道维度上拼接这些特征图 (含原始特征图),最终输出一个糅合了多种尺度的复合特征图,从而达到兼顾全局语义信息与局部细节信息的目的。
Input Iamge
Feature Map:通过 CNN 提取的原始特征图
Pyramid Pooling Module:原始特征图进行不同尺度的池化操作,得到多个不同尺寸的特征图,将每个池化特征图通过上采样操作(双线性插值)还原到原始特征图的大小,最后在通道维度上进行拼接,得到最终的复合特征图。
Final Prediction:该特征表示输入到一个卷积层中,以生成逐像素的最终预测结果
示例代码如下:
# _*_coding:utf-8_*_
import torch
import torch.nn as nn
import torch.nn.functional as F
class PPM(nn.Module):
def __init__(self, in_dim, out_dim, bins):
super(PPM, self).__init__()
self.features = []
for bin in bins:
self.features.append(nn.Sequential(
nn.AdaptiveAvgPool2d(bin),
nn.Conv2d(in_dim, out_dim, kernel_size=1, bias=False),
nn.BatchNorm2d(out_dim),
nn.ReLU(inplace=True)
))
self.features = nn.ModuleList(self.features)
def forward(self, x):
x_size = x.size()
out = [x]
for f in self.features:
temp = f(x)
temp = F.interpolate(temp, x_size[2:], mode="bilinear", align_corners=True)
out.append(temp)
return torch.cat(out, 1)
DAPPM
DAPPM是一个类似于PPM的多尺度模块,接在语义分支的最后,用来进一步从低分辨率特征图中提取上下文信息,其具体结构如图5所示。以1/64分辨率的特征图为输入,采用指数步长的大池化核生成1/128、1/256、1/512分辨率的特征图,以及输入特征图和全局平均池化产生的图像级信息(红色块)
在论文中,作者表明:仅使用一个3×3卷积或1×1卷积将所有的多尺度上下文信息融合在一起是不够的。所以采用类似Res2Net的结构,首先对特征图进行上采样(灰色块),然后使用更多的3×3卷积,以层次残差hierarchial-residual的方式融合不同尺度的上下文信息(紫色块)。最后将所有得到的特征concatenate并使用一个1×1卷积压缩。此外,为了便于优化,还添加了一个1×1卷积映射作为shortcut。
class DAPPM(BaseModule):
"""DAPPM module in `DDRNet <https://arxiv.org/abs/2101.06085>`_.
Args:
in_channels (int): Input channels.
branch_channels (int): Branch channels.
out_channels (int): Output channels.
num_scales (int): Number of scales.
kernel_sizes (list[int]): Kernel sizes of each scale.
strides (list[int]): Strides of each scale.
paddings (list[int]): Paddings of each scale.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN').
act_cfg (dict): Config dict for activation layer in ConvModule.
Default: dict(type='ReLU', inplace=True).
conv_cfg (dict): Config dict for convolution layer in ConvModule.
Default: dict(order=('norm', 'act', 'conv'), bias=False).
upsample_mode (str): Upsample mode. Default: 'bilinear'.
"""
def __init__(self,
in_channels: int,
branch_channels: int,
out_channels: int,
num_scales: int,
kernel_sizes: List[int] = [5, 9, 17],
strides: List[int] = [2, 4, 8],
paddings: List[int] = [2, 4, 8],
norm_cfg: Dict = dict(type='BN', momentum=0.1),
act_cfg: Dict = dict(type='ReLU', inplace=True),
conv_cfg: Dict = dict(
order=('norm', 'act', 'conv'), bias=False),
upsample_mode: str = 'bilinear'):
super().__init__()
self.num_scales = num_scales
self.unsample_mode = upsample_mode
self.in_channels = in_channels
self.branch_channels = branch_channels
self.out_channels = out_channels
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.conv_cfg = conv_cfg
self.scales = ModuleList([
ConvModule(
in_channels,
branch_channels,
kernel_size=1,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
**conv_cfg)
])
for i in range(1, num_scales - 1):
self.scales.append(
Sequential(*[
nn.AvgPool2d(
kernel_size=kernel_sizes[i - 1],
stride=strides[i - 1],
padding=paddings[i - 1]),
ConvModule(
in_channels,
branch_channels,
kernel_size=1,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
**conv_cfg)
]))
self.scales.append(
Sequential(*[
nn.AdaptiveAvgPool2d((1, 1)),
ConvModule(
in_channels,
branch_channels,
kernel_size=1,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
**conv_cfg)
]))
self.processes = ModuleList()
for i in range(num_scales - 1):
self.processes.append(
ConvModule(
branch_channels,
branch_channels,
kernel_size=3,
padding=1,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
**conv_cfg))
self.compression = ConvModule(
branch_channels * num_scales,
out_channels,
kernel_size=1,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
**conv_cfg)
self.shortcut = ConvModule(
in_channels,
out_channels,
kernel_size=1,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
**conv_cfg)
def forward(self, inputs: Tensor): # (16,1024,8,8)
feats = []
feats.append(self.scales[0](inputs))
for i in range(1, self.num_scales):
feat_up = F.interpolate(
self.scales[i](inputs),
size=inputs.shape[2:],
mode=self.unsample_mode)
feats.append(self.processes[i - 1](feat_up + feats[i - 1]))
# [(16,128,8,8),(16,128,8,8),(16,128,8,8),(16,128,8,8),(16,128,8,8)]
return self.compression(torch.cat(feats,
dim=1)) + self.shortcut(inputs) # (16,256,8,8)
PAPPM
为了更好地构建全局场景先验,PSPNet 引入了金字塔池化模块(PPM),通过在卷积层之前拼接多尺度池化图,形成局部和全局上下文表示。深度聚合 PPM (DAPPM) 由 [20] 提出,进一步提高了 PPM 的上下文嵌入能力,并表现出卓越的性能。然而,由于其深度,DAPPM 的计算过程无法并行化,这导致耗时问题。此外,DAPPM 在每个尺度上包含的通道数过多,可能会超出轻量化模型的表示能力。
因此,修改了 DAPPM 的连接方式,使其能够并行化处理,并将每个尺度的通道数从 128 减少到 96。这个新的上下文提取模块称为 并行聚合 PPM (PAPPM),并应用于 PIDNet-M 和 PIDNet-S,以保证其速度表现。对于较深的模型 PIDNet-L,仍然选择 DAPPM 结构以利用其深度特性,但通过减少通道数来降低计算量并提升速度。
代码如下:
class PAPPM(nn.Module):
def __init__(self, inplanes, branch_planes, outplanes, BatchNorm=nn.BatchNorm2d):
super(PAPPM, self).__init__()
bn_mom = 0.1
self.scale1 = nn.Sequential(nn.AvgPool2d(kernel_size=5, stride=2, padding=2),
BatchNorm(inplanes, momentum=bn_mom),
nn.ReLU(inplace=True),
nn.Conv2d(inplanes, branch_planes, kernel_size=1, bias=False),
)
self.scale2 = nn.Sequential(nn.AvgPool2d(kernel_size=9, stride=4, padding=4),
BatchNorm(inplanes, momentum=bn_mom),
nn.ReLU(inplace=True),
nn.Conv2d(inplanes, branch_planes, kernel_size=1, bias=False),
)
self.scale3 = nn.Sequential(nn.AvgPool2d(kernel_size=17, stride=8, padding=8),
BatchNorm(inplanes, momentum=bn_mom),
nn.ReLU(inplace=True),
nn.Conv2d(inplanes, branch_planes, kernel_size=1, bias=False),
)
self.scale4 = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
BatchNorm(inplanes, momentum=bn_mom),
nn.ReLU(inplace=True),
nn.Conv2d(inplanes, branch_planes, kernel_size=1, bias=False),
)
self.scale0 = nn.Sequential(
BatchNorm(inplanes, momentum=bn_mom),
nn.ReLU(inplace=True),
nn.Conv2d(inplanes, branch_planes, kernel_size=1, bias=False),
)
self.scale_process = nn.Sequential(
BatchNorm(branch_planes*4, momentum=bn_mom),
nn.ReLU(inplace=True),
nn.Conv2d(branch_planes*4, branch_planes*4, kernel_size=3, padding=1, groups=4, bias=False),
)
self.compression = nn.Sequential(
BatchNorm(branch_planes * 5, momentum=bn_mom),
nn.ReLU(inplace=True),
nn.Conv2d(branch_planes * 5, outplanes, kernel_size=1, bias=False),
)
self.shortcut = nn.Sequential(
BatchNorm(inplanes, momentum=bn_mom),
nn.ReLU(inplace=True),
nn.Conv2d(inplanes, outplanes, kernel_size=1, bias=False),
)
def forward(self, x):
width = x.shape[-1]
height = x.shape[-2]
scale_list = []
x_ = self.scale0(x)
scale_list.append(F.interpolate(self.scale1(x), size=[height, width],
mode='bilinear', align_corners=algc)+x_)
scale_list.append(F.interpolate(self.scale2(x), size=[height, width],
mode='bilinear', align_corners=algc)+x_)
scale_list.append(F.interpolate(self.scale3(x), size=[height, width],
mode='bilinear', align_corners=algc)+x_)
scale_list.append(F.interpolate(self.scale4(x), size=[height, width],
mode='bilinear', align_corners=algc)+x_)
scale_out = self.scale_process(torch.cat(scale_list, 1))
out = self.compression(torch.cat([x_,scale_out], 1)) + self.shortcut(x)
return out