Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 47 additions & 19 deletions electrolyte_fm/models/prod_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,18 @@

import torch
import torch.nn as nn
from smirk import SmirkTokenizerFast
from transformers import (
AutoConfig,
AutoModel,
AutoTokenizer,
DataCollatorWithPadding,
PreTrainedModel,
PretrainedConfig,
PreTrainedModel,
)

from smirk import SmirkTokenizerFast
from .prediction_task_head import PredictionTaskHead
from .normalize import AbstractNormalizer
from .prediction_task_head import PredictionTaskHead

AutoTokenizer.register("SmirkTokenizer", fast_tokenizer_class=SmirkTokenizerFast)

Expand Down Expand Up @@ -141,14 +141,28 @@ def _resolve_tokenizer(self, tokenizer=None):
return tokenizer
if getattr(self, "tokenizer", None) is not None:
return self.tokenizer
try:
return AutoTokenizer.from_pretrained(
self.name_or_path, use_fast=True, trust_remote_code=True
)
except Exception:
return AutoTokenizer.from_pretrained(
self.config._name_or_path, use_fast=True, trust_remote_code=True
)

if self.name_or_path and "/" in self.name_or_path:
try:
return AutoTokenizer.from_pretrained(
self.name_or_path, use_fast=True, trust_remote_code=True
)
except Exception:
pass

if (
hasattr(self.config, "_name_or_path")
and self.config._name_or_path
and "/" in self.config._name_or_path
):
try:
return AutoTokenizer.from_pretrained(
self.config._name_or_path, use_fast=True, trust_remote_code=True
)
except Exception:
pass

return None

def embed(self, smi: List[str], tokenizer=None):
batch = self.tokenizer(smi)
Expand Down Expand Up @@ -282,14 +296,28 @@ def _resolve_tokenizer(self, tokenizer=None):
return tokenizer
if getattr(self, "tokenizer", None) is not None:
return self.tokenizer
try:
return AutoTokenizer.from_pretrained(
self.name_or_path, use_fast=True, trust_remote_code=True
)
except Exception:
return AutoTokenizer.from_pretrained(
self.config._name_or_path, use_fast=True, trust_remote_code=True
)

if self.name_or_path and "/" in self.name_or_path:
try:
return AutoTokenizer.from_pretrained(
self.name_or_path, use_fast=True, trust_remote_code=True
)
except Exception:
pass

if (
hasattr(self.config, "_name_or_path")
and self.config._name_or_path
and "/" in self.config._name_or_path
):
try:
return AutoTokenizer.from_pretrained(
self.config._name_or_path, use_fast=True, trust_remote_code=True
)
except Exception:
pass

return None

def predict(self, smi: List[str], tokenizer=None):
batch = self.tokenizer(smi)
Expand Down
57 changes: 32 additions & 25 deletions opt/package/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,57 +5,60 @@
# python -m opt.package --help

from __future__ import annotations

import ast
import inspect
import json
import logging
from pathlib import Path
from typing import Iterable, Type, Optional, List, Tuple, Set
from typing import Iterable, List, Optional, Set, Tuple, Type

import typer
from safetensors.torch import load_file
from transformers import AutoConfig, AutoModel

from electrolyte_fm.utils.ckpt import SaveConfigWithCkpts, get_ckpt_tokenizer
from electrolyte_fm.models.lm_finetuning import load_encoder
from electrolyte_fm.models import (
MISTFinetunedConfig,
MISTExcessPhysics,
MISTExcessPhysicsConfig,
MISTFinetuned,
MISTIonicConductivityConfig,
MISTFinetunedConfig,
MISTIonicConductivity,
MISTMultiTaskConfig,
MISTMultiTask,
MISTExcessPhysicsConfig,
MISTExcessPhysics,
MISTMixturesConfig,
MISTIonicConductivityConfig,
MISTMixtures,
MISTMixturesConfig,
MISTMultiTask,
MISTMultiTaskConfig,
)
from electrolyte_fm.models.lm_finetuning import load_encoder
from electrolyte_fm.models.mixture_model import TemperatureCondition
from electrolyte_fm.utils.tokenizer import load_tokenizer
from electrolyte_fm.models.prediction_task_head import PredictionTaskHead
from electrolyte_fm.models.normalize import (
AbstractNormalizer,
Standardize,
PowerTransform,
IdentityTransform,
LogTransform,
MaxScaleTransform,
IdentityTransform,
PowerTransform,
Standardize,
)
from electrolyte_fm.models.pairwise_fusion import pairwise_fusion
from electrolyte_fm.models.physics_task_heads import (
VFTDecayTaskHead,
ArrtheniusActivation,
LinearExogenousEffect,
VFTDecayTaskHead,
)
from electrolyte_fm.models.polynomials import LagrangePolynomial
from electrolyte_fm.models.pairwise_fusion import pairwise_fusion
from electrolyte_fm.models.prediction_task_head import PredictionTaskHead
from electrolyte_fm.utils.ckpt import SaveConfigWithCkpts, get_ckpt_tokenizer
from electrolyte_fm.utils.tokenizer import load_tokenizer

from .migrate_legacy import load_legacy_packaged_checkpoint
from .utils import (
name_model,
get_best_ckpt,
create_save_directory,
ckpt_id,
save_tokenizer,
create_save_directory,
create_tar_gz,
get_best_ckpt,
name_model,
save_tokenizer,
)
from .migrate_legacy import load_legacy_packaged_checkpoint
from .write_model_class import write_modeling_module

cli = typer.Typer()
Expand Down Expand Up @@ -421,10 +424,9 @@ def mixtures(ckpt: Path, name: Optional[str] = None, safe: bool = True):


def export_multitask(encoder_ckpt: Path, task_ckpt: List[Path]) -> MISTMultiTask:
encoder_ckpt = maybe_best_ckpt(encoder_ckpt)

try:
# Try loading from training checkpoints
encoder_ckpt = maybe_best_ckpt(encoder_ckpt)
encoder = load_encoder(encoder_ckpt)
tokenizer = get_ckpt_tokenizer(encoder_ckpt)

Expand Down Expand Up @@ -452,6 +454,11 @@ def export_multitask(encoder_ckpt: Path, task_ckpt: List[Path]) -> MISTMultiTask

try:
encoder_model, _ = load_model(encoder_ckpt)

# If the loaded model is already a MISTMultiTask, return it directly
if isinstance(encoder_model, MISTMultiTask):
return encoder_model

encoder = (
encoder_model.encoder
if hasattr(encoder_model, "encoder")
Expand Down Expand Up @@ -532,7 +539,7 @@ def multitask(
architecture_name="MISTMultiTask",
config_class_name="MISTMultiTaskConfig",
model_class_name="MISTMultiTask",
)(save_dir)
)
L.info("Saved multitask model to %s", save_dir)
create_tar_gz(save_dir)

Expand Down
5 changes: 3 additions & 2 deletions opt/package/generate_model_card.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@
"""

import argparse
import json
import re
import sys
from pathlib import Path
import json

import yaml
from jinja2 import Environment, FileSystemLoader, Template

Expand Down Expand Up @@ -132,7 +133,7 @@ def generate_model_card_for_directory(
parts = model_dir.name.split("-")
if len(parts) >= 2:
encoder_size = parts[1]
if encoder_size in ["26.9M", "27.0M", "28M"]:
if encoder_size in ["26.9M", "27.0M", "27.1M", "28M"]:
encoder_key = "mist-28M"
elif encoder_size == "1.8B":
encoder_key = "mist-1.8B"
Expand Down
30 changes: 19 additions & 11 deletions opt/package/migrate_legacy.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,34 @@
import json
from pathlib import Path
from typing import Optional

from safetensors.torch import load_file
from transformers import (
AutoConfig,
AutoModel,
RobertaPreLayerNormConfig,
RobertaPreLayerNormForMaskedLM,
)
from typing import Optional
from safetensors.torch import load_file

from electrolyte_fm.models import (
MISTFinetunedConfig,
MISTExcessPhysics,
MISTExcessPhysicsConfig,
MISTFinetuned,
MISTIonicConductivityConfig,
MISTFinetunedConfig,
MISTIonicConductivity,
MISTMultiTaskConfig,
MISTIonicConductivityConfig,
MISTMultiTask,
MISTExcessPhysicsConfig,
MISTExcessPhysics,
MISTMultiTaskConfig,
)

from electrolyte_fm.models.excess_physics_model import pairwise_fusion
from electrolyte_fm.models.normalize import AbstractNormalizer
from electrolyte_fm.models.physics_task_heads import (
VFTDecayTaskHead,
ArrtheniusActivation,
LinearExogenousEffect,
VFTDecayTaskHead,
)
from electrolyte_fm.models.polynomials import LagrangePolynomial
from electrolyte_fm.models.excess_physics_model import pairwise_fusion
from electrolyte_fm.models.prediction_task_head import PredictionTaskHead
from electrolyte_fm.models.normalize import AbstractNormalizer


def load_legacy_packaged_checkpoint(path: Path, model_class: Optional = None):
Expand Down Expand Up @@ -154,3 +154,11 @@ def load_legacy_packaged_checkpoint(path: Path, model_class: Optional = None):

model.load_state_dict(load_file(str(path / "model.safetensors")), strict=False)
return model


if __name__ == "__main___":
model = load_legacy_packaged_checkpoint(
"/nfs/turbo/coe-venkvis/abhutani/electrolyte-fm/solvent-properties",
model_class=MISTMultiTask,
)
print(str(model))
4 changes: 2 additions & 2 deletions opt/package/model_card_template.j2
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
---
language: en
library_name: transformers
license: gpl-3.0
license: apache-2.0
tags:
- mist
- chemistry
Expand All @@ -27,7 +27,7 @@ Consistent with prior works, the encoder hidden states were pooled by taking the

- **Developed by:** [Electrochemical Energy Group](https://eeg.engin.umich.edu/), University of Michigan, Ann Arbor.
- **Model type:** Self-supervised pre-trained MIST encoder with supervised finetuning.
- **License:** GPL 3.0 (GNU General Public License version 3)
- **License:** Apache 2.0
- **Finetuned from model:** [``{{ encoder.base_model }}``](https://huggingface.co/mist-models/{{ encoder.base_model }})

### Model Sources
Expand Down
25 changes: 25 additions & 0 deletions opt/package/model_configs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -456,3 +456,28 @@ datasets:
metrics: "Mean Absolute Error (MAE) for Cv predictions"
split_type: "random"
num_samples: 133885

# Solvent property models
pka:
name: "Dissociation Constant (pKa) in DMSO"
source: "The raw dataset was taken from Acid Dissociation Constants in Selected Dipolar Non-Hydrogen-Bond-Donor Solvents (IUPAC Technical Report). The dataset was filtered to include only measurements which included a value for the measurement in DMSO. Duplicates were averaged over based on InChI keys and two measurements for mixtures were excluded."
task_type: "Regression"
task_description: "dissociation constant in DMSO"
output_format: "Regression prediction for pKa"
output_description: "Returns predictions for pKa in DMSO"
loss_function: "Mean Squared Error (MSE) Loss"
metrics: "Mean Absolute Error (MAE) for pKa predictions"
split_type: "random"
num_samples: 3,254

etn:
name: "Normalized Dimroth-Reichardt Solvent Polarity Parameter (ETN)"
source: "The Dimroth-Reichardt ET dataset retrieved from the Stenutz solvent polarity table (https://www.stenutz.eu/chem/solv20.php?sort=3) was used. The dataset includes tabulated normalized ETN values for listed solvents and solvent-like media."
task_type: "Regression"
task_description: "normalized empirical solvent polarity parameter E_T^N"
output_format: "Regression prediction for ETN"
output_description: "Returns predictions for the normalized Dimroth-Reichardt solvent polarity parameter ETN"
loss_function: "Mean Squared Error (MSE) Loss"
metrics: "Mean Absolute Error (MAE) for ETN predictions"
split_type: "random"
num_samples: 379
16 changes: 16 additions & 0 deletions opt/package/test_package.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,3 +165,19 @@ def test_encoder_config(model, request):
config = model.encoder.config
if hasattr(config, "add_pooling_layer"):
assert config.add_pooling_layer is False


def test_export_multitask_preserves_tasks():
"""Test that export_multitask preserves task_networks and transforms
when loading an already-packaged MISTMultiTask model."""
from opt.package.__main__ import export_multitask

# Load an already-packaged multitask model with empty task_ckpt
model = export_multitask(MULTITASK_CKPT, task_ckpt=[])

# Verify task_networks and transforms are preserved, not empty
assert len(model.task_networks) > 0, "task_networks should not be empty"
assert len(model.transforms) > 0, "transforms should not be empty"
assert len(model.task_networks) == len(
model.transforms
), "task_networks and transforms should align"
3 changes: 3 additions & 0 deletions opt/screening/activate
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,6 @@ export HF_HOME="$GIT_ROOT/.cache/huggingface"
# Julia
export JULIA_CONDAPKG_BACKEND=Null
export JULIA_PYTHONCALL_EXE="${ROOT}/.venv/bin/python"
export JULIA_GLMAKIE_BACKEND="headless"
export LD_LIBRARY_PATH=""
export DISPLAY=0
Loading
Loading