Aligning crystal structures, XRD patterns, and text descriptions in a shared embedding space
Overview | Results | Architecture | Quick Start | Training | Citation
MatCLIP bridges three fundamental representations of materials into a unified 256-dimensional embedding space using contrastive learning:
O --- Ti --- O
| |
Ba Ba
| |
O --- Ti --- O
3D atomic arrangementencoded as a graph (CGCNN) |
| | || | ||| | |||||||| | | |||||||||||||||| ________________ 5 2theta 90X-ray diffraction spectrum encoded by 1D ResNet |
"BaTiO3 is a tetragonal material with space group P4mm. It has a band gap of 3.38 eV..."Natural language summary encoded by MatSciBERT |
Once aligned, the shared space enables:
- Cross-modal retrieval: Given a crystal structure, find its text description (or vice versa)
- Zero-shot classification: Classify materials using text prompts without labeled training data
- Transfer learning: Use pre-trained embeddings for downstream property prediction
Inspired by CLIP (Radford et al., 2021) and ImageBind (Girdhar et al., 2023), adapted for the materials science domain.
Trained on 9,751 materials from the Materials Project. Evaluated on 976 held-out test materials.
| Task | R@1 | R@5 | R@10 | MRR |
|---|---|---|---|---|
| Text to Crystal | 94.0% | 99.5% | 99.8% | 0.966 |
| Crystal to Text | 90.5% | 99.4% | 99.9% | 0.948 |
| XRD to Text | 23.8% | 51.7% | 61.6% | 0.371 |
| XRD to Crystal | 16.8% | 43.4% | 55.7% | 0.298 |
| Crystal to XRD | 16.4% | 44.9% | 58.2% | 0.297 |
| Random baseline | ~0.1% | ~0.5% | ~1.0% | ~0.005 |
Dataset: 9,751 materials (Materials Project, stable only)
Parameters: 113M total, 17.9M trainable (84% frozen)
Training: 200 epochs, 3.3 hours on NVIDIA TITAN RTX 24GB
Best Val Loss: 2.075 (75% reduction from start)
Overfitting: None (train-val gap stable at ~1.5 for 200 epochs)
| v1 (overfit) | v2 (regularized) | |
|---|---|---|
| Dataset | 2,480 materials | 9,751 materials |
| Val loss trajectory | 4.87 then rose to 6.48 | Dropped to 2.075 and stayed |
| Train-val gap | 6.44 (diverged) | 1.50 (stable) |
| Best R@1 (Crystal to Text) | 80.2% | 90.5% |
| Best R@1 (Text to Crystal) | N/A | 94.0% |
Anti-overfitting techniques applied: label smoothing (0.1), dropout (0.3), weight decay (5e-4), XRD noise augmentation, text word dropout (15%), 10/12 BERT layers frozen, early stopping (patience=25).
ENCODERS PROJECTION SHARED SPACE
-------- ---------- ------------
Crystal Structure --> [ CGCNN Encoder ] --> [ MLP + L2 ] -->
(graph: atoms+bonds) [ 3 conv layers, 256d ] [ Norm ] \
\
XRD Pattern --------> [ 1D ResNet Encoder ] --> [ MLP + L2 ] -----> [ 256-dim ]
(512-point signal) [ 3 ResBlocks, 256d ] [ Norm ] [ Shared ]
[ Embedding ]
Text Description ---> [ MatSciBERT ] --> [ MLP + L2 ] -----> [ Space ]
(natural language) [ 10/12 layers frozen] [ Norm ] /
[ 768d output ] /
/
Loss: InfoNCE + Label Smoothing
L = L(crystal,text) + L(crystal,xrd) + 0.5*L(xrd,text)
| Component | Architecture | Input | Output | Trainable Params |
|---|---|---|---|---|
| Crystal Encoder | CGCNN (3 conv layers + global mean pool) | PyG graph (atoms, bonds) | 256-dim | ~1.2M |
| XRD Encoder | 1D ResNet (3 stages of ResBlocks) | 512-point intensity signal | 256-dim | ~0.8M |
| Text Encoder | MatSciBERT (12-layer Transformer) | Tokenized text (max 256) | 768-dim | ~1.5M (2 layers unfrozen) |
| Projection Heads | MLP (Linear-BN-ReLU-Dropout-Linear) | Encoder output | 256-dim L2-normalized | ~0.9M each |
# Clone
git clone https://github.com/YOUR_USERNAME/MatCLIP.git
cd MatCLIP
# Create environment
python3 -m venv venv && source venv/bin/activate
# Install PyTorch (adjust CUDA version as needed)
pip install torch torchvision --index-url https://download.pytorch.org/whl/cu128
# Install dependencies
pip install torch-geometric pymatgen mp-api transformers sentencepiece \
scikit-learn scipy matplotlib seaborn umap-learn tqdm pyyaml
# Install MatCLIP
pip install -e .- Register at materialsproject.org
- Go to Dashboard > API > Generate API Key
- Set the environment variable:
export MP_API_KEY="your_api_key_here"# Download materials + generate XRD + build text descriptions
python3 scripts/prepare_large.pyThis will:
- Download 10,000 stable materials from Materials Project
- Generate simulated XRD patterns via pymatgen
- Build natural language descriptions from computed properties
- Create train/val/test splits (7,800 / 975 / 976)
# v2 training with anti-overfitting (recommended)
python3 scripts/train_v2.py
# Original training (for comparison)
python3 -m scripts.train --config configs/default.yamlpython3 -m scripts.evaluate --checkpoint checkpoints_v2/best_model.ptMatCLIP/
|
|-- matclip/ # Core library
| |-- data/
| | |-- mp_downloader.py # Materials Project API client
| | |-- xrd_generator.py # Simulated XRD via pymatgen
| | |-- text_builder.py # Natural language descriptions
| | |-- dataset.py # PyTorch dataset + graph builder
| |
| |-- encoders/
| | |-- crystal_encoder.py # CGCNN graph neural network
| | |-- xrd_encoder.py # 1D ResNet for spectra
| | |-- text_encoder.py # MatSciBERT wrapper
| | |-- matclip_model.py # Unified model + projection heads
| |
| |-- losses/
| | |-- contrastive.py # InfoNCE + multi-modal loss
| |
| |-- utils/
| |-- metrics.py # Recall@K, MRR, F1, alignment
| |-- visualization.py # t-SNE, UMAP, retrieval plots
|
|-- scripts/
| |-- prepare_data.py # Standard data prep
| |-- prepare_large.py # Fast large-scale data prep
| |-- train.py # Training v1
| |-- train_v2.py # Training v2 (anti-overfitting)
| |-- evaluate.py # Full evaluation pipeline
| |-- baselines.py # Baseline comparisons
|
|-- configs/
| |-- default.yaml # Full hyperparameter config
| |-- quick_train.yaml # Fast iteration config
|
|-- notebooks/
| |-- demo.ipynb # Interactive demo notebook
| |-- kaggle_train.ipynb # Self-contained Kaggle GPU notebook
|
|-- tests/
| |-- test_encoders.py # Unit tests (9/9 passing)
|
|-- results/ # Training logs and figures
|-- REPORT.md # Full project report
|-- requirements.txt
|-- setup.py
| Parameter | Value | Purpose |
|---|---|---|
| Batch size | 64 | Larger batches = more negatives per sample |
| Learning rate | 1e-4 | Lower LR for stable convergence |
| Weight decay | 5e-4 | L2 regularization |
| Dropout | 0.3 | Prevent co-adaptation |
| Label smoothing | 0.1 | Prevent overconfident predictions |
| XRD noise (std) | 0.05 | Simulate measurement noise |
| XRD shift | +/- 3 pts | Simulate calibration error |
| Text word dropout | 15% | Force robust text understanding |
| BERT frozen layers | 10/12 | Reduce trainable capacity |
| Early stopping | patience 25 | Stop before overfitting |
| Scheduler | Cosine + 10-epoch warmup | Smooth LR decay |
A fully self-contained notebook is provided at notebooks/kaggle_train.ipynb. Upload to Kaggle, enable GPU T4, and run all cells. No external files needed.
Materials Project API
|
v
10,000 stable crystal structures
|
+---> pymatgen XRD Calculator ---> 512-point XRD patterns
|
+---> Property-based templates ---> Text descriptions
|
+---> PyG graph builder ---------> Crystal graphs (nodes=atoms, edges=bonds)
|
v
Paired dataset: (crystal_graph, xrd_pattern, text_description)
|
v
Train: 7,800 / Val: 975 / Test: 976
For each batch of 64 materials:
- Encode all three modalities into 256-dim vectors
- Compute cosine similarity matrices between each pair
- Apply InfoNCE loss: push matching pairs together, non-matching pairs apart
- Crystal is the "anchor" modality (following ImageBind)
Crystal <----(weight 1.0)----> Text
Crystal <----(weight 1.0)----> XRD
XRD <----(weight 0.5)----> Text (emerges through shared space)
- Recall@K: Is the correct match in the top-K retrieved results?
- MRR: Mean Reciprocal Rank (average of 1/rank of correct match)
- Alignment: Mean cosine similarity of positive pairs (higher = better)
- Uniformity: How spread out embeddings are on the hypersphere (lower = better)
| Component | Technology |
|---|---|
| Deep Learning | PyTorch 2.11, CUDA 12.8 |
| Graph Neural Networks | PyTorch Geometric 2.7 |
| Crystal Structures | pymatgen 2026.3 |
| XRD Simulation | pymatgen.analysis.diffraction.xrd |
| Text Encoder | MatSciBERT (HuggingFace Transformers) |
| Materials Data | Materials Project API (mp-api) |
| Visualization | matplotlib, seaborn, t-SNE, UMAP |
| Hardware | NVIDIA TITAN RTX 24GB |
| Paper | Venue | Relevance |
|---|---|---|
| CLIP (Radford et al.) | ICML 2021 | Core architecture inspiration |
| ImageBind (Girdhar et al.) | CVPR 2023 | Anchor modality alignment strategy |
| CGCNN (Xie & Grossman) | PRL 2018 | Crystal graph encoder |
| MatSciBERT (Gupta et al.) | ACL 2022 | Domain-specific text encoder |
| GNoME (Merchant et al.) | Nature 2023 | Scaling crystal discovery with GNNs |
| MultiMat (Moro et al.) | Newton 2025 | Multi-modal foundation model for materials |
| SimCLR (Chen et al.) | ICML 2020 | Projection head design + contrastive framework |
| Wang & Isola | ICML 2020 | Alignment and uniformity metrics |
# Run all tests (9/9 passing)
python3 -m pytest tests/ -vTests cover: XRD encoder shapes, crystal encoder shapes, Gaussian expansion, InfoNCE loss, multi-modal loss, retrieval metrics, projection head normalization, and text builder output.
MIT
Built with PyTorch, pymatgen, and MatSciBERT
MatCLIP is an academic research project exploring multi-modal contrastive learning for materials science.