Reputation: 614
in my project I found that sorting performance is the bottleneck. After some googling I came up with parallel version of radix sort (with base 256). However it is not behaving as I expected.
First changing the base to 2^16 doesn't cause any speedup and it should theoretically by 2.
Second in my parallel version I split it to 4 parts (number of cores) and do radix sort on them, then I merge the result. Again it runs merely at the same time as serial version.
public class RadixSortPrototype {
public static void parallelSort(long[] arr) {
long[] output = new long[arr.length];
int MAX_PART = 1_000_000;
int numProc = Runtime.getRuntime().availableProcessors();
int partL = Math
.min((int) Math.ceil(arr.length / (double) numProc), MAX_PART);
int parts = (int) Math.ceil(arr.length / (double) partL);
Future[] threads = new Future[parts];
ExecutorService worker = Executors.newFixedThreadPool(numProc);
for (int i = 0; i < 8; i++) {
int[][] counts = new int[parts][256];
int radix = i;
for (int j = 0; j < parts; j++) {
int part = j;
threads[j] = worker.submit(() -> {
for (int k = part * partL; k < (part + 1) * partL && k < arr.length;
k++) {
int chunk = (int) ((arr[k] >> (radix * 8)) & 255);
counts[part][chunk]++;
}
});
}
barrier(parts, threads);
int base = 0;
for (int k = 0; k <= 255; k++) {
for (int j = 0; j < parts; j++) {
int t = counts[j][k];
counts[j][k] = base;
base += t;
}
}
for (int j = 0; j < parts; j++) {
int part = j;
threads[j] = worker.submit(() -> {
for (int k = part * partL;
k < (part + 1) * partL && k < arr.length;
k++) {
int chunk = (int) ((arr[k] >> (radix * 8)) & 255);
output[counts[part][chunk]] = arr[k];
counts[part][chunk]++;
}
});
}
barrier(parts, threads);
for (int j = 0; j < parts; j++) {
int part = j;
threads[j] = worker.submit(() -> {
for (int k = part * partL;
k < (part + 1) * partL && k < arr.length;
k++) {
arr[k] = output[k];
}
});
}
barrier(parts, threads);
}
worker.shutdownNow();
}
private static void barrier(int parts, Future[] threads) {
for (int j = 0; j < parts; j++) {
try {
threads[j].get();
} catch (InterruptedException | ExecutionException e) {
e.printStackTrace();
}
}
}
}
Any ideas why it is running so slow? What is the recommended way to tackle this optimization?
I'm really curious about the answer.
Thanks!
Update
Basing on the answer I improved locality of data, so now it uses all the cores. Updated the code snippet. Here are results for 2-core 4-thread CPU.
Java Parallel: 1130 ms
Radixsort Serial: 1218 ms
Radixsort Parallel: 625 ms
The question remains open if it can be further improved.
Upvotes: 0
Views: 361
Reputation: 28826
Using base 2^16 = 65536 ends up a bit slower because L1 cache is typically 32768 bytes per core, and base 2^16 counts|indexes arrays each use 2^20 = 262144 bytes.
The issue with radix sort is that the reads are sequential, but the writes are as random as the data. Based on the comment, the program is sorting 20 million longs at 8 bytes each, so 80 MB of data, and assuming 8MB L3 cache, most of those writes are going to be cache misses. The parallel operations aren't helping much because most of the writes are competing for the same 80 MB of non-cached main memory.
To avoid this issue, I used an alternate implementation where the first pass does a most significant digit radix sort to produce 256 bins (each bin contains integers with the same most significant byte). Then each bin is sorted using conventional radix sort least significant digit first. For reasonably uniform psuedo random data, the 256 bins end up nearly equal in size, so the 80MB is split into 256 bins, about 312500 bytes each, and for 4 threads, there are 8 of these bins, 4 for reads, 4 for writes, plus the count|index arrays, and all of this will fit into the 8MB L3 16 way associative L3 cache common to all 4 cores.
For larger arrays, the initial pass could split up the array into 512 to 4096 or more bins.
I did some testing with some old C++ code I have for radix sort for sorting pseudo random 64 bit integers, using base 2^8 = 256. I tested 3 implementations, single thread least significant digit, single thread most significant digit first, and quad thread most significant digit first. When the number of integers was a power of 2, it resulted in some cache conflicts, affecting the time in some cases.
16000000 - 8 bins + index arrays fit in 8MB L3 cache.
16777216 = 2^24, 8 bins + index arrays fit in 8MB L3 cache.
30000000 - 8 bins + index arrays fit in 8MB L3 cache.
33554432 = 2^25, 8 bins + index arrays a bit larger than 8MB
36000000 - 8 bins + index arrays a bit larger than the 8MB.
Win 7 Pro 64 bit, VS 2015, Intel 3770K 3.5 ghz
count 1 thread LSD 1 thread MSD 4 thread MSD
16000000 0.59 0.38 0.16
16777216 1.35 0.48 0.30
30000000 0.82 0.70 0.30
33554432 3.20 1.09 0.68
36000000 0.95 0.82 0.39
Win 10 Pro 64 bit, VS 2019, Intel 10510U 1.8 ghz to 4.9 ghz
count 1 thread LSD 1 thread MSD 4 thread MSD
16000000 0.312 0.230 0.125
16777216 0.897 0.242 0.150
30000000 0.480 0.430 0.236
33554432 2.880 0.510 0.250
36000000 0.568 0.530 0.305
Upvotes: 1