diff --git a/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp b/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp index 5d802b50a8c76c..2218a00d98db60 100644 --- a/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp +++ b/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp @@ -36,7 +36,8 @@ namespace flangomp { #include "flang/Optimizer/OpenMP/Passes.h.inc" } // namespace flangomp -#define DEBUG_TYPE "fopenmp-do-concurrent-conversion" +#define DEBUG_TYPE "do-concurrent-conversion" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") namespace Fortran { namespace lower { @@ -175,9 +176,24 @@ bool isIndVarUltimateOperand(mlir::Operation *op, fir::DoLoopOp doLoop) { return false; } +/// For the \p doLoop parameter, find the operations that declares its induction +/// variable or allocates memory for it. +mlir::Operation *findLoopIndVarMemDecl(fir::DoLoopOp doLoop) { + mlir::Value result = nullptr; + mlir::visitUsedValuesDefinedAbove( + doLoop.getRegion(), [&](mlir::OpOperand *operand) { + if (isIndVarUltimateOperand(operand->getOwner(), doLoop)) { + assert(result == nullptr && + "loop can have only one induction variable"); + result = operand->get(); + } + }); + + assert(result != nullptr && result.getDefiningOp() != nullptr); + return result.getDefiningOp(); +} + /// Collect the list of values used inside the loop but defined outside of it. -/// The first item in the returned list is always the loop's induction -/// variable. void collectLoopLiveIns(fir::DoLoopOp doLoop, llvm::SmallVectorImpl &liveIns) { llvm::SmallDenseSet seenValues; @@ -194,9 +210,6 @@ void collectLoopLiveIns(fir::DoLoopOp doLoop, return; liveIns.push_back(operand->get()); - - if (isIndVarUltimateOperand(operand->getOwner(), doLoop)) - std::swap(*liveIns.begin(), *liveIns.rbegin()); }); } @@ -286,24 +299,78 @@ void collectIndirectConstOpChain(mlir::Operation *link, opChain.insert(link); } +/// Loop \p innerLoop is considered perfectly-nested inside \p outerLoop iff +/// there are no operations in \p outerloop's other than: +/// +/// 1. the operations needed to assing/update \p outerLoop's induction variable. +/// 2. \p innerLoop itself. +/// +/// \p return true if \p innerLoop is perfectly nested inside \p outerLoop +/// according to the above definition. +bool isPerfectlyNested(fir::DoLoopOp outerLoop, fir::DoLoopOp innerLoop) { + mlir::BackwardSliceOptions backwardSliceOptions; + backwardSliceOptions.inclusive = true; + // We will collect the backward slices for innerLoop's LB, UB, and step. + // However, we want to limit the scope of these slices to the scope of + // outerLoop's region. + backwardSliceOptions.filter = [&](mlir::Operation *op) { + return !mlir::areValuesDefinedAbove(op->getResults(), + outerLoop.getRegion()); + }; + + mlir::ForwardSliceOptions forwardSliceOptions; + forwardSliceOptions.inclusive = true; + // We don't care about the outer-loop's induction variable's uses within the + // inner-loop, so we filter out these uses. + forwardSliceOptions.filter = [&](mlir::Operation *op) { + return mlir::areValuesDefinedAbove(op->getResults(), innerLoop.getRegion()); + }; + + llvm::SetVector indVarSlice; + mlir::getForwardSlice(outerLoop.getInductionVar(), &indVarSlice, + forwardSliceOptions); + llvm::DenseSet innerLoopSetupOpsSet(indVarSlice.begin(), + indVarSlice.end()); + + llvm::DenseSet loopBodySet; + outerLoop.walk([&](mlir::Operation *op) { + if (op == outerLoop) + return mlir::WalkResult::advance(); + + if (op == innerLoop) + return mlir::WalkResult::skip(); + + if (mlir::isa(op)) + return mlir::WalkResult::advance(); + + loopBodySet.insert(op); + return mlir::WalkResult::advance(); + }); + + bool result = (loopBodySet == innerLoopSetupOpsSet); + mlir::Location loc = outerLoop.getLoc(); + LLVM_DEBUG(DBGS() << "Loop pair starting at location " << loc << " is" + << (result ? "" : " not") << " perfectly nested\n"); + + return result; +} + /// Starting with `outerLoop` collect a perfectly nested loop nest, if any. This /// function collects as much as possible loops in the nest; it case it fails to /// recognize a certain nested loop as part of the nest it just returns the /// parent loops it discovered before. -mlir::LogicalResult collectLoopNest(fir::DoLoopOp outerLoop, +mlir::LogicalResult collectLoopNest(fir::DoLoopOp currentLoop, LoopNestToIndVarMap &loopNest) { - assert(outerLoop.getUnordered()); - llvm::SmallVector outerLoopLiveIns; - collectLoopLiveIns(outerLoop, outerLoopLiveIns); + assert(currentLoop.getUnordered()); while (true) { loopNest.try_emplace( - outerLoop, + currentLoop, InductionVariableInfo{ - outerLoopLiveIns.front().getDefiningOp(), - std::move(looputils::extractIndVarUpdateOps(outerLoop))}); + findLoopIndVarMemDecl(currentLoop), + std::move(looputils::extractIndVarUpdateOps(currentLoop))}); - auto directlyNestedLoops = outerLoop.getRegion().getOps(); + auto directlyNestedLoops = currentLoop.getRegion().getOps(); llvm::SmallVector unorderedLoops; for (auto nestedLoop : directlyNestedLoops) @@ -318,69 +385,10 @@ mlir::LogicalResult collectLoopNest(fir::DoLoopOp outerLoop, fir::DoLoopOp nestedUnorderedLoop = unorderedLoops.front(); - if ((nestedUnorderedLoop.getLowerBound().getDefiningOp() == nullptr) || - (nestedUnorderedLoop.getUpperBound().getDefiningOp() == nullptr) || - (nestedUnorderedLoop.getStep().getDefiningOp() == nullptr)) - return mlir::failure(); - - llvm::SmallVector nestedLiveIns; - collectLoopLiveIns(nestedUnorderedLoop, nestedLiveIns); - - llvm::DenseSet outerLiveInsSet; - llvm::DenseSet nestedLiveInsSet; - - // Returns a "unified" view of an mlir::Value. This utility checks if the - // value is defined by an op, and if so, return the first value defined by - // that op (if there are many), otherwise just returns the value. - // - // This serves the purpose that if, for example, `%op_res#0` is used in the - // outer loop and `%op_res#1` is used in the nested loop (or vice versa), - // that we detect both as the same value. If we did not do so, we might - // falesely detect that the 2 loops are not perfectly nested since they use - // "different" sets of values. - auto getUnifiedLiveInView = [](mlir::Value liveIn) { - return liveIn.getDefiningOp() != nullptr - ? liveIn.getDefiningOp()->getResult(0) - : liveIn; - }; - - // Re-package both lists of live-ins into sets so that we can use set - // equality to compare the values used in the outerloop vs. the nestd one. - - for (auto liveIn : nestedLiveIns) - nestedLiveInsSet.insert(getUnifiedLiveInView(liveIn)); - - mlir::Value outerLoopIV; - for (auto liveIn : outerLoopLiveIns) { - outerLiveInsSet.insert(getUnifiedLiveInView(liveIn)); - - // Keep track of the IV of the outerloop. See `isPerfectlyNested` for more - // info on the reason. - if (outerLoopIV == nullptr) - outerLoopIV = getUnifiedLiveInView(liveIn); - } - - // For the 2 loops to be perfectly nested, either: - // * both would have exactly the same set of live-in values or, - // * the outer loop would have exactly 1 extra live-in value: the outer - // loop's induction variable; this happens when the outer loop's IV is - // *not* referenced in the nested loop. - bool isPerfectlyNested = [&]() { - if (outerLiveInsSet == nestedLiveInsSet) - return true; - - if ((outerLiveInsSet.size() == nestedLiveIns.size() + 1) && - !nestedLiveInsSet.contains(outerLoopIV)) - return true; - - return false; - }(); - - if (!isPerfectlyNested) + if (!isPerfectlyNested(currentLoop, nestedUnorderedLoop)) return mlir::failure(); - outerLoop = nestedUnorderedLoop; - outerLoopLiveIns = std::move(nestedLiveIns); + currentLoop = nestedUnorderedLoop; } return mlir::success(); @@ -554,10 +562,6 @@ class DoConcurrentConversion : public mlir::OpConversionPattern { "defining operation."); } - llvm::SmallVector outermostLoopLiveIns; - looputils::collectLoopLiveIns(doLoop, outermostLoopLiveIns); - assert(!outermostLoopLiveIns.empty()); - looputils::LoopNestToIndVarMap loopNest; bool hasRemainingNestedLoops = failed(looputils::collectLoopNest(doLoop, loopNest)); @@ -566,15 +570,19 @@ class DoConcurrentConversion : public mlir::OpConversionPattern { "Some `do concurent` loops are not perfectly-nested. " "These will be serialzied."); + llvm::SmallVector loopNestLiveIns; + looputils::collectLoopLiveIns(loopNest.back().first, loopNestLiveIns); + assert(!loopNestLiveIns.empty()); + llvm::SetVector locals; looputils::collectLoopLocalValues(loopNest.back().first, locals); // We do not want to map "loop-local" values to the device through // `omp.map.info` ops. Therefore, we remove them from the list of live-ins. - outermostLoopLiveIns.erase(llvm::remove_if(outermostLoopLiveIns, - [&](mlir::Value liveIn) { - return locals.contains(liveIn); - }), - outermostLoopLiveIns.end()); + loopNestLiveIns.erase(llvm::remove_if(loopNestLiveIns, + [&](mlir::Value liveIn) { + return locals.contains(liveIn); + }), + loopNestLiveIns.end()); looputils::sinkLoopIVArgs(rewriter, loopNest); @@ -590,24 +598,25 @@ class DoConcurrentConversion : public mlir::OpConversionPattern { loopNestClauseOps, &targetClauseOps); // Prevent mapping host-evaluated variables. - outermostLoopLiveIns.erase( - llvm::remove_if(outermostLoopLiveIns, + loopNestLiveIns.erase( + llvm::remove_if(loopNestLiveIns, [&](mlir::Value liveIn) { return llvm::is_contained( targetClauseOps.hostEvalVars, liveIn); }), - outermostLoopLiveIns.end()); + loopNestLiveIns.end()); // The outermost loop will contain all the live-in values in all nested // loops since live-in values are collected recursively for all nested // ops. - for (mlir::Value liveIn : outermostLoopLiveIns) + for (mlir::Value liveIn : loopNestLiveIns) targetClauseOps.mapVars.push_back( genMapInfoOpForLiveIn(rewriter, liveIn)); targetOp = - genTargetOp(doLoop.getLoc(), rewriter, mapper, outermostLoopLiveIns, + genTargetOp(doLoop.getLoc(), rewriter, mapper, loopNestLiveIns, targetClauseOps, loopNestClauseOps); + genTeamsOp(doLoop.getLoc(), rewriter); } @@ -998,10 +1007,11 @@ class DoConcurrentConversionPass context, mapTo == flangomp::DoConcurrentMappingKind::DCMK_Device, concurrentLoopsToSkip); mlir::ConversionTarget target(*context); - target.addLegalDialect< - fir::FIROpsDialect, hlfir::hlfirDialect, mlir::arith::ArithDialect, - mlir::func::FuncDialect, mlir::omp::OpenMPDialect, - mlir::cf::ControlFlowDialect, mlir::math::MathDialect>(); + target + .addLegalDialect(); target.addDynamicallyLegalOp([&](fir::DoLoopOp op) { return !op.getUnordered() || concurrentLoopsToSkip.contains(op); diff --git a/flang/test/Transforms/DoConcurrent/loop_nest_test.f90 b/flang/test/Transforms/DoConcurrent/loop_nest_test.f90 new file mode 100644 index 00000000000000..c73a8be64bed63 --- /dev/null +++ b/flang/test/Transforms/DoConcurrent/loop_nest_test.f90 @@ -0,0 +1,87 @@ +! Tests loop-nest detection algorithm for do-concurrent mapping. + +! REQUIRES: asserts + +! RUN: %flang_fc1 -emit-hlfir -fopenmp -fdo-concurrent-parallel=host \ +! RUN: -mmlir -debug %s -o - 2> %t.log || true + +! RUN: FileCheck %s < %t.log + +program main + implicit none + +contains + +subroutine foo(n) + implicit none + integer :: n, m + integer :: i, j, k + integer :: x + integer, dimension(n) :: a + integer, dimension(n, n, n) :: b + + ! CHECK: Loop pair starting at location + ! CHECK: loc("{{.*}}":[[# @LINE + 1]]:{{.*}}) is perfectly nested + do concurrent(i=1:n, j=1:bar(n*m, n/m)) + a(i) = n + end do + + ! CHECK: Loop pair starting at location + ! CHECK: loc("{{.*}}":[[# @LINE + 1]]:{{.*}}) is perfectly nested + do concurrent(i=bar(n, x):n, j=1:bar(n*m, n/m)) + a(i) = n + end do + + ! CHECK: Loop pair starting at location + ! CHECK: loc("{{.*}}":[[# @LINE + 1]]:{{.*}}) is not perfectly nested + do concurrent(i=bar(n, x):n) + do concurrent(j=1:bar(n*m, n/m)) + a(i) = n + end do + end do + + ! CHECK: Loop pair starting at location + ! CHECK: loc("{{.*}}":[[# @LINE + 1]]:{{.*}}) is not perfectly nested + do concurrent(i=1:n) + x = 10 + do concurrent(j=1:m) + b(i,j,k) = i * j + k + end do + end do + + ! CHECK: Loop pair starting at location + ! CHECK: loc("{{.*}}":[[# @LINE + 1]]:{{.*}}) is not perfectly nested + do concurrent(i=1:n) + do concurrent(j=1:m) + b(i,j,k) = i * j + k + end do + x = 10 + end do + + ! CHECK: Loop pair starting at location + ! CHECK: loc("{{.*}}":[[# @LINE + 1]]:{{.*}}) is not perfectly nested + do concurrent(i=1:n) + do concurrent(j=1:m) + b(i,j,k) = i * j + k + x = 10 + end do + end do + + ! CHECK: Loop pair starting at location + ! CHECK: loc("{{.*}}":[[# @LINE + 1]]:{{.*}}) is perfectly nested + do concurrent(i=bar(n, x):n, j=1:bar(n*m, n/m), k=1:bar(n*m, bar(n*m, n/m))) + a(i) = n + end do + + +end subroutine + +pure function bar(n, m) + implicit none + integer, intent(in) :: n, m + integer :: bar + + bar = n + m +end function + +end program main diff --git a/flang/test/Transforms/DoConcurrent/multiple_iteration_ranges.f90 b/flang/test/Transforms/DoConcurrent/multiple_iteration_ranges.f90 index 13ee9bce85944f..7939ec9e99debb 100644 --- a/flang/test/Transforms/DoConcurrent/multiple_iteration_ranges.f90 +++ b/flang/test/Transforms/DoConcurrent/multiple_iteration_ranges.f90 @@ -8,88 +8,42 @@ ! RUN: %flang_fc1 -emit-hlfir -fopenmp -fdo-concurrent-parallel=device %t/multi_range.f90 -o - \ ! RUN: | FileCheck %s --check-prefixes=DEVICE,COMMON -! RUN: %flang_fc1 -emit-hlfir -fopenmp -fdo-concurrent-parallel=host %t/perfectly_nested.f90 -o - \ -! RUN: | FileCheck %s --check-prefixes=HOST,COMMON - -! RUN: %flang_fc1 -emit-hlfir -fopenmp -fdo-concurrent-parallel=device %t/perfectly_nested.f90 -o - \ -! RUN: | FileCheck %s --check-prefixes=DEVICE,COMMON - -! RUN: %flang_fc1 -emit-hlfir -fopenmp -fdo-concurrent-parallel=host %t/partially_nested.f90 -o - \ -! RUN: | FileCheck %s --check-prefixes=HOST,COMMON - -! RUN: %flang_fc1 -emit-hlfir -fopenmp -fdo-concurrent-parallel=device %t/partially_nested.f90 -o - \ -! RUN: | FileCheck %s --check-prefixes=DEVICE,COMMON - -! This is temporarily disabled since the IR for `do concurrent` loops is different after -! https://github.com/llvm/llvm-project/pull/114020. This will be enabled again soon. -! XFAIL: true - !--- multi_range.f90 program main - integer, parameter :: n = 10 - integer, parameter :: m = 20 - integer, parameter :: l = 30 + integer, parameter :: n = 20 + integer, parameter :: m = 40 + integer, parameter :: l = 60 integer :: a(n, m, l) - do concurrent(i=1:n, j=1:m, k=1:l) + do concurrent(i=3:n, j=5:m, k=7:l) a(i,j,k) = i * j + k end do end -!--- perfectly_nested.f90 -program main - integer, parameter :: n = 10 - integer, parameter :: m = 20 - integer, parameter :: l = 30 - integer :: a(n, m, l) - - do concurrent(i=1:n) - do concurrent(j=1:m) - do concurrent(k=1:l) - a(i,j,k) = i * j + k - end do - end do - end do -end - -!--- partially_nested.f90 -program main - integer, parameter :: n = 10 - integer, parameter :: m = 20 - integer, parameter :: l = 30 - integer :: a(n, m, l) - - do concurrent(i=1:n, j=1:m) - do concurrent(k=1:l) - a(i,j,k) = i * j + k - end do - end do -end - ! COMMON: func.func @_QQmain -! DEVICE: %[[DUPLICATED_C1_1:.*]] = arith.constant 1 : i32 -! DEVICE: %[[DUPLICATED_LB_I:.*]] = fir.convert %[[DUPLICATED_C1_1]] : (i32) -> index -! DEVICE: %[[DUPLICATED_C10:.*]] = arith.constant 10 : i32 -! DEVICE: %[[DUPLICATED_UB_I:.*]] = fir.convert %[[DUPLICATED_C10]] : (i32) -> index +! DEVICE: %[[DUPLICATED_C3:.*]] = arith.constant 3 : i32 +! DEVICE: %[[DUPLICATED_LB_I:.*]] = fir.convert %[[DUPLICATED_C3]] : (i32) -> index +! DEVICE: %[[DUPLICATED_C20:.*]] = arith.constant 20 : i32 +! DEVICE: %[[DUPLICATED_UB_I:.*]] = fir.convert %[[DUPLICATED_C20]] : (i32) -> index ! DEVICE: %[[DUPLICATED_STEP_I:.*]] = arith.constant 1 : index -! DEVICE: %[[C1_1:.*]] = arith.constant 1 : i32 -! DEVICE: %[[HOST_LB_I:.*]] = fir.convert %[[C1_1]] : (i32) -> index -! DEVICE: %[[C10:.*]] = arith.constant 10 : i32 -! DEVICE: %[[HOST_UB_I:.*]] = fir.convert %[[C10]] : (i32) -> index +! DEVICE: %[[C3:.*]] = arith.constant 3 : i32 +! DEVICE: %[[HOST_LB_I:.*]] = fir.convert %[[C3]] : (i32) -> index +! DEVICE: %[[C20:.*]] = arith.constant 20 : i32 +! DEVICE: %[[HOST_UB_I:.*]] = fir.convert %[[C20]] : (i32) -> index ! DEVICE: %[[HOST_STEP_I:.*]] = arith.constant 1 : index -! DEVICE: %[[C1_2:.*]] = arith.constant 1 : i32 -! DEVICE: %[[HOST_LB_J:.*]] = fir.convert %[[C1_2]] : (i32) -> index -! DEVICE: %[[C20:.*]] = arith.constant 20 : i32 -! DEVICE: %[[HOST_UB_J:.*]] = fir.convert %[[C20]] : (i32) -> index +! DEVICE: %[[C5:.*]] = arith.constant 5 : i32 +! DEVICE: %[[HOST_LB_J:.*]] = fir.convert %[[C5]] : (i32) -> index +! DEVICE: %[[C40:.*]] = arith.constant 40 : i32 +! DEVICE: %[[HOST_UB_J:.*]] = fir.convert %[[C40]] : (i32) -> index ! DEVICE: %[[HOST_STEP_J:.*]] = arith.constant 1 : index -! DEVICE: %[[C1_3:.*]] = arith.constant 1 : i32 -! DEVICE: %[[HOST_LB_K:.*]] = fir.convert %[[C1_3]] : (i32) -> index -! DEVICE: %[[C30:.*]] = arith.constant 30 : i32 -! DEVICE: %[[HOST_UB_K:.*]] = fir.convert %[[C30]] : (i32) -> index +! DEVICE: %[[C7:.*]] = arith.constant 7 : i32 +! DEVICE: %[[HOST_LB_K:.*]] = fir.convert %[[C7]] : (i32) -> index +! DEVICE: %[[C60:.*]] = arith.constant 60 : i32 +! DEVICE: %[[HOST_UB_K:.*]] = fir.convert %[[C60]] : (i32) -> index ! DEVICE: %[[HOST_STEP_K:.*]] = arith.constant 1 : index ! DEVICE: omp.target host_eval( @@ -103,6 +57,7 @@ program main ! DEVICE-SAME: %[[HOST_UB_K]] -> %[[UB_K:[[:alnum:]]+]], ! DEVICE-SAME: %[[HOST_STEP_K]] -> %[[STEP_K:[[:alnum:]]+]] : ! DEVICE-SAME: index, index, index, index, index, index, index, index, index) + ! DEVICE: omp.teams ! HOST-NOT: omp.target @@ -119,22 +74,22 @@ program main ! COMMON-NEXT: %[[ITER_VAR_K:.*]] = fir.alloca i32 {bindc_name = "k"} ! COMMON-NEXT: %[[BINDING_K:.*]]:2 = hlfir.declare %[[ITER_VAR_K]] {uniq_name = "_QFEk"} -! HOST: %[[C1_1:.*]] = arith.constant 1 : i32 -! HOST: %[[LB_I:.*]] = fir.convert %[[C1_1]] : (i32) -> index -! HOST: %[[C10:.*]] = arith.constant 10 : i32 -! HOST: %[[UB_I:.*]] = fir.convert %[[C10]] : (i32) -> index +! HOST: %[[C3:.*]] = arith.constant 3 : i32 +! HOST: %[[LB_I:.*]] = fir.convert %[[C3]] : (i32) -> index +! HOST: %[[C20:.*]] = arith.constant 20 : i32 +! HOST: %[[UB_I:.*]] = fir.convert %[[C20]] : (i32) -> index ! HOST: %[[STEP_I:.*]] = arith.constant 1 : index -! HOST: %[[C1_2:.*]] = arith.constant 1 : i32 -! HOST: %[[LB_J:.*]] = fir.convert %[[C1_2]] : (i32) -> index -! HOST: %[[C20:.*]] = arith.constant 20 : i32 -! HOST: %[[UB_J:.*]] = fir.convert %[[C20]] : (i32) -> index +! HOST: %[[C5:.*]] = arith.constant 5 : i32 +! HOST: %[[LB_J:.*]] = fir.convert %[[C5]] : (i32) -> index +! HOST: %[[C40:.*]] = arith.constant 40 : i32 +! HOST: %[[UB_J:.*]] = fir.convert %[[C40]] : (i32) -> index ! HOST: %[[STEP_J:.*]] = arith.constant 1 : index -! HOST: %[[C1_3:.*]] = arith.constant 1 : i32 -! HOST: %[[LB_K:.*]] = fir.convert %[[C1_3]] : (i32) -> index -! HOST: %[[C30:.*]] = arith.constant 30 : i32 -! HOST: %[[UB_K:.*]] = fir.convert %[[C30]] : (i32) -> index +! HOST: %[[C7:.*]] = arith.constant 7 : i32 +! HOST: %[[LB_K:.*]] = fir.convert %[[C7]] : (i32) -> index +! HOST: %[[C60:.*]] = arith.constant 60 : i32 +! HOST: %[[UB_K:.*]] = fir.convert %[[C60]] : (i32) -> index ! HOST: %[[STEP_K:.*]] = arith.constant 1 : index ! DEVICE: omp.distribute