@@ -45,12 +45,16 @@ function choose_pivot(xs, order)
45
45
)
46
46
end
47
47
48
- function _quicksort! (ys, xs, alg, order)
48
+ function _quicksort! (ys, xs, alg, order, givenpivot = nothing )
49
49
@check length (ys) == length (xs)
50
50
if length (ys) <= max (8 , alg. basesize)
51
51
return _quicksort_serial! (ys, xs, alg, order)
52
52
end
53
- pivot = choose_pivot (ys, order)
53
+ pivot = if givenpivot === nothing
54
+ choose_pivot (ys, order)
55
+ else
56
+ something (givenpivot)
57
+ end
54
58
chunksize = alg. basesize
55
59
56
60
# TODO : Calculate extrema during the first pass if it's possible
@@ -89,9 +93,12 @@ function _quicksort!(ys, xs, alg, order)
89
93
@check acc == length (xs)
90
94
91
95
total_nbelows = above_offsets[1 ]
92
- total_nbelows == 0 && return sort! (ys, alg. smallsort, order)
93
- # TODO : Fallback to parallel mergesort? Scan the array to check degeneracy
94
- # and also to estimate a good pivot?
96
+ if total_nbelows == 0
97
+ @assert givenpivot === nothing
98
+ betterpivot, ishomogenous = refine_pivot (ys, pivot, alg. basesize, order)
99
+ ishomogenous && return ys
100
+ return _quicksort! (ys, xs, alg, order, Some (betterpivot))
101
+ end
95
102
96
103
# (2) `quicksort_copyback!` -- Copy partitions back to the original
97
104
# (destination) array `ys` in the natural order
@@ -153,3 +160,69 @@ function quicksort_copyback!(ys, xs_chunk, nbelows, below_offset, above_offset)
153
160
@inbounds ys[above_offset+ i] = xs_chunk[end - i+ 1 ]
154
161
end
155
162
end
163
+
164
+ """
165
+ refine_pivot(ys, badpivot::T, basesize, order) -> (pivot::T, ishomogenous::Bool)
166
+
167
+ Iterate over `ys` for refining `badpivot` and checking if all elements in `ys`
168
+ are `order`-equal to `badpivot` (i.e., it is impossible to refine `badpivot`).
169
+
170
+ Return a value `pivot` in `ys` and a boolean `ishomogenous` indicating if `pivot`
171
+ is not `order`-greater than `badpivot`.
172
+
173
+ Given the precondition:
174
+
175
+ badpivot ∈ ys
176
+ all(!(y < badpivot) for y in ys) # i.e., total_nbelows == 0
177
+
178
+ `ishomogenous` implies all elements in `ys` are `order`-equal to `badpivot` and
179
+ `pivot` is better than `badpivot` if and only if `!ishomogenous`.
180
+ """
181
+ function refine_pivot (ys, badpivot, basesize, order)
182
+ chunksize = max (basesize, cld (length (ys), Threads. nthreads ()))
183
+ nchunks = cld (length (ys), chunksize)
184
+ nchunks == 1 && return refine_pivot_serial (ys, badpivot, order)
185
+ ishomogenous = Vector {Bool} (undef, nchunks)
186
+ pivots = Vector {eltype(ys)} (undef, nchunks)
187
+ @sync for (i, ys_chunk) in enumerate (_partition (ys, chunksize))
188
+ @spawn (pivots[i], ishomogenous[i]) = refine_pivot_serial (ys_chunk, badpivot, order)
189
+ end
190
+ allishomogenous = all (ishomogenous)
191
+ allishomogenous && return (badpivot, true )
192
+ @DBG for (i, p) in pairs (pivots)
193
+ ishomogenous[i] && @check eq (order, p, badpivot)
194
+ end
195
+ # Find the smallest `pivot` that is not `badpivot`. Assuming that there are
196
+ # a lot of `badpivot` entries, this is perhaps better than using the median
197
+ # of `pivots`.
198
+ i0 = findfirst (! , ishomogenous)
199
+ goodpivot = pivots[i0]
200
+ for i in i0+ 1 : nchunks
201
+ if @inbounds ! ishomogenous[i]
202
+ p = @inbounds pivots[i]
203
+ if Base. lt (order, p, goodpivot)
204
+ goodpivot = p
205
+ end
206
+ end
207
+ end
208
+ return (goodpivot, false )
209
+ end
210
+
211
+ function refine_pivot_serial (ys, badpivot, order)
212
+ for y in ys
213
+ if Base. lt (order, badpivot, y)
214
+ return (y, false )
215
+ else
216
+ # Since `refine_pivot` is called only if `total_nbelows == 0` and
217
+ # `y1` is the bad pivot, we have:
218
+ @DBG @check ! Base. lt (order, y, badpivot) # i.e., y == y1
219
+ end
220
+ end
221
+ return (badpivot, true )
222
+ end
223
+ # TODO : online median approximation
224
+ # TODO : Check if the homogeneity check can be done in `quicksort_partition!`
225
+ # without overall performance degradation? Use it to determine the pivot
226
+ # for the next recursion.
227
+ # TODO : Do this right after `choose_pivot` if it finds out that all samples are
228
+ # equivalent?
0 commit comments