diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index f4f76efdcf7..82e4530f428 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -1388,7 +1388,7 @@ def __getitem__(self, key): return value def __setitem__(self, key, value): - if DASK_VERSION >= "2021.04.0+17": + if DASK_VERSION >= "2021.04.1": if isinstance(key, BasicIndexer): self.array[key.tuple] = value elif isinstance(key, VectorizedIndexer): diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 8b875662ab8..f1896ebe652 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -110,39 +110,37 @@ def test_indexing(self): self.assertLazyAndIdentical(u[0], v[0]) self.assertLazyAndIdentical(u[:1], v[:1]) self.assertLazyAndIdentical(u[[0, 1], [0, 1, 2]], v[[0, 1], [0, 1, 2]]) - if LooseVersion(dask.__version__) >= LooseVersion("2021.04.0+17"): - # TODO: use @pytest.mark.parametrize to parametrize this - arr = Variable(("x"), da.array([1, 2, 3, 4])) - expected = Variable(("x"), da.array([99, 2, 3, 4])) - arr[0] = 99 # Indexing by integers - assert_identical(arr, expected) - arr = Variable(("x"), da.array([1, 2, 3, 4])) - expected = Variable(("x"), da.array([99, 99, 99, 4])) - arr[2::-1] = 99 # Indexing by slices - assert_identical(arr, expected) - arr = Variable(("x"), da.array([1, 2, 3, 4])) - expected = Variable(("x"), da.array([99, 99, 3, 99])) - arr[[0, -1, 1]] = 99 # Indexing by a list of integers - assert_identical(arr, expected) - arr = Variable(("x"), da.array([1, 2, 3, 4])) - expected = Variable(("x"), da.array([99, 99, 99, 4])) - arr[np.arange(3)] = 99 # Indexing by a 1-d numpy array of integers - assert_identical(arr, expected) - arr = Variable(("x"), da.array([1, 2, 3, 4])) - expected = Variable(("x"), da.array([1, 99, 99, 99])) - arr[[False, True, True, True]] = 99 # Indexing by a list of booleans - assert_identical(arr, expected) - arr = Variable(("x"), da.array([1, 2, 3, 4])) - expected = Variable(("x"), da.array([1, 99, 99, 99])) - arr[np.arange(4) > 0] = 99 # Indexing by a 1-d numpy array of booleans - assert_identical(arr, expected) - arr = Variable(("x"), da.array([1, 2, 3, 4])) - expected = Variable(("x"), da.array([99, 99, 99, 99])) - arr[arr > 0] = 99 # Indexing by one broadcastable Array of booleans - assert_identical(arr, expected) - else: - with pytest.raises(TypeError, match=r"stored in a dask array"): - v[:1] = 0 + + @pytest.mark.skipif( + LooseVersion(dask.__version__) < LooseVersion("2021.04.1"), + reason="Requires dask v2021.04.1 or later", + ) + @pytest.mark.parametrize( + "expected_data, index", + [ + (da.array([99, 2, 3, 4]), 0), + (da.array([99, 99, 99, 4]), slice(2, None, -1)), + (da.array([99, 99, 3, 99]), [0, -1, 1]), + (da.array([99, 99, 99, 4]), np.arange(3)), + (da.array([1, 99, 99, 99]), [False, True, True, True]), + (da.array([1, 99, 99, 99]), np.arange(4) > 0), + (da.array([99, 99, 99, 99]), Variable(("x"), da.array([1, 2, 3, 4])) > 0), + ], + ) + def test_setitem_dask_array(self, expected_data, index): + arr = Variable(("x"), da.array([1, 2, 3, 4])) + expected = Variable(("x"), expected_data) + arr[index] = 99 + assert_identical(arr, expected) + + @pytest.mark.skipif( + LooseVersion(dask.__version__) >= LooseVersion("2021.04.1"), + reason="Requires dask v2021.04.0 or earlier", + ) + def test_setitem_dask_array_error(self): + with pytest.raises(TypeError, match=r"stored in a dask array"): + v = self.lazy_var + v[:1] = 0 def test_squeeze(self): u = self.eager_var