From 39a316a2c83fefe9b740990339328818dbd74122 Mon Sep 17 00:00:00 2001 From: solin319 Date: Fri, 22 Sep 2017 21:28:41 +0800 Subject: [PATCH] Update metric.py --- python/mxnet/metric.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/mxnet/metric.py b/python/mxnet/metric.py index 1b192f233f60..55d9859c6643 100644 --- a/python/mxnet/metric.py +++ b/python/mxnet/metric.py @@ -390,13 +390,13 @@ def update(self, labels, preds): for label, pred_label in zip(labels, preds): if pred_label.shape != label.shape: pred_label = ndarray.argmax(pred_label, axis=self.axis) - label = label.astype('int32') - pred_label = pred_label.astype('int32').as_in_context(label.context) + pred_label = pred_label.asnumpy().astype('int32') + label = label.asnumpy().astype('int32') check_label_shapes(label, pred_label) - self.sum_metric += ndarray.sum(label == pred_label).asscalar() - self.num_inst += label.size + self.sum_metric += (pred_label.flat == label.flat).sum() + self.num_inst += len(pred_label.flat) @register