From c169703016c3beb44dba8798a6f9e4dea8b66452 Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Thu, 19 Feb 2026 14:34:59 -0600 Subject: [PATCH 1/2] test: add vmap gradient test for Gaussian object --- tests/jax/test_deriv_gsobject.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/tests/jax/test_deriv_gsobject.py b/tests/jax/test_deriv_gsobject.py index 1c24beb9..7f3ee17a 100644 --- a/tests/jax/test_deriv_gsobject.py +++ b/tests/jax/test_deriv_gsobject.py @@ -70,3 +70,27 @@ def _run(val_): atol = 1e-5 np.testing.assert_allclose(gval, gfdiff, rtol=0, atol=atol) + + +def test_deriv_gsobject_params_vmap(): + val = jnp.array([2.0, 3.0]) + eps = 1e-5 + + def _run(val_): + return jnp.max( + jgs.Gaussian( + half_light_radius=val_, + gsparams=jgs.GSParams(minimum_fft_size=64, maximum_fft_size=64), + ) + .drawImage(nx=48, ny=48, scale=0.2) + .array[24, 24] + ** 2 + ) + + _vmap_run = jax.vmap(_run) + gfunc = jax.jit(jax.vmap(jax.grad(_run))) + gval = gfunc(val) + + gfdiff = (_vmap_run(val + eps) - _vmap_run(val - eps)) / 2.0 / eps + + np.testing.assert_allclose(gval, gfdiff, rtol=0, atol=1e-6) From fec8e020f642a6d07e3d2c3c21bea78d229b4171 Mon Sep 17 00:00:00 2001 From: "Matthew R. Becker" Date: Thu, 19 Feb 2026 14:36:34 -0600 Subject: [PATCH 2/2] Apply suggestion from @beckermr --- tests/jax/test_deriv_gsobject.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/jax/test_deriv_gsobject.py b/tests/jax/test_deriv_gsobject.py index 7f3ee17a..ae4bee5b 100644 --- a/tests/jax/test_deriv_gsobject.py +++ b/tests/jax/test_deriv_gsobject.py @@ -87,7 +87,7 @@ def _run(val_): ** 2 ) - _vmap_run = jax.vmap(_run) + _vmap_run = jax.jit(jax.vmap(_run)) gfunc = jax.jit(jax.vmap(jax.grad(_run))) gval = gfunc(val)