diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index 24848f8cfd89..53c8f7754602 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -815,12 +815,12 @@ CodeGenCPU::PackedCall CodeGenCPU::MakeCallPackedLowered(const Array& t_tvm_value_, builder_->CreatePointerCast(stack_value, t_tvm_value_->getPointerTo()), ConstInt32(begin)); TypedPointer arg_tcode = - CreateBufferPtr(stack_tcode, DataType::Int(32), ConstInt32(begin), DataType::Int(32)); + CreateBufferPtr(stack_tcode, DataType::Int(32), {ConstInt32(begin)}, DataType::Int(32)); llvm::Value* ret_value = builder_->CreateInBoundsGEP( t_tvm_value_, builder_->CreatePointerCast(stack_value, t_tvm_value_->getPointerTo()), ConstInt32(end)); TypedPointer ret_tcode = - CreateBufferPtr(stack_tcode, DataType::Int(32), ConstInt32(end), DataType::Int(32)); + CreateBufferPtr(stack_tcode, DataType::Int(32), {ConstInt32(end)}, DataType::Int(32)); #if TVM_LLVM_VERSION >= 90 auto call_callee = llvm::FunctionCallee(ftype_tvm_func_call_, RuntimeTVMFuncCall()); diff --git a/src/target/llvm/codegen_hexagon.cc b/src/target/llvm/codegen_hexagon.cc index 32587030ba17..28127da9a64b 100644 --- a/src/target/llvm/codegen_hexagon.cc +++ b/src/target/llvm/codegen_hexagon.cc @@ -76,6 +76,8 @@ class CodeGenHexagon final : public CodeGenLLVM { llvm::FunctionType* ftype_tvm_api_set_last_error_{nullptr}; private: + TypedPointer CreateBufferPtr(llvm::Value* buffer_ptr, DataType buffer_element_dtype, + llvm::ArrayRef indices, DataType value_dtype) final; TypedPointer CreateStructRefPtr(DataType t, llvm::Value* buf, llvm::Value* index, int kind); // Check if the call to packed function is successful @@ -320,12 +322,12 @@ CodeGenHexagon::PackedCall CodeGenHexagon::MakeCallPackedLowered(const ArrayCreatePointerCast(stack_value, t_tvm_value_->getPointerTo()), ConstInt32(begin)); TypedPointer arg_tcode = - CreateBufferPtr(stack_tcode, DataType::Int(32), ConstInt32(begin), DataType::Int(32)); + CreateBufferPtr(stack_tcode, DataType::Int(32), {ConstInt32(begin)}, DataType::Int(32)); llvm::Value* ret_value = builder_->CreateInBoundsGEP( t_tvm_value_, builder_->CreatePointerCast(stack_value, t_tvm_value_->getPointerTo()), ConstInt32(end)); TypedPointer ret_tcode = - CreateBufferPtr(stack_tcode, DataType::Int(32), ConstInt32(end), DataType::Int(32)); + CreateBufferPtr(stack_tcode, DataType::Int(32), {ConstInt32(end)}, DataType::Int(32)); #if TVM_LLVM_VERSION >= 90 auto call_callee = llvm::FunctionCallee(ftype_tvm_func_call_, RuntimeTVMFuncCall()); @@ -570,6 +572,31 @@ llvm::Value* CodeGenHexagon::CreateIntrinsic(const CallNode* op) { return CodeGenLLVM::CreateIntrinsic(op); } +CodeGenLLVM::TypedPointer CodeGenHexagon::CreateBufferPtr(llvm::Value* buffer_ptr, + DataType buffer_element_dtype, + llvm::ArrayRef indices, + DataType value_dtype) { + // Flat indices get delegated to the LLVM codegen. + if (indices.size() == 1) { + return CodeGenLLVM::CreateBufferPtr(buffer_ptr, buffer_element_dtype, indices, value_dtype); + } + + ICHECK_EQ(indices.size(), 2) << "CodegenHexagon supports 1-d and 2-d physical buffers, received " + << indices.size() << "-d buffer indices"; + + // Use the first index to identify the pointer. + DataType dtype_void_ptr = DataType::Handle(); + CodeGenLLVM::TypedPointer buffer_chunk_ptr_ptr = + CodeGenLLVM::CreateBufferPtr(buffer_ptr, dtype_void_ptr, {indices[0]}, dtype_void_ptr); + llvm::Value* buffer_chunk_ptr = + builder_->CreateLoad(buffer_chunk_ptr_ptr.type, buffer_chunk_ptr_ptr.addr); + + // Then delegate the CodeGenLLVM to find the value from the second + // index. + return CodeGenLLVM::CreateBufferPtr(buffer_chunk_ptr, buffer_element_dtype, {indices[1]}, + value_dtype); +} + CodeGenLLVM::TypedPointer CodeGenHexagon::CreateStructRefPtr(DataType t, llvm::Value* buf, llvm::Value* index, int kind) { static const std::map field_index = { diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 26aadd4ff881..3ddf4af12bea 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -791,7 +791,11 @@ llvm::Constant* CodeGenLLVM::GetConstString(const std::string& str) { CodeGenLLVM::TypedPointer CodeGenLLVM::CreateBufferPtr(llvm::Value* buffer_ptr, DataType buffer_element_dtype, - llvm::Value* index, DataType value_dtype) { + llvm::ArrayRef indices, + DataType value_dtype) { + ICHECK_EQ(indices.size(), 1) << "CodeGenLLVM requires all buffers to be flat 1-d buffers."; + llvm::Value* index = indices[0]; + llvm::PointerType* buffer_ptr_type = llvm::dyn_cast(buffer_ptr->getType()); ICHECK(buffer_ptr_type != nullptr); auto address_space = buffer_ptr_type->getAddressSpace(); @@ -1010,7 +1014,7 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { index = r->base; } TypedPointer buffer_ptr = CreateBufferPtr(MakeValue(load->buffer->data), load->buffer->dtype, - MakeValue(index), load->dtype); + {MakeValue(index)}, load->dtype); unsigned addrspace = llvm::dyn_cast(buffer_ptr.addr->getType())->getAddressSpace(); return builder_->CreatePointerCast(buffer_ptr.addr, t_char_->getPointerTo(addrspace)); @@ -1274,39 +1278,56 @@ bool CodeGenLLVM::HasAlignmentPadding(DataType dtype) { } void CodeGenLLVM::BufferAccessHelper( - Buffer buffer, PrimExpr index, DataType value_dtype, + Buffer buffer, Array indices, DataType value_dtype, std::function make_instruction) { DataType buffer_element_dtype = buffer->dtype; - ICHECK_EQ(value_dtype.lanes(), index.dtype().lanes() * buffer_element_dtype.lanes()); + ICHECK_GE(indices.size(), 1) + << "Buffer " << buffer->name << " is accessed with no indices. " + << "0-d scalar buffers are expected to be flattened to 1-d buffers prior to codegen."; + + // Only the last index is allowed to be multi-lane. All earlier + // indices must be scalar. This only matters for subclasses of + // CodeGenLLVM, because the default implementation of GetBufferPtr + // requires 1-d indices. + std::vector earlier_index_values; + for (size_t i = 0; i < indices.size() - 1; i++) { + ICHECK_EQ(indices[i].dtype().lanes(), 1) + << "Buffer " << buffer->name << " is accessed with a multi-lane index at position " << i + << ". Multi-lane indices are only supported as the last index."; + earlier_index_values.push_back(MakeValue(indices[i])); + } + + PrimExpr last_index = indices[indices.size() - 1]; + ICHECK_EQ(value_dtype.lanes(), last_index.dtype().lanes() * buffer_element_dtype.lanes()); bool is_volatile = volatile_buf_.count(buffer->data.get()); // If the buffer index is a contiguous ramp node, we only need to // access the first element, then cast to the value type. - if (const RampNode* ramp_index = index.as()) { + if (const RampNode* ramp_index = last_index.as()) { if (ramp_index && is_one(ramp_index->stride)) { - index = ramp_index->base; + last_index = ramp_index->base; } } // All TVM arrays are densely packed. If the vectorized LLVM type // contains padding for alignment, we need to index based on the // size of the scalar type to avoid introducing that padding. - if (index.dtype().lanes() == 1 && HasAlignmentPadding(buffer_element_dtype)) { - index = buffer_element_dtype.lanes() * index; + if (last_index.dtype().lanes() == 1 && HasAlignmentPadding(buffer_element_dtype)) { + last_index = buffer_element_dtype.lanes() * last_index; buffer_element_dtype = buffer_element_dtype.element_of(); } int alignment; - if (index.dtype().lanes() == 1) { + if (last_index.dtype().lanes() == 1) { // If we are accessing with a single index, then the vectorized // element being accessed may require more alignment than the // underlying data type. int native_bits; - GetAlignment(value_dtype, buffer->data.get(), index, &alignment, &native_bits); + GetAlignment(value_dtype, buffer->data.get(), last_index, &alignment, &native_bits); } else { // Otherwise, alignment is based on the return value's scalar // type. @@ -1315,35 +1336,35 @@ void CodeGenLLVM::BufferAccessHelper( } llvm::Value* cached_vector_index = nullptr; - for (int i = 0; i < index.dtype().lanes(); ++i) { - llvm::Value* index_value; + for (int i = 0; i < last_index.dtype().lanes(); ++i) { + llvm::Value* last_index_value; int subelement_i = i; - if (const RampNode* ramp = index.as()) { + if (const RampNode* ramp = last_index.as()) { PrimExpr offset = ramp->base + (ramp->stride * i); - index_value = MakeValue(offset); - } else if (index.dtype().lanes() > 1) { + last_index_value = MakeValue(offset); + } else if (last_index.dtype().lanes() > 1) { if (i == 0) { - cached_vector_index = MakeValue(index); + cached_vector_index = MakeValue(last_index); } - index_value = builder_->CreateExtractElement(cached_vector_index, i); + last_index_value = builder_->CreateExtractElement(cached_vector_index, i); } else { - index_value = MakeValue(index); + last_index_value = MakeValue(last_index); subelement_i = -1; } + std::vector all_index_values = earlier_index_values; + all_index_values.push_back(last_index_value); + TypedPointer buffer_ptr = - CreateBufferPtr(MakeValue(buffer->data), buffer_element_dtype, index_value, - value_dtype.with_lanes(value_dtype.lanes() / index.dtype().lanes())); + CreateBufferPtr(MakeValue(buffer->data), buffer_element_dtype, all_index_values, + value_dtype.with_lanes(value_dtype.lanes() / last_index.dtype().lanes())); auto instruction = make_instruction(buffer_ptr, subelement_i, alignment, is_volatile); - AddAliasInfo(instruction, buffer->data.get(), index); + AddAliasInfo(instruction, buffer->data.get(), last_index); } } llvm::Value* CodeGenLLVM::VisitExpr_(const BufferLoadNode* op) { - ICHECK_EQ(op->indices.size(), 1) << "CodeGenLLVM expects flattened 1-d buffers."; - DataType value_dtype = op->dtype; - PrimExpr index = op->indices[0]; std::vector loads; @@ -1363,7 +1384,10 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const BufferLoadNode* op) { return load; }; - BufferAccessHelper(op->buffer, index, value_dtype, make_load); + // Pass all indices into BufferAccessHelper. In CodeGenLLVM, + // non-flat indices will result in an error in CreateBufferPtr, but + // a subclass may override CreateBufferPtr. + BufferAccessHelper(op->buffer, op->indices, value_dtype, make_load); if (loads.size() == 1) { return loads[0]; @@ -1441,11 +1465,8 @@ void CodeGenLLVM::VisitStmt_(const StoreNode* op) { } void CodeGenLLVM::VisitStmt_(const BufferStoreNode* op) { - ICHECK_EQ(op->indices.size(), 1) << "CodeGenLLVM expects flattened 1-d buffers."; - DataType value_dtype = op->value.dtype(); Var buffer_var = op->buffer->data; - PrimExpr buffer_index = op->indices[0]; llvm::Value* value = MakeValue(op->value); @@ -1463,7 +1484,10 @@ void CodeGenLLVM::VisitStmt_(const BufferStoreNode* op) { #endif }; - BufferAccessHelper(op->buffer, buffer_index, value_dtype, make_store); + // Pass all indices into BufferAccessHelper. In CodeGenLLVM, + // non-flat indices will result in an error in CreateBufferPtr, but + // a subclass may override CreateBufferPtr. + BufferAccessHelper(op->buffer, op->indices, value_dtype, make_store); } void CodeGenLLVM::VisitStmt_(const ForNode* op) { @@ -1528,6 +1552,10 @@ void CodeGenLLVM::VisitStmt_(const AllocateConstNode* op) { } void CodeGenLLVM::VisitStmt_(const AllocateNode* op) { + ICHECK_EQ(op->extents.size(), 1) + << "LLVM codegen only supports flat 1-d buffer allocation, but allocation of " + << op->buffer_var->name_hint << " is " << op->extents << "-d"; + ICHECK(!is_zero(op->condition)); llvm::Value* buf = nullptr; diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index 3ec0881d5251..559ce97f8fc4 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -265,7 +265,7 @@ class CodeGenLLVM : public ExprFunctor, * * \param buffer The buffer being accessed * - * \param index The index at which the buffer is being accessed. + * \param indices The indices at which the buffer is being accessed. * * \param value_dtype The datatype to be read from (BufferLoad) or * written to (BufferStore) the buffer. @@ -286,7 +286,7 @@ class CodeGenLLVM : public ExprFunctor, * - Should return the generated expression. */ void BufferAccessHelper( - Buffer buffer, PrimExpr index, DataType value_dtype, + Buffer buffer, Array indices, DataType value_dtype, std::function make_instruction); @@ -372,8 +372,8 @@ class CodeGenLLVM : public ExprFunctor, llvm::Value* CreateSub(DataType t, llvm::Value* a, llvm::Value* b); llvm::Value* CreateMul(DataType t, llvm::Value* a, llvm::Value* b); llvm::Value* CreateBroadcast(llvm::Value* value, int lanes); - TypedPointer CreateBufferPtr(llvm::Value* buffer_ptr, DataType buffer_element_dtype, - llvm::Value* index, DataType value_dtype); + virtual TypedPointer CreateBufferPtr(llvm::Value* buffer_ptr, DataType buffer_element_dtype, + llvm::ArrayRef indices, DataType value_dtype); // Vector concatenation. llvm::Value* CreateVecSlice(llvm::Value* vec, int begin, int extent); llvm::Value* CreateVecFlip(llvm::Value* vec); diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index 2bc081483ccd..ed36f5828d13 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -67,8 +67,13 @@ class BufferShapeLegalize : public StmtExprMutator { bound_analyzer(func->body); + auto pass = BufferShapeLegalize(func->buffer_map, &bound_analyzer); + auto fptr = func.CopyOnWrite(); - fptr->body = BufferShapeLegalize(fptr->buffer_map, &bound_analyzer)(std::move(fptr->body)); + fptr->body = pass(std::move(fptr->body)); + if (auto map = func->attrs.GetAttr>>("layout_transform_map")) { + func = WithAttr(std::move(func), "layout_transform_map", pass.UpdateIndexMap(map.value())); + } return func; }; return transform::CreatePrimFuncPass(pass_func, 0, "tir.BufferShapeLegalize", {}); @@ -89,6 +94,19 @@ class BufferShapeLegalize : public StmtExprMutator { } } + Map> UpdateIndexMap(const Map>& orig) { + Map> output; + for (const auto& kv : orig) { + auto it = buf_map_.find(kv.first); + if (it != buf_map_.end()) { + output.Set(it->second.remap_to, kv.second); + } else { + output.Set(kv.first, kv.second); + } + } + return output; + } + PrimExpr VisitExpr_(const VarNode* op) final { auto it = var_remap_.find(op); if (it != var_remap_.end()) { @@ -379,8 +397,13 @@ class BufferStrideLegalize : public StmtExprMutator { bound_analyzer(func->body); + auto pass = BufferStrideLegalize(func->buffer_map, &bound_analyzer); + auto fptr = func.CopyOnWrite(); - fptr->body = BufferStrideLegalize(fptr->buffer_map, &bound_analyzer)(std::move(fptr->body)); + fptr->body = pass(std::move(fptr->body)); + if (auto map = func->attrs.GetAttr>>("layout_transform_map")) { + func = WithAttr(std::move(func), "layout_transform_map", pass.UpdateIndexMap(map.value())); + } return func; }; return transform::CreatePrimFuncPass(pass_func, 0, "tir.BufferStrideLegalize", {}); @@ -403,6 +426,19 @@ class BufferStrideLegalize : public StmtExprMutator { } } + Map> UpdateIndexMap(const Map>& orig) { + Map> output; + for (const auto& kv : orig) { + auto it = buf_map_.find(kv.first); + if (it != buf_map_.end()) { + output.Set(it->second.remap_to, kv.second); + } else { + output.Set(kv.first, kv.second); + } + } + return output; + } + Map UpdatedExternBufferMap() const { return updated_extern_buffer_map_; } Buffer WithStrides(Buffer buf) { @@ -595,8 +631,13 @@ class ThreadScopePropagate : public StmtExprMutator { public: static transform::Pass Pass() { auto pass_func = [](PrimFunc func, IRModule m, transform::PassContext ctx) { + auto pass = ThreadScopePropagate(func->buffer_map); + auto fptr = func.CopyOnWrite(); - fptr->body = ThreadScopePropagate(fptr->buffer_map)(std::move(fptr->body)); + fptr->body = pass(std::move(fptr->body)); + if (auto map = func->attrs.GetAttr>>("layout_transform_map")) { + func = WithAttr(std::move(func), "layout_transform_map", pass.UpdateIndexMap(map.value())); + } return func; }; return transform::CreatePrimFuncPass(pass_func, 0, "tir.ThreadScopePropagate", {}); @@ -610,6 +651,19 @@ class ThreadScopePropagate : public StmtExprMutator { } } + Map> UpdateIndexMap(const Map>& orig) { + Map> output; + for (const auto& kv : orig) { + auto it = buf_remap_.find(kv.first->data); + if (it != buf_remap_.end()) { + output.Set(it->second, kv.second); + } else { + output.Set(kv.first, kv.second); + } + } + return output; + } + PrimExpr VisitExpr_(const VarNode* op) final { auto it = buf_remap_.find(GetRef(op)); if (it != buf_remap_.end()) { @@ -761,8 +815,10 @@ class BufferBindUnwrapper : public StmtExprMutator { bound_analyzer(func->body); + auto pass = BufferBindUnwrapper(func->buffer_map, &bound_analyzer); + auto fptr = func.CopyOnWrite(); - fptr->body = BufferBindUnwrapper(fptr->buffer_map, &bound_analyzer)(std::move(fptr->body)); + fptr->body = pass(std::move(fptr->body)); return func; }; return transform::CreatePrimFuncPass(pass_func, 0, "tir.BufferBindUnwrapper", {}); @@ -779,6 +835,20 @@ class BufferBindUnwrapper : public StmtExprMutator { } } + Map> UpdateIndexMap(const Map>& orig) { + Map> output; + for (const auto& kv : orig) { + const BufferEntry& e = GetBufferEntry(kv.first); + + if (e.remap) { + output.Set(e.remap->target, kv.second); + } else { + output.Set(kv.first, kv.second); + } + } + return output; + } + Stmt VisitStmt_(const StoreNode* op) final { LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; return Stmt(); @@ -1357,7 +1427,8 @@ class StorageFlattener : public StmtExprMutator { } e.buffer = Buffer(op->buffer->data, op->buffer->dtype, op->buffer->shape, op->buffer->strides, - PrimExpr(), op->buffer->name, align, 0, kDefault); + PrimExpr(), op->buffer->name, align, 0, kDefault, + op->buffer->axis_separators, op->buffer->span); e.flattened_buffer = e.buffer.GetFlattenedBuffer(); // TODO(Lunderberg): Move the handling of boolean into a diff --git a/tests/python/contrib/test_hexagon/test_2d_physical_buffers.py b/tests/python/contrib/test_hexagon/test_2d_physical_buffers.py new file mode 100755 index 000000000000..9093956bcfca --- /dev/null +++ b/tests/python/contrib/test_hexagon/test_2d_physical_buffers.py @@ -0,0 +1,263 @@ +#!/usr/bin/env python3 + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import contextlib +import sys +import tempfile +import pathlib + +import pytest +import numpy as np + +import tvm +import tvm.testing +from tvm import te +from tvm.tir.stmt_functor import post_order_visit + +from .conftest import requires_hexagon_toolchain + +# Needed to register the link_shared packedfunc. +import tvm.contrib.hexagon.hexagon + + +dtype = tvm.testing.parameter("int8") +batch_size = tvm.testing.parameter(16) +input_channels = tvm.testing.parameter(32) +output_channels = tvm.testing.parameter(32) +input_image_shape = tvm.testing.parameter((64, 64)) +filter_size = tvm.testing.parameter((5, 5)) + +input_layout = tvm.testing.parameter( + "nhwc", + "nchw-8h8w32c", + "nchw-8h8w32c-flat", +) +working_layout = tvm.testing.parameter( + "nhwc", + "nchw-8h8w32c", + "nchw-8h8w32c-flat", +) +output_layout = tvm.testing.parameter( + "nhwc", + "nchw-8h8w32c", + "nchw-8h8w32c-flat", +) +working_scope = tvm.testing.parameter( + "global", + "global.vtcm", +) + + +@tvm.testing.fixture +def target_host(target): + target = tvm.target.Target(target) + + if target.kind.name == "hexagon": + # Shouldn't have to modify the target here, current + # workaround. In the future, should move the parameter + # handling from tvm.target to target_kind.cc. + target = tvm.target.hexagon("v68", link_params=True) + host = target + else: + host = None + return tvm.target.Target(target, host=host) + + +@tvm.testing.fixture +def input_shape(batch_size, input_channels, input_image_shape): + return [batch_size, *input_image_shape, input_channels] + + +def transform_shape(shape, layout): + if layout == "nhwc": + return shape + elif layout in ["nchw-8h8w32c", "nchw-8h8w32c-flat"]: + N, H, W, C = shape + return [N, (C + 31) // 32, (H + 7) // 8, (W + 7) // 8, 8, 8, 32] + else: + raise RuntimeError(f"Unexpected layout '{layout}'") + + +@tvm.testing.fixture +def transformed_input_shape(input_shape, input_layout): + return transform_shape(input_shape, input_layout) + + +@tvm.testing.fixture +def transformed_output_shape(output_shape, output_layout): + return transform_shape(output_shape, output_layout) + + +@tvm.testing.fixture +def input_np(input_shape, dtype): + return (100 * np.random.uniform(size=input_shape)).astype(dtype) + + +def layout_transform_1d(n, h, w, c): + return [ + n, + c // 32, + h // 8, + w // 8, + h % 8, + w % 8, + c % 32, + ] + + +def layout_transform_2d(n, h, w, c): + return [ + n, + c // 32, + h // 8, + w // 8, + te.AXIS_SEPARATOR, + h % 8, + w % 8, + c % 32, + ] + + +def extract_buffers(stmt): + buffers = [] + + def visitor(node): + if isinstance(node, (tvm.tir.BufferLoad, tvm.tir.BufferStore, tvm.tir.BufferRealize)): + buffers.append(node.buffer) + + post_order_visit(stmt, visitor) + return buffers + + +class TestElementWise: + @tvm.testing.fixture + def output_np(self, input_np): + return 2 * input_np + + @tvm.testing.fixture + def output_shape(self, input_shape): + return input_shape + + @tvm.testing.fixture + def schedule_args( + self, + input_shape, + dtype, + input_layout, + output_layout, + working_layout, + working_scope, + ): + InputTensor = te.placeholder(input_shape, dtype, name="Input") + OutputTensor = te.compute( + shape=InputTensor.shape, + fcompute=lambda *indices: 2 * InputTensor[indices], + name="Output", + ) + schedule = te.create_schedule(OutputTensor.op) + + WriteCache = schedule.cache_write(OutputTensor, working_scope) + ReadCache = schedule.cache_read(InputTensor, working_scope, [WriteCache]) + + def apply_transform(tensor, layout): + if layout == "nhwc": + pass + elif layout == "nchw-8h8w32c": + return schedule[tensor].transform_layout(layout_transform_2d) + elif layout == "nchw-8h8w32c-flat": + return schedule[tensor].transform_layout(layout_transform_1d) + else: + raise RuntimeError(f"Unexpected layout '{layout}'") + + apply_transform(InputTensor, input_layout) + compute_loopnest = apply_transform(OutputTensor, output_layout) or OutputTensor.op.axis + schedule[WriteCache].compute_at(schedule[OutputTensor], compute_loopnest[0]) + + apply_transform(ReadCache, working_layout) + apply_transform(WriteCache, working_layout) + + return [schedule, [InputTensor, OutputTensor]] + + @tvm.testing.fixture + def ir_module(self, schedule_args): + # If the two buffers are accessed with the same indices, CSE + # will replace them with a Let binding. Since this makes it + # harder to test what the transformed indices are, disabling + # the CSE pass for this test. + with tvm.transform.PassContext(disabled_pass=["tir.CommonSubexprElimTIR"]): + return tvm.lower(*schedule_args) + + @tvm.testing.fixture + def uses_unsupported_physical_dimensions( + self, target_host, input_layout, working_layout, output_layout + ): + uses_2d_memory = "nchw-8h8w32c" in [input_layout, working_layout, output_layout] + can_handle_2d_memory = target_host.kind.name == "hexagon" + + return uses_2d_memory and not can_handle_2d_memory + + def test_param_shapes(self, ir_module, transformed_input_shape, transformed_output_shape): + func = ir_module["main"] + primfunc_input_shape, primfunc_output_shape = [ + list(func.preflattened_buffer_map[param].shape) for param in func.params + ] + assert primfunc_input_shape == transformed_input_shape + assert primfunc_output_shape == transformed_output_shape + + def test_cache_shape(self, ir_module, input_layout, working_layout, output_layout): + func = ir_module["main"] + for buffer in extract_buffers(func.body): + buffer_layout = { + "Input": input_layout, + "Input.global": working_layout, + "Output.global": working_layout, + "Input.global.vtcm": working_layout, + "Output.global.vtcm": working_layout, + "Output": output_layout, + }[buffer.name] + + expected_physical_dimensions = { + "nhwc": 1, + "nchw-8h8w32c": 2, + "nchw-8h8w32c-flat": 1, + }[buffer_layout] + + assert len(buffer.shape) == expected_physical_dimensions + + def test_lower(self, schedule_args): + return tvm.lower(*schedule_args) + + @requires_hexagon_toolchain + def test_build(self, schedule_args, target_host, input_layout, working_layout, output_layout): + # contextlib.nullcontext wasn't added until python3.7, and the + # CI currently runs on python3.6. Therefore, using ExitStack + # to manage an optional context instead. + stack = contextlib.ExitStack() + + with stack: + is_hexagon = target_host.kind.name == "hexagon" + uses_2d_memory = "nchw-8h8w32c" in [input_layout, working_layout, output_layout] + if uses_2d_memory and not is_hexagon: + stack.enter_context(pytest.raises(tvm.TVMError)) + + tvm.build(*schedule_args, target=target_host) + + +if __name__ == "__main__": + sys.exit(pytest.main(sys.argv))