Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Steer round to even to Z's hardware operation (for Z), and MLIR/LLVM roundEven for the other platforms. #2970

Merged
merged 30 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
e5d44b9
use one over scale for quantize linear
AlexandreEichenberger Oct 1, 2024
dcd0fd4
add krnl.round
AlexandreEichenberger Oct 1, 2024
23beab9
update
AlexandreEichenberger Oct 3, 2024
ac301f0
update
AlexandreEichenberger Oct 3, 2024
cf2bd9e
fix lit test because of 1/scale usage
AlexandreEichenberger Oct 3, 2024
17cff60
try to use roundEven
AlexandreEichenberger Oct 4, 2024
b85672b
update
AlexandreEichenberger Oct 4, 2024
dea818e
update
AlexandreEichenberger Oct 7, 2024
5ac6305
first attempt with hw roundeven
AlexandreEichenberger Oct 8, 2024
c7a39f5
cleaner interface
AlexandreEichenberger Oct 8, 2024
0c0a908
use shapecast
AlexandreEichenberger Oct 8, 2024
b128e6e
added lowering pattern for shapecast
AlexandreEichenberger Oct 8, 2024
a022acf
fix merge issue, add reg pressure to quant lin
AlexandreEichenberger Oct 8, 2024
e41eb81
update
AlexandreEichenberger Oct 8, 2024
e08351f
update
AlexandreEichenberger Oct 8, 2024
b8ebc2e
enable scalar fiebr
AlexandreEichenberger Oct 8, 2024
2d25e55
update
AlexandreEichenberger Oct 8, 2024
5f1999c
update
AlexandreEichenberger Oct 8, 2024
c6239bd
update
AlexandreEichenberger Oct 8, 2024
9cd309b
cleanup
AlexandreEichenberger Oct 8, 2024
4f92d93
cleanup
AlexandreEichenberger Oct 8, 2024
e938dd0
fix lit tests
AlexandreEichenberger Oct 9, 2024
02e2f96
added doc for new Krnl op
AlexandreEichenberger Oct 9, 2024
f620eb3
format
AlexandreEichenberger Oct 9, 2024
42040e7
Merge branch 'main' into round-opt-v2
AlexandreEichenberger Oct 9, 2024
887ccd5
update
AlexandreEichenberger Oct 9, 2024
5b6b867
Merge branch 'round-opt-v2' of https://github.com/AlexandreEichenberg…
AlexandreEichenberger Oct 9, 2024
23031aa
respond to comments
AlexandreEichenberger Oct 9, 2024
3977369
reverted to using round even wiht emulation do to a mac os issue
AlexandreEichenberger Oct 9, 2024
9622cee
also reverted the non-z16 tests
AlexandreEichenberger Oct 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions docs/Dialects/krnl.md
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,12 @@ Interfaces: `MemoryEffectOpInterface`
| :-----: | ----------- |
| `parameters` | variadic of any type

#### Results:

| Result | Description |
| :----: | ----------- |
| `returnValue` | variadic of floating-point or integer

### `krnl.copy_from_tile_buffer` (KrnlCopyFromBufferOp)

_Copy from buffer._
Expand Down Expand Up @@ -1193,6 +1199,25 @@ create a new memref inside the region and use it outside of the region.

Traits: `AffineScope`, `NoTerminator`, `SingleBlock`

### `krnl.round_even` (KrnlRoundEvenOp)

_Krnl round to nearest even operation_

Krnl round to nearest even operation. Accept scalar or vector float values.
Vector must be 1D of a size that is a multiple of the hardware vector size.

#### Operands:

| Operand | Description |
| :-----: | ----------- |
| `in` | floating-point-like

#### Results:

| Result | Description |
| :----: | ----------- |
| `out` | floating-point-like

### `krnl.seqalloc` (KrnlSeqAllocOp)

_Krnl create a sequence_
Expand Down
1 change: 1 addition & 0 deletions src/Conversion/KrnlToLLVM/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ add_onnx_mlir_library(OMKrnlToLLVM
KrnlPrintTensor.cpp
KrnlPrint.cpp
KrnlRandomNormal.cpp
KrnlRoundEven.cpp
KrnlStrlen.cpp
KrnlStrncmp.cpp
KrnlToLLVMHelper.cpp
Expand Down
2 changes: 2 additions & 0 deletions src/Conversion/KrnlToLLVM/ConvertKrnlToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ void populateAffineAndKrnlToLLVMConversion(RewritePatternSet &patterns,
patterns, vector::VectorTransformsOptions());
vector::populateVectorTransposeLoweringPatterns(
patterns, vector::VectorTransformsOptions());
vector::populateVectorShapeCastLoweringPatterns(patterns);

populateAffineToStdConversionPatterns(patterns);
populateSCFToControlFlowConversionPatterns(patterns);
Expand Down Expand Up @@ -971,6 +972,7 @@ void populateKrnlToLLVMConversion(LLVMTypeConverter &typeConverter,
krnl::populateLoweringKrnlUnaryMathOpPattern(typeConverter, patterns, ctx);
krnl::populateLoweringKrnlStrncmpOpPattern(typeConverter, patterns, ctx);
krnl::populateLoweringKrnlNoneOpPattern(typeConverter, patterns, ctx);
krnl::populateLoweringKrnlRoundEvenOpPattern(typeConverter, patterns, ctx);
}

} // namespace krnl
Expand Down
4 changes: 4 additions & 0 deletions src/Conversion/KrnlToLLVM/ConvertKrnlToLLVM.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,10 @@ void populateLoweringKrnlVectorTypeCastOpPattern(
void populateLoweringKrnlNoneOpPattern(mlir::LLVMTypeConverter &typeConverter,
mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx);

void populateLoweringKrnlRoundEvenOpPattern(
mlir::LLVMTypeConverter &typeConverter, mlir::RewritePatternSet &patterns,
mlir::MLIRContext *ctx);

void determineOwnershipForOutputOMTensors(mlir::ModuleOp &module,
llvm::SmallVectorImpl<bool> &outputOMTensorOwnerships);

Expand Down
115 changes: 115 additions & 0 deletions src/Conversion/KrnlToLLVM/KrnlRoundEven.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/

//===------ KrnlRoundEven.cpp - Lower KrnlRoundEvenOp ---------------------===//
//
// Copyright 2019-2024 The IBM Research Authors.
//
// =============================================================================
//
// This file lowers the KrnlRoundEvenOp operator.
//
// Currently limited to fp32 integers, instructions supports other data types.
//===----------------------------------------------------------------------===//

#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"

#include "src/Conversion/KrnlToLLVM/KrnlToLLVMHelper.hpp"
#include "src/Dialect/Krnl/KrnlHelper.hpp"
#include "src/Dialect/Krnl/KrnlOps.hpp"
#include "src/Dialect/Mlir/DialectBuilder.hpp"
#include "llvm/Support/Debug.h"

#define DEBUG_TYPE "krnl_to_llvm"

using namespace mlir;

namespace onnx_mlir {
namespace krnl {

class KrnlRoundEvenOpLowering : public ConversionPattern {
public:
explicit KrnlRoundEvenOpLowering(
LLVMTypeConverter &typeConverter, MLIRContext *context)
: ConversionPattern(
typeConverter, KrnlRoundEvenOp::getOperationName(), 1, context) {}
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
KrnlRoundEvenOp::Adaptor operandAdaptor(operands);
Value input = operandAdaptor.getIn();

// Scalar or Vector?
Type inputType = input.getType();
Type inputElemType = getElementTypeOrSelf(inputType);
assert(mlir::isa<FloatType>(inputElemType) && "expected float");
int64_t inputBitWidth = inputElemType.getIntOrFloatBitWidth();
assert(inputBitWidth == 32 && "expected 32bit float");
VectorType inputVecType = mlir::dyn_cast<VectorType>(inputType);
assert(VectorMachineSupport::requireCustomASM(
GenericOps::roundEvenGop, inputElemType) &&
"expected custom requirement");
// Common between scalar and vector
MultiDialectBuilder<LLVMBuilder> create(rewriter, loc);
Type i32Ty = rewriter.getI32Type();
Type f32Ty = rewriter.getF32Type();

if (inputVecType) {
// Vector of 4 elements.
Type vecTypeI32 = LLVM::getFixedVectorType(i32Ty, 4);
Type vecTypeF32 = LLVM::getFixedVectorType(f32Ty, 4);
// Use integer as container for inputs.
Value inputVecI32 = create.llvm.bitcast(vecTypeI32, input);
SmallVector<Value> asmVals{inputVecI32};
// SIMD ASM round to nearest even (M5=4) op
const char *asmStr = "VFISB $0,$1,0,4";
const char *asmConstraints = "=v,v";
Value outVecI32 =
rewriter
.create<LLVM::InlineAsmOp>(loc, vecTypeI32,
/*operands=*/asmVals,
/*asm_string=*/asmStr,
/*constraints=*/asmConstraints, /*has_side_effects=*/false,
/*is_align_stack=*/false,
/*asm_dialect=*/LLVM::AsmDialectAttr(),
/*operand_attrs=*/ArrayAttr())
.getResult(0);
// Cast output back to float.
Value outVecF32 = create.llvm.bitcast(vecTypeF32, outVecI32);
rewriter.replaceOp(op, {outVecF32});
return success();
} else {
// Scalar types.
Type typeF32 = rewriter.getF32Type();
SmallVector<Value> asmVals{input};
// Scalar ASM round to the nearest even (M3=4) op.
const char *asmStr = "FIEBR $0,4,$1";
const char *asmConstraints = "=f,f";
Value outF32 =
rewriter
.create<LLVM::InlineAsmOp>(loc, typeF32,
/*operands=*/asmVals,
/*asm_string=*/asmStr,
/*constraints=*/asmConstraints, /*has_side_effects=*/false,
/*is_align_stack=*/false,
/*asm_dialect=*/LLVM::AsmDialectAttr(),
/*operand_attrs=*/ArrayAttr())
.getResult(0);
rewriter.replaceOp(op, {outF32});
return success();
}
llvm_unreachable("not supported");
}
};

void populateLoweringKrnlRoundEvenOpPattern(LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns, MLIRContext *ctx) {
patterns.insert<KrnlRoundEvenOpLowering>(typeConverter, ctx);
}

} // namespace krnl
} // namespace onnx_mlir
18 changes: 11 additions & 7 deletions src/Conversion/ONNXToKrnl/Math/Elementwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1287,21 +1287,25 @@ struct ScalarOp<ONNXRoundOp> {

template <>
GenOpMix getGenOpMix<ONNXRoundOp>(Type t, Operation *op) {
return {{GenericOps::ArithmeticGop, 4}, {GenericOps::MulGop, 2},
{GenericOps::CompareGop, 3}, {GenericOps::SelectGop, 3},
{GenericOps::FloorGop, 2},
{GenericOps::EstimatedVectorRegisterPressure,
4 /* Little parallelism in code. */}};
// If using roundEven emulation, cost is as below.
// return {{GenericOps::ArithmeticGop, 1}, {GenericOps::MulGop, 2},
// {GenericOps::CompareGop, 3}, {GenericOps::SelectGop, 3},
// {GenericOps::FloorGop, 2},
// {GenericOps::EstimatedVectorRegisterPressure,
// 4 /* Little parallelism in code. */}};

// Assume here that there is a hw op to handle this.
return {{GenericOps::ArithmeticGop, 1}};
}

template <>
Value emitScalarOpFor<ONNXRoundOp>(ConversionPatternRewriter &rewriter,
Location loc, Operation *op, Type elementType,
ArrayRef<Value> scalarOperands) {
Value x = scalarOperands[0];
MultiDialectBuilder<MathBuilder> create(rewriter, loc);
MultiDialectBuilder<KrnlBuilder> create(rewriter, loc);
CheckIfCustomScalarOpIsSupported<ONNXRoundOp>(elementType);
return create.math.round(x);
return create.krnl.roundEven(x);
}

//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ void emitDynamicQuantizationLinearScalarParameters(
// Saturate zero point.
Value saturateZeroPoint = create.math.clip(interZeroPoint, qMin, qMax);
// Round zero point.
zeroPoint = create.math.round(saturateZeroPoint);
zeroPoint = create.krnl.roundEven(saturateZeroPoint);
} else {
zeroPoint = zero;
}
Expand Down
20 changes: 12 additions & 8 deletions src/Conversion/ONNXToKrnl/Quantization/QuantizeLinear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ void emitQuantizationLinearScalarParameters(ConversionPatternRewriter &rewriter,

// Types
Type quantizedElementType = quantizedType.getElementType();
Type inputElementType = inputType.getElementType();
int64_t rank = inputType.getRank();

// Flatten the input data and outputs
Expand All @@ -51,14 +52,17 @@ void emitQuantizationLinearScalarParameters(ConversionPatternRewriter &rewriter,
if (enableSIMD) {
int64_t innermostLoopCollapse = 1; // Only innermost is simdized.
bool canOverCompute = false;
GenOpMix mix = {{GenericOps::DivGop, 1}, {GenericOps::ArithmeticGop, 5},
GenOpMix mixAdjust;
if (hasZeroPoint)
mixAdjust = {{GenericOps::ArithmeticGop, 1}};
GenOpMix mixRound = getGenOpMix<ONNXRoundOp>(inputElementType, op);
GenOpMix mixOthers = {{GenericOps::DivGop, 1},
{GenericOps::ConversionGop, 1}, {GenericOps::MinMaxGop, 2},
{GenericOps::MulGop, 2}, {GenericOps::SelectGop, 3},
{GenericOps::FloorGop, 2},
{GenericOps::EstimatedVectorRegisterPressure,
8 /* Little parallelism in code. */}};
{GenericOps::EstimatedVectorRegisterPressure, 8}};
GenOpMix mix1 = computeGenOpMixUnion(mixAdjust, mixRound);
GenOpMix mix2 = computeGenOpMixUnion(mix1, mixOthers);
totVL = computeSuitableUnrollFactor(inputType /* use unquantized type*/,
innermostLoopCollapse, mix, canOverCompute, simdLoopStaticTripCount,
innermostLoopCollapse, mix2, canOverCompute, simdLoopStaticTripCount,
simdOnly);
}

Expand All @@ -74,12 +78,12 @@ void emitQuantizationLinearScalarParameters(ConversionPatternRewriter &rewriter,
create.krnl.simdIterateIE(simdLb, simdUb, totVL, simdOnly, enableParallel,
{flatInput}, {inputAF}, {flatAlloc}, {outputAF},
{[&](const KrnlBuilder &kb, ArrayRef<Value> inputVals, int64_t VL) {
MultiDialectBuilder<MathBuilder> create(kb);
MultiDialectBuilder<KrnlBuilder, MathBuilder> create(kb);
Value x = inputVals[0];
// Scale
Value scaleX = create.math.div(x, scale);
// Round
Value roundX = create.math.round(scaleX);
Value roundX = create.krnl.roundEven(scaleX);
// Adjust
Value adjustX;
if (hasZeroPoint)
Expand Down
43 changes: 43 additions & 0 deletions src/Dialect/Krnl/DialectBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,49 @@ Value KrnlBuilder::constant(MemRefType type, StringRef name,
alignment.value_or(nullptr));
}

//===----------------------------------------------------------------------===//
// Math style functions.

Value KrnlBuilder::roundEven(Value input) const {
Type elementType = getElementTypeOrSelf(input.getType());
MultiDialectBuilder<VectorBuilder, MathBuilder> create(*this);
// hi alex, may want to generalize support to scalar as well.
VectorType vecType = mlir::dyn_cast<VectorType>(input.getType());
if (VectorMachineSupport::requireCustomASM(
GenericOps::roundEvenGop, elementType)) {
// Use Krnl round even op as LLVM does not support roundEven.
if (vecType) {
// Vector, enable unrolling of multiple archVL.
int64_t archVL = VectorMachineSupport::getArchVectorLength(
GenericOps::roundEvenGop, elementType);
assert(archVL > 1 && "expected vector with archVL>1");
assert(vecType.getRank() == 1 && "1D vec only");
int64_t vecSize = vecType.getShape()[0];
assert(vecSize % archVL == 0 && "expected multiple of archVL");
int64_t numArchVec = vecSize / archVL;
VectorType vecType2D = VectorType::get({numArchVec, archVL}, elementType);
Value input2D = create.vec.shapeCast(vecType2D, input);
Value output2D = input2D;
for (int64_t i = 0; i < numArchVec; ++i) {
Value subInput = create.vec.extractFrom2D(input2D, i);
Value subOutput =
b().create<KrnlRoundEvenOp>(loc(), subInput.getType(), subInput);
output2D = create.vec.insertInto2D(subOutput, output2D, i);
}
return create.vec.shapeCast(vecType, output2D);
} else {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This else looks redundant since we have return inside the if.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tx, will fix

// Scalar.
return b().create<KrnlRoundEvenOp>(loc(), input.getType(), input);
}
}
// No need for custom support, use math roundEven. May want to evaluate
// whether to use the mlir roundEven or our own emulation.
return create.math.roundEven(input);
}

//===----------------------------------------------------------------------===//
// C library functions.

void KrnlBuilder::memcpy(Value dest, Value src, Value numElems) const {
MultiDialectBuilder<MathBuilder> create(*this);
Value zero = create.math.constantIndex(0);
Expand Down
3 changes: 3 additions & 0 deletions src/Dialect/Krnl/DialectBuilder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,9 @@ struct KrnlBuilder : public DialectBuilder {
std::optional<mlir::IntegerAttr> offset = std::nullopt,
std::optional<mlir::IntegerAttr> alignment = std::nullopt) const;

// Math style functions
mlir::Value roundEven(mlir::Value input) const;

// C library functions.
void memcpy(mlir::Value dest, mlir::Value src, mlir::Value numElems) const;
void memcpy(mlir::Value dest, mlir::Value src, mlir::Value numElems,
Expand Down
11 changes: 11 additions & 0 deletions src/Dialect/Krnl/Krnl.td
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,17 @@ def KrnlParallelClauseOp : Op<Krnl_Dialect, "parallel_clause"> {
}];
}

def KrnlRoundEvenOp : Op<Krnl_Dialect, "round_even"> {
let summary = "Krnl round to nearest even operation";
let description = [{
Krnl round to nearest even operation. Accept scalar or vector float values.
Vector must be 1D of a size that is a multiple of the hardware vector size.
}];

let arguments = (ins FloatLike:$in);
let results = (outs FloatLike:$out);
}

def KrnlErfOp : Op<Krnl_Dialect, "erf"> {
let summary = "Krnl erf scalar operation";
let description = [{
Expand Down
Loading
Loading