From e627c9eaa4d6ecde255a3d931c41f53490c1590e Mon Sep 17 00:00:00 2001 From: shoubhik Date: Mon, 14 Oct 2019 17:43:41 -0700 Subject: [PATCH 1/3] Fix infer type of kernel in dense. --- src/relay/op/nn/nn.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relay/op/nn/nn.h b/src/relay/op/nn/nn.h index 2c65d2526437..e2be0318c2fc 100644 --- a/src/relay/op/nn/nn.h +++ b/src/relay/op/nn/nn.h @@ -49,7 +49,7 @@ bool DenseRel(const Array& types, int num_inputs, const Attrs& attrs, // validate the weight shape is proper if defined // Assign weight type Array wshape({param->units, dshape[dshape.size() - 1]}); - reporter->Assign(types[1], TensorTypeNode::make(wshape, data->dtype)); + reporter->Assign(types[1], TensorTypeNode::make(wshape, weight->dtype)); oshape.Set((oshape.size() - 1), param->units); } else { if (weight == nullptr) return false; From a53d1be55c9ad29d0f212ce1eb562c9f7cf94132 Mon Sep 17 00:00:00 2001 From: shoubhik Date: Tue, 15 Oct 2019 10:32:35 -0700 Subject: [PATCH 2/3] - Moving the check of weight being nullptr up as it is needed in both the branches now. - Adding test case for validating that data dtype and kernel dtypes can be different. --- src/relay/op/nn/nn.h | 3 +-- tests/python/relay/test_op_level1.py | 16 ++++++++++++++++ 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/src/relay/op/nn/nn.h b/src/relay/op/nn/nn.h index e2be0318c2fc..45d663052de4 100644 --- a/src/relay/op/nn/nn.h +++ b/src/relay/op/nn/nn.h @@ -36,7 +36,7 @@ bool DenseRel(const Array& types, int num_inputs, const Attrs& attrs, CHECK_EQ(types.size(), 3); const auto* data = types[0].as(); const auto* weight = types[1].as(); - if (data == nullptr) return false; + if (data == nullptr || weight == nullptr) return false; const AttrType* param = attrs.as(); CHECK(param != nullptr); @@ -52,7 +52,6 @@ bool DenseRel(const Array& types, int num_inputs, const Attrs& attrs, reporter->Assign(types[1], TensorTypeNode::make(wshape, weight->dtype)); oshape.Set((oshape.size() - 1), param->units); } else { - if (weight == nullptr) return false; Array wshape = weight->shape; oshape.Set((oshape.size() - 1), wshape[0]); } diff --git a/tests/python/relay/test_op_level1.py b/tests/python/relay/test_op_level1.py index 4a07662554b9..187bd26bda1f 100644 --- a/tests/python/relay/test_op_level1.py +++ b/tests/python/relay/test_op_level1.py @@ -382,6 +382,21 @@ def test_dense(): tvm.testing.assert_allclose(op_res2.asnumpy(), ref_res, rtol=1e-5) +def test_dense_dtype(): + data_dtype = 'uint8' + weight_dtype = 'int8' + out_dtype = 'uint8' + n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w") + x = relay.var("x", relay.TensorType((n, c, h, w), data_dtype)) + w = relay.var("w", relay.TensorType((2, w), weight_dtype)) + y = relay.nn.dense(x, w, units=2, out_dtype=out_dtype) + assert "units=2" in y.astext() + yy = run_infer_type(y) + assert yy.checked_type == relay.TensorType((n, c, h, 2), out_dtype) + assert run_infer_type(yy.args[0]).checked_type.dtype == 'uint8' + assert run_infer_type(yy.args[1]).checked_type.dtype == 'int8' + + def test_bitserial_dense(): m, k = tvm.var("m"), tvm.var("k") x = relay.var("x", relay.TensorType((m, k), "int16")) @@ -405,3 +420,4 @@ def test_bitserial_dense(): test_batch_norm() test_dense() test_bitserial_dense() + test_dense_dtype() From 5d7eb5ebd06ddb6e18a24ce4d59343ae69cbfcbb Mon Sep 17 00:00:00 2001 From: shoubhik Date: Tue, 15 Oct 2019 11:04:15 -0700 Subject: [PATCH 3/3] - Fix the dtype check for weight. If the weight is not present then we will use the data dtype. --- src/relay/op/nn/nn.h | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/relay/op/nn/nn.h b/src/relay/op/nn/nn.h index 45d663052de4..684bae708662 100644 --- a/src/relay/op/nn/nn.h +++ b/src/relay/op/nn/nn.h @@ -36,7 +36,7 @@ bool DenseRel(const Array& types, int num_inputs, const Attrs& attrs, CHECK_EQ(types.size(), 3); const auto* data = types[0].as(); const auto* weight = types[1].as(); - if (data == nullptr || weight == nullptr) return false; + if (data == nullptr) return false; const AttrType* param = attrs.as(); CHECK(param != nullptr); @@ -49,9 +49,14 @@ bool DenseRel(const Array& types, int num_inputs, const Attrs& attrs, // validate the weight shape is proper if defined // Assign weight type Array wshape({param->units, dshape[dshape.size() - 1]}); - reporter->Assign(types[1], TensorTypeNode::make(wshape, weight->dtype)); + // It is possible for weight to be nullptr in which case we will use + // data dtype as the weight dtype. However if weight dtype is explicitly + // present we will use that. + auto weight_dtype = (weight == nullptr ? data->dtype : weight->dtype); + reporter->Assign(types[1], TensorTypeNode::make(wshape, weight_dtype)); oshape.Set((oshape.size() - 1), param->units); } else { + if (weight == nullptr) return false; Array wshape = weight->shape; oshape.Set((oshape.size() - 1), wshape[0]); }