Skip to content

Commit

Permalink
Merge branch 'main' into mem_reduction_stickified
Browse files Browse the repository at this point in the history
  • Loading branch information
imaihal committed Oct 11, 2024
2 parents 1a4cb5d + 90aea66 commit 7af8bdd
Show file tree
Hide file tree
Showing 20 changed files with 699 additions and 584 deletions.
25 changes: 25 additions & 0 deletions docs/Dialects/krnl.md
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,12 @@ Interfaces: `MemoryEffectOpInterface`
| :-----: | ----------- |
| `parameters` | variadic of any type

#### Results:

| Result | Description |
| :----: | ----------- |
| `returnValue` | variadic of floating-point or integer

### `krnl.copy_from_tile_buffer` (KrnlCopyFromBufferOp)

_Copy from buffer._
Expand Down Expand Up @@ -1193,6 +1199,25 @@ create a new memref inside the region and use it outside of the region.

Traits: `AffineScope`, `NoTerminator`, `SingleBlock`

### `krnl.round_even` (KrnlRoundEvenOp)

_Krnl round to nearest even operation_

Krnl round to nearest even operation. Accept scalar or vector float values.
Vector must be 1D of a size that is a multiple of the hardware vector size.

#### Operands:

| Operand | Description |
| :-----: | ----------- |
| `in` | floating-point-like

#### Results:

| Result | Description |
| :----: | ----------- |
| `out` | floating-point-like

### `krnl.seqalloc` (KrnlSeqAllocOp)

_Krnl create a sequence_
Expand Down
1 change: 1 addition & 0 deletions src/Conversion/KrnlToLLVM/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ add_onnx_mlir_library(OMKrnlToLLVM
KrnlPrintTensor.cpp
KrnlPrint.cpp
KrnlRandomNormal.cpp
KrnlRoundEven.cpp
KrnlStrlen.cpp
KrnlStrncmp.cpp
KrnlToLLVMHelper.cpp
Expand Down
2 changes: 2 additions & 0 deletions src/Conversion/KrnlToLLVM/ConvertKrnlToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ void populateAffineAndKrnlToLLVMConversion(RewritePatternSet &patterns,
patterns, vector::VectorTransformsOptions());
vector::populateVectorTransposeLoweringPatterns(
patterns, vector::VectorTransformsOptions());
vector::populateVectorShapeCastLoweringPatterns(patterns);

populateAffineToStdConversionPatterns(patterns);
populateSCFToControlFlowConversionPatterns(patterns);
Expand Down Expand Up @@ -984,6 +985,7 @@ void populateKrnlToLLVMConversion(LLVMTypeConverter &typeConverter,
krnl::populateLoweringKrnlUnaryMathOpPattern(typeConverter, patterns, ctx);
krnl::populateLoweringKrnlStrncmpOpPattern(typeConverter, patterns, ctx);
krnl::populateLoweringKrnlNoneOpPattern(typeConverter, patterns, ctx);
krnl::populateLoweringKrnlRoundEvenOpPattern(typeConverter, patterns, ctx);
}

} // namespace krnl
Expand Down
4 changes: 4 additions & 0 deletions src/Conversion/KrnlToLLVM/ConvertKrnlToLLVM.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,10 @@ void populateLoweringKrnlVectorTypeCastOpPattern(
void populateLoweringKrnlNoneOpPattern(mlir::LLVMTypeConverter &typeConverter,
mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx);

void populateLoweringKrnlRoundEvenOpPattern(
mlir::LLVMTypeConverter &typeConverter, mlir::RewritePatternSet &patterns,
mlir::MLIRContext *ctx);

void determineOwnershipForOutputOMTensors(mlir::ModuleOp &module,
llvm::SmallVectorImpl<bool> &outputOMTensorOwnerships);

Expand Down
115 changes: 115 additions & 0 deletions src/Conversion/KrnlToLLVM/KrnlRoundEven.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/

//===------ KrnlRoundEven.cpp - Lower KrnlRoundEvenOp ---------------------===//
//
// Copyright 2019-2024 The IBM Research Authors.
//
// =============================================================================
//
// This file lowers the KrnlRoundEvenOp operator.
//
// Currently limited to fp32 integers, instructions supports other data types.
//===----------------------------------------------------------------------===//

#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"

#include "src/Conversion/KrnlToLLVM/KrnlToLLVMHelper.hpp"
#include "src/Dialect/Krnl/KrnlHelper.hpp"
#include "src/Dialect/Krnl/KrnlOps.hpp"
#include "src/Dialect/Mlir/DialectBuilder.hpp"
#include "llvm/Support/Debug.h"

#define DEBUG_TYPE "krnl_to_llvm"

using namespace mlir;

namespace onnx_mlir {
namespace krnl {

class KrnlRoundEvenOpLowering : public ConversionPattern {
public:
explicit KrnlRoundEvenOpLowering(
LLVMTypeConverter &typeConverter, MLIRContext *context)
: ConversionPattern(
typeConverter, KrnlRoundEvenOp::getOperationName(), 1, context) {}
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
KrnlRoundEvenOp::Adaptor operandAdaptor(operands);
Value input = operandAdaptor.getIn();

// Scalar or Vector?
Type inputType = input.getType();
Type inputElemType = getElementTypeOrSelf(inputType);
assert(mlir::isa<FloatType>(inputElemType) && "expected float");
int64_t inputBitWidth = inputElemType.getIntOrFloatBitWidth();
assert(inputBitWidth == 32 && "expected 32bit float");
VectorType inputVecType = mlir::dyn_cast<VectorType>(inputType);
assert(VectorMachineSupport::requireCustomASM(
GenericOps::roundEvenGop, inputElemType) &&
"expected custom requirement");
// Common between scalar and vector
MultiDialectBuilder<LLVMBuilder> create(rewriter, loc);
Type i32Ty = rewriter.getI32Type();
Type f32Ty = rewriter.getF32Type();

if (inputVecType) {
// Vector of 4 elements.
Type vecTypeI32 = LLVM::getFixedVectorType(i32Ty, 4);
Type vecTypeF32 = LLVM::getFixedVectorType(f32Ty, 4);
// Use integer as container for inputs.
Value inputVecI32 = create.llvm.bitcast(vecTypeI32, input);
SmallVector<Value> asmVals{inputVecI32};
// SIMD ASM round to nearest even (M5=4) op
const char *asmStr = "VFISB $0,$1,0,4";
const char *asmConstraints = "=v,v";
Value outVecI32 =
rewriter
.create<LLVM::InlineAsmOp>(loc, vecTypeI32,
/*operands=*/asmVals,
/*asm_string=*/asmStr,
/*constraints=*/asmConstraints, /*has_side_effects=*/false,
/*is_align_stack=*/false,
/*asm_dialect=*/LLVM::AsmDialectAttr(),
/*operand_attrs=*/ArrayAttr())
.getResult(0);
// Cast output back to float.
Value outVecF32 = create.llvm.bitcast(vecTypeF32, outVecI32);
rewriter.replaceOp(op, {outVecF32});
return success();
} else {
// Scalar types.
Type typeF32 = rewriter.getF32Type();
SmallVector<Value> asmVals{input};
// Scalar ASM round to the nearest even (M3=4) op.
const char *asmStr = "FIEBR $0,4,$1";
const char *asmConstraints = "=f,f";
Value outF32 =
rewriter
.create<LLVM::InlineAsmOp>(loc, typeF32,
/*operands=*/asmVals,
/*asm_string=*/asmStr,
/*constraints=*/asmConstraints, /*has_side_effects=*/false,
/*is_align_stack=*/false,
/*asm_dialect=*/LLVM::AsmDialectAttr(),
/*operand_attrs=*/ArrayAttr())
.getResult(0);
rewriter.replaceOp(op, {outF32});
return success();
}
llvm_unreachable("not supported");
}
};

void populateLoweringKrnlRoundEvenOpPattern(LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns, MLIRContext *ctx) {
patterns.insert<KrnlRoundEvenOpLowering>(typeConverter, ctx);
}

} // namespace krnl
} // namespace onnx_mlir
18 changes: 11 additions & 7 deletions src/Conversion/ONNXToKrnl/Math/Elementwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1287,21 +1287,25 @@ struct ScalarOp<ONNXRoundOp> {

template <>
GenOpMix getGenOpMix<ONNXRoundOp>(Type t, Operation *op) {
return {{GenericOps::ArithmeticGop, 4}, {GenericOps::MulGop, 2},
{GenericOps::CompareGop, 3}, {GenericOps::SelectGop, 3},
{GenericOps::FloorGop, 2},
{GenericOps::EstimatedVectorRegisterPressure,
4 /* Little parallelism in code. */}};
// If using roundEven emulation, cost is as below.
// return {{GenericOps::ArithmeticGop, 1}, {GenericOps::MulGop, 2},
// {GenericOps::CompareGop, 3}, {GenericOps::SelectGop, 3},
// {GenericOps::FloorGop, 2},
// {GenericOps::EstimatedVectorRegisterPressure,
// 4 /* Little parallelism in code. */}};

// Assume here that there is a hw op to handle this.
return {{GenericOps::ArithmeticGop, 1}};
}

template <>
Value emitScalarOpFor<ONNXRoundOp>(ConversionPatternRewriter &rewriter,
Location loc, Operation *op, Type elementType,
ArrayRef<Value> scalarOperands) {
Value x = scalarOperands[0];
MultiDialectBuilder<MathBuilder> create(rewriter, loc);
MultiDialectBuilder<KrnlBuilder> create(rewriter, loc);
CheckIfCustomScalarOpIsSupported<ONNXRoundOp>(elementType);
return create.math.round(x);
return create.krnl.roundEven(x);
}

//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ void emitDynamicQuantizationLinearScalarParameters(
// Saturate zero point.
Value saturateZeroPoint = create.math.clip(interZeroPoint, qMin, qMax);
// Round zero point.
zeroPoint = create.math.round(saturateZeroPoint);
zeroPoint = create.krnl.roundEven(saturateZeroPoint);
} else {
zeroPoint = zero;
}
Expand Down
20 changes: 12 additions & 8 deletions src/Conversion/ONNXToKrnl/Quantization/QuantizeLinear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ void emitQuantizationLinearScalarParameters(ConversionPatternRewriter &rewriter,

// Types
Type quantizedElementType = quantizedType.getElementType();
Type inputElementType = inputType.getElementType();
int64_t rank = inputType.getRank();

// Flatten the input data and outputs
Expand All @@ -51,14 +52,17 @@ void emitQuantizationLinearScalarParameters(ConversionPatternRewriter &rewriter,
if (enableSIMD) {
int64_t innermostLoopCollapse = 1; // Only innermost is simdized.
bool canOverCompute = false;
GenOpMix mix = {{GenericOps::DivGop, 1}, {GenericOps::ArithmeticGop, 5},
GenOpMix mixAdjust;
if (hasZeroPoint)
mixAdjust = {{GenericOps::ArithmeticGop, 1}};
GenOpMix mixRound = getGenOpMix<ONNXRoundOp>(inputElementType, op);
GenOpMix mixOthers = {{GenericOps::DivGop, 1},
{GenericOps::ConversionGop, 1}, {GenericOps::MinMaxGop, 2},
{GenericOps::MulGop, 2}, {GenericOps::SelectGop, 3},
{GenericOps::FloorGop, 2},
{GenericOps::EstimatedVectorRegisterPressure,
8 /* Little parallelism in code. */}};
{GenericOps::EstimatedVectorRegisterPressure, 8}};
GenOpMix mix1 = computeGenOpMixUnion(mixAdjust, mixRound);
GenOpMix mix2 = computeGenOpMixUnion(mix1, mixOthers);
totVL = computeSuitableUnrollFactor(inputType /* use unquantized type*/,
innermostLoopCollapse, mix, canOverCompute, simdLoopStaticTripCount,
innermostLoopCollapse, mix2, canOverCompute, simdLoopStaticTripCount,
simdOnly);
}

Expand All @@ -74,12 +78,12 @@ void emitQuantizationLinearScalarParameters(ConversionPatternRewriter &rewriter,
create.krnl.simdIterateIE(simdLb, simdUb, totVL, simdOnly, enableParallel,
{flatInput}, {inputAF}, {flatAlloc}, {outputAF},
{[&](const KrnlBuilder &kb, ArrayRef<Value> inputVals, int64_t VL) {
MultiDialectBuilder<MathBuilder> create(kb);
MultiDialectBuilder<KrnlBuilder, MathBuilder> create(kb);
Value x = inputVals[0];
// Scale
Value scaleX = create.math.div(x, scale);
// Round
Value roundX = create.math.round(scaleX);
Value roundX = create.krnl.roundEven(scaleX);
// Adjust
Value adjustX;
if (hasZeroPoint)
Expand Down
50 changes: 50 additions & 0 deletions src/Dialect/Krnl/DialectBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,56 @@ Value KrnlBuilder::constant(MemRefType type, StringRef name,
alignment.value_or(nullptr));
}

//===----------------------------------------------------------------------===//
// Math style functions.

Value KrnlBuilder::roundEven(Value input) const {
Type elementType = getElementTypeOrSelf(input.getType());
MultiDialectBuilder<VectorBuilder, MathBuilder> create(*this);
// hi alex, may want to generalize support to scalar as well.
VectorType vecType = mlir::dyn_cast<VectorType>(input.getType());
if (VectorMachineSupport::requireCustomASM(
GenericOps::roundEvenGop, elementType)) {
// Use Krnl round even op as LLVM does not support roundEven.
if (!vecType)
// Scalar.
return b().create<KrnlRoundEvenOp>(loc(), input.getType(), input);

// Vector, enable unrolling of multiple archVL.
int64_t archVL = VectorMachineSupport::getArchVectorLength(
GenericOps::roundEvenGop, elementType);
assert(archVL > 1 && "expected vector with archVL>1");
assert(vecType.getRank() == 1 && "1D vec only");
int64_t vecSize = vecType.getShape()[0];
assert(vecSize % archVL == 0 && "expected multiple of archVL");
int64_t numArchVec = vecSize / archVL;
VectorType vecType2D = VectorType::get({numArchVec, archVL}, elementType);
// Cast input vector to a vector of chunks (archVL values that can be
// handled by one hardware SIMD instruction).
Value input2D = create.vec.shapeCast(vecType2D, input);
Value output2D = input2D;
// Iterates over all hardware SIMD chunks.
for (int64_t i = 0; i < numArchVec; ++i) {
// Extract one chunk, compute new value, insert result in corresponding
// output 2D vector.
Value subInput = create.vec.extractFrom2D(input2D, i);
Value subOutput =
b().create<KrnlRoundEvenOp>(loc(), subInput.getType(), subInput);
output2D = create.vec.insertInto2D(subOutput, output2D, i);
}
// Recast output 2D vector into the flat vector (same shape as input).
return create.vec.shapeCast(vecType, output2D);
}
// No need for custom support, use math roundEven. May want to evaluate
// whether to use the mlir roundEven or our own emulation.
// Note: MacOS CI has an issue with the roundEven instruction, thus continue
// to use emulation. May change in the future.
return create.math.roundEvenEmulation(input);
}

//===----------------------------------------------------------------------===//
// C library functions.

void KrnlBuilder::memcpy(Value dest, Value src, Value numElems) const {
MultiDialectBuilder<MathBuilder> create(*this);
Value zero = create.math.constantIndex(0);
Expand Down
3 changes: 3 additions & 0 deletions src/Dialect/Krnl/DialectBuilder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,9 @@ struct KrnlBuilder : public DialectBuilder {
std::optional<mlir::IntegerAttr> offset = std::nullopt,
std::optional<mlir::IntegerAttr> alignment = std::nullopt) const;

// Math style functions
mlir::Value roundEven(mlir::Value input) const;

// C library functions.
void memcpy(mlir::Value dest, mlir::Value src, mlir::Value numElems) const;
void memcpy(mlir::Value dest, mlir::Value src, mlir::Value numElems,
Expand Down
11 changes: 11 additions & 0 deletions src/Dialect/Krnl/Krnl.td
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,17 @@ def KrnlParallelClauseOp : Op<Krnl_Dialect, "parallel_clause"> {
}];
}

def KrnlRoundEvenOp : Op<Krnl_Dialect, "round_even"> {
let summary = "Krnl round to nearest even operation";
let description = [{
Krnl round to nearest even operation. Accept scalar or vector float values.
Vector must be 1D of a size that is a multiple of the hardware vector size.
}];

let arguments = (ins FloatLike:$in);
let results = (outs FloatLike:$out);
}

def KrnlErfOp : Op<Krnl_Dialect, "erf"> {
let summary = "Krnl erf scalar operation";
let description = [{
Expand Down
Loading

0 comments on commit 7af8bdd

Please sign in to comment.