Fast, Differentiable and Sparse Top-k: a Convex Analysis Perspective
Authors: Michael Eli Sander, Joan Puigcerver, Josip Djolonga, Gabriel Peyré, Mathieu Blondel
ICML 2023 | Conference PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | In 6, we chose to focus on three applications of our operators. First, we use them to prune weights in a multilayer perceptron during training and show that they lead to better accuracy than with a hard top-k. Second, we define top-k losses to fine-tune vision transformers (Vi Ts) and obtain better top-k accuracy than with the cross-entropy loss. Finally, we use our operators as a router in vision mixture of experts. |
| Researcher Affiliation | Collaboration | 1Ecole Normale Sup erieure 2Google Research, Brain team 3CNRS. |
| Pseudocode | Yes | Algorithm 1. (Dykstra s projection algorithm) and Algorithm 1 Pool Adjacent Violators (PAV) |
| Open Source Code | Yes | Our JAX (Bradbury et al., 2018) implementation is available at the following URL. |
| Open Datasets | Yes | Results on MNIST are displayed in Figure 5. finetune a Vi T-B/16 (Dosovitskiy et al., 2020) pretrained on the Image Net21k dataset on CIFAR 100 (Krizhevsky et al., 2009). JFT-300M dataset (Sun et al., 2017). |
| Dataset Splits | No | The paper mentions using standard datasets like MNIST, CIFAR-100, and JFT-300M, but does not explicitly provide specific training/validation/test split percentages, sample counts, or citations to predefined splits within the text. |
| Hardware Specification | Yes | In terms of hardware, we use 8 TPUs. We trained both models on TPUv2-128 devices. |
| Software Dependencies | No | The paper mentions 'JAX (Bradbury et al., 2018)' but does not specify a version number for it or any other key software dependencies. |
| Experiment Setup | Yes | We train the MLP using SGD with a batch size of 128 and a constant learning rate of 10 2. We trained the model for 30 epochs. warmup phase until the learning rate reaches 3 10 3. The model is trained for 100 steps using a cosine learning rate scheduler. Specifically, we train for 7 epochs using a batch size of 4 096. We use the Adam optimizer (β1 = 0.9, β2 = 0.999), with a peak learning rate of 10 3, warmed up for 10 000 steps and followed by linear decay. We use mild data augmentations (random cropping and horizontal flipping) and weight decay of 10 1 in all parameters as means of regularization. |