Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions src/relay/transforms/simplify_expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,36 @@ class ConcretizeBroadcastToLikeRewrite : public ConcretizeLikeRewrite {
}
};

/*!
* \brief Converts cast_like operator to cast. Not inheriting from ConcretizeLikeRewrite
* because even if shape is not static, still can concretize.
*/
class ConcretizeCastLikeRewrite : public DFPatternRewrite {
public:
ConcretizeCastLikeRewrite() {
data_pat_ = IsWildcard();
like_pat_ = IsWildcard();
pattern_ = IsOp("cast_like")({data_pat_, like_pat_});
}

Expr Callback(const Expr& pre, const Expr& post,
const Map<DFPattern, Array<Expr>>& node_map) const override {
const CallNode* call_node = pre.as<CallNode>();
ICHECK(call_node);

if (!call_node->checked_type().as<TensorTypeNode>()) {
return post;
}

const TensorTypeNode* like_ty = pre->checked_type().as<TensorTypeNode>();
return MakeCast(node_map[data_pat_][0], like_ty->dtype);
}

protected:
DFPattern data_pat_;
DFPattern like_pat_;
};

/*! \brief Eliminates expressions that are equivalent to identity. */
class EliminateIdentityRewrite : public DFPatternRewrite {
public:
Expand Down Expand Up @@ -762,6 +792,7 @@ Expr SimplifyExpr(const Expr& expr, const IRModule& mod) {
composer.AddRewrite<ConcretizeReshapeLikeRewrite>();
composer.AddRewrite<ConcretizeCollapseSumLikeRewrite>();
composer.AddRewrite<ConcretizeBroadcastToLikeRewrite>();
composer.AddRewrite<ConcretizeCastLikeRewrite>();
composer.AddRewrite<SimplifyRSqrt>();
composer.AddRewrite<EliminateIdentityRewrite>();
composer.AddRewrite<SimplifyReshape>();
Expand Down
13 changes: 12 additions & 1 deletion tests/python/relay/test_pass_simplify_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ def test_simplify_consecutive_cast():
expr1 = relay.cast(x, "int32")
expr2 = relay.cast_like(expr1, y)
actual = run_opt_pass(expr2, relay.transform.SimplifyExpr())
expected = run_infer_type(expr2)
expected = run_infer_type(relay.cast(expr1, "float32"))
assert tvm.ir.structural_equal(actual, expected)


Expand Down Expand Up @@ -517,6 +517,17 @@ def test_concretize_broadcast_to_like():
assert tvm.ir.structural_equal(actual, expected)


def test_concretize_cast_like():
dim_any = tvm.tir.Any()
data = relay.var("data", shape=(3, dim_any, 5), dtype="float32")
dtype_like = relay.var("dtype_like", shape=(dim_any, 3, 3), dtype="int32")
expr = relay.cast_like(data, dtype_like)

expected = run_infer_type(relay.cast(data, "int32"))
actual = run_opt_pass(expr, relay.transform.SimplifyExpr())
assert tvm.ir.structural_equal(actual, expected)


def test_concretize_multiple():
x = relay.var("x", shape=(2, 3), dtype="float32")
y = relay.var("y", shape=(3,), dtype="float32")
Expand Down