https://arxiv.org/pdf/2106.02034

 

 

최근에 Token Pruning에 대해서 볼일이 있어서 다시 보는 김에 리뷰를 남겨봅니다.

 

 

Abstract

Vision Transformer에서 가장 정보량이 많은 일부 토큰만 사용해서, 정확한 이미지 인식이 가능하다는 것을 발견했습니다.

이 관찰을 바탕으로, 중복 토큰들을 동적으로 제거하는 dynamic token sparsification 프레임워크를 제안합니다.

 

이를 위해, 현재 feature에서 각 token의 중요도를 측정하는 lightweight prediction module을 사용합니다.

이 모듈을 여러 layer에 추가하여 중복토큰을 반복적으로 제거합니다.

 

입력 token의 약 66%를 제거하여 31% ~ 37%정도의 FLOPs감소와 40% 이상의 처리량 향상을 달성하면서 정확도 하락은 0.5% 이내로 유지했습니다.

 

 

Introduction

 

일반적인 CNN의 pruning에서는 중요하지 않은 filter(channel)등을 프루닝하는 방법을 사용합니다.

하지만, ViT에서는 이미지의 패치 토큰으로 입력이 들어가기 때문에 덜 중요한 토큰을 제거하는 방식을 사용할 수 있습니다.

이러한 방식의 연산이 가능한 이유는 self-attention이 가변 길이의 token seq를 처리할 수 있기 때문입니다.

 

논문의 저자들은 특정 layer마다(예를 들어 4, 7, 10번째) 어떤 토큰을 제거할지 동적으로 결정하는 lightweight prediction module을 사용하여 점차 제거하는 토큰의 양을 늘려가는 hierarchical sparsification을 사용합니다.

 

모델과 모듈을 end-to-end로 학습하기 위해서 두가지 방법을 사용합니다.

1. Gumbel-softmax - softmax결과를 binary로 만들지만, gradient는 흘릴 수 있는 방법

2. attention masking - 실제로 token을 0으로 만들면 attention계산에 영향을 주기 때문에, 제거된 토큰은 attention 연산에 참여하지 못하도록 만듦. 

 

 

이러한 방법을 통해서, 31%~37%의 GFLOPs 감소, 40% 이상의 throughput 향상을 달성하면서도, 정확도 하락은 0.5% 이하를 달성했습니다.

 

 

Method

Hierarchical Token Sparsification with Prediction Modules

Token sparsification은 반복적으로 이루어지기 때문에, 토큰들은 binary decision mask ($\hat{D} \in {0, 1}^{N}$) 을 유지합니다.

(이때 N은 patch embedding의 개수)

맨 처음에는 모든 mask의 값을 1로 초기화합니다.

MLP에 토큰 $x$ ($NC$)를 넣고, ($NC'$)를 예측합니다. ($C' = C / 2$)

 

이렇게 계산된 $z^{local}$과 mask를 통합하여, 전체에서 살아있는 token만 모은 $z^{global}$을 만들어냅니다.

 

이때, Agg 함수는 살아있는 token들의 평균입니다.

mask D와 곱해서 살아있는 token feature만 다 더하고, 살아있는 개수로 나누는 평균 값

 

$z^{local}$은 특정 token자체의 정보, $z^{global}$은 전체 이미지의 정보를 담고 있어서 두개를 concat해서 사용합니다. concat한 local-global embedding 정보를 사용하여, 각 토큰을 버릴지 말지 결정하는 확률을 예측합니다.

i번째 token을 버리면 0, 살리면 1 (Nx2) 형태의 output

 

해당 파이 값들을 모아서 현재의 mask $D$ 를 만들고, 이전 mask $\hat{D}$와 곱해서 갱신합니다.

Mask 갱신

 

 

End-to-end Optimization with Attention Masking

 

크게 두가지 문제를 해결합니다.

 

1. 위에서 파이값으로 이진 mask를 만드는 과정이 미분이 안되는 문제가 있습니다. (softmax -> binary)

이를 해결하기 위해서, Gumbel-Softmax 기법을 사용하여 mask를 생성합니다.

 

 

2. 토큰을 단순하게 0으로 바꾸면, self-attention에서 score를 계산할때 영향을 주게 됩니다.

이유는 softmax 연산 때문인데, 0이 들어가도 softmax를 통과한 이후에는 0이 아니게 되기 때문입니다.

 

 

 

이를 해결하기 위해, softmax에서 0인 애들은 모두 제외하고 계산합니다.

이때 i == j인 self-loop는 1로 두었는데, 토큰이 pruned되었더라도 자기 자신은 볼 수 있게 두는 것이 안정성에 좋았다고 합니다.

 

Training and Inference

학습에서는 end-to-end 학습을 위해서 실제로 토큰을 줄이지는 않고, masking을 사용합니다.

추론때는 실제로 특정 비율만큼 토큰을 제거하여 속도를 향상시키는 방법을 사용합니다.

 

loss는 총 4가지를 사용하고, 첫번째로는 CrossEntropy를 사용합니다.

token pruning으로 인한 성능저하를 방지하기 위해 dense network를 teacher로 삼아서 distillation loss를 '두개' 사용합니다.

1. 맨 마지막 stage의 token끼리 distillation

맨 마지막 stage의 살아남은 token t끼리 distillation을 수행함

2. output logits에 대한 distillation (일반 KD)

token pruning sparsity를 유지하기 위한 regularization term을 추가합니다.

원하는 p만큼을 유지하기 위해서, mask와 차이를 regularizer로 설정함

 

 

Experiments

다른 방법들과 DynamicViT간의 정확도와 Params, GFLOPs비교, 일반적으로 비슷한 정확도 대비 30% 이상 FLOPs가 높다.

 

 

재미있는 결과인데, 실제 프루닝된 token으로 visualize했을때 신기하게 딱 물체를 구분할 수 있는 외적인 애들이 지워지는 것 같다.

 

각 stage별로 token이 살아있을 확률을 시각화한 것 같은데, 약간 Image의 물체는 항상 중앙에 있을 확률이 높다 << 는 정보가 반영된게 아닐까 싶다.

 

 

 

+ Recent posts