Official PyTorch implementation of "Towards Universal Gene Regulatory Network Inference: Unlocking Generalizable Regulatory Knowledge in Single-cell Foundation Models", accepted to ICML 2026.
Gene Regulatory Network (GRN) inference from single-cell RNA sequencing data remains a fundamental challenge in computational biology. Existing methods typically operate in a closed-world setting: a specialized model is optimized on a fixed gene set and struggles to generalize to unseen genes or heterogeneous datasets due to dimension mismatches.
UGRN introduces a universal, transfer-learning-based framework that leverages frozen single-cell foundation models (scFMs) for generalizable feature extraction. A lightweight downstream "translator"
Figure 1. (a) Traditional GRN inference operates in a closed-world setting, where optimized
UGRN features three complementary feature extraction strategies:
| Mode | Description |
|---|---|
| Embedding | Concatenated gene embeddings from the frozen scFM |
| Perturbation | In-silico expression perturbation to probe model response |
| Gradient (Ours) | Integrated gradients for directional regulatory signals |
The full UGRN method ensembles perturbation and gradient features via weighted late fusion, achieving state-of-the-art cross-dataset generalization on the GENELink benchmark.
git clone https://github.com/simpleshinobu/UGRN.git
cd UGRN
pip install -r requirements.txt- Python >= 3.9
- PyTorch >= 2.0.0
- See
requirements.txtfor full dependencies
Download the pre-trained model checkpoint and place under ./checkpoints/:
mkdir -p checkpoints
# Download model.pt and args.jsonDownload: Google Drive
Expected structure:
checkpoints/
├── model.pt
└── args.json
vocab.json is included in this repository and will be loaded automatically.
Download and extract to ./data/:
mkdir -p data
# Download Benchmark Dataset.zip and extract to data/Download: Google Drive
Expected structure:
data/
└── GENELink/
└── Dataset/
└── Benchmark Dataset/
├── STRING Dataset/
│ └── hESC/TFs+500/
│ ├── Label.csv
│ ├── TF.csv
│ ├── Target.csv
│ └── BL--ExpressionData.csv
├── Non-Specific Dataset/
├── Lofgof Dataset/
└── Specific Dataset/
Run with perturbation features (default):
python transfer_base.py \
--model_ckpt checkpoints/model.pt \
--genelink_root "data/GENELink/Dataset/Benchmark Dataset" \
--train_source STRING:hESC \
--save_root results/transfer_baseRun with embedding features:
python transfer_base.py \
--model_ckpt checkpoints/model.pt \
--genelink_root "data/GENELink/Dataset/Benchmark Dataset" \
--feature_type embedding \
--train_source STRING:hESC \
--save_root results/transfer_base_embRun the complete UGRN ensemble (perturbation + gradient):
python transfer_ours.py \
--model_ckpt checkpoints/model.pt \
--genelink_root "data/GENELink/Dataset/Benchmark Dataset" \
--train_source STRING:hESC \
--save_root results/transfer_oursRun individual modes:
# Perturbation only
python transfer_ours.py \
--model_ckpt checkpoints/model.pt \
--genelink_root "data/GENELink/Dataset/Benchmark Dataset" \
--mode perturbation \
--train_source STRING:hESC \
--save_root results/transfer_ours_pert
# Gradient only
python transfer_ours.py \
--model_ckpt checkpoints/model.pt \
--genelink_root "data/GENELink/Dataset/Benchmark Dataset" \
--mode grad \
--train_source STRING:hESC \
--save_root results/transfer_ours_grad| Argument | Description | Default |
|---|---|---|
--model_ckpt |
Path to pre-trained model checkpoint | ./checkpoints/model.pt |
--genelink_root |
Path to GENELink dataset root | ./data/GENELink/Dataset/Benchmark Dataset |
--train_source |
Training source as Network:Dataset |
STRING:hESC |
--feature_type |
Feature type (embedding / perturbation / combined) |
Script-dependent |
--mode |
Modes for transfer_ours (perturbation / grad) |
perturbation grad |
--runs |
Number of random runs | 3 |
--seed |
Random seed | 42 |
--save_root |
Directory to save results | results/... |
UGRN/
├── ugrn/ # Core package
│ ├── __init__.py
│ ├── model.py # scModel: single-cell foundation model wrapper
│ ├── vocab.py # GeneVocab: gene token vocabulary
│ ├── dataset.py # GENELink data loading utilities
│ └── config.py # Dataset and experiment configurations
├── transfer_base.py # Baseline: embedding / perturbation features
├── transfer_ours.py # UGRN: perturbation + gradient ensemble
├── vocab.json # Gene vocabulary
├── requirements.txt
└── README.md
If you find this work useful, please consider citing:
@inproceedings{qi2026ugrn,
title={Towards Universal Gene Regulatory Network Inference: Unlocking Generalizable Regulatory Knowledge in Single-cell Foundation Models},
author={Qi, Jiaxin and Li, Hang and Cui, Yan and Zheng, Yuhua and Huang, Jianqiang},
booktitle={International Conference on Machine Learning (ICML)},
year={2026}
}This project is released under the MIT License.
