Alternating Gradient Descent and Mixture-of-Experts for Integrated Multimodal Perception
Authors: Hassan Akbari, Dan Kondratyuk, Yin Cui, Rachel Hornung, Huisheng Wang, Hartwig Adam
NeurIPS 2023 | Conference PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | We conduct extensive empirical studies and reveal the following key insights: 1) performing gradient descent updates by alternating on diverse modalities, loss functions, and tasks, with varying input resolutions, efficiently improves the model. 2) sparsification with Mo E on a single modalityagnostic encoder substantially improves the performance, outperforming dense models that use modality-specific encoders or additional fusion layers and greatly mitigates the conflicts between modalities. IMP achieves competitive performance on a wide range of downstream tasks including video classification, image classification, image-text, and video-text retrieval. |
| Researcher Affiliation | Industry | Hassan Akbari Dan Kondratyuk Yin Cui Rachel Hornung Huisheng Wang Hartwig Adam Google Research {hassanak, dankondratyuk, yincui, rachelhornung, huishengw, hadam}@google.com |
| Pseudocode | Yes | Algorithm 1 Accelerated Multimodal AGD Algorithm |
| Open Source Code | No | The paper provides links to JAX documentation (e.g., https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html) to explain features used, but it does not provide a direct link to the authors' own implementation code or explicitly state that their code for this work is open-source or publicly available. |
| Open Datasets | Yes | Datasets. Our datasets consist of a diverse set of learnable signals across multiple modalities. We use Web LI [Chen et al., 2022], LAION-400M [Schuhmann et al., 2021], WIT [Srinivasan et al., 2021], CC12M [Changpinyo et al., 2021], and VCC [Nagrani et al., 2022] for vision-text contrastive learning; JFT-3B [Zhai et al., 2022], I21K [Ridnik et al., 2021], and WTS-70M [Stroud et al., 2020] for both supervised classification and label-based vision-text contrastive estimation (similar to BASIC [Pham et al., 2021]); HT100M [Miech et al., 2019] and Audio Set [Gemmeke et al., 2017] for vision-audio-text triplet contrastive loss (similar to VATT [Akbari et al., 2021]). |
| Dataset Splits | No | The paper mentions evaluating on various public datasets like ImageNet1K and CIFAR-100, which typically have standard validation splits. However, it does not explicitly state the specific validation set splits or methodology (e.g., percentages, counts, or how they were used during training for model selection) for their own experimental setup. For example, it only mentions "We evaluate on Image Net1k and CIFAR-100 by image-to-text retrieval and linear probing on the frozen model s features and report the results in Table 2." |
| Hardware Specification | Yes | Compared to the previous state-of-the-art, Video Co Ca [Yan et al., 2022], we train IMP-Mo E-L on 256 TPU v4 chips for 6 days, representing only 15% of the total training cost of Video Co Ca. |
| Software Dependencies | No | The paper mentions the use of JAX APIs (jax.jit, jax.checkpoint, jax.lax.scan, jax.pjit) and refers to their documentation. However, it does not specify version numbers for JAX or any other software libraries or programming languages used in the experiments. |
| Experiment Setup | Yes | Training Parameters. For our final experiments, we train with a patch size of 4x16x16 on base input resolutions of 16x256x256 and 4x256x256 on video and image modalities respectively, resulting in a total of 1024 and 256 patches per sample. The text inputs in Image Net21K and JFT are truncated to 16 tokens to improve step efficiency with no loss of information, while keeping the text length of the rest of the datasets to a maximum of 256 tokens. We use a base batch size of 65536 and train using the Adam optimizer, a peak learning rate of 1e-3 with a cosine schedule, and apply no weight decay. For Mo E parameters, we apply experts-choose routing with a top-c capacity factor of 1.0 and do not apply any jittering to the routing or other auxiliary losses. |