Added fma, with tests.#203
Conversation
Tests required a third value set (PRIMAL etc) in the testing framework, which was also added. Fixes issue #202.
| @inline function Base.fma(x::Dual, y::Dual, z::Dual) | ||
| vx, vy = value(x), value(y) | ||
| result = fma(vx, vy, value(z)) | ||
| return Dual(result, |
There was a problem hiding this comment.
Wouldn't it be better to define a fma function for Partials as well, so that fma instructions will be used?
There was a problem hiding this comment.
Did you mean for the
Base.fma(x::Dual, y::Real, z::Dual)case? For the others I don't see how it would help.
There was a problem hiding this comment.
I was thinking something like (where R is the value part and E is the partials part):
Partials(x * y + z)
=
Partials[(Rx * [Ey1, Ey2, ...]) +
(Ry * [Ex1, Ex2, ...])
+ [Ez1, Ez2, ...]
]
=
[Rx * Ey1 + Ry * Ex1 + Ez1,
Rx * Ey2 + Ry * Ex2 + Ez2,
...]
=
[fma(Rx, Ey1, fma(Ry, Ex1, Ez1)),
fma(Rx, Ey2, fma(Ry, Ex2, Ez2)),
...]
An implementation would be:
@generated function Base.fma{N}(x::Dual{N}, y::Dual{N}, z::Dual{N})
ex = Expr(:tuple, [:(fma(value(x), partials(y)[$i], fma(value(y), partials(x)[$i], partials(z)[$i]))) for i in 1:N]...)
return quote
v = fma(value(x), value(y), value(z))
Dual(v, $ex)
end
end
so that:
julia> x, y, z = Dual(rand(9)...) , Dual(rand(9)...), Dual(rand(9)...);
julia> x*y + z - fma(x,y,z)
Dual(-5.551115123125783e-17,0.0,0.0,0.0,0.0,0.0,0.0,-2.220446049250313e-16,0.0)
This gives a bunch of fused multiply adds in the generated code:
vfmadd213sd 8(%rcx), %xmm2, %xmm8
vfmadd231sd 8(%rdx), %xmm1, %xmm8
but I'm not sure how big of a performance difference it is in general.
There was a problem hiding this comment.
Some benchmarking with and without O3 shows that the fma version is just barely (5-10%) faster than a function f(x,y,z) = x*y + z (with my processor at least). Might not be worth the extra complexity.
There was a problem hiding this comment.
My understanding is that the point of fma is not speed but precision. Try
for i in 1:100
x, y, z = rand(3)
if fma(x,y,z) != x*y+z
println("*")
end
endThere was a problem hiding this comment.
I thought it was both. Anyway, regarding accuracy, I guess the function I posted has a purpose then, since it does more fusing...?
There was a problem hiding this comment.
Possibly. Could you perhaps merge my PR as is, then you could add your function?
There was a problem hiding this comment.
I let @jrevels do the merging since this is his baby :)
|
Thanks! |
Tests required a third value set (PRIMAL etc) in the testing
framework, which was also added.
Fixes issue #202.