Notice: The reproducibility variables underlying each score are classified using an automated LLM-based pipeline, validated against a manually labeled dataset. LLM-based classification introduces uncertainty and potential bias; scores should be interpreted as estimates. Full accuracy metrics and methodology are described in [1].

Steering Masked Discrete Diffusion Models via Discrete Denoising Posterior Prediction

Authors: Jarrid Rector-Brooks, Mohsin Hasan, Zhangzhi Peng, Chenghao Liu, Sarthak Mittal, Nouha Dziri, Michael Bronstein, Pranam Chatterjee, Alexander Tong, Joey Bose

ICLR 2025 | Venue PDF | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental Empirically, we instantiate DDPP by steering MDMs to perform class-conditional pixel-level image modeling, RLHF-based alignment of MDMs using text-based rewards, and finetuning protein language models to generate more diverse secondary structures and shorter proteins. We substantiate our designs via wet-lab validation, where we observe transient expression of reward-optimized protein sequences.
Researcher Affiliation Collaboration 1Université de Montréal, 2Mila, 3Duke University, 4Mc Gill University, 5Allen Institute for AI, 6Oxford University, 7Aithyra
Pseudocode Yes Algorithm 1 Single-step DDPP-IS and DDPP-LB Input: Reward R(x0), base MDM ppre 0 (x0|xt), sampling policy r(x0), fine-tuning MDM qθ(x0|xt) 1: while Training do 2: t, x0 U[0, 1], r(x0) Sample time and clean data on or off-policy 3: xt pt(xt|x0) Construct partially masked sample given clean data 4: if Importance Sample log Z(xt) then Log Partition Function Estimation Strategy Algorithm 2 DDPP-KL Input: Differentiable reward R(x0), base MDM ppre 0 (x0|xt), fine-tuning MDM qθ(x0|xt), Num samples K 1: while Training do 2: t, x0 U[0, 1], q(x0) Sample time and clean data on-policy from the fine-tuning MDM 3: xt pt(xt|x0) Construct a partially masked sample given clean data 4: {ˆxi 0}K i=0 qt,θ( |xt) Reparametrized Sampling of clean data 5: LKL = 1 K PK i=1 log qt,θ(ˆxi 0|xt) log ppre 0 (ˆxi 0|xt) log R(ˆxi 0)
Open Source Code Yes Code is available at https://github.com/jarridrb/ddpp.
Open Datasets Yes We finetune MDMs to generate even MNIST digits.
Dataset Splits No For FLD, we draw K samples from the model, and K samples from the test set restricted to the target class. The FLD is computed using the DINOV2 feature space (from the Vi T-B14 model) between these two sets of samples (Oquab et al., 2024). For MNIST, K = 5k. For baselines other than discrete guidance, this metric is computed on the test set restricted to the target-class. For discrete guidance, BPD is computed by evaluating the MDM ELBO of the base model on samples generated using guidance (due to lacking an analogous ELBO for guidance-based sampling).
Hardware Specification Yes All experiments were performed on a shared heterogenous high-performance computing cluster. This cluster is primarily composed of GPU nodes with RTX8000, V100, A100, L40S, and H100 NVIDIA GPUs.
Software Dependencies No No specific software versions are provided. The paper mentions using RoBERTa, BERT, GPT-2, ESMFold, Adam optimizer, but no version numbers are specified for these or any other core libraries/frameworks.
Experiment Setup Yes For fine-tuning we train the model using our loss-functions with the Adam optimizer, using a learning rate of 4e 3, β1, β2 = {0.9, 0.999}, and a weight decay of 0 across all methods. DDPP-IS used 16 samples to estimate the partition function. Training is done using a replay buffer populated with points x0 sampled on policy from the model, as well as off-policy points from the prior distribution, added to the buffer every 100 training steps. A batch of 64 is used.