Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ add_library(${TARGET} STATIC
regex-partial.h
sampling.cpp
sampling.h
sampling-reasoning-budget.cpp
sampling-reasoning-budget.h
speculative.cpp
speculative.h
unicode.cpp
Expand Down
101 changes: 101 additions & 0 deletions common/sampling-reasoning-budget.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
#include "sampling-reasoning-budget.h"

#include <cmath>

enum reasoning_budget_state {
REASONING_BUDGET_COUNTING,
REASONING_BUDGET_FORCING,
REASONING_BUDGET_PASSTHROUGH,
};

struct llama_sampler_reasoning_budget {
int32_t budget;
llama_token end_token;

reasoning_budget_state state;
int32_t n_sampled;
};

static const char * llama_sampler_reasoning_budget_name(const struct llama_sampler * /*smpl*/) {
return "reasoning-budget";
}

static void llama_sampler_reasoning_budget_accept(struct llama_sampler * smpl, llama_token token) {
auto * ctx = (llama_sampler_reasoning_budget *) smpl->ctx;

switch (ctx->state) {
case REASONING_BUDGET_COUNTING:
if (token == ctx->end_token) {
ctx->state = REASONING_BUDGET_PASSTHROUGH;
} else {
ctx->n_sampled++;
if (ctx->n_sampled >= ctx->budget) {
ctx->state = REASONING_BUDGET_FORCING;
}
}
break;
case REASONING_BUDGET_FORCING:
// The only token that should have been sampled is the end token.
ctx->state = REASONING_BUDGET_PASSTHROUGH;
break;
case REASONING_BUDGET_PASSTHROUGH:
break;
}
}

static void llama_sampler_reasoning_budget_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
auto * ctx = (llama_sampler_reasoning_budget *) smpl->ctx;

switch (ctx->state) {
case REASONING_BUDGET_COUNTING:
// Allow everything through — no modification.
break;
case REASONING_BUDGET_FORCING:
// Set all logits to -inf except the end token.
for (size_t i = 0; i < cur_p->size; i++) {
if (cur_p->data[i].id != ctx->end_token) {
cur_p->data[i].logit = -INFINITY;
}
}
break;
case REASONING_BUDGET_PASSTHROUGH:
// Allow everything through — no modification.
break;
}
}

static void llama_sampler_reasoning_budget_reset(struct llama_sampler * smpl) {
auto * ctx = (llama_sampler_reasoning_budget *) smpl->ctx;
ctx->state = REASONING_BUDGET_COUNTING;
ctx->n_sampled = 0;
}

static struct llama_sampler * llama_sampler_reasoning_budget_clone(const struct llama_sampler * smpl) {
const auto * ctx = (const llama_sampler_reasoning_budget *) smpl->ctx;
return common_sampler_init_reasoning_budget(ctx->budget, ctx->end_token);
}

static void llama_sampler_reasoning_budget_free(struct llama_sampler * smpl) {
delete (llama_sampler_reasoning_budget *) smpl->ctx;
}

static struct llama_sampler_i llama_sampler_reasoning_budget_i = {
/* .name = */ llama_sampler_reasoning_budget_name,
/* .accept = */ llama_sampler_reasoning_budget_accept,
/* .apply = */ llama_sampler_reasoning_budget_apply,
/* .reset = */ llama_sampler_reasoning_budget_reset,
/* .clone = */ llama_sampler_reasoning_budget_clone,
/* .free = */ llama_sampler_reasoning_budget_free,
};

struct llama_sampler * common_sampler_init_reasoning_budget(int32_t budget, llama_token end_token) {
return llama_sampler_init(
&llama_sampler_reasoning_budget_i,
new llama_sampler_reasoning_budget {
/* .budget = */ budget,
/* .end_token = */ end_token,
/* .state = */ REASONING_BUDGET_COUNTING,
/* .n_sampled = */ 0,
}
);
}
20 changes: 20 additions & 0 deletions common/sampling-reasoning-budget.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#pragma once

#include "llama.h"

// Reasoning budget sampler
//
// A stateful sampler that enforces a token budget on reasoning (thinking) output.
//
// Behavior:
// 1. COUNTING phase: Accept all tokens, counting each one, until either:
// - The end token is seen (transition to PASSTHROUGH), or
// - The budget is exhausted (transition to FORCING)
// 2. FORCING phase: Only allow the end token by setting all other logits to -inf.
// Once the end token is accepted, transition to PASSTHROUGH.
// 3. PASSTHROUGH phase: Do nothing — allow all tokens through unmodified.
//
// This is intended to be inserted early in a sampler chain (before temperature,
// top-k, etc.) so that it can mask logits before other samplers run.

struct llama_sampler * common_sampler_init_reasoning_budget(int32_t budget, llama_token end_token);
1 change: 1 addition & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ llama_build_and_test(
peg-parser/tests.h
)
llama_build_and_test(test-regex-partial.cpp)
llama_build_and_test(test-sampling-reasoning-budget.cpp)

if (NOT ${CMAKE_SYSTEM_PROCESSOR} MATCHES "s390x")
llama_build_and_test(test-thread-safety.cpp ARGS -hf ggml-org/models -hff tinyllamas/stories15M-q4_0.gguf -ngl 99 -p "The meaning of life is" -n 128 -c 256 -ub 32 -np 4 -t 2)
Expand Down
Loading
Loading