테스트 이미지(좌측 Label 7)을 예측하는데 있어, 가장 안좋은 영향력(harmful)한 데이터를 influence function을 이용하여 찾은 결과


요약


흔히, 딥러닝모델을 예측치에 대한 설명이 되지않아 블랙박스 모델이라고 한다. Influnence function은 특정 테스트 데이터포인트를 예측하는데 도움을 주었던, 도움을 주지 않았던 훈련데이터를 정량적으로 측정할 수 있는 방법이다. 즉, 훈련데이터에서 특정데이터가 빠진 경우, 테스트데이터를 예측하는데 어느정도 영향이 있었는지를 평가한다. 이를 쉽게 측정할 수 있는 방법은 모델에서 특정 데이터를 하나 뺴고, 다시 훈련(Re-training)하면 된다. 하지만 딥러닝 학습에는 너무 많은 자원이 소모되므로, 훈련후에 이를 influence function으로 추정하는 방법을 저자들은 고안을 했다.



 

Influence function 계산을 위한 가정


훈련데이터에 대한 정의를 다음과 같이 한다. 예를 들어, 이미지와 같이 입력값을 $z$라고 한다.

  • 훈련데이터 포인트(개별 하나하나의 데이터)를 $z_{n}$라고 한다.
  • 이 훈련 데이터 포인트에 개별 관측치는 입력값 x와 라벨y로 이루어져 다음과 같이 표기한다 $z_{i}=(x_{i}, y_{i})$
  • 그리고, 모델의 학습된 파라미터를 $\theta$라고 한다.
  • 주어진 파라미터한에서, 데이터포인트 z를 주었을 때의 loss 값을 $L(z,\theta)$ 라고한다.
  • 훈련데이터 1부터 n개까지의 총합 로스를 emprical loss라고 정의한다 $1/n \sum_{i=1}^{n}L(z_{i},\theta)$
  • 학습과정을 통해 얻어진 파라미터가 위의 empirical loss을 가장 줄일 수 있게 학습된 파라미터를 $\hat{\theta}=argmin_{\theta\in\Theta}1/n\sum_{i=1}^{n}L(z_{i}\theta)$

 

Influence function 의 종류


Influence functiond은 training point가 모델의 결과에 어떤 영향을 주었는지 이해하기위해서 3가지 방법론을 고안했다.
1. 해당 훈련데이터포인트가 없으면, 모델의 결과가 어떻게 변화하냐? (=가중치가 어떻게 바뀌어서 결과까지 바뀌게 되나?)
= 이 방법론은 loss가 얼마나 변화하냐로도 확인할 수 있다.
연구에서는 이를 up-weighing 또는, up,params이라는 표현으로 쓴다. up-weighting이라는 표현을 쓰는 이유는 특정 훈련데이터가 매우작은 $\epsilon$ 만큼 가중치를 둬서 계산된다고하면, 이 $epsilon$을 0으로 고려해서 계산하기 때문이다. 또는 이를 loss에 적용하면 "up,loss"라는 표현으로 기술한다.
2. 해당 훈련데이터포인트가 변경되면(perturnbing), 모델의 결과가 어떻게 변경되나?

Influence function 의 계산(up,loss)


아래의 계산식에서 좌측($I_{up,loss}(z,z_{test})$)은 예측값 $z_{test}$을 예측하는데있어, 특정 훈련데이터포인트 $z$가 얼마나 영향을 주었나(=loss을 얼마나 변화시켰나)를 의미한다. 마지막줄의 각 계산에 들어가는 원소는 다음 의미를 갖는다.

  • $-\nabla_{\theta}L(z_{test},\hat{\theta})^{T}$: 특정 테스트 z포인트를 넣었을 때, 나오는 loss을 학습파라미터로 미분한 값이다. 즉
  • $H_{\hat{\theta}}^{-1}$: 학습된 파라미터의 hessian의 inversed matrix이다.
  • $\nabla_{\theta}L(z,\hat{\theta})^{T}$: 특정 훈련데이터포인트 z를 넣었을 때, 나오는 loss을 학습파라미터로 미분한 값이다.

위의 세 값을 다 계산하여 곱해주기만하면 특정 훈련데이터z가 테스트데이트를 예측하는데 로스를 얼마나 변화시켰는지 파악할 수 있다.


여기서 문제는 해시안 메트릭스($H_{\hat{\theta}}^{-1}$)를 계산하는 것이 매우 오래걸린다는 것이다. 따라서 저자들은 아래와 같이 추정방법으로 해시안메트릭스의 역행열을 구한다.

Influence function 계산의 어려움(비용)


Influence function 계산이 어려운 부분(계산비용이 큰 이유) 2부분 때문이다.

  1. 해시안 메트릭스의 역행렬을 구해야함: 해시안 메트릭스는 각 행렬에 파라미터의 개수가 p개라고 하면, $O(np^{2}+p^{3})$의 계산량이 든다는 것이다.
  2. 훈련 데이터 포인트에 대해서 모든 데이터포인트를 다 찾아야한다는 것. 예측에 도움이 안되는 데이터포인트를 찾으려면 훈련데이터가 n개면 n개의 데이터포인트를 찾아야한다.

다행이도 첫 번째 문제인 해시안 메트릭스이 역행렬을 그하는 것은 잘 연구가 되어있어서 해결할 수 있다. 명시적으로 해시안의 역행렬을 구하는것 대신에 해시안 역행렬과 손실값을 곱한 Heissan-vector product(HPV) 을  구하는 것이다



Inversed hessian matrix 구하기(HPV구하기, 중요)


조금더 구체적으로는 저자들은 해시안메트릭스의 역행열을 구하기보다는, 해시안메트릭스의 역행렬과 loss에 해당하는 vector($\nabla_{\theta}L(z_{test},\hat{\theta})^{T}$)을 한번에 구하는 방법을 소개한다. 이는 책 또는 github코드에서 주로 Inversed hessian vector product (inversed HVP)라고 불린다.


이를 구하는 2가지 방법이 있다.

  1. Conjugate gradient (CG)
  2. Stocahstic estimtation

저자들은 주로 Stocahstic estimation 방법으로 이를 해결하려고 한다. 

다음의 서술은 위의 HVP을 구하는 방법에 관한 것이다.

  1. 훈련 데이터 포인트 전체 n개에서 매번 t개만을 샘플링 한다.
  2. HVP의 초기값은 v을 이용한다 (이 때, v은 $\nabla_{\theta}L(z_{test}, \hat{\theta})^{T}$을 사용한다. 테스트세트에 대한 로스값의 gradient이다.)
  3. $\tilde{H_{j}}^{-1}v=v+(I-\nabla_{\theta}^{2}L(z_{s_{j}, \hat{\theta}})\tilde{H}_{j-1}^{-1}v$ 을 계속한다.
  4. 위의 과정(1-3)을 반복한다.

위와 과정을 충분히 반복하고, t을 충분히 크게 뽑는 경우 HVP값이 안정화되고 수렴된다고 한다. 이 방법이 CG방법보다 빠르다고 언급이 되어있다.

하이라이트 결과


Leave one out 방법과 유사하게 Influence function이 loss값의 변화를 추정한다.

개인적으로 아래의 Figure 2가 제일 중요해보인다. 훈련데이터를 하나 뻈을때(leave-one-out)의 방법과 Inlfuence function으로 예측한 예측값이 거의 정확하다면 선형을 이룰것이다. 아래의 그림은 이를 뒷받침하는 결과로 사용된다.

 

손실함수를 미분할 수 없는 경우도, 손실함수를 대략적으로 추정(Smooth approximation)  한 경우 추정할 수 있다.


아래의 결과는 Linear SVM을 이용하여 예측한 loss의 값과, 실제 retraining한 경우의 예측값을 보여주고 있다. Linear SVM 자체에서는 Hinge 값이 쓰이는데 (ReLU와 유사하게) 미분이 불가능하다. 따라서, 저자들은 0에서 미분불가능할 때 smooth하게 변경하기위해서 아래와 같이 smooth hinge라는 것을 만들어서 어느정도 추정한 값을 쓰고 있다($smooth ~hinge(s,t)=tlog(1+exp((1-s)/t)$). 즉 이 함수는 t가 0이되면 0이될수록 미분이 불가능한 hinge loss와 동일해진다. 이렇게 적용했을 때 결과가 (b)와 같다. 미분불가능한 함수를 미분 가능하게 조금 수정한다면, (b) 와 같이 t=0.001인 경우, t=0.1인 경우와 같이 어느정도 influence function이 데이터를 빼고 재학습한 것과 같이 추정이 가능하다는 것이다)

 

Influence function의 유즈케이스


아래는 influence function을 언제 쓰는지에 관한 것이다. 

1. 모델이 어떤 방식으로 행동하는지를 알 수 있다.: 아래의 그림 (Figure 4)은 SVM과 Inception(CNN기반)에서의 물고기 vs 개의 분류문제이다. 이 Figure 은 여러가지를 우리에게 알려준다.

  1. 우리의 직관은 "SVM은 딱히 이미지의 위상정보을 이용하지 않으니, 이미지의 거리(L2 dist)가 클수록 분류에 도움이 안될 것"으로 생각할 수 있다. Figure 4을보면 x축이 Euclidean dist인데, 400이상인경우 influcen function의 크기가 0에 가까워지는 것을 볼 수 있다.
  2. 녹색의 데이터포인트는 물고기의 학습데이터이다. 물고기를 예측하는데 있어, 물고기의 이미지만 도움을 주는 것으로 파악할 수 있다.
  3. 반대로, Inception 은 딱히 위상정보가 아니어도, 이미지의 외형(contour)을 잘 학습하기 때문에, 거의 모든 이미지가 분류에 영향을 주는 것을 알 수 있다. 즉 물고기의 예측하는데 있어서, 일부 강아지 이미지가 도움을 줄 수 있다는 것이다.

 

 

 

반응형

+ Recent posts