Skip to content

Commit 501b8a6

Browse files
committed
Try to refine pivot when partition fails
1 parent f260315 commit 501b8a6

File tree

2 files changed

+80
-5
lines changed

2 files changed

+80
-5
lines changed

src/quicksort.jl

+78-5
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,16 @@ function choose_pivot(xs, order)
4545
)
4646
end
4747

48-
function _quicksort!(ys, xs, alg, order)
48+
function _quicksort!(ys, xs, alg, order, givenpivot = nothing)
4949
@check length(ys) == length(xs)
5050
if length(ys) <= max(8, alg.basesize)
5151
return _quicksort_serial!(ys, xs, alg, order)
5252
end
53-
pivot = choose_pivot(ys, order)
53+
pivot = if givenpivot === nothing
54+
choose_pivot(ys, order)
55+
else
56+
something(givenpivot)
57+
end
5458
chunksize = alg.basesize
5559

5660
# TODO: Calculate extrema during the first pass if it's possible
@@ -89,9 +93,12 @@ function _quicksort!(ys, xs, alg, order)
8993
@check acc == length(xs)
9094

9195
total_nbelows = above_offsets[1]
92-
total_nbelows == 0 && return sort!(ys, alg.smallsort, order)
93-
# TODO: Fallback to parallel mergesort? Scan the array to check degeneracy
94-
# and also to estimate a good pivot?
96+
if total_nbelows == 0
97+
@assert givenpivot === nothing
98+
betterpivot, ishomogenous = refine_pivot(ys, pivot, alg.basesize, order)
99+
ishomogenous && return ys
100+
return _quicksort!(ys, xs, alg, order, Some(betterpivot))
101+
end
95102

96103
# (2) `quicksort_copyback!` -- Copy partitions back to the original
97104
# (destination) array `ys` in the natural order
@@ -153,3 +160,69 @@ function quicksort_copyback!(ys, xs_chunk, nbelows, below_offset, above_offset)
153160
@inbounds ys[above_offset+i] = xs_chunk[end-i+1]
154161
end
155162
end
163+
164+
"""
165+
refine_pivot(ys, badpivot::T, basesize, order) -> (pivot::T, ishomogenous::Bool)
166+
167+
Iterate over `ys` for refining `badpivot` and checking if all elements in `ys`
168+
are `order`-equal to `badpivot` (i.e., it is impossible to refine `badpivot`).
169+
170+
Return a value `pivot` in `ys` and a boolean `ishomogenous` indicating if `pivot`
171+
is not `order`-greater than `badpivot`.
172+
173+
Given the precondition:
174+
175+
badpivot ∈ ys
176+
all(!(y < badpivot) for y in ys) # i.e., total_nbelows == 0
177+
178+
`ishomogenous` implies all elements in `ys` are `order`-equal to `badpivot` and
179+
`pivot` is better than `badpivot` if and only if `!ishomogenous`.
180+
"""
181+
function refine_pivot(ys, badpivot, basesize, order)
182+
chunksize = max(basesize, cld(length(ys), Threads.nthreads()))
183+
nchunks = cld(length(ys), chunksize)
184+
nchunks == 1 && return refine_pivot_serial(ys, badpivot, order)
185+
ishomogenous = Vector{Bool}(undef, nchunks)
186+
pivots = Vector{eltype(ys)}(undef, nchunks)
187+
@sync for (i, ys_chunk) in enumerate(_partition(ys, chunksize))
188+
@spawn (pivots[i], ishomogenous[i]) = refine_pivot_serial(ys_chunk, badpivot, order)
189+
end
190+
allishomogenous = all(ishomogenous)
191+
allishomogenous && return (badpivot, true)
192+
@DBG for (i, p) in pairs(pivots)
193+
ishomogenous[i] && @check eq(order, p, badpivot)
194+
end
195+
# Find the smallest `pivot` that is not `badpivot`. Assuming that there are
196+
# a lot of `badpivot` entries, this is perhaps better than using the median
197+
# of `pivots`.
198+
i0 = findfirst(!, ishomogenous)
199+
goodpivot = pivots[i0]
200+
for i in i0+1:nchunks
201+
if @inbounds !ishomogenous[i]
202+
p = @inbounds pivots[i]
203+
if Base.lt(order, p, goodpivot)
204+
goodpivot = p
205+
end
206+
end
207+
end
208+
return (goodpivot, false)
209+
end
210+
211+
function refine_pivot_serial(ys, badpivot, order)
212+
for y in ys
213+
if Base.lt(order, badpivot, y)
214+
return (y, false)
215+
else
216+
# Since `refine_pivot` is called only if `total_nbelows == 0` and
217+
# `y1` is the bad pivot, we have:
218+
@DBG @check !Base.lt(order, y, badpivot) # i.e., y == y1
219+
end
220+
end
221+
return (badpivot, true)
222+
end
223+
# TODO: online median approximation
224+
# TODO: Check if the homogeneity check can be done in `quicksort_partition!`
225+
# without overall performance degradation? Use it to determine the pivot
226+
# for the next recursion.
227+
# TODO: Do this right after `choose_pivot` if it finds out that all samples are
228+
# equivalent?

src/utils.jl

+2
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ function elsizeof(::Type{T}) where {T}
3737
end
3838
end
3939

40+
eq(order, a, b) = !(Base.lt(order, a, b) || Base.lt(order, b, a))
41+
4042
function _median(order, (a, b, c)::NTuple{3,Any})
4143
# Sort `(a, b, c)`:
4244
if Base.lt(order, b, a)

0 commit comments

Comments
 (0)