TabR: Tabular Deep Learning Meets Nearest Neighbors
Authors: Yury Gorishniy, Ivan Rubachev, Nikolay Kartashev, Daniil Shlenskii, Akim Kotelnikov, Artem Babenko
ICLR 2024 | Conference PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | On a set of public benchmarks with datasets up to several million objects, Tab R marks a big step forward for tabular DL: it demonstrates the best average performance among tabular DL models, becomes the new state-of-the-art on several datasets, and even outperforms GBDT models on the recently proposed GBDT-friendly benchmark (see Figure 1). In this section, we compare Tab R (section 3) with existing retrieval-based solutions and state-of-the-art parametric models. In addition to the fully-fledged configuration of Tab R (with all degrees of freedom available for E and P as described in Figure 3), we also use Tab R-S ( S stands for simple ) a simple configuration, which does not use feature embeddings (Gorishniy et al., 2022), has a linear encoder (NE = 0) and a one-block predictor (NP = 1). We specify when Tab R-S is used only in tables, figures, and captions but not in the text. For other details on Tab R, including hyperparameter tuning, see subsection D.8. |
| Researcher Affiliation | Collaboration | Yury Gorishniy Ivan Rubachev Nikolay Kartashev Daniil Shlenskii Akim Kotelnikov Artem Babenko Deep learning (DL) models for tabular data problems (e.g. classification, regression) are currently receiving increasingly more attention from researchers. However, despite the recent efforts, the non-DL algorithms based on gradient-boosted decision trees (GBDT) remain a strong go-to solution for these problems. One of the research directions aimed at improving the position of tabular DL involves designing so-called retrieval-augmented models. For a target object, such models retrieve other objects (e.g. the nearest neighbors) from the available training data and use their features and labels to make a better prediction. In this work, we present Tab R essentially, a feed-forward network with a custom k-Nearest-Neighbors-like component in the middle. On a set of public benchmarks with datasets up to several million objects, Tab R marks a big step forward for tabular DL: it demonstrates the best average performance among tabular DL models, becomes the new state-of-the-art on several datasets, and even outperforms GBDT models on the recently proposed GBDT-friendly benchmark (see Figure 1). Among the important findings and technical details powering Tab R, the main ones lie in the attention-like mechanism that is responsible for retrieving the nearest neighbors and extracting valuable signal from them. In addition to the higher performance, Tab R is simple and significantly more efficient compared to prior retrieval-based tabular DL models. The source code is published: link. Corresponding author: firstnamelastname@gmail.com Yandex HSE |
| Pseudocode | No | The paper describes the architecture and processes using figures and mathematical equations (e.g., Figure 2, Figure 3, Figure 4, Equation 1-5) but does not contain a formal pseudocode or algorithm block. |
| Open Source Code | Yes | The source code is published: link. To make the results and models reproducible and verifiable, we provide our full codebase, all the results, and step-by-step usage instructions: link. |
| Open Datasets | Yes | In this work, we mostly use the datasets from prior literature and provide their summary in Table 1 (sometimes, we refer to this set of datasets as the default benchmark ). Additionally, in subsection 4.2, we use the recently introduced benchmark with middle-scale tasks ( 50K objects) (Grinsztajn et al., 2022) where GBDT was reported to be superior to DL solutions. Table 16: Details on datasets from the main benchmark. # Num , # Bin , and # Cat denote the number of numerical, binary, and categorical features, respectively. The Batch size is the default batch size used to train DL-based models. Abbr Name # Train # Validation # Test # Num # Bin # Cat Task type Batch size CH Churn Modelling 6400 1600 2000 10 3 1 Binclass 128 CA California Housing 13209 3303 4128 8 0 0 Regression 256 HO House 16H 14581 3646 4557 16 0 0 Regression 256 AD Adult 26048 6513 16281 6 1 8 Binclass 256 DI Diamond 34521 8631 10788 6 0 3 Regression 512 OT Otto Group Products 39601 9901 12376 93 0 0 Multiclass 512 HI Higgs Small 62751 15688 19610 28 0 0 Binclass 512 BL Black Friday 106764 26692 33365 4 1 4 Regression 512 WE Shifts Weather (subset) 296554 47373 53172 118 1 0 Regression 1024 CO Covertype 371847 92962 116203 54 44 0 Multiclass 1024 WE (full) Shifts Weather (full) 2965542 47373 531720 118 1 0 Regression 1024 |
| Dataset Splits | Yes | The dataset is split into three disjoint parts: 1, n = Itrain Ival Itest, where the train part is used for training, the validation part is used for early stopping and hyperparameter tuning, and the test part is used for the final evaluation. Table 16: Details on datasets from the main benchmark. # Train # Validation # Test |
| Hardware Specification | Yes | We report the used hardware in the results published along with the source code. In a nutshell, the vast majority of experiments on GPU were performed on one NVidia A100 GPU, the remaining small part of GPU experiments was performed on one Nvidia 2080 Ti GPU, and there was also a small portion of runs performed on CPU (e.g. all the experiments on Light GBM). |
| Software Dependencies | No | We report the used hardware in the results published along with the source code. The implementation, tuning hyperparameters, evaluation hyperparameters, metrics, execution times, hardware and other details are available in the source code. (The paper itself doesn't provide specific software versions, only mentions that it's in the source code.) |
| Experiment Setup | Yes | For each dataset, we used a predefined dataset-specific batch size. We continue training until there are patience + 1 consecutive epochs without improvements on the validation set; we set patience = 16 for the DL models. The specific default hyperparameter values of Tab R-S are as follows: d = 265 Attention dropout rate = 0.38920071545944357 Dropout rate in FFN = 0.38852797479169876 Learning rate = 0.0003121273641315169 Weight decay = 0.0000012260352006404615 Table 17: The hyperparameter tuning space for Tab R. |