论文学习18:Bilateral Reference for High-Resolution Dichotomous Image Segmentation

代码来源

https://github.com/ZhengPeng7/BiRefNet

模块作用

DIS 是一种旨在对高分辨率图像中的目标物体进行精确分割的技术,尤其适用于具有复杂细微结构的物体,例如细长的边缘或微小细节。传统方法在处理这类任务时往往难以捕捉细微特征或恢复高分辨率细节,因此论文提出了一种新颖的网络架构BiRefNet以解决这些挑战。

模块结构

定位模块(LM)
  • 输入高分辨率图像至视觉变换器骨干网络。
  • 提取多尺度的层次特征,捕捉全局语义信息。
  • 通过特征融合和压缩,生成低分辨率的粗略预测图。
  • 原理:利用变换器的全局建模能力,在低分辨率下快速定位目标物体,避免直接处理高分辨率带来的计算负担。
重建模块(RM)
  • 接收定位模块输出的低分辨率粗略预测图。
  • 在解码器的多个阶段,逐步上采样并结合双边参考信息。
  • 输出高分辨率的精细分割图。
  • 原理:通过将原始图像的分块输入解码器,提供高分辨率的细节参考,确保重建过程中细节不丢失。通过梯度图的监督,引导模型聚焦于边缘和细微结构,避免模糊或遗漏关键区域。从低分辨率到高分辨率的分阶段上采样,确保全局一致性和局部精确性的平衡。

代码

class BiRefNet(
    nn.Module,
    PyTorchModelHubMixin,
    library_name="birefnet",
    repo_url="https://github.com/ZhengPeng7/BiRefNet",
    tags=['Image Segmentation', 'Background Removal', 'Mask Generation', 'Dichotomous Image Segmentation', 'Camouflaged Object Detection', 'Salient Object Detection']
):
    def __init__(self, bb_pretrained=True):
        super(BiRefNet, self).__init__()
        self.config = Config()
        self.epoch = 1
        self.bb = build_backbone(self.config.bb, pretrained=bb_pretrained)

        channels = self.config.lateral_channels_in_collection

        if self.config.auxiliary_classification:
            self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
            self.cls_head = nn.Sequential(
                nn.Linear(channels[0], len(class_labels_TR_sorted))
            )

        if self.config.squeeze_block:
            self.squeeze_module = nn.Sequential(*[
                eval(self.config.squeeze_block.split('_x')[0])(channels[0]+sum(self.config.cxt), channels[0])
                for _ in range(eval(self.config.squeeze_block.split('_x')[1]))
            ])

        self.decoder = Decoder(channels)

        if self.config.ender:
            self.dec_end = nn.Sequential(
                nn.Conv2d(1, 16, 3, 1, 1),
                nn.Conv2d(16, 1, 3, 1, 1),
                nn.ReLU(inplace=True),
            )

        # refine patch-level segmentation
        if self.config.refine:
            if self.config.refine == 'itself':
                self.stem_layer = StemLayer(in_channels=3+1, inter_channels=48, out_channels=3, norm_layer='BN' if self.config.batch_size > 1 else 'LN')
            else:
                self.refiner = eval('{}({})'.format(self.config.refine, 'in_channels=3+1'))

        if self.config.freeze_bb:
            # Freeze the backbone...
            print(self.named_parameters())
            for key, value in self.named_parameters():
                if 'bb.' in key and 'refiner.' not in key:
                    value.requires_grad = False

    def forward_enc(self, x):
        if self.config.bb in ['vgg16', 'vgg16bn', 'resnet50']:
            x1 = self.bb.conv1(x); x2 = self.bb.conv2(x1); x3 = self.bb.conv3(x2); x4 = self.bb.conv4(x3)
        else:
            x1, x2, x3, x4 = self.bb(x)
        if self.config.mul_scl_ipt:
            B, C, H, W = x.shape
            x_pyramid = F.interpolate(x, size=(H//2, W//2), mode='bilinear', align_corners=True)
            if self.config.mul_scl_ipt == 'cat':
                if self.config.bb in ['vgg16', 'vgg16bn', 'resnet50']:
                    x1_ = self.bb.conv1(x_pyramid); x2_ = self.bb.conv2(x1_); x3_ = self.bb.conv3(x2_); x4_ = self.bb.conv4(x3_)
                else:
                    x1_, x2_, x3_, x4_ = self.bb(x_pyramid)
                x1 = torch.cat([x1, F.interpolate(x1_, size=x1.shape[2:], mode='bilinear', align_corners=True)], dim=1)
                x2 = torch.cat([x2, F.interpolate(x2_, size=x2.shape[2:], mode='bilinear', align_corners=True)], dim=1)
                x3 = torch.cat([x3, F.interpolate(x3_, size=x3.shape[2:], mode='bilinear', align_corners=True)], dim=1)
                x4 = torch.cat([x4, F.interpolate(x4_, size=x4.shape[2:], mode='bilinear', align_corners=True)], dim=1)
            elif self.config.mul_scl_ipt == 'add':
                x1_, x2_, x3_, x4_ = self.bb(x_pyramid)
                x1 = x1 + F.interpolate(x1_, size=x1.shape[2:], mode='bilinear', align_corners=True)
                x2 = x2 + F.interpolate(x2_, size=x2.shape[2:], mode='bilinear', align_corners=True)
                x3 = x3 + F.interpolate(x3_, size=x3.shape[2:], mode='bilinear', align_corners=True)
                x4 = x4 + F.interpolate(x4_, size=x4.shape[2:], mode='bilinear', align_corners=True)
        class_preds = self.cls_head(self.avgpool(x4).view(x4.shape[0], -1)) if self.training and self.config.auxiliary_classification else None
        if self.config.cxt:
            x4 = torch.cat(
                (
                    *[
                        F.interpolate(x1, size=x4.shape[2:], mode='bilinear', align_corners=True),
                        F.interpolate(x2, size=x4.shape[2:], mode='bilinear', align_corners=True),
                        F.interpolate(x3, size=x4.shape[2:], mode='bilinear', align_corners=True),
                    ][-len(self.config.cxt):],
                    x4
                ),
                dim=1
            )
        return (x1, x2, x3, x4), class_preds

    def forward_ori(self, x):
        ########## Encoder ##########
        (x1, x2, x3, x4), class_preds = self.forward_enc(x)
        if self.config.squeeze_block:
            x4 = self.squeeze_module(x4)
        ########## Decoder ##########
        features = [x, x1, x2, x3, x4]
        if self.training and self.config.out_ref:
            features.append(laplacian(torch.mean(x, dim=1).unsqueeze(1), kernel_size=5))
        scaled_preds = self.decoder(features)
        return scaled_preds, class_preds

    def forward(self, x):
        scaled_preds, class_preds = self.forward_ori(x)
        class_preds_lst = [class_preds]
        return [scaled_preds, class_preds_lst] if self.training else scaled_preds

总结

本文提出了一个配备双边参考的 BiRefNet 框架,该框架可在同一框架内执行二分图像分割、高分辨率显著目标检测和隐藏目标检测。通过全面的实验,研究者发现未缩放的源图像和对信息丰富区域的关注对于生成 HR 图像中精细且细节丰富的区域至关重要。为此,研究者提出了双边参考来填充精细部分中缺失的信息(内向参考),并引导模型更加关注细节更丰富的区域(外向参考)。这显著提升了模型捕捉微小像素特征的能力。为了降低 HR 数据训练的高昂训练成本,本文还提供了各种实用技巧,以实现更高质量的预测和更快的收敛速度。在 13 个基准测试中取得的优异结果证明了BiRefNet 的卓越性能和强大的泛化能力。

内容概要:本文介绍了基于SMA-BP黏菌优化算法优化反向传播神经网络(BP)进行多变量回归预测的项目实例。项目旨在通过SMA优化BP神经网络的权重和阈值,解决BP神经网络易陷入局部最优、收敛速度慢及参数调优困难等问题。SMA算法模拟黏菌寻找食物的行为,具备优秀的全局搜索能力,能有效提高模型的预测准确性和训练效率。项目涵盖了数据预处理、模型设计、算法实现、性能验证等环节,适用于多变量非线性数据的建模和预测。; 适合人群:具备一定机器学习基础,特别是对神经网络和优化算法有一定了解的研发人员、数据科学家和研究人员。; 使用场景及目标:① 提升多变量回归模型的预测准确性,特别是在工业过程控制、金融风险管理等领域;② 加速神经网络训练过程,减少迭代次数和训练时间;③ 提高模型的稳定性和泛化能力,确保模型在不同数据集上均能保持良好表现;④ 推动智能优化算法与深度学习的融合创新,促进多领域复杂数据分析能力的提升。; 其他说明:项目采用Python实现,包含详细的代码示例和注释,便于理解和二次开发。模型架构由数据预处理模块、基于SMA优化的BP神经网络训练模块以及模型预测与评估模块组成,各模块接口清晰,便于扩展和维护。此外,项目还提供了多种评价指标和可视化分析方法,确保实验结果科学可信。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值