We were optimizing a GPU TopK indexer kernel that selects the top-2048 indices from scored sequences. Profiling showed torch.topk consumed 77% of compute for large sequences (N=16384), making it the dominant bottleneck.

The idea: CUB radix sort with fewer passes

CUB’s DeviceRadixSort::SortPairsDescending supports begin_bit and end_bit parameters that control which bits get sorted. IEEE 754 float32 has sign (bit 31), exponent (bits 23-30), and mantissa (bits 0-22). By setting begin_bit=16, we sort only the upper 16 bits — sign, exponent, and top 7 mantissa bits — cutting from 4 radix passes to 2.

// 2 passes instead of 4: sort only upper 16 bits
cub::DeviceRadixSort::SortPairsDescending(
    temp, temp_bytes,
    scores_in, scores_out,
    indices_in, indices_out,
    N, /*begin_bit=*/16, /*end_bit=*/32, stream
);

The precision trap

This worked perfectly for N >= 4096 (128/128 correct). But lowering the threshold to N=2048 broke 7 workloads — scores near the top-K boundary differed only in the lower 16 mantissa bits, causing the sort to produce different orderings than torch.topk.

We tested begin_bit=8 (3 passes) at the same threshold — the exact same 7 workloads failed. This confirmed the issue was precision-dependent, not pass-count-dependent: those workloads had near-identical scores at the selection boundary where every mantissa bit matters.

The fix: tiered dispatch

Instead of one threshold, we dispatch based on sequence length:

static constexpr int CUB_THRESHOLD_LOW  = 2048;
static constexpr int CUB_THRESHOLD_HIGH = 4096;

if (N >= CUB_THRESHOLD_HIGH) {
    // Large N: 2 passes (begin_bit=16), scores well-separated
    cub_sort(scores, N, /*begin_bit=*/16, /*end_bit=*/32);
} else if (N >= CUB_THRESHOLD_LOW) {
    // Medium N: 4 passes (begin_bit=0), full precision needed
    cub_sort(scores, N, /*begin_bit=*/0, /*end_bit=*/32);
} else {
    // Small N: torch.topk is faster
    torch_topk(scores, N, K);
}

Full precision CUB (4 passes) at N=2048-4095 is still faster than torch.topk, while 2-pass CUB handles the bulk of large workloads. Combined with earlier optimizations (vectorized FP8 dequant, fused ReLU+multiply CUDA kernel, C++ ATen pipeline), this pushed us from ~5.9x to ~6.9x average speedup across 128 workloads.

Takeaway

Radix sort’s begin_bit parameter is a powerful but dangerous knob — it trades mantissa precision for throughput. For TopK selection, the safe strategy is tiered dispatch: fewer passes where scores are well-separated (large N), full precision where boundary scores are tight (small N).