diff --git a/xrspatial/tests/test_zonal.py b/xrspatial/tests/test_zonal.py index 8ff150e5..011de940 100644 --- a/xrspatial/tests/test_zonal.py +++ b/xrspatial/tests/test_zonal.py @@ -680,18 +680,8 @@ def test_stats_all_nan_zone(backend): 'sum': [12.0], 'count': [2], } - elif 'dask' in backend: - # dask uses nansum reduction, so count/sum of all-NaN become 0 - expected = { - 'zone': [1, 2], - 'mean': [np.nan, 6.0], - 'max': [np.nan, 7.0], - 'min': [np.nan, 5.0], - 'sum': [0.0, 12.0], - 'count': [0, 2], - } else: - # numpy keeps empty zone with NaN for every stat + # numpy and dask both return NaN for all-NaN zones expected = { 'zone': [1, 2], 'mean': [np.nan, 6.0], @@ -798,16 +788,8 @@ def test_stats_nodata_wipes_zone(backend): 'sum': [10.0], 'count': [2], } - elif 'dask' in backend: - expected = { - 'zone': [1, 2], - 'mean': [np.nan, 5.0], - 'max': [np.nan, 7.0], - 'min': [np.nan, 3.0], - 'sum': [0.0, 10.0], - 'count': [0, 2], - } else: + # numpy and dask both return NaN for zones with no valid values expected = { 'zone': [1, 2], 'mean': [np.nan, 5.0], @@ -868,6 +850,71 @@ def test_zonal_stats_inputs_unmodified(backend, data_zones, data_values_2d, resu assert_input_data_unmodified(data_values_2d, copied_data_values_2d) +@pytest.mark.filterwarnings("ignore:All-NaN slice encountered:RuntimeWarning") +@pytest.mark.filterwarnings("ignore:invalid value encountered in divide:RuntimeWarning") +@pytest.mark.parametrize("backend", ['numpy', 'dask+numpy']) +def test_stats_variance_numerical_stability_1090(backend): + """Dask std/var should match numpy for data with large mean, small spread. + + Regression test for #1090: the naive one-pass formula + ``(Σx² − (Σx)²/n) / n`` loses precision through catastrophic + cancellation. The fix uses Chan-Golub-LeVeque parallel merge. + """ + if 'dask' in backend and not dask_array_available(): + pytest.skip("Requires Dask") + + # Values near 1e8 with a spread of 1: the naive formula would lose + # most of the significant digits in float64. + zones_data = np.array([[1, 1, 1, 1, 1, 1]]) + values_data = np.array([[1e8, 1e8 + 1, 1e8 + 2, + 1e8 + 3, 1e8 + 4, 1e8 + 5]], dtype=np.float64) + + zones = create_test_raster(zones_data, backend, chunks=(1, 3)) + values = create_test_raster(values_data, backend, chunks=(1, 3)) + + df_result = stats(zones=zones, values=values, + stats_funcs=['mean', 'std', 'var']) + + if hasattr(df_result, 'compute'): + df_result = df_result.compute() + + # Reference: population variance of [0,1,2,3,4,5] = 35/12 ≈ 2.9167 + expected_var = np.var(np.arange(6, dtype=np.float64)) + expected_std = np.std(np.arange(6, dtype=np.float64)) + + actual_var = float(df_result['var'].iloc[0]) + actual_std = float(df_result['std'].iloc[0]) + + assert abs(actual_var - expected_var) < 1e-6, ( + f"var={actual_var}, expected={expected_var}" + ) + assert abs(actual_std - expected_std) < 1e-6, ( + f"std={actual_std}, expected={expected_std}" + ) + + +def test_stats_nodata_none_no_warning_1090(): + """Passing nodata_values=None (the default) should not trigger warnings. + + Regression test for #1090: ``zone_values != None`` triggered a numpy + FutureWarning. + """ + import warnings + + zones_data = np.array([[1, 1], [2, 2]], dtype=float) + values_data = np.array([[1.0, 2.0], [3.0, 4.0]]) + zones = xr.DataArray(zones_data) + values = xr.DataArray(values_data) + + with warnings.catch_warnings(): + warnings.simplefilter("error") + df = stats(zones=zones, values=values, nodata_values=None) + + assert len(df) == 2 + assert float(df['mean'].iloc[0]) == 1.5 + assert float(df['mean'].iloc[1]) == 3.5 + + @pytest.mark.filterwarnings("ignore:All-NaN slice encountered:RuntimeWarning") @pytest.mark.filterwarnings("ignore:invalid value encountered in divide:RuntimeWarning") @pytest.mark.parametrize("backend", ['numpy', 'dask+numpy']) diff --git a/xrspatial/zonal.py b/xrspatial/zonal.py index eaefaabd..db5f8709 100644 --- a/xrspatial/zonal.py +++ b/xrspatial/zonal.py @@ -198,21 +198,77 @@ def _stats_majority(data): min=lambda z: z.min(), sum=lambda z: z.sum(), count=lambda z: _stats_count(z), - sum_squares=lambda z: (z**2).sum() + sum_squares=lambda z: ((z - z.mean()) ** 2).sum() # block-level M2 ) +def _nanreduce_preserve_allnan(blocks, func): + """Reduce across blocks, returning NaN when ALL blocks are NaN for a zone. + + ``np.nansum`` returns 0 for all-NaN input; we want NaN so that zones + with no valid values propagate NaN, consistent with the numpy backend. + """ + result = func(blocks, axis=0) + all_nan = np.all(np.isnan(blocks), axis=0) + result[all_nan] = np.nan + return result + + _DASK_STATS = dict( - max=lambda block_maxes: np.nanmax(block_maxes, axis=0), - min=lambda block_mins: np.nanmin(block_mins, axis=0), - sum=lambda block_sums: np.nansum(block_sums, axis=0), - count=lambda block_counts: np.nansum(block_counts, axis=0), - sum_squares=lambda block_sum_squares: np.nansum(block_sum_squares, axis=0), - squared_sum=lambda block_sums: np.nansum(block_sums, axis=0)**2, + max=lambda blocks: _nanreduce_preserve_allnan(blocks, np.nanmax), + min=lambda blocks: _nanreduce_preserve_allnan(blocks, np.nanmin), + sum=lambda blocks: _nanreduce_preserve_allnan(blocks, np.nansum), + count=lambda blocks: _nanreduce_preserve_allnan(blocks, np.nansum), + sum_squares=lambda blocks: _nanreduce_preserve_allnan(blocks, np.nansum), ) -def _dask_mean(sums, counts): return sums / counts # noqa -def _dask_std(sum_squares, squared_sum, n): return np.sqrt((sum_squares - squared_sum/n) / n) # noqa -def _dask_var(sum_squares, squared_sum, n): return (sum_squares - squared_sum/n) / n # noqa + + +def _dask_mean(sums, counts): # noqa + return sums / counts + + +def _parallel_variance(block_counts, block_sums, block_m2s): + """Population variance via Chan-Golub-LeVeque parallel merge. + + Each input is (n_blocks, n_zones). ``block_m2s`` contains + per-block M2 values (sum of squared deviations from the block mean), + NOT raw sum-of-squares. Returns (n_zones,) population variance, + with NaN for zones that have no valid values in any block. + """ + n_blocks = block_counts.shape[0] + n_zones = block_counts.shape[1] + + n_acc = np.zeros(n_zones, dtype=np.float64) + mean_acc = np.zeros(n_zones, dtype=np.float64) + m2_acc = np.zeros(n_zones, dtype=np.float64) + + for i in range(n_blocks): + nc = np.asarray(block_counts[i], dtype=np.float64) + sc = np.asarray(block_sums[i], dtype=np.float64) + m2_b = np.asarray(block_m2s[i], dtype=np.float64) + + has_data = np.isfinite(nc) & (nc > 0) + nc_safe = np.where(has_data, nc, 1.0) # avoid /0 + + with np.errstate(invalid='ignore', divide='ignore'): + mean_b = sc / nc_safe + + nc = np.where(has_data, nc, 0.0) + n_ab = n_acc + nc + + delta = mean_b - mean_acc + with np.errstate(invalid='ignore', divide='ignore'): + n_ab_safe = np.where(n_ab > 0, n_ab, 1.0) + correction = delta ** 2 * n_acc * nc / n_ab_safe + new_mean = mean_acc + delta * nc / n_ab_safe + + m2_acc = np.where(has_data, m2_acc + m2_b + correction, m2_acc) + mean_acc = np.where(has_data, new_mean, mean_acc) + n_acc = np.where(has_data, n_ab, n_acc) + + with np.errstate(invalid='ignore', divide='ignore'): + var = np.where(n_acc > 0, m2_acc / n_acc, np.nan) + return var @ngjit @@ -269,7 +325,10 @@ def _calc_stats( if unique_zones[i] in zone_ids: zone_values = values_by_zones[start:end] # filter out non-finite and nodata_values - zone_values = zone_values[np.isfinite(zone_values) & (zone_values != nodata_values)] + mask = np.isfinite(zone_values) + if nodata_values is not None: + mask = mask & (zone_values != nodata_values) + zone_values = zone_values[mask] if len(zone_values) > 0: results[i] = func(zone_values) start = end @@ -342,9 +401,11 @@ def _stats_dask_numpy( sum=values.dtype, count=np.int64, sum_squares=values.dtype, - squared_sum=values.dtype, ) + # Keep per-block stacked arrays for the parallel variance merge + stacked_blocks = {} + for s in basis_stats: if s == 'sum_squares' and not compute_sum_squares: continue @@ -358,6 +419,10 @@ def _stats_dask_numpy( for z, v in zip(zones_blocks, values_blocks) ] zonal_stats = da.stack(stats_by_block, allow_unknown_chunksizes=True) + + if compute_sum_squares and s in ('count', 'sum', 'sum_squares'): + stacked_blocks[s] = zonal_stats + stats_func_by_block = delayed(_DASK_STATS[s]) stats_dict[s] = da.from_delayed( stats_func_by_block(zonal_stats), shape=(np.nan,), dtype=np.float64 @@ -365,14 +430,23 @@ def _stats_dask_numpy( if 'mean' in stats_funcs: stats_dict['mean'] = _dask_mean(stats_dict['sum'], stats_dict['count']) - if 'std' in stats_funcs: - stats_dict['std'] = _dask_std( - stats_dict['sum_squares'], stats_dict['sum'] ** 2, stats_dict['count'] - ) - if 'var' in stats_funcs: - stats_dict['var'] = _dask_var( - stats_dict['sum_squares'], stats_dict['sum'] ** 2, stats_dict['count'] + + if 'std' in stats_funcs or 'var' in stats_funcs: + var_result = da.from_delayed( + delayed(_parallel_variance)( + stacked_blocks['count'], + stacked_blocks['sum'], + stacked_blocks['sum_squares'], + ), + shape=(np.nan,), dtype=np.float64, ) + if 'var' in stats_funcs: + stats_dict['var'] = var_result + if 'std' in stats_funcs: + stats_dict['std'] = da.from_delayed( + delayed(np.sqrt)(var_result), + shape=(np.nan,), dtype=np.float64, + ) # generate dask dataframe stats_df = dd.concat([dd.from_dask_array(s) for s in stats_dict.values()], axis=1, ignore_unknown_divisions=True) @@ -846,9 +920,10 @@ def _single_zone_crosstab_2d( ): # 1D flatten zone_values, i.e, original data is 2D # filter out non-finite and nodata_values - zone_values = zone_values[ - np.isfinite(zone_values) & (zone_values != nodata_values) - ] + mask = np.isfinite(zone_values) + if nodata_values is not None: + mask = mask & (zone_values != nodata_values) + zone_values = zone_values[mask] total_count = zone_values.shape[0] crosstab_dict[TOTAL_COUNT].append(total_count) @@ -877,10 +952,10 @@ def _single_zone_crosstab_3d( if cat in cat_ids: zone_cat_data = zone_values[j] # filter out non-finite and nodata_values - zone_cat_data = zone_cat_data[ - np.isfinite(zone_cat_data) - & (zone_cat_data != nodata_values) - ] + cat_mask = np.isfinite(zone_cat_data) + if nodata_values is not None: + cat_mask = cat_mask & (zone_cat_data != nodata_values) + zone_cat_data = zone_cat_data[cat_mask] crosstab_dict[cat].append(stats_func(zone_cat_data))