@@ -70,48 +70,63 @@ function _fill_range_for_accelerated_join!(ranges, starts, loc, x, f, sz, chunk;
7070 end
7171end
7272# TODO how the hashing behave for Categorical Arrays?
73- function _find_permute_and_fill_range_for_join! (ranges, dsr, dsl, oncols_right, oncols_left, stable, alg, mapformats, accelerate, chunk = 2 ^ 10 ; nsfpaj = true , threads = true )
73+ function _find_permute_and_fill_range_for_join! (ranges, dsr, dsl, oncols_right, oncols_left, stable, alg, mapformats, accelerate, chunk = 2 ^ 10 ; nsfpaj= nsfpaj, threads = true )
74+ # nsfpaj has no value by default to make sure caller passes it
75+ # we use vector to represent nsfpaj, because we may override its value
76+ nsfpaj_in = nsfpaj[1 ]
77+
7478 if isempty (dsr)
7579 idx = []
7680 fill! (ranges, 1 : nrow (dsr))
7781 last_valid_range = - 1
7882 else
79- if accelerate
80- if mapformats[2 ]
81- _fr = getformat (dsr, oncols_right[1 ])
82- else
83- _fr = identity
84- end
85- grng = _divide_for_fast_join (_columns (dsr)[oncols_right[1 ]], _fr, chunk; threads = threads)
86- if mapformats[1 ]
87- _fl = getformat (dsl, oncols_left[1 ])
88- else
89- _fl = identity
90- end
91- _fill_range_for_accelerated_join! (ranges, grng. starts, grng. starts_loc, _columns (dsl)[oncols_left[1 ]], _fl, nrow (dsr), chunk; threads = threads)
92- if dsr isa SubDataset
93- starts, idx, last_valid_range = _sortperm_v (dsr, oncols_right, stable = stable, a = alg, mapformats = mapformats[2 ], notsortpaforjoin = nsfpaj, givenrange = grng, threads = threads)
94-
95- else
96- starts, idx, last_valid_range = _sortperm (dsr, oncols_right, stable = stable, a = alg, mapformats = mapformats[2 ], notsortpaforjoin = nsfpaj, givenrange = grng, threads = threads)
97- end
83+ # check if data already sorted, if so it overrides accelerate
84+ if _check_for_fast_sort (dsr, oncols_right, fill (false , length (oncols_right)), mapformats[2 ]; notsortpaforjoin = false , givenrange = nothing ) == 0
85+ # if it is already sorted based on what we want we can saftly change nsfpaj to false
86+ nsfpaj[1 ] = false
87+ idx = 1 : nrow (dsr)
88+ last_valid_range = _ngroups (dsr)
89+ fill! (ranges, 1 : nrow (dsr))
9890 else
99- if dsr isa SubDataset
100- starts, idx, last_valid_range = _sortperm_v (dsr, oncols_right, stable = stable, a = alg, mapformats = mapformats[2 ], notsortpaforjoin = nsfpaj, threads = threads)
91+
92+ if accelerate
93+ if mapformats[2 ]
94+ _fr = getformat (dsr, oncols_right[1 ])
95+ else
96+ _fr = identity
97+ end
98+ grng = _divide_for_fast_join (_columns (dsr)[oncols_right[1 ]], _fr, chunk; threads = threads)
99+ if mapformats[1 ]
100+ _fl = getformat (dsl, oncols_left[1 ])
101+ else
102+ _fl = identity
103+ end
104+ _fill_range_for_accelerated_join! (ranges, grng. starts, grng. starts_loc, _columns (dsl)[oncols_left[1 ]], _fl, nrow (dsr), chunk; threads = threads)
105+ if dsr isa SubDataset
106+ starts, idx, last_valid_range = _sortperm_v (dsr, oncols_right, stable = stable, a = alg, mapformats = mapformats[2 ], notsortpaforjoin = nsfpaj_in, givenrange = grng, threads = threads)
107+
108+ else
109+ starts, idx, last_valid_range = _sortperm (dsr, oncols_right, stable = stable, a = alg, mapformats = mapformats[2 ], notsortpaforjoin = nsfpaj_in, givenrange = grng, threads = threads)
110+ end
101111 else
102- starts, idx, last_valid_range = _sortperm (dsr, oncols_right, stable = stable, a = alg, mapformats = mapformats[2 ], notsortpaforjoin = nsfpaj, threads = threads)
112+ if dsr isa SubDataset
113+ starts, idx, last_valid_range = _sortperm_v (dsr, oncols_right, stable = stable, a = alg, mapformats = mapformats[2 ], notsortpaforjoin = nsfpaj_in, threads = threads)
114+ else
115+ starts, idx, last_valid_range = _sortperm (dsr, oncols_right, stable = stable, a = alg, mapformats = mapformats[2 ], notsortpaforjoin = nsfpaj_in, threads = threads)
116+ end
117+ fill! (ranges, 1 : nrow (dsr))
103118 end
104- fill! (ranges, 1 : nrow (dsr))
105119 end
106120 end
107121 idx, last_valid_range == length (idx)
108122end
109123
110124function _sort_for_join_after_hash (dsr, oncols_right, stable, alg, mapformats, nsfpaj, grng; threads = true )
125+ nsfpaj_in = nsfpaj[1 ]
111126 if dsr isa SubDataset
112- starts, idx, last_valid_range = _sortperm_v (dsr, oncols_right, stable = stable, a = alg, mapformats = mapformats[2 ], notsortpaforjoin = nsfpaj , givenrange = grng, threads = threads)
127+ starts, idx, last_valid_range = _sortperm_v (dsr, oncols_right, stable = stable, a = alg, mapformats = mapformats[2 ], notsortpaforjoin = nsfpaj_in , givenrange = grng, threads = threads)
113128 else
114- starts, idx, last_valid_range = _sortperm (dsr, oncols_right, stable = stable, a = alg, mapformats = mapformats[2 ], notsortpaforjoin = nsfpaj , givenrange = grng, threads = threads)
129+ starts, idx, last_valid_range = _sortperm (dsr, oncols_right, stable = stable, a = alg, mapformats = mapformats[2 ], notsortpaforjoin = nsfpaj_in , givenrange = grng, threads = threads)
115130 end
116131end
117132
@@ -423,7 +438,8 @@ function _mark_lt_part!(inbits, x_l, x_r, _fl::F1, _fr::F2, ranges, r_perms, en,
423438 our_cumsum! (revised_ends)
424439end
425440
426- function _change_refpool_find_range_for_join! (ranges, dsl, dsr, r_perms, oncols_left, oncols_right, lmf, rmf, j; type = :both , nsfpaj = true , threads = true )
441+ function _change_refpool_find_range_for_join! (ranges, dsl, dsr, r_perms, oncols_left, oncols_right, lmf, rmf, j; type = :both , nsfpaj= nsfpaj, threads = true )
442+ nsfpaj_in = nsfpaj[1 ]
427443 var_l = _columns (dsl)[oncols_left[j]]
428444 var_r = _columns (dsr)[oncols_right[j]]
429445 l_idx = oncols_left[j]
@@ -441,7 +457,7 @@ function _change_refpool_find_range_for_join!(ranges, dsl, dsr, r_perms, oncols_
441457
442458 T1 = Core. Compiler. return_type (DataAPI. unwrap∘ _fl, Tuple{eltype (var_l)})
443459
444- if DataAPI. refpool (var_r) != = nothing && nsfpaj
460+ if DataAPI. refpool (var_r) != = nothing && nsfpaj_in
445461 # sort taken care for refs ordering of modified values, but we still need to change refs
446462 if _fr == identity
447463 var_r_cpy = var_r
463479
464480function _join_left (dsl, dsr, :: Val{T} ; onleft, onright, makeunique = false , mapformats = [true , true ], stable = false , alg = HeapSort, check = true , accelerate = false , method = :sort , threads = true , multiple_match:: Bool = false , multiple_match_name = :multiple , obs_id = [false , false ], obs_id_name = :obs_id ) where T
465481 isempty (dsl) && return copy (dsl)
482+ nsfpaj = [true ]
466483 if method == :hash
467484 ranges, a, idx, minval, reps, sz, right_cols = _find_ranges_for_join_using_hash (dsl, dsr, onleft, onright, mapformats, makeunique, Val (T); threads = threads)
468485 elseif method == :sort
@@ -480,10 +497,10 @@ function _join_left(dsl, dsr, ::Val{T}; onleft, onright, makeunique = false, map
480497 return result
481498 end
482499 end
483- idx, uniquemode = _find_permute_and_fill_range_for_join! (ranges, dsr, dsl, oncols_right, oncols_left, stable, alg, mapformats, accelerate; threads = threads)
500+ idx, uniquemode = _find_permute_and_fill_range_for_join! (ranges, dsr, dsl, oncols_right, oncols_left, stable, alg, mapformats, accelerate;nsfpaj = nsfpaj, threads = threads)
484501
485502 for j in 1 : length (oncols_left)
486- _change_refpool_find_range_for_join! (ranges, dsl, dsr, idx, oncols_left, oncols_right, mapformats[1 ], mapformats[2 ], j; threads = threads)
503+ _change_refpool_find_range_for_join! (ranges, dsl, dsr, idx, oncols_left, oncols_right, mapformats[1 ], mapformats[2 ], j; nsfpaj = nsfpaj, threads = threads)
487504 end
488505 end
489506 new_ends = map (x -> max (1 , length (x)), ranges)
553570
554571function _join_left! (dsl:: Dataset , dsr:: AbstractDataset , :: Val{T} ; onleft, onright, makeunique = false , mapformats = [true , true ], stable = false , alg = HeapSort, check = true , accelerate = false , method = :sort , threads = true , multiple_match = false , multiple_match_name = :multiple , obs_id = [false , false ], obs_id_name = :obs_id ) where T
555572 isempty (dsl) && return dsl
573+ nsfpaj = [true ]
556574 if method == :hash
557575 ranges, a, idx, minval, reps, sz, right_cols = _find_ranges_for_join_using_hash (dsl, dsr, onleft, onright, mapformats, makeunique, Val (T); threads = threads)
558576 elseif method == :sort
@@ -569,9 +587,9 @@ function _join_left!(dsl::Dataset, dsr::AbstractDataset, ::Val{T}; onleft, onrig
569587 return result
570588 end
571589 end
572- idx, uniquemode = _find_permute_and_fill_range_for_join! (ranges, dsr, dsl, oncols_right, oncols_left, stable, alg, mapformats, accelerate, threads = threads)
590+ idx, uniquemode = _find_permute_and_fill_range_for_join! (ranges, dsr, dsl, oncols_right, oncols_left, stable, alg, mapformats, accelerate, nsfpaj = nsfpaj, threads = threads)
573591 for j in 1 : length (oncols_left)
574- _change_refpool_find_range_for_join! (ranges, dsl, dsr, idx, oncols_left, oncols_right, mapformats[1 ], mapformats[2 ], j, threads = threads)
592+ _change_refpool_find_range_for_join! (ranges, dsl, dsr, idx, oncols_left, oncols_right, mapformats[1 ], mapformats[2 ], j, nsfpaj = nsfpaj, threads = threads)
575593 end
576594 end
577595 if ! all (x-> length (x) <= 1 , ranges)
@@ -660,11 +678,11 @@ function _join_inner(dsl, dsr::AbstractDataset, ::Val{T}; onleft, onright, onrig
660678 throw (ArgumentError (" duplicate column names, pass `makeunique = true` to make them unique using a suffix automatically." ))
661679 end
662680
663- nsfpaj = true
681+ nsfpaj = [ true ]
664682 # if the columns for inequality like join are PA we cannot use the fast path
665683 if type != :both
666684 if any (i-> DataAPI. refpool (_columns (dsr)[i]) != = nothing , right_range_cols)
667- nsfpaj = false
685+ nsfpaj = [ false ]
668686 end
669687 end
670688 # if (onright_range === nothing || length(onleft) > 1) is false, then we have inequality kind join with no exact match join
@@ -689,7 +707,7 @@ function _join_inner(dsl, dsr::AbstractDataset, ::Val{T}; onleft, onright, onrig
689707 return result
690708 end
691709 end
692- idx, uniquemode = _find_permute_and_fill_range_for_join! (ranges, dsr, dsl, oncols_right, oncols_left, stable, alg, mapformats, accelerate && (onright_range == nothing || length (oncols_right)> 1 ); nsfpaj = nsfpaj, threads = threads)
710+ idx, uniquemode = _find_permute_and_fill_range_for_join! (ranges, dsr, dsl, oncols_right, oncols_left, stable, alg, mapformats, accelerate && (onright_range === nothing || length (oncols_right)> 1 ); nsfpaj = nsfpaj, threads = threads)
693711
694712 for j in 1 : length (oncols_left)- 1
695713 _change_refpool_find_range_for_join! (ranges, dsl, dsr, idx, oncols_left, oncols_right, mapformats[1 ], mapformats[2 ], j; nsfpaj = nsfpaj, threads = threads)
@@ -784,7 +802,7 @@ function _in(dsl::AbstractDataset, dsr::AbstractDataset, ::Val{T}; onleft, onrig
784802 isempty (dsl) && return Bool[]
785803 oncols_left = onleft
786804 oncols_right = onright
787-
805+ nsfpaj = [ true ]
788806 # use Set when there is only one column in `on`
789807 if length (oncols_right) == 1
790808 if mapformats[1 ]
@@ -800,9 +818,9 @@ function _in(dsl::AbstractDataset, dsr::AbstractDataset, ::Val{T}; onleft, onrig
800818 return _in_use_Set (_columns (dsl)[oncols_left[1 ]], _columns (dsr)[oncols_right[1 ]], _fl, _fr, threads = threads)
801819 end
802820 ranges = Vector {UnitRange{T}} (undef, nrow (dsl))
803- idx, uniquemode = _find_permute_and_fill_range_for_join! (ranges, dsr, dsl, oncols_right, oncols_left, stable, alg, mapformats, accelerate, threads = threads)
821+ idx, uniquemode = _find_permute_and_fill_range_for_join! (ranges, dsr, dsl, oncols_right, oncols_left, stable, alg, mapformats, accelerate, nsfpaj = nsfpaj, threads = threads)
804822 for j in 1 : length (oncols_left)
805- _change_refpool_find_range_for_join! (ranges, dsl, dsr, idx, oncols_left, oncols_right, mapformats[1 ], mapformats[2 ], j, threads = threads)
823+ _change_refpool_find_range_for_join! (ranges, dsl, dsr, idx, oncols_left, oncols_right, mapformats[1 ], mapformats[2 ], j, nsfpaj = nsfpaj, threads = threads)
806824 end
807825 map (x -> length (x) == 0 ? false : true , ranges)
808826end
@@ -875,6 +893,7 @@ function _join_outer(dsl, dsr::AbstractDataset, ::Val{T}; onleft, onright, makeu
875893 (isempty (dsl) || isempty (dsr)) && throw (ArgumentError (" in `outerjoin` both left and right tables must be non-empty" ))
876894 oncols_left = onleft
877895 oncols_right = onright
896+ nsfpaj = [true ]
878897 if method == :hash
879898 ranges, a, idx, minval, reps, sz, right_cols = _find_ranges_for_join_using_hash (dsl, dsr, onleft, onright, mapformats, makeunique, Val (T); threads = threads)
880899 elseif method == :sort
@@ -889,9 +908,9 @@ function _join_outer(dsl, dsr::AbstractDataset, ::Val{T}; onleft, onright, makeu
889908 return result
890909 end
891910 end
892- idx, uniquemode = _find_permute_and_fill_range_for_join! (ranges, dsr, dsl, oncols_right, oncols_left, stable, alg, mapformats, accelerate, threads = threads)
911+ idx, uniquemode = _find_permute_and_fill_range_for_join! (ranges, dsr, dsl, oncols_right, oncols_left, stable, alg, mapformats, accelerate, nsfpaj = nsfpaj, threads = threads)
893912 for j in 1 : length (oncols_left)
894- _change_refpool_find_range_for_join! (ranges, dsl, dsr, idx, oncols_left, oncols_right, mapformats[1 ], mapformats[2 ], j, threads = threads)
913+ _change_refpool_find_range_for_join! (ranges, dsl, dsr, idx, oncols_left, oncols_right, mapformats[1 ], mapformats[2 ], j, nsfpaj = nsfpaj, threads = threads)
895914 end
896915 end
897916 new_ends = map (x -> max (1 , length (x)), ranges)
0 commit comments