From de1314430a0b42f7802c3fb77876ae65145f850f Mon Sep 17 00:00:00 2001 From: Qidong Su Date: Tue, 5 Oct 2021 23:14:33 -0400 Subject: [PATCH 1/3] fix bias_add --- python/tvm/relay/frontend/pytorch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index a0a837f92df9..76cd0455661b 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1469,7 +1469,7 @@ def linear(self, inputs, input_types): if isinstance(bias, _expr.Expr): bias_ndims = len(self.infer_shape_with_prelude(bias)) if bias_ndims == 1: - return _op.nn.bias_add(mm_out, bias) + return _op.nn.bias_add(mm_out, bias, axis=-1) mm_dtype = self.infer_type_with_prelude(mm_out).dtype return self.add([mm_out, bias], [mm_dtype, input_types[2]]) return mm_out From a03a11979082fea5444cc4580bdbf8c033a0c987 Mon Sep 17 00:00:00 2001 From: Qidong Su Date: Wed, 6 Oct 2021 22:40:42 -0400 Subject: [PATCH 2/3] add test --- tests/python/frontend/pytorch/test_forward.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 6f5eb1825dfc..6783eab07a0f 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -1602,6 +1602,8 @@ def forward(self, x, y, z): verify_model(LinearNoBias(), input_data=[input2d, weight1d]) # 3D input, 2D weight, no bias verify_model(LinearNoBias(), input_data=[input3d, weight3x2]) + # 3D input, 2D weight, 1D bias + verify_model(LinearNoBias(), input_data=[input3d, weight2d, bias1d]) verify_model(LinearNested(), input_data=[torch.randn(10, 10) for _ in range(3)]) From 3715cf37effbc531697d99248c802ee3fda9d8b6 Mon Sep 17 00:00:00 2001 From: Qidong Su Date: Wed, 6 Oct 2021 22:40:54 -0400 Subject: [PATCH 3/3] add test --- tests/python/frontend/pytorch/test_forward.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 6783eab07a0f..3a3889d5cfb7 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -1603,7 +1603,7 @@ def forward(self, x, y, z): # 3D input, 2D weight, no bias verify_model(LinearNoBias(), input_data=[input3d, weight3x2]) # 3D input, 2D weight, 1D bias - verify_model(LinearNoBias(), input_data=[input3d, weight2d, bias1d]) + verify_model(Linear(), input_data=[input3d, weight2d, bias1d]) verify_model(LinearNested(), input_data=[torch.randn(10, 10) for _ in range(3)])