Best Paper review

[5분 컷 리뷰] Supervised Contrastive Learning

연금(Pension)술사 2024. 7. 23. 19:23

 

 

Motivation

CE(Cross-entropy)은 지도학습의 분류에 주로 사용됩니다. 하지만, 많은 단점들이 존재하는데, 예를 들어 noisy label이나 poor margin 같은게 있어 일반화 성능이 떨어집니다. CE의 대안으로 나온 여러가지 손실함수가 발명됬지만, 여러 챌린지에서 보면 여전히 CE을 쓰듯이, 실무에서는 큰 도움은 안됩니다.

최근에 대조학습(Constraive learning)으로, 라벨이 없이도 자기지도학습에서 사용됩니다. 미니 배치 내에서, 앵커와 같은 origin data은 가깝게, 앵커와 먼 데이터는 멀게 학습하는 방법으로 학습합니다.

Figure 2: Supervised The self-supervised contrastive loss (left, Eq. 1)

 

이 논문은 자기지도학습에서의 라벨을 이용하여 contrastive learning을 이용해서 학습하는 방법을 제안했습니다. 같은 클레스의 임베딩을 정규화한다음에 가깝게 위치시키고, 다른 클레스보다 멀게 학습시킵니다. 앵커하나에 여러 positive pair을 만들 수 있다는게 장점이고, 많은 negative pair도 만들 수 있습니다.

 

Method

  1. $\text{Aug()}$: 한 이미지 x로부터 이미지를 증강하는 모듈입니다. $\tilde{x}=Aug(x)$와 같이 $\tilde{x} $가 2개씩 생성됩니다. 각 $\tilde{x}$은 x로부터 새로운 view(증강법에 따른 변형)을 의미합니다. crop일수도, random rotation일 수도있습니다.
  2. $\text{Enc()}$: 인코더 파트입니다. 벡터표현을 위해서 인코더를 하나 만듭니다. 이 인코더는 1에서 만들어진 2개의 증강된이미지를 포워딩하여 벡터를 만듭니다. 벡터의 차원은 2,048입니다. $r=Enc(x)\in R^{D_{E}}$
  3. $Proj()$: Projection network입니다. 벡터표현이후에 128차원의 벡터로 프로젝션합니다. 이 벡터는 단위의 길이를 가질 수 있도록(unit hypersphere), 정규화를 합니다. Projection layer은 SSL에서는 사용하지 않습니다.

 

1. Contrastive loss function의 정의: 

  1. $\{x_{k}, y_{k}\}_{k=1,...N}$: 지도학습에서는 N개의 임의의 샘플이 있다고 가정합니다. N은 배치사이즈입니다.
  2. $\{\tilde{x_{l}}, \tilde{y_{k}|\}_{l=1,...,2N}$: 1.에서의 이미지를 각각 2개의 이미지 증강을한 결과를 의미합니다. 여기서 라벨은 $y_{k}$은 증강후에 $ \tilde{y_{2k-1}, \tilde{y_{2k} $로 되니, 다 같은 값입니다. 2N은 증강후의 배치입니다. 본문내에서는 "multi-viewed batch"라 합니다.

 

2. Self-supervised Contrastive Loss의 정의

Contrastive learning과 Supervised constrastive learning을 비교하면 아래와 같습니다. 라벨이 있으니, 같은 클레스끼리는 가깝게 학습시키게끔 변형되었습니다.

  • $P(i)$: 은 i랑 동일한 이미지는 아닌, positive 의 집합니다.
  • $A(i)$: 은 i(Anchor)가 아닌 나머지의 집합입니다.
  • $z_{i}$:은 Projection 이후의 임베딩값입니다.
  • $\tau$: temperature scaling

실제 논문에서 제안된건 아래와 같이 |P(i)|의 개수를 어디서 연산하냐에 따라서, 아래와 같이 2가지로 나뉩니다. 두 함수는 완전 동치는 아니라 각각 실험에서의 최적화된 방법을 사용하면 됩니다.

 

 

Results: 엄청좋아지나..? 그정도까진...

1. Top 1 classification 성능은 그냥 CE을 쓰나 SimCLR을 쓰나 엄청난 차이를 보이지 않습니다. 하지만 이 영역대에서 x%p올리기 쉽지않은데, SupCon으로 한번 최적화해보는건 좋은 선택같습니다.

 

2. 그래도 강건한 모델:

mCE라는 지표로 N은 이미지의 왜곡방법의 수를 의미하며, 분자 분모는 오차율입니다. mCE은 낮을수록 이미지 변환에도 강건하고 오차율이 없다는 것을 의미하는데요. 

$mCE = \frac{1}{N} \sum_{c=1}^{N} E_{c \text{baseline}} E_{c}$

여기서 corruption 이 정확히 어떤 이미지 변환인지는 표기는 안되어있으나, 이미지 corruption을 강하게 하더라도 종전모델보다 강건한 이미지를 나타냅니다.

 

얼마나 학습시켜야하나?

1. ResNet을 CE로 학습시키는데 1400정도 에폭을 활용: "Since SupCon uses 2 views per sample, its batch sizes are effectively twice the cross-entropy equiv-
alent. We therefore also experimented with the cross-entropy ResNet-50 baselines using a batch size
of 12,288. These only achieved 77.5% top-1 accuracy. We additionally experimented with increas-
ing the number of training epochs for cross-entropy all the way to 1400, but this actually decreased
accuracy (77.0%)"

2. 200에폭도 충분함

"The SupCon loss was trained for 700 epochs during pretraining for ResNet-200 and 350 epochs for
smaller models. Fig. 4(c) shows accuracy as a function of SupCon training epochs for a ResNet50,
demonstrating that even 200 epochs is likely sufficient for most purposes."

반응형