Notice: The reproducibility variables underlying each score are classified using an automated LLM-based pipeline, validated against a manually labeled dataset. LLM-based classification introduces uncertainty and potential bias; scores should be interpreted as estimates. Full accuracy metrics and methodology are described in [1].

Leveraging Function Space Aggregation for Federated Learning at Scale

Authors: Nikita Dhawan, Nicole Elyse Mitchell, Zachary Charles, Zachary Garrett, Gintare Karolina Dziugaite

TMLR 2024 | Venue PDF | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental Our extensive evaluation includes domain-specific criteria as well as metrics specific to FL. We demonstrate settings in which Fed Fish outperforms Fed Avg, especially as the amount of local training is varied. Image and language experiments with varying levels of client data heterogeneity show improved post-personalization performance of Fed Fish throughout training, when the global model is locally fine-tuned for a few steps by clients that were held out during training. This observation also holds when measuring transfer performance by drawing the evaluation clients from a shifted data distribution. For instance, in an experiment with federated pretraining on the large and hetergenous C4 dataset, followed by few-shot personalization on Stack Overflow clients, Fed Fish is able to improve upon Fed Avg’s next-token prediction performance by 5-7%, depending on the amount of personalization data available. We provide insight into these gains by assessing a measure of deviation between global and local models, coined the Client-Server Barrier. Finally, we discuss the impact of these methods and settings on the cost of communication between clients and the server. Contributions. We formalize a function space perspective of federated learning to motivate a scalable algorithm, Fed Fish, which aims to match client input output functions during aggregation. Via a synthetic example, we demonstrate that Fed Fish outperforms Fed Avg as client data heterogeneity increases. We then investigate this performance at larger scales than have been explored by previous works. Figure 1: Given two functions modeled over disjoint supports (left), a direct parameter average fails to represent either function well (center), while function space aggregation aims to preserve both functional relationships (right). Our thorough empirical results show that Fed Fish allows for longer local client training compared to Fed Avg. We find that the global models learned via Fed Fish have greater ability to be personalized via fine-tuning on the same or shifted data distributions, indicating they provide a better initialization for local training in each round. We propose to evaluate effects of aggregation via a Client-Server Barrier, leveraging the function space perspective to gain further insight into the observed results. 2 Federated Learning in the Function Space 4 Evaluation and Client-Server Barrier 5 Experiments Using the criteria described in section 4, we now conduct a systematic empirical evaluation of Fed Fish in varied settings, compared to the best performing variant of Fed Avg. We first demonstrate the advantage of Fed Fish as client data heterogeneity increases in a toy regression problem. We then assess its performance across settings in larger scale image and language benchmarks. 5.1 A Toy Regression Demonstration Figure 2 shows a non-linear regression problem with two clients across which data is distributed with varying heterogeneity, including full (+), partial ( ) and no overlap ( ). We plot the local functions learned by each client, as well as the global functions produced by aggregation via Fed Avg and Fed Fish after one round. For completely homogeneous client data, Fed Avg and Fed Fish fit similar functions. When there is partial overlap, Fed Avg seems to reasonably retain predictions on one client’s data while poorly fitting the other, while Fed Fish fits both datasets well. In the extreme case of completely disjoint supports, Fed Avg fails to fit either client dataset, but Fed Fish matches the locally learned functions of both clients on their respective input data. The Client-Server Barrier (CSB) defined in eq. (5) is computed in terms of mean squared error for each client on their corresponding data. As shown by all the points above the x = y line in fig. 2 (bottom right), the CSB is lower for Fed Fish than Fed Avg in each of these settings, with more significant difference as data heterogeneity increases. We hypothesize that accounting for the functions learned by local models confers this advantage upon Fed Fish. 5.2 Image Classification and Text Benchmarks Datasets and architectures. We consider a variety of federated benchmarks for image classification (EMNIST (Cohen et al., 2017), CIFAR100 (Krizhevsky et al.)) and language modeling (Stack Overflow (Authors, 2019), CC-News (Hamborg et al., 2017) and C4 (Raffel et al., 2020)). In particular, C4 is a largescale and significantly heterogenous dataset. For these domains, we use standard classifier and transformer architectures, respectively. Performance metrics. We evaluate global performance, client personalization performance and Client Server Barrier (see section 4), using standard domain-relevant performance metrics. These include classification accuracy for images and next-token prediction accuracy and perplexity for language modeling. Since C4 is a very large scale dataset that may generally be used as a pretraining corpus, we evaluate its global model on held-out clients from C4 itself as well as on the shifted Stack Overflow and CC-News datasets. This tests the methods transfer performance in addition to adaptability to new clients that were not seen during training. For the C4 experiments, we also vary the amount of fine-tuning data used for personalization 25% or 50% of each held-out client’s data to assess few-shot performance. Additional details are reported in A.3. 5.2.1 Effect of Local Training on Global Model Performance We study the effect of local training on the global model’s performance by varying the number of epochs of training clients perform in between rounds. Since increasing the number of local epochs for a fixed number of rounds increases computational costs, we present our results by separately fixing compute (or total number of training iterations) and number of aggregation rounds. We provide a complete table of results covering all settings in table 3 of appendix A.4 and discuss representative experiments here. With fixed compute, fig. 3 (left) shows a decline in global accuracy as number of local epochs is increased for both Fed Avg and Fed Fish, as expected by the client-drift phenomenon. However, within each setting, and across datasets, Fed Fish outperforms Fed Avg, suffering a more graceful decline in performance with more local training and converging to higher performance faster. Figure 4 similarly shows Fed Fish outperforming Fed Avg in terms of classification accuracy and next-token prediction accuracy for CIFAR100 (left) and Stack Overflow (right) datasets, respectively. In both of these settings, Fed Fish is much more robust to longer periods of local training that Fed Avg, whose performance suffers as local training increases. In fig. 5 (left), we use the heterogeneous C4 dataset for federated training and evaluate its zero-shot performance on unseen clients from C4 itself. The bottom row of the plot shows that the global model, without any personalization data, benefits considerably from increased local epochs during federated pretraining for both algorithms, with Fed Fish outperforming Fed Avg. 5.2.2 Post-Personalization Performance At scale, pretrained models are often personalized or fine-tuned for a small number of steps using local client data before deployment. Accordingly, we fine-tune the global models for a few steps on limited datapoints from held-out clients and evaluate the metrics discussed earlier. Consistent with global performance reported above, we observe across tasks that Fed Fish yields models with higher post-personalization performance than those trained with Fed Avg. This is shown on EMNIST in fig. 3 (right), on CIFAR100 in fig. 4 (left), on Stack Overflow in fig. 4 (right), and on C4 in fig. 5 (left). Notably we see that personalization worsens performance on CIFAR100 trained with Fed Avg using 16 local epochs, while substantially improving CIFAR100 trained with Fed Fish using the same configuration. By contrast, Fed Avg with 16 local epochs on Stack Overflow improves dramatically with personalization, despite still underperforming all other models. Interestingly, we see in the case of C4 that more than aggregation algorithm the amount of local training seems to impact personalization performance. While both Fed Fish models trained with 1 or 16 local epochs have higher zero-shot performance than either Fed Avg model, the 16 local epoch Fed Avg model’s postpersonalization performance surpasses that of Fed Fish with 1 local epoch as the amount of fine-tuning data increases. Overall, we find that models trained with Fed Fish using longer periods of local training tend to be more amenable to personalization than models trained with Fed Avg. 5.2.3 Transfer Performance Considering Fed Avg and Fed Fish as methods of federated pretraining, we further evaluate the networks trained on C4 in terms of their transfer performance (fig. 5) on Stack Overflow (center) and CC-News (right). Here, performance is in terms of next-token prediction accuracy. We similarly report the perplexity for each of these settings in appendix A.3.3. We vary the amount of data available for personalization, between 0%, 25% and 50% of each held-out client dataset. The reported performance is always evaluated on the unseen 50% of the data. In each of these settings, we find that transfer performance benefits from longer local training for both methods and Fed Fish yields better zero-shot, few-shot and post-personalization performance than Fed Avg. We observe largest gains in the case of federated pretraining on C4, followed by few-shot personalization on Stack Overflow clients, where Fed Fish improves upon Fed Avg’s next-token prediction performance by 5-7%, depending on the amount of personalization data available. These results are promising since they encourage longer local training, which connotes parallelism and efficiency gains. Note, the results in fig. 5 correspond to a fixed number of federated rounds, R; we similarly report performance at round R/2 in table 5 of appendix A.4. 5.2.4 Client-Server Barrier To gain more insight into the performance of Fed Fish, which only differs from Fed Avg in the aggregation step, we measure the Client-Server Barrier (CSB) defined in eq. (5), using accuracy as the metric L, for different model checkpoints across rounds of federated training. Note that we fix seeds to control cohort sampling in each federated round, such that the dataset iteration for Fed Avg and Fed Fish training runs match. In fig. 6, we plot this quantity for Fed Fish on the x-axis and for Fed Avg on the y-axis, using networks trained on Stack Overflow with 8 local epochs (top left), Stack Overflow with 16 local epochs (top right), CIFAR100 with 16 local epochs (bottom left) and C4 with 16 local epochs (bottom right). Points lying above or below the x = y line on these plots indicate whether Fed Fish or Fed Avg, respectively, achieves lower CSB in those rounds. We observe the relative round-over-round aggregation performance of Fed Avg and Fed Fish vary significantly for different datasets and settings throughout training. For Stack Overflow, the CSB steadily increases with rounds of training, with Fed Fish achieving lower barrier than Fed Avg during later stages of federated training. This difference is more stark when training with 16 local epochs as compared to 8 local epochs. In the case of C4, while the barrier generally increases with rounds of training, we see that Fed Fish tends to have lower values in the beginning stages of training while Fed Avg obtains lower CSB towards the final 20% of training rounds. This indicates connections to linear mode connectivity in later stages of training. We discuss this connection in appendix A.6 and leave deeper explorations to future work. Interestingly, CIFAR100 CSB values decrease with rounds of federated training with Fed Fish achieving lower barrier than Fed Avg throughout. On investigating further, we find that the client and data splits on CIFAR100 are such that each local model achieves very high performance right from the beginning of training, and maintains that performance throughout. Hence, the reduction in CSB as rounds increase is indicative of the improvement of the global model as it bridges its gap to local models. This is in contrast to the other datasets we present, wherein the local models often improve their performance more gradually.
Researcher Affiliation Collaboration Nikita Dhawan EMAIL University of Toronto and Vector Institute Nicole Mitchell EMAIL Google Research Zachary Charles EMAIL Google Research Zachary Garrett EMAIL Google Research Gintare Karolina Dziugaite EMAIL Google Deep Mind and Mila Quebec AI Institute
Pseudocode Yes Algorithm 1 Fed Fish (SGD) Require: rounds R, local epochs E, θG, client datasets D, global lr ηg, local lr ηc 1: for r 1 to R do 2: Sample a cohort of clients, C 3: for i C in parallel do 4: Compute client weights, wi = |Di| 5: θi θG 6: θi, Fi Fed Local Train(E, θi, Di, ηc) 7: end for 8: θG θG ηg PN i=1 wi F T i θi PN i=1 wi Fi 9: end for 10: return θG Algorithm 2 Fed Local Train (SGD) Require: E, θi, Di, ηc 1: Model Delta, Sum Fisher 0, 0 2: for e 1 to E do 3: for b Di do 4: g θL(θi, b) 5: θi θi ηc g 6: Model Delta Model Delta + ηc g 7: end for 8: end for 9: for b Di do 10: g θL(θi, b) 11: Sum Fisher Sum Fisher + g2 12: end for 13: return Model Delta, Sum Fisher
Open Source Code No No explicit statement or link for the authors' own source code for the methodology described in this paper was found.
Open Datasets Yes Each of these datasets is publicly available. EMNIST is licensed under Standard Reference Data by NIST. CIFAR100 is published by the authors. Stack Overflow is licensed under the Creative Commons Attribution Share Alike 3.0 Unported License. C4 and CC-News are hosted by commoncrawl.org and we access both through Hugging Face datasets.
Dataset Splits Yes For EMNIST, we partition the handwritten characters according to their author, as proposed by Caldas et al. (2018). For Stack Overflow, posts on the eponymous web forum are partitioned by their author as well. For CIFAR100, we partition the examples over 100 clients in a heterogeneous fashion using the two-level latent Dirichlet allocation scheme proposed by Reddi et al. (2020). For CC-News and C4, we use Dataset Grouper (Charles et al., 2023) to partition web-crawled examples according to their base URL (e.g. nytimes.com). More details about data splits, architectures and hyperparameters are included in appendix A.3. ... The data is split into train, test and validation: train client examples are from before 2018-01-01 UTC, test client examples are from after 2018-01-01 UTC, and validation clients are held out from both train and test splits. ... We further measure few-shot performance of the C4 base model by conducting a personalization evaluation on 25% of held-out client data for each dataset of interest (C4, CC-News, Stack Overflow). Because of this specific personalization evaluation, we filter C4 evaluation datasets to have at least 4 examples per client.
Hardware Specification Yes We run image-classification experiments on a TPU Pod slice consisting of 4 TPU v2 chips in a 2x2 topology, interconnected on a single machine. Each TPU v2 chip contains two Tensor Cores, 16 Gi B of high-bandwidth memory, and 90.75 Gi B RAM. We run language-modeling experiments on a TPU Pod slice consisting of 16 TPU v3 chips in a 4x4 topology, configured to use a multi-machine inter-chip interconnect mesh. Each TPU v3 chip contains two Tensor Cores, 32 Gi B of high-bandwidth memory, 87.25 Gi B RAM, 900 GBps bandwidth, and 123 teraflops peak compute.
Software Dependencies No No specific software dependencies with version numbers were explicitly mentioned.
Experiment Setup Yes Hyperparameter Tuning We fix hyperparameters like number of clients per round, number of training rounds, local batch size, maximum dataset size for any client and sequence length for language models to reasonable values based on previous literature. For local and global learning rates, we conducted a grid search over [1e-4, 5e-4, 1e-3, 5e-3, 1e-2, 5e-2, 1e-1], and chose the best performing hyperparameters. Final hyperparameter configurations used to obtain the results in the paper are listed in table 2. These are held consistent for Fed Avg and Fed Fish experiments. Table 2: Final hyperparameter configurations for all datasets. Dataset EMNIST CIFAR100 Stack Overflow C4 Number of clients per round 64 64 16 8 Number of training rounds 1500 1500 1500 10000 Batch size for local training 10 25 4 4 Max client dataset size per round 100 100 16 16 Sequence length (training) 128 1024 Sequence length (personalization and evaluation) 128 128 Number of personalization epochs 1 4 4 4 Global learning rate 1e-3 1e-3 1e-2 1e-2 Local learning rate 1e-3 1e-3 1e-3 1e-3