Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/tvm/relay/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
from . import mlp
from . import resnet
from . import dqn
from . import dcgan
96 changes: 96 additions & 0 deletions python/tvm/relay/testing/dcgan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# pylint: disable=unused-argument
"""
Net of the generator of DCGAN

Adopted from:
https://github.com/tqchen/mxnet-gan/blob/master/mxgan/generator.py

Reference:
Radford, Alec, Luke Metz, and Soumith Chintala.
"Unsupervised representation learning with deep convolutional generative adversarial networks."
arXiv preprint arXiv:1511.06434 (2015).
"""
from tvm import relay
from . import layers
from .init import create_workload

def deconv2d(data, ishape, oshape, kshape, name, stride=(2, 2)):
"""a deconv layer that enlarges the feature map"""
target_shape = (oshape[-2], oshape[-1])

pad_y = (kshape[0] - 1) // 2
pad_x = (kshape[1] - 1) // 2
adj_y = (target_shape[0] + 2 * pad_y - kshape[0]) % stride[0]
adj_x = (target_shape[1] + 2 * pad_x - kshape[1]) % stride[1]

net = layers.conv2d_transpose(data,
kernel_size=kshape,
strides=stride,
channels=oshape[0],
padding=(pad_y, pad_x),
output_padding=(adj_y, adj_x),
name=name)
return net

def deconv2d_bn_relu(data, prefix, **kwargs):
"""a block of deconv + batch norm + relu"""
eps = 1e-5 + 1e-12
net = deconv2d(data, name="%s_deconv" % prefix, **kwargs)
net = layers.batch_norm_infer(net, epsilon=eps, name="batch_norm")
net = relay.nn.relu(net)
return net

def get_net(batch_size, random_len=100, oshape=(3, 64, 64), ngf=128, code=None, dtype="float32"):
"""get net of dcgan generator"""
assert oshape[-1] == 64, "Only support 64x64 image"
assert oshape[-2] == 64, "Only support 64x64 image"

code = relay.var("data", dtype=dtype, shape=(batch_size, random_len)) if code is None else code
dense_weight = relay.var("dense_weight")
dense = relay.nn.dense(code, weight=dense_weight, units=4*4*ngf*8)
relu = relay.nn.relu(dense)
# 4 x 4
reshape = relay.reshape(relu, newshape=(-1, ngf * 8, 4, 4))
# 8 x 8
dc8 = deconv2d_bn_relu(
reshape, ishape=(ngf * 8, 4, 4), oshape=(ngf * 4, 8, 8), kshape=(4, 4), prefix="g2")
# 16x16
dc16 = deconv2d_bn_relu(
dc8, ishape=(ngf * 4, 8, 8), oshape=(ngf * 2, 16, 16), kshape=(4, 4), prefix="g3")
# 32x32
dc32 = deconv2d_bn_relu(
dc16, ishape=(ngf * 2, 16, 16), oshape=(ngf, 32, 32), kshape=(4, 4), prefix="g4")
# 64x64
dc64 = deconv2d(
dc32, ishape=(ngf, 32, 32), oshape=oshape[-3:], kshape=(4, 4), name="g5_deconv")
tanh = relay.tanh(dc64)

args = relay.ir_pass.free_vars(tanh)
return relay.Function(args, tanh)


def get_workload(batch_size, oshape=(3, 64, 64), ngf=128, random_len=100, dtype="float32"):
"""Get benchmark workload for a DCGAN generator

Parameters
----------
batch_size : int
The batch size used in the model
oshape : tuple, optional
The shape of output image, layout="CHW"
ngf: int, optional
The number of final feature maps in the generator
random_len : int, optional
The length of random input
dtype : str, optional
The data type

Returns
-------
net : nnvm.symbol
The computational graph
params : dict of str to NDArray
The parameters.
"""
net = get_net(batch_size, random_len, oshape=oshape, ngf=ngf, dtype=dtype)
return create_workload(net)
10 changes: 10 additions & 0 deletions python/tvm/relay/testing/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,25 @@ def get_net(batch_size, num_actions=18, image_shape=(4, 84, 84), dtype="float32"
"""get symbol of nature dqn"""
data_shape = (batch_size,) + image_shape
data = relay.var("data", shape=data_shape, dtype=dtype)

conv1_bias = relay.var("conv1_bias")
conv1 = layers.conv2d(data, kernel_size=(8, 8), strides=(4, 4), padding=(0, 0),
channels=32, name="conv1")
conv1 = relay.nn.bias_add(conv1, conv1_bias)
relu1 = relay.nn.relu(conv1)

conv2_bias = relay.var("conv2_bias")
conv2 = layers.conv2d(relu1, kernel_size=(4, 4), strides=(2, 2), padding=(0, 0),
channels=64, name="conv2")
conv2 = relay.nn.bias_add(conv2, conv2_bias)
relu2 = relay.nn.relu(conv2)

conv3_bias = relay.var("conv3_bias")
conv3 = layers.conv2d(relu2, kernel_size=(3, 3), strides=(1, 1), padding=(0, 0),
channels=64, name="conv3")
conv3 = relay.nn.bias_add(conv3, conv3_bias)
relu3 = relay.nn.relu(conv3)

bf1 = relay.nn.batch_flatten(relu3)
dense1 = layers.dense_add_bias(bf1, units=512, name="dense1")
relu4 = relay.nn.relu(dense1)
Expand Down
24 changes: 24 additions & 0 deletions python/tvm/relay/testing/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,30 @@ def conv2d(data, weight=None, **kwargs):
weight = relay.var(name + "_weight")
return relay.nn.conv2d(data, weight, **kwargs)

def conv2d_transpose(data, weight=None, **kwargs):
"""Wrapper of conv2d_transpose which automatically creates weights if not given.

Parameters
----------
data : relay.Expr
The input expression.

weight : relay.Expr
The weight to conv2d_transpose.

kwargs : dict
Additional arguments.

Returns
-------
result : relay.Expr
The result.
"""
name = kwargs.get("name")
kwargs.pop("name")
if not weight:
weight = relay.var(name + "_weight")
return relay.nn.conv2d_transpose(data, weight, **kwargs)

def dense_add_bias(data, weight=None, bias=None, **kwargs):
"""Wrapper of dense which automatically creates weights if not given.
Expand Down
7 changes: 6 additions & 1 deletion tests/python/relay/test_ir_text_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,13 +106,18 @@ def test_resnet():

def test_dqn():
net, params = tvm.relay.testing.dqn.get_workload(batch_size=1)
show(net.astext())
net.astext()

def test_dcgan():
net, params = tvm.relay.testing.dcgan.get_workload(batch_size=1)
net.astext()

if __name__ == "__main__":
do_print[0] = True
test_resnet()
test_mlp()
test_dqn()
test_dcgan()
test_func()
test_env()
test_meta_data()
Expand Down