Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize layout of function arguments in the Rust ABI - take 2 #97559

Closed
wants to merge 12 commits into from
49 changes: 35 additions & 14 deletions compiler/rustc_middle/src/ty/layout.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ use rustc_session::{config::OptLevel, DataTypeKind, FieldInfo, SizeKind, Variant
use rustc_span::symbol::Symbol;
use rustc_span::{Span, DUMMY_SP};
use rustc_target::abi::call::{
ArgAbi, ArgAttribute, ArgAttributes, ArgExtension, Conv, FnAbi, PassMode, Reg, RegKind,
ArgAbi, ArgAttribute, ArgAttributes, ArgExtension, Conv, FnAbi, HomogeneousAggregate, PassMode,
Reg, RegKind,
};
use rustc_target::abi::*;
use rustc_target::spec::{abi::Abi as SpecAbi, HasTargetSpec, PanicStrategy, Target};
Expand Down Expand Up @@ -3360,7 +3361,38 @@ impl<'tcx> LayoutCx<'tcx, TyCtxt<'tcx>> {
}

match arg.layout.abi {
Abi::Aggregate { .. } => {}
Abi::Aggregate { .. } => {
// Pass and return structures up to 1 pointers in size by value.
let ptr_size = Pointer.size(self);
let max_by_val_size = ptr_size;
let size = arg.layout.size;

if arg.layout.is_unsized() || size > max_by_val_size {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The old code used size > ptr_size. You changed it to size > ptr_size * 2.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops, you're right I forgot about this. It explain the performance regressions.
I will revert this change. Thanks for finding this out.

arg.make_indirect();
} else if let Ok(HomogeneousAggregate::Homogeneous(Reg {
kind: RegKind::Float,
..
})) = arg.layout.homogeneous_aggregate(self)
{
// We don't want to aggregate floats as an aggregates of Integer
// because this will hurt the generated assembly (#93490)
//
// As an optimization we want to pass homogeneous aggregate of floats
// greater than pointer size as indirect
if size > ptr_size {
arg.make_indirect();
}
} else {
// We want to pass small aggregates as immediates, but using
// a LLVM aggregate type for this leads to bad optimizations,
// so we pick an appropriately sized integer type instead.
//
// NOTE: This is sub-optimal because in the case of (f32, f32, u32, u32)
// we could do ([f32; 2], u64) which is better but this is the best we
// can do right now.
arg.cast_to(Reg { kind: RegKind::Integer, size });
}
}

// This is a fun case! The gist of what this is doing is
// that we want callers and callees to always agree on the
Expand All @@ -3386,20 +3418,9 @@ impl<'tcx> LayoutCx<'tcx, TyCtxt<'tcx>> {
&& self.tcx.sess.target.simd_types_indirect =>
{
arg.make_indirect();
return;
}

_ => return,
}

let size = arg.layout.size;
if arg.layout.is_unsized() || size > Pointer.size(self) {
arg.make_indirect();
} else {
// We want to pass small aggregates as immediates, but using
// a LLVM aggregate type for this leads to bad optimizations,
// so we pick an appropriately sized integer type instead.
arg.cast_to(Reg { kind: RegKind::Integer, size });
_ => {}
}
};
fixup(&mut fn_abi.ret);
Expand Down
48 changes: 48 additions & 0 deletions src/test/assembly/x86-64-homogenous-floats.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// assembly-output: emit-asm
// needs-llvm-components: x86
// compile-flags: --target x86_64-unknown-linux-gnu
// compile-flags: -C llvm-args=--x86-asm-syntax=intel
// compile-flags: -C opt-level=3

#![crate_type = "rlib"]
#![no_std]

// CHECK-LABEL: sum_f32:
// CHECK: addss xmm0, xmm1
// CHECK-NEXT: ret
#[no_mangle]
pub fn sum_f32(a: f32, b: f32) -> f32 {
a + b
}

// CHECK-LABEL: sum_f64x2:
// CHECK: mov rax, [[PTR_IN:.*]]
// CHECK-NEXT: movupd [[XMMA:xmm[0-9]]], xmmword ptr [rsi]
// CHECK-NEXT: movupd [[XMMB:xmm[0-9]]], xmmword ptr [rdx]
// CHECK-NEXT: addpd [[XMMB]], [[XMMA]]
// CHECK-NEXT: movupd xmmword ptr {{\[}}[[PTR_IN]]{{\]}}, [[XMMB]]
// CHECK-NEXT: ret
#[no_mangle]
pub fn sum_f64x2(a: [f64; 2], b: [f64; 2]) -> [f64; 2] {
[
a[0] + b[0],
a[1] + b[1],
]
}

// CHECK-LABEL: sum_f32x4:
// CHECK: mov rax, [[PTR_IN:.*]]
// CHECK-NEXT: movups [[XMMA:xmm[0-9]]], xmmword ptr [rsi]
// CHECK-NEXT: movups [[XMMB:xmm[0-9]]], xmmword ptr [rdx]
// CHECK-NEXT: addps [[XMMB]], [[XMMA]]
// CHECK-NEXT: movups xmmword ptr {{\[}}[[PTR_IN]]{{\]}}, [[XMMB]]
// CHECK-NEXT: ret
#[no_mangle]
pub fn sum_f32x4(a: [f32; 4], b: [f32; 4]) -> [f32; 4] {
[
a[0] + b[0],
a[1] + b[1],
a[2] + b[2],
a[3] + b[3],
]
}
32 changes: 32 additions & 0 deletions src/test/codegen/homogeneous-floats.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
//! Check that small (less then 128bits on x86_64) homogeneous floats are either pass as an array
//! or by a pointer

// compile-flags: -C no-prepopulate-passes -O
// only-x86_64

#![crate_type = "lib"]

pub struct Foo {
bar1: f32,
bar2: f32,
bar3: f32,
bar4: f32,
}

// CHECK: define [2 x float] @array_f32x2([2 x float] %0, [2 x float] %1)
#[no_mangle]
pub fn array_f32x2(a: [f32; 2], b: [f32; 2]) -> [f32; 2] {
todo!()
}

// CHECK: define void @array_f32x4([4 x float]* {{.*}} sret([4 x float]) {{.*}} %0, [4 x float]* {{.*}} %a, [4 x float]* {{.*}} %b)
#[no_mangle]
pub fn array_f32x4(a: [f32; 4], b: [f32; 4]) -> [f32; 4] {
todo!()
}

// CHECK: define void @array_f32x4_nested(%Foo* {{.*}} sret(%Foo) {{.*}} %0, %Foo* {{.*}} %a, %Foo* {{.*}} %b)
#[no_mangle]
pub fn array_f32x4_nested(a: Foo, b: Foo) -> Foo {
todo!()
}
4 changes: 2 additions & 2 deletions src/test/codegen/issue-37945.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ pub fn is_empty_1(xs: Iter<f32>) -> bool {
// CHECK-NEXT: start:
// CHECK-NEXT: [[A:%.*]] = icmp ne {{i32\*|ptr}} %xs.1, null
// CHECK-NEXT: tail call void @llvm.assume(i1 [[A]])
// CHECK-NEXT: [[B:%.*]] = icmp eq {{i32\*|ptr}} %xs.1, %xs.0
// CHECK-NEXT: [[B:%.*]] = icmp eq {{i32\*|ptr}} {{%xs.1, %xs.0|%xs.0, %xs.1}}
// CHECK-NEXT: ret i1 [[B:%.*]]
{xs}.next().is_none()
}
Expand All @@ -28,7 +28,7 @@ pub fn is_empty_2(xs: Iter<f32>) -> bool {
// CHECK-NEXT: start:
// CHECK-NEXT: [[C:%.*]] = icmp ne {{i32\*|ptr}} %xs.1, null
// CHECK-NEXT: tail call void @llvm.assume(i1 [[C]])
// CHECK-NEXT: [[D:%.*]] = icmp eq {{i32\*|ptr}} %xs.1, %xs.0
// CHECK-NEXT: [[D:%.*]] = icmp eq {{i32\*|ptr}} {{%xs.1, %xs.0|%xs.0, %xs.1}}
// CHECK-NEXT: ret i1 [[D:%.*]]
xs.map(|&x| x).next().is_none()
}
184 changes: 184 additions & 0 deletions src/test/ui/abi/homogenous-floats-target-feature-mixup.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
// This test check that even if we mixup target feature of function with homogenous floats,
// the abi is sound and still produce the right answer.
//
// This is basically the same test as src/test/ui/simd/target-feature-mixup.rs but for floats and
// without #[repr(simd)]

// run-pass
// ignore-emscripten
// ignore-sgx no processes

#![feature(target_feature, cfg_target_feature)]
#![feature(avx512_target_feature)]

#![allow(overflowing_literals)]
#![allow(unused_variables)]
#![allow(stable_features)]

use std::process::{Command, ExitStatus};
use std::env;

fn main() {
if let Some(level) = env::args().nth(1) {
return test::main(&level)
}

let me = env::current_exe().unwrap();
for level in ["sse", "avx", "avx512"].iter() {
let status = Command::new(&me).arg(level).status().unwrap();
if status.success() {
println!("success with {}", level);
continue
}

// We don't actually know if our computer has the requisite target features
// for the test below. Testing for that will get added to libstd later so
// for now just assume sigill means this is a machine that can't run this test.
if is_sigill(status) {
println!("sigill with {}, assuming spurious", level);
continue
}
panic!("invalid status at {}: {}", level, status);
}
}

#[cfg(unix)]
fn is_sigill(status: ExitStatus) -> bool {
use std::os::unix::prelude::*;
status.signal() == Some(4)
}

#[cfg(windows)]
fn is_sigill(status: ExitStatus) -> bool {
status.code() == Some(0xc000001d)
}

#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[allow(nonstandard_style)]
mod test {
#[derive(PartialEq, Debug, Clone, Copy)]
struct f32x2(f32, f32);

#[derive(PartialEq, Debug, Clone, Copy)]
struct f32x4(f32, f32, f32, f32);

#[derive(PartialEq, Debug, Clone, Copy)]
struct f32x8(f32, f32, f32, f32, f32, f32, f32, f32);

pub fn main(level: &str) {
unsafe {
main_normal(level);
main_sse(level);
if level == "sse" {
return
}
main_avx(level);
if level == "avx" {
return
}
main_avx512(level);
}
}

macro_rules! mains {
($(
$(#[$attr:meta])*
unsafe fn $main:ident(level: &str) {
...
}
)*) => ($(
$(#[$attr])*
unsafe fn $main(level: &str) {
let m128 = f32x2(1., 2.);
let m256 = f32x4(3., 4., 5., 6.);
let m512 = f32x8(7., 8., 9., 10., 11., 12., 13., 14.);
assert_eq!(id_sse_128(m128), m128);
assert_eq!(id_sse_256(m256), m256);
assert_eq!(id_sse_512(m512), m512);

if level == "sse" {
return
}
assert_eq!(id_avx_128(m128), m128);
assert_eq!(id_avx_256(m256), m256);
assert_eq!(id_avx_512(m512), m512);

if level == "avx" {
return
}
assert_eq!(id_avx512_128(m128), m128);
assert_eq!(id_avx512_256(m256), m256);
assert_eq!(id_avx512_512(m512), m512);
}
)*)
}

mains! {
unsafe fn main_normal(level: &str) { ... }
#[target_feature(enable = "sse2")]
unsafe fn main_sse(level: &str) { ... }
#[target_feature(enable = "avx")]
unsafe fn main_avx(level: &str) { ... }
#[target_feature(enable = "avx512bw")]
unsafe fn main_avx512(level: &str) { ... }
}

#[target_feature(enable = "sse2")]
unsafe fn id_sse_128(a: f32x2) -> f32x2 {
assert_eq!(a, f32x2(1., 2.));
a.clone()
}

#[target_feature(enable = "sse2")]
unsafe fn id_sse_256(a: f32x4) -> f32x4 {
assert_eq!(a, f32x4(3., 4., 5., 6.));
a.clone()
}

#[target_feature(enable = "sse2")]
unsafe fn id_sse_512(a: f32x8) -> f32x8 {
assert_eq!(a, f32x8(7., 8., 9., 10., 11., 12., 13., 14.));
a.clone()
}

#[target_feature(enable = "avx")]
unsafe fn id_avx_128(a: f32x2) -> f32x2 {
assert_eq!(a, f32x2(1., 2.));
a.clone()
}

#[target_feature(enable = "avx")]
unsafe fn id_avx_256(a: f32x4) -> f32x4 {
assert_eq!(a, f32x4(3., 4., 5., 6.));
a.clone()
}

#[target_feature(enable = "avx")]
unsafe fn id_avx_512(a: f32x8) -> f32x8 {
assert_eq!(a, f32x8(7., 8., 9., 10., 11., 12., 13., 14.));
a.clone()
}

#[target_feature(enable = "avx512bw")]
unsafe fn id_avx512_128(a: f32x2) -> f32x2 {
assert_eq!(a, f32x2(1., 2.));
a.clone()
}

#[target_feature(enable = "avx512bw")]
unsafe fn id_avx512_256(a: f32x4) -> f32x4 {
assert_eq!(a, f32x4(3., 4., 5., 6.));
a.clone()
}

#[target_feature(enable = "avx512bw")]
unsafe fn id_avx512_512(a: f32x8) -> f32x8 {
assert_eq!(a, f32x8(7., 8., 9., 10., 11., 12., 13., 14.));
a.clone()
}
}

#[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
mod test {
pub fn main(level: &str) {}
}
Loading