Context-Guided Diffusion for Out-of-Distribution Molecular and Protein Design
Authors: Leo Klarner, Tim G. J. Rudner, Garrett M Morris, Charlotte Deane, Yee Whye Teh
ICML 2024 | Conference PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | We demonstrate that this approach leads to substantial performance gains across various settings, including continuous, discrete, and graph-structured diffusion processes with applications across drug discovery, materials science, and protein design. We compare our method to both standard guidance models and sophisticated pre-training and domain adaptation techniques across a range of experimental settings. Specifically, we demonstrate the versatility of context-guided diffusion by applying it to the design of small molecules with graph-structured diffusion processes (Section 5.1), to the generation of novel materials with equivariant diffusion models (Section 5.2), and to the optimization of discrete protein sequences with categorical diffusion models (Section 5.3). |
| Researcher Affiliation | Collaboration | 1University of Oxford, UK 2New York University, New York, USA. Correspondence to: Leo Klarner <leo.klarner@stats.ox.ac.uk>. |
| Pseudocode | Yes | Algorithm 1 One guidance model training iteration with the context-guided regularization scheme from Equation (9). Require: Data and context batch: (x0, y), ˆx0 Diffusion time steps: t, t U(0, T) Guidance model: ft( ; θ) Regularizer: DM(ft( ; θ), Qt( ))2 Hyperparameters: σt, τt, λ xt Noising Process(x0, t) y ft(xt; θ) L log pt(y | xt; y) + 1 2λ||θ||2 2 ˆxt Noising Process(ˆx0, t ) ˆy ft (ˆxt ; θ) R P j=1,2 DM(ˆyj, Qj t (ˆxt ))2 θ Update(θ, L + R) |
| Open Source Code | Yes | The code for our experiments can be accessed at: https://github.com/leojklarner/ context-guided-diffusion |
| Open Datasets | Yes | We evaluate all methods on the same dataset as Lee et al. (2023), consisting of 250 000 molecules sampled uniformly from the ZINC database of commercially available compounds (Irwin et al., 2012). |
| Dataset Splits | Yes | This label split allows us to select regularization hyperparameters that maximize the ability of the guidance models to generalize well to novel, high-value regions of chemical space, serving as a proxy for the desired behavior at sampling time. Specifically, we pre-compute the labels y for every target and use them to split the data into a low-property training and a high-property validation set. |
| Hardware Specification | Yes | All models were trained with an identical setup and on the same NVIDIA A100 GPUs, differing only in the regularizer. |
| Software Dependencies | Yes | All code was written in (Van Rossum & Drake, 2009) and can be accessed at https://github.com/leojklarner/context-guided-diffusion. A range of core scientific computing libraries were used for data preparation and analysis, including NUMPY (Harris et al., 2020), SCIPY (Virtanen et al., 2020), PANDAS (pandas development team, 2020), MATPLOTLIB (Hunter, 2007), SEABORN (Waskom, 2021), SCIKIT-LEARN (Pedregosa et al., 2011) and RDKIT (Landrum et al., 2013). All deep learning models were implemented in PYTORCH (Paszke et al., 2019). |
| Experiment Setup | Yes | For training the guidance models ft( ; θ), we use the graph convolutional neural network (GNN) architecture from Lee et al. (2023), consisting of 3 GCN layers with 16 hidden units and tanh activations (Kipf & Welling, 2016). The embeddings from each layer are concatenated and fed into two fully connected heads with tanh and sigmoid activations, respectively, whose outputs of dimension 16 are then multiplied and fed into another two-layer MLP with the same dimension and RELU activation functions (Nair & Hinton, 2010). We train separate networks for each protein target, predicting both the most likely regression label µt(gt; θ) = f 1 t (gt; θ), as well as log-variances log σ2 t (gt; θ) = f 2 t (gt; θ) that serve as an estimator of their predictive uncertainty. Given the regression labels y, the models are optimized with respect to a negative log-likelihood loss Lt = log N(y; µt, σ2 t I) Following the protocol of Lee et al. (2023), all models are trained for 10 epochs of stochastic gradient descent, using the ADAM optimizer with a batch size of 1024 and a learning rate of 1 10 3 (Kingma & Ba, 2014). |