From c9a244e2c79cc17f9e8c3ab91f9b929929be131b Mon Sep 17 00:00:00 2001 From: Chris Elrod Date: Thu, 4 Aug 2022 11:06:24 -0400 Subject: [PATCH] check for `iszero(partials(x))` in pow --- src/dual.jl | 2 +- test/DualTest.jl | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/src/dual.jl b/src/dual.jl index 8e99c9fe..c1522410 100644 --- a/src/dual.jl +++ b/src/dual.jl @@ -534,7 +534,7 @@ for f in (:(Base.:^), :(NaNMath.pow)) begin v = value(x) expv = ($f)(v, y) - if y == zero(y) + if y == zero(y) || iszero(partials(x)) new_partials = zero(partials(x)) else new_partials = partials(x) * y * ($f)(v, y - 1) diff --git a/test/DualTest.jl b/test/DualTest.jl index 66a1866a..ade38c43 100644 --- a/test/DualTest.jl +++ b/test/DualTest.jl @@ -12,6 +12,7 @@ using DiffRules import Calculus struct TestTag end +struct OuterTestTag end samerng() = MersenneTwister(1) @@ -26,6 +27,8 @@ dual_isapprox(a::Dual{T,T1,T2}, b::Dual{T3,T4,T5}) where {T,T1,T2,T3,T4,T5} = er ForwardDiff.:≺(::Type{TestTag()}, ::Int) = true ForwardDiff.:≺(::Int, ::Type{TestTag()}) = false +ForwardDiff.:≺(::Type{TestTag}, ::Type{OuterTestTag}) = true +ForwardDiff.:≺(::Type{OuterTestTag}, ::Type{TestTag}) = false for N in (0,3), M in (0,4), V in (Int, Float32) println(" ...testing Dual{TestTag(),$V,$N} and Dual{TestTag(),Dual{TestTag(),$V,$M},$N}") @@ -553,6 +556,9 @@ end @test pow(x3, 2) === x3^2 === x3 * x3 @test pow(x2, 1) === x2^1 === x2 @test pow(x1, 0) === x1^0 === Dual{:t1}(1.0, 0.0) + y = Dual{typeof(TestTag())}(1.0, 0.0, 1.0); + x = Dual{typeof(OuterTestTag())}(0*y, 0*y); + @test iszero(ForwardDiff.partials(ForwardDiff.partials(x^y)[1])) end @testset "Type min/max" begin