Skip to content

Commit

Permalink
[MLIR] Flatten fused locations when merging constants. (llvm#75218)
Browse files Browse the repository at this point in the history
[PR 74670](llvm#74670) added
support for merging locations at constant folding time. We have
discovered that in some cases, the number of locations grows so big as
to cause a compilation process to OOM. In that case, many of the
locations end up appearing several times in nested fused locations.

We add here a helper that always flattens fused locations in order to
eliminate duplicates in the case of nested fused locations.
  • Loading branch information
bchetioui authored Dec 12, 2023
1 parent fe6f137 commit 0d1490f
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 5 deletions.
37 changes: 35 additions & 2 deletions mlir/lib/Transforms/Utils/FoldUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,39 @@ OperationFolder::tryGetOrCreateConstant(ConstantMap &uniquedConstants,
return newIt.first->second;
}

/// Helper that flattens nested fused locations to a single fused location.
/// Fused locations nested under non-fused locations are not flattened, and
/// calling this on non-fused locations is a no-op as a result.
///
/// Fused locations are only flattened into parent fused locations if the
/// child fused location has no metadata, or if the metadata of the parent and
/// child fused locations are the same---this to avoid breaking cases where
/// metadata matter.
static Location FlattenFusedLocationRecursively(const Location loc) {
if (auto fusedLoc = dyn_cast<FusedLoc>(loc)) {
SetVector<Location> flattenedLocs;
Attribute metadata = fusedLoc.getMetadata();

for (const Location &unflattenedLoc : fusedLoc.getLocations()) {
Location flattenedLoc = FlattenFusedLocationRecursively(unflattenedLoc);
auto flattenedFusedLoc = dyn_cast<FusedLoc>(flattenedLoc);

if (flattenedFusedLoc && (!flattenedFusedLoc.getMetadata() ||
flattenedFusedLoc.getMetadata() == metadata)) {
ArrayRef<Location> nestedLocations = flattenedFusedLoc.getLocations();
flattenedLocs.insert(nestedLocations.begin(), nestedLocations.end());
} else {
flattenedLocs.insert(flattenedLoc);
}
}

return FusedLoc::get(loc->getContext(), flattenedLocs.takeVector(),
fusedLoc.getMetadata());
}

return loc;
}

void OperationFolder::appendFoldedLocation(Operation *retainedOp,
Location foldedLocation) {
// Append into existing fused location if it has the same tag.
Expand All @@ -344,7 +377,7 @@ void OperationFolder::appendFoldedLocation(Operation *retainedOp,
locations.insert(foldedLocation);
Location newFusedLoc = FusedLoc::get(
retainedOp->getContext(), locations.takeVector(), existingMetadata);
retainedOp->setLoc(newFusedLoc);
retainedOp->setLoc(FlattenFusedLocationRecursively(newFusedLoc));
return;
}
}
Expand All @@ -357,5 +390,5 @@ void OperationFolder::appendFoldedLocation(Operation *retainedOp,
Location newFusedLoc =
FusedLoc::get(retainedOp->getContext(),
{retainedOp->getLoc(), foldedLocation}, fusedLocationTag);
retainedOp->setLoc(newFusedLoc);
retainedOp->setLoc(FlattenFusedLocationRecursively(newFusedLoc));
}
13 changes: 10 additions & 3 deletions mlir/test/Transforms/canonicalize-debuginfo.mlir
Original file line number Diff line number Diff line change
@@ -1,19 +1,26 @@
// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(canonicalize{test-convergence}))' -split-input-file -mlir-print-debuginfo | FileCheck %s

// CHECK-LABEL: func @merge_constants
func.func @merge_constants() -> (index, index, index, index) {
func.func @merge_constants() -> (index, index, index, index, index, index, index) {
// CHECK-NEXT: arith.constant 42 : index loc(#[[FusedLoc:.*]])
%0 = arith.constant 42 : index loc("merge_constants":0:0)
%1 = arith.constant 42 : index loc("merge_constants":1:0)
%2 = arith.constant 42 : index loc("merge_constants":2:0)
%3 = arith.constant 42 : index loc("merge_constants":2:0) // repeated loc
return %0, %1, %2, %3: index, index, index, index
%4 = arith.constant 43 : index loc(fused<"some_label">["merge_constants":3:0])
%5 = arith.constant 43 : index loc(fused<"some_label">["merge_constants":3:0])
%6 = arith.constant 43 : index loc(fused<"some_other_label">["merge_constants":3:0])
return %0, %1, %2, %3, %4, %5, %6 : index, index, index, index, index, index, index
}

// CHECK-DAG: #[[LocConst0:.*]] = loc("merge_constants":0:0)
// CHECK-DAG: #[[LocConst1:.*]] = loc("merge_constants":1:0)
// CHECK-DAG: #[[LocConst2:.*]] = loc("merge_constants":2:0)
// CHECK: #[[FusedLoc]] = loc(fused<"CSE">[#[[LocConst0]], #[[LocConst1]], #[[LocConst2]]])
// CHECK-DAG: #[[LocConst3:.*]] = loc("merge_constants":3:0)
// CHECK-DAG: #[[FusedLoc_CSE_1:.*]] = loc(fused<"CSE">[#[[LocConst0]], #[[LocConst1]], #[[LocConst2]]])
// CHECK-DAG: #[[FusedLoc_Some_Label:.*]] = loc(fused<"some_label">[#[[LocConst3]]])
// CHECK-DAG: #[[FusedLoc_Some_Other_Label:.*]] = loc(fused<"some_other_label">[#[[LocConst3]]])
// CHECK-DAG: #[[FusedLoc_CSE_2:.*]] = loc(fused<"CSE">[#[[FusedLoc_Some_Label]], #[[FusedLoc_Some_Other_Label]]])

// -----

Expand Down

0 comments on commit 0d1490f

Please sign in to comment.