https://arxiv.org/pdf/2106.02034
최근에 Token Pruning에 대해서 볼일이 있어서 다시 보는 김에 리뷰를 남겨봅니다.
Abstract
Vision Transformer에서 가장 정보량이 많은 일부 토큰만 사용해서, 정확한 이미지 인식이 가능하다는 것을 발견했습니다.
이 관찰을 바탕으로, 중복 토큰들을 동적으로 제거하는 dynamic token sparsification 프레임워크를 제안합니다.
이를 위해, 현재 feature에서 각 token의 중요도를 측정하는 lightweight prediction module을 사용합니다.
이 모듈을 여러 layer에 추가하여 중복토큰을 반복적으로 제거합니다.
입력 token의 약 66%를 제거하여 31% ~ 37%정도의 FLOPs감소와 40% 이상의 처리량 향상을 달성하면서 정확도 하락은 0.5% 이내로 유지했습니다.
Introduction

일반적인 CNN의 pruning에서는 중요하지 않은 filter(channel)등을 프루닝하는 방법을 사용합니다.
하지만, ViT에서는 이미지의 패치 토큰으로 입력이 들어가기 때문에 덜 중요한 토큰을 제거하는 방식을 사용할 수 있습니다.
이러한 방식의 연산이 가능한 이유는 self-attention이 가변 길이의 token seq를 처리할 수 있기 때문입니다.
논문의 저자들은 특정 layer마다(예를 들어 4, 7, 10번째) 어떤 토큰을 제거할지 동적으로 결정하는 lightweight prediction module을 사용하여 점차 제거하는 토큰의 양을 늘려가는 hierarchical sparsification을 사용합니다.
모델과 모듈을 end-to-end로 학습하기 위해서 두가지 방법을 사용합니다.
1. Gumbel-softmax - softmax결과를 binary로 만들지만, gradient는 흘릴 수 있는 방법
2. attention masking - 실제로 token을 0으로 만들면 attention계산에 영향을 주기 때문에, 제거된 토큰은 attention 연산에 참여하지 못하도록 만듦.
이러한 방법을 통해서, 31%~37%의 GFLOPs 감소, 40% 이상의 throughput 향상을 달성하면서도, 정확도 하락은 0.5% 이하를 달성했습니다.
Method
Hierarchical Token Sparsification with Prediction Modules
Token sparsification은 반복적으로 이루어지기 때문에, 토큰들은 binary decision mask ($\hat{D} \in {0, 1}^{N}$) 을 유지합니다.
(이때 N은 patch embedding의 개수)
맨 처음에는 모든 mask의 값을 1로 초기화합니다.

MLP에 토큰 $x$ ($NC$)를 넣고, ($NC'$)를 예측합니다. ($C' = C / 2$)
이렇게 계산된 $z^{local}$과 mask를 통합하여, 전체에서 살아있는 token만 모은 $z^{global}$을 만들어냅니다.

이때, Agg 함수는 살아있는 token들의 평균입니다.

$z^{local}$은 특정 token자체의 정보, $z^{global}$은 전체 이미지의 정보를 담고 있어서 두개를 concat해서 사용합니다. concat한 local-global embedding 정보를 사용하여, 각 토큰을 버릴지 말지 결정하는 확률을 예측합니다.

해당 파이 값들을 모아서 현재의 mask $D$ 를 만들고, 이전 mask $\hat{D}$와 곱해서 갱신합니다.

End-to-end Optimization with Attention Masking
크게 두가지 문제를 해결합니다.
1. 위에서 파이값으로 이진 mask를 만드는 과정이 미분이 안되는 문제가 있습니다. (softmax -> binary)
이를 해결하기 위해서, Gumbel-Softmax 기법을 사용하여 mask를 생성합니다.
2. 토큰을 단순하게 0으로 바꾸면, self-attention에서 score를 계산할때 영향을 주게 됩니다.
이유는 softmax 연산 때문인데, 0이 들어가도 softmax를 통과한 이후에는 0이 아니게 되기 때문입니다.

이를 해결하기 위해, softmax에서 0인 애들은 모두 제외하고 계산합니다.
이때 i == j인 self-loop는 1로 두었는데, 토큰이 pruned되었더라도 자기 자신은 볼 수 있게 두는 것이 안정성에 좋았다고 합니다.
Training and Inference
학습에서는 end-to-end 학습을 위해서 실제로 토큰을 줄이지는 않고, masking을 사용합니다.
추론때는 실제로 특정 비율만큼 토큰을 제거하여 속도를 향상시키는 방법을 사용합니다.
loss는 총 4가지를 사용하고, 첫번째로는 CrossEntropy를 사용합니다.
token pruning으로 인한 성능저하를 방지하기 위해 dense network를 teacher로 삼아서 distillation loss를 '두개' 사용합니다.
1. 맨 마지막 stage의 token끼리 distillation

2. output logits에 대한 distillation (일반 KD)
token pruning sparsity를 유지하기 위한 regularization term을 추가합니다.

Experiments



'Papers > Compression' 카테고리의 다른 글
| TRAINING-FREE ACTIVATION SPARSITY IN LARGE LANGUAGE MODELS (0) | 2026.03.17 |
|---|---|
| TOKEN MERGING: YOUR VIT BUT FASTER (0) | 2026.03.17 |
| Sparse Structure Exploration and Re-optimization for Vision Transformer (0) | 2026.02.19 |
| ParetoQ: Improving Scaling Laws in Extremely Low-bit LLM Quantization (0) | 2026.01.02 |
| BitNet: Scaling 1-bit Transformers for Large Language Models (0) | 2025.11.21 |