Skip to content

Commit

Permalink
use [N x i8] for byval/sret types
Browse files Browse the repository at this point in the history
This avoids depending on LLVM's struct types to determine the size of
the byval/sret slot.
  • Loading branch information
erikdesjardins committed Mar 5, 2024
1 parent 3c02972 commit 96a7267
Show file tree
Hide file tree
Showing 10 changed files with 80 additions and 73 deletions.
20 changes: 16 additions & 4 deletions compiler/rustc_codegen_llvm/src/abi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,10 @@ impl<'ll, 'tcx> FnAbiLlvmExt<'ll, 'tcx> for FnAbi<'tcx, Ty<'tcx>> {
PassMode::Indirect { attrs, meta_attrs: _, on_stack } => {
assert!(!on_stack);
let i = apply(attrs);
let sret = llvm::CreateStructRetAttr(cx.llcx, self.ret.layout.llvm_type(cx));
let sret = llvm::CreateStructRetAttr(
cx.llcx,
cx.type_array(cx.type_i8(), self.ret.layout.size.bytes()),
);
attributes::apply_to_llfn(llfn, llvm::AttributePlace::Argument(i), &[sret]);
}
PassMode::Cast { cast, pad_i32: _ } => {
Expand All @@ -437,7 +440,10 @@ impl<'ll, 'tcx> FnAbiLlvmExt<'ll, 'tcx> for FnAbi<'tcx, Ty<'tcx>> {
PassMode::Ignore => {}
PassMode::Indirect { attrs, meta_attrs: None, on_stack: true } => {
let i = apply(attrs);
let byval = llvm::CreateByValAttr(cx.llcx, arg.layout.llvm_type(cx));
let byval = llvm::CreateByValAttr(
cx.llcx,
cx.type_array(cx.type_i8(), arg.layout.size.bytes()),
);
attributes::apply_to_llfn(llfn, llvm::AttributePlace::Argument(i), &[byval]);
}
PassMode::Direct(attrs)
Expand Down Expand Up @@ -486,7 +492,10 @@ impl<'ll, 'tcx> FnAbiLlvmExt<'ll, 'tcx> for FnAbi<'tcx, Ty<'tcx>> {
PassMode::Indirect { attrs, meta_attrs: _, on_stack } => {
assert!(!on_stack);
let i = apply(bx.cx, attrs);
let sret = llvm::CreateStructRetAttr(bx.cx.llcx, self.ret.layout.llvm_type(bx));
let sret = llvm::CreateStructRetAttr(
bx.cx.llcx,
bx.cx.type_array(bx.cx.type_i8(), self.ret.layout.size.bytes()),
);
attributes::apply_to_callsite(callsite, llvm::AttributePlace::Argument(i), &[sret]);
}
PassMode::Cast { cast, pad_i32: _ } => {
Expand All @@ -513,7 +522,10 @@ impl<'ll, 'tcx> FnAbiLlvmExt<'ll, 'tcx> for FnAbi<'tcx, Ty<'tcx>> {
PassMode::Ignore => {}
PassMode::Indirect { attrs, meta_attrs: None, on_stack: true } => {
let i = apply(bx.cx, attrs);
let byval = llvm::CreateByValAttr(bx.cx.llcx, arg.layout.llvm_type(bx));
let byval = llvm::CreateByValAttr(
bx.cx.llcx,
bx.cx.type_array(bx.cx.type_i8(), arg.layout.size.bytes()),
);
attributes::apply_to_callsite(
callsite,
llvm::AttributePlace::Argument(i),
Expand Down
8 changes: 4 additions & 4 deletions tests/codegen/align-byval-vector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,12 @@ pub struct DoubleFoo {
}

extern "C" {
// x86-linux: declare void @f({{.*}}byval(%Foo) align 4{{.*}})
// x86-darwin: declare void @f({{.*}}byval(%Foo) align 16{{.*}})
// x86-linux: declare void @f({{.*}}byval([32 x i8]) align 4{{.*}})
// x86-darwin: declare void @f({{.*}}byval([32 x i8]) align 16{{.*}})
fn f(foo: Foo);

// x86-linux: declare void @g({{.*}}byval(%DoubleFoo) align 4{{.*}})
// x86-darwin: declare void @g({{.*}}byval(%DoubleFoo) align 16{{.*}})
// x86-linux: declare void @g({{.*}}byval([64 x i8]) align 4{{.*}})
// x86-darwin: declare void @g({{.*}}byval([64 x i8]) align 16{{.*}})
fn g(foo: DoubleFoo);
}

Expand Down
92 changes: 46 additions & 46 deletions tests/codegen/align-byval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,20 +107,20 @@ pub unsafe fn call_na1(x: NaturalAlign1) {
// CHECK: start:

// m68k: [[ALLOCA:%[a-z0-9+]]] = alloca %NaturalAlign1, align 1
// m68k: call void @natural_align_1({{.*}}byval(%NaturalAlign1) align 1{{.*}} [[ALLOCA]])
// m68k: call void @natural_align_1({{.*}}byval([2 x i8]) align 1{{.*}} [[ALLOCA]])

// wasm: [[ALLOCA:%[a-z0-9+]]] = alloca %NaturalAlign1, align 1
// wasm: call void @natural_align_1({{.*}}byval(%NaturalAlign1) align 1{{.*}} [[ALLOCA]])
// wasm: call void @natural_align_1({{.*}}byval([2 x i8]) align 1{{.*}} [[ALLOCA]])

// x86_64-linux: call void @natural_align_1(i16

// x86_64-windows: call void @natural_align_1(i16

// i686-linux: [[ALLOCA:%[a-z0-9+]]] = alloca %NaturalAlign1, align 4
// i686-linux: call void @natural_align_1({{.*}}byval(%NaturalAlign1) align 4{{.*}} [[ALLOCA]])
// i686-linux: call void @natural_align_1({{.*}}byval([2 x i8]) align 4{{.*}} [[ALLOCA]])

// i686-windows: [[ALLOCA:%[a-z0-9+]]] = alloca %NaturalAlign1, align 4
// i686-windows: call void @natural_align_1({{.*}}byval(%NaturalAlign1) align 4{{.*}} [[ALLOCA]])
// i686-windows: call void @natural_align_1({{.*}}byval([2 x i8]) align 4{{.*}} [[ALLOCA]])
natural_align_1(x);
}

Expand All @@ -135,10 +135,10 @@ pub unsafe fn call_na2(x: NaturalAlign2) {
// x86_64-windows-NEXT: call void @natural_align_2

// i686-linux: [[ALLOCA:%[0-9]+]] = alloca %NaturalAlign2, align 4
// i686-linux: call void @natural_align_2({{.*}}byval(%NaturalAlign2) align 4{{.*}} [[ALLOCA]])
// i686-linux: call void @natural_align_2({{.*}}byval([34 x i8]) align 4{{.*}} [[ALLOCA]])

// i686-windows: [[ALLOCA:%[0-9]+]] = alloca %NaturalAlign2, align 4
// i686-windows: call void @natural_align_2({{.*}}byval(%NaturalAlign2) align 4{{.*}} [[ALLOCA]])
// i686-windows: call void @natural_align_2({{.*}}byval([34 x i8]) align 4{{.*}} [[ALLOCA]])
natural_align_2(x);
}

Expand Down Expand Up @@ -199,141 +199,141 @@ pub unsafe fn call_fa16(x: ForceAlign16) {
}

extern "C" {
// m68k: declare void @natural_align_1({{.*}}byval(%NaturalAlign1) align 1{{.*}})
// m68k: declare void @natural_align_1({{.*}}byval([2 x i8]) align 1{{.*}})

// wasm: declare void @natural_align_1({{.*}}byval(%NaturalAlign1) align 1{{.*}})
// wasm: declare void @natural_align_1({{.*}}byval([2 x i8]) align 1{{.*}})

// x86_64-linux: declare void @natural_align_1(i16)

// x86_64-windows: declare void @natural_align_1(i16)

// i686-linux: declare void @natural_align_1({{.*}}byval(%NaturalAlign1) align 4{{.*}})
// i686-linux: declare void @natural_align_1({{.*}}byval([2 x i8]) align 4{{.*}})

// i686-windows: declare void @natural_align_1({{.*}}byval(%NaturalAlign1) align 4{{.*}})
// i686-windows: declare void @natural_align_1({{.*}}byval([2 x i8]) align 4{{.*}})
fn natural_align_1(x: NaturalAlign1);

// m68k: declare void @natural_align_2({{.*}}byval(%NaturalAlign2) align 2{{.*}})
// m68k: declare void @natural_align_2({{.*}}byval([34 x i8]) align 2{{.*}})

// wasm: declare void @natural_align_2({{.*}}byval(%NaturalAlign2) align 2{{.*}})
// wasm: declare void @natural_align_2({{.*}}byval([34 x i8]) align 2{{.*}})

// x86_64-linux: declare void @natural_align_2({{.*}}byval(%NaturalAlign2) align 2{{.*}})
// x86_64-linux: declare void @natural_align_2({{.*}}byval([34 x i8]) align 2{{.*}})

// x86_64-windows: declare void @natural_align_2(
// x86_64-windows-NOT: byval
// x86_64-windows-SAME: align 2{{.*}})

// i686-linux: declare void @natural_align_2({{.*}}byval(%NaturalAlign2) align 4{{.*}})
// i686-linux: declare void @natural_align_2({{.*}}byval([34 x i8]) align 4{{.*}})

// i686-windows: declare void @natural_align_2({{.*}}byval(%NaturalAlign2) align 4{{.*}})
// i686-windows: declare void @natural_align_2({{.*}}byval([34 x i8]) align 4{{.*}})
fn natural_align_2(x: NaturalAlign2);

// m68k: declare void @force_align_4({{.*}}byval(%ForceAlign4) align 4{{.*}})
// m68k: declare void @force_align_4({{.*}}byval([20 x i8]) align 4{{.*}})

// wasm: declare void @force_align_4({{.*}}byval(%ForceAlign4) align 4{{.*}})
// wasm: declare void @force_align_4({{.*}}byval([20 x i8]) align 4{{.*}})

// x86_64-linux: declare void @force_align_4({{.*}}byval(%ForceAlign4) align 4{{.*}})
// x86_64-linux: declare void @force_align_4({{.*}}byval([20 x i8]) align 4{{.*}})

// x86_64-windows: declare void @force_align_4(
// x86_64-windows-NOT: byval
// x86_64-windows-SAME: align 4{{.*}})

// i686-linux: declare void @force_align_4({{.*}}byval(%ForceAlign4) align 4{{.*}})
// i686-linux: declare void @force_align_4({{.*}}byval([20 x i8]) align 4{{.*}})

// i686-windows: declare void @force_align_4({{.*}}byval(%ForceAlign4) align 4{{.*}})
// i686-windows: declare void @force_align_4({{.*}}byval([20 x i8]) align 4{{.*}})
fn force_align_4(x: ForceAlign4);

// m68k: declare void @natural_align_8({{.*}}byval(%NaturalAlign8) align 4{{.*}})
// m68k: declare void @natural_align_8({{.*}}byval([24 x i8]) align 4{{.*}})

// wasm: declare void @natural_align_8({{.*}}byval(%NaturalAlign8) align 8{{.*}})
// wasm: declare void @natural_align_8({{.*}}byval([24 x i8]) align 8{{.*}})

// x86_64-linux: declare void @natural_align_8({{.*}}byval(%NaturalAlign8) align 8{{.*}})
// x86_64-linux: declare void @natural_align_8({{.*}}byval([24 x i8]) align 8{{.*}})

// x86_64-windows: declare void @natural_align_8(
// x86_64-windows-NOT: byval
// x86_64-windows-SAME: align 8{{.*}})

// i686-linux: declare void @natural_align_8({{.*}}byval(%NaturalAlign8) align 4{{.*}})
// i686-linux: declare void @natural_align_8({{.*}}byval([24 x i8]) align 4{{.*}})

// i686-windows: declare void @natural_align_8({{.*}}byval(%NaturalAlign8) align 4{{.*}})
// i686-windows: declare void @natural_align_8({{.*}}byval([24 x i8]) align 4{{.*}})
fn natural_align_8(x: NaturalAlign8);

// m68k: declare void @force_align_8({{.*}}byval(%ForceAlign8) align 8{{.*}})
// m68k: declare void @force_align_8({{.*}}byval([24 x i8]) align 8{{.*}})

// wasm: declare void @force_align_8({{.*}}byval(%ForceAlign8) align 8{{.*}})
// wasm: declare void @force_align_8({{.*}}byval([24 x i8]) align 8{{.*}})

// x86_64-linux: declare void @force_align_8({{.*}}byval(%ForceAlign8) align 8{{.*}})
// x86_64-linux: declare void @force_align_8({{.*}}byval([24 x i8]) align 8{{.*}})

// x86_64-windows: declare void @force_align_8(
// x86_64-windows-NOT: byval
// x86_64-windows-SAME: align 8{{.*}})

// i686-linux: declare void @force_align_8({{.*}}byval(%ForceAlign8) align 4{{.*}})
// i686-linux: declare void @force_align_8({{.*}}byval([24 x i8]) align 4{{.*}})

// i686-windows: declare void @force_align_8(
// i686-windows-NOT: byval
// i686-windows-SAME: align 8{{.*}})
fn force_align_8(x: ForceAlign8);

// m68k: declare void @lower_fa8({{.*}}byval(%LowerFA8) align 4{{.*}})
// m68k: declare void @lower_fa8({{.*}}byval([24 x i8]) align 4{{.*}})

// wasm: declare void @lower_fa8({{.*}}byval(%LowerFA8) align 8{{.*}})
// wasm: declare void @lower_fa8({{.*}}byval([24 x i8]) align 8{{.*}})

// x86_64-linux: declare void @lower_fa8({{.*}}byval(%LowerFA8) align 8{{.*}})
// x86_64-linux: declare void @lower_fa8({{.*}}byval([24 x i8]) align 8{{.*}})

// x86_64-windows: declare void @lower_fa8(
// x86_64-windows-NOT: byval
// x86_64-windows-SAME: align 8{{.*}})

// i686-linux: declare void @lower_fa8({{.*}}byval(%LowerFA8) align 4{{.*}})
// i686-linux: declare void @lower_fa8({{.*}}byval([24 x i8]) align 4{{.*}})

// i686-windows: declare void @lower_fa8({{.*}}byval(%LowerFA8) align 4{{.*}})
// i686-windows: declare void @lower_fa8({{.*}}byval([24 x i8]) align 4{{.*}})
fn lower_fa8(x: LowerFA8);

// m68k: declare void @wrapped_fa8({{.*}}byval(%WrappedFA8) align 8{{.*}})
// m68k: declare void @wrapped_fa8({{.*}}byval([24 x i8]) align 8{{.*}})

// wasm: declare void @wrapped_fa8({{.*}}byval(%WrappedFA8) align 8{{.*}})
// wasm: declare void @wrapped_fa8({{.*}}byval([24 x i8]) align 8{{.*}})

// x86_64-linux: declare void @wrapped_fa8({{.*}}byval(%WrappedFA8) align 8{{.*}})
// x86_64-linux: declare void @wrapped_fa8({{.*}}byval([24 x i8]) align 8{{.*}})

// x86_64-windows: declare void @wrapped_fa8(
// x86_64-windows-NOT: byval
// x86_64-windows-SAME: align 8{{.*}})

// i686-linux: declare void @wrapped_fa8({{.*}}byval(%WrappedFA8) align 4{{.*}})
// i686-linux: declare void @wrapped_fa8({{.*}}byval([24 x i8]) align 4{{.*}})

// i686-windows: declare void @wrapped_fa8(
// i686-windows-NOT: byval
// i686-windows-SAME: align 8{{.*}})
fn wrapped_fa8(x: WrappedFA8);

// m68k: declare void @transparent_fa8({{.*}}byval(%TransparentFA8) align 8{{.*}})
// m68k: declare void @transparent_fa8({{.*}}byval([24 x i8]) align 8{{.*}})

// wasm: declare void @transparent_fa8({{.*}}byval(%TransparentFA8) align 8{{.*}})
// wasm: declare void @transparent_fa8({{.*}}byval([24 x i8]) align 8{{.*}})

// x86_64-linux: declare void @transparent_fa8({{.*}}byval(%TransparentFA8) align 8{{.*}})
// x86_64-linux: declare void @transparent_fa8({{.*}}byval([24 x i8]) align 8{{.*}})

// x86_64-windows: declare void @transparent_fa8(
// x86_64-windows-NOT: byval
// x86_64-windows-SAME: align 8{{.*}})

// i686-linux: declare void @transparent_fa8({{.*}}byval(%TransparentFA8) align 4{{.*}})
// i686-linux: declare void @transparent_fa8({{.*}}byval([24 x i8]) align 4{{.*}})

// i686-windows: declare void @transparent_fa8(
// i686-windows-NOT: byval
// i686-windows-SAME: align 8{{.*}})
fn transparent_fa8(x: TransparentFA8);

// m68k: declare void @force_align_16({{.*}}byval(%ForceAlign16) align 16{{.*}})
// m68k: declare void @force_align_16({{.*}}byval([80 x i8]) align 16{{.*}})

// wasm: declare void @force_align_16({{.*}}byval(%ForceAlign16) align 16{{.*}})
// wasm: declare void @force_align_16({{.*}}byval([80 x i8]) align 16{{.*}})

// x86_64-linux: declare void @force_align_16({{.*}}byval(%ForceAlign16) align 16{{.*}})
// x86_64-linux: declare void @force_align_16({{.*}}byval([80 x i8]) align 16{{.*}})

// x86_64-windows: declare void @force_align_16(
// x86_64-windows-NOT: byval
// x86_64-windows-SAME: align 16{{.*}})

// i686-linux: declare void @force_align_16({{.*}}byval(%ForceAlign16) align 4{{.*}})
// i686-linux: declare void @force_align_16({{.*}}byval([80 x i8]) align 4{{.*}})

// i686-windows: declare void @force_align_16(
// i686-windows-NOT: byval
Expand Down
1 change: 0 additions & 1 deletion tests/codegen/align-enum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ pub enum Align64 {
A(u32),
B(u32),
}
// CHECK: %Align64 = type { i32, [15 x i32] }

pub struct Nested64 {
a: u8,
Expand Down
4 changes: 0 additions & 4 deletions tests/codegen/align-struct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,23 @@

#[repr(align(64))]
pub struct Align64(i32);
// CHECK: %Align64 = type { i32, [15 x i32] }

pub struct Nested64 {
a: Align64,
b: i32,
c: i32,
d: i8,
}
// CHECK: %Nested64 = type { %Align64, i32, i32, i8, [55 x i8] }

pub enum Enum4 {
A(i32),
B(i32),
}
// No Aggregate type, and hence nothing in LLVM IR.

pub enum Enum64 {
A(Align64),
B(i32),
}
// CHECK: %Enum64 = type { i32, [31 x i32] }

// CHECK-LABEL: @align64
#[no_mangle]
Expand Down
4 changes: 2 additions & 2 deletions tests/codegen/function-arguments-noopt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ pub fn borrow_call(x: &i32, f: fn(&i32) -> &i32) -> &i32 {
f(x)
}

// CHECK: void @struct_(ptr sret(%S) align 4{{( %_0)?}}, ptr align 4 %x)
// CHECK: void @struct_(ptr sret([32 x i8]) align 4{{( %_0)?}}, ptr align 4 %x)
#[no_mangle]
pub fn struct_(x: S) -> S {
x
Expand All @@ -51,7 +51,7 @@ pub fn struct_(x: S) -> S {
// CHECK-LABEL: @struct_call
#[no_mangle]
pub fn struct_call(x: S, f: fn(S) -> S) -> S {
// CHECK: call void %f(ptr sret(%S) align 4{{( %_0)?}}, ptr align 4 %{{.+}})
// CHECK: call void %f(ptr sret([32 x i8]) align 4{{( %_0)?}}, ptr align 4 %{{.+}})
f(x)
}

Expand Down
2 changes: 1 addition & 1 deletion tests/codegen/function-arguments.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ pub fn notunpin_box(x: Box<NotUnpin>) -> Box<NotUnpin> {
x
}

// CHECK: @struct_return(ptr noalias nocapture noundef sret(%S) align 4 dereferenceable(32){{( %_0)?}})
// CHECK: @struct_return(ptr noalias nocapture noundef sret([32 x i8]) align 4 dereferenceable(32){{( %_0)?}})
#[no_mangle]
pub fn struct_return() -> S {
S {
Expand Down
2 changes: 1 addition & 1 deletion tests/codegen/mem-replace-big-type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ pub fn replace_big(dst: &mut Big, src: Big) -> Big {
// CHECK-NOT: call void @llvm.memcpy

// For a large type, we expect exactly three `memcpy`s
// CHECK-LABEL: define internal void @{{.+}}mem{{.+}}replace{{.+}}sret(%Big)
// CHECK-LABEL: define internal void @{{.+}}mem{{.+}}replace{{.+}}sret([56 x i8])
// CHECK-NOT: call void @llvm.memcpy
// CHECK: call void @llvm.memcpy.{{.+}}(ptr align 8 %result, ptr align 8 %dest, i{{.*}} 56, i1 false)
// CHECK-NOT: call void @llvm.memcpy
Expand Down
Loading

0 comments on commit 96a7267

Please sign in to comment.