diff --git a/python/tvm/relay/testing/__init__.py b/python/tvm/relay/testing/__init__.py index 9bfda3c7abc7..2a9b81b84230 100644 --- a/python/tvm/relay/testing/__init__.py +++ b/python/tvm/relay/testing/__init__.py @@ -3,3 +3,4 @@ from . import mlp from . import resnet +from . import dqn diff --git a/python/tvm/relay/testing/dqn.py b/python/tvm/relay/testing/dqn.py new file mode 100644 index 000000000000..736894612e19 --- /dev/null +++ b/python/tvm/relay/testing/dqn.py @@ -0,0 +1,71 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Net of Nature DQN +Reference: +Mnih, Volodymyr, et al. "Human-level control through deep reinforcement learning." +Nature 518.7540 (2015): 529. +""" + +from tvm import relay +from . import layers +from .init import create_workload + +def get_net(batch_size, num_actions=18, image_shape=(4, 84, 84), dtype="float32"): + """get symbol of nature dqn""" + data_shape = (batch_size,) + image_shape + data = relay.var("data", shape=data_shape, dtype=dtype) + conv1 = layers.conv2d(data, kernel_size=(8, 8), strides=(4, 4), padding=(0, 0), + channels=32, name="conv1") + relu1 = relay.nn.relu(conv1) + conv2 = layers.conv2d(relu1, kernel_size=(4, 4), strides=(2, 2), padding=(0, 0), + channels=64, name="conv2") + relu2 = relay.nn.relu(conv2) + conv3 = layers.conv2d(relu2, kernel_size=(3, 3), strides=(1, 1), padding=(0, 0), + channels=64, name="conv3") + relu3 = relay.nn.relu(conv3) + bf1 = relay.nn.batch_flatten(relu3) + dense1 = layers.dense_add_bias(bf1, units=512, name="dense1") + relu4 = relay.nn.relu(dense1) + dense2 = layers.dense_add_bias(relu4, units=num_actions, name="dense2") + + args = relay.ir_pass.free_vars(dense2) + return relay.Function(args, dense2) + + +def get_workload(batch_size, num_actions=18, image_shape=(4, 84, 84), dtype="float32"): + """Get benchmark workload for a Deep Q Network + Parameters + ---------- + batch_size : int + The batch size used in the model + num_actions : int, optional + Number of actions + image_shape : tuple, optional + The input image shape + dtype : str, optional + The data type + Returns + ------- + net : nnvm.symbol + The computational graph + params : dict of str to NDArray + The parameters. + """ + net = get_net(batch_size, num_actions=num_actions, image_shape=image_shape, dtype=dtype) + return create_workload(net) diff --git a/tests/python/relay/test_ir_text_printer.py b/tests/python/relay/test_ir_text_printer.py index 69ba4797a1c7..7b2c343b0844 100644 --- a/tests/python/relay/test_ir_text_printer.py +++ b/tests/python/relay/test_ir_text_printer.py @@ -104,10 +104,15 @@ def test_resnet(): net, params = tvm.relay.testing.resnet.get_workload(batch_size=1) net.astext() +def test_dqn(): + net, params = tvm.relay.testing.dqn.get_workload(batch_size=1) + show(net.astext()) + if __name__ == "__main__": do_print[0] = True test_resnet() test_mlp() + test_dqn() test_func() test_env() test_meta_data()