From 719e50c8b6da777d8c8085aa8539e81a403cc588 Mon Sep 17 00:00:00 2001 From: Sergio Afonso Date: Thu, 3 Oct 2024 11:26:06 +0100 Subject: [PATCH] [MLIR][OpenMP] Introduce host_eval clause to omp.target This patch defines a map-like clause named `host_eval` used to capture host values for use inside of target regions on restricted cases: - As `num_teams` or `thread_limit` of a nested `omp.target` operation. - As `num_threads` of a nested `omp.parallel` operation or as bounds or steps of a nested `omp.loop_nest`, if it is a target SPMD kernel. This replaces the following `omp.target` arguments: `trip_count`, `num_threads`, `num_teams_lower`, `num_teams_upper` and `teams_thread_limit`. --- mlir/docs/Dialects/OpenMPDialect/_index.md | 56 ++++++++++++- .../mlir/Dialect/OpenMP/OpenMPClauses.td | 38 +++++++++ mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 31 +++---- .../Dialect/OpenMP/OpenMPOpsInterfaces.td | 27 ++++-- mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 82 +++++++++++-------- mlir/test/Dialect/OpenMP/invalid.mlir | 71 +++++++++++++++- mlir/test/Dialect/OpenMP/ops.mlir | 40 ++++++++- 7 files changed, 284 insertions(+), 61 deletions(-) diff --git a/mlir/docs/Dialects/OpenMPDialect/_index.md b/mlir/docs/Dialects/OpenMPDialect/_index.md index b4e359284edae69..01a0b7decbf6556 100644 --- a/mlir/docs/Dialects/OpenMPDialect/_index.md +++ b/mlir/docs/Dialects/OpenMPDialect/_index.md @@ -297,7 +297,8 @@ arguments for the region of that MLIR operation. This enables, for example, the introduction of private copies of the same underlying variable defined outside the MLIR operation the clause is attached to. Currently, clauses with this property can be classified into three main categories: - - Map-like clauses: `map`, `use_device_addr` and `use_device_ptr`. + - Map-like clauses: `host_eval`, `map`, `use_device_addr` and +`use_device_ptr`. - Reduction-like clauses: `in_reduction`, `reduction` and `task_reduction`. - Privatization clauses: `private`. @@ -526,3 +527,56 @@ omp.parallel ... { omp.terminator } {omp.composite} ``` + +## Host-Evaluated Clauses in Target Regions + +The `omp.target` operation, used to represent the OpenMP `target` construct, is +an `IsolatedFromAbove` operation, which means no outside MLIR values are allowed +inside of the region defined by it. This is a good match for the semantics of +the construct, since host values used inside of the target region must be +privatized or mapped to be used. + +Regularly, the evaluation of clauses applied to a given construct must be +completed prior to entering that construct. However, there are clauses for which +the OpenMP specification defines exceptions when nested inside of a target +region. Specifically, the `num_teams` and `thread_limit` clauses of the `teams` +construct must be evaluated in the host if it is nested inside of or combined +with a `target` construct. + +Additionally, the runtime library targeted by the MLIR to LLVM IR translation of +the OpenMP dialect supports the optimized launch of SPMD kernels (i.e. +`target teams distribute parallel {do,for}` in OpenMP), which requires +specifying in advance what the total trip count of the loop is. Consequently, it +is also beneficial to evaluate it in the host prior to the kernel launch. + +These host-evaluated values in MLIR would need to be placed outside of the +`omp.target` region and also attached to the corresponding nested operations, +which is not possible because of the `IsolatedFromAbove` trait. The solution +implemented to address this problem has been to introduce the `host_eval` +argument to the `omp.target` operation. It works similarly to a `map` clause, +but its only intended use is to forward host-evaluated values to their +corresponding operation inside of the region. Any uses outside of the previously +described result in a verifier error. + +```mlir +// Initialize %0, %1, %2, %3... +omp.target host_eval(%0 -> %nt, %1 -> %lb, %2 -> %ub, %3 -> %step : i32, i32, i32, i32) { + omp.teams num_teams(to %nt : i32) { + omp.parallel { + omp.distribute { + omp.wsloop { + omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) { + // ... + omp.yield + } + omp.terminator + } {omp.composite} + omp.terminator + } {omp.composite} + omp.terminator + } {omp.composite} + omp.terminator + } + omp.terminator +} +``` diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td index 886554f66afffc3..ddcde74a363d46a 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td @@ -419,6 +419,44 @@ class OpenMP_HintClauseSkip< def OpenMP_HintClause : OpenMP_HintClauseSkip<>; +//===----------------------------------------------------------------------===// +// Not in the spec: Clause-like structure to hold host-evaluated values. +//===----------------------------------------------------------------------===// + +class OpenMP_HostEvalClauseSkip< + bit traits = false, bit arguments = false, bit assemblyFormat = false, + bit description = false, bit extraClassDeclaration = false + > : OpenMP_Clause { + let traits = [ + BlockArgOpenMPOpInterface + ]; + + let arguments = (ins + Variadic:$host_eval_vars + ); + + let extraClassDeclaration = [{ + unsigned numHostEvalBlockArgs() { + return getHostEvalVars().size(); + } + }]; + + let description = [{ + The optional `host_eval_vars` holds values defined outside of the region of + the `IsolatedFromAbove` operation for which a corresponding entry block + argument is defined. The only legal uses for these captured values are the + following: + - `num_teams` or `thread_limit` clause of an immediately nested + `omp.teams` operation. + - If the operation is the top-level `omp.target` of a target SPMD kernel: + - `num_threads` clause of the nested `omp.parallel` operation. + - Bounds and steps of the nested `omp.loop_nest` operation. + }]; +} + +def OpenMP_HostEvalClause : OpenMP_HostEvalClauseSkip<>; + //===----------------------------------------------------------------------===// // V5.2: [3.4] `if` clause //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td index 5bd0843f7f2f91c..c8500f1ba65b18c 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -1100,20 +1100,16 @@ def TargetUpdateOp: OpenMP_Op<"target_update", traits = [ // 2.14.5 target construct //===----------------------------------------------------------------------===// -// TODO: Remove num_threads, teams_thread_limit and trip_count and implement the -// passthrough approach described here: -// https://discourse.llvm.org/t/rfc-openmp-dialect-representation-of-num-teams-thread-limit-and-target-spmd/81106. def TargetOp : OpenMP_Op<"target", traits = [ AttrSizedOperandSegments, BlockArgOpenMPOpInterface, IsolatedFromAbove, OutlineableOpenMPOpInterface ], clauses = [ // TODO: Complete clause list (defaultmap, uses_allocators). OpenMP_AllocateClause, OpenMP_DependClause, OpenMP_DeviceClause, - OpenMP_HasDeviceAddrClause, OpenMP_IfClause, OpenMP_InReductionClause, - OpenMP_IsDevicePtrClause, OpenMP_MapClauseSkip, - OpenMP_NowaitClause, OpenMP_NumTeamsClauseSkip, - OpenMP_NumThreadsClauseSkip, OpenMP_PrivateClause, - OpenMP_ThreadLimitClause + OpenMP_HasDeviceAddrClause, OpenMP_HostEvalClause, OpenMP_IfClause, + OpenMP_InReductionClause, OpenMP_IsDevicePtrClause, + OpenMP_MapClauseSkip, OpenMP_NowaitClause, + OpenMP_PrivateClause, OpenMP_ThreadLimitClause ], singleRegion = true> { let summary = "target construct"; let description = [{ @@ -1140,10 +1136,6 @@ def TargetOp : OpenMP_Op<"target", traits = [ an `omp.parallel`. }] # clausesDescription; - let arguments = !con(clausesArgs, - (ins Optional:$trip_count, - Optional:$teams_thread_limit)); - let builders = [ OpBuilder<(ins CArg<"const TargetOperands &">:$clauses)> ]; @@ -1168,15 +1160,12 @@ def TargetOp : OpenMP_Op<"target", traits = [ bool isTargetSPMDLoop(); }] # clausesExtraClassDeclaration; - let assemblyFormat = clausesReqAssemblyFormat # - " oilist(" # clausesOptAssemblyFormat # [{ - | `trip_count` `(` $trip_count `:` type($trip_count) `)` - | `teams_thread_limit` `(` $teams_thread_limit `:` type($teams_thread_limit) `)` - }] # ")" # [{ - custom( - $region, $in_reduction_vars, type($in_reduction_vars), - $in_reduction_byref, $in_reduction_syms, $map_vars, type($map_vars), - $private_vars, type($private_vars), $private_syms) attr-dict + let assemblyFormat = clausesAssemblyFormat # [{ + custom( + $region, $host_eval_vars, type($host_eval_vars), $in_reduction_vars, + type($in_reduction_vars), $in_reduction_byref, $in_reduction_syms, + $map_vars, type($map_vars), $private_vars, type($private_vars), + $private_syms) attr-dict }]; let hasVerifier = 1; diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td index 22521b08637cf8c..46395ddc4701664 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td @@ -25,6 +25,10 @@ def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> { let methods = [ // Default-implemented methods to be overriden by the corresponding clauses. + InterfaceMethod<"Get number of block arguments defined by `host_eval`.", + "unsigned", "numHostEvalBlockArgs", (ins), [{}], [{ + return 0; + }]>, InterfaceMethod<"Get number of block arguments defined by `in_reduction`.", "unsigned", "numInReductionBlockArgs", (ins), [{}], [{ return 0; @@ -55,9 +59,14 @@ def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> { }]>, // Unified access methods for clause-associated entry block arguments. + InterfaceMethod<"Get start index of block arguments defined by `host_eval`.", + "unsigned", "getHostEvalBlockArgsStart", (ins), [{ + return 0; + }]>, InterfaceMethod<"Get start index of block arguments defined by `in_reduction`.", "unsigned", "getInReductionBlockArgsStart", (ins), [{ - return 0; + auto iface = ::llvm::cast(*$_op); + return iface.getHostEvalBlockArgsStart() + $_op.numHostEvalBlockArgs(); }]>, InterfaceMethod<"Get start index of block arguments defined by `map`.", "unsigned", "getMapBlockArgsStart", (ins), [{ @@ -91,6 +100,13 @@ def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> { return iface.getUseDeviceAddrBlockArgsStart() + $_op.numUseDeviceAddrBlockArgs(); }]>, + InterfaceMethod<"Get block arguments defined by `host_eval`.", + "::llvm::MutableArrayRef<::mlir::BlockArgument>", + "getHostEvalBlockArgs", (ins), [{ + auto iface = ::llvm::cast(*$_op); + return $_op->getRegion(0).getArguments().slice( + iface.getHostEvalBlockArgsStart(), $_op.numHostEvalBlockArgs()); + }]>, InterfaceMethod<"Get block arguments defined by `in_reduction`.", "::llvm::MutableArrayRef<::mlir::BlockArgument>", "getInReductionBlockArgs", (ins), [{ @@ -147,10 +163,11 @@ def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> { let verify = [{ auto iface = ::llvm::cast($_op); - unsigned expectedArgs = iface.numInReductionBlockArgs() + - iface.numMapBlockArgs() + iface.numPrivateBlockArgs() + - iface.numReductionBlockArgs() + iface.numTaskReductionBlockArgs() + - iface.numUseDeviceAddrBlockArgs() + iface.numUseDevicePtrBlockArgs(); + unsigned expectedArgs = iface.numHostEvalBlockArgs() + + iface.numInReductionBlockArgs() + iface.numMapBlockArgs() + + iface.numPrivateBlockArgs() + iface.numReductionBlockArgs() + + iface.numTaskReductionBlockArgs() + iface.numUseDeviceAddrBlockArgs() + + iface.numUseDevicePtrBlockArgs(); if ($_op->getRegion(0).getNumArguments() < expectedArgs) return $_op->emitOpError() << "expected at least " << expectedArgs << " entry block argument(s)"; diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index 28ed12700349114..a671a4d6fab7231 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -498,6 +498,7 @@ struct ReductionParseArgs { : vars(vars), types(types), byref(byref), syms(syms) {} }; struct AllRegionParseArgs { + std::optional hostEvalArgs; std::optional inReductionArgs; std::optional mapArgs; std::optional privateArgs; @@ -624,6 +625,11 @@ static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region ®ion, AllRegionParseArgs args) { llvm::SmallVector entryBlockArgs; + if (failed(parseBlockArgClause(parser, entryBlockArgs, "host_eval", + args.hostEvalArgs))) + return parser.emitError(parser.getCurrentLocation()) + << "invalid `host_eval` format"; + if (failed(parseBlockArgClause(parser, entryBlockArgs, "in_reduction", args.inReductionArgs))) return parser.emitError(parser.getCurrentLocation()) @@ -662,8 +668,10 @@ static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region ®ion, return parser.parseRegion(region, entryBlockArgs); } -static ParseResult parseInReductionMapPrivateRegion( +static ParseResult parseHostEvalInReductionMapPrivateRegion( OpAsmParser &parser, Region ®ion, + SmallVectorImpl &hostEvalVars, + SmallVectorImpl &hostEvalTypes, SmallVectorImpl &inReductionVars, SmallVectorImpl &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, @@ -672,6 +680,7 @@ static ParseResult parseInReductionMapPrivateRegion( llvm::SmallVectorImpl &privateVars, llvm::SmallVectorImpl &privateTypes, ArrayAttr &privateSyms) { AllRegionParseArgs args; + args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes); args.inReductionArgs.emplace(inReductionVars, inReductionTypes, inReductionByref, inReductionSyms); args.mapArgs.emplace(mapVars, mapTypes); @@ -785,6 +794,7 @@ struct ReductionPrintArgs { : vars(vars), types(types), byref(byref), syms(syms) {} }; struct AllRegionPrintArgs { + std::optional hostEvalArgs; std::optional inReductionArgs; std::optional mapArgs; std::optional privateArgs; @@ -863,6 +873,8 @@ static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region ®ion, auto iface = llvm::cast(op); MLIRContext *ctx = op->getContext(); + printBlockArgClause(p, ctx, "host_eval", iface.getHostEvalBlockArgs(), + args.hostEvalArgs); printBlockArgClause(p, ctx, "in_reduction", iface.getInReductionBlockArgs(), args.inReductionArgs); printBlockArgClause(p, ctx, "map_entries", iface.getMapBlockArgs(), @@ -883,12 +895,14 @@ static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region ®ion, p.printRegion(region, /*printEntryBlockArgs=*/false); } -static void printInReductionMapPrivateRegion( - OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange inReductionVars, +static void printHostEvalInReductionMapPrivateRegion( + OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange hostEvalVars, + TypeRange hostEvalTypes, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange mapVars, TypeRange mapTypes, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms) { AllRegionPrintArgs args; + args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes); args.inReductionArgs.emplace(inReductionVars, inReductionTypes, inReductionByref, inReductionSyms); args.mapArgs.emplace(mapVars, mapTypes); @@ -966,6 +980,7 @@ static void printUseDeviceAddrUseDevicePtrRegion(OpAsmPrinter &p, Operation *op, args.useDevicePtrArgs.emplace(useDevicePtrVars, useDevicePtrTypes); printBlockArgRegion(p, op, region, args); } + /// Verifies Reduction Clause static LogicalResult verifyReductionVarList(Operation *op, std::optional reductionSyms, @@ -1651,14 +1666,12 @@ void TargetOp::build(OpBuilder &builder, OperationState &state, // inReductionByref, inReductionSyms. TargetOp::build(builder, state, /*allocate_vars=*/{}, /*allocator_vars=*/{}, makeArrayAttr(ctx, clauses.dependKinds), clauses.dependVars, - clauses.device, clauses.hasDeviceAddrVars, clauses.ifExpr, + clauses.device, clauses.hasDeviceAddrVars, + clauses.hostEvalVars, clauses.ifExpr, /*in_reduction_vars=*/{}, /*in_reduction_byref=*/nullptr, /*in_reduction_syms=*/nullptr, clauses.isDevicePtrVars, - clauses.mapVars, clauses.nowait, /*num_teams_lower=*/nullptr, - /*num_teams_upper=*/nullptr, /*num_threads_var=*/nullptr, - clauses.privateVars, makeArrayAttr(ctx, clauses.privateSyms), - clauses.threadLimit, - /*trip_count=*/nullptr, /*teams_thread_limit=*/nullptr); + clauses.mapVars, clauses.nowait, clauses.privateVars, + makeArrayAttr(ctx, clauses.privateSyms), clauses.threadLimit); } /// Only allow OpenMP terminators and non-OpenMP ops that have known memory @@ -1707,18 +1720,31 @@ LogicalResult TargetOp::verify() { if (std::distance(teamsOps.begin(), teamsOps.end()) > 1) return emitError("target containing multiple teams constructs"); - if (!isTargetSPMDLoop() && getTripCount()) - return emitError("trip_count set on non-SPMD target region"); - - if (teamsOps.empty()) { - if (getNumTeamsLower() || getNumTeamsUpper() || getTeamsThreadLimit()) - return emitError( - "num_teams and teams_thread_limit arguments only allowed if there is " - "an omp.teams child operation"); - } else { - if (failed(verifyNumTeamsClause(*this, getNumTeamsLower(), - getNumTeamsUpper()))) - return failure(); + // Check that host_eval values are only used in legal ways. + bool isTargetSPMD = isTargetSPMDLoop(); + for (Value hostEvalArg : + cast(getOperation()).getHostEvalBlockArgs()) { + for (Operation *user : hostEvalArg.getUsers()) { + if (auto teamsOp = dyn_cast(user)) { + if (llvm::is_contained({teamsOp.getNumTeamsLower(), + teamsOp.getNumTeamsUpper(), + teamsOp.getThreadLimit()}, + hostEvalArg)) + continue; + } else if (auto parallelOp = dyn_cast(user)) { + if (isTargetSPMD && hostEvalArg == parallelOp.getNumThreads()) + continue; + } else if (auto loopNestOp = dyn_cast(user)) { + if (isTargetSPMD && + (llvm::is_contained(loopNestOp.getLoopLowerBounds(), hostEvalArg) || + llvm::is_contained(loopNestOp.getLoopUpperBounds(), hostEvalArg) || + llvm::is_contained(loopNestOp.getLoopSteps(), hostEvalArg))) { + continue; + } + } + return emitOpError() << "host_eval argument illegal use in '" + << user->getName() << "' operation"; + } } LogicalResult verifyDependVars = @@ -1948,17 +1974,9 @@ LogicalResult TeamsOp::verify() { return emitError("expected to be nested inside of omp.target or not nested " "in any OpenMP dialect operations"); - auto offloadModOp = - llvm::cast(*(*this)->getParentOfType()); - if (targetOp && !offloadModOp.getIsTargetDevice()) { - if (getNumTeamsLower() || getNumTeamsUpper() || getThreadLimit()) - return emitError("num_teams and thread_limit arguments expected to be " - "attached to parent omp.target operation"); - } else { - if (failed(verifyNumTeamsClause(*this, getNumTeamsLower(), - getNumTeamsUpper()))) - return failure(); - } + if (failed( + verifyNumTeamsClause(*this, getNumTeamsLower(), getNumTeamsUpper()))) + return failure(); // Check for allocate clause restrictions if (getAllocateVars().size() != getAllocatorVars().size()) diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir index 57c1e34a2c40e76..d77f5929435da78 100644 --- a/mlir/test/Dialect/OpenMP/invalid.mlir +++ b/mlir/test/Dialect/OpenMP/invalid.mlir @@ -2187,11 +2187,80 @@ func.func @omp_target_update_data_depend(%a: memref) { // ----- +func.func @omp_target_multiple_teams() { + // expected-error @below {{target containing multiple teams constructs}} + omp.target { + omp.teams { + omp.terminator + } + omp.teams { + omp.terminator + } + omp.terminator + } + return +} + +// ----- + +func.func @omp_target_host_eval1(%x : !llvm.ptr) { + // expected-error @below {{op host_eval argument illegal use in 'llvm.load' operation}} + omp.target host_eval(%x -> %arg0 : !llvm.ptr) { + %0 = llvm.load %arg0 : !llvm.ptr -> f32 + omp.terminator + } + return +} + +// ----- + +func.func @omp_target_host_eval2(%x : i1) { + // expected-error @below {{op host_eval argument illegal use in 'omp.teams' operation}} + omp.target host_eval(%x -> %arg0 : i1) { + omp.teams if(%arg0) { + omp.terminator + } + omp.terminator + } + return +} + +// ----- + +func.func @omp_target_host_eval3(%x : i32) { + // expected-error @below {{op host_eval argument illegal use in 'omp.parallel' operation}} + omp.target host_eval(%x -> %arg0 : i32) { + omp.parallel num_threads(%arg0 : i32) { + omp.terminator + } + omp.terminator + } + return +} + +// ----- + +func.func @omp_target_host_eval3(%x : i32) { + // expected-error @below {{op host_eval argument illegal use in 'omp.loop_nest' operation}} + omp.target host_eval(%x -> %arg0 : i32) { + omp.wsloop { + omp.loop_nest (%iv) : i32 = (%arg0) to (%arg0) step (%arg0) { + omp.yield + } + omp.terminator + } + omp.terminator + } + return +} + +// ----- + func.func @omp_target_depend(%data_var: memref) { // expected-error @below {{op expected as many depend values as depend variables}} "omp.target"(%data_var) ({ "omp.terminator"() : () -> () - }) {depend_kinds = [], operandSegmentSizes = array} : (memref) -> () + }) {depend_kinds = [], operandSegmentSizes = array} : (memref) -> () "func.return"() : () -> () } diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir index d7ad1ec62e6bec7..2b902972aee045b 100644 --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -832,7 +832,7 @@ func.func @omp_target(%if_cond : i1, %device : si32, %num_threads : i32, %devic "omp.target"(%device, %if_cond, %num_threads) ({ // CHECK: omp.terminator omp.terminator - }) {nowait, operandSegmentSizes = array} : ( si32, i1, i32 ) -> () + }) {nowait, operandSegmentSizes = array} : ( si32, i1, i32 ) -> () // Test with optional map clause. // CHECK: %[[MAP_A:.*]] = omp.map.info var_ptr(%[[VAL_1:.*]] : memref, tensor) map_clauses(tofrom) capture(ByRef) -> memref {name = ""} @@ -2840,3 +2840,41 @@ func.func @omp_target_private(%map1: memref, %map2: memref, %priv_ return } + +func.func @omp_target_host_eval(%x : i32) { + // CHECK: omp.target host_eval(%{{.*}} -> %[[HOST_ARG:.*]] : i32) { + // CHECK: omp.teams num_teams( to %[[HOST_ARG]] : i32) + // CHECK-SAME: thread_limit(%[[HOST_ARG]] : i32) + omp.target host_eval(%x -> %arg0 : i32) { + omp.teams num_teams(to %arg0 : i32) thread_limit(%arg0 : i32) { + omp.terminator + } + omp.terminator + } + + // CHECK: omp.target host_eval(%{{.*}} -> %[[HOST_ARG:.*]] : i32) { + // CHECK: omp.teams + // CHECK: omp.parallel num_threads(%[[HOST_ARG]] : i32) { + // CHECK: omp.distribute { + // CHECK: omp.wsloop { + // CHECK: omp.loop_nest (%{{.*}}) : i32 = (%[[HOST_ARG]]) to (%[[HOST_ARG]]) step (%[[HOST_ARG]]) { + omp.target host_eval(%x -> %arg0 : i32) { + omp.teams { + omp.parallel num_threads(%arg0 : i32) { + omp.distribute { + omp.wsloop { + omp.loop_nest (%iv) : i32 = (%arg0) to (%arg0) step (%arg0) { + omp.yield + } + omp.terminator + } {omp.composite} + omp.terminator + } {omp.composite} + omp.terminator + } {omp.composite} + omp.terminator + } + omp.terminator + } + return +}