From 3b3ab919fa26c64ae95d1282883ef9ba11da34a5 Mon Sep 17 00:00:00 2001 From: Alexius Wadell Date: Fri, 27 Mar 2026 11:07:50 -0600 Subject: [PATCH 01/11] wip --- opt/BayesianScaling/plots/Project.toml | 6 ++- opt/BayesianScaling/plots/campaign.jl | 4 ++ opt/BayesianScaling/scripts/Project.toml | 2 +- opt/BayesianScaling/scripts/dec_3_sweep.jl | 49 +++++++++++++------ opt/BayesianScaling/scripts/wandb_import.jl | 46 +++++++++++++---- .../python/helper/atomic_oov.py | 22 +++++++-- opt/run_logs/sync_wandb.py | 18 ++++--- 7 files changed, 112 insertions(+), 35 deletions(-) diff --git a/opt/BayesianScaling/plots/Project.toml b/opt/BayesianScaling/plots/Project.toml index 62b8ddc0..a0a7f84e 100644 --- a/opt/BayesianScaling/plots/Project.toml +++ b/opt/BayesianScaling/plots/Project.toml @@ -10,5 +10,9 @@ PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" [sources] -BayesianScaling = {path = "../"} +BayesianScaling = {path = ".."} MISTStyle = {path = "../../MISTStyle"} + +[compat] +PrettyTables = "^2.0.0" + diff --git a/opt/BayesianScaling/plots/campaign.jl b/opt/BayesianScaling/plots/campaign.jl index fdf0af66..c3275600 100755 --- a/opt/BayesianScaling/plots/campaign.jl +++ b/opt/BayesianScaling/plots/campaign.jl @@ -56,6 +56,10 @@ function format_table(df; sigdigits=3) ) end +function estimate_brute_force(df::DataFrame) + +end + function (@main)(ARGS=[]) data = jldopen(ARGS[1], "r") model = data["model"] diff --git a/opt/BayesianScaling/scripts/Project.toml b/opt/BayesianScaling/scripts/Project.toml index 8d9d7811..e65d27a9 100644 --- a/opt/BayesianScaling/scripts/Project.toml +++ b/opt/BayesianScaling/scripts/Project.toml @@ -13,4 +13,4 @@ Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" [sources] -BayesianScaling = {path = "../"} +BayesianScaling = {path = ".."} diff --git a/opt/BayesianScaling/scripts/dec_3_sweep.jl b/opt/BayesianScaling/scripts/dec_3_sweep.jl index 88c61958..c9f5f033 100755 --- a/opt/BayesianScaling/scripts/dec_3_sweep.jl +++ b/opt/BayesianScaling/scripts/dec_3_sweep.jl @@ -10,7 +10,7 @@ using Setfield: @set! GIT_ROOT = readchomp(`git rev-parse --show-toplevel`) -include("wandb_import.jl") +# include("wandb_import.jl") function prior_d_model_lr!(priors; lr_0=1.64e-4, d_model=768, batch_size=1024) # LR scaling with Model size used in Attention is All You Need @@ -22,7 +22,23 @@ function prior_d_model_lr!(priors; lr_0=1.64e-4, d_model=768, batch_size=1024) return priors end -function build_models() +function dec_3_sweep_runs() + df = pretraining_runs( + joinpath(GIT_ROOT, ".cache", "wandb-export"); + ) + subset!(df, + :tokenizer => ByRow(==("smirk")), + :created => ByRow(<=(DateTime(2025, 1))), + :tags => ByRow(tags -> "dec-3-sweep" in tags), + [:step, :max_steps] => ByRow((s, ms) -> s / ms > 0.8); + skipmissing=true, + ) + dropmissing!(df, [:model_size, :d_model, :effective_batch_size, :max_steps, :lr, :ff_ratio, :aspect_ratio, :kv_size]) + subset!(df, :val_loss_best => ByRow(x -> 1e-6 < x < 1.0); skipmissing=true) + return df +end + +function build_models(df_sweep) formulas = Dict( "baseline" => ShapedScaling(), "hoffman" => HoffmanScaling(), @@ -34,17 +50,7 @@ function build_models() models = Dict() for eval_batch in [1, 1e3, 1e4, 1e5, 1e6] - df = pretraining_runs( - joinpath(GIT_ROOT, ".cache", "wandb-export"); - smoothed_eval_batch=eval_batch - ) - subset!(df, - :tokenizer => ByRow(==("smirk")), - :created => ByRow(<=(DateTime(2025, 1))), - :tags => ByRow(tags -> "dec-3-sweep" in tags), - [:step, :max_steps] => ByRow((s, ms) -> s / ms > 0.8); - skipmissing=true, - ) + df = deepcopy(df_sweep) loss = eval_batch == 1 ? :val_loss_best : :val_loss_smooth df = select(df, :model_size, @@ -164,7 +170,22 @@ function fit_summary() return df end +function export_runs() + df_sweep = dec_3_sweep_runs() + outdir = joinpath(pkgdir(BayesianScaling), "out") + open(joinpath(outdir, "dec_3_sweep_runs.jsonl"), "w") do fid + for run in eachrow(df_sweep) + println(fid, JSON.json(Dict(pairs(run)))) + end + end + return df_sweep +end + function (@main)(args) - models = build_models() + # Record runs + df_sweep = export_runs() + + # Build models + models = build_models(df_sweep) run_models(models) end diff --git a/opt/BayesianScaling/scripts/wandb_import.jl b/opt/BayesianScaling/scripts/wandb_import.jl index da2d9006..ae3d52d3 100644 --- a/opt/BayesianScaling/scripts/wandb_import.jl +++ b/opt/BayesianScaling/scripts/wandb_import.jl @@ -3,10 +3,36 @@ using Dates: Dates, DateTime using JSON: JSON using BayesianScaling: find, average_last +function dedup_trace(metric_traces, name::AbstractString) + steps = metric_traces["step"] + y = metric_traces[name] + trace = Dict{eltype(steps), eltype(y)}() + fn = x -> (!isnothing(x[2]) && !ismissing(x[2])) + for (xt, yt) in Iterators.filter(fn, zip(steps, y)) + trace[xt] = yt isa Real ? yt : parse(Float64, yt) + end + x_out::Vector{Int} = (sort∘collect∘keys)(trace) + y_out::Vector{Float64} = [trace[x] for x in x_out] + return x_out, y_out +end + +function get_metric(metrics::Vector; kwargs...) + df = DataFrame(metrics) + kwargs = map(kv -> kv[1] => ByRow(==(kv[2])), collect(kwargs)) + subset!(df, kwargs...) + if nrow(df) == 1 + return first(df.value) + elseif nrow(df) == 0 + return missing + else + error("Multiple matches: $df") + end +end + function pretraining_runs(dir::AbstractString; smoothed_eval_batch=1e6) row = [] for file in find(joinpath(dir, "pretraining"), r".*\.json") - run = JSON.parsefile(file; null=missing) + run = JSON.parsefile(file; null=missing, allownan=true) pretrained_models = [ "electrolyte_fm.models.RoBERTa", @@ -25,13 +51,17 @@ function pretraining_runs(dir::AbstractString; smoothed_eval_batch=1e6) beta2 = !ismissing(beta) ? beta[2] : missing # Average over the last 1e6 samples - val_epoch_loss = map(run["metric_traces"]["val_loss"]) do x - x isa Real ? x : parse(Float64, x) - end + _, val_epoch_loss = dedup_trace(run["metric_traces"], "val_loss") effective_val_epoch = run["trainer"]["effective_val_epoch"] ismissing(effective_val_epoch) && continue val_loss_smooth = average_last(val_epoch_loss; n=smoothed_eval_batch / effective_val_epoch) + # Get last/best metrics + val_loss_best=get_metric(run["metrics"]; metric="loss_epoch", split="val", type="best") + val_loss_min=get_metric(run["metrics"]; metric="loss_epoch", split="val", type="min") + val_loss_last=get_metric(run["metrics"]; metric="loss_epoch", split="val", type="last") + val_loss_best = ismissing(val_loss_best) ? val_loss_min : val_loss_best + push!(row, (; id=run["id"], state=run["state"], @@ -49,17 +79,15 @@ function pretraining_runs(dir::AbstractString; smoothed_eval_batch=1e6) optimizer=run["optimizer"]["class_path"], tokenizer=run["data"]["tokenizer"], max_steps=run["trainer"]["num_training_steps"], - step=run["trainer"]["step"], - tokens=run["trainer"]["tokens"], + step=get_metric(run["metrics"]; metric="global_step", type="last"), gas=run["trainer"]["gas"], - masked_tokens=run["trainer"]["masked_tokens"], effective_batch_size=run["trainer"]["effective_batch_size"], effective_val_epoch, lr=run["optimizer"]["lr"], beta1, beta2, - val_loss_best=run["metrics"]["val_loss_best"], - val_loss_last=run["metrics"]["val_loss_last"], + val_loss_best, + val_loss_last, val_loss_smooth )) end diff --git a/opt/TokenizerStats/python/helper/atomic_oov.py b/opt/TokenizerStats/python/helper/atomic_oov.py index e8969145..f4be16a4 100644 --- a/opt/TokenizerStats/python/helper/atomic_oov.py +++ b/opt/TokenizerStats/python/helper/atomic_oov.py @@ -332,10 +332,19 @@ def tabulate_tokenizer( } -def process_tokenizer(dataset_name: str, tokenizer: dict[str, str], args) -> dict: +def process_tokenizer( + dataset_name: str, + tokenizer: dict[str, str], + tmqm_dataset: Optional[Path] = None, +) -> dict: + """Process a single dataset for a single tokenizer. + + NOTE: This function is executed in child processes when --workers is set. + Keep its signature/payload picklable. + """ tok = load_tokenizer(tokenizer["name_or_path"]) - if args.tmqm_dataset is not None: - DATASETS["tmQM"] = lambda: tmqm_dataset(args.tmqm_dataset) + if tmqm_dataset is not None: + DATASETS["tmQM"] = lambda: tmqm_dataset(tmqm_dataset) dataset = DATASETS[dataset_name] LOG.info("processing %s for %s", dataset_name, tokenizer["name"]) @@ -382,7 +391,12 @@ def process_tokenizer(dataset_name: str, tokenizer: dict[str, str], args) -> dic datasets = args.dataset or DATASETS.keys() with concurrent.futures.ProcessPoolExecutor(max_workers=args.workers) as executor: futures = { - executor.submit(process_tokenizer, dataset, tokenizer, args): dataset + executor.submit( + process_tokenizer, + dataset, + tokenizer, + args.tmqm_dataset, + ): dataset for dataset in datasets } diff --git a/opt/run_logs/sync_wandb.py b/opt/run_logs/sync_wandb.py index a7a0b79b..1d459a83 100755 --- a/opt/run_logs/sync_wandb.py +++ b/opt/run_logs/sync_wandb.py @@ -174,6 +174,12 @@ def run_summary(run: Run): metadata = run.metadata if isinstance(run.metadata, dict) else {} git_info = metadata.get("git", {}) summary = run.summary if isinstance(run.summary, dict) else {} + summary_metrics = ( + dict(run.summary_metrics) + if hasattr(run, "summary_metrics") and run.summary_metrics + else {} + ) + sys_metrics = system_metrics(run) stats = { "id": run.id, @@ -187,7 +193,9 @@ def run_summary(run: Run): "created": metadata.get("startedAt"), "gpu": metadata.get("gpu"), "commit": git_info.get("commit") if isinstance(git_info, dict) else None, - "runtime": summary.get("_runtime"), + "runtime": summary.get("_runtime") + or summary_metrics.get("_runtime") + or get_entry(summary_metrics, "_wandb", "runtime"), "optimizer": { "class_path": get_entry( config, "cli", "model", "optimizer", "class_path" @@ -242,13 +250,11 @@ def run_summary(run: Run): run, "stats/val_batch_throughput", "mean" ), "train_batch_time": summary_metric(run, "stats/train_batch_time_epoch"), - **system_metrics(run), + **sys_metrics, }, - "summary_metrics": dict(run.summary_metrics) - if hasattr(run, "summary_metrics") and run.summary_metrics - else {}, + "summary_metrics": summary_metrics, } - + print(stats["id"], stats["runtime"]) # Record world_size world_size = (stats["job_config"]["nodes"] or nan) * ( stats["job_config"]["gpus_per_node"] or nan From b79f9dfa9fc083643ae4037a2c95ee2cea1cba5d Mon Sep 17 00:00:00 2001 From: Alexius Wadell Date: Fri, 27 Mar 2026 12:08:43 -0600 Subject: [PATCH 02/11] brute force calc --- .../scripts/brute_force_cost.jl | 89 +++++++++++++++++++ 1 file changed, 89 insertions(+) create mode 100644 opt/BayesianScaling/scripts/brute_force_cost.jl diff --git a/opt/BayesianScaling/scripts/brute_force_cost.jl b/opt/BayesianScaling/scripts/brute_force_cost.jl new file mode 100644 index 00000000..55ad190b --- /dev/null +++ b/opt/BayesianScaling/scripts/brute_force_cost.jl @@ -0,0 +1,89 @@ +#!/usr/bin/env -S julia +release --color=auto --startup-file=no --project=@script +# Estimate the compute cost of doing a brute force search +# Usage: ./brute_force_cost.jl ../out/dec_3_sweep_runs.jsonl +using DataFrames +using JSON: JSON +using BayesianScaling: non_embedding_size + +function read_jsonl(file::AbstractString) + rows = [] + open(file, "r") do fid + for line in eachline(fid) + push!(rows, JSON.parse(line)) + end + end + return DataFrame(rows) +end + +function unique_lr_prefactor(df::DataFrame) + # LR Prefector + # Round to avoid 3.125 and 3.1249999... being treated differently + lr_prefactor = @. df.lr / sqrt(df.effective_batch_size) + return round.(lr_prefactor; sigdigits=6) +end + +function compute_full_cost(df::DataFrame) + levels = (; + d_model = unique(df.d_model), + n_layer = unique(df.d_model ./ df.aspect_ratio), + ff_ratio = unique(df.d_model), + effective_batch_size = unique(df.effective_batch_size), + kv_size = unique(df.kv_size), + data_size = unique(df.max_steps .* df.effective_batch_size), + lr_prefactor = unique_lr_prefactor(df), + ) + l_names = keys(levels) + return sum(Iterators.product(levels...)) do p + p = NamedTuple{l_names}(p) + N = non_embedding_size(p.d_model, p.ff_ratio, p.n_layer) + D = p.data_size + C = 6 * N * D + end +end + +function compute_partial_cost(df::DataFrame) + levels = (; + model_size = unique(df.model_size), + data_size = unique(df.effective_batch_size .* df.max_steps), + # lr_prefactor = unique_lr_prefactor(df), + ) + l_names = keys(levels) + return sum(Iterators.product(levels...)) do p + p = NamedTuple{l_names}(p) + C = 6 * p.model_size * p.data_size + end +end + +function compute_bayes_cost(df::DataFrame) + N = df.model_size + D = @. df.effective_batch_size * df.max_steps + return sum(x -> 6*prod(x), zip(N, D)) +end + +function estimate_cost(levels::NamedTuple) + l_names = keys(levels) + total_compute = sum(Iterators.product(levels...)) do p + p = NamedTuple{l_names}(p) + N = non_embedding_size(p.d_model, p.ff_ratio, p.n_layer) + D = p.data_size + C = 6 * N * D + end + return total_compute +end + +function (@main)(args=[]) + runs_jsonl = args[1] + @assert isfile(runs_jsonl) + df = read_jsonl(runs_jsonl) + + @show C_bayes = compute_bayes_cost(df) + @show C_full = compute_full_cost(df) + @show C_partial = compute_partial_cost(df) + + # Cost Savings + @show C_full / C_bayes + @show C_partial / C_bayes + + + return 0 +end From 7259f6ed322dadd48e989e8ed5353af845042452 Mon Sep 17 00:00:00 2001 From: Alexius Wadell Date: Fri, 27 Mar 2026 15:32:39 -0600 Subject: [PATCH 03/11] fix: account for seq len in flops calculation --- opt/BayesianScaling/plots/paper.jl | 40 ++++++++++++++++++++---------- 1 file changed, 27 insertions(+), 13 deletions(-) mode change 100644 => 100755 opt/BayesianScaling/plots/paper.jl diff --git a/opt/BayesianScaling/plots/paper.jl b/opt/BayesianScaling/plots/paper.jl old mode 100644 new mode 100755 index d6ac5f57..af18a14f --- a/opt/BayesianScaling/plots/paper.jl +++ b/opt/BayesianScaling/plots/paper.jl @@ -1,3 +1,4 @@ +#!/usr/bin/env -S uv run julia --project=@script --startup-file=no using Makie using MISTStyle using DataFrames @@ -6,6 +7,12 @@ using JLD2: jldopen using BayesianScaling using BayesianScaling: pf_day, get_nbins +const AVG_SEQ_LENGTH::Float64 = 65.2 +""" +Average number of tokens per molecule for the Smirk Tokenizer on REALSpace +Source: SI for doi:10.1021/acs.jcim.5c01856 +""" + function mark_model!(ax_scale, ax_compute; model_loss, d_model, @@ -15,11 +22,12 @@ function mark_model!(ax_scale, ax_compute; eff_batch_size, colorbar, h_loss, + avg_seq_length = 1, kwargs... ) N = BayesianScaling.non_embedding_size(d_model, ff_ratio * d_model, n_layers) D = steps * eff_batch_size - C = 6 * float(N) * float(D) + C = 6 * float(N) * float(D) * float(avg_seq_length) kwargs = (; marker=:star5, color=model_loss, @@ -321,9 +329,10 @@ end function figure_bayesian(data; N=logrange(1e5, 3e9; length=50), - C=logrange(1e12, 2*pf_day; length=50), + C=logrange(1e-5 * pf_day, 100*pf_day; length=50), p=0.95, n_samples=500, + avg_seq_length::AbstractFloat = 1.0, ) model = data["model"] chains = data["chains"] @@ -333,7 +342,7 @@ function figure_bayesian(data; model_loss = StatsBase.response(model) model_size = df.model_size - model_flops = @. 6 * float(df.model_size) * float(df.data_size) + model_flops = @. 6 * float(df.model_size) * float(df.data_size) * float(avg_seq_length) f_height = model.formula isa BayesianScaling.ShapedScaling ? 4.5inch : 3inch f = Figure(; @@ -358,13 +367,13 @@ function figure_bayesian(data; ylabel="Non-Embedding Parameters", ) - loss_c_opt = BayesianScaling.compute_optimal_loss(scaling(chain_samples), C; p) + loss_c_opt = BayesianScaling.compute_optimal_loss(scaling(chain_samples), C ./ avg_seq_length; p) h_opt = predictionband!(ax_compute, C ./ pf_day, loss_c_opt...; color=MISTStyle.UM_COLORS.maize, linewidth=2pt, band_color=(RGBf(0.596, 0.612, 0.592), 0.4), ) - loss_cs = first(BayesianScaling.hoffman_compute_scaling(scaling(chain_samples), N, C))' + loss_cs = first(BayesianScaling.hoffman_compute_scaling(scaling(chain_samples), N, C ./ avg_seq_length))' levels = logrange(minimum(loss_cs), 1.0; length=10) cb = Colorbar(gl[1, end+1]; label="Validation Loss (nats)", @@ -388,7 +397,7 @@ function figure_bayesian(data; ) - N_opt = BayesianScaling.compute_optimal_model_size(scaling(chain_samples), C) + N_opt = BayesianScaling.compute_optimal_model_size(scaling(chain_samples), C ./ avg_seq_length) predictionband!(ax_scale, C ./ pf_day, N_opt...; color=h_opt.color, linewidth=h_opt.linewidth, @@ -414,6 +423,7 @@ function figure_bayesian(data; model_loss=0.01019080262631178, colorbar=cb, h_loss, + avg_seq_length, ) # dh61satt @@ -426,6 +436,7 @@ function figure_bayesian(data; model_loss=0.03703538700938225, colorbar=cb, h_loss, + avg_seq_length, ) @@ -452,9 +463,10 @@ end function figure_bayesian_panel(data; N=logrange(1e5, 3e9; length=50), - C=logrange(1e12, 2*pf_day; length=50), + C=logrange(1e-5 * pf_day, 100*pf_day; length=50), p=0.95, n_samples=500, + avg_seq_length::AbstractFloat = 1, ) f = Figure(; size=(89mm, 170mm), figure_padding=(2, 2, 2, 4)) @@ -466,7 +478,7 @@ function figure_bayesian_panel(data; model_loss = StatsBase.response(model) model_size = df.model_size - model_flops = @. 6 * float(df.model_size) * float(df.data_size) + model_flops = @. 6 * float(df.model_size) * float(df.data_size) * float(avg_seq_length) # Plot Covariance colsize!(f.layout, 1, 100mm) @@ -497,13 +509,13 @@ function figure_bayesian_panel(data; ylabel="Non-Embedding Parameters", ) - loss_c_opt = BayesianScaling.compute_optimal_loss(scaling(chain_samples), C; p) + loss_c_opt = BayesianScaling.compute_optimal_loss(scaling(chain_samples), C ./ avg_seq_length; p) h_opt = predictionband!(ax_compute, C ./ pf_day, loss_c_opt...; color=MISTStyle.UM_COLORS.maize, linewidth=2pt, band_color=(RGBf(0.596, 0.612, 0.592), 0.4), ) - loss_cs = first(BayesianScaling.hoffman_compute_scaling(scaling(chain_samples), N, C))' + loss_cs = first(BayesianScaling.hoffman_compute_scaling(scaling(chain_samples), N, C ./ avg_seq_length))' levels = logrange(minimum(loss_cs), 1.0; length=10) cb = Colorbar(gl[1, end+1]; label="Validation Loss (nats)", @@ -527,7 +539,7 @@ function figure_bayesian_panel(data; ) - N_opt = BayesianScaling.compute_optimal_model_size(scaling(chain_samples), C) + N_opt = BayesianScaling.compute_optimal_model_size(scaling(chain_samples), C ./ avg_seq_length) predictionband!(ax_scale, C ./ pf_day, N_opt...; color=h_opt.color, linewidth=h_opt.linewidth, @@ -553,6 +565,7 @@ function figure_bayesian_panel(data; model_loss=0.01019080262631178, colorbar=cb, h_loss, + avg_seq_length, ) # dh61satt @@ -565,6 +578,7 @@ function figure_bayesian_panel(data; model_loss=0.03703538700938225, colorbar=cb, h_loss, + avg_seq_length, ) # Fit distributions @@ -604,7 +618,7 @@ end function plot_all(dir; kwargs...) for model in readdir(dir; join=true) isfile(joinpath(model, "chains.jld2")) || continue - plot_model(model) + plot_model(model; avg_seq_length=AVG_SEQ_LENGTH, kwargs...) end end @@ -618,7 +632,7 @@ function (@main)(ARGS=[]) # Panel figure data = jldopen(joinpath(chains_dir, "dec-3-sweep-smoothed-1.0--geometric-shape-gamma", "chains.jld2"), "r") with_theme(MISTStyle.theme()) do - figure_bayesian_panel(data) + figure_bayesian_panel(data; avg_seq_length=AVG_SEQ_LENGTH) end |> MISTStyle.savefig("scaling_panel_baseline") return nothing From c00e7b25757ccc0c79284cade3e32a92ea9e6b10 Mon Sep 17 00:00:00 2001 From: Alexius Wadell Date: Fri, 27 Mar 2026 17:52:44 -0600 Subject: [PATCH 04/11] fixup scaling plots --- opt/BayesianScaling/plots/paper.jl | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/opt/BayesianScaling/plots/paper.jl b/opt/BayesianScaling/plots/paper.jl index af18a14f..6ac752b9 100755 --- a/opt/BayesianScaling/plots/paper.jl +++ b/opt/BayesianScaling/plots/paper.jl @@ -110,11 +110,11 @@ function plot_scaling_params!(f, chains; legend_pos=:top) ) prior_works = [ - "Kaplan et al." => (; α=0.076, β=0.103, a = 0.58, b = 0.42), "Hoffmann et al. (Appr. 1)" => (; a = 0.50, b = 0.50), "Hoffmann et al. (Appr. 3)" => (; α=0.34, β=0.28, a = 0.46, b = 0.54), "Bi et al. (Early)" => (; a = 0.450, b = 0.550), "Bi et al. (Current)" => (; a = 0.524, b = 0.478), + "Kaplan et al." => (; α=0.076, β=0.103, a = 0.58, b = 0.42), ] gl = GridLayout(f[1, 1]) @@ -175,12 +175,14 @@ function plot_scaling_params!(f, chains; legend_pos=:top) Legend(legend_pos, elem, MISTStyle.label.(elem); tellwidth, tellheight, - margin=2pt .* (1, 1, 1, 1), orientation, nbanks, halign, valign, + padding=(4pt, 4pt, 1pt, 1pt), + colgap=4pt, ) + rowgap!(gl, 2pt) return f end @@ -328,7 +330,7 @@ function plot_lr_partial_dependence!(f, model, chains; p=0.95, npoints=100) end function figure_bayesian(data; - N=logrange(1e5, 3e9; length=50), + N=logrange(1e5, 4e9; length=50), C=logrange(1e-5 * pf_day, 100*pf_day; length=50), p=0.95, n_samples=500, @@ -462,7 +464,7 @@ function figure_bayesian(data; end function figure_bayesian_panel(data; - N=logrange(1e5, 3e9; length=50), + N=logrange(1e5, 4e9; length=50), C=logrange(1e-5 * pf_day, 100*pf_day; length=50), p=0.95, n_samples=500, From 5d66c547af63cd402444f22ff89577d3b86e27ee Mon Sep 17 00:00:00 2001 From: Alexius Wadell Date: Sat, 28 Mar 2026 13:01:29 -0600 Subject: [PATCH 05/11] add script to compute scaling campaign costs --- .../scripts/brute_force_cost.jl | 80 ++++++++++++++----- opt/BayesianScaling/scripts/dec_3_sweep.jl | 1 + opt/BayesianScaling/scripts/wandb_import.jl | 71 +++++++++------- uv.lock | 24 +++--- 4 files changed, 116 insertions(+), 60 deletions(-) mode change 100644 => 100755 opt/BayesianScaling/scripts/brute_force_cost.jl diff --git a/opt/BayesianScaling/scripts/brute_force_cost.jl b/opt/BayesianScaling/scripts/brute_force_cost.jl old mode 100644 new mode 100755 index 55ad190b..7ce9bc54 --- a/opt/BayesianScaling/scripts/brute_force_cost.jl +++ b/opt/BayesianScaling/scripts/brute_force_cost.jl @@ -5,6 +5,12 @@ using DataFrames using JSON: JSON using BayesianScaling: non_embedding_size +const AVG_SEQ_LENGTH::Float64 = 65.2 +""" +Average number of tokens per molecule for the Smirk Tokenizer on REALSpace +Source: SI for doi:10.1021/acs.jcim.5c01856 +""" + function read_jsonl(file::AbstractString) rows = [] open(file, "r") do fid @@ -19,44 +25,48 @@ function unique_lr_prefactor(df::DataFrame) # LR Prefector # Round to avoid 3.125 and 3.1249999... being treated differently lr_prefactor = @. df.lr / sqrt(df.effective_batch_size) - return round.(lr_prefactor; sigdigits=6) + return unique(round.(lr_prefactor; sigdigits=6)) end function compute_full_cost(df::DataFrame) levels = (; d_model = unique(df.d_model), n_layer = unique(df.d_model ./ df.aspect_ratio), - ff_ratio = unique(df.d_model), + ff_ratio = unique(df.ff_ratio), effective_batch_size = unique(df.effective_batch_size), kv_size = unique(df.kv_size), data_size = unique(df.max_steps .* df.effective_batch_size), lr_prefactor = unique_lr_prefactor(df), ) l_names = keys(levels) - return sum(Iterators.product(levels...)) do p + total_cost = sum(Iterators.product(levels...)) do p p = NamedTuple{l_names}(p) N = non_embedding_size(p.d_model, p.ff_ratio, p.n_layer) D = p.data_size - C = 6 * N * D + return 6 * N * D * AVG_SEQ_LENGTH end + return levels, total_cost end -function compute_partial_cost(df::DataFrame) +function compute_partial_cost(df::DataFrame; include_lr::Bool = true) levels = (; model_size = unique(df.model_size), data_size = unique(df.effective_batch_size .* df.max_steps), - # lr_prefactor = unique_lr_prefactor(df), ) + if include_lr + levels = (; levels..., lr_prefactor = unique_lr_prefactor(df)) + end l_names = keys(levels) - return sum(Iterators.product(levels...)) do p + total_cost = sum(Iterators.product(levels...)) do p p = NamedTuple{l_names}(p) - C = 6 * p.model_size * p.data_size + 6 * p.model_size * p.data_size * AVG_SEQ_LENGTH end + return levels, total_cost end function compute_bayes_cost(df::DataFrame) N = df.model_size - D = @. df.effective_batch_size * df.max_steps + D = @. df.effective_batch_size * df.max_steps * AVG_SEQ_LENGTH return sum(x -> 6*prod(x), zip(N, D)) end @@ -66,24 +76,56 @@ function estimate_cost(levels::NamedTuple) p = NamedTuple{l_names}(p) N = non_embedding_size(p.d_model, p.ff_ratio, p.n_layer) D = p.data_size - C = 6 * N * D + return 6 * N * D * AVG_SEQ_LENGTH end return total_compute end -function (@main)(args=[]) - runs_jsonl = args[1] - @assert isfile(runs_jsonl) - df = read_jsonl(runs_jsonl) +function show_levels(levels::NamedTuple) + for (k, v) in pairs(levels) + n = length(v) + println("$k ($n): $v") + end + return nothing +end + +function compute_costs(df) + C_bayes = compute_bayes_cost(df) + full_levels, C_full = compute_full_cost(df) + partial_levels_lr, C_partial_lr = compute_partial_cost(df; include_lr = true) + partial_levels, C_partial = compute_partial_cost(df; include_lr = false) - @show C_bayes = compute_bayes_cost(df) - @show C_full = compute_full_cost(df) - @show C_partial = compute_partial_cost(df) + rs(x) = round(x; sigdigits=3) + + println("Full Factorial:") + println("Cost: $(rs(C_full / pf_day)) pf-day") + show_levels(full_levels) + + println("\n\nN + D Only:") + println("Cost: $(rs(C_partial / pf_day)) pf-day") + println("Cost (3 lr): $(rs(3 * C_partial / pf_day)) pf-day") + show_levels(partial_levels) + + println("\n\nN + D + lr_prefactor Only:") + println("Cost: $(rs(C_partial_lr / pf_day)) pf-day") + show_levels(partial_levels_lr) # Cost Savings - @show C_full / C_bayes - @show C_partial / C_bayes + gpu_hours = sum(df.runtime .* df.world_size / (60 * 60)) + println("\n\nBayes Cost: $(rs(C_bayes / pf_day)) pf-day") + println("Bayes Cost: $(rs(gpu_hours)) GPU-hours") + println("Rel. Full: $(rs(C_full / C_bayes))") + println("Rel. Partial (N, D): $(rs(C_partial / C_bayes))") + println("Rel. Partial (N, D, 3 lr): $(rs(3 * C_partial / C_bayes))") + println("Rel. Partial (N, D, lr): $(rs(C_partial_lr / C_bayes))") + return nothing +end +function (@main)(args=[]) + runs_jsonl = args[1] + @assert isfile(runs_jsonl) + df = read_jsonl(runs_jsonl) + compute_costs(df) return 0 end diff --git a/opt/BayesianScaling/scripts/dec_3_sweep.jl b/opt/BayesianScaling/scripts/dec_3_sweep.jl index c9f5f033..24439a09 100755 --- a/opt/BayesianScaling/scripts/dec_3_sweep.jl +++ b/opt/BayesianScaling/scripts/dec_3_sweep.jl @@ -173,6 +173,7 @@ end function export_runs() df_sweep = dec_3_sweep_runs() outdir = joinpath(pkgdir(BayesianScaling), "out") + mkpath(outdir) open(joinpath(outdir, "dec_3_sweep_runs.jsonl"), "w") do fid for run in eachrow(df_sweep) println(fid, JSON.json(Dict(pairs(run)))) diff --git a/opt/BayesianScaling/scripts/wandb_import.jl b/opt/BayesianScaling/scripts/wandb_import.jl index ae3d52d3..28fbab38 100644 --- a/opt/BayesianScaling/scripts/wandb_import.jl +++ b/opt/BayesianScaling/scripts/wandb_import.jl @@ -1,5 +1,5 @@ using DataFrames -using Dates: Dates, DateTime +using Dates: Dates, DateTime, @dateformat_str using JSON: JSON using BayesianScaling: find, average_last @@ -62,34 +62,47 @@ function pretraining_runs(dir::AbstractString; smoothed_eval_batch=1e6) val_loss_last=get_metric(run["metrics"]; metric="loss_epoch", split="val", type="last") val_loss_best = ismissing(val_loss_best) ? val_loss_min : val_loss_best - push!(row, (; - id=run["id"], - state=run["state"], - user=run["user"], - cluster=run["cluster"], - commit=run["commit"], - tags=string.(run["tags"]), - created=DateTime(run["created"][1:23], Dates.ISODateTimeFormat), - model_size=run["model"]["model_size"], - model_class=run["model"]["class_path"], - d_model, - ff_ratio, - kv_size, - aspect_ratio, - optimizer=run["optimizer"]["class_path"], - tokenizer=run["data"]["tokenizer"], - max_steps=run["trainer"]["num_training_steps"], - step=get_metric(run["metrics"]; metric="global_step", type="last"), - gas=run["trainer"]["gas"], - effective_batch_size=run["trainer"]["effective_batch_size"], - effective_val_epoch, - lr=run["optimizer"]["lr"], - beta1, - beta2, - val_loss_best, - val_loss_last, - val_loss_smooth - )) + # Parse Date + if !ismissing(run["created"]) + created = DateTime(run["created"][1:23], Dates.ISODateTimeFormat) + else + created = missing + end + + try + push!(row, (; + id=run["id"], + state=run["state"], + user=run["user"], + cluster=run["cluster"], + commit=run["commit"], + tags=string.(run["tags"]), + created, + runtime=float(run["runtime"]), + world_size=Int(run["job_config"]["world_size"]), + model_size=run["model"]["model_size"], + model_class=run["model"]["class_path"], + d_model, + ff_ratio, + kv_size, + aspect_ratio, + optimizer=run["optimizer"]["class_path"], + tokenizer=run["data"]["tokenizer"], + max_steps=run["trainer"]["num_training_steps"], + step=get_metric(run["metrics"]; metric="global_step", type="last"), + gas=run["trainer"]["gas"], + effective_batch_size=run["trainer"]["effective_batch_size"], + effective_val_epoch, + lr=run["optimizer"]["lr"], + beta1, + beta2, + val_loss_best, + val_loss_last, + val_loss_smooth + )) + catch e + @error "Failed to import $file" e + end end return DataFrame(row) end diff --git a/uv.lock b/uv.lock index 9895cb03..bfea30a3 100644 --- a/uv.lock +++ b/uv.lock @@ -3095,7 +3095,7 @@ wheels = [ [[package]] name = "wandb" -version = "0.22.2" +version = "0.25.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "click" }, @@ -3109,17 +3109,17 @@ dependencies = [ { name = "sentry-sdk" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/c1/a8/680bd77e11a278e6c14a2cb4646e8ab9525b2baaa81c3d12dc0f616aa4aa/wandb-0.22.2.tar.gz", hash = "sha256:510f5a1ac30d16921c36c3b932da852f046641d4aee98a86a7f5ec03a6e95bda", size = 41401439, upload-time = "2025-10-07T19:54:21.88Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/e7/b3/8c637fb594cfd574ce9c9f7d0ac2f2d12742eb38ec59dcbb713beae95343/wandb-0.22.2-py3-none-macosx_12_0_arm64.whl", hash = "sha256:2e29c9fa4462b5411b2cd2175ae33eff4309c91de7c426bca6bc8e7abc7e5dec", size = 18677549, upload-time = "2025-10-07T19:54:00.839Z" }, - { url = "https://files.pythonhosted.org/packages/d3/f3/e309a726eaebddad6b8d9a73a50891e5796962ec8a091bb6a61d31692d1e/wandb-0.22.2-py3-none-macosx_12_0_x86_64.whl", hash = "sha256:c42d594cd7a9da4fd39ecdb0abbc081b61f304123277b2b6c4ba84283956fd21", size = 19715188, upload-time = "2025-10-07T19:54:03.805Z" }, - { url = "https://files.pythonhosted.org/packages/f9/73/fad59910215876008f4781b57d828d1b19b3677c9b46af615e7229746435/wandb-0.22.2-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a5188d84e66d3fd584f3b3ae4d2a70e78f29403c0528e6aecaa4188a1fcf54d8", size = 18463148, upload-time = "2025-10-07T19:54:05.676Z" }, - { url = "https://files.pythonhosted.org/packages/87/11/572c1913b5b92e4c519f735adfae572b46f2d79d99ede63eec0d6a272d6e/wandb-0.22.2-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:88ccd484af9f21cfc127976793c3cf66cfe1acd75bd8cd650086a64e88bac4bf", size = 19908645, upload-time = "2025-10-07T19:54:07.693Z" }, - { url = "https://files.pythonhosted.org/packages/6d/0d/133aa82f5a505ba638b4fda5014cefddfe7f1f6238ef4afc0871ec61c41f/wandb-0.22.2-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:abf0ed175e791af64110e0a0b99ce02bbbbd1017722bc32d3bc328efb86450cd", size = 18501348, upload-time = "2025-10-07T19:54:10.234Z" }, - { url = "https://files.pythonhosted.org/packages/d0/d5/776203be2601872f01dacc6a5b4274106ec0db7cd3bf2cdb3b741f8fc932/wandb-0.22.2-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:44e77c56403b90bf3473a7ca3bfc4d42c636b7c0e31a5fb9cd0382f08302f74b", size = 20001756, upload-time = "2025-10-07T19:54:12.452Z" }, - { url = "https://files.pythonhosted.org/packages/30/43/ae3fa46e20b1d9a6508dd9abe716d57205c038ed4661c5c98ace48a60eac/wandb-0.22.2-py3-none-win32.whl", hash = "sha256:44d12bd379dbe15be5ceed6bdf23803d42f648ba0dd111297b4c47a3c7be6dbd", size = 19075950, upload-time = "2025-10-07T19:54:14.892Z" }, - { url = "https://files.pythonhosted.org/packages/09/59/c174321e868205f7a659d1e5ec51f546e62267296d6f4179bb9119294964/wandb-0.22.2-py3-none-win_amd64.whl", hash = "sha256:c95eb221bf316c0872f7ac55071856b9f25f95a2de983ada48acf653ce259386", size = 19075953, upload-time = "2025-10-07T19:54:16.837Z" }, - { url = "https://files.pythonhosted.org/packages/7a/a2/c7c24fda78513cab5686949d8cb36459dbbccbbb4b2b6fc67237ece31a00/wandb-0.22.2-py3-none-win_arm64.whl", hash = "sha256:20d2ab9aa10445aab3d60914a980f002a4f66566e28b0cd156b1e462f0080a0d", size = 17383217, upload-time = "2025-10-07T19:54:19.384Z" }, +sdist = { url = "https://files.pythonhosted.org/packages/60/bb/eb579bf9abac70934a014a9d4e45346aab307994f3021d201bebe5fa25ec/wandb-0.25.1.tar.gz", hash = "sha256:b2a95cd777ecbe7499599a43158834983448a0048329bc7210ef46ca18d21994", size = 43983308, upload-time = "2026-03-10T23:51:44.227Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e7/d8/873553b6818499d1b1de314067d528b892897baf0dc81fedc0e845abc2dd/wandb-0.25.1-py3-none-macosx_12_0_arm64.whl", hash = "sha256:9bb0679a3e2dcd96db9d9b6d3e17d046241d8d122974b24facb85cc93309a8c9", size = 23615900, upload-time = "2026-03-10T23:51:06.278Z" }, + { url = "https://files.pythonhosted.org/packages/71/ea/b131f319aaa5d0bf7572b6bfcff3dd89e1cf92b17eee443bbab71d12d74c/wandb-0.25.1-py3-none-macosx_12_0_x86_64.whl", hash = "sha256:0fb13ed18914027523e7b4fc20380c520e0d10da0ee452f924a13f84509fbe12", size = 25576144, upload-time = "2026-03-10T23:51:11.527Z" }, + { url = "https://files.pythonhosted.org/packages/70/5f/81508581f0bb77b0495665c1c78e77606a48e66e855ca71ba7c8ae29efa4/wandb-0.25.1-py3-none-manylinux_2_28_aarch64.whl", hash = "sha256:cc4521eb5223429ddab5e8eee9b42fdf4caabdf0bc4e0e809042720e5fbef0ed", size = 23070425, upload-time = "2026-03-10T23:51:15.71Z" }, + { url = "https://files.pythonhosted.org/packages/f2/c7/445155ef010e2e35d190797d7c36ff441e062a5b566a6da4778e22233395/wandb-0.25.1-py3-none-manylinux_2_28_x86_64.whl", hash = "sha256:e73b4c55b947edae349232d5845204d30fac88e18eb4ad1d4b96bf7cf898405a", size = 25628142, upload-time = "2026-03-10T23:51:19.326Z" }, + { url = "https://files.pythonhosted.org/packages/d5/63/f5c55ee00cf481ef1ccd3c385a0585ad52e7840d08419d4f82ddbeeea959/wandb-0.25.1-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:22b84065aa398e1624d2e5ad79e08bc4d2af41a6db61697b03b3aaba332977c6", size = 23123172, upload-time = "2026-03-10T23:51:23.418Z" }, + { url = "https://files.pythonhosted.org/packages/3e/d9/19eb7974c0e9253bcbaee655222c0f0e1a52e63e9479ee711b4208f8ac31/wandb-0.25.1-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:005c4c6b5126ef8f4b4110e5372d950918b00637d6dc4b615ad17445f9739478", size = 25714479, upload-time = "2026-03-10T23:51:27.421Z" }, + { url = "https://files.pythonhosted.org/packages/11/19/466c1d03323a4a0ed7d4036a59b18d6b6f67cb5032e444205927e226b18d/wandb-0.25.1-py3-none-win32.whl", hash = "sha256:8f2d04f16b88d65bfba9d79fb945f6c64e2686215469a841936e0972be8ec6a5", size = 24967338, upload-time = "2026-03-10T23:51:31.833Z" }, + { url = "https://files.pythonhosted.org/packages/89/22/680d34c1587f3a979c701b66d71aa7c42b4ef2fdf0774f67034e618e834e/wandb-0.25.1-py3-none-win_amd64.whl", hash = "sha256:62db5166de14456156d7a85953a58733a631228e6d4248a753605f75f75fb845", size = 24967343, upload-time = "2026-03-10T23:51:36.026Z" }, + { url = "https://files.pythonhosted.org/packages/c4/e8/76836b75d401ff5912aaf513176e64557ceaec4c4946bfd38a698ff84d48/wandb-0.25.1-py3-none-win_arm64.whl", hash = "sha256:cc7c34b70cf4b7be4d395541e82e325fd9d2be978d62c9ec01f1a7141523b6bb", size = 22080774, upload-time = "2026-03-10T23:51:40.196Z" }, ] [[package]] From 06566a5c97064e50a33509f018b6130f3dd391f8 Mon Sep 17 00:00:00 2001 From: Alexius Wadell Date: Sat, 28 Mar 2026 13:03:21 -0600 Subject: [PATCH 06/11] fix: missing pf_day import --- opt/BayesianScaling/scripts/brute_force_cost.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/opt/BayesianScaling/scripts/brute_force_cost.jl b/opt/BayesianScaling/scripts/brute_force_cost.jl index 7ce9bc54..3a721a86 100755 --- a/opt/BayesianScaling/scripts/brute_force_cost.jl +++ b/opt/BayesianScaling/scripts/brute_force_cost.jl @@ -3,7 +3,7 @@ # Usage: ./brute_force_cost.jl ../out/dec_3_sweep_runs.jsonl using DataFrames using JSON: JSON -using BayesianScaling: non_embedding_size +using BayesianScaling: non_embedding_size, pf_day const AVG_SEQ_LENGTH::Float64 = 65.2 """ From d7337fd1d3b0105ff4f3c487cc587e422fea4927 Mon Sep 17 00:00:00 2001 From: Alexius Wadell Date: Sat, 28 Mar 2026 13:15:50 -0600 Subject: [PATCH 07/11] add latex output --- .../scripts/brute_force_cost.jl | 223 ++++++++++++++++-- 1 file changed, 208 insertions(+), 15 deletions(-) diff --git a/opt/BayesianScaling/scripts/brute_force_cost.jl b/opt/BayesianScaling/scripts/brute_force_cost.jl index 3a721a86..13517c80 100755 --- a/opt/BayesianScaling/scripts/brute_force_cost.jl +++ b/opt/BayesianScaling/scripts/brute_force_cost.jl @@ -1,13 +1,17 @@ #!/usr/bin/env -S julia +release --color=auto --startup-file=no --project=@script # Estimate the compute cost of doing a brute force search -# Usage: ./brute_force_cost.jl ../out/dec_3_sweep_runs.jsonl +# Usage: +# ./brute_force_cost.jl RUNS_JSONL OUTPUT_TEX +# Example: +# ./brute_force_cost.jl ../out/dec_3_sweep_runs.jsonl cost_summary.tex + using DataFrames using JSON: JSON using BayesianScaling: non_embedding_size, pf_day const AVG_SEQ_LENGTH::Float64 = 65.2 """ -Average number of tokens per molecule for the Smirk Tokenizer on REALSpace +Average number of tokens per molecule for the Smirk Tokenizer on REALSpace. Source: SI for doi:10.1021/acs.jcim.5c01856 """ @@ -22,10 +26,9 @@ function read_jsonl(file::AbstractString) end function unique_lr_prefactor(df::DataFrame) - # LR Prefector - # Round to avoid 3.125 and 3.1249999... being treated differently + # Round to avoid 3.125 and 3.1249999... being treated differently. lr_prefactor = @. df.lr / sqrt(df.effective_batch_size) - return unique(round.(lr_prefactor; sigdigits=6)) + return unique(round.(lr_prefactor; sigdigits=3)) end function compute_full_cost(df::DataFrame) @@ -48,7 +51,7 @@ function compute_full_cost(df::DataFrame) return levels, total_cost end -function compute_partial_cost(df::DataFrame; include_lr::Bool = true) +function compute_partial_cost(df::DataFrame; include_lr::Bool=true) levels = (; model_size = unique(df.model_size), data_size = unique(df.effective_batch_size .* df.max_steps), @@ -59,7 +62,7 @@ function compute_partial_cost(df::DataFrame; include_lr::Bool = true) l_names = keys(levels) total_cost = sum(Iterators.product(levels...)) do p p = NamedTuple{l_names}(p) - 6 * p.model_size * p.data_size * AVG_SEQ_LENGTH + return 6 * p.model_size * p.data_size * AVG_SEQ_LENGTH end return levels, total_cost end @@ -67,7 +70,7 @@ end function compute_bayes_cost(df::DataFrame) N = df.model_size D = @. df.effective_batch_size * df.max_steps * AVG_SEQ_LENGTH - return sum(x -> 6*prod(x), zip(N, D)) + return sum(x -> 6 * prod(x), zip(N, D)) end function estimate_cost(levels::NamedTuple) @@ -89,11 +92,176 @@ function show_levels(levels::NamedTuple) return nothing end -function compute_costs(df) +function round_sig(x::Real; sigdigits::Int=3) + return round(float(x); sigdigits=sigdigits) +end + +function latex_escape_identifier(x) + s = String(x) + return replace(s, "_" => "\\_") +end + +function latex_math_identifier(x) + s = String(x) + if occursin("_", s) + head, tail... = split(s, "_") + return "\$" * head * "_{\\text{" * join(tail, "\\_") * "}}\$" + end + return "\$" * s * "\$" +end + +function latex_inline_identifier(x) + if x in (:d_model, :n_layer, :kv_size) + return latex_math_identifier(x) + elseif x == :data_size + return "\$D\$" + elseif x == :lr_prefactor + return "\$lr\$" + elseif x == :ff_ratio + return "ff\\_ratio" + elseif x == :effective_batch_size + return "batch" + else + return latex_escape_identifier(x) + end +end + +function latex_description_label(x) + if x in (:d_model, :n_layer, :kv_size) + return latex_math_identifier(x) + elseif x == :lr_prefactor + return "\$lr\\_prefactor\$" + else + return latex_escape_identifier(x) + end +end + +function latex_number(x::Real; sigdigits::Int=6) + xr = round(float(x); sigdigits=sigdigits) + + if iszero(xr) + return "0" + end + + if 1e-3 <= abs(xr) < 1e4 + if isinteger(xr) + return string(Int(round(xr))) + end + return string(xr) + end + + exp10 = floor(Int, log10(abs(xr))) + mant = xr / 10.0^exp10 + mant = round(mant; sigdigits=sigdigits) + return "\\sn{$mant}{$exp10}" +end + +function latex_vector(values; sigdigits::Int=6) + parts = [latex_number(v; sigdigits=sigdigits) for v in sort(collect(values))] + return "[" * join(parts, ",\\;\n") * "]" +end + +function latex_levels_block( + io::IO, + title::AbstractString, + levels::NamedTuple; + skip::Set{Symbol}=Set{Symbol}(), +) + println(io, "\\paragraph{$title}") + println(io, "\\begin{description}") + println(io) + + for (k, v) in pairs(levels) + if k in skip + continue + end + n = length(v) + label = latex_description_label(k) + vec = latex_vector(v) + println(io, "\\item[$label ($n)]") + println(io, "\\[") + println(io, vec) + println(io, "\\]") + println(io) + end + + println(io, "\\end{description}") + println(io) + return nothing +end + +function generate_cost_latex( + full_levels::NamedTuple, + C_full::Real, + partial_levels::NamedTuple, + C_partial::Real, + partial_levels_lr::NamedTuple, + C_partial_lr::Real, + C_bayes::Real; + cost_3lr::Union{Nothing, Real}=nothing, + sigdigits_cost::Int=3, +) + rs(x) = round(float(x); sigdigits=sigdigits_cost) + + rel_full = C_full / C_bayes + rel_partial = C_partial / C_bayes + rel_partial_3lr = isnothing(cost_3lr) ? nothing : cost_3lr / C_bayes + rel_partial_lr = C_partial_lr / C_bayes + + full_level_summary = join([ + "$(latex_inline_identifier(k))($(length(v)))" for (k, v) in pairs(full_levels) + ], ", ") + + partial_level_summary = join([ + "$(latex_inline_identifier(k))($(length(v)))" for (k, v) in pairs(partial_levels) + ], ", ") + + partial_lr_level_summary = join([ + "$(latex_inline_identifier(k))($(length(v)))" for (k, v) in pairs(partial_levels_lr) + ], ", ") + + io = IOBuffer() + + println(io, "\\begin{table}[h]") + println(io, "\\centering") + println(io, "\\begin{tabular}{lccc}") + println(io, "\\hline") + println(io, "Case & Levels (with \$n\$) & Total Cost (PF-days) & Cost Rel.\\ Bayes \\\\") + println(io, "\\hline") + println(io, "Full Factorial & $full_level_summary & $(rs(C_full / pf_day)) & $(rs(rel_full)) \\\\") + println(io, "\$N + D\$ Only & $partial_level_summary & $(rs(C_partial / pf_day)) & $(rs(rel_partial)) \\\\") + if !isnothing(cost_3lr) + println( + io, + "\$N + D\$ Only (3 lr) & model\\_size($(length(partial_levels.model_size))), \$D\$($(length(partial_levels.data_size))), \$lr\$(3) & $(rs(cost_3lr / pf_day)) & $(rs(rel_partial_3lr)) \\\\", + ) + end + println(io, "\$N + D + lr\$ & $partial_lr_level_summary & $(rs(C_partial_lr / pf_day)) & $(rs(rel_partial_lr)) \\\\") + println(io, "Bayesian Optimization & --- & $(rs(C_bayes / pf_day)) & 1.0 \\\\") + println(io, "\\hline") + println(io, "\\end{tabular}") + println(io, "\\caption{Training cost comparison across experimental design strategies.}") + println(io, "\\end{table}") + println(io) + println(io) + + latex_levels_block(io, "Full Factorial Levels", full_levels) + latex_levels_block(io, "\$N + D\$ Only Levels", partial_levels) + latex_levels_block( + io, + "\$N + D + lr\$ Levels", + partial_levels_lr; + skip=Set([:model_size, :data_size]), + ) + + return String(take!(io)) +end + +function compute_costs(df::DataFrame) C_bayes = compute_bayes_cost(df) full_levels, C_full = compute_full_cost(df) - partial_levels_lr, C_partial_lr = compute_partial_cost(df; include_lr = true) - partial_levels, C_partial = compute_partial_cost(df; include_lr = false) + partial_levels_lr, C_partial_lr = compute_partial_cost(df; include_lr=true) + partial_levels, C_partial = compute_partial_cost(df; include_lr=false) rs(x) = round(x; sigdigits=3) @@ -110,7 +278,6 @@ function compute_costs(df) println("Cost: $(rs(C_partial_lr / pf_day)) pf-day") show_levels(partial_levels_lr) - # Cost Savings gpu_hours = sum(df.runtime .* df.world_size / (60 * 60)) println("\n\nBayes Cost: $(rs(C_bayes / pf_day)) pf-day") println("Bayes Cost: $(rs(gpu_hours)) GPU-hours") @@ -119,13 +286,39 @@ function compute_costs(df) println("Rel. Partial (N, D, 3 lr): $(rs(3 * C_partial / C_bayes))") println("Rel. Partial (N, D, lr): $(rs(C_partial_lr / C_bayes))") - return nothing + latex = generate_cost_latex( + full_levels, + C_full, + partial_levels, + C_partial, + partial_levels_lr, + C_partial_lr, + C_bayes; + cost_3lr=3 * C_partial, + ) + + println("\n\nLaTeX:\n") + println(latex) + + return latex end function (@main)(args=[]) + @assert length(args) >= 2 "Usage: brute_force_cost.jl RUNS_JSONL OUTPUT_TEX" + runs_jsonl = args[1] - @assert isfile(runs_jsonl) + output_tex = args[2] + + @assert isfile(runs_jsonl) "Input JSONL file does not exist: $runs_jsonl" + df = read_jsonl(runs_jsonl) - compute_costs(df) + latex = compute_costs(df) + + open(output_tex, "w") do io + write(io, latex) + end + + println("\nWrote LaTeX to: $output_tex") + return 0 end From 0286d77d595f04b5fd95c211bbb754096b5dfdcd Mon Sep 17 00:00:00 2001 From: Alexius Wadell Date: Thu, 2 Apr 2026 18:46:20 -0600 Subject: [PATCH 08/11] generate all scaling tables programmatically --- opt/BayesianScaling/.gitignore | 1 + opt/BayesianScaling/plots/paper.jl | 13 +++- opt/BayesianScaling/plots/param_table.jl | 73 ++++++++++++++----- .../scripts/brute_force_cost.jl | 38 ++++------ opt/BayesianScaling/scripts/dec_3_sweep.jl | 8 +- 5 files changed, 88 insertions(+), 45 deletions(-) diff --git a/opt/BayesianScaling/.gitignore b/opt/BayesianScaling/.gitignore index 3da8b99e..320b911a 100644 --- a/opt/BayesianScaling/.gitignore +++ b/opt/BayesianScaling/.gitignore @@ -1,3 +1,4 @@ # Ignore output directory (or symlink to it) out/ out +fig/ diff --git a/opt/BayesianScaling/plots/paper.jl b/opt/BayesianScaling/plots/paper.jl index 6ac752b9..e9346ae7 100755 --- a/opt/BayesianScaling/plots/paper.jl +++ b/opt/BayesianScaling/plots/paper.jl @@ -597,14 +597,15 @@ function figure_bayesian_panel(data; return f end -function plot_model(model; kwargs...) +function plot_model(model; avg_seq_length, kwargs...) @info "plotting $model" run_name = basename(model) data = jldopen(joinpath(model, "chains.jld2"), "r") try + suffix = avg_seq_length != 1 ? "-token-cost" : "" with_theme(MISTStyle.theme()) do - figure_bayesian(data; kwargs...) - end |> MISTStyle.savefig("bayesian-$run_name") + figure_bayesian(data; avg_seq_length, kwargs...) + end |> MISTStyle.savefig("bayesian-$(run_name)$(suffix)") with_theme(MISTStyle.theme()) do plot_prediction_intervals(data["model"], data["chains"]) @@ -620,6 +621,7 @@ end function plot_all(dir; kwargs...) for model in readdir(dir; join=true) isfile(joinpath(model, "chains.jld2")) || continue + plot_model(model; avg_seq_length=1.0, kwargs...) plot_model(model; avg_seq_length=AVG_SEQ_LENGTH, kwargs...) end end @@ -634,8 +636,11 @@ function (@main)(ARGS=[]) # Panel figure data = jldopen(joinpath(chains_dir, "dec-3-sweep-smoothed-1.0--geometric-shape-gamma", "chains.jld2"), "r") with_theme(MISTStyle.theme()) do - figure_bayesian_panel(data; avg_seq_length=AVG_SEQ_LENGTH) + figure_bayesian_panel(data; avg_seq_length=1.0) end |> MISTStyle.savefig("scaling_panel_baseline") + with_theme(MISTStyle.theme()) do + figure_bayesian_panel(data; avg_seq_length=AVG_SEQ_LENGTH) + end |> MISTStyle.savefig("scaling_panel_baseline-token-cost") return nothing end diff --git a/opt/BayesianScaling/plots/param_table.jl b/opt/BayesianScaling/plots/param_table.jl index 7590b7fd..745a8d1e 100755 --- a/opt/BayesianScaling/plots/param_table.jl +++ b/opt/BayesianScaling/plots/param_table.jl @@ -1,7 +1,9 @@ #!/usr/bin/env -S julia +release --color=auto --startup-file=no --project=@script using PrettyTables using DataFrames -using BayesianScaling: BayesianScaling, find, selectparam +using BayesianScaling: BayesianScaling, selectparam +using StatsBase: response, cor, corspearman +using Format: format using JLD2: jldopen using MISTStyle using Makie @@ -69,19 +71,17 @@ function model_summary(files) return DataFrame(rows) end -function model_name(model, lr_model_size, geometric_penalty, harmonic_shape_penalty) - - named = Dict( - ("Penalized", :model_size, true, false) => "Penalized, Baseline", - ("Penalized", :d_model, true, false) => "Penalized, LR Scales with Hidden Size", - ("Penalized", :model_size, false, true) => "Penalized, Additive", - ("Penalized", :model_size, true, true) => "Penalized, Harmonic Shape Penalty", - ("Chinchilla", nothing, nothing, nothing) => "No Penalties", - ) - return get(named, (model, lr_model_size, geometric_penalty, harmonic_shape_penalty), nothing) -end - function write_scaling_param_table(df::DataFrame; sigdigits=3) + function model_name(model, lr_model_size, geometric_penalty, harmonic_shape_penalty) + named = Dict( + ("Penalized", :model_size, true, false) => "Penalized, Baseline", + ("Penalized", :d_model, true, false) => "Penalized, \\(\\eta_{\\star} = f(d_{\\text{model}})\\)", + ("Penalized", :model_size, true, true) => "\\stacked{Penalized, Harmonic}{Shape Penalty}", + ("Penalized", :model_size, false, true) => "Penalized, Additive", + ("Chinchilla", nothing, nothing, nothing) => "No Penalties", + ) + return get(named, (model, lr_model_size, geometric_penalty, harmonic_shape_penalty), nothing) + end df = transform(df, [:model, :lr_model_size, :geometric_penalty, :harmonic_shape_penalty] => ByRow(LatexCell∘model_name) => :model_name) subset!(df, :model_name => ByRow(!isnothing)) sort!(df, :waic) @@ -91,8 +91,8 @@ function write_scaling_param_table(df::DataFrame; sigdigits=3) :waic => latex_cell"\acs{WAIC}", :mape =>latex_cell"\acs{MAPE}", # :aic =>latex_cell"\acs{AIC}", - :pearson =>latex_cell"Pearson's \(\rho\)", - :spearman =>latex_cell"Spearman's \(\rho\)", + :pearson =>latex_cell"\makecell{Pearson's\\\(r\)}", + :spearman =>latex_cell"\makecell{Spearman's\\\(\rho\)}", :A =>latex_cell"$A$", :B =>latex_cell"$B$", :α =>latex_cell"$\alpha$", @@ -121,7 +121,7 @@ function write_scaling_param_table(df::DataFrame; sigdigits=3) l = sn(l; sigdigits) u = sn(u; sigdigits) end - return LatexCell("\\ci{$μ}{$l}{$u}") + return LatexCell("\\cistacked{$μ}{$l}{$u}") end function fmt_metric(v, i, j) metrics = [:waic, :mape, :aic, :pearson, :spearman] @@ -136,7 +136,7 @@ function write_scaling_param_table(df::DataFrame; sigdigits=3) return format("{}", v) end end - return pretty_table(df[:, first.(cols)]; + return pretty_table(String, df[:, first.(cols)]; tf, alignment=[col in [:model_name] ? :l : :c for col in first.(cols)], formatters=(fmt_ci, fmt_metric), @@ -194,6 +194,44 @@ function figure_smoothed_scaling(df) return f end +function summary_rows(df; sigdigits=3) + sort!(df, :n_smooth) + function model_name(model, lr_model_size, geometric_penalty, harmonic_shape_penalty) + named = Dict( + ("Penalized", :model_size, true, false) => "MIST, Penalized Scaling \\cref{eq:penalized_neural_scaling}", + ("Chinchilla", nothing, nothing, nothing) => "MIST, No Penalty Terms", + ) + return get(named, (model, lr_model_size, geometric_penalty, harmonic_shape_penalty), nothing) + end + transform!(df, [:model, :lr_model_size, :geometric_penalty, :harmonic_shape_penalty] => ByRow(model_name) => :model_name) + df = subset(df, + :model_name => ByRow(!isnothing), + :file => ByRow(endswith("gamma")), + :n_smooth => ByRow(==(1)), + ) + + function format_ci(v) + μ, l, u = v + μ = round(μ; sigdigits) + l = round(l; sigdigits) + u = round(u; sigdigits) + return "\\ci{$μ}{$l}{$u}" + end + open(joinpath("fig", "scaling_summary_rows.tex"), "w") do io + for model in eachrow(df) + row = [ + model.model_name, + format_ci(model.α), + format_ci(model.β), + format_ci(model.a), + format_ci(model.b), + format_ci(model.r) + ] + println(io, join(row, " & ") * " \\\\ % $(model.file)") + end + end + return nothing +end function (@main)(ARGS) df = model_summary(BayesianScaling.find("out", r"dec-3-.*chains.jld2")) @@ -202,4 +240,5 @@ function (@main)(ARGS) open(joinpath("fig", "param_table.tex"), "w") do io write(io, table) end + summary_rows(df) end diff --git a/opt/BayesianScaling/scripts/brute_force_cost.jl b/opt/BayesianScaling/scripts/brute_force_cost.jl index 13517c80..c7aa14fe 100755 --- a/opt/BayesianScaling/scripts/brute_force_cost.jl +++ b/opt/BayesianScaling/scripts/brute_force_cost.jl @@ -46,7 +46,7 @@ function compute_full_cost(df::DataFrame) p = NamedTuple{l_names}(p) N = non_embedding_size(p.d_model, p.ff_ratio, p.n_layer) D = p.data_size - return 6 * N * D * AVG_SEQ_LENGTH + return 6 * N * D # * AVG_SEQ_LENGTH end return levels, total_cost end @@ -62,14 +62,14 @@ function compute_partial_cost(df::DataFrame; include_lr::Bool=true) l_names = keys(levels) total_cost = sum(Iterators.product(levels...)) do p p = NamedTuple{l_names}(p) - return 6 * p.model_size * p.data_size * AVG_SEQ_LENGTH + return 6 * p.model_size * p.data_size # * AVG_SEQ_LENGTH end return levels, total_cost end function compute_bayes_cost(df::DataFrame) N = df.model_size - D = @. df.effective_batch_size * df.max_steps * AVG_SEQ_LENGTH + D = @. df.effective_batch_size * df.max_steps # * AVG_SEQ_LENGTH return sum(x -> 6 * prod(x), zip(N, D)) end @@ -79,7 +79,7 @@ function estimate_cost(levels::NamedTuple) p = NamedTuple{l_names}(p) N = non_embedding_size(p.d_model, p.ff_ratio, p.n_layer) D = p.data_size - return 6 * N * D * AVG_SEQ_LENGTH + return 6 * N * D # * AVG_SEQ_LENGTH end return total_compute end @@ -105,32 +105,26 @@ function latex_math_identifier(x) s = String(x) if occursin("_", s) head, tail... = split(s, "_") - return "\$" * head * "_{\\text{" * join(tail, "\\_") * "}}\$" + return "\$" * head * "_{\\textrm{" * join(tail, "\\_") * "}}\$" end return "\$" * s * "\$" end function latex_inline_identifier(x) - if x in (:d_model, :n_layer, :kv_size) + if x in (:n_layer, :d_model) return latex_math_identifier(x) + elseif x == :model_size + return "\$N\$" elseif x == :data_size return "\$D\$" elseif x == :lr_prefactor - return "\$lr\$" + return "\$\\eta_{0}\$" elseif x == :ff_ratio - return "ff\\_ratio" + return latex_math_identifier("r_ff") + elseif x == :kv_size + return latex_math_identifier("r_kv") elseif x == :effective_batch_size - return "batch" - else - return latex_escape_identifier(x) - end -end - -function latex_description_label(x) - if x in (:d_model, :n_layer, :kv_size) - return latex_math_identifier(x) - elseif x == :lr_prefactor - return "\$lr\\_prefactor\$" + return "\$\\mathcal{B}\$" else return latex_escape_identifier(x) end @@ -176,12 +170,12 @@ function latex_levels_block( continue end n = length(v) - label = latex_description_label(k) + label = latex_inline_identifier(k) vec = latex_vector(v) println(io, "\\item[$label ($n)]") - println(io, "\\[") + println(io, "\$") println(io, vec) - println(io, "\\]") + println(io, "\$") println(io) end diff --git a/opt/BayesianScaling/scripts/dec_3_sweep.jl b/opt/BayesianScaling/scripts/dec_3_sweep.jl index 24439a09..39403671 100755 --- a/opt/BayesianScaling/scripts/dec_3_sweep.jl +++ b/opt/BayesianScaling/scripts/dec_3_sweep.jl @@ -10,7 +10,11 @@ using Setfield: @set! GIT_ROOT = readchomp(`git rev-parse --show-toplevel`) -# include("wandb_import.jl") +if isinteractive() + includet("wandb_import.jl") +else + include("wandb_import.jl") +end function prior_d_model_lr!(priors; lr_0=1.64e-4, d_model=768, batch_size=1024) # LR scaling with Model size used in Attention is All You Need @@ -30,7 +34,7 @@ function dec_3_sweep_runs() :tokenizer => ByRow(==("smirk")), :created => ByRow(<=(DateTime(2025, 1))), :tags => ByRow(tags -> "dec-3-sweep" in tags), - [:step, :max_steps] => ByRow((s, ms) -> s / ms > 0.8); + # [:step, :max_steps] => ByRow((s, ms) -> s / ms > 0.8); skipmissing=true, ) dropmissing!(df, [:model_size, :d_model, :effective_batch_size, :max_steps, :lr, :ff_ratio, :aspect_ratio, :kv_size]) From c22f73d426262d71a4b6f0e28046b83c9f55e2d4 Mon Sep 17 00:00:00 2001 From: Alexius Wadell Date: Mon, 27 Apr 2026 22:49:38 -0600 Subject: [PATCH 09/11] fix: correct param_table --- opt/BayesianScaling/plots/param_table.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/opt/BayesianScaling/plots/param_table.jl b/opt/BayesianScaling/plots/param_table.jl index 745a8d1e..d2290d48 100755 --- a/opt/BayesianScaling/plots/param_table.jl +++ b/opt/BayesianScaling/plots/param_table.jl @@ -82,9 +82,10 @@ function write_scaling_param_table(df::DataFrame; sigdigits=3) ) return get(named, (model, lr_model_size, geometric_penalty, harmonic_shape_penalty), nothing) end - df = transform(df, [:model, :lr_model_size, :geometric_penalty, :harmonic_shape_penalty] => ByRow(LatexCell∘model_name) => :model_name) + df = transform(df, [:model, :lr_model_size, :geometric_penalty, :harmonic_shape_penalty] => ByRow(model_name) => :model_name) subset!(df, :model_name => ByRow(!isnothing)) sort!(df, :waic) + df.model_name = LatexCell.(df.model_name) round_sigfigs = ByRow(x -> round(x ; sigdigits)) cols = [ :model_name => "Model", @@ -195,7 +196,7 @@ function figure_smoothed_scaling(df) end function summary_rows(df; sigdigits=3) - sort!(df, :n_smooth) + df = sort(df, :n_smooth) function model_name(model, lr_model_size, geometric_penalty, harmonic_shape_penalty) named = Dict( ("Penalized", :model_size, true, false) => "MIST, Penalized Scaling \\cref{eq:penalized_neural_scaling}", From 6056a947b6f59ac1ee8f829081f801e37a3969c6 Mon Sep 17 00:00:00 2001 From: Anoushka Bhutani Date: Sat, 2 May 2026 13:58:43 -0400 Subject: [PATCH 10/11] formatting fix --- opt/BayesianScaling/plots/Project.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/opt/BayesianScaling/plots/Project.toml b/opt/BayesianScaling/plots/Project.toml index a0a7f84e..c2472e07 100644 --- a/opt/BayesianScaling/plots/Project.toml +++ b/opt/BayesianScaling/plots/Project.toml @@ -15,4 +15,3 @@ MISTStyle = {path = "../../MISTStyle"} [compat] PrettyTables = "^2.0.0" - From 43ec264561373adc9c85b791ac720ec9299cc7ed Mon Sep 17 00:00:00 2001 From: Alexius Wadell Date: Sun, 3 May 2026 21:42:35 -0600 Subject: [PATCH 11/11] remove unused function --- opt/BayesianScaling/plots/campaign.jl | 4 ---- 1 file changed, 4 deletions(-) diff --git a/opt/BayesianScaling/plots/campaign.jl b/opt/BayesianScaling/plots/campaign.jl index c3275600..fdf0af66 100755 --- a/opt/BayesianScaling/plots/campaign.jl +++ b/opt/BayesianScaling/plots/campaign.jl @@ -56,10 +56,6 @@ function format_table(df; sigdigits=3) ) end -function estimate_brute_force(df::DataFrame) - -end - function (@main)(ARGS=[]) data = jldopen(ARGS[1], "r") model = data["model"]