From 0418e2e678d9cc6b008a01e27f81dffcfd2f5252 Mon Sep 17 00:00:00 2001 From: mbrookhart Date: Tue, 23 Feb 2021 10:20:27 -0700 Subject: [PATCH] Support creating Bool constants in the pattern_utils --- src/relay/transforms/pattern_utils.h | 3 +++ tests/python/relay/test_pass_simplify_expr.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/relay/transforms/pattern_utils.h b/src/relay/transforms/pattern_utils.h index bc0fcc9f2988..c1eebde15fba 100644 --- a/src/relay/transforms/pattern_utils.h +++ b/src/relay/transforms/pattern_utils.h @@ -86,6 +86,9 @@ namespace relay { } else if (type == DataType::UInt(8)) { \ typedef uint8_t DType; \ { __VA_ARGS__ } \ + } else if (type == DataType::Bool()) { \ + typedef bool DType; \ + { __VA_ARGS__ } \ } else if ((*tvm::runtime::Registry::Get("runtime._datatype_get_type_registered"))( \ static_cast(type.code()))) { \ typedef double DType; \ diff --git a/tests/python/relay/test_pass_simplify_expr.py b/tests/python/relay/test_pass_simplify_expr.py index 3d925bcfc759..423f0a4f213d 100644 --- a/tests/python/relay/test_pass_simplify_expr.py +++ b/tests/python/relay/test_pass_simplify_expr.py @@ -117,7 +117,7 @@ def after_right(x, elem_op, value): assert tvm.ir.structural_equal(zz, after) for shape in [[10], [10, 10], [10, 10, 10]]: - for dtype in ["float32", "int32"]: + for dtype in ["float32", "int32", "bool"]: for value in [0, 1, 2]: validate(shape, value, dtype)