diff --git a/extension/pytree/aten_util/targets.bzl b/extension/pytree/aten_util/targets.bzl index d09ffe48cbd..e1793080203 100644 --- a/extension/pytree/aten_util/targets.bzl +++ b/extension/pytree/aten_util/targets.bzl @@ -17,6 +17,7 @@ def define_common_targets(): ], exported_deps = [ "//executorch/extension/pytree:pytree", + "//executorch/runtime/platform:platform", ], compiler_flags = ["-Wno-missing-prototypes"], fbcode_deps = [ diff --git a/extension/pytree/function_ref.h b/extension/pytree/function_ref.h new file mode 100644 index 00000000000..01d2988597a --- /dev/null +++ b/extension/pytree/function_ref.h @@ -0,0 +1,158 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +//===- llvm/ADT/STLFunctionalExtras.h - Extras for -*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains some extension to . +// +// No library is required when using these functions. +// +//===----------------------------------------------------------------------===// +// Extra additions to +//===----------------------------------------------------------------------===// + +/// An efficient, type-erasing, non-owning reference to a callable. This is +/// intended for use as the type of a function parameter that is not used +/// after the function in question returns. +/// +/// This class does not own the callable, so it is not in general safe to store +/// a FunctionRef. + +// torch::executor: modified from llvm::function_ref +// see https://www.foonathan.net/2017/01/function-ref-implementation/ + +#pragma once + +#include +#include +#include + +namespace torch { +namespace executor { +namespace pytree { + +//===----------------------------------------------------------------------===// +// Features from C++20 +//===----------------------------------------------------------------------===// + +template +struct remove_cvref { + using type = + typename std::remove_cv::type>::type; +}; + +template +using remove_cvref_t = typename remove_cvref::type; + +template +class FunctionRef; + +template +class FunctionRef { + Ret (*callback_)(const void* memory, Params... params) = nullptr; + union Storage { + void* callable; + Ret (*function)(Params...); + } storage_; + + public: + FunctionRef() = default; + explicit FunctionRef(std::nullptr_t) {} + + /** + * Case 1: A callable object passed by lvalue reference. + * Taking rvalue reference is error prone because the object will be always + * be destroyed immediately. + */ + template < + typename Callable, + // This is not the copy-constructor. + typename std::enable_if< + !std::is_same, FunctionRef>::value, + int32_t>::type = 0, + // Avoid lvalue reference to non-capturing lambda. + typename std::enable_if< + !std::is_convertible::value, + int32_t>::type = 0, + // Functor must be callable and return a suitable type. + // To make this container type safe, we need to ensure either: + // 1. The return type is void. + // 2. Or the resulting type from calling the callable is convertible to + // the declared return type. + typename std::enable_if< + std::is_void::value || + std::is_convertible< + decltype(std::declval()(std::declval()...)), + Ret>::value, + int32_t>::type = 0> + explicit FunctionRef(Callable& callable) + : callback_([](const void* memory, Params... params) { + auto& storage = *static_cast(memory); + auto& callable = *static_cast(storage.callable); + return static_cast(callable(std::forward(params)...)); + }) { + storage_.callable = &callable; + } + + /** + * Case 2: A plain function pointer. + * Instead of storing an opaque pointer to underlying callable object, + * store a function pointer directly. + * Note that in the future a variant which coerces compatible function + * pointers could be implemented by erasing the storage type. + */ + /* implicit */ FunctionRef(Ret (*ptr)(Params...)) + : callback_([](const void* memory, Params... params) { + auto& storage = *static_cast(memory); + return storage.function(std::forward(params)...); + }) { + storage_.function = ptr; + } + + /** + * Case 3: Implicit conversion from lambda to FunctionRef. + * A common use pattern is like: + * void foo(FunctionRef<...>) {...} + * foo([](...){...}) + * Here constructors for non const lvalue reference or function pointer + * would not work because they do not cover implicit conversion from rvalue + * lambda. + * We need to define a constructor for capturing temporary callables and + * always try to convert the lambda to a function pointer behind the scene. + */ + template < + typename Function, + // This is not the copy-constructor. + typename std::enable_if< + !std::is_same::value, + int32_t>::type = 0, + // Function is convertible to pointer of (Params...) -> Ret. + typename std::enable_if< + std::is_convertible::value, + int32_t>::type = 0> + /* implicit */ FunctionRef(const Function& function) + : FunctionRef(static_cast(function)) {} + + Ret operator()(Params... params) const { + return callback_(&storage_, std::forward(params)...); + } + + explicit operator bool() const { + return callback_; + } +}; + +} // namespace pytree +} // namespace executor +} // namespace torch diff --git a/extension/pytree/pytree.h b/extension/pytree/pytree.h index 19dedb25814..4d6116fb4b5 100644 --- a/extension/pytree/pytree.h +++ b/extension/pytree/pytree.h @@ -16,7 +16,8 @@ #include #include -#include +// NB: This is a local, pytree FunctionRef and not from the ExecuTorch runtime. +#include namespace torch { namespace executor { diff --git a/extension/pytree/targets.bzl b/extension/pytree/targets.bzl index 5bb8bae1f52..02779c26106 100644 --- a/extension/pytree/targets.bzl +++ b/extension/pytree/targets.bzl @@ -10,13 +10,9 @@ def define_common_targets(): runtime.cxx_library( name = "pytree", srcs = [], - exported_headers = ["pytree.h"], + exported_headers = ["pytree.h", "function_ref.h"], visibility = [ "//executorch/...", "@EXECUTORCH_CLIENTS", ], - exported_deps = [ - "//executorch/runtime/platform:platform", - "//executorch/runtime/core:core", - ], ) diff --git a/extension/pytree/test/TARGETS b/extension/pytree/test/TARGETS index 765f297b381..d1a4eb54590 100644 --- a/extension/pytree/test/TARGETS +++ b/extension/pytree/test/TARGETS @@ -11,6 +11,13 @@ cpp_unittest( deps = ["//executorch/extension/pytree:pytree"], ) +cpp_unittest( + name = "function_ref_test", + srcs = ["function_ref_test.cpp"], + supports_static_listing = True, + deps = ["//executorch/extension/pytree:pytree"], +) + python_unittest( name = "test", srcs = [ diff --git a/extension/pytree/test/function_ref_test.cpp b/extension/pytree/test/function_ref_test.cpp new file mode 100644 index 00000000000..f847c8ebd78 --- /dev/null +++ b/extension/pytree/test/function_ref_test.cpp @@ -0,0 +1,90 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +using namespace ::testing; + +namespace torch { +namespace executor { +namespace pytree { + +namespace { +class Item { + private: + int32_t val_; + FunctionRef ref_; + + public: + /* implicit */ Item(int32_t val, FunctionRef ref) + : val_(val), ref_(ref) {} + + int32_t get() { + ref_(val_); + return val_; + } +}; + +void one(int32_t& i) { + i = 1; +} + +} // namespace + +TEST(FunctionRefTest, CapturingLambda) { + auto one = 1; + auto f = [&](int32_t& i) { i = one; }; + Item item(0, FunctionRef{f}); + EXPECT_EQ(item.get(), 1); + // ERROR: + // Item item1(0, f); + // Item item2(0, [&](int32_t& i) { i = 2; }); + // FunctionRef ref([&](int32_t&){}); +} + +TEST(FunctionRefTest, NonCapturingLambda) { + int32_t val = 0; + FunctionRef ref([](int32_t& i) { i = 1; }); + ref(val); + EXPECT_EQ(val, 1); + + val = 0; + auto lambda = [](int32_t& i) { i = 1; }; + FunctionRef ref1(lambda); + ref1(val); + EXPECT_EQ(val, 1); + + Item item(0, [](int32_t& i) { i = 1; }); + EXPECT_EQ(item.get(), 1); + + auto f = [](int32_t& i) { i = 1; }; + Item item1(0, f); + EXPECT_EQ(item1.get(), 1); + + Item item2(0, std::move(f)); + EXPECT_EQ(item2.get(), 1); +} + +TEST(FunctionRefTest, FunctionPointer) { + int32_t val = 0; + FunctionRef ref(one); + ref(val); + EXPECT_EQ(val, 1); + + Item item(0, one); + EXPECT_EQ(item.get(), 1); + + Item item1(0, &one); + EXPECT_EQ(item1.get(), 1); +} + +} // namespace pytree +} // namespace executor +} // namespace torch diff --git a/runtime/core/function_ref.h b/runtime/core/function_ref.h index 92171134291..a07f6151f10 100644 --- a/runtime/core/function_ref.h +++ b/runtime/core/function_ref.h @@ -59,9 +59,7 @@ class FunctionRef; template class FunctionRef { - Ret (*callback_)(const void* memory, Params... params) = nullptr; union Storage { - void* callable; Ret (*function)(Params...); } storage_; @@ -70,57 +68,18 @@ class FunctionRef { explicit FunctionRef(std::nullptr_t) {} /** - * Case 1: A callable object passed by lvalue reference. - * Taking rvalue reference is error prone because the object will be always - * be destroyed immediately. - */ - template < - typename Callable, - // This is not the copy-constructor. - typename std::enable_if< - !std::is_same, FunctionRef>::value, - int32_t>::type = 0, - // Avoid lvalue reference to non-capturing lambda. - typename std::enable_if< - !std::is_convertible::value, - int32_t>::type = 0, - // Functor must be callable and return a suitable type. - // To make this container type safe, we need to ensure either: - // 1. The return type is void. - // 2. Or the resulting type from calling the callable is convertible to - // the declared return type. - typename std::enable_if< - std::is_void::value || - std::is_convertible< - decltype(std::declval()(std::declval()...)), - Ret>::value, - int32_t>::type = 0> - explicit FunctionRef(Callable& callable) - : callback_([](const void* memory, Params... params) { - auto& storage = *static_cast(memory); - auto& callable = *static_cast(storage.callable); - return static_cast(callable(std::forward(params)...)); - }) { - storage_.callable = &callable; - } - - /** - * Case 2: A plain function pointer. + * Case 1: A plain function pointer. * Instead of storing an opaque pointer to underlying callable object, * store a function pointer directly. * Note that in the future a variant which coerces compatible function * pointers could be implemented by erasing the storage type. */ - /* implicit */ FunctionRef(Ret (*ptr)(Params...)) - : callback_([](const void* memory, Params... params) { - auto& storage = *static_cast(memory); - return storage.function(std::forward(params)...); - }) { + /* implicit */ FunctionRef(Ret (*ptr)(Params...)) { storage_.function = ptr; } /** - * Case 3: Implicit conversion from lambda to FunctionRef. + * Case 2: Implicit conversion from lambda to FunctionRef. * A common use pattern is like: * void foo(FunctionRef<...>) {...} * foo([](...){...}) @@ -144,11 +103,11 @@ class FunctionRef { : FunctionRef(static_cast(function)) {} Ret operator()(Params... params) const { - return callback_(&storage_, std::forward(params)...); + return storage_.function(std::forward(params)...); } explicit operator bool() const { - return callback_; + return storage_.function; } }; diff --git a/runtime/core/test/function_ref_test.cpp b/runtime/core/test/function_ref_test.cpp index 69442a8ba4e..1ec8faccf50 100644 --- a/runtime/core/test/function_ref_test.cpp +++ b/runtime/core/test/function_ref_test.cpp @@ -37,17 +37,7 @@ void one(int32_t& i) { } // namespace -TEST(FunctionRefTest, CapturingLambda) { - auto one = 1; - auto f = [&](int32_t& i) { i = one; }; - Item item(0, FunctionRef{f}); - EXPECT_EQ(item.get(), 1); - // ERROR: - // Item item1(0, f); - // Item item2(0, [&](int32_t& i) { i = 2; }); - // FunctionRef ref([&](int32_t&){}); -} - +// Only non-capturing lambdas can be used to initialize a function reference. TEST(FunctionRefTest, NonCapturingLambda) { int32_t val = 0; FunctionRef ref([](int32_t& i) { i = 1; }); diff --git a/runtime/kernel/operator_registry.h b/runtime/kernel/operator_registry.h index 7f9ef5b6244..55cb4164715 100644 --- a/runtime/kernel/operator_registry.h +++ b/runtime/kernel/operator_registry.h @@ -42,7 +42,9 @@ namespace executor { class KernelRuntimeContext; // Forward declaration using RuntimeContext = KernelRuntimeContext; // TODO(T147221312): Remove -using OpFunction = FunctionRef; +using OpFunction = + FunctionRef; // TODO(T165139545): + // Remove FunctionRef /** * Dtype and dim order metadata for a Tensor argument to an operator.