def bipartite_soft_matching( metric: torch.Tensor, r: int, class_token: bool = False, distill_token: bool = False, ) -> Tuple[Callable, Callable]: """ Applies ToMe with a balanced matching set (50%, 50%). Input size is [batch, tokens, channels]. r indicates the number of tokens to remove (max 50% of tokens). Extra args: - class_token: Whether or not there's a class token. - distill_token: Whether or not there's also a distillation token. When enabled, the class token and distillation tokens won't get merged. """ protected = 0 if class_token: protected += 1 if distill_token: protected += 1
# We can only reduce by a maximum of 50% tokens t = metric.shape[1] r = min(r, (t - protected) // 2)
if r <= 0: return do_nothing, do_nothing
with torch.no_grad(): metric = metric / metric.norm(dim=-1, keepdim=True) a, b = metric[..., ::2, :], metric[..., 1::2, :] scores = a @ b.transpose(-1, -2)
if class_token: scores[..., 0, :] = -math.inf if distill_token: scores[..., :, 0] = -math.inf
到目前为止,已经能够直接向已经训练好的 ViT 模型中添加 ToMe 模块。使用 ToMe 模块进行训练虽然不是必须的,但是它可以减少准确度下降,并且加快训练速度。ToMe 模块本质上是 token 的均值操作,因此可以视为是一种池化操作 (Pooling)。因此,我们可以按照平均池化操作 (Average Pooling) 的方式进行反向传播。
51.1.6 其他消融实验结果
定义式1所示的对不同 tokens 进行加权的方式为 weighted avg,在决定合并哪些 tokens 之后,通过对 tokens 进行平均加权来合并它们。下图5左侧的消融实验结果表明,weighted avg 的方式优于直接的 average pooling 的方式以及 max pooling 的方式。
如下图11所示是 ToMe 方法 + MAE 微调的模型 (具体是在 MAE 进行微调的环节用上了本文的 ToMe 方法) 与其他 ImageNet-1K 模型的性能对比,可以看到 ToMe 方法可以提高 ViT 模型的吞吐量,使得较深的 ViT 模型 (如 ViT-H 和 ViT-L) 的吞吐量与较浅的模型相当。
对于视频实验,作者使用 Kinetics-400 数据集,使用了 Spatiotemporal MAE[11] 的方式来训练。仿照图像实验的两种做法进行验证,一种是直接把 ToMe 方法应用在现成的训练好的模型中,另一种是在 MAE 进行微调的环节用上 ToMe 方法。实验结果如下图14所示。将 ToMe 方法应用在 ViT-L 上之后,吞吐量与 Swin-B 接近,同时性能更好。而且,将 ToMe 方法应用在 ViT-L 上之后,使用 Spatiotemporal MAE[11] 的方式,性能明显优于 MAE 方式训练的 ViT-B 模型,说明 token 融合的方法比 model scaling 更好。
图14:视频任务实验结果,蓝色是无需训练直接使用 ToMe 方法的结果,灰色是微调阶段使用 ToMe 方法的结果
总结
ToMe 是一个无需训练并且兼顾性能-速度权衡的 token 融合方法,意在缩减 ViT 模型中大量冗余的 tokens。Token Merging 的基本思路是在一个 ViT 模型中间插入一些 token merging 的模块,希望把这些模块植入 ViT 以后,训练和推理的速度都有提升。在图像和视频中多个模型的实验结果表明,这种 token 融合的方法能够合并背景和前景中的大量冗余的 tokens,提高 ViT 模型的吞吐量,而且不丢失信息。
参考
^abDynamicViT: Efficient Vision Transformers with Dynamic Token Sparsification
^Adavit: Adaptive vision transformers for efficient image recognition
^abA-ViT: Adaptive tokens for efficient vision transformer
^abSpvit: Enabling faster vision transformers via soft token pruning
^Token pooling in vision transformers
^Tokenlearner: Adaptive space-time tokenization for videos
^How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers
^Masked autoencoders are scalable vision learners
^Revisiting weakly supervised pre-training of visual perception models
^Training data-efficient image transformers & distillation through attention