둔비의 공부공간
Continual Test-Time Domain Adaptation 본문
https://github.com/qinenergy/cotta
GitHub - qinenergy/cotta: [CVPR 2022] Official CoTTA Code for our paper Continual Test-Time Domain Adaptation
[CVPR 2022] Official CoTTA Code for our paper Continual Test-Time Domain Adaptation - GitHub - qinenergy/cotta: [CVPR 2022] Official CoTTA Code for our paper Continual Test-Time Domain Adaptation
github.com
Abstract
TTA (Test-time domain adaptation)은 source pretrained model을 target domain에 source data없이 adaptation시키는 것을 말한다.
TTA도 많은 연구가 있었는데, 기존의 연구들에서는 target domain이 고정되어 있다고 가정하고 연구를 했었다.
그러나, 실제 환경에서는 target domain은 계속 변하므로, self-training과 entropy regularization등의 기존 연구를 적용하기엔 어려움이 있었다. (계속 distribution이 변하면서, pseudo-label을 신뢰할 수 없기 때문이다.)
이런 변해버린 pseudo-label은 error accumulation과 catastrophic forgetting을 야기할 수 있다.
error accumulation은 정확도 하락을 말하고, catastrophic forgetting 배운걸 까먹는 현상을 말한다.
이러한 문제를 해결하기 위해, continual test-time adaptation approach (CoTTA)를 제안한다.
CoTTA의 two steps
- reduce error accumulation
- 더 정확해진 weight averaged prediction과 augmentation-averaged prediction 을 사용한다.
- catastrophic forgetting
- 확률적으로 neurons의 작은 부분을 source model로 보존하여, source knowledge를 보존할 수 있도록 설계했다.
Introduction
기존의 test-time adaptation 기존 연구에서는 pseudo-label이나 entropy regularization을 이용해서 model의 parameter를 업데이트해서, source domain <-> fixd target domain간의 distribution shift를 해결했었다.
근데, 기존연구는 test domain이 계속 바뀌면 불안정한 문제가 있었다.
- domain이 바뀔때 distribution shift로 인해 pseudo-labels가 noiser and miscalibrated 된다.
- model이 새로운 distribution에 오래 adaptation하면서, 기존 source domain의 정보를 보존하지 못하고, 이는 catastrophic forgetting으로 이어진다.
이러한 위의 문제를 해결하기 위해, 아래와 그림과 같은 모델을 제안했다.
또한 pseudo-label의 quality를 향상시키기 위해 두가지의 다른 방법을 사용했다.
- mean teacher prediction이 standard model보다 좋은 quality를 보인다는 점에서, weight average teacher model을 사용해서 조금 더 좋은 quality의 prediction을 넘겨준다.
- 만약, test data가 기존 data와 domain gap이 클 경우에는 augmentation-averaged prediction을 넘겨준다.
Preserve source knowledge and avoid forgetting을 방지하기 위해서는 다음과 같은 것들을 추가했다.
- 확률적으로 neurons의 작은 부분을 기존 source model로 돌려서 source domain knowledge를 잊지 않도록 했다.
기존 batchnorm만 학습하던 방법들과는 다르게 모든 layer를 학습하여 long-term adaptation이 가능하도록 했다.
Method
source data $(x^{s}, y^{s})$를 갖고 학습한 $\theta$ parameter의 model $f_{\theta_{0}}(x)$라고 하고
$t$ 번째 domain의 target data를 $(x_{t}^{T})$라고 하자.
time step $t$에서 target data $x_{t}^{T}$가 주어지면, model은 $f_{\theta_{t}}(x_{t}^{T})$에 대한 prediction을 만들어야하고
추후 $t+1$ 시점에서는 $t$의 domain에 대해서는 adaptation이 된 $\theta_{t+1}$이 되어야한다.
target data $x_{t}^{T}$가 들어오면 prediction $\hat{y}_{t}^{T} = f_{\theta_{t}}(x_{t}^{T})$와 pseudo-label간의 cross entropy를 minimize하며 training을 한다.
기존 TENT처럼 entropy를 minimization하는 방향으로 최적화하면, 변화하는 domain에서는 pseudo-label의 quality가 하락하면서 성능하락으로 이어질 수 있다.
weight-averaged model이 final model보다 성능이 좋다는 기존 연구들에서 영감을 받아서, weight-averaged teacher model $f_{\theta}$를 만들어서 pseudo-labels을 만들도록 설계했다.
- time $t = 0$일 때는 teacher network가 source pretrained network와 동일하다.
- time $t > 0$일 때는 $\hat{y'}_{t}^{T} = f_{\theta'_{t}}(x_{t}^{T})$ 로 pesudo-label을 만든다.
- 이 pseudo-label로 student network를 cross entropy로 학습시킨다.
위의 cross entropy로 학습한 student model의 파라미터가 $\theta_{t}$ -> $\theta_{t+1}$로 업데이트가 될때,
student model의 weight를 사용한 exponential moving average로 teacher model을 업데이트한다.
- $\theta_{t+1}' = \alpha \theta_{t}' + (1 - \alpha ) \theta_{t+1}$
- $\alpha$는 smoothing factor다.
weight-averaged consistency는 두가지 장점이 있다. (semi-supervised learning에서 사용하는 mean teacher method)
- 좀 더 정확한 weight-averaged prediction을 pseudo-label로 사용할 수 있기에, model은 continual adaptation에서 error accumulation을 겪을일이 줄어든다.
- mean teacher prediction $\hat{y'}_{t}^{T}$는 과거 iteration 정보 encoding 되어있다. 그러므로, long-term continual adaptation에서 catastrophic forgetting을 겪을일이 줄어들고, unseen domain에 대한 generalization capacity를 향상시킬 수 있다.
논문에서는 Augmentation-Averaged Pseudo-labels도 제안했다.
- test time의 domain shift를 고려해서, prediction confidence에 따라 domain의 차이를 측정한다.
- 이 domain의 차이가 클때 augmentation을 사용하여 error accumulation을 낮췄다.
source pretrained model에 현재 $t$의 target input $x_{t}^{T}$가 들어왔을때, 해당 confidence값이 threshold $p_{th}$ 보다 낮으면, 아까의 target input $x_{t}^{T}$에 augmentation을 한 input $aug_{i}(x_{t}^{T})$의 teacher model prediction을 평균내어 pseudo-label로 사용한다.
- domain difference가 작을때 random $i$ augmentation을 적용하면 모델의 성능이 하락하는 것을 확인했다.
- 무튼 이렇게 구한 pseudo-label로 student를 학습시킨다.
catastrophic forgetting을 해결하기 위해 stochastic restoration method를 제안했다.
- 확률적으로 source pretrained model로 되돌리는 작업이다.
- Mask $M$을 Bernoulli(p)로 생성한다.
- $p = 0.01$
- $W_{t+1} = M \bigodot W_{0} + (1-M) \bigodot W_{t+1}$
이러한 stochastic restoration method을 통해서 기존 source pretrained model을 잊지 않고, adaptation을 할 수 있다고 주장한다.
기존 TENT방식보다 error rate가 낮은 것을 볼 수 있다.
'Papers > Domain Adaptation' 카테고리의 다른 글
NOTE: Robust Continual Test-time Adaptation Against Temporal Correlation (0) | 2023.09.12 |
---|---|
TENT: FULLY TEST-TIME ADAPTATION BY ENTROPY MINIMIZATION (0) | 2023.09.06 |