Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
206 changes: 195 additions & 11 deletions src/relax/transform/fuse_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,11 @@

#include <optional>
#include <unordered_map>
#include <vector>

#include "../../relay/analysis/graph_partitioner.h"
#include "../../support/arena.h"
#include "../../support/ordered_set.h"
#include "tvm/relax/expr.h"
#include "utils.h"

Expand Down Expand Up @@ -360,6 +362,169 @@ class GraphCreator : public ExprVisitor {
std::unordered_set<IndexedForwardGraph::Node*> initialized_nodes_;
};

class InferredCommonSubexpressionCollector : relax::ExprVisitor,
StructInfoVisitor,
tir::ExprVisitor {
public:
struct InferResult {
// A list of additional symbolic variables that must be provided
// to the function. These variables cannot be inferred from the
// StructInfo of the existing parameters.
Array<tir::Var> symbolic_vars;

// A list of expressions, each of which must be remapped to a new
// symbolic variable. These expressions can be inferred from the
// StructInfo of the existing parameters, but may contain
// sub-expressions that cannot.
Array<PrimExpr> symbolic_expressions;
};

static InferResult Infer(Array<Var> params, Expr body) {
InferredCommonSubexpressionCollector collector;
collector.VisitStructInfo(TupleStructInfo(params.Map(GetStructInfo)));
collector.phase_ = Phase::CollectRequiredExpressions;
collector.VisitExpr(body);

return InferResult{
Array<tir::Var>(collector.required_symbolic_vars_.begin(),
collector.required_symbolic_vars_.end()),
Array<PrimExpr>(collector.required_symbolic_exprs_.begin(),
collector.required_symbolic_exprs_.end()),
};
}

private:
using relax::ExprVisitor::VisitExpr;
using relax::ExprVisitor::VisitExpr_;
using tir::ExprVisitor::VisitExpr;
using tir::ExprVisitor::VisitExpr_;

void VisitExprDepStructInfoField(const StructInfo& struct_info) override {
VisitStructInfo(struct_info);
}
void VisitStructInfoExprField(const Expr& expr) override { VisitStructInfo(GetStructInfo(expr)); }
void VisitStructInfoExprField(const PrimExpr& expr) override {
if (expr->IsInstance<IntImmNode>()) {
return;
}

switch (phase_) {
case Phase::CollectInferableExpressions:
inferable_expressions_.insert(expr);
break;

case Phase::CollectRequiredExpressions:
VisitExpr(expr);
break;

default:
LOG(FATAL) << "Invalid value for Phase: " << static_cast<int>(phase_);
break;
}
}

void VisitExpr(const PrimExpr& expr) override {
if (inferable_expressions_.count(expr)) {
required_symbolic_exprs_.insert(expr);
} else {
tir::ExprVisitor::VisitExpr(expr);
}
}

void VisitExpr_(const tir::VarNode* op) override {
required_symbolic_vars_.push_back(GetRef<tir::Var>(op));
}

enum class Phase {
CollectInferableExpressions,
CollectRequiredExpressions,
};
Phase phase_ = Phase::CollectInferableExpressions;
std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> inferable_expressions_;
support::OrderedSet<tir::Var> required_symbolic_vars_;
support::OrderedSet<PrimExpr, StructuralHash, StructuralEqual> required_symbolic_exprs_;
};

/* \brief Replace occurrences of a PrimExpr in the symbolic variables
*
* In most cases, the `tvm::relax::Bind` utility should be used
* instead. Here, though, we are replacing a `PrimExpr` with a
* `tir::Var`, whereas `tvm::relax::Bind` supports the more standard
* case of replacing a `tir::Var` with a `PrimExpr`.
*/
class SymbolicSubexprReplacer : relax::ExprMutator, StructInfoMutator, tir::ExprMutator {
public:
/* \brief Replace occurrences of a PrimExpr in the symbolic variables
*
* In most cases, the `tvm::relax::Bind` utility should be used
* instead. Here, though, we are replacing a `PrimExpr` with a
* `tir::Var`, rather than the other way around.
*
* \param relax_expr The expression in which to replace symbolic expressions
*
* \param symbolic_exprs A list of expressions, each of which should
* be replaced with a new symbolic variable. This is provided as a
* list, rather than as a replacement map, to allow context-dependent
* names to be generated for these expressions.
*
* \returns The updated relax expression.
*/
static Expr Replace(const Expr& relax_expr, Array<PrimExpr> symbolic_exprs) {
std::unordered_map<PrimExpr, Optional<tir::Var>, StructuralHash, StructuralEqual> replacements;
for (const auto& expr : symbolic_exprs) {
replacements.insert({expr, NullOpt});
}

SymbolicSubexprReplacer mutator(replacements);
return mutator(relax_expr);
}

private:
using relax::ExprMutator::operator();
using relax::ExprMutator::VisitExpr;
using tir::ExprMutator::operator();
using tir::ExprMutator::VisitExpr;

SymbolicSubexprReplacer(
std::unordered_map<PrimExpr, Optional<tir::Var>, StructuralHash, StructuralEqual>
replacements)
: replacements_(replacements) {}

StructInfo VisitExprDepStructInfoField(const StructInfo& struct_info) override {
return VisitStructInfo(struct_info);
}
Expr VisitStructInfoExprField(const Expr& expr) override { return VisitExpr(expr); }
PrimExpr VisitStructInfoExprField(const PrimExpr& expr) override { return VisitExpr(expr); }
PrimExpr VisitPrimExpr(const PrimExpr& expr) override { return VisitExpr(expr); }

PrimExpr VisitExpr(const PrimExpr& expr) override {
if (auto replacement = GetReplacement(expr)) {
return replacement.value();
} else {
return tir::ExprMutator::VisitExpr(expr);
}
}

Optional<tir::Var> GetReplacement(const PrimExpr& expr) {
auto it = replacements_.find(expr);
if (it == replacements_.end()) {
return NullOpt;
}

Optional<tir::Var>& opt_var = it->second;
if (!opt_var.defined()) {
// Ideally, this path would never be reached, as it doesn't
// provide as much context in the variable name. However, it's
// useful as a fallback.
opt_var = tir::Var("fused_expr", expr->dtype);
}

return opt_var.value();
}

std::unordered_map<PrimExpr, Optional<tir::Var>, StructuralHash, StructuralEqual> replacements_;
};

/*!
* \brief The ExprMutator used to create a new grouped function
* \details The workflow of this ExprMutator is:
Expand Down Expand Up @@ -533,25 +698,44 @@ class FunctionCreator : public ExprMutator {
function_ = NullOpt;
} else {
Expr body = outputs.size() == 1 ? outputs[0] : Tuple(outputs);
body = SeqExpr({new_block}, body);
body = builder_->Normalize(body);
body = builder_->Normalize(SeqExpr({new_block}, body));

// Any symbolic variables that are required within the body of
// the function, but cannot be inferred from the parameters of
// the function, must be exposed using an additional argument.
auto [symbolic_vars, symbolic_expressions] =
InferredCommonSubexpressionCollector::Infer(params_, body);
if (symbolic_vars.size()) {
auto symbolic_vars_as_expr =
symbolic_vars.Map([](tir::Var var) -> PrimExpr { return var; });
params_.push_back(Var("tir_vars", ShapeStructInfo(symbolic_vars_as_expr)));
arguments_.push_back(ShapeExpr(symbolic_vars_as_expr));
}

group_attrs.Set(tvm::relax::attr::kPrimitive, Integer(1));
Function function = Function(/*params=*/params_, //
/*body=*/body, //
/*ret_struct_info=*/NullOpt, //
/*is_pure=*/true, //
/*attrs=*/DictAttrs(group_attrs));
Array<PrimExpr> free_vars =
FreeSymbolicVars(function).Map([](const tir::Var& var) -> PrimExpr { return var; });
if (!free_vars.empty()) {
params_.push_back(Var("tir_vars", ShapeStructInfo(free_vars)));
arguments_.push_back(ShapeExpr(free_vars));
function = Function(/*params=*/params_, //
/*body=*/body, //
/*ret_struct_info=*/NullOpt, //
/*is_pure=*/true, //
/*attrs=*/DictAttrs(group_attrs));

// If the function contains symbolic expressions that can be
// inferred from the parameters, but contain subexpressions that
// cannot be inferred from the parameters, those expressions
// should be replaced with symbolic variables.
//
// For example, suppose a fused function maps from a tensor of
// shape `[batch_size+1, 1024]` to `[batch_size+1,1024]`. It
// cannot infer `batch_size`, but could infer the value of
// `batch_size+1`. By introducing `batch_size_plus_one =
// batch_size+1`, we can rely on just the infer-able symbolic
// vars.
if (symbolic_expressions.size()) {
function =
Downcast<Function>(SymbolicSubexprReplacer::Replace(function, symbolic_expressions));
}

function_ = SymbolicVarRenewMutator::Renew(function);
}
}
Expand Down
25 changes: 20 additions & 5 deletions src/support/ordered_set.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

#include <tvm/runtime/object.h>

#include <functional>
#include <list>
#include <unordered_map>

Expand All @@ -39,17 +40,31 @@ namespace detail {
*/
template <typename T, typename = void>
struct OrderedSetLookupType {
using MapType = std::unordered_map<T, typename std::list<T>::iterator>;
using Hash = std::hash<T>;
using Equal = std::equal_to<T>;
};

template <typename T>
struct OrderedSetLookupType<T, std::enable_if_t<std::is_base_of_v<runtime::ObjectRef, T>>> {
using MapType = std::unordered_map<T, typename std::list<T>::iterator, runtime::ObjectPtrHash,
runtime::ObjectPtrEqual>;
using Hash = runtime::ObjectPtrHash;
using Equal = runtime::ObjectPtrEqual;
};
} // namespace detail

template <typename T>
/* \brief Utility to hold an ordered set
*
* \tparam T The type held by the OrderedSet
*
* \tparam LookupHash The hash implementation to use for detecting
* duplicate entries. If unspecified, defaults to `ObjectPtrHash` for
* TVM types, and `std::hash<T>` otherwise.
*
* \tparam LookupEqual The equality-checker to use for detecting
* duplicate entries. If unspecified, defaults to `ObjectPtrEqual`
* for TVM types, and `std::equal_to<T>` otherwise.
*/
template <typename T, typename LookupHash = typename detail::OrderedSetLookupType<T>::Hash,
typename LookupEqual = typename detail::OrderedSetLookupType<T>::Equal>
class OrderedSet {
public:
OrderedSet() = default;
Expand Down Expand Up @@ -91,7 +106,7 @@ class OrderedSet {

private:
std::list<T> elements_;
typename detail::OrderedSetLookupType<T>::MapType elem_to_iter_;
std::unordered_map<T, typename std::list<T>::iterator, LookupHash, LookupEqual> elem_to_iter_;
};

} // namespace support
Expand Down
77 changes: 77 additions & 0 deletions tests/python/relax/test_transform_fuse_ops_by_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -1217,5 +1217,82 @@ def inner_func(
tvm.ir.assert_structural_equal(Expected, After)


def test_matmul_symbolic_expr():
"""Like `test_matmul_symbolic_var`, but with a PrimExpr shape

The shape of weights used in the matmul are `[1024, M + 1024]`,
which can result from `CombineParallelMatmul`. If the fused
function is written in terms of `M`, then `M` must be provided as
an additional `ShapeExpr`, as it cannot be inferred from the
tensor shape. This can cause issues for downstream passes, as
CodeGenJSON, used by the TVM's runtime for cublas and cutlass,
only supports `R.Tensor` and tuples of `R.Tensor`.

If a symbolic variable is only used within expressions that
themselves are inferable from the tensor shapes, then the fused
function could be written in terms of that expression, removing
the need for the `ShapeExpr`. Here, the expression `M + 1024` is
replaced by the variable `w2_size`.
"""

@I.ir_module
class Before:
@R.function
def main(
x: R.Tensor(["batch_size", 1024], dtype="float16"),
w1: R.Tensor([1024, 1024], dtype="float16"),
w2: R.Tensor([1024, "M"], dtype="float16"),
) -> R.Tensor(["batch_size", "M + 1024"], "float16"):
with R.dataflow():
concat = R.concat([w1, w2], axis=1)
out = R.matmul(x, concat)
R.output(out)
return out

@I.ir_module
class Expected:
@R.function
def main(
x: R.Tensor(["batch_size", 1024], dtype="float16"),
w1: R.Tensor([1024, 1024], dtype="float16"),
w2: R.Tensor([1024, "M"], dtype="float16"),
) -> R.Tensor(["batch_size", "M + 1024"], "float16"):
cls = Expected
with R.dataflow():
concat = R.concat([w1, w2], axis=1)
out = cls.fused_relax_matmul_cublas(x, concat)
R.output(out)
return out

@R.function
def fused_relax_matmul_cublas(
x: R.Tensor(["batch_size", 1024], dtype="float16"),
w2: R.Tensor([1024, "w2_size"], dtype="float16"),
) -> R.Tensor(["batch_size", "w2_size"], dtype="float16"):
batch_size = T.int64()
w2_size = T.int64()
R.func_attr({"Codegen": "cublas"})

@R.function
def inner_func(
x: R.Tensor([batch_size, 1024], dtype="float16"),
w2: R.Tensor((1024, w2_size), dtype="float16"),
) -> R.Tensor([batch_size, w2_size], dtype="float16"):
R.func_attr({"Composite": "cublas.matmul"})
with R.dataflow():
out = R.matmul(x, w2)
R.output(out)
return out

out = inner_func(x, w2)
return out

patterns = relax.backend.pattern_registry.get_patterns_with_prefix("cublas.matmul")
After = relax.transform.FuseOpsByPattern(patterns, bind_constants=False, annotate_codegen=True)(
Before
)
tvm.ir.assert_structural_equal(Expected, After)


if __name__ == "__main__":
pytest.main([__file__])