Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions python/tvm/relay/testing/vgg.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ def get_workload(batch_size,
num_classes=1000,
image_shape=(3, 224, 224),
dtype="float32",
num_layers=11):
num_layers=11,
batch_norm=False):
"""Get benchmark workload for VGG nets.

Parameters
Expand All @@ -118,6 +119,9 @@ def get_workload(batch_size,
num_layers : int
Number of layers for the variant of vgg. Options are 11, 13, 16, 19.

batch_norm : bool
Use batch normalization.

Returns
-------
net : nnvm.Symbol
Expand All @@ -126,5 +130,5 @@ def get_workload(batch_size,
params : dict of str to NDArray
The parameters.
"""
net = get_net(batch_size, image_shape, num_classes, dtype, num_layers)
net = get_net(batch_size, image_shape, num_classes, dtype, num_layers, batch_norm)
return create_workload(net)