diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 401a88e2ed..c3f14790f3 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -812,19 +812,6 @@ def aten_atleast_3d_single_tensor(self: TTensor) -> TTensor: return self -def aten_avg_pool1d( - self: TensorType, - kernel_size: Sequence[int], - stride: Optional[Sequence[int]] = None, - padding: Sequence[int] = (0,), - ceil_mode: bool = False, - count_include_pad: bool = True, -) -> TensorType: - """avg_pool1d(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, bool ceil_mode=False, bool count_include_pad=True) -> Tensor""" - - raise NotImplementedError() - - @torch_op("aten::baddbmm") def aten_baddbmm( self: TRealOrUInt8, diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index 446103d6c5..945b3329c6 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -127,6 +127,70 @@ def aten_adaptive_max_pool3d_backward( raise NotImplementedError() +def _adjust_attributes_of_avg_pool( + expand_size: int, + kernel_size: Sequence[int], + stride: Sequence[int], + padding: Sequence[int], +) -> Tuple[Sequence[int], Sequence[int], Sequence[int]]: + """Adjust attributes of avg_pool to match ONNX specification.""" + + if isinstance(kernel_size, int): + kernel_shape = [kernel_size] * expand_size + else: + kernel_shape = kernel_size + + if isinstance(padding, int): + pads = [padding] * expand_size * 2 + elif len(padding) == 1: + pads = padding * expand_size * 2 + elif len(padding) == 2: + pads = padding * expand_size + else: + pads = padding * 2 + + if isinstance(stride, int): + strides = [stride] * expand_size + elif not stride: + strides = kernel_shape + else: + strides = stride + + return (kernel_shape, strides, pads) + + +@torch_op("aten::avg_pool1d", trace_only=True) +def aten_avg_pool1d( + self: TFloat, + kernel_size: Sequence[int], + stride: Sequence[int] = (), + padding: Sequence[int] = (0,), + ceil_mode: bool = False, + count_include_pad: bool = True, +) -> TFloat: + """avg_pool1d(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, bool ceil_mode=False, bool count_include_pad=True) -> Tensor""" + + # Torch prefer to use single number x for kerne,stride,pad,dilation on both side implicitly + # But ONNX needs pair number [x,y] to specify on each side explicitly + # For pool3d, this number should be 3 + expand_size = 1 + + kernel_shape, strides, pads = _adjust_attributes_of_avg_pool( + expand_size, kernel_size, stride, padding + ) + + result = op.AveragePool( + self, + ceil_mode=ceil_mode, + count_include_pad=count_include_pad, + kernel_shape=kernel_shape, + pads=pads, + strides=strides, + ) + + return result + + @torch_op("aten::avg_pool2d", trace_only=True) def aten_avg_pool2d( self: TFloat, @@ -144,30 +208,9 @@ def aten_avg_pool2d( # For pool3d, this number should be 3 expand_size = 2 - # The kernel_shape should be [x, y] - if isinstance(kernel_size, int): # x -> [x, x] - kernel_shape = [kernel_size] * expand_size - else: # assert(len(kernel_size)==2), already [x, y] - kernel_shape = kernel_size - - # The pads should be [w, x, y, z] - if isinstance(padding, int): # w -> [w, w, w, w] - pads = [padding] * expand_size * 2 - elif len(padding) == 1: # [w] -> [w, w, w, w] - pads = padding * 4 - elif len(padding) == 2: # [w, x] -> [w, x, w, x] - pads = padding * 2 - else: # assert len(padding) == 4, already [w, x, y, z] - pads = padding - - # The strides should be [x, y] - if isinstance(stride, int): # x -> [x, x] - strides = [stride] * expand_size - elif not stride: - # stride is empty - strides = kernel_shape - else: - strides = stride + kernel_shape, strides, pads = _adjust_attributes_of_avg_pool( + expand_size, kernel_size, stride, padding + ) result = op.AveragePool( self, @@ -209,18 +252,50 @@ def aten_avg_pool2d_backward( raise NotImplementedError() +@torch_op("aten::avg_pool3d", trace_only=True) def aten_avg_pool3d( - self: TensorType, + self: TFloat, kernel_size: Sequence[int], - stride: Optional[Sequence[int]] = None, + stride: Sequence[int] = (), padding: Sequence[int] = (0, 0, 0), ceil_mode: bool = False, count_include_pad: bool = True, - divisor_override: Optional[int] = None, -) -> TensorType: + divisor_override: Optional[int] = None, # pylint: disable=unused-argument +) -> TFloat: """avg_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> Tensor""" - raise NotImplementedError() + # Torch prefer to use single number x for kerne,stride,pad,dilation on both side implicitly + # But ONNX needs pair number [x,y] to specify on each side explicitly + # For pool3d, this number should be 3 + expand_size = 3 + + kernel_shape, strides, pads = _adjust_attributes_of_avg_pool( + expand_size, kernel_size, stride, padding + ) + + result = op.AveragePool( + self, + kernel_shape=kernel_shape, + strides=strides, + pads=pads, + count_include_pad=count_include_pad, + ceil_mode=ceil_mode, + ) + + # TODO: if want to support divisor_override argument, need to op.Mul(result, mask) + # mask = [ + # 1, 2, 3, S,..3, 2, 1 + # 2, 4, 6, 2S, 6, 4, 2 + # 3, 6, 9, 3S, 9, 6, 3 + # S, 2S,3S,SS,3S,2S, S + # 3, 6, 9, 3S, 9, 6, 3 + # 2, 4, 6, 2S, 6, 4, 2 + # 1, 2, 3, S,..3, 2, 1 + # ] + # S is stride size, in this case S=4, + # S may dup lot of times according to the image size + + return result def aten_avg_pool3d_backward( diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py index 29b89a7689..ffe3da3a23 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py @@ -181,7 +181,7 @@ def _amin_amax_input_wrangler( return args, kwargs -def _avg_pool2d_input_wrangler( +def _avg_pool_input_wrangler( args: list[Any], kwargs: dict[str, Any] ) -> tuple[list[Any], dict[str, Any]]: if "dim" not in kwargs: @@ -197,10 +197,11 @@ def _avg_pool2d_input_wrangler( # Cannot using list(padding) here, because the element will be numpy.int64 instead of int padding = padding.tolist() kwargs["padding"] = padding - stride = args.pop(2) - if isinstance(stride, np.ndarray): - stride = stride.tolist() - kwargs["stride"] = stride + if len(args) > 2: + stride = args.pop(2) + if isinstance(stride, np.ndarray): + stride = stride.tolist() + kwargs["stride"] = stride kernel_size = args.pop(1) if isinstance(kernel_size, np.ndarray): kernel_size = kernel_size.tolist() @@ -1393,15 +1394,52 @@ def _where_input_wrangler( trace_only=True, tolerance={torch.float32: (3.7e-5, 1.8e-4)}, ), + TorchLibOpInfo( + "nn.functional.avg_pool1d", + nn_ops.aten_avg_pool1d, + input_wrangler=_avg_pool_input_wrangler, + trace_only=True, + ) + .xfail( + matcher=lambda sample: (len(sample.args) > 5 and sample.args[5] is not None) + or (sample.kwargs.get("divisor_override") is not None), + reason="ONNX doesn't support divisor_override argument", + ) + .xfail( + matcher=lambda sample: (sample.kwargs.get("ceil_mode") is True) + and ( + sample.kwargs.get("count_include_pad") is True + or sample.input.shape[2] + % (sample.args[0][0] if isinstance(sample.args[0], tuple) else sample.args[0]) + != 0 + ), + reason="fixme: ORT doesn't match PyTorch when ceil_mode=True until opset 19", + ), TorchLibOpInfo( "nn.functional.avg_pool2d", nn_ops.aten_avg_pool2d, - input_wrangler=_avg_pool2d_input_wrangler, + input_wrangler=_avg_pool_input_wrangler, trace_only=True, ).xfail( - matcher=lambda sample: len(sample.args) > 5 and sample.args[5] is not None, + matcher=lambda sample: (len(sample.args) > 5 and sample.args[5] is not None) + or (sample.kwargs.get("divisor_override") is not None), reason="ONNX doesn't support divisor_override argument", ), + TorchLibOpInfo( + "nn.functional.avg_pool3d", + nn_ops.aten_avg_pool3d, + input_wrangler=_avg_pool_input_wrangler, + trace_only=True, + ) + .xfail( + matcher=lambda sample: (len(sample.args) > 5 and sample.args[5] is not None) + or (sample.kwargs.get("divisor_override") is not None), + reason="ONNX doesn't support divisor_override argument", + ) + .xfail( + matcher=lambda sample: sample.kwargs.get("ceil_mode") is True, + reason="fixme: ORT doesn't match PyTorch when ceil_mode=True until opset 19", + ), TorchLibOpInfo( "nn.functional.conv1d", core_ops.aten_conv1d,