Skip to content

Commit

Permalink
InferAddressSpaces: Factor replacement loop into function [NFC] (llvm…
Browse files Browse the repository at this point in the history
  • Loading branch information
arsenm authored Aug 20, 2024
1 parent c670cb4 commit 90a8e5a
Showing 1 changed file with 111 additions and 100 deletions.
211 changes: 111 additions & 100 deletions llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ class InferAddressSpaces : public FunctionPass {

class InferAddressSpacesImpl {
AssumptionCache ∾
Function *F = nullptr;
const DominatorTree *DT = nullptr;
const TargetTransformInfo *TTI = nullptr;
const DataLayout *DL = nullptr;
Expand Down Expand Up @@ -212,14 +213,17 @@ class InferAddressSpacesImpl {
const PredicatedAddrSpaceMapTy &PredicatedAS,
SmallVectorImpl<const Use *> *PoisonUsesToFix) const;

void performPointerReplacement(
Value *V, Value *NewV, Use &U, ValueToValueMapTy &ValueWithNewAddrSpace,
SmallVectorImpl<Instruction *> &DeadInstructions) const;

// Changes the flat address expressions in function F to point to specific
// address spaces if InferredAddrSpace says so. Postorder is the postorder of
// all flat expressions in the use-def graph of function F.
bool
rewriteWithNewAddressSpaces(ArrayRef<WeakTrackingVH> Postorder,
const ValueToAddrSpaceMapTy &InferredAddrSpace,
const PredicatedAddrSpaceMapTy &PredicatedAS,
Function *F) const;
bool rewriteWithNewAddressSpaces(
ArrayRef<WeakTrackingVH> Postorder,
const ValueToAddrSpaceMapTy &InferredAddrSpace,
const PredicatedAddrSpaceMapTy &PredicatedAS) const;

void appendsFlatAddressExpressionToPostorderStack(
Value *V, PostorderStackTy &PostorderStack,
Expand Down Expand Up @@ -842,8 +846,9 @@ unsigned InferAddressSpacesImpl::joinAddressSpaces(unsigned AS1,
return (AS1 == AS2) ? AS1 : FlatAddrSpace;
}

bool InferAddressSpacesImpl::run(Function &F) {
DL = &F.getDataLayout();
bool InferAddressSpacesImpl::run(Function &CurFn) {
F = &CurFn;
DL = &F->getDataLayout();

if (AssumeDefaultIsFlatAddressSpace)
FlatAddrSpace = 0;
Expand All @@ -855,7 +860,7 @@ bool InferAddressSpacesImpl::run(Function &F) {
}

// Collects all flat address expressions in postorder.
std::vector<WeakTrackingVH> Postorder = collectFlatAddressExpressions(F);
std::vector<WeakTrackingVH> Postorder = collectFlatAddressExpressions(*F);

// Runs a data-flow analysis to refine the address spaces of every expression
// in Postorder.
Expand All @@ -865,8 +870,8 @@ bool InferAddressSpacesImpl::run(Function &F) {

// Changes the address spaces of the flat address expressions who are inferred
// to point to a specific address space.
return rewriteWithNewAddressSpaces(Postorder, InferredAddrSpace, PredicatedAS,
&F);
return rewriteWithNewAddressSpaces(Postorder, InferredAddrSpace,
PredicatedAS);
}

// Constants need to be tracked through RAUW to handle cases with nested
Expand Down Expand Up @@ -1168,10 +1173,103 @@ static Value::use_iterator skipToNextUser(Value::use_iterator I,
return I;
}

void InferAddressSpacesImpl::performPointerReplacement(
Value *V, Value *NewV, Use &U, ValueToValueMapTy &ValueWithNewAddrSpace,
SmallVectorImpl<Instruction *> &DeadInstructions) const {

User *CurUser = U.getUser();

unsigned AddrSpace = V->getType()->getPointerAddressSpace();
if (replaceIfSimplePointerUse(*TTI, CurUser, AddrSpace, V, NewV))
return;

// Skip if the current user is the new value itself.
if (CurUser == NewV)
return;

auto *CurUserI = dyn_cast<Instruction>(CurUser);
if (!CurUserI || CurUserI->getFunction() != F)
return;

// Handle more complex cases like intrinsic that need to be remangled.
if (auto *MI = dyn_cast<MemIntrinsic>(CurUser)) {
if (!MI->isVolatile() && handleMemIntrinsicPtrUse(MI, V, NewV))
return;
}

if (auto *II = dyn_cast<IntrinsicInst>(CurUser)) {
if (rewriteIntrinsicOperands(II, V, NewV))
return;
}

if (ICmpInst *Cmp = dyn_cast<ICmpInst>(CurUserI)) {
// If we can infer that both pointers are in the same addrspace,
// transform e.g.
// %cmp = icmp eq float* %p, %q
// into
// %cmp = icmp eq float addrspace(3)* %new_p, %new_q

unsigned NewAS = NewV->getType()->getPointerAddressSpace();
int SrcIdx = U.getOperandNo();
int OtherIdx = (SrcIdx == 0) ? 1 : 0;
Value *OtherSrc = Cmp->getOperand(OtherIdx);

if (Value *OtherNewV = ValueWithNewAddrSpace.lookup(OtherSrc)) {
if (OtherNewV->getType()->getPointerAddressSpace() == NewAS) {
Cmp->setOperand(OtherIdx, OtherNewV);
Cmp->setOperand(SrcIdx, NewV);
return;
}
}

// Even if the type mismatches, we can cast the constant.
if (auto *KOtherSrc = dyn_cast<Constant>(OtherSrc)) {
if (isSafeToCastConstAddrSpace(KOtherSrc, NewAS)) {
Cmp->setOperand(SrcIdx, NewV);
Cmp->setOperand(OtherIdx, ConstantExpr::getAddrSpaceCast(
KOtherSrc, NewV->getType()));
return;
}
}
}

if (AddrSpaceCastInst *ASC = dyn_cast<AddrSpaceCastInst>(CurUserI)) {
unsigned NewAS = NewV->getType()->getPointerAddressSpace();
if (ASC->getDestAddressSpace() == NewAS) {
ASC->replaceAllUsesWith(NewV);
DeadInstructions.push_back(ASC);
return;
}
}

// Otherwise, replaces the use with flat(NewV).
if (Instruction *VInst = dyn_cast<Instruction>(V)) {
// Don't create a copy of the original addrspacecast.
if (U == V && isa<AddrSpaceCastInst>(V))
return;

// Insert the addrspacecast after NewV.
BasicBlock::iterator InsertPos;
if (Instruction *NewVInst = dyn_cast<Instruction>(NewV))
InsertPos = std::next(NewVInst->getIterator());
else
InsertPos = std::next(VInst->getIterator());

while (isa<PHINode>(InsertPos))
++InsertPos;
// This instruction may contain multiple uses of V, update them all.
CurUser->replaceUsesOfWith(
V, new AddrSpaceCastInst(NewV, V->getType(), "", InsertPos));
} else {
CurUserI->replaceUsesOfWith(
V, ConstantExpr::getAddrSpaceCast(cast<Constant>(NewV), V->getType()));
}
}

bool InferAddressSpacesImpl::rewriteWithNewAddressSpaces(
ArrayRef<WeakTrackingVH> Postorder,
const ValueToAddrSpaceMapTy &InferredAddrSpace,
const PredicatedAddrSpaceMapTy &PredicatedAS, Function *F) const {
const PredicatedAddrSpaceMapTy &PredicatedAS) const {
// For each address expression to be modified, creates a clone of it with its
// pointer operands converted to the new address space. Since the pointer
// operands are converted, the clone is naturally in the new address space by
Expand Down Expand Up @@ -1262,100 +1360,13 @@ bool InferAddressSpacesImpl::rewriteWithNewAddressSpaces(
Value::use_iterator I, E, Next;
for (I = V->use_begin(), E = V->use_end(); I != E;) {
Use &U = *I;
User *CurUser = U.getUser();

// Some users may see the same pointer operand in multiple operands. Skip
// to the next instruction.
I = skipToNextUser(I, E);

unsigned AddrSpace = V->getType()->getPointerAddressSpace();
if (replaceIfSimplePointerUse(*TTI, CurUser, AddrSpace, V, NewV))
continue;

// Skip if the current user is the new value itself.
if (CurUser == NewV)
continue;

if (auto *CurUserI = dyn_cast<Instruction>(CurUser);
CurUserI && CurUserI->getFunction() != F)
continue;

// Handle more complex cases like intrinsic that need to be remangled.
if (auto *MI = dyn_cast<MemIntrinsic>(CurUser)) {
if (!MI->isVolatile() && handleMemIntrinsicPtrUse(MI, V, NewV))
continue;
}

if (auto *II = dyn_cast<IntrinsicInst>(CurUser)) {
if (rewriteIntrinsicOperands(II, V, NewV))
continue;
}

if (isa<Instruction>(CurUser)) {
if (ICmpInst *Cmp = dyn_cast<ICmpInst>(CurUser)) {
// If we can infer that both pointers are in the same addrspace,
// transform e.g.
// %cmp = icmp eq float* %p, %q
// into
// %cmp = icmp eq float addrspace(3)* %new_p, %new_q

unsigned NewAS = NewV->getType()->getPointerAddressSpace();
int SrcIdx = U.getOperandNo();
int OtherIdx = (SrcIdx == 0) ? 1 : 0;
Value *OtherSrc = Cmp->getOperand(OtherIdx);

if (Value *OtherNewV = ValueWithNewAddrSpace.lookup(OtherSrc)) {
if (OtherNewV->getType()->getPointerAddressSpace() == NewAS) {
Cmp->setOperand(OtherIdx, OtherNewV);
Cmp->setOperand(SrcIdx, NewV);
continue;
}
}

// Even if the type mismatches, we can cast the constant.
if (auto *KOtherSrc = dyn_cast<Constant>(OtherSrc)) {
if (isSafeToCastConstAddrSpace(KOtherSrc, NewAS)) {
Cmp->setOperand(SrcIdx, NewV);
Cmp->setOperand(OtherIdx, ConstantExpr::getAddrSpaceCast(
KOtherSrc, NewV->getType()));
continue;
}
}
}

if (AddrSpaceCastInst *ASC = dyn_cast<AddrSpaceCastInst>(CurUser)) {
unsigned NewAS = NewV->getType()->getPointerAddressSpace();
if (ASC->getDestAddressSpace() == NewAS) {
ASC->replaceAllUsesWith(NewV);
DeadInstructions.push_back(ASC);
continue;
}
}

// Otherwise, replaces the use with flat(NewV).
if (Instruction *VInst = dyn_cast<Instruction>(V)) {
// Don't create a copy of the original addrspacecast.
if (U == V && isa<AddrSpaceCastInst>(V))
continue;

// Insert the addrspacecast after NewV.
BasicBlock::iterator InsertPos;
if (Instruction *NewVInst = dyn_cast<Instruction>(NewV))
InsertPos = std::next(NewVInst->getIterator());
else
InsertPos = std::next(VInst->getIterator());

while (isa<PHINode>(InsertPos))
++InsertPos;
// This instruction may contain multiple uses of V, update them all.
CurUser->replaceUsesOfWith(
V, new AddrSpaceCastInst(NewV, V->getType(), "", InsertPos));
} else {
CurUser->replaceUsesOfWith(
V, ConstantExpr::getAddrSpaceCast(cast<Constant>(NewV),
V->getType()));
}
}
performPointerReplacement(V, NewV, U, ValueWithNewAddrSpace,
DeadInstructions);
}

if (V->use_empty()) {
Expand Down

0 comments on commit 90a8e5a

Please sign in to comment.