Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions src/target/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1009,7 +1009,7 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) {
const BufferLoadNode* load = op->args[0].as<BufferLoadNode>();
ICHECK(op->args.size() == 1 && load);
ICHECK_EQ(load->indices.size(), 1) << "LLVM only supports flat memory allocations.";
PrimExpr index = load->indices[0];
PrimExpr index = analyzer_->Simplify(load->buffer->elem_offset + load->indices[0]);
if (const RampNode* r = index.as<RampNode>()) {
index = r->base;
}
Expand Down Expand Up @@ -1253,15 +1253,14 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const SelectNode* op) {
llvm::Value* CodeGenLLVM::VisitExpr_(const LetNode* op) {
auto it = let_binding_.find(op->var);
if (it != let_binding_.end()) {
ICHECK(deep_equal_(it->second->value, op->value))
ICHECK(deep_equal_(it->second, op->value))
<< "Let cannot bind the same var to two different values";
} else {
let_binding_[op->var] = op;
let_binding_[op->var] = op->value;
}
auto var_value = MakeValue(op->value);
var_map_[op->var.get()] = var_value;
var_value->setName(op->var->name_hint.c_str());
analyzer_->Bind(op->var, op->value);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't we need bind here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bind means analyzer will always expand the value expression for simplify and other functionalities. It will break the evaluation order specified by lets, which should be respsected in codegen phase. The issue could be triggered on existing testcases by the simplify this PR adds.

I can use a local analyzer on this PR's purpose but I think this is still an issue to resolve.

return MakeValue(op->body);
}

Expand Down Expand Up @@ -1301,6 +1300,7 @@ void CodeGenLLVM::BufferAccessHelper(
}

PrimExpr last_index = indices[indices.size() - 1];

ICHECK_EQ(value_dtype.lanes(), last_index.dtype().lanes() * buffer_element_dtype.lanes());

bool is_volatile = volatile_buf_.count(buffer->data.get());
Expand Down Expand Up @@ -1355,8 +1355,15 @@ void CodeGenLLVM::BufferAccessHelper(
std::vector<llvm::Value*> all_index_values = earlier_index_values;
all_index_values.push_back(last_index_value);

llvm::Value* buffer_data_begin = MakeValue(buffer->data);
if (!analyzer_->CanProveEqual(buffer->elem_offset, 0)) {
buffer_data_begin = CreateBufferPtr(buffer_data_begin, buffer_element_dtype,
{MakeValue(buffer->elem_offset)}, buffer_element_dtype)
.addr;
}

TypedPointer buffer_ptr =
CreateBufferPtr(MakeValue(buffer->data), buffer_element_dtype, all_index_values,
CreateBufferPtr(buffer_data_begin, buffer_element_dtype, all_index_values,
value_dtype.with_lanes(value_dtype.lanes() / last_index.dtype().lanes()));
auto instruction = make_instruction(buffer_ptr, subelement_i, alignment, is_volatile);
AddAliasInfo(instruction, buffer->data.get(), last_index);
Expand Down Expand Up @@ -1632,7 +1639,6 @@ void CodeGenLLVM::VisitStmt_(const LetStmtNode* op) {
llvm::Value* value = MakeValue(op->value);
value->setName(v->name_hint.c_str());
var_map_[v] = value;
analyzer_->Bind(op->var, op->value);
if (alloc_storage_info_.count(v) && alloc_storage_info_[v].alignment > 1) {
builder_->CreateAlignmentAssumption(*data_layout_, GetVarValue(v),
alloc_storage_info_[v].alignment);
Expand Down
2 changes: 1 addition & 1 deletion src/target/llvm/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ class CodeGenLLVM : public ExprFunctor<llvm::Value*(const PrimExpr&)>,
// deep comparison of PrimExpr
ExprDeepEqual deep_equal_;
// binding of let variables. Enables duplicate var defs that map to same value
std::unordered_map<Var, const LetNode*, ObjectPtrHash, ObjectPtrEqual> let_binding_;
std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> let_binding_;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please explain this change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A local expression can be visited (eg, a simplify res) and some pointer are recorded. But when the local scope ends the backing object is expired.

// Cache potential common path ops to slightly improve lookup time.
// global symbol table.
OpAttrMap<TGlobalSymbol> op_attr_global_symbol_ = Op::GetAttrMap<TGlobalSymbol>("TGlobalSymbol");
Expand Down
13 changes: 7 additions & 6 deletions src/target/source/codegen_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,8 @@ void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*)
const BufferLoadNode* load = op->args[0].as<BufferLoadNode>();
ICHECK(op->args.size() == 1 && load);
ICHECK_EQ(load->indices.size(), 1) << "CodeGenC only supports flat memory allocations.";
os << "(&(" << GetBufferRef(load->dtype, load->buffer.get(), load->indices[0]) << "))";
PrimExpr index = analyzer_.Simplify(load->buffer->elem_offset + load->indices[0]);
os << "(&(" << GetBufferRef(load->dtype, load->buffer.get(), index) << "))";
} else if (op->op.same_as(builtin::tvm_struct_get())) {
ICHECK_EQ(op->args.size(), 3U);
os << GetStructRef(op->dtype, op->args[0], op->args[1], op->args[2].as<IntImmNode>()->value);
Expand Down Expand Up @@ -669,7 +670,7 @@ void CodeGenC::VisitExpr_(const BufferLoadNode* op, std::ostream& os) { // NOLI
ICHECK_EQ(op->indices.size(), 1) << "Load from non-flat memory not supported.";

DataType value_dtype = op->dtype;
PrimExpr index = op->indices[0];
PrimExpr index = analyzer_.Simplify(op->buffer->elem_offset + op->indices[0]);
Var buffer_var = op->buffer->data;
DataType element_dtype = op->buffer->dtype;

Expand All @@ -684,7 +685,7 @@ void CodeGenC::VisitExpr_(const BufferLoadNode* op, std::ostream& os) { // NOLI
if (arith::ramp(base, 1, op->dtype.lanes()).Match(index)) {
const RampNode* ramp = index.as<RampNode>();
ICHECK(ramp);
arith::ModularSet me = arith::Analyzer().modular_set(ramp->base);
arith::ModularSet me = analyzer_.modular_set(ramp->base);
// The condition: {k * coeff + base} divisible by the alignment for any k
if (me->coeff % op->dtype.lanes() == 0 && me->base % op->dtype.lanes() == 0) {
can_vector_load = true;
Expand Down Expand Up @@ -733,7 +734,7 @@ void CodeGenC::VisitStmt_(const BufferStoreNode* op) {

DataType value_dtype = op->value.dtype();
DataType element_dtype = op->buffer->dtype;
PrimExpr index_expr = op->indices[0];
PrimExpr index_expr = analyzer_.Simplify(op->buffer->elem_offset + op->indices[0]);
Var buffer_var = op->buffer->data;

if (value_dtype.lanes() == element_dtype.lanes()) {
Expand Down Expand Up @@ -786,10 +787,10 @@ void CodeGenC::VisitStmt_(const BufferStoreNode* op) {
void CodeGenC::VisitExpr_(const LetNode* op, std::ostream& os) { // NOLINT(*)
auto it = let_binding_.find(op->var);
if (it != let_binding_.end()) {
ICHECK(deep_equal_(it->second->value, op->value))
ICHECK(deep_equal_(it->second, op->value))
<< "Let cannot bind the same var to two different values";
} else {
let_binding_[op->var] = op;
let_binding_[op->var] = op->value;
}
std::string value = PrintExpr(op->value);
var_idmap_[op->var.get()] = value;
Expand Down
5 changes: 4 additions & 1 deletion src/target/source/codegen_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#ifndef TVM_TARGET_SOURCE_CODEGEN_C_H_
#define TVM_TARGET_SOURCE_CODEGEN_C_H_

#include <tvm/arith/analyzer.h>
#include <tvm/ir/op.h>
#include <tvm/target/codegen.h>
#include <tvm/tir/analysis.h>
Expand Down Expand Up @@ -280,7 +281,9 @@ class CodeGenC : public ExprFunctor<void(const PrimExpr&, std::ostream&)>,
// deep comparison of PrimExpr
ExprDeepEqual deep_equal_;
// binding of let variables. Enables duplicate var defs that map to same value
std::unordered_map<Var, const LetNode*, ObjectPtrHash, ObjectPtrEqual> let_binding_;
std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> let_binding_;
// analyzer
arith::Analyzer analyzer_;
};

} // namespace codegen
Expand Down
6 changes: 4 additions & 2 deletions src/target/source/codegen_opencl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -382,15 +382,17 @@ void CodeGenOpenCL::VisitExpr_(const CallNode* op, std::ostream& os) {
// Overload tvm_address_of to add storage scope (e.g. __global).
const BufferLoadNode* load = op->args[0].as<BufferLoadNode>();
ICHECK(op->args.size() == 1 && load);
ICHECK_EQ(load->indices.size(), 0) << "CodeGenOpenCL only supports flat memory allocations.";
ICHECK_EQ(load->indices.size(), 1) << "CodeGenOpenCL only supports flat memory allocations.";
PrimExpr index = analyzer_.Simplify(load->buffer->elem_offset + load->indices[0]);

os << "((";
auto it = alloc_storage_scope_.find(load->buffer->data.get());
if (it != alloc_storage_scope_.end()) {
PrintStorageScope(it->second, os);
}
this->PrintType(load->dtype.element_of(), os);
os << " *)" << this->GetVarID(load->buffer->data.get()) << " + ";
this->PrintExpr(load->indices[0], os);
this->PrintExpr(index, os);
os << ')';
} else if (op->op.same_as(builtin::texture2d_store())) {
auto* ptr_type = op->args[0].as<VarNode>()->type_annotation.as<PointerTypeNode>();
Expand Down
3 changes: 3 additions & 0 deletions src/target/source/codegen_opencl.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#ifndef TVM_TARGET_SOURCE_CODEGEN_OPENCL_H_
#define TVM_TARGET_SOURCE_CODEGEN_OPENCL_H_

#include <tvm/arith/analyzer.h>
#include <tvm/target/codegen.h>

#include <string>
Expand Down Expand Up @@ -85,6 +86,8 @@ class CodeGenOpenCL final : public CodeGenC {
// Mapping from buffer to allocation size.
// Useful to track when a scalar store of a vectorized texture load is required.
std::unordered_map<const Object*, size_t> allocation_size_;
// analyzer
arith::Analyzer analyzer_;
};

} // namespace codegen
Expand Down
10 changes: 4 additions & 6 deletions src/target/spirv/codegen_spirv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -287,13 +287,12 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const SelectNode* op) {
spirv::Value CodeGenSPIRV::VisitExpr_(const LetNode* op) {
auto it = let_binding_.find(op->var);
if (it != let_binding_.end()) {
ICHECK(deep_equal_(it->second->value, op->value))
ICHECK(deep_equal_(it->second, op->value))
<< "Let cannot bind the same var to two different values";
} else {
let_binding_[op->var] = op;
let_binding_[op->var] = op->value;
}
var_map_[op->var.get()] = MakeValue(op->value);
analyzer_->Bind(op->var, op->value);
return MakeValue(op->body);
}

Expand Down Expand Up @@ -428,7 +427,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const BroadcastNode* op) {
spirv::Value CodeGenSPIRV::VisitExpr_(const BufferLoadNode* op) {
ICHECK_EQ(op->indices.size(), 1) << "SPIR-V codegen expects flat memory buffers";
Var buffer_var = op->buffer->data;
PrimExpr prim_index = op->indices[0];
PrimExpr prim_index = analyzer_->Simplify(op->buffer->elem_offset + op->indices[0]);

DataType desired_read_type = op->dtype;
if (desired_read_type == DataType::Bool()) {
Expand Down Expand Up @@ -501,7 +500,7 @@ void CodeGenSPIRV::Scalarize(const PrimExpr& e, std::function<void(int i, spirv:
void CodeGenSPIRV::VisitStmt_(const BufferStoreNode* op) {
ICHECK_EQ(op->indices.size(), 1) << "SPIR-V codegen expects flat memory buffers";
Var buffer_var = op->buffer->data;
PrimExpr prim_index = op->indices[0];
PrimExpr prim_index = analyzer_->Simplify(op->buffer->elem_offset + op->indices[0]);

auto it = storage_info_.find(buffer_var.get());
ICHECK(it != storage_info_.end());
Expand Down Expand Up @@ -716,7 +715,6 @@ void CodeGenSPIRV::VisitStmt_(const LetStmtNode* op) {
ICHECK(!var_map_.count(op->var.get()));
ICHECK(!op->var.dtype().is_handle());
var_map_[op->var.get()] = MakeValue(op->value);
analyzer_->Bind(op->var, op->value);
this->VisitStmt(op->body);
}

Expand Down
2 changes: 1 addition & 1 deletion src/target/spirv/codegen_spirv.h
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ class CodeGenSPIRV : public ExprFunctor<spirv::Value(const PrimExpr&)>,
ExprDeepEqual deep_equal_;

// binding of let variables. Enables duplicate var defs that map to same value
std::unordered_map<Var, const LetNode*, ObjectPtrHash, ObjectPtrEqual> let_binding_;
std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> let_binding_;

// Running total of the number of bytes of shared memory used.
// Checked against the max_shared_memory_per_group
Expand Down
9 changes: 5 additions & 4 deletions src/target/stackvm/codegen_stackvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ void CodeGenStackVM::VisitExpr_(const BufferLoadNode* op) {
ICHECK_EQ(op->indices.size(), 1) << "StackVM expects flat 1-d buffers. "
<< "Has StorageFlatten (TE-based schedules) or "
<< "FlattenBuffer (TIR-based schedules) been run?";
auto index = op->indices[0];
auto index = analyzer_.Simplify(op->buffer->elem_offset + op->indices[0]);

this->Push(op->buffer->data);
StackVM::OpCode code = StackVM::GetLoad(op->dtype);
Expand All @@ -170,7 +170,7 @@ void CodeGenStackVM::VisitStmt_(const BufferStoreNode* op) {
ICHECK_EQ(op->indices.size(), 1) << "StackVM expects flat 1-d buffers. "
<< "Has StorageFlatten (TE-based schedules) or "
<< "FlattenBuffer (TIR-based schedules) been run?";
auto index = op->indices[0];
auto index = analyzer_.Simplify(op->buffer->elem_offset + op->indices[0]);

this->Push(op->buffer->data);
StackVM::OpCode code = StackVM::GetStore(op->value.dtype());
Expand All @@ -195,10 +195,11 @@ void CodeGenStackVM::VisitExpr_(const CallNode* op) {
if (op->op.same_as(builtin::address_of())) {
const BufferLoadNode* load = op->args[0].as<BufferLoadNode>();
ICHECK(op->args.size() == 1 && load);
ICHECK_EQ(load->indices.size(), 0) << "CodeGenStackVM only supports flat memory allocations.";
ICHECK_EQ(load->indices.size(), 1) << "CodeGenStackVM only supports flat memory allocations.";
PrimExpr index = analyzer_.Simplify(load->buffer->elem_offset + load->indices[0]);

this->PushOp(StackVM::LOAD_HEAP, GetVarID(load->buffer->data.get()));
this->Push(load->indices[0]);
this->Push(index);
this->PushOp(StackVM::PUSH_I64, load->dtype.element_of().bytes());
this->PushOp(StackVM::MUL_I64);
this->PushOp(StackVM::ADDR_ADD);
Expand Down
3 changes: 3 additions & 0 deletions src/target/stackvm/codegen_stackvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#ifndef TVM_TARGET_STACKVM_CODEGEN_STACKVM_H_
#define TVM_TARGET_STACKVM_CODEGEN_STACKVM_H_

#include <tvm/arith/analyzer.h>
#include <tvm/target/codegen.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/function.h>
Expand Down Expand Up @@ -156,6 +157,8 @@ class CodeGenStackVM : public ExprFunctor<void(const PrimExpr&)>,
std::unordered_map<std::string, int> str_idmap_;
/*! \brief id of each global function */
std::unordered_map<std::string, int> extern_fun_idmap_;
/*! \brief analyzer */
arith::Analyzer analyzer_;

Op backend_alloc_workspace_op_ = Op::Get("tir.TVMBackendAllocWorkspace");
Op backend_free_workspace_op_ = Op::Get("tir.TVMBackendFreeWorkspace");
Expand Down
50 changes: 47 additions & 3 deletions src/tir/transforms/narrow_datatype.cc
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,16 @@ class DataTypeVisitor final : public StmtExprVisitor {
StmtExprVisitor::VisitExpr_(op);
}

void VisitExpr_(const BufferLoadNode* op) {
VisitExpr(op->buffer->elem_offset);
StmtExprVisitor::VisitExpr_(op);
}

void VisitStmt_(const BufferStoreNode* op) {
VisitExpr(op->buffer->elem_offset);
StmtExprVisitor::VisitStmt_(op);
}

// the narrowed datatype of Var and IntImm
std::unordered_map<const PrimExprNode*, DataType> vmap;

Expand Down Expand Up @@ -217,11 +227,13 @@ class DataTypeRewriter : public StmtExprMutator {
Stmt VisitStmt_(const BufferStoreNode* op) final {
BufferStore store = GetRef<BufferStore>(op);

auto buffer = VisitBuffer(op->buffer);
auto value = this->VisitExpr(op->value);
auto indices = VisitIndices(op->indices);

if (!value.same_as(op->value) || !indices.same_as(op->indices)) {
if (!buffer.same_as(op->buffer) || !value.same_as(op->value) || !indices.same_as(op->indices)) {
auto writer = store.CopyOnWrite();
writer->buffer = buffer;
writer->value = value;
writer->indices = indices;
}
Expand All @@ -232,10 +244,12 @@ class DataTypeRewriter : public StmtExprMutator {
PrimExpr VisitExpr_(const BufferLoadNode* op) final {
BufferLoad load = GetRef<BufferLoad>(op);

auto buffer = VisitBuffer(op->buffer);
auto indices = VisitIndices(op->indices);

if (!indices.same_as(op->indices)) {
if (!buffer.same_as(op->buffer) || !indices.same_as(op->indices)) {
auto writer = load.CopyOnWrite();
writer->buffer = buffer;
writer->indices = indices;
}

Expand All @@ -253,6 +267,23 @@ class DataTypeRewriter : public StmtExprMutator {
return indices;
}

Buffer VisitBuffer(const Buffer& origin_buffer) {
auto it = buffer_remap_.find(origin_buffer);
if (it != buffer_remap_.end()) {
return it->second;
}
is_index_ = true;
PrimExpr elem_offset = VisitExpr(origin_buffer->elem_offset);
is_index_ = false;
Buffer updated_buffer = origin_buffer;
if (!elem_offset.same_as(origin_buffer->elem_offset)) {
auto n = updated_buffer.CopyOnWrite();
n->elem_offset = elem_offset;
}
buffer_remap_.insert(it, {origin_buffer, updated_buffer});
return updated_buffer;
}

Stmt VisitStmt_(const ForNode* op) final {
Stmt s = StmtExprMutator::VisitStmt_(op);
op = s.as<ForNode>();
Expand Down Expand Up @@ -375,6 +406,9 @@ class DataTypeRewriter : public StmtExprMutator {
PrimExpr VisitExpr_(const GENode* op) final;
PrimExpr VisitExpr_(const CallNode* op) final;

// buffer remapping dict
std::unordered_map<Buffer, Buffer, ObjectPtrHash, ObjectPtrEqual> buffer_remap_;

private:
// the internal visitor to deduce the narrowed dtype
DataTypeVisitor visitor_;
Expand Down Expand Up @@ -449,7 +483,17 @@ namespace transform {
Pass NarrowDataType(int target_bits) {
auto pass_func = [target_bits](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
n->body = DataTypeRewriter(target_bits)(std::move(n->body));
DataTypeRewriter rewriter(target_bits);
n->body = rewriter(std::move(n->body));
for (const Var& param : f->params) {
auto it = n->buffer_map.find(param);
if (it != n->buffer_map.end()) {
auto remap_it = rewriter.buffer_remap_.find((*it).second);
if (remap_it != rewriter.buffer_remap_.end()) {
n->buffer_map.Set(param, remap_it->second);
}
}
}
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.NarrowDataType", {});
Expand Down
Loading