Optimization of metric evaluation#13471
Conversation
|
@mxnet-label-bot add [Metric, pr-awaiting-review] Thanks @ptrendx, I'll have a look |
vandanavk
left a comment
There was a problem hiding this comment.
Overall LGTM. Just a few comments.
Also, could you test the following:
- image classification example (train_mnist.py, train_imagenet.py)
- Speedometer's auto_reset=True and auto_reset=False
- tools/parse_log.py
| if "has_global_stats" in kwargs: | ||
| self._has_global_stats = kwargs["has_global_stats"] | ||
| else: | ||
| self._has_global_stats = False |
There was a problem hiding this comment.
self._has_global_stats = kwargs.get("has_global_stats", False) ?
There was a problem hiding this comment.
done (with pop to not keep the has_global_stats key in kwargs, to not screw with deserialization later.)
| else: | ||
| self.sum_metric = self.metrics.fscore * self.metrics.total_examples | ||
| self.sum_metric = fscore * self.metrics.total_examples | ||
| self.global_sum_metric = fscore * self.metrics.total_examples |
There was a problem hiding this comment.
self.sum_metric = self.global_sum_metric = ?
| self.sum_metric = fscore * self.metrics.total_examples | ||
| self.global_sum_metric = fscore * self.metrics.total_examples | ||
| self.num_inst = self.metrics.total_examples | ||
| self.global_num_inst = self.metrics.total_examples |
| self.sum_metric = matthewscc * self._metrics.total_examples | ||
| self.global_sum_metric = matthewscc * self._metrics.total_examples | ||
| self.num_inst = self._metrics.total_examples | ||
| self.global_num_inst = self._metrics.total_examples |
|
I tested train_imagenet.py, tools/parse_log.py and both auto_reset values for Speedometer. |
|
@ptrendx thanks for the PR. Do you mind elaborating a bit more how this PR avoids GIL/speeds up metric evaluation? |
|
It does not avoid GIL, I just do less work in Python -
|
eric-haibin-lin
left a comment
There was a problem hiding this comment.
Thanks for the fix and explanation. LGTM
|
Thanks for the contribution @ptrendx ! |
sandeep-krishnamurthy
left a comment
There was a problem hiding this comment.
Thanks. Few comments in line
| else: | ||
| return (self.name, self.global_sum_metric / self.global_num_inst) | ||
| else: | ||
| return self.get() |
There was a problem hiding this comment.
If user calls specifically global statistics and if it is not available, shouldn't we throw exception than silently return local? Same in other places.
There was a problem hiding this comment.
I'm not sure if that is possible - doing this would break fully custom metrics (other than subclasses of CustomMetric class, for which I added support) that did not implement global stats.
| for metric in self.metrics: | ||
| metric.reset_local() | ||
| except AttributeError: | ||
| pass |
There was a problem hiding this comment.
why is this required? When will reach here? Can we please document?
There was a problem hiding this comment.
This is not added in this PR and I'm not sure myself why this is needed here (this function is basically a copy of the reset function but calls reset_local on children instead of reset).
| self.sum_metric += (pred_label == label).sum() | ||
| num_correct = (pred_label == label).sum() | ||
| self.sum_metric += num_correct | ||
| self.global_sum_metric += num_correct |
There was a problem hiding this comment.
I am sorry for trivial question, but, when will global metrics different than local metrics with this logic?
There was a problem hiding this comment.
They will be different when you call reset_local (which is called from Speedometer for example) - this will reset the local versions of metrics while keeping the global versions intact (which was the point of this PR - to enable computing both per-batch and per-epoch statistics using only single computation).
There was a problem hiding this comment.
That said, this comment made me think about this again and I found a bug in how I handle global statistics in F1 and MCC metrics - fix and tests incoming. Thank you!
| for label, pred_label in zip(labels, preds): | ||
| assert(len(pred_label.shape) <= 2), 'Predictions should be no more than 2 dims' | ||
| pred_label = numpy.argsort(pred_label.asnumpy().astype('float32'), axis=1) | ||
| pred_label = numpy.argpartition(pred_label.asnumpy().astype('float32'), -self.top_k) |
There was a problem hiding this comment.
This is very nice. I think it warrants a comment on why argpartition is used here (for its performance benefit)
| Tuple of (str, float) | ||
| Representing name of the metric and evaluation result. | ||
| """ | ||
| num = self.global_num_inst if self.global_num_inst > 0 else float('nan') |
There was a problem hiding this comment.
Oh, this change got here by accident - I will revert it. Good catch, thank you!
There was a problem hiding this comment.
When I added tests, it turned out that changing this is actually necessary (I did change it slightly differently though, to match the get function from a base class) to avoid floating point division by 0 exception.
added test for global stats
|
I made fixes and added test for all metrics. @sandeep-krishnamurthy please review again. |
Description
Currently metrics are mostly evaluated on CPU using NumPy. Due to Python GIL, they are evaluated in single thread, sequentially, which may become a problem once number of used GPUs is large enough.
Checklist
Essentials
Please feel free to remove inapplicable items for your PR.
Changes
CC @vandanavk for comments.