diff --git a/library/core/src/ptr/mod.rs b/library/core/src/ptr/mod.rs index 2fcff3dfc5c44..509854225c4f3 100644 --- a/library/core/src/ptr/mod.rs +++ b/library/core/src/ptr/mod.rs @@ -1557,11 +1557,10 @@ pub(crate) unsafe fn align_offset(p: *const T, a: usize) -> usize { // FIXME(#75598): Direct use of these intrinsics improves codegen significantly at opt-level <= // 1, where the method versions of these operations are not inlined. use intrinsics::{ - unchecked_shl, unchecked_shr, unchecked_sub, wrapping_add, wrapping_mul, wrapping_sub, + cttz_nonzero, exact_div, unchecked_rem, unchecked_shl, unchecked_shr, unchecked_sub, + wrapping_add, wrapping_mul, wrapping_sub, }; - let addr = p.addr(); - /// Calculate multiplicative modular inverse of `x` modulo `m`. /// /// This implementation is tailored for `align_offset` and has following preconditions: @@ -1611,36 +1610,61 @@ pub(crate) unsafe fn align_offset(p: *const T, a: usize) -> usize { } } + let addr = p.addr(); let stride = mem::size_of::(); // SAFETY: `a` is a power-of-two, therefore non-zero. let a_minus_one = unsafe { unchecked_sub(a, 1) }; - if stride == 1 { - // `stride == 1` case can be computed more simply through `-p (mod a)`, but doing so - // inhibits LLVM's ability to select instructions like `lea`. Instead we compute + + if stride == 0 { + // SPECIAL_CASE: handle 0-sized types. No matter how many times we step, the address will + // stay the same, so no offset will be able to align the pointer unless it is already + // aligned. This branch _will_ be optimized out as `stride` is known at compile-time. + let p_mod_a = addr & a_minus_one; + return if p_mod_a == 0 { 0 } else { usize::MAX }; + } + + // SAFETY: `stride == 0` case has been handled by the special case above. + let a_mod_stride = unsafe { unchecked_rem(a, stride) }; + if a_mod_stride == 0 { + // SPECIAL_CASE: In cases where the `a` is divisible by `stride`, byte offset to align a + // pointer can be computed more simply through `-p (mod a)`. In the off-chance the byte + // offset is not a multiple of `stride`, the input pointer was misaligned and no pointer + // offset will be able to produce a `p` aligned to the specified `a`. // - // round_up_to_next_alignment(p, a) - p + // The naive `-p (mod a)` equation inhibits LLVM's ability to select instructions + // like `lea`. We compute `(round_up_to_next_alignment(p, a) - p)` instead. This + // redistributes operations around the load-bearing, but pessimizing `and` instruction + // sufficiently for LLVM to be able to utilize the various optimizations it knows about. // - // which distributes operations around the load-bearing, but pessimizing `and` sufficiently - // for LLVM to be able to utilize the various optimizations it knows about. - return wrapping_sub(wrapping_add(addr, a_minus_one) & wrapping_sub(0, a), addr); - } + // LLVM handles the branch here particularly nicely. If this branch needs to be evaluated + // at runtime, it will produce a mask `if addr_mod_stride == 0 { 0 } else { usize::MAX }` + // in a branch-free way and then bitwise-OR it with whatever result the `-p mod a` + // computation produces. + + // SAFETY: `stride == 0` case has been handled by the special case above. + let addr_mod_stride = unsafe { unchecked_rem(addr, stride) }; - let pmoda = addr & a_minus_one; - if pmoda == 0 { - // Already aligned. Yay! - return 0; - } else if stride == 0 { - // If the pointer is not aligned, and the element is zero-sized, then no amount of - // elements will ever align the pointer. - return usize::MAX; + return if addr_mod_stride == 0 { + let aligned_address = wrapping_add(addr, a_minus_one) & wrapping_sub(0, a); + let byte_offset = wrapping_sub(aligned_address, addr); + // SAFETY: `stride` is non-zero. This is guaranteed to divide exactly as well, because + // addr has been verified to be aligned to the original type’s alignment requirements. + unsafe { exact_div(byte_offset, stride) } + } else { + usize::MAX + }; } - let smoda = stride & a_minus_one; + // GENERAL_CASE: From here on we’re handling the very general case where `addr` may be + // misaligned, there isn’t an obvious relationship between `stride` and `a` that we can take an + // advantage of, etc. This case produces machine code that isn’t particularly high quality, + // compared to the special cases above. The code produced here is still within the realm of + // miracles, given the situations this case has to deal with. + // SAFETY: a is power-of-two hence non-zero. stride == 0 case is handled above. - let gcdpow = unsafe { intrinsics::cttz_nonzero(stride).min(intrinsics::cttz_nonzero(a)) }; + let gcdpow = unsafe { cttz_nonzero(stride).min(cttz_nonzero(a)) }; // SAFETY: gcdpow has an upper-bound that’s at most the number of bits in a usize. let gcd = unsafe { unchecked_shl(1usize, gcdpow) }; - // SAFETY: gcd is always greater or equal to 1. if addr & unsafe { unchecked_sub(gcd, 1) } == 0 { // This branch solves for the following linear congruence equation: @@ -1656,14 +1680,13 @@ pub(crate) unsafe fn align_offset(p: *const T, a: usize) -> usize { // ` p' + s'o = 0 mod a' ` // ` o = (a' - (p' mod a')) * (s'^-1 mod a') ` // - // The first term is "the relative alignment of `p` to `a`" (divided by the `g`), the second - // term is "how does incrementing `p` by `s` bytes change the relative alignment of `p`" (again - // divided by `g`). - // Division by `g` is necessary to make the inverse well formed if `a` and `s` are not - // co-prime. + // The first term is "the relative alignment of `p` to `a`" (divided by the `g`), the + // second term is "how does incrementing `p` by `s` bytes change the relative alignment of + // `p`" (again divided by `g`). Division by `g` is necessary to make the inverse well + // formed if `a` and `s` are not co-prime. // // Furthermore, the result produced by this solution is not "minimal", so it is necessary - // to take the result `o mod lcm(s, a)`. We can replace `lcm(s, a)` with just a `a'`. + // to take the result `o mod lcm(s, a)`. This `lcm(s, a)` is the same as `a'`. // SAFETY: `gcdpow` has an upper-bound not greater than the number of trailing 0-bits in // `a`. @@ -1673,11 +1696,11 @@ pub(crate) unsafe fn align_offset(p: *const T, a: usize) -> usize { let a2minus1 = unsafe { unchecked_sub(a2, 1) }; // SAFETY: `gcdpow` has an upper-bound not greater than the number of trailing 0-bits in // `a`. - let s2 = unsafe { unchecked_shr(smoda, gcdpow) }; + let s2 = unsafe { unchecked_shr(stride & a_minus_one, gcdpow) }; // SAFETY: `gcdpow` has an upper-bound not greater than the number of trailing 0-bits in // `a`. Furthermore, the subtraction cannot overflow, because `a2 = a >> gcdpow` will // always be strictly greater than `(p % a) >> gcdpow`. - let minusp2 = unsafe { unchecked_sub(a2, unchecked_shr(pmoda, gcdpow)) }; + let minusp2 = unsafe { unchecked_sub(a2, unchecked_shr(addr & a_minus_one, gcdpow)) }; // SAFETY: `a2` is a power-of-two, as proven above. `s2` is strictly less than `a2` // because `(s % a) >> gcdpow` is strictly less than `a >> gcdpow`. return wrapping_mul(minusp2, unsafe { mod_inv(s2, a2) }) & a2minus1; diff --git a/library/core/tests/ptr.rs b/library/core/tests/ptr.rs index bab2b1792f680..12861794c2d2c 100644 --- a/library/core/tests/ptr.rs +++ b/library/core/tests/ptr.rs @@ -359,7 +359,7 @@ fn align_offset_zst() { } #[test] -fn align_offset_stride1() { +fn align_offset_stride_one() { // For pointers of stride = 1, the pointer can always be aligned. The offset is equal to // number of bytes. let mut align = 1; @@ -380,24 +380,8 @@ fn align_offset_stride1() { } #[test] -fn align_offset_weird_strides() { - #[repr(packed)] - struct A3(u16, u8); - struct A4(u32); - #[repr(packed)] - struct A5(u32, u8); - #[repr(packed)] - struct A6(u32, u16); - #[repr(packed)] - struct A7(u32, u16, u8); - #[repr(packed)] - struct A8(u32, u32); - #[repr(packed)] - struct A9(u32, u32, u8); - #[repr(packed)] - struct A10(u32, u32, u16); - - unsafe fn test_weird_stride(ptr: *const T, align: usize) -> bool { +fn align_offset_various_strides() { + unsafe fn test_stride(ptr: *const T, align: usize) -> bool { let numptr = ptr as usize; let mut expected = usize::MAX; // Naive but definitely correct way to find the *first* aligned element of stride::. @@ -431,14 +415,39 @@ fn align_offset_weird_strides() { while align < limit { for ptr in 1usize..4 * align { unsafe { - x |= test_weird_stride::(ptr::invalid::(ptr), align); - x |= test_weird_stride::(ptr::invalid::(ptr), align); - x |= test_weird_stride::(ptr::invalid::(ptr), align); - x |= test_weird_stride::(ptr::invalid::(ptr), align); - x |= test_weird_stride::(ptr::invalid::(ptr), align); - x |= test_weird_stride::(ptr::invalid::(ptr), align); - x |= test_weird_stride::(ptr::invalid::(ptr), align); - x |= test_weird_stride::(ptr::invalid::(ptr), align); + #[repr(packed)] + struct A3(u16, u8); + x |= test_stride::(ptr::invalid::(ptr), align); + + struct A4(u32); + x |= test_stride::(ptr::invalid::(ptr), align); + + #[repr(packed)] + struct A5(u32, u8); + x |= test_stride::(ptr::invalid::(ptr), align); + + #[repr(packed)] + struct A6(u32, u16); + x |= test_stride::(ptr::invalid::(ptr), align); + + #[repr(packed)] + struct A7(u32, u16, u8); + x |= test_stride::(ptr::invalid::(ptr), align); + + #[repr(packed)] + struct A8(u32, u32); + x |= test_stride::(ptr::invalid::(ptr), align); + + #[repr(packed)] + struct A9(u32, u32, u8); + x |= test_stride::(ptr::invalid::(ptr), align); + + #[repr(packed)] + struct A10(u32, u32, u16); + x |= test_stride::(ptr::invalid::(ptr), align); + + x |= test_stride::(ptr::invalid::(ptr), align); + x |= test_stride::(ptr::invalid::(ptr), align); } } align = (align + 1).next_power_of_two(); diff --git a/src/test/assembly/align_offset.rs b/src/test/assembly/align_offset.rs new file mode 100644 index 0000000000000..c5eefca3467bb --- /dev/null +++ b/src/test/assembly/align_offset.rs @@ -0,0 +1,48 @@ +// assembly-output: emit-asm +// compile-flags: -Copt-level=1 +// only-x86_64 +// min-llvm-version: 14.0 +#![crate_type="rlib"] + +// CHECK-LABEL: align_offset_byte_ptr +// CHECK: leaq 31 +// CHECK: andq $-32 +// CHECK: subq +#[no_mangle] +pub fn align_offset_byte_ptr(ptr: *const u8) -> usize { + ptr.align_offset(32) +} + +// CHECK-LABEL: align_offset_byte_slice +// CHECK: leaq 31 +// CHECK: andq $-32 +// CHECK: subq +#[no_mangle] +pub fn align_offset_byte_slice(slice: &[u8]) -> usize { + slice.as_ptr().align_offset(32) +} + +// CHECK-LABEL: align_offset_word_ptr +// CHECK: leaq 31 +// CHECK: andq $-32 +// CHECK: subq +// CHECK: shrq +// This `ptr` is not known to be aligned, so it is required to check if it is at all possible to +// align. LLVM applies a simple mask. +// CHECK: orq +#[no_mangle] +pub fn align_offset_word_ptr(ptr: *const u32) -> usize { + ptr.align_offset(32) +} + +// CHECK-LABEL: align_offset_word_slice +// CHECK: leaq 31 +// CHECK: andq $-32 +// CHECK: subq +// CHECK: shrq +// `slice` is known to be aligned, so `!0` is not possible as a return +// CHECK-NOT: orq +#[no_mangle] +pub fn align_offset_word_slice(slice: &[u32]) -> usize { + slice.as_ptr().align_offset(32) +}