From 3528bb599e6c287296454130a605d624deda5473 Mon Sep 17 00:00:00 2001 From: Li Xiaoquan Date: Wed, 26 Dec 2018 09:12:23 +0800 Subject: [PATCH] [NNVM] Fix dtype of output of mean. dtype of count is the same as dtype of inputs[0] when created, but its type may change when multiplied by inputs[0]->shape[i]. Which causes dtype of output is not same as dtype of input. --- nnvm/src/top/tensor/reduce.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nnvm/src/top/tensor/reduce.cc b/nnvm/src/top/tensor/reduce.cc index 007a3cc6e3fb..105765fccc61 100644 --- a/nnvm/src/top/tensor/reduce.cc +++ b/nnvm/src/top/tensor/reduce.cc @@ -352,7 +352,7 @@ Example:: Expr count = make_const(inputs[0]->dtype, 1); for (auto& i : r_axes) { - count *= inputs[0]->shape[i]; + count *= cast(inputs[0]->dtype, inputs[0]->shape[i]); } return Array{