https://github.com/FasterDecoding/TEAL

 

GitHub - FasterDecoding/TEAL

Contribute to FasterDecoding/TEAL development by creating an account on GitHub.

github.com

 

ICLR2025 spotlight 논문입니다.

 

 

Abstract

Activation sparsity는 forward에서 행렬 곱에 필요한 연산량과 메모리 이동량을 줄여서 추론속도를 높일 수 있습니다.

하지만, 기존 연구들은 ReLU기반의 모델에서만 동작되게 설계되거나, 추가 학습이 필요한 한계가 있었습니다.

 

이 논문에서는 TEAL(Training-Free Activatioon Sparsity in LLMs)라는 학습이 필요없는 간단한 방법을 제안합니다.

모델 전체의 hidden state에 magnitude-based activation sparsity를 적용합니다.

 

7B ~ 70B까지의 Llama2, Llama3, Mistral 계열 모델에서 성능 저하를 최소화하며 40~50%의 sparsity를 달성했습니다.

sparsity가 40%일때 최대 1.53배, 50%일때 최대 1.8배의 decoding 속도 향상을 보였습니다.

 

또한 weight quantization과도 호환이 가능하기 때문에, 더욱 효율성을 향상시킬 수 있습니다.

 

Overview of TEA, low-magnitude activation에 해당하는 weight를 불러오지 않음

 

 

Introduction

LLM의 병목에는 크게 두가지가 있습니다.

  • small-batch 환경에서, autoregressive inference일때 memory-bound
    • 작은 batch 추론에서, 모델 파라미터를 읽어오는 시간이 연산시간보다 오래걸리게 됩니다.
    • xW를 연산하는 것 보다, 모델 파라미터를 읽어오는 시간이 더 큰 병목이 되는 상황입니다.
  • prefill/학습에서, compute-bound
    • 처음 prompt로 KV Cache를 만들거나, 학습할 때는 모델 파라미터를 읽어오면 계속 사용하게 됩니다.
    • xW를 연산하는 것이 더 큰 병목이 되는 상황입니다. 

 

Activation이 0인 weight channel을 생략하고 읽어오면 더욱 빠르게 접근할 수 있는데, 이게 activation sparsity입니다.

 

과거 ReLU 기반의 transformer에서는 MLP 중간의 state가 95%의 sparsity를 갖고 있어서, 적용이 쉬웠습니다. 

하지만, 현대의 transformer는 ReLU가 아닌 GLU나 SwiGLU를 사용하기 때문에, activation이 sparsity를 갖지 않습니다.

또 다른 연구에서는 activation에 sparsity를 부여하기 위해, activation function을 수정하고 추가 학습을 필요로 했습니다.

 

이러한 문제점들을 해결하기 위해, 논문의 저자들은 TEAL(Training-Free Activation Sparsity in LLMs)를 제안합니다.

LLaMA계열의 activation이 평균이 0인 분포를 따른다는 관찰을 바탕으로, 작은 크기의 activation을 제거하여 모델 전체에서 40~50%의 sparsity를 달성합니다.

 

특수 커널을 통해서 40%에서 1.53, 50%에서 1.8배의 속도 향상을 보였으며, weight quantization도 호환되는 것을 보였습니다.

 

 

Method

Motivating Study: Distributional Properties of activations in LLMs

 

위 그림의 LLM activation 분포를 분석했을 때, 일반적으로 평균이 0에 가깝고 unimodal(한개의 봉우리) 형태를 보였습니다.

또한, Attention block과 MLP block의 앞쪽은 Gaussian(1,3)에 가깝고, 중간은 Laplacian(2,4)에 가까운 형태입니다.

 

저자들은 이렇게 0 근처에 값이 몰려있는 성질magnitude 기반의 activation pruning이 가능하도록 만들 수 있다고 이야기합니다.

 

 

TEAL

 

위 Motivating Study 분석을 기반으로, magnitude-based activation pruning이라는 단순한 접근을 제안합니다.

 

우선, magnitude pruning에는 threshold $t_{p}$를 정해야 합니다.

간단하게 이야기하면, $x$ 절대값의 크기가 threshold 이하인 비율이 sparsity level $p$가 되도록 하는 값입니다

sparsification 함수는 threshold보다 작으면 0으로 만들어주는 함수입니다. (프루닝 함수 생각하시면 돼요)

Threshold와 sparsificatin function 정의

 

프루닝하려고 하는 행렬들의 개수 $N$만큼 sparsity level을 갖습니다. ( $p = (p_{1}, ... , p_{N})$ )

 

입력 $x$를 sparsification 함수로 sparse하게 만든 후에, 연산을 수행합니다. 이를 $N$개 프루닝 대상에 모두 적용합니다.

 

 

Block-Wise Greedy Optimization

 

이제 최적의 $p_{i}$를 찾는 문제가 남았습니다. 저자들은 threshold를 gradient기반으로 학습하려고 했지만, optimization 문제가 있어서, 단순하게 greedy한 방법을 사용했다고 합니다.

 

방법은 각 출력의 $l_{2}$ activation error가 최소화되는 sparsity를 찾는 것입니다.

위 행렬들에 대해서, sparsity를 0으로 초기화 시키고, sparsity를 조금씩 증가시킵니다.

메모리가 큰 행렬은 작은 폭으로 올리고, 메모리가 작은 행렬은 큰 폭으로 올리면서 오차를 측정하고, 제일 오차가 적은 layer만 sparsity를 증가시킵니다. (단순하게 sparsity 비율로 접근하면, 아무래도 파라미터가 많은 행렬이 영향을 더 크게 받을테니까요)

 

이런 greedy optimization을 통해서, 모든 transformer block에 같은 block-level sparsity 목표를 부여하지만, 각 block 내부의 저 일곱 행렬에 분배되는 sparsity는 달라질 수 있습니다. 

 

이 과정은 Llama-3-8B기준 A100 한장에서 1GPU-hour정도가 소요된다고 합니다.

 

 

Hardware-aware Acceleration

 

 

내용을 잘 이해 못했는데, 대충 triton based sparse GEMV 커널에 여러 개선을 더해서 가속을 걸었다는 의미같습니다.

필요하신 분은 참고하시면 될 것 같습니다.

 

 

Experiments

모든 실험내용을 다루지 않았으니, 필요하신분은 논문을 참고해주세요.

 

25%에서는 성능하락이 거의 없고, 40%에서도 잘 방어가 된다. 모델이 클수록 더 잘된다.

 

비교대상은 CATS인데, 얘는 attention에는 적용하지 않고, MLP에만 적용하는 방법입니다. 둘 다 fine-tuning 없는 training-free 설정입니다.

 

TEAL이 CATS보다 잘되는 이유는 CATS랑 다르게, TEAL은 모든 행렬을 sparsify하기 때문에 특정 부분에 과하게 높은 sparsity가 적용되지 않아서라고 이야기합니다.

 

 

 

End-To-End Decoding Speed-Up

TEAL은 40%에서 최대 1.53배, 50%에서 최대 1.8배의 속도 향상을 보였습니다.

Llama3-8B가 Llama2-7B보다 속도향상이 작은데, 이는 LM head를 sparsify하지 않기 때문이라고 합니다.

 

 

 

Compatibility with Quantization

8-bit channelwise RTN, 4-bit AWQ, 2/3-bit QuIP#를 적용했을때, perplexity가 급격하게 악화되는 지점이 quantization bit와 상관없이 비슷하게 나왔습니다.

 

이는, activation sparsity와 weight quantization의 오차가 어느정도는 독립적으로 누적되는 것이라고 주장하며, 함께 사용할 수 있다고 이야기합니다.다만, 제대로 활용하려면 sparse + quantized 전용 커널이 필요할 것이라고도 이야기합니다.

 

 

 

Batched Sparsification

 

 

논문에서는 single-batch를 가정하고 설계되어있지만, 실제 배치로 inference가 이루어질때도 분석합니다.

가장 큰 문제는 입력마다 선호하는 sparsity pattern이 다를 수 있다는 것 입니다. (행렬 $W$마다 적절한 sparsity level $p$를 찾았다고 해도, 실제 입력에서 어떤 패턴으로 0이 될지는 다 다르게 되고, 이러면 weight column를 읽어올때 이점을 확보하기가 어렵습니다.)

 

이를 해결하기 위해서, 배치차원에서 activation magnitude의 평균을 기준으로 sparsify합니다. (마치 single-batch inference 처럼요)

 

위 Figure 8을 보면,  batch 1보다 약간 큰 (2, 4, 8) 정도에서는 어느정도의 효과를 볼 수 있었다고 합니다. 하지만, batch가 커질수록 activation-aware structured channel-wise pruning처럼 동작하게 되고, activation-dependent한 이점이 사라지게 됩니다.

 

이렇게 큰 batch가 지원이 안되는 것이 이 논문의 한계점이기도 합니다. 이는 나중에 후속연구에서 학습하면서 최적화할 수 있도록 바꿔야할 것이라고 이야기합니다. 

 

 

 

Summary

  • 학습없이 activation sparsity를 적용해서, 속도 향상을 할 수 있는 방법을 제안했다.
  • Layer별 sparsity를 최적화하고, spase kernel을 개선했으며, quantization과 같이 사용할 수 있음을 보였다.

 

https://arxiv.org/pdf/2210.09461

https://doonby.tistory.com/79 

 

DynamicViT: Efficient Vision Transformers with Dynamic Token Sparsification

https://arxiv.org/pdf/2106.02034 최근에 Token Pruning에 대해서 볼일이 있어서 다시 보는 김에 리뷰를 남겨봅니다. AbstractVision Transformer에서 가장 정보량이 많은 일부 토큰만 사용해서, 정확한 이미지 인식

doonby.tistory.com

- 논문을 읽기전에, 간단하게 보고 오면 좋을 token pruning 논문리뷰

Abstract

학습없이 token pruning만큼 빠르면서도, 높은 정확도를 유지할 수 있는 token merging (ToMe)를 소개합니다.

이미지에서는 throughput을 약 2배, 비디오에서는 약 2.2배까지 올릴 수 있었으며, 정확도 하락은 0.2~0.3%를 유지했습니다.

심지어 ToMe는 inference뿐만 아니라, 학습에도 적용할 수 있습니다.

다양한 실험을 통해서, 이미지, 비디오, 오디오 분야에서 높은 정확도와 속도를 달성할 수 있음을 보였습니다.

 

 

Introduction

이전에는 token pruning으로 입력 토큰수를 줄여서 속도를 향상시키는 방법들이 연구됐습니다.

하지만 학습중에는 zero-masking(실제로 토큰을 줄이지 않고, 1과 0으로 만드는)을 사용했기 때문에, 학습에는 속도향상을 볼 수 없었습니다. 또한, 토큰 정보를 지우는 것이기 때문에 정보 손실이 크게 발생하여 줄일 수 있는 토큰수에 한계가 존재합니다.

 

이런 큰 문제점들을 개선하기 위해, 논문의 저자들은 토큰을 제거하지 않고 합치는 Token Merging (ToMe)를 제안합니다.

ToMe는 학습과 추론 모두에서 사용할 수 있으며, 별도의 학습없이도 적용할 수 있다는 장점이 있습니다.

 

논문의 큰 contribution은 다음과 같습니다.

  • 학습과 상관없이 ViT의 throughput과 실제 학습속도를 향상시키는 방법을 제안.
  • 이미지뿐만 아니라, 동영상, 오디오에도 적용할 수 있음을 보임.
  • 시각화를 통해서, 이미지와 비디오 등에서 병합 기준을 보임.

 

 

 

Token Merging

Token pruning처럼 Token merging도 모듈을 추가합니다.

 

Token Merging - Strategy

 

토큰을 줄이는 기준은 매번 $r$ 개의 토큰이 감소하도록 설계했습니다.

예를 들어, 총 $L$개의 모듈에서 병합을 수행하면 토큰은 $rL$개 줄어들게 됩니다. (개수입니다. 비율이 아니예요.)

 

장점은 token pruning에서 입력에 따라 dynamic하게 토큰 수를 줄이던 연구에서 정확도는 높지만, 배치 추론이 어려운 문제를 해결 할 수 있습니다. (동일하게 $rL$개가 줄었으니까요)

 

또 다른 점은, token pruning은 블록의 시작부분에 모듈을 삽입하여 토큰을 줄이는 방법을 사용했지만, ToMe는 위 Figure1처럼 attention과 MLP 사이에 삽입합니다. 이는 attention의 결과를 바로 가져와서, 토큰간의 정보를 효과적으로 사용하여 병합하기 위함이라고 합니다.

 

 

Token Merging - Token Similarity

 

토큰을 병합하려면 '비슷하다'는 기준을 먼저 정의해야합니다.

단순하게 토큰의 feature vector간의 거리로 정의할 수 있지만, 저자들은 이것이 최적이 아닐 수 있다고 이야기합니다.

(transformer는 overparameterized되어있어서, 중간 feature들은 불필요한 정보들이 포함되어있기 때문입니다.)

 

저자들은 단순한 feature vector의 거리가 아니라, QKVself-attention 구조의 정보를 활용하려고 합니다.

구체적으로, key(K)는 이미 토큰이 담고있는 정보를 dot-product 유사도 계산에 적합한 형태로 만들어주고 있습니다.

그렇기 때문에, 각 토큰의 key들 사이에서 유사도(cosine similarity)를 측정하여, 어떤 토큰들이 비슷한 정보를 담는지 판단합니다.

 

 

Token Merging - Bipartite Soft Matching

 

이제 위에서 Token Similarity를 계산했으니, 실제로 어떤 토큰끼리 매칭할지를 빠르게 결정해야합니다.

K-means clustering같은 알고리즘이 아니라, 매칭 자체의 시간이 무시할 수 있을 만큼 빠른 알고리즘을 선택합니다.

 

이를 해결하기 위해 채택한 알고리즘이 Bipartite Soft Matching 입니다.

 

알고리즘은 다음과 같습니다. (위 Figure 1 참고)

  1. 토큰들을 크기가 비슷한 두 집합 A, B로 나눕니다.
  2. A의 토큰을 돌면서, B에서 제일 비슷한 토큰 하나를 찾아둡니다.
  3. 이제 제일 유사도가 높은 r개의 쌍만 남깁니다.
  4. 연결된 토큰들을 병합합니다. (평균값) 
  5. 두 집합 A, B를 합칩니다.

- 흠, A B를 잘못나누면 항상 최적의 결과가 나올 것 같지는 않네요.

Token Merging - Tracking Token Size

 

한번 토큰이 병합되면, 해당 토큰은 더 이상 한개의 입력 패치가 아니게 됩니다. (최소 두개 이상의 이미지 패치를 합쳤으니까요) 이는 softmax attention의 결과에 영향을 줄 수 있다고 합니다.

 

이를 해결하기 위해서 proportional attention을 제안합니다.

proportional attention

 

$log$ $s$는 각 토큰이 몇개의 패치를 대표하고 있는지를 담은 내용이라고 합니다. (토큰의 개수만큼 정보를 담고 있는 vector s)

많이 더 해질수록 해당 토큰의 attention 계산에 가중치를 더해줍니다.

 

이 s값은 실제 병합할때도 활용하여 합칩니다.

A토큰과 B토큰을 단순하게 평균낼 때, B는 100개의 토큰을 합친 상태임에도 단순하게 (A+B)/2 로 계산하면 이상하기 때문입니다.

s는 병합된 토큰A, B의 개수, x는 토큰

 

Token Merging - Training with Merging

 

학습에서 ToMe를 사용하기 위해서, token merging을 일종의 pooling 연산처럼 취급해서 average pooling처럼 역전파를 수행한다고 합니다. 

 

일반적인 ViT 학습 설정으로 ToMe를 그대로 적용해도 최적의 성능을 보였어서, 단순하게 학습 속도를 높이기 위한 옵션으로 사용해도 괜찮을 것 같다고 이야기합니다.

 

 

Experiments

 

화살표 방향은 항상 동일한 토큰의 양을 감소시킬 건지(rL), 혹은 첫 layer를 많이, 맨 마지막은 0에 가깝게 선형 감소 시킬건지를 나눈것. 구분선이 이상하게 그려져서 보기가 약간 힘든데, baseline과 비교했을때 정확도도 높고 속도도 빠르다.

 

 

 

Pruning과 Merge, 그리고 다양한 매칭 알고리즘에 따른 속도와 정확도 비교

 

 

 

Pruning method와의 비교인데, 학습속도가 빨라진다는 장점이 추가됨. DynamicViT token pruning도 나쁘진 않아보인다.

 

 

 

이미지에서의 Token Merging visualize인데, 눈을 찡그리고 보면 실제 object끼리 merge된 것 같은 느낌을 받는다.

 

 

 

Video에서의 visualize, 공 같은 물체가 확실히 병합된 것을 볼 수 있다.

 

 

 

https://arxiv.org/pdf/2106.02034

 

 

최근에 Token Pruning에 대해서 볼일이 있어서 다시 보는 김에 리뷰를 남겨봅니다.

 

 

Abstract

Vision Transformer에서 가장 정보량이 많은 일부 토큰만 사용해서, 정확한 이미지 인식이 가능하다는 것을 발견했습니다.

이 관찰을 바탕으로, 중복 토큰들을 동적으로 제거하는 dynamic token sparsification 프레임워크를 제안합니다.

 

이를 위해, 현재 feature에서 각 token의 중요도를 측정하는 lightweight prediction module을 사용합니다.

이 모듈을 여러 layer에 추가하여 중복토큰을 반복적으로 제거합니다.

 

입력 token의 약 66%를 제거하여 31% ~ 37%정도의 FLOPs감소와 40% 이상의 처리량 향상을 달성하면서 정확도 하락은 0.5% 이내로 유지했습니다.

 

 

Introduction

 

일반적인 CNN의 pruning에서는 중요하지 않은 filter(channel)등을 프루닝하는 방법을 사용합니다.

하지만, ViT에서는 이미지의 패치 토큰으로 입력이 들어가기 때문에 덜 중요한 토큰을 제거하는 방식을 사용할 수 있습니다.

이러한 방식의 연산이 가능한 이유는 self-attention이 가변 길이의 token seq를 처리할 수 있기 때문입니다.

 

논문의 저자들은 특정 layer마다(예를 들어 4, 7, 10번째) 어떤 토큰을 제거할지 동적으로 결정하는 lightweight prediction module을 사용하여 점차 제거하는 토큰의 양을 늘려가는 hierarchical sparsification을 사용합니다.

 

모델과 모듈을 end-to-end로 학습하기 위해서 두가지 방법을 사용합니다.

1. Gumbel-softmax - softmax결과를 binary로 만들지만, gradient는 흘릴 수 있는 방법

2. attention masking - 실제로 token을 0으로 만들면 attention계산에 영향을 주기 때문에, 제거된 토큰은 attention 연산에 참여하지 못하도록 만듦. 

 

 

이러한 방법을 통해서, 31%~37%의 GFLOPs 감소, 40% 이상의 throughput 향상을 달성하면서도, 정확도 하락은 0.5% 이하를 달성했습니다.

 

 

Method

Hierarchical Token Sparsification with Prediction Modules

Token sparsification은 반복적으로 이루어지기 때문에, 토큰들은 binary decision mask ($\hat{D} \in {0, 1}^{N}$) 을 유지합니다.

(이때 N은 patch embedding의 개수)

맨 처음에는 모든 mask의 값을 1로 초기화합니다.

MLP에 토큰 $x$ ($NC$)를 넣고, ($NC'$)를 예측합니다. ($C' = C / 2$)

 

이렇게 계산된 $z^{local}$과 mask를 통합하여, 전체에서 살아있는 token만 모은 $z^{global}$을 만들어냅니다.

 

이때, Agg 함수는 살아있는 token들의 평균입니다.

mask D와 곱해서 살아있는 token feature만 다 더하고, 살아있는 개수로 나누는 평균 값

 

$z^{local}$은 특정 token자체의 정보, $z^{global}$은 전체 이미지의 정보를 담고 있어서 두개를 concat해서 사용합니다. concat한 local-global embedding 정보를 사용하여, 각 토큰을 버릴지 말지 결정하는 확률을 예측합니다.

i번째 token을 버리면 0, 살리면 1 (Nx2) 형태의 output

 

해당 파이 값들을 모아서 현재의 mask $D$ 를 만들고, 이전 mask $\hat{D}$와 곱해서 갱신합니다.

Mask 갱신

 

 

End-to-end Optimization with Attention Masking

 

크게 두가지 문제를 해결합니다.

 

1. 위에서 파이값으로 이진 mask를 만드는 과정이 미분이 안되는 문제가 있습니다. (softmax -> binary)

이를 해결하기 위해서, Gumbel-Softmax 기법을 사용하여 mask를 생성합니다.

 

 

2. 토큰을 단순하게 0으로 바꾸면, self-attention에서 score를 계산할때 영향을 주게 됩니다.

이유는 softmax 연산 때문인데, 0이 들어가도 softmax를 통과한 이후에는 0이 아니게 되기 때문입니다.

 

 

 

이를 해결하기 위해, softmax에서 0인 애들은 모두 제외하고 계산합니다.

이때 i == j인 self-loop는 1로 두었는데, 토큰이 pruned되었더라도 자기 자신은 볼 수 있게 두는 것이 안정성에 좋았다고 합니다.

 

Training and Inference

학습에서는 end-to-end 학습을 위해서 실제로 토큰을 줄이지는 않고, masking을 사용합니다.

추론때는 실제로 특정 비율만큼 토큰을 제거하여 속도를 향상시키는 방법을 사용합니다.

 

loss는 총 4가지를 사용하고, 첫번째로는 CrossEntropy를 사용합니다.

token pruning으로 인한 성능저하를 방지하기 위해 dense network를 teacher로 삼아서 distillation loss를 '두개' 사용합니다.

1. 맨 마지막 stage의 token끼리 distillation

맨 마지막 stage의 살아남은 token t끼리 distillation을 수행함

2. output logits에 대한 distillation (일반 KD)

token pruning sparsity를 유지하기 위한 regularization term을 추가합니다.

원하는 p만큼을 유지하기 위해서, mask와 차이를 regularizer로 설정함

 

 

Experiments

다른 방법들과 DynamicViT간의 정확도와 Params, GFLOPs비교, 일반적으로 비슷한 정확도 대비 30% 이상 FLOPs가 높다.

 

 

재미있는 결과인데, 실제 프루닝된 token으로 visualize했을때 신기하게 딱 물체를 구분할 수 있는 외적인 애들이 지워지는 것 같다.

 

각 stage별로 token이 살아있을 확률을 시각화한 것 같은데, 약간 Image의 물체는 항상 중앙에 있을 확률이 높다 << 는 정보가 반영된게 아닐까 싶다.

 

 

 

https://github.com/Ahnho/SERo/

 

GitHub - Ahnho/SERo

Contribute to Ahnho/SERo development by creating an account on GitHub.

github.com

https://openreview.net/pdf?id=n1UUgGyh2Z

 

UAI 2025에 Accept된 논문으로, 국민대학교와 현대자동차 로보틱스 연구실에서 작성한 논문입니다.

 

 

 

Abstract

논문에서는 ViTs에서 pruning 효과를 극대화할 수 있는 프레임워크 SERo를 제안합니다.

SERo에서 크게 집중하는 것은 다음과 같습니다.

  1. 실제 압축이 가능한 하드웨어에 친화적인 pruning을 하자.
  2. Exploration과 Re-optimization phase로 나누어서 학습하자.
  3. Pre-trained model에 대해서 심플한 gradient magnitude-based pruning을 사용하자.

SERo는 DeiT-Base에서 1.55% 정확도 하락만으로 2.4배 속도 향상과, computational cost를 69% 줄일 수 있었습니다.

 

 

 

 

Introduction

일반적으로 dynamic pruning 연구에서는 sparse structure를 잘 탐색하는 것이 중요하다고 이야기하고, 특정 연구에서는 이렇게 찾은 sparse structure는 평평한 loss landscape를 갖게 되면서, 더 좋은 최적화를 할 수 있다고 이야기합니다.

 

또 다른 연구에서는 one-shot pruning에서도 모델을 다시 최적화 (re-optimization)하면, 성능이 좋아질 수 있다고 이야기합니다.

 

이러한 관점에서 논문에서는 잘 찾은 structure를 갖고 re-optimization을 하면, 더 성능을 높일 수 있을 것이라고 이야기합니다. 하지만, 일반적인 CNN과 다르게 pretrained를 불러와서 사용하고, attention 개념이 적용된 ViT pruning에서는 고려해야할 것이 많습니다.

 

위 그림을 보면, Weight Q, Weight K가 Weight V보다 일반적으로 weight 크기가 큰 것을 볼 수 있습니다.

그렇기 때문에, 일반적인 pruning에서 사용하는 global weight magnitude pruning을 사용하면 V matrix가 극단적으로 프루닝되며 성능이 하락하는 문제가 있다고 이야기합니다.

 

이는 attention 매커니즘에서, Q와 K는 많이 프루닝되어도 token간의 정보는 보존되지만, V가 많이 프루닝되면 정보의 손실이 크기 때문입니다. 이러한 문제를 해결하기 위해서 SERo에서는 global gradient magnitude pruning을 사용합니다.

 

 

위 그림은 SERo의 전반적인 과정으로, 중요한 부분은 다음과 같습니다.

  • Pre-trained model에 대해서 gradient magnitude pruning으로 optimal sparse model을 탐색합니다.
    • Gradual pruning을 적용합니다. (sparsity를 0.0부터 목표치까지 차근차근 올려가는 방법)
    • 디테일을 하나 추가해서, pruning에서는 일반적으로 'weight magnitude'가 중요합니다. 하지만, SERo는 gradient 기반으로 프루닝을 하기 때문에, 계산한 loss로 update하는 과정에서 가중치의 크기에 비례해서 업데이트하는 전략을 추가합니다. -> 자세한 내용은 아래 Method를 참고하세요.
  • Gradual pruning에서 sparsity가 목표치에 도달하면, exploring sparse structure phase는 끝나고 Re-optimizing compressed sparse structure phase로 넘어가서 학습합니다.
    • 일반적인 fine-tuning과 다른 점은 '찾은 sparse structure'로 압축 후에, learning rate를 초기값으로 돌려서 학습하는 것입니다. (RigL과 다르게, 학습한 $W$는 초기화하지 않습니다.)

 

 

Method

Unit Pruning

ViT block은 크게 self-attention layer, proojection layer, feed-forward layer로 구성됩니다. 이는 self-attention layer($W^{q}$, $W^{k}$, $W^{v}$), projection layer($W^{p}$), feed-forward layer($W^{f1}$, $W^{f2}$) 로 표현하겠습니다.

 

하드웨어에서 실제로 압축이 가능한 형태로 만들기 위해서 다음과 같이 묶어서 프루닝을 진행합니다.

  • $W^{q}$, $W^{k}$는 공통의 pruning mask $M^{q,k}$를 사용하고
  • $W^{v}$, $W^{q}$는 $M^{v}$
  • $W^{f1}$, $W^{f2}$는 $M^{f}$를 사용합니다.

 

이렇게 묶어서 프루닝을 하는 이유는 self-attention의 차원적인 고려입니다.

 

위 그림처럼 프루닝에서 가중치를 0으로 바꾸고 계산하는 zeroing과 다르게, compression은 제거한 프루닝의 채널이 맞지 않으면 다음 연산에 영향을 주게 됩니다.

 

이제 프루닝할 채널은 다음과 같이 계산합니다.

 

예시로 가져온 $W^{q}$, $W^{k}$에 대한 Mask $M^{q,k}$를 계산하는 방법입니다.

 

채널 $j$ 별로 두 $W$의 gradient의 L1 norm이 평균을 사용해서, $W^{q}$와 $W^{k}$ 모두의 중요도를 고려하는 방법을 사용합니다. 이 평균값이 threshold $\tau$ 이상이면 1 아니면 0으로 계산합니다.

 

$M^{v}$를 계산할때엔 $W^{v}$에 대한 gradient의 L1 norm을 사용하여 중요도를 체크합니다.

$W^{v}$에 적용할때엔 $M^{v}$를 element-wise로 곱하고, $W^{p}$에 적용할때엔 $(M^{v})^{T}$를 element-wise로 곱해서 사용합니다.  

 

마찬가지로, feed-forward network에 대해서도 $W^{f1}$을 기준으로 계산해서 $W^{f2}$에는 transpose하고 곱해서 적용합니다.

 

 

Exploration and Re-optimization

위에서는 어떻게 차원을 고려해서 압축하는지에 대해서 설명했다면, 이번 챕터에서는 어떻게 exploration을 할지에 대해서 설명합니다.

 

 

수식이 조금 복잡한데, 결국 살아있는 채널에 대해서 업데이트를 진행할 것이고, 이때 gradient에 $W$의 크기를 곱해주겠다 라는 이야기입니다.

 

gradient에 $W$의 크기를 곱해주는 이유는, 현재는 프루닝 기준을 'gradient'로 설정합니다. 하지만 프루닝에서는 이미 잘 알려진 것처럼 weight의 크기도 중요합니다. 이러한 weight의 크기 정보를 프루닝에서 활용하기 위해 gradient에 적용하는 것입니다.

 

이렇게 한다면 $W$가 큰 weight는 업데이트할 때 더 많이 업데이트되고, $W$가 작은 weight는 업데이트가 덜 되게 됩니다.

$W$크기를 고려해서 차등을 두어 업데이트하겠다는 이야기입니다.

  • Exploration phase에서만 사용하고, Re-optimization에서는 $W$크기는 제외합니다. (성능이 더 떨어져요)

 

Re-optimization에서는 구조를 찾았으니, learning rate를 초기값으로 돌리고 다시 학습합니다.

SERo의 pseudo code

 

 

 

Results

DeiT에서 다양한 baseline과 비교했을때 빠르고 좋은 모습

 

 

단순 Zeroing과 FFN, Attention, 전부를 압축했을때 얼마만큼의 속도 향상이 있는지를 비교하는 table

 

 

 

Weight magntidue pruning과 gradient pruning, SERo의 방법에 따라서 각 Unit이 얼마나 프루닝이 됐는지에 대한 비율과 정확도 비교. (SERo가 gradient magnitude pruning과는 다르다라는 이야기)

 

 

Exploration만 했을때와 Re-optimization을 했을때의 정확도 차이. (Re-optimization의 중요성)

 

 

 

요약

  • Weight를 0으로 바꿔놓고 프루닝됐다! 하는 것이 아니라, 실제 차원을 고려해서 압축이 가능하도록 설계했다.
  • Weight magnitude pruning을 사용하면 안되는 이유를 밝히고, gradient magnitude pruning에 weight 크기를 고려한 방법을 제안했다.
  • Exploration과 Re-optimization phase를 나눠서 구조를 찾고, 압축하고, 다시 최적화하는 단계를 제안해서 성능을 향상시켰다.
  • 평상시에 작성하는 논문리뷰보다 쵸큼 더 자세하게 썼는데, 제가 참여했던 논문이라서 그렇습니다.

+ Recent posts