From d0822ec6177b8f7e5a78fefc286cf5303c8f818b Mon Sep 17 00:00:00 2001 From: "James R. Maddison" Date: Fri, 29 Aug 2025 15:18:18 +0100 Subject: [PATCH 01/33] Functional transformation for mesh-independent optimization --- firedrake/adjoint/transformed_functional.py | 192 ++++++++++++++++++++ 1 file changed, 192 insertions(+) create mode 100644 firedrake/adjoint/transformed_functional.py diff --git a/firedrake/adjoint/transformed_functional.py b/firedrake/adjoint/transformed_functional.py new file mode 100644 index 0000000000..fcf2cffdd7 --- /dev/null +++ b/firedrake/adjoint/transformed_functional.py @@ -0,0 +1,192 @@ +from functools import cached_property + +import firedrake as fd +import finat +from pyadjoint.control import Control +from pyadjoint.enlisting import Enlist +from pyadjoint.reduced_functional import AbstractReducedFunctional, ReducedFunctional +from pyadjoint.tape import no_annotations + +__all__ = \ + [ + "L2TransformedFunctional" + ] + + +class L2Cholesky: + def __init__(self, space): + self._space = space + + @property + def space(self): + return self._space + + @cached_property + def M(self): + return fd.assemble(fd.inner(fd.TrialFunction(self.space), fd.TestFunction(self.space)) * fd.dx) + + @cached_property + def solver(self): + return fd.LinearSolver(self.M, solver_parameters={"ksp_type": "preonly", + "pc_type": "cholesky", + "pc_factor_mat_ordering_type": "nd"}) + + @cached_property + def pc(self): + return self.solver.ksp.getPC() + + def C_inv_action(self, u): + v = fd.Cofunction(self.space.dual()) + with u.dat.vec_ro as u_v, v.dat.vec_wo as v_v: + self.pc.applySymmetricLeft(u_v, v_v) + return v + + def C_T_inv_action(self, u): + v = fd.Function(self.space) + with u.dat.vec_ro as u_v, v.dat.vec_wo as v_v: + self.pc.applySymmetricRight(u_v, v_v) + return v + + +def dg_space(space): + return fd.FunctionSpace( + space.mesh(), + finat.ufl.BrokenElement(space.ufl_element())) + + +class L2TransformedFunctional(AbstractReducedFunctional): + r"""Represents the functional + + .. math:: + + J \circ \Pi \circ \Xi + + where + + - :math:`J` is the functional definining an optimization problem. + - :math:`\Pi` is the :math:`L^2` projection onto a discontinuous + superspace of the control space. + - :math:`\Xi` represents a change of basis from an :math:`L^2` + othogonal basis to the finite element basis for the discontinuous + superspace. + + The optimization is therefore transformed into an optimization problem + using an :math:`L^2` orthonormal basis for a discontinuous finite element + space. This can be used for mesh-independent optimization for libraries + which support only an :math:`l_2` inner product. + + Parameters + ---------- + + functional : OverloadedType + Functional defining the optimization problem, :math:`J`. + controls : Control or Sequence[Control, ...] + Controls. Must be :class:`firedrake.Function` objects. + tape : Tape + Tape used in evaluations involving :math:`J`. + """ + + @no_annotations + def __init__(self, functional, controls, *, tape=None, + project_solver_parameters=None): + if not all(isinstance(control.control, fd.Function) for control in Enlist(controls)): + raise TypeError("controls must be Function objects") + if project_solver_parameters is None: + project_solver_parameters = {} + + self._J = ReducedFunctional(functional, controls, tape=tape) + self._space = tuple(control.control.function_space() + for control in self._J.controls) + self._space_d = tuple(map(dg_space, self._space)) + self._C = tuple(map(L2Cholesky, self._space_d)) + self._controls = tuple(Control(fd.Cofunction(space_d.dual()), riesz_map="l2") + for space_d in self._space_d) + self._project_solver_parameters = dict(project_solver_parameters) + + # Map the initial guess + controls_t = self._primal_transform(tuple(control.control for control in self._J.controls)) + for control, control_t in zip(self._controls, controls_t): + control.assign(control_t) + + @property + def controls(self) -> list[Control]: + return list(self._controls) + + def _primal_transform(self, u): + u = Enlist(u) + if len(u) != len(self.controls): + raise ValueError("Invalid length") + + def transform(C, u, space_d): + # Map function to transformed 'cofunction': + # C_W^{-1} P_{VW}^* + v = fd.assemble(fd.inner(u, fd.TestFunction(space_d)) * fd.dx) + v = C.C_inv_action(v) + return v + + v = tuple(map(transform, self._C, u, self._space_d)) + return u.delist(v) + + def _dual_transform(self, u): + u = Enlist(u) + if len(u) != len(self.controls): + raise ValueError("Invalid length") + + def transform(C, u, space): + # Map transformed 'cofunction' to function: + # M_V^{-1} P_{VW} C_W^{-*} + v = C.C_T_inv_action(u) # for complex would need to be adjoint + v = fd.Function(space).project( + v, solver_parameters=self._project_solver_parameters) + return v + + v = tuple(map(transform, self._C, u, self._space)) + return u.delist(v) + + @no_annotations + def map_result(self, m): + """Map the result of an optimization. + + Parameters + ---------- + + m : firedrake.Cofunction or Sequence[firedrake.Cofunction, ...] + The result of the optimization. Represents an expansion in an + :math:`L^2` orthonormal basis for the discontinuous space. + + Returns + ------- + + firedrake.Function or list[firedrake.Function, ...] + The mapped control value in the domain of the functional + :math:`J`. + """ + + return self._dual_transform(m) + + @no_annotations + def __call__(self, values): + v = self._dual_transform(values) + return self._J(v) + + @no_annotations + def derivative(self, adj_input=1.0, apply_riesz=False): + if adj_input != 1: + raise ValueError("adj_input != 1 not supported") + + u = Enlist(self._J.derivative()) + v = self._primal_transform( + tuple(u_i.riesz_representation(solver_options=self._project_solver_parameters) + for u_i in u)) + if apply_riesz: + v = tuple(v_i._ad_convert_riesz(v_i, riesz_map=control.riesz_map) + for v_i, control in zip(v, self.controls)) + return u.delist(v) + + @no_annotations + def hessian(self, m_dot, hessian_input=None, evaluate_tlm=True, apply_riesz=False): + raise NotImplementedError("hessian not implemented") + + @no_annotations + def tlm(self, m_dot): + raise NotImplementedError("tlm not implemented") From be85023031d563fb63964413d5b62c458917fca3 Mon Sep 17 00:00:00 2001 From: "James R. Maddison" Date: Fri, 29 Aug 2025 18:38:15 +0100 Subject: [PATCH 02/33] Mass inverse test for L2TransformedFunctional --- .../adjoint/test_transformed_functional.py | 65 +++++++++++++++++++ 1 file changed, 65 insertions(+) create mode 100644 tests/firedrake/adjoint/test_transformed_functional.py diff --git a/tests/firedrake/adjoint/test_transformed_functional.py b/tests/firedrake/adjoint/test_transformed_functional.py new file mode 100644 index 0000000000..6d7f383090 --- /dev/null +++ b/tests/firedrake/adjoint/test_transformed_functional.py @@ -0,0 +1,65 @@ +import firedrake as fd +from firedrake.adjoint import ( + Control, ReducedFunctional, continue_annotation, minimize, + pause_annotation) +from firedrake.adjoint.transformed_functional import L2TransformedFunctional +import numpy as np +from pyadjoint.reduced_functional_numpy import ReducedFunctionalNumPy +from pyadjoint.tape import set_working_tape +import pytest + + +@pytest.fixture(scope="module", autouse=True) +def setup_tape(): + with set_working_tape(): + pause_annotation() + yield + pause_annotation() + + +def test_transformed_functional_mass_inverse(): + mesh = fd.UnitSquareMesh(5, 5, diagonal="crossed") + x, y = fd.SpatialCoordinate(mesh) + space = fd.FunctionSpace(mesh, "Lagrange", 1) + + m_ref = fd.Function(space, name="m_ref").interpolate( + fd.exp(x) * fd.sin(fd.pi * x) * fd.cos(fd.pi * x)) + + def forward(m): + return fd.assemble(fd.inner(m - m_ref, m - m_ref) * fd.dx) + + continue_annotation() + m_0 = fd.Function(space, name="m_0") + J = forward(m_0) + pause_annotation() + c = Control(m_0, riesz_map="l2") + + class MinimizeCallback: + def __init__(self): + self._ncalls = 0 + + @property + def ncalls(self): + return self._ncalls + + def __call__(self, xk): + self._ncalls += 1 + + J_hat = ReducedFunctional(J, c) + cb = MinimizeCallback() + m_opt = minimize(J_hat, method="L-BFGS-B", + callback=cb, + options={"ftol": 0, + "gtol": 1e-6}) + assert fd.norm(m_opt - m_ref, "L2") < 1e-4 + assert cb.ncalls > 10 # == 13 + + J_hat = L2TransformedFunctional(J, c, alpha=1) + cb = MinimizeCallback() + m_opt = minimize(ReducedFunctionalNumPy(J_hat), method="L-BFGS-B", + callback=cb, + options={"ftol": 0, + "gtol": 1e-6}) + m_opt = J_hat.map_result(m_opt) + assert fd.norm(m_opt - m_ref, "L2") < 1e-10 + assert cb.ncalls == 2 From ec0db62df9e30275bdfdf72c75ab40fda7d4b664 Mon Sep 17 00:00:00 2001 From: "James R. Maddison" Date: Fri, 29 Aug 2025 18:59:08 +0100 Subject: [PATCH 03/33] Optionally resolve ill-posedness introduced by the extension to a discontinuous space --- firedrake/adjoint/transformed_functional.py | 64 ++++++++++++++++----- 1 file changed, 49 insertions(+), 15 deletions(-) diff --git a/firedrake/adjoint/transformed_functional.py b/firedrake/adjoint/transformed_functional.py index fcf2cffdd7..d8b4938394 100644 --- a/firedrake/adjoint/transformed_functional.py +++ b/firedrake/adjoint/transformed_functional.py @@ -1,7 +1,9 @@ from functools import cached_property +from operator import itemgetter import firedrake as fd import finat +from petsctools import flatten_parameters from pyadjoint.control import Control from pyadjoint.enlisting import Enlist from pyadjoint.reduced_functional import AbstractReducedFunctional, ReducedFunctional @@ -84,11 +86,25 @@ class L2TransformedFunctional(AbstractReducedFunctional): Controls. Must be :class:`firedrake.Function` objects. tape : Tape Tape used in evaluations involving :math:`J`. + alpha : Real + Modifies the derivative, equivalent to adding an extra term to + :math:`J \circ \Pi` + + .. math:: + + \frac{1}{2} \alpha \left\| m_D - \Pi m_D \right\|_{L^2}^2. + + e.g. in a minimization problem this adds a penalty term which can + be used to avoid ill-posedness due to the use of a larger discontinuous + space. + project_solver_parameters : Mapping + Solver parameters for an :math:`L^2` projection onto the domain of the + functional :math:`J`. """ @no_annotations def __init__(self, functional, controls, *, tape=None, - project_solver_parameters=None): + alpha=0, project_solver_parameters=None): if not all(isinstance(control.control, fd.Function) for control in Enlist(controls)): raise TypeError("controls must be Function objects") if project_solver_parameters is None: @@ -101,7 +117,9 @@ def __init__(self, functional, controls, *, tape=None, self._C = tuple(map(L2Cholesky, self._space_d)) self._controls = tuple(Control(fd.Cofunction(space_d.dual()), riesz_map="l2") for space_d in self._space_d) - self._project_solver_parameters = dict(project_solver_parameters) + self._alpha = alpha + self._project_solver_parameters = flatten_parameters(project_solver_parameters) + self._m_k = None # Map the initial guess controls_t = self._primal_transform(tuple(control.control for control in self._J.controls)) @@ -112,19 +130,27 @@ def __init__(self, functional, controls, *, tape=None, def controls(self) -> list[Control]: return list(self._controls) - def _primal_transform(self, u): + def _primal_transform(self, u, u_D=None): u = Enlist(u) if len(u) != len(self.controls): raise ValueError("Invalid length") + if u_D is None: + u_D = tuple(None for _ in u) + else: + u_D = Enlist(u_D) + if len(u_D) != len(self.controls): + raise ValueError("Invalid length") - def transform(C, u, space_d): + def transform(C, u, u_D, space_d): # Map function to transformed 'cofunction': # C_W^{-1} P_{VW}^* v = fd.assemble(fd.inner(u, fd.TestFunction(space_d)) * fd.dx) + if u_D is not None: + v.dat.axpy(1, u_D.dat) v = C.C_inv_action(v) return v - v = tuple(map(transform, self._C, u, self._space_d)) + v = tuple(map(transform, self._C, u, u_D, self._space_d)) return u.delist(v) def _dual_transform(self, u): @@ -136,12 +162,12 @@ def transform(C, u, space): # Map transformed 'cofunction' to function: # M_V^{-1} P_{VW} C_W^{-*} v = C.C_T_inv_action(u) # for complex would need to be adjoint - v = fd.Function(space).project( + w = fd.Function(space).project( v, solver_parameters=self._project_solver_parameters) - return v + return v, w - v = tuple(map(transform, self._C, u, self._space)) - return u.delist(v) + vw = tuple(map(transform, self._C, u, self._space)) + return u.delist(tuple(map(itemgetter(0), vw))), u.delist(tuple(map(itemgetter(1), vw))) @no_annotations def map_result(self, m): @@ -162,12 +188,13 @@ def map_result(self, m): :math:`J`. """ - return self._dual_transform(m) + _, m_J = self._dual_transform(m) + return m_J @no_annotations def __call__(self, values): - v = self._dual_transform(values) - return self._J(v) + _, m_J = self._m_k = self._dual_transform(values) + return self._J(m_J) @no_annotations def derivative(self, adj_input=1.0, apply_riesz=False): @@ -175,9 +202,16 @@ def derivative(self, adj_input=1.0, apply_riesz=False): raise ValueError("adj_input != 1 not supported") u = Enlist(self._J.derivative()) - v = self._primal_transform( - tuple(u_i.riesz_representation(solver_options=self._project_solver_parameters) - for u_i in u)) + + v = tuple(u_i.riesz_representation(solver_options=self._project_solver_parameters) + for u_i in u) + if self._alpha == 0: + v_alpha = None + else: + v_alpha = tuple( + fd.assemble(fd.Constant(self._alpha) * fd.inner(m_D_i - m_J_i, fd.TestFunction(space_d)) * fd.dx) + for space_d, m_D_i, m_J_i in zip(self._space_d, *self._m_k)) + v = self._primal_transform(v, v_alpha) if apply_riesz: v = tuple(v_i._ad_convert_riesz(v_i, riesz_map=control.riesz_map) for v_i, control in zip(v, self.controls)) From e8f33f9e228ebde4144033ece60fbfdba5a41bc0 Mon Sep 17 00:00:00 2001 From: "James R. Maddison" Date: Fri, 29 Aug 2025 19:06:34 +0100 Subject: [PATCH 04/33] Tidying --- firedrake/adjoint/transformed_functional.py | 20 +++++++++---------- .../adjoint/test_transformed_functional.py | 7 +++---- 2 files changed, 13 insertions(+), 14 deletions(-) diff --git a/firedrake/adjoint/transformed_functional.py b/firedrake/adjoint/transformed_functional.py index d8b4938394..7245cb7dc4 100644 --- a/firedrake/adjoint/transformed_functional.py +++ b/firedrake/adjoint/transformed_functional.py @@ -69,7 +69,7 @@ class L2TransformedFunctional(AbstractReducedFunctional): - :math:`\Pi` is the :math:`L^2` projection onto a discontinuous superspace of the control space. - :math:`\Xi` represents a change of basis from an :math:`L^2` - othogonal basis to the finite element basis for the discontinuous + orthonormal basis to the finite element basis for the discontinuous superspace. The optimization is therefore transformed into an optimization problem @@ -113,10 +113,10 @@ def __init__(self, functional, controls, *, tape=None, self._J = ReducedFunctional(functional, controls, tape=tape) self._space = tuple(control.control.function_space() for control in self._J.controls) - self._space_d = tuple(map(dg_space, self._space)) - self._C = tuple(map(L2Cholesky, self._space_d)) - self._controls = tuple(Control(fd.Cofunction(space_d.dual()), riesz_map="l2") - for space_d in self._space_d) + self._space_D = tuple(map(dg_space, self._space)) + self._C = tuple(map(L2Cholesky, self._space_D)) + self._controls = tuple(Control(fd.Cofunction(space_D.dual()), riesz_map="l2") + for space_D in self._space_D) self._alpha = alpha self._project_solver_parameters = flatten_parameters(project_solver_parameters) self._m_k = None @@ -141,16 +141,16 @@ def _primal_transform(self, u, u_D=None): if len(u_D) != len(self.controls): raise ValueError("Invalid length") - def transform(C, u, u_D, space_d): + def transform(C, u, u_D, space_D): # Map function to transformed 'cofunction': # C_W^{-1} P_{VW}^* - v = fd.assemble(fd.inner(u, fd.TestFunction(space_d)) * fd.dx) + v = fd.assemble(fd.inner(u, fd.TestFunction(space_D)) * fd.dx) if u_D is not None: v.dat.axpy(1, u_D.dat) v = C.C_inv_action(v) return v - v = tuple(map(transform, self._C, u, u_D, self._space_d)) + v = tuple(map(transform, self._C, u, u_D, self._space_D)) return u.delist(v) def _dual_transform(self, u): @@ -209,8 +209,8 @@ def derivative(self, adj_input=1.0, apply_riesz=False): v_alpha = None else: v_alpha = tuple( - fd.assemble(fd.Constant(self._alpha) * fd.inner(m_D_i - m_J_i, fd.TestFunction(space_d)) * fd.dx) - for space_d, m_D_i, m_J_i in zip(self._space_d, *self._m_k)) + fd.assemble(fd.Constant(self._alpha) * fd.inner(m_D_i - m_J_i, fd.TestFunction(space_D)) * fd.dx) + for space_D, m_D_i, m_J_i in zip(self._space_D, *self._m_k)) v = self._primal_transform(v, v_alpha) if apply_riesz: v = tuple(v_i._ad_convert_riesz(v_i, riesz_map=control.riesz_map) diff --git a/tests/firedrake/adjoint/test_transformed_functional.py b/tests/firedrake/adjoint/test_transformed_functional.py index 6d7f383090..b5335d0f11 100644 --- a/tests/firedrake/adjoint/test_transformed_functional.py +++ b/tests/firedrake/adjoint/test_transformed_functional.py @@ -3,7 +3,6 @@ Control, ReducedFunctional, continue_annotation, minimize, pause_annotation) from firedrake.adjoint.transformed_functional import L2TransformedFunctional -import numpy as np from pyadjoint.reduced_functional_numpy import ReducedFunctionalNumPy from pyadjoint.tape import set_working_tape import pytest @@ -23,7 +22,7 @@ def test_transformed_functional_mass_inverse(): space = fd.FunctionSpace(mesh, "Lagrange", 1) m_ref = fd.Function(space, name="m_ref").interpolate( - fd.exp(x) * fd.sin(fd.pi * x) * fd.cos(fd.pi * x)) + fd.exp(x) * fd.sin(fd.pi * x) * fd.cos(fd.pi * y)) def forward(m): return fd.assemble(fd.inner(m - m_ref, m - m_ref) * fd.dx) @@ -52,7 +51,7 @@ def __call__(self, xk): options={"ftol": 0, "gtol": 1e-6}) assert fd.norm(m_opt - m_ref, "L2") < 1e-4 - assert cb.ncalls > 10 # == 13 + assert cb.ncalls > 10 # == 14 J_hat = L2TransformedFunctional(J, c, alpha=1) cb = MinimizeCallback() @@ -61,5 +60,5 @@ def __call__(self, xk): options={"ftol": 0, "gtol": 1e-6}) m_opt = J_hat.map_result(m_opt) - assert fd.norm(m_opt - m_ref, "L2") < 1e-10 + assert fd.norm(m_opt - m_ref, "L2") < 1e-8 assert cb.ncalls == 2 From b1c9188030fa5a5d8bcc26d1fa57970a517192d0 Mon Sep 17 00:00:00 2001 From: "James R. Maddison" Date: Mon, 1 Sep 2025 09:40:58 +0100 Subject: [PATCH 05/33] Adding missing functional term --- firedrake/adjoint/transformed_functional.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/firedrake/adjoint/transformed_functional.py b/firedrake/adjoint/transformed_functional.py index 7245cb7dc4..3f29b93ae5 100644 --- a/firedrake/adjoint/transformed_functional.py +++ b/firedrake/adjoint/transformed_functional.py @@ -87,7 +87,7 @@ class L2TransformedFunctional(AbstractReducedFunctional): tape : Tape Tape used in evaluations involving :math:`J`. alpha : Real - Modifies the derivative, equivalent to adding an extra term to + Modifies the functional, equivalent to adding an extra term to :math:`J \circ \Pi` .. math:: @@ -193,8 +193,12 @@ def map_result(self, m): @no_annotations def __call__(self, values): - _, m_J = self._m_k = self._dual_transform(values) - return self._J(m_J) + m_D, m_J = self._m_k = self._dual_transform(values) + J = self._J(m_J) + if self._alpha != 0: + for m_D, m_J in zip(*self._m_k): + J += fd.assemble(0.5 * fd.Constant(self._alpha) * fd.inner(m_D - m_J, m_D - m_J) * fd.dx) + return J @no_annotations def derivative(self, adj_input=1.0, apply_riesz=False): From fb4222ef7c215fbce3439a2b28a22151c6fa39b8 Mon Sep 17 00:00:00 2001 From: "James R. Maddison" Date: Mon, 1 Sep 2025 18:46:39 +0100 Subject: [PATCH 06/33] Poisson equation constrained optimization unit test --- .../adjoint/test_transformed_functional.py | 154 +++++++++++++++--- 1 file changed, 127 insertions(+), 27 deletions(-) diff --git a/tests/firedrake/adjoint/test_transformed_functional.py b/tests/firedrake/adjoint/test_transformed_functional.py index b5335d0f11..e6ffe72bf3 100644 --- a/tests/firedrake/adjoint/test_transformed_functional.py +++ b/tests/firedrake/adjoint/test_transformed_functional.py @@ -1,11 +1,16 @@ +from collections.abc import Sequence +from functools import partial + import firedrake as fd from firedrake.adjoint import ( Control, ReducedFunctional, continue_annotation, minimize, pause_annotation) from firedrake.adjoint.transformed_functional import L2TransformedFunctional +import numpy as np from pyadjoint.reduced_functional_numpy import ReducedFunctionalNumPy from pyadjoint.tape import set_working_tape import pytest +import ufl @pytest.fixture(scope="module", autouse=True) @@ -16,49 +21,144 @@ def setup_tape(): pause_annotation() +class MinimizeCallback(Sequence): + def __init__(self, m_0, error_norm): + self._space = m_0.function_space() + self._error_norm = error_norm + self._data = [] + + self(np.asarray(m_0._ad_to_list(m_0))) + + def __len__(self): + return len(self._data) + + def __getitem__(self, key): + return self._data[key] + + def __call__(self, xk): + k = len(self) + if ufl.duals.is_primal(self._space): + m_k = fd.Function(self._space, name="m_k") + elif ufl.duals.is_dual(self._space): + m_k = fd.Cofunction(self._space, name="m_k") + else: + raise ValueError("space is neither primal nor dual") + m_k._ad_assign_numpy(m_k, xk, 0) + error_norm = self._error_norm(m_k) + print(f"{k=} {error_norm=:6g}") + self._data.append(error_norm) + + def test_transformed_functional_mass_inverse(): mesh = fd.UnitSquareMesh(5, 5, diagonal="crossed") x, y = fd.SpatialCoordinate(mesh) space = fd.FunctionSpace(mesh, "Lagrange", 1) - m_ref = fd.Function(space, name="m_ref").interpolate( - fd.exp(x) * fd.sin(fd.pi * x) * fd.cos(fd.pi * y)) - def forward(m): return fd.assemble(fd.inner(m - m_ref, m - m_ref) * fd.dx) + m_ref = fd.Function(space, name="m_ref").interpolate( + fd.exp(x) * fd.sin(fd.pi * x) * fd.cos(fd.pi * y)) + continue_annotation() m_0 = fd.Function(space, name="m_0") J = forward(m_0) pause_annotation() + + c = Control(m_0, riesz_map="l2") + J_hat = ReducedFunctional(J, c) + + def error_norm(m): + return fd.norm(m - m_ref, norm_type="L2") + + cb = MinimizeCallback(m_0, error_norm) + _ = minimize(J_hat, method="L-BFGS-B", + callback=cb, + options={"ftol": 0, + "gtol": 1e-6}) + assert 1e-6 < cb[-1] < 1e-5 + assert len(cb) > 12 # == 15 + c = Control(m_0, riesz_map="l2") + J_hat = L2TransformedFunctional(J, c, alpha=1) - class MinimizeCallback: - def __init__(self): - self._ncalls = 0 + def error_norm(m): + m = J_hat.map_result(m) + return fd.norm(m - m_ref, norm_type="L2") - @property - def ncalls(self): - return self._ncalls + cb = MinimizeCallback(J_hat.controls[0].control, error_norm) + _ = minimize(ReducedFunctionalNumPy(J_hat), method="L-BFGS-B", + callback=cb, + options={"ftol": 0, + "gtol": 1e-6}) + assert cb[-1] < 1e-8 + assert len(cb) == 3 - def __call__(self, xk): - self._ncalls += 1 +def test_transformed_functional_poisson(): + mesh = fd.UnitSquareMesh(5, 5, diagonal="crossed") + x, y = fd.SpatialCoordinate(mesh) + space = fd.FunctionSpace(mesh, "Lagrange", 1) + test = fd.TestFunction(space) + trial = fd.TrialFunction(space) + bc = fd.DirichletBC(space, 0, "on_boundary") + + def pre_process(m): + m_0 = fd.Function(space, name="m_0").assign(m) + bc.apply(m_0) + m_1 = fd.Function(space, name="m_1").assign(m - m_0) + return m_0, m_1 + + def forward(m): + m_0, m_1 = pre_process(m) + u = fd.Function(space, name="u") + fd.solve(fd.inner(fd.grad(trial), fd.grad(test)) * fd.dx + == fd.inner(m_0, test) * fd.dx, + u, bc) + return m_0, m_1, u + + def forward_J(m, u_ref, alpha): + _, m_1, u = forward(m) + return fd.assemble(fd.inner(u - u_ref, u - u_ref) * fd.dx + + fd.Constant(alpha ** 2) * fd.inner(m_1, m_1) * fd.ds) + + m_ref = fd.Function(space, name="m_ref").interpolate( + fd.exp(x) * fd.sin(fd.pi * x) * fd.sin(fd.pi * y)) + m_ref, _, u_ref = forward(m_ref) + forward_J = partial(forward_J, u_ref=u_ref, alpha=1) + + continue_annotation() + m_0 = fd.Function(space, name="m_0") + J = forward_J(m_0) + pause_annotation() + + c = Control(m_0, riesz_map="l2") J_hat = ReducedFunctional(J, c) - cb = MinimizeCallback() - m_opt = minimize(J_hat, method="L-BFGS-B", - callback=cb, - options={"ftol": 0, - "gtol": 1e-6}) - assert fd.norm(m_opt - m_ref, "L2") < 1e-4 - assert cb.ncalls > 10 # == 14 - J_hat = L2TransformedFunctional(J, c, alpha=1) - cb = MinimizeCallback() - m_opt = minimize(ReducedFunctionalNumPy(J_hat), method="L-BFGS-B", - callback=cb, - options={"ftol": 0, - "gtol": 1e-6}) - m_opt = J_hat.map_result(m_opt) - assert fd.norm(m_opt - m_ref, "L2") < 1e-8 - assert cb.ncalls == 2 + def error_norm(m): + m, _ = pre_process(m) + return fd.norm(m - m_ref, norm_type="L2") + + cb = MinimizeCallback(m_0, error_norm) + _ = minimize(J_hat, method="L-BFGS-B", + callback=cb, + options={"ftol": 0, + "gtol": 1e-10}) + assert 1e-2 < cb[-1] < 5e-2 + assert len(cb) > 80 # == 85 + + c = Control(m_0, riesz_map="l2") + J_hat = L2TransformedFunctional(J, c, alpha=1e-5) + + def error_norm(m): + m = J_hat.map_result(m) + m, _ = pre_process(m) + return fd.norm(m - m_ref, norm_type="L2") + + cb = MinimizeCallback(J_hat.controls[0].control, error_norm) + _ = minimize(ReducedFunctionalNumPy(J_hat), method="L-BFGS-B", + callback=cb, + options={"ftol": 0, + "gtol": 1e-10}) + assert 1e-4 < cb[-1] < 5e-4 + assert len(cb) < 55 # == 50 From 5117ed60dc677de63ee2fc30de0eb815d19f7bf0 Mon Sep 17 00:00:00 2001 From: "James R. Maddison" Date: Mon, 1 Sep 2025 19:07:02 +0100 Subject: [PATCH 07/33] Add checks for number of forward evaluations --- .../adjoint/test_transformed_functional.py | 37 ++++++++++++++++++- 1 file changed, 36 insertions(+), 1 deletion(-) diff --git a/tests/firedrake/adjoint/test_transformed_functional.py b/tests/firedrake/adjoint/test_transformed_functional.py index e6ffe72bf3..eea2a68ec3 100644 --- a/tests/firedrake/adjoint/test_transformed_functional.py +++ b/tests/firedrake/adjoint/test_transformed_functional.py @@ -1,5 +1,5 @@ from collections.abc import Sequence -from functools import partial +from functools import partial, wraps import firedrake as fd from firedrake.adjoint import ( @@ -21,6 +21,37 @@ def setup_tape(): pause_annotation() +def count_calls(cls): + def init(fn): + @wraps(fn) + def wrapped(self, *args, **kwargs): + self._test_transformed_functional__ncalls = 0 + return fn(self, *args, **kwargs) + return wrapped + + def call(fn): + @wraps(fn) + def wrapped(self, *args, **kwargs): + self._test_transformed_functional__ncalls += 1 + return fn(self, *args, **kwargs) + return wrapped + + cls.__init__ = init(cls.__init__) + cls.__call__ = call(cls.__call__) + + return cls + + +@count_calls +class ReducedFunctional(ReducedFunctional): + pass + + +@count_calls +class L2TransformedFunctional(L2TransformedFunctional): + pass + + class MinimizeCallback(Sequence): def __init__(self, m_0, error_norm): self._space = m_0.function_space() @@ -78,6 +109,7 @@ def error_norm(m): "gtol": 1e-6}) assert 1e-6 < cb[-1] < 1e-5 assert len(cb) > 12 # == 15 + assert J_hat._test_transformed_functional__ncalls > 12 # == 15 c = Control(m_0, riesz_map="l2") J_hat = L2TransformedFunctional(J, c, alpha=1) @@ -93,6 +125,7 @@ def error_norm(m): "gtol": 1e-6}) assert cb[-1] < 1e-8 assert len(cb) == 3 + assert J_hat._test_transformed_functional__ncalls == 3 def test_transformed_functional_poisson(): @@ -146,6 +179,7 @@ def error_norm(m): "gtol": 1e-10}) assert 1e-2 < cb[-1] < 5e-2 assert len(cb) > 80 # == 85 + assert J_hat._test_transformed_functional__ncalls > 90 # == 95 c = Control(m_0, riesz_map="l2") J_hat = L2TransformedFunctional(J, c, alpha=1e-5) @@ -162,3 +196,4 @@ def error_norm(m): "gtol": 1e-10}) assert 1e-4 < cb[-1] < 5e-4 assert len(cb) < 55 # == 50 + assert J_hat._test_transformed_functional__ncalls < 55 # == 51 From 6fa3b2ff65b822a05f54244ae0a0f68c0d525bee Mon Sep 17 00:00:00 2001 From: "James R. Maddison" Date: Mon, 1 Sep 2025 19:16:40 +0100 Subject: [PATCH 08/33] Tidying --- .../adjoint/test_transformed_functional.py | 41 +++++++------------ 1 file changed, 15 insertions(+), 26 deletions(-) diff --git a/tests/firedrake/adjoint/test_transformed_functional.py b/tests/firedrake/adjoint/test_transformed_functional.py index eea2a68ec3..ee09c66999 100644 --- a/tests/firedrake/adjoint/test_transformed_functional.py +++ b/tests/firedrake/adjoint/test_transformed_functional.py @@ -1,5 +1,5 @@ from collections.abc import Sequence -from functools import partial, wraps +from functools import partial import firedrake as fd from firedrake.adjoint import ( @@ -21,35 +21,24 @@ def setup_tape(): pause_annotation() -def count_calls(cls): - def init(fn): - @wraps(fn) - def wrapped(self, *args, **kwargs): - self._test_transformed_functional__ncalls = 0 - return fn(self, *args, **kwargs) - return wrapped - - def call(fn): - @wraps(fn) - def wrapped(self, *args, **kwargs): - self._test_transformed_functional__ncalls += 1 - return fn(self, *args, **kwargs) - return wrapped - - cls.__init__ = init(cls.__init__) - cls.__call__ = call(cls.__call__) - - return cls - - -@count_calls class ReducedFunctional(ReducedFunctional): - pass + def __init__(self, *args, **kwargs): + self._test_transformed_functional__ncalls = 0 + super().__init__(*args, **kwargs) + + def __call__(self, *args, **kwargs): + self._test_transformed_functional__ncalls += 1 + return super().__call__(*args, **kwargs) -@count_calls class L2TransformedFunctional(L2TransformedFunctional): - pass + def __init__(self, *args, **kwargs): + self._test_transformed_functional__ncalls = 0 + super().__init__(*args, **kwargs) + + def __call__(self, *args, **kwargs): + self._test_transformed_functional__ncalls += 1 + return super().__call__(*args, **kwargs) class MinimizeCallback(Sequence): From dd94235d2d24df30cab731d536c548dfd8499fee Mon Sep 17 00:00:00 2001 From: "James R. Maddison" Date: Tue, 2 Sep 2025 09:38:57 +0100 Subject: [PATCH 09/33] Fix + tidying --- firedrake/adjoint/transformed_functional.py | 2 +- tests/firedrake/adjoint/test_transformed_functional.py | 6 ++---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/firedrake/adjoint/transformed_functional.py b/firedrake/adjoint/transformed_functional.py index 3f29b93ae5..27bd690f43 100644 --- a/firedrake/adjoint/transformed_functional.py +++ b/firedrake/adjoint/transformed_functional.py @@ -66,7 +66,7 @@ class L2TransformedFunctional(AbstractReducedFunctional): where - :math:`J` is the functional definining an optimization problem. - - :math:`\Pi` is the :math:`L^2` projection onto a discontinuous + - :math:`\Pi` is the :math:`L^2` projection from a discontinuous superspace of the control space. - :math:`\Xi` represents a change of basis from an :math:`L^2` orthonormal basis to the finite element basis for the discontinuous diff --git a/tests/firedrake/adjoint/test_transformed_functional.py b/tests/firedrake/adjoint/test_transformed_functional.py index ee09c66999..da7c15435a 100644 --- a/tests/firedrake/adjoint/test_transformed_functional.py +++ b/tests/firedrake/adjoint/test_transformed_functional.py @@ -84,8 +84,8 @@ def forward(m): m_0 = fd.Function(space, name="m_0") J = forward(m_0) pause_annotation() - c = Control(m_0, riesz_map="l2") + J_hat = ReducedFunctional(J, c) def error_norm(m): @@ -100,7 +100,6 @@ def error_norm(m): assert len(cb) > 12 # == 15 assert J_hat._test_transformed_functional__ncalls > 12 # == 15 - c = Control(m_0, riesz_map="l2") J_hat = L2TransformedFunctional(J, c, alpha=1) def error_norm(m): @@ -153,8 +152,8 @@ def forward_J(m, u_ref, alpha): m_0 = fd.Function(space, name="m_0") J = forward_J(m_0) pause_annotation() - c = Control(m_0, riesz_map="l2") + J_hat = ReducedFunctional(J, c) def error_norm(m): @@ -170,7 +169,6 @@ def error_norm(m): assert len(cb) > 80 # == 85 assert J_hat._test_transformed_functional__ncalls > 90 # == 95 - c = Control(m_0, riesz_map="l2") J_hat = L2TransformedFunctional(J, c, alpha=1e-5) def error_norm(m): From f5f40486d76789c705116cd255204178166512a7 Mon Sep 17 00:00:00 2001 From: "James R. Maddison" Date: Tue, 2 Sep 2025 10:39:13 +0100 Subject: [PATCH 10/33] DG optimization --- firedrake/adjoint/transformed_functional.py | 61 ++++++++++++------- .../adjoint/test_transformed_functional.py | 15 +++-- 2 files changed, 51 insertions(+), 25 deletions(-) diff --git a/firedrake/adjoint/transformed_functional.py b/firedrake/adjoint/transformed_functional.py index 27bd690f43..f014aba477 100644 --- a/firedrake/adjoint/transformed_functional.py +++ b/firedrake/adjoint/transformed_functional.py @@ -50,10 +50,16 @@ def C_T_inv_action(self, u): return v +def is_dg_space(space): + e, _ = finat.element_factory.convert(space.ufl_element()) + return e.is_dg() + + def dg_space(space): - return fd.FunctionSpace( - space.mesh(), - finat.ufl.BrokenElement(space.ufl_element())) + if is_dg_space(space): + return space + else: + return fd.FunctionSpace(space.mesh(), finat.ufl.BrokenElement(space.ufl_element())) class L2TransformedFunctional(AbstractReducedFunctional): @@ -122,7 +128,7 @@ def __init__(self, functional, controls, *, tape=None, self._m_k = None # Map the initial guess - controls_t = self._primal_transform(tuple(control.control for control in self._J.controls)) + controls_t = self._primal_transform(tuple(control.control for control in self._J.controls), apply_riesz=False) for control, control_t in zip(self._controls, controls_t): control.assign(control_t) @@ -130,7 +136,7 @@ def __init__(self, functional, controls, *, tape=None, def controls(self) -> list[Control]: return list(self._controls) - def _primal_transform(self, u, u_D=None): + def _primal_transform(self, u, u_D=None, *, apply_riesz=False): u = Enlist(u) if len(u) != len(self.controls): raise ValueError("Invalid length") @@ -141,16 +147,24 @@ def _primal_transform(self, u, u_D=None): if len(u_D) != len(self.controls): raise ValueError("Invalid length") - def transform(C, u, u_D, space_D): + def transform(C, u, u_D, space, space_D): # Map function to transformed 'cofunction': - # C_W^{-1} P_{VW}^* - v = fd.assemble(fd.inner(u, fd.TestFunction(space_D)) * fd.dx) + if apply_riesz: + # C_W^{-1} P_{VW}^* M_V^{-1} + if space is space_D: + v = u + else: + v = u.riesz_representation(solver_options=self._project_solver_parameters) + v = fd.assemble(fd.inner(v, fd.TestFunction(space_D)) * fd.dx) + else: + # C_W^{-1} P_{VW}^* + v = fd.assemble(fd.inner(u, fd.TestFunction(space_D)) * fd.dx) if u_D is not None: v.dat.axpy(1, u_D.dat) v = C.C_inv_action(v) return v - v = tuple(map(transform, self._C, u, u_D, self._space_D)) + v = tuple(map(transform, self._C, u, u_D, self._space, self._space_D)) return u.delist(v) def _dual_transform(self, u): @@ -158,15 +172,18 @@ def _dual_transform(self, u): if len(u) != len(self.controls): raise ValueError("Invalid length") - def transform(C, u, space): + def transform(C, u, space, space_D): # Map transformed 'cofunction' to function: # M_V^{-1} P_{VW} C_W^{-*} v = C.C_T_inv_action(u) # for complex would need to be adjoint - w = fd.Function(space).project( - v, solver_parameters=self._project_solver_parameters) + if space is space_D: + w = v + else: + w = fd.Function(space).project( + v, solver_parameters=self._project_solver_parameters) return v, w - vw = tuple(map(transform, self._C, u, self._space)) + vw = tuple(map(transform, self._C, u, self._space, self._space_D)) return u.delist(tuple(map(itemgetter(0), vw))), u.delist(tuple(map(itemgetter(1), vw))) @no_annotations @@ -196,8 +213,9 @@ def __call__(self, values): m_D, m_J = self._m_k = self._dual_transform(values) J = self._J(m_J) if self._alpha != 0: - for m_D, m_J in zip(*self._m_k): - J += fd.assemble(0.5 * fd.Constant(self._alpha) * fd.inner(m_D - m_J, m_D - m_J) * fd.dx) + for space, space_D, m_D, m_J in zip(self._space, self._space_D, *self._m_k): + if space is not space_D: + J += fd.assemble(0.5 * fd.Constant(self._alpha) * fd.inner(m_D - m_J, m_D - m_J) * fd.dx) return J @no_annotations @@ -207,15 +225,16 @@ def derivative(self, adj_input=1.0, apply_riesz=False): u = Enlist(self._J.derivative()) - v = tuple(u_i.riesz_representation(solver_options=self._project_solver_parameters) - for u_i in u) if self._alpha == 0: v_alpha = None else: - v_alpha = tuple( - fd.assemble(fd.Constant(self._alpha) * fd.inner(m_D_i - m_J_i, fd.TestFunction(space_D)) * fd.dx) - for space_D, m_D_i, m_J_i in zip(self._space_D, *self._m_k)) - v = self._primal_transform(v, v_alpha) + v_alpha = [] + for space, space_D, m_D, m_J in zip(self._space, self._space_D, *self._m_k): + if space is space_D: + v_alpha.append(None) + else: + v_alpha.append(fd.assemble(fd.Constant(self._alpha) * fd.inner(m_D - m_J, fd.TestFunction(space_D)) * fd.dx)) + v = self._primal_transform(u, v_alpha, apply_riesz=True) if apply_riesz: v = tuple(v_i._ad_convert_riesz(v_i, riesz_map=control.riesz_map) for v_i, control in zip(v, self.controls)) diff --git a/tests/firedrake/adjoint/test_transformed_functional.py b/tests/firedrake/adjoint/test_transformed_functional.py index da7c15435a..fe6ce5711c 100644 --- a/tests/firedrake/adjoint/test_transformed_functional.py +++ b/tests/firedrake/adjoint/test_transformed_functional.py @@ -69,10 +69,11 @@ def __call__(self, xk): self._data.append(error_norm) -def test_transformed_functional_mass_inverse(): +@pytest.mark.parametrize("family", ("Lagrange", "Discontinuous Lagrange")) +def test_transformed_functional_mass_inverse(family): mesh = fd.UnitSquareMesh(5, 5, diagonal="crossed") x, y = fd.SpatialCoordinate(mesh) - space = fd.FunctionSpace(mesh, "Lagrange", 1) + space = fd.FunctionSpace(mesh, family, 1, variant="equispaced") def forward(m): return fd.assemble(fd.inner(m - m_ref, m - m_ref) * fd.dx) @@ -97,8 +98,14 @@ def error_norm(m): options={"ftol": 0, "gtol": 1e-6}) assert 1e-6 < cb[-1] < 1e-5 - assert len(cb) > 12 # == 15 - assert J_hat._test_transformed_functional__ncalls > 12 # == 15 + if family == "Lagrange": + assert len(cb) > 12 # == 15 + assert J_hat._test_transformed_functional__ncalls > 12 # == 15 + elif family == "Discontinuous Lagrange": + assert len(cb) == 5 + assert J_hat._test_transformed_functional__ncalls == 6 + else: + raise ValueError(f"Invalid element family: '{family}'") J_hat = L2TransformedFunctional(J, c, alpha=1) From b2a6a6b1cff0a95ede6f38360439579cec619785 Mon Sep 17 00:00:00 2001 From: "James R. Maddison" Date: Tue, 2 Sep 2025 12:09:39 +0100 Subject: [PATCH 11/33] Add comment --- firedrake/adjoint/transformed_functional.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/firedrake/adjoint/transformed_functional.py b/firedrake/adjoint/transformed_functional.py index f014aba477..9e86a0ad19 100644 --- a/firedrake/adjoint/transformed_functional.py +++ b/firedrake/adjoint/transformed_functional.py @@ -154,6 +154,8 @@ def transform(C, u, u_D, space, space_D): if space is space_D: v = u else: + # Might be replaced with interpolation (does this work for + # all spaces?) v = u.riesz_representation(solver_options=self._project_solver_parameters) v = fd.assemble(fd.inner(v, fd.TestFunction(space_D)) * fd.dx) else: From 83cad0c662cbd0d9ae5f14b701e94e452c883071 Mon Sep 17 00:00:00 2001 From: "James R. Maddison" Date: Tue, 2 Sep 2025 12:18:26 +0100 Subject: [PATCH 12/33] Minor fix --- firedrake/adjoint/transformed_functional.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/firedrake/adjoint/transformed_functional.py b/firedrake/adjoint/transformed_functional.py index 9e86a0ad19..57083a069e 100644 --- a/firedrake/adjoint/transformed_functional.py +++ b/firedrake/adjoint/transformed_functional.py @@ -130,7 +130,7 @@ def __init__(self, functional, controls, *, tape=None, # Map the initial guess controls_t = self._primal_transform(tuple(control.control for control in self._J.controls), apply_riesz=False) for control, control_t in zip(self._controls, controls_t): - control.assign(control_t) + control.control.assign(control_t) @property def controls(self) -> list[Control]: From 7db73d0ee0f9bb7cd4f9a074efeaa759cdd4ad05 Mon Sep 17 00:00:00 2001 From: "James R. Maddison" Date: Thu, 4 Sep 2025 12:10:21 +0100 Subject: [PATCH 13/33] Add L2TransformedFunctional.hessian, add TAO test using tao_type_nls, other minor updates --- firedrake/adjoint/transformed_functional.py | 50 ++++++++---- .../adjoint/test_transformed_functional.py | 81 +++++++++++++++++++ 2 files changed, 117 insertions(+), 14 deletions(-) diff --git a/firedrake/adjoint/transformed_functional.py b/firedrake/adjoint/transformed_functional.py index 57083a069e..1f46f2e6b6 100644 --- a/firedrake/adjoint/transformed_functional.py +++ b/firedrake/adjoint/transformed_functional.py @@ -88,7 +88,7 @@ class L2TransformedFunctional(AbstractReducedFunctional): functional : OverloadedType Functional defining the optimization problem, :math:`J`. - controls : Control or Sequence[Control, ...] + controls : Control or Sequence[Control] Controls. Must be :class:`firedrake.Function` objects. tape : Tape Tape used in evaluations involving :math:`J`. @@ -98,14 +98,15 @@ class L2TransformedFunctional(AbstractReducedFunctional): .. math:: - \frac{1}{2} \alpha \left\| m_D - \Pi m_D \right\|_{L^2}^2. + \frac{1}{2} \alpha \left\| m_D - \Pi ( m_D ) \right\|_{L^2}^2. e.g. in a minimization problem this adds a penalty term which can be used to avoid ill-posedness due to the use of a larger discontinuous space. project_solver_parameters : Mapping - Solver parameters for an :math:`L^2` projection onto the domain of the - functional :math:`J`. + Solver parameters for an :math:`L^2` projection from the discontinuous + space onto the control space. Ignored for controls in DG spaces, + where the projection is an identity. """ @no_annotations @@ -116,13 +117,15 @@ def __init__(self, functional, controls, *, tape=None, if project_solver_parameters is None: project_solver_parameters = {} + super().__init__() self._J = ReducedFunctional(functional, controls, tape=tape) self._space = tuple(control.control.function_space() for control in self._J.controls) self._space_D = tuple(map(dg_space, self._space)) self._C = tuple(map(L2Cholesky, self._space_D)) - self._controls = tuple(Control(fd.Cofunction(space_D.dual()), riesz_map="l2") + self._controls = tuple(Control(fd.Function(space_D), riesz_map="l2") for space_D in self._space_D) + self._controls_delist = Enlist(Enlist(controls).delist(tuple(None for _ in self._controls))).delist self._alpha = alpha self._project_solver_parameters = flatten_parameters(project_solver_parameters) self._m_k = None @@ -133,8 +136,8 @@ def __init__(self, functional, controls, *, tape=None, control.control.assign(control_t) @property - def controls(self) -> list[Control]: - return list(self._controls) + def controls(self) -> Enlist[Control]: + return Enlist(self._controls_delist(self._controls)) def _primal_transform(self, u, u_D=None, *, apply_riesz=False): u = Enlist(u) @@ -164,7 +167,7 @@ def transform(C, u, u_D, space, space_D): if u_D is not None: v.dat.axpy(1, u_D.dat) v = C.C_inv_action(v) - return v + return v.riesz_representation("l2") v = tuple(map(transform, self._C, u, u_D, self._space, self._space_D)) return u.delist(v) @@ -195,16 +198,15 @@ def map_result(self, m): Parameters ---------- - m : firedrake.Cofunction or Sequence[firedrake.Cofunction, ...] + m : firedrake.Cofunction or Sequence[firedrake.Cofunction] The result of the optimization. Represents an expansion in an :math:`L^2` orthonormal basis for the discontinuous space. Returns ------- - firedrake.Function or list[firedrake.Function, ...] - The mapped control value in the domain of the functional - :math:`J`. + firedrake.Function or list[firedrake.Function] + The mapped result in the original control space. """ _, m_J = self._dual_transform(m) @@ -223,7 +225,7 @@ def __call__(self, values): @no_annotations def derivative(self, adj_input=1.0, apply_riesz=False): if adj_input != 1: - raise ValueError("adj_input != 1 not supported") + raise NotImplementedError("adj_input != 1 not supported") u = Enlist(self._J.derivative()) @@ -244,7 +246,27 @@ def derivative(self, adj_input=1.0, apply_riesz=False): @no_annotations def hessian(self, m_dot, hessian_input=None, evaluate_tlm=True, apply_riesz=False): - raise NotImplementedError("hessian not implemented") + if hessian_input is not None: + raise NotImplementedError("hessian_input not None not supported") + + m_dot = Enlist(m_dot) + m_dot_D, m_dot_J = self._dual_transform(m_dot) + u = Enlist(self._J.hessian(m_dot.delist(m_dot_J), evaluate_tlm=evaluate_tlm)) + + if self._alpha == 0: + v_alpha = None + else: + v_alpha = [] + for space, space_D, m_dot_D, m_dot_J in zip(self._space, self._space_D, m_dot_D, m_dot_J): + if space is space_D: + v_alpha.append(None) + else: + v_alpha.append(fd.assemble(fd.Constant(self._alpha) * fd.inner(m_dot_D - m_dot_J, fd.TestFunction(space_D)) * fd.dx)) + v = self._primal_transform(u, v_alpha, apply_riesz=True) + if apply_riesz: + v = tuple(v_i._ad_convert_riesz(v_i, riesz_map=control.riesz_map) + for v_i, control in zip(v, self.controls)) + return u.delist(v) @no_annotations def tlm(self, m_dot): diff --git a/tests/firedrake/adjoint/test_transformed_functional.py b/tests/firedrake/adjoint/test_transformed_functional.py index fe6ce5711c..6b5ff9c094 100644 --- a/tests/firedrake/adjoint/test_transformed_functional.py +++ b/tests/firedrake/adjoint/test_transformed_functional.py @@ -7,6 +7,7 @@ pause_annotation) from firedrake.adjoint.transformed_functional import L2TransformedFunctional import numpy as np +from pyadjoint import MinimizationProblem, TAOSolver from pyadjoint.reduced_functional_numpy import ReducedFunctionalNumPy from pyadjoint.tape import set_working_tape import pytest @@ -191,3 +192,83 @@ def error_norm(m): assert 1e-4 < cb[-1] < 5e-4 assert len(cb) < 55 # == 50 assert J_hat._test_transformed_functional__ncalls < 55 # == 51 + + +def test_transformed_functional_poisson_tao_nls(): + mesh = fd.UnitSquareMesh(5, 5, diagonal="crossed") + x, y = fd.SpatialCoordinate(mesh) + space = fd.FunctionSpace(mesh, "Lagrange", 1) + test = fd.TestFunction(space) + trial = fd.TrialFunction(space) + bc = fd.DirichletBC(space, 0, "on_boundary") + + def pre_process(m): + m_0 = fd.Function(space, name="m_0").assign(m) + bc.apply(m_0) + m_1 = fd.Function(space, name="m_1").assign(m - m_0) + return m_0, m_1 + + def forward(m): + m_0, m_1 = pre_process(m) + u = fd.Function(space, name="u") + fd.solve(fd.inner(fd.grad(trial), fd.grad(test)) * fd.dx + == fd.inner(m_0, test) * fd.dx, + u, bc) + return m_0, m_1, u + + def forward_J(m, u_ref, alpha): + _, m_1, u = forward(m) + return fd.assemble(fd.inner(u - u_ref, u - u_ref) * fd.dx + + fd.Constant(alpha ** 2) * fd.inner(m_1, m_1) * fd.ds) + + m_ref = fd.Function(space, name="m_ref").interpolate( + fd.exp(x) * fd.sin(fd.pi * x) * fd.sin(fd.pi * y)) + m_ref, _, u_ref = forward(m_ref) + forward_J = partial(forward_J, u_ref=u_ref, alpha=1) + + continue_annotation() + m_0 = fd.Function(space, name="m_0") + J = forward_J(m_0) + pause_annotation() + c = Control(m_0) + + J_hat = ReducedFunctional(J, c) + + def error_norm(m): + m, _ = pre_process(m) + return fd.norm(m - m_ref, norm_type="L2") + + problem = MinimizationProblem(J_hat) + solver = TAOSolver(problem, {"tao_type": "nls", + "tao_monitor": None, + "tao_converged_reason": None, + "tao_gatol": 1.0e-5, + "tao_grtol": 0.0, + "tao_gttol": 1.0e-6, + "tao_monitor": None}) + m_opt = solver.solve() + error_norm_opt = error_norm(m_opt) + print(f"{error_norm_opt=}") + assert 1e-2 < error_norm_opt < 5e-2 + assert J_hat._test_transformed_functional__ncalls > 22 # == 24 + + J_hat = L2TransformedFunctional(J, c, alpha=1e-5) + + def error_norm(m): + m = J_hat.map_result(m) + m, _ = pre_process(m) + return fd.norm(m - m_ref, norm_type="L2") + + problem = MinimizationProblem(J_hat) + solver = TAOSolver(problem, {"tao_type": "nls", + "tao_monitor": None, + "tao_converged_reason": None, + "tao_gatol": 1.0e-5, + "tao_grtol": 0.0, + "tao_gttol": 1.0e-6, + "tao_monitor": None}) + m_opt = solver.solve() + error_norm_opt = error_norm(m_opt) + print(f"{error_norm_opt=}") + assert 1e-3 < error_norm_opt < 1e-2 + assert J_hat._test_transformed_functional__ncalls < 18 # == 16 From 9cba940e9f55d83429ef089003864b2e8bfd2167 Mon Sep 17 00:00:00 2001 From: "James R. Maddison" Date: Thu, 4 Sep 2025 12:33:03 +0100 Subject: [PATCH 14/33] Add L2TransformedFunctional.tlm, minor updates --- firedrake/adjoint/transformed_functional.py | 26 ++++++++++++++++--- .../adjoint/test_transformed_functional.py | 4 +-- 2 files changed, 24 insertions(+), 6 deletions(-) diff --git a/firedrake/adjoint/transformed_functional.py b/firedrake/adjoint/transformed_functional.py index 1f46f2e6b6..a47aa4a81b 100644 --- a/firedrake/adjoint/transformed_functional.py +++ b/firedrake/adjoint/transformed_functional.py @@ -180,7 +180,10 @@ def _dual_transform(self, u): def transform(C, u, space, space_D): # Map transformed 'cofunction' to function: # M_V^{-1} P_{VW} C_W^{-*} - v = C.C_T_inv_action(u) # for complex would need to be adjoint + if fd.utils.complex_mode: + # Would need to be adjoint + raise NotImplementedError("complex not supported") + v = C.C_T_inv_action(u) if space is space_D: w = v else: @@ -214,6 +217,7 @@ def map_result(self, m): @no_annotations def __call__(self, values): + values = Enlist(values) m_D, m_J = self._m_k = self._dual_transform(values) J = self._J(m_J) if self._alpha != 0: @@ -237,6 +241,8 @@ def derivative(self, adj_input=1.0, apply_riesz=False): if space is space_D: v_alpha.append(None) else: + if fd.utils.complex_mode: + raise RuntimeError("Not complex differentiable") v_alpha.append(fd.assemble(fd.Constant(self._alpha) * fd.inner(m_D - m_J, fd.TestFunction(space_D)) * fd.dx)) v = self._primal_transform(u, v_alpha, apply_riesz=True) if apply_riesz: @@ -257,11 +263,13 @@ def hessian(self, m_dot, hessian_input=None, evaluate_tlm=True, apply_riesz=Fals v_alpha = None else: v_alpha = [] - for space, space_D, m_dot_D, m_dot_J in zip(self._space, self._space_D, m_dot_D, m_dot_J): + for space, space_D, m_dot_D_i, m_dot_J_i in zip(self._space, self._space_D, m_dot_D, m_dot_J): if space is space_D: v_alpha.append(None) else: - v_alpha.append(fd.assemble(fd.Constant(self._alpha) * fd.inner(m_dot_D - m_dot_J, fd.TestFunction(space_D)) * fd.dx)) + if fd.utils.complex_mode: + raise RuntimeError("Not complex differentiable") + v_alpha.append(fd.assemble(fd.Constant(self._alpha) * fd.inner(m_dot_D_i - m_dot_J_i, fd.TestFunction(space_D)) * fd.dx)) v = self._primal_transform(u, v_alpha, apply_riesz=True) if apply_riesz: v = tuple(v_i._ad_convert_riesz(v_i, riesz_map=control.riesz_map) @@ -270,4 +278,14 @@ def hessian(self, m_dot, hessian_input=None, evaluate_tlm=True, apply_riesz=Fals @no_annotations def tlm(self, m_dot): - raise NotImplementedError("tlm not implemented") + m_dot = Enlist(m_dot) + m_dot_D, m_dot_J = self._dual_transform(m_dot) + tau_J = self._J.tlm(m_dot.delist(m_dot_J)) + + if self._alpha != 0: + for space, space_D, m_dot_D_i, m_D, m_J in zip(self._space, self._space_D, m_dot_D, *self._m_k): + if space is not space_D: + if fd.utils.complex_mode: + raise RuntimeError("Not complex differentiable") + tau_J += fd.assemble(fd.Constant(self._alpha) * fd.inner(m_D - m_J, m_dot_D_i) * fd.dx) + return tau_J diff --git a/tests/firedrake/adjoint/test_transformed_functional.py b/tests/firedrake/adjoint/test_transformed_functional.py index 6b5ff9c094..51e6cbca25 100644 --- a/tests/firedrake/adjoint/test_transformed_functional.py +++ b/tests/firedrake/adjoint/test_transformed_functional.py @@ -248,7 +248,7 @@ def error_norm(m): "tao_monitor": None}) m_opt = solver.solve() error_norm_opt = error_norm(m_opt) - print(f"{error_norm_opt=}") + print(f"{error_norm_opt=:.6g}") assert 1e-2 < error_norm_opt < 5e-2 assert J_hat._test_transformed_functional__ncalls > 22 # == 24 @@ -269,6 +269,6 @@ def error_norm(m): "tao_monitor": None}) m_opt = solver.solve() error_norm_opt = error_norm(m_opt) - print(f"{error_norm_opt=}") + print(f"{error_norm_opt=:.6g}") assert 1e-3 < error_norm_opt < 1e-2 assert J_hat._test_transformed_functional__ncalls < 18 # == 16 From 8f392384d22f59982a7226252fad70795bb44127 Mon Sep 17 00:00:00 2001 From: "James R. Maddison" Date: Thu, 4 Sep 2025 13:19:43 +0100 Subject: [PATCH 15/33] Tidying --- firedrake/adjoint/transformed_functional.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/firedrake/adjoint/transformed_functional.py b/firedrake/adjoint/transformed_functional.py index a47aa4a81b..d9dcbcf5ac 100644 --- a/firedrake/adjoint/transformed_functional.py +++ b/firedrake/adjoint/transformed_functional.py @@ -125,7 +125,7 @@ def __init__(self, functional, controls, *, tape=None, self._C = tuple(map(L2Cholesky, self._space_D)) self._controls = tuple(Control(fd.Function(space_D), riesz_map="l2") for space_D in self._space_D) - self._controls_delist = Enlist(Enlist(controls).delist(tuple(None for _ in self._controls))).delist + self._controls = Enlist(Enlist(controls).delist(self._controls)) self._alpha = alpha self._project_solver_parameters = flatten_parameters(project_solver_parameters) self._m_k = None @@ -137,7 +137,7 @@ def __init__(self, functional, controls, *, tape=None, @property def controls(self) -> Enlist[Control]: - return Enlist(self._controls_delist(self._controls)) + return Enlist(self._controls.delist()) def _primal_transform(self, u, u_D=None, *, apply_riesz=False): u = Enlist(u) From 6c3239c34145ba8a7cc211e6454a7c13342126fa Mon Sep 17 00:00:00 2001 From: "James R. Maddison" Date: Thu, 18 Sep 2025 15:52:14 +0100 Subject: [PATCH 16/33] Use a RieszMap for the projection and adjoint projection, update imports --- firedrake/adjoint/__init__.py | 1 + firedrake/adjoint/transformed_functional.py | 67 ++++++++++++------- .../adjoint/test_transformed_functional.py | 22 +++--- 3 files changed, 52 insertions(+), 38 deletions(-) diff --git a/firedrake/adjoint/__init__.py b/firedrake/adjoint/__init__.py index d3d28e6129..e2e57fc452 100644 --- a/firedrake/adjoint/__init__.py +++ b/firedrake/adjoint/__init__.py @@ -38,6 +38,7 @@ from firedrake.adjoint.ufl_constraints import UFLInequalityConstraint, \ UFLEqualityConstraint # noqa F401 from firedrake.adjoint.ensemble_reduced_functional import EnsembleReducedFunctional # noqa F401 +from firedrake.adjoint.transformed_functional import L2RieszMap, L2TransformedFunctional # noqa: F401 import numpy_adjoint # noqa F401 import firedrake.ufl_expr import types diff --git a/firedrake/adjoint/transformed_functional.py b/firedrake/adjoint/transformed_functional.py index d9dcbcf5ac..9aab49cd07 100644 --- a/firedrake/adjoint/transformed_functional.py +++ b/firedrake/adjoint/transformed_functional.py @@ -2,15 +2,16 @@ from operator import itemgetter import firedrake as fd +from firedrake.adjoint import Control, ReducedFunctional import finat -from petsctools import flatten_parameters -from pyadjoint.control import Control +from pyadjoint import no_annotations from pyadjoint.enlisting import Enlist -from pyadjoint.reduced_functional import AbstractReducedFunctional, ReducedFunctional -from pyadjoint.tape import no_annotations +from pyadjoint.reduced_functional import AbstractReducedFunctional +import ufl __all__ = \ [ + "L2RieszMap", "L2TransformedFunctional" ] @@ -50,6 +51,25 @@ def C_T_inv_action(self, u): return v +class L2RieszMap(fd.RieszMap): + """An :math:`L^2` Riesz map. + + Parameters + ---------- + + target : WithGeometry + Target space. + + Keyword arguments are passed to the :class:`firedrake.RieszMap` + constructor. + """ + + def __init__(self, target, **kwargs): + if not isinstance(target, fd.functionspaceimpl.WithGeometry): + raise TypeError("Target must be a WithGeometry") + super().__init__(target, ufl.L2, **kwargs) + + def is_dg_space(space): e, _ = finat.element_factory.convert(space.ufl_element()) return e.is_dg() @@ -90,8 +110,9 @@ class L2TransformedFunctional(AbstractReducedFunctional): Functional defining the optimization problem, :math:`J`. controls : Control or Sequence[Control] Controls. Must be :class:`firedrake.Function` objects. - tape : Tape - Tape used in evaluations involving :math:`J`. + riesz_map : L2RieszMap or Sequence[L2RieszMap] + Used for projecting from the discontinuous space onto the control + space. Ignored for DG controls. alpha : Real Modifies the functional, equivalent to adding an extra term to :math:`J \circ \Pi` @@ -103,19 +124,14 @@ class L2TransformedFunctional(AbstractReducedFunctional): e.g. in a minimization problem this adds a penalty term which can be used to avoid ill-posedness due to the use of a larger discontinuous space. - project_solver_parameters : Mapping - Solver parameters for an :math:`L^2` projection from the discontinuous - space onto the control space. Ignored for controls in DG spaces, - where the projection is an identity. + tape : Tape + Tape used in evaluations involving :math:`J`. """ @no_annotations - def __init__(self, functional, controls, *, tape=None, - alpha=0, project_solver_parameters=None): + def __init__(self, functional, controls, *, riesz_map=None, alpha=0, tape=None): if not all(isinstance(control.control, fd.Function) for control in Enlist(controls)): raise TypeError("controls must be Function objects") - if project_solver_parameters is None: - project_solver_parameters = {} super().__init__() self._J = ReducedFunctional(functional, controls, tape=tape) @@ -127,9 +143,14 @@ def __init__(self, functional, controls, *, tape=None, for space_D in self._space_D) self._controls = Enlist(Enlist(controls).delist(self._controls)) self._alpha = alpha - self._project_solver_parameters = flatten_parameters(project_solver_parameters) self._m_k = None + if riesz_map is None: + riesz_map = tuple(map(L2RieszMap, self._space)) + self._riesz_map = Enlist(riesz_map) + if len(self._riesz_map) != len(self._controls): + raise ValueError("Invalid length") + # Map the initial guess controls_t = self._primal_transform(tuple(control.control for control in self._J.controls), apply_riesz=False) for control, control_t in zip(self._controls, controls_t): @@ -150,17 +171,14 @@ def _primal_transform(self, u, u_D=None, *, apply_riesz=False): if len(u_D) != len(self.controls): raise ValueError("Invalid length") - def transform(C, u, u_D, space, space_D): + def transform(C, u, u_D, space, space_D, riesz_map): # Map function to transformed 'cofunction': if apply_riesz: # C_W^{-1} P_{VW}^* M_V^{-1} if space is space_D: v = u else: - # Might be replaced with interpolation (does this work for - # all spaces?) - v = u.riesz_representation(solver_options=self._project_solver_parameters) - v = fd.assemble(fd.inner(v, fd.TestFunction(space_D)) * fd.dx) + v = fd.assemble(fd.inner(riesz_map(u), fd.TestFunction(space_D)) * fd.dx) else: # C_W^{-1} P_{VW}^* v = fd.assemble(fd.inner(u, fd.TestFunction(space_D)) * fd.dx) @@ -169,7 +187,7 @@ def transform(C, u, u_D, space, space_D): v = C.C_inv_action(v) return v.riesz_representation("l2") - v = tuple(map(transform, self._C, u, u_D, self._space, self._space_D)) + v = tuple(map(transform, self._C, u, u_D, self._space, self._space_D, self._riesz_map)) return u.delist(v) def _dual_transform(self, u): @@ -177,7 +195,7 @@ def _dual_transform(self, u): if len(u) != len(self.controls): raise ValueError("Invalid length") - def transform(C, u, space, space_D): + def transform(C, u, space, space_D, riesz_map): # Map transformed 'cofunction' to function: # M_V^{-1} P_{VW} C_W^{-*} if fd.utils.complex_mode: @@ -187,11 +205,10 @@ def transform(C, u, space, space_D): if space is space_D: w = v else: - w = fd.Function(space).project( - v, solver_parameters=self._project_solver_parameters) + w = riesz_map(fd.assemble(fd.inner(v, fd.TestFunction(space)) * fd.dx)) return v, w - vw = tuple(map(transform, self._C, u, self._space, self._space_D)) + vw = tuple(map(transform, self._C, u, self._space, self._space_D, self._riesz_map)) return u.delist(tuple(map(itemgetter(0), vw))), u.delist(tuple(map(itemgetter(1), vw))) @no_annotations diff --git a/tests/firedrake/adjoint/test_transformed_functional.py b/tests/firedrake/adjoint/test_transformed_functional.py index 51e6cbca25..f42a9c3ba5 100644 --- a/tests/firedrake/adjoint/test_transformed_functional.py +++ b/tests/firedrake/adjoint/test_transformed_functional.py @@ -3,13 +3,11 @@ import firedrake as fd from firedrake.adjoint import ( - Control, ReducedFunctional, continue_annotation, minimize, - pause_annotation) -from firedrake.adjoint.transformed_functional import L2TransformedFunctional + Control, L2TransformedFunctional, MinimizationProblem, ReducedFunctional, + continue_annotation, minimize, pause_annotation, set_working_tape) import numpy as np -from pyadjoint import MinimizationProblem, TAOSolver +from pyadjoint import TAOSolver from pyadjoint.reduced_functional_numpy import ReducedFunctionalNumPy -from pyadjoint.tape import set_working_tape import pytest import ufl @@ -119,7 +117,7 @@ def error_norm(m): callback=cb, options={"ftol": 0, "gtol": 1e-6}) - assert cb[-1] < 1e-8 + assert cb[-1] < 1e-10 assert len(cb) == 3 assert J_hat._test_transformed_functional__ncalls == 3 @@ -190,8 +188,8 @@ def error_norm(m): options={"ftol": 0, "gtol": 1e-10}) assert 1e-4 < cb[-1] < 5e-4 - assert len(cb) < 55 # == 50 - assert J_hat._test_transformed_functional__ncalls < 55 # == 51 + assert len(cb) < 55 # == 51 + assert J_hat._test_transformed_functional__ncalls < 60 # == 55 def test_transformed_functional_poisson_tao_nls(): @@ -244,13 +242,12 @@ def error_norm(m): "tao_converged_reason": None, "tao_gatol": 1.0e-5, "tao_grtol": 0.0, - "tao_gttol": 1.0e-6, - "tao_monitor": None}) + "tao_gttol": 1.0e-6}) m_opt = solver.solve() error_norm_opt = error_norm(m_opt) print(f"{error_norm_opt=:.6g}") assert 1e-2 < error_norm_opt < 5e-2 - assert J_hat._test_transformed_functional__ncalls > 22 # == 24 + assert J_hat._test_transformed_functional__ncalls > 22 # == 25 J_hat = L2TransformedFunctional(J, c, alpha=1e-5) @@ -265,8 +262,7 @@ def error_norm(m): "tao_converged_reason": None, "tao_gatol": 1.0e-5, "tao_grtol": 0.0, - "tao_gttol": 1.0e-6, - "tao_monitor": None}) + "tao_gttol": 1.0e-6}) m_opt = solver.solve() error_norm_opt = error_norm(m_opt) print(f"{error_norm_opt=:.6g}") From ea3fd8a9560c020932224679f2f3dec087f2eca4 Mon Sep 17 00:00:00 2001 From: "James R. Maddison" Date: Fri, 19 Sep 2025 09:55:55 +0100 Subject: [PATCH 17/33] Add reference --- firedrake/adjoint/transformed_functional.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/firedrake/adjoint/transformed_functional.py b/firedrake/adjoint/transformed_functional.py index 9aab49cd07..bf8a76550c 100644 --- a/firedrake/adjoint/transformed_functional.py +++ b/firedrake/adjoint/transformed_functional.py @@ -103,6 +103,10 @@ class L2TransformedFunctional(AbstractReducedFunctional): space. This can be used for mesh-independent optimization for libraries which support only an :math:`l_2` inner product. + The transformation is related to the factorization in section 4.1 of + https://doi.org/10.1137/18M1175239 -- specifically the factorization + in their equation (4.2) can be related to :math:`\Pi \circ \Xi`. + Parameters ---------- From c530ba475ebbbb734449b18cc544e00b27ae41f0 Mon Sep 17 00:00:00 2001 From: "James R. Maddison" Date: Fri, 19 Sep 2025 12:20:43 +0100 Subject: [PATCH 18/33] Docstring updates --- firedrake/adjoint/transformed_functional.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/firedrake/adjoint/transformed_functional.py b/firedrake/adjoint/transformed_functional.py index bf8a76550c..d1c1587380 100644 --- a/firedrake/adjoint/transformed_functional.py +++ b/firedrake/adjoint/transformed_functional.py @@ -100,8 +100,7 @@ class L2TransformedFunctional(AbstractReducedFunctional): The optimization is therefore transformed into an optimization problem using an :math:`L^2` orthonormal basis for a discontinuous finite element - space. This can be used for mesh-independent optimization for libraries - which support only an :math:`l_2` inner product. + space. The transformation is related to the factorization in section 4.1 of https://doi.org/10.1137/18M1175239 -- specifically the factorization @@ -222,7 +221,7 @@ def map_result(self, m): Parameters ---------- - m : firedrake.Cofunction or Sequence[firedrake.Cofunction] + m : firedrake.Function or Sequence[firedrake.Function] The result of the optimization. Represents an expansion in an :math:`L^2` orthonormal basis for the discontinuous space. From 2214c6ef34f222df4a27597343f6b5181b55676e Mon Sep 17 00:00:00 2001 From: "James R. Maddison" Date: Fri, 19 Sep 2025 14:22:00 +0100 Subject: [PATCH 19/33] Remove some comments (notes to myself) --- firedrake/adjoint/transformed_functional.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/firedrake/adjoint/transformed_functional.py b/firedrake/adjoint/transformed_functional.py index d1c1587380..2d45685f17 100644 --- a/firedrake/adjoint/transformed_functional.py +++ b/firedrake/adjoint/transformed_functional.py @@ -175,15 +175,12 @@ def _primal_transform(self, u, u_D=None, *, apply_riesz=False): raise ValueError("Invalid length") def transform(C, u, u_D, space, space_D, riesz_map): - # Map function to transformed 'cofunction': if apply_riesz: - # C_W^{-1} P_{VW}^* M_V^{-1} if space is space_D: v = u else: v = fd.assemble(fd.inner(riesz_map(u), fd.TestFunction(space_D)) * fd.dx) else: - # C_W^{-1} P_{VW}^* v = fd.assemble(fd.inner(u, fd.TestFunction(space_D)) * fd.dx) if u_D is not None: v.dat.axpy(1, u_D.dat) @@ -199,8 +196,6 @@ def _dual_transform(self, u): raise ValueError("Invalid length") def transform(C, u, space, space_D, riesz_map): - # Map transformed 'cofunction' to function: - # M_V^{-1} P_{VW} C_W^{-*} if fd.utils.complex_mode: # Would need to be adjoint raise NotImplementedError("complex not supported") From af86f5b57069d800e23fe23fe14c42b99349e6d9 Mon Sep 17 00:00:00 2001 From: "James R. Maddison" Date: Mon, 22 Sep 2025 09:17:14 +0100 Subject: [PATCH 20/33] Minor tidying --- firedrake/adjoint/transformed_functional.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/firedrake/adjoint/transformed_functional.py b/firedrake/adjoint/transformed_functional.py index 2d45685f17..b5dc899360 100644 --- a/firedrake/adjoint/transformed_functional.py +++ b/firedrake/adjoint/transformed_functional.py @@ -233,12 +233,13 @@ def map_result(self, m): @no_annotations def __call__(self, values): values = Enlist(values) - m_D, m_J = self._m_k = self._dual_transform(values) + m_D, m_J = self._dual_transform(values) J = self._J(m_J) if self._alpha != 0: - for space, space_D, m_D, m_J in zip(self._space, self._space_D, *self._m_k): + for space, space_D, m_D_i, m_J_i in zip(self._space, self._space_D, m_D, m_J): if space is not space_D: - J += fd.assemble(0.5 * fd.Constant(self._alpha) * fd.inner(m_D - m_J, m_D - m_J) * fd.dx) + J += fd.assemble(0.5 * fd.Constant(self._alpha) * fd.inner(m_D_i - m_J_i, m_D_i - m_J_i) * fd.dx) + self._m_k = m_D, m_J return J @no_annotations From 060be0ef9b381bad2eee9d84ea385deddf4d54b7 Mon Sep 17 00:00:00 2001 From: "James R. Maddison" Date: Mon, 22 Sep 2025 09:27:42 +0100 Subject: [PATCH 21/33] Parallelize --- firedrake/adjoint/transformed_functional.py | 32 +++++++++++++++------ 1 file changed, 24 insertions(+), 8 deletions(-) diff --git a/firedrake/adjoint/transformed_functional.py b/firedrake/adjoint/transformed_functional.py index b5dc899360..86d1bc7b45 100644 --- a/firedrake/adjoint/transformed_functional.py +++ b/firedrake/adjoint/transformed_functional.py @@ -1,3 +1,4 @@ +from contextlib import contextmanager from functools import cached_property from operator import itemgetter @@ -16,6 +17,14 @@ ] +@contextmanager +def local_vector(u, *, readonly=False): + u_local = u.createLocalVector() + u.getLocalVector(u_local, readonly=readonly) + yield u_local + u.restoreLocalVector(u_local, readonly=readonly) + + class L2Cholesky: def __init__(self, space): self._space = space @@ -26,28 +35,35 @@ def space(self): @cached_property def M(self): - return fd.assemble(fd.inner(fd.TrialFunction(self.space), fd.TestFunction(self.space)) * fd.dx) + return fd.assemble(fd.inner(fd.TrialFunction(self.space), fd.TestFunction(self.space)) * fd.dx, + mat_type="aij") @cached_property - def solver(self): - return fd.LinearSolver(self.M, solver_parameters={"ksp_type": "preonly", - "pc_type": "cholesky", - "pc_factor_mat_ordering_type": "nd"}) + def M_local(self): + return self.M.petscmat.getDiagonalBlock() @cached_property def pc(self): - return self.solver.ksp.getPC() + import petsc4py.PETSc as PETSc + pc = PETSc.PC().create(self.M_local.comm) + pc.setType(PETSc.PC.Type.CHOLESKY) + pc.setFactorSolverType(PETSc.Mat.SolverType.PETSC) + pc.setOperators(self.M_local) + pc.setUp() + return pc def C_inv_action(self, u): v = fd.Cofunction(self.space.dual()) with u.dat.vec_ro as u_v, v.dat.vec_wo as v_v: - self.pc.applySymmetricLeft(u_v, v_v) + with local_vector(u_v, readonly=True) as u_v_s, local_vector(v_v) as v_v_s: + self.pc.applySymmetricLeft(u_v_s, v_v_s) return v def C_T_inv_action(self, u): v = fd.Function(self.space) with u.dat.vec_ro as u_v, v.dat.vec_wo as v_v: - self.pc.applySymmetricRight(u_v, v_v) + with local_vector(u_v, readonly=True) as u_v_s, local_vector(v_v) as v_v_s: + self.pc.applySymmetricRight(u_v_s, v_v_s) return v From 06e69e673562c7dd82ddac18fb01d27d76b53d96 Mon Sep 17 00:00:00 2001 From: "James R. Maddison" Date: Mon, 22 Sep 2025 09:40:35 +0100 Subject: [PATCH 22/33] Control L2Cholesky caching --- firedrake/adjoint/transformed_functional.py | 39 +++++++++++---------- firedrake/cofunction.py | 9 +++++ 2 files changed, 30 insertions(+), 18 deletions(-) diff --git a/firedrake/adjoint/transformed_functional.py b/firedrake/adjoint/transformed_functional.py index 86d1bc7b45..93a462ca43 100644 --- a/firedrake/adjoint/transformed_functional.py +++ b/firedrake/adjoint/transformed_functional.py @@ -1,5 +1,4 @@ from contextlib import contextmanager -from functools import cached_property from operator import itemgetter import firedrake as fd @@ -26,30 +25,33 @@ def local_vector(u, *, readonly=False): class L2Cholesky: - def __init__(self, space): + def __init__(self, space, *, constant_jacobian=True): self._space = space + self._constant_jacobian = constant_jacobian + self._pc = None @property def space(self): return self._space - @cached_property - def M(self): - return fd.assemble(fd.inner(fd.TrialFunction(self.space), fd.TestFunction(self.space)) * fd.dx, - mat_type="aij") - - @cached_property - def M_local(self): - return self.M.petscmat.getDiagonalBlock() - - @cached_property + @property def pc(self): import petsc4py.PETSc as PETSc - pc = PETSc.PC().create(self.M_local.comm) - pc.setType(PETSc.PC.Type.CHOLESKY) - pc.setFactorSolverType(PETSc.Mat.SolverType.PETSC) - pc.setOperators(self.M_local) - pc.setUp() + + pc = self._pc + if self._pc is None: + M = fd.assemble(fd.inner(fd.TrialFunction(self.space), fd.TestFunction(self.space)) * fd.dx, + mat_type="aij") + M_local = M.petscmat.getDiagonalBlock() + + pc = PETSc.PC().create(M_local.comm) + pc.setType(PETSc.PC.Type.CHOLESKY) + pc.setFactorSolverType(PETSc.Mat.SolverType.PETSC) + pc.setOperators(M_local) + pc.setUp() + if self._constant_jacobian: + self._pc = pc + return pc def C_inv_action(self, u): @@ -157,7 +159,6 @@ def __init__(self, functional, controls, *, riesz_map=None, alpha=0, tape=None): self._space = tuple(control.control.function_space() for control in self._J.controls) self._space_D = tuple(map(dg_space, self._space)) - self._C = tuple(map(L2Cholesky, self._space_D)) self._controls = tuple(Control(fd.Function(space_D), riesz_map="l2") for space_D in self._space_D) self._controls = Enlist(Enlist(controls).delist(self._controls)) @@ -169,6 +170,8 @@ def __init__(self, functional, controls, *, riesz_map=None, alpha=0, tape=None): self._riesz_map = Enlist(riesz_map) if len(self._riesz_map) != len(self._controls): raise ValueError("Invalid length") + self._C = tuple(L2Cholesky(space_D, constant_jacobian=riesz_map.constant_jacobian) + for space_D, riesz_map in zip(self._space_D, self._riesz_map)) # Map the initial guess controls_t = self._primal_transform(tuple(control.control for control in self._J.controls), apply_riesz=False) diff --git a/firedrake/cofunction.py b/firedrake/cofunction.py index 9577a909ba..b66bbb7939 100644 --- a/firedrake/cofunction.py +++ b/firedrake/cofunction.py @@ -435,6 +435,8 @@ class RieszMap: variational problem that solves for the Riesz map. restrict: bool If `True`, use restricted function spaces in the Riesz map solver. + constant_jacobian : bool + Whether the matrix associated with the map is constant. """ def __init__(self, function_space_or_inner_product=None, @@ -539,3 +541,10 @@ def __call__(self, value): f"Unable to ascertain if {value} is primal or dual." ) return output + + @property + def constant_jacobian(self) -> bool: + """Whether the matrix associated with the map is constant. + """ + + return self._constant_jacobian From 83d583f8eed98ba62dac4b911a2010f3b42c3e11 Mon Sep 17 00:00:00 2001 From: "James R. Maddison" Date: Mon, 22 Sep 2025 09:55:22 +0100 Subject: [PATCH 23/33] Defensively hold references to matrices --- firedrake/adjoint/transformed_functional.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/firedrake/adjoint/transformed_functional.py b/firedrake/adjoint/transformed_functional.py index 93a462ca43..03d6662797 100644 --- a/firedrake/adjoint/transformed_functional.py +++ b/firedrake/adjoint/transformed_functional.py @@ -38,7 +38,6 @@ def space(self): def pc(self): import petsc4py.PETSc as PETSc - pc = self._pc if self._pc is None: M = fd.assemble(fd.inner(fd.TrialFunction(self.space), fd.TestFunction(self.space)) * fd.dx, mat_type="aij") @@ -49,8 +48,11 @@ def pc(self): pc.setFactorSolverType(PETSc.Mat.SolverType.PETSC) pc.setOperators(M_local) pc.setUp() - if self._constant_jacobian: - self._pc = pc + + if self._constant_jacobian: + self._pc = M, M_local, pc + else: + _, _, pc = self._pc return pc From 4fb84a697795b249df0db2d734d69db38c88acdc Mon Sep 17 00:00:00 2001 From: "James R. Maddison" Date: Tue, 23 Sep 2025 18:33:27 +0100 Subject: [PATCH 24/33] Docstring update, minor edits --- firedrake/adjoint/transformed_functional.py | 131 ++++++++++++++++---- 1 file changed, 110 insertions(+), 21 deletions(-) diff --git a/firedrake/adjoint/transformed_functional.py b/firedrake/adjoint/transformed_functional.py index 03d6662797..9dcf997dcc 100644 --- a/firedrake/adjoint/transformed_functional.py +++ b/firedrake/adjoint/transformed_functional.py @@ -25,20 +25,36 @@ def local_vector(u, *, readonly=False): class L2Cholesky: + """Mass matrix Cholesky factorization for a (real) DG space. + + Parameters + ---------- + + space : WithGeometry + DG space. + constant_jacobian : bool + Whether the mass matrix is constant. + """ + def __init__(self, space, *, constant_jacobian=True): + if fd.utils.complex_mode: + raise NotImplementedError("complex not supported") + self._space = space self._constant_jacobian = constant_jacobian - self._pc = None + self._cached_pc = None @property - def space(self): + def space(self) -> fd.functionspaceimpl.WithGeometry: + """Function space. + """ + return self._space - @property - def pc(self): + def _pc(self): import petsc4py.PETSc as PETSc - if self._pc is None: + if self._cached_pc is None: M = fd.assemble(fd.inner(fd.TrialFunction(self.space), fd.TestFunction(self.space)) * fd.dx, mat_type="aij") M_local = M.petscmat.getDiagonalBlock() @@ -50,24 +66,70 @@ def pc(self): pc.setUp() if self._constant_jacobian: - self._pc = M, M_local, pc + self._cached_pc = M, M_local, pc else: - _, _, pc = self._pc + _, _, pc = self._cached_pc return pc def C_inv_action(self, u): + """For the Cholesky factorization + + ... math : + + M = C C^T, + + compute the action of :math:`C^{-1}`. + + Parameters + ---------- + + u : Function or Cofunction + Compute :math:`C^{-1} \tilde{u}` where :math:`\tilde{u}` is the + vector of degrees of freedom for :math:`u`. + + Returns + ------- + + v : Cofunction + Has vector of degrees of freedom :math:`C^{-1} \tilde{u}`. + """ + + pc = self._pc() v = fd.Cofunction(self.space.dual()) with u.dat.vec_ro as u_v, v.dat.vec_wo as v_v: with local_vector(u_v, readonly=True) as u_v_s, local_vector(v_v) as v_v_s: - self.pc.applySymmetricLeft(u_v_s, v_v_s) + pc.applySymmetricLeft(u_v_s, v_v_s) return v def C_T_inv_action(self, u): + """For the Cholesky factorization + + ... math : + + M = C C^T, + + compute the action of :math:`C^{-T}`. + + Parameters + ---------- + + u : Function or Cofunction + Compute :math:`C^{-T} \tilde{u}` where :math:`\tilde{u}` is the + vector of degrees of freedom for :math:`u`. + + Returns + ------- + + v : Function + Has vector of degrees of freedom :math:`C^{-T} \tilde{u}`. + """ + + pc = self._pc() v = fd.Function(self.space) with u.dat.vec_ro as u_v, v.dat.vec_wo as v_v: with local_vector(u_v, readonly=True) as u_v_s, local_vector(v_v) as v_v_s: - self.pc.applySymmetricRight(u_v_s, v_v_s) + pc.applySymmetricRight(u_v_s, v_v_s) return v @@ -78,7 +140,7 @@ class L2RieszMap(fd.RieszMap): ---------- target : WithGeometry - Target space. + Function space. Keyword arguments are passed to the :class:`firedrake.RieszMap` constructor. @@ -91,11 +153,41 @@ def __init__(self, target, **kwargs): def is_dg_space(space): + """Return whether a function space is DG. + + Parameters + ---------- + + space : WithGeometry + The function space. + + Returns + ------- + + bool + Whether the function space is DG. + """ + e, _ = finat.element_factory.convert(space.ufl_element()) return e.is_dg() def dg_space(space): + """Construct a DG space containing a given function space as a subspace. + + Parameters + ---------- + + space : WithGeometry + A function space. + + Returns + ------- + + WithGeometry + A DG space containing `space` as a subspace. May be `space`. + """ + if is_dg_space(space): return space else: @@ -112,15 +204,13 @@ class L2TransformedFunctional(AbstractReducedFunctional): where - :math:`J` is the functional definining an optimization problem. - - :math:`\Pi` is the :math:`L^2` projection from a discontinuous - superspace of the control space. + - :math:`\Pi` is the :math:`L^2` projection from a DG space containing + the control space as a subspace. - :math:`\Xi` represents a change of basis from an :math:`L^2` - orthonormal basis to the finite element basis for the discontinuous - superspace. + orthonormal basis to the finite element basis for the DG space. The optimization is therefore transformed into an optimization problem - using an :math:`L^2` orthonormal basis for a discontinuous finite element - space. + using an :math:`L^2` orthonormal basis for a DG finite element space. The transformation is related to the factorization in section 4.1 of https://doi.org/10.1137/18M1175239 -- specifically the factorization @@ -134,8 +224,8 @@ class L2TransformedFunctional(AbstractReducedFunctional): controls : Control or Sequence[Control] Controls. Must be :class:`firedrake.Function` objects. riesz_map : L2RieszMap or Sequence[L2RieszMap] - Used for projecting from the discontinuous space onto the control - space. Ignored for DG controls. + Used for projecting from the DG space onto the control space. Ignored + for DG controls. alpha : Real Modifies the functional, equivalent to adding an extra term to :math:`J \circ \Pi` @@ -145,8 +235,7 @@ class L2TransformedFunctional(AbstractReducedFunctional): \frac{1}{2} \alpha \left\| m_D - \Pi ( m_D ) \right\|_{L^2}^2. e.g. in a minimization problem this adds a penalty term which can - be used to avoid ill-posedness due to the use of a larger discontinuous - space. + be used to avoid ill-posedness due to the use of a larger DG space. tape : Tape Tape used in evaluations involving :math:`J`. """ @@ -239,7 +328,7 @@ def map_result(self, m): m : firedrake.Function or Sequence[firedrake.Function] The result of the optimization. Represents an expansion in an - :math:`L^2` orthonormal basis for the discontinuous space. + :math:`L^2` orthonormal basis for the DG space. Returns ------- From 11f4e6ac6cf1c781addf5e43d62e2807f9615248 Mon Sep 17 00:00:00 2001 From: "James R. Maddison" Date: Tue, 23 Sep 2025 18:34:28 +0100 Subject: [PATCH 25/33] Add optional space_D argument --- firedrake/adjoint/transformed_functional.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/firedrake/adjoint/transformed_functional.py b/firedrake/adjoint/transformed_functional.py index 9dcf997dcc..005bbac7db 100644 --- a/firedrake/adjoint/transformed_functional.py +++ b/firedrake/adjoint/transformed_functional.py @@ -223,6 +223,8 @@ class L2TransformedFunctional(AbstractReducedFunctional): Functional defining the optimization problem, :math:`J`. controls : Control or Sequence[Control] Controls. Must be :class:`firedrake.Function` objects. + space_D : None, WithGeometry, or Sequence[None or WithGeometry] + DG space containing the control space. riesz_map : L2RieszMap or Sequence[L2RieszMap] Used for projecting from the DG space onto the control space. Ignored for DG controls. @@ -241,18 +243,27 @@ class L2TransformedFunctional(AbstractReducedFunctional): """ @no_annotations - def __init__(self, functional, controls, *, riesz_map=None, alpha=0, tape=None): + def __init__(self, functional, controls, *, space_D=None, riesz_map=None, alpha=0, tape=None): if not all(isinstance(control.control, fd.Function) for control in Enlist(controls)): raise TypeError("controls must be Function objects") super().__init__() self._J = ReducedFunctional(functional, controls, tape=tape) + self._space = tuple(control.control.function_space() for control in self._J.controls) - self._space_D = tuple(map(dg_space, self._space)) + if space_D is None: + space_D = tuple(None for _ in self._space) + self._space_D = Enlist(space_D) + if len(self._space_D) != len(self._space): + raise ValueError("Invalid length") + self._space_D = tuple(dg_space(space) if space_D is None else space_D + for space, space_D in zip(self._space, self._space_D)) + self._controls = tuple(Control(fd.Function(space_D), riesz_map="l2") for space_D in self._space_D) self._controls = Enlist(Enlist(controls).delist(self._controls)) + self._alpha = alpha self._m_k = None From b8287d9521b7470e6455ffd54074d060ce052f8b Mon Sep 17 00:00:00 2001 From: "James R. Maddison" Date: Wed, 24 Sep 2025 09:37:29 +0100 Subject: [PATCH 26/33] Tidying --- firedrake/adjoint/transformed_functional.py | 26 +++++++++---------- .../adjoint/test_transformed_functional.py | 2 +- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/firedrake/adjoint/transformed_functional.py b/firedrake/adjoint/transformed_functional.py index 005bbac7db..cbe40667cc 100644 --- a/firedrake/adjoint/transformed_functional.py +++ b/firedrake/adjoint/transformed_functional.py @@ -264,9 +264,6 @@ def __init__(self, functional, controls, *, space_D=None, riesz_map=None, alpha= for space_D in self._space_D) self._controls = Enlist(Enlist(controls).delist(self._controls)) - self._alpha = alpha - self._m_k = None - if riesz_map is None: riesz_map = tuple(map(L2RieszMap, self._space)) self._riesz_map = Enlist(riesz_map) @@ -275,8 +272,11 @@ def __init__(self, functional, controls, *, space_D=None, riesz_map=None, alpha= self._C = tuple(L2Cholesky(space_D, constant_jacobian=riesz_map.constant_jacobian) for space_D, riesz_map in zip(self._space_D, self._riesz_map)) + self._alpha = alpha + self._m_k = None + # Map the initial guess - controls_t = self._primal_transform(tuple(control.control for control in self._J.controls), apply_riesz=False) + controls_t = self._dual_transform(tuple(control.control for control in self._J.controls), apply_riesz=False) for control, control_t in zip(self._controls, controls_t): control.control.assign(control_t) @@ -284,7 +284,7 @@ def __init__(self, functional, controls, *, space_D=None, riesz_map=None, alpha= def controls(self) -> Enlist[Control]: return Enlist(self._controls.delist()) - def _primal_transform(self, u, u_D=None, *, apply_riesz=False): + def _dual_transform(self, u, u_D=None, *, apply_riesz=False): u = Enlist(u) if len(u) != len(self.controls): raise ValueError("Invalid length") @@ -311,7 +311,7 @@ def transform(C, u, u_D, space, space_D, riesz_map): v = tuple(map(transform, self._C, u, u_D, self._space, self._space_D, self._riesz_map)) return u.delist(v) - def _dual_transform(self, u): + def _primal_transform(self, u): u = Enlist(u) if len(u) != len(self.controls): raise ValueError("Invalid length") @@ -344,17 +344,17 @@ def map_result(self, m): Returns ------- - firedrake.Function or list[firedrake.Function] + firedrake.Function or Sequence[firedrake.Function] The mapped result in the original control space. """ - _, m_J = self._dual_transform(m) + _, m_J = self._primal_transform(m) return m_J @no_annotations def __call__(self, values): values = Enlist(values) - m_D, m_J = self._dual_transform(values) + m_D, m_J = self._primal_transform(values) J = self._J(m_J) if self._alpha != 0: for space, space_D, m_D_i, m_J_i in zip(self._space, self._space_D, m_D, m_J): @@ -381,7 +381,7 @@ def derivative(self, adj_input=1.0, apply_riesz=False): if fd.utils.complex_mode: raise RuntimeError("Not complex differentiable") v_alpha.append(fd.assemble(fd.Constant(self._alpha) * fd.inner(m_D - m_J, fd.TestFunction(space_D)) * fd.dx)) - v = self._primal_transform(u, v_alpha, apply_riesz=True) + v = self._dual_transform(u, v_alpha, apply_riesz=True) if apply_riesz: v = tuple(v_i._ad_convert_riesz(v_i, riesz_map=control.riesz_map) for v_i, control in zip(v, self.controls)) @@ -393,7 +393,7 @@ def hessian(self, m_dot, hessian_input=None, evaluate_tlm=True, apply_riesz=Fals raise NotImplementedError("hessian_input not None not supported") m_dot = Enlist(m_dot) - m_dot_D, m_dot_J = self._dual_transform(m_dot) + m_dot_D, m_dot_J = self._primal_transform(m_dot) u = Enlist(self._J.hessian(m_dot.delist(m_dot_J), evaluate_tlm=evaluate_tlm)) if self._alpha == 0: @@ -407,7 +407,7 @@ def hessian(self, m_dot, hessian_input=None, evaluate_tlm=True, apply_riesz=Fals if fd.utils.complex_mode: raise RuntimeError("Not complex differentiable") v_alpha.append(fd.assemble(fd.Constant(self._alpha) * fd.inner(m_dot_D_i - m_dot_J_i, fd.TestFunction(space_D)) * fd.dx)) - v = self._primal_transform(u, v_alpha, apply_riesz=True) + v = self._dual_transform(u, v_alpha, apply_riesz=True) if apply_riesz: v = tuple(v_i._ad_convert_riesz(v_i, riesz_map=control.riesz_map) for v_i, control in zip(v, self.controls)) @@ -416,7 +416,7 @@ def hessian(self, m_dot, hessian_input=None, evaluate_tlm=True, apply_riesz=Fals @no_annotations def tlm(self, m_dot): m_dot = Enlist(m_dot) - m_dot_D, m_dot_J = self._dual_transform(m_dot) + m_dot_D, m_dot_J = self._primal_transform(m_dot) tau_J = self._J.tlm(m_dot.delist(m_dot_J)) if self._alpha != 0: diff --git a/tests/firedrake/adjoint/test_transformed_functional.py b/tests/firedrake/adjoint/test_transformed_functional.py index f42a9c3ba5..ada587b5e4 100644 --- a/tests/firedrake/adjoint/test_transformed_functional.py +++ b/tests/firedrake/adjoint/test_transformed_functional.py @@ -247,7 +247,7 @@ def error_norm(m): error_norm_opt = error_norm(m_opt) print(f"{error_norm_opt=:.6g}") assert 1e-2 < error_norm_opt < 5e-2 - assert J_hat._test_transformed_functional__ncalls > 22 # == 25 + assert J_hat._test_transformed_functional__ncalls > 22 # == 24 J_hat = L2TransformedFunctional(J, c, alpha=1e-5) From ba01fd8db5da3933c95f2f49ab1ca21ca2c5680c Mon Sep 17 00:00:00 2001 From: "James R. Maddison" Date: Tue, 11 Nov 2025 16:48:37 +0000 Subject: [PATCH 27/33] skipcomplex --- tests/firedrake/adjoint/test_transformed_functional.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/firedrake/adjoint/test_transformed_functional.py b/tests/firedrake/adjoint/test_transformed_functional.py index ada587b5e4..2369925144 100644 --- a/tests/firedrake/adjoint/test_transformed_functional.py +++ b/tests/firedrake/adjoint/test_transformed_functional.py @@ -69,6 +69,7 @@ def __call__(self, xk): @pytest.mark.parametrize("family", ("Lagrange", "Discontinuous Lagrange")) +@pytest.mark.skipcomplex def test_transformed_functional_mass_inverse(family): mesh = fd.UnitSquareMesh(5, 5, diagonal="crossed") x, y = fd.SpatialCoordinate(mesh) @@ -122,6 +123,7 @@ def error_norm(m): assert J_hat._test_transformed_functional__ncalls == 3 +@pytest.mark.skipcomplex def test_transformed_functional_poisson(): mesh = fd.UnitSquareMesh(5, 5, diagonal="crossed") x, y = fd.SpatialCoordinate(mesh) @@ -192,6 +194,7 @@ def error_norm(m): assert J_hat._test_transformed_functional__ncalls < 60 # == 55 +@pytest.mark.skipcomplex def test_transformed_functional_poisson_tao_nls(): mesh = fd.UnitSquareMesh(5, 5, diagonal="crossed") x, y = fd.SpatialCoordinate(mesh) From 821e525d715be99cd37e1904896ccb946e075f56 Mon Sep 17 00:00:00 2001 From: "James R. Maddison" Date: Wed, 12 Nov 2025 09:52:51 +0000 Subject: [PATCH 28/33] Update test --- tests/firedrake/adjoint/test_transformed_functional.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/firedrake/adjoint/test_transformed_functional.py b/tests/firedrake/adjoint/test_transformed_functional.py index 2369925144..a048cdd610 100644 --- a/tests/firedrake/adjoint/test_transformed_functional.py +++ b/tests/firedrake/adjoint/test_transformed_functional.py @@ -250,7 +250,7 @@ def error_norm(m): error_norm_opt = error_norm(m_opt) print(f"{error_norm_opt=:.6g}") assert 1e-2 < error_norm_opt < 5e-2 - assert J_hat._test_transformed_functional__ncalls > 22 # == 24 + assert J_hat._test_transformed_functional__ncalls > 8 # == 10 J_hat = L2TransformedFunctional(J, c, alpha=1e-5) @@ -270,4 +270,4 @@ def error_norm(m): error_norm_opt = error_norm(m_opt) print(f"{error_norm_opt=:.6g}") assert 1e-3 < error_norm_opt < 1e-2 - assert J_hat._test_transformed_functional__ncalls < 18 # == 16 + assert J_hat._test_transformed_functional__ncalls < 10 # == 8 From 78f000bcdc35831efbeddf2200671dad4be5a149 Mon Sep 17 00:00:00 2001 From: "James R. Maddison" Date: Wed, 12 Nov 2025 10:06:56 +0000 Subject: [PATCH 29/33] Rerun forward, set the Riesz map --- .../adjoint/test_transformed_functional.py | 22 +++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/tests/firedrake/adjoint/test_transformed_functional.py b/tests/firedrake/adjoint/test_transformed_functional.py index a048cdd610..a8f4d5a5bc 100644 --- a/tests/firedrake/adjoint/test_transformed_functional.py +++ b/tests/firedrake/adjoint/test_transformed_functional.py @@ -107,6 +107,12 @@ def error_norm(m): else: raise ValueError(f"Invalid element family: '{family}'") + continue_annotation() + m_0 = fd.Function(space, name="m_0") + J = forward(m_0) + pause_annotation() + c = Control(m_0, riesz_map="l2") + J_hat = L2TransformedFunctional(J, c, alpha=1) def error_norm(m): @@ -177,6 +183,12 @@ def error_norm(m): assert len(cb) > 80 # == 85 assert J_hat._test_transformed_functional__ncalls > 90 # == 95 + continue_annotation() + m_0 = fd.Function(space, name="m_0") + J = forward_J(m_0) + pause_annotation() + c = Control(m_0, riesz_map="l2") + J_hat = L2TransformedFunctional(J, c, alpha=1e-5) def error_norm(m): @@ -231,7 +243,7 @@ def forward_J(m, u_ref, alpha): m_0 = fd.Function(space, name="m_0") J = forward_J(m_0) pause_annotation() - c = Control(m_0) + c = Control(m_0, riesz_map="l2") J_hat = ReducedFunctional(J, c) @@ -250,7 +262,13 @@ def error_norm(m): error_norm_opt = error_norm(m_opt) print(f"{error_norm_opt=:.6g}") assert 1e-2 < error_norm_opt < 5e-2 - assert J_hat._test_transformed_functional__ncalls > 8 # == 10 + assert J_hat._test_transformed_functional__ncalls < 10 + + continue_annotation() + m_0 = fd.Function(space, name="m_0") + J = forward_J(m_0) + pause_annotation() + c = Control(m_0, riesz_map="l2") J_hat = L2TransformedFunctional(J, c, alpha=1e-5) From dbf6555414806d18832bb81f7673a580ea5dabfa Mon Sep 17 00:00:00 2001 From: "James R. Maddison" Date: Wed, 12 Nov 2025 12:27:55 +0000 Subject: [PATCH 30/33] Documentation fixes --- firedrake/adjoint/transformed_functional.py | 66 ++++++++++++--------- 1 file changed, 37 insertions(+), 29 deletions(-) diff --git a/firedrake/adjoint/transformed_functional.py b/firedrake/adjoint/transformed_functional.py index cbe40667cc..f53796f75b 100644 --- a/firedrake/adjoint/transformed_functional.py +++ b/firedrake/adjoint/transformed_functional.py @@ -1,10 +1,14 @@ +from collections.abc import Sequence from contextlib import contextmanager +from numbers import Real from operator import itemgetter +from typing import Optional, Union import firedrake as fd -from firedrake.adjoint import Control, ReducedFunctional +from firedrake.adjoint import Control, ReducedFunctional, Tape +from firedrake.functionspaceimpl import WithGeometry import finat -from pyadjoint import no_annotations +from pyadjoint import OverloadedType, no_annotations from pyadjoint.enlisting import Enlist from pyadjoint.reduced_functional import AbstractReducedFunctional import ufl @@ -30,13 +34,13 @@ class L2Cholesky: Parameters ---------- - space : WithGeometry + space DG space. - constant_jacobian : bool + constant_jacobian Whether the mass matrix is constant. """ - def __init__(self, space, *, constant_jacobian=True): + def __init__(self, space: WithGeometry, *, constant_jacobian: Optional[bool] = True): if fd.utils.complex_mode: raise NotImplementedError("complex not supported") @@ -72,8 +76,8 @@ def _pc(self): return pc - def C_inv_action(self, u): - """For the Cholesky factorization + def C_inv_action(self, u: Union[fd.Function, fd.Cofunction]) -> fd.Cofunction: + r"""For the Cholesky factorization ... math : @@ -84,14 +88,14 @@ def C_inv_action(self, u): Parameters ---------- - u : Function or Cofunction + u Compute :math:`C^{-1} \tilde{u}` where :math:`\tilde{u}` is the vector of degrees of freedom for :math:`u`. Returns ------- - v : Cofunction + firedrake.Cofunction Has vector of degrees of freedom :math:`C^{-1} \tilde{u}`. """ @@ -102,8 +106,8 @@ def C_inv_action(self, u): pc.applySymmetricLeft(u_v_s, v_v_s) return v - def C_T_inv_action(self, u): - """For the Cholesky factorization + def C_T_inv_action(self, u: Union[fd.Function, fd.Cofunction]) -> fd.Function: + r"""For the Cholesky factorization ... math : @@ -114,14 +118,14 @@ def C_T_inv_action(self, u): Parameters ---------- - u : Function or Cofunction + u Compute :math:`C^{-T} \tilde{u}` where :math:`\tilde{u}` is the vector of degrees of freedom for :math:`u`. Returns ------- - v : Function + firedrake.Function Has vector of degrees of freedom :math:`C^{-T} \tilde{u}`. """ @@ -139,26 +143,26 @@ class L2RieszMap(fd.RieszMap): Parameters ---------- - target : WithGeometry + target Function space. Keyword arguments are passed to the :class:`firedrake.RieszMap` constructor. """ - def __init__(self, target, **kwargs): + def __init__(self, target: WithGeometry, **kwargs): if not isinstance(target, fd.functionspaceimpl.WithGeometry): raise TypeError("Target must be a WithGeometry") super().__init__(target, ufl.L2, **kwargs) -def is_dg_space(space): +def is_dg_space(space: WithGeometry) -> bool: """Return whether a function space is DG. Parameters ---------- - space : WithGeometry + space The function space. Returns @@ -172,19 +176,19 @@ def is_dg_space(space): return e.is_dg() -def dg_space(space): +def dg_space(space: WithGeometry) -> WithGeometry: """Construct a DG space containing a given function space as a subspace. Parameters ---------- - space : WithGeometry + space A function space. Returns ------- - WithGeometry + firedrake.functionspaceimpl.WithGeometry A DG space containing `space` as a subspace. May be `space`. """ @@ -219,16 +223,16 @@ class L2TransformedFunctional(AbstractReducedFunctional): Parameters ---------- - functional : OverloadedType + functional Functional defining the optimization problem, :math:`J`. - controls : Control or Sequence[Control] + controls Controls. Must be :class:`firedrake.Function` objects. - space_D : None, WithGeometry, or Sequence[None or WithGeometry] + space_D DG space containing the control space. - riesz_map : L2RieszMap or Sequence[L2RieszMap] + riesz_map Used for projecting from the DG space onto the control space. Ignored for DG controls. - alpha : Real + alpha Modifies the functional, equivalent to adding an extra term to :math:`J \circ \Pi` @@ -238,12 +242,16 @@ class L2TransformedFunctional(AbstractReducedFunctional): e.g. in a minimization problem this adds a penalty term which can be used to avoid ill-posedness due to the use of a larger DG space. - tape : Tape + tape Tape used in evaluations involving :math:`J`. """ @no_annotations - def __init__(self, functional, controls, *, space_D=None, riesz_map=None, alpha=0, tape=None): + def __init__(self, functional: OverloadedType, controls: Union[Control, Sequence[Control]], *, + space_D: Optional[Union[None, WithGeometry, Sequence[Union[None, WithGeometry]]]] = None, + riesz_map: Optional[Union[L2RieszMap, Sequence[L2RieszMap]]] = None, + alpha: Optional[Real] = 0, + tape: Optional[Tape] = None): if not all(isinstance(control.control, fd.Function) for control in Enlist(controls)): raise TypeError("controls must be Function objects") @@ -331,13 +339,13 @@ def transform(C, u, space, space_D, riesz_map): return u.delist(tuple(map(itemgetter(0), vw))), u.delist(tuple(map(itemgetter(1), vw))) @no_annotations - def map_result(self, m): + def map_result(self, m: Union[fd.Function, Sequence[fd.Function]]) -> Union[fd.Function, Sequence[fd.Function]]: """Map the result of an optimization. Parameters ---------- - m : firedrake.Function or Sequence[firedrake.Function] + m The result of the optimization. Represents an expansion in an :math:`L^2` orthonormal basis for the DG space. From 03c22da99a25440c2cbc10e80d6d8e6f6fd16805 Mon Sep 17 00:00:00 2001 From: "James R. Maddison" Date: Wed, 12 Nov 2025 13:36:14 +0000 Subject: [PATCH 31/33] Documentation fixes --- firedrake/adjoint/transformed_functional.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/firedrake/adjoint/transformed_functional.py b/firedrake/adjoint/transformed_functional.py index f53796f75b..f8ec5b3951 100644 --- a/firedrake/adjoint/transformed_functional.py +++ b/firedrake/adjoint/transformed_functional.py @@ -145,9 +145,9 @@ class L2RieszMap(fd.RieszMap): target Function space. - - Keyword arguments are passed to the :class:`firedrake.RieszMap` - constructor. + kwargs + Keyword arguments are passed to the :class:`firedrake.RieszMap` + constructor. """ def __init__(self, target: WithGeometry, **kwargs): @@ -209,7 +209,7 @@ class L2TransformedFunctional(AbstractReducedFunctional): - :math:`J` is the functional definining an optimization problem. - :math:`\Pi` is the :math:`L^2` projection from a DG space containing - the control space as a subspace. + the control space as a subspace. - :math:`\Xi` represents a change of basis from an :math:`L^2` orthonormal basis to the finite element basis for the DG space. From b1a5ee9417c4feee61c719a25eda6dc36a21cad1 Mon Sep 17 00:00:00 2001 From: "James R. Maddison" Date: Thu, 13 Nov 2025 11:21:08 +0000 Subject: [PATCH 32/33] Documentation fixes --- firedrake/adjoint/transformed_functional.py | 93 ++++++++++++++++++--- 1 file changed, 82 insertions(+), 11 deletions(-) diff --git a/firedrake/adjoint/transformed_functional.py b/firedrake/adjoint/transformed_functional.py index f8ec5b3951..599cf9edaa 100644 --- a/firedrake/adjoint/transformed_functional.py +++ b/firedrake/adjoint/transformed_functional.py @@ -8,7 +8,8 @@ from firedrake.adjoint import Control, ReducedFunctional, Tape from firedrake.functionspaceimpl import WithGeometry import finat -from pyadjoint import OverloadedType, no_annotations +import pyadjoint +from pyadjoint import no_annotations from pyadjoint.enlisting import Enlist from pyadjoint.reduced_functional import AbstractReducedFunctional import ufl @@ -146,8 +147,7 @@ class L2RieszMap(fd.RieszMap): target Function space. kwargs - Keyword arguments are passed to the :class:`firedrake.RieszMap` - constructor. + Keyword arguments are passed to the base class constructor. """ def __init__(self, target: WithGeometry, **kwargs): @@ -226,7 +226,7 @@ class L2TransformedFunctional(AbstractReducedFunctional): functional Functional defining the optimization problem, :math:`J`. controls - Controls. Must be :class:`firedrake.Function` objects. + Controls. space_D DG space containing the control space. riesz_map @@ -247,7 +247,7 @@ class L2TransformedFunctional(AbstractReducedFunctional): """ @no_annotations - def __init__(self, functional: OverloadedType, controls: Union[Control, Sequence[Control]], *, + def __init__(self, functional: pyadjoint.OverloadedType, controls: Union[Control, Sequence[Control]], *, space_D: Optional[Union[None, WithGeometry, Sequence[Union[None, WithGeometry]]]] = None, riesz_map: Optional[Union[L2RieszMap, Sequence[L2RieszMap]]] = None, alpha: Optional[Real] = 0, @@ -346,7 +346,7 @@ def map_result(self, m: Union[fd.Function, Sequence[fd.Function]]) -> Union[fd.F ---------- m - The result of the optimization. Represents an expansion in an + The result of the optimization. Represents an expansion in the :math:`L^2` orthonormal basis for the DG space. Returns @@ -360,7 +360,22 @@ def map_result(self, m: Union[fd.Function, Sequence[fd.Function]]) -> Union[fd.F return m_J @no_annotations - def __call__(self, values): + def __call__(self, values: Union[fd.Function, Sequence[fd.Function]]) -> pyadjoint.AdjFloat: + """Evaluate the functional. + + Parameters + --------- + + value + Control values. + + Returns + ------- + + pyadjoint.AdjFloat + The functional value. + """ + values = Enlist(values) m_D, m_J = self._primal_transform(values) J = self._J(m_J) @@ -372,8 +387,26 @@ def __call__(self, values): return J @no_annotations - def derivative(self, adj_input=1.0, apply_riesz=False): - if adj_input != 1: + def derivative(self, adj_input: Optional[Real] = 1.0, + apply_riesz: Optional[bool] = False) -> Union[fd.Function, fd.Cofunction, list[fd.Function, fd.Cofunction]]: + """Evaluate the derivative. + + Parameters + --------- + + adj_value + Not supported. + apply_riesz + Whether to apply the Riesz map to the result. + + Returns + ------- + + Function, Cofunction, or list[Function or Cofunction] + The derivative. + """ + + if not isinstance(adj_input, Real) or adj_input != 1: raise NotImplementedError("adj_input != 1 not supported") u = Enlist(self._J.derivative()) @@ -396,7 +429,30 @@ def derivative(self, adj_input=1.0, apply_riesz=False): return u.delist(v) @no_annotations - def hessian(self, m_dot, hessian_input=None, evaluate_tlm=True, apply_riesz=False): + def hessian(self, m_dot: Union[fd.Function, Sequence[fd.Function]], + hessian_input: Optional[None] = None, evaluate_tlm: Optional[bool] = True, + apply_riesz: Optional[bool] = False) -> Union[fd.Function, fd.Cofunction, list[fd.Function, fd.Cofunction]]: + """Evaluate the Hessian action. + + Parameters + ---------- + + m_dot + Action direction. + hessian_input + Not supported. + evaluate_tlm + Whether to re-evaluate the tangent-linear. + apply_riesz + Whether to apply the Riesz map to the result. + + Returns + ------- + + Function, Cofunction, or list[Function or Cofunction] + The Hessian action. + """ + if hessian_input is not None: raise NotImplementedError("hessian_input not None not supported") @@ -422,7 +478,22 @@ def hessian(self, m_dot, hessian_input=None, evaluate_tlm=True, apply_riesz=Fals return u.delist(v) @no_annotations - def tlm(self, m_dot): + def tlm(self, m_dot: Union[fd.Function, Sequence[fd.Function]]) -> Union[fd.Function, list[fd.Function]]: + """Evaluate a Jacobian action. + + Parameters + ---------- + + m_dot + Action direction. + + Returns + ------- + + Function or list[Function] + The Jacobian action. + """ + m_dot = Enlist(m_dot) m_dot_D, m_dot_J = self._primal_transform(m_dot) tau_J = self._J.tlm(m_dot.delist(m_dot_J)) From 354be4dd75832df419570e87a918ef5fa682dd02 Mon Sep 17 00:00:00 2001 From: "James R. Maddison" Date: Thu, 13 Nov 2025 11:50:02 +0000 Subject: [PATCH 33/33] Documentation fixes --- firedrake/adjoint/transformed_functional.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/firedrake/adjoint/transformed_functional.py b/firedrake/adjoint/transformed_functional.py index 599cf9edaa..d3eeb2f46b 100644 --- a/firedrake/adjoint/transformed_functional.py +++ b/firedrake/adjoint/transformed_functional.py @@ -96,7 +96,7 @@ def C_inv_action(self, u: Union[fd.Function, fd.Cofunction]) -> fd.Cofunction: Returns ------- - firedrake.Cofunction + firedrake.cofunction.Cofunction Has vector of degrees of freedom :math:`C^{-1} \tilde{u}`. """ @@ -126,7 +126,7 @@ def C_T_inv_action(self, u: Union[fd.Function, fd.Cofunction]) -> fd.Function: Returns ------- - firedrake.Function + firedrake.function.Function Has vector of degrees of freedom :math:`C^{-T} \tilde{u}`. """ @@ -352,7 +352,7 @@ def map_result(self, m: Union[fd.Function, Sequence[fd.Function]]) -> Union[fd.F Returns ------- - firedrake.Function or Sequence[firedrake.Function] + firedrake.function.Function or Sequence[firedrake.function.Function] The mapped result in the original control space. """ @@ -402,7 +402,7 @@ def derivative(self, adj_input: Optional[Real] = 1.0, Returns ------- - Function, Cofunction, or list[Function or Cofunction] + firedrake.function.Function, firedrake.cofunction.Cofunction, or list[firedrake.function.Function or firedrake.cofunction.Cofunction] The derivative. """ @@ -449,7 +449,7 @@ def hessian(self, m_dot: Union[fd.Function, Sequence[fd.Function]], Returns ------- - Function, Cofunction, or list[Function or Cofunction] + firedrake.function.Function, firedrake.cofunction.Cofunction, or list[firedrake.function.Function or firedrake.cofunction.Cofunction] The Hessian action. """ @@ -490,7 +490,7 @@ def tlm(self, m_dot: Union[fd.Function, Sequence[fd.Function]]) -> Union[fd.Func Returns ------- - Function or list[Function] + firedrake.function.Function or list[firedrake.function.Function] The Jacobian action. """