Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions include/tvm/relax/struct_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,19 +60,26 @@ class ObjectStructInfo : public StructInfo {
*/
class PrimStructInfoNode : public StructInfoNode {
public:
/*! \brief Underlying primitive value, if known */
Optional<PrimExpr> value;

/*! \brief Underlying data type of the primitive value */
DataType dtype;

void VisitAttrs(AttrVisitor* v) {
v->Visit("value", &value);
v->Visit("dtype", &dtype);
v->Visit("span", &span);
}

bool SEqualReduce(const PrimStructInfoNode* other, SEqualReducer equal) const {
return equal(dtype, other->dtype);
return equal(value, other->value) && equal(dtype, other->dtype);
}

void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(dtype); }
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(value);
hash_reduce(dtype);
}

static constexpr const char* _type_key = "relax.PrimStructInfo";
TVM_DECLARE_FINAL_OBJECT_INFO(PrimStructInfoNode, StructInfoNode);
Expand All @@ -84,8 +91,12 @@ class PrimStructInfoNode : public StructInfoNode {
*/
class PrimStructInfo : public StructInfo {
public:
/* Construct a PrimStructInfo with a known dtype, but unknown value */
TVM_DLL PrimStructInfo(DataType dtype, Span span = Span());

/* Construct a PrimStructInfo with a known value */
TVM_DLL PrimStructInfo(PrimExpr value, Span span = Span());

TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(PrimStructInfo, StructInfo, PrimStructInfoNode);
};

Expand Down
63 changes: 59 additions & 4 deletions python/tvm/relax/struct_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from tvm.ir import Span, EnvFunc, Array, VDevice
from tvm.tir import PrimExpr
from tvm.runtime import DataType
from .expr import StructInfo, Expr, ShapeExpr

from . import _ffi_api, ty, expr
Expand All @@ -42,14 +43,68 @@ class PrimStructInfo(StructInfo):

Parameters
----------
dtype : str
The data type of the prim value.
dtype_or_expr : Union[str, DataType, PrimExpr]

The data type of the prim value, or a known expression for the prim
value.
"""

value: Optional[PrimExpr]
dtype: str

def __init__(self, dtype: str, span: Span = None) -> None:
self.__init_handle_by_constructor__(_ffi_api.PrimStructInfo, dtype, span) # type: ignore
def __init__(
self,
dtype: Optional[Union[str, DataType]] = None,
value: Optional[Union[int, float, PrimExpr]] = None,
span: Span = None,
) -> None:
# Guard against incorrect usage. For backwards compatibility,
# the dtype and value are in the opposite order from most
# usages. While PrimStructInfo could take a single positional
# argument and check the type, this would require an API
# difference from TVMScript's PrimProxy, which cannot.
# (PrimProxy uses string arguments for datatype, and also for
# inline variable definitions when used in a function
# signature, and requires separate arguments to distinguish
# the two cases.)
Comment on lines +61 to +69
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very user-friendly touch :)

if isinstance(dtype, (PrimExpr, int, float)):
raise TypeError(
f"The first positional argument of PrimStructInfo must be the datatype, "
f", but received {type(dtype)}. "
f"The value can be specified as a keyword argument "
f"without needing specifying the dtype: "
f"PrimStructInfo(value=arg)."
)

if dtype is None and value is None:
raise TypeError(
"PrimStructInfo.__init__ missing required argument. "
"Must provide either 'dtype' or 'value'"
)

if dtype is not None:
if isinstance(value, PrimExpr):
assert value.dtype == dtype, (
"When providing both 'value' and 'dtype' to PrimStructInfo.__init__, "
"they must be consistent with each other. "
"However, the value {value} has dtype {value.dtype}, "
"but the specified dtype was {dtype}."
)
elif isinstance(value, (int, float)):
value = tvm.tir.const(value, dtype)

# Use relax's default integer type if not otherwise specified.
if isinstance(value, int):
value = tvm.tir.IntImm("int64", value)

if value is None:
self.__init_handle_by_constructor__(
_ffi_api.PrimStructInfoFromDtype, dtype, span
) # type: ignore
else:
self.__init_handle_by_constructor__(
_ffi_api.PrimStructInfoFromValue, value, span
) # type: ignore


@tvm._ffi.register_object("relax.ShapeStructInfo")
Expand Down
43 changes: 34 additions & 9 deletions python/tvm/script/parser/relax/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def __init__(

def get_symbolic_vars(self) -> Set[str]:
if self.values is None:
return {}
return set()
else:
return {v for v in self.values if isinstance(v, str) and v.isidentifier()}

Expand Down Expand Up @@ -342,27 +342,52 @@ def Object() -> ObjectProxy:


class PrimProxy(StructInfoProxy):
dtype: str
"""The type of shape values.
dtype: Optional[str]
value: Optional[Union[int, float, str, PrimExpr]]

"""The type of TIR-representable values.

Parameters
----------
dtype : str
dtype : Optional[str]
The data type.

value: Optional[Union[int, float, str, PrimExpr]]
The known value
"""

def __init__(self, dtype: str) -> None:
def __init__(
self,
dtype: Optional[str] = None,
value: Optional[Union[int, float, str, PrimExpr]] = None,
) -> None:
if dtype is None and value is None:
raise TypeError(
"R.Prim missing required argument. " "Must provide either 'dtype' or 'value'"
)

self.dtype = dtype
self.value = value

def get_symbolic_vars(self) -> Set[str]:
return set()
if isinstance(self.value, str) and self.value.isidentifier():
return {self.value}
else:
return set()

def as_struct_info(self, dict_globals: Optional[Dict[str, Any]] = None) -> ShapeStructInfo:
return PrimStructInfo(self.dtype)
if self.value is None:
return PrimStructInfo(dtype=self.dtype)
else:
value = _eval_shape(self.value, dict_globals)
return PrimStructInfo(dtype=self.dtype, value=value)


def Prim(dtype: str) -> PrimProxy:
return PrimProxy(dtype)
def Prim(
dtype: Optional[str] = None,
value: Optional[Union[int, float, str, PrimExpr]] = None,
) -> PrimProxy:
return PrimProxy(dtype, value)


############################ R.match_cast #############################
Expand Down
3 changes: 3 additions & 0 deletions src/relax/analysis/struct_info_analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1053,6 +1053,9 @@ class SymbolicVarCollector : public relax::ExprVisitor,
this->VisitStructInfoExprField(val);
}
}
if (auto prim_value = expr.as<relax::PrimValue>()) {
this->VisitStructInfoExprField(prim_value.value()->value);
}
}

void VisitStructInfoExprField(const PrimExpr& expr) final {
Expand Down
2 changes: 1 addition & 1 deletion src/relax/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ TVM_REGISTER_GLOBAL("relax.Constant")
PrimValue::PrimValue(PrimExpr value, Span span) {
ObjectPtr<PrimValueNode> n = make_object<PrimValueNode>();
n->checked_type_ = PrimType(value.dtype());
n->struct_info_ = PrimStructInfo(value.dtype());
n->struct_info_ = PrimStructInfo(value);
n->value = std::move(value);
n->span = std::move(span);
data_ = std::move(n);
Expand Down
3 changes: 3 additions & 0 deletions src/relax/ir/expr_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,9 @@ void ExprVisitor::VisitExpr_(const SeqExprNode* op) {

void ExprVisitor::VisitExpr_(const PrimValueNode* op) {
this->VisitPrimExpr(op->value);
if (auto* sinfo = op->struct_info_.as<StructInfoNode>()) {
this->VisitExprDepStructInfoField(GetRef<StructInfo>(sinfo));
}
this->VisitSpan(op->span);
}

Expand Down
15 changes: 14 additions & 1 deletion src/relax/ir/struct_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,19 +42,32 @@ TVM_REGISTER_GLOBAL("relax.ObjectStructInfo").set_body_typed([](Span span) {
});

// Prim
PrimStructInfo::PrimStructInfo(PrimExpr value, Span span) {
ObjectPtr<PrimStructInfoNode> n = make_object<PrimStructInfoNode>();
n->dtype = value->dtype;
n->value = std::move(value);
n->span = span;
data_ = std::move(n);
}

PrimStructInfo::PrimStructInfo(DataType dtype, Span span) {
ObjectPtr<PrimStructInfoNode> n = make_object<PrimStructInfoNode>();
n->dtype = dtype;
n->value = NullOpt;
n->span = span;
data_ = std::move(n);
}

TVM_REGISTER_NODE_TYPE(PrimStructInfoNode);

TVM_REGISTER_GLOBAL("relax.PrimStructInfo").set_body_typed([](DataType dtype, Span span) {
TVM_REGISTER_GLOBAL("relax.PrimStructInfoFromDtype").set_body_typed([](DataType dtype, Span span) {
return PrimStructInfo(dtype, span);
});

TVM_REGISTER_GLOBAL("relax.PrimStructInfoFromValue").set_body_typed([](PrimExpr value, Span span) {
return PrimStructInfo(value, span);
});

// Shape
ShapeStructInfo::ShapeStructInfo(Array<PrimExpr> values, Span span) {
ObjectPtr<ShapeStructInfoNode> n = make_object<ShapeStructInfoNode>();
Expand Down
17 changes: 15 additions & 2 deletions src/relax/ir/struct_info_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,11 @@ namespace relax {

void StructInfoVisitor::VisitStructInfo_(const ObjectStructInfoNode* op) {}

void StructInfoVisitor::VisitStructInfo_(const PrimStructInfoNode* op) {}
void StructInfoVisitor::VisitStructInfo_(const PrimStructInfoNode* op) {
if (op->value.defined()) {
this->VisitStructInfoExprField(op->value.value());
}
}

void StructInfoVisitor::VisitStructInfo_(const ShapeStructInfoNode* op) {
if (op->values.defined()) {
Expand Down Expand Up @@ -68,7 +72,16 @@ StructInfo StructInfoMutator::VisitStructInfo_(const ObjectStructInfoNode* op) {
}

StructInfo StructInfoMutator::VisitStructInfo_(const PrimStructInfoNode* op) {
return GetRef<StructInfo>(op);
if (!op->value.defined()) {
return GetRef<StructInfo>(op);
}

auto new_expr = VisitStructInfoExprField(op->value.value());
if (new_expr.same_as(op->value)) {
return GetRef<StructInfo>(op);
} else {
return PrimStructInfo(new_expr);
}
}

StructInfo StructInfoMutator::VisitStructInfo_(const ShapeStructInfoNode* op) {
Expand Down
23 changes: 17 additions & 6 deletions src/script/printer/relax/struct_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,6 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
return Relax(d, "Object");
});

TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<relax::PrimStructInfo>(
"", [](relax::PrimStructInfo n, ObjectPath n_p, IRDocsifier d) -> Doc {
return Relax(d, "Prim")->Call({LiteralDoc::DataType(n->dtype, n_p->Attr("dtype"))});
});

ExprDoc PrintShapeVar(const PrimExpr& e, const ObjectPath& e_p, const IRDocsifier& d) {
ExprDoc expr_doc = d->AsDoc<ExprDoc>(e, e_p);
// Step 1. Find if `func_vars` are being collected
Expand Down Expand Up @@ -66,6 +60,23 @@ ExprDoc PrintShapeVar(const PrimExpr& e, const ObjectPath& e_p, const IRDocsifie
return expr_doc;
}

TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<relax::PrimStructInfo>(
"", [](relax::PrimStructInfo n, ObjectPath n_p, IRDocsifier d) -> Doc {
Array<ExprDoc, void> args;
Array<String> kwargs_keys;
Array<ExprDoc, void> kwargs_values;

if (n->value.defined()) {
kwargs_keys.push_back("value");
kwargs_values.push_back(PrintShapeVar(n->value.value(), n_p->Attr("value"), d));
} else {
args.push_back(LiteralDoc::DataType(n->dtype, n_p->Attr("dtype")));
}

return Relax(d, "Prim")->Call(args, kwargs_keys, kwargs_values);
});

TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<relax::ShapeStructInfo>(
"", [](relax::ShapeStructInfo n, ObjectPath n_p, IRDocsifier d) -> Doc {
Expand Down
43 changes: 42 additions & 1 deletion tests/python/relax/test_analysis_struct_info_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,7 +578,7 @@ def test_tir_vars_in_struct_info():
tvm.ir.assert_structural_equal(rx.analysis.tir_vars_in_struct_info(func), [n, m])


def test_symbolic_var_collector():
def test_collect_symbolic_var_from_tensor_shape():
n, m, k, q, p = (
tir.Var("n", "int64"),
tir.Var("m", "int64"),
Expand All @@ -600,5 +600,46 @@ def test_symbolic_var_collector():
assert free_vars == {n, p, q}


param_type = tvm.testing.parameter("shape_expr", "prim_value")
param_order = tvm.testing.parameter("definition_first", "usage_first")


def test_collect_symbolic_var_from_non_tensor_params(param_type, param_order):
tir_n = tir.Var("n", "int64")
tir_m = tir.Var("m", "int64")

bb = rx.BlockBuilder()
arg = rx.Var("arg", rx.TensorStructInfo([tir_n * tir_m]))

if param_type == "shape_expr":
extra_params = [
rx.Var("shape_expr", rx.ShapeStructInfo([tir_n, tir_m])),
]
elif param_type == "prim_value":
extra_params = [
rx.Var("n", rx.PrimStructInfo(value=tir_n)),
rx.Var("m", rx.PrimStructInfo(value=tir_m)),
]
else:
raise ValueError(f"Unknown param_type: {param_type}")

if param_order == "definition_first":
params = [*extra_params, arg]
elif param_order == "usage_first":
params = [arg, *extra_params]
else:
raise ValueError(f"Unknown param_order: {param_order}")

with bb.function("main", params=params):
out = rx.op.reshape(arg, [tir_n, tir_m])
bb.emit_func_output(out)
func = bb.get()["main"]

defined_vars = set(rx.analysis.defined_symbolic_vars(func))
free_vars = set(rx.analysis.free_symbolic_vars(func))
assert defined_vars == {tir_n, tir_m}
assert free_vars == set()


if __name__ == "__main__":
tvm.testing.main()
Loading