From e5c9e971ea4dbf228edcebc90cb8b2c739efb439 Mon Sep 17 00:00:00 2001 From: Eric Boucher Date: Wed, 28 Jun 2023 23:13:00 +0200 Subject: [PATCH 1/2] Add rv_array to custom functions --- src/rasterstats/main.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/rasterstats/main.py b/src/rasterstats/main.py index ad77377..0dce152 100644 --- a/src/rasterstats/main.py +++ b/src/rasterstats/main.py @@ -1,3 +1,4 @@ +import inspect import sys import warnings @@ -283,10 +284,14 @@ def gen_zonal_stats( if add_stats is not None: for stat_name, stat_func in add_stats.items(): - try: + n_params = len(inspect.signature(stat_func).parameters.keys()) + if n_params == 3: + feature_stats[stat_name] = stat_func(masked, feat["properties"], rv_array) + # backwards compatible with two-argument function + elif n_params == 2: feature_stats[stat_name] = stat_func(masked, feat["properties"]) - except TypeError: - # backwards compatible with single-argument function + # backwards compatible with single-argument function + else: feature_stats[stat_name] = stat_func(masked) if raster_out: From 12d5169bb72c63bdd0ddc875434e012029fc03ec Mon Sep 17 00:00:00 2001 From: Eric Boucher Date: Wed, 28 Jun 2023 23:25:51 +0200 Subject: [PATCH 2/2] Add test --- tests/test_zonal.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/test_zonal.py b/tests/test_zonal.py index 02ef6c5..9b0c514 100644 --- a/tests/test_zonal.py +++ b/tests/test_zonal.py @@ -315,6 +315,18 @@ def mymean_prop(x, prop): for i in range(len(stats)): assert stats[i]["mymean_prop"] == stats[i]["mean"] * (i + 1) +def test_add_stats_prop_and_array(): + polygons = os.path.join(DATA, "polygons.shp") + + def mymean_prop_and_array(x, prop, rv_array): + # confirm that the object exists and is accessible. + assert rv_array is not None + return np.ma.mean(x) * prop["id"] + + stats = zonal_stats(polygons, raster, add_stats={"mymean_prop_and_array": mymean_prop_and_array}) + for i in range(len(stats)): + assert stats[i]["mymean_prop_and_array"] == stats[i]["mean"] * (i + 1) + def test_mini_raster(): polygons = os.path.join(DATA, "polygons.shp")