둔비의 공부공간

Cumulative Spatial Knowledge Distillation for Vision Transformer 본문

Papers/Compression

Cumulative Spatial Knowledge Distillation for Vision Transformer

Doonby 2024. 4. 11. 19:43

https://arxiv.org/abs/2307.08500

 

Cumulative Spatial Knowledge Distillation for Vision Transformers

Distilling knowledge from convolutional neural networks (CNNs) is a double-edged sword for vision transformers (ViTs). It boosts the performance since the image-friendly local-inductive bias of CNN helps ViT learn faster and better, but leading to two prob

arxiv.org

 

DKD저자와 같은 저자로, 코드는 아직 올라오지 않은 것 같다.

 

 

Abstract

기존 CNN에서 ViT로 Knowledge distillation을 받는 것은 '양날의 검'이다.

Image에 친화적인 Convolution을 통해, inductive bias를 받는 것은 ViT의 초반 수렴속도를 향상시킬 수 있지만

아키텍쳐 차이와 모델의 capability의 차이에서 오는 문제점도 존재한다.

 

논문에서는 이러한 문제점을 Cumulative Spatial Knowledge Distillation(CSKD)을 제안하여 해결한다.

 

 

 

Methods

논문에서 주장하는 CNN -> ViT의 양날의 검은 다음과 같다.

 

장점

  • CNNs의 Inductive bias가 ViT에 학습 초반 수렴속도 향상에 도움이 될 수 있다.

단점

  • Network architecture가 다르다.
    • Receptive field
      • CNN은 Conv Block을 쌓으면서 점차 늘어나는 형태고, ViT는 self-attention으로 한번에 global정보를 볼 수 있다.
    • The way to stack blocks
      • CNN은 channel수가 점차 늘어나는 방향으로 쌓지만, ViT는 동일한 구조가 반복된다.
    • The type of normalization layers
      • CNN은 batch, group norm을 사용하지만, ViT는 Layer norm을 사용한다.
  • Network capability가 다르다.
    • CNN자체는 훌륭한 모델이 맞지만, ViT와 비교했을때는 capability가 약하다.

 

논문에서는 Cumulative Spatial Knowledge Distillation (CSKD)를 통해 위의 단점 두개를 극복하는 방법을 제안한다.

CSKD framework

 

 

논문에서는 기존 DeiT구조와 loss에 $L_{CSKD}$를 새로 정의한다.

DeiT 구조

 

 

$L_{CSKD}$는 어떻게 두 문제를 해결할 수 있을까?

Network architecture가 다르다는 문제는 어떻게 해결할까?

Network architecture가 다를 때 오는 제일 큰 문제는 'feature distillation'이 효과적이지 않다는 것이다.

두 architecture의 receptive field 등이 다르니까, feature자체가 다른 이미지를 보고 생성됐을 확률이 있다는 것이다.

 

이 논문의 저자들은 이를 해결하기 위해 hard target logit distillation $L_{CSKD}$를 정의한다.

$L_{CSKD} = CE(P_{patch}^{S}, Y^{T})$ 이다.

  • DeiT의 CE는 Hard Distillation라고 생각하면 된다.
    (CE의 target에 1.0이 아니라, teacher의 target에 대한 confidence가 들어간다.)
  • 저자들은 단순 feature matching이 아니라, spatial정보가 담긴 logits을 distillation하면서 기존 feature misalign 문제를 피했다.

CNN의 capability가 부족한 문제는 어떻게 해결할까?

DearKD(CVPR 2022)는 CNNs이 ViT의 초기 stage에서 수렴을 도와주지만, 오히려 나중에는 수렴을 방해한다고 한다.

기존 DearKD는 이를 해결하기 위해 stage를 둘로 나누어, 후반 stage에서는 KD를 주지 않는 방법을 선택했다.

 

하지만, 논문의 저자들은 CNN은 후반에도 여전히 효과적인 supervision을 줄 수 있을 것이라고 생각했다.

ViT는 학습후반에 들어서면, token과 다른 token사이의 global relation은 이미 다 학습했을 것이라고 이야기하며, 후반에는 local relation으로 supervision을 바꾸어 학습시킨다.

 

학습의 초반과 후반에 따라서 local과 global을 구분하기 위해, $\alpha$를 두어, epoch에 따라서 변화하도록 설계했다.

 

$Y^{T} = argmax_{c}(\alpha P_{local}^{T} + (1-\alpha)*P_{global}^{T})$

$\alpha = 1 - epoch/epoch_{max}$

local과 global logits은 어떻게 만들까?

$P_{global}$ 은 우리가 흔히 아는 last feature map -> pooling -> fc의 결과물이다.

  • (BS, C)의 output dims을 갖는다.

$P_{local}$은 last feature map -> fc인데, (B, N, H, W)를 (B * H* W, N)로 만들어서 fc를 통과시킨다.

  • (BS, H, W, C)의 output dims를 갖는다.

이제 CNNs에서는 $P_{global}$와 $P_{local}$을 만드는 방법을 알았는데, ViT에서는 어떻게 만들까?

일단, ViT에서는 distillation token을 사용하지만, CNN의 $P_{local}$을 distillation받기 위해 ViT도 $P_{local}$을 정의한다.

(BS, N, C)의 경우 N이 196이면 14x14과 같다. (BS, 14, 14, C) 이를 pooling등을 통해서 CNN의 feature shape와 맞춰준다.

 

 

 

Experiments

 

DeiT의 실험세팅을 그대로 가져왔다고 한다. ImageNet부터 downstream task에서 SOTA 성능을 보였다.

 

 

왼쪽을 보면, CKF를 넣었을때 확실히 후반의 val accuracy가 향상되는 것을 볼 수 있으며, 오른쪽에서는 local, global, CKF모두 사용하는 것이 좋았다고 한다.

 

 

 

Conclusion

CNN과 ViT의 architecture로 인한 mismatching 문제를 spatial 정보가 담긴 logits으로 해결했다.

동시에, CNN과 ViT의 capability로 인한 mismatching 문제를 CKF라는 epoch에 따라서 global / local로 변환해주는 term을 추가하여 해결함을 보였다.

다양한 데이터셋에 대해서 SOTA를 기록했다.

 

 

사견

최근 Feature alignment문제를 logits으로 풀어가는 논문들이 많은데, 이 논문도 비슷하다.

그렇다보니, logits 자체를 활용하는 것은 참신하다는 생각은 못들었지만, logits에 spatial정보를 담는것은 의미있어 보인다.

 

Comments