Learned step size quantization
블로그에 자세히 정리한다고 생각하고 글을 쓰다 보니, 논문 리뷰가 아니라 번역에 가깝게 옮겨놓게 되면서
[작성중] 카테고리에 비공개로 쓰다가 만 논문들이 너무 많아졌다.
그냥 한번 읽고, 기억나는 내용을 중점으로 정리하는 방법으로 바꿔야겠다.
https://arxiv.org/abs/1902.08153
Learned Step Size Quantization
Deep networks run with low precision operations at inference time offer power and space advantages over high precision alternatives, but need to overcome the challenge of maintaining high accuracy as precision decreases. Here, we present a method for train
arxiv.org
이 논문은 QAT에서 step size $s$를 학습하는 논문이다.
공식적인 github repo는 없는 것 같다.
주의 깊게 볼만한 내용은 step size를 학습하면서 STE를 제안했다는 점이다.
논문의 버전은 v1, v2, v3로 이뤄진 것 같은데 나는 v3 버전을 본 것 같다.
위 그림을 보면 쉽게 이해할 수 있다.
weight용 step size $S_{w}$와 input(activation)용 step size $S_{x}$가 있고, 이를 통해서 아래 식으로 Quantization하여 $\bar{w}$와 $\bar{x}$를 만든다.
unsigned data (activation) 기준으로, $Q_{N} = 0$, $Q_{P} = 2^{b}-1$이다.
signed data (weight) 기준으로, $Q_{N} = 2^{b-1}$, $Q_{P} = 2^{b-1}-1$이다.
즉, 8bit라면 activation $Q_{N} = 0$, $Q_{P} = 255$, weight $Q_{N} = 128$, $Q_{P} = 127$
저렇게 quantization하고 다시 step size를 곱하면 $\hat{v}$를 만들 수 있다.
이제, 이렇게 나온 $\hat{v}$로 loss를 계산하여 $s$에 대해서 미분하려고 하는데, 위 $\bar{v}$를 만들때 사용했던 "round"함수가 문제가 된다. round함수는 미분이 불가능하기 때문이다.
이를 위해서 저자들은 다음과 같은 STE를 제안한다.
이 내용을 해석하면, 저 clip 범위안에 있어서 gradient가 흘러야할때는 -v/s + round(v/s)를 흘리겠다고 하는 것이다.
$\bar{V} = \left[ \text{clip}(\frac{v}{s}, -Q_N, Q_P) \right]$
$\hat{V} = \bar{V} \times S$
을 기억하라!
$\frac{\partial \hat{V}}{\partial S} = \frac{\partial (\bar{V} \times S)}{\partial S} = S \frac{\partial \bar{V}}{\partial S} + \bar{V} \frac{\partial S}{\partial S} = S \frac{\partial \bar{V}}{\partial S} + \bar{V} = S \frac{\partial \left[ \frac{V}{S} \right]}{\partial S} + \bar{V}$
여기서 $\left[ \frac{V}{S} \right ]$의 round의 gradient는 그냥 1로 근사함. (avg했더니 1로 쓸 수 있겠더라~)
$S\frac{\partial \frac{V}{S} }{\partial S} + \bar{V} = - S\frac{V}{S^2} + \left[ \frac{V}{S} \right ]$
$ = - \frac{V}{S} + \left[ \frac{V}{S} \right ]$
이런식으로 STE를 구성했더니, 기존 두 방법과 다르게 quantization transition point (quantization 했을때 값이 바뀌는 부분)에서 멀어질수록 gradient가 증가하는 컨셉과 맞게 잘 작동했다고 함.
읽으면서 들었던 생각이 그냥 $v$ -> $\bar{v}$ -> $\hat{v}$로 loss를 계산하면서 이 quantization error를 줄이는 방향으로 최적화하는거면, full precision $v$와 $\bar{v}$의 값을 줄이도록 KL Divergence로 학습하는 것과 동일한 컨셉인가? 했는데, 3.6 Quantization Error section 실험에서, L1, L2, KL로 최적화하도록 실험해봤는데, LSQ에서 학습한 s와 차이가 있었다고 한다. 즉, LSQ는 단순하게 quantization error를 최적화 하는 것과는 다르다고 한다.