Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
b8db5ad
revise trends plots
anoushka2000 Mar 5, 2026
f5b810d
fewer func groups in trends plot
anoushka2000 Mar 5, 2026
e9430cc
olfaction exp validation plots
anoushka2000 Mar 5, 2026
fead611
better colors?
anoushka2000 Mar 6, 2026
9ae5953
fix paths
anoushka2000 Mar 17, 2026
b49fe6c
update smirk dep
anoushka2000 Mar 17, 2026
0f6b9b9
support huggingface checkpoints in OracleCritic
anoushka2000 Jan 29, 2026
f04baa7
fix logging in for ref fragments
anoushka2000 Feb 3, 2026
c8994c5
fix: vendor exec path
anoushka2000 Feb 9, 2026
e056848
feat: expose num passes arg for fragementation
anoushka2000 Feb 9, 2026
03fb6fa
fix: allow models without channel labels and toch compile doesn't wor…
anoushka2000 Feb 9, 2026
b98dc4d
fix: logging
anoushka2000 Feb 9, 2026
f46cbaf
power transforms don't work in fp16
anoushka2000 Feb 10, 2026
b5720f4
fix token resolution from production models
anoushka2000 Feb 24, 2026
dc5bc28
allow chkpts with channels stored as lists
anoushka2000 Feb 24, 2026
48411c1
pin smirk version
anoushka2000 Feb 24, 2026
bd03e8b
formatting
anoushka2000 Feb 24, 2026
204379f
add preprocessing for olfactory similarity dataset
anoushka2000 Dec 5, 2025
20f8192
add config for olfactory similarity task
anoushka2000 Dec 5, 2025
1ed66b4
model and data module to handle multiple mixtures
anoushka2000 Dec 5, 2025
917c9b2
fix docstrings
anoushka2000 Dec 5, 2025
140f7e6
feat: option to augment emb with Dragon features
anoushka2000 Jan 17, 2026
6d8b0cb
fix: mask direction
anoushka2000 Jan 17, 2026
bd13083
feat: get_encoder method for finetuned models to allow second ft stage
anoushka2000 Jan 17, 2026
07c4404
remove use of descriptors
anoushka2000 Jan 17, 2026
70f914d
use last tok emb for mols, don't enforce monotonicity, fix mask direc…
anoushka2000 Jan 17, 2026
81a32af
R2 loss and Manhattan Distance metric
anoushka2000 Jan 19, 2026
b0d521d
formatting
anoushka2000 Jan 19, 2026
ef39df3
add preprocessing for sarifard paper test set
anoushka2000 Jan 25, 2026
b0cbac8
seperate olfaction deps
anoushka2000 Jan 28, 2026
6a348c8
formatting
anoushka2000 Jan 28, 2026
bbb8d98
anonymize paths
anoushka2000 Jan 28, 2026
3187396
Merge branch 'master' into osmo-validation-plot
anoushka2000 Mar 23, 2026
e87eab5
combine olfaction opts
anoushka2000 Mar 30, 2026
33af0bb
switch to distinguishable colors and add triplet diagram
anoushka2000 Mar 30, 2026
f3c3a1d
finalize paper plots
anoushka2000 Mar 31, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion electrolyte_fm/data_modules/mixture_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class ComponentDataModule(LightningDataModule):
def __init__(
self,
path: str,
target_columns: Union[str, List[str]],
target_columns: list | str,
n_components: int = 2,
tokenizer: Optional[str] = None,
batch_size: int = 64,
Expand Down
34 changes: 20 additions & 14 deletions opt/design/olfaction_trends.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,13 @@ function plot_scent_by_type(df::DataFrame)
types = sort(unique(df.type))
scent_cols = get_scent_cols(df)
colors = MISTStyle.CAT_COLORS
linestyles = [
:solid, (:dash, :dense), (:dot, :dense),
:dashdot, :dashdotdot, (:dot, :loose), (:dash, :loose)
]
markers = [:circle, :rect, :diamond, :utriangle, :dtriangle, :pentagon, :cross, :xcross]

scent_colors = Dict(s => colors[mod1(i, length(colors))] for (i, s) in enumerate(scent_cols))
scent_linestyles = Dict(s => linestyles[mod1(i, length(linestyles))] for (i, s) in enumerate(scent_cols))
scent_markers = Dict(s => markers[mod1(i, length(markers))] for (i, s) in enumerate(scent_cols))

fig = Figure(size=(250mm, 120mm), figure_padding=(3, 3, 3, 3))
n_cols = ceil(Int, length(types) / 2)
fig = Figure(size=(175mm, 170mm), figure_padding=(3, 3, 3, 3))
n_cols = ceil(Int, length(types) / 4)

for (i, type) in enumerate(types)
df_type = filter(row -> row.type == type, df)
Expand Down Expand Up @@ -99,17 +96,18 @@ function plot_scent_by_type(df::DataFrame)
sigmoid.(y_mean .+ y_std),
color=(scent_colors[scent], 0.2)
)
lines!(
scatterlines!(
ax, df_type.n_carbon, y_prob,
color=scent_colors[scent],
linestyle=scent_linestyles[scent]
marker=scent_markers[scent],
markersize=5
)

push!(
plotted_elements,
LineElement(
MarkerElement(
color=scent_colors[scent],
linestyle=scent_linestyles[scent]
marker=scent_markers[scent],
)
)
push!(plotted_labels, scent)
Expand Down Expand Up @@ -139,16 +137,18 @@ function plot_branched_scent_comparison(df::DataFrame)
max_variants = maximum(nrow(filter(row -> row.type == type, df)) for type in types)
scent_colors = [colors[mod1(i, length(colors))] for i in eachindex(scents_to_plot)]

fig = Figure(size=(175mm, 140mm), figure_padding=(5, 5, 5, 5))
fig = Figure(size=(175mm, 120mm), figure_padding=(5, 5, 5, 5))
gl = GridLayout(fig[1, 1])

for (i, type) in enumerate(types)
df_type = filter(row -> row.type == type, df)
row_pos, col_pos = grid_position(i, 2)

ax = Axis(
fig[row_pos, col_pos],
gl[row_pos, col_pos],
ylabel="Probability",
title=type,
titlesize=6,
xticklabelrotation=π/3,
xticklabelalign=(:right, :center),
limits=(0.5, max_variants + 0.5, 0, 1.3),
Expand Down Expand Up @@ -180,8 +180,14 @@ function plot_branched_scent_comparison(df::DataFrame)
end
end

sublabels = ["a", "b", "c", "d", "e", "f"]
for (i, label) in enumerate(sublabels[1:length(types)])
row_pos, col_pos = grid_position(i, 2)
sublabel!(gl[row_pos, col_pos, TopLeft()], label; left=15pt)
end

Legend(
fig[3, :],
fig[2, 1],
[PolyElement(color=c) for c in scent_colors],
scents_to_plot,
orientation=:horizontal,
Expand Down
68 changes: 37 additions & 31 deletions opt/design/src/hydrocarbons.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,25 +22,27 @@ tert_alkyl(n::Int) = n < 4 ? "C"^n : "C(C)(C)" * "C"^(n - 3)

function branched_esters(n_total::Int)
esters = []
# Generate esters with different chain length splits
for n_acyl in 1:(n_total - 1)
n_alkoxy = n_total - n_acyl
# Linear acyl, iso alkoxy (if n_alkoxy >= 3)

push!(esters, (; smi="O=C(O" * "C"^n_alkoxy * ")" * "C"^(n_acyl - 1),
branch_pattern="linear"))

if n_alkoxy >= 3
push!(esters, (; smi="O=C(O" * iso_alkyl(n_alkoxy) * ")" * "C"^(n_acyl - 1),
branch_pattern="iso-alkoxy"))
end
# Iso acyl, linear alkoxy (n_acyl - 1 >= 3, so n_acyl >= 4)

if n_acyl >= 4
push!(esters, (; smi="O=C(O" * "C"^n_alkoxy * ")" * iso_alkyl(n_acyl - 1),
branch_pattern="iso-acyl"))
end
# Tert acyl, linear alkoxy (n_acyl - 1 >= 4, so n_acyl >= 5)

if n_acyl >= 5
push!(esters, (; smi="O=C(O" * "C"^n_alkoxy * ")" * tert_alkyl(n_acyl - 1),
branch_pattern="tert-acyl"))
end
# Both iso (n_acyl - 1 >= 3 and n_alkoxy >= 3, so n_acyl >= 4 and n_alkoxy >= 3)

if n_acyl >= 4 && n_alkoxy >= 3
push!(esters, (; smi="O=C(O" * iso_alkyl(n_alkoxy) * ")" * iso_alkyl(n_acyl - 1),
branch_pattern="iso-both"))
Expand All @@ -51,30 +53,32 @@ end

function branched_ethers(n_total::Int)
ethers = []
# Generate ethers with different chain length splits
for n_left in 1:(n_total - 1)
n_right = n_total - n_left
# Linear left, iso right (if n_right >= 3)

push!(ethers, (; smi="C"^n_left * "O" * "C"^n_right,
branch_pattern="linear"))

if n_right >= 3
push!(ethers, (; smi="C"^n_left * "O" * iso_alkyl(n_right),
branch_pattern="iso-right"))
end
# Iso left, linear right (if n_left >= 3)

if n_left >= 3
push!(ethers, (; smi=iso_alkyl(n_left) * "O" * "C"^n_right,
branch_pattern="iso-left"))
end
# Tert left, linear right (if n_left >= 4)

if n_left >= 4
push!(ethers, (; smi=tert_alkyl(n_left) * "O" * "C"^n_right,
branch_pattern="tert-left"))
end
# Linear left, tert right (if n_right >= 4)

if n_right >= 4
push!(ethers, (; smi="C"^n_left * "O" * tert_alkyl(n_right),
branch_pattern="tert-right"))
end
# Both iso (if both sides >= 3)

if n_left >= 3 && n_right >= 3
push!(ethers, (; smi=iso_alkyl(n_left) * "O" * iso_alkyl(n_right),
branch_pattern="iso-both"))
Expand All @@ -85,29 +89,30 @@ end

function branched_alcohols(n_total::Int)
alcohols = []
# Secondary alcohols - OH on carbon with iso branching (if n >= 3)

push!(alcohols, (; smi="O" * "C"^n_total,
branch_pattern="primary-linear"))

if n_total >= 3
push!(alcohols, (; smi="O" * iso_alkyl(n_total),
branch_pattern="secondary-at-O-iso"))
end
# Tertiary alcohols - OH on carbon with tert branching (if n >= 4)

if n_total >= 4
push!(alcohols, (; smi="O" * tert_alkyl(n_total),
branch_pattern="tertiary-at-O-tert"))
end
# Secondary alcohols (OH on internal carbon, linear)

for pos in 2:(n_total - 1)
# Linear secondary alcohol
push!(alcohols, (; smi="C"^(pos - 1) * "C(O)" * "C"^(n_total - pos),
branch_pattern="secondary-linear"))

# Tertiary alcohol with branching at OH position (both sides must exist)
if pos >= 2 && pos <= n_total - 2
push!(alcohols, (; smi="C"^(pos - 1) * "C(O)(C)" * "C"^(n_total - pos - 1),
branch_pattern="tertiary-at-O"))
end
end
# Tertiary alcohols (3 distinct alkyl groups)

if n_total >= 4
for n_side1 in 1:(n_total - 3)
for n_side2 in 1:(n_total - n_side1 - 2)
Expand All @@ -122,29 +127,30 @@ end

function branched_thiols(n_total::Int)
thiols = []
# Secondary thiols - SH on carbon with iso branching (if n >= 3)

push!(thiols, (; smi="S" * "C"^n_total,
branch_pattern="primary-linear"))

if n_total >= 3
push!(thiols, (; smi="S" * iso_alkyl(n_total),
branch_pattern="secondary-at-S-iso"))
end
# Tertiary thiols - SH on carbon with tert branching (if n >= 4)

if n_total >= 4
push!(thiols, (; smi="S" * tert_alkyl(n_total),
branch_pattern="tertiary-at-S-tert"))
end
# Secondary thiols (SH on internal carbon, linear)

for pos in 2:(n_total - 1)
# Linear secondary thiol
push!(thiols, (; smi="C"^(pos - 1) * "C(S)" * "C"^(n_total - pos),
branch_pattern="secondary-linear"))

# Tertiary thiol with branching at SH position (both sides must exist)
if pos >= 2 && pos <= n_total - 2
push!(thiols, (; smi="C"^(pos - 1) * "C(S)(C)" * "C"^(n_total - pos - 1),
branch_pattern="tertiary-at-S"))
end
end
# Tertiary thiols (3 distinct alkyl groups)

if n_total >= 4
for n_side1 in 1:(n_total - 3)
for n_side2 in 1:(n_total - n_side1 - 2)
Expand Down Expand Up @@ -205,15 +211,15 @@ end

function fragrance_compounds(n::Int)
df = DataFrame(vcat(
[(; type="Alkanes", smi=alkane(n)) for n in 1:n],
# [(; type="Alkanes", smi=alkane(n)) for n in 1:n],
[(; type="Esters", smi=ester(n)) for n in 1:n],
[(; type="Ethers", smi=ether(n)) for n in 3:n],
[(; type="Alcohols", smi=alcohol(n)) for n in 1:n],
[(; type="Aldehydes", smi=aldehyde(n)) for n in 1:n],
[(; type="Carboxylic acids", smi=carboxylic_acid(n)) for n in 2:n],
[(; type="Alkenes", smi=alkene(n)) for n in 2:n],
[(; type="Alkynes", smi=alkyne(n)) for n in 2:n],
[(; type="Arenes", smi=arene(n)) for n in 0:n],
# [(; type="Alkenes", smi=alkene(n)) for n in 2:n],
# [(; type="Alkynes", smi=alkyne(n)) for n in 2:n],
# [(; type="Arenes", smi=arene(n)) for n in 0:n],
[(; type="Thiol", smi=thiol(n)) for n in 1:n],

))
Expand All @@ -223,10 +229,10 @@ end

function branched_fragrance_compounds(n::Int)
df = DataFrame(vcat(
[(; type="Branched Esters", row...) for row in branched_esters(n)],
[(; type="Branched Ethers", row...) for row in branched_ethers(n)],
[(; type="Branched Alcohols", row...) for row in branched_alcohols(n)],
[(; type="Branched Thiols", row...) for row in branched_thiols(n)],
[(; type="Esters", row...) for row in branched_esters(n)],
[(; type="Ethers", row...) for row in branched_ethers(n)],
[(; type="Alcohols", row...) for row in branched_alcohols(n)],
[(; type="Thiols", row...) for row in branched_thiols(n)],
))
n_carbon!(df)
return df
Expand Down
Loading
Loading