diff --git a/src/Deinterleave.cpp b/src/Deinterleave.cpp index c01002476b68..52a2b139ebcf 100644 --- a/src/Deinterleave.cpp +++ b/src/Deinterleave.cpp @@ -633,14 +633,14 @@ class Interleaver : public IRMutator { const int64_t *stride_ptr = as_const_int(r0->stride); - // The stride isn't a constant or is <= 0 - if (!stride_ptr || *stride_ptr < 1) { + // The stride isn't a constant or is <= 1 + if (!stride_ptr || *stride_ptr <= 1) { return Stmt(); } const int64_t stride = *stride_ptr; const int lanes = r0->lanes; - const int64_t expected_stores = stride == 1 ? lanes : stride; + const int64_t expected_stores = stride; // Collect the rest of the stores. std::vector stores; @@ -690,53 +690,11 @@ class Interleaver : public IRMutator { if (*offs < min_offset) { min_offset = *offs; } - - if (stride == 1) { - // Difference between bases is not a multiple of the lanes. - if (*offs % lanes != 0) { - return Stmt(); - } - - // This case only triggers if we have an immediate load of the correct stride on the RHS. - // TODO: Could we consider mutating the RHS so that we can handle more complex Expr's than just loads? - const Load *load = stores[i].as()->value.as(); - if (!load) { - return Stmt(); - } - // TODO(psuriana): Predicated load is not currently handled. - if (!is_const_one(load->predicate)) { - return Stmt(); - } - - const Ramp *ramp = load->index.as(); - if (!ramp) { - return Stmt(); - } - - // Load stride or lanes is not equal to the store lanes. - if (!is_const(ramp->stride, lanes) || ramp->lanes != lanes) { - return Stmt(); - } - - if (i == 0) { - load_name = load->name; - load_image = load->image; - load_param = load->param; - } else { - if (load->name != load_name) { - return Stmt(); - } - } - } } // Gather the args for interleaving. for (size_t i = 0; i < stores.size(); ++i) { int j = offsets[i] - min_offset; - if (stride == 1) { - j /= stores.size(); - } - if (j == 0) { base = stores[i].as()->index.as()->base; } @@ -751,14 +709,7 @@ class Interleaver : public IRMutator { return Stmt(); } - if (stride == 1) { - // Convert multiple dense vector stores of strided vector loads - // into one dense vector store of interleaving dense vector loads. - args[j] = Load::make(t, load_name, stores[i].as()->index, - load_image, load_param, const_true(t.lanes()), ModulusRemainder()); - } else { - args[j] = stores[i].as()->value; - } + args[j] = stores[i].as()->value; predicates[j] = stores[i].as()->predicate; } diff --git a/test/correctness/interleave.cpp b/test/correctness/interleave.cpp index 82b1332c1006..1cb8b10b8c0a 100644 --- a/test/correctness/interleave.cpp +++ b/test/correctness/interleave.cpp @@ -353,43 +353,24 @@ int main(int argc, char **argv) { } { - // Test that transposition works when vectorizing either dimension: + // Test transposition Func square("square"); square(x, y) = cast(UInt(16), 5 * x + y); - Func trans1("trans1"); - trans1(x, y) = square(y, x); - - Func trans2("trans2"); - trans2(x, y) = square(y, x); + Func trans("trans2"); + trans(x, y) = square(y, x); square.compute_root() .bound(x, 0, 8) .bound(y, 0, 8); - trans1.compute_root() - .bound(x, 0, 8) - .bound(y, 0, 8) - .vectorize(x) - .unroll(y); - - trans2.compute_root() + trans.compute_root() .bound(x, 0, 8) .bound(y, 0, 8) .unroll(x) .vectorize(y); - trans1.output_buffer() - .dim(0) - .set_min(0) - .set_stride(1) - .set_extent(8) - .dim(1) - .set_min(0) - .set_stride(8) - .set_extent(8); - - trans2.output_buffer() + trans.output_buffer() .dim(0) .set_min(0) .set_stride(1) @@ -399,19 +380,12 @@ int main(int argc, char **argv) { .set_stride(8) .set_extent(8); - Buffer result6(8, 8); Buffer result7(8, 8); - trans1.realize(result6); - trans2.realize(result7); + trans.realize(result7); for (int x = 0; x < 8; x++) { for (int y = 0; y < 8; y++) { int correct = 5 * y + x; - if (result6(x, y) != correct) { - printf("result(%d) = %d instead of %d\n", x, result6(x, y), correct); - return -1; - } - if (result7(x, y) != correct) { printf("result(%d) = %d instead of %d\n", x, result7(x, y), correct); return -1; @@ -419,8 +393,7 @@ int main(int argc, char **argv) { } } - check_interleave_count(trans1, 1); - check_interleave_count(trans2, 1); + check_interleave_count(trans, 1); } printf("Success!\n");