diff --git a/src/utils/serialize.cpp b/src/utils/serialize.cpp index 60ca741..e3f88fe 100644 --- a/src/utils/serialize.cpp +++ b/src/utils/serialize.cpp @@ -1,4 +1,5 @@ // torch +#include #include // kintera @@ -15,17 +16,31 @@ void save_tensors(const std::map& tensor_map, archive.save_to(filename); } -void load_tensors(std::map& tensor_map, - const std::string& filename) { - torch::serialize::InputArchive archive; +std::map load_tensors(const std::string& filename) { + std::map data; + + // get keys + torch::jit::Module m = torch::jit::load(filename); + + for (const auto& p : m.named_parameters(/*recurse=*/true)) { + data[p.name] = p.value; + } + + for (const auto& p : m.named_buffers(/*recurse=*/true)) { + data[p.name] = p.value; + } + + /*torch::serialize::InputArchive archive; archive.load_from(filename); - for (auto& pair : tensor_map) { + for (auto& pair : data) { try { archive.read(pair.first, pair.second); } catch (const c10::Error& e) { // skip missing tensors } - } + }*/ + + return data; } } // namespace kintera diff --git a/src/utils/serialize.hpp b/src/utils/serialize.hpp index 4e15b1e..0f07762 100644 --- a/src/utils/serialize.hpp +++ b/src/utils/serialize.hpp @@ -11,7 +11,6 @@ namespace kintera { void save_tensors(const std::map& tensor_map, const std::string& filename); -void load_tensors(std::map& tensor_map, - const std::string& filename); +std::map load_tensors(const std::string& filename); } // namespace kintera