diff --git a/nnvm/src/top/tensor/reduce.cc b/nnvm/src/top/tensor/reduce.cc index 007a3cc6e3fb..105765fccc61 100644 --- a/nnvm/src/top/tensor/reduce.cc +++ b/nnvm/src/top/tensor/reduce.cc @@ -352,7 +352,7 @@ Example:: Expr count = make_const(inputs[0]->dtype, 1); for (auto& i : r_axes) { - count *= inputs[0]->shape[i]; + count *= cast(inputs[0]->dtype, inputs[0]->shape[i]); } return Array{