diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index 35c89acd4c..d91a12ec35 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -825,26 +825,17 @@ def aten_leaky_relu_backward( raise NotImplementedError() -# NOTE: Do not register - We rely on PyTorch decomposition to aten_addmm (Gemm) -def aten_linear(input: TFloat, weight: TFloat) -> TFloat: +@torch_op("aten::linear", trace_only=True) +def aten_linear(input: TFloat, weight: TFloat, bias: TFloat | None = None) -> TFloat: """linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor""" - # NOTE: The symbolic function in torch.onnx also uses Gemm in certain cases - # Optimizers may consider this path and replace it with Gemm - # We do not use Gemm here because input can have batch dimensions, which Gemm does not support - weight_transposed = op.Transpose(weight, perm=[1, 0]) - return op.MatMul(input, weight_transposed) - - -# NOTE: Do not register - We rely on PyTorch decomposition to aten_addmm (Gemm) -def aten_linear_bias(input: TFloat, weight: TFloat, bias: TFloat) -> TFloat: - """linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor""" - - # NOTE: The symbolic function in torch.onnx also uses Gemm in certain cases - # Optimizers may consider this path and replace it with Gemm - # We do not use Gemm here because input can have batch dimensions, which Gemm does not support + if len(input.shape) == 2: + # Use Gemm for the rank 2 input + return op.Gemm(input, weight, bias, transB=True) weight_transposed = op.Transpose(weight, perm=[1, 0]) mul = op.MatMul(input, weight_transposed) + if bias is None: + return mul return op.Add(mul, bias) diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 8422ab7306..ee86327362 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -1855,6 +1855,9 @@ def _where_input_wrangler( tolerance={torch.float16: (8e-2, 1e-4)}, ), TorchLibOpInfo("nn.functional.glu", nn_ops.aten_glu), + TorchLibOpInfo( + "nn.functional.linear", nn_ops.aten_linear, tolerance={torch.float16: (1e-2, 1e-3)} + ), TorchLibOpInfo( "nn.functional.unfold", nn_ops.aten_im2col, @@ -2176,9 +2179,6 @@ def _where_input_wrangler( ops_test_common.duplicate_opinfo(OPS_DB, "mean", ("mean_dim",)) ops_test_common.duplicate_opinfo(OPS_DB, "min", ("min_dim",)) ops_test_common.duplicate_opinfo(OPS_DB, "minimum", ("minimum_bool",)) -ops_test_common.duplicate_opinfo( - OPS_DB, "nn.functional.linear", ("nn.functional.linear_bias",) -) ops_test_common.duplicate_opinfo( OPS_DB, "nn.functional.pad",