Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 102 additions & 0 deletions src/relay/transforms/simplify_expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,106 @@ class SimplifyConsecutiveCast : public DFPatternRewrite {
DFPattern cast1_;
};

bool CheckDataTypeMaxMinValue(DataType dtype, double min_value, double max_value) {
if (dtype.is_int() || dtype.is_uint()) {
double ubound = static_cast<double>(Downcast<IntImm>(tvm::max_value(dtype))->value);
double lbound = static_cast<double>(Downcast<IntImm>(tvm::min_value(dtype))->value);
return ubound == max_value && lbound == min_value;
} else if (dtype.is_float()) {
double ubound = Downcast<FloatImm>(tvm::max_value(dtype))->value;
double lbound = Downcast<FloatImm>(tvm::min_value(dtype))->value;
return ubound == max_value && lbound == min_value;
}

return false;
}

/*!
* \brief SimplifyClipAndConsecutiveCast matches the pattern clip->cast->cast and remove redundant
* casts.
* Analysis of "redundancy" is done based on clip min/max values and min/max values of casted data
* type.
*/
class SimplifyClipAndConsecutiveCast : public DFPatternRewrite {
public:
SimplifyClipAndConsecutiveCast() {
clip_ = IsOp("clip")({IsWildcard()});
cast1_ = IsOp("cast")({clip_});
pattern_ = IsOp("cast")({cast1_});
}

Expr Callback(const Expr& pre, const Expr& post,
const Map<DFPattern, Array<Expr>>& node_map) const override {
auto clip = Downcast<Call>(node_map[clip_][0]);
const CallNode* clip_node = clip.as<CallNode>();
const ClipAttrs* clip_attrs = clip_node->attrs.as<ClipAttrs>();
DataType clip_dtype = Downcast<TensorType>(clip->checked_type())->dtype;

auto cast1 = Downcast<Call>(node_map[cast1_][0]);
DataType cast1_dtype = Downcast<TensorType>(cast1->checked_type())->dtype;

auto cast2 = Downcast<Call>(post);
DataType cast2_dtype = Downcast<TensorType>(cast2->checked_type())->dtype;

if (clip_dtype == cast2_dtype &&
CheckDataTypeMaxMinValue(cast1_dtype, clip_attrs->a_min, clip_attrs->a_max)) {
// Case 1:
// Data type of Clip == target data type of second Cast and min/max value of Clip == min/max
// value of first Clip target data type. In this case both Clip ops can be removed.
// Example:
// %0 == [type=int32]
// %1 = clip(%0, a_min=0f, a_max=255f) [type=int32]
// %2 = cast(%1, dtype="uint8") [type=uint8]
// %3 = cast(%2, dtype="int32") [type=int32]
//
// Optimized to (both casts can be removed):
// %1 = clip(%0, a_min=0f, a_max=255f) [type=int32]
return node_map[clip_][0];
}
return post;
}

protected:
DFPattern clip_, cast1_;
};

/*!
* \brief SimplifyCastClip matches the pattern cast->clip and remove redundant Cast based on Clip
* min/max values and min/max values of Cast target data type.
*
* Example:
* %1 = cast(%0, dtype="uint8") [type=uint8]
* %2 = clip(%1, a_min=0f, a_max=255f) [type=int8]
*
* Optimized to (remove Clip):
* %1 = cast(%0, dtype="uint8") [type=uint8]
*/
class SimplifyCastClip : public DFPatternRewrite {
public:
SimplifyCastClip() {
cast_ = IsOp("cast")({IsWildcard()});
pattern_ = IsOp("clip")({cast_});
}

Expr Callback(const Expr& pre, const Expr& post,
const Map<DFPattern, Array<Expr>>& node_map) const override {
auto cast = Downcast<Call>(node_map[cast_][0]);
DataType cast_dtype = Downcast<TensorType>(cast->checked_type())->dtype;

auto clip = Downcast<Call>(post);
const CallNode* clip_node = clip.as<CallNode>();
const ClipAttrs* clip_attrs = clip_node->attrs.as<ClipAttrs>();

if (CheckDataTypeMaxMinValue(cast_dtype, clip_attrs->a_min, clip_attrs->a_max)) {
return node_map[cast_][0];
}
return post;
}

protected:
DFPattern clip_, cast_;
};

/*!
* \brief SimplifyTranspose matches the pattern of consecutive transpose op,
* and merges or cancels them.
Expand Down Expand Up @@ -804,6 +904,8 @@ Expr SimplifyExpr(const Expr& expr, const IRModule& mod) {
composer.AddRewrite<SimplifyDQArgMax>();
composer.AddRewrite<SimplifyDQArgMin>();
composer.AddRewrite<SimplifyDQArgSort>();
composer.AddRewrite<SimplifyClipAndConsecutiveCast>();
composer.AddRewrite<SimplifyCastClip>();
return RewritePatterns(composer.MakeCallbacks(), expr, mod);
}

Expand Down
31 changes: 31 additions & 0 deletions tests/python/relay/test_pass_simplify_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,5 +669,36 @@ def expected():
assert tvm.ir.structural_equal(opt, after)


def test_simplify_clip_cast():
x = relay.var("x", shape=(4, 8), dtype="int32")

def before():
clip = relay.clip(x, a_min=0.0, a_max=255.0)
cast = relay.cast(clip, "uint8")
return relay.cast(cast, "int32")

def expected():
return relay.clip(x, a_min=0.0, a_max=255.0)

opt = run_opt_pass(before(), transform.SimplifyExpr())
ref = run_infer_type(expected())
assert tvm.ir.structural_equal(opt, ref)


def test_simplify_cast_clip():
x = relay.var("x", shape=(4, 8), dtype="int32")

def before():
cast = relay.cast(x, "uint8")
return relay.clip(cast, a_min=0.0, a_max=255.0)

def expected():
return relay.cast(x, "uint8")

opt = run_opt_pass(before(), transform.SimplifyExpr())
ref = run_infer_type(expected())
assert tvm.ir.structural_equal(opt, ref)


if __name__ == "__main__":
pytest.main([__file__])