Skip to content

ghostiee-11/MatCLIP

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

MatCLIP

Multi-Modal Contrastive Learning for Materials Science

Python 3.10+ PyTorch 2.x License: MIT Materials Project

Aligning crystal structures, XRD patterns, and text descriptions in a shared embedding space

Overview | Results | Architecture | Quick Start | Training | Citation


Overview

MatCLIP bridges three fundamental representations of materials into a unified 256-dimensional embedding space using contrastive learning:

Crystal Structure

    O --- Ti --- O
    |           |
    Ba          Ba
    |           |
    O --- Ti --- O
3D atomic arrangement
encoded as a graph (CGCNN)

XRD Pattern

  |
  |  ||
  | ||| |
  |||||||| |  |
  ||||||||||||||||
  ________________
  5    2theta   90
X-ray diffraction spectrum
encoded by 1D ResNet

Text Description

"BaTiO3 is a tetragonal
material with space group
P4mm. It has a band gap
of 3.38 eV..."
Natural language summary
encoded by MatSciBERT

Once aligned, the shared space enables:

  • Cross-modal retrieval: Given a crystal structure, find its text description (or vice versa)
  • Zero-shot classification: Classify materials using text prompts without labeled training data
  • Transfer learning: Use pre-trained embeddings for downstream property prediction

Inspired by CLIP (Radford et al., 2021) and ImageBind (Girdhar et al., 2023), adapted for the materials science domain.


Results

Trained on 9,751 materials from the Materials Project. Evaluated on 976 held-out test materials.

Retrieval Performance

Task R@1 R@5 R@10 MRR
Text to Crystal 94.0% 99.5% 99.8% 0.966
Crystal to Text 90.5% 99.4% 99.9% 0.948
XRD to Text 23.8% 51.7% 61.6% 0.371
XRD to Crystal 16.8% 43.4% 55.7% 0.298
Crystal to XRD 16.4% 44.9% 58.2% 0.297
Random baseline ~0.1% ~0.5% ~1.0% ~0.005

Training Highlights

Dataset:         9,751 materials (Materials Project, stable only)
Parameters:      113M total, 17.9M trainable (84% frozen)
Training:        200 epochs, 3.3 hours on NVIDIA TITAN RTX 24GB
Best Val Loss:   2.075 (75% reduction from start)
Overfitting:     None (train-val gap stable at ~1.5 for 200 epochs)

No Overfitting (v1 vs v2)

v1 (overfit) v2 (regularized)
Dataset 2,480 materials 9,751 materials
Val loss trajectory 4.87 then rose to 6.48 Dropped to 2.075 and stayed
Train-val gap 6.44 (diverged) 1.50 (stable)
Best R@1 (Crystal to Text) 80.2% 90.5%
Best R@1 (Text to Crystal) N/A 94.0%

Anti-overfitting techniques applied: label smoothing (0.1), dropout (0.3), weight decay (5e-4), XRD noise augmentation, text word dropout (15%), 10/12 BERT layers frozen, early stopping (patience=25).


Architecture

                        ENCODERS                    PROJECTION         SHARED SPACE
                        --------                    ----------         ------------

Crystal Structure  -->  [ CGCNN Encoder      ]  --> [ MLP + L2 ] -->
(graph: atoms+bonds)    [ 3 conv layers, 256d ]     [ Norm     ]     \
                                                                       \
XRD Pattern  -------->  [ 1D ResNet Encoder  ]  --> [ MLP + L2 ] -----> [ 256-dim    ]
(512-point signal)      [ 3 ResBlocks, 256d  ]     [ Norm     ]       [ Shared     ]
                                                                       [ Embedding  ]
Text Description  --->  [ MatSciBERT         ]  --> [ MLP + L2 ] -----> [ Space      ]
(natural language)      [ 10/12 layers frozen]     [ Norm     ]     /
                        [ 768d output        ]                    /
                                                                 /
                        Loss: InfoNCE + Label Smoothing
                        L = L(crystal,text) + L(crystal,xrd) + 0.5*L(xrd,text)

Encoder Details

Component Architecture Input Output Trainable Params
Crystal Encoder CGCNN (3 conv layers + global mean pool) PyG graph (atoms, bonds) 256-dim ~1.2M
XRD Encoder 1D ResNet (3 stages of ResBlocks) 512-point intensity signal 256-dim ~0.8M
Text Encoder MatSciBERT (12-layer Transformer) Tokenized text (max 256) 768-dim ~1.5M (2 layers unfrozen)
Projection Heads MLP (Linear-BN-ReLU-Dropout-Linear) Encoder output 256-dim L2-normalized ~0.9M each

Quick Start

1. Installation

# Clone
git clone https://github.com/YOUR_USERNAME/MatCLIP.git
cd MatCLIP

# Create environment
python3 -m venv venv && source venv/bin/activate

# Install PyTorch (adjust CUDA version as needed)
pip install torch torchvision --index-url https://download.pytorch.org/whl/cu128

# Install dependencies
pip install torch-geometric pymatgen mp-api transformers sentencepiece \
    scikit-learn scipy matplotlib seaborn umap-learn tqdm pyyaml

# Install MatCLIP
pip install -e .

2. Get Materials Project API Key

  1. Register at materialsproject.org
  2. Go to Dashboard > API > Generate API Key
  3. Set the environment variable:
export MP_API_KEY="your_api_key_here"

3. Prepare Data

# Download materials + generate XRD + build text descriptions
python3 scripts/prepare_large.py

This will:

  • Download 10,000 stable materials from Materials Project
  • Generate simulated XRD patterns via pymatgen
  • Build natural language descriptions from computed properties
  • Create train/val/test splits (7,800 / 975 / 976)

4. Train

# v2 training with anti-overfitting (recommended)
python3 scripts/train_v2.py

# Original training (for comparison)
python3 -m scripts.train --config configs/default.yaml

5. Evaluate

python3 -m scripts.evaluate --checkpoint checkpoints_v2/best_model.pt

Project Structure

MatCLIP/
|
|-- matclip/                         # Core library
|   |-- data/
|   |   |-- mp_downloader.py        # Materials Project API client
|   |   |-- xrd_generator.py        # Simulated XRD via pymatgen
|   |   |-- text_builder.py         # Natural language descriptions
|   |   |-- dataset.py              # PyTorch dataset + graph builder
|   |
|   |-- encoders/
|   |   |-- crystal_encoder.py      # CGCNN graph neural network
|   |   |-- xrd_encoder.py          # 1D ResNet for spectra
|   |   |-- text_encoder.py         # MatSciBERT wrapper
|   |   |-- matclip_model.py        # Unified model + projection heads
|   |
|   |-- losses/
|   |   |-- contrastive.py          # InfoNCE + multi-modal loss
|   |
|   |-- utils/
|       |-- metrics.py              # Recall@K, MRR, F1, alignment
|       |-- visualization.py        # t-SNE, UMAP, retrieval plots
|
|-- scripts/
|   |-- prepare_data.py             # Standard data prep
|   |-- prepare_large.py            # Fast large-scale data prep
|   |-- train.py                    # Training v1
|   |-- train_v2.py                 # Training v2 (anti-overfitting)
|   |-- evaluate.py                 # Full evaluation pipeline
|   |-- baselines.py                # Baseline comparisons
|
|-- configs/
|   |-- default.yaml                # Full hyperparameter config
|   |-- quick_train.yaml            # Fast iteration config
|
|-- notebooks/
|   |-- demo.ipynb                  # Interactive demo notebook
|   |-- kaggle_train.ipynb          # Self-contained Kaggle GPU notebook
|
|-- tests/
|   |-- test_encoders.py            # Unit tests (9/9 passing)
|
|-- results/                        # Training logs and figures
|-- REPORT.md                       # Full project report
|-- requirements.txt
|-- setup.py

Training

Configuration (v2, anti-overfitting)

Parameter Value Purpose
Batch size 64 Larger batches = more negatives per sample
Learning rate 1e-4 Lower LR for stable convergence
Weight decay 5e-4 L2 regularization
Dropout 0.3 Prevent co-adaptation
Label smoothing 0.1 Prevent overconfident predictions
XRD noise (std) 0.05 Simulate measurement noise
XRD shift +/- 3 pts Simulate calibration error
Text word dropout 15% Force robust text understanding
BERT frozen layers 10/12 Reduce trainable capacity
Early stopping patience 25 Stop before overfitting
Scheduler Cosine + 10-epoch warmup Smooth LR decay

Kaggle GPU

A fully self-contained notebook is provided at notebooks/kaggle_train.ipynb. Upload to Kaggle, enable GPU T4, and run all cells. No external files needed.


How It Works

1. Data Pipeline

Materials Project API
        |
        v
  10,000 stable crystal structures
        |
        +---> pymatgen XRD Calculator ---> 512-point XRD patterns
        |
        +---> Property-based templates ---> Text descriptions
        |
        +---> PyG graph builder ---------> Crystal graphs (nodes=atoms, edges=bonds)
        |
        v
  Paired dataset: (crystal_graph, xrd_pattern, text_description)
        |
        v
  Train: 7,800 / Val: 975 / Test: 976

2. Contrastive Training

For each batch of 64 materials:

  1. Encode all three modalities into 256-dim vectors
  2. Compute cosine similarity matrices between each pair
  3. Apply InfoNCE loss: push matching pairs together, non-matching pairs apart
  4. Crystal is the "anchor" modality (following ImageBind)
Crystal <----(weight 1.0)----> Text
Crystal <----(weight 1.0)----> XRD
    XRD <----(weight 0.5)----> Text     (emerges through shared space)

3. Evaluation

  • Recall@K: Is the correct match in the top-K retrieved results?
  • MRR: Mean Reciprocal Rank (average of 1/rank of correct match)
  • Alignment: Mean cosine similarity of positive pairs (higher = better)
  • Uniformity: How spread out embeddings are on the hypersphere (lower = better)

Tech Stack

Component Technology
Deep Learning PyTorch 2.11, CUDA 12.8
Graph Neural Networks PyTorch Geometric 2.7
Crystal Structures pymatgen 2026.3
XRD Simulation pymatgen.analysis.diffraction.xrd
Text Encoder MatSciBERT (HuggingFace Transformers)
Materials Data Materials Project API (mp-api)
Visualization matplotlib, seaborn, t-SNE, UMAP
Hardware NVIDIA TITAN RTX 24GB

Key References

Paper Venue Relevance
CLIP (Radford et al.) ICML 2021 Core architecture inspiration
ImageBind (Girdhar et al.) CVPR 2023 Anchor modality alignment strategy
CGCNN (Xie & Grossman) PRL 2018 Crystal graph encoder
MatSciBERT (Gupta et al.) ACL 2022 Domain-specific text encoder
GNoME (Merchant et al.) Nature 2023 Scaling crystal discovery with GNNs
MultiMat (Moro et al.) Newton 2025 Multi-modal foundation model for materials
SimCLR (Chen et al.) ICML 2020 Projection head design + contrastive framework
Wang & Isola ICML 2020 Alignment and uniformity metrics

Tests

# Run all tests (9/9 passing)
python3 -m pytest tests/ -v

Tests cover: XRD encoder shapes, crystal encoder shapes, Gaussian expansion, InfoNCE loss, multi-modal loss, retrieval metrics, projection head normalization, and text builder output.


License

MIT


Built with PyTorch, pymatgen, and MatSciBERT

MatCLIP is an academic research project exploring multi-modal contrastive learning for materials science.

About

Multi-Modal Contrastive Learning for Materials Science - Aligning crystal structures, XRD patterns, and text descriptions

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors