diff --git a/.ci/docker/ci_commit_pins/pytorch.txt b/.ci/docker/ci_commit_pins/pytorch.txt index ccee7739dc6..ddd1f4a6b16 100644 --- a/.ci/docker/ci_commit_pins/pytorch.txt +++ b/.ci/docker/ci_commit_pins/pytorch.txt @@ -1 +1 @@ -b1984237a0fb32b760c1b84d6d02d2f0f7ed293b +48b6c8dbc376db4406a979b35cd6909bcb428931 diff --git a/.ci/docker/common/install_base.sh b/.ci/docker/common/install_base.sh index fec0e057ba1..cbca22cfa33 100755 --- a/.ci/docker/common/install_base.sh +++ b/.ci/docker/common/install_base.sh @@ -23,7 +23,8 @@ install_ubuntu() { unzip \ gdb \ rsync \ - libssl-dev + libssl-dev \ + zip # Cleanup package manager apt-get autoclean && apt-get clean diff --git a/.github/workflows/android.yml b/.github/workflows/android.yml index 324440dd745..051c9b22a78 100644 --- a/.github/workflows/android.yml +++ b/.github/workflows/android.yml @@ -61,6 +61,9 @@ jobs: cp cmake-out-android-arm64-v8a/extension/android/*.so artifacts-to-be-uploaded/arm64-v8a/ cp cmake-out-android-x86_64/lib/*.a artifacts-to-be-uploaded/x86_64/ cp cmake-out-android-x86_64/extension/android/*.so artifacts-to-be-uploaded/x86_64/ + # Copyp AAR to S3 + cp build_aar/executorch.aar artifacts-to-be-uploaded/ + cp build_aar/executorch-llama.aar artifacts-to-be-uploaded/ # Upload the app and its test suite to S3 so that they can be downloaded by the test job upload-artifacts: diff --git a/.gitmodules b/.gitmodules index 40af2980839..e22780c8e84 100644 --- a/.gitmodules +++ b/.gitmodules @@ -68,3 +68,6 @@ [submodule "third-party/ios-cmake"] path = third-party/ios-cmake url = https://github.com/leetal/ios-cmake +[submodule "backends/cadence/hifi/third-party/nnlib/nnlib-hifi4"] + path = backends/cadence/hifi/third-party/nnlib/nnlib-hifi4 + url = https://github.com/foss-xtensa/nnlib-hifi4.git diff --git a/CMakeLists.txt b/CMakeLists.txt index 3b4c31a131c..42568129c02 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -417,7 +417,7 @@ target_link_libraries(executorch_no_prim_ops PRIVATE program_schema) # Check if dl exists for this toolchain and only then link it. find_library(DL_LIBRARY_EXISTS NAMES dl) # Check if the library was found -if(DL_LIBRARY_EXISTS) +if(DL_LIBRARY_EXISTS AND NOT EXECUTORCH_BUILD_CADENCE) target_link_libraries(executorch_no_prim_ops PRIVATE dl) # For dladdr() endif() target_include_directories( @@ -443,7 +443,9 @@ target_link_options_shared_lib(executorch) # Real integrations should supply their own YAML file that only lists the # operators necessary for the models that will run. # +if(NOT EXECUTORCH_BUILD_CADENCE) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/kernels/portable) +endif() if(EXECUTORCH_BUILD_KERNELS_CUSTOM) # TODO: move all custom kernels to ${CMAKE_CURRENT_SOURCE_DIR}/kernels/custom @@ -496,6 +498,8 @@ if(EXECUTORCH_BUILD_EXECUTOR_RUNNER) if(EXECUTORCH_BUILD_KERNELS_OPTIMIZED) list(APPEND _executor_runner_libs optimized_native_cpu_ops_lib) + elseif(EXECUTORCH_BUILD_CADENCE) + list(APPEND _executor_runner_libs cadence_ops_lib) else() list(APPEND _executor_runner_libs portable_ops_lib) endif() @@ -566,6 +570,10 @@ if(EXECUTORCH_BUILD_COREML) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backends/apple/coreml) endif() +if(EXECUTORCH_BUILD_CADENCE) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backends/cadence) +endif() + if(EXECUTORCH_BUILD_PYBIND) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/third-party/pybind11) diff --git a/backends/apple/coreml/CMakeLists.txt b/backends/apple/coreml/CMakeLists.txt index 0f56e3bdf7e..c9a15ed935e 100644 --- a/backends/apple/coreml/CMakeLists.txt +++ b/backends/apple/coreml/CMakeLists.txt @@ -155,6 +155,8 @@ target_link_libraries( ${FOUNDATION_FRAMEWORK} ${SQLITE_LIBRARY} ) +target_link_options_shared_lib(coremldelegate) + if(COREML_BUILD_EXECUTOR_RUNNER) target_link_libraries( coremldelegate PRIVATE portable_ops_lib portable_kernels diff --git a/backends/apple/coreml/README.md b/backends/apple/coreml/README.md index 4a21d8d8ae1..05b56e9c788 100644 --- a/backends/apple/coreml/README.md +++ b/backends/apple/coreml/README.md @@ -28,7 +28,7 @@ import torch import executorch.exir from executorch.backends.apple.coreml.compiler import CoreMLBackend -from executorch.backends.apple.coreml.partition.coreml_partitioner import CoreMLPartitioner +from executorch.backends.apple.coreml.partition import CoreMLPartitioner class Model(torch.nn.Module): def __init__(self): @@ -72,7 +72,7 @@ from torch.ao.quantization.quantize_pt2e import ( prepare_qat_pt2e, ) -from executorch.backends.apple.coreml.quantizer.coreml_quantizer import CoreMLQuantizer +from executorch.backends.apple.coreml.quantizer import CoreMLQuantizer from coremltools.optimize.torch.quantization.quantization_config import ( LinearQuantizerConfig, QuantizationScheme, diff --git a/backends/apple/coreml/partition/__init__.py b/backends/apple/coreml/partition/__init__.py new file mode 100644 index 00000000000..1630e9ece45 --- /dev/null +++ b/backends/apple/coreml/partition/__init__.py @@ -0,0 +1,9 @@ +# Copyright © 2023 Apple Inc. All rights reserved. +# +# Please refer to the license found in the LICENSE file in the root directory of the source tree. + +from .coreml_partitioner import CoreMLPartitioner + +__all__ = [ + CoreMLPartitioner, +] diff --git a/backends/apple/coreml/quantizer/__init__.py b/backends/apple/coreml/quantizer/__init__.py new file mode 100644 index 00000000000..f6282834fa1 --- /dev/null +++ b/backends/apple/coreml/quantizer/__init__.py @@ -0,0 +1,9 @@ +# Copyright © 2023 Apple Inc. All rights reserved. +# +# Please refer to the license found in the LICENSE file in the root directory of the source tree. + +from .coreml_quantizer import CoreMLQuantizer + +__all__ = [ + CoreMLQuantizer, +] diff --git a/backends/apple/coreml/runtime/inmemoryfs/inmemory_filesystem.cpp b/backends/apple/coreml/runtime/inmemoryfs/inmemory_filesystem.cpp index bddb7e4d410..f699316cfdb 100644 --- a/backends/apple/coreml/runtime/inmemoryfs/inmemory_filesystem.cpp +++ b/backends/apple/coreml/runtime/inmemoryfs/inmemory_filesystem.cpp @@ -10,7 +10,6 @@ #include #include #include -#include #include #if __has_include() @@ -22,7 +21,8 @@ namespace filesystem = std::experimental::filesystem; } #endif -#include +#include "range.hpp" +#include "reversed_memory_stream.hpp" namespace { using namespace inmemoryfs; diff --git a/backends/apple/coreml/runtime/inmemoryfs/inmemory_filesystem.hpp b/backends/apple/coreml/runtime/inmemoryfs/inmemory_filesystem.hpp index fedf4190334..d0ace1a5250 100644 --- a/backends/apple/coreml/runtime/inmemoryfs/inmemory_filesystem.hpp +++ b/backends/apple/coreml/runtime/inmemoryfs/inmemory_filesystem.hpp @@ -8,14 +8,15 @@ #pragma once #include -#include #include -#include #include #include #include #include +#include "inmemory_filesystem_metadata.hpp" +#include "memory_buffer.hpp" + namespace inmemoryfs { /// A class representing an in-memory file system. @@ -29,36 +30,36 @@ class InMemoryFileSystem final { DirectoryExpected, // If path is not a directory. FileExpected, // If the path is not a file. }; - + /// Options for loading file content. enum class FileLoadOption: int8_t { Malloc = 1, // Copy file contents into memory. MMap, // Memory map file contents. LazyMMap // Memory map file contents but lazily. }; - + /// The error category for `InMemoryFileSystem`. struct ErrorCategory final: public std::error_category { public: inline const char* name() const noexcept override { return "InMemoryFileSystem"; } - + std::string message(int code) const override; }; - + struct Attributes { time_t modificationTime; - + inline Attributes() noexcept: modificationTime(time(0)) {} }; - + using MetadataWriter = std::function; using MetadataWriterInMemory = std::function; using MetadataReader = std::function(std::istream&)>; - + /// A class representing an in-memory node. This could either be a file node or a directory node. class InMemoryNode { public: @@ -67,7 +68,7 @@ class InMemoryFileSystem final { File = 0, /// Node is a File. Directory /// Node is a Directory. }; - + /// Constructs an in-memory node instance. /// /// @param name The name of the Node. It must be unique in the enclosing Directory. @@ -78,38 +79,38 @@ class InMemoryFileSystem final { attributes_(std::move(attributes)), kind_(kind) {} - + InMemoryNode(InMemoryNode const&) = delete; InMemoryNode& operator=(InMemoryNode const&) = delete; - + inline virtual ~InMemoryNode() {} - + /// Returns the node attributes. inline Attributes attributes() const noexcept { return attributes_; } - + /// Sets the node attributes. /// /// @param attributes The node attributes. inline void set_attributes(Attributes attributes) noexcept { attributes_ = std::move(attributes); } - + /// Returns the node kind, possible values are `File` and `Directory`. inline Kind kind() const noexcept { return kind_; } - + /// Returns the name of the node. inline const std::string& name() const noexcept { return name_; } - + inline void set_name(std::string name) noexcept { std::swap(name_, name); } - + /// Returns `true` if the node is a directory otherwise `false`. inline bool isDirectory() const noexcept { switch (kind_) { @@ -119,58 +120,58 @@ class InMemoryFileSystem final { return false; } } - + /// Returns `true` if the node is a file otherwise `false`. inline bool isFile() const noexcept { return !isDirectory(); } - + private: std::string name_; InMemoryFileSystem::Attributes attributes_; const Kind kind_; }; - + /// Constructs an`InMemoryFileSystem` instance with an empty root and the specified name. /// /// @param rootName The name of the root node. explicit InMemoryFileSystem(std::string rootName = "root") noexcept; - + /// Constructs an`InMemoryFileSystem` instance with the specified root. /// /// @param root The root node. explicit InMemoryFileSystem(std::unique_ptr root) noexcept :root_(std::move(root)) {} - + InMemoryFileSystem(InMemoryFileSystem const&) = delete; InMemoryFileSystem& operator=(InMemoryFileSystem const&) = delete; - + virtual ~InMemoryFileSystem() {} - + /// Returns the root. InMemoryNode *root() const noexcept { return root_.get(); } - + /// Checks if the node at the specified path is a directory. /// /// @param canonical_path The path components from the root. /// @retval `true` if the node at the specified path is a directory otherwise `false`. bool is_directory(const std::vector& canonical_path) noexcept; - + /// Checks if the node at the specified path is a file. /// /// @param canonical_path The path components from the root. /// @retval `true` if the node at the specified path is a file otherwise `false`. bool is_file(const std::vector& canonical_path) noexcept; - + /// Checks if the node at the specified path exists. /// /// @param canonical_path The path components from the root. /// @retval `true` if the node at the specified path exists. bool exists(const std::vector& canonical_path) const noexcept; - + /// Retrieves the canonical path of all the child nodes at the specified path. The node /// at the specified path must be a directory otherwise it returns an empty vector with the `error` /// populated. @@ -180,7 +181,7 @@ class InMemoryFileSystem final { /// @retval paths to all the items at the specified path. std::vector> get_item_paths(const std::vector& canonical_path, std::error_code& error) const noexcept; - + /// Retrieves the attributes of the item at the specified path. /// /// @param canonical_path The path components from the root. @@ -188,7 +189,7 @@ class InMemoryFileSystem final { /// @retval The item attributes at the specified path. std::optional get_attributes(const std::vector& canonical_path, std::error_code& error) const noexcept; - + /// Retrieves the contents of the file at the specified path. /// /// @param canonical_path The path components from the root. @@ -196,7 +197,7 @@ class InMemoryFileSystem final { /// @retval The file contents or `nullptr` if the item at the specified path is not a file. std::shared_ptr get_file_content(const std::vector& canonical_path, std::error_code& error) const noexcept; - + /// Creates an in-memory directory at the specified path. /// /// @param canonical_path The path components from the root. @@ -208,7 +209,7 @@ class InMemoryFileSystem final { Attributes attributes, bool create_intermediate_directories, std::error_code& error) noexcept; - + /// Creates an in-memory file at the specified path. /// /// @param canonical_path The path components from the root. @@ -222,7 +223,7 @@ class InMemoryFileSystem final { Attributes attributes, bool overwrite, std::error_code& error) noexcept; - + /// Removes the item at the specified path. /// /// @param canonical_path The path components from the root. @@ -230,7 +231,7 @@ class InMemoryFileSystem final { /// @retval `true` if the item is removed otherwise `false`. bool remove_item(const std::vector& canonical_path, std::error_code& error) noexcept; - + /// Sets the attributes at the specified path. /// /// @param canonical_path The path components from the root. @@ -239,7 +240,7 @@ class InMemoryFileSystem final { bool set_attributes(const std::vector& canonical_path, Attributes attributes, std::error_code& error) noexcept; - + /// Writes the item at the specified path to the filesystem. /// /// @param canonical_path The path components from the root. @@ -251,7 +252,7 @@ class InMemoryFileSystem final { const std::string& dst_path, bool recursive, std::error_code& error) const noexcept; - + /// Renames the item at the specified path, if there is already an item with the same name then /// the rename would fail. /// @@ -262,7 +263,7 @@ class InMemoryFileSystem final { bool rename_item(const std::vector& canonical_path, const std::string& name, std::error_code& error) noexcept; - + /// Creates an`InMemoryFileSystem` from the filesystem path. /// /// The structure of the `InMemoryFileSystem` is identical to the structure of the filesystem at the @@ -275,7 +276,7 @@ class InMemoryFileSystem final { static std::unique_ptr make_from_directory(const std::string& path, FileLoadOption option, std::error_code& error) noexcept; - + /// Serializes the item at the specified path and writes it to the stream. /// /// The structure of the `InMemoryFileSystem` is identical to the structure of the filesystem at the @@ -292,7 +293,7 @@ class InMemoryFileSystem final { const MetadataWriter& metadata_writer, std::ostream& ostream, std::error_code& error) const noexcept; - + /// Serializes the item at the specified path and writes it to the stream. /// /// The structure of the `InMemoryFileSystem` is identical to the structure of the filesystem at the @@ -309,7 +310,7 @@ class InMemoryFileSystem final { const MetadataWriterInMemory& metadata_writer, void *dst, std::error_code& error) const noexcept; - + /// Computes the size of the buffer that would be needed to serialized the item at the specified path. /// /// @param canonical_path The path components from the root. @@ -319,7 +320,7 @@ class InMemoryFileSystem final { size_t get_buffer_size_for_serialization(const std::vector& canonical_path, size_t alignment, const MetadataWriter& metadata_writer) const noexcept; - + /// Constructs an `InMemoryFileSystem` instance from the buffer contents. /// /// @param buffer The memory buffer. @@ -327,7 +328,7 @@ class InMemoryFileSystem final { /// @retval The constructed `InMemoryFileSystem` or `nullptr` if the deserialization failed. static std::unique_ptr make_from_buffer(const std::shared_ptr& buffer, const MetadataReader& metadata_reader) noexcept; - + private: const std::unique_ptr root_; }; diff --git a/backends/apple/coreml/runtime/inmemoryfs/inmemory_filesystem_metadata.hpp b/backends/apple/coreml/runtime/inmemoryfs/inmemory_filesystem_metadata.hpp index d9a807a7fc7..4f183205b05 100644 --- a/backends/apple/coreml/runtime/inmemoryfs/inmemory_filesystem_metadata.hpp +++ b/backends/apple/coreml/runtime/inmemoryfs/inmemory_filesystem_metadata.hpp @@ -7,11 +7,12 @@ #pragma once -#include #include #include #include -#include + +#include "memory_buffer.hpp" +#include "range.hpp" namespace inmemoryfs { @@ -27,4 +28,3 @@ struct InMemoryFileSystemMetadata { }; } // namespace inmemoryfs - diff --git a/backends/apple/coreml/runtime/inmemoryfs/inmemory_filesystem_py.cpp b/backends/apple/coreml/runtime/inmemoryfs/inmemory_filesystem_py.cpp index 90bc0eb3e1b..66ffa697654 100644 --- a/backends/apple/coreml/runtime/inmemoryfs/inmemory_filesystem_py.cpp +++ b/backends/apple/coreml/runtime/inmemoryfs/inmemory_filesystem_py.cpp @@ -6,13 +6,9 @@ // Please refer to the license found in the LICENSE file in the root directory of the source tree. -#include #include #include -#include -#include -#include -#include +#include #include #include #include @@ -21,6 +17,13 @@ #include #include +#include +#include + +#include "inmemory_filesystem_utils.hpp" +#include "memory_buffer.hpp" +#include "memory_stream.hpp" + #if __has_include() #include #elif __has_include() diff --git a/backends/apple/coreml/runtime/inmemoryfs/inmemory_filesystem_utils.cpp b/backends/apple/coreml/runtime/inmemoryfs/inmemory_filesystem_utils.cpp index 1dffacf15a5..a7810e23db3 100644 --- a/backends/apple/coreml/runtime/inmemoryfs/inmemory_filesystem_utils.cpp +++ b/backends/apple/coreml/runtime/inmemoryfs/inmemory_filesystem_utils.cpp @@ -5,14 +5,17 @@ // // Please refer to the license found in the LICENSE file in the root directory of the source tree. -#include -#include -#include +#include "inmemory_filesystem_utils.hpp" + #include -#include -#include #include +#include + +#include "inmemory_filesystem_metadata.hpp" +#include "inmemory_filesystem_metadata_keys.hpp" +#include "json_util.hpp" + namespace inmemoryfs { using json = nlohmann::json; diff --git a/backends/apple/coreml/runtime/inmemoryfs/inmemory_filesystem_utils.mm b/backends/apple/coreml/runtime/inmemoryfs/inmemory_filesystem_utils.mm index 1f018e3c74a..309b95e8d85 100644 --- a/backends/apple/coreml/runtime/inmemoryfs/inmemory_filesystem_utils.mm +++ b/backends/apple/coreml/runtime/inmemoryfs/inmemory_filesystem_utils.mm @@ -6,15 +6,18 @@ // Please refer to the license found in the LICENSE file in the root directory of the source tree. #import "inmemory_filesystem_utils.hpp" -#import -#import -#import + #import -#import -#import #import #import +#import + +#import "inmemory_filesystem_metadata.hpp" +#import "inmemory_filesystem_metadata_keys.hpp" +#import "json_util.hpp" +#import "objc_json_serde.h" + namespace executorchcoreml { namespace serde { namespace json { @@ -29,13 +32,13 @@ static id to_json(const Range& range) { to_string(RangeKeys::kSize) : to_json_value(range.size) }; } - + static void from_json(id json, Range& range) { NSDictionary *json_dict = SAFE_CAST(json, NSDictionary); if (!json_dict) { return; } - + from_json_value(json_dict[to_string(RangeKeys::kOffset)], range.offset); from_json_value(json_dict[to_string(RangeKeys::kSize)], range.size); } @@ -51,13 +54,13 @@ static id to_json(const InMemoryNodeMetadata& node) { to_string(InMemoryNodeMetadataKeys::kKind) : to_json_value(node.kind) }; } - + static void from_json(id json, InMemoryNodeMetadata& node) { NSDictionary *json_dict = SAFE_CAST(json, NSDictionary); if (!json_dict) { return; } - + from_json_value(json_dict[to_string(InMemoryNodeMetadataKeys::kName)], node.name); from_json_value(json_dict[to_string(InMemoryNodeMetadataKeys::kDataRegion)], node.data_region); from_json_value(json_dict[to_string(InMemoryNodeMetadataKeys::kChildIndices)], node.child_name_to_indices_map); @@ -72,13 +75,13 @@ static id to_json(const InMemoryFileSystemMetadata& fs) { to_string(InMemoryFileSystemMetadataKeys::kNodes) : to_json_value(fs.nodes) }; } - + static void from_json(id json, InMemoryFileSystemMetadata& fs) { NSDictionary *json_dict = SAFE_CAST(json, NSDictionary); if (!json_dict) { return; } - + from_json_value(json_dict[to_string(InMemoryFileSystemMetadataKeys::kNodes)], fs.nodes); } }; @@ -114,7 +117,7 @@ size_t write_metadata_to_buffer(const InMemoryFileSystemMetadata& metadata, void if (!json_object) { return std::optional(); } - + InMemoryFileSystemMetadata metadata; Converter::from_json(to_json_object(json_object.value()), metadata); return metadata; @@ -132,7 +135,7 @@ bool serialize(const InMemoryFileSystem& file_system, write_metadata_to_stream(fs_metadata, stream); return true; }; - + return file_system.serialize(canonical_path, alignment, metadata_writer, ostream, ec); } @@ -145,7 +148,7 @@ bool serialize(const InMemoryFileSystem& file_system, void *metadata_dst) { return ::write_metadata_to_buffer(fs_metadata, metadata_dst); }; - + return file_system.serialize(canonical_path, alignment, metadata_writer, dst, ec); } @@ -156,7 +159,7 @@ size_t get_buffer_size_for_serialization(const InMemoryFileSystem& file_system, std::ostream& stream) { return ::write_metadata_to_stream(fs_metadata, stream); }; - + return file_system.get_buffer_size_for_serialization(canonical_path, alignment, metadata_writer); } @@ -164,7 +167,7 @@ size_t get_buffer_size_for_serialization(const InMemoryFileSystem& file_system, InMemoryFileSystem::MetadataReader metadata_reader = [](std::istream& stream) { return ::read_metadata_from_stream(stream); }; - + return InMemoryFileSystem::make_from_buffer(buffer, metadata_reader); } } // namespace inmemoryfs diff --git a/backends/apple/coreml/runtime/inmemoryfs/memory_buffer.cpp b/backends/apple/coreml/runtime/inmemoryfs/memory_buffer.cpp index 61b50a54655..c4485569d56 100644 --- a/backends/apple/coreml/runtime/inmemoryfs/memory_buffer.cpp +++ b/backends/apple/coreml/runtime/inmemoryfs/memory_buffer.cpp @@ -5,9 +5,10 @@ // // Please refer to the license found in the LICENSE file in the root directory of the source tree. -#include +#include "memory_buffer.hpp" #include +#include #include #include #include diff --git a/backends/apple/coreml/runtime/inmemoryfs/memory_buffer.hpp b/backends/apple/coreml/runtime/inmemoryfs/memory_buffer.hpp index e6e33f5ce26..5243401e2df 100644 --- a/backends/apple/coreml/runtime/inmemoryfs/memory_buffer.hpp +++ b/backends/apple/coreml/runtime/inmemoryfs/memory_buffer.hpp @@ -8,12 +8,13 @@ #pragma once #include -#include #include #include #include #include +#include "range.hpp" + namespace inmemoryfs { /// A class representing a memory buffer. class MemoryBuffer: public std::enable_shared_from_this { @@ -23,38 +24,38 @@ class MemoryBuffer: public std::enable_shared_from_this { MMap = 0, // If the buffer is memory mapped. Malloc , // If the buffer is heap allocated. }; - + enum class ReadOption: uint8_t { Malloc = 0, MMap, LazyMMap }; - + inline MemoryBuffer(void *data, size_t size, Kind kind = Kind::Malloc, std::shared_ptr parent = nullptr) noexcept: - data_(data), + data_(data), size_(size), kind_(kind), parent_(parent) {} - + MemoryBuffer(const MemoryBuffer &) = delete; MemoryBuffer &operator=(const MemoryBuffer &) = delete; - + virtual ~MemoryBuffer() noexcept {} - + /// Returns the underlying data. virtual inline void *data() noexcept { return data_; } - + /// Returns the size of the buffer. inline const size_t size() const noexcept { return size_; } - + /// Loads the contents of the buffer. /// /// - For a malloced buffer, the method is a no op, content is loaded at the initialization time. @@ -65,12 +66,12 @@ class MemoryBuffer: public std::enable_shared_from_this { inline virtual bool load(std::error_code& error) noexcept { return true; } - + /// Returns the kind of the buffer. inline const Kind kind() const noexcept { return kind_; } - + /// Returns the offset range that would be used when writing the buffer content. /// /// @param proposed_offset The proposed offset. @@ -78,7 +79,7 @@ class MemoryBuffer: public std::enable_shared_from_this { inline virtual std::pair get_offset_range(size_t proposed_offset) const noexcept { return {proposed_offset, proposed_offset}; } - + /// Returns the revised range that must be used for writing. /// /// @param dst The destination pointer. @@ -87,7 +88,7 @@ class MemoryBuffer: public std::enable_shared_from_this { inline virtual Range get_revised_range_for_writing(void *dst, Range proposed_range) const noexcept { return proposed_range; } - + /// Writes the contents of the buffer to the destination buffer at the given offset. /// /// @param dst The destination pointer. @@ -97,13 +98,13 @@ class MemoryBuffer: public std::enable_shared_from_this { virtual bool write(void *dst, size_t offset, std::error_code& error) noexcept; - + /// Slices a buffer. /// /// @param range The memory range. /// @retval The sliced buffer if the region is inside the buffer otherwise `nullptr`. virtual std::shared_ptr slice(Range range) noexcept; - + /// Reads the file content at the specified path. /// /// @param file_path The file path. @@ -116,7 +117,7 @@ class MemoryBuffer: public std::enable_shared_from_this { const std::vector& ranges, ReadOption option, std::error_code& error); - + /// Reads the whole file content at the specified path. /// /// @param file_path The file path. @@ -127,28 +128,28 @@ class MemoryBuffer: public std::enable_shared_from_this { read_file_content(const std::string& file_path, ReadOption option, std::error_code& error); - + /// Constructs a `MemoryBuffer`. /// /// @param size The size of the buffer. /// @param alignment The address alignment. static std::unique_ptr make_using_malloc(size_t size, size_t alignment = 1); - - + + /// Constructs a `MemoryBuffer` from memory allocated using `mmap`. /// /// @param size The size of the buffer. static std::unique_ptr make_using_mmap(size_t size); - + /// Constructs a `MemoryBuffer` without copying data. /// /// @param data The buffer content. /// @param size The size of the buffer. static std::unique_ptr make_unowned(void *data, size_t size); - + /// Constructs a `MemoryBuffer` with copying data. /// /// @param data The buffer content. diff --git a/backends/apple/coreml/runtime/inmemoryfs/memory_stream.cpp b/backends/apple/coreml/runtime/inmemoryfs/memory_stream.cpp index 5078db66e80..cb634234c5c 100644 --- a/backends/apple/coreml/runtime/inmemoryfs/memory_stream.cpp +++ b/backends/apple/coreml/runtime/inmemoryfs/memory_stream.cpp @@ -7,6 +7,8 @@ #include "memory_stream.hpp" +#include + namespace inmemoryfs { MemoryStreamBuf::MemoryStreamBuf(const std::shared_ptr& buffer) noexcept : buffer_(buffer) { diff --git a/backends/apple/coreml/runtime/inmemoryfs/memory_stream.hpp b/backends/apple/coreml/runtime/inmemoryfs/memory_stream.hpp index a5f40f26b5f..f7a8100f74f 100644 --- a/backends/apple/coreml/runtime/inmemoryfs/memory_stream.hpp +++ b/backends/apple/coreml/runtime/inmemoryfs/memory_stream.hpp @@ -7,18 +7,18 @@ #pragma once -#include - #include #include +#include "memory_buffer.hpp" + namespace inmemoryfs { /// A class representing an in-memory stream buffer. class MemoryStreamBuf: public std::streambuf { public: ~MemoryStreamBuf() = default; - + /// Constructs a `MemoryStreamBuf` from a `MemoryBuffer`. /// /// @param buffer The memory buffer. @@ -31,7 +31,7 @@ class MemoryStreamBuf: public std::streambuf { /// @param dir The seek direction. /// @retval The stream position. pos_type iseekoff(off_type offset, std::ios_base::seekdir dir); - + /// Called by `seekof` if the `openmode` is output. /// /// @param offset The offset value relative to the `dir`. @@ -44,7 +44,7 @@ class MemoryStreamBuf: public std::streambuf { /// @param which The open mode. /// @retval The stream position. pos_type seekpos(pos_type pos, std::ios_base::openmode which) override; - + /// Called by the public member function `pubseekoff` to alter the stream position. /// /// @param offset The offset value relative to the `dir`. @@ -74,18 +74,18 @@ class MemoryStreamBuf: public std::streambuf { /// /// Returns the value of the current character, converted to a value of type int. std::streambuf::int_type uflow() override; - + /// Called by other member functions to put a character into the controlled output sequence. /// /// Returns the value of the character that's put into the stream, converted to a value of type int. int_type overflow(int_type ch) override; - + /// Retrieves characters from the controlled input sequence and stores them in the array pointed by s, /// until either n characters have been extracted or the end of the sequence is reached. /// /// Returns the number of characters copied. std::streamsize xsgetn(char *s, std::streamsize n) override; - + /// Writes characters from the array pointed to by s into the controlled output sequence, /// until either n characters have been written or the end of the output sequence is reached. /// @@ -122,4 +122,3 @@ class MemoryOStream final : public std::ostream { }; } - diff --git a/backends/apple/coreml/runtime/inmemoryfs/reversed_memory_stream.cpp b/backends/apple/coreml/runtime/inmemoryfs/reversed_memory_stream.cpp index e38b9d08b19..7fe6c26ca41 100644 --- a/backends/apple/coreml/runtime/inmemoryfs/reversed_memory_stream.cpp +++ b/backends/apple/coreml/runtime/inmemoryfs/reversed_memory_stream.cpp @@ -7,6 +7,8 @@ #include "reversed_memory_stream.hpp" +#include + namespace inmemoryfs { ReversedIMemoryStreamBuf::ReversedIMemoryStreamBuf(std::shared_ptr buffer) noexcept diff --git a/backends/apple/coreml/runtime/inmemoryfs/reversed_memory_stream.hpp b/backends/apple/coreml/runtime/inmemoryfs/reversed_memory_stream.hpp index 1827af36413..09b3606bfe0 100644 --- a/backends/apple/coreml/runtime/inmemoryfs/reversed_memory_stream.hpp +++ b/backends/apple/coreml/runtime/inmemoryfs/reversed_memory_stream.hpp @@ -7,18 +7,18 @@ #pragma once -#include - #include #include +#include "memory_buffer.hpp" + namespace inmemoryfs { /// A class for reading an in-memory stream buffer in reverse. class ReversedIMemoryStreamBuf: public std::streambuf { public: ~ReversedIMemoryStreamBuf() = default; - + /// Constructs a `ReversedIMemoryStreamBuf` from a `MemoryBuffer`. /// /// @param buffer The memory buffer. @@ -50,7 +50,7 @@ class ReversedIMemoryStreamBuf: public std::streambuf { /// /// Returns the value of the current character, converted to a value of type int. std::streambuf::int_type uflow() override; - + /// Retrieves characters from the controlled input sequence and stores them in the array pointed by s, /// until either n characters have been extracted or the end of the sequence is reached. /// @@ -60,7 +60,7 @@ class ReversedIMemoryStreamBuf: public std::streambuf { private: /// Reads the character at the specified position. std::streambuf::int_type read(char *pos); - + const std::shared_ptr buffer_; char *start_; char *current_; @@ -70,7 +70,7 @@ class ReversedIMemoryStreamBuf: public std::streambuf { /// A class for reading an in-memory buffer in reverse. class ReversedIMemoryStream final : public std::istream { public: - + /// Constructs a `ReversedIMemoryStream` from a `MemoryBuffer`. /// /// @param buffer The memory buffer. @@ -83,4 +83,3 @@ class ReversedIMemoryStream final : public std::istream { }; } - diff --git a/backends/apple/coreml/runtime/inmemoryfs/setup.py b/backends/apple/coreml/runtime/inmemoryfs/setup.py index 95818485ca8..c93022ed341 100644 --- a/backends/apple/coreml/runtime/inmemoryfs/setup.py +++ b/backends/apple/coreml/runtime/inmemoryfs/setup.py @@ -30,7 +30,7 @@ cxx_std=cxx_std, extra_compile_args=["-mmacosx-version-min=10.15", "-g"], include_dirs=[ - "../../third-party/nlohmann_json/single_include/nlohmann", + "../../third-party/nlohmann_json/single_include", ".", "../util", ], diff --git a/backends/apple/coreml/runtime/test/DatabaseTests.mm b/backends/apple/coreml/runtime/test/DatabaseTests.mm index 9b89f20aa5a..1d66448852e 100644 --- a/backends/apple/coreml/runtime/test/DatabaseTests.mm +++ b/backends/apple/coreml/runtime/test/DatabaseTests.mm @@ -8,7 +8,7 @@ #import #import -#import +#import @interface DatabaseTests : XCTestCase @@ -58,7 +58,7 @@ - (void)testDatabaseQuery { XCTAssertTrue(insertStatement->bind_name("$value", std::string("1"), error)); XCTAssertTrue(insertStatement->execute(error)); XCTAssertTrue(database->get_row_count("TEST", error) == 1); - + auto query = database->prepare_statement("SELECT * FROM TEST", error); XCTAssertTrue(query != nullptr); XCTAssertTrue(query->step(error)); diff --git a/backends/apple/coreml/runtime/test/InMemoryFileSystemTests.mm b/backends/apple/coreml/runtime/test/InMemoryFileSystemTests.mm index a4ccbd94b68..226a42aaaaf 100644 --- a/backends/apple/coreml/runtime/test/InMemoryFileSystemTests.mm +++ b/backends/apple/coreml/runtime/test/InMemoryFileSystemTests.mm @@ -13,7 +13,7 @@ #import #import -#import +#import #import using json = nlohmann::json; @@ -25,11 +25,11 @@ inline Content(std::string identifier, std::string value) noexcept :identifier(std::move(identifier)), value(std::move(value)) {} - + inline Content() noexcept :identifier(""), value("") {} - + std::string identifier; std::string value; }; @@ -80,7 +80,7 @@ T from_memory_buffer(const std::shared_ptr& buffer) { for (size_t i = 0; i < length; ++i) { result += chars[rand() % (sizeof(chars) - 1)]; } - + return result; } @@ -178,12 +178,12 @@ - (void)testWriteItemAtPath { Content content("abc", "xyz"); std::shared_ptr buffer = to_memory_buffer(content); std::error_code error; - + XCTAssertTrue(fs.make_directory({"dir1"}, InMemoryFileSystem::Attributes(), false, error)); XCTAssertTrue(fs.make_file({"dir1", "content.json"}, buffer, InMemoryFileSystem::Attributes(), false /*overwrite*/, error)); XCTAssertTrue(fs.make_directory({"dir1", "dir2"}, InMemoryFileSystem::Attributes(), false, error)); XCTAssertTrue(fs.make_file({"dir1", "dir2", "content.json"}, buffer, InMemoryFileSystem::Attributes(), false /*overwrite*/, error)); - + NSURL *dirURL = [[NSURL fileURLWithPath:NSTemporaryDirectory()] URLByAppendingPathComponent:[NSUUID UUID].UUIDString]; NSFileManager *fm = [[NSFileManager alloc] init]; NSError *localError = nil; @@ -220,7 +220,7 @@ - (void)testCreationFromFileSystem { NSData *data = [NSData dataWithBytesNoCopy:buffer->data() length:buffer->size() freeWhenDone:NO]; XCTAssertTrue([data writeToURL:[dirURL URLByAppendingPathComponent:@"dir1/content.json"] atomically:YES]); XCTAssertTrue([data writeToURL:[dirURL URLByAppendingPathComponent:@"dir2/content.json"] atomically:YES]); - + std::filesystem::path dirPath(dirURL.path.UTF8String); std::error_code error; auto fs = InMemoryFileSystem::make_from_directory(dirPath, @@ -256,7 +256,7 @@ - (void)_testSerdeWithConfig:(SerdeVerificationConfig)config { } XCTAssertTrue(fs.write_item_to_disk({}, dirURL.path.UTF8String, true, error)); } - + // Verify serialization. std::shared_ptr buffer = nullptr; { @@ -264,7 +264,7 @@ - (void)_testSerdeWithConfig:(SerdeVerificationConfig)config { auto fs = InMemoryFileSystem::make_from_directory(dirURL.path.UTF8String, config.file_load_option, error); - + XCTAssertTrue(fs != nullptr); size_t length = inmemoryfs::get_buffer_size_for_serialization(*fs, {}, config.alignment); switch (config.file_load_option) { @@ -272,15 +272,15 @@ - (void)_testSerdeWithConfig:(SerdeVerificationConfig)config { buffer = MemoryBuffer::make_using_mmap(length); break; } - + default: buffer = MemoryBuffer::make_using_malloc(length); break; } - + XCTAssertTrue(inmemoryfs::serialize(*fs, {}, config.alignment, buffer->data(), error)); } - + // Verify de-serialization. { auto fs = inmemoryfs::make_from_buffer(buffer); @@ -290,7 +290,7 @@ - (void)_testSerdeWithConfig:(SerdeVerificationConfig)config { XCTAssertEqual(from_memory_buffer(fs->get_file_content({"test", "dir", content.identifier}, error)), content); } } - + [fm removeItemAtURL:dirURL error:nil]; } @@ -332,7 +332,7 @@ - (void)testSerde { .file_base_length = 100, .alignment = 2 * (size_t)getpagesize(), }); - + for (const auto& config : configs) { [self _testSerdeWithConfig:config]; } @@ -349,7 +349,7 @@ - (void)testReadJSONObject { auto j = json::parse(object.value().begin(), object.value().end()); XCTAssertEqual(j["x"], 1, "The value must match"); } - + { std::stringstream ss; std::string fragment("{\"x\" : 1"); @@ -357,8 +357,8 @@ - (void)testReadJSONObject { auto object = executorchcoreml::json::read_object_from_stream(ss); XCTAssertFalse(object.has_value(), "There is no closing brace, `read_json_object` must return nullopt"); } - - + + { std::stringstream ss; std::string fragment("{\"x\" : \"\\\"1\"}xyz"); @@ -369,7 +369,7 @@ - (void)testReadJSONObject { std::string value = j["x"]; XCTAssertEqual(value, std::string("\"1"), "The value must match"); } - + { std::stringstream ss; std::string fragment("{sdhalskjks}"); @@ -384,7 +384,7 @@ - (void)testReadJSONObject { } XCTAssertNotEqual(eptr, nullptr, "Parsing invalid json object must throw an exception"); } - + } @end diff --git a/backends/apple/coreml/runtime/test/KeyValueStoreTests.mm b/backends/apple/coreml/runtime/test/KeyValueStoreTests.mm index 81a667bc375..4d113efa43a 100644 --- a/backends/apple/coreml/runtime/test/KeyValueStoreTests.mm +++ b/backends/apple/coreml/runtime/test/KeyValueStoreTests.mm @@ -9,7 +9,7 @@ #import #import -#import +#import namespace { using json = nlohmann::json; @@ -24,11 +24,11 @@ inline Entry(std::string identifier, size_t count) noexcept :identifier(std::move(identifier)), count(count) {} - + inline Entry() noexcept :identifier(""), count(0) {} - + inline std::string to_json_string() const noexcept { json j; to_json(j, *this); @@ -36,12 +36,12 @@ inline Entry() noexcept ss << j; return ss.str(); } - + inline void from_json_string(const std::string& json_string) noexcept { auto j = json::parse(json_string); from_json(j, *this); } - + std::string identifier; size_t count; }; @@ -110,12 +110,12 @@ - (void)testJSONKeyValueStore { std::error_code error; auto database = Database::make_inmemory(Database::SynchronousMode::Normal, 100, error); auto store = JSONKeyValueStore::make(std::move(database), "test", error); - + XCTAssertTrue(store->put(1, Entry("1", 1), error)); auto entry1 = store->get(1, error); XCTAssertTrue(entry1.value().count == 1); XCTAssertTrue(entry1.value().identifier == "1"); - + XCTAssertTrue(store->put(2, Entry("2", 2), error)); auto entry2 = store->get(2, error); XCTAssertTrue(entry2.value().count == 2); @@ -134,7 +134,7 @@ - (void)testKVStoreTransactionCommit { // Commit the transaction. return true; }, Database::TransactionBehavior::Immediate, error)); - + XCTAssertTrue(store->size(error) == 2); } @@ -150,7 +150,7 @@ - (void)testKVStoreTransactionRollback { // Rollback the transaction. return false; }, Database::TransactionBehavior::Immediate, error)); - + XCTAssertTrue(store->size(error) == 0); } @@ -173,7 +173,7 @@ - (void)testKVStoreGetKeysSortedByAccessTime { // 1 is accessed first then 2 and then 3 XCTAssertTrue(keys == (std::vector{1, 2, 3})); } - + { std::vector keys; XCTAssertTrue(store->get_keys_sorted_by_access_time([&keys](int key) { @@ -210,7 +210,7 @@ - (void)testKVStoreGetKeysSortedByAccessCount { // 3 is accessed 1 time, 2 is accessed 2 times, and 1 is accessed 3 times. XCTAssertTrue(keys == (std::vector{3, 2, 1})); } - + { std::vector keys; XCTAssertTrue(store->get_keys_sorted_by_access_count([&keys](int key) { diff --git a/backends/apple/coreml/runtime/util/json_util.cpp b/backends/apple/coreml/runtime/util/json_util.cpp index 80605c55e8f..a7592541a49 100644 --- a/backends/apple/coreml/runtime/util/json_util.cpp +++ b/backends/apple/coreml/runtime/util/json_util.cpp @@ -6,7 +6,7 @@ // // Please refer to the license found in the LICENSE file in the root directory of the source tree. -#include +#include "json_util.hpp" #include #include diff --git a/backends/apple/coreml/runtime/util/objc_json_serde.mm b/backends/apple/coreml/runtime/util/objc_json_serde.mm index 0f55d4b5919..9102046a759 100644 --- a/backends/apple/coreml/runtime/util/objc_json_serde.mm +++ b/backends/apple/coreml/runtime/util/objc_json_serde.mm @@ -7,7 +7,7 @@ // Please refer to the license found in the LICENSE file in the root directory of the source tree. -#import +#import "objc_json_serde.h" namespace executorchcoreml { namespace serde { diff --git a/backends/apple/coreml/runtime/workspace/executorchcoreml.xcodeproj/project.pbxproj b/backends/apple/coreml/runtime/workspace/executorchcoreml.xcodeproj/project.pbxproj index d8ee4ea693a..d8a5e611077 100644 --- a/backends/apple/coreml/runtime/workspace/executorchcoreml.xcodeproj/project.pbxproj +++ b/backends/apple/coreml/runtime/workspace/executorchcoreml.xcodeproj/project.pbxproj @@ -900,7 +900,7 @@ "$(SRCROOT)/../include", "$(SRCROOT)/../sdk", "$(SRCROOT)/../util", - "$(SRCROOT)/../../third-party/nlohmann_json/single_include/nlohmann", + "$(SRCROOT)/../../third-party/nlohmann_json/single_include", "$(SRCROOT)/../../third-party/coremltools/deps/protobuf/src", ); IPHONEOS_DEPLOYMENT_TARGET = 16.0; @@ -931,7 +931,7 @@ "$(SRCROOT)/../include", "$(SRCROOT)/../sdk", "$(SRCROOT)/../util", - "$(SRCROOT)/../../third-party/nlohmann_json/single_include/nlohmann", + "$(SRCROOT)/../../third-party/nlohmann_json/single_include", "$(SRCROOT)/../../third-party/coremltools/deps/protobuf/src", ); IPHONEOS_DEPLOYMENT_TARGET = 16.0; diff --git a/backends/apple/coreml/scripts/build_tests.sh b/backends/apple/coreml/scripts/build_tests.sh index 730ba0839db..911c6cd4e10 100755 --- a/backends/apple/coreml/scripts/build_tests.sh +++ b/backends/apple/coreml/scripts/build_tests.sh @@ -13,7 +13,7 @@ SCRIPT_DIR_PATH="$( EXECUTORCH_ROOT_PATH=$(realpath "$SCRIPT_DIR_PATH/../../../../") COREML_DIR_PATH="$EXECUTORCH_ROOT_PATH/backends/apple/coreml" PROTOBUF_DIR_PATH="$COREML_DIR_PATH/third-party/coremltools/deps/protobuf" -IOS_TOOLCHAIN_PATH="$COREML_DIR_PATH/third-party/ios-cmake/ios.toolchain.cmake" +IOS_TOOLCHAIN_PATH="$EXECUTORCH_ROOT_PATH/third-party/ios-cmake/ios.toolchain.cmake" CMAKE_EXECUTORCH_BUILD_DIR_PATH="$COREML_DIR_PATH/executorch-cmake-out" CMAKE_PROTOBUF_BUILD_DIR_PATH="$COREML_DIR_PATH/protobuf-cmake-out" LIBRARIES_DIR_PATH="$COREML_DIR_PATH/runtime/libraries" diff --git a/backends/apple/coreml/scripts/install_requirements.sh b/backends/apple/coreml/scripts/install_requirements.sh index b48ac7bfb69..baf731452e9 100755 --- a/backends/apple/coreml/scripts/install_requirements.sh +++ b/backends/apple/coreml/scripts/install_requirements.sh @@ -53,14 +53,6 @@ if [ $STATUS -ne 0 ]; then exit 1 fi -echo "${green}ExecuTorch: Cloning ios-cmake." -git clone https://github.com/leetal/ios-cmake.git "$COREML_DIR_PATH/third-party/ios-cmake" -STATUS=$? -if [ $STATUS -ne 0 ]; then - echo "${red}ExecuTorch: Failed to clone ios-cmake." - exit 1 -fi - echo "${green}ExecuTorch: Cloning nlohmann." git clone https://github.com/nlohmann/json.git "$COREML_DIR_PATH/third-party/nlohmann_json" STATUS=$? @@ -72,5 +64,5 @@ fi sh "$COREML_DIR_PATH/scripts/install_inmemoryfs.sh" echo "${green}ExecuTorch: Copying protobuf files." -mkdir -p "$COREML_DIR_PATH/runtime/sdk/format/" -cp -rf "$PROTOBUF_FILES_DIR_PATH" "$COREML_DIR_PATH/runtime/sdk/format/" +mkdir -p "$COREML_DIR_PATH/runtime/sdk/format/" +cp -rf "$PROTOBUF_FILES_DIR_PATH" "$COREML_DIR_PATH/runtime/sdk/format/" diff --git a/backends/apple/coreml/test/test_coreml_partitioner.py b/backends/apple/coreml/test/test_coreml_partitioner.py index e59e5c95544..45c468e450b 100644 --- a/backends/apple/coreml/test/test_coreml_partitioner.py +++ b/backends/apple/coreml/test/test_coreml_partitioner.py @@ -9,9 +9,7 @@ import torch import torchvision -from executorch.backends.apple.coreml.partition.coreml_partitioner import ( - CoreMLPartitioner, -) +from executorch.backends.apple.coreml.partition import CoreMLPartitioner class TestCoreMLPartitioner(unittest.TestCase): diff --git a/backends/apple/coreml/test/test_coreml_quantizer.py b/backends/apple/coreml/test/test_coreml_quantizer.py index 67eee3593fd..c05cde05a0a 100644 --- a/backends/apple/coreml/test/test_coreml_quantizer.py +++ b/backends/apple/coreml/test/test_coreml_quantizer.py @@ -14,7 +14,7 @@ QuantizationScheme, ) -from executorch.backends.apple.coreml.quantizer.coreml_quantizer import CoreMLQuantizer +from executorch.backends.apple.coreml.quantizer import CoreMLQuantizer from torch._export import capture_pre_autograd_graph from torch.ao.quantization.quantize_pt2e import ( convert_pt2e, diff --git a/backends/apple/mps/operators/indexing_ops.py b/backends/apple/mps/operators/indexing_ops.py index 690549973a4..02506e11823 100644 --- a/backends/apple/mps/operators/indexing_ops.py +++ b/backends/apple/mps/operators/indexing_ops.py @@ -16,6 +16,7 @@ MPSIndexPut, MPSIndexSelect, MPSIndexTensor, + MPSScatter, ) from executorch.backends.apple.mps.utils.mps_utils import get_input_node from executorch.backends.transforms import get_shape @@ -65,12 +66,9 @@ def define_node( mps_graph.mps_nodes.append(mps_node) -# [MPS TODO]: Works on a single iteration of llama2, but subsequent tokens -# are wrong when using Index put. Disabling it for now. @register_node_visitor class IndexPutVisitor(NodeVisitor): - # target = "aten.index_put.default" - target = "disabled" + target = "aten.index_put.default" def __init__(self, *args) -> None: super().__init__(*args) @@ -115,6 +113,88 @@ def define_node( mps_graph.mps_nodes.append(mps_node) +@register_node_visitor +class SliceScatterVisitor(NodeVisitor): + target = "aten.slice_scatter.default" + + def __init__(self, *args) -> None: + super().__init__(*args) + self.invalid_val = 2**63 - 1 + + def maybe_wrap_dim(self, dim: int, n: int) -> List[int]: + if dim < 0: + wrapped_dim = dim + n + if wrapped_dim < 0: + wrapped_dim = 0 + return wrapped_dim + elif dim > n: + return n + return dim + + def get_exapnded_index(self, idx, shape, dim): + if idx.dim() == 0: + return idx.expand(shape) + + dim = self.maybe_wrap_dim(dim, len(shape)) + + # setup new_index_shape as [BS, 1, ..., idx_size, ..., 1] + # to reshape index_ + idx_size = idx.size(0) + new_index_shape = [1] * len(shape) + new_index_shape[dim] = idx_size + + # Now apply expand to index_ + index = idx.view(new_index_shape) + new_index_shape = list(shape) + new_index_shape[dim] = idx_size + index = index.expand(new_index_shape) + + return index + + def get_slice_scatter_indices( + self, dim, start, end, step, input_shape, dtype=torch.int64 + ): + idx = torch.arange(start, end, step, dtype=dtype) + return self.get_exapnded_index(idx, input_shape, dim) + + def define_node( + self, + node: torch.fx.Node, + mps_graph: MPSGraph, + ) -> None: + mps_node = self.create_unary_node(node, mps_graph, MPSScatter) + + start = None + end = None + step = 1 + + mps_node.mpsnode_union.src_id = self.define_tensor( + get_input_node(node, 1), mps_graph + ) + if len(node.args) >= 3: + mps_node.mpsnode_union.dim = cast(int, node.args[2]) + if len(node.args) >= 4: + start = cast(int, node.args[3]) + if len(node.args) >= 5 and node.args[4] != self.invalid_val: + end = cast(int, node.args[4]) + if len(node.args) >= 6: + step = cast(int, node.args[5]) + + input_shape = get_shape(get_input_node(node, 0)) + dim_len = input_shape[ + self.maybe_wrap_dim(mps_node.mpsnode_union.dim, len(input_shape)) + ] + + start_val = start if start is not None else 0 + end_val = end if end is not None else dim_len + + scatter_indices = self.get_slice_scatter_indices( + mps_node.mpsnode_union.dim, start_val, end_val, step, input_shape + ) + mps_node.mpsnode_union.idx_id = self.define_constant(scatter_indices, mps_graph) + mps_graph.mps_nodes.append(mps_node) + + @register_node_visitor class EmbeddingVisitor(NodeVisitor): target = "aten.embedding.default" diff --git a/backends/apple/mps/operators/node_visitor.py b/backends/apple/mps/operators/node_visitor.py index e9f879db88a..0b9b2d5512c 100644 --- a/backends/apple/mps/operators/node_visitor.py +++ b/backends/apple/mps/operators/node_visitor.py @@ -143,6 +143,38 @@ def define_tensor_list(self, node: torch.fx.Node, mps_graph: MPSGraph) -> List[i mps_graph.mps_values.append(mps_tensor) return self.tensor_to_id[node] + def define_constant( + self, + constant_tensor: torch.tensor, + mps_graph: MPSGraph, + ): + """Defines a scalar value into the MPSGraph serialization schema + + Args: + tensor (torch.fx.Node): EdgeIR tensor to define into mps_graph + mps_graph (MPSGraph): MPSGraph object for serializing into flatbuffer + """ + constant_tensor = constant_tensor.contiguous() + # MPS TODO: cache these values + id = len(mps_graph.mps_values) + self.tensor_to_id[constant_tensor] = id + mps_data_type = edge_dtype_to_mps_dtype(constant_tensor.dtype) + constant_buffer_size, constant_buffer, mps_data_type = self.get_serialized_data( + constant_tensor, mps_graph, mps_data_type, id + ) + dims = list(constant_tensor.shape) + + mps_tensor = MPSTensor( + datatype=mps_data_type, + num_dims=len(dims), + dims=dims, + constant_buffer_size=constant_buffer_size, + constant_buffer=constant_buffer, + ) + + mps_graph.mps_values.append(mps_tensor) + return id + def define_scalar( self, val: Union[float, int], @@ -157,6 +189,7 @@ def define_scalar( """ assert isinstance(val, int) or isinstance(val, float) + # MPS TODO: cache these values id = len(mps_graph.mps_values) self.tensor_to_id[val] = id diff --git a/backends/apple/mps/runtime/MPSGraphBuilder.h b/backends/apple/mps/runtime/MPSGraphBuilder.h index e4e89d68691..29b9471ae9a 100644 --- a/backends/apple/mps/runtime/MPSGraphBuilder.h +++ b/backends/apple/mps/runtime/MPSGraphBuilder.h @@ -123,6 +123,7 @@ class MPSGraphBuilder { _DEFINE_MPS_OP(Embedding); _DEFINE_MPS_OP(IndexTensor); _DEFINE_MPS_OP(IndexPut); + _DEFINE_MPS_OP(Scatter); // Linear algebra ops _DEFINE_MPS_OP(MatMul); _DEFINE_MPS_OP(Addmm); diff --git a/backends/apple/mps/runtime/operations/IndexingOps.mm b/backends/apple/mps/runtime/operations/IndexingOps.mm index b4dcf192b46..6536aa52cf3 100644 --- a/backends/apple/mps/runtime/operations/IndexingOps.mm +++ b/backends/apple/mps/runtime/operations/IndexingOps.mm @@ -204,6 +204,30 @@ return err; } +Error +MPSGraphBuilder::mpsScatterOp(NodePtr nodePtr) { + auto graphNode = nodePtr->mpsnode_union_as_MPSScatter(); + ET_LOG( + Debug, "%s %d: %d", + __FUNCTION__, graphNode->input1_id(), graphNode->output_id() + ); + + int64_t dim = graphNode->dim(); + MPSGraphTensor* inputTensor = getMPSGraphTensor(graphNode->input1_id()); + MPSGraphTensor* indicesTensor = getMPSGraphTensor(graphNode->idx_id()); + MPSGraphTensor* updatesTensor = getMPSGraphTensor(graphNode->src_id()); + + _idToMPSGraphTensor[graphNode->output_id()] = + [_mpsGraph scatterAlongAxis:dim + withDataTensor:inputTensor + updatesTensor:updatesTensor + indicesTensor:indicesTensor + mode:MPSGraphScatterModeSet + name:nil]; + return Error::Ok; +} + + } // namespace delegate } // namespace mps } // namespace executor diff --git a/backends/apple/mps/runtime/operations/OperationUtils.mm b/backends/apple/mps/runtime/operations/OperationUtils.mm index 648421ee2cd..21c4a0d3e7b 100644 --- a/backends/apple/mps/runtime/operations/OperationUtils.mm +++ b/backends/apple/mps/runtime/operations/OperationUtils.mm @@ -181,6 +181,7 @@ _DEFINE_MPS_NODE(Embedding); _DEFINE_MPS_NODE(IndexTensor); _DEFINE_MPS_NODE(IndexPut); + _DEFINE_MPS_NODE(Scatter); // Reduce ops _DEFINE_MPS_NODE(Mean); // Shape ops diff --git a/backends/apple/mps/serialization/mps_graph_schema.py b/backends/apple/mps/serialization/mps_graph_schema.py index 8134091a01d..6909926e8cf 100644 --- a/backends/apple/mps/serialization/mps_graph_schema.py +++ b/backends/apple/mps/serialization/mps_graph_schema.py @@ -456,6 +456,13 @@ class MPSIndexPut(MPSNode1x1): values_id: int = -1 +@dataclass +class MPSScatter(MPSNode1x1): + dim: int = 0 + idx_id: int = -1 + src_id: int = -1 + + ## ## Shape ops ## @@ -703,6 +710,7 @@ class MPSArange: MPSEmbedding, MPSIndexTensor, MPSIndexPut, + MPSScatter, # Shape ops MPSPermute, MPSView, diff --git a/backends/apple/mps/serialization/schema.fbs b/backends/apple/mps/serialization/schema.fbs index 6ba2c937f32..6e089d4526f 100644 --- a/backends/apple/mps/serialization/schema.fbs +++ b/backends/apple/mps/serialization/schema.fbs @@ -166,6 +166,14 @@ table MPSIndexPut { output_id:int; } +table MPSScatter { + input1_id:int; + output_id:int; + dim:long; + idx_id:int; + src_id:int; +} + // Shape ops. table MPSPermute { input1_id:int; @@ -390,6 +398,7 @@ union MPSNodeUnion { MPSEmbedding, MPSIndexTensor, MPSIndexPut, + MPSScatter, // Reduce ops MPSMean, diff --git a/backends/apple/mps/test/test_mps_indexing_ops.py b/backends/apple/mps/test/test_mps_indexing_ops.py index 7991f1a165a..03709fc891a 100644 --- a/backends/apple/mps/test/test_mps_indexing_ops.py +++ b/backends/apple/mps/test/test_mps_indexing_ops.py @@ -201,7 +201,6 @@ def forward(self, x): # ) def test_mps_indexing_put_1(self): - class IndexPut(torch.nn.Module): def __init__(self): super().__init__() @@ -223,3 +222,43 @@ def forward(self, x, y, z): self.lower_and_test_with_partitioner( module, model_inputs, func_name=inspect.stack()[0].function[5:] ) + + def test_mps_indexing_slice_scatter_1(self): + class IndexSliceScatter(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return x.slice_scatter(y, start=6) + + module = IndexSliceScatter() + input = torch.zeros(8, 8) + src = torch.ones(2, 8) + model_inputs = ( + input, + src, + ) + + self.lower_and_test_with_partitioner( + module, model_inputs, func_name=inspect.stack()[0].function[5:] + ) + + def test_mps_indexing_slice_scatter_2(self): + class IndexSliceScatter(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return x.slice_scatter(y, dim=1, start=2, end=6, step=2) + + module = IndexSliceScatter() + input = torch.zeros(8, 8) + src = torch.ones(8, 2) + model_inputs = ( + input, + src, + ) + + self.lower_and_test_with_partitioner( + module, model_inputs, func_name=inspect.stack()[0].function[5:] + ) diff --git a/backends/apple/mps/test/test_mps_utils.py b/backends/apple/mps/test/test_mps_utils.py index 08088df7db5..36f11c08c80 100644 --- a/backends/apple/mps/test/test_mps_utils.py +++ b/backends/apple/mps/test/test_mps_utils.py @@ -247,7 +247,9 @@ def lower_module_and_test_output( ) executorch_program = delegated_program.to_executorch( - config=ExecutorchBackendConfig(extract_constant_segment=False) + config=ExecutorchBackendConfig( + extract_delegate_segments=False, extract_constant_segment=False + ) ) else: delegated_program = to_backend( @@ -264,7 +266,9 @@ def lower_module_and_test_output( _skip_dim_order=True, # TODO(T182928844): Delegate dim order op to backend. ), ).to_executorch( - config=ExecutorchBackendConfig(extract_constant_segment=False) + config=ExecutorchBackendConfig( + extract_delegate_segments=False, extract_constant_segment=False + ) ) if bundled_program: diff --git a/backends/arm/test/arm_tosa_reference.py b/backends/arm/test/arm_tosa_reference.py index ef6db7db526..f6a7fd97876 100644 --- a/backends/arm/test/arm_tosa_reference.py +++ b/backends/arm/test/arm_tosa_reference.py @@ -202,7 +202,9 @@ def tosa_run_test(op, profile=TosaProfile.MI): # noqa: C901 model_edge = model_edge.to_backend(ArmPartitioner(compile_spec)) exec_prog = model_edge.to_executorch( - config=ExecutorchBackendConfig(extract_constant_segment=False) + config=ExecutorchBackendConfig( + extract_delegate_segments=False, extract_constant_segment=False + ) ) # Save ground truth results to file diff --git a/backends/arm/test/ops/test_clone.py b/backends/arm/test/ops/test_clone.py index 32554dfadd6..2eb94a82322 100644 --- a/backends/arm/test/ops/test_clone.py +++ b/backends/arm/test/ops/test_clone.py @@ -73,8 +73,8 @@ def _test_clone_tosa_BI_pipeline( if common.TOSA_REF_MODEL_INSTALLED: tester.run_method_and_compare_outputs(qtol=1) else: - logger.warning( - "TOSA ref model tool not installed, skip numerical correctness tests" + raise RuntimeError( + "TOSA ref model tool not installed and the test is an expected fail" ) def _test_clone_tosa_u55_pipeline( diff --git a/backends/arm/test/ops/test_view.py b/backends/arm/test/ops/test_view.py index 0620ecb49b4..fddd21ed2fb 100644 --- a/backends/arm/test/ops/test_view.py +++ b/backends/arm/test/ops/test_view.py @@ -70,8 +70,8 @@ def _test_view_tosa_BI_pipeline( if common.TOSA_REF_MODEL_INSTALLED: tester.run_method_and_compare_outputs(qtol=1) else: - logger.warning( - "TOSA ref model tool not installed, skip numerical correctness tests" + raise RuntimeError( + "TOSA ref model tool not installed and the test is an expected fail" ) def _test_view_u55_BI_pipeline( diff --git a/backends/cadence/CMakeLists.txt b/backends/cadence/CMakeLists.txt index f1d5ccbd2e5..14030719153 100644 --- a/backends/cadence/CMakeLists.txt +++ b/backends/cadence/CMakeLists.txt @@ -12,7 +12,7 @@ if(NOT CMAKE_CXX_STANDARD) endif() # Set the project name. -project(cadence_executorch_example) +project(cadence_backend) # Source root directory for executorch. if(NOT EXECUTORCH_ROOT) @@ -21,121 +21,21 @@ endif() include(${EXECUTORCH_ROOT}/build/Utils.cmake) -if(NOT PYTHON_EXECUTABLE) - resolve_python_executable() -endif() - # Let files say "include ". set(_common_include_directories ${EXECUTORCH_ROOT}/..) -# Find prebuilt executorch lib -find_package(executorch CONFIG REQUIRED) - -add_compile_options( - -DSDK_DEBUGCONSOLE=1 - -DSERIAL_PORT_TYPE_UART=1 - -DDEBUG_CONSOLE_RX_ENABLE=0 - -DDEBUG - -DCPU_MIMXRT685SFVKB_dsp - -DMCUXPRESSO_SDK - -g - -O0 - -Wall - -fsigned-char - -Wno-missing-braces - -fmessage-length=0 - -DPRINTF_FLOAT_ENABLE=1 -) - -if(NOT DEFINED NXP_SDK_ROOT_DIR) - message(FATAL_ERROR "NXP_SDK_ROOT_DIR is not set") -endif() - -# lint_cmake: -linelength -set(SOURCES - ${NXP_SDK_ROOT_DIR}/components/lists/fsl_component_generic_list.c - ${NXP_SDK_ROOT_DIR}/components/uart/fsl_adapter_usart.c - ${NXP_SDK_ROOT_DIR}/devices/MIMXRT685S/drivers/fsl_clock.c - ${NXP_SDK_ROOT_DIR}/devices/MIMXRT685S/drivers/fsl_common.c - ${NXP_SDK_ROOT_DIR}/devices/MIMXRT685S/drivers/fsl_common_dsp.c - ${NXP_SDK_ROOT_DIR}/devices/MIMXRT685S/drivers/fsl_flexcomm.c - ${NXP_SDK_ROOT_DIR}/devices/MIMXRT685S/drivers/fsl_gpio.c - ${NXP_SDK_ROOT_DIR}/devices/MIMXRT685S/drivers/fsl_mu.c - ${NXP_SDK_ROOT_DIR}/devices/MIMXRT685S/drivers/fsl_reset.c - ${NXP_SDK_ROOT_DIR}/devices/MIMXRT685S/drivers/fsl_usart.c - ${NXP_SDK_ROOT_DIR}/devices/MIMXRT685S/system_MIMXRT685S_dsp.c - ${NXP_SDK_ROOT_DIR}/devices/MIMXRT685S/utilities/debug_console_lite/fsl_assert.c - ${NXP_SDK_ROOT_DIR}/devices/MIMXRT685S/utilities/debug_console_lite/fsl_debug_console.c - ${NXP_SDK_ROOT_DIR}/boards/evkmimxrt685/dsp_examples/mu_polling/dsp/board_hifi4.c - ${NXP_SDK_ROOT_DIR}/boards/evkmimxrt685/dsp_examples/mu_polling/dsp/pin_mux.c - ${NXP_SDK_ROOT_DIR}/devices/MIMXRT685S/utilities/str/fsl_str.c -) - -add_library(dsp_mu_polling_libs STATIC ${SOURCES}) - -target_include_directories( - dsp_mu_polling_libs - PUBLIC ${NXP_SDK_ROOT_DIR} - ${NXP_SDK_ROOT_DIR}/components/uart - ${NXP_SDK_ROOT_DIR}/devices/MIMXRT685S/drivers - ${NXP_SDK_ROOT_DIR}/devices/MIMXRT685S/utilities/debug_console_lite - ${NXP_SDK_ROOT_DIR}/components/lists - ${NXP_SDK_ROOT_DIR}/devices/MIMXRT685S - ${NXP_SDK_ROOT_DIR}/CMSIS/Core/Include - ${NXP_SDK_ROOT_DIR}/devices/MIMXRT685S/utilities/str - ${NXP_SDK_ROOT_DIR}/boards/evkmimxrt685/dsp_examples/mu_polling/dsp -) - -add_library(extension_runner_util STATIC IMPORTED) -set_property( - TARGET extension_runner_util - PROPERTY - IMPORTED_LOCATION - "${CMAKE_CURRENT_LIST_DIR}/../../cmake-out/extension/runner_util/libextension_runner_util.a" -) - +add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/hifi/third-party/nnlib) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/hifi/operators) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/hifi/kernels) -# Generate the model header file -add_custom_command( - OUTPUT ${CMAKE_BINARY_DIR}/model_pte.h - COMMAND ${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/utils/gen_header.py - --model_path ${MODEL_PATH} --header_output_path ${CMAKE_BINARY_DIR} - COMMENT "Converting .pte model to header file..." - DEPENDS ${CMAKE_CURRENT_LIST_DIR}/utils/gen_header.py -) - -add_custom_target(gen_model_header DEPENDS ${CMAKE_BINARY_DIR}/model_pte.h) -add_executable(cadence_executorch_example executor_runner.cpp) -add_dependencies(cadence_executorch_example gen_model_header) - -# lint_cmake: -linelength -target_include_directories( - cadence_executorch_example PUBLIC ${ROOT_DIR}/.. ${CMAKE_BINARY_DIR} - ${_common_include_directories} +install( + TARGETS cadence_ops_lib + DESTINATION lib + INCLUDES + DESTINATION ${_common_include_directories} ) -target_link_options( - cadence_executorch_example PRIVATE - -mlsp=${NXP_SDK_ROOT_DIR}/devices/MIMXRT685S/xtensa/min-rt -) -target_link_libraries( - cadence_executorch_example dsp_mu_polling_libs cadence_ops_lib - extension_runner_util executorch -) -add_custom_command( - TARGET cadence_executorch_example - POST_BUILD - COMMAND - ${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/utils/post_compilation.py - ${CMAKE_CURRENT_BINARY_DIR}/${PROJECT_NAME} ${CMAKE_BINARY_DIR} - COMMENT - "Generating .bin files that can be used to flash the DSP with. Copy over - the dsp_text_release.bin and dsp_data_release.bin that are generated into - your NXP MCUXpresso IDE workspace and flash the DSP with these binaries." - DEPENDS - ${CMAKE_CURRENT_LIST_DIR}/utils/post_compilation.py -) + + diff --git a/backends/cadence/aot/__init__.py b/backends/cadence/aot/__init__.py index e69de29bb2d..2e41cd717f6 100644 --- a/backends/cadence/aot/__init__.py +++ b/backends/cadence/aot/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/backends/cadence/aot/export_example.py b/backends/cadence/aot/export_example.py index e95a39b5a36..bf96de2afdf 100644 --- a/backends/cadence/aot/export_example.py +++ b/backends/cadence/aot/export_example.py @@ -14,12 +14,12 @@ from typing import Any, Tuple from executorch.backends.cadence.aot.compiler import export_to_edge -from executorch.backends.cadence.aot.quantizer import ( - CadenceBaseQuantizer, - QuantFusion, +from executorch.backends.cadence.aot.passes import ( ReplacePT2DequantWithCadenceDequant, ReplacePT2QuantWithCadenceQuant, ) +from executorch.backends.cadence.aot.quantizer.fusion_pass import QuantFusion +from executorch.backends.cadence.aot.quantizer.quantizer import CadenceQuantizer from executorch.exir import ExecutorchProgramManager from torch import nn from torch._export import capture_pre_autograd_graph @@ -52,7 +52,7 @@ def export_model( model: nn.Module, example_inputs: Tuple[Any], file_name: str = "CadenceDemoModel" ): # Quantizer - quantizer = CadenceBaseQuantizer() + quantizer = CadenceQuantizer() # Export model_exp = capture_pre_autograd_graph(model, example_inputs) diff --git a/backends/cadence/aot/passes.py b/backends/cadence/aot/passes.py new file mode 100644 index 00000000000..2ced2eaf87a --- /dev/null +++ b/backends/cadence/aot/passes.py @@ -0,0 +1,42 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass + + +class ReplacePT2QuantWithCadenceQuant(ExportPass): + """ + Replace the pt2 quantization ops with custom cadence quantization ops. + """ + + def call_operator(self, op, args, kwargs, meta): + if op not in {exir_ops.edge.quantized_decomposed.quantize_per_tensor.default}: + return super().call_operator(op, args, kwargs, meta) + + return super().call_operator( + exir_ops.edge.cadence.quantize_per_tensor.default, + args, + kwargs, + meta, + ) + + +class ReplacePT2DequantWithCadenceDequant(ExportPass): + """ + Replace the pt2 dequantization ops with custom cadence dequantization ops. + """ + + def call_operator(self, op, args, kwargs, meta): + if op not in {exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default}: + return super().call_operator(op, args, kwargs, meta) + + return super().call_operator( + exir_ops.edge.cadence.dequantize_per_tensor.default, + args, + kwargs, + meta, + ) diff --git a/backends/cadence/aot/quantizer.py b/backends/cadence/aot/quantizer.py deleted file mode 100644 index df184f9d92c..00000000000 --- a/backends/cadence/aot/quantizer.py +++ /dev/null @@ -1,855 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from abc import ABC, abstractmethod -from dataclasses import dataclass, field -from math import frexp, isclose, trunc -from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union - -import torch -from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import ExportPass - -from torch import fx - -from torch.ao.quantization.observer import HistogramObserver, MinMaxObserver -from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions -from torch.ao.quantization.quantizer import Quantizer -from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer -from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import ( - OperatorConfig, - QuantizationAnnotation, - QuantizationConfig, - QuantizationSpec, - SharedQuantizationSpec, -) -from torch.fx import GraphModule -from torch.fx.passes.infra.pass_base import PassResult -from torch.fx.passes.utils.fuser_utils import legalize_graph - - -def quantize_tensor_multiplier( - requantize_scale_tensor: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Given requantize_scale_tensor with values in the interval (0, 1), - produce a pair of tensors (out_multiplier, right_shift) where out_multiplier - is an int32 tensor representing fixed-point values in the interval [-1, 1), - and right_shift is an amount to shift right by, so that the floating-point - multiplication of some int32 input with each value of requantize_scale_tensor: - result = int32_value * requantize_scale_tensors[i] - is best approximated by the integer-arithmetic-only code: - result = RoundingRightShift(FixedPointMultiplication(int32_value, - out_multiplier[i]), right_shift[i]) - """ - - # This is identical to C++11 std::round(). The general python round rounds - # down, and C++ rounds away from zero. - def round_away_zero(f) -> int: - r = -0.5 if (f < 0) else 0.5 - return trunc(f + r) - - def quantize_scalar_multiplier(requantize_scale: float) -> Tuple[int, int]: - significand, exponent = frexp(requantize_scale) - significand_q31 = int(round_away_zero(significand * (1 << 31))) - # Handle the special case when the real multiplier was so close to 1 - # that its fixed-point approximation was indistinguishable from 1. - # We handle this by dividing it by two, incrementing exponent by 1. - # the right shift amount. - if significand_q31 == (1 << 31): - significand_q31 //= 2 - exponent += 1 - - # Verify that the decomposition of requantize_scale into significand - # and exponent is correct. - reconstructed = significand_q31 / (1 << 31) * pow(2, exponent) - assert isclose( - requantize_scale, reconstructed, rel_tol=1e-4, abs_tol=1e-4 - ), "computation of significand and exponent from requantize_scale is not accurate" - - return (significand_q31, exponent) - - # Flatten the input scale tensor so that we can operate on individual values - orig_shape = requantize_scale_tensor.shape - flattened_tensor = requantize_scale_tensor.flatten().to(torch.float32) - out_multiplier = torch.zeros(flattened_tensor.shape, dtype=torch.int32) - right_shift = torch.zeros(flattened_tensor.shape, dtype=torch.int32) - - # Iterate over the flattened scale tensor and compute the decomposition of - # each value in scale tensor into significand(out_multiplier) and - # exponent(right_shift) - for idx, scale in enumerate(flattened_tensor): - (si, ex) = quantize_scalar_multiplier(scale) - out_multiplier[idx], right_shift[idx] = si, ex - - # Reshape the tensors back to the original shape - out_multiplier = out_multiplier.reshape(orig_shape) - right_shift = right_shift.reshape(orig_shape) - - return (out_multiplier, right_shift) - - -def _is_annotated(nodes: List[fx.Node]) -> bool: - annotated = False - for node in nodes: - annotated = annotated or ( - "quantization_annotation" in node.meta - and node.meta["quantization_annotation"]._annotated - ) - return annotated - - -def _no_outside_users(fused_partition) -> bool: - """ - Checks if each partition other than the last does not have any outside users. - """ - for source_partition in fused_partition[:-1]: - if len(source_partition.output_nodes) != 1: - return False - if len(source_partition.output_nodes[0].users) != 1: - return False - return True - - -# Helper function to get the weight node for both quantized and unquantized weights -# TODO(matthiascremon): get a better test! -def get_weight_node(weights_inputs: fx.Node, dequants_weights: fx.Node) -> fx.Node: - """ - Returns the weight node. - """ - weight_node = ( - weights_inputs - if weights_inputs.name.endswith("_frozen_param") - else dequants_weights - ) - return weight_node - - -# Helper function to get the args and kwargs for the linear replacement op -def get_args_and_kwargs_linear( - graph_module: GraphModule, - inputs_inputs: List[fx.Node], - dequants_inputs: List[fx.Node], - other_inputs: List[fx.Node], - weights_inputs: List[fx.Node], - dequants_weights: List[fx.Node], - bias_inputs: List[fx.Node], - quant_node: fx.Node, -) -> Tuple[Tuple[Any], Dict[str, Any]]: - """ - Returns the args and kwargs for the linear replacement op. - """ - weight_scale = get_weight_node(weights_inputs[0], dequants_weights[0]).args[1] - # pyre-fixme[58]: Unsupported operand types - bias_scale = dequants_inputs[0].args[1] * weight_scale - requantize_scale = bias_scale / quant_node.args[1] - requantize_scale_t = torch.tensor([requantize_scale]) - - (out_multiplier, out_shift) = quantize_tensor_multiplier(requantize_scale_t) - - # If bias is not available, create a bias tensor with the shape of weight[0] - if not bias_inputs: - weight_node = get_weight_node(weights_inputs[0], dequants_weights[0]).args[0] - # pyre-fixme[16]: Undefined attribute - attr_node = getattr(graph_module, weight_node.target) - weight_shape = list(attr_node.shape) - bias_shape = weight_shape[0] - bias = graph_module.graph.call_function( - torch.ops.aten.full.default, ([bias_shape], 0.0) - ) - else: - bias = bias_inputs[0] - - bias_int32_quant = graph_module.graph.call_function( - torch.ops.quantized_decomposed.quantize_per_tensor.default, - ( - bias, - bias_scale, - 0, - -(2**31), - (2**31) - 1, - torch.int32, - ), - ) - - # Create single element tensors for weight_zero_point, out_multiplier, out_shift. - # Note that the function expects int32_t, when it would default to int64_t, so - # we explicitly require that type. - weight_zero_point_ = graph_module.graph.call_function( - torch.ops.aten.full.default, - ([1], dequants_weights[0].args[2]), - {"dtype": torch.int32}, - ) - out_multiplier_ = graph_module.graph.call_function( - torch.ops.aten.full.default, - ([1], out_multiplier[0].item()), - {"dtype": torch.int32}, - ) - out_shift_ = graph_module.graph.call_function( - torch.ops.aten.full.default, - ([1], out_shift[0].item()), - {"dtype": torch.int32}, - ) - - args = tuple(inputs_inputs + weights_inputs + other_inputs + [bias_int32_quant]) - kwargs = { - "src_zero_point": dequants_inputs[0].args[2], - "weight_zero_point": weight_zero_point_, - "out_multiplier": out_multiplier_, - "out_shift": out_shift_, - "out_zero_point": quant_node.args[2], - "offset": None, - } - return args, kwargs - - -# Helper function to get the args and kwargs for the layer norm replacement op -def get_args_and_kwargs_layer_norm( - graph_module: GraphModule, - inputs_inputs: List[fx.Node], - dequants_inputs: List[fx.Node], - other_inputs: List[fx.Node], - weights_init_inputs: List[fx.Node], - bias_inputs: List[fx.Node], - quant_node: fx.Node, -) -> Tuple[Tuple[Any], Dict[str, Any]]: - """ - Returns the args and kwargs for the layer norm replacement op. - """ - # Check if the input is per-channel quantized - # TODO(matthiascremon): add proper support and testing for per-channel quantization - assert isinstance(dequants_inputs[0].args[1], float) and isinstance( - dequants_inputs[0].args[2], int - ), "per-channel quantization is not supported for layer norm, both scale and zero_point should be scalars" - - # Make the scale and zero_point tensors - scale_tensor = graph_module.graph.call_function( - torch.ops.aten.full.default, - ( - [1], - dequants_inputs[0].args[1], - ), - ) - zero_point_tensor = graph_module.graph.call_function( - torch.ops.aten.full.default, - ( - [1], - dequants_inputs[0].args[2], - ), - ) - - # Make the args and kwargs for the replacement op - args = tuple(inputs_inputs + [scale_tensor] + [zero_point_tensor]) - kwargs = { - "normalized_shape": other_inputs[0], - "weight": weights_init_inputs[0], - "bias": bias_inputs[0], - "eps": 1e-05, - "output_scale": quant_node.args[1], - "output_zero_point": quant_node.args[2], - } - return args, kwargs - - -def get_conv_args(arg, first_val: int) -> List[fx.Node]: - return arg if len(arg) == 2 else [first_val, arg[0]] - - -def get_args_and_kwargs_conv1d( - graph_module: GraphModule, - inputs_inputs: List[fx.Node], - dequants_inputs: List[fx.Node], - other_inputs: List[fx.Node], - weights_inputs: List[fx.Node], - dequants_weights: List[fx.Node], - bias_inputs: List[fx.Node], - quant_node: fx.Node, - op_node: fx.Node, -): - weight_scale = get_weight_node(weights_inputs[0], dequants_weights[0]).args[1] - weight_zero_point = get_weight_node(weights_inputs[0], dequants_weights[0]).args[2] - # pyre-fixme[58]: Unsupported operand types - bias_scale = dequants_inputs[0].args[1] * weight_scale - stride = [1, 1] if len(op_node.args) < 4 else get_conv_args(op_node.args[3], 1) - padding = [0, 0] if len(op_node.args) < 5 else get_conv_args(op_node.args[4], 0) - dilation = [1, 1] if len(op_node.args) < 6 else get_conv_args(op_node.args[5], 1) - groups = 1 if len(op_node.args) < 7 else op_node.args[6] - # If bias is not available, create a bias tensor with the shape of weight[0] - if not bias_inputs: - weight_node = get_weight_node(weights_inputs[0], dequants_weights[0]).args[0] - # pyre-fixme[16]: Undefined attribute - attr_node = getattr(graph_module, weight_node.target) - weight_shape = list(attr_node.shape) - bias_shape = weight_shape[0] - bias = graph_module.graph.call_function( - torch.ops.aten.full.default, ([bias_shape], 0.0) - ) - else: - bias = bias_inputs[0] - # The bias is quantized to int32_t - bias_int32_quant = graph_module.graph.call_function( - torch.ops.quantized_decomposed.quantize_per_tensor.default, - ( - bias, - bias_scale, - 0, - -(2**31), - (2**31) - 1, - torch.int32, - ), - ) - - # Compute the out multiplier and out shift. They are used when the conv op is - # replaced by quantized linear, we compute them a priori for simplicity but - # may revisit the decision. - requantize_scale = bias_scale / quant_node.args[1] - requantize_scale_t = torch.tensor([requantize_scale]) - - (out_multiplier, out_shift) = quantize_tensor_multiplier(requantize_scale_t) - - out_multiplier_ = graph_module.graph.call_function( - torch.ops.aten.full.default, - ([1], out_multiplier[0].item()), - {"dtype": torch.int32}, - ) - out_shift_ = graph_module.graph.call_function( - torch.ops.aten.full.default, - ([1], out_shift[0].item()), - {"dtype": torch.int32}, - ) - - # Create a single element tensor for the weight zero point - weight_zero_point_tensor = graph_module.graph.call_function( - torch.ops.aten.full.default, - ([1], weight_zero_point), - {"dtype": torch.int32}, - ) - - # Create a single element tensor for the bias scale - bias_scale_tensor = graph_module.graph.call_function( - torch.ops.aten.full.default, - ([1], bias_scale), - {"dtype": torch.float32}, - ) - - # Make the args and kwargs for the replacement op - args = tuple(inputs_inputs + weights_inputs + other_inputs + [bias_int32_quant]) - kwargs = { - "stride": stride, - "padding": padding, - "dilation": dilation, - "groups": groups, - "input_zero_point": dequants_inputs[0].args[2], - "weight_zero_point": weight_zero_point_tensor, - "bias_scale": bias_scale_tensor, - "out_scale": quant_node.args[1], - "out_zero_point": quant_node.args[2], - "out_multiplier": out_multiplier_, - "out_shift": out_shift_, - "channel_last": False, - } - return args, kwargs - - -def get_args_and_kwargs_relu( - graph_module: GraphModule, - inputs_inputs: List[fx.Node], - dequants_inputs: List[fx.Node], -): - # Make the args and kwargs for the replacement op - args = tuple(inputs_inputs) - - X_zero_point = graph_module.graph.call_function( - torch.ops.aten.full.default, ([1], dequants_inputs[0].args[2]) - ) - - kwargs = { - "X_zero_point": X_zero_point, - } - return args, kwargs - - -@dataclass -class PartitionAnchors: - """ - All fields except output are lists of (node, args_index) pair, where node is from - the given partition and node.args[args_index] is an input to the partition. Assumes - a single output. - - Quantizer uses inputs, weights and biases for quantization annotation. The others - field contains tensor inputs that aren't quantized, and the literals fields contains - is used for other types of input values as well as handling default parameters. - """ - - inputs: List[Tuple[fx.Node, int]] = field(default_factory=list) - weights: List[Tuple[fx.Node, int]] = field(default_factory=list) - biases: List[Tuple[fx.Node, int]] = field(default_factory=list) - others: List[Tuple[fx.Node, int]] = field(default_factory=list) - literals: List[Tuple[fx.Node, int]] = field(default_factory=list) - output: List[Union[Tuple[fx.Node], Tuple[fx.Node, QuantizationSpec]]] = field( - default_factory=list - ) - - -class QuantizationPattern(ABC): - @abstractmethod - def partition_types(self) -> List[Any]: - """ - List of types to be passed to find_sequential_partitions. - """ - pass - - @abstractmethod - def get_anchors(self, gm, fused_partition) -> Optional[PartitionAnchors]: - pass - - @abstractmethod - def replacement_op(self) -> Callable[..., Any]: - """ - Operator (most likely a custom one) that this partition should be fused into in - the backend. Refer to the QuantFusion pass for examples. - """ - pass - - -class LinearPattern(QuantizationPattern): - def partition_types(self) -> List[Type[torch.nn.Module]]: - return [torch.nn.Linear] - - def get_anchors( - self, gm: GraphModule, fused_partition: List[GraphModule] - ) -> PartitionAnchors: - linear_node = fused_partition[0].nodes[-1] - - # Keep bias empty if not supplied - bias = [] - if len(linear_node.args) > 2: - bias = [(linear_node, 2)] - - return PartitionAnchors( - inputs=[(linear_node, 0)], - weights=[(linear_node, 1)], - biases=bias, - output=[(linear_node,)], - ) - - def replacement_op(self): - return torch.ops.cadence.quantized_linear.default - - -class LinearFunctionalPattern(QuantizationPattern): - def partition_types(self): - return [torch.nn.functional.linear] - - def get_anchors( - self, gm: GraphModule, fused_partition: List[GraphModule] - ) -> PartitionAnchors: - linear_node = fused_partition[0].nodes[-1] - - return PartitionAnchors( - inputs=[(linear_node, 0)], - weights=[(linear_node, 1)], - biases=[(linear_node, 2)], - output=[(linear_node,)], - ) - - def replacement_op(self): - return torch.ops.cadence.quantized_linear.default - - -class LayerNormPattern(QuantizationPattern): - def partition_types(self): - return [torch.nn.LayerNorm] - - def get_anchors(self, gm, fused_partition) -> PartitionAnchors: - layer_norm_node = fused_partition[0].nodes[-1] - - return PartitionAnchors( - inputs=[(layer_norm_node, 0)], - weights=[(layer_norm_node, 2)], - biases=[(layer_norm_node, 3)], - others=[(layer_norm_node, 1)], - output=[(layer_norm_node,)], - ) - - def replacement_op(self): - return torch.ops.cadence.quantized_layer_norm.default - - -class Conv1dPattern(QuantizationPattern): - def partition_types(self) -> List[Type[torch.nn.Module]]: - return [torch.nn.Conv1d] - - def get_anchors( - self, gm: GraphModule, fused_partition: List[GraphModule] - ) -> PartitionAnchors: - conv1d_node = fused_partition[0].nodes[-1] - - # If bias is None, replace it with an empty list. - bias = ( - [(conv1d_node, 2)] - if len(conv1d_node.args) > 2 and conv1d_node.args[2] - else [] - ) - - return PartitionAnchors( - inputs=[(conv1d_node, 0)], - weights=[(conv1d_node, 1)], - biases=bias, - output=[(conv1d_node,)], - ) - - def replacement_op(self): - return torch.ops.cadence.quantized_conv.default - - -class Conv2dPattern(QuantizationPattern): - def partition_types(self) -> List[Type[torch.nn.Module]]: - return [torch.nn.Conv2d] - - def get_anchors( - self, gm: GraphModule, fused_partition: List[GraphModule] - ) -> PartitionAnchors: - conv2d_node = fused_partition[0].nodes[-1] - - # If bias is None, replace it with an empty list. - bias = ( - [(conv2d_node, 2)] - if len(conv2d_node.args) > 2 and conv2d_node.args[2] - else [] - ) - - return PartitionAnchors( - inputs=[(conv2d_node, 0)], - weights=[(conv2d_node, 1)], - biases=bias, - output=[(conv2d_node,)], - ) - - def replacement_op(self): - return torch.ops.cadence.quantized_conv.default - - -class AddmmPattern(QuantizationPattern): - def partition_types(self) -> List[Type[torch.nn.Module]]: - return [torch.addmm] - - def get_anchors( - self, gm: GraphModule, fused_partition: List[GraphModule] - ) -> PartitionAnchors: - addmm_node = fused_partition[0].nodes[-1] - - return PartitionAnchors( - inputs=[(addmm_node, 1)], - weights=[(addmm_node, 2)], - biases=[(addmm_node, 0)], - output=[(addmm_node,)], - ) - - def replacement_op(self): - return torch.ops.cadence.quantized_linear.default - - -class ReluPattern(QuantizationPattern): - def partition_types(self) -> List[Type[torch.nn.Module]]: - return [torch.nn.ReLU] - - def get_anchors( - self, gm: GraphModule, fused_partition: List[GraphModule] - ) -> PartitionAnchors: - relu_node = fused_partition[0].nodes[-1] - - return PartitionAnchors( - inputs=[(relu_node, 0)], - weights=[], - biases=[], - # pyre-fixme[6]: Incompatible parameter type - output=[ - (relu_node, SharedQuantizationSpec((relu_node.args[0], relu_node))) - ], - ) - - def replacement_op(self): - return torch.ops.cadence.quantized_relu.default - - -class GenericQuantizer(Quantizer): - def __init__(self, pattern, quantization_config): - super().__init__() - self.pattern = pattern - self.quantization_config = quantization_config - - def annotate(self, model): - fused_partitions = find_sequential_partitions( - model, - self.pattern.partition_types(), - ) - - input_act_qspec = self.quantization_config.input_activation - weight_qspec = self.quantization_config.weight - bias_qspec = self.quantization_config.bias - output_act_qspec = self.quantization_config.output_activation - - for fused_partition in fused_partitions: - if not _no_outside_users(fused_partition): - continue - - anchors = self.pattern.get_anchors(model, fused_partition) - if not anchors: - continue - if _is_annotated( - [ - x[0] - for x in anchors.inputs - + anchors.weights - + anchors.biases - + anchors.output - ] - ): - continue - - for output, *custom_spec in anchors.output: - output.meta["quantization_annotation"] = QuantizationAnnotation( - output_qspec=custom_spec[0] if custom_spec else output_act_qspec, - _annotated=True, - ) - - def annotate_inputs(inputs, spec): - for node, idx in inputs: - annotation = node.meta.get( - "quantization_annotation", - QuantizationAnnotation(_annotated=True), - ) - annotation.input_qspec_map[node.args[idx]] = spec - node.meta["quantization_annotation"] = annotation - - annotate_inputs(anchors.inputs, input_act_qspec) - annotate_inputs(anchors.weights, weight_qspec) - annotate_inputs(anchors.biases, bias_qspec) - - def validate(self, model: fx.GraphModule) -> None: - pass - - @classmethod - def get_supported_operators(cls) -> List[OperatorConfig]: - return [] - - -act_qspec = QuantizationSpec( - dtype=torch.uint8, - quant_min=0, - quant_max=255, - qscheme=torch.per_tensor_affine, - is_dynamic=False, - observer_or_fake_quant_ctr=HistogramObserver.with_args(eps=2**-12), -) - -wgt_qspec = QuantizationSpec( - dtype=torch.uint8, - quant_min=0, - quant_max=255, - qscheme=torch.per_tensor_affine, - is_dynamic=False, - observer_or_fake_quant_ctr=MinMaxObserver, -) - - -class CadenceBaseQuantizer(ComposableQuantizer): - def __init__(self): - static_qconfig = QuantizationConfig( - act_qspec, - act_qspec, - wgt_qspec, - None, - ) - static_qconfig_no_wgt = QuantizationConfig( - act_qspec, - act_qspec, - None, - None, - ) - super().__init__( - [ - GenericQuantizer(AddmmPattern(), static_qconfig), - GenericQuantizer(Conv1dPattern(), static_qconfig), - GenericQuantizer(Conv2dPattern(), static_qconfig), - GenericQuantizer(LayerNormPattern(), static_qconfig_no_wgt), - GenericQuantizer(LinearFunctionalPattern(), static_qconfig), - GenericQuantizer(LinearPattern(), static_qconfig), - GenericQuantizer(ReluPattern(), static_qconfig), - ] - ) - - -class QuantFusion(ExportPass): - def __init__(self, patterns): - super().__init__() - self.patterns = patterns - - def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901 - for pattern in self.patterns: - fused_partitions = find_sequential_partitions( - graph_module, - pattern.partition_types(), - ) - for fused_partition in fused_partitions: - anchors = pattern.get_anchors(graph_module, fused_partition) - if not anchors: - continue - if any(self.is_fused(p.nodes) for p in fused_partition): - continue - - for p in fused_partition: - self.mark_fused(p.nodes) - - dequants_inputs = [] - for node, idx in anchors.inputs: - if ( - node.args[idx].target - == torch.ops.quantized_decomposed.dequantize_per_tensor.default - ): - dequants_inputs.append(node.args[idx]) - dequants_weights = [] - for node, idx in anchors.weights: - if ( - node.args[idx].target - == torch.ops.quantized_decomposed.dequantize_per_tensor.default - ): - dequants_weights.append(node.args[idx]) - - inputs_inputs = [node.args[0] for node in dequants_inputs] - weights_inputs = [node.args[0] for node in dequants_weights] - weights_init_inputs = [node.args[idx] for node, idx in anchors.weights] - bias_inputs = [node.args[idx] for node, idx in anchors.biases] - other_inputs = [node.args[idx] for node, idx in anchors.others] - - # The node is the first index of the list and first of the tuple - op_node = anchors.output[0][0] - - assert len(op_node.users) == 1 - quant_node = list(op_node.users.keys())[0] - - with graph_module.graph.inserting_after(op_node): - args = tuple( - inputs_inputs + weights_inputs + other_inputs + bias_inputs - ) - kwargs = {} - if isinstance(pattern, Conv1dPattern) or isinstance( - pattern, Conv2dPattern - ): - args, kwargs = get_args_and_kwargs_conv1d( - graph_module, - inputs_inputs, - dequants_inputs, - other_inputs, - weights_inputs, - dequants_weights, - bias_inputs, - quant_node, - op_node, - ) - elif isinstance(pattern, LinearPattern) or isinstance( - pattern, LinearFunctionalPattern - ): - args, kwargs = get_args_and_kwargs_linear( - graph_module, - inputs_inputs, - dequants_inputs, - other_inputs, - weights_inputs, - dequants_weights, - bias_inputs, - quant_node, - ) - elif isinstance(pattern, LayerNormPattern): - args, kwargs = get_args_and_kwargs_layer_norm( - graph_module, - inputs_inputs, - dequants_inputs, - other_inputs, - weights_init_inputs, - bias_inputs, - quant_node, - ) - elif isinstance(pattern, AddmmPattern): - # Transpose the weight tensor - transposed_weights = graph_module.graph.call_function( - torch.ops.aten.transpose.int, - (weights_inputs[0], 0, 1), - ) - # Call linear with transposed weight - args, kwargs = get_args_and_kwargs_linear( - graph_module, - inputs_inputs, - dequants_inputs, - other_inputs, - [transposed_weights], - dequants_weights, - bias_inputs, - quant_node, - ) - elif isinstance(pattern, ReluPattern): - args, kwargs = get_args_and_kwargs_relu( - graph_module, - inputs_inputs, - dequants_inputs, - ) - fused = graph_module.graph.call_function( - pattern.replacement_op(), - args, - kwargs, - ) - fused.meta = quant_node.meta - quant_node.replace_all_uses_with(fused) - - legalize_graph(graph_module) - graph_module.graph.eliminate_dead_code() - # pyre-fixme[7]: Incompatible return type - graph_module.recompile() - - @classmethod - def is_fused(cls, nodes) -> bool: - return any(cls.__qualname__ in n.meta for n in nodes) - - @classmethod - def mark_fused(cls, nodes) -> bool: - for n in nodes: - # pyre-fixme[7]: Incompatible return type - n.meta["QuantFusion"] = True - - -class ReplacePT2QuantWithCadenceQuant(ExportPass): - """ - Replace the pt2 quantization ops with custom cadence quantization ops. - """ - - def call_operator(self, op, args, kwargs, meta): - if op not in {exir_ops.edge.quantized_decomposed.quantize_per_tensor.default}: - return super().call_operator(op, args, kwargs, meta) - - return super().call_operator( - exir_ops.edge.cadence.quantize_per_tensor.default, - args, - kwargs, - meta, - ) - - -class ReplacePT2DequantWithCadenceDequant(ExportPass): - """ - Replace the pt2 dequantization ops with custom cadence dequantization ops. - """ - - def call_operator(self, op, args, kwargs, meta): - if op not in {exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default}: - return super().call_operator(op, args, kwargs, meta) - - return super().call_operator( - exir_ops.edge.cadence.dequantize_per_tensor.default, - args, - kwargs, - meta, - ) diff --git a/backends/cadence/aot/quantizer/fusion_pass.py b/backends/cadence/aot/quantizer/fusion_pass.py new file mode 100644 index 00000000000..69b12cf7a97 --- /dev/null +++ b/backends/cadence/aot/quantizer/fusion_pass.py @@ -0,0 +1,437 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Dict, List, Tuple + +import torch +from executorch.backends.cadence.aot.quantizer.patterns import ( + AddmmPattern, + Conv1dPattern, + Conv2dPattern, + LayerNormFunctionalPattern, + LayerNormPattern, + LinearFunctionalPattern, + LinearPattern, + MatmulPattern, + ReluPattern, +) +from executorch.backends.cadence.aot.quantizer.utils import ( + create_zero_bias_int32, + get_conv_args, + quantize_tensor_multiplier, +) +from executorch.exir.pass_base import ExportPass +from torch import fx +from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions +from torch.fx import GraphModule +from torch.fx.passes.infra.pass_base import PassResult +from torch.fx.passes.utils.fuser_utils import legalize_graph + + +# Helper function to get the args and kwargs for the linear replacement op +def get_args_and_kwargs_linear( + graph_module: GraphModule, + inputs_inputs: List[fx.Node], + dequants_inputs: List[fx.Node], + weights_inputs: List[fx.Node], + dequants_weights: List[fx.Node], + bias_inputs: List[fx.Node], + quant_node: fx.Node, +) -> Tuple[Tuple[Any], Dict[str, Any]]: + """ + Returns the args and kwargs for the linear replacement op. + """ + weight_scale = dequants_weights[0].args[1] + # pyre-fixme[58]: Unsupported operand types + bias_scale = dequants_inputs[0].args[1] * weight_scale + requantize_scale = bias_scale / quant_node.args[1] + requantize_scale_t = torch.tensor([requantize_scale]) + + (out_multiplier, out_shift) = quantize_tensor_multiplier(requantize_scale_t) + + # If bias is not available, create a bias tensor with the shape of weight[0] + if not bias_inputs: + weight_node = dequants_weights[0].args[0] + assert isinstance(weight_node, fx.Node) + bias = create_zero_bias_int32(graph_module, weight_node, bias_scale) + else: + bias = bias_inputs[0] + + # Create single element tensors for weight_zero_point, out_multiplier, out_shift. + # Note that the function expects int32_t, when it would default to int64_t, so + # we explicitly require that type. + weight_zero_point_ = graph_module.graph.call_function( + torch.ops.aten.full.default, + ([1], dequants_weights[0].args[2]), + {"dtype": torch.int32}, + ) + out_multiplier_ = graph_module.graph.call_function( + torch.ops.aten.full.default, + ([1], out_multiplier[0].item()), + {"dtype": torch.int32}, + ) + out_shift_ = graph_module.graph.call_function( + torch.ops.aten.full.default, + ([1], out_shift[0].item()), + {"dtype": torch.int32}, + ) + + args = tuple(inputs_inputs + weights_inputs + [bias]) + kwargs = { + "src_zero_point": dequants_inputs[0].args[2], + "weight_zero_point": weight_zero_point_, + "out_multiplier": out_multiplier_, + "out_shift": out_shift_, + "out_zero_point": quant_node.args[2], + "offset": None, + } + return args, kwargs + + +# Helper function to get the args and kwargs for the layer norm replacement op +def get_args_and_kwargs_layer_norm( + graph_module: GraphModule, + inputs_inputs: List[fx.Node], + dequants_inputs: List[fx.Node], + other_inputs: List[fx.Node], + quant_node: fx.Node, +) -> Tuple[Tuple[Any], Dict[str, Any]]: + """ + Returns the args and kwargs for the layer norm replacement op. + """ + # Check if the input is per-channel quantized + # TODO(matthiascremon): add proper support and testing for per-channel quantization + assert isinstance(dequants_inputs[0].args[1], float) and isinstance( + dequants_inputs[0].args[2], int + ), "per-channel quantization is not supported for layer norm, both scale and zero_point should be scalars" + + # Make the scale and zero_point tensors + scale_tensor = graph_module.graph.call_function( + torch.ops.aten.full.default, + ( + [1], + dequants_inputs[0].args[1], + ), + {"dtype": torch.float32}, + ) + zero_point_tensor = graph_module.graph.call_function( + torch.ops.aten.full.default, + ( + [1], + dequants_inputs[0].args[2], + ), + {"dtype": torch.int32}, + ) + + weight = other_inputs[1] if len(other_inputs) > 1 else None + + if not weight: + weight = graph_module.graph.call_function( + torch.ops.aten.full.default, + ( + other_inputs[0], + 1, + ), + {"dtype": torch.float32}, + ) + + bias = other_inputs[2] if len(other_inputs) > 2 else None + + if not bias: + bias = graph_module.graph.call_function( + torch.ops.aten.full.default, + ( + other_inputs[0], + 0, + ), + {"dtype": torch.float32}, + ) + + # Make the args and kwargs for the replacement op + args = tuple(inputs_inputs + [scale_tensor] + [zero_point_tensor]) + kwargs = { + "normalized_shape": other_inputs[0], + "weight": weight, + "bias": bias, + "eps": 1e-05, + "output_scale": quant_node.args[1], + "output_zero_point": quant_node.args[2], + } + return args, kwargs + + +def get_args_and_kwargs_matmul( + inputs_inputs: List[fx.Node], + dequants_inputs: List[fx.Node], + quant_node: fx.Node, +) -> Tuple[Tuple[Any], Dict[str, Any]]: + requantize_scale = ( + # pyre-ignore[58]: Unsupported operand + dequants_inputs[0].args[1] + * dequants_inputs[1].args[1] + ) / quant_node.args[1] + requantize_scale_t = torch.tensor([requantize_scale]) + + (out_multiplier, out_shift) = quantize_tensor_multiplier(requantize_scale_t) + + args = ( + inputs_inputs[0], + dequants_inputs[0].args[2], + inputs_inputs[1], + dequants_inputs[1].args[2], + None, + ) + + kwargs = { + "out_multiplier": out_multiplier[0].item(), + "out_shift": out_shift[0].item(), + "out_zero_point": quant_node.args[2], + "transposed": False, + } + return args, kwargs + + +def get_args_and_kwargs_conv( + graph_module: GraphModule, + inputs_inputs: List[fx.Node], + dequants_inputs: List[fx.Node], + weights_inputs: List[fx.Node], + dequants_weights: List[fx.Node], + bias_inputs: List[fx.Node], + quant_node: fx.Node, + op_node: fx.Node, +): + weight_scale = dequants_weights[0].args[1] + weight_zero_point = dequants_weights[0].args[2] + # pyre-fixme[58]: Unsupported operand types + bias_scale = dequants_inputs[0].args[1] * weight_scale + stride = [1, 1] if len(op_node.args) < 4 else get_conv_args(op_node.args[3], 1) + padding = [0, 0] if len(op_node.args) < 5 else get_conv_args(op_node.args[4], 0) + dilation = [1, 1] if len(op_node.args) < 6 else get_conv_args(op_node.args[5], 1) + groups = 1 if len(op_node.args) < 7 else op_node.args[6] + + # If bias is not available, create a bias tensor with the shape of weight[0] + if not bias_inputs: + weight_node = dequants_weights[0].args[0] + assert isinstance(weight_node, fx.Node) + bias = create_zero_bias_int32(graph_module, weight_node, bias_scale) + else: + bias = bias_inputs[0] + + # Compute the out multiplier and out shift. They are used when the conv op is + # replaced by quantized linear, we compute them a priori for simplicity but + # may revisit the decision. + requantize_scale = bias_scale / quant_node.args[1] + requantize_scale_t = torch.tensor([requantize_scale]) + + (out_multiplier, out_shift) = quantize_tensor_multiplier(requantize_scale_t) + + out_multiplier_ = graph_module.graph.call_function( + torch.ops.aten.full.default, + ([1], out_multiplier[0].item()), + {"dtype": torch.int32}, + ) + out_shift_ = graph_module.graph.call_function( + torch.ops.aten.full.default, + ([1], out_shift[0].item()), + {"dtype": torch.int32}, + ) + + # Create a single element tensor for the weight zero point + weight_zero_point_tensor = graph_module.graph.call_function( + torch.ops.aten.full.default, + ([1], weight_zero_point), + {"dtype": torch.int32}, + ) + + # Create a single element tensor for the bias scale + bias_scale_tensor = graph_module.graph.call_function( + torch.ops.aten.full.default, + ([1], bias_scale), + {"dtype": torch.float32}, + ) + + # Make the args and kwargs for the replacement op + args = tuple(inputs_inputs + weights_inputs + [bias]) + kwargs = { + "stride": stride, + "padding": padding, + "dilation": dilation, + "groups": groups, + "input_zero_point": dequants_inputs[0].args[2], + "weight_zero_point": weight_zero_point_tensor, + "bias_scale": bias_scale_tensor, + "out_scale": quant_node.args[1], + "out_zero_point": quant_node.args[2], + "out_multiplier": out_multiplier_, + "out_shift": out_shift_, + "channel_last": False, + } + return args, kwargs + + +def get_args_and_kwargs_relu( + graph_module: GraphModule, + inputs_inputs: List[fx.Node], + dequants_inputs: List[fx.Node], +): + # Make the args and kwargs for the replacement op + args = tuple(inputs_inputs) + + X_zero_point = graph_module.graph.call_function( + torch.ops.aten.full.default, ([1], dequants_inputs[0].args[2]) + ) + + kwargs = { + "X_zero_point": X_zero_point, + } + return args, kwargs + + +class QuantFusion(ExportPass): + def __init__(self, patterns): + super().__init__() + self.patterns = patterns + + def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901 + for pattern in self.patterns: + fused_partitions = find_sequential_partitions( + graph_module, + pattern.partition_types(), + ) + for fused_partition in fused_partitions: + anchors = pattern.get_anchors(graph_module, fused_partition) + if not anchors: + continue + if any(self.is_fused(p.nodes) for p in fused_partition): + continue + + for p in fused_partition: + self.mark_fused(p.nodes) + + dequants_inputs = [] + for node, idx in anchors.inputs: + if ( + node.args[idx].target + == torch.ops.quantized_decomposed.dequantize_per_tensor.default + ): + dequants_inputs.append(node.args[idx]) + dequants_weights = [] + for node, idx in anchors.weights: + if ( + node.args[idx].target + == torch.ops.quantized_decomposed.dequantize_per_tensor.default + ): + dequants_weights.append(node.args[idx]) + dequants_biases = [] + for node, idx, *_spec in anchors.biases: + if ( + node.args[idx].target + == torch.ops.quantized_decomposed.dequantize_per_tensor.default + ): + dequants_biases.append(node.args[idx]) + + inputs_inputs = [node.args[0] for node in dequants_inputs] + weights_inputs = [node.args[0] for node in dequants_weights] + bias_inputs = [node.args[0] for node in dequants_biases] + other_inputs = [node.args[idx] for node, idx in anchors.others] + + # The node is the first index of the list and first of the tuple + op_node = anchors.output[0][0] + + assert len(op_node.users) == 1 + quant_node = list(op_node.users.keys())[0] + + with graph_module.graph.inserting_after(op_node): + args = tuple( + inputs_inputs + weights_inputs + other_inputs + bias_inputs + ) + kwargs = {} + if isinstance(pattern, Conv1dPattern) or isinstance( + pattern, Conv2dPattern + ): + args, kwargs = get_args_and_kwargs_conv( + graph_module, + inputs_inputs, + dequants_inputs, + weights_inputs, + dequants_weights, + bias_inputs, + quant_node, + op_node, + ) + elif isinstance(pattern, LinearPattern) or isinstance( + pattern, LinearFunctionalPattern + ): + args, kwargs = get_args_and_kwargs_linear( + graph_module, + inputs_inputs, + dequants_inputs, + weights_inputs, + dequants_weights, + bias_inputs, + quant_node, + ) + elif isinstance(pattern, LayerNormPattern) or isinstance( + pattern, LayerNormFunctionalPattern + ): + args, kwargs = get_args_and_kwargs_layer_norm( + graph_module, + inputs_inputs, + dequants_inputs, + other_inputs, + quant_node, + ) + elif isinstance(pattern, MatmulPattern): + args, kwargs = get_args_and_kwargs_matmul( + inputs_inputs, + dequants_inputs, + quant_node, + ) + elif isinstance(pattern, AddmmPattern): + # Transpose the weight tensor + transposed_weights = graph_module.graph.call_function( + torch.ops.aten.transpose.int, + (weights_inputs[0], 0, 1), + ) + # Call linear with transposed weight + args, kwargs = get_args_and_kwargs_linear( + graph_module, + inputs_inputs, + dequants_inputs, + [transposed_weights], + dequants_weights, + bias_inputs, + quant_node, + ) + elif isinstance(pattern, ReluPattern): + args, kwargs = get_args_and_kwargs_relu( + graph_module, + inputs_inputs, + dequants_inputs, + ) + fused = graph_module.graph.call_function( + pattern.replacement_op(), + args, + kwargs, + ) + fused.meta = quant_node.meta + quant_node.replace_all_uses_with(fused) + + legalize_graph(graph_module) + graph_module.graph.eliminate_dead_code() + # pyre-fixme[7]: Incompatible return type + graph_module.recompile() + + @classmethod + def is_fused(cls, nodes) -> bool: + return any(cls.__qualname__ in n.meta for n in nodes) + + @classmethod + def mark_fused(cls, nodes) -> bool: + for n in nodes: + # pyre-fixme[7]: Incompatible return type + n.meta["QuantFusion"] = True diff --git a/backends/cadence/aot/quantizer/patterns.py b/backends/cadence/aot/quantizer/patterns.py new file mode 100644 index 00000000000..6df27982585 --- /dev/null +++ b/backends/cadence/aot/quantizer/patterns.py @@ -0,0 +1,344 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import Any, Callable, List, Optional, Tuple, Type, Union + +import torch +from executorch.backends.cadence.aot.quantizer.utils import get_bias_qparams + +from torch import fx +from torch.ao.quantization.quantizer import ( + DerivedQuantizationSpec, + SharedQuantizationSpec, +) + + +@dataclass +class PartitionAnchors: + """ + All fields except output are lists of (node, args_index) pair, where node is from + the given partition and node.args[args_index] is an input to the partition. Assumes + a single output. + + Quantizer uses inputs, weights and biases for quantization annotation. The others + field contains tensor inputs that aren't quantized, and the literals fields contains + is used for other types of input values as well as handling default parameters. + """ + + inputs: List[Tuple[fx.Node, int]] = field(default_factory=list) + weights: List[Tuple[fx.Node, int]] = field(default_factory=list) + biases: List[ + Union[Tuple[fx.Node, int], Tuple[fx.Node, int, DerivedQuantizationSpec]] + ] = field(default_factory=list) + others: List[Tuple[fx.Node, int]] = field(default_factory=list) + literals: List[Tuple[fx.Node, int]] = field(default_factory=list) + output: List[Union[Tuple[fx.Node], Tuple[fx.Node, SharedQuantizationSpec]]] = field( + default_factory=list + ) + + +class QuantizationPattern(ABC): + @abstractmethod + def partition_types(self): + """ + List of types to be passed to find_sequential_partitions. + """ + pass + + @abstractmethod + def get_anchors(self, gm, fused_partition) -> Optional[PartitionAnchors]: + pass + + @abstractmethod + def replacement_op(self) -> Callable[..., Any]: + """ + Operator (most likely a custom one) that this partition should be fused into in + the backend. Refer to the QuantFusion pass for examples. + """ + pass + + +class AddmmPattern(QuantizationPattern): + def partition_types(self) -> List[Type[torch.nn.Module]]: + return [torch.addmm] + + def get_anchors( + self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] + ) -> PartitionAnchors: + addmm_node = fused_partition[0].nodes[-1] + + bias_qspec = DerivedQuantizationSpec( + derived_from=[ + (addmm_node.args[1], addmm_node), + (addmm_node.args[2], addmm_node), + ], + derive_qparams_fn=get_bias_qparams, + dtype=torch.int32, + quant_min=-(2**31), + quant_max=2**31 - 1, + qscheme=torch.per_tensor_affine, + ) + + return PartitionAnchors( + inputs=[(addmm_node, 1)], + weights=[(addmm_node, 2)], + biases=[(addmm_node, 0, bias_qspec)], + output=[(addmm_node,)], + ) + + def replacement_op(self): + return torch.ops.cadence.quantized_linear + + +class Conv1dPattern(QuantizationPattern): + def partition_types(self) -> List[Type[torch.nn.Module]]: + return [torch.nn.Conv1d] + + def get_anchors( + self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] + ) -> PartitionAnchors: + conv1d_node = fused_partition[0].nodes[-1] + + bias_qspec = DerivedQuantizationSpec( + derived_from=[ + (conv1d_node.args[0], conv1d_node), + (conv1d_node.args[1], conv1d_node), + ], + derive_qparams_fn=get_bias_qparams, + dtype=torch.int32, + quant_min=-(2**31), + quant_max=2**31 - 1, + qscheme=torch.per_tensor_affine, + ) + + # Keep bias empty if not supplied + bias = [] + if len(conv1d_node.args) > 2 and conv1d_node.args[2] is not None: + bias = [(conv1d_node, 2, bias_qspec)] + + return PartitionAnchors( + inputs=[(conv1d_node, 0)], + weights=[(conv1d_node, 1)], + # pyre-fixme[6]: Incompatible parameter type + biases=bias, + output=[(conv1d_node,)], + ) + + def replacement_op(self): + return torch.ops.cadence.quantized_conv.default + + +class Conv2dPattern(QuantizationPattern): + def partition_types(self) -> List[Type[torch.nn.Module]]: + return [torch.nn.Conv2d] + + def get_anchors( + self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] + ) -> PartitionAnchors: + conv2d_node = fused_partition[0].nodes[-1] + + bias_qspec = DerivedQuantizationSpec( + derived_from=[ + (conv2d_node.args[0], conv2d_node), + (conv2d_node.args[1], conv2d_node), + ], + derive_qparams_fn=get_bias_qparams, + dtype=torch.int32, + quant_min=-(2**31), + quant_max=2**31 - 1, + qscheme=torch.per_tensor_affine, + ) + + # Keep bias empty if not supplied + bias = [] + if len(conv2d_node.args) > 2 and conv2d_node.args[2] is not None: + bias = [(conv2d_node, 2, bias_qspec)] + + return PartitionAnchors( + inputs=[(conv2d_node, 0)], + weights=[(conv2d_node, 1)], + # pyre-fixme[6]: Incompatible parameter type + biases=bias, + output=[(conv2d_node,)], + ) + + def replacement_op(self): + return torch.ops.cadence.quantized_conv.default + + +class LayerNormPattern(QuantizationPattern): + def partition_types(self): + return [torch.nn.LayerNorm] + + def get_anchors(self, gm, fused_partition) -> PartitionAnchors: + layer_norm_node = fused_partition[0].nodes[-1] + + # Weights and biases are used as fp32 by our kernel, so they are + # passed in as others here along with the normalized shape. + return PartitionAnchors( + inputs=[(layer_norm_node, 0)], + weights=[], + biases=[], + # Ordering: normalized_shape, weights, bias + others=[(layer_norm_node, 1), (layer_norm_node, 2), (layer_norm_node, 3)], + output=[(layer_norm_node,)], + ) + + def replacement_op(self): + return torch.ops.cadence.quantized_layer_norm.default + + +class LayerNormFunctionalPattern(QuantizationPattern): + def partition_types(self): + return [torch.nn.functional.layer_norm] + + def get_anchors(self, gm, fused_partition) -> PartitionAnchors: + layer_norm_node = fused_partition[0].nodes[-1] + + others = [(layer_norm_node, 1)] + + # Add weights if supplied + if len(layer_norm_node.args) > 2 and layer_norm_node.args[2]: + others.append((layer_norm_node, 2)) + + # Add bias if supplied + if len(layer_norm_node.args) > 3 and layer_norm_node.args[3]: + others.append((layer_norm_node, 3)) + + # Weights are used in quantized mode by our kernel, so they are + # passed in as others here along with the normalized shape. + return PartitionAnchors( + inputs=[(layer_norm_node, 0)], + weights=[], + biases=[], + # Ordering: normalized_shape, weights, bias + # pyre-fixme[6]: Incompatible parameter type + others=others, + output=[(layer_norm_node,)], + ) + + def replacement_op(self): + return torch.ops.cadence.quantized_layer_norm.default + + +class LinearPattern(QuantizationPattern): + def partition_types(self) -> List[Type[torch.nn.Module]]: + return [torch.nn.Linear] + + def get_anchors( + self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] + ) -> PartitionAnchors: + linear_node = fused_partition[0].nodes[-1] + + bias_qspec = DerivedQuantizationSpec( + derived_from=[ + (linear_node.args[0], linear_node), + (linear_node.args[1], linear_node), + ], + derive_qparams_fn=get_bias_qparams, + dtype=torch.int32, + quant_min=-(2**31), + quant_max=2**31 - 1, + qscheme=torch.per_tensor_affine, + ) + + # Keep bias empty if not supplied + bias = [] + if len(linear_node.args) > 2: + bias = [(linear_node, 2, bias_qspec)] + + return PartitionAnchors( + inputs=[(linear_node, 0)], + weights=[(linear_node, 1)], + # pyre-fixme[6]: Incompatible parameter type + biases=bias, + output=[(linear_node,)], + ) + + def replacement_op(self): + return torch.ops.cadence.quantized_linear.default + + +class LinearFunctionalPattern(QuantizationPattern): + def partition_types(self): + return [torch.nn.functional.linear] + + def get_anchors( + self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] + ) -> PartitionAnchors: + linear_node = fused_partition[0].nodes[-1] + + bias_qspec = DerivedQuantizationSpec( + derived_from=[ + (linear_node.args[0], linear_node), + (linear_node.args[1], linear_node), + ], + derive_qparams_fn=get_bias_qparams, + dtype=torch.int32, + quant_min=-(2**31), + quant_max=2**31 - 1, + qscheme=torch.per_tensor_affine, + ) + + # Keep bias empty if not supplied + bias = [] + if len(linear_node.args) > 2 and linear_node.args[2] is not None: + bias = [(linear_node, 2, bias_qspec)] + + return PartitionAnchors( + inputs=[(linear_node, 0)], + weights=[(linear_node, 1)], + # pyre-fixme[6]: Incompatible parameter type + biases=bias, + output=[(linear_node,)], + ) + + def replacement_op(self): + return torch.ops.cadence.quantized_linear.default + + +class MatmulPattern(QuantizationPattern): + def partition_types(self): + return [torch.matmul] + + def get_anchors( + self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] + ) -> PartitionAnchors: + matmul_node = fused_partition[0].nodes[-1] + + return PartitionAnchors( + inputs=[(matmul_node, 0), (matmul_node, 1)], + weights=[], + biases=[], + output=[(matmul_node,)], + ) + + def replacement_op(self): + return torch.ops.cadence.quantized_matmul.default + + +class ReluPattern(QuantizationPattern): + def partition_types(self) -> List[Type[torch.nn.Module]]: + return [torch.nn.ReLU] + + def get_anchors( + self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] + ) -> PartitionAnchors: + relu_node = fused_partition[0].nodes[-1] + + return PartitionAnchors( + inputs=[(relu_node, 0)], + weights=[], + biases=[], + output=[ + (relu_node, SharedQuantizationSpec((relu_node.args[0], relu_node))) + ], + ) + + def replacement_op(self): + return torch.ops.cadence.quantized_relu.default diff --git a/backends/cadence/aot/quantizer/quantizer.py b/backends/cadence/aot/quantizer/quantizer.py new file mode 100644 index 00000000000..79e6fb28149 --- /dev/null +++ b/backends/cadence/aot/quantizer/quantizer.py @@ -0,0 +1,145 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List + +import torch +from executorch.backends.cadence.aot.quantizer.patterns import ( + AddmmPattern, + Conv1dPattern, + Conv2dPattern, + LayerNormFunctionalPattern, + LayerNormPattern, + LinearFunctionalPattern, + LinearPattern, + MatmulPattern, + ReluPattern, +) +from executorch.backends.cadence.aot.quantizer.utils import ( + is_annotated, + no_outside_users, +) + +from torch import fx + +from torch.ao.quantization.observer import HistogramObserver, MinMaxObserver +from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions +from torch.ao.quantization.quantizer import Quantizer +from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer +from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import ( + OperatorConfig, + QuantizationAnnotation, + QuantizationConfig, + QuantizationSpec, +) + + +act_qspec = QuantizationSpec( + dtype=torch.uint8, + quant_min=0, + quant_max=255, + qscheme=torch.per_tensor_affine, + is_dynamic=False, + observer_or_fake_quant_ctr=HistogramObserver.with_args(eps=2**-12), +) + +wgt_qspec = QuantizationSpec( + dtype=torch.uint8, + quant_min=0, + quant_max=255, + qscheme=torch.per_tensor_affine, + is_dynamic=False, + observer_or_fake_quant_ctr=MinMaxObserver, +) + +bias_qspec = None + + +class CadenceGenericQuantizer(Quantizer): + def __init__(self, pattern, quantization_config): + super().__init__() + self.pattern = pattern + self.quantization_config = quantization_config + + def annotate(self, model): + fused_partitions = find_sequential_partitions( + model, + self.pattern.partition_types(), + ) + + input_act_qspec = self.quantization_config.input_activation + weight_qspec = self.quantization_config.weight + bias_qspec = self.quantization_config.bias + output_act_qspec = self.quantization_config.output_activation + + for fused_partition in fused_partitions: + if not no_outside_users(fused_partition): + continue + + anchors = self.pattern.get_anchors(model, fused_partition) + if not anchors: + continue + if is_annotated( + [ + x[0] + for x in anchors.inputs + + anchors.weights + + anchors.biases + + anchors.output + ] + ): + continue + + for output, *custom_spec in anchors.output: + output.meta["quantization_annotation"] = QuantizationAnnotation( + output_qspec=custom_spec[0] if custom_spec else output_act_qspec, + _annotated=True, + ) + + def annotate_inputs(inputs, spec): + for node, idx, *custom_spec in inputs: + annotation = node.meta.get( + "quantization_annotation", + QuantizationAnnotation(_annotated=True), + ) + annotation.input_qspec_map[node.args[idx]] = ( + custom_spec[0] if custom_spec else spec + ) + node.meta["quantization_annotation"] = annotation + + annotate_inputs(anchors.inputs, input_act_qspec) + annotate_inputs(anchors.weights, weight_qspec) + annotate_inputs(anchors.biases, bias_qspec) + + def validate(self, model: fx.GraphModule) -> None: + pass + + @classmethod + def get_supported_operators(cls) -> List[OperatorConfig]: + return [] + + +class CadenceQuantizer(ComposableQuantizer): + def __init__(self): + static_qconfig = QuantizationConfig( + act_qspec, + act_qspec, + wgt_qspec, + None, + ) + super().__init__( + [ + CadenceGenericQuantizer(AddmmPattern(), static_qconfig), + CadenceGenericQuantizer(Conv1dPattern(), static_qconfig), + CadenceGenericQuantizer(Conv2dPattern(), static_qconfig), + CadenceGenericQuantizer(LayerNormPattern(), static_qconfig), + CadenceGenericQuantizer(LayerNormFunctionalPattern(), static_qconfig), + CadenceGenericQuantizer(LinearPattern(), static_qconfig), + CadenceGenericQuantizer(LinearFunctionalPattern(), static_qconfig), + CadenceGenericQuantizer(MatmulPattern(), static_qconfig), + CadenceGenericQuantizer(ReluPattern(), static_qconfig), + ] + ) diff --git a/backends/cadence/aot/quantizer/utils.py b/backends/cadence/aot/quantizer/utils.py new file mode 100644 index 00000000000..21dac6b0b0f --- /dev/null +++ b/backends/cadence/aot/quantizer/utils.py @@ -0,0 +1,129 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from math import frexp, isclose, trunc +from typing import List, Tuple + +import torch +from torch import fx +from torch.ao.quantization import ObserverOrFakeQuantize + +from torch.fx import GraphModule + + +def quantize_tensor_multiplier( + requantize_scale_tensor: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Given requantize_scale_tensor with values in the interval (0, 1), + produce a pair of tensors (out_multiplier, right_shift) where out_multiplier + is an int32 tensor representing fixed-point values in the interval [-1, 1), + and right_shift is an amount to shift right by, so that the floating-point + multiplication of some int32 input with each value of requantize_scale_tensor: + result = int32_value * requantize_scale_tensors[i] + is best approximated by the integer-arithmetic-only code: + result = RoundingRightShift(FixedPointMultiplication(int32_value, + out_multiplier[i]), right_shift[i]) + """ + + # This is identical to C++11 std::round(). The general python round rounds + # down, and C++ rounds away from zero. + def round_away_zero(f) -> int: + r = -0.5 if (f < 0) else 0.5 + return trunc(f + r) + + def quantize_scalar_multiplier(requantize_scale: float) -> Tuple[int, int]: + significand, exponent = frexp(requantize_scale) + significand_q31 = int(round_away_zero(significand * (1 << 31))) + # Handle the special case when the real multiplier was so close to 1 + # that its fixed-point approximation was indistinguishable from 1. + # We handle this by dividing it by two, incrementing exponent by 1. + # the right shift amount. + if significand_q31 == (1 << 31): + significand_q31 //= 2 + exponent += 1 + + # Verify that the decomposition of requantize_scale into significand + # and exponent is correct. + reconstructed = significand_q31 / (1 << 31) * pow(2, exponent) + assert isclose( + requantize_scale, reconstructed, rel_tol=1e-4, abs_tol=1e-4 + ), "computation of significand and exponent from requantize_scale is not accurate" + + return (significand_q31, exponent) + + # Flatten the input scale tensor so that we can operate on individual values + orig_shape = requantize_scale_tensor.shape + flattened_tensor = requantize_scale_tensor.flatten().to(torch.float32) + out_multiplier = torch.zeros(flattened_tensor.shape, dtype=torch.int32) + right_shift = torch.zeros(flattened_tensor.shape, dtype=torch.int32) + + # Iterate over the flattened scale tensor and compute the decomposition of + # each value in scale tensor into significand(out_multiplier) and + # exponent(right_shift) + for idx, scale in enumerate(flattened_tensor): + (si, ex) = quantize_scalar_multiplier(scale) + out_multiplier[idx], right_shift[idx] = si, ex + + # Reshape the tensors back to the original shape + out_multiplier = out_multiplier.reshape(orig_shape) + right_shift = right_shift.reshape(orig_shape) + + return (out_multiplier, right_shift) + + +def is_annotated(nodes: List[fx.Node]) -> bool: + annotated = False + for node in nodes: + annotated = annotated or ( + "quantization_annotation" in node.meta + and node.meta["quantization_annotation"]._annotated + ) + return annotated + + +def no_outside_users(fused_partition) -> bool: + """ + Checks if each partition other than the last does not have any outside users. + """ + for source_partition in fused_partition[:-1]: + if len(source_partition.output_nodes) != 1: + return False + if len(source_partition.output_nodes[0].users) != 1: + return False + return True + + +def create_zero_bias_int32( + graph_module: GraphModule, + weight_node: fx.Node, + bias_scale: float, +) -> fx.Node: + """ + Creates a zero bias tensor with the shape of weight[0] + """ + attr_node = getattr(graph_module, weight_node.target) + weight_shape = list(attr_node.shape) + bias_shape = weight_shape[0] + return graph_module.graph.call_function( + torch.ops.aten.full.default, + ([bias_shape], 0.0), + {"dtype": torch.int32}, + ) + + +def get_bias_qparams( + obs_or_fqs: List[ObserverOrFakeQuantize], +) -> Tuple[torch.Tensor, torch.Tensor]: + act_scale, _ = obs_or_fqs[0].calculate_qparams() + weight_scale, _ = obs_or_fqs[1].calculate_qparams() + bias_scale = act_scale * weight_scale + bias_zero_point = torch.zeros_like(bias_scale, dtype=torch.int32) + return bias_scale, bias_zero_point + + +def get_conv_args(arg, first_val: int) -> List[fx.Node]: + return arg if len(arg) == 2 else [first_val, arg[0]] diff --git a/backends/cadence/cadence.cmake b/backends/cadence/cadence.cmake index 137b178ab92..25f241f205c 100644 --- a/backends/cadence/cadence.cmake +++ b/backends/cadence/cadence.cmake @@ -41,8 +41,8 @@ set(CMAKE_CROSSCOMPILING TRUE) set(CMAKE_C_COMPILER ${TOOLCHAIN_HOME}/bin/${CROSS_COMPILE_TARGET}-clang) set(CMAKE_CXX_COMPILER ${TOOLCHAIN_HOME}/bin/${CROSS_COMPILE_TARGET}-clang++) -set(CMAKE_C_FLAGS_INIT "-stdlib=libc++") -set(CMAKE_CXX_FLAGS_INIT "-stdlib=libc++") +set(CMAKE_C_FLAGS_INIT "-stdlib=libc++ -mtext-section-literals -mlongcalls") +set(CMAKE_CXX_FLAGS_INIT "-stdlib=libc++ -mtext-section-literals -mlongcalls") set(CMAKE_SYSROOT ${TOOLCHAIN_HOME}/${SYSROOT_TARGET}) set(CMAKE_LINKER ${TOOLCHAIN_HOME}/bin/xt-ld) add_link_options(-lm -stdlib=libc++ -Wl,--no-as-needed -static) diff --git a/backends/cadence/hifi/kernels/CMakeLists.txt b/backends/cadence/hifi/kernels/CMakeLists.txt index 9d4d456d8bc..872e62fc970 100644 --- a/backends/cadence/hifi/kernels/CMakeLists.txt +++ b/backends/cadence/hifi/kernels/CMakeLists.txt @@ -20,3 +20,5 @@ target_include_directories( ${NN_LIB_BASE_DIR}/xa_nnlib/algo/ndsp/hifi4/include/ ${NXP_SDK_ROOT_DIR}/middleware/dsp/naturedsp/hifi4/include/ ) + +target_link_libraries(cadence_kernels PRIVATE xa_nnlib) \ No newline at end of file diff --git a/backends/cadence/hifi/kernels/kernels.h b/backends/cadence/hifi/kernels/kernels.h index 13e0470b382..59bf4c41f65 100644 --- a/backends/cadence/hifi/kernels/kernels.h +++ b/backends/cadence/hifi/kernels/kernels.h @@ -12,6 +12,31 @@ #include "stddef.h" #include "xa_type_def.h" + +extern "C" WORD32 xa_nn_elm_quantize_f32_asym8s(WORD8 * __restrict__ p_out, + const FLOAT32 * __restrict__ p_inp, + FLOAT32 out_scale, + WORD32 out_zero_bias, + WORD32 num_elm); + +/*extern "C" WORD32 xa_nn_elm_quantize_f32_asym8u(UWORD8 * __restrict__ p_out, + const FLOAT32 * __restrict__ p_inp, + FLOAT32 out_scale, + WORD32 out_zero_bias, + WORD32 num_elm); */ + +extern "C" WORD32 xa_nn_elm_dequantize_asym8s_f32(FLOAT32 * __restrict__ p_out, + const WORD8 * __restrict__ p_inp, + WORD32 inp_zero_bias, + FLOAT32 inp_scale, + WORD32 num_elm); + +/*extern "C" WORD32 xa_nn_elm_dequantize_asym8u_f32(FLOAT32 * __restrict__ p_out, + const UWORD8 * __restrict__ p_inp, + WORD32 inp_zero_bias, + FLOAT32 inp_scale, + WORD32 num_elm);*/ + namespace impl { namespace HiFi { namespace kernels { diff --git a/backends/cadence/hifi/operators/dequantize_per_tensor.cpp b/backends/cadence/hifi/operators/dequantize_per_tensor.cpp index dcc4ace7898..3f683a6d713 100644 --- a/backends/cadence/hifi/operators/dequantize_per_tensor.cpp +++ b/backends/cadence/hifi/operators/dequantize_per_tensor.cpp @@ -31,16 +31,21 @@ void dequantize_per_tensor_out( if (input.scalar_type() == ScalarType::Byte) { const uint8_t* input_data = input.const_data_ptr(); - impl::HiFi::kernels::dequantize( - out_data, input_data, scale, zero_point, numel); +#if 0 //NNLIB_OPT (not available in nnlib) + xa_nn_elm_dequantize_asym8u_f32(out_data, input_data, zero_point, scale, numel); +#else + impl::HiFi::kernels::dequantize(out_data, input_data, scale, zero_point, numel); +#endif } else if (input.scalar_type() == ScalarType::Char) { const int8_t* input_data = input.const_data_ptr(); - impl::HiFi::kernels::dequantize( - out_data, input_data, scale, zero_point, numel); +#if NNLIB_OPT + xa_nn_elm_dequantize_asym8s_f32(out_data, input_data, zero_point, scale, numel); +#else + impl::HiFi::kernels::dequantize(out_data, input_data, scale, zero_point, numel); +#endif } else if (input.scalar_type() == ScalarType::Int) { const int32_t* input_data = input.const_data_ptr(); - impl::HiFi::kernels::dequantize( - out_data, input_data, scale, zero_point, numel); + impl::HiFi::kernels::dequantize(out_data, input_data, scale, zero_point, numel); } else { ET_CHECK_MSG(false, "Unhandled input dtype %hhd", input.scalar_type()); } diff --git a/backends/cadence/hifi/operators/quantize_per_tensor.cpp b/backends/cadence/hifi/operators/quantize_per_tensor.cpp index ec186cc68e2..3137b91f6be 100644 --- a/backends/cadence/hifi/operators/quantize_per_tensor.cpp +++ b/backends/cadence/hifi/operators/quantize_per_tensor.cpp @@ -33,16 +33,21 @@ void quantize_per_tensor_out( if (out.scalar_type() == ScalarType::Byte) { uint8_t* out_data = out.mutable_data_ptr(); - impl::HiFi::kernels::quantize( - out_data, input_data, 1. / scale, zero_point, numel); +#if 0 //NNLIB_OPT (not available in nnlib) + xa_nn_elm_quantize_f32_asym8u(out_data, input_data, scale, zero_point, numel); +#else + impl::HiFi::kernels::quantize(out_data, input_data, 1. / scale, zero_point, numel); +#endif } else if (out.scalar_type() == ScalarType::Char) { int8_t* out_data = out.mutable_data_ptr(); - impl::HiFi::kernels::quantize( - out_data, input_data, 1. / scale, zero_point, numel); +#if NNLIB_OPT + xa_nn_elm_quantize_f32_asym8s(out_data, input_data, scale, zero_point, numel); +#else + impl::HiFi::kernels::quantize(out_data, input_data, 1. / scale, zero_point, numel); +#endif } else if (out.scalar_type() == ScalarType::Int) { int32_t* out_data = out.mutable_data_ptr(); - impl::HiFi::kernels::quantize( - out_data, input_data, 1. / scale, zero_point, numel); + impl::HiFi::kernels::quantize(out_data, input_data, 1. / scale, zero_point, numel); } else { ET_CHECK_MSG(false, "Unhandled input dtype %hhd", out.scalar_type()); } diff --git a/backends/cadence/hifi/third-party/nnlib/CMakeLists.txt b/backends/cadence/hifi/third-party/nnlib/CMakeLists.txt new file mode 100644 index 00000000000..d8f2b4eb3d9 --- /dev/null +++ b/backends/cadence/hifi/third-party/nnlib/CMakeLists.txt @@ -0,0 +1,33 @@ + +cmake_minimum_required(VERSION 3.10.0) +project(cadence_nnlib) + + +add_custom_target( nnlib_target ALL COMMAND + make install_nnlib -f makefile -C ${EXECUTORCH_ROOT}/backends/cadence/hifi/third-party/nnlib/nnlib-hifi4/xa_nnlib/build + OBJDIR=${CMAKE_CURRENT_BINARY_DIR}/obj + LIBDIR=${CMAKE_CURRENT_BINARY_DIR}/lib + -j8 ) + +add_library(xa_nnlib STATIC IMPORTED GLOBAL) +add_dependencies(xa_nnlib nnlib_target) + +message("NNLIB") +message("${CMAKE_CURRENT_BINARY_DIR}") + +set_property( + TARGET xa_nnlib + PROPERTY + IMPORTED_LOCATION "${CMAKE_CURRENT_BINARY_DIR}/lib/xa_nnlib.a" +) + + + + + + + + + + + diff --git a/backends/cadence/hifi/third-party/nnlib/nnlib-hifi4 b/backends/cadence/hifi/third-party/nnlib/nnlib-hifi4 new file mode 160000 index 00000000000..6a9ea45e23e --- /dev/null +++ b/backends/cadence/hifi/third-party/nnlib/nnlib-hifi4 @@ -0,0 +1 @@ +Subproject commit 6a9ea45e23ef591fe207442df33a5ebe88bbe8de diff --git a/backends/qualcomm/tests/utils.py b/backends/qualcomm/tests/utils.py index 59a48f123da..b7390bd42b2 100644 --- a/backends/qualcomm/tests/utils.py +++ b/backends/qualcomm/tests/utils.py @@ -187,6 +187,7 @@ def lower_module_and_test_output( ) exec_prog = delegated_program.to_executorch( exir.ExecutorchBackendConfig( + extract_delegate_segments=False, # For shared buffer, user must pass the memory address # which is allocated by RPC memory to executor runner. # Therefore, won't want to pre-allocate @@ -195,7 +196,7 @@ def lower_module_and_test_output( memory_planning_algo="greedy", alloc_graph_input=not self.shared_buffer, alloc_graph_output=not self.shared_buffer, - ) + ), ) ) diff --git a/backends/vulkan/partitioner/vulkan_partitioner.py b/backends/vulkan/partitioner/vulkan_partitioner.py index 93e703dc030..86dfb74d069 100644 --- a/backends/vulkan/partitioner/vulkan_partitioner.py +++ b/backends/vulkan/partitioner/vulkan_partitioner.py @@ -38,11 +38,12 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: # Unary operators exir_ops.edge.aten.abs.default, exir_ops.edge.aten.clamp.default, + exir_ops.edge.aten.gelu.default, exir_ops.edge.aten.hardtanh.default, exir_ops.edge.aten.relu.default, exir_ops.edge.aten.sigmoid.default, - exir_ops.edge.aten.tanh.default, exir_ops.edge.aten.sqrt.default, + exir_ops.edge.aten.tanh.default, # Matrix multiplication operators exir_ops.edge.aten.bmm.default, exir_ops.edge.aten.mm.default, diff --git a/backends/vulkan/runtime/VulkanBackend.cpp b/backends/vulkan/runtime/VulkanBackend.cpp index c1f3f06b440..1c754599678 100644 --- a/backends/vulkan/runtime/VulkanBackend.cpp +++ b/backends/vulkan/runtime/VulkanBackend.cpp @@ -48,7 +48,7 @@ using BytesVector = const flatbuffers::Vector>*; using UIntVector = const flatbuffers::Vector*; -const uint8_t* getConstantDataPtr( +const uint8_t* get_constant_data_ptr( VkGraphPtr flatbuffer_graph, const int32_t buffer_idx, const uint8_t* constant_data) { @@ -111,19 +111,19 @@ GraphConfig get_graph_config(ArrayRef& compile_specs) { const size_t value_size = spec.value.nbytes; if (strcmp(spec.key, "storage_type_override") == 0) { ET_CHECK_MSG(value_size == sizeof(int32_t), "Unexpected value size!"); - int value_as_int = static_cast(GetUInt32LE(value_data)); + int value_as_int = static_cast(getUInt32LE(value_data)); api::StorageType storage_type = static_cast(value_as_int); - config.setStorageTypeOverride(storage_type); + config.set_storage_type_override(storage_type); } if (strcmp(spec.key, "memory_layout_override") == 0) { ET_CHECK_MSG(value_size == sizeof(uint32_t), "Unexpected value size!"); - uint32_t value_as_int = GetUInt32LE(value_data); + uint32_t value_as_int = getUInt32LE(value_data); api::GPUMemoryLayout memory_layout = static_cast(value_as_int); - config.setMemoryLayoutOverride(memory_layout); + config.set_memory_layout_override(memory_layout); } } return config; @@ -181,7 +181,7 @@ class GraphBuilder { ValueRef ref; if (tensor_fb->constant_id() >= 0) { - const uint8_t* tensor_data = getConstantDataPtr( + const uint8_t* tensor_data = get_constant_data_ptr( flatbuffer_, tensor_fb->constant_id(), constant_data_); ref = compute_graph_->add_tensorref(dims_vector, dtype, tensor_data); @@ -399,7 +399,7 @@ class VulkanBackend final : public PyTorchBackendInterface { __ET_NODISCARD Error compileModel(const void* buffer_pointer, ComputeGraph* compute_graph) const { Result header = - VulkanDelegateHeader::Parse(buffer_pointer); + VulkanDelegateHeader::parse(buffer_pointer); const uint8_t* flatbuffer_data = nullptr; const uint8_t* constant_data = nullptr; diff --git a/backends/vulkan/runtime/VulkanDelegateHeader.cpp b/backends/vulkan/runtime/VulkanDelegateHeader.cpp index 4415996a648..a9a9fa849a7 100644 --- a/backends/vulkan/runtime/VulkanDelegateHeader.cpp +++ b/backends/vulkan/runtime/VulkanDelegateHeader.cpp @@ -39,7 +39,7 @@ constexpr ByteSlice kBytesSize = {22, 8}; } // namespace /// Interprets the 8 bytes at `data` as a little-endian uint64_t. -uint64_t GetUInt64LE(const uint8_t* data) { +uint64_t getUInt64LE(const uint8_t* data) { return (uint64_t)data[0] | ((uint64_t)data[1] << 8) | ((uint64_t)data[2] << 16) | ((uint64_t)data[3] << 24) | ((uint64_t)data[4] << 32) | ((uint64_t)data[5] << 40) | @@ -47,13 +47,13 @@ uint64_t GetUInt64LE(const uint8_t* data) { } /// Interprets the 4 bytes at `data` as a little-endian uint32_t. -uint32_t GetUInt32LE(const uint8_t* data) { +uint32_t getUInt32LE(const uint8_t* data) { return (uint32_t)data[0] | ((uint32_t)data[1] << 8) | ((uint32_t)data[2] << 16) | ((uint32_t)data[3] << 24); } /// Interprets the 2 bytes at `data` as a little-endian uint32_t. -uint32_t GetUInt16LE(const uint8_t* data) { +uint32_t getUInt16LE(const uint8_t* data) { return (uint32_t)data[0] | ((uint32_t)data[1] << 8); } @@ -77,7 +77,7 @@ bool VulkanDelegateHeader::is_valid() const { return true; } -Result VulkanDelegateHeader::Parse(const void* data) { +Result VulkanDelegateHeader::parse(const void* data) { const uint8_t* header_data = (const uint8_t*)data; const uint8_t* magic_start = header_data + kMagic.offset; @@ -86,11 +86,11 @@ Result VulkanDelegateHeader::Parse(const void* data) { } VulkanDelegateHeader header = VulkanDelegateHeader{ - GetUInt16LE(header_data + kHeaderSize.offset), - GetUInt32LE(header_data + kFlatbufferOffset.offset), - GetUInt32LE(header_data + kFlatbufferSize.offset), - GetUInt32LE(header_data + kBytesOffset.offset), - GetUInt64LE(header_data + kBytesSize.offset), + getUInt16LE(header_data + kHeaderSize.offset), + getUInt32LE(header_data + kFlatbufferOffset.offset), + getUInt32LE(header_data + kFlatbufferSize.offset), + getUInt32LE(header_data + kBytesOffset.offset), + getUInt64LE(header_data + kBytesSize.offset), }; if (!header.is_valid()) { diff --git a/backends/vulkan/runtime/VulkanDelegateHeader.h b/backends/vulkan/runtime/VulkanDelegateHeader.h index f9757ef4c2a..c5e8859743a 100644 --- a/backends/vulkan/runtime/VulkanDelegateHeader.h +++ b/backends/vulkan/runtime/VulkanDelegateHeader.h @@ -15,14 +15,14 @@ namespace executor { namespace vulkan { // Byte decoding utilities -uint64_t GetUInt64LE(const uint8_t* data); -uint32_t GetUInt32LE(const uint8_t* data); -uint32_t GetUInt16LE(const uint8_t* data); +uint64_t getUInt64LE(const uint8_t* data); +uint32_t getUInt32LE(const uint8_t* data); +uint32_t getUInt16LE(const uint8_t* data); struct VulkanDelegateHeader { bool is_valid() const; - static Result Parse(const void* data); + static Result parse(const void* data); uint32_t header_size; uint32_t flatbuffer_offset; diff --git a/backends/vulkan/runtime/api/Adapter.cpp b/backends/vulkan/runtime/api/Adapter.cpp index 5db2642e3ec..932678f18fc 100644 --- a/backends/vulkan/runtime/api/Adapter.cpp +++ b/backends/vulkan/runtime/api/Adapter.cpp @@ -292,7 +292,8 @@ DeviceHandle::~DeviceHandle() { Adapter::Adapter( VkInstance instance, PhysicalDevice physical_device, - const uint32_t num_queues) + const uint32_t num_queues, + const std::string& cache_data_path) : queue_usage_mutex_{}, physical_device_(std::move(physical_device)), queues_{}, @@ -307,7 +308,7 @@ Adapter::Adapter( shader_layout_cache_(device_.handle_), shader_cache_(device_.handle_), pipeline_layout_cache_(device_.handle_), - compute_pipeline_cache_(device_.handle_), + compute_pipeline_cache_(device_.handle_, cache_data_path), sampler_cache_(device_.handle_), vma_(instance_, physical_device_.handle, device_.handle_) {} diff --git a/backends/vulkan/runtime/api/Adapter.h b/backends/vulkan/runtime/api/Adapter.h index b038aea9fa8..fcbba281642 100644 --- a/backends/vulkan/runtime/api/Adapter.h +++ b/backends/vulkan/runtime/api/Adapter.h @@ -16,6 +16,8 @@ #include #include +#include + #include #include #include @@ -99,7 +101,8 @@ class Adapter final { explicit Adapter( VkInstance instance, PhysicalDevice physical_device, - const uint32_t num_queues); + const uint32_t num_queues, + const std::string& cache_data_path); Adapter(const Adapter&) = delete; Adapter& operator=(const Adapter&) = delete; @@ -136,7 +139,7 @@ class Adapter final { ComputePipelineCache compute_pipeline_cache_; // Memory Management SamplerCache sampler_cache_; - MemoryAllocator vma_; + Allocator vma_; public: // Physical Device metadata @@ -194,7 +197,7 @@ class Adapter final { return sampler_cache_; } - inline MemoryAllocator& vma() { + inline Allocator& vma() { return vma_; } diff --git a/backends/vulkan/runtime/api/Command.cpp b/backends/vulkan/runtime/api/Command.cpp index 2ddb4ab15aa..9c70cfa60b2 100644 --- a/backends/vulkan/runtime/api/Command.cpp +++ b/backends/vulkan/runtime/api/Command.cpp @@ -133,16 +133,14 @@ void CommandBuffer::insert_barrier(PipelineBarrier& pipeline_barrier) { if (!pipeline_barrier.buffer_barrier_handles.empty()) { pipeline_barrier.buffer_barrier_handles.clear(); } - for (const api::BufferMemoryBarrier& memory_barrier : - pipeline_barrier.buffers) { + for (const BufferMemoryBarrier& memory_barrier : pipeline_barrier.buffers) { pipeline_barrier.buffer_barrier_handles.push_back(memory_barrier.handle); } if (!pipeline_barrier.image_barrier_handles.empty()) { pipeline_barrier.image_barrier_handles.clear(); } - for (const api::ImageMemoryBarrier& memory_barrier : - pipeline_barrier.images) { + for (const ImageMemoryBarrier& memory_barrier : pipeline_barrier.images) { pipeline_barrier.image_barrier_handles.push_back(memory_barrier.handle); } vkCmdPipelineBarrier( @@ -185,11 +183,11 @@ void CommandBuffer::dispatch(const utils::uvec3& global_workgroup_size) { } void CommandBuffer::copy_buffer_to_buffer( - const api::VulkanBuffer& source, - const api::VulkanBuffer& destination, - const api::utils::uvec3& copy_range, - const api::utils::uvec3& src_offset, - const api::utils::uvec3& dst_offset) { + const VulkanBuffer& source, + const VulkanBuffer& destination, + const utils::uvec3& copy_range, + const utils::uvec3& src_offset, + const utils::uvec3& dst_offset) { VK_CHECK_COND( state_ == CommandBuffer::State::BARRIERS_INSERTED, "Vulkan CommandBuffer: called copy_buffer_to_buffer() on a command buffer whose state " @@ -208,11 +206,11 @@ void CommandBuffer::copy_buffer_to_buffer( } void CommandBuffer::copy_texture_to_texture( - const api::VulkanImage& source, - const api::VulkanImage& destination, - const api::utils::uvec3& copy_range, - const api::utils::uvec3& src_offset, - const api::utils::uvec3& dst_offset) { + const VulkanImage& source, + const VulkanImage& destination, + const utils::uvec3& copy_range, + const utils::uvec3& src_offset, + const utils::uvec3& dst_offset) { VK_CHECK_COND( state_ == CommandBuffer::State::BARRIERS_INSERTED, "Vulkan CommandBuffer: called copy_texture_to_texture() on a command buffer whose state " @@ -253,11 +251,11 @@ void CommandBuffer::copy_texture_to_texture( } void CommandBuffer::copy_texture_to_buffer( - const api::VulkanImage& source, - const api::VulkanBuffer& destination, - const api::utils::uvec3& copy_range, - const api::utils::uvec3& src_offset, - const api::utils::uvec3& dst_offset) { + const VulkanImage& source, + const VulkanBuffer& destination, + const utils::uvec3& copy_range, + const utils::uvec3& src_offset, + const utils::uvec3& dst_offset) { VK_CHECK_COND( state_ == CommandBuffer::State::BARRIERS_INSERTED, "Vulkan CommandBuffer: called copy_texture_to_buffer() on a command buffer whose state " @@ -291,11 +289,11 @@ void CommandBuffer::copy_texture_to_buffer( } void CommandBuffer::copy_buffer_to_texture( - const api::VulkanBuffer& source, - const api::VulkanImage& destination, - const api::utils::uvec3& copy_range, - const api::utils::uvec3& src_offset, - const api::utils::uvec3& dst_offset) { + const VulkanBuffer& source, + const VulkanImage& destination, + const utils::uvec3& copy_range, + const utils::uvec3& src_offset, + const utils::uvec3& dst_offset) { VK_CHECK_COND( state_ == CommandBuffer::State::BARRIERS_INSERTED, "Vulkan CommandBuffer: called copy_buffer_to_texture() on a command buffer whose state " @@ -392,7 +390,7 @@ CommandPool::CommandPool( VK_CHECK(vkCreateCommandPool(device_, &create_info, nullptr, &pool_)); // Pre-allocate some command buffers - allocate_new_batch(config_.cmdPoolInitialSize); + allocate_new_batch(config_.cmd_pool_initial_size); } CommandPool::~CommandPool() { @@ -406,7 +404,7 @@ CommandBuffer CommandPool::get_new_cmd(bool reusable) { std::lock_guard lock(mutex_); // No-ops if there are command buffers available - allocate_new_batch(config_.cmdPoolBatchSize); + allocate_new_batch(config_.cmd_pool_batch_size); VkCommandBuffer handle = buffers_[in_use_]; diff --git a/backends/vulkan/runtime/api/Command.h b/backends/vulkan/runtime/api/Command.h index 904631b2ac4..ff009de8fc0 100644 --- a/backends/vulkan/runtime/api/Command.h +++ b/backends/vulkan/runtime/api/Command.h @@ -14,10 +14,12 @@ #include #include -#include #include #include +#include +#include + namespace vkcompute { namespace api { @@ -92,32 +94,32 @@ class CommandBuffer final { void dispatch(const utils::uvec3&); void copy_buffer_to_buffer( - const api::VulkanBuffer&, - const api::VulkanBuffer&, - const api::utils::uvec3&, - const api::utils::uvec3&, - const api::utils::uvec3&); + const VulkanBuffer&, + const VulkanBuffer&, + const utils::uvec3&, + const utils::uvec3&, + const utils::uvec3&); void copy_texture_to_texture( - const api::VulkanImage&, - const api::VulkanImage&, - const api::utils::uvec3&, - const api::utils::uvec3&, - const api::utils::uvec3&); + const VulkanImage&, + const VulkanImage&, + const utils::uvec3&, + const utils::uvec3&, + const utils::uvec3&); void copy_texture_to_buffer( - const api::VulkanImage&, - const api::VulkanBuffer&, - const api::utils::uvec3&, - const api::utils::uvec3&, - const api::utils::uvec3&); + const VulkanImage&, + const VulkanBuffer&, + const utils::uvec3&, + const utils::uvec3&, + const utils::uvec3&); void copy_buffer_to_texture( - const api::VulkanBuffer&, - const api::VulkanImage&, - const api::utils::uvec3&, - const api::utils::uvec3&, - const api::utils::uvec3&); + const VulkanBuffer&, + const VulkanImage&, + const utils::uvec3&, + const utils::uvec3&, + const utils::uvec3&); void write_timestamp(VkQueryPool, const uint32_t) const; void reset_querypool(VkQueryPool, const uint32_t, const uint32_t) const; @@ -130,8 +132,8 @@ class CommandBuffer final { }; struct CommandPoolConfig final { - uint32_t cmdPoolInitialSize; - uint32_t cmdPoolBatchSize; + uint32_t cmd_pool_initial_size; + uint32_t cmd_pool_batch_size; }; class CommandPool final { diff --git a/backends/vulkan/runtime/api/Context.cpp b/backends/vulkan/runtime/api/Context.cpp index 9a43cf455d6..99d9ab0aa5d 100644 --- a/backends/vulkan/runtime/api/Context.cpp +++ b/backends/vulkan/runtime/api/Context.cpp @@ -30,12 +30,12 @@ Context::Context(size_t adapter_i, const ContextConfig& config) device_(adapter_p_->device_handle()), queue_(adapter_p_->request_queue()), // Resource pools - command_pool_(device_, queue_.family_index, config_.cmdPoolConfig), - descriptor_pool_(device_, config_.descriptorPoolConfig), + command_pool_(device_, queue_.family_index, config_.cmd_pool_config), + descriptor_pool_(device_, config_.descriptor_pool_config), fences_(device_), // Diagnostics #ifdef USE_VULKAN_GPU_DIAGNOSTICS - querypool_(config_.queryPoolConfig, adapter_p_), + querypool_(config_.query_pool_config, adapter_p_), #endif /* USE_VULKAN_GPU_DIAGNOSTICS */ // Command buffer submission cmd_mutex_{}, @@ -143,7 +143,7 @@ bool available() { Context* context() { static const std::unique_ptr context([]() -> Context* { try { - const uint32_t submit_frequency = 16u; + const uint32_t cmd_submit_frequency = 16u; const CommandPoolConfig cmd_config{ 32u, // cmdPoolInitialSize @@ -165,10 +165,10 @@ Context* context() { }; const ContextConfig config{ - submit_frequency, // cmdSubmitFrequency - cmd_config, // cmdPoolConfig - descriptor_pool_config, // descriptorPoolConfig - query_pool_config, // queryPoolConfig + cmd_submit_frequency, + cmd_config, + descriptor_pool_config, + query_pool_config, }; return new Context(runtime()->default_adapter_i(), config); @@ -236,7 +236,7 @@ UniformParamsBuffer& UniformParamsBuffer::operator=( } ParamsBindList::ParamsBindList( - std::initializer_list init_list) { + std::initializer_list init_list) { bind_infos.resize(init_list.size()); std::copy(init_list.begin(), init_list.end(), bind_infos.begin()); } diff --git a/backends/vulkan/runtime/api/Context.h b/backends/vulkan/runtime/api/Context.h index d79344dce8d..f8bc923f394 100644 --- a/backends/vulkan/runtime/api/Context.h +++ b/backends/vulkan/runtime/api/Context.h @@ -15,21 +15,23 @@ #include #include #include +#include #include #include -#include #include #include #include +#include + namespace vkcompute { namespace api { struct ContextConfig final { - uint32_t cmdSubmitFrequency; - CommandPoolConfig cmdPoolConfig; - DescriptorPoolConfig descriptorPoolConfig; - QueryPoolConfig queryPoolConfig; + uint32_t cmd_submit_frequency; + CommandPoolConfig cmd_pool_config; + DescriptorPoolConfig descriptor_pool_config; + QueryPoolConfig query_pool_config; }; // @@ -194,9 +196,9 @@ class Context final { PipelineBarrier&, const S&, const D&, - const api::utils::uvec3&, - const api::utils::uvec3&, - const api::utils::uvec3&, + const utils::uvec3&, + const utils::uvec3&, + const utils::uvec3&, VkFence fence_handle); template @@ -265,9 +267,9 @@ class UniformParamsBuffer final { }; struct ParamsBindList final { - std::vector bind_infos; + std::vector bind_infos; - ParamsBindList(std::initializer_list init_list); + ParamsBindList(std::initializer_list init_list); }; class StorageBuffer final { @@ -374,18 +376,18 @@ inline void record_copy( CommandBuffer& cmd, const S& source, const D& destination, - const api::utils::uvec3& copy_range, - const api::utils::uvec3& src_offset, - const api::utils::uvec3& dst_offset) = delete; + const utils::uvec3& copy_range, + const utils::uvec3& src_offset, + const utils::uvec3& dst_offset) = delete; template <> inline void record_copy( CommandBuffer& cmd, const VulkanBuffer& source, const VulkanBuffer& destination, - const api::utils::uvec3& copy_range, - const api::utils::uvec3& src_offset, - const api::utils::uvec3& dst_offset) { + const utils::uvec3& copy_range, + const utils::uvec3& src_offset, + const utils::uvec3& dst_offset) { cmd.copy_buffer_to_buffer( source, destination, copy_range, src_offset, dst_offset); } @@ -395,9 +397,9 @@ inline void record_copy( CommandBuffer& cmd, const VulkanImage& source, const VulkanImage& destination, - const api::utils::uvec3& copy_range, - const api::utils::uvec3& src_offset, - const api::utils::uvec3& dst_offset) { + const utils::uvec3& copy_range, + const utils::uvec3& src_offset, + const utils::uvec3& dst_offset) { cmd.copy_texture_to_texture( source, destination, copy_range, src_offset, dst_offset); } @@ -407,9 +409,9 @@ inline void record_copy( CommandBuffer& cmd, const VulkanImage& source, const VulkanBuffer& destination, - const api::utils::uvec3& copy_range, - const api::utils::uvec3& src_offset, - const api::utils::uvec3& dst_offset) { + const utils::uvec3& copy_range, + const utils::uvec3& src_offset, + const utils::uvec3& dst_offset) { cmd.copy_texture_to_buffer( source, destination, copy_range, src_offset, dst_offset); } @@ -419,9 +421,9 @@ inline void record_copy( CommandBuffer& cmd, const VulkanBuffer& source, const VulkanImage& destination, - const api::utils::uvec3& copy_range, - const api::utils::uvec3& src_offset, - const api::utils::uvec3& dst_offset) { + const utils::uvec3& copy_range, + const utils::uvec3& src_offset, + const utils::uvec3& dst_offset) { cmd.copy_buffer_to_texture( source, destination, copy_range, src_offset, dst_offset); } @@ -438,9 +440,9 @@ inline bool Context::submit_copy( PipelineBarrier& pipeline_barrier, const S& source, const D& destination, - const api::utils::uvec3& copy_range, - const api::utils::uvec3& src_offset, - const api::utils::uvec3& dst_offset, + const utils::uvec3& copy_range, + const utils::uvec3& src_offset, + const utils::uvec3& dst_offset, VkFence fence_handle) { // If any of the provided arguments does not have memory associated with it, // then exit early as there is no work to be done. However, if a fence has @@ -485,7 +487,7 @@ inline bool Context::submit_copy( submit_count_++; if (fence_handle != VK_NULL_HANDLE || - submit_count_ >= config_.cmdSubmitFrequency) { + submit_count_ >= config_.cmd_submit_frequency) { submit_cmd_to_gpu(fence_handle); return true; } @@ -568,7 +570,7 @@ inline bool Context::submit_compute_job( submit_count_++; if (fence_handle != VK_NULL_HANDLE || - submit_count_ >= config_.cmdSubmitFrequency) { + submit_count_ >= config_.cmd_submit_frequency) { submit_cmd_to_gpu(fence_handle); return true; } diff --git a/backends/vulkan/runtime/api/Descriptor.cpp b/backends/vulkan/runtime/api/Descriptor.cpp index 572cc674981..99ca6978594 100644 --- a/backends/vulkan/runtime/api/Descriptor.cpp +++ b/backends/vulkan/runtime/api/Descriptor.cpp @@ -235,7 +235,7 @@ DescriptorPool::DescriptorPool( config_(config), mutex_{}, piles_{} { - if (config.descriptorPoolMaxSets > 0) { + if (config.descriptor_pool_max_sets > 0) { init(config); } } @@ -257,19 +257,19 @@ void DescriptorPool::init(const DescriptorPoolConfig& config) { std::vector type_sizes{ { VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, - config_.descriptorUniformBufferCount, + config_.descriptor_uniform_buffer_count, }, { VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, - config_.descriptorStorageBufferCount, + config_.descriptor_storage_buffer_count, }, { VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, - config_.descriptorCombinedSamplerCount, + config_.descriptor_combined_sampler_count, }, { VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, - config_.descriptorStorageBufferCount, + config_.descriptor_storage_buffer_count, }, }; @@ -277,7 +277,7 @@ void DescriptorPool::init(const DescriptorPoolConfig& config) { VK_STRUCTURE_TYPE_DESCRIPTOR_POOL_CREATE_INFO, // sType nullptr, // pNext 0u, // flags - config_.descriptorPoolMaxSets, // maxSets + config_.descriptor_pool_max_sets, // maxSets static_cast(type_sizes.size()), // poolSizeCounts type_sizes.data(), // pPoolSizes }; @@ -297,7 +297,7 @@ DescriptorSet DescriptorPool::get_descriptor_set( .insert({ set_layout, DescriptorSetPile( - config_.descriptorPileSizes, set_layout, device_, pool_), + config_.descriptor_pile_sizes, set_layout, device_, pool_), }) .first; } diff --git a/backends/vulkan/runtime/api/Descriptor.h b/backends/vulkan/runtime/api/Descriptor.h index 0b6b1cd885a..e1b40fbd173 100644 --- a/backends/vulkan/runtime/api/Descriptor.h +++ b/backends/vulkan/runtime/api/Descriptor.h @@ -12,9 +12,11 @@ #include -#include #include +#include +#include + #include namespace vkcompute { @@ -107,14 +109,14 @@ class DescriptorSetPile final { struct DescriptorPoolConfig final { // Overall Pool capacity - uint32_t descriptorPoolMaxSets; + uint32_t descriptor_pool_max_sets; // DescriptorCounts by type - uint32_t descriptorUniformBufferCount; - uint32_t descriptorStorageBufferCount; - uint32_t descriptorCombinedSamplerCount; - uint32_t descriptorStorageImageCount; + uint32_t descriptor_uniform_buffer_count; + uint32_t descriptor_storage_buffer_count; + uint32_t descriptor_combined_sampler_count; + uint32_t descriptor_storage_image_count; // Pile size for pre-allocating descriptor sets - uint32_t descriptorPileSizes; + uint32_t descriptor_pile_sizes; }; class DescriptorPool final { diff --git a/backends/vulkan/runtime/api/Exception.h b/backends/vulkan/runtime/api/Exception.h index 28ee096984b..05dc10ee953 100644 --- a/backends/vulkan/runtime/api/Exception.h +++ b/backends/vulkan/runtime/api/Exception.h @@ -9,14 +9,15 @@ #pragma once // @lint-ignore-every CLANGTIDY facebook-hte-BadMemberName +#include + +#include + #include #include #include #include -#include -#include - #define VK_CHECK(function) \ do { \ const VkResult result = (function); \ diff --git a/backends/vulkan/runtime/api/Fence.cpp b/backends/vulkan/runtime/api/Fence.cpp new file mode 100644 index 00000000000..6253a5e13e1 --- /dev/null +++ b/backends/vulkan/runtime/api/Fence.cpp @@ -0,0 +1,76 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +namespace vkcompute { +namespace api { + +VulkanFence::VulkanFence() + : device_(VK_NULL_HANDLE), handle_(VK_NULL_HANDLE), waiting_(false) {} + +VulkanFence::VulkanFence(VkDevice device) + : device_(device), handle_(VK_NULL_HANDLE), waiting_(VK_NULL_HANDLE) { + const VkFenceCreateInfo fence_create_info{ + VK_STRUCTURE_TYPE_FENCE_CREATE_INFO, // sType + nullptr, // pNext + 0u, // flags + }; + + VK_CHECK(vkCreateFence(device_, &fence_create_info, nullptr, &handle_)); +} + +VulkanFence::VulkanFence(VulkanFence&& other) noexcept + : device_(other.device_), handle_(other.handle_), waiting_(other.waiting_) { + other.handle_ = VK_NULL_HANDLE; + other.waiting_ = false; +} + +VulkanFence& VulkanFence::operator=(VulkanFence&& other) noexcept { + device_ = other.device_; + handle_ = other.handle_; + waiting_ = other.waiting_; + + other.device_ = VK_NULL_HANDLE; + other.handle_ = VK_NULL_HANDLE; + other.waiting_ = false; + + return *this; +} + +VulkanFence::~VulkanFence() { + if (VK_NULL_HANDLE == handle_) { + return; + } + vkDestroyFence(device_, handle_, nullptr); +} + +void VulkanFence::wait() { + // if get_submit_handle() has not been called, then this will no-op + if (waiting_) { + VkResult fence_status = VK_NOT_READY; + // Run the wait in a loop to keep the CPU hot. A single call to + // vkWaitForFences with no timeout may cause the calling thread to be + // scheduled out. + do { + // The timeout (last) arg is in units of ns + fence_status = vkWaitForFences(device_, 1u, &handle_, VK_TRUE, 100000); + + VK_CHECK_COND( + fence_status != VK_ERROR_DEVICE_LOST, + "Vulkan Fence: Device lost while waiting for fence!"); + } while (fence_status != VK_SUCCESS); + + VK_CHECK(vkResetFences(device_, 1u, &handle_)); + + waiting_ = false; + } +} + +} // namespace api +} // namespace vkcompute diff --git a/backends/vulkan/runtime/api/Fence.h b/backends/vulkan/runtime/api/Fence.h new file mode 100644 index 00000000000..613a24aaec5 --- /dev/null +++ b/backends/vulkan/runtime/api/Fence.h @@ -0,0 +1,98 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +// @lint-ignore-every CLANGTIDY facebook-hte-BadMemberName + +#include + +#include + +#include + +namespace vkcompute { +namespace api { + +class VulkanFence final { + public: + // TODO: This is required for the lazy allocation pattern in api/Tensor. + // It will be disabled pending future refactors. + explicit VulkanFence(); + + explicit VulkanFence(VkDevice); + + VulkanFence(const VulkanFence&) = delete; + VulkanFence& operator=(const VulkanFence&) = delete; + + VulkanFence(VulkanFence&&) noexcept; + VulkanFence& operator=(VulkanFence&&) noexcept; + + ~VulkanFence(); + + private: + VkDevice device_; + VkFence handle_; + bool waiting_; + + public: + // Used to get the handle for a queue submission. + VkFence get_submit_handle() { + if (handle_ != VK_NULL_HANDLE) { + // Indicate we are now waiting for this fence to be signaled + waiting_ = true; + } + return handle_; + } + + VkFence handle() { + return handle_; + } + + // Trigger a synchronous wait for the fence to be signaled + void wait(); + + bool waiting() const { + return waiting_; + } + + operator bool() const { + return (VK_NULL_HANDLE != handle_); + } +}; + +// A pool to track created Fences and reuse ones that are available. +// Only intended to be modified by one thread at a time. +struct FencePool final { + VkDevice device_; + + std::stack pool_; + + explicit FencePool(VkDevice device) : device_(device), pool_{} {} + + // Returns an rvalue reference to a fence, so that it can be moved + inline VulkanFence get_fence() { + if (pool_.empty()) { + VulkanFence new_fence = VulkanFence(device_); + return new_fence; + } + + VulkanFence top_fence = std::move(pool_.top()); + pool_.pop(); + + return top_fence; + } + + // Marks the fence as available + inline void return_fence(VulkanFence& fence) { + pool_.push(std::move(fence)); + } +}; + +} // namespace api +} // namespace vkcompute diff --git a/backends/vulkan/runtime/api/Pipeline.cpp b/backends/vulkan/runtime/api/Pipeline.cpp index f4be0039e67..a6bff47cac1 100644 --- a/backends/vulkan/runtime/api/Pipeline.cpp +++ b/backends/vulkan/runtime/api/Pipeline.cpp @@ -8,6 +8,8 @@ #include +#include + namespace vkcompute { namespace api { @@ -137,7 +139,7 @@ uint32_t SpecVar::val_size() const { } uint32_t SpecVar::val_offset() const { - return api::utils::safe_downcast(offsetof(SpecVar, value)); + return utils::safe_downcast(offsetof(SpecVar, value)); } bool operator==(const SpecVar& lhs, const SpecVar& rhs) { @@ -358,17 +360,24 @@ void PipelineLayoutCache::purge() { // ComputePipelineCache // -ComputePipelineCache::ComputePipelineCache(VkDevice device) +ComputePipelineCache::ComputePipelineCache( + VkDevice device, + const std::string& cache_data_path) : cache_mutex_{}, device_(device), pipeline_cache_{VK_NULL_HANDLE}, - cache_{} { - const VkPipelineCacheCreateInfo pipeline_cache_create_info{ + cache_{}, + cache_data_path_(cache_data_path) { + VkPipelineCacheCreateInfo pipeline_cache_create_info{}; + + auto buffer = load_cache(); + + pipeline_cache_create_info = { VK_STRUCTURE_TYPE_PIPELINE_CACHE_CREATE_INFO, // sType nullptr, // pNext 0u, // flags - 0u, // initialDataSize - nullptr, // pInitialData + buffer.size(), // initialDataSize + buffer.data(), // pInitialData }; VK_CHECK(vkCreatePipelineCache( @@ -392,6 +401,9 @@ ComputePipelineCache::~ComputePipelineCache() { if (VK_NULL_HANDLE == pipeline_cache_) { return; } + + save_cache(); + vkDestroyPipelineCache(device_, pipeline_cache_, nullptr); pipeline_cache_ = VK_NULL_HANDLE; } @@ -416,5 +428,37 @@ void ComputePipelineCache::purge() { cache_.clear(); } +std::vector ComputePipelineCache::load_cache() { + // Return if path is not specified; this means the optimization is disabled + if (cache_data_path_.empty()) { + return {}; + } + + // Return if file doesn't exist; this is expected on the first model-load + std::ifstream file(cache_data_path_, std::ios::binary | std::ios::ate); + if (file.fail()) { + return {}; + } + + auto size = file.tellg(); + file.seekg(0, std::ios::beg); + + std::vector buffer(size); + file.read(buffer.data(), size); + + return buffer; +} + +void ComputePipelineCache::save_cache() { + size_t size{}; + vkGetPipelineCacheData(device_, pipeline_cache_, &size, nullptr); + + std::vector buffer(size); + vkGetPipelineCacheData(device_, pipeline_cache_, &size, buffer.data()); + + std::ofstream file(cache_data_path_, std::ios::binary); + file.write(buffer.data(), buffer.size()); +} + } // namespace api } // namespace vkcompute diff --git a/backends/vulkan/runtime/api/Pipeline.h b/backends/vulkan/runtime/api/Pipeline.h index b8c16efd910..35b3b6275b4 100644 --- a/backends/vulkan/runtime/api/Pipeline.h +++ b/backends/vulkan/runtime/api/Pipeline.h @@ -12,9 +12,11 @@ #include -#include #include +#include +#include + #include #include @@ -69,7 +71,7 @@ class SpecVarList final { } inline uint32_t size() const { - return api::utils::safe_downcast(vars.size()); + return utils::safe_downcast(vars.size()); } inline uint32_t data_nbytes() const { @@ -214,7 +216,9 @@ class PipelineLayoutCache final { class ComputePipelineCache final { public: - explicit ComputePipelineCache(VkDevice device); + explicit ComputePipelineCache( + VkDevice device, + const std::string& cache_data_path); ComputePipelineCache(const ComputePipelineCache&) = delete; ComputePipelineCache& operator=(const ComputePipelineCache&) = delete; @@ -264,6 +268,9 @@ class ComputePipelineCache final { }; private: + std::vector load_cache(); + void save_cache(); + // Multiple threads could potentially be adding entries into the cache, so use // a mutex to manage access std::mutex cache_mutex_; @@ -271,6 +278,7 @@ class ComputePipelineCache final { VkDevice device_; VkPipelineCache pipeline_cache_; std::unordered_map cache_; + const std::string cache_data_path_; public: VkPipeline retrieve(const Key&); diff --git a/backends/vulkan/runtime/api/QueryPool.cpp b/backends/vulkan/runtime/api/QueryPool.cpp index b908c6e53b4..5deff1d4c4c 100644 --- a/backends/vulkan/runtime/api/QueryPool.cpp +++ b/backends/vulkan/runtime/api/QueryPool.cpp @@ -42,13 +42,13 @@ QueryPool::QueryPool(const QueryPoolConfig& config, const Adapter* adapter_p) nullptr, // pNext 0u, // flags VK_QUERY_TYPE_TIMESTAMP, // queryType - config_.maxQueryCount, // queryCount + config_.max_query_count, // queryCount 0u, // pipelineStatistics }; VK_CHECK(vkCreateQueryPool(device_, &info, nullptr, &querypool_)); - shader_log().reserve(config_.initialReserveSize); + shader_log().reserve(config_.initial_reserve_size); VK_CHECK_COND(adapter_p, "Valid GPU device must be created for QueryPool"); ns_per_tick_ = std::lround(adapter_p->timestamp_period()); @@ -79,16 +79,16 @@ void QueryPool::reset(const CommandBuffer& cmd) { previous_shader_count_ += shader_log().size(); in_use_ = 0u; shader_logs_.emplace_back(); - shader_log().reserve(config_.initialReserveSize); + shader_log().reserve(config_.initial_reserve_size); results_pending_ = false; } size_t QueryPool::write_timestamp(const CommandBuffer& cmd) { VK_CHECK_COND( - in_use_ < config_.maxQueryCount, + in_use_ < config_.max_query_count, "Vulkan QueryPool: Exceeded the maximum number of queries " "allowed by the queryPool (", - config_.maxQueryCount, + config_.max_query_count, ")!"); cmd.write_timestamp(querypool_, in_use_); diff --git a/backends/vulkan/runtime/api/QueryPool.h b/backends/vulkan/runtime/api/QueryPool.h index 9249942df08..a0c6d9b14f1 100644 --- a/backends/vulkan/runtime/api/QueryPool.h +++ b/backends/vulkan/runtime/api/QueryPool.h @@ -22,8 +22,8 @@ namespace vkcompute { namespace api { struct QueryPoolConfig final { - uint32_t maxQueryCount; - uint32_t initialReserveSize; + uint32_t max_query_count; + uint32_t initial_reserve_size; }; struct ShaderDuration final { diff --git a/backends/vulkan/runtime/api/Resource.cpp b/backends/vulkan/runtime/api/Resource.cpp deleted file mode 100644 index d15dfc05275..00000000000 --- a/backends/vulkan/runtime/api/Resource.cpp +++ /dev/null @@ -1,838 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include -#include - -#define PRINT_FIELD(struct, field) #field << ": " << struct.field << std::endl - -std::ostream& operator<<(std::ostream& out, VmaTotalStatistics stats) { - VmaDetailedStatistics total_stats = stats.total; - out << "VmaTotalStatistics: " << std::endl; - out << " " << PRINT_FIELD(total_stats.statistics, blockCount); - out << " " << PRINT_FIELD(total_stats.statistics, allocationCount); - out << " " << PRINT_FIELD(total_stats.statistics, blockBytes); - out << " " << PRINT_FIELD(total_stats.statistics, allocationBytes); - return out; -} - -#undef PRINT_FIELD - -namespace vkcompute { -namespace api { - -// -// MemoryBarrier -// - -MemoryBarrier::MemoryBarrier( - const VkAccessFlags src_access_flags, - const VkAccessFlags dst_access_flags) - : handle{ - VK_STRUCTURE_TYPE_MEMORY_BARRIER, // sType - nullptr, // pNext - src_access_flags, // srcAccessMask - dst_access_flags, // dstAccessMask - } {} - -// -// MemoryAllocation -// - -MemoryAllocation::MemoryAllocation() - : memory_requirements{}, - create_info{}, - allocator(VK_NULL_HANDLE), - allocation(VK_NULL_HANDLE) {} - -MemoryAllocation::MemoryAllocation( - VmaAllocator vma_allocator, - const VkMemoryRequirements& mem_props, - const VmaAllocationCreateInfo& create_info) - : memory_requirements(mem_props), - create_info(create_info), - allocator(vma_allocator), - allocation(VK_NULL_HANDLE) { - VK_CHECK(vmaAllocateMemory( - allocator, &memory_requirements, &create_info, &allocation, nullptr)); -} - -MemoryAllocation::MemoryAllocation(MemoryAllocation&& other) noexcept - : memory_requirements(other.memory_requirements), - create_info(other.create_info), - allocator(other.allocator), - allocation(other.allocation) { - other.allocation = VK_NULL_HANDLE; -} - -MemoryAllocation& MemoryAllocation::operator=( - MemoryAllocation&& other) noexcept { - VmaAllocation tmp_allocation = allocation; - - memory_requirements = other.memory_requirements; - create_info = other.create_info; - allocator = other.allocator; - allocation = other.allocation; - - other.allocation = tmp_allocation; - - return *this; -} - -MemoryAllocation::~MemoryAllocation() { - if (VK_NULL_HANDLE != allocation) { - vmaFreeMemory(allocator, allocation); - } -} - -// -// VulkanBuffer -// - -VulkanBuffer::VulkanBuffer() - : buffer_properties_{}, - allocator_(VK_NULL_HANDLE), - memory_{}, - owns_memory_(false), - handle_(VK_NULL_HANDLE) {} - -VulkanBuffer::VulkanBuffer( - VmaAllocator vma_allocator, - const VkDeviceSize size, - const VmaAllocationCreateInfo& allocation_create_info, - const VkBufferUsageFlags usage, - const bool allocate_memory) - : buffer_properties_({ - size, - 0u, - size, - usage, - }), - allocator_(vma_allocator), - memory_{}, - owns_memory_(allocate_memory), - handle_(VK_NULL_HANDLE) { - // Only allocate memory if the buffer has non-zero size - if (size == 0) { - return; - } - - const VkBufferCreateInfo buffer_create_info{ - VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO, // sType - nullptr, // pNext - 0u, // flags - size, // size - buffer_properties_.buffer_usage, // usage - VK_SHARING_MODE_EXCLUSIVE, // sharingMode - 0u, // queueFamilyIndexCount - nullptr, // pQueueFamilyIndices - }; - - memory_.create_info = allocation_create_info; - - if (allocate_memory) { - VK_CHECK(vmaCreateBuffer( - allocator_, - &buffer_create_info, - &allocation_create_info, - &handle_, - &(memory_.allocation), - nullptr)); - } else { - VmaAllocatorInfo allocator_info{}; - vmaGetAllocatorInfo(allocator_, &allocator_info); - VK_CHECK(vkCreateBuffer( - allocator_info.device, &buffer_create_info, nullptr, &handle_)); - } -} - -VulkanBuffer::VulkanBuffer(VulkanBuffer&& other) noexcept - : buffer_properties_(other.buffer_properties_), - allocator_(other.allocator_), - memory_(std::move(other.memory_)), - owns_memory_(other.owns_memory_), - handle_(other.handle_) { - other.handle_ = VK_NULL_HANDLE; -} - -VulkanBuffer& VulkanBuffer::operator=(VulkanBuffer&& other) noexcept { - VkBuffer tmp_buffer = handle_; - bool tmp_owns_memory = owns_memory_; - - buffer_properties_ = other.buffer_properties_; - allocator_ = other.allocator_; - memory_ = std::move(other.memory_); - owns_memory_ = other.owns_memory_; - handle_ = other.handle_; - - other.handle_ = tmp_buffer; - other.owns_memory_ = tmp_owns_memory; - - return *this; -} - -VulkanBuffer::~VulkanBuffer() { - if (VK_NULL_HANDLE != handle_) { - if (owns_memory_) { - vmaDestroyBuffer(allocator_, handle_, memory_.allocation); - } else { - vkDestroyBuffer(this->device(), handle_, nullptr); - } - // Prevent the underlying memory allocation from being freed; it was either - // freed by vmaDestroyBuffer, or this resource does not own the underlying - // memory - memory_.allocation = VK_NULL_HANDLE; - } -} - -VkMemoryRequirements VulkanBuffer::get_memory_requirements() const { - VkMemoryRequirements memory_requirements; - vkGetBufferMemoryRequirements(this->device(), handle_, &memory_requirements); - return memory_requirements; -} - -// -// MemoryMap -// - -MemoryMap::MemoryMap(const VulkanBuffer& buffer, const uint8_t access) - : access_(access), - allocator_(buffer.vma_allocator()), - allocation_(buffer.allocation()), - data_(nullptr), - data_len_{buffer.mem_size()} { - if (allocation_) { - VK_CHECK(vmaMapMemory(allocator_, allocation_, &data_)); - } -} - -MemoryMap::MemoryMap(MemoryMap&& other) noexcept - : access_(other.access_), - allocator_(other.allocator_), - allocation_(other.allocation_), - data_(other.data_), - data_len_{other.data_len_} { - other.allocation_ = VK_NULL_HANDLE; - other.data_ = nullptr; -} - -MemoryMap::~MemoryMap() { - if (!data_) { - return; - } - - if (allocation_) { - if (access_ & MemoryAccessType::WRITE) { - // Call will be ignored by implementation if the memory type this - // allocation belongs to is not HOST_VISIBLE or is HOST_COHERENT, which is - // the behavior we want. Don't check the result here as the destructor - // cannot throw. - vmaFlushAllocation(allocator_, allocation_, 0u, VK_WHOLE_SIZE); - } - - vmaUnmapMemory(allocator_, allocation_); - } -} - -void MemoryMap::invalidate() { - if (access_ & MemoryAccessType::READ && allocation_) { - // Call will be ignored by implementation if the memory type this allocation - // belongs to is not HOST_VISIBLE or is HOST_COHERENT, which is the behavior - // we want. - VK_CHECK( - vmaInvalidateAllocation(allocator_, allocation_, 0u, VK_WHOLE_SIZE)); - } -} - -// -// BufferMemoryBarrier -// - -BufferMemoryBarrier::BufferMemoryBarrier( - const VkAccessFlags src_access_flags, - const VkAccessFlags dst_access_flags, - const VulkanBuffer& buffer) - : handle{ - VK_STRUCTURE_TYPE_BUFFER_MEMORY_BARRIER, // sType - nullptr, // pNext - src_access_flags, // srcAccessMask - dst_access_flags, // dstAccessMask - VK_QUEUE_FAMILY_IGNORED, // srcQueueFamilyIndex - VK_QUEUE_FAMILY_IGNORED, // dstQueueFamilyIndex - buffer.handle_, // buffer - buffer.buffer_properties_.mem_offset, // offset - buffer.buffer_properties_.mem_range, // size - } {} - -// -// ImageSampler -// - -bool operator==( - const ImageSampler::Properties& _1, - const ImageSampler::Properties& _2) { - return ( - _1.filter == _2.filter && _1.mipmap_mode == _2.mipmap_mode && - _1.address_mode == _2.address_mode && _1.border_color == _2.border_color); -} - -ImageSampler::ImageSampler( - VkDevice device, - const ImageSampler::Properties& props) - : device_(device), handle_(VK_NULL_HANDLE) { - const VkSamplerCreateInfo sampler_create_info{ - VK_STRUCTURE_TYPE_SAMPLER_CREATE_INFO, // sType - nullptr, // pNext - 0u, // flags - props.filter, // magFilter - props.filter, // minFilter - props.mipmap_mode, // mipmapMode - props.address_mode, // addressModeU - props.address_mode, // addressModeV - props.address_mode, // addressModeW - 0.0f, // mipLodBias - VK_FALSE, // anisotropyEnable - 1.0f, // maxAnisotropy, - VK_FALSE, // compareEnable - VK_COMPARE_OP_NEVER, // compareOp - 0.0f, // minLod - VK_LOD_CLAMP_NONE, // maxLod - props.border_color, // borderColor - VK_FALSE, // unnormalizedCoordinates - }; - - VK_CHECK(vkCreateSampler(device_, &sampler_create_info, nullptr, &handle_)); -} - -ImageSampler::ImageSampler(ImageSampler&& other) noexcept - : device_(other.device_), handle_(other.handle_) { - other.handle_ = VK_NULL_HANDLE; -} - -ImageSampler::~ImageSampler() { - if (VK_NULL_HANDLE == handle_) { - return; - } - vkDestroySampler(device_, handle_, nullptr); -} - -size_t ImageSampler::Hasher::operator()( - const ImageSampler::Properties& props) const { - size_t seed = 0; - seed = utils::hash_combine(seed, std::hash()(props.filter)); - seed = utils::hash_combine( - seed, std::hash()(props.mipmap_mode)); - seed = utils::hash_combine( - seed, std::hash()(props.address_mode)); - seed = - utils::hash_combine(seed, std::hash()(props.border_color)); - return seed; -} - -void swap(ImageSampler& lhs, ImageSampler& rhs) noexcept { - VkDevice tmp_device = lhs.device_; - VkSampler tmp_handle = lhs.handle_; - - lhs.device_ = rhs.device_; - lhs.handle_ = rhs.handle_; - - rhs.device_ = tmp_device; - rhs.handle_ = tmp_handle; -} - -// -// VulkanImage -// - -VulkanImage::VulkanImage() - : image_properties_{}, - view_properties_{}, - sampler_properties_{}, - allocator_(VK_NULL_HANDLE), - memory_{}, - owns_memory_(false), - handles_{ - VK_NULL_HANDLE, - VK_NULL_HANDLE, - VK_NULL_HANDLE, - }, - layout_{} {} - -VulkanImage::VulkanImage( - VmaAllocator vma_allocator, - const VmaAllocationCreateInfo& allocation_create_info, - const ImageProperties& image_props, - const ViewProperties& view_props, - const SamplerProperties& sampler_props, - const VkImageLayout layout, - VkSampler sampler, - const bool allocate_memory) - : image_properties_(image_props), - view_properties_(view_props), - sampler_properties_(sampler_props), - allocator_(vma_allocator), - memory_{}, - owns_memory_{allocate_memory}, - handles_{ - VK_NULL_HANDLE, - VK_NULL_HANDLE, - sampler, - }, - layout_(layout) { - VmaAllocatorInfo allocator_info{}; - vmaGetAllocatorInfo(allocator_, &allocator_info); - - // If any dims are zero, then no memory will be allocated for the image. - if (image_props.image_extents.width == 0 || - image_props.image_extents.height == 0 || - image_props.image_extents.depth == 0) { - return; - } - - const VkImageCreateInfo image_create_info{ - VK_STRUCTURE_TYPE_IMAGE_CREATE_INFO, // sType - nullptr, // pNext - 0u, // flags - image_properties_.image_type, // imageType - image_properties_.image_format, // format - image_properties_.image_extents, // extents - 1u, // mipLevels - 1u, // arrayLayers - VK_SAMPLE_COUNT_1_BIT, // samples - VK_IMAGE_TILING_OPTIMAL, // tiling - image_properties_.image_usage, // usage - VK_SHARING_MODE_EXCLUSIVE, // sharingMode - 0u, // queueFamilyIndexCount - nullptr, // pQueueFamilyIndices - layout_, // initialLayout - }; - - memory_.create_info = allocation_create_info; - - if (allocate_memory) { - VK_CHECK(vmaCreateImage( - allocator_, - &image_create_info, - &allocation_create_info, - &(handles_.image), - &(memory_.allocation), - nullptr)); - // Only create the image view if the image has been bound to memory - create_image_view(); - } else { - VK_CHECK(vkCreateImage( - allocator_info.device, &image_create_info, nullptr, &(handles_.image))); - } -} - -VulkanImage::VulkanImage(VulkanImage&& other) noexcept - : image_properties_(other.image_properties_), - view_properties_(other.view_properties_), - sampler_properties_(other.sampler_properties_), - allocator_(other.allocator_), - memory_(std::move(other.memory_)), - owns_memory_(other.owns_memory_), - handles_(other.handles_), - layout_(other.layout_) { - other.handles_.image = VK_NULL_HANDLE; - other.handles_.image_view = VK_NULL_HANDLE; - other.handles_.sampler = VK_NULL_HANDLE; - other.owns_memory_ = false; -} - -VulkanImage& VulkanImage::operator=(VulkanImage&& other) noexcept { - VkImage tmp_image = handles_.image; - VkImageView tmp_image_view = handles_.image_view; - bool tmp_owns_memory = owns_memory_; - - image_properties_ = other.image_properties_; - view_properties_ = other.view_properties_; - sampler_properties_ = other.sampler_properties_; - allocator_ = other.allocator_; - memory_ = std::move(other.memory_); - owns_memory_ = other.owns_memory_; - handles_ = other.handles_; - layout_ = other.layout_; - - other.handles_.image = tmp_image; - other.handles_.image_view = tmp_image_view; - other.owns_memory_ = tmp_owns_memory; - - return *this; -} - -VulkanImage::~VulkanImage() { - if (VK_NULL_HANDLE != handles_.image_view) { - vkDestroyImageView(this->device(), handles_.image_view, nullptr); - } - - if (VK_NULL_HANDLE != handles_.image) { - if (owns_memory_) { - vmaDestroyImage(allocator_, handles_.image, memory_.allocation); - } else { - vkDestroyImage(this->device(), handles_.image, nullptr); - } - // Prevent the underlying memory allocation from being freed; it was either - // freed by vmaDestroyImage, or this resource does not own the underlying - // memory - memory_.allocation = VK_NULL_HANDLE; - } -} - -void VulkanImage::create_image_view() { - VmaAllocatorInfo allocator_info{}; - vmaGetAllocatorInfo(allocator_, &allocator_info); - - const VkComponentMapping component_mapping{ - VK_COMPONENT_SWIZZLE_IDENTITY, // r - VK_COMPONENT_SWIZZLE_IDENTITY, // g - VK_COMPONENT_SWIZZLE_IDENTITY, // b - VK_COMPONENT_SWIZZLE_IDENTITY, // a - }; - - const VkImageSubresourceRange subresource_range{ - VK_IMAGE_ASPECT_COLOR_BIT, // aspectMask - 0u, // baseMipLevel - VK_REMAINING_MIP_LEVELS, // levelCount - 0u, // baseArrayLayer - VK_REMAINING_ARRAY_LAYERS, // layerCount - }; - - const VkImageViewCreateInfo image_view_create_info{ - VK_STRUCTURE_TYPE_IMAGE_VIEW_CREATE_INFO, // sType - nullptr, // pNext - 0u, // flags - handles_.image, // image - view_properties_.view_type, // viewType - view_properties_.view_format, // format - component_mapping, // components - subresource_range, // subresourceRange - }; - - VK_CHECK(vkCreateImageView( - allocator_info.device, - &(image_view_create_info), - nullptr, - &(handles_.image_view))); -} - -VkMemoryRequirements VulkanImage::get_memory_requirements() const { - VkMemoryRequirements memory_requirements; - vkGetImageMemoryRequirements( - this->device(), handles_.image, &memory_requirements); - return memory_requirements; -} - -// -// ImageMemoryBarrier -// - -ImageMemoryBarrier::ImageMemoryBarrier( - const VkAccessFlags src_access_flags, - const VkAccessFlags dst_access_flags, - const VkImageLayout src_layout_flags, - const VkImageLayout dst_layout_flags, - const VulkanImage& image) - : handle{ - VK_STRUCTURE_TYPE_IMAGE_MEMORY_BARRIER, // sType - nullptr, // pNext - src_access_flags, // srcAccessMask - dst_access_flags, // dstAccessMask - src_layout_flags, // oldLayout - dst_layout_flags, // newLayout - VK_QUEUE_FAMILY_IGNORED, // srcQueueFamilyIndex - VK_QUEUE_FAMILY_IGNORED, // dstQueueFamilyIndex - image.handles_.image, // image - { - // subresourceRange - VK_IMAGE_ASPECT_COLOR_BIT, // aspectMask - 0u, // baseMipLevel - VK_REMAINING_MIP_LEVELS, // levelCount - 0u, // baseArrayLayer - VK_REMAINING_ARRAY_LAYERS, // layerCount - }, - } {} - -// -// SamplerCache -// - -SamplerCache::SamplerCache(VkDevice device) - : cache_mutex_{}, device_(device), cache_{} {} - -SamplerCache::SamplerCache(SamplerCache&& other) noexcept - : cache_mutex_{}, device_(other.device_), cache_(std::move(other.cache_)) { - std::lock_guard lock(other.cache_mutex_); -} - -SamplerCache::~SamplerCache() { - purge(); -} - -VkSampler SamplerCache::retrieve(const SamplerCache::Key& key) { - std::lock_guard lock(cache_mutex_); - - auto it = cache_.find(key); - if (cache_.cend() == it) { - it = cache_.insert({key, SamplerCache::Value(device_, key)}).first; - } - - return it->second.handle(); -} - -void SamplerCache::purge() { - std::lock_guard lock(cache_mutex_); - cache_.clear(); -} - -// -// MemoryAllocator -// - -MemoryAllocator::MemoryAllocator( - VkInstance instance, - VkPhysicalDevice physical_device, - VkDevice device) - : instance_{}, - physical_device_(physical_device), - device_(device), - allocator_{VK_NULL_HANDLE} { - VmaVulkanFunctions vk_functions{}; - vk_functions.vkGetInstanceProcAddr = vkGetInstanceProcAddr; - vk_functions.vkGetDeviceProcAddr = vkGetDeviceProcAddr; - - const VmaAllocatorCreateInfo allocator_create_info{ - 0u, // flags - physical_device_, // physicalDevice - device_, // device - 0u, // preferredLargeHeapBlockSize - nullptr, // pAllocationCallbacks - nullptr, // pDeviceMemoryCallbacks - nullptr, // pHeapSizeLimit - &vk_functions, // pVulkanFunctions - instance, // instance - VK_API_VERSION_1_0, // vulkanApiVersion - nullptr, // pTypeExternalMemoryHandleTypes - }; - - VK_CHECK(vmaCreateAllocator(&allocator_create_info, &allocator_)); -} - -MemoryAllocator::MemoryAllocator(MemoryAllocator&& other) noexcept - : instance_(other.instance_), - physical_device_(other.physical_device_), - device_(other.device_), - allocator_(other.allocator_) { - other.allocator_ = VK_NULL_HANDLE; - other.device_ = VK_NULL_HANDLE; - other.physical_device_ = VK_NULL_HANDLE; - other.instance_ = VK_NULL_HANDLE; -} - -MemoryAllocator::~MemoryAllocator() { - if (VK_NULL_HANDLE == allocator_) { - return; - } - vmaDestroyAllocator(allocator_); -} - -MemoryAllocation MemoryAllocator::create_allocation( - const VkMemoryRequirements& memory_requirements, - const VmaAllocationCreateInfo& create_info) { - VmaAllocationCreateInfo alloc_create_info = create_info; - // Protect against using VMA_MEMORY_USAGE_AUTO_* flags when allocating memory - // directly, since those usage flags require that VkBufferCreateInfo and/or - // VkImageCreateInfo also be available. - switch (create_info.usage) { - // The logic for the below usage options are too complex, therefore prevent - // those from being used with direct memory allocation. - case VMA_MEMORY_USAGE_AUTO: - case VMA_MEMORY_USAGE_AUTO_PREFER_HOST: - VK_THROW( - "Only the VMA_MEMORY_USAGE_AUTO_PREFER_DEVICE usage flag is compatible with create_allocation()"); - break; - // Most of the time, VMA_MEMORY_USAGE_AUTO_PREFER_DEVICE will simply set the - // DEVICE_LOCAL_BIT as a preferred memory flag. Therefore the below is a - // decent approximation for VMA behaviour. - case VMA_MEMORY_USAGE_AUTO_PREFER_DEVICE: - alloc_create_info.preferredFlags = VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT; - alloc_create_info.usage = VMA_MEMORY_USAGE_UNKNOWN; - break; - default: - break; - } - - return MemoryAllocation(allocator_, memory_requirements, alloc_create_info); -} - -VulkanImage MemoryAllocator::create_image( - const VkExtent3D& extents, - const VkFormat image_format, - const VkImageType image_type, - const VkImageViewType image_view_type, - const VulkanImage::SamplerProperties& sampler_props, - VkSampler sampler, - const bool allow_transfer, - const bool allocate_memory) { - VkImageUsageFlags usage = - VK_IMAGE_USAGE_SAMPLED_BIT | VK_IMAGE_USAGE_STORAGE_BIT; - if (allow_transfer) { - usage |= - (VK_IMAGE_USAGE_TRANSFER_SRC_BIT | VK_IMAGE_USAGE_TRANSFER_DST_BIT); - } - - VmaAllocationCreateInfo alloc_create_info = {}; - alloc_create_info.flags = DEFAULT_ALLOCATION_STRATEGY; - alloc_create_info.usage = VMA_MEMORY_USAGE_AUTO_PREFER_DEVICE; - - const VulkanImage::ImageProperties image_props{ - image_type, - image_format, - extents, - usage, - }; - - const VulkanImage::ViewProperties view_props{ - image_view_type, - image_format, - }; - - const VkImageLayout initial_layout = VK_IMAGE_LAYOUT_UNDEFINED; - - return VulkanImage( - allocator_, - alloc_create_info, - image_props, - view_props, - sampler_props, - initial_layout, - sampler, - allocate_memory); -} - -VulkanBuffer MemoryAllocator::create_storage_buffer( - const VkDeviceSize size, - const bool gpu_only, - const bool allocate_memory) { - const VkBufferUsageFlags buffer_usage = VK_BUFFER_USAGE_STORAGE_BUFFER_BIT; - - VmaAllocationCreateInfo alloc_create_info = {}; - alloc_create_info.flags = DEFAULT_ALLOCATION_STRATEGY; - alloc_create_info.usage = VMA_MEMORY_USAGE_AUTO_PREFER_DEVICE; - - // The create storage buffer will be accessed by both the CPU and GPU, so set - // the appropriate flags to indicate that the host device will be accessing - // the data from this buffer. - if (!gpu_only) { - // Deferred memory allocation should only be used for GPU only buffers. - VK_CHECK_COND( - allocate_memory, - "Only GPU-only buffers should use deferred memory allocation"); - - alloc_create_info.flags |= VMA_ALLOCATION_CREATE_HOST_ACCESS_RANDOM_BIT; - alloc_create_info.usage = VMA_MEMORY_USAGE_AUTO_PREFER_HOST; - alloc_create_info.requiredFlags = VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT; - alloc_create_info.preferredFlags = VK_MEMORY_PROPERTY_HOST_COHERENT_BIT | - VK_MEMORY_PROPERTY_HOST_CACHED_BIT; - } - - return VulkanBuffer( - allocator_, size, alloc_create_info, buffer_usage, allocate_memory); -} - -VulkanBuffer MemoryAllocator::create_staging_buffer(const VkDeviceSize size) { - VmaAllocationCreateInfo alloc_create_info = {}; - alloc_create_info.flags = DEFAULT_ALLOCATION_STRATEGY; - alloc_create_info.usage = VMA_MEMORY_USAGE_AUTO_PREFER_HOST; - - VkBufferUsageFlags buffer_usage = - VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT; - - return VulkanBuffer(allocator_, size, alloc_create_info, buffer_usage); -} - -VulkanBuffer MemoryAllocator::create_uniform_buffer(const VkDeviceSize size) { - VmaAllocationCreateInfo alloc_create_info = {}; - alloc_create_info.flags = DEFAULT_ALLOCATION_STRATEGY | - VMA_ALLOCATION_CREATE_HOST_ACCESS_SEQUENTIAL_WRITE_BIT; - alloc_create_info.usage = VMA_MEMORY_USAGE_AUTO; - - VkBufferUsageFlags buffer_usage = VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT; - - VulkanBuffer uniform_buffer( - allocator_, size, alloc_create_info, buffer_usage); - return uniform_buffer; -} - -// -// VulkanFence -// - -VulkanFence::VulkanFence() - : device_(VK_NULL_HANDLE), handle_(VK_NULL_HANDLE), waiting_(false) {} - -VulkanFence::VulkanFence(VkDevice device) - : device_(device), handle_(VK_NULL_HANDLE), waiting_(VK_NULL_HANDLE) { - const VkFenceCreateInfo fence_create_info{ - VK_STRUCTURE_TYPE_FENCE_CREATE_INFO, // sType - nullptr, // pNext - 0u, // flags - }; - - VK_CHECK(vkCreateFence(device_, &fence_create_info, nullptr, &handle_)); -} - -VulkanFence::VulkanFence(VulkanFence&& other) noexcept - : device_(other.device_), handle_(other.handle_), waiting_(other.waiting_) { - other.handle_ = VK_NULL_HANDLE; - other.waiting_ = false; -} - -VulkanFence& VulkanFence::operator=(VulkanFence&& other) noexcept { - device_ = other.device_; - handle_ = other.handle_; - waiting_ = other.waiting_; - - other.device_ = VK_NULL_HANDLE; - other.handle_ = VK_NULL_HANDLE; - other.waiting_ = false; - - return *this; -} - -VulkanFence::~VulkanFence() { - if (VK_NULL_HANDLE == handle_) { - return; - } - vkDestroyFence(device_, handle_, nullptr); -} - -void VulkanFence::wait() { - // if get_submit_handle() has not been called, then this will no-op - if (waiting_) { - VkResult fence_status = VK_NOT_READY; - // Run the wait in a loop to keep the CPU hot. A single call to - // vkWaitForFences with no timeout may cause the calling thread to be - // scheduled out. - do { - // The timeout (last) arg is in units of ns - fence_status = vkWaitForFences(device_, 1u, &handle_, VK_TRUE, 100000); - - VK_CHECK_COND( - fence_status != VK_ERROR_DEVICE_LOST, - "Vulkan Fence: Device lost while waiting for fence!"); - } while (fence_status != VK_SUCCESS); - - VK_CHECK(vkResetFences(device_, 1u, &handle_)); - - waiting_ = false; - } -} - -} // namespace api -} // namespace vkcompute diff --git a/backends/vulkan/runtime/api/Resource.h b/backends/vulkan/runtime/api/Resource.h deleted file mode 100644 index 81388cdcb06..00000000000 --- a/backends/vulkan/runtime/api/Resource.h +++ /dev/null @@ -1,599 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#pragma once - -// @lint-ignore-every CLANGTIDY facebook-hte-BadMemberName - -#include - -#include -#include -#include - -#include -#include -#include -#include - -std::ostream& operator<<(std::ostream& out, VmaTotalStatistics stats); - -namespace vkcompute { -namespace api { - -using MemoryAccessFlags = uint8_t; - -constexpr VmaAllocationCreateFlags DEFAULT_ALLOCATION_STRATEGY = - VMA_ALLOCATION_CREATE_STRATEGY_MIN_MEMORY_BIT; - -enum MemoryAccessType : MemoryAccessFlags { - NONE = 0u << 0u, - READ = 1u << 0u, - WRITE = 1u << 1u, -}; - -struct MemoryBarrier final { - VkMemoryBarrier handle; - - MemoryBarrier( - const VkAccessFlags src_access_flags, - const VkAccessFlags dst_access_flags); -}; - -struct MemoryAllocation final { - explicit MemoryAllocation(); - - explicit MemoryAllocation( - const VmaAllocator, - const VkMemoryRequirements&, - const VmaAllocationCreateInfo&); - - MemoryAllocation(const MemoryAllocation&) = delete; - MemoryAllocation& operator=(const MemoryAllocation&) = delete; - - MemoryAllocation(MemoryAllocation&&) noexcept; - MemoryAllocation& operator=(MemoryAllocation&&) noexcept; - - ~MemoryAllocation(); - - VkMemoryRequirements memory_requirements; - // The properties this allocation was created with - VmaAllocationCreateInfo create_info; - // The allocator object this was allocated from - VmaAllocator allocator; - // Handles to the allocated memory - VmaAllocation allocation; - - operator bool() const { - return (allocation != VK_NULL_HANDLE); - } -}; - -class VulkanBuffer final { - public: - struct BufferProperties final { - VkDeviceSize size; - VkDeviceSize mem_offset; - VkDeviceSize mem_range; - VkBufferUsageFlags buffer_usage; - }; - - explicit VulkanBuffer(); - - explicit VulkanBuffer( - const VmaAllocator, - const VkDeviceSize, - const VmaAllocationCreateInfo&, - const VkBufferUsageFlags, - const bool allocate_memory = true); - - VulkanBuffer(const VulkanBuffer&) = delete; - VulkanBuffer& operator=(const VulkanBuffer&) = delete; - - VulkanBuffer(VulkanBuffer&&) noexcept; - VulkanBuffer& operator=(VulkanBuffer&&) noexcept; - - ~VulkanBuffer(); - - struct Package final { - VkBuffer handle; - VkDeviceSize buffer_offset; - VkDeviceSize buffer_range; - }; - - friend struct BufferMemoryBarrier; - - private: - BufferProperties buffer_properties_; - VmaAllocator allocator_; - MemoryAllocation memory_; - // Indicates whether the underlying memory is owned by this resource - bool owns_memory_; - VkBuffer handle_; - - public: - inline VkDevice device() const { - VmaAllocatorInfo allocator_info{}; - vmaGetAllocatorInfo(allocator_, &allocator_info); - return allocator_info.device; - } - - inline VmaAllocator vma_allocator() const { - return allocator_; - } - - inline VmaAllocation allocation() const { - return memory_.allocation; - } - - inline VmaAllocationCreateInfo allocation_create_info() const { - return VmaAllocationCreateInfo(memory_.create_info); - } - - inline VkBuffer handle() const { - return handle_; - } - - inline VkDeviceSize mem_offset() const { - return buffer_properties_.mem_offset; - } - - inline VkDeviceSize mem_range() const { - return buffer_properties_.mem_range; - } - - inline VkDeviceSize mem_size() const { - return buffer_properties_.size; - } - - inline bool has_memory() const { - return (memory_.allocation != VK_NULL_HANDLE); - } - - inline bool owns_memory() const { - return owns_memory_; - } - - operator bool() const { - return (handle_ != VK_NULL_HANDLE); - } - - inline void bind_allocation(const MemoryAllocation& memory) { - VK_CHECK_COND(!memory_, "Cannot bind an already bound allocation!"); - VK_CHECK(vmaBindBufferMemory(allocator_, memory.allocation, handle_)); - memory_.allocation = memory.allocation; - } - - VkMemoryRequirements get_memory_requirements() const; -}; - -class MemoryMap final { - public: - explicit MemoryMap( - const VulkanBuffer& buffer, - const MemoryAccessFlags access); - - MemoryMap(const MemoryMap&) = delete; - MemoryMap& operator=(const MemoryMap&) = delete; - - MemoryMap(MemoryMap&&) noexcept; - MemoryMap& operator=(MemoryMap&&) = delete; - - ~MemoryMap(); - - private: - uint8_t access_; - VmaAllocator allocator_; - VmaAllocation allocation_; - void* data_; - VkDeviceSize data_len_; - - public: - template - T* data() { - return reinterpret_cast(data_); - } - - inline size_t nbytes() { - return utils::safe_downcast(data_len_); - } - - void invalidate(); -}; - -struct BufferMemoryBarrier final { - VkBufferMemoryBarrier handle; - - BufferMemoryBarrier( - const VkAccessFlags src_access_flags, - const VkAccessFlags dst_access_flags, - const VulkanBuffer& buffer); -}; - -class ImageSampler final { - public: - struct Properties final { - VkFilter filter; - VkSamplerMipmapMode mipmap_mode; - VkSamplerAddressMode address_mode; - VkBorderColor border_color; - }; - - explicit ImageSampler(VkDevice, const Properties&); - - ImageSampler(const ImageSampler&) = delete; - ImageSampler& operator=(const ImageSampler&) = delete; - - ImageSampler(ImageSampler&&) noexcept; - ImageSampler& operator=(ImageSampler&&) = delete; - - ~ImageSampler(); - - private: - VkDevice device_; - VkSampler handle_; - - public: - VkSampler handle() const { - return handle_; - } - - struct Hasher { - size_t operator()(const Properties&) const; - }; - - // We need to define a custom swap function since this class - // does not allow for move assignment. The swap function will - // be used in the hash map. - friend void swap(ImageSampler& lhs, ImageSampler& rhs) noexcept; -}; - -class VulkanImage final { - public: - struct ImageProperties final { - VkImageType image_type; - VkFormat image_format; - VkExtent3D image_extents; - VkImageUsageFlags image_usage; - }; - - struct ViewProperties final { - VkImageViewType view_type; - VkFormat view_format; - }; - - using SamplerProperties = ImageSampler::Properties; - - struct Handles final { - VkImage image; - VkImageView image_view; - VkSampler sampler; - }; - - explicit VulkanImage(); - - explicit VulkanImage( - const VmaAllocator, - const VmaAllocationCreateInfo&, - const ImageProperties&, - const ViewProperties&, - const SamplerProperties&, - const VkImageLayout layout, - VkSampler, - const bool allocate_memory = true); - - VulkanImage(const VulkanImage&) = delete; - VulkanImage& operator=(const VulkanImage&) = delete; - - VulkanImage(VulkanImage&&) noexcept; - VulkanImage& operator=(VulkanImage&&) noexcept; - - ~VulkanImage(); - - struct Package final { - VkImage handle; - VkImageLayout image_layout; - VkImageView image_view; - VkSampler image_sampler; - }; - - friend struct ImageMemoryBarrier; - - private: - ImageProperties image_properties_; - ViewProperties view_properties_; - SamplerProperties sampler_properties_; - // The allocator object this was allocated from - VmaAllocator allocator_; - // Handles to the allocated memory - MemoryAllocation memory_; - // Indicates whether the underlying memory is owned by this resource - bool owns_memory_; - Handles handles_; - // Layout - VkImageLayout layout_; - - public: - void create_image_view(); - - inline VkDevice device() const { - VmaAllocatorInfo allocator_info{}; - vmaGetAllocatorInfo(allocator_, &allocator_info); - return allocator_info.device; - } - - inline VmaAllocator vma_allocator() const { - return allocator_; - } - - inline VmaAllocation allocation() const { - return memory_.allocation; - } - - inline VmaAllocationCreateInfo allocation_create_info() const { - return VmaAllocationCreateInfo(memory_.create_info); - } - - inline VkFormat format() const { - return image_properties_.image_format; - } - - inline VkExtent3D extents() const { - return image_properties_.image_extents; - } - - inline VkImage handle() const { - return handles_.image; - } - - inline VkImageView image_view() const { - return handles_.image_view; - } - - inline VkSampler sampler() const { - return handles_.sampler; - } - - Package package() const { - return { - handles_.image, - layout_, - handles_.image_view, - handles_.sampler, - }; - } - - inline VkImageLayout layout() const { - return layout_; - } - - inline void set_layout(const VkImageLayout layout) { - layout_ = layout; - } - - inline bool has_memory() const { - return (memory_.allocation != VK_NULL_HANDLE); - } - - inline bool owns_memory() const { - return owns_memory_; - } - - inline operator bool() const { - return (handles_.image != VK_NULL_HANDLE); - } - - inline void bind_allocation(const MemoryAllocation& memory) { - VK_CHECK_COND(!memory_, "Cannot bind an already bound allocation!"); - VK_CHECK(vmaBindImageMemory(allocator_, memory.allocation, handles_.image)); - memory_.allocation = memory.allocation; - - // Only create the image view if the image has been bound to memory - create_image_view(); - } - - VkMemoryRequirements get_memory_requirements() const; -}; - -struct ImageMemoryBarrier final { - VkImageMemoryBarrier handle; - - ImageMemoryBarrier( - const VkAccessFlags src_access_flags, - const VkAccessFlags dst_access_flags, - const VkImageLayout src_layout_flags, - const VkImageLayout dst_layout_flags, - const VulkanImage& image); -}; - -class SamplerCache final { - public: - explicit SamplerCache(VkDevice device); - - SamplerCache(const SamplerCache&) = delete; - SamplerCache& operator=(const SamplerCache&) = delete; - - SamplerCache(SamplerCache&&) noexcept; - SamplerCache& operator=(SamplerCache&&) = delete; - - ~SamplerCache(); - - using Key = ImageSampler::Properties; - using Value = ImageSampler; - using Hasher = ImageSampler::Hasher; - - private: - // Multiple threads could potentially be adding entries into the cache, so use - // a mutex to manage access - std::mutex cache_mutex_; - - VkDevice device_; - std::unordered_map cache_; - - public: - VkSampler retrieve(const Key&); - void purge(); -}; - -class MemoryAllocator final { - public: - explicit MemoryAllocator( - VkInstance instance, - VkPhysicalDevice physical_device, - VkDevice device); - - MemoryAllocator(const MemoryAllocator&) = delete; - MemoryAllocator& operator=(const MemoryAllocator&) = delete; - - MemoryAllocator(MemoryAllocator&&) noexcept; - MemoryAllocator& operator=(MemoryAllocator&&) = delete; - - ~MemoryAllocator(); - - private: - VkInstance instance_; - VkPhysicalDevice physical_device_; - VkDevice device_; - VmaAllocator allocator_; - - public: - MemoryAllocation create_allocation( - const VkMemoryRequirements& memory_requirements, - const VmaAllocationCreateInfo& create_info); - - VulkanImage create_image( - const VkExtent3D&, - const VkFormat, - const VkImageType, - const VkImageViewType, - const VulkanImage::SamplerProperties&, - VkSampler, - const bool allow_transfer = false, - const bool allocate_memory = true); - - VulkanBuffer create_storage_buffer( - const VkDeviceSize, - const bool gpu_only = true, - const bool allocate_memory = true); - - VulkanBuffer create_staging_buffer(const VkDeviceSize); - - /* - * Create a uniform buffer with a specified size - */ - VulkanBuffer create_uniform_buffer(const VkDeviceSize); - - /* - * Create a uniform buffer containing the data in an arbitrary struct - */ - template - VulkanBuffer create_params_buffer(const Block& block); - - VmaTotalStatistics get_memory_statistics() const { - VmaTotalStatistics stats = {}; - vmaCalculateStatistics(allocator_, &stats); - return stats; - } -}; - -class VulkanFence final { - public: - // TODO: This is required for the lazy allocation pattern in api/Tensor. - // It will be disabled pending future refactors. - explicit VulkanFence(); - - explicit VulkanFence(VkDevice); - - VulkanFence(const VulkanFence&) = delete; - VulkanFence& operator=(const VulkanFence&) = delete; - - VulkanFence(VulkanFence&&) noexcept; - VulkanFence& operator=(VulkanFence&&) noexcept; - - ~VulkanFence(); - - private: - VkDevice device_; - VkFence handle_; - bool waiting_; - - public: - // Used to get the handle for a queue submission. - VkFence get_submit_handle() { - if (handle_ != VK_NULL_HANDLE) { - // Indicate we are now waiting for this fence to be signaled - waiting_ = true; - } - return handle_; - } - - VkFence handle() { - return handle_; - } - - // Trigger a synchronous wait for the fence to be signaled - void wait(); - - bool waiting() const { - return waiting_; - } - - operator bool() const { - return (VK_NULL_HANDLE != handle_); - } -}; - -// A pool to track created Fences and reuse ones that are available. -// Only intended to be modified by one thread at a time. -struct FencePool final { - VkDevice device_; - - std::stack pool_; - - explicit FencePool(VkDevice device) : device_(device), pool_{} {} - - // Returns an rvalue reference to a fence, so that it can be moved - inline VulkanFence get_fence() { - if (pool_.empty()) { - VulkanFence new_fence = VulkanFence(device_); - return new_fence; - } - - VulkanFence top_fence = std::move(pool_.top()); - pool_.pop(); - - return top_fence; - } - - // Marks the fence as available - inline void return_fence(VulkanFence& fence) { - pool_.push(std::move(fence)); - } -}; - -// -// Impl -// - -template -inline VulkanBuffer MemoryAllocator::create_params_buffer(const Block& block) { - VulkanBuffer uniform_buffer = create_uniform_buffer(sizeof(Block)); - - // Fill the uniform buffer with data in block - { - MemoryMap mapping(uniform_buffer, MemoryAccessType::WRITE); - Block* data_ptr = mapping.template data(); - - *data_ptr = block; - } - - return uniform_buffer; -} - -} // namespace api -} // namespace vkcompute diff --git a/backends/vulkan/runtime/api/Runtime.cpp b/backends/vulkan/runtime/api/Runtime.cpp index e113a4e3b4f..432af326a53 100644 --- a/backends/vulkan/runtime/api/Runtime.cpp +++ b/backends/vulkan/runtime/api/Runtime.cpp @@ -91,7 +91,7 @@ VkInstance create_instance(const RuntimeConfiguration& config) { std::vector enabled_layers; std::vector enabled_extensions; - if (config.enableValidationMessages) { + if (config.enable_validation_messages) { std::vector requested_layers{ // "VK_LAYER_LUNARG_api_dump", "VK_LAYER_KHRONOS_validation", @@ -175,7 +175,7 @@ VKAPI_ATTR VkBool32 VKAPI_CALL debug_report_callback_fn( VkDebugReportCallbackEXT create_debug_report_callback( VkInstance instance, const RuntimeConfiguration config) { - if (VK_NULL_HANDLE == instance || !config.enableValidationMessages) { + if (VK_NULL_HANDLE == instance || !config.enable_validation_messages) { return VkDebugReportCallbackEXT{}; } @@ -245,20 +245,22 @@ std::unique_ptr init_global_vulkan_runtime() { } #endif /* USE_VULKAN_VOLK, USE_VULKAN_WRAPPER */ - const bool enableValidationMessages = + const bool enable_validation_messages = #if defined(VULKAN_DEBUG) true; #else false; #endif /* VULKAN_DEBUG */ - const bool initDefaultDevice = true; - const uint32_t numRequestedQueues = 1; // TODO: raise this value + const bool init_default_device = true; + const uint32_t num_requested_queues = 1; // TODO: raise this value + const std::string cache_data_path = ""; // TODO: expose to client const RuntimeConfiguration default_config{ - enableValidationMessages, - initDefaultDevice, + enable_validation_messages, + init_default_device, AdapterSelector::First, - numRequestedQueues, + num_requested_queues, + cache_data_path, }; try { @@ -281,9 +283,9 @@ Runtime::Runtime(const RuntimeConfiguration config) // List of adapters will never exceed the number of physical devices adapters_.reserve(device_mappings_.size()); - if (config.initDefaultDevice) { + if (config.init_default_device) { try { - switch (config.defaultSelector) { + switch (config.default_selector) { case AdapterSelector::First: default_adapter_i_ = create_adapter(select_first); } @@ -350,8 +352,11 @@ uint32_t Runtime::create_adapter(const Selector& selector) { } // Otherwise, create an adapter for the selected physical device adapter_i = utils::safe_downcast(adapters_.size()); - adapters_.emplace_back( - new Adapter(instance_, device_mapping.first, config_.numRequestedQueues)); + adapters_.emplace_back(new Adapter( + instance_, + device_mapping.first, + config_.num_requested_queues, + config_.cache_data_path)); device_mapping.second = adapter_i; return adapter_i; diff --git a/backends/vulkan/runtime/api/Runtime.h b/backends/vulkan/runtime/api/Runtime.h index f54bd7522ac..e4cb6922ad8 100644 --- a/backends/vulkan/runtime/api/Runtime.h +++ b/backends/vulkan/runtime/api/Runtime.h @@ -35,10 +35,11 @@ enum AdapterSelector { }; struct RuntimeConfiguration final { - bool enableValidationMessages; - bool initDefaultDevice; - AdapterSelector defaultSelector; - uint32_t numRequestedQueues; + bool enable_validation_messages; + bool init_default_device; + AdapterSelector default_selector; + uint32_t num_requested_queues; + std::string cache_data_path; }; class Runtime final { diff --git a/backends/vulkan/runtime/api/Tensor.cpp b/backends/vulkan/runtime/api/Tensor.cpp index 402d35d75bb..4148601ee78 100644 --- a/backends/vulkan/runtime/api/Tensor.cpp +++ b/backends/vulkan/runtime/api/Tensor.cpp @@ -143,6 +143,7 @@ vTensor::vTensor( // Utility Uniform Buffers that can be passed to shaders as arguments sizes_uniform_(), texture_limits_uniform_(), + packed_dim_meta_(), // Construct Tensor storage storage_( context, @@ -212,6 +213,30 @@ const api::BufferBindInfo vTensor::texture_limits_ubo() { return api::BufferBindInfo(texture_limits_uniform_.buffer()); } +vTensor::PackedDimMeta vTensor::make_packed_dim_metadata() const { + int64_t packed_dim = gpu_memory_layout_int(); + int32_t dim_size = api::utils::val_at(-(packed_dim + 1), sizes_); + int32_t dim_size_padded = api::utils::val_at(-(packed_dim + 1), gpu_sizes_); + int32_t dim_texel_len = + api::utils::safe_downcast(extents().data[packed_dim]); + int32_t padding = dim_size_padded - dim_size; + + return { + dim_size, + dim_size_padded, + dim_texel_len, + padding, + }; +} + +const api::BufferBindInfo vTensor::packed_dim_meta_ubo() { + if (!packed_dim_meta_.buffer()) { + packed_dim_meta_ = + api::UniformParamsBuffer(storage_.context_, make_packed_dim_metadata()); + } + return api::BufferBindInfo(packed_dim_meta_.buffer()); +} + VmaAllocationCreateInfo vTensor::get_allocation_create_info() const { switch (storage_type()) { case api::kBuffer: @@ -234,7 +259,7 @@ VkMemoryRequirements vTensor::get_memory_requirements() const { return {}; } -void vTensor::bind_allocation(const api::MemoryAllocation& allocation) { +void vTensor::bind_allocation(const api::Allocation& allocation) { switch (storage_type()) { case api::kBuffer: storage_.buffer_.bind_allocation(allocation); @@ -268,6 +293,9 @@ void vTensor::update_size_metadata(const std::vector& new_sizes) { if (texture_limits_uniform_.buffer()) { texture_limits_uniform_.update(texture_limits_); } + if (packed_dim_meta_.buffer()) { + packed_dim_meta_.update(make_packed_dim_metadata()); + } } void vTensor::reallocate(const std::vector& new_sizes) { diff --git a/backends/vulkan/runtime/api/Tensor.h b/backends/vulkan/runtime/api/Tensor.h index 787e8111204..cb0fad76eb6 100644 --- a/backends/vulkan/runtime/api/Tensor.h +++ b/backends/vulkan/runtime/api/Tensor.h @@ -117,6 +117,13 @@ class vTensor final { vTensor& operator=(vTensor&& other) = default; private: + struct PackedDimMeta { + int32_t dim_size; + int32_t dim_size_padded; + int32_t dim_texel_len; + int32_t padding; + }; + api::ScalarType dtype_; api::GPUMemoryLayout memory_layout_; @@ -134,6 +141,10 @@ class vTensor final { // tensor has been resized with `virtual_resize()`. api::UniformParamsBuffer texture_limits_uniform_; + // A Vulkan uniform buffer containing an instance of PackedDimMeta which + // describes how the tensor's packed dimension is padded. + api::UniformParamsBuffer packed_dim_meta_; + vTensorStorage storage_; public: @@ -220,6 +231,12 @@ class vTensor final { */ const api::BufferBindInfo texture_limits_ubo(); + private: + vTensor::PackedDimMeta make_packed_dim_metadata() const; + + public: + const api::BufferBindInfo packed_dim_meta_ubo(); + inline const api::utils::ivec3 texture_limits() const { return texture_limits_.limits; } @@ -259,7 +276,7 @@ class vTensor final { /* * Binds the underlying resource to the given memory allocation */ - void bind_allocation(const api::MemoryAllocation& allocation); + void bind_allocation(const api::Allocation& allocation); private: /* diff --git a/backends/vulkan/runtime/api/Utils.h b/backends/vulkan/runtime/api/Utils.h index ca36f7f75c6..26fe4ac075e 100644 --- a/backends/vulkan/runtime/api/Utils.h +++ b/backends/vulkan/runtime/api/Utils.h @@ -279,6 +279,17 @@ inline std::ostream& operator<<(std::ostream& os, const ivec4& v) { return os; } +template +inline detail::vec divup_vec( + const detail::vec& a, + const detail::vec& b) { + detail::vec result; + for (uint32_t i = 0; i < N; ++i) { + result.data[i] = api::utils::div_up(a.data[i], b.data[i]); + } + return result; +} + // // std::vector Handling // diff --git a/backends/vulkan/runtime/api/api.h b/backends/vulkan/runtime/api/api.h index 117f326cb45..16e2b969871 100644 --- a/backends/vulkan/runtime/api/api.h +++ b/backends/vulkan/runtime/api/api.h @@ -12,10 +12,15 @@ #include #include #include +#include #include -#include #include #include #include #include #include + +#include +#include +#include +#include diff --git a/backends/vulkan/runtime/api/memory/Allocation.cpp b/backends/vulkan/runtime/api/memory/Allocation.cpp new file mode 100644 index 00000000000..9bde2ac744d --- /dev/null +++ b/backends/vulkan/runtime/api/memory/Allocation.cpp @@ -0,0 +1,74 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#define PRINT_FIELD(struct, field) #field << ": " << struct.field << std::endl + +std::ostream& operator<<(std::ostream& out, VmaTotalStatistics stats) { + VmaDetailedStatistics total_stats = stats.total; + out << "VmaTotalStatistics: " << std::endl; + out << " " << PRINT_FIELD(total_stats.statistics, blockCount); + out << " " << PRINT_FIELD(total_stats.statistics, allocationCount); + out << " " << PRINT_FIELD(total_stats.statistics, blockBytes); + out << " " << PRINT_FIELD(total_stats.statistics, allocationBytes); + return out; +} + +#undef PRINT_FIELD + +namespace vkcompute { +namespace api { + +Allocation::Allocation() + : memory_requirements{}, + create_info{}, + allocator(VK_NULL_HANDLE), + allocation(VK_NULL_HANDLE) {} + +Allocation::Allocation( + VmaAllocator vma_allocator, + const VkMemoryRequirements& mem_props, + const VmaAllocationCreateInfo& create_info) + : memory_requirements(mem_props), + create_info(create_info), + allocator(vma_allocator), + allocation(VK_NULL_HANDLE) { + VK_CHECK(vmaAllocateMemory( + allocator, &memory_requirements, &create_info, &allocation, nullptr)); +} + +Allocation::Allocation(Allocation&& other) noexcept + : memory_requirements(other.memory_requirements), + create_info(other.create_info), + allocator(other.allocator), + allocation(other.allocation) { + other.allocation = VK_NULL_HANDLE; +} + +Allocation& Allocation::operator=(Allocation&& other) noexcept { + VmaAllocation tmp_allocation = allocation; + + memory_requirements = other.memory_requirements; + create_info = other.create_info; + allocator = other.allocator; + allocation = other.allocation; + + other.allocation = tmp_allocation; + + return *this; +} + +Allocation::~Allocation() { + if (VK_NULL_HANDLE != allocation) { + vmaFreeMemory(allocator, allocation); + } +} + +} // namespace api +} // namespace vkcompute diff --git a/backends/vulkan/runtime/api/memory/Allocation.h b/backends/vulkan/runtime/api/memory/Allocation.h new file mode 100644 index 00000000000..b93556bd501 --- /dev/null +++ b/backends/vulkan/runtime/api/memory/Allocation.h @@ -0,0 +1,56 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +// @lint-ignore-every CLANGTIDY facebook-hte-BadMemberName + +#include + +#include + +#include + +#include + +std::ostream& operator<<(std::ostream& out, VmaTotalStatistics stats); + +namespace vkcompute { +namespace api { + +struct Allocation final { + explicit Allocation(); + + explicit Allocation( + const VmaAllocator, + const VkMemoryRequirements&, + const VmaAllocationCreateInfo&); + + Allocation(const Allocation&) = delete; + Allocation& operator=(const Allocation&) = delete; + + Allocation(Allocation&&) noexcept; + Allocation& operator=(Allocation&&) noexcept; + + ~Allocation(); + + VkMemoryRequirements memory_requirements; + // The properties this allocation was created with + VmaAllocationCreateInfo create_info; + // The allocator object this was allocated from + VmaAllocator allocator; + // Handles to the allocated memory + VmaAllocation allocation; + + operator bool() const { + return (allocation != VK_NULL_HANDLE); + } +}; + +} // namespace api +} // namespace vkcompute diff --git a/backends/vulkan/runtime/api/memory/Allocator.cpp b/backends/vulkan/runtime/api/memory/Allocator.cpp new file mode 100644 index 00000000000..5749ecd0714 --- /dev/null +++ b/backends/vulkan/runtime/api/memory/Allocator.cpp @@ -0,0 +1,190 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +namespace vkcompute { +namespace api { + +Allocator::Allocator( + VkInstance instance, + VkPhysicalDevice physical_device, + VkDevice device) + : instance_{}, + physical_device_(physical_device), + device_(device), + allocator_{VK_NULL_HANDLE} { + VmaVulkanFunctions vk_functions{}; + vk_functions.vkGetInstanceProcAddr = vkGetInstanceProcAddr; + vk_functions.vkGetDeviceProcAddr = vkGetDeviceProcAddr; + + const VmaAllocatorCreateInfo allocator_create_info{ + 0u, // flags + physical_device_, // physicalDevice + device_, // device + 0u, // preferredLargeHeapBlockSize + nullptr, // pAllocationCallbacks + nullptr, // pDeviceMemoryCallbacks + nullptr, // pHeapSizeLimit + &vk_functions, // pVulkanFunctions + instance, // instance + VK_API_VERSION_1_0, // vulkanApiVersion + nullptr, // pTypeExternalMemoryHandleTypes + }; + + VK_CHECK(vmaCreateAllocator(&allocator_create_info, &allocator_)); +} + +Allocator::Allocator(Allocator&& other) noexcept + : instance_(other.instance_), + physical_device_(other.physical_device_), + device_(other.device_), + allocator_(other.allocator_) { + other.allocator_ = VK_NULL_HANDLE; + other.device_ = VK_NULL_HANDLE; + other.physical_device_ = VK_NULL_HANDLE; + other.instance_ = VK_NULL_HANDLE; +} + +Allocator::~Allocator() { + if (VK_NULL_HANDLE == allocator_) { + return; + } + vmaDestroyAllocator(allocator_); +} + +Allocation Allocator::create_allocation( + const VkMemoryRequirements& memory_requirements, + const VmaAllocationCreateInfo& create_info) { + VmaAllocationCreateInfo alloc_create_info = create_info; + // Protect against using VMA_MEMORY_USAGE_AUTO_* flags when allocating memory + // directly, since those usage flags require that VkBufferCreateInfo and/or + // VkImageCreateInfo also be available. + switch (create_info.usage) { + // The logic for the below usage options are too complex, therefore prevent + // those from being used with direct memory allocation. + case VMA_MEMORY_USAGE_AUTO: + case VMA_MEMORY_USAGE_AUTO_PREFER_HOST: + VK_THROW( + "Only the VMA_MEMORY_USAGE_AUTO_PREFER_DEVICE usage flag is compatible with create_allocation()"); + break; + // Most of the time, VMA_MEMORY_USAGE_AUTO_PREFER_DEVICE will simply set the + // DEVICE_LOCAL_BIT as a preferred memory flag. Therefore the below is a + // decent approximation for VMA behaviour. + case VMA_MEMORY_USAGE_AUTO_PREFER_DEVICE: + alloc_create_info.preferredFlags = VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT; + alloc_create_info.usage = VMA_MEMORY_USAGE_UNKNOWN; + break; + default: + break; + } + + return Allocation(allocator_, memory_requirements, alloc_create_info); +} + +VulkanImage Allocator::create_image( + const VkExtent3D& extents, + const VkFormat image_format, + const VkImageType image_type, + const VkImageViewType image_view_type, + const VulkanImage::SamplerProperties& sampler_props, + VkSampler sampler, + const bool allow_transfer, + const bool allocate_memory) { + VkImageUsageFlags usage = + VK_IMAGE_USAGE_SAMPLED_BIT | VK_IMAGE_USAGE_STORAGE_BIT; + if (allow_transfer) { + usage |= + (VK_IMAGE_USAGE_TRANSFER_SRC_BIT | VK_IMAGE_USAGE_TRANSFER_DST_BIT); + } + + VmaAllocationCreateInfo alloc_create_info = {}; + alloc_create_info.flags = DEFAULT_ALLOCATION_STRATEGY; + alloc_create_info.usage = VMA_MEMORY_USAGE_AUTO_PREFER_DEVICE; + + const VulkanImage::ImageProperties image_props{ + image_type, + image_format, + extents, + usage, + }; + + const VulkanImage::ViewProperties view_props{ + image_view_type, + image_format, + }; + + const VkImageLayout initial_layout = VK_IMAGE_LAYOUT_UNDEFINED; + + return VulkanImage( + allocator_, + alloc_create_info, + image_props, + view_props, + sampler_props, + initial_layout, + sampler, + allocate_memory); +} + +VulkanBuffer Allocator::create_storage_buffer( + const VkDeviceSize size, + const bool gpu_only, + const bool allocate_memory) { + const VkBufferUsageFlags buffer_usage = VK_BUFFER_USAGE_STORAGE_BUFFER_BIT; + + VmaAllocationCreateInfo alloc_create_info = {}; + alloc_create_info.flags = DEFAULT_ALLOCATION_STRATEGY; + alloc_create_info.usage = VMA_MEMORY_USAGE_AUTO_PREFER_DEVICE; + + // The create storage buffer will be accessed by both the CPU and GPU, so set + // the appropriate flags to indicate that the host device will be accessing + // the data from this buffer. + if (!gpu_only) { + // Deferred memory allocation should only be used for GPU only buffers. + VK_CHECK_COND( + allocate_memory, + "Only GPU-only buffers should use deferred memory allocation"); + + alloc_create_info.flags |= VMA_ALLOCATION_CREATE_HOST_ACCESS_RANDOM_BIT; + alloc_create_info.usage = VMA_MEMORY_USAGE_AUTO_PREFER_HOST; + alloc_create_info.requiredFlags = VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT; + alloc_create_info.preferredFlags = VK_MEMORY_PROPERTY_HOST_COHERENT_BIT | + VK_MEMORY_PROPERTY_HOST_CACHED_BIT; + } + + return VulkanBuffer( + allocator_, size, alloc_create_info, buffer_usage, allocate_memory); +} + +VulkanBuffer Allocator::create_staging_buffer(const VkDeviceSize size) { + VmaAllocationCreateInfo alloc_create_info = {}; + alloc_create_info.flags = DEFAULT_ALLOCATION_STRATEGY; + alloc_create_info.usage = VMA_MEMORY_USAGE_AUTO_PREFER_HOST; + + VkBufferUsageFlags buffer_usage = + VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT; + + return VulkanBuffer(allocator_, size, alloc_create_info, buffer_usage); +} + +VulkanBuffer Allocator::create_uniform_buffer(const VkDeviceSize size) { + VmaAllocationCreateInfo alloc_create_info = {}; + alloc_create_info.flags = DEFAULT_ALLOCATION_STRATEGY | + VMA_ALLOCATION_CREATE_HOST_ACCESS_SEQUENTIAL_WRITE_BIT; + alloc_create_info.usage = VMA_MEMORY_USAGE_AUTO; + + VkBufferUsageFlags buffer_usage = VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT; + + VulkanBuffer uniform_buffer( + allocator_, size, alloc_create_info, buffer_usage); + return uniform_buffer; +} + +} // namespace api +} // namespace vkcompute diff --git a/backends/vulkan/runtime/api/memory/Allocator.h b/backends/vulkan/runtime/api/memory/Allocator.h new file mode 100644 index 00000000000..f1d3a449f56 --- /dev/null +++ b/backends/vulkan/runtime/api/memory/Allocator.h @@ -0,0 +1,110 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +// @lint-ignore-every CLANGTIDY facebook-hte-BadMemberName + +#include + +#include + +#include + +#include +#include +#include + +namespace vkcompute { +namespace api { + +constexpr VmaAllocationCreateFlags DEFAULT_ALLOCATION_STRATEGY = + VMA_ALLOCATION_CREATE_STRATEGY_MIN_MEMORY_BIT; + +class Allocator final { + public: + explicit Allocator( + VkInstance instance, + VkPhysicalDevice physical_device, + VkDevice device); + + Allocator(const Allocator&) = delete; + Allocator& operator=(const Allocator&) = delete; + + Allocator(Allocator&&) noexcept; + Allocator& operator=(Allocator&&) = delete; + + ~Allocator(); + + private: + VkInstance instance_; + VkPhysicalDevice physical_device_; + VkDevice device_; + VmaAllocator allocator_; + + public: + Allocation create_allocation( + const VkMemoryRequirements& memory_requirements, + const VmaAllocationCreateInfo& create_info); + + VulkanImage create_image( + const VkExtent3D&, + const VkFormat, + const VkImageType, + const VkImageViewType, + const VulkanImage::SamplerProperties&, + VkSampler, + const bool allow_transfer = false, + const bool allocate_memory = true); + + VulkanBuffer create_storage_buffer( + const VkDeviceSize, + const bool gpu_only = true, + const bool allocate_memory = true); + + VulkanBuffer create_staging_buffer(const VkDeviceSize); + + /* + * Create a uniform buffer with a specified size + */ + VulkanBuffer create_uniform_buffer(const VkDeviceSize); + + /* + * Create a uniform buffer containing the data in an arbitrary struct + */ + template + VulkanBuffer create_params_buffer(const Block& block); + + VmaTotalStatistics get_memory_statistics() const { + VmaTotalStatistics stats = {}; + vmaCalculateStatistics(allocator_, &stats); + return stats; + } +}; + +// +// Impl +// + +template +inline VulkanBuffer Allocator::create_params_buffer(const Block& block) { + VulkanBuffer uniform_buffer = create_uniform_buffer(sizeof(Block)); + + // Fill the uniform buffer with data in block + { + MemoryMap mapping(uniform_buffer, MemoryAccessType::WRITE); + Block* data_ptr = mapping.template data(); + + *data_ptr = block; + } + + return uniform_buffer; +} + +} // namespace api +} // namespace vkcompute diff --git a/backends/vulkan/runtime/api/memory/Buffer.cpp b/backends/vulkan/runtime/api/memory/Buffer.cpp new file mode 100644 index 00000000000..b12f1bf8deb --- /dev/null +++ b/backends/vulkan/runtime/api/memory/Buffer.cpp @@ -0,0 +1,194 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +namespace vkcompute { +namespace api { + +// +// VulkanBuffer +// + +VulkanBuffer::VulkanBuffer() + : buffer_properties_{}, + allocator_(VK_NULL_HANDLE), + memory_{}, + owns_memory_(false), + handle_(VK_NULL_HANDLE) {} + +VulkanBuffer::VulkanBuffer( + VmaAllocator vma_allocator, + const VkDeviceSize size, + const VmaAllocationCreateInfo& allocation_create_info, + const VkBufferUsageFlags usage, + const bool allocate_memory) + : buffer_properties_({ + size, + 0u, + size, + usage, + }), + allocator_(vma_allocator), + memory_{}, + owns_memory_(allocate_memory), + handle_(VK_NULL_HANDLE) { + // Only allocate memory if the buffer has non-zero size + if (size == 0) { + return; + } + + const VkBufferCreateInfo buffer_create_info{ + VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO, // sType + nullptr, // pNext + 0u, // flags + size, // size + buffer_properties_.buffer_usage, // usage + VK_SHARING_MODE_EXCLUSIVE, // sharingMode + 0u, // queueFamilyIndexCount + nullptr, // pQueueFamilyIndices + }; + + memory_.create_info = allocation_create_info; + + if (allocate_memory) { + VK_CHECK(vmaCreateBuffer( + allocator_, + &buffer_create_info, + &allocation_create_info, + &handle_, + &(memory_.allocation), + nullptr)); + } else { + VmaAllocatorInfo allocator_info{}; + vmaGetAllocatorInfo(allocator_, &allocator_info); + VK_CHECK(vkCreateBuffer( + allocator_info.device, &buffer_create_info, nullptr, &handle_)); + } +} + +VulkanBuffer::VulkanBuffer(VulkanBuffer&& other) noexcept + : buffer_properties_(other.buffer_properties_), + allocator_(other.allocator_), + memory_(std::move(other.memory_)), + owns_memory_(other.owns_memory_), + handle_(other.handle_) { + other.handle_ = VK_NULL_HANDLE; +} + +VulkanBuffer& VulkanBuffer::operator=(VulkanBuffer&& other) noexcept { + VkBuffer tmp_buffer = handle_; + bool tmp_owns_memory = owns_memory_; + + buffer_properties_ = other.buffer_properties_; + allocator_ = other.allocator_; + memory_ = std::move(other.memory_); + owns_memory_ = other.owns_memory_; + handle_ = other.handle_; + + other.handle_ = tmp_buffer; + other.owns_memory_ = tmp_owns_memory; + + return *this; +} + +VulkanBuffer::~VulkanBuffer() { + if (VK_NULL_HANDLE != handle_) { + if (owns_memory_) { + vmaDestroyBuffer(allocator_, handle_, memory_.allocation); + } else { + vkDestroyBuffer(this->device(), handle_, nullptr); + } + // Prevent the underlying memory allocation from being freed; it was either + // freed by vmaDestroyBuffer, or this resource does not own the underlying + // memory + memory_.allocation = VK_NULL_HANDLE; + } +} + +VkMemoryRequirements VulkanBuffer::get_memory_requirements() const { + VkMemoryRequirements memory_requirements; + vkGetBufferMemoryRequirements(this->device(), handle_, &memory_requirements); + return memory_requirements; +} + +// +// MemoryMap +// + +MemoryMap::MemoryMap(const VulkanBuffer& buffer, const uint8_t access) + : access_(access), + allocator_(buffer.vma_allocator()), + allocation_(buffer.allocation()), + data_(nullptr), + data_len_{buffer.mem_size()} { + if (allocation_) { + VK_CHECK(vmaMapMemory(allocator_, allocation_, &data_)); + } +} + +MemoryMap::MemoryMap(MemoryMap&& other) noexcept + : access_(other.access_), + allocator_(other.allocator_), + allocation_(other.allocation_), + data_(other.data_), + data_len_{other.data_len_} { + other.allocation_ = VK_NULL_HANDLE; + other.data_ = nullptr; +} + +MemoryMap::~MemoryMap() { + if (!data_) { + return; + } + + if (allocation_) { + if (access_ & MemoryAccessType::WRITE) { + // Call will be ignored by implementation if the memory type this + // allocation belongs to is not HOST_VISIBLE or is HOST_COHERENT, which is + // the behavior we want. Don't check the result here as the destructor + // cannot throw. + vmaFlushAllocation(allocator_, allocation_, 0u, VK_WHOLE_SIZE); + } + + vmaUnmapMemory(allocator_, allocation_); + } +} + +void MemoryMap::invalidate() { + if (access_ & MemoryAccessType::READ && allocation_) { + // Call will be ignored by implementation if the memory type this allocation + // belongs to is not HOST_VISIBLE or is HOST_COHERENT, which is the behavior + // we want. + VK_CHECK( + vmaInvalidateAllocation(allocator_, allocation_, 0u, VK_WHOLE_SIZE)); + } +} + +// +// BufferMemoryBarrier +// + +BufferMemoryBarrier::BufferMemoryBarrier( + const VkAccessFlags src_access_flags, + const VkAccessFlags dst_access_flags, + const VulkanBuffer& buffer) + : handle{ + VK_STRUCTURE_TYPE_BUFFER_MEMORY_BARRIER, // sType + nullptr, // pNext + src_access_flags, // srcAccessMask + dst_access_flags, // dstAccessMask + VK_QUEUE_FAMILY_IGNORED, // srcQueueFamilyIndex + VK_QUEUE_FAMILY_IGNORED, // dstQueueFamilyIndex + buffer.handle_, // buffer + buffer.buffer_properties_.mem_offset, // offset + buffer.buffer_properties_.mem_range, // size + } {} + +} // namespace api +} // namespace vkcompute diff --git a/backends/vulkan/runtime/api/memory/Buffer.h b/backends/vulkan/runtime/api/memory/Buffer.h new file mode 100644 index 00000000000..c0eea5bea6e --- /dev/null +++ b/backends/vulkan/runtime/api/memory/Buffer.h @@ -0,0 +1,174 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +// @lint-ignore-every CLANGTIDY facebook-hte-BadMemberName + +#include + +#include + +#include + +#include + +namespace vkcompute { +namespace api { + +using MemoryAccessFlags = uint8_t; + +enum MemoryAccessType : MemoryAccessFlags { + NONE = 0u << 0u, + READ = 1u << 0u, + WRITE = 1u << 1u, +}; + +class VulkanBuffer final { + public: + struct BufferProperties final { + VkDeviceSize size; + VkDeviceSize mem_offset; + VkDeviceSize mem_range; + VkBufferUsageFlags buffer_usage; + }; + + explicit VulkanBuffer(); + + explicit VulkanBuffer( + const VmaAllocator, + const VkDeviceSize, + const VmaAllocationCreateInfo&, + const VkBufferUsageFlags, + const bool allocate_memory = true); + + VulkanBuffer(const VulkanBuffer&) = delete; + VulkanBuffer& operator=(const VulkanBuffer&) = delete; + + VulkanBuffer(VulkanBuffer&&) noexcept; + VulkanBuffer& operator=(VulkanBuffer&&) noexcept; + + ~VulkanBuffer(); + + struct Package final { + VkBuffer handle; + VkDeviceSize buffer_offset; + VkDeviceSize buffer_range; + }; + + friend struct BufferMemoryBarrier; + + private: + BufferProperties buffer_properties_; + VmaAllocator allocator_; + Allocation memory_; + // Indicates whether the underlying memory is owned by this resource + bool owns_memory_; + VkBuffer handle_; + + public: + inline VkDevice device() const { + VmaAllocatorInfo allocator_info{}; + vmaGetAllocatorInfo(allocator_, &allocator_info); + return allocator_info.device; + } + + inline VmaAllocator vma_allocator() const { + return allocator_; + } + + inline VmaAllocation allocation() const { + return memory_.allocation; + } + + inline VmaAllocationCreateInfo allocation_create_info() const { + return VmaAllocationCreateInfo(memory_.create_info); + } + + inline VkBuffer handle() const { + return handle_; + } + + inline VkDeviceSize mem_offset() const { + return buffer_properties_.mem_offset; + } + + inline VkDeviceSize mem_range() const { + return buffer_properties_.mem_range; + } + + inline VkDeviceSize mem_size() const { + return buffer_properties_.size; + } + + inline bool has_memory() const { + return (memory_.allocation != VK_NULL_HANDLE); + } + + inline bool owns_memory() const { + return owns_memory_; + } + + operator bool() const { + return (handle_ != VK_NULL_HANDLE); + } + + inline void bind_allocation(const Allocation& memory) { + VK_CHECK_COND(!memory_, "Cannot bind an already bound allocation!"); + VK_CHECK(vmaBindBufferMemory(allocator_, memory.allocation, handle_)); + memory_.allocation = memory.allocation; + } + + VkMemoryRequirements get_memory_requirements() const; +}; + +class MemoryMap final { + public: + explicit MemoryMap( + const VulkanBuffer& buffer, + const MemoryAccessFlags access); + + MemoryMap(const MemoryMap&) = delete; + MemoryMap& operator=(const MemoryMap&) = delete; + + MemoryMap(MemoryMap&&) noexcept; + MemoryMap& operator=(MemoryMap&&) = delete; + + ~MemoryMap(); + + private: + uint8_t access_; + VmaAllocator allocator_; + VmaAllocation allocation_; + void* data_; + VkDeviceSize data_len_; + + public: + template + T* data() { + return reinterpret_cast(data_); + } + + inline size_t nbytes() { + return utils::safe_downcast(data_len_); + } + + void invalidate(); +}; + +struct BufferMemoryBarrier final { + VkBufferMemoryBarrier handle; + + BufferMemoryBarrier( + const VkAccessFlags src_access_flags, + const VkAccessFlags dst_access_flags, + const VulkanBuffer& buffer); +}; + +} // namespace api +} // namespace vkcompute diff --git a/backends/vulkan/runtime/api/memory/Image.cpp b/backends/vulkan/runtime/api/memory/Image.cpp new file mode 100644 index 00000000000..449dbaf2416 --- /dev/null +++ b/backends/vulkan/runtime/api/memory/Image.cpp @@ -0,0 +1,336 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +namespace vkcompute { +namespace api { + +// +// ImageSampler +// + +bool operator==( + const ImageSampler::Properties& _1, + const ImageSampler::Properties& _2) { + return ( + _1.filter == _2.filter && _1.mipmap_mode == _2.mipmap_mode && + _1.address_mode == _2.address_mode && _1.border_color == _2.border_color); +} + +ImageSampler::ImageSampler( + VkDevice device, + const ImageSampler::Properties& props) + : device_(device), handle_(VK_NULL_HANDLE) { + const VkSamplerCreateInfo sampler_create_info{ + VK_STRUCTURE_TYPE_SAMPLER_CREATE_INFO, // sType + nullptr, // pNext + 0u, // flags + props.filter, // magFilter + props.filter, // minFilter + props.mipmap_mode, // mipmapMode + props.address_mode, // addressModeU + props.address_mode, // addressModeV + props.address_mode, // addressModeW + 0.0f, // mipLodBias + VK_FALSE, // anisotropyEnable + 1.0f, // maxAnisotropy, + VK_FALSE, // compareEnable + VK_COMPARE_OP_NEVER, // compareOp + 0.0f, // minLod + VK_LOD_CLAMP_NONE, // maxLod + props.border_color, // borderColor + VK_FALSE, // unnormalizedCoordinates + }; + + VK_CHECK(vkCreateSampler(device_, &sampler_create_info, nullptr, &handle_)); +} + +ImageSampler::ImageSampler(ImageSampler&& other) noexcept + : device_(other.device_), handle_(other.handle_) { + other.handle_ = VK_NULL_HANDLE; +} + +ImageSampler::~ImageSampler() { + if (VK_NULL_HANDLE == handle_) { + return; + } + vkDestroySampler(device_, handle_, nullptr); +} + +size_t ImageSampler::Hasher::operator()( + const ImageSampler::Properties& props) const { + size_t seed = 0; + seed = utils::hash_combine(seed, std::hash()(props.filter)); + seed = utils::hash_combine( + seed, std::hash()(props.mipmap_mode)); + seed = utils::hash_combine( + seed, std::hash()(props.address_mode)); + seed = + utils::hash_combine(seed, std::hash()(props.border_color)); + return seed; +} + +void swap(ImageSampler& lhs, ImageSampler& rhs) noexcept { + VkDevice tmp_device = lhs.device_; + VkSampler tmp_handle = lhs.handle_; + + lhs.device_ = rhs.device_; + lhs.handle_ = rhs.handle_; + + rhs.device_ = tmp_device; + rhs.handle_ = tmp_handle; +} + +// +// VulkanImage +// + +VulkanImage::VulkanImage() + : image_properties_{}, + view_properties_{}, + sampler_properties_{}, + allocator_(VK_NULL_HANDLE), + memory_{}, + owns_memory_(false), + handles_{ + VK_NULL_HANDLE, + VK_NULL_HANDLE, + VK_NULL_HANDLE, + }, + layout_{} {} + +VulkanImage::VulkanImage( + VmaAllocator vma_allocator, + const VmaAllocationCreateInfo& allocation_create_info, + const ImageProperties& image_props, + const ViewProperties& view_props, + const SamplerProperties& sampler_props, + const VkImageLayout layout, + VkSampler sampler, + const bool allocate_memory) + : image_properties_(image_props), + view_properties_(view_props), + sampler_properties_(sampler_props), + allocator_(vma_allocator), + memory_{}, + owns_memory_{allocate_memory}, + handles_{ + VK_NULL_HANDLE, + VK_NULL_HANDLE, + sampler, + }, + layout_(layout) { + VmaAllocatorInfo allocator_info{}; + vmaGetAllocatorInfo(allocator_, &allocator_info); + + // If any dims are zero, then no memory will be allocated for the image. + if (image_props.image_extents.width == 0 || + image_props.image_extents.height == 0 || + image_props.image_extents.depth == 0) { + return; + } + + const VkImageCreateInfo image_create_info{ + VK_STRUCTURE_TYPE_IMAGE_CREATE_INFO, // sType + nullptr, // pNext + 0u, // flags + image_properties_.image_type, // imageType + image_properties_.image_format, // format + image_properties_.image_extents, // extents + 1u, // mipLevels + 1u, // arrayLayers + VK_SAMPLE_COUNT_1_BIT, // samples + VK_IMAGE_TILING_OPTIMAL, // tiling + image_properties_.image_usage, // usage + VK_SHARING_MODE_EXCLUSIVE, // sharingMode + 0u, // queueFamilyIndexCount + nullptr, // pQueueFamilyIndices + layout_, // initialLayout + }; + + memory_.create_info = allocation_create_info; + + if (allocate_memory) { + VK_CHECK(vmaCreateImage( + allocator_, + &image_create_info, + &allocation_create_info, + &(handles_.image), + &(memory_.allocation), + nullptr)); + // Only create the image view if the image has been bound to memory + create_image_view(); + } else { + VK_CHECK(vkCreateImage( + allocator_info.device, &image_create_info, nullptr, &(handles_.image))); + } +} + +VulkanImage::VulkanImage(VulkanImage&& other) noexcept + : image_properties_(other.image_properties_), + view_properties_(other.view_properties_), + sampler_properties_(other.sampler_properties_), + allocator_(other.allocator_), + memory_(std::move(other.memory_)), + owns_memory_(other.owns_memory_), + handles_(other.handles_), + layout_(other.layout_) { + other.handles_.image = VK_NULL_HANDLE; + other.handles_.image_view = VK_NULL_HANDLE; + other.handles_.sampler = VK_NULL_HANDLE; + other.owns_memory_ = false; +} + +VulkanImage& VulkanImage::operator=(VulkanImage&& other) noexcept { + VkImage tmp_image = handles_.image; + VkImageView tmp_image_view = handles_.image_view; + bool tmp_owns_memory = owns_memory_; + + image_properties_ = other.image_properties_; + view_properties_ = other.view_properties_; + sampler_properties_ = other.sampler_properties_; + allocator_ = other.allocator_; + memory_ = std::move(other.memory_); + owns_memory_ = other.owns_memory_; + handles_ = other.handles_; + layout_ = other.layout_; + + other.handles_.image = tmp_image; + other.handles_.image_view = tmp_image_view; + other.owns_memory_ = tmp_owns_memory; + + return *this; +} + +VulkanImage::~VulkanImage() { + if (VK_NULL_HANDLE != handles_.image_view) { + vkDestroyImageView(this->device(), handles_.image_view, nullptr); + } + + if (VK_NULL_HANDLE != handles_.image) { + if (owns_memory_) { + vmaDestroyImage(allocator_, handles_.image, memory_.allocation); + } else { + vkDestroyImage(this->device(), handles_.image, nullptr); + } + // Prevent the underlying memory allocation from being freed; it was either + // freed by vmaDestroyImage, or this resource does not own the underlying + // memory + memory_.allocation = VK_NULL_HANDLE; + } +} + +void VulkanImage::create_image_view() { + VmaAllocatorInfo allocator_info{}; + vmaGetAllocatorInfo(allocator_, &allocator_info); + + const VkComponentMapping component_mapping{ + VK_COMPONENT_SWIZZLE_IDENTITY, // r + VK_COMPONENT_SWIZZLE_IDENTITY, // g + VK_COMPONENT_SWIZZLE_IDENTITY, // b + VK_COMPONENT_SWIZZLE_IDENTITY, // a + }; + + const VkImageSubresourceRange subresource_range{ + VK_IMAGE_ASPECT_COLOR_BIT, // aspectMask + 0u, // baseMipLevel + VK_REMAINING_MIP_LEVELS, // levelCount + 0u, // baseArrayLayer + VK_REMAINING_ARRAY_LAYERS, // layerCount + }; + + const VkImageViewCreateInfo image_view_create_info{ + VK_STRUCTURE_TYPE_IMAGE_VIEW_CREATE_INFO, // sType + nullptr, // pNext + 0u, // flags + handles_.image, // image + view_properties_.view_type, // viewType + view_properties_.view_format, // format + component_mapping, // components + subresource_range, // subresourceRange + }; + + VK_CHECK(vkCreateImageView( + allocator_info.device, + &(image_view_create_info), + nullptr, + &(handles_.image_view))); +} + +VkMemoryRequirements VulkanImage::get_memory_requirements() const { + VkMemoryRequirements memory_requirements; + vkGetImageMemoryRequirements( + this->device(), handles_.image, &memory_requirements); + return memory_requirements; +} + +// +// ImageMemoryBarrier +// + +ImageMemoryBarrier::ImageMemoryBarrier( + const VkAccessFlags src_access_flags, + const VkAccessFlags dst_access_flags, + const VkImageLayout src_layout_flags, + const VkImageLayout dst_layout_flags, + const VulkanImage& image) + : handle{ + VK_STRUCTURE_TYPE_IMAGE_MEMORY_BARRIER, // sType + nullptr, // pNext + src_access_flags, // srcAccessMask + dst_access_flags, // dstAccessMask + src_layout_flags, // oldLayout + dst_layout_flags, // newLayout + VK_QUEUE_FAMILY_IGNORED, // srcQueueFamilyIndex + VK_QUEUE_FAMILY_IGNORED, // dstQueueFamilyIndex + image.handles_.image, // image + { + // subresourceRange + VK_IMAGE_ASPECT_COLOR_BIT, // aspectMask + 0u, // baseMipLevel + VK_REMAINING_MIP_LEVELS, // levelCount + 0u, // baseArrayLayer + VK_REMAINING_ARRAY_LAYERS, // layerCount + }, + } {} + +// +// SamplerCache +// + +SamplerCache::SamplerCache(VkDevice device) + : cache_mutex_{}, device_(device), cache_{} {} + +SamplerCache::SamplerCache(SamplerCache&& other) noexcept + : cache_mutex_{}, device_(other.device_), cache_(std::move(other.cache_)) { + std::lock_guard lock(other.cache_mutex_); +} + +SamplerCache::~SamplerCache() { + purge(); +} + +VkSampler SamplerCache::retrieve(const SamplerCache::Key& key) { + std::lock_guard lock(cache_mutex_); + + auto it = cache_.find(key); + if (cache_.cend() == it) { + it = cache_.insert({key, SamplerCache::Value(device_, key)}).first; + } + + return it->second.handle(); +} + +void SamplerCache::purge() { + std::lock_guard lock(cache_mutex_); + cache_.clear(); +} + +} // namespace api +} // namespace vkcompute diff --git a/backends/vulkan/runtime/api/memory/Image.h b/backends/vulkan/runtime/api/memory/Image.h new file mode 100644 index 00000000000..e3f4d7437df --- /dev/null +++ b/backends/vulkan/runtime/api/memory/Image.h @@ -0,0 +1,253 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +// @lint-ignore-every CLANGTIDY facebook-hte-BadMemberName + +#include + +#include + +#include + +#include + +#include +#include + +namespace vkcompute { +namespace api { + +class ImageSampler final { + public: + struct Properties final { + VkFilter filter; + VkSamplerMipmapMode mipmap_mode; + VkSamplerAddressMode address_mode; + VkBorderColor border_color; + }; + + explicit ImageSampler(VkDevice, const Properties&); + + ImageSampler(const ImageSampler&) = delete; + ImageSampler& operator=(const ImageSampler&) = delete; + + ImageSampler(ImageSampler&&) noexcept; + ImageSampler& operator=(ImageSampler&&) = delete; + + ~ImageSampler(); + + private: + VkDevice device_; + VkSampler handle_; + + public: + VkSampler handle() const { + return handle_; + } + + struct Hasher { + size_t operator()(const Properties&) const; + }; + + // We need to define a custom swap function since this class + // does not allow for move assignment. The swap function will + // be used in the hash map. + friend void swap(ImageSampler& lhs, ImageSampler& rhs) noexcept; +}; + +class VulkanImage final { + public: + struct ImageProperties final { + VkImageType image_type; + VkFormat image_format; + VkExtent3D image_extents; + VkImageUsageFlags image_usage; + }; + + struct ViewProperties final { + VkImageViewType view_type; + VkFormat view_format; + }; + + using SamplerProperties = ImageSampler::Properties; + + struct Handles final { + VkImage image; + VkImageView image_view; + VkSampler sampler; + }; + + explicit VulkanImage(); + + explicit VulkanImage( + const VmaAllocator, + const VmaAllocationCreateInfo&, + const ImageProperties&, + const ViewProperties&, + const SamplerProperties&, + const VkImageLayout layout, + VkSampler, + const bool allocate_memory = true); + + VulkanImage(const VulkanImage&) = delete; + VulkanImage& operator=(const VulkanImage&) = delete; + + VulkanImage(VulkanImage&&) noexcept; + VulkanImage& operator=(VulkanImage&&) noexcept; + + ~VulkanImage(); + + struct Package final { + VkImage handle; + VkImageLayout image_layout; + VkImageView image_view; + VkSampler image_sampler; + }; + + friend struct ImageMemoryBarrier; + + private: + ImageProperties image_properties_; + ViewProperties view_properties_; + SamplerProperties sampler_properties_; + // The allocator object this was allocated from + VmaAllocator allocator_; + // Handles to the allocated memory + Allocation memory_; + // Indicates whether the underlying memory is owned by this resource + bool owns_memory_; + Handles handles_; + // Layout + VkImageLayout layout_; + + public: + void create_image_view(); + + inline VkDevice device() const { + VmaAllocatorInfo allocator_info{}; + vmaGetAllocatorInfo(allocator_, &allocator_info); + return allocator_info.device; + } + + inline VmaAllocator vma_allocator() const { + return allocator_; + } + + inline VmaAllocation allocation() const { + return memory_.allocation; + } + + inline VmaAllocationCreateInfo allocation_create_info() const { + return VmaAllocationCreateInfo(memory_.create_info); + } + + inline VkFormat format() const { + return image_properties_.image_format; + } + + inline VkExtent3D extents() const { + return image_properties_.image_extents; + } + + inline VkImage handle() const { + return handles_.image; + } + + inline VkImageView image_view() const { + return handles_.image_view; + } + + inline VkSampler sampler() const { + return handles_.sampler; + } + + Package package() const { + return { + handles_.image, + layout_, + handles_.image_view, + handles_.sampler, + }; + } + + inline VkImageLayout layout() const { + return layout_; + } + + inline void set_layout(const VkImageLayout layout) { + layout_ = layout; + } + + inline bool has_memory() const { + return (memory_.allocation != VK_NULL_HANDLE); + } + + inline bool owns_memory() const { + return owns_memory_; + } + + inline operator bool() const { + return (handles_.image != VK_NULL_HANDLE); + } + + inline void bind_allocation(const Allocation& memory) { + VK_CHECK_COND(!memory_, "Cannot bind an already bound allocation!"); + VK_CHECK(vmaBindImageMemory(allocator_, memory.allocation, handles_.image)); + memory_.allocation = memory.allocation; + + // Only create the image view if the image has been bound to memory + create_image_view(); + } + + VkMemoryRequirements get_memory_requirements() const; +}; + +struct ImageMemoryBarrier final { + VkImageMemoryBarrier handle; + + ImageMemoryBarrier( + const VkAccessFlags src_access_flags, + const VkAccessFlags dst_access_flags, + const VkImageLayout src_layout_flags, + const VkImageLayout dst_layout_flags, + const VulkanImage& image); +}; + +class SamplerCache final { + public: + explicit SamplerCache(VkDevice device); + + SamplerCache(const SamplerCache&) = delete; + SamplerCache& operator=(const SamplerCache&) = delete; + + SamplerCache(SamplerCache&&) noexcept; + SamplerCache& operator=(SamplerCache&&) = delete; + + ~SamplerCache(); + + using Key = ImageSampler::Properties; + using Value = ImageSampler; + using Hasher = ImageSampler::Hasher; + + private: + // Multiple threads could potentially be adding entries into the cache, so use + // a mutex to manage access + std::mutex cache_mutex_; + + VkDevice device_; + std::unordered_map cache_; + + public: + VkSampler retrieve(const Key&); + void purge(); +}; + +} // namespace api +} // namespace vkcompute diff --git a/backends/vulkan/runtime/api/Allocator.cpp b/backends/vulkan/runtime/api/memory/vma_api.cpp similarity index 78% rename from backends/vulkan/runtime/api/Allocator.cpp rename to backends/vulkan/runtime/api/memory/vma_api.cpp index 3cedaa2f5af..d1180305fea 100644 --- a/backends/vulkan/runtime/api/Allocator.cpp +++ b/backends/vulkan/runtime/api/memory/vma_api.cpp @@ -7,4 +7,4 @@ */ #define VMA_IMPLEMENTATION -#include +#include diff --git a/backends/vulkan/runtime/api/Allocator.h b/backends/vulkan/runtime/api/memory/vma_api.h similarity index 92% rename from backends/vulkan/runtime/api/Allocator.h rename to backends/vulkan/runtime/api/memory/vma_api.h index a5a9ea02a98..34e3219d934 100644 --- a/backends/vulkan/runtime/api/Allocator.h +++ b/backends/vulkan/runtime/api/memory/vma_api.h @@ -10,11 +10,9 @@ // // Do NOT include vk_mem_alloc.h directly. -// Always include this file (Allocator.h) instead. +// Always include this file (vma_api.h) instead. // -#include - #define VMA_VULKAN_VERSION 1000000 #ifdef USE_VULKAN_WRAPPER diff --git a/backends/vulkan/runtime/graph/ComputeGraph.cpp b/backends/vulkan/runtime/graph/ComputeGraph.cpp index 0c7941d6f52..aa34aae6771 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.cpp +++ b/backends/vulkan/runtime/graph/ComputeGraph.cpp @@ -56,7 +56,7 @@ ComputeGraph::ComputeGraph(GraphConfig config) execute_descriptor_counts_{}, context_{new api::Context( api::runtime()->default_adapter_i(), - config_.contextConfig)}, + config_.context_config)}, shared_objects_{}, values_{}, param_ubos_{}, @@ -65,17 +65,17 @@ ComputeGraph::ComputeGraph(GraphConfig config) inputs_{}, outputs_{} { // Ensure that descriptor counts are initialized to 0 - prepack_descriptor_counts_.descriptorPoolMaxSets = 0; - prepack_descriptor_counts_.descriptorUniformBufferCount = 0; - prepack_descriptor_counts_.descriptorStorageBufferCount = 0; - prepack_descriptor_counts_.descriptorCombinedSamplerCount = 0; - prepack_descriptor_counts_.descriptorStorageImageCount = 0; - - execute_descriptor_counts_.descriptorPoolMaxSets = 0; - execute_descriptor_counts_.descriptorUniformBufferCount = 0; - execute_descriptor_counts_.descriptorStorageBufferCount = 0; - execute_descriptor_counts_.descriptorCombinedSamplerCount = 0; - execute_descriptor_counts_.descriptorStorageImageCount = 0; + prepack_descriptor_counts_.descriptor_pool_max_sets = 0; + prepack_descriptor_counts_.descriptor_uniform_buffer_count = 0; + prepack_descriptor_counts_.descriptor_storage_buffer_count = 0; + prepack_descriptor_counts_.descriptor_combined_sampler_count = 0; + prepack_descriptor_counts_.descriptor_storage_image_count = 0; + + execute_descriptor_counts_.descriptor_pool_max_sets = 0; + execute_descriptor_counts_.descriptor_uniform_buffer_count = 0; + execute_descriptor_counts_.descriptor_storage_buffer_count = 0; + execute_descriptor_counts_.descriptor_combined_sampler_count = 0; + execute_descriptor_counts_.descriptor_storage_image_count = 0; context_->set_cmd(/*reusable = */ true); } @@ -89,44 +89,17 @@ ComputeGraph::~ComputeGraph() { context_->flush(); } -void ComputeGraph::update_descriptor_counts( - const api::ShaderInfo& shader_info, - bool execute) { - api::DescriptorPoolConfig* config = - execute ? &execute_descriptor_counts_ : &prepack_descriptor_counts_; - - config->descriptorPoolMaxSets += 1; - for (const VkDescriptorType arg_type : shader_info.kernel_layout) { - switch (arg_type) { - case VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER: - config->descriptorUniformBufferCount += 1; - break; - case VK_DESCRIPTOR_TYPE_STORAGE_BUFFER: - config->descriptorStorageBufferCount += 1; - break; - case VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER: - config->descriptorCombinedSamplerCount += 1; - break; - case VK_DESCRIPTOR_TYPE_STORAGE_IMAGE: - config->descriptorStorageImageCount += 1; - break; - default: - VK_THROW("Unsupported descriptor type!"); - } - } -} - api::StorageType ComputeGraph::suggested_storage_type() { - if (config_.enableStorageTypeOverride) { - return config_.storageTypeOverride; + if (config_.enable_storage_type_override) { + return config_.storage_type_override; } return api::kTexture3D; } api::GPUMemoryLayout ComputeGraph::suggested_memory_layout( const std::vector& sizes) { - if (config_.enableMemoryLayoutOverride) { - return config_.memoryLayoutOverride; + if (config_.enable_memory_layout_override) { + return config_.memory_layout_override; } if (sizes.size() < 3) { return api::kWidthPacked; @@ -148,22 +121,22 @@ void ComputeGraph::check_no_active_value_ptrs() { "invalidated."); } -std::vector ComputeGraph::get_sizes_of(ValueRef idx) { - Value& val = values_.at(idx); +std::vector ComputeGraph::sizes_of(const ValueRef idx) const { + const Value& val = values_.at(idx); if (val.isTensor()) { - return val.toTensor().sizes(); + return val.toConstTensor().sizes(); } else if (val.isTensorRef()) { - return val.toTensorRef().sizes; + return val.toConstTensorRef().sizes; } VK_THROW("Could not get sizes of value with type ", val.type()); } -api::ScalarType ComputeGraph::get_dtype_of(ValueRef idx) { - Value& val = values_.at(idx); +api::ScalarType ComputeGraph::dtype_of(const ValueRef idx) const { + const Value& val = values_.at(idx); if (val.isTensor()) { - return val.toTensor().dtype(); + return val.toConstTensor().dtype(); } else if (val.isTensorRef()) { - return val.toTensorRef().dtype; + return val.toConstTensorRef().dtype; } VK_THROW("Could not get dtype of value with type ", val.type()); } @@ -200,14 +173,13 @@ ValueRef ComputeGraph::add_tensor_like( const ValueRef idx, const api::StorageType storage_type, const api::GPUMemoryLayout memory_layout) { - return add_tensor( - get_sizes_of(idx), get_dtype_of(idx), storage_type, memory_layout); + return add_tensor(sizes_of(idx), dtype_of(idx), storage_type, memory_layout); } ValueRef ComputeGraph::add_tensor_like( const ValueRef idx, const api::GPUMemoryLayout memory_layout) { - return add_tensor(get_sizes_of(idx), get_dtype_of(idx), memory_layout); + return add_tensor(sizes_of(idx), dtype_of(idx), memory_layout); } ValueRef ComputeGraph::add_tensor( @@ -280,7 +252,12 @@ ValueRef ComputeGraph::set_output_tensor( api::ScalarType dtype = get_tensor(idx)->dtype(); size_t gpu_numel = get_tensor(idx)->gpu_numel(); ValueRef staging_idx = add_staging(dtype, gpu_numel); - add_tensor_to_staging_node(*this, idx, staging_idx); + // We only run this when the tensor is non-empty. When the underlying + // tensor is empty (e.g. gpu_numel == 0), we do not allocate a VkImage to + // tensor, we will not be able to bind the node for execution. + if (gpu_numel > 0) { + add_tensor_to_staging_node(*this, idx, staging_idx); + } outputs_.push_back({idx, staging_idx}); return staging_idx; } @@ -295,6 +272,33 @@ SharedObject& ComputeGraph::get_shared_object(const int64_t idx) { return shared_objects_.at(idx); } +void ComputeGraph::update_descriptor_counts( + const api::ShaderInfo& shader_info, + bool execute) { + api::DescriptorPoolConfig* config = + execute ? &execute_descriptor_counts_ : &prepack_descriptor_counts_; + + config->descriptor_pool_max_sets += 1; + for (const VkDescriptorType arg_type : shader_info.kernel_layout) { + switch (arg_type) { + case VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER: + config->descriptor_uniform_buffer_count += 1; + break; + case VK_DESCRIPTOR_TYPE_STORAGE_BUFFER: + config->descriptor_storage_buffer_count += 1; + break; + case VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER: + config->descriptor_combined_sampler_count += 1; + break; + case VK_DESCRIPTOR_TYPE_STORAGE_IMAGE: + config->descriptor_storage_image_count += 1; + break; + default: + VK_THROW("Unsupported descriptor type!"); + } + } +} + void ComputeGraph::copy_into_staging( const ValueRef idx, const void* data, @@ -319,15 +323,15 @@ void ComputeGraph::prepare() { std::max( \ execute_descriptor_counts_.field, \ prepack_descriptor_counts_.field) * \ - config_.descriptorPoolSafetyFactor)) + config_.descriptor_pool_safety_factor)) - uint32_t max_sets = MERGE_FIELD(descriptorPoolMaxSets); + uint32_t max_sets = MERGE_FIELD(descriptor_pool_max_sets); api::DescriptorPoolConfig config{ max_sets, - std::max(MERGE_FIELD(descriptorUniformBufferCount), max_sets), - std::max(MERGE_FIELD(descriptorStorageBufferCount), max_sets), - std::max(MERGE_FIELD(descriptorCombinedSamplerCount), max_sets), - std::max(MERGE_FIELD(descriptorStorageImageCount), max_sets), + std::max(MERGE_FIELD(descriptor_uniform_buffer_count), max_sets), + std::max(MERGE_FIELD(descriptor_storage_buffer_count), max_sets), + std::max(MERGE_FIELD(descriptor_combined_sampler_count), max_sets), + std::max(MERGE_FIELD(descriptor_storage_image_count), max_sets), 1u, }; diff --git a/backends/vulkan/runtime/graph/ComputeGraph.h b/backends/vulkan/runtime/graph/ComputeGraph.h index 00d8cbd3c55..6a7c9e3f424 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.h +++ b/backends/vulkan/runtime/graph/ComputeGraph.h @@ -123,9 +123,17 @@ class ComputeGraph final { return outputs_; } - void update_descriptor_counts( - const api::ShaderInfo& shader_info, - bool execute); + inline std::vector>& prepack_nodes() { + return prepack_nodes_; + } + + inline std::vector>& execute_nodes() { + return execute_nodes_; + } + + // + // Value Extraction + // #define GET_AND_CHECK_VAL_AS_PTR_TYPE_FNS(ptr_type, short_name, type_name) \ inline ptr_type get_##short_name(const ValueRef idx) { \ @@ -168,9 +176,33 @@ class ComputeGraph final { return values_.at(idx).type(); } - std::vector get_sizes_of(ValueRef idx); + // Get Tensor Property + + std::vector sizes_of(const ValueRef idx) const; + + api::ScalarType dtype_of(const ValueRef idx) const; + + inline api::utils::uvec3 extents_of(const ValueRef idx) const { + return values_.at(idx).toConstTensor().extents(); + } + + inline api::GPUMemoryLayout memory_layout_of(const ValueRef idx) const { + return values_.at(idx).toConstTensor().gpu_memory_layout(); + } - api::ScalarType get_dtype_of(ValueRef idx); + inline api::BufferBindInfo sizes_ubo(const ValueRef idx) { + return values_.at(idx).toTensor().sizes_ubo(); + } + + inline api::BufferBindInfo texture_limits_ubo(const ValueRef idx) { + return values_.at(idx).toTensor().texture_limits_ubo(); + } + + inline api::BufferBindInfo packed_dim_meta_ubo(const ValueRef idx) { + return values_.at(idx).toTensor().packed_dim_meta_ubo(); + } + + // Scalar Value Extraction template T extract_scalar(const ValueRef idx) { @@ -196,12 +228,8 @@ class ComputeGraph final { } } - inline std::vector>& prepack_nodes() { - return prepack_nodes_; - } - - inline std::vector>& execute_nodes() { - return execute_nodes_; + std::string extract_string(const ValueRef idx) { + return values_.at(idx).toString(); } // @@ -229,13 +257,6 @@ class ComputeGraph final { api::GPUMemoryLayout suggested_memory_layout( const std::vector& sizes); - /* - * Returns the memory layout of a Tensor value at the specified index. - */ - inline api::GPUMemoryLayout memory_layout_of(ValueRef idx) { - return get_tensor(idx)->gpu_memory_layout(); - } - // // Graph Building // @@ -363,6 +384,10 @@ class ComputeGraph final { // Graph Preparation // + void update_descriptor_counts( + const api::ShaderInfo& shader_info, + bool execute); + void prepare(); // diff --git a/backends/vulkan/runtime/graph/GraphConfig.cpp b/backends/vulkan/runtime/graph/GraphConfig.cpp index 98b2d9a4263..29de4704395 100644 --- a/backends/vulkan/runtime/graph/GraphConfig.cpp +++ b/backends/vulkan/runtime/graph/GraphConfig.cpp @@ -12,12 +12,12 @@ namespace vkcompute { GraphConfig::GraphConfig() { // No automatic submissions - const uint32_t submit_frequency = UINT32_MAX; + const uint32_t cmd_submit_frequency = UINT32_MAX; // Only one command buffer will be encoded at a time const api::CommandPoolConfig cmd_config{ - 1u, // cmdPoolInitialSize - 1u, // cmdPoolBatchSize + 1u, // cmd_pool_initial_size + 1u, // cmd_pool_batch_size }; // Use lazy descriptor pool initialization by default; the graph runtime will @@ -25,49 +25,48 @@ GraphConfig::GraphConfig() { // trigger descriptor pool initialization with exact sizes before encoding the // command buffer. const api::DescriptorPoolConfig descriptor_pool_config{ - 0u, // descriptorPoolMaxSets - 0u, // descriptorUniformBufferCount - 0u, // descriptorStorageBufferCount - 0u, // descriptorCombinedSamplerCount - 0u, // descriptorStorageImageCount - 0u, // descriptorPileSizes + 0u, // descriptor_pool_max_sets + 0u, // descriptor_uniform_buffer_count + 0u, // descriptor_storage_buffer_count + 0u, // descriptor_combined_sampler_count + 0u, // descriptor_storage_image_count + 0u, // descriptor_pile_sizes }; const api::QueryPoolConfig query_pool_config{}; - const api::ContextConfig context_config{ - submit_frequency, // cmdSubmitFrequency - cmd_config, // cmdPoolConfig - descriptor_pool_config, // descriptorPoolConfig - query_pool_config, // queryPoolConfig + context_config = { + cmd_submit_frequency, + cmd_config, + descriptor_pool_config, + query_pool_config, }; - contextConfig = context_config; - // Empirically selected safety factor. If descriptor pools start running out // of memory, increase this safety factor. - descriptorPoolSafetyFactor = 1.25; + descriptor_pool_safety_factor = 1.25; // For now, force kTexture3D storage as we are still developing shader support // for buffer storage type. - enableStorageTypeOverride = true; - storageTypeOverride = api::kTexture3D; + enable_storage_type_override = true; + storage_type_override = api::kTexture3D; // For now, force kWidthPacked memory layout by default as we are still // developing support for other memory layouts. In the future memory layout // settings will be serialized as part of the graph. - enableMemoryLayoutOverride = true; - memoryLayoutOverride = api::kWidthPacked; + enable_memory_layout_override = true; + memory_layout_override = api::kWidthPacked; } -void GraphConfig::setStorageTypeOverride(api::StorageType storage_type) { - enableStorageTypeOverride = true; - storageTypeOverride = storage_type; +void GraphConfig::set_storage_type_override(api::StorageType storage_type) { + enable_storage_type_override = true; + storage_type_override = storage_type; } -void GraphConfig::setMemoryLayoutOverride(api::GPUMemoryLayout memory_layout) { - enableMemoryLayoutOverride = true; - memoryLayoutOverride = memory_layout; +void GraphConfig::set_memory_layout_override( + api::GPUMemoryLayout memory_layout) { + enable_memory_layout_override = true; + memory_layout_override = memory_layout; } } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/GraphConfig.h b/backends/vulkan/runtime/graph/GraphConfig.h index 7fb99f50407..f3e311daa22 100644 --- a/backends/vulkan/runtime/graph/GraphConfig.h +++ b/backends/vulkan/runtime/graph/GraphConfig.h @@ -13,26 +13,26 @@ namespace vkcompute { struct GraphConfig final { - api::ContextConfig contextConfig; + api::ContextConfig context_config; // Creating a descriptor pool with exactly the number of descriptors tallied // by iterating through the shader layouts of shaders used in the graph risks // the descriptor pool running out of memory, therefore apply a safety factor // to descriptor counts when creating the descriptor pool to mitigate this // risk. - float descriptorPoolSafetyFactor; + float descriptor_pool_safety_factor; - bool enableStorageTypeOverride; - api::StorageType storageTypeOverride; + bool enable_storage_type_override; + api::StorageType storage_type_override; - bool enableMemoryLayoutOverride; - api::GPUMemoryLayout memoryLayoutOverride; + bool enable_memory_layout_override; + api::GPUMemoryLayout memory_layout_override; // Generate a default graph config with pre-configured settings explicit GraphConfig(); - void setStorageTypeOverride(api::StorageType storage_type); - void setMemoryLayoutOverride(api::GPUMemoryLayout memory_layout); + void set_storage_type_override(api::StorageType storage_type); + void set_memory_layout_override(api::GPUMemoryLayout memory_layout); }; } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/containers/SharedObject.h b/backends/vulkan/runtime/graph/containers/SharedObject.h index f1e96bf0c2c..09509ad45b9 100644 --- a/backends/vulkan/runtime/graph/containers/SharedObject.h +++ b/backends/vulkan/runtime/graph/containers/SharedObject.h @@ -30,7 +30,7 @@ struct SharedObject { VkMemoryRequirements aggregate_memory_requirements; VmaAllocationCreateInfo aggregate_create_info; std::vector users; - api::MemoryAllocation allocation; + api::Allocation allocation; void add_user(ComputeGraph* const graph, const ValueRef idx); void allocate(ComputeGraph* const graph); diff --git a/backends/vulkan/runtime/graph/containers/Value.h b/backends/vulkan/runtime/graph/containers/Value.h index 2e5da86a723..2b3da29dde9 100644 --- a/backends/vulkan/runtime/graph/containers/Value.h +++ b/backends/vulkan/runtime/graph/containers/Value.h @@ -230,6 +230,14 @@ struct Value final { tag, \ " instead."); \ return payload.member_name; \ + } \ + inline const type& toConst##type_name() const { \ + VK_CHECK_COND( \ + is##type_name(), \ + "Expected value to have type " #type_name ", got ", \ + tag, \ + " instead."); \ + return payload.member_name; \ } SUPPORT_TRIVIALLY_MOVEABLE_TYPE(vTensor, Tensor, TypeTag::TENSOR, as_tensor); diff --git a/backends/vulkan/runtime/graph/ops/glsl/addmm_naive.glsl b/backends/vulkan/runtime/graph/ops/glsl/addmm_naive.glsl new file mode 100644 index 00000000000..dbc87eb7944 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/addmm_naive.glsl @@ -0,0 +1,79 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +$if MAT2_IS_TRANSPOSED: + #define MAT2_IS_TRANSPOSED + +#include "indexing_utils.h" +#include "matmul.h" + +layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly image3D im_out; +layout(set = 0, binding = 1) uniform PRECISION ${SAMPLER_T[NDIM][DTYPE]} im_mat1; +layout(set = 0, binding = 2) uniform PRECISION ${SAMPLER_T[NDIM][DTYPE]} im_mat2; +layout(set = 0, binding = 3) uniform PRECISION ${SAMPLER_T[NDIM][DTYPE]} im_self; + +layout(set = 0, binding = 4) uniform PRECISION restrict OutLimits { + ivec3 out_limits; +}; + +layout(set = 0, binding = 5) uniform PRECISION restrict InSizes { + ivec4 in_sizes; +}; + +layout(set = 0, binding = 6) uniform PRECISION restrict SelfSizes { + ivec3 self_sizes; +}; + +layout(set = 0, binding = 7) uniform PRECISION restrict AddmmParams { + float alpha; + float beta; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + if (any(greaterThanEqual(pos, out_limits))) { + return; + } + + vec4 texel = vec4(0); + + $if MAT1_PACKING == "W_packed": + $if MAT2_PACKING == "H_packed": + ivec3 mat2_pos = ivec3(pos.x * 4, 0, pos.z); + texel = matmul_naive_W_packed_H_packed( + im_mat1, + im_mat2, + pos, + in_sizes[0]); + $elif MAT2_PACKING == "W_packed": + texel = matmul_naive_W_packed_W_packed( + im_mat1, + im_mat2, + pos, + in_sizes[0]); + $else: + $raise Exception("Unsupported value for MAT2_PACKING") + $else: + $raise Exception("Unsupported value combo for MAT1_PACKING and MAT2_PACKING") + + vec4 self_texel = get_texel_W_packed( + im_self, + pos, + self_sizes.x == 1, + self_sizes.y == 1); + + texel = beta * self_texel + alpha * texel; + imageStore(im_out, pos, texel); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/addmm_naive.yaml b/backends/vulkan/runtime/graph/ops/glsl/addmm_naive.yaml new file mode 100644 index 00000000000..48db85cb56e --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/addmm_naive.yaml @@ -0,0 +1,24 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +addmm_naive: + parameter_names_with_default_values: + DTYPE: float + NDIM: 3 + MAT1_PACKING: W_packed + MAT2_PACKING: H_packed + MAT2_IS_TRANSPOSED: false + generate_variant_forall: + DTYPE: + - VALUE: float + - VALUE: half + shader_variants: + - NAME: addmm_naive_W_packed_H_packed + - NAME: addmm_naive_W_packed_W_packed + MAT2_PACKING: W_packed + - NAME: linear_naive_W_packed_W_packed + MAT2_PACKING: W_packed + MAT2_IS_TRANSPOSED: true diff --git a/backends/vulkan/runtime/graph/ops/glsl/addmm_optimized.glsl b/backends/vulkan/runtime/graph/ops/glsl/addmm_optimized.glsl new file mode 100644 index 00000000000..9d45c33704f --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/addmm_optimized.glsl @@ -0,0 +1,84 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +$if MAT2_IS_TRANSPOSED: + #define MAT2_IS_TRANSPOSED + +#include "indexing_utils.h" +#include "matmul.h" + +// addmm will have additional arguments compared to regular mm +layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly image3D im_out; +layout(set = 0, binding = 1) uniform PRECISION ${SAMPLER_T[NDIM][DTYPE]} im_mat1; +layout(set = 0, binding = 2) uniform PRECISION ${SAMPLER_T[NDIM][DTYPE]} im_mat2; +layout(set = 0, binding = 3) uniform PRECISION ${SAMPLER_T[NDIM][DTYPE]} im_self; + +layout(set = 0, binding = 4) uniform PRECISION restrict OutLimits { + ivec3 out_limits; +}; + +layout(set = 0, binding = 5) uniform PRECISION restrict OutSizes { + ivec4 out_sizes; +}; + +layout(set = 0, binding = 6) uniform PRECISION restrict SelfSizes { + ivec4 self_sizes; +}; + +layout(set = 0, binding = 7) uniform PRECISION restrict InLimits { + ivec3 in_limits; +}; + +layout(set = 0, binding = 8) uniform PRECISION restrict Params { + float alpha; + float beta; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + if (any(greaterThanEqual(pos, out_limits))) { + return; + } + + FloatMatrix results = matmul_partial_4x4( + im_mat1, + im_mat2, + pos, + out_sizes[2], + in_limits[0]); + + for (int idx_c = 0; idx_c < FOUR; idx_c++) { + for (int idx_r = 0; idx_r < FOUR; idx_r++) { + const ivec3 out_pos = + ivec3(idx_r + FOUR * pos.x, idx_c + FOUR * pos.y, pos.z); + + vec4 self_texel = get_texel_C_packed( + im_self, + out_pos, + self_sizes.x == 1, + self_sizes.y == 1); + + // results is in transposed order w.r.t. the desired output + imageStore( + im_out, + out_pos, + vec4( + beta * self_texel.x + alpha * results.data[idx_c][idx_r][0], + beta * self_texel.x + alpha * results.data[idx_c][idx_r][1], + beta * self_texel.x + alpha * results.data[idx_c][idx_r][2], + beta * self_texel.x + alpha * results.data[idx_c][idx_r][3])); + } + } +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/addmm_optimized.yaml b/backends/vulkan/runtime/graph/ops/glsl/addmm_optimized.yaml new file mode 100644 index 00000000000..73014d440dd --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/addmm_optimized.yaml @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +addmm_optimized: + parameter_names_with_default_values: + DTYPE: float + NDIM: 3 + PACKING: C_packed + MAT2_IS_TRANSPOSED: false + generate_variant_forall: + DTYPE: + - VALUE: float + - VALUE: half + shader_variants: + - NAME: addmm_optimized + - NAME: linear_optimized + MAT2_IS_TRANSPOSED: true diff --git a/backends/vulkan/runtime/graph/ops/glsl/batchnorm.glsl b/backends/vulkan/runtime/graph/ops/glsl/batchnorm.glsl new file mode 100644 index 00000000000..deb03192af0 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/batchnorm.glsl @@ -0,0 +1,55 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} +#define VEC4_T ${texel_type(DTYPE)} + +layout(std430) buffer; + +layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out; +layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in; +layout(set = 0, binding = 2) uniform PRECISION sampler3D weight_in; +layout(set = 0, binding = 3) uniform PRECISION sampler3D bias_in; +layout(set = 0, binding = 4) uniform PRECISION sampler3D mean_in; +layout(set = 0, binding = 5) uniform PRECISION sampler3D var_in; + +layout(set = 0, binding = 6) uniform PRECISION restrict OutLimits { + ivec3 out_limits; +}; + +layout(set = 0, binding = 7) uniform PRECISION restrict Params { + float eps; +}; + +layout(set = 0, binding = 8) uniform PRECISION restrict Params2 { + int num_texel_per_batch; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + ivec3 pos = ivec3(gl_GlobalInvocationID); + if (any(greaterThanEqual(pos, out_limits))) { + return; + } + + VEC4_T v = VEC4_T(texelFetch(image_in, pos, 0)); + + ivec3 param_pos = ivec3(pos.z % num_texel_per_batch, 0, 0); + + VEC4_T weight = VEC4_T(texelFetch(weight_in, param_pos, 0)); + VEC4_T bias = VEC4_T(texelFetch(bias_in, param_pos, 0)); + VEC4_T mean = VEC4_T(texelFetch(mean_in, param_pos, 0)); + VEC4_T var = VEC4_T(texelFetch(var_in, param_pos, 0)); + + v = ((v - mean) / sqrt(var + eps)) * weight + bias; + + imageStore(image_out, pos, v); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/batchnorm.yaml b/backends/vulkan/runtime/graph/ops/glsl/batchnorm.yaml new file mode 100644 index 00000000000..a92e44f636b --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/batchnorm.yaml @@ -0,0 +1,10 @@ +batchnorm: + parameter_names_with_default_values: + DTYPE: float + NDIM: 3 + generate_variant_forall: + DTYPE: + - VALUE: half + - VALUE: float + shader_variants: + - NAME: batchnorm diff --git a/backends/vulkan/runtime/graph/ops/glsl/matmul.h b/backends/vulkan/runtime/graph/ops/glsl/matmul.h index f157828f616..5a7f6795879 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/matmul.h +++ b/backends/vulkan/runtime/graph/ops/glsl/matmul.h @@ -16,38 +16,66 @@ struct FloatMatrix { float data[FOUR][FOUR][FOUR]; }; +#ifdef MAT2_IS_TRANSPOSED +vec4 matmul_naive_W_packed_W_packed( +#else vec4 matmul_naive_W_packed_H_packed( - sampler3D im_mat1, - sampler3D im_mat2, - ivec3 mat1_pos, - ivec3 mat2_pos, - int width) { +#endif + const sampler3D im_mat1, + const sampler3D im_mat2, + const ivec3 out_pos, + const int width) { + ivec3 mat1_pos = ivec3(0, out_pos.y, out_pos.z); +#ifdef MAT2_IS_TRANSPOSED + ivec3 mat2_pos = ivec3(0, out_pos.x * 4, 0); +#else + ivec3 mat2_pos = ivec3(out_pos.x * 4, 0, out_pos.z); +#endif + vec4 texel = vec4(0); - int K = (width + 3) / 4; + const int K = (width + 3) / 4; for (int i = 0; i < K; ++i) { - vec4 mat1_tex = texelFetch(im_mat1, mat1_pos, 0); - vec4 sums = vec4( + const vec4 mat1_tex = texelFetch(im_mat1, mat1_pos, 0); +#ifdef MAT2_IS_TRANSPOSED + const vec4 sums = vec4( + dot(mat1_tex, texelFetch(im_mat2, mat2_pos, 0)), + dot(mat1_tex, texelFetch(im_mat2, mat2_pos + ivec3(0, 1, 0), 0)), + dot(mat1_tex, texelFetch(im_mat2, mat2_pos + ivec3(0, 2, 0), 0)), + dot(mat1_tex, texelFetch(im_mat2, mat2_pos + ivec3(0, 3, 0), 0))); +#else + const vec4 sums = vec4( dot(mat1_tex, texelFetch(im_mat2, mat2_pos, 0)), dot(mat1_tex, texelFetch(im_mat2, mat2_pos + ivec3(1, 0, 0), 0)), dot(mat1_tex, texelFetch(im_mat2, mat2_pos + ivec3(2, 0, 0), 0)), dot(mat1_tex, texelFetch(im_mat2, mat2_pos + ivec3(3, 0, 0), 0))); +#endif texel += sums; mat1_pos.x++; +#ifdef MAT2_IS_TRANSPOSED + mat2_pos.x++; +#else mat2_pos.y++; +#endif } return texel; } +#ifdef MAT2_IS_TRANSPOSED +vec4 matmul_naive_W_packed_H_packed( +#else vec4 matmul_naive_W_packed_W_packed( - sampler3D im_mat1, - sampler3D im_mat2, - ivec3 mat1_pos, - ivec3 mat2_pos, - int width) { +#endif + const sampler3D im_mat1, + const sampler3D im_mat2, + const ivec3 out_pos, + const int width) { + ivec3 mat1_pos = ivec3(0, out_pos.y, out_pos.z); + ivec3 mat2_pos = ivec3(out_pos.x, 0, out_pos.z); + vec4 texel = vec4(0); int K = divup4(width); @@ -71,23 +99,23 @@ vec4 matmul_naive_W_packed_W_packed( // get texel from self tensor (width_packed) in addmm vec4 get_texel_W_packed( sampler3D im_self, - ivec3 pos, - int broadcast_at_width, - int broadcast_at_height) { + const ivec3 pos, + const bool broadcast_at_width, + const bool broadcast_at_height) { vec4 self_texel; // self is of shape {1} - if (broadcast_at_width == 1 && broadcast_at_height == 1) { + if (broadcast_at_width && broadcast_at_height) { self_texel = texelFetch(im_self, ivec3(0, 0, 0), 0).xxxx; } // self is of shape {*, 1} - else if (broadcast_at_width == 1) { + else if (broadcast_at_width) { self_texel = texelFetch(im_self, ivec3(0, pos.y, 0), 0).xxxx; } // self is of shape {1, *} - else if (broadcast_at_height == 1) { + else if (broadcast_at_height) { self_texel = texelFetch(im_self, ivec3(pos.x, 0, 0), 0); } else { - self_texel = texelFetch(im_self, pos, 0); + self_texel = texelFetch(im_self, ivec3(pos.x, pos.y, 0), 0); } return self_texel; @@ -96,23 +124,23 @@ vec4 get_texel_W_packed( // get texel from self tensor (channel_packed) in addmm vec4 get_texel_C_packed( sampler3D im_self, - ivec3 pos, - int broadcast_at_width, - int broadcast_at_height) { + const ivec3 pos, + const bool broadcast_at_width, + const bool broadcast_at_height) { vec4 self_texel; // self is of shape {1} - if (broadcast_at_width == 1 && broadcast_at_height == 1) { + if (broadcast_at_width && broadcast_at_height) { self_texel = texelFetch(im_self, ivec3(0, 0, 0), 0); } // self is of shape {*, 1} - else if (broadcast_at_width == 1) { + else if (broadcast_at_width) { self_texel = texelFetch(im_self, ivec3(0, pos.y, 0), 0); } // self is of shape {1, *} - else if (broadcast_at_height == 1) { + else if (broadcast_at_height) { self_texel = texelFetch(im_self, ivec3(pos.x, 0, 0), 0); } else { - self_texel = texelFetch(im_self, pos, 0); + self_texel = texelFetch(im_self, ivec3(pos.x, pos.y, 0), 0); } return self_texel; @@ -121,10 +149,9 @@ vec4 get_texel_C_packed( FloatMatrix matmul_partial_4x4( sampler3D im_mat1, sampler3D im_mat2, - ivec3 pos, - int batch_size, - int step_size, - int reminder) { + const ivec3 pos, + const int batch_size, + const int K_texel_len) { FloatMatrix results; for (int i = 0; i < FOUR; i++) { for (int j = 0; j < FOUR; j++) { @@ -133,43 +160,36 @@ FloatMatrix matmul_partial_4x4( } } } - // read and cache 4x4 tile of im_mat1 (4 adjacent rows) - vec4 im_mat1_partial_rows[FOUR]; - vec4 im_mat2_partial_cols[FOUR]; + vec4 im_mat1_partial_load[FOUR]; + vec4 im_mat2_partial_load[FOUR]; - for (int c = 0; c < FOUR; c++) { - if (FOUR * pos.z + c >= batch_size) { + for (int batch_idx = 0; batch_idx < FOUR; batch_idx++) { + if (FOUR * pos.z + batch_idx >= batch_size) { break; } - for (int j = 0; j < step_size; j++) { - for (int k = 0; k < FOUR; k++) { - const int pos_y_offset = (FOUR * pos.y) + k; - const ivec3 pos_rd = ivec3(j, pos_y_offset, FOUR * pos.z + c); - im_mat1_partial_rows[k] = texelFetch(im_mat1, pos_rd, 0); - // set the value out of the boundary to be 0 - if (j == step_size - 1 && reminder > 0) { - for (int kk = 0; kk < 4 - reminder; kk++) { - im_mat1_partial_rows[k][3 - kk] = 0; - } - } - } - // read and cache 4x4 tile of im_mat2 (4 adjacent columns) - for (int k = 0; k < FOUR; k++) { - const int pos_x_offset = (FOUR * pos.x) + k; - const ivec3 pos_rd = ivec3(pos_x_offset, j, FOUR * pos.z + c); - im_mat2_partial_cols[k] = texelFetch(im_mat2, pos_rd, 0); - // set the value out of the boundary to be 0 - if (j == step_size - 1 && reminder > 0) { - for (int kk = 0; kk < 4 - reminder; kk++) { - im_mat2_partial_cols[k][3 - kk] = 0; - } - } + int mat_z = FOUR * pos.z + batch_idx; + for (int mat1_x = 0; mat1_x < K_texel_len; mat1_x++) { + for (int offset = 0; offset < FOUR; offset++) { + // read and cache 4x4 tile of im_mat1 + const int mat1_y = (FOUR * pos.y) + offset; + const ivec3 mat1_pos = ivec3(mat1_x, mat1_y, mat_z); + im_mat1_partial_load[offset] = texelFetch(im_mat1, mat1_pos, 0); + // read and cache 4x4 tile of im_mat2 +#ifdef MAT2_IS_TRANSPOSED + const int mat2_y = (FOUR * pos.x) + offset; + const ivec3 mat2_pos = ivec3(mat1_x, mat2_y, 0); + im_mat2_partial_load[offset] = texelFetch(im_mat2, mat2_pos, 0); +#else + const int mat2_x = (FOUR * pos.x) + offset; + const ivec3 mat2_pos = ivec3(mat2_x, mat1_x, mat_z); + im_mat2_partial_load[offset] = texelFetch(im_mat2, mat2_pos, 0); +#endif } // perform partial dot products and add partial result to results - for (int idx_r = 0; idx_r < FOUR; idx_r++) { - for (int idx_c = 0; idx_c < FOUR; idx_c++) { - results.data[idx_r][idx_c][c] += - dot(im_mat1_partial_rows[idx_r], im_mat2_partial_cols[idx_c]); + for (int out_row = 0; out_row < FOUR; out_row++) { + for (int out_col = 0; out_col < FOUR; out_col++) { + results.data[out_row][out_col][batch_idx] += + dot(im_mat1_partial_load[out_row], im_mat2_partial_load[out_col]); } } } diff --git a/backends/vulkan/runtime/graph/ops/glsl/matmul_naive.glsl b/backends/vulkan/runtime/graph/ops/glsl/matmul_naive.glsl index cb8371cb5df..37a9b60f3c5 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/matmul_naive.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/matmul_naive.glsl @@ -10,43 +10,23 @@ #define PRECISION ${PRECISION} +$if MAT2_IS_TRANSPOSED: + #define MAT2_IS_TRANSPOSED + #include "indexing_utils.h" #include "matmul.h" -$if IS_ADDMM: - // addmm will have additional arguments compared to regular mm - layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly image3D im_out; - layout(set = 0, binding = 1) uniform PRECISION ${SAMPLER_T[NDIM][DTYPE]} im_mat1; - layout(set = 0, binding = 2) uniform PRECISION ${SAMPLER_T[NDIM][DTYPE]} im_mat2; - layout(set = 0, binding = 3) uniform PRECISION ${SAMPLER_T[NDIM][DTYPE]} im_self; - - layout(set = 0, binding = 4) uniform PRECISION restrict OutLimits { - ivec3 out_limits; - }; - - layout(set = 0, binding = 5) uniform PRECISION restrict InSizes { - ivec4 in_sizes; - }; +layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly image3D im_out; +layout(set = 0, binding = 1) uniform PRECISION ${SAMPLER_T[NDIM][DTYPE]} im_mat1; +layout(set = 0, binding = 2) uniform PRECISION ${SAMPLER_T[NDIM][DTYPE]} im_mat2; - layout(set = 0, binding = 6) uniform PRECISION restrict AddmmParams { - int broadcast_at_width; - int broadcast_at_height; - float alpha; - float beta; - }; -$else: - // define original matmul_naive arguments - layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly image3D im_out; - layout(set = 0, binding = 1) uniform PRECISION ${SAMPLER_T[NDIM][DTYPE]} im_mat1; - layout(set = 0, binding = 2) uniform PRECISION ${SAMPLER_T[NDIM][DTYPE]} im_mat2; +layout(set = 0, binding = 3) uniform PRECISION restrict OutLimits { + ivec3 out_limits; +}; - layout(set = 0, binding = 3) uniform PRECISION restrict OutLimits { - ivec3 out_limits; - }; - - layout(set = 0, binding = 4) uniform PRECISION restrict InSizes { - ivec4 in_sizes; - }; +layout(set = 0, binding = 4) uniform PRECISION restrict InSizes { + ivec4 in_sizes; +}; layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; @@ -58,23 +38,24 @@ void main() { } vec4 texel = vec4(0); - ivec3 mat1_pos = ivec3(0, pos.y, pos.z); $if MAT1_PACKING == "W_packed": - $if MAT2_PACKING == "H_packed": - ivec3 mat2_pos = ivec3(pos.x * 4, 0, pos.z); - texel = matmul_naive_W_packed_H_packed(im_mat1, im_mat2, mat1_pos, mat2_pos, in_sizes[0]); - $elif MAT2_PACKING == "W_packed": - ivec3 mat2_pos = ivec3(pos.x, 0, pos.z); - texel = matmul_naive_W_packed_W_packed(im_mat1, im_mat2, mat1_pos, mat2_pos, in_sizes[0]); - $else: - $raise Exception("Unsupported value for MAT2_PACKING") + $if MAT2_PACKING == "H_packed": + texel = matmul_naive_W_packed_H_packed( + im_mat1, + im_mat2, + pos, + in_sizes[0]); + $elif MAT2_PACKING == "W_packed": + texel = matmul_naive_W_packed_W_packed( + im_mat1, + im_mat2, + pos, + in_sizes[0]); + $else: + $raise Exception("Unsupported value for MAT2_PACKING") $else: $raise Exception("Unsupported value combo for MAT1_PACKING and MAT2_PACKING") - $if IS_ADDMM: - vec4 self_texel = get_texel_W_packed(im_self, pos, broadcast_at_width, broadcast_at_height); - texel = beta * self_texel + alpha * texel; - imageStore(im_out, pos, texel); } diff --git a/backends/vulkan/runtime/graph/ops/glsl/matmul_naive.yaml b/backends/vulkan/runtime/graph/ops/glsl/matmul_naive.yaml index 32cff0cf09e..1c4db3f0ce9 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/matmul_naive.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/matmul_naive.yaml @@ -10,17 +10,15 @@ matmul_naive: NDIM: 3 MAT1_PACKING: W_packed MAT2_PACKING: H_packed - IS_ADDMM: false + MAT2_IS_TRANSPOSED: false generate_variant_forall: DTYPE: - VALUE: float - VALUE: half shader_variants: - NAME: matmul_naive_W_packed_H_packed - - NAME: addmm_naive_W_packed_H_packed - IS_ADDMM: true - NAME: matmul_naive_W_packed_W_packed MAT2_PACKING: W_packed - - NAME: addmm_naive_W_packed_W_packed + - NAME: matmul_transposed_naive_W_packed_W_packed MAT2_PACKING: W_packed - IS_ADDMM: true + MAT2_IS_TRANSPOSED: true diff --git a/backends/vulkan/runtime/graph/ops/glsl/matmul_optimized.glsl b/backends/vulkan/runtime/graph/ops/glsl/matmul_optimized.glsl index b9f62cc6593..f39bea12be3 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/matmul_optimized.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/matmul_optimized.glsl @@ -10,59 +10,27 @@ #define PRECISION ${PRECISION} +$if MAT2_IS_TRANSPOSED: + #define MAT2_IS_TRANSPOSED + #include "indexing_utils.h" #include "matmul.h" -$if IS_ADDMM: - // addmm will have additional arguments compared to regular mm - layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly image3D im_out; - layout(set = 0, binding = 1) uniform PRECISION ${SAMPLER_T[NDIM][DTYPE]} im_mat1; - layout(set = 0, binding = 2) uniform PRECISION ${SAMPLER_T[NDIM][DTYPE]} im_mat2; - layout(set = 0, binding = 3) uniform PRECISION ${SAMPLER_T[NDIM][DTYPE]} im_self; - - layout(set = 0, binding = 4) uniform PRECISION restrict OutLimits { - ivec3 out_limits; - }; - - layout(set = 0, binding = 5) uniform PRECISION restrict StepSize { - int step_size; - }; - - layout(set = 0, binding = 6) uniform PRECISION restrict Reminder { - int reminder; - }; - - layout(set = 0, binding = 7) uniform PRECISION restrict BatchSize { - int batch_size; - }; - - layout(set = 0, binding = 8) uniform PRECISION restrict AddmmParams { - int broadcast_at_width; - int broadcast_at_height; - float alpha; - float beta; - }; -$else: - // define original matmul_optimized arguments - layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly image3D im_out; - layout(set = 0, binding = 1) uniform PRECISION ${SAMPLER_T[NDIM][DTYPE]} im_mat1; - layout(set = 0, binding = 2) uniform PRECISION ${SAMPLER_T[NDIM][DTYPE]} im_mat2; - - layout(set = 0, binding = 3) uniform PRECISION restrict OutLimits { - ivec3 out_limits; - }; +layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly image3D im_out; +layout(set = 0, binding = 1) uniform PRECISION ${SAMPLER_T[NDIM][DTYPE]} im_mat1; +layout(set = 0, binding = 2) uniform PRECISION ${SAMPLER_T[NDIM][DTYPE]} im_mat2; - layout(set = 0, binding = 4) uniform PRECISION restrict StepSize { - int step_size; - }; +layout(set = 0, binding = 3) uniform PRECISION restrict OutLimits { + ivec3 out_limits; +}; - layout(set = 0, binding = 5) uniform PRECISION restrict Reminder { - int reminder; - }; +layout(set = 0, binding = 4) uniform PRECISION restrict OutSizes { + ivec4 out_sizes; +}; - layout(set = 0, binding = 6) uniform PRECISION restrict BatchSize { - int batch_size; - }; +layout(set = 0, binding = 5) uniform PRECISION restrict InLimits { + ivec3 in_limits; +}; layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; @@ -73,15 +41,17 @@ void main() { return; } - FloatMatrix results = matmul_partial_4x4(im_mat1, im_mat2, pos, batch_size, step_size, reminder); + FloatMatrix results = matmul_partial_4x4( + im_mat1, + im_mat2, + pos, + out_sizes[2], + in_limits[0]); for (int idx_c = 0; idx_c < FOUR; idx_c++) { for (int idx_r = 0; idx_r < FOUR; idx_r++) { const ivec3 out_pos = ivec3(idx_r + FOUR * pos.x, idx_c + FOUR * pos.y, pos.z); - $if IS_ADDMM: - vec4 self_texel = get_texel_C_packed(im_self, out_pos, broadcast_at_width, broadcast_at_height); - results.data[idx_c][idx_r][0] = beta * self_texel.x + alpha * results.data[idx_c][idx_r][0]; // results is in transposed order w.r.t. the desired output imageStore( diff --git a/backends/vulkan/runtime/graph/ops/glsl/matmul_optimized.yaml b/backends/vulkan/runtime/graph/ops/glsl/matmul_optimized.yaml index 250d2b1a5b9..ecc62f7ca3c 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/matmul_optimized.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/matmul_optimized.yaml @@ -9,12 +9,12 @@ matmul_optimized: DTYPE: float NDIM: 3 PACKING: C_packed - IS_ADDMM: false + MAT2_IS_TRANSPOSED: false generate_variant_forall: DTYPE: - VALUE: float - VALUE: half shader_variants: - NAME: matmul_optimized - - NAME: addmm_optimized - IS_ADDMM: true + - NAME: matmul_transposed_optimized + MAT2_IS_TRANSPOSED: true diff --git a/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml b/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml index 2d8ec36d9a8..14e4e111a2d 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml @@ -12,9 +12,11 @@ unary_op: OPERATOR: abs(X) - NAME: clamp OPERATOR: clamp(X, A, B) + - NAME: gelu + OPERATOR: 0.5 * X * (1 + tanh(sqrt(2 / 3.141593) * (X + 0.044715 * X * X * X))) - NAME: sigmoid OPERATOR: 1 / (1 + exp(-1 * X)) - - NAME: tanh - OPERATOR: tanh(clamp(X, -15.0, 15.0)) - NAME: sqrt OPERATOR: sqrt(X) + - NAME: tanh + OPERATOR: tanh(clamp(X, -15.0, 15.0)) diff --git a/backends/vulkan/runtime/graph/ops/glsl/view.glsl b/backends/vulkan/runtime/graph/ops/glsl/view.glsl index 2429c841c9c..6680baad031 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/view.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/view.glsl @@ -35,7 +35,7 @@ layout(constant_id = 4) const int out_packed_dim = C_DIM; void main() { const ivec3 out_pos = ivec3(gl_GlobalInvocationID); - const ivec4 out_tensor_idx = to_tensor_idx(out_pos, out_sizes, out_packed_dim); + ivec4 out_tensor_idx = to_tensor_idx(out_pos, out_sizes, out_packed_dim); if (all(greaterThanEqual(out_tensor_idx, out_sizes))) { return; @@ -46,13 +46,15 @@ void main() { // the input position from the indx. const ivec4 buf_indices = get_texel_nchw_buffer_ixs(out_tensor_idx, out_sizes, out_packed_dim); - VEC4_T value; + VEC4_T value = VEC4_T(0); // Need to look up the 4 values in the output texel separately. - for (int i =0 ; i < 4; i++) { - ivec4 user_coor = from_nchw_buffer_i(buf_indices[i], in_sizes); - ivec4 in_pos_elem = to_texture_elem_pos(user_coor, in_sizes, in_packed_dim); - VEC4_T intex = texelFetch(image_in, in_pos_elem.xyz, 0); - value[i] = intex[in_pos_elem.w]; + for (int i = 0 ; i < 4; i++) { + if (out_tensor_idx[out_packed_dim]++ < out_sizes[out_packed_dim]) { + ivec4 user_coor = from_nchw_buffer_i(buf_indices[i], in_sizes); + ivec4 in_pos_elem = to_texture_elem_pos(user_coor, in_sizes, in_packed_dim); + VEC4_T intex = texelFetch(image_in, in_pos_elem.xyz, 0); + value[i] = intex[in_pos_elem.w]; + } } imageStore(image_out, out_pos, value); diff --git a/backends/vulkan/runtime/graph/ops/impl/BatchNorm.cpp b/backends/vulkan/runtime/graph/ops/impl/BatchNorm.cpp new file mode 100644 index 00000000000..7ea541aab46 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/BatchNorm.cpp @@ -0,0 +1,110 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include +#include +#include + +#include + +namespace vkcompute { + +ValueRef prepack_arg( + ComputeGraph& graph, + ValueRef arg_ref, + int64_t num_channels, + const std::string& debug_name) { + VK_CHECK_COND( + graph.val_is_tref(arg_ref), + "native_batch_norm requires ", + debug_name, + " to be a constant tensorref"); + VK_CHECK_COND(graph.get_tref(arg_ref)->sizes[0] == num_channels); + + // batch_norm's param are broadcasted on the channel dimension. + // In this implementation, we pack the weights along the x dimension, and + // in the shader, we lookup using the along the x. + return prepack_if_tensor_ref(graph, arg_ref, api::kWidthPacked); +} + +void add_native_batch_norm_node( + ComputeGraph& graph, + ValueRef in_ref, + ValueRef weight_ref, + ValueRef bias_ref, + ValueRef mean_ref, + ValueRef var_ref, + ValueRef eps_ref, + ValueRef out_tuple_ref) { + std::vector in_sizes = graph.get_tensor(in_ref)->sizes(); + std::vector out_sizes = graph.get_tensor(in_ref)->sizes(); + + VK_CHECK_COND(in_sizes.size() == 4, "BatchNorm only support 4d tensor"); + VK_CHECK_COND(out_sizes.size() == 4, "BatchNorm only support 4d tensor"); + + int64_t num_channels = dim_at(in_sizes); + + ValueRef arg_weight = prepack_arg(graph, weight_ref, num_channels, "weight"); + ValueRef arg_bias = prepack_arg(graph, bias_ref, num_channels, "bias"); + ValueRef arg_mean = prepack_arg(graph, mean_ref, num_channels, "mean"); + ValueRef arg_var = prepack_arg(graph, var_ref, num_channels, "var"); + float epsilon = graph.extract_scalar(eps_ref); + + vTensorPtr t_in = graph.get_tensor(in_ref); + + // Only the first element of the return value is propagated. The remaining 2 + // elements are zero-size dummy tensor. + const auto out_tuple_val = graph.get_value_list(out_tuple_ref); + + ValueRef out_ref = out_tuple_val->at(0); + + VK_CHECK_COND(!graph.val_is_tref(out_ref), "Output should not be tref"); + vTensorPtr t_out = graph.get_tensor(out_ref); + + VK_CHECK_COND( + dim_at(t_out->sizes()) == num_channels, + "out channel must match in channel"); + + std::string kernel_name = "batchnorm"; + add_dtype_suffix(kernel_name, *t_out); + + api::utils::uvec3 global_size = t_out->extents(); + api::utils::uvec3 local_size = adaptive_work_group_size(global_size); + + int32_t num_texel_per_batch = + api::utils::div_up((dim_at(t_in->sizes())), 4); + + graph.execute_nodes().emplace_back(new ExecuteNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + global_size, + local_size, + {{out_ref, api::MemoryAccessType::WRITE}, + {{in_ref, arg_weight, arg_bias, arg_mean, arg_var}, + api::MemoryAccessType::READ}}, + {t_out->texture_limits_ubo(), + graph.create_params_buffer(epsilon), + graph.create_params_buffer(num_texel_per_batch)})); +} + +void native_batch_norm(ComputeGraph& graph, const std::vector& args) { + // args[5] is momentum. It is not used in the calculation. + return add_native_batch_norm_node( + graph, args[0], args[1], args[2], args[3], args[4], args[6], args[7]); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP( + aten._native_batch_norm_legit_no_training.default, native_batch_norm); +} + +} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp b/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp index 4221f6f373e..d457f637d47 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp @@ -90,11 +90,11 @@ ValueRef prepack_biases( const bool transposed, const api::StorageType storage_type, const api::GPUMemoryLayout memory_layout) { - auto sizes = graph.get_sizes_of(weight); + auto sizes = graph.sizes_of(weight); const int64_t out_channels = transposed ? sizes.at(1) : sizes.at(0); ValueRef v = graph.add_tensor( - {out_channels}, graph.get_dtype_of(weight), storage_type, memory_layout); + {out_channels}, graph.dtype_of(weight), storage_type, memory_layout); vTensorPtr t = graph.get_tensor(v); api::ShaderInfo shader = get_nchw_to_image_shader(*t); @@ -193,14 +193,11 @@ ValueRef prepack_weights( ComputeGraph& graph, const ValueRef vref, const Conv2dMethod method) { - const auto original_sizes = graph.get_sizes_of(vref); + const auto original_sizes = graph.sizes_of(vref); const auto final_sizes = get_final_sizes(original_sizes, method); ValueRef v = graph.add_tensor( - final_sizes, - graph.get_dtype_of(vref), - api::kTexture2D, - api::kChannelsPacked); + final_sizes, graph.dtype_of(vref), api::kTexture2D, api::kChannelsPacked); vTensorPtr t = graph.get_tensor(v); api::utils::uvec3 global_size = t->extents(); @@ -246,7 +243,7 @@ Conv2dParams create_conv2d_params( p.kernel_size.data[1] + (p.kernel_size.data[1] - 1) * (p.dilation.data[1] - 1), }); - const auto weight_sizes = graph.get_sizes_of(weight); + const auto weight_sizes = graph.sizes_of(weight); const int32_t in_group_size = api::utils::safe_downcast(api::utils::align_up( transposed ? weight_sizes.at(0) : weight_sizes.at(1), INT64_C(4))); @@ -274,7 +271,7 @@ Conv2dMethod get_conv2d_method( const ValueRef weight, const int64_t groups, const bool transposed) { - const auto weight_sizes = graph.get_sizes_of(weight); + const auto weight_sizes = graph.sizes_of(weight); if (!transposed && weight_sizes.at(0) == groups && weight_sizes.at(1) == 1) { return Conv2dMethod::Depthwise; } diff --git a/backends/vulkan/runtime/graph/ops/impl/Linear.cpp b/backends/vulkan/runtime/graph/ops/impl/Linear.cpp new file mode 100644 index 00000000000..8c963579da9 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/Linear.cpp @@ -0,0 +1,270 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include + +#include +#include + +#include + +namespace vkcompute { + +void check_addmm_args( + ComputeGraph& graph, + const ValueRef self, + const ValueRef mat1, + const ValueRef mat2_data, + const ValueRef beta, + const ValueRef alpha, + const ValueRef out) { + (void)alpha; + (void)beta; + + std::vector self_sizes = graph.sizes_of(self); + std::vector mat1_sizes = graph.sizes_of(mat1); + std::vector mat2_sizes = graph.sizes_of(mat2_data); + + VK_CHECK_COND(mat1_sizes.size() == 2 || mat1_sizes.size() == 3); + VK_CHECK_COND(mat1_sizes.size() == mat2_sizes.size()); + + VK_CHECK_COND(graph.memory_layout_of(mat1) == graph.memory_layout_of(out)); + + VK_CHECK_COND( + api::utils::val_at(-1, mat1_sizes) == api::utils::val_at(-2, mat2_sizes)); + + if (api::utils::val_at(-1, self_sizes) != 1) { + VK_CHECK_COND( + api::utils::val_at(-1, self_sizes) == + api::utils::val_at(-1, mat2_sizes)); + } + if (api::utils::val_at(-2, self_sizes) != 1) { + VK_CHECK_COND( + api::utils::val_at(-2, self_sizes) == + api::utils::val_at(-2, mat1_sizes)); + } +} + +void resize_addmm_node( + ComputeGraph* graph, + const std::vector& args, + const std::vector& extra_args) { + vTensorPtr out = graph->get_tensor(args[0].refs[0]); + vTensorPtr mat1 = graph->get_tensor(args[1].refs[0]); + vTensorPtr mat2 = graph->get_tensor(args[1].refs[1]); + vTensorPtr self = graph->get_tensor(args[1].refs[2]); + + bool mat2_is_transposed = graph->get_bool(extra_args[0]); + + const int out_cols = api::utils::val_at(-2, mat1->sizes()); + const int out_rows = mat2_is_transposed + ? api::utils::val_at(-2, mat2->sizes()) + : api::utils::val_at(-1, mat2->sizes()); + + std::vector new_out_sizes(3); + if (mat1->sizes().size() == 2) { + new_out_sizes.resize(2); + new_out_sizes.at(0) = out_cols; + new_out_sizes.at(1) = out_rows; + } else { + new_out_sizes.at(0) = mat1->sizes().at(0); + new_out_sizes.at(1) = out_cols; + new_out_sizes.at(2) = out_rows; + } + + out->virtual_resize(new_out_sizes); +} + +struct Params final { + float alpha; + float beta; +}; + +void add_addmm_naive_node( + ComputeGraph& graph, + const ValueRef self_data, + const ValueRef mat1, + const ValueRef mat2_data, + const ValueRef beta, + const ValueRef alpha, + const ValueRef out, + const Params& params, + const ValueRef mat2_is_transposed) { + ValueRef self = prepack_if_tensor_ref(graph, self_data, api::kWidthPacked); + ValueRef mat2 = prepack_if_tensor_ref(graph, mat2_data, api::kHeightPacked); + + api::utils::uvec3 global_size = graph.extents_of(out); + api::utils::uvec3 local_size = adaptive_work_group_size(global_size); + + std::string kernel_name = + graph.get_bool(mat2_is_transposed) ? "linear_naive" : "addmm_naive"; + kernel_name.reserve(kShaderNameReserve); + add_memory_layout_suffix(kernel_name, graph.memory_layout_of(mat1)); + add_memory_layout_suffix(kernel_name, graph.memory_layout_of(mat2)); + add_dtype_suffix(kernel_name, graph.dtype_of(out)); + + graph.execute_nodes().emplace_back(new ExecuteNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + global_size, + local_size, + // Inputs and Outputs + {{out, api::MemoryAccessType::WRITE}, + {{mat1, mat2, self}, api::MemoryAccessType::READ}}, + // Shader params buffers + { + graph.texture_limits_ubo(out), + graph.sizes_ubo(mat1), + graph.sizes_ubo(self), + graph.create_params_buffer(params), + }, + // Specialization Constants + {}, + // Resizing Logic + resize_addmm_node, + {mat2_is_transposed})); +} + +void add_addmm_optimized_node( + ComputeGraph& graph, + const ValueRef self_data, + const ValueRef mat1, + const ValueRef mat2_data, + const ValueRef beta, + const ValueRef alpha, + const ValueRef out, + const Params& params, + const ValueRef mat2_is_transposed) { + ValueRef self = prepack_if_tensor_ref(graph, self_data, api::kChannelsPacked); + ValueRef mat2 = prepack_if_tensor_ref(graph, mat2_data, api::kHeightPacked); + + // Ensure mat1 is width packed + ValueRef mat1_W_packed = graph.add_tensor_like(mat1, api::kWidthPacked); + auto viewFn = VK_GET_OP_FN("aten.view_copy.default"); + viewFn(graph, {mat1, graph.add_none(), mat1_W_packed}); + + const bool mat2_is_transposed_val = graph.get_bool(mat2_is_transposed); + + // Ensure mat2 is height packed + ValueRef mat2_packed = mat2; + const api::GPUMemoryLayout mat2_layout = + mat2_is_transposed_val ? api::kWidthPacked : api::kHeightPacked; + if (graph.memory_layout_of(mat2) != mat2_layout) { + mat2_packed = graph.add_tensor_like(mat2, mat2_layout); + viewFn(graph, {mat2, graph.add_none(), mat2_packed}); + } + + api::utils::uvec3 global_size = + api::utils::divup_vec(graph.extents_of(out), {4, 4, 1}); + api::utils::uvec3 local_size = adaptive_work_group_size(global_size); + + std::string kernel_name = graph.get_bool(mat2_is_transposed) + ? "linear_optimized" + : "addmm_optimized"; + add_dtype_suffix(kernel_name, graph.dtype_of(out)); + + graph.execute_nodes().emplace_back(new ExecuteNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + global_size, + local_size, + // Inputs and Outputs + {{out, api::MemoryAccessType::WRITE}, + {{mat1_W_packed, mat2_packed, self}, api::MemoryAccessType::READ}}, + // Shader params buffers + { + graph.texture_limits_ubo(out), + graph.sizes_ubo(out), + graph.sizes_ubo(self), + graph.texture_limits_ubo(mat1_W_packed), + graph.create_params_buffer(params), + }, + // Specialization Constants + {}, + // Resizing Logic + resize_addmm_node, + {mat2_is_transposed})); +} + +void add_addmm_node( + ComputeGraph& graph, + const ValueRef self, + const ValueRef mat1, + const ValueRef mat2, + const ValueRef beta, + const ValueRef alpha, + const ValueRef out, + const ValueRef mat2_is_transposed) { + float alpha_val = 1.0f; + float beta_val = 1.0f; + + if (alpha != kDummyValueRef) { + alpha_val = graph.extract_scalar(alpha); + } + if (beta != kDummyValueRef) { + beta_val = graph.extract_scalar(beta); + } + + Params params = {alpha_val, beta_val}; + if (graph.memory_layout_of(mat1) == api::kChannelsPacked) { + add_addmm_optimized_node( + graph, self, mat1, mat2, beta, alpha, out, params, mat2_is_transposed); + } else if (graph.memory_layout_of(mat1) == api::kWidthPacked) { + add_addmm_naive_node( + graph, self, mat1, mat2, beta, alpha, out, params, mat2_is_transposed); + } else { + VK_THROW("Input should be channel packed or width packed."); + } +} + +void addmm(ComputeGraph& graph, const std::vector& args) { + check_addmm_args(graph, args[0], args[1], args[2], args[3], args[4], args[5]); + ValueRef mat2_is_transposed = graph.add_scalar(false); + return add_addmm_node( + graph, + args[0], + args[1], + args[2], + args[3], + args[4], + args[5], + mat2_is_transposed); +} + +void linear(ComputeGraph& graph, const std::vector& args) { + ValueRef input = args.at(0); + ValueRef weight_data = args.at(1); + ValueRef bias = args.at(2); + ValueRef out = args.at(3); + ValueRef weight = + prepack_if_tensor_ref(graph, weight_data, api::kWidthPacked); + ValueRef mat2_is_transposed = graph.add_scalar(true); + if (graph.val_is_none(bias)) { + return add_matmul_node(graph, input, weight, out, mat2_is_transposed); + } else { + return add_addmm_node( + graph, + bias, + input, + weight, + kDummyValueRef, + kDummyValueRef, + out, + mat2_is_transposed); + } +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(aten.addmm.default, addmm); + VK_REGISTER_OP(aten.linear.default, linear); +} + +} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/MatMul.cpp b/backends/vulkan/runtime/graph/ops/impl/MatMul.cpp index 45a74636ea5..0bdfad1c23a 100644 --- a/backends/vulkan/runtime/graph/ops/impl/MatMul.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/MatMul.cpp @@ -8,6 +8,7 @@ #include +#include #include #include @@ -18,363 +19,168 @@ namespace vkcompute { void check_matmul_args( - ComputeGraph& graph, - const ValueRef arg1, - const ValueRef arg2, + const ComputeGraph& graph, + const ValueRef mat1, + const ValueRef mat2_data, const ValueRef out) { - vTensorPtr t_mat1 = graph.get_tensor(arg1); - vTensorPtr t_mat2 = graph.get_tensor(arg2); - vTensorPtr t_out = graph.get_tensor(out); + std::vector mat1_sizes = graph.sizes_of(mat1); + std::vector mat2_sizes = graph.sizes_of(mat2_data); - VK_CHECK_COND(check_ndim_is(*t_mat1, 2) || check_ndim_is(*t_mat1, 3)); - VK_CHECK_COND(check_same_ndim(*t_mat1, *t_mat2)); + VK_CHECK_COND(mat1_sizes.size() == 2 || mat1_sizes.size() == 3); + VK_CHECK_COND(mat1_sizes.size() == mat2_sizes.size()); - VK_CHECK_COND(check_same_memory_layout(*t_mat1, *t_out)); + VK_CHECK_COND(graph.memory_layout_of(mat1) == graph.memory_layout_of(out)); - VK_CHECK_COND(check_same_sizes_at(*t_mat1, -1, *t_mat2, -2)); + VK_CHECK_COND( + api::utils::val_at(-1, mat1_sizes) == api::utils::val_at(-2, mat2_sizes)); } void resize_matmul_node( ComputeGraph* graph, const std::vector& args, const std::vector& extra_args) { - (void)extra_args; vTensorPtr out = graph->get_tensor(args[0].refs[0]); vTensorPtr mat1 = graph->get_tensor(args[1].refs[0]); vTensorPtr mat2 = graph->get_tensor(args[1].refs[1]); + bool mat2_is_transposed = graph->get_bool(extra_args[0]); + + const int out_cols = api::utils::val_at(-2, mat1->sizes()); + const int out_rows = mat2_is_transposed + ? api::utils::val_at(-2, mat2->sizes()) + : api::utils::val_at(-1, mat2->sizes()); + std::vector new_out_sizes(3); if (mat1->sizes().size() == 2) { new_out_sizes.resize(2); - new_out_sizes.at(0) = mat1->sizes().at(0); - new_out_sizes.at(1) = mat2->sizes().at(1); + new_out_sizes.at(0) = out_cols; + new_out_sizes.at(1) = out_rows; } else { new_out_sizes.at(0) = mat1->sizes().at(0); - new_out_sizes.at(1) = mat1->sizes().at(1); - new_out_sizes.at(2) = mat2->sizes().at(2); + new_out_sizes.at(1) = out_cols; + new_out_sizes.at(2) = out_rows; } out->virtual_resize(new_out_sizes); } -struct AddmmParams final { - int broadcast_at_width; - int broadcast_at_height; - float alpha; - float beta; -}; - -// TODO: `add_matmul_node` and `add_addmm_node` has lots of duplicated code. -// We should do refactoring to simplify. -void add_matmul_node( +void add_matmul_naive_node( ComputeGraph& graph, const ValueRef mat1, - const ValueRef mat2, - const ValueRef out) { - ValueRef arg1 = mat1; - ValueRef arg2 = prepack_if_tensor_ref(graph, mat2, api::kHeightPacked); - - std::vector t_mat1_sizes = graph.get_tensor(arg1)->sizes(); - std::vector t_mat2_sizes = graph.get_tensor(arg2)->sizes(); - std::vector out_sizes = graph.get_tensor(out)->sizes(); - int64_t t_mat1_dim = t_mat1_sizes.size(); - int64_t out_dim = out_sizes.size(); - - check_matmul_args(graph, arg1, arg2, out); - auto viewFn = VK_GET_OP_FN("aten.view_copy.default"); - - // optimized mm - if (graph.memory_layout_of(arg1) == api::kChannelsPacked) { - ValueRef t_mat1_width_packed = - graph.add_tensor_like(arg1, api::kWidthPacked); - viewFn(graph, {arg1, graph.add_none(), t_mat1_width_packed}); - arg1 = t_mat1_width_packed; - - if (graph.memory_layout_of(arg2) != api::kHeightPacked) { - ValueRef t_mat2_height_packed = - graph.add_tensor(t_mat2_sizes, api::kFloat, api::kHeightPacked); - viewFn(graph, {arg2, graph.add_none(), t_mat2_height_packed}); - arg2 = t_mat2_height_packed; - } - - vTensorPtr t_mat1 = graph.get_tensor(arg1); - vTensorPtr t_mat2 = graph.get_tensor(arg2); - - VK_CHECK_COND(check_memory_layout_is(*t_mat1, api::kWidthPacked)); - VK_CHECK_COND(check_memory_layout_is(*t_mat2, api::kHeightPacked)); - - // Step size is the 2d input's width dimension / 4. - int32_t step_size = - api::utils::div_up(t_mat1_sizes.at(t_mat1_dim - 1), INT64_C(4)); - - // reminder is used in shader to detect whether the fetched texel is out of - // boundary - int32_t reminder = t_mat1_sizes.at(t_mat1_dim - 1) % INT64_C(4); - - int64_t batch_size = 1; - if (t_mat1_dim == 3) { - batch_size = t_mat1_sizes.at(0); - } - - vTensorPtr t_out = graph.get_tensor(out); - - api::utils::uvec3 global_size = { - static_cast( - api::utils::div_up(out_sizes.at(t_mat1_dim - 1), INT64_C(4))), - static_cast( - api::utils::div_up(out_sizes.at(t_mat1_dim - 2), INT64_C(4))), - static_cast( - out_dim == 3 ? api::utils::div_up(out_sizes.at(0), INT64_C(4)) - : 1)}; - api::utils::uvec3 local_size = adaptive_work_group_size(global_size); - - std::string kernel_name("matmul_optimized"); - kernel_name.reserve(kShaderNameReserve); - - add_dtype_suffix(kernel_name, *t_out); - - graph.execute_nodes().emplace_back(new ExecuteNode( - graph, - VK_KERNEL_FROM_STR(kernel_name), - global_size, - local_size, - // Inputs and Outputs - {{out, api::MemoryAccessType::WRITE}, - {{arg1, arg2}, api::MemoryAccessType::READ}}, - // Shader params buffers - { - t_out->texture_limits_ubo(), - graph.create_params_buffer(step_size), - graph.create_params_buffer(reminder), - graph.create_params_buffer(batch_size), - }, - // Specialization Constants - {}, - // Resizing Logic - resize_matmul_node)); - } else if (graph.memory_layout_of(arg1) == api::kWidthPacked) { - // native mm - if (graph.memory_layout_of(arg2) != api::kHeightPacked) { - ValueRef t_mat2_height_packed = - graph.add_tensor(t_mat2_sizes, api::kFloat, api::kHeightPacked); - viewFn(graph, {arg2, graph.add_none(), t_mat2_height_packed}); - arg2 = t_mat2_height_packed; - } - - vTensorPtr t_mat1 = graph.get_tensor(arg1); - vTensorPtr t_mat2 = graph.get_tensor(arg2); - vTensorPtr t_out = graph.get_tensor(out); - - VK_CHECK_COND(check_memory_layout_is(*t_mat2, api::kHeightPacked)); - - api::utils::uvec3 global_size = t_out->extents(); - api::utils::uvec3 local_size = adaptive_work_group_size(global_size); - - std::string kernel_name("matmul_naive"); - kernel_name.reserve(kShaderNameReserve); - add_memory_layout_suffix(kernel_name, *t_mat1); - add_memory_layout_suffix(kernel_name, *t_mat2); - add_dtype_suffix(kernel_name, *t_out); - - graph.execute_nodes().emplace_back(new ExecuteNode( - graph, - VK_KERNEL_FROM_STR(kernel_name), - global_size, - local_size, - // Inputs and Outputs - {{out, api::MemoryAccessType::WRITE}, - {{arg1, arg2}, api::MemoryAccessType::READ}}, - // Shader params buffers - { - t_out->texture_limits_ubo(), - t_mat1->sizes_ubo(), - }, - // Specialization Constants - {}, - // Resizing Logic - resize_matmul_node)); - } else { - VK_THROW("Input should be channel packed or width packed."); - } + const ValueRef mat2_data, + const ValueRef out, + const ValueRef mat2_is_transposed) { + ValueRef mat2 = prepack_if_tensor_ref(graph, mat2_data, api::kHeightPacked); + + api::utils::uvec3 global_size = graph.extents_of(out); + api::utils::uvec3 local_size = adaptive_work_group_size(global_size); + + std::string kernel_name = graph.get_bool(mat2_is_transposed) + ? "matmul_transposed_naive" + : "matmul_naive"; + kernel_name.reserve(kShaderNameReserve); + add_memory_layout_suffix(kernel_name, graph.memory_layout_of(mat1)); + add_memory_layout_suffix(kernel_name, graph.memory_layout_of(mat2)); + add_dtype_suffix(kernel_name, graph.dtype_of(out)); + + graph.execute_nodes().emplace_back(new ExecuteNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + global_size, + local_size, + // Inputs and Outputs + {{out, api::MemoryAccessType::WRITE}, + {{mat1, mat2}, api::MemoryAccessType::READ}}, + // Shader params buffers + { + graph.texture_limits_ubo(out), + graph.sizes_ubo(mat1), + }, + // Specialization Constants + {}, + // Resizing Logic + resize_matmul_node, + {mat2_is_transposed})); } -void add_addmm_node( +void add_matmul_optimized_node( ComputeGraph& graph, - const ValueRef self, const ValueRef mat1, - const ValueRef mat2, - const ValueRef beta, - const ValueRef alpha, - const ValueRef out) { - ValueRef arg1 = prepack_if_tensor_ref(graph, mat1, api::kChannelsPacked); - ValueRef arg2 = prepack_if_tensor_ref(graph, mat2, api::kHeightPacked); - - std::vector t_mat1_sizes = graph.get_tensor(arg1)->sizes(); - std::vector t_mat2_sizes = graph.get_tensor(arg2)->sizes(); - std::vector out_sizes = graph.get_tensor(out)->sizes(); - int64_t t_mat1_dim = t_mat1_sizes.size(); - - ValueRef self_arg; - int broadcast_at_width = 0; - int broadcast_at_height = 0; - float alpha_val = 1.0f; - float beta_val = 1.0f; - if (graph.memory_layout_of(arg1) == api::kChannelsPacked) { - self_arg = prepack_if_tensor_ref(graph, self, api::kChannelsPacked); - } else if (graph.memory_layout_of(arg1) == api::kWidthPacked) { - self_arg = prepack_if_tensor_ref(graph, self, api::kWidthPacked); - } else { - VK_THROW("Input should be channel packed or width packed."); - } + const ValueRef mat2_data, + const ValueRef out, + const ValueRef mat2_is_transposed) { + ValueRef mat2 = prepack_if_tensor_ref(graph, mat2_data, api::kHeightPacked); - std::vector self_sizes = graph.get_tensor(self_arg)->sizes(); - int64_t self_dim = self_sizes.size(); - if (self_sizes.at(self_dim - 1) < out_sizes.at(t_mat1_dim - 1)) { - broadcast_at_width = 1; - } - if (self_dim < t_mat1_dim || self_sizes.at(0) < out_sizes.at(0)) { - broadcast_at_height = 1; - } - alpha_val = graph.extract_scalar(alpha); - beta_val = graph.extract_scalar(beta); - - AddmmParams addmm_params = { - broadcast_at_width, broadcast_at_height, alpha_val, beta_val}; - - check_matmul_args(graph, arg1, arg2, out); + // Ensure mat1 is width packed + ValueRef mat1_W_packed = graph.add_tensor_like(mat1, api::kWidthPacked); auto viewFn = VK_GET_OP_FN("aten.view_copy.default"); + viewFn(graph, {mat1, graph.add_none(), mat1_W_packed}); - // optimized mm - if (graph.memory_layout_of(arg1) == api::kChannelsPacked) { - ValueRef t_mat1_width_packed = - graph.add_tensor(t_mat1_sizes, api::kFloat, api::kWidthPacked); - viewFn(graph, {arg1, graph.add_none(), t_mat1_width_packed}); - arg1 = t_mat1_width_packed; - - if (graph.memory_layout_of(arg2) != api::kHeightPacked) { - ValueRef t_mat2_height_packed = - graph.add_tensor(t_mat2_sizes, api::kFloat, api::kHeightPacked); - viewFn(graph, {arg2, graph.add_none(), t_mat2_height_packed}); - arg2 = t_mat2_height_packed; - } - - vTensorPtr t_mat1 = graph.get_tensor(arg1); - vTensorPtr t_mat2 = graph.get_tensor(arg2); - - VK_CHECK_COND(check_memory_layout_is(*t_mat1, api::kWidthPacked)); - VK_CHECK_COND(check_memory_layout_is(*t_mat2, api::kHeightPacked)); - - // Step size is the 2d input's width dimension / 4. - int32_t step_size = - api::utils::div_up(t_mat1_sizes.at(t_mat1_dim - 1), INT64_C(4)); - - // reminder is used in shader to detect whether the fetched texel is out of - // boundary - int32_t reminder = t_mat1_sizes.at(t_mat1_dim - 1) % INT64_C(4); + const bool mat2_is_transposed_val = graph.get_bool(mat2_is_transposed); - int64_t batch_size = 1; - if (t_mat1_dim == 3) { - batch_size = t_mat1_sizes.at(0); - } - - vTensorPtr t_out = graph.get_tensor(out); - int64_t out_dim = out_sizes.size(); - - api::utils::uvec3 global_size = { - static_cast( - api::utils::div_up(out_sizes.at(t_mat1_dim - 1), INT64_C(4))), - static_cast( - api::utils::div_up(out_sizes.at(t_mat1_dim - 2), INT64_C(4))), - static_cast( - out_dim == 3 ? api::utils::div_up(out_sizes.at(0), INT64_C(4)) - : 1)}; - api::utils::uvec3 local_size = adaptive_work_group_size(global_size); - - std::string kernel_name("addmm_optimized"); - kernel_name.reserve(kShaderNameReserve); - - add_dtype_suffix(kernel_name, *t_out); - - graph.execute_nodes().emplace_back(new ExecuteNode( - graph, - VK_KERNEL_FROM_STR(kernel_name), - global_size, - local_size, - // Inputs and Outputs - {{out, api::MemoryAccessType::WRITE}, - {{arg1, arg2, self_arg}, api::MemoryAccessType::READ}}, - // Shader params buffers - { - t_out->texture_limits_ubo(), - graph.create_params_buffer(step_size), - graph.create_params_buffer(reminder), - graph.create_params_buffer(batch_size), - graph.create_params_buffer(addmm_params), - }, - // Specialization Constants - {}, - // Resizing Logic - resize_matmul_node)); - } else if (graph.memory_layout_of(arg1) == api::kWidthPacked) { - // native mm - if (graph.memory_layout_of(arg2) != api::kHeightPacked) { - ValueRef t_mat2_height_packed = - graph.add_tensor(t_mat2_sizes, api::kFloat, api::kHeightPacked); - viewFn(graph, {arg2, graph.add_none(), t_mat2_height_packed}); - arg2 = t_mat2_height_packed; - } - - vTensorPtr t_mat1 = graph.get_tensor(arg1); - vTensorPtr t_mat2 = graph.get_tensor(arg2); - vTensorPtr t_out = graph.get_tensor(out); - - VK_CHECK_COND(check_memory_layout_is(*t_mat2, api::kHeightPacked)); - - api::utils::uvec3 global_size = t_out->extents(); - api::utils::uvec3 local_size = adaptive_work_group_size(global_size); + // Ensure mat2 to height packed + ValueRef mat2_packed = mat2; + const api::GPUMemoryLayout mat2_layout = + mat2_is_transposed_val ? api::kWidthPacked : api::kHeightPacked; + if (graph.memory_layout_of(mat2) != mat2_layout) { + mat2_packed = graph.add_tensor_like(mat2, mat2_layout); + viewFn(graph, {mat2, graph.add_none(), mat2_packed}); + } - std::string kernel_name("addmm_naive"); - kernel_name.reserve(kShaderNameReserve); - add_memory_layout_suffix(kernel_name, *t_mat1); - add_memory_layout_suffix(kernel_name, *t_mat2); - add_dtype_suffix(kernel_name, *t_out); + api::utils::uvec3 global_size = + api::utils::divup_vec(graph.extents_of(out), {4, 4, 1}); + api::utils::uvec3 local_size = adaptive_work_group_size(global_size); + + std::string kernel_name = mat2_is_transposed_val + ? "matmul_transposed_optimized" + : "matmul_optimized"; + add_dtype_suffix(kernel_name, graph.dtype_of(out)); + + graph.execute_nodes().emplace_back(new ExecuteNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + global_size, + local_size, + // Inputs and Outputs + {{out, api::MemoryAccessType::WRITE}, + {{mat1_W_packed, mat2_packed}, api::MemoryAccessType::READ}}, + // Shader params buffers + { + graph.texture_limits_ubo(out), + graph.sizes_ubo(out), + graph.texture_limits_ubo(mat1_W_packed), + }, + // Specialization Constants + {}, + // Resizing Logic + resize_matmul_node, + {mat2_is_transposed})); +} - graph.execute_nodes().emplace_back(new ExecuteNode( - graph, - VK_KERNEL_FROM_STR(kernel_name), - global_size, - local_size, - // Inputs and Outputs - {{out, api::MemoryAccessType::WRITE}, - {{arg1, arg2, self_arg}, api::MemoryAccessType::READ}}, - // Shader params buffers - { - t_out->texture_limits_ubo(), - t_mat1->sizes_ubo(), - graph.create_params_buffer(addmm_params), - }, - // Specialization Constants - {}, - // Resizing Logic - resize_matmul_node)); +void add_matmul_node( + ComputeGraph& graph, + const ValueRef mat1, + const ValueRef mat2_data, + const ValueRef out, + const ValueRef mat2_is_transposed) { + if (graph.memory_layout_of(mat1) == api::kChannelsPacked) { + add_matmul_optimized_node(graph, mat1, mat2_data, out, mat2_is_transposed); + } else if (graph.memory_layout_of(mat1) == api::kWidthPacked) { + add_matmul_naive_node(graph, mat1, mat2_data, out, mat2_is_transposed); } else { VK_THROW("Input should be channel packed or width packed."); } } -void addmm(ComputeGraph& graph, const std::vector& args) { - return add_addmm_node( - graph, args[0], args[1], args[2], args[3], args[4], args[5]); -} - void matmul(ComputeGraph& graph, const std::vector& args) { - return add_matmul_node(graph, args[0], args[1], args[2]); + check_matmul_args(graph, args[0], args[1], args[2]); + const ValueRef mat2_is_transposed = graph.add_scalar(false); + return add_matmul_node(graph, args[0], args[1], args[2], mat2_is_transposed); } REGISTER_OPERATORS { VK_REGISTER_OP(aten.mm.default, matmul); VK_REGISTER_OP(aten.bmm.default, matmul); - VK_REGISTER_OP(aten.addmm.default, addmm); } } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/MatMul.h b/backends/vulkan/runtime/graph/ops/impl/MatMul.h new file mode 100644 index 00000000000..38f7907f1b6 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/MatMul.h @@ -0,0 +1,22 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace vkcompute { + +void add_matmul_node( + ComputeGraph& graph, + const ValueRef mat1, + const ValueRef mat2_data, + const ValueRef out, + const ValueRef mat2_is_transposed); + +} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp b/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp index b2fb1135d77..4dd615cda18 100644 --- a/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp @@ -100,10 +100,18 @@ float get_val_or_inf(ComputeGraph& graph, const ValueRef& val, bool max) { kClampShaderName); \ } +void gelu(ComputeGraph& graph, const std::vector& args) { + // args[1] is the `approximate` string + // https://fburl.com/code/9omngmyo + // currently only `approximate = "tanh"` is supported + return add_unary_op_node( + graph, args[0], kDummyFloat, kDummyFloat, args[2], "gelu"); +} + DEFINE_ACTIVATION_FN(abs); DEFINE_ACTIVATION_FN(sigmoid); -DEFINE_ACTIVATION_FN(tanh); DEFINE_ACTIVATION_FN(sqrt); +DEFINE_ACTIVATION_FN(tanh); DEFINE_CLAMP_FN(clamp); DEFINE_CLAMP_FN(hardtanh); DEFINE_RELU_FN(relu); @@ -111,11 +119,12 @@ DEFINE_RELU_FN(relu); REGISTER_OPERATORS { VK_REGISTER_OP(aten.abs.default, abs); VK_REGISTER_OP(aten.clamp.default, clamp); + VK_REGISTER_OP(aten.gelu.default, gelu); VK_REGISTER_OP(aten.hardtanh.default, hardtanh); VK_REGISTER_OP(aten.relu.default, relu); VK_REGISTER_OP(aten.sigmoid.default, sigmoid); - VK_REGISTER_OP(aten.tanh.default, tanh); VK_REGISTER_OP(aten.sqrt.default, sqrt); + VK_REGISTER_OP(aten.tanh.default, tanh); } } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/View.cpp b/backends/vulkan/runtime/graph/ops/impl/View.cpp index ef23110d116..b3b4dedefd5 100644 --- a/backends/vulkan/runtime/graph/ops/impl/View.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/View.cpp @@ -14,7 +14,50 @@ namespace vkcompute { -void add_view_node(ComputeGraph& graph, ValueRef in, ValueRef out) { +std::vector compute_out_sizes( + std::vector orig_sizes, + std::vector& view_sizes) { + std::vector out_sizes(view_sizes.begin(), view_sizes.end()); + int64_t numel = 1; + int64_t transferred_numel = 1; + + for (int i = 0; i < orig_sizes.size(); i++) { + numel *= orig_sizes.at(i); + } + for (int i = 0; i < view_sizes.size(); i++) { + if (view_sizes.at(i) > 0) { + transferred_numel *= view_sizes.at(i); + } + } + for (int i = 0; i < out_sizes.size(); i++) { + if (out_sizes.at(i) == -1) { + out_sizes.at(i) = numel / transferred_numel; + } + } + return out_sizes; +} + +void resize_view_node( + ComputeGraph* graph, + const std::vector& args, + const std::vector& extra_args) { + vTensorPtr out = graph->get_tensor(args[0].refs[0]); + vTensorPtr in = graph->get_tensor(args[1].refs[0]); + if (extra_args[0] == kDummyValueRef || graph->val_is_none(extra_args[0])) { + out->virtual_resize(in->sizes()); + } else { + IntListPtr view_sizes = graph->get_int_list(extra_args[0]); + std::vector out_sizes = + compute_out_sizes(in->sizes(), *view_sizes); + out->virtual_resize(out_sizes); + } +} + +void add_view_node( + ComputeGraph& graph, + ValueRef in, + ValueRef sizes, + ValueRef out) { vTensorPtr t_in = graph.get_tensor(in); vTensorPtr t_out = graph.get_tensor(out); @@ -35,13 +78,14 @@ void add_view_node(ComputeGraph& graph, ValueRef in, ValueRef out) { // Parameter Buffers {t_out->sizes_ubo(), t_in->sizes_ubo()}, // Specialization Constants - {SV(t_in->gpu_memory_layout_int()), SV(t_out->gpu_memory_layout_int())})); + {SV(t_in->gpu_memory_layout_int()), SV(t_out->gpu_memory_layout_int())}, + // Resizing Logic + resize_view_node, + {sizes})); } void view(ComputeGraph& graph, const std::vector& args) { - // Note: The second argument size_ref is not used here. Since the output - // tensor's size have been determined during compilation. - return add_view_node(graph, args[0], args[2]); + return add_view_node(graph, args[0], args[1], args[2]); } REGISTER_OPERATORS { diff --git a/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.cpp b/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.cpp index 0bca0b4f055..41e8f0fb02f 100644 --- a/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.cpp +++ b/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.cpp @@ -10,15 +10,15 @@ namespace vkcompute { -void add_dtype_suffix(std::string& kernel_name, const vTensor& tensor) { - switch (tensor.image().format()) { - case VK_FORMAT_R32G32B32A32_SFLOAT: +void add_dtype_suffix(std::string& kernel_name, const api::ScalarType dtype) { + switch (dtype) { + case api::kFloat: kernel_name += "_float"; break; - case VK_FORMAT_R16G16B16A16_SFLOAT: + case api::kHalf: kernel_name += "_half"; break; - case VK_FORMAT_R32G32B32A32_SINT: + case api::kInt: kernel_name += "_int"; break; default: @@ -26,6 +26,10 @@ void add_dtype_suffix(std::string& kernel_name, const vTensor& tensor) { } } +void add_dtype_suffix(std::string& kernel_name, const vTensor& tensor) { + return add_dtype_suffix(kernel_name, tensor.dtype()); +} + void add_ndim_suffix(std::string& kernel_name, const vTensor& tensor) { switch (tensor.storage_type()) { case api::kTexture3D: @@ -39,8 +43,10 @@ void add_ndim_suffix(std::string& kernel_name, const vTensor& tensor) { } } -void add_memory_layout_suffix(std::string& kernel_name, const vTensor& tensor) { - switch (tensor.gpu_memory_layout()) { +void add_memory_layout_suffix( + std::string& kernel_name, + api::GPUMemoryLayout layout) { + switch (layout) { case api::kChannelsPacked: kernel_name += "_C_packed"; break; @@ -55,4 +61,8 @@ void add_memory_layout_suffix(std::string& kernel_name, const vTensor& tensor) { } } +void add_memory_layout_suffix(std::string& kernel_name, const vTensor& tensor) { + return add_memory_layout_suffix(kernel_name, tensor.gpu_memory_layout()); +} + } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h b/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h index a784a4acb4c..bf97efcd4ec 100644 --- a/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h +++ b/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h @@ -16,10 +16,15 @@ namespace vkcompute { constexpr size_t kShaderNameReserve = 64u; +void add_dtype_suffix(std::string& kernel_name, const api::ScalarType dtype); void add_dtype_suffix(std::string& kernel_name, const vTensor& tensor); +void add_ndim_suffix(std::string& kernel_name, const size_t ndim); void add_ndim_suffix(std::string& kernel_name, const vTensor& tensor); +void add_memory_layout_suffix( + std::string& kernel_name, + const api::GPUMemoryLayout layout); void add_memory_layout_suffix(std::string& kernel_name, const vTensor& tensor); } // namespace vkcompute diff --git a/backends/vulkan/targets.bzl b/backends/vulkan/targets.bzl index 573eaf2e1c0..d80c694e6db 100644 --- a/backends/vulkan/targets.bzl +++ b/backends/vulkan/targets.bzl @@ -150,10 +150,10 @@ def define_common_targets(is_fbcode = False): name = "vulkan_compute_api", compiler_flags = get_vulkan_compiler_flags(), srcs = native.glob([ - "runtime/api/*.cpp", + "runtime/api/**/*.cpp", ]), exported_headers = native.glob([ - "runtime/api/*.h", + "runtime/api/**/*.h", ]), visibility = [ "//executorch/backends/vulkan/...", diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index 8e7fbab6636..d115f1897fa 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -94,6 +94,26 @@ def get_addmm_inputs(): return test_suite +def get_linear_inputs(): + MKN_list = [ + (S2, M2, M1), + (L, L, M1), + ] + + inputs_list = [((M, K), (N, K), None) for M, K, N in MKN_list] + inputs_list += [((M, K), (N, K), (N)) for M, K, N in MKN_list] + inputs_list += [((3, M, K), (N, K), None) for M, K, N in MKN_list] + inputs_list += [((3, M, K), (N, K), (N)) for M, K, N in MKN_list] + + test_suite = VkTestSuite(inputs_list) + test_suite.dtypes = ["at::kFloat"] + test_suite.layouts = [ + "api::kWidthPacked", + "api::kChannelsPacked", + ] + return test_suite + + def get_pool2d_inputs(): test_suite = VkTestSuite( [ @@ -651,6 +671,94 @@ def get_unary_ops_inputs(): return test_suite +def get_native_batch_norm_inputs(): + Test = namedtuple( + "VkSliceTest", ["self", "weight", "bias", "mean", "var", "momentum", "eps"] + ) + + test_cases = [ + Test( + self=(1, 1, 2, 5), + weight=(1,), + bias=(1,), + mean=(1,), + var=(1,), + momentum=0.0, + eps=0.001, + ), + Test( + self=(S2, 1, 2, 5), + weight=(1,), + bias=(1,), + mean=(1,), + var=(1,), + momentum=0.0, + eps=0.001, + ), + Test( + self=(1, S2, 2, 5), + weight=(S2,), + bias=(S2,), + mean=(S2,), + var=(S2,), + momentum=0.0, + eps=0.001, + ), + Test( + self=(9, S1, 2, 5), + weight=(S1,), + bias=(S1,), + mean=(S1,), + var=(S1,), + momentum=0.0, + eps=0.01, + ), + Test( + self=(3, S1, 2, 5), + weight=(S1,), + bias=(S1,), + mean=(S1,), + var=(S1,), + momentum=0.0, + eps=0.001, + ), + Test( + self=(3, S2, 2, 5), + weight=(S2,), + bias=(S2,), + mean=(S2,), + var=(S2,), + momentum=0.0, + eps=0.001, + ), + Test( + self=(3, S2, 2, 5), + weight=(S2,), + bias=(S2,), + mean=(S2,), + var=(S2,), + momentum=0.0, + eps=0.000, + ), + ] + + test_suite = VkTestSuite(test_cases) + + return test_suite + + +def get_gelu_inputs(): + test_suite = VkTestSuite( + [ + ((M1), "tanh"), + ((M1, M2), "tanh"), + ((S1, M1, M2), "tanh"), + ((S1, S2, S2, M2), "tanh"), + ] + ) + return test_suite + + test_suites = { "aten.add.Tensor": get_binary_elementwise_inputs(), "aten.sub.Tensor": get_binary_elementwise_inputs(), @@ -659,6 +767,7 @@ def get_unary_ops_inputs(): "aten.addmm.default": get_addmm_inputs(), "aten.bmm.default": get_bmm_inputs(), "aten.mm.default": get_mm_inputs(), + "aten.linear.default": get_linear_inputs(), "aten.max_pool2d_with_indices.default": get_pool2d_inputs(), "aten.convolution.default": get_conv_inputs(), "aten.native_layer_norm.default": get_native_layer_norm_inputs(), @@ -678,4 +787,6 @@ def get_unary_ops_inputs(): "aten.sqrt.default": get_unary_ops_inputs(), "aten._softmax.default": get_softmax_inputs(), "aten._log_softmax.default": get_softmax_inputs(), + "aten._native_batch_norm_legit_no_training.default": get_native_batch_norm_inputs(), + "aten.gelu.default": get_gelu_inputs(), } diff --git a/backends/vulkan/test/op_tests/utils/codegen.py b/backends/vulkan/test/op_tests/utils/codegen.py index a43998b47c9..c803f767920 100644 --- a/backends/vulkan/test/op_tests/utils/codegen.py +++ b/backends/vulkan/test/op_tests/utils/codegen.py @@ -24,6 +24,7 @@ OPT_LAYOUT, OPT_MEMORY_FORMAT, OPT_SCALAR_TYPE, + STRING, TENSOR_VECTOR, TestSuite, TestSuiteGen, @@ -130,7 +131,14 @@ def __init__(self, op_reg_name: str, f: NativeFunction, suite_def: TestSuite): ATenArg(name=arg.name, cpp_type=cpp_type, default=arg.default) ) - requires_prepack = "weight" in arg.name or "bias" in arg.name + # These are the argument will be passed as a "weight" tensor, the + # corresponding object will be TensorRef in the compute graph. + requires_prepack = ( + "weight" in arg.name + or "bias" in arg.name + or "running_mean" in arg.name + or "running_var" in arg.name + ) supports_prepack = False if arg.name in self.suite_def.prepacked_args: supports_prepack = True @@ -344,6 +352,8 @@ def create_value_for(self, ref: ValueRefList) -> str: # noqa: C901 or ref.src_cpp_type == OPT_MEMORY_FORMAT ): ret_str += "add_none(); \n" + elif ref.src_cpp_type == STRING: + ret_str += f"add_string(std::string({ref.src_cpp_name})); \n" elif ref.src_cpp_type == TWO_TENSOR_TUPLE: ret_str += f"add_value_list({{{ref.name}_first, {ref.name}_second}}); \n" elif ref.src_cpp_type == THREE_TENSOR_TUPLE: @@ -582,8 +592,8 @@ class GeneratedOpsTest_{op_name} : public ::testing::TestWithParam< ::std::tuple api::StorageType default_storage_type; api::GPUMemoryLayout default_memory_layout; std::tie(test_dtype, default_storage_type, default_memory_layout) = GetParam(); - config.setStorageTypeOverride(default_storage_type); - config.setMemoryLayoutOverride(default_memory_layout); + config.set_storage_type_override(default_storage_type); + config.set_memory_layout_override(default_memory_layout); graph = new ComputeGraph(config); if (test_dtype == at::kHalf) {{ diff --git a/backends/vulkan/test/op_tests/utils/codegen_base.py b/backends/vulkan/test/op_tests/utils/codegen_base.py index 6dac97583c6..c1c6249e27f 100644 --- a/backends/vulkan/test/op_tests/utils/codegen_base.py +++ b/backends/vulkan/test/op_tests/utils/codegen_base.py @@ -29,6 +29,7 @@ OPT_LAYOUT = "::std::optional" OPT_MEMORY_FORMAT = "::std::optional" OPT_SCALAR_TYPE = "::std::optional" +STRING = "c10::string_view" TWO_TENSOR_TUPLE = "::std::tuple" THREE_TENSOR_TUPLE = "::std::tuple" TENSOR_VECTOR = "::std::vector" @@ -166,6 +167,8 @@ def create_input_data(self, arg: Argument, data: Any) -> str: # noqa: C901 ret_str += "std::nullopt;" else: ret_str += f"{str(data)};" + elif cpp_type == STRING: + ret_str += f'c10::string_view("{data}");' elif ( cpp_type == OPT_SCALAR_TYPE or cpp_type == OPT_LAYOUT diff --git a/backends/vulkan/test/test_vulkan_delegate.py b/backends/vulkan/test/test_vulkan_delegate.py index 531f1d28a92..2cd3bc3a270 100644 --- a/backends/vulkan/test/test_vulkan_delegate.py +++ b/backends/vulkan/test/test_vulkan_delegate.py @@ -1034,3 +1034,14 @@ def forward(self, x): sample_inputs, memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], ) + + def test_vulkan_backend_gelu(self): + class GeluModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.gelu = torch.nn.GELU(approximate="tanh") + + def forward(self, x): + return self.gelu(x) + + self.lower_unary_module_and_test_output(GeluModule()) diff --git a/backends/vulkan/test/utils/test_utils.cpp b/backends/vulkan/test/utils/test_utils.cpp index 37ced363b61..7bbba9108a5 100644 --- a/backends/vulkan/test/utils/test_utils.cpp +++ b/backends/vulkan/test/utils/test_utils.cpp @@ -204,7 +204,7 @@ void submit_to_gpu() { fence.wait(); } -api::MemoryAllocation allocate_memory_for(const vTensor& vten) { +api::Allocation allocate_memory_for(const vTensor& vten) { return api::context()->adapter_ptr()->vma().create_allocation( vten.get_memory_requirements(), vten.get_allocation_create_info()); } diff --git a/backends/vulkan/test/utils/test_utils.h b/backends/vulkan/test/utils/test_utils.h index 168f643fe52..1a65ea04c26 100644 --- a/backends/vulkan/test/utils/test_utils.h +++ b/backends/vulkan/test/utils/test_utils.h @@ -179,7 +179,7 @@ inline int64_t get_buf_idx( void submit_to_gpu(); -api::MemoryAllocation allocate_memory_for(const vTensor& vten); +api::Allocation allocate_memory_for(const vTensor& vten); VmaTotalStatistics get_vma_stats(); diff --git a/backends/vulkan/test/vulkan_compute_api_test.cpp b/backends/vulkan/test/vulkan_compute_api_test.cpp index 85c0a5ebb46..614a2ffcaf6 100644 --- a/backends/vulkan/test/vulkan_compute_api_test.cpp +++ b/backends/vulkan/test/vulkan_compute_api_test.cpp @@ -301,11 +301,11 @@ TEST_F(VulkanComputeAPITest, texture_deferred_allocation_test) { std::fill(data_b.begin(), data_b.end(), 1.5f); // Allocate memory at the last possible opportunity - api::MemoryAllocation a_mem = allocate_memory_for(a); + api::Allocation a_mem = allocate_memory_for(a); a.image().bind_allocation(a_mem); - api::MemoryAllocation b_mem = allocate_memory_for(b); + api::Allocation b_mem = allocate_memory_for(b); b.image().bind_allocation(b_mem); - api::MemoryAllocation c_mem = allocate_memory_for(c); + api::Allocation c_mem = allocate_memory_for(c); c.image().bind_allocation(c_mem); // One allocation for each tensor @@ -341,15 +341,15 @@ TEST_F(VulkanComputeAPITest, texture_resource_aliasing_test) { EXPECT_TRUE(get_vma_allocation_count() == 0); // a and d can share the same memory allocation - api::MemoryAllocation a_d_mem = allocate_memory_for(a); + api::Allocation a_d_mem = allocate_memory_for(a); a.image().bind_allocation(a_d_mem); d.image().bind_allocation(a_d_mem); // b and e can share the same memory allocation - api::MemoryAllocation b_e_mem = allocate_memory_for(b); + api::Allocation b_e_mem = allocate_memory_for(b); b.image().bind_allocation(b_e_mem); e.image().bind_allocation(b_e_mem); // c must have its own memory allocation - api::MemoryAllocation c_mem = allocate_memory_for(c); + api::Allocation c_mem = allocate_memory_for(c); c.image().bind_allocation(c_mem); // 3 allocations should be made @@ -394,7 +394,7 @@ TEST_F(VulkanComputeAPITest, resource_bind_twice_fails) { vTensor a = CREATE_FLOAT_TEXTURE(sizes, /*allocate_memory = */ true); // Try to double bind a resource, which should fail - api::MemoryAllocation a_mem = allocate_memory_for(a); + api::Allocation a_mem = allocate_memory_for(a); EXPECT_THROW(a.image().bind_allocation(a_mem), api::Error); } @@ -402,9 +402,9 @@ TEST_F(VulkanComputeAPITest, resource_destructor_non_owning_memory) { // Check that the destructor of a vTensor that does not own its memory // does not free the memory - api::MemoryAllocation memory; + api::Allocation memory; - // Default MemoryAllocation constructor should not allocate memory + // Default Allocation constructor should not allocate memory EXPECT_TRUE(get_vma_allocation_count() == 0); std::vector sizes = {4, 4, 1}; @@ -464,11 +464,11 @@ TEST_F( vTensor b = CREATE_FLOAT_TEXTURE(sizes, /*allocate_memory = */ false); vTensor c = CREATE_FLOAT_TEXTURE(sizes, /*allocate_memory = */ false); - api::MemoryAllocation a_mem = allocate_memory_for(a); + api::Allocation a_mem = allocate_memory_for(a); a.image().bind_allocation(a_mem); - api::MemoryAllocation b_mem = allocate_memory_for(b); + api::Allocation b_mem = allocate_memory_for(b); b.image().bind_allocation(b_mem); - api::MemoryAllocation c_mem = allocate_memory_for(c); + api::Allocation c_mem = allocate_memory_for(c); c.image().bind_allocation(c_mem); execute_and_check_add(a, b, c, 4.0f, 8.0f); diff --git a/backends/xnnpack/operators/op_avg_pooling2d.py b/backends/xnnpack/operators/op_avg_pooling2d.py index 18f981cb330..94cd06cc08e 100644 --- a/backends/xnnpack/operators/op_avg_pooling2d.py +++ b/backends/xnnpack/operators/op_avg_pooling2d.py @@ -16,6 +16,7 @@ XNNGraph, XNode, ) +from executorch.backends.xnnpack.utils.xnnpack_constants import XNN_FLAG_KEEP_DIMS @register_node_visitor @@ -67,7 +68,7 @@ def define_node( dilation_width=0, # Unused input_id=input_id, output_id=output_id, - flags=0, + flags=XNN_FLAG_KEEP_DIMS, ), debug_handle=debug_handle, ) diff --git a/backends/xnnpack/operators/op_max_pool2d.py b/backends/xnnpack/operators/op_max_pool2d.py index 6fb49d30d57..d1a010295ef 100644 --- a/backends/xnnpack/operators/op_max_pool2d.py +++ b/backends/xnnpack/operators/op_max_pool2d.py @@ -18,6 +18,7 @@ XNNMaxPooling2d, XNode, ) +from executorch.backends.xnnpack.utils.xnnpack_constants import XNN_FLAG_KEEP_DIMS @register_node_visitor @@ -80,7 +81,7 @@ def define_node( kwargs["dilation_height"] = dilation[0] kwargs["dilation_width"] = dilation[1] - kwargs["flags"] = 0 + kwargs["flags"] = XNN_FLAG_KEEP_DIMS ser_node = XNode( xnode_union=XNNMaxPooling2d( diff --git a/backends/xnnpack/operators/op_mean_dim.py b/backends/xnnpack/operators/op_mean_dim.py index fe9f2249631..663606a8880 100644 --- a/backends/xnnpack/operators/op_mean_dim.py +++ b/backends/xnnpack/operators/op_mean_dim.py @@ -18,6 +18,7 @@ XNNGraph, XNode, ) +from executorch.backends.xnnpack.utils.xnnpack_constants import XNN_FLAG_KEEP_DIMS @register_node_visitor @@ -70,7 +71,7 @@ def define_node( ser_node = XNode( xnode_union=XNNGlobalAvgPooling2d( - input_id=input_id, output_id=output_id, flags=0 + input_id=input_id, output_id=output_id, flags=XNN_FLAG_KEEP_DIMS ), debug_handle=debug_handle, ) diff --git a/backends/xnnpack/partition/xnnpack_partitioner.py b/backends/xnnpack/partition/xnnpack_partitioner.py index f5b11a631a3..6d483e4ea00 100644 --- a/backends/xnnpack/partition/xnnpack_partitioner.py +++ b/backends/xnnpack/partition/xnnpack_partitioner.py @@ -166,6 +166,8 @@ def _check_outputs_are_valid_dtypes(self, node, valid_dtypes): return True def check_node_has_valid_dtype(self, node): + # max_pool2d_with_indicies returns indicies which is int64 + # this is supportable within XNNPACK if node.target in {exir_ops.edge.aten.max_pool2d_with_indices.default}: return True @@ -268,13 +270,16 @@ def maxpool2d_with_indices( ) -> bool: """ Only if the first output value is consumed in the graph + and it is not in ceil mode """ users = list(node.users.keys()) + is_ceil_mode = len(node.args) >= 6 and node.args[5] return ( True if len(users) == 1 and users[0].target == operator.getitem and users[0].args[1] == 0 + and not is_ceil_mode else False ) diff --git a/backends/xnnpack/test/ops/linear.py b/backends/xnnpack/test/ops/linear.py index 85b760e38ad..06ca72e377c 100644 --- a/backends/xnnpack/test/ops/linear.py +++ b/backends/xnnpack/test/ops/linear.py @@ -48,6 +48,18 @@ def test_fp32_linear(self): num_batch_dims=num_batch_dims, ) + def test_qc8_linear(self): + for use_bias in (True, False): + for num_batch_dims in range(1, 3): + self._test_linear( + lambda in_size, out_size: torch.nn.Linear( + in_size, out_size, bias=use_bias # noqa + ), + uses_bias=use_bias, + quant_type="per_channel", + num_batch_dims=num_batch_dims, + ) + def test_fp32_addmm(self): """ Note that the ConvertToLinear pass requires the weight matrix to be transposed. @@ -107,7 +119,7 @@ def forward(self, x): ), num_batch_dims=num_batch_dims, uses_bias=use_bias, - quant=True, + quant_type="per_tensor", ) def test_qs8_linear(self): @@ -119,6 +131,7 @@ def test_qs8_linear(self): ), uses_bias=use_bias, num_batch_dims=num_batch_dims, + quant_type="per_tensor", ) @unittest.skip("XNNPACK currently only supports per-channel dynamic quantization.") @@ -726,7 +739,7 @@ def _test_linear( make_module, uses_bias, num_batch_dims=1, - quant=False, + quant_type=None, dtype: torch.dtype = torch.float, atol=1e-03, ): @@ -746,6 +759,8 @@ def _test_linear( input_sizes = [4, 37, 17] output_sizes = [4, 17, 37] + quant = quant_type is not None + """ Note that torch.nn.Linear maps to aten.mm.default (no bias) or aten.addmm.default (bias), which ares then transformed into aten.linear.default by the ConvertToLinear pass. @@ -769,7 +784,19 @@ def _test_linear( tester = Tester(module, inputs, dynamic_shapes=dynamic_shape) if quant: - tester.quantize() + if quant_type == "per_channel": + quant_config = get_symmetric_quantization_config( + is_per_channel=True, + is_dynamic=False, + ) + elif quant_type == "per_tensor": + quant_config = get_symmetric_quantization_config( + is_per_channel=False, + is_dynamic=False, + ) + else: + raise ValueError(f"Unsupported quant type {quant_type}") + tester.quantize(Quantize(quantization_config=quant_config)) tester.export() tester.check_count({aten_op: 1}) diff --git a/backends/xnnpack/test/ops/maxpool2d.py b/backends/xnnpack/test/ops/maxpool2d.py index e919fc6e776..889c29a5f38 100644 --- a/backends/xnnpack/test/ops/maxpool2d.py +++ b/backends/xnnpack/test/ops/maxpool2d.py @@ -38,6 +38,14 @@ def __init__(self, kernel_size=3, stride=1, padding=0, dilation=1): def forward(self, x): return self.max_pool2d_module(x)[1] + class MaxPool2dUnsupportedCeilMode(torch.nn.Module): + def __init__(self): + super().__init__() + self.max_pool2d_module = torch.nn.MaxPool2d(2, stride=2, ceil_mode=True) + + def forward(self, x): + return self.max_pool2d_module(x) + def _test_maxpool2d(self, inputs): """ Note that the export process generates aten.max_pool2d_with_indices. The remove_getitem_op @@ -99,6 +107,34 @@ def test_fp32_maxpool2d_unsupported(self): ) ) + def test_fp32_maxpool2d_unsupported_ceilmode(self): + """ + MaxPool2d with ceil mode is not generally supported (see maxpool2d constraint). + """ + inputs = (torch.randn(1, 32, 23, 23),) + ( + Tester(self.MaxPool2dUnsupportedCeilMode(), inputs) + .export() + .check_count({"torch.ops.aten.max_pool2d_with_indices.default": 1}) + .to_edge() + .check_count( + { + "executorch_exir_dialects_edge__ops_aten_max_pool2d_with_indices_default": 1 + } + ) + .partition() + # We expect it not be be delegated. + .check_count({"torch.ops.higher_order.executorch_call_delegate": 0}) + .check_count( + { + "executorch_exir_dialects_edge__ops_aten_max_pool2d_with_indices_default": 1 + } + ) + .to_executorch() + .serialize() + .run_method_and_compare_outputs() + ) + def test_qs8_maxpool2d(self): class MaxPool(torch.nn.Module): def __init__(self, maxpool_params): diff --git a/backends/xnnpack/third-party/XNNPACK b/backends/xnnpack/third-party/XNNPACK index 70bbd07c1de..20c0d886fb7 160000 --- a/backends/xnnpack/third-party/XNNPACK +++ b/backends/xnnpack/third-party/XNNPACK @@ -1 +1 @@ -Subproject commit 70bbd07c1de310a1f89379c746b8f24a506c3283 +Subproject commit 20c0d886fb78d6497362e8303b999bf5d67aaa02 diff --git a/backends/xnnpack/utils/xnnpack_constants.py b/backends/xnnpack/utils/xnnpack_constants.py index 63c8d6fdeef..351cc8ad897 100644 --- a/backends/xnnpack/utils/xnnpack_constants.py +++ b/backends/xnnpack/utils/xnnpack_constants.py @@ -8,21 +8,25 @@ UINT32_MAX = 4294967295 XNN_EXTRA_BYTES = 16 XNN_MAX_TENSOR_DIMS = 6 -XNN_FLAG_SPARSE_INFERENCE = 0x00000001 -XNN_FLAG_HINT_SPARSE_INFERENCE = XNN_FLAG_SPARSE_INFERENCE -XNN_FLAG_FP16_INFERENCE = 0x00000002 -XNN_FLAG_HINT_FP16_INFERENCE = XNN_FLAG_FP16_INFERENCE +XNN_FLAG_HINT_SPARSE_INFERENCE = 0x00000001 +XNN_FLAG_HINT_FP16_INFERENCE = 0x00000002 XNN_FLAG_FORCE_FP16_INFERENCE = 0x00000004 XNN_FLAG_BASIC_PROFILING = 0x00000008 +XNN_FLAG_JIT = 0x00000010 XNN_FLAG_DEPTHWISE_CONVOLUTION = 0x00000001 XNN_FLAG_TRANSPOSE_WEIGHTS = 0x00000001 XNN_FLAG_INPUT_NHWC = 0x00000002 XNN_FLAG_TENSORFLOW_SAME_PADDING = 0x00000004 +XNN_FLAG_TRANSPOSE_B = XNN_FLAG_TRANSPOSE_WEIGHTS +XNN_FLAG_TRANSPOSE_A = 0x00000002 XNN_FLAG_TENSORFLOW_RESHAPE_2D = 0x00000004 XNN_FLAG_TENSORFLOW_LEGACY_MODE = 0x00000004 XNN_FLAG_FP32_STATIC_WEIGHTS = 0x00000008 XNN_FLAG_ALIGN_CORNERS = 0x00000008 XNN_FLAG_YIELD_WORKERS = 0x00000010 +XNN_FLAG_TRANSIENT_INDIRECTION_BUFFER = 0x00000020 +XNN_FLAG_KEEP_DIMS = 0x00000040 +XNN_EXTRA_QUANTIZATION_PARAMS = 8 XNN_VALUE_FLAG_EXTERNAL_INPUT = 0x00000001 XNN_VALUE_FLAG_EXTERNAL_OUTPUT = 0x00000002 XNN_VALUE_FLAG_PERSISTENT = 0x00000004 diff --git a/build/packaging/pre_build_script.sh b/build/packaging/pre_build_script.sh index 1abb1a76fe3..74c98406d05 100644 --- a/build/packaging/pre_build_script.sh +++ b/build/packaging/pre_build_script.sh @@ -16,6 +16,7 @@ set -euxo pipefail readonly BUILD_DEPS=( # This list must match the build-system.requires list from pyproject.toml. "cmake" + "pip>=23" "pyyaml" "setuptools>=63" "tomli" diff --git a/build/test_android_ci.sh b/build/test_android_ci.sh index 8d9391146dc..b1f17730f5d 100755 --- a/build/test_android_ci.sh +++ b/build/test_android_ci.sh @@ -22,6 +22,7 @@ build_android_native_library() { pushd examples/demo-apps/android/LlamaDemo CMAKE_OUT="cmake-out-android-$1" ANDROID_NDK=/opt/ndk ANDROID_ABI="$1" ./gradlew setup popd + cp "cmake-out-android-$1"/extension/android/*.so build_aar/jni/$1/ } build_android_demo_app() { @@ -37,8 +38,25 @@ build_android_llama_demo_app() { popd } +build_aar() { + cp extension/android/build/libs/executorch.jar build_aar/libs + echo \ \ + \ \ + \ > build_aar/AndroidManifest.xml + pushd build_aar + zip -r executorch.aar libs jni AndroidManifest.xml + + rm jni/arm64-v8a/libexecutorch_jni.so jni/x86_64/libexecutorch_jni.so + zip -r executorch-llama.aar libs jni AndroidManifest.xml + popd +} + +mkdir -p build_aar/jni/arm64-v8a build_aar/jni/x86_64 build_aar/libs + build_android_native_library arm64-v8a build_android_native_library x86_64 export_model build_android_demo_app build_android_llama_demo_app +build_aar diff --git a/codegen/tools/gen_all_oplist.py b/codegen/tools/gen_all_oplist.py index 6626803df74..ec02ff7ec3c 100644 --- a/codegen/tools/gen_all_oplist.py +++ b/codegen/tools/gen_all_oplist.py @@ -22,10 +22,7 @@ def main(argv: List[Any]) -> None: parser = argparse.ArgumentParser(description="Generate operator lists") parser.add_argument( "--output_dir", - help=( - "The directory to store the output yaml files (selected_mobile_ops.h, " - + "selected_kernel_dtypes.h, selected_operators.yaml)" - ), + help=("The directory to store the output yaml file (selected_operators.yaml)"), required=True, ) parser.add_argument( diff --git a/codegen/tools/gen_oplist_copy_from_core.py b/codegen/tools/gen_oplist_copy_from_core.py index 452f84ba572..34a8af245bb 100644 --- a/codegen/tools/gen_oplist_copy_from_core.py +++ b/codegen/tools/gen_oplist_copy_from_core.py @@ -5,138 +5,19 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# This is a copy from //xplat/caffe2/tools/code_analyzer/gen_oplist.py -# TODO(mnachin): We will need to either simplify or remove this code altogether. -# This is necessary to remove dependency from pytorch core from ExecuTorch. +# This is a simplified copy from //xplat/caffe2/tools/code_analyzer/gen_oplist.py import argparse -import json import os import sys from functools import reduce -from typing import Any, List, Set +from typing import Any, List import yaml -from torchgen.code_template import CodeTemplate from torchgen.selective_build.selector import ( combine_selective_builders, SelectiveBuilder, ) -if_condition_template_str = """if (kernel_tag_sv.compare("$kernel_tag_name") == 0) { - return $dtype_checks; -}""" -if_condition_template = CodeTemplate(if_condition_template_str) - -selected_kernel_dtypes_h_template_str = """ -#include -#include -#include - -namespace at { -inline constexpr bool should_include_kernel_dtype( - const char *kernel_tag_str, - at::ScalarType scalar_type -) { - c10::string_view kernel_tag_sv C10_UNUSED = c10::string_view(kernel_tag_str); - $body - return false; -} -} -""" -selected_kernel_dtypes_h_template = CodeTemplate(selected_kernel_dtypes_h_template_str) - -selected_mobile_ops_preamble = """#pragma once -/** - * Generated by gen_selected_mobile_ops_header.py - */ - -""" - - -def get_selected_kernel_dtypes_code( - selective_builder: SelectiveBuilder, -) -> str: - # See https://www.internalfb.com/intern/paste/P153411698/ for an example of the - # generated code in case all kernel dtypes are selected and in case some kernel - # dtypes are selected (i.e. both cases). - # - body = "return true;" - if ( - selective_builder.include_all_operators is False - and selective_builder.include_all_non_op_selectives is False - ): - body_parts = [] - for kernel_tag, dtypes in selective_builder.kernel_metadata.items(): - conditions = ["scalar_type == at::ScalarType::" + x for x in dtypes] - body_parts.append( - if_condition_template.substitute( - kernel_tag_name=kernel_tag, - dtype_checks=" || ".join(conditions), - ), - ) - body = " else ".join(body_parts) - - header_contents = selected_kernel_dtypes_h_template.substitute(body=body) - return header_contents - - -def extract_root_operators(selective_builder: SelectiveBuilder) -> Set[str]: - ops = [] - for op_name, op in selective_builder.operators.items(): - if op.is_root_operator: - ops.append(op_name) - return set(ops) - - -# Write the file selected_mobile_ops.h with optionally: -# 1. The selected root operators -# 2. The selected kernel dtypes -def write_selected_mobile_ops( - output_file_path: str, - selective_builder: SelectiveBuilder, -) -> None: - root_ops = extract_root_operators(selective_builder) - custom_classes = selective_builder.custom_classes - build_features = selective_builder.build_features - with open(output_file_path, "wb") as out_file: - body_parts = [selected_mobile_ops_preamble] - # This condition checks if we are in selective build. - # if these lists are not defined the corresponding selective build macros trivially return the item in question was selected - if not selective_builder.include_all_operators: - body_parts.append( - "#define TORCH_OPERATOR_WHITELIST " - + (";".join(sorted(root_ops))) - + ";\n\n" - ) - # This condition checks if we are in tracing based selective build - if selective_builder.include_all_non_op_selectives is False: - body_parts.append( - "#define TORCH_CUSTOM_CLASS_ALLOWLIST " - + (";".join(sorted(custom_classes))) - + ";\n\n" - ) - body_parts.append( - "#define TORCH_BUILD_FEATURE_ALLOWLIST " - + (";".join(sorted(build_features))) - + ";\n\n" - ) - - body_parts.append(get_selected_kernel_dtypes_code(selective_builder)) - header_contents = "".join(body_parts) - out_file.write(header_contents.encode("utf-8")) - - -def extract_all_operators(selective_builder: SelectiveBuilder) -> Set[str]: - return set(selective_builder.operators.keys()) - - -def extract_training_operators(selective_builder: SelectiveBuilder) -> Set[str]: - ops = [] - for op_name, op in selective_builder.operators.items(): - if op.is_used_for_training: - ops.append(op_name) - return set(ops) - def throw_if_any_op_includes_overloads(selective_builder: SelectiveBuilder) -> None: ops = [] @@ -153,49 +34,6 @@ def throw_if_any_op_includes_overloads(selective_builder: SelectiveBuilder) -> N ) -def gen_supported_mobile_models(model_dicts: List[Any], output_dir: str) -> None: - supported_mobile_models_source = """/* - * Generated by gen_oplist.py - */ -#include "fb/supported_mobile_models/SupportedMobileModels.h" - - -struct SupportedMobileModelCheckerRegistry {{ - SupportedMobileModelCheckerRegistry() {{ - auto& ref = facebook::pytorch::supported_model::SupportedMobileModelChecker::singleton(); - ref.set_supported_md5_hashes(std::unordered_set{{ - {supported_hashes_template} - }}); - }} -}}; - -// This is a global object, initializing which causes the registration to happen. -SupportedMobileModelCheckerRegistry register_model_versions; - - -""" - - # Generate SupportedMobileModelsRegistration.cpp - md5_hashes = set() - for model_dict in model_dicts: - if "debug_info" in model_dict: - debug_info = json.loads(model_dict["debug_info"][0]) - if debug_info["is_new_style_rule"]: - for asset_info in debug_info["asset_info"].values(): - md5_hashes.update(asset_info["md5_hash"]) - - supported_hashes = "" - for md5 in md5_hashes: - supported_hashes += f'"{md5}",\n' - with open( - os.path.join(output_dir, "SupportedMobileModelsRegistration.cpp"), "wb" - ) as out_file: - source = supported_mobile_models_source.format( - supported_hashes_template=supported_hashes - ) - out_file.write(source.encode("utf-8")) - - def main(argv: List[Any]) -> None: """This binary generates 3 files: @@ -258,9 +96,6 @@ def main(argv: List[Any]) -> None: selective_builders = [SelectiveBuilder.from_yaml_dict(m) for m in model_dicts] - # While we have the model_dicts generate the supported mobile models api - gen_supported_mobile_models(model_dicts, options.output_dir) - # We may have 0 selective builders since there may not be any viable # pt_operator_library rule marked as a dep for the pt_operator_registry rule. # This is potentially an error, and we should probably raise an assertion @@ -283,11 +118,6 @@ def main(argv: List[Any]) -> None: ).encode("utf-8"), ) - write_selected_mobile_ops( - os.path.join(options.output_dir, "selected_mobile_ops.h"), - selective_builder, - ) - if __name__ == "__main__": main(sys.argv[1:]) diff --git a/examples/apple/coreml/scripts/build_executor_runner.sh b/examples/apple/coreml/scripts/build_executor_runner.sh index 347f3b4474f..86ff5f6edb9 100755 --- a/examples/apple/coreml/scripts/build_executor_runner.sh +++ b/examples/apple/coreml/scripts/build_executor_runner.sh @@ -13,7 +13,7 @@ SCRIPT_DIR_PATH="$( EXECUTORCH_ROOT_PATH=$(realpath "$SCRIPT_DIR_PATH/../../../../") COREML_DIR_PATH="$EXECUTORCH_ROOT_PATH/backends/apple/coreml" EXAMPLES_COREML_DIR_PATH="$EXECUTORCH_ROOT_PATH/examples/apple/coreml" -IOS_TOOLCHAIN_PATH="$COREML_DIR_PATH/third-party/ios-cmake/ios.toolchain.cmake" +IOS_TOOLCHAIN_PATH="$EXECUTORCH_ROOT_PATH/third-party/ios-cmake/ios.toolchain.cmake" CMAKE_BUILD_DIR_PATH="$EXAMPLES_COREML_DIR_PATH/cmake-out" LIBRARIES_DIR_PATH="$EXAMPLES_COREML_DIR_PATH/executor_runner/libraries" INCLUDE_DIR_PATH="$EXAMPLES_COREML_DIR_PATH/executor_runner/include" diff --git a/examples/apple/coreml/scripts/export.py b/examples/apple/coreml/scripts/export.py index 966714ba31c..4bf26a7f3ea 100644 --- a/examples/apple/coreml/scripts/export.py +++ b/examples/apple/coreml/scripts/export.py @@ -16,9 +16,7 @@ from executorch.backends.apple.coreml.compiler import CoreMLBackend -from executorch.backends.apple.coreml.partition.coreml_partitioner import ( - CoreMLPartitioner, -) +from executorch.backends.apple.coreml.partition import CoreMLPartitioner from executorch.exir import to_edge from executorch.exir.backend.backend_api import to_backend diff --git a/examples/apple/mps/scripts/mps_example.py b/examples/apple/mps/scripts/mps_example.py index 0bfef7bf4ce..c6ef6b14c74 100644 --- a/examples/apple/mps/scripts/mps_example.py +++ b/examples/apple/mps/scripts/mps_example.py @@ -182,7 +182,9 @@ def get_model_config(args): logging.info(f"Lowered graph:\n{edge.exported_program().graph}") executorch_program = edge.to_executorch( - config=ExecutorchBackendConfig(extract_constant_segment=False) + config=ExecutorchBackendConfig( + extract_delegate_segments=False, extract_constant_segment=False + ) ) else: lowered_module = to_backend( @@ -192,7 +194,11 @@ def get_model_config(args): lowered_module, example_inputs, edge_compile_config=exir.EdgeCompileConfig(_check_ir_validity=False), - ).to_executorch(config=ExecutorchBackendConfig(extract_constant_segment=False)) + ).to_executorch( + config=ExecutorchBackendConfig( + extract_delegate_segments=False, extract_constant_segment=False + ) + ) model_name = f"{args.model_name}_mps" diff --git a/examples/arm/aot_arm_compiler.py b/examples/arm/aot_arm_compiler.py index 2c74a829b87..7f30924b7b4 100644 --- a/examples/arm/aot_arm_compiler.py +++ b/examples/arm/aot_arm_compiler.py @@ -211,7 +211,9 @@ def forward(self, x): logging.debug(f"Lowered graph:\n{edge.exported_program().graph}") exec_prog = edge.to_executorch( - config=ExecutorchBackendConfig(extract_constant_segment=False) + config=ExecutorchBackendConfig( + extract_delegate_segments=False, extract_constant_segment=False + ) ) model_name = f"{args.model_name}" + ( diff --git a/examples/demo-apps/apple_ios/LLaMA/LLaMA/Application/ResourceMonitor.swift b/examples/demo-apps/apple_ios/LLaMA/LLaMA/Application/ResourceMonitor.swift index 847eb51bae3..3ec16463e8a 100644 --- a/examples/demo-apps/apple_ios/LLaMA/LLaMA/Application/ResourceMonitor.swift +++ b/examples/demo-apps/apple_ios/LLaMA/LLaMA/Application/ResourceMonitor.swift @@ -33,16 +33,16 @@ final class ResourceMonitor: ObservableObject { } private func usedMemoryInMB() -> Int { - var info = mach_task_basic_info() - var count = mach_msg_type_number_t(MemoryLayout.size) / 4 + var info = task_vm_info_data_t() + var count = mach_msg_type_number_t(MemoryLayout.size) / 4 let kerr: kern_return_t = withUnsafeMutablePointer(to: &info) { $0.withMemoryRebound(to: integer_t.self, capacity: Int(count)) { - task_info(mach_task_self_, task_flavor_t(MACH_TASK_BASIC_INFO), $0, &count) + task_info(mach_task_self_, task_flavor_t(TASK_VM_INFO), $0, &count) } } guard kerr == KERN_SUCCESS else { return 0 } - return Int(info.resident_size / 0x100000) + return Int(info.phys_footprint / 0x100000) } private func availableMemoryInMB() -> Int { diff --git a/examples/models/llama2/README.md b/examples/models/llama2/README.md index 74b6e8e1343..8fc77dd72ba 100644 --- a/examples/models/llama2/README.md +++ b/examples/models/llama2/README.md @@ -292,7 +292,7 @@ Please refer to [this tutorial](https://pytorch.org/executorch/main/llm/llama-de ## Optional: Smaller models delegated to other backends Currently we supported lowering the stories model to other backends, including, CoreML, MPS and QNN. Please refer to the instruction -for each backend ([CoreML](https://pytorch.org/executorch/main/build-run-coreml.html), [MPS](https://pytorch.org/executorch/main/build-run-mps.html), [QNN](https://pytorch.org/executorch/main/build-run-qualcomm.html)) before trying to lower them. After the backend library is installed, the script to export a lowered model is +for each backend ([CoreML](https://pytorch.org/executorch/main/build-run-coreml.html), [MPS](https://pytorch.org/executorch/main/build-run-mps.html), [QNN](https://pytorch.org/executorch/main/build-run-qualcomm-ai-engine-direct-backend.html)) before trying to lower them. After the backend library is installed, the script to export a lowered model is - Lower to CoreML: `python -m examples.models.llama2.export_llama -kv --coreml -c stories110M.pt -p params.json` - MPS: `python -m examples.models.llama2.export_llama -kv --mps -c stories110M.pt -p params.json` diff --git a/examples/models/llama2/lib/partitioner_lib.py b/examples/models/llama2/lib/partitioner_lib.py index 1638a357576..c11e74a7e0b 100644 --- a/examples/models/llama2/lib/partitioner_lib.py +++ b/examples/models/llama2/lib/partitioner_lib.py @@ -57,16 +57,14 @@ def get_coreml_partitioner(args): args.use_kv_cache is True ), "CoreML backend currently only supports static shape and use_kv_cache=True is the only way to support it at the moment" try: - # pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.apple.coreml.partition.coreml_partitioner`. + # pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `coremltools`. import coremltools as ct # pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.apple.coreml.compiler` from executorch.backends.apple.coreml.compiler import CoreMLBackend - # pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.apple.coreml.partition.coreml_partitioner` - from executorch.backends.apple.coreml.partition.coreml_partitioner import ( - CoreMLPartitioner, - ) + # pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.apple.coreml.partition` + from executorch.backends.apple.coreml.partition import CoreMLPartitioner except ImportError: raise ImportError( "Please install the CoreML backend follwing https://pytorch.org/executorch/main/build-run-coreml.html" diff --git a/examples/models/llama2/tokenizer/bpe_tokenizer.cpp b/examples/models/llama2/tokenizer/bpe_tokenizer.cpp index ed7d34aca4d..7af2357d9be 100644 --- a/examples/models/llama2/tokenizer/bpe_tokenizer.cpp +++ b/examples/models/llama2/tokenizer/bpe_tokenizer.cpp @@ -146,10 +146,7 @@ BPETokenizer::~BPETokenizer() { * token. */ Result BPETokenizer::decode(uint64_t prev_token, uint64_t token) { - if (!initialized_) { - ET_LOG(Error, "Tokenizer not initialized"); - return Error::NotSupported; - } + ET_CHECK_OK_OR_RETURN_ERROR(Tokenizer::decode_verify(token)); const char* piece = vocab_[token]; // following BOS token, sentencepiece decoder strips any leading // whitespace diff --git a/examples/models/llama2/tokenizer/test/test_bpe_tokenizer.cpp b/examples/models/llama2/tokenizer/test/test_bpe_tokenizer.cpp index 1d1f83065cf..e9eada338d5 100644 --- a/examples/models/llama2/tokenizer/test/test_bpe_tokenizer.cpp +++ b/examples/models/llama2/tokenizer/test/test_bpe_tokenizer.cpp @@ -39,6 +39,14 @@ TEST_F(TokenizerExtensionTest, DecodeWithoutLoadFails) { EXPECT_EQ(result.error(), Error::NotSupported); } +TEST_F(TokenizerExtensionTest, DecodeOutOfRangeFails) { + Error res = tokenizer_->load(modelPath_.c_str()); + EXPECT_EQ(res, Error::Ok); + auto result = tokenizer_->decode(0, 64000); + // The vocab size is 32000, and token 64000 is out of vocab range. + EXPECT_EQ(result.error(), Error::NotSupported); +} + TEST_F(TokenizerExtensionTest, TokenizerVocabSizeIsExpected) { Error res = tokenizer_->load(modelPath_.c_str()); EXPECT_EQ(res, Error::Ok); diff --git a/examples/models/llama2/tokenizer/test/test_tiktoken.cpp b/examples/models/llama2/tokenizer/test/test_tiktoken.cpp index 2f08e2a1aa7..6130a9e858a 100644 --- a/examples/models/llama2/tokenizer/test/test_tiktoken.cpp +++ b/examples/models/llama2/tokenizer/test/test_tiktoken.cpp @@ -77,5 +77,14 @@ TEST_F(TiktokenExtensionTest, TokenizerDecodeCorrectly) { } } +TEST_F(TiktokenExtensionTest, TokenizerDecodeOutOfRangeFails) { + Error res = tokenizer_->load(modelPath_.c_str()); + EXPECT_EQ(res, Error::Ok); + // The vocab size is 128256, addes 256 just so the token is out of vocab + // range. + Result out = tokenizer_->decode(0, 128256 + 256); + EXPECT_EQ(out.error(), Error::NotSupported); +} + } // namespace executor } // namespace torch diff --git a/examples/models/llama2/tokenizer/tiktoken.cpp b/examples/models/llama2/tokenizer/tiktoken.cpp index 849a2ff1e8d..79b61e5eb64 100644 --- a/examples/models/llama2/tokenizer/tiktoken.cpp +++ b/examples/models/llama2/tokenizer/tiktoken.cpp @@ -364,9 +364,7 @@ Tiktoken::encode(const std::string& text, int8_t bos, int8_t eos) { Result Tiktoken::decode(uint64_t prev, uint64_t cur) { (void)prev; - if (!initialized_) { - return Error::NotSupported; - } + ET_CHECK_OK_OR_RETURN_ERROR(Tokenizer::decode_verify(cur)); std::string ret; std::string token_bytes; diff --git a/examples/models/llama2/tokenizer/tokenizer.h b/examples/models/llama2/tokenizer/tokenizer.h index 5e9f0925823..7ad3b32bbb8 100644 --- a/examples/models/llama2/tokenizer/tokenizer.h +++ b/examples/models/llama2/tokenizer/tokenizer.h @@ -40,6 +40,22 @@ class Tokenizer { virtual Result> encode(const std::string& input, int8_t bos, int8_t eos) = 0; + Error decode_verify(uint64_t token) const { + if (!initialized_) { + ET_LOG(Error, "Tokenizer not initialized"); + return Error::NotSupported; + } + if (token >= vocab_size_) { + ET_LOG( + Error, + "token %" PRIu64 " is out side of vacab range %d", + token, + vocab_size_); + return Error::NotSupported; + } + return Error::Ok; + } + virtual Result decode(uint64_t prev_token, uint64_t token) = 0; // getters diff --git a/examples/portable/utils.py b/examples/portable/utils.py index 82242f585a8..9e4a9607618 100644 --- a/examples/portable/utils.py +++ b/examples/portable/utils.py @@ -20,6 +20,7 @@ _EDGE_COMPILE_CONFIG = exir.EdgeCompileConfig( _check_ir_validity=True, + _skip_dim_order=True, # TODO(T189114319): Reuse dim order op after solving the ios oss issue ) diff --git a/examples/qualcomm/scripts/export_example.py b/examples/qualcomm/scripts/export_example.py index a6d2e6d1a3e..98b245c512d 100644 --- a/examples/qualcomm/scripts/export_example.py +++ b/examples/qualcomm/scripts/export_example.py @@ -96,7 +96,9 @@ ) executorch_program = delegated_program.to_executorch( - config=ExecutorchBackendConfig(extract_constant_segment=False) + config=ExecutorchBackendConfig( + extract_delegate_segments=False, extract_constant_segment=False + ) ) if args.generate_etrecord: diff --git a/examples/xnnpack/aot_compiler.py b/examples/xnnpack/aot_compiler.py index 4ef6852fd28..f23ba5e9c21 100644 --- a/examples/xnnpack/aot_compiler.py +++ b/examples/xnnpack/aot_compiler.py @@ -103,7 +103,9 @@ logging.info(f"Lowered graph:\n{edge.exported_program().graph}") exec_prog = edge.to_executorch( - config=ExecutorchBackendConfig(extract_constant_segment=False) + config=ExecutorchBackendConfig( + extract_delegate_segments=False, extract_constant_segment=False + ) ) if args.etrecord is not None: diff --git a/examples/xnnpack/quantization/example.py b/examples/xnnpack/quantization/example.py index 4804af0b42e..a47d2180667 100644 --- a/examples/xnnpack/quantization/example.py +++ b/examples/xnnpack/quantization/example.py @@ -191,7 +191,9 @@ def main() -> None: start = time.perf_counter() prog = edge_m.to_executorch( - config=ExecutorchBackendConfig(extract_constant_segment=False) + config=ExecutorchBackendConfig( + extract_delegate_segments=False, extract_constant_segment=False + ) ) save_pte_program(prog, f"{args.model_name}_quantized") end = time.perf_counter() diff --git a/exir/capture/_config.py b/exir/capture/_config.py index fecb2382e27..dd0ed94094f 100644 --- a/exir/capture/_config.py +++ b/exir/capture/_config.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from dataclasses import dataclass, field -from typing import List, Optional +from typing import Dict, List, Optional, Union from executorch.exir.dynamic_shape import DynamicMemoryPlanningMode from executorch.exir.pass_manager import PassType @@ -45,7 +45,12 @@ class EdgeCompileConfig: @dataclass class ExecutorchBackendConfig: passes: List[PassType] = field(default_factory=list) - memory_planning_pass: PassType = MemoryPlanningPass("greedy") + + # A single memory planning pass can be defined for all the programs in the + # EdgeProgramManager or can be defined per program. + memory_planning_pass: Union[PassType, Dict[str, PassType]] = MemoryPlanningPass( + "greedy" + ) to_out_var_pass: PassType = ToOutVarPass(ignore_to_out_var_failure=False) dynamic_memory_planning_mode: DynamicMemoryPlanningMode = ( DynamicMemoryPlanningMode.UPPER_BOUND @@ -55,7 +60,7 @@ class ExecutorchBackendConfig: # Whether to move delegate data blobs from the Program into separate # segments, rather than encoding those blobs in the flatbuffer data. # This makes it possible to free those blobs at runtime. - extract_delegate_segments: bool = False + extract_delegate_segments: bool = True # Whether to extract constants from the Program into separate segments, # rather than encoding those constants in the flatbuffer data. diff --git a/exir/passes/constant_prop_pass.py b/exir/passes/constant_prop_pass.py index 0fabf223fb8..96c40e65363 100644 --- a/exir/passes/constant_prop_pass.py +++ b/exir/passes/constant_prop_pass.py @@ -112,11 +112,11 @@ def get_propagated_const_tensor_dict( # Initialize dict with all constant placeholders. const_node_to_tensor = get_constant_placeholder_dict(exported_program) - all_skip_targets: set[EdgeOpOverload] = set() - # Default set of targets to skip. - all_skip_targets.update(_DEFAULT_SKIP_TARGETS) if custom_skip_targets is not None: - all_skip_targets.update(custom_skip_targets) + all_skip_targets = custom_skip_targets + else: + # Default set of targets to skip. + all_skip_targets = _DEFAULT_SKIP_TARGETS for node in exported_program.graph.nodes: if node.op != "call_function" or node.target in all_skip_targets: diff --git a/exir/passes/replace_view_copy_with_view_pass.py b/exir/passes/replace_view_copy_with_view_pass.py index a9304f3eec8..378b9332119 100644 --- a/exir/passes/replace_view_copy_with_view_pass.py +++ b/exir/passes/replace_view_copy_with_view_pass.py @@ -273,7 +273,9 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: if not isinstance(module, torch.fx.GraphModule): continue for node in module.graph.nodes: - if _is_view_copy(node): + # Note: We only replace view_copy nodes that are not output, since + # the output pointer could be modified at runtime (T187925929) + if _is_view_copy(node) and all(u.op != "output" for u in node.users): base, _ = node.args node.target = _VIEW_OP @@ -298,7 +300,11 @@ def ensures(self, graph_module: torch.fx.GraphModule) -> None: if not isinstance(module, torch.fx.GraphModule): continue for node in module.graph.nodes: - assert not _is_view_copy(node) + # Note: We only replace view_copy nodes that are not output, since + # the output pointer could be modified at runtime (T187925929) + assert not ( + _is_view_copy(node) and all(u.op != "output" for u in node.users) + ) if node.op == "call_function" and node.target == _VIEW_OP: assert isinstance(node.meta["spec"], _ViewSpec) @@ -311,6 +317,8 @@ def requires(self, graph_module: torch.fx.GraphModule) -> None: if not isinstance(module, torch.fx.GraphModule): continue for node in module.graph.nodes: - if _is_view_copy(node): + # Note: We only replace view_copy nodes that are not output, since + # the output pointer could be modified at runtime (T187925929) + if _is_view_copy(node) and all(u.op != "output" for u in node.users): base, size = node.args assert not _is_view_copy(base) diff --git a/exir/program/_program.py b/exir/program/_program.py index f2c2a5438fd..c5afe011691 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -412,7 +412,7 @@ def to_executorch( # Existing user passes dont use run so Im just cheating here because they dont need to work on mutable buffers yet. # After exir.capture is gone I will clean up the memory planning infra to be consistent. # Frankly all of exir has big code quality issues because of the migrations that need to be addressed. - new_gm_res = config.memory_planning_pass(new_gm) # pyre-ignore[19] + new_gm_res = config.memory_planning_pass(new_gm) # pyre-ignore[29] assert new_gm_res is not None new_gm = new_gm_res.graph_module new_prog = ExirExportedProgram( @@ -889,7 +889,8 @@ def to_backend( ) def to_executorch( - self, config: Optional[ExecutorchBackendConfig] = None + self, + config: Optional[ExecutorchBackendConfig] = None, ) -> "ExecutorchProgramManager": """ Transforms the program to the ExecuTorch backend. @@ -926,13 +927,19 @@ def to_executorch( # TODO(who?) p.update_placeholder_tensor_specs(program, new_gm) + if isinstance(config.memory_planning_pass, dict): + memory_planning_pass = config.memory_planning_pass.get( + name, ExecutorchBackendConfig().memory_planning_pass + ) + else: + memory_planning_pass = config.memory_planning_pass # TODO(jakeszwe): Follow up with compiler on if the deepcopy is necessary and if so how to make it work - if hasattr(config.memory_planning_pass, "run"): - new_gm_res = config.memory_planning_pass.run( # pyre-ignore[16] + if hasattr(memory_planning_pass, "run"): + new_gm_res = memory_planning_pass.run( # pyre-ignore[16] new_gm, new_signature ) else: - new_gm_res = config.memory_planning_pass(new_gm) # pyre-ignore[19] + new_gm_res = memory_planning_pass(new_gm) # pyre-ignore[29] assert new_gm_res is not None new_gm = new_gm_res.graph_module diff --git a/exir/program/test/test_program.py b/exir/program/test/test_program.py index 01de1f3befd..f84f0c1cd02 100644 --- a/exir/program/test/test_program.py +++ b/exir/program/test/test_program.py @@ -16,6 +16,7 @@ from executorch.exir.error import ExportError from executorch.exir.lowered_backend_module import get_lowered_submodules from executorch.exir.pass_base import ExportPass +from executorch.exir.passes import MemoryPlanningPass from executorch.exir.program._program import ( EdgeProgramManager, ExecutorchProgramManager, @@ -160,6 +161,45 @@ def test_executorch_manager_basic_api(self): 3, ) + def test_executorch_manager_multi_config(self): + def get_executorch_memory_planning_passes() -> Dict[str, MemoryPlanningPass]: + return { + "forward": MemoryPlanningPass( + memory_planning_algo="greedy", + alloc_graph_input=True, + alloc_graph_output=False, + ), + "foo": MemoryPlanningPass( + memory_planning_algo="greedy", + alloc_graph_input=False, + alloc_graph_output=True, + ), + } + + executorch_manager: ExecutorchProgramManager = to_edge( + get_exported_programs(), get_config_methods() + ).to_executorch( + ExecutorchBackendConfig( + memory_planning_pass=get_executorch_memory_planning_passes() + ) + ) + + method = executorch_manager._emitter_output.program.execution_plan[0] + if method.name == "forward": + for input_val in method.inputs: + evalue = method.values[input_val] + self.assertEqual(evalue.val.allocation_info, None) + for output_val in method.outputs: + evalue = method.values[output_val] + self.assertNotEqual(evalue.val.allocation_info, None) + else: + for input_val in method.inputs: + evalue = method.values[input_val] + self.assertEqual(evalue.val.allocation_info, None) + for output_val in method.outputs: + evalue = method.values[output_val] + self.assertNotEqual(evalue.val.allocation_info, None) + def test_no_getattr(self): class Mul(torch.nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -293,9 +333,7 @@ def test_edge_to_backend_replaces_subgraph(self): # two delegate blobs for forward and foo self.assertEqual( len( - delegate_manager.to_executorch( - ExecutorchBackendConfig(extract_delegate_segments=True) - ) + delegate_manager.to_executorch(ExecutorchBackendConfig()) ._emitter_output.program.execution_plan[0] .delegates ), @@ -303,9 +341,7 @@ def test_edge_to_backend_replaces_subgraph(self): ) self.assertEqual( len( - delegate_manager.to_executorch( - ExecutorchBackendConfig(extract_delegate_segments=True) - ) + delegate_manager.to_executorch(ExecutorchBackendConfig()) ._emitter_output.program.execution_plan[1] .delegates ), @@ -349,7 +385,11 @@ def test_edge_to_backend_selective(self): # one delegate blob for forward self.assertEqual( len( - delegate_manager.to_executorch(ExecutorchBackendConfig()) + delegate_manager.to_executorch( + ExecutorchBackendConfig( + extract_delegate_segments=False, + ) + ) ._emitter_output.program.execution_plan[0] # foo .delegates ), @@ -357,7 +397,11 @@ def test_edge_to_backend_selective(self): ) self.assertEqual( len( - delegate_manager.to_executorch(ExecutorchBackendConfig()) + delegate_manager.to_executorch( + ExecutorchBackendConfig( + extract_delegate_segments=False, + ) + ) ._emitter_output.program.execution_plan[1] # forward .delegates ), diff --git a/exir/serde/export_serialize.py b/exir/serde/export_serialize.py index 87691dfbee2..13590b78dde 100644 --- a/exir/serde/export_serialize.py +++ b/exir/serde/export_serialize.py @@ -22,6 +22,7 @@ import sympy import torch +import torch._export.exported_program import torch.export.exported_program as ep from torch._export.serde.schema import ( diff --git a/exir/tensor.py b/exir/tensor.py index ee2633654e8..ee074cf7119 100644 --- a/exir/tensor.py +++ b/exir/tensor.py @@ -37,7 +37,11 @@ def contiguous_stride_from_shape(shape: torch.Size) -> Tuple[int]: strides.append(accum) # For sizes[i] == 0, treat it as 1 to be consistent with core Pytorch # This preserves the PT equivalent behavior for dims with 0 elements - if sz != 0: + if isinstance(sz, int): + if sz != 0: + accum *= sz + else: + # Unbacked symints may error on the != 0 check accum *= sz return tuple(reversed(strides)) diff --git a/exir/tests/test_memory_planning.py b/exir/tests/test_memory_planning.py index 90a6d7b7d8c..12a0583ab41 100644 --- a/exir/tests/test_memory_planning.py +++ b/exir/tests/test_memory_planning.py @@ -495,7 +495,7 @@ def test_multiple_pools( memory_planning_pass=CustomPoolMemoryPlanningPass( memory_planning_algo=algo, alignment=1, - ) + ), ) ) graph_module = edge_program.exported_program().graph_module diff --git a/exir/tests/test_passes.py b/exir/tests/test_passes.py index 9c5e4b59adc..f65ccff13b0 100644 --- a/exir/tests/test_passes.py +++ b/exir/tests/test_passes.py @@ -1602,7 +1602,9 @@ def __init__(self): def forward(self, x): o1 = torch.ops.aten.view_copy.default(x, [1]) o2 = torch.ops.aten.view_copy.default(self.parameter, [1]) - return o1, o2 + # view_copys at the end of a function are not replaced, so add + # a computation before the end of the graph. + return torch.ops.aten.add.Tensor(o1, o2) ep = torch.export.export( TestViewCopies(), @@ -1630,7 +1632,7 @@ def forward(self, x): assert gm_res is not None gm = gm_res.graph_module - # Check before transformation + # Check after transformation FileCheck().check_count( "torch.ops.aten.view_copy.default", 0, exactly=True ).run(gm.code) diff --git a/exir/tests/test_remove_view_copy.py b/exir/tests/test_remove_view_copy.py index 0c5b61f8d8f..f64a1f19981 100644 --- a/exir/tests/test_remove_view_copy.py +++ b/exir/tests/test_remove_view_copy.py @@ -19,6 +19,8 @@ def __init__(self): super().__init__() self.parameter = nn.Parameter(torch.rand(5, 6)) self.parameter.requires_grad = False + self.parameter2 = nn.Parameter(torch.rand(30)) + self.parameter2.requires_grad = False def forward(self, x): v1 = self.parameter.view( @@ -28,7 +30,10 @@ def forward(self, x): v3 = torch.ops.aten.mul.Tensor(v1, v2).view( 30 ) # removed, lifetime of mul.Tensor will be extended - return v3 + v4 = torch.ops.aten.mul.Tensor(v3, self.parameter2) + v5 = v4.view(6, 5) # not removed, output of the graph + v6 = v4.view(2, 15) # not removed, output of the graph + return v5, v6 def get_example_inputs(self): return (torch.rand(5, 6),) @@ -83,10 +88,15 @@ def test_output_matches(self) -> None: ), ) - out_remove = etpm_remove.exported_program().module()(*example_inputs) - out_no_remove = etpm_no_remove.exported_program().module()(*example_inputs) + out_remove_v5, out_remove_v6 = etpm_remove.exported_program().module()( + *example_inputs + ) + out_no_remove_v5, out_no_remove_v6 = etpm_no_remove.exported_program().module()( + *example_inputs + ) - self.assertTrue(torch.allclose(out_remove, out_no_remove)) + self.assertTrue(torch.allclose(out_remove_v5, out_no_remove_v5)) + self.assertTrue(torch.allclose(out_remove_v6, out_no_remove_v6)) def test_spec(self) -> None: model = TestModel1() @@ -106,20 +116,25 @@ def test_spec(self) -> None: # etpm.exported_program().graph.print_tabular() # idx opcode name target args kwargs - # --- ------------- ------------------------ ---------------------------------- -------------------------------------------------- -------------- + # --- ------------- ------------------------ ---------------------------------- -------------------------------------------------- ---------------- # 0 placeholder p_parameter p_parameter () {} - # 1 placeholder x x () {} - # 2 call_function aten_view_copy_default (p_parameter, [6, 5]) {} - # 3 call_function aten_view_copy_default_1 (x, [6, 5]) {} - # 4 call_function alloc (((6, 5), torch.float32),) {} - # 5 call_function aten_mul_tensor aten.mul.out (aten_view_copy_default, aten_view_copy_default_1) {'out': alloc} - # 6 call_function aten_view_copy_default_2 (aten_mul_tensor, [30]) {} - # 7 output output_1 output ((aten_view_copy_default_2,),) {} + # 1 placeholder p_parameter2 p_parameter2 () {} + # 2 placeholder x x () {} + # 3 call_function aten_view_copy_default (p_parameter, [6, 5]) {} + # 4 call_function aten_view_copy_default_1 (x, [6, 5]) {} + # 5 call_function alloc (((6, 5), torch.float32),) {} + # 6 call_function aten_mul_tensor aten.mul.out (aten_view_copy_default, aten_view_copy_default_1) {'out': alloc} + # 7 call_function aten_view_copy_default_2 (aten_mul_tensor, [30]) {} + # 8 call_function alloc_1 (((30,), torch.float32),) {} + # 9 call_function aten_mul_tensor_1 aten.mul.out (aten_view_copy_default_2, p_parameter2) {'out': alloc_1} + # 10 call_function alloc_2 (((6, 5), torch.float32),) {} + # 11 call_function aten_view_copy_default_3 aten.view_copy.out (aten_mul_tensor_1, [6, 5]) {'out': alloc_2} + # 12 output output_1 output ((aten_view_copy_default_3,),) {} for node in etpm.exported_program().graph.nodes: if node.name == "p_parameter": - # p_parameter's lifetime is extended through aten_view_copy_default (memory.view) to idx 5 - self.assertEqual(node.meta["spec"].lifetime, [0, 5]) + # p_parameter's lifetime is extended through aten_view_copy_default (memory.view) to idx 6 + self.assertEqual(node.meta["spec"].lifetime, [0, 6]) elif node.name == "aten_view_copy_default": # aten_view_copy_default is a memory.view of p_parameter. # p_parameter is a constant with storage, so we check that the view's storage matches the base @@ -149,8 +164,8 @@ def test_spec(self) -> None: node.meta["spec"].lifetime, node.args[0].meta["spec"].lifetime ) elif node.name == "aten_mul_tensor": - # aten_mul_tensor's lifetime is extended through aten_view_copy_default_2 (memory.view) to idx 7 - self.assertEqual(node.meta["spec"].lifetime, [4, 7]) + # aten_mul_tensor's lifetime is extended through aten_view_copy_default_2 (memory.view) to idx 9 + self.assertEqual(node.meta["spec"].lifetime, [5, 9]) elif node.name == "aten_view_copy_default_2": # aten_view_copy_default_2 is a memory.view of aten_mul_tensor @@ -184,9 +199,10 @@ def test_spec(self) -> None: plan = etpm.executorch_program.execution_plan[0] self.assertEqual(plan.operators[0].name, "executorch_prim::et_view") self.assertEqual(plan.operators[1].name, "aten::mul") + self.assertEqual(plan.operators[2].name, "aten::view_copy") instructions = plan.chains[0].instructions - self.assertEqual(len(instructions), 4) + self.assertEqual(len(instructions), 7) self.assertEqual( instructions[0].instr_args.op_index, 0 # pyre-ignore @@ -196,7 +212,16 @@ def test_spec(self) -> None: ) # view @ idx3 self.assertEqual( instructions[2].instr_args.op_index, 1 # pyre-ignore - ) # aten:mul @ idx5 + ) # aten:mul @ idx6 self.assertEqual( instructions[3].instr_args.op_index, 0 # pyre-ignore - ) # view @ idx6 + ) # view @ idx7 + self.assertEqual( + instructions[4].instr_args.op_index, 1 # pyre-ignore + ) # aten:mul @ idx9 + self.assertEqual( + instructions[5].instr_args.op_index, 2 # pyre-ignore + ) # aten:view_copy @ idx11 + self.assertEqual( + instructions[6].instr_args.op_index, 2 # pyre-ignore + ) # aten:view_copy @ idx11 diff --git a/extension/training/optimizer/TARGETS b/extension/training/optimizer/TARGETS new file mode 100644 index 00000000000..2341af9282f --- /dev/null +++ b/extension/training/optimizer/TARGETS @@ -0,0 +1,8 @@ +# Any targets that should be shared between fbcode and xplat must be defined in +# targets.bzl. This file can contain fbcode-only targets. + +load(":targets.bzl", "define_common_targets") + +oncall("executorch") + +define_common_targets() diff --git a/extension/training/optimizer/sgd.h b/extension/training/optimizer/sgd.h new file mode 100644 index 00000000000..a5f46b44066 --- /dev/null +++ b/extension/training/optimizer/sgd.h @@ -0,0 +1,49 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/** + * SGD (stochastic gradient descent) optimizer to perform on-device training. + * This uses the gradients calculated in the backwards pass of the loss function + * and updates the parameters such that it minimizes the loss. + * + * This is similar to the Lite Interpreter implementation of the SGD optimizer, + * but without the dependency on ATen Tensors and autograd. + */ +#pragma once + +namespace torch { +namespace executor { +namespace optimizer { + +/** + * SGD optimizer state. This keeps track of the state of a given parameter to + * be used in later epochs. + */ +class SGDParamState {}; + +/** + * SGD optimizer options. This contains options for performing training on a + * param group, such as the learning rate. + */ +class SGDOptions {}; + +/** + * SGD optimizer param group. This contains the parameters and + * the OptimizerOptions associated to it. + */ +class SGDParamGroup {}; + +/** + * SGD optimizer class. This is responsible for performing the optimization + * step. + */ +class SGD {}; + +} // namespace optimizer +} // namespace executor +} // namespace torch diff --git a/extension/training/optimizer/targets.bzl b/extension/training/optimizer/targets.bzl new file mode 100644 index 00000000000..ffe8e30d7b6 --- /dev/null +++ b/extension/training/optimizer/targets.bzl @@ -0,0 +1,20 @@ +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") + +def define_common_targets(): + """Defines targets that should be shared between fbcode and xplat. + + The directory containing this targets.bzl file should also contain both + TARGETS and BUCK files that call this function. + """ + + runtime.cxx_library( + name = "optimizer", + exported_headers = [ + "sgd.h", + ], + exported_deps = [ + ], + visibility = [ + "@EXECUTORCH_CLIENTS", + ], + ) diff --git a/extension/training/optimizer/test/TARGETS b/extension/training/optimizer/test/TARGETS new file mode 100644 index 00000000000..2341af9282f --- /dev/null +++ b/extension/training/optimizer/test/TARGETS @@ -0,0 +1,8 @@ +# Any targets that should be shared between fbcode and xplat must be defined in +# targets.bzl. This file can contain fbcode-only targets. + +load(":targets.bzl", "define_common_targets") + +oncall("executorch") + +define_common_targets() diff --git a/extension/training/optimizer/test/sgd_test.cpp b/extension/training/optimizer/test/sgd_test.cpp new file mode 100644 index 00000000000..1d35e43458f --- /dev/null +++ b/extension/training/optimizer/test/sgd_test.cpp @@ -0,0 +1,28 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +using namespace ::testing; +using namespace torch::executor::optimizer; + +class SGDOptimizerTest : public ::testing::Test {}; + +TEST_F(SGDOptimizerTest, InstantiateTypes) { + SGDParamState state; + SGDOptions options; + SGDParamGroup param_group; + SGD sgd; + + EXPECT_TRUE(dynamic_cast(&state) != nullptr); + EXPECT_TRUE(dynamic_cast(&options) != nullptr); + EXPECT_TRUE(dynamic_cast(¶m_group) != nullptr); + EXPECT_TRUE(dynamic_cast(&sgd) != nullptr); +} diff --git a/extension/training/optimizer/test/targets.bzl b/extension/training/optimizer/test/targets.bzl new file mode 100644 index 00000000000..9d380f90a14 --- /dev/null +++ b/extension/training/optimizer/test/targets.bzl @@ -0,0 +1,18 @@ +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") + +def define_common_targets(): + """Defines targets that should be shared between fbcode and xplat. + + The directory containing this targets.bzl file should also contain both + TARGETS and BUCK files that call this function. + """ + + runtime.cxx_test( + name = "sgd_test", + srcs = [ + "sgd_test.cpp", + ], + deps = [ + "//executorch/extension/training/optimizer:optimizer", + ], + ) diff --git a/install_requirements.sh b/install_requirements.sh index c5b45706709..24a01cae9b6 100755 --- a/install_requirements.sh +++ b/install_requirements.sh @@ -59,7 +59,7 @@ done # NOTE: If a newly-fetched version of the executorch repo changes the value of # NIGHTLY_VERSION, you should re-run this script to install the necessary # package versions. -NIGHTLY_VERSION=dev20240422 +NIGHTLY_VERSION=dev20240507 # The pip repository that hosts nightly torch packages. TORCH_NIGHTLY_URL="https://download.pytorch.org/whl/nightly/cpu" @@ -73,6 +73,7 @@ EXIR_REQUIREMENTS=( # pip packages needed for development. DEVEL_REQUIREMENTS=( cmake # For building binary targets. + "pip>=23" # For building the pip package. pyyaml # Imported by the kernel codegen tools. "setuptools>=63" # For building the pip package. tomli # Imported by extract_sources.py when using python < 3.11. diff --git a/kernels/optimized/cpu/op_add.cpp b/kernels/optimized/cpu/op_add.cpp index c11c9977fe5..b62c3b154fa 100644 --- a/kernels/optimized/cpu/op_add.cpp +++ b/kernels/optimized/cpu/op_add.cpp @@ -16,6 +16,55 @@ namespace torch { namespace executor { namespace native { +namespace { + +template < + bool can_cast, + typename CTYPE_A, + typename CTYPE_B, + typename CTYPE_IN, + typename CTYPE_OUT> +struct AddInner; + +template < + typename CTYPE_A, + typename CTYPE_B, + typename CTYPE_IN, + typename CTYPE_OUT> +struct AddInner { + static void + run(const Tensor& a, const Tensor& b, CTYPE_IN alpha_val, Tensor& out) { + apply_binary_elementwise_fn( + // NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue) + [alpha_val](const CTYPE_A val_a, const CTYPE_B val_b) { + CTYPE_IN a_casted = static_cast(val_a); + CTYPE_IN b_casted = static_cast(val_b); + CTYPE_IN value = a_casted + alpha_val * b_casted; + + return static_cast(value); + }, + a, + b, + out); + } +}; + +template +struct ReportCanCastBug { + static void run(const Tensor&, const Tensor&, CTYPE_IN, Tensor&) { + ET_DCHECK_MSG(false, "BUG: canCast should have been checked above"); + } +}; + +template < + typename CTYPE_A, + typename CTYPE_B, + typename CTYPE_IN, + typename CTYPE_OUT> +struct AddInner + : public ReportCanCastBug {}; + +} // namespace using Tensor = exec_aten::Tensor; using ScalarType = exec_aten::ScalarType; @@ -69,26 +118,20 @@ Tensor& opt_add_out( ET_SWITCH_REALHB_TYPES(a_type, ctx, "add.out", CTYPE_A, [&]() { ET_SWITCH_REALHB_TYPES(b_type, ctx, "add.out", CTYPE_B, [&]() { - ET_SWITCH_REALB_TYPES(common_type, ctx, "add.out", CTYPE_IN, [&]() { - ET_SWITCH_REALHB_TYPES(out_type, ctx, "add.out", CTYPE_OUT, [&]() { - CTYPE_IN alpha_val; - ET_KERNEL_CHECK( - ctx, - utils::extract_scalar(alpha, &alpha_val), - InvalidArgument, ); - - apply_binary_elementwise_fn( - [alpha_val](const CTYPE_A val_a, const CTYPE_B val_b) { - CTYPE_IN a_casted = static_cast(val_a); - CTYPE_IN b_casted = static_cast(val_b); - CTYPE_IN value = a_casted + alpha_val * b_casted; - - return static_cast(value); - }, - a, - b, - out); - }); + using CTYPE_IN = typename torch::executor:: + promote_types::type; + ET_DCHECK(CppTypeToScalarType::value == common_type); + ET_SWITCH_REALHB_TYPES(out_type, ctx, "add.out", CTYPE_OUT, [&]() { + CTYPE_IN alpha_val; + ET_KERNEL_CHECK( + ctx, utils::extract_scalar(alpha, &alpha_val), InvalidArgument, ); + + AddInner< + can_cast::value, + CTYPE_A, + CTYPE_B, + CTYPE_IN, + CTYPE_OUT>::run(a, b, alpha_val, out); }); }); }); diff --git a/kernels/optimized/cpu/op_le.cpp b/kernels/optimized/cpu/op_le.cpp index 05e7889671b..15481403c2d 100644 --- a/kernels/optimized/cpu/op_le.cpp +++ b/kernels/optimized/cpu/op_le.cpp @@ -53,31 +53,26 @@ Tensor& opt_le_tensor_out( a.numel()); }); } else { - ScalarType common_type = promoteTypes(a_type, b_type); ET_SWITCH_REAL_TYPES_AND( Bool, a_type, ctx, "le.Tensor_out", CTYPE_A, [&]() { ET_SWITCH_REAL_TYPES_AND( Bool, b_type, ctx, "le.Tensor_out", CTYPE_B, [&]() { + using CTYPE_IN = typename torch::executor:: + promote_types::type; + ET_DCHECK( + CppTypeToScalarType::value == + promoteTypes(a_type, b_type)); ET_SWITCH_REAL_TYPES_AND( - Bool, common_type, ctx, "le.Tensor_out", CTYPE_IN, [&]() { - ET_SWITCH_REAL_TYPES_AND( - Bool, - out_type, - ctx, - "le.Tensor_out", - CTYPE_OUT, - [&]() { - const size_t n = a.numel(); - const CTYPE_A* a_data = a.const_data_ptr(); - const CTYPE_B* b_data = b.const_data_ptr(); - CTYPE_OUT* out_data = - out.mutable_data_ptr(); - for (auto i = 0; i < n; ++i) { - out_data[i] = static_cast( - static_cast(a_data[i]) <= - static_cast(b_data[i])); - } - }); + Bool, out_type, ctx, "le.Tensor_out", CTYPE_OUT, [&]() { + const size_t n = a.numel(); + const CTYPE_A* a_data = a.const_data_ptr(); + const CTYPE_B* b_data = b.const_data_ptr(); + CTYPE_OUT* out_data = out.mutable_data_ptr(); + for (auto i = 0; i < n; ++i) { + out_data[i] = static_cast( + static_cast(a_data[i]) <= + static_cast(b_data[i])); + } }); }); }); diff --git a/kernels/optimized/cpu/op_mul.cpp b/kernels/optimized/cpu/op_mul.cpp index 3b2926a8a74..adcd8999150 100644 --- a/kernels/optimized/cpu/op_mul.cpp +++ b/kernels/optimized/cpu/op_mul.cpp @@ -41,6 +41,50 @@ bool can_use_optimized_path( (a.numel() == b.numel() && a.numel() == out.numel())); return can_use_optimized_path; } + +template < + bool can_cast, + typename CTYPE_A, + typename CTYPE_B, + typename CTYPE_IN, + typename CTYPE_OUT> +struct MulInner; + +template < + typename CTYPE_A, + typename CTYPE_B, + typename CTYPE_IN, + typename CTYPE_OUT> +struct MulInner { + static void run(const Tensor& a, const Tensor& b, Tensor& out) { + apply_binary_elementwise_fn( + // NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue) + [](const CTYPE_A val_a, const CTYPE_B val_b) { + CTYPE_IN a_casted = static_cast(val_a); + CTYPE_IN b_casted = static_cast(val_b); + CTYPE_IN value = a_casted * b_casted; + + return static_cast(value); + }, + a, + b, + out); + } +}; + +struct ReportCanCastBug { + static void run(const Tensor&, const Tensor&, Tensor&) { + ET_DCHECK_MSG(false, "BUG: canCast should have been checked above"); + } +}; + +template < + typename CTYPE_A, + typename CTYPE_B, + typename CTYPE_IN, + typename CTYPE_OUT> +struct MulInner + : public ReportCanCastBug {}; } // namespace Tensor& opt_mul_out( @@ -86,20 +130,21 @@ Tensor& opt_mul_out( ET_SWITCH_REALHB_TYPES(a_type, ctx, "mul.out", CTYPE_A, [&]() { ET_SWITCH_REALHB_TYPES(b_type, ctx, "mul.out", CTYPE_B, [&]() { - ET_SWITCH_REALB_TYPES(common_type, ctx, "mul.out", CTYPE_IN, [&]() { - ET_SWITCH_REALHB_TYPES(out_type, ctx, "mul.out", CTYPE_OUT, [&]() { - apply_binary_elementwise_fn( - [](const CTYPE_A val_a, const CTYPE_B val_b) { - CTYPE_IN a_casted = static_cast(val_a); - CTYPE_IN b_casted = static_cast(val_b); - CTYPE_IN value = a_casted * b_casted; - - return static_cast(value); - }, - a, - b, - out); - }); + using CTYPE_IN = typename torch::executor:: + promote_types::type; + ET_DCHECK(CppTypeToScalarType::value == common_type); + ET_SWITCH_REALHB_TYPES(out_type, ctx, "mul.out", CTYPE_OUT, [&]() { + apply_binary_elementwise_fn( + [](const CTYPE_A val_a, const CTYPE_B val_b) { + CTYPE_IN a_casted = static_cast(val_a); + CTYPE_IN b_casted = static_cast(val_b); + CTYPE_IN value = a_casted * b_casted; + + return static_cast(value); + }, + a, + b, + out); }); }); }); diff --git a/kernels/optimized/cpu/op_sub.cpp b/kernels/optimized/cpu/op_sub.cpp index 77917c0eda6..87368f3ed76 100644 --- a/kernels/optimized/cpu/op_sub.cpp +++ b/kernels/optimized/cpu/op_sub.cpp @@ -17,6 +17,55 @@ namespace torch { namespace executor { namespace native { +namespace { + +template < + bool can_cast, + typename CTYPE_A, + typename CTYPE_B, + typename CTYPE_IN, + typename CTYPE_OUT> +struct SubInner; + +template < + typename CTYPE_A, + typename CTYPE_B, + typename CTYPE_IN, + typename CTYPE_OUT> +struct SubInner { + static void + run(const Tensor& a, const Tensor& b, CTYPE_IN alpha_val, Tensor& out) { + apply_binary_elementwise_fn( + // NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue) + [alpha_val](const CTYPE_A val_a, const CTYPE_B val_b) { + CTYPE_IN a_casted = static_cast(val_a); + CTYPE_IN b_casted = static_cast(val_b); + CTYPE_IN value = a_casted - alpha_val * b_casted; + + return static_cast(value); + }, + a, + b, + out); + } +}; + +template +struct ReportCanCastBug { + static void run(const Tensor&, const Tensor&, CTYPE_IN, Tensor&) { + ET_DCHECK_MSG(false, "BUG: canCast should have been checked above"); + } +}; + +template < + typename CTYPE_A, + typename CTYPE_B, + typename CTYPE_IN, + typename CTYPE_OUT> +struct SubInner + : public ReportCanCastBug {}; + +} // namespace using Tensor = exec_aten::Tensor; using ScalarType = exec_aten::ScalarType; @@ -72,26 +121,19 @@ Tensor& opt_sub_out( ET_SWITCH_REALH_TYPES(a_type, ctx, "sub.out", CTYPE_A, [&]() { ET_SWITCH_REALH_TYPES(b_type, ctx, "sub.out", CTYPE_B, [&]() { - ET_SWITCH_REAL_TYPES(common_type, ctx, "sub.out", CTYPE_IN, [&]() { - ET_SWITCH_REALH_TYPES(out_type, ctx, "sub.out", CTYPE_OUT, [&]() { - CTYPE_IN alpha_val; - ET_KERNEL_CHECK( - ctx, - utils::extract_scalar(alpha, &alpha_val), - InvalidArgument, ); - - apply_binary_elementwise_fn( - [alpha_val](const CTYPE_A val_a, const CTYPE_B val_b) { - CTYPE_IN a_casted = static_cast(val_a); - CTYPE_IN b_casted = static_cast(val_b); - CTYPE_IN value = a_casted - alpha_val * b_casted; - - return static_cast(value); - }, - a, - b, - out); - }); + using CTYPE_IN = typename torch::executor:: + promote_types::type; + ET_DCHECK(CppTypeToScalarType::value == common_type); + ET_SWITCH_REALH_TYPES(out_type, ctx, "sub.out", CTYPE_OUT, [&]() { + CTYPE_IN alpha_val; + ET_KERNEL_CHECK( + ctx, utils::extract_scalar(alpha, &alpha_val), InvalidArgument, ); + SubInner< + can_cast::value, + CTYPE_A, + CTYPE_B, + CTYPE_IN, + CTYPE_OUT>::run(a, b, alpha_val, out); }); }); }); diff --git a/kernels/portable/cpu/op_bitwise_and.cpp b/kernels/portable/cpu/op_bitwise_and.cpp index b1078f780a4..de137afbec2 100644 --- a/kernels/portable/cpu/op_bitwise_and.cpp +++ b/kernels/portable/cpu/op_bitwise_and.cpp @@ -6,8 +6,10 @@ * LICENSE file in the root directory of this source tree. */ -#include +// patternlint-disable-next-line executorch-cpp-nostdinc +#include +#include #include #include #include @@ -17,20 +19,6 @@ namespace torch { namespace executor { namespace native { -namespace { - -template -CTYPE bitwise_and(CTYPE a, CTYPE b) { - return a & b; -} - -template <> -bool bitwise_and(bool a, bool b) { - return a && b; -} - -} // namespace - using Tensor = exec_aten::Tensor; Tensor& bitwise_and_Tensor_out( @@ -55,38 +43,23 @@ Tensor& bitwise_and_Tensor_out( Bool, a_type, ctx, "bitwise_and.Tensor_out", CTYPE_A, [&]() { ET_SWITCH_INT_TYPES_AND( Bool, b_type, ctx, "bitwise_and.Tensor_out", CTYPE_B, [&]() { - ET_SWITCH_INT_TYPES_AND( + using CTYPE_IN = typename torch::executor:: + promote_types::type; + ET_DCHECK(CppTypeToScalarType::value == common_type); + ET_SWITCH_REAL_TYPES_AND( Bool, - common_type, + out_type, ctx, "bitwise_and.Tensor_out", - CTYPE_IN, + CTYPE_OUT, [&]() { - ET_SWITCH_REAL_TYPES_AND( - Bool, - out_type, - ctx, - "bitwise_and.Tensor_out", - CTYPE_OUT, - [&]() { - apply_binary_elementwise_fn< - CTYPE_A, - CTYPE_B, - CTYPE_OUT>( - [](const CTYPE_A val_a, const CTYPE_B val_b) { - CTYPE_IN a_casted = - static_cast(val_a); - CTYPE_IN b_casted = - static_cast(val_b); - CTYPE_IN value = - bitwise_and(a_casted, b_casted); - - return static_cast(value); - }, - a, - b, - out); - }); + internal::BitwiseOpInner< + can_cast::value, + std::bit_and, + CTYPE_A, + CTYPE_B, + CTYPE_IN, + CTYPE_OUT>::run(a, b, out); }); }); }); @@ -142,8 +115,8 @@ Tensor& bitwise_and_Scalar_out( static_cast(val_a); CTYPE_IN b_casted = static_cast(val_b); - CTYPE_IN value = - bitwise_and(a_casted, b_casted); + CTYPE_IN value = std::bit_and()( + a_casted, b_casted); return static_cast(value); }, diff --git a/kernels/portable/cpu/op_bitwise_or.cpp b/kernels/portable/cpu/op_bitwise_or.cpp index c13c68d3db4..39707de07ce 100644 --- a/kernels/portable/cpu/op_bitwise_or.cpp +++ b/kernels/portable/cpu/op_bitwise_or.cpp @@ -6,8 +6,10 @@ * LICENSE file in the root directory of this source tree. */ -#include +// patternlint-disable-next-line executorch-cpp-nostdinc +#include +#include #include #include #include @@ -17,20 +19,6 @@ namespace torch { namespace executor { namespace native { -namespace { - -template -CTYPE bitwise_or(CTYPE a, CTYPE b) { - return a | b; -} - -template <> -bool bitwise_or(bool a, bool b) { - return a || b; -} - -} // namespace - using Tensor = exec_aten::Tensor; Tensor& bitwise_or_Tensor_out( @@ -55,37 +43,23 @@ Tensor& bitwise_or_Tensor_out( Bool, a_type, ctx, "bitwise_or.Tensor_out", CTYPE_A, [&]() { ET_SWITCH_INT_TYPES_AND( Bool, b_type, ctx, "bitwise_or.Tensor_out", CTYPE_B, [&]() { - ET_SWITCH_INT_TYPES_AND( + using CTYPE_IN = typename torch::executor:: + promote_types::type; + ET_DCHECK(CppTypeToScalarType::value == common_type); + ET_SWITCH_REAL_TYPES_AND( Bool, - common_type, + out_type, ctx, "bitwise_or.Tensor_out", - CTYPE_IN, + CTYPE_OUT, [&]() { - ET_SWITCH_REAL_TYPES_AND( - Bool, - out_type, - ctx, - "bitwise_or.Tensor_out", - CTYPE_OUT, - [&]() { - apply_binary_elementwise_fn< - CTYPE_A, - CTYPE_B, - CTYPE_OUT>( - [](const CTYPE_A val_a, const CTYPE_B val_b) { - CTYPE_IN a_casted = - static_cast(val_a); - CTYPE_IN b_casted = - static_cast(val_b); - CTYPE_IN value = bitwise_or(a_casted, b_casted); - - return static_cast(value); - }, - a, - b, - out); - }); + internal::BitwiseOpInner< + can_cast::value, + std::bit_or, + CTYPE_A, + CTYPE_B, + CTYPE_IN, + CTYPE_OUT>::run(a, b, out); }); }); }); @@ -141,7 +115,8 @@ Tensor& bitwise_or_Scalar_out( static_cast(val_a); CTYPE_IN b_casted = static_cast(val_b); - CTYPE_IN value = bitwise_or(a_casted, b_casted); + CTYPE_IN value = + std::bit_or()(a_casted, b_casted); return static_cast(value); }, diff --git a/kernels/portable/cpu/op_bitwise_xor.cpp b/kernels/portable/cpu/op_bitwise_xor.cpp index d2ea8a81cfb..1855485ee52 100644 --- a/kernels/portable/cpu/op_bitwise_xor.cpp +++ b/kernels/portable/cpu/op_bitwise_xor.cpp @@ -6,8 +6,10 @@ * LICENSE file in the root directory of this source tree. */ -#include +// patternlint-disable-next-line executorch-cpp-nostdinc +#include +#include #include #include #include @@ -17,20 +19,6 @@ namespace torch { namespace executor { namespace native { -namespace { - -template -CTYPE bitwise_xor(CTYPE a, CTYPE b) { - return a ^ b; -} - -template <> -bool bitwise_xor(bool a, bool b) { - return a != b; -} - -} // namespace - using Tensor = exec_aten::Tensor; Tensor& bitwise_xor_Tensor_out( @@ -38,7 +26,6 @@ Tensor& bitwise_xor_Tensor_out( const Tensor& a, const Tensor& b, Tensor& out) { - // Determine output size and resize for dynamic shapes ET_KERNEL_CHECK( ctx, resize_to_broadcast_target_size(a, b, out) == Error::Ok, @@ -56,38 +43,23 @@ Tensor& bitwise_xor_Tensor_out( Bool, a_type, ctx, "bitwise_xor.Tensor_out", CTYPE_A, [&]() { ET_SWITCH_INT_TYPES_AND( Bool, b_type, ctx, "bitwise_xor.Tensor_out", CTYPE_B, [&]() { - ET_SWITCH_INT_TYPES_AND( + using CTYPE_IN = typename torch::executor:: + promote_types::type; + ET_DCHECK(CppTypeToScalarType::value == common_type); + ET_SWITCH_REAL_TYPES_AND( Bool, - common_type, + out_type, ctx, "bitwise_xor.Tensor_out", - CTYPE_IN, + CTYPE_OUT, [&]() { - ET_SWITCH_REAL_TYPES_AND( - Bool, - out_type, - ctx, - "bitwise_xor.Tensor_out", - CTYPE_OUT, - [&]() { - apply_binary_elementwise_fn< - CTYPE_A, - CTYPE_B, - CTYPE_OUT>( - [](const CTYPE_A val_a, const CTYPE_B val_b) { - CTYPE_IN a_casted = - static_cast(val_a); - CTYPE_IN b_casted = - static_cast(val_b); - CTYPE_IN value = - bitwise_xor(a_casted, b_casted); - - return static_cast(value); - }, - a, - b, - out); - }); + internal::BitwiseOpInner< + can_cast::value, + std::bit_xor, + CTYPE_A, + CTYPE_B, + CTYPE_IN, + CTYPE_OUT>::run(a, b, out); }); }); }); @@ -143,8 +115,8 @@ Tensor& bitwise_xor_Scalar_out( static_cast(val_a); CTYPE_IN b_casted = static_cast(val_b); - CTYPE_IN value = - bitwise_xor(a_casted, b_casted); + CTYPE_IN value = std::bit_xor()( + a_casted, b_casted); return static_cast(value); }, diff --git a/kernels/portable/cpu/op_clamp.cpp b/kernels/portable/cpu/op_clamp.cpp index 06c87d03f2d..50d7e8c374d 100644 --- a/kernels/portable/cpu/op_clamp.cpp +++ b/kernels/portable/cpu/op_clamp.cpp @@ -53,7 +53,7 @@ __ET_NODISCARD bool check_bounds( } }); } else if (isFloatingType(out_type)) { - ET_SWITCH_FLOAT_TYPES(out_type, ctx, "clamp", CTYPE_OUT, [&]() { + ET_SWITCH_FLOATH_TYPES(out_type, ctx, "clamp", CTYPE_OUT, [&]() { if (std::isfinite(val) && is_out_of_bounds(val)) { ET_LOG(Error, "%s value out of bounds", val_name); @@ -119,7 +119,7 @@ Tensor& clamp_out( ET_KERNEL_CHECK(ctx, common_type == out_type, InvalidArgument, out); - ET_SWITCH_REAL_TYPES(out_type, ctx, "clamp", CTYPE_OUT, [&]() { + ET_SWITCH_REALH_TYPES(out_type, ctx, "clamp", CTYPE_OUT, [&]() { // Extract optional min value CTYPE_OUT min = 0; if (has_min) { @@ -140,7 +140,7 @@ Tensor& clamp_out( }); } - ET_SWITCH_REAL_TYPES_AND(Bool, in_type, ctx, "clamp", CTYPE_IN, [&]() { + ET_SWITCH_REALHB_TYPES(in_type, ctx, "clamp", CTYPE_IN, [&]() { apply_unary_map_fn( [has_min, min, has_max, max](const CTYPE_IN val_in) { CTYPE_OUT val_out = static_cast(val_in); @@ -195,20 +195,20 @@ Tensor& clamp_tensor_out( ScalarType out_type = out.scalar_type(); if (has_min) { - common_type = promoteTypes(common_type, min_type); + common_type = promoteTypes(common_type, min_type, /*half_to_float*/ true); } if (has_max) { - common_type = promoteTypes(common_type, max_type); + common_type = promoteTypes(common_type, max_type, /*half_to_float*/ true); } ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out); constexpr auto name = "clamp.Tensor_out"; - ET_SWITCH_REALB_TYPES(in_type, ctx, name, CTYPE_IN, [&]() { - ET_SWITCH_REALB_TYPES(min_type, ctx, name, CTYPE_MIN, [&]() { - ET_SWITCH_REALB_TYPES(max_type, ctx, name, CTYPE_MAX, [&]() { - ET_SWITCH_REALB_TYPES(out_type, ctx, name, CTYPE_OUT, [&]() { + ET_SWITCH_REALHB_TYPES(in_type, ctx, name, CTYPE_IN, [&]() { + ET_SWITCH_REALHB_TYPES(min_type, ctx, name, CTYPE_MIN, [&]() { + ET_SWITCH_REALHB_TYPES(max_type, ctx, name, CTYPE_MAX, [&]() { + ET_SWITCH_REALHB_TYPES(out_type, ctx, name, CTYPE_OUT, [&]() { apply_ternary_elementwise_fn< CTYPE_IN, CTYPE_MIN, diff --git a/kernels/portable/cpu/op_floor_divide.cpp b/kernels/portable/cpu/op_floor_divide.cpp index 261f77ce617..0514df0ca25 100644 --- a/kernels/portable/cpu/op_floor_divide.cpp +++ b/kernels/portable/cpu/op_floor_divide.cpp @@ -20,6 +20,60 @@ namespace native { using Tensor = exec_aten::Tensor; using ScalarType = exec_aten::ScalarType; +namespace { +template < + bool can_cast, + typename CTYPE_A, + typename CTYPE_B, + typename CTYPE_IN, + typename CTYPE_OUT> +struct FloorDivideInner; + +template < + typename CTYPE_A, + typename CTYPE_B, + typename CTYPE_IN, + typename CTYPE_OUT> +struct FloorDivideInner { + static void + run(const Tensor& a, const Tensor& b, Tensor& out, bool& div_by_zero_error) { + apply_binary_elementwise_fn( + // NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue) + [&div_by_zero_error](const CTYPE_A val_a, const CTYPE_B val_b) { + if (is_integral_type::value) { + if (val_b == 0) { + div_by_zero_error = true; + return static_cast(0); + } + } + CTYPE_IN a_casted = static_cast(val_a); + CTYPE_IN b_casted = static_cast(val_b); + CTYPE_IN value = utils::floor_divide(a_casted, b_casted); + + return static_cast(value); + }, + a, + b, + out); + } +}; + +struct ReportCanCastBug { + static void run(const Tensor&, const Tensor&, Tensor&, bool&) { + ET_DCHECK_MSG(false, "BUG: canCast should have been checked above"); + } +}; + +template < + typename CTYPE_A, + typename CTYPE_B, + typename CTYPE_IN, + typename CTYPE_OUT> +struct FloorDivideInner + : public ReportCanCastBug {}; + +} // namespace + Tensor& floor_divide_out( RuntimeContext& ctx, const Tensor& a, @@ -46,36 +100,17 @@ Tensor& floor_divide_out( Bool, a_type, ctx, "floor_divide.out", CTYPE_A, [&]() { ET_SWITCH_REAL_TYPES_AND( Bool, b_type, ctx, "floor_divide.out", CTYPE_B, [&]() { + using CTYPE_IN = typename torch::executor:: + promote_types::type; + ET_DCHECK(CppTypeToScalarType::value == common_type); ET_SWITCH_REAL_TYPES( - common_type, ctx, "floor_divide.out", CTYPE_IN, [&]() { - ET_SWITCH_REAL_TYPES( - out_type, ctx, "floor_divide.out", CTYPE_OUT, [&]() { - apply_binary_elementwise_fn< - CTYPE_A, - CTYPE_B, - CTYPE_OUT>( - [common_type, &div_by_zero_error]( - const CTYPE_A val_a, const CTYPE_B val_b) { - if (isIntegralType( - common_type, /*includeBool=*/true)) { - if (val_b == 0) { - div_by_zero_error = true; - return static_cast(0); - } - } - CTYPE_IN a_casted = - static_cast(val_a); - CTYPE_IN b_casted = - static_cast(val_b); - CTYPE_IN value = utils::floor_divide( - a_casted, b_casted); - - return static_cast(value); - }, - a, - b, - out); - }); + out_type, ctx, "floor_divide.out", CTYPE_OUT, [&]() { + FloorDivideInner< + can_cast::value, + CTYPE_A, + CTYPE_B, + CTYPE_IN, + CTYPE_OUT>::run(a, b, out, div_by_zero_error); }); }); }); diff --git a/kernels/portable/cpu/op_fmod.cpp b/kernels/portable/cpu/op_fmod.cpp index 0083c1379d5..42f83731199 100644 --- a/kernels/portable/cpu/op_fmod.cpp +++ b/kernels/portable/cpu/op_fmod.cpp @@ -19,6 +19,60 @@ namespace native { using Tensor = exec_aten::Tensor; +namespace { +template < + bool can_cast, + typename CTYPE_A, + typename CTYPE_B, + typename CTYPE_IN, + typename CTYPE_OUT> +struct FmodInner; + +template < + typename CTYPE_A, + typename CTYPE_B, + typename CTYPE_IN, + typename CTYPE_OUT> +struct FmodInner { + static void + run(const Tensor& a, const Tensor& b, Tensor& out, bool& div_by_zero_error) { + apply_binary_elementwise_fn( + // NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue) + [&div_by_zero_error](const CTYPE_A val_a, const CTYPE_B val_b) { + if (is_integral_type::value) { + if (val_b == 0) { + div_by_zero_error = true; + return static_cast(0); + } + } + CTYPE_IN a_casted = static_cast(val_a); + CTYPE_IN b_casted = static_cast(val_b); + CTYPE_IN value = std::fmod(a_casted, b_casted); + + return static_cast(value); + }, + a, + b, + out); + } +}; + +struct ReportCanCastBug { + static void run(const Tensor&, const Tensor&, Tensor&, bool&) { + ET_DCHECK_MSG(false, "BUG: canCast should have been checked above"); + } +}; + +template < + typename CTYPE_A, + typename CTYPE_B, + typename CTYPE_IN, + typename CTYPE_OUT> +struct FmodInner + : public ReportCanCastBug {}; + +} // namespace + Tensor& fmod_Tensor_out( RuntimeContext& ctx, const Tensor& a, @@ -44,35 +98,18 @@ Tensor& fmod_Tensor_out( Bool, a_type, ctx, "fmod.Tensor_out", CTYPE_A, [&]() { ET_SWITCH_REAL_TYPES_AND( Bool, b_type, ctx, "fmod.Tensor_out", CTYPE_B, [&]() { + using CTYPE_IN = typename torch::executor:: + promote_types::type; + ET_DCHECK(CppTypeToScalarType::value == common_type); ET_SWITCH_REAL_TYPES( - common_type, ctx, "fmod.Tensor_out", CTYPE_IN, [&]() { - ET_SWITCH_REAL_TYPES( - out_type, ctx, "fmod.Tensor_out", CTYPE_OUT, [&]() { - apply_binary_elementwise_fn< - CTYPE_A, - CTYPE_B, - CTYPE_OUT>( - [common_type, &div_by_zero_error]( - const CTYPE_A val_a, const CTYPE_B val_b) { - if (isIntegralType( - common_type, /*includeBool=*/true)) { - if (val_b == 0) { - div_by_zero_error = true; - return static_cast(0); - } - } - CTYPE_IN a_casted = - static_cast(val_a); - CTYPE_IN b_casted = - static_cast(val_b); - CTYPE_IN value = std::fmod(a_casted, b_casted); - - return static_cast(value); - }, - a, - b, - out); - }); + out_type, ctx, "fmod.Tensor_out", CTYPE_OUT, [&]() { + FmodInner< + !std::is_same::value && + can_cast::value, + CTYPE_A, + CTYPE_B, + CTYPE_IN, + CTYPE_OUT>::run(a, b, out, div_by_zero_error); }); }); }); diff --git a/kernels/portable/cpu/op_maximum.cpp b/kernels/portable/cpu/op_maximum.cpp index 3e34035d5f6..1353479b294 100644 --- a/kernels/portable/cpu/op_maximum.cpp +++ b/kernels/portable/cpu/op_maximum.cpp @@ -8,6 +8,7 @@ #include #include +#include #include namespace torch { @@ -15,10 +16,49 @@ namespace executor { namespace native { namespace { -template -const T& max(const T& a, const T& b) { - return (b > a) ? b : a; -} +template < + bool can_cast, + typename CTYPE_A, + typename CTYPE_B, + typename CTYPE_IN, + typename CTYPE_OUT> +struct MaximumInner; + +template < + typename CTYPE_A, + typename CTYPE_B, + typename CTYPE_IN, + typename CTYPE_OUT> +struct MaximumInner { + static void run(const Tensor& a, const Tensor& b, Tensor& out) { + apply_binary_elementwise_fn( + // NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue) + [](const CTYPE_A val_a, const CTYPE_B val_b) { + CTYPE_IN a_casted = static_cast(val_a); + CTYPE_IN b_casted = static_cast(val_b); + CTYPE_IN value = utils::max_override(a_casted, b_casted); + + return static_cast(value); + }, + a, + b, + out); + } +}; + +struct ReportCanCastBug { + static void run(const Tensor&, const Tensor&, Tensor&) { + ET_DCHECK_MSG(false, "BUG: canCast should have been checked above"); + } +}; + +template < + typename CTYPE_A, + typename CTYPE_B, + typename CTYPE_IN, + typename CTYPE_OUT> +struct MaximumInner + : public ReportCanCastBug {}; } // namespace @@ -44,20 +84,16 @@ Tensor& maximum_out( ET_SWITCH_REALHB_TYPES(a_type, ctx, "maximum.out", CTYPE_A, [&]() { ET_SWITCH_REALHB_TYPES(b_type, ctx, "maximum.out", CTYPE_B, [&]() { - ET_SWITCH_REALB_TYPES(common_type, ctx, "maximum.out", CTYPE_IN, [&]() { - ET_SWITCH_REALHB_TYPES(out_type, ctx, "maximum.out", CTYPE_OUT, [&]() { - apply_binary_elementwise_fn( - [](const CTYPE_A val_a, const CTYPE_B val_b) { - CTYPE_IN a_casted = static_cast(val_a); - CTYPE_IN b_casted = static_cast(val_b); - CTYPE_IN value = max(a_casted, b_casted); - - return static_cast(value); - }, - a, - b, - out); - }); + using CTYPE_IN = typename torch::executor:: + promote_types::type; + ET_DCHECK(CppTypeToScalarType::value == common_type); + ET_SWITCH_REALHB_TYPES(out_type, ctx, "maximum.out", CTYPE_OUT, [&]() { + MaximumInner< + can_cast::value, + CTYPE_A, + CTYPE_B, + CTYPE_IN, + CTYPE_OUT>::run(a, b, out); }); }); }); diff --git a/kernels/portable/cpu/op_minimum.cpp b/kernels/portable/cpu/op_minimum.cpp index 767a2c4ca59..f18d1a6d368 100644 --- a/kernels/portable/cpu/op_minimum.cpp +++ b/kernels/portable/cpu/op_minimum.cpp @@ -8,6 +8,7 @@ #include #include +#include #include namespace torch { @@ -15,10 +16,49 @@ namespace executor { namespace native { namespace { -template -const T& min(const T& a, const T& b) { - return (b < a) ? b : a; -} +template < + bool can_cast, + typename CTYPE_A, + typename CTYPE_B, + typename CTYPE_IN, + typename CTYPE_OUT> +struct MinimumInner; + +template < + typename CTYPE_A, + typename CTYPE_B, + typename CTYPE_IN, + typename CTYPE_OUT> +struct MinimumInner { + static void run(const Tensor& a, const Tensor& b, Tensor& out) { + apply_binary_elementwise_fn( + // NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue) + [](const CTYPE_A val_a, const CTYPE_B val_b) { + CTYPE_IN a_casted = static_cast(val_a); + CTYPE_IN b_casted = static_cast(val_b); + CTYPE_IN value = utils::min_override(a_casted, b_casted); + + return static_cast(value); + }, + a, + b, + out); + } +}; + +struct ReportCanCastBug { + static void run(const Tensor&, const Tensor&, Tensor&) { + ET_DCHECK_MSG(false, "BUG: canCast should have been checked above"); + } +}; + +template < + typename CTYPE_A, + typename CTYPE_B, + typename CTYPE_IN, + typename CTYPE_OUT> +struct MinimumInner + : public ReportCanCastBug {}; } // namespace @@ -37,30 +77,24 @@ Tensor& minimum_out( ScalarType a_type = a.scalar_type(); ScalarType b_type = b.scalar_type(); - ScalarType common_type = promoteTypes(a_type, b_type); + ScalarType common_type = promoteTypes(a_type, b_type, /*half_to_float*/ true); ScalarType out_type = out.scalar_type(); ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out); - ET_SWITCH_REAL_TYPES_AND(Bool, a_type, ctx, "minimum.out", CTYPE_A, [&]() { - ET_SWITCH_REAL_TYPES_AND(Bool, b_type, ctx, "minimum.out", CTYPE_B, [&]() { - ET_SWITCH_REAL_TYPES_AND( - Bool, common_type, ctx, "minimum.out", CTYPE_IN, [&]() { - ET_SWITCH_REAL_TYPES_AND( - Bool, out_type, ctx, "minimum.out", CTYPE_OUT, [&]() { - apply_binary_elementwise_fn( - [](const CTYPE_A val_a, const CTYPE_B val_b) { - CTYPE_IN a_casted = static_cast(val_a); - CTYPE_IN b_casted = static_cast(val_b); - CTYPE_IN value = min(a_casted, b_casted); - - return static_cast(value); - }, - a, - b, - out); - }); - }); + ET_SWITCH_REALHB_TYPES(a_type, ctx, "minimum.out", CTYPE_A, [&]() { + ET_SWITCH_REALHB_TYPES(b_type, ctx, "minimum.out", CTYPE_B, [&]() { + using CTYPE_IN = typename torch::executor:: + promote_types::type; + ET_DCHECK(CppTypeToScalarType::value == common_type); + ET_SWITCH_REALHB_TYPES(out_type, ctx, "minimum.out", CTYPE_OUT, [&]() { + MinimumInner< + can_cast::value, + CTYPE_A, + CTYPE_B, + CTYPE_IN, + CTYPE_OUT>::run(a, b, out); + }); }); }); diff --git a/kernels/portable/cpu/op_remainder.cpp b/kernels/portable/cpu/op_remainder.cpp index 9e48374a81a..7c858c1c08a 100644 --- a/kernels/portable/cpu/op_remainder.cpp +++ b/kernels/portable/cpu/op_remainder.cpp @@ -20,6 +20,52 @@ namespace native { using Tensor = exec_aten::Tensor; +namespace { +template < + bool can_cast, + typename CTYPE_A, + typename CTYPE_B, + typename CTYPE_IN, + typename CTYPE_OUT> +struct RemainderInner; + +template < + typename CTYPE_A, + typename CTYPE_B, + typename CTYPE_IN, + typename CTYPE_OUT> +struct RemainderInner { + static void run(const Tensor& a, const Tensor& b, Tensor& out) { + apply_binary_elementwise_fn( + // NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue) + [](const CTYPE_A val_a, const CTYPE_B val_b) { + CTYPE_IN a_casted = static_cast(val_a); + CTYPE_IN b_casted = static_cast(val_b); + CTYPE_IN value = utils::remainder_override(a_casted, b_casted); + + return static_cast(value); + }, + a, + b, + out); + } +}; + +struct ReportCanCastBug { + static void run(const Tensor&, const Tensor&, Tensor&) { + ET_DCHECK_MSG(false, "BUG: canCast should have been checked above"); + } +}; + +template < + typename CTYPE_A, + typename CTYPE_B, + typename CTYPE_IN, + typename CTYPE_OUT> +struct RemainderInner + : public ReportCanCastBug {}; + +} // namespace Tensor& remainder_Tensor_out( RuntimeContext& ctx, const Tensor& a, @@ -45,32 +91,17 @@ Tensor& remainder_Tensor_out( Bool, a_type, ctx, "remainder.Tensor_out", CTYPE_A, [&]() { ET_SWITCH_REAL_TYPES_AND( Bool, b_type, ctx, "remainder.Tensor_out", CTYPE_B, [&]() { + using CTYPE_IN = typename torch::executor:: + promote_types::type; + ET_DCHECK(CppTypeToScalarType::value == common_type); ET_SWITCH_REAL_TYPES( - common_type, ctx, "remainder.Tensor_out", CTYPE_IN, [&]() { - ET_SWITCH_REAL_TYPES( - out_type, - ctx, - "remainder.Tensor_out", - CTYPE_OUT, - [&]() { - apply_binary_elementwise_fn< - CTYPE_A, - CTYPE_B, - CTYPE_OUT>( - [](const CTYPE_A val_a, const CTYPE_B val_b) { - CTYPE_IN a_casted = - static_cast(val_a); - CTYPE_IN b_casted = - static_cast(val_b); - CTYPE_IN value = utils::remainder_override( - a_casted, b_casted); - - return static_cast(value); - }, - a, - b, - out); - }); + out_type, ctx, "remainder.Tensor_out", CTYPE_OUT, [&]() { + RemainderInner< + can_cast::value, + CTYPE_A, + CTYPE_B, + CTYPE_IN, + CTYPE_OUT>::run(a, b, out); }); }); }); diff --git a/kernels/portable/cpu/pattern/bitwise_op.h b/kernels/portable/cpu/pattern/bitwise_op.h new file mode 100644 index 00000000000..dda4fe5cd55 --- /dev/null +++ b/kernels/portable/cpu/pattern/bitwise_op.h @@ -0,0 +1,72 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace torch { +namespace executor { +namespace native { +namespace internal { + +template < + bool can_cast, + template + class OpFunc, + typename CTYPE_A, + typename CTYPE_B, + typename CTYPE_IN, + typename CTYPE_OUT> +struct BitwiseOpInner; + +template < + template + class OpFunc, + typename CTYPE_A, + typename CTYPE_B, + typename CTYPE_IN, + typename CTYPE_OUT> +struct BitwiseOpInner { + static void run(const Tensor& a, const Tensor& b, Tensor& out) { + apply_binary_elementwise_fn( + // NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue) + [](const CTYPE_A val_a, const CTYPE_B val_b) { + CTYPE_IN a_casted = static_cast(val_a); + CTYPE_IN b_casted = static_cast(val_b); + CTYPE_IN value = OpFunc()(a_casted, b_casted); + + return static_cast(value); + }, + a, + b, + out); + } +}; + +struct ReportCanCastBug { + static void run(const Tensor&, const Tensor&, Tensor&) { + ET_DCHECK_MSG(false, "BUG: canCast should have been checked above"); + } +}; + +template < + template + class OpFunc, + typename CTYPE_A, + typename CTYPE_B, + typename CTYPE_IN, + typename CTYPE_OUT> +struct BitwiseOpInner + : public ReportCanCastBug {}; + +} // namespace internal +} // namespace native +} // namespace executor +} // namespace torch diff --git a/kernels/portable/cpu/pattern/targets.bzl b/kernels/portable/cpu/pattern/targets.bzl index 360d991767b..7e0b71ed950 100644 --- a/kernels/portable/cpu/pattern/targets.bzl +++ b/kernels/portable/cpu/pattern/targets.bzl @@ -6,6 +6,17 @@ def define_common_targets(): The directory containing this targets.bzl file should also contain both TARGETS and BUCK files that call this function. """ + runtime.cxx_library( + name = "bitwise_op", + exported_headers = [ + "bitwise_op.h", + ], + compiler_flags = [], + deps = [ + "//executorch/runtime/kernel:kernel_includes", + ], + visibility = ["//executorch/kernels/portable/cpu/...", "//executorch/kernels/optimized/cpu/..."], + ) runtime.cxx_library( name = "pattern", diff --git a/kernels/portable/cpu/scalar_utils.h b/kernels/portable/cpu/scalar_utils.h index 989e7978fc3..3daf3e72526 100644 --- a/kernels/portable/cpu/scalar_utils.h +++ b/kernels/portable/cpu/scalar_utils.h @@ -84,9 +84,9 @@ template struct promote_type_with_scalar_type { private: static_assert( - std::is_same::value || - std::is_same::value || - std::is_same::value, + std::is_same::value || + std::is_same::value || + std::is_same::value, "scalar type can only be Bool, Long or Double"); static_assert( !is_qint_type::value, @@ -102,17 +102,19 @@ struct promote_type_with_scalar_type { "promote_type_with_scalar_type not valid for BFloat16"); using promote_type_with_scalar_type_not_respecting_half_to_float = typename std::conditional< - is_complex_type::value || std::is_same::value, + is_complex_type::value || + std::is_same::value, T1, typename std::conditional< - std::is_same::value, + std::is_same::value, typename std::conditional< - std::is_same::value, - internal::I8, + std::is_same::value, + torch::executor::internal::I8, T1>::type, - typename std:: - conditional::value, T1, internal::F4>:: - type>::type>::type; + typename std::conditional< + is_floating_point::value, + T1, + torch::executor::internal::F4>::type>::type>::type; public: using type = typename std::conditional< diff --git a/kernels/portable/cpu/targets.bzl b/kernels/portable/cpu/targets.bzl index 77796c68526..7be1d94d2bf 100644 --- a/kernels/portable/cpu/targets.bzl +++ b/kernels/portable/cpu/targets.bzl @@ -142,6 +142,7 @@ _ATEN_OPS = ( deps = [ "//executorch/runtime/core/exec_aten/util:scalar_type_util", "//executorch/runtime/core/exec_aten/util:tensor_util", + "//executorch/kernels/portable/cpu/pattern:bitwise_op", "//executorch/kernels/portable/cpu/util:broadcast_util", "//executorch/kernels/portable/cpu/util:functional_util", ":scalar_utils", @@ -160,6 +161,7 @@ _ATEN_OPS = ( deps = [ "//executorch/runtime/core/exec_aten/util:scalar_type_util", "//executorch/runtime/core/exec_aten/util:tensor_util", + "//executorch/kernels/portable/cpu/pattern:bitwise_op", "//executorch/kernels/portable/cpu/util:broadcast_util", "//executorch/kernels/portable/cpu/util:functional_util", ":scalar_utils", @@ -170,6 +172,7 @@ _ATEN_OPS = ( deps = [ "//executorch/runtime/core/exec_aten/util:scalar_type_util", "//executorch/runtime/core/exec_aten/util:tensor_util", + "//executorch/kernels/portable/cpu/pattern:bitwise_op", "//executorch/kernels/portable/cpu/util:broadcast_util", "//executorch/kernels/portable/cpu/util:functional_util", ":scalar_utils", @@ -560,6 +563,7 @@ _ATEN_OPS = ( name = "op_maximum", deps = [ "//executorch/kernels/portable/cpu/util:broadcast_util", + "//executorch/kernels/portable/cpu/util:math_util", ":scalar_utils", ], ), @@ -591,6 +595,7 @@ _ATEN_OPS = ( name = "op_minimum", deps = [ "//executorch/kernels/portable/cpu/util:broadcast_util", + "//executorch/kernels/portable/cpu/util:math_util", ":scalar_utils", ], ), diff --git a/kernels/portable/cpu/util/math_util.h b/kernels/portable/cpu/util/math_util.h index 44cb47f8cba..df175147062 100644 --- a/kernels/portable/cpu/util/math_util.h +++ b/kernels/portable/cpu/util/math_util.h @@ -94,6 +94,48 @@ INT_T max_override(INT_T a, INT_T b) { return std::max(a, b); } +template < + typename T, + typename std::enable_if< + std::is_same::value, + bool>::type = true> +T min_override(T a, T b) { + const auto float_a = static_cast(a); + if (std::isnan(float_a)) { + return a; + } + const auto float_b = static_cast(b); + if (std::isnan(float_b)) { + return b; + } + + if (float_a < float_b) { + return a; + } + return b; +} + +template < + typename T, + typename std::enable_if< + std::is_same::value, + bool>::type = true> +T max_override(T a, T b) { + const auto float_a = static_cast(a); + if (std::isnan(float_a)) { + return a; + } + const auto float_b = static_cast(b); + if (std::isnan(float_b)) { + return b; + } + + if (float_a > float_b) { + return a; + } + return b; +} + /** * There is a slight difference in how std::fmod works compared to how ATen * determines remainders: diff --git a/kernels/test/op_clamp_test.cpp b/kernels/test/op_clamp_test.cpp index 871333482c8..0244fd55700 100644 --- a/kernels/test/op_clamp_test.cpp +++ b/kernels/test/op_clamp_test.cpp @@ -147,8 +147,16 @@ class OpClampOutTest : public OperatorTest { // Test cases that are compatible with float and double. template void run_floating_point_test_cases() { - constexpr auto kInfinity = - std::numeric_limits::ctype>::infinity(); + using ctype = typename TensorFactory::ctype; + using opt_infinity_type = std::conditional_t< + std::is_same::value, + float, + ctype>; + constexpr auto kInfinity = std::numeric_limits::infinity(); + const auto kOptInfinity = + OptScalar(static_cast(kInfinity)); + const auto kOptMinusInfinity = + OptScalar(static_cast(-kInfinity)); std::vector> test_cases = { { std::string(__func__) + ": Simple negative/positive clamp", @@ -178,7 +186,7 @@ class OpClampOutTest : public OperatorTest { std::string(__func__) + ": Infinite min", {2, 2}, // sizes {-10.1, -1.1, 1.1, 10.1}, // input_data - OptScalar(-kInfinity), // min + kOptMinusInfinity, // min OptScalar(5.5), // max {-10.1, -1.1, 1.1, 5.5}, // expected_data }, @@ -187,7 +195,7 @@ class OpClampOutTest : public OperatorTest { {2, 2}, // sizes {-10.1, -1.1, 1.1, 10.1}, // input_data OptScalar(-5.5), // min - OptScalar(kInfinity), // max + kOptInfinity, // max {-5.5, -1.1, 1.1, 10.1}, // expected_data }, { @@ -285,6 +293,15 @@ TEST_F(OpClampOutTest, LongTensors) { run_signed_integer_test_cases(); } +TEST_F(OpClampOutTest, HalfTensors) { + // Note that the integer test cases test the situation where the min/max value + // Scalars are integer types, demonstrating that floating point types can be + // clamped to integer values. + run_unsigned_integer_test_cases(); + run_signed_integer_test_cases(); + run_floating_point_test_cases(); +} + TEST_F(OpClampOutTest, FloatTensors) { // Note that the integer test cases test the situation where the min/max value // Scalars are integer types, demonstrating that floating point types can be diff --git a/kernels/test/op_fmod_test.cpp b/kernels/test/op_fmod_test.cpp index 475d4ea5cb4..4ee4d84c1cc 100644 --- a/kernels/test/op_fmod_test.cpp +++ b/kernels/test/op_fmod_test.cpp @@ -32,3 +32,16 @@ class OpFmodTest : public OperatorTest { return torch::executor::aten::fmod_outf(context_, self, other, out); } }; + +TEST_F(OpFmodTest, SmokeTest) { + TensorFactory tfDouble; + TensorFactory tfLong; + TensorFactory tfInt; + + Tensor self = tfLong.full({2, 2}, 46); + Tensor other = tfInt.full({2, 2}, 4); + Tensor out = tfDouble.zeros({2, 2}); + Tensor out_expected = tfDouble.full({2, 2}, 2.0); + op_fmod_tensor_out(self, other, out); + EXPECT_TENSOR_CLOSE(out, out_expected); +} diff --git a/kernels/test/op_minimum_test.cpp b/kernels/test/op_minimum_test.cpp index be43e0af07d..7e12374b8d1 100644 --- a/kernels/test/op_minimum_test.cpp +++ b/kernels/test/op_minimum_test.cpp @@ -65,6 +65,10 @@ TEST_F(OpMinimumOutTest, LongTensors) { test_minimum_out_same_size(); } +TEST_F(OpMinimumOutTest, HalfTensors) { + test_minimum_out_same_size(); +} + TEST_F(OpMinimumOutTest, FloatTensors) { test_minimum_out_same_size(); } diff --git a/kernels/test/op_remainder_test.cpp b/kernels/test/op_remainder_test.cpp index 4a550958a1a..254e8122b61 100644 --- a/kernels/test/op_remainder_test.cpp +++ b/kernels/test/op_remainder_test.cpp @@ -21,6 +21,7 @@ using exec_aten::Tensor; using torch::executor::testing::TensorFactory; class OpRemainderOutTest : public OperatorTest { + protected: Tensor& op_remainder_tensor_out( const Tensor& self, const Tensor& other, @@ -35,3 +36,16 @@ class OpRemainderOutTest : public OperatorTest { return torch::executor::aten::remainder_outf(context_, self, other, out); } }; + +TEST_F(OpRemainderOutTest, SmokeTest) { + TensorFactory tfDouble; + TensorFactory tfLong; + TensorFactory tfInt; + + Tensor self = tfLong.full({2, 2}, 46); + Tensor other = tfInt.full({2, 2}, 4); + Tensor out = tfDouble.zeros({2, 2}); + Tensor out_expected = tfDouble.full({2, 2}, 2.0); + op_remainder_tensor_out(self, other, out); + EXPECT_TENSOR_CLOSE(out, out_expected); +} diff --git a/pyproject.toml b/pyproject.toml index 099cdd0d32c..ca5358e25fe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,7 @@ [build-system] requires = [ "cmake", # For building binary targets in the wheel. + "pip>=23", # For building the pip package. "pyyaml", # Imported by the kernel codegen tools. "setuptools>=63", # For building the pip package contents. "tomli", # Imported by extract_sources.py when using python < 3.11. @@ -50,6 +51,7 @@ dependencies=[ "expecttest", "flatbuffers", "hypothesis", + "mpmath==1.3.0", "numpy>=1.25.2", "packaging", "pandas", diff --git a/runtime/core/exec_aten/util/scalar_type_util.h b/runtime/core/exec_aten/util/scalar_type_util.h index 595ed7a1c02..084289520aa 100644 --- a/runtime/core/exec_aten/util/scalar_type_util.h +++ b/runtime/core/exec_aten/util/scalar_type_util.h @@ -349,6 +349,12 @@ inline constexpr bool isIntegralType( t == exec_aten::ScalarType::Short); } +template +struct is_integral_type + : public std::integral_constant< + bool, + isIntegralType(CppTypeToScalarType::value, includeBool)> {}; + inline constexpr bool isFloatingType(exec_aten::ScalarType t) { return ( t == exec_aten::ScalarType::Double || t == exec_aten::ScalarType::Float || diff --git a/runtime/core/portable_type/test/CMakeLists.txt b/runtime/core/portable_type/test/CMakeLists.txt new file mode 100644 index 00000000000..3ea05677c3d --- /dev/null +++ b/runtime/core/portable_type/test/CMakeLists.txt @@ -0,0 +1,46 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# ### Editing this file ### +# +# This file should be formatted with +# ~~~ +# cmake-format -i CMakeLists.txt +# ~~~ +# It should also be cmake-lint clean. +# + +cmake_minimum_required(VERSION 3.19) +project(runtime_core_portable_type_test) + +# Use C++11 for test. +set(CMAKE_CXX_STANDARD 11) + +set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../../..) + +include(${EXECUTORCH_ROOT}/build/Utils.cmake) + +# Find prebuilt executorch library +find_package(executorch CONFIG REQUIRED) + +enable_testing() +find_package(GTest CONFIG REQUIRED) + +# Let files say "include ". +set(_common_include_directories + ${EXECUTORCH_ROOT}/.. +) +target_include_directories(executorch INTERFACE ${_common_include_directories}) + +set(_test_srcs optional_test.cpp executor_tensor_test.cpp half_test.cpp + scalar_test.cpp tensor_impl_test.cpp +) + +add_executable(runtime_core_portable_type_test ${_test_srcs}) +target_link_libraries( + runtime_core_portable_type_test GTest::gtest GTest::gtest_main executorch +) +add_test(ExecuTorchTest runtime_core_portable_type_test) diff --git a/runtime/executor/method.cpp b/runtime/executor/method.cpp index 1184eb0d3c8..3ac8e4897e4 100644 --- a/runtime/executor/method.cpp +++ b/runtime/executor/method.cpp @@ -933,10 +933,15 @@ Method::set_output_data_ptr(void* buffer, size_t size, size_t output_idx) { InvalidState, "Outputs can not be retrieved until method has been initialized."); - ET_CHECK_OR_RETURN_ERROR( - !pre_allocated_output_, - InvalidState, - "Overriding output data pointer allocated by memory plan is not allowed."); + // ET_CHECK_OR_RETURN_ERROR( + // !pre_allocated_output_, + // InvalidState, + // "Overriding output data pointer allocated by memory plan is not + // allowed."); + // TODO(T188740925): for now, return error without logs. + if (pre_allocated_output_) { + return ::torch::executor::Error::InvalidState; + } // Check the args ET_CHECK_OR_RETURN_ERROR( diff --git a/shim/third-party/rust/Cargo.toml b/shim/third-party/rust/Cargo.toml index 1084abbddc8..e0e31bf578a 100644 --- a/shim/third-party/rust/Cargo.toml +++ b/shim/third-party/rust/Cargo.toml @@ -22,7 +22,7 @@ path = "top/main.rs" gazebo = {version = "0.8.1", features = ["str_pattern_extensions"]} fbinit = "0.1" -sorted_vector_map = "0.1" +sorted_vector_map = "0.2" watchman_client = "0.8.0" annotate-snippets = { version = "0.9.0", features = ["color"] } @@ -184,7 +184,7 @@ syn1 = { package = "syn", version = "1.0.109", features = ["extra-traits", "fold synstructure = "0.12" sync_wrapper = "0.1.0" sys-info = "0.9.1" -sysinfo = "0.26.8" +sysinfo = "0.30.11" take_mut = "0.2.2" tar = "0.4.38" tempfile = "3.1.0" @@ -209,6 +209,7 @@ tracing = "0.1.22" tracing-subscriber = { version = "0.3", features = ["env-filter"] } triomphe = "0.1.11" trybuild = "1.0.56" +typed-arena = "2.0" twox-hash = "1.6.1" unicode-segmentation = "1.7" uuid = { version = "1.2", features = ["v4"] } diff --git a/test/run_oss_cpp_tests.sh b/test/run_oss_cpp_tests.sh new file mode 100644 index 00000000000..dcc8b4d27f6 --- /dev/null +++ b/test/run_oss_cpp_tests.sh @@ -0,0 +1,23 @@ +#!/bin/bash +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +set -ex + +build_executorch() { + cmake . -DCMAKE_INSTALL_PREFIX=cmake-out -DEXECUTORCH_BUILD_GTESTS=ON -Bcmake-out + cmake --build cmake-out -j9 --target install +} + +build_and_run_test() { + local test_dir=$1 + cmake "${test_dir}" -Bcmake-out/"${test_dir}" -DCMAKE_INSTALL_PREFIX=cmake-out + cmake --build cmake-out/"${test_dir}" + for t in $(cmake-out/"${test_dir}"/*test); do ./"$t"; done +} + +build_executorch +build_and_run_test runtime/core/portable_type/test/