본문 바로가기
Digital pathology

In context MIL

by 연금(Pension)술사 2026. 6. 29.

2026.06.04 arXiv

 

Motivation


1. in-context learning의 필요성: 기존 병리 MIL에서는 WSI 하나를 bag, tile/patch embedding을 instance로 보고 slide-level label만으로 학습합니다. 문제는 실제 병리 데이터셋이 생각보다 작거나, organ/task마다 label 수가 적다는 점입니다. 예를 들어 rare cancer detection, 새로운 biomarker prediction, gastric biopsy에서 특정 grading task처럼 slide-level label은 있지만 수백 장 이하인 경우 ABMIL/TransMIL을 매번 학습하면 overfit, hyperparameter instability, seed variance가 큼. => 그렇기에 LLM의 in-context learning 처럼 바꿈. 

즉, ABMIL을 새로 학습하지 않고 {labeled WSI bags + labels}를 context로 넣고 query WSI를 바로 예측하는 pretrained MIL learner를 쓰자.

다른 말로하면, task마다 새로 fit하지 말고, labeled bags를 context로 넣어서 새 bag을 바로 예측하자

 

Method


  • $Q_{\theta} (y_{N+1} | B_{N+1}, {(B_i, y_i)}_{i=1}^{N} )$ : 이미 알려진 데이터셋 i~N개의 (WSI, label)을 보고, 새로운 abg의 라벨을 예측하는 모델. 근데, 이 모델 $Q_{\theta}$
  • $L(theta) = E_{D~p(D)}[-log q_theta(y_{N+1} | B_{N+1}, context bags)]$
  •  
  •  

2.2 Prior-data fitted networks: 진짜 데이터로 매번 모델을 학습하는 대신, 가짜 데이터 생성 규칙(prior)으로 엄청 많이 사전학습해둔 뒤, 새 데이터셋이 오면 학습 없이 바로 예측하는 모델. 약간 few-shot prompting이랑 비슷함. 일종의 psuedo-bag 및 task을 만드는 과정임.

주의: synthetic feature vector bag 사용한다는 것. MIL처럼 real WSI patch embedding bag 사용하는게 아님..

1. prior에서 data-generating process 하나를 뽑는다. = 하나의 synthetic task 생성
2. 그 process로 작은 dataset을 만든다. = 여러 개의 pseudo-bag + label 생성
3. 일부 bag-label은 context로 주고, query bag label을 맞히게 한다.
4. 이걸 수백만 번 반복한다.

예를 들어, task을 만드는게 핵심인데 아래와 같은 느낌. MLP-SCM, Tree-SCM이라는 generator를 TabICL에서 가져와 씁니다. SCM은 structural causal model인데, 여기서는 너무 어렵게 볼 필요 없고, 그냥: 랜덤하게 만든 nonlinear data generator

Task 1:
bag 안에 특정 type의 instance가 하나라도 있으면 positive

Task 2:
bag 안에 특정 type의 instance가 3개 이상이면 positive

Task 3:
A type과 B type instance가 같이 있으면 positive

Task 4:
bag 전체 feature 평균이 threshold보다 크면 positive

일단, bag 안에서 먼저 augmentation하는 것 

  1. Factorized prior: 한 bag내에 특정 instance 가 있는지? 또는 3개 이상인지? type1+type4가 있는지를 샘플링. 아래의 수식(4)에서의 $\phi$가 이 기능을함.
  2. Joint prior: bag 전체 구조로 label 결정. bag의 구조를보고 라벨을 결정하는식.

3. In-context learning for MIL

어찌되든, 아래의 수식에 맞춰서 보유데이터 + 쿼리데이터 식으로 넣어줘야하는데, 3가지 문제점이 있음

  • Challenge 1 (계산량) N개의 데이터셋을 추론시에 같이 넣는게 상당히 계산량이 큼.  예) WSI 100장 x 인스턴스 평균 100개
  • Challenge 2 (인스턴스의 의존성): bag을 요약할때, 어떤 인스턴스가 중요한지는 task따라 다름. => instance aggregation과 inter-bag attention을 여러 번 반복
  • Challenge 3 (순서비존성): MIL의 주요가정(다만, 어느WSI 소속인지는 유지되어야함)   learnable bag token으로 해결

 

3.1. ICMIL: Perceiver MIL 구조임. 

  • instance aggergation: bag 마다 bag-level token과 연산되게해서 요약함. feature-level split을 해서, G축을 하나 더만듬. 결과적으로는 M 개의 슬라이드 x I개의 인스턴스 x G개의 특징그룹 x F개의 특징차원
  • bag-level token을 넣어서 attention계산 후, residual해서 슬라이드 단위 임베딩을 구함.
  • Inter-bag attention: 그 슬라이드 단위 임베딩을 다시 계산 (식3)

 

 

Results


병리쪽 TCGA(LUAD vs LUSC)에서는 좀 약한데, 다른 의료쪽인 RSNA ICH에서는 그래도 ABMIL대비 성능개선이 좀 있음. 

반응형