Skip to content

[Developing] Toolbox and Benchmark for Token Merging Modules

License

Notifications You must be signed in to change notification settings

Westlake-AI/OpenToMe

Repository files navigation

Toolbox and Benchmark for Token Merging Modules

News

  • 2025-06-23: Setup the basic framework of OpenToMe.

Installation

Install experimental dependencies

conda create -n opentome python=3.10.0
conda activate opentome
pip install torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/cu124
pip install -r requirements.txt

Install OpenToMe from source

git git@github.com:Westlake-AI/OpenToMe.git
cd OpenToMe
pip install -e .

Developing

Install FLA (flash-linear-attention)

Follow the flash-linear-attention, install the requested environment for training.

Getting Started

Model Examples

Here is an example of using ToMe with timm Attention blocks.

import torch
import timm
from torch import nn
from opentome.timm import Block, tome_apply_patch
from opentome.tome import check_parse_r


class TransformerBlock(nn.Module):
    def __init__(self, *, embed_dim=768, num_layers=12, num_heads=12, drop_path=0.0,
                 with_cls_token=True, init_values=1e-5, use_flash_attn=False, **kwargs):
        super(TransformerBlock, self).__init__()

        self.embed_dim = embed_dim
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.with_cls_token = with_cls_token

        if self.with_cls_token:
            self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
        self.pos_drop = nn.Dropout(p=0.0)

        dp_rates=[x.item() for x in torch.linspace(drop_path, 0.0, num_layers)]
        self.blocks = nn.Sequential(
            *[Block(
                dim=self.embed_dim,
                num_heads=self.num_heads,
                mlp_ratio=4.0,
                qkv_bias=True,
                init_values=init_values,
                drop_path=dp_rates[j],
                norm_layer=nn.LayerNorm,
                act_layer=nn.GELU,
                mlp_layer=timm.layers.Mlp,
                use_flash_attn=use_flash_attn,
            ) for j in range(num_layers)]
        )
        self.norm = nn.LayerNorm(self.embed_dim)

    def forward(self, x):
        B, N, C = x.shape
        if self.with_cls_token:
            cls_tokens = self.cls_token.expand(B, -1, -1)
            x = torch.cat((cls_tokens, x), dim=1)
        x = self.pos_drop(x)
        x = self.blocks(x)
        x = self.norm(x)
        return x

# test
embed_dim, token_num, merge_num, inflect = 384, 196, 100, 0.5
x = torch.randn(1, token_num, embed_dim)
# model = timm.create_model('vit_small_patch16_224')
model = TransformerBlock(embed_dim=384, num_layers=12, num_heads=8)
z = model.forward(x)
print(x.shape, z.shape)

# update tome
merge_ratio = check_parse_r(len(model.blocks), merge_num, token_num, inflect)
tome_apply_patch(model)
model.r = (merge_ratio, inflect)
model._tome_info["r"] = model.r
model._tome_info["total_merge"] = merge_num

z = model.forward(x)
print(x.shape, z.shape)

Image Classification on ImageNet

Here is an example of evaluate ImageNet validation set with various Token Compression methods.

CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 \
./evaluations/image_classification/in1k_example.py \
--model_name vit_base_patch16_224 \
--tome tome \
--merge_num 100 \
--dataset /PATH/TO/ImageNet/val \
--inflect -0.5 \

You can also run the evaluation with the bash example on GPU0:

bash evaluations/image_classification/in1k_eval.sh 0 tome 100 /PATH/TO/ImageNet/val 1 deit_small_patch16_224

Image ToMe Visualization

Here is an example of visualization with various Token Compression methods

CUDA_VISIBLE_DEVICES=0 python -m torch.distributed.launch --nproc_per_node=0 \
evaluations/visualizations/vis_classification.py \
--image_path ./demo \
--model_name deit_small_patch16_224 \
--tome tome \
--merge_num 100 \
--inflect -0.5 \
--save_vis True \

You can run the visualization with the bash example:

bash evaluations/visualizations/vis_eval.sh ./demo tome 100 1 deit_small_patch16_224

Flash Linear Attention Model Training

Here is an example of training with flash linear attention by flame

#!/usr/bin/bash
export HF_ENDPOINT=https://hf-mirror.com

NNODE=1 NGPU=8 LOG_RANK=0 bash train.sh \
  --job.config_file flame/models/fla.toml \
  --job.dump_folder RESULTS/PATH \
  --model.config configs/gla_340M.json \
  --model.tokenizer_path TOKENIZER/PATH \
  --optimizer.name AdamW \
  --optimizer.eps 1e-15 \
  --optimizer.lr 3e-4 \
  --lr_scheduler.warmup_steps 1024 \
  --lr_scheduler.lr_min 0.1 \
  --lr_scheduler.decay_type cosine \
  --training.batch_size 32 \
  --training.seq_len 2048 \
  --training.gradient_accumulation_steps 1 \
  --training.steps 20480 \
  --training.max_norm 1.0 \
  --training.skip_nan_inf \
  --training.dataset DATASET/PATH \
  --training.dataset_name default \
  --training.dataset_split train \
  --training.streaming \
  --training.num_workers 32 \
  --training.prefetch_factor 2 \
  --training.seed 42 \
  --training.compile \
  --training.tensor_parallel_degree 1 \
  --training.disable_loss_parallel \
  --checkpoint.interval 2048 \
  --checkpoint.load_step -1 \
  --metrics.log_freq 1

NOTE: OpenToMe now support other optimizers for training.

Evaluation of PPL and Common-sense Resaoning/QA

The evaluation we follow up with the flash-linear-attention. Please confirm that the requirements for lmms-eval-harness are satisfied.

Here is an example of PPL evaluation by GLA model.

#!/usr/bin/bash
export HF_ENDPOINT=https://hf-mirror.com
MODEL='MODEL/PATH'

python -m harness --model hf \
    --model_args pretrained=$MODEL,dtype=bfloat16 \
    --tasks wikitext \
    --batch_size 64 \
    --num_fewshot 0 \
    --device cuda \
    --show_config

Token Compression Baselines

  • ToMe [ICLR'23] Token Merging: Your ViT but Faster paper code
  • DiffRate [ICCV'23] Diffrate: Differentiable Compression Rate for Efficient Vision Transformers paper code
  • DCT [ACL'23] Fourier Transformer: Fast Long Range Modeling by Removing Sequence Redundancy with FFT Operator paper code
  • LTMP [TMLR'23] Learned Thresholds Token Merging and Pruning for Vision Transformers paper code
  • ToFu [WACV'24] Token Fusion: Bridging the Gap between Token Pruning and Token Merging paper
  • CrossGET [ICML'24] CrossGET: Cross-Guided Ensemble of Tokens for Accelerating Vision-Language Transformers paper code
  • MCTF [CVPR'24] Multi-criteria Token Fusion with One-step-ahead Attention for Efficient Vision Transformers paper code
  • ATC [ECCV'24] paper code
  • DTEM [NeurIPS'24] Learning to Merge Tokens via Decoupled Embedding for Efficient Vision Transformers paper code
  • PiToMe [NeurIPS'24] Accelerating Transformers with Spectrum-Preserving Token Merging paper code
  • FPET [CVPR'25] Faster Parameter-Efficient Tuning with Token Redundancy Reduction paper code

Support Tasks (TODO List)

  • Image Classification
  • LLM Inference
  • Long Sequence Training
  • Throughput
  • ToMe Visualization
  • Optimizers
  • AI for Science
  • Image Generation
  • MLLM Inference

Summary Functions

Method Published Support Source Matrix/Map Support Visualization
ToMe ICLR’23 Matrix, Map
DiffRate ICCV’23 Matrix
DCT ACL’23 None
ToFu WACV’24 Matrix, Map
MCTF CVPR’24 Matrix, Map
ATC ECCV’24 Map
DTEM NIPS’24 Matrix, Map
PiToMe NIPS’24 Matrix
FPET CVPR’25 Matrix

Citation

If you find this repository useful, please consider giving a star ⭐ and citation:

@article{2025opentome,
  title = {OpenToMe},
  author = {Siyuan Li and Xin Jin and Kai Yu},
  year = {2025},
  url={https://github.com/Westlake-AI/OpenToMe},
  urldate = {2025-08-15},
}

(back to top)

About

[Developing] Toolbox and Benchmark for Token Merging Modules

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors 6