Motivation

1. Vanillar Attention MIL은 IID(independent and identical distribution)에 기반하고 있어서, 패치별 연관성을 고려하지 못함.

2. 이에 대한 대안으로 Transformer을 사용하는 경우도 계산량 때문에, 시퀀스 길이가 짧아야해서 WSI분석에는 적용이 어려움.

예를 들어, 아래의 그림과 같이 Vanillar Attention인 경우(d), $\sum \alpha_{i} * h_{i}$와 같이 attention weight * instance embedding으로 가중합으로 하게되는데, 이때 attention weight은 i의 인스턴스만이고려된 가중치임(=attention weight은 attention score로 유래되고, i 인스턴스만의 정보로만 계산됨). 달리말하면, 예측에 어떤 패치가 중요한지만을 가중치로 둠. 반면 Self-attention pooling을 하면, 국지적인 연관성도 하나의 정보로  확인해서, 이 국지적 연관성이 높다는 것을 하나의 가중치로 확인함.

 

Method:  TransMIL은 TPT Module로 이뤄짐

1. TPT Module의 중요 특징이 2가지

첫 째로, WSI 패치로 나눌 때, 패치가 N x N으로 연산되지 않고, $\sqrt{N} \times \sqrt{N}$으로 연산되도록 하는 Transformer구조를 가짐. 여기서 국지적인 연관성이 고려됨. * TPT Module에 대한 Full term이 논문내에서는 주어지지 아니함.

둘 째로, Convolution을 가변적으로 진행하기 위한 PPEG(Pyramid Position Encoding Generator). PPEG은 이미지를 N by N kernel 사이즈를 달리하여 Convolution을 적용함. 이 PPEG의 이점은 서로 다른 위치정보를 압축하고, 혼합할 수 있음.

 

TPT Module은 아래와 같은 알고리즘입. 크게는 1) 정사각형으로 시퀀싱을 만드는 단계, 2)멀티해드어텐션 단계, 3) PPEG적용단계), 4) 멑리해드어텐션 단계, 5) Slide level예측

그림으로 보면 단순함. 원래는 패치수가 30개면, 30 x 30의 Attention score (alignment score)을 구한후에 Attention weight을 계산하게 되는데, 여기서는 30x30이 아니라, 루트30 x 루트 30으로 계산함.  패치 수가 30개면, 이의 제곱근에 가장가까운 윗 정수는 6으로 계산됨. 즉, 이 행렬은 6x6의 사이즈를 만드는것이 목적임.

이렇게 하는 이유는 Attention score을 N x N으로 만들면 어짜피, 상삼각부분은 i,j의 패치와 i,j의 위치에 해당하는 연관성은 거의 동일할 것이기, 더 단순하게 $\sqrt{N} \times \sqrt{N}$으로 하려는 것으로 보임.

하지만, 이렇게 가공하면, 정방이 안나오기에, 정방이 나올수 있도록, 부족한 패치는 앞의 패치에서 따와 붙여줌. 그렇게 만든 패치의 모음이 $H_{s}^{f}$

이 방법을 이용해서 Sefl-attention을 구하면 됨. 하지만, 조금 더 나아가서 Nystrom Method라는 방법으로 O(n)의 연산을 가지도록 하고, 이렇게하면 토큰수가 엄청 커져도 효율적인 연산이 가능.

 

한편, PPEG은 위에서 만든 $H_{f}$을 여러 convolution kernel을 동시에 적용 후에,  padding하여 집계하는 방식으로 이뤄짐. 각 kernel을 통과한 featuremap은 단순히 합산으로 연산되고, 이를 Flattenting하여 다시 패치토큰으로 얻어냄.

 

 

반응형

 

Motivation

  • 한 조직슬라이드를 보면, 같은 패턴(예, Stroma, Cancer, Epithelium)의 조직슬라이드 여러 번 등장한다는 것
  • 여러 패턴의 특징을 집합으로해서, 분포로 특징화하면 슬라이드를 표현하기에 더 좋을 것 이라는 가설

 

Method

방법론은 크게 2가지로 구분됩니다. 1) 패치를 어떻게 프로토타입이랑 연계할 것인가, 2) 슬라이드 임베딩 과정

1. 패치를 임베딩: $z_{N_{j}}^{j} $

2. 슬라이드 임베딩: $z^{j}_{WSI} = \left[ \sum_{n=1}^{N_j} \varphi_j \left( z_{j \, n}, h_1 \right), \, \cdots , \, \sum_{n=1}^{N_j} \varphi_j \left( z_{j \, n}, h_C \right) \right]$

 - 패치의 임베딩인 집합을 패치의 개수가 더 작은 C개의 프로토타입의 집합으로 표현하고자합니다.프로토타입이란 개념을 이용합니다. 여기서 프로토타입은 각 패치의 특징을 대표할만한 벡터를 의미합니다. 예를 들어, 염증을 대표할만한 패치의 벡터, 암 패치를 대표할만한 벡터를 의미합니다. 

  -  이 패치 임베딩의 집합인 WSI은 수식으로 $Z_{WSI}^{j} \in \mathbb{R}^{C \cdot M}$ 으로 표현합니다. 이는 패치1부터 N까지를 프로토타입1에 대해서 유사도를 만들어 모두 합 합니다. 그리고 이 과정을 모든 프로토타입C까지를 진행하여, 벡터로 표현합니다. 결과적으로 M개의 벡터가 프로토타입개수인 C개까지하여 CM차원의 벡터를 얻을 수 있습니다. 라이드내 패치 수($N_{j}$)가 가변적이더라도, 이 축에 대해서 합계를 하기 때문에, $CM$의 차원으로 항상 고정인 벡터가 얻어집니다.

  - $\phi^{j}(\cdot, \cdot)$: 각 패치와 프로토타입의 유사도를 새로운 M차원으로 표현합니다. 이 매핑 함수를 정의하기위해서 GMM(Gausian mixture)을 이용합니다. 각 패치 임베딩($z_{n}^{j}$)가 GMM에서 얻어졌다라고 가정합니다. 즉, 패치임베딩은 여러 가우시안분포중 하나의 분포에서 생성될 확률이 있다고 봅니다. 

$p(x) = \sum_{k=1}^{K}\pi_{k} \cdot N(x|\mu_{x}, \sigma_{k})$

=> $p(z_{n_j}; \theta_j) = \sum_{c=1}^{C} p(c_{n_j} = c; \theta_j) \cdot p(z_{n_j} \mid c_{n_j} = c; \theta_j)$

=> $p(z_{n_j}; \theta_j) = \sum_{c=1}^{C} \pi_{c_j} \cdot N(z_{n_j}; \mu_{c_j}, \Sigma_{c_j})$

 - $ \pi $: 프로토타입c의 확률을 의미합니다. 여기서는 c번째 가우시안 분포가 선택될 확률을 나타냅니다. 이 확률의 총합은 1입니다.

 - $ \theta $ : 최적화해야할 파라미터를 의미합니다. GMM은 Mixutre probablity ($ \pi $), 평균($ \mu_{c} $), 공분산 행렬($ \Sigma_{c_j} $)의 집합을 의미합니다.

위의 수식을 풀어서 해석해보면, C개의 가우시안분포의확률 x 선택된 가우시안분포에서의 확률밀도 입니다.

1. 첫 번째 항($ p(c_{n_j} = c; \theta_j)  $)은 c번째 가우시안 분포로 선택될 확률을 의미합니다. 

2. 두 번째 항($ p(z_{n_j} \mid c_{n_j} = c; \theta_j) $): c번째 가우시안 분포일때, 확률밀도를 의미합니다. 예를 들어, 염증(c=염증)을 의미하는 가우시안 분포에서 임베딩이 나왔다면, 그 임베딩값이 나왔을 확률 밀도를 계산하는 것입니다.

여기서 중요한건, 혼합확률(Mixture probablity, $\pi _{c} ^{j}$)은 슬라이드 별로 계산합니다. 슬라이드의 특징을 고려해서, 특징을 뽑아서 WSI을 GMM의 파라미터로 설명합니다.

최종적으로, 아래와 같이 표현할 수 있습니다.

 

$ z_j^{\text{WSI}} = \left[z_j^{\text{WSI},1}, \cdots, z_j^{\text{WSI},C}\right] = \left[\pi b_j^1, \mu b_j^1, \Sigma b_j^1 \,  \cdots, \pi b_j^C, \mu b_j^C, \Sigma b_j^C \, \right]$

 

Q. GMM에 알고리즘에서 초기화는 어떻게하나?

1. 혼합확률(mixture probability, $\pi_{c} ^{j, (0)}$)은 균등분포로 1/C을 줍니다. 

2. 각 분포의 평균값($ \mu_{c} ^{j, (0)}$): c클레스 프로토타입에 대한 벡터

3. 공분산($Sigma_{c}^{j, (0)}$: Indentity matrix

4. 프로토타입 벡터에 대한 초기화: K-means clsutering을 훈련데이터셋에서 해서 

반응형

 

 

이미지에서 blur(블러, 흐림 현상)을 확인하는 데 사용되는 다양한 알고리즘과 지표들이 있습니다. 주로 이미지의 선명도, 엣지(경계)의 강도를 측정하여 블러를 평가합니다.

이 평가하는 방법은 2가지가 있습니다.

  1. 원본과 대비가 가능한 경우: With reference image
  2. 원본과 대비가 필요 없는 경우: without reference image

 

이 포스팅에서 without reference image(no refereince image)인 대표적인 지표 및 알고리즘들을 크게 분류하면 다음과 같습니다.

  1. Spatial domain: 방법은 이미지의 픽셀과 인접 픽셀 간의 관계를 계산하여 흐림과 선명한 이미지를 구분하는 기법입니다.
    1. Grayscale Gradient base method: 이미지를 그레이스케일로 변환하여, 인접픽셀과의 그레디언트값(변화도)를 측정합니다. 측정된 변화도가 클수록 또렷한 이미지로 간주됩니다. 대표적으로 Laplacian variance가 있습니다.
  2. Spectral-domain: 이미지의 고주파 성분과 저주파 성분을 분석하여 이미지의 선명도를 평가하는 방법입니다. 고주파 성분은 이미지의 선명한 부분과 관련이 있으며, 세부 정보와 경계 정보를 많이 포함하고 있습니다. 반면 저주파 성분은 흐릿한 부분에 해당합니다
  3. Learning: 머신러닝을 이용한 방법
  4. Combination: 2개 이상의 조합을 이용하는 방법

 

1. Laplacian Variance (라플라시안 분산)

라플라시안 분산은 라플라시안 커널을 이용해서, 이미지의 2차미분을 구해 분산을 구하는 방법입니다. 이 분산이 뜻하는 바는 픽셀의 흩어짐 정도로, 또렷한 이미지일수록 높은 분산값을 가지며, 흐릿한 이미지일수록 낮은 분산값을 가지게 됩니다. 특정한 임계점을 두어 또렷하다, 흐리다라고 판단할 수 있습니다. 

다음의 장점을 지닙니다.

  • 간단하고 빠릅니다.: 계산이 비교적 간단하여 실시간 블러 감지에도 적합합니다.
  • 효과적임: 엣지 정보를 기반으로 하므로 다양한 종류의 블러를 효과적으로 감지할 수 있습니다.
  • 수학적 직관성: 분산을 이용한 접근 방식은 통계적으로도 타당성이 있습니다.

 

2차원 이미지에서는 다음과 같이 표현합니다. 2차미분을 구하기 위해서는 Laplacian kernel을 이용해서 구합니다. 

$\nabla^2 I = \frac{\partial^2 I}{\partial x^2} + \frac{\partial^2 I}{\partial y^2}$

이를 메트릭스 연산으로 표현하면, 라플라시안 커널이 됩니다.

$$\begin{bmatrix}
0 & 1 & 0 \\
1 & -4 & 1 \\
0 & 1 & 0
\end{bmatrix}
$$

import cv2
import numpy as np


def cal_laplacian_variance(image_array: np.ndarray):
    """라플라시안 분산을 구함

    Args:
        image_array (np.ndarray): RGB image array

    Returns:
        float: Laplacian variance
    """

    gray_image = cv2.cvtColor(image_array, cv2.COLOR_RGB2GRAY)
    laplacian = cv2.Laplacian(gray_image, cv2.CV_64F)
    return laplacian.var()

라플라시안 분산은 분산이기에 항상 0또는 양의값을 갖게됩니다. 이런 경계면이 모호한 이미지를 바로 식별할 수 있습니다.

또는 경계면이 또렷한(=focal plane)에 맞게 하는 경우의 이미지도 선택할 수 있습니다.

 

2.  Wavelet-based transform

이미지도 신호기이 때문에, 이를 신호로 처리하는 방법이 있습니다. Spectral-domain에 속하는 방법으로, 이미지를 저주파 및 고주파로 나눕니다. 각각 저주파 고주파는 다시, 저저, 고고, 저고,고저로 나눌 수 있고, 저저를 제외한 나머지 신호 강도를 이용하여 이미지의 선명도를 추출할 수 있습니다.

 

Blur Detection for Digital Images Using Wavelet Transform*이 논문에서도 저저주파를 제외하고 신호강도를 측정합니다.

def cal_wavelet(image_array):
    # 이미지를 그레이스케일로 변환 (이미지가 컬러일 경우)
    if len(image_array.shape) == 3:
        image = cv2.cvtColor(image_array, cv2.COLOR_RGB2GRAY)

    # 2D 웨이브렛 변환 수행 (Haar 웨이브렛 사용)
    coeffs = pywt.dwt2(image, "haar")  # 'haar' 외에도 다양한 웨이브렛 사용 가능
    LL, (LH, HL, HH) = coeffs

    # 고주파 성분 (LH, HL, HH)의 에너지를 합산하여 선명도 측정
    high_freq_energy = (
        np.sum(np.abs(LH) ** 2) + np.sum(np.abs(HL) ** 2) + np.sum(np.abs(HH) ** 2)
    )

    return high_freq_energy
반응형

요약


register_buffer는 모델의 상태(state)로서 관리하고 싶은 텐서를 등록하는 데 사용됩니다. 즉, 이 메서드는 state_dict에 포함되어서, torch.nn.Module.state_dict()에 함께 저장되어, torch.save을 할 때, 함께 저장됩니다. 또한, register_buffer으로 등록된 텐서는 기본적으로 기울기를 계산하지 않습니다.

 

기능 1. state_dict을 통해 모델을 저장/로드 할 때, 함께 포함되도록

torch.nn.Module로 딥러닝 네트워크를 구성하고, 필요한 텐서(non-trainable)도 함께 저장이 가능합니다.

아래의 예시를 살펴보겠습니다.

  • 9번줄:  self.register_buffer("running_mean", torch.zeros(10))으로 텐서를 하나 저장합니다. 이렇게되면 self. running_mean에 속성으로도 저장됩니다.
  • 21번줄: torch.save(model.state_dict(), "model_with_buffer.pth")에서 model.state_dict()을 이용해서 state_dict을 저장합니다. 이 때, register_buffer으로 등록한 running_mean= torch.zeros(10)도 함께 state_dict에 저장됩니다.
  • 24번줄: register_buffer을 사용하지않고 저장하려면, state_dict의 딕셔너리에 key-value을 별도로 이렇게 저장해줘야합니다.

 

기능 2. non-trainable parameter을 저장하는 경우

배치 정규화에서 배치 단위의 평균(mean)과 분산(var)은 통계량값만 저장하고, gradient로는 사용되지 않습니다. 이 때도 사용이 가능합니다.

 

반응형

IEEE TCSVT, 2023

 

Preliminary

  1. Instance-based methods: 인스턴스 분류기를 슈도라벨을 이용해서 학습하고, Bag-level로 집계하는 방식
  2. Attention based bag level MIL: instance feature을 뽑은 후에, 이를 집계해서 bag label을 예측함. 주로 Attention weight을 이용하여 이를 instance-level classification처럼 이용.

 

Motivation: 기존 방법들의 단점

  1. Instance-base method: Bag label을 그대로 인스턴스 label로 할당하는 방법을 의미함. Instance pseudo label(할당된 라벨)의 노이즈가 많음. 실제 인스턴스 라벨을 모르기 때문에, 인스턴스 분류기의 품질을 모르고, 성능에 차이가 많이남 (Figure A).
  2. Bag-based method:
    1. Instance level classification 성능이 안좋음: 쉽게 식별가능한 인스턴스에 대해서만 Attention score가 높게 할당하여 학습되는 경향. 모든 positive 인스턴스를 고려하는 것이 아니라는 것 -> 몇 개의 인스턴스만 Attention score 엄청 높고 나머지 인스턴스는 낮게나옴
    2. Bag-level classification도 강건하지 않음: 몇개의 인스턴스만 스코어가 엄청 높아 이 스코어에 의존적으로 예측하다보니 Bag-level classification도 강건하지 않음. 만일 attention score가 높은 예측이 틀린경우, bag classification도 틀림(특히, positive instance가 몇 없는 경우). 논문내에서는 "laziness"라고 하여, 쉬운것만 배우려고하는 게으름(lazy)을 의미

 

 

Method


방법론 개괄: instance-based MIL에 1)contrastive learning, 2)prototype learning을 합친 INS을 제안

 

이 방법론의 주된 목적은 instance classifier을 학습하는 것. instance-based이기 때문.

  1. 임의의 패치($x_{i,j}$)을 하나 샘플링
  2. 이 패치를 서로 다른 방법으로 증강하여 View1(Query view), View2(Key view)을 생성
  3. Query view branch: Query view을 인코더에 전달하여 특징값을 뽑고, instance classifier와 projector에서 슈도라벨과임베딩 벧터를 얻음. 표기로는 클레스 예측($\hat{y}^{i,j} \in \mathbb R^{2}$)와 임베딩벡터 ($q_{i,j} \in \mathbb R^{d}$)에 해당.
  4. Key view branch: key view을 별도의 인코더에 넣고, 다시 projector에 넣어 임베딩 벡터($k_{i,j}\in \mathbb R^{d}$)을 획득.
  5. MoCo을 이용한 학습(Instance-level Weakly Supervised Contrastive learning):
    1. Enqueue: 인코더와 프로젝터의 학습은 MOCO 방식, Pico으로 업데이트. MoCo의 방식대로 key encoder의 값을 dictionary 에 enqueue 해서 넣음. 넣을 때, 슬라이드 이미지의 라벨에 따라 다르게 처리하여 넣음. MIL문제에서는 negative slide에서는 모두 negative patch이기 때문에, 모두 예측클레스를 사용하는게 아니라 0을 넣음. 반대로, positive slide에 대해서만 classifier을 추론시 킴. 그리고, 예측된 클레스( $\hat{y}^{i,j} $), 임베딩값($k_{i,j}$))을 넣음
    2. Contrastive learning: contrastive learning을 하기 위해서 positive, negative pair가 필요한데, 본문에서는 Family and Non-family Sample selection 이라고 표현.
      1. Family set $F(q_{i,j})$: 패밀리 샘플들은 2개의 집합의 합집합임. 첫 번째 집합은 같은 이미지로부터 나온 두 임베딩임 ($q_{i,j}, k_{i,j}$). 나머지 집합은 예측된 라벨과 같은 라벨을 지닌 이미지들임(즉, $\hat{y}_{i,j}$와 같은 임베딩값들을 의미함.
      2. Non-family set $F'(q_{i,j})$: Family set의 여집합을 의미. 즉, 여기서는 다른 예측 클레스 라벨(또는 negative slide)을 가진 임베딩인 경우.
    3. Positive key, Negative key selection: SuperCon(https://analytics4everything.tistory.com/303) 과 같은 알고리즘으로 contrastive learning을 진행. 차이라면, SuperCon에서는 instance label이 hard label로주어졋다면, 이 논문에서는 instance classifier와 negative bag을 이용해서 얻었다는 차이

    6. Prototype-based Psudo label (PPLG): Instance classifier을 학습하기 위한 방법. 주요가정: 고품질의 feature representation을 얻기위해서 정확학 psuedo labels을 할당해야함.

  • $s_{i,j}$: 슈도 라벨을 의미함. one-hot vector로 2차원으로 표현. [0, 1] 또는 [1, 0]임.
  • $\mu_{r}$: prototype vector을 의미함. $r$은 라벨을 의미함.

자세한 방법은 아래와 같음

  1. 2개의 슈도라벨을 생성: 프로토타입 벡터 ($\mu _{r} \in \mathbb R ^{d}, r=0,1$)을 생성. r은 라벨을 의미함. 즉 d차원을 가진 프로토타입 벡터를 생성
  2. 프로토타입 벡터 업데이트:
    1. True negative guided update: 패치가 negative positive bag에서 획득된 경우, 해당 패치도 negative일 것임. 따라서, $s_{i,j}=0$을 직접할당. 그리고 프로토타입 벡터도 이 임베딩값($q_{i,j}$)을 할당.
    2. Poistive update: positive prototype vector 
    3. 쿼리임베딩($q_{i,j}$)과 두 프로토타입($\mu_{r}$)과 내적. 내적값이 더 가까운쪽으로 학습하기위해서, onehot 인코딩으로 내적값을 [0, 1], 또는 [1, 0]으로 만듬
    4. 슈도라벨을 모멘텀 업데이트: $s_{i,j} = \alpha s_{i,j} + (1- \alpha ) z_{i,j}, z_{i,j} = onehot(argmax( q_{i,j}^{T}, \mu_{r))$
    5. 프로토타입 업데이트: 프로토타입 벡터($\mu_{c}$)을 classifier에 포워딩하여 얻은 클레스에 맞게 업데이트합니다.
    6. 슈도 라벨 업데이트: 식(7), 식(8)
    7. Bag constraint: mean pooling을 이용해서 bag embedding을 진행

 

    7. Loss

  • Bag loss: 패치 임베딩을 Mean pooling하여 라벨과 계산 (식10)
  • Total loss: $\mathcal{L}=\mathcal{L}_{IWSCL} + \mathcal{L}_{cls} +\mathcal{L}_{bc}$

 

 

 

 

반응형

 

Preliminary

  • Positive sampleOrigin이 동일한 이미지(또는 데이터포인트)
  • Negative sample: Origin이 다른 이미지(또는 데이터포인트)
  • Dictionary look-up task: 여기서는 contrastive learning을 의미합니다. Query image에서 하나의 positive이미와 나머지 (N-1)개의 negative이미지와 유사도를 계산하는 과정이기에 k-value와 같은 dictionary구조라고 일컫습니다.
  • InfoNCE: InfoNCE은 미니배치 K+1개에서 1개의 positive sample을 찾는 softmax 함수 + temperature scale추가

 

Methods: 딕셔너리 사이즈를 키워 negative samples을 포함할 수 있으면 좋은 학습방법이 될 것

Moco은 딕셔너리 사이즈를 키워, Negative key을 많이 갖고 있게되면 더 학습이 잘된다는 가설로 시작됩니다. 위의 딕셔너리가 가변적(dynamic)합니다. 하나의 query에 대해서 key도 augmentation을 달리하고, 인코더도 점차 학습되기 때문입니다. MoCo은 Negative sample을 임베딩하기에 Key encoder가 비교적 천천히 학습되고(=급격히 변하지 않고) 일관적으로 학습되는 것을 목표로 합니다.

1. MoCo은 샘플을 Queue방식으로 처리합니다. 즉, 새로운 샘플을 Queue에 넣으면, 먼저 들어가있던 오래된 샘플은 제외되는 방식입니다. 샘플을 넣는 과정을 "Enqueue" 제외하는 과정을 "dequeue"라 합니다. 따라서, 비교적 최근의 인코딩된 Key Image들은 enqueue후에 몇 회 재사용될 수 있습니다.

2. Momentum update: key encoder을 직접학습하는 대신에, query encoder을 학습하고 학습의 일부만 반영(=momentum)합니다. 이유는 key encoder에는 많은 negative samples에 대해서 backpropagation하기 어렵기 때문입니다즉, computational cost문제 때문입니다. 예를 들어, 이번 배치가 16개였고, 딕셔너리사이즈가 1024개면, 1,024개의 샘플에 대해서 backpropgation을 해야하기 때문에 계산량이 많아 현실적으로 어렵다는 것입니다. enqueue하는 방식으로는 back-propagation이가능하다는 글을 같이 공유합니다(URL, URL2) (보통, "intractable"이라 하면 미분이 불가능한, 다루기 어려운 이런 뜻인데 미분과 연관되어 쓰여 혼란이 있습니다)

또한 인코더가 급격하게 변하는 경우, 딕셔너리에 저장되어있던 기존의 인코딩의 값이 현재와 많이 상이하기에 천천히변화시켜야한다는 것입니다.

그래서 아래와 같이 momentum update을 대안으로 선택합니다. 아래의 수식(2)을 보면, query encoder의 파라미터의 내용이용하여 key encoder을 업데이트하는 방식으로 학습합니다.

 

$\theta_{k} \leftarrow m \theta_{k} + ( 1- m )\theta_{q}$ (2),

  • $\theta_{q}$: query encoder의 파라미터
  • $ \theta_{k} $: key encoder의 파라미터
  • $m$: 모멘텀계수. [0, 1]사이의 값

따라서, query encoder만 back-propagation을 진행합니다. 그리고, 급진적인 key encoder의 파라미터 변화를 막기 위해서, m값을 0.999로 학습합니다. 이는 m=0.9보다 성능이 더 좋았다고 합니다.

 

기타1: contrastive loss 학습방식

  • end-to-end update: query encoder와 key encoder가 동일한 경우를 의미합니다. 그리고, 딕셔너리는 항상 미니배치 사이즈와 같습니다. 따라서, 딕셔너리 사이즈를 엄청 키우진 못합니다.
  • memory bank: 메모리뱅크는 데이터셋에 모든 샘플을 일단 인코딩하여 넣어둡니다. 그리고, 매니배치에서 일부를 샘플링하여 contrastive learning에 사용합니다. 장점은 "큰 사이즈의 딕셔너리"라는 것입니다. 단점은 메모리뱅크의 업데이트이후 key encoder가 여러번업데이트되기 떄문에, 둘의 일관성(인코딩 차원에서의)이 다소 떨어지는다는 것입니다.

 

Linear probe 과정에서 learning rate을 30으로 한 이유

논문에서는 "optimal learning rate을 30, weight decay을 0"으로 진행합니다. 일반적으로 다른 hyper-parameter인데, 이 이유는 feature distirbtion ImageNet supervised training과 Instragram-1B와 크게 다름을 의미합니다.

반응형

+ Recent posts