Towards Understanding and Mitigating Dimensional Collapse in Heterogeneous Federated Learning
Authors: Yujun Shi, Jian Liang, Wenqing Zhang, Vincent Tan, Song Bai
ICLR 2023 | Conference PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | To commence, we study how heterogeneous data affects the global model in federated learning in Sec. 3.1. Specifically, we compare representations produced by global models trained under different degrees of data heterogeneity. Since the singular values of the covariance matrix provide a comprehensive characterization of the distribution of high-dimensional embeddings, we use it to study the representations output by each global model. Interestingly, we find that as the degree of data heterogeneity increases, more singular values tend to evolve towards zero. This observation suggests that stronger data heterogeneity causes the trained global model to suffer from more severe dimensional collapse, whereby representations are biased towards residing in a lower-dimensional space (or manifold). A graphical illustration of how heterogeneous training data affect output representations is shown in Fig. 1(b-c). Our observations suggest that dimensional collapse might be one of the key reasons why federated learning methods struggle under data heterogeneity. Essentially, dimensional collapse is a form of oversimplification in terms of the model, where the representation space is not being fully utilized to discriminate diverse data of different classes. Given the observations made on the global model, we conjecture that the dimensional collapse of the global model is inherited from models locally trained on various clients. This is because the global model is a result of the aggregation of local models. To validate our conjecture, we further visualize the local models in terms of the singular values of representation covariance matrices in Sec. 3.2. Similar to the visualization on the global model, we observe dimensional collapse on representations produced by local models. With this observation, we establish the connection between dimensional collapse of the global model and local models. To further understand the dimensional collapse on local models, we analyze the gradient flow dynamics of local training in Sec. 3.3. Interestingly, we show theoretically that heterogeneous data drive the weight matrices of the local models to be biased to being low-rank, which further results in representation dimensional collapse. Inspired by the observations that dimensional collapse of the global model stems from local models, we consider mitigating dimensional collapse during local training in Sec. 4. In particular, we propose a novel federated learning method termed FEDDECORR. FEDDECORR adds a regularization term during local training to encourage the Frobenius norm of the correlation matrix of representations to be small. We show theoretically and empirically that this proposed regularization term can effectively mitigate dimensional collapse (see Fig. 1(d) for example). Next, in Sec. 5, through extensive experiments on standard benchmark datasets including CIFAR10, CIFAR100, and Tiny Image Net, we show that FEDDECORR consistently improves over baseline federated learning methods. In addition, we find that FEDDECORR yields more dramatic improvements in more challenging federated learning setups such as stronger heterogeneity or more number of clients. Lastly, FEDDECORR has extremely low computation overhead and can be built on top of any existing federated learning baseline methods, which makes it widely applicable. Our contributions are summarized as follows. First, we discover through experiments that stronger data heterogeneity in federated learning leads to greater dimensional collapse for global and local models. Second, we develop a theoretical understanding of the dynamics behind our empirical discovery that connects data heterogeneity and dimensional collapse. Third, based on the motivation of mitigating dimensional collapse, we propose a novel method called FEDDECORR, which yields consistent improvements while being implementation-friendly and computationally-efficient. |
| Researcher Affiliation | Collaboration | Yujun Shi1 Jian Liang3 Wenqing Zhang2 Vincent Y. F. Tan1 Song Bai2 1National University of Singapore 2Byte Dance Inc. 3Institute of Automation, CAS |
| Pseudocode | Yes | The pseudocode of our method is provided in Appendix G. |
| Open Source Code | Yes | Code: https://github.com/bytedance/Fed Decorr. All source code has been released at https://github.com/bytedance/Fed Decorr. |
| Open Datasets | Yes | Datasets: We adopt three popular benchmark datasets, namely CIFAR10, CIFAR100, and Tiny Image Net. CIFAR10 and CIFAR100 both have 50, 000 training samples and 10, 000 test samples, and the size of each image is 32 32. Tiny Image Net contains 200 classes, with 100, 000 training samples and 10, 000 testing samples, and each image is 64 64. The method generating local data for each client was introduced in Sec. 3.1. |
| Dataset Splits | No | The paper does not explicitly mention a validation set or its split, only training and test sample counts. |
| Hardware Specification | Yes | All results are produced with a NVIDIA Tesla V100 GPU. |
| Software Dependencies | No | The paper mentions using PyTorch-style pseudocode and basing their code on a previous work, but it does not specify version numbers for Python, PyTorch, or other relevant libraries. |
| Experiment Setup | Yes | We run 100 communication rounds for all experiments on the CIFAR10/100 datasets and 50 communication rounds on the Tiny Image Net dataset. We conduct local training for 10 epochs in each communication round using SGD optimizer with a learning rate of 0.01, a SGD momentum of 0.9, and a batch size of 64. The weight decay is set to 10 5 for CIFAR10 and 10 4 for CIFAR100 and Tiny Image Net. We apply the data augmentation of Cubuk et al. (2018) in all CIFAR100 and Tiny Image Net experiments. The β of FEDDECORR (i.e., β in Eqn. (9)) is tuned to be 0.1. |