Exposing HDF5 saving and loading to python#4227
Conversation
|
Would you considering updating this to also add loadHDF5 as well? If the point of this is to save models with tensors > 2GB, then it would be nice to be able to load them in PyCaffe as well. Here's a patch that I'd suggest modifying yours to: diff --git a/python/caffe/_caffe.cpp b/python/caffe/_caffe.cpp
--- a/python/caffe/_caffe.cpp
+++ b/python/caffe/_caffe.cpp
@@ -101,6 +101,14 @@
WriteProtoToBinaryFile(net_param, filename.c_str());
}
+void Net_SaveHDF5(const Net<Dtype>& net, string filename) {
+ net.ToHDF5(filename.c_str(), false);
+}
+
+void Net_LoadHDF5(Net<Dtype>* net, string filename) {
+ net->CopyTrainedLayersFromHDF5(filename.c_str());
+}
+
void Net_SetInputArrays(Net<Dtype>* net, bp::object data_obj,
bp::object labels_obj) {
// check that this network has an input MemoryDataLayer
@@ -254,6 +262,8 @@
bp::return_value_policy<bp::copy_const_reference>()))
.def("_set_input_arrays", &Net_SetInputArrays,
bp::with_custodian_and_ward<1, 2, bp::with_custodian_and_ward<1, 3> >())
+ .def("load_hdf5", &Net_LoadHDF5)
+ .def("save_hdf5", &Net_SaveHDF5)
.def("save", &Net_Save);
bp::class_<Blob<Dtype>, shared_ptr<Blob<Dtype> >, boost::noncopyable>(
diff --git a/python/caffe/test/test_net.py b/python/caffe/test/test_net.py
--- a/python/caffe/test/test_net.py
+++ b/python/caffe/test/test_net.py
@@ -79,3 +79,17 @@
for i in range(len(self.net.params[name])):
self.assertEqual(abs(self.net.params[name][i].data
- net2.params[name][i].data).sum(), 0)
+
+ def test_save_hdf5(self):
+ f = tempfile.NamedTemporaryFile(mode='w+', delete=False)
+ f.close()
+ self.net.save_hdf5(f.name)
+ net_file = simple_net_file(self.num_output)
+ net2 = caffe.Net(net_file, caffe.TRAIN)
+ net2.load_hdf5(f.name)
+ os.remove(net_file)
+ os.remove(f.name)
+ for name in self.net.params:
+ for i in range(len(self.net.params[name])):
+ self.assertEqual(abs(self.net.params[name][i].data
+ - net2.params[name][i].data).sum(), 0) |
|
thank you @ajtulloch , I just added the load_hdf5 function. |
|
LGTM. @longjon, @shelhamer? |
| self.net.save_hdf5(f.name) | ||
| net_file = simple_net_file(self.num_output) | ||
| net2 = caffe.Net(net_file, caffe.TRAIN) | ||
| net2.load_hdf5(f.name) |
There was a problem hiding this comment.
You're using camel case in the _caffe.cpp file (saveHDF5, loadHDF5), but snake case here. I think for consistency you should use save_hdf5, load_hdf5 in the _caffe.cpp file. It's Python PEP-8 style to use snake_case for member functions.
|
Could you fix the casing issue? That will fix travis. |
|
@philkr I think you meant |
|
@shelhamer yes I did, changed it now. |
|
Thanks Philipp for exposing hdf5 net serialization to pycaffe. |
|
Thanks @philkr! |
[pycaffe] expose saving/loading nets as hdf5 to python
Exposes a new function 'Net.saveHDF5' to python. This allows caffe to store model weights in the hdf5 format (which can easily be used in other libraries).