Correlated Quantization for Distributed Mean Estimation and Optimization

Authors: Ananda Theertha Suresh, Ziteng Sun, Jae Ro, Felix Yu

ICML 2022 | Conference PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental Experimental results show that our proposed algorithm outperforms existing mean estimation protocols on a diverse set of tasks. We demonstrate that the proposed algorithm outperforms existing baselines on several distributed tasks.
Researcher Affiliation Industry 1Google Research, New York.
Pseudocode Yes Algorithm 1 ONEDIMONEBITCQ Input: x1, x2, . . . , xn, l, r. Generate π, a random permutation of {0, 1, 2, . . . , n 1}.
Open Source Code Yes We implement all algorithms and experiments using the open-source JAX (Bradbury et al., 2018) and Fed JAX (Ro et al., 2021) libraries 3. 3https://github.com/google-research/google-research/tree/master/correlated_compression
Open Datasets Yes We also compare quantizers on the distributed mean estimation task for the MNIST (d = 784) dataset distributed over 100 clients. We use the image recognition task for the Federated MNIST dataset (Caldas et al., 2018a) provided by Tensor Flow Federated (Bonawitz et al., 2019).
Dataset Splits No The paper does not explicitly provide specific dataset split information (exact percentages, sample counts, citations to predefined splits, or detailed splitting methodology) for validation.
Hardware Specification No The paper does not provide specific hardware details (exact GPU/CPU models, processor types with speeds, memory amounts, or detailed computer specifications) used for running its experiments.
Software Dependencies No We implement all algorithms and experiments using the open-source JAX (Bradbury et al., 2018) and Fed JAX (Ro et al., 2021) libraries
Experiment Setup Yes We first fix the number of clients n to be 100, k = 2, and vary σmd. We then fix σmd = 0.01, n = 100 and vary k. Finally, we fix σmd = 0.01, k = 2 and vary n. The experiments are averaged over ten runs for statistical consistency. We use 2-level quantization (one bit) for all the algorithms, except Tern Grad which uses 3 levels and hence requires log2(3) bits per coordinate per client.