Conversation
There was a problem hiding this comment.
Pull request overview
This PR revises the tensor deserialization API in src/utils/serialize by changing load_tensors from an out-parameter function to a function that returns a new tensor map, and updates the implementation to load data via TorchScript.
Changes:
- Changed
load_tensorssignature to returnstd::map<std::string, torch::Tensor>instead of mutating an input map. - Reimplemented
load_tensorsto calltorch::jit::loadand populate the map from module parameters/buffers. - Added the TorchScript header include needed for
torch::jit::load.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
src/utils/serialize.hpp |
Updates the load_tensors function declaration to return a tensor map. |
src/utils/serialize.cpp |
Updates load_tensors implementation to use TorchScript loading and return a populated map. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| /*torch::serialize::InputArchive archive; | ||
| archive.load_from(filename); | ||
| for (auto& pair : tensor_map) { | ||
| for (auto& pair : data) { | ||
| try { | ||
| archive.read(pair.first, pair.second); | ||
| } catch (const c10::Error& e) { | ||
| // skip missing tensors | ||
| } | ||
| } | ||
| }*/ |
There was a problem hiding this comment.
The large commented-out InputArchive block should be removed or re-enabled behind a clear conditional/compile-time flag. Leaving dead code in production makes the intended behavior ambiguous and increases maintenance cost.
|
|
||
| void load_tensors(std::map<std::string, torch::Tensor>& tensor_map, | ||
| const std::string& filename); | ||
| std::map<std::string, torch::Tensor> load_tensors(const std::string& filename); |
There was a problem hiding this comment.
The load_tensors API now returns a new map, but the implementation currently only extracts TorchScript module parameters/buffers rather than loading the tensors written by save_tensors. If this is meant to load diagnostic .pt files created by save_tensors, the declaration/semantics should match that format (or the function name/docs should be updated to reflect the TorchScript-only behavior).
| std::map<std::string, torch::Tensor> load_tensors(const std::string& filename) { | ||
| std::map<std::string, torch::Tensor> data; | ||
|
|
||
| // get keys | ||
| torch::jit::Module m = torch::jit::load(filename); | ||
|
|
||
| for (const auto& p : m.named_parameters(/*recurse=*/true)) { | ||
| data[p.name] = p.value; | ||
| } | ||
|
|
||
| for (const auto& p : m.named_buffers(/*recurse=*/true)) { | ||
| data[p.name] = p.value; | ||
| } |
There was a problem hiding this comment.
load_tensors now calls torch::jit::load and returns named_parameters/named_buffers, but save_tensors writes a torch::serialize::OutputArchive. These serialization formats are not compatible, so load_tensors will fail (or silently return unrelated data) when given files produced by save_tensors. Consider either (a) switching load_tensors back to torch::serialize::InputArchive to read the same keys that were written, or (b) updating save_tensors to write a TorchScript Module (register tensors as buffers/parameters) if TorchScript loading is the new intended format.
|
🎉 Released v1.3.1! What's Changed
Full Changelog: v1.3.0...v1.3.1 |
No description provided.