From d274861f5a30639bd42af25711780c52f3a193f9 Mon Sep 17 00:00:00 2001 From: Martijn Courteaux Date: Sat, 24 May 2025 15:05:23 +0200 Subject: [PATCH 01/47] Fix Vulkan interleave SPIRV codegen. Fix a bug in Simplify_Shuffle. Fix a bug in Deinterleave. --- src/CodeGen_Vulkan_Dev.cpp | 24 ++--- src/Deinterleave.cpp | 24 ++++- src/Simplify_Let.cpp | 12 ++- src/Simplify_Shuffle.cpp | 14 ++- src/runtime/vulkan_internal.h | 4 + src/runtime/vulkan_resources.h | 138 ++++++++++++++-------------- test/correctness/vector_shuffle.cpp | 116 ++++++++++++++++++++--- 7 files changed, 229 insertions(+), 103 deletions(-) diff --git a/src/CodeGen_Vulkan_Dev.cpp b/src/CodeGen_Vulkan_Dev.cpp index f65a57005175..896c411d2471 100644 --- a/src/CodeGen_Vulkan_Dev.cpp +++ b/src/CodeGen_Vulkan_Dev.cpp @@ -2080,31 +2080,21 @@ void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const Shuffle *op) { debug(3) << "\n"; if (arg_ids.size() == 1) { - // 1 argument, just do a simple assignment via a cast SpvId result_id = cast_type(op->type, op->vectors[0].type(), arg_ids[0]); builder.update_id(result_id); } else if (arg_ids.size() == 2) { - - // 2 arguments, use a composite insert to update even and odd indices - uint32_t even_idx = 0; - uint32_t odd_idx = 1; - SpvFactory::Indices even_indices; - SpvFactory::Indices odd_indices; - for (int i = 0; i < op_lanes; ++i) { - even_indices.push_back(even_idx); - odd_indices.push_back(odd_idx); - even_idx += 2; - odd_idx += 2; + // 2 arguments, use vector-shuffle with logical indices indexing into (vec1[0], vec1[1], ..., vec2[0], vec2[1], ...) + SpvFactory::Indices logical_indices; + for (int i = 0; i < arg_lanes; ++i) { + logical_indices.push_back(uint32_t(i)); + logical_indices.push_back(uint32_t(i + arg_lanes)); } SpvId type_id = builder.declare_type(op->type); - SpvId value_id = builder.declare_null_constant(op->type); - SpvId partial_id = builder.reserve_id(SpvResultId); SpvId result_id = builder.reserve_id(SpvResultId); - builder.append(SpvFactory::composite_insert(type_id, partial_id, arg_ids[0], value_id, even_indices)); - builder.append(SpvFactory::composite_insert(type_id, result_id, arg_ids[1], partial_id, odd_indices)); + builder.append(SpvFactory::vector_shuffle(type_id, result_id, arg_ids[0], arg_ids[1], logical_indices)); builder.update_id(result_id); } else { @@ -2134,7 +2124,7 @@ void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const Shuffle *op) { } else if (op->is_extract_element()) { int idx = op->indices[0]; internal_assert(idx >= 0); - internal_assert(idx <= op->vectors[0].type().lanes()); + internal_assert(idx < op->vectors[0].type().lanes()); if (op->vectors[0].type().is_vector()) { SpvFactory::Indices indices = {(uint32_t)idx}; SpvId type_id = builder.declare_type(op->type); diff --git a/src/Deinterleave.cpp b/src/Deinterleave.cpp index f7a5b5f49aa8..243760e9d050 100644 --- a/src/Deinterleave.cpp +++ b/src/Deinterleave.cpp @@ -299,6 +299,10 @@ class Deinterleaver : public IRGraphMutator { } else { Type t = op->type.with_lanes(new_lanes); + internal_assert((op->type.lanes() - starting_lane + lane_stride - 1) / lane_stride == new_lanes) + << "Deinterleaving with lane stride " << lane_stride << " and staring lane " << starting_lane + << " for var of Type " << op->type << " to " << t << " drops lanes unexpectedly." + << " Deinterleaver probably recursed too deep into types of different lane count."; if (external_lets.contains(op->name) && starting_lane == 0 && lane_stride == 2) { @@ -393,8 +397,12 @@ class Deinterleaver : public IRGraphMutator { int index = indices.front(); for (const auto &i : op->vectors) { if (index < i.type().lanes()) { - ScopedValue lane(starting_lane, index); - return mutate(i); + if (i.type().lanes() == op->type.lanes()) { + ScopedValue scoped_starting_lane(starting_lane, index); + return mutate(i); + } else { + return Shuffle::make(op->vectors, indices); + } } index -= i.type().lanes(); } @@ -406,10 +414,18 @@ class Deinterleaver : public IRGraphMutator { }; Expr deinterleave(Expr e, int starting_lane, int lane_stride, int new_lanes, const Scope<> &lets) { + debug(3) << "Deinterleave " + << "(start:" << starting_lane << ", stide:" << lane_stride << ", new_lanes:" << new_lanes << "): " + << e << " of Type: " << e.type() << "\n"; + Type original_type = e.type(); e = substitute_in_all_lets(e); Deinterleaver d(starting_lane, lane_stride, new_lanes, lets); e = d.mutate(e); e = common_subexpression_elimination(e); + Type final_type = e.type(); + int expected_lanes = (original_type.lanes() + lane_stride - starting_lane - 1) / lane_stride; + internal_assert(original_type.code() == final_type.code()) << "Underlying types not identical after interleaving."; + internal_assert(expected_lanes == final_type.lanes()) << "Number of lanes incorrect after interleaving: " << final_type.lanes() << "while expected was " << expected_lanes << "."; return simplify(e); } @@ -420,12 +436,12 @@ Expr extract_odd_lanes(const Expr &e, const Scope<> &lets) { Expr extract_even_lanes(const Expr &e, const Scope<> &lets) { internal_assert(e.type().lanes() % 2 == 0); - return deinterleave(e, 0, 2, (e.type().lanes() + 1) / 2, lets); + return deinterleave(e, 0, 2, e.type().lanes() / 2, lets); } Expr extract_mod3_lanes(const Expr &e, int lane, const Scope<> &lets) { internal_assert(e.type().lanes() % 3 == 0); - return deinterleave(e, lane, 3, (e.type().lanes() + 2) / 3, lets); + return deinterleave(e, lane, 3, e.type().lanes() / 3, lets); } } // namespace diff --git a/src/Simplify_Let.cpp b/src/Simplify_Let.cpp index 1c60e7a2510d..801163215dc9 100644 --- a/src/Simplify_Let.cpp +++ b/src/Simplify_Let.cpp @@ -98,7 +98,7 @@ Body Simplify::simplify_let(const LetOrLetStmt *op, ExprInfo *info) { Expr new_var = Variable::make(f.new_value.type(), f.new_name); Expr replacement = new_var; - debug(4) << "simplify let " << op->name << " = " << f.value << " in...\n"; + debug(4) << "simplify let " << op->name << " = (" << f.value.type() << ") " << f.value << " in...\n"; while (true) { const Variable *var = f.new_value.template as(); @@ -180,6 +180,16 @@ Body Simplify::simplify_let(const LetOrLetStmt *op, ExprInfo *info) { f.new_value = cast->value; new_var = Variable::make(f.new_value.type(), f.new_name); replacement = substitute(f.new_name, Cast::make(cast->type, new_var), replacement); + } else if (shuffle && shuffle->is_concat() && is_pure(shuffle)) { + // Substitute in all concatenates as they will likely simplify + // with other shuffles. + // As the structure of this while loop makes it hard to peel off + // pure operations from _all_ arguments to the Shuffle, we will + // instead subsitute all of the vars that go in the shuffle, and + // instead guard against side effects by checking with `is_pure()`. + replacement = substitute(f.new_name, shuffle, replacement); + f.new_value = Expr(); + break; } else if (shuffle && shuffle->is_slice()) { // Replacing new_value below might free the shuffle // indices vector, so save them now. diff --git a/src/Simplify_Shuffle.cpp b/src/Simplify_Shuffle.cpp index aecb4c6fc99a..cf8d1f03317d 100644 --- a/src/Simplify_Shuffle.cpp +++ b/src/Simplify_Shuffle.cpp @@ -289,13 +289,18 @@ Expr Simplify::visit(const Shuffle *op, ExprInfo *info) { if (inner_shuffle->is_concat()) { int slice_min = op->indices.front(); int slice_max = op->indices.back(); + if (slice_min > slice_max) { + // Slices can go backward. + std::swap(slice_min, slice_max); + } int concat_index = 0; int new_slice_start = -1; vector new_concat_vectors; for (const auto &v : inner_shuffle->vectors) { // Check if current concat vector overlaps with slice. - if ((concat_index >= slice_min && concat_index <= slice_max) || - ((concat_index + v.type().lanes() - 1) >= slice_min && (concat_index + v.type().lanes() - 1) <= slice_max)) { + int overlap_max = std::min(slice_max, concat_index + v.type().lanes() - 1); + int overlap_min = std::max(slice_min, concat_index); + if (overlap_min <= overlap_max) { if (new_slice_start < 0) { new_slice_start = concat_index; } @@ -305,7 +310,10 @@ Expr Simplify::visit(const Shuffle *op, ExprInfo *info) { concat_index += v.type().lanes(); } if (new_concat_vectors.size() < inner_shuffle->vectors.size()) { - return Shuffle::make_slice(Shuffle::make_concat(new_concat_vectors), op->slice_begin() - new_slice_start, op->slice_stride(), op->indices.size()); + return Shuffle::make_slice(Shuffle::make_concat(new_concat_vectors), + op->slice_begin() - new_slice_start, + op->slice_stride(), + op->indices.size()); } } } diff --git a/src/runtime/vulkan_internal.h b/src/runtime/vulkan_internal.h index e9d345b6d403..db74739e20da 100644 --- a/src/runtime/vulkan_internal.h +++ b/src/runtime/vulkan_internal.h @@ -280,6 +280,8 @@ const char *vk_get_error_name(VkResult error) { return "VK_ERROR_FORMAT_NOT_SUPPORTED"; case VK_ERROR_FRAGMENTED_POOL: return "VK_ERROR_FRAGMENTED_POOL"; + case VK_ERROR_UNKNOWN: + return "VK_ERROR_UNKNOWN"; case VK_ERROR_SURFACE_LOST_KHR: return "VK_ERROR_SURFACE_LOST_KHR"; case VK_ERROR_NATIVE_WINDOW_IN_USE_KHR: @@ -303,6 +305,8 @@ const char *vk_get_error_name(VkResult error) { } } +#define vk_report_error(user_context, code, func) (error((user_context)) << "Vulkan: " << (func) << " returned " << vk_get_error_name((code)) << " (code: " << (code) << ") ") + // -------------------------------------------------------------------------- } // namespace diff --git a/src/runtime/vulkan_resources.h b/src/runtime/vulkan_resources.h index 95d1ae8d4a9b..a4c4f3b0d7e5 100644 --- a/src/runtime/vulkan_resources.h +++ b/src/runtime/vulkan_resources.h @@ -85,7 +85,7 @@ int vk_create_command_pool(void *user_context, VulkanMemoryAllocator *allocator, debug(user_context) << " vk_create_command_pool (user_context: " << user_context << ", " << "allocator: " << (void *)allocator << ", " - << "queue_index: " << queue_index << ")\n"; + << "queue_index: " << queue_index << ")"; #endif if (allocator == nullptr) { @@ -103,7 +103,7 @@ int vk_create_command_pool(void *user_context, VulkanMemoryAllocator *allocator, VkResult result = vkCreateCommandPool(allocator->current_device(), &command_pool_info, allocator->callbacks(), command_pool); if (result != VK_SUCCESS) { - error(user_context) << "Vulkan: Failed to create command pool!\n"; + vk_report_error(user_context, result, "vkCreateCommandPool"); return halide_error_code_generic_error; } return halide_error_code_success; @@ -117,7 +117,7 @@ int vk_destroy_command_pool(void *user_context, VulkanMemoryAllocator *allocator << "command_pool: " << (void *)command_pool << ")\n"; #endif if (allocator == nullptr) { - error(user_context) << "Vulkan: Failed to destroy command pool ... invalid allocator pointer!\n"; + error(user_context) << "Vulkan: Failed to destroy command pool ... invalid allocator pointer!"; return halide_error_code_generic_error; } vkResetCommandPool(allocator->current_device(), command_pool, VK_COMMAND_POOL_RESET_RELEASE_RESOURCES_BIT); @@ -135,7 +135,7 @@ int vk_create_command_buffer(void *user_context, VulkanMemoryAllocator *allocato << "command_pool: " << (void *)command_pool << ")\n"; #endif if (allocator == nullptr) { - error(user_context) << "Vulkan: Failed to create command buffer ... invalid allocator pointer!\n"; + error(user_context) << "Vulkan: Failed to create command buffer ... invalid allocator pointer!"; return halide_error_code_generic_error; } @@ -150,7 +150,7 @@ int vk_create_command_buffer(void *user_context, VulkanMemoryAllocator *allocato VkResult result = vkAllocateCommandBuffers(allocator->current_device(), &command_buffer_info, command_buffer); if (result != VK_SUCCESS) { - error(user_context) << "Vulkan: Failed to allocate command buffers!\n"; + vk_report_error(user_context, result, "vkAllocateCommandBuffers"); return halide_error_code_generic_error; } return halide_error_code_success; @@ -165,7 +165,7 @@ int vk_destroy_command_buffer(void *user_context, VulkanMemoryAllocator *allocat << "command_buffer: " << (void *)command_buffer << ")\n"; #endif if (allocator == nullptr) { - error(user_context) << "Vulkan: Failed to destroy command buffer ... invalid allocator pointer!\n"; + error(user_context) << "Vulkan: Failed to destroy command buffer ... invalid allocator pointer!"; return halide_error_code_generic_error; } @@ -231,7 +231,7 @@ int vk_fill_command_buffer_with_dispatch_call(void *user_context, VkResult result = vkBeginCommandBuffer(command_buffer, &command_buffer_begin_info); if (result != VK_SUCCESS) { - error(user_context) << "vkBeginCommandBuffer returned " << vk_get_error_name(result) << "\n"; + vk_report_error(user_context, result, "vkBeginCommandBuffer"); return halide_error_code_generic_error; } @@ -242,7 +242,7 @@ int vk_fill_command_buffer_with_dispatch_call(void *user_context, result = vkEndCommandBuffer(command_buffer); if (result != VK_SUCCESS) { - error(user_context) << "vkEndCommandBuffer returned " << vk_get_error_name(result) << "\n"; + vk_report_error(user_context, result, "vkEndCommandBuffer"); return halide_error_code_generic_error; } @@ -272,7 +272,7 @@ int vk_submit_command_buffer(void *user_context, VkQueue queue, VkCommandBuffer VkResult result = vkQueueSubmit(queue, 1, &submit_info, VK_NULL_HANDLE); if (result != VK_SUCCESS) { - error(user_context) << "Vulkan: vkQueueSubmit returned " << vk_get_error_name(result) << "\n"; + vk_report_error(user_context, result, "vkSubmitQueue"); return halide_error_code_generic_error; } return halide_error_code_success; @@ -325,7 +325,7 @@ int vk_create_descriptor_pool(void *user_context, << "storage_buffer_count: " << (uint32_t)storage_buffer_count << ")\n"; #endif if (allocator == nullptr) { - error(user_context) << "Vulkan: Failed to create descriptor pool ... invalid allocator pointer!\n"; + error(user_context) << "Vulkan: Failed to create descriptor pool ... invalid allocator pointer!"; return halide_error_code_generic_error; } @@ -362,7 +362,7 @@ int vk_create_descriptor_pool(void *user_context, VkResult result = vkCreateDescriptorPool(allocator->current_device(), &descriptor_pool_info, allocator->callbacks(), descriptor_pool); if (result != VK_SUCCESS) { - error(user_context) << "Vulkan: Failed to create descriptor pool! vkCreateDescriptorPool returned " << vk_get_error_name(result) << "\n"; + vk_report_error(user_context, result, "vkCreateDescriptorPool"); return halide_error_code_generic_error; } return halide_error_code_success; @@ -378,7 +378,7 @@ int vk_destroy_descriptor_pool(void *user_context, << "descriptor_pool: " << (void *)descriptor_pool << ")\n"; #endif if (allocator == nullptr) { - error(user_context) << "Vulkan: Failed to destroy descriptor pool ... invalid allocator pointer!\n"; + error(user_context) << "Vulkan: Failed to destroy descriptor pool ... invalid allocator pointer!"; return halide_error_code_generic_error; } vkDestroyDescriptorPool(allocator->current_device(), descriptor_pool, allocator->callbacks()); @@ -402,7 +402,7 @@ int vk_create_descriptor_set_layout(void *user_context, << "layout: " << (void *)layout << ")\n"; #endif if (allocator == nullptr) { - error(user_context) << "Vulkan: Failed to create descriptor set layout ... invalid allocator pointer!\n"; + error(user_context) << "Vulkan: Failed to create descriptor set layout ... invalid allocator pointer!"; return halide_error_code_generic_error; } @@ -460,7 +460,7 @@ int vk_create_descriptor_set_layout(void *user_context, // Create the descriptor set layout VkResult result = vkCreateDescriptorSetLayout(allocator->current_device(), &layout_info, allocator->callbacks(), layout); if (result != VK_SUCCESS) { - error(user_context) << "vkCreateDescriptorSetLayout returned " << vk_get_error_name(result) << "\n"; + vk_report_error(user_context, result, "vkCreateDescriptorSetLayout"); return halide_error_code_generic_error; } @@ -478,7 +478,7 @@ int vk_destroy_descriptor_set_layout(void *user_context, << "layout: " << (void *)descriptor_set_layout << ")\n"; #endif if (allocator == nullptr) { - error(user_context) << "Vulkan: Failed to destroy descriptor set layout ... invalid allocator pointer!\n"; + error(user_context) << "Vulkan: Failed to destroy descriptor set layout ... invalid allocator pointer!"; return halide_error_code_generic_error; } vkDestroyDescriptorSetLayout(allocator->current_device(), descriptor_set_layout, allocator->callbacks()); @@ -500,7 +500,7 @@ int vk_create_descriptor_set(void *user_context, << "descriptor_pool: " << (void *)descriptor_pool << ")\n"; #endif if (allocator == nullptr) { - error(user_context) << "Vulkan: Failed to create descriptor set ... invalid allocator pointer!\n"; + error(user_context) << "Vulkan: Failed to create descriptor set ... invalid allocator pointer!"; return halide_error_code_generic_error; } @@ -515,7 +515,7 @@ int vk_create_descriptor_set(void *user_context, VkResult result = vkAllocateDescriptorSets(allocator->current_device(), &descriptor_set_info, descriptor_set); if (result != VK_SUCCESS) { - error(user_context) << "Vulkan: vkAllocateDescriptorSets returned " << vk_get_error_name(result) << "\n"; + vk_report_error(user_context, result, "vkAllocateDescriptorSets"); return halide_error_code_generic_error; } @@ -541,7 +541,7 @@ int vk_update_descriptor_set(void *user_context, << "descriptor_set: " << (void *)descriptor_set << ")\n"; #endif if (allocator == nullptr) { - error(user_context) << "Vulkan: Failed to create descriptor set ... invalid allocator pointer!\n"; + error(user_context) << "Vulkan: Failed to create descriptor set ... invalid allocator pointer!"; return halide_error_code_generic_error; } @@ -599,7 +599,7 @@ int vk_update_descriptor_set(void *user_context, // retrieve the buffer from the region VkBuffer *device_buffer = reinterpret_cast(owner->handle); if (device_buffer == nullptr) { - error(user_context) << "Vulkan: Failed to retrieve buffer for device memory!\n"; + error(user_context) << "Vulkan: Failed to retrieve buffer for device memory!"; return halide_error_code_internal_error; } @@ -698,7 +698,7 @@ MemoryRegion *vk_create_scalar_uniform_buffer(void *user_context, #endif if (allocator == nullptr) { - error(user_context) << "Vulkan: Failed to create scalar uniform buffer ... invalid allocator pointer!\n"; + error(user_context) << "Vulkan: Failed to create scalar uniform buffer ... invalid allocator pointer!"; return nullptr; } @@ -711,7 +711,7 @@ MemoryRegion *vk_create_scalar_uniform_buffer(void *user_context, // allocate a new region MemoryRegion *region = allocator->reserve(user_context, request); if ((region == nullptr) || (region->handle == nullptr)) { - error(user_context) << "Vulkan: Failed to create scalar uniform buffer ... unable to allocate device memory!\n"; + error(user_context) << "Vulkan: Failed to create scalar uniform buffer ... unable to allocate device memory!"; return nullptr; } @@ -733,19 +733,19 @@ int vk_update_scalar_uniform_buffer(void *user_context, #endif if (allocator == nullptr) { - error(user_context) << "Vulkan: Failed to update scalar uniform buffer ... invalid allocator pointer!\n"; + error(user_context) << "Vulkan: Failed to update scalar uniform buffer ... invalid allocator pointer!"; return halide_error_code_generic_error; } if ((region == nullptr) || (region->handle == nullptr)) { - error(user_context) << "Vulkan: Failed to update scalar uniform buffer ... invalid memory region!\n"; + error(user_context) << "Vulkan: Failed to update scalar uniform buffer ... invalid memory region!"; return halide_error_code_internal_error; } // map the region to a host ptr uint8_t *host_ptr = (uint8_t *)allocator->map(user_context, region); if (host_ptr == nullptr) { - error(user_context) << "Vulkan: Failed to update scalar uniform buffer ... unable to map host pointer to device memory!\n"; + error(user_context) << "Vulkan: Failed to update scalar uniform buffer ... unable to map host pointer to device memory!"; return halide_error_code_internal_error; } @@ -798,7 +798,7 @@ int vk_destroy_scalar_uniform_buffer(void *user_context, VulkanMemoryAllocator * << "scalar_args_region: " << (void *)scalar_args_region << ")\n"; #endif if (allocator == nullptr) { - error(user_context) << "Vulkan: Failed to destroy scalar uniform buffer ... invalid allocator pointer!\n"; + error(user_context) << "Vulkan: Failed to destroy scalar uniform buffer ... invalid allocator pointer!"; return halide_error_code_generic_error; } @@ -832,7 +832,7 @@ int vk_create_pipeline_layout(void *user_context, << "pipeline_layout: " << (void *)pipeline_layout << ")\n"; #endif if (allocator == nullptr) { - error(user_context) << "Vulkan: Failed to create pipeline layout ... invalid allocator pointer!\n"; + error(user_context) << "Vulkan: Failed to create pipeline layout ... invalid allocator pointer!"; return halide_error_code_generic_error; } @@ -841,7 +841,7 @@ int vk_create_pipeline_layout(void *user_context, if (descriptor_set_count > max_bound_descriptor_sets) { error(user_context) << "Vulkan: Number of descriptor sets for pipeline layout exceeds the number that can be bound by device!\n" << " requested: " << descriptor_set_count << "," - << " available: " << max_bound_descriptor_sets << "\n"; + << " available: " << max_bound_descriptor_sets; return halide_error_code_incompatible_device_interface; } } @@ -858,7 +858,7 @@ int vk_create_pipeline_layout(void *user_context, VkResult result = vkCreatePipelineLayout(allocator->current_device(), &pipeline_layout_info, allocator->callbacks(), pipeline_layout); if (result != VK_SUCCESS) { - error(user_context) << "Vulkan: vkCreatePipelineLayout returned " << vk_get_error_name(result) << "\n"; + vk_report_error(user_context, result, "vkCreatePipelineLayout"); return halide_error_code_generic_error; } return halide_error_code_success; @@ -876,7 +876,7 @@ int vk_destroy_pipeline_layout(void *user_context, #endif if (allocator == nullptr) { - error(user_context) << "Vulkan: Failed to destroy pipeline layout ... invalid allocator pointer!\n"; + error(user_context) << "Vulkan: Failed to destroy pipeline layout ... invalid allocator pointer!"; return halide_error_code_generic_error; } @@ -898,11 +898,12 @@ int vk_create_compute_pipeline(void *user_context, debug(user_context) << " vk_create_compute_pipeline (user_context: " << user_context << ", " << "allocator: " << (void *)allocator << ", " + << "pipeline_name: " << pipeline_name << ", " << "shader_module: " << (void *)shader_module << ", " << "pipeline_layout: " << (void *)pipeline_layout << ")\n"; #endif if (allocator == nullptr) { - error(user_context) << "Vulkan: Failed to create compute pipeline ... invalid allocator pointer!\n"; + error(user_context) << "Vulkan: Failed to create compute pipeline ... invalid allocator pointer!"; return halide_error_code_generic_error; } @@ -928,7 +929,10 @@ int vk_create_compute_pipeline(void *user_context, VkResult result = vkCreateComputePipelines(allocator->current_device(), VK_NULL_HANDLE, 1, &compute_pipeline_info, allocator->callbacks(), compute_pipeline); if (result != VK_SUCCESS) { - error(user_context) << "Vulkan: Failed to create compute pipeline! vkCreateComputePipelines returned " << vk_get_error_name(result) << "\n"; + vk_report_error(user_context, result, "vkCreateComputePipeline") + << "failed to create compute pipeline " << pipeline_name << ".\n" + << " (This might be a bug in Halide. To debug this, see the HL_SPIRV_DUMP_FILE environment variable, and use the Khronos validator to make a bug report)"; + return halide_error_code_generic_error; } @@ -955,24 +959,24 @@ int vk_setup_compute_pipeline(void *user_context, #endif if (allocator == nullptr) { - error(user_context) << "Vulkan: Failed to setup compute pipeline ... invalid allocator pointer!\n"; + error(user_context) << "Vulkan: Failed to setup compute pipeline ... invalid allocator pointer!"; return halide_error_code_generic_error; } if (shader_bindings == nullptr) { - error(user_context) << "Vulkan: Failed to setup compute pipeline ... invalid shader bindings!\n"; + error(user_context) << "Vulkan: Failed to setup compute pipeline ... invalid shader bindings!"; return halide_error_code_generic_error; } if (shader_bindings == nullptr) { - error(user_context) << "Vulkan: Failed to setup compute pipeline ... invalid dispatch data!\n"; + error(user_context) << "Vulkan: Failed to setup compute pipeline ... invalid dispatch data!"; return halide_error_code_generic_error; } VkResult result = VK_SUCCESS; const char *entry_point_name = shader_bindings->entry_point_name; if (entry_point_name == nullptr) { - error(user_context) << "Vulkan: Failed to setup compute pipeline ... missing entry point name!\n"; + error(user_context) << "Vulkan: Failed to setup compute pipeline ... missing entry point name!"; return halide_error_code_generic_error; } @@ -995,7 +999,7 @@ int vk_setup_compute_pipeline(void *user_context, } else { // dynamic allocation if (shared_mem_constant_id > 0) { - error(user_context) << "Vulkan: Multiple dynamic shared memory allocations found! Only one is suported!!\n"; + error(user_context) << "Vulkan: Multiple dynamic shared memory allocations found! Only one is suported!!"; result = VK_ERROR_TOO_MANY_OBJECTS; break; } @@ -1028,13 +1032,13 @@ int vk_setup_compute_pipeline(void *user_context, if (static_shared_mem_bytes > device_shared_mem_size) { error(user_context) << "Vulkan: Amount of static shared memory used exceeds device limit!\n" << " requested: " << static_shared_mem_bytes << " bytes," - << " available: " << device_shared_mem_size << " bytes\n"; + << " available: " << device_shared_mem_size << " bytes"; return halide_error_code_incompatible_device_interface; } if (dispatch_data->shared_mem_bytes > device_shared_mem_size) { error(user_context) << "Vulkan: Amount of dynamic shared memory used exceeds device limit!\n" << " requested: " << dispatch_data->shared_mem_bytes << " bytes," - << " available: " << device_shared_mem_size << " bytes\n"; + << " available: " << device_shared_mem_size << " bytes"; return halide_error_code_incompatible_device_interface; } } @@ -1065,14 +1069,14 @@ int vk_setup_compute_pipeline(void *user_context, } } if (found_index == invalid_index) { - error(user_context) << "Vulkan: Failed to locate dispatch constant index for shader binding!\n"; + error(user_context) << "Vulkan: Failed to locate dispatch constant index for shader binding!"; result = VK_ERROR_INITIALIZATION_FAILED; } } // don't even attempt to create the pipeline layout if we encountered errors in the shader binding if (result != VK_SUCCESS) { - error(user_context) << "Vulkan: Failed to decode shader bindings! " << vk_get_error_name(result) << "\n"; + error(user_context) << "Vulkan: Failed to decode shader bindings! " << vk_get_error_name(result); return halide_error_code_generic_error; } @@ -1100,7 +1104,7 @@ int vk_setup_compute_pipeline(void *user_context, if (shader_bindings->compute_pipeline) { int error_code = vk_destroy_compute_pipeline(user_context, allocator, shader_bindings->compute_pipeline); if (error_code != halide_error_code_success) { - error(user_context) << "Vulkan: Failed to destroy compute pipeline!\n"; + error(user_context) << "Vulkan: Failed to destroy compute pipeline!"; return halide_error_code_generic_error; } shader_bindings->compute_pipeline = VK_NULL_HANDLE; @@ -1108,7 +1112,7 @@ int vk_setup_compute_pipeline(void *user_context, int error_code = vk_create_compute_pipeline(user_context, allocator, entry_point_name, shader_module, pipeline_layout, &specialization_info, &(shader_bindings->compute_pipeline)); if (error_code != halide_error_code_success) { - error(user_context) << "Vulkan: Failed to create compute pipeline!\n"; + error(user_context) << "Vulkan: Failed to create compute pipeline!"; return error_code; } @@ -1118,7 +1122,7 @@ int vk_setup_compute_pipeline(void *user_context, if (shader_bindings->compute_pipeline == VK_NULL_HANDLE) { int error_code = vk_create_compute_pipeline(user_context, allocator, entry_point_name, shader_module, pipeline_layout, nullptr, &(shader_bindings->compute_pipeline)); if (error_code != halide_error_code_success) { - error(user_context) << "Vulkan: Failed to create compute pipeline!\n"; + error(user_context) << "Vulkan: Failed to create compute pipeline!"; return error_code; } } @@ -1138,7 +1142,7 @@ int vk_destroy_compute_pipeline(void *user_context, << "compute_pipeline: " << (void *)compute_pipeline << ")\n"; #endif if (allocator == nullptr) { - error(user_context) << "Vulkan: Failed to destroy compute pipeline ... invalid allocator pointer!\n"; + error(user_context) << "Vulkan: Failed to destroy compute pipeline ... invalid allocator pointer!"; return halide_error_code_generic_error; } @@ -1160,12 +1164,12 @@ VulkanShaderBinding *vk_decode_shader_bindings(void *user_context, VulkanMemoryA #endif if (allocator == nullptr) { - error(user_context) << "Vulkan: Failed to decode shader bindings ... invalid allocator pointer!\n"; + error(user_context) << "Vulkan: Failed to decode shader bindings ... invalid allocator pointer!"; return nullptr; } if ((module_ptr == nullptr) || (module_size < (2 * sizeof(uint32_t)))) { - error(user_context) << "Vulkan: Failed to decode shader bindings ... invalid module buffer!\n"; + error(user_context) << "Vulkan: Failed to decode shader bindings ... invalid module buffer!"; return nullptr; } @@ -1213,7 +1217,7 @@ VulkanShaderBinding *vk_decode_shader_bindings(void *user_context, VulkanMemoryA uint32_t idx = 1; // skip past the header_word_count uint32_t shader_count = module_ptr[idx++]; if (shader_count < 1) { - error(user_context) << "Vulkan: Failed to decode shader bindings ... no descriptors found!\n"; + error(user_context) << "Vulkan: Failed to decode shader bindings ... no descriptors found!"; return nullptr; // no descriptors } @@ -1222,7 +1226,7 @@ VulkanShaderBinding *vk_decode_shader_bindings(void *user_context, VulkanMemoryA size_t shader_bindings_size = shader_count * sizeof(VulkanShaderBinding); VulkanShaderBinding *shader_bindings = (VulkanShaderBinding *)vk_host_malloc(user_context, shader_bindings_size, 0, alloc_scope, allocator->callbacks()); if (shader_bindings == nullptr) { - error(user_context) << "Vulkan: Failed to allocate shader_bindings! Out of memory!\n"; + error(user_context) << "Vulkan: Failed to allocate shader_bindings! Out of memory!"; return nullptr; } memset(shader_bindings, 0, shader_bindings_size); @@ -1255,7 +1259,7 @@ VulkanShaderBinding *vk_decode_shader_bindings(void *user_context, VulkanMemoryA size_t specialization_constants_size = specialization_constants_count * sizeof(VulkanSpecializationConstant); specialization_constants = (VulkanSpecializationConstant *)vk_host_malloc(user_context, specialization_constants_size, 0, alloc_scope, allocator->callbacks()); if (specialization_constants == nullptr) { - error(user_context) << "Vulkan: Failed to allocate specialization_constants! Out of memory!\n"; + error(user_context) << "Vulkan: Failed to allocate specialization_constants! Out of memory!"; return nullptr; } memset(specialization_constants, 0, specialization_constants_size); @@ -1291,7 +1295,7 @@ VulkanShaderBinding *vk_decode_shader_bindings(void *user_context, VulkanMemoryA size_t shared_memory_allocations_size = shared_memory_allocations_count * sizeof(VulkanSharedMemoryAllocation); shared_memory_allocations = (VulkanSharedMemoryAllocation *)vk_host_malloc(user_context, shared_memory_allocations_size, 0, alloc_scope, allocator->callbacks()); if (shared_memory_allocations == nullptr) { - error(user_context) << "Vulkan: Failed to allocate shared_memory_allocations! Out of memory!\n"; + error(user_context) << "Vulkan: Failed to allocate shared_memory_allocations! Out of memory!"; return nullptr; } memset(shared_memory_allocations, 0, shared_memory_allocations_size); @@ -1356,7 +1360,7 @@ VulkanShaderBinding *vk_decode_shader_bindings(void *user_context, VulkanMemoryA #endif shader_bindings[n].entry_point_name = (char *)vk_host_malloc(user_context, entry_point_name_length * sizeof(uint32_t), 0, alloc_scope, allocator->callbacks()); if (shader_bindings[n].entry_point_name == nullptr) { - error(user_context) << "Vulkan: Failed to allocate entry_point_name! Out of memory!\n"; + error(user_context) << "Vulkan: Failed to allocate entry_point_name! Out of memory!"; return nullptr; } @@ -1408,7 +1412,7 @@ int vk_validate_shader_for_device(void *user_context, VulkanMemoryAllocator *all if (static_shared_mem_bytes > device_shared_mem_size) { error(user_context) << "Vulkan: Amount of static shared memory used exceeds device limit!\n" << " requested: " << static_shared_mem_bytes << " bytes," - << " available: " << device_shared_mem_size << " bytes\n"; + << " available: " << device_shared_mem_size << " bytes"; return halide_error_code_incompatible_device_interface; } } @@ -1420,7 +1424,7 @@ int vk_validate_shader_for_device(void *user_context, VulkanMemoryAllocator *all if (shader_count > max_descriptors) { error(user_context) << "Vulkan: Number of required descriptor sets exceeds the amount available for device!\n" << " requested: " << shader_count << "," - << " available: " << max_descriptors << "\n"; + << " available: " << max_descriptors; return halide_error_code_incompatible_device_interface; } } @@ -1516,7 +1520,7 @@ VulkanCompilationCacheEntry *vk_compile_kernel_module(void *user_context, Vulkan // Compile the "SPIR-V Module" for the kernel cache_entry->compiled_modules[i] = vk_compile_shader_module(user_context, allocator, (const char *)spirv_ptr, (int)spirv_size); if (cache_entry->compiled_modules[i] == nullptr) { - debug(user_context) << "Vulkan: Failed to compile shader module!\n"; + debug(user_context) << "Vulkan: Failed to compile shader module!"; error_code = halide_error_code_generic_error; } @@ -1556,12 +1560,12 @@ VulkanCompiledShaderModule *vk_compile_shader_module(void *user_context, VulkanM #endif if (allocator == nullptr) { - error(user_context) << "Vulkan: Failed to compile shader modules ... invalid allocator pointer!\n"; + error(user_context) << "Vulkan: Failed to compile shader modules ... invalid allocator pointer!"; return nullptr; } if ((ptr == nullptr) || (size <= 0)) { - error(user_context) << "Vulkan: Failed to compile shader modules ... invalid program source buffer!\n"; + error(user_context) << "Vulkan: Failed to compile shader modules ... invalid program source buffer!"; return nullptr; } @@ -1599,7 +1603,7 @@ VulkanCompiledShaderModule *vk_compile_shader_module(void *user_context, VulkanM VkSystemAllocationScope alloc_scope = VkSystemAllocationScope::VK_SYSTEM_ALLOCATION_SCOPE_OBJECT; VulkanCompiledShaderModule *compiled_module = (VulkanCompiledShaderModule *)vk_host_malloc(user_context, sizeof(VulkanCompiledShaderModule), 0, alloc_scope, allocator->callbacks()); if (compiled_module == nullptr) { - error(user_context) << "Vulkan: Failed to allocate compilation cache entry! Out of memory!\n"; + error(user_context) << "Vulkan: Failed to allocate compilation cache entry! Out of memory!"; return nullptr; } memset(compiled_module, 0, sizeof(VulkanCompiledShaderModule)); @@ -1607,7 +1611,7 @@ VulkanCompiledShaderModule *vk_compile_shader_module(void *user_context, VulkanM // decode the entry point data and extract the shader bindings VulkanShaderBinding *decoded_bindings = vk_decode_shader_bindings(user_context, allocator, module_ptr, module_size); if (decoded_bindings == nullptr) { - error(user_context) << "Vulkan: Failed to decode shader bindings!\n"; + error(user_context) << "Vulkan: Failed to decode shader bindings!"; return nullptr; } @@ -1624,8 +1628,8 @@ VulkanCompiledShaderModule *vk_compile_shader_module(void *user_context, VulkanM compiled_module->shader_count = shader_count; VkResult result = vkCreateShaderModule(allocator->current_device(), &shader_info, allocator->callbacks(), &compiled_module->shader_module); - if ((result != VK_SUCCESS)) { - error(user_context) << "Vulkan: vkCreateShaderModule Failed! Error returned: " << vk_get_error_name(result) << "\n"; + if (result != VK_SUCCESS) { + vk_report_error(user_context, result, "vkCreateShaderModule"); vk_host_free(user_context, compiled_module->shader_bindings, allocator->callbacks()); vk_host_free(user_context, compiled_module, allocator->callbacks()); return nullptr; @@ -1635,7 +1639,7 @@ VulkanCompiledShaderModule *vk_compile_shader_module(void *user_context, VulkanM if (compiled_module->shader_count) { compiled_module->descriptor_set_layouts = (VkDescriptorSetLayout *)vk_host_malloc(user_context, compiled_module->shader_count * sizeof(VkDescriptorSetLayout), 0, alloc_scope, allocator->callbacks()); if (compiled_module->descriptor_set_layouts == nullptr) { - error(user_context) << "Vulkan: Failed to allocate descriptor set layouts for cache entry! Out of memory!\n"; + error(user_context) << "Vulkan: Failed to allocate descriptor set layouts for cache entry! Out of memory!"; return nullptr; } memset(compiled_module->descriptor_set_layouts, 0, compiled_module->shader_count * sizeof(VkDescriptorSetLayout)); @@ -1808,7 +1812,7 @@ int vk_do_multidimensional_copy(void *user_context, VkCommandBuffer command_buff VkBuffer *src_buffer = reinterpret_cast(c.src); VkBuffer *dst_buffer = reinterpret_cast(c.dst); if (!src_buffer || !dst_buffer) { - error(user_context) << "Vulkan: Failed to retrieve buffer for device memory!\n"; + error(user_context) << "Vulkan: Failed to retrieve buffer for device memory!"; return halide_error_code_internal_error; } @@ -1846,7 +1850,7 @@ int vk_device_crop_from_offset(void *user_context, VulkanContext ctx(user_context); if (ctx.error != halide_error_code_success) { - error(user_context) << "Vulkan: Failed to acquire context!\n"; + error(user_context) << "Vulkan: Failed to acquire context!"; return ctx.error; } @@ -1854,15 +1858,15 @@ int vk_device_crop_from_offset(void *user_context, uint64_t t_before = halide_current_time_ns(user_context); #endif - if (byte_offset < 0) { - error(user_context) << "Vulkan: Invalid offset for device crop!\n"; + if (byte_offset < 0) { + error(user_context) << "Vulkan: Invalid offset for device crop!"; return halide_error_code_device_crop_failed; } // get the allocated region for the device MemoryRegion *device_region = reinterpret_cast(src->device); if (device_region == nullptr) { - error(user_context) << "Vulkan: Failed to crop region! Invalide device region!\n"; + error(user_context) << "Vulkan: Failed to crop region! Invalide device region!"; return halide_error_code_device_crop_failed; } @@ -1873,7 +1877,7 @@ int vk_device_crop_from_offset(void *user_context, region_indexing.offset = byte_offset / src->type.bytes(); MemoryRegion *cropped_region = ctx.allocator->create_crop(user_context, device_region, region_indexing); if ((cropped_region == nullptr) || (cropped_region->handle == nullptr)) { - error(user_context) << "Vulkan: Failed to crop region! Unable to create memory region!\n"; + error(user_context) << "Vulkan: Failed to crop region! Unable to create memory region!"; return halide_error_code_device_crop_failed; } diff --git a/test/correctness/vector_shuffle.cpp b/test/correctness/vector_shuffle.cpp index aff6fcbcddcf..f50a5607f52a 100644 --- a/test/correctness/vector_shuffle.cpp +++ b/test/correctness/vector_shuffle.cpp @@ -1,10 +1,20 @@ #include "Halide.h" +#include +#include #include using namespace Halide; -int main(int argc, char **argv) { - Target target = get_jit_target_from_environment(); +int test_with_indices(const Target &target, const std::vector &indices0, const std::vector &indices1) { + printf("indices0:"); + for (int i : indices0) { + printf(" %d", i); + } + printf(" indices1:"); + for (int i : indices1) { + printf(" %d", i); + } + printf("\n"); Var x{"x"}, y{"y"}; Func f0{"f0"}, f1{"f1"}, g{"g"}; @@ -12,15 +22,6 @@ int main(int argc, char **argv) { f1(x, y) = x * (y + 3); Expr vec1 = Internal::Shuffle::make_concat({f0(x, 0), f0(x, 1), f0(x, 2), f0(x, 3)}); Expr vec2 = Internal::Shuffle::make_concat({f1(x, 4), f1(x, 5), f1(x, 6), f1(x, 7)}); - std::vector indices0; - std::vector indices1; - if (!target.has_gpu_feature() || target.has_feature(Target::Feature::OpenCL) || target.has_feature(Target::Feature::CUDA)) { - indices0 = {3, 1, 6, 7, 2, 4, 0, 5}; - indices1 = {1, 0, 3, 4, 7, 0, 5, 2}; - } else { - indices0 = {3, 1, 6, 7}; - indices1 = {1, 0, 3, 4}; - } Expr shuffle1 = Internal::Shuffle::make({vec1, vec2}, indices0); Expr shuffle2 = Internal::Shuffle::make({vec1, vec2}, indices1); Expr result = shuffle1 * shuffle2; @@ -55,6 +56,99 @@ int main(int argc, char **argv) { return 1; } } + return 0; +} + +int main(int argc, char **argv) { + Target target = get_jit_target_from_environment(); + + int max_vec_size = 4; + if (!target.has_gpu_feature() || target.has_feature(Target::Feature::OpenCL) || target.has_feature(Target::Feature::CUDA)) { + max_vec_size = 8; + } + + for (int vec_size = max_vec_size; vec_size > 1; vec_size /= 2) { + printf("Testing vector size %d...\n", vec_size); + std::vector indices0, indices1; + + // Test 1: All indices: foreward/backward and combined + for (int i = 0; i < vec_size; ++i) { + indices0.push_back(i); // forward + indices1.push_back(vec_size - i - 1); // backward + } + printf(" All indices forward...\n"); + if (test_with_indices(target, indices0, indices0)) { + return 1; + } + printf(" All indices backward...\n"); + if (test_with_indices(target, indices1, indices1)) { + return 1; + } + printf(" All indices mixed forware / backward...\n"); + if (test_with_indices(target, indices0, indices1)) { + return 1; + } + + // Test 2: Shuffled indices (4 repetitions) + for (int r = 0; r < 4; ++r) { + // Shuffle with Fisher-Yates + for (int i = vec_size - 1; i >= 1; --i) { + // indices0 + int idx = std::rand() % (i + 1); + std::swap(indices0[idx], indices0[i]); + // indices1 + idx = std::rand() % (i + 1); + std::swap(indices1[idx], indices1[i]); + } + printf(" Randomly shuffled...\n"); + if (test_with_indices(target, indices0, indices1)) { + return 1; + } + } + + // Test 3: Interleaved + indices0.clear(); + indices1.clear(); + for (int i = 0; i < vec_size / 2; ++i) { + // interleave (A, B) + indices0.push_back(i); + indices0.push_back(i + vec_size / 2); + + // interleave (B, A) + indices1.push_back(i + vec_size / 2); + indices1.push_back(i); + } + printf(" Interleaved...\n"); + if (test_with_indices(target, indices0, indices1)) { + return 1; + } + + // Test 4: Concat (not-really, as the input-vectors are size 4, so only if vec_size == 8, it's a concat) + indices0.clear(); + indices1.clear(); + for (int i = 0; i < vec_size; ++i) { + // concat (A, B) + indices0.push_back(i); + + // concat (B, A) + indices1.push_back((i + vec_size / 2) % (vec_size / 2)); + } + printf(" Concat...\n"); + if (test_with_indices(target, indices0, indices1)) { + return 1; + } + + if (vec_size == 4) { + indices0 = {1, 3, 2, 0}; + indices1 = {2, 3, 1, 0}; + + printf(" Specific index combination, known to have caused problems...\n"); + if (test_with_indices(target, indices0, indices1)) { + return 1; + } + } + } + printf("Success!\n"); return 0; } From 4fde93842e3260855a75369cf425e70356778ef1 Mon Sep 17 00:00:00 2001 From: Martijn Courteaux Date: Tue, 27 May 2025 15:23:16 +0200 Subject: [PATCH 02/47] Vector Legalization Pass. Useful for vectorizing to GPU backends with limited vector lanes. --- .gitignore | 3 + src/CMakeLists.txt | 32 +- src/CSE.cpp | 5 + src/IR.cpp | 4 +- src/IROperator.h | 3 +- src/LegalizeVectors.cpp | 692 ++++++++++++++++++ src/LegalizeVectors.h | 19 + src/Lower.cpp | 5 + src/Simplify_Shuffle.cpp | 5 + src/VectorizeLoops.cpp | 4 +- test/correctness/CMakeLists.txt | 1 + .../metal_long_vectors.cpp} | 2 +- test/correctness/require.cpp | 16 +- test/correctness/specialize.cpp | 26 + test/correctness/vector_shuffle.cpp | 7 +- 15 files changed, 795 insertions(+), 29 deletions(-) create mode 100644 src/LegalizeVectors.cpp create mode 100644 src/LegalizeVectors.h rename test/{error/metal_vector_too_large.cpp => correctness/metal_long_vectors.cpp} (89%) diff --git a/.gitignore b/.gitignore index a08b8e8dd7f3..888235a389d8 100644 --- a/.gitignore +++ b/.gitignore @@ -240,6 +240,9 @@ xcuserdata # NeoVim + clangd .cache +# CCLS +.ccls-cache + # Emacs tags TAGS diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 162bb6f74e4e..fec57b9da9d4 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -62,12 +62,14 @@ target_sources( Associativity.h AsyncProducers.h AutoScheduleUtils.h + BoundConstantExtentLoops.h + BoundSmallAllocations.h BoundaryConditions.h Bounds.h BoundsInference.h - BoundConstantExtentLoops.h - BoundSmallAllocations.h Buffer.h + CPlusPlusMangle.h + CSE.h Callable.h CanonicalizeGPUVars.h ClampUnsafeAccesses.h @@ -79,18 +81,16 @@ target_sources( CodeGen_LLVM.h CodeGen_Metal_Dev.h CodeGen_OpenCL_Dev.h - CodeGen_Posix.h CodeGen_PTX_Dev.h + CodeGen_Posix.h CodeGen_PyTorch.h CodeGen_Targets.h CodeGen_Vulkan_Dev.h CodeGen_WebGPU_Dev.h CompilerLogger.h ConciseCasts.h - CPlusPlusMangle.h ConstantBounds.h ConstantInterval.h - CSE.h Debug.h DebugArguments.h DebugToFile.h @@ -127,6 +127,13 @@ target_sources( Generator.h HexagonOffload.h HexagonOptimize.h + IR.h + IREquality.h + IRMatch.h + IRMutator.h + IROperator.h + IRPrinter.h + IRVisitor.h ImageParam.h InferArguments.h InjectHostDevBufferCopies.h @@ -135,19 +142,13 @@ target_sources( IntegerDivisionTable.h Interval.h IntrusivePtr.h - IR.h - IREquality.h - IRMatch.h - IRMutator.h - IROperator.h - IRPrinter.h - IRVisitor.h JITModule.h - Lambda.h - Lerp.h LICM.h LLVM_Output.h LLVM_Runtime_Linker.h + Lambda.h + LegalizeVectors.h + Lerp.h LoopCarry.h LoopPartitioningDirective.h Lower.h @@ -173,8 +174,8 @@ target_sources( PurifyIndexMath.h PythonExtensionGen.h Qualify.h - Random.h RDom.h + Random.h Realization.h RealizationOrder.h RebaseLoopsToZero.h @@ -320,6 +321,7 @@ target_sources( IRVisitor.cpp JITModule.cpp Lambda.cpp + LegalizeVectors.cpp Lerp.cpp LICM.cpp LLVM_Output.cpp diff --git a/src/CSE.cpp b/src/CSE.cpp index c2a46d93bc4d..6051e5e9cf62 100644 --- a/src/CSE.cpp +++ b/src/CSE.cpp @@ -33,6 +33,11 @@ bool should_extract(const Expr &e, bool lift_all) { return false; } + if (const Call *c = e.as()) { + // Calls with side effects should not be moved. + return c->is_pure() || c->call_type == Call::Halide; + } + if (lift_all) { return true; } diff --git a/src/IR.cpp b/src/IR.cpp index c82ae4ebd252..6190486c79e0 100644 --- a/src/IR.cpp +++ b/src/IR.cpp @@ -12,7 +12,7 @@ namespace Internal { Expr Cast::make(Type t, Expr v) { internal_assert(v.defined()) << "Cast of undefined\n"; - internal_assert(t.lanes() == v.type().lanes()) << "Cast may not change vector widths\n"; + internal_assert(t.lanes() == v.type().lanes()) << "Cast may not change vector widths: " << v << " of type " << v.type() << " cannot be cast to " << t << "\n"; Cast *node = new Cast; node->type = t; @@ -281,7 +281,7 @@ Expr Ramp::make(Expr base, Expr stride, int lanes) { Expr Broadcast::make(Expr value, int lanes) { internal_assert(value.defined()) << "Broadcast of undefined\n"; - internal_assert(lanes != 1) << "Broadcast of lanes 1\n"; + internal_assert(lanes != 1) << "Broadcast over 1 lane is not a broadcast\n"; Broadcast *node = new Broadcast; node->type = value.type().with_lanes(lanes * value.type().lanes()); diff --git a/src/IROperator.h b/src/IROperator.h index d6d33a1cf82e..527015770093 100644 --- a/src/IROperator.h +++ b/src/IROperator.h @@ -1278,7 +1278,8 @@ Expr random_int(Expr seed = Expr()); /** Create an Expr that prints out its value whenever it is * evaluated. It also prints out everything else in the arguments - * list, separated by spaces. This can include string literals. */ + * list, separated by spaces. This can include string literals. + * Evaluates to the first argument passed. */ //@{ Expr print(const std::vector &values); diff --git a/src/LegalizeVectors.cpp b/src/LegalizeVectors.cpp new file mode 100644 index 000000000000..4d939ec4a690 --- /dev/null +++ b/src/LegalizeVectors.cpp @@ -0,0 +1,692 @@ +#include "LegalizeVectors.h" +#include "CSE.h" +#include "Deinterleave.h" +#include "DeviceInterface.h" +#include "IRMutator.h" +#include "IROperator.h" +#include "Simplify.h" +#include "Util.h" + +#include + +namespace Halide { +namespace Internal { + +namespace { + +using namespace std; + +int max_lanes_for_device(DeviceAPI api, int parent_max_lanes) { + switch (api) { + case DeviceAPI::Metal: + case DeviceAPI::WebGPU: + case DeviceAPI::Vulkan: + case DeviceAPI::D3D12Compute: + return 4; + case DeviceAPI::OpenCL: + return 16; + case DeviceAPI::CUDA: + case DeviceAPI::Host: + return 0; // No max: LLVM based legalization + case DeviceAPI::None: + return parent_max_lanes; + default: + return 0; + } +} + +std::string vec_name(const string &name, int lane_start, int lane_count) { + // return name + ".ls" + std::to_string(lane_start) + ".lc" + std::to_string(lane_count); + return name + ".lanes_" + std::to_string(lane_start) + "_" + std::to_string(lane_start + lane_count - 1); +} + +Expr simplify_shuffle(const Shuffle *op) { + if (op->is_extract_element()) { + if (op->vectors.size() == 1) { + if (op->vectors[0].type().is_scalar()) { + return op->vectors[0]; + } else { + return Expr(op); + } + } else { + // Figure out which element is comes from. + int index = op->indices[0]; + internal_assert(index >= 0); + for (const Expr &vector : op->vectors) { + if (index < vector.type().lanes()) { + if (vector.type().is_scalar()) { + return vector; + } else { + return Shuffle::make_extract_element(vector, index); + } + } + index -= vector.type().lanes(); + } + internal_error << "Index out of bounds."; + } + } + + // Figure out if all extracted lanes come from 1 component. + vector> src_vec_and_lane_idx = op->vector_and_lane_indices(); + bool all_from_the_same = true; + bool is_full_vec = src_vec_and_lane_idx[0].second == 0; + for (int i = 1; i < op->indices.size(); ++i) { + if (src_vec_and_lane_idx[i].first != src_vec_and_lane_idx[0].first) { + all_from_the_same = false; + is_full_vec = false; + break; + } + if (src_vec_and_lane_idx[i].second != i) { + is_full_vec = false; + } + } + if (all_from_the_same) { + const Expr &src_vec = op->vectors[src_vec_and_lane_idx[0].first]; + is_full_vec &= src_vec.type().lanes() == op->indices.size(); + int first_lane_in_src = src_vec_and_lane_idx[0].second; + if (is_full_vec) { + return src_vec; + } else { + const Ramp *ramp = src_vec.as(); + if (ramp && op->is_slice() && op->slice_stride() == 1) { + return simplify(Ramp::make(ramp->base + first_lane_in_src * ramp->stride, ramp->stride, op->indices.size())); + } + vector new_indices; + for (int i = 0; i < op->indices.size(); ++i) { + new_indices.push_back(src_vec_and_lane_idx[i].second); + } + return Shuffle::make({src_vec}, new_indices); + } + } + + return op; +} + +class LiftLetToLetStmt : public IRMutator { +public: + vector lets; + Expr visit(const Let *op) override { + lets.push_back(op); + return mutate(op->body); + } + + Stmt mutate(const Stmt &s) override { + ScopedValue scoped_lets(lets, {}); + Stmt mutated = IRMutator::mutate(s); + for (const Let *let : reverse_view(lets)) { + mutated = LetStmt::make(let->name, let->value, mutated); + } + return mutated; + } + + Expr mutate(const Expr &e) override { + return IRMutator::mutate(e); + } +}; + +class ExtractLanes : public IRMutator { + int lane_start; + int lane_count; + int max_legal_lanes; + + Expr extract_lanes_from_make_struct(const Call *op) { + internal_assert(op); + internal_assert(op->is_intrinsic(Call::make_struct)); + vector args(op->args.size()); + for (int i = 0; i < op->args.size(); ++i) { + args[i] = mutate(op->args[i]); + } + return Call::make(op->type, Call::make_struct, args, Call::Intrinsic); + } + + Expr extract_lanes_trace(const Call *op) { + // user_error << "Cannot legalize vectors when tracing is enabled."; + auto event = as_const_int(op->args[6]); + internal_assert(event); + if (*event == halide_trace_load || *event == halide_trace_store) { + debug(3) << "Extracting Trace Lanes: " << Expr(op) << "\n"; + const Expr &func = op->args[0]; + Expr values = extract_lanes_from_make_struct(op->args[1].as()); + Expr coords = extract_lanes_from_make_struct(op->args[2].as()); + const Expr &type_code = op->args[3]; + const Expr &type_bits = op->args[4]; + int type_lanes = *as_const_int(op->args[5]); + const Expr &event = op->args[6]; + const Expr &parent_id = op->args[7]; + const Expr &idx = op->args[8]; + int size = *as_const_int(op->args[9]); + const Expr &tag = op->args[10]; + + int num_vecs = op->args[2].as()->args.size(); + internal_assert(size == type_lanes * num_vecs) << Expr(op); + vector args = { + func, + values, coords, + type_code, type_bits, Expr(lane_count), + event, parent_id, idx, Expr(lane_count * num_vecs), + tag}; + Expr result = Call::make(Int(32), Call::trace, args, Call::Extern); + debug(4) << " => " << result << "\n"; + return result; + } else { + user_warning << "Discarding tracing during vector legalization: " << Expr(op) << "\n"; + } + + // This is feasible: see VectorizeLoops. + return Expr(0); + } + +public: + ExtractLanes(int start, int count, int max_legal) + : lane_start(start), lane_count(count), max_legal_lanes(max_legal) { + } + + Expr visit(const Shuffle *op) override { + vector new_indices; + for (int i = 0; i < lane_count; ++i) { + new_indices.push_back(op->indices[lane_start + i]); + } + Expr result = Shuffle::make(op->vectors, new_indices); + return simplify_shuffle(result.as()); + } + + Expr visit(const Ramp *op) override { + if (lane_count == 1) { + return simplify(op->base + op->stride * lane_start); + } + return simplify(Ramp::make(op->base + op->stride * lane_start, op->stride, lane_count)); + } + + Expr visit(const Broadcast *op) override { + Expr value = op->value; + if (const Call *call = op->value.as()) { + if (call->name == Call::trace) { + value = extract_lanes_trace(call); + } + } + if (lane_count == 1) { + return value; + } else { + return Broadcast::make(value, lane_count); + } + } + + Expr visit(const Variable *op) override { + return Variable::make(op->type.with_lanes(lane_count), vec_name(op->name, lane_start, lane_count)); + } + + Expr visit(const Load *op) override { + return Load::make(op->type.with_lanes(lane_count), + op->name, + mutate(op->index), + op->image, op->param, + mutate(op->predicate), + op->alignment + lane_start); + } + + Expr visit(const Call *op) override { + internal_assert(op->type.lanes() >= lane_start + lane_count); + Expr mutated = op; + std::vector args; + args.reserve(op->args.size()); + for (int i = 0; i < op->args.size(); ++i) { + const Expr &arg = op->args[i]; + internal_assert(arg.type().lanes() == op->type.lanes()) << "Call argument " << arg << " lane count of " << arg.type().lanes() << " does not match op lane count of " << op->type.lanes(); + Expr mutated = mutate(arg); + internal_assert(!mutated.same_as(arg)); + args.push_back(mutated); + } + mutated = Call::make(op->type.with_lanes(lane_count), op->name, args, op->call_type); + return mutated; + } + + Expr visit(const Cast *op) override { + return Cast::make(op->type.with_lanes(lane_count), mutate(op->value)); + } + + Expr visit(const Reinterpret *op) override { + Type result_type = op->type.with_lanes(lane_count); + int result_scalar_bits = op->type.element_of().bits(); + int input_scalar_bits = op->value.type().element_of().bits(); + + Expr value = op->value; + // If the bit widths of the scalar elements are the same, it's easy. + if (result_scalar_bits == input_scalar_bits) { + value = mutate(value); + } else { + // Otherwise, there can be two limiting aspects: the input lane count and the resulting lane count. + // In order to construct a correct Reinterpret from a small type to a wider type, we + // will need to produce multiple Reinterprets, all able to hold the lane count of the input + // and concatate the results together. + // Even worse, reinterpreting uint8x8 to uint64 would require intermediate reinterprets + // if the maximul legal vector length is 4. + // + // TODO implement this for all scenarios + internal_error << "Vector legalization for Reinterpret to different bit size per element is " + << "not supported yet: reinterpret<" << result_type << ">(" << value.type() << ")"; + + // int input_lane_start = lane_start * result_scalar_bits / input_scalar_bits; + // int input_lane_count = lane_count * result_scalar_bits / input_scalar_bits; + } + Expr result = Reinterpret::make(result_type, value); + debug(3) << "Legalized " << Expr(op) << " to " << result << "\n"; + return result; + } + + Expr visit(const VectorReduce *op) override { + internal_assert(op->type.lanes() >= lane_start + lane_count); + int vecs_per_reduction = op->value.type().lanes() / op->type.lanes(); + int input_lane_start = vecs_per_reduction * lane_start; + int input_lane_count = vecs_per_reduction * lane_count; + Expr arg = ExtractLanes(input_lane_start, input_lane_count, max_legal_lanes).mutate(op->value); + // This might fail if the extracted lanes reference a non-existing variable! + return VectorReduce::make(op->op, arg, lane_count); + } + + // Small helper to assert the transform did what it's supposed to do. + Expr mutate(const Expr &e) override { + Type original_type = e.type(); + internal_assert(original_type.lanes() >= lane_start + lane_count) + << "Cannot extract lanes " << lane_start << " through " << lane_start + lane_count - 1 + << " when the input type is " << original_type; + Expr result = IRMutator::mutate(e); + Type new_type = result.type(); + internal_assert(new_type.lanes() == lane_count) + << "We didn't correctly legalize " << e << " of type " << original_type << ".\n" + << "Got back: " << result << " of type " << new_type << ", expected " << lane_count << " lanes."; + return result; + } + + Stmt mutate(const Stmt &s) override { + return IRMutator::mutate(s); + } +}; + +class LiftExceedingVectors : public IRMutator { + int max_lanes{max_lanes_for_device(DeviceAPI::Host, 0)}; + + vector> lets; + map replacements; + bool just_in_let_defintion{false}; + int in_strict_float = 0; + +public: + Stmt visit(const For *op) override { + ScopedValue scoped_max_lanes(max_lanes, max_lanes_for_device(op->device_api, max_lanes)); + return IRMutator::visit(op); + } + + template + decltype(LetOrLetStmt::body) visit_let(const LetOrLetStmt *op) { + ScopedValue scoped_just_in_let(just_in_let_defintion, true); + Expr def = mutate(op->value); + auto body = mutate(op->body); + if (def.same_as(op->value) && body.same_as(op->body)) { + return op; + } + return LetOrLetStmt::make(op->name, std::move(def), std::move(body)); + } + + Expr visit(const Let *op) override { + internal_error << "We don't want to process Lets. They should have all been converted to LetStmts."; + Expr def; + { + ScopedValue scoped_just_in_let(just_in_let_defintion, true); + def = mutate(op->value); + } + auto body = mutate(op->body); + if (def.same_as(op->value) && body.same_as(op->body)) { + return op; + } + return IRMutator::visit(op); + } + + Stmt visit(const LetStmt *op) override { + return visit_let(op); + } + + Stmt visit(const Store *op) override { + bool exceeds_lanecount = max_lanes && op->index.type().lanes() > max_lanes; + if (exceeds_lanecount) { + Expr value = mutate(op->value); + + // Split up in multiple stores + int num_vecs = (op->index.type().lanes() + max_lanes - 1) / max_lanes; + std::vector assignments; + assignments.reserve(num_vecs); + for (int i = 0; i < num_vecs; ++i) { + int lane_start = i * max_lanes; + int lane_count_for_vec = std::min(op->value.type().lanes() - lane_start, max_lanes); + Expr rhs = ExtractLanes(lane_start, lane_count_for_vec, max_lanes).mutate(value); + Expr index = ExtractLanes(lane_start, lane_count_for_vec, max_lanes).mutate(op->index); + Expr predictate = ExtractLanes(lane_start, lane_count_for_vec, max_lanes).mutate(op->predicate); + assignments.push_back(Store::make( + op->name, std::move(rhs), std::move(index), + op->param, std::move(predictate), op->alignment + lane_start)); + } + return Block::make(assignments); + } + return IRMutator::visit(op); + } + + Expr visit(const Call *op) override { + bool exceeds_lanecount = max_lanes && op->type.lanes() > max_lanes; + if (op->is_intrinsic(Call::strict_float)) { + in_strict_float++; + } + Expr mutated = op; + if (exceeds_lanecount) { + std::vector args; + args.reserve(op->args.size()); + bool changed = false; + for (int i = 0; i < op->args.size(); ++i) { + bool may_extract = true; + if (op->is_intrinsic(Call::require)) { + may_extract &= i < 2; + } + const Expr &arg = op->args[i]; + if (may_extract) { + internal_assert(arg.type().lanes() == op->type.lanes()); + Expr mutated = mutate(arg); + if (!mutated.same_as(arg)) { + changed = true; + } + args.push_back(mutated); + } else { + args.push_back(arg); + } + } + if (!changed) { + return op; + } + mutated = Call::make(op->type, op->name, args, op->call_type); + } else { + mutated = IRMutator::visit(op); + } + if (op->is_intrinsic(Call::strict_float)) { + in_strict_float--; + } + return mutated; + } + + Stmt mutate(const Stmt &s) override { + ScopedValue scoped_lets(lets, {}); + ScopedValue scoped_just_in_let(just_in_let_defintion, false); + Stmt mutated = IRMutator::mutate(s); + for (auto &let : reverse_view(lets)) { + // There is no recurse into let.second. This is handled by repeatedly calling this tranform. + mutated = LetStmt::make(let.first, let.second, mutated); + } + return mutated; + } + +#if 0 + Stmt visit(const IfThenElse *op) override { + debug(3) << "Visit IfThenElse: " << Stmt(op) << " with max lanes: " << max_lanes << "\n"; + Expr condition; + decltype(lets) condition_lets; + { + ScopedValue scoped_lets(lets, {}); + condition = mutate(op->condition); + condition_lets = lets; + } + Stmt then_case, else_case; + { + ScopedValue scoped_lets(lets, {}); + then_case = mutate(op->then_case); + for (auto &let : reverse_view(lets)) { + then_case = LetStmt::make(let.first, let.second, then_case); + } + } + { + ScopedValue scoped_lets(lets, {}); + else_case = mutate(op->else_case); + for (auto &let : reverse_view(lets)) { + else_case = LetStmt::make(let.first, let.second, else_case); + } + } + if (condition.same_as(op->condition) && + then_case.same_as(op->then_case) && + else_case.same_as(op->else_case)) { + return op; + } + Stmt mutated = IfThenElse::make(std::move(condition), std::move(then_case), std::move(else_case)); + for (auto &let : reverse_view(lets)) { + mutated = LetStmt::make(let.first, let.second, mutated); + } + return mutated; + } +#endif + + Expr mutate(const Expr &e) override { + bool exceeds_lanecount = max_lanes && e.type().lanes() > max_lanes; + + if (exceeds_lanecount) { + bool should_extract = true; + should_extract &= e.node_type() != IRNodeType::Variable; + should_extract &= e.node_type() != IRNodeType::Let; + should_extract &= e.node_type() != IRNodeType::Broadcast; + should_extract &= e.node_type() != IRNodeType::Ramp; + should_extract &= e.node_type() != IRNodeType::Call; + should_extract &= e.node_type() != IRNodeType::Add; + should_extract &= e.node_type() != IRNodeType::Sub; + should_extract &= e.node_type() != IRNodeType::Mul; + should_extract &= e.node_type() != IRNodeType::Div; + should_extract &= e.node_type() != IRNodeType::EQ; + should_extract &= e.node_type() != IRNodeType::NE; + should_extract &= e.node_type() != IRNodeType::LT; + should_extract &= e.node_type() != IRNodeType::GT; + should_extract &= e.node_type() != IRNodeType::GE; + should_extract &= e.node_type() != IRNodeType::LE; + + // TODO: Handling of strict_float is not well done. + // But at least it covers a few basic scenarios. + // This should be redone once we overhaul strict_float. + should_extract &= !in_strict_float; + + should_extract &= !just_in_let_defintion; + + debug((should_extract ? 3 : 4)) << "Max lanes (" << max_lanes << ") exceeded (" << e.type().lanes() << ") by: " << e << "\n"; + if (should_extract) { + std::string name = unique_name('t'); + Expr var = Variable::make(e.type(), name); + replacements[e] = var; + lets.emplace_back(name, e); + debug(3) << " => Lifted out into " << name << "\n"; + return var; + } + } + + ScopedValue scoped_just_in_let(just_in_let_defintion, false); + return IRMutator::mutate(e); + } + +public: + LiftExceedingVectors() = default; +}; + +class LegalizeVectors : public IRMutator { + int max_lanes{max_lanes_for_device(DeviceAPI::Host, 0)}; + +public: + Stmt visit(const For *op) override { + ScopedValue scoped_max_lanes(max_lanes, max_lanes_for_device(op->device_api, max_lanes)); + return IRMutator::visit(op); + } + + template + decltype(LetOrLetStmt::body) visit_let(const LetOrLetStmt *op) { + bool exceeds_lanecount = max_lanes && op->value.type().lanes() > max_lanes; + + if (exceeds_lanecount) { + int num_vecs = (op->value.type().lanes() + max_lanes - 1) / max_lanes; + debug(3) << "Legalize let " << op->value.type() << ": " << op->name + << " = " << op->value << " into " << num_vecs << " vecs\n"; + decltype(LetOrLetStmt::body) body = IRMutator::mutate(op->body); + for (int i = num_vecs - 1; i >= 0; --i) { + int lane_start = i * max_lanes; + int lane_count_for_vec = std::min(op->value.type().lanes() - lane_start, max_lanes); + std::string name = vec_name(op->name, lane_start, lane_count_for_vec); + + Expr value = mutate(ExtractLanes(lane_start, lane_count_for_vec, max_lanes).mutate(op->value)); + + debug(3) << " Add: let " << name << " = " << value << "\n"; + body = LetOrLetStmt::make(name, value, body); + } + return body; + } else { + return IRMutator::visit(op); + } + } + + Stmt visit(const LetStmt *op) override { + return visit_let(op); + } + + Expr visit(const Let *op) override { + internal_error << "Lets should have been lifted into letStmts."; + return visit_let(op); + } + + Expr visit(const Shuffle *op) override { + if (max_lanes == 0) { + return IRMutator::visit(op); + } + internal_assert(op->type.lanes() <= max_lanes) << Expr(op); + bool requires_mutation = false; + for (int i = 0; i < op->vectors.size(); ++i) { + if (op->vectors[i].type().lanes() > max_lanes) { + requires_mutation = true; + break; + } + } + + if (requires_mutation) { + debug(4) << "Legalizing Shuffle " << Expr(op) << "\n"; + // We are dealing with a shuffle of an exceeding-lane-count vector argument. + // We can assume the variable here has extracted lane variables in surrounding Lets. + // So let's hope it's a simple case, and we can legalize. + + vector new_vectors; + vector> vector_and_lane_indices = op->vector_and_lane_indices(); + for (int i = 0; i < op->vectors.size(); ++i) { + const Expr &vec = op->vectors[i]; + if (vec.type().lanes() > max_lanes) { + debug(4) << " Arg " << i << ": " << vec << "\n"; + int num_vecs = (vec.type().lanes() + max_lanes - 1) / max_lanes; + for (int i = 0; i < num_vecs; i++) { + int lane_start = i * max_lanes; + int lane_count_for_vec = std::min(vec.type().lanes() - lane_start, max_lanes); + new_vectors.push_back(ExtractLanes(lane_start, lane_count_for_vec, max_lanes).mutate(vec)); + } + } else { + new_vectors.push_back(IRMutator::mutate(vec)); + } + } + Expr result = Shuffle::make(new_vectors, op->indices); + result = simplify_shuffle(result.as()); + debug(3) << "Legalized " << Expr(op) << " => " << result << "\n"; + return result; + } + return IRMutator::visit(op); + } + + Expr visit(const VectorReduce *op) override { + if (max_lanes == 0) { + return IRMutator::visit(op); + } + const Expr &arg = op->value; + if (arg.type().lanes() > max_lanes) { + int vecs_per_reduction = op->value.type().lanes() / op->type.lanes(); + if (vecs_per_reduction % max_lanes == 0) { + // This should be possible too. TODO + } + + internal_assert(op->type.lanes() == 1) << "Vector legalization currently does not support VectorReduce with lanes != 1: " << Expr(op); + int num_vecs = (arg.type().lanes() + max_lanes - 1) / max_lanes; + Expr result; + for (int i = 0; i < num_vecs; i++) { + int lane_start = i * max_lanes; + int lane_count_for_vec = std::min(arg.type().lanes() - lane_start, max_lanes); + Expr partial_arg = mutate(ExtractLanes(lane_start, lane_count_for_vec, max_lanes).mutate(arg)); + Expr partial_red = VectorReduce::make(op->op, std::move(partial_arg), op->type.lanes()); + if (i == 0) { + result = partial_red; + } else { + switch (op->op) { + case VectorReduce::Add: + result = result + partial_red; + break; + case VectorReduce::SaturatingAdd: + result = saturating_add(result, partial_red); + break; + case VectorReduce::Mul: + result = result * partial_red; + break; + case VectorReduce::Min: + result = min(result, partial_red); + break; + case VectorReduce::Max: + result = max(result, partial_red); + break; + case VectorReduce::And: + result = result && partial_red; + break; + case VectorReduce::Or: + result = result || partial_red; + break; + } + } + } + return result; + } else { + return IRMutator::visit(op); + } + } +}; + +} // namespace + +Stmt legalize_vectors(const Stmt &s) { + // Similar to CSE, lifting out stuff into variables. + // Pass 1): lift out vectors that exceed lane count into variables + // Pass 2): Rewrite those vector variables as bundles of vector variables. + Stmt m0 = simplify(s); + Stmt m1 = common_subexpression_elimination(m0, false); + if (!m1.same_as(s)) { + debug(3) << "After CSE:\n" + << m1 << "\n"; + } + Stmt m2 = LiftLetToLetStmt().mutate(m1); + if (!m2.same_as(m1)) { + debug(3) << "After lifting Lets to LetStmts:\n" + << m2 << "\n"; + } + + Stmt m3 = m2; + while (true) { + Stmt m = LiftExceedingVectors().mutate(m3); + bool modified = !m3.same_as(m); + m3 = std::move(m); + if (!modified) { + debug(3) << "Nothing got lifted out\n"; + break; + } else { + debug(3) << "Atfer lifting exceeding vectors:\n" + << m3 << "\n"; + } + } + + Stmt m4 = LegalizeVectors().mutate(m3); + if (!m4.same_as(m3)) { + debug(3) << "After legalizing vectors:\n" + << m3 << "\n"; + } + if (m4.same_as(m2)) { + debug(3) << "Vector Legalization did do nothing, returning input.\n"; + return s; + } + return simplify(m4); +} +} // namespace Internal +} // namespace Halide diff --git a/src/LegalizeVectors.h b/src/LegalizeVectors.h new file mode 100644 index 000000000000..14fe8d806fb1 --- /dev/null +++ b/src/LegalizeVectors.h @@ -0,0 +1,19 @@ +#ifndef HALIDE_INTERNAL_LEGALIZE_VECTORS_H +#define HALIDE_INTERNAL_LEGALIZE_VECTORS_H + +#include "Expr.h" + +/** \file + * Defines a lowering pass that legalizes vectorized expressions + * to have a maximal lane count. + */ + +namespace Halide { +namespace Internal { + +Stmt legalize_vectors(const Stmt &s); + +} // namespace Internal +} // namespace Halide + +#endif diff --git a/src/Lower.cpp b/src/Lower.cpp index 32b64e83a2bd..7a86a4379a88 100644 --- a/src/Lower.cpp +++ b/src/Lower.cpp @@ -42,6 +42,7 @@ #include "InjectHostDevBufferCopies.h" #include "Inline.h" #include "LICM.h" +#include "LegalizeVectors.h" #include "LoopCarry.h" #include "LowerParallelTasks.h" #include "LowerWarpShuffles.h" @@ -444,6 +445,10 @@ void lower_impl(const vector &output_funcs, s = flatten_nested_ramps(s); log("Lowering after flattening nested ramps:", s); + debug(1) << "Legalizing vectors...\n"; + s = legalize_vectors(s); + log("Lowering after legalizing vectors:", s); + debug(1) << "Removing dead allocations and moving loop invariant code...\n"; s = remove_dead_allocations(s); s = simplify(s); diff --git a/src/Simplify_Shuffle.cpp b/src/Simplify_Shuffle.cpp index cf8d1f03317d..9c38f0faf622 100644 --- a/src/Simplify_Shuffle.cpp +++ b/src/Simplify_Shuffle.cpp @@ -45,6 +45,11 @@ Expr Simplify::visit(const Shuffle *op, ExprInfo *info) { new_vectors.push_back(new_vector); } + // A concat of one vector, is just the vector. + if (op->vectors.size() == 1 && op->is_concat()) { + return new_vectors[0]; + } + // Try to convert a load with shuffled indices into a // shuffle of a dense load. if (const Load *first_load = new_vectors[0].as()) { diff --git a/src/VectorizeLoops.cpp b/src/VectorizeLoops.cpp index 2d149adbaf20..fc6fd9531983 100644 --- a/src/VectorizeLoops.cpp +++ b/src/VectorizeLoops.cpp @@ -732,8 +732,8 @@ class VectorSubs : public IRMutator { if (op->is_intrinsic(Call::prefetch)) { // We don't want prefetch args to ve vectorized, but we can't just skip the mutation - // (otherwise we can end up with dead loop variables. Instead, use extract_lane() on each arg - // to scalarize it again. + // (otherwise we can end up with dead loop variables). Instead, use extract_lane() on + // each arg to scalarize it again. for (auto &arg : new_args) { if (arg.type().is_vector()) { arg = extract_lane(arg, 0); diff --git a/test/correctness/CMakeLists.txt b/test/correctness/CMakeLists.txt index 690081f5ce4b..29d0670943f8 100644 --- a/test/correctness/CMakeLists.txt +++ b/test/correctness/CMakeLists.txt @@ -223,6 +223,7 @@ tests(GROUPS correctness median3x3.cpp metal_precompiled_shaders.cpp memoize_cloned.cpp + metal_long_vectors.cpp min_extent.cpp mod.cpp mul_div_mod.cpp diff --git a/test/error/metal_vector_too_large.cpp b/test/correctness/metal_long_vectors.cpp similarity index 89% rename from test/error/metal_vector_too_large.cpp rename to test/correctness/metal_long_vectors.cpp index bf4c74bb75a0..74c2e981fc2d 100644 --- a/test/error/metal_vector_too_large.cpp +++ b/test/correctness/metal_long_vectors.cpp @@ -9,7 +9,7 @@ int main(int argc, char **argv) { Var x("x"), y("y"); f(x, y) = input(x, y) + 42; - f.vectorize(x, 16).gpu_blocks(y, DeviceAPI::Metal); + f.vectorize(x, 32).gpu_blocks(y, DeviceAPI::Metal); std::string test_object = Internal::get_test_tmp_dir() + "metal_vector_too_large.o"; Target mac_target("x86-64-osx-metal"); diff --git a/test/correctness/require.cpp b/test/correctness/require.cpp index 625383f460df..58226077d971 100644 --- a/test/correctness/require.cpp +++ b/test/correctness/require.cpp @@ -9,7 +9,7 @@ void halide_error(JITUserContext *ctx, const char *msg) { // Emitting "error.*:" to stdout or stderr will cause CMake to report the // test as a failure on Windows, regardless of error code returned, // hence the abbreviation to "err". - printf("Saw (Expected) Halide Err: %s\n", msg); + printf("Saw (Expected) Halide Err: %s", msg); error_occurred = true; } @@ -46,14 +46,18 @@ static void test(int vector_width) { if (!error_occurred) { printf("There should have been a requirement error (vector_width = %d)\n", vector_width); exit(1); + } else { + printf("OK\n"); } + printf("\n"); + p1.set(1); p2.set(kPrime1 - 1); error_occurred = false; result = f.realize({realize_width}); if (error_occurred) { - printf("There should not have been a requirement error (vector_width = %d)\n", vector_width); + printf("There should NOT have been a requirement error (vector_width = %d)\n", vector_width); exit(1); } for (int i = 0; i < realize_width; ++i) { @@ -64,6 +68,8 @@ static void test(int vector_width) { exit(1); } } + printf("OK\n"); + printf("\n"); ImageParam input(Int(32), 2); Expr h = require(p1 == p2, p1); @@ -81,8 +87,12 @@ static void test(int vector_width) { if (!error_occurred) { printf("There should have been a requirement error (vector_width = %d)\n", vector_width); exit(1); + } else { + printf("OK\n"); } + printf("\n"); + p1.set(16); p2.set(16); @@ -91,6 +101,8 @@ static void test(int vector_width) { if (error_occurred) { printf("There should NOT have been a requirement error (vector_width = %d)\n", vector_width); exit(1); + } else { + printf("OK\n"); } } diff --git a/test/correctness/specialize.cpp b/test/correctness/specialize.cpp index 1a807003f72a..8df87dd27333 100644 --- a/test/correctness/specialize.cpp +++ b/test/correctness/specialize.cpp @@ -128,6 +128,11 @@ int main(int argc, char **argv) { } } + if (!vector_store && !scalar_store) { + printf("No stores were reported\n"); + return 1; + } + // Should have used vector stores if (!vector_store || scalar_store) { printf("This was supposed to use vector stores\n"); @@ -156,6 +161,11 @@ int main(int argc, char **argv) { } } + if (!vector_store && !scalar_store) { + printf("No stores were reported\n"); + return 1; + } + // Should have used scalar stores if (vector_store || !scalar_store) { printf("This was supposed to use scalar stores\n"); @@ -243,6 +253,10 @@ int main(int argc, char **argv) { // Check we don't crash with the small input, and that it uses scalar stores reset_trace(); f.realize({5}); + if (!vector_store && !scalar_store) { + printf("No stores were reported\n"); + return 1; + } if (!scalar_store || vector_store) { printf("These stores were supposed to be scalar.\n"); return 1; @@ -254,6 +268,10 @@ int main(int argc, char **argv) { reset_trace(); f.realize({100}); + if (!vector_store && !scalar_store) { + printf("No stores were reported\n"); + return 1; + } if (scalar_store || !vector_store) { printf("These stores were supposed to be vector.\n"); return 1; @@ -282,6 +300,10 @@ int main(int argc, char **argv) { // Check we used scalar stores for a strided input. reset_trace(); f.realize({100}); + if (!vector_store && !scalar_store) { + printf("No stores were reported\n"); + return 1; + } if (!scalar_store || vector_store) { printf("These stores were supposed to be scalar.\n"); return 1; @@ -293,6 +315,10 @@ int main(int argc, char **argv) { reset_trace(); f.realize({100}); + if (!vector_store && !scalar_store) { + printf("No stores were reported\n"); + return 1; + } if (scalar_store || !vector_store) { printf("These stores were supposed to be vector.\n"); return 1; diff --git a/test/correctness/vector_shuffle.cpp b/test/correctness/vector_shuffle.cpp index f50a5607f52a..f0a62ab3d8cd 100644 --- a/test/correctness/vector_shuffle.cpp +++ b/test/correctness/vector_shuffle.cpp @@ -62,12 +62,7 @@ int test_with_indices(const Target &target, const std::vector &indices0, co int main(int argc, char **argv) { Target target = get_jit_target_from_environment(); - int max_vec_size = 4; - if (!target.has_gpu_feature() || target.has_feature(Target::Feature::OpenCL) || target.has_feature(Target::Feature::CUDA)) { - max_vec_size = 8; - } - - for (int vec_size = max_vec_size; vec_size > 1; vec_size /= 2) { + for (int vec_size = 8; vec_size > 1; vec_size /= 2) { printf("Testing vector size %d...\n", vec_size); std::vector indices0, indices1; From e88e66ce77a44fca99475f8abc802401b1cc1fdb Mon Sep 17 00:00:00 2001 From: Martijn Courteaux Date: Tue, 27 May 2025 15:33:53 +0200 Subject: [PATCH 03/47] Fix Makefile. --- Makefile | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Makefile b/Makefile index 54c61a622ae8..a928cd9b81bb 100644 --- a/Makefile +++ b/Makefile @@ -535,6 +535,7 @@ SOURCE_FILES = \ IRVisitor.cpp \ JITModule.cpp \ Lambda.cpp \ + LegalizeVectors.cpp \ Lerp.cpp \ LICM.cpp \ LLVM_Output.cpp \ @@ -737,6 +738,7 @@ HEADER_FILES = \ WasmExecutor.h \ JITModule.h \ Lambda.h \ + LegalizeVectors.h \ Lerp.h \ LICM.h \ LLVM_Output.h \ From 2182bd19439bc33b33e6c2a47d5fb68755e9f041 Mon Sep 17 00:00:00 2001 From: Martijn Courteaux Date: Tue, 27 May 2025 15:46:25 +0200 Subject: [PATCH 04/47] Cleanup. --- src/LegalizeVectors.cpp | 96 ++++++++--------------------------------- 1 file changed, 19 insertions(+), 77 deletions(-) diff --git a/src/LegalizeVectors.cpp b/src/LegalizeVectors.cpp index 4d939ec4a690..af45538d3561 100644 --- a/src/LegalizeVectors.cpp +++ b/src/LegalizeVectors.cpp @@ -103,13 +103,13 @@ Expr simplify_shuffle(const Shuffle *op) { } class LiftLetToLetStmt : public IRMutator { -public: vector lets; Expr visit(const Let *op) override { lets.push_back(op); return mutate(op->body); } +public: Stmt mutate(const Stmt &s) override { ScopedValue scoped_lets(lets, {}); Stmt mutated = IRMutator::mutate(s); @@ -176,11 +176,6 @@ class ExtractLanes : public IRMutator { return Expr(0); } -public: - ExtractLanes(int start, int count, int max_legal) - : lane_start(start), lane_count(count), max_legal_lanes(max_legal) { - } - Expr visit(const Shuffle *op) override { vector new_indices; for (int i = 0; i < lane_count; ++i) { @@ -283,6 +278,7 @@ class ExtractLanes : public IRMutator { return VectorReduce::make(op->op, arg, lane_count); } +public: // Small helper to assert the transform did what it's supposed to do. Expr mutate(const Expr &e) override { Type original_type = e.type(); @@ -300,6 +296,10 @@ class ExtractLanes : public IRMutator { Stmt mutate(const Stmt &s) override { return IRMutator::mutate(s); } + + ExtractLanes(int start, int count, int max_legal) + : lane_start(start), lane_count(count), max_legal_lanes(max_legal) { + } }; class LiftExceedingVectors : public IRMutator { @@ -310,39 +310,24 @@ class LiftExceedingVectors : public IRMutator { bool just_in_let_defintion{false}; int in_strict_float = 0; -public: Stmt visit(const For *op) override { ScopedValue scoped_max_lanes(max_lanes, max_lanes_for_device(op->device_api, max_lanes)); return IRMutator::visit(op); } - template - decltype(LetOrLetStmt::body) visit_let(const LetOrLetStmt *op) { - ScopedValue scoped_just_in_let(just_in_let_defintion, true); - Expr def = mutate(op->value); - auto body = mutate(op->body); - if (def.same_as(op->value) && body.same_as(op->body)) { - return op; - } - return LetOrLetStmt::make(op->name, std::move(def), std::move(body)); - } - Expr visit(const Let *op) override { internal_error << "We don't want to process Lets. They should have all been converted to LetStmts."; - Expr def; - { - ScopedValue scoped_just_in_let(just_in_let_defintion, true); - def = mutate(op->value); - } - auto body = mutate(op->body); - if (def.same_as(op->value) && body.same_as(op->body)) { - return op; - } - return IRMutator::visit(op); + return {}; } Stmt visit(const LetStmt *op) override { - return visit_let(op); + ScopedValue scoped_just_in_let(just_in_let_defintion, true); + Expr def = mutate(op->value); + Stmt body = mutate(op->body); + if (def.same_as(op->value) && body.same_as(op->body)) { + return op; + } + return LetStmt::make(op->name, std::move(def), std::move(body)); } Stmt visit(const Store *op) override { @@ -409,6 +394,7 @@ class LiftExceedingVectors : public IRMutator { return mutated; } +public: Stmt mutate(const Stmt &s) override { ScopedValue scoped_lets(lets, {}); ScopedValue scoped_just_in_let(just_in_let_defintion, false); @@ -420,44 +406,6 @@ class LiftExceedingVectors : public IRMutator { return mutated; } -#if 0 - Stmt visit(const IfThenElse *op) override { - debug(3) << "Visit IfThenElse: " << Stmt(op) << " with max lanes: " << max_lanes << "\n"; - Expr condition; - decltype(lets) condition_lets; - { - ScopedValue scoped_lets(lets, {}); - condition = mutate(op->condition); - condition_lets = lets; - } - Stmt then_case, else_case; - { - ScopedValue scoped_lets(lets, {}); - then_case = mutate(op->then_case); - for (auto &let : reverse_view(lets)) { - then_case = LetStmt::make(let.first, let.second, then_case); - } - } - { - ScopedValue scoped_lets(lets, {}); - else_case = mutate(op->else_case); - for (auto &let : reverse_view(lets)) { - else_case = LetStmt::make(let.first, let.second, else_case); - } - } - if (condition.same_as(op->condition) && - then_case.same_as(op->then_case) && - else_case.same_as(op->else_case)) { - return op; - } - Stmt mutated = IfThenElse::make(std::move(condition), std::move(then_case), std::move(else_case)); - for (auto &let : reverse_view(lets)) { - mutated = LetStmt::make(let.first, let.second, mutated); - } - return mutated; - } -#endif - Expr mutate(const Expr &e) override { bool exceeds_lanecount = max_lanes && e.type().lanes() > max_lanes; @@ -508,21 +456,19 @@ class LiftExceedingVectors : public IRMutator { class LegalizeVectors : public IRMutator { int max_lanes{max_lanes_for_device(DeviceAPI::Host, 0)}; -public: Stmt visit(const For *op) override { ScopedValue scoped_max_lanes(max_lanes, max_lanes_for_device(op->device_api, max_lanes)); return IRMutator::visit(op); } - template - decltype(LetOrLetStmt::body) visit_let(const LetOrLetStmt *op) { + Stmt visit(const LetStmt *op) override { bool exceeds_lanecount = max_lanes && op->value.type().lanes() > max_lanes; if (exceeds_lanecount) { int num_vecs = (op->value.type().lanes() + max_lanes - 1) / max_lanes; debug(3) << "Legalize let " << op->value.type() << ": " << op->name << " = " << op->value << " into " << num_vecs << " vecs\n"; - decltype(LetOrLetStmt::body) body = IRMutator::mutate(op->body); + Stmt body = IRMutator::mutate(op->body); for (int i = num_vecs - 1; i >= 0; --i) { int lane_start = i * max_lanes; int lane_count_for_vec = std::min(op->value.type().lanes() - lane_start, max_lanes); @@ -531,7 +477,7 @@ class LegalizeVectors : public IRMutator { Expr value = mutate(ExtractLanes(lane_start, lane_count_for_vec, max_lanes).mutate(op->value)); debug(3) << " Add: let " << name << " = " << value << "\n"; - body = LetOrLetStmt::make(name, value, body); + body = LetStmt::make(name, value, body); } return body; } else { @@ -539,13 +485,9 @@ class LegalizeVectors : public IRMutator { } } - Stmt visit(const LetStmt *op) override { - return visit_let(op); - } - Expr visit(const Let *op) override { internal_error << "Lets should have been lifted into letStmts."; - return visit_let(op); + return {}; } Expr visit(const Shuffle *op) override { From b34592969fcbbf61ff93f1a4a1ef8cf789cf7bc3 Mon Sep 17 00:00:00 2001 From: Martijn Courteaux Date: Tue, 27 May 2025 18:06:40 +0200 Subject: [PATCH 05/47] Cleanup vector legalization. --- src/LegalizeVectors.cpp | 102 +++++++++++++++++----------------------- 1 file changed, 44 insertions(+), 58 deletions(-) diff --git a/src/LegalizeVectors.cpp b/src/LegalizeVectors.cpp index af45538d3561..d3ceb538bdc0 100644 --- a/src/LegalizeVectors.cpp +++ b/src/LegalizeVectors.cpp @@ -26,17 +26,19 @@ int max_lanes_for_device(DeviceAPI api, int parent_max_lanes) { case DeviceAPI::OpenCL: return 16; case DeviceAPI::CUDA: + case DeviceAPI::Hexagon: + case DeviceAPI::HexagonDma: case DeviceAPI::Host: return 0; // No max: LLVM based legalization case DeviceAPI::None: return parent_max_lanes; - default: + case DeviceAPI::Default_GPU: + internal_error << "No GPU API was selected."; return 0; } } std::string vec_name(const string &name, int lane_start, int lane_count) { - // return name + ".ls" + std::to_string(lane_start) + ".lc" + std::to_string(lane_count); return name + ".lanes_" + std::to_string(lane_start) + "_" + std::to_string(lane_start + lane_count - 1); } @@ -70,7 +72,7 @@ Expr simplify_shuffle(const Shuffle *op) { vector> src_vec_and_lane_idx = op->vector_and_lane_indices(); bool all_from_the_same = true; bool is_full_vec = src_vec_and_lane_idx[0].second == 0; - for (int i = 1; i < op->indices.size(); ++i) { + for (int i = 1; i < int(op->indices.size()); ++i) { if (src_vec_and_lane_idx[i].first != src_vec_and_lane_idx[0].first) { all_from_the_same = false; is_full_vec = false; @@ -82,7 +84,7 @@ Expr simplify_shuffle(const Shuffle *op) { } if (all_from_the_same) { const Expr &src_vec = op->vectors[src_vec_and_lane_idx[0].first]; - is_full_vec &= src_vec.type().lanes() == op->indices.size(); + is_full_vec &= src_vec.type().lanes() == int(op->indices.size()); int first_lane_in_src = src_vec_and_lane_idx[0].second; if (is_full_vec) { return src_vec; @@ -92,7 +94,7 @@ Expr simplify_shuffle(const Shuffle *op) { return simplify(Ramp::make(ramp->base + first_lane_in_src * ramp->stride, ramp->stride, op->indices.size())); } vector new_indices; - for (int i = 0; i < op->indices.size(); ++i) { + for (int i = 0; i < int(op->indices.size()); ++i) { new_indices.push_back(src_vec_and_lane_idx[i].second); } return Shuffle::make({src_vec}, new_indices); @@ -133,7 +135,7 @@ class ExtractLanes : public IRMutator { internal_assert(op); internal_assert(op->is_intrinsic(Call::make_struct)); vector args(op->args.size()); - for (int i = 0; i < op->args.size(); ++i) { + for (int i = 0; i < int(op->args.size()); ++i) { args[i] = mutate(op->args[i]); } return Call::make(op->type, Call::make_struct, args, Call::Intrinsic); @@ -224,7 +226,7 @@ class ExtractLanes : public IRMutator { Expr mutated = op; std::vector args; args.reserve(op->args.size()); - for (int i = 0; i < op->args.size(); ++i) { + for (int i = 0; i < int(op->args.size()); ++i) { const Expr &arg = op->args[i]; internal_assert(arg.type().lanes() == op->type.lanes()) << "Call argument " << arg << " lane count of " << arg.type().lanes() << " does not match op lane count of " << op->type.lanes(); Expr mutated = mutate(arg); @@ -306,7 +308,6 @@ class LiftExceedingVectors : public IRMutator { int max_lanes{max_lanes_for_device(DeviceAPI::Host, 0)}; vector> lets; - map replacements; bool just_in_let_defintion{false}; int in_strict_float = 0; @@ -330,31 +331,9 @@ class LiftExceedingVectors : public IRMutator { return LetStmt::make(op->name, std::move(def), std::move(body)); } - Stmt visit(const Store *op) override { - bool exceeds_lanecount = max_lanes && op->index.type().lanes() > max_lanes; - if (exceeds_lanecount) { - Expr value = mutate(op->value); - - // Split up in multiple stores - int num_vecs = (op->index.type().lanes() + max_lanes - 1) / max_lanes; - std::vector assignments; - assignments.reserve(num_vecs); - for (int i = 0; i < num_vecs; ++i) { - int lane_start = i * max_lanes; - int lane_count_for_vec = std::min(op->value.type().lanes() - lane_start, max_lanes); - Expr rhs = ExtractLanes(lane_start, lane_count_for_vec, max_lanes).mutate(value); - Expr index = ExtractLanes(lane_start, lane_count_for_vec, max_lanes).mutate(op->index); - Expr predictate = ExtractLanes(lane_start, lane_count_for_vec, max_lanes).mutate(op->predicate); - assignments.push_back(Store::make( - op->name, std::move(rhs), std::move(index), - op->param, std::move(predictate), op->alignment + lane_start)); - } - return Block::make(assignments); - } - return IRMutator::visit(op); - } - Expr visit(const Call *op) override { + // Custom handling of Call, to prevent certain things from being extracted out + // of the call arguments, as that's not always allowed. bool exceeds_lanecount = max_lanes && op->type.lanes() > max_lanes; if (op->is_intrinsic(Call::strict_float)) { in_strict_float++; @@ -364,7 +343,7 @@ class LiftExceedingVectors : public IRMutator { std::vector args; args.reserve(op->args.size()); bool changed = false; - for (int i = 0; i < op->args.size(); ++i) { + for (int i = 0; i < int(op->args.size()); ++i) { bool may_extract = true; if (op->is_intrinsic(Call::require)) { may_extract &= i < 2; @@ -410,22 +389,9 @@ class LiftExceedingVectors : public IRMutator { bool exceeds_lanecount = max_lanes && e.type().lanes() > max_lanes; if (exceeds_lanecount) { - bool should_extract = true; - should_extract &= e.node_type() != IRNodeType::Variable; - should_extract &= e.node_type() != IRNodeType::Let; - should_extract &= e.node_type() != IRNodeType::Broadcast; - should_extract &= e.node_type() != IRNodeType::Ramp; - should_extract &= e.node_type() != IRNodeType::Call; - should_extract &= e.node_type() != IRNodeType::Add; - should_extract &= e.node_type() != IRNodeType::Sub; - should_extract &= e.node_type() != IRNodeType::Mul; - should_extract &= e.node_type() != IRNodeType::Div; - should_extract &= e.node_type() != IRNodeType::EQ; - should_extract &= e.node_type() != IRNodeType::NE; - should_extract &= e.node_type() != IRNodeType::LT; - should_extract &= e.node_type() != IRNodeType::GT; - should_extract &= e.node_type() != IRNodeType::GE; - should_extract &= e.node_type() != IRNodeType::LE; + bool should_extract = false; + should_extract |= e.node_type() == IRNodeType::Shuffle; + should_extract |= e.node_type() == IRNodeType::VectorReduce; // TODO: Handling of strict_float is not well done. // But at least it covers a few basic scenarios. @@ -438,7 +404,6 @@ class LiftExceedingVectors : public IRMutator { if (should_extract) { std::string name = unique_name('t'); Expr var = Variable::make(e.type(), name); - replacements[e] = var; lets.emplace_back(name, e); debug(3) << " => Lifted out into " << name << "\n"; return var; @@ -448,9 +413,6 @@ class LiftExceedingVectors : public IRMutator { ScopedValue scoped_just_in_let(just_in_let_defintion, false); return IRMutator::mutate(e); } - -public: - LiftExceedingVectors() = default; }; class LegalizeVectors : public IRMutator { @@ -490,13 +452,37 @@ class LegalizeVectors : public IRMutator { return {}; } + Stmt visit(const Store *op) override { + bool exceeds_lanecount = max_lanes && op->index.type().lanes() > max_lanes; + if (exceeds_lanecount) { + // Split up in multiple stores + int num_vecs = (op->index.type().lanes() + max_lanes - 1) / max_lanes; + std::vector assignments; + assignments.reserve(num_vecs); + for (int i = 0; i < num_vecs; ++i) { + int lane_start = i * max_lanes; + int lane_count_for_vec = std::min(op->value.type().lanes() - lane_start, max_lanes); + Expr rhs = ExtractLanes(lane_start, lane_count_for_vec, max_lanes).mutate(op->value); + Expr index = ExtractLanes(lane_start, lane_count_for_vec, max_lanes).mutate(op->index); + Expr predictate = ExtractLanes(lane_start, lane_count_for_vec, max_lanes).mutate(op->predicate); + assignments.push_back(Store::make( + op->name, std::move(rhs), std::move(index), + op->param, std::move(predictate), op->alignment + lane_start)); + } + Stmt result = Block::make(assignments); + debug(3) << "Legalized store " << Stmt(op) << " => " << result << "\n"; + return result; + } + return IRMutator::visit(op); + } + Expr visit(const Shuffle *op) override { if (max_lanes == 0) { return IRMutator::visit(op); } internal_assert(op->type.lanes() <= max_lanes) << Expr(op); bool requires_mutation = false; - for (int i = 0; i < op->vectors.size(); ++i) { + for (size_t i = 0; i < op->vectors.size(); ++i) { if (op->vectors[i].type().lanes() > max_lanes) { requires_mutation = true; break; @@ -511,7 +497,7 @@ class LegalizeVectors : public IRMutator { vector new_vectors; vector> vector_and_lane_indices = op->vector_and_lane_indices(); - for (int i = 0; i < op->vectors.size(); ++i) { + for (int i = 0; i < int(op->vectors.size()); ++i) { const Expr &vec = op->vectors[i]; if (vec.type().lanes() > max_lanes) { debug(4) << " Arg " << i << ": " << vec << "\n"; @@ -591,8 +577,8 @@ class LegalizeVectors : public IRMutator { Stmt legalize_vectors(const Stmt &s) { // Similar to CSE, lifting out stuff into variables. - // Pass 1): lift out vectors that exceed lane count into variables - // Pass 2): Rewrite those vector variables as bundles of vector variables. + // Pass 1): lift out Shuffles that exceed lane count into variables + // Pass 2): Rewrite those vector variables as bundles of vector variables, while legalizing all other stuff. Stmt m0 = simplify(s); Stmt m1 = common_subexpression_elimination(m0, false); if (!m1.same_as(s)) { @@ -622,7 +608,7 @@ Stmt legalize_vectors(const Stmt &s) { Stmt m4 = LegalizeVectors().mutate(m3); if (!m4.same_as(m3)) { debug(3) << "After legalizing vectors:\n" - << m3 << "\n"; + << m4 << "\n"; } if (m4.same_as(m2)) { debug(3) << "Vector Legalization did do nothing, returning input.\n"; From 488426ce295801146cd60d150b7c4c8cb54ca549 Mon Sep 17 00:00:00 2001 From: Martijn Courteaux Date: Wed, 28 May 2025 11:28:18 +0200 Subject: [PATCH 06/47] Try to fix the compiler complaint around visibility. --- src/LegalizeVectors.cpp | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/LegalizeVectors.cpp b/src/LegalizeVectors.cpp index d3ceb538bdc0..164cc76a9342 100644 --- a/src/LegalizeVectors.cpp +++ b/src/LegalizeVectors.cpp @@ -105,6 +105,8 @@ Expr simplify_shuffle(const Shuffle *op) { } class LiftLetToLetStmt : public IRMutator { + using IRMutator::visit; + vector lets; Expr visit(const Let *op) override { lets.push_back(op); @@ -127,6 +129,8 @@ class LiftLetToLetStmt : public IRMutator { }; class ExtractLanes : public IRMutator { + using IRMutator::visit; + int lane_start; int lane_count; int max_legal_lanes; @@ -305,6 +309,8 @@ class ExtractLanes : public IRMutator { }; class LiftExceedingVectors : public IRMutator { + using IRMutator::visit; + int max_lanes{max_lanes_for_device(DeviceAPI::Host, 0)}; vector> lets; @@ -416,6 +422,8 @@ class LiftExceedingVectors : public IRMutator { }; class LegalizeVectors : public IRMutator { + using IRMutator::visit; + int max_lanes{max_lanes_for_device(DeviceAPI::Host, 0)}; Stmt visit(const For *op) override { From c44a130e501bae4544bddc3cc2d4d6415c1f796d Mon Sep 17 00:00:00 2001 From: Martijn Courteaux Date: Wed, 28 May 2025 11:44:25 +0200 Subject: [PATCH 07/47] GCC-9 does not understand a complete switch? --- src/LegalizeVectors.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/LegalizeVectors.cpp b/src/LegalizeVectors.cpp index 164cc76a9342..a70087b2862c 100644 --- a/src/LegalizeVectors.cpp +++ b/src/LegalizeVectors.cpp @@ -36,6 +36,8 @@ int max_lanes_for_device(DeviceAPI api, int parent_max_lanes) { internal_error << "No GPU API was selected."; return 0; } + internal_error << "Unknown Device API"; + return 0; } std::string vec_name(const string &name, int lane_start, int lane_count) { From 17a8c0aadb95ac98de39a853f8f749906685e9d1 Mon Sep 17 00:00:00 2001 From: Martijn Courteaux Date: Thu, 5 Jun 2025 02:03:02 +0200 Subject: [PATCH 08/47] Do not lift Let out to LetStmt if we are not in a loop with lane limit. Other feedback: typos, and clarifications. --- src/LegalizeVectors.cpp | 77 +++++++++++++++++++++++++++++------------ src/Simplify_Let.cpp | 2 +- 2 files changed, 55 insertions(+), 24 deletions(-) diff --git a/src/LegalizeVectors.cpp b/src/LegalizeVectors.cpp index a70087b2862c..a7910a1a0a5b 100644 --- a/src/LegalizeVectors.cpp +++ b/src/LegalizeVectors.cpp @@ -16,7 +16,13 @@ namespace { using namespace std; +const char *legalization_error_guide = "\n(This issue can most likely be resolved by reducing lane count for vectorize() calls in the schedule, or disabling it.)"; + int max_lanes_for_device(DeviceAPI api, int parent_max_lanes) { + std::string envvar = Halide::Internal::get_env_variable("HL_FORCE_VECTOR_LEGALIZATION"); + if (!envvar.empty()) { + return std::atoi(envvar.c_str()); + } switch (api) { case DeviceAPI::Metal: case DeviceAPI::WebGPU: @@ -109,10 +115,25 @@ Expr simplify_shuffle(const Shuffle *op) { class LiftLetToLetStmt : public IRMutator { using IRMutator::visit; + int max_lanes{max_lanes_for_device(DeviceAPI::Host, 0)}; + + Stmt visit(const For *op) override { + ScopedValue scoped_max_lanes(max_lanes, max_lanes_for_device(op->device_api, max_lanes)); + return IRMutator::visit(op); + } + vector lets; Expr visit(const Let *op) override { - lets.push_back(op); - return mutate(op->body); + if (max_lanes != 0) { + for (const Let *existing : lets) { + internal_assert(existing->name != op->name) + << "Let " << op->name << " = ... cannot be lifted to LetStmt because the name is not unique."; + } + lets.push_back(op); + return mutate(op->body); + } else { + return IRMutator::visit(op); + } } public: @@ -148,7 +169,6 @@ class ExtractLanes : public IRMutator { } Expr extract_lanes_trace(const Call *op) { - // user_error << "Cannot legalize vectors when tracing is enabled."; auto event = as_const_int(op->args[6]); internal_assert(event); if (*event == halide_trace_load || *event == halide_trace_store) { @@ -176,11 +196,10 @@ class ExtractLanes : public IRMutator { Expr result = Call::make(Int(32), Call::trace, args, Call::Extern); debug(4) << " => " << result << "\n"; return result; - } else { - user_warning << "Discarding tracing during vector legalization: " << Expr(op) << "\n"; } - // This is feasible: see VectorizeLoops. + internal_error << "Unhandled trace call in LegalizeVectors' ExtractLanes: " << *event << legalization_error_guide << "\n" + << "Please report this error on GitHub." << legalization_error_guide; return Expr(0); } @@ -234,7 +253,9 @@ class ExtractLanes : public IRMutator { args.reserve(op->args.size()); for (int i = 0; i < int(op->args.size()); ++i) { const Expr &arg = op->args[i]; - internal_assert(arg.type().lanes() == op->type.lanes()) << "Call argument " << arg << " lane count of " << arg.type().lanes() << " does not match op lane count of " << op->type.lanes(); + internal_assert(arg.type().lanes() == op->type.lanes()) + << "Call argument " << arg << " lane count of " << arg.type().lanes() + << " does not match op lane count of " << op->type.lanes(); Expr mutated = mutate(arg); internal_assert(!mutated.same_as(arg)); args.push_back(mutated); @@ -266,7 +287,8 @@ class ExtractLanes : public IRMutator { // // TODO implement this for all scenarios internal_error << "Vector legalization for Reinterpret to different bit size per element is " - << "not supported yet: reinterpret<" << result_type << ">(" << value.type() << ")"; + << "not supported yet: reinterpret<" << result_type << ">(" << value.type() << ")" + << legalization_error_guide; // int input_lane_start = lane_start * result_scalar_bits / input_scalar_bits; // int input_lane_count = lane_count * result_scalar_bits / input_scalar_bits; @@ -316,7 +338,7 @@ class LiftExceedingVectors : public IRMutator { int max_lanes{max_lanes_for_device(DeviceAPI::Host, 0)}; vector> lets; - bool just_in_let_defintion{false}; + bool just_in_let_definition{false}; int in_strict_float = 0; Stmt visit(const For *op) override { @@ -325,13 +347,19 @@ class LiftExceedingVectors : public IRMutator { } Expr visit(const Let *op) override { - internal_error << "We don't want to process Lets. They should have all been converted to LetStmts."; - return {}; + if (max_lanes != 0) { + internal_error << "We don't want to process Lets. They should have all been converted to LetStmts."; + } else { + return IRMutator::visit(op); + } } Stmt visit(const LetStmt *op) override { - ScopedValue scoped_just_in_let(just_in_let_defintion, true); - Expr def = mutate(op->value); + Expr def; + { + ScopedValue scoped_just_in_let(just_in_let_definition, true); + def = mutate(op->value); + } Stmt body = mutate(op->body); if (def.same_as(op->value) && body.same_as(op->body)) { return op; @@ -384,7 +412,7 @@ class LiftExceedingVectors : public IRMutator { public: Stmt mutate(const Stmt &s) override { ScopedValue scoped_lets(lets, {}); - ScopedValue scoped_just_in_let(just_in_let_defintion, false); + ScopedValue scoped_just_in_let(just_in_let_definition, false); Stmt mutated = IRMutator::mutate(s); for (auto &let : reverse_view(lets)) { // There is no recurse into let.second. This is handled by repeatedly calling this tranform. @@ -406,7 +434,7 @@ class LiftExceedingVectors : public IRMutator { // This should be redone once we overhaul strict_float. should_extract &= !in_strict_float; - should_extract &= !just_in_let_defintion; + should_extract &= !just_in_let_definition; debug((should_extract ? 3 : 4)) << "Max lanes (" << max_lanes << ") exceeded (" << e.type().lanes() << ") by: " << e << "\n"; if (should_extract) { @@ -418,7 +446,7 @@ class LiftExceedingVectors : public IRMutator { } } - ScopedValue scoped_just_in_let(just_in_let_defintion, false); + ScopedValue scoped_just_in_let(just_in_let_definition, false); return IRMutator::mutate(e); } }; @@ -458,8 +486,10 @@ class LegalizeVectors : public IRMutator { } Expr visit(const Let *op) override { - internal_error << "Lets should have been lifted into letStmts."; - return {}; + if (max_lanes != 0) { + internal_error << "Lets should have been lifted into LetStmts."; + } + return IRMutator::visit(op); } Stmt visit(const Store *op) override { @@ -535,12 +565,13 @@ class LegalizeVectors : public IRMutator { } const Expr &arg = op->value; if (arg.type().lanes() > max_lanes) { - int vecs_per_reduction = op->value.type().lanes() / op->type.lanes(); - if (vecs_per_reduction % max_lanes == 0) { - // This should be possible too. TODO - } + // TODO: The transformation below is not allowed under strict_float, but + // I won't bother right now, as strict_float is due for an overhaul. + // This should be an internal_assert. - internal_assert(op->type.lanes() == 1) << "Vector legalization currently does not support VectorReduce with lanes != 1: " << Expr(op); + internal_assert(op->type.lanes() == 1) + << "Vector legalization currently does not support VectorReduce with lanes != 1: " << Expr(op) + << legalization_error_guide; int num_vecs = (arg.type().lanes() + max_lanes - 1) / max_lanes; Expr result; for (int i = 0; i < num_vecs; i++) { diff --git a/src/Simplify_Let.cpp b/src/Simplify_Let.cpp index 801163215dc9..0d7b6677f8e6 100644 --- a/src/Simplify_Let.cpp +++ b/src/Simplify_Let.cpp @@ -185,7 +185,7 @@ Body Simplify::simplify_let(const LetOrLetStmt *op, ExprInfo *info) { // with other shuffles. // As the structure of this while loop makes it hard to peel off // pure operations from _all_ arguments to the Shuffle, we will - // instead subsitute all of the vars that go in the shuffle, and + // instead substitute all of the vars that go in the shuffle, and // instead guard against side effects by checking with `is_pure()`. replacement = substitute(f.new_name, shuffle, replacement); f.new_value = Expr(); From 306b6166877666849c220e99f2af2419d9d0868b Mon Sep 17 00:00:00 2001 From: Martijn Courteaux Date: Thu, 5 Jun 2025 02:20:16 +0200 Subject: [PATCH 09/47] Improve error message for reinterpret. --- src/LegalizeVectors.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/LegalizeVectors.cpp b/src/LegalizeVectors.cpp index a7910a1a0a5b..fa11466ecc02 100644 --- a/src/LegalizeVectors.cpp +++ b/src/LegalizeVectors.cpp @@ -287,7 +287,7 @@ class ExtractLanes : public IRMutator { // // TODO implement this for all scenarios internal_error << "Vector legalization for Reinterpret to different bit size per element is " - << "not supported yet: reinterpret<" << result_type << ">(" << value.type() << ")" + << "not supported yet: reinterpret<" << op->type << ">(" << value.type() << ")" << legalization_error_guide; // int input_lane_start = lane_start * result_scalar_bits / input_scalar_bits; From 2a50d116fb4c9fdcadd5215de4107225fcbf94c4 Mon Sep 17 00:00:00 2001 From: Martijn Courteaux Date: Thu, 5 Jun 2025 12:00:32 +0200 Subject: [PATCH 10/47] Only run vector legalization mutators on device loops that require it. --- src/LegalizeVectors.cpp | 142 ++++++++++++++++++++-------------------- 1 file changed, 72 insertions(+), 70 deletions(-) diff --git a/src/LegalizeVectors.cpp b/src/LegalizeVectors.cpp index fa11466ecc02..986582e03c0c 100644 --- a/src/LegalizeVectors.cpp +++ b/src/LegalizeVectors.cpp @@ -115,25 +115,14 @@ Expr simplify_shuffle(const Shuffle *op) { class LiftLetToLetStmt : public IRMutator { using IRMutator::visit; - int max_lanes{max_lanes_for_device(DeviceAPI::Host, 0)}; - - Stmt visit(const For *op) override { - ScopedValue scoped_max_lanes(max_lanes, max_lanes_for_device(op->device_api, max_lanes)); - return IRMutator::visit(op); - } - vector lets; Expr visit(const Let *op) override { - if (max_lanes != 0) { - for (const Let *existing : lets) { - internal_assert(existing->name != op->name) - << "Let " << op->name << " = ... cannot be lifted to LetStmt because the name is not unique."; - } - lets.push_back(op); - return mutate(op->body); - } else { - return IRMutator::visit(op); + for (const Let *existing : lets) { + internal_assert(existing->name != op->name) + << "Let " << op->name << " = ... cannot be lifted to LetStmt because the name is not unique."; } + lets.push_back(op); + return mutate(op->body); } public: @@ -156,7 +145,6 @@ class ExtractLanes : public IRMutator { int lane_start; int lane_count; - int max_legal_lanes; Expr extract_lanes_from_make_struct(const Call *op) { internal_assert(op); @@ -303,7 +291,7 @@ class ExtractLanes : public IRMutator { int vecs_per_reduction = op->value.type().lanes() / op->type.lanes(); int input_lane_start = vecs_per_reduction * lane_start; int input_lane_count = vecs_per_reduction * lane_count; - Expr arg = ExtractLanes(input_lane_start, input_lane_count, max_legal_lanes).mutate(op->value); + Expr arg = ExtractLanes(input_lane_start, input_lane_count).mutate(op->value); // This might fail if the extracted lanes reference a non-existing variable! return VectorReduce::make(op->op, arg, lane_count); } @@ -327,39 +315,30 @@ class ExtractLanes : public IRMutator { return IRMutator::mutate(s); } - ExtractLanes(int start, int count, int max_legal) - : lane_start(start), lane_count(count), max_legal_lanes(max_legal) { + ExtractLanes(int start, int count) + : lane_start(start), lane_count(count) { } }; class LiftExceedingVectors : public IRMutator { using IRMutator::visit; - int max_lanes{max_lanes_for_device(DeviceAPI::Host, 0)}; + int max_lanes; vector> lets; bool just_in_let_definition{false}; int in_strict_float = 0; - Stmt visit(const For *op) override { - ScopedValue scoped_max_lanes(max_lanes, max_lanes_for_device(op->device_api, max_lanes)); - return IRMutator::visit(op); - } - Expr visit(const Let *op) override { - if (max_lanes != 0) { - internal_error << "We don't want to process Lets. They should have all been converted to LetStmts."; - } else { - return IRMutator::visit(op); - } + internal_error << "We don't want to process Lets. They should have all been converted to LetStmts."; + return IRMutator::visit(op); } Stmt visit(const LetStmt *op) override { - Expr def; - { - ScopedValue scoped_just_in_let(just_in_let_definition, true); - def = mutate(op->value); - } + just_in_let_definition = true; + Expr def = mutate(op->value); + just_in_let_definition = false; + Stmt body = mutate(op->body); if (def.same_as(op->value) && body.same_as(op->body)) { return op; @@ -370,7 +349,7 @@ class LiftExceedingVectors : public IRMutator { Expr visit(const Call *op) override { // Custom handling of Call, to prevent certain things from being extracted out // of the call arguments, as that's not always allowed. - bool exceeds_lanecount = max_lanes && op->type.lanes() > max_lanes; + bool exceeds_lanecount = op->type.lanes() > max_lanes; if (op->is_intrinsic(Call::strict_float)) { in_strict_float++; } @@ -382,8 +361,15 @@ class LiftExceedingVectors : public IRMutator { for (int i = 0; i < int(op->args.size()); ++i) { bool may_extract = true; if (op->is_intrinsic(Call::require)) { + // Call::require is special: it behaves a little like if-then-else: + // it runs the 3rd argument (the error handling part) only when there + // is an error. Extracting that would unconditionally print the error. may_extract &= i < 2; } + if (op->is_intrinsic(Call::if_then_else)) { + // Only allow the condition to be extracted. + may_extract &= i == 0; + } const Expr &arg = op->args[i]; if (may_extract) { internal_assert(arg.type().lanes() == op->type.lanes()); @@ -412,7 +398,7 @@ class LiftExceedingVectors : public IRMutator { public: Stmt mutate(const Stmt &s) override { ScopedValue scoped_lets(lets, {}); - ScopedValue scoped_just_in_let(just_in_let_definition, false); + just_in_let_definition = false; Stmt mutated = IRMutator::mutate(s); for (auto &let : reverse_view(lets)) { // There is no recurse into let.second. This is handled by repeatedly calling this tranform. @@ -422,7 +408,7 @@ class LiftExceedingVectors : public IRMutator { } Expr mutate(const Expr &e) override { - bool exceeds_lanecount = max_lanes && e.type().lanes() > max_lanes; + bool exceeds_lanecount = e.type().lanes() > max_lanes; if (exceeds_lanecount) { bool should_extract = false; @@ -446,23 +432,23 @@ class LiftExceedingVectors : public IRMutator { } } - ScopedValue scoped_just_in_let(just_in_let_definition, false); + just_in_let_definition = false; return IRMutator::mutate(e); } + + LiftExceedingVectors(int max_lanes) + : max_lanes(max_lanes) { + internal_assert(max_lanes != 0) << "LiftExceedingVectors should not be called when there is no lane limit."; + } }; class LegalizeVectors : public IRMutator { using IRMutator::visit; - int max_lanes{max_lanes_for_device(DeviceAPI::Host, 0)}; - - Stmt visit(const For *op) override { - ScopedValue scoped_max_lanes(max_lanes, max_lanes_for_device(op->device_api, max_lanes)); - return IRMutator::visit(op); - } + int max_lanes; Stmt visit(const LetStmt *op) override { - bool exceeds_lanecount = max_lanes && op->value.type().lanes() > max_lanes; + bool exceeds_lanecount = op->value.type().lanes() > max_lanes; if (exceeds_lanecount) { int num_vecs = (op->value.type().lanes() + max_lanes - 1) / max_lanes; @@ -474,7 +460,7 @@ class LegalizeVectors : public IRMutator { int lane_count_for_vec = std::min(op->value.type().lanes() - lane_start, max_lanes); std::string name = vec_name(op->name, lane_start, lane_count_for_vec); - Expr value = mutate(ExtractLanes(lane_start, lane_count_for_vec, max_lanes).mutate(op->value)); + Expr value = mutate(ExtractLanes(lane_start, lane_count_for_vec).mutate(op->value)); debug(3) << " Add: let " << name << " = " << value << "\n"; body = LetStmt::make(name, value, body); @@ -486,14 +472,12 @@ class LegalizeVectors : public IRMutator { } Expr visit(const Let *op) override { - if (max_lanes != 0) { - internal_error << "Lets should have been lifted into LetStmts."; - } + internal_error << "Lets should have been lifted into LetStmts."; return IRMutator::visit(op); } Stmt visit(const Store *op) override { - bool exceeds_lanecount = max_lanes && op->index.type().lanes() > max_lanes; + bool exceeds_lanecount = op->index.type().lanes() > max_lanes; if (exceeds_lanecount) { // Split up in multiple stores int num_vecs = (op->index.type().lanes() + max_lanes - 1) / max_lanes; @@ -502,9 +486,9 @@ class LegalizeVectors : public IRMutator { for (int i = 0; i < num_vecs; ++i) { int lane_start = i * max_lanes; int lane_count_for_vec = std::min(op->value.type().lanes() - lane_start, max_lanes); - Expr rhs = ExtractLanes(lane_start, lane_count_for_vec, max_lanes).mutate(op->value); - Expr index = ExtractLanes(lane_start, lane_count_for_vec, max_lanes).mutate(op->index); - Expr predictate = ExtractLanes(lane_start, lane_count_for_vec, max_lanes).mutate(op->predicate); + Expr rhs = ExtractLanes(lane_start, lane_count_for_vec).mutate(op->value); + Expr index = ExtractLanes(lane_start, lane_count_for_vec).mutate(op->index); + Expr predictate = ExtractLanes(lane_start, lane_count_for_vec).mutate(op->predicate); assignments.push_back(Store::make( op->name, std::move(rhs), std::move(index), op->param, std::move(predictate), op->alignment + lane_start)); @@ -517,9 +501,6 @@ class LegalizeVectors : public IRMutator { } Expr visit(const Shuffle *op) override { - if (max_lanes == 0) { - return IRMutator::visit(op); - } internal_assert(op->type.lanes() <= max_lanes) << Expr(op); bool requires_mutation = false; for (size_t i = 0; i < op->vectors.size(); ++i) { @@ -545,7 +526,7 @@ class LegalizeVectors : public IRMutator { for (int i = 0; i < num_vecs; i++) { int lane_start = i * max_lanes; int lane_count_for_vec = std::min(vec.type().lanes() - lane_start, max_lanes); - new_vectors.push_back(ExtractLanes(lane_start, lane_count_for_vec, max_lanes).mutate(vec)); + new_vectors.push_back(ExtractLanes(lane_start, lane_count_for_vec).mutate(vec)); } } else { new_vectors.push_back(IRMutator::mutate(vec)); @@ -560,9 +541,6 @@ class LegalizeVectors : public IRMutator { } Expr visit(const VectorReduce *op) override { - if (max_lanes == 0) { - return IRMutator::visit(op); - } const Expr &arg = op->value; if (arg.type().lanes() > max_lanes) { // TODO: The transformation below is not allowed under strict_float, but @@ -577,7 +555,7 @@ class LegalizeVectors : public IRMutator { for (int i = 0; i < num_vecs; i++) { int lane_start = i * max_lanes; int lane_count_for_vec = std::min(arg.type().lanes() - lane_start, max_lanes); - Expr partial_arg = mutate(ExtractLanes(lane_start, lane_count_for_vec, max_lanes).mutate(arg)); + Expr partial_arg = mutate(ExtractLanes(lane_start, lane_count_for_vec).mutate(arg)); Expr partial_red = VectorReduce::make(op->op, std::move(partial_arg), op->type.lanes()); if (i == 0) { result = partial_red; @@ -612,17 +590,25 @@ class LegalizeVectors : public IRMutator { return IRMutator::visit(op); } } + +public: + LegalizeVectors(int max_lanes) + : max_lanes(max_lanes) { + internal_assert(max_lanes != 0) << "LegalizeVectors should not be called when there is no lane limit."; + } }; } // namespace -Stmt legalize_vectors(const Stmt &s) { +Stmt legalize_vectors_in_device_loop(const For *op) { + int max_lanes = max_lanes_for_device(op->device_api, 0); + // Similar to CSE, lifting out stuff into variables. // Pass 1): lift out Shuffles that exceed lane count into variables // Pass 2): Rewrite those vector variables as bundles of vector variables, while legalizing all other stuff. - Stmt m0 = simplify(s); + Stmt m0 = simplify(op->body); Stmt m1 = common_subexpression_elimination(m0, false); - if (!m1.same_as(s)) { + if (!m1.same_as(op->body)) { debug(3) << "After CSE:\n" << m1 << "\n"; } @@ -634,7 +620,7 @@ Stmt legalize_vectors(const Stmt &s) { Stmt m3 = m2; while (true) { - Stmt m = LiftExceedingVectors().mutate(m3); + Stmt m = LiftExceedingVectors(max_lanes).mutate(m3); bool modified = !m3.same_as(m); m3 = std::move(m); if (!modified) { @@ -646,16 +632,32 @@ Stmt legalize_vectors(const Stmt &s) { } } - Stmt m4 = LegalizeVectors().mutate(m3); + Stmt m4 = LegalizeVectors(max_lanes).mutate(m3); if (!m4.same_as(m3)) { debug(3) << "After legalizing vectors:\n" << m4 << "\n"; } if (m4.same_as(m2)) { debug(3) << "Vector Legalization did do nothing, returning input.\n"; - return s; + return op; } - return simplify(m4); + m4 = simplify(m4); + return For::make(op->name, op->min, op->extent, op->for_type, + op->partition_policy, op->device_api, m4); +} + +Stmt legalize_vectors(const Stmt &s) { + class LegalizeDeviceLoops : public IRMutator { + using IRMutator::visit; + Stmt visit(const For *op) override { + if (max_lanes_for_device(op->device_api, 0)) { + return legalize_vectors_in_device_loop(op); + } else { + return IRMutator::visit(op); + } + } + } mutator; + return mutator.mutate(s); } } // namespace Internal } // namespace Halide From 963f510aed94e6448d898505bf77b003c171641a Mon Sep 17 00:00:00 2001 From: Martijn Courteaux Date: Sat, 14 Jun 2025 09:32:56 +0200 Subject: [PATCH 11/47] Move required simplifier logic for the vector legalization to the actual Simplifier. Adjust Hexagon simd op tests, as the Simplifier now does optimize away some shuffles. Bugfix in Hexagon shuffle_vector() logic. Co-authored-by: Andrew Adams --- src/CodeGen_Hexagon.cpp | 18 +-- src/CodeGen_LLVM.cpp | 3 +- src/LegalizeVectors.cpp | 76 ++--------- src/Simplify_Shuffle.cpp | 155 +++++++++++++++++++---- test/correctness/simd_op_check.h | 17 ++- test/correctness/simd_op_check_hvx.cpp | 24 ++-- test/correctness/stage_strided_loads.cpp | 12 +- 7 files changed, 184 insertions(+), 121 deletions(-) diff --git a/src/CodeGen_Hexagon.cpp b/src/CodeGen_Hexagon.cpp index 4bb61a4103ab..28a2f5211add 100644 --- a/src/CodeGen_Hexagon.cpp +++ b/src/CodeGen_Hexagon.cpp @@ -1186,15 +1186,16 @@ Value *CodeGen_Hexagon::shuffle_vectors(Value *a, Value *b, create_bitcast(a_call->getArgOperand(1), native_ty), create_bitcast(a_call->getArgOperand(0), native_ty), indices); } else if (ShuffleVectorInst *a_shuffle = dyn_cast(a)) { - bool is_identity = true; - for (int i = 0; i < a_elements; i++) { - int mask_i = a_shuffle->getMaskValue(i); - is_identity = is_identity && (mask_i == i || mask_i == -1); - } - if (is_identity) { - return shuffle_vectors(a_shuffle->getOperand(0), - a_shuffle->getOperand(1), indices); + std::vector new_indices(indices.size()); + for (size_t i = 0; i < indices.size(); i++) { + if (indices[i] != -1) { + new_indices[i] = a_shuffle->getMaskValue(indices[i]); + } else { + new_indices[i] = -1; + } } + return shuffle_vectors(a_shuffle->getOperand(0), + a_shuffle->getOperand(1), new_indices); } } @@ -1556,6 +1557,7 @@ Value *CodeGen_Hexagon::vdelta(Value *lut, const vector &indices) { Value *ret = nullptr; for (int i = 0; i < lut_elements; i += native_elements) { Value *lut_i = slice_vector(lut, i, native_elements); + internal_assert(get_vector_num_elements(lut_i->getType()) == native_elements); vector indices_i(native_elements); vector mask(native_elements); bool all_used = true; diff --git a/src/CodeGen_LLVM.cpp b/src/CodeGen_LLVM.cpp index 31282ef1ee75..d25b27226bd1 100644 --- a/src/CodeGen_LLVM.cpp +++ b/src/CodeGen_LLVM.cpp @@ -5030,10 +5030,11 @@ Value *CodeGen_LLVM::shuffle_vectors(Value *a, Value *b, } // Check for type identity *after* normalizing to fixed vectors internal_assert(a->getType() == b->getType()); + int elements_a = get_vector_num_elements(a->getType()); vector llvm_indices(indices.size()); for (size_t i = 0; i < llvm_indices.size(); i++) { if (indices[i] >= 0) { - internal_assert(indices[i] < get_vector_num_elements(a->getType()) * 2); + internal_assert(indices[i] < elements_a * 2) << indices[i] << " " << elements_a * 2; llvm_indices[i] = ConstantInt::get(i32_t, indices[i]); } else { // Only let -1 be undef. diff --git a/src/LegalizeVectors.cpp b/src/LegalizeVectors.cpp index 986582e03c0c..4c26acbb5806 100644 --- a/src/LegalizeVectors.cpp +++ b/src/LegalizeVectors.cpp @@ -50,68 +50,6 @@ std::string vec_name(const string &name, int lane_start, int lane_count) { return name + ".lanes_" + std::to_string(lane_start) + "_" + std::to_string(lane_start + lane_count - 1); } -Expr simplify_shuffle(const Shuffle *op) { - if (op->is_extract_element()) { - if (op->vectors.size() == 1) { - if (op->vectors[0].type().is_scalar()) { - return op->vectors[0]; - } else { - return Expr(op); - } - } else { - // Figure out which element is comes from. - int index = op->indices[0]; - internal_assert(index >= 0); - for (const Expr &vector : op->vectors) { - if (index < vector.type().lanes()) { - if (vector.type().is_scalar()) { - return vector; - } else { - return Shuffle::make_extract_element(vector, index); - } - } - index -= vector.type().lanes(); - } - internal_error << "Index out of bounds."; - } - } - - // Figure out if all extracted lanes come from 1 component. - vector> src_vec_and_lane_idx = op->vector_and_lane_indices(); - bool all_from_the_same = true; - bool is_full_vec = src_vec_and_lane_idx[0].second == 0; - for (int i = 1; i < int(op->indices.size()); ++i) { - if (src_vec_and_lane_idx[i].first != src_vec_and_lane_idx[0].first) { - all_from_the_same = false; - is_full_vec = false; - break; - } - if (src_vec_and_lane_idx[i].second != i) { - is_full_vec = false; - } - } - if (all_from_the_same) { - const Expr &src_vec = op->vectors[src_vec_and_lane_idx[0].first]; - is_full_vec &= src_vec.type().lanes() == int(op->indices.size()); - int first_lane_in_src = src_vec_and_lane_idx[0].second; - if (is_full_vec) { - return src_vec; - } else { - const Ramp *ramp = src_vec.as(); - if (ramp && op->is_slice() && op->slice_stride() == 1) { - return simplify(Ramp::make(ramp->base + first_lane_in_src * ramp->stride, ramp->stride, op->indices.size())); - } - vector new_indices; - for (int i = 0; i < int(op->indices.size()); ++i) { - new_indices.push_back(src_vec_and_lane_idx[i].second); - } - return Shuffle::make({src_vec}, new_indices); - } - } - - return op; -} - class LiftLetToLetStmt : public IRMutator { using IRMutator::visit; @@ -196,8 +134,7 @@ class ExtractLanes : public IRMutator { for (int i = 0; i < lane_count; ++i) { new_indices.push_back(op->indices[lane_start + i]); } - Expr result = Shuffle::make(op->vectors, new_indices); - return simplify_shuffle(result.as()); + return simplify(Shuffle::make(op->vectors, new_indices)); } Expr visit(const Ramp *op) override { @@ -532,8 +469,7 @@ class LegalizeVectors : public IRMutator { new_vectors.push_back(IRMutator::mutate(vec)); } } - Expr result = Shuffle::make(new_vectors, op->indices); - result = simplify_shuffle(result.as()); + Expr result = simplify(Shuffle::make(new_vectors, op->indices)); debug(3) << "Legalized " << Expr(op) << " => " << result << "\n"; return result; } @@ -641,9 +577,13 @@ Stmt legalize_vectors_in_device_loop(const For *op) { debug(3) << "Vector Legalization did do nothing, returning input.\n"; return op; } - m4 = simplify(m4); + Stmt m5 = simplify(m4); + if (!m4.same_as(m5)) { + debug(3) << "After simplify:\n" + << m5 << "\n"; + } return For::make(op->name, op->min, op->extent, op->for_type, - op->partition_policy, op->device_api, m4); + op->partition_policy, op->device_api, m5); } Stmt legalize_vectors(const Stmt &s) { diff --git a/src/Simplify_Shuffle.cpp b/src/Simplify_Shuffle.cpp index 9c38f0faf622..60d15c0afcdc 100644 --- a/src/Simplify_Shuffle.cpp +++ b/src/Simplify_Shuffle.cpp @@ -5,6 +5,7 @@ namespace Halide { namespace Internal { +using std::pair; using std::vector; Expr Simplify::visit(const Shuffle *op, ExprInfo *info) { @@ -25,9 +26,11 @@ Expr Simplify::visit(const Shuffle *op, ExprInfo *info) { } } - // Mutate the vectors vector new_vectors; + vector new_indices = op->indices; bool changed = false; + + // Mutate the vectors for (const Expr &vector : op->vectors) { ExprInfo v_info; Expr new_vector = mutate(vector, &v_info); @@ -46,52 +49,150 @@ Expr Simplify::visit(const Shuffle *op, ExprInfo *info) { } // A concat of one vector, is just the vector. + // (Early check, this is repeated below, once the argument list is potentially reduced) if (op->vectors.size() == 1 && op->is_concat()) { return new_vectors[0]; } - // Try to convert a load with shuffled indices into a - // shuffle of a dense load. + Expr result = op; + + // Analyze which input vectors are actually used. We will rewrite + // the vector of inputs and the indices jointly, and continue with + // those below. + { + vector arg_used(new_vectors.size()); + // Figure out if all extracted lanes come from 1 component. + vector> src_vec_and_lane_idx = op->vector_and_lane_indices(); + for (int i = 0; i < int(op->indices.size()); ++i) { + arg_used[src_vec_and_lane_idx[i].first] = true; + } + size_t num_args_used = 0; + for (size_t i = 0; i < arg_used.size(); ++i) { + if (arg_used[i]) { + num_args_used++; + } + } + + if (num_args_used < op->vectors.size()) { + // Not all arguments to the shuffle are used by the indices. + // Let's throw them out. + for (int vi = arg_used.size() - 1; vi >= 0; --vi) { + if (!arg_used[vi]) { + int lanes_deleted = op->vectors[vi].type().lanes(); + int vector_start_lane = 0; + for (int i = 0; i < vi; ++i) { + vector_start_lane += op->vectors[i].type().lanes(); + } + for (size_t i = 0; i < new_indices.size(); ++i) { + if (new_indices[i] > vector_start_lane) { + internal_assert(new_indices[i] >= vector_start_lane + lanes_deleted); + new_indices[i] -= lanes_deleted; + } + } + new_vectors.erase(new_vectors.begin() + vi); + } + } + + changed = true; + } + } + + // Replace the op with the intermediate simplified result (if it changed), and continue. + if (changed) { + result = Shuffle::make(new_vectors, new_indices); + op = result.as(); + changed = false; + } + + if (new_vectors.size() == 1) { + const Ramp *ramp = new_vectors[0].as(); + if (ramp && op->is_slice()) { + int first_lane_in_src = op->indices[0]; + int slice_stride = op->slice_stride(); + if (slice_stride >= 1) { + return mutate(Ramp::make(ramp->base + first_lane_in_src * ramp->stride, + ramp->stride * slice_stride, + op->indices.size()), + nullptr); + } + } + + // Test this again, but now after new_vectors got potentially shorter. + if (op->is_concat()) { + return new_vectors[0]; + } + } + + // Try to convert a Shuffle of Loads into a single Load of a Ramp. + // Make sure to not undo the work of the StageStridedLoads pass: + // only if the result of the shuffled indices is a *dense* ramp, we + // can proceed. There are two side cases: concatenations of scalars, + // and when the loads weren't dense to begin with. if (const Load *first_load = new_vectors[0].as()) { vector load_predicates; vector load_indices; + bool all_loads_are_dense = true; bool unpredicated = true; + bool concat_of_scalars = true; for (const Expr &e : new_vectors) { const Load *load = e.as(); if (load && load->name == first_load->name) { load_predicates.push_back(load->predicate); load_indices.push_back(load->index); unpredicated = unpredicated && is_const_one(load->predicate); + if (const Ramp *index_ramp = load->index.as()) { + if (!is_const_one(index_ramp->stride)) { + all_loads_are_dense = false; + } + } else if (!load->index.type().is_scalar()) { + all_loads_are_dense = false; + } + if (!load->index.type().is_scalar()) { + concat_of_scalars = false; + } } else { break; } } + debug(3) << "Shuffle of Load found: " << result << " where" + << " all_loads_are_dense=" << all_loads_are_dense << "," + << " concat_of_scalars=" << concat_of_scalars << "\n"; + if (load_indices.size() == new_vectors.size()) { + // All of the Shuffle arguments are Loads. Type t = load_indices[0].type().with_lanes(op->indices.size()); Expr shuffled_index = Shuffle::make(load_indices, op->indices); + debug(3) << " Shuffled index: " << shuffled_index << "\n"; ExprInfo shuffled_index_info; shuffled_index = mutate(shuffled_index, &shuffled_index_info); - if (shuffled_index.as()) { - ExprInfo base_info; - if (const Ramp *r = shuffled_index.as()) { - mutate(r->base, &base_info); - } + debug(3) << " Simplified shuffled index: " << shuffled_index << "\n"; + if (const Ramp *index_ramp = shuffled_index.as()) { + if (is_const_one(index_ramp->stride) || !all_loads_are_dense || concat_of_scalars) { + ExprInfo base_info; + mutate(index_ramp->base, &base_info); - ModulusRemainder alignment = - ModulusRemainder::intersect(base_info.alignment, shuffled_index_info.alignment); + ModulusRemainder alignment = + ModulusRemainder::intersect(base_info.alignment, shuffled_index_info.alignment); - Expr shuffled_predicate; - if (unpredicated) { - shuffled_predicate = const_true(t.lanes(), nullptr); - } else { - shuffled_predicate = Shuffle::make(load_predicates, op->indices); - shuffled_predicate = mutate(shuffled_predicate, nullptr); + Expr shuffled_predicate; + if (unpredicated) { + shuffled_predicate = const_true(t.lanes(), nullptr); + } else { + shuffled_predicate = Shuffle::make(load_predicates, op->indices); + shuffled_predicate = mutate(shuffled_predicate, nullptr); + } + t = first_load->type; + t = t.with_lanes(op->indices.size()); + Expr result = Load::make(t, first_load->name, shuffled_index, first_load->image, + first_load->param, shuffled_predicate, alignment); + debug(3) << " => " << result << "\n"; + return result; } - t = first_load->type; - t = t.with_lanes(op->indices.size()); - return Load::make(t, first_load->name, shuffled_index, first_load->image, - first_load->param, shuffled_predicate, alignment); + } else { + // We can't... Leave it as a Shuffle of Loads. + // Note: don't proceed down. + return result; } } } @@ -261,6 +362,14 @@ Expr Simplify::visit(const Shuffle *op, ExprInfo *info) { } } + for (size_t i = 0; i < new_vectors.size() && can_collapse; i++) { + if (new_vectors[i].as()) { + // Don't create a Ramp of a Load, like: + // ramp(buf[x], buf[x + 1] - buf[x], ...) + can_collapse = false; + } + } + if (can_collapse) { return Ramp::make(new_vectors[0], stride, op->indices.size()); } @@ -324,11 +433,7 @@ Expr Simplify::visit(const Shuffle *op, ExprInfo *info) { } } - if (!changed) { - return op; - } else { - return Shuffle::make(new_vectors, op->indices); - } + return result; } } // namespace Internal diff --git a/test/correctness/simd_op_check.h b/test/correctness/simd_op_check.h index 16bd60d836a8..d59872924420 100644 --- a/test/correctness/simd_op_check.h +++ b/test/correctness/simd_op_check.h @@ -507,20 +507,27 @@ class SimdOpCheckTest { })); } + std::vector failed_tests; + constexpr int tabstop = 32; for (auto &f : futures) { auto result = f.get(); - constexpr int tabstop = 32; const int spaces = std::max(1, tabstop - (int)result.op.size()); std::cout << result.op << std::string(spaces, ' ') << "(" << run_target_str << ")\n"; if (!result.error_msg.empty()) { std::cerr << result.error_msg; - // The thread-pool destructor will block until in-progress tasks - // are done, and then will discard any tasks that haven't been - // launched yet. - return false; + failed_tests.push_back(std::move(result)); } } + if (!failed_tests.empty()) { + std::cerr << "SIMD op check summary: " << failed_tests.size() << " tests failed:\n"; + for (auto &result : failed_tests) { + const int spaces = std::max(1, tabstop - (int)result.op.size()); + std::cerr << " " << result.op << std::string(spaces, ' ') << "(" << run_target_str << ")\n"; + } + return false; + } + return true; } diff --git a/test/correctness/simd_op_check_hvx.cpp b/test/correctness/simd_op_check_hvx.cpp index 5da8e85d8b23..241152df2342 100644 --- a/test/correctness/simd_op_check_hvx.cpp +++ b/test/correctness/simd_op_check_hvx.cpp @@ -54,16 +54,24 @@ class SimdOpCheckHVX : public SimdOpCheckTest { isa_version = 62; } + auto valign_test_u8 = [&](int off) { + return in_u8(x + off) + in_u8(x + off + 1); + }; + + auto valign_test_u16 = [&](int off) { + return in_u16(x + off) + in_u16(x + off + 1); + }; + // Verify that unaligned loads use the right instructions, and don't try to use // immediates of more than 3 bits. - check("valign(v*,v*,#7)", hvx_width / 1, in_u8(x + 7)); - check("vlalign(v*,v*,#7)", hvx_width / 1, in_u8(x + hvx_width - 7)); - check("valign(v*,v*,r*)", hvx_width / 1, in_u8(x + 8)); - check("valign(v*,v*,r*)", hvx_width / 1, in_u8(x + hvx_width - 8)); - check("valign(v*,v*,#6)", hvx_width / 1, in_u16(x + 3)); - check("vlalign(v*,v*,#6)", hvx_width / 1, in_u16(x + hvx_width - 3)); - check("valign(v*,v*,r*)", hvx_width / 1, in_u16(x + 4)); - check("valign(v*,v*,r*)", hvx_width / 1, in_u16(x + hvx_width - 4)); + check("valign(v*,v*,#7)", hvx_width / 1, valign_test_u8(6)); + check("vlalign(v*,v*,#7)", hvx_width / 1, valign_test_u8(hvx_width - 7)); + check("valign(v*,v*,r*)", hvx_width / 1, valign_test_u8(8)); + check("valign(v*,v*,r*)", hvx_width / 1, valign_test_u8(hvx_width - 8)); + check("valign(v*,v*,#6)", hvx_width / 1, valign_test_u16(3)); + check("vlalign(v*,v*,#6)", hvx_width / 1, valign_test_u16(hvx_width - 3)); + check("valign(v*,v*,r*)", hvx_width / 1, valign_test_u16(4)); + check("valign(v*,v*,r*)", hvx_width / 1, valign_test_u16(hvx_width - 4)); check("vunpack(v*.ub)", hvx_width / 1, u16(u8_1)); check("vunpack(v*.ub)", hvx_width / 1, i16(u8_1)); diff --git a/test/correctness/stage_strided_loads.cpp b/test/correctness/stage_strided_loads.cpp index f791385f7c25..dab19a370d93 100644 --- a/test/correctness/stage_strided_loads.cpp +++ b/test/correctness/stage_strided_loads.cpp @@ -10,7 +10,7 @@ class CheckForStridedLoads : public IRMutator { if (const Ramp *r = op->index.as()) { if (op->name == buf_name) { bool dense = is_const_one(r->stride); - found |= !dense; + found_strided_load |= !dense; dense_loads += dense; } } @@ -18,27 +18,27 @@ class CheckForStridedLoads : public IRMutator { } public: - bool found = false; + bool found_strided_load = false; int dense_loads = 0; std::string buf_name; void check(Func f, int desired_dense_loads, std::string name = "buf") { - found = false; + found_strided_load = false; dense_loads = 0; buf_name = name; f.add_custom_lowering_pass(this, nullptr); f.compile_jit(); - assert(!found); + assert(!found_strided_load); assert(dense_loads == desired_dense_loads); } void check_not(Func f, int desired_dense_loads, std::string name = "buf") { - found = false; + found_strided_load = false; dense_loads = 0; buf_name = name; f.add_custom_lowering_pass(this, nullptr); f.compile_jit(); - assert(found); + assert(found_strided_load); assert(dense_loads == desired_dense_loads); } } checker; From f381af02dc9af1daae3a84f09c5d3ee47a5bd054 Mon Sep 17 00:00:00 2001 From: Martijn Courteaux Date: Sat, 14 Jun 2025 14:08:28 +0200 Subject: [PATCH 12/47] Remove special handling of strict_float, as those got overhauled. --- src/LegalizeVectors.cpp | 14 +------------- 1 file changed, 1 insertion(+), 13 deletions(-) diff --git a/src/LegalizeVectors.cpp b/src/LegalizeVectors.cpp index 4c26acbb5806..f8c5a4891948 100644 --- a/src/LegalizeVectors.cpp +++ b/src/LegalizeVectors.cpp @@ -264,7 +264,6 @@ class LiftExceedingVectors : public IRMutator { vector> lets; bool just_in_let_definition{false}; - int in_strict_float = 0; Expr visit(const Let *op) override { internal_error << "We don't want to process Lets. They should have all been converted to LetStmts."; @@ -287,9 +286,6 @@ class LiftExceedingVectors : public IRMutator { // Custom handling of Call, to prevent certain things from being extracted out // of the call arguments, as that's not always allowed. bool exceeds_lanecount = op->type.lanes() > max_lanes; - if (op->is_intrinsic(Call::strict_float)) { - in_strict_float++; - } Expr mutated = op; if (exceeds_lanecount) { std::vector args; @@ -326,9 +322,6 @@ class LiftExceedingVectors : public IRMutator { } else { mutated = IRMutator::visit(op); } - if (op->is_intrinsic(Call::strict_float)) { - in_strict_float--; - } return mutated; } @@ -352,11 +345,6 @@ class LiftExceedingVectors : public IRMutator { should_extract |= e.node_type() == IRNodeType::Shuffle; should_extract |= e.node_type() == IRNodeType::VectorReduce; - // TODO: Handling of strict_float is not well done. - // But at least it covers a few basic scenarios. - // This should be redone once we overhaul strict_float. - should_extract &= !in_strict_float; - should_extract &= !just_in_let_definition; debug((should_extract ? 3 : 4)) << "Max lanes (" << max_lanes << ") exceeded (" << e.type().lanes() << ") by: " << e << "\n"; @@ -480,7 +468,7 @@ class LegalizeVectors : public IRMutator { const Expr &arg = op->value; if (arg.type().lanes() > max_lanes) { // TODO: The transformation below is not allowed under strict_float, but - // I won't bother right now, as strict_float is due for an overhaul. + // I don't immediately know what to do here. // This should be an internal_assert. internal_assert(op->type.lanes() == 1) From 9e1329a839eb75cf959e32028f21d5a8f7f47d48 Mon Sep 17 00:00:00 2001 From: Martijn Courteaux Date: Fri, 29 Aug 2025 19:48:32 +0200 Subject: [PATCH 13/47] Hexagon codegen for vdelta fix regarding dont-care values in shuffle indices. --- src/CodeGen_Hexagon.cpp | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/CodeGen_Hexagon.cpp b/src/CodeGen_Hexagon.cpp index 28a2f5211add..c46375c7c53e 100644 --- a/src/CodeGen_Hexagon.cpp +++ b/src/CodeGen_Hexagon.cpp @@ -1517,7 +1517,11 @@ Value *CodeGen_Hexagon::vdelta(Value *lut, const vector &indices) { vector i8_indices(indices.size() * replicate); for (size_t i = 0; i < indices.size(); i++) { for (int j = 0; j < replicate; j++) { - i8_indices[i * replicate + j] = indices[i] * replicate + j; + if (indices[i] == -1) { + i8_indices[i * replicate + j] = -1; // Replicate the don't-care. + } else { + i8_indices[i * replicate + j] = indices[i] * replicate + j; + } } } Value *result = vdelta(i8_lut, i8_indices); From 43ed906afa038e4830a629969cb88aa8d71eb4b6 Mon Sep 17 00:00:00 2001 From: Martijn Courteaux Date: Fri, 29 Aug 2025 19:52:27 +0200 Subject: [PATCH 14/47] Clang-format --- src/CodeGen_Hexagon.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/CodeGen_Hexagon.cpp b/src/CodeGen_Hexagon.cpp index c46375c7c53e..92aeaf5018db 100644 --- a/src/CodeGen_Hexagon.cpp +++ b/src/CodeGen_Hexagon.cpp @@ -1518,7 +1518,7 @@ Value *CodeGen_Hexagon::vdelta(Value *lut, const vector &indices) { for (size_t i = 0; i < indices.size(); i++) { for (int j = 0; j < replicate; j++) { if (indices[i] == -1) { - i8_indices[i * replicate + j] = -1; // Replicate the don't-care. + i8_indices[i * replicate + j] = -1; // Replicate the don't-care. } else { i8_indices[i * replicate + j] = indices[i] * replicate + j; } From 4cb5c2c1b2f0ca2e6913cce2a6045d9fbe1a6599 Mon Sep 17 00:00:00 2001 From: Martijn Courteaux Date: Sun, 12 Oct 2025 17:20:17 +0200 Subject: [PATCH 15/47] Satisfy clang-tidy --- src/LegalizeVectors.cpp | 5 +++-- src/Simplify_Shuffle.cpp | 12 ++++++------ 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/src/LegalizeVectors.cpp b/src/LegalizeVectors.cpp index f8c5a4891948..da80ed515077 100644 --- a/src/LegalizeVectors.cpp +++ b/src/LegalizeVectors.cpp @@ -131,6 +131,7 @@ class ExtractLanes : public IRMutator { Expr visit(const Shuffle *op) override { vector new_indices; + new_indices.reserve(lane_count); for (int i = 0; i < lane_count; ++i) { new_indices.push_back(op->indices[lane_start + i]); } @@ -428,8 +429,8 @@ class LegalizeVectors : public IRMutator { Expr visit(const Shuffle *op) override { internal_assert(op->type.lanes() <= max_lanes) << Expr(op); bool requires_mutation = false; - for (size_t i = 0; i < op->vectors.size(); ++i) { - if (op->vectors[i].type().lanes() > max_lanes) { + for (const auto &vec : op->vectors) { + if (vec.type().lanes() > max_lanes) { requires_mutation = true; break; } diff --git a/src/Simplify_Shuffle.cpp b/src/Simplify_Shuffle.cpp index 60d15c0afcdc..5c84cea8d195 100644 --- a/src/Simplify_Shuffle.cpp +++ b/src/Simplify_Shuffle.cpp @@ -67,8 +67,8 @@ Expr Simplify::visit(const Shuffle *op, ExprInfo *info) { arg_used[src_vec_and_lane_idx[i].first] = true; } size_t num_args_used = 0; - for (size_t i = 0; i < arg_used.size(); ++i) { - if (arg_used[i]) { + for (bool used : arg_used) { + if (used) { num_args_used++; } } @@ -83,10 +83,10 @@ Expr Simplify::visit(const Shuffle *op, ExprInfo *info) { for (int i = 0; i < vi; ++i) { vector_start_lane += op->vectors[i].type().lanes(); } - for (size_t i = 0; i < new_indices.size(); ++i) { - if (new_indices[i] > vector_start_lane) { - internal_assert(new_indices[i] >= vector_start_lane + lanes_deleted); - new_indices[i] -= lanes_deleted; + for (int &new_index : new_indices) { + if (new_index > vector_start_lane) { + internal_assert(new_index >= vector_start_lane + lanes_deleted); + new_index -= lanes_deleted; } } new_vectors.erase(new_vectors.begin() + vi); From 3034e9244adc291ae947c69deac99965a4a00b55 Mon Sep 17 00:00:00 2001 From: Martijn Courteaux Date: Sat, 13 Dec 2025 16:45:10 +0100 Subject: [PATCH 16/47] Revive. --- src/LegalizeVectors.cpp | 2 +- test/error/CMakeLists.txt | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/src/LegalizeVectors.cpp b/src/LegalizeVectors.cpp index da80ed515077..07be6d438354 100644 --- a/src/LegalizeVectors.cpp +++ b/src/LegalizeVectors.cpp @@ -571,7 +571,7 @@ Stmt legalize_vectors_in_device_loop(const For *op) { debug(3) << "After simplify:\n" << m5 << "\n"; } - return For::make(op->name, op->min, op->extent, op->for_type, + return For::make(op->name, op->min, op->max, op->for_type, op->partition_policy, op->device_api, m5); } diff --git a/test/error/CMakeLists.txt b/test/error/CMakeLists.txt index 41816d5ba36b..b7d6a380c504 100644 --- a/test/error/CMakeLists.txt +++ b/test/error/CMakeLists.txt @@ -79,7 +79,6 @@ tests(GROUPS error memoize_output_invalid.cpp memoize_redefine_eviction_key.cpp metal_threads_too_large.cpp - metal_vector_too_large.cpp mismatch_runtime_vscale.cpp missing_args.cpp no_default_device.cpp From 70debb12fcf8006ed1552bd937e100d00d3ce875 Mon Sep 17 00:00:00 2001 From: Martijn Courteaux Date: Wed, 28 Jan 2026 15:39:03 +0100 Subject: [PATCH 17/47] Restore case-insensitive sorting order. --- src/CMakeLists.txt | 41 +++++++++++++++++++++-------------------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index fec57b9da9d4..0e0a7c626bd5 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -37,7 +37,8 @@ endif () set_target_properties(Halide PROPERTIES POSITION_INDEPENDENT_CODE ON) ## -# Lists of source files. Keep ALL lists sorted in alphabetical order. +# Lists of source files. Keep ALL lists sorted in case-insensitive alphabetical order. +# (neo)vim users can use ":sort i" in visual line mode. ## # The externally-visible header files that go into making Halide.h. @@ -62,14 +63,12 @@ target_sources( Associativity.h AsyncProducers.h AutoScheduleUtils.h - BoundConstantExtentLoops.h - BoundSmallAllocations.h BoundaryConditions.h + BoundConstantExtentLoops.h Bounds.h BoundsInference.h + BoundSmallAllocations.h Buffer.h - CPlusPlusMangle.h - CSE.h Callable.h CanonicalizeGPUVars.h ClampUnsafeAccesses.h @@ -81,8 +80,8 @@ target_sources( CodeGen_LLVM.h CodeGen_Metal_Dev.h CodeGen_OpenCL_Dev.h - CodeGen_PTX_Dev.h CodeGen_Posix.h + CodeGen_PTX_Dev.h CodeGen_PyTorch.h CodeGen_Targets.h CodeGen_Vulkan_Dev.h @@ -91,6 +90,8 @@ target_sources( ConciseCasts.h ConstantBounds.h ConstantInterval.h + CPlusPlusMangle.h + CSE.h Debug.h DebugArguments.h DebugToFile.h @@ -127,13 +128,6 @@ target_sources( Generator.h HexagonOffload.h HexagonOptimize.h - IR.h - IREquality.h - IRMatch.h - IRMutator.h - IROperator.h - IRPrinter.h - IRVisitor.h ImageParam.h InferArguments.h InjectHostDevBufferCopies.h @@ -142,13 +136,20 @@ target_sources( IntegerDivisionTable.h Interval.h IntrusivePtr.h + IR.h + IREquality.h + IRMatch.h + IRMutator.h + IROperator.h + IRPrinter.h + IRVisitor.h JITModule.h - LICM.h - LLVM_Output.h - LLVM_Runtime_Linker.h Lambda.h LegalizeVectors.h Lerp.h + LICM.h + LLVM_Output.h + LLVM_Runtime_Linker.h LoopCarry.h LoopPartitioningDirective.h Lower.h @@ -174,8 +175,8 @@ target_sources( PurifyIndexMath.h PythonExtensionGen.h Qualify.h - RDom.h Random.h + RDom.h Realization.h RealizationOrder.h RebaseLoopsToZero.h @@ -242,9 +243,9 @@ target_sources( AsyncProducers.cpp AutoScheduleUtils.cpp BoundaryConditions.cpp + BoundConstantExtentLoops.cpp Bounds.cpp BoundsInference.cpp - BoundConstantExtentLoops.cpp BoundSmallAllocations.cpp Buffer.cpp Callable.cpp @@ -270,9 +271,9 @@ target_sources( CodeGen_WebGPU_Dev.cpp CodeGen_X86.cpp CompilerLogger.cpp - CPlusPlusMangle.cpp ConstantBounds.cpp ConstantInterval.cpp + CPlusPlusMangle.cpp CSE.cpp Debug.cpp DebugArguments.cpp @@ -366,7 +367,6 @@ target_sources( Simplify_Add.cpp Simplify_And.cpp Simplify_Call.cpp - Simplify_Reinterpret.cpp Simplify_Cast.cpp Simplify_Div.cpp Simplify_EQ.cpp @@ -379,6 +379,7 @@ target_sources( Simplify_Mul.cpp Simplify_Not.cpp Simplify_Or.cpp + Simplify_Reinterpret.cpp Simplify_Select.cpp Simplify_Shuffle.cpp Simplify_Stmts.cpp From f29344dd9b887f0afa77b5f890c9e16fb342cfd9 Mon Sep 17 00:00:00 2001 From: Martijn Courteaux Date: Sat, 21 Feb 2026 16:44:17 +0100 Subject: [PATCH 18/47] Feedback from Andrew. --- src/LegalizeVectors.cpp | 38 ++++++++++++++++++++------------------ src/Simplify_Let.cpp | 5 +++++ src/Simplify_Shuffle.cpp | 6 +++++- 3 files changed, 30 insertions(+), 19 deletions(-) diff --git a/src/LegalizeVectors.cpp b/src/LegalizeVectors.cpp index 07be6d438354..c3919ef67da5 100644 --- a/src/LegalizeVectors.cpp +++ b/src/LegalizeVectors.cpp @@ -8,6 +8,8 @@ #include "Util.h" #include +#include +#include namespace Halide { namespace Internal { @@ -16,13 +18,19 @@ namespace { using namespace std; -const char *legalization_error_guide = "\n(This issue can most likely be resolved by reducing lane count for vectorize() calls in the schedule, or disabling it.)"; +const char *legalization_error_guide = "\n" + "(This is an implemenation limitation in Halide right now. This issue can most likely be \n" + " worked around by reducing lane count for vectorize() calls in GPU schedules, or disabling it.)"; int max_lanes_for_device(DeviceAPI api, int parent_max_lanes) { + // The environment variable below (HL_FORCE_VECTOR_LEGALIZATION) is here solely for testing purposes. + // It is useful to "stress-test" this lowering pass by forcing a shorter maximal vector size across + // all codegen across the entire test suite. This should not be used in real uses of Halide. std::string envvar = Halide::Internal::get_env_variable("HL_FORCE_VECTOR_LEGALIZATION"); if (!envvar.empty()) { return std::atoi(envvar.c_str()); } + // The remainder of this function correctly determines the number of lanes the device API supports. switch (api) { case DeviceAPI::Metal: case DeviceAPI::WebGPU: @@ -53,13 +61,13 @@ std::string vec_name(const string &name, int lane_start, int lane_count) { class LiftLetToLetStmt : public IRMutator { using IRMutator::visit; + unordered_set lifted_let_names; vector lets; Expr visit(const Let *op) override { - for (const Let *existing : lets) { - internal_assert(existing->name != op->name) - << "Let " << op->name << " = ... cannot be lifted to LetStmt because the name is not unique."; - } + internal_assert(lifted_let_names.count(op->name) == 0) + << "Let " << op->name << " = ... cannot be lifted to LetStmt because the name is not unique."; lets.push_back(op); + lifted_let_names.insert(op->name); return mutate(op->body); } @@ -124,8 +132,7 @@ class ExtractLanes : public IRMutator { return result; } - internal_error << "Unhandled trace call in LegalizeVectors' ExtractLanes: " << *event << legalization_error_guide << "\n" - << "Please report this error on GitHub." << legalization_error_guide; + internal_error << "Unhandled trace call in LegalizeVectors' ExtractLanes: " << *event << legalization_error_guide; return Expr(0); } @@ -332,7 +339,7 @@ class LiftExceedingVectors : public IRMutator { just_in_let_definition = false; Stmt mutated = IRMutator::mutate(s); for (auto &let : reverse_view(lets)) { - // There is no recurse into let.second. This is handled by repeatedly calling this tranform. + // There is no recurse into let.second. This is handled by repeatedly calling this transform. mutated = LetStmt::make(let.first, let.second, mutated); } return mutated; @@ -576,17 +583,12 @@ Stmt legalize_vectors_in_device_loop(const For *op) { } Stmt legalize_vectors(const Stmt &s) { - class LegalizeDeviceLoops : public IRMutator { - using IRMutator::visit; - Stmt visit(const For *op) override { - if (max_lanes_for_device(op->device_api, 0)) { - return legalize_vectors_in_device_loop(op); - } else { - return IRMutator::visit(op); - } + return mutate_with(s, [&](auto *self, const For *op) { + if (max_lanes_for_device(op->device_api, 0)) { + return legalize_vectors_in_device_loop(op); } - } mutator; - return mutator.mutate(s); + return self->visit_base(op); + }); } } // namespace Internal } // namespace Halide diff --git a/src/Simplify_Let.cpp b/src/Simplify_Let.cpp index 0d7b6677f8e6..9f18a6fb25c1 100644 --- a/src/Simplify_Let.cpp +++ b/src/Simplify_Let.cpp @@ -187,6 +187,11 @@ Body Simplify::simplify_let(const LetOrLetStmt *op, ExprInfo *info) { // pure operations from _all_ arguments to the Shuffle, we will // instead substitute all of the vars that go in the shuffle, and // instead guard against side effects by checking with `is_pure()`. + // + // Also, it is safe to substitute in without combinatorial + // blow-up, because deeply nested concats implies a + // combinatorially-large number of vector lanes, which we can't + // express in the type system anyway. replacement = substitute(f.new_name, shuffle, replacement); f.new_value = Expr(); break; diff --git a/src/Simplify_Shuffle.cpp b/src/Simplify_Shuffle.cpp index 5c84cea8d195..91ff3e784c8d 100644 --- a/src/Simplify_Shuffle.cpp +++ b/src/Simplify_Shuffle.cpp @@ -191,7 +191,9 @@ Expr Simplify::visit(const Shuffle *op, ExprInfo *info) { } } else { // We can't... Leave it as a Shuffle of Loads. - // Note: don't proceed down. + // Note: no mutate-recursion as we are dealing here with a + // Shuffle of Loads, which have already undergone mutation + // early in this function (new_vectors). return result; } } @@ -362,6 +364,7 @@ Expr Simplify::visit(const Shuffle *op, ExprInfo *info) { } } +#if 0 // Not sure what this was for. Disabling for now, and will run tests to see what's up. for (size_t i = 0; i < new_vectors.size() && can_collapse; i++) { if (new_vectors[i].as()) { // Don't create a Ramp of a Load, like: @@ -369,6 +372,7 @@ Expr Simplify::visit(const Shuffle *op, ExprInfo *info) { can_collapse = false; } } +#endif if (can_collapse) { return Ramp::make(new_vectors[0], stride, op->indices.size()); From ef2274ab8af5d7885fa91b0dfab4fc702bfc30db Mon Sep 17 00:00:00 2001 From: Martijn Courteaux Date: Sun, 1 Mar 2026 20:29:53 +0100 Subject: [PATCH 19/47] Unify my own ExtractLanes and the existing Deinterleaver. At the same time support all the cases that were previously unsupported (AFAICT). Co-authored-by: Andrew Adams Co-authored-by: Google Gemini 3 Pro --- src/Deinterleave.cpp | 491 ++++++++++++++---- src/Deinterleave.h | 14 +- src/LegalizeVectors.cpp | 397 ++++++-------- src/Simplify_Shuffle.cpp | 10 - .../performance/nested_vectorization_gemm.cpp | 1 - 5 files changed, 542 insertions(+), 371 deletions(-) diff --git a/src/Deinterleave.cpp b/src/Deinterleave.cpp index 243760e9d050..0337cabae241 100644 --- a/src/Deinterleave.cpp +++ b/src/Deinterleave.cpp @@ -17,6 +17,33 @@ namespace Internal { using std::pair; +std::string variable_name_with_extracted_lanes( + const std::string &varname, int varlanes, + int starting_lane, int lane_stride, int new_lanes) { + + if (lane_stride * new_lanes == varlanes) { + if (starting_lane == 0 && lane_stride == 2) { + return varname + ".even_lanes"; + } else if (starting_lane == 1 && lane_stride == 2) { + return varname + ".odd_lanes"; + } + } + if (lane_stride == 1) { + return varname + ".lanes_" + std::to_string(starting_lane) + + "_to_" + std::to_string(starting_lane + new_lanes - 1); + } else { + // Just specify the slice + std::string name = varname; + name += ".slice_"; + name += std::to_string(starting_lane); + name += "_"; + name += std::to_string(lane_stride); + name += "_"; + name += std::to_string(new_lanes); + return name; + } +} + namespace { class StoreCollector : public IRMutator { @@ -176,13 +203,17 @@ Stmt collect_strided_stores(const Stmt &stmt, const std::string &name, int strid return collect.mutate(stmt); } -class Deinterleaver : public IRGraphMutator { +class ExtractLanes : public IRGraphMutator { public: - Deinterleaver(int starting_lane, int lane_stride, int new_lanes, const Scope<> &lets) + ExtractLanes( + int starting_lane, int lane_stride, int new_lanes, + const Scope<> &sliceable_lets, + Scope> &requested_slices) : starting_lane(starting_lane), lane_stride(lane_stride), new_lanes(new_lanes), - external_lets(lets) { + sliceable_lets(sliceable_lets), + requested_slices(requested_slices) { } private: @@ -191,23 +222,88 @@ class Deinterleaver : public IRGraphMutator { int new_lanes; // lets for which we have even and odd lane specializations - const Scope<> &external_lets; + const Scope<> &sliceable_lets; + Scope> &requested_slices; // We populate this with the slices we need from the external_lets. using IRMutator::visit; + Expr extract_lanes_from_make_struct(const Call *op) { + internal_assert(op); + internal_assert(op->is_intrinsic(Call::make_struct)); + std::vector args(op->args.size()); + for (int i = 0; i < int(op->args.size()); ++i) { + args[i] = mutate(op->args[i]); + } + return Call::make(op->type, Call::make_struct, args, Call::Intrinsic); + } + + Expr extract_lanes_trace(const Call *op) { + auto event = as_const_int(op->args[6]); + internal_assert(event); + if (*event == halide_trace_load || *event == halide_trace_store) { + debug(3) << "Extracting Trace Lanes: " << Expr(op) << "\n"; + const Expr &func = op->args[0]; + Expr values = extract_lanes_from_make_struct(op->args[1].as()); + Expr coords = extract_lanes_from_make_struct(op->args[2].as()); + const Expr &type_code = op->args[3]; + const Expr &type_bits = op->args[4]; + int type_lanes = *as_const_int(op->args[5]); + const Expr &event = op->args[6]; + const Expr &parent_id = op->args[7]; + const Expr &idx = op->args[8]; + int size = *as_const_int(op->args[9]); + const Expr &tag = op->args[10]; + + int num_vecs = op->args[2].as()->args.size(); + internal_assert(size == type_lanes * num_vecs) << Expr(op); + std::vector args = { + func, + values, coords, + type_code, type_bits, Expr(new_lanes), + event, parent_id, idx, Expr(new_lanes * num_vecs), + tag}; + Expr result = Call::make(Int(32), Call::trace, args, Call::Extern); + debug(4) << " => " << result << "\n"; + return result; + } + + internal_error << "Unhandled trace call in ExtractLanes: " << *event; + } + Expr visit(const VectorReduce *op) override { - std::vector input_lanes; int factor = op->value.type().lanes() / op->type.lanes(); - for (int i = starting_lane; i < op->type.lanes(); i += lane_stride) { - for (int j = 0; j < factor; j++) { - input_lanes.push_back(i * factor + j); + if (lane_stride != 1) { + std::vector input_lanes; + for (int i = 0; i < new_lanes; ++i) { + int lane_start = (starting_lane + lane_stride * i) * factor; + for (int j = 0; j < factor; j++) { + input_lanes.push_back(lane_start + j); + } + } + Expr in = Shuffle::make({op->value}, input_lanes); + return VectorReduce::make(op->op, in, new_lanes); + } else { + Expr in; + { + ScopedValue old_starting_lane(starting_lane, starting_lane * factor); + ScopedValue old_new_lanes(new_lanes, new_lanes * factor); + in = mutate(op->value); } + return VectorReduce::make(op->op, in, new_lanes); } - Expr in = Shuffle::make({op->value}, input_lanes); - return VectorReduce::make(op->op, in, new_lanes); } Expr visit(const Broadcast *op) override { + if (const Call *call = op->value.as()) { + if (call->name == Call::trace) { + Expr value = extract_lanes_trace(call); + if (new_lanes == 1) { + return value; + } else { + return Broadcast::make(value, new_lanes); + } + } + } if (new_lanes == 1) { if (op->value.type().lanes() == 1) { return op->value; @@ -299,30 +395,40 @@ class Deinterleaver : public IRGraphMutator { } else { Type t = op->type.with_lanes(new_lanes); + /* internal_assert((op->type.lanes() - starting_lane + lane_stride - 1) / lane_stride == new_lanes) << "Deinterleaving with lane stride " << lane_stride << " and staring lane " << starting_lane << " for var of Type " << op->type << " to " << t << " drops lanes unexpectedly." << " Deinterleaver probably recursed too deep into types of different lane count."; - if (external_lets.contains(op->name) && - starting_lane == 0 && - lane_stride == 2) { - return Variable::make(t, op->name + ".even_lanes", op->image, op->param, op->reduction_domain); - } else if (external_lets.contains(op->name) && - starting_lane == 1 && - lane_stride == 2) { - return Variable::make(t, op->name + ".odd_lanes", op->image, op->param, op->reduction_domain); - } else if (external_lets.contains(op->name) && - starting_lane == 0 && - lane_stride == 3) { - return Variable::make(t, op->name + ".lanes_0_of_3", op->image, op->param, op->reduction_domain); - } else if (external_lets.contains(op->name) && - starting_lane == 1 && - lane_stride == 3) { - return Variable::make(t, op->name + ".lanes_1_of_3", op->image, op->param, op->reduction_domain); - } else if (external_lets.contains(op->name) && - starting_lane == 2 && - lane_stride == 3) { - return Variable::make(t, op->name + ".lanes_2_of_3", op->image, op->param, op->reduction_domain); + */ + + if (sliceable_lets.contains(op->name)) { + // The variable accessed is marked as sliceable by the caller. + // Let's request a slice and pretend it exists. + std::string sliced_var_name = variable_name_with_extracted_lanes( + op->name, op->type.lanes(), + starting_lane, lane_stride, new_lanes); + VectorSlice new_sl = {.start = starting_lane, + .stride = lane_stride, + .count = new_lanes, + .variable_name = sliced_var_name}; + if (auto *vec = requested_slices.shallow_find(op->name)) { + bool found = false; + for (const VectorSlice &existing_sl : *vec) { + if (existing_sl.start == starting_lane && + existing_sl.stride == lane_stride && + existing_sl.count == new_lanes) { + found = true; + break; + } + } + if (!found) { + vec->push_back(std::move(new_sl)); + } + } else { + requested_slices.push(op->name, {std::move(new_sl)}); + } + return Variable::make(t, sliced_var_name, op->image, op->param, op->reduction_domain); } else { return give_up_and_shuffle(op); } @@ -339,24 +445,117 @@ class Deinterleaver : public IRGraphMutator { } Expr visit(const Reinterpret *op) override { + // Written with assistance from Gemini 3 Pro, which required a lot of baby-sitting. + + // Simple case of a scalar reinterpret: always one lane: if (op->type.is_scalar()) { return op; - } else if (op->type.bits() != op->value.type().bits()) { - return give_up_and_shuffle(op); - } else { - Type t = op->type.with_lanes(new_lanes); - return Reinterpret::make(t, mutate(op->value)); } + + int out_bits = op->type.bits(); + int in_bits = op->value.type().bits(); + + internal_assert(out_bits % in_bits == 0 || in_bits % out_bits == 0); + + // Case A: Stride 1. Calculate everything with bit-offsets + if (lane_stride == 1) { + + // Compute range of bits required from the input. + int start_bit = starting_lane * out_bits; + int total_bits = new_lanes * out_bits; + int end_bit = start_bit + total_bits; + + // Convert this to a range of lane indices + int start_input_lane = start_bit / in_bits; + int end_input_lane = (end_bit + in_bits - 1) / in_bits; + int num_input_lanes = end_input_lane - start_input_lane; + + // Actually now get those lanes from the input. + Expr extracted_input_lanes; + { + ScopedValue old_sl(starting_lane, start_input_lane); + ScopedValue old_nl(new_lanes, num_input_lanes); + extracted_input_lanes = mutate(op->value); + } + + // The range of lanes we extracted from the input still might be too big, because + // we had to grab whole elements from the input, which can be coarser if out_bits > in_bits. + // So calculate how many lanes we extracted, when measured in the reinterpreted output type. + int intm_lanes = (num_input_lanes * in_bits) / out_bits; + Expr reinterprted = Reinterpret::make(op->type.with_lanes(intm_lanes), extracted_input_lanes); + + // Now calculate how many we output Type lanes we need to trim away. + int bits_to_strip_front = start_bit - (start_input_lane * in_bits); + int lanes_to_strip_front = bits_to_strip_front / out_bits; + + if (lanes_to_strip_front == 0) { + internal_assert(reinterprted.type().lanes() == new_lanes); + return reinterprted; + } else { + return Shuffle::make_slice(reinterprted, lanes_to_strip_front, 1, new_lanes); + } + } + + // Case B: Stride != 1. We are effectively gathering. + // We will rewrite those Reinterprets as a Concat of Reinterprets of extracted lanes. + std::vector chunks(new_lanes); + for (int i = 0; i < new_lanes; ++i) { + // Find the bit range of this element in the output + int start_bit = (starting_lane + lane_stride * i) * out_bits; + int end_bit = start_bit + out_bits; + + // Map it to input lanes + int start_input_lane = start_bit / in_bits; + int end_input_lane = (end_bit + in_bits - 1) / in_bits; + int num_input_lanes = end_input_lane - start_input_lane; + + // Grab this range of lanes from the input + Expr input_chunk; + { + ScopedValue s_start(starting_lane, start_input_lane); + ScopedValue s_stride(lane_stride, 1); + ScopedValue s_len(new_lanes, num_input_lanes); + input_chunk = mutate(op->value); + } + + // Reinterpret the chunk. + int extracted_bits = num_input_lanes * in_bits; + int reinterpreted_lanes = extracted_bits / out_bits; + internal_assert(reinterpreted_lanes != 0); + + Expr reinterpreted = Reinterpret::make(op->type.with_lanes(reinterpreted_lanes), input_chunk); + + // Now, in case of demotion: + // Example: + // R = ExtractLanes(Reinterpret([u32, u32, u32, u32], u8), 0, 2, 4) + // = ExtractLanes([u8_0, u8_1, u8_2, u8_3, ...], 0, 2, 4) + // = [u8_0, u8_2, u8_4, u8_6] + // A single extracted u32 element is too large, even after reinterpreting. + // So we need to slice the reinterpreted result. + int bit_offset = start_bit - (start_input_lane * in_bits); + int lane_offset = bit_offset / out_bits; + + if (lane_offset == 0 && reinterpreted_lanes == 1) { + chunks[i] = std::move(input_chunk); + } else { + chunks[i] = Shuffle::make_extract_element(reinterpreted, lane_offset); + } + } + + // In case of demotion, we will potentially extract and reinterpret the same input lane several times. + // Simplification afterwards will turn them into Lets. + + return Shuffle::make_concat(chunks); } Expr visit(const Call *op) override { + internal_assert(op->type.lanes() >= starting_lane + lane_stride * (new_lanes - 1)) << Expr(op) << starting_lane << " " << lane_stride << " " << new_lanes; Type t = op->type.with_lanes(new_lanes); // Don't mutate scalars if (op->type.is_scalar()) { return op; } else { - // Vector calls are always parallel across the lanes, so we // can just deinterleave the args. @@ -368,105 +567,182 @@ class Deinterleaver : public IRGraphMutator { } Expr visit(const Shuffle *op) override { + // Special case 1: Scalar extraction + if (new_lanes == 1) { + // Find in which vector it sits. + int index = starting_lane; + for (const auto &vec : op->vectors) { + if (index < vec.type().lanes()) { + // We found the source vector. Extract the scalar from it. + ScopedValue old_start(starting_lane, index); + ScopedValue old_stride(lane_stride, 1); // Stride doesn't matter for scalar + ScopedValue old_count(new_lanes, 1); + return mutate(vec); + } + index -= vec.type().lanes(); + } + internal_error << "extract_lane index out of bounds: " << Expr(op) << " " << index << "\n"; + } + if (op->is_interleave()) { // Special case where we can discard some of the vector arguments entirely. - internal_assert(starting_lane >= 0 && starting_lane < lane_stride); - if ((int)op->vectors.size() == lane_stride) { - return op->vectors[starting_lane]; - } else if ((int)op->vectors.size() % lane_stride == 0) { - // Pick up every lane-stride vector. - std::vector new_vectors(op->vectors.size() / lane_stride); - for (size_t i = 0; i < new_vectors.size(); i++) { - new_vectors[i] = op->vectors[i * lane_stride + starting_lane]; + internal_assert(starting_lane >= 0); + int n_vectors = (int)op->vectors.size(); + + // Case A: Stride is a multiple of the number of input vectors. + // Example: extract_lanes(interleave(A, B), stride=4) + // result comes from either A or B, depending on starting lane modulo number of vectors, + // required stride of said vector is lane_stride / num_vectors + if (lane_stride % n_vectors == 0) { + const Expr &vec = op->vectors[starting_lane % n_vectors]; + if (vec.type().lanes() == new_lanes) { + // We need all lanes of this vector, just return it. + return vec; + } else { + // We don't need all lanes, unfortunately. Let's extract the part we need. + ScopedValue old_starting_lane(starting_lane, starting_lane / n_vectors); + ScopedValue old_lane_stride(lane_stride, lane_stride / n_vectors); + return mutate(vec); + } + } + + // Case B: Number of vectors is a multiple of the stride. + // Eg: extract_lanes(interleave(a, b, c, d, e, f), start=8, stride=3) + // = extract_lanes(a0, b0, c0, d0, e0, f0, a1, b1, c1, d1, e1, f1, ...) + // = (a2, c2, e2, c1, ...) + // = interleave(a, c) + if (n_vectors % lane_stride == 0) { + int num_required_vectors = n_vectors / lane_stride; + + // The result is only an interleave if the number of constituent + // vectors divides the number of total required lanes. + if (new_lanes % num_required_vectors == 0) { + int lanes_per_vec = new_lanes / num_required_vectors; + + // Pick up every lane-stride vector. + std::vector new_vectors(num_required_vectors); + for (size_t i = 0; i < new_vectors.size(); i++) { + int absolute_lane_index = starting_lane + i * lane_stride; + int src_vec_idx = absolute_lane_index % n_vectors; + int vec_lane_start = absolute_lane_index / n_vectors; + const Expr &vec = op->vectors[src_vec_idx]; + + ScopedValue old_starting_lane(starting_lane, vec_lane_start); + ScopedValue old_lane_stride(lane_stride, 1); + ScopedValue old_new_lanes(new_lanes, lanes_per_vec); + new_vectors[i] = mutate(vec); + } + return Shuffle::make_interleave(new_vectors); } - return Shuffle::make_interleave(new_vectors); } } - // Keep the same set of vectors and extract every nth numeric - // arg to the shuffle. - std::vector indices; + // General case fallback + std::vector indices(new_lanes); + bool constant_stride = true; for (int i = 0; i < new_lanes; i++) { - int idx = i * lane_stride + starting_lane; - indices.push_back(op->indices[idx]); - } - - // If this is extracting a single lane, try to recursively deinterleave rather - // than leaving behind a shuffle. - if (indices.size() == 1) { - int index = indices.front(); - for (const auto &i : op->vectors) { - if (index < i.type().lanes()) { - if (i.type().lanes() == op->type.lanes()) { - ScopedValue scoped_starting_lane(starting_lane, index); - return mutate(i); - } else { - return Shuffle::make(op->vectors, indices); + int idx = op->indices[i * lane_stride + starting_lane]; + indices[i] = idx; + if (i > 1 && constant_stride) { + int stride = indices[1] - indices[0]; + if (indices[i] != indices[i - 1] + stride) { + constant_stride = false; + } + } + } + + // One optimization if we take a slice of a single vector. + if (constant_stride) { + int stride = indices[1] - indices[0]; + int first_idx = indices.front(); + int last_idx = indices.back(); + + // Find which vector contains this range + int current_bound = 0; + for (const auto &vec : op->vectors) { + int vec_lanes = vec.type().lanes(); + + // Check if the START of the ramp is in this vector + if (first_idx >= current_bound && first_idx < current_bound + vec_lanes) { + + // We found the vector containing the start. + // Now, because it is a linear ramp, we only need to check if the + // END of the ramp is also within this same vector. + // (This handles negative strides, forward strides, and broadcasts correctly). + if (last_idx >= current_bound && last_idx < current_bound + vec_lanes) { + + // Calculate the start index relative to this specific vector + int local_start = first_idx - current_bound; + + ScopedValue s_start(starting_lane, local_start); + ScopedValue s_stride(lane_stride, stride); + // new_lanes is already correct + return mutate(vec); } + + // If the start is here but the end is elsewhere, the ramp crosses + // vector boundaries. We cannot optimize this as a single vector extraction. + break; } - index -= i.type().lanes(); + current_bound += vec_lanes; } - internal_error << "extract_lane index out of bounds: " << Expr(op) << " " << index << "\n"; } return Shuffle::make(op->vectors, indices); } }; -Expr deinterleave(Expr e, int starting_lane, int lane_stride, int new_lanes, const Scope<> &lets) { - debug(3) << "Deinterleave " - << "(start:" << starting_lane << ", stide:" << lane_stride << ", new_lanes:" << new_lanes << "): " - << e << " of Type: " << e.type() << "\n"; - Type original_type = e.type(); - e = substitute_in_all_lets(e); - Deinterleaver d(starting_lane, lane_stride, new_lanes, lets); +} // namespace + +Expr extract_lanes(Expr original_expr, int starting_lane, int lane_stride, int new_lanes, const Scope<> &lets, Scope> &requested_sliced_lets) { + internal_assert(starting_lane + (new_lanes - 1) * lane_stride <= original_expr.type().lanes()) + << "Extract lanes with start:" << starting_lane << ", stride:" << lane_stride << ", new_lanes:" << new_lanes << " " + << "out of " << original_expr.type() << " which goes out of bounds."; + + debug(3) << "ExtractLanes " + << "(start:" << starting_lane << ", stride:" << lane_stride << ", new_lanes:" << new_lanes << "): " + << original_expr << " of Type: " << original_expr.type() << "\n"; + Type original_type = original_expr.type(); + Expr e = substitute_in_all_lets(original_expr); + ExtractLanes d(starting_lane, lane_stride, new_lanes, lets, requested_sliced_lets); e = d.mutate(e); e = common_subexpression_elimination(e); + debug(3) << " => " << e << "\n"; Type final_type = e.type(); - int expected_lanes = (original_type.lanes() + lane_stride - starting_lane - 1) / lane_stride; - internal_assert(original_type.code() == final_type.code()) << "Underlying types not identical after interleaving."; - internal_assert(expected_lanes == final_type.lanes()) << "Number of lanes incorrect after interleaving: " << final_type.lanes() << "while expected was " << expected_lanes << "."; - return simplify(e); + internal_assert(original_type.code() == final_type.code()) << "Underlying types not identical after extract_lanes."; + e = simplify(e); + internal_assert(new_lanes == final_type.lanes()) + << "Number of lanes incorrect after extract_lanes: " << final_type.lanes() << " while expected was " << new_lanes << ": extract_lanes(" << starting_lane << ", " << lane_stride << ", " << new_lanes << "):\n" + << "Input: " << original_expr << "\nResult: " << e; + return e; } -Expr extract_odd_lanes(const Expr &e, const Scope<> &lets) { - internal_assert(e.type().lanes() % 2 == 0); - return deinterleave(e, 1, 2, e.type().lanes() / 2, lets); -} - -Expr extract_even_lanes(const Expr &e, const Scope<> &lets) { - internal_assert(e.type().lanes() % 2 == 0); - return deinterleave(e, 0, 2, e.type().lanes() / 2, lets); -} - -Expr extract_mod3_lanes(const Expr &e, int lane, const Scope<> &lets) { - internal_assert(e.type().lanes() % 3 == 0); - return deinterleave(e, lane, 3, e.type().lanes() / 3, lets); +Expr extract_lanes(Expr e, int starting_lane, int lane_stride, int new_lanes) { + Scope<> lets; + Scope> req; + return extract_lanes(std::move(e), starting_lane, lane_stride, new_lanes, lets, req); } -} // namespace - Expr extract_even_lanes(const Expr &e) { internal_assert(e.type().lanes() % 2 == 0); - Scope<> lets; - return extract_even_lanes(e, lets); + return extract_lanes(e, 0, 2, e.type().lanes() / 2); } Expr extract_odd_lanes(const Expr &e) { internal_assert(e.type().lanes() % 2 == 0); - Scope<> lets; - return extract_odd_lanes(e, lets); + return extract_lanes(e, 1, 2, e.type().lanes() / 2); } Expr extract_lane(const Expr &e, int lane) { - Scope<> lets; - return deinterleave(e, lane, e.type().lanes(), 1, lets); + return extract_lanes(e, lane, e.type().lanes(), 1); } namespace { +// Change name to DenisfyStridedLoadsAndStores? class Interleaver : public IRMutator { Scope<> vector_lets; + Scope> requested_sliced_lets; using IRMutator::visit; @@ -475,9 +751,9 @@ class Interleaver : public IRMutator { Expr deinterleave_expr(const Expr &e) { std::vector exprs; + exprs.reserve(num_lanes); for (int i = 0; i < num_lanes; i++) { - Scope<> lets; - exprs.emplace_back(deinterleave(e, i, num_lanes, e.type().lanes() / num_lanes, lets)); + exprs.emplace_back(extract_lanes(e, i, num_lanes, e.type().lanes() / num_lanes, vector_lets, requested_sliced_lets)); } return Shuffle::make_interleave(exprs); } @@ -508,18 +784,17 @@ class Interleaver : public IRMutator { for (const auto &frame : reverse_view(frames)) { Expr value = std::move(frame.new_value); + // The original variable: result = LetOrLetStmt::make(frame.op->name, value, result); - // For vector lets, we may additionally need a let defining the even and odd lanes only + // For vector lets, we may additionally need a lets for the requested slices of this variable: if (value.type().is_vector()) { - if (value.type().lanes() % 2 == 0) { - result = LetOrLetStmt::make(frame.op->name + ".even_lanes", extract_even_lanes(value, vector_lets), result); - result = LetOrLetStmt::make(frame.op->name + ".odd_lanes", extract_odd_lanes(value, vector_lets), result); - } - if (value.type().lanes() % 3 == 0) { - result = LetOrLetStmt::make(frame.op->name + ".lanes_0_of_3", extract_mod3_lanes(value, 0, vector_lets), result); - result = LetOrLetStmt::make(frame.op->name + ".lanes_1_of_3", extract_mod3_lanes(value, 1, vector_lets), result); - result = LetOrLetStmt::make(frame.op->name + ".lanes_2_of_3", extract_mod3_lanes(value, 2, vector_lets), result); + if (std::vector *reqs = requested_sliced_lets.shallow_find(frame.op->name)) { + for (const VectorSlice &sl : *reqs) { + result = LetOrLetStmt::make( + sl.variable_name, + extract_lanes(value, sl.start, sl.stride, sl.count, vector_lets, requested_sliced_lets), result); + } } } } diff --git a/src/Deinterleave.h b/src/Deinterleave.h index 485641f71a5f..0332e0bfc8c0 100644 --- a/src/Deinterleave.h +++ b/src/Deinterleave.h @@ -9,15 +9,21 @@ */ #include "Expr.h" +#include "Scope.h" namespace Halide { namespace Internal { -/** Extract the odd-numbered lanes in a vector */ -Expr extract_odd_lanes(const Expr &a); +struct VectorSlice { + int start, stride, count; + std::string variable_name; +}; -/** Extract the even-numbered lanes in a vector */ -Expr extract_even_lanes(const Expr &a); +/* Extract lanes and relying on the fact that the caller will provide new variables in Lets or LetStmts which correspond to slices of the original variable. */ +Expr extract_lanes(Expr e, int starting_lane, int lane_stride, int new_lanes, const Scope<> &sliceable_lets, Scope> &requested_sliced_lets); + +/* Extract lanes without requesting any extra slices from variables. */ +Expr extract_lanes(Expr e, int starting_lane, int lane_stride, int new_lanes); /** Extract the nth lane of a vector */ Expr extract_lane(const Expr &vec, int lane); diff --git a/src/LegalizeVectors.cpp b/src/LegalizeVectors.cpp index c3919ef67da5..a10b8cb8fd2a 100644 --- a/src/LegalizeVectors.cpp +++ b/src/LegalizeVectors.cpp @@ -1,13 +1,11 @@ #include "LegalizeVectors.h" #include "CSE.h" #include "Deinterleave.h" -#include "DeviceInterface.h" #include "IRMutator.h" #include "IROperator.h" #include "Simplify.h" #include "Util.h" -#include #include #include @@ -18,10 +16,6 @@ namespace { using namespace std; -const char *legalization_error_guide = "\n" - "(This is an implemenation limitation in Halide right now. This issue can most likely be \n" - " worked around by reducing lane count for vectorize() calls in GPU schedules, or disabling it.)"; - int max_lanes_for_device(DeviceAPI api, int parent_max_lanes) { // The environment variable below (HL_FORCE_VECTOR_LEGALIZATION) is here solely for testing purposes. // It is useful to "stress-test" this lowering pass by forcing a shorter maximal vector size across @@ -54,10 +48,6 @@ int max_lanes_for_device(DeviceAPI api, int parent_max_lanes) { return 0; } -std::string vec_name(const string &name, int lane_start, int lane_count) { - return name + ".lanes_" + std::to_string(lane_start) + "_" + std::to_string(lane_start + lane_count - 1); -} - class LiftLetToLetStmt : public IRMutator { using IRMutator::visit; @@ -86,184 +76,6 @@ class LiftLetToLetStmt : public IRMutator { } }; -class ExtractLanes : public IRMutator { - using IRMutator::visit; - - int lane_start; - int lane_count; - - Expr extract_lanes_from_make_struct(const Call *op) { - internal_assert(op); - internal_assert(op->is_intrinsic(Call::make_struct)); - vector args(op->args.size()); - for (int i = 0; i < int(op->args.size()); ++i) { - args[i] = mutate(op->args[i]); - } - return Call::make(op->type, Call::make_struct, args, Call::Intrinsic); - } - - Expr extract_lanes_trace(const Call *op) { - auto event = as_const_int(op->args[6]); - internal_assert(event); - if (*event == halide_trace_load || *event == halide_trace_store) { - debug(3) << "Extracting Trace Lanes: " << Expr(op) << "\n"; - const Expr &func = op->args[0]; - Expr values = extract_lanes_from_make_struct(op->args[1].as()); - Expr coords = extract_lanes_from_make_struct(op->args[2].as()); - const Expr &type_code = op->args[3]; - const Expr &type_bits = op->args[4]; - int type_lanes = *as_const_int(op->args[5]); - const Expr &event = op->args[6]; - const Expr &parent_id = op->args[7]; - const Expr &idx = op->args[8]; - int size = *as_const_int(op->args[9]); - const Expr &tag = op->args[10]; - - int num_vecs = op->args[2].as()->args.size(); - internal_assert(size == type_lanes * num_vecs) << Expr(op); - vector args = { - func, - values, coords, - type_code, type_bits, Expr(lane_count), - event, parent_id, idx, Expr(lane_count * num_vecs), - tag}; - Expr result = Call::make(Int(32), Call::trace, args, Call::Extern); - debug(4) << " => " << result << "\n"; - return result; - } - - internal_error << "Unhandled trace call in LegalizeVectors' ExtractLanes: " << *event << legalization_error_guide; - return Expr(0); - } - - Expr visit(const Shuffle *op) override { - vector new_indices; - new_indices.reserve(lane_count); - for (int i = 0; i < lane_count; ++i) { - new_indices.push_back(op->indices[lane_start + i]); - } - return simplify(Shuffle::make(op->vectors, new_indices)); - } - - Expr visit(const Ramp *op) override { - if (lane_count == 1) { - return simplify(op->base + op->stride * lane_start); - } - return simplify(Ramp::make(op->base + op->stride * lane_start, op->stride, lane_count)); - } - - Expr visit(const Broadcast *op) override { - Expr value = op->value; - if (const Call *call = op->value.as()) { - if (call->name == Call::trace) { - value = extract_lanes_trace(call); - } - } - if (lane_count == 1) { - return value; - } else { - return Broadcast::make(value, lane_count); - } - } - - Expr visit(const Variable *op) override { - return Variable::make(op->type.with_lanes(lane_count), vec_name(op->name, lane_start, lane_count)); - } - - Expr visit(const Load *op) override { - return Load::make(op->type.with_lanes(lane_count), - op->name, - mutate(op->index), - op->image, op->param, - mutate(op->predicate), - op->alignment + lane_start); - } - - Expr visit(const Call *op) override { - internal_assert(op->type.lanes() >= lane_start + lane_count); - Expr mutated = op; - std::vector args; - args.reserve(op->args.size()); - for (int i = 0; i < int(op->args.size()); ++i) { - const Expr &arg = op->args[i]; - internal_assert(arg.type().lanes() == op->type.lanes()) - << "Call argument " << arg << " lane count of " << arg.type().lanes() - << " does not match op lane count of " << op->type.lanes(); - Expr mutated = mutate(arg); - internal_assert(!mutated.same_as(arg)); - args.push_back(mutated); - } - mutated = Call::make(op->type.with_lanes(lane_count), op->name, args, op->call_type); - return mutated; - } - - Expr visit(const Cast *op) override { - return Cast::make(op->type.with_lanes(lane_count), mutate(op->value)); - } - - Expr visit(const Reinterpret *op) override { - Type result_type = op->type.with_lanes(lane_count); - int result_scalar_bits = op->type.element_of().bits(); - int input_scalar_bits = op->value.type().element_of().bits(); - - Expr value = op->value; - // If the bit widths of the scalar elements are the same, it's easy. - if (result_scalar_bits == input_scalar_bits) { - value = mutate(value); - } else { - // Otherwise, there can be two limiting aspects: the input lane count and the resulting lane count. - // In order to construct a correct Reinterpret from a small type to a wider type, we - // will need to produce multiple Reinterprets, all able to hold the lane count of the input - // and concatate the results together. - // Even worse, reinterpreting uint8x8 to uint64 would require intermediate reinterprets - // if the maximul legal vector length is 4. - // - // TODO implement this for all scenarios - internal_error << "Vector legalization for Reinterpret to different bit size per element is " - << "not supported yet: reinterpret<" << op->type << ">(" << value.type() << ")" - << legalization_error_guide; - - // int input_lane_start = lane_start * result_scalar_bits / input_scalar_bits; - // int input_lane_count = lane_count * result_scalar_bits / input_scalar_bits; - } - Expr result = Reinterpret::make(result_type, value); - debug(3) << "Legalized " << Expr(op) << " to " << result << "\n"; - return result; - } - - Expr visit(const VectorReduce *op) override { - internal_assert(op->type.lanes() >= lane_start + lane_count); - int vecs_per_reduction = op->value.type().lanes() / op->type.lanes(); - int input_lane_start = vecs_per_reduction * lane_start; - int input_lane_count = vecs_per_reduction * lane_count; - Expr arg = ExtractLanes(input_lane_start, input_lane_count).mutate(op->value); - // This might fail if the extracted lanes reference a non-existing variable! - return VectorReduce::make(op->op, arg, lane_count); - } - -public: - // Small helper to assert the transform did what it's supposed to do. - Expr mutate(const Expr &e) override { - Type original_type = e.type(); - internal_assert(original_type.lanes() >= lane_start + lane_count) - << "Cannot extract lanes " << lane_start << " through " << lane_start + lane_count - 1 - << " when the input type is " << original_type; - Expr result = IRMutator::mutate(e); - Type new_type = result.type(); - internal_assert(new_type.lanes() == lane_count) - << "We didn't correctly legalize " << e << " of type " << original_type << ".\n" - << "Got back: " << result << " of type " << new_type << ", expected " << lane_count << " lanes."; - return result; - } - - Stmt mutate(const Stmt &s) override { - return IRMutator::mutate(s); - } - - ExtractLanes(int start, int count) - : lane_start(start), lane_count(count) { - } -}; class LiftExceedingVectors : public IRMutator { using IRMutator::visit; @@ -380,23 +192,31 @@ class LegalizeVectors : public IRMutator { int max_lanes; - Stmt visit(const LetStmt *op) override { + Scope<> sliceable_vectors; + Scope> requested_slices; + + template + auto visit_let(const LetOrLetStmt *op) -> decltype(op->body) { bool exceeds_lanecount = op->value.type().lanes() > max_lanes; if (exceeds_lanecount) { int num_vecs = (op->value.type().lanes() + max_lanes - 1) / max_lanes; debug(3) << "Legalize let " << op->value.type() << ": " << op->name << " = " << op->value << " into " << num_vecs << " vecs\n"; - Stmt body = IRMutator::mutate(op->body); - for (int i = num_vecs - 1; i >= 0; --i) { - int lane_start = i * max_lanes; - int lane_count_for_vec = std::min(op->value.type().lanes() - lane_start, max_lanes); - std::string name = vec_name(op->name, lane_start, lane_count_for_vec); - Expr value = mutate(ExtractLanes(lane_start, lane_count_for_vec).mutate(op->value)); + // First mark this Let as sliceable before mutating the body: + ScopedBinding<> vector_is_slicable(sliceable_vectors, op->name); + + Stmt body = mutate(op->body); + // Here we know which requested vector variable slices should be created for the body of the Let/LetStmt to work. - debug(3) << " Add: let " << name << " = " << value << "\n"; - body = LetStmt::make(name, value, body); + if (std::vector *reqs = requested_slices.shallow_find(op->name)) { + for (const VectorSlice &sl : *reqs) { + Expr value = extract_lanes(op->value, sl.start, sl.stride, sl.count, sliceable_vectors, requested_slices); + value = mutate(value); + body = LetOrLetStmt::make(sl.variable_name, value, body); + debug(3) << " Add: let " << sl.variable_name << " = " << value << "\n"; + } } return body; } else { @@ -404,7 +224,12 @@ class LegalizeVectors : public IRMutator { } } + Stmt visit(const LetStmt *op) override { + return visit_let(op); + } + Expr visit(const Let *op) override { + // TODO is this still true? internal_error << "Lets should have been lifted into LetStmts."; return IRMutator::visit(op); } @@ -419,9 +244,10 @@ class LegalizeVectors : public IRMutator { for (int i = 0; i < num_vecs; ++i) { int lane_start = i * max_lanes; int lane_count_for_vec = std::min(op->value.type().lanes() - lane_start, max_lanes); - Expr rhs = ExtractLanes(lane_start, lane_count_for_vec).mutate(op->value); - Expr index = ExtractLanes(lane_start, lane_count_for_vec).mutate(op->index); - Expr predictate = ExtractLanes(lane_start, lane_count_for_vec).mutate(op->predicate); + + Expr rhs = extract_lanes(op->value, lane_start, 1, lane_count_for_vec, sliceable_vectors, requested_slices); + Expr index = extract_lanes(op->index, lane_start, 1, lane_count_for_vec, sliceable_vectors, requested_slices); + Expr predictate = extract_lanes(op->predicate, lane_start, 1, lane_count_for_vec, sliceable_vectors, requested_slices); assignments.push_back(Store::make( op->name, std::move(rhs), std::move(index), op->param, std::move(predictate), op->alignment + lane_start)); @@ -434,7 +260,29 @@ class LegalizeVectors : public IRMutator { } Expr visit(const Shuffle *op) override { - internal_assert(op->type.lanes() <= max_lanes) << Expr(op); + // Primary violatation: there are too many output lanes. + if (op->type.lanes() > max_lanes) { + // Break it down in multiple legal-output-length shuffles, and concatenate them back together. + int total_lanes = op->type.lanes(); + + std::vector output_chunks; + output_chunks.reserve((total_lanes + max_lanes - 1) / max_lanes); + for (int i = 0; i < total_lanes; i += max_lanes) { + int slice_len = std::min(max_lanes, total_lanes - i); + + std::vector slice_indices(slice_len); + for (int k = 0; k < slice_len; ++k) { + slice_indices[k] = op->indices[i + k]; + } + + Expr sub_shuffle = Shuffle::make(op->vectors, slice_indices); + + output_chunks.push_back(mutate(sub_shuffle)); + } + return Shuffle::make_concat(output_chunks); + } + + // Secondary violation: input vectors have too many lanes. bool requires_mutation = false; for (const auto &vec : op->vectors) { if (vec.type().lanes() > max_lanes) { @@ -459,7 +307,7 @@ class LegalizeVectors : public IRMutator { for (int i = 0; i < num_vecs; i++) { int lane_start = i * max_lanes; int lane_count_for_vec = std::min(vec.type().lanes() - lane_start, max_lanes); - new_vectors.push_back(ExtractLanes(lane_start, lane_count_for_vec).mutate(vec)); + new_vectors.push_back(extract_lanes(vec, lane_start, 1, lane_count_for_vec, sliceable_vectors, requested_slices)); } } else { new_vectors.push_back(IRMutator::mutate(vec)); @@ -469,58 +317,111 @@ class LegalizeVectors : public IRMutator { debug(3) << "Legalized " << Expr(op) << " => " << result << "\n"; return result; } + + // Base case: everything legal in this Shuffle return IRMutator::visit(op); } + Expr make_binary_reduce_op(VectorReduce::Operator op, Expr a, Expr b) { + switch (op) { + case VectorReduce::Add: + return a + b; + case VectorReduce::SaturatingAdd: + return saturating_add(a, b); + case VectorReduce::Mul: + return a * b; + case VectorReduce::Min: + return min(a, b); + case VectorReduce::Max: + return max(a, b); + case VectorReduce::And: + return a && b; + case VectorReduce::Or: + return a || b; + default: + internal_error << "Unknown VectorReduce operator\n"; + return Expr(); + } + } + Expr visit(const VectorReduce *op) override { - const Expr &arg = op->value; - if (arg.type().lanes() > max_lanes) { - // TODO: The transformation below is not allowed under strict_float, but - // I don't immediately know what to do here. - // This should be an internal_assert. - - internal_assert(op->type.lanes() == 1) - << "Vector legalization currently does not support VectorReduce with lanes != 1: " << Expr(op) - << legalization_error_guide; - int num_vecs = (arg.type().lanes() + max_lanes - 1) / max_lanes; - Expr result; - for (int i = 0; i < num_vecs; i++) { - int lane_start = i * max_lanes; - int lane_count_for_vec = std::min(arg.type().lanes() - lane_start, max_lanes); - Expr partial_arg = mutate(ExtractLanes(lane_start, lane_count_for_vec).mutate(arg)); - Expr partial_red = VectorReduce::make(op->op, std::move(partial_arg), op->type.lanes()); - if (i == 0) { - result = partial_red; - } else { - switch (op->op) { - case VectorReduce::Add: - result = result + partial_red; - break; - case VectorReduce::SaturatingAdd: - result = saturating_add(result, partial_red); - break; - case VectorReduce::Mul: - result = result * partial_red; - break; - case VectorReduce::Min: - result = min(result, partial_red); - break; - case VectorReduce::Max: - result = max(result, partial_red); - break; - case VectorReduce::And: - result = result && partial_red; - break; - case VectorReduce::Or: - result = result || partial_red; - break; - } - } + // Written with the help of Gemini 3 Pro. + Expr value = mutate(op->value); + + int input_lanes = value.type().lanes(); + int output_lanes = op->type.lanes(); + + // Base case: we don't need legalization. + if (input_lanes <= max_lanes && output_lanes <= max_lanes) { + if (value.same_as(op->value)) { + return op; + } else { + return VectorReduce::make(op->op, value, output_lanes); } - return result; - } else { - return IRMutator::visit(op); } + + // Recursive splitting strategy. + // Case A: Segmented Reduction (Multiple Output Lanes) + // Example: VectorReduce( <16 lanes>, output_lanes=2 ) with max_lanes=4. + // Input is too big. We split the OUTPUT domain. + // We calculate which chunk of the input corresponds to the first half of the output. + if (output_lanes > 1) { + // 1. Calculate good splitting point + int out_split = output_lanes / 2; + + // 2. However, do align to max_lanes to keep chunks native-sized if possible + if (out_split > max_lanes) { + out_split = (out_split / max_lanes) * max_lanes; + } else if (output_lanes > max_lanes) { + // If the total is > max, but half is < max (e.g. 6), + // we want to peel 'max' (4) rather than split (3). + out_split = max_lanes; + } + + // Take remainder beyond the split point + int out_remaining = output_lanes - out_split; + internal_assert(out_remaining >= 1); + + // Calculate the reduction factor to find where to split the input + // e.g., 16 input -> 2 output means factor is 8. + // If we want the first 1 output lane, we need the first 8 input lanes. + int reduction_factor = input_lanes / output_lanes; + int in_split = out_split * reduction_factor; + int in_remaining = input_lanes - in_split; + + Expr arg_lo = extract_lanes(value, 0, 1, in_split, sliceable_vectors, requested_slices); + Expr arg_hi = extract_lanes(value, in_split, 1, in_remaining, sliceable_vectors, requested_slices); + + // Recursively mutate the smaller reductions + Expr res_lo = mutate(VectorReduce::make(op->op, arg_lo, out_split)); + Expr res_hi = mutate(VectorReduce::make(op->op, arg_hi, out_remaining)); + + // Concatenate the results to form the new vector + return Shuffle::make_concat({res_lo, res_hi}); + } + + // Case B: Horizontal Reduction (Single Output Lane) + // Example: VectorReduce( <16 lanes>, output_lanes=1 ) with max_lanes=4. + // We cannot split the output. We must split the INPUT, reduce both halves + // to scalars, and then combine them. + if (output_lanes == 1) { + int in_split = input_lanes / 2; + int in_remaining = input_lanes - in_split; + + // Extract input halves + Expr arg_lo = extract_lanes(value, 0, 1, in_split, sliceable_vectors, requested_slices); + Expr arg_hi = extract_lanes(value, in_split, 1, in_remaining, sliceable_vectors, requested_slices); + + // Recursively reduce both halves to scalars + Expr res_lo = mutate(VectorReduce::make(op->op, arg_lo, 1)); + Expr res_hi = mutate(VectorReduce::make(op->op, arg_hi, 1)); + + // Combine using the standard binary operator for this reduction type + return make_binary_reduce_op(op->op, res_lo, res_hi); + } + + internal_error << "Unreachable"; + return op; } public: diff --git a/src/Simplify_Shuffle.cpp b/src/Simplify_Shuffle.cpp index 91ff3e784c8d..01cfce0151d9 100644 --- a/src/Simplify_Shuffle.cpp +++ b/src/Simplify_Shuffle.cpp @@ -364,16 +364,6 @@ Expr Simplify::visit(const Shuffle *op, ExprInfo *info) { } } -#if 0 // Not sure what this was for. Disabling for now, and will run tests to see what's up. - for (size_t i = 0; i < new_vectors.size() && can_collapse; i++) { - if (new_vectors[i].as()) { - // Don't create a Ramp of a Load, like: - // ramp(buf[x], buf[x + 1] - buf[x], ...) - can_collapse = false; - } - } -#endif - if (can_collapse) { return Ramp::make(new_vectors[0], stride, op->indices.size()); } diff --git a/test/performance/nested_vectorization_gemm.cpp b/test/performance/nested_vectorization_gemm.cpp index 660d3d7bbdf8..7c12bca0b94e 100644 --- a/test/performance/nested_vectorization_gemm.cpp +++ b/test/performance/nested_vectorization_gemm.cpp @@ -300,7 +300,6 @@ int main(int argc, char **argv) { return 1; } } - printf("Success!\n"); // 8-bit sparse blur into 32-bit accumulator { From d9184c8ee5550469055bc58819d4ff21d2257f28 Mon Sep 17 00:00:00 2001 From: Martijn Courteaux Date: Sun, 1 Mar 2026 23:37:46 +0100 Subject: [PATCH 20/47] Don't use designated initializers. We're not on C++20 yet... :( --- src/Deinterleave.cpp | 11 +++++++---- test/performance/nested_vectorization_gemm.cpp | 1 + 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/Deinterleave.cpp b/src/Deinterleave.cpp index 0337cabae241..6824e3b7bcaf 100644 --- a/src/Deinterleave.cpp +++ b/src/Deinterleave.cpp @@ -408,10 +408,13 @@ class ExtractLanes : public IRGraphMutator { std::string sliced_var_name = variable_name_with_extracted_lanes( op->name, op->type.lanes(), starting_lane, lane_stride, new_lanes); - VectorSlice new_sl = {.start = starting_lane, - .stride = lane_stride, - .count = new_lanes, - .variable_name = sliced_var_name}; + + VectorSlice new_sl; // When C++20 lands: Designated initializer + new_sl.start = starting_lane; + new_sl.stride = lane_stride; + new_sl.count = new_lanes; + new_sl.variable_name = sliced_var_name; + if (auto *vec = requested_slices.shallow_find(op->name)) { bool found = false; for (const VectorSlice &existing_sl : *vec) { diff --git a/test/performance/nested_vectorization_gemm.cpp b/test/performance/nested_vectorization_gemm.cpp index 7c12bca0b94e..4d831e4ba247 100644 --- a/test/performance/nested_vectorization_gemm.cpp +++ b/test/performance/nested_vectorization_gemm.cpp @@ -395,5 +395,6 @@ int main(int argc, char **argv) { } } + printf("Success!\n"); return 0; } From 9601159a46d0a3110efa347ea881ad41f5b1d699 Mon Sep 17 00:00:00 2001 From: Martijn Courteaux Date: Sun, 1 Mar 2026 23:40:50 +0100 Subject: [PATCH 21/47] clang-format --- src/Deinterleave.cpp | 2 +- src/LegalizeVectors.cpp | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/src/Deinterleave.cpp b/src/Deinterleave.cpp index 6824e3b7bcaf..7cbbd2a1bc88 100644 --- a/src/Deinterleave.cpp +++ b/src/Deinterleave.cpp @@ -409,7 +409,7 @@ class ExtractLanes : public IRGraphMutator { op->name, op->type.lanes(), starting_lane, lane_stride, new_lanes); - VectorSlice new_sl; // When C++20 lands: Designated initializer + VectorSlice new_sl; // When C++20 lands: Designated initializer new_sl.start = starting_lane; new_sl.stride = lane_stride; new_sl.count = new_lanes; diff --git a/src/LegalizeVectors.cpp b/src/LegalizeVectors.cpp index a10b8cb8fd2a..1333ed41f8cc 100644 --- a/src/LegalizeVectors.cpp +++ b/src/LegalizeVectors.cpp @@ -76,7 +76,6 @@ class LiftLetToLetStmt : public IRMutator { } }; - class LiftExceedingVectors : public IRMutator { using IRMutator::visit; From cbc0031b147a881c7e63ce1ad00acdc0c3d54d3b Mon Sep 17 00:00:00 2001 From: Martijn Courteaux Date: Sun, 1 Mar 2026 23:44:46 +0100 Subject: [PATCH 22/47] unrelated clang-format??? --- src/runtime/vulkan_resources.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/runtime/vulkan_resources.h b/src/runtime/vulkan_resources.h index a4c4f3b0d7e5..cf6f3d60a5ab 100644 --- a/src/runtime/vulkan_resources.h +++ b/src/runtime/vulkan_resources.h @@ -1858,7 +1858,7 @@ int vk_device_crop_from_offset(void *user_context, uint64_t t_before = halide_current_time_ns(user_context); #endif - if (byte_offset < 0) { + if (byte_offset < 0) { error(user_context) << "Vulkan: Invalid offset for device crop!"; return halide_error_code_device_crop_failed; } From d747743111dce76d5867acdf2ceaed27d514ea4a Mon Sep 17 00:00:00 2001 From: Martijn Courteaux Date: Tue, 3 Mar 2026 11:38:17 +0100 Subject: [PATCH 23/47] Slightly better early-outing of the ExtractLanes mutator. --- src/Deinterleave.cpp | 148 ++++++++++++++++++++++++------------------- 1 file changed, 84 insertions(+), 64 deletions(-) diff --git a/src/Deinterleave.cpp b/src/Deinterleave.cpp index 7cbbd2a1bc88..d817f085b2b5 100644 --- a/src/Deinterleave.cpp +++ b/src/Deinterleave.cpp @@ -227,12 +227,17 @@ class ExtractLanes : public IRGraphMutator { using IRMutator::visit; + inline bool needs_extracting(const Expr &op) { + if (op.type.is_scalar()) { return false; } + return !(starting_lane == 0 && lane_stride == 1 && new_lanes == op.type.lanes()); + } + Expr extract_lanes_from_make_struct(const Call *op) { internal_assert(op); internal_assert(op->is_intrinsic(Call::make_struct)); - std::vector args(op->args.size()); - for (int i = 0; i < int(op->args.size()); ++i) { - args[i] = mutate(op->args[i]); + auto [args, changed] = mutate_with_changes(op->args); + if (!changed) { + return op; } return Call::make(op->type, Call::make_struct, args, Call::Intrinsic); } @@ -271,6 +276,9 @@ class ExtractLanes : public IRGraphMutator { } Expr visit(const VectorReduce *op) override { + if (!needs_extracting(op)) { + return op; + } int factor = op->value.type().lanes() / op->type.lanes(); if (lane_stride != 1) { std::vector input_lanes; @@ -289,6 +297,9 @@ class ExtractLanes : public IRGraphMutator { ScopedValue old_new_lanes(new_lanes, new_lanes * factor); in = mutate(op->value); } + if (new_lanes == op->type.lanes() && in.same_as(op->value)) { + return op; + } return VectorReduce::make(op->op, in, new_lanes); } } @@ -323,31 +334,36 @@ class ExtractLanes : public IRGraphMutator { return mutate(flatten_nested_ramps(op)); } + if (new_lanes == op->type.lanes()) { + return op; + } return Broadcast::make(op->value, new_lanes); } Expr visit(const Load *op) override { - if (op->type.is_scalar()) { + if (!needs_extracting(op)) { return op; - } else { - Type t = op->type.with_lanes(new_lanes); - ModulusRemainder align = op->alignment; - // The alignment of a Load refers to the alignment of the first - // lane, so we can preserve the existing alignment metadata if the - // deinterleave is asking for any subset of lanes that includes the - // first. Otherwise we just drop it. We could check if the index is - // a ramp with constant stride or some other special case, but if - // that's the case, the simplifier is very good at figuring out the - // alignment, and it has access to context (e.g. the alignment of - // enclosing lets) that we do not have here. - if (starting_lane != 0) { - align = ModulusRemainder(); - } - return Load::make(t, op->name, mutate(op->index), op->image, op->param, mutate(op->predicate), align); } + Type t = op->type.with_lanes(new_lanes); + ModulusRemainder align = op->alignment; + // The alignment of a Load refers to the alignment of the first + // lane, so we can preserve the existing alignment metadata if the + // deinterleave is asking for any subset of lanes that includes the + // first. Otherwise we just drop it. We could check if the index is + // a ramp with constant stride or some other special case, but if + // that's the case, the simplifier is very good at figuring out the + // alignment, and it has access to context (e.g. the alignment of + // enclosing lets) that we do not have here. + if (starting_lane != 0) { + align = ModulusRemainder(); + } + return Load::make(t, op->name, mutate(op->index), op->image, op->param, mutate(op->predicate), align); } Expr visit(const Ramp *op) override { + if (!needs_extracting(op)) { + return op; + } int base_lanes = op->base.type().lanes(); if (base_lanes > 1) { if (new_lanes == 1) { @@ -390,56 +406,55 @@ class ExtractLanes : public IRGraphMutator { } Expr visit(const Variable *op) override { - if (op->type.is_scalar()) { + if (!needs_extracting(op)) { return op; - } else { + } - Type t = op->type.with_lanes(new_lanes); - /* - internal_assert((op->type.lanes() - starting_lane + lane_stride - 1) / lane_stride == new_lanes) - << "Deinterleaving with lane stride " << lane_stride << " and staring lane " << starting_lane - << " for var of Type " << op->type << " to " << t << " drops lanes unexpectedly." - << " Deinterleaver probably recursed too deep into types of different lane count."; - */ - - if (sliceable_lets.contains(op->name)) { - // The variable accessed is marked as sliceable by the caller. - // Let's request a slice and pretend it exists. - std::string sliced_var_name = variable_name_with_extracted_lanes( - op->name, op->type.lanes(), - starting_lane, lane_stride, new_lanes); - - VectorSlice new_sl; // When C++20 lands: Designated initializer - new_sl.start = starting_lane; - new_sl.stride = lane_stride; - new_sl.count = new_lanes; - new_sl.variable_name = sliced_var_name; - - if (auto *vec = requested_slices.shallow_find(op->name)) { - bool found = false; - for (const VectorSlice &existing_sl : *vec) { - if (existing_sl.start == starting_lane && - existing_sl.stride == lane_stride && - existing_sl.count == new_lanes) { - found = true; - break; - } - } - if (!found) { - vec->push_back(std::move(new_sl)); + Type t = op->type.with_lanes(new_lanes); + /* + internal_assert((op->type.lanes() - starting_lane + lane_stride - 1) / lane_stride == new_lanes) + << "Deinterleaving with lane stride " << lane_stride << " and staring lane " << starting_lane + << " for var of Type " << op->type << " to " << t << " drops lanes unexpectedly." + << " Deinterleaver probably recursed too deep into types of different lane count."; + */ + + if (sliceable_lets.contains(op->name)) { + // The variable accessed is marked as sliceable by the caller. + // Let's request a slice and pretend it exists. + std::string sliced_var_name = variable_name_with_extracted_lanes( + op->name, op->type.lanes(), + starting_lane, lane_stride, new_lanes); + + VectorSlice new_sl; // When C++20 lands: Designated initializer + new_sl.start = starting_lane; + new_sl.stride = lane_stride; + new_sl.count = new_lanes; + new_sl.variable_name = sliced_var_name; + + if (auto *vec = requested_slices.shallow_find(op->name)) { + bool found = false; + for (const VectorSlice &existing_sl : *vec) { + if (existing_sl.start == starting_lane && + existing_sl.stride == lane_stride && + existing_sl.count == new_lanes) { + found = true; + break; } - } else { - requested_slices.push(op->name, {std::move(new_sl)}); } - return Variable::make(t, sliced_var_name, op->image, op->param, op->reduction_domain); + if (!found) { + vec->push_back(std::move(new_sl)); + } } else { - return give_up_and_shuffle(op); + requested_slices.push(op->name, {std::move(new_sl)}); } + return Variable::make(t, sliced_var_name, op->image, op->param, op->reduction_domain); + } else { + return give_up_and_shuffle(op); } } Expr visit(const Cast *op) override { - if (op->type.is_scalar()) { + if (!needs_extracting(op)) { return op; } else { Type t = op->type.with_lanes(new_lanes); @@ -451,7 +466,7 @@ class ExtractLanes : public IRGraphMutator { // Written with assistance from Gemini 3 Pro, which required a lot of baby-sitting. // Simple case of a scalar reinterpret: always one lane: - if (op->type.is_scalar()) { + if (!needs_extracting(op)) { return op; } @@ -553,23 +568,28 @@ class ExtractLanes : public IRGraphMutator { Expr visit(const Call *op) override { internal_assert(op->type.lanes() >= starting_lane + lane_stride * (new_lanes - 1)) << Expr(op) << starting_lane << " " << lane_stride << " " << new_lanes; - Type t = op->type.with_lanes(new_lanes); // Don't mutate scalars - if (op->type.is_scalar()) { + if (!needs_extracting(op)) { return op; } else { // Vector calls are always parallel across the lanes, so we // can just deinterleave the args. + Type t = op->type.with_lanes(new_lanes); // Beware of intrinsics for which this is not true! - auto args = mutate(op->args); + auto [args, changed] = mutate_with_changes(op->args); + internal_assert(changed); return Call::make(t, op->name, args, op->call_type, op->func, op->value_index, op->image, op->param); } } Expr visit(const Shuffle *op) override { + if (!needs_extracting(op)) { + return op; + } + // Special case 1: Scalar extraction if (new_lanes == 1) { // Find in which vector it sits. @@ -1012,7 +1032,7 @@ class Interleaver : public IRMutator { const Ramp *ri = stores[i].as()->index.as(); internal_assert(ri); - // Mismatched store vector laness. + // Mismatched store vector lanes. if (ri->lanes != lanes) { return Stmt(); } From 193779746f121da4ed782c42ecdff1b36f6a42ee Mon Sep 17 00:00:00 2001 From: Martijn Courteaux Date: Tue, 3 Mar 2026 11:45:28 +0100 Subject: [PATCH 24/47] Forgot brackets. --- src/Deinterleave.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/Deinterleave.cpp b/src/Deinterleave.cpp index d817f085b2b5..b2ae5c74f7f5 100644 --- a/src/Deinterleave.cpp +++ b/src/Deinterleave.cpp @@ -228,8 +228,10 @@ class ExtractLanes : public IRGraphMutator { using IRMutator::visit; inline bool needs_extracting(const Expr &op) { - if (op.type.is_scalar()) { return false; } - return !(starting_lane == 0 && lane_stride == 1 && new_lanes == op.type.lanes()); + if (op.type().is_scalar()) { + return false; + } + return !(starting_lane == 0 && lane_stride == 1 && new_lanes == op.type().lanes()); } Expr extract_lanes_from_make_struct(const Call *op) { From a2f084b34398cd70a29617d160c900a9371d6616 Mon Sep 17 00:00:00 2001 From: Martijn Courteaux Date: Wed, 4 Mar 2026 00:14:36 +0100 Subject: [PATCH 25/47] Clang-tidy. --- src/LegalizeVectors.cpp | 42 +++++++++++++++++++---------------------- 1 file changed, 19 insertions(+), 23 deletions(-) diff --git a/src/LegalizeVectors.cpp b/src/LegalizeVectors.cpp index 1333ed41f8cc..3d134fd081a3 100644 --- a/src/LegalizeVectors.cpp +++ b/src/LegalizeVectors.cpp @@ -321,28 +321,6 @@ class LegalizeVectors : public IRMutator { return IRMutator::visit(op); } - Expr make_binary_reduce_op(VectorReduce::Operator op, Expr a, Expr b) { - switch (op) { - case VectorReduce::Add: - return a + b; - case VectorReduce::SaturatingAdd: - return saturating_add(a, b); - case VectorReduce::Mul: - return a * b; - case VectorReduce::Min: - return min(a, b); - case VectorReduce::Max: - return max(a, b); - case VectorReduce::And: - return a && b; - case VectorReduce::Or: - return a || b; - default: - internal_error << "Unknown VectorReduce operator\n"; - return Expr(); - } - } - Expr visit(const VectorReduce *op) override { // Written with the help of Gemini 3 Pro. Expr value = mutate(op->value); @@ -416,7 +394,25 @@ class LegalizeVectors : public IRMutator { Expr res_hi = mutate(VectorReduce::make(op->op, arg_hi, 1)); // Combine using the standard binary operator for this reduction type - return make_binary_reduce_op(op->op, res_lo, res_hi); + switch (op->op) { + case VectorReduce::Add: + return res_lo + res_hi; + case VectorReduce::SaturatingAdd: + return saturating_add(res_lo, res_hi); + case VectorReduce::Mul: + return res_lo * res_hi; + case VectorReduce::Min: + return min(res_lo, res_hi); + case VectorReduce::Max: + return max(res_lo, res_hi); + case VectorReduce::And: + return res_lo && res_hi; + case VectorReduce::Or: + return res_lo || res_hi; + default: + internal_error << "Unknown VectorReduce operator\n"; + return Expr(); + } } internal_error << "Unreachable"; From 202d5c0c87300cf07c5264caa54016af8c0fb0c3 Mon Sep 17 00:00:00 2001 From: Martijn Courteaux Date: Wed, 4 Mar 2026 00:15:32 +0100 Subject: [PATCH 26/47] Two bugs identified by Gemini in CodeGen_Hexagon Co-authored-by: Gemini 3.1 Pro --- src/CodeGen_Hexagon.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/CodeGen_Hexagon.cpp b/src/CodeGen_Hexagon.cpp index 92aeaf5018db..a6537787974d 100644 --- a/src/CodeGen_Hexagon.cpp +++ b/src/CodeGen_Hexagon.cpp @@ -1158,8 +1158,8 @@ Value *CodeGen_Hexagon::shuffle_vectors(Value *a, Value *b, llvm::Type *result_ty = get_vector_type(element_ty, result_elements); // Try to rewrite shuffles that only access the elements of b. - int min = indices[0]; - for (size_t i = 1; i < indices.size(); i++) { + int min = INT_MAX; + for (size_t i = 0; i < indices.size(); i++) { if (indices[i] != -1 && indices[i] < min) { min = indices[i]; } @@ -1171,7 +1171,7 @@ Value *CodeGen_Hexagon::shuffle_vectors(Value *a, Value *b, i -= a_elements; } } - return shuffle_vectors(b, shifted_indices); + return shuffle_vectors(b, b, shifted_indices); } // Try to rewrite shuffles that only access the elements of a. From 0f22d8739525acecab87425f6298b180c95d01e7 Mon Sep 17 00:00:00 2001 From: Martijn Courteaux Date: Fri, 6 Mar 2026 15:00:08 +0100 Subject: [PATCH 27/47] Fix the shuffle bug that's causing everything to fail. Co-authored-by: Gemini 3 Pro --- src/Deinterleave.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Deinterleave.cpp b/src/Deinterleave.cpp index b2ae5c74f7f5..0312a01e7bb0 100644 --- a/src/Deinterleave.cpp +++ b/src/Deinterleave.cpp @@ -595,7 +595,7 @@ class ExtractLanes : public IRGraphMutator { // Special case 1: Scalar extraction if (new_lanes == 1) { // Find in which vector it sits. - int index = starting_lane; + int index = op->indices[starting_lane]; for (const auto &vec : op->vectors) { if (index < vec.type().lanes()) { // We found the source vector. Extract the scalar from it. From 6f71253aa1f31721a452a1cf9b4cfaac99cb77cd Mon Sep 17 00:00:00 2001 From: Martijn Courteaux Date: Fri, 6 Mar 2026 17:20:24 +0100 Subject: [PATCH 28/47] Two bugs found by Gemini Pro. Co-authored-by: Gemini 3 Pro --- src/CodeGen_Hexagon.cpp | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/src/CodeGen_Hexagon.cpp b/src/CodeGen_Hexagon.cpp index 208672c648fd..e50fe7521813 100644 --- a/src/CodeGen_Hexagon.cpp +++ b/src/CodeGen_Hexagon.cpp @@ -1158,10 +1158,10 @@ Value *CodeGen_Hexagon::shuffle_vectors(Value *a, Value *b, llvm::Type *result_ty = get_vector_type(element_ty, result_elements); // Try to rewrite shuffles that only access the elements of b. - int min = indices[0]; - for (size_t i = 1; i < indices.size(); i++) { - if (indices[i] != -1 && indices[i] < min) { - min = indices[i]; + int min = INT_MAX; + for (int idx : indices) { + if (idx != -1 && idx < min) { + min = idx; } } if (min >= a_elements) { @@ -1683,7 +1683,7 @@ Value *CodeGen_Hexagon::vlut(Value *lut, Value *idx, int min_index, int max_inde // contains the result of each range, and a condition vector // indicating whether the result should be used. vector> ranges; - for (int min_index_i = 0; min_index_i < max_index; min_index_i += 256) { + for (int min_index_i = 0; min_index_i <= max_index; min_index_i += 256) { // Make a vector of the indices shifted such that the min of // this range is at 0. Use 16-bit indices for this. Value *min_index_i_val = create_vector(i16x_t, min_index_i); @@ -1697,9 +1697,11 @@ Value *CodeGen_Hexagon::vlut(Value *lut, Value *idx, int min_index, int max_inde // truncate to 8 bits, as vlut requires. indices = call_intrin(i8x_t, "halide.hexagon.pack.vh", {indices}); - int range_extent_i = std::min(max_index - min_index_i, 255); - Value *range_i = vlut256(slice_vector(lut, min_index_i, range_extent_i), - indices, 0, range_extent_i); + int local_max_index = std::min(max_index - min_index_i, 255); + int slice_size = local_max_index + 1; + + Value *range_i = vlut256(slice_vector(lut, min_index_i, slice_size), + indices, 0, local_max_index); ranges.emplace_back(range_i, use_index); } From 7d9370d713847efab49bc067efa8e417a84700b9 Mon Sep 17 00:00:00 2001 From: Martijn Courteaux Date: Fri, 6 Mar 2026 17:22:18 +0100 Subject: [PATCH 29/47] Another bug found by Gemini Pro. Co-authored-by: Gemini 3 Pro --- src/CodeGen_Hexagon.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/CodeGen_Hexagon.cpp b/src/CodeGen_Hexagon.cpp index e50fe7521813..ad6906efd951 100644 --- a/src/CodeGen_Hexagon.cpp +++ b/src/CodeGen_Hexagon.cpp @@ -1171,7 +1171,7 @@ Value *CodeGen_Hexagon::shuffle_vectors(Value *a, Value *b, i -= a_elements; } } - return shuffle_vectors(b, shifted_indices); + return shuffle_vectors(b, b, shifted_indices); } // Try to rewrite shuffles that only access the elements of a. From 3c5637807012e7863ef9b61b5705699f3653cc25 Mon Sep 17 00:00:00 2001 From: Martijn Courteaux Date: Fri, 6 Mar 2026 23:40:37 +0100 Subject: [PATCH 30/47] Fix infinite recursion on shuffles of vectors with exclusively don't-care indices. --- src/CodeGen_Hexagon.cpp | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/CodeGen_Hexagon.cpp b/src/CodeGen_Hexagon.cpp index ad6906efd951..0cadc5fe2a4b 100644 --- a/src/CodeGen_Hexagon.cpp +++ b/src/CodeGen_Hexagon.cpp @@ -1159,11 +1159,16 @@ Value *CodeGen_Hexagon::shuffle_vectors(Value *a, Value *b, // Try to rewrite shuffles that only access the elements of b. int min = INT_MAX; + int max = -1; for (int idx : indices) { - if (idx != -1 && idx < min) { - min = idx; + if (idx != -1) { + min = std::min(min, idx); + max = std::max(max, idx); } } + if (min == INT_MAX) { + return llvm::PoisonValue::get(result_ty); + } if (min >= a_elements) { vector shifted_indices(indices); for (int &i : shifted_indices) { @@ -1175,7 +1180,6 @@ Value *CodeGen_Hexagon::shuffle_vectors(Value *a, Value *b, } // Try to rewrite shuffles that only access the elements of a. - int max = *std::max_element(indices.begin(), indices.end()); if (max < a_elements) { BitCastInst *a_cast = dyn_cast(a); CallInst *a_call = dyn_cast(a_cast ? a_cast->getOperand(0) : a); From 47797376b95ce4fe5ce1b0b2e3b20b95e8c5e8a1 Mon Sep 17 00:00:00 2001 From: Martijn Courteaux Date: Sat, 7 Mar 2026 10:44:13 +0100 Subject: [PATCH 31/47] I somehow f*cked up the git merge yesterday. --- src/CodeGen_Hexagon.cpp | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/CodeGen_Hexagon.cpp b/src/CodeGen_Hexagon.cpp index 9ed1dfdac65e..46d29cb66700 100644 --- a/src/CodeGen_Hexagon.cpp +++ b/src/CodeGen_Hexagon.cpp @@ -1157,7 +1157,7 @@ Value *CodeGen_Hexagon::shuffle_vectors(Value *a, Value *b, internal_assert(result_elements > 0); llvm::Type *result_ty = get_vector_type(element_ty, result_elements); - // Try to rewrite shuffles that only access the elements of b. + // Find the range of non-dont-care indices. int min = INT_MAX; int max = -1; for (int idx : indices) { @@ -1166,6 +1166,11 @@ Value *CodeGen_Hexagon::shuffle_vectors(Value *a, Value *b, max = std::max(max, idx); } } + if (min == INT_MAX) { + return llvm::PoisonValue::get(result_ty); + } + + // Try to rewrite shuffles that only access the elements of b. if (min >= a_elements) { vector shifted_indices(indices); for (int &i : shifted_indices) { From 65658800259377240587e9930fbab103e4831130 Mon Sep 17 00:00:00 2001 From: Martijn Courteaux Date: Tue, 10 Mar 2026 11:08:47 +0100 Subject: [PATCH 32/47] fix clang-tidy. --- src/Deinterleave.cpp | 6 +++--- src/Deinterleave.h | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/Deinterleave.cpp b/src/Deinterleave.cpp index 0312a01e7bb0..009984936aab 100644 --- a/src/Deinterleave.cpp +++ b/src/Deinterleave.cpp @@ -719,7 +719,7 @@ class ExtractLanes : public IRGraphMutator { } // namespace -Expr extract_lanes(Expr original_expr, int starting_lane, int lane_stride, int new_lanes, const Scope<> &lets, Scope> &requested_sliced_lets) { +Expr extract_lanes(const Expr &original_expr, int starting_lane, int lane_stride, int new_lanes, const Scope<> &lets, Scope> &requested_sliced_lets) { internal_assert(starting_lane + (new_lanes - 1) * lane_stride <= original_expr.type().lanes()) << "Extract lanes with start:" << starting_lane << ", stride:" << lane_stride << ", new_lanes:" << new_lanes << " " << "out of " << original_expr.type() << " which goes out of bounds."; @@ -742,10 +742,10 @@ Expr extract_lanes(Expr original_expr, int starting_lane, int lane_stride, int n return e; } -Expr extract_lanes(Expr e, int starting_lane, int lane_stride, int new_lanes) { +Expr extract_lanes(const Expr &e, int starting_lane, int lane_stride, int new_lanes) { Scope<> lets; Scope> req; - return extract_lanes(std::move(e), starting_lane, lane_stride, new_lanes, lets, req); + return extract_lanes(e, starting_lane, lane_stride, new_lanes, lets, req); } Expr extract_even_lanes(const Expr &e) { diff --git a/src/Deinterleave.h b/src/Deinterleave.h index 0332e0bfc8c0..630fa8e7ecc1 100644 --- a/src/Deinterleave.h +++ b/src/Deinterleave.h @@ -20,10 +20,10 @@ struct VectorSlice { }; /* Extract lanes and relying on the fact that the caller will provide new variables in Lets or LetStmts which correspond to slices of the original variable. */ -Expr extract_lanes(Expr e, int starting_lane, int lane_stride, int new_lanes, const Scope<> &sliceable_lets, Scope> &requested_sliced_lets); +Expr extract_lanes(const Expr &e, int starting_lane, int lane_stride, int new_lanes, const Scope<> &sliceable_lets, Scope> &requested_sliced_lets); /* Extract lanes without requesting any extra slices from variables. */ -Expr extract_lanes(Expr e, int starting_lane, int lane_stride, int new_lanes); +Expr extract_lanes(const Expr &e, int starting_lane, int lane_stride, int new_lanes); /** Extract the nth lane of a vector */ Expr extract_lane(const Expr &vec, int lane); From 61d5c55c0d57bae4939b63b1c0ae1542db352058 Mon Sep 17 00:00:00 2001 From: Martijn Courteaux Date: Tue, 10 Mar 2026 12:49:31 +0100 Subject: [PATCH 33/47] Use CSE across stores during legalization. --- src/LegalizeVectors.cpp | 77 ++++++++++++++++++++++++++++++++--------- 1 file changed, 61 insertions(+), 16 deletions(-) diff --git a/src/LegalizeVectors.cpp b/src/LegalizeVectors.cpp index 3d134fd081a3..84d1111dc312 100644 --- a/src/LegalizeVectors.cpp +++ b/src/LegalizeVectors.cpp @@ -84,21 +84,25 @@ class LiftExceedingVectors : public IRMutator { vector> lets; bool just_in_let_definition{false}; - Expr visit(const Let *op) override { - internal_error << "We don't want to process Lets. They should have all been converted to LetStmts."; - return IRMutator::visit(op); - } - - Stmt visit(const LetStmt *op) override { + template + auto visit_let_or_letstmt(const LetOrLetStmt *op) -> decltype(op->body) { just_in_let_definition = true; Expr def = mutate(op->value); just_in_let_definition = false; - Stmt body = mutate(op->body); + decltype(op->body) body = mutate(op->body); if (def.same_as(op->value) && body.same_as(op->body)) { return op; } - return LetStmt::make(op->name, std::move(def), std::move(body)); + return LetOrLetStmt::make(op->name, std::move(def), std::move(body)); + } + + Expr visit(const Let *op) override { + return visit_let_or_letstmt(op); + } + + Stmt visit(const LetStmt *op) override { + return visit_let_or_letstmt(op); } Expr visit(const Call *op) override { @@ -206,7 +210,7 @@ class LegalizeVectors : public IRMutator { // First mark this Let as sliceable before mutating the body: ScopedBinding<> vector_is_slicable(sliceable_vectors, op->name); - Stmt body = mutate(op->body); + auto body = mutate(op->body); // Here we know which requested vector variable slices should be created for the body of the Let/LetStmt to work. if (std::vector *reqs = requested_slices.shallow_find(op->name)) { @@ -228,8 +232,8 @@ class LegalizeVectors : public IRMutator { } Expr visit(const Let *op) override { - // TODO is this still true? - internal_error << "Lets should have been lifted into LetStmts."; + bool exceeds_lanecount = op->value.type().lanes() > max_lanes; + internal_assert(!exceeds_lanecount) << "All illegal Let's should have been converted to LetStmts"; return IRMutator::visit(op); } @@ -238,20 +242,61 @@ class LegalizeVectors : public IRMutator { if (exceeds_lanecount) { // Split up in multiple stores int num_vecs = (op->index.type().lanes() + max_lanes - 1) / max_lanes; + + std::vector bundle_args; + bundle_args.reserve(num_vecs * 3); + + // Break up the index, predicate, and value of the Store into legal chunks. + for (int i = 0; i < num_vecs; ++i) { + int lane_start = i * max_lanes; + int lane_count_for_vec = std::min(op->value.type().lanes() - lane_start, max_lanes); + + // Pack them in a known order: rhs, index, predicate + bundle_args.push_back(extract_lanes(op->value, lane_start, 1, lane_count_for_vec, sliceable_vectors, requested_slices)); + bundle_args.push_back(extract_lanes(op->index, lane_start, 1, lane_count_for_vec, sliceable_vectors, requested_slices)); + bundle_args.push_back(extract_lanes(op->predicate, lane_start, 1, lane_count_for_vec, sliceable_vectors, requested_slices)); + } + + // Run CSE on the joint bundle + Expr joint_bundle = Call::make(Int(32), Call::bundle, bundle_args, Call::PureIntrinsic); + joint_bundle = common_subexpression_elimination(joint_bundle); + + // Peel off the `Let` expressions introduced by the CSE pass + std::vector> let_bindings; + while (const Let *let = joint_bundle.as()) { + let_bindings.emplace_back(let->name, let->value); + joint_bundle = let->body; + } + + // Destructure the bundle to get our optimized expressions + const Call *struct_call = joint_bundle.as(); + internal_assert(struct_call && struct_call->is_intrinsic(Call::bundle)) + << "Expected the CSE bundle to remain a bundle Call."; + + // Construct the legal stores with the CSE'd expressions std::vector assignments; assignments.reserve(num_vecs); for (int i = 0; i < num_vecs; ++i) { int lane_start = i * max_lanes; - int lane_count_for_vec = std::min(op->value.type().lanes() - lane_start, max_lanes); - Expr rhs = extract_lanes(op->value, lane_start, 1, lane_count_for_vec, sliceable_vectors, requested_slices); - Expr index = extract_lanes(op->index, lane_start, 1, lane_count_for_vec, sliceable_vectors, requested_slices); - Expr predictate = extract_lanes(op->predicate, lane_start, 1, lane_count_for_vec, sliceable_vectors, requested_slices); + // Unpack in the same order we packed them + Expr rhs = struct_call->args[i * 3 + 0]; + Expr index = struct_call->args[i * 3 + 1]; + Expr predicate = struct_call->args[i * 3 + 2]; + assignments.push_back(Store::make( op->name, std::move(rhs), std::move(index), - op->param, std::move(predictate), op->alignment + lane_start)); + op->param, std::move(predicate), op->alignment + lane_start)); } + Stmt result = Block::make(assignments); + + // Wrap the block in LetStmts to properly scope all shared expressions + // Iterate backwards to build the LetStmt tree from the inside out. + for (auto &let : reverse_view(let_bindings)) { + result = LetStmt::make(let.first, let.second, result); + } + debug(3) << "Legalized store " << Stmt(op) << " => " << result << "\n"; return result; } From 5f8e2269b9a94acf1c3be51aed426b98b524f2d5 Mon Sep 17 00:00:00 2001 From: Martijn Courteaux Date: Tue, 10 Mar 2026 20:29:28 +0100 Subject: [PATCH 34/47] Address review comments. --- src/Deinterleave.cpp | 6 ------ src/LegalizeVectors.cpp | 14 ++++++++++---- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/Deinterleave.cpp b/src/Deinterleave.cpp index 009984936aab..47cb3fd057d6 100644 --- a/src/Deinterleave.cpp +++ b/src/Deinterleave.cpp @@ -413,12 +413,6 @@ class ExtractLanes : public IRGraphMutator { } Type t = op->type.with_lanes(new_lanes); - /* - internal_assert((op->type.lanes() - starting_lane + lane_stride - 1) / lane_stride == new_lanes) - << "Deinterleaving with lane stride " << lane_stride << " and staring lane " << starting_lane - << " for var of Type " << op->type << " to " << t << " drops lanes unexpectedly." - << " Deinterleaver probably recursed too deep into types of different lane count."; - */ if (sliceable_lets.contains(op->name)) { // The variable accessed is marked as sliceable by the caller. diff --git a/src/LegalizeVectors.cpp b/src/LegalizeVectors.cpp index 84d1111dc312..85fba6c3eca1 100644 --- a/src/LegalizeVectors.cpp +++ b/src/LegalizeVectors.cpp @@ -213,13 +213,14 @@ class LegalizeVectors : public IRMutator { auto body = mutate(op->body); // Here we know which requested vector variable slices should be created for the body of the Let/LetStmt to work. - if (std::vector *reqs = requested_slices.shallow_find(op->name)) { + if (const std::vector *reqs = requested_slices.find(op->name)) { for (const VectorSlice &sl : *reqs) { Expr value = extract_lanes(op->value, sl.start, sl.stride, sl.count, sliceable_vectors, requested_slices); value = mutate(value); body = LetOrLetStmt::make(sl.variable_name, value, body); debug(3) << " Add: let " << sl.variable_name << " = " << value << "\n"; } + requested_slices.pop(op->name); } return body; } else { @@ -277,16 +278,21 @@ class LegalizeVectors : public IRMutator { std::vector assignments; assignments.reserve(num_vecs); for (int i = 0; i < num_vecs; ++i) { - int lane_start = i * max_lanes; // Unpack in the same order we packed them Expr rhs = struct_call->args[i * 3 + 0]; Expr index = struct_call->args[i * 3 + 1]; Expr predicate = struct_call->args[i * 3 + 2]; + ModulusRemainder alignment = op->alignment; + if (i != 0) { + // In case i == 0, we are taking the first lane, and the alignment is still valid. + alignment = ModulusRemainder(); + } + assignments.push_back(Store::make( op->name, std::move(rhs), std::move(index), - op->param, std::move(predicate), op->alignment + lane_start)); + op->param, std::move(predicate), alignment)); } Stmt result = Block::make(assignments); @@ -304,7 +310,7 @@ class LegalizeVectors : public IRMutator { } Expr visit(const Shuffle *op) override { - // Primary violatation: there are too many output lanes. + // Primary violation: there are too many output lanes. if (op->type.lanes() > max_lanes) { // Break it down in multiple legal-output-length shuffles, and concatenate them back together. int total_lanes = op->type.lanes(); From 459eed2193d6ff5c0f5fc1189a90eb3e7faf6a3d Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Tue, 10 Mar 2026 16:38:49 -0700 Subject: [PATCH 35/47] Add a fuzzer for extract_lanes and fix issues found Several pre-existing bugs were found in the simplifier too. Co-authored-by: Claude Code --- src/Deinterleave.cpp | 158 +++++++-- src/LegalizeVectors.cpp | 1 + src/Simplify_Exprs.cpp | 36 +- src/Simplify_Shuffle.cpp | 8 +- test/correctness/CMakeLists.txt | 1 + test/correctness/fuzz_extract_lanes.cpp | 453 ++++++++++++++++++++++++ 6 files changed, 597 insertions(+), 60 deletions(-) create mode 100644 test/correctness/fuzz_extract_lanes.cpp diff --git a/src/Deinterleave.cpp b/src/Deinterleave.cpp index 009984936aab..7535bd8bbe05 100644 --- a/src/Deinterleave.cpp +++ b/src/Deinterleave.cpp @@ -203,7 +203,7 @@ Stmt collect_strided_stores(const Stmt &stmt, const std::string &name, int strid return collect.mutate(stmt); } -class ExtractLanes : public IRGraphMutator { +class ExtractLanes : public IRMutator { public: ExtractLanes( int starting_lane, int lane_stride, int new_lanes, @@ -212,8 +212,8 @@ class ExtractLanes : public IRGraphMutator { : starting_lane(starting_lane), lane_stride(lane_stride), new_lanes(new_lanes), - sliceable_lets(sliceable_lets), requested_slices(requested_slices) { + this->sliceable_lets.set_containing_scope(&sliceable_lets); } private: @@ -221,9 +221,11 @@ class ExtractLanes : public IRGraphMutator { int lane_stride; int new_lanes; - // lets for which we have even and odd lane specializations - const Scope<> &sliceable_lets; - Scope> &requested_slices; // We populate this with the slices we need from the external_lets. + // vector lets we're allowed to request slices of + Scope<> sliceable_lets; + + // We populate this with the slices we need from the external_lets. + Scope> &requested_slices; using IRMutator::visit; @@ -277,6 +279,72 @@ class ExtractLanes : public IRGraphMutator { internal_error << "Unhandled trace call in ExtractLanes: " << *event; } + Expr visit(const Let *op) override { + + // Visit an entire chain of lets in a single method to conserve stack space. + + // This logic is very to the same visit method in interleaver, but not + // the same. We don't mutate the let values by default, we just produce + // any requested slices of them. + + struct Frame { + const Let *op; + ScopedBinding<> binding; + Frame(const Let *op, Scope &scope) + : op(op), + binding(op->value.type().is_vector(), scope, op->name) { + } + }; + std::vector frames; + Expr result; + + do { + result = op->body; + frames.emplace_back(op, sliceable_lets); + } while ((op = result.template as())); + + result = mutate(result); + + std::set vars_used; + auto track_vars_used = [&](const Expr &e) { + return visit_with(e, + [&](auto *self, const Variable *var) { + vars_used.insert(var->name); + }); + }; + track_vars_used(result); + + for (const auto &frame : reverse_view(frames)) { + + // The original variable, if it's needed. + if (vars_used.count(frame.op->name)) { + result = Let::make(frame.op->name, frame.op->value, result); + track_vars_used(frame.op->value); + } + + // For vector lets, we may additionally need lets for the requested + // slices of this variable: + if (frame.op->value.type().is_vector()) { + if (std::vector *reqs = requested_slices.shallow_find(frame.op->name)) { + for (const VectorSlice &sl : *reqs) { + Expr slice; + { + ScopedValue old_start(starting_lane, sl.start); + ScopedValue old_stride(lane_stride, sl.stride); + ScopedValue old_count(new_lanes, sl.count); + slice = mutate(frame.op->value); + } + track_vars_used(slice); + result = Let::make(sl.variable_name, slice, result); + } + requested_slices.pop(frame.op->name); + } + } + } + + return result; + } + Expr visit(const VectorReduce *op) override { if (!needs_extracting(op)) { return op; @@ -377,11 +445,19 @@ class ExtractLanes : public IRGraphMutator { return expr; } else if (base_lanes == lane_stride && starting_lane < base_lanes) { - // Base class mutator actually works fine in this - // case, but we only want one lane from the base and - // one lane from the stride. - ScopedValue old_new_lanes(new_lanes, 1); - return IRMutator::visit(op); + // We want one lane from the base and one lane from + // the stride, then build a new ramp with the right + // number of steps. + int ramp_lanes = new_lanes; + { + ScopedValue old_new_lanes(new_lanes, 1); + Expr new_base = mutate(op->base); + Expr new_stride = mutate(op->stride); + if (ramp_lanes == 1) { + return new_base; + } + return Ramp::make(new_base, new_stride, ramp_lanes); + } } else { // There is probably a more efficient way to this by // generalizing the two cases above. @@ -391,7 +467,7 @@ class ExtractLanes : public IRGraphMutator { Expr expr = op->base + cast(op->base.type(), starting_lane) * op->stride; internal_assert(expr.type() == op->base.type()); if (new_lanes > 1) { - expr = Ramp::make(expr, op->stride * lane_stride, new_lanes); + expr = Ramp::make(expr, op->stride * cast(op->base.type(), lane_stride), new_lanes); } return expr; } @@ -413,12 +489,6 @@ class ExtractLanes : public IRGraphMutator { } Type t = op->type.with_lanes(new_lanes); - /* - internal_assert((op->type.lanes() - starting_lane + lane_stride - 1) / lane_stride == new_lanes) - << "Deinterleaving with lane stride " << lane_stride << " and staring lane " << starting_lane - << " for var of Type " << op->type << " to " << t << " drops lanes unexpectedly." - << " Deinterleaver probably recursed too deep into types of different lane count."; - */ if (sliceable_lets.contains(op->name)) { // The variable accessed is marked as sliceable by the caller. @@ -502,17 +572,16 @@ class ExtractLanes : public IRGraphMutator { // we had to grab whole elements from the input, which can be coarser if out_bits > in_bits. // So calculate how many lanes we extracted, when measured in the reinterpreted output type. int intm_lanes = (num_input_lanes * in_bits) / out_bits; - Expr reinterprted = Reinterpret::make(op->type.with_lanes(intm_lanes), extracted_input_lanes); + Expr reinterpreted = Reinterpret::make(op->type.with_lanes(intm_lanes), extracted_input_lanes); // Now calculate how many we output Type lanes we need to trim away. int bits_to_strip_front = start_bit - (start_input_lane * in_bits); int lanes_to_strip_front = bits_to_strip_front / out_bits; - if (lanes_to_strip_front == 0) { - internal_assert(reinterprted.type().lanes() == new_lanes); - return reinterprted; + if (lanes_to_strip_front == 0 && intm_lanes == new_lanes) { + return reinterpreted; } else { - return Shuffle::make_slice(reinterprted, lanes_to_strip_front, 1, new_lanes); + return Shuffle::make_slice(reinterpreted, lanes_to_strip_front, 1, new_lanes); } } @@ -529,12 +598,12 @@ class ExtractLanes : public IRGraphMutator { int end_input_lane = (end_bit + in_bits - 1) / in_bits; int num_input_lanes = end_input_lane - start_input_lane; - // Grab this range of lanes from the input + // Grab this range of lanes from the input. Expr input_chunk; { - ScopedValue s_start(starting_lane, start_input_lane); - ScopedValue s_stride(lane_stride, 1); - ScopedValue s_len(new_lanes, num_input_lanes); + ScopedValue old_start(starting_lane, start_input_lane); + ScopedValue old_stride(lane_stride, 1); + ScopedValue old_count(new_lanes, num_input_lanes); input_chunk = mutate(op->value); } @@ -556,7 +625,7 @@ class ExtractLanes : public IRGraphMutator { int lane_offset = bit_offset / out_bits; if (lane_offset == 0 && reinterpreted_lanes == 1) { - chunks[i] = std::move(input_chunk); + chunks[i] = std::move(reinterpreted); } else { chunks[i] = Shuffle::make_extract_element(reinterpreted, lane_offset); } @@ -579,11 +648,17 @@ class ExtractLanes : public IRGraphMutator { // can just deinterleave the args. Type t = op->type.with_lanes(new_lanes); - // Beware of intrinsics for which this is not true! auto [args, changed] = mutate_with_changes(op->args); - internal_assert(changed); - return Call::make(t, op->name, args, op->call_type, - op->func, op->value_index, op->image, op->param); + if (!changed) { + // It's possible that this is a slice where output lanes = input + // lanes (e.g. reversing a vector) and the args are invariant + // under that slice (e.g. they are broadcasts). + internal_assert(t == op->type); + return op; + } else { + return Call::make(t, op->name, args, op->call_type, + op->func, op->value_index, op->image, op->param); + } } } @@ -618,7 +693,7 @@ class ExtractLanes : public IRGraphMutator { // Example: extract_lanes(interleave(A, B), stride=4) // result comes from either A or B, depending on starting lane modulo number of vectors, // required stride of said vector is lane_stride / num_vectors - if (lane_stride % n_vectors == 0) { + if (lane_stride > 0 && lane_stride % n_vectors == 0) { const Expr &vec = op->vectors[starting_lane % n_vectors]; if (vec.type().lanes() == new_lanes) { // We need all lanes of this vector, just return it. @@ -636,7 +711,7 @@ class ExtractLanes : public IRGraphMutator { // = extract_lanes(a0, b0, c0, d0, e0, f0, a1, b1, c1, d1, e1, f1, ...) // = (a2, c2, e2, c1, ...) // = interleave(a, c) - if (n_vectors % lane_stride == 0) { + if (lane_stride > 0 && n_vectors % lane_stride == 0) { int num_required_vectors = n_vectors / lane_stride; // The result is only an interleave if the number of constituent @@ -728,14 +803,15 @@ Expr extract_lanes(const Expr &original_expr, int starting_lane, int lane_stride << "(start:" << starting_lane << ", stride:" << lane_stride << ", new_lanes:" << new_lanes << "): " << original_expr << " of Type: " << original_expr.type() << "\n"; Type original_type = original_expr.type(); - Expr e = substitute_in_all_lets(original_expr); ExtractLanes d(starting_lane, lane_stride, new_lanes, lets, requested_sliced_lets); - e = d.mutate(e); + Expr e = d.mutate(original_expr); e = common_subexpression_elimination(e); debug(3) << " => " << e << "\n"; Type final_type = e.type(); - internal_assert(original_type.code() == final_type.code()) << "Underlying types not identical after extract_lanes."; - e = simplify(e); + internal_assert(original_type.code() == final_type.code()) + << "Underlying types not identical after extract_lanes:\n" + << "Before: " << original_expr << "\n" + << "After: " << e << "\n"; internal_assert(new_lanes == final_type.lanes()) << "Number of lanes incorrect after extract_lanes: " << final_type.lanes() << " while expected was " << new_lanes << ": extract_lanes(" << starting_lane << ", " << lane_stride << ", " << new_lanes << "):\n" << "Input: " << original_expr << "\nResult: " << e; @@ -814,12 +890,16 @@ class Interleaver : public IRMutator { // For vector lets, we may additionally need a lets for the requested slices of this variable: if (value.type().is_vector()) { - if (std::vector *reqs = requested_sliced_lets.shallow_find(frame.op->name)) { + if (std::vector *reqs = + requested_sliced_lets.shallow_find(frame.op->name)) { for (const VectorSlice &sl : *reqs) { result = LetOrLetStmt::make( sl.variable_name, - extract_lanes(value, sl.start, sl.stride, sl.count, vector_lets, requested_sliced_lets), result); + extract_lanes(value, sl.start, sl.stride, sl.count, + vector_lets, requested_sliced_lets), + result); } + requested_sliced_lets.pop(frame.op->name); } } } diff --git a/src/LegalizeVectors.cpp b/src/LegalizeVectors.cpp index 84d1111dc312..40da99815203 100644 --- a/src/LegalizeVectors.cpp +++ b/src/LegalizeVectors.cpp @@ -220,6 +220,7 @@ class LegalizeVectors : public IRMutator { body = LetOrLetStmt::make(sl.variable_name, value, body); debug(3) << " Add: let " << sl.variable_name << " = " << value << "\n"; } + requested_slices.pop(op->name); } return body; } else { diff --git a/src/Simplify_Exprs.cpp b/src/Simplify_Exprs.cpp index 0eb3bbaf3c15..12857a8a10e2 100644 --- a/src/Simplify_Exprs.cpp +++ b/src/Simplify_Exprs.cpp @@ -136,7 +136,7 @@ Expr Simplify::visit(const VectorReduce *op, ExprInfo *info) { rewrite(h_min(max(broadcast(x, arg_lanes), y), lanes), max(h_min(y, lanes), broadcast(x, lanes))) || rewrite(h_min(broadcast(x, arg_lanes), lanes), broadcast(x, lanes)) || rewrite(h_min(broadcast(x, c0), lanes), h_min(x, lanes), factor % c0 == 0) || - rewrite(h_min(ramp(x, y, arg_lanes), lanes), x + min(y * (arg_lanes - 1), 0)) || + (lanes == 1 && rewrite(h_min(ramp(x, y, arg_lanes), lanes), x + min(y * (arg_lanes - 1), 0))) || false) { return mutate(rewrite.result, info); } @@ -150,7 +150,7 @@ Expr Simplify::visit(const VectorReduce *op, ExprInfo *info) { rewrite(h_max(max(broadcast(x, arg_lanes), y), lanes), max(h_max(y, lanes), broadcast(x, lanes))) || rewrite(h_max(broadcast(x, arg_lanes), lanes), broadcast(x, lanes)) || rewrite(h_max(broadcast(x, c0), lanes), h_max(x, lanes), factor % c0 == 0) || - rewrite(h_max(ramp(x, y, arg_lanes), lanes), x + max(y * (arg_lanes - 1), 0)) || + (lanes == 1 && rewrite(h_max(ramp(x, y, arg_lanes), lanes), x + max(y * (arg_lanes - 1), 0))) || false) { return mutate(rewrite.result, info); } @@ -164,14 +164,14 @@ Expr Simplify::visit(const VectorReduce *op, ExprInfo *info) { rewrite(h_and(broadcast(x, arg_lanes) && y, lanes), h_and(y, lanes) && broadcast(x, lanes)) || rewrite(h_and(broadcast(x, arg_lanes), lanes), broadcast(x, lanes)) || rewrite(h_and(broadcast(x, c0), lanes), h_and(x, lanes), factor % c0 == 0) || - rewrite(h_and(ramp(x, y, arg_lanes) < broadcast(z, arg_lanes), lanes), - x + max(y * (arg_lanes - 1), 0) < z) || - rewrite(h_and(ramp(x, y, arg_lanes) <= broadcast(z, arg_lanes), lanes), - x + max(y * (arg_lanes - 1), 0) <= z) || - rewrite(h_and(broadcast(x, arg_lanes) < ramp(y, z, arg_lanes), lanes), - x < y + min(z * (arg_lanes - 1), 0)) || - rewrite(h_and(broadcast(x, arg_lanes) < ramp(y, z, arg_lanes), lanes), - x <= y + min(z * (arg_lanes - 1), 0)) || + (lanes == 1 && rewrite(h_and(ramp(x, y, arg_lanes) < broadcast(z, arg_lanes), lanes), + x + max(y * (arg_lanes - 1), 0) < z)) || + (lanes == 1 && rewrite(h_and(ramp(x, y, arg_lanes) <= broadcast(z, arg_lanes), lanes), + x + max(y * (arg_lanes - 1), 0) <= z)) || + (lanes == 1 && rewrite(h_and(broadcast(x, arg_lanes) < ramp(y, z, arg_lanes), lanes), + x < y + min(z * (arg_lanes - 1), 0))) || + (lanes == 1 && rewrite(h_and(broadcast(x, arg_lanes) < ramp(y, z, arg_lanes), lanes), + x <= y + min(z * (arg_lanes - 1), 0))) || false) { return mutate(rewrite.result, info); } @@ -186,14 +186,14 @@ Expr Simplify::visit(const VectorReduce *op, ExprInfo *info) { rewrite(h_or(broadcast(x, arg_lanes), lanes), broadcast(x, lanes)) || rewrite(h_or(broadcast(x, c0), lanes), h_or(x, lanes), factor % c0 == 0) || // type of arg_lanes is somewhat indeterminate - rewrite(h_or(ramp(x, y, arg_lanes) < broadcast(z, arg_lanes), lanes), - x + min(y * (arg_lanes - 1), 0) < z) || - rewrite(h_or(ramp(x, y, arg_lanes) <= broadcast(z, arg_lanes), lanes), - x + min(y * (arg_lanes - 1), 0) <= z) || - rewrite(h_or(broadcast(x, arg_lanes) < ramp(y, z, arg_lanes), lanes), - x < y + max(z * (arg_lanes - 1), 0)) || - rewrite(h_or(broadcast(x, arg_lanes) < ramp(y, z, arg_lanes), lanes), - x <= y + max(z * (arg_lanes - 1), 0)) || + (lanes == 1 && rewrite(h_or(ramp(x, y, arg_lanes) < broadcast(z, arg_lanes), lanes), + x + min(y * (arg_lanes - 1), 0) < z)) || + (lanes == 1 && rewrite(h_or(ramp(x, y, arg_lanes) <= broadcast(z, arg_lanes), lanes), + x + min(y * (arg_lanes - 1), 0) <= z)) || + (lanes == 1 && rewrite(h_or(broadcast(x, arg_lanes) < ramp(y, z, arg_lanes), lanes), + x < y + max(z * (arg_lanes - 1), 0))) || + (lanes == 1 && rewrite(h_or(broadcast(x, arg_lanes) < ramp(y, z, arg_lanes), lanes), + x <= y + max(z * (arg_lanes - 1), 0))) || false) { return mutate(rewrite.result, info); } diff --git a/src/Simplify_Shuffle.cpp b/src/Simplify_Shuffle.cpp index 01cfce0151d9..644418664ffc 100644 --- a/src/Simplify_Shuffle.cpp +++ b/src/Simplify_Shuffle.cpp @@ -106,7 +106,7 @@ Expr Simplify::visit(const Shuffle *op, ExprInfo *info) { if (new_vectors.size() == 1) { const Ramp *ramp = new_vectors[0].as(); - if (ramp && op->is_slice()) { + if (ramp && ramp->base.type().is_scalar() && op->is_slice()) { int first_lane_in_src = op->indices[0]; int slice_stride = op->slice_stride(); if (slice_stride >= 1) { @@ -201,9 +201,11 @@ Expr Simplify::visit(const Shuffle *op, ExprInfo *info) { // Try to collapse a shuffle of broadcasts into a single // broadcast. Note that it doesn't matter what the indices - // are. + // are. Only applies when the broadcast value is scalar, + // because Broadcast::make(vec, N) has vec.lanes() * N total + // lanes. const Broadcast *b1 = new_vectors[0].as(); - if (b1) { + if (b1 && b1->value.type().is_scalar()) { bool can_collapse = true; for (size_t i = 1; i < new_vectors.size() && can_collapse; i++) { if (const Broadcast *b2 = new_vectors[i].as()) { diff --git a/test/correctness/CMakeLists.txt b/test/correctness/CMakeLists.txt index 7970b42e1064..a7b5f98dbe39 100644 --- a/test/correctness/CMakeLists.txt +++ b/test/correctness/CMakeLists.txt @@ -126,6 +126,7 @@ tests(GROUPS correctness fuse_gpu_threads.cpp fused_where_inner_extent_is_zero.cpp fuzz_float_stores.cpp + fuzz_extract_lanes.cpp fuzz_schedule.cpp fuzz_simplify.cpp gameoflife.cpp diff --git a/test/correctness/fuzz_extract_lanes.cpp b/test/correctness/fuzz_extract_lanes.cpp new file mode 100644 index 000000000000..67ce54645ea6 --- /dev/null +++ b/test/correctness/fuzz_extract_lanes.cpp @@ -0,0 +1,453 @@ +#include "Halide.h" +#include +#include +#include + +// Fuzz test for deinterleave / extract_lane operations in Deinterleave.cpp. +// Constructs random vector expressions covering the IR node types that +// the Deinterleaver has visit methods for, evaluates them by JIT-compiling +// with a custom lowering pass, then checks that deinterleave() produces +// results consistent with the original expression. + +namespace { + +using std::string; +using std::vector; +using namespace Halide; +using namespace Halide::Internal; + +using RandomEngine = std::mt19937_64; + +constexpr int fuzz_var_count = 3; +std::vector> fuzz_vars(fuzz_var_count); + +template +decltype(auto) random_choice(RandomEngine &rng, T &&choices) { + std::uniform_int_distribution dist(0, std::size(choices) - 1); + return choices[dist(rng)]; +} + +Type fuzz_types[] = {UInt(8), UInt(16), UInt(32), UInt(64), Int(8), Int(16), Int(32), Int(64)}; + +Type random_scalar_type(RandomEngine &rng) { + return random_choice(rng, fuzz_types); +} + +int random_factor(RandomEngine &rng, int x) { + vector factors; + factors.reserve(x); + for (int i = 1; i < x; i++) { + if (x % i == 0) { + factors.push_back(i); + } + } + return random_choice(rng, factors); +} + +Expr random_const(RandomEngine &rng, Type t) { + int val = (int)((int8_t)(rng() & 0x0f)); + if (t.is_vector()) { + return Broadcast::make(cast(t.element_of(), val), t.lanes()); + } else { + return cast(t, val); + } +} + +Expr random_leaf(RandomEngine &rng, Type t) { + if (t.is_scalar()) { + if (rng() & 1) { + // Variable + std::uniform_int_distribution dist(0, fuzz_var_count - 1); + return cast(t, fuzz_vars[dist(rng)]); + } else { + return random_const(rng, t); + } + } + // For vector types, build from Ramp or Broadcast + int lanes = t.lanes(); + if (rng() & 1) { + Expr base = random_leaf(rng, t.element_of()); + Expr stride = random_const(rng, t.element_of()); + return Ramp::make(base, stride, lanes); + } else { + Expr val = random_leaf(rng, t.element_of()); + return Broadcast::make(val, lanes); + } +} + +Expr random_vector_expr(RandomEngine &rng, Type t, int depth) { + if (depth <= 0 || t.lanes() == 1) { + return random_leaf(rng, t); + } + + // Weight the choices to cover all Deinterleaver visit methods: + // Broadcast, Ramp, Cast, Reinterpret, Call (via abs), Shuffle, + // VectorReduce, Add/Sub/Min/Max (handled by default IRMutator) + std::function ops[] = { + // Leaf + [&]() -> Expr { + return random_leaf(rng, t); + }, + // Add + [&]() -> Expr { + Expr a = random_vector_expr(rng, t, depth - 1); + Expr b = random_vector_expr(rng, t, depth - 1); + return a + b; + }, + // Sub (only for signed types to avoid unsigned underflow coercion errors) + [&]() -> Expr { + if (t.is_uint()) { + // Fall back to Add for unsigned types + Expr a = random_vector_expr(rng, t, depth - 1); + Expr b = random_vector_expr(rng, t, depth - 1); + return a + b; + } + Expr a = random_vector_expr(rng, t, depth - 1); + Expr b = random_vector_expr(rng, t, depth - 1); + return a - b; + }, + // Min + [&]() -> Expr { + Expr a = random_vector_expr(rng, t, depth - 1); + Expr b = random_vector_expr(rng, t, depth - 1); + return min(a, b); + }, + // Max + [&]() -> Expr { + Expr a = random_vector_expr(rng, t, depth - 1); + Expr b = random_vector_expr(rng, t, depth - 1); + internal_assert(a.type() == b.type()) << a << " " << b; + return max(a, b); + }, + // Select + [&]() -> Expr { + Expr a = random_vector_expr(rng, t, depth - 1); + Expr b = random_vector_expr(rng, t, depth - 1); + Expr c = random_vector_expr(rng, t, depth - 1); + Expr cond = (a > b); + return select(cond, a, c); + }, + // Cast + [&]() -> Expr { + // Cast from a different type + Type other = random_scalar_type(rng).with_lanes(t.lanes()); + while (other == t) { + other = random_scalar_type(rng).with_lanes(t.lanes()); + } + Expr e = random_vector_expr(rng, other, depth - 1); + return Cast::make(t, e); + }, + // Reinterpret (different bit width, changes lane count) + [&]() -> Expr { + int total_bits = t.bits() * t.lanes(); + // Pick a different bit width that divides the total bits evenly + int bit_widths[] = {8, 16, 32, 64}; + vector valid_widths; + for (int bw : bit_widths) { + if (total_bits % bw == 0) { + valid_widths.push_back(bw); + } + } + // Should at least be able to preserve the existing bit width and change signedness. + internal_assert(!valid_widths.empty()); + int other_bits = random_choice(rng, valid_widths); + int other_lanes = total_bits / other_bits; + Type other = ((rng() & 1) ? Int(other_bits) : UInt(other_bits)).with_lanes(other_lanes); + Expr e = random_vector_expr(rng, other, depth - 1); + return Reinterpret::make(t, e); + }, + // Broadcast of sub-expression + [&]() -> Expr { + int f = random_factor(rng, t.lanes()); + Expr val = random_vector_expr(rng, t.with_lanes(f), depth - 1); + return Broadcast::make(val, t.lanes() / f); + }, + // Ramp + [&]() -> Expr { + int f = random_factor(rng, t.lanes()); + Type sub_t = t.with_lanes(f); + Expr base = random_vector_expr(rng, sub_t, depth - 1); + Expr stride = random_const(rng, sub_t); + return Ramp::make(base, stride, t.lanes() / f); + }, + // Shuffle (interleave) + [&]() -> Expr { + if (t.lanes() >= 4 && t.lanes() % 2 == 0) { + int half = t.lanes() / 2; + Expr a = random_vector_expr(rng, t.with_lanes(half), depth - 1); + Expr b = random_vector_expr(rng, t.with_lanes(half), depth - 1); + return Shuffle::make_interleave({a, b}); + } + // Fall back to a simple expression + return random_vector_expr(rng, t, depth - 1); + }, + // Shuffle (concat) + [&]() -> Expr { + if (t.lanes() >= 4 && t.lanes() % 2 == 0) { + int half = t.lanes() / 2; + Expr a = random_vector_expr(rng, t.with_lanes(half), depth - 1); + Expr b = random_vector_expr(rng, t.with_lanes(half), depth - 1); + return Shuffle::make_concat({a, b}); + } + return random_vector_expr(rng, t, depth - 1); + }, + // Shuffle (slice) + [&]() -> Expr { + // Make a wider vector and slice it + if (t.lanes() <= 8) { + int wider = t.lanes() * 2; + Expr e = random_vector_expr(rng, t.with_lanes(wider), depth - 1); + // Slice: take every other element starting at 0 or 1 + int start = rng() & 1; + return Shuffle::make_slice(e, start, 2, t.lanes()); + } + return random_vector_expr(rng, t, depth - 1); + }, + // VectorReduce (only when we can make it work with lane counts) + [&]() -> Expr { + // Input has more lanes, output has t.lanes() lanes + // factor must divide input lanes, and input lanes = t.lanes() * factor + int factor = (rng() % 3) + 2; + int input_lanes = t.lanes() * factor; + if (input_lanes <= 32) { + VectorReduce::Operator ops[] = { + VectorReduce::Add, + VectorReduce::Min, + VectorReduce::Max, + }; + auto op = random_choice(rng, ops); + Expr val = random_vector_expr(rng, t.with_lanes(input_lanes), depth - 1); + internal_assert(val.type().lanes() == input_lanes) << val; + return VectorReduce::make(op, val, t.lanes()); + } + return random_vector_expr(rng, t, depth - 1); + }, + // Call node (using a pure intrinsic like absd) + [&]() -> Expr { + Expr a = random_vector_expr(rng, t, depth - 1); + Expr b = random_vector_expr(rng, t, depth - 1); + return cast(t, absd(a, b)); + }, + }; + + Expr e = random_choice(rng, ops)(); + internal_assert(e.type() == t) << e.type() << " " << t << " " << e; + return e; +} + +// A custom lowering pass that replaces a specific dummy store RHS with the +// desired test expression. This lets us JIT-evaluate arbitrary vector Exprs. +class InjectExpr : public IRMutator { + using IRMutator::visit; + + string func_name; + const std::vector &replacements; + int idx = 0; + + Stmt visit(const Store *op) override { + // Replace calls to our dummy function with the replacement expr + internal_assert(idx < (int)replacements.size()); + if (op->name == func_name) { + return Store::make(op->name, flatten_nested_ramps(replacements[idx++]), + op->index, op->param, op->predicate, op->alignment); + } + return IRMutator::visit(op); + } + +public: + InjectExpr(const string &func_name, const std::vector &replacements) + : func_name(func_name), replacements(replacements) { + } +}; + +// Evaluate a vector expression by JIT-compiling it. Returns the values +// as a vector of int64_t (to hold any integer type). +// The expression may reference variables a, b, c which are set to fixed values. +bool evaluate_vector_exprs(const std::vector &e, + Buffer &result) { + Type t = e[0].type(); + int lanes = t.lanes(); + + // Create a Func that outputs a vector of the right size + Func f("test_func"); + Var x("x"), y("y"); + + // We define f(x, y) as a dummy, then inject our expressions via a custom + // lowering pass + Expr fuzz_var_sum = 0; + for (int i = 0; i < fuzz_var_count; i++) { + fuzz_var_sum += fuzz_vars[i]; + } + f(x, y) = cast(t.element_of(), fuzz_var_sum); + f.bound(x, 0, lanes) + .bound(y, 0, (int)e.size()) + .vectorize(x) + .unroll(y); + + // The custom lowering pass replaces the dummy RHS + InjectExpr injector(f.name(), e); + + auto buf = Runtime::Buffer<>(t.element_of(), {lanes, (int)e.size()}); + + Pipeline p(f); + p.add_custom_lowering_pass(&injector, nullptr); + p.realize(buf); + + // Upcast results to int64 for easier comparison + internal_assert(result.height() == (int)e.size()); + internal_assert(result.width() == lanes); + for (int y = 0; y < (int)e.size(); y++) { + for (int x = 0; x < lanes; x++) { + if (t.is_uint()) { + switch (t.bits()) { + case 8: + result(x, y) = buf.as()(x, y); + break; + case 16: + result(x, y) = buf.as()(x, y); + break; + case 32: + result(x, y) = buf.as()(x, y); + break; + case 64: + result(x, y) = buf.as()(x, y); + break; + default: + return false; + } + } else { + switch (t.bits()) { + case 8: + result(x, y) = buf.as()(x, y); + break; + case 16: + result(x, y) = buf.as()(x, y); + break; + case 32: + result(x, y) = buf.as()(x, y); + break; + case 64: + result(x, y) = buf.as()(x, y); + break; + default: + return false; + } + } + } + } + + return true; +} + +template +T initialize_rng() { + constexpr size_t kStateWords = T::state_size * sizeof(typename T::result_type) / sizeof(uint32_t); + vector random(kStateWords); + std::generate(random.begin(), random.end(), std::random_device{}); + std::seed_seq seed_seq(random.begin(), random.end()); + return T{seed_seq}; +} + +bool test_one(RandomEngine &rng) { + // Pick a random vector width and type + int lanes = std::uniform_int_distribution(4, 16)(rng); + Type scalar_t = random_scalar_type(rng); + Type t = scalar_t.with_lanes(lanes); + + // Pick random deinterleave parameters + int starting_lane = std::uniform_int_distribution(0, lanes - 1)(rng); + int ending_lane = std::uniform_int_distribution(0, lanes - 1)(rng); + int new_lanes = std::abs(ending_lane - starting_lane) + 1; + int lane_stride = std::uniform_int_distribution(1, new_lanes)(rng); + // bias it towards small strides + lane_stride = std::uniform_int_distribution(1, lane_stride)(rng); + new_lanes /= lane_stride; + if (starting_lane > ending_lane) { + lane_stride = -lane_stride; + } + + // Generate a batch of random vector expressions + constexpr int batch_size = 32; + constexpr int depth = 4; + std::vector original(batch_size); + std::vector sliced(batch_size); + + for (int i = 0; i < batch_size; i++) { + original[i] = random_vector_expr(rng, t, depth); + sliced[i] = extract_lanes(original[i], starting_lane, lane_stride, new_lanes); + internal_assert(sliced[i].type() == scalar_t.with_lanes(new_lanes)) + << sliced[i].type() << " vs " << scalar_t.with_lanes(new_lanes); + } + + // Pick random variable values + for (int i = 0; i < fuzz_var_count; i++) { + fuzz_vars[i].set((int)((int8_t)(rng() & 0x0f))); + } + + // Evaluate both + Buffer orig_vals({lanes, batch_size}), sliced_vals({new_lanes, batch_size}); + if (!evaluate_vector_exprs(original, orig_vals) || + !evaluate_vector_exprs(sliced, sliced_vals)) { + // Can't evaluate this for whatever reason + return true; + } + + // Check that the sliced values match the corresponding lanes of the original + for (int y = 0; y < batch_size; y++) { + for (int x = 0; x < new_lanes; x++) { + int orig_lane = starting_lane + x * lane_stride; + if (sliced_vals(x, y) != orig_vals(orig_lane, y)) { + std::cerr << "MISMATCH!\n" + << "Original expr: " << original[y] << "\n" + << "Original type: " << original[y].type() << "\n" + << "Deinterleave params: starting_lane=" << starting_lane + << " lane_stride=" << lane_stride + << " new_lanes=" << new_lanes << "\n" + << "Sliced expr: " << sliced[y] << "\n" + << "Variables:"; + for (int j = 0; j < fuzz_var_count; j++) { + std::cerr << " " << fuzz_vars[j].name() << "=" << fuzz_vars[j].get() << "\n"; + } + std::cerr << "\n" + << "Original values:"; + for (int j = 0; j < lanes; j++) { + std::cerr << " " << orig_vals(j, y); + } + std::cerr << "\n" + << "Sliced values:"; + for (int j = 0; j < new_lanes; j++) { + std::cerr << " " << sliced_vals(j, y); + } + std::cerr << "\n"; + return false; + } + } + } + + return true; +} + +} // namespace + +int main(int argc, char **argv) { + auto seed_generator = initialize_rng(); + + int num_iters = (argc > 1) ? 1 : 32; + + for (int i = 0; i < num_iters; i++) { + auto seed = seed_generator(); + if (argc > 1) { + std::istringstream{argv[1]} >> seed; + } + std::cout << "Seed: " << seed << std::endl; + RandomEngine rng{seed}; + + if (!test_one(rng)) { + std::cout << "Failed with seed " << seed << "\n"; + return 1; + } + } + + std::cout << "Success!\n"; + return 0; +} From 0c3b824831b319f2223c2389510557f7bcfaaf60 Mon Sep 17 00:00:00 2001 From: Martijn Courteaux Date: Wed, 11 Mar 2026 14:22:40 +0100 Subject: [PATCH 36/47] Clang format. --- src/Simplify_Exprs.cpp | 8 ++++---- test/correctness/fuzz_extract_lanes.cpp | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/Simplify_Exprs.cpp b/src/Simplify_Exprs.cpp index 12857a8a10e2..e80e99d3f6fe 100644 --- a/src/Simplify_Exprs.cpp +++ b/src/Simplify_Exprs.cpp @@ -187,13 +187,13 @@ Expr Simplify::visit(const VectorReduce *op, ExprInfo *info) { rewrite(h_or(broadcast(x, c0), lanes), h_or(x, lanes), factor % c0 == 0) || // type of arg_lanes is somewhat indeterminate (lanes == 1 && rewrite(h_or(ramp(x, y, arg_lanes) < broadcast(z, arg_lanes), lanes), - x + min(y * (arg_lanes - 1), 0) < z)) || + x + min(y * (arg_lanes - 1), 0) < z)) || (lanes == 1 && rewrite(h_or(ramp(x, y, arg_lanes) <= broadcast(z, arg_lanes), lanes), - x + min(y * (arg_lanes - 1), 0) <= z)) || + x + min(y * (arg_lanes - 1), 0) <= z)) || (lanes == 1 && rewrite(h_or(broadcast(x, arg_lanes) < ramp(y, z, arg_lanes), lanes), - x < y + max(z * (arg_lanes - 1), 0))) || + x < y + max(z * (arg_lanes - 1), 0))) || (lanes == 1 && rewrite(h_or(broadcast(x, arg_lanes) < ramp(y, z, arg_lanes), lanes), - x <= y + max(z * (arg_lanes - 1), 0))) || + x <= y + max(z * (arg_lanes - 1), 0))) || false) { return mutate(rewrite.result, info); } diff --git a/test/correctness/fuzz_extract_lanes.cpp b/test/correctness/fuzz_extract_lanes.cpp index 67ce54645ea6..ca6729b068dd 100644 --- a/test/correctness/fuzz_extract_lanes.cpp +++ b/test/correctness/fuzz_extract_lanes.cpp @@ -439,7 +439,7 @@ int main(int argc, char **argv) { if (argc > 1) { std::istringstream{argv[1]} >> seed; } - std::cout << "Seed: " << seed << std::endl; + std::cout << "Seed: " << seed << "\n"; RandomEngine rng{seed}; if (!test_one(rng)) { From a386c58af59c093c1241774f8d312b575beb1bc8 Mon Sep 17 00:00:00 2001 From: Martijn Courteaux Date: Wed, 11 Mar 2026 14:32:54 +0100 Subject: [PATCH 37/47] Resolve ambiguous C++ call to Buffer constructor. --- test/correctness/fuzz_extract_lanes.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/correctness/fuzz_extract_lanes.cpp b/test/correctness/fuzz_extract_lanes.cpp index ca6729b068dd..0b8b1d29cc26 100644 --- a/test/correctness/fuzz_extract_lanes.cpp +++ b/test/correctness/fuzz_extract_lanes.cpp @@ -385,7 +385,7 @@ bool test_one(RandomEngine &rng) { } // Evaluate both - Buffer orig_vals({lanes, batch_size}), sliced_vals({new_lanes, batch_size}); + Buffer orig_vals(lanes, batch_size), sliced_vals(new_lanes, batch_size); if (!evaluate_vector_exprs(original, orig_vals) || !evaluate_vector_exprs(sliced, sliced_vals)) { // Can't evaluate this for whatever reason From 06dcb86d38e40a9f576b9b9eb199fd1f20b35789 Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Thu, 12 Mar 2026 09:27:12 -0700 Subject: [PATCH 38/47] Fix lossless casts of vector reduces down to bools Fixes #9011 --- src/IROperator.cpp | 12 ++++++++++-- test/correctness/lossless_cast.cpp | 22 ++++++++++++++++++++++ 2 files changed, 32 insertions(+), 2 deletions(-) diff --git a/src/IROperator.cpp b/src/IROperator.cpp index 285744ba6eef..c729539daa29 100644 --- a/src/IROperator.cpp +++ b/src/IROperator.cpp @@ -544,12 +544,20 @@ Expr lossless_cast(Type t, } } } else if (const VectorReduce *op = e.as()) { - if (op->op == VectorReduce::Add || + if ((t.bits() > 1 && op->op == VectorReduce::Add) || op->op == VectorReduce::Min || op->op == VectorReduce::Max) { Expr v = lossless_cast(t.with_lanes(op->value.type().lanes()), op->value, scope, cache); if (v.defined()) { - return VectorReduce::make(op->op, v, op->type.lanes()); + auto reduce_op = op->op; + if (t.bits() == 1) { + // UInt(1) == Bool() is the only 1-bit type we expect to see + internal_assert(t.is_uint()) << "Unexpected type: " << t << "\n"; + reduce_op = (op->op == VectorReduce::Min ? + VectorReduce::And : + VectorReduce::Or); + } + return VectorReduce::make(reduce_op, v, op->type.lanes()); } } } diff --git a/test/correctness/lossless_cast.cpp b/test/correctness/lossless_cast.cpp index 77104d149d39..7633954fb003 100644 --- a/test/correctness/lossless_cast.cpp +++ b/test/correctness/lossless_cast.cpp @@ -81,6 +81,28 @@ int lossless_cast_test() { e = cast(i64, 1024) * cast(i64, 1024) * cast(i64, 1024); res |= check_lossless_cast(i32, e, (cast(i32, 1024) * 1024) * 1024); + // Check narrowing a vector reduction of something narrowable to bool ... + auto make_reduce = [&](Type t, VectorReduce::Operator op) { + return VectorReduce::make(op, + cast(t.with_lanes(4), Ramp::make(x, 1, 4) > 4), 2); + }; + + // It's OK to narrow it to 8-bit. + e = make_reduce(UInt(16), VectorReduce::Add); + res |= check_lossless_cast(UInt(8), e, make_reduce(UInt(8), VectorReduce::Add)); + + // ... but we can't reduce it all the way to bool if the operator isn't + // legal for bools (issue #9011) + e = make_reduce(UInt(8), VectorReduce::Add); + res |= check_lossless_cast(Bool(), e, Expr()); + + // Min or Max, however, can just become And and Or + e = make_reduce(UInt(8), VectorReduce::Min); + res |= check_lossless_cast(Bool(), e, make_reduce(Bool(), VectorReduce::And)); + + e = make_reduce(UInt(8), VectorReduce::Max); + res |= check_lossless_cast(Bool(), e, make_reduce(Bool(), VectorReduce::Or)); + return res; } From 4dabe54ce44665eb82e9824854a20523f9feb68c Mon Sep 17 00:00:00 2001 From: Martijn Courteaux Date: Sat, 14 Mar 2026 12:06:59 +0100 Subject: [PATCH 39/47] Fix an ARM codegen issue. --- src/CodeGen_ARM.cpp | 2 +- test/correctness/fuzz_extract_lanes.cpp | 8 +++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/CodeGen_ARM.cpp b/src/CodeGen_ARM.cpp index d5c3879d36af..3731517047e6 100644 --- a/src/CodeGen_ARM.cpp +++ b/src/CodeGen_ARM.cpp @@ -1521,7 +1521,7 @@ void CodeGen_ARM::visit(const Store *op) { // Declare the function std::ostringstream instr; vector arg_types; - llvm::Type *intrin_llvm_type = llvm_type_with_constraint(intrin_type, false, is_sve ? VectorTypeConstraint::VScale : VectorTypeConstraint::Fixed); + llvm::Type *intrin_llvm_type = llvm_type_with_constraint(intrin_type, true, is_sve ? VectorTypeConstraint::VScale : VectorTypeConstraint::Fixed); if (target.bits == 32) { instr << "llvm.arm.neon.vst" << num_vecs diff --git a/test/correctness/fuzz_extract_lanes.cpp b/test/correctness/fuzz_extract_lanes.cpp index 0b8b1d29cc26..620cb70329b4 100644 --- a/test/correctness/fuzz_extract_lanes.cpp +++ b/test/correctness/fuzz_extract_lanes.cpp @@ -291,7 +291,13 @@ bool evaluate_vector_exprs(const std::vector &e, Pipeline p(f); p.add_custom_lowering_pass(&injector, nullptr); - p.realize(buf); + if (get_target_from_environment() == get_host_target()) { + p.realize(buf); + } else { + // Compile something, to be able to at least test CodeGen from the backends and LLVM. + p.compile_to_assembly("fuzz_extract_lanes.s", {fuzz_vars[0], fuzz_vars[1], fuzz_vars[2]}, "fuzz_func"); + return false; + } // Upcast results to int64 for easier comparison internal_assert(result.height() == (int)e.size()); From c83dc510d4f3512b0b5a834972d5f6c1e607a8a4 Mon Sep 17 00:00:00 2001 From: Martijn Courteaux Date: Sun, 15 Mar 2026 00:04:02 +0100 Subject: [PATCH 40/47] Simplify result of extract_lanes --- src/Deinterleave.cpp | 1 + test/CMakeLists.txt | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/Deinterleave.cpp b/src/Deinterleave.cpp index 7535bd8bbe05..1f36d6808f55 100644 --- a/src/Deinterleave.cpp +++ b/src/Deinterleave.cpp @@ -805,6 +805,7 @@ Expr extract_lanes(const Expr &original_expr, int starting_lane, int lane_stride Type original_type = original_expr.type(); ExtractLanes d(starting_lane, lane_stride, new_lanes, lets, requested_sliced_lets); Expr e = d.mutate(original_expr); + e = simplify(e); e = common_subexpression_elimination(e); debug(3) << " => " << e << "\n"; Type final_type = e.type(); diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 64f2c9ac6825..9edfd0476cbf 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -4,7 +4,7 @@ include(CheckCXXCompilerFlag) # Internal tests are a special case. # HalideTestHelpers depends on this test being present. add_executable(_test_internal internal.cpp) -target_link_libraries(_test_internal PRIVATE Halide::Test) +target_link_libraries(_test_internal PRIVATE Halide::Test Halide::TerminateHandler) target_include_directories(_test_internal PRIVATE "${Halide_SOURCE_DIR}/src") target_precompile_headers(_test_internal PRIVATE ) if (Halide_CCACHE_BUILD) From a3077a2d3b20c4a79d58ca3e4ae45137b44aa3cb Mon Sep 17 00:00:00 2001 From: Martijn Courteaux Date: Sun, 15 Mar 2026 00:07:43 +0100 Subject: [PATCH 41/47] Add skip for ARM in the extract_lanes fuzz tester. --- test/correctness/fuzz_extract_lanes.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/correctness/fuzz_extract_lanes.cpp b/test/correctness/fuzz_extract_lanes.cpp index 620cb70329b4..122dd9295dc8 100644 --- a/test/correctness/fuzz_extract_lanes.cpp +++ b/test/correctness/fuzz_extract_lanes.cpp @@ -436,6 +436,10 @@ bool test_one(RandomEngine &rng) { } // namespace int main(int argc, char **argv) { + if (get_jit_target_from_environment().arch == Halide::Target::ARM) { + printf("[SKIP-WITH-ISSUE-9026] LLVM generates incorrect IR for some expressions.\n"); + return 0; + } auto seed_generator = initialize_rng(); int num_iters = (argc > 1) ? 1 : 32; From 2c0933038998a340b3bd12aae59d51ecbea7c966 Mon Sep 17 00:00:00 2001 From: Martijn Courteaux Date: Sun, 15 Mar 2026 13:10:52 +0100 Subject: [PATCH 42/47] Fixes #9030. Co-authored-by: Gemini 3.1 Pro --- src/Simplify_Exprs.cpp | 5 +++-- test/correctness/simplify.cpp | 18 ++++++++++++++++++ 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/src/Simplify_Exprs.cpp b/src/Simplify_Exprs.cpp index 0eb3bbaf3c15..bbd67a5bace0 100644 --- a/src/Simplify_Exprs.cpp +++ b/src/Simplify_Exprs.cpp @@ -69,7 +69,7 @@ Expr Simplify::visit(const VectorReduce *op, ExprInfo *info) { return value; } - if (info && op->type.is_int()) { + if (info && op->type.is_int_or_uint()) { switch (op->op) { case VectorReduce::Add: // Alignment of result is the alignment of the arg. Bounds @@ -123,7 +123,8 @@ Expr Simplify::visit(const VectorReduce *op, ExprInfo *info) { case VectorReduce::Add: { auto rewrite = IRMatcher::rewriter(IRMatcher::h_add(value, lanes), op->type); if (rewrite(h_add(x * broadcast(y, arg_lanes), lanes), h_add(x, lanes) * broadcast(y, lanes)) || - rewrite(h_add(broadcast(x, arg_lanes) * y, lanes), h_add(y, lanes) * broadcast(x, lanes))) { + rewrite(h_add(broadcast(x, arg_lanes) * y, lanes), h_add(y, lanes) * broadcast(x, lanes)) || + rewrite(h_add(broadcast(x, arg_lanes), lanes), broadcast(x * factor, lanes))) { return mutate(rewrite.result, info); } break; diff --git a/test/correctness/simplify.cpp b/test/correctness/simplify.cpp index 628de4d91504..de10bde5a1b9 100644 --- a/test/correctness/simplify.cpp +++ b/test/correctness/simplify.cpp @@ -810,6 +810,24 @@ void check_vectors() { int_vector); check(VectorReduce::make(VectorReduce::Max, Broadcast::make(int_vector, 4), 8), VectorReduce::make(VectorReduce::Max, Broadcast::make(int_vector, 4), 8)); + + { + // h_add(broadcast(x, 8), 4) should simplify to broadcast(x * 2, 4) + check(VectorReduce::make(VectorReduce::Add, broadcast(x, 8), 4), + broadcast(x * 2, 4)); + } + + { + Expr const_u8 = cast(UInt(8), 3); + check(VectorReduce::make(VectorReduce::Add, broadcast(const_u8, 9), 3), broadcast(cast(UInt(8), 9), 3)); + } + + { + // Test VectorReduce::Add on a variable of unsigned type to ensure the multiplied factor + // keeps the correct type and avoids type-mismatch assertion failures. + Expr u8_x = Variable::make(UInt(8), "u8_x"); + check(VectorReduce::make(VectorReduce::Add, broadcast(u8_x, 9), 3), broadcast(u8_x * cast(UInt(8), 3), 3)); + } } void check_bounds() { From af3bc76269298ac2167df78453ad25a12606c9ba Mon Sep 17 00:00:00 2001 From: Martijn Courteaux Date: Sun, 15 Mar 2026 14:11:04 +0100 Subject: [PATCH 43/47] Move simplification calls down to where they are needed. --- src/Deinterleave.cpp | 15 +++++---------- test/correctness/fuzz_extract_lanes.cpp | 22 ++++++++++++++++------ 2 files changed, 21 insertions(+), 16 deletions(-) diff --git a/src/Deinterleave.cpp b/src/Deinterleave.cpp index 1f36d6808f55..d6d6463d614d 100644 --- a/src/Deinterleave.cpp +++ b/src/Deinterleave.cpp @@ -389,13 +389,9 @@ class ExtractLanes : public IRMutator { if (op->value.type().lanes() == 1) { return op->value; } else { - int old_starting_lane = starting_lane; - int old_lane_stride = lane_stride; - starting_lane = starting_lane % op->value.type().lanes(); - lane_stride = op->value.type().lanes(); + ScopedValue old_starting_lane(starting_lane, starting_lane % op->value.type().lanes()); + ScopedValue old_lane_stride(lane_stride, op->value.type().lanes()); Expr e = mutate(op->value); - starting_lane = old_starting_lane; - lane_stride = old_lane_stride; return e; } } @@ -438,7 +434,7 @@ class ExtractLanes : public IRMutator { if (base_lanes > 1) { if (new_lanes == 1) { int index = starting_lane / base_lanes; - Expr expr = op->base + cast(op->base.type(), index) * op->stride; + Expr expr = simplify(op->base + cast(op->base.type(), index) * op->stride); ScopedValue old_starting_lane(starting_lane, starting_lane % base_lanes); ScopedValue old_lane_stride(lane_stride, base_lanes); expr = mutate(expr); @@ -464,10 +460,10 @@ class ExtractLanes : public IRMutator { return mutate(flatten_nested_ramps(op)); } } - Expr expr = op->base + cast(op->base.type(), starting_lane) * op->stride; + Expr expr = simplify(op->base + cast(op->base.type(), starting_lane) * op->stride); internal_assert(expr.type() == op->base.type()); if (new_lanes > 1) { - expr = Ramp::make(expr, op->stride * cast(op->base.type(), lane_stride), new_lanes); + expr = Ramp::make(expr, simplify(op->stride * cast(op->base.type(), lane_stride)), new_lanes); } return expr; } @@ -805,7 +801,6 @@ Expr extract_lanes(const Expr &original_expr, int starting_lane, int lane_stride Type original_type = original_expr.type(); ExtractLanes d(starting_lane, lane_stride, new_lanes, lets, requested_sliced_lets); Expr e = d.mutate(original_expr); - e = simplify(e); e = common_subexpression_elimination(e); debug(3) << " => " << e << "\n"; Type final_type = e.type(); diff --git a/test/correctness/fuzz_extract_lanes.cpp b/test/correctness/fuzz_extract_lanes.cpp index 122dd9295dc8..afa97fa7a1e7 100644 --- a/test/correctness/fuzz_extract_lanes.cpp +++ b/test/correctness/fuzz_extract_lanes.cpp @@ -373,7 +373,7 @@ bool test_one(RandomEngine &rng) { } // Generate a batch of random vector expressions - constexpr int batch_size = 32; + constexpr int batch_size = 64; constexpr int depth = 4; std::vector original(batch_size); std::vector sliced(batch_size); @@ -403,10 +403,10 @@ bool test_one(RandomEngine &rng) { for (int x = 0; x < new_lanes; x++) { int orig_lane = starting_lane + x * lane_stride; if (sliced_vals(x, y) != orig_vals(orig_lane, y)) { - std::cerr << "MISMATCH!\n" + std::cerr << "MISMATCH! (y=" << y << ", x=" << x << ")\n" << "Original expr: " << original[y] << "\n" << "Original type: " << original[y].type() << "\n" - << "Deinterleave params: starting_lane=" << starting_lane + << "ExtractLanes params: starting_lane=" << starting_lane << " lane_stride=" << lane_stride << " new_lanes=" << new_lanes << "\n" << "Sliced expr: " << sliced[y] << "\n" @@ -436,16 +436,26 @@ bool test_one(RandomEngine &rng) { } // namespace int main(int argc, char **argv) { - if (get_jit_target_from_environment().arch == Halide::Target::ARM) { + if (get_jit_target_from_environment().has_feature(Target::SVE2)) { printf("[SKIP-WITH-ISSUE-9026] LLVM generates incorrect IR for some expressions.\n"); return 0; } auto seed_generator = initialize_rng(); - int num_iters = (argc > 1) ? 1 : 32; + /* Seeds known to have failed in the past: */ + std::vector seeds_to_try = { + 11290674455725750672ull, + 18322803614019275106ull, + 12847901530538798383ull, + }; + + int num_iters = (argc > 1) ? 1 : 64; for (int i = 0; i < num_iters; i++) { - auto seed = seed_generator(); + uint64_t seed = seed_generator(); + if (i < seeds_to_try.size()) { + seed = seeds_to_try[i]; + } if (argc > 1) { std::istringstream{argv[1]} >> seed; } From a165925f16e07e5b1070f4e398ac2a0f0eee2429 Mon Sep 17 00:00:00 2001 From: Martijn Courteaux Date: Sun, 15 Mar 2026 14:40:09 +0100 Subject: [PATCH 44/47] Fix int type warning. --- test/correctness/fuzz_extract_lanes.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/correctness/fuzz_extract_lanes.cpp b/test/correctness/fuzz_extract_lanes.cpp index afa97fa7a1e7..a3b2cd125347 100644 --- a/test/correctness/fuzz_extract_lanes.cpp +++ b/test/correctness/fuzz_extract_lanes.cpp @@ -449,9 +449,9 @@ int main(int argc, char **argv) { 12847901530538798383ull, }; - int num_iters = (argc > 1) ? 1 : 64; + size_t num_iters = (argc > 1) ? 1 : 64; - for (int i = 0; i < num_iters; i++) { + for (size_t i = 0; i < num_iters; i++) { uint64_t seed = seed_generator(); if (i < seeds_to_try.size()) { seed = seeds_to_try[i]; From 01b6b1b5aa06704388c81425bb97b305cf278d52 Mon Sep 17 00:00:00 2001 From: Martijn Courteaux Date: Sun, 15 Mar 2026 21:39:08 +0100 Subject: [PATCH 45/47] Extra work on fuzz test. --- test/correctness/fuzz_extract_lanes.cpp | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/test/correctness/fuzz_extract_lanes.cpp b/test/correctness/fuzz_extract_lanes.cpp index a3b2cd125347..ad8dc2348df4 100644 --- a/test/correctness/fuzz_extract_lanes.cpp +++ b/test/correctness/fuzz_extract_lanes.cpp @@ -428,6 +428,16 @@ bool test_one(RandomEngine &rng) { return false; } } + + std::cerr << "Original values:"; + for (int j = 0; j < lanes; j++) { + std::cerr << " " << orig_vals(j, y); + } + std::cerr << " Sliced values:"; + for (int j = 0; j < new_lanes; j++) { + std::cerr << " " << sliced_vals(j, y); + } + std::cerr << " Correct.\n"; } return true; @@ -447,6 +457,10 @@ int main(int argc, char **argv) { 11290674455725750672ull, 18322803614019275106ull, 12847901530538798383ull, + + // Failures on ARM: + 5792148528566212763, + 6300344786331520063, }; size_t num_iters = (argc > 1) ? 1 : 64; From cd60a772d8f04b3679f98cebd005c63aa9069957 Mon Sep 17 00:00:00 2001 From: Martijn Courteaux Date: Sun, 15 Mar 2026 23:49:56 +0100 Subject: [PATCH 46/47] Disable fuzz tester on non-x86_64 for now. --- test/correctness/fuzz_extract_lanes.cpp | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/test/correctness/fuzz_extract_lanes.cpp b/test/correctness/fuzz_extract_lanes.cpp index ad8dc2348df4..e23e60ccf0be 100644 --- a/test/correctness/fuzz_extract_lanes.cpp +++ b/test/correctness/fuzz_extract_lanes.cpp @@ -446,10 +446,15 @@ bool test_one(RandomEngine &rng) { } // namespace int main(int argc, char **argv) { - if (get_jit_target_from_environment().has_feature(Target::SVE2)) { + Target t = get_jit_target_from_environment(); + if (t.has_feature(Target::SVE2)) { printf("[SKIP-WITH-ISSUE-9026] LLVM generates incorrect IR for some expressions.\n"); return 0; } + if (t.arch != Target::X86 || t.bits != 64) { + printf("[SKIP-WITH-ISSUE-9040] Only running test on X86-64 for now. See also #9044."); + return 0; + } auto seed_generator = initialize_rng(); /* Seeds known to have failed in the past: */ From 4e3750bf492c0bf476a22e61b36c2185f3ce1b05 Mon Sep 17 00:00:00 2001 From: "halide-ci[bot]" <266445882+halide-ci[bot]@users.noreply.github.com> Date: Sun, 15 Mar 2026 22:53:15 +0000 Subject: [PATCH 47/47] Apply pre-commit auto-fixes --- test/correctness/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/correctness/CMakeLists.txt b/test/correctness/CMakeLists.txt index a498b7fce20b..d732f1e72284 100644 --- a/test/correctness/CMakeLists.txt +++ b/test/correctness/CMakeLists.txt @@ -124,8 +124,8 @@ tests(GROUPS correctness fuse.cpp fuse_gpu_threads.cpp fused_where_inner_extent_is_zero.cpp - fuzz_float_stores.cpp fuzz_extract_lanes.cpp + fuzz_float_stores.cpp fuzz_schedule.cpp fuzz_simplify.cpp gameoflife.cpp