Supervised Tree-Wasserstein Distance
Authors: Yuki Takezawa, Ryoma Sato, Makoto Yamada
ICML 2021 | Conference PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | Experimentally, we show that the STW distance can be computed fast, and improves the accuracy of document classification tasks. Furthermore, the STW distance is formulated by matrix multiplications, runs on a GPU, and is suitable for batch processing. Therefore, we show that the STW distance is extremely efficient when comparing a large number of documents. |
| Researcher Affiliation | Academia | Yuki Takezawa 1 2 Ryoma Sato 1 2 Makoto Yamada 1 2 1Kyoto University 2RIKEN AIP. |
| Pseudocode | Yes | Algorithm 1 Implementation of the STW distance, using Py Torch syntax. |
| Open Source Code | No | The paper states: “We implement S-WMD, and the TSW and STW distances in Py Torch.” However, it does not provide a direct link to their implementation code or explicitly state that it is open-sourced. |
| Open Datasets | Yes | We evaluate the following methods in document classification tasks on the synthetic and six real datasets following S-WMD in the test error rate of the k-nearest neighbors (k NN) and the time consumption: TWITTER, AMAZON, CLASSIC, BBCSPORT, OHSUMED, and REUTERS. Datasets are split into train/test as with the previous works (Kusner et al., 2015; Huang et al., 2016). |
| Dataset Splits | Yes | To select the margin m, we use 20% of the training dataset for validation. We then train our model at a learning rate of 0.1 and a batch size of 100 for 30 epochs. To avoid overfitting, we evaluated the STW distance using the parameters with the lowest loss in 30 epochs of the validation dataset. |
| Hardware Specification | Yes | We evaluated WMD (Sinkhorn), SWMD, and the TSW and STW distances on Nvidia Quadro RTX 8000, and WMD, Quadtree, and Flowtree on Intel Xeon CPU E5-2690 v4 (2.60 GHz). |
| Software Dependencies | No | The paper states: “We implement S-WMD, and the TSW and STW distances in Py Torch.” However, it does not specify the version number for PyTorch or any other software dependencies. |
| Experiment Setup | Yes | We initialize D1 such that the tree whose adjacency matrix is D1 is a perfect 5-ary tree of depth 5, and optimize Eq. (6) using Adam (Kingma & Ba, 2015) and LARS (You et al., 2017). After optimization, the deepest level of the tree is 5 or 6. To select the margin m, we use 20% of the training dataset for validation. We then train our model at a learning rate of 0.1 and a batch size of 100 for 30 epochs. |