diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index 196f7ef81293..c1f184671780 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -27,8 +27,8 @@ from tvm import relay, te from tvm.runtime import ndarray as _nd -from . import _ffi_api from ..backend.utils import mangle_module_name +from . import _ffi_api def build_config(opt_level=2, required_pass=None, disabled_pass=None, trace=None): @@ -1484,3 +1484,8 @@ def CollagePartition(config, cost_estimator=None): cost_estimator = relay.collage.CostEstimator() return _ffi_api.CollagePartition(config, cost_estimator) + + +def DivToMul(): + """Transform division by a constant to multiplication by the inverse of the constant""" + return _ffi_api.DivToMul() diff --git a/src/relay/transforms/div_to_mul.cc b/src/relay/transforms/div_to_mul.cc new file mode 100644 index 000000000000..42983c520682 --- /dev/null +++ b/src/relay/transforms/div_to_mul.cc @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include +#include +#include + +#include "pattern_utils.h" + +namespace tvm { +namespace relay { + +class DivToMulRewrite : public MixedModeMutator { + Expr Rewrite_(const CallNode* pre, const Expr& post) final { + if (const CallNode* call_node = post.as()) { + if (call_node->op == Op::Get("divide")) { + auto rhs = call_node->args[1].as(); + if (rhs != nullptr) { + auto inv = + runtime::NDArray::Empty(rhs->data.Shape(), rhs->data.DataType(), rhs->data->device); + std::string dtype = DLDataType2String(rhs->data.DataType()); + if (dtype == "float32") { + float rhs_val = static_cast(rhs->data->data)[0]; + // Check for division by zero + if (rhs_val == 0.) { + return post; + } + static_cast(inv->data)[0] = 1. / rhs_val; + } else if (dtype == "float64") { + double rhs_val = static_cast(rhs->data->data)[0]; + // Check for division by zero + if (rhs_val == 0.) { + return post; + } + static_cast(inv->data)[0] = 1. / rhs_val; + } else if (dtype == "float16") { + // Do f16 math in f32 + float rhs_val = __gnu_h2f_ieee(static_cast(rhs->data->data)[0]); + // Check for division by zero + if (rhs_val == 0.) { + return post; + } + static_cast(inv->data)[0] = __gnu_f2h_ieee(1. / rhs_val); + } else { + // Cannot do 1/int because it will truncate + return post; + } + return Multiply(call_node->args[0], Constant(inv)); + } + } + } + return post; + } +}; + +namespace transform { + +Pass DivToMul() { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { + return Downcast(DivToMulRewrite().Mutate(f)); + }; + return CreateFunctionPass(pass_func, 0, "DivToMul", {"InferType", "FoldConstant"}); +} + +TVM_REGISTER_GLOBAL("relay._transform.DivToMul").set_body_typed(DivToMul); + +} // namespace transform +} // namespace relay +} // namespace tvm diff --git a/src/relay/transforms/fake_quantization_to_integer.cc b/src/relay/transforms/fake_quantization_to_integer.cc index 0b9daed896c3..eb176df5c978 100644 --- a/src/relay/transforms/fake_quantization_to_integer.cc +++ b/src/relay/transforms/fake_quantization_to_integer.cc @@ -542,7 +542,7 @@ Pass FakeQuantizationToInteger(bool hard_fail, bool use_qat) { [=](Function f, IRModule m, PassContext pc) { return Downcast(FakeQuantizationToInteger(f, m, hard_fail, use_qat)); }; - return CreateFunctionPass(pass_func, 0, "FakeQuantizationToInteger", {"InferType"}); + return CreateFunctionPass(pass_func, 0, "FakeQuantizationToInteger", {"InferType", "DivToMul"}); } TVM_REGISTER_GLOBAL("relay._transform.FakeQuantizationToInteger") diff --git a/tests/python/unittest/test_div_to_mul.py b/tests/python/unittest/test_div_to_mul.py new file mode 100644 index 000000000000..60c67ae2499c --- /dev/null +++ b/tests/python/unittest/test_div_to_mul.py @@ -0,0 +1,31 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +from tvm import relay +import pytest +import numpy as np + + +@pytest.mark.parametrize("dtype, rtol", [("float16", 1e-3), ("float32", 1e-7), ("float64", 1e-12)]) +def test_div_to_mul(dtype, rtol): + x = relay.var("x", relay.TensorType((), dtype)) + y = relay.Constant(tvm.nd.array(np.array([1.5]).astype(dtype))) + z = x / y + mod = tvm.IRModule.from_expr(z) + transformed = relay.transform.DivToMul()(mod) + assert transformed["main"].body.op.name == "multiply" + np.testing.assert_allclose(transformed["main"].body.args[1].data.numpy()[0], 1 / 1.5, rtol=rtol)