diff --git a/src/arith/const_int_bound.cc b/src/arith/const_int_bound.cc index 2c01b9143155..75c09ac05073 100644 --- a/src/arith/const_int_bound.cc +++ b/src/arith/const_int_bound.cc @@ -519,8 +519,8 @@ class ConstIntBoundAnalyzer::Impl */ static Entry MakeBound(int64_t min_value, int64_t max_value) { Entry e; - e.min_value = min_value; - e.max_value = max_value; + e.min_value = (min_value == kPosInf) ? min_value - 1 : min_value; + e.max_value = (max_value == kNegInf) ? max_value + 1 : max_value; return e; } /*! diff --git a/src/relay/transforms/fold_scale_axis.cc b/src/relay/transforms/fold_scale_axis.cc index 23be70c1e442..a93532895b5a 100644 --- a/src/relay/transforms/fold_scale_axis.cc +++ b/src/relay/transforms/fold_scale_axis.cc @@ -243,7 +243,18 @@ class ForwardPrep : private ExprVisitor { } } // Visitor pattern override. - void VisitExpr_(const LetNode* call) { LOG(FATAL) << "FoldScaleAxis only accept dataflow-form"; } + void VisitExpr_(const LetNode* op) { + ExprVisitor::VisitExpr_(op); + // do pass through condition + // by assigning NullValue + // it means fuse signal cannot pass + // through into these subexpressions. + auto flazy = [this, op]() { + this->Update(op->value, NullValue()); + this->Update(op->body, NullValue()); + }; + flist_.push_back(flazy); + } void VisitExpr_(const FunctionNode* op) { ExprVisitor::VisitExpr_(op); diff --git a/tests/python/relay/test_pass_fold_scale_axis.py b/tests/python/relay/test_pass_fold_scale_axis.py index 421c6c5e8ef2..3c2dc82cb07b 100644 --- a/tests/python/relay/test_pass_fold_scale_axis.py +++ b/tests/python/relay/test_pass_fold_scale_axis.py @@ -311,6 +311,44 @@ def check(shape, channels, blocking, in_scale): check((2, 11, 10, 2, 2), 4, (2, 2), in_scale) +def test_fold_fwd_let_fail(): + """testcase where we canont fold""" + + def before(x, conv_weight, in_bias, in_scale, channels): + args = [x, conv_weight, in_bias] + x = relay.multiply(x, in_scale) + x = relay.nn.relu(x) + x = relay.add(x, in_bias) + x_var = relay.Var("x_var") + y1 = relay.nn.conv2d( + x_var, + conv_weight, + channels=channels, + kernel_size=(3, 3), + data_layout="NHWC", + kernel_layout="HWIO", + padding=(1, 1), + ) + z = relay.add(y1, x) + let = relay.Let(x_var, x, z) + return relay.Function(args, let) + + def check(shape, channels): + x = relay.var("x", shape=shape) + in_channels = shape[-1] + in_bias = relay.var("in_bias", shape=(in_channels,)) + in_scale = relay.const(_get_positive_scale(size=(in_channels,))) + # test depthwise + assert in_channels == channels + weight = relay.var("weight") + y1 = before(x, weight, in_bias, in_scale, channels) + y1 = run_opt_pass(y1, transform.InferType()) + y1_folded = run_opt_pass(y1, transform.ForwardFoldScaleAxis()) + assert tvm.ir.structural_equal(y1, y1_folded) + + check((2, 11, 10, 4), 4) + + def test_fold_fwd_negative_scale(): """Testcase of folding negative scale""" diff --git a/tests/python/unittest/test_arith_const_int_bound.py b/tests/python/unittest/test_arith_const_int_bound.py index 84fc7fd64614..57e488f4f302 100644 --- a/tests/python/unittest/test_arith_const_int_bound.py +++ b/tests/python/unittest/test_arith_const_int_bound.py @@ -76,6 +76,20 @@ def test_add_sub_bound(): assert bd.min_value == bd.NEG_INF assert bd.max_value == 1 + ## constants with negative or positive max(int64) occassionally show up + ## in models, this is to ensure we can handle those cases + analyzer.update(x, tvm.arith.ConstIntBound(bd.NEG_INF, bd.NEG_INF), override=True) + analyzer.update(y, tvm.arith.ConstIntBound(bd.NEG_INF, bd.POS_INF), override=True) + bd = analyzer.const_int_bound(x + y) + assert bd.min_value == bd.NEG_INF + assert bd.max_value == bd.POS_INF + + analyzer.update(x, tvm.arith.ConstIntBound(bd.POS_INF, bd.POS_INF), override=True) + analyzer.update(y, tvm.arith.ConstIntBound(bd.NEG_INF, bd.POS_INF), override=True) + bd = analyzer.const_int_bound(x + y) + assert bd.min_value == bd.NEG_INF + assert bd.max_value == bd.POS_INF + def test_mul_bound(): analyzer = tvm.arith.Analyzer()