代码来源
https://github.com/ZhengPeng7/BiRefNet
模块作用
DIS 是一种旨在对高分辨率图像中的目标物体进行精确分割的技术,尤其适用于具有复杂细微结构的物体,例如细长的边缘或微小细节。传统方法在处理这类任务时往往难以捕捉细微特征或恢复高分辨率细节,因此论文提出了一种新颖的网络架构BiRefNet以解决这些挑战。
模块结构
定位模块(LM)
- 输入高分辨率图像至视觉变换器骨干网络。
- 提取多尺度的层次特征,捕捉全局语义信息。
- 通过特征融合和压缩,生成低分辨率的粗略预测图。
- 原理:利用变换器的全局建模能力,在低分辨率下快速定位目标物体,避免直接处理高分辨率带来的计算负担。
重建模块(RM)
- 接收定位模块输出的低分辨率粗略预测图。
- 在解码器的多个阶段,逐步上采样并结合双边参考信息。
- 输出高分辨率的精细分割图。
- 原理:通过将原始图像的分块输入解码器,提供高分辨率的细节参考,确保重建过程中细节不丢失。通过梯度图的监督,引导模型聚焦于边缘和细微结构,避免模糊或遗漏关键区域。从低分辨率到高分辨率的分阶段上采样,确保全局一致性和局部精确性的平衡。
代码
class BiRefNet(
nn.Module,
PyTorchModelHubMixin,
library_name="birefnet",
repo_url="https://github.com/ZhengPeng7/BiRefNet",
tags=['Image Segmentation', 'Background Removal', 'Mask Generation', 'Dichotomous Image Segmentation', 'Camouflaged Object Detection', 'Salient Object Detection']
):
def __init__(self, bb_pretrained=True):
super(BiRefNet, self).__init__()
self.config = Config()
self.epoch = 1
self.bb = build_backbone(self.config.bb, pretrained=bb_pretrained)
channels = self.config.lateral_channels_in_collection
if self.config.auxiliary_classification:
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.cls_head = nn.Sequential(
nn.Linear(channels[0], len(class_labels_TR_sorted))
)
if self.config.squeeze_block:
self.squeeze_module = nn.Sequential(*[
eval(self.config.squeeze_block.split('_x')[0])(channels[0]+sum(self.config.cxt), channels[0])
for _ in range(eval(self.config.squeeze_block.split('_x')[1]))
])
self.decoder = Decoder(channels)
if self.config.ender:
self.dec_end = nn.Sequential(
nn.Conv2d(1, 16, 3, 1, 1),
nn.Conv2d(16, 1, 3, 1, 1),
nn.ReLU(inplace=True),
)
# refine patch-level segmentation
if self.config.refine:
if self.config.refine == 'itself':
self.stem_layer = StemLayer(in_channels=3+1, inter_channels=48, out_channels=3, norm_layer='BN' if self.config.batch_size > 1 else 'LN')
else:
self.refiner = eval('{}({})'.format(self.config.refine, 'in_channels=3+1'))
if self.config.freeze_bb:
# Freeze the backbone...
print(self.named_parameters())
for key, value in self.named_parameters():
if 'bb.' in key and 'refiner.' not in key:
value.requires_grad = False
def forward_enc(self, x):
if self.config.bb in ['vgg16', 'vgg16bn', 'resnet50']:
x1 = self.bb.conv1(x); x2 = self.bb.conv2(x1); x3 = self.bb.conv3(x2); x4 = self.bb.conv4(x3)
else:
x1, x2, x3, x4 = self.bb(x)
if self.config.mul_scl_ipt:
B, C, H, W = x.shape
x_pyramid = F.interpolate(x, size=(H//2, W//2), mode='bilinear', align_corners=True)
if self.config.mul_scl_ipt == 'cat':
if self.config.bb in ['vgg16', 'vgg16bn', 'resnet50']:
x1_ = self.bb.conv1(x_pyramid); x2_ = self.bb.conv2(x1_); x3_ = self.bb.conv3(x2_); x4_ = self.bb.conv4(x3_)
else:
x1_, x2_, x3_, x4_ = self.bb(x_pyramid)
x1 = torch.cat([x1, F.interpolate(x1_, size=x1.shape[2:], mode='bilinear', align_corners=True)], dim=1)
x2 = torch.cat([x2, F.interpolate(x2_, size=x2.shape[2:], mode='bilinear', align_corners=True)], dim=1)
x3 = torch.cat([x3, F.interpolate(x3_, size=x3.shape[2:], mode='bilinear', align_corners=True)], dim=1)
x4 = torch.cat([x4, F.interpolate(x4_, size=x4.shape[2:], mode='bilinear', align_corners=True)], dim=1)
elif self.config.mul_scl_ipt == 'add':
x1_, x2_, x3_, x4_ = self.bb(x_pyramid)
x1 = x1 + F.interpolate(x1_, size=x1.shape[2:], mode='bilinear', align_corners=True)
x2 = x2 + F.interpolate(x2_, size=x2.shape[2:], mode='bilinear', align_corners=True)
x3 = x3 + F.interpolate(x3_, size=x3.shape[2:], mode='bilinear', align_corners=True)
x4 = x4 + F.interpolate(x4_, size=x4.shape[2:], mode='bilinear', align_corners=True)
class_preds = self.cls_head(self.avgpool(x4).view(x4.shape[0], -1)) if self.training and self.config.auxiliary_classification else None
if self.config.cxt:
x4 = torch.cat(
(
*[
F.interpolate(x1, size=x4.shape[2:], mode='bilinear', align_corners=True),
F.interpolate(x2, size=x4.shape[2:], mode='bilinear', align_corners=True),
F.interpolate(x3, size=x4.shape[2:], mode='bilinear', align_corners=True),
][-len(self.config.cxt):],
x4
),
dim=1
)
return (x1, x2, x3, x4), class_preds
def forward_ori(self, x):
########## Encoder ##########
(x1, x2, x3, x4), class_preds = self.forward_enc(x)
if self.config.squeeze_block:
x4 = self.squeeze_module(x4)
########## Decoder ##########
features = [x, x1, x2, x3, x4]
if self.training and self.config.out_ref:
features.append(laplacian(torch.mean(x, dim=1).unsqueeze(1), kernel_size=5))
scaled_preds = self.decoder(features)
return scaled_preds, class_preds
def forward(self, x):
scaled_preds, class_preds = self.forward_ori(x)
class_preds_lst = [class_preds]
return [scaled_preds, class_preds_lst] if self.training else scaled_preds
总结
本文提出了一个配备双边参考的 BiRefNet 框架,该框架可在同一框架内执行二分图像分割、高分辨率显著目标检测和隐藏目标检测。通过全面的实验,研究者发现未缩放的源图像和对信息丰富区域的关注对于生成 HR 图像中精细且细节丰富的区域至关重要。为此,研究者提出了双边参考来填充精细部分中缺失的信息(内向参考),并引导模型更加关注细节更丰富的区域(外向参考)。这显著提升了模型捕捉微小像素特征的能力。为了降低 HR 数据训练的高昂训练成本,本文还提供了各种实用技巧,以实现更高质量的预测和更快的收敛速度。在 13 个基准测试中取得的优异结果证明了BiRefNet 的卓越性能和强大的泛化能力。