-
Notifications
You must be signed in to change notification settings - Fork 177
Description
Consider
from firedrake import *
import numpy
mesh = UnitSquareMesh(3, 3)
V = FunctionSpace(mesh, "DG", 0)
x, y = SpatialCoordinate(mesh)
g = interpolate(x*y, V)
f = interpolate(x**2 - y**4, V)
h = interpolate(max_value(x*y, x**2 - y**4), V)
# copy g
z = Function(g)
# This should mean z_i \gets max(z_i, f_i)
interpolate(f, z, access=MAX)
by_hand = numpy.where(f.dat.data_ro > g.dat.data_ro, f.dat.data_ro, g.dat.data_ro)
print("By hand:", numpy.allclose(by_hand, h.dat.data_ro))
print("Interpolate:", numpy.allclose(by_hand, z.dat.data_ro))
=>
By hand: True
Interpolate: FalseIf we look at the wrapper + kernel generated for the last expression, we can see the problem:
static void expression_kernel(double *__restrict__ A, double const *__restrict__ w_0)
{
A[0] = A[0] + w_0[0];
}
void wrap_expression_kernel(int32_t const start, int32_t const end, double *__restrict__ dat1, double const *__restrict__ dat0, int32_t const *__restrict__ map0)
{
double t0[1];
double t1[1];
for (int32_t n = start; n <= -1 + end; ++n)
{
t1[0] = dat1[map0[n]];
t0[0] = dat0[map0[n]];
expression_kernel(&(t1[0]), &(t0[0]));
dat1[map0[n]] = fmax(dat1[map0[n]], t1[0]);
}
}TSFC always generates kernels that accumulate into the output argument (with the requirement that the caller zeros them beforehand). This is in part a legacy of at first only supporting assembly kernel generation, and so it doesn't have access descriptors for the output variable.
PyOP2 on the other hand, does know about access descriptors and builds the wrapper for MAX assuming that the kernel is going to reduce with MAX. It expected that you wrote the inner kernel as:
static void expression_kernel(double *__restrict__ A, double const *__restrict__ w_0)
{
A[0] = fmax(A[0], w_0[0]);
}What's the best way to fix this. Change the semantics of PyOP2 kernels so that if you reduce then we say "PyOP2 generates appropriate pre/post staging code to perform the reduction against the global version, you should just put the right thing in the local temporary"? Or, extend TSFC codegen to support access descriptors on arguments, so that it can appropriately deal with the return values.