둔비의 공부공간
Training Independent Subnetworks For Robust Prediction 본문
https://arxiv.org/abs/2010.06610
Training independent subnetworks for robust prediction
Recent approaches to efficiently ensemble neural networks have shown that strong robustness and uncertainty performance can be achieved with a negligible gain in parameters over the original network. However, these methods still require multiple forward pa
arxiv.org
(ICLR 2021) Cambridge, Google, Standford
Abstract
Ensemble은 여전히 prediction을 위해 multiple forward passes와 computational cost를 필요로 한다.
논문에서는, single model's forward pass로 multiple predictions을 하는 방법을 소개한다.
Multi-Input Multi-Output(MIMO)를 사용하면, 단일 모델의 용량으로 다양한 subnetworks를 학습시킬 수 있다.
subnetworks의 prediction을 ensemble하면, 계산량의 증가 없이 robustness를 향상시킬 수 있다.
Introduction
기존 ensemble이 acc향상이나 robustness등에 도움이 됐지만, 여전히 prediction을 위해서는 4~10번의 forward pass가 필요했다.
논문에서는 여러 predictions의 이점을 single model forward pass로 달성할 수 있다는 것을 보였다.
multi-input multi-output 구조를 사용하여, 동시에 여러 sub-network를 학습시키는 방법을 사용했다.
이러한 방법은 test 할때, 한번에 모든 sub-network를 평가할 수 있고, single forward pass로 ensemble 효과를 낼 수 있다.
MIMO는 기존 network에서 두가지만 변경하면 된다.
- Replace input layer
- Replace output layer
테스트 할때는 동일한 입력이 $M$번 반복되어 들어가고, 이를 avg한 prediction으로 사용한다.
subnetwork는 서로 disjoint한 부분을 사용하여 독립적으로 prediction하기에 robustness를 확보할 수 있었다.
Contributions
- MIMO구조로 multiple independent prediction이 single forward pass로 할 수 있다.
약간의 parameter와 compute cost로 불확실성을 줄이고 robustness를 키울 수 있다. - 각 member들의 diversity 를 분석하고, 이들이 각각 독립적으로 학습되는 것을 보였다.
- CIFAR10, CIFAR100, ImageNet에 대해서 SOTA를 달성했다.
Multi-input Multi-output networks
MIMO 구조에서, $M$ 개의 inputs과 $M$ 개의 outputs으로 구성되며, 각 outputs은 input에 상응하는 것을 prediction한다.
이를 하기 위해서 두가지 아키텍쳐 변경이 필요하다.
- $M$ inputs ${X_{1}, ... , X_{M}}$는 first hidden layer 전에서 concatenated된다.
- output layer에서, $M$ predictive distributions가 나온다.
- ${p_{\theta}(y_{1} | x_{1}, ... , x_{m}), ..., p_{\theta}(y_{M} | x_{1}, ..., x_{M})}$
이를 위해서 약간(0.03%)의 추가적인 파라미터가 필요하고, 0.01% 정도의 FLOPs가 증가됐다.

input과 output pair가 서로 독립적이기 때문에, M-tuples로 network를 학습시킨 것과 동일하다.
evalutation에서는 unseen input $x'$을 $M$개로 쌓아서 넣는다.
$M$개의 결과를 ensemble 처럼 사용하여, predictive performance를 향상시킨다.
multiple weight sample가 필요한 Bayesian methods나 parameter-efficient한 BatchEnsemble과 다르게, MIMO는 모든 ensemble member가 single forward pass로 계산이 가능하다.
Understanding the subnetworks
Loss-Landscape analysis
multiple inputs을 사용하는 것이 diverse subnetworks를 학습하는 중요 포인트이다.
naive multiheaded architecture의 경우 input이 공유되고, model의 outputs이 $M$개로 나뉜다.
이는, prediction을 위해 동일한 feature를 사용하는 것으로, 매우 낮은 diversity를 야기시킨다.
- CIFAR-10에서 훈련된 SmallCNN 모델을 살펴보고 가중치 공간에서 세 개의 subnetwork간에 선형 보간을 수행한다.
- MIMO의 경우 네트워크의 본체가 공유되므로 입력 및 출력 레이어를 보간한다.
- 네이브 멀티헤드 모델의 경우 네트워크의 입력과 본체가 공유되므로 출력 레이어만 보간합니다.
deep ensemble과 유사하게, MIMO를 사용하여 훈련된 subnetwork는 초기화의 차이로 인해 가중치 공간에서 단절된 모드로 수렴하는 반면, 나이브 멀티헤드 모델의 경우 서브 네트워크가 동일한 모드로 끝난다.(그림 3, 왼쪽).
그림 3(오른쪽)은 불일치, 즉 하위 네트워크가 예측된 클래스에 대해 동의하지 않을 확률이다. MIMO의 연결이 끊긴 모드는 다양한 예측을 산출하는 반면, 순진한 멀티헤드 모델의 예측은 높은 상관관계를 보인다.
Function space analysis
function space에서의 MIMO subnetworks의 training경로를 visualize했다.
SmallCNN ($M = 3$)를 학습했는데, training이 끝날때, predictions에 대해 t-SNE projection을 진행했다.
trajectories가 서로 다른 local optima에 수렴한 것을 확인했다.
large scale network의 diversity를 수치적으로 평가하기 위해, subnetworks predictions의 pairwise similarity를 average하고, 다른 ensemble 방법과 비교해봤다.
여기서 $D$는 Disagreement와 KL-divergence다.
naive multiheaded model, TreeNet, batch ensemble을 봤을 때, 그 중 naive multiheaded model이 diversity를 유도하는데 실패했다. TreeNet, Batch Ensemble은 prediction간의 correlation이 좀 있긴 했지만, multiheaded model보다 조금 더 diversity갖고 있었다. MIMO는 조금 더 diversity가 높았다.
Separation of the subnetworks
subnetwork가 네트워크의 개별 부분을 활용한다는 것을 보여주기 위해 activation과 각 $M$ input에 따라 어떻게 반응하는지 측정한다.
각 입력에 대해 pre-activation의 "conditional variance"를 측정했다.
activation의 conditional variance가 input에 대해 0이 아닌 경우, 이는 input이 바뀌면 activation도 바뀐다는 의미이므로 입력에 해당하는 하위 네트워크의 일부로 간주한다.
만약 subnetwork가 독립적이라면, 하나의 input에 대해서는 conditional variance의 값이 0이 아닐 것이고, 그 외의 나머지 input에 대해서는 conditional variance의 값이 0에 가까울 것이다.
첫번째와 두번째 그림 모두, 하나의 input에 대해서만 conditional variance값이 0이 아닌 것을 볼 수 있다.
The optimal number of subnetworks
그렇다면, 이상적인 subnetworks의 개수는 무엇일까?
너무 적으면, ensemble의 효과를 낼 수 없고, 너무 많으면 subnetwork의 개별 성능이 하락한다.
결론부터 말하자면, subnetwork의 개별 성능에 영향을 주지 않는 선에서 최대 $M$을 정하는 것이 제일 좋았다.
위의 그림5를 보면, CIFAR10, CIFAR100에 대해서 $M$크기의 따른 ensemble과 subnetwork의 성능을 보이고 있다.
$M = 1$일때는 일반적으로 단일 network를 학습시켜서 사용할때와 동일하다.
$M$와 network capacity가 커질 수록, subnetwork의 성능이 조금씩 감소하는 것을 볼 수 있었다.
그러나 ensemble의 성능은 $M=2$, $M=4$일때, ensemble의 성능향상폭이 subnetwork의 성능감소폭보다 컸다.
Input and batch repetition
MIMO는 기존 baseline에 multi-input, multi-output구조와 hyperparameter $M$만 추가 했는데 잘 동작한다.
두개의 hyperparameter를 추가하면 성능을 더 올릴 수 있었는데, 특히 network capacity가 제한적일때 효과가 좋았다.
- Input repetition
- subnetwork끼리 어떠한 feature도 공유하지 못하도록, independent example을 고르는 것이다.
이러한 방법은 network의 capacity가 subnetwork $M$에 충분할때는 유용하다. 하지만, capacity가 부족할때에는 오히려 완벽하게 independent한 것 보다, 약간 공유하는 것이 좋다는 것을 발견했다. - 그래서 논문에서는 inputs간의 independence를 완화시키는 방법을 제안했다. (input1 $x1$을 train set에서 고른 후에, 확률 $p$로 input2 $x2$를 $x1$과 동일한 것을 사용한다.
- 위 그림 6을 보면, $p=0$일때는 independent하고, $M=2$일때 resnet50의 capacity가 충분하지 않아서 성능이 나오지 않는다. $p$가 커지면서 subnetwork끼리 공유하는 feature가 증가하고, 성능이 향상되는 것을 볼 수 있다.
하지만, $p=1$이 되면, subnetwork의 diversity가 감소하면서, ensemble 이점이 없어져 성능이 낮아지는 것을 볼 수 있다.
- subnetwork끼리 어떠한 feature도 공유하지 못하도록, independent example을 고르는 것이다.
- Batch repetition
- MC dropout 및 Variational Bayesian neural network 같은 stochastic model의 경우, 각 예제에 대해 여러 개의 approximate posterior sample을 그리면 network의 모델 불확실성에 대한 gradient noise를 줄여 성능을 향상시킬 수 있다.
- mini-batch에서 example을 반복함으로써 비슷한 효과를 얻을 수 있다. SGD 단계에서 unique example 수를 결정하는 batch size 선택과 마찬가지로, 반복 횟수를 변경하면 implicit regularization 효과가 있다. 그림 6b은 배치 반복 횟수에 따른 성능을 보여주며, 각 그래프는 batchsize, lr, $M$등 12가지 설정에 대한 성능 범위를 나타낸다. 일반적으로 반복 횟수가 많을수록 성능이 약간 향상된다.
Experiments
Conclusion
- 약간의 architecture 수정과 hyperparameter추가로 subnetworks로 분리되고 독립적으로 학습한 효과를 낼 수 있다.
- 그렇기에 computational efficiency하며, single forward pass로 evaluation할 수 있다. (ensemble 효과)
'Papers > Ensemble' 카테고리의 다른 글
Model soups: averaging weights of multiple fine-tuned models improves accuracy without increasing inference time (0) | 2023.03.08 |
---|