diff --git a/opt/BayesianScaling/.gitignore b/opt/BayesianScaling/.gitignore index 3da8b99e..320b911a 100644 --- a/opt/BayesianScaling/.gitignore +++ b/opt/BayesianScaling/.gitignore @@ -1,3 +1,4 @@ # Ignore output directory (or symlink to it) out/ out +fig/ diff --git a/opt/BayesianScaling/plots/Project.toml b/opt/BayesianScaling/plots/Project.toml index 62b8ddc0..c2472e07 100644 --- a/opt/BayesianScaling/plots/Project.toml +++ b/opt/BayesianScaling/plots/Project.toml @@ -10,5 +10,8 @@ PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" [sources] -BayesianScaling = {path = "../"} +BayesianScaling = {path = ".."} MISTStyle = {path = "../../MISTStyle"} + +[compat] +PrettyTables = "^2.0.0" diff --git a/opt/BayesianScaling/plots/paper.jl b/opt/BayesianScaling/plots/paper.jl old mode 100644 new mode 100755 index d6ac5f57..e9346ae7 --- a/opt/BayesianScaling/plots/paper.jl +++ b/opt/BayesianScaling/plots/paper.jl @@ -1,3 +1,4 @@ +#!/usr/bin/env -S uv run julia --project=@script --startup-file=no using Makie using MISTStyle using DataFrames @@ -6,6 +7,12 @@ using JLD2: jldopen using BayesianScaling using BayesianScaling: pf_day, get_nbins +const AVG_SEQ_LENGTH::Float64 = 65.2 +""" +Average number of tokens per molecule for the Smirk Tokenizer on REALSpace +Source: SI for doi:10.1021/acs.jcim.5c01856 +""" + function mark_model!(ax_scale, ax_compute; model_loss, d_model, @@ -15,11 +22,12 @@ function mark_model!(ax_scale, ax_compute; eff_batch_size, colorbar, h_loss, + avg_seq_length = 1, kwargs... ) N = BayesianScaling.non_embedding_size(d_model, ff_ratio * d_model, n_layers) D = steps * eff_batch_size - C = 6 * float(N) * float(D) + C = 6 * float(N) * float(D) * float(avg_seq_length) kwargs = (; marker=:star5, color=model_loss, @@ -102,11 +110,11 @@ function plot_scaling_params!(f, chains; legend_pos=:top) ) prior_works = [ - "Kaplan et al." => (; α=0.076, β=0.103, a = 0.58, b = 0.42), "Hoffmann et al. (Appr. 1)" => (; a = 0.50, b = 0.50), "Hoffmann et al. (Appr. 3)" => (; α=0.34, β=0.28, a = 0.46, b = 0.54), "Bi et al. (Early)" => (; a = 0.450, b = 0.550), "Bi et al. (Current)" => (; a = 0.524, b = 0.478), + "Kaplan et al." => (; α=0.076, β=0.103, a = 0.58, b = 0.42), ] gl = GridLayout(f[1, 1]) @@ -167,12 +175,14 @@ function plot_scaling_params!(f, chains; legend_pos=:top) Legend(legend_pos, elem, MISTStyle.label.(elem); tellwidth, tellheight, - margin=2pt .* (1, 1, 1, 1), orientation, nbanks, halign, valign, + padding=(4pt, 4pt, 1pt, 1pt), + colgap=4pt, ) + rowgap!(gl, 2pt) return f end @@ -320,10 +330,11 @@ function plot_lr_partial_dependence!(f, model, chains; p=0.95, npoints=100) end function figure_bayesian(data; - N=logrange(1e5, 3e9; length=50), - C=logrange(1e12, 2*pf_day; length=50), + N=logrange(1e5, 4e9; length=50), + C=logrange(1e-5 * pf_day, 100*pf_day; length=50), p=0.95, n_samples=500, + avg_seq_length::AbstractFloat = 1.0, ) model = data["model"] chains = data["chains"] @@ -333,7 +344,7 @@ function figure_bayesian(data; model_loss = StatsBase.response(model) model_size = df.model_size - model_flops = @. 6 * float(df.model_size) * float(df.data_size) + model_flops = @. 6 * float(df.model_size) * float(df.data_size) * float(avg_seq_length) f_height = model.formula isa BayesianScaling.ShapedScaling ? 4.5inch : 3inch f = Figure(; @@ -358,13 +369,13 @@ function figure_bayesian(data; ylabel="Non-Embedding Parameters", ) - loss_c_opt = BayesianScaling.compute_optimal_loss(scaling(chain_samples), C; p) + loss_c_opt = BayesianScaling.compute_optimal_loss(scaling(chain_samples), C ./ avg_seq_length; p) h_opt = predictionband!(ax_compute, C ./ pf_day, loss_c_opt...; color=MISTStyle.UM_COLORS.maize, linewidth=2pt, band_color=(RGBf(0.596, 0.612, 0.592), 0.4), ) - loss_cs = first(BayesianScaling.hoffman_compute_scaling(scaling(chain_samples), N, C))' + loss_cs = first(BayesianScaling.hoffman_compute_scaling(scaling(chain_samples), N, C ./ avg_seq_length))' levels = logrange(minimum(loss_cs), 1.0; length=10) cb = Colorbar(gl[1, end+1]; label="Validation Loss (nats)", @@ -388,7 +399,7 @@ function figure_bayesian(data; ) - N_opt = BayesianScaling.compute_optimal_model_size(scaling(chain_samples), C) + N_opt = BayesianScaling.compute_optimal_model_size(scaling(chain_samples), C ./ avg_seq_length) predictionband!(ax_scale, C ./ pf_day, N_opt...; color=h_opt.color, linewidth=h_opt.linewidth, @@ -414,6 +425,7 @@ function figure_bayesian(data; model_loss=0.01019080262631178, colorbar=cb, h_loss, + avg_seq_length, ) # dh61satt @@ -426,6 +438,7 @@ function figure_bayesian(data; model_loss=0.03703538700938225, colorbar=cb, h_loss, + avg_seq_length, ) @@ -451,10 +464,11 @@ function figure_bayesian(data; end function figure_bayesian_panel(data; - N=logrange(1e5, 3e9; length=50), - C=logrange(1e12, 2*pf_day; length=50), + N=logrange(1e5, 4e9; length=50), + C=logrange(1e-5 * pf_day, 100*pf_day; length=50), p=0.95, n_samples=500, + avg_seq_length::AbstractFloat = 1, ) f = Figure(; size=(89mm, 170mm), figure_padding=(2, 2, 2, 4)) @@ -466,7 +480,7 @@ function figure_bayesian_panel(data; model_loss = StatsBase.response(model) model_size = df.model_size - model_flops = @. 6 * float(df.model_size) * float(df.data_size) + model_flops = @. 6 * float(df.model_size) * float(df.data_size) * float(avg_seq_length) # Plot Covariance colsize!(f.layout, 1, 100mm) @@ -497,13 +511,13 @@ function figure_bayesian_panel(data; ylabel="Non-Embedding Parameters", ) - loss_c_opt = BayesianScaling.compute_optimal_loss(scaling(chain_samples), C; p) + loss_c_opt = BayesianScaling.compute_optimal_loss(scaling(chain_samples), C ./ avg_seq_length; p) h_opt = predictionband!(ax_compute, C ./ pf_day, loss_c_opt...; color=MISTStyle.UM_COLORS.maize, linewidth=2pt, band_color=(RGBf(0.596, 0.612, 0.592), 0.4), ) - loss_cs = first(BayesianScaling.hoffman_compute_scaling(scaling(chain_samples), N, C))' + loss_cs = first(BayesianScaling.hoffman_compute_scaling(scaling(chain_samples), N, C ./ avg_seq_length))' levels = logrange(minimum(loss_cs), 1.0; length=10) cb = Colorbar(gl[1, end+1]; label="Validation Loss (nats)", @@ -527,7 +541,7 @@ function figure_bayesian_panel(data; ) - N_opt = BayesianScaling.compute_optimal_model_size(scaling(chain_samples), C) + N_opt = BayesianScaling.compute_optimal_model_size(scaling(chain_samples), C ./ avg_seq_length) predictionband!(ax_scale, C ./ pf_day, N_opt...; color=h_opt.color, linewidth=h_opt.linewidth, @@ -553,6 +567,7 @@ function figure_bayesian_panel(data; model_loss=0.01019080262631178, colorbar=cb, h_loss, + avg_seq_length, ) # dh61satt @@ -565,6 +580,7 @@ function figure_bayesian_panel(data; model_loss=0.03703538700938225, colorbar=cb, h_loss, + avg_seq_length, ) # Fit distributions @@ -581,14 +597,15 @@ function figure_bayesian_panel(data; return f end -function plot_model(model; kwargs...) +function plot_model(model; avg_seq_length, kwargs...) @info "plotting $model" run_name = basename(model) data = jldopen(joinpath(model, "chains.jld2"), "r") try + suffix = avg_seq_length != 1 ? "-token-cost" : "" with_theme(MISTStyle.theme()) do - figure_bayesian(data; kwargs...) - end |> MISTStyle.savefig("bayesian-$run_name") + figure_bayesian(data; avg_seq_length, kwargs...) + end |> MISTStyle.savefig("bayesian-$(run_name)$(suffix)") with_theme(MISTStyle.theme()) do plot_prediction_intervals(data["model"], data["chains"]) @@ -604,7 +621,8 @@ end function plot_all(dir; kwargs...) for model in readdir(dir; join=true) isfile(joinpath(model, "chains.jld2")) || continue - plot_model(model) + plot_model(model; avg_seq_length=1.0, kwargs...) + plot_model(model; avg_seq_length=AVG_SEQ_LENGTH, kwargs...) end end @@ -618,8 +636,11 @@ function (@main)(ARGS=[]) # Panel figure data = jldopen(joinpath(chains_dir, "dec-3-sweep-smoothed-1.0--geometric-shape-gamma", "chains.jld2"), "r") with_theme(MISTStyle.theme()) do - figure_bayesian_panel(data) + figure_bayesian_panel(data; avg_seq_length=1.0) end |> MISTStyle.savefig("scaling_panel_baseline") + with_theme(MISTStyle.theme()) do + figure_bayesian_panel(data; avg_seq_length=AVG_SEQ_LENGTH) + end |> MISTStyle.savefig("scaling_panel_baseline-token-cost") return nothing end diff --git a/opt/BayesianScaling/plots/param_table.jl b/opt/BayesianScaling/plots/param_table.jl index 7590b7fd..d2290d48 100755 --- a/opt/BayesianScaling/plots/param_table.jl +++ b/opt/BayesianScaling/plots/param_table.jl @@ -1,7 +1,9 @@ #!/usr/bin/env -S julia +release --color=auto --startup-file=no --project=@script using PrettyTables using DataFrames -using BayesianScaling: BayesianScaling, find, selectparam +using BayesianScaling: BayesianScaling, selectparam +using StatsBase: response, cor, corspearman +using Format: format using JLD2: jldopen using MISTStyle using Makie @@ -69,30 +71,29 @@ function model_summary(files) return DataFrame(rows) end -function model_name(model, lr_model_size, geometric_penalty, harmonic_shape_penalty) - - named = Dict( - ("Penalized", :model_size, true, false) => "Penalized, Baseline", - ("Penalized", :d_model, true, false) => "Penalized, LR Scales with Hidden Size", - ("Penalized", :model_size, false, true) => "Penalized, Additive", - ("Penalized", :model_size, true, true) => "Penalized, Harmonic Shape Penalty", - ("Chinchilla", nothing, nothing, nothing) => "No Penalties", - ) - return get(named, (model, lr_model_size, geometric_penalty, harmonic_shape_penalty), nothing) -end - function write_scaling_param_table(df::DataFrame; sigdigits=3) - df = transform(df, [:model, :lr_model_size, :geometric_penalty, :harmonic_shape_penalty] => ByRow(LatexCell∘model_name) => :model_name) + function model_name(model, lr_model_size, geometric_penalty, harmonic_shape_penalty) + named = Dict( + ("Penalized", :model_size, true, false) => "Penalized, Baseline", + ("Penalized", :d_model, true, false) => "Penalized, \\(\\eta_{\\star} = f(d_{\\text{model}})\\)", + ("Penalized", :model_size, true, true) => "\\stacked{Penalized, Harmonic}{Shape Penalty}", + ("Penalized", :model_size, false, true) => "Penalized, Additive", + ("Chinchilla", nothing, nothing, nothing) => "No Penalties", + ) + return get(named, (model, lr_model_size, geometric_penalty, harmonic_shape_penalty), nothing) + end + df = transform(df, [:model, :lr_model_size, :geometric_penalty, :harmonic_shape_penalty] => ByRow(model_name) => :model_name) subset!(df, :model_name => ByRow(!isnothing)) sort!(df, :waic) + df.model_name = LatexCell.(df.model_name) round_sigfigs = ByRow(x -> round(x ; sigdigits)) cols = [ :model_name => "Model", :waic => latex_cell"\acs{WAIC}", :mape =>latex_cell"\acs{MAPE}", # :aic =>latex_cell"\acs{AIC}", - :pearson =>latex_cell"Pearson's \(\rho\)", - :spearman =>latex_cell"Spearman's \(\rho\)", + :pearson =>latex_cell"\makecell{Pearson's\\\(r\)}", + :spearman =>latex_cell"\makecell{Spearman's\\\(\rho\)}", :A =>latex_cell"$A$", :B =>latex_cell"$B$", :α =>latex_cell"$\alpha$", @@ -121,7 +122,7 @@ function write_scaling_param_table(df::DataFrame; sigdigits=3) l = sn(l; sigdigits) u = sn(u; sigdigits) end - return LatexCell("\\ci{$μ}{$l}{$u}") + return LatexCell("\\cistacked{$μ}{$l}{$u}") end function fmt_metric(v, i, j) metrics = [:waic, :mape, :aic, :pearson, :spearman] @@ -136,7 +137,7 @@ function write_scaling_param_table(df::DataFrame; sigdigits=3) return format("{}", v) end end - return pretty_table(df[:, first.(cols)]; + return pretty_table(String, df[:, first.(cols)]; tf, alignment=[col in [:model_name] ? :l : :c for col in first.(cols)], formatters=(fmt_ci, fmt_metric), @@ -194,6 +195,44 @@ function figure_smoothed_scaling(df) return f end +function summary_rows(df; sigdigits=3) + df = sort(df, :n_smooth) + function model_name(model, lr_model_size, geometric_penalty, harmonic_shape_penalty) + named = Dict( + ("Penalized", :model_size, true, false) => "MIST, Penalized Scaling \\cref{eq:penalized_neural_scaling}", + ("Chinchilla", nothing, nothing, nothing) => "MIST, No Penalty Terms", + ) + return get(named, (model, lr_model_size, geometric_penalty, harmonic_shape_penalty), nothing) + end + transform!(df, [:model, :lr_model_size, :geometric_penalty, :harmonic_shape_penalty] => ByRow(model_name) => :model_name) + df = subset(df, + :model_name => ByRow(!isnothing), + :file => ByRow(endswith("gamma")), + :n_smooth => ByRow(==(1)), + ) + + function format_ci(v) + μ, l, u = v + μ = round(μ; sigdigits) + l = round(l; sigdigits) + u = round(u; sigdigits) + return "\\ci{$μ}{$l}{$u}" + end + open(joinpath("fig", "scaling_summary_rows.tex"), "w") do io + for model in eachrow(df) + row = [ + model.model_name, + format_ci(model.α), + format_ci(model.β), + format_ci(model.a), + format_ci(model.b), + format_ci(model.r) + ] + println(io, join(row, " & ") * " \\\\ % $(model.file)") + end + end + return nothing +end function (@main)(ARGS) df = model_summary(BayesianScaling.find("out", r"dec-3-.*chains.jld2")) @@ -202,4 +241,5 @@ function (@main)(ARGS) open(joinpath("fig", "param_table.tex"), "w") do io write(io, table) end + summary_rows(df) end diff --git a/opt/BayesianScaling/scripts/Project.toml b/opt/BayesianScaling/scripts/Project.toml index 8d9d7811..e65d27a9 100644 --- a/opt/BayesianScaling/scripts/Project.toml +++ b/opt/BayesianScaling/scripts/Project.toml @@ -13,4 +13,4 @@ Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" [sources] -BayesianScaling = {path = "../"} +BayesianScaling = {path = ".."} diff --git a/opt/BayesianScaling/scripts/brute_force_cost.jl b/opt/BayesianScaling/scripts/brute_force_cost.jl new file mode 100755 index 00000000..c7aa14fe --- /dev/null +++ b/opt/BayesianScaling/scripts/brute_force_cost.jl @@ -0,0 +1,318 @@ +#!/usr/bin/env -S julia +release --color=auto --startup-file=no --project=@script +# Estimate the compute cost of doing a brute force search +# Usage: +# ./brute_force_cost.jl RUNS_JSONL OUTPUT_TEX +# Example: +# ./brute_force_cost.jl ../out/dec_3_sweep_runs.jsonl cost_summary.tex + +using DataFrames +using JSON: JSON +using BayesianScaling: non_embedding_size, pf_day + +const AVG_SEQ_LENGTH::Float64 = 65.2 +""" +Average number of tokens per molecule for the Smirk Tokenizer on REALSpace. +Source: SI for doi:10.1021/acs.jcim.5c01856 +""" + +function read_jsonl(file::AbstractString) + rows = [] + open(file, "r") do fid + for line in eachline(fid) + push!(rows, JSON.parse(line)) + end + end + return DataFrame(rows) +end + +function unique_lr_prefactor(df::DataFrame) + # Round to avoid 3.125 and 3.1249999... being treated differently. + lr_prefactor = @. df.lr / sqrt(df.effective_batch_size) + return unique(round.(lr_prefactor; sigdigits=3)) +end + +function compute_full_cost(df::DataFrame) + levels = (; + d_model = unique(df.d_model), + n_layer = unique(df.d_model ./ df.aspect_ratio), + ff_ratio = unique(df.ff_ratio), + effective_batch_size = unique(df.effective_batch_size), + kv_size = unique(df.kv_size), + data_size = unique(df.max_steps .* df.effective_batch_size), + lr_prefactor = unique_lr_prefactor(df), + ) + l_names = keys(levels) + total_cost = sum(Iterators.product(levels...)) do p + p = NamedTuple{l_names}(p) + N = non_embedding_size(p.d_model, p.ff_ratio, p.n_layer) + D = p.data_size + return 6 * N * D # * AVG_SEQ_LENGTH + end + return levels, total_cost +end + +function compute_partial_cost(df::DataFrame; include_lr::Bool=true) + levels = (; + model_size = unique(df.model_size), + data_size = unique(df.effective_batch_size .* df.max_steps), + ) + if include_lr + levels = (; levels..., lr_prefactor = unique_lr_prefactor(df)) + end + l_names = keys(levels) + total_cost = sum(Iterators.product(levels...)) do p + p = NamedTuple{l_names}(p) + return 6 * p.model_size * p.data_size # * AVG_SEQ_LENGTH + end + return levels, total_cost +end + +function compute_bayes_cost(df::DataFrame) + N = df.model_size + D = @. df.effective_batch_size * df.max_steps # * AVG_SEQ_LENGTH + return sum(x -> 6 * prod(x), zip(N, D)) +end + +function estimate_cost(levels::NamedTuple) + l_names = keys(levels) + total_compute = sum(Iterators.product(levels...)) do p + p = NamedTuple{l_names}(p) + N = non_embedding_size(p.d_model, p.ff_ratio, p.n_layer) + D = p.data_size + return 6 * N * D # * AVG_SEQ_LENGTH + end + return total_compute +end + +function show_levels(levels::NamedTuple) + for (k, v) in pairs(levels) + n = length(v) + println("$k ($n): $v") + end + return nothing +end + +function round_sig(x::Real; sigdigits::Int=3) + return round(float(x); sigdigits=sigdigits) +end + +function latex_escape_identifier(x) + s = String(x) + return replace(s, "_" => "\\_") +end + +function latex_math_identifier(x) + s = String(x) + if occursin("_", s) + head, tail... = split(s, "_") + return "\$" * head * "_{\\textrm{" * join(tail, "\\_") * "}}\$" + end + return "\$" * s * "\$" +end + +function latex_inline_identifier(x) + if x in (:n_layer, :d_model) + return latex_math_identifier(x) + elseif x == :model_size + return "\$N\$" + elseif x == :data_size + return "\$D\$" + elseif x == :lr_prefactor + return "\$\\eta_{0}\$" + elseif x == :ff_ratio + return latex_math_identifier("r_ff") + elseif x == :kv_size + return latex_math_identifier("r_kv") + elseif x == :effective_batch_size + return "\$\\mathcal{B}\$" + else + return latex_escape_identifier(x) + end +end + +function latex_number(x::Real; sigdigits::Int=6) + xr = round(float(x); sigdigits=sigdigits) + + if iszero(xr) + return "0" + end + + if 1e-3 <= abs(xr) < 1e4 + if isinteger(xr) + return string(Int(round(xr))) + end + return string(xr) + end + + exp10 = floor(Int, log10(abs(xr))) + mant = xr / 10.0^exp10 + mant = round(mant; sigdigits=sigdigits) + return "\\sn{$mant}{$exp10}" +end + +function latex_vector(values; sigdigits::Int=6) + parts = [latex_number(v; sigdigits=sigdigits) for v in sort(collect(values))] + return "[" * join(parts, ",\\;\n") * "]" +end + +function latex_levels_block( + io::IO, + title::AbstractString, + levels::NamedTuple; + skip::Set{Symbol}=Set{Symbol}(), +) + println(io, "\\paragraph{$title}") + println(io, "\\begin{description}") + println(io) + + for (k, v) in pairs(levels) + if k in skip + continue + end + n = length(v) + label = latex_inline_identifier(k) + vec = latex_vector(v) + println(io, "\\item[$label ($n)]") + println(io, "\$") + println(io, vec) + println(io, "\$") + println(io) + end + + println(io, "\\end{description}") + println(io) + return nothing +end + +function generate_cost_latex( + full_levels::NamedTuple, + C_full::Real, + partial_levels::NamedTuple, + C_partial::Real, + partial_levels_lr::NamedTuple, + C_partial_lr::Real, + C_bayes::Real; + cost_3lr::Union{Nothing, Real}=nothing, + sigdigits_cost::Int=3, +) + rs(x) = round(float(x); sigdigits=sigdigits_cost) + + rel_full = C_full / C_bayes + rel_partial = C_partial / C_bayes + rel_partial_3lr = isnothing(cost_3lr) ? nothing : cost_3lr / C_bayes + rel_partial_lr = C_partial_lr / C_bayes + + full_level_summary = join([ + "$(latex_inline_identifier(k))($(length(v)))" for (k, v) in pairs(full_levels) + ], ", ") + + partial_level_summary = join([ + "$(latex_inline_identifier(k))($(length(v)))" for (k, v) in pairs(partial_levels) + ], ", ") + + partial_lr_level_summary = join([ + "$(latex_inline_identifier(k))($(length(v)))" for (k, v) in pairs(partial_levels_lr) + ], ", ") + + io = IOBuffer() + + println(io, "\\begin{table}[h]") + println(io, "\\centering") + println(io, "\\begin{tabular}{lccc}") + println(io, "\\hline") + println(io, "Case & Levels (with \$n\$) & Total Cost (PF-days) & Cost Rel.\\ Bayes \\\\") + println(io, "\\hline") + println(io, "Full Factorial & $full_level_summary & $(rs(C_full / pf_day)) & $(rs(rel_full)) \\\\") + println(io, "\$N + D\$ Only & $partial_level_summary & $(rs(C_partial / pf_day)) & $(rs(rel_partial)) \\\\") + if !isnothing(cost_3lr) + println( + io, + "\$N + D\$ Only (3 lr) & model\\_size($(length(partial_levels.model_size))), \$D\$($(length(partial_levels.data_size))), \$lr\$(3) & $(rs(cost_3lr / pf_day)) & $(rs(rel_partial_3lr)) \\\\", + ) + end + println(io, "\$N + D + lr\$ & $partial_lr_level_summary & $(rs(C_partial_lr / pf_day)) & $(rs(rel_partial_lr)) \\\\") + println(io, "Bayesian Optimization & --- & $(rs(C_bayes / pf_day)) & 1.0 \\\\") + println(io, "\\hline") + println(io, "\\end{tabular}") + println(io, "\\caption{Training cost comparison across experimental design strategies.}") + println(io, "\\end{table}") + println(io) + println(io) + + latex_levels_block(io, "Full Factorial Levels", full_levels) + latex_levels_block(io, "\$N + D\$ Only Levels", partial_levels) + latex_levels_block( + io, + "\$N + D + lr\$ Levels", + partial_levels_lr; + skip=Set([:model_size, :data_size]), + ) + + return String(take!(io)) +end + +function compute_costs(df::DataFrame) + C_bayes = compute_bayes_cost(df) + full_levels, C_full = compute_full_cost(df) + partial_levels_lr, C_partial_lr = compute_partial_cost(df; include_lr=true) + partial_levels, C_partial = compute_partial_cost(df; include_lr=false) + + rs(x) = round(x; sigdigits=3) + + println("Full Factorial:") + println("Cost: $(rs(C_full / pf_day)) pf-day") + show_levels(full_levels) + + println("\n\nN + D Only:") + println("Cost: $(rs(C_partial / pf_day)) pf-day") + println("Cost (3 lr): $(rs(3 * C_partial / pf_day)) pf-day") + show_levels(partial_levels) + + println("\n\nN + D + lr_prefactor Only:") + println("Cost: $(rs(C_partial_lr / pf_day)) pf-day") + show_levels(partial_levels_lr) + + gpu_hours = sum(df.runtime .* df.world_size / (60 * 60)) + println("\n\nBayes Cost: $(rs(C_bayes / pf_day)) pf-day") + println("Bayes Cost: $(rs(gpu_hours)) GPU-hours") + println("Rel. Full: $(rs(C_full / C_bayes))") + println("Rel. Partial (N, D): $(rs(C_partial / C_bayes))") + println("Rel. Partial (N, D, 3 lr): $(rs(3 * C_partial / C_bayes))") + println("Rel. Partial (N, D, lr): $(rs(C_partial_lr / C_bayes))") + + latex = generate_cost_latex( + full_levels, + C_full, + partial_levels, + C_partial, + partial_levels_lr, + C_partial_lr, + C_bayes; + cost_3lr=3 * C_partial, + ) + + println("\n\nLaTeX:\n") + println(latex) + + return latex +end + +function (@main)(args=[]) + @assert length(args) >= 2 "Usage: brute_force_cost.jl RUNS_JSONL OUTPUT_TEX" + + runs_jsonl = args[1] + output_tex = args[2] + + @assert isfile(runs_jsonl) "Input JSONL file does not exist: $runs_jsonl" + + df = read_jsonl(runs_jsonl) + latex = compute_costs(df) + + open(output_tex, "w") do io + write(io, latex) + end + + println("\nWrote LaTeX to: $output_tex") + + return 0 +end diff --git a/opt/BayesianScaling/scripts/dec_3_sweep.jl b/opt/BayesianScaling/scripts/dec_3_sweep.jl index 88c61958..39403671 100755 --- a/opt/BayesianScaling/scripts/dec_3_sweep.jl +++ b/opt/BayesianScaling/scripts/dec_3_sweep.jl @@ -10,7 +10,11 @@ using Setfield: @set! GIT_ROOT = readchomp(`git rev-parse --show-toplevel`) -include("wandb_import.jl") +if isinteractive() + includet("wandb_import.jl") +else + include("wandb_import.jl") +end function prior_d_model_lr!(priors; lr_0=1.64e-4, d_model=768, batch_size=1024) # LR scaling with Model size used in Attention is All You Need @@ -22,7 +26,23 @@ function prior_d_model_lr!(priors; lr_0=1.64e-4, d_model=768, batch_size=1024) return priors end -function build_models() +function dec_3_sweep_runs() + df = pretraining_runs( + joinpath(GIT_ROOT, ".cache", "wandb-export"); + ) + subset!(df, + :tokenizer => ByRow(==("smirk")), + :created => ByRow(<=(DateTime(2025, 1))), + :tags => ByRow(tags -> "dec-3-sweep" in tags), + # [:step, :max_steps] => ByRow((s, ms) -> s / ms > 0.8); + skipmissing=true, + ) + dropmissing!(df, [:model_size, :d_model, :effective_batch_size, :max_steps, :lr, :ff_ratio, :aspect_ratio, :kv_size]) + subset!(df, :val_loss_best => ByRow(x -> 1e-6 < x < 1.0); skipmissing=true) + return df +end + +function build_models(df_sweep) formulas = Dict( "baseline" => ShapedScaling(), "hoffman" => HoffmanScaling(), @@ -34,17 +54,7 @@ function build_models() models = Dict() for eval_batch in [1, 1e3, 1e4, 1e5, 1e6] - df = pretraining_runs( - joinpath(GIT_ROOT, ".cache", "wandb-export"); - smoothed_eval_batch=eval_batch - ) - subset!(df, - :tokenizer => ByRow(==("smirk")), - :created => ByRow(<=(DateTime(2025, 1))), - :tags => ByRow(tags -> "dec-3-sweep" in tags), - [:step, :max_steps] => ByRow((s, ms) -> s / ms > 0.8); - skipmissing=true, - ) + df = deepcopy(df_sweep) loss = eval_batch == 1 ? :val_loss_best : :val_loss_smooth df = select(df, :model_size, @@ -164,7 +174,23 @@ function fit_summary() return df end +function export_runs() + df_sweep = dec_3_sweep_runs() + outdir = joinpath(pkgdir(BayesianScaling), "out") + mkpath(outdir) + open(joinpath(outdir, "dec_3_sweep_runs.jsonl"), "w") do fid + for run in eachrow(df_sweep) + println(fid, JSON.json(Dict(pairs(run)))) + end + end + return df_sweep +end + function (@main)(args) - models = build_models() + # Record runs + df_sweep = export_runs() + + # Build models + models = build_models(df_sweep) run_models(models) end diff --git a/opt/BayesianScaling/scripts/wandb_import.jl b/opt/BayesianScaling/scripts/wandb_import.jl index da2d9006..28fbab38 100644 --- a/opt/BayesianScaling/scripts/wandb_import.jl +++ b/opt/BayesianScaling/scripts/wandb_import.jl @@ -1,12 +1,38 @@ using DataFrames -using Dates: Dates, DateTime +using Dates: Dates, DateTime, @dateformat_str using JSON: JSON using BayesianScaling: find, average_last +function dedup_trace(metric_traces, name::AbstractString) + steps = metric_traces["step"] + y = metric_traces[name] + trace = Dict{eltype(steps), eltype(y)}() + fn = x -> (!isnothing(x[2]) && !ismissing(x[2])) + for (xt, yt) in Iterators.filter(fn, zip(steps, y)) + trace[xt] = yt isa Real ? yt : parse(Float64, yt) + end + x_out::Vector{Int} = (sort∘collect∘keys)(trace) + y_out::Vector{Float64} = [trace[x] for x in x_out] + return x_out, y_out +end + +function get_metric(metrics::Vector; kwargs...) + df = DataFrame(metrics) + kwargs = map(kv -> kv[1] => ByRow(==(kv[2])), collect(kwargs)) + subset!(df, kwargs...) + if nrow(df) == 1 + return first(df.value) + elseif nrow(df) == 0 + return missing + else + error("Multiple matches: $df") + end +end + function pretraining_runs(dir::AbstractString; smoothed_eval_batch=1e6) row = [] for file in find(joinpath(dir, "pretraining"), r".*\.json") - run = JSON.parsefile(file; null=missing) + run = JSON.parsefile(file; null=missing, allownan=true) pretrained_models = [ "electrolyte_fm.models.RoBERTa", @@ -25,43 +51,58 @@ function pretraining_runs(dir::AbstractString; smoothed_eval_batch=1e6) beta2 = !ismissing(beta) ? beta[2] : missing # Average over the last 1e6 samples - val_epoch_loss = map(run["metric_traces"]["val_loss"]) do x - x isa Real ? x : parse(Float64, x) - end + _, val_epoch_loss = dedup_trace(run["metric_traces"], "val_loss") effective_val_epoch = run["trainer"]["effective_val_epoch"] ismissing(effective_val_epoch) && continue val_loss_smooth = average_last(val_epoch_loss; n=smoothed_eval_batch / effective_val_epoch) - push!(row, (; - id=run["id"], - state=run["state"], - user=run["user"], - cluster=run["cluster"], - commit=run["commit"], - tags=string.(run["tags"]), - created=DateTime(run["created"][1:23], Dates.ISODateTimeFormat), - model_size=run["model"]["model_size"], - model_class=run["model"]["class_path"], - d_model, - ff_ratio, - kv_size, - aspect_ratio, - optimizer=run["optimizer"]["class_path"], - tokenizer=run["data"]["tokenizer"], - max_steps=run["trainer"]["num_training_steps"], - step=run["trainer"]["step"], - tokens=run["trainer"]["tokens"], - gas=run["trainer"]["gas"], - masked_tokens=run["trainer"]["masked_tokens"], - effective_batch_size=run["trainer"]["effective_batch_size"], - effective_val_epoch, - lr=run["optimizer"]["lr"], - beta1, - beta2, - val_loss_best=run["metrics"]["val_loss_best"], - val_loss_last=run["metrics"]["val_loss_last"], - val_loss_smooth - )) + # Get last/best metrics + val_loss_best=get_metric(run["metrics"]; metric="loss_epoch", split="val", type="best") + val_loss_min=get_metric(run["metrics"]; metric="loss_epoch", split="val", type="min") + val_loss_last=get_metric(run["metrics"]; metric="loss_epoch", split="val", type="last") + val_loss_best = ismissing(val_loss_best) ? val_loss_min : val_loss_best + + # Parse Date + if !ismissing(run["created"]) + created = DateTime(run["created"][1:23], Dates.ISODateTimeFormat) + else + created = missing + end + + try + push!(row, (; + id=run["id"], + state=run["state"], + user=run["user"], + cluster=run["cluster"], + commit=run["commit"], + tags=string.(run["tags"]), + created, + runtime=float(run["runtime"]), + world_size=Int(run["job_config"]["world_size"]), + model_size=run["model"]["model_size"], + model_class=run["model"]["class_path"], + d_model, + ff_ratio, + kv_size, + aspect_ratio, + optimizer=run["optimizer"]["class_path"], + tokenizer=run["data"]["tokenizer"], + max_steps=run["trainer"]["num_training_steps"], + step=get_metric(run["metrics"]; metric="global_step", type="last"), + gas=run["trainer"]["gas"], + effective_batch_size=run["trainer"]["effective_batch_size"], + effective_val_epoch, + lr=run["optimizer"]["lr"], + beta1, + beta2, + val_loss_best, + val_loss_last, + val_loss_smooth + )) + catch e + @error "Failed to import $file" e + end end return DataFrame(row) end diff --git a/opt/TokenizerStats/python/helper/atomic_oov.py b/opt/TokenizerStats/python/helper/atomic_oov.py index e8969145..f4be16a4 100644 --- a/opt/TokenizerStats/python/helper/atomic_oov.py +++ b/opt/TokenizerStats/python/helper/atomic_oov.py @@ -332,10 +332,19 @@ def tabulate_tokenizer( } -def process_tokenizer(dataset_name: str, tokenizer: dict[str, str], args) -> dict: +def process_tokenizer( + dataset_name: str, + tokenizer: dict[str, str], + tmqm_dataset: Optional[Path] = None, +) -> dict: + """Process a single dataset for a single tokenizer. + + NOTE: This function is executed in child processes when --workers is set. + Keep its signature/payload picklable. + """ tok = load_tokenizer(tokenizer["name_or_path"]) - if args.tmqm_dataset is not None: - DATASETS["tmQM"] = lambda: tmqm_dataset(args.tmqm_dataset) + if tmqm_dataset is not None: + DATASETS["tmQM"] = lambda: tmqm_dataset(tmqm_dataset) dataset = DATASETS[dataset_name] LOG.info("processing %s for %s", dataset_name, tokenizer["name"]) @@ -382,7 +391,12 @@ def process_tokenizer(dataset_name: str, tokenizer: dict[str, str], args) -> dic datasets = args.dataset or DATASETS.keys() with concurrent.futures.ProcessPoolExecutor(max_workers=args.workers) as executor: futures = { - executor.submit(process_tokenizer, dataset, tokenizer, args): dataset + executor.submit( + process_tokenizer, + dataset, + tokenizer, + args.tmqm_dataset, + ): dataset for dataset in datasets } diff --git a/opt/run_logs/sync_wandb.py b/opt/run_logs/sync_wandb.py index a7a0b79b..1d459a83 100755 --- a/opt/run_logs/sync_wandb.py +++ b/opt/run_logs/sync_wandb.py @@ -174,6 +174,12 @@ def run_summary(run: Run): metadata = run.metadata if isinstance(run.metadata, dict) else {} git_info = metadata.get("git", {}) summary = run.summary if isinstance(run.summary, dict) else {} + summary_metrics = ( + dict(run.summary_metrics) + if hasattr(run, "summary_metrics") and run.summary_metrics + else {} + ) + sys_metrics = system_metrics(run) stats = { "id": run.id, @@ -187,7 +193,9 @@ def run_summary(run: Run): "created": metadata.get("startedAt"), "gpu": metadata.get("gpu"), "commit": git_info.get("commit") if isinstance(git_info, dict) else None, - "runtime": summary.get("_runtime"), + "runtime": summary.get("_runtime") + or summary_metrics.get("_runtime") + or get_entry(summary_metrics, "_wandb", "runtime"), "optimizer": { "class_path": get_entry( config, "cli", "model", "optimizer", "class_path" @@ -242,13 +250,11 @@ def run_summary(run: Run): run, "stats/val_batch_throughput", "mean" ), "train_batch_time": summary_metric(run, "stats/train_batch_time_epoch"), - **system_metrics(run), + **sys_metrics, }, - "summary_metrics": dict(run.summary_metrics) - if hasattr(run, "summary_metrics") and run.summary_metrics - else {}, + "summary_metrics": summary_metrics, } - + print(stats["id"], stats["runtime"]) # Record world_size world_size = (stats["job_config"]["nodes"] or nan) * ( stats["job_config"]["gpus_per_node"] or nan diff --git a/uv.lock b/uv.lock index 9895cb03..bfea30a3 100644 --- a/uv.lock +++ b/uv.lock @@ -3095,7 +3095,7 @@ wheels = [ [[package]] name = "wandb" -version = "0.22.2" +version = "0.25.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "click" }, @@ -3109,17 +3109,17 @@ dependencies = [ { name = "sentry-sdk" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/c1/a8/680bd77e11a278e6c14a2cb4646e8ab9525b2baaa81c3d12dc0f616aa4aa/wandb-0.22.2.tar.gz", hash = "sha256:510f5a1ac30d16921c36c3b932da852f046641d4aee98a86a7f5ec03a6e95bda", size = 41401439, upload-time = "2025-10-07T19:54:21.88Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/e7/b3/8c637fb594cfd574ce9c9f7d0ac2f2d12742eb38ec59dcbb713beae95343/wandb-0.22.2-py3-none-macosx_12_0_arm64.whl", hash = "sha256:2e29c9fa4462b5411b2cd2175ae33eff4309c91de7c426bca6bc8e7abc7e5dec", size = 18677549, upload-time = "2025-10-07T19:54:00.839Z" }, - { url = "https://files.pythonhosted.org/packages/d3/f3/e309a726eaebddad6b8d9a73a50891e5796962ec8a091bb6a61d31692d1e/wandb-0.22.2-py3-none-macosx_12_0_x86_64.whl", hash = "sha256:c42d594cd7a9da4fd39ecdb0abbc081b61f304123277b2b6c4ba84283956fd21", size = 19715188, upload-time = "2025-10-07T19:54:03.805Z" }, - { url = "https://files.pythonhosted.org/packages/f9/73/fad59910215876008f4781b57d828d1b19b3677c9b46af615e7229746435/wandb-0.22.2-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a5188d84e66d3fd584f3b3ae4d2a70e78f29403c0528e6aecaa4188a1fcf54d8", size = 18463148, upload-time = "2025-10-07T19:54:05.676Z" }, - { url = "https://files.pythonhosted.org/packages/87/11/572c1913b5b92e4c519f735adfae572b46f2d79d99ede63eec0d6a272d6e/wandb-0.22.2-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:88ccd484af9f21cfc127976793c3cf66cfe1acd75bd8cd650086a64e88bac4bf", size = 19908645, upload-time = "2025-10-07T19:54:07.693Z" }, - { url = "https://files.pythonhosted.org/packages/6d/0d/133aa82f5a505ba638b4fda5014cefddfe7f1f6238ef4afc0871ec61c41f/wandb-0.22.2-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:abf0ed175e791af64110e0a0b99ce02bbbbd1017722bc32d3bc328efb86450cd", size = 18501348, upload-time = "2025-10-07T19:54:10.234Z" }, - { url = "https://files.pythonhosted.org/packages/d0/d5/776203be2601872f01dacc6a5b4274106ec0db7cd3bf2cdb3b741f8fc932/wandb-0.22.2-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:44e77c56403b90bf3473a7ca3bfc4d42c636b7c0e31a5fb9cd0382f08302f74b", size = 20001756, upload-time = "2025-10-07T19:54:12.452Z" }, - { url = "https://files.pythonhosted.org/packages/30/43/ae3fa46e20b1d9a6508dd9abe716d57205c038ed4661c5c98ace48a60eac/wandb-0.22.2-py3-none-win32.whl", hash = "sha256:44d12bd379dbe15be5ceed6bdf23803d42f648ba0dd111297b4c47a3c7be6dbd", size = 19075950, upload-time = "2025-10-07T19:54:14.892Z" }, - { url = "https://files.pythonhosted.org/packages/09/59/c174321e868205f7a659d1e5ec51f546e62267296d6f4179bb9119294964/wandb-0.22.2-py3-none-win_amd64.whl", hash = "sha256:c95eb221bf316c0872f7ac55071856b9f25f95a2de983ada48acf653ce259386", size = 19075953, upload-time = "2025-10-07T19:54:16.837Z" }, - { url = "https://files.pythonhosted.org/packages/7a/a2/c7c24fda78513cab5686949d8cb36459dbbccbbb4b2b6fc67237ece31a00/wandb-0.22.2-py3-none-win_arm64.whl", hash = "sha256:20d2ab9aa10445aab3d60914a980f002a4f66566e28b0cd156b1e462f0080a0d", size = 17383217, upload-time = "2025-10-07T19:54:19.384Z" }, +sdist = { url = "https://files.pythonhosted.org/packages/60/bb/eb579bf9abac70934a014a9d4e45346aab307994f3021d201bebe5fa25ec/wandb-0.25.1.tar.gz", hash = "sha256:b2a95cd777ecbe7499599a43158834983448a0048329bc7210ef46ca18d21994", size = 43983308, upload-time = "2026-03-10T23:51:44.227Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e7/d8/873553b6818499d1b1de314067d528b892897baf0dc81fedc0e845abc2dd/wandb-0.25.1-py3-none-macosx_12_0_arm64.whl", hash = "sha256:9bb0679a3e2dcd96db9d9b6d3e17d046241d8d122974b24facb85cc93309a8c9", size = 23615900, upload-time = "2026-03-10T23:51:06.278Z" }, + { url = "https://files.pythonhosted.org/packages/71/ea/b131f319aaa5d0bf7572b6bfcff3dd89e1cf92b17eee443bbab71d12d74c/wandb-0.25.1-py3-none-macosx_12_0_x86_64.whl", hash = "sha256:0fb13ed18914027523e7b4fc20380c520e0d10da0ee452f924a13f84509fbe12", size = 25576144, upload-time = "2026-03-10T23:51:11.527Z" }, + { url = "https://files.pythonhosted.org/packages/70/5f/81508581f0bb77b0495665c1c78e77606a48e66e855ca71ba7c8ae29efa4/wandb-0.25.1-py3-none-manylinux_2_28_aarch64.whl", hash = "sha256:cc4521eb5223429ddab5e8eee9b42fdf4caabdf0bc4e0e809042720e5fbef0ed", size = 23070425, upload-time = "2026-03-10T23:51:15.71Z" }, + { url = "https://files.pythonhosted.org/packages/f2/c7/445155ef010e2e35d190797d7c36ff441e062a5b566a6da4778e22233395/wandb-0.25.1-py3-none-manylinux_2_28_x86_64.whl", hash = "sha256:e73b4c55b947edae349232d5845204d30fac88e18eb4ad1d4b96bf7cf898405a", size = 25628142, upload-time = "2026-03-10T23:51:19.326Z" }, + { url = "https://files.pythonhosted.org/packages/d5/63/f5c55ee00cf481ef1ccd3c385a0585ad52e7840d08419d4f82ddbeeea959/wandb-0.25.1-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:22b84065aa398e1624d2e5ad79e08bc4d2af41a6db61697b03b3aaba332977c6", size = 23123172, upload-time = "2026-03-10T23:51:23.418Z" }, + { url = "https://files.pythonhosted.org/packages/3e/d9/19eb7974c0e9253bcbaee655222c0f0e1a52e63e9479ee711b4208f8ac31/wandb-0.25.1-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:005c4c6b5126ef8f4b4110e5372d950918b00637d6dc4b615ad17445f9739478", size = 25714479, upload-time = "2026-03-10T23:51:27.421Z" }, + { url = "https://files.pythonhosted.org/packages/11/19/466c1d03323a4a0ed7d4036a59b18d6b6f67cb5032e444205927e226b18d/wandb-0.25.1-py3-none-win32.whl", hash = "sha256:8f2d04f16b88d65bfba9d79fb945f6c64e2686215469a841936e0972be8ec6a5", size = 24967338, upload-time = "2026-03-10T23:51:31.833Z" }, + { url = "https://files.pythonhosted.org/packages/89/22/680d34c1587f3a979c701b66d71aa7c42b4ef2fdf0774f67034e618e834e/wandb-0.25.1-py3-none-win_amd64.whl", hash = "sha256:62db5166de14456156d7a85953a58733a631228e6d4248a753605f75f75fb845", size = 24967343, upload-time = "2026-03-10T23:51:36.026Z" }, + { url = "https://files.pythonhosted.org/packages/c4/e8/76836b75d401ff5912aaf513176e64557ceaec4c4946bfd38a698ff84d48/wandb-0.25.1-py3-none-win_arm64.whl", hash = "sha256:cc7c34b70cf4b7be4d395541e82e325fd9d2be978d62c9ec01f1a7141523b6bb", size = 22080774, upload-time = "2026-03-10T23:51:40.196Z" }, ] [[package]]