From fab549181b99a7ef53bd1eb3e768f69ab44ed4c9 Mon Sep 17 00:00:00 2001 From: Tristan Konolige Date: Mon, 17 Oct 2022 15:03:36 -0700 Subject: [PATCH 1/4] [Relay] Rewrite division by constant to multiply Convert division by a scalar constant into multiplication by the inverse of the constant. Multiplication is faster than division and also allow for more optimization opportunities. Only applies to float32 and float64. --- python/tvm/relay/transform/transform.py | 7 +- src/relay/transforms/div_to_mul.cc | 67 +++++++++++++++++++ .../fake_quantization_to_integer.cc | 2 +- tests/python/unittest/test_div_to_mul.py | 31 +++++++++ 4 files changed, 105 insertions(+), 2 deletions(-) create mode 100644 src/relay/transforms/div_to_mul.cc create mode 100644 tests/python/unittest/test_div_to_mul.py diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index 196f7ef81293..c1f184671780 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -27,8 +27,8 @@ from tvm import relay, te from tvm.runtime import ndarray as _nd -from . import _ffi_api from ..backend.utils import mangle_module_name +from . import _ffi_api def build_config(opt_level=2, required_pass=None, disabled_pass=None, trace=None): @@ -1484,3 +1484,8 @@ def CollagePartition(config, cost_estimator=None): cost_estimator = relay.collage.CostEstimator() return _ffi_api.CollagePartition(config, cost_estimator) + + +def DivToMul(): + """Transform division by a constant to multiplication by the inverse of the constant""" + return _ffi_api.DivToMul() diff --git a/src/relay/transforms/div_to_mul.cc b/src/relay/transforms/div_to_mul.cc new file mode 100644 index 000000000000..b5fdbb8a4281 --- /dev/null +++ b/src/relay/transforms/div_to_mul.cc @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include +#include + +#include "pattern_utils.h" + +namespace tvm { +namespace relay { + +class DivToMulRewrite : public MixedModeMutator { + Expr Rewrite_(const CallNode* pre, const Expr& post) final { + if (const CallNode* call_node = post.as()) { + if (call_node->op == Op::Get("divide")) { + auto rhs = call_node->args[1].as(); + if (rhs != nullptr) { + auto inv = + runtime::NDArray::Empty(rhs->data.Shape(), rhs->data.DataType(), rhs->data->device); + std::string dtype = DLDataType2String(rhs->data.DataType()); + if (dtype == "float32") { + static_cast(inv->data)[0] = 1. / static_cast(rhs->data->data)[0]; + } else if (dtype == "float64") { + static_cast(inv->data)[0] = 1. / static_cast(rhs->data->data)[0]; + } else { + // Cannot do 1/int because it will truncate + return post; + } + return Multiply(call_node->args[0], Constant(inv)); + } + } + } + return post; + } +}; + +namespace transform { + +Pass DivToMul() { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { + return Downcast(DivToMulRewrite().Mutate(f)); + }; + return CreateFunctionPass(pass_func, 0, "DivToMul", {"InferType", "FoldConstant"}); +} + +TVM_REGISTER_GLOBAL("relay._transform.DivToMul").set_body_typed(DivToMul); + +} // namespace transform +} // namespace relay +} // namespace tvm diff --git a/src/relay/transforms/fake_quantization_to_integer.cc b/src/relay/transforms/fake_quantization_to_integer.cc index 0b9daed896c3..eb176df5c978 100644 --- a/src/relay/transforms/fake_quantization_to_integer.cc +++ b/src/relay/transforms/fake_quantization_to_integer.cc @@ -542,7 +542,7 @@ Pass FakeQuantizationToInteger(bool hard_fail, bool use_qat) { [=](Function f, IRModule m, PassContext pc) { return Downcast(FakeQuantizationToInteger(f, m, hard_fail, use_qat)); }; - return CreateFunctionPass(pass_func, 0, "FakeQuantizationToInteger", {"InferType"}); + return CreateFunctionPass(pass_func, 0, "FakeQuantizationToInteger", {"InferType", "DivToMul"}); } TVM_REGISTER_GLOBAL("relay._transform.FakeQuantizationToInteger") diff --git a/tests/python/unittest/test_div_to_mul.py b/tests/python/unittest/test_div_to_mul.py new file mode 100644 index 000000000000..349e6fffea77 --- /dev/null +++ b/tests/python/unittest/test_div_to_mul.py @@ -0,0 +1,31 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +from tvm import relay +import pytest +import numpy as np + + +@pytest.mark.parametrize("dtype", ["float32", "float64"]) +def test_div_to_mul(dtype): + x = relay.var("x", relay.TensorType((), dtype)) + y = relay.Constant(tvm.nd.array(np.array([1.5]).astype(dtype))) + z = x / y + mod = tvm.IRModule.from_expr(z) + transformed = relay.transform.DivToMul()(mod) + assert transformed["main"].body.op.name == "multiply" + np.testing.assert_allclose(transformed["main"].body.args[1].data.numpy()[0], 1/1.5) From bff454936884a52d9c5703ed7c76a5fd52d66c57 Mon Sep 17 00:00:00 2001 From: Tristan Konolige Date: Mon, 24 Oct 2022 13:12:23 -0700 Subject: [PATCH 2/4] formatting --- tests/python/unittest/test_div_to_mul.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/unittest/test_div_to_mul.py b/tests/python/unittest/test_div_to_mul.py index 349e6fffea77..9fbc6d8acf26 100644 --- a/tests/python/unittest/test_div_to_mul.py +++ b/tests/python/unittest/test_div_to_mul.py @@ -28,4 +28,4 @@ def test_div_to_mul(dtype): mod = tvm.IRModule.from_expr(z) transformed = relay.transform.DivToMul()(mod) assert transformed["main"].body.op.name == "multiply" - np.testing.assert_allclose(transformed["main"].body.args[1].data.numpy()[0], 1/1.5) + np.testing.assert_allclose(transformed["main"].body.args[1].data.numpy()[0], 1 / 1.5) From 90827d297cec581204226f3498dc096dfb5d5286 Mon Sep 17 00:00:00 2001 From: Tristan Konolige Date: Mon, 24 Oct 2022 13:38:20 -0700 Subject: [PATCH 3/4] handle division by zero --- src/relay/transforms/div_to_mul.cc | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/relay/transforms/div_to_mul.cc b/src/relay/transforms/div_to_mul.cc index b5fdbb8a4281..641764b424c6 100644 --- a/src/relay/transforms/div_to_mul.cc +++ b/src/relay/transforms/div_to_mul.cc @@ -35,9 +35,19 @@ class DivToMulRewrite : public MixedModeMutator { runtime::NDArray::Empty(rhs->data.Shape(), rhs->data.DataType(), rhs->data->device); std::string dtype = DLDataType2String(rhs->data.DataType()); if (dtype == "float32") { - static_cast(inv->data)[0] = 1. / static_cast(rhs->data->data)[0]; + float rhs_val = static_cast(rhs->data->data)[0]; + // Check for division by zero + if (rhs_val == 0.) { + return post; + } + static_cast(inv->data)[0] = 1. / rhs_val; } else if (dtype == "float64") { - static_cast(inv->data)[0] = 1. / static_cast(rhs->data->data)[0]; + double rhs_val = static_cast(rhs->data->data)[0]; + // Check for division by zero + if (rhs_val == 0.) { + return post; + } + static_cast(inv->data)[0] = 1. / rhs_val; } else { // Cannot do 1/int because it will truncate return post; From adeb0843ef5cd97dfd2b5c61a18333874d74466b Mon Sep 17 00:00:00 2001 From: Tristan Konolige Date: Mon, 24 Oct 2022 13:53:18 -0700 Subject: [PATCH 4/4] handle float16 --- src/relay/transforms/div_to_mul.cc | 9 +++++++++ tests/python/unittest/test_div_to_mul.py | 6 +++--- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/src/relay/transforms/div_to_mul.cc b/src/relay/transforms/div_to_mul.cc index 641764b424c6..42983c520682 100644 --- a/src/relay/transforms/div_to_mul.cc +++ b/src/relay/transforms/div_to_mul.cc @@ -19,6 +19,7 @@ #include #include #include +#include #include "pattern_utils.h" @@ -48,6 +49,14 @@ class DivToMulRewrite : public MixedModeMutator { return post; } static_cast(inv->data)[0] = 1. / rhs_val; + } else if (dtype == "float16") { + // Do f16 math in f32 + float rhs_val = __gnu_h2f_ieee(static_cast(rhs->data->data)[0]); + // Check for division by zero + if (rhs_val == 0.) { + return post; + } + static_cast(inv->data)[0] = __gnu_f2h_ieee(1. / rhs_val); } else { // Cannot do 1/int because it will truncate return post; diff --git a/tests/python/unittest/test_div_to_mul.py b/tests/python/unittest/test_div_to_mul.py index 9fbc6d8acf26..60c67ae2499c 100644 --- a/tests/python/unittest/test_div_to_mul.py +++ b/tests/python/unittest/test_div_to_mul.py @@ -20,12 +20,12 @@ import numpy as np -@pytest.mark.parametrize("dtype", ["float32", "float64"]) -def test_div_to_mul(dtype): +@pytest.mark.parametrize("dtype, rtol", [("float16", 1e-3), ("float32", 1e-7), ("float64", 1e-12)]) +def test_div_to_mul(dtype, rtol): x = relay.var("x", relay.TensorType((), dtype)) y = relay.Constant(tvm.nd.array(np.array([1.5]).astype(dtype))) z = x / y mod = tvm.IRModule.from_expr(z) transformed = relay.transform.DivToMul()(mod) assert transformed["main"].body.op.name == "multiply" - np.testing.assert_allclose(transformed["main"].body.args[1].data.numpy()[0], 1 / 1.5) + np.testing.assert_allclose(transformed["main"].body.args[1].data.numpy()[0], 1 / 1.5, rtol=rtol)