Supervised Tree-Wasserstein Distance

Authors: Yuki Takezawa, Ryoma Sato, Makoto Yamada

ICML 2021 | Conference PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental Experimentally, we show that the STW distance can be computed fast, and improves the accuracy of document classification tasks. Furthermore, the STW distance is formulated by matrix multiplications, runs on a GPU, and is suitable for batch processing. Therefore, we show that the STW distance is extremely efficient when comparing a large number of documents.
Researcher Affiliation Academia Yuki Takezawa 1 2 Ryoma Sato 1 2 Makoto Yamada 1 2 1Kyoto University 2RIKEN AIP.
Pseudocode Yes Algorithm 1 Implementation of the STW distance, using Py Torch syntax.
Open Source Code No The paper states: “We implement S-WMD, and the TSW and STW distances in Py Torch.” However, it does not provide a direct link to their implementation code or explicitly state that it is open-sourced.
Open Datasets Yes We evaluate the following methods in document classification tasks on the synthetic and six real datasets following S-WMD in the test error rate of the k-nearest neighbors (k NN) and the time consumption: TWITTER, AMAZON, CLASSIC, BBCSPORT, OHSUMED, and REUTERS. Datasets are split into train/test as with the previous works (Kusner et al., 2015; Huang et al., 2016).
Dataset Splits Yes To select the margin m, we use 20% of the training dataset for validation. We then train our model at a learning rate of 0.1 and a batch size of 100 for 30 epochs. To avoid overfitting, we evaluated the STW distance using the parameters with the lowest loss in 30 epochs of the validation dataset.
Hardware Specification Yes We evaluated WMD (Sinkhorn), SWMD, and the TSW and STW distances on Nvidia Quadro RTX 8000, and WMD, Quadtree, and Flowtree on Intel Xeon CPU E5-2690 v4 (2.60 GHz).
Software Dependencies No The paper states: “We implement S-WMD, and the TSW and STW distances in Py Torch.” However, it does not specify the version number for PyTorch or any other software dependencies.
Experiment Setup Yes We initialize D1 such that the tree whose adjacency matrix is D1 is a perfect 5-ary tree of depth 5, and optimize Eq. (6) using Adam (Kingma & Ba, 2015) and LARS (You et al., 2017). After optimization, the deepest level of the tree is 5 or 6. To select the margin m, we use 20% of the training dataset for validation. We then train our model at a learning rate of 0.1 and a batch size of 100 for 30 epochs.