https://github.com/Ahnho/SERo/

 

GitHub - Ahnho/SERo

Contribute to Ahnho/SERo development by creating an account on GitHub.

github.com

https://openreview.net/pdf?id=n1UUgGyh2Z

 

UAI 2025에 Accept된 논문으로, 국민대학교와 현대자동차 로보틱스 연구실에서 작성한 논문입니다.

 

 

 

Abstract

논문에서는 ViTs에서 pruning 효과를 극대화할 수 있는 프레임워크 SERo를 제안합니다.

SERo에서 크게 집중하는 것은 다음과 같습니다.

  1. 실제 압축이 가능한 하드웨어에 친화적인 pruning을 하자.
  2. Exploration과 Re-optimization phase로 나누어서 학습하자.
  3. Pre-trained model에 대해서 심플한 gradient magnitude-based pruning을 사용하자.

SERo는 DeiT-Base에서 1.55% 정확도 하락만으로 2.4배 속도 향상과, computational cost를 69% 줄일 수 있었습니다.

 

 

 

 

Introduction

일반적으로 dynamic pruning 연구에서는 sparse structure를 잘 탐색하는 것이 중요하다고 이야기하고, 특정 연구에서는 이렇게 찾은 sparse structure는 평평한 loss landscape를 갖게 되면서, 더 좋은 최적화를 할 수 있다고 이야기합니다.

 

또 다른 연구에서는 one-shot pruning에서도 모델을 다시 최적화 (re-optimization)하면, 성능이 좋아질 수 있다고 이야기합니다.

 

이러한 관점에서 논문에서는 잘 찾은 structure를 갖고 re-optimization을 하면, 더 성능을 높일 수 있을 것이라고 이야기합니다. 하지만, 일반적인 CNN과 다르게 pretrained를 불러와서 사용하고, attention 개념이 적용된 ViT pruning에서는 고려해야할 것이 많습니다.

 

위 그림을 보면, Weight Q, Weight K가 Weight V보다 일반적으로 weight 크기가 큰 것을 볼 수 있습니다.

그렇기 때문에, 일반적인 pruning에서 사용하는 global weight magnitude pruning을 사용하면 V matrix가 극단적으로 프루닝되며 성능이 하락하는 문제가 있다고 이야기합니다.

 

이는 attention 매커니즘에서, Q와 K는 많이 프루닝되어도 token간의 정보는 보존되지만, V가 많이 프루닝되면 정보의 손실이 크기 때문입니다. 이러한 문제를 해결하기 위해서 SERo에서는 global gradient magnitude pruning을 사용합니다.

 

 

위 그림은 SERo의 전반적인 과정으로, 중요한 부분은 다음과 같습니다.

  • Pre-trained model에 대해서 gradient magnitude pruning으로 optimal sparse model을 탐색합니다.
    • Gradual pruning을 적용합니다. (sparsity를 0.0부터 목표치까지 차근차근 올려가는 방법)
    • 디테일을 하나 추가해서, pruning에서는 일반적으로 'weight magnitude'가 중요합니다. 하지만, SERo는 gradient 기반으로 프루닝을 하기 때문에, 계산한 loss로 update하는 과정에서 가중치의 크기에 비례해서 업데이트하는 전략을 추가합니다. -> 자세한 내용은 아래 Method를 참고하세요.
  • Gradual pruning에서 sparsity가 목표치에 도달하면, exploring sparse structure phase는 끝나고 Re-optimizing compressed sparse structure phase로 넘어가서 학습합니다.
    • 일반적인 fine-tuning과 다른 점은 '찾은 sparse structure'로 압축 후에, learning rate를 초기값으로 돌려서 학습하는 것입니다. (RigL과 다르게, 학습한 $W$는 초기화하지 않습니다.)

 

 

Method

Unit Pruning

ViT block은 크게 self-attention layer, proojection layer, feed-forward layer로 구성됩니다. 이는 self-attention layer($W^{q}$, $W^{k}$, $W^{v}$), projection layer($W^{p}$), feed-forward layer($W^{f1}$, $W^{f2}$) 로 표현하겠습니다.

 

하드웨어에서 실제로 압축이 가능한 형태로 만들기 위해서 다음과 같이 묶어서 프루닝을 진행합니다.

  • $W^{q}$, $W^{k}$는 공통의 pruning mask $M^{q,k}$를 사용하고
  • $W^{v}$, $W^{q}$는 $M^{v}$
  • $W^{f1}$, $W^{f2}$는 $M^{f}$를 사용합니다.

 

이렇게 묶어서 프루닝을 하는 이유는 self-attention의 차원적인 고려입니다.

 

위 그림처럼 프루닝에서 가중치를 0으로 바꾸고 계산하는 zeroing과 다르게, compression은 제거한 프루닝의 채널이 맞지 않으면 다음 연산에 영향을 주게 됩니다.

 

이제 프루닝할 채널은 다음과 같이 계산합니다.

 

예시로 가져온 $W^{q}$, $W^{k}$에 대한 Mask $M^{q,k}$를 계산하는 방법입니다.

 

채널 $j$ 별로 두 $W$의 gradient의 L1 norm이 평균을 사용해서, $W^{q}$와 $W^{k}$ 모두의 중요도를 고려하는 방법을 사용합니다. 이 평균값이 threshold $\tau$ 이상이면 1 아니면 0으로 계산합니다.

 

$M^{v}$를 계산할때엔 $W^{v}$에 대한 gradient의 L1 norm을 사용하여 중요도를 체크합니다.

$W^{v}$에 적용할때엔 $M^{v}$를 element-wise로 곱하고, $W^{p}$에 적용할때엔 $(M^{v})^{T}$를 element-wise로 곱해서 사용합니다.  

 

마찬가지로, feed-forward network에 대해서도 $W^{f1}$을 기준으로 계산해서 $W^{f2}$에는 transpose하고 곱해서 적용합니다.

 

 

Exploration and Re-optimization

위에서는 어떻게 차원을 고려해서 압축하는지에 대해서 설명했다면, 이번 챕터에서는 어떻게 exploration을 할지에 대해서 설명합니다.

 

 

수식이 조금 복잡한데, 결국 살아있는 채널에 대해서 업데이트를 진행할 것이고, 이때 gradient에 $W$의 크기를 곱해주겠다 라는 이야기입니다.

 

gradient에 $W$의 크기를 곱해주는 이유는, 현재는 프루닝 기준을 'gradient'로 설정합니다. 하지만 프루닝에서는 이미 잘 알려진 것처럼 weight의 크기도 중요합니다. 이러한 weight의 크기 정보를 프루닝에서 활용하기 위해 gradient에 적용하는 것입니다.

 

이렇게 한다면 $W$가 큰 weight는 업데이트할 때 더 많이 업데이트되고, $W$가 작은 weight는 업데이트가 덜 되게 됩니다.

$W$크기를 고려해서 차등을 두어 업데이트하겠다는 이야기입니다.

  • Exploration phase에서만 사용하고, Re-optimization에서는 $W$크기는 제외합니다. (성능이 더 떨어져요)

 

Re-optimization에서는 구조를 찾았으니, learning rate를 초기값으로 돌리고 다시 학습합니다.

SERo의 pseudo code

 

 

 

Results

DeiT에서 다양한 baseline과 비교했을때 빠르고 좋은 모습

 

 

단순 Zeroing과 FFN, Attention, 전부를 압축했을때 얼마만큼의 속도 향상이 있는지를 비교하는 table

 

 

 

Weight magntidue pruning과 gradient pruning, SERo의 방법에 따라서 각 Unit이 얼마나 프루닝이 됐는지에 대한 비율과 정확도 비교. (SERo가 gradient magnitude pruning과는 다르다라는 이야기)

 

 

Exploration만 했을때와 Re-optimization을 했을때의 정확도 차이. (Re-optimization의 중요성)

 

 

 

요약

  • Weight를 0으로 바꿔놓고 프루닝됐다! 하는 것이 아니라, 실제 차원을 고려해서 압축이 가능하도록 설계했다.
  • Weight magnitude pruning을 사용하면 안되는 이유를 밝히고, gradient magnitude pruning에 weight 크기를 고려한 방법을 제안했다.
  • Exploration과 Re-optimization phase를 나눠서 구조를 찾고, 압축하고, 다시 최적화하는 단계를 제안해서 성능을 향상시켰다.
  • 평상시에 작성하는 논문리뷰보다 쵸큼 더 자세하게 썼는데, 제가 참여했던 논문이라서 그렇습니다.

+ Recent posts