diff --git a/llvm/include/llvm/Transforms/Utils/ModuleUtils.h b/llvm/include/llvm/Transforms/Utils/ModuleUtils.h index 1ec87505544f83..2a52d3ed1431e5 100644 --- a/llvm/include/llvm/Transforms/Utils/ModuleUtils.h +++ b/llvm/include/llvm/Transforms/Utils/ModuleUtils.h @@ -30,6 +30,7 @@ class FunctionCallee; class GlobalIFunc; class GlobalValue; class Constant; +class ConstantStruct; class Value; class Type; @@ -44,6 +45,14 @@ void appendToGlobalCtors(Module &M, Function *F, int Priority, void appendToGlobalDtors(Module &M, Function *F, int Priority, Constant *Data = nullptr); +/// Apply 'Fn' to the list of global ctors of module M and replace contructor +/// record with the one returned by `Fn`. If `nullptr` was returned, the +/// corresponding constructor will be removed from the array. For details see +/// https://llvm.org/docs/LangRef.html#the-llvm-global-ctors-global-variable +using GlobalCtorTransformFn = llvm::function_ref; +void transformGlobalCtors(Module &M, const GlobalCtorTransformFn &Fn); +void transformGlobalDtors(Module &M, const GlobalCtorTransformFn &Fn); + /// Sets the KCFI type for the function. Used for compiler-generated functions /// that are indirectly called in instrumented code. void setKCFIType(Module &M, Function &F, StringRef MangledType); diff --git a/llvm/lib/Transforms/Utils/ModuleUtils.cpp b/llvm/lib/Transforms/Utils/ModuleUtils.cpp index 122279160cc7e8..309cf8e70b979c 100644 --- a/llvm/lib/Transforms/Utils/ModuleUtils.cpp +++ b/llvm/lib/Transforms/Utils/ModuleUtils.cpp @@ -79,6 +79,50 @@ void llvm::appendToGlobalDtors(Module &M, Function *F, int Priority, Constant *D appendToGlobalArray("llvm.global_dtors", M, F, Priority, Data); } +static void transformGlobalArray(StringRef ArrayName, Module &M, + const GlobalCtorTransformFn &Fn) { + GlobalVariable *GVCtor = M.getNamedGlobal(ArrayName); + if (!GVCtor) + return; + + IRBuilder<> IRB(M.getContext()); + SmallVector CurrentCtors; + bool Changed = false; + StructType *EltTy = + cast(GVCtor->getValueType()->getArrayElementType()); + if (Constant *Init = GVCtor->getInitializer()) { + CurrentCtors.reserve(Init->getNumOperands()); + for (Value *OP : Init->operands()) { + Constant *C = cast(OP); + Constant *NewC = Fn(C); + Changed |= (!NewC || NewC != C); + if (NewC) + CurrentCtors.push_back(NewC); + } + } + if (!Changed) + return; + + GVCtor->eraseFromParent(); + + // Create a new initializer. + ArrayType *AT = ArrayType::get(EltTy, CurrentCtors.size()); + Constant *NewInit = ConstantArray::get(AT, CurrentCtors); + + // Create the new global variable and replace all uses of + // the old global variable with the new one. + (void)new GlobalVariable(M, NewInit->getType(), false, + GlobalValue::AppendingLinkage, NewInit, ArrayName); +} + +void llvm::transformGlobalCtors(Module &M, const GlobalCtorTransformFn &Fn) { + transformGlobalArray("llvm.global_ctors", M, Fn); +} + +void llvm::transformGlobalDtors(Module &M, const GlobalCtorTransformFn &Fn) { + transformGlobalArray("llvm.global_dtors", M, Fn); +} + static void collectUsedGlobals(GlobalVariable *GV, SmallSetVector &Init) { if (!GV || !GV->hasInitializer()) diff --git a/llvm/unittests/Transforms/Utils/ModuleUtilsTest.cpp b/llvm/unittests/Transforms/Utils/ModuleUtilsTest.cpp index 0ed7be9620a6f4..d4094c53070600 100644 --- a/llvm/unittests/Transforms/Utils/ModuleUtilsTest.cpp +++ b/llvm/unittests/Transforms/Utils/ModuleUtilsTest.cpp @@ -70,17 +70,23 @@ TEST(ModuleUtils, AppendToUsedList2) { } using AppendFnType = decltype(&appendToGlobalCtors); -using ParamType = std::tuple; +using TransformFnType = decltype(&transformGlobalCtors); +using ParamType = std::tuple; class ModuleUtilsTest : public testing::TestWithParam { public: StringRef arrayName() const { return std::get<0>(GetParam()); } AppendFnType appendFn() const { return std::get(GetParam()); } + TransformFnType transformFn() const { + return std::get(GetParam()); + } }; INSTANTIATE_TEST_SUITE_P( ModuleUtilsTestCtors, ModuleUtilsTest, - ::testing::Values(ParamType{"llvm.global_ctors", &appendToGlobalCtors}, - ParamType{"llvm.global_dtors", &appendToGlobalDtors})); + ::testing::Values(ParamType{"llvm.global_ctors", &appendToGlobalCtors, + &transformGlobalCtors}, + ParamType{"llvm.global_dtors", &appendToGlobalDtors, + &transformGlobalDtors})); TEST_P(ModuleUtilsTest, AppendToMissingArray) { LLVMContext C; @@ -124,3 +130,39 @@ TEST_P(ModuleUtilsTest, AppendToArray) { 11, nullptr); EXPECT_EQ(3, getListSize(*M, arrayName())); } + +TEST_P(ModuleUtilsTest, UpdateArray) { + LLVMContext C; + + std::unique_ptr M = + parseIR(C, (R"(@)" + arrayName() + + R"( = appending global [2 x { i32, ptr, ptr }] [ + { i32, ptr, ptr } { i32 65535, ptr null, ptr null }, + { i32, ptr, ptr } { i32 0, ptr null, ptr null }] + )") + .str()); + + EXPECT_EQ(2, getListSize(*M, arrayName())); + transformFn()(*M, [](Constant *C) -> Constant * { + ConstantStruct *CS = dyn_cast(C); + if (!CS) + return nullptr; + StructType *EltTy = cast(C->getType()); + Constant *CSVals[3] = { + ConstantInt::getSigned(CS->getOperand(0)->getType(), 12), + CS->getOperand(1), + CS->getOperand(2), + }; + return ConstantStruct::get(EltTy, + ArrayRef(CSVals, EltTy->getNumElements())); + }); + EXPECT_EQ(1, getListSize(*M, arrayName())); + ConstantArray *CA = dyn_cast( + M->getGlobalVariable(arrayName())->getInitializer()); + ASSERT_NE(nullptr, CA); + ConstantStruct *CS = dyn_cast(CA->getOperand(0)); + ASSERT_NE(nullptr, CS); + ConstantInt *Pri = dyn_cast(CS->getOperand(0)); + ASSERT_NE(nullptr, Pri); + EXPECT_EQ(12u, Pri->getLimitedValue()); +}