둔비의 공부공간

DaFKD: Domain-aware Federated Knowledge Distillation 본문

Papers/Compression

DaFKD: Domain-aware Federated Knowledge Distillation

Doonby 2024. 7. 31. 23:56

https://openaccess.thecvf.com/content/CVPR2023/html/Wang_DaFKD_Domain-Aware_Federated_Knowledge_Distillation_CVPR_2023_paper.html

 

CVPR 2023 Open Access Repository

DaFKD: Domain-Aware Federated Knowledge Distillation Haozhao Wang, Yichen Li, Wenchao Xu, Ruixuan Li, Yufeng Zhan, Zhigang Zeng; Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), 2023, pp. 20412-20421 Abstract Federa

openaccess.thecvf.com

https://github.com/haozhaowang/DaFKD2023

 

GitHub - haozhaowang/DaFKD2023: Code for CVPR2023 DaFKD : Domain-aware Federated Knowledge Distillation

Code for CVPR2023 DaFKD : Domain-aware Federated Knowledge Distillation - haozhaowang/DaFKD2023

github.com

 

Abstract

기존 FD방법들은 client들의 soft logits을 단순히 avg만 하며, model의 다양성을 고려하지 않아 aggregate model의 성능이 떨어졌다.

(특히 특정 local이 아직 sample에 대해서 완전하게 학습되지 않은 상태에서 aggregate했을때 그렇다.)

 

이 논문에서는, 각 client의 local data를 specific domain으로 간주하고 distillation sample에 대한 중요성을 알 수 있는 Domain Knowledge Aware Federated Distillation을 제안했다.

 

각 client에 대해서 sample과 domain의 상관계수를 알 수 있도록 domain discriminator를 사용했다.

(communication cost를 줄이기 위해서, classification model 파라미터의 일부분을 공유해서 사용한다.)

 

Introduction

Classic FL paradigm, FedAvg는 client의 model을 aggregate하여 optimize를 했다.

하지만, 이러한 세팅은 client간의 Non-IID 세팅에서 성능하락을 보였다. 이는 다양한 방향으로 optimize가 되어, aggregate된 model의 variance가 엄청 커지기 때문이다.

 

이러한 문제를 해결하기 위해서, 여러 local model의 knowledge를 global model로 distillation하기 위해, output soft predictions을 aggregate하는 방법들이 나왔다.

  • Ensemble distillation for robust model fusion in federated learning
    • public dataset을 distillation data sample로 사용해서, 여러 local 모델들의 soft prediction을 만들고, 이 prediction을 사용해서 global 모델을 업데이트 하는 방법을 사용했다.
  • Fine-tuning global model via data-free knowledge distillation for non-iid federated learning
  • Data-free knowledge distillation for heterogeneous federated learning
    • 이 두 논문은 public dataset을 사용하지 않고, 생성모델에서 만든 데이터를 갖고 사용했다.

위 연구들은 모델의 다양성을 고려하지 않아, 성능에 한계가 있다고 이야기한다. (soft prediction만 사용하면 일부 local model이 distillation sample에 대해서 잘못된 예측을 하면 global 모델이 망가진다고 이야기한다.)

 

이 논문에서는 DaFKD라는 새로운 feaderated distillation 방법을 제안한다.
이 방법은 주어진 distillation sample에 대한 모델의 importance를 측정하여, 잘못된 soft prediction의 영향을 줄일 수 있다고 이야기한다.
 
구체적으로, 각 client의 local data를 domain으로 간주하고, 각 client에 domain discriminator를 사용해서 sample과 domain간의 상관계수를 파악한다. 상관계수가 높으면, 로컬모델에 중요도도 높이는 방법이다.
 
sample과 domain이 같을때 모델이 좋은 예측을 할것이다 라는 것에서 기반했다고 한다.
 

 

 

Methodology

 

위 그림에서 볼 수 있듯, DaFKD는 도메인 discriminator를 활용해서 각 local 모델의 중요도를 알아내고, ensemble distillation의 성능을 향상시키는 것이다. 

 

Domain Discriminator

 

우선, distillation sample이 해당 모델이 학습한 도메인과 동일할때, 모델은 올바른 예측을 할 확률이 높다는 직관으로 시작한다.

따라서, 도메인과 샘플간의 상관관계를 측정하는 것이 필요하다.

 

이 논문에서는 local dataset을 target으로 간주하고, distillation sample과 local domain간의 상관관계를 출력하는 domain discriminator를 제안한다.

- 그냥 generator를 분포에 맞게 잘 뽑게 하면 안돼...?

 

저자들은 각 client $k$에게 personalized discriminator $\theta_{k}^{d}$를 할당하고, 모든 client가 공유하는 global generator $\theta^{g}$를 사용했다.

 

매 round $t$마다, 참여하는 각 client는 서버로부터 generator $\theta^{g}$를 받아와서 분포 $p_{z}$에서 sampling한 noise로 pseudo dataset $\hat{D}_{k}$ 을 생성한다.

 

그 다음, 각 client는 local data에는 positive label을 붙이고, 생성한 pseudo data에는 false label을 붙여서 discriminator를 학습한다.

 $f(\theta_{k}^{d}; x_{i})$는 $x_{i}$가 real data일 확률이다.

discriminator $\theta^{d}$를 학습하면, client는 이 discriminator를 활용하여 generator를 학습한다. 

이 과정에서 아래 함수를 사용한다.

이렇게 학습하면, 각 client의 generator를 만들 수 있고, 이러한 generator를 합쳐서 global generator를 만들 수 있다.

Domain-aware Federated Distillation

Classification model을 얻기 위해, 매 round $t$마다 각 참가하는 client $k$는 local에서 $w_{t}^{k}$를 훈련하고, 이 model과 discriminator를 server로 전송한다. 서버는 받은 모델의 weight를 aggregate한다.

server는 global generator $\theta^g$를 사용해서 pseudo dataset $\hat{D}^g$를 생성하고, 이를 distillation data로 사용한다.

그 후, 매 distillation sample $x_{i}$에 대해서 server는 각 local의 domain discriminator $w_{t}^{k}$ 로  importance를 계산한다.

 

모든 local에 대한 importance의 합이 1이 되기 위해서 아래와 같이 normalize 한다.

최종적으로, server는 pseudo sample $x_{i}$를 각 local model $w_{t}^{k}$와 avg model $\hat{w}_{t+1}$에 넣어서 soft prediction두개를 뽑아내고, importance를 ensemble knowledge distillation에 적용하여 global model $w_{t+1}$을 얻는다.

 

 

 

Partial Parameters Sharing

Domain discriminator가 local dataset을 기반으로 학습되기 때문에, local dataset이 작으면 성능이 떨어질 수 있다.

서로 다른 task간에 encoder를 공유하면서, 상호보완적인 효과를 얻을 수 있는 Multi-task Learning 아이디어를 활용하여, discriminator $\theta_{k}^{d}$와 classification model $w^k$의 encoder parameter를 공유하는 방법을 사용했다.

  • Discriminator와 Classification Network 모두 추출된 feature로부터 sample을 구분하는 것.
  • Discriminator를 server에 보낼때, communication cost를 줄일 수 있음.

 

Experiments

Baselines

  • FEDAVG
  • FEDPROX
  • FEDDFUSION
  • FEDGEN
  • FEDFTG

Datasets

  • MNIST
  • EMNIST
  • FASHION MNIST
  • SVHN

Local Training Epoch : 20

Communication rounds : 60

Number of Clients : 20 (activate ratio 0.4)

 

alpha가 낮을수록 높은 이질성을 보인다.

 

위 Table1은 다양한 Dataset와 각 dataset의 이질성 $\alpha$ 에 대한 Test Accuracy를 측정한 결과이다.

FEDDFUSION, FEDGEN, FEDFTG, DaFKD등 KD를 사용한 방법이 기존 FEDAVG나 FEDPROX의 성능을 넘는 것을 볼 수 있다.

DaFKD는 다양한 데이터셋에 대해서 대부분 우수한 성능을 보이는 유일한 KD알고리즘이라고 말하며, 이러한 결과는 domain discriminator가 상관관계를 잘 식별하고, 이를 통해서 이질성 문제를 해결할 수 있었다고 주장한다.

Data heterogeneity and Client participant

위 그림 3은, 4가지 Image dataset에서 데이터 이질성에 따른 test acc를 측정한 결과이다. 이러한 결과를 보면, 모든 방법들은 데이터 이질성이 감소할수록 test acc가 올라가는 것을 볼 수 있다.

 

DaFKD는 $\alpha=0.1$에서 test acc가 크게 향상되며, 항상 좋은 성능을 보인다고 말한다.

위 그림 5는 client ratio에 대한 acc이다. DaFKD는 뭐 다양한 값에서 좋은 성능을 보였다고 이야기한다.

또한, 모든 방법들과 마찬가지로 참가하는 client의 수가 늘어날 수록 성능은 향상된다고 이야기한다.

 

Communication rounds

 

위 table 2는 target test acc에 도달하기 위해 필요한 communication rounds를 측정한 것이다.

best와 second를 굵게 표시했는데, 대부분 2위 안에는 들었다고 이야기한다.

 

위 그림 4는, communication rounds = 60동안의 test acc를 측정한 결과이다.

DaFKD는항상 높은 성능을 보였다고 이야기한다.

- 왜 중간에 acc가 떨어지지...? overfitting? Non-IID에 대한 불안정성?

 

Conclusion

이 논문에서는 연합 지식 증류(Federated Knowledge Distillation)에서 데이터 이질성 문제를 해결하기 위해, 도메인 인식 연합 지식 증류(Domain-Aware Federated Knowledge Distillation) DaFKD 라는 새로운 방법을 제안했다.

 

이 방법은 주어진 distillation sample에 대해 각 local model의 importance를 부여하기 위해, domain discriminator를 사용하여 sample과 local model학습에 사용된 domain간의 상관관계를 측정한다. 

 

또한, domain discriminator의 훈련을 빠르게 하기 위해, 일부 파라미터를 classification model과 공유했다. (conv layer)

 

그 결과, 다양한 데이터셋에서 test acc가 높게 나오는 것을 확인했다.

 

 

 

Process에 대한 간단한 요약

  1. Server는 Client $K$ 에게 generator를 넘겨준다.
  2. Client는 Generator와 Client의 Discriminator, Local Model을 학습한다.
  3. Client는 Server에게 Generator, Discriminator, Local Model을 넘겨준다.
  4. Server는 받은 Generator와 Local Model을 Aggregate한다.
  5. Server는 Aggregated generator로 pseudo distillation sample을 생성한다.
  6. 생성한 sample과 각 domain의 상관관계를 계산해둔다.
  7. local model들로 forwarding해서 prediction을 만들고, 계산한 상관관계를 사용하여 KL Divergence로 Aggregate model을 학습한다.

 

Comments