diff --git a/extension/pytree/aten_util/targets.bzl b/extension/pytree/aten_util/targets.bzl index d09ffe48cbd..e1793080203 100644 --- a/extension/pytree/aten_util/targets.bzl +++ b/extension/pytree/aten_util/targets.bzl @@ -17,6 +17,7 @@ def define_common_targets(): ], exported_deps = [ "//executorch/extension/pytree:pytree", + "//executorch/runtime/platform:platform", ], compiler_flags = ["-Wno-missing-prototypes"], fbcode_deps = [ diff --git a/extension/pytree/function_ref.h b/extension/pytree/function_ref.h new file mode 100644 index 00000000000..01d2988597a --- /dev/null +++ b/extension/pytree/function_ref.h @@ -0,0 +1,158 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +//===- llvm/ADT/STLFunctionalExtras.h - Extras for -*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains some extension to . +// +// No library is required when using these functions. +// +//===----------------------------------------------------------------------===// +// Extra additions to +//===----------------------------------------------------------------------===// + +/// An efficient, type-erasing, non-owning reference to a callable. This is +/// intended for use as the type of a function parameter that is not used +/// after the function in question returns. +/// +/// This class does not own the callable, so it is not in general safe to store +/// a FunctionRef. + +// torch::executor: modified from llvm::function_ref +// see https://www.foonathan.net/2017/01/function-ref-implementation/ + +#pragma once + +#include +#include +#include + +namespace torch { +namespace executor { +namespace pytree { + +//===----------------------------------------------------------------------===// +// Features from C++20 +//===----------------------------------------------------------------------===// + +template +struct remove_cvref { + using type = + typename std::remove_cv::type>::type; +}; + +template +using remove_cvref_t = typename remove_cvref::type; + +template +class FunctionRef; + +template +class FunctionRef { + Ret (*callback_)(const void* memory, Params... params) = nullptr; + union Storage { + void* callable; + Ret (*function)(Params...); + } storage_; + + public: + FunctionRef() = default; + explicit FunctionRef(std::nullptr_t) {} + + /** + * Case 1: A callable object passed by lvalue reference. + * Taking rvalue reference is error prone because the object will be always + * be destroyed immediately. + */ + template < + typename Callable, + // This is not the copy-constructor. + typename std::enable_if< + !std::is_same, FunctionRef>::value, + int32_t>::type = 0, + // Avoid lvalue reference to non-capturing lambda. + typename std::enable_if< + !std::is_convertible::value, + int32_t>::type = 0, + // Functor must be callable and return a suitable type. + // To make this container type safe, we need to ensure either: + // 1. The return type is void. + // 2. Or the resulting type from calling the callable is convertible to + // the declared return type. + typename std::enable_if< + std::is_void::value || + std::is_convertible< + decltype(std::declval()(std::declval()...)), + Ret>::value, + int32_t>::type = 0> + explicit FunctionRef(Callable& callable) + : callback_([](const void* memory, Params... params) { + auto& storage = *static_cast(memory); + auto& callable = *static_cast(storage.callable); + return static_cast(callable(std::forward(params)...)); + }) { + storage_.callable = &callable; + } + + /** + * Case 2: A plain function pointer. + * Instead of storing an opaque pointer to underlying callable object, + * store a function pointer directly. + * Note that in the future a variant which coerces compatible function + * pointers could be implemented by erasing the storage type. + */ + /* implicit */ FunctionRef(Ret (*ptr)(Params...)) + : callback_([](const void* memory, Params... params) { + auto& storage = *static_cast(memory); + return storage.function(std::forward(params)...); + }) { + storage_.function = ptr; + } + + /** + * Case 3: Implicit conversion from lambda to FunctionRef. + * A common use pattern is like: + * void foo(FunctionRef<...>) {...} + * foo([](...){...}) + * Here constructors for non const lvalue reference or function pointer + * would not work because they do not cover implicit conversion from rvalue + * lambda. + * We need to define a constructor for capturing temporary callables and + * always try to convert the lambda to a function pointer behind the scene. + */ + template < + typename Function, + // This is not the copy-constructor. + typename std::enable_if< + !std::is_same::value, + int32_t>::type = 0, + // Function is convertible to pointer of (Params...) -> Ret. + typename std::enable_if< + std::is_convertible::value, + int32_t>::type = 0> + /* implicit */ FunctionRef(const Function& function) + : FunctionRef(static_cast(function)) {} + + Ret operator()(Params... params) const { + return callback_(&storage_, std::forward(params)...); + } + + explicit operator bool() const { + return callback_; + } +}; + +} // namespace pytree +} // namespace executor +} // namespace torch diff --git a/extension/pytree/pytree.h b/extension/pytree/pytree.h index 19dedb25814..4d6116fb4b5 100644 --- a/extension/pytree/pytree.h +++ b/extension/pytree/pytree.h @@ -16,7 +16,8 @@ #include #include -#include +// NB: This is a local, pytree FunctionRef and not from the ExecuTorch runtime. +#include namespace torch { namespace executor { diff --git a/extension/pytree/targets.bzl b/extension/pytree/targets.bzl index 5bb8bae1f52..02779c26106 100644 --- a/extension/pytree/targets.bzl +++ b/extension/pytree/targets.bzl @@ -10,13 +10,9 @@ def define_common_targets(): runtime.cxx_library( name = "pytree", srcs = [], - exported_headers = ["pytree.h"], + exported_headers = ["pytree.h", "function_ref.h"], visibility = [ "//executorch/...", "@EXECUTORCH_CLIENTS", ], - exported_deps = [ - "//executorch/runtime/platform:platform", - "//executorch/runtime/core:core", - ], ) diff --git a/extension/pytree/test/TARGETS b/extension/pytree/test/TARGETS index 765f297b381..d1a4eb54590 100644 --- a/extension/pytree/test/TARGETS +++ b/extension/pytree/test/TARGETS @@ -11,6 +11,13 @@ cpp_unittest( deps = ["//executorch/extension/pytree:pytree"], ) +cpp_unittest( + name = "function_ref_test", + srcs = ["function_ref_test.cpp"], + supports_static_listing = True, + deps = ["//executorch/extension/pytree:pytree"], +) + python_unittest( name = "test", srcs = [ diff --git a/extension/pytree/test/function_ref_test.cpp b/extension/pytree/test/function_ref_test.cpp new file mode 100644 index 00000000000..f847c8ebd78 --- /dev/null +++ b/extension/pytree/test/function_ref_test.cpp @@ -0,0 +1,90 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +using namespace ::testing; + +namespace torch { +namespace executor { +namespace pytree { + +namespace { +class Item { + private: + int32_t val_; + FunctionRef ref_; + + public: + /* implicit */ Item(int32_t val, FunctionRef ref) + : val_(val), ref_(ref) {} + + int32_t get() { + ref_(val_); + return val_; + } +}; + +void one(int32_t& i) { + i = 1; +} + +} // namespace + +TEST(FunctionRefTest, CapturingLambda) { + auto one = 1; + auto f = [&](int32_t& i) { i = one; }; + Item item(0, FunctionRef{f}); + EXPECT_EQ(item.get(), 1); + // ERROR: + // Item item1(0, f); + // Item item2(0, [&](int32_t& i) { i = 2; }); + // FunctionRef ref([&](int32_t&){}); +} + +TEST(FunctionRefTest, NonCapturingLambda) { + int32_t val = 0; + FunctionRef ref([](int32_t& i) { i = 1; }); + ref(val); + EXPECT_EQ(val, 1); + + val = 0; + auto lambda = [](int32_t& i) { i = 1; }; + FunctionRef ref1(lambda); + ref1(val); + EXPECT_EQ(val, 1); + + Item item(0, [](int32_t& i) { i = 1; }); + EXPECT_EQ(item.get(), 1); + + auto f = [](int32_t& i) { i = 1; }; + Item item1(0, f); + EXPECT_EQ(item1.get(), 1); + + Item item2(0, std::move(f)); + EXPECT_EQ(item2.get(), 1); +} + +TEST(FunctionRefTest, FunctionPointer) { + int32_t val = 0; + FunctionRef ref(one); + ref(val); + EXPECT_EQ(val, 1); + + Item item(0, one); + EXPECT_EQ(item.get(), 1); + + Item item1(0, &one); + EXPECT_EQ(item1.get(), 1); +} + +} // namespace pytree +} // namespace executor +} // namespace torch