From 2051559edf454881d92b6fbddf81304d81e0936e Mon Sep 17 00:00:00 2001 From: titaiwangms Date: Tue, 20 Aug 2024 01:14:54 +0000 Subject: [PATCH 1/2] unregister aten_linear --- onnxscript/function_libs/torch_lib/ops/nn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index 62edd7caa4..594c85515d 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -822,7 +822,7 @@ def aten_leaky_relu_backward( raise NotImplementedError() -@torch_op("aten::linear") +# NOTE: Do not register - We rely on PyTorch decomposition to aten_addmm (Gemm) def aten_linear(input: TFloat, weight: TFloat) -> TFloat: """linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor""" @@ -833,7 +833,7 @@ def aten_linear(input: TFloat, weight: TFloat) -> TFloat: return op.MatMul(input, weight_transposed) -@torch_op("aten::linear") +# NOTE: Do not register - We rely on PyTorch decomposition to aten_addmm (Gemm) def aten_linear_bias(input: TFloat, weight: TFloat, bias: TFloat) -> TFloat: """linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor""" From a4ee8a3d52e20444721fc0c5f63841f8c0bd87c8 Mon Sep 17 00:00:00 2001 From: titaiwangms Date: Tue, 20 Aug 2024 02:08:09 +0000 Subject: [PATCH 2/2] delete tests --- tests/function_libs/torch_lib/ops_test_data.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index f1099864e6..b4469a4d7b 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -1942,20 +1942,6 @@ def _where_input_wrangler( or not sample.input.shape, reason="fixme: Logic not implemented for size 0 inputs in op.Reshape", ), - TorchLibOpInfo("nn.functional.linear", nn_ops.aten_linear).skip( - # input: input, args: weight, bias; so len(args) == 2 means bias is provided - matcher=lambda sample: len(sample.args) != 1, - reason="this overload is implemented for bias=None", - ), - TorchLibOpInfo( - "nn.functional.linear_bias", - nn_ops.aten_linear_bias, - tolerance={torch.float16: (2e-1, 4e-4)}, - ).skip( - # input: input, args: weight, bias; so len(args) == 2 means bias is provided - matcher=lambda sample: len(sample.args) != 2, - reason="this overload is implemented for bias!=None", - ), TorchLibOpInfo( "nn.functional.max_pool1d", nn_ops.aten_max_pool1d,