Learning to Embed Distributions via Maximum Kernel Entropy
Authors: Oleksii Kachaiev, Stefano Recanatesi
NeurIPS 2024 | Conference PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | We empirically demonstrate the performance of our method by performing classification tasks in multiple modalities. We here demonstrate that our proposed method successfully performs unsupervised learning of data-dependent distribution kernel across different modalities. The experimental setup is divided into two phases: unsupervised pre-training and downstream regression classification using the learned kernel. |
| Researcher Affiliation | Academia | Oleksii Kachaiev Dipartimento di Matematica, Università degli Studi di Genova, Genoa, Italy oleksii.kachaiev@gmail.com Stefano Recanatesi Technion Israel Institute of Technology, Haifa, Israel Allen Institute for Neural Dynamics, Seattle, USA stefano.recanatesi@gmail.com |
| Pseudocode | Yes | In this section, we provide an example that illustrates the implementation of the proposed method using the PyTorch framework. Functions to compute distribution kernel Gram matrix: def pairwise_kernel (x, gamma1): ... def distribution_kernel_gram (x, gamma1 , gamma2): ... Distribution kernel entropy estimator and the MKDE loss: def distribution_kernel_entropy (K): ... def mkde_loss(encoder , X, gamma1 , gamma2): ... |
| Open Source Code | Yes | The code for the implementation of the proposed loss is provided in Appendix E. |
| Open Datasets | Yes | For this study we used a dataset [56] where more than 100.000 cells are measured per each patient (subject). Images. MNIST [12] and Fashion-MNIST [65] consist of 28 × 28 pixel grayscale images divided into 10 classes. Text. To assess our method’s performance in a larger discrete support space, we utilized the “20 Newsgroups” [27], a multi-class text classification dataset. |
| Dataset Splits | Yes | For each dataset, we select a hold-out validation subset with balanced classes, while the remainder of the dataset is utilized for unsupervised pre-training. A grid search with 5 splits (70/30) is conducted to optimize the strength of the squared l2 regularization penalty C, exploring 50 values over the log-spaced range {10−7, . . . , 105}. The best estimator is then applied to evaluate classification accuracy on the validation subset, which we report. |
| Hardware Specification | No | All experiments were performed on a single machine with 1 GPU and 6 CPUs. |
| Software Dependencies | No | In this section, we provide an example that illustrates the implementation of the proposed method using the PyTorch framework. |
| Experiment Setup | Yes | We use mini-batch ADAM [22] with a static learning rate of 0.0005. We report mini-batch based (instead of epoch based) training dynamics as our tasks do not require cycling over the entire dataset to converge to the optimal loss value. A grid search with 5 splits (70/30) is conducted to optimize the strength of the squared l2 regularization penalty C, exploring 50 values over the log-spaced range {10−7, . . . , 105}. |