From 04f34e6c73ca331dc9bad49471ce57d5bb49158a Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Sun, 10 Feb 2019 07:59:37 +0000 Subject: [PATCH 1/6] fix storage_rewrite bug when input is big --- src/pass/storage_rewrite.cc | 5 +++- .../unittest/test_pass_storage_rewrite.py | 28 +++++++++++++++++++ 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/src/pass/storage_rewrite.cc b/src/pass/storage_rewrite.cc index 331b60a865ed..134676c287e1 100644 --- a/src/pass/storage_rewrite.cc +++ b/src/pass/storage_rewrite.cc @@ -564,6 +564,9 @@ class StoragePlanRewriter : public IRMutator { Expr combo_size; for (const Allocate* op : e->allocs) { Expr sz = arith::ComputeReduce(op->extents, make_const(Int(32), 1)); + if (const auto* imm = sz.as()) { + sz = make_const(Int(64), imm->value); + } // transform to bits auto sz_nbits = sz * (op->type.bits() * op->type.lanes()); if (combo_size.defined()) { @@ -578,7 +581,7 @@ class StoragePlanRewriter : public IRMutator { combo_size = combo_size / type_bits; // round up for can not divided if (!divided) { - combo_size = combo_size + make_const(Int(32), 1); + combo_size = combo_size + make_const(Int(64), 1); } combo_size = ir::Simplify(combo_size); e->new_alloc = Allocate::make( diff --git a/tests/python/unittest/test_pass_storage_rewrite.py b/tests/python/unittest/test_pass_storage_rewrite.py index faf70204c29e..9573d0cdf8cb 100644 --- a/tests/python/unittest/test_pass_storage_rewrite.py +++ b/tests/python/unittest/test_pass_storage_rewrite.py @@ -477,6 +477,33 @@ def test_replace_dataflow(): assert isinstance(bounds, tvm.container.Map) +def test_large_input(): + @tvm.hybrid.script + def compute(a, b): + n = 16384 + c = output_tensor((n, n), 'int32') + for i in range(n): + for j in range(n): + c[i, j] = a[i, j] - b[i, j] + return c + + n = 16384 + shape = (n, n) + a = tvm.placeholder(shape, name='a', dtype='int32') + b = tvm.placeholder(shape, name='b', dtype='int32') + c = tvm.compute(shape, lambda i, j: compute(a, b)[i, j]) + c = tvm.compute(shape, lambda i, j: 1 + c[i, j]) + s = tvm.create_schedule(c.op) + stmt = tvm.lower(s, [a, b, c], simple_mode=True) + num_alloc = [0] + def verify(n): + if isinstance(n, tvm.stmt.Allocate): + num_alloc[0] += 1 + assert n.extents[0].value == 268435456 + tvm.ir_pass.PostOrderVisit(stmt, verify) + assert num_alloc[0] == 1 + + if __name__ == "__main__": test_alloc_seq() test_alloc_different_dtypes() @@ -492,3 +519,4 @@ def test_replace_dataflow(): test_alloc_seq_type2() test_reuse_small_buffer() test_replace_dataflow() + test_large_input() From b2965b3b86fd85346417c7d4606d2f452cebce59 Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Mon, 11 Feb 2019 18:50:31 +0000 Subject: [PATCH 2/6] cast when necessary --- src/pass/storage_rewrite.cc | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/pass/storage_rewrite.cc b/src/pass/storage_rewrite.cc index 134676c287e1..065894871ab9 100644 --- a/src/pass/storage_rewrite.cc +++ b/src/pass/storage_rewrite.cc @@ -564,11 +564,14 @@ class StoragePlanRewriter : public IRMutator { Expr combo_size; for (const Allocate* op : e->allocs) { Expr sz = arith::ComputeReduce(op->extents, make_const(Int(32), 1)); + auto nbits = op->type.bits() * op->type.lanes(); if (const auto* imm = sz.as()) { - sz = make_const(Int(64), imm->value); + if (imm->value > std::numeric_limits::max() / nbits) { + sz = make_const(Int(64), imm->value); + } } // transform to bits - auto sz_nbits = sz * (op->type.bits() * op->type.lanes()); + auto sz_nbits = sz * nbits; if (combo_size.defined()) { combo_size = max(combo_size, sz_nbits); } else { @@ -581,7 +584,7 @@ class StoragePlanRewriter : public IRMutator { combo_size = combo_size / type_bits; // round up for can not divided if (!divided) { - combo_size = combo_size + make_const(Int(64), 1); + combo_size = combo_size + make_const(Int(32), 1); } combo_size = ir::Simplify(combo_size); e->new_alloc = Allocate::make( From 00c41dab19819901d2ed3b8fcab7653570241869 Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Tue, 12 Feb 2019 20:20:05 +0000 Subject: [PATCH 3/6] simplification --- src/pass/storage_rewrite.cc | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/pass/storage_rewrite.cc b/src/pass/storage_rewrite.cc index 065894871ab9..340c00b0443a 100644 --- a/src/pass/storage_rewrite.cc +++ b/src/pass/storage_rewrite.cc @@ -550,8 +550,10 @@ class StoragePlanRewriter : public IRMutator { } if (e->allocs.size() == 1) { // simply use the original allocation. + Expr sz = arith::ComputeReduce(e->allocs[0]->extents, + make_const(Int(32), 1)); e->new_alloc = Allocate::make( - e->alloc_var, alloc_type, e->allocs[0]->extents, + e->alloc_var, alloc_type, {sz}, e->allocs[0]->condition, Evaluate::make(0)); if (e->scope.tag.length() != 0) { MemoryInfo info = GetMemoryInfo(e->scope.to_string()); @@ -567,6 +569,11 @@ class StoragePlanRewriter : public IRMutator { auto nbits = op->type.bits() * op->type.lanes(); if (const auto* imm = sz.as()) { if (imm->value > std::numeric_limits::max() / nbits) { + LOG(WARNING) << "The allocation requires : " << imm->value + << " * " << nbits + << " bits, which is greater than the maximum of " + "int32. The size is cast to int64." + << "\n"; sz = make_const(Int(64), imm->value); } } From 56f32348b40d6031b0f0f5d7491d9c90dee0e017 Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Tue, 12 Feb 2019 20:20:05 +0000 Subject: [PATCH 4/6] simplification --- tests/python/unittest/test_pass_storage_rewrite.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/python/unittest/test_pass_storage_rewrite.py b/tests/python/unittest/test_pass_storage_rewrite.py index 9573d0cdf8cb..52851d4afe95 100644 --- a/tests/python/unittest/test_pass_storage_rewrite.py +++ b/tests/python/unittest/test_pass_storage_rewrite.py @@ -495,13 +495,10 @@ def compute(a, b): c = tvm.compute(shape, lambda i, j: 1 + c[i, j]) s = tvm.create_schedule(c.op) stmt = tvm.lower(s, [a, b, c], simple_mode=True) - num_alloc = [0] def verify(n): if isinstance(n, tvm.stmt.Allocate): - num_alloc[0] += 1 assert n.extents[0].value == 268435456 tvm.ir_pass.PostOrderVisit(stmt, verify) - assert num_alloc[0] == 1 if __name__ == "__main__": From f512daa379942eef68184be2fb570b7edc620e06 Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Wed, 13 Feb 2019 21:43:46 +0000 Subject: [PATCH 5/6] int64->uint32 --- src/pass/storage_rewrite.cc | 4 ++-- tests/python/unittest/test_pass_storage_rewrite.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/pass/storage_rewrite.cc b/src/pass/storage_rewrite.cc index 340c00b0443a..cd8ca73cc7c2 100644 --- a/src/pass/storage_rewrite.cc +++ b/src/pass/storage_rewrite.cc @@ -572,9 +572,9 @@ class StoragePlanRewriter : public IRMutator { LOG(WARNING) << "The allocation requires : " << imm->value << " * " << nbits << " bits, which is greater than the maximum of " - "int32. The size is cast to int64." + "int32. The size is cast to uint32." << "\n"; - sz = make_const(Int(64), imm->value); + sz = make_const(UInt(32), imm->value); } } // transform to bits diff --git a/tests/python/unittest/test_pass_storage_rewrite.py b/tests/python/unittest/test_pass_storage_rewrite.py index 52851d4afe95..bbf573ae1d31 100644 --- a/tests/python/unittest/test_pass_storage_rewrite.py +++ b/tests/python/unittest/test_pass_storage_rewrite.py @@ -480,14 +480,14 @@ def test_replace_dataflow(): def test_large_input(): @tvm.hybrid.script def compute(a, b): - n = 16384 + n = 8192 c = output_tensor((n, n), 'int32') for i in range(n): for j in range(n): c[i, j] = a[i, j] - b[i, j] return c - n = 16384 + n = 8192 shape = (n, n) a = tvm.placeholder(shape, name='a', dtype='int32') b = tvm.placeholder(shape, name='b', dtype='int32') @@ -497,7 +497,7 @@ def compute(a, b): stmt = tvm.lower(s, [a, b, c], simple_mode=True) def verify(n): if isinstance(n, tvm.stmt.Allocate): - assert n.extents[0].value == 268435456 + assert n.extents[0].value == 67108864 tvm.ir_pass.PostOrderVisit(stmt, verify) From 3ed16800ddae7ccde4c2f7ad6a9a1e7f3863bab6 Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Wed, 13 Feb 2019 22:06:05 +0000 Subject: [PATCH 6/6] revert uint32->int64 --- src/pass/storage_rewrite.cc | 6 +++--- tests/python/unittest/test_pass_storage_rewrite.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/pass/storage_rewrite.cc b/src/pass/storage_rewrite.cc index cd8ca73cc7c2..9ba9dcde63c9 100644 --- a/src/pass/storage_rewrite.cc +++ b/src/pass/storage_rewrite.cc @@ -571,10 +571,10 @@ class StoragePlanRewriter : public IRMutator { if (imm->value > std::numeric_limits::max() / nbits) { LOG(WARNING) << "The allocation requires : " << imm->value << " * " << nbits - << " bits, which is greater than the maximum of " - "int32. The size is cast to uint32." + << " bits, which is greater than the maximum of" + " int32. The size is cast to int64." << "\n"; - sz = make_const(UInt(32), imm->value); + sz = make_const(Int(64), imm->value); } } // transform to bits diff --git a/tests/python/unittest/test_pass_storage_rewrite.py b/tests/python/unittest/test_pass_storage_rewrite.py index bbf573ae1d31..52851d4afe95 100644 --- a/tests/python/unittest/test_pass_storage_rewrite.py +++ b/tests/python/unittest/test_pass_storage_rewrite.py @@ -480,14 +480,14 @@ def test_replace_dataflow(): def test_large_input(): @tvm.hybrid.script def compute(a, b): - n = 8192 + n = 16384 c = output_tensor((n, n), 'int32') for i in range(n): for j in range(n): c[i, j] = a[i, j] - b[i, j] return c - n = 8192 + n = 16384 shape = (n, n) a = tvm.placeholder(shape, name='a', dtype='int32') b = tvm.placeholder(shape, name='b', dtype='int32') @@ -497,7 +497,7 @@ def compute(a, b): stmt = tvm.lower(s, [a, b, c], simple_mode=True) def verify(n): if isinstance(n, tvm.stmt.Allocate): - assert n.extents[0].value == 67108864 + assert n.extents[0].value == 268435456 tvm.ir_pass.PostOrderVisit(stmt, verify)