diff --git a/src/vector_of_array.jl b/src/vector_of_array.jl index 8265586f..32a9a1c9 100644 --- a/src/vector_of_array.jl +++ b/src/vector_of_array.jl @@ -576,37 +576,6 @@ function Base.:(:)(start::RaggedEnd, step::Integer, stop::Integer) end Base.broadcastable(x::RaggedRange) = Ref(x) -# Specialized method for type stability when last index is RaggedEnd with dim=0 (resolved column index) -# This handles the common case: vec[i, end] where end -> RaggedEnd(0, lastindex) -Base.@propagate_inbounds function Base.getindex( - A::AbstractVectorOfArray, i::Int, re::RaggedEnd - ) - if re.dim == 0 - # Sentinel case: RaggedEnd(0, offset) means offset is the resolved column index - return A.u[re.offset][i] - else - # Non-sentinel case: resolve the ragged index for the last column - col = lastindex(A.u) - resolved_idx = lastindex(A.u[col], re.dim) + re.offset - return A.u[col][i, resolved_idx] - end -end - -# Specialized method for type stability when first index is RaggedEnd (row dimension) -# This handles the common case: vec[end, col] where end -> RaggedEnd(1, 0) -Base.@propagate_inbounds function Base.getindex( - A::AbstractVectorOfArray, re::RaggedEnd, col::Int - ) - if re.dim == 0 - # Sentinel case: RaggedEnd(0, offset) means offset is a plain index - return A.u[col][re.offset] - else - # Non-sentinel case: resolve the ragged index for the given column - resolved_idx = lastindex(A.u[col], re.dim) + re.offset - return A.u[col][resolved_idx] - end -end - @inline function _is_ragged_dim(VA::AbstractVectorOfArray, d::Integer) length(VA.u) <= 1 && return false first_size = size(VA.u[1], d) @@ -740,8 +709,8 @@ Base.@propagate_inbounds function _getindex( return getindex(A, all_variable_symbols(A), args...) end -@inline _column_indices(VA::AbstractVectorOfArray, idx) = idx === Colon() ? - eachindex(VA.u) : idx +@inline _column_indices(VA::AbstractVectorOfArray, idx) = idx +@inline _column_indices(VA::AbstractVectorOfArray, idx::Colon) = eachindex(VA.u) @inline function _column_indices(VA::AbstractVectorOfArray, idx::AbstractArray{Bool}) return findall(idx) end @@ -874,106 +843,115 @@ end n = ndims(A) # Special-case when user provided one fewer index than ndims(A): last index is column selector. if length(I) == n - 1 - raw_cols = last(I) - # Determine if we're doing column selection (preserve type) or inner-dimension selection (don't preserve) - is_column_selection = if raw_cols isa RaggedEnd && raw_cols.dim != 0 - false # Inner dimension - don't preserve type - elseif raw_cols isa RaggedRange && raw_cols.dim != 0 - true # Inner dimension range converted to column range - DO preserve type - else - true # Column selection (dim == 0 or not ragged) - end + return _ragged_getindex_nm1dims(A, I...) + else + return _ragged_getindex_full(A, I...) + end +end - # If the raw selector is a RaggedEnd/RaggedRange referring to inner dims, reinterpret as column selector. - cols = if raw_cols isa RaggedEnd && raw_cols.dim != 0 - lastindex(A.u) + raw_cols.offset - elseif raw_cols isa RaggedRange && raw_cols.dim != 0 - # Convert inner-dimension range to column range by resolving bounds - start_val = raw_cols.start < 0 ? lastindex(A.u) + raw_cols.start : raw_cols.start - stop_val = lastindex(A.u) + raw_cols.offset - Base.range(start_val; step = raw_cols.step, stop = stop_val) - else - _column_indices(A, raw_cols) - end - prefix = Base.front(I) - if cols isa Int - resolved_prefix = _resolve_ragged_indices(prefix, A, cols) - inner_nd = ndims(A.u[cols]) - n_missing = inner_nd - length(resolved_prefix) - padded = if n_missing > 0 - if all(idx -> idx === Colon(), resolved_prefix) - (resolved_prefix..., ntuple(_ -> Colon(), n_missing)...) - else - ( - resolved_prefix..., - (lastindex(A.u[cols], length(resolved_prefix) + i) for i in 1:n_missing)..., - ) - end +@inline function _ragged_getindex_nm1dims(A::AbstractVectorOfArray, I...) + raw_cols = last(I) + # Determine if we're doing column selection (preserve type) or inner-dimension selection (don't preserve) + is_column_selection = if raw_cols isa RaggedEnd && raw_cols.dim != 0 + false # Inner dimension - don't preserve type + elseif raw_cols isa RaggedRange && raw_cols.dim != 0 + true # Inner dimension range converted to column range - DO preserve type + else + true # Column selection (dim == 0 or not ragged) + end + + # If the raw selector is a RaggedEnd/RaggedRange referring to inner dims, reinterpret as column selector. + cols = if raw_cols isa RaggedEnd && raw_cols.dim != 0 + lastindex(A.u) + raw_cols.offset + elseif raw_cols isa RaggedRange && raw_cols.dim != 0 + # Convert inner-dimension range to column range by resolving bounds + start_val = raw_cols.start < 0 ? lastindex(A.u) + raw_cols.start : raw_cols.start + stop_val = lastindex(A.u) + raw_cols.offset + Base.range(start_val; step = raw_cols.step, stop = stop_val) + else + _column_indices(A, raw_cols) + end + prefix = Base.front(I) + if cols isa Int + resolved_prefix = _resolve_ragged_indices(prefix, A, cols) + inner_nd = ndims(A.u[cols]) + n_missing = inner_nd - length(resolved_prefix) + padded = if n_missing > 0 + if all(idx -> idx === Colon(), resolved_prefix) + (resolved_prefix..., ntuple(_ -> Colon(), n_missing)...) else - resolved_prefix + ( + resolved_prefix..., + (lastindex(A.u[cols], length(resolved_prefix) + i) for i in 1:n_missing)..., + ) end - return A.u[cols][padded...] else - u_slice = [ - begin - resolved_prefix = _resolve_ragged_indices(prefix, A, col) - inner_nd = ndims(A.u[col]) - n_missing = inner_nd - length(resolved_prefix) - padded = if n_missing > 0 - if all(idx -> idx === Colon(), resolved_prefix) - ( - resolved_prefix..., - ntuple(_ -> Colon(), n_missing)..., - ) - else - ( - resolved_prefix..., - ( - lastindex( - A.u[col], - length(resolved_prefix) + i - ) for i in 1:n_missing - )..., - ) - end + resolved_prefix + end + return A.u[cols][padded...] + else + u_slice = [ + begin + resolved_prefix = _resolve_ragged_indices(prefix, A, col) + inner_nd = ndims(A.u[col]) + n_missing = inner_nd - length(resolved_prefix) + padded = if n_missing > 0 + if all(idx -> idx === Colon(), resolved_prefix) + ( + resolved_prefix..., + ntuple(_ -> Colon(), n_missing)..., + ) else - resolved_prefix - end - A.u[col][padded...] + ( + resolved_prefix..., + ( + lastindex( + A.u[col], + length(resolved_prefix) + i + ) for i in 1:n_missing + )..., + ) end - for col in cols - ] - # Only preserve DiffEqArray type if we're selecting actual columns, not inner dimensions - if is_column_selection - return _preserve_array_type(A, u_slice, cols) - else - return VectorOfArray(u_slice) - end + else + resolved_prefix + end + A.u[col][padded...] + end + for col in cols + ] + # Only preserve DiffEqArray type if we're selecting actual columns, not inner dimensions + if is_column_selection + return _preserve_array_type(A, u_slice, cols) + else + return VectorOfArray(u_slice) end end +end +@inline function _padded_resolved_indices(prefix, A::AbstractVectorOfArray, col) + resolved = _resolve_ragged_indices(prefix, A, col) + inner_nd = ndims(A.u[col]) + padded = (resolved..., ntuple(_ -> Colon(), max(inner_nd - length(resolved), 0))...) + return padded +end + +@inline function _ragged_getindex_full(A::AbstractVectorOfArray, I...) # Otherwise, use the full-length interpretation (last index is column selector; missing columns default to Colon()). - if length(I) == n - cols = last(I) - prefix = Base.front(I) + n = ndims(A) + cols, prefix = if length(I) == n + last(I), Base.front(I) else - cols = Colon() - prefix = I + Colon(), I end if cols isa Int if all(idx -> idx === Colon(), prefix) return A.u[cols] end - resolved = _resolve_ragged_indices(prefix, A, cols) - inner_nd = ndims(A.u[cols]) - padded = (resolved..., ntuple(_ -> Colon(), max(inner_nd - length(resolved), 0))...) - return A.u[cols][padded...] + return A.u[cols][_padded_resolved_indices(prefix, A, cols)...] else col_idxs = _column_indices(A, cols) # Resolve sentinel RaggedEnd/RaggedRange (dim==0) for column selection - if col_idxs isa RaggedEnd - col_idxs = _resolve_ragged_index(col_idxs, A, 1) - elseif col_idxs isa RaggedRange + if col_idxs isa RaggedEnd || col_idxs isa RaggedRange col_idxs = _resolve_ragged_index(col_idxs, A, 1) end # If we're selecting whole inner arrays (all leading indices are Colons), @@ -986,23 +964,14 @@ end end end # If col_idxs resolved to a single Int, handle it directly - if col_idxs isa Int - resolved = _resolve_ragged_indices(prefix, A, col_idxs) - inner_nd = ndims(A.u[col_idxs]) - padded = ( - resolved..., ntuple(_ -> Colon(), max(inner_nd - length(resolved), 0))..., - ) - return A.u[col_idxs][padded...] - end vals = map(col_idxs) do col - resolved = _resolve_ragged_indices(prefix, A, col) - inner_nd = ndims(A.u[col]) - padded = ( - resolved..., ntuple(_ -> Colon(), max(inner_nd - length(resolved), 0))..., - ) - A.u[col][padded...] + A.u[col][_padded_resolved_indices(prefix, A, col)...] + end + if col_idxs isa Int + return vals + else + return stack(vals) end - return stack(vals) end end diff --git a/test/interface_tests.jl b/test/interface_tests.jl index aa0aa0b3..655a441c 100644 --- a/test/interface_tests.jl +++ b/test/interface_tests.jl @@ -67,16 +67,19 @@ push!(testda, [-1, -2, -3, -4]) @inferred sum(VectorOfArray([VectorOfArray([zeros(4, 4)])])) @inferred mapreduce(string, *, testva) # Type stability for `end` indexing (issue #525) -testva_end = VectorOfArray([fill(2.0, 2) for i in 1:10]) +testva_end = VectorOfArray(fill(fill(2.0, 2), 10)) # Use lastindex directly since `end` doesn't work in SafeTestsets last_col = lastindex(testva_end, 2) @inferred testva_end[1, last_col] +@inferred testva_end[1, 1:last_col] @test testva_end[1, last_col] == 2.0 last_col = lastindex(testva_end) @inferred testva_end[1, last_col] +@inferred testva_end[1, 1:last_col] @test testva_end[1, last_col] == 2.0 last_row = lastindex(testva_end, 1) @inferred testva_end[last_row, 1] +@inferred testva_end[1:last_row, 1] @test testva_end[last_row, 1] == 2.0 # mapreduce