From de19f322499ed50524dd7669097d30a05293cc10 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 17 Jan 2025 16:46:34 -0800 Subject: [PATCH 1/3] [torchlib] Register linear and use matmul to simplify graph --- onnxscript/function_libs/torch_lib/ops/nn.py | 23 ++++++-------------- 1 file changed, 7 insertions(+), 16 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index 35c89acd4c..38889c9545 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -825,26 +825,17 @@ def aten_leaky_relu_backward( raise NotImplementedError() -# NOTE: Do not register - We rely on PyTorch decomposition to aten_addmm (Gemm) -def aten_linear(input: TFloat, weight: TFloat) -> TFloat: +@torch_op("aten::linear", trace_only=True) +def aten_linear(input: TFloat, weight: TFloat, bias: TFloat | None = None) -> TFloat: """linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor""" - # NOTE: The symbolic function in torch.onnx also uses Gemm in certain cases - # Optimizers may consider this path and replace it with Gemm - # We do not use Gemm here because input can have batch dimensions, which Gemm does not support - weight_transposed = op.Transpose(weight, perm=[1, 0]) - return op.MatMul(input, weight_transposed) - - -# NOTE: Do not register - We rely on PyTorch decomposition to aten_addmm (Gemm) -def aten_linear_bias(input: TFloat, weight: TFloat, bias: TFloat) -> TFloat: - """linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor""" - - # NOTE: The symbolic function in torch.onnx also uses Gemm in certain cases - # Optimizers may consider this path and replace it with Gemm - # We do not use Gemm here because input can have batch dimensions, which Gemm does not support + if len(input.shape) == 2: + # Use Gemm for the rank 2 input + return op.Gemm(input, weight, transB=True) weight_transposed = op.Transpose(weight, perm=[1, 0]) mul = op.MatMul(input, weight_transposed) + if bias is None: + return mul return op.Add(mul, bias) From 61a474ed326b0c5237ab80f1869409d1b86ac166 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 17 Jan 2025 16:51:02 -0800 Subject: [PATCH 2/3] gemm --- onnxscript/function_libs/torch_lib/ops/nn.py | 2 +- tests/function_libs/torch_lib/ops_test_data.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index 38889c9545..d91a12ec35 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -831,7 +831,7 @@ def aten_linear(input: TFloat, weight: TFloat, bias: TFloat | None = None) -> TF if len(input.shape) == 2: # Use Gemm for the rank 2 input - return op.Gemm(input, weight, transB=True) + return op.Gemm(input, weight, bias, transB=True) weight_transposed = op.Transpose(weight, perm=[1, 0]) mul = op.MatMul(input, weight_transposed) if bias is None: diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 8422ab7306..8bfbb3ee88 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -1855,6 +1855,7 @@ def _where_input_wrangler( tolerance={torch.float16: (8e-2, 1e-4)}, ), TorchLibOpInfo("nn.functional.glu", nn_ops.aten_glu), + TorchLibOpInfo("nn.functional.linear", nn_ops.aten_linear), TorchLibOpInfo( "nn.functional.unfold", nn_ops.aten_im2col, From 32db4bc0ca7eac190fbc5e5eaffae2bd40ad5ee3 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 17 Jan 2025 17:00:16 -0800 Subject: [PATCH 3/3] test --- tests/function_libs/torch_lib/ops_test_data.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 8bfbb3ee88..ee86327362 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -1855,7 +1855,9 @@ def _where_input_wrangler( tolerance={torch.float16: (8e-2, 1e-4)}, ), TorchLibOpInfo("nn.functional.glu", nn_ops.aten_glu), - TorchLibOpInfo("nn.functional.linear", nn_ops.aten_linear), + TorchLibOpInfo( + "nn.functional.linear", nn_ops.aten_linear, tolerance={torch.float16: (1e-2, 1e-3)} + ), TorchLibOpInfo( "nn.functional.unfold", nn_ops.aten_im2col, @@ -2177,9 +2179,6 @@ def _where_input_wrangler( ops_test_common.duplicate_opinfo(OPS_DB, "mean", ("mean_dim",)) ops_test_common.duplicate_opinfo(OPS_DB, "min", ("min_dim",)) ops_test_common.duplicate_opinfo(OPS_DB, "minimum", ("minimum_bool",)) -ops_test_common.duplicate_opinfo( - OPS_DB, "nn.functional.linear", ("nn.functional.linear_bias",) -) ops_test_common.duplicate_opinfo( OPS_DB, "nn.functional.pad",