Skip to content

Commit

Permalink
[PAC] Implement pointer authentication for C++ member function pointers.
Browse files Browse the repository at this point in the history
Introduces type based signing of member function pointers. To support this
discrimination schema we no longer emit member function pointer to virtual
methods and indices into a vtable but migrate to using thunks. This does mean
member function pointers are no longer necessarily directly comparable,
however as such comparisons are UB this is acceptable.

We derive the discriminator from the C++ mangling of the type of the pointer
being authenticated.

Co-Authored-By: Akira Hatanaka ahatanaka@apple.com
Co-Authored-By: John McCall rjmccall@apple.com
  • Loading branch information
ahmedbougacha authored and ojhunt committed Jul 18, 2024
1 parent 4afdcd9 commit 547ce3e
Show file tree
Hide file tree
Showing 11 changed files with 855 additions and 117 deletions.
2 changes: 1 addition & 1 deletion clang/include/clang/AST/ASTContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -1287,7 +1287,7 @@ class ASTContext : public RefCountedBase<ASTContext> {
getPointerAuthVTablePointerDiscriminator(const CXXRecordDecl *RD);

/// Return the "other" type-specific discriminator for the given type.
uint16_t getPointerAuthTypeDiscriminator(QualType T) const;
uint16_t getPointerAuthTypeDiscriminator(QualType T);

/// Apply Objective-C protocol qualifiers to the given type.
/// \param allowOnPointerType specifies if we can apply protocol
Expand Down
3 changes: 3 additions & 0 deletions clang/include/clang/Basic/PointerAuthOptions.h
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,9 @@ struct PointerAuthOptions {

/// The ABI for variadic C++ virtual function pointers.
PointerAuthSchema CXXVirtualVariadicFunctionPointers;

/// The ABI for C++ member function pointers.
PointerAuthSchema CXXMemberFunctionPointers;
};

} // end namespace clang
Expand Down
12 changes: 7 additions & 5 deletions clang/lib/AST/ASTContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3404,7 +3404,7 @@ static void encodeTypeForFunctionPointerAuth(const ASTContext &Ctx,
}
}

uint16_t ASTContext::getPointerAuthTypeDiscriminator(QualType T) const {
uint16_t ASTContext::getPointerAuthTypeDiscriminator(QualType T) {
assert(!T->isDependentType() &&
"cannot compute type discriminator of a dependent type");

Expand All @@ -3414,11 +3414,13 @@ uint16_t ASTContext::getPointerAuthTypeDiscriminator(QualType T) const {
if (T->isFunctionPointerType() || T->isFunctionReferenceType())
T = T->getPointeeType();

if (T->isFunctionType())
if (T->isFunctionType()) {
encodeTypeForFunctionPointerAuth(*this, Out, T);
else
llvm_unreachable(
"type discrimination of non-function type not implemented yet");
} else {
T = T.getUnqualifiedType();
std::unique_ptr<MangleContext> MC(createMangleContext());
MC->mangleCanonicalTypeName(T, Out);
}

return llvm::getPointerAuthStableSipHash(Str);
}
Expand Down
212 changes: 114 additions & 98 deletions clang/lib/CodeGen/CGCall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5034,7 +5034,8 @@ RValue CodeGenFunction::EmitCall(const CGFunctionInfo &CallInfo,
ReturnValueSlot ReturnValue,
const CallArgList &CallArgs,
llvm::CallBase **callOrInvoke, bool IsMustTail,
SourceLocation Loc) {
SourceLocation Loc,
bool IsVirtualFunctionPointerThunk) {
// FIXME: We no longer need the types from CallArgs; lift up and simplify.

assert(Callee.isOrdinary() || Callee.isVirtual());
Expand Down Expand Up @@ -5098,7 +5099,11 @@ RValue CodeGenFunction::EmitCall(const CGFunctionInfo &CallInfo,
RawAddress SRetAlloca = RawAddress::invalid();
llvm::Value *UnusedReturnSizePtr = nullptr;
if (RetAI.isIndirect() || RetAI.isInAlloca() || RetAI.isCoerceAndExpand()) {
if (!ReturnValue.isNull()) {
if (IsVirtualFunctionPointerThunk && RetAI.isIndirect()) {
SRetPtr = makeNaturalAddressForPointer(CurFn->arg_begin() +
IRFunctionArgs.getSRetArgNo(),
RetTy, CharUnits::fromQuantity(1));
} else if (!ReturnValue.isNull()) {
SRetPtr = ReturnValue.getAddress();
} else {
SRetPtr = CreateMemTemp(RetTy, "tmp", &SRetAlloca);
Expand Down Expand Up @@ -5877,119 +5882,130 @@ RValue CodeGenFunction::EmitCall(const CGFunctionInfo &CallInfo,
CallArgs.freeArgumentMemory(*this);

// Extract the return value.
RValue Ret = [&] {
switch (RetAI.getKind()) {
case ABIArgInfo::CoerceAndExpand: {
auto coercionType = RetAI.getCoerceAndExpandType();

Address addr = SRetPtr.withElementType(coercionType);

assert(CI->getType() == RetAI.getUnpaddedCoerceAndExpandType());
bool requiresExtract = isa<llvm::StructType>(CI->getType());
RValue Ret;

unsigned unpaddedIndex = 0;
for (unsigned i = 0, e = coercionType->getNumElements(); i != e; ++i) {
llvm::Type *eltType = coercionType->getElementType(i);
if (ABIArgInfo::isPaddingForCoerceAndExpand(eltType)) continue;
Address eltAddr = Builder.CreateStructGEP(addr, i);
llvm::Value *elt = CI;
if (requiresExtract)
elt = Builder.CreateExtractValue(elt, unpaddedIndex++);
else
assert(unpaddedIndex == 0);
Builder.CreateStore(elt, eltAddr);
// If the current function is a virtual function pointer thunk, avoid copying
// the return value of the musttail call to a temporary.
if (IsVirtualFunctionPointerThunk)
Ret = RValue::get(CI);
else
Ret = [&] {
switch (RetAI.getKind()) {
case ABIArgInfo::CoerceAndExpand: {
auto coercionType = RetAI.getCoerceAndExpandType();

Address addr = SRetPtr.withElementType(coercionType);

assert(CI->getType() == RetAI.getUnpaddedCoerceAndExpandType());
bool requiresExtract = isa<llvm::StructType>(CI->getType());

unsigned unpaddedIndex = 0;
for (unsigned i = 0, e = coercionType->getNumElements(); i != e; ++i) {
llvm::Type *eltType = coercionType->getElementType(i);
if (ABIArgInfo::isPaddingForCoerceAndExpand(eltType))
continue;
Address eltAddr = Builder.CreateStructGEP(addr, i);
llvm::Value *elt = CI;
if (requiresExtract)
elt = Builder.CreateExtractValue(elt, unpaddedIndex++);
else
assert(unpaddedIndex == 0);
Builder.CreateStore(elt, eltAddr);
}
[[fallthrough]];
}
[[fallthrough]];
}

case ABIArgInfo::InAlloca:
case ABIArgInfo::Indirect: {
RValue ret = convertTempToRValue(SRetPtr, RetTy, SourceLocation());
if (UnusedReturnSizePtr)
PopCleanupBlock();
return ret;
}

case ABIArgInfo::Ignore:
// If we are ignoring an argument that had a result, make sure to
// construct the appropriate return value for our caller.
return GetUndefRValue(RetTy);
case ABIArgInfo::InAlloca:
case ABIArgInfo::Indirect: {
RValue ret = convertTempToRValue(SRetPtr, RetTy, SourceLocation());
if (UnusedReturnSizePtr)
PopCleanupBlock();
return ret;
}

case ABIArgInfo::Extend:
case ABIArgInfo::Direct: {
llvm::Type *RetIRTy = ConvertType(RetTy);
if (RetAI.getCoerceToType() == RetIRTy && RetAI.getDirectOffset() == 0) {
switch (getEvaluationKind(RetTy)) {
case TEK_Complex: {
llvm::Value *Real = Builder.CreateExtractValue(CI, 0);
llvm::Value *Imag = Builder.CreateExtractValue(CI, 1);
return RValue::getComplex(std::make_pair(Real, Imag));
}
case TEK_Aggregate: {
Address DestPtr = ReturnValue.getAddress();
bool DestIsVolatile = ReturnValue.isVolatile();
case ABIArgInfo::Ignore:
// If we are ignoring an argument that had a result, make sure to
// construct the appropriate return value for our caller.
return GetUndefRValue(RetTy);

case ABIArgInfo::Extend:
case ABIArgInfo::Direct: {
llvm::Type *RetIRTy = ConvertType(RetTy);
if (RetAI.getCoerceToType() == RetIRTy &&
RetAI.getDirectOffset() == 0) {
switch (getEvaluationKind(RetTy)) {
case TEK_Complex: {
llvm::Value *Real = Builder.CreateExtractValue(CI, 0);
llvm::Value *Imag = Builder.CreateExtractValue(CI, 1);
return RValue::getComplex(std::make_pair(Real, Imag));
}
case TEK_Aggregate: {
Address DestPtr = ReturnValue.getAddress();
bool DestIsVolatile = ReturnValue.isVolatile();

if (!DestPtr.isValid()) {
DestPtr = CreateMemTemp(RetTy, "agg.tmp");
DestIsVolatile = false;
if (!DestPtr.isValid()) {
DestPtr = CreateMemTemp(RetTy, "agg.tmp");
DestIsVolatile = false;
}
EmitAggregateStore(CI, DestPtr, DestIsVolatile);
return RValue::getAggregate(DestPtr);
}
case TEK_Scalar: {
// If the argument doesn't match, perform a bitcast to coerce it.
// This can happen due to trivial type mismatches.
llvm::Value *V = CI;
if (V->getType() != RetIRTy)
V = Builder.CreateBitCast(V, RetIRTy);
return RValue::get(V);
}
}
EmitAggregateStore(CI, DestPtr, DestIsVolatile);
return RValue::getAggregate(DestPtr);
llvm_unreachable("bad evaluation kind");
}
case TEK_Scalar: {
// If the argument doesn't match, perform a bitcast to coerce it. This
// can happen due to trivial type mismatches.

// If coercing a fixed vector from a scalable vector for ABI
// compatibility, and the types match, use the llvm.vector.extract
// intrinsic to perform the conversion.
if (auto *FixedDstTy = dyn_cast<llvm::FixedVectorType>(RetIRTy)) {
llvm::Value *V = CI;
if (V->getType() != RetIRTy)
V = Builder.CreateBitCast(V, RetIRTy);
return RValue::get(V);
}
if (auto *ScalableSrcTy =
dyn_cast<llvm::ScalableVectorType>(V->getType())) {
if (FixedDstTy->getElementType() ==
ScalableSrcTy->getElementType()) {
llvm::Value *Zero = llvm::Constant::getNullValue(CGM.Int64Ty);
V = Builder.CreateExtractVector(FixedDstTy, V, Zero,
"cast.fixed");
return RValue::get(V);
}
}
}
llvm_unreachable("bad evaluation kind");
}

// If coercing a fixed vector from a scalable vector for ABI
// compatibility, and the types match, use the llvm.vector.extract
// intrinsic to perform the conversion.
if (auto *FixedDstTy = dyn_cast<llvm::FixedVectorType>(RetIRTy)) {
llvm::Value *V = CI;
if (auto *ScalableSrcTy =
dyn_cast<llvm::ScalableVectorType>(V->getType())) {
if (FixedDstTy->getElementType() == ScalableSrcTy->getElementType()) {
llvm::Value *Zero = llvm::Constant::getNullValue(CGM.Int64Ty);
V = Builder.CreateExtractVector(FixedDstTy, V, Zero, "cast.fixed");
return RValue::get(V);
}
Address DestPtr = ReturnValue.getValue();
bool DestIsVolatile = ReturnValue.isVolatile();

if (!DestPtr.isValid()) {
DestPtr = CreateMemTemp(RetTy, "coerce");
DestIsVolatile = false;
}
}

Address DestPtr = ReturnValue.getValue();
bool DestIsVolatile = ReturnValue.isVolatile();
// An empty record can overlap other data (if declared with
// no_unique_address); omit the store for such types - as there is no
// actual data to store.
if (!isEmptyRecord(getContext(), RetTy, true)) {
// If the value is offset in memory, apply the offset now.
Address StorePtr = emitAddressAtOffset(*this, DestPtr, RetAI);
CreateCoercedStore(CI, StorePtr, DestIsVolatile, *this);
}

if (!DestPtr.isValid()) {
DestPtr = CreateMemTemp(RetTy, "coerce");
DestIsVolatile = false;
return convertTempToRValue(DestPtr, RetTy, SourceLocation());
}

// An empty record can overlap other data (if declared with
// no_unique_address); omit the store for such types - as there is no
// actual data to store.
if (!isEmptyRecord(getContext(), RetTy, true)) {
// If the value is offset in memory, apply the offset now.
Address StorePtr = emitAddressAtOffset(*this, DestPtr, RetAI);
CreateCoercedStore(CI, StorePtr, DestIsVolatile, *this);
case ABIArgInfo::Expand:
case ABIArgInfo::IndirectAliased:
llvm_unreachable("Invalid ABI kind for return argument");
}

return convertTempToRValue(DestPtr, RetTy, SourceLocation());
}

case ABIArgInfo::Expand:
case ABIArgInfo::IndirectAliased:
llvm_unreachable("Invalid ABI kind for return argument");
}

llvm_unreachable("Unhandled ABIArgInfo::Kind");
} ();
llvm_unreachable("Unhandled ABIArgInfo::Kind");
}();

// Emit the assume_aligned check on the return value.
if (Ret.isScalar() && TargetDecl) {
Expand Down
34 changes: 34 additions & 0 deletions clang/lib/CodeGen/CGPointerAuth.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,40 @@ llvm::Constant *CodeGenModule::getFunctionPointer(GlobalDecl GD,
return getFunctionPointer(getRawFunctionPointer(GD, Ty), FuncType);
}

CGPointerAuthInfo CodeGenModule::getMemberFunctionPointerAuthInfo(QualType FT) {
assert(FT->getAs<MemberPointerType>() && "MemberPointerType expected");
auto &Schema = getCodeGenOpts().PointerAuth.CXXMemberFunctionPointers;
if (!Schema)
return CGPointerAuthInfo();

assert(!Schema.isAddressDiscriminated() &&
"function pointers cannot use address-specific discrimination");

llvm::ConstantInt *Discriminator =
getPointerAuthOtherDiscriminator(Schema, GlobalDecl(), FT);
return CGPointerAuthInfo(Schema.getKey(), Schema.getAuthenticationMode(),
/* IsIsaPointer */ false,
/* AuthenticatesNullValues */ false, Discriminator);
}

llvm::Constant *CodeGenModule::getMemberFunctionPointer(llvm::Constant *Pointer,
QualType FT) {
if (CGPointerAuthInfo PointerAuth = getMemberFunctionPointerAuthInfo(FT))
return getConstantSignedPointer(
Pointer, PointerAuth.getKey(), nullptr,
cast_or_null<llvm::ConstantInt>(PointerAuth.getDiscriminator()));

return Pointer;
}

llvm::Constant *CodeGenModule::getMemberFunctionPointer(const FunctionDecl *FD,
llvm::Type *Ty) {
QualType FT = FD->getType();
FT = getContext().getMemberPointerType(
FT, cast<CXXMethodDecl>(FD)->getParent()->getTypeForDecl());
return getMemberFunctionPointer(getRawFunctionPointer(FD, Ty), FT);
}

std::optional<PointerAuthQualifier>
CodeGenModule::computeVTPointerAuthentication(const CXXRecordDecl *ThisClass) {
auto DefaultAuthentication = getCodeGenOpts().PointerAuth.CXXVTablePointers;
Expand Down
3 changes: 2 additions & 1 deletion clang/lib/CodeGen/CodeGenFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -4373,7 +4373,8 @@ class CodeGenFunction : public CodeGenTypeCache {
RValue EmitCall(const CGFunctionInfo &CallInfo, const CGCallee &Callee,
ReturnValueSlot ReturnValue, const CallArgList &Args,
llvm::CallBase **callOrInvoke, bool IsMustTail,
SourceLocation Loc);
SourceLocation Loc,
bool IsVirtualFunctionPointerThunk = false);
RValue EmitCall(const CGFunctionInfo &CallInfo, const CGCallee &Callee,
ReturnValueSlot ReturnValue, const CallArgList &Args,
llvm::CallBase **callOrInvoke = nullptr,
Expand Down
8 changes: 8 additions & 0 deletions clang/lib/CodeGen/CodeGenModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -973,8 +973,16 @@ class CodeGenModule : public CodeGenTypeCache {
llvm::Constant *getFunctionPointer(llvm::Constant *Pointer,
QualType FunctionType);

llvm::Constant *getMemberFunctionPointer(const FunctionDecl *FD,
llvm::Type *Ty = nullptr);

llvm::Constant *getMemberFunctionPointer(llvm::Constant *Pointer,
QualType FT);

CGPointerAuthInfo getFunctionPointerAuthInfo(QualType T);

CGPointerAuthInfo getMemberFunctionPointerAuthInfo(QualType FT);

CGPointerAuthInfo getPointerAuthInfoForPointeeType(QualType type);

CGPointerAuthInfo getPointerAuthInfoForType(QualType type);
Expand Down
Loading

0 comments on commit 547ce3e

Please sign in to comment.