diff --git a/python/caffe/draw.py b/python/caffe/draw.py index 61205ca9f37..76d0c40ae95 100644 --- a/python/caffe/draw.py +++ b/python/caffe/draw.py @@ -91,11 +91,11 @@ def get_layer_label(layer, rankdir): separator, layer.type, separator, - layer.convolution_param.kernel_size[0] if len(layer.convolution_param.kernel_size._values) else 1, + layer.convolution_param.kernel_size[0] if len(layer.convolution_param.kernel_size) else 1, separator, - layer.convolution_param.stride[0] if len(layer.convolution_param.stride._values) else 1, + layer.convolution_param.stride[0] if len(layer.convolution_param.stride) else 1, separator, - layer.convolution_param.pad[0] if len(layer.convolution_param.pad._values) else 0) + layer.convolution_param.pad[0] if len(layer.convolution_param.pad) else 0) elif layer.type == 'Pooling': pooling_types_dict = get_pooling_types_dict() node_label = '"%s%s(%s %s)%skernel size: %d%sstride: %d%spad: %d"' %\ diff --git a/python/caffe/test/test_draw.py b/python/caffe/test/test_draw.py new file mode 100644 index 00000000000..835bb5df010 --- /dev/null +++ b/python/caffe/test/test_draw.py @@ -0,0 +1,37 @@ +import os +import unittest + +from google.protobuf import text_format + +import caffe.draw +from caffe.proto import caffe_pb2 + +def getFilenames(): + """Yields files in the source tree which are Net prototxts.""" + result = [] + + root_dir = os.path.abspath(os.path.join( + os.path.dirname(__file__), '..', '..', '..')) + assert os.path.exists(root_dir) + + for dirname in ('models', 'examples'): + dirname = os.path.join(root_dir, dirname) + assert os.path.exists(dirname) + for cwd, _, filenames in os.walk(dirname): + for filename in filenames: + filename = os.path.join(cwd, filename) + if filename.endswith('.prototxt') and 'solver' not in filename: + yield os.path.join(dirname, filename) + + +class TestDraw(unittest.TestCase): + def test_draw_net(self): + for filename in getFilenames(): + net = caffe_pb2.NetParameter() + with open(filename) as infile: + text_format.Merge(infile.read(), net) + caffe.draw.draw_net(net, 'LR') + + +if __name__ == "__main__": + unittest.main() diff --git a/scripts/travis/install-deps.sh b/scripts/travis/install-deps.sh index b04dd38d626..2824423bb0a 100755 --- a/scripts/travis/install-deps.sh +++ b/scripts/travis/install-deps.sh @@ -8,6 +8,7 @@ source $BASEDIR/defaults.sh apt-get -y update apt-get install -y --no-install-recommends \ build-essential \ + graphviz \ libboost-filesystem-dev \ libboost-python-dev \ libboost-system-dev \ @@ -31,6 +32,7 @@ if ! $WITH_PYTHON3 ; then python-dev \ python-numpy \ python-protobuf \ + python-pydot \ python-skimage else # Python3 diff --git a/scripts/travis/install-python-deps.sh b/scripts/travis/install-python-deps.sh index eeec302791f..910d35a93be 100755 --- a/scripts/travis/install-python-deps.sh +++ b/scripts/travis/install-python-deps.sh @@ -11,4 +11,5 @@ if ! $WITH_PYTHON3 ; then else # Python3 pip install --pre protobuf==3.0.0b3 + pip install pydot fi