Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion cortex/cmdline/train_cortex_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
import hydra
import lightning as L
import torch
import wandb
from omegaconf import DictConfig, OmegaConf

# ruff: noqa: I001
import wandb

from cortex.logging import wandb_setup


Expand All @@ -35,6 +37,7 @@ def execute(cfg):
"""
instantiate and train a multitask neural tree
"""
torch.set_float32_matmul_precision("medium")

trainer = hydra.utils.instantiate(cfg.trainer)

Expand Down
9 changes: 9 additions & 0 deletions cortex/config/hydra/branches/folding_encoder.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
folding:
_target_: cortex.model.branch.TransformerBranch
out_dim: 8
channel_dim: ${channel_dim}
num_blocks: 2
num_heads: 8
dropout_prob: ${dropout_prob}
is_causal: false
pooling_type: attention
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ protein_property:
out_dim: 8
channel_dim: ${channel_dim}
num_blocks: 2
num_heads: 4
num_heads: 8
dropout_prob: ${dropout_prob}
is_causal: false
pooling_type: attention
4 changes: 2 additions & 2 deletions cortex/config/hydra/model_globals/default.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# @package _global_
channel_dim: 128
channel_dim: 512
embed_dim: 32
ensemble_size: 4
ensemble_size: 2
dropout_prob: 0.0
kernel_size: 5
pooling_type: mean
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ protein_seq:
channel_dim: ${channel_dim}
num_blocks: 2
num_heads: 4
is_causal: false
is_causal: true
dropout_prob: ${dropout_prob}
pos_encoding: true
train_transforms: null
Expand Down
19 changes: 19 additions & 0 deletions cortex/config/hydra/roots/protein_seq_encoder.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
protein_seq:
_target_: cortex.model.root.TransformerRoot
corruption_process:
_target_: cortex.corruption.MaskCorruptionProcess
tokenizer_transform:
_target_: cortex.transforms.HuggingFaceTokenizerTransform
tokenizer:
_target_: cortex.tokenization.ProteinSequenceTokenizerFast
max_len: 256
out_dim: ${channel_dim}
embed_dim: ${embed_dim}
channel_dim: ${channel_dim}
num_blocks: 10
num_heads: 8
is_causal: false
dropout_prob: ${dropout_prob}
pos_encoding: true
train_transforms: null
eval_transforms: null
25 changes: 25 additions & 0 deletions cortex/config/hydra/tasks/generation/gfp_autoregressive.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
gfp:
_target_: cortex.task.AutoregressiveLanguageModelTask
tokenizer:
_target_: cortex.tokenization.ProteinSequenceTokenizerFast
input_map:
protein_seq: ['tokenized_seq']
root_key: protein_seq
# Add BLOSUM62-based substitution corruption for data augmentation
corruption_process:
_target_: cortex.corruption.SubstitutionCorruptionProcess.from_blosum62
corruption_rate: 0.1 # Apply corruption to 10% of masked tokens
data_module:
_target_: cortex.data.data_module.TaskDataModule
_recursive_: false
batch_size: ${fit.batch_size}
balance_train_partition: null
drop_last: true
lengths: [1.0, 0.0]
train_on_everything: false
num_workers: ${num_workers}
dataset_config:
_target_: cortex.data.dataset.TAPEFluorescenceDataset
root: ${dataset_root_dir}
download: ${download_datasets}
train: ???
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
stable_proteins_autoregressive:
_target_: cortex.task.AutoregressiveLanguageModelTask
tokenizer:
_target_: cortex.tokenization.ProteinSequenceTokenizerFast
input_map:
protein_seq: ['tokenized_seq']
root_key: protein_seq
# Add BLOSUM62-based substitution corruption for data augmentation
corruption_process:
_target_: cortex.corruption.SubstitutionCorruptionProcess.from_blosum62
corruption_rate: 0.1 # Apply corruption to 10% of masked tokens
data_module:
_target_: cortex.data.data_module.TaskDataModule
_recursive_: false
batch_size: ${fit.batch_size}
balance_train_partition: null
drop_last: true
lengths: [1.0, 0.0]
train_on_everything: false
num_workers: ${num_workers}
dataset_config:
_target_: cortex.data.dataset.TAPEStabilityDataset
root: ${dataset_root_dir}
download: ${download_datasets}
train: ???
25 changes: 25 additions & 0 deletions cortex/config/hydra/tasks/generation/tape_combined.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
tape_combined:
_target_: cortex.task.DenoisingLanguageModelTask
tokenizer:
_target_: cortex.tokenization.ProteinSequenceTokenizerFast
input_map:
protein_seq: ['tokenized_seq']
root_key: protein_seq
# Add BLOSUM62-based substitution corruption for data augmentation
corruption_process:
_target_: cortex.corruption.SubstitutionCorruptionProcess.from_blosum62
corruption_rate: 0.1 # Apply corruption to 10% of masked tokens
data_module:
_target_: cortex.data.data_module.TaskDataModule
_recursive_: false
batch_size: ${fit.batch_size}
balance_train_partition: partition
drop_last: true
lengths: [1.0, 0.0]
train_on_everything: false
num_workers: ${num_workers}
dataset_config:
_target_: cortex.data.dataset.TAPECombinedDataset
root: ${dataset_root_dir}
download: ${download_datasets}
train: ???
52 changes: 52 additions & 0 deletions cortex/config/hydra/train_gpt.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
defaults:
- general_settings: default
- logging: default
- model_globals: default
- roots: [protein_seq_decoder]
- trunk: default
- branches: [generation]
- tree: protein_model
- tasks:
- generation/stable_proteins_autoregressive
- _self_

fit:
batch_size: 128

trainer:
_target_: lightning.Trainer
accelerator: gpu
max_epochs: 64
devices: 1
num_sanity_val_steps: 0


tree:
_recursive_: false
fit_cfg:
reinitialize_roots: true
linear_probing: false
weight_averaging: null
optimizer:
_target_: torch.optim.Adam
lr: 6e-4
weight_decay: 0.
betas: [0.99, 0.999]
fused: false
lr_scheduler:
_target_: transformers.get_cosine_schedule_with_warmup
num_warmup_steps: 10
num_training_steps: ${trainer.max_epochs}

tasks:
generation:
stable_proteins_autoregressive:
ensemble_size: 1

train_on_everything: false
linear_probing: false
dataset_root_dir: /home/stantos5/scratch/datasets
download_datasets: true
num_workers: 2

ckpt_name: ${exp_name}_${job_name}
28 changes: 12 additions & 16 deletions cortex/config/hydra/train_protein_model.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@ defaults:
- general_settings: default
- logging: default
- model_globals: default
- roots: [protein_seq]
- roots: [protein_seq_encoder]
- trunk: default
- branches: [protein_property, generation]
- branches: [folding_encoder, protein_property_encoder, generation]
- tree: protein_model
- tasks:
- protein_property/log_fluorescence
- protein_property/stability
- generation/gfp
- generation/stable_proteins
- folding/stability
- generation/tape_combined
- _self_

fit:
Expand All @@ -19,12 +18,12 @@ fit:
trainer:
_target_: lightning.Trainer
accelerator: gpu
max_epochs: 64
max_epochs: 128
devices: 1
# devices: 8
# strategy: ddp
num_sanity_val_steps: 0

# precision: 16

tree:
_recursive_: false
Expand All @@ -34,7 +33,7 @@ tree:
weight_averaging: null
optimizer:
_target_: torch.optim.Adam
lr: 5e-3
lr: 3e-4
weight_decay: 0.
betas: [0.99, 0.999]
fused: false
Expand All @@ -44,17 +43,14 @@ tree:
num_training_steps: ${trainer.max_epochs}

tasks:
folding:
stability:
ensemble_size: ${ensemble_size}
protein_property:
log_fluorescence:
# ensemble_size: ${ensemble_size}
ensemble_size: 2
stability:
# ensemble_size: ${ensemble_size}
ensemble_size: 2
ensemble_size: ${ensemble_size}
generation:
gfp:
ensemble_size: 1
stable_proteins:
tape_combined:
ensemble_size: 1

train_on_everything: false
Expand Down
3 changes: 2 additions & 1 deletion cortex/data/data_module/_task_data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,8 @@ def get_dataloader(self, split: str = "train"):
# Full batch for evaluation on the test set
if split == "test":
self._dataloader_kwargs["batch_size"] = len(self.datasets[split])
dataloader = DataLoader(self.datasets[split], shuffle=True, drop_last=True, **self._dataloader_kwargs)
# self._dataloader_kwargs["batch_size"] = 2 * self._batch_size
dataloader = DataLoader(self.datasets[split], shuffle=False, drop_last=False, **self._dataloader_kwargs)
if split == "test":
self._dataloader_kwargs["batch_size"] = self._batch_size
return dataloader
Expand Down
4 changes: 4 additions & 0 deletions cortex/data/dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
from ._rfp_dataset import RedFluorescentProteinDataset
from ._tape_fluorescence import TAPEFluorescenceDataset
from ._tape_stability import TAPEStabilityDataset

# ruff: noqa: I001
from ._tape_combined import TAPECombinedDataset
from ._transformed_dataset import TransformedDataset

__all__ = [
Expand All @@ -12,5 +15,6 @@
"RedFluorescentProteinDataset",
"TAPEFluorescenceDataset",
"TAPEStabilityDataset",
"TAPECombinedDataset",
"TransformedDataset",
]
20 changes: 20 additions & 0 deletions cortex/data/dataset/_tape_combined.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import pandas as pd

from cortex.data.dataset import TAPEFluorescenceDataset, TAPEStabilityDataset
from cortex.data.dataset._data_frame_dataset import DataFrameDataset


# hack to combine TAPE datasets for self-supervised training
class TAPECombinedDataset(DataFrameDataset):
columns = [
"tokenized_seq",
"partition",
]

def __init__(self, root: str, download: bool = False, **kwargs):
fluorescence_data = TAPEFluorescenceDataset(root=root, download=download, **kwargs)._data
stability_data = TAPEStabilityDataset(root=root, download=download, **kwargs)._data

fluorescence_data["partition"] = "fluorescence"
stability_data["partition"] = "stability"
self._data = pd.concat([fluorescence_data[self.columns], stability_data[self.columns]], ignore_index=True)
4 changes: 3 additions & 1 deletion cortex/logging/_wandb_setup.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import uuid
from typing import MutableMapping

import wandb
from omegaconf import DictConfig, OmegaConf

# ruff: noqa: I001
import wandb

import cortex


Expand Down
5 changes: 4 additions & 1 deletion cortex/model/branch/_transformer_branch.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
Apply,
Expression,
MeanPooling,
PoolingSelfAttention,
WeightedMeanPooling,
identity,
)
Expand All @@ -32,7 +33,7 @@ def __init__(
out_dim: int = 64,
channel_dim: int = 64,
num_blocks: int = 2,
num_heads: int = 5,
num_heads: int = 4,
is_causal: bool = False,
dropout_prob: float = 0.0,
pooling_type: str = "mean",
Expand Down Expand Up @@ -74,6 +75,8 @@ def __init__(
self.pooling_op = MeanPooling()
elif pooling_type == "weighted_mean":
self.pooling_op = WeightedMeanPooling(out_dim)
elif pooling_type == "attention":
self.pooling_op = PoolingSelfAttention(num_heads=num_heads, embed_dim=out_dim, dropout_p=dropout_prob)
else:
raise NotImplementedError

Expand Down
2 changes: 2 additions & 0 deletions cortex/model/elemental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from ._layernorm import MaskLayerNorm1d
from ._mean_pooling import MeanPooling, WeightedMeanPooling
from ._mlp import MLP
from ._pooling_self_attention import PoolingSelfAttention
from ._sine_pos_encoder import SinePosEncoder

__all__ = [
Expand All @@ -20,6 +21,7 @@
"swish",
"MaskLayerNorm1d",
"MeanPooling",
"PoolingSelfAttention",
"WeightedMeanPooling",
"SinePosEncoder",
]
Loading