Skip to content

Commit

Permalink
replace isLoopWithIdenticalConfiguration with checkFusionStructuralLe…
Browse files Browse the repository at this point in the history
…gality
  • Loading branch information
srcarroll committed Jun 5, 2024

Verified

This commit was signed with the committer’s verified signature. The key has expired.
addaleax Anna Henningsen
1 parent b73238a commit f5bbd13
Showing 3 changed files with 29 additions and 57 deletions.
4 changes: 2 additions & 2 deletions mlir/include/mlir/Dialect/SCF/Utils/Utils.h
Original file line number Diff line number Diff line change
@@ -160,8 +160,8 @@ void getPerfectlyNestedLoops(SmallVectorImpl<scf::ForOp> &nestedLoops,
// Fusion related helpers
//===----------------------------------------------------------------------===//

template <typename LoopTy>
bool checkFusionStructuralLegality(Operation *target, Operation *source);
bool checkFusionStructuralLegality(LoopLikeOpInterface &target,
LoopLikeOpInterface &source);

/// Prepends operations of firstPloop's body into secondPloop's body.
/// Updates secondPloop with new loop.
50 changes: 15 additions & 35 deletions mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
Original file line number Diff line number Diff line change
@@ -442,34 +442,6 @@ static DiagnosedSilenceableFailure isOpSibling(Operation *target,
return DiagnosedSilenceableFailure::success();
}

/// Check if `target` scf loop can be fused into `source` scf loop.
/// Applies for scf.for, scf.forall, and scf.parallel.
///
/// This simply checks if both loops have the same bounds, steps and mapping.
/// No attempt is made at checking that the side effects of `target` and
/// `source` are independent of each other.
template <typename LoopTy>
static bool isLoopWithIdenticalConfiguration(Operation *target,
Operation *source) {
static_assert(llvm::is_one_of<LoopTy, scf::ForallOp, scf::ForOp,
scf::ParallelOp>::value,
"applies to only `forall`, `for` and `parallel`");
auto targetOp = dyn_cast<LoopTy>(target);
auto sourceOp = dyn_cast<LoopTy>(source);
if (!targetOp || !sourceOp)
return false;

if constexpr (std::is_same_v<LoopTy, scf::ForallOp>)
return targetOp.getMixedLowerBound() == sourceOp.getMixedLowerBound() &&
targetOp.getMixedUpperBound() == sourceOp.getMixedUpperBound() &&
targetOp.getMixedStep() == sourceOp.getMixedStep() &&
targetOp.getMapping() == sourceOp.getMapping();
else
return targetOp.getLowerBound() == sourceOp.getLowerBound() &&
targetOp.getUpperBound() == sourceOp.getUpperBound() &&
targetOp.getStep() == sourceOp.getStep();
}

DiagnosedSilenceableFailure
transform::LoopFuseSiblingOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
@@ -485,29 +457,37 @@ transform::LoopFuseSiblingOp::apply(transform::TransformRewriter &rewriter,
<< "source handle (got " << llvm::range_size(sourceOps) << ")";
}

Operation *target = *targetOps.begin();
Operation *source = *sourceOps.begin();
LoopLikeOpInterface target =
dyn_cast<LoopLikeOpInterface>(*targetOps.begin());
LoopLikeOpInterface source =
dyn_cast<LoopLikeOpInterface>(*sourceOps.begin());
if (!target || !source)
return emitSilenceableFailure(target->getLoc())
<< "target or source is not a loop op";

// Check if the target and source are siblings.
DiagnosedSilenceableFailure diag = isOpSibling(target, source);
if (!diag.succeeded())
return diag;

if (!mlir::checkFusionStructuralLegality(target, source))
return emitSilenceableFailure(target->getLoc())
<< "operations cannot be fused";

Operation *fusedLoop;
/// TODO: Support fusion for loop-like ops besides scf.for and scf.forall.
if (isLoopWithIdenticalConfiguration<scf::ForOp>(target, source)) {
if (isa<scf::ForOp>(target) && isa<scf::ForOp>(source)) {
fusedLoop = fuseIndependentSiblingForLoops(
cast<scf::ForOp>(target), cast<scf::ForOp>(source), rewriter);
} else if (isLoopWithIdenticalConfiguration<scf::ForallOp>(target, source)) {
} else if (isa<scf::ForallOp>(target) && isa<scf::ForallOp>(source)) {
fusedLoop = fuseIndependentSiblingForallLoops(
cast<scf::ForallOp>(target), cast<scf::ForallOp>(source), rewriter);
} else if (isLoopWithIdenticalConfiguration<scf::ParallelOp>(target,
source)) {
} else if (isa<scf::ParallelOp>(target) && isa<scf::ParallelOp>(source)) {
fusedLoop = fuseIndependentSiblingParallelLoops(
cast<scf::ParallelOp>(target), cast<scf::ParallelOp>(source), rewriter);
} else
return emitSilenceableFailure(target->getLoc())
<< "operations cannot be fused";
<< "unsupported loop type for fusion";

assert(fusedLoop && "failed to fuse operations");

32 changes: 12 additions & 20 deletions mlir/lib/Dialect/SCF/Utils/Utils.cpp
Original file line number Diff line number Diff line change
@@ -1195,26 +1195,18 @@ static bool hasNestedParallelOp(scf::ParallelOp ploop) {
return walkResult.wasInterrupted();
}

template <typename LoopTy>
static bool checkFusionStructuralLegality(Operation *target,
Operation *source) {
static_assert(llvm::is_one_of<LoopTy, scf::ForallOp, scf::ForOp,
scf::ParallelOp>::value,
"applies to only `forall`, `for` and `parallel`");
auto targetOp = dyn_cast<LoopTy>(target);
auto sourceOp = dyn_cast<LoopTy>(source);
if (!targetOp || !sourceOp)
return false;

if constexpr (std::is_same_v<LoopTy, scf::ForallOp>)
return targetOp.getMixedLowerBound() == sourceOp.getMixedLowerBound() &&
targetOp.getMixedUpperBound() == sourceOp.getMixedUpperBound() &&
targetOp.getMixedStep() == sourceOp.getMixedStep() &&
targetOp.getMapping() == sourceOp.getMapping();
else
return targetOp.getLowerBound() == sourceOp.getLowerBound() &&
targetOp.getUpperBound() == sourceOp.getUpperBound() &&
targetOp.getStep() == sourceOp.getStep();
bool mlir::checkFusionStructuralLegality(LoopLikeOpInterface &target,
LoopLikeOpInterface &source) {
auto iterSpaceEq =
target.getMixedLowerBound() == source.getMixedLowerBound() &&
target.getMixedUpperBound() == source.getMixedUpperBound() &&
target.getMixedStep() == source.getMixedStep();
auto forAllTarget = dyn_cast<scf::ForallOp>(*target);
auto forAllSource = dyn_cast<scf::ForallOp>(*source);
if (forAllTarget && forAllSource)
return iterSpaceEq &&
forAllTarget.getMapping() == forAllSource.getMapping();
return iterSpaceEq;
}

static bool isFusionLegal(scf::ParallelOp firstPloop,

0 comments on commit f5bbd13

Please sign in to comment.