From 88e70b42011b6a5fc2e4333ac83e7d2cafc72950 Mon Sep 17 00:00:00 2001 From: Frames White Date: Tue, 5 Mar 2024 13:58:37 +0800 Subject: [PATCH 1/2] Fix frule for static array constructor that converts eltype --- src/extra_rules.jl | 10 ++++++++-- test/extra_rules.jl | 35 +++++++++++++++++++++++++++++++++++ test/runtests.jl | 1 + 3 files changed, 44 insertions(+), 2 deletions(-) create mode 100644 test/extra_rules.jl diff --git a/src/extra_rules.jl b/src/extra_rules.jl index b9bcff7e..303419cd 100644 --- a/src/extra_rules.jl +++ b/src/extra_rules.jl @@ -179,8 +179,14 @@ end Base.view(t::Tangent{T}, inds) where T<:SVector = view(T(ChainRulesCore.backing(t.data)), inds) Base.getindex(t::Tangent{<:SVector, <:NamedTuple}, ind::Int) = ChainRulesCore.backing(t.data)[ind] -function ChainRules.frule((_, ∂x), ::Type{SArray{S, T, N, L}}, x::NTuple{L,Any}) where {S, T, N, L} - SArray{S, T, N, L}(x), SArray{S}(∂x) +function ChainRules.frule( + (_, ∂x)::Tuple{Any, Tangent{TUP}}, + ::Type{SArray{S, T, N, L}}, + x::TUP, +) where {L, TUP<:NTuple{L, Number}, S, T<:Number, N} + y = SArray{S, T, N, L}(x) + ∂y = SArray{S, T, N, L}(ChainRulesCore.backing(∂x)) + return y, ∂y end @ChainRulesCore.non_differentiable StaticArrays.promote_tuple_eltype(T) diff --git a/test/extra_rules.jl b/test/extra_rules.jl new file mode 100644 index 00000000..ac14cd7a --- /dev/null +++ b/test/extra_rules.jl @@ -0,0 +1,35 @@ +using Diffractor +using StaticArrays +using ChainRulesCore +using Test + +@testset "StaticArrays constructor" begin + #frule(::Tuple{ChainRulesCore.NoTangent, ChainRulesCore.Tangent{Tuple{Int64, Vararg{Float64, 9}}, Tuple{Int64, Vararg{Float64, 9}}}}, ::Type{StaticArraysCore.SVector{10, Float64}}, x::Tuple{Int64, Vararg{Float64, 9}}) + # @ Diffractor ~/.julia/packages/Diffractor/yCsbI/src/extra_rules.jl:183 + + @testset "homogenious type" begin + x = (10.0, 20.0, 30.0) + ẋ = zero_tangent(x) + y, ẏ = frule((NoTangent(), ẋ), StaticArraysCore.SVector{3, Float64}, x) + @test y == @SVector [10.0, 20.0, 30.0] + @test ẏ == @SVector [0.0, 0.0, 0.0] + end + + @testset "convertable type" begin + x = (10, 20.0, 30.0) + ẋ = zero_tangent(x) + y, ẏ = frule((NoTangent(), ẋ), StaticArraysCore.SVector{3, Float64}, x) + # all are float + @test y == @SVector [10.0, 20.0, 30.0] + @test ẏ == @SVector [0.0, 0.0, 0.0] + end + + @testset "convertable type with ZeroTangent()" begin + x = (10, 20.0, 30.0) + ẋ = Tangent{typeof(x)}(ZeroTangent(), 1.0, 2.0) + y, ẏ = frule((NoTangent(), ẋ), StaticArraysCore.SVector{3, Float64}, x) + # all are float + @test y == @SVector [10.0, 20.0, 30.0] + @test ẏ == @SVector [0.0, 1.0, 2.0] + end +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 0acd3416..01cbc825 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -14,6 +14,7 @@ const bwd = Diffractor.PrimeDerivativeBack @testset verbose=true "Diffractor.jl" begin # overall testset, ensures all tests run @testset "$file" for file in ( + "extra_rules.jl" "stage2_fwd.jl", "tangent.jl", "forward_diff_no_inf.jl", From 2ddb972f0ea1098498ed64129bdb09b9704598c4 Mon Sep 17 00:00:00 2001 From: Frames White Date: Tue, 5 Mar 2024 14:45:44 +0800 Subject: [PATCH 2/2] Make obvious the type of x --- test/extra_rules.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/extra_rules.jl b/test/extra_rules.jl index ac14cd7a..2b860048 100644 --- a/test/extra_rules.jl +++ b/test/extra_rules.jl @@ -16,7 +16,7 @@ using Test end @testset "convertable type" begin - x = (10, 20.0, 30.0) + x::Tuple{Int, Float64, Float64} = (10, 20.0, 30.0) ẋ = zero_tangent(x) y, ẏ = frule((NoTangent(), ẋ), StaticArraysCore.SVector{3, Float64}, x) # all are float