Digital pathology
[리뷰] ROAM: A transformer-based weakly supervised computational pathology method for clinical-grade diagnosis and molecular marker discovery of gliomas
연금(Pension)술사
2025. 1. 22. 15:56
Method: ROAM은 4단계로 구성: 1) ROI추출 -> 2) Self-attetion을 이용한 인스턴스 표현 3) attention weigthed aggregation, 4) Multi-level supervision의 4단계
1. ROI 추출 및 특징화: WSI로부터 2,048 x 2,048 (0.5mpp, x20)으로 ROI을 추출합니다.
- 추출과정에서는 Overlapping 없이 진행합니다.
- 2048x2048 ROI에서 2배, 4배의 다운샘플링을 진행합니다 (1024x1024, 512x512)가 생성. 논문에서는 $R_{i}^{0}, R_{i}^{1}, R_{i}^{2}$ 로 표현함
- 각 ROI에서 다시, 256x256사이즈의 패치를 뽑아서 pretrained CNN에 통과. 각 통과한 출력은 ($X_{i}^{0}\in \mathbb{R}^{64\times c }, X_{i}^{1}\in \mathbb{R}^{16\times c } , X_{i}^{2}\in \mathbb{R}^{4 \times c }$로 표현.
- 각 ROI의 출력을 FC을 하나 통과시켜서, 초기의 특징값(d차원으로 변경)을 구성. 이를 하나로 합쳐(concat) 하나의 특징으로 활용 $X_{i} \in \mathbb{R}^{84 \times d}$ . 즉 $\mathbf{X}_{i} = [X_{i}^{0}, X_{i}^{1}, X_{i}^{2}]$
- 한 슬라이드(b)은 M개의 ROI을 가짐
- 스케일 s에 대해서 $\mathbf{X}_{s} = [X_{i}^{s}, X_{i}^{s}, ...X_{n_{s}}^{s}]$ 로 표현함. 각 ROI은 d차원을 가짐
2. Self-attention(SA)을 이용한 인스턴스 표현
- 2.1. embedding projection + positional embedding
- $Z_{s0} = \begin{bmatrix} x^{s}_{cls}; & x^{s}_{1}E,..., x^{s}_{n_{s}E} \end{bmatrix} + E_{pos}$
- 2.2. Intra-scale SA(동일스캐일 내): ViT에서 원래 해주던 SA
- 추가로 Relative position bias을 도입합니다. $\text{Attention Scores} = \frac{\mathbf{Q} \mathbf{K}^\top}{\sqrt{d_k}} + \text{Relative Position Bias}$
- 이 Relative position bias을 주는 이유는 PE(Positional embedding)은 절대적인 위치로만 임베딩 되기에 이를 보정하기위한 bias을 추가한 경우입니다. 예를 들어, A단어는 B단어의 2칸 옆에 떨어져있다와 같은 정보가 relative position이 되겠습니다.
- $h_{cls}^{s}$: Intrascale SA에서 나온 출력시퀀스에 첫 토큰으로, cls token에 해당합니다.
- 2.3 Inter-scale SA:
- Inter-scale SA은 고배율의 4개의 패치의 특징과 저배율의 1개의 패치(특징아님)을 concat하여 또 다른 transformer에 전달합니다. Transformer에서는 positional embedding도 해줍니다 (아마, 첫번째 토큰이 저배율임을 인식시키기 위함일 듯 합니다)
- 전달한 결과의 가장 첫 토큰을 inter-scale 표현으로 사용합니다. $Z^{s+1}, \_ = InterscaleSA(Z^{s->(s+1)})$
- 집계: Intrascale 결과의 cls_token결과만 weighted sum해서 사용합니다. 이 $w_{s}$은 [0,1]의 값을 가지고 sum to 1으로 정규화한갑입니다.
3. Pooling: attention weighted aggregation
- 한 슬라이드내에 ROI별로 $h$을 구했다면, 이를 모두 attention weighted sum하여 슬라이드 예측에 사용합니다.
- 식(6)은 gated attention weighted 입니다.
- 식(7)은 attention weighted sum
4. Multi-level supervision: Slide prediction과 patch level prediction을 multi-task로 학습
- Slide level prediction은 $h_{slide,k}$을 이용해서 FC을 통과하여 예측합니다.
- Instance level prediction은 $h_{i}$을 이용해서 FC을 통과하여 예측합니다.
- instance level label이 없기 때문에, CLAM의 방법과 동일하게 top k 의 라벨을 그대로 따라 씁니다.
반응형