From 9714120680fdecc4e6d0a34a5dbb49006ac8613d Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Tue, 12 Mar 2024 08:59:19 -0700 Subject: [PATCH] [ET-VK] Dynamic shape support in Vulkan Backend ## Context This changeset exposes API functions to the `ComputeGraph` class that allow inputs to be resized, and for the resizing to propagate through the graph via re-calculation of output shapes. Differential Revision: [D54754546](https://our.internmc.facebook.com/intern/diff/D54754546/) [ghstack-poisoned] --- backends/vulkan/runtime/VulkanBackend.cpp | 75 ++++++++++++++++- .../vulkan/runtime/graph/ComputeGraph.cpp | 21 ++++- backends/vulkan/runtime/graph/ComputeGraph.h | 15 +++- .../vulkan/runtime/graph/ops/ExecuteNode.cpp | 8 +- .../vulkan/runtime/graph/ops/ExecuteNode.h | 17 +++- .../runtime/graph/ops/impl/BinaryOp.cpp | 27 ++++++- backends/vulkan/targets.bzl | 1 + backends/vulkan/test/test_vulkan_delegate.py | 56 ++++++++++++- .../vulkan/test/vulkan_compute_api_test.cpp | 80 +++++++++++++++++++ backends/vulkan/vulkan_preprocess.py | 3 + 10 files changed, 287 insertions(+), 16 deletions(-) diff --git a/backends/vulkan/runtime/VulkanBackend.cpp b/backends/vulkan/runtime/VulkanBackend.cpp index a073919c696..ce7de51bc93 100644 --- a/backends/vulkan/runtime/VulkanBackend.cpp +++ b/backends/vulkan/runtime/VulkanBackend.cpp @@ -16,6 +16,7 @@ #include #include #include +#include #include #include @@ -195,6 +196,68 @@ class GraphBuilder { } }; +// +// Execution tools +// + +bool maybe_resize_input( + ComputeGraph* graph, + const size_t input_i, + exec_aten::Tensor& et_tensor) { + ValueRef in_tensor_ref = graph->inputs()[input_i].value; + vTensor& in_tensor = graph->get_val(in_tensor_ref).toTensor(); + + ET_CHECK_MSG( + et_tensor.dim() == in_tensor.sizes().size(), + "Cannot resize input tensor: old ndim %zu does not match new ndim %zu", + static_cast(in_tensor.sizes().size()), + static_cast(et_tensor.dim())); + + bool should_resize = false; + std::vector new_sizes(et_tensor.dim()); + for (size_t i = 0; i < et_tensor.dim(); i++) { + if (in_tensor.sizes()[i] != et_tensor.sizes()[i]) { + should_resize = true; + } + new_sizes[i] = et_tensor.sizes()[i]; + } + + if (should_resize) { + graph->resize_input(input_i, new_sizes); + } + + ET_CHECK_MSG( + in_tensor.numel() == et_tensor.numel(), + "Vulkan tensor numel %zu does not match ET tensor numel %zu", + static_cast(in_tensor.numel()), + static_cast(et_tensor.numel())); + + return should_resize; +} + +void resize_output( + ComputeGraph* graph, + const size_t output_i, + exec_aten::Tensor& et_tensor) { + ValueRef out_tensor_ref = graph->outputs()[output_i].value; + vTensor& out_tensor = graph->get_val(out_tensor_ref).toTensor(); + + exec_aten::SizesType new_output_size[kTensorDimensionLimit]; + size_t ndim = out_tensor.sizes().size(); + for (int i = 0; i < ndim; ++i) { + new_output_size[i] = out_tensor.sizes()[i]; + } + + exec_aten::ArrayRef output_size{new_output_size, ndim}; + Error err = resize_tensor(et_tensor, output_size); + + ET_CHECK_MSG(err == Error::Ok, "Failed to resize output tensor."); +} + +// +// VulkanBackend class +// + class VulkanBackend final : public PyTorchBackendInterface { public: ~VulkanBackend() override = default; @@ -273,20 +336,28 @@ class VulkanBackend final : public PyTorchBackendInterface { ComputeGraph* compute_graph = static_cast(handle); const size_t num_inputs = compute_graph->inputs().size(); + bool should_propagate_resize = false; for (size_t i = 0; i < num_inputs; i++) { + bool was_resized = + maybe_resize_input(compute_graph, i, args[i]->toTensor()); + should_propagate_resize = should_propagate_resize || was_resized; compute_graph->copy_into_staging( - compute_graph->inputs()[i], + compute_graph->inputs()[i].staging, args[i]->toTensor().const_data_ptr(), args[i]->toTensor().numel()); } + if (should_propagate_resize) { + compute_graph->propagate_resize(); + } compute_graph->execute(); for (size_t i = 0; i < compute_graph->outputs().size(); i++) { + resize_output(compute_graph, i, args[num_inputs + i]->toTensor()); // args holds inputs directly followed by outputs, so the i'th output // for compute_graph corresponds to the (i + num_inputs)'th arg compute_graph->copy_from_staging( - compute_graph->outputs()[i], + compute_graph->outputs()[i].staging, args[num_inputs + i]->toTensor().mutable_data_ptr(), args[num_inputs + i]->toTensor().numel()); } diff --git a/backends/vulkan/runtime/graph/ComputeGraph.cpp b/backends/vulkan/runtime/graph/ComputeGraph.cpp index 6583d4a5a3e..262dd64a313 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.cpp +++ b/backends/vulkan/runtime/graph/ComputeGraph.cpp @@ -135,10 +135,10 @@ ValueRef ComputeGraph::set_input_tensor( vTensor& tensor = get_val(idx).toTensor(); ValueRef staging_idx = add_staging(tensor.dtype(), tensor.gpu_numel()); add_staging_to_tensor_node(*this, staging_idx, idx); - inputs_.push_back(staging_idx); + inputs_.push_back({idx, staging_idx}); return staging_idx; } - inputs_.push_back(idx); + inputs_.push_back({idx, -1}); return idx; } @@ -149,10 +149,10 @@ ValueRef ComputeGraph::set_output_tensor( vTensor& tensor = get_val(idx).toTensor(); ValueRef staging_idx = add_staging(tensor.dtype(), tensor.gpu_numel()); add_tensor_to_staging_node(*this, idx, staging_idx); - outputs_.push_back(staging_idx); + outputs_.push_back({idx, staging_idx}); return staging_idx; } - outputs_.push_back(idx); + outputs_.push_back({idx, -1}); return idx; } @@ -241,6 +241,19 @@ void ComputeGraph::execute() const { fence.wait(); } +void ComputeGraph::resize_input( + const int64_t idx, + const std::vector& new_sizes) { + IOValueRef io_val = inputs_.at(idx); + get_val(io_val.value).toTensor().virtual_resize(new_sizes); +} + +void ComputeGraph::propagate_resize() { + for (std::unique_ptr& node : execute_nodes_) { + node->trigger_resize(this); + } +} + } // namespace vulkan } // namespace native } // namespace at diff --git a/backends/vulkan/runtime/graph/ComputeGraph.h b/backends/vulkan/runtime/graph/ComputeGraph.h index 1253111150d..47c45f574e7 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.h +++ b/backends/vulkan/runtime/graph/ComputeGraph.h @@ -68,8 +68,8 @@ class ComputeGraph final { std::vector> prepack_nodes_; std::vector> execute_nodes_; - std::vector inputs_; - std::vector outputs_; + std::vector inputs_; + std::vector outputs_; public: // @@ -80,11 +80,11 @@ class ComputeGraph final { return context_.get(); } - inline std::vector& inputs() { + inline std::vector& inputs() { return inputs_; } - inline std::vector& outputs() { + inline std::vector& outputs() { return outputs_; } @@ -201,6 +201,13 @@ class ComputeGraph final { void encode_execute(); void execute() const; + + // + // Dynamic Shape support + // + + void resize_input(const int64_t idx, const std::vector& new_sizes); + void propagate_resize(); }; template diff --git a/backends/vulkan/runtime/graph/ops/ExecuteNode.cpp b/backends/vulkan/runtime/graph/ops/ExecuteNode.cpp index 496a94238b4..586ef3ef4e6 100644 --- a/backends/vulkan/runtime/graph/ops/ExecuteNode.cpp +++ b/backends/vulkan/runtime/graph/ops/ExecuteNode.cpp @@ -22,12 +22,16 @@ ExecuteNode::ExecuteNode( const api::utils::uvec3& global_workgroup_size, const api::utils::uvec3& local_workgroup_size, const std::vector& args, - const std::vector>& params) + std::vector> params, + const std::vector& extra_args, + const ResizeFunction& resize_fn) : shader_(shader), global_workgroup_size_(global_workgroup_size), local_workgroup_size_(local_workgroup_size), args_(args), - params_(params) { + params_(params), + extra_args_(extra_args), + resize_fn_(resize_fn) { graph.update_descriptor_counts(shader, /*execute = */ true); } diff --git a/backends/vulkan/runtime/graph/ops/ExecuteNode.h b/backends/vulkan/runtime/graph/ops/ExecuteNode.h index 5e3a1e003b8..52b0c738f28 100644 --- a/backends/vulkan/runtime/graph/ops/ExecuteNode.h +++ b/backends/vulkan/runtime/graph/ops/ExecuteNode.h @@ -47,18 +47,31 @@ class ExecuteNode final { friend class ComputeGraph; public: + using ResizeFunction = const std::function&, + const std::vector&)>; + ExecuteNode( ComputeGraph& graph, const api::ShaderInfo& shader, const api::utils::uvec3& global_workgroup_size, const api::utils::uvec3& local_workgroup_size, const std::vector& args, - const std::vector>& params); + std::vector> params, + const std::vector& extra_args = {}, + const ResizeFunction& resize_fn = nullptr); ~ExecuteNode() = default; void encode(ComputeGraph* graph); + inline void trigger_resize(ComputeGraph* graph) { + if (resize_fn_ != nullptr) { + resize_fn_(graph, args_, extra_args_); + } + } + protected: const api::ShaderInfo shader_; const api::utils::uvec3 global_workgroup_size_; @@ -66,6 +79,8 @@ class ExecuteNode final { const std::vector args_; // TODO(T180906457): allow re-computing param buffers. std::vector> params_; + const std::vector extra_args_; + const ResizeFunction resize_fn_; }; } // namespace vulkan diff --git a/backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp b/backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp index 887b529b208..5fc5bdeb3de 100644 --- a/backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp @@ -23,6 +23,26 @@ std::string get_arithmetic_shader_name(const std::string& op_name) { return "arithmetic_" + op_name; } +void resize_arithmetic_node( + ComputeGraph* graph, + const std::vector& args, + const std::vector& extra_args) { + vTensor& out = graph->get_val(args[0].refs[0]).toTensor(); + vTensor& self = graph->get_val(args[1].refs[0]).toTensor(); + vTensor& other = graph->get_val(args[1].refs[1]).toTensor(); + + std::vector new_out_sizes( + std::max(self.sizes().size(), other.sizes().size())); + + for (int i = -1; i >= -new_out_sizes.size(); --i) { + new_out_sizes[new_out_sizes.size() + i] = std::max( + api::utils::val_at(i, self.sizes()), + api::utils::val_at(i, other.sizes())); + } + + out.virtual_resize(new_out_sizes); +} + void add_arithmetic_node( ComputeGraph& graph, const ValueRef in1, @@ -56,12 +76,17 @@ void add_arithmetic_node( VK_KERNEL_FROM_STR(kernel_name.str()), global_size, local_size, + // Inputs and Outputs {{out, api::MemoryAccessType::WRITE}, {{arg1, arg2}, api::MemoryAccessType::READ}}, + // Shader params buffers {t_out.gpu_sizes_ubo(), t_in1.gpu_sizes_ubo(), t_in2.gpu_sizes_ubo(), - graph.create_params_buffer(alpha_val)})); + graph.create_params_buffer(alpha_val)}, + // Resizing + {alpha}, + resize_arithmetic_node)); } #define DEFINE_ARITHMETIC_WITH_ALPHA_FN(function, shader) \ diff --git a/backends/vulkan/targets.bzl b/backends/vulkan/targets.bzl index d4f58062b17..eee2c3e823f 100644 --- a/backends/vulkan/targets.bzl +++ b/backends/vulkan/targets.bzl @@ -143,6 +143,7 @@ def define_common_targets(): ":vk_delegate_schema", ":vulkan_graph_runtime", "//executorch/runtime/backend:interface", + "//executorch/runtime/core/exec_aten/util:tensor_util", ], define_static_target = False, # VulkanBackend.cpp needs to compile with executor as whole diff --git a/backends/vulkan/test/test_vulkan_delegate.py b/backends/vulkan/test/test_vulkan_delegate.py index 8a491497c31..98c48369661 100644 --- a/backends/vulkan/test/test_vulkan_delegate.py +++ b/backends/vulkan/test/test_vulkan_delegate.py @@ -14,7 +14,7 @@ from executorch.backends.vulkan.vulkan_preprocess import VulkanBackend from executorch.exir import EdgeProgramManager, to_edge -from torch.export import export, ExportedProgram +from torch.export import Dim, export, ExportedProgram ctypes.CDLL("libvulkan.so.1") @@ -54,13 +54,17 @@ def lower_module_and_test_output( sample_inputs: Tuple[torch.Tensor], atol=1e-03, rtol=1e-01, + dynamic_shapes=None, + test_inputs=None, ): """ Helper testing function that takes a torch.nn.Module and lowers it to Vulkan with the given sample inputs. It then runs the lowered module and compares its outputs with the outputs of the eager module. """ - program: ExportedProgram = export(model, sample_inputs) + program: ExportedProgram = export( + model, sample_inputs, dynamic_shapes=dynamic_shapes + ) edge_program: EdgeProgramManager = to_edge(program) edge_program = edge_program.to_backend(VulkanPartitioner()) @@ -80,6 +84,19 @@ def lower_module_and_test_output( self.assert_outputs_equal(model_output, ref_output, atol=atol, rtol=rtol) + if test_inputs is not None: + for test_input in test_inputs: + # pyre-fixme[16]: Module `pytree` has no attribute `tree_flatten`. + test_inputs_flattened, _ = tree_flatten(test_input) + model_output = executorch_module.run_method( + "forward", tuple(test_inputs_flattened) + ) + ref_output = model(*test_input) + + self.assert_outputs_equal( + model_output, ref_output, atol=atol, rtol=rtol + ) + def test_vulkan_backend_add(self): # This test is the simplest test by manually lowering some submodules, we can use paritioner for auto detecting lowerable parts class AddModule(torch.nn.Module): @@ -251,3 +268,38 @@ def forward(self, x): model_inputs = (torch.rand(size=(2, 10), dtype=torch.float32),) self.lower_module_and_test_output(model, model_inputs) + + def test_vulkan_backend_partial_dynamic_shapes(self): + class SimpleModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.branch1 = torch.nn.Sequential( + torch.nn.Linear(64, 64), torch.nn.ReLU() + ) + self.branch2 = torch.nn.Sequential( + torch.nn.Linear(128, 64), torch.nn.ReLU() + ) + self.buffer_1 = torch.ones((1, 64)) * 0.5 + self.buffer_2 = torch.ones((1, 64)) * 1.4 + + def forward(self, x1, x2): + out1 = self.branch1(x1) + out2 = self.branch2(x2) + return (out1 + self.buffer_1 + out2) * self.buffer_2 + + model = SimpleModel() + model_inputs = (torch.randn(32, 64), torch.randn(32, 128)) + batch = Dim("batch", max=124) + dynamic_shapes = {"x1": {0: batch}, "x2": {0: batch}} + + test_inputs = [ + (torch.randn(15, 64), torch.randn(15, 128)), + (torch.randn(6, 64), torch.randn(6, 128)), + (torch.randn(30, 64), torch.randn(30, 128)), + (torch.randn(20, 64), torch.randn(20, 128)), + (torch.randn(19, 64), torch.randn(19, 128)), + ] + + self.lower_module_and_test_output( + model, model_inputs, dynamic_shapes=dynamic_shapes, test_inputs=test_inputs + ) diff --git a/backends/vulkan/test/vulkan_compute_api_test.cpp b/backends/vulkan/test/vulkan_compute_api_test.cpp index 5d4725e5dd5..d1f564920ac 100644 --- a/backends/vulkan/test/vulkan_compute_api_test.cpp +++ b/backends/vulkan/test/vulkan_compute_api_test.cpp @@ -695,3 +695,83 @@ TEST(VulkanComputeGraphTest, test_manual_virtual_resize) { } } } + +TEST(VulkanComputeGraphTest, test_resize) { + GraphConfig config; + ComputeGraph graph(config); + + std::vector size_big = {12, 64, 64}; + std::vector size_small = {12, 64, 64}; + + // Build graph + + IOValueRef a = graph.add_input_tensor( + size_big, + api::kFloat, + /*shared_object_idx = */ 2); + IOValueRef b = graph.add_input_tensor( + size_small, + api::kFloat, + /*shared_object_idx = */ 4); + + ValueRef c = graph.add_tensor( + size_big, + api::kFloat, + /*shared_object_idx = */ 6); + + auto addFn = VK_GET_OP_FN("aten.add.Tensor"); + addFn(graph, {a.value, b.value, kDummyValueRef, c}); + + IOValueRef d = graph.add_input_tensor( + size_small, + api::kFloat, + /*shared_object_idx = */ 2); + + ValueRef e = graph.add_tensor( + size_big, + api::kFloat, + /*shared_object_idx = */ 4); + + auto mulFn = VK_GET_OP_FN("aten.mul.Tensor"); + mulFn(graph, {c, d.value, e}); + + IOValueRef out = {}; + out.value = e; + out.staging = graph.set_output_tensor(out.value); + + graph.prepare(); + graph.encode_execute(); + + // Run graph + + std::vector> new_sizes_list = { + {8, 44, 34}, {4, 13, 56}, {8, 12, 64}, {12, 55, 33}, {4, 54, 10}}; + + for (auto& new_sizes : new_sizes_list) { + graph.resize_input(0, new_sizes); + graph.resize_input(1, new_sizes); + graph.resize_input(2, new_sizes); + graph.propagate_resize(); + + float val_a = new_sizes[1] + 6.0f; + float val_b = new_sizes[2] + 2.5f; + float val_d = new_sizes[0] + 4.0f; + float val_out = (val_a + val_b) * val_d; + + fill_vtensor(graph, a, val_a); + fill_vtensor(graph, b, val_b); + fill_vtensor(graph, d, val_d); + + // Execute graph + graph.execute(); + + EXTRACT_TENSOR(out); + + // Sanity check that the values are correct + int i = 0; + for (const auto& val : data_out) { + ASSERT_TRUE(val == val_out); + ++i; + } + } +} diff --git a/backends/vulkan/vulkan_preprocess.py b/backends/vulkan/vulkan_preprocess.py index 293d114e8d3..27f42d1ec8f 100644 --- a/backends/vulkan/vulkan_preprocess.py +++ b/backends/vulkan/vulkan_preprocess.py @@ -22,6 +22,8 @@ from executorch.exir.passes import MemoryPlanningPass, SpecPropPass +from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass + from executorch.exir.program._program import _copy_module from torch import dtype, float32 @@ -46,6 +48,7 @@ def preprocess( # noqa: C901 ) -> PreprocessResult: passes = [ SpecPropPass(), + ConstraintBasedSymShapeEvalPass(), MemoryPlanningPass("greedy"), ]