Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR][NFC] Add fast path to fused loc flattening. #75312

Merged
merged 1 commit into from
Dec 13, 2023

Conversation

bchetioui
Copy link
Member

This is a follow-up on PR 75218 that avoids reconstructing a fused loc in the FlattenFusedLocationRecursively helper when there has been no change.

@bchetioui bchetioui requested a review from joker-eph December 13, 2023 10:21
@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Dec 13, 2023
@llvmbot
Copy link
Member

llvmbot commented Dec 13, 2023

@llvm/pr-subscribers-mlir-core

Author: Benjamin Chetioui (bchetioui)

Changes

This is a follow-up on PR 75218 that avoids reconstructing a fused loc in the FlattenFusedLocationRecursively helper when there has been no change.


Full diff: https://github.com/llvm/llvm-project/pull/75312.diff

1 Files Affected:

  • (modified) mlir/lib/Transforms/Utils/FoldUtils.cpp (+28-17)
diff --git a/mlir/lib/Transforms/Utils/FoldUtils.cpp b/mlir/lib/Transforms/Utils/FoldUtils.cpp
index 056a681718e12..aa1e1ce01777d 100644
--- a/mlir/lib/Transforms/Utils/FoldUtils.cpp
+++ b/mlir/lib/Transforms/Utils/FoldUtils.cpp
@@ -340,28 +340,39 @@ OperationFolder::tryGetOrCreateConstant(ConstantMap &uniquedConstants,
 /// child fused locations are the same---this to avoid breaking cases where
 /// metadata matter.
 static Location FlattenFusedLocationRecursively(const Location loc) {
-  if (auto fusedLoc = dyn_cast<FusedLoc>(loc)) {
-    SetVector<Location> flattenedLocs;
-    Attribute metadata = fusedLoc.getMetadata();
-
-    for (const Location &unflattenedLoc : fusedLoc.getLocations()) {
-      Location flattenedLoc = FlattenFusedLocationRecursively(unflattenedLoc);
-      auto flattenedFusedLoc = dyn_cast<FusedLoc>(flattenedLoc);
-
-      if (flattenedFusedLoc && (!flattenedFusedLoc.getMetadata() ||
-                                flattenedFusedLoc.getMetadata() == metadata)) {
-        ArrayRef<Location> nestedLocations = flattenedFusedLoc.getLocations();
-        flattenedLocs.insert(nestedLocations.begin(), nestedLocations.end());
-      } else {
-        flattenedLocs.insert(flattenedLoc);
+  auto fusedLoc = dyn_cast<FusedLoc>(loc);
+  if (!fusedLoc) return loc;
+
+  SetVector<Location> flattenedLocs;
+  Attribute metadata = fusedLoc.getMetadata();
+  ArrayRef<Location> unflattenedLocs = fusedLoc.getLocations();
+  bool hasAnyNestedLocChanged = false;
+
+  for (const Location &unflattenedLoc : unflattenedLocs) {
+    Location flattenedLoc = FlattenFusedLocationRecursively(unflattenedLoc);
+
+    auto flattenedFusedLoc = dyn_cast<FusedLoc>(flattenedLoc);
+    if (flattenedFusedLoc && (!flattenedFusedLoc.getMetadata() ||
+                              flattenedFusedLoc.getMetadata() == metadata)) {
+      hasAnyNestedLocChanged = true;
+      ArrayRef<Location> nestedLocations = flattenedFusedLoc.getLocations();
+      flattenedLocs.insert(nestedLocations.begin(), nestedLocations.end());
+    } else {
+      if (flattenedLoc != unflattenedLoc) {
+        hasAnyNestedLocChanged = true;
       }
+
+      flattenedLocs.insert(flattenedLoc);
     }
+  }
 
-    return FusedLoc::get(loc->getContext(), flattenedLocs.takeVector(),
-                         fusedLoc.getMetadata());
+  if (!hasAnyNestedLocChanged &&
+      unflattenedLocs.size() == flattenedLocs.size()) {
+    return loc;
   }
 
-  return loc;
+  return FusedLoc::get(loc->getContext(), flattenedLocs.takeVector(),
+                        fusedLoc.getMetadata());
 }
 
 void OperationFolder::appendFoldedLocation(Operation *retainedOp,

@llvmbot
Copy link
Member

llvmbot commented Dec 13, 2023

@llvm/pr-subscribers-mlir

Author: Benjamin Chetioui (bchetioui)

Changes

This is a follow-up on PR 75218 that avoids reconstructing a fused loc in the FlattenFusedLocationRecursively helper when there has been no change.


Full diff: https://github.com/llvm/llvm-project/pull/75312.diff

1 Files Affected:

  • (modified) mlir/lib/Transforms/Utils/FoldUtils.cpp (+28-17)
diff --git a/mlir/lib/Transforms/Utils/FoldUtils.cpp b/mlir/lib/Transforms/Utils/FoldUtils.cpp
index 056a681718e121..aa1e1ce01777db 100644
--- a/mlir/lib/Transforms/Utils/FoldUtils.cpp
+++ b/mlir/lib/Transforms/Utils/FoldUtils.cpp
@@ -340,28 +340,39 @@ OperationFolder::tryGetOrCreateConstant(ConstantMap &uniquedConstants,
 /// child fused locations are the same---this to avoid breaking cases where
 /// metadata matter.
 static Location FlattenFusedLocationRecursively(const Location loc) {
-  if (auto fusedLoc = dyn_cast<FusedLoc>(loc)) {
-    SetVector<Location> flattenedLocs;
-    Attribute metadata = fusedLoc.getMetadata();
-
-    for (const Location &unflattenedLoc : fusedLoc.getLocations()) {
-      Location flattenedLoc = FlattenFusedLocationRecursively(unflattenedLoc);
-      auto flattenedFusedLoc = dyn_cast<FusedLoc>(flattenedLoc);
-
-      if (flattenedFusedLoc && (!flattenedFusedLoc.getMetadata() ||
-                                flattenedFusedLoc.getMetadata() == metadata)) {
-        ArrayRef<Location> nestedLocations = flattenedFusedLoc.getLocations();
-        flattenedLocs.insert(nestedLocations.begin(), nestedLocations.end());
-      } else {
-        flattenedLocs.insert(flattenedLoc);
+  auto fusedLoc = dyn_cast<FusedLoc>(loc);
+  if (!fusedLoc) return loc;
+
+  SetVector<Location> flattenedLocs;
+  Attribute metadata = fusedLoc.getMetadata();
+  ArrayRef<Location> unflattenedLocs = fusedLoc.getLocations();
+  bool hasAnyNestedLocChanged = false;
+
+  for (const Location &unflattenedLoc : unflattenedLocs) {
+    Location flattenedLoc = FlattenFusedLocationRecursively(unflattenedLoc);
+
+    auto flattenedFusedLoc = dyn_cast<FusedLoc>(flattenedLoc);
+    if (flattenedFusedLoc && (!flattenedFusedLoc.getMetadata() ||
+                              flattenedFusedLoc.getMetadata() == metadata)) {
+      hasAnyNestedLocChanged = true;
+      ArrayRef<Location> nestedLocations = flattenedFusedLoc.getLocations();
+      flattenedLocs.insert(nestedLocations.begin(), nestedLocations.end());
+    } else {
+      if (flattenedLoc != unflattenedLoc) {
+        hasAnyNestedLocChanged = true;
       }
+
+      flattenedLocs.insert(flattenedLoc);
     }
+  }
 
-    return FusedLoc::get(loc->getContext(), flattenedLocs.takeVector(),
-                         fusedLoc.getMetadata());
+  if (!hasAnyNestedLocChanged &&
+      unflattenedLocs.size() == flattenedLocs.size()) {
+    return loc;
   }
 
-  return loc;
+  return FusedLoc::get(loc->getContext(), flattenedLocs.takeVector(),
+                        fusedLoc.getMetadata());
 }
 
 void OperationFolder::appendFoldedLocation(Operation *retainedOp,

Copy link

github-actions bot commented Dec 13, 2023

:white_check_mark: With the latest revision this PR passed the C/C++ code formatter.

@bchetioui bchetioui force-pushed the improve-fused-loc-flattening branch from a980bd5 to 8add012 Compare December 13, 2023 10:28
mlir/lib/Transforms/Utils/FoldUtils.cpp Outdated Show resolved Hide resolved
This is a follow-up on
[PR 75218](llvm#75218) that avoids
reconstructing a fused loc in the `FlattenFusedLocationRecursively`
helper when there has been no change.
@bchetioui bchetioui force-pushed the improve-fused-loc-flattening branch from 8add012 to 12fe97a Compare December 13, 2023 11:39
@bchetioui bchetioui merged commit 6fe3cd5 into llvm:main Dec 13, 2023
3 of 4 checks passed
@bchetioui bchetioui deleted the improve-fused-loc-flattening branch December 13, 2023 11:40
MaskRay added a commit to MaskRay/llvm-project that referenced this pull request Dec 13, 2023
This reverts commit 87e2e89.
and its follow-ups 0d1490f (llvm#75218)
and 6fe3cd5 (llvm#75312).

We observed significant OOM/timeout issues due to llvm#74670 to quite a few
services including google-research/swirl-lm. The follow-up llvm#75218 and
 llvm#75312 do not address the issue. Perhaps this is worth more
investigation.
MaskRay added a commit that referenced this pull request Dec 13, 2023
This reverts commit 87e2e89.
and its follow-ups 0d1490f (#75218)
and 6fe3cd5 (#75312).

We observed significant OOM/timeout issues due to #74670 to quite a few
services including google-research/swirl-lm. The follow-up #75218 and
 #75312 do not address the issue. Perhaps this is worth more
investigation.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants