From Sparse to Soft Mixtures of Experts

Authors: Joan Puigcerver, Carlos Riquelme Ruiz, Basil Mustafa, Neil Houlsby

ICLR 2024 | Conference PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental We pretrain our models on JFT-4B (Zhai et al., 2022a), a proprietary dataset that contains more than 4B images, covering 29k classes. During pretraining, we evaluate the models on two metrics: upstream validation precision-at-1 on JFT-4B, and Image Net 10-shot accuracy. Finally, we provide the accuracy on the validation set of Image Net-1k after finetuning on the training set of Image Net-1k (1.3 million images) at 384 resolution. Figure 3: Train Pareto frontiers. Soft Mo E dominates both Vi Ts (dense) and popular Mo Es (Experts and Tokens Choice) on the training cost / performance Pareto frontier. Table 1: Models trained for longer durations (cooldown steps in parentheses). We study the effect of changing the number of slots and experts in the Sparse and Soft Mo Es. We study the impact of the components of the Soft Mo E routing layer by running the following ablations:
Researcher Affiliation Industry Joan Puigcerver Google Deep Mind Carlos Riquelme Google Deep Mind Basil Mustafa Google Deep Mind Neil Houlsby Google Deep Mind
Pseudocode Yes 1 def soft_moe_layer(X, Phi, experts): 2 # Compute the dispatch and combine weights. 3 logits = jnp.einsum( md,dnp->mnp , X, Phi) 4 D = jax.nn.softmax(logits, axis=(0,)) 5 C = jax.nn.softmax(logits, axis=(1, 2)) 6 # The input slots are a weighted average of all the input tokens, 7 # given by the dispatch weights. 8 Xs = jnp.einsum( md,mnp->npd , X, D) 9 # Apply the corresponding expert function to each input slot. 10 Ys = jnp.stack([ 11 f_i(Xs[i, :, :]) for i, f_i in enumerate(experts)], 13 # The output tokens are a weighted average of all the output slots, 14 # given by the combine weights. 15 Y = jnp.einsum( npd,mnp->md , Ys, C) 16 return Y Algorithm 1: Simple JAX (Bradbury et al., 2018) implementation of a Soft Mo E layer. Full code is available at https://github.com/google-research/vmoe.
Open Source Code Yes Full code is available at https://github.com/google-research/vmoe.
Open Datasets Yes We pretrain our models on JFT-4B (Zhai et al., 2022a), a proprietary dataset that contains more than 4B images, covering 29k classes. During pretraining, we evaluate the models on two metrics: upstream validation precision-at-1 on JFT-4B, and Image Net 10-shot accuracy. The latter is computed by freezing the model weights and replacing the head with a new one that is only trained on a dataset containing 10 images per class from Image Net-1k (Deng et al., 2009). Finally, we provide the accuracy on the validation set of Image Net-1k after finetuning on the training set of Image Net-1k (1.3 million images) at 384 resolution. Finally, in Appendix F.1 we show that Soft Mo Es also surpass vanilla Vi T and the Experts Choice router when trained from scratch on the publicly available LAION-400M (Schuhmann et al., 2021).
Dataset Splits Yes During pretraining, we evaluate the models on two metrics: upstream validation precision-at-1 on JFT-4B, and Image Net 10-shot accuracy. The latter is computed by freezing the model weights and replacing the head with a new one that is only trained on a dataset containing 10 images per class from Image Net-1k (Deng et al., 2009). Finally, we provide the accuracy on the validation set of Image Net-1k after finetuning on the training set of Image Net-1k (1.3 million images) at 384 resolution.
Hardware Specification Yes Distributing the model typically adds an overhead in the cost of the model, which is not captured by the time complexity analysis based on FLOPs that we derived above. In order to account for this difference, in all of our experiments we measure not only the FLOPs, but also the wall-clock time in TPUv3-chip-hours. Total Training TPUv3-days Evaluation time (TPUv3 ms/img)
Software Dependencies No Algorithm 1: Simple JAX (Bradbury et al., 2018) implementation of a Soft Mo E layer. The paper mentions JAX but does not specify a version number for JAX itself or for Python.
Experiment Setup Yes We trained for 300k steps with batch size 4096, resolution 224, using a reciprocal square root learning rate schedule. All models were trained for 4M steps, except for H/14, which was trained for 2M steps for cost reasons. We replace the last half of the blocks in Vi T S/16, B/16, L/16, and H/14 with Soft Mo E layers with 128 experts, using one slot per expert. Therefore, we increase the cooldown from 50k steps to 500k steps.