diff --git a/flang/lib/Optimizer/CodeGen/CMakeLists.txt b/flang/lib/Optimizer/CodeGen/CMakeLists.txt index 879bc28d017a35..6c0cec4eae4680 100644 --- a/flang/lib/Optimizer/CodeGen/CMakeLists.txt +++ b/flang/lib/Optimizer/CodeGen/CMakeLists.txt @@ -26,6 +26,7 @@ add_flang_library(FIRCodeGen MLIRMathToFuncs MLIRMathToLLVM MLIRMathToLibm + MLIRMathToROCDL MLIROpenMPToLLVM MLIRBuiltinToLLVMIRTranslation MLIRLLVMToLLVMIRTranslation diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp index 21154902d23f8f..185ab78ea4a379 100644 --- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp +++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp @@ -34,6 +34,7 @@ #include "mlir/Conversion/MathToFuncs/MathToFuncs.h" #include "mlir/Conversion/MathToLLVM/MathToLLVM.h" #include "mlir/Conversion/MathToLibm/MathToLibm.h" +#include "mlir/Conversion/MathToROCDL/MathToROCDL.h" #include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h" #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" @@ -3559,6 +3560,14 @@ class FIRToLLVMLowering // as passes here. mlir::OpPassManager mathConvertionPM("builtin.module"); + bool isAMDGCN = fir::getTargetTriple(mod).isAMDGCN(); + // If compiling for AMD target some math operations must be lowered to AMD + // GPU library calls, the rest can be converted to LLVM intrinsics, which + // is handled in the mathToLLVM conversion. The lowering to libm calls is + // not needed since all math operations are handled this way. + if (isAMDGCN) + mathConvertionPM.addPass(mlir::createConvertMathToROCDL()); + // Convert math::FPowI operations to inline implementation // only if the exponent's width is greater than 32, otherwise, // it will be lowered to LLVM intrinsic operation by a later conversion. @@ -3598,7 +3607,8 @@ class FIRToLLVMLowering pattern); // Math operations that have not been converted yet must be converted // to Libm. - mlir::populateMathToLibmConversionPatterns(pattern); + if (!isAMDGCN) + mlir::populateMathToLibmConversionPatterns(pattern); mlir::populateComplexToLLVMConversionPatterns(typeConverter, pattern); mlir::populateVectorToLLVMConversionPatterns(typeConverter, pattern); diff --git a/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h b/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h new file mode 100644 index 00000000000000..04e148422487c3 --- /dev/null +++ b/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h @@ -0,0 +1,21 @@ +//===- MathToROCDL.h - Utils to convert from the complex dialect --------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_CONVERSION_MATHTOROCDL_MATHTOROCDL_H_ +#define MLIR_CONVERSION_MATHTOROCDL_MATHTOROCDL_H_ + +#include + +namespace mlir { +class Pass; + +#define GEN_PASS_DECL_CONVERTMATHTOROCDL +#include "mlir/Conversion/Passes.h.inc" + +} // namespace mlir + +#endif // MLIR_CONVERSION_MATHTOROCDL_MATHTOROCDL_H_ diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h index 2179ae18ac074b..2dc6c0b4fd1241 100644 --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -44,6 +44,7 @@ #include "mlir/Conversion/MathToFuncs/MathToFuncs.h" #include "mlir/Conversion/MathToLLVM/MathToLLVM.h" #include "mlir/Conversion/MathToLibm/MathToLibm.h" +#include "mlir/Conversion/MathToROCDL/MathToROCDL.h" #include "mlir/Conversion/MathToSPIRV/MathToSPIRVPass.h" #include "mlir/Conversion/MemRefToEmitC/MemRefToEmitCPass.h" #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index d094ee3b36ab95..eb2ca172110b21 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -707,6 +707,22 @@ def ConvertMathToLLVMPass : Pass<"convert-math-to-llvm"> { ]; } +//===----------------------------------------------------------------------===// +// MathToROCDL +//===----------------------------------------------------------------------===// + +def ConvertMathToROCDL : Pass<"convert-math-to-rocdl", "ModuleOp"> { + let summary = "Convert Math dialect to rocdl calls"; + let description = [{ + This pass converts supported Math ops to rocdl calls. + }]; + let dependentDialects = [ + "func::FuncDialect", + "math::MathDialect", + "vector::VectorDialect", + ]; +} + //===----------------------------------------------------------------------===// // MathToSPIRV //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt index 41ab7046b91ce3..f2aa98d6130811 100644 --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -34,6 +34,7 @@ add_subdirectory(LLVMCommon) add_subdirectory(MathToFuncs) add_subdirectory(MathToLibm) add_subdirectory(MathToLLVM) +add_subdirectory(MathToROCDL) add_subdirectory(MathToSPIRV) add_subdirectory(MemRefToEmitC) add_subdirectory(MemRefToLLVM) diff --git a/mlir/lib/Conversion/MathToROCDL/CMakeLists.txt b/mlir/lib/Conversion/MathToROCDL/CMakeLists.txt new file mode 100644 index 00000000000000..2771955aa94939 --- /dev/null +++ b/mlir/lib/Conversion/MathToROCDL/CMakeLists.txt @@ -0,0 +1,23 @@ +add_mlir_conversion_library(MLIRMathToROCDL + MathToROCDL.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MathToROCDL + + DEPENDS + MLIRConversionPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRDialectUtils + MLIRFuncDialect + MLIRGPUToGPURuntimeTransforms + MLIRMathDialect + MLIRLLVMCommonConversion + MLIRPass + MLIRTransformUtils + MLIRVectorDialect + MLIRVectorUtils + ) diff --git a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp new file mode 100644 index 00000000000000..310cc834f40179 --- /dev/null +++ b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp @@ -0,0 +1,147 @@ +//===-- MathToROCDL.cpp - conversion from Math to rocdl calls -------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/MathToROCDL/MathToROCDL.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Conversion/LLVMCommon/LoweringOptions.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/BuiltinDialect.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "mlir/Conversion/GPUCommon/GPUCommonPass.h" +#include "../GPUCommon/GPUOpsLowering.h" +#include "../GPUCommon/IndexIntrinsicsOpLowering.h" +#include "../GPUCommon/OpToFuncCallLowering.h" + +namespace mlir { +#define GEN_PASS_DEF_CONVERTMATHTOROCDL +#include "mlir/Conversion/Passes.h.inc" +} // namespace mlir + +using namespace mlir; + +#define DEBUG_TYPE "math-to-rocdl" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") + +template +static void populateOpPatterns(LLVMTypeConverter &converter, + RewritePatternSet &patterns, StringRef f32Func, + StringRef f64Func) { + patterns.add>(converter); + patterns.add>(converter, f32Func, f64Func); +} + +static void populateMathToROCDLConversionPatterns(LLVMTypeConverter &converter, + RewritePatternSet &patterns) { + // Handled by mathToLLVM: math::AbsIOp + // Handled by mathToLLVM: math::CopySignOp + // Handled by mathToLLVM: math::CountLeadingZerosOp + // Handled by mathToLLVM: math::CountTrailingZerosOp + // Handled by mathToLLVM: math::CgPopOp + // Handled by mathToLLVM: math::FmaOp + // FIXME: math::IPowIOp + // FIXME: math::FPowIOp + // Handled by mathToLLVM: math::RoundEvenOp + // Handled by mathToLLVM: math::RoundOp + // Handled by mathToLLVM: math::TruncOp + + populateOpPatterns(converter, patterns, "__ocml_fabs_f32", + "__ocml_fabs_f64"); + populateOpPatterns(converter, patterns, "__ocml_acos_f32", + "__ocml_acos_f64"); + populateOpPatterns(converter, patterns, "__ocml_acosh_f32", + "__ocml_acosh_f64"); + populateOpPatterns(converter, patterns, "__ocml_asin_f32", + "__ocml_asin_f64"); + populateOpPatterns(converter, patterns, "__ocml_asinh_f32", + "__ocml_asinh_f64"); + populateOpPatterns(converter, patterns, "__ocml_atan_f32", + "__ocml_atan_f64"); + populateOpPatterns(converter, patterns, "__ocml_atanh_f32", + "__ocml_atanh_f64"); + populateOpPatterns(converter, patterns, "__ocml_atan2_f32", + "__ocml_atan2_f64"); + populateOpPatterns(converter, patterns, "__ocml_cbrt_f32", + "__ocml_cbrt_f64"); + populateOpPatterns(converter, patterns, "__ocml_ceil_f32", + "__ocml_ceil_f64"); + populateOpPatterns(converter, patterns, "__ocml_cos_f32", + "__ocml_cos_f64"); + populateOpPatterns(converter, patterns, "__ocml_cosh_f32", + "__ocml_cosh_f64"); + populateOpPatterns(converter, patterns, "__ocml_sinh_f32", + "__ocml_sinh_f64"); + populateOpPatterns(converter, patterns, "__ocml_exp_f32", + "__ocml_exp_f64"); + populateOpPatterns(converter, patterns, "__ocml_exp2_f32", + "__ocml_exp2_f64"); + populateOpPatterns(converter, patterns, "__ocml_expm1_f32", + "__ocml_expm1_f64"); + populateOpPatterns(converter, patterns, "__ocml_floor_f32", + "__ocml_floor_f64"); + // FIXME: Different pass or new op in math? + // populateOpPatterns(converter, patterns, "__ocml_fmod_f32", + // "__ocml_fmod_f64"); + populateOpPatterns(converter, patterns, "__ocml_log_f32", + "__ocml_log_f64"); + populateOpPatterns(converter, patterns, "__ocml_log10_f32", + "__ocml_log10_f64"); + populateOpPatterns(converter, patterns, "__ocml_log1p_f32", + "__ocml_log1p_f64"); + populateOpPatterns(converter, patterns, "__ocml_log2_f32", + "__ocml_log2_f64"); + populateOpPatterns(converter, patterns, "__ocml_pow_f32", + "__ocml_pow_f64"); + populateOpPatterns(converter, patterns, "__ocml_rsqrt_f32", + "__ocml_rsqrt_f64"); + populateOpPatterns(converter, patterns, "__ocml_sin_f32", + "__ocml_sin_f64"); + populateOpPatterns(converter, patterns, "__ocml_sqrt_f32", + "__ocml_sqrt_f64"); + populateOpPatterns(converter, patterns, "__ocml_tanh_f32", + "__ocml_tanh_f64"); + populateOpPatterns(converter, patterns, "__ocml_tan_f32", + "__ocml_tan_f64"); + populateOpPatterns(converter, patterns, "__ocml_erf_f32", + "__ocml_erf_f64"); +} + +namespace { +struct ConvertMathToROCDLPass + : public impl::ConvertMathToROCDLBase { + ConvertMathToROCDLPass() = default; + void runOnOperation() override; +}; +} // namespace + +void ConvertMathToROCDLPass::runOnOperation() { + auto m = getOperation(); + MLIRContext *ctx = m.getContext(); + + + RewritePatternSet patterns(&getContext()); + LowerToLLVMOptions options(ctx, DataLayout(m)); + LLVMTypeConverter converter(ctx, options); + populateMathToROCDLConversionPatterns(converter, patterns); + + ConversionTarget target(getContext()); + target.addLegalDialect(); + target.addIllegalOp(); + if (failed(applyPartialConversion(m, target, std::move(patterns)))) + signalPassFailure(); +}