Data science/Deep learning

[5분 컷 리뷰] Self-distillation 설명

연금(Pension)술사 2024. 4. 11. 20:54

 

Data-Distortion Guided Self-Distillation for Deep Neural Networks , AAAI 2019

 

배경: Knowledge distillation 

큰 모델의 model confidence (softmax output)은 이 큰 모델이 학습한 지식을 담고있지 않을까라는 가설로 시작됩니다. 큰 모델의 결과물(증류된 정보)을 이용해서 작은 모델을 간단히 학습할 수 있지 않을까라는 생각입니다 [1]. 즉 큰 모델의지식(knowledge)이 model confidences 분포이겠구나라는 것입니다. 

아래의 그림과 같이 Teacher model이 훨씬 큰 딥러닝모델이고, Student모델이 더 작은 모델입니다. 큰 모델의 예측결과(model confidences)자체가 큰 모델의 지식이라는 것입니다. 예를 들어, 고양이가 6번인덱스의 클레스였다면, 0번은 낮고, 1번도 낮고..., 6번이 제일 큰데 0.4정도라는 이 분포자체가 Teacher모델이 갖고있는 큰 파라미터의 지식이 됩니다. 이를 답습하자는것이 Student모델의 학습목표입니다.

Knowledge distillation 방법론

흔히 지식증류(knowledge distillation)방법들은 학습방법에 따라, 3가지로 나뉩니다. 

  1. teacher-to-student: 큰 모델의 지식(logits, feature map,attention map등)을 작은 모델이 더 잘학습시키기 위한 방법론. 서로 다른 모델이 2개가 있어야하며, Teacher model은 학습되어있다고 가정합니다.
  2. student-to-student: 서로 다른 작은 다른 모델(student)을 서로 다른 초기값을 가지게하고, 학습을 동시에 해가면서 지식의 차이를 점점 줄여가나는 방식입니다. teacher-to-student와는 다르게 이미 사전학습된 모델이 필요없는 On-line training 방법입니다.
  3. self-distillation (self-teaching): 하나의 네트워크를 이용하며, 이미지 1개를 데이터 증강(random, rotation)등을 통해서 2개로 만들고, 서로 다른 지식을 일치시키는 방법을 사용합니다. 즉, 하나의 네트워크를 사용하며, 서로 다른이미지를 추론했음에도 분포가 같아야하는 모델 일반화가 더 잘되는 방법으로 학습을 유도합니다.

 

방법: Self-distillation

Self-distillation은 크게 4단계로 이어집니다.

1. Data distortion operation: data augmentation 방법과도 같습니다. 같은 원본데이터로, 서로 다른 두 이미지를 생성합니다.

2. Feature extraction layers: 하나의 모델에 두 증강된 이미지를 통과시켜 특징값(feature)을 얻습니다. 여기서는 두 특징값(벡터)의 차이를 계산합니다. 

MMD은 Maximum Mean Discrepancy을 의미하며, 두 증강된 이미지를 특징으로 만든 벡터의 차이를 의미합니다. 수식으로는 아래와 같습니다.

$MMD_{emp}=||\frac{1}{n}\sum_{i=1}^{n}\mathbf{h}(x_{ai})- \frac{1}{n}\sum_{i=1}^{n}\mathbf{h}(x_{bi})||_{2}^{2}$

$ \mathbf{h} $은 네트워크를 의미합니다. 즉 $ \mathbf{h} (x_{ai}), \mathbf{h} (x_{bi})$ 은 서로 다른 이미지를 하나의 네트워크에 태워 얻은 벡터를 의미합니다. 이 벡터의 차이를 계산하는 것이 MMD입니다. 

3. Classifier: classifier레이어에 통과시킵니다.

4. Predictor: 사후확률($\mathbf{p}(y|\mathbf{x}_{i})$)을 이용해서 2가지를 계산합니다. 하나는 분포의차이(KL divergence), 또 하나는 classification loss (cross-entropy loss)입니다.

  • KL divergence은 방향이 없기에 양방향으로 $KL(\mathbf{p}_{a}, \mathbf{p}_{b}), KL(\mathbf{p}_{b}, \mathbf{p}_{a})$을 둘 다 계산합니다. 
  • Classification loss은 두 이미지에 대해서 라벨을 이용해서 classification을 계산합니다.

KL diveregnce loss 2개, classification loss2개, MMD loss을 합산, 총 5개의 손실함수를 합산하여 모델을 학습시킵니다.

Results

첫 번째 결과는 CIFAR-100에서의 분류성능이고 동일한 파라미터수에서는 가장 낮은 테스트에러를 보였습니다.

두 번째, 결과는 모델의 크기에따라, self-distillation을 할때 어느정도 큰 차이가 나는지를 확인해봤습니다. 작은모델에서 distillation할때 성능이 크게 향상되는 것을 볼 수 있습니다. 즉, 이 방법론을 사용할거면 작은모델을 사용하는 접근이 유리해보입니다.

세 번째, 결과는 다양한 네트워크를 이용한 종합적인 성능인데, 모두 SD(Self-distilation)을 이용하는 경우 더 좋은 성능을 보였습니다.

네 번째 결과로, 데이터가 매우 큰 경우에서도 잘 동작하는지를 봤는데, 데이터가 매우 큰 경우에는 성능향상이 있었는데, 그렇게 큰 차이는 아닌 것 같내요.

Self-supervised learning vs self-distillation

두 방법론은 1)augmetentation이 둘 다 사용되고, 2) 하나의 네트워크를 학습한다는 것에 공통점이 있습니다.

하지만, SSL(Self-supervised learning)은 라벨이 필요없는 데이터로 학습한다는 것에 큰 차이가 있습니다.

SD(Self-distillation)은 총 5가지의 손실함수를 사용하는것중에 2개의 손실함수는 classification loss가 사용되니 라벨이 반드시 필요한 방법론이 됩니다.

 

References

[1] https://arxiv.org/abs/1503.02531

반응형