From 6770f6b77252cf17a89ea7aeb292a2a54190cfff Mon Sep 17 00:00:00 2001 From: Gaoxiong Date: Sat, 20 Oct 2018 10:41:04 +0800 Subject: [PATCH] Fix non-zero extent of access_ptr out of range (#1937) --- src/lang/buffer.cc | 4 ++-- tests/python/unittest/test_lang_buffer.py | 13 +++++++++++++ 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/src/lang/buffer.cc b/src/lang/buffer.cc index 183a52f785bd..524cad2eeac6 100644 --- a/src/lang/buffer.cc +++ b/src/lang/buffer.cc @@ -357,9 +357,9 @@ Expr Buffer::access_ptr(int access_mask, Type ptr_type, int content_lanes, Expr } else if (self->strides.size() == self->shape.size()) { int highest_dim = 0; extent = arith::ComputeExpr( - self->strides[highest_dim], self->shape[highest_dim]); + self->strides[highest_dim], self->shape[highest_dim]) - offset; } else { - extent = arith::ComputeReduce(self->shape, Expr()); + extent = arith::ComputeReduce(self->shape, Expr()) - offset; } Expr elem_offset = self->elem_offset + offset; if (content_lanes > 1) { diff --git a/tests/python/unittest/test_lang_buffer.py b/tests/python/unittest/test_lang_buffer.py index 51f1e3abb7e9..85c9fbeee53e 100644 --- a/tests/python/unittest/test_lang_buffer.py +++ b/tests/python/unittest/test_lang_buffer.py @@ -41,6 +41,18 @@ def test_buffer_access_ptr_offset(): assert tvm.ir_pass.Equal(offset, tvm.call_extern('int32', "test_call", 200 + v)) assert aptr.args[4].value == Buffer.READ | Buffer.WRITE +def test_buffer_access_ptr_extent(): + m = tvm.var('m') + n = tvm.var('n') + Ab = tvm.decl_buffer((m, n), tvm.float32) + aptr = Ab.access_ptr("rw") + assert tvm.ir_pass.Equal(aptr.args[3], m * n) + aptr = Ab.access_ptr("rw", offset=100) + assert tvm.ir_pass.Equal(aptr.args[3], m * n - 100) + Ab = tvm.decl_buffer((m, n), tvm.float32, strides=[n + 1 , 1]) + aptr = Ab.access_ptr("rw", offset=100) + assert tvm.ir_pass.Equal(aptr.args[3], Ab.strides[0] * m - 100) + def test_buffer_vload(): m = tvm.var('m') n = tvm.var('n') @@ -84,5 +96,6 @@ def assert_simplified_equal(index_simplified, index_direct): test_buffer() test_buffer_access_ptr() test_buffer_access_ptr_offset() + test_buffer_access_ptr_extent() test_buffer_vload() test_buffer_index_merge_mult_mod()