From 3c0f9f7c4658de35b3be3325a7895710b8d0bb40 Mon Sep 17 00:00:00 2001 From: Sheng Zha Date: Fri, 26 Jan 2018 16:34:09 -0800 Subject: [PATCH 1/2] use nd for accuracy calculation --- python/mxnet/metric.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/python/mxnet/metric.py b/python/mxnet/metric.py index 5b0780aeccee..bb2f71aff122 100644 --- a/python/mxnet/metric.py +++ b/python/mxnet/metric.py @@ -380,23 +380,24 @@ def update(self, labels, preds): Parameters ---------- labels : list of `NDArray` - The labels of the data. + The labels of the data with class indices as values, one per sample. preds : list of `NDArray` - Predicted values. + Prediction values for samples. Each prediction value can either be the class index, + or a vector of likelihoods for all classes. """ check_label_shapes(labels, preds) for label, pred_label in zip(labels, preds): if pred_label.shape != label.shape: pred_label = ndarray.argmax(pred_label, axis=self.axis) - pred_label = pred_label.asnumpy().astype('int32') - label = label.asnumpy().astype('int32') + pred_label = pred_label.astype('int32') + label = label.astype('int32') check_label_shapes(label, pred_label) - self.sum_metric += (pred_label.flat == label.flat).sum() - self.num_inst += len(pred_label.flat) + self.sum_metric += (pred_label.flatten() == label.flatten()).sum().asscalar() + self.num_inst += numpy.prod(pred_label.shape) @register From 99bc0345e44538686f133f4659eafc5c1f63cf14 Mon Sep 17 00:00:00 2001 From: Sheng Zha Date: Fri, 26 Jan 2018 18:03:20 -0800 Subject: [PATCH 2/2] check for context --- python/mxnet/metric.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/mxnet/metric.py b/python/mxnet/metric.py index bb2f71aff122..f1cdae26a235 100644 --- a/python/mxnet/metric.py +++ b/python/mxnet/metric.py @@ -396,6 +396,9 @@ def update(self, labels, preds): check_label_shapes(label, pred_label) + if pred_label.context != label.context: + pred_label = pred_label.as_in_context(label.context) + self.sum_metric += (pred_label.flatten() == label.flatten()).sum().asscalar() self.num_inst += numpy.prod(pred_label.shape)