Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions include/tvm/relax/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -356,12 +356,13 @@ TVM_DLL Pass RunCodegen(Optional<Map<String, Map<String, ObjectRef>>> target_opt
Array<runtime::String> entry_functions);

/*!
* \brief Simplify normalization operators during inference. For example, the result
* \brief Decompose composite operators during inference. For example, the result
* of a batch norm which is indexed at tuple index 0 will be unpacked into a
* number of simplified operators.
* number of simplified operators. Operators like Attention, Erf, etc. can be also
* simplified into several operators as well.
* \return The Pass.
*/
TVM_DLL Pass SimplifyNormInference();
TVM_DLL Pass DecomposeCompositeOperator();

/*!
* \brief Returns a pass which replaces PrimFuncs which have matching kOperatorName attribute in \p
Expand Down
14 changes: 14 additions & 0 deletions python/tvm/relax/op/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,3 +402,17 @@ def shape_of(expr: Expr) -> Expr:
A relax Call, which gets the shape of the input
"""
return _ffi_api.shape_of(expr) # type: ignore # pylint: disable=no-member


def tensor_to_shape(expr: Expr) -> Expr:
"""Convert tensor to shape expr.
Parameters
----------
expr : Expr
The input Expr
Returns
-------
result : Expr
A relax Call, which transforms the tensor values to the shape
"""
return _ffi_api.tensor_to_shape(expr) # type: ignore # pylint: disable=no-member
6 changes: 5 additions & 1 deletion python/tvm/relax/transform/legalize_ops/manipulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from tvm import topi, tir, relax, te
from tvm.tir.expr import IntImm
from ...block_builder import BlockBuilder
from ...expr import Call, Expr, Var, Tuple, TupleGetItem
from ...expr import Call, Expr, Var, Tuple, TupleGetItem, ShapeExpr
from .common import TEFunc, LegalizeFunc, register_legalize


Expand All @@ -32,6 +32,10 @@ def _reshape(
) -> LegalizeFunc:
def reshape_call_te(bb: BlockBuilder, call: Call):
tgt_shape = call.args[1].struct_info.shape if is_collapse_sum_like else call.args[1]
# If target shape is Var, pass its bound expr only when it is ShapeExpr
if isinstance(tgt_shape, Var):
tgt_shape = bb.lookup_binding(tgt_shape)
assert isinstance(tgt_shape, ShapeExpr)
return bb.call_te(te_func, call.args[0], tgt_shape, primfunc_name_hint=primfunc_name)

return reshape_call_te
Expand Down
11 changes: 6 additions & 5 deletions python/tvm/relax/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,18 +577,19 @@ def MetaScheduleTuneIRMod(
return _ffi_api.MetaScheduleTuneIRMod(params, work_dir, max_trials_global) # type: ignore


def SimplifyNormInference() -> tvm.ir.transform.Pass:
"""Simplify normalization operators during inference. For example, the result
of a batch norm which is indexed at tuple index 0 will be unpacked into a
number of simplified operators.
def DecomposeCompositeOps() -> tvm.ir.transform.Pass:
"""Decompose composite operators that are composed by other operators during inference.
For example, the result of a batch norm which is indexed at tuple index 0 will be unpacked
into a number of simplified operators. Attention, tensor_to_shape, etc. can be also
decomposed into a number of simplified operators as well.

Returns
-------
ret : tvm.transform.Pass
The registered pass
"""

return _ffi_api.SimplifyNormInference() # type: ignore
return _ffi_api.DecomposeCompositeOps() # type: ignore


def AlterOpImpl(
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/script/ir_builder/relax/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@
prod,
repeat,
reshape,
tensor_to_shape,
round,
shape_of,
std,
Expand Down Expand Up @@ -612,6 +613,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr:
"prod",
"repeat",
"reshape",
"tensor_to_shape",
"round",
"shape",
"shape_of",
Expand Down
20 changes: 17 additions & 3 deletions src/relax/backend/vm/vm_builtin_lower.cc
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,23 @@ class VMBuiltinLowerMutator : public ExprMutator {
Expr Reshape(const Call& call_node) {
ICHECK(call_node->args.size() == 2);
ICHECK(call_node->struct_info_.defined());
CHECK(call_node->args[1]->IsInstance<ShapeExprNode>())
<< "VMBuiltinLower expects the shape arg of reshape op to be a ShapeExpr";
return Call(builtin_reshape_, call_node->args, Attrs(), {GetStructInfo(call_node)});
auto arg = call_node->args[1];
CHECK(arg->IsInstance<ShapeExprNode>() || arg->IsInstance<VarNode>())
<< "VMBuiltinLower expects the shape arg of reshape op to be a ShapeExpr or VarNode bound "
"to a ShapeExpr";

if (arg->IsInstance<ShapeExprNode>()) {
return Call(builtin_reshape_, call_node->args, Attrs(), {GetStructInfo(call_node)});
} else {
// Handling the case when arg is VarNode
Optional<Expr> _bound_val = LookupBinding(Downcast<Var>(arg));
ICHECK(_bound_val.defined());
Expr bound_val = _bound_val.value();
CHECK(bound_val->IsInstance<ShapeExprNode>())
<< "VMBuiltinLower expects bound value to be a ShapeExpr";
return Call(builtin_reshape_, {call_node->args[0], bound_val}, Attrs(),
{GetStructInfo(call_node)});
}
}

Expr ShapeOf(const Call& call_node) {
Expand Down
26 changes: 26 additions & 0 deletions src/relax/op/op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,32 @@ Expr MakeShapeOf(Expr expr) {

TVM_REGISTER_GLOBAL("relax.op.shape_of").set_body_typed(MakeShapeOf);

// tensor_to_shape

StructInfo ReturnTensorToShapeStructInfo(const Call& call, const BlockBuilder& ctx) {
ICHECK(call->args.size() == 1);
ICHECK(call->args[0]->struct_info_.defined());
const auto* tsinfo = GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
ICHECK(tsinfo && tsinfo->shape.defined());
Copy link
Contributor

@slyubomirsky slyubomirsky Mar 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this condition is more restrictive than necessary. I think it's okay for the shape not to be known at compile time and to return ShapeStructInfo(ndim=-1) if the input rank/shape is unknown at compile time.

Edit: I guess you use the rank in vm_builtin_lower.cc but I don't see why it has to be implemented that way. You could check all the necessary properties dynamically (inside the builtin)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean these lines?

  // define symbolic variables
  Array<PrimExpr> shape_var;
  for (int i = 0; i < sinfo->ndim; i++) {
    shape_var.push_back(tir::Var("x", DataType::Int(64)));
  }

Initially, I also wanted to support the unknown rank but realized it is trickier than I thought.
The problem is you need to insert these symbolic variables at compile-time so we need this info.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What exactly are the new shape vars needed for?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To MatchCast and put them in the ShapeExpr.

// define symbolic variables
Array<PrimExpr> shape_var;
for (int i = 0; i < sinfo->ndim; i++) {
    shape_var.push_back(tir::Var("x", DataType::Int(64)));
 }

// bind symbolic variables to the shape tuple
relax::Var var("y", ShapeStructInfo(shape_var));
builder_->EmitNormalized(MatchCast(var, call, ShapeStructInfo(shape_var)));
return ShapeExpr(shape_var);

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That much I understand, what I am more curious about is why we need to construct this shape expression

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does lowering require the dimensions to be bound to a variable? That doesn't seem like a restriction that needs to exist.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@slyubomirsky To pre-allocate memory for TIR functions? TE compute takes the output shape in Array<PrimExpr> as the first argument:

Tensor compute(Array<PrimExpr> shape, FCompute fcompute, std::string name, std::string tag,
               Map<String, ObjectRef> attrs)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great discussion folks! Looks the limitation for shape_expr is from TE, interesting

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmmm, I think we should put this on the list to fix later. Let's please note this limitation in a comment, since it is not obvious from looking at the code

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added this in the comment. Would it be good to go now?

ShapeExpr shape_expr = Downcast<ShapeExpr>(tsinfo->shape.value());
ICHECK(shape_expr->values.size() == 1);
const IntImmNode* ndim = shape_expr->values[0].as<IntImmNode>();
ICHECK(ndim);
return ShapeStructInfo(ndim->value);
}

RELAY_REGISTER_OP("relax.tensor_to_shape")
.set_num_inputs(1)
.add_argument("input", "Expr", "The input expression")
.set_attr<FInferStructInfo>("FInferStructInfo", ReturnTensorToShapeStructInfo);

Expr MakeTensorToShape(Expr expr) {
static const Op& op = Op::Get("relax.tensor_to_shape");
return Call(op, {expr}, {}, {});
}

TVM_REGISTER_GLOBAL("relax.op.tensor_to_shape").set_body_typed(MakeTensorToShape);

// alloc_tensor

StructInfo InferStructInfoAllocateTensor(const Call& call, const BlockBuilder& ctx) {
Expand Down
7 changes: 6 additions & 1 deletion src/relax/op/tensor/manipulate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -751,7 +751,12 @@ StructInfo InferStructInfoReshape(const Call& call, const BlockBuilder& ctx) {
<< new_shape_prod);
}
}
return TensorStructInfo(call->args[1], data_sinfo->dtype);
Expr target_shape = call->args[1];
// If shape values are defined, use them
if (target_shape->IsInstance<VarNode>() && new_shape_sinfo->values.defined()) {
return TensorStructInfo(ShapeExpr(new_shape_sinfo->values.value()), data_sinfo->dtype);
}
return TensorStructInfo(target_shape, data_sinfo->dtype);
}

TVM_REGISTER_OP("relax.reshape")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

#include <tvm/relax/analysis.h>
#include <tvm/relax/attrs/nn.h>
#include <tvm/relax/struct_info.h>
#include <tvm/relax/transform.h>

#include "utils.h"
Expand Down Expand Up @@ -110,21 +111,63 @@ class NormInferenceSimplifier : public ExprMutator {
Map<Expr, Expr> batch_norm_map_;
};

class OpDecomposer : public ExprMutator {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The name of OpDecoposer is too generic since it's only for tensor_to_shape

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can extend this function to add other composite ops like erf, attention which will be upcoming soon.

public:
static Expr Decompose(Expr expr) { return OpDecomposer()(expr); }

private:
using ExprMutator::VisitExpr_;
Expr TensorToShape(const Call& call_node) {
ICHECK(call_node->struct_info_.defined());
Expr expr = call_node->args[0];
const ShapeStructInfoNode* sinfo = GetStructInfoAs<ShapeStructInfoNode>(call_node);
ICHECK(sinfo);
// call builtin function that converts tensor to shape tuple
// TODO(@sunggg): Register operator for "vm.builtin.tensor_to_shape"
Var call = builder_->Emit(Call(ExternFunc("vm.builtin.tensor_to_shape"), {expr}, {},
{GetRef<ShapeStructInfo>(sinfo)}));

// Operators like reshape take the output of `TensorToShape` as their output shape.
// Because TOPI expects to have such output shape in symbolic shape at least (i.e.,
// Array<PrimExpr>), we define symbolic variables and returns them as a ShapeExpr.
Array<PrimExpr> shape_var;
for (int i = 0; i < sinfo->ndim; i++) {
shape_var.push_back(tir::Var("x", DataType::Int(64)));
}
// bind symbolic variables to the shape tuple
relax::Var var("y", ShapeStructInfo(shape_var));
builder_->EmitNormalized(MatchCast(var, call, ShapeStructInfo(shape_var)));
return ShapeExpr(shape_var);
}

Expr VisitExpr_(const CallNode* call_node) final {
Call call = Downcast<Call>(VisitExprPostOrder_(call_node));
if (call->op == tensor_to_shape_op_) {
return TensorToShape(call);
} else {
return call;
}
}

const Op& tensor_to_shape_op_ = Op::Get("relax.tensor_to_shape");
};

namespace transform {
Pass SimplifyNormInference() {
Pass DecomposeCompositeOps() {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
f = Downcast<Function>(NormInferenceSimplifier::Simplify(f));
// Remove original batch_norm op if it's not used.
f = Downcast<Function>(OpDecomposer::Decompose(f));
// Remove original ops if it's not used.
return RemoveAllUnused(f);
};
return CreateFunctionPass(/*pass_function=*/pass_func, //
/*opt_level=*/0, //
/*pass_name=*/"SimplifyNormInference", //
/*pass_name=*/"DecomposeCompositeOps", //
/*required=*/{});
}

TVM_REGISTER_GLOBAL("relax.transform.SimplifyNormInference").set_body_typed(SimplifyNormInference);
TVM_REGISTER_GLOBAL("relax.transform.DecomposeCompositeOps").set_body_typed(DecomposeCompositeOps);

} // namespace transform
} // namespace relax
Expand Down
70 changes: 62 additions & 8 deletions src/relax/transform/fold_constant.cc
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,10 @@ class ConstantFolder : public ExprMutator {

using ExprMutator::VisitExpr_;

// TODO(@sunggg):
// Next PR will support fold with PackedFunc and MatchCast
// Until then, DecomposeCompositeOps() should be applied after
// this pass to fold `tensor_to_shape` op.
Expr VisitExpr_(const CallNode* call) final {
// post-order mutation
Call post_call = Downcast<Call>(VisitExprPostOrder_(call));
Expand All @@ -217,14 +221,64 @@ class ConstantFolder : public ExprMutator {
return VisitCallTIR(post_call).value_or(post_call);
}

// If we are in a dataflow block, we can fold ops by lowering them to call_tir.
if (builder_->CurrentBlockIsDataFlow() && legalize_map.count(op)) {
// Get the legalized expression
Expr legalized_expr = builder_->Normalize(legalize_map[op](builder_, post_call));
// If the legalized expression is call_tir, try to fold it.
const CallNode* call = legalized_expr.as<CallNode>();
if (call && call->op.same_as(call_tir_op)) {
return VisitCallTIR(GetRef<Call>(call)).value_or(post_call);
// Special logic to fold ShapeExpr between operators
// e.g.,
// <Before>
// lv: R.Shape([16, 16]) = R.shape([16, 16])
// gv: R.Tensor(lv2, dtype="float32") = R.reshape(data, lv)
// <After>
// gv: R.Tensor(lv2, dtype="float32") = R.reshape(data, R.shape([16, 16]))
//
Array<Expr> new_args;
for (auto arg : post_call->args) {
if (arg->IsInstance<VarNode>()) {
Optional<Expr> val = LookupBinding(Downcast<Var>(arg));
if (val.defined() && val.value()->IsInstance<ShapeExprNode>()) {
new_args.push_back(val.value());
continue;
}
}
new_args.push_back(arg);
}
post_call =
Call(post_call->op, new_args, post_call->attrs, post_call->sinfo_args, post_call->span);

// If we are in a dataflow block, we can fold ops.
if (builder_->CurrentBlockIsDataFlow()) {
// Check if we can them to call_tir
if (legalize_map.count(op)) {
// Get the legalized expression
Expr legalized_expr = builder_->Normalize(legalize_map[op](builder_, post_call));
// If the legalized expression is call_tir, try to fold it.
const CallNode* call = legalized_expr.as<CallNode>();
if (call && call->op.same_as(call_tir_op)) {
return VisitCallTIR(GetRef<Call>(call)).value_or(post_call);
}
} else if (op->name == "relax.tensor_to_shape") {
// Special handling for composite op "relax.tensor_to_shape"
// If its input is constant, we can access its value and create ShapeExpr
// TODO(@sunggg):
// currently, we do not have a info map about decomposition.
// Thus, this is a temporary solution until we have a consensus about
// how to deal with composite ops. One possibility is we register the
// decomposition map for each op in a similar way we do for legalization.
ICHECK_EQ(post_call->args.size(), 1);
Expr arg = post_call->args[0];
if (arg->IsInstance<ConstantNode>()) {
Constant constant = Downcast<Constant>(arg);
runtime::NDArray ndarray = constant->data;
ICHECK_EQ(ndarray->device.device_type, kDLCPU);
ICHECK(ndarray->strides == nullptr);
ICHECK_EQ(ndarray->byte_offset, 0);
ICHECK_EQ(ndarray->ndim, 1);
const int64_t* data = static_cast<const int64_t*>(ndarray->data);
int64_t num_elems = ndarray->shape[0];
Array<PrimExpr> shape_values;
for (int64_t i = 0; i < num_elems; i++) {
shape_values.push_back(IntImm(DataType::Int(64), data[i]));
}
return ShapeExpr(shape_values);
}
}
}

Expand Down
34 changes: 34 additions & 0 deletions src/runtime/relax_vm/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,40 @@ TVM_REGISTER_GLOBAL("vm.builtin.make_tuple").set_body([](TVMArgs args, TVMRetVal
*rv = arr;
});

TVM_REGISTER_GLOBAL("vm.builtin.tensor_to_shape").set_body_typed([](NDArray data) {
NDArray arr = data;
if (data->device.device_type != kDLCPU) {
arr = data.CopyTo(DLDevice{kDLCPU, 0});
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this copy fail if the NDArray is on another device? I'm a little hesitant to have an op that just will not work depending on the device.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible that we have a shape tensor not on host device?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure where this might come up. I wouldn't be surprised if, say, we import an ONNX model (this is where this use-case first came up) that we mean to run on GPU and every tensor (including those that stand for shapes) is stored on GPU.

Copy link
Contributor Author

@sunggg sunggg Mar 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on my current understanding, we want to keep the shape-related computation on the host side since its result could be related to the memory planning. Even for general GPU execution besides this case, we are running output shape computation (inserted by VMShapeLower pass) on the host side.

}

ICHECK_EQ(arr->ndim, 1);
ICHECK_EQ(arr->dtype.code, kDLInt);

std::vector<int64_t> out_shape;
for (int i = 0; i < arr.Shape()[0]; ++i) {
int64_t result;
switch (arr->dtype.bits) {
case 16: {
result = reinterpret_cast<int16_t*>(arr->data)[i];
break;
}
case 32: {
result = reinterpret_cast<int32_t*>(arr->data)[i];
break;
}
case 64: {
result = reinterpret_cast<int64_t*>(arr->data)[i];
break;
}
default:
LOG(FATAL) << "Unknown scalar int type: " << DLDataType2String(arr->dtype);
throw;
}
out_shape.push_back(result);
}
return ShapeTuple(out_shape);
});

} // namespace relax_vm
} // namespace runtime
} // namespace tvm
Expand Down
Loading