Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PAC] Implement authentication for C++ member function pointers #99576

Merged
merged 4 commits into from
Jul 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion clang/include/clang/AST/ASTContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -1287,7 +1287,7 @@ class ASTContext : public RefCountedBase<ASTContext> {
getPointerAuthVTablePointerDiscriminator(const CXXRecordDecl *RD);

/// Return the "other" type-specific discriminator for the given type.
uint16_t getPointerAuthTypeDiscriminator(QualType T) const;
uint16_t getPointerAuthTypeDiscriminator(QualType T);

/// Apply Objective-C protocol qualifiers to the given type.
/// \param allowOnPointerType specifies if we can apply protocol
Expand Down
3 changes: 3 additions & 0 deletions clang/include/clang/Basic/PointerAuthOptions.h
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,9 @@ struct PointerAuthOptions {

/// The ABI for variadic C++ virtual function pointers.
PointerAuthSchema CXXVirtualVariadicFunctionPointers;

/// The ABI for C++ member function pointers.
PointerAuthSchema CXXMemberFunctionPointers;
};

} // end namespace clang
Expand Down
12 changes: 7 additions & 5 deletions clang/lib/AST/ASTContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3407,7 +3407,7 @@ static void encodeTypeForFunctionPointerAuth(const ASTContext &Ctx,
}
}

uint16_t ASTContext::getPointerAuthTypeDiscriminator(QualType T) const {
uint16_t ASTContext::getPointerAuthTypeDiscriminator(QualType T) {
assert(!T->isDependentType() &&
"cannot compute type discriminator of a dependent type");

Expand All @@ -3417,11 +3417,13 @@ uint16_t ASTContext::getPointerAuthTypeDiscriminator(QualType T) const {
if (T->isFunctionPointerType() || T->isFunctionReferenceType())
T = T->getPointeeType();

if (T->isFunctionType())
if (T->isFunctionType()) {
encodeTypeForFunctionPointerAuth(*this, Out, T);
else
llvm_unreachable(
"type discrimination of non-function type not implemented yet");
} else {
ojhunt marked this conversation as resolved.
Show resolved Hide resolved
T = T.getUnqualifiedType();
std::unique_ptr<MangleContext> MC(createMangleContext());
MC->mangleCanonicalTypeName(T, Out);
}

return llvm::getPointerAuthStableSipHash(Str);
}
Expand Down
213 changes: 115 additions & 98 deletions clang/lib/CodeGen/CGCall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5034,7 +5034,8 @@ RValue CodeGenFunction::EmitCall(const CGFunctionInfo &CallInfo,
ReturnValueSlot ReturnValue,
const CallArgList &CallArgs,
llvm::CallBase **callOrInvoke, bool IsMustTail,
SourceLocation Loc) {
SourceLocation Loc,
bool IsVirtualFunctionPointerThunk) {
// FIXME: We no longer need the types from CallArgs; lift up and simplify.

assert(Callee.isOrdinary() || Callee.isVirtual());
Expand Down Expand Up @@ -5098,7 +5099,11 @@ RValue CodeGenFunction::EmitCall(const CGFunctionInfo &CallInfo,
RawAddress SRetAlloca = RawAddress::invalid();
llvm::Value *UnusedReturnSizePtr = nullptr;
if (RetAI.isIndirect() || RetAI.isInAlloca() || RetAI.isCoerceAndExpand()) {
if (!ReturnValue.isNull()) {
if (IsVirtualFunctionPointerThunk && RetAI.isIndirect()) {
SRetPtr = makeNaturalAddressForPointer(CurFn->arg_begin() +
IRFunctionArgs.getSRetArgNo(),
RetTy, CharUnits::fromQuantity(1));
} else if (!ReturnValue.isNull()) {
SRetPtr = ReturnValue.getAddress();
} else {
SRetPtr = CreateMemTemp(RetTy, "tmp", &SRetAlloca);
Expand Down Expand Up @@ -5877,119 +5882,131 @@ RValue CodeGenFunction::EmitCall(const CGFunctionInfo &CallInfo,
CallArgs.freeArgumentMemory(*this);

// Extract the return value.
RValue Ret = [&] {
switch (RetAI.getKind()) {
case ABIArgInfo::CoerceAndExpand: {
auto coercionType = RetAI.getCoerceAndExpandType();

Address addr = SRetPtr.withElementType(coercionType);

assert(CI->getType() == RetAI.getUnpaddedCoerceAndExpandType());
bool requiresExtract = isa<llvm::StructType>(CI->getType());
RValue Ret;

unsigned unpaddedIndex = 0;
for (unsigned i = 0, e = coercionType->getNumElements(); i != e; ++i) {
llvm::Type *eltType = coercionType->getElementType(i);
if (ABIArgInfo::isPaddingForCoerceAndExpand(eltType)) continue;
Address eltAddr = Builder.CreateStructGEP(addr, i);
llvm::Value *elt = CI;
if (requiresExtract)
elt = Builder.CreateExtractValue(elt, unpaddedIndex++);
else
assert(unpaddedIndex == 0);
Builder.CreateStore(elt, eltAddr);
// If the current function is a virtual function pointer thunk, avoid copying
// the return value of the musttail call to a temporary.
if (IsVirtualFunctionPointerThunk) {
Ret = RValue::get(CI);
} else {
Ret = [&] {
switch (RetAI.getKind()) {
case ABIArgInfo::CoerceAndExpand: {
auto coercionType = RetAI.getCoerceAndExpandType();

Address addr = SRetPtr.withElementType(coercionType);

assert(CI->getType() == RetAI.getUnpaddedCoerceAndExpandType());
bool requiresExtract = isa<llvm::StructType>(CI->getType());

unsigned unpaddedIndex = 0;
for (unsigned i = 0, e = coercionType->getNumElements(); i != e; ++i) {
llvm::Type *eltType = coercionType->getElementType(i);
if (ABIArgInfo::isPaddingForCoerceAndExpand(eltType))
continue;
Address eltAddr = Builder.CreateStructGEP(addr, i);
llvm::Value *elt = CI;
if (requiresExtract)
elt = Builder.CreateExtractValue(elt, unpaddedIndex++);
else
assert(unpaddedIndex == 0);
Builder.CreateStore(elt, eltAddr);
}
[[fallthrough]];
}
[[fallthrough]];
}

case ABIArgInfo::InAlloca:
case ABIArgInfo::Indirect: {
RValue ret = convertTempToRValue(SRetPtr, RetTy, SourceLocation());
if (UnusedReturnSizePtr)
PopCleanupBlock();
return ret;
}

case ABIArgInfo::Ignore:
// If we are ignoring an argument that had a result, make sure to
// construct the appropriate return value for our caller.
return GetUndefRValue(RetTy);
case ABIArgInfo::InAlloca:
case ABIArgInfo::Indirect: {
RValue ret = convertTempToRValue(SRetPtr, RetTy, SourceLocation());
if (UnusedReturnSizePtr)
PopCleanupBlock();
return ret;
}

case ABIArgInfo::Extend:
case ABIArgInfo::Direct: {
llvm::Type *RetIRTy = ConvertType(RetTy);
if (RetAI.getCoerceToType() == RetIRTy && RetAI.getDirectOffset() == 0) {
switch (getEvaluationKind(RetTy)) {
case TEK_Complex: {
llvm::Value *Real = Builder.CreateExtractValue(CI, 0);
llvm::Value *Imag = Builder.CreateExtractValue(CI, 1);
return RValue::getComplex(std::make_pair(Real, Imag));
}
case TEK_Aggregate: {
Address DestPtr = ReturnValue.getAddress();
bool DestIsVolatile = ReturnValue.isVolatile();
case ABIArgInfo::Ignore:
// If we are ignoring an argument that had a result, make sure to
// construct the appropriate return value for our caller.
return GetUndefRValue(RetTy);

case ABIArgInfo::Extend:
case ABIArgInfo::Direct: {
llvm::Type *RetIRTy = ConvertType(RetTy);
if (RetAI.getCoerceToType() == RetIRTy &&
RetAI.getDirectOffset() == 0) {
switch (getEvaluationKind(RetTy)) {
case TEK_Complex: {
llvm::Value *Real = Builder.CreateExtractValue(CI, 0);
llvm::Value *Imag = Builder.CreateExtractValue(CI, 1);
return RValue::getComplex(std::make_pair(Real, Imag));
}
case TEK_Aggregate: {
Address DestPtr = ReturnValue.getAddress();
bool DestIsVolatile = ReturnValue.isVolatile();

if (!DestPtr.isValid()) {
DestPtr = CreateMemTemp(RetTy, "agg.tmp");
DestIsVolatile = false;
if (!DestPtr.isValid()) {
DestPtr = CreateMemTemp(RetTy, "agg.tmp");
DestIsVolatile = false;
}
EmitAggregateStore(CI, DestPtr, DestIsVolatile);
return RValue::getAggregate(DestPtr);
}
case TEK_Scalar: {
// If the argument doesn't match, perform a bitcast to coerce it.
// This can happen due to trivial type mismatches.
llvm::Value *V = CI;
if (V->getType() != RetIRTy)
V = Builder.CreateBitCast(V, RetIRTy);
return RValue::get(V);
}
EmitAggregateStore(CI, DestPtr, DestIsVolatile);
return RValue::getAggregate(DestPtr);
}
llvm_unreachable("bad evaluation kind");
}
case TEK_Scalar: {
// If the argument doesn't match, perform a bitcast to coerce it. This
// can happen due to trivial type mismatches.

// If coercing a fixed vector from a scalable vector for ABI
// compatibility, and the types match, use the llvm.vector.extract
// intrinsic to perform the conversion.
if (auto *FixedDstTy = dyn_cast<llvm::FixedVectorType>(RetIRTy)) {
llvm::Value *V = CI;
if (V->getType() != RetIRTy)
V = Builder.CreateBitCast(V, RetIRTy);
return RValue::get(V);
}
if (auto *ScalableSrcTy =
dyn_cast<llvm::ScalableVectorType>(V->getType())) {
if (FixedDstTy->getElementType() ==
ScalableSrcTy->getElementType()) {
llvm::Value *Zero = llvm::Constant::getNullValue(CGM.Int64Ty);
V = Builder.CreateExtractVector(FixedDstTy, V, Zero,
"cast.fixed");
return RValue::get(V);
}
}
}
llvm_unreachable("bad evaluation kind");
}

// If coercing a fixed vector from a scalable vector for ABI
// compatibility, and the types match, use the llvm.vector.extract
// intrinsic to perform the conversion.
if (auto *FixedDstTy = dyn_cast<llvm::FixedVectorType>(RetIRTy)) {
llvm::Value *V = CI;
if (auto *ScalableSrcTy =
dyn_cast<llvm::ScalableVectorType>(V->getType())) {
if (FixedDstTy->getElementType() == ScalableSrcTy->getElementType()) {
llvm::Value *Zero = llvm::Constant::getNullValue(CGM.Int64Ty);
V = Builder.CreateExtractVector(FixedDstTy, V, Zero, "cast.fixed");
return RValue::get(V);
}
Address DestPtr = ReturnValue.getValue();
bool DestIsVolatile = ReturnValue.isVolatile();

if (!DestPtr.isValid()) {
DestPtr = CreateMemTemp(RetTy, "coerce");
DestIsVolatile = false;
}
}

Address DestPtr = ReturnValue.getValue();
bool DestIsVolatile = ReturnValue.isVolatile();
// An empty record can overlap other data (if declared with
// no_unique_address); omit the store for such types - as there is no
// actual data to store.
if (!isEmptyRecord(getContext(), RetTy, true)) {
// If the value is offset in memory, apply the offset now.
Address StorePtr = emitAddressAtOffset(*this, DestPtr, RetAI);
CreateCoercedStore(CI, StorePtr, DestIsVolatile, *this);
}

if (!DestPtr.isValid()) {
DestPtr = CreateMemTemp(RetTy, "coerce");
DestIsVolatile = false;
return convertTempToRValue(DestPtr, RetTy, SourceLocation());
}

// An empty record can overlap other data (if declared with
// no_unique_address); omit the store for such types - as there is no
// actual data to store.
if (!isEmptyRecord(getContext(), RetTy, true)) {
// If the value is offset in memory, apply the offset now.
Address StorePtr = emitAddressAtOffset(*this, DestPtr, RetAI);
CreateCoercedStore(CI, StorePtr, DestIsVolatile, *this);
case ABIArgInfo::Expand:
case ABIArgInfo::IndirectAliased:
llvm_unreachable("Invalid ABI kind for return argument");
}

return convertTempToRValue(DestPtr, RetTy, SourceLocation());
}

case ABIArgInfo::Expand:
case ABIArgInfo::IndirectAliased:
llvm_unreachable("Invalid ABI kind for return argument");
}

llvm_unreachable("Unhandled ABIArgInfo::Kind");
} ();
llvm_unreachable("Unhandled ABIArgInfo::Kind");
}();
}

// Emit the assume_aligned check on the return value.
if (Ret.isScalar() && TargetDecl) {
Expand Down
34 changes: 34 additions & 0 deletions clang/lib/CodeGen/CGPointerAuth.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,40 @@ llvm::Constant *CodeGenModule::getFunctionPointer(GlobalDecl GD,
return getFunctionPointer(getRawFunctionPointer(GD, Ty), FuncType);
}

CGPointerAuthInfo CodeGenModule::getMemberFunctionPointerAuthInfo(QualType FT) {
assert(FT->getAs<MemberPointerType>() && "MemberPointerType expected");
const auto &Schema = getCodeGenOpts().PointerAuth.CXXMemberFunctionPointers;
if (!Schema)
return CGPointerAuthInfo();

assert(!Schema.isAddressDiscriminated() &&
"function pointers cannot use address-specific discrimination");

llvm::ConstantInt *Discriminator =
getPointerAuthOtherDiscriminator(Schema, GlobalDecl(), FT);
return CGPointerAuthInfo(Schema.getKey(), Schema.getAuthenticationMode(),
/* IsIsaPointer */ false,
/* AuthenticatesNullValues */ false, Discriminator);
}

llvm::Constant *CodeGenModule::getMemberFunctionPointer(llvm::Constant *Pointer,
QualType FT) {
if (CGPointerAuthInfo PointerAuth = getMemberFunctionPointerAuthInfo(FT))
return getConstantSignedPointer(
Pointer, PointerAuth.getKey(), nullptr,
cast_or_null<llvm::ConstantInt>(PointerAuth.getDiscriminator()));

return Pointer;
}

llvm::Constant *CodeGenModule::getMemberFunctionPointer(const FunctionDecl *FD,
llvm::Type *Ty) {
QualType FT = FD->getType();
FT = getContext().getMemberPointerType(
FT, cast<CXXMethodDecl>(FD)->getParent()->getTypeForDecl());
return getMemberFunctionPointer(getRawFunctionPointer(FD, Ty), FT);
}

std::optional<PointerAuthQualifier>
CodeGenModule::computeVTPointerAuthentication(const CXXRecordDecl *ThisClass) {
auto DefaultAuthentication = getCodeGenOpts().PointerAuth.CXXVTablePointers;
Expand Down
3 changes: 2 additions & 1 deletion clang/lib/CodeGen/CodeGenFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -4374,7 +4374,8 @@ class CodeGenFunction : public CodeGenTypeCache {
RValue EmitCall(const CGFunctionInfo &CallInfo, const CGCallee &Callee,
ReturnValueSlot ReturnValue, const CallArgList &Args,
llvm::CallBase **callOrInvoke, bool IsMustTail,
SourceLocation Loc);
SourceLocation Loc,
bool IsVirtualFunctionPointerThunk = false);
RValue EmitCall(const CGFunctionInfo &CallInfo, const CGCallee &Callee,
ReturnValueSlot ReturnValue, const CallArgList &Args,
llvm::CallBase **callOrInvoke = nullptr,
Expand Down
8 changes: 8 additions & 0 deletions clang/lib/CodeGen/CodeGenModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -973,8 +973,16 @@ class CodeGenModule : public CodeGenTypeCache {
llvm::Constant *getFunctionPointer(llvm::Constant *Pointer,
QualType FunctionType);

llvm::Constant *getMemberFunctionPointer(const FunctionDecl *FD,
llvm::Type *Ty = nullptr);

llvm::Constant *getMemberFunctionPointer(llvm::Constant *Pointer,
QualType FT);

CGPointerAuthInfo getFunctionPointerAuthInfo(QualType T);

CGPointerAuthInfo getMemberFunctionPointerAuthInfo(QualType FT);

CGPointerAuthInfo getPointerAuthInfoForPointeeType(QualType type);

CGPointerAuthInfo getPointerAuthInfoForType(QualType type);
Expand Down
Loading
Loading