NODE-GAM: Neural Generalized Additive Model for Interpretable Deep Learning
Authors: Chun-Hao Chang, Rich Caruana, Anna Goldenberg
ICLR 2022 | Conference PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | In this work, we propose a neural GAM (NODE-GAM) and neural GA2M (NODEGA2M) that scale well and perform better than other GAMs on large datasets, while remaining interpretable compared to other ensemble and deep learning models. We demonstrate that our models find interesting patterns in the data. Lastly, we show that we improve model accuracy via self-supervised pre-training, an improvement that is not possible for non-differentiable GAMs. 4.1 ARE NODE-GAM AND NODE-GA2M ACCURATE? We compare our performance on 6 popular binary classification datasets (Churn, Support2, MIMIC2, MIMIC3, Income, and Credit) and 2 regression datasets (Wine and Bikeshare). |
| Researcher Affiliation | Collaboration | 1University of Toronto, 2Vector Institute, 3Hospital of Sickkids, 4Microsoft Research |
| Pseudocode | Yes | C PSEUDO-CODE FOR NODE-GAM Here we provide the pseudo codes for our model in Alg. 1-4. |
| Open Source Code | Yes | We released our code in https://github.com/zzzace2000/nodegam with instructions and hyperparameters to reproduce our final results. |
| Open Datasets | Yes | F DATASET DESCRIPTIONS Here we describe all 8 datasets we use and we summarize them in Table 6. Churn: this is to predict which user is a potential subscription churner for telecom company. https://www.kaggle.com/blastchar/telco-customer-churn ... For 6 datasets used in NODE, we use the scripts from NODE paper (https://github.com/ Qwicen/node) which directly downloads the dataset. Here we still cite and list their sources: Click: https://www.kaggle.com/c/kddcup2012-track2 Higgs: UCI (Dua & Graff, 2017) https://archive.ics.uci.edu/ml/ datasets/HIGGS ... |
| Dataset Splits | Yes | We use 5-fold cross validation to derive the mean and standard deviation for each model. We use 80-20 splits for training and val set. For datasets we compile of medium-sized (Income, Churn, Credit, Mimic2, Mimic3, Support2, Bikeshare), we do a 5-fold cross validation for 5 different test splits. For datasets in NODE paper (Click, Epsilon, Higgs, Microsoft, Yahoo, Year), we use train/val/test split provided by the NODE paper author. |
| Hardware Specification | Yes | All NODE, NODE-GAM/GA2M are run with 1 TITAN Xp GPU, 4 CPU and 8GB memory. For EBM and Spline, they are run with a machine with 32 CPUs and 120GB memory. |
| Software Dependencies | No | We use the cubic spline in Py GAM package (Servén & Brummitt, 2018) and We use QHAdam (Ma & Yarats, 2018) and sklearn.preprocessing.quantile_transform. Specific version numbers for these software components are not provided. |
| Experiment Setup | Yes | G HYPERPARAMETERS SELECTION... Below we describe the hyperparameters we use for each method: For EBM, we set inner_bags=100 and outer_bags=100 and set the maximum rounds as 20k... For NODE, NODE-GA2M AND NODE: optimizer: QHAdam (Ma & Yarats, 2018) lr_warmup_steps: 500 num_checkpoints_avged: 5 temperature_annealing_steps (S): 4k min_temperature: 0.01 batch_size: 2048, or the max batch size that fits in GPU memory with minimum 128. Maximum training time: 20 hours. We use random search to find the best hyperparameters which we list the range in below. We list the random search range for NODE: num_layers: {2, 3, 4, 5}. Default: 3. total tree counts (= num_trees num_layers): {500, 1000, 2000, 4000}. Default: 2000. depth: {2, 4, 6}. Default: 4. output_dropout (p1): {0, 0.1, 0.2}. Default: 0. colsample_bytree: {1, 0.5, 0.1, 1e-5}. Default: 0.1. lr: {0.01, 0.005}. Default: 0.01. l2_lambda: {0., 1e-7, 1e-6, 1e-5}. Default: 1e-5. |