Cognitive Model Discovery via Disentangled RNNs
Authors: Kevin Miller, Maria Eckstein, Matt Botvinick, Zeb Kurth-Nelson
NeurIPS 2023 | Conference PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | We fit behavior data using a recurrent neural network that is penalized for carrying information forward in time, leading to sparse, interpretable representations and dynamics. When fitting synthetic behavioral data from known cognitive models, our method recovers the underlying form of those models. When fit to laboratory data from rats performing either a reward learning task or a decision-making task, our method recovers simple and interpretable models that make testable predictions about neural mechanisms. |
| Researcher Affiliation | Collaboration | Kevin J. Miller Google Deep Mind and University College London London, UK kevinjmiller@deepmind.com |
| Pseudocode | No | The paper includes architectural diagrams and equations but no explicit pseudocode or algorithm blocks. |
| Open Source Code | No | The paper mentions using JAX and Haiku, providing their GitHub URLs, but does not provide a link or explicit statement for the authors' own source code developed for this paper. |
| Open Datasets | Yes | For the two-armed bandit task, datasets consisted of sequences of binary choices made (left vs. right) and outcomes experienced (reward vs. no reward) by either a rat [37] or an artificial agent (Q-Learning or Leaky Actor-Critic, see below). ... For the pulse accumulation task, datasets consisted of sequences of integer pulse counts (left vs right) observed and binary choices made (left vs right) by either a rat [7] or an artificial agent (Bounded Accumulation, see below). |
| Dataset Splits | Yes | Following [37], we evaluated model performance using two-fold cross-validation: we divide each rat s dataset into even-numbered and odd-numbered sessions, fit a set of model parameters to each, and compute the log-likelihood for each parameter set using the unseen portion of the dataset. |
| Hardware Specification | Yes | Using a second-generation TPU, models required between four and fifty hours to complete this number of training steps. |
| Software Dependencies | Yes | Networks were defined using custom modules written using Jax [6] and Haiku [22]. ... JAX: composable transformations of Python+Num Py programs. Version 0.3.13. ... Haiku: Sonnet for JAX. Version 0.0.9. |
| Experiment Setup | Yes | Each network had five latent variables. Update MLPs consisted of three hidden layers containing five units each. The Choice MLP consisted of two hidden layers of two units each. We used the rectified linear (Re LU) activation function. Networks parameters were optimized using gradient descent and the Adam optimizer [30], with a learning rate of 5 10 3. We typically trained networks for 105 steps, except that networks with very low β (10 3 or 3 10 4) required longer to converge and were trained for 5 105 steps. |