From b122bba00c79ef33867615f35c507a75413c2816 Mon Sep 17 00:00:00 2001 From: Tyler Davis Date: Wed, 16 Dec 2020 15:12:21 -0800 Subject: [PATCH 1/5] Add div_ and is_floating_point operators --- python/tvm/relay/frontend/pytorch.py | 1 - tests/python/frontend/pytorch/test_forward.py | 366 ++++++++++-------- 2 files changed, 200 insertions(+), 167 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index ebc0132435ba..a85d454cae43 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -146,7 +146,6 @@ def __init__(self, prelude, default_dtype): # above. def infer_type(self, node, mod=None): """An incremental method to infer the type of a node in the relay graph.""" - if node in self.types: return self.types[node] if isinstance(node, tvm.relay.Var): diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 2dda675c74f5..1742f1bc1f2d 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -185,10 +185,19 @@ def verify_model(model_name, input_data=[], custom_convert_map={}, rtol=1e-5, at if isinstance(baseline_outputs, tuple): baseline_outputs = tuple(out.cpu().numpy() for out in baseline_outputs) + # If baseline output is a bool, wrap it to make it iterable + elif isinstance(baseline_outputs, bool): + baseline_outputs = np.array(baseline_outputs) + baseline_outputs = np.expand_dims(baseline_outputs, 0) else: baseline_outputs = (baseline_outputs.cpu().numpy(),) - trace = torch.jit.trace(baseline_model, baseline_input) + # If the base model outputs a raw bool, it cannot be traced. + # Use torch.jit.script instead. + try: + trace = torch.jit.trace(baseline_model, baseline_input) + except RuntimeError: + trace = torch.jit.script(baseline_model) if isinstance(baseline_model, torch.nn.Module): trace = trace.float().eval() @@ -196,6 +205,8 @@ def verify_model(model_name, input_data=[], custom_convert_map={}, rtol=1e-5, at trace = trace.cuda() else: trace = trace.cpu() + print('-------------Torch Graph-------------') + print(trace.graph) input_names = ["input{}".format(idx) for idx, inp in enumerate(baseline_input)] input_shapes = list(zip(input_names, [inp.shape for inp in baseline_input])) @@ -205,10 +216,13 @@ def verify_model(model_name, input_data=[], custom_convert_map={}, rtol=1e-5, at with tvm.transform.PassContext(opt_level=3): for target, ctx in tvm.testing.enabled_targets(): relay_graph, relay_lib, relay_params = relay.build(mod, target=target, params=params) + print('--------------Relay Graph----------------') + print(relay_graph) relay_model = graph_runtime.create(relay_graph, relay_lib, ctx) relay_model.set_input(**relay_params) - for name, inp in compiled_input.items(): - relay_model.set_input(name, inp) + # TODO(dvisnty): Model has no inputs, so none need to be set. Figure out how to handle this + # for name, inp in compiled_input.items(): + # relay_model.set_input(name, inp) relay_model.run() for i, baseline_output in enumerate(baseline_outputs): @@ -2949,6 +2963,25 @@ def forward(self, *args): TrueDivide().float().eval(), input_data=[dividend, divisor_scalar], atol=1e-4, rtol=1e-4 ) +# TODO(dvisnty): Test fails when using raw bool not wrapped in Tensor. Is this due to +# the `float()` in `trace.float.eval()`? +@tvm.testing.uses_gpu +def test_forward_is_floating_point(): + torch.set_grad_enabled(False) + + class IsFloatingPoint(Module): + def forward(self, arg): + # Uncomment to wrap Bool in a Tensor, allowing use of + # `torch.jit.trace` + return torch.Tensor([torch.is_floating_point(arg)]) + # Else `torch.jit.script` will be used + # return torch.is_floating_point(arg) + + # Input could be either float or non-float + int_tensor = torch.tensor([[1]]) + float_tensor = torch.tensor([[1.0]]) + verify_model(IsFloatingPoint().float().eval(), input_data=[int_tensor]) + verify_model(IsFloatingPoint().float().eval(), input_data=[float_tensor]) @tvm.testing.uses_gpu def test_forward_traced_function(): @@ -3365,166 +3398,167 @@ def test_fn(x, weights=None): if __name__ == "__main__": # some structural tests - test_forward_traced_function() - test_forward_dtypes() - test_weight_names() - test_duplicate_weight_use() - - # Single operator tests - test_forward_pixel_shuffle() - test_forward_add() - test_forward_subtract() - test_forward_multiply() - test_forward_matmul() - test_forward_rsub() - test_forward_onehot() - test_forward_embedding() - test_forward_reshape() - test_forward_reciprocal() - test_forward_repeat() - test_forward_repeat_interleave() - test_forward_squeeze() - test_forward_unsqueeze() - test_forward_concatenate() - test_forward_reduce_sum() - test_forward_reduce_prod() - test_forward_argmin() - test_forward_argmax() - test_forward_norm() - test_forward_frobenius_norm() - test_forward_std() - test_forward_variance() - test_forward_relu() - test_forward_prelu() - test_forward_leakyrelu() - test_forward_elu() - test_forward_celu() - test_forward_gelu() - test_forward_selu() - test_forward_log_sigmoid() - test_forward_adaptiveavgpool() - test_forward_maxpool2d() - test_forward_maxpool1d() - test_forward_maxpool3d() - test_forward_hardtanh() - test_forward_conv() - test_forward_conv_transpose() - test_forward_threshold() - test_forward_contiguous() - test_forward_batchnorm() - test_forward_instancenorm() - test_forward_layernorm() - test_forward_groupnorm() - test_forward_transpose() - test_forward_size() - test_forward_view() - test_forward_select() - test_forward_take() - test_forward_topk() - test_forward_where() - test_forward_addcdiv() - test_forward_addcmul() - test_forward_true_divide() - test_forward_clone() - test_forward_softplus() - test_forward_softsign() - test_forward_logsoftmax() - test_forward_sigmoid() - test_forward_dense() - test_forward_avgpool() - test_forward_avgpool3d() - test_forward_dropout() - test_forward_slice() - test_forward_mean() - test_forward_expand() - test_forward_pow() - test_forward_unary() - test_forward_clamp() - test_forward_clamp_() - test_forward_logical_not() - test_forward_bitwise_not() - test_forward_bitwise_xor() - test_forward_logical_xor() - test_forward_isfinite() - test_forward_isnan() - test_forward_isinf() - test_forward_ones() - test_forward_ones_like() - test_forward_zeros() - test_forward_zeros_like() - test_forward_full() - test_forward_full_like() - test_forward_linspace() - test_forward_arange() - test_forward_mesh_grid() - test_forward_chunk() - test_forward_split() - test_forward_gather() - test_upsample() - test_forward_upsample3d() - test_forward_nms() - test_forward_roi_align() - test_to() - test_flatten() - test_type_as() - test_forward_functional_pad() - test_forward_zero_pad2d() - test_forward_constant_pad1d() - test_forward_constant_pad2d() - test_forward_constant_pad3d() - test_forward_reflection_pad1d() - test_forward_reflection_pad2d() - test_forward_replication_pad1d() - test_forward_replication_pad2d() - test_forward_replication_pad3d() - test_adaptive_pool3d() - test_conv3d() - test_conv3d_transpose() - test_forward_index() - test_min_max() - test_logsumexp() - test_stack() - test_stack_dynamic() - test_forward_unbind() - test_forward_nonzero() - test_forward_scatter() - test_numel() - test_bincount() - - # Model tests - test_resnet18() - test_squeezenet1_0() - test_squeezenet1_1() - test_densenet121() - # disable inception test for now, since loading it takes ~5min on torchvision-0.5 due to scipy bug - # See https://discuss.pytorch.org/t/torchvisions-inception-v3-takes-much-longer-to-load-than-other-models/68756 - # test_inception_v3() - test_googlenet() - test_mnasnet0_5() - test_mobilenet_v2() - - test_custom_conversion_map() - - test_segmentaton_models() - test_3d_models() - - # Quantization test - from qnn_test import test_quantized_imagenet, test_quantized_modules - - test_quantized_modules() - test_quantized_imagenet() - - # Test simple conditionals and loop - test_control_flow() - test_simple_rnn() - - # More complex recurrent models - from test_lstm import test_custom_lstm - - test_custom_lstm() - - # Test bert model - test_forward_pretrained_bert_base_uncased() - - # Test convert torch script(jit) with specific inputs' types - test_convert_torch_script_with_input_types() + # test_forward_traced_function() + # test_forward_dtypes() + # test_weight_names() + # test_duplicate_weight_use() + + # # Single operator tests + # test_forward_pixel_shuffle() + # test_forward_add() + # test_forward_subtract() + # test_forward_multiply() + # test_forward_matmul() + # test_forward_rsub() + # test_forward_onehot() + # test_forward_embedding() + # test_forward_reshape() + # test_forward_reciprocal() + # test_forward_repeat() + # test_forward_repeat_interleave() + # test_forward_squeeze() + # test_forward_unsqueeze() + # test_forward_concatenate() + # test_forward_reduce_sum() + # test_forward_reduce_prod() + # test_forward_argmin() + # test_forward_argmax() + # test_forward_norm() + # test_forward_frobenius_norm() + # test_forward_std() + # test_forward_variance() + # test_forward_relu() + # test_forward_prelu() + # test_forward_leakyrelu() + # test_forward_elu() + # test_forward_celu() + # test_forward_gelu() + # test_forward_selu() + # test_forward_log_sigmoid() + # test_forward_adaptiveavgpool() + # test_forward_maxpool2d() + # test_forward_maxpool1d() + # test_forward_maxpool3d() + # test_forward_hardtanh() + # test_forward_conv() + # test_forward_conv_transpose() + # test_forward_threshold() + # test_forward_contiguous() + # test_forward_batchnorm() + # test_forward_instancenorm() + # test_forward_layernorm() + # test_forward_groupnorm() + # test_forward_transpose() + # test_forward_size() + # test_forward_view() + # test_forward_select() + # test_forward_take() + # test_forward_topk() + # test_forward_where() + # test_forward_addcdiv() + # test_forward_addcmul() + # test_forward_true_divide() + test_forward_is_floating_point() + # test_forward_clone() + # test_forward_softplus() + # test_forward_softsign() + # test_forward_logsoftmax() + # test_forward_sigmoid() + # test_forward_dense() + # test_forward_avgpool() + # test_forward_avgpool3d() + # test_forward_dropout() + # test_forward_slice() + # test_forward_mean() + # test_forward_expand() + # test_forward_pow() + # test_forward_unary() + # test_forward_clamp() + # test_forward_clamp_() + # test_forward_logical_not() + # test_forward_bitwise_not() + # test_forward_bitwise_xor() + # test_forward_logical_xor() + # test_forward_isfinite() + # test_forward_isnan() + # test_forward_isinf() + # test_forward_ones() + # test_forward_ones_like() + # test_forward_zeros() + # test_forward_zeros_like() + # test_forward_full() + # test_forward_full_like() + # test_forward_linspace() + # test_forward_arange() + # test_forward_mesh_grid() + # test_forward_chunk() + # test_forward_split() + # test_forward_gather() + # test_upsample() + # test_forward_upsample3d() + # test_forward_nms() + # test_forward_roi_align() + # test_to() + # test_flatten() + # test_type_as() + # test_forward_functional_pad() + # test_forward_zero_pad2d() + # test_forward_constant_pad1d() + # test_forward_constant_pad2d() + # test_forward_constant_pad3d() + # test_forward_reflection_pad1d() + # test_forward_reflection_pad2d() + # test_forward_replication_pad1d() + # test_forward_replication_pad2d() + # test_forward_replication_pad3d() + # test_adaptive_pool3d() + # test_conv3d() + # test_conv3d_transpose() + # test_forward_index() + # test_min_max() + # test_logsumexp() + # test_stack() + # test_stack_dynamic() + # test_forward_unbind() + # test_forward_nonzero() + # test_forward_scatter() + # test_numel() + # test_bincount() + + # # Model tests + # test_resnet18() + # test_squeezenet1_0() + # test_squeezenet1_1() + # test_densenet121() + # # disable inception test for now, since loading it takes ~5min on torchvision-0.5 due to scipy bug + # # See https://discuss.pytorch.org/t/torchvisions-inception-v3-takes-much-longer-to-load-than-other-models/68756 + # # test_inception_v3() + # test_googlenet() + # test_mnasnet0_5() + # test_mobilenet_v2() + + # test_custom_conversion_map() + + # test_segmentaton_models() + # test_3d_models() + + # # Quantization test + # from qnn_test import test_quantized_imagenet, test_quantized_modules + + # test_quantized_modules() + # test_quantized_imagenet() + + # # Test simple conditionals and loop + # test_control_flow() + # test_simple_rnn() + + # # More complex recurrent models + # from test_lstm import test_custom_lstm + + # test_custom_lstm() + + # # Test bert model + # test_forward_pretrained_bert_base_uncased() + + # # Test convert torch script(jit) with specific inputs' types + # test_convert_torch_script_with_input_types() From d36ceed2f21828305f70eccc09ef8573362f0f07 Mon Sep 17 00:00:00 2001 From: Tyler Davis Date: Thu, 17 Dec 2020 18:12:33 -0800 Subject: [PATCH 2/5] Add handling of exprs to op, update tests --- python/tvm/relay/frontend/pytorch.py | 1 + tests/python/frontend/pytorch/test_forward.py | 49 ++++++++++++++----- 2 files changed, 39 insertions(+), 11 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index a85d454cae43..69cfe1a85d75 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2067,6 +2067,7 @@ def is_floating_point(self, inputs, input_types): input_type = input_types[0] is_float = input_type in ["float32", "float64", "float16"] + return _expr.const(is_float) # Operator mappings diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 1742f1bc1f2d..ca1d57c18ee7 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -205,19 +205,31 @@ def verify_model(model_name, input_data=[], custom_convert_map={}, rtol=1e-5, at trace = trace.cuda() else: trace = trace.cpu() - print('-------------Torch Graph-------------') - print(trace.graph) input_names = ["input{}".format(idx) for idx, inp in enumerate(baseline_input)] - input_shapes = list(zip(input_names, [inp.shape for inp in baseline_input])) + # input_shapes = list(zip(input_names, [inp.shape for inp in baseline_input])) + # input_shapes = [('input0', torch.Size([1, 1]))] + + # Dtype must be set in `input_shape` for scripted models to work properly + if baseline_input[0].dtype == torch.float64: + curr_dtype = 'float64' + elif baseline_input[0].dtype == torch.float32: + curr_dtype = 'float32' + elif baseline_input[0].dtype == torch.float16: + curr_dtype = 'float16' + elif baseline_input[0].dtype == torch.int8: + curr_dtype = 'int' + else: + print('uh oh, other dtype: {}'.format(baseline_input[0].dtype)) + exit(-1) + input_shapes = [('input0', ((1, 1), curr_dtype))] mod, params = relay.frontend.from_pytorch(trace, input_shapes, custom_convert_map) + compiled_input = dict(zip(input_names, [inp.cpu().numpy() for inp in baseline_input])) with tvm.transform.PassContext(opt_level=3): for target, ctx in tvm.testing.enabled_targets(): relay_graph, relay_lib, relay_params = relay.build(mod, target=target, params=params) - print('--------------Relay Graph----------------') - print(relay_graph) relay_model = graph_runtime.create(relay_graph, relay_lib, ctx) relay_model.set_input(**relay_params) # TODO(dvisnty): Model has no inputs, so none need to be set. Figure out how to handle this @@ -2971,17 +2983,32 @@ def test_forward_is_floating_point(): class IsFloatingPoint(Module): def forward(self, arg): - # Uncomment to wrap Bool in a Tensor, allowing use of + # `torch.jit.trace` cannot accept something that outputs + # a Bool, so `torch.jit.script` will be used instead + return torch.is_floating_point(arg) + + class IsFloatingPointWrapped(Module): + def forward(self, arg): + # Wrap Bool in a Tensor, allowing use of # `torch.jit.trace` return torch.Tensor([torch.is_floating_point(arg)]) - # Else `torch.jit.script` will be used - # return torch.is_floating_point(arg) # Input could be either float or non-float - int_tensor = torch.tensor([[1]]) - float_tensor = torch.tensor([[1.0]]) + int_tensor = torch.tensor([[1]], dtype=torch.int8) + float64_tensor = torch.tensor([[1.0]], dtype=torch.float64) + float32_tensor = torch.tensor([[1.0]], dtype=torch.float32) + float16_tensor = torch.tensor([[1.0]], dtype=torch.float16) + verify_model(IsFloatingPoint().float().eval(), input_data=[int_tensor]) - verify_model(IsFloatingPoint().float().eval(), input_data=[float_tensor]) + verify_model(IsFloatingPoint().float().eval(), input_data=[float64_tensor]) + verify_model(IsFloatingPoint().float().eval(), input_data=[float32_tensor]) + verify_model(IsFloatingPoint().float().eval(), input_data=[float16_tensor]) + + verify_model(IsFloatingPointWrapped().float().eval(), input_data=[int_tensor]) + verify_model(IsFloatingPointWrapped().float().eval(), input_data=[float64_tensor]) + verify_model(IsFloatingPointWrapped().float().eval(), input_data=[float32_tensor]) + verify_model(IsFloatingPointWrapped().float().eval(), input_data=[float16_tensor]) + @tvm.testing.uses_gpu def test_forward_traced_function(): From cfc0e8d5b08bbe80a06400bf1b99d9b697ebfccf Mon Sep 17 00:00:00 2001 From: Tyler Davis Date: Fri, 18 Dec 2020 12:13:29 -0800 Subject: [PATCH 3/5] Properly handle bfloat16 in is_floating_point --- python/tvm/relay/frontend/pytorch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 69cfe1a85d75..d5ac70813f2d 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2066,7 +2066,7 @@ def is_floating_point(self, inputs, input_types): else: input_type = input_types[0] - is_float = input_type in ["float32", "float64", "float16"] + is_float = input_type in ["float32", "float64", "float16", "bfloat16"] return _expr.const(is_float) From bba6d3d6971733de374c50a4a29e1d979fe54577 Mon Sep 17 00:00:00 2001 From: Tyler Davis Date: Fri, 18 Dec 2020 12:17:41 -0800 Subject: [PATCH 4/5] Revert test changes --- tests/python/frontend/pytorch/test_forward.py | 395 ++++++++---------- 1 file changed, 167 insertions(+), 228 deletions(-) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index ca1d57c18ee7..2dda675c74f5 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -185,19 +185,10 @@ def verify_model(model_name, input_data=[], custom_convert_map={}, rtol=1e-5, at if isinstance(baseline_outputs, tuple): baseline_outputs = tuple(out.cpu().numpy() for out in baseline_outputs) - # If baseline output is a bool, wrap it to make it iterable - elif isinstance(baseline_outputs, bool): - baseline_outputs = np.array(baseline_outputs) - baseline_outputs = np.expand_dims(baseline_outputs, 0) else: baseline_outputs = (baseline_outputs.cpu().numpy(),) - # If the base model outputs a raw bool, it cannot be traced. - # Use torch.jit.script instead. - try: - trace = torch.jit.trace(baseline_model, baseline_input) - except RuntimeError: - trace = torch.jit.script(baseline_model) + trace = torch.jit.trace(baseline_model, baseline_input) if isinstance(baseline_model, torch.nn.Module): trace = trace.float().eval() @@ -207,24 +198,8 @@ def verify_model(model_name, input_data=[], custom_convert_map={}, rtol=1e-5, at trace = trace.cpu() input_names = ["input{}".format(idx) for idx, inp in enumerate(baseline_input)] - # input_shapes = list(zip(input_names, [inp.shape for inp in baseline_input])) - # input_shapes = [('input0', torch.Size([1, 1]))] - - # Dtype must be set in `input_shape` for scripted models to work properly - if baseline_input[0].dtype == torch.float64: - curr_dtype = 'float64' - elif baseline_input[0].dtype == torch.float32: - curr_dtype = 'float32' - elif baseline_input[0].dtype == torch.float16: - curr_dtype = 'float16' - elif baseline_input[0].dtype == torch.int8: - curr_dtype = 'int' - else: - print('uh oh, other dtype: {}'.format(baseline_input[0].dtype)) - exit(-1) - input_shapes = [('input0', ((1, 1), curr_dtype))] + input_shapes = list(zip(input_names, [inp.shape for inp in baseline_input])) mod, params = relay.frontend.from_pytorch(trace, input_shapes, custom_convert_map) - compiled_input = dict(zip(input_names, [inp.cpu().numpy() for inp in baseline_input])) with tvm.transform.PassContext(opt_level=3): @@ -232,9 +207,8 @@ def verify_model(model_name, input_data=[], custom_convert_map={}, rtol=1e-5, at relay_graph, relay_lib, relay_params = relay.build(mod, target=target, params=params) relay_model = graph_runtime.create(relay_graph, relay_lib, ctx) relay_model.set_input(**relay_params) - # TODO(dvisnty): Model has no inputs, so none need to be set. Figure out how to handle this - # for name, inp in compiled_input.items(): - # relay_model.set_input(name, inp) + for name, inp in compiled_input.items(): + relay_model.set_input(name, inp) relay_model.run() for i, baseline_output in enumerate(baseline_outputs): @@ -2975,40 +2949,6 @@ def forward(self, *args): TrueDivide().float().eval(), input_data=[dividend, divisor_scalar], atol=1e-4, rtol=1e-4 ) -# TODO(dvisnty): Test fails when using raw bool not wrapped in Tensor. Is this due to -# the `float()` in `trace.float.eval()`? -@tvm.testing.uses_gpu -def test_forward_is_floating_point(): - torch.set_grad_enabled(False) - - class IsFloatingPoint(Module): - def forward(self, arg): - # `torch.jit.trace` cannot accept something that outputs - # a Bool, so `torch.jit.script` will be used instead - return torch.is_floating_point(arg) - - class IsFloatingPointWrapped(Module): - def forward(self, arg): - # Wrap Bool in a Tensor, allowing use of - # `torch.jit.trace` - return torch.Tensor([torch.is_floating_point(arg)]) - - # Input could be either float or non-float - int_tensor = torch.tensor([[1]], dtype=torch.int8) - float64_tensor = torch.tensor([[1.0]], dtype=torch.float64) - float32_tensor = torch.tensor([[1.0]], dtype=torch.float32) - float16_tensor = torch.tensor([[1.0]], dtype=torch.float16) - - verify_model(IsFloatingPoint().float().eval(), input_data=[int_tensor]) - verify_model(IsFloatingPoint().float().eval(), input_data=[float64_tensor]) - verify_model(IsFloatingPoint().float().eval(), input_data=[float32_tensor]) - verify_model(IsFloatingPoint().float().eval(), input_data=[float16_tensor]) - - verify_model(IsFloatingPointWrapped().float().eval(), input_data=[int_tensor]) - verify_model(IsFloatingPointWrapped().float().eval(), input_data=[float64_tensor]) - verify_model(IsFloatingPointWrapped().float().eval(), input_data=[float32_tensor]) - verify_model(IsFloatingPointWrapped().float().eval(), input_data=[float16_tensor]) - @tvm.testing.uses_gpu def test_forward_traced_function(): @@ -3425,167 +3365,166 @@ def test_fn(x, weights=None): if __name__ == "__main__": # some structural tests - # test_forward_traced_function() - # test_forward_dtypes() - # test_weight_names() - # test_duplicate_weight_use() - - # # Single operator tests - # test_forward_pixel_shuffle() - # test_forward_add() - # test_forward_subtract() - # test_forward_multiply() - # test_forward_matmul() - # test_forward_rsub() - # test_forward_onehot() - # test_forward_embedding() - # test_forward_reshape() - # test_forward_reciprocal() - # test_forward_repeat() - # test_forward_repeat_interleave() - # test_forward_squeeze() - # test_forward_unsqueeze() - # test_forward_concatenate() - # test_forward_reduce_sum() - # test_forward_reduce_prod() - # test_forward_argmin() - # test_forward_argmax() - # test_forward_norm() - # test_forward_frobenius_norm() - # test_forward_std() - # test_forward_variance() - # test_forward_relu() - # test_forward_prelu() - # test_forward_leakyrelu() - # test_forward_elu() - # test_forward_celu() - # test_forward_gelu() - # test_forward_selu() - # test_forward_log_sigmoid() - # test_forward_adaptiveavgpool() - # test_forward_maxpool2d() - # test_forward_maxpool1d() - # test_forward_maxpool3d() - # test_forward_hardtanh() - # test_forward_conv() - # test_forward_conv_transpose() - # test_forward_threshold() - # test_forward_contiguous() - # test_forward_batchnorm() - # test_forward_instancenorm() - # test_forward_layernorm() - # test_forward_groupnorm() - # test_forward_transpose() - # test_forward_size() - # test_forward_view() - # test_forward_select() - # test_forward_take() - # test_forward_topk() - # test_forward_where() - # test_forward_addcdiv() - # test_forward_addcmul() - # test_forward_true_divide() - test_forward_is_floating_point() - # test_forward_clone() - # test_forward_softplus() - # test_forward_softsign() - # test_forward_logsoftmax() - # test_forward_sigmoid() - # test_forward_dense() - # test_forward_avgpool() - # test_forward_avgpool3d() - # test_forward_dropout() - # test_forward_slice() - # test_forward_mean() - # test_forward_expand() - # test_forward_pow() - # test_forward_unary() - # test_forward_clamp() - # test_forward_clamp_() - # test_forward_logical_not() - # test_forward_bitwise_not() - # test_forward_bitwise_xor() - # test_forward_logical_xor() - # test_forward_isfinite() - # test_forward_isnan() - # test_forward_isinf() - # test_forward_ones() - # test_forward_ones_like() - # test_forward_zeros() - # test_forward_zeros_like() - # test_forward_full() - # test_forward_full_like() - # test_forward_linspace() - # test_forward_arange() - # test_forward_mesh_grid() - # test_forward_chunk() - # test_forward_split() - # test_forward_gather() - # test_upsample() - # test_forward_upsample3d() - # test_forward_nms() - # test_forward_roi_align() - # test_to() - # test_flatten() - # test_type_as() - # test_forward_functional_pad() - # test_forward_zero_pad2d() - # test_forward_constant_pad1d() - # test_forward_constant_pad2d() - # test_forward_constant_pad3d() - # test_forward_reflection_pad1d() - # test_forward_reflection_pad2d() - # test_forward_replication_pad1d() - # test_forward_replication_pad2d() - # test_forward_replication_pad3d() - # test_adaptive_pool3d() - # test_conv3d() - # test_conv3d_transpose() - # test_forward_index() - # test_min_max() - # test_logsumexp() - # test_stack() - # test_stack_dynamic() - # test_forward_unbind() - # test_forward_nonzero() - # test_forward_scatter() - # test_numel() - # test_bincount() - - # # Model tests - # test_resnet18() - # test_squeezenet1_0() - # test_squeezenet1_1() - # test_densenet121() - # # disable inception test for now, since loading it takes ~5min on torchvision-0.5 due to scipy bug - # # See https://discuss.pytorch.org/t/torchvisions-inception-v3-takes-much-longer-to-load-than-other-models/68756 - # # test_inception_v3() - # test_googlenet() - # test_mnasnet0_5() - # test_mobilenet_v2() - - # test_custom_conversion_map() - - # test_segmentaton_models() - # test_3d_models() - - # # Quantization test - # from qnn_test import test_quantized_imagenet, test_quantized_modules - - # test_quantized_modules() - # test_quantized_imagenet() - - # # Test simple conditionals and loop - # test_control_flow() - # test_simple_rnn() - - # # More complex recurrent models - # from test_lstm import test_custom_lstm - - # test_custom_lstm() - - # # Test bert model - # test_forward_pretrained_bert_base_uncased() - - # # Test convert torch script(jit) with specific inputs' types - # test_convert_torch_script_with_input_types() + test_forward_traced_function() + test_forward_dtypes() + test_weight_names() + test_duplicate_weight_use() + + # Single operator tests + test_forward_pixel_shuffle() + test_forward_add() + test_forward_subtract() + test_forward_multiply() + test_forward_matmul() + test_forward_rsub() + test_forward_onehot() + test_forward_embedding() + test_forward_reshape() + test_forward_reciprocal() + test_forward_repeat() + test_forward_repeat_interleave() + test_forward_squeeze() + test_forward_unsqueeze() + test_forward_concatenate() + test_forward_reduce_sum() + test_forward_reduce_prod() + test_forward_argmin() + test_forward_argmax() + test_forward_norm() + test_forward_frobenius_norm() + test_forward_std() + test_forward_variance() + test_forward_relu() + test_forward_prelu() + test_forward_leakyrelu() + test_forward_elu() + test_forward_celu() + test_forward_gelu() + test_forward_selu() + test_forward_log_sigmoid() + test_forward_adaptiveavgpool() + test_forward_maxpool2d() + test_forward_maxpool1d() + test_forward_maxpool3d() + test_forward_hardtanh() + test_forward_conv() + test_forward_conv_transpose() + test_forward_threshold() + test_forward_contiguous() + test_forward_batchnorm() + test_forward_instancenorm() + test_forward_layernorm() + test_forward_groupnorm() + test_forward_transpose() + test_forward_size() + test_forward_view() + test_forward_select() + test_forward_take() + test_forward_topk() + test_forward_where() + test_forward_addcdiv() + test_forward_addcmul() + test_forward_true_divide() + test_forward_clone() + test_forward_softplus() + test_forward_softsign() + test_forward_logsoftmax() + test_forward_sigmoid() + test_forward_dense() + test_forward_avgpool() + test_forward_avgpool3d() + test_forward_dropout() + test_forward_slice() + test_forward_mean() + test_forward_expand() + test_forward_pow() + test_forward_unary() + test_forward_clamp() + test_forward_clamp_() + test_forward_logical_not() + test_forward_bitwise_not() + test_forward_bitwise_xor() + test_forward_logical_xor() + test_forward_isfinite() + test_forward_isnan() + test_forward_isinf() + test_forward_ones() + test_forward_ones_like() + test_forward_zeros() + test_forward_zeros_like() + test_forward_full() + test_forward_full_like() + test_forward_linspace() + test_forward_arange() + test_forward_mesh_grid() + test_forward_chunk() + test_forward_split() + test_forward_gather() + test_upsample() + test_forward_upsample3d() + test_forward_nms() + test_forward_roi_align() + test_to() + test_flatten() + test_type_as() + test_forward_functional_pad() + test_forward_zero_pad2d() + test_forward_constant_pad1d() + test_forward_constant_pad2d() + test_forward_constant_pad3d() + test_forward_reflection_pad1d() + test_forward_reflection_pad2d() + test_forward_replication_pad1d() + test_forward_replication_pad2d() + test_forward_replication_pad3d() + test_adaptive_pool3d() + test_conv3d() + test_conv3d_transpose() + test_forward_index() + test_min_max() + test_logsumexp() + test_stack() + test_stack_dynamic() + test_forward_unbind() + test_forward_nonzero() + test_forward_scatter() + test_numel() + test_bincount() + + # Model tests + test_resnet18() + test_squeezenet1_0() + test_squeezenet1_1() + test_densenet121() + # disable inception test for now, since loading it takes ~5min on torchvision-0.5 due to scipy bug + # See https://discuss.pytorch.org/t/torchvisions-inception-v3-takes-much-longer-to-load-than-other-models/68756 + # test_inception_v3() + test_googlenet() + test_mnasnet0_5() + test_mobilenet_v2() + + test_custom_conversion_map() + + test_segmentaton_models() + test_3d_models() + + # Quantization test + from qnn_test import test_quantized_imagenet, test_quantized_modules + + test_quantized_modules() + test_quantized_imagenet() + + # Test simple conditionals and loop + test_control_flow() + test_simple_rnn() + + # More complex recurrent models + from test_lstm import test_custom_lstm + + test_custom_lstm() + + # Test bert model + test_forward_pretrained_bert_base_uncased() + + # Test convert torch script(jit) with specific inputs' types + test_convert_torch_script_with_input_types() From 100d6391a74980aae2815a56f4d1e248ddd532aa Mon Sep 17 00:00:00 2001 From: Tyler Davis Date: Fri, 18 Dec 2020 12:21:07 -0800 Subject: [PATCH 5/5] revert whitespace changes --- python/tvm/relay/frontend/pytorch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index d5ac70813f2d..c75bd2dd3c09 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -146,6 +146,7 @@ def __init__(self, prelude, default_dtype): # above. def infer_type(self, node, mod=None): """An incremental method to infer the type of a node in the relay graph.""" + if node in self.types: return self.types[node] if isinstance(node, tvm.relay.Var): @@ -2067,7 +2068,6 @@ def is_floating_point(self, inputs, input_types): input_type = input_types[0] is_float = input_type in ["float32", "float64", "float16", "bfloat16"] - return _expr.const(is_float) # Operator mappings