Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions python/caffe/draw.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,11 @@ def get_layer_label(layer, rankdir):
separator,
layer.type,
separator,
layer.convolution_param.kernel_size[0] if len(layer.convolution_param.kernel_size._values) else 1,
layer.convolution_param.kernel_size[0] if len(layer.convolution_param.kernel_size) else 1,
separator,
layer.convolution_param.stride[0] if len(layer.convolution_param.stride._values) else 1,
layer.convolution_param.stride[0] if len(layer.convolution_param.stride) else 1,
separator,
layer.convolution_param.pad[0] if len(layer.convolution_param.pad._values) else 0)
layer.convolution_param.pad[0] if len(layer.convolution_param.pad) else 0)
elif layer.type == 'Pooling':
pooling_types_dict = get_pooling_types_dict()
node_label = '"%s%s(%s %s)%skernel size: %d%sstride: %d%spad: %d"' %\
Expand Down
37 changes: 37 additions & 0 deletions python/caffe/test/test_draw.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import os
import unittest

from google.protobuf import text_format

import caffe.draw
from caffe.proto import caffe_pb2

def getFilenames():
"""Yields files in the source tree which are Net prototxts."""
result = []

root_dir = os.path.abspath(os.path.join(
os.path.dirname(__file__), '..', '..', '..'))
assert os.path.exists(root_dir)

for dirname in ('models', 'examples'):
dirname = os.path.join(root_dir, dirname)
assert os.path.exists(dirname)
for cwd, _, filenames in os.walk(dirname):
for filename in filenames:
filename = os.path.join(cwd, filename)
if filename.endswith('.prototxt') and 'solver' not in filename:
yield os.path.join(dirname, filename)


class TestDraw(unittest.TestCase):
def test_draw_net(self):
for filename in getFilenames():
net = caffe_pb2.NetParameter()
with open(filename) as infile:
text_format.Merge(infile.read(), net)
caffe.draw.draw_net(net, 'LR')


if __name__ == "__main__":
unittest.main()
2 changes: 2 additions & 0 deletions scripts/travis/install-deps.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ source $BASEDIR/defaults.sh
apt-get -y update
apt-get install -y --no-install-recommends \
build-essential \
graphviz \
libboost-filesystem-dev \
libboost-python-dev \
libboost-system-dev \
Expand All @@ -31,6 +32,7 @@ if ! $WITH_PYTHON3 ; then
python-dev \
python-numpy \
python-protobuf \
python-pydot \
python-skimage
else
# Python3
Expand Down
1 change: 1 addition & 0 deletions scripts/travis/install-python-deps.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@ if ! $WITH_PYTHON3 ; then
else
# Python3
pip install --pre protobuf==3.0.0b3
pip install pydot
fi