둔비의 공부공간

Improving Ensemble Distillation With Weight Averaging and Diversifying Perturbation 본문

Papers/Compression

Improving Ensemble Distillation With Weight Averaging and Diversifying Perturbation

Doonby 2023. 3. 8. 19:33

(ICML 2022)

 

서론

옛날에 생각했던 diversify에 대한 논문이다.

 

teacher network가 학습하면서, train loss는 0에 수렴하고, student가 받을 teacher의 output distribution은 항상 비슷하다.

이를 학생과 선생모델 모두에게 diversify를 강조하여 학습하고 해결해보자.

 

앞서 말한 문제를 해결하는 참신한 ensemble distillation 방법을 제시한다.
1. 학생 네트워크를 설계하는 새로운 방법
- 학습할때는 subnetwork를 사용하지만, inference시에는 weight average를 통해 
단일 네트워크로 만든다.

2. 학생과 선생의 다양성을 모두 고려한 perturbation strategy 제시
- 학생의 weak points를 찾는다. (학생은 다 동일한 값일때, 선생들은 다른 값)인 그 input1. 

 

 

배경지식

 

1. avg연산은 diversity를 해친다.

multiple teacher, student 구조에서, teacher의 output을 avg하고, 학생의 output을 놓고 $KL$ 혹은 $CE loss$ 를 사용했다.

논문의 저자들은, output을 avg하는 과정에서 diversity가 없어진다고 이야기한다.

 

기존 mean대신, 학생모델 하나하나와 Loss를 주고, 이를 평균내는 방식을 사용했다.

 

 

2. ODS(Output Diversified Sampling) 혹은 ConfODS(Scaled ODS)는 diversity에 도움이 된다.

 

또한, 선생모델이 다 다른 의견을 갖도록 만드는 ODS(Output Diversified Sampling)를 사용했다.

w는 k-dim 가우시안 랜덤벡터로, teacher의 output prediction에 영향을 준다.

이를 $x + \epsilon$ 로 활용하여 $ \hat{x} $ 를 대신 사용한다.

 

 

 

 

Contribution

- New way of constructing a student network; we propose to distill with a student with multiple subnetworks during training, but average the sub- network weights later for inference to get a single student network as a result.

 

- The second contribution is a novel perturbation strategy improving upon the one proposed in Nam et al(ODS).

 

 

두번째 contribution은 다음과 같다.

student의 diversity는 작고, teacher의 diversity는 크도록하는 $\epsilon$ 값을 갖고 $\hat{x}$를 만들자

 

computational cost 절약을 위해서, $i, j$는 random sampling하여 stochastic approximation했다.

 

신기하게도, $KL Loss$ input중 하나를 stop gradient 하면 성능이 더 오른다고 한다.

 

 

 

 

Result

TDiv - SDiv의 성능이 높음을 알 수 있다.

 

Comments