Data-Distortion Guided Self-Distillation for Deep Neural Networks , AAAI 2019
배경: Knowledge distillation
큰 모델의 model confidence (softmax output)은 이 큰 모델이 학습한 지식을 담고있지 않을까라는 가설로 시작됩니다. 큰 모델의 결과물(증류된 정보)을 이용해서 작은 모델을 간단히 학습할 수 있지 않을까라는 생각입니다 [1]. 즉 큰 모델의지식(knowledge)이 model confidences 분포이겠구나라는 것입니다.
아래의 그림과 같이 Teacher model이 훨씬 큰 딥러닝모델이고, Student모델이 더 작은 모델입니다. 큰 모델의 예측결과(model confidences)자체가 큰 모델의 지식이라는 것입니다. 예를 들어, 고양이가 6번인덱스의 클레스였다면, 0번은 낮고, 1번도 낮고..., 6번이 제일 큰데 0.4정도라는 이 분포자체가 Teacher모델이 갖고있는 큰 파라미터의 지식이 됩니다. 이를 답습하자는것이 Student모델의 학습목표입니다.
Knowledge distillation 방법론
흔히 지식증류(knowledge distillation)방법들은 학습방법에 따라, 3가지로 나뉩니다.
- teacher-to-student: 큰 모델의 지식(logits, feature map,attention map등)을 작은 모델이 더 잘학습시키기 위한 방법론. 서로 다른 모델이 2개가 있어야하며, Teacher model은 학습되어있다고 가정합니다.
- student-to-student: 서로 다른 작은 다른 모델(student)을 서로 다른 초기값을 가지게하고, 학습을 동시에 해가면서 지식의 차이를 점점 줄여가나는 방식입니다. teacher-to-student와는 다르게 이미 사전학습된 모델이 필요없는 On-line training 방법입니다.
- self-distillation (self-teaching): 하나의 네트워크를 이용하며, 이미지 1개를 데이터 증강(random, rotation)등을 통해서 2개로 만들고, 서로 다른 지식을 일치시키는 방법을 사용합니다. 즉, 하나의 네트워크를 사용하며, 서로 다른이미지를 추론했음에도 분포가 같아야하는 모델 일반화가 더 잘되는 방법으로 학습을 유도합니다.
방법: Self-distillation
Self-distillation은 크게 4단계로 이어집니다.
1. Data distortion operation: data augmentation 방법과도 같습니다. 같은 원본데이터로, 서로 다른 두 이미지를 생성합니다.
2. Feature extraction layers: 하나의 모델에 두 증강된 이미지를 통과시켜 특징값(feature)을 얻습니다. 여기서는 두 특징값(벡터)의 차이를 계산합니다.
MMD은 Maximum Mean Discrepancy을 의미하며, 두 증강된 이미지를 특징으로 만든 벡터의 차이를 의미합니다. 수식으로는 아래와 같습니다.
$MMD_{emp}=||\frac{1}{n}\sum_{i=1}^{n}\mathbf{h}(x_{ai})- \frac{1}{n}\sum_{i=1}^{n}\mathbf{h}(x_{bi})||_{2}^{2}$
$ \mathbf{h} $은 네트워크를 의미합니다. 즉 $ \mathbf{h} (x_{ai}), \mathbf{h} (x_{bi})$ 은 서로 다른 이미지를 하나의 네트워크에 태워 얻은 벡터를 의미합니다. 이 벡터의 차이를 계산하는 것이 MMD입니다.
3. Classifier: classifier레이어에 통과시킵니다.
4. Predictor: 사후확률($\mathbf{p}(y|\mathbf{x}_{i})$)을 이용해서 2가지를 계산합니다. 하나는 분포의차이(KL divergence), 또 하나는 classification loss (cross-entropy loss)입니다.
- KL divergence은 방향이 없기에 양방향으로 $KL(\mathbf{p}_{a}, \mathbf{p}_{b}), KL(\mathbf{p}_{b}, \mathbf{p}_{a})$을 둘 다 계산합니다.
- Classification loss은 두 이미지에 대해서 라벨을 이용해서 classification을 계산합니다.
KL diveregnce loss 2개, classification loss2개, MMD loss을 합산, 총 5개의 손실함수를 합산하여 모델을 학습시킵니다.
Results
첫 번째 결과는 CIFAR-100에서의 분류성능이고 동일한 파라미터수에서는 가장 낮은 테스트에러를 보였습니다.
두 번째, 결과는 모델의 크기에따라, self-distillation을 할때 어느정도 큰 차이가 나는지를 확인해봤습니다. 작은모델에서 distillation할때 성능이 크게 향상되는 것을 볼 수 있습니다. 즉, 이 방법론을 사용할거면 작은모델을 사용하는 접근이 유리해보입니다.
세 번째, 결과는 다양한 네트워크를 이용한 종합적인 성능인데, 모두 SD(Self-distilation)을 이용하는 경우 더 좋은 성능을 보였습니다.
네 번째 결과로, 데이터가 매우 큰 경우에서도 잘 동작하는지를 봤는데, 데이터가 매우 큰 경우에는 성능향상이 있었는데, 그렇게 큰 차이는 아닌 것 같내요.
Self-supervised learning vs self-distillation
두 방법론은 1)augmetentation이 둘 다 사용되고, 2) 하나의 네트워크를 학습한다는 것에 공통점이 있습니다.
하지만, SSL(Self-supervised learning)은 라벨이 필요없는 데이터로 학습한다는 것에 큰 차이가 있습니다.
SD(Self-distillation)은 총 5가지의 손실함수를 사용하는것중에 2개의 손실함수는 classification loss가 사용되니 라벨이 반드시 필요한 방법론이 됩니다.
References
'Data science > Deep learning' 카테고리의 다른 글
Pytorch DDP(Distributed Data Parallel), DP(Data Parallel) 비교 총정리 (0) | 2024.05.03 |
---|---|
[5분 컷 리뷰] Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics (0) | 2024.04.22 |
분류문제: Cross entropy 대신에 MSE을 쓸 수 있나? + 역전파 및 MLE해석 (0) | 2022.11.23 |
[5분 컷 이해] Multiple Instance learning 이란? (0) | 2022.11.08 |
[5분 컷 이해] Probability calibration 설명 (0) | 2022.09.14 |