Token Merging: Your ViT But Faster
Authors: Daniel Bolya, Cheng-Yang Fu, Xiaoliang Dai, Peizhao Zhang, Christoph Feichtenhofer, Judy Hoffman
ICLR 2023 | Conference PDF | Archive PDF | Plain Text | LLM Run Details
| Reproducibility Variable | Result | LLM Response |
|---|---|---|
| Research Type | Experimental | We perform several experiments on Image Net-1k (Deng et al., 2009) using Vi T models trained in four different ways: Aug Reg (Steiner et al., 2022), MAE (He et al., 2022), SWAG (Singh et al., 2022), and Dei T (Touvron et al., 2021). For all experiments, we either run the model off-the-shelf with our method or, in the case of MAE and Dei T, trained with our method applied. |
| Researcher Affiliation | Collaboration | Daniel Bolya1,2 Cheng-Yang Fu2 Xiaoliang Dai2 Peizhao Zhang2 Christoph Feichtenhofer2 Judy Hoffman1 1 Georgia Tech 2 Meta AI {dbolya,judy}@gatech.edu, {chengyangfu,xiaoliangdai,stzpz,feichtenhofer}@meta.com |
| Pseudocode | Yes | def bipartite_soft_matching(k: torch.Tensor , r: int) -> torch.Tensor: """ Input is k from attention , size [batch , tokens , channels ]. """ k = k / k.norm(dim=-1, keepdim=True) a, b = k[..., ::2, :], k[..., 1::2, :] scores = a @ b.transpose(-1, -2) scores[..., 0, :] = -math.inf # don t merge cls token node_max , node_idx = scores.max(dim=-1) edge_idx = node_max.argsort(dim=-1, descending=True)[..., None] unm_idx = edge_idx[..., r:, :] # Unmerged Tokens src_idx = edge_idx[..., :r, :] # Merged Tokens dst_idx = node_idx[..., None].gather(dim=-2, index=src_idx) unm_idx = unm_idx.sort(dim=-2)[0] # Sort cls token back to idx 0 def merge(x: torch.Tensor) -> torch.Tensor: """ Input is of shape [batch , tokens , channels ]. """ src , dst = x[..., ::2, :], x[..., 1::2, :] n, t1, c = src.shape unm = src.gather(dim=-2, index=unm_idx.expand(n, t1 r, c)) src = src.gather(dim=-2, index=src_idx.expand(n, r, c)) dst = dst.scatter_add(-2, dst_idx.expand(n, r, c), src) return torch.cat([unm , dst], dim=-2) return merge |
| Open Source Code | Yes | Work done during an internship at Meta AI. Code at http://github.com/facebookresearch/To Me. |
| Open Datasets | Yes | We perform several experiments on Image Net-1k (Deng et al., 2009) using Vi T models trained in four different ways: Aug Reg (Steiner et al., 2022), MAE (He et al., 2022), SWAG (Singh et al., 2022), and Dei T (Touvron et al., 2021). For all experiments, we either run the model off-the-shelf with our method or, in the case of MAE and Dei T, trained with our method applied. All throughputs are measured during inference on a V100 GPU with optimal batch size and fp32 unless noted otherwise. |
| Dataset Splits | Yes | For each schedule, we test its accuracy and fp16 throughput on Image Net-1k val using an off-the-shelf Aug Reg Vi T-B/16 model. In Fig. 2, we plot the results of this experiment and find that a constant schedule is close to optimal, especially as the total tokens merged increases. We further analyze the best random samples (see Appendix C) and find that a linearly decreasing schedule works well at throughputs up to 3x. Thus, we also define a decreasing schedule that removes 2r tokens in the first layer and 0 tokens in the last layer, linearly interpolating for the rest. This also removes r L tokens, but is faster because more are removed early: |
| Hardware Specification | Yes | All throughputs are measured during inference on a V100 GPU with optimal batch size and fp32 unless noted otherwise. |
| Software Dependencies | No | The pseudocode mentions `Py Torch (Paszke et al., 2019)` but does not provide specific version numbers for PyTorch or other software dependencies like Python. |
| Experiment Setup | Yes | In Tab. 1, we ablate the design choices made in our approach. For each ablation, we start from our default parameters marked in purple. Unless otherwise noted, we test on an off-the-shelf Vi T-L/16 MAE model without training (acc: 85.96%, im/s: 93.3). and merge with r = 8 which gradually removes 98% of tokens over the 24 layers of the network. All throughputs are measured during inference on a V100 GPU with optimal batch size and fp32 unless noted otherwise. |