From 5e3e747ec0237c36e9ddb1a84542de01c09eeda3 Mon Sep 17 00:00:00 2001 From: Stephan Kramer Date: Mon, 1 Dec 2025 15:57:29 +0000 Subject: [PATCH] BUG: fix Function._ad_dot for H1 (#4590) Fixes _ad_dot for H1 riesz_representation and adds tests for this and other representations. --- firedrake/adjoint_utils/function.py | 2 +- .../adjoint/test_reduced_functional.py | 28 +++++++++++++++++++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/firedrake/adjoint_utils/function.py b/firedrake/adjoint_utils/function.py index b66bca7716..da2b3051f4 100644 --- a/firedrake/adjoint_utils/function.py +++ b/firedrake/adjoint_utils/function.py @@ -269,7 +269,7 @@ def _ad_dot(self, other, options=None): return assemble(firedrake.inner(self, other)*firedrake.dx) elif riesz_representation == "H1": return assemble((firedrake.inner(self, other) - + firedrake.inner(firedrake.grad(self), other))*firedrake.dx) + + firedrake.inner(firedrake.grad(self), firedrake.grad(other)))*firedrake.dx) else: raise NotImplementedError( "Unknown Riesz representation %s" % riesz_representation) diff --git a/tests/firedrake/adjoint/test_reduced_functional.py b/tests/firedrake/adjoint/test_reduced_functional.py index 06469d1bca..33803b3f2f 100644 --- a/tests/firedrake/adjoint/test_reduced_functional.py +++ b/tests/firedrake/adjoint/test_reduced_functional.py @@ -283,3 +283,31 @@ def test_real_space_parallel(): Jhat = ReducedFunctional(J, Control(m)) opt = minimize(Jhat) parallel_assert(np.allclose(opt.dat.data_ro, 1)) + + +@pytest.mark.parametrize("riesz_representation", ["l2", "L2", "H1"]) +@pytest.mark.skipcomplex +def test_ad_dot(riesz_representation): + mesh = IntervalMesh(10, 0, 1) + V = FunctionSpace(mesh, "Lagrange", 1) + + c = Constant(1) + f = Function(V) + x = SpatialCoordinate(mesh) + f.interpolate(x[0]) + + u = Function(V) + v = TestFunction(V) + bc = DirichletBC(V, Constant(1), "on_boundary") + + F = inner(grad(u), grad(v))*dx - f**2*v*dx + solve(F == 0, u, bc) + + J = assemble(c**2*u*dx) + Jhat = ReducedFunctional(J, Control(f, riesz_map=riesz_representation)) + dJhat = Jhat.derivative(apply_riesz=True) + + h = Function(V) + h.dat.data[:] = rand(V.dof_dset.size) + dJdh = dJhat._ad_dot(h, options={'riesz_representation': riesz_representation}) + assert taylor_test(Jhat, f, h, dJdm=dJdh) > 1.9