From dd684dca5216cad6a367bf96ba2bc777ec63e06e Mon Sep 17 00:00:00 2001 From: Cameron Zwarich Date: Wed, 25 Jun 2014 20:33:49 -0700 Subject: [PATCH 1/2] Add a NullCheckElimination pass This pass is not Rust-specific, in that all of its transformations are intended to be correct for arbitrary LLVM IR, but it targets idioms found in IR generated by `rustc`, e.g. heavy use of `inbounds` GEPs. --- include/llvm/InitializePasses.h | 3 + include/llvm/LinkAllPasses.h | 3 + include/llvm/Transforms/Scalar.h | 7 + lib/Transforms/Scalar/CMakeLists.txt | 1 + .../Scalar/NullCheckElimination.cpp | 273 ++++++++++++++++++ lib/Transforms/Scalar/Scalar.cpp | 1 + test/Transforms/NullCheckElimination/basic.ll | 165 +++++++++++ 7 files changed, 453 insertions(+) create mode 100644 lib/Transforms/Scalar/NullCheckElimination.cpp create mode 100644 test/Transforms/NullCheckElimination/basic.ll diff --git a/include/llvm/InitializePasses.h b/include/llvm/InitializePasses.h index 0c840f39f522..1e7090b4094f 100644 --- a/include/llvm/InitializePasses.h +++ b/include/llvm/InitializePasses.h @@ -275,6 +275,9 @@ void initializeBBVectorizePass(PassRegistry&); void initializeMachineFunctionPrinterPassPass(PassRegistry&); void initializeStackMapLivenessPass(PassRegistry&); void initializeLoadCombinePass(PassRegistry&); + +// Specific to the rust-lang llvm branch: +void initializeNullCheckEliminationPass(PassRegistry&); } #endif diff --git a/include/llvm/LinkAllPasses.h b/include/llvm/LinkAllPasses.h index b2309ffc2140..45a3256397e6 100644 --- a/include/llvm/LinkAllPasses.h +++ b/include/llvm/LinkAllPasses.h @@ -160,6 +160,9 @@ namespace { (void) llvm::createScalarizerPass(); (void) llvm::createSeparateConstOffsetFromGEPPass(); + // Specific to the rust-lang llvm branch: + (void) llvm::createNullCheckEliminationPass(); + (void)new llvm::IntervalPartition(); (void)new llvm::FindUsedTypes(); (void)new llvm::ScalarEvolution(); diff --git a/include/llvm/Transforms/Scalar.h b/include/llvm/Transforms/Scalar.h index 8ecfd801d0d8..4fe70c1d9ce1 100644 --- a/include/llvm/Transforms/Scalar.h +++ b/include/llvm/Transforms/Scalar.h @@ -388,6 +388,13 @@ FunctionPass *createSeparateConstOffsetFromGEPPass(); // BasicBlockPass *createLoadCombinePass(); +// Specific to the rust-lang llvm branch: +//===----------------------------------------------------------------------===// +// +// NullCheckElimination - Eliminate null checks. +// +FunctionPass *createNullCheckEliminationPass(); + } // End llvm namespace #endif diff --git a/lib/Transforms/Scalar/CMakeLists.txt b/lib/Transforms/Scalar/CMakeLists.txt index 2dcfa237ca33..08a789f848f7 100644 --- a/lib/Transforms/Scalar/CMakeLists.txt +++ b/lib/Transforms/Scalar/CMakeLists.txt @@ -22,6 +22,7 @@ add_llvm_library(LLVMScalarOpts LoopUnswitch.cpp LowerAtomic.cpp MemCpyOptimizer.cpp + NullCheckElimination.cpp PartiallyInlineLibCalls.cpp Reassociate.cpp Reg2Mem.cpp diff --git a/lib/Transforms/Scalar/NullCheckElimination.cpp b/lib/Transforms/Scalar/NullCheckElimination.cpp new file mode 100644 index 000000000000..1a921ccaaa86 --- /dev/null +++ b/lib/Transforms/Scalar/NullCheckElimination.cpp @@ -0,0 +1,273 @@ +//===-- NullCheckElimination.cpp - Null Check Elimination Pass ------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Scalar.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/Instructions.h" +#include "llvm/Pass.h" +using namespace llvm; + +#define DEBUG_TYPE "null-check-elimination" + +namespace { + struct NullCheckElimination : public FunctionPass { + static char ID; + NullCheckElimination() : FunctionPass(ID) { + initializeNullCheckEliminationPass(*PassRegistry::getPassRegistry()); + } + bool runOnFunction(Function &F) override; + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesCFG(); + } + + private: + static const unsigned kPhiLimit = 16; + typedef SmallPtrSet SmallPhiSet; + enum NullCheckResult { + NotNullCheck, + NullCheckEq, + NullCheckNe, + }; + + bool isNonNullOrPoisonPhi(SmallPhiSet *VisitedPhis, PHINode*); + + NullCheckResult isCmpNullCheck(ICmpInst*); + std::pair findNullCheck(Use*); + + bool blockContainsLoadDerivedFrom(BasicBlock*, Value*); + + DenseSet NonNullOrPoisonValues; + }; +} + +char NullCheckElimination::ID = 0; +INITIALIZE_PASS_BEGIN(NullCheckElimination, + "null-check-elimination", + "Null Check Elimination", + false, false) +INITIALIZE_PASS_END(NullCheckElimination, + "null-check-elimination", + "Null Check Elimination", + false, false) + +FunctionPass *llvm::createNullCheckEliminationPass() { + return new NullCheckElimination(); +} + +bool NullCheckElimination::runOnFunction(Function &F) { + if (skipOptnoneFunction(F)) + return false; + + bool Changed = false; + + // Collect argumetns with the `nonnull` attribute. + for (auto &Arg : F.args()) { + if (Arg.hasNonNullAttr()) + NonNullOrPoisonValues.insert(&Arg); + } + + // Collect instructions that definitely produce nonnull-or-poison values. + // At the moment, this is restricted to inbounds GEPs. It would be slightly + // more difficult to include uses of values dominated by a null check, since + // then we would have to consider uses instead of mere values. + for (auto &BB : F) { + for (auto &I : BB) { + if (auto *GEP = dyn_cast(&I)) { + if (GEP->isInBounds()) { + NonNullOrPoisonValues.insert(GEP); + } + } + } + } + + // Find phis that are derived entirely from nonnull-or-poison values, + // including other phis that are themselves derived entirely from these + // values. + for (auto &BB : F) { + for (auto &I : BB) { + auto *PN = dyn_cast(&I); + if (!PN) + break; + + SmallPhiSet VisitedPHIs; + if (isNonNullOrPoisonPhi(&VisitedPHIs, PN)) + NonNullOrPoisonValues.insert(PN); + } + } + + for (auto &BB : F) { + // This could also be extended to handle SwitchInst, but using a SwitchInst + // for a null check seems unlikely. + auto *BI = dyn_cast(BB.getTerminator()); + if (!BI || BI->isUnconditional()) + continue; + + // The first operand of a conditional branch is the condition. + auto result = findNullCheck(&BI->getOperandUse(0)); + if (!result.first) + continue; + assert((result.second == NullCheckEq || result.second == NullCheckNe) && + "valid null check kind expected if ICmpInst was found"); + + BasicBlock *NonNullBB; + if (result.second == NullCheckEq) { + // If the comparison instruction is checking for equaliity with null, + // then the pointer is nonnull on the `false` branch. + NonNullBB = BI->getSuccessor(1); + } else { + // Otherwise, if the comparison instruction is checking for inequality + // with null, the pointer is nonnull on the `true` branch. + NonNullBB = BI->getSuccessor(0); + } + + Use *U = result.first; + ICmpInst *CI = cast(U->get()); + unsigned nonConstantIndex; + if (isa(CI->getOperand(0))) + nonConstantIndex = 1; + else + nonConstantIndex = 0; + + // Due to the semantics of poison values in LLVM, we have to check that + // there is actually some externally visible side effect that is dependent + // on the poison value. Since poison values are otherwise treated as undef, + // and a load of undef is undefined behavior (which is externally visible), + // it suffices to look for a load of the nonnull-or-poison value. + // + // This could be extended to any block control-dependent on this branch of + // the null check, it's unclear if that will actually catch more cases in + // real code. + Value *PtrV = CI->getOperand(nonConstantIndex); + if (blockContainsLoadDerivedFrom(NonNullBB, PtrV)) { + Type *BoolTy = CI->getType(); + Value *NewV = ConstantInt::get(BoolTy, result.second == NullCheckNe); + U->set(NewV); + } + } + + NonNullOrPoisonValues.clear(); + + return Changed; +} + +/// Checks whether a phi is derived from known nonnnull-or-poison values, +/// including other phis that are derived from the same. May return `false` +/// conservatively in some cases, e.g. if exploring a large cycle of phis. +bool +NullCheckElimination::isNonNullOrPoisonPhi(SmallPhiSet *VisitedPhis, + PHINode *PN) { + // If we've already seen this phi, return `true`, even though it may not be + // nonnull, since some other operand in a cycle of phis may invalidate the + // optimistic assumption that the entire cycle is nonnull, including this phi. + if (!VisitedPhis->insert(PN)) + return true; + + // Use a sensible limit to avoid iterating over long chains of phis that are + // unlikely to be nonnull. + if (VisitedPhis->size() >= kPhiLimit) + return false; + + unsigned numOperands = PN->getNumOperands(); + for (unsigned i = 0; i < numOperands; ++i) { + Value *SrcValue = PN->getOperand(i); + if (NonNullOrPoisonValues.count(SrcValue)) { + continue; + } else if (auto *SrcPN = dyn_cast(SrcValue)) { + if (!isNonNullOrPoisonPhi(VisitedPhis, SrcPN)) + return false; + } else { + return false; + } + } + + return true; +} + +/// Determines whether an ICmpInst is a null check of a known nonnull-or-poison +/// value. +NullCheckElimination::NullCheckResult +NullCheckElimination::isCmpNullCheck(ICmpInst *CI) { + if (!CI->isEquality()) + return NotNullCheck; + + unsigned constantIndex; + if (NonNullOrPoisonValues.count(CI->getOperand(0))) + constantIndex = 1; + else if (NonNullOrPoisonValues.count(CI->getOperand(1))) + constantIndex = 0; + else + return NotNullCheck; + + auto *C = dyn_cast(CI->getOperand(constantIndex)); + if (!C || !C->isZeroValue()) + return NotNullCheck; + + return + CI->getPredicate() == llvm::CmpInst::ICMP_EQ ? NullCheckEq : NullCheckNe; +} + +/// Finds the Use, if any, of an ICmpInst null check of a nonnull-or-poison +/// value. +std::pair +NullCheckElimination::findNullCheck(Use *U) { + auto *I = dyn_cast(U->get()); + if (!I) + return std::make_pair(nullptr, NotNullCheck); + + if (auto *CI = dyn_cast(I)) { + NullCheckResult result = isCmpNullCheck(CI); + if (result == NotNullCheck) + return std::make_pair(nullptr, NotNullCheck); + else + return std::make_pair(U, result); + } + + unsigned opcode = I->getOpcode(); + if (opcode == Instruction::Or || opcode == Instruction::And) { + auto result = findNullCheck(&I->getOperandUse(0)); + if (result.second == NotNullCheck) + return findNullCheck(&I->getOperandUse(1)); + else + return result; + } + + return std::make_pair(nullptr, NotNullCheck); +} + +/// Determines whether `BB` contains a load from `PtrV`, or any inbounds GEP +/// derived from `PtrV`. +bool +NullCheckElimination::blockContainsLoadDerivedFrom(BasicBlock *BB, + Value *PtrV) { + for (auto &I : *BB) { + auto *LI = dyn_cast(&I); + if (!LI) + continue; + + Value *V = LI->getPointerOperand(); + while (NonNullOrPoisonValues.count(V)) { + if (V == PtrV) + return true; + + auto *GEP = dyn_cast(V); + if (!GEP) + break; + + V = GEP->getOperand(0); + } + } + + return false; +} + diff --git a/lib/Transforms/Scalar/Scalar.cpp b/lib/Transforms/Scalar/Scalar.cpp index edf012d81171..f2aed1ea0f62 100644 --- a/lib/Transforms/Scalar/Scalar.cpp +++ b/lib/Transforms/Scalar/Scalar.cpp @@ -66,6 +66,7 @@ void llvm::initializeScalarOpts(PassRegistry &Registry) { initializeTailCallElimPass(Registry); initializeSeparateConstOffsetFromGEPPass(Registry); initializeLoadCombinePass(Registry); + initializeNullCheckEliminationPass(Registry); } void LLVMInitializeScalarOpts(LLVMPassRegistryRef R) { diff --git a/test/Transforms/NullCheckElimination/basic.ll b/test/Transforms/NullCheckElimination/basic.ll new file mode 100644 index 000000000000..f176fb32c3fd --- /dev/null +++ b/test/Transforms/NullCheckElimination/basic.ll @@ -0,0 +1,165 @@ +; RUN: opt < %s -null-check-elimination -instsimplify -S | FileCheck %s + +define i64 @test_arg_simple(i64* nonnull %p) { +entry: + br label %loop_body + +loop_body: + %p0 = phi i64* [ %p, %entry ], [ %p1, %match_else ] + %b0 = icmp eq i64* %p0, null + br i1 %b0, label %return, label %match_else + +; CHECK-LABEL: @test_arg_simple +; CHECK-NOT: , null + +match_else: + %i0 = load i64* %p0 + %p1 = getelementptr inbounds i64* %p0, i64 1 + %b1 = icmp ugt i64 %i0, 10 + br i1 %b1, label %return, label %loop_body + +return: + ret i64 0 +} + +define i64 @test_arg_simple_fail(i64* %p) { +entry: + br label %loop_body + +loop_body: + %p0 = phi i64* [ %p, %entry ], [ %p1, %match_else ] + %b0 = icmp eq i64* %p0, null + br i1 %b0, label %return, label %match_else + +; CHECK-LABEL: @test_arg_simple_fail +; CHECK: null + +match_else: + %i0 = load i64* %p0 + %p1 = getelementptr inbounds i64* %p0, i64 1 + %b1 = icmp ugt i64 %i0, 10 + br i1 %b1, label %return, label %loop_body + +return: + ret i64 0 +} + +define i64 @test_inbounds_simple(i64* %p) { +entry: + %p0 = getelementptr inbounds i64* %p, i64 0 + br label %loop_body + +loop_body: + %p1 = phi i64* [ %p0, %entry ], [ %p2, %match_else ] + %b0 = icmp eq i64* %p1, null + br i1 %b0, label %return, label %match_else + +; CHECK-LABEL: @test_inbounds_simple +; CHECK-NOT: null + +match_else: + %i0 = load i64* %p1 + %p2 = getelementptr inbounds i64* %p1, i64 1 + %b1 = icmp ugt i64 %i0, 10 + br i1 %b1, label %return, label %loop_body + +return: + ret i64 0 +} + +define i64 @test_inbounds_simple_fail(i64* %p) { +entry: + %p0 = getelementptr i64* %p, i64 0 + br label %loop_body + +loop_body: + %p1 = phi i64* [ %p0, %entry ], [ %p2, %match_else ] + %b0 = icmp eq i64* %p1, null + br i1 %b0, label %return, label %match_else + +; CHECK-LABEL: @test_inbounds_simple_fail +; CHECK: null + +match_else: + %i0 = load i64* %p1 + %p2 = getelementptr inbounds i64* %p1, i64 1 + %b1 = icmp ugt i64 %i0, 10 + br i1 %b1, label %return, label %loop_body + +return: + ret i64 0 +} + +define i64 @test_inbounds_or(i64* %p, i64* %q) { +entry: + %p0 = getelementptr inbounds i64* %p, i64 0 + br label %loop_body + +loop_body: + %p1 = phi i64* [ %p0, %entry ], [ %p2, %match_else ] + %b0 = icmp eq i64* %p1, %q + %b1 = icmp eq i64* %p1, null + %b2 = or i1 %b0, %b1 + br i1 %b2, label %return, label %match_else + +; CHECK-LABEL: @test_inbounds_or +; CHECK-NOT: null + +match_else: + %i0 = load i64* %p1 + %p2 = getelementptr inbounds i64* %p1, i64 1 + %b3 = icmp ugt i64 %i0, 10 + br i1 %b3, label %return, label %loop_body + +return: + ret i64 0 +} + +define i64 @test_inbounds_and(i64* %p, i64* %q) { +entry: + %p0 = getelementptr inbounds i64* %p, i64 0 + br label %loop_body + +loop_body: + %p1 = phi i64* [ %p0, %entry ], [ %p2, %match_else ] + %b0 = icmp eq i64* %p1, %q + %b1 = icmp eq i64* %p1, null + %b2 = and i1 %b0, %b1 + br i1 %b2, label %return, label %match_else + +; CHECK-LABEL: @test_inbounds_and +; CHECK-NOT: null + +match_else: + %i0 = load i64* %p1 + %p2 = getelementptr inbounds i64* %p1, i64 1 + %b3 = icmp ugt i64 %i0, 10 + br i1 %b3, label %return, label %loop_body + +return: + ret i64 0 +} + +define i64 @test_inbounds_derived_load(i64* %p) { +entry: + %p0 = getelementptr inbounds i64* %p, i64 0 + br label %loop_body + +loop_body: + %p1 = phi i64* [ %p0, %entry ], [ %p2, %match_else ] + %b0 = icmp eq i64* %p1, null + br i1 %b0, label %return, label %match_else + +; CHECK-LABEL: @test_inbounds_derived_load +; CHECK-NOT: null + +match_else: + %p2 = getelementptr inbounds i64* %p1, i64 1 + %i0 = load i64* %p2 + %b1 = icmp ugt i64 %i0, 10 + br i1 %b1, label %return, label %loop_body + +return: + ret i64 0 +} + From 0876764d07653e8f51d9b9ba943aab0920caf99d Mon Sep 17 00:00:00 2001 From: Cameron Zwarich Date: Thu, 26 Jun 2014 23:41:06 -0700 Subject: [PATCH 2/2] Add the NullCheckElimination pass to the default pass list Since the NullCheckElimination pass has a similar intent to the CorrelatedValuePropagation pass, I decided to run it right after the both places that the latter runs. --- lib/Transforms/IPO/PassManagerBuilder.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/lib/Transforms/IPO/PassManagerBuilder.cpp b/lib/Transforms/IPO/PassManagerBuilder.cpp index c20c717de5e7..9503edc72b9b 100644 --- a/lib/Transforms/IPO/PassManagerBuilder.cpp +++ b/lib/Transforms/IPO/PassManagerBuilder.cpp @@ -187,6 +187,8 @@ void PassManagerBuilder::populateModulePassManager(PassManagerBase &MPM) { MPM.add(createEarlyCSEPass()); // Catch trivial redundancies MPM.add(createJumpThreadingPass()); // Thread jumps. MPM.add(createCorrelatedValuePropagationPass()); // Propagate conditionals + // Specific to the rust-lang llvm branch: + MPM.add(createNullCheckEliminationPass()); // Eliminate null checks MPM.add(createCFGSimplificationPass()); // Merge & remove BBs MPM.add(createInstructionCombiningPass()); // Combine silly seq's addExtensionsToPM(EP_Peephole, MPM); @@ -218,6 +220,8 @@ void PassManagerBuilder::populateModulePassManager(PassManagerBase &MPM) { addExtensionsToPM(EP_Peephole, MPM); MPM.add(createJumpThreadingPass()); // Thread jumps MPM.add(createCorrelatedValuePropagationPass()); + // Specific to the rust-lang llvm branch: + MPM.add(createNullCheckEliminationPass()); // Eliminate null checks MPM.add(createDeadStoreEliminationPass()); // Delete dead stores addExtensionsToPM(EP_ScalarOptimizerLate, MPM);