Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Apply initializes attribute to DSE #113630

Merged
merged 1 commit into from
Oct 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
251 changes: 209 additions & 42 deletions llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
#include "llvm/IR/Argument.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Constant.h"
#include "llvm/IR/ConstantRangeList.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/DebugInfo.h"
Expand Down Expand Up @@ -164,6 +165,11 @@ static cl::opt<bool>
OptimizeMemorySSA("dse-optimize-memoryssa", cl::init(true), cl::Hidden,
cl::desc("Allow DSE to optimize memory accesses."));

// TODO: turn on and remove this flag.
static cl::opt<bool> EnableInitializesImprovement(
"enable-dse-initializes-attr-improvement", cl::init(false), cl::Hidden,
cl::desc("Enable the initializes attr improvement in DSE"));

//===----------------------------------------------------------------------===//
// Helper functions
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -809,8 +815,10 @@ bool canSkipDef(MemoryDef *D, bool DefVisibleToCaller) {
// A memory location wrapper that represents a MemoryLocation, `MemLoc`,
// defined by `MemDef`.
struct MemoryLocationWrapper {
MemoryLocationWrapper(MemoryLocation MemLoc, MemoryDef *MemDef)
: MemLoc(MemLoc), MemDef(MemDef) {
MemoryLocationWrapper(MemoryLocation MemLoc, MemoryDef *MemDef,
bool DefByInitializesAttr)
: MemLoc(MemLoc), MemDef(MemDef),
DefByInitializesAttr(DefByInitializesAttr) {
assert(MemLoc.Ptr && "MemLoc should be not null");
UnderlyingObject = getUnderlyingObject(MemLoc.Ptr);
DefInst = MemDef->getMemoryInst();
Expand All @@ -820,20 +828,59 @@ struct MemoryLocationWrapper {
const Value *UnderlyingObject;
MemoryDef *MemDef;
Instruction *DefInst;
bool DefByInitializesAttr = false;
};

// A memory def wrapper that represents a MemoryDef and the MemoryLocation(s)
// defined by this MemoryDef.
struct MemoryDefWrapper {
MemoryDefWrapper(MemoryDef *MemDef, std::optional<MemoryLocation> MemLoc) {
MemoryDefWrapper(MemoryDef *MemDef,
ArrayRef<std::pair<MemoryLocation, bool>> MemLocations) {
DefInst = MemDef->getMemoryInst();
if (MemLoc.has_value())
DefinedLocation = MemoryLocationWrapper(*MemLoc, MemDef);
for (auto &[MemLoc, DefByInitializesAttr] : MemLocations)
DefinedLocations.push_back(
MemoryLocationWrapper(MemLoc, MemDef, DefByInitializesAttr));
}
Instruction *DefInst;
std::optional<MemoryLocationWrapper> DefinedLocation = std::nullopt;
SmallVector<MemoryLocationWrapper, 1> DefinedLocations;
};

bool hasInitializesAttr(Instruction *I) {
CallBase *CB = dyn_cast<CallBase>(I);
return CB && CB->getArgOperandWithAttribute(Attribute::Initializes);
}

struct ArgumentInitInfo {
unsigned Idx;
bool IsDeadOrInvisibleOnUnwind;
ConstantRangeList Inits;
};

// Return the intersected range list of the initializes attributes of "Args".
// "Args" are call arguments that alias to each other.
// If any argument in "Args" doesn't have dead_on_unwind attr and
// "CallHasNoUnwindAttr" is false, return empty.
ConstantRangeList getIntersectedInitRangeList(ArrayRef<ArgumentInitInfo> Args,
bool CallHasNoUnwindAttr) {
if (Args.empty())
return {};

// To address unwind, the function should have nounwind attribute or the
// arguments have dead or invisible on unwind. Otherwise, return empty.
for (const auto &Arg : Args) {
if (!CallHasNoUnwindAttr && !Arg.IsDeadOrInvisibleOnUnwind)
return {};
if (Arg.Inits.empty())
return {};
}

ConstantRangeList IntersectedIntervals = Args.front().Inits;
for (auto &Arg : Args.drop_front())
IntersectedIntervals = IntersectedIntervals.intersectWith(Arg.Inits);

return IntersectedIntervals;
}

struct DSEState {
Function &F;
AliasAnalysis &AA;
Expand Down Expand Up @@ -911,7 +958,8 @@ struct DSEState {

auto *MD = dyn_cast_or_null<MemoryDef>(MA);
if (MD && MemDefs.size() < MemorySSADefsPerBlockLimit &&
(getLocForWrite(&I) || isMemTerminatorInst(&I)))
(getLocForWrite(&I) || isMemTerminatorInst(&I) ||
(EnableInitializesImprovement && hasInitializesAttr(&I))))
MemDefs.push_back(MD);
}
}
Expand Down Expand Up @@ -1147,13 +1195,26 @@ struct DSEState {
return MemoryLocation::getOrNone(I);
}

std::optional<MemoryLocation> getLocForInst(Instruction *I) {
// Returns a list of <MemoryLocation, bool> pairs written by I.
// The bool means whether the write is from Initializes attr.
SmallVector<std::pair<MemoryLocation, bool>, 1>
getLocForInst(Instruction *I, bool ConsiderInitializesAttr) {
SmallVector<std::pair<MemoryLocation, bool>, 1> Locations;
if (isMemTerminatorInst(I)) {
if (auto Loc = getLocForTerminator(I)) {
return Loc->first;
if (auto Loc = getLocForTerminator(I))
Locations.push_back(std::make_pair(Loc->first, false));
return Locations;
}

if (auto Loc = getLocForWrite(I))
Locations.push_back(std::make_pair(*Loc, false));

if (ConsiderInitializesAttr) {
for (auto &MemLoc : getInitializesArgMemLoc(I)) {
Locations.push_back(std::make_pair(MemLoc, true));
}
}
return getLocForWrite(I);
return Locations;
}

/// Assuming this instruction has a dead analyzable write, can we delete
Expand Down Expand Up @@ -1365,7 +1426,8 @@ struct DSEState {
getDomMemoryDef(MemoryDef *KillingDef, MemoryAccess *StartAccess,
const MemoryLocation &KillingLoc, const Value *KillingUndObj,
unsigned &ScanLimit, unsigned &WalkerStepLimit,
bool IsMemTerm, unsigned &PartialLimit) {
bool IsMemTerm, unsigned &PartialLimit,
bool IsInitializesAttrMemLoc) {
if (ScanLimit == 0 || WalkerStepLimit == 0) {
LLVM_DEBUG(dbgs() << "\n ... hit scan limit\n");
return std::nullopt;
Expand Down Expand Up @@ -1602,7 +1664,16 @@ struct DSEState {

// Uses which may read the original MemoryDef mean we cannot eliminate the
// original MD. Stop walk.
if (isReadClobber(MaybeDeadLoc, UseInst)) {
// If KillingDef is a CallInst with "initializes" attribute, the reads in
// the callee would be dominated by initializations, so it should be safe.
bool IsKillingDefFromInitAttr = false;
if (IsInitializesAttrMemLoc) {
if (KillingI == UseInst &&
KillingUndObj == getUnderlyingObject(MaybeDeadLoc.Ptr))
IsKillingDefFromInitAttr = true;
}

if (isReadClobber(MaybeDeadLoc, UseInst) && !IsKillingDefFromInitAttr) {
LLVM_DEBUG(dbgs() << " ... found read clobber\n");
return std::nullopt;
}
Expand Down Expand Up @@ -2171,6 +2242,16 @@ struct DSEState {
return MadeChange;
}

// Return the locations written by the initializes attribute.
// Note that this function considers:
// 1. Unwind edge: use "initializes" attribute only if the callee has
// "nounwind" attribute, or the argument has "dead_on_unwind" attribute,
// or the argument is invisible to caller on unwind. That is, we don't
// perform incorrect DSE on unwind edges in the current function.
// 2. Argument alias: for aliasing arguments, the "initializes" attribute is
// the intersected range list of their "initializes" attributes.
SmallVector<MemoryLocation, 1> getInitializesArgMemLoc(const Instruction *I);

// Try to eliminate dead defs that access `KillingLocWrapper.MemLoc` and are
// killed by `KillingLocWrapper.MemDef`. Return whether
// any changes were made, and whether `KillingLocWrapper.DefInst` was deleted.
Expand All @@ -2182,6 +2263,75 @@ struct DSEState {
bool eliminateDeadDefs(const MemoryDefWrapper &KillingDefWrapper);
};

SmallVector<MemoryLocation, 1>
DSEState::getInitializesArgMemLoc(const Instruction *I) {
const CallBase *CB = dyn_cast<CallBase>(I);
if (!CB)
return {};

// Collect aliasing arguments and their initializes ranges.
SmallMapVector<Value *, SmallVector<ArgumentInitInfo, 2>, 2> Arguments;
for (unsigned Idx = 0, Count = CB->arg_size(); Idx < Count; ++Idx) {
ConstantRangeList Inits;
Attribute InitializesAttr = CB->getParamAttr(Idx, Attribute::Initializes);
if (InitializesAttr.isValid())
Inits = InitializesAttr.getValueAsConstantRangeList();

Value *CurArg = CB->getArgOperand(Idx);
// We don't perform incorrect DSE on unwind edges in the current function,
// and use the "initializes" attribute to kill dead stores if:
// - The call does not throw exceptions, "CB->doesNotThrow()".
// - Or the callee parameter has "dead_on_unwind" attribute.
// - Or the argument is invisible to caller on unwind, and there are no
// unwind edges from this call in the current function (e.g. `CallInst`).
bool IsDeadOrInvisibleOnUnwind =
CB->paramHasAttr(Idx, Attribute::DeadOnUnwind) ||
(isa<CallInst>(CB) && isInvisibleToCallerOnUnwind(CurArg));
ArgumentInitInfo InitInfo{Idx, IsDeadOrInvisibleOnUnwind, Inits};
bool FoundAliasing = false;
for (auto &[Arg, AliasList] : Arguments) {
auto AAR = BatchAA.alias(MemoryLocation::getBeforeOrAfter(Arg),
MemoryLocation::getBeforeOrAfter(CurArg));
if (AAR == AliasResult::NoAlias) {
continue;
} else if (AAR == AliasResult::MustAlias) {
FoundAliasing = true;
AliasList.push_back(InitInfo);
} else {
// For PartialAlias and MayAlias, there is an offset or may be an
// unknown offset between the arguments and we insert an empty init
// range to discard the entire initializes info while intersecting.
FoundAliasing = true;
AliasList.push_back(ArgumentInitInfo{Idx, IsDeadOrInvisibleOnUnwind,
ConstantRangeList()});
}
}
if (!FoundAliasing)
Arguments[CurArg] = {InitInfo};
}

SmallVector<MemoryLocation, 1> Locations;
for (const auto &[_, Args] : Arguments) {
auto IntersectedRanges =
getIntersectedInitRangeList(Args, CB->doesNotThrow());
if (IntersectedRanges.empty())
continue;

for (const auto &Arg : Args) {
for (const auto &Range : IntersectedRanges) {
int64_t Start = Range.getLower().getSExtValue();
int64_t End = Range.getUpper().getSExtValue();
// For now, we only handle locations starting at offset 0.
if (Start == 0)
Locations.push_back(MemoryLocation(CB->getArgOperand(Arg.Idx),
LocationSize::precise(End - Start),
CB->getAAMetadata()));
}
}
}
return Locations;
}

std::pair<bool, bool>
DSEState::eliminateDeadDefs(const MemoryLocationWrapper &KillingLocWrapper) {
bool Changed = false;
Expand All @@ -2208,7 +2358,8 @@ DSEState::eliminateDeadDefs(const MemoryLocationWrapper &KillingLocWrapper) {
std::optional<MemoryAccess *> MaybeDeadAccess = getDomMemoryDef(
KillingLocWrapper.MemDef, Current, KillingLocWrapper.MemLoc,
KillingLocWrapper.UnderlyingObject, ScanLimit, WalkerStepLimit,
isMemTerminatorInst(KillingLocWrapper.DefInst), PartialLimit);
isMemTerminatorInst(KillingLocWrapper.DefInst), PartialLimit,
KillingLocWrapper.DefByInitializesAttr);

if (!MaybeDeadAccess) {
LLVM_DEBUG(dbgs() << " finished walk\n");
Expand All @@ -2231,10 +2382,20 @@ DSEState::eliminateDeadDefs(const MemoryLocationWrapper &KillingLocWrapper) {
}
continue;
}
// We cannot apply the initializes attribute to DeadAccess/DeadDef.
// It would incorrectly consider a call instruction as redundant store
// and remove this call instruction.
// TODO: this conflates the existence of a MemoryLocation with being able
// to delete the instruction. Fix isRemovable() to consider calls with
// side effects that cannot be removed, e.g. calls with the initializes
// attribute, and remove getLocForInst(ConsiderInitializesAttr = false).
MemoryDefWrapper DeadDefWrapper(
cast<MemoryDef>(DeadAccess),
getLocForInst(cast<MemoryDef>(DeadAccess)->getMemoryInst()));
MemoryLocationWrapper &DeadLocWrapper = *DeadDefWrapper.DefinedLocation;
getLocForInst(cast<MemoryDef>(DeadAccess)->getMemoryInst(),
/*ConsiderInitializesAttr=*/false));
assert(DeadDefWrapper.DefinedLocations.size() == 1);
MemoryLocationWrapper &DeadLocWrapper =
DeadDefWrapper.DefinedLocations.front();
LLVM_DEBUG(dbgs() << " (" << *DeadLocWrapper.DefInst << ")\n");
ToCheck.insert(DeadLocWrapper.MemDef->getDefiningAccess());
NumGetDomMemoryDefPassed++;
Expand Down Expand Up @@ -2309,37 +2470,42 @@ DSEState::eliminateDeadDefs(const MemoryLocationWrapper &KillingLocWrapper) {
}

bool DSEState::eliminateDeadDefs(const MemoryDefWrapper &KillingDefWrapper) {
if (!KillingDefWrapper.DefinedLocation.has_value()) {
if (KillingDefWrapper.DefinedLocations.empty()) {
LLVM_DEBUG(dbgs() << "Failed to find analyzable write location for "
<< *KillingDefWrapper.DefInst << "\n");
return false;
}

auto &KillingLocWrapper = *KillingDefWrapper.DefinedLocation;
LLVM_DEBUG(dbgs() << "Trying to eliminate MemoryDefs killed by "
<< *KillingLocWrapper.MemDef << " ("
<< *KillingLocWrapper.DefInst << ")\n");
auto [Changed, DeletedKillingLoc] = eliminateDeadDefs(KillingLocWrapper);

// Check if the store is a no-op.
if (!DeletedKillingLoc && storeIsNoop(KillingLocWrapper.MemDef,
KillingLocWrapper.UnderlyingObject)) {
LLVM_DEBUG(dbgs() << "DSE: Remove No-Op Store:\n DEAD: "
<< *KillingLocWrapper.DefInst << '\n');
deleteDeadInstruction(KillingLocWrapper.DefInst);
NumRedundantStores++;
return true;
}
// Can we form a calloc from a memset/malloc pair?
if (!DeletedKillingLoc &&
tryFoldIntoCalloc(KillingLocWrapper.MemDef,
KillingLocWrapper.UnderlyingObject)) {
LLVM_DEBUG(dbgs() << "DSE: Remove memset after forming calloc:\n"
<< " DEAD: " << *KillingLocWrapper.DefInst << '\n');
deleteDeadInstruction(KillingLocWrapper.DefInst);
return true;
bool MadeChange = false;
for (auto &KillingLocWrapper : KillingDefWrapper.DefinedLocations) {
LLVM_DEBUG(dbgs() << "Trying to eliminate MemoryDefs killed by "
<< *KillingLocWrapper.MemDef << " ("
<< *KillingLocWrapper.DefInst << ")\n");
auto [Changed, DeletedKillingLoc] = eliminateDeadDefs(KillingLocWrapper);
MadeChange |= Changed;

// Check if the store is a no-op.
if (!DeletedKillingLoc && storeIsNoop(KillingLocWrapper.MemDef,
KillingLocWrapper.UnderlyingObject)) {
LLVM_DEBUG(dbgs() << "DSE: Remove No-Op Store:\n DEAD: "
<< *KillingLocWrapper.DefInst << '\n');
deleteDeadInstruction(KillingLocWrapper.DefInst);
NumRedundantStores++;
MadeChange = true;
continue;
}
// Can we form a calloc from a memset/malloc pair?
if (!DeletedKillingLoc &&
tryFoldIntoCalloc(KillingLocWrapper.MemDef,
KillingLocWrapper.UnderlyingObject)) {
LLVM_DEBUG(dbgs() << "DSE: Remove memset after forming calloc:\n"
<< " DEAD: " << *KillingLocWrapper.DefInst << '\n');
deleteDeadInstruction(KillingLocWrapper.DefInst);
MadeChange = true;
continue;
}
}
return Changed;
return MadeChange;
}

static bool eliminateDeadStores(Function &F, AliasAnalysis &AA, MemorySSA &MSSA,
Expand All @@ -2355,7 +2521,8 @@ static bool eliminateDeadStores(Function &F, AliasAnalysis &AA, MemorySSA &MSSA,
continue;

MemoryDefWrapper KillingDefWrapper(
KillingDef, State.getLocForInst(KillingDef->getMemoryInst()));
KillingDef, State.getLocForInst(KillingDef->getMemoryInst(),
EnableInitializesImprovement));
MadeChange |= State.eliminateDeadDefs(KillingDefWrapper);
}

Expand Down
Loading
Loading