Skip to content

Commit

Permalink
added a flag to optionally enable fast math
Browse files Browse the repository at this point in the history
Signed-off-by: Alexandre Eichenberger <alexe@us.ibm.com>
  • Loading branch information
AlexandreEichenberger committed Oct 10, 2024
1 parent 4fea80f commit 710e016
Show file tree
Hide file tree
Showing 12 changed files with 175 additions and 34 deletions.
16 changes: 15 additions & 1 deletion src/Compiler/CompilerOptions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ int onnxOpTransformThreshold; // onnx-mlir only
bool onnxOpTransformReport; // onnx-mlir only
bool enableParallel; // onnx-mlir only
bool disableSimdOption; // onnx-mlir only
bool enableFastMathOption; // onnx-mlir only
bool disableRecomposeOption; // onnx-mlir only
bool enableSimdDataLayout; // onnx-mlir only
bool verifyInputTensors; // onnx-mlir only
Expand Down Expand Up @@ -516,6 +517,13 @@ static llvm::cl::opt<bool, true> disableSimdOptionOpt("disable-simd",
llvm::cl::location(disableSimdOption), llvm::cl::init(false),
llvm::cl::cat(OnnxMlirOptions));

static llvm::cl::opt<bool, true> enableFastMathOptionOpt("enable-fast-math",
llvm::cl::desc(
"Enable fast math optimizations (default=false). Set to `true` "
"to enable fast math options at O3."),
llvm::cl::location(enableFastMathOption), llvm::cl::init(false),
llvm::cl::cat(OnnxMlirOptions));

static llvm::cl::opt<bool, true> enableSimdDataLayoutOpt("simd-data-layout",
llvm::cl::desc("Enable SIMD optimization for convolution (default=false)\n"
"Set to 'true' if you want to enable SIMD optimizations."),
Expand Down Expand Up @@ -1249,8 +1257,14 @@ void initCompilerConfig() {

// Enable aggressive optimization for NNPA with -O3
if (OptimizationLevel == OptLevel::O3 &&
getTargetAccel().find("NNPA") != std::string::npos &&
getTargetAccel().find("NNPA") != std::string::npos) {
// Have O3 and NNPA. May enable fast math default in the future.
}

// Enabling unsafe math.
if (enableFastMathOption &&
getLLVMOption().find("enable-unsafe-fp-math") == std::string::npos) {
// Fast math option is enabled (in general)
setLLVMOption(getLLVMOption() + " --enable-unsafe-fp-math");
}
}
Expand Down
1 change: 1 addition & 0 deletions src/Compiler/CompilerOptions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ extern int onnxOpTransformThreshold; // onnx-mlir only
extern bool onnxOpTransformReport; // onnx-mlir only
extern bool enableParallel; // onnx-mlir only
extern bool disableSimdOption; // onnx-mlir only
extern bool enableFastMathOption; // onnx-mlir only
extern bool disableRecomposeOption; // onnx-mlir only
extern bool enableSimdDataLayout; // onnx-mlir only
extern bool verifyInputTensors; // onnx-mlir only
Expand Down
1 change: 1 addition & 0 deletions src/Compiler/CompilerPasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ void addONNXToKrnlPasses(mlir::PassManager &pm, int optLevel, bool enableCSE,
instrumentSignatureString));
pm.addPass(onnx_mlir::createLowerToKrnlPass(/*enableTiling*/ optLevel >= 3,
/*enableSIMD*/ optLevel >= 3 && !disableSimdOption, enableParallel,
/*enableFastMath*/ optLevel >= 3 && enableFastMathOption,
/*opsToCall*/ opsForCall));
// An additional pass of canonicalization is helpful because lowering
// from ONNX dialect to Standard dialect exposes additional canonicalization
Expand Down
17 changes: 10 additions & 7 deletions src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ std::map<std::string, std::string> ONNXEntryPointLowering::typeMap = {
void populateONNXToKrnlConversionPattern(RewritePatternSet &patterns,
TypeConverter &typeConverter, MLIRContext *ctx, DimAnalysis *dimAnalysis,
bool enableTiling, bool enableSIMD, bool enableParallel,
std::string opsForCall) {
bool enableFastMath, std::string opsForCall) {
// clang-format off
// Type conversion for function signatures.
// Call MLIR FuncOp signature conversion when result type is a ranked tensor.
Expand Down Expand Up @@ -224,8 +224,8 @@ void populateONNXToKrnlConversionPattern(RewritePatternSet &patterns,
// ObjectDetection
populateLoweringONNXNonMaxSuppressionOpPattern(patterns, typeConverter, ctx);
// Quantization
populateLoweringONNXDynamicQuantizeLinearOpPattern(patterns, typeConverter, ctx, enableSIMD, enableParallel);
populateLoweringONNXQuantizeLinearOpPattern(patterns, typeConverter, ctx, enableSIMD, enableParallel);
populateLoweringONNXDynamicQuantizeLinearOpPattern(patterns, typeConverter, ctx, enableSIMD, enableParallel, enableFastMath);
populateLoweringONNXQuantizeLinearOpPattern(patterns, typeConverter, ctx, enableSIMD, enableParallel, enableFastMath);
// Tensor
populateLoweringONNXArgMinMaxOpPattern(patterns, typeConverter, ctx);
populateLoweringONNXDimOpPattern(patterns, typeConverter, ctx);
Expand Down Expand Up @@ -309,12 +309,13 @@ struct FrontendToKrnlLoweringPass
FrontendToKrnlLoweringPass(const FrontendToKrnlLoweringPass &pass)
: PassWrapper<FrontendToKrnlLoweringPass, OperationPass<ModuleOp>>() {}
FrontendToKrnlLoweringPass(bool enableTiling, bool enableSIMD,
bool enableParallel, std::string opsForCall) {
bool enableParallel, bool enableFastMath, std::string opsForCall) {
// Below, need explicit assignment to enable implicit conversion of bool to
// Option<bool>.
this->enableTiling = enableTiling;
this->enableSIMD = enableSIMD;
this->enableParallel = enableParallel;
this->enableFastMath = enableFastMath;
this->opsForCall = opsForCall;
}

Expand Down Expand Up @@ -343,6 +344,8 @@ struct FrontendToKrnlLoweringPass
llvm::cl::desc("Enable SIMD code gen"), llvm::cl::init(false)};
Option<bool> enableParallel{*this, "enable-parallel",
llvm::cl::desc("Enable parallelization"), llvm::cl::init(false)};
Option<bool> enableFastMath{*this, "enable-fast-math",
llvm::cl::desc("Enable fast math optimizations"), llvm::cl::init(false)};
Option<std::string> opsForCall{*this, "ops-for-call",
llvm::cl::desc("Specify ops to be lowered to krnl.call"),
llvm::cl::init("")};
Expand Down Expand Up @@ -430,7 +433,7 @@ void FrontendToKrnlLoweringPass::runOnOperation() {
// Define patterns.
populateONNXToKrnlConversionPattern(patterns, krnlTypeConverter,
&getContext(), dimAnalysis, enableTiling, enableSIMD, enableParallel,
opsForCall);
enableFastMath, opsForCall);

// Rewrite patterns for accelerators.
for (auto *accel : onnx_mlir::accel::Accelerator::getAccelerators())
Expand All @@ -450,9 +453,9 @@ std::unique_ptr<Pass> createLowerToKrnlPass() {
}

std::unique_ptr<Pass> createLowerToKrnlPass(bool enableTiling, bool enableSIMD,
bool enableParallel, std::string opsForCall) {
bool enableParallel, bool enableFastMath, std::string opsForCall) {
return std::make_unique<FrontendToKrnlLoweringPass>(
enableTiling, enableSIMD, enableParallel, opsForCall);
enableTiling, enableSIMD, enableParallel, enableFastMath, opsForCall);
}

//===----------------------------------------------------------------------===//
Expand Down
6 changes: 3 additions & 3 deletions src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ class KrnlTypeConverter : public mlir::TypeConverter {
// For all ONNX operations.
void populateONNXToKrnlConversionPattern(mlir::RewritePatternSet &,
mlir::TypeConverter &, mlir::MLIRContext *, bool enableTiling,
bool enableParallel);
bool enableParallel, bool enableFastMath);

// `ControlFlow` directory methods:
void populateLoweringONNXIfOpPattern(
Expand Down Expand Up @@ -380,10 +380,10 @@ void populateLoweringONNXNonMaxSuppressionOpPattern(
// `Quantization` directory methods:
void populateLoweringONNXDynamicQuantizeLinearOpPattern(
mlir::RewritePatternSet &, mlir::TypeConverter &, mlir::MLIRContext *,
bool enableSIMD, bool enableParallel);
bool enableSIMD, bool enableParallel, bool enableFastMath);
void populateLoweringONNXQuantizeLinearOpPattern(mlir::RewritePatternSet &,
mlir::TypeConverter &, mlir::MLIRContext *, bool enableSIMD,
bool enableParallel);
bool enableParallel, bool enableFastMath);

// `RNN` directory methods:
void populateLoweringONNXGRUOpPattern(
Expand Down
12 changes: 7 additions & 5 deletions src/Conversion/ONNXToKrnl/Quantization/DynamicQuantizeLinear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,14 @@ void emitDynamicQuantizationLinearScalarParameters(
struct ONNXDynamicQuantizeLinearOpLowering
: public OpConversionPattern<ONNXDynamicQuantizeLinearOp> {
ONNXDynamicQuantizeLinearOpLowering(TypeConverter &typeConverter,
MLIRContext *ctx, bool enableSIMD, bool enableParallel)
MLIRContext *ctx, bool enableSIMD, bool enableParallel,
bool enableFastMath)
: OpConversionPattern(typeConverter, ctx), enableSIMD(enableSIMD),
enableParallel(enableParallel) {}
enableParallel(enableParallel), enableFastMath(enableFastMath) {}

bool enableSIMD = false;
bool enableParallel = false;
bool enableFastMath = false;

using LocalDialectBuilder = MultiDialectBuilder<KrnlBuilder,
IndexExprBuilderForKrnl, MathBuilder, MemRefBuilder>;
Expand Down Expand Up @@ -137,7 +139,7 @@ struct ONNXDynamicQuantizeLinearOpLowering
emitQuantizationLinearScalarParameters(rewriter, loc, op, xMemRefType,
yMemRefType, Y, shapeHelper.getOutputDims(0), X, qMin, qMax, scale,
zeroPoint, wantZeroPoint /*wanted one, so we have a zero point*/,
enableSIMD, enableParallel);
enableSIMD, enableParallel, enableFastMath);

rewriter.replaceOp(op, {Y, YScale, YZeroPoint});
onnxToKrnlSimdReport(op);
Expand All @@ -147,9 +149,9 @@ struct ONNXDynamicQuantizeLinearOpLowering

void populateLoweringONNXDynamicQuantizeLinearOpPattern(
RewritePatternSet &patterns, TypeConverter &typeConverter, MLIRContext *ctx,
bool enableSIMD, bool enableParallel) {
bool enableSIMD, bool enableParallel, bool enableFastMath) {
patterns.insert<ONNXDynamicQuantizeLinearOpLowering>(
typeConverter, ctx, enableSIMD, enableParallel);
typeConverter, ctx, enableSIMD, enableParallel, enableFastMath);
}

} // namespace onnx_mlir
5 changes: 3 additions & 2 deletions src/Conversion/ONNXToKrnl/Quantization/QuantizeHelper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,15 @@
namespace onnx_mlir {

// Given an input, scale, zero point, qMin, and qMax, perform a linear
// quantization and store in alloc.
// quantization and store in alloc. FastMath enables taking the reciprocal for
// faster results on machines where mul is faster than div.
void emitQuantizationLinearScalarParameters(
mlir::ConversionPatternRewriter &rewriter, mlir::Location loc,
mlir::Operation *op, mlir::MemRefType inputType,
mlir::MemRefType quantizedType, mlir::Value alloc, DimsExpr &allocDims,
mlir::Value input, mlir::Value qMin, mlir::Value qMax, mlir::Value scale,
mlir::Value zeroPoint, bool hasZeroPoint, bool enableSIMD,
bool enableParallel);
bool enableParallel, bool enableFastMath);

// Scan the input to compute scale, zeroPoint, and quantizedZeroPoint given qMin
// and qMax.
Expand Down
30 changes: 16 additions & 14 deletions src/Conversion/ONNXToKrnl/Quantization/QuantizeLinear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

using namespace mlir;

#define DISABLE_FAST_MATH_FOR_QL 0 /* disable reciprocal (for debug) */
#define DISABLE_FAST_MATH 0 /* disable reciprocal (for debug) */

namespace onnx_mlir {

Expand All @@ -30,7 +30,7 @@ void emitQuantizationLinearScalarParameters(ConversionPatternRewriter &rewriter,
Location loc, Operation *op, MemRefType inputType, MemRefType quantizedType,
Value alloc, DimsExpr &allocDims, Value input, Value qMin, Value qMax,
Value scale, Value zeroPoint, bool hasZeroPoint, bool enableSIMD,
bool enableParallel) {
bool enableParallel, bool enableFastMath) {
MultiDialectBuilder<KrnlBuilder, MemRefBuilder, VectorBuilder, MathBuilder>
create(rewriter, loc);

Expand Down Expand Up @@ -77,12 +77,13 @@ void emitQuantizationLinearScalarParameters(ConversionPatternRewriter &rewriter,
DimsExpr outputAF;
outputAF.emplace_back(zero);

Value oneOverScale;
bool useOneOverScale =
!DISABLE_FAST_MATH_FOR_QL && isa<FloatType>(inputElementType);
if (useOneOverScale) {
Value scaleReciprocal;
bool useReciprocal =
!DISABLE_FAST_MATH && enableFastMath && isa<FloatType>(inputElementType);
fprintf(stderr, "hi alex, use reciprocal %d\n", (int)useReciprocal);
if (useReciprocal) {
Value one = create.math.constant(inputElementType, 1.0);
oneOverScale = create.math.div(one, scale);
scaleReciprocal = create.math.div(one, scale);
}
create.krnl.simdIterateIE(simdLb, simdUb, totVL, simdOnly, enableParallel,
{flatInput}, {inputAF}, {flatAlloc}, {outputAF},
Expand All @@ -91,8 +92,8 @@ void emitQuantizationLinearScalarParameters(ConversionPatternRewriter &rewriter,
Value x = inputVals[0];
// Scale
Value scaleX;
if (useOneOverScale)
scaleX = create.math.mul(x, oneOverScale);
if (useReciprocal)
scaleX = create.math.mul(x, scaleReciprocal);
else
scaleX = create.math.div(x, scale);
// Round
Expand Down Expand Up @@ -120,12 +121,13 @@ void emitQuantizationLinearScalarParameters(ConversionPatternRewriter &rewriter,
struct ONNXQuantizeLinearOpLowering
: public OpConversionPattern<ONNXQuantizeLinearOp> {
ONNXQuantizeLinearOpLowering(TypeConverter &typeConverter, MLIRContext *ctx,
bool enableSIMD, bool enableParallel)
bool enableSIMD, bool enableParallel, bool enableFastMath)
: OpConversionPattern(typeConverter, ctx), enableSIMD(enableSIMD),
enableParallel(enableParallel) {}
enableParallel(enableParallel), enableFastMath(enableFastMath) {}

bool enableSIMD = false;
bool enableParallel = false;
bool enableFastMath = false;

using LocalDialectBuilder = MultiDialectBuilder<KrnlBuilder,
IndexExprBuilderForKrnl, MathBuilder, MemRefBuilder>;
Expand Down Expand Up @@ -201,7 +203,7 @@ struct ONNXQuantizeLinearOpLowering
}
emitQuantizationLinearScalarParameters(rewriter, loc, op, xMemRefType,
yMemRefType, Y, shapeHelper.getOutputDims(0), X, qMin, qMax, scale,
zeroPoint, hasZeroPoint, enableSIMD, enableParallel);
zeroPoint, hasZeroPoint, enableSIMD, enableParallel, enableFastMath);

rewriter.replaceOp(op, {Y});
onnxToKrnlSimdReport(op);
Expand All @@ -211,9 +213,9 @@ struct ONNXQuantizeLinearOpLowering

void populateLoweringONNXQuantizeLinearOpPattern(RewritePatternSet &patterns,
TypeConverter &typeConverter, MLIRContext *ctx, bool enableSIMD,
bool enableParallel) {
bool enableParallel, bool enableFastMath) {
patterns.insert<ONNXQuantizeLinearOpLowering>(
typeConverter, ctx, enableSIMD, enableParallel);
typeConverter, ctx, enableSIMD, enableParallel, enableFastMath);
}

} // namespace onnx_mlir
1 change: 0 additions & 1 deletion src/Dialect/Krnl/DialectBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,6 @@ Value KrnlBuilder::constant(MemRefType type, StringRef name,
Value KrnlBuilder::roundEven(Value input) const {
Type elementType = getElementTypeOrSelf(input.getType());
MultiDialectBuilder<VectorBuilder, MathBuilder> create(*this);
// hi alex, may want to generalize support to scalar as well.
VectorType vecType = mlir::dyn_cast<VectorType>(input.getType());
if (VectorMachineSupport::requireCustomASM(
GenericOps::roundEvenGop, elementType)) {
Expand Down
3 changes: 2 additions & 1 deletion src/Pass/Passes.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ std::unique_ptr<mlir::Pass> createONNXPreKrnlVerifyPass();
/// Add pass for lowering to Krnl IR.
std::unique_ptr<mlir::Pass> createLowerToKrnlPass();
std::unique_ptr<mlir::Pass> createLowerToKrnlPass(bool enableTiling,
bool enableSIMD, bool enableParallel, std::string opsForCall);
bool enableSIMD, bool enableParallel, bool enableFastMath,
std::string opsForCall);
void configureOnnxToKrnlLoweringPass(bool reportOnParallel,
bool parallelIsEnabled, std::string specificParallelOps, bool reportOnSimd,
bool simdIsEnabled);
Expand Down
3 changes: 3 additions & 0 deletions src/Tools/onnx-mlir-opt/RegisterPasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,13 @@ void registerOMPasses(int optLevel) {
return krnl::createConvertKrnlToAffinePass();
});

fprintf(stderr, "hi alex, init lower krnl from here\n");
// hi alex, refine this? use disable/enable?
mlir::registerPass([optLevel]() -> std::unique_ptr<mlir::Pass> {
return createLowerToKrnlPass(/*enableTiling*/ optLevel >= 3,
/*enableSIMD, should consider disableSimdOption*/ optLevel >= 3,
/*enableParallel*/ false,
/*enableFastMath*/ false, /*default is still off*/
/*opsForCall*/ "");
});

Expand Down
Loading

0 comments on commit 710e016

Please sign in to comment.