Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions opt/BayesianScaling/.gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# Ignore output directory (or symlink to it)
out/
out
fig/
5 changes: 4 additions & 1 deletion opt/BayesianScaling/plots/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,8 @@ PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

[sources]
BayesianScaling = {path = "../"}
BayesianScaling = {path = ".."}
MISTStyle = {path = "../../MISTStyle"}

[compat]
PrettyTables = "^2.0.0"
61 changes: 41 additions & 20 deletions opt/BayesianScaling/plots/paper.jl
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#!/usr/bin/env -S uv run julia --project=@script --startup-file=no
using Makie
using MISTStyle
using DataFrames
Expand All @@ -6,6 +7,12 @@ using JLD2: jldopen
using BayesianScaling
using BayesianScaling: pf_day, get_nbins

const AVG_SEQ_LENGTH::Float64 = 65.2
"""
Average number of tokens per molecule for the Smirk Tokenizer on REALSpace
Source: SI for doi:10.1021/acs.jcim.5c01856
"""

function mark_model!(ax_scale, ax_compute;
model_loss,
d_model,
Expand All @@ -15,11 +22,12 @@ function mark_model!(ax_scale, ax_compute;
eff_batch_size,
colorbar,
h_loss,
avg_seq_length = 1,
kwargs...
)
N = BayesianScaling.non_embedding_size(d_model, ff_ratio * d_model, n_layers)
D = steps * eff_batch_size
C = 6 * float(N) * float(D)
C = 6 * float(N) * float(D) * float(avg_seq_length)
kwargs = (;
marker=:star5,
color=model_loss,
Expand Down Expand Up @@ -102,11 +110,11 @@ function plot_scaling_params!(f, chains; legend_pos=:top)
)

prior_works = [
"Kaplan et al." => (; α=0.076, β=0.103, a = 0.58, b = 0.42),
"Hoffmann et al. (Appr. 1)" => (; a = 0.50, b = 0.50),
"Hoffmann et al. (Appr. 3)" => (; α=0.34, β=0.28, a = 0.46, b = 0.54),
"Bi et al. (Early)" => (; a = 0.450, b = 0.550),
"Bi et al. (Current)" => (; a = 0.524, b = 0.478),
"Kaplan et al." => (; α=0.076, β=0.103, a = 0.58, b = 0.42),
]

gl = GridLayout(f[1, 1])
Expand Down Expand Up @@ -167,12 +175,14 @@ function plot_scaling_params!(f, chains; legend_pos=:top)
Legend(legend_pos, elem, MISTStyle.label.(elem);
tellwidth,
tellheight,
margin=2pt .* (1, 1, 1, 1),
orientation,
nbanks,
halign,
valign,
padding=(4pt, 4pt, 1pt, 1pt),
colgap=4pt,
)
rowgap!(gl, 2pt)

return f
end
Expand Down Expand Up @@ -320,10 +330,11 @@ function plot_lr_partial_dependence!(f, model, chains; p=0.95, npoints=100)
end

function figure_bayesian(data;
N=logrange(1e5, 3e9; length=50),
C=logrange(1e12, 2*pf_day; length=50),
N=logrange(1e5, 4e9; length=50),
C=logrange(1e-5 * pf_day, 100*pf_day; length=50),
p=0.95,
n_samples=500,
avg_seq_length::AbstractFloat = 1.0,
)
model = data["model"]
chains = data["chains"]
Expand All @@ -333,7 +344,7 @@ function figure_bayesian(data;

model_loss = StatsBase.response(model)
model_size = df.model_size
model_flops = @. 6 * float(df.model_size) * float(df.data_size)
model_flops = @. 6 * float(df.model_size) * float(df.data_size) * float(avg_seq_length)

f_height = model.formula isa BayesianScaling.ShapedScaling ? 4.5inch : 3inch
f = Figure(;
Expand All @@ -358,13 +369,13 @@ function figure_bayesian(data;
ylabel="Non-Embedding Parameters",
)

loss_c_opt = BayesianScaling.compute_optimal_loss(scaling(chain_samples), C; p)
loss_c_opt = BayesianScaling.compute_optimal_loss(scaling(chain_samples), C ./ avg_seq_length; p)
h_opt = predictionband!(ax_compute, C ./ pf_day, loss_c_opt...;
color=MISTStyle.UM_COLORS.maize,
linewidth=2pt,
band_color=(RGBf(0.596, 0.612, 0.592), 0.4),
)
loss_cs = first(BayesianScaling.hoffman_compute_scaling(scaling(chain_samples), N, C))'
loss_cs = first(BayesianScaling.hoffman_compute_scaling(scaling(chain_samples), N, C ./ avg_seq_length))'
levels = logrange(minimum(loss_cs), 1.0; length=10)
cb = Colorbar(gl[1, end+1];
label="Validation Loss (nats)",
Expand All @@ -388,7 +399,7 @@ function figure_bayesian(data;
)


N_opt = BayesianScaling.compute_optimal_model_size(scaling(chain_samples), C)
N_opt = BayesianScaling.compute_optimal_model_size(scaling(chain_samples), C ./ avg_seq_length)
predictionband!(ax_scale, C ./ pf_day, N_opt...;
color=h_opt.color,
linewidth=h_opt.linewidth,
Expand All @@ -414,6 +425,7 @@ function figure_bayesian(data;
model_loss=0.01019080262631178,
colorbar=cb,
h_loss,
avg_seq_length,
)

# dh61satt
Expand All @@ -426,6 +438,7 @@ function figure_bayesian(data;
model_loss=0.03703538700938225,
colorbar=cb,
h_loss,
avg_seq_length,
)


Expand All @@ -451,10 +464,11 @@ function figure_bayesian(data;
end

function figure_bayesian_panel(data;
N=logrange(1e5, 3e9; length=50),
C=logrange(1e12, 2*pf_day; length=50),
N=logrange(1e5, 4e9; length=50),
C=logrange(1e-5 * pf_day, 100*pf_day; length=50),
p=0.95,
n_samples=500,
avg_seq_length::AbstractFloat = 1,
)
f = Figure(; size=(89mm, 170mm), figure_padding=(2, 2, 2, 4))

Expand All @@ -466,7 +480,7 @@ function figure_bayesian_panel(data;

model_loss = StatsBase.response(model)
model_size = df.model_size
model_flops = @. 6 * float(df.model_size) * float(df.data_size)
model_flops = @. 6 * float(df.model_size) * float(df.data_size) * float(avg_seq_length)

# Plot Covariance
colsize!(f.layout, 1, 100mm)
Expand Down Expand Up @@ -497,13 +511,13 @@ function figure_bayesian_panel(data;
ylabel="Non-Embedding Parameters",
)

loss_c_opt = BayesianScaling.compute_optimal_loss(scaling(chain_samples), C; p)
loss_c_opt = BayesianScaling.compute_optimal_loss(scaling(chain_samples), C ./ avg_seq_length; p)
h_opt = predictionband!(ax_compute, C ./ pf_day, loss_c_opt...;
color=MISTStyle.UM_COLORS.maize,
linewidth=2pt,
band_color=(RGBf(0.596, 0.612, 0.592), 0.4),
)
loss_cs = first(BayesianScaling.hoffman_compute_scaling(scaling(chain_samples), N, C))'
loss_cs = first(BayesianScaling.hoffman_compute_scaling(scaling(chain_samples), N, C ./ avg_seq_length))'
levels = logrange(minimum(loss_cs), 1.0; length=10)
cb = Colorbar(gl[1, end+1];
label="Validation Loss (nats)",
Expand All @@ -527,7 +541,7 @@ function figure_bayesian_panel(data;
)


N_opt = BayesianScaling.compute_optimal_model_size(scaling(chain_samples), C)
N_opt = BayesianScaling.compute_optimal_model_size(scaling(chain_samples), C ./ avg_seq_length)
predictionband!(ax_scale, C ./ pf_day, N_opt...;
color=h_opt.color,
linewidth=h_opt.linewidth,
Expand All @@ -553,6 +567,7 @@ function figure_bayesian_panel(data;
model_loss=0.01019080262631178,
colorbar=cb,
h_loss,
avg_seq_length,
)

# dh61satt
Expand All @@ -565,6 +580,7 @@ function figure_bayesian_panel(data;
model_loss=0.03703538700938225,
colorbar=cb,
h_loss,
avg_seq_length,
)

# Fit distributions
Expand All @@ -581,14 +597,15 @@ function figure_bayesian_panel(data;
return f
end

function plot_model(model; kwargs...)
function plot_model(model; avg_seq_length, kwargs...)
@info "plotting $model"
run_name = basename(model)
data = jldopen(joinpath(model, "chains.jld2"), "r")
try
suffix = avg_seq_length != 1 ? "-token-cost" : ""
with_theme(MISTStyle.theme()) do
figure_bayesian(data; kwargs...)
end |> MISTStyle.savefig("bayesian-$run_name")
figure_bayesian(data; avg_seq_length, kwargs...)
end |> MISTStyle.savefig("bayesian-$(run_name)$(suffix)")

with_theme(MISTStyle.theme()) do
plot_prediction_intervals(data["model"], data["chains"])
Expand All @@ -604,7 +621,8 @@ end
function plot_all(dir; kwargs...)
for model in readdir(dir; join=true)
isfile(joinpath(model, "chains.jld2")) || continue
plot_model(model)
plot_model(model; avg_seq_length=1.0, kwargs...)
plot_model(model; avg_seq_length=AVG_SEQ_LENGTH, kwargs...)
end
end

Expand All @@ -618,8 +636,11 @@ function (@main)(ARGS=[])
# Panel figure
data = jldopen(joinpath(chains_dir, "dec-3-sweep-smoothed-1.0--geometric-shape-gamma", "chains.jld2"), "r")
with_theme(MISTStyle.theme()) do
figure_bayesian_panel(data)
figure_bayesian_panel(data; avg_seq_length=1.0)
end |> MISTStyle.savefig("scaling_panel_baseline")
with_theme(MISTStyle.theme()) do
figure_bayesian_panel(data; avg_seq_length=AVG_SEQ_LENGTH)
end |> MISTStyle.savefig("scaling_panel_baseline-token-cost")

return nothing
end
76 changes: 58 additions & 18 deletions opt/BayesianScaling/plots/param_table.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
#!/usr/bin/env -S julia +release --color=auto --startup-file=no --project=@script
using PrettyTables
using DataFrames
using BayesianScaling: BayesianScaling, find, selectparam
using BayesianScaling: BayesianScaling, selectparam
using StatsBase: response, cor, corspearman
using Format: format
using JLD2: jldopen
using MISTStyle
using Makie
Expand Down Expand Up @@ -69,30 +71,29 @@ function model_summary(files)
return DataFrame(rows)
end

function model_name(model, lr_model_size, geometric_penalty, harmonic_shape_penalty)

named = Dict(
("Penalized", :model_size, true, false) => "Penalized, Baseline",
("Penalized", :d_model, true, false) => "Penalized, LR Scales with Hidden Size",
("Penalized", :model_size, false, true) => "Penalized, Additive",
("Penalized", :model_size, true, true) => "Penalized, Harmonic Shape Penalty",
("Chinchilla", nothing, nothing, nothing) => "No Penalties",
)
return get(named, (model, lr_model_size, geometric_penalty, harmonic_shape_penalty), nothing)
end

function write_scaling_param_table(df::DataFrame; sigdigits=3)
df = transform(df, [:model, :lr_model_size, :geometric_penalty, :harmonic_shape_penalty] => ByRow(LatexCell∘model_name) => :model_name)
function model_name(model, lr_model_size, geometric_penalty, harmonic_shape_penalty)
named = Dict(
("Penalized", :model_size, true, false) => "Penalized, Baseline",
("Penalized", :d_model, true, false) => "Penalized, \\(\\eta_{\\star} = f(d_{\\text{model}})\\)",
("Penalized", :model_size, true, true) => "\\stacked{Penalized, Harmonic}{Shape Penalty}",
("Penalized", :model_size, false, true) => "Penalized, Additive",
("Chinchilla", nothing, nothing, nothing) => "No Penalties",
)
return get(named, (model, lr_model_size, geometric_penalty, harmonic_shape_penalty), nothing)
end
df = transform(df, [:model, :lr_model_size, :geometric_penalty, :harmonic_shape_penalty] => ByRow(model_name) => :model_name)
subset!(df, :model_name => ByRow(!isnothing))
sort!(df, :waic)
df.model_name = LatexCell.(df.model_name)
round_sigfigs = ByRow(x -> round(x ; sigdigits))
cols = [
:model_name => "Model",
:waic => latex_cell"\acs{WAIC}",
:mape =>latex_cell"\acs{MAPE}",
# :aic =>latex_cell"\acs{AIC}",
:pearson =>latex_cell"Pearson's \(\rho\)",
:spearman =>latex_cell"Spearman's \(\rho\)",
:pearson =>latex_cell"\makecell{Pearson's\\\(r\)}",
:spearman =>latex_cell"\makecell{Spearman's\\\(\rho\)}",
:A =>latex_cell"$A$",
:B =>latex_cell"$B$",
:α =>latex_cell"$\alpha$",
Expand Down Expand Up @@ -121,7 +122,7 @@ function write_scaling_param_table(df::DataFrame; sigdigits=3)
l = sn(l; sigdigits)
u = sn(u; sigdigits)
end
return LatexCell("\\ci{$μ}{$l}{$u}")
return LatexCell("\\cistacked{$μ}{$l}{$u}")
end
function fmt_metric(v, i, j)
metrics = [:waic, :mape, :aic, :pearson, :spearman]
Expand All @@ -136,7 +137,7 @@ function write_scaling_param_table(df::DataFrame; sigdigits=3)
return format("{}", v)
end
end
return pretty_table(df[:, first.(cols)];
return pretty_table(String, df[:, first.(cols)];
tf,
alignment=[col in [:model_name] ? :l : :c for col in first.(cols)],
formatters=(fmt_ci, fmt_metric),
Expand Down Expand Up @@ -194,6 +195,44 @@ function figure_smoothed_scaling(df)
return f
end

function summary_rows(df; sigdigits=3)
df = sort(df, :n_smooth)
function model_name(model, lr_model_size, geometric_penalty, harmonic_shape_penalty)
named = Dict(
("Penalized", :model_size, true, false) => "MIST, Penalized Scaling \\cref{eq:penalized_neural_scaling}",
("Chinchilla", nothing, nothing, nothing) => "MIST, No Penalty Terms",
)
return get(named, (model, lr_model_size, geometric_penalty, harmonic_shape_penalty), nothing)
end
transform!(df, [:model, :lr_model_size, :geometric_penalty, :harmonic_shape_penalty] => ByRow(model_name) => :model_name)
df = subset(df,
:model_name => ByRow(!isnothing),
:file => ByRow(endswith("gamma")),
:n_smooth => ByRow(==(1)),
)

function format_ci(v)
μ, l, u = v
μ = round(μ; sigdigits)
l = round(l; sigdigits)
u = round(u; sigdigits)
return "\\ci{$μ}{$l}{$u}"
end
open(joinpath("fig", "scaling_summary_rows.tex"), "w") do io
for model in eachrow(df)
row = [
model.model_name,
format_ci(model.α),
format_ci(model.β),
format_ci(model.a),
format_ci(model.b),
format_ci(model.r)
]
println(io, join(row, " & ") * " \\\\ % $(model.file)")
end
end
return nothing
end

function (@main)(ARGS)
df = model_summary(BayesianScaling.find("out", r"dec-3-.*chains.jld2"))
Expand All @@ -202,4 +241,5 @@ function (@main)(ARGS)
open(joinpath("fig", "param_table.tex"), "w") do io
write(io, table)
end
summary_rows(df)
end
2 changes: 1 addition & 1 deletion opt/BayesianScaling/scripts/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@ Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

[sources]
BayesianScaling = {path = "../"}
BayesianScaling = {path = ".."}
Loading
Loading