diff --git a/python/tvm/relay/op/contrib/tensorrt.py b/python/tvm/relay/op/contrib/tensorrt.py index cbe6a22f4a4d..bd9051716969 100644 --- a/python/tvm/relay/op/contrib/tensorrt.py +++ b/python/tvm/relay/op/contrib/tensorrt.py @@ -723,8 +723,12 @@ def pad_annotate_fn(expr): # pylint: disable=unused-variable if attrs.pad_mode != "constant": logger.info("nn.pad: pad mode is %s but must be constant.", attrs.pad_mode) return False - if float(attrs.pad_value) != 0.0: - logger.info("nn.pad: pad value is %f but must be 0.0.", float(attrs.pad_value)) + if ( + not isinstance(args[1], relay.Constant) + or len(args[1].checked_type.shape) != 0 + or args[1].data.numpy().item() != 0.0 + ): + logger.info("nn.pad: pad value is %s but must be 0.0.", args[1]) return False if len(attrs.pad_width) not in [4, 5]: logger.info("nn.pad: can only pad 4D or 5D inputs") diff --git a/src/runtime/contrib/tensorrt/tensorrt_ops.cc b/src/runtime/contrib/tensorrt/tensorrt_ops.cc index 7197172d73db..020ef866692a 100644 --- a/src/runtime/contrib/tensorrt/tensorrt_ops.cc +++ b/src/runtime/contrib/tensorrt/tensorrt_ops.cc @@ -1057,7 +1057,7 @@ class ReshapeOpConverter : public TensorRTOpConverter { class PadOpConverter : public TensorRTOpConverter { public: - PadOpConverter() : TensorRTOpConverter({kTensor}) {} + PadOpConverter() : TensorRTOpConverter({kTensor, kWeight}) {} void Convert(TensorRTOpConverterParams* params) const { auto input = params->inputs.at(0).tensor; diff --git a/tests/python/contrib/test_tensorrt.py b/tests/python/contrib/test_tensorrt.py index f9912c9674e5..7ff051a4d38b 100644 --- a/tests/python/contrib/test_tensorrt.py +++ b/tests/python/contrib/test_tensorrt.py @@ -793,16 +793,25 @@ def get_graph(x_shape=(1, 16)): def test_pad(): - def get_graph(x_shape, pad_width): + def get_graph(x_shape, pad_width, pad_value=0.0): x = relay.var("x", shape=(x_shape), dtype="float32") - out = relay.nn.pad(x, pad_width=pad_width) + out = relay.nn.pad(x, pad_width=pad_width, pad_value=pad_value) f = relay.Function([x], out) return f, {"x": x_shape}, [] run_and_verify_func(get_graph((1, 8, 16, 16), [[0, 0], [0, 0], [0, 0], [0, 0]])) run_and_verify_func(get_graph((1, 8, 16, 16), [[0, 0], [0, 0], [1, 1], [1, 1]])) + run_and_verify_func( + get_graph((1, 8, 16, 16), [[0, 0], [0, 0], [1, 1], [1, 1]], pad_value=-1.0e30) + ) run_and_verify_func(get_graph((1, 8, 16, 16), [[0, 0], [0, 0], [0, 1], [2, 0]])) + run_and_verify_func( + get_graph((1, 8, 16, 16), [[0, 0], [0, 0], [0, 1], [2, 0]], pad_value=-1.0e30) + ) run_and_verify_func(get_graph((1, 8, 3, 16, 16), [[0, 0], [0, 0], [0, 0], [0, 0], [0, 0]])) + run_and_verify_func( + get_graph((1, 8, 3, 16, 16), [[0, 0], [0, 0], [0, 0], [0, 0], [0, 0]], pad_value=-1.0e30) + ) def test_softmax():