From cb3b18785b2a968637605f06a89039d53569b818 Mon Sep 17 00:00:00 2001 From: N5N3 <2642243996@qq.com> Date: Wed, 9 Feb 2022 22:18:52 +0800 Subject: [PATCH 1/6] Make `StructArrayStyle` track inputs dimension fix #185 --- src/structarray.jl | 14 +++++++++++--- test/runtests.jl | 18 +++++++++++++++++- 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/src/structarray.jl b/src/structarray.jl index 650a6b44..c676d210 100644 --- a/src/structarray.jl +++ b/src/structarray.jl @@ -445,7 +445,15 @@ end # broadcast import Base.Broadcast: BroadcastStyle, ArrayStyle, AbstractArrayStyle, Broadcasted, DefaultArrayStyle -struct StructArrayStyle{Style} <: AbstractArrayStyle{Any} end +struct StructArrayStyle{S,N} <: AbstractArrayStyle{N} end + +# Here we define the dimension tracking behaviour of StructArrayStyle +function StructArrayStyle{S,M}(::Val{N}) where {S,M,N} + if S <: AbstractArrayStyle{M} + return StructArrayStyle{typeof(S(Val(N))),N}() + end + return StructArrayStyle{S,N}() +end @inline combine_style_types(::Type{A}, args...) where A<:AbstractArray = combine_style_types(BroadcastStyle(A), args...) @@ -455,9 +463,9 @@ combine_style_types(s::BroadcastStyle) = s Base.@pure cst(::Type{SA}) where SA = combine_style_types(array_types(SA).parameters...) -BroadcastStyle(::Type{SA}) where SA<:StructArray = StructArrayStyle{typeof(cst(SA))}() +BroadcastStyle(::Type{SA}) where SA<:StructArray = StructArrayStyle{typeof(cst(SA)),ndims(SA)}() -Base.similar(bc::Broadcasted{StructArrayStyle{S}}, ::Type{ElType}) where {S<:DefaultArrayStyle,N,ElType} = +Base.similar(bc::Broadcasted{<:StructArrayStyle{S}}, ::Type{ElType}) where {S<:DefaultArrayStyle,ElType} = isstructtype(ElType) ? similar(StructArray{ElType}, axes(bc)) : similar(Array{ElType}, axes(bc)) # for aliasing analysis during broadcast diff --git a/test/runtests.jl b/test/runtests.jl index 1fe7be57..20412dc5 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -926,8 +926,24 @@ Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray}}, ::Type{El # used inside of broadcast but we also test it here explicitly @test isa(@inferred(Base.dataids(s)), NTuple{N, UInt} where {N}) - s = StructArray{ComplexF64}((MyArray(rand(2,2)), MyArray(rand(2,2)))) + s = StructArray{ComplexF64}((MyArray(rand(2)), MyArray(rand(2)))) @test_throws MethodError s .+ s + + # test for dimensionality track + @test Base.broadcasted(+, s, s) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{1}} + @test Base.broadcasted(+, s, [1,2]) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{1}} + @test Base.broadcasted(+, s, [1;;2]) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{2}} + @test Base.broadcasted(+, [1;;;2], s) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{3}} + + a = StructArray([1;2+im]) + b = StructArray([1;;2+im]) + @test a .+ b == a .+ collect(b) == collect(a) .+ b == collect(a) .+ collect(b) + + # issue #185 + A = StructArray(randn(ComplexF64, 3, 3)) + B = randn(ComplexF64, 3, 3) + c = StructArray(randn(ComplexF64, 3)) + @test (A .= B .* c) === A end @testset "staticarrays" begin From 7a8fbf418ad8b2e11599b18f7dbf5f4a5e9c7938 Mon Sep 17 00:00:00 2001 From: N5N3 <2642243996@qq.com> Date: Sun, 20 Feb 2022 22:07:49 +0800 Subject: [PATCH 2/6] Add test for unstable broadcast --- test/runtests.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/runtests.jl b/test/runtests.jl index 20412dc5..af86c7d6 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -938,6 +938,7 @@ Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray}}, ::Type{El a = StructArray([1;2+im]) b = StructArray([1;;2+im]) @test a .+ b == a .+ collect(b) == collect(a) .+ b == collect(a) .+ collect(b) + @test a .+ Any[1] isa StructArray # issue #185 A = StructArray(randn(ComplexF64, 3, 3)) From a5e68f4be7ad3dccdcc5a145ffad595b4553b239 Mon Sep 17 00:00:00 2001 From: N5N3 <2642243996@qq.com> Date: Sun, 20 Feb 2022 22:18:51 +0800 Subject: [PATCH 3/6] fix test --- test/runtests.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index af86c7d6..2f7beddd 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -931,9 +931,9 @@ Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray}}, ::Type{El # test for dimensionality track @test Base.broadcasted(+, s, s) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{1}} - @test Base.broadcasted(+, s, [1,2]) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{1}} - @test Base.broadcasted(+, s, [1;;2]) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{2}} - @test Base.broadcasted(+, [1;;;2], s) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{3}} + @test Base.broadcasted(+, s, 1:2) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{1}} + @test Base.broadcasted(+, s, reshape(1:2,1,2)) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{2}} + @test Base.broadcasted(+, reshape(1:2,1,1,2), s) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{3}} a = StructArray([1;2+im]) b = StructArray([1;;2+im]) From 2dfc5e7708222eefabc7f17debbedba84c815423 Mon Sep 17 00:00:00 2001 From: piever Date: Sun, 20 Feb 2022 15:27:46 +0100 Subject: [PATCH 4/6] style tweaks --- src/structarray.jl | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/src/structarray.jl b/src/structarray.jl index c676d210..27d52d0d 100644 --- a/src/structarray.jl +++ b/src/structarray.jl @@ -445,14 +445,12 @@ end # broadcast import Base.Broadcast: BroadcastStyle, ArrayStyle, AbstractArrayStyle, Broadcasted, DefaultArrayStyle -struct StructArrayStyle{S,N} <: AbstractArrayStyle{N} end +struct StructArrayStyle{S, N} <: AbstractArrayStyle{N} end -# Here we define the dimension tracking behaviour of StructArrayStyle -function StructArrayStyle{S,M}(::Val{N}) where {S,M,N} - if S <: AbstractArrayStyle{M} - return StructArrayStyle{typeof(S(Val(N))),N}() - end - return StructArrayStyle{S,N}() +# Here we define the dimension tracking behavior of StructArrayStyle +function StructArrayStyle{S, M}(::Val{N}) where {S, M, N} + T = S <: AbstractArrayStyle{M} ? typeof(S(Val(N))) : S + return StructArrayStyle{T,N}() end @inline combine_style_types(::Type{A}, args...) where A<:AbstractArray = @@ -463,9 +461,9 @@ combine_style_types(s::BroadcastStyle) = s Base.@pure cst(::Type{SA}) where SA = combine_style_types(array_types(SA).parameters...) -BroadcastStyle(::Type{SA}) where SA<:StructArray = StructArrayStyle{typeof(cst(SA)),ndims(SA)}() +BroadcastStyle(::Type{SA}) where SA<:StructArray = StructArrayStyle{typeof(cst(SA)), ndims(SA)}() -Base.similar(bc::Broadcasted{<:StructArrayStyle{S}}, ::Type{ElType}) where {S<:DefaultArrayStyle,ElType} = +Base.similar(bc::Broadcasted{StructArrayStyle{S, N}}, ::Type{ElType}) where {S<:DefaultArrayStyle, N, ElType} = isstructtype(ElType) ? similar(StructArray{ElType}, axes(bc)) : similar(Array{ElType}, axes(bc)) # for aliasing analysis during broadcast From 61d4aa79c2487e40bb34fe6c279a8fbe029ca363 Mon Sep 17 00:00:00 2001 From: piever Date: Sun, 20 Feb 2022 15:28:05 +0100 Subject: [PATCH 5/6] bump version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index f67e5041..66739491 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "StructArrays" uuid = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" -version = "0.6.4" +version = "0.6.5" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" From 0a0d142b7a3ed73d3d224cab2fe0a569d82a3961 Mon Sep 17 00:00:00 2001 From: piever Date: Sun, 20 Feb 2022 15:31:01 +0100 Subject: [PATCH 6/6] style --- src/structarray.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/structarray.jl b/src/structarray.jl index 27d52d0d..114c514e 100644 --- a/src/structarray.jl +++ b/src/structarray.jl @@ -450,18 +450,18 @@ struct StructArrayStyle{S, N} <: AbstractArrayStyle{N} end # Here we define the dimension tracking behavior of StructArrayStyle function StructArrayStyle{S, M}(::Val{N}) where {S, M, N} T = S <: AbstractArrayStyle{M} ? typeof(S(Val(N))) : S - return StructArrayStyle{T,N}() + return StructArrayStyle{T, N}() end -@inline combine_style_types(::Type{A}, args...) where A<:AbstractArray = +@inline combine_style_types(::Type{A}, args...) where {A<:AbstractArray} = combine_style_types(BroadcastStyle(A), args...) -@inline combine_style_types(s::BroadcastStyle, ::Type{A}, args...) where A<:AbstractArray = +@inline combine_style_types(s::BroadcastStyle, ::Type{A}, args...) where {A<:AbstractArray} = combine_style_types(Broadcast.result_style(s, BroadcastStyle(A)), args...) combine_style_types(s::BroadcastStyle) = s -Base.@pure cst(::Type{SA}) where SA = combine_style_types(array_types(SA).parameters...) +Base.@pure cst(::Type{SA}) where {SA} = combine_style_types(array_types(SA).parameters...) -BroadcastStyle(::Type{SA}) where SA<:StructArray = StructArrayStyle{typeof(cst(SA)), ndims(SA)}() +BroadcastStyle(::Type{SA}) where {SA<:StructArray} = StructArrayStyle{typeof(cst(SA)), ndims(SA)}() Base.similar(bc::Broadcasted{StructArrayStyle{S, N}}, ::Type{ElType}) where {S<:DefaultArrayStyle, N, ElType} = isstructtype(ElType) ? similar(StructArray{ElType}, axes(bc)) : similar(Array{ElType}, axes(bc))