From 5d4f0d289829b6669bb6ff88cc0b150e032b1ff9 Mon Sep 17 00:00:00 2001 From: Sheng Zha Date: Tue, 6 Feb 2018 23:25:30 -0800 Subject: [PATCH 1/4] Revert "avoid per-batch blocking in metric (#9636)" This reverts commit 3fe694e7b1ed7fa6a2dcfeddeac44c14ab77b015. --- python/mxnet/metric.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/python/mxnet/metric.py b/python/mxnet/metric.py index e91fd3b13ee6..fc2b9014e8cc 100644 --- a/python/mxnet/metric.py +++ b/python/mxnet/metric.py @@ -28,7 +28,6 @@ from .base import numeric_types, string_types from . import ndarray from . import registry -from .context import cpu def check_label_shapes(labels, preds, shape=0): @@ -389,7 +388,6 @@ def update(self, labels, preds): """ check_label_shapes(labels, preds) - results = [] for label, pred_label in zip(labels, preds): if pred_label.shape != label.shape: pred_label = ndarray.argmax(pred_label, axis=self.axis) @@ -401,10 +399,8 @@ def update(self, labels, preds): if pred_label.context != label.context: pred_label = pred_label.as_in_context(label.context) - self.num_inst += pred_label.size - results.append((pred_label.reshape((-1,)) == label.reshape((-1,))) - .sum().as_in_context(cpu())) - self.sum_metric += ndarray.add_n(*results).asscalar() + self.sum_metric += (pred_label.reshape((-1,)) == label.reshape((-1,))).sum().asscalar() + self.num_inst += numpy.prod(pred_label.shape) @register From 35d7654d4a851477607d0e9ff58b4c2580d4c203 Mon Sep 17 00:00:00 2001 From: Sheng Zha Date: Tue, 6 Feb 2018 23:25:51 -0800 Subject: [PATCH 2/4] Revert "proper flatten in acc (#9619)" This reverts commit ed823b2e187eb859d9475eb651465edf714c6c5f. --- python/mxnet/metric.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mxnet/metric.py b/python/mxnet/metric.py index fc2b9014e8cc..f1cdae26a235 100644 --- a/python/mxnet/metric.py +++ b/python/mxnet/metric.py @@ -399,7 +399,7 @@ def update(self, labels, preds): if pred_label.context != label.context: pred_label = pred_label.as_in_context(label.context) - self.sum_metric += (pred_label.reshape((-1,)) == label.reshape((-1,))).sum().asscalar() + self.sum_metric += (pred_label.flatten() == label.flatten()).sum().asscalar() self.num_inst += numpy.prod(pred_label.shape) From 46e33610691c68668957f64e5a238ad137b6fd06 Mon Sep 17 00:00:00 2001 From: Sheng Zha Date: Tue, 6 Feb 2018 23:26:15 -0800 Subject: [PATCH 3/4] Revert "use nd for accuracy calculation (#9583)" This reverts commit f5f1b91ff972ad70e9131d3cd1d7408ddddb7684. --- python/mxnet/metric.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/python/mxnet/metric.py b/python/mxnet/metric.py index f1cdae26a235..5b0780aeccee 100644 --- a/python/mxnet/metric.py +++ b/python/mxnet/metric.py @@ -380,27 +380,23 @@ def update(self, labels, preds): Parameters ---------- labels : list of `NDArray` - The labels of the data with class indices as values, one per sample. + The labels of the data. preds : list of `NDArray` - Prediction values for samples. Each prediction value can either be the class index, - or a vector of likelihoods for all classes. + Predicted values. """ check_label_shapes(labels, preds) for label, pred_label in zip(labels, preds): if pred_label.shape != label.shape: pred_label = ndarray.argmax(pred_label, axis=self.axis) - pred_label = pred_label.astype('int32') - label = label.astype('int32') + pred_label = pred_label.asnumpy().astype('int32') + label = label.asnumpy().astype('int32') check_label_shapes(label, pred_label) - if pred_label.context != label.context: - pred_label = pred_label.as_in_context(label.context) - - self.sum_metric += (pred_label.flatten() == label.flatten()).sum().asscalar() - self.num_inst += numpy.prod(pred_label.shape) + self.sum_metric += (pred_label.flat == label.flat).sum() + self.num_inst += len(pred_label.flat) @register From 6b407594f374cd05933f1b6b57880522417227c5 Mon Sep 17 00:00:00 2001 From: Sheng Zha Date: Tue, 6 Feb 2018 23:39:24 -0800 Subject: [PATCH 4/4] keep doc change --- python/mxnet/metric.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/mxnet/metric.py b/python/mxnet/metric.py index 5b0780aeccee..8bb3f6ee0a81 100644 --- a/python/mxnet/metric.py +++ b/python/mxnet/metric.py @@ -380,10 +380,11 @@ def update(self, labels, preds): Parameters ---------- labels : list of `NDArray` - The labels of the data. + The labels of the data with class indices as values, one per sample. preds : list of `NDArray` - Predicted values. + Prediction values for samples. Each prediction value can either be the class index, + or a vector of likelihoods for all classes. """ check_label_shapes(labels, preds)