From 95ed505c4f4617bf0e8b4c02184a6dc020c4fc4a Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Tue, 25 Oct 2022 18:15:58 +0200 Subject: [PATCH] Fix transform dialect test and cleanup some usages. (#10782) This revision cherry-picks LLVM commit d8cab3f407070c6d80396553ce024e17a0659b04 and manually resolves conflicts. This is necessary to fix a bug in the transform dialect. This allows reactivating the softmax gpu example. --- .../TransformExtensions/CommonExtensions.cpp | 74 +++++++++++-------- .../TransformExtensions/LLVMGPUExtensions.cpp | 11 ++- .../IR/StructuredTransformOpsExt.cpp | 4 +- tests/transform_dialect/cuda/BUILD | 72 +++++++++--------- tests/transform_dialect/cuda/CMakeLists.txt | 25 +++++++ .../cuda/reduction_codegen_spec.mlir | 4 +- tests/transform_dialect/cuda/softmax.mlir | 16 ++-- .../cuda/softmax_codegen_spec.mlir | 2 +- .../cuda/softmax_fused_codegen_spec.mlir | 15 ++-- .../softmax_fused_codegen_spec.mlir.broken | 57 -------------- 10 files changed, 134 insertions(+), 146 deletions(-) delete mode 100644 tests/transform_dialect/cuda/softmax_fused_codegen_spec.mlir.broken diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp index 68a7c1cac2bc..78b59fe4093a 100644 --- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp @@ -101,10 +101,10 @@ DiagnosedSilenceableFailure transform_dialect::ApplyPatternsOp::applyToOne( Operation *target, SmallVectorImpl &results, transform::TransformState &state) { if (!target->hasTrait()) { - target->emitOpError( + return mlir::emitDefiniteFailure( + target, "applies only to isolated-from-above targets because it needs to apply " "patterns greedily"); - return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); } MLIRContext *ctx = target->getContext(); RewritePatternSet patterns(ctx); @@ -121,8 +121,13 @@ DiagnosedSilenceableFailure transform_dialect::ApplyPatternsOp::applyToOne( LogicalResult result = applyPatternsAndFoldGreedily( target, std::move(patterns), config, &listener); LogicalResult listenerResult = listener.checkErrorState(); - if (failed(result) || failed(listenerResult)) - return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); + if (failed(result)) { + return mlir::emitDefiniteFailure(target, + "greedy pattern application failed"); + } + if (failed(listenerResult)) + return mlir::emitDefiniteFailure(target, "listener tracking failed"); + results.assign({target}); return DiagnosedSilenceableFailure(success()); } @@ -207,10 +212,10 @@ DiagnosedSilenceableFailure transform_dialect::IREEBufferizeOp::apply( if (payload.size() != 1 || !isa( payload.front())) { - state.getTopLevel()->emitOpError( + return mlir::emitDefiniteFailure( + state.getTopLevel(), "requires exactly a single HAL::ExecutableOp or " "HAL::ExecutableVariantOp target op."); - return DiagnosedSilenceableFailure(failure()); } PassManager pm(getContext()); // Bufferize the dispatch. @@ -237,9 +242,14 @@ DiagnosedSilenceableFailure transform_dialect::IREEBufferizeOp::apply( } return WalkResult::advance(); }); + + if (res.wasInterrupted()) + return DiagnosedSilenceableFailure::definiteFailure(); + results.set(getOperation()->getOpResult(0), payload.front()); - return DiagnosedSilenceableFailure(failure(res.wasInterrupted())); + return DiagnosedSilenceableFailure::success(); } + /// Populate the workgroup_count region of `dispatchOp`. /// For now, this only supports constant index ops and empty workload operands. /// Assumes the HAL::ExecutableExportOp is built with an empty region. @@ -437,11 +447,11 @@ transform_dialect::ForeachThreadToWorkgroupOp::applyToOne( func::FuncOp target, SmallVectorImpl &results, transform::TransformState &state) { if (!isa(state.getTopLevel())) { - state.getTopLevel()->emitOpError( + return mlir::emitDefiniteFailure( + state.getTopLevel(), "requires HAL::ExecutableOp or HAL::ExecutableVariantOp toplevel " "to attach the workgroup size information to a nested " "ExecutableExportOp"); - return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); } IREE::HAL::ExecutableExportOp exportOp; @@ -449,8 +459,9 @@ transform_dialect::ForeachThreadToWorkgroupOp::applyToOne( if (op.getSymName() == target.getName()) exportOp = op; }); if (!exportOp) { - state.getTopLevel()->emitOpError("no IREE::HAL::ExecutableExportOp found"); - return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); + results.assign(1, nullptr); + return mlir::emitSilenceableFailure( + target, "no IREE::HAL::ExecutableExportOp found"); } scf::ForeachThreadOp topLevelForeachThreadOp; @@ -463,18 +474,19 @@ transform_dialect::ForeachThreadToWorkgroupOp::applyToOne( }); if (walkResult.wasInterrupted()) { - state.getTopLevel()->emitOpError( - "could not find a unique topLevel scf.foreach_thread"); - return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); + results.assign(1, nullptr); + return mlir::emitSilenceableFailure( + target, "could not find a unique topLevel scf.foreach_thread"); } SimplePatternRewriter rewriter(topLevelForeachThreadOp); if (failed(rewriteForeachThreadToWorkgroup(topLevelForeachThreadOp, exportOp, - rewriter))) - return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); + rewriter))) { + return mlir::emitDefiniteFailure(target, + "rewriteForeachThreadToWorkgroup failed"); + } results.assign({target}); - return DiagnosedSilenceableFailure(success()); } @@ -560,10 +572,7 @@ void transform_dialect::TileToForeachThreadAndWorkgroupCountRegion::getEffects( transform::onlyReadsHandle(getTileSizes(), effects); transform::onlyReadsHandle(getNumThreads(), effects); transform::producesHandle(getResults(), effects); - effects.emplace_back(MemoryEffects::Read::get(), - transform::PayloadIRResource::get()); - effects.emplace_back(MemoryEffects::Write::get(), - transform::PayloadIRResource::get()); + transform::modifiesPayload(effects); } DiagnosedSilenceableFailure @@ -575,15 +584,14 @@ transform_dialect::TileToForeachThreadAndWorkgroupCountRegion::apply( auto funcOp = targetOps.front()->getParentOfType(); FailureOr exportOp = getEntryPoint(funcOp); if (failed(exportOp)) { - state.getTopLevel()->emitOpError("couldn't find export op for func"); - return DiagnosedSilenceableFailure(reportUnknownTransformError(funcOp)); + return mlir::emitDefiniteFailure(state.getTopLevel(), + "couldn't find export op for func"); } SmallVector mixedTileSizes = getMixedTileSizes(); if (mixedTileSizes.empty()) { - exportOp.value()->emitOpError("require tile sizes to be specified"); - return DiagnosedSilenceableFailure( - reportUnknownTransformError(exportOp.value())); + return mlir::emitDefiniteFailure(exportOp.value(), + "require tile sizes to be specified"); } /// Lower the workgroup count region in keeping with the way dispatch @@ -591,9 +599,8 @@ transform_dialect::TileToForeachThreadAndWorkgroupCountRegion::apply( IRRewriter rewriter(getContext()); if (failed(lowerWorkgroupCountComputingRegion(rewriter, exportOp.value(), mixedTileSizes))) { - exportOp.value()->emitOpError("failed to lower workgroup count region"); - return DiagnosedSilenceableFailure( - reportUnknownTransformError(exportOp.value())); + return mlir::emitDefiniteFailure(exportOp.value(), + "failed to lower workgroup count region"); } ArrayRef targets = state.getPayloadOps(getTarget()); @@ -607,11 +614,16 @@ transform_dialect::TileToForeachThreadAndWorkgroupCountRegion::apply( targets, getMixedNumThreads(), getMixedTileSizes(), getThreadDimMapping(), tileOps, tiledOps); - if (!diag.succeeded()) return diag; + if (!diag.succeeded()) { + transformResults.set(getForeachThreadOp().cast(), + SmallVector{}); + transformResults.set(getTiledOp().cast(), + SmallVector{}); + return diag; + } transformResults.set(getForeachThreadOp().cast(), tileOps); transformResults.set(getTiledOp().cast(), tiledOps); - return DiagnosedSilenceableFailure(success()); } diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp index 2088d83de15b..771d3362029c 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp @@ -238,12 +238,11 @@ transform_dialect::VectorToWarpExecuteOnLane0Op::applyToOne( scf::IfOp target, SmallVectorImpl &results, transform::TransformState &state) { if (!isa(state.getTopLevel())) { - state.getTopLevel()->emitOpError( - "requires HAL::ExecutableOp or HAL::ExecutableVariantOp toplevel so " - "that IR is properly isolated. This is required so we can safely " - "inspect the HAL::ExecutableExportOp under multi-threaded pass " - "assumptions."); - return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); + return emitDefaultSilenceableFailure(state.getTopLevel()) + << "requires HAL::ExecutableOp or HAL::ExecutableVariantOp toplevel " + "so that IR is properly isolated. This is required so we can " + "safely inspect the HAL::ExecutableExportOp under multi-threaded " + "pass assumptions."; } auto halExecutableVariantOp = diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/StructuredTransformOpsExt.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/StructuredTransformOpsExt.cpp index 969443837d6a..9b8c05b43dae 100644 --- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/StructuredTransformOpsExt.cpp +++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgTransform/IR/StructuredTransformOpsExt.cpp @@ -569,7 +569,7 @@ forgetUnnecessaryHandles(transform::TransformState &state, continue; for (Operation *payload : state.getPayloadOps(operand)) { - if (seen.contains(payload)) + if (!payload || seen.contains(payload)) continue; SmallVector allHandles; (void)state.getHandlesForPayloadOp(payload, allHandles); @@ -592,7 +592,7 @@ forgetUnnecessaryHandles(transform::TransformState &state, if (!result.getUses().empty()) continue; for (Operation *payload : state.getPayloadOps(result)) { - if (seen.contains(payload)) + if (!payload || seen.contains(payload)) continue; listener->removeMappings(payload); seen.insert(payload); diff --git a/tests/transform_dialect/cuda/BUILD b/tests/transform_dialect/cuda/BUILD index c093771fff3a..93948e982242 100644 --- a/tests/transform_dialect/cuda/BUILD +++ b/tests/transform_dialect/cuda/BUILD @@ -7,7 +7,7 @@ # Tests for end-to-end IREE support of entire models or their close derivatives. load("//build_tools/bazel:build_defs.oss.bzl", "iree_cmake_extra_content") -#load("//build_tools/bazel:iree_lit_test.bzl", "iree_lit_test_suite") +load("//build_tools/bazel:iree_lit_test.bzl", "iree_lit_test_suite") package( features = ["layering_check"], @@ -23,37 +23,39 @@ endif() inline = True, ) -# TODO: re-enable the tests -# iree_lit_test_suite( -# name = "lit", -# srcs = [ -# "reduction.mlir", -# "softmax.mlir", -# ], -# cfg = "//tests:lit.cfg.py", -# # transform dialect spec files are MLIR files that specify a transformation, -# # they need to be included as data. -# data = [ -# "reduction_codegen_spec.mlir", -# "softmax_codegen_spec.mlir", -# # FIXME: This cannot be retired yet as there is some writeonly vs readwrite -# # issue and we even end up emitting out of bounds accesses. -# "softmax_dispatch_spec.mlir", -# "softmax_fused_codegen_spec.mlir", -# ], -# tags = [ -# # CUDA cuInit fails with sanitizer on. -# "noasan", -# "nomsan", -# "notsan", -# "noubsan", -# "requires-gpu-nvidia", -# "driver=cuda", -# ], -# tools = [ -# "//tools:iree-compile", -# "//tools:iree-opt", -# "//tools:iree-run-module", -# "@llvm-project//llvm:FileCheck", -# ], -# ) +iree_lit_test_suite( + name = "lit", + srcs = [ + "reduction.mlir", + "softmax.mlir", + ], + cfg = "//tests:lit.cfg.py", + # transform dialect spec files are MLIR files that specify a transformation, + # they need to be included as data. + data = [ + "reduction_codegen_spec.mlir", + "softmax_codegen_spec.mlir", + # + # FIXME: Fused codegen must be used with the custom dispatch region formation + # because IREE's pulls in tensor.empty by default. + # This results in threadprivate allocations and prevents vector distribution. + # + "softmax_dispatch_spec.mlir", + "softmax_fused_codegen_spec.mlir", + ], + tags = [ + # CUDA cuInit fails with sanitizer on. + "noasan", + "nomsan", + "notsan", + "noubsan", + "requires-gpu-nvidia", + "driver=cuda", + ], + tools = [ + "//tools:iree-compile", + "//tools:iree-opt", + "//tools:iree-run-module", + "@llvm-project//llvm:FileCheck", + ], +) diff --git a/tests/transform_dialect/cuda/CMakeLists.txt b/tests/transform_dialect/cuda/CMakeLists.txt index 79331ad86a63..139817fe1b6c 100644 --- a/tests/transform_dialect/cuda/CMakeLists.txt +++ b/tests/transform_dialect/cuda/CMakeLists.txt @@ -14,4 +14,29 @@ if(NOT IREE_HAL_DRIVER_CUDA OR NOT IREE_TARGET_BACKEND_CUDA) return() endif() +iree_lit_test_suite( + NAME + lit + SRCS + "reduction.mlir" + "softmax.mlir" + TOOLS + FileCheck + iree-compile + iree-opt + iree-run-module + DATA + reduction_codegen_spec.mlir + softmax_codegen_spec.mlir + softmax_dispatch_spec.mlir + softmax_fused_codegen_spec.mlir + LABELS + "noasan" + "nomsan" + "notsan" + "noubsan" + "requires-gpu-nvidia" + "driver=cuda" +) + ### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ### diff --git a/tests/transform_dialect/cuda/reduction_codegen_spec.mlir b/tests/transform_dialect/cuda/reduction_codegen_spec.mlir index 18db174cb889..f011e7c38d44 100644 --- a/tests/transform_dialect/cuda/reduction_codegen_spec.mlir +++ b/tests/transform_dialect/cuda/reduction_codegen_spec.mlir @@ -1,6 +1,6 @@ // RUN: iree-opt %s -transform.structured.canonicalized_sequence failures(propagate) { +transform.structured.canonicalized_sequence failures(suppress) { ^bb1(%variant_op: !pdl.operation): %fill = transform.structured.match ops{["linalg.fill"]} in %variant_op @@ -15,7 +15,7 @@ transform.structured.canonicalized_sequence failures(propagate) { // The mapping to block ids can only happen after bufferization atm. %foreach_thread_grid, %grid_combiner_op = transform.iree.tile_to_foreach_thread_and_workgroup_count_region %combiner_op tile_sizes [1] - %not_combiner = transform.merge_handles %fill, %more_parallel_fill_op, %more_parallel_op + %not_combiner = transform.merge_handles %fill, %more_parallel_fill_op, %more_parallel_op : !pdl.operation transform.structured.fuse_into_containing_op %not_combiner into %foreach_thread_grid // Second level of tiling + fusion parallelizes to threads. diff --git a/tests/transform_dialect/cuda/softmax.mlir b/tests/transform_dialect/cuda/softmax.mlir index e7959d47805b..e2bebb844da5 100644 --- a/tests/transform_dialect/cuda/softmax.mlir +++ b/tests/transform_dialect/cuda/softmax.mlir @@ -13,13 +13,14 @@ // RUN: iree-run-module --entry_function=max_sub_exp --device=cuda | \ // RUN: FileCheck %s +/// +/// FIXME: Fused codegen must be used with the custom dispatch region formation +/// because IREE's pulls in tensor.empty by default. +/// This results in threadprivate allocations and prevents vector distribution. +/// // RUN: iree-opt %s --iree-hal-target-backends=cuda \ // RUN: --iree-abi-transformation-pipeline \ // RUN: --iree-flow-transformation-pipeline \ -/// -/// FIXME: This cannot be retired yet as there is some writeonly vs readwrite -/// issue and we even end up emitting out of bounds accesses. -/// // RUN: --iree-flow-dispatch-use-transform-dialect=%p/softmax_dispatch_spec.mlir \ // RUN: --iree-stream-transformation-pipeline \ // RUN: --iree-hal-configuration-pipeline | \ @@ -27,11 +28,12 @@ // RUN: --iree-codegen-llvmgpu-use-transform-dialect=%p/softmax_fused_codegen_spec.mlir | \ // RUN: FileCheck %s --check-prefix=CHECK-SHUFFLE -// RUN: iree-compile %s --iree-hal-target-backends=cuda \ /// -/// FIXME: This cannot be retired yet as there is some writeonly vs readwrite -/// issue and we even end up emitting out of bounds accesses. +/// FIXME: Fused codegen must be used with the custom dispatch region formation +/// because IREE's pulls in tensor.empty by default. +/// This results in threadprivate allocations and prevents vector distribution. /// +// RUN: iree-compile %s --iree-hal-target-backends=cuda \ // RUN: --iree-flow-dispatch-use-transform-dialect=%p/softmax_dispatch_spec.mlir \ // RUN: --iree-codegen-llvmgpu-use-transform-dialect=%p/softmax_fused_codegen_spec.mlir | \ // RUN: iree-run-module --entry_function=max_sub_exp --device=cuda | \ diff --git a/tests/transform_dialect/cuda/softmax_codegen_spec.mlir b/tests/transform_dialect/cuda/softmax_codegen_spec.mlir index 4607570487ce..e3f06381cc65 100644 --- a/tests/transform_dialect/cuda/softmax_codegen_spec.mlir +++ b/tests/transform_dialect/cuda/softmax_codegen_spec.mlir @@ -10,7 +10,7 @@ transform.structured.canonicalized_sequence failures(propagate) { %fill = transform.structured.match ops{["linalg.fill"]} in %variant_op %red = transform.structured.match interface{LinalgOp} attributes{iterator_types = ["parallel", "parallel", "reduction"]} in %variant_op - %not_root = merge_handles %fill, %red + %not_root = merge_handles %fill, %red : !pdl.operation %foreach_thread, %tiled_generic = transform.iree.tile_to_foreach_thread_and_workgroup_count_region %root tile_sizes [1, 4] transform.structured.fuse_into_containing_op %not_root into %foreach_thread diff --git a/tests/transform_dialect/cuda/softmax_fused_codegen_spec.mlir b/tests/transform_dialect/cuda/softmax_fused_codegen_spec.mlir index 67d896ad8eac..d4235a5b9420 100644 --- a/tests/transform_dialect/cuda/softmax_fused_codegen_spec.mlir +++ b/tests/transform_dialect/cuda/softmax_fused_codegen_spec.mlir @@ -1,8 +1,7 @@ // RUN: iree-opt %s // Codegen -transform.structured.canonicalized_sequence failures(propagate) { -// transform.sequence %arg0 failures(propagate) { +transform.structured.canonicalized_sequence failures(suppress) { ^bb1(%variant_op: !pdl.operation): // First level of tiling + fusion parallelizes to blocks. // The mapping to block ids can only happen after bufferization atm @@ -11,12 +10,18 @@ transform.structured.canonicalized_sequence failures(propagate) { %fill = transform.structured.match ops{["linalg.fill"]} in %variant_op %red = transform.structured.match interface{LinalgOp} attributes{iterator_types = ["parallel", "parallel", "reduction"]} in %variant_op - %not_root = merge_handles %fill, %red + %not_root = merge_handles %fill, %red : !pdl.operation + // This must be used with the custom dispatch region formation because IREE's + // pulls in tensor.empty by default. This results in threadprivate allocations + // and prevents vector distribution down the line. %foreach_thread, %tiled_generic = transform.structured.tile_to_foreach_thread_op %root tile_sizes [1, 1] (mapped to dims [0, 1, 2]) + // %foreach_thread, %tiled_generic = + // transform.iree.tile_to_foreach_thread_and_workgroup_count_region %root tile_sizes [1, 1] + // (mapped to dims [0, 1, 2]) transform.structured.fuse_into_containing_op %not_root into %foreach_thread - + // Second level of tiling + fusion parallelizes to threads. // Leaving the reduction untiled on threadIdx.x makes it sequential on // threadIdx.x. After distribution, predication by if (threadIdx.x == 0) is @@ -25,7 +30,7 @@ transform.structured.canonicalized_sequence failures(propagate) { %fill_linalg = transform.structured.match ops{["linalg.fill"]} in %variant_op %reduction_linalg = transform.structured.match ops{["linalg.generic"]} attributes{iterator_types = ["parallel", "parallel", "reduction"]} in %variant_op - %not_root_2 = merge_handles %fill_linalg, %reduction_linalg + %not_root_2 = merge_handles %fill_linalg, %reduction_linalg : !pdl.operation %parallel_linalg = transform.structured.match ops{["linalg.generic"]} attributes{iterator_types = ["parallel", "parallel", "parallel"]} in %variant_op %foreach_thread_2, %parallel_linalg_2 = diff --git a/tests/transform_dialect/cuda/softmax_fused_codegen_spec.mlir.broken b/tests/transform_dialect/cuda/softmax_fused_codegen_spec.mlir.broken deleted file mode 100644 index 68f891ed38f1..000000000000 --- a/tests/transform_dialect/cuda/softmax_fused_codegen_spec.mlir.broken +++ /dev/null @@ -1,57 +0,0 @@ -// RUN: iree-opt %s - -// Codegen -transform.structured.canonicalized_sequence failures(propagate) { -// transform.sequence %arg0 failures(propagate) { -^bb1(%variant_op: !pdl.operation): - // First level of tiling + fusion parallelizes to blocks. - // The mapping to block ids can only happen after bufferization atm - %root = transform.structured.match interface{LinalgOp} - attributes{iterator_types = ["parallel", "parallel", "parallel"]} in %variant_op - %fill = transform.structured.match ops{["linalg.fill"]} in %variant_op - %red = transform.structured.match interface{LinalgOp} - attributes{iterator_types = ["parallel", "parallel", "reduction"]} in %variant_op - %not_root = merge_handles %fill, %red - %foreach_thread, %tiled_generic = - transform.iree.tile_to_foreach_thread_and_workgroup_count_region %root tile_sizes [1, 1] - (mapped to dims [0, 1, 2]) - transform.structured.fuse_into_containing_op %not_root into %foreach_thread - - // Second level of tiling + fusion parallelizes to threads. - // Leaving the reduction untiled on threadIdx.x makes it sequential on - // threadIdx.x. After distribution, predication by if (threadIdx.x == 0) is - // introduced and opportunities for distributing vector ops across warps - // appear. - %fill_linalg = transform.structured.match ops{["linalg.fill"]} in %variant_op - %reduction_linalg = transform.structured.match ops{["linalg.generic"]} - attributes{iterator_types = ["parallel", "parallel", "reduction"]} in %variant_op - %not_root_2 = merge_handles %fill_linalg, %reduction_linalg - %parallel_linalg = transform.structured.match ops{["linalg.generic"]} - attributes{iterator_types = ["parallel", "parallel", "parallel"]} in %variant_op - %foreach_thread_2, %parallel_linalg_2 = - transform.structured.tile_to_foreach_thread_op %parallel_linalg tile_sizes [1, 1, 0] - (mapped to dims [2, 1, 0]) - transform.structured.fuse_into_containing_op %not_root_2 into %foreach_thread_2 - - // Rank-reduce and vectorize. - %func = transform.structured.match ops{["func.func"]} in %variant_op - %funcx = transform.iree.apply_patterns %func { rank_reducing } - transform.structured.vectorize %funcx - - // Bufferization is necessary for: - // 1. lowering scf.foreach_thread to workgroup (block level parallelism) - // 2. lowering scf.foreach_thread to gpu (thread level parallelism) - // 3. introducing predication (due to 1. + 2.) which enables rewriting to - // warp_execute_on_lane_0 and later vector distribution. - %variant_op_2 = transform.iree.bufferize { target_gpu } %variant_op - %func_2 = transform.structured.match ops{["func.func"]} in %variant_op_2 - %func_3 = transform.iree.foreach_thread_to_workgroup %func_2 - transform.iree.map_nested_foreach_thread_to_gpu_threads %func_3 - { workgroup_size = [32, 1, 1] } - - // Vector distribution needs to happen on buffers. - %end_func = transform.structured.match ops{["func.func"]} in %variant_op_2 - %if_op = transform.structured.match ops{["scf.if"]} in %variant_op_2 - %warp = transform.iree.vector.to_warp_execute_on_lane_0 %if_op { warp_size = 32 } - transform.iree.vector.warp_distribute %end_func -}