Deep Reinforcement Learning for Cost-Effective Medical Diagnosis

Authors: Zheng Yu, Yikuan Li, Joseph Chahn Kim, Kaixuan Huang, Yuan Luo, Mengdi Wang

ICLR 2023 | Conference PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental Experiments with real-world data validate that SM-DDPO trains efficiently and identify all Pareto-front solutions. Across all tasks, SM-DDPO is able to achieve state-of-the-art diagnosis accuracy (in some cases higher than conventional methods) with up to 85% reduction in testing cost. Core codes are available on Git Hub1.
Researcher Affiliation Academia Zheng Yu Princeton University Yikuan Li Northwestern University Joseph C. Kim Princeton University Kaixuan Huang Princeton University Yuan Luo Northwestern University Mengdi Wang Princeton University
Pseudocode Yes Algorithm 1 Semi-Model-Based Deep Diagnosis Policy Optimization (SM-DDPO)
Open Source Code Yes Core codes are available on Git Hub1. 1https://github.com/Zheng321/Deep-Reinforcement-Learning-for-Cost-Effective-Medical-Diagnosis
Open Datasets Yes We followed steps in Zimmerman et al. (2019) to extract 23,950 ICU visits of 19,811 patients from the MIMIC-III dataset Johnson et al. (2016)
Dataset Splits Yes We split each dataset into 3 parts: training set (75%), validation set (15%), and test set (10%).
Hardware Specification No Quest provides computing access of over 11,800 CPU cores. In our experiment, we deploy each model training job to one CPU, so that multiple configurations can be tested simultaneously. Each training job requires a wall-time less than 2 hours of a single CPU core.
Software Dependencies No Our codes used the implementation of PPO algorithm in package Raffin et al. (2021). We use the Python package stable-baseline3 Raffin et al. (2021) for implementing PPO.
Experiment Setup Yes We list all the hyper-parameters we tuned in Table 7, including both the tuning range and final selection. Table 7 lists specific values like "Batch size 256", "Learning rate 1e-3", "Hidden size (3-layer) 64", "# Timesteps per update 1024".