Accurate Predictions of Molecular Properties of Proteins via Graph Neural Networks and Transfer Learning
This repository contains code and pre-trained weights for Graph Transformer Networks that are trained for the computational prediction of protein properties. GSnet is adept at predicting a variety of physicochemical properties from three-dimensional protein structures, while aLCnet was specifically trained for residue-specific pKa prediction. Moreover, the application of transfer learning allows these models to utilize previously learned representations (i.e., embeddings) in new prediction tasks, even with limited specific training data.
Properties that these models can predict include:
- Free energy of solvation (
$\Delta G_{sol}$ ) - Hydrodynamic radius (
$R_h$ ) - Translational diffusion coefficient (
$D_t$ ) - Rotational diffusion coefficient (
$D_r$ ) - Molecular volume (
$V$ ) - Radius of gyration (
$R_g$ ) - Solvent accessible surface area (
$SASA$ ) -
$pK_a$ values
Paper
For more information about this project, you can access the paper at the following link:
To cite our project in your work, use:
@article{doi:10.1021/acs.jctc.4c01682,
author = {Wozniak, Spencer and Janson, Giacomo and Feig, Michael},
title = {Accurate Predictions of Molecular Properties of Proteins via Graph Neural Networks and Transfer Learning},
journal = {Journal of Chemical Theory and Computation},
volume = {21},
number = {9},
pages = {4830-4845},
year = {2025},
doi = {10.1021/acs.jctc.4c01682},
note = {PMID: 40270304},
URL = {https://doi.org/10.1021/acs.jctc.4c01682},
eprint = {https://doi.org/10.1021/acs.jctc.4c01682}
}Table of Contents
Before you can run the models, you need to set up your environment:
- Clone the repository:
git clone https://github.com/feiglab/ProteinStructureEmbedding.git
cd ProteinStructureEmbedding- Set up the environment:
- If you use a virtual environment, set it up and activate it before installing the packages:
conda create -n gsnet python=3.9 pip
conda activate gsnet- Install required packages:
- Use Python 3.9 on your system.
- Install the required Python packages using:
pip install -r requirements.txtThe predict.py script predicts physicochemical properties from PDB files or from precomputed NPZ representations. Run from the src/ directory (or with src on PYTHONPATH). You can pass multiple input files.
- Default — Six global properties: ΔG, RG, RH, DT, DR, V.
--sasa— Same six plus SASA.--pka— Residue-level pKa for titratable residues (ASP, CYS, GLU, HIS, LYS, TYR). Use--atomicfor aLCnet (faster and more accurate); omit for GSnet. Use--shiftto output pKa shift from the standard value instead of absolute pKa.
pKa output is tabular with a header. Columns: Predicted, (Observed if --show-label), AA, Res, Chain, File.
- PDB — One or more
.pdbfiles. Use--cleanto strip non-standard residues/atoms before prediction;--keepkeeps the cleaned files.--chainrestricts to one chain;--combine-chainsmerges chains into one structure. --numpy— One or more.npzfiles (e.g. fromgenerate_datasets.py). Same prediction modes and outputs as PDB. Residue index and chain are taken from the NPZ when present, otherwise parsed from the filename when it follows{pdb}_{chain}_{resid}.npz.
With --numpy and --show-label, the script prints an Observed column. NPZ labels are stored as shifts; the script converts them to absolute pKa when --shift is not used (using the residue type). If a file has no label key, Observed is printed as -.
--state-dict PATH loads a custom state dict for the main model instead of the default .pt file. Loading is loose (strict=False); if no parameters match the selected architecture, the script exits with an error. Supports raw state dicts or checkpoint dicts with a state_dict / model_state_dict key.
--cpu/--gpu— Force device (default: GPU if available).--time— Print timing for loading and forward pass.--skip-bad-files— Skip inputs that fail to load instead of raising.
Default (global properties) from PDB:
python predict.py /path/to/file.pdbpKa (absolute) with aLCnet:
python predict.py --pka --atomic /path/to/file.pdbpKa shifts from NPZ with observed labels and custom weights:
python predict.py --pka --atomic --numpy --shift --show-label --state-dict ../models/tr_25_test.pt ../pKa-datasets/msu-test-data/*.npzSample pKa output (with header and observed):
Predicted Observed AA Res Chain File
-0.99 -1.67 ASP 93 C ../pKa-datasets/msu-test-data/1A2P_C_93.npz
...
For all options: python predict.py -h
The plot_predictions.py script plots predicted vs observed from the text output of predict.py. It expects output that includes both a Predicted and an Observed column (e.g. pKa runs with --show-label). Run from the src/ directory.
Default behavior: the script saves the figure to predictions.png in the current directory (it does not open a window). Use --show to display the figure in a GUI instead of saving. Use -o path or --save path to save to a different path.
- From file: one or more files containing
predict.pystdout (e.g. after redirecting:python predict.py ... > out.txt). - From stdin (pipe): use
-as the input so you can pipepredict.pydirectly into the plot script.
Pipe the output of predict.py into plot_predictions.py. By default the figure is saved as predictions.png:
python predict.py --pka --atomic --numpy --shift --show-label ../pKa-datasets/msu-test-data/*.npz | python plot_predictions.py -Use -o to save under a different name:
python predict.py --pka --atomic --numpy --shift --show-label ... | python plot_predictions.py - -o pka_plot.pngYou can add other plot options before the output path:
python predict.py --pka --atomic --numpy --show-label ... | python plot_predictions.py - --title "pKa" --linreg -o pka.pngSave predict.py output to a file, then pass that file to the plot script (figure is saved as predictions.png unless you pass -o):
python predict.py --pka --atomic --numpy --show-label ... > pka_out.txt
python plot_predictions.py pka_out.txtTo save under a specific path:
python plot_predictions.py pka_out.txt -o pka_plot.pngYou can pass multiple files; their data are combined into one plot.
--hexbin— Use hexbin instead of scatter (default: scatter).--linreg— Add linear regression line (default: on).--mape— Use MAPE instead of Pearson r on the 1:1 line.--title,--xlabel,--ylabel,--units— Labels and title.--clean— Drop rows where reference (x) < -700.--save/-o— Save figure to this path (default when not using--show:predictions.png).--show— Display the figure in a GUI instead of saving.--dpi— Resolution for saved figure (default 600).
For all options: python plot_predictions.py -h
| Model Name | Number of Parameters | Description |
|---|---|---|
GSnet_default.pt |
5,971,748 | The original GSnet model trained on the 6 physicochemical properties. |
GSnet_SASA.pt |
5,971,748 | GSnet fine-tuned for molecular SASA predictions. |
GSnet_pKa.pt |
11,210,392 | GSnet fine-tuned for residue-level SASA, then further trained to predict pKa values. |
aLCnet_pKa.pt |
4,784,324 | aLCnet trained from scratch on PHMD549 data and fine-tuned for pKa prediction on experimental data. |
The embed_GSnet.py and embed_aLCnet.py scripts allow you to easily generate embeddings for all PDB files within a specified directory.
- Generated embeddings (via either method) will be saved as tensors of shape
[N,d]whereNis the number of residues in the protein anddis the embedding dimension.
-
Gather PDB files: Put PDB files containing only 1 chain that you would like embeddings for into a directory. Make sure the file extension for the files is
.pdb. -
Run the script: Navigate to the
src/directory in your terminal. Use the following command(s) to generate embeddings:
python embed_GSnet.py --protein/--residue PDBPATH OUTPATHReplace PDBPATH with the directory containing your PDB files and OUTPATH with the directory where you want to save the embeddings.
- Use the
--proteinoption to generate GSnet embeddings optimized for whole protein predictions (trained on 6 physicochemical properties). - Use the
--residueoption to generate GSnet embeddings optimized for residue-specific predictions (fine-tuned on rSASA and pKa). - In theory, either embedding method (
--proteinor--residue) may be useful in either context. It could be worthwhile to try both embeddings for the same task to determine which is more useful.
python embed_aLCnet.py PDBPATH OUTPATHReplace PDBPATH with the directory containing your PDB files and OUTPATH with the directory where you want to save the embeddings.
- This will take longer than GSnet embeddings because separate graphs will be constructed for atoms around each residue, rather than for the whole protein.
- The scripts utilize multiprocessing to expedite the embedding process. Ensure your system has adequate resources to handle multiple processes simultaneously.
Datasets for GSnet and aLCnet can be created in two ways: (1) using the generate_datasets.py script from residue-level CSV files (recommended when you have PDB IDs and pKa or other residue-level targets), or (2) manually with the dataset.py classes NumpyRep / NumpyRep_atomic and then ProteinDataset / AtomicDataset. The script automates downloading, chain extraction, and NPZ generation; the manual approach is for custom pipelines or when your data is already in a different format.
The src/generate_datasets.py script builds GSnet and/or aLCnet datasets from residue-level CSV files. It downloads PDBs from the RCSB if missing, extracts a single chain, writes cleaned single-chain PDBs to disk, and generates NPZ files (and output CSVs) in the format expected by ProteinDataset and AtomicDataset.
Input CSV format
Each input CSV must have the following columns:
| Column | Description |
|---|---|
PDB |
PDB ID (e.g. 1abc) or path to a local PDB file |
CHAIN |
Chain ID to use (e.g. A) |
RES |
Residue name (informational) |
RES_IDX |
Residue index (1-based) for the target residue |
PKA |
Target value (e.g. pKa) for that residue |
How the script works
- For each row in the CSV, the script resolves the PDB file: if
PDBlooks like a PDB ID (e.g.1abc), it downloads the file from the RCSB into the output tree if not already present; otherwise it uses the given path. - It extracts the requested chain and keeps only standard amino acids (same filtering as the prediction script), writing a cleaned single-chain PDB (e.g.
{PDB}_{CHAIN}.pdb) into the output directory. - Using the cleaned PDB and the row’s
RES_IDXandPKA, it builds the appropriate representations and saves NPZ files for GSnet and/or aLCnet (depending on--dataset). - It writes summary CSVs (
csv/gsnet.csvand/orcsv/alcnet.csv) that list the final PDB paths and target values, matching the layout expected by the manual workflow below.
Command
From the project root (or with src on PYTHONPATH):
python src/generate_datasets.py --input_csv /path/to/file1.csv [/path/to/file2.csv ...] --outdir /path/to/output [--dataset gsnet|alcnet|both]--input_csv: One or more input CSV files (residue-level, format above).--outdir: Root directory under which all outputs are written.--dataset:gsnet,alcnet, orboth(default:both). Controls whether to generate GSnet NPZs, aLCnet NPZs, or both.
Output layout
For each input CSV file, the script creates a subdirectory under --outdir named after the CSV (without the .csv extension). Inside that subdirectory:
| Path | Contents |
|---|---|
pdbs/ |
Downloaded and/or chain-extracted PDBs (e.g. 1abc.pdb, 1abc_A.pdb). |
npz/ |
NPZ files: gsnet_0.npz, gsnet_1.npz, ... and/or alcnet_0.npz, alcnet_1.npz, ... (one per CSV row). |
csv/ |
Summary CSVs: gsnet.csv (columns PDB, Target) and/or alcnet.csv (columns PDB, Res, Target). |
The NPZ files have the same structure as in the manual workflow below. You can load them with ProteinDataset (for npz/gsnet_*.npz) or AtomicDataset (for npz/alcnet_*.npz) by passing the npz directory as root (e.g. root='/path/to/output/my_dataset/npz' for a dataset named my_dataset). Splitting into train/val/test is done by organizing or symlinking NPZ directories and then creating separate ProteinDataset / AtomicDataset instances for each split.
- Have paths to PDBs and target values stored in a CSV file (or similar):
PDB,Target Value
/path/to/file1.pdb,4.10
/path/to/file2.pdb,6.21
/path/to/file3.pdb,7.94
...- Generate NumPy representations of the data:
import numpy as np
import pandas as pd
from dataset import NumpyRep
outdir = '/path/to/output/dir'
df = pd.read_csv('/path/to/file.csv') # Read CSV file
# Iterate over datapoints in dataset (this can be expidited with multiprocessing)
for i, row in df.iterrows():
rep = NumpyRep(row[0]) # Create a NumpyRep for PDB
y = float(row[1]) # Extract target value
# We want to generate NPZ files for each datapoint
np.savez(
f'{outdir}/{i}.npz', # Define output file path
label = y, # Define target value
x = rep.x, # Define Cartesian coordinates of residues
a = rep.get_aas(), # Define residue types
dh = rep.get_dh(), # Define dihedral information
cc = rep.get_cc() # Define alpha carbon to center of mass distance
)- Generate a PyTorch dataset:
import numpy as np
from dataset import ProteinDataset
dataset = ProteinDataset(
root='/path/to/output/dir', # Path to directory containing NPZ files
use_dh=True, # Specify that dihedral info is used
use_cc=True, # Specify that ca-cofm distance is used
normalize=True # Normalize target values
)- Have paths to PDBs, residue indicies, and target values stored in a CSV file (or similar):
PDB,Res,Target Value
/path/to/file1.pdb,24,4.10
/path/to/file2.pdb,54,6.21
/path/to/file3.pdb,91,7.94
...- Generate NumPy representations of the data:
import numpy as np
import pandas as pd
from dataset import NumpyRep_atomic
outdir = '/path/to/output/dir'
df = pd.read_csv('/path/to/file.csv') # Read CSV file
# Iterate over datapoints in dataset (this can be expidited with multiprocessing)
for i, row in df.iterrows():
rep = NumpyRep_atomic(row[0],row[1]) # Create a NumpyRep for residue in PDB
y = float(row[2]) # Extract target value
# We want to generate NPZ files for each datapoint
np.savez(
f'{outdir}/{i}.npz', # Define output file path
label = y, # Define target value
x = rep.x, # Define Cartesian coordinates of residues
a = rep.a, # Define residue types
atoms = rep.atoms, # Define atom types
charge = rep.charge, # Define atom charges
resid_atomic=rep.resid_atomic, # Define residue atom indicies
resid_ca=rep.resid_ca, # Define alpha-carbon index
)- Generate a PyTorch dataset:
import numpy as np
from dataset import AtomicDataset
dataset = AtomicDataset(
root='/path/to/output/dir', # Path to directory containing NPZ files
normalize=True # Normalize target values
)- You can split the NPZ data into multiple directories to have training, validation, test sets via any method you choose. You can then load multiple PyTorch datasets.
Sample training scripts train_GSnet.py and train_aLCnet.py are provided for training GSnet and aLCnet, respectively.
To train a new model:
- Make sure you have PyTorch datasets generated. See Generating Datasets for more info.
- See the
train_GSnet.pyandtrain_aLCnet.pyscripts for examples on how to train our models. Sample data for training both GSnet and aLCnet was provided for selected structures.
- Install the repo — Installation
- Generate NPZ data sets — Generating Datasets
- Make predictions — Making Predictions
- Plot predictions — Run
plot_predictions.pyon the tabular output ofpredict.py. Use--show-labelwithpredict.pyso the output includes an Observed column; then pipe that output (optionally filtered, e.g. by residue type) intoplot_predictions.py. By default the plot is saved topredictions.png; use-o pathor--showto change the output.
Example — pKa predictions on a generated dataset, excluding CYS and TYR, then plotting:
cd src
python predict.py --pka --atomic --shift --numpy --show-label /path/to/dataset/*.npz | grep -v "CYS" | grep -v "TYR" | python plot_predictions.pyTo save the figure under a different path: add -o figure.png after plot_predictions.py. Use python predict.py -h and python plot_predictions.py -h for all options.
For more info, or if you have any questions, please email me at hey@spencerwozniak.com
Here’s a brief overview of the directory structure:
ProteinStructureEmbedding/
│
├── pKa-datasets/ # Our datasets for pKa training.
| |
| ├── MSU-pKa-train.csv # Our training dataset.
| ├── MSU-pKa-val.csv # Our validation dataset.
| └── MSU-pKa-test.csv # Our test dataset.
|
├── src/ # Source code of our application.
| |
| ├── dataset.py # Classes for processing PDBs and generating datasets
| ├── generate_datasets.py # Script to generate GSnet/aLCnet datasets from residue-level CSVs
| ├── net.py # Neural network architectures used in our project
| ├── predict.py # Script for making predictions.
| ├── plot_predictions.py # Script to plot predicted vs observed from predict.py output.
| ├── embed_GSnet.py # Script that generates GSnet embeddings.
| ├── embed_aLCnet.py # Script that generates aLCnet embeddings.
| ├── train_GSnet.py # Script for training GSnet.
| ├── train_aLCnet.py # Script for training aLCnet.
| └── time.sh # Script for timing the running of a script.
|
├── models/ # State dictionaries containing weights and biases of the models
| |
| ├── GSnet_default.pt # Original pretrained GSnet.
| ├── GSnet_SASA.pt # GSnet fine-tuned for SASA predictions.
| ├── GSnet_pKa.pt # GSnet fine-tuned for pKa predictions.
| ├── aLCnet_pKa.pt # aLCnet trained for pKa predictions.
| └── normalization.npz # Normalization parameters.
|
├── sample_data/ # Sample data provided for running certain scripts
| |
| ├── time_test/ # Directory containing PDB structures used to test the speed of GSnet
| ├── GSnet/ # Directory containing sample training and test sets for retraining GSnet.
| └── aLCnet/ # Directory containing sample training and test sets for retraining aLCnet.
|
├── requirements.txt # Required Python packages to run our model.
└── README.md # The file you are currently reading.

