diff --git a/common/sampling.cpp b/common/sampling.cpp index 9c04d35fd00..452cefee3b9 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -348,6 +348,11 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co llama_sampler_apply(chain, &cur_p); + /*for (int k = 0; k < (int)cur_p.size; ++k) { + LOG_INF(" - draft candidate %3d, pos %3d: %6d (%8.3f)\n", + k, 0, cur_p.data[k].id, cur_p.data[k].p); + }*/ + GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration"); const llama_token id = cur_p.data[cur_p.selected].id; @@ -577,3 +582,7 @@ std::vector common_sampler_types_from_chars(const std::stri return samplers; } + +void common_sampler_apply_chain(struct common_sampler * gsmpl, struct llama_token_data_array * cur_p) { + llama_sampler_apply(gsmpl->chain, cur_p); +} \ No newline at end of file diff --git a/common/sampling.h b/common/sampling.h index 2064421db4e..b424d7d6d70 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -105,3 +105,5 @@ std::vector common_sampler_types_from_chars(const std: llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab, const char * grammar_kind, const char * grammar_data); + +void common_sampler_apply_chain(struct common_sampler * gsmpl, struct llama_token_data_array * cur_p); \ No newline at end of file diff --git a/common/speculative.cpp b/common/speculative.cpp index 262b2c23e72..64ef1e4f51a 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -5,6 +5,8 @@ #include "log.h" #include "common.h" #include "sampling.h" +#include "../src/llama-graph.h" +#include "../src/llama-context.h" #include #include @@ -359,3 +361,55 @@ llama_tokens common_speculative_gen_draft( } return result; } + + +llama_token mtp_speculative_gen_draft( + struct common_sampler* smpl, + struct llama_context* ctx, + llama_token id_last, + int32_t n_past, + int32_t last_tok_idx) { + + if (!smpl) { + return -1; + } + llama_batch mtp_batch = llama_batch_init(1, 0, 1); + const llama_seq_id draft_seq_id = 0; + common_batch_add(mtp_batch, id_last, n_past, {0}, true); + + mtp_batch.mtp_params.op_type = MTP_OP_DRAFT_ONLY; + + // Perform the MTP draft generation decode. This writes the MTP layer's + // KV state for the draft token into the cache. + const int64_t t_start_us = ggml_time_us(); + if (llama_decode(ctx, mtp_batch) != 0) { + llama_batch_free(mtp_batch); + return -1; + } + const int64_t t_end_us = ggml_time_us(); + LOG_INF("[PERF-MTP] mtp_speculative_gen_draft internal decode: %.2f ms\n", (t_end_us - t_start_us) / 1000.0); + llama_batch_free(mtp_batch); + + // CRITICAL: Purge the metadata for the draft token we just wrote. + // This makes the physical cell available again for the main model's validation pass, + // preventing a cache state corruption where two cells map to the same logical position. + llama_kv_cache_seq_rm(ctx, draft_seq_id, n_past, n_past + 1); + + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); + const int n_vocab = llama_n_vocab(vocab); + + llama_token_data_array * cur_p = common_sampler_get_candidates(smpl); + float * logits = llama_get_logits_ith(ctx, 0); + cur_p->size = n_vocab; + + for (int i = 0; i < n_vocab; ++i) { + cur_p->data[i].id = i; + cur_p->data[i].logit = logits[i]; + } + + cur_p->sorted = false; + common_sampler_apply_chain(smpl, cur_p); + + return cur_p->data[0].id; +} \ No newline at end of file diff --git a/common/speculative.h b/common/speculative.h index e69d7aaa1eb..5cf80701d3b 100644 --- a/common/speculative.h +++ b/common/speculative.h @@ -12,6 +12,12 @@ struct common_speculative_params { float p_min = 0.75f; // min probability required to accept a token in the draft }; +struct mtp_kv_update_data { + llama_token id; + int32_t n_past; + int32_t tok_idx; +}; + struct common_speculative * common_speculative_init( struct llama_context * ctx_tgt, struct llama_context * ctx_dft @@ -27,9 +33,18 @@ void common_speculative_add_replacement_tgt_dft( struct common_speculative * spec, const char *source, const char *dest); + +// sample up to n_draft tokens and add them to the batch using the draft model +llama_token mtp_speculative_gen_draft( + struct common_sampler* smpl, + struct llama_context* ctx, + llama_token id_last, + int32_t n_past, + int32_t last_tok_idx); + // sample up to n_draft tokens and add them to the batch using the draft model llama_tokens common_speculative_gen_draft( - struct common_speculative * spec, - struct common_speculative_params params, - const llama_tokens & prompt, - llama_token id_last); + struct common_speculative * spec, + struct common_speculative_params params, + const llama_tokens & prompt, + llama_token id_last); \ No newline at end of file diff --git a/include/llama.h b/include/llama.h index 545e957e5f5..ce53dd5f20f 100644 --- a/include/llama.h +++ b/include/llama.h @@ -221,6 +221,16 @@ extern "C" { // - if not: only the last token is output // ) // + typedef enum { + MTP_OP_NONE, + MTP_OP_DRAFT_ONLY, + MTP_OP_UNIFIED, + } llama_mtp_op_type; + + typedef struct llama_mtp_params { + llama_mtp_op_type op_type; + } llama_mtp_params; + typedef struct llama_batch { int32_t n_tokens; @@ -230,6 +240,7 @@ extern "C" { int32_t * n_seq_id; llama_seq_id ** seq_id; int8_t * logits; // TODO: rename this to "output" + llama_mtp_params mtp_params; } llama_batch; enum llama_model_kv_override_type { @@ -495,6 +506,8 @@ extern "C" { LLAMA_API int32_t llama_vocab_n_tokens(const struct llama_vocab * vocab); + LLAMA_API int32_t llama_model_n_nextn_layer(const struct llama_model * model); + // Functions to access the model's GGUF metadata scalar values // - The functions return the length of the string on success, or -1 on failure // - The output string is always null-terminated and cleared on failure @@ -548,6 +561,8 @@ extern "C" { const char * fname_out, const llama_model_quantize_params * params); + + // // Adapters // @@ -1450,6 +1465,18 @@ extern "C" { ggml_opt_epoch_callback callback_train, ggml_opt_epoch_callback callback_eval); + // + // MTP + // + + LLAMA_API void llama_set_draft_input_hidden_state(struct llama_context * ctx, const float * hidden_state); + + /** + * @brief Removes KV cache metadata for a specified sequence and token range. + * This makes the physical cells logically available again without deleting the tensor data. + */ + LLAMA_API void llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1); + #ifdef __cplusplus } #endif diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 18dcc6ddfe5..4b6fa3e6059 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -2240,12 +2240,13 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_SHORTCONV_OUTPROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, // NextN/MTP tensors are currently ignored (reserved for future MTP support) // These tensors only exist in the last layer(s) and are treated as output tensors - {LLM_TENSOR_NEXTN_EH_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_NEXTN_EMBED_TOKENS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}}, - {LLM_TENSOR_NEXTN_ENORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}}, - {LLM_TENSOR_NEXTN_HNORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, - {LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, + // Changed to LLM_TENSOR_LAYER_REPEATING because we saved these under a blk with a non-negative id + {LLM_TENSOR_NEXTN_EH_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_NEXTN_EMBED_TOKENS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_GET_ROWS}}, + {LLM_TENSOR_NEXTN_ENORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_GET_ROWS}}, + {LLM_TENSOR_NEXTN_HNORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, }; LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {} diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index 8698d89acec..c01960c55ea 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -275,7 +275,9 @@ bool llama_batch_allocr::init( } } - if (!ok) { + // TEMPORARILY DISABLING THIS SANITY CHECK + // TODO: UNDO THIS IF IT WORKS + /*if (!ok) { LLAMA_LOG_ERROR( "%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n" " - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n" @@ -284,7 +286,7 @@ bool llama_batch_allocr::init( __func__, s, s, p0, s, seq_pos_min(s)); return false; - } + }*/ } if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) { @@ -832,13 +834,14 @@ struct llama_batch llama_batch_get_one( struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) { llama_batch batch = { - /*n_tokens =*/ 0, - /*tokens =*/ nullptr, - /*embd =*/ nullptr, - /*pos =*/ nullptr, - /*n_seq_id =*/ nullptr, - /*seq_id =*/ nullptr, - /*logits =*/ nullptr, + /*n_tokens =*/ 0, + /*tokens =*/ nullptr, + /*embd =*/ nullptr, + /*pos =*/ nullptr, + /*n_seq_id =*/ nullptr, + /*seq_id =*/ nullptr, + /*logits =*/ nullptr, + /*.mtp_params =*/ { MTP_OP_NONE }, }; if (embd) { diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 26a5cf9c3f8..44555b83d4c 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -6,6 +6,8 @@ #include "llama-memory.h" #include "llama-mmap.h" #include "llama-model.h" +#include "llama-graph.h" +#include "llama-kv-cache-unified.h" #include #include @@ -15,6 +17,26 @@ // // llama_context // +// Key for the graph cache. It contains all parameters that define the graph topology. +struct llama_graph_cache_key { + uint32_t n_tokens; + uint32_t n_outputs; + llama_mtp_op_type op_type; + bool causal_attn; + + bool operator<(const llama_graph_cache_key& other) const { + return std::tie(n_tokens, n_outputs, op_type, causal_attn) < + std::tie(other.n_tokens, other.n_outputs, other.op_type, other.causal_attn); + } +}; + +struct llama_context_kv_cache_data { + llama_kv_cache_unified::slot_info_vec_t last_main_model_sinfos; + llama_kv_cache_unified::slot_info_vec_t resized_sinfo_for_force; + const llama_kv_cache_unified::slot_info_vec_t * forced_sinfos = nullptr; + std::map graph_cache; + llm_graph_result_ptr gf_res_prev_validation; +}; llama_context::llama_context( const llama_model & model, @@ -103,6 +125,10 @@ llama_context::llama_context( cparams.op_offload = params.op_offload; cparams.kv_unified = params.kv_unified; + kv_cache_data = new llama_context_kv_cache_data(); + auto * kvd = static_cast(kv_cache_data); + kvd->gf_res_prev_validation = std::make_unique(graph_max_nodes()); + { const char * LLAMA_SET_ROWS = getenv("LLAMA_SET_ROWS"); supports_set_rows = LLAMA_SET_ROWS ? (atoi(LLAMA_SET_ROWS) != 0) : supports_set_rows; @@ -279,9 +305,10 @@ llama_context::llama_context( } sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, pipeline_parallel, cparams.op_offload)); + sched_mtp.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, pipeline_parallel, cparams.op_offload)); if (pipeline_parallel) { - LLAMA_LOG_INFO("%s: pipeline parallelism enabled (n_copies=%d)\n", __func__, ggml_backend_sched_get_n_copies(sched.get())); + LLAMA_LOG_INFO("%s: pipeline parallelism enabled (n_copies=%d)\n", __func__, ggml_backend_sched_get_n_copies(sched.get()), ggml_backend_sched_get_n_copies(sched_mtp.get())); } } @@ -368,6 +395,7 @@ llama_context::llama_context( llama_context::~llama_context() { ggml_opt_free(opt_ctx); + delete static_cast(kv_cache_data); } void llama_context::synchronize() { @@ -522,6 +550,18 @@ float * llama_context::get_logits() { return logits; } +void llama_context::set_logits_ith(struct ggml_tensor * logit_override, ggml_backend_sched_t sched_override, int32_t i) { + output_reorder(); + + ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched_override, logit_override); + GGML_ASSERT(backend_res != nullptr); + GGML_ASSERT(logits != nullptr); + + int64_t j = output_ids[i]; + + ggml_backend_tensor_get_async(backend_res, logit_override, logits + j*model.vocab.n_tokens(), 0, model.vocab.n_tokens() * sizeof(float)); +} + float * llama_context::get_logits_ith(int32_t i) { int64_t j = -1; @@ -617,6 +657,10 @@ float * llama_context::get_embeddings_seq(llama_seq_id seq_id) { return it->second.data(); } +ggml_tensor * llama_context::get_embeddings_tensor() { + return embd_tensor; +} + void llama_context::attach_threadpool( ggml_threadpool_t threadpool, ggml_threadpool_t threadpool_batch) { @@ -711,47 +755,85 @@ bool llama_context::apply_adapter_cvec( return cvec.apply(model, data, len, n_embd, il_start, il_end); } -llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) { +llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret, const llama_mtp_params & mtp_params) { if (mctx && !mctx->apply()) { LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__); ret = GGML_STATUS_FAILED; return nullptr; } - auto * res = gf_res_prev.get(); - auto * gf = res->get_gf(); + auto * kvd = static_cast(kv_cache_data); + + ggml_backend_sched * current_sched = nullptr; + llm_graph_result * res = nullptr; - // the new graph parameters - // in order to correctly reuse a graph, it's full topology has to be uniquely determined by these parameters - const auto gparams = graph_params(res, ubatch, mctx, gtype); + if (mtp_params.op_type == MTP_OP_DRAFT_ONLY) { + current_sched = sched_mtp.get(); - if (!graph_reuse_disable && res->can_reuse(gparams)) { - //LLAMA_LOG_DEBUG("%s: reusing previous graph\n", __func__); + int32_t n_outputs = 0; + for (int i = 0; i < ubatch.n_tokens; ++i) { if (ubatch.output[i]) n_outputs++; } + + const llama_graph_cache_key key = { + (uint32_t)ubatch.n_tokens, + (uint32_t)n_outputs, + mtp_params.op_type, + cparams.causal_attn + }; - n_reused++; + auto & res_ptr = kvd->graph_cache[key]; + if (!res_ptr) { + // LLAMA_LOG_INFO("[CACHE] New Entry: op=%d tokens=%d\n", key.op_type, key.n_tokens); + res_ptr = std::make_unique(graph_max_nodes()); + } + res = res_ptr.get(); + } else { - res->reset(); + current_sched = sched.get(); + res = gf_res_prev.get(); + } - ggml_backend_sched_reset(sched.get()); - ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data); + const auto gparams = graph_params(res, ubatch, mctx, gtype, mtp_params); - //const auto t_start_us = ggml_time_us(); + bool structure_hit = !graph_reuse_disable && res->can_reuse(gparams); - gf = model.build_graph(gparams); + if (structure_hit) { + LLAMA_LOG_INFO("[GRAPH-REUSE] HIT (op=%d)\n", mtp_params.op_type); + if (current_sched == sched.get()) { + ggml_backend_sched_reset(current_sched); + + if (!ggml_backend_sched_alloc_graph(current_sched, res->gf)) { + LLAMA_LOG_ERROR("%s: failed to allocate graph on reuse\n", __func__); + return nullptr; + } + } + + res->set_inputs(&ubatch); - //LLAMA_LOG_INFO("graph build time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0); + } else { + LLAMA_LOG_INFO("[GRAPH-REUSE] MISS (op=%d) - Rebuilding\n", mtp_params.op_type); + + ggml_backend_sched_reset(current_sched); + ggml_backend_sched_set_eval_callback(current_sched, cparams.cb_eval, cparams.cb_eval_user_data); - if (!gf) { - LLAMA_LOG_ERROR("%s: failed to initialize graph\n", __func__); - ret = GGML_STATUS_FAILED; - return nullptr; + res->reset(); + res->set_params(gparams); + res->gf = model.build_graph(gparams); + + if (!ggml_backend_sched_alloc_graph(current_sched, res->gf)) { + return nullptr; } + } + + if (mtp_params.op_type == MTP_OP_DRAFT_ONLY) { + LLAMA_LOG_INFO("[MTP-DEBUG] Executing DRAFT_ONLY path.\n"); - if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) { - LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__); - ret = GGML_STATUS_ALLOC_FAILED; + const int64_t t_inputs_start_us = ggml_time_us(); + if (!prepare_mtp_graph_inputs(res, ubatch, mtp_params)) { + ret = GGML_STATUS_FAILED; return nullptr; } + const int64_t t_inputs_end_us = ggml_time_us(); + LLAMA_LOG_INFO("[PERF-MTP] DRAFT_ONLY input setup: %.2f ms\n", (t_inputs_end_us - t_inputs_start_us) / 1000.0); } // set the input data for the input tensors @@ -763,7 +845,10 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll //LLAMA_LOG_INFO("graph set inputs time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0); } - const auto status = graph_compute(res->get_gf(), ubatch.n_tokens > 1); + const int64_t t_compute_start_us = ggml_time_us(); + auto status = graph_compute(res->get_gf(), ubatch.n_tokens > 1, current_sched); + const int64_t t_compute_end_us = ggml_time_us(); + LLAMA_LOG_INFO("[PERF-MTP] DRAFT_ONLY graph compute: %.2f ms\n", (t_compute_end_us - t_compute_start_us) / 1000.0); if (status != GGML_STATUS_SUCCESS) { LLAMA_LOG_ERROR("%s: failed to compute graph, compute status: %d\n", __func__, status); ret = status; @@ -771,7 +856,6 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll } ret = GGML_STATUS_SUCCESS; - return res; } @@ -832,7 +916,7 @@ int llama_context::encode(const llama_batch & batch_inp) { cparams.causal_attn = false; ggml_status status; - const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status); + const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status, { MTP_OP_NONE }); cparams.causal_attn = causal_attn_org; @@ -946,6 +1030,8 @@ int llama_context::encode(const llama_batch & batch_inp) { int llama_context::decode(const llama_batch & batch_inp) { GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT + auto * kvd = static_cast(kv_cache_data); + if (!memory) { LLAMA_LOG_DEBUG("%s: cannot decode batches with this context (calling encode() instead)\n", __func__); return encode(batch_inp); @@ -1004,6 +1090,7 @@ int llama_context::decode(const llama_batch & batch_inp) { while (true) { mctx = memory->init_batch(*balloc, cparams.n_ubatch, output_all); + if (!mctx) { return -2; } @@ -1015,29 +1102,28 @@ int llama_context::decode(const llama_batch & batch_inp) { case LLAMA_MEMORY_STATUS_NO_UPDATE: { LLAMA_LOG_ERROR("%s: unexpected memory context status: %d\n", __func__, mctx->get_status()); - return -2; } case LLAMA_MEMORY_STATUS_FAILED_PREPARE: { + if (kvd->forced_sinfos) { + LLAMA_LOG_ERROR("%s: Mismatch between ubatches and sinfos during reuse.\n", __func__); + return -1; + } + if (!did_optimize) { did_optimize = true; - if (kv_self_update(true)) { LLAMA_LOG_DEBUG("%s: retrying batch size %d after cache optimization\n", __func__, balloc->get_n_tokens()); - continue; } } - LLAMA_LOG_WARN("%s: failed to find a memory slot for batch of size %d\n", __func__, balloc->get_n_tokens()); - return 1; } case LLAMA_MEMORY_STATUS_FAILED_COMPUTE: { LLAMA_LOG_ERROR("%s: compute failed while preparing batch of size %d\n", __func__, balloc->get_n_tokens()); - return -2; } } @@ -1052,10 +1138,9 @@ int llama_context::decode(const llama_batch & batch_inp) { }; int64_t n_outputs_prev = 0; - + do { const auto & ubatch = mctx->get_ubatch(); - // count the outputs in this ubatch { int32_t n_outputs_new = 0; @@ -1071,10 +1156,8 @@ int llama_context::decode(const llama_batch & batch_inp) { // needs to happen before the graph is built n_outputs = n_outputs_new; } - ggml_status status; - const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status); - + const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status, batch_inp.mtp_params); if (!res) { // the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache llama_pos pos_min[LLAMA_MAX_SEQ]; @@ -1118,9 +1201,13 @@ int llama_context::decode(const llama_batch & batch_inp) { t_embd = res->get_embd_pooled(); } + ggml_backend_sched_t active_sched = sched.get(); + if (batch_inp.mtp_params.op_type == MTP_OP_DRAFT_ONLY) { + active_sched = sched_mtp.get(); + } // extract logits if (t_logits && n_outputs > 0) { - ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits); + ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(active_sched, t_logits); GGML_ASSERT(backend_res != nullptr); GGML_ASSERT(logits != nullptr); @@ -1135,56 +1222,57 @@ int llama_context::decode(const llama_batch & batch_inp) { // extract embeddings if (t_embd && n_outputs > 0) { - ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd); - GGML_ASSERT(backend_embd != nullptr); - - switch (cparams.pooling_type) { - case LLAMA_POOLING_TYPE_NONE: - { - // extract token embeddings - GGML_ASSERT(embd != nullptr); - float * embd_out = embd + n_outputs_prev*n_embd; - - if (n_outputs) { - GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all); - GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd <= (int64_t) embd_size); - ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*n_embd*sizeof(float)); + if (batch_inp.mtp_params.op_type == MTP_OP_NONE || batch_inp.mtp_params.op_type == MTP_OP_UNIFIED) { + ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(active_sched, t_embd); + GGML_ASSERT(backend_embd != nullptr); + + switch (cparams.pooling_type) { + case LLAMA_POOLING_TYPE_NONE: + { + // extract token embeddings + GGML_ASSERT(embd != nullptr); + float * embd_out = embd + n_outputs_prev*n_embd; + + if (n_outputs) { + GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all); + GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd <= (int64_t) embd_size); + ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*n_embd*sizeof(float)); + } + } break; + case LLAMA_POOLING_TYPE_MEAN: + case LLAMA_POOLING_TYPE_CLS: + case LLAMA_POOLING_TYPE_LAST: + { + // extract sequence embeddings (cleared before processing each batch) + auto & embd_seq_out = embd_seq; + + for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) { + const llama_seq_id seq_id = ubatch.seq_id_unq[s]; + const int32_t seq_idx = ubatch.seq_idx[seq_id]; + + embd_seq_out[seq_id].resize(n_embd); + ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_idx)*sizeof(float), n_embd*sizeof(float)); + } + } break; + case LLAMA_POOLING_TYPE_RANK: + { + // extract the rerank score - n_cls_out floats per sequence + auto & embd_seq_out = embd_seq; + const uint32_t n_cls_out = hparams.n_cls_out; + + for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) { + const llama_seq_id seq_id = ubatch.seq_id_unq[s]; + const int32_t seq_idx = ubatch.seq_idx[seq_id]; + + embd_seq_out[seq_id].resize(n_cls_out); + ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_idx)*sizeof(float), n_cls_out*sizeof(float)); + } + } break; + case LLAMA_POOLING_TYPE_UNSPECIFIED: + { + GGML_ABORT("unknown pooling type"); } - } break; - case LLAMA_POOLING_TYPE_MEAN: - case LLAMA_POOLING_TYPE_CLS: - case LLAMA_POOLING_TYPE_LAST: - { - // extract sequence embeddings (cleared before processing each batch) - auto & embd_seq_out = embd_seq; - - for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) { - const llama_seq_id seq_id = ubatch.seq_id_unq[s]; - const int32_t seq_idx = ubatch.seq_idx[seq_id]; - - embd_seq_out[seq_id].resize(n_embd); - ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_idx)*sizeof(float), n_embd*sizeof(float)); - } - } break; - case LLAMA_POOLING_TYPE_RANK: - { - // extract the rerank score - n_cls_out floats per sequence - auto & embd_seq_out = embd_seq; - - const uint32_t n_cls_out = hparams.n_cls_out; - - for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) { - const llama_seq_id seq_id = ubatch.seq_id_unq[s]; - const int32_t seq_idx = ubatch.seq_idx[seq_id]; - - embd_seq_out[seq_id].resize(n_cls_out); - ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_idx)*sizeof(float), n_cls_out*sizeof(float)); - } - } break; - case LLAMA_POOLING_TYPE_UNSPECIFIED: - { - GGML_ABORT("unknown pooling type"); - } + } } } @@ -1249,7 +1337,6 @@ int llama_context::decode(const llama_batch & batch_inp) { // overlap with device computation. ggml_backend_sched_reset(sched.get()); } - return 0; } @@ -1374,6 +1461,9 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u } ggml_backend_sched_reset(sched.get()); + if (sched_mtp) { + ggml_backend_sched_reset(sched_mtp.get()); + } // when the scheduler is reset, we cannnot reuse the old graph, so we reset the previous graph result to prevent that gf_res_prev->reset(); @@ -1389,7 +1479,7 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u auto * res = gf_res_reserve.get(); - const auto gparams = graph_params(res, ubatch, mctx, LLM_GRAPH_TYPE_DEFAULT); + const auto gparams = graph_params(res, ubatch, mctx, LLM_GRAPH_TYPE_DEFAULT, { MTP_OP_NONE }); res->reset(); @@ -1409,8 +1499,9 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u llm_graph_params llama_context::graph_params( llm_graph_result * res, const llama_ubatch & ubatch, - const llama_memory_context_i * mctx, - llm_graph_type gtype) const { + const llama_memory_context_i * mctx, + llm_graph_type gtype, + const llama_mtp_params & mtp_params) const { return { /*.arch =*/ model.arch, /*.hparams =*/ model.hparams, @@ -1423,15 +1514,29 @@ llm_graph_params llama_context::graph_params( /*.loras =*/ &loras, /*.mctx =*/ mctx, /*.cross =*/ &cross, + /*.mtp_params =*/ mtp_params, /*.n_outputs =*/ n_outputs, /*.cb =*/ graph_get_cb(), /*.res =*/ res, }; } -ggml_status llama_context::graph_compute( - ggml_cgraph * gf, - bool batched) { +std::unique_ptr llama_context::mtp_memory_batch(const llama_batch& batch_inp) { + const auto& vocab = model.vocab; + const auto& hparams = model.hparams; + + const int64_t n_vocab = vocab.n_tokens(); + const int64_t n_embd = hparams.n_embd; + + if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, false)) { + LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__); + return nullptr; + } + + return memory->init_batch(*balloc, 1, false); +} + +ggml_status llama_context::graph_compute(ggml_cgraph * gf, bool batched, ggml_backend_sched_t custom_sched) { int n_threads = batched ? cparams.n_threads_batch : cparams.n_threads; ggml_threadpool_t tp = batched ? threadpool_batch : threadpool; @@ -1446,7 +1551,9 @@ ggml_status llama_context::graph_compute( set_n_threads_fn.second(set_n_threads_fn.first, n_threads); } - auto status = ggml_backend_sched_graph_compute_async(sched.get(), gf); + ggml_backend_sched_t target = custom_sched ? custom_sched : sched.get(); + + auto status = ggml_backend_sched_graph_compute_async(target, gf); if (status != GGML_STATUS_SUCCESS) { LLAMA_LOG_ERROR("%s: ggml_backend_sched_graph_compute_async failed with error %d\n", __func__, status); } @@ -1456,8 +1563,10 @@ ggml_status llama_context::graph_compute( return status; } -llm_graph_cb llama_context::graph_get_cb() const { - return [&](const llama_ubatch & ubatch, ggml_tensor * cur, const char * name, int il) { +llm_graph_cb llama_context::graph_get_cb(ggml_backend_sched * sched_override) const { + ggml_backend_sched * cb_sched = sched_override ? sched_override : sched.get(); + + return [=](const llama_ubatch & ubatch, ggml_tensor * cur, const char * name, int il) { if (il >= 0) { ggml_format_name(cur, "%s-%d", name, il); } else { @@ -1467,7 +1576,7 @@ llm_graph_cb llama_context::graph_get_cb() const { if (!cparams.offload_kqv) { if (strcmp(name, "kqv_merged_cont") == 0) { // all nodes between the KV store and the attention output are run on the CPU - ggml_backend_sched_set_tensor_backend(sched.get(), cur, backend_cpu); + ggml_backend_sched_set_tensor_backend(cb_sched, cur, backend_cpu); } } @@ -1480,7 +1589,7 @@ llm_graph_cb llama_context::graph_get_cb() const { for (const auto & backend : backends) { if (ggml_backend_get_device(backend.get()) == dev_layer) { if (ggml_backend_supports_op(backend.get(), cur)) { - ggml_backend_sched_set_tensor_backend(sched.get(), cur, backend.get()); + ggml_backend_sched_set_tensor_backend(cb_sched, cur, backend.get()); } } } @@ -1489,6 +1598,10 @@ llm_graph_cb llama_context::graph_get_cb() const { }; } +ggml_backend_sched_t llama_context::create_temp_scheduler(size_t n_nodes) { + return ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), n_nodes, false, cparams.op_offload); +} + // // state save/load // @@ -2142,7 +2255,7 @@ void llama_context::opt_epoch_iter( auto * res = gf_res_prev.get(); - const auto gparams = graph_params(res, ubatch, mctx.get(), LLM_GRAPH_TYPE_DEFAULT); + const auto gparams = graph_params(res, ubatch, mctx.get(), LLM_GRAPH_TYPE_DEFAULT, { MTP_OP_NONE }); res->reset(); @@ -2233,6 +2346,7 @@ void llama_context::opt_epoch( llama_batch_free(batch); } + // // interface implementation // @@ -2274,6 +2388,8 @@ llama_context_params llama_context_default_params() { return result; } + + llama_context * llama_init_from_model( llama_model * model, llama_context_params params) { @@ -2412,6 +2528,7 @@ float * llama_get_logits_ith(llama_context * ctx, int32_t i) { return ctx->get_logits_ith(i); } + float * llama_get_embeddings(llama_context * ctx) { ctx->synchronize(); @@ -2926,3 +3043,44 @@ void llama_opt_epoch( callback_train, callback_eval); } + +void llama_set_draft_input_hidden_state(struct llama_context * ctx, const float * hidden_state) { + ctx->draft_input_hidden_state = hidden_state; +} + +void llama_context::kv_cache_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + if (memory) { + static_cast(memory.get())->seq_rm(seq_id, p0, p1); + } +} + +void llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + ctx->kv_cache_seq_rm(seq_id, p0, p1); +} + +bool llama_context::prepare_mtp_graph_inputs( + llm_graph_result * res, + const llama_ubatch & ubatch, + const llama_mtp_params & mtp_params) { + + // We only need to inject hidden states manually for the DRAFT_ONLY path. + if (mtp_params.op_type != MTP_OP_DRAFT_ONLY) { + return true; + } + + struct ggml_tensor * inp_mtp = ggml_graph_get_tensor(res->gf, "mtp_draft_hidden_state"); + if (!inp_mtp) { + LLAMA_LOG_ERROR("MTP input tensor not found in graph\n"); + return false; + } + + const float * src_data = this->draft_input_hidden_state; + if (!src_data) { + LLAMA_LOG_ERROR("%s: Source hidden state data is NULL (draft_input_hidden_state)\n", __func__); + return false; + } + + ggml_backend_tensor_set(inp_mtp, src_data, 0, ggml_nbytes(inp_mtp)); + + return true; +} diff --git a/src/llama-context.h b/src/llama-context.h index 25c143d56df..895efbc343f 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -20,6 +20,8 @@ class llama_io_write_i; struct llama_memory_i; struct llama_memory_context_i; +struct llama_context_kv_cache_data; + struct llama_context { // init scheduler and compute buffers, reserve worst-case graphs llama_context( @@ -27,6 +29,15 @@ struct llama_context { llama_context_params params); ~llama_context(); + + // The llama_context manages significant resources (GPU memory, file handles, PImpl data) + // and is fundamentally a non-copyable, non-movable object. Deleting these special + // member functions enforces this rule and is also technically required to allow the + // PImpl pattern (via unique_ptr or void*) with an incomplete type in the header. + llama_context(const llama_context &) = delete; + llama_context & operator=(const llama_context &) = delete; + llama_context(llama_context &&) = delete; + llama_context & operator=(llama_context &&) = delete; void synchronize(); @@ -59,6 +70,9 @@ struct llama_context { float * get_embeddings(); float * get_embeddings_ith(int32_t i); float * get_embeddings_seq(llama_seq_id seq_id); + ggml_tensor * get_embeddings_tensor(); + + const float * draft_input_hidden_state = nullptr; void attach_threadpool( ggml_threadpool_t threadpool, @@ -90,6 +104,8 @@ struct llama_context { int32_t il_start, int32_t il_end); + void kv_cache_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1); + // process a single ubatch with a specific graph type // if memory_context is provided, it will be applied first to the context's memory // ret contains the status of the graph computation @@ -98,7 +114,8 @@ struct llama_context { const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, - ggml_status & ret); + ggml_status & ret, + const llama_mtp_params & mtp_params); int encode(const llama_batch & batch_inp); int decode(const llama_batch & batch_inp); @@ -194,19 +211,35 @@ struct llama_context { llm_graph_result * get_gf_res_reserve() const; // returns the result of ggml_backend_sched_graph_compute_async execution - ggml_status graph_compute(ggml_cgraph * gf, bool batched); + ggml_status graph_compute(ggml_cgraph * gf, bool batched, ggml_backend_sched_t custom_sched = nullptr); // reserve a graph with a dummy ubatch of the specified size ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx); + void set_logits_ith(struct ggml_tensor * logit_override, ggml_backend_sched_t sched_override, int32_t i); + + ggml_backend_sched_t create_temp_scheduler(size_t n_nodes); + + std::unique_ptr mtp_memory_batch(const llama_batch& batch_inp); + + // For MTP KV cache cell reuse + void * kv_cache_data; + private: llm_graph_params graph_params( llm_graph_result * res, const llama_ubatch & ubatch, const llama_memory_context_i * mctx, - llm_graph_type gtype) const; + llm_graph_type gtype, + const llama_mtp_params & mtp_params) const; + + llm_graph_cb graph_get_cb(ggml_backend_sched * sched_override = nullptr) const; - llm_graph_cb graph_get_cb() const; + // Methods for MTP decode + bool prepare_mtp_graph_inputs( + llm_graph_result * res, + const llama_ubatch & ubatch, + const llama_mtp_params & mtp_params); // TODO: read/write lora adapters and cvec size_t state_write_data(llama_io_write_i & io); @@ -240,6 +273,7 @@ struct llama_context { // populated only when pooling_type == LLAMA_POOLING_TYPE_NONE size_t embd_size = 0; // capacity (of floats) for embeddings float * embd = nullptr; + ggml_tensor * embd_tensor = nullptr; // sequence embeddings output (map of [n_embd] vectors) // populated only when pooling_type != LLAMA_POOLING_TYPE_NONE @@ -261,6 +295,8 @@ struct llama_context { ggml_backend_sched_ptr sched; + ggml_backend_sched_ptr sched_mtp; + ggml_backend_t backend_cpu = nullptr; std::vector backends; @@ -308,3 +344,4 @@ struct llama_context { mutable int32_t n_reused = 0; // number of times the previous graph was reused }; + diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 053c72d6dc8..50bd8e03f74 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -442,34 +442,22 @@ void llm_graph_result::set_inputs(const llama_ubatch * ubatch) { bool llm_graph_result::can_reuse(const llm_graph_params & params) { if (!this->params.allow_reuse(params)) { - if (debug > 1) { - LLAMA_LOG_DEBUG("%s: cannot reuse graph due to incompatible graph parameters\n", __func__); - } - + LLAMA_LOG_WARN("[GRAPH-REUSE-FAIL] Failure in 'allow_reuse'. Incompatible parameters."); + LLAMA_LOG_WARN(" n_tokens: %d vs %d, op_type: %d vs %d", + this->params.ubatch.n_tokens, params.ubatch.n_tokens, + (int)this->params.mtp_params.op_type, (int)params.mtp_params.op_type); return false; } - if (debug > 1) { - LLAMA_LOG_DEBUG("%s: checking compatibility of %d inputs:\n", __func__, (int) inputs.size()); - } - - bool res = true; - - for (auto & input : inputs) { - const bool cur = input->can_reuse(params); - - if (debug > 1) { - LLAMA_LOG_DEBUG("%s: can_reuse = %d\n", "placeholder", cur); + for (size_t i = 0; i < inputs.size(); ++i) { + if (!inputs[i]->can_reuse(params)) { + LLAMA_LOG_WARN("[GRAPH-REUSE-FAIL] Failure in 'can_reuse' of the input node #%zu.", i); + return false; } - - res = res && cur; - } - - if (debug > 0) { - LLAMA_LOG_DEBUG("%s: can reuse graph = %d\n", __func__, res); } - return res; + LLAMA_LOG_DEBUG("%s: can reuse graph = true\n", __func__); + return true; } llm_graph_input_i * llm_graph_result::add_input(llm_graph_input_ptr input) { @@ -1074,6 +1062,26 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const { return cur; } + +ggml_tensor * llm_graph_context::build_inp_embd_mtp(ggml_tensor * mtp_tok_embd) const { + auto inp = std::make_unique(); + ggml_tensor * cur = nullptr; + + if (ubatch.token) { + inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens); + ggml_set_name(inp->tokens, "mtp_inp_tokens"); + ggml_set_input(inp->tokens); + + cur = ggml_get_rows(ctx0, mtp_tok_embd, inp->tokens); + } else { + GGML_ABORT("fatal error: MTP update expects token IDs, not embeddings"); + } + + cb(cur, "mtp_inp_embd", -1); + res->add_input(std::move(inp)); + return cur; +} + ggml_tensor * llm_graph_context::build_inp_pos() const { auto inp = std::make_unique(hparams.n_pos_per_embd()); diff --git a/src/llama-graph.h b/src/llama-graph.h index 6ff49de3a1c..3c5feadfdc7 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -29,6 +29,7 @@ enum llm_graph_type { LLM_GRAPH_TYPE_DEFAULT, LLM_GRAPH_TYPE_ENCODER, LLM_GRAPH_TYPE_DECODER, + LLM_GRAPH_TYPE_DRAFT, }; enum llm_ffn_op_type { @@ -94,6 +95,20 @@ class llm_graph_input_i { using llm_graph_input_ptr = std::unique_ptr; +class llm_graph_input_mtp_states : public llm_graph_input_i { +public: + llm_graph_input_mtp_states() = default; + virtual ~llm_graph_input_mtp_states() = default; + + void set_input(const llama_ubatch * /*ubatch*/) override {} + + bool can_reuse(const llm_graph_params & /*params*/) override { + return true; + } + + ggml_tensor * states = nullptr; +}; + class llm_graph_input_embd : public llm_graph_input_i { public: llm_graph_input_embd() = default; @@ -402,6 +417,7 @@ struct llm_graph_params { const llama_adapter_loras * loras; const llama_memory_context_i * mctx; const llama_cross * cross; + llama_mtp_params mtp_params; uint32_t n_outputs; @@ -450,6 +466,7 @@ struct llm_graph_params { cvec == other.cvec && loras == other.loras && cross == other.cross && + mtp_params.op_type == other.mtp_params.op_type && n_outputs == other.n_outputs; } }; @@ -664,6 +681,7 @@ struct llm_graph_context { // ggml_tensor * build_inp_embd(ggml_tensor * tok_embd) const; + ggml_tensor * build_inp_embd_mtp(ggml_tensor * mtp_tok_embd) const; ggml_tensor * build_inp_pos() const; ggml_tensor * build_inp_attn_scale() const; ggml_tensor * build_inp_out_ids() const; @@ -818,3 +836,4 @@ struct llm_graph_context { // TODO: better name int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional); + diff --git a/src/llama-kv-cache-unified.cpp b/src/llama-kv-cache-unified.cpp index e539142e6b8..8d9b1f631f7 100644 --- a/src/llama-kv-cache-unified.cpp +++ b/src/llama-kv-cache-unified.cpp @@ -41,7 +41,7 @@ llama_kv_cache_unified::llama_kv_cache_unified( } if (model.arch == LLM_ARCH_GLM4_MOE) { // GLM-4.5: Only process up to last layer, skip final NextN layer - n_layer_cache = hparams.n_layer - hparams.nextn_predict_layers; + n_layer_cache = hparams.n_layer;// - hparams.nextn_predict_layers; } // create a context for each buffer type @@ -508,6 +508,34 @@ llama_memory_context_ptr llama_kv_cache_unified::init_batch( return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); } +llama_memory_context_ptr llama_kv_cache_unified::init_batch_with_sinfos( + llama_batch_allocr & balloc, + uint32_t n_ubatch, + const slot_info_vec_t & sinfos, + bool is_inplace_update) { + + if (sinfos.empty()) { + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); + } + + balloc.split_reset(); + std::vector ubatches; + while (true) { + auto ubatch = n_stream == 1 ? balloc.split_simple(n_ubatch) : balloc.split_equal(n_ubatch, true); + if (ubatch.n_tokens == 0) { + break; + } + ubatches.push_back(std::move(ubatch)); + } + + if (ubatches.size() != sinfos.size()) { + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); + } + + return std::make_unique( + this, sinfos, std::move(ubatches), is_inplace_update); +} + llama_memory_context_ptr llama_kv_cache_unified::init_full() { return std::make_unique(this); } @@ -928,64 +956,81 @@ llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ } assert(res.s1 >= res.s0); + if (!res.empty()) { + std::string idxs_str; + for (const auto& vec : res.idxs) { + if (!vec.empty()) { + if (vec.size() > 8) { + idxs_str += " [" + std::to_string(vec.front()) + "..." + std::to_string(vec.back()) + " (" + std::to_string(vec.size()) + " cells)]"; + } else { + idxs_str += " ["; + for(size_t i = 0; i < vec.size(); ++i) { + idxs_str += std::to_string(vec[i]) + (i == vec.size() - 1 ? "" : ", "); + } + idxs_str += "]"; + } + } + } + } return res; } -void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch) { - // keep track of the max sequence position that we would overwrite with this ubatch - // for non-SWA cache, this would be always empty - llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ]; - for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) { - seq_pos_max_rm[s] = -1; - } - - assert(ubatch.n_tokens == sinfo.n_stream()*sinfo.size()); - - for (uint32_t s = 0; s < sinfo.n_stream(); ++s) { - for (uint32_t ii = 0; ii < sinfo.size(); ++ii) { - const uint32_t i = s*sinfo.size() + ii; - - auto & cells = v_cells[sinfo.strm[s]]; +void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch, bool is_inplace_update) { + // For "in-place" updates (MTP warmup/accept), we only update the tensor data. + // The cell metadata (logical position, sequence ID) has already been set + // by the main model's pass. We must skip all metadata modifications + // to prevent `pos_set` from asserting on an already-set cell. + if (!is_inplace_update) { + // keep track of the max sequence position that we would overwrite with this ubatch + // for non-SWA cache, this would be always empty + llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ]; + for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) { + seq_pos_max_rm[s] = -1; + } - const auto idx = sinfo.idxs[s][ii]; + assert(ubatch.n_tokens == sinfo.n_stream()*sinfo.size()); - if (!cells.is_empty(idx)) { - assert(cells.seq_count(idx) == 1); + for (uint32_t s = 0; s < sinfo.n_stream(); ++s) { + for (uint32_t ii = 0; ii < sinfo.size(); ++ii) { + const uint32_t i = s*sinfo.size() + ii; - const llama_seq_id seq_id = cells.seq_get(idx); - const llama_pos pos = cells.pos_get(idx); + auto & cells = v_cells[sinfo.strm[s]]; - seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos); + const auto idx = sinfo.idxs[s][ii]; - cells.rm(idx); - } + if (!cells.is_empty(idx)) { + assert(cells.seq_count(idx) == 1); + const llama_seq_id seq_id = cells.seq_get(idx); + const llama_pos pos = cells.pos_get(idx); + seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos); + cells.rm(idx); + } - cells.pos_set(idx, ubatch.pos[i]); + cells.pos_set(idx, ubatch.pos[i]); - for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) { - cells.seq_add(idx, ubatch.seq_id[i][s]); + for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) { + cells.seq_add(idx, ubatch.seq_id[i][s]); + } } } - } - // note: we want to preserve the invariant that all positions between [pos_min, pos_max] for each sequence - // will be present in the cache. so we have to purge any position which is less than those we would overwrite - // ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092 - for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) { - if (seq_pos_max_rm[s] == -1) { - continue; - } + // note: we want to preserve the invariant that all positions between [pos_min, pos_max] for each sequence + // will be present in the cache. so we have to purge any position which is less than those we would overwrite + // ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092 + for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) { + if (seq_pos_max_rm[s] == -1) { + continue; + } - GGML_ASSERT(s < seq_to_stream.size()); + GGML_ASSERT(s < seq_to_stream.size()); - auto & cells = v_cells[seq_to_stream[s]]; + auto & cells = v_cells[seq_to_stream[s]]; - if (cells.seq_pos_min(s) <= seq_pos_max_rm[s]) { - LLAMA_LOG_DEBUG("%s: purging positions [%d, %d] of sequence %d from KV cache\n", - __func__, cells.seq_pos_min(s), seq_pos_max_rm[s], s); + if (cells.seq_pos_min(s) <= seq_pos_max_rm[s]) { - seq_rm(s, cells.seq_pos_min(s), seq_pos_max_rm[s] + 1); + seq_rm(s, cells.seq_pos_min(s), seq_pos_max_rm[s] + 1); + } } } @@ -2290,7 +2335,8 @@ llama_kv_cache_unified_context::llama_kv_cache_unified_context( llama_kv_cache_unified_context::llama_kv_cache_unified_context( llama_kv_cache_unified * kv, llama_kv_cache_unified::slot_info_vec_t sinfos, - std::vector ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), sinfos(std::move(sinfos)), ubatches(std::move(ubatches)) { + std::vector ubatches, + bool is_inplace_update) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), sinfos(std::move(sinfos)), ubatches(std::move(ubatches)), is_inplace_update(is_inplace_update) { } llama_kv_cache_unified_context::~llama_kv_cache_unified_context() = default; @@ -2315,13 +2361,18 @@ bool llama_kv_cache_unified_context::apply() { return true; } - kv->apply_ubatch(sinfos[i_cur], ubatches[i_cur]); + kv->apply_ubatch(sinfos[i_cur], ubatches[i_cur], is_inplace_update); n_kv = kv->get_n_kv(); return true; } +void llama_kv_cache_unified_context::set_n_kv() { + n_kv = kv->get_n_kv(); +} + + llama_memory_status llama_kv_cache_unified_context::get_status() const { return status; } @@ -2384,6 +2435,10 @@ void llama_kv_cache_unified_context::set_input_pos_bucket(ggml_tensor * dst, con kv->set_input_pos_bucket(dst, ubatch); } +void llama_kv_cache_unified_context::set_sinfos(llama_kv_cache_unified::slot_info_vec_t new_sinfos) { + sinfos = new_sinfos; +} + uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) { // the FA kernels require padding to avoid extra runtime boundary checks return cparams.flash_attn ? 256u : 32u; diff --git a/src/llama-kv-cache-unified.h b/src/llama-kv-cache-unified.h index 342a675962e..f64f7faa5c0 100644 --- a/src/llama-kv-cache-unified.h +++ b/src/llama-kv-cache-unified.h @@ -116,6 +116,12 @@ class llama_kv_cache_unified : public llama_memory_i { llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) override; + + llama_memory_context_ptr init_batch_with_sinfos( + llama_batch_allocr & balloc, + uint32_t n_ubatch, + const slot_info_vec_t & sinfos, + bool is_inplace_update); llama_memory_context_ptr init_full() override; @@ -181,7 +187,7 @@ class llama_kv_cache_unified : public llama_memory_i { slot_info find_slot(const llama_ubatch & ubatch, bool cont) const; // emplace the ubatch context into slot: [sinfo.idxs[0...ubatch.n_tokens - 1]] - void apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch); + void apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch, bool is_inplace_update = false); // // input API @@ -321,7 +327,8 @@ class llama_kv_cache_unified_context : public llama_memory_context_i { llama_kv_cache_unified_context( llama_kv_cache_unified * kv, slot_info_vec_t sinfos, - std::vector ubatches); + std::vector ubatches, + bool is_inplace_update = false); virtual ~llama_kv_cache_unified_context(); @@ -340,6 +347,7 @@ class llama_kv_cache_unified_context : public llama_memory_context_i { // uint32_t get_n_kv() const; + void set_n_kv(); // TODO: temporary bool get_supports_set_rows() const; @@ -362,6 +370,10 @@ class llama_kv_cache_unified_context : public llama_memory_context_i { void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const; void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const; + void set_sinfos(slot_info_vec_t new_sinfos); + + const slot_info_vec_t & get_sinfos() const { return sinfos; } + private: llama_memory_status status; @@ -396,4 +408,6 @@ class llama_kv_cache_unified_context : public llama_memory_context_i { // a heuristic, to avoid attending the full cache if it is not yet utilized // as the cache gets filled, the benefit from this heuristic disappears int32_t n_kv; + + bool is_inplace_update = false; }; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 58ca7df707e..154ef1de37d 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -4507,9 +4507,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) { // but only PROCESS up to last layer (skipping final NextN layer) in forward pass for (int i = 0; i < n_layer; ++i) { int flags = 0; + if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { // skip all tensors in the NextN layers - flags |= TENSOR_SKIP; + // flags |= TENSOR_SKIP; } auto & layer = layers[i]; @@ -4573,12 +4574,37 @@ bool llama_model::load_tensors(llama_model_loader & ml) { // NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { + + // our input/output layer sanity check prevents us from loading the eh_proj layer! + // this is because eh_proj is labelled with a layer number in existing GGUFs, + // so we need to set bid == to successfully load the tensors, but our io layer sanity check requires bid == -1. + // this function is a hack that creates the nextn layers as LLM_TENSOR_LAYER_REPEATING instead. + /* auto create_tensor_override_io_sanity_check = + [&](llm_tensor type_enum, const char * suffix, int bid, const std::initializer_list& ne, int flags) -> ggml_tensor * { + + auto tn_orig = tn(type_enum, suffix, bid); + llm_tensor_info info_override = *tn_orig.info; + info_override.layer = LLM_TENSOR_LAYER_REPEATING; + + auto tn_override = tn_orig; + tn_override.info = &info_override; + + return create_tensor(tn_override, ne, flags); + };*/ + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags); layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, flags); layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags); layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, flags); layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, flags); layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, flags); + + // layer.nextn.eh_proj = create_tensor_override_io_sanity_check(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i, { 2 * n_embd, n_embd }, flags); + // layer.nextn.embed_tokens = create_tensor_override_io_sanity_check(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i, { n_embd, n_vocab }, flags); + // layer.nextn.enorm = create_tensor_override_io_sanity_check(LLM_TENSOR_NEXTN_ENORM, "weight", i, { n_embd }, flags); + // layer.nextn.hnorm = create_tensor_override_io_sanity_check(LLM_TENSOR_NEXTN_HNORM, "weight", i, { n_embd }, flags); + // layer.nextn.shared_head_head = create_tensor_override_io_sanity_check(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i, { n_embd, n_vocab }, flags); + // layer.nextn.shared_head_norm = create_tensor_override_io_sanity_check(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i, { n_embd }, flags); } } } @@ -13763,159 +13789,316 @@ struct llm_build_glm4 : public llm_graph_context { struct llm_build_glm4_moe : public llm_graph_context { llm_build_glm4_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); ggml_tensor * cur; - ggml_tensor * inpL; - inpL = build_inp_embd(model.tok_embd); + if (params.mtp_params.op_type == MTP_OP_DRAFT_ONLY) { + ggml_tensor * hidden_state_input = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, hparams.n_embd); + ggml_set_name(hidden_state_input, "mtp_draft_hidden_state"); + ggml_set_input(hidden_state_input); + auto inp_mtp = std::make_unique(); + inp_mtp->states = hidden_state_input; + res->add_input(std::move(inp_mtp)); - // inp_pos - contains the positions - ggml_tensor * inp_pos = build_inp_pos(); + const int il_mtp = hparams.n_layer - 1; + const auto & mtp_layer = model.layers[il_mtp]; + res->t_logits = build_mtp_draft_graph(mtp_layer, hidden_state_input, n_embd_head); - auto * inp_attn = build_attn_inp_kv_unified(); + } else { + ggml_tensor * inp_raw_embd = build_inp_embd(model.tok_embd); + ggml_tensor * inpL = inp_raw_embd; + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + auto * inp_attn = build_attn_inp_kv_unified(); + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + // Only process up to last layer (skip final NextN layer) + // Final layer tensors are loaded but not processed in forward pass + const int n_transformer_layers = n_layer - hparams.nextn_predict_layers; + for (int il = 0; il < n_transformer_layers; ++il) { + ggml_tensor * inpSA = inpL; + + // Pre-attention norm + cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); - ggml_tensor * inp_out_ids = build_inp_out_ids(); + // self-attention + { + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + } + cb(Qcur, "Qcur", il); - // Only process up to last layer (skip final NextN layer) - // Final layer tensors are loaded but not processed in forward pass - const int n_transformer_layers = n_layer - hparams.nextn_predict_layers; - for (int il = 0; il < n_transformer_layers; ++il) { - ggml_tensor * inpSA = inpL; + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + } + cb(Kcur, "Kcur", il); - // Pre-attention norm - cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); - cb(cur, "attn_norm", il); + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + } + cb(Vcur, "Vcur", il); - // self-attention - { - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - } - cb(Qcur, "Qcur", il); + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - } - cb(Kcur, "Kcur", il); + // Apply Q/K norm if available (GLM-4.5 355B variant) + if (model.layers[il].attn_q_norm) { + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + } + if (model.layers[il].attn_k_norm) { + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); + cb(Kcur, "Kcur_normed", il); + } - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - } - cb(Vcur, "Vcur", il); + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); - // Apply Q/K norm if available (GLM-4.5 355B variant) - if (model.layers[il].attn_q_norm) { - Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); - cb(Qcur, "Qcur_normed", il); + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } - if (model.layers[il].attn_k_norm) { - Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); - cb(Kcur, "Kcur_normed", il); + + if (il == n_transformer_layers - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } - Qcur = ggml_rope_ext( - ctx0, Qcur, inp_pos, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow - ); + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); - Kcur = ggml_rope_ext( - ctx0, Kcur, inp_pos, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow - ); + // Post-attention norm + cur = build_norm(ffn_inp, model.layers[il].attn_post_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "post_attn_norm", il); + + // Check if this is a dense layer (n_layer_dense_lead=1, so layer 0 is dense) + if (static_cast(il) < hparams.n_layer_dense_lead) { + // Dense FFN layer + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } else { + // Process routed experts using existing MoE infrastructure + ggml_tensor * routed_out = build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + model.layers[il].ffn_exp_probs_b, + n_expert, n_expert_used, + LLM_FFN_SILU, hparams.expert_weights_norm, + true, hparams.expert_weights_scale, + (llama_expert_gating_func_type) hparams.expert_gating_func, + il); + cb(routed_out, "ffn_moe_out", il); - cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - cb(Vcur, "Vcur", il); + // Process shared expert on original input + ggml_tensor * shared_out = build_ffn(cur, + model.layers[il].ffn_up_shexp, NULL, NULL, + model.layers[il].ffn_gate_shexp, NULL, NULL, + model.layers[il].ffn_down_shexp, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(shared_out, "ffn_shexp_out", il); - cur = build_attn(inp_attn, - model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + // Final output: routed_output + shared_output + cur = ggml_add(ctx0, routed_out, shared_out); + cb(cur, "ffn_out", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; } - if (il == n_transformer_layers - 1 && inp_out_ids) { - cur = ggml_get_rows(ctx0, cur, inp_out_ids); - inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + cur = inpL; + cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1); + + // cb(cur, "result_norm", -1); + res->t_embd = cur; + + if (params.mtp_params.op_type == MTP_OP_UNIFIED) { + const int il_mtp = hparams.n_layer - 1; + const auto & mtp_layer = model.layers[il_mtp]; + + ggml_tensor * mtp_embd_input = inp_raw_embd; + + if (inp_out_ids) { + mtp_embd_input = ggml_get_rows(ctx0, inp_raw_embd, inp_out_ids); + ggml_set_name(mtp_embd_input, "mtp_sliced_embd"); + } + + build_mtp_update_graph(mtp_layer, cur, mtp_embd_input, inp_pos, inp_attn, n_embd_head); } - ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); - cb(ffn_inp, "ffn_inp", il); + // Use the main model header + res->t_logits = build_lora_mm(model.output, cur); - // Post-attention norm - cur = build_norm(ffn_inp, model.layers[il].attn_post_norm, NULL, LLM_NORM_RMS, il); - cb(cur, "post_attn_norm", il); + } - // Check if this is a dense layer (n_layer_dense_lead=1, so layer 0 is dense) - if (static_cast(il) < hparams.n_layer_dense_lead) { - // Dense FFN layer - cur = build_ffn(cur, - model.layers[il].ffn_up, NULL, NULL, - model.layers[il].ffn_gate, NULL, NULL, - model.layers[il].ffn_down, NULL, NULL, - NULL, - LLM_FFN_SILU, LLM_FFN_PAR, il); - cb(cur, "ffn_out", il); - } else { - // Process routed experts using existing MoE infrastructure - ggml_tensor * routed_out = build_moe_ffn(cur, - model.layers[il].ffn_gate_inp, - model.layers[il].ffn_up_exps, - model.layers[il].ffn_gate_exps, - model.layers[il].ffn_down_exps, - model.layers[il].ffn_exp_probs_b, - n_expert, n_expert_used, - LLM_FFN_SILU, hparams.expert_weights_norm, - true, hparams.expert_weights_scale, - (llama_expert_gating_func_type) hparams.expert_gating_func, - il); - cb(routed_out, "ffn_moe_out", il); + ggml_build_forward_expand(gf, res->t_logits); + } - // Process shared expert on original input - ggml_tensor * shared_out = build_ffn(cur, - model.layers[il].ffn_up_shexp, NULL, NULL, - model.layers[il].ffn_gate_shexp, NULL, NULL, - model.layers[il].ffn_down_shexp, NULL, NULL, - NULL, - LLM_FFN_SILU, LLM_FFN_PAR, il); - cb(shared_out, "ffn_shexp_out", il); +private: + ggml_tensor * build_mtp_draft_graph(const llama_layer & mtp_layer, ggml_tensor * hidden_state_input, int64_t n_embd_head) { + const int il = hparams.n_layer - 1; - // Final output: routed_output + shared_output - cur = ggml_add(ctx0, routed_out, shared_out); - cb(cur, "ffn_out", il); + ggml_tensor * token_emb = build_inp_embd_mtp(mtp_layer.nextn.embed_tokens); + ggml_tensor * hidden_state_norm = build_norm(hidden_state_input, mtp_layer.nextn.hnorm, NULL, LLM_NORM_RMS, il); + ggml_tensor * token_emb_norm = build_norm(token_emb, mtp_layer.nextn.enorm, NULL, LLM_NORM_RMS, il); + ggml_tensor * combined = ggml_concat(ctx0, token_emb_norm, hidden_state_norm, 0); + ggml_tensor * projected_input = build_lora_mm(mtp_layer.nextn.eh_proj, combined); + + ggml_tensor * inp_pos = build_inp_pos(); + auto * inp_attn = build_attn_inp_kv_unified(); + + ggml_tensor* final_state = build_transformer_block(mtp_layer, projected_input, inp_pos, inp_attn, n_embd_head, il); + + ggml_tensor* logits = build_norm(final_state, mtp_layer.nextn.shared_head_norm, NULL, LLM_NORM_RMS, il); + logits = build_lora_mm(mtp_layer.nextn.shared_head_head, logits); + + return logits; + } + + void build_mtp_update_graph( + const llama_layer & mtp_layer, + ggml_tensor * main_model_hidden_state, + ggml_tensor * main_model_token_emb, + ggml_tensor * inp_pos, + llm_graph_input_attn_kv_unified * inp_attn, + int64_t n_embd_head + ) { + const int il = hparams.n_layer - 1; + + ggml_tensor * hidden_state_norm = build_norm(main_model_hidden_state, mtp_layer.nextn.hnorm, NULL, LLM_NORM_RMS, il); + ggml_tensor * token_emb_norm = build_norm(main_model_token_emb, mtp_layer.nextn.enorm, NULL, LLM_NORM_RMS, il); + ggml_tensor * combined = ggml_concat(ctx0, token_emb_norm, hidden_state_norm, 0); + ggml_tensor * projected_input = build_lora_mm(mtp_layer.nextn.eh_proj, combined); + + build_transformer_block(mtp_layer, projected_input, inp_pos, inp_attn, n_embd_head, il); + + ggml_tensor * dummy_output = ggml_sum(ctx0, main_model_hidden_state); + ggml_set_name(dummy_output, "mtp_update_side_effect"); + } + + ggml_tensor * llm_build_glm4_moe::build_transformer_block(const llama_layer & layer, ggml_tensor* input, + ggml_tensor* current_pos, + llm_graph_input_attn_kv_unified* current_inp_attn, int64_t n_embd_head, int il) { + // now proceed through last layer (skipped in main model) + ggml_tensor * inpSA = input; + // Pre-attention norm for the MTP block + ggml_tensor* cur = build_norm(input, layer.attn_norm, NULL, LLM_NORM_RMS, il); + + // self-attention + { + ggml_tensor * Qcur = build_lora_mm(layer.wq, cur); + if (layer.bq) Qcur = ggml_add(ctx0, Qcur, layer.bq); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(layer.wk, cur); + if (layer.bk) Kcur = ggml_add(ctx0, Kcur, layer.bk); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(layer.wv, cur); + if (layer.bv) Vcur = ggml_add(ctx0, Vcur, layer.bv); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + // Apply Q/K norm if available (GLM-4.5 355B variant) + if (layer.attn_q_norm) { + Qcur = build_norm(Qcur, layer.attn_q_norm, NULL, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + } + if (layer.attn_k_norm) { + Kcur = build_norm(Kcur, layer.attn_k_norm, NULL, LLM_NORM_RMS, il); + cb(Kcur, "Kcur_normed", il); } - cur = ggml_add(ctx0, cur, ffn_inp); + Qcur = ggml_rope_ext( + ctx0, Qcur, current_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); - cur = build_cvec(cur, il); - cb(cur, "l_out", il); + Kcur = ggml_rope_ext( + ctx0, Kcur, current_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); - // input for next layer - inpL = cur; + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(current_inp_attn, layer.wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } - cur = inpL; - cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1); + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cur = build_norm(ffn_inp, layer.attn_post_norm, NULL, LLM_NORM_RMS, il); - cb(cur, "result_norm", -1); - res->t_embd = cur; + // moe ffn for nextn block + { + // Process routed experts using existing MoE infrastructure + ggml_tensor * routed_out = build_moe_ffn(cur, + layer.ffn_gate_inp, + layer.ffn_up_exps, + layer.ffn_gate_exps, + layer.ffn_down_exps, + layer.ffn_exp_probs_b, + n_expert, n_expert_used, + LLM_FFN_SILU, hparams.expert_weights_norm, + true, hparams.expert_weights_scale, + (llama_expert_gating_func_type) hparams.expert_gating_func, + il); + cb(routed_out, "ffn_moe_out", il); - // lm_head - cur = build_lora_mm(model.output, cur); + // Process shared expert on original input + ggml_tensor * shared_out = build_ffn(cur, + layer.ffn_up_shexp, NULL, NULL, + layer.ffn_gate_shexp, NULL, NULL, + layer.ffn_down_shexp, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(shared_out, "ffn_shexp_out", il); - cb(cur, "result_output", -1); - res->t_logits = cur; + // Final output: routed_output + shared_output + cur = ggml_add(ctx0, routed_out, shared_out); + cb(cur, "ffn_out", il); + } + cur = ggml_add(ctx0, cur, ffn_inp); + // cb(cur, "mtp_ffn_residual", il); - ggml_build_forward_expand(gf, cur); + return cur; } }; @@ -18144,8 +18327,8 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, } ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { + const int64_t t_start_us = ggml_time_us(); std::unique_ptr llm; - switch (arch) { case LLM_ARCH_LLAMA: { @@ -18503,9 +18686,16 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { GGML_ABORT("fatal error"); } - // add on pooling layer - llm->build_pooling(cls, cls_b, cls_out, cls_out_b); - + if (params.mtp_params.op_type == MTP_OP_NONE || params.mtp_params.op_type == MTP_OP_UNIFIED) { + // add on pooling layer + llm->build_pooling(cls, cls_b, cls_out, cls_out_b); + } + const int64_t t_end_us = ggml_time_us(); + LLAMA_LOG_INFO( + "[PERF] Graph build time: %.2f ms (MTP path: %s)\n", + (t_end_us - t_start_us) / 1000.0, + params.mtp_params.op_type != MTP_OP_NONE || params.mtp_params.op_type != MTP_OP_UNIFIED ? "yes" : "no" + ); return llm->res->get_gf(); } @@ -18587,6 +18777,10 @@ const char * llama_model_cls_label(const struct llama_model * model, uint32_t i) return nullptr; } +int32_t llama_model_n_nextn_layer(const llama_model * model) { + return model->hparams.nextn_predict_layers; +} + // deprecated int32_t llama_n_ctx_train(const llama_model * model) { return llama_model_n_ctx_train(model); @@ -18820,3 +19014,4 @@ bool llama_model_is_diffusion(const llama_model * model) { const std::vector> & llama_internal_get_tensor_map(const llama_model * model) { return model->tensors_by_name; } + diff --git a/tools/server/server.cpp b/tools/server/server.cpp index a255d481a4d..e189b86d87b 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -1278,6 +1278,7 @@ struct server_task_result_apply_lora : server_task_result { } }; + struct server_slot { int id; int id_task = -1; @@ -1294,6 +1295,8 @@ struct server_slot { mtmd_context * mctx = nullptr; common_speculative * spec = nullptr; + bool has_mtp = false; + int32_t last_tok_idx = -1; std::vector lora; @@ -1363,6 +1366,7 @@ struct server_slot { // Speculative decoding stats int32_t n_draft_total = 0; // Total draft tokens generated int32_t n_draft_accepted = 0; // Draft tokens actually accepted + llama_tokens ids_prev_accepted; void reset() { SLT_DBG(*this, "%s", "\n"); @@ -1391,7 +1395,7 @@ struct server_slot { } bool need_embd() const { - return server_task_type_need_embd(task_type); + return server_task_type_need_embd(task_type) || has_mtp; } bool need_logits() const { @@ -1401,9 +1405,14 @@ struct server_slot { // if the context does not have a memory module then all embeddings have to be computed within a single ubatch // also we cannot split if the pooling would require any past tokens bool can_split() const { + //fprintf(stderr, "need_embd() %d\n", need_embd()); + //fprintf(stderr, "llama_get_memory(ctx) %d\n", llama_get_memory(ctx) != nullptr); + //fprintf(stderr, "POOLING_TYPE check %d\n", llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_LAST); + return !need_embd() || - (llama_get_memory(ctx) && llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_LAST); + (llama_get_memory(ctx) && llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_LAST) || + (llama_get_memory(ctx) && llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_NONE); // this seems to save embeddings for whole batch? } bool can_batch_with(server_slot & other_slot) const { @@ -1431,7 +1440,8 @@ struct server_slot { } bool can_speculate() const { - return ctx_dft && params.speculative.n_max > 0 && params.cache_prompt; + return (ctx_dft || has_mtp) && params.speculative.n_max > 0 && params.cache_prompt; + // return (ctx_dft) && params.speculative.n_max > 0 && params.cache_prompt; } void add_token(const completion_token_output & token) { @@ -1566,6 +1576,7 @@ struct server_slot { } }; + struct server_metrics { int64_t t_start = 0; @@ -2122,6 +2133,22 @@ struct server_context { } } + // if model has MTP and no draft model is specified... + else if (llama_model_n_nextn_layer(model) > 0) { + SRV_INF("model has nextn layers = %d\n", llama_model_n_nextn_layer(model)); + slot.has_mtp = true; + + // assume one speculative token (true of all well-known MTP models so far) + slot.batch_spec = llama_batch_init(2, 0, 1); + SLT_DBG(slot, "batch_spec contains %d tokens\n", slot.batch_spec.n_tokens); + + params_base.speculative.n_min = 0; + params_base.speculative.n_max = 1; + + SRV_INF("%s\n", "MTP needs embeddings on decode, enabling"); + llama_set_embeddings(ctx, true); + } + SLT_INF(slot, "new slot n_ctx_slot = %d\n", slot.n_ctx); slot.params.sampling = params_base.sampling; @@ -3368,6 +3395,7 @@ struct server_context { const bool need_embd = server_task_type_need_embd(slot.task_type); common_batch_add(batch, cur_tok, slot.n_past, { slot.id }, need_embd); + slot.cache_tokens.push_back(cur_tok); slot.n_prompt_tokens_processed++; @@ -3441,7 +3469,11 @@ struct server_context { batch.logits + i, }; + const int64_t t_prompt_main_start_us = ggml_time_us(); + batch_view.mtp_params.op_type = MTP_OP_UNIFIED; // TODO: Apply only for when the model have mtp const int ret = llama_decode(ctx, batch_view); + const int64_t t_prompt_main_end_us = ggml_time_us(); + LOG_INF("[PERF-PROMPT] Main model prompt processing: %.2f ms\n", (t_prompt_main_end_us - t_prompt_main_start_us) / 1000.0); metrics.on_decoded(slots); @@ -3518,9 +3550,9 @@ struct server_context { const int tok_idx = slot.i_batch - i; llama_token id = common_sampler_sample(slot.smpl, ctx, tok_idx); + slot.last_tok_idx = tok_idx; slot.i_batch = -1; - common_sampler_accept(slot.smpl, id, true); slot.n_decoded += 1; @@ -3590,23 +3622,36 @@ struct server_context { llama_token id = slot.sampled; - struct common_speculative_params params_spec; - params_spec.n_draft = n_draft_max; - params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max; - params_spec.p_min = slot.params.speculative.p_min; + llama_tokens draft; + // const int64_t t_spec_start_us = ggml_time_us(); + if (slot.has_mtp) { + llama_set_draft_input_hidden_state(ctx, llama_get_embeddings_ith(ctx, -1)); + + LOG_INF("[MTP-FLOW] Generating a new draft from the token %d in the position %d.\n", id, slot.n_past); + llama_token draft_id = mtp_speculative_gen_draft(slot.smpl, ctx, id, slot.n_past, slot.last_tok_idx); + draft.push_back(draft_id); + } + else { + struct common_speculative_params params_spec; + params_spec.n_draft = n_draft_max; + params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max; + params_spec.p_min = slot.params.speculative.p_min; + + const llama_tokens& cached_text_tokens = slot.cache_tokens.get_text_tokens(); - const llama_tokens & cached_text_tokens = slot.cache_tokens.get_text_tokens(); - llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, id); + draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, id); + } // ignore small drafts - if (slot.params.speculative.n_min > (int) draft.size()) { - SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.params.speculative.n_min); + if (slot.params.speculative.n_min > (int)draft.size()) { + SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int)draft.size(), slot.params.speculative.n_min); continue; } // keep track of total number of drafted tokens tested slot.n_draft_total += draft.size(); + SLT_DBG(slot, "draft size = %d\n", draft.size()); // construct the speculation batch common_batch_clear(slot.batch_spec); @@ -3617,11 +3662,17 @@ struct server_context { } SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.n_tokens); - + const int64_t t_valid_start_us = ggml_time_us(); + slot.batch_spec.mtp_params.op_type = MTP_OP_UNIFIED; llama_decode(ctx, slot.batch_spec); + const int64_t t_valid_end_us = ggml_time_us(); // the accepted tokens from the speculation + const int64_t t_accept_start_us = ggml_time_us(); const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft); + const int64_t t_accept_end_us = ggml_time_us(); + + slot.ids_prev_accepted = ids; slot.n_past += ids.size(); slot.n_decoded += ids.size();