diff --git a/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp b/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp index 3e3e5bfe2d6332..566cdc51f6e74a 100644 --- a/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp +++ b/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp @@ -184,6 +184,7 @@ class InferAddressSpaces : public FunctionPass { class InferAddressSpacesImpl { AssumptionCache ∾ + Function *F = nullptr; const DominatorTree *DT = nullptr; const TargetTransformInfo *TTI = nullptr; const DataLayout *DL = nullptr; @@ -212,14 +213,17 @@ class InferAddressSpacesImpl { const PredicatedAddrSpaceMapTy &PredicatedAS, SmallVectorImpl *PoisonUsesToFix) const; + void performPointerReplacement( + Value *V, Value *NewV, Use &U, ValueToValueMapTy &ValueWithNewAddrSpace, + SmallVectorImpl &DeadInstructions) const; + // Changes the flat address expressions in function F to point to specific // address spaces if InferredAddrSpace says so. Postorder is the postorder of // all flat expressions in the use-def graph of function F. - bool - rewriteWithNewAddressSpaces(ArrayRef Postorder, - const ValueToAddrSpaceMapTy &InferredAddrSpace, - const PredicatedAddrSpaceMapTy &PredicatedAS, - Function *F) const; + bool rewriteWithNewAddressSpaces( + ArrayRef Postorder, + const ValueToAddrSpaceMapTy &InferredAddrSpace, + const PredicatedAddrSpaceMapTy &PredicatedAS) const; void appendsFlatAddressExpressionToPostorderStack( Value *V, PostorderStackTy &PostorderStack, @@ -842,8 +846,9 @@ unsigned InferAddressSpacesImpl::joinAddressSpaces(unsigned AS1, return (AS1 == AS2) ? AS1 : FlatAddrSpace; } -bool InferAddressSpacesImpl::run(Function &F) { - DL = &F.getDataLayout(); +bool InferAddressSpacesImpl::run(Function &CurFn) { + F = &CurFn; + DL = &F->getDataLayout(); if (AssumeDefaultIsFlatAddressSpace) FlatAddrSpace = 0; @@ -855,7 +860,7 @@ bool InferAddressSpacesImpl::run(Function &F) { } // Collects all flat address expressions in postorder. - std::vector Postorder = collectFlatAddressExpressions(F); + std::vector Postorder = collectFlatAddressExpressions(*F); // Runs a data-flow analysis to refine the address spaces of every expression // in Postorder. @@ -865,8 +870,8 @@ bool InferAddressSpacesImpl::run(Function &F) { // Changes the address spaces of the flat address expressions who are inferred // to point to a specific address space. - return rewriteWithNewAddressSpaces(Postorder, InferredAddrSpace, PredicatedAS, - &F); + return rewriteWithNewAddressSpaces(Postorder, InferredAddrSpace, + PredicatedAS); } // Constants need to be tracked through RAUW to handle cases with nested @@ -1168,10 +1173,103 @@ static Value::use_iterator skipToNextUser(Value::use_iterator I, return I; } +void InferAddressSpacesImpl::performPointerReplacement( + Value *V, Value *NewV, Use &U, ValueToValueMapTy &ValueWithNewAddrSpace, + SmallVectorImpl &DeadInstructions) const { + + User *CurUser = U.getUser(); + + unsigned AddrSpace = V->getType()->getPointerAddressSpace(); + if (replaceIfSimplePointerUse(*TTI, CurUser, AddrSpace, V, NewV)) + return; + + // Skip if the current user is the new value itself. + if (CurUser == NewV) + return; + + auto *CurUserI = dyn_cast(CurUser); + if (!CurUserI || CurUserI->getFunction() != F) + return; + + // Handle more complex cases like intrinsic that need to be remangled. + if (auto *MI = dyn_cast(CurUser)) { + if (!MI->isVolatile() && handleMemIntrinsicPtrUse(MI, V, NewV)) + return; + } + + if (auto *II = dyn_cast(CurUser)) { + if (rewriteIntrinsicOperands(II, V, NewV)) + return; + } + + if (ICmpInst *Cmp = dyn_cast(CurUserI)) { + // If we can infer that both pointers are in the same addrspace, + // transform e.g. + // %cmp = icmp eq float* %p, %q + // into + // %cmp = icmp eq float addrspace(3)* %new_p, %new_q + + unsigned NewAS = NewV->getType()->getPointerAddressSpace(); + int SrcIdx = U.getOperandNo(); + int OtherIdx = (SrcIdx == 0) ? 1 : 0; + Value *OtherSrc = Cmp->getOperand(OtherIdx); + + if (Value *OtherNewV = ValueWithNewAddrSpace.lookup(OtherSrc)) { + if (OtherNewV->getType()->getPointerAddressSpace() == NewAS) { + Cmp->setOperand(OtherIdx, OtherNewV); + Cmp->setOperand(SrcIdx, NewV); + return; + } + } + + // Even if the type mismatches, we can cast the constant. + if (auto *KOtherSrc = dyn_cast(OtherSrc)) { + if (isSafeToCastConstAddrSpace(KOtherSrc, NewAS)) { + Cmp->setOperand(SrcIdx, NewV); + Cmp->setOperand(OtherIdx, ConstantExpr::getAddrSpaceCast( + KOtherSrc, NewV->getType())); + return; + } + } + } + + if (AddrSpaceCastInst *ASC = dyn_cast(CurUserI)) { + unsigned NewAS = NewV->getType()->getPointerAddressSpace(); + if (ASC->getDestAddressSpace() == NewAS) { + ASC->replaceAllUsesWith(NewV); + DeadInstructions.push_back(ASC); + return; + } + } + + // Otherwise, replaces the use with flat(NewV). + if (Instruction *VInst = dyn_cast(V)) { + // Don't create a copy of the original addrspacecast. + if (U == V && isa(V)) + return; + + // Insert the addrspacecast after NewV. + BasicBlock::iterator InsertPos; + if (Instruction *NewVInst = dyn_cast(NewV)) + InsertPos = std::next(NewVInst->getIterator()); + else + InsertPos = std::next(VInst->getIterator()); + + while (isa(InsertPos)) + ++InsertPos; + // This instruction may contain multiple uses of V, update them all. + CurUser->replaceUsesOfWith( + V, new AddrSpaceCastInst(NewV, V->getType(), "", InsertPos)); + } else { + CurUserI->replaceUsesOfWith( + V, ConstantExpr::getAddrSpaceCast(cast(NewV), V->getType())); + } +} + bool InferAddressSpacesImpl::rewriteWithNewAddressSpaces( ArrayRef Postorder, const ValueToAddrSpaceMapTy &InferredAddrSpace, - const PredicatedAddrSpaceMapTy &PredicatedAS, Function *F) const { + const PredicatedAddrSpaceMapTy &PredicatedAS) const { // For each address expression to be modified, creates a clone of it with its // pointer operands converted to the new address space. Since the pointer // operands are converted, the clone is naturally in the new address space by @@ -1262,100 +1360,13 @@ bool InferAddressSpacesImpl::rewriteWithNewAddressSpaces( Value::use_iterator I, E, Next; for (I = V->use_begin(), E = V->use_end(); I != E;) { Use &U = *I; - User *CurUser = U.getUser(); // Some users may see the same pointer operand in multiple operands. Skip // to the next instruction. I = skipToNextUser(I, E); - unsigned AddrSpace = V->getType()->getPointerAddressSpace(); - if (replaceIfSimplePointerUse(*TTI, CurUser, AddrSpace, V, NewV)) - continue; - - // Skip if the current user is the new value itself. - if (CurUser == NewV) - continue; - - if (auto *CurUserI = dyn_cast(CurUser); - CurUserI && CurUserI->getFunction() != F) - continue; - - // Handle more complex cases like intrinsic that need to be remangled. - if (auto *MI = dyn_cast(CurUser)) { - if (!MI->isVolatile() && handleMemIntrinsicPtrUse(MI, V, NewV)) - continue; - } - - if (auto *II = dyn_cast(CurUser)) { - if (rewriteIntrinsicOperands(II, V, NewV)) - continue; - } - - if (isa(CurUser)) { - if (ICmpInst *Cmp = dyn_cast(CurUser)) { - // If we can infer that both pointers are in the same addrspace, - // transform e.g. - // %cmp = icmp eq float* %p, %q - // into - // %cmp = icmp eq float addrspace(3)* %new_p, %new_q - - unsigned NewAS = NewV->getType()->getPointerAddressSpace(); - int SrcIdx = U.getOperandNo(); - int OtherIdx = (SrcIdx == 0) ? 1 : 0; - Value *OtherSrc = Cmp->getOperand(OtherIdx); - - if (Value *OtherNewV = ValueWithNewAddrSpace.lookup(OtherSrc)) { - if (OtherNewV->getType()->getPointerAddressSpace() == NewAS) { - Cmp->setOperand(OtherIdx, OtherNewV); - Cmp->setOperand(SrcIdx, NewV); - continue; - } - } - - // Even if the type mismatches, we can cast the constant. - if (auto *KOtherSrc = dyn_cast(OtherSrc)) { - if (isSafeToCastConstAddrSpace(KOtherSrc, NewAS)) { - Cmp->setOperand(SrcIdx, NewV); - Cmp->setOperand(OtherIdx, ConstantExpr::getAddrSpaceCast( - KOtherSrc, NewV->getType())); - continue; - } - } - } - - if (AddrSpaceCastInst *ASC = dyn_cast(CurUser)) { - unsigned NewAS = NewV->getType()->getPointerAddressSpace(); - if (ASC->getDestAddressSpace() == NewAS) { - ASC->replaceAllUsesWith(NewV); - DeadInstructions.push_back(ASC); - continue; - } - } - - // Otherwise, replaces the use with flat(NewV). - if (Instruction *VInst = dyn_cast(V)) { - // Don't create a copy of the original addrspacecast. - if (U == V && isa(V)) - continue; - - // Insert the addrspacecast after NewV. - BasicBlock::iterator InsertPos; - if (Instruction *NewVInst = dyn_cast(NewV)) - InsertPos = std::next(NewVInst->getIterator()); - else - InsertPos = std::next(VInst->getIterator()); - - while (isa(InsertPos)) - ++InsertPos; - // This instruction may contain multiple uses of V, update them all. - CurUser->replaceUsesOfWith( - V, new AddrSpaceCastInst(NewV, V->getType(), "", InsertPos)); - } else { - CurUser->replaceUsesOfWith( - V, ConstantExpr::getAddrSpaceCast(cast(NewV), - V->getType())); - } - } + performPointerReplacement(V, NewV, U, ValueWithNewAddrSpace, + DeadInstructions); } if (V->use_empty()) {