둔비의 공부공간
DOT: A Distillation-Oriented Trainer 본문
https://arxiv.org/abs/2307.08436
DOT: A Distillation-Oriented Trainer
Knowledge distillation transfers knowledge from a large model to a small one via task and distillation losses. In this paper, we observe a trade-off between task and distillation losses, i.e., introducing distillation loss limits the convergence of task lo
arxiv.org
Accepted by ICCV 2023
Decoupled Knowledge distillation이랑, Curriculum temperature for knowledge distillation의 저자와 동일하다.
한동안 종이로 논문을 읽었더니, 리뷰를 잘 안하게 됐다.
논문 세미나 발표자료를 만드는김에 오랜만에 리뷰를 작성하게 됐다.
Motivation
Knowledge distillation은 잘 학습된 Large Teacher의 분포를 small student의 network가 따라가도록 한다.
vanilla kd의 loss는 $L = \alpha H(y,\sigma(p)) + (1 - \alpha) D_{KL} (\sigma(p/T) || (\sigma(q/T))$로 이뤄져있다.
이를 논문에서는 task loss + distillation loss로 말한다.
단순하게 생각했을때, Large Teacher는 student model에 비해 task loss도 낮고 distillation loss도 낮을 것이다.
그렇다면 student가 distillation loss를 추가하여 학습하면 자연스럽게 task loss도 낮아져야한다.
하지만, 논문에서는 distillation loss를 추가했더니, task loss만 갖고 학습하는 모델에 비해서 task loss가 수렴하지 못하는 trade off를 발견한다.
이러한 문제를 논문에서는 multi-task문제로 task와 distillation을 적당하게 최적화하여, distillation이 최적화 되지 않았던 것은 아닐까? 하고 예상했다.
그렇다면, distillation을 최적화 시킬 수 있다면 task loss와 distillation loss를 둘 다 줄일 수 있지 않을까? 로 논문이 시작된다.
Methods
task loss와 distillation loss를 같이 최적화하되, distillation을 조금 더 확실하게 최적화 하기 위해, 논문에서는 모멘텀을 수정했다.
위가 기존 momentum 공식이다.
이를 갖고, 논문에서는 위와 같이 task와 distillation의 momentum을 분리시키고, coefficient를 다르게 주었다.
이렇게 momentum의 coefficient를 바꾸면, distillation의 momentum은 커지고, 커진만큼 task의 momentum은 작아지게 된다.
그 결과, distillation은 task보다 조금 더 빠르게 수렴할 수 있게 된다.
Results
또한, distillation loss를 사용하면 student의 local minima surface가 flat해진다는 이야기가 있다.
DOT에서는 distillation를 더욱 최적화했으니, 더 flat해지는 것을 보였다.
또한, distillation을 최적화하면, 이미 task loss가 낮은 teacher의 분포를 따라가기에, task loss도 더 낮아질 수 있음을 보였다.
그렇다면, 오히려 task loss를 강화하거나, coefficient를 바꾸면 어떻게 되는지도 ablation study가 있다.
task momentum을 키우면 성능이 떨어졌으며, coefficient를 바꾸면 의미있는 성능향상이 없었다고 한다.
Summary
- 기본적으로 distillation은 task loss가 작은 large teacher model의 분포를 따라가는데, 오히려 distillation loss를 사용하니, Task loss가 수렴하지 않더라
- 직관적으로 이해가 가지 않는데, Distillation 이 충분히 optimization 되지 않은 것은 아닐까?
- Task와 Distillation의 Gradient를 계산할 때, momentum을 다르게 주어, distillation을 더욱 빠르게 수렴시켰더니, 성능이 오르더라.
'Papers > Compression' 카테고리의 다른 글
Curriculum Temperature for Knowledge Distillation (0) | 2023.12.28 |
---|---|
Multi-level Logit Distillation (1) | 2023.12.08 |
Toward domain generalized pruning by scoring out-of-distribution importance (0) | 2023.09.26 |
Prune Your Model Before Distill It (0) | 2023.08.17 |
Triplet Knowledge Distillation (0) | 2023.07.26 |