From 5133aa4a433eaa2f96cfb33e003fbb22ec882f5a Mon Sep 17 00:00:00 2001 From: Eric Wieser Date: Tue, 31 Oct 2017 23:30:29 -0700 Subject: [PATCH] Use numpy's normalize_axis_index, if available --- xarray/core/nputils.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/xarray/core/nputils.py b/xarray/core/nputils.py index a721425b839..ac33c3d34e1 100644 --- a/xarray/core/nputils.py +++ b/xarray/core/nputils.py @@ -5,15 +5,20 @@ import pandas as pd import warnings - -def _validate_axis(data, axis): - ndim = data.ndim - if not -ndim <= axis < ndim: - raise IndexError('axis %r out of bounds [-%r, %r)' - % (axis, ndim, ndim)) - if axis < 0: - axis += ndim - return axis +# Numpy has a function for this as of 1.13 +_normalize_axis_index = getattr(np.core.multiarray, 'normalize_axis_index', None) +if _normalize_axis_index is not None: + def _validate_axis(data, axis): + return _normalize_axis_index(axis, data.ndim) +else: + def _validate_axis(data, axis): + ndim = data.ndim + if not -ndim <= axis < ndim: + raise IndexError('axis %r out of bounds [-%r, %r)' + % (axis, ndim, ndim)) + if axis < 0: + axis += ndim + return axis def _select_along_axis(values, idx, axis):