Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions firedrake/formmanipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import collections

from ufl import as_tensor, as_vector, split
from ufl.classes import Zero, FixedIndex, ListTensor, ZeroBaseForm
from ufl.classes import Zero, FixedIndex, ListTensor, ZeroBaseForm, Interpolate
from ufl.algorithms.map_integrands import map_integrand_dags
from ufl.algorithms import expand_derivatives
from ufl.corealg.map_dag import MultiFunction, map_expr_dags
Expand Down Expand Up @@ -71,6 +71,11 @@ def split(self, form, argument_indices):
args = form.arguments()
self._arg_cache = {}
self.blocks = dict(enumerate(map(as_tuple, argument_indices)))

if isinstance(form, Interpolate) and not args:
dual_arg, _ = form.argument_slots()
args = dual_arg.arguments()

if len(args) == 0:
# Functional can't be split
return form
Expand Down Expand Up @@ -191,14 +196,14 @@ def interpolate(self, o, operand):
return self(ZeroBaseForm(o.arguments()))

dual_arg, _ = o.argument_slots()
if len(dual_arg.arguments()) == 1 or len(dual_arg.arguments()[-1].function_space()) == 1:
# The dual argument has been contracted or does not need to be split
dual_arguments = dual_arg.arguments()
if len(dual_arguments) == 1 and len(dual_arguments[0].function_space()) == 1:
return o._ufl_expr_reconstruct_(operand, dual_arg)

if not isinstance(dual_arg, Coargument):
if not isinstance(dual_arg, Coargument | Cofunction):
raise NotImplementedError(f"I do not know how to split an Interpolate with a {type(dual_arg).__name__}.")

indices = self.blocks[dual_arg.number()]
indices = self.blocks[dual_arguments[0].number()]
V = dual_arg.function_space()

# Split the target (dual) argument
Expand Down Expand Up @@ -254,6 +259,11 @@ def split_form(form, diagonal=False):
"""
splitter = ExtractSubBlock()
args = form.arguments()

if isinstance(form, Interpolate) and not args:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the args be added as a kwarg?

dual_arg, _ = form.argument_slots()
args = dual_arg.arguments()

shape = tuple(len(a.function_space()) for a in args)
forms = []
rank = len(shape)
Expand Down
25 changes: 1 addition & 24 deletions firedrake/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,20 +726,8 @@ def _get_callable(self, tensor=None, bcs=None, mat_type=None, sub_mat_type=None)
# We need to split the target space V and generate separate kernels
if self.rank == 2:
expressions = {(0,): self.ufl_interpolate}
elif isinstance(self.dual_arg, Coargument):
# Split in the coargument
expressions = dict(split_form(self.ufl_interpolate))
else:
assert isinstance(self.dual_arg, Cofunction)
# Split in the cofunction: split_form can only split in the coargument
# Replace the cofunction with a coargument to construct the Jacobian
interp = self.ufl_interpolate._ufl_expr_reconstruct_(self.operand, self.target_space)
# Split the Jacobian into blocks
interp_split = dict(split_form(interp))
# Split the cofunction
dual_split = dict(split_form(self.dual_arg))
# Combine the splits by taking their action
expressions = {i: action(interp_split[i], dual_split[i[-1:]]) for i in interp_split}
expressions = dict(split_form(self.ufl_interpolate))

# Interpolate each sub expression into each function space
for indices, sub_expr in expressions.items():
Expand Down Expand Up @@ -1649,14 +1637,6 @@ def _get_sub_interpolators(
# See https://github.com/firedrakeproject/firedrake/issues/4668
space_equals = lambda V1, V2: V1 == V2 and V1.parent == V2.parent and V1.index == V2.index

# We need a Coargument in order to split the Interpolate
needs_action = not any(isinstance(a, Coargument) for a in self.interpolate_args)
if needs_action:
# Split the dual argument
dual_split = dict(split_form(self.dual_arg))
# Create the Jacobian to be split into blocks
self.ufl_interpolate = self.ufl_interpolate._ufl_expr_reconstruct_(self.operand, self.target_space)

# Get sub-interpolators and sub-bcs for each block
Isub: dict[tuple[int] | tuple[int, int], tuple[Interpolator, list[DirichletBC]]] = {}
for indices, form in split_form(self.ufl_interpolate):
Expand All @@ -1667,9 +1647,6 @@ def _get_sub_interpolators(
for space, index in zip(spaces, indices):
subspace = space.sub(index)
sub_bcs.extend(bc for bc in bcs if space_equals(bc.function_space(), subspace))
if needs_action:
# Take the action of each sub-cofunction against each block
form = action(form, dual_split[indices[-1:]])
Isub[indices] = (get_interpolator(form), sub_bcs)

return Isub
Expand Down
Loading