Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 122 additions & 35 deletions src/vector_of_array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,18 @@ end
function Base.:(:)(start::Integer, step::Integer, stop::RaggedEnd)
return RaggedRange(stop.dim, Int(start), Int(step), stop.offset)
end
function Base.:(:)(start::RaggedEnd, stop::RaggedEnd)
return RaggedRange(stop.dim, start.offset, 1, stop.offset)
end
function Base.:(:)(start::RaggedEnd, step::Integer, stop::RaggedEnd)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're missing start with raggedend and stop with integer?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, right. That's also possible. I'll add that in a minute.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be fixed. Thanks for catching that!

return RaggedRange(stop.dim, start.offset, Int(step), stop.offset)
end
function Base.:(:)(start::RaggedEnd, stop::Integer)
return RaggedRange(start.dim, start.offset, 1, Int(stop))
end
function Base.:(:)(start::RaggedEnd, step::Integer, stop::Integer)
return RaggedRange(start.dim, start.offset, Int(step), Int(stop))
end
Base.broadcastable(x::RaggedRange) = Ref(x)

@inline function _is_ragged_dim(VA::AbstractVectorOfArray, d::Integer)
Expand All @@ -579,6 +591,12 @@ Base.@propagate_inbounds function _getindex(
return A.u[I]
end

Base.@propagate_inbounds function _getindex(
A::AbstractDiffEqArray, ::NotSymbolic, ::Colon, I::Int
)
return A.u[I]
end

Base.@propagate_inbounds function _getindex(
A::AbstractVectorOfArray, ::NotSymbolic,
I::Union{Int, AbstractArray{Int}, AbstractArray{Bool}, Colon}...
Expand All @@ -589,6 +607,33 @@ Base.@propagate_inbounds function _getindex(
stack(getindex.(A.u[last(I)], tuple.(Base.front(I))...))
end
end

Base.@propagate_inbounds function _getindex(
A::AbstractDiffEqArray, ::NotSymbolic,
I::Union{Int, AbstractArray{Int}, AbstractArray{Bool}, Colon}...
)
return if last(I) isa Int
A.u[last(I)][Base.front(I)...]
else
col_idxs = last(I)
# Only preserve DiffEqArray type if all prefix indices are Colons (selecting whole inner arrays)
if all(idx -> idx isa Colon, Base.front(I))
# For Colon, select all columns
if col_idxs isa Colon
col_idxs = eachindex(A.u)
end
# For DiffEqArray, we need to preserve the time values and type
# Create a vector of sliced arrays instead of stacking into higher-dim array
u_slice = [A.u[col][Base.front(I)...] for col in col_idxs]
# Return as DiffEqArray with sliced time values
return DiffEqArray(u_slice, A.t[col_idxs], parameter_values(A), symbolic_container(A))
else
# Prefix indices are not all Colons - do the same as VectorOfArray
# (stack the results into a higher-dimensional array)
return stack(getindex.(A.u[col_idxs], tuple.(Base.front(I))...))
end
end
end
Base.@propagate_inbounds function _getindex(
VA::AbstractVectorOfArray, ::NotSymbolic, ii::CartesianIndex
)
Expand Down Expand Up @@ -674,6 +719,17 @@ end
return idx.dim == 0 ? idx.offset : idx
end

@inline function _column_indices(VA::AbstractVectorOfArray, idx::RaggedRange)
# RaggedRange with dim=0 means it's a column range with pre-resolved indices
if idx.dim == 0
# Create a range with the offset as the stop value
return Base.range(idx.start; step = idx.step, stop = idx.offset)
else
# dim != 0 means it's an inner-dimension range that needs column expansion
return idx
end
end

@inline _resolve_ragged_index(idx, ::AbstractVectorOfArray, ::Any) = idx
@inline function _resolve_ragged_index(idx::RaggedEnd, VA::AbstractVectorOfArray, col)
if idx.dim == 0
Expand Down Expand Up @@ -757,27 +813,54 @@ end
return (Base.front(args)..., resolved_last)
end
elseif args[end] isa RaggedRange
resolved_last = _resolve_ragged_index(args[end], A, 1)
if length(args) == 1
return (resolved_last,)
# Only pre-resolve if it's an inner-dimension range (dim != 0)
# Column ranges (dim == 0) are handled later by _column_indices
if args[end].dim == 0
# Column range - let _column_indices handle it
return args
else
return (Base.front(args)..., resolved_last)
resolved_last = _resolve_ragged_index(args[end], A, 1)
if length(args) == 1
return (resolved_last,)
else
return (Base.front(args)..., resolved_last)
end
end
end
return args
end

# Helper function to preserve DiffEqArray type when slicing
@inline function _preserve_array_type(A::AbstractVectorOfArray, u_slice, col_idxs)
return VectorOfArray(u_slice)
end

@inline function _preserve_array_type(A::AbstractDiffEqArray, u_slice, col_idxs)
return DiffEqArray(u_slice, A.t[col_idxs], parameter_values(A), symbolic_container(A))
end

@inline function _ragged_getindex(A::AbstractVectorOfArray, I...)
n = ndims(A)
# Special-case when user provided one fewer index than ndims(A): last index is column selector.
if length(I) == n - 1
raw_cols = last(I)
# Determine if we're doing column selection (preserve type) or inner-dimension selection (don't preserve)
is_column_selection = if raw_cols isa RaggedEnd && raw_cols.dim != 0
false # Inner dimension - don't preserve type
elseif raw_cols isa RaggedRange && raw_cols.dim != 0
true # Inner dimension range converted to column range - DO preserve type
else
true # Column selection (dim == 0 or not ragged)
end

# If the raw selector is a RaggedEnd/RaggedRange referring to inner dims, reinterpret as column selector.
cols = if raw_cols isa RaggedEnd && raw_cols.dim != 0
lastindex(A.u) + raw_cols.offset
elseif raw_cols isa RaggedRange && raw_cols.dim != 0
# Convert inner-dimension range to column range by resolving bounds
start_val = raw_cols.start < 0 ? lastindex(A.u) + raw_cols.start : raw_cols.start
stop_val = lastindex(A.u) + raw_cols.offset
Base.range(raw_cols.start; step = raw_cols.step, stop = stop_val)
Base.range(start_val; step = raw_cols.step, stop = stop_val)
else
_column_indices(A, raw_cols)
end
Expand All @@ -800,37 +883,41 @@ end
end
return A.u[cols][padded...]
else
return VectorOfArray(
[
begin
resolved_prefix = _resolve_ragged_indices(prefix, A, col)
inner_nd = ndims(A.u[col])
n_missing = inner_nd - length(resolved_prefix)
padded = if n_missing > 0
if all(idx -> idx === Colon(), resolved_prefix)
(
resolved_prefix...,
ntuple(_ -> Colon(), n_missing)...,
)
else
(
resolved_prefix...,
(
lastindex(
A.u[col],
length(resolved_prefix) + i
) for i in 1:n_missing
)...,
)
end
u_slice = [
begin
resolved_prefix = _resolve_ragged_indices(prefix, A, col)
inner_nd = ndims(A.u[col])
n_missing = inner_nd - length(resolved_prefix)
padded = if n_missing > 0
if all(idx -> idx === Colon(), resolved_prefix)
(
resolved_prefix...,
ntuple(_ -> Colon(), n_missing)...,
)
else
resolved_prefix
end
A.u[col][padded...]
(
resolved_prefix...,
(
lastindex(
A.u[col],
length(resolved_prefix) + i
) for i in 1:n_missing
)...,
)
end
for col in cols
]
)
else
resolved_prefix
end
A.u[col][padded...]
end
for col in cols
]
# Only preserve DiffEqArray type if we're selecting actual columns, not inner dimensions
if is_column_selection
return _preserve_array_type(A, u_slice, cols)
else
return VectorOfArray(u_slice)
end
end
end

Expand Down Expand Up @@ -864,7 +951,7 @@ end
if col_idxs isa Int
return A.u[col_idxs]
else
return VectorOfArray(A.u[col_idxs])
return _preserve_array_type(A, A.u[col_idxs], col_idxs)
end
end
# If col_idxs resolved to a single Int, handle it directly
Expand Down
40 changes: 39 additions & 1 deletion test/basic_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,12 @@ diffeq = DiffEqArray(recs, t)
@test diffeq[:, 1] == testa[:, 1]
@test diffeq.u == recs
@test diffeq[:, end] == testa[:, end]
@test diffeq[:, 2:end] == DiffEqArray([recs[i] for i in 2:length(recs)], t)
@test diffeq[:, 2:end] == DiffEqArray([recs[i] for i in 2:length(recs)], t[2:end])
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was actually pretty confusing because in the previous test t of the left hand side is not the same as the t on the right hand side, but it is not caught because of

.

@test diffeq[:, 2:end].t == t[2:end]
@test diffeq[:, (end - 1):end] == DiffEqArray([recs[i] for i in (length(recs) - 1):length(recs)], t[(length(t) - 1):length(t)])
@test diffeq[:, (end - 1):end].t == t[(length(t) - 1):length(t)]
@test diffeq[:, (end - 5):8] == DiffEqArray([recs[i] for i in (length(t) - 5):8], t[(length(t) - 5):8])
@test diffeq[:, (end - 5):8].t == t[(length(t) - 5):8]

# ## (Int, Int)
@test testa[5, 4] == testva[5, 4]
Expand Down Expand Up @@ -148,6 +153,12 @@ diffeq = DiffEqArray(recs, t)
@test testva[1:2, 1:2] == [1 3; 2 5]
@test diffeq[:, 1] == recs[1]
@test diffeq[1:2, 1:2] == [1 3; 2 5]
@test diffeq[:, 1:2] == DiffEqArray([recs[i] for i in 1:2], t[1:2])
@test diffeq[:, 1:2].t == t[1:2]
@test diffeq[:, 2:end] == DiffEqArray([recs[i] for i in 2:3], t[2:end])
@test diffeq[:, 2:end].t == t[2:end]
@test diffeq[:, (end - 1):end] == DiffEqArray([recs[i] for i in (length(recs) - 1):length(recs)], t[(length(t) - 1):length(t)])
@test diffeq[:, (end - 1):end].t == t[(length(t) - 1):length(t)]

# Test views of heterogeneous arrays (issue #453)
f = VectorOfArray([[1.0], [2.0, 3.0]])
Expand Down Expand Up @@ -179,6 +190,7 @@ ragged = VectorOfArray([[1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0, 9.0]])
@test ragged[1:end, 3] == [6.0, 7.0, 8.0, 9.0]
@test ragged[:, end] == [6.0, 7.0, 8.0, 9.0]
@test ragged[:, 2:end] == VectorOfArray(ragged.u[2:end])
@test ragged[:, (end - 1):end] == VectorOfArray(ragged.u[(end - 1):end])

ragged2 = VectorOfArray([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0], [7.0, 8.0, 9.0]])
@test ragged2[end, 1] == 4.0
Expand All @@ -199,6 +211,7 @@ ragged2 = VectorOfArray([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0], [7.0, 8.0, 9.0]])
@test ragged2[1:(end - 1), 1] == [1.0, 2.0, 3.0]
@test ragged2[1:(end - 1), 2] == [5.0]
@test ragged2[1:(end - 1), 3] == [7.0, 8.0]
@test ragged2[:, (end - 1):end] == VectorOfArray(ragged2.u[(end - 1):end])

# Test that RaggedEnd and RaggedRange broadcast as scalars
# (fixes issue with SymbolicIndexingInterface where broadcasting over RaggedEnd would fail)
Expand All @@ -222,6 +235,31 @@ u = VectorOfArray([[1.0], [2.0, 3.0]])
u[:, 2] .= [10.0, 11.0]
@test u.u[2] == [10.0, 11.0]

# Test DiffEqArray with 2D inner arrays (matrices)
t = 1:2
recs_2d = [rand(2, 3), rand(2, 4)]
diffeq_2d = DiffEqArray(recs_2d, t)
@test diffeq_2d[:, 1] == recs_2d[1]
@test diffeq_2d[:, 2] == recs_2d[2]
@test diffeq_2d[:, 1:2] == DiffEqArray(recs_2d[1:2], t[1:2])
@test diffeq_2d[:, 1:2].t == t[1:2]
@test diffeq_2d[:, 2:end] == DiffEqArray(recs_2d[2:end], t[2:end])
@test diffeq_2d[:, 2:end].t == t[2:end]
@test diffeq_2d[:, (end - 1):end] == DiffEqArray(recs_2d[(end - 1):end], t[(end - 1):end])
@test diffeq_2d[:, (end - 1):end].t == t[(end - 1):end]

# Test DiffEqArray with 3D inner arrays (tensors)
recs_3d = [rand(2, 3, 4), rand(2, 3, 5)]
diffeq_3d = DiffEqArray(recs_3d, t)
@test diffeq_3d[:, :, :, 1] == recs_3d[1]
@test diffeq_3d[:, :, :, 2] == recs_3d[2]
@test diffeq_3d[:, :, :, 1:2] == DiffEqArray(recs_3d[1:2], t[1:2])
@test diffeq_3d[:, :, :, 1:2].t == t[1:2]
@test diffeq_3d[:, :, :, 2:end] == DiffEqArray(recs_3d[2:end], t[2:end])
@test diffeq_3d[:, :, :, 2:end].t == t[2:end]
@test diffeq_3d[:, :, :, (end - 1):end] == DiffEqArray(recs_3d[(end - 1):end], t[(end - 1):end])
@test diffeq_3d[:, :, :, (end - 1):end].t == t[(end - 1):end]

# 2D inner arrays (matrices) with ragged second dimension
u = VectorOfArray([zeros(1, n) for n in (2, 3)])
@test length(view(u, 1, :, 1)) == 2
Expand Down
Loading