Cold Analysis of Rao-Blackwellized Straight-Through Gumbel-Softmax Gradient Estimator

Authors: Alexander Shekhovtsov

ICML 2023 | Conference PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental We demonstrate that the simple ST ZGR family of estimators practically dominates in the bias-variance tradeoffs the whole GR family while also outperforming SOTA unbiased estimators. We compare ZGR with Gumbel-Softmax (GS), Straight-through Gumbel-Softmax (GS-ST) (Jang et al., 2017), Gumbel-Rao with MC samples (GR-MC) (Paulus et al., 2021) and the ST estimator (4). We also compare to the REINFORCE with the leave-one-out baseline (Kool et al., 2019) using M 2 inner samples, denoted RF(M), which is a strong baseline amongst unbiased estimators. In some tests we include unbiased ARSM (Yin et al., 2019), which requires more computation than RF(4) but performs worse. See Appendix B.2 for details of implementations.
Researcher Affiliation Academia 1Department of Cybernetics, Czech Technical University in Prague, Czech Republic. Correspondence to: Alexander Shekhovtsov <shekhole@fel.cvut.cz>.
Pseudocode Yes def ZGR(p:Tensor)->Tensor:"""Returns a categorical sample from p [*,C] (over axis=-1) as one-hot vector, with ZGR gradient. """ index = distributions.categorical.Categorical(probs=p).sample() x = functional.one_hot(index, num_classes=p.shape[-1]).to(p) logpx = p.log().gather(-1, index.unsqueeze(-1))# log p(x) dx_ST = p dx_RE = (x p.detach()) * logpx dx = (dx_ST + dx_RE) / 2 return x + (dx dx.detach()) # value of x with backprop through dx
Open Source Code Yes Our implementation of the described experiments is available for research purposes at https://gitlab.com/shekhovt/zgr.
Open Datasets Yes We use MNIST data with a fixed binarization (Yin et al., 2019) and Omniglot data with dynamic binarization (Burda et al., 2016; Dong et al., 2021). The down-scaled dataset published by (Burda et al., 2016)3 was used, same as in the public implementation of (Dong et al., 2021). It contains about 24000 training images, which were split into training (90%) and validation (10%) parts and currently we are not using the validation part.
Dataset Splits Yes In quantized training we use MNIST1 and Fashion MNIST2 datasets. Each contains 60000 training and 10000 test images. We used 54000 images for training and 6000 for validation. It contains about 24000 training images, which were split into training (90%) and validation (10%) parts and currently we are not using the validation part.
Hardware Specification Yes Time [ms] of a forward-backward pass per batch on GPU (Nvidia Tesla P100).
Software Dependencies Yes Gumbel-Softmax (GS) and Straight-through Gumbel-Softmax (GS-ST) (Jang et al., 2017) are shipped with pytorch4. For Gumbel-Rao with MC samples (GR-MC) we adopted the public reimplementation by nshepperd 5, which is parallel over MC-samples. Figure 5: Time [ms] of a forward-backward pass per batch on GPU (Nvidia Tesla P100). ... The time is measured after optimizing out Python and C++ calling overheads with CUDA Graphs in Pytorch 1.13.
Experiment Setup Yes Following (Dong et al., 2021) we train with Adam with learning rate 10 4 using batch size 50. Furthermore we tried to match the training time that of (Dong et al., 2021). For MNIST we perform 500 epochs, and for Omniglot-28-D we perform 1000 epochs, roughly equivalent in booth cases to their 500K iterations with batch size 50. We train a convolutional network (32C5-MP2-64C5-MP2-512FC-10FC) on Fashion MNIST. ...All methods are applied with Adam optimizer for 200 epochs. For every method we select the bast validation performance with the grid search for the learning rate from {10 3, 3.3 10 4, 10 4}. We used the step-wise learning rate schedule decreasing the learning rate 10 times at epochs 100 and 150. The whole procedure is repeated for 3 different initialization seeds and we report the mean test error over seeds and (max min)/2 over seeds.