Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions src/relax/transform/fuse_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,8 @@ class GraphCreator : public ExprVisitor {
ICHECK_NOTNULL(binding_var_node);

static const Op& call_tir_op_ = Op::Get("relax.call_tir");
static const Op& call_tir_inplace_op_ = Op::Get("relax.call_tir_inplace");

OpPatternKind pattern = OpPatternKind::kOpaque;
Array<Expr> args = call->args;

Expand All @@ -191,7 +193,7 @@ class GraphCreator : public ExprVisitor {
// - Otherwise, the pattern of the current binding variable node is set to `kOpaque`, and we
// recurse into the call expression.
const auto* op = call->op.as<OpNode>();
if (op == call_tir_op_.get()) {
if (op == call_tir_op_.get() || op == call_tir_inplace_op_.get()) {
const GlobalVar& global_var = Downcast<GlobalVar>(call->args[0]);
tir::PrimFunc func = Downcast<tir::PrimFunc>(mod_->Lookup(global_var));

Expand Down Expand Up @@ -377,7 +379,8 @@ class FunctionCreator : public ExprMutator {
* function accordingly
* \param binding The binding to be appended
* \note Allowed bindings are:
* - VarBinding with value being a call node calling `relax.call_tir`.
* - VarBinding with value being a call node calling `relax.call_tir` or
* `relax.call_tir_inplace`.
* - VarBinding with value being a tuple-get-item node.
* // TODO(tvm-team): handle match shape
*/
Expand All @@ -387,7 +390,8 @@ class FunctionCreator : public ExprMutator {

if (const auto* var_binding = binding.as<VarBindingNode>()) {
if (const auto* call = var_binding->value.as<CallNode>()) {
if (call->op == Op::Get("relax.call_tir")) {
if (call->op == Op::Get("relax.call_tir") ||
call->op == Op::Get("relax.call_tir_inplace")) {
// Update the name of the function.
name_hint_ = name_hint_ + "_" + Downcast<GlobalVar>(call->args[0])->name_hint;

Expand Down
158 changes: 128 additions & 30 deletions src/relax/transform/fuse_tir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
* under the License.
*/
#include <tvm/relax/analysis.h>
#include <tvm/relax/attrs/op.h>
#include <tvm/relax/expr_functor.h>
#include <tvm/relax/struct_info.h>
#include <tvm/relax/transform.h>
Expand Down Expand Up @@ -367,17 +368,22 @@ class FusedTIRConstructor : public ExprVisitor {
* \brief Construct a fused TIR PrimFunc from a relax sub-function
* \param mod The IRModule
* \param gv The global var of relax subfunction to be fused into one PrimFunc
* \return The fused TIR PrimFunc
* \return The fused TIR PrimFunc and the in-place indices (non-empty for an in-place call)
*/
static tir::PrimFunc GetFusedTIR(const IRModule& mod, const GlobalVar& gv) {
static std::pair<tir::PrimFunc, Array<Integer>> GetFusedTIR(const IRModule& mod,
const GlobalVar& gv) {
FusedTIRConstructor visitor(mod, gv->name_hint);
BaseFunc f = mod->Lookup(gv);
CHECK(f->IsInstance<relax::FunctionNode>())
<< "Expected relax functions, but got: " << f->GetTypeKey();
CHECK(f->HasNonzeroAttr(relax::attr::kPrimitive))
<< "Expected a function with attr `kPrimitive`";
visitor(Downcast<relax::Function>(f));
return visitor.fused_tir_;
Array<Integer> inplace_indices;
for (size_t idx : visitor.inplace_indices_) {
inplace_indices.push_back(Integer(idx));
}
return {visitor.fused_tir_, inplace_indices};
}

private:
Expand Down Expand Up @@ -438,9 +444,38 @@ class FusedTIRConstructor : public ExprVisitor {
auto it = func_info_.expr2buffers.find(body);
ICHECK(it != func_info_.expr2buffers.end())
<< "Fail to detect output buffers for function body";

const Array<tir::Buffer>& buffers = (*it).second;

// map of input buffers to indices (helpful for detecting in-place inputs)
std::unordered_map<tir::Buffer, size_t, ObjectPtrHash, ObjectPtrEqual> buffer_to_idx;
std::unordered_map<tir::Var, size_t, ObjectPtrHash, ObjectPtrEqual> input_to_idx;
for (size_t i = 0; i < func_info_.params.size(); i++) {
input_to_idx[func_info_.params[i]] = i;
}
for (auto [var, buffer] : func_info_.buffer_map) {
if (auto it = input_to_idx.find(var); it != input_to_idx.end()) {
buffer_to_idx[buffer] = (*it).second;
}
}

// numbered separately because the number of output *vars* might differ from the
// number of outputs if there are in-place inputs
int out_idx = 0;
for (size_t i = 0; i < buffers.size(); ++i) {
tir::Var param = tir::Var("p_output" + std::to_string(i), PrimType(DataType::Handle()));
// Do not add output vars for in-place inputs
// (i.e., already listed in the buffer map. This would result
// in duplicates in the buffer map otherwise)
if (auto it = buffer_to_idx.find(buffers[i]); it != buffer_to_idx.end()) {
auto idx = (*it).second;
CHECK(!inplace_indices_.count(idx))
<< "In-place index " << idx << " used twice! An argument must be aliased.";
inplace_indices_.insert(idx);
continue;
}

tir::Var param = tir::Var("p_output" + std::to_string(out_idx), PrimType(DataType::Handle()));
out_idx++;
func_info_.buffer_map.Set(param, buffers[i]);
func_info_.params.push_back(param);
func_info_.output_buffers.insert(buffers[i].get());
Expand Down Expand Up @@ -476,8 +511,11 @@ class FusedTIRConstructor : public ExprVisitor {
void VisitExpr_(const CallNode* call) final {
ExprVisitor::VisitExpr_(call);
static const Op& call_tir_op_ = Op::Get("relax.call_tir");
ICHECK(call->op == call_tir_op_)
<< "Only call_tir is supported in primitive function, but got: " << GetRef<Expr>(call);
static const Op& call_tir_inplace_op_ = Op::Get("relax.call_tir_inplace");

ICHECK(call->op == call_tir_op_ || call->op == call_tir_inplace_op_)
<< "Only call_tir and call_tir_inplace are supported in primitive function, but got: "
<< GetRef<Expr>(call);

// Step 1. Get Global var and PrimFunc
GlobalVar gv = Downcast<GlobalVar>(call->args[0]);
Expand All @@ -503,7 +541,7 @@ class FusedTIRConstructor : public ExprVisitor {
MapInputBuffer(prim_func, call->args[1]);
const Array<Array<PrimExpr>>& output_buffer_shapes = GetCallTIROutputShapes(call);

AllocateIntermediateBuffer(GetRef<Expr>(call), prim_func, output_buffer_shapes);
AllocateIntermediateBuffer(call, prim_func, output_buffer_shapes);

// Step 6. Update tir_vars
if (call->args.size() > 2) {
Expand Down Expand Up @@ -566,7 +604,8 @@ class FusedTIRConstructor : public ExprVisitor {
*/
static Array<Array<PrimExpr>> GetCallTIROutputShapes(const CallNode* call) {
static const Op& call_tir_op_ = Op::Get("relax.call_tir");
ICHECK(call->op.same_as(call_tir_op_));
static const Op& call_tir_inplace_op_ = Op::Get("relax.call_tir_inplace");
ICHECK(call->op.same_as(call_tir_op_) || call->op.same_as(call_tir_inplace_op_));
ICHECK_EQ(call->sinfo_args.size(), 1);
auto get_tensor_shape = [](const TensorStructInfoNode* sinfo) {
const auto* shape_expr = sinfo->shape.as<ShapeExprNode>();
Expand Down Expand Up @@ -611,7 +650,7 @@ class FusedTIRConstructor : public ExprVisitor {
}
}
}
// Make sure every buffers are mapped.
// Make sure every buffer is mapped.
ICHECK_EQ(buffer_idx, buffers.size());
}

Expand Down Expand Up @@ -639,28 +678,49 @@ class FusedTIRConstructor : public ExprVisitor {
MapArgsToBuffer(arg_list, buffer_list);
}

static Array<tir::Var> GetPrimFuncOutputParams(const tir::PrimFunc& func, size_t output_size) {
static Array<Integer> GetInplaceOutputIndices(const Array<Integer>& inplace_indices,
int num_inputs) {
Array<Integer> ret;
int last_idx = num_inputs;
for (auto idx : inplace_indices) {
int i = idx.IntValue();
if (i >= 0) {
ret.push_back(Integer(i));
} else {
ret.push_back(Integer(last_idx));
last_idx++;
}
}

return ret;
}

static Array<tir::Var> GetPrimFuncOutputParams(const tir::PrimFunc& func,
const Array<Integer>& output_indices) {
size_t n = func->params.size();
int symbolic_var_index = -1;
size_t output_size = output_indices.size();
ICHECK_GE(n, output_size);
for (size_t i = 0; i < n; ++i) {
const tir::Var& param = func->params[i];

Array<tir::Var> ret;
for (auto idx : output_indices) {
int i = idx.IntValue();
const tir::Var& param = func->params[static_cast<size_t>(i)];
if (param->dtype.is_int() || param->dtype.is_uint()) {
if (symbolic_var_index == -1) symbolic_var_index = i;
} else if (param->dtype.is_handle()) {
CHECK(symbolic_var_index == -1) << "The scalar input should be at the ending of the "
"parameter list.";
ret.push_back(param);
} else {
LOG(FATAL) << "The params of PrimFunc are expected to be Buffer handle or scalar, but got: "
<< param->dtype;
}
}

size_t end_index = symbolic_var_index == -1 ? n : symbolic_var_index;
ICHECK_GE(end_index, output_size);
size_t begin_index = end_index - output_size;
Array<tir::Var> output_params{func->params.begin() + begin_index,
func->params.begin() + end_index};
return output_params;
return ret;
}

/*!
Expand All @@ -670,18 +730,39 @@ class FusedTIRConstructor : public ExprVisitor {
* \param func The old TIR PrimFunc
* \param output_shapes The shape of output params.
*/
void AllocateIntermediateBuffer(const Expr& expr, const tir::PrimFunc& func,
void AllocateIntermediateBuffer(const CallNode* call, const tir::PrimFunc& func,
const Array<Array<PrimExpr>>& output_shapes) {
bool is_inplace = (call->op == Op::Get("relax.call_tir_inplace"));

size_t n = func->params.size();
int num_inputs = Downcast<Tuple>(call->args[1])->fields.size();
size_t output_size = output_shapes.size();
ICHECK_GE(n, output_size);
// Allocate intermediate buffer
Array<tir::Buffer> alloc_buffers;
Array<tir::Var> output_params = GetPrimFuncOutputParams(func, output_size);
Array<tir::Buffer> output_buffers;
Array<Integer> output_idxs;
if (is_inplace) {
const auto* attrs = call->attrs.as<CallTIRInplaceAttrs>();
CHECK(attrs) << "Must have CallTIRInplaceAttrs for an in-place call";
output_idxs = std::move(GetInplaceOutputIndices(attrs->inplace_indices, num_inputs));
} else {
for (size_t i = 0; i < output_size; i++) {
output_idxs.push_back(num_inputs + i);
}
}

Array<tir::Var> output_params = GetPrimFuncOutputParams(func, output_idxs);
auto input_buffers = func_info_.expr2buffers.Get(call->args[1]);
for (size_t i = 0; i < output_size; ++i) {
const tir::Var& param = output_params[i];
const tir::Buffer& buffer = func->buffer_map.at(param);

// if this is an inplace output, do not do an intermediate allocation
if (output_idxs[i].IntValue() < num_inputs) {
CHECK(input_buffers.defined()) << "Inplace functions must have some defined input";
output_buffers.push_back(input_buffers.value()[output_idxs[i].IntValue()]);
continue;
}

auto unify_name_hints = [this, &buffer]() {
String base_name = buffer->name;
String unique_name = base_name + "_intermediate";
Expand All @@ -703,14 +784,14 @@ class FusedTIRConstructor : public ExprVisitor {
n->name = unify_name_hints();
tir::Buffer new_buffer(n);
func_info_.alloc_buffers.push_back(new_buffer);
alloc_buffers.push_back(new_buffer);
output_buffers.push_back(new_buffer);

// Match the shape of the output buffer with the shape
func_info_.symbolic_var_matcher.Match(buffer->shape, n->shape);
func_info_.buffer_subst_map.Set(buffer, new_buffer);
}
// Update expr2buffers
func_info_.expr2buffers.Set(expr, alloc_buffers);
func_info_.expr2buffers.Set(GetRef<Expr>(call), output_buffers);
}

/*!
Expand Down Expand Up @@ -858,6 +939,8 @@ class FusedTIRConstructor : public ExprVisitor {
FuseFuncInfo func_info_;
/*! \brief The tir function after fusion*/
tir::PrimFunc fused_tir_;
/*! \brief Indices of inputs that are used for in-place computation */
std::unordered_set<size_t> inplace_indices_;
};

std::vector<size_t> GetTupleAccessedIndices(const FunctionNode* func, const Var& tuple_var) {
Expand Down Expand Up @@ -897,8 +980,11 @@ class TIRFuseMutator : public ExprMutator {
for (const auto& [gv, func] : mod->functions) {
// Only fuse primitive relax functions
if (func->IsInstance<relax::FunctionNode>() && func->HasNonzeroAttr(attr::kPrimitive)) {
tir::PrimFunc fused_tir = FusedTIRConstructor::GetFusedTIR(mod, gv);
mutator.fused_tir_funcs_.Set(gv, fused_tir);
const auto& [prim_func, indices] = FusedTIRConstructor::GetFusedTIR(mod, gv);
mutator.fused_tir_funcs_.Set(gv, prim_func);
if (!indices.empty()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick: Using if (indices.size()) instead of if (!indices.empty()) would avoiding double-negatives and make the condition easier to read. (Though, personal preference as if (indices.size()) relies on conversion of non-zero size_t to true.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Kind of subjective, I think "not empty" is readable (sounds more like how it would be phrased verbally). I'll change it if you prefer it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably more personal preference, so either works.

mutator.inplace_indices_.Set(gv, indices);
}
}
}

Expand Down Expand Up @@ -945,6 +1031,7 @@ class TIRFuseMutator : public ExprMutator {

Expr VisitExpr_(const CallNode* op) final {
static const Op& call_tir_op_ = Op::Get("relax.call_tir");
static const Op& call_tir_inplace_op_ = Op::Get("relax.call_tir_inplace");

Call call = Downcast<Call>(builder_->Normalize(ExprMutator::VisitExpr_(op)));

Expand Down Expand Up @@ -985,26 +1072,34 @@ class TIRFuseMutator : public ExprMutator {
CHECK(prim_value->value.defined())
<< "FuseTIR requires all R.Prim arguments to have a known value.";
PrimExpr expr = prim_value->value.value();
CHECK(expr->IsInstance<tir::VarNode>())
<< "FuseTIR currently requires all R.Prim arguments to provide a single tir::Var.";
CHECK(expr->IsInstance<tir::VarNode>()) << "FuseTIR currently requires all R.Prim "
"arguments to provide a single tir::Var.";
tir_vars.push_back(expr);

} else {
arg_list.push_back(arg);
}
}
// Step b. Create call_tir
// Step b. Create call_tir or call_tir_inplace
Array<Expr> call_args = {fused_tir_gv, Tuple(arg_list)};
if (!tir_vars.empty()) {
call_args.push_back(ShapeExpr(tir_vars));
}
return Call(call_tir_op_, call_args, call->attrs, {GetStructInfo(call)});
Op call_op = call_tir_op_;
Attrs call_attrs = call->attrs;
if (auto it = inplace_indices_.find(old_gv); it != inplace_indices_.end()) {
call_op = call_tir_inplace_op_;
auto inplace_attrs = make_object<CallTIRInplaceAttrs>();
inplace_attrs->inplace_indices = (*it).second;
call_attrs = Attrs(inplace_attrs);
}
return Call(call_op, call_args, call_attrs, {GetStructInfo(call)});
} else {
// Case 1.2. The callee function is not primitive, nothing to do.
return call;
}
} else if (call->op == call_tir_op_) {
// Case 2. It is a call_tir, re-emit the PrimFunc.
} else if (call->op == call_tir_op_ || call->op == call_tir_inplace_op_) {
// Case 2. It is a call_tir or call_tir_inplace, re-emit the PrimFunc.
if (const auto* gv = call->args[0].as<GlobalVarNode>()) {
tir::PrimFunc func = Downcast<tir::PrimFunc>(mod_->Lookup(GetRef<GlobalVar>(gv)));
GlobalVar new_gv = this->builder_->AddFunction(func, gv->name_hint);
Expand All @@ -1023,6 +1118,9 @@ class TIRFuseMutator : public ExprMutator {
const IRModule& mod_;
/*! \brief The map from global var of primitive relax function to generated prim func. */
Map<GlobalVar, tir::PrimFunc> fused_tir_funcs_;
/*! \brief The map from global var of primitive relax function to in-place indices
* (if there are any). */
Map<GlobalVar, Array<Integer>> inplace_indices_;
};

IRModule FuseTIR(IRModule mod) {
Expand Down
Loading