From a2677ff7d7f2ffdc568ad7a4d71def86612c5f07 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20Schw=C3=B6rer?= Date: Thu, 22 Apr 2021 15:25:30 +0200 Subject: [PATCH 1/2] Skip mean over empty axis Avoids changing the datatype if the data does not have the requested axis. --- doc/whats-new.rst | 4 ++++ xarray/core/duck_array_ops.py | 5 +++++ xarray/tests/test_duck_array_ops.py | 10 ++++++++++ 3 files changed, 19 insertions(+) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index d0e2ef3bd59..27c0f3d15fc 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -85,6 +85,10 @@ Breaking changes as positional, all others need to be passed are keyword arguments. This is part of the refactor to support external backends (:issue:`4309`, :pull:`4989`). By `Alessandro Amici `_. +- :py:func:`mean` does not change the data if axis is None. This + ensures that Datasets where some variables do not have the averaged + dimensions are not accidentially changed (:issue:`4885`). + By `David Schwörer `_ Deprecations ~~~~~~~~~~~~ diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 9dcd7906ef7..e543fdeef33 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -537,6 +537,11 @@ def mean(array, axis=None, skipna=None, **kwargs): dtypes""" from .common import _contains_cftime_datetimes + # The mean over an empty axis shouldn't change the data + # See https://github.com/pydata/xarray/issues/4885 + if axis == tuple(): + return array + array = asarray(array) if array.dtype.kind in "Mm": offset = _datetime_nanmin(array) diff --git a/xarray/tests/test_duck_array_ops.py b/xarray/tests/test_duck_array_ops.py index 1dd26bab6b6..3f1d8e6edbf 100644 --- a/xarray/tests/test_duck_array_ops.py +++ b/xarray/tests/test_duck_array_ops.py @@ -373,6 +373,16 @@ def test_cftime_datetime_mean_dask_error(): da.mean() +def test_mean_dtype(): + ds = Dataset() + ds["pos"] = [1, 2, 3] + ds["data"] = ("pos", "time"), [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]] + ds["var"] = "pos", [2, 3, 4] + ds2 = ds.mean(dim="time") + assert all(ds2["var"] == ds["var"]) + assert ds2["var"].dtype == ds["var"].dtype + + @pytest.mark.parametrize("dim_num", [1, 2]) @pytest.mark.parametrize("dtype", [float, int, np.float32, np.bool_]) @pytest.mark.parametrize("dask", [False, True]) From fcebe5e5f3bcd2d93df614966431c845384a3b2f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20Schw=C3=B6rer?= Date: Fri, 23 Apr 2021 15:33:00 +0200 Subject: [PATCH 2/2] Improvements based on feedback * Better testing * Clarify comment * Handle other functions as well, like sum, min, max --- doc/whats-new.rst | 9 ++++---- xarray/core/duck_array_ops.py | 33 +++++++++++++++++------------ xarray/tests/test_duck_array_ops.py | 11 +++++----- 3 files changed, 30 insertions(+), 23 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 27c0f3d15fc..029231a3753 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -85,10 +85,11 @@ Breaking changes as positional, all others need to be passed are keyword arguments. This is part of the refactor to support external backends (:issue:`4309`, :pull:`4989`). By `Alessandro Amici `_. -- :py:func:`mean` does not change the data if axis is None. This - ensures that Datasets where some variables do not have the averaged - dimensions are not accidentially changed (:issue:`4885`). - By `David Schwörer `_ +- Functions that are identities for 0d data return the unchanged data + if axis is empty. This ensures that Datasets where some variables do + not have the averaged dimensions are not accidentially changed + (:issue:`4885`, :pull:`5207`). By `David Schwörer + `_ Deprecations ~~~~~~~~~~~~ diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index e543fdeef33..8947ecd7477 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -310,13 +310,21 @@ def _ignore_warnings_if(condition): yield -def _create_nan_agg_method(name, dask_module=dask_array, coerce_strings=False): +def _create_nan_agg_method( + name, dask_module=dask_array, coerce_strings=False, invariant_0d=False +): from . import nanops def f(values, axis=None, skipna=None, **kwargs): if kwargs.pop("out", None) is not None: raise TypeError(f"`out` is not valid for {name}") + # The data is invariant in the case of 0d data, so do not + # change the data (and dtype) + # See https://github.com/pydata/xarray/issues/4885 + if invariant_0d and axis == (): + return values + values = asarray(values) if coerce_strings and values.dtype.kind in "SU": @@ -354,28 +362,30 @@ def f(values, axis=None, skipna=None, **kwargs): # See ops.inject_reduce_methods argmax = _create_nan_agg_method("argmax", coerce_strings=True) argmin = _create_nan_agg_method("argmin", coerce_strings=True) -max = _create_nan_agg_method("max", coerce_strings=True) -min = _create_nan_agg_method("min", coerce_strings=True) -sum = _create_nan_agg_method("sum") +max = _create_nan_agg_method("max", coerce_strings=True, invariant_0d=True) +min = _create_nan_agg_method("min", coerce_strings=True, invariant_0d=True) +sum = _create_nan_agg_method("sum", invariant_0d=True) sum.numeric_only = True sum.available_min_count = True std = _create_nan_agg_method("std") std.numeric_only = True var = _create_nan_agg_method("var") var.numeric_only = True -median = _create_nan_agg_method("median", dask_module=dask_array_compat) +median = _create_nan_agg_method( + "median", dask_module=dask_array_compat, invariant_0d=True +) median.numeric_only = True -prod = _create_nan_agg_method("prod") +prod = _create_nan_agg_method("prod", invariant_0d=True) prod.numeric_only = True prod.available_min_count = True -cumprod_1d = _create_nan_agg_method("cumprod") +cumprod_1d = _create_nan_agg_method("cumprod", invariant_0d=True) cumprod_1d.numeric_only = True -cumsum_1d = _create_nan_agg_method("cumsum") +cumsum_1d = _create_nan_agg_method("cumsum", invariant_0d=True) cumsum_1d.numeric_only = True unravel_index = _dask_or_eager_func("unravel_index") -_mean = _create_nan_agg_method("mean") +_mean = _create_nan_agg_method("mean", invariant_0d=True) def _datetime_nanmin(array): @@ -537,11 +547,6 @@ def mean(array, axis=None, skipna=None, **kwargs): dtypes""" from .common import _contains_cftime_datetimes - # The mean over an empty axis shouldn't change the data - # See https://github.com/pydata/xarray/issues/4885 - if axis == tuple(): - return array - array = asarray(array) if array.dtype.kind in "Mm": offset = _datetime_nanmin(array) diff --git a/xarray/tests/test_duck_array_ops.py b/xarray/tests/test_duck_array_ops.py index 3f1d8e6edbf..ef81a6108dd 100644 --- a/xarray/tests/test_duck_array_ops.py +++ b/xarray/tests/test_duck_array_ops.py @@ -26,7 +26,7 @@ where, ) from xarray.core.pycompat import dask_array_type -from xarray.testing import assert_allclose, assert_equal +from xarray.testing import assert_allclose, assert_equal, assert_identical from . import ( arm_xfail, @@ -373,14 +373,15 @@ def test_cftime_datetime_mean_dask_error(): da.mean() -def test_mean_dtype(): +def test_empty_axis_dtype(): ds = Dataset() ds["pos"] = [1, 2, 3] ds["data"] = ("pos", "time"), [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]] ds["var"] = "pos", [2, 3, 4] - ds2 = ds.mean(dim="time") - assert all(ds2["var"] == ds["var"]) - assert ds2["var"].dtype == ds["var"].dtype + assert_identical(ds.mean(dim="time")["var"], ds["var"]) + assert_identical(ds.max(dim="time")["var"], ds["var"]) + assert_identical(ds.min(dim="time")["var"], ds["var"]) + assert_identical(ds.sum(dim="time")["var"], ds["var"]) @pytest.mark.parametrize("dim_num", [1, 2])