From de614c756b91355e50cdd4e4de476e471e6ff7d5 Mon Sep 17 00:00:00 2001 From: Cheng Li Date: Sat, 14 Feb 2026 17:03:06 +0000 Subject: [PATCH 1/2] wip --- src/utils/serialize.cpp | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/utils/serialize.cpp b/src/utils/serialize.cpp index 60ca741..ba5154e 100644 --- a/src/utils/serialize.cpp +++ b/src/utils/serialize.cpp @@ -1,5 +1,6 @@ // torch #include +#include // kintera #include "serialize.hpp" @@ -17,6 +18,17 @@ void save_tensors(const std::map& tensor_map, void load_tensors(std::map& tensor_map, const std::string& filename) { + // get keys + torch::jit::Module m = torch::jit::load(filename); + + for (const auto& p : m.named_parameters(/*recurse=*/true)) { + tensor_map[p.name] = p.value; + } + + for (const auto& b : m.named_buffers(/*recurse=*/true)) { + tensor_map[p.name] = p.value; + } + torch::serialize::InputArchive archive; archive.load_from(filename); for (auto& pair : tensor_map) { From cd4bb5b79ccac73601e4e6317a2d7d4c6d91e7ab Mon Sep 17 00:00:00 2001 From: mac/cli Date: Sun, 15 Feb 2026 11:36:37 -0500 Subject: [PATCH 2/2] wip --- src/utils/serialize.cpp | 21 ++++++++++++--------- src/utils/serialize.hpp | 3 +-- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/src/utils/serialize.cpp b/src/utils/serialize.cpp index ba5154e..e3f88fe 100644 --- a/src/utils/serialize.cpp +++ b/src/utils/serialize.cpp @@ -1,6 +1,6 @@ // torch -#include #include +#include // kintera #include "serialize.hpp" @@ -16,28 +16,31 @@ void save_tensors(const std::map& tensor_map, archive.save_to(filename); } -void load_tensors(std::map& tensor_map, - const std::string& filename) { +std::map load_tensors(const std::string& filename) { + std::map data; + // get keys torch::jit::Module m = torch::jit::load(filename); for (const auto& p : m.named_parameters(/*recurse=*/true)) { - tensor_map[p.name] = p.value; + data[p.name] = p.value; } - for (const auto& b : m.named_buffers(/*recurse=*/true)) { - tensor_map[p.name] = p.value; + for (const auto& p : m.named_buffers(/*recurse=*/true)) { + data[p.name] = p.value; } - torch::serialize::InputArchive archive; + /*torch::serialize::InputArchive archive; archive.load_from(filename); - for (auto& pair : tensor_map) { + for (auto& pair : data) { try { archive.read(pair.first, pair.second); } catch (const c10::Error& e) { // skip missing tensors } - } + }*/ + + return data; } } // namespace kintera diff --git a/src/utils/serialize.hpp b/src/utils/serialize.hpp index 4e15b1e..0f07762 100644 --- a/src/utils/serialize.hpp +++ b/src/utils/serialize.hpp @@ -11,7 +11,6 @@ namespace kintera { void save_tensors(const std::map& tensor_map, const std::string& filename); -void load_tensors(std::map& tensor_map, - const std::string& filename); +std::map load_tensors(const std::string& filename); } // namespace kintera