diff --git a/src/relay/transforms/div_to_mul.cc b/src/relay/transforms/div_to_mul.cc index 42983c520682..3e21e9878b6b 100644 --- a/src/relay/transforms/div_to_mul.cc +++ b/src/relay/transforms/div_to_mul.cc @@ -26,42 +26,61 @@ namespace tvm { namespace relay { +template +inline bool const_has_values(size_t size, const ConstantNode* const_node, + const std::vector&& values) { + for (size_t i = 0; i < size; i++) { + T data = static_cast(const_node->data->data)[i]; + for (const T& v : values) { + if (data == v) return true; + } + } + return false; +} + +inline size_t get_num_elements_const(const ConstantNode* const_node) { + const auto& shape = const_node->data.Shape(); + + size_t cnt_elements = 1; + for (const auto& dim : shape) { + cnt_elements *= dim; + } + + return cnt_elements; +} + class DivToMulRewrite : public MixedModeMutator { Expr Rewrite_(const CallNode* pre, const Expr& post) final { if (const CallNode* call_node = post.as()) { if (call_node->op == Op::Get("divide")) { auto rhs = call_node->args[1].as(); if (rhs != nullptr) { - auto inv = - runtime::NDArray::Empty(rhs->data.Shape(), rhs->data.DataType(), rhs->data->device); + auto one = runtime::NDArray::Empty({}, rhs->data.DataType(), rhs->data->device); + size_t num_ele = get_num_elements_const(rhs); std::string dtype = DLDataType2String(rhs->data.DataType()); + + bool const_has_zero_flag = false; if (dtype == "float32") { - float rhs_val = static_cast(rhs->data->data)[0]; - // Check for division by zero - if (rhs_val == 0.) { - return post; - } - static_cast(inv->data)[0] = 1. / rhs_val; + static_cast(one->data)[0] = 1.; + const_has_zero_flag = const_has_values(num_ele, rhs, {0.}); } else if (dtype == "float64") { - double rhs_val = static_cast(rhs->data->data)[0]; - // Check for division by zero - if (rhs_val == 0.) { - return post; - } - static_cast(inv->data)[0] = 1. / rhs_val; + static_cast(one->data)[0] = 1.; + const_has_zero_flag = const_has_values(num_ele, rhs, {0.}); } else if (dtype == "float16") { - // Do f16 math in f32 - float rhs_val = __gnu_h2f_ieee(static_cast(rhs->data->data)[0]); - // Check for division by zero - if (rhs_val == 0.) { - return post; - } - static_cast(inv->data)[0] = __gnu_f2h_ieee(1. / rhs_val); + static_cast(one->data)[0] = __gnu_f2h_ieee(1.); + // have to handle both + and - zero semantics manually here + const_has_zero_flag = const_has_values(num_ele, rhs, {0x0000, 0x8000}); } else { - // Cannot do 1/int because it will truncate + LOG(WARNING) << "Unknown dtype not handled for div_to_mull: " << rhs->data.DataType(); return post; } - return Multiply(call_node->args[0], Constant(inv)); + + if (const_has_zero_flag) { + return post; + } + + // rely on constant folding to fold things + return Multiply(call_node->args[0], Divide(Constant(one), call_node->args[1])); } } } diff --git a/tests/python/unittest/test_div_to_mul.py b/tests/python/unittest/test_div_to_mul.py index 60c67ae2499c..32f977bfb89a 100644 --- a/tests/python/unittest/test_div_to_mul.py +++ b/tests/python/unittest/test_div_to_mul.py @@ -14,10 +14,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import numpy as np +import pytest + import tvm from tvm import relay -import pytest -import numpy as np @pytest.mark.parametrize("dtype, rtol", [("float16", 1e-3), ("float32", 1e-7), ("float64", 1e-12)]) @@ -27,5 +28,31 @@ def test_div_to_mul(dtype, rtol): z = x / y mod = tvm.IRModule.from_expr(z) transformed = relay.transform.DivToMul()(mod) + transformed = relay.transform.FoldConstant()(transformed) assert transformed["main"].body.op.name == "multiply" np.testing.assert_allclose(transformed["main"].body.args[1].data.numpy()[0], 1 / 1.5, rtol=rtol) + + +@pytest.mark.parametrize("dtype, rtol", [("float16", 1e-3), ("float32", 1e-7), ("float64", 1e-12)]) +def test_div_to_mul_vector(dtype, rtol): + x = relay.var("x", relay.TensorType([5], dtype)) + y = relay.Constant(tvm.nd.array(np.array([2, 2, 2, 4, 5]).astype(dtype))) + z = x / y + mod = tvm.IRModule.from_expr(z) + transformed = relay.transform.DivToMul()(mod) + transformed = relay.transform.FoldConstant()(transformed) + assert transformed["main"].body.op.name == "multiply" + np.testing.assert_allclose( + transformed["main"].body.args[1].data.numpy(), [0.5, 0.5, 0.5, 0.25, 0.2], rtol=rtol + ) + + +@pytest.mark.parametrize("dtype", [("float16"), ("float32"), ("float64")]) +def test_do_not_simplify_zero_div(dtype): + x = relay.var("x", relay.TensorType([5], dtype)) + y = relay.Constant(tvm.nd.array(np.array([2, 2, 2, 4, 0]).astype(dtype))) + z = x / y + mod = tvm.IRModule.from_expr(z) + transformed = relay.transform.DivToMul()(mod) + transformed = relay.transform.FoldConstant()(transformed) + assert transformed["main"].body.op.name == "divide"