From 89d58b919d28d27a805971a391832b186a5e7a7f Mon Sep 17 00:00:00 2001 From: Sheng Zha Date: Mon, 29 Jan 2018 11:18:17 -0800 Subject: [PATCH] proper flatten in acc --- python/mxnet/metric.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mxnet/metric.py b/python/mxnet/metric.py index f1cdae26a235..fc2b9014e8cc 100644 --- a/python/mxnet/metric.py +++ b/python/mxnet/metric.py @@ -399,7 +399,7 @@ def update(self, labels, preds): if pred_label.context != label.context: pred_label = pred_label.as_in_context(label.context) - self.sum_metric += (pred_label.flatten() == label.flatten()).sum().asscalar() + self.sum_metric += (pred_label.reshape((-1,)) == label.reshape((-1,))).sum().asscalar() self.num_inst += numpy.prod(pred_label.shape)