From 49439bbd8d13375b31a3f69807a431c765d27b42 Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Fri, 29 Sep 2023 08:02:48 -0700 Subject: [PATCH] Add split_with_sizes_copy Summary: Adds an implementation of `aten.split_with_sizes_copy`. The implementation is identical to the existing implementation for `split.Tensor`. The only difference is the input checking and output resizing functions. Differential Revision: D49763676 --- exir/dialects/edge/op/sample_input.py | 17 ++++ .../portable/cpu/op_split_with_sizes_copy.cpp | 87 ++++++++++++++++++ kernels/portable/cpu/targets.bzl | 6 ++ kernels/portable/cpu/util/copy_ops_util.cpp | 41 +++++++++ kernels/portable/cpu/util/copy_ops_util.h | 13 +++ kernels/portable/functions.yaml | 5 ++ .../test/op_split_with_sizes_copy_test.cpp | 90 +++++++++++++++++++ kernels/test/targets.bzl | 1 + 8 files changed, 260 insertions(+) create mode 100644 kernels/portable/cpu/op_split_with_sizes_copy.cpp create mode 100644 kernels/test/op_split_with_sizes_copy_test.cpp diff --git a/exir/dialects/edge/op/sample_input.py b/exir/dialects/edge/op/sample_input.py index d9c228a2887..f51e8d9ece6 100644 --- a/exir/dialects/edge/op/sample_input.py +++ b/exir/dialects/edge/op/sample_input.py @@ -1150,6 +1150,23 @@ ), ], }, + "split_with_sizes_copy.default": { # (Tensor self, SymInt[] split_sizes, int dim=0) -> Tensor[] + "args": [ + InArg(ArgType.Tensor, size=[2, 6, 3]), + InArg(ArgType.LengthList, value=[3, 1, 2]), + InArg(ArgType.Dim, value=1), + ], + "returns": [ + Return( + ArgType.TensorList, + value=[ + Return(ArgType.Tensor, size=[2, 3, 3]), + Return(ArgType.Tensor, size=[2, 1, 3]), + Return(ArgType.Tensor, size=[2, 2, 3]), + ], + ), + ], + }, "sqrt.default": { # (Tensor self) -> Tensor "args": [ InArg(ArgType.Tensor), diff --git a/kernels/portable/cpu/op_split_with_sizes_copy.cpp b/kernels/portable/cpu/op_split_with_sizes_copy.cpp new file mode 100644 index 00000000000..a394308dd53 --- /dev/null +++ b/kernels/portable/cpu/op_split_with_sizes_copy.cpp @@ -0,0 +1,87 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include +#include + +namespace torch { +namespace executor { +namespace native { + +using Tensor = exec_aten::Tensor; +using TensorList = exec_aten::TensorList; + +void split_with_sizes_copy_out( + RuntimeContext& ctx, + const Tensor& in, + exec_aten::ArrayRef split_sizes, + int64_t dim, + TensorList out) { + (void)ctx; + // Support python-style negative indexing. Note that this op does not accept 0 + // dimensional input tensors. + if (dim < 0) { + dim += in.dim(); + } + + ET_KERNEL_CHECK( + ctx, + check_split_with_sizes_copy_args(in, split_sizes, dim, out), + InvalidArgument, + out); + + Tensor::SizesType expected_out_size[kTensorDimensionLimit]; + size_t expected_out_dim = 0; + for (size_t i = 0; i < split_sizes.size(); i++) { + expected_out_size[expected_out_dim++] = split_sizes[i]; + get_split_with_sizes_copy_out_target_size( + in, split_sizes[i], dim, expected_out_size, &expected_out_dim); + ET_KERNEL_CHECK( + ctx, + resize_tensor(out[i], {expected_out_size, expected_out_dim}) == + Error::Ok, + InvalidArgument, + out); + } + + const size_t leading_dims = getLeadingDims(in, dim); + const size_t trailing_dims = getTrailingDims(in, dim); + const size_t step = in.size(dim) * trailing_dims; + + ScalarType in_type = in.scalar_type(); + ScalarType out_type = out[0].scalar_type(); + + ET_SWITCH_REAL_TYPES_AND(Bool, in_type, ctx, __func__, CTYPE_IN, [&]() { + ET_SWITCH_REAL_TYPES_AND(Bool, out_type, ctx, __func__, CTYPE_OUT, [&]() { + const CTYPE_IN* in_data = in.const_data_ptr(); + for (size_t i = 0, e = out.size(); i < e; ++i) { + size_t out_step = out[i].size(dim) * trailing_dims; + if (out_step == 0) { + continue; + } + const CTYPE_IN* src = in_data; + CTYPE_OUT* dest = out[i].mutable_data_ptr(); + for (size_t j = 0; j < leading_dims; ++j) { + for (size_t k = 0; k < out_step; ++k) { + dest[k] = convert(src[k]); + } + src += step; + dest += out_step; + } + in_data += out_step; + } + }); + }); +} + +} // namespace native +} // namespace executor +} // namespace torch diff --git a/kernels/portable/cpu/targets.bzl b/kernels/portable/cpu/targets.bzl index f6587ac2134..62d93009e8a 100644 --- a/kernels/portable/cpu/targets.bzl +++ b/kernels/portable/cpu/targets.bzl @@ -691,6 +691,12 @@ _ATEN_OPS = ( op_target( name = "op_split_copy", ), + op_target( + name = "op_split_with_sizes_copy", + deps = [ + "//executorch/kernels/portable/cpu/util:copy_ops_util", + ], + ), op_target( name = "op_sqrt", deps = [ diff --git a/kernels/portable/cpu/util/copy_ops_util.cpp b/kernels/portable/cpu/util/copy_ops_util.cpp index 7a8fc3dbb6f..acefe61e780 100644 --- a/kernels/portable/cpu/util/copy_ops_util.cpp +++ b/kernels/portable/cpu/util/copy_ops_util.cpp @@ -163,6 +163,47 @@ void get_pixel_shuffle_out_target_size( out_sizes[i] = in.size(i) * casted_upscale_factor; } +bool check_split_with_sizes_copy_args( + const Tensor& in, + exec_aten::ArrayRef split_sizes, + int64_t dim, + TensorList out) { + ET_LOG_AND_RETURN_IF_FALSE(tensor_has_rank_greater_or_equal_to(in, 1)); + ET_LOG_AND_RETURN_IF_FALSE(tensor_has_dim(in, dim)); + + ET_LOG_MSG_AND_RETURN_IF_FALSE( + split_sizes.size() == out.size(), + "Number of split sizes must match the number of output tensors"); + + int64_t sum = 0; + for (int i = 0; i < split_sizes.size(); i++) { + ET_LOG_MSG_AND_RETURN_IF_FALSE( + split_sizes[i] >= 0, "All split sizes must be non negative."); + sum += split_sizes[i]; + } + + const ssize_t dim_size = in.size(dim); + ET_LOG_MSG_AND_RETURN_IF_FALSE( + sum == dim_size, + "Sum of split sizes does not match input size at given dim"); + + return true; +} + +void get_split_with_sizes_copy_out_target_size( + const Tensor& in, + int64_t split_size, + int64_t dim, + Tensor::SizesType* out_sizes, + size_t* out_ndim) { + *out_ndim = in.dim(); + + for (size_t d = 0; d < in.dim(); ++d) { + out_sizes[d] = in.size(d); + } + out_sizes[dim] = split_size; +} + bool check_stack_args( exec_aten::ArrayRef tensors, int64_t dim, diff --git a/kernels/portable/cpu/util/copy_ops_util.h b/kernels/portable/cpu/util/copy_ops_util.h index f4ed50f189e..10f145341ea 100644 --- a/kernels/portable/cpu/util/copy_ops_util.h +++ b/kernels/portable/cpu/util/copy_ops_util.h @@ -43,6 +43,19 @@ void get_pixel_shuffle_out_target_size( Tensor::SizesType* out_sizes, size_t* out_ndim); +bool check_split_with_sizes_copy_args( + const Tensor& in, + exec_aten::ArrayRef split_sizes, + int64_t dim, + TensorList out); + +void get_split_with_sizes_copy_out_target_size( + const Tensor& in, + int64_t split_size, + int64_t dim, + Tensor::SizesType* out_sizes, + size_t* out_ndim); + bool check_stack_args( exec_aten::ArrayRef tensors, int64_t dim, diff --git a/kernels/portable/functions.yaml b/kernels/portable/functions.yaml index 949b771b9cc..4c4e06c2414 100644 --- a/kernels/portable/functions.yaml +++ b/kernels/portable/functions.yaml @@ -632,6 +632,11 @@ - arg_meta: null kernel_name: torch::executor::split_copy_Tensor_out +- op: split_with_sizes_copy.out + kernels: + - arg_meta: null + kernel_name: torch::executor::split_with_sizes_copy_out + - op: sqrt.out kernels: - arg_meta: null diff --git a/kernels/test/op_split_with_sizes_copy_test.cpp b/kernels/test/op_split_with_sizes_copy_test.cpp new file mode 100644 index 00000000000..1941ca4b690 --- /dev/null +++ b/kernels/test/op_split_with_sizes_copy_test.cpp @@ -0,0 +1,90 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include // Declares the operator +#include +#include +#include +#include +#include + +#include + +using namespace ::testing; + +void op_split_with_sizes_copy_out( + const exec_aten::Tensor& self, + exec_aten::ArrayRef split_sizes, + int64_t dim, + exec_aten::TensorList out) { + exec_aten::RuntimeContext context{}; + return torch::executor::aten::split_with_sizes_copy_outf( + context, self, split_sizes, dim, out); +} + +TEST(OpSplitWithSizesCopyOutTest, SanityCheckDim1) { + torch::executor::testing::TensorFactory tfFloat; + + exec_aten::Tensor self = tfFloat.make( + {2, 6, 3}, + {-31.25, -92.75, -39.75, -3.25, 53.875, 88.25, -0.625, -1.125, + 14.75, 42.0, 89.875, -21.125, -8.0, -64.125, 23.0, 37.0, + 46.125, -83.25, -58.125, 19.625, -71.125, 64.75, -1.375, -83.5, + -61.375, 13.125, 28.625, -94.0, -67.0, -8.625, -88.875, -79.125, + 0.375, -61.375, 65.0, -99.375}); + ::std::vector split_sizes_vec = {3, 1, 2}; + exec_aten::ArrayRef split_sizes = exec_aten::ArrayRef( + split_sizes_vec.data(), split_sizes_vec.size()); + int64_t dim = 1; + ::std::vector out_vec = { + tfFloat.zeros({2, 3, 3}), + tfFloat.zeros({2, 1, 3}), + tfFloat.zeros({2, 2, 3})}; + exec_aten::TensorList out = + exec_aten::TensorList(out_vec.data(), out_vec.size()); + ::std::vector out_expected_vec = { + tfFloat.make( + {2, 3, 3}, + {-31.25, + -92.75, + -39.75, + -3.25, + 53.875, + 88.25, + -0.625, + -1.125, + 14.75, + -58.125, + 19.625, + -71.125, + 64.75, + -1.375, + -83.5, + -61.375, + 13.125, + 28.625}), + tfFloat.make({2, 1, 3}, {42.0, 89.875, -21.125, -94.0, -67.0, -8.625}), + tfFloat.make( + {2, 2, 3}, + {-8.0, + -64.125, + 23.0, + 37.0, + 46.125, + -83.25, + -88.875, + -79.125, + 0.375, + -61.375, + 65.0, + -99.375})}; + exec_aten::TensorList out_expected = + exec_aten::TensorList(out_expected_vec.data(), out_expected_vec.size()); + op_split_with_sizes_copy_out(self, split_sizes, dim, out); + EXPECT_TENSOR_LISTS_CLOSE(out, out_expected); +} diff --git a/kernels/test/targets.bzl b/kernels/test/targets.bzl index d20f70caed3..85a00db3d3d 100644 --- a/kernels/test/targets.bzl +++ b/kernels/test/targets.bzl @@ -260,6 +260,7 @@ def define_common_targets(): _common_op_test("op_slice_copy_test", ["aten", "portable"]) _common_op_test("op_softmax_test", ["aten", "portable"]) _common_op_test("op_split_copy_test", ["aten", "portable"]) + _common_op_test("op_split_with_sizes_copy_test", ["aten", "portable"]) _common_op_test("op_sqrt_test", ["aten", "portable"]) _common_op_test("op_squeeze_copy_test", ["aten", "portable"]) _common_op_test("op_stack_test", ["aten", "portable"])