From e9944a56c6d00c31c19d3449ca2525d9ea917ba5 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Thu, 8 Jul 2021 20:47:48 +0000 Subject: [PATCH 1/7] Add RandomUniform converter and tests to onnx frontend. --- python/tvm/relay/frontend/onnx.py | 27 +++ tests/python/frontend/onnx/test_forward.py | 208 +++++++++++++-------- 2 files changed, 153 insertions(+), 82 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index f876b1d14fa1..25dec7e6a98a 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -34,6 +34,7 @@ from .. import qnn as _qnn from .. import ty as _ty from .. import vision as _vision +from .. import random as _random from .common import ( AttrCvt, Renamer, @@ -3244,6 +3245,30 @@ def _impl_v11(cls, inputs, attr, params): return _expr.TupleWrapper(_expr.Tuple([unique_vals, indices, inverse_indices, counts]), 4) +class RandomUniform(OnnxOpConverter): + """Operator converter for random_uniform""" + + @classmethod + def _impl_v1(cls, inputs, attr, params): + dtype = get_type(attr.get("dtype", 1)) + high = attr.get("high", 1.0) + low = attr.get("low", 0.0) + seed = attr.get("seed", None) + shape = attr["shape"] + + assert dtype in [ + "float32", + "float64", + ], "Only float random value generation is currently supported." + + if seed is None: + seed = np.random.randint(1e6) + key = _random.threefry_key(seed) + output = _op.random.uniform(key, shape, dtype=dtype, low=low, high=high) + _, vals = _expr.TupleWrapper(output, 2) + return vals + + # compatible operators that do NOT require any conversion. _identity_list = [] @@ -3421,6 +3446,8 @@ def _get_convert_map(opset): "ReverseSequence": ReverseSequence.get_converter(opset), "QLinearConv": QLinearConv.get_converter(opset), "QLinearAdd": QLinearAdd.get_converter(opset), + # Random number generation. + "RandomUniform": RandomUniform.get_converter(opset), } diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index c5407697de46..a5b394252578 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -22,6 +22,7 @@ import scipy import torch import torchvision +from deckhand import ONNXModel import tvm import tvm.testing import tvm.topi.testing @@ -4872,86 +4873,129 @@ def test_qlinearadd(): verify_qlinearadd([5, 1, 7], [2, 7], [5, 2, 7]) +def get_random_uniform(shape, dtype="float32", high=1.0, low=0.0, seed=None, target="llvm"): + ONNX_DTYPE = mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)] + node = helper.make_node( + "RandomUniform", [], ["out"], shape=shape, dtype=ONNX_DTYPE, high=high, low=low + ) + if seed is not None: + seed_attr = helper.make_attribute("seed", seed) + node.attribute.append(seed_attr) + + graph = helper.make_graph( + [node], + "random_uniform_test", + inputs=[], + outputs=[helper.make_tensor_value_info("out", ONNX_DTYPE, shape)], + ) + model = helper.make_model(graph, producer_name="random_uniform_test") + return get_tvm_output_with_vm(model, [], target=target, device=tvm.device(target, 0)) + + +def test_random_uniform(): + targets = [tgt for (tgt, _) in tvm.testing.enabled_targets()] + for target in targets: + # Check that function runs and produces proper shape. + vals = get_random_uniform([10], dtype="float32", target=target) + assert list(vals.shape) == [10] + assert vals.dtype == "float32" + + # Test N-D tensor generation. + vals = get_random_uniform([1, 3, 100, 100], dtype="float32", target=target) + assert list(vals.shape) == [1, 3, 100, 100] + + # Check that bounds aren't exceeded. + vals = get_random_uniform(shape=[100], high=100, low=-100) + assert list(vals.shape) == [100] + assert all(vals >= -100) and all(vals <= 100) + + # Check that fixed seed produces the same values. + vals_1 = get_random_uniform(shape=[10], seed=1) + vals_2 = get_random_uniform(shape=[10], seed=1) + assert all(vals_1 == vals_2) + + if __name__ == "__main__": - test_flatten() - test_reshape() - test_shape() - test_expand() - test_power() - test_squeeze() - test_unsqueeze() - test_slice() - test_floor() - test_ceil() - test_round() - test_isinf() - test_isnan() - test_clip() - test_clip_min_max_as_inputs() - test_onehot() - test_gemm() - test_matmul() - test_gather() - test_gatherelements() - test_gather_nd() - test_scatter() - test_lrn() - test_instance_norm() - test_upsample() - test_forward_min() - test_forward_max() - test_forward_mean() - test_forward_hardsigmoid() - test_forward_arg_min_max() - test_softmax() - test_constantofshape() - test_all_reduce_funcs() - test_pad() - test_split() - test_binary_ops() - test_unary_ops() - test_leaky_relu() - test_elu() - test_selu() - test_prelu() - test_ThresholdedRelu() - test_LogSoftmax() - test_resnet() - test_inception() - test_densenet() - test_sign() - test_not() - test_and() - test_tile() - test_erf() - test_where() - test_or() - test_depth_to_space() - test_space_to_depth() - test_batch_norm() - test_batch_norm_dynamic_subgraph() - test_conv() - test_convtranspose() - test_unsqueeze_constant() - test_pooling() - test_lppool() - test_lstm() - test_gru() - test_resize() - test_nonzero() - test_topk() - test_mod() - test_xor() - test_max_roi_pool() - test_roi_align() - test_range() - test_loop() - test_size() - test_maxunpool() - test_softplus() - test_cumsum() - test_wrong_input() - test_aten() - test_reverse_sequence() - test_eyelike() - test_qlinearconv() + # test_flatten() + # test_reshape() + # test_shape() + # test_expand() + # test_power() + # test_squeeze() + # test_unsqueeze() + # test_slice() + # test_floor() + # test_ceil() + # test_round() + # test_isinf() + # test_isnan() + # test_clip() + # test_clip_min_max_as_inputs() + # test_onehot() + # test_gemm() + # test_matmul() + # test_gather() + # test_gatherelements() + # test_gather_nd() + # test_scatter() + # test_lrn() + # test_instance_norm() + # test_upsample() + # test_forward_min() + # test_forward_max() + # test_forward_mean() + # test_forward_hardsigmoid() + # test_forward_arg_min_max() + # test_softmax() + # test_constantofshape() + # test_all_reduce_funcs() + # test_pad() + # test_split() + # test_binary_ops() + # test_unary_ops() + # test_leaky_relu() + # test_elu() + # test_selu() + # test_prelu() + # test_ThresholdedRelu() + # test_LogSoftmax() + # test_resnet() + # test_inception() + # test_densenet() + # test_sign() + # test_not() + # test_and() + # test_tile() + # test_erf() + # test_where() + # test_or() + # test_depth_to_space() + # test_space_to_depth() + # test_batch_norm() + # test_batch_norm_dynamic_subgraph() + # test_conv() + # test_convtranspose() + # test_unsqueeze_constant() + # test_pooling() + # test_lppool() + # test_lstm() + # test_gru() + # test_resize() + # test_nonzero() + # test_topk() + # test_mod() + # test_xor() + # test_max_roi_pool() + # test_roi_align() + # test_range() + # test_loop() + # test_size() + # test_maxunpool() + # test_softplus() + # test_cumsum() + # test_wrong_input() + # test_aten() + # test_reverse_sequence() + # test_eyelike() + # test_qlinearconv() + test_random_uniform() From 15c464ed6d8f03a087b96e7cc5e86cd3d00d87ef Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Thu, 8 Jul 2021 20:50:27 +0000 Subject: [PATCH 2/7] Fix comments. --- tests/python/frontend/onnx/test_forward.py | 164 ++++++++++----------- 1 file changed, 82 insertions(+), 82 deletions(-) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index a5b394252578..ce34a5d70f7d 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -4916,86 +4916,86 @@ def test_random_uniform(): if __name__ == "__main__": - # test_flatten() - # test_reshape() - # test_shape() - # test_expand() - # test_power() - # test_squeeze() - # test_unsqueeze() - # test_slice() - # test_floor() - # test_ceil() - # test_round() - # test_isinf() - # test_isnan() - # test_clip() - # test_clip_min_max_as_inputs() - # test_onehot() - # test_gemm() - # test_matmul() - # test_gather() - # test_gatherelements() - # test_gather_nd() - # test_scatter() - # test_lrn() - # test_instance_norm() - # test_upsample() - # test_forward_min() - # test_forward_max() - # test_forward_mean() - # test_forward_hardsigmoid() - # test_forward_arg_min_max() - # test_softmax() - # test_constantofshape() - # test_all_reduce_funcs() - # test_pad() - # test_split() - # test_binary_ops() - # test_unary_ops() - # test_leaky_relu() - # test_elu() - # test_selu() - # test_prelu() - # test_ThresholdedRelu() - # test_LogSoftmax() - # test_resnet() - # test_inception() - # test_densenet() - # test_sign() - # test_not() - # test_and() - # test_tile() - # test_erf() - # test_where() - # test_or() - # test_depth_to_space() - # test_space_to_depth() - # test_batch_norm() - # test_batch_norm_dynamic_subgraph() - # test_conv() - # test_convtranspose() - # test_unsqueeze_constant() - # test_pooling() - # test_lppool() - # test_lstm() - # test_gru() - # test_resize() - # test_nonzero() - # test_topk() - # test_mod() - # test_xor() - # test_max_roi_pool() - # test_roi_align() - # test_range() - # test_loop() - # test_size() - # test_maxunpool() - # test_softplus() - # test_cumsum() - # test_wrong_input() - # test_aten() - # test_reverse_sequence() - # test_eyelike() - # test_qlinearconv() + test_flatten() + test_reshape() + test_shape() + test_expand() + test_power() + test_squeeze() + test_unsqueeze() + test_slice() + test_floor() + test_ceil() + test_round() + test_isinf() + test_isnan() + test_clip() + test_clip_min_max_as_inputs() + test_onehot() + test_gemm() + test_matmul() + test_gather() + test_gatherelements() + test_gather_nd() + test_scatter() + test_lrn() + test_instance_norm() + test_upsample() + test_forward_min() + test_forward_max() + test_forward_mean() + test_forward_hardsigmoid() + test_forward_arg_min_max() + test_softmax() + test_constantofshape() + test_all_reduce_funcs() + test_pad() + test_split() + test_binary_ops() + test_unary_ops() + test_leaky_relu() + test_elu() + test_selu() + test_prelu() + test_ThresholdedRelu() + test_LogSoftmax() + test_resnet() + test_inception() + test_densenet() + test_sign() + test_not() + test_and() + test_tile() + test_erf() + test_where() + test_or() + test_depth_to_space() + test_space_to_depth() + test_batch_norm() + test_batch_norm_dynamic_subgraph() + test_conv() + test_convtranspose() + test_unsqueeze_constant() + test_pooling() + test_lppool() + test_lstm() + test_gru() + test_resize() + test_nonzero() + test_topk() + test_mod() + test_xor() + test_max_roi_pool() + test_roi_align() + test_range() + test_loop() + test_size() + test_maxunpool() + test_softplus() + test_cumsum() + test_wrong_input() + test_aten() + test_reverse_sequence() + test_eyelike() + test_qlinearconv() test_random_uniform() From ee1ab4c7f446a16e3cdd983ce948c12e6407a69b Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Thu, 8 Jul 2021 20:51:13 +0000 Subject: [PATCH 3/7] Remove weird import. --- tests/python/frontend/onnx/test_forward.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index ce34a5d70f7d..4557afe752ac 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -22,7 +22,6 @@ import scipy import torch import torchvision -from deckhand import ONNXModel import tvm import tvm.testing import tvm.topi.testing From 8d91a222b506c4a1c9a1fde9eb66a8251e01828b Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Thu, 8 Jul 2021 21:42:31 +0000 Subject: [PATCH 4/7] Add test against golden array. --- tests/python/frontend/onnx/test_forward.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 4557afe752ac..48aff78044bd 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -4913,6 +4913,24 @@ def test_random_uniform(): vals_2 = get_random_uniform(shape=[10], seed=1) assert all(vals_1 == vals_2) + # Test against an expected output with a fixed seed. + real = get_random_uniform(shape=[10], seed=5) + expected = np.asarray( + [ + 0.8614111, + 0.46572232, + 0.6007328, + 0.21619737, + 0.6361222, + 0.7298056, + 0.13094282, + 0.03556716, + 0.32997167, + 0.2977605, + ] + ) + tvm.testing.assert_allclose(real, expected, rtol=1e-5) + if __name__ == "__main__": test_flatten() From d14edcb9365501b7dc88f0ed0c8ba669f6050f5c Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Mon, 12 Jul 2021 17:45:07 +0000 Subject: [PATCH 5/7] Retrigger CI From d991830a2f2fac0eef9e19fbfa7da7ff1b83c4ba Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Tue, 13 Jul 2021 15:31:09 -0700 Subject: [PATCH 6/7] Improve test comment. --- tests/python/frontend/onnx/test_forward.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 48a926bbd81c..7013dac79860 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -4908,7 +4908,7 @@ def test_random_uniform(): assert list(vals.shape) == [100] assert all(vals >= -100) and all(vals <= 100) - # Check that fixed seed produces the same values. + # Check that a fixed seed produces the same values when run twice. vals_1 = get_random_uniform(shape=[10], seed=1) vals_2 = get_random_uniform(shape=[10], seed=1) assert all(vals_1 == vals_2) From dd190be82be48267b4c967c06de2dd1be8770cd3 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Wed, 14 Jul 2021 03:31:50 +0000 Subject: [PATCH 7/7] Retrigger CI.