diff --git a/python_bindings/src/PyImageParam.cpp b/python_bindings/src/PyImageParam.cpp index 1a3e35f50a4d..ec9ef529723a 100644 --- a/python_bindings/src/PyImageParam.cpp +++ b/python_bindings/src/PyImageParam.cpp @@ -31,6 +31,7 @@ void define_image_param(py::module &m) { .def("host_alignment", &OutputImageParam::host_alignment) .def("set_estimates", &OutputImageParam::set_estimates, py::arg("estimates")) .def("set_host_alignment", &OutputImageParam::set_host_alignment) + .def("is_host_aligned", &OutputImageParam::is_host_aligned) .def("store_in", &OutputImageParam::store_in, py::arg("memory_type")) .def("dimensions", &OutputImageParam::dimensions) .def("left", &OutputImageParam::left) diff --git a/src/AddImageChecks.cpp b/src/AddImageChecks.cpp index b98fd62ff86d..1cb5a7a5f85a 100644 --- a/src/AddImageChecks.cpp +++ b/src/AddImageChecks.cpp @@ -637,8 +637,11 @@ Stmt add_image_checks_inner(Stmt s, int alignment_required = param.host_alignment(); Expr u64t_host_ptr = reinterpret(host_ptr); Expr align_condition = (u64t_host_ptr % alignment_required) == 0; - Expr error = Call::make(Int(32), "halide_error_unaligned_host_ptr", - {name, alignment_required}, Call::Extern); + Expr error = 0; + if (!no_asserts) { + error = Call::make(Int(32), "halide_error_unaligned_host_ptr", + {name, alignment_required}, Call::Extern); + } asserts_host_alignment.push_back(AssertStmt::make(align_condition, error)); } } @@ -661,7 +664,6 @@ Stmt add_image_checks_inner(Stmt s, if (!no_asserts) { // Inject the code that checks the host pointers. prepend_stmts(&asserts_host_non_null); - prepend_stmts(&asserts_host_alignment); prepend_stmts(&asserts_device_not_dirty); prepend_stmts(&dims_no_overflow_asserts); prepend_lets(&lets_overflow); @@ -680,6 +682,7 @@ Stmt add_image_checks_inner(Stmt s, // Inject the code that checks the constraints are correct. We // need these regardless of how NoAsserts is set, because they are // what gets Halide to actually exploit the constraint. + prepend_stmts(&asserts_host_alignment); prepend_stmts(&asserts_constrained); if (!no_asserts) { diff --git a/src/AlignLoads.cpp b/src/AlignLoads.cpp index 7df913a5b28c..f32187fb339d 100644 --- a/src/AlignLoads.cpp +++ b/src/AlignLoads.cpp @@ -2,7 +2,6 @@ #include "AlignLoads.h" #include "Bounds.h" -#include "HexagonAlignment.h" #include "IRMutator.h" #include "IROperator.h" #include "ModulusRemainder.h" @@ -22,12 +21,10 @@ namespace { class AlignLoads : public IRMutator { public: AlignLoads(int alignment) - : alignment_analyzer(alignment), required_alignment(alignment) { + : required_alignment(alignment) { } private: - HexagonAlignmentAnalyzer alignment_analyzer; - // Loads and stores should ideally be aligned to the vector width in bytes. int required_alignment; @@ -75,14 +72,12 @@ class AlignLoads : public IRMutator { return IRMutator::visit(op); } - int64_t aligned_offset = 0; - bool is_aligned = - alignment_analyzer.is_aligned(op, &aligned_offset); - // We know the alignment_analyzer has been able to reason about alignment - // if the following is true. - bool known_alignment = is_aligned || (!is_aligned && aligned_offset != 0); int lanes = ramp->lanes; int native_lanes = required_alignment / op->type.bytes(); + int64_t aligned_offset = + op->alignment.modulus % native_lanes == 0 ? op->alignment.remainder % native_lanes : 0; + bool is_aligned = op->alignment.contains(native_lanes); + bool known_alignment = is_aligned || aligned_offset != 0; int stride = static_cast(*const_stride); if (stride != 1) { internal_assert(stride >= 0); diff --git a/src/CodeGen_LLVM.cpp b/src/CodeGen_LLVM.cpp index eda89aca82da..6f8a66bf1015 100644 --- a/src/CodeGen_LLVM.cpp +++ b/src/CodeGen_LLVM.cpp @@ -581,9 +581,6 @@ void CodeGen_LLVM::begin_func(LinkageType linkage, const std::string &name, size_t i = 0; for (auto &arg : function->args()) { if (args[i].is_buffer()) { - // Track this buffer name so that loads and stores from it - // don't try to be too aligned. - external_buffer.insert(args[i].name); sym_push(args[i].name + ".buffer", &arg); } else { Type passed_type = upgrade_type_for_argument_passing(args[i].type); @@ -2169,7 +2166,6 @@ void CodeGen_LLVM::codegen_predicated_vector_store(const Store *op) { Value *vpred = codegen(op->predicate); Halide::Type value_type = op->value.type(); Value *val = codegen(op->value); - bool is_external = (external_buffer.find(op->name) != external_buffer.end()); int alignment = value_type.bytes(); int native_bits = native_vector_bits(); int native_bytes = native_bits / 8; @@ -2184,14 +2180,6 @@ void CodeGen_LLVM::codegen_predicated_vector_store(const Store *op) { alignment *= 2; } - // If it is an external buffer, then we cannot assume that the host pointer - // is aligned to at least the native vector width. However, we may be able to do - // better than just assuming that it is unaligned. - if (is_external && op->param.defined()) { - int host_alignment = op->param.host_alignment(); - alignment = gcd(alignment, host_alignment); - } - // For dense vector stores wider than the native vector // width, bust them up into native vectors. int store_lanes = value_type.lanes(); @@ -2255,7 +2243,6 @@ Value *CodeGen_LLVM::codegen_dense_vector_load(const Load *load, Value *vpred) { const Ramp *ramp = load->index.as(); internal_assert(ramp && is_const_one(ramp->stride)) << "Should be dense vector load\n"; - bool is_external = (external_buffer.find(load->name) != external_buffer.end()); int alignment = load->type.bytes(); // The size of a single element int native_bits = native_vector_bits(); @@ -2275,19 +2262,6 @@ Value *CodeGen_LLVM::codegen_dense_vector_load(const Load *load, Value *vpred) { alignment *= 2; } - // If it is an external buffer, then we cannot assume that the host pointer - // is aligned to at least native vector width. However, we may be able to do - // better than just assuming that it is unaligned. - if (is_external) { - if (load->param.defined()) { - int host_alignment = load->param.host_alignment(); - alignment = gcd(alignment, host_alignment); - } else if (get_target().has_feature(Target::JIT) && load->image.defined()) { - // If we're JITting, use the actual pointer value to determine alignment for embedded buffers. - alignment = gcd(alignment, (int)(((uintptr_t)load->image.data()) & std::numeric_limits::max())); - } - } - // For dense vector loads wider than the native vector // width, bust them up into native vectors int load_lanes = load->type.lanes(); @@ -3979,7 +3953,6 @@ void CodeGen_LLVM::visit(const Store *op) { } Value *val = codegen(op->value); - bool is_external = (external_buffer.find(op->name) != external_buffer.end()); // Scalar if (value_type.is_scalar()) { Value *ptr = codegen_buffer_pointer(op->name, value_type, op->index); @@ -4006,14 +3979,6 @@ void CodeGen_LLVM::visit(const Store *op) { alignment *= 2; } - // If it is an external buffer, then we cannot assume that the host pointer - // is aligned to at least the native vector width. However, we may be able to do - // better than just assuming that it is unaligned. - if (is_external && op->param.defined()) { - int host_alignment = op->param.host_alignment(); - alignment = gcd(alignment, host_alignment); - } - // For dense vector stores wider than the native vector // width, bust them up into native vectors. int store_lanes = value_type.lanes(); diff --git a/src/CodeGen_LLVM.h b/src/CodeGen_LLVM.h index 092bc7713b5b..6026b53d3efc 100644 --- a/src/CodeGen_LLVM.h +++ b/src/CodeGen_LLVM.h @@ -414,10 +414,6 @@ class CodeGen_LLVM : public IRVisitor { */ size_t requested_alloca_total = 0; - /** Which buffers came in from the outside world (and so we can't - * guarantee their alignment) */ - std::set external_buffer; - /** The user_context argument. May be a constant null if the * function is being compiled without a user context. */ llvm::Value *get_user_context() const; diff --git a/src/CodeGen_OpenCL_Dev.cpp b/src/CodeGen_OpenCL_Dev.cpp index a5259d20fe52..447b3522b7b8 100644 --- a/src/CodeGen_OpenCL_Dev.cpp +++ b/src/CodeGen_OpenCL_Dev.cpp @@ -491,8 +491,7 @@ void CodeGen_OpenCL_Dev::CodeGen_OpenCL_C::visit(const Load *op) { internal_assert(op->type.is_vector()); ostringstream rhs; - if ((op->alignment.modulus % op->type.lanes() == 0) && - (op->alignment.remainder % op->type.lanes() == 0)) { + if (op->alignment.contains(op->type.lanes())) { // Get the rhs just for the cache. string id_ramp_base = print_expr(ramp_base / op->type.lanes()); string array_indexing = print_array_access(op->name, op->type, id_ramp_base); @@ -658,8 +657,7 @@ void CodeGen_OpenCL_Dev::CodeGen_OpenCL_C::visit(const Store *op) { if (ramp_base.defined()) { internal_assert(op->value.type().is_vector()); - if ((op->alignment.modulus % op->value.type().lanes() == 0) && - (op->alignment.remainder % op->value.type().lanes() == 0)) { + if (op->alignment.contains(op->value.type().lanes())) { string id_ramp_base = print_expr(ramp_base / op->value.type().lanes()); string array_indexing = print_array_access(op->name, t, id_ramp_base); stream << get_indent() << array_indexing << " = " << id_value << ";\n"; diff --git a/src/CodeGen_PTX_Dev.cpp b/src/CodeGen_PTX_Dev.cpp index 26822cda2296..54101953b736 100644 --- a/src/CodeGen_PTX_Dev.cpp +++ b/src/CodeGen_PTX_Dev.cpp @@ -346,7 +346,7 @@ void CodeGen_PTX_Dev::visit(const Load *op) { // TODO: lanes >= 4, not lanes == 4 if (is_const_one(op->predicate) && r && is_const_one(r->stride) && r->lanes == 4 && op->type.bits() == 32) { ModulusRemainder align = op->alignment; - if (align.modulus % 4 == 0 && align.remainder % 4 == 0) { + if (align.contains(4)) { Expr index = simplify(r->base / 4); Expr equiv = Load::make(UInt(128), op->name, index, op->image, op->param, const_true(), align / 4); @@ -371,7 +371,7 @@ void CodeGen_PTX_Dev::visit(const Store *op) { // TODO: lanes >= 4, not lanes == 4 if (is_const_one(op->predicate) && r && is_const_one(r->stride) && r->lanes == 4 && op->value.type().bits() == 32) { ModulusRemainder align = op->alignment; - if (align.modulus % 4 == 0 && align.remainder % 4 == 0) { + if (align.contains(4)) { Expr index = simplify(r->base / 4); Expr value = reinterpret(UInt(128), op->value); Stmt equiv = Store::make(op->name, value, index, op->param, const_true(), align / 4); @@ -411,8 +411,7 @@ class RewriteLoadsAs32Bit : public IRMutator { if (idx && is_const_one(op->predicate) && is_const_one(idx->stride) && - op->alignment.modulus % sub_lanes == 0 && - op->alignment.remainder % sub_lanes == 0) { + op->alignment.contains(sub_lanes)) { Expr new_idx = simplify(idx->base / sub_lanes); int load_lanes = op->type.lanes() / sub_lanes; if (op->type.lanes() > sub_lanes) { diff --git a/src/Generator.h b/src/Generator.h index 5cbdcf0a3ca0..79cc0f410d45 100644 --- a/src/Generator.h +++ b/src/Generator.h @@ -1778,6 +1778,7 @@ class GeneratorInput_Buffer : public GeneratorInputImpl { HALIDE_FORWARD_METHOD_CONST(ImageParam, dim) HALIDE_FORWARD_METHOD_CONST(ImageParam, host_alignment) HALIDE_FORWARD_METHOD(ImageParam, set_host_alignment) + HALIDE_FORWARD_METHOD(ImageParam, is_host_aligned) HALIDE_FORWARD_METHOD(ImageParam, store_in) HALIDE_FORWARD_METHOD_CONST(ImageParam, dimensions) HALIDE_FORWARD_METHOD_CONST(ImageParam, left) @@ -2521,6 +2522,7 @@ class GeneratorOutput_Buffer : public GeneratorOutputImpl { HALIDE_FORWARD_METHOD_CONST(OutputImageParam, dim) HALIDE_FORWARD_METHOD_CONST(OutputImageParam, host_alignment) HALIDE_FORWARD_METHOD(OutputImageParam, set_host_alignment) + HALIDE_FORWARD_METHOD(OutputImageParam, is_host_aligned) HALIDE_FORWARD_METHOD(OutputImageParam, store_in) HALIDE_FORWARD_METHOD_CONST(OutputImageParam, dimensions) HALIDE_FORWARD_METHOD_CONST(OutputImageParam, left) diff --git a/src/HexagonAlignment.h b/src/HexagonAlignment.h deleted file mode 100644 index 71ae88fce820..000000000000 --- a/src/HexagonAlignment.h +++ /dev/null @@ -1,69 +0,0 @@ -#ifndef HALIDE_HEXAGON_ALIGNMENT_H -#define HALIDE_HEXAGON_ALIGNMENT_H - -/** \file - * Class for analyzing Alignment of loads and stores for Hexagon. - */ - -#include "IR.h" - -namespace Halide { -namespace Internal { - -// TODO: This class is barely stateful, and could probably be replaced with free functions. -class HexagonAlignmentAnalyzer { - const int required_alignment; - -public: - HexagonAlignmentAnalyzer(int required_alignment) - : required_alignment(required_alignment) { - internal_assert(required_alignment != 0); - } - - /** Analyze the index of a load/store instruction for alignment - * Returns true if it can determing that the address of the store or load is aligned, false otherwise. - */ - template - bool is_aligned_impl(const T *op, int native_lanes, int64_t *aligned_offset) { - debug(3) << "HexagonAlignmentAnalyzer: Check if " << op->index << " is aligned to a " - << required_alignment << " byte boundary\n" - << "native_lanes: " << native_lanes << "\n"; - Expr index = op->index; - const Ramp *ramp = index.as(); - if (ramp) { - index = ramp->base; - } else if (index.type().is_vector()) { - debug(3) << "Is Unaligned\n"; - return false; - } - - internal_assert(native_lanes != 0) << "Type is larger than required alignment of " << required_alignment << " bytes\n"; - - // If this is a parameter, the base_alignment should be - // host_alignment. Otherwise, this is an internal buffer, - // which we assume has been aligned to the required alignment. - if (op->param.defined() && ((op->param.host_alignment() % required_alignment) != 0)) { - return false; - } - - bool known_alignment = (op->alignment.modulus % native_lanes) == 0; - if (known_alignment) { - *aligned_offset = op->alignment.remainder % native_lanes; - } - return known_alignment && (*aligned_offset == 0); - } - - bool is_aligned(const Load *op, int64_t *aligned_offset) { - int native_lanes = required_alignment / op->type.bytes(); - return is_aligned_impl(op, native_lanes, aligned_offset); - } - - bool is_aligned(const Store *op, int64_t *aligned_offset) { - int native_lanes = required_alignment / op->value.type().bytes(); - return is_aligned_impl(op, native_lanes, aligned_offset); - } -}; - -} // namespace Internal -} // namespace Halide -#endif diff --git a/src/HexagonOptimize.cpp b/src/HexagonOptimize.cpp index fcd0de374d81..f11205dbf187 100644 --- a/src/HexagonOptimize.cpp +++ b/src/HexagonOptimize.cpp @@ -5,7 +5,6 @@ #include "ConciseCasts.h" #include "ExprUsesVar.h" #include "FindIntrinsics.h" -#include "HexagonAlignment.h" #include "IREquality.h" #include "IRMatch.h" #include "IRMutator.h" @@ -1419,10 +1418,7 @@ class EliminateInterleaves : public IRMutator { Scope vars; // We need to know when loads are a multiple of 2 native vectors. - int native_vector_bits; - - // Alignment analyzer for loads and stores - HexagonAlignmentAnalyzer alignment_analyzer; + const int native_vector_bytes; // Check if x is an expression that is either an interleave, or // transitively is an interleave. @@ -1884,9 +1880,8 @@ class EliminateInterleaves : public IRMutator { } internal_assert(aligned_buffer_access.contains(op->name) && "Buffer not found in scope"); bool &aligned_accesses = aligned_buffer_access.ref(op->name); - int64_t aligned_offset = 0; - - if (!alignment_analyzer.is_aligned(op, &aligned_offset)) { + const int native_vector_lanes = native_vector_bytes / value.type().bytes(); + if (op->alignment.contains(native_vector_lanes)) { aligned_accesses = false; } } @@ -1906,7 +1901,7 @@ class EliminateInterleaves : public IRMutator { Expr visit(const Load *op) override { if (buffers.contains(op->name)) { - if ((op->type.lanes() * op->type.bits()) % (native_vector_bits * 2) == 0) { + if ((op->type.lanes() * op->type.bytes()) % (native_vector_bytes * 2) == 0) { // This is a double vector load, we might be able to // deinterleave the storage of this buffer. // We don't want to actually do anything to the buffer @@ -1918,9 +1913,8 @@ class EliminateInterleaves : public IRMutator { // interleave). internal_assert(aligned_buffer_access.contains(op->name) && "Buffer not found in scope"); bool &aligned_accesses = aligned_buffer_access.ref(op->name); - int64_t aligned_offset = 0; - - if (!alignment_analyzer.is_aligned(op, &aligned_offset)) { + const int native_vector_lanes = native_vector_bytes / op->type.bytes(); + if (op->alignment.contains(native_vector_lanes)) { aligned_accesses = false; } } else { @@ -1941,7 +1935,7 @@ class EliminateInterleaves : public IRMutator { public: EliminateInterleaves(int native_vector_bytes) - : native_vector_bits(native_vector_bytes * 8), alignment_analyzer(native_vector_bytes) { + : native_vector_bytes(native_vector_bytes) { } }; diff --git a/src/IR.h b/src/IR.h index ce5e16cac996..359932227dd6 100644 --- a/src/IR.h +++ b/src/IR.h @@ -208,8 +208,8 @@ struct Load : public ExprNode { // If it's a load from an image parameter, this points to that Parameter param; - // The alignment of the index. If the index is a vector, this is - // the alignment of the first lane. + // The alignment of the loaded address. If the index is a vector, + // this is the alignment of the first lane. ModulusRemainder alignment; static Expr make(Type type, const std::string &name, @@ -318,7 +318,7 @@ struct Store : public StmtNode { // If it's a store to an output buffer, then this parameter points to it. Parameter param; - // The alignment of the index. If the index is a vector, this is + // The alignment of the stored address. If the index is a vector, this is // the alignment of the first lane. ModulusRemainder alignment; diff --git a/src/ModulusRemainder.h b/src/ModulusRemainder.h index c0341b75abf6..f92507b28438 100644 --- a/src/ModulusRemainder.h +++ b/src/ModulusRemainder.h @@ -47,6 +47,11 @@ struct ModulusRemainder { bool operator==(const ModulusRemainder &other) const { return (modulus == other.modulus) && (remainder == other.remainder); } + + // Check if this set contains the value x. + bool contains(int64_t x) const { + return modulus % x == 0 && remainder % x == 0; + } }; ModulusRemainder operator+(const ModulusRemainder &a, const ModulusRemainder &b); diff --git a/src/OutputImageParam.cpp b/src/OutputImageParam.cpp index 0a668c3395a9..69ecf8eee82f 100644 --- a/src/OutputImageParam.cpp +++ b/src/OutputImageParam.cpp @@ -40,6 +40,12 @@ OutputImageParam &OutputImageParam::set_host_alignment(int bytes) { return *this; } +Expr OutputImageParam::is_host_aligned(int bytes) { + Expr host_ptr = Internal::Variable::make(Handle(), param.name(), Buffer<>(), param, Internal::ReductionDomain()); + Expr u64t_host_ptr = reinterpret(host_ptr); + return (u64t_host_ptr % cast(bytes)) == 0; +} + int OutputImageParam::dimensions() const { return param.dimensions(); } diff --git a/src/OutputImageParam.h b/src/OutputImageParam.h index c5b6b30371c7..62824355dfe3 100644 --- a/src/OutputImageParam.h +++ b/src/OutputImageParam.h @@ -65,7 +65,11 @@ class OutputImageParam { int host_alignment() const; /** Set the expected alignment of the host pointer in bytes. */ - OutputImageParam &set_host_alignment(int); + OutputImageParam &set_host_alignment(int bytes); + + /** Returns a boolean Expr that is true if the host pointer is + * aligned to `bytes`. */ + Expr is_host_aligned(int bytes); /** Get the dimensionality of this image parameter */ int dimensions() const; diff --git a/src/Simplify.cpp b/src/Simplify.cpp index 12fd76d96eb5..413bc02b8f88 100644 --- a/src/Simplify.cpp +++ b/src/Simplify.cpp @@ -202,6 +202,8 @@ void Simplify::ScopedFact::learn_true(const Expr &fact) { const Mod *m = eq->a.as(); const int64_t *modulus = m ? as_const_int(m->b) : nullptr; const int64_t *remainder = m ? as_const_int(eq->b) : nullptr; + const uint64_t *umodulus = m ? as_const_uint(m->b) : nullptr; + const uint64_t *uremainder = m ? as_const_uint(eq->b) : nullptr; if (v) { if (is_const(eq->b) || eq->b.as()) { // TODO: consider other cases where we might want to entirely substitute @@ -249,6 +251,27 @@ void Simplify::ScopedFact::learn_true(const Expr &fact) { } simplify->bounds_and_alignment_info.push(v->name, expr_info); bounds_pop_list.push_back(v); + } else if (umodulus && uremainder) { + Expr m_a = m->a; + if (const Call *c = Call::as_intrinsic(m_a, {Call::reinterpret})) { + if (c->args[0].type().is_handle()) { + // Ignore reinterprets of pointers for the purposes of learning alignment. + m_a = c->args[0]; + } + } + if ((v = m_a.as())) { + // Learn from expressions of the form x % 8 == 3 + Simplify::ExprInfo expr_info; + expr_info.alignment.modulus = *umodulus; + expr_info.alignment.remainder = *uremainder; + if (simplify->bounds_and_alignment_info.contains(v->name)) { + // We already know something about this variable and don't want to suppress it. + auto existing_knowledge = simplify->bounds_and_alignment_info.get(v->name); + expr_info.intersect(existing_knowledge); + } + simplify->bounds_and_alignment_info.push(v->name, expr_info); + bounds_pop_list.push_back(v); + } } } else if (const LT *lt = fact.as()) { const Variable *v = lt->a.as(); diff --git a/src/Simplify_Exprs.cpp b/src/Simplify_Exprs.cpp index d0b2da780e40..91b448eef3a9 100644 --- a/src/Simplify_Exprs.cpp +++ b/src/Simplify_Exprs.cpp @@ -307,6 +307,30 @@ Expr Simplify::visit(const Load *op, ExprInfo *bounds) { ModulusRemainder align = ModulusRemainder::intersect(op->alignment, base_info.alignment); + if (!allocations.contains(op->name)) { + // For external buffers, we also need to know something about the + // alignment of the pointer. + ModulusRemainder ptr_alignment(1, 0); + if (bounds_and_alignment_info.contains(op->name)) { + ptr_alignment = bounds_and_alignment_info.get(op->name).alignment; + // The alignment of the ptr is in bytes, we need it + // in values. + int type_bytes = op->type.bytes(); + if (ptr_alignment.contains(type_bytes)) { + ptr_alignment.modulus /= type_bytes; + ptr_alignment.remainder /= type_bytes; + } else { + ptr_alignment = ModulusRemainder(1, 0); + } + } + align = ptr_alignment + align; + + // Simplification should never reduce the alignment, but it + // does happen if simplifying without context containing + // information about the pointer. + align = ModulusRemainder::intersect(op->alignment, align); + } + const Broadcast *b_index = index.as(); const Shuffle *s_index = index.as(); if (is_const_zero(predicate)) { diff --git a/src/Simplify_Internal.h b/src/Simplify_Internal.h index 845aaa07527d..3568f697d88e 100644 --- a/src/Simplify_Internal.h +++ b/src/Simplify_Internal.h @@ -174,6 +174,8 @@ class Simplify : public VariadicVisitor { // Only tracked for integer let vars Scope bounds_and_alignment_info; + Scope<> allocations; + // Symbols used by rewrite rules IRMatcher::Wild<0> x; IRMatcher::Wild<1> y; diff --git a/src/Simplify_Stmts.cpp b/src/Simplify_Stmts.cpp index 3c53b2b34a66..65f3553da591 100644 --- a/src/Simplify_Stmts.cpp +++ b/src/Simplify_Stmts.cpp @@ -270,6 +270,30 @@ Stmt Simplify::visit(const Store *op) { ModulusRemainder align = ModulusRemainder::intersect(op->alignment, base_info.alignment); + if (!allocations.contains(op->name)) { + // For external buffers, we also need to know something about the + // alignment of the pointer. + ModulusRemainder ptr_alignment(1, 0); + if (bounds_and_alignment_info.contains(op->name)) { + ptr_alignment = bounds_and_alignment_info.get(op->name).alignment; + // The alignment of the ptr is in bytes, we need it + // in values. + int type_bytes = op->value.type().bytes(); + if (ptr_alignment.contains(type_bytes)) { + ptr_alignment.modulus /= type_bytes; + ptr_alignment.remainder /= type_bytes; + } else { + ptr_alignment = {1, 0}; + } + } + align = ptr_alignment + align; + + // Simplification should never reduce the alignment, but it + // does happen if simplifying without context containing + // information about the pointer. + align = ModulusRemainder::intersect(op->alignment, align); + } + if (is_const_zero(predicate)) { // Predicate is always false return Evaluate::make(0); @@ -293,12 +317,13 @@ Stmt Simplify::visit(const Allocate *op) { new_extents.push_back(mutate(op->extents[i], nullptr)); all_extents_unmodified &= new_extents[i].same_as(op->extents[i]); } - Stmt body = mutate(op->body); Expr condition = mutate(op->condition, nullptr); Expr new_expr; if (op->new_expr.defined()) { new_expr = mutate(op->new_expr, nullptr); } + ScopedBinding<> allocated(allocations, op->name); + Stmt body = mutate(op->body); const IfThenElse *body_if = body.as(); if (body_if && op->condition.defined() && diff --git a/test/correctness/host_alignment.cpp b/test/correctness/host_alignment.cpp index 4dc7bf40a376..0f827ce9168c 100644 --- a/test/correctness/host_alignment.cpp +++ b/test/correctness/host_alignment.cpp @@ -3,134 +3,76 @@ #include #include -namespace { - using std::map; using std::string; -using std::vector; using namespace Halide; using namespace Halide::Internal; -class FindErrorHandler : public IRVisitor { -public: - bool result; - FindErrorHandler() - : result(false) { - } - using IRVisitor::visit; - void visit(const Call *op) override { - if (op->name == "halide_error_unaligned_host_ptr" && - op->call_type == Call::Extern) { - result = true; - return; - } - IRVisitor::visit(op); - } -}; - -class ParseCondition : public IRVisitor { +class CheckLoadsStoresAligned : public IRMutator { public: - Expr condition; - - using IRVisitor::visit; - void visit(const Mod *op) override { - condition = op; + const map &alignments_needed; + CheckLoadsStoresAligned(const map &m) + : alignments_needed(m) { } - void visit(const Call *op) override { - if (op->is_intrinsic(Call::bitwise_and)) { - condition = op; - } else { - IRVisitor::visit(op); + using IRMutator::visit; + + void check_alignment(const string &name, const ModulusRemainder &alignment) { + auto i = alignments_needed.find(name); + ModulusRemainder expected_alignment = + i != alignments_needed.end() ? i->second : ModulusRemainder(1, 0); + if (alignment.modulus != expected_alignment.modulus || + alignment.remainder != expected_alignment.remainder) { + printf("Load/store of %s is (%d, %d), expected (%d, %d)\n", + name.c_str(), (int)alignment.modulus, (int)alignment.remainder, + (int)expected_alignment.modulus, (int)expected_alignment.remainder); + abort(); } } -}; -class CountHostAlignmentAsserts : public IRVisitor { -public: - int count; - std::map alignments_needed; - CountHostAlignmentAsserts(std::map m) - : count(0), - alignments_needed(m) { + Expr visit(const Load *op) override { + check_alignment(op->name, op->alignment); + return IRMutator::visit(op); } - using IRVisitor::visit; - - void visit(const AssertStmt *op) override { - Expr m = op->message; - FindErrorHandler f; - m.accept(&f); - if (f.result) { - Expr c = op->condition; - ParseCondition p; - c.accept(&p); - if (p.condition.defined()) { - Expr left, right; - if (const Mod *mod = p.condition.as()) { - left = mod->a; - right = mod->b; - } else if (const Call *call = Call::as_intrinsic(p.condition, {Call::bitwise_and})) { - left = call->args[0]; - right = call->args[1]; - } - const Call *reinterpret_call = left.as(); - if (!reinterpret_call || - !reinterpret_call->is_intrinsic(Call::reinterpret)) return; - Expr name = reinterpret_call->args[0]; - const Variable *V = name.as(); - string name_host_ptr = V->name; - int expected_alignment = alignments_needed[name_host_ptr]; - if (is_const(right, expected_alignment) || is_const(right, expected_alignment - 1)) { - count++; - alignments_needed.erase(name_host_ptr); - } - } - } + Stmt visit(const Store *op) override { + check_alignment(op->name, op->alignment); + return IRMutator::visit(op); } }; -void set_alignment_host_ptr(ImageParam &i, int align, std::map &m) { - i.set_host_alignment(align); - m.insert(std::pair(i.name(), align)); -} - -int count_host_alignment_asserts(Func f, std::map m) { - Target t = get_jit_target_from_environment(); - t.set_feature(Target::NoBoundsQuery); - f.compute_root(); - Stmt s = Internal::lower_main_stmt({f.function()}, f.name(), t); - CountHostAlignmentAsserts c(m); - s.accept(&c); - return c.count; -} - -int test() { - Var x, y, c; - std::map m; +int main(int argc, char **argv) { ImageParam i1(Int(8), 1); ImageParam i2(Int(8), 1); ImageParam i3(Int(8), 1); - - set_alignment_host_ptr(i1, 128, m); - set_alignment_host_ptr(i2, 32, m); - - Func f("f"); - f(x) = i1(x) + i2(x) + i3(x); - f.output_buffer().set_host_alignment(128); - m.insert(std::pair("f", 128)); - int cnt = count_host_alignment_asserts(f, m); - if (cnt != 3) { - printf("Error: expected 3 host alignment assertions in code, but got %d\n", cnt); - return -1; - } + ImageParam i4(Int(8), 1); + + Var x; + Func f; + f(x) = i1(x) + i2(x * 2) + i3(x / 2) + i4(x + 1); + + i1.dim(0).set_min(0); + i2.set_host_alignment(4); + i2.dim(0).set_min(0); + i4.set_host_alignment(8); + i4.dim(0).set_min(0); + f.output_buffer().set_host_alignment(3); + f.output_buffer().dim(0).set_min(0); + f.vectorize(x, 12, TailStrategy::RoundUp); + f.specialize(i3.is_host_aligned(4) && i3.dim(0).min() == 0); + f.specialize_fail("No unaligned loads"); + + map expected_alignment = { + {i1.name(), {1, 0}}, + {i2.name(), {4, 0}}, + {i3.name(), {2, 0}}, + {i4.name(), {4, 1}}, + {f.name(), {3, 0}}, + }; + f.add_custom_lowering_pass(new CheckLoadsStoresAligned(expected_alignment), []() {}); + // Test with NoAsserts to make sure the host alignment asserts are present. + f.compile_jit(get_jit_target_from_environment().with_feature(Target::NoAsserts)); printf("Success!\n"); return 0; } - -} // namespace - -int main(int argc, char **argv) { - return test(); -}