Review of “[2004.09666] Data Efficient and Weakly Supervised Computational Pathology on Whole Slide Images”

来源

这篇论文来自 2020 年 4 月 arxiv 预印的一篇 Lu et al. 所著关于若监督病理 WSI 图像处理的论文 Data Efficient and Weakly Supervised Computational Pathology on Whole Slide Images

*[Lu et al.]: M. Y. Lu, D. F. K. Williamson, T. Y. Chen, R. J. Chen, M. Barbieri, and F. Mahmood, “Data Efficient and Weakly Supervised Computational Pathology on Whole Slide Images,” Apr. 2020.

论文中提到了其实现代码的 GitHub 仓库 和一个交互式的 Demo

来龙去脉

WSI: Whole Slide Image, 全视野数字切片

Lu et al. 出发点就是当尝试对一个 WSI 进行分析的时候,往往分析的结果需要是一个宏观任务的结果,比如是不是乳头状癌、或者某种评分是多少。但是由于 WSI 图像本身的特点,无法直接通过端到端的网络对 WSI 图像进行监督的学习,就像 MNIST 任务那样。因为 WSI 图像的像素个数已经达到千万甚至亿的级别。这类任务通常都需要通过中观、微观的任务进行特征的“提取”,然后进行分类或者其他的判断。这就引发了一个问题:

中观、微观的标注量太大了

所以针对这样的一个问题,Lu et al. 提出了 “CLAM” 这样的一个算法框架。这个算法框架将对中观 “patch” 的聚类作为弱监督的一种约束方式。标注只需要给出 WSI 的分类,然后通过弱监督的方式一方面对 patch 通过聚类对进行有限度的“无监督分类”,另一方面对学习到的特征来进一步提取特征,然后达到对 WSI 分类的效果。

CLAM A-Z

这一节将通过 A-Z 的方式,梳理 Lu et al. 提出的 CLAM 的是怎样“运作”的。

BackBone

讲 Backblone 之前,先贴一个原论文中的图形摘要。这个图像把论文中提出的 CLAM 的流程大致的展示了出来。

Lu et al. 图形摘要

其实 CLAM 比较好理解,很简单的来说就是五个步骤

  1. 首先是提取的 patch, 这一步就是把非空的区域的 patch 提取出来。
  2. 其次是通过一个预训练好的 CNN 模型,对每一个 patch 进行特征提取。
  3. 然后是通过一个“注意力”的网络来对提取的特征计算一个“注意力评分”。(每个类别都会对应一个这样的“评分”)
  4. 接着是对这些评分进行处理,首先是对这些评分进行聚类,并得到一个聚类的 loss。同时将这些“注意力”的评分用于 attention pooling。
  5. 最后得到的特征,被用来 slide 层面上进行有监督的分类。

同时通过聚类于注意力机制,每一个 patch 可以获得一个评分,用来衡量对判断的支撑度。而这个支撑度反应回图像上又能对网络进行解释。

分割 与 Patch 提取

第一个步骤是 分割 与 Patch 提取,这部分代码可以在 create_patches.py 找到对应内容。而 GitHub 中对提取 Patch 的说明可以在 这里 找到。

分割

首先是分割,这部分代码对应 create_patches.py segment() 与其调用的 wsi_core/WholeSlideImage.py WholeSlideImage.segmentTissue() 中找到代码对应。

WSI 首先是下采样,并被装在到内存中,同时将 RGB 色系转换成 HSV 色系。然后通过对饱和通道(也就是 HSV 中的 S)进行阈值二值化进行分割,并找出前景。分割的阈值为图像的中值。 然后通过一些形态学的方法,对小孔、间隙进行封闭。最后将分割变为轮廓并储存起来。

Patch 提取

Patch 提取是将上一步分割出来的前景,提取成 256×256256 \times 256 的小块,并使用 HDF5 进行存储。

*[HDF5]: Hierarchical Data Format 5

提取的 Patch 通常是 20×20\times40×40\times 的中微观层面的内容。

基于切片层面的标注的弱监督分类学习

基于若监督分类的内容在 CLAM 的 文档 中有介绍。若监督分类主要是两个步骤,一个是基于特征提取,另一个是基于多 branch 和 注意力的弱监督网络上面。

特征提取

CLAM repo 中对于特征提取的说明可以在 这里 找到。而使用的代码是 extract_features.py

在 ImageNet 上预训练的 ResNet50 对提取的 Patch 进行特征提取(代码详见 extract_features.py __main__,将 256×256256 \times 256 的图像提取为 10241024 维度的特征。 提取的特征被用于进一步的基于注意力的弱监督的分类任务上。再将 256×256256 \times 256 的图像提取为 10241024 维度的特征的过程中,使用 adaptive mean-spatial pooling(Liu et al.) 进行池化操作,并得到 10241024 维度的特征。

训练

三分类示例二分类示例 训练的说明在 repo 中有。训练使用的 main.py 中的代码进行。

训练的代码实际的dingyizai utils/core_utils.py train() 中。 而 models/model_clam.py 中定义了 CLAM 使用的相关的神经网络模型。

models/model_clam.py class Attn_Netmodels/model_clam.py class Attn_Net_Gated 定义了注意力网络相关的内容。

models/model_clam.py class CLAM 这个是 CLAM 进行弱监督的网络模型。 而 models/model_clam.py CLAM.forward() 是模型执行的过程。

Attn_Net 由一个全连接层、一个 tanh()tanh(\cdot) 激活函数、可选的 Dropout\text{Dropout} 层和一个全连接层 组成的。

Attn_Net_Gated 则是由三个序列组成。第一个序列是全连接层、tanh()tanh(\cdot) 激活函数、可选的 Dropout\text{Dropout}; 第二条是全连接层、sigmoid()sigmoid(\cdot) 激活函数、可选的 Dropout\text{Dropout};第三个是单独的一个全连接层。若将三个序列表示的网络分别记为 A()A(\cdot)B()B(\cdot)C()C(\cdot),则整个 Attn_Net_Gated 网络可以表示为

y=f(x)=C(A(x)×B(x)) y = f(x) = C\left( A\left(x\right) \times B\left(x\right) \right)

Attn_Net 与 Attn_Net_Gated 两个网络实际实现中,forward(·) 中,返回的是 f(x)f(x)xx 所组成的字典。

CLAM 网络的组成相对复杂。第一步是通过注意力网络提取注意力评分。第二步是对提取的注意力评分进行聚类并生成 instance_level 的 loss。第三步 应用注意力评分。最后一步 则是将“应用注意力评分”的结果转换为最后的 slide-level 分类(使用 softmax)与训练。

网络最后的 loss 是由 CLAM 的 slide-level 分类的 loss 与 instance-level 聚类的 loss 相加获得的。

total=c1slide+c2patch \ell_{total} = c_1\ell_{slide} + c_2\ell_{patch}

SVM loss 与 smooth SVM loss

models/model_clam.py L166-L188 对 instance level 的数据进行了聚类,并且对聚类的内容进行了“评价”以获得一个损失函数,用于网络学习。这个聚类针对每一个分类分别进行。

“聚类”的过程,本质上来说,还是一种评判 —— 通过一定方式计算损失。当分类器与patch类别一直的时候,通过 models/model_clam.py CLAM.inst_eval() 来计算 loss,反之通过 models/model_clam.py CLAM.inst_eval_out 计算 loss。 Instance-level 聚类的算法在其论文也有些,但是这里先看如何获得 loss。

Instance-level 所用的 loss 是 smooth SVM loss , 也就是Berrade et al. 提出的 smooth top1 SVM loss

常规 SVM loss 如下公式所示:

L(s,y)=max{maxjY/{y}{sj+α}sy,0} \mathcal{L}(s,y) = \max \left\{ \max\limits_{j \in \mathcal{Y}/\{y\}} \left\{s_j + \alpha\right\}-s_y, 0 \right\}

而论文中使用的 smooth SVM loss 则如下公式所示:

L1,τ(s,y)=τlog(iYexp(1τ(αI(jy)+sjsy))) \mathcal{L}_{1,\tau}(s,y) = \tau \log\left(\sum\limits_{i \in \mathcal{Y}} \exp\left(\frac{1}{\tau}\left(\alpha\mathbb{I}(j \neq y) + s_j - s_y \right)\right) \right)

这两个公式是针对多 标签/分类 相关的任务的,对于类别 Y={0,1,,n1}\mathcal{Y} = \left\{0, 1, \cdots, n - 1\right\} 为 n 个类别。其中损失函数的两个参数 s\mathbf{s}yy 分别是预测评分向量与标注分类,α\alpha 是一个指定的余量或者阈值(a specified margin), τ\tau “温度系数”(a temperature scaling)。

以二分类为例,对这两个 loss 进行了可是化。根据损失函数的具体的形式,将 sjsys_j - s_y 这个部分变成了 Δs\Delta s,绘制了 Δs\Delta s \sim \ell 的图像。

SVM loss 1

SVM loss 2

SVM loss 3

SVM loss 与 smooth SVM loss 对比 smooth SVM loss ($\alpha$ 不同) smooth SVM loss ($\tau$ 不同)
图: 对 SVM loss 与 smooth SVM 的可视化

无监督聚类

为了通过上面的 SVM loss 或者 smooth SVM loss 对网络进行优化,instance-level 的标注是需要有的。然而 CLAM 是 slide-level 标注的弱监督分类,也就是说 instance-level 并没有标注。所以就需要 instance-level 的聚类来赋予 label。

CLAM 的多类别分类是通过多个 branch 的分类完成的,每个 branch 都是二分类。所以每个聚类也是二分类。 论文中使用了一个有 512 个神经元的隐藏层,WR2×512\mathbf{W} \in \mathbb{R}^{2\times 512}

简单来说,在提取了这些内容之后,首先根据注意力相关评分对关键元素进行排序,然后会按照类内与类外来分别计算“loss”。对于类别 YY 的分类器来说,loss 会计算“是”与“不是”两种情况,而对于其他的,则是计算”不是“者一种情况。”是“与”不是“在论文中被称作为 ”prediction for positive evidence“ 与 ”prediction for negative evidence“。通过这样的方式来完成聚类。

好吧,说实话,原论文中的聚类的这部分讲的也不是太好理解,所以我也只是看出来一个大概。

Reference

M. Y. Lu, D. F. K. Williamson, T. Y. Chen, R. J. Chen, M. Barbieri, and F. Mahmood, “Data Efficient and Weakly Supervised Computational Pathology on Whole Slide Images,” Apr. 2020.

Y. Liu, Y. M. Zhang, X. Y. Zhang, and C. L. Liu, “Adaptive spatial pooling for image classification,” Pattern Recognit., vol. 55, pp. 58–67, Jul. 2016, doi: 10.1016/j.patcog.2016.01.030.

L. Berrada, A. Zisserman, and M. P. Kumar, “Smooth loss functions for deep top-k classification,” 6th Int. Conf. Learn. Represent. ICLR 2018 - Conf. Track Proc., Feb. 2018.