diff --git a/src/relay/op/nn/nn.h b/src/relay/op/nn/nn.h index 2c65d2526437..684bae708662 100644 --- a/src/relay/op/nn/nn.h +++ b/src/relay/op/nn/nn.h @@ -49,7 +49,11 @@ bool DenseRel(const Array& types, int num_inputs, const Attrs& attrs, // validate the weight shape is proper if defined // Assign weight type Array wshape({param->units, dshape[dshape.size() - 1]}); - reporter->Assign(types[1], TensorTypeNode::make(wshape, data->dtype)); + // It is possible for weight to be nullptr in which case we will use + // data dtype as the weight dtype. However if weight dtype is explicitly + // present we will use that. + auto weight_dtype = (weight == nullptr ? data->dtype : weight->dtype); + reporter->Assign(types[1], TensorTypeNode::make(wshape, weight_dtype)); oshape.Set((oshape.size() - 1), param->units); } else { if (weight == nullptr) return false; diff --git a/tests/python/relay/test_op_level1.py b/tests/python/relay/test_op_level1.py index 4a07662554b9..187bd26bda1f 100644 --- a/tests/python/relay/test_op_level1.py +++ b/tests/python/relay/test_op_level1.py @@ -382,6 +382,21 @@ def test_dense(): tvm.testing.assert_allclose(op_res2.asnumpy(), ref_res, rtol=1e-5) +def test_dense_dtype(): + data_dtype = 'uint8' + weight_dtype = 'int8' + out_dtype = 'uint8' + n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w") + x = relay.var("x", relay.TensorType((n, c, h, w), data_dtype)) + w = relay.var("w", relay.TensorType((2, w), weight_dtype)) + y = relay.nn.dense(x, w, units=2, out_dtype=out_dtype) + assert "units=2" in y.astext() + yy = run_infer_type(y) + assert yy.checked_type == relay.TensorType((n, c, h, 2), out_dtype) + assert run_infer_type(yy.args[0]).checked_type.dtype == 'uint8' + assert run_infer_type(yy.args[1]).checked_type.dtype == 'int8' + + def test_bitserial_dense(): m, k = tvm.var("m"), tvm.var("k") x = relay.var("x", relay.TensorType((m, k), "int16")) @@ -405,3 +420,4 @@ def test_bitserial_dense(): test_batch_norm() test_dense() test_bitserial_dense() + test_dense_dtype()