From 3cdad7d677f4919e3e4d9414ff12991089aac4a7 Mon Sep 17 00:00:00 2001 From: Claude Date: Mon, 9 Mar 2026 19:35:59 +0000 Subject: [PATCH 1/2] common: add reasoning budget sampler Implements a stateful sampler that enforces a token budget on reasoning output. It counts tokens until the budget is exhausted, then forces the configured end token, and finally becomes a passthrough for subsequent generation. https://claude.ai/code/session_016QuHSS4Xd8cDpBng1HiXY4 --- common/CMakeLists.txt | 2 + common/sampling-reasoning-budget.cpp | 101 +++++++++++++++++++++++++++ common/sampling-reasoning-budget.h | 20 ++++++ 3 files changed, 123 insertions(+) create mode 100644 common/sampling-reasoning-budget.cpp create mode 100644 common/sampling-reasoning-budget.h diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index 377b26846b6..b173ce3ede2 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -77,6 +77,8 @@ add_library(${TARGET} STATIC regex-partial.h sampling.cpp sampling.h + sampling-reasoning-budget.cpp + sampling-reasoning-budget.h speculative.cpp speculative.h unicode.cpp diff --git a/common/sampling-reasoning-budget.cpp b/common/sampling-reasoning-budget.cpp new file mode 100644 index 00000000000..74be198259d --- /dev/null +++ b/common/sampling-reasoning-budget.cpp @@ -0,0 +1,101 @@ +#include "sampling-reasoning-budget.h" + +#include + +enum reasoning_budget_state { + REASONING_BUDGET_COUNTING, + REASONING_BUDGET_FORCING, + REASONING_BUDGET_PASSTHROUGH, +}; + +struct llama_sampler_reasoning_budget { + int32_t budget; + llama_token end_token; + + reasoning_budget_state state; + int32_t n_sampled; +}; + +static const char * llama_sampler_reasoning_budget_name(const struct llama_sampler * /*smpl*/) { + return "reasoning-budget"; +} + +static void llama_sampler_reasoning_budget_accept(struct llama_sampler * smpl, llama_token token) { + auto * ctx = (llama_sampler_reasoning_budget *) smpl->ctx; + + switch (ctx->state) { + case REASONING_BUDGET_COUNTING: + if (token == ctx->end_token) { + ctx->state = REASONING_BUDGET_PASSTHROUGH; + } else { + ctx->n_sampled++; + if (ctx->n_sampled >= ctx->budget) { + ctx->state = REASONING_BUDGET_FORCING; + } + } + break; + case REASONING_BUDGET_FORCING: + // The only token that should have been sampled is the end token. + ctx->state = REASONING_BUDGET_PASSTHROUGH; + break; + case REASONING_BUDGET_PASSTHROUGH: + break; + } +} + +static void llama_sampler_reasoning_budget_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + auto * ctx = (llama_sampler_reasoning_budget *) smpl->ctx; + + switch (ctx->state) { + case REASONING_BUDGET_COUNTING: + // Allow everything through — no modification. + break; + case REASONING_BUDGET_FORCING: + // Set all logits to -inf except the end token. + for (size_t i = 0; i < cur_p->size; i++) { + if (cur_p->data[i].id != ctx->end_token) { + cur_p->data[i].logit = -INFINITY; + } + } + break; + case REASONING_BUDGET_PASSTHROUGH: + // Allow everything through — no modification. + break; + } +} + +static void llama_sampler_reasoning_budget_reset(struct llama_sampler * smpl) { + auto * ctx = (llama_sampler_reasoning_budget *) smpl->ctx; + ctx->state = REASONING_BUDGET_COUNTING; + ctx->n_sampled = 0; +} + +static struct llama_sampler * llama_sampler_reasoning_budget_clone(const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_reasoning_budget *) smpl->ctx; + return common_sampler_init_reasoning_budget(ctx->budget, ctx->end_token); +} + +static void llama_sampler_reasoning_budget_free(struct llama_sampler * smpl) { + delete (llama_sampler_reasoning_budget *) smpl->ctx; +} + +static struct llama_sampler_i llama_sampler_reasoning_budget_i = { + /* .name = */ llama_sampler_reasoning_budget_name, + /* .accept = */ llama_sampler_reasoning_budget_accept, + /* .apply = */ llama_sampler_reasoning_budget_apply, + /* .reset = */ llama_sampler_reasoning_budget_reset, + /* .clone = */ llama_sampler_reasoning_budget_clone, + /* .free = */ llama_sampler_reasoning_budget_free, +}; + +struct llama_sampler * common_sampler_init_reasoning_budget(int32_t budget, llama_token end_token) { + return llama_sampler_init( + &llama_sampler_reasoning_budget_i, + new llama_sampler_reasoning_budget { + /* .budget = */ budget, + /* .end_token = */ end_token, + /* .state = */ REASONING_BUDGET_COUNTING, + /* .n_sampled = */ 0, + } + ); +} diff --git a/common/sampling-reasoning-budget.h b/common/sampling-reasoning-budget.h new file mode 100644 index 00000000000..cc9569dc6eb --- /dev/null +++ b/common/sampling-reasoning-budget.h @@ -0,0 +1,20 @@ +#pragma once + +#include "llama.h" + +// Reasoning budget sampler +// +// A stateful sampler that enforces a token budget on reasoning (thinking) output. +// +// Behavior: +// 1. COUNTING phase: Accept all tokens, counting each one, until either: +// - The end token is seen (transition to PASSTHROUGH), or +// - The budget is exhausted (transition to FORCING) +// 2. FORCING phase: Only allow the end token by setting all other logits to -inf. +// Once the end token is accepted, transition to PASSTHROUGH. +// 3. PASSTHROUGH phase: Do nothing — allow all tokens through unmodified. +// +// This is intended to be inserted early in a sampler chain (before temperature, +// top-k, etc.) so that it can mask logits before other samplers run. + +struct llama_sampler * common_sampler_init_reasoning_budget(int32_t budget, llama_token end_token); From 856a50f80858ff7be6d3d1aea62064430581c91d Mon Sep 17 00:00:00 2001 From: Claude Date: Mon, 9 Mar 2026 19:41:06 +0000 Subject: [PATCH 2/2] tests: add reasoning budget sampler tests Tests cover all three states (counting, forcing, passthrough), state transitions, reset, clone, zero budget edge case, and sampler name. Uses the peg-parser testing harness for consistent test output. https://claude.ai/code/session_016QuHSS4Xd8cDpBng1HiXY4 --- tests/CMakeLists.txt | 1 + tests/test-sampling-reasoning-budget.cpp | 328 +++++++++++++++++++++++ 2 files changed, 329 insertions(+) create mode 100644 tests/test-sampling-reasoning-budget.cpp diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 9ba559c8dfb..1a5162815c2 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -200,6 +200,7 @@ llama_build_and_test( peg-parser/tests.h ) llama_build_and_test(test-regex-partial.cpp) +llama_build_and_test(test-sampling-reasoning-budget.cpp) if (NOT ${CMAKE_SYSTEM_PROCESSOR} MATCHES "s390x") llama_build_and_test(test-thread-safety.cpp ARGS -hf ggml-org/models -hff tinyllamas/stories15M-q4_0.gguf -ngl 99 -p "The meaning of life is" -n 128 -c 256 -ub 32 -np 4 -t 2) diff --git a/tests/test-sampling-reasoning-budget.cpp b/tests/test-sampling-reasoning-budget.cpp new file mode 100644 index 00000000000..e3fc145e809 --- /dev/null +++ b/tests/test-sampling-reasoning-budget.cpp @@ -0,0 +1,328 @@ +#include +#include +#include +#include + +#include "peg-parser/testing.h" +#include "sampling-reasoning-budget.h" + +// Helper: build a token data array from a list of token ids, all with logit 0.0 +static std::vector make_candidates(const std::vector & tokens) { + std::vector data; + data.reserve(tokens.size()); + for (auto id : tokens) { + data.push_back({id, 0.0f, 0.0f}); + } + return data; +} + +static llama_token_data_array make_cur_p(std::vector & data) { + return { data.data(), data.size(), -1, false }; +} + +// Check whether a token's logit was set to -inf +static bool is_blocked(const llama_token_data_array & cur_p, llama_token id) { + for (size_t i = 0; i < cur_p.size; i++) { + if (cur_p.data[i].id == id) { + return cur_p.data[i].logit == -INFINITY; + } + } + return false; +} + +static bool is_allowed(const llama_token_data_array & cur_p, llama_token id) { + return !is_blocked(cur_p, id); +} + +static constexpr llama_token END_TOKEN = 99; +static constexpr llama_token TOKEN_A = 1; +static constexpr llama_token TOKEN_B = 2; +static constexpr llama_token TOKEN_C = 3; + +static const std::vector ALL_TOKENS = { TOKEN_A, TOKEN_B, TOKEN_C, END_TOKEN }; + +static void test_passthrough_before_budget(testing & t) { + t.test("all tokens allowed before budget", [](testing & t) { + auto * smpl = common_sampler_init_reasoning_budget(5, END_TOKEN); + + auto data = make_candidates(ALL_TOKENS); + auto cur_p = make_cur_p(data); + + llama_sampler_apply(smpl, &cur_p); + + t.assert_true("token A allowed", is_allowed(cur_p, TOKEN_A)); + t.assert_true("token B allowed", is_allowed(cur_p, TOKEN_B)); + t.assert_true("token C allowed", is_allowed(cur_p, TOKEN_C)); + t.assert_true("end token allowed", is_allowed(cur_p, END_TOKEN)); + + llama_sampler_free(smpl); + }); + + t.test("tokens allowed while counting", [](testing & t) { + auto * smpl = common_sampler_init_reasoning_budget(5, END_TOKEN); + + // Accept 3 tokens (under budget of 5) + llama_sampler_accept(smpl, TOKEN_A); + llama_sampler_accept(smpl, TOKEN_B); + llama_sampler_accept(smpl, TOKEN_C); + + auto data = make_candidates(ALL_TOKENS); + auto cur_p = make_cur_p(data); + + llama_sampler_apply(smpl, &cur_p); + + t.assert_true("token A allowed", is_allowed(cur_p, TOKEN_A)); + t.assert_true("token B allowed", is_allowed(cur_p, TOKEN_B)); + t.assert_true("end token allowed", is_allowed(cur_p, END_TOKEN)); + + llama_sampler_free(smpl); + }); +} + +static void test_forcing_at_budget(testing & t) { + t.test("only end token allowed when budget exhausted", [](testing & t) { + auto * smpl = common_sampler_init_reasoning_budget(3, END_TOKEN); + + // Exhaust the budget + llama_sampler_accept(smpl, TOKEN_A); + llama_sampler_accept(smpl, TOKEN_B); + llama_sampler_accept(smpl, TOKEN_C); + + auto data = make_candidates(ALL_TOKENS); + auto cur_p = make_cur_p(data); + + llama_sampler_apply(smpl, &cur_p); + + t.assert_true("token A blocked", is_blocked(cur_p, TOKEN_A)); + t.assert_true("token B blocked", is_blocked(cur_p, TOKEN_B)); + t.assert_true("token C blocked", is_blocked(cur_p, TOKEN_C)); + t.assert_true("end token allowed", is_allowed(cur_p, END_TOKEN)); + + llama_sampler_free(smpl); + }); + + t.test("budget of 1", [](testing & t) { + auto * smpl = common_sampler_init_reasoning_budget(1, END_TOKEN); + + llama_sampler_accept(smpl, TOKEN_A); + + auto data = make_candidates(ALL_TOKENS); + auto cur_p = make_cur_p(data); + + llama_sampler_apply(smpl, &cur_p); + + t.assert_true("token A blocked", is_blocked(cur_p, TOKEN_A)); + t.assert_true("end token allowed", is_allowed(cur_p, END_TOKEN)); + + llama_sampler_free(smpl); + }); +} + +static void test_passthrough_after_end(testing & t) { + t.test("passthrough after end token during counting", [](testing & t) { + auto * smpl = common_sampler_init_reasoning_budget(10, END_TOKEN); + + // Accept 2 tokens then the end token (under budget) + llama_sampler_accept(smpl, TOKEN_A); + llama_sampler_accept(smpl, TOKEN_B); + llama_sampler_accept(smpl, END_TOKEN); + + // Now should be passthrough + auto data = make_candidates(ALL_TOKENS); + auto cur_p = make_cur_p(data); + + llama_sampler_apply(smpl, &cur_p); + + t.assert_true("token A allowed", is_allowed(cur_p, TOKEN_A)); + t.assert_true("token B allowed", is_allowed(cur_p, TOKEN_B)); + t.assert_true("token C allowed", is_allowed(cur_p, TOKEN_C)); + t.assert_true("end token allowed", is_allowed(cur_p, END_TOKEN)); + + llama_sampler_free(smpl); + }); + + t.test("passthrough after forced end token", [](testing & t) { + auto * smpl = common_sampler_init_reasoning_budget(2, END_TOKEN); + + // Exhaust budget + llama_sampler_accept(smpl, TOKEN_A); + llama_sampler_accept(smpl, TOKEN_B); + + // Now in FORCING state, accept the end token + llama_sampler_accept(smpl, END_TOKEN); + + // Should be passthrough now + auto data = make_candidates(ALL_TOKENS); + auto cur_p = make_cur_p(data); + + llama_sampler_apply(smpl, &cur_p); + + t.assert_true("token A allowed", is_allowed(cur_p, TOKEN_A)); + t.assert_true("token B allowed", is_allowed(cur_p, TOKEN_B)); + t.assert_true("token C allowed", is_allowed(cur_p, TOKEN_C)); + t.assert_true("end token allowed", is_allowed(cur_p, END_TOKEN)); + + llama_sampler_free(smpl); + }); + + t.test("passthrough persists across many tokens", [](testing & t) { + auto * smpl = common_sampler_init_reasoning_budget(1, END_TOKEN); + + // Exhaust budget and force end + llama_sampler_accept(smpl, TOKEN_A); + llama_sampler_accept(smpl, END_TOKEN); + + // Accept many more tokens in passthrough + for (int i = 0; i < 100; i++) { + llama_sampler_accept(smpl, TOKEN_B); + } + + auto data = make_candidates(ALL_TOKENS); + auto cur_p = make_cur_p(data); + + llama_sampler_apply(smpl, &cur_p); + + t.assert_true("token A allowed", is_allowed(cur_p, TOKEN_A)); + t.assert_true("token B allowed", is_allowed(cur_p, TOKEN_B)); + + llama_sampler_free(smpl); + }); +} + +static void test_reset(testing & t) { + t.test("reset returns to counting state", [](testing & t) { + auto * smpl = common_sampler_init_reasoning_budget(2, END_TOKEN); + + // Exhaust budget + llama_sampler_accept(smpl, TOKEN_A); + llama_sampler_accept(smpl, TOKEN_B); + + // Verify forcing + { + auto data = make_candidates(ALL_TOKENS); + auto cur_p = make_cur_p(data); + llama_sampler_apply(smpl, &cur_p); + t.assert_true("token A blocked before reset", is_blocked(cur_p, TOKEN_A)); + } + + // Reset + llama_sampler_reset(smpl); + + // Should be back to counting, all tokens allowed + { + auto data = make_candidates(ALL_TOKENS); + auto cur_p = make_cur_p(data); + llama_sampler_apply(smpl, &cur_p); + t.assert_true("token A allowed after reset", is_allowed(cur_p, TOKEN_A)); + t.assert_true("token B allowed after reset", is_allowed(cur_p, TOKEN_B)); + } + + llama_sampler_free(smpl); + }); + + t.test("reset from passthrough returns to counting", [](testing & t) { + auto * smpl = common_sampler_init_reasoning_budget(5, END_TOKEN); + + // Enter passthrough via early end token + llama_sampler_accept(smpl, END_TOKEN); + + // Reset + llama_sampler_reset(smpl); + + // Exhaust the budget again to prove we're counting + for (int i = 0; i < 5; i++) { + llama_sampler_accept(smpl, TOKEN_A); + } + + auto data = make_candidates(ALL_TOKENS); + auto cur_p = make_cur_p(data); + llama_sampler_apply(smpl, &cur_p); + t.assert_true("forcing after reset and re-exhaust", is_blocked(cur_p, TOKEN_A)); + t.assert_true("end token allowed", is_allowed(cur_p, END_TOKEN)); + + llama_sampler_free(smpl); + }); +} + +static void test_clone(testing & t) { + t.test("clone preserves config", [](testing & t) { + auto * smpl = common_sampler_init_reasoning_budget(3, END_TOKEN); + + // Partially use the budget + llama_sampler_accept(smpl, TOKEN_A); + llama_sampler_accept(smpl, TOKEN_B); + + // Clone (clone starts fresh per the clone implementation) + auto * cloned = llama_sampler_clone(smpl); + + // The clone should be in initial counting state with budget=3 + // Accept 3 tokens to exhaust its budget + llama_sampler_accept(cloned, TOKEN_A); + llama_sampler_accept(cloned, TOKEN_B); + llama_sampler_accept(cloned, TOKEN_C); + + auto data = make_candidates(ALL_TOKENS); + auto cur_p = make_cur_p(data); + llama_sampler_apply(cloned, &cur_p); + + t.assert_true("cloned: token A blocked", is_blocked(cur_p, TOKEN_A)); + t.assert_true("cloned: end token allowed", is_allowed(cur_p, END_TOKEN)); + + llama_sampler_free(smpl); + llama_sampler_free(cloned); + }); +} + +static void test_name(testing & t) { + t.test("sampler name", [](testing & t) { + auto * smpl = common_sampler_init_reasoning_budget(10, END_TOKEN); + + t.assert_equal("name", std::string("reasoning-budget"), std::string(llama_sampler_name(smpl))); + + llama_sampler_free(smpl); + }); +} + +static void test_zero_budget(testing & t) { + t.test("zero budget forces immediately", [](testing & t) { + auto * smpl = common_sampler_init_reasoning_budget(0, END_TOKEN); + + // Budget is 0 but no tokens have been accepted yet, so n_sampled (0) >= budget (0) + // The state machine starts in COUNTING and the transition happens in accept(). + // Since we haven't accepted any token, apply should still show COUNTING. + // But after accepting any token the budget check triggers. + // Let's accept one token to trigger the transition. + llama_sampler_accept(smpl, TOKEN_A); + + auto data = make_candidates(ALL_TOKENS); + auto cur_p = make_cur_p(data); + llama_sampler_apply(smpl, &cur_p); + + t.assert_true("token A blocked", is_blocked(cur_p, TOKEN_A)); + t.assert_true("end token allowed", is_allowed(cur_p, END_TOKEN)); + + llama_sampler_free(smpl); + }); +} + +int main(int argc, char * argv[]) { + testing t(std::cout); + if (argc >= 2) { + t.set_filter(argv[1]); + } + + const char * verbose = getenv("LLAMA_TEST_VERBOSE"); + if (verbose) { + t.verbose = std::string(verbose) == "1"; + } + + t.test("passthrough_before_budget", test_passthrough_before_budget); + t.test("forcing_at_budget", test_forcing_at_budget); + t.test("passthrough_after_end", test_passthrough_after_end); + t.test("reset", test_reset); + t.test("clone", test_clone); + t.test("name", test_name); + t.test("zero_budget", test_zero_budget); + + return t.summary(); +}