Token Merging: Your ViT But Faster

Authors: Daniel Bolya, Cheng-Yang Fu, Xiaoliang Dai, Peizhao Zhang, Christoph Feichtenhofer, Judy Hoffman

ICLR 2023 | Conference PDF | Archive PDF | Plain Text | LLM Run Details

Reproducibility Variable Result LLM Response
Research Type Experimental We perform several experiments on Image Net-1k (Deng et al., 2009) using Vi T models trained in four different ways: Aug Reg (Steiner et al., 2022), MAE (He et al., 2022), SWAG (Singh et al., 2022), and Dei T (Touvron et al., 2021). For all experiments, we either run the model off-the-shelf with our method or, in the case of MAE and Dei T, trained with our method applied.
Researcher Affiliation Collaboration Daniel Bolya1,2 Cheng-Yang Fu2 Xiaoliang Dai2 Peizhao Zhang2 Christoph Feichtenhofer2 Judy Hoffman1 1 Georgia Tech 2 Meta AI {dbolya,judy}@gatech.edu, {chengyangfu,xiaoliangdai,stzpz,feichtenhofer}@meta.com
Pseudocode Yes def bipartite_soft_matching(k: torch.Tensor , r: int) -> torch.Tensor: """ Input is k from attention , size [batch , tokens , channels ]. """ k = k / k.norm(dim=-1, keepdim=True) a, b = k[..., ::2, :], k[..., 1::2, :] scores = a @ b.transpose(-1, -2) scores[..., 0, :] = -math.inf # don t merge cls token node_max , node_idx = scores.max(dim=-1) edge_idx = node_max.argsort(dim=-1, descending=True)[..., None] unm_idx = edge_idx[..., r:, :] # Unmerged Tokens src_idx = edge_idx[..., :r, :] # Merged Tokens dst_idx = node_idx[..., None].gather(dim=-2, index=src_idx) unm_idx = unm_idx.sort(dim=-2)[0] # Sort cls token back to idx 0 def merge(x: torch.Tensor) -> torch.Tensor: """ Input is of shape [batch , tokens , channels ]. """ src , dst = x[..., ::2, :], x[..., 1::2, :] n, t1, c = src.shape unm = src.gather(dim=-2, index=unm_idx.expand(n, t1 r, c)) src = src.gather(dim=-2, index=src_idx.expand(n, r, c)) dst = dst.scatter_add(-2, dst_idx.expand(n, r, c), src) return torch.cat([unm , dst], dim=-2) return merge
Open Source Code Yes Work done during an internship at Meta AI. Code at http://github.com/facebookresearch/To Me.
Open Datasets Yes We perform several experiments on Image Net-1k (Deng et al., 2009) using Vi T models trained in four different ways: Aug Reg (Steiner et al., 2022), MAE (He et al., 2022), SWAG (Singh et al., 2022), and Dei T (Touvron et al., 2021). For all experiments, we either run the model off-the-shelf with our method or, in the case of MAE and Dei T, trained with our method applied. All throughputs are measured during inference on a V100 GPU with optimal batch size and fp32 unless noted otherwise.
Dataset Splits Yes For each schedule, we test its accuracy and fp16 throughput on Image Net-1k val using an off-the-shelf Aug Reg Vi T-B/16 model. In Fig. 2, we plot the results of this experiment and find that a constant schedule is close to optimal, especially as the total tokens merged increases. We further analyze the best random samples (see Appendix C) and find that a linearly decreasing schedule works well at throughputs up to 3x. Thus, we also define a decreasing schedule that removes 2r tokens in the first layer and 0 tokens in the last layer, linearly interpolating for the rest. This also removes r L tokens, but is faster because more are removed early:
Hardware Specification Yes All throughputs are measured during inference on a V100 GPU with optimal batch size and fp32 unless noted otherwise.
Software Dependencies No The pseudocode mentions `Py Torch (Paszke et al., 2019)` but does not provide specific version numbers for PyTorch or other software dependencies like Python.
Experiment Setup Yes In Tab. 1, we ablate the design choices made in our approach. For each ablation, we start from our default parameters marked in purple. Unless otherwise noted, we test on an off-the-shelf Vi T-L/16 MAE model without training (acc: 85.96%, im/s: 93.3). and merge with r = 8 which gradually removes 98% of tokens over the 24 layers of the network. All throughputs are measured during inference on a V100 GPU with optimal batch size and fp32 unless noted otherwise.