diff --git a/src/structarray.jl b/src/structarray.jl index 3175f84d..34fe3bd1 100644 --- a/src/structarray.jl +++ b/src/structarray.jl @@ -22,14 +22,18 @@ index_type(::Type{NamedTuple{names, types}}) where {names, types} = index_type(t index_type(::Type{Tuple{}}) = Int function index_type(::Type{T}) where {T<:Tuple} S, U = tuple_type_head(T), tuple_type_tail(T) - IndexStyle(S) isa IndexCartesian ? CartesianIndex{ndims(S)} : index_type(U) + IndexStyle(S) isa IndexCartesian ? CartesianIndex{ndims(S)} : index_type(U) end index_type(::Type{StructArray{T, N, C, I}}) where {T, N, C, I} = I +array_types(::Type{StructArray{T, N, C, I}}) where {T, N, C, I} = array_types(C) +array_types(::Type{NamedTuple{names, types}}) where {names, types} = types +array_types(::Type{TT}) where {TT<:Tuple} = TT + function StructArray{T}(c::C) where {T, C<:Tup} cols = strip_params(staticschema(T))(c) - N = isempty(cols) ? 1 : ndims(cols[1]) + N = isempty(cols) ? 1 : ndims(cols[1]) StructArray{T, N, typeof(cols)}(cols) end @@ -225,3 +229,21 @@ function Base.showarg(io::IO, s::StructArray{T}, toplevel) where T showfields(io, Tuple(fieldarrays(s))) toplevel && print(io, " with eltype ", T) end + +# broadcast +import Base.Broadcast: BroadcastStyle, ArrayStyle, AbstractArrayStyle, Broadcasted, DefaultArrayStyle + +struct StructArrayStyle{Style} <: AbstractArrayStyle{Any} end + +@inline combine_style_types(::Type{A}, args...) where A<:AbstractArray = + combine_style_types(BroadcastStyle(A), args...) +@inline combine_style_types(s::BroadcastStyle, ::Type{A}, args...) where A<:AbstractArray = + combine_style_types(Broadcast.result_style(s, BroadcastStyle(A)), args...) +combine_style_types(s::BroadcastStyle) = s + +Base.@pure cst(::Type{SA}) where SA = combine_style_types(array_types(SA).parameters...) + +BroadcastStyle(::Type{SA}) where SA<:StructArray = StructArrayStyle{typeof(cst(SA))}() + +Base.similar(bc::Broadcasted{StructArrayStyle{S}}, ::Type{ElType}) where {S<:DefaultArrayStyle,N,ElType} = + isstructtype(ElType) ? similar(StructArray{ElType}, axes(bc)) : similar(Array{ElType}, axes(bc)) diff --git a/test/runtests.jl b/test/runtests.jl index 756a63cc..463e51b2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -714,3 +714,32 @@ Adapt.adapt_storage(::ArrayConverter, xs::AbstractArray) = convert(Array, xs) @test t.b.c isa Array @test t.b.d isa Array end + +struct MyArray{T,N} <: AbstractArray{T,N} + A::Array{T,N} +end +MyArray{T}(::UndefInitializer, sz::Dims) where T = MyArray(Array{T}(undef, sz)) +Base.IndexStyle(::Type{<:MyArray}) = IndexLinear() +Base.getindex(A::MyArray, i::Int) = A.A[i] +Base.setindex!(A::MyArray, val, i::Int) = A.A[i] = val +Base.size(A::MyArray) = Base.size(A.A) +Base.BroadcastStyle(::Type{<:MyArray}) = Broadcast.ArrayStyle{MyArray}() +Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray}}, ::Type{ElType}) where ElType = + MyArray{ElType}(undef, size(bc)) + +@testset "broadcast" begin + s = StructArray{ComplexF64}((rand(2,2), rand(2,2))) + @test isa(@inferred(s .+ s), StructArray) + @test (s .+ s).re == 2*s.re + @test (s .+ s).im == 2*s.im + @test isa(@inferred(broadcast(t->1, s)), Array) + @test all(x->x==1, broadcast(t->1, s)) + @test isa(@inferred(s .+ 1), StructArray) + @test s .+ 1 == StructArray{ComplexF64}((s.re .+ 1, s.im)) + r = rand(2,2) + @test isa(@inferred(s .+ r), StructArray) + @test s .+ r == StructArray{ComplexF64}((s.re .+ r, s.im)) + + s = StructArray{ComplexF64}((MyArray(rand(2,2)), MyArray(rand(2,2)))) + @test_throws MethodError s .+ s +end