From 954efc69a778ff2669c3665fe6ba36851da8e6fb Mon Sep 17 00:00:00 2001 From: yuerqiqi <2500526025@qq.com> Date: Mon, 2 Feb 2026 09:01:19 +0000 Subject: [PATCH 1/2] Feat: implement qwen3 probing service and tests --- examples/qwen3_service/CMakeLists.txt | 13 + examples/qwen3_service/main_probing.cpp | 153 +++ examples/qwen3_service/test_accuracy.cpp | 232 ++++ .../qwen3_service/test_trivia_probing.cpp | 420 +++++++ mllm-cli/cmd/mllm-server/main.go | 9 +- .../qwen3/modeling_qwen3_probing_service.hpp | 1057 +++++++++++++++++ 6 files changed, 1880 insertions(+), 4 deletions(-) create mode 100644 examples/qwen3_service/main_probing.cpp create mode 100644 examples/qwen3_service/test_accuracy.cpp create mode 100644 examples/qwen3_service/test_trivia_probing.cpp create mode 100644 mllm/models/qwen3/modeling_qwen3_probing_service.hpp diff --git a/examples/qwen3_service/CMakeLists.txt b/examples/qwen3_service/CMakeLists.txt index 31faa395e..e0456a91f 100644 --- a/examples/qwen3_service/CMakeLists.txt +++ b/examples/qwen3_service/CMakeLists.txt @@ -1,3 +1,16 @@ add_executable(mllm-qwen3-service main.cpp) target_link_libraries(mllm-qwen3-service PRIVATE MllmRT MllmCPUBackend) target_include_directories(mllm-qwen3-service PRIVATE ${MLLM_INCLUDE_DIR}) + +add_executable(mllm-qwen3-accuracy test_accuracy.cpp) +target_link_libraries(mllm-qwen3-accuracy PRIVATE MllmRT MllmCPUBackend) +target_include_directories(mllm-qwen3-accuracy PRIVATE ${MLLM_INCLUDE_DIR}) + + +add_executable(mllm-qwen3-probing main_probing.cpp) +target_link_libraries(mllm-qwen3-probing PRIVATE MllmRT MllmCPUBackend) +target_include_directories(mllm-qwen3-probing PRIVATE ${MLLM_INCLUDE_DIR}) + +add_executable(mllm-qwen3-trivia-probing test_trivia_probing.cpp) +target_link_libraries(mllm-qwen3-trivia-probing PRIVATE MllmRT MllmCPUBackend) +target_include_directories(mllm-qwen3-trivia-probing PRIVATE ${MLLM_INCLUDE_DIR}) \ No newline at end of file diff --git a/examples/qwen3_service/main_probing.cpp b/examples/qwen3_service/main_probing.cpp new file mode 100644 index 000000000..35cb5225c --- /dev/null +++ b/examples/qwen3_service/main_probing.cpp @@ -0,0 +1,153 @@ +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include "mllm/models/qwen3/modeling_qwen3_probing_service.hpp" + +using namespace mllm; +using namespace mllm::models::qwen3_probing; +namespace fs = std::filesystem; + +std::vector parseLayers(const std::string& input) { + std::vector layers; + std::stringstream ss(input); + std::string segment; + int num; + ss >> segment; + while (ss >> num) layers.push_back(num); + return layers; +} + +MLLM_MAIN({ + mllm::setLogLevel(mllm::LogLevel::kError); + auto& model_path = mllm::Argparse::add("-m|--model_path").help("Model path").required(true); + auto& probe_path = mllm::Argparse::add("-p|--probe_path").help("Probes dir").required(true); + mllm::Argparse::parse(argc, argv); + + auto qwen3_session = std::make_shared(); + try { + qwen3_session->fromPreTrain(model_path.get()); + } catch (const std::exception& e) { + std::cerr << "Load Model Error: " << e.what() << std::endl; + return 1; + } + + ProbingArgs p_args; + p_args.enable_prefill_check = true; + p_args.enable_decode_check = true; + + p_args.prefill_stop_threshold = 0.7f; + p_args.decode_stop_threshold = 0.8f; + + p_args.pos_threshold = 0.9f; + + std::cout << ">>> Loading Probes..." << std::endl; + qwen3_session->setProbingArgs(p_args); + qwen3_session->loadProbes(probe_path.get(), p_args); + + mllm::service::insertSession("mllmTeam/Qwen3-Probing", qwen3_session); + mllm::service::startService(); + + std::vector history; + std::vector current_prefill_layers = {27, 30}; // 默认 + + std::cout << "\n[System] Ready. Commands:\n"; + std::cout << " /prefill 15 20 Set prefill layers\n"; + std::cout << " /clear Clear history\n"; + std::cout << " /exit Exit\n"; + + while (true) { + std::cout << "\nUser: "; + std::string user_input; + std::getline(std::cin, user_input); + + if (user_input == "/exit") break; + if (user_input == "/clear") { + history.clear(); + continue; + } + if (user_input.rfind("/prefill", 0) == 0) { + current_prefill_layers = parseLayers(user_input); + std::cout << "Prefill layers: " << nlohmann::json(current_prefill_layers).dump() << "\n"; + continue; + } + + nlohmann::json user_msg; + user_msg["role"] = "user"; + user_msg["content"] = user_input; + history.push_back(user_msg); + + nlohmann::json req; + req["model"] = "mllmTeam/Qwen3-Probing"; + req["messages"] = history; + req["prefill_layers"] = current_prefill_layers; + req["enable_thinking"] = false; + req["id"] = "chat-probing"; + + mllm::service::sendRequest(req.dump()); + + std::string assistant_content; + bool thinking = false; + + while (true) { + std::string resp = mllm::service::getResponse("chat-probing"); + auto j = nlohmann::json::parse(resp); + + if (j.contains("choices") && j["choices"].size() > 0) { + auto& choice = j["choices"][0]; + auto content = choice["delta"]["content"]; + + if (content.is_string()) { + std::string s = content.get(); + if (s.find("early_exit") != std::string::npos) { + try { + auto warn = nlohmann::json::parse(s); + fmt::print(fmt::fg(fmt::color::red) | fmt::emphasis::bold, + "\n[Hallucination] Phase: {} | Layer: {} | Score: {:.4f}\n", warn.value("phase", "unknown"), + warn.value("layer", -1), warn.value("score", 0.0f)); + } catch (...) { fmt::print(fmt::fg(fmt::color::red), "\n[Hallucination] Raw: {}\n", s); } + + if (!history.empty() && history.back()["role"] == "user") history.pop_back(); + break; + } + + if (s == "") { + thinking = true; + continue; + } + if (s == "") { + thinking = false; + continue; + } + + if (thinking) + fmt::print(fmt::fg(fmt::color::gray), "{}", s); + else { + fmt::print("{}", s); + assistant_content += s; + } + std::fflush(stdout); + } + + if (choice["finish_reason"] == "stop") break; + } + } + + if (!assistant_content.empty()) { + nlohmann::json msg; + msg["role"] = "assistant"; + msg["content"] = assistant_content; + history.push_back(msg); + } + } + mllm::service::stopService(); + return 0; +}) \ No newline at end of file diff --git a/examples/qwen3_service/test_accuracy.cpp b/examples/qwen3_service/test_accuracy.cpp new file mode 100644 index 000000000..a8bcf6fe7 --- /dev/null +++ b/examples/qwen3_service/test_accuracy.cpp @@ -0,0 +1,232 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include "mllm/preprocessor/tokenizers/Unicode.hpp" +#include "mllm/models/qwen3/modeling_qwen3_probing_service.hpp" +#include "mllm/models/qwen3/tokenization_qwen3.hpp" + +using namespace mllm; +using namespace mllm::models::qwen3_probing; +using namespace mllm::models::qwen3; +namespace fs = std::filesystem; + +struct Sample { + std::string question; + std::string exact_answer; +}; + +// Robust CSV Parser +std::vector parse_csv_line(const std::string& line) { + std::vector result; + bool in_quote = false; + std::string field; + for (size_t i = 0; i < line.size(); ++i) { + char c = line[i]; + if (c == '"') { + if (in_quote && i + 1 < line.size() && line[i + 1] == '"') { + field += '"'; // escaped quote + i++; + } else { + in_quote = !in_quote; + } + } else if (c == ',' && !in_quote) { + result.push_back(field); + field.clear(); + } else { + field += c; + } + } + result.push_back(field); + return result; +} + +std::vector load_samples(const std::string& path) { + std::vector samples; + std::ifstream file(path); + if (!file.is_open()) { + std::cerr << "Cannot open file: " << path << std::endl; + return samples; + } + std::string line; + std::getline(file, line); // header + + while (std::getline(file, line)) { + if (line.empty()) continue; + auto fields = parse_csv_line(line); + + std::string q, exact; + + // 0: Unnamed, 1: raw_q, 2: q, 3: model_ans, 4: correct, 5: auto, 6: exact + if (fields.size() > 6) { + q = fields[2]; + exact = fields[6]; + } else { + continue; + } + + if (exact == "NO ANSWER" || exact == "Answer:" || exact == "[]") continue; + if (exact.empty()) continue; + + samples.push_back({q, exact}); + } + return samples; +} + +std::string normalize(std::string s) { + std::string out; + for (char c : s) { + if (!std::ispunct(c)) out += std::tolower(c); + } + return out; +} + +MLLM_MAIN({ + if (argc < 4) { + std::cout << "Usage: " << argv[0] << " [limit]" << std::endl; + return 1; + } + std::string model_path = argv[1]; + std::string probes_path = argv[2]; + std::string csv_path = argv[3]; + int limit = (argc > 4) ? std::stoi(argv[4]) : -1; + + // try { + std::cout << "Loading samples..." << std::endl; + auto samples = load_samples(csv_path); + if (samples.empty()) { + std::cerr << "No samples loaded from " << csv_path << std::endl; + return 1; + } + std::cout << "Loaded " << samples.size() << " samples." << std::endl; + + std::cout << "Initializing session..." << std::endl; + // Load Model + auto session = std::make_unique(); + std::cout << "Loading model from " << model_path << "..." << std::endl; + session->fromPreTrain(model_path); + + // Config Probing + std::cout << "Loading probes..." << std::endl; + ProbingArgs p_args; + p_args.enable_prefill_check = false; + p_args.enable_decode_check = true; + p_args.decode_stop_threshold = 1.1f; + p_args.pos_threshold = 0.9f; // Set a realistic high threshold for debounce strategy + + session->setProbingArgs(p_args); + session->loadProbes(probes_path, p_args); + + // Tokenizer + std::cout << "Loading tokenizer..." << std::endl; + Qwen3Tokenizer tokenizer(model_path + "/tokenizer.json"); + + int total_gen_tokens = 0; + int global_tp = 0; + int global_activations = 0; + int global_fn = 0; // Estimate + int global_real_positives = 0; + + for (int i = 0; i < samples.size(); ++i) { + if (limit > 0 && i >= limit) break; + const auto& s = samples[i]; + + std::cout << "\n=== Q [" << i << "]: " << s.question.substr(0, 100) << "..." << std::endl; + std::cout << "Target Exact: " << s.exact_answer << std::endl; + + session->clearLastProbeResults(); + + nlohmann::json request; + request["messages"] = nlohmann::json::array(); + request["messages"].push_back({{"role", "user"}, {"content", s.question}}); + request["max_length"] = 512; + request["do_sample"] = false; + + std::string full_response; + std::vector generated_tokens_list; + + try { + session->streamGenerate(request, [&](const nlohmann::json& chunk, bool is_finish) { + if (chunk.is_string()) { + std::string t = chunk.get(); + generated_tokens_list.push_back(t); + full_response += t; + } + }); + } catch (const std::exception& e) { + std::cerr << "Generation failed for Q[" << i << "]: " << e.what() << std::endl; + continue; + } + + std::cout << "Model Answer: " << full_response.substr(0, 100) << "..." << std::endl; + + auto results = session->getLastProbeResults(); + std::string target_norm = normalize(s.exact_answer); + + std::set real_positive_indices; + + // 1. Identify ALL Real Positives in the generated sequence + for (int t_idx = 0; t_idx < generated_tokens_list.size(); ++t_idx) { + std::string t_norm = normalize(generated_tokens_list[t_idx]); + bool matches_target = false; + if (!t_norm.empty() && target_norm.find(t_norm) != std::string::npos) { + // Heuristic: Length > 2 OR exact match short word + if (t_norm.length() > 2 || target_norm == t_norm) { matches_target = true; } + } + if (matches_target) { + real_positive_indices.insert(t_idx); + global_real_positives++; + } + total_gen_tokens++; + } + + // 2. Check Probe Activations (TP vs FP) + int local_tp = 0; + int local_fp = 0; + + for (const auto& res : results) { + if (res.type != "pos_check") continue; + + // Note: res.token_idx is the index in the generated sequence + if (real_positive_indices.count(res.token_idx)) { + local_tp++; + global_tp++; + } else { + local_fp++; + } + global_activations++; + } + + // 3. Estimate FN (Real Positives not detected) + // Since one activation could cover "a phrase", strict token-wise matching is harsh for Recall. + // But let's stick to token-wise for now. + + std::cout << " -> Stats: TP=" << local_tp << " FP=" << local_fp << " RealPos=" << real_positive_indices.size() << std::endl; + } + + std::cout << "\n=== Strategy Evaluation (Thr=0.9, Debounced) ===" << std::endl; + std::cout << "Total Checked Tokens: " << total_gen_tokens << std::endl; + std::cout << "Total Real Positives: " << global_real_positives << std::endl; + std::cout << "Total Activations: " << global_activations << std::endl; + + double precision = (global_activations > 0) ? (double)global_tp / global_activations : 0.0; + double recall = (global_real_positives > 0) ? (double)global_tp / global_real_positives : 0.0; + double f1 = (precision + recall > 0) ? 2 * (precision * recall) / (precision + recall) : 0.0; + + std::cout << "\nFinal Metrics:" << std::endl; + std::cout << "Precision: " << std::fixed << std::setprecision(2) << precision * 100.0 << "% (" << global_tp << "/" + << global_activations << ")" << std::endl; + std::cout << "Recall: " << std::fixed << std::setprecision(2) << recall * 100.0 << "% (" << global_tp << "/" + << global_real_positives << ")" << std::endl; + std::cout << "F1-Score: " << std::fixed << std::setprecision(2) << f1 * 100.0 << "%" << std::endl; + + return 0; +}) diff --git a/examples/qwen3_service/test_trivia_probing.cpp b/examples/qwen3_service/test_trivia_probing.cpp new file mode 100644 index 000000000..40bbf3b49 --- /dev/null +++ b/examples/qwen3_service/test_trivia_probing.cpp @@ -0,0 +1,420 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include "mllm/models/qwen3/modeling_qwen3_probing_service.hpp" + +#include + +using namespace mllm; +using namespace mllm::models::qwen3_probing; +namespace fs = std::filesystem; + +struct TriviaSample { + std::string question; + std::vector answers; + std::string expected_label; +}; + +struct GlobalStat { + float max_prefill_score; // 该样本所有Prefill层中的最高分 + bool is_model_wrong; // 1=Wrong/Hallucination, 0=Correct +}; + +std::vector split(const std::string& s, char delimiter) { + std::vector tokens; + std::string token; + std::istringstream tokenStream(s); + while (std::getline(tokenStream, token, delimiter)) { tokens.push_back(token); } + return tokens; +} + +std::vector loadTrivia(const std::string& path, int max_lines = -1) { + std::vector samples; + std::ifstream file(path); + if (!file.is_open()) return samples; + std::string line; + std::getline(file, line); + std::vector all_lines; + while (std::getline(file, line)) { + if (!line.empty()) all_lines.push_back(line); + } + unsigned seed = std::chrono::system_clock::now().time_since_epoch().count(); + std::mt19937 g(seed); + std::shuffle(all_lines.begin(), all_lines.end(), g); + for (const auto& l : all_lines) { + bool in_quote = false; + std::vector fields; + std::string current_field; + for (char c : l) { + if (c == '"') { + in_quote = !in_quote; + } else if (c == ',' && !in_quote) { + fields.push_back(current_field); + current_field.clear(); + } else { + current_field += c; + } + } + fields.push_back(current_field); + if (fields.size() < 2) continue; + std::string q, a_str; + std::string f0 = fields[0]; + if (f0.size() >= 2 && f0.front() == '"' && f0.back() == '"') f0 = f0.substr(1, f0.size() - 2); + if ((fields.size() >= 2) && (f0.find('_') != std::string::npos || f0.find("tc-") != std::string::npos)) { + q = fields[1]; + if (fields.size() > 3) + a_str = fields[3]; + else if (fields.size() > 2) + a_str = fields[2]; + else + a_str = ""; + } else { + q = fields[0]; + int ans_idx = (fields.size() > 6) ? 6 : 1; + if (fields.size() <= ans_idx) ans_idx = fields.size() - 1; + a_str = fields[ans_idx]; + } + if (q.size() >= 2 && q.front() == '"' && q.back() == '"') q = q.substr(1, q.size() - 2); + if (a_str.size() >= 2 && a_str.front() == '"' && a_str.back() == '"') a_str = a_str.substr(1, a_str.size() - 2); + if (q.find("bt_") == 0 || q.find("tc_") == 0 || q.length() < 5) continue; + if (q.size() >= 2 && q.front() == '"' && q.back() == '"') q = q.substr(1, q.size() - 2); + if (a_str.size() >= 2 && a_str.front() == '"' && a_str.back() == '"') a_str = a_str.substr(1, a_str.size() - 2); + size_t val_pos = a_str.find("'Value': '"); + if (val_pos != std::string::npos) { + size_t end_pos = a_str.find("'", val_pos + 10); + if (end_pos != std::string::npos) { + a_str = a_str.substr(val_pos + 10, end_pos - (val_pos + 10)); + samples.push_back({q, {a_str}}); + continue; + } + } + auto ans_list = split(a_str, '|'); + samples.push_back({q, ans_list}); + } + return samples; +} + +std::vector loadReplay(const std::string& path) { + std::vector samples; + std::ifstream file(path); + if (!file.is_open()) return samples; + nlohmann::json j; + file >> j; + for (const auto& item : j) { + TriviaSample s; + if (item.contains("question")) s.question = item["question"].get(); + if (item.contains("refs")) + for (const auto& r : item["refs"]) s.answers.push_back(r.get()); + if (item.contains("label")) s.expected_label = item["label"].get(); + if (!s.question.empty()) samples.push_back(s); + } + return samples; +} + +std::string normalize(std::string s) { + if (s.empty()) return ""; + std::string out; + out.reserve(s.size()); + for (size_t i = 0; i < s.size(); ++i) { + unsigned char c = s[i]; + if (c < 128) { + if (!std::ispunct(c) && !std::iscntrl(c)) out += std::tolower(c); + } + } + return out; +} + +std::string removeThinking(const std::string& text) { + std::string out = text; + size_t start = out.find(""); + size_t end = out.find(""); + if (start != std::string::npos && end != std::string::npos && end > start) { out.erase(start, end - start + 8); } + return out; +} + +bool checkAnswer(const std::string& generated, const std::vector& refs) { + std::string clean_gen = removeThinking(generated); + std::string gen_norm = normalize(clean_gen); + for (const auto& ref : refs) { + std::string ref_norm = normalize(ref); + if (ref_norm.empty()) continue; + if (gen_norm.find(ref_norm) != std::string::npos) return true; + } + return false; +} + +// --- AUC --- +double calculate_auc(const std::vector& positive_scores, const std::vector& negative_scores) { + if (positive_scores.empty() || negative_scores.empty()) return 0.0; + + struct Pair { + float score; + int label; + }; + std::vector all_samples; + all_samples.reserve(positive_scores.size() + negative_scores.size()); + + for (float s : positive_scores) all_samples.push_back({s, 1}); + for (float s : negative_scores) all_samples.push_back({s, 0}); + + std::sort(all_samples.begin(), all_samples.end(), [](const Pair& a, const Pair& b) { return a.score > b.score; }); + + double auc_sum = 0; + double current_pos_count = 0; + + for (const auto& p : all_samples) { + if (p.label == 1) { + current_pos_count++; + } else { + auc_sum += current_pos_count; + } + } + + return auc_sum / (double)(positive_scores.size() * negative_scores.size()); +} + +void print_stats(const std::map>>>& stats) { + for (auto& [phase, layer_map] : stats) { + float threshold = (phase == "decode") ? 0.6f : 0.7f; + + std::cout << "\n--- Per-Layer Analysis (" << phase << ") [Thres=" << threshold << "] ---\n"; + std::cout << "Layer | Acc (Det) | AUC | AvgScore (All) | Samples (C/W)\n"; + std::cout << "----------------------------------------------------------------\n"; + + for (auto& [layer, correctness_map] : layer_map) { + const auto& correct_vec = correctness_map.count(0) ? correctness_map.at(0) : std::vector{}; + const auto& wrong_vec = correctness_map.count(1) ? correctness_map.at(1) : std::vector{}; + + int total = correct_vec.size() + wrong_vec.size(); + if (total == 0) continue; + + // Acc + int tn = 0; + for (float s : correct_vec) + if (s < threshold) tn++; + int tp = 0; + for (float s : wrong_vec) + if (s >= threshold) tp++; + double acc = (double)(tn + tp) / total * 100.0; + + // Avg + double sum_s = 0; + for (float s : correct_vec) sum_s += s; + for (float s : wrong_vec) sum_s += s; + + // AUC + double auc = calculate_auc(wrong_vec, correct_vec); // Wrong=Positive, Correct=Negative + + std::cout << "L" << std::setw(2) << layer << " | " << std::fixed << std::setprecision(1) << std::setw(5) << acc + << "% | " << std::setprecision(3) << std::setw(6) << auc << " | " << std::setprecision(4) << (sum_s / total) + << " | " << correct_vec.size() << "/" << wrong_vec.size() << "\n"; + } + } +} + +void print_global_stats(const std::vector& global_stats) { + float threshold = 0.7f; + if (global_stats.empty()) return; + + std::vector pos_scores, neg_scores; + int g_tp = 0, g_tn = 0, g_fp = 0, g_fn = 0; + + for (const auto& s : global_stats) { + if (s.is_model_wrong) + pos_scores.push_back(s.max_prefill_score); + else + neg_scores.push_back(s.max_prefill_score); + + bool predicted_hallucination = (s.max_prefill_score >= threshold); + if (s.is_model_wrong) { + if (predicted_hallucination) + g_tp++; + else + g_fn++; + } else { + if (!predicted_hallucination) + g_tn++; + else + g_fp++; + } + } + + int total = global_stats.size(); + int total_wrong = g_tp + g_fn; + int total_correct = g_tn + g_fp; + + double acc = (double)(g_tp + g_tn) / total * 100.0; + double recall = (total_wrong > 0) ? (double)g_tp / total_wrong * 100.0 : 0.0; + double precision = (g_tp + g_fp > 0) ? (double)g_tp / (g_tp + g_fp) * 100.0 : 0.0; + double auc = calculate_auc(pos_scores, neg_scores); + + std::cout << "\n>>> WHOLE MODEL PREFILL STATS (Any Layer >= " << threshold << ") <<<\n"; + std::cout << "Total: " << total << " (Wrong: " << total_wrong << ", Correct: " << total_correct << ")\n"; + std::cout << "AUC: " << std::fixed << std::setprecision(4) << auc << " <--- Classification Capability\n"; + std::cout << "Accuracy: " << std::fixed << std::setprecision(2) << acc << "%\n"; + std::cout << "Recall: " << recall << "% (Caught " << g_tp << ")\n"; + std::cout << "Precision: " << precision << "%\n"; + std::cout << "Confusion: [TP:" << g_tp << " FN:" << g_fn << "] [FP:" << g_fp << " TN:" << g_tn << "]\n"; + std::cout << "---------------------------------------------------------\n"; +} + +MLLM_MAIN({ + mllm::setLogLevel(mllm::LogLevel::kError); + auto& model_path = mllm::Argparse::add("-m|--model_path").help("Model path").required(true); + auto& probe_path = mllm::Argparse::add("-p|--probe_path").help("Probes dir").required(true); + auto& data_path = mllm::Argparse::add("-d|--data_path").help("Trivia CSV path"); + auto& replay_path = mllm::Argparse::add("-r|--replay_file").help("Replay from JSON"); + auto& limit = mllm::Argparse::add("--limit").help("Max samples").def(200); + auto& balanced_target = mllm::Argparse::add("-b|--balanced_target").def(0); + mllm::Argparse::parse(argc, argv); + + if (data_path.get().empty() && replay_path.get().empty()) { + std::cerr << "Error: Must provide either -d or -r" << std::endl; + return 1; + } + + auto session = std::make_shared(); + try { + session->fromPreTrain(model_path.get()); + } catch (const std::exception& e) { + std::cerr << "Load Model Error: " << e.what() << std::endl; + return 1; + } + + ProbingArgs p_args; + p_args.enable_prefill_check = true; + p_args.enable_decode_check = true; + p_args.prefill_stop_threshold = 1.1f; // Don't stop early + p_args.decode_stop_threshold = 1.1f; + p_args.pos_threshold = 0.9f; + + for (int i = 0; i < 36; ++i) p_args.default_prefill_layers.push_back(i); + session->setProbingArgs(p_args); + session->loadProbes(probe_path.get(), p_args); + + std::cout << "Loading data..." << std::endl; + std::vector samples; + if (!replay_path.get().empty()) { + samples = loadReplay(replay_path.get()); + } else { + samples = loadTrivia(data_path.get()); + } + std::cout << "Loaded " << samples.size() << " samples." << std::endl; + + std::map>>> stats; + std::vector global_prefill_stats; + + nlohmann::json output_samples = nlohmann::json::array(); + int model_correct_total = 0; + int model_wrong_total = 0; + int processed_count = 0; + int max_lim = limit.get(); + + for (size_t i = 0; i < samples.size(); ++i) { + if (processed_count >= max_lim && max_lim > 0) break; + + const auto& sample = samples[i]; + session->clearLastProbeResults(); + + nlohmann::json req; + nlohmann::json msg; + msg["role"] = "user"; + msg["content"] = "Answer the question directly. Do not use tags. Question: " + sample.question; + req["messages"] = nlohmann::json::array({msg}); + req["max_length"] = 50; + req["do_sample"] = false; + req["enable_thinking"] = false; + + std::string generated_text = ""; + auto start_t = std::chrono::high_resolution_clock::now(); + int tok_cnt = 0; + try { + session->streamGenerate(req, [&](const nlohmann::json& j, bool finished) { + if (j.is_string()) { + generated_text += j.get(); + tok_cnt++; + } + }); + } catch (...) { continue; } + auto end_t = std::chrono::high_resolution_clock::now(); + double dur = std::chrono::duration(end_t - start_t).count(); + double tps = (dur > 0) ? (tok_cnt / dur) : 0.0; + + bool is_correct = checkAnswer(generated_text, sample.answers); + if (is_correct) + model_correct_total++; + else { + model_wrong_total++; + std::cout << "\n[WRONG] Q: " << sample.question << "\n"; + std::cout << " Gen: " << removeThinking(generated_text) << "\n"; + std::cout << " Ref: " << sample.answers[0] << "\n"; + } + + std::cout << "\r[" << (i + 1) << "] C:" << model_correct_total << " W:" << model_wrong_total << " TPS:" << std::fixed + << std::setprecision(1) << tps << " | " << sample.question.substr(0, 20) << "..." << std::flush; + + auto probe_data = session->getLastProbeResults(); + std::map>> sample_scores; + float max_prefill_this_sample = 0.0f; + + for (const auto& p : probe_data) { + sample_scores[p.phase][p.layer].push_back(p.score); + if (p.phase == "prefill") { + if (p.score > max_prefill_this_sample) max_prefill_this_sample = p.score; + } + } + + global_prefill_stats.push_back({max_prefill_this_sample, !is_correct}); + + for (const auto& [ph, layer_map] : sample_scores) { + for (const auto& [lay, scores] : layer_map) { + if (scores.empty()) continue; + double sum = 0; + for (float s : scores) sum += s; + double avg = sum / scores.size(); + stats[ph][lay][is_correct ? 0 : 1].push_back(avg); + } + } + + { + nlohmann::json s_obj; + s_obj["q"] = sample.question; + s_obj["g"] = generated_text; + s_obj["l"] = is_correct ? "correct" : "wrong"; + output_samples.push_back(s_obj); + } + processed_count++; + + if (processed_count % 10 == 0) { + print_stats(stats); + print_global_stats(global_prefill_stats); + } + } + + std::cout << "\n\nTotal: " << processed_count << " (Acc: " << (float)model_correct_total / processed_count * 100.0 << "%)\n"; + std::ofstream out_f("probing_results_replay.json"); + if (out_f.is_open()) { + out_f << output_samples.dump(4); + out_f.close(); + } + + print_stats(stats); + print_global_stats(global_prefill_stats); + + return 0; +}); \ No newline at end of file diff --git a/mllm-cli/cmd/mllm-server/main.go b/mllm-cli/cmd/mllm-server/main.go index 4d78836be..819e9936b 100644 --- a/mllm-cli/cmd/mllm-server/main.go +++ b/mllm-cli/cmd/mllm-server/main.go @@ -29,10 +29,6 @@ func main() { log.Fatal("FATAL: InitializeContext failed!") } mllm.SetLogLevel(2) - if !mllm.StartService(1) { - log.Fatal("FATAL: StartService failed!") - } - defer mllm.StopService() defer mllm.ShutdownContext() mllmService := pkgmllm.NewService() @@ -69,6 +65,11 @@ func main() { log.Printf("DeepSeek-OCR Session created and registered successfully with ID: %s", sessionID) } + if !mllm.StartService(1) { + log.Fatal("FATAL: StartService failed!") + } + defer mllm.StopService() + httpServer := server.NewServer(":8080", mllmService) go httpServer.Start() diff --git a/mllm/models/qwen3/modeling_qwen3_probing_service.hpp b/mllm/models/qwen3/modeling_qwen3_probing_service.hpp new file mode 100644 index 000000000..8565ab6fc --- /dev/null +++ b/mllm/models/qwen3/modeling_qwen3_probing_service.hpp @@ -0,0 +1,1057 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "mllm/mllm.hpp" +#include "mllm/nn/Nn.hpp" +#include "mllm/nn/Module.hpp" +#include "mllm/nn/Functional.hpp" +#include "mllm/utils/Enumerate.hpp" +#include "mllm/models/ARGeneration.hpp" +#include "mllm/preprocessor/tokenizers/Unicode.hpp" +#include "mllm/models/qwen3/tokenization_qwen3.hpp" +#include "mllm/models/qwen3/configuration_qwen3.hpp" + +// Service related. +#include "mllm/engine/service/Session.hpp" +#include "mllm/engine/prefix_cache/Cache.hpp" + +namespace mllm::models::qwen3_probing { + +using namespace mllm; +using namespace mllm::nn; +using namespace mllm::models::qwen3; + +struct ProbingArgs { + bool enable_prefill_check = false; + float prefill_stop_threshold = 0.7f; + std::vector default_prefill_layers; + + bool enable_decode_check = false; + float decode_stop_threshold = 0.8f; + float pos_threshold = 0.9f; +}; + +struct ProbingContext { + std::map mlp_outputs; + bool collecting = false; + bool save_last_token_only = false; + std::set target_layers; + + void reset() { + mlp_outputs.clear(); + collecting = false; + save_last_token_only = false; + target_layers.clear(); + } + + void soft_reset() { + collecting = false; + save_last_token_only = false; + target_layers.clear(); + } +}; + +// RoPE +inline auto makeRoPEInvFreq(int output_dim, float rope_theta) -> Tensor { + auto inv_freq = Tensor::empty({output_dim / 2}, kFloat32, kCPU).alloc(); + auto inv_freq_ptr = inv_freq.ptr(); + for (int i = 0; i < output_dim / 2; i++) { inv_freq_ptr[i] = 1.0 / std::pow(rope_theta, 2.0 * i / output_dim); } + return inv_freq; +} + +inline auto makeRotaryPosEmbedding(Tensor& position_ids, const Tensor& inv_freq, + float attention_scaling = 1.0f) -> std::pair { + auto batch_size = position_ids.shape()[0]; + auto seq_len = position_ids.shape()[1]; + auto inv_freq_len = inv_freq.shape()[0]; + auto dim = inv_freq_len * 2; + auto freqs = Tensor::empty({batch_size, seq_len, inv_freq_len}, kFloat32, kCPU).alloc(); + auto freqs_ptr = freqs.ptr(); + auto position_ids_ptr = position_ids.ptr(); + auto inv_freq_ptr = inv_freq.ptr(); + for (int b = 0; b < batch_size; ++b) { + for (int s = 0; s < seq_len; ++s) { + auto pos = position_ids_ptr[b * seq_len + s]; + for (int d = 0; d < inv_freq_len; ++d) { + freqs_ptr[b * seq_len * inv_freq_len + s * inv_freq_len + d] = static_cast(pos) * inv_freq_ptr[d]; + } + } + } + auto sin_emb = Tensor::empty({batch_size, seq_len, dim}, kFloat32, kCPU).alloc(); + auto cos_emb = Tensor::empty({batch_size, seq_len, dim}, kFloat32, kCPU).alloc(); + auto sin_ptr = sin_emb.ptr(); + auto cos_ptr = cos_emb.ptr(); + for (int b = 0; b < batch_size; ++b) { + for (int s = 0; s < seq_len; ++s) { + for (int d = 0; d < inv_freq_len; ++d) { + auto freq = freqs_ptr[b * seq_len * inv_freq_len + s * inv_freq_len + d]; + auto sin_val = std::sin(freq) * attention_scaling; + auto cos_val = std::cos(freq) * attention_scaling; + sin_ptr[b * seq_len * dim + s * dim + d] = sin_val; + sin_ptr[b * seq_len * dim + s * dim + d + inv_freq_len] = sin_val; + cos_ptr[b * seq_len * dim + s * dim + d] = cos_val; + cos_ptr[b * seq_len * dim + s * dim + d + inv_freq_len] = cos_val; + } + } + } + return {sin_emb, cos_emb}; +} + +// Linear, Scaler, PCA +class ProbeClassifier : public Module { + Linear linear_; + + bool use_scaler_ = false; + Param scaler_mean_; + Param scaler_scale_; + + bool use_pca_ = false; + Param pca_components_; // [hidden_dim, pca_dim] + + public: + ProbeClassifier() = default; + + ProbeClassifier(const std::string& name, int hidden_dim, int linear_in_dim, bool use_scaler, bool use_pca, + std::string linear_name, std::string scaler_prefix, std::string pca_name) + : Module(name), use_scaler_(use_scaler), use_pca_(use_pca) { + if (use_scaler_) { + scaler_mean_ = reg("scaler_mean", scaler_prefix + "_mean.weight", Tensor::shape_t({1, 1, hidden_dim})); + scaler_scale_ = reg("scaler_scale", scaler_prefix + "_scale.weight", Tensor::shape_t({1, 1, hidden_dim})); + } + + if (use_pca_) { + pca_components_ = reg("pca_components", pca_name + ".weight", Tensor::shape_t({linear_in_dim, hidden_dim})); + } + linear_ = reg(linear_name, linear_in_dim, 1, true, mllm::aops::LinearImplTypes::kDefault); + } + + virtual float predict(Tensor& hidden_emb) { + Tensor x = hidden_emb; + + if (use_scaler_) { + x = x - scaler_mean_.weight(); + x = x / scaler_scale_.weight(); + } + if (use_pca_) { + // hidden_emb: [1, 1, hidden_dim] + // pca_components_.weight(): [linear_in_dim, hidden_dim] + // transpose(0, 1): [hidden_dim, linear_in_dim] + // matmul: [1, 1, hidden_dim] * [hidden_dim, linear_in_dim] -> [1, 1, linear_in_dim] + x = mllm::nn::functional::matmul(x, pca_components_.weight().transpose(0, 1)); + } + + auto logits = linear_(x); + + float val = 0.0f; + if (logits.dtype() == mllm::kFloat32) { + val = logits.ptr()[0]; + } else if (logits.dtype() == mllm::kFloat16) { + val = (float)logits.ptr<__fp16>()[0]; + } + + return 1.0f / (1.0f + std::exp(-val)); + } +}; + +// MODEL +class Qwen3ProbingMLP final : public Module { + Linear gate_proj_, up_proj_, down_proj_; + SiLU silu_; + + public: + Qwen3ProbingMLP() = default; + Qwen3ProbingMLP(const std::string& name, const Qwen3Config& cfg) : Module(name) { + gate_proj_ = reg("gate_proj", cfg.hidden_size, cfg.intermediate_size, false, cfg.linear_impl_type); + silu_ = reg("act"); + up_proj_ = reg("up_proj", cfg.hidden_size, cfg.intermediate_size, false, cfg.linear_impl_type); + down_proj_ = reg("down_proj", cfg.intermediate_size, cfg.hidden_size, false, cfg.linear_impl_type); + } + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto x = gate_proj_(inputs[0]); + x = silu_(x); + auto y = up_proj_(inputs[0]); + x = x * y; + x = down_proj_(x); + return {x}; + } +}; + +class Qwen3ProbingAttention final : public Module { + Linear q_proj_, k_proj_, v_proj_, o_proj_; + RMSNorm rms_norm_q_, rms_norm_k_; + RoPE q_rope_, k_rope_; + RadixAttn attn_; + int hidden_size_, head_dim_, num_attention_heads_, num_key_value_heads_; + int layer_idx_; + + public: + friend class Qwen3ProbingText; + friend class Qwen3ProbingDecoder; + + Qwen3ProbingAttention() = default; + Qwen3ProbingAttention(const std::string& name, const Qwen3Config& cfg) : Module(name) { + hidden_size_ = cfg.hidden_size; + num_attention_heads_ = cfg.num_attention_heads; + num_key_value_heads_ = cfg.num_key_value_heads; + head_dim_ = cfg.head_dim; + q_proj_ = reg("q_proj", hidden_size_, head_dim_ * num_attention_heads_, cfg.attention_bias, cfg.linear_impl_type); + k_proj_ = reg("k_proj", hidden_size_, head_dim_ * num_key_value_heads_, cfg.attention_bias, cfg.linear_impl_type); + v_proj_ = reg("v_proj", hidden_size_, head_dim_ * num_key_value_heads_, cfg.attention_bias, cfg.linear_impl_type); + o_proj_ = reg("o_proj", head_dim_ * num_attention_heads_, hidden_size_, cfg.attention_bias, cfg.linear_impl_type); + rms_norm_q_ = reg("q_norm", cfg.rms_norm_eps); + rms_norm_k_ = reg("k_norm", cfg.rms_norm_eps); + q_rope_ = reg("q_rope", cfg.rope_theta, cfg.max_position_embeddings, aops::RoPEOpOptionsInputType::kBSHD); + k_rope_ = reg("k_rope", cfg.rope_theta, cfg.max_position_embeddings, aops::RoPEOpOptionsInputType::kBSHD); + attn_ = reg("attn", num_attention_heads_, num_key_value_heads_); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto x = inputs[0]; + auto llm_embedding_sin = inputs[1]; + auto llm_embedding_cos = inputs[2]; + auto k_cache_addr = args[0].get>*>(); + auto v_cache_addr = args[1].get>*>(); + auto prefix_cache_context = args[2].get(); + + auto query_states = q_proj_(x); + auto key_states = k_proj_(x); + auto value_states = v_proj_(x); + + int B = inputs[0].shape()[0]; + int S = inputs[0].shape()[1]; + + query_states = query_states.view({B, S, num_attention_heads_, head_dim_}); + key_states = key_states.view({B, S, num_key_value_heads_, head_dim_}); + value_states = value_states.view({B, S, num_key_value_heads_, head_dim_}); + + query_states = rms_norm_q_(query_states); + key_states = rms_norm_k_(key_states); + + query_states = q_rope_(query_states, llm_embedding_sin, llm_embedding_cos); + key_states = k_rope_(key_states, llm_embedding_sin, llm_embedding_cos); + + std::vector k_addr_wait_for_promote, v_addr_wait_for_promote; + for (int s_idx = 0; s_idx < S; ++s_idx) { + k_addr_wait_for_promote.push_back(prefix_cache_context->alloc(kCPU)); + v_addr_wait_for_promote.push_back(prefix_cache_context->alloc(kCPU)); + } + std::vector k_phy_addr_wait_for_promote, v_phy_addr_wait_for_promote; + for (int s_idx = 0; s_idx < S; ++s_idx) { + k_phy_addr_wait_for_promote.push_back(prefix_cache_context->physicalAddr(k_addr_wait_for_promote[s_idx])); + v_phy_addr_wait_for_promote.push_back(prefix_cache_context->physicalAddr(v_addr_wait_for_promote[s_idx])); + } + auto k_wait_for_promote = Tensor::refVectorData(k_phy_addr_wait_for_promote, {S}, kInt64, kCPU); + auto v_wait_for_promote = Tensor::refVectorData(v_phy_addr_wait_for_promote, {S}, kInt64, kCPU); + + nn::functional::scatter2Shards(key_states, k_wait_for_promote, 1); + nn::functional::scatter2Shards(value_states, v_wait_for_promote, 1); + + { + auto& dst = (*k_cache_addr)[layer_idx_]; + dst.insert(dst.end(), k_addr_wait_for_promote.begin(), k_addr_wait_for_promote.end()); + } + { + auto& dst = (*v_cache_addr)[layer_idx_]; + dst.insert(dst.end(), v_addr_wait_for_promote.begin(), v_addr_wait_for_promote.end()); + } + + std::vector k_phy_cache_indicies, v_phy_cache_indicies; + int32_t kv_cache_len = (*k_cache_addr)[layer_idx_].size(); + k_phy_cache_indicies.reserve(kv_cache_len); + v_phy_cache_indicies.reserve(kv_cache_len); + for (int i = 0; i < kv_cache_len; ++i) { + k_phy_cache_indicies.push_back(prefix_cache_context->physicalAddr((*k_cache_addr)[layer_idx_][i])); + v_phy_cache_indicies.push_back(prefix_cache_context->physicalAddr((*v_cache_addr)[layer_idx_][i])); + } + auto k_cache = Tensor::refVectorData(k_phy_cache_indicies, {kv_cache_len}, kInt64, kCPU); + auto v_cache = Tensor::refVectorData(v_phy_cache_indicies, {kv_cache_len}, kInt64, kCPU); + + auto output = attn_(query_states, k_cache, v_cache); + output = output.view({B, S, num_attention_heads_ * head_dim_}); + output = o_proj_(output); + + return {output}; + } +}; + +class Qwen3ProbingDecoder final : public Module { + public: + Qwen3ProbingAttention self_attn_; + Qwen3ProbingMLP mlp_; + RMSNorm input_layer_norm_, post_attention_layer_norm_; + int layer_idx_; + + Qwen3ProbingDecoder() = default; + Qwen3ProbingDecoder(const std::string& name, const Qwen3Config& cfg) : Module(name) { + self_attn_ = reg("self_attn", cfg); + mlp_ = reg("mlp", cfg); + input_layer_norm_ = reg("input_layernorm", cfg.rms_norm_eps); + post_attention_layer_norm_ = reg("post_attention_layernorm", cfg.rms_norm_eps); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto llm_embedding_sin = inputs[1]; + auto llm_embedding_cos = inputs[2]; + auto& k_cache_addr = args[0]; + auto& v_cache_addr = args[1]; + auto& prefix_cache_context = args[2]; + + ProbingContext* probe_ctx = nullptr; + if (args.size() > 3) probe_ctx = args[3].get(); + + auto x = input_layer_norm_(inputs[0]); + x = self_attn_(x, llm_embedding_sin, llm_embedding_cos, k_cache_addr, v_cache_addr, prefix_cache_context)[0]; + auto tmp = x + inputs[0]; + x = post_attention_layer_norm_(tmp); + + auto mlp_out = mlp_(x)[0]; + + // Probe + if (probe_ctx && probe_ctx->collecting) { + bool layer_needed = false; + if (probe_ctx->target_layers.empty()) + layer_needed = true; + else if (probe_ctx->target_layers.count(layer_idx_)) + layer_needed = true; + + if (layer_needed) { + int batch = mlp_out.shape()[0]; + int seq_len = mlp_out.shape()[1]; + int hidden_dim = mlp_out.shape()[2]; + + Tensor* dest_ptr = nullptr; + bool need_alloc = true; + if (probe_ctx->mlp_outputs.count(layer_idx_)) { + auto& t = probe_ctx->mlp_outputs[layer_idx_]; + if (t.shape().size() == 3 && t.shape()[0] == batch && t.shape()[1] == 1 && t.shape()[2] == hidden_dim + && t.dtype() == mlp_out.dtype()) { + dest_ptr = &t; + need_alloc = false; + } + } + + if (need_alloc) { + probe_ctx->mlp_outputs[layer_idx_] = Tensor::empty({batch, 1, hidden_dim}, mlp_out.dtype(), kCPU); + probe_ctx->mlp_outputs[layer_idx_].alloc(); + dest_ptr = &probe_ctx->mlp_outputs[layer_idx_]; + } + + int token_offset = probe_ctx->save_last_token_only ? (seq_len - 1) : 0; + + size_t dtype_size = (mlp_out.dtype() == mllm::kFloat32) ? 4 : 2; + char* src_base_ptr = (char*)mlp_out.ptr(); + size_t byte_offset = (size_t)token_offset * hidden_dim * dtype_size; + + if (src_base_ptr && dest_ptr->ptr()) { + std::memcpy(dest_ptr->ptr(), src_base_ptr + byte_offset, hidden_dim * dtype_size); + } + } + } + + x = mlp_out + tmp; + return {x}; + } +}; + +class Qwen3ProbingText final : public Module { + ModuleList decode_blocks_; + RMSNorm norm_; + Embedding embedding_; + + public: + Qwen3ProbingText() = default; + Qwen3ProbingText(const std::string& name, const Qwen3Config& cfg) : Module(name) { + decode_blocks_ = reg>("layers", cfg.num_hidden_layers, cfg); + for (auto [idx, b] : enumerate(decode_blocks_.list())) { + b.self_attn_.layer_idx_ = idx; + b.layer_idx_ = idx; + } + norm_ = reg("norm", cfg.rms_norm_eps); + embedding_ = reg("embed_tokens", cfg.vocab_size, cfg.hidden_size); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto& blocks = decode_blocks_.list(); + auto x = embedding_(inputs[0]); + auto llm_embedding_sin = inputs[1]; + auto llm_embedding_cos = inputs[2]; + + for (auto& block : blocks) { + x = block(x, llm_embedding_sin, llm_embedding_cos, args[0], args[1], args[2], args.size() > 3 ? args[3] : AnyValue())[0]; + } + x = norm_(x); + return {x}; + } +}; + +class Qwen3ProbingForCausalLM : public ARGeneration, public Module { + public: + explicit Qwen3ProbingForCausalLM(const Qwen3Config& cfg) : cfg(cfg) { + eos_token_id_ = cfg.end_of_text_token_id; + max_length_ = cfg.max_cache_length; + tie_word_embeddings_ = cfg.tie_word_embeddings; + llm = reg("model", cfg); + if (cfg.tie_word_embeddings) { + lm_head_ = reg("lm_head_out", cfg.hidden_size, cfg.vocab_size, false, cfg.linear_impl_type); + } + auto inv = makeRoPEInvFreq(cfg.head_dim, cfg.rope_theta); + registerBuffer("inv_freq", inv); + } + + // load probes from directory + void loadProbesFromDirectory(const std::string& dir_path, const ProbingArgs& args) { + namespace fs = std::filesystem; + if (!fs::exists(dir_path)) { + std::cerr << "Probe dir not found: " << dir_path << std::endl; + return; + } + + for (const auto& entry : fs::directory_iterator(dir_path)) { + std::string fn = entry.path().filename().string(); + std::string path = entry.path().string(); + if (fn.find(".mllm") == std::string::npos) continue; + + std::shared_ptr params; + try { + params = mllm::load(path, mllm::ModelFileVersion::kV2); + } catch (const std::exception& e) { + std::cerr << "Failed to open " << fn << ": " << e.what() << std::endl; + continue; + } + + std::string detect_linear_name = "classifier"; + std::string detect_scaler_prefix = "scaler"; + std::string detect_pca_name = "pca_components"; + + bool has_scaler = false; + bool has_pca = false; + int linear_in_dim = cfg.hidden_size; + + for (auto& [key, tensor] : *params) { + if (key.find("linear.weight") != std::string::npos) { + detect_linear_name = "linear"; + if (tensor.shape().size() > 1) linear_in_dim = tensor.shape()[1]; + } else if (key.find("classifier.weight") != std::string::npos) { + detect_linear_name = "classifier"; + if (tensor.shape().size() > 1) linear_in_dim = tensor.shape()[1]; + } + + if (key.find("scaler_mean") != std::string::npos) { + has_scaler = true; + detect_scaler_prefix = "scaler"; + } + + if (key.find("pca_proj") != std::string::npos) { + has_pca = true; + detect_pca_name = "pca_proj"; + } else if (key.find("pca_components") != std::string::npos) { + has_pca = true; + detect_pca_name = "pca_components"; + } + } + + bool is_prefill = (fn.find("prefill") != std::string::npos); + bool is_pos = (fn.find("pos_probe") != std::string::npos); + + int parsed_layer = -1; + size_t layer_pos = fn.find("layer-"); + if (layer_pos != std::string::npos) { + try { + size_t num_start = layer_pos + 6; + size_t num_end = fn.find_first_not_of("0123456789", num_start); + parsed_layer = std::stoi(fn.substr(num_start, num_end - num_start)); + } catch (...) {} + } + + bool use_scaler = has_scaler; + bool use_pca = has_pca; + + if (!use_pca && linear_in_dim != cfg.hidden_size) { + // If linear_in_dim differs from hidden_size, PCA must be used + } + + std::cout << " -> Loading " << fn << " [S:" << (use_scaler ? "ON" : "OFF") << ", P:" << (use_pca ? "ON" : "OFF") + << ", Dim:" << linear_in_dim << ", Layer:" << parsed_layer << "]" << std::endl; + + auto probe = std::make_shared("", cfg.hidden_size, linear_in_dim, use_scaler, use_pca, + detect_linear_name, detect_scaler_prefix, detect_pca_name); + try { + probe->load(params); + } catch (const std::exception& e) { + std::cerr << "Error loading weights for " << fn << ": " << e.what() << std::endl; + continue; + } + + if (is_pos) { + pos_probe = probe; + if (parsed_layer != -1) pos_probe_layer_idx = parsed_layer; + } else { + if (parsed_layer == -1) continue; + if (is_prefill) + prefill_probes[parsed_layer].push_back(probe); + else + decode_probes[parsed_layer].push_back(probe); + } + } + std::cout << "Loaded Summary: Prefill(" << prefill_probes.size() << "), Decode(" << decode_probes.size() << "), Pos(" + << (pos_probe ? "Yes@L" + std::to_string(pos_probe_layer_idx) : "No") << ")" << std::endl; + } + + ARGenerationOutputPast forward(const ARGenerationOutputPast& input, const ARGenerationArgs& args) override { + // Standard forward pass (same as before) + auto sequence = input.at("sequence"); + auto batch_size = sequence.shape()[0]; + auto seq_len = sequence.shape()[1]; + Tensor position_ids = Tensor::nil(); + if (input.count("position_ids")) { + position_ids = input.at("position_ids"); + if (seq_len == 1) { + auto last_pos = *position_ids.offsettedPtr({0, position_ids.shape()[1] - 1}); + position_ids = Tensor::empty({batch_size, 1}, kInt64, kCPU).alloc(); + *position_ids.offsettedPtr({0, 0}) = last_pos + 1; + } + } else { + position_ids = Tensor::empty({batch_size, seq_len}, kInt64, kCPU).alloc(); + auto position_ids_ptr = position_ids.ptr(); + for (int b = 0; b < batch_size; ++b) { + for (int s = 0; s < seq_len; ++s) { position_ids_ptr[b * seq_len + s] = s; } + } + } + auto [llm_embedding_sin, llm_embedding_cos] = makeRotaryPosEmbedding(position_ids, getBuffer("inv_freq"), 1.0f); + std::vector forward_args = {args.at("k_cache_addrs"), args.at("v_cache_addrs"), args.at("prefix_cache_context")}; + if (args.count("probing_context")) forward_args.push_back(args.at("probing_context")); + + sequence = llm(sequence, llm_embedding_sin, llm_embedding_cos, forward_args)[0]; + { + auto S = sequence.shape()[1]; + sequence = sequence[{kAll, {S - 1}, kAll}]; + } + if (tie_word_embeddings_) { sequence = lm_head_(sequence); } + return {{"sequence", sequence}, {"position_ids", position_ids}}; + } + + const Qwen3Config cfg; + std::map>> prefill_probes; + std::map>> decode_probes; + std::shared_ptr pos_probe; + int pos_probe_layer_idx = -1; + + // Public exposure for collector + struct ProbeResult { + float score; + int layer; + std::string phase; + std::string type = "hallucination"; // "hallucination" or "pos_check" + bool is_key_predicted = false; + int token_idx = -1; + int token_id = -1; + }; + std::vector last_probe_results_; + void clearProbeResults() { last_probe_results_.clear(); } + + private: + Qwen3ProbingText llm; + Linear lm_head_; + bool tie_word_embeddings_; +}; + +// Session +class Qwen3ProbingSession final : public ::mllm::service::Session { + public: + Qwen3ProbingSession() = default; + + void setProbingArgs(const ProbingArgs& args) { probing_args_ = args; } + void loadProbes(const std::string& path, const ProbingArgs& args) { model_->loadProbesFromDirectory(path, args); } + + std::vector getLastProbeResults() { return model_->last_probe_results_; } + void clearLastProbeResults() { model_->clearProbeResults(); } + + std::size_t findThinkStartToken(const std::vector& output_ids) { + auto it = std::find(output_ids.begin(), output_ids.end(), model_->cfg.thinking_start_token_id); + return std::distance(output_ids.begin(), it); + } + + void streamGenerate(const nlohmann::json& request, + const std::function& callback) override { + mllm::cpu::wakeupHpcThreadPool(); + auto messages = request["messages"]; + + // 简短指令 + std::string concise_instruction = " Please answer in a single, complete sentence. Keep it concise."; + + bool has_system = false; + if (!messages.empty() && messages[0].value("role", "") == "system") { + std::string current_content = messages[0].value("content", ""); + messages[0]["content"] = current_content + concise_instruction; + has_system = true; + } + + if (!has_system) { + nlohmann::json sys_msg; + sys_msg["role"] = "system"; + sys_msg["content"] = "You are a helpful assistant." + concise_instruction; + messages.insert(messages.begin(), sys_msg); + } + auto inputs = applyChatTemplate(messages, {}, true, request.value("enable_thinking", false)); + auto full_seq_idx = tokenizer_->convert2Ids(tokenizer_->tokenize(inputs)).toVector(); + + ARGenerationArgs args; + ARGenerationOutputPast input; + auto prefix_cache_result = cache_->find(full_seq_idx); + std::span reduced_seq_idx(full_seq_idx.data() + prefix_cache_result.matched_length, + full_seq_idx.size() - prefix_cache_result.matched_length); + std::vector position_ids; + { + auto start = prefix_cache_result.matched_length; + auto end = full_seq_idx.size(); + position_ids.reserve(end - start); + std::ranges::copy(std::views::iota(static_cast(start), static_cast(end)), + std::back_inserter(position_ids)); + } + input["sequence"] = Tensor::fromVector(reduced_seq_idx, {1, (int32_t)reduced_seq_idx.size()}, kInt64, kCPU); + input["position_ids"] = Tensor::fromVector(position_ids, {1, (int32_t)position_ids.size()}, kInt64, kCPU); + k_cache_addrs_ = prefix_cache_result.k_cache_addresses; + v_cache_addrs_ = prefix_cache_result.v_cache_addresses; + + ProbingContext probe_ctx; + args["k_cache_addrs"] = &k_cache_addrs_; + args["v_cache_addrs"] = &v_cache_addrs_; + args["prefix_cache_context"] = cache_.get(); + args["probing_context"] = &probe_ctx; + args["temperature"] = request.value("temperature", 1.0f); + args["top_k"] = request.value("top_k", 0); + args["top_p"] = request.value("top_p", 0.0f); + auto max_length = request.value("max_length", 1024); + args["max_length"] = max_length; + args["do_sample"] = request.value("do_sample", false); + + bool stop_generating = false; + + // get Prefill Layers + std::vector current_prefill_layers; + if (request.contains("prefill_layers") && request["prefill_layers"].is_array()) { + current_prefill_layers = request["prefill_layers"].get>(); + } else { + current_prefill_layers = probing_args_.default_prefill_layers; + } + + if (probing_args_.enable_prefill_check && !current_prefill_layers.empty()) { + probe_ctx.reset(); + probe_ctx.collecting = true; + probe_ctx.save_last_token_only = true; + for (int l : current_prefill_layers) probe_ctx.target_layers.insert(l); + } + + struct CandidateKey { + int token_idx; + int token_id; + float score; + std::map activations; + }; + std::shared_ptr candidate_key = nullptr; + int debounce_counter = 0; + bool has_confirmed_key_in_decode = false; + + int64_t package_cnt = 0; + std::string accumulated_output = ""; + auto wrapped_callback = [this, &max_length, &request, &full_seq_idx, &package_cnt, &callback, &probe_ctx, &stop_generating, + &candidate_key, &debounce_counter, &has_confirmed_key_in_decode, + &accumulated_output](int64_t idx) { + if (stop_generating) return; + + // Calculate token string early for punctuation check + std::string current_token_str = preprocessor::wideString2Utf8String(tokenizer_->detokenize(idx)); + + // 0. Accumulate output (Wait for safety check) + // Do not append EOS token to the buffer + if (idx != model_->cfg.eos_token_id) { accumulated_output += current_token_str; } + + auto should_skip_token = [](std::string s) -> bool { + // 1. Trim leading/trailing whitespace (include common whitespace chars) + size_t start = s.find_first_not_of(" \t\n\r"); + if (start == std::string::npos) return true; // All spaces/empty + size_t end = s.find_last_not_of(" \t\n\r"); + std::string core = s.substr(start, end - start + 1); + + // 2. Check if Punctuation (all chars are punct) + bool all_punct = true; + for (unsigned char c : core) { + // If we find an alphanumeric char, it's NOT (just) punctuation + if (isalnum(c)) { + all_punct = false; + break; + } + // If high-bit (likely UTF-8 chinese/emoji), treat as Non-Punctuation for now (unless specific symbol list) + if (c & 0x80) { + all_punct = false; + break; + } + } + if (all_punct && !core.empty()) return true; + + // 3. Skip Blocklist (Articles, Prepositions, Conjunctions, Linking Verbs) + // Transform to lowercase for comparison + std::transform(core.begin(), core.end(), core.begin(), ::tolower); + static const std::set skip_words = { + // Articles + "the", "a", "an", + // Prepositions + "of", "in", "to", "for", "with", "on", "at", "from", "by", "about", "as", "into", "like", "through", "after", + "over", "between", "out", "against", "during", "without", "before", "under", "around", "among", + // Linking Verbs / Auxiliaries + "is", "are", "was", "were", "be", "been", "being", "have", "has", "had", "can", "could", "will", "would", "shall", + "should", "may", "might", "must", + // Conjunctions + "and", "but", "or", "nor", "so", "yet", "if", "than", "then", "else", "when", "while", "where", "because", "since", + "although", "though", "unless"}; + if (skip_words.count(core)) return true; + + return false; + }; + + bool skip_token = should_skip_token(current_token_str); + + // 1. Prefill check + if (probing_args_.enable_prefill_check && package_cnt == 0) { + float max_hallu = 0.0f; + for (auto& [layer_idx, tensor] : probe_ctx.mlp_outputs) { + if (model_->prefill_probes.count(layer_idx)) { + float total_prob = 0.0f; + auto& probes = model_->prefill_probes[layer_idx]; + if (probes.empty()) continue; + for (auto& probe : probes) total_prob += probe->predict(tensor); + float halluc_prob = 1.0f - (total_prob / probes.size()); + + if (model_->last_probe_results_.size() < 10000) { + model_->last_probe_results_.push_back({halluc_prob, layer_idx, "prefill", "hallucination", false, -1, -1}); + } + + if (halluc_prob >= probing_args_.prefill_stop_threshold) { + nlohmann::json stop_resp; + stop_resp["status"] = "early_exit_hallucination"; + stop_resp["score"] = halluc_prob; + stop_resp["layer"] = layer_idx; + stop_resp["phase"] = "prefill"; + callback(stop_resp.dump(), true); + stop_generating = true; + throw std::runtime_error("PROBING_INTERRUPT"); + } + } + } + probe_ctx.collecting = false; + } + + // 2. Decode check (Debounced) + bool finished = false; + if (idx == model_->cfg.eos_token_id || package_cnt + 1 >= max_length) finished = true; + + if (package_cnt > 0 && probing_args_.enable_decode_check && !has_confirmed_key_in_decode) { + bool is_new_potential = false; + int pos_layer = model_->cfg.num_hidden_layers - 1; + if (model_->pos_probe_layer_idx != -1) pos_layer = model_->pos_probe_layer_idx; + + if (idx != model_->cfg.eos_token_id && !skip_token && model_->pos_probe && probe_ctx.mlp_outputs.count(pos_layer)) { + float pos_score = model_->pos_probe->predict(probe_ctx.mlp_outputs[pos_layer]); + if (pos_score >= probing_args_.pos_threshold) { + // Found a key token! + is_new_potential = true; + + // Create/Update candidate + auto new_cand = std::make_shared(); + new_cand->token_idx = (int)package_cnt; + new_cand->token_id = (int)idx; + new_cand->score = pos_score; + + // Deep copy activations + for (auto const& [l, t_src] : probe_ctx.mlp_outputs) { + Tensor t_dst = Tensor::empty(t_src.shape(), t_src.dtype(), kCPU); + t_dst.alloc(); + size_t sz = t_src.numel() * (t_src.dtype() == mllm::kFloat32 ? 4 : 2); + if (t_src.ptr()) memcpy(t_dst.ptr(), t_src.ptr(), sz); + new_cand->activations[l] = t_dst; + } + + candidate_key = new_cand; + debounce_counter = 5; // Reset window + + std::cout << "[PosCheck] UpdateCandidate: '" << current_token_str << "' (" << pos_score << ")" << std::endl; + } + } + + // B. Debounce / Expiration Logic + bool trigger_hallu_check = false; + + if (!is_new_potential && candidate_key) { + debounce_counter--; + if (debounce_counter <= 0) trigger_hallu_check = true; + } + if (finished && candidate_key) trigger_hallu_check = true; + + // C. Execution + if (trigger_hallu_check && candidate_key) { + std::cout << "[PosCheck] ConfirmKey: Index " << candidate_key->token_idx << std::endl; + has_confirmed_key_in_decode = true; // Mark as done for this sentence + + if (model_->last_probe_results_.size() < 10000) { + Qwen3ProbingForCausalLM::ProbeResult res; + res.score = candidate_key->score; + res.layer = pos_layer; + res.phase = "decode"; + res.type = "pos_check"; + res.is_key_predicted = true; + res.token_idx = candidate_key->token_idx; + res.token_id = candidate_key->token_id; + model_->last_probe_results_.push_back(res); + } + + // Run Hallu Check on SAVED activations + for (auto& [layer_idx, tensor] : candidate_key->activations) { + // Only check Layer 22 for hallucination as mapped to user request + if (layer_idx != 22) continue; + + if (model_->decode_probes.count(layer_idx)) { + auto& probes = model_->decode_probes[layer_idx]; + if (probes.empty()) continue; + float total_prob = 0.0f; + for (auto& probe : probes) total_prob += probe->predict(tensor); + float halluc_prob = 1.0f - (total_prob / probes.size()); + + if (model_->last_probe_results_.size() < 100000) { + model_->last_probe_results_.push_back({halluc_prob, layer_idx, "decode", "hallucination", false, + candidate_key->token_idx, candidate_key->token_id}); + } + + if (halluc_prob >= probing_args_.decode_stop_threshold) { + nlohmann::json stop_resp; + stop_resp["status"] = "early_exit_hallucination"; + stop_resp["score"] = halluc_prob; + stop_resp["layer"] = layer_idx; + stop_resp["phase"] = "decode"; + callback(stop_resp.dump(), true); + stop_generating = true; + } + } + } + candidate_key = nullptr; // Clear after processing + if (stop_generating) throw std::runtime_error("PROBING_INTERRUPT"); + } + } + + // 3. Context reset for next token + if (!stop_generating) { + probe_ctx.soft_reset(); + probe_ctx.save_last_token_only = true; + if (probing_args_.enable_decode_check) { + probe_ctx.collecting = true; + int pos_layer = model_->cfg.num_hidden_layers - 1; + if (model_->pos_probe_layer_idx != -1) pos_layer = model_->pos_probe_layer_idx; + + probe_ctx.target_layers.insert(pos_layer); + for (auto const& [l, _] : model_->decode_probes) probe_ctx.target_layers.insert(l); + } + } + + // Only output if finished AND successful (no hallucination stop) + if (finished && !stop_generating) { callback(accumulated_output, true); } + + if (!finished) full_seq_idx.push_back(idx); + + package_cnt++; + }; + + try { + model_->streamGenerate(input, args, wrapped_callback); + } catch (const std::exception& e) { + if (std::string(e.what()) != "PROBING_INTERRUPT") std::cerr << e.what() << std::endl; + } + + auto thinking_end_token_idx = findThinkStartToken(full_seq_idx); + full_seq_idx.resize(thinking_end_token_idx); + for (auto& k_vec : k_cache_addrs_) k_vec.resize(thinking_end_token_idx); + for (auto& v_vec : v_cache_addrs_) v_vec.resize(thinking_end_token_idx); + cache_->promote(full_seq_idx, k_cache_addrs_, v_cache_addrs_); + k_cache_addrs_ = {}; + v_cache_addrs_ = {}; + mllm::cpu::idleHpcThreadPool(); + } + + void fromPreTrain(const std::string& model_path) override { + namespace fs = std::filesystem; + fs::path root = fs::path(model_path).lexically_normal(); + auto cfg = Qwen3Config((root / "config.json").string()); + model_ = std::make_shared(cfg); + model_->load(mllm::load((root / "model.mllm").string(), ModelFileVersion::kV2)); + tokenizer_ = std::make_shared((root / "tokenizer.json").string()); + cache_ = std::make_shared(prefix_cache::CacheOptions{ + .radix_tree_options = {.enable_lru_eviction = false, + .eviction_threshold = 0.9f, + .enable_path_compression = false, + .min_compression_length = 2, + .transformer_blocks_num = cfg.num_hidden_layers}, + .allocator_options = {.per_k_token_ele = static_cast(cfg.head_dim * cfg.num_key_value_heads), + .per_v_token_ele = static_cast(cfg.head_dim * cfg.num_key_value_heads), + .k_dtype = mllm::kFloat32, + .v_dtype = mllm::kFloat32, + .enable_cuda = false, + .cuda_mem_base = 0x100000, + .enable_cpu_hierarchy_memory = true, + .zen_fs_options = { + .record = false, + .working_dir = ".", + .blob_bits_size = 20, + .page_bits = 7, + .lane_bits = 5, + .per_k_token_ele = static_cast(cfg.head_dim * cfg.num_key_value_heads), + .per_v_token_ele = static_cast(cfg.head_dim * cfg.num_key_value_heads), + .k_dtype = mllm::kFloat32, + .v_dtype = mllm::kFloat32, + .mmap_type = mllm::prefix_cache::ZenFSBlobMMapType::kAnonymous, + }}}); + } + + std::string ltrim(const std::string& s) { + size_t start = s.find_first_not_of(" \n\r\t\f\v"); + return (start == std::string::npos) ? "" : s.substr(start); + } + std::string rtrim(const std::string& s) { + size_t end = s.find_last_not_of(" \n\r\t\f\v"); + return (end == std::string::npos) ? "" : s.substr(0, end + 1); + } + std::string trim(const std::string& s) { return rtrim(ltrim(s)); } + + std::string applyChatTemplate(const nlohmann::json& messages, const std::vector& tools = {}, + bool add_generation_prompt = true, bool enable_thinking = true, + const std::string& bos_token = "", const std::string& eos_token = "<|im_end|>") { + std::ostringstream oss; + if (!tools.empty()) { + oss << "<|im_start|>system\n"; + if (!messages.empty() && messages[0].value("role", "") == "system") { oss << messages[0].value("content", "") << "\n\n"; } + oss << "# Tools\n\nYou may call one or more functions to assist with the user query.\n"; + oss << "You are provided with function signatures within XML tags:\n"; + for (const auto& tool : tools) { oss << "\n" << tool.dump(); } + oss << "\n\n\nFor each function call, return a json object with function name and arguments within " + " XML tags:\n\n{\"name\": , \"arguments\": " + "}\n<|im_end|>\n"; + } else { + if (!messages.empty() && messages[0].value("role", "") == "system") { + oss << "<|im_start|>system\n" << messages[0].value("content", "") << "<|im_end|>\n"; + } + } + + size_t last_query_index = messages.empty() ? 0 : messages.size() - 1; + bool found_last_query = false; + if (!messages.empty()) { + for (int i = messages.size() - 1; i >= 0; --i) { + const auto& msg = messages[i]; + if (msg.value("role", "") == "user" && msg.contains("content") && msg["content"].is_string()) { + std::string content_str = msg["content"].get(); + if (!(content_str.starts_with("") + && content_str.find("") == content_str.length() - std::string("").length())) { + last_query_index = i; + found_last_query = true; + break; + } + } + } + } + if (messages.empty()) { found_last_query = false; } + + for (size_t i = 0; i < messages.size(); ++i) { + const auto& message = messages[i]; + std::string role = message.value("role", ""); + std::string content; + if (message.contains("content") && message["content"].is_string()) { content = message["content"].get(); } + + if (role == "user" || (role == "system" && i > 0)) { + oss << "<|im_start|>" << role << "\n" << content << "<|im_end|>\n"; + } else if (role == "assistant") { + std::string reasoning_content; + if (message.contains("reasoning_content") && message["reasoning_content"].is_string()) { + reasoning_content = message["reasoning_content"].get(); + } else { + auto think_end_pos = content.find(""); + if (think_end_pos != std::string::npos) { + auto think_start_pos = content.rfind("", think_end_pos); + if (think_start_pos != std::string::npos) { + reasoning_content = content.substr(think_start_pos + 7, think_end_pos - (think_start_pos + 7)); + content = content.substr(think_end_pos + 8); + } + } + } + + oss << "<|im_start|>" << role << "\n"; + if (found_last_query && i > last_query_index) { + if ((i == messages.size() - 1) || !reasoning_content.empty()) { + oss << "\n" << trim(reasoning_content) << "\n\n\n" << ltrim(content); + } else { + oss << content; + } + } else { + oss << content; + } + + if (message.contains("tool_calls")) { + bool is_first_tool = true; + for (const auto& tool_call_item : message["tool_calls"]) { + if ((is_first_tool && !content.empty()) || !is_first_tool) { oss << "\n"; } + is_first_tool = false; + const nlohmann::json* tool_call_ptr = &tool_call_item; + if (tool_call_item.contains("function")) { tool_call_ptr = &tool_call_item["function"]; } + const nlohmann::json& tool_call = *tool_call_ptr; + oss << "\n{\"name\": \"" << tool_call.value("name", "") << R"(", "arguments": )"; + const auto& args = tool_call["arguments"]; + if (args.is_string()) { + oss << args.get(); + } else { + oss << args.dump(); + } + oss << "}\n"; + } + } + oss << "<|im_end|>\n"; + } else if (role == "tool") { + if (i == 0 || messages[i - 1].value("role", "") != "tool") { oss << "<|im_start|>user"; } + oss << "\n\n" << content << "\n"; + if (i == messages.size() - 1 || messages[i + 1].value("role", "") != "tool") { oss << "<|im_end|>\n"; } + } + } + + if (add_generation_prompt) { + oss << "<|im_start|>assistant\n"; + if (!enable_thinking) { oss << "\n\n\n\n"; } + } + return oss.str(); + } + + private: + std::vector> k_cache_addrs_; + std::vector> v_cache_addrs_; + std::shared_ptr model_; + std::shared_ptr tokenizer_; + std::shared_ptr cache_; + ProbingArgs probing_args_; + + public: + std::vector getLastProbeResults() const { + if (model_) return model_->last_probe_results_; + return {}; + } +}; + +} // namespace mllm::models::qwen3_probing \ No newline at end of file From 37df73e2f44781a06a9e21a22e07f69034114a61 Mon Sep 17 00:00:00 2001 From: yuerqiqi <2500526025@qq.com> Date: Mon, 2 Feb 2026 17:46:19 +0800 Subject: [PATCH 2/2] Refactor logits handling for float types --- mllm/models/qwen3/modeling_qwen3_probing_service.hpp | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/mllm/models/qwen3/modeling_qwen3_probing_service.hpp b/mllm/models/qwen3/modeling_qwen3_probing_service.hpp index 8565ab6fc..f4b3ceaf7 100644 --- a/mllm/models/qwen3/modeling_qwen3_probing_service.hpp +++ b/mllm/models/qwen3/modeling_qwen3_probing_service.hpp @@ -161,9 +161,7 @@ class ProbeClassifier : public Module { float val = 0.0f; if (logits.dtype() == mllm::kFloat32) { val = logits.ptr()[0]; - } else if (logits.dtype() == mllm::kFloat16) { - val = (float)logits.ptr<__fp16>()[0]; - } + } return 1.0f / (1.0f + std::exp(-val)); } @@ -1054,4 +1052,4 @@ class Qwen3ProbingSession final : public ::mllm::service::Session { } }; -} // namespace mllm::models::qwen3_probing \ No newline at end of file +} // namespace mllm::models::qwen3_probing