Fewer Radix Passes: Tuning CUB Sort for GPU TopK with Tiered Bit Dispatch
We were optimizing a GPU TopK indexer kernel that selects the top-2048 indices from scored sequences. Profiling showed torch.topk consumed 77% of compute for large sequences (N=16384), making it the dominant bottleneck.
The idea: CUB radix sort with fewer passes
CUB’s DeviceRadixSort::SortPairsDescending supports begin_bit and end_bit parameters that control which bits get sorted. IEEE 754 float32 has sign (bit 31), exponent (bits 23-30), and mantissa (bits 0-22). By setting begin_bit=16, we sort only the upper 16 bits — sign, exponent, and top 7 mantissa bits — cutting from 4 radix passes to 2.
// 2 passes instead of 4: sort only upper 16 bits
cub::DeviceRadixSort::SortPairsDescending(
temp, temp_bytes,
scores_in, scores_out,
indices_in, indices_out,
N, /*begin_bit=*/16, /*end_bit=*/32, stream
);
The precision trap
This worked perfectly for N >= 4096 (128/128 correct). But lowering the threshold to N=2048 broke 7 workloads — scores near the top-K boundary differed only in the lower 16 mantissa bits, causing the sort to produce different orderings than torch.topk.
We tested begin_bit=8 (3 passes) at the same threshold — the exact same 7 workloads failed. This confirmed the issue was precision-dependent, not pass-count-dependent: those workloads had near-identical scores at the selection boundary where every mantissa bit matters.
The fix: tiered dispatch
Instead of one threshold, we dispatch based on sequence length:
static constexpr int CUB_THRESHOLD_LOW = 2048;
static constexpr int CUB_THRESHOLD_HIGH = 4096;
if (N >= CUB_THRESHOLD_HIGH) {
// Large N: 2 passes (begin_bit=16), scores well-separated
cub_sort(scores, N, /*begin_bit=*/16, /*end_bit=*/32);
} else if (N >= CUB_THRESHOLD_LOW) {
// Medium N: 4 passes (begin_bit=0), full precision needed
cub_sort(scores, N, /*begin_bit=*/0, /*end_bit=*/32);
} else {
// Small N: torch.topk is faster
torch_topk(scores, N, K);
}
Full precision CUB (4 passes) at N=2048-4095 is still faster than torch.topk, while 2-pass CUB handles the bulk of large workloads. Combined with earlier optimizations (vectorized FP8 dequant, fused ReLU+multiply CUDA kernel, C++ ATen pipeline), this pushed us from ~5.9x to ~6.9x average speedup across 128 workloads.
Takeaway
Radix sort’s begin_bit parameter is a powerful but dangerous knob — it trades mantissa precision for throughput. For TopK selection, the safe strategy is tiered dispatch: fewer passes where scores are well-separated (large N), full precision where boundary scores are tight (small N).