Skip to content

Commit

Permalink
Auto merge of #122050 - erikdesjardins:sret, r=nikic
Browse files Browse the repository at this point in the history
Stop using LLVM struct types for byval/sret

For `byval` and `sret`, the type has no semantic meaning, only the size matters\*†. Using `[N x i8]` is a more direct way to specify that we want `N` bytes, and avoids relying on LLVM's struct layout.

\*: The alignment would matter, if we didn't explicitly specify it. From what I can tell, we always specified the alignment for `sret`; for `byval`, we didn't until #112157.

†: For `byval`, the hidden copy may be impacted by padding in the LLVM struct type, i.e. padding bytes may not be copied. (I'm not sure if this is done today, but I think it would be legal.) But we manually pad our LLVM struct types specifically to avoid there ever being LLVM-visible padding, so that shouldn't be an issue.

Split out from #121577.

r? `@nikic`
  • Loading branch information
bors committed Mar 11, 2024
2 parents c69fda7 + f18c2f8 commit a6d93ac
Show file tree
Hide file tree
Showing 17 changed files with 329 additions and 132 deletions.
20 changes: 16 additions & 4 deletions compiler/rustc_codegen_llvm/src/abi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,10 @@ impl<'ll, 'tcx> FnAbiLlvmExt<'ll, 'tcx> for FnAbi<'tcx, Ty<'tcx>> {
PassMode::Indirect { attrs, meta_attrs: _, on_stack } => {
assert!(!on_stack);
let i = apply(attrs);
let sret = llvm::CreateStructRetAttr(cx.llcx, self.ret.layout.llvm_type(cx));
let sret = llvm::CreateStructRetAttr(
cx.llcx,
cx.type_array(cx.type_i8(), self.ret.layout.size.bytes()),
);
attributes::apply_to_llfn(llfn, llvm::AttributePlace::Argument(i), &[sret]);
}
PassMode::Cast { cast, pad_i32: _ } => {
Expand All @@ -437,7 +440,10 @@ impl<'ll, 'tcx> FnAbiLlvmExt<'ll, 'tcx> for FnAbi<'tcx, Ty<'tcx>> {
PassMode::Ignore => {}
PassMode::Indirect { attrs, meta_attrs: None, on_stack: true } => {
let i = apply(attrs);
let byval = llvm::CreateByValAttr(cx.llcx, arg.layout.llvm_type(cx));
let byval = llvm::CreateByValAttr(
cx.llcx,
cx.type_array(cx.type_i8(), arg.layout.size.bytes()),
);
attributes::apply_to_llfn(llfn, llvm::AttributePlace::Argument(i), &[byval]);
}
PassMode::Direct(attrs)
Expand Down Expand Up @@ -486,7 +492,10 @@ impl<'ll, 'tcx> FnAbiLlvmExt<'ll, 'tcx> for FnAbi<'tcx, Ty<'tcx>> {
PassMode::Indirect { attrs, meta_attrs: _, on_stack } => {
assert!(!on_stack);
let i = apply(bx.cx, attrs);
let sret = llvm::CreateStructRetAttr(bx.cx.llcx, self.ret.layout.llvm_type(bx));
let sret = llvm::CreateStructRetAttr(
bx.cx.llcx,
bx.cx.type_array(bx.cx.type_i8(), self.ret.layout.size.bytes()),
);
attributes::apply_to_callsite(callsite, llvm::AttributePlace::Argument(i), &[sret]);
}
PassMode::Cast { cast, pad_i32: _ } => {
Expand All @@ -513,7 +522,10 @@ impl<'ll, 'tcx> FnAbiLlvmExt<'ll, 'tcx> for FnAbi<'tcx, Ty<'tcx>> {
PassMode::Ignore => {}
PassMode::Indirect { attrs, meta_attrs: None, on_stack: true } => {
let i = apply(bx.cx, attrs);
let byval = llvm::CreateByValAttr(bx.cx.llcx, arg.layout.llvm_type(bx));
let byval = llvm::CreateByValAttr(
bx.cx.llcx,
bx.cx.type_array(bx.cx.type_i8(), arg.layout.size.bytes()),
);
attributes::apply_to_callsite(
callsite,
llvm::AttributePlace::Argument(i),
Expand Down
37 changes: 33 additions & 4 deletions compiler/rustc_target/src/abi/call/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,16 +46,24 @@ pub enum PassMode {
///
/// The argument has a layout abi of `ScalarPair`.
Pair(ArgAttributes, ArgAttributes),
/// Pass the argument after casting it. See the `CastTarget` docs for details. The bool
/// indicates if a `Reg::i32()` dummy argument is emitted before the real argument.
/// Pass the argument after casting it. See the `CastTarget` docs for details.
///
/// `pad_i32` indicates if a `Reg::i32()` dummy argument is emitted before the real argument.
Cast { pad_i32: bool, cast: Box<CastTarget> },
/// Pass the argument indirectly via a hidden pointer.
///
/// The `meta_attrs` value, if any, is for the metadata (vtable or length) of an unsized
/// argument. (This is the only mode that supports unsized arguments.)
///
/// `on_stack` defines that the value should be passed at a fixed stack offset in accordance to
/// the ABI rather than passed using a pointer. This corresponds to the `byval` LLVM argument
/// attribute (using the Rust type of this argument). `on_stack` cannot be true for unsized
/// arguments, i.e., when `meta_attrs` is `Some`.
/// attribute. The `byval` argument will use a byte array with the same size as the Rust type
/// (which ensures that padding is preserved and that we do not rely on LLVM's struct layout),
/// and will use the alignment specified in `attrs.pointee_align` (if `Some`) or the type's
/// alignment (if `None`). This means that the alignment will not always
/// match the Rust type's alignment; see documentation of `make_indirect_byval` for more info.
///
/// `on_stack` cannot be true for unsized arguments, i.e., when `meta_attrs` is `Some`.
Indirect { attrs: ArgAttributes, meta_attrs: Option<ArgAttributes>, on_stack: bool },
}

Expand Down Expand Up @@ -596,6 +604,8 @@ impl<'a, Ty> ArgAbi<'a, Ty> {
}
}

/// Pass this argument indirectly, by passing a (thin or fat) pointer to the argument instead.
/// This is valid for both sized and unsized arguments.
pub fn make_indirect(&mut self) {
match self.mode {
PassMode::Direct(_) | PassMode::Pair(_, _) => {
Expand All @@ -609,7 +619,26 @@ impl<'a, Ty> ArgAbi<'a, Ty> {
}
}

/// Pass this argument indirectly, by placing it at a fixed stack offset.
/// This corresponds to the `byval` LLVM argument attribute.
/// This is only valid for sized arguments.
///
/// `byval_align` specifies the alignment of the `byval` stack slot, which does not need to
/// correspond to the type's alignment. This will be `Some` if the target's ABI specifies that
/// stack slots used for arguments passed by-value have specific alignment requirements which
/// differ from the alignment used in other situations.
///
/// If `None`, the type's alignment is used.
///
/// If the resulting alignment differs from the type's alignment,
/// the argument will be copied to an alloca with sufficient alignment,
/// either in the caller (if the type's alignment is lower than the byval alignment)
/// or in the callee† (if the type's alignment is higher than the byval alignment),
/// to ensure that Rust code never sees an underaligned pointer.
///
/// † This is currently broken, see <https://github.com/rust-lang/rust/pull/122212>.
pub fn make_indirect_byval(&mut self, byval_align: Option<Align>) {
assert!(!self.layout.is_unsized(), "used byval ABI for unsized layout");
self.make_indirect();
match self.mode {
PassMode::Indirect { ref mut attrs, meta_attrs: _, ref mut on_stack } => {
Expand Down
2 changes: 2 additions & 0 deletions compiler/rustc_target/src/abi/call/x86_64.rs
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,8 @@ where
match cls_or_mem {
Err(Memory) => {
if is_arg {
// The x86_64 ABI doesn't have any special requirements for `byval` alignment,
// the type's alignment is always used.
arg.make_indirect_byval(None);
} else {
// `sret` parameter thus one less integer register available
Expand Down
8 changes: 4 additions & 4 deletions tests/codegen/align-byval-vector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,12 @@ pub struct DoubleFoo {
}

extern "C" {
// x86-linux: declare void @f({{.*}}byval(%Foo) align 4{{.*}})
// x86-darwin: declare void @f({{.*}}byval(%Foo) align 16{{.*}})
// x86-linux: declare void @f({{.*}}byval([32 x i8]) align 4{{.*}})
// x86-darwin: declare void @f({{.*}}byval([32 x i8]) align 16{{.*}})
fn f(foo: Foo);

// x86-linux: declare void @g({{.*}}byval(%DoubleFoo) align 4{{.*}})
// x86-darwin: declare void @g({{.*}}byval(%DoubleFoo) align 16{{.*}})
// x86-linux: declare void @g({{.*}}byval([64 x i8]) align 4{{.*}})
// x86-darwin: declare void @g({{.*}}byval([64 x i8]) align 16{{.*}})
fn g(foo: DoubleFoo);
}

Expand Down
92 changes: 46 additions & 46 deletions tests/codegen/align-byval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,20 +107,20 @@ pub unsafe fn call_na1(x: NaturalAlign1) {
// CHECK: start:

// m68k: [[ALLOCA:%[a-z0-9+]]] = alloca %NaturalAlign1, align 1
// m68k: call void @natural_align_1({{.*}}byval(%NaturalAlign1) align 1{{.*}} [[ALLOCA]])
// m68k: call void @natural_align_1({{.*}}byval([2 x i8]) align 1{{.*}} [[ALLOCA]])

// wasm: [[ALLOCA:%[a-z0-9+]]] = alloca %NaturalAlign1, align 1
// wasm: call void @natural_align_1({{.*}}byval(%NaturalAlign1) align 1{{.*}} [[ALLOCA]])
// wasm: call void @natural_align_1({{.*}}byval([2 x i8]) align 1{{.*}} [[ALLOCA]])

// x86_64-linux: call void @natural_align_1(i16

// x86_64-windows: call void @natural_align_1(i16

// i686-linux: [[ALLOCA:%[a-z0-9+]]] = alloca %NaturalAlign1, align 4
// i686-linux: call void @natural_align_1({{.*}}byval(%NaturalAlign1) align 4{{.*}} [[ALLOCA]])
// i686-linux: call void @natural_align_1({{.*}}byval([2 x i8]) align 4{{.*}} [[ALLOCA]])

// i686-windows: [[ALLOCA:%[a-z0-9+]]] = alloca %NaturalAlign1, align 4
// i686-windows: call void @natural_align_1({{.*}}byval(%NaturalAlign1) align 4{{.*}} [[ALLOCA]])
// i686-windows: call void @natural_align_1({{.*}}byval([2 x i8]) align 4{{.*}} [[ALLOCA]])
natural_align_1(x);
}

Expand All @@ -135,10 +135,10 @@ pub unsafe fn call_na2(x: NaturalAlign2) {
// x86_64-windows-NEXT: call void @natural_align_2

// i686-linux: [[ALLOCA:%[0-9]+]] = alloca %NaturalAlign2, align 4
// i686-linux: call void @natural_align_2({{.*}}byval(%NaturalAlign2) align 4{{.*}} [[ALLOCA]])
// i686-linux: call void @natural_align_2({{.*}}byval([34 x i8]) align 4{{.*}} [[ALLOCA]])

// i686-windows: [[ALLOCA:%[0-9]+]] = alloca %NaturalAlign2, align 4
// i686-windows: call void @natural_align_2({{.*}}byval(%NaturalAlign2) align 4{{.*}} [[ALLOCA]])
// i686-windows: call void @natural_align_2({{.*}}byval([34 x i8]) align 4{{.*}} [[ALLOCA]])
natural_align_2(x);
}

Expand Down Expand Up @@ -199,141 +199,141 @@ pub unsafe fn call_fa16(x: ForceAlign16) {
}

extern "C" {
// m68k: declare void @natural_align_1({{.*}}byval(%NaturalAlign1) align 1{{.*}})
// m68k: declare void @natural_align_1({{.*}}byval([2 x i8]) align 1{{.*}})

// wasm: declare void @natural_align_1({{.*}}byval(%NaturalAlign1) align 1{{.*}})
// wasm: declare void @natural_align_1({{.*}}byval([2 x i8]) align 1{{.*}})

// x86_64-linux: declare void @natural_align_1(i16)

// x86_64-windows: declare void @natural_align_1(i16)

// i686-linux: declare void @natural_align_1({{.*}}byval(%NaturalAlign1) align 4{{.*}})
// i686-linux: declare void @natural_align_1({{.*}}byval([2 x i8]) align 4{{.*}})

// i686-windows: declare void @natural_align_1({{.*}}byval(%NaturalAlign1) align 4{{.*}})
// i686-windows: declare void @natural_align_1({{.*}}byval([2 x i8]) align 4{{.*}})
fn natural_align_1(x: NaturalAlign1);

// m68k: declare void @natural_align_2({{.*}}byval(%NaturalAlign2) align 2{{.*}})
// m68k: declare void @natural_align_2({{.*}}byval([34 x i8]) align 2{{.*}})

// wasm: declare void @natural_align_2({{.*}}byval(%NaturalAlign2) align 2{{.*}})
// wasm: declare void @natural_align_2({{.*}}byval([34 x i8]) align 2{{.*}})

// x86_64-linux: declare void @natural_align_2({{.*}}byval(%NaturalAlign2) align 2{{.*}})
// x86_64-linux: declare void @natural_align_2({{.*}}byval([34 x i8]) align 2{{.*}})

// x86_64-windows: declare void @natural_align_2(
// x86_64-windows-NOT: byval
// x86_64-windows-SAME: align 2{{.*}})

// i686-linux: declare void @natural_align_2({{.*}}byval(%NaturalAlign2) align 4{{.*}})
// i686-linux: declare void @natural_align_2({{.*}}byval([34 x i8]) align 4{{.*}})

// i686-windows: declare void @natural_align_2({{.*}}byval(%NaturalAlign2) align 4{{.*}})
// i686-windows: declare void @natural_align_2({{.*}}byval([34 x i8]) align 4{{.*}})
fn natural_align_2(x: NaturalAlign2);

// m68k: declare void @force_align_4({{.*}}byval(%ForceAlign4) align 4{{.*}})
// m68k: declare void @force_align_4({{.*}}byval([20 x i8]) align 4{{.*}})

// wasm: declare void @force_align_4({{.*}}byval(%ForceAlign4) align 4{{.*}})
// wasm: declare void @force_align_4({{.*}}byval([20 x i8]) align 4{{.*}})

// x86_64-linux: declare void @force_align_4({{.*}}byval(%ForceAlign4) align 4{{.*}})
// x86_64-linux: declare void @force_align_4({{.*}}byval([20 x i8]) align 4{{.*}})

// x86_64-windows: declare void @force_align_4(
// x86_64-windows-NOT: byval
// x86_64-windows-SAME: align 4{{.*}})

// i686-linux: declare void @force_align_4({{.*}}byval(%ForceAlign4) align 4{{.*}})
// i686-linux: declare void @force_align_4({{.*}}byval([20 x i8]) align 4{{.*}})

// i686-windows: declare void @force_align_4({{.*}}byval(%ForceAlign4) align 4{{.*}})
// i686-windows: declare void @force_align_4({{.*}}byval([20 x i8]) align 4{{.*}})
fn force_align_4(x: ForceAlign4);

// m68k: declare void @natural_align_8({{.*}}byval(%NaturalAlign8) align 4{{.*}})
// m68k: declare void @natural_align_8({{.*}}byval([24 x i8]) align 4{{.*}})

// wasm: declare void @natural_align_8({{.*}}byval(%NaturalAlign8) align 8{{.*}})
// wasm: declare void @natural_align_8({{.*}}byval([24 x i8]) align 8{{.*}})

// x86_64-linux: declare void @natural_align_8({{.*}}byval(%NaturalAlign8) align 8{{.*}})
// x86_64-linux: declare void @natural_align_8({{.*}}byval([24 x i8]) align 8{{.*}})

// x86_64-windows: declare void @natural_align_8(
// x86_64-windows-NOT: byval
// x86_64-windows-SAME: align 8{{.*}})

// i686-linux: declare void @natural_align_8({{.*}}byval(%NaturalAlign8) align 4{{.*}})
// i686-linux: declare void @natural_align_8({{.*}}byval([24 x i8]) align 4{{.*}})

// i686-windows: declare void @natural_align_8({{.*}}byval(%NaturalAlign8) align 4{{.*}})
// i686-windows: declare void @natural_align_8({{.*}}byval([24 x i8]) align 4{{.*}})
fn natural_align_8(x: NaturalAlign8);

// m68k: declare void @force_align_8({{.*}}byval(%ForceAlign8) align 8{{.*}})
// m68k: declare void @force_align_8({{.*}}byval([24 x i8]) align 8{{.*}})

// wasm: declare void @force_align_8({{.*}}byval(%ForceAlign8) align 8{{.*}})
// wasm: declare void @force_align_8({{.*}}byval([24 x i8]) align 8{{.*}})

// x86_64-linux: declare void @force_align_8({{.*}}byval(%ForceAlign8) align 8{{.*}})
// x86_64-linux: declare void @force_align_8({{.*}}byval([24 x i8]) align 8{{.*}})

// x86_64-windows: declare void @force_align_8(
// x86_64-windows-NOT: byval
// x86_64-windows-SAME: align 8{{.*}})

// i686-linux: declare void @force_align_8({{.*}}byval(%ForceAlign8) align 4{{.*}})
// i686-linux: declare void @force_align_8({{.*}}byval([24 x i8]) align 4{{.*}})

// i686-windows: declare void @force_align_8(
// i686-windows-NOT: byval
// i686-windows-SAME: align 8{{.*}})
fn force_align_8(x: ForceAlign8);

// m68k: declare void @lower_fa8({{.*}}byval(%LowerFA8) align 4{{.*}})
// m68k: declare void @lower_fa8({{.*}}byval([24 x i8]) align 4{{.*}})

// wasm: declare void @lower_fa8({{.*}}byval(%LowerFA8) align 8{{.*}})
// wasm: declare void @lower_fa8({{.*}}byval([24 x i8]) align 8{{.*}})

// x86_64-linux: declare void @lower_fa8({{.*}}byval(%LowerFA8) align 8{{.*}})
// x86_64-linux: declare void @lower_fa8({{.*}}byval([24 x i8]) align 8{{.*}})

// x86_64-windows: declare void @lower_fa8(
// x86_64-windows-NOT: byval
// x86_64-windows-SAME: align 8{{.*}})

// i686-linux: declare void @lower_fa8({{.*}}byval(%LowerFA8) align 4{{.*}})
// i686-linux: declare void @lower_fa8({{.*}}byval([24 x i8]) align 4{{.*}})

// i686-windows: declare void @lower_fa8({{.*}}byval(%LowerFA8) align 4{{.*}})
// i686-windows: declare void @lower_fa8({{.*}}byval([24 x i8]) align 4{{.*}})
fn lower_fa8(x: LowerFA8);

// m68k: declare void @wrapped_fa8({{.*}}byval(%WrappedFA8) align 8{{.*}})
// m68k: declare void @wrapped_fa8({{.*}}byval([24 x i8]) align 8{{.*}})

// wasm: declare void @wrapped_fa8({{.*}}byval(%WrappedFA8) align 8{{.*}})
// wasm: declare void @wrapped_fa8({{.*}}byval([24 x i8]) align 8{{.*}})

// x86_64-linux: declare void @wrapped_fa8({{.*}}byval(%WrappedFA8) align 8{{.*}})
// x86_64-linux: declare void @wrapped_fa8({{.*}}byval([24 x i8]) align 8{{.*}})

// x86_64-windows: declare void @wrapped_fa8(
// x86_64-windows-NOT: byval
// x86_64-windows-SAME: align 8{{.*}})

// i686-linux: declare void @wrapped_fa8({{.*}}byval(%WrappedFA8) align 4{{.*}})
// i686-linux: declare void @wrapped_fa8({{.*}}byval([24 x i8]) align 4{{.*}})

// i686-windows: declare void @wrapped_fa8(
// i686-windows-NOT: byval
// i686-windows-SAME: align 8{{.*}})
fn wrapped_fa8(x: WrappedFA8);

// m68k: declare void @transparent_fa8({{.*}}byval(%TransparentFA8) align 8{{.*}})
// m68k: declare void @transparent_fa8({{.*}}byval([24 x i8]) align 8{{.*}})

// wasm: declare void @transparent_fa8({{.*}}byval(%TransparentFA8) align 8{{.*}})
// wasm: declare void @transparent_fa8({{.*}}byval([24 x i8]) align 8{{.*}})

// x86_64-linux: declare void @transparent_fa8({{.*}}byval(%TransparentFA8) align 8{{.*}})
// x86_64-linux: declare void @transparent_fa8({{.*}}byval([24 x i8]) align 8{{.*}})

// x86_64-windows: declare void @transparent_fa8(
// x86_64-windows-NOT: byval
// x86_64-windows-SAME: align 8{{.*}})

// i686-linux: declare void @transparent_fa8({{.*}}byval(%TransparentFA8) align 4{{.*}})
// i686-linux: declare void @transparent_fa8({{.*}}byval([24 x i8]) align 4{{.*}})

// i686-windows: declare void @transparent_fa8(
// i686-windows-NOT: byval
// i686-windows-SAME: align 8{{.*}})
fn transparent_fa8(x: TransparentFA8);

// m68k: declare void @force_align_16({{.*}}byval(%ForceAlign16) align 16{{.*}})
// m68k: declare void @force_align_16({{.*}}byval([80 x i8]) align 16{{.*}})

// wasm: declare void @force_align_16({{.*}}byval(%ForceAlign16) align 16{{.*}})
// wasm: declare void @force_align_16({{.*}}byval([80 x i8]) align 16{{.*}})

// x86_64-linux: declare void @force_align_16({{.*}}byval(%ForceAlign16) align 16{{.*}})
// x86_64-linux: declare void @force_align_16({{.*}}byval([80 x i8]) align 16{{.*}})

// x86_64-windows: declare void @force_align_16(
// x86_64-windows-NOT: byval
// x86_64-windows-SAME: align 16{{.*}})

// i686-linux: declare void @force_align_16({{.*}}byval(%ForceAlign16) align 4{{.*}})
// i686-linux: declare void @force_align_16({{.*}}byval([80 x i8]) align 4{{.*}})

// i686-windows: declare void @force_align_16(
// i686-windows-NOT: byval
Expand Down
1 change: 0 additions & 1 deletion tests/codegen/align-enum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ pub enum Align64 {
A(u32),
B(u32),
}
// CHECK: %Align64 = type { i32, [15 x i32] }

pub struct Nested64 {
a: u8,
Expand Down
4 changes: 0 additions & 4 deletions tests/codegen/align-struct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,23 @@

#[repr(align(64))]
pub struct Align64(i32);
// CHECK: %Align64 = type { i32, [15 x i32] }

pub struct Nested64 {
a: Align64,
b: i32,
c: i32,
d: i8,
}
// CHECK: %Nested64 = type { %Align64, i32, i32, i8, [55 x i8] }

pub enum Enum4 {
A(i32),
B(i32),
}
// No Aggregate type, and hence nothing in LLVM IR.

pub enum Enum64 {
A(Align64),
B(i32),
}
// CHECK: %Enum64 = type { i32, [31 x i32] }

// CHECK-LABEL: @align64
#[no_mangle]
Expand Down
Loading

0 comments on commit a6d93ac

Please sign in to comment.