diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index a0a837f92df9..76cd0455661b 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1469,7 +1469,7 @@ def linear(self, inputs, input_types): if isinstance(bias, _expr.Expr): bias_ndims = len(self.infer_shape_with_prelude(bias)) if bias_ndims == 1: - return _op.nn.bias_add(mm_out, bias) + return _op.nn.bias_add(mm_out, bias, axis=-1) mm_dtype = self.infer_type_with_prelude(mm_out).dtype return self.add([mm_out, bias], [mm_dtype, input_types[2]]) return mm_out diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 6f5eb1825dfc..3a3889d5cfb7 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -1602,6 +1602,8 @@ def forward(self, x, y, z): verify_model(LinearNoBias(), input_data=[input2d, weight1d]) # 3D input, 2D weight, no bias verify_model(LinearNoBias(), input_data=[input3d, weight3x2]) + # 3D input, 2D weight, 1D bias + verify_model(Linear(), input_data=[input3d, weight2d, bias1d]) verify_model(LinearNested(), input_data=[torch.randn(10, 10) for _ in range(3)])