FedBFPT: An Efficient Federated Learning Framework for Bert Further Pre-training

Authors: Xin'ao Wang, Huan Li, Ke Chen, Lidan Shou

IJCAI 2023 | Conference PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental Theoretical analysis is conducted to support the efficiency of FEDBFPT, and experiments are conducted on corpora across domains such as medicine, biology, and computer science. Results indicate that FEDBFPT achieves performance levels comparable to traditional FL methods while reducing computation and communication costs by 46.70% and 7.04%, respectively, even approaching the performance of centralized training models.
Researcher Affiliation Academia Xin ao Wang , Huan Li , Ke Chen and Lidan Shou Key Lab of Intelligent Computing Based Big Data of Zhejiang Province Zhejiang University, Hangzhou, China {wangxin.ao, lihuan.cs, chenk, should}@zju.edu.cn
Pseudocode Yes The whole process is formalized in Algorithm 1. S1 Each client Ck fetches the parameters of the global model from the server and persists the fetched parameters in the local parameter pool. Simultaneously, an index number ℓis received from the server, which indicates the specific ℓ-th T-layer to be trained and updated for the local model at Ck. The determination of the layer index number ℓis to be detailed in Section 3.3. S2 With the other layers including the embedding layer kept frozen, the ℓ-th T-layer and the output layer at each client Ck are trained using the local data Dk. Specifically, the MLM task [Devlin et al., 2019] is performed for a number ek of epochs (cf. MLM Train(Dk, Pk, ℓ, ek) in Algorithm 1). How to train the ℓ-th T-layer in the smallcapacity local model will be presented in Section 3.3. S3 Each client Ck uploads the weights of the trained ℓ-th T-layer and the output layer to the server. S4 The server merges the weights uploaded by all clients. There are many approaches for merging the weights in FL, such as Fed Avg [Mc Mahan et al., 2017], Fed Prox [Li et al., 2020a], and Fed Adam [Reddi et al., 2020]. In our implementation, we choose the typical approach Fed Avg, which obtains the merged weights W (i) at the i-th FL training iteration as follows. k=1 W (i) k , (1) where W (i) k is the weights uploaded by client Ck at the i-th iteration. S5 The server updates the global model using the merged weights and determines the next index number ℓ. S6 Resume to S1 if the current FL training iteration number i does not exceed the user-specified threshold I. Otherwise, we terminate the further pre-training process, and consequently, the further pre-trained global model is ready to be fine-tuned for downstream tasks. To minimize computational expenses for clients with limited resources, we opt to train only a single T-layer and output layer while keeping the remaining parts of the model frozen. However, choosing which T-layer to train (i.e., determining ℓ) and how to efficiently train on smaller capacity local models (cf. step S2) are critical procedures. To resolve these, we propose the Progressive Learning with Sampled Deeper Layers (PL-SDL) method in Section 3.3. ALGORITHM 1 FEDBFPT Input: datasets {D1, . . . , DN}, global model G with initial model parameters W (0), epoch threshold I, local MLM training epoch ek and layer index number ℓ. Output: final model parameters W (I) of G. Initialize: parameter pool Pk of each client Ck; ℓ 0. // FL training for I iterations for i = 1 to I do // Parallel executions at clients for each client Ck do if i > 1 then update Pk with W (i 1) fetched from the server end if // Train ℓ-th T-layer and the output layer of the local model and obtain the weights W (i) k MLM Train(Dk, Pk, ℓ, ek) upload W (i) k to the server end for // Merge clients weights at server W (i) Merge(W (i) 1 , . . . , W (i) N ) determine the next layer index number ℓ sent W (i) and ℓto each client end for
Open Source Code Yes The Source code is released at https://github.com/Hanzhouu/Fed BFPT.
Open Datasets Yes We use the S2ORC dataset [Lo et al., 2020], which contains many general-purpose corpora for NLP and text mining research over scientific papers, for MLM further pre-training, and then fine-tune the global BERT model on classification (CLS) and Named Entity Recognition (NER) tasks of corresponding scientific domains. In particular, we use the JNLPBA dataset [Collier and Kim, 2004] in the Biology domain for NER and the Sci ERC [Luan et al., 2018] dataset in the Computer Science domain for NER, and the Rct-20k [Beltagy et al., 2019] 2 dataset in the Medicine domain for CLS for fine-tuning.
Dataset Splits Yes To simulate the FL setting, we generate 6 clients and build a local BERT model for each client. The dataset for pretraining is evenly partitioned and then stored in these clients. The uniform and skewed distributions of local datasets are compared in Section 5.4. ... To create a more realistic simulation, we introduce a normal distribution to determine the sizes of datasets assigned to clients. Each client receives a subset of the centralized dataset according to the proportions specified by the normal distribution.
Hardware Specification Yes We use Pytorch 1.11 [Paszke et al., 2019] and an NVIDIA RTX A6000 with 49140 MB Video Memory for training.
Software Dependencies Yes We use Pytorch 1.11 [Paszke et al., 2019] and an NVIDIA RTX A6000 with 49140 MB Video Memory for training.
Experiment Setup Yes Corresponding to the costs defined in Section 3.1, we measure the computational cost as the time to perform one epoch of MLM training on the S2ORC domain subset with batch size 256 and learning rate 5 × 10 −5, and the communication cost as the storage space occupied by the trained parameters per client.