Skip to content

Commit

Permalink
[MLIR][NFC] Add fast path to fused loc flattening. (llvm#75312)
Browse files Browse the repository at this point in the history
This is a follow-up on [PR
75218](llvm#75218) that avoids
reconstructing a fused loc in the `FlattenFusedLocationRecursively`
helper when there has been no change.
  • Loading branch information
bchetioui authored Dec 13, 2023
1 parent 35dacf2 commit 6fe3cd5
Showing 1 changed file with 29 additions and 18 deletions.
47 changes: 29 additions & 18 deletions mlir/lib/Transforms/Utils/FoldUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -340,28 +340,39 @@ OperationFolder::tryGetOrCreateConstant(ConstantMap &uniquedConstants,
/// child fused locations are the same---this to avoid breaking cases where
/// metadata matter.
static Location FlattenFusedLocationRecursively(const Location loc) {
if (auto fusedLoc = dyn_cast<FusedLoc>(loc)) {
SetVector<Location> flattenedLocs;
Attribute metadata = fusedLoc.getMetadata();

for (const Location &unflattenedLoc : fusedLoc.getLocations()) {
Location flattenedLoc = FlattenFusedLocationRecursively(unflattenedLoc);
auto flattenedFusedLoc = dyn_cast<FusedLoc>(flattenedLoc);

if (flattenedFusedLoc && (!flattenedFusedLoc.getMetadata() ||
flattenedFusedLoc.getMetadata() == metadata)) {
ArrayRef<Location> nestedLocations = flattenedFusedLoc.getLocations();
flattenedLocs.insert(nestedLocations.begin(), nestedLocations.end());
} else {
flattenedLocs.insert(flattenedLoc);
}
auto fusedLoc = dyn_cast<FusedLoc>(loc);
if (!fusedLoc)
return loc;

SetVector<Location> flattenedLocs;
Attribute metadata = fusedLoc.getMetadata();
ArrayRef<Location> unflattenedLocs = fusedLoc.getLocations();
bool hasAnyNestedLocChanged = false;

for (const Location &unflattenedLoc : unflattenedLocs) {
Location flattenedLoc = FlattenFusedLocationRecursively(unflattenedLoc);

auto flattenedFusedLoc = dyn_cast<FusedLoc>(flattenedLoc);
if (flattenedFusedLoc && (!flattenedFusedLoc.getMetadata() ||
flattenedFusedLoc.getMetadata() == metadata)) {
hasAnyNestedLocChanged = true;
ArrayRef<Location> nestedLocations = flattenedFusedLoc.getLocations();
flattenedLocs.insert(nestedLocations.begin(), nestedLocations.end());
} else {
if (flattenedLoc != unflattenedLoc)
hasAnyNestedLocChanged = true;

flattenedLocs.insert(flattenedLoc);
}
}

return FusedLoc::get(loc->getContext(), flattenedLocs.takeVector(),
fusedLoc.getMetadata());
if (!hasAnyNestedLocChanged &&
unflattenedLocs.size() == flattenedLocs.size()) {
return loc;
}

return loc;
return FusedLoc::get(loc->getContext(), flattenedLocs.takeVector(),
fusedLoc.getMetadata());
}

void OperationFolder::appendFoldedLocation(Operation *retainedOp,
Expand Down

0 comments on commit 6fe3cd5

Please sign in to comment.