PairNet: Training with Observed Pairs to Estimate Individual Treatment Effect
Authors: Lokesh Nagalapatti, Pranava Singhal, Avishek Ghosh, Sunita Sarawagi
ICML 2024 | Conference PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | Empirical comparison with thirteen existing methods across eight benchmarks, covering both discrete and continuous treatments, shows that Pair Net achieves significantly lower ITE error compared to the baselines. Also, it is model-agnostic and easy to implement. We release the code at the URL: https://github.com/ nlokeshiisc/pairnet_release. We conduct experiments to address the following research questions. |
| Researcher Affiliation | Academia | 1IIT Bombay. |
| Pseudocode | Yes | Algorithm 1 Pair Net Algorithm Require: Data D: {(xi, ti, yi)}, distance threshold δpair, number of pairs numz , Epochs E, ψ for forming pairs 1: Let ϕ rep. network and {µt} prediction heads 2: Dtrn,Dval SPLIT(D,pc=0.3, stratify=T) 3: Dval CREATEPAIRDS(Dval,D,δpair,numz ,ψ) 4: for e [E] do 5: De trn CREATEPAIRDS(Dtrn,Dtrn,δpair,numz ,ψ) 6: for each batch {(x,t,y,x ,t ,y )} De Trn do 7: z,z ϕ(x),ϕ(x ) 8: ˆy, ˆy µt(z),µt (z ) 9: loss L (y y ),(ˆy ˆy ) 10: ϕ,{µt} GRADDESC(loss) 11: end for 12: Break if EARLYSTOPPING(Dval,ϕ,{µt}) 13: end for 14: Return ϕ,{µt} 1: function CREATEPAIRDS(D ,D,δpair, numz ,ψ) 2: N |D|,Dpair {} 3: for (x i,t i,y i) D do 4: di[j] distance d(ψ(x i),ψ(xj)) j [N] 5: di[j] if t i =tj 6: qti(xj|xi) softmax( di) 7: Nbrsi SAMPLE(qti,numz ) 8: Dpair Dpair {(x i,t i,y i,xj,tj,yj)} j Nbrsi 9: end for 10: Return Dpair after dropping largest δpair distances. |
| Open Source Code | Yes | We release the code at the URL: https://github.com/ nlokeshiisc/pairnet_release. |
| Open Datasets | Yes | Binary Datasets. We use the following three benchmark datasets: IHDP, ACIC, and Twins. The IHDP and ACIC datasets are semi-synthetic with synthetic potential outcome functions µ (x,t), while the Twins dataset contains real outcomes. |
| Dataset Splits | Yes | For model training, all methods reserved 30% of the data for validation and implemented early stopping based on it. Pair Net early stops on pairs as shown in Algorithm 1. We configured all hyperparameters, network architecture, optimizer, learning rate, etc. according to the CATENets defaults. |
| Hardware Specification | Yes | Our experiments were conducted on a DGX machine equipped with an NVIDIA A100 GPU card, with 80 GB of GPU memory. The DGX machine is powered by an AMD EPYC 7742 64-Core Processor with 256 CPUs, featuring 64 cores per CPU. |
| Software Dependencies | No | Our codebase was entirely developed using JAX (Bradbury et al., 2018), a functional programming-based deep learning library that extends CUDA support for GPU acceleration. For a fair comparison, we adopt the hyperparameters used in the CATENets benchmark (Curth et al., 2021) as is except for the weight associated with the L2 penalty. |
| Experiment Setup | Yes | In each epoch of training, we sample mini-batches of 100 examples (along with their respective pairs for Pair Net) and impose losses on them. We use Adam optimizer with a learning rate set to 1e-4. Training proceeds for a maximum of 1000 epochs, while we perform early stopping based on a 30% validation set and a patience level of 10. |