diff --git a/python/mxnet/metric.py b/python/mxnet/metric.py index e91fd3b13ee6..8bb3f6ee0a81 100644 --- a/python/mxnet/metric.py +++ b/python/mxnet/metric.py @@ -28,7 +28,6 @@ from .base import numeric_types, string_types from . import ndarray from . import registry -from .context import cpu def check_label_shapes(labels, preds, shape=0): @@ -389,22 +388,16 @@ def update(self, labels, preds): """ check_label_shapes(labels, preds) - results = [] for label, pred_label in zip(labels, preds): if pred_label.shape != label.shape: pred_label = ndarray.argmax(pred_label, axis=self.axis) - pred_label = pred_label.astype('int32') - label = label.astype('int32') + pred_label = pred_label.asnumpy().astype('int32') + label = label.asnumpy().astype('int32') check_label_shapes(label, pred_label) - if pred_label.context != label.context: - pred_label = pred_label.as_in_context(label.context) - - self.num_inst += pred_label.size - results.append((pred_label.reshape((-1,)) == label.reshape((-1,))) - .sum().as_in_context(cpu())) - self.sum_metric += ndarray.add_n(*results).asscalar() + self.sum_metric += (pred_label.flat == label.flat).sum() + self.num_inst += len(pred_label.flat) @register