둔비의 공부공간

DOT: A Distillation-Oriented Trainer 본문

Papers/Compression

DOT: A Distillation-Oriented Trainer

Doonby 2023. 11. 23. 15:33

 

https://arxiv.org/abs/2307.08436

 

DOT: A Distillation-Oriented Trainer

Knowledge distillation transfers knowledge from a large model to a small one via task and distillation losses. In this paper, we observe a trade-off between task and distillation losses, i.e., introducing distillation loss limits the convergence of task lo

arxiv.org

 

Accepted by ICCV 2023

 

 

Decoupled Knowledge distillation이랑, Curriculum temperature for knowledge distillation의 저자와 동일하다.

 

 

한동안 종이로 논문을 읽었더니, 리뷰를 잘 안하게 됐다.

논문 세미나 발표자료를 만드는김에 오랜만에 리뷰를 작성하게 됐다. 

 

 

 

Motivation

 

Knowledge distillation은 잘 학습된 Large Teacher의 분포를 small student의 network가 따라가도록 한다.

 

vanilla kd의 loss는 $L = \alpha H(y,\sigma(p)) + (1 - \alpha) D_{KL} (\sigma(p/T) || (\sigma(q/T))$로 이뤄져있다.

 

이를 논문에서는 task loss + distillation loss로 말한다.

단순하게 생각했을때, Large Teacher는 student model에 비해 task loss도 낮고 distillation loss도 낮을 것이다.


그렇다면 student가 distillation loss를 추가하여 학습하면 자연스럽게 task loss도 낮아져야한다.

 

하지만, 논문에서는 distillation loss를 추가했더니, task loss만 갖고 학습하는 모델에 비해서 task loss가 수렴하지 못하는 trade off를 발견한다.

 

KD를 사용했는데, 오히려 task loss가 baseline보다 수렴하지 못하는 현상

 

 

 

이러한 문제를 논문에서는 multi-task문제로 task와 distillation을 적당하게 최적화하여, distillation이 최적화 되지 않았던 것은 아닐까? 하고 예상했다.

 

그렇다면, distillation을 최적화 시킬 수 있다면 task loss와 distillation loss를 둘 다 줄일 수 있지 않을까? 로 논문이 시작된다.

 

 

 

Methods

task loss와 distillation loss를 같이 최적화하되, distillation을 조금 더 확실하게 최적화 하기 위해, 논문에서는 모멘텀을 수정했다.

momentum을 키우면 local minima에 수렴하는 속도가 빨라진다는 논문들이 있다.

 

 

 

기존 모멘텀 공식

 

 

위가 기존 momentum 공식이다. 

 

 

논문에서 제안하는 모멘텀

 

 

이를 갖고, 논문에서는 위와 같이 task와 distillation의 momentum을 분리시키고, coefficient를 다르게 주었다.

Vanilla KD의 모멘텀과 DOT의 모멘텀 차이

 

 

이렇게 momentum의 coefficient를 바꾸면, distillation의 momentum은 커지고, 커진만큼 task의 momentum은 작아지게 된다.

그 결과, distillation은 task보다 조금 더 빠르게 수렴할 수 있게 된다.

 

 

Results

 

CIFAR100에서의 성능. SOTA logit KD라고 할 수 있는 DKD보다 높다.

 

ImageNet에 대해서도 DKD보다 높아지는 것을 볼 수 있다.

 

 

 

또한, distillation loss를 사용하면 student의 local minima surface가 flat해진다는 이야기가 있다.

DOT에서는 distillation를 더욱 최적화했으니, 더 flat해지는 것을 보였다.

flat해지는 모습 (일반화 성능의 향상)

 

 

또한, distillation을 최적화하면, 이미 task loss가 낮은 teacher의 분포를 따라가기에, task loss도 더 낮아질 수 있음을 보였다.

Task와 Distillation loss가 baseline, KD보다 낮아지는 모습

 

 

 

그렇다면, 오히려 task loss를 강화하거나, coefficient를 바꾸면 어떻게 되는지도 ablation study가 있다.

task momentum을 키우면 성능이 떨어졌으며, coefficient를 바꾸면 의미있는 성능향상이 없었다고 한다.

ablation study

 

 

Summary

  1. 기본적으로 distillation은 task loss가 작은 large teacher model의 분포를 따라가는데, 오히려 distillation loss를 사용하니, Task loss가 수렴하지 않더라

  2. 직관적으로 이해가 가지 않는데, Distillation 이 충분히 optimization 되지 않은 것은 아닐까?

  3. Task와 Distillation의 Gradient를 계산할 때, momentum을 다르게 주어, distillation을 더욱 빠르게 수렴시켰더니, 성능이 오르더라.
Comments