From 48a21b5e64707ee2c057cf70256b7827f65eebf2 Mon Sep 17 00:00:00 2001 From: Dan Field Date: Mon, 13 Feb 2023 15:47:49 -0800 Subject: [PATCH] Fix multi-function compute --- impeller/fixtures/BUILD.gn | 10 +- impeller/fixtures/stage1.comp | 29 +++++ impeller/fixtures/stage2.comp | 26 ++++ .../backend/metal/compute_pass_mtl.mm | 32 ++--- impeller/renderer/compute_unittests.cc | 114 ++++++++++++++++++ 5 files changed, 194 insertions(+), 17 deletions(-) create mode 100644 impeller/fixtures/stage1.comp create mode 100644 impeller/fixtures/stage2.comp diff --git a/impeller/fixtures/BUILD.gn b/impeller/fixtures/BUILD.gn index 6ff7eeddc2c13..a2a73ccc843a5 100644 --- a/impeller/fixtures/BUILD.gn +++ b/impeller/fixtures/BUILD.gn @@ -27,13 +27,19 @@ impeller_shaders("shader_fixtures") { "mipmaps.frag", "mipmaps.vert", "sample.comp", + "stage1.comp", + "stage2.comp", "simple.vert", "test_texture.frag", "test_texture.vert", ] if (impeller_enable_opengles) { - gles_exclusions = [ "sample.comp" ] + gles_exclusions = [ + "sample.comp", + "stage1.comp", + "stage2.comp", + ] } } @@ -77,6 +83,8 @@ test_fixtures("file_fixtures") { "sample_with_binding.vert", "simple.vert.hlsl", "sa%m#ple.vert", + "stage1.comp", + "stage2.comp", "struct_def_bug.vert", "table_mountain_nx.png", "table_mountain_ny.png", diff --git a/impeller/fixtures/stage1.comp b/impeller/fixtures/stage1.comp new file mode 100644 index 0000000000000..9c39e3f5a95a0 --- /dev/null +++ b/impeller/fixtures/stage1.comp @@ -0,0 +1,29 @@ +layout(local_size_x = 128) in; +layout(std430) buffer; + +layout(binding = 0) writeonly buffer Output { + uint count; + uint elements[]; +} +output_data; + +layout(binding = 1) readonly buffer Input { + uint count; + uint elements[]; +} +input_data; + +void main() { + uint ident = gl_GlobalInvocationID.x; + + if (ident >= input_data.count) { + return; + } + + uint out_slot = ident * 2; + + output_data.count = input_data.count * 2; + + output_data.elements[out_slot] = input_data.elements[ident] * 2; + output_data.elements[out_slot + 1] = input_data.elements[ident] * 3; +} diff --git a/impeller/fixtures/stage2.comp b/impeller/fixtures/stage2.comp new file mode 100644 index 0000000000000..ed51278da16f8 --- /dev/null +++ b/impeller/fixtures/stage2.comp @@ -0,0 +1,26 @@ +layout(local_size_x = 128) in; +layout(std430) buffer; + +layout(binding = 0) writeonly buffer Output { + uint count; + uint elements[]; +} +output_data; + +layout(binding = 1) readonly buffer Input { + uint count; + uint elements[]; +} +input_data; + +void main() { + uint ident = gl_GlobalInvocationID.x; + + if (ident >= input_data.count) { + return; + } + + output_data.count = input_data.count; + + output_data.elements[ident] = input_data.elements[ident] * 2; +} diff --git a/impeller/renderer/backend/metal/compute_pass_mtl.mm b/impeller/renderer/backend/metal/compute_pass_mtl.mm index fd831f35a2aa2..e4755fef47115 100644 --- a/impeller/renderer/backend/metal/compute_pass_mtl.mm +++ b/impeller/renderer/backend/metal/compute_pass_mtl.mm @@ -241,23 +241,23 @@ static bool Bind(ComputePassBindingsCache& pass, return false; } } + // TODO(dnfield): use feature detection to support non-uniform threadgroup + // sizes. + // https://github.com/flutter/flutter/issues/110619 + + // For now, check that the sizes are uniform. + FML_DCHECK(grid_size == thread_group_size); + auto width = grid_size.width; + auto height = grid_size.height; + while (width * height > + static_cast( + pass_bindings.GetPipeline().maxTotalThreadsPerThreadgroup)) { + width /= 2; + height /= 2; + } + auto size = MTLSizeMake(width, height, 1); + [encoder dispatchThreadgroups:size threadsPerThreadgroup:size]; } - // TODO(dnfield): use feature detection to support non-uniform threadgroup - // sizes. - // https://github.com/flutter/flutter/issues/110619 - - // For now, check that the sizes are uniform. - FML_DCHECK(grid_size == thread_group_size); - auto width = grid_size.width; - auto height = grid_size.height; - while (width * height > - static_cast( - pass_bindings.GetPipeline().maxTotalThreadsPerThreadgroup)) { - width /= 2; - height /= 2; - } - auto size = MTLSizeMake(width, height, 1); - [encoder dispatchThreadgroups:size threadsPerThreadgroup:size]; return true; } diff --git a/impeller/renderer/compute_unittests.cc b/impeller/renderer/compute_unittests.cc index 5da7fe133ec6b..4ae43c04fccbf 100644 --- a/impeller/renderer/compute_unittests.cc +++ b/impeller/renderer/compute_unittests.cc @@ -5,8 +5,11 @@ #include "flutter/fml/synchronization/waitable_event.h" #include "flutter/fml/time/time_point.h" #include "flutter/testing/testing.h" +#include "gmock/gmock.h" #include "impeller/base/strings.h" #include "impeller/fixtures/sample.comp.h" +#include "impeller/fixtures/stage1.comp.h" +#include "impeller/fixtures/stage2.comp.h" #include "impeller/playground/compute_playground_test.h" #include "impeller/renderer/command_buffer.h" #include "impeller/renderer/compute_command.h" @@ -102,5 +105,116 @@ TEST_P(ComputeTest, CanCreateComputePass) { latch.Wait(); } +TEST_P(ComputeTest, MultiStageInputAndOutput) { + using CS1 = Stage1ComputeShader; + using Stage1PipelineBuilder = ComputePipelineBuilder; + using CS2 = Stage2ComputeShader; + using Stage2PipelineBuilder = ComputePipelineBuilder; + + auto context = GetContext(); + ASSERT_TRUE(context); + + auto pipeline_desc_1 = + Stage1PipelineBuilder::MakeDefaultPipelineDescriptor(*context); + ASSERT_TRUE(pipeline_desc_1.has_value()); + auto compute_pipeline_1 = + context->GetPipelineLibrary()->GetPipeline(pipeline_desc_1).Get(); + ASSERT_TRUE(compute_pipeline_1); + + auto pipeline_desc_2 = + Stage2PipelineBuilder::MakeDefaultPipelineDescriptor(*context); + ASSERT_TRUE(pipeline_desc_2.has_value()); + auto compute_pipeline_2 = + context->GetPipelineLibrary()->GetPipeline(pipeline_desc_2).Get(); + ASSERT_TRUE(compute_pipeline_2); + + auto cmd_buffer = context->CreateCommandBuffer(); + auto pass = cmd_buffer->CreateComputePass(); + ASSERT_TRUE(pass && pass->IsValid()); + + static constexpr size_t kCount1 = 5; + static constexpr size_t kCount2 = kCount1 * 2; + + pass->SetGridSize(ISize(512, 1)); + pass->SetThreadGroupSize(ISize(512, 1)); + + CS1::Input input_1; + input_1.count = kCount1; + for (uint i = 0; i < kCount1; i++) { + input_1.elements[i] = i; + } + + CS2::Input input_2; + input_2.count = kCount2; + for (uint i = 0; i < kCount2; i++) { + input_2.elements[i] = i; + } + + DeviceBufferDescriptor output_desc_1; + output_desc_1.storage_mode = StorageMode::kHostVisible; + output_desc_1.size = sizeof(CS1::Output); + + auto output_buffer_1 = + context->GetResourceAllocator()->CreateBuffer(output_desc_1); + output_buffer_1->SetLabel("Output Buffer Stage 1"); + + DeviceBufferDescriptor output_desc_2; + output_desc_2.storage_mode = StorageMode::kHostVisible; + output_desc_2.size = sizeof(CS2::Output); + + auto output_buffer_2 = + context->GetResourceAllocator()->CreateBuffer(output_desc_2); + output_buffer_2->SetLabel("Output Buffer Stage 2"); + + { + ComputeCommand cmd; + cmd.label = "Compute1"; + cmd.pipeline = compute_pipeline_1; + + CS1::BindInput(cmd, + pass->GetTransientsBuffer().EmplaceStorageBuffer(input_1)); + CS1::BindOutput(cmd, output_buffer_1->AsBufferView()); + + ASSERT_TRUE(pass->AddCommand(std::move(cmd))); + } + + { + ComputeCommand cmd; + cmd.label = "Compute2"; + cmd.pipeline = compute_pipeline_2; + + CS1::BindInput(cmd, output_buffer_1->AsBufferView()); + CS2::BindOutput(cmd, output_buffer_2->AsBufferView()); + ASSERT_TRUE(pass->AddCommand(std::move(cmd))); + } + + ASSERT_TRUE(pass->EncodeCommands()); + + fml::AutoResetWaitableEvent latch; + ASSERT_TRUE(cmd_buffer->SubmitCommands([&latch, &output_buffer_1, + &output_buffer_2]( + CommandBuffer::Status status) { + EXPECT_EQ(status, CommandBuffer::Status::kCompleted); + + CS1::Output* output_1 = reinterpret_cast*>( + output_buffer_1->AsBufferView().contents); + EXPECT_TRUE(output_1); + EXPECT_EQ(output_1->count, 10u); + EXPECT_THAT(output_1->elements, + ::testing::ElementsAre(0, 0, 2, 3, 4, 6, 6, 9, 8, 12)); + + CS2::Output* output_2 = reinterpret_cast*>( + output_buffer_2->AsBufferView().contents); + EXPECT_TRUE(output_2); + EXPECT_EQ(output_2->count, 10u); + EXPECT_THAT(output_2->elements, + ::testing::ElementsAre(0, 0, 4, 6, 8, 12, 12, 18, 16, 24)); + + latch.Signal(); + })); + + latch.Wait(); +} + } // namespace testing } // namespace impeller