文章目录
1 概述
1.1 要点
题目:蒙版硬实例挖掘多示例学习 (Multiple instance learning framework with masked hard instance mining)
背景:全幻灯片图像 (WSI) 分类常被看作是一个典型的MIL问题,其已有的方法大都关注于使用注意力机制来鉴别显著实例 (易于分类的实例,反之是硬实例)。
问题:已有方法忽略了难以分类的实例,其对建立更准确模型以及辨别性决策边界是有益的。
方法:提出了一个新颖算法MHIM-MIL,其使用蒙版硬实例挖掘下的Siamese结构来探索与训练硬实例。
1.2 代码
https://github.com/DearCaat/MHIM-MIL/
1.3 引用
@InProceedings{Tang:2023:40784087,
author = {Tang, Wen Hao and Huang, Sheng and Zhang, Xiao Xian and Zhou, Feng Tao and Zhang, Yi and Liu, Bo},
title = {Multiple instance learning framework with masked hard instance mining for whole slide image classification},
booktitle = {{ICCV}},
year = {2023},
pages = {4078-4087}
}
2 方法
2.1 背景:MIL定义
在MIL中,任意WSI可以表示为 X = { x i } i = 1 N X=\{x_i\}_{i=1}^N X={xi}i=1N,其中 x i x_i xi是第 i i i个区块,亦即实例。对于分类任务,包对应一个已知的标签 Y ∈ C Y\in C Y∈C,实例对应未知标签 y n ∈ C y_n\in C yn∈C,其中 C C C是类别标签的集合。MIL模型 M ( ⋅ ) \mathcal{M}(\cdot) M(⋅)的目的是预测包标签 Y ^ ← M ( X ) \hat{Y}\leftarrow\mathcal{M}(X) Y^←M(X)。目前常用的方法是从实例的提取特征 Z = { z i } i = 1 N Z=\{z_i\}_{i=1}^N Z={zi}i=1N中习得包表示 F F F,并在此基础上训练一个分类器 C ( ⋅ ) \mathcal{C}(\cdot) C(⋅),即 Y ^ ← C ( F ) \hat{Y}\leftarrow\mathcal{C}(F) Y^←C(F)。这里有两种讲实例特征汇聚为包嵌入的方法:
- 注意力聚合:
F = ∑ i = 1 N a i z i ∈ R D , (1) \tag{1} F=\sum_{i=1}^Na_iz_i\in\mathbb{R}^D, F=i=1∑Naizi∈RD,(1)其中 a i a_i ai是注意力稀疏; - 多头自注意力聚合 (MSA):类别token
z
0
z_0
z0通过实例特征嵌入,以获取初始的输入序列
Z
0
=
[
z
0
,
z
1
,
…
,
z
N
]
i
n
R
(
N
+
1
)
×
D
Z^0=[z_0,z_1,\dots,z_N]in\mathbb{R}^{(N+1)\times D}
Z0=[z0,z1,…,zN]inR(N+1)×D:
head = A ℓ ( Z ℓ − 1 W V ) ∈ R ( N + 1 ) × D H , ℓ = 1 … L Z ℓ = Concat ( h e a d 1 , … , h e a d H ) W O , ℓ = 1 … L (2) \tag{2} \begin{array}{ll} \text{head}=A^\ell(Z^{\ell-1}W^V)\in\mathbb{R}^{(N+1)\times\frac{D}{H}},&\ell=1\dots L\\ Z^\ell=\text{Concat}(head_1,\dots,head_H)W^O,&\ell=1\dots L\\ \end{array} head=Aℓ(Zℓ−1WV)∈R(N+1)×HD,Zℓ=Concat(head1,…,headH)WO,ℓ=1…Lℓ=1…L(2)其中 W V ∈ R D × D H W^V\in\mathbb{R}^{D\times \frac{D}{H}} WV∈RD×HD、 W O ∈ R D × D W^O\in\mathbb{R}^{D\times D} WO∈RD×D、 A ℓ ∈ R ( N + 1 ) × ( N + 1 ) A^\ell\in\mathbb{R}^{(N+1)\times (N+1)} Aℓ∈R(N+1)×(N+1)是注意力矩阵、 L L L是MSA块的数量,以及 H H H是每个MSA的头数。包嵌入是最后一层的类别token:
F = Z 0 L . (3) \tag{3} F=Z^L_0. F=Z0L.(3)自注意力嵌入本质上是一类特殊的MIL注意力方法,本文将这些方法均归类为注意力方法。
2.2 MHIM-MIL
在基于注意力的MIL方法中,实例的注意力得分用于指示实例对于包分类的贡献。这些具有高得分的显著实例是益于分类的,但却不易于训练一个泛化性强的模型。尽管硬样本再很多计算机视觉领域被证明可以提升模型的泛化性,但已有的MIL工作却常常将这些忽略。
本文提出的MHIM-MIL来处理这个问题,其框架如图2所示:
- 在训练阶段,MHIM-MIL使用孪生结构,其主体部分是一个注意力MIL网络 (Student),表示为 S ( ⋅ ) \mathcal{S}(\cdot) S(⋅),用于汇聚实例特征;
- 为了增加student模型的辨别难度,且更多地关注硬实例,引入动量teacher,表示为 T ( ⋅ ) \mathcal{T}(\cdot) T(⋅),其通过注意力权重来给实例打分,然后使用蒙版硬实例挖掘策略来遮蔽显著实例,以此保留硬实例;
- 在推理阶段,所有挖掘的特征均传递给student以获取包标签;
- Teacher与student共享同一结构;
所提出方法的被定义为:
Y ^ = S ( Z ^ ) = S ( M T ( Z ) ) , (4) \tag{4} \hat{Y}=\mathcal{S}(\hat{Z})=\mathcal{S}(M_\mathcal{T}(Z)), Y^=S(Z^)=S(MT(Z)),(4)其中 M T ( ⋅ ) M_\mathcal{T}(\cdot) MT(⋅)表示通过teacher模型设计蒙版硬实例挖掘策略,以及 Z ^ \hat{Z} Z^是挖掘到的实例。
图2:MHIM-MIL总体框架。动量teacher用于计算所有实例的分数。然后基于硬实例挖掘注意力来为实例添加蒙版,并将余下的实例传递给student模型。student模型通过一致性损失 L c o n \mathcal{L}_{con} Lcon以及标签损失 L c l s \mathcal{L}_{cls} Lcls。teacher的损失通过student参数的指数移动平均 (EMA) 来更新,而非梯度。在推理阶段,将使用完整的输入实例,且只使用student模型
2.3 蒙版硬实例挖掘策略
在没用实例级别监督时,传统的硬实例发掘策略将难以应用。对此,通过遮蔽高注意力得分的实例来明确硬实例。具体地,给定一个实例特征的序列
Z
=
{
z
i
}
i
=
1
N
Z=\{z_i\}_{i=1}^N
Z={zi}i=1N作为teacher模型
T
(
⋅
)
\mathcal{T}(\cdot)
T(⋅)是输入,其输出每个实例的注意力权重
a
i
a_i
ai:
A
=
[
a
1
,
…
,
a
i
,
…
,
a
N
]
=
T
(
Z
)
.
(5)
\tag{5} A=[a_1,\dots,a_i,\dots,a_N]=\mathcal{T}(Z).
A=[a1,…,ai,…,aN]=T(Z).(5)然后获取注意力序列的降序索引:
I
=
[
i
1
,
i
2
,
…
,
i
N
]
=
Sort(A)
,
I=[i_1,i_2,\dots,i_N]=\text{Sort(A)},
I=[i1,i2,…,iN]=Sort(A),其中
i
1
i_1
i1和
i
N
i_N
iN分别是具有最高注意力和最低注意力分数的实例的索引。基于
I
I
I,我们提出了几个蒙版硬实例挖掘策略来选择硬实例。令
M
=
[
m
1
,
…
,
m
i
,
…
,
m
N
]
M=[m_1,\dots,m_i,\dots,m_N]
M=[m1,…,mi,…,mN]表示用于编码蒙版标志的
N
N
N维二元向量,其中
m
i
∈
{
0
,
1
}
m_i\in\{0,1\}
mi∈{0,1}。如果
m
i
=
1
m_i=1
mi=1,则表示蒙蔽第
i
i
i个实例。
2.3.1 高注意力蒙版
最简单的蒙版硬实例挖掘策略是高注意力蒙版 (HAM),其简单的蒙蔽注意力最高的top β h % \beta_h\% βh%的实例。
2.3.2 混合蒙版
通过结合其它实例蒙蔽技术,提出三种混合蒙版,如图3:
- L-HAM:额外蒙蔽 β l % \beta_l\% βl%注意力最低的实例 M l M_l Ml,即共蒙蔽 M ^ = M h ∪ M l \hat{M}=M_h\cup M_l M^=Mh∪Ml;
- R-HAM:随机蒙蔽 β r % \beta_r\% βr%的实例,防止过拟合,即 M ^ = M h ∪ M r \hat{M}=M_h\cup M_r M^=Mh∪Mr;
- LR-HAM:
M
^
=
M
h
∪
M
r
∪
M
l
\hat{M}=M_h\cup M_r\cup M_l
M^=Mh∪Mr∪Ml;
基于
M
^
\hat{M}
M^,蒙蔽实例序列计算为:
Z
^
=
M
T
(
Z
)
=
Mask
(
Z
,
M
^
)
∈
R
N
^
×
D
,
(7)
\tag{7} \hat{Z}=M_\mathcal{T}(Z)=\text{Mask}(Z,\hat{M})\in\mathbf{R}^{\hat{N}\times D},
Z^=MT(Z)=Mask(Z,M^)∈RN^×D,(7)其中
N
^
\hat{N}
N^是未蒙蔽的实例。
2.4 一致性优化
在孪生网络结构下,teacher指导student训练的同时,student学习的知识也会用于更新teacher。这样迭代的优化过程将渐进式地提升teacher的挖掘能力以及student的辨别能力。为了进一步促进这个优化过程,并使用动量teacher提供的额外监督信息,提出了一个一致性损失来约束两个模型的分类结果。
2.4.1 Student优化
Student包含两个损失:
- 度量包标签的交叉熵:
L c l s = Y log Y ^ + ( 1 − Y ) log ( 1 − Y ^ ) . (8) \tag{8} \mathcal{L}_{cls}=Y\log\hat{Y}+(1-Y)\log(1-\hat{Y}). Lcls=YlogY^+(1−Y)log(1−Y^).(8) - Student
F
s
F_s
Fs和动量teacher
F
t
F_t
Ft包标签之间的一致性损失:
L c o n = − softmax ( F t / τ ) log F s , (9) \tag{9} \mathcal{L}_{con}=-\text{softmax}(F_t/\tau)\log F_s, Lcon=−softmax(Ft/τ)logFs,(9)其中 τ > 0 \tau>0 τ>0是温度参数;
综上,最终的优化损失定义未:
{
θ
^
s
}
←
arg min
θ
s
L
=
L
c
l
s
+
α
L
c
o
n
,
(10)
\tag{10} \{\hat{\theta}_s\}\leftarrow\argmin_{\theta_s}\mathcal{L}=\mathcal{L}_{cls}+\alpha\mathcal{L}_{con},
{θ^s}←θsargminL=Lcls+αLcon,(10)其中
θ
s
\theta_s
θs是
S
(
⋅
)
\mathcal{S}(\cdot)
S(⋅)的参数,以及
α
\alpha
α是缩放因子。
2.4.2 Teacher优化
Teacher的参数
θ
t
\theta_t
θt通过指数移动平均 (EMA) 获得:
θ
t
←
λ
θ
t
+
(
1
−
λ
)
θ
s
,
\theta_t\leftarrow\lambda\theta_t+(1-\lambda)\theta_s,
θt←λθt+(1−λ)θs,其中
λ
\lambda
λ是参参数。