Skip to content

Commit f260315

Browse files
committed
Quicksort: copy twice instead of scan-scatter
1 parent f8fa366 commit f260315

File tree

2 files changed

+66
-205
lines changed

2 files changed

+66
-205
lines changed

src/quicksort.jl

+60-205
Original file line numberDiff line numberDiff line change
@@ -23,32 +23,13 @@ function Base.sort!(
2323
a = @set a.smallsize = a.basesize
2424
end
2525
ys = view(v, lo:hi)
26-
_quicksort!(
27-
similar(ys),
28-
ys,
29-
a,
30-
o,
31-
Vector{Int8}(undef, length(ys)),
32-
false, # ys_is_result
33-
true, # mutable_xs
34-
)
26+
xs = similar(ys)
27+
_quicksort!(ys, xs, a, o)
3528
return v
3629
end
3730

38-
function _quicksort!(
39-
ys,
40-
xs,
41-
alg,
42-
order,
43-
cs = Vector{Int8}(undef, length(ys)),
44-
ys_is_result = true,
45-
mutable_xs = false,
46-
)
47-
@check length(ys) == length(xs)
48-
if length(ys) <= max(8, alg.basesize)
49-
return _quicksort_serial!(ys, xs, alg, order, cs, ys_is_result, mutable_xs)
50-
end
51-
pivot = _median(
31+
function choose_pivot(xs, order)
32+
return _median(
5233
order,
5334
(
5435
xs[1],
@@ -62,239 +43,113 @@ function _quicksort!(
6243
xs[end],
6344
),
6445
)
46+
end
47+
48+
function _quicksort!(ys, xs, alg, order)
49+
@check length(ys) == length(xs)
50+
if length(ys) <= max(8, alg.basesize)
51+
return _quicksort_serial!(ys, xs, alg, order)
52+
end
53+
pivot = choose_pivot(ys, order)
6554
chunksize = alg.basesize
6655

6756
# TODO: Calculate extrema during the first pass if it's possible
6857
# to use counting sort.
69-
# TODO: When recursing, fuse copying _from_ `ys` to `xs` with the
70-
# first pass.
7158

72-
# Compute sizes of each partition for each chunks.
59+
# (1) `quicksort_partition!` -- partition each chunk in parallel
7360
xs_chunk_list = _partition(xs, chunksize)
74-
cs_chunk_list = _partition(cs, chunksize)
61+
ys_chunk_list = _partition(ys, chunksize)
7562
nchunks = cld(length(xs), chunksize)
7663
nbelows = Vector{Int}(undef, nchunks)
77-
nequals = Vector{Int}(undef, nchunks)
7864
naboves = Vector{Int}(undef, nchunks)
7965
@DBG begin
8066
VERSION >= v"1.4" &&
81-
@check length(xs_chunk_list) == length(cs_chunk_list) == nchunks
67+
@check length(xs_chunk_list) == length(ys_chunk_list) == nchunks
8268
fill!(nbelows, -1)
83-
fill!(nequals, -1)
8469
fill!(naboves, -1)
8570
end
86-
@sync for (nb, ne, na, xs_chunk, cs_chunk) in zip(
71+
@sync for (nb, na, xs_chunk, ys_chunk) in zip(
8772
referenceable(nbelows),
88-
referenceable(nequals),
8973
referenceable(naboves),
9074
xs_chunk_list,
91-
cs_chunk_list,
75+
ys_chunk_list,
9276
)
93-
@spawn partition_sizes!(nb, ne, na, xs_chunk, cs_chunk, pivot, order)
77+
@spawn (nb[], na[]) = quicksort_partition!(xs_chunk, ys_chunk, pivot, order)
9478
end
9579
@DBG begin
9680
@check all(>=(0), nbelows)
97-
@check all(>=(0), nequals)
9881
@check all(>=(0), naboves)
82+
@check nbelows .+ nbelows == map(length, xs_chunk_list)
9983
end
10084

10185
below_offsets = nbelows
102-
equal_offsets = nequals
10386
above_offsets = naboves
10487
acc = exclusive_cumsum!(below_offsets)
105-
acc = exclusive_cumsum!(equal_offsets, acc)
10688
acc = exclusive_cumsum!(above_offsets, acc)
10789
@check acc == length(xs)
10890

109-
@inline function singleton_chunkid(i)
110-
nb = @inbounds get(below_offsets, i + 1, equal_offsets[1]) - below_offsets[i]
111-
ne = @inbounds get(equal_offsets, i + 1, above_offsets[1]) - equal_offsets[i]
112-
na = @inbounds get(above_offsets, i + 1, length(ys)) - above_offsets[i]
113-
if (nb > 0) + (ne > 0) + (na > 0) == 1
114-
return 1 * (nb > 0) + 2 * (ne > 0) + 3 * (na > 0)
115-
else
116-
return 0
117-
end
118-
end
119-
120-
@sync begin
121-
for (i, (xs_chunk, cs_chunk)) in enumerate(zip(xs_chunk_list, cs_chunk_list))
122-
singleton_chunkid(i) > 0 && continue
123-
@spawn unsafe_quicksort_scatter!(
124-
ys,
125-
xs_chunk,
126-
cs_chunk,
127-
below_offsets[i],
128-
equal_offsets[i],
129-
above_offsets[i],
130-
)
131-
end
132-
for (i, xs_chunk) in enumerate(xs_chunk_list)
133-
sid = singleton_chunkid(i)
134-
sid > 0 || continue
135-
idx = (
136-
below_offsets[i]+1:get(below_offsets, i + 1, equal_offsets[1]),
137-
equal_offsets[i]+1:get(equal_offsets, i + 1, above_offsets[1]),
138-
above_offsets[i]+1:get(above_offsets, i + 1, length(ys)),
139-
)[sid]
140-
# There is only one partition. Short-circuit scattering.
141-
ys_chunk = view(ys, idx)
142-
copyto!(ys_chunk, xs_chunk)
143-
# Is it better to multi-thread this?
144-
end
91+
total_nbelows = above_offsets[1]
92+
total_nbelows == 0 && return sort!(ys, alg.smallsort, order)
93+
# TODO: Fallback to parallel mergesort? Scan the array to check degeneracy
94+
# and also to estimate a good pivot?
95+
96+
# (2) `quicksort_copyback!` -- Copy partitions back to the original
97+
# (destination) array `ys` in the natural order
98+
@sync for (i, (xs_chunk, below_offset, above_offset)) in
99+
enumerate(zip(xs_chunk_list, below_offsets, above_offsets))
100+
local nb = get(below_offsets, i + 1, total_nbelows) - below_offsets[i]
101+
@spawn quicksort_copyback!(ys, xs_chunk, nb, below_offset, above_offset)
145102
end
146103

147-
partitions = (1:equal_offsets[1], above_offsets[1]+1:length(xs))
104+
# (3) Recursively sort each partion
105+
below = 1:total_nbelows
106+
above = total_nbelows+1:length(xs)
148107
@sync begin
149-
for idx in partitions
150-
length(idx) <= alg.smallsize && continue
151-
ys_new = view(ys, idx)
152-
xs_new = view(xs, idx)
153-
cs_new = view(cs, idx)
154-
@spawn let zs
155-
if mutable_xs
156-
zs = xs_new
157-
else
158-
zs = similar(ys_new)
159-
end
160-
_quicksort!(zs, ys_new, alg, order, cs_new, !ys_is_result, true)
161-
end
162-
end
163-
for idx in partitions
164-
length(idx) <= alg.smallsize || continue
165-
if ys_is_result
166-
ys_new = view(ys, idx)
167-
else
168-
ys_new = copyto!(view(xs, idx), view(ys, idx))
169-
end
170-
sort!(ys_new, alg.smallsort, order)
171-
end
172-
if !ys_is_result
173-
let idx = equal_offsets[1]+1:above_offsets[1]
174-
copyto!(view(xs, idx), view(ys, idx))
175-
end
176-
end
108+
@spawn _quicksort!(view(ys, above), view(xs, above), alg, order)
109+
_quicksort!(view(ys, below), view(xs, below), alg, order)
177110
end
178111

179-
return ys_is_result ? ys : xs
112+
return ys
180113
end
181114

182-
function _quicksort_serial!(
183-
ys,
184-
xs,
185-
alg,
186-
order,
187-
cs = Vector{Int8}(undef, length(ys)),
188-
ys_is_result = true,
189-
mutable_xs = false,
190-
)
115+
function _quicksort_serial!(ys, xs, alg, order)
191116
# @check length(ys) == length(xs)
192117
if length(ys) <= max(8, alg.smallsize)
193-
if ys_is_result
194-
zs = copyto!(ys, xs)
195-
else
196-
zs = xs
197-
end
198-
return sort!(zs, alg.smallsort, order)
118+
return sort!(ys, alg.smallsort, order)
199119
end
200-
pivot = _median(
201-
order,
202-
(
203-
xs[1],
204-
xs[end÷8],
205-
xs[end÷4],
206-
xs[3*(end÷8)],
207-
xs[end÷2],
208-
xs[5*(end÷8)],
209-
xs[3*(end÷4)],
210-
xs[7*(end÷8)],
211-
xs[end],
212-
),
213-
)
120+
pivot = choose_pivot(ys, order)
121+
122+
nbelows, naboves = quicksort_partition!(xs, ys, pivot, order)
123+
@DBG @check nbelows + naboves == length(xs)
124+
nbelows == 0 && return sort!(ys, alg.smallsort, order)
214125

215-
(nbelows, nequals) = partition_sizes!(xs, cs, pivot, order)
216-
if nequals == length(xs)
217-
if ys_is_result
218-
copyto!(ys, xs)
219-
return ys
220-
else
221-
return xs
222-
end
223-
end
224-
@assert nequals > 0
225126
below_offset = 0
226-
equal_offset = nbelows
227-
above_offset = nbelows + nequals
228-
unsafe_quicksort_scatter!(ys, xs, cs, below_offset, equal_offset, above_offset)
127+
above_offset = nbelows
128+
quicksort_copyback!(ys, xs, nbelows, below_offset, above_offset)
229129

230-
below = 1:equal_offset
130+
below = 1:above_offset
231131
above = above_offset+1:length(xs)
232-
ya = view(ys, above)
233-
yb = view(ys, below)
234-
ca = view(cs, above)
235-
cb = view(cs, below)
236-
if mutable_xs
237-
_quicksort_serial!(view(xs, above), ya, alg, order, ca, !ys_is_result, true)
238-
_quicksort_serial!(view(xs, below), yb, alg, order, cb, !ys_is_result, true)
239-
else
240-
let zs = similar(ys)
241-
_quicksort_serial!(view(zs, above), ya, alg, order, ca, !ys_is_result, true)
242-
_quicksort_serial!(view(zs, below), yb, alg, order, cb, !ys_is_result, true)
243-
end
244-
end
245-
if !ys_is_result
246-
let idx = equal_offset+1:above_offset
247-
copyto!(view(xs, idx), view(ys, idx))
248-
end
249-
end
250-
251-
return ys_is_result ? ys : xs
252-
end
132+
_quicksort_serial!(view(xs, below), view(ys, below), alg, order)
133+
_quicksort_serial!(view(xs, above), view(ys, above), alg, order)
253134

254-
function partition_sizes!(nbelows, nequals, naboves, xs, cs, pivot, order)
255-
(nb, ne) = partition_sizes!(xs, cs, pivot, order)
256-
nbelows[] = nb
257-
nequals[] = ne
258-
naboves[] = length(xs) - (nb + ne)
259-
return
135+
return ys
260136
end
261137

262-
function partition_sizes!(xs, cs, pivot, order)
263-
nbelows = 0
264-
nequals = 0
265-
@inbounds for i in eachindex(xs, cs)
266-
x = xs[i]
138+
function quicksort_partition!(xs, ys, pivot, order)
139+
_foldl((0, 0), Unroll{4}(eachindex(xs, ys))) do (nbelows, naboves), i
140+
@_inline_meta
141+
x = @inbounds ys[i]
267142
b = Base.lt(order, x, pivot)
268-
a = Base.lt(order, pivot, x)
269-
cs[i] = ifelse(b, -Int8(1), ifelse(a, Int8(1), Int8(0)))
270143
nbelows += Int(b)
271-
nequals += Int(!(a | b))
144+
naboves += Int(!b)
145+
@inbounds xs[ifelse(b, nbelows, end - naboves + 1)] = x
146+
return (nbelows, naboves)
272147
end
273-
return (nbelows, nequals)
274148
end
275149

276-
function unsafe_quicksort_scatter!(
277-
ys,
278-
xs_chunk,
279-
cs_chunk,
280-
below_offset,
281-
equal_offset,
282-
above_offset,
283-
)
284-
b = below_offset
285-
e = equal_offset
286-
a = above_offset
287-
_foldl((b, a, e), Unroll{4}(eachindex(xs_chunk, cs_chunk))) do (b, a, e), i
288-
@inbounds x = xs_chunk[i]
289-
@inbounds c = cs_chunk[i]
290-
is_equal = c == 0
291-
is_above = c > 0
292-
is_below = c < 0
293-
e += Int(is_equal)
294-
a += Int(is_above)
295-
b += Int(is_below)
296-
@inbounds ys[ifelse(is_equal, e, ifelse(is_above, a, b))] = x
297-
(b, a, e)
150+
function quicksort_copyback!(ys, xs_chunk, nbelows, below_offset, above_offset)
151+
copyto!(ys, below_offset + 1, xs_chunk, firstindex(xs_chunk), nbelows)
152+
@simd ivdep for i in 1:length(xs_chunk)-nbelows
153+
@inbounds ys[above_offset+i] = xs_chunk[end-i+1]
298154
end
299-
return
300155
end

src/utils.jl

+6
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
@static if VERSION >= v"1.8"
2+
@eval const $(Symbol("@_inline_meta")) = $(Symbol("@inline"))
3+
else
4+
using Base: @_inline_meta
5+
end
6+
17
function adhoc_partition(xs, n)
28
@check firstindex(xs) == 1
39
m = cld(length(xs), n)

0 commit comments

Comments
 (0)