Skip to content

Commit 3ae6d13

Browse files
committed
bug fix - avoid resorting data in joins
1 parent 2257366 commit 3ae6d13

File tree

4 files changed

+169
-50
lines changed

4 files changed

+169
-50
lines changed

src/join/closejoin.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,8 @@ function _fill_right_cols_table_close!(_res, x, ranges, total, borderval, fill_
274274

275275
end
276276

277-
function _change_refpool_find_range_for_close!(ranges, dsl, dsr, r_perms, oncols_left, oncols_right, direction, lmf, rmf, j; nsfpaj = true, threads = true)
277+
function _change_refpool_find_range_for_close!(ranges, dsl, dsr, r_perms, oncols_left, oncols_right, direction, lmf, rmf, j; nsfpaj=nsfpaj, threads = true)
278+
nsfpaj_in = nsfpaj[1]
278279
var_l = _columns(dsl)[oncols_left[j]]
279280
var_r = _columns(dsr)[oncols_right[j]]
280281
l_idx = oncols_left[j]
@@ -292,8 +293,8 @@ function _change_refpool_find_range_for_close!(ranges, dsl, dsr, r_perms, oncols
292293

293294
T1 = Core.Compiler.return_type(_fl, Tuple{eltype(var_l)})
294295

295-
if DataAPI.refpool(var_r) !== nothing && nsfpaj
296-
true && throw(ErrorException("we shouldn't end up here"))
296+
if DataAPI.refpool(var_r) !== nothing && nsfpaj_in
297+
throw(ErrorException("we shouldn't end up here"))
297298
else
298299
T2 = Core.Compiler.return_type(_fr, Tuple{eltype(var_r)})
299300
if direction == :backward
@@ -327,10 +328,10 @@ function _join_closejoin(dsl, dsr::AbstractDataset, ::Val{T}; onleft, onright, m
327328
throw(ArgumentError("duplicate column names, pass `makeunique = true` to make them unique using a suffix automatically." ))
328329
end
329330

330-
nsfpaj = true
331+
nsfpaj = [true]
331332
# if the column for close join is a PA we cannot use the fast path
332333
if DataAPI.refpool(_columns(dsr)[oncols_right[end]]) !== nothing
333-
nsfpaj = false
334+
nsfpaj = [false]
334335
end
335336
if length(oncols_left) > 1 && method == :hash
336337
ranges, a, idx, minval, reps, sz, right_cols_2= _find_ranges_for_join_using_hash(dsl, dsr, onleft[1:end-1], onright[1:end-1], mapformats, true, Val(T), threads = threads)

src/join/join.jl

Lines changed: 59 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -70,48 +70,63 @@ function _fill_range_for_accelerated_join!(ranges, starts, loc, x, f, sz, chunk;
7070
end
7171
end
7272
# TODO how the hashing behave for Categorical Arrays?
73-
function _find_permute_and_fill_range_for_join!(ranges, dsr, dsl, oncols_right, oncols_left, stable, alg, mapformats, accelerate, chunk = 2^10; nsfpaj = true, threads = true)
73+
function _find_permute_and_fill_range_for_join!(ranges, dsr, dsl, oncols_right, oncols_left, stable, alg, mapformats, accelerate, chunk = 2^10; nsfpaj=nsfpaj, threads = true)
74+
# nsfpaj has no value by default to make sure caller passes it
75+
# we use vector to represent nsfpaj, because we may override its value
76+
nsfpaj_in = nsfpaj[1]
77+
7478
if isempty(dsr)
7579
idx = []
7680
fill!(ranges, 1:nrow(dsr))
7781
last_valid_range = -1
7882
else
79-
if accelerate
80-
if mapformats[2]
81-
_fr = getformat(dsr, oncols_right[1])
82-
else
83-
_fr = identity
84-
end
85-
grng = _divide_for_fast_join(_columns(dsr)[oncols_right[1]], _fr, chunk; threads = threads)
86-
if mapformats[1]
87-
_fl = getformat(dsl, oncols_left[1])
88-
else
89-
_fl = identity
90-
end
91-
_fill_range_for_accelerated_join!(ranges, grng.starts, grng.starts_loc, _columns(dsl)[oncols_left[1]], _fl, nrow(dsr), chunk; threads = threads)
92-
if dsr isa SubDataset
93-
starts, idx, last_valid_range = _sortperm_v(dsr, oncols_right, stable = stable, a = alg, mapformats = mapformats[2], notsortpaforjoin = nsfpaj, givenrange = grng, threads = threads)
94-
95-
else
96-
starts, idx, last_valid_range = _sortperm(dsr, oncols_right, stable = stable, a = alg, mapformats = mapformats[2], notsortpaforjoin = nsfpaj, givenrange = grng, threads = threads)
97-
end
83+
# check if data already sorted, if so it overrides accelerate
84+
if _check_for_fast_sort(dsr, oncols_right, fill(false, length(oncols_right)), mapformats[2]; notsortpaforjoin = false, givenrange = nothing) == 0
85+
# if it is already sorted based on what we want we can saftly change nsfpaj to false
86+
nsfpaj[1] = false
87+
idx = 1:nrow(dsr)
88+
last_valid_range = _ngroups(dsr)
89+
fill!(ranges, 1:nrow(dsr))
9890
else
99-
if dsr isa SubDataset
100-
starts, idx, last_valid_range = _sortperm_v(dsr, oncols_right, stable = stable, a = alg, mapformats = mapformats[2], notsortpaforjoin = nsfpaj, threads = threads)
91+
92+
if accelerate
93+
if mapformats[2]
94+
_fr = getformat(dsr, oncols_right[1])
95+
else
96+
_fr = identity
97+
end
98+
grng = _divide_for_fast_join(_columns(dsr)[oncols_right[1]], _fr, chunk; threads = threads)
99+
if mapformats[1]
100+
_fl = getformat(dsl, oncols_left[1])
101+
else
102+
_fl = identity
103+
end
104+
_fill_range_for_accelerated_join!(ranges, grng.starts, grng.starts_loc, _columns(dsl)[oncols_left[1]], _fl, nrow(dsr), chunk; threads = threads)
105+
if dsr isa SubDataset
106+
starts, idx, last_valid_range = _sortperm_v(dsr, oncols_right, stable = stable, a = alg, mapformats = mapformats[2], notsortpaforjoin = nsfpaj_in, givenrange = grng, threads = threads)
107+
108+
else
109+
starts, idx, last_valid_range = _sortperm(dsr, oncols_right, stable = stable, a = alg, mapformats = mapformats[2], notsortpaforjoin = nsfpaj_in, givenrange = grng, threads = threads)
110+
end
101111
else
102-
starts, idx, last_valid_range = _sortperm(dsr, oncols_right, stable = stable, a = alg, mapformats = mapformats[2], notsortpaforjoin = nsfpaj, threads = threads)
112+
if dsr isa SubDataset
113+
starts, idx, last_valid_range = _sortperm_v(dsr, oncols_right, stable = stable, a = alg, mapformats = mapformats[2], notsortpaforjoin = nsfpaj_in, threads = threads)
114+
else
115+
starts, idx, last_valid_range = _sortperm(dsr, oncols_right, stable = stable, a = alg, mapformats = mapformats[2], notsortpaforjoin = nsfpaj_in, threads = threads)
116+
end
117+
fill!(ranges, 1:nrow(dsr))
103118
end
104-
fill!(ranges, 1:nrow(dsr))
105119
end
106120
end
107121
idx, last_valid_range == length(idx)
108122
end
109123

110124
function _sort_for_join_after_hash(dsr, oncols_right, stable, alg, mapformats, nsfpaj, grng; threads = true)
125+
nsfpaj_in = nsfpaj[1]
111126
if dsr isa SubDataset
112-
starts, idx, last_valid_range = _sortperm_v(dsr, oncols_right, stable = stable, a = alg, mapformats = mapformats[2], notsortpaforjoin = nsfpaj, givenrange = grng, threads = threads)
127+
starts, idx, last_valid_range = _sortperm_v(dsr, oncols_right, stable = stable, a = alg, mapformats = mapformats[2], notsortpaforjoin = nsfpaj_in, givenrange = grng, threads = threads)
113128
else
114-
starts, idx, last_valid_range = _sortperm(dsr, oncols_right, stable = stable, a = alg, mapformats = mapformats[2], notsortpaforjoin = nsfpaj, givenrange = grng, threads = threads)
129+
starts, idx, last_valid_range = _sortperm(dsr, oncols_right, stable = stable, a = alg, mapformats = mapformats[2], notsortpaforjoin = nsfpaj_in, givenrange = grng, threads = threads)
115130
end
116131
end
117132

@@ -423,7 +438,8 @@ function _mark_lt_part!(inbits, x_l, x_r, _fl::F1, _fr::F2, ranges, r_perms, en,
423438
our_cumsum!(revised_ends)
424439
end
425440

426-
function _change_refpool_find_range_for_join!(ranges, dsl, dsr, r_perms, oncols_left, oncols_right, lmf, rmf, j; type = :both, nsfpaj = true, threads = true)
441+
function _change_refpool_find_range_for_join!(ranges, dsl, dsr, r_perms, oncols_left, oncols_right, lmf, rmf, j; type = :both, nsfpaj=nsfpaj, threads = true)
442+
nsfpaj_in = nsfpaj[1]
427443
var_l = _columns(dsl)[oncols_left[j]]
428444
var_r = _columns(dsr)[oncols_right[j]]
429445
l_idx = oncols_left[j]
@@ -441,7 +457,7 @@ function _change_refpool_find_range_for_join!(ranges, dsl, dsr, r_perms, oncols_
441457

442458
T1 = Core.Compiler.return_type(DataAPI.unwrap_fl, Tuple{eltype(var_l)})
443459

444-
if DataAPI.refpool(var_r) !== nothing && nsfpaj
460+
if DataAPI.refpool(var_r) !== nothing && nsfpaj_in
445461
# sort taken care for refs ordering of modified values, but we still need to change refs
446462
if _fr == identity
447463
var_r_cpy = var_r
@@ -463,6 +479,7 @@ end
463479

464480
function _join_left(dsl, dsr, ::Val{T}; onleft, onright, makeunique = false, mapformats = [true, true], stable = false, alg = HeapSort, check = true, accelerate = false, method = :sort, threads = true, multiple_match::Bool = false, multiple_match_name = :multiple, obs_id = [false, false], obs_id_name = :obs_id) where T
465481
isempty(dsl) && return copy(dsl)
482+
nsfpaj = [true]
466483
if method == :hash
467484
ranges, a, idx, minval, reps, sz, right_cols = _find_ranges_for_join_using_hash(dsl, dsr, onleft, onright, mapformats, makeunique, Val(T); threads = threads)
468485
elseif method == :sort
@@ -480,10 +497,10 @@ function _join_left(dsl, dsr, ::Val{T}; onleft, onright, makeunique = false, map
480497
return result
481498
end
482499
end
483-
idx, uniquemode = _find_permute_and_fill_range_for_join!(ranges, dsr, dsl, oncols_right, oncols_left, stable, alg, mapformats, accelerate; threads = threads)
500+
idx, uniquemode = _find_permute_and_fill_range_for_join!(ranges, dsr, dsl, oncols_right, oncols_left, stable, alg, mapformats, accelerate;nsfpaj = nsfpaj, threads = threads)
484501

485502
for j in 1:length(oncols_left)
486-
_change_refpool_find_range_for_join!(ranges, dsl, dsr, idx, oncols_left, oncols_right, mapformats[1], mapformats[2], j; threads = threads)
503+
_change_refpool_find_range_for_join!(ranges, dsl, dsr, idx, oncols_left, oncols_right, mapformats[1], mapformats[2], j; nsfpaj = nsfpaj, threads = threads)
487504
end
488505
end
489506
new_ends = map(x -> max(1, length(x)), ranges)
@@ -553,6 +570,7 @@ end
553570

554571
function _join_left!(dsl::Dataset, dsr::AbstractDataset, ::Val{T}; onleft, onright, makeunique = false, mapformats = [true, true], stable = false, alg = HeapSort, check = true, accelerate = false, method = :sort, threads = true, multiple_match = false, multiple_match_name = :multiple, obs_id = [false, false], obs_id_name = :obs_id) where T
555572
isempty(dsl) && return dsl
573+
nsfpaj = [true]
556574
if method == :hash
557575
ranges, a, idx, minval, reps, sz, right_cols = _find_ranges_for_join_using_hash(dsl, dsr, onleft, onright, mapformats, makeunique, Val(T); threads = threads)
558576
elseif method == :sort
@@ -569,9 +587,9 @@ function _join_left!(dsl::Dataset, dsr::AbstractDataset, ::Val{T}; onleft, onrig
569587
return result
570588
end
571589
end
572-
idx, uniquemode = _find_permute_and_fill_range_for_join!(ranges, dsr, dsl, oncols_right, oncols_left, stable, alg, mapformats, accelerate, threads = threads)
590+
idx, uniquemode = _find_permute_and_fill_range_for_join!(ranges, dsr, dsl, oncols_right, oncols_left, stable, alg, mapformats, accelerate, nsfpaj = nsfpaj, threads = threads)
573591
for j in 1:length(oncols_left)
574-
_change_refpool_find_range_for_join!(ranges, dsl, dsr, idx, oncols_left, oncols_right, mapformats[1], mapformats[2], j, threads = threads)
592+
_change_refpool_find_range_for_join!(ranges, dsl, dsr, idx, oncols_left, oncols_right, mapformats[1], mapformats[2], j, nsfpaj = nsfpaj, threads = threads)
575593
end
576594
end
577595
if !all(x->length(x) <= 1, ranges)
@@ -660,11 +678,11 @@ function _join_inner(dsl, dsr::AbstractDataset, ::Val{T}; onleft, onright, onrig
660678
throw(ArgumentError("duplicate column names, pass `makeunique = true` to make them unique using a suffix automatically." ))
661679
end
662680

663-
nsfpaj = true
681+
nsfpaj = [true]
664682
# if the columns for inequality like join are PA we cannot use the fast path
665683
if type != :both
666684
if any(i-> DataAPI.refpool(_columns(dsr)[i]) !== nothing, right_range_cols)
667-
nsfpaj = false
685+
nsfpaj = [false]
668686
end
669687
end
670688
# if (onright_range === nothing || length(onleft) > 1) is false, then we have inequality kind join with no exact match join
@@ -689,7 +707,7 @@ function _join_inner(dsl, dsr::AbstractDataset, ::Val{T}; onleft, onright, onrig
689707
return result
690708
end
691709
end
692-
idx, uniquemode = _find_permute_and_fill_range_for_join!(ranges, dsr, dsl, oncols_right, oncols_left, stable, alg, mapformats, accelerate && (onright_range == nothing || length(oncols_right)>1); nsfpaj = nsfpaj, threads = threads)
710+
idx, uniquemode = _find_permute_and_fill_range_for_join!(ranges, dsr, dsl, oncols_right, oncols_left, stable, alg, mapformats, accelerate && (onright_range === nothing || length(oncols_right)>1); nsfpaj = nsfpaj, threads = threads)
693711

694712
for j in 1:length(oncols_left)-1
695713
_change_refpool_find_range_for_join!(ranges, dsl, dsr, idx, oncols_left, oncols_right, mapformats[1], mapformats[2], j; nsfpaj = nsfpaj, threads = threads)
@@ -784,7 +802,7 @@ function _in(dsl::AbstractDataset, dsr::AbstractDataset, ::Val{T}; onleft, onrig
784802
isempty(dsl) && return Bool[]
785803
oncols_left = onleft
786804
oncols_right = onright
787-
805+
nsfpaj = [true]
788806
# use Set when there is only one column in `on`
789807
if length(oncols_right) == 1
790808
if mapformats[1]
@@ -800,9 +818,9 @@ function _in(dsl::AbstractDataset, dsr::AbstractDataset, ::Val{T}; onleft, onrig
800818
return _in_use_Set(_columns(dsl)[oncols_left[1]], _columns(dsr)[oncols_right[1]], _fl, _fr, threads = threads)
801819
end
802820
ranges = Vector{UnitRange{T}}(undef, nrow(dsl))
803-
idx, uniquemode = _find_permute_and_fill_range_for_join!(ranges, dsr, dsl, oncols_right, oncols_left, stable, alg, mapformats, accelerate, threads = threads)
821+
idx, uniquemode = _find_permute_and_fill_range_for_join!(ranges, dsr, dsl, oncols_right, oncols_left, stable, alg, mapformats, accelerate, nsfpaj = nsfpaj, threads = threads)
804822
for j in 1:length(oncols_left)
805-
_change_refpool_find_range_for_join!(ranges, dsl, dsr, idx, oncols_left, oncols_right, mapformats[1], mapformats[2], j, threads = threads)
823+
_change_refpool_find_range_for_join!(ranges, dsl, dsr, idx, oncols_left, oncols_right, mapformats[1], mapformats[2], j, nsfpaj = nsfpaj, threads = threads)
806824
end
807825
map(x -> length(x) == 0 ? false : true, ranges)
808826
end
@@ -875,6 +893,7 @@ function _join_outer(dsl, dsr::AbstractDataset, ::Val{T}; onleft, onright, makeu
875893
(isempty(dsl) || isempty(dsr)) && throw(ArgumentError("in `outerjoin` both left and right tables must be non-empty"))
876894
oncols_left = onleft
877895
oncols_right = onright
896+
nsfpaj = [true]
878897
if method == :hash
879898
ranges, a, idx, minval, reps, sz, right_cols = _find_ranges_for_join_using_hash(dsl, dsr, onleft, onright, mapformats, makeunique, Val(T); threads = threads)
880899
elseif method == :sort
@@ -889,9 +908,9 @@ function _join_outer(dsl, dsr::AbstractDataset, ::Val{T}; onleft, onright, makeu
889908
return result
890909
end
891910
end
892-
idx, uniquemode = _find_permute_and_fill_range_for_join!(ranges, dsr, dsl, oncols_right, oncols_left, stable, alg, mapformats, accelerate, threads = threads)
911+
idx, uniquemode = _find_permute_and_fill_range_for_join!(ranges, dsr, dsl, oncols_right, oncols_left, stable, alg, mapformats, accelerate, nsfpaj = nsfpaj, threads = threads)
893912
for j in 1:length(oncols_left)
894-
_change_refpool_find_range_for_join!(ranges, dsl, dsr, idx, oncols_left, oncols_right, mapformats[1], mapformats[2], j, threads = threads)
913+
_change_refpool_find_range_for_join!(ranges, dsl, dsr, idx, oncols_left, oncols_right, mapformats[1], mapformats[2], j, nsfpaj = nsfpaj, threads = threads)
895914
end
896915
end
897916
new_ends = map(x -> max(1, length(x)), ranges)

src/join/update.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ end
2828

2929
function _update!(dsl::Dataset, dsr::AbstractDataset, ::Val{T}; onleft, onright, check = true, allowmissing = true, mode = :all, mapformats = [true, true], stable = false, alg = HeapSort, accelerate = false, usehash = true, method = :sort, threads = true, op = nothing) where T
3030
isempty(dsl) && return dsl
31+
nsfpaj = [true]
3132
if method == :hash
3233
ranges, a, idx, minval, reps, sz, right_cols = _find_ranges_for_join_using_hash(dsl, dsr, onleft, onright, mapformats, true, Val(T); threads = threads)
3334
elseif method == :sort
@@ -42,10 +43,10 @@ function _update!(dsl::Dataset, dsr::AbstractDataset, ::Val{T}; onleft, onright,
4243
return result
4344
end
4445
end
45-
idx, uniquemode = _find_permute_and_fill_range_for_join!(ranges, dsr, dsl, oncols_right, oncols_left, stable, alg, mapformats, accelerate, threads = threads)
46+
idx, uniquemode = _find_permute_and_fill_range_for_join!(ranges, dsr, dsl, oncols_right, oncols_left, stable, alg, mapformats, accelerate, nsfpaj=nsfpaj, threads = threads)
4647

4748
for j in 1:length(oncols_left)
48-
_change_refpool_find_range_for_join!(ranges, dsl, dsr, idx, oncols_left, oncols_right, mapformats[1], mapformats[2], j, threads = threads)
49+
_change_refpool_find_range_for_join!(ranges, dsl, dsr, idx, oncols_left, oncols_right, mapformats[1], mapformats[2], j, nsfpaj = nsfpaj, threads = threads)
4950
end
5051
end
5152

0 commit comments

Comments
 (0)