From b7f99bad6d0a82193ee583509422dfb801172b4f Mon Sep 17 00:00:00 2001 From: Andrew Zhao Luo Date: Thu, 23 Mar 2023 11:12:45 -0700 Subject: [PATCH 1/4] divtomul fix --- src/relay/transforms/div_to_mul.cc | 64 +++++++++++++++--------- tests/python/unittest/test_div_to_mul.py | 31 +++++++++++- 2 files changed, 70 insertions(+), 25 deletions(-) diff --git a/src/relay/transforms/div_to_mul.cc b/src/relay/transforms/div_to_mul.cc index 42983c520682..ea747e98a09e 100644 --- a/src/relay/transforms/div_to_mul.cc +++ b/src/relay/transforms/div_to_mul.cc @@ -26,42 +26,60 @@ namespace tvm { namespace relay { +template +inline bool const_has_values(size_t size, const ConstantNode* const_node, const std::vector&& values) { + for (size_t i = 0; i < size; i++) { + T data = static_cast(const_node->data->data)[i]; + for (const T& v: values) { + if (data == v) return true; + } + } + return false; +} + +inline size_t get_num_elements_const(const ConstantNode* const_node) { + const auto& shape = const_node -> data.Shape(); + + size_t cnt_elements = 1; + for (const auto& dim: shape) { + cnt_elements *= dim; + } + + return cnt_elements; +} + class DivToMulRewrite : public MixedModeMutator { Expr Rewrite_(const CallNode* pre, const Expr& post) final { if (const CallNode* call_node = post.as()) { if (call_node->op == Op::Get("divide")) { auto rhs = call_node->args[1].as(); if (rhs != nullptr) { - auto inv = - runtime::NDArray::Empty(rhs->data.Shape(), rhs->data.DataType(), rhs->data->device); + auto one = + runtime::NDArray::Empty({}, rhs->data.DataType(), rhs->data->device); + size_t num_ele = get_num_elements_const(rhs); std::string dtype = DLDataType2String(rhs->data.DataType()); + + bool const_has_zero_flag = false; if (dtype == "float32") { - float rhs_val = static_cast(rhs->data->data)[0]; - // Check for division by zero - if (rhs_val == 0.) { - return post; - } - static_cast(inv->data)[0] = 1. / rhs_val; + static_cast(one->data)[0] = 1.; + const_has_zero_flag = const_has_values(num_ele, rhs, {0.}); } else if (dtype == "float64") { - double rhs_val = static_cast(rhs->data->data)[0]; - // Check for division by zero - if (rhs_val == 0.) { - return post; - } - static_cast(inv->data)[0] = 1. / rhs_val; + static_cast(one->data)[0] = 1.; + const_has_zero_flag = const_has_values(num_ele, rhs, {0.}); } else if (dtype == "float16") { - // Do f16 math in f32 - float rhs_val = __gnu_h2f_ieee(static_cast(rhs->data->data)[0]); - // Check for division by zero - if (rhs_val == 0.) { - return post; - } - static_cast(inv->data)[0] = __gnu_f2h_ieee(1. / rhs_val); + static_cast(one->data)[0] = __gnu_f2h_ieee(1.); + const_has_zero_flag = const_has_values(num_ele, rhs, {0}); } else { - // Cannot do 1/int because it will truncate + LOG(WARNING) << "Unknown dtype not handled for div_to_mull: " << rhs->data.DataType(); + return post; + } + + if (const_has_zero_flag) { return post; } - return Multiply(call_node->args[0], Constant(inv)); + + // rely on constant folding to fold things + return Multiply(call_node->args[0], Divide(Constant(one), call_node->args[1])); } } } diff --git a/tests/python/unittest/test_div_to_mul.py b/tests/python/unittest/test_div_to_mul.py index 60c67ae2499c..32f977bfb89a 100644 --- a/tests/python/unittest/test_div_to_mul.py +++ b/tests/python/unittest/test_div_to_mul.py @@ -14,10 +14,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import numpy as np +import pytest + import tvm from tvm import relay -import pytest -import numpy as np @pytest.mark.parametrize("dtype, rtol", [("float16", 1e-3), ("float32", 1e-7), ("float64", 1e-12)]) @@ -27,5 +28,31 @@ def test_div_to_mul(dtype, rtol): z = x / y mod = tvm.IRModule.from_expr(z) transformed = relay.transform.DivToMul()(mod) + transformed = relay.transform.FoldConstant()(transformed) assert transformed["main"].body.op.name == "multiply" np.testing.assert_allclose(transformed["main"].body.args[1].data.numpy()[0], 1 / 1.5, rtol=rtol) + + +@pytest.mark.parametrize("dtype, rtol", [("float16", 1e-3), ("float32", 1e-7), ("float64", 1e-12)]) +def test_div_to_mul_vector(dtype, rtol): + x = relay.var("x", relay.TensorType([5], dtype)) + y = relay.Constant(tvm.nd.array(np.array([2, 2, 2, 4, 5]).astype(dtype))) + z = x / y + mod = tvm.IRModule.from_expr(z) + transformed = relay.transform.DivToMul()(mod) + transformed = relay.transform.FoldConstant()(transformed) + assert transformed["main"].body.op.name == "multiply" + np.testing.assert_allclose( + transformed["main"].body.args[1].data.numpy(), [0.5, 0.5, 0.5, 0.25, 0.2], rtol=rtol + ) + + +@pytest.mark.parametrize("dtype", [("float16"), ("float32"), ("float64")]) +def test_do_not_simplify_zero_div(dtype): + x = relay.var("x", relay.TensorType([5], dtype)) + y = relay.Constant(tvm.nd.array(np.array([2, 2, 2, 4, 0]).astype(dtype))) + z = x / y + mod = tvm.IRModule.from_expr(z) + transformed = relay.transform.DivToMul()(mod) + transformed = relay.transform.FoldConstant()(transformed) + assert transformed["main"].body.op.name == "divide" From f396f27c9828575aa594f208c37343ab8a0034ce Mon Sep 17 00:00:00 2001 From: Andrew Zhao Luo Date: Thu, 23 Mar 2023 11:17:23 -0700 Subject: [PATCH 2/4] handle +/- zero --- src/relay/transforms/div_to_mul.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/relay/transforms/div_to_mul.cc b/src/relay/transforms/div_to_mul.cc index ea747e98a09e..0236f863a8ab 100644 --- a/src/relay/transforms/div_to_mul.cc +++ b/src/relay/transforms/div_to_mul.cc @@ -68,7 +68,8 @@ class DivToMulRewrite : public MixedModeMutator { const_has_zero_flag = const_has_values(num_ele, rhs, {0.}); } else if (dtype == "float16") { static_cast(one->data)[0] = __gnu_f2h_ieee(1.); - const_has_zero_flag = const_has_values(num_ele, rhs, {0}); + // have to handle both + and - zero semantics manually here + const_has_zero_flag = const_has_values(num_ele, rhs, {0x00, 0x80}); } else { LOG(WARNING) << "Unknown dtype not handled for div_to_mull: " << rhs->data.DataType(); return post; @@ -77,7 +78,7 @@ class DivToMulRewrite : public MixedModeMutator { if (const_has_zero_flag) { return post; } - + // rely on constant folding to fold things return Multiply(call_node->args[0], Divide(Constant(one), call_node->args[1])); } From df51972df17a00642c8700c8e9314e919e799b7f Mon Sep 17 00:00:00 2001 From: Andrew Zhao Luo Date: Thu, 23 Mar 2023 11:25:29 -0700 Subject: [PATCH 3/4] lint -- also need to fix precommits --- src/relay/transforms/div_to_mul.cc | 34 +++++++++++++++--------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/src/relay/transforms/div_to_mul.cc b/src/relay/transforms/div_to_mul.cc index 0236f863a8ab..82434edc1b43 100644 --- a/src/relay/transforms/div_to_mul.cc +++ b/src/relay/transforms/div_to_mul.cc @@ -27,10 +27,11 @@ namespace tvm { namespace relay { template -inline bool const_has_values(size_t size, const ConstantNode* const_node, const std::vector&& values) { +inline bool const_has_values(size_t size, const ConstantNode* const_node, + const std::vector&& values) { for (size_t i = 0; i < size; i++) { T data = static_cast(const_node->data->data)[i]; - for (const T& v: values) { + for (const T& v : values) { if (data == v) return true; } } @@ -38,10 +39,10 @@ inline bool const_has_values(size_t size, const ConstantNode* const_node, const } inline size_t get_num_elements_const(const ConstantNode* const_node) { - const auto& shape = const_node -> data.Shape(); + const auto& shape = const_node->data.Shape(); size_t cnt_elements = 1; - for (const auto& dim: shape) { + for (const auto& dim : shape) { cnt_elements *= dim; } @@ -54,32 +55,31 @@ class DivToMulRewrite : public MixedModeMutator { if (call_node->op == Op::Get("divide")) { auto rhs = call_node->args[1].as(); if (rhs != nullptr) { - auto one = - runtime::NDArray::Empty({}, rhs->data.DataType(), rhs->data->device); + auto one = runtime::NDArray::Empty({}, rhs->data.DataType(), rhs->data->device); size_t num_ele = get_num_elements_const(rhs); std::string dtype = DLDataType2String(rhs->data.DataType()); bool const_has_zero_flag = false; if (dtype == "float32") { - static_cast(one->data)[0] = 1.; - const_has_zero_flag = const_has_values(num_ele, rhs, {0.}); + static_cast(one->data)[0] = 1.; + const_has_zero_flag = const_has_values(num_ele, rhs, {0.}); } else if (dtype == "float64") { - static_cast(one->data)[0] = 1.; - const_has_zero_flag = const_has_values(num_ele, rhs, {0.}); + static_cast(one->data)[0] = 1.; + const_has_zero_flag = const_has_values(num_ele, rhs, {0.}); } else if (dtype == "float16") { - static_cast(one->data)[0] = __gnu_f2h_ieee(1.); - // have to handle both + and - zero semantics manually here - const_has_zero_flag = const_has_values(num_ele, rhs, {0x00, 0x80}); + static_cast(one->data)[0] = __gnu_f2h_ieee(1.); + // have to handle both + and - zero semantics manually here + const_has_zero_flag = const_has_values(num_ele, rhs, {0x00, 0x80}); } else { - LOG(WARNING) << "Unknown dtype not handled for div_to_mull: " << rhs->data.DataType(); - return post; + LOG(WARNING) << "Unknown dtype not handled for div_to_mull: " << rhs->data.DataType(); + return post; } - + if (const_has_zero_flag) { return post; } - // rely on constant folding to fold things + // rely on constant folding to fold things return Multiply(call_node->args[0], Divide(Constant(one), call_node->args[1])); } } From 3cb691d4345668dc5be95d9cd6b52dd4666691df Mon Sep 17 00:00:00 2001 From: Andrew Zhao Luo Date: Thu, 23 Mar 2023 12:54:15 -0700 Subject: [PATCH 4/4] always forget 2 hex for one byte --- src/relay/transforms/div_to_mul.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relay/transforms/div_to_mul.cc b/src/relay/transforms/div_to_mul.cc index 82434edc1b43..3e21e9878b6b 100644 --- a/src/relay/transforms/div_to_mul.cc +++ b/src/relay/transforms/div_to_mul.cc @@ -69,7 +69,7 @@ class DivToMulRewrite : public MixedModeMutator { } else if (dtype == "float16") { static_cast(one->data)[0] = __gnu_f2h_ieee(1.); // have to handle both + and - zero semantics manually here - const_has_zero_flag = const_has_values(num_ele, rhs, {0x00, 0x80}); + const_has_zero_flag = const_has_values(num_ele, rhs, {0x0000, 0x8000}); } else { LOG(WARNING) << "Unknown dtype not handled for div_to_mull: " << rhs->data.DataType(); return post;