From 16b85da4806a7b4f5e70b1897d8251eed2e118cd Mon Sep 17 00:00:00 2001 From: malmans2 Date: Wed, 22 Jan 2025 16:14:03 +0100 Subject: [PATCH 1/4] fix weighted polyfit --- xarray/core/dataset.py | 2 +- xarray/tests/test_dataset.py | 8 ++++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index a943d9bfc57..74f90ce9eea 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -9206,7 +9206,7 @@ def polyfit( present_dims.update(other_dims) if w is not None: - rhs = rhs * w[:, np.newaxis] + rhs = rhs * w.reshape(-1, *((1,) * len(other_dims))) with warnings.catch_warnings(): if full: # Copy np.polyfit behavior diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 8a90a05a4e3..f43bc25e7cb 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -6685,11 +6685,15 @@ def test_polyfit_output(self) -> None: assert len(out.data_vars) == 0 def test_polyfit_weighted(self) -> None: - # Make sure weighted polyfit does not change the original object (issue #5644) ds = create_test_data(seed=1) + ds = ds.broadcast_like(ds) # test more than 2 dimensions (issue #9972) ds_copy = ds.copy(deep=True) - ds.polyfit("dim2", 2, w=np.arange(ds.sizes["dim2"])) + result_weighted = ds.polyfit("dim2", 2, w=np.ones(ds.sizes["dim2"])) + result_unweighted = ds.polyfit("dim2", 2) + xr.testing.assert_identical(result_weighted, result_unweighted) + + # Make sure weighted polyfit does not change the original object (issue #5644) xr.testing.assert_identical(ds, ds_copy) def test_polyfit_coord(self) -> None: From 67cdd22bbbd4f86102ce4d66d0a351bbdd619004 Mon Sep 17 00:00:00 2001 From: malmans2 Date: Wed, 22 Jan 2025 16:18:35 +0100 Subject: [PATCH 2/4] docs --- doc/whats-new.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 17af655c02e..9b40a323f39 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -71,6 +71,8 @@ Bug fixes By `Kai Mühlbauer `_. - Use zarr-fixture to prevent thread leakage errors (:pull:`9967`). By `Kai Mühlbauer `_. +- Fix weighted ``polyfit`` for arrays with more than two dimensions (:issue:`9972`, :pull:`9974`). + By `Mattia Almansi `_. Documentation ~~~~~~~~~~~~~ From bf9699a432561fa5cc5e4845c9614470a17a0d45 Mon Sep 17 00:00:00 2001 From: malmans2 Date: Wed, 22 Jan 2025 16:22:46 +0100 Subject: [PATCH 3/4] cleanup --- xarray/tests/test_dataset.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index f43bc25e7cb..c3302dd6c9d 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -395,7 +395,7 @@ def test_unicode_data(self) -> None: Size: 12B Dimensions: (foø: 1) Coordinates: - * foø (foø) {byteorder}U3 12B {'ba®'!r} + * foø (foø) {byteorder}U3 12B {"ba®"!r} Data variables: *empty* Attributes: @@ -6689,9 +6689,9 @@ def test_polyfit_weighted(self) -> None: ds = ds.broadcast_like(ds) # test more than 2 dimensions (issue #9972) ds_copy = ds.copy(deep=True) - result_weighted = ds.polyfit("dim2", 2, w=np.ones(ds.sizes["dim2"])) - result_unweighted = ds.polyfit("dim2", 2) - xr.testing.assert_identical(result_weighted, result_unweighted) + expected = ds.polyfit("dim2", 2) + actual = ds.polyfit("dim2", 2, w=np.ones(ds.sizes["dim2"])) + xr.testing.assert_identical(expected, actual) # Make sure weighted polyfit does not change the original object (issue #5644) xr.testing.assert_identical(ds, ds_copy) From 2459ffce9e4eec121b926c27deb73cba6945663e Mon Sep 17 00:00:00 2001 From: malmans2 Date: Wed, 22 Jan 2025 16:28:32 +0100 Subject: [PATCH 4/4] restore unicode test --- xarray/tests/test_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index c3302dd6c9d..f3867bd67d2 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -395,7 +395,7 @@ def test_unicode_data(self) -> None: Size: 12B Dimensions: (foø: 1) Coordinates: - * foø (foø) {byteorder}U3 12B {"ba®"!r} + * foø (foø) {byteorder}U3 12B {'ba®'!r} Data variables: *empty* Attributes: