Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 42 additions & 23 deletions src/relay/transforms/div_to_mul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,42 +26,61 @@
namespace tvm {
namespace relay {

template <typename T>
inline bool const_has_values(size_t size, const ConstantNode* const_node,
const std::vector<T>&& values) {
for (size_t i = 0; i < size; i++) {
T data = static_cast<T*>(const_node->data->data)[i];
for (const T& v : values) {
if (data == v) return true;
}
}
return false;
}

inline size_t get_num_elements_const(const ConstantNode* const_node) {
const auto& shape = const_node->data.Shape();

size_t cnt_elements = 1;
for (const auto& dim : shape) {
cnt_elements *= dim;
}

return cnt_elements;
}

class DivToMulRewrite : public MixedModeMutator {
Expr Rewrite_(const CallNode* pre, const Expr& post) final {
if (const CallNode* call_node = post.as<CallNode>()) {
if (call_node->op == Op::Get("divide")) {
auto rhs = call_node->args[1].as<ConstantNode>();
if (rhs != nullptr) {
auto inv =
runtime::NDArray::Empty(rhs->data.Shape(), rhs->data.DataType(), rhs->data->device);
auto one = runtime::NDArray::Empty({}, rhs->data.DataType(), rhs->data->device);
size_t num_ele = get_num_elements_const(rhs);
std::string dtype = DLDataType2String(rhs->data.DataType());

bool const_has_zero_flag = false;
if (dtype == "float32") {
float rhs_val = static_cast<float*>(rhs->data->data)[0];
// Check for division by zero
if (rhs_val == 0.) {
return post;
}
static_cast<float*>(inv->data)[0] = 1. / rhs_val;
static_cast<float*>(one->data)[0] = 1.;
const_has_zero_flag = const_has_values<float>(num_ele, rhs, {0.});
} else if (dtype == "float64") {
double rhs_val = static_cast<double*>(rhs->data->data)[0];
// Check for division by zero
if (rhs_val == 0.) {
return post;
}
static_cast<double*>(inv->data)[0] = 1. / rhs_val;
static_cast<double*>(one->data)[0] = 1.;
const_has_zero_flag = const_has_values<double>(num_ele, rhs, {0.});
} else if (dtype == "float16") {
// Do f16 math in f32
float rhs_val = __gnu_h2f_ieee(static_cast<uint16_t*>(rhs->data->data)[0]);
// Check for division by zero
if (rhs_val == 0.) {
return post;
}
static_cast<uint16_t*>(inv->data)[0] = __gnu_f2h_ieee(1. / rhs_val);
static_cast<uint16_t*>(one->data)[0] = __gnu_f2h_ieee(1.);
// have to handle both + and - zero semantics manually here
const_has_zero_flag = const_has_values<uint16_t>(num_ele, rhs, {0x0000, 0x8000});
} else {
// Cannot do 1/int because it will truncate
LOG(WARNING) << "Unknown dtype not handled for div_to_mull: " << rhs->data.DataType();
return post;
}
return Multiply(call_node->args[0], Constant(inv));

if (const_has_zero_flag) {
return post;
}

// rely on constant folding to fold things
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure constant folding is going to fire at the right time? I can't remember if the divisions need to be rewritten before FakeQuantization or if it doesn't matter when.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm good point, I think there is a way to specify a required pass to be run after this one, so maybe will try to figure that out.

return Multiply(call_node->args[0], Divide(Constant(one), call_node->args[1]));
}
}
}
Expand Down
31 changes: 29 additions & 2 deletions tests/python/unittest/test_div_to_mul.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import numpy as np
import pytest

import tvm
from tvm import relay
import pytest
import numpy as np


@pytest.mark.parametrize("dtype, rtol", [("float16", 1e-3), ("float32", 1e-7), ("float64", 1e-12)])
Expand All @@ -27,5 +28,31 @@ def test_div_to_mul(dtype, rtol):
z = x / y
mod = tvm.IRModule.from_expr(z)
transformed = relay.transform.DivToMul()(mod)
transformed = relay.transform.FoldConstant()(transformed)
assert transformed["main"].body.op.name == "multiply"
np.testing.assert_allclose(transformed["main"].body.args[1].data.numpy()[0], 1 / 1.5, rtol=rtol)


@pytest.mark.parametrize("dtype, rtol", [("float16", 1e-3), ("float32", 1e-7), ("float64", 1e-12)])
def test_div_to_mul_vector(dtype, rtol):
x = relay.var("x", relay.TensorType([5], dtype))
y = relay.Constant(tvm.nd.array(np.array([2, 2, 2, 4, 5]).astype(dtype)))
z = x / y
mod = tvm.IRModule.from_expr(z)
transformed = relay.transform.DivToMul()(mod)
transformed = relay.transform.FoldConstant()(transformed)
assert transformed["main"].body.op.name == "multiply"
np.testing.assert_allclose(
transformed["main"].body.args[1].data.numpy(), [0.5, 0.5, 0.5, 0.25, 0.2], rtol=rtol
)


@pytest.mark.parametrize("dtype", [("float16"), ("float32"), ("float64")])
def test_do_not_simplify_zero_div(dtype):
x = relay.var("x", relay.TensorType([5], dtype))
y = relay.Constant(tvm.nd.array(np.array([2, 2, 2, 4, 0]).astype(dtype)))
z = x / y
mod = tvm.IRModule.from_expr(z)
transformed = relay.transform.DivToMul()(mod)
transformed = relay.transform.FoldConstant()(transformed)
assert transformed["main"].body.op.name == "divide"