본문 바로가기
Digital pathology

ReMix: A General and EfficientFramework for Multiple InstanceLearning Based Whole Slide ImageClassification

by 연금(Pension)술사 2025. 9. 4.

MICCAI 2022

 

Motivation


  • 메모리 및 연산 비효율성 문제: Bag내의 multiple instance인 점으로 학습속도 느림. 패치수도 달라서 배치연산도 안됨.
  • 데이터 다양성부족: 딥러닝은 보통 다양한 샘플에서 비롯되는데, 거의 활용안됨. 라벨을 보존하면서 의미있는 증강필요.

 

주요가정


  • WSI 내의패치들이 거의 반복적, 유사, 중
  • 프로토타입으로 표현이 어느정도 가능할것

 

방법론


  • Reduce: 프로토타입(대표)패치를 정하는 방법
    1. K-means clustering을 진행하여, Centroid을 구함.
    2. 해당 클러스터에 속한(=맴버십)인 패치의 임베딩의 공분산을 구함
    3. 하나의 multivariate guassian mixture모델로 고려(N(중심점, 공분산)).
  • Mix: 데이터의 다양성을 위한 특징수준에서의 증강방법.
    • 다른 Bag에서 가져와서 믹스하는 방식
    • Bag두 개를 선정 (Query bag $B'_{q}$, key bag: $B'_{k}$) 각 bag은 프로토타입의 집합.
    • 하나의 쿼리 프로토타입에 대해서, 가장 인접합 key centroid을 하나 찾음(c_{i}^{*}).
    • 그리고 다음의 4가지 방법으로 증강가능
      • Append: 가장 유사한 key centroid을 quey bag에 추가
      • Replace: query centroid와 가장 유사한 key centroid을 교체
      • Interpolate: 두 centroid을 선형보간
      • Covary: $\hat {c}_{i} = c_{i}^{q} + \lambda \delta, \delta \~ N(0, \sum_{i}^{k})$

 

 

Q1. Augmentation은 몇개나 몇확률로? 

아래의 코드를 보면, rate의 인자를 기준으로 전체 centroid을 바꿀지말지 결정함. 

def mix_aug(src_feats, tgt_feats, mode='replace', rate=0.3, strength=0.5, shift=None):
    assert mode in ['replace', 'append', 'interpolate', 'cov', 'joint']
    auged_feats = [_ for _ in src_feats.reshape(-1, 512)]
    tgt_feats = tgt_feats.reshape(-1, 512)
    closest_idxs = np.argmin(cdist(src_feats.reshape(-1, 512), tgt_feats), axis=1)
    if mode != 'joint':
        for ix in range(len(src_feats)):
            if np.random.rand() <= rate:
                if mode == 'replace':
                    auged_feats[ix] = tgt_feats[closest_idxs[ix]]
                elif mode == 'append':
                    auged_feats.append(tgt_feats[closest_idxs[ix]])
                elif mode == 'interpolate':
                    generated = (1 - strength) * auged_feats[ix] + strength * tgt_feats[closest_idxs[ix]]
                    auged_feats.append(generated)
                elif mode == 'cov':
                    generated = auged_feats[ix][np.newaxis, :] + strength * shift[closest_idxs[ix]][np.random.choice(200, 1)]
                    auged_feats.append(generated.flatten())
                else:
                    raise NotImplementedError
    else:
        for ix in range(len(src_feats)):
            if np.random.rand() <= rate:
                # replace
                auged_feats[ix] = tgt_feats[closest_idxs[ix]]
            if np.random.rand() <= rate:
                # append
                auged_feats.append(tgt_feats[closest_idxs[ix]])
            if np.random.rand() <= rate:
                # interpolate
                generated = (1 - strength) * auged_feats[ix] + strength * tgt_feats[closest_idxs[ix]]
                auged_feats.append(generated)
            if np.random.rand() <= rate:
                # covary
                generated = auged_feats[ix][np.newaxis, :] + strength * shift[closest_idxs[ix]][np.random.choice(200, 1)]
                auged_feats.append(generated.flatten())
    return np.array(auged_feats)

 


 

반응형