mcmayer
mcmayer

Reputation: 2023

Remove bounds checking in Rust loop in attempt to reach optimal compiler output

In an attempt to determine whether I can/should use Rust instead of the default C/C++ I'm looking into various edge cases, mostly with this question in mind: In the 0.1% of cases where it does matter, can I always get compiler output as good as gcc's (with the appropriate optimization flags)? The answer is most likely no, but let's see...

On Reddit there is a rather idiosyncratic example that studies the compiler output of a subroutine for a branchless sort algorithm.

Here is the benchmark C code:

#include <stdint.h>
#include <stdlib.h>
int32_t* foo(int32_t* elements, int32_t* buffer, int32_t pivot)
{
    size_t buffer_index = 0;

    for (size_t i = 0; i < 64; ++i) {
        buffer[buffer_index] = (int32_t)i;
        buffer_index += (size_t)(elements[i] < pivot);
    }
}

And here is the godbolt link with compiler output.

The first attempt with Rust looks like this:

pub fn foo0(elements: &Vec<i32>, mut buffer: [i32; 64], pivot: i32) -> () {
    let mut buffer_index: usize = 0;
    for i in 0..buffer.len() {
        buffer[buffer_index] = i as i32;
        buffer_index += (elements[i] < pivot) as usize; 
    }
}

There's quite a bit of bounds checking going on, see godbolt.

The next attempt eliminates the first bounds checking:

pub unsafe fn foo1(elements: &Vec<i32>, mut buffer: [i32; 64], pivot: i32) -> () {
    let mut buffer_index: usize = 0;
    for i in 0..buffer.len() {
        unsafe {
            buffer[buffer_index] = i as i32;
            buffer_index += (elements.get_unchecked(i) < &pivot) as usize; 
        }
    }
}

That's a little better (see the same godbolt link as above).

Finally, let's try to remove the bounds checks altogether:

use std::ptr;

pub unsafe fn foo2(elements: &Vec<i32>, mut buffer: [i32; 64], pivot: i32) -> () {
    let mut buffer_index: usize = 0;
    unsafe {
        for i in 0..buffer.len() {
            ptr::replace(&mut buffer[buffer_index], i as i32);
            buffer_index += (elements.get_unchecked(i) < &pivot) as usize; 
        }
    }
}

This produces the same output as foo1, so ptr::replace still performs bounds checking. I'm certainly out of my depth, here, with those unsafe operations. That leads to my two questions:

Regarding the last point, I'm curious, in general, whether Rust can be butchered to the point where it is as "literal", i.e. close to the metal, as C is. Seasoned Rust programmers will probably cringe at this line of investigation, but here it is...

Upvotes: 11

Views: 6409

Answers (2)

Angelicos Phosphoros
Angelicos Phosphoros

Reputation: 3057

You can achieve this using old-school pointer arithmetics.

const N: usize = 64;
pub fn foo2(elements: &Vec<i32>, mut buffer: [i32; N], pivot: i32) -> () {
    assert!(elements.len() >= N);
    let elements = &elements[..N];
    let mut buff_ptr = buffer.as_mut_ptr();
    for (i, &elem) in elements.iter().enumerate(){
        unsafe{
            // SAFETY: We increase ptr strictly less or N times
            *buff_ptr = i as i32;
            if elem < pivot{
                buff_ptr = buff_ptr.add(1);
            }
        }
    }
}

This version compiles into:

example::foo2:
        push    rax
        cmp     qword ptr [rdi + 16], 64
        jb      .LBB7_4
        mov     r9, qword ptr [rdi]
        lea     r8, [r9 + 256]
        xor     edi, edi

        // Loop goes here
.LBB7_2:
        mov     ecx, dword ptr [r9 + 4*rdi]
        mov     dword ptr [rsi], edi
        lea     rax, [rsi + 4]
        cmp     ecx, edx
        cmovge  rax, rsi
        mov     ecx, dword ptr [r9 + 4*rdi + 4]
        lea     esi, [rdi + 1]
        mov     dword ptr [rax], esi
        lea     rsi, [rax + 4]
        cmp     ecx, edx
        cmovge  rsi, rax
        mov     eax, dword ptr [r9 + 4*rdi + 8]
        lea     ecx, [rdi + 2]
        mov     dword ptr [rsi], ecx
        lea     rcx, [rsi + 4]
        cmp     eax, edx
        cmovge  rcx, rsi
        mov     r10d, dword ptr [r9 + 4*rdi + 12]
        lea     esi, [rdi + 3]
        lea     rax, [r9 + 4*rdi + 16]
        add     rdi, 4
        mov     dword ptr [rcx], esi
        lea     rsi, [rcx + 4]
        cmp     r10d, edx
        cmovge  rsi, rcx
        // Conditional branch to the loop beginning
        cmp     rax, r8
        jne     .LBB7_2
        pop     rax
        ret
.LBB7_4:
        call    std::panicking::begin_panic
        ud2

As you see, loop is unrolled and single branch is loop iteration jump.

However, I am suprised, that this function is not eliminated because it has no effects: it should be compiled into simple noop. Probably, it would be made such after inlining.

Also, I would say, that changing parameter to the &mut doesn't change code:

example::foo2:
        push    rax
        cmp     qword ptr [rdi + 16], 64
        jb      .LBB7_4
        mov     r9, qword ptr [rdi]
        lea     r8, [r9 + 256]
        xor     edi, edi
.LBB7_2:
        mov     ecx, dword ptr [r9 + 4*rdi]
        mov     dword ptr [rsi], edi
        lea     rax, [rsi + 4]
        cmp     ecx, edx
        cmovge  rax, rsi
        mov     ecx, dword ptr [r9 + 4*rdi + 4]
        lea     esi, [rdi + 1]
        mov     dword ptr [rax], esi
        lea     rsi, [rax + 4]
        cmp     ecx, edx
        cmovge  rsi, rax
        mov     eax, dword ptr [r9 + 4*rdi + 8]
        lea     ecx, [rdi + 2]
        mov     dword ptr [rsi], ecx
        lea     rcx, [rsi + 4]
        cmp     eax, edx
        cmovge  rcx, rsi
        mov     r10d, dword ptr [r9 + 4*rdi + 12]
        lea     esi, [rdi + 3]
        lea     rax, [r9 + 4*rdi + 16]
        add     rdi, 4
        mov     dword ptr [rcx], esi
        lea     rsi, [rcx + 4]
        cmp     r10d, edx
        cmovge  rsi, rcx
        cmp     rax, r8
        jne     .LBB7_2
        pop     rax
        ret
.LBB7_4:
        call    std::panicking::begin_panic
        ud2

So probably rustc emits that function accepts buffer parameter as pointer in LLVM IR, unfortunately.

Upvotes: 2

E_net4
E_net4

Reputation: 29972

  • How can the bounds check be eliminated?

Arrays, via their deref-coercion to a slice, have an unchecked form of mutable get too.

pub unsafe fn foo(elements: &Vec<i32>, mut buffer: [i32; 64], pivot: i32) {
    let mut buffer_index: usize = 0;
    for i in 0..buffer.len() {
        unsafe {
            *buffer.get_unchecked_mut(buffer_index) = i as i32;
            buffer_index += (elements.get_unchecked(i) < &pivot) as usize; 
        }
    }
}

This may result in the same machine code as the one obtained by compiling the equivalent C code with Clang. https://godbolt.org/z/ddxP1P

  • Does it even make sense to analyze edge cases edge cases like this? Or would the Rust compiler see through all this if presented with the whole algorithm instead of only a small fraction thereof.

As always, benchmark any of these situations in the even that you have identified a bottleneck in that part of the code. Otherwise, it is a premature optimization that one day could be regretted. Particularly in Rust, the decision to write unsafe code should not be taken lightly. It is safe to say that in many cases, the effort and risk alone of removing bounds checking outweighs the intended performance benefits.

Regarding the last point, I'm curious, in general, whether Rust can be butchered to the point where it is as "literal", i.e. close to the metal, as C is.

No, and you wouldn't want this for two major reasons:

  1. Despite the power of abstractions in Rust, the principle of not paying for what you don't use is still very pertinent, in a similar fashion to C++. See what makes an abstraction zero-cost. In the case of bounds checking, this is merely a consequence of a language design decision to always perform spatial checks when the compiler cannot ensure that such an access is memory safe.
  2. C is not that low-level anyway. It may seem literal and close to the metal until it really isn't.

See also:

Upvotes: 4

Related Questions