Skip to content

Commit

Permalink
[CIR][CIRGen] Teach all uses of ApplyNonVirtualAndVirtualOffset to us…
Browse files Browse the repository at this point in the history
…e BaseClassAddrOp
  • Loading branch information
bcardosolopes committed Nov 15, 2024
1 parent c10f493 commit 3aed38c
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 29 deletions.
26 changes: 13 additions & 13 deletions clang/lib/CIR/CodeGen/CIRGenClass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -662,21 +662,19 @@ static Address ApplyNonVirtualAndVirtualOffset(
// Compute the offset from the static and dynamic components.
mlir::Value baseOffset;
if (!nonVirtualOffset.isZero()) {
mlir::Type OffsetType =
(CGF.CGM.getTarget().getCXXABI().isItaniumFamily() &&
CGF.CGM.getItaniumVTableContext().isRelativeLayout())
? CGF.SInt32Ty
: CGF.PtrDiffTy;
baseOffset = CGF.getBuilder().getConstInt(loc, OffsetType,
nonVirtualOffset.getQuantity());
if (virtualOffset) {
mlir::Type OffsetType =
(CGF.CGM.getTarget().getCXXABI().isItaniumFamily() &&
CGF.CGM.getItaniumVTableContext().isRelativeLayout())
? CGF.SInt32Ty
: CGF.PtrDiffTy;
baseOffset = CGF.getBuilder().getConstInt(loc, OffsetType,
nonVirtualOffset.getQuantity());
baseOffset = CGF.getBuilder().createBinop(
virtualOffset, cir::BinOpKind::Add, baseOffset);
} else if (baseValueTy) {
// TODO(cir): this should be used as a firt class in this function for the
// nonVirtualOffset cases, but all users of this function need to be
// updated first.
baseOffset.getDefiningOp()->erase();
} else {
assert(baseValueTy && "expected base type");
// If no virtualOffset is present this is the final stop.
return CGF.getBuilder().createBaseClassAddr(
loc, addr, baseValueTy, nonVirtualOffset.getQuantity(),
assumeNotNull);
Expand Down Expand Up @@ -725,6 +723,7 @@ void CIRGenFunction::initializeVTablePointer(mlir::Location loc,
mlir::Value VirtualOffset{};
CharUnits NonVirtualOffset = CharUnits::Zero();

mlir::Type BaseValueTy;
if (CGM.getCXXABI().isVirtualOffsetNeededForVTableField(*this, Vptr)) {
// We need to use the virtual base offset offset because the virtual base
// might have a different offset in the most derived class.
Expand All @@ -734,14 +733,15 @@ void CIRGenFunction::initializeVTablePointer(mlir::Location loc,
} else {
// We can just use the base offset in the complete class.
NonVirtualOffset = Vptr.Base.getBaseOffset();
BaseValueTy = convertType(getContext().getTagDeclType(Vptr.Base.getBase()));
}

// Apply the offsets.
Address VTableField = LoadCXXThisAddress();
if (!NonVirtualOffset.isZero() || VirtualOffset) {
VTableField = ApplyNonVirtualAndVirtualOffset(
loc, *this, VTableField, NonVirtualOffset, VirtualOffset,
Vptr.VTableClass, Vptr.NearestVBase);
Vptr.VTableClass, Vptr.NearestVBase, BaseValueTy);
}

// Finally, store the address point. Use the same CIR types as the field.
Expand Down
9 changes: 3 additions & 6 deletions clang/test/CIR/CodeGen/multi-vtable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,8 @@ int main() {
// CIR: %{{[0-9]+}} = cir.cast(bitcast, %{{[0-9]+}} : !cir.ptr<!ty_Child>), !cir.ptr<!cir.ptr<!cir.ptr<!cir.func<!u32i ()>>>>
// CIR: cir.store %{{[0-9]+}}, %{{[0-9]+}} : !cir.ptr<!cir.ptr<!cir.func<!u32i ()>>>, !cir.ptr<!cir.ptr<!cir.ptr<!cir.func<!u32i ()>>>>
// CIR: %{{[0-9]+}} = cir.vtable.address_point(@_ZTV5Child, vtable_index = 1, address_point_index = 2) : !cir.ptr<!cir.ptr<!cir.func<!u32i ()>>>
// CIR: %{{[0-9]+}} = cir.const #cir.int<8> : !s64i
// CIR: %{{[0-9]+}} = cir.cast(bitcast, %{{[0-9]+}} : !cir.ptr<!ty_Child>), !cir.ptr<!u8i>
// CIR: %{{[0-9]+}} = cir.ptr_stride(%{{[0-9]+}} : !cir.ptr<!u8i>, %{{[0-9]+}} : !s64i), !cir.ptr<!u8i>
// CIR: %{{[0-9]+}} = cir.cast(bitcast, %{{[0-9]+}} : !cir.ptr<!u8i>), !cir.ptr<!ty_Child>
// CIR: %{{[0-9]+}} = cir.cast(bitcast, %{{[0-9]+}} : !cir.ptr<!ty_Child>), !cir.ptr<!cir.ptr<!cir.ptr<!cir.func<!u32i ()>>>>
// CIR: %7 = cir.base_class_addr(%1 : !cir.ptr<!ty_Child> nonnull) [8] -> !cir.ptr<!ty_Father>
// CIR: %8 = cir.cast(bitcast, %7 : !cir.ptr<!ty_Father>), !cir.ptr<!cir.ptr<!cir.ptr<!cir.func<!u32i ()>>>> loc(#loc8)
// CIR: cir.store %{{[0-9]+}}, %{{[0-9]+}} : !cir.ptr<!cir.ptr<!cir.func<!u32i ()>>>, !cir.ptr<!cir.ptr<!cir.ptr<!cir.func<!u32i ()>>>>
// CIR: cir.return
// CIR: }
Expand All @@ -70,7 +67,7 @@ int main() {

// LLVM-DAG: define linkonce_odr void @_ZN5ChildC2Ev(ptr %0)
// LLVM-DAG: store ptr getelementptr inbounds ({ [4 x ptr], [3 x ptr] }, ptr @_ZTV5Child, i32 0, i32 0, i32 2), ptr %{{[0-9]+}}, align 8
// LLVM-DAG: %{{[0-9]+}} = getelementptr i8, ptr %3, i64 8
// LLVM-DAG: %{{[0-9]+}} = getelementptr i8, ptr {{.*}}, i32 8
// LLVM-DAG: store ptr getelementptr inbounds ({ [4 x ptr], [3 x ptr] }, ptr @_ZTV5Child, i32 0, i32 1, i32 2), ptr %{{[0-9]+}}, align 8
// LLVM-DAG: ret void
// }
Expand Down
14 changes: 4 additions & 10 deletions clang/test/CIR/CodeGen/vtt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,19 +120,13 @@ int f() {
// CIR: cir.store %{{[0-9]+}}, %{{[0-9]+}} : !cir.ptr<!cir.ptr<!cir.func<!u32i ()>>>, !cir.ptr<!cir.ptr<!cir.ptr<!cir.func<!u32i ()>>>>
// CIR: %{{[0-9]+}} = cir.vtable.address_point(@_ZTV1D, vtable_index = 2, address_point_index = 3) : !cir.ptr<!cir.ptr<!cir.func<!u32i ()>>>

// CIR: %{{[0-9]+}} = cir.const #cir.int<40> : !s64i
// CIR: %{{[0-9]+}} = cir.cast(bitcast, %{{[0-9]+}} : !cir.ptr<!ty_D>), !cir.ptr<!u8i>
// CIR: %{{[0-9]+}} = cir.ptr_stride(%{{[0-9]+}} : !cir.ptr<!u8i>, %{{[0-9]+}} : !s64i), !cir.ptr<!u8i>
// CIR: %{{[0-9]+}} = cir.cast(bitcast, %{{[0-9]+}} : !cir.ptr<!u8i>), !cir.ptr<!ty_D>
// CIR: %{{[0-9]+}} = cir.cast(bitcast, %{{[0-9]+}} : !cir.ptr<!ty_D>), !cir.ptr<!cir.ptr<!cir.ptr<!cir.func<!u32i ()>>>>
// CIR: %{{[0-9]+}} = cir.base_class_addr(%{{[0-9]+}} : !cir.ptr<!ty_D> nonnull) [40] -> !cir.ptr<!ty_A>
// CIR: %{{[0-9]+}} = cir.cast(bitcast, %{{[0-9]+}} : !cir.ptr<!ty_A>), !cir.ptr<!cir.ptr<!cir.ptr<!cir.func<!u32i ()>>>>
// CIR: cir.store %{{[0-9]+}}, %{{[0-9]+}} : !cir.ptr<!cir.ptr<!cir.func<!u32i ()>>>, !cir.ptr<!cir.ptr<!cir.ptr<!cir.func<!u32i ()>>>>
// CIR: %{{[0-9]+}} = cir.vtable.address_point(@_ZTV1D, vtable_index = 1, address_point_index = 3) : !cir.ptr<!cir.ptr<!cir.func<!u32i ()>>>

// CIR: %{{[0-9]+}} = cir.const #cir.int<16> : !s64i
// CIR: %{{[0-9]+}} = cir.cast(bitcast, %{{[0-9]+}} : !cir.ptr<!ty_D>), !cir.ptr<!u8i>
// CIR: %{{[0-9]+}} = cir.ptr_stride(%{{[0-9]+}} : !cir.ptr<!u8i>, %{{[0-9]+}} : !s64i), !cir.ptr<!u8i>
// CIR: %{{[0-9]+}} = cir.cast(bitcast, %{{[0-9]+}} : !cir.ptr<!u8i>), !cir.ptr<!ty_D>
// CIR: %{{[0-9]+}} = cir.cast(bitcast, %{{[0-9]+}} : !cir.ptr<!ty_D>), !cir.ptr<!cir.ptr<!cir.ptr<!cir.func<!u32i ()>>>>
// CIR: cir.base_class_addr(%{{[0-9]+}} : !cir.ptr<!ty_D> nonnull) [16] -> !cir.ptr<!ty_C>
// CIR: cir.cast(bitcast, %{{[0-9]+}} : !cir.ptr<!ty_C>), !cir.ptr<!cir.ptr<!cir.ptr<!cir.func<!u32i ()>>>>
// CIR: cir.store %{{[0-9]+}}, %{{[0-9]+}} : !cir.ptr<!cir.ptr<!cir.func<!u32i ()>>>, !cir.ptr<!cir.ptr<!cir.ptr<!cir.func<!u32i ()>>>>
// CIR: cir.return
// CIR: }
Expand Down

0 comments on commit 3aed38c

Please sign in to comment.