diff --git a/src/lang/buffer.cc b/src/lang/buffer.cc index 183a52f785bd..524cad2eeac6 100644 --- a/src/lang/buffer.cc +++ b/src/lang/buffer.cc @@ -357,9 +357,9 @@ Expr Buffer::access_ptr(int access_mask, Type ptr_type, int content_lanes, Expr } else if (self->strides.size() == self->shape.size()) { int highest_dim = 0; extent = arith::ComputeExpr( - self->strides[highest_dim], self->shape[highest_dim]); + self->strides[highest_dim], self->shape[highest_dim]) - offset; } else { - extent = arith::ComputeReduce(self->shape, Expr()); + extent = arith::ComputeReduce(self->shape, Expr()) - offset; } Expr elem_offset = self->elem_offset + offset; if (content_lanes > 1) { diff --git a/tests/python/unittest/test_lang_buffer.py b/tests/python/unittest/test_lang_buffer.py index 51f1e3abb7e9..85c9fbeee53e 100644 --- a/tests/python/unittest/test_lang_buffer.py +++ b/tests/python/unittest/test_lang_buffer.py @@ -41,6 +41,18 @@ def test_buffer_access_ptr_offset(): assert tvm.ir_pass.Equal(offset, tvm.call_extern('int32', "test_call", 200 + v)) assert aptr.args[4].value == Buffer.READ | Buffer.WRITE +def test_buffer_access_ptr_extent(): + m = tvm.var('m') + n = tvm.var('n') + Ab = tvm.decl_buffer((m, n), tvm.float32) + aptr = Ab.access_ptr("rw") + assert tvm.ir_pass.Equal(aptr.args[3], m * n) + aptr = Ab.access_ptr("rw", offset=100) + assert tvm.ir_pass.Equal(aptr.args[3], m * n - 100) + Ab = tvm.decl_buffer((m, n), tvm.float32, strides=[n + 1 , 1]) + aptr = Ab.access_ptr("rw", offset=100) + assert tvm.ir_pass.Equal(aptr.args[3], Ab.strides[0] * m - 100) + def test_buffer_vload(): m = tvm.var('m') n = tvm.var('n') @@ -84,5 +96,6 @@ def assert_simplified_equal(index_simplified, index_direct): test_buffer() test_buffer_access_ptr() test_buffer_access_ptr_offset() + test_buffer_access_ptr_extent() test_buffer_vload() test_buffer_index_merge_mult_mod()