diff --git a/docs/Dialects/krnl.md b/docs/Dialects/krnl.md index 80ad6224a9..0e6c8a065c 100644 --- a/docs/Dialects/krnl.md +++ b/docs/Dialects/krnl.md @@ -191,6 +191,12 @@ Interfaces: `MemoryEffectOpInterface` | :-----: | ----------- | | `parameters` | variadic of any type +#### Results: + +| Result | Description | +| :----: | ----------- | +| `returnValue` | variadic of floating-point or integer + ### `krnl.copy_from_tile_buffer` (KrnlCopyFromBufferOp) _Copy from buffer._ @@ -1193,6 +1199,25 @@ create a new memref inside the region and use it outside of the region. Traits: `AffineScope`, `NoTerminator`, `SingleBlock` +### `krnl.round_even` (KrnlRoundEvenOp) + +_Krnl round to nearest even operation_ + +Krnl round to nearest even operation. Accept scalar or vector float values. +Vector must be 1D of a size that is a multiple of the hardware vector size. + +#### Operands: + +| Operand | Description | +| :-----: | ----------- | +| `in` | floating-point-like + +#### Results: + +| Result | Description | +| :----: | ----------- | +| `out` | floating-point-like + ### `krnl.seqalloc` (KrnlSeqAllocOp) _Krnl create a sequence_ diff --git a/src/Conversion/KrnlToLLVM/CMakeLists.txt b/src/Conversion/KrnlToLLVM/CMakeLists.txt index 52a583552f..92948137be 100644 --- a/src/Conversion/KrnlToLLVM/CMakeLists.txt +++ b/src/Conversion/KrnlToLLVM/CMakeLists.txt @@ -12,6 +12,7 @@ add_onnx_mlir_library(OMKrnlToLLVM KrnlPrintTensor.cpp KrnlPrint.cpp KrnlRandomNormal.cpp + KrnlRoundEven.cpp KrnlStrlen.cpp KrnlStrncmp.cpp KrnlToLLVMHelper.cpp diff --git a/src/Conversion/KrnlToLLVM/ConvertKrnlToLLVM.cpp b/src/Conversion/KrnlToLLVM/ConvertKrnlToLLVM.cpp index 62db84beff..a8d631b2d1 100644 --- a/src/Conversion/KrnlToLLVM/ConvertKrnlToLLVM.cpp +++ b/src/Conversion/KrnlToLLVM/ConvertKrnlToLLVM.cpp @@ -198,6 +198,7 @@ void populateAffineAndKrnlToLLVMConversion(RewritePatternSet &patterns, patterns, vector::VectorTransformsOptions()); vector::populateVectorTransposeLoweringPatterns( patterns, vector::VectorTransformsOptions()); + vector::populateVectorShapeCastLoweringPatterns(patterns); populateAffineToStdConversionPatterns(patterns); populateSCFToControlFlowConversionPatterns(patterns); @@ -971,6 +972,7 @@ void populateKrnlToLLVMConversion(LLVMTypeConverter &typeConverter, krnl::populateLoweringKrnlUnaryMathOpPattern(typeConverter, patterns, ctx); krnl::populateLoweringKrnlStrncmpOpPattern(typeConverter, patterns, ctx); krnl::populateLoweringKrnlNoneOpPattern(typeConverter, patterns, ctx); + krnl::populateLoweringKrnlRoundEvenOpPattern(typeConverter, patterns, ctx); } } // namespace krnl diff --git a/src/Conversion/KrnlToLLVM/ConvertKrnlToLLVM.hpp b/src/Conversion/KrnlToLLVM/ConvertKrnlToLLVM.hpp index c222913dfe..2309871db4 100644 --- a/src/Conversion/KrnlToLLVM/ConvertKrnlToLLVM.hpp +++ b/src/Conversion/KrnlToLLVM/ConvertKrnlToLLVM.hpp @@ -107,6 +107,10 @@ void populateLoweringKrnlVectorTypeCastOpPattern( void populateLoweringKrnlNoneOpPattern(mlir::LLVMTypeConverter &typeConverter, mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx); +void populateLoweringKrnlRoundEvenOpPattern( + mlir::LLVMTypeConverter &typeConverter, mlir::RewritePatternSet &patterns, + mlir::MLIRContext *ctx); + void determineOwnershipForOutputOMTensors(mlir::ModuleOp &module, llvm::SmallVectorImpl &outputOMTensorOwnerships); diff --git a/src/Conversion/KrnlToLLVM/KrnlRoundEven.cpp b/src/Conversion/KrnlToLLVM/KrnlRoundEven.cpp new file mode 100644 index 0000000000..81ea95eced --- /dev/null +++ b/src/Conversion/KrnlToLLVM/KrnlRoundEven.cpp @@ -0,0 +1,115 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//===------ KrnlRoundEven.cpp - Lower KrnlRoundEvenOp ---------------------===// +// +// Copyright 2019-2024 The IBM Research Authors. +// +// ============================================================================= +// +// This file lowers the KrnlRoundEvenOp operator. +// +// Currently limited to fp32 integers, instructions supports other data types. +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" + +#include "src/Conversion/KrnlToLLVM/KrnlToLLVMHelper.hpp" +#include "src/Dialect/Krnl/KrnlHelper.hpp" +#include "src/Dialect/Krnl/KrnlOps.hpp" +#include "src/Dialect/Mlir/DialectBuilder.hpp" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "krnl_to_llvm" + +using namespace mlir; + +namespace onnx_mlir { +namespace krnl { + +class KrnlRoundEvenOpLowering : public ConversionPattern { +public: + explicit KrnlRoundEvenOpLowering( + LLVMTypeConverter &typeConverter, MLIRContext *context) + : ConversionPattern( + typeConverter, KrnlRoundEvenOp::getOperationName(), 1, context) {} + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + KrnlRoundEvenOp::Adaptor operandAdaptor(operands); + Value input = operandAdaptor.getIn(); + + // Scalar or Vector? + Type inputType = input.getType(); + Type inputElemType = getElementTypeOrSelf(inputType); + assert(mlir::isa(inputElemType) && "expected float"); + int64_t inputBitWidth = inputElemType.getIntOrFloatBitWidth(); + assert(inputBitWidth == 32 && "expected 32bit float"); + VectorType inputVecType = mlir::dyn_cast(inputType); + assert(VectorMachineSupport::requireCustomASM( + GenericOps::roundEvenGop, inputElemType) && + "expected custom requirement"); + // Common between scalar and vector + MultiDialectBuilder create(rewriter, loc); + Type i32Ty = rewriter.getI32Type(); + Type f32Ty = rewriter.getF32Type(); + + if (inputVecType) { + // Vector of 4 elements. + Type vecTypeI32 = LLVM::getFixedVectorType(i32Ty, 4); + Type vecTypeF32 = LLVM::getFixedVectorType(f32Ty, 4); + // Use integer as container for inputs. + Value inputVecI32 = create.llvm.bitcast(vecTypeI32, input); + SmallVector asmVals{inputVecI32}; + // SIMD ASM round to nearest even (M5=4) op + const char *asmStr = "VFISB $0,$1,0,4"; + const char *asmConstraints = "=v,v"; + Value outVecI32 = + rewriter + .create(loc, vecTypeI32, + /*operands=*/asmVals, + /*asm_string=*/asmStr, + /*constraints=*/asmConstraints, /*has_side_effects=*/false, + /*is_align_stack=*/false, + /*asm_dialect=*/LLVM::AsmDialectAttr(), + /*operand_attrs=*/ArrayAttr()) + .getResult(0); + // Cast output back to float. + Value outVecF32 = create.llvm.bitcast(vecTypeF32, outVecI32); + rewriter.replaceOp(op, {outVecF32}); + return success(); + } else { + // Scalar types. + Type typeF32 = rewriter.getF32Type(); + SmallVector asmVals{input}; + // Scalar ASM round to the nearest even (M3=4) op. + const char *asmStr = "FIEBR $0,4,$1"; + const char *asmConstraints = "=f,f"; + Value outF32 = + rewriter + .create(loc, typeF32, + /*operands=*/asmVals, + /*asm_string=*/asmStr, + /*constraints=*/asmConstraints, /*has_side_effects=*/false, + /*is_align_stack=*/false, + /*asm_dialect=*/LLVM::AsmDialectAttr(), + /*operand_attrs=*/ArrayAttr()) + .getResult(0); + rewriter.replaceOp(op, {outF32}); + return success(); + } + llvm_unreachable("not supported"); + } +}; + +void populateLoweringKrnlRoundEvenOpPattern(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, MLIRContext *ctx) { + patterns.insert(typeConverter, ctx); +} + +} // namespace krnl +} // namespace onnx_mlir diff --git a/src/Conversion/ONNXToKrnl/Math/Elementwise.cpp b/src/Conversion/ONNXToKrnl/Math/Elementwise.cpp index 8e6feaf540..ca592fffbe 100644 --- a/src/Conversion/ONNXToKrnl/Math/Elementwise.cpp +++ b/src/Conversion/ONNXToKrnl/Math/Elementwise.cpp @@ -1287,11 +1287,15 @@ struct ScalarOp { template <> GenOpMix getGenOpMix(Type t, Operation *op) { - return {{GenericOps::ArithmeticGop, 4}, {GenericOps::MulGop, 2}, - {GenericOps::CompareGop, 3}, {GenericOps::SelectGop, 3}, - {GenericOps::FloorGop, 2}, - {GenericOps::EstimatedVectorRegisterPressure, - 4 /* Little parallelism in code. */}}; + // If using roundEven emulation, cost is as below. + // return {{GenericOps::ArithmeticGop, 1}, {GenericOps::MulGop, 2}, + // {GenericOps::CompareGop, 3}, {GenericOps::SelectGop, 3}, + // {GenericOps::FloorGop, 2}, + // {GenericOps::EstimatedVectorRegisterPressure, + // 4 /* Little parallelism in code. */}}; + + // Assume here that there is a hw op to handle this. + return {{GenericOps::ArithmeticGop, 1}}; } template <> @@ -1299,9 +1303,9 @@ Value emitScalarOpFor(ConversionPatternRewriter &rewriter, Location loc, Operation *op, Type elementType, ArrayRef scalarOperands) { Value x = scalarOperands[0]; - MultiDialectBuilder create(rewriter, loc); + MultiDialectBuilder create(rewriter, loc); CheckIfCustomScalarOpIsSupported(elementType); - return create.math.round(x); + return create.krnl.roundEven(x); } //===----------------------------------------------------------------------===// diff --git a/src/Conversion/ONNXToKrnl/Quantization/DynamicQuantizeLinear.cpp b/src/Conversion/ONNXToKrnl/Quantization/DynamicQuantizeLinear.cpp index 5484974624..ca74f8a480 100644 --- a/src/Conversion/ONNXToKrnl/Quantization/DynamicQuantizeLinear.cpp +++ b/src/Conversion/ONNXToKrnl/Quantization/DynamicQuantizeLinear.cpp @@ -68,7 +68,7 @@ void emitDynamicQuantizationLinearScalarParameters( // Saturate zero point. Value saturateZeroPoint = create.math.clip(interZeroPoint, qMin, qMax); // Round zero point. - zeroPoint = create.math.round(saturateZeroPoint); + zeroPoint = create.krnl.roundEven(saturateZeroPoint); } else { zeroPoint = zero; } diff --git a/src/Conversion/ONNXToKrnl/Quantization/QuantizeLinear.cpp b/src/Conversion/ONNXToKrnl/Quantization/QuantizeLinear.cpp index 01293c81e9..9c4b7f7699 100644 --- a/src/Conversion/ONNXToKrnl/Quantization/QuantizeLinear.cpp +++ b/src/Conversion/ONNXToKrnl/Quantization/QuantizeLinear.cpp @@ -34,6 +34,7 @@ void emitQuantizationLinearScalarParameters(ConversionPatternRewriter &rewriter, // Types Type quantizedElementType = quantizedType.getElementType(); + Type inputElementType = inputType.getElementType(); int64_t rank = inputType.getRank(); // Flatten the input data and outputs @@ -51,14 +52,17 @@ void emitQuantizationLinearScalarParameters(ConversionPatternRewriter &rewriter, if (enableSIMD) { int64_t innermostLoopCollapse = 1; // Only innermost is simdized. bool canOverCompute = false; - GenOpMix mix = {{GenericOps::DivGop, 1}, {GenericOps::ArithmeticGop, 5}, + GenOpMix mixAdjust; + if (hasZeroPoint) + mixAdjust = {{GenericOps::ArithmeticGop, 1}}; + GenOpMix mixRound = getGenOpMix(inputElementType, op); + GenOpMix mixOthers = {{GenericOps::DivGop, 1}, {GenericOps::ConversionGop, 1}, {GenericOps::MinMaxGop, 2}, - {GenericOps::MulGop, 2}, {GenericOps::SelectGop, 3}, - {GenericOps::FloorGop, 2}, - {GenericOps::EstimatedVectorRegisterPressure, - 8 /* Little parallelism in code. */}}; + {GenericOps::EstimatedVectorRegisterPressure, 8}}; + GenOpMix mix1 = computeGenOpMixUnion(mixAdjust, mixRound); + GenOpMix mix2 = computeGenOpMixUnion(mix1, mixOthers); totVL = computeSuitableUnrollFactor(inputType /* use unquantized type*/, - innermostLoopCollapse, mix, canOverCompute, simdLoopStaticTripCount, + innermostLoopCollapse, mix2, canOverCompute, simdLoopStaticTripCount, simdOnly); } @@ -74,12 +78,12 @@ void emitQuantizationLinearScalarParameters(ConversionPatternRewriter &rewriter, create.krnl.simdIterateIE(simdLb, simdUb, totVL, simdOnly, enableParallel, {flatInput}, {inputAF}, {flatAlloc}, {outputAF}, {[&](const KrnlBuilder &kb, ArrayRef inputVals, int64_t VL) { - MultiDialectBuilder create(kb); + MultiDialectBuilder create(kb); Value x = inputVals[0]; // Scale Value scaleX = create.math.div(x, scale); // Round - Value roundX = create.math.round(scaleX); + Value roundX = create.krnl.roundEven(scaleX); // Adjust Value adjustX; if (hasZeroPoint) diff --git a/src/Dialect/Krnl/DialectBuilder.cpp b/src/Dialect/Krnl/DialectBuilder.cpp index ae5054ca61..ec258f20c4 100644 --- a/src/Dialect/Krnl/DialectBuilder.cpp +++ b/src/Dialect/Krnl/DialectBuilder.cpp @@ -349,6 +349,56 @@ Value KrnlBuilder::constant(MemRefType type, StringRef name, alignment.value_or(nullptr)); } +//===----------------------------------------------------------------------===// +// Math style functions. + +Value KrnlBuilder::roundEven(Value input) const { + Type elementType = getElementTypeOrSelf(input.getType()); + MultiDialectBuilder create(*this); + // hi alex, may want to generalize support to scalar as well. + VectorType vecType = mlir::dyn_cast(input.getType()); + if (VectorMachineSupport::requireCustomASM( + GenericOps::roundEvenGop, elementType)) { + // Use Krnl round even op as LLVM does not support roundEven. + if (!vecType) + // Scalar. + return b().create(loc(), input.getType(), input); + + // Vector, enable unrolling of multiple archVL. + int64_t archVL = VectorMachineSupport::getArchVectorLength( + GenericOps::roundEvenGop, elementType); + assert(archVL > 1 && "expected vector with archVL>1"); + assert(vecType.getRank() == 1 && "1D vec only"); + int64_t vecSize = vecType.getShape()[0]; + assert(vecSize % archVL == 0 && "expected multiple of archVL"); + int64_t numArchVec = vecSize / archVL; + VectorType vecType2D = VectorType::get({numArchVec, archVL}, elementType); + // Cast input vector to a vector of chunks (archVL values that can be + // handled by one hardware SIMD instruction). + Value input2D = create.vec.shapeCast(vecType2D, input); + Value output2D = input2D; + // Iterates over all hardware SIMD chunks. + for (int64_t i = 0; i < numArchVec; ++i) { + // Extract one chunk, compute new value, insert result in corresponding + // output 2D vector. + Value subInput = create.vec.extractFrom2D(input2D, i); + Value subOutput = + b().create(loc(), subInput.getType(), subInput); + output2D = create.vec.insertInto2D(subOutput, output2D, i); + } + // Recast output 2D vector into the flat vector (same shape as input). + return create.vec.shapeCast(vecType, output2D); + } + // No need for custom support, use math roundEven. May want to evaluate + // whether to use the mlir roundEven or our own emulation. + // Note: MacOS CI has an issue with the roundEven instruction, thus continue + // to use emulation. May change in the future. + return create.math.roundEvenEmulation(input); +} + +//===----------------------------------------------------------------------===// +// C library functions. + void KrnlBuilder::memcpy(Value dest, Value src, Value numElems) const { MultiDialectBuilder create(*this); Value zero = create.math.constantIndex(0); diff --git a/src/Dialect/Krnl/DialectBuilder.hpp b/src/Dialect/Krnl/DialectBuilder.hpp index 3a3c786aad..f810998dca 100644 --- a/src/Dialect/Krnl/DialectBuilder.hpp +++ b/src/Dialect/Krnl/DialectBuilder.hpp @@ -265,6 +265,9 @@ struct KrnlBuilder : public DialectBuilder { std::optional offset = std::nullopt, std::optional alignment = std::nullopt) const; + // Math style functions + mlir::Value roundEven(mlir::Value input) const; + // C library functions. void memcpy(mlir::Value dest, mlir::Value src, mlir::Value numElems) const; void memcpy(mlir::Value dest, mlir::Value src, mlir::Value numElems, diff --git a/src/Dialect/Krnl/Krnl.td b/src/Dialect/Krnl/Krnl.td index cac5423fcf..c8220dfc53 100644 --- a/src/Dialect/Krnl/Krnl.td +++ b/src/Dialect/Krnl/Krnl.td @@ -567,6 +567,17 @@ def KrnlParallelClauseOp : Op { }]; } +def KrnlRoundEvenOp : Op { + let summary = "Krnl round to nearest even operation"; + let description = [{ + Krnl round to nearest even operation. Accept scalar or vector float values. + Vector must be 1D of a size that is a multiple of the hardware vector size. + }]; + + let arguments = (ins FloatLike:$in); + let results = (outs FloatLike:$out); +} + def KrnlErfOp : Op { let summary = "Krnl erf scalar operation"; let description = [{ diff --git a/src/Dialect/Mlir/DialectBuilder.cpp b/src/Dialect/Mlir/DialectBuilder.cpp index ff425513a5..44836d961d 100644 --- a/src/Dialect/Mlir/DialectBuilder.cpp +++ b/src/Dialect/Mlir/DialectBuilder.cpp @@ -279,6 +279,19 @@ Value MathBuilder::rem(Value lhs, Value rhs) const { Value MathBuilder::round(Value x) const { Type type = x.getType(); assert(isScalarOrVectorFloat(type) && "expected float"); + return b().create(loc(), x); +} + +Value MathBuilder::roundEven(Value x) const { + Type type = x.getType(); + assert(isScalarOrVectorFloat(type) && "expected float"); + return b().create(loc(), x); +} + +Value MathBuilder::roundEvenEmulation(Value x) const { + Type type = x.getType(); + assert(isScalarOrVectorFloat(type) && "expected float"); + // Use algorithm originally posted in ONNXtoKRNL/Math/Elementwise.cpp // lowering. @@ -2112,6 +2125,24 @@ void VectorBuilder::multiReduction(ArrayRef inputVecArray, } } +// Cast vectors to vectors of different shape (e.g. 1D to 2D and back). +Value VectorBuilder::shapeCast(VectorType newType, Value vector) const { + return b().create(loc(), newType, vector); +} + +// Extract 1D vector from 2D vector. +Value VectorBuilder::extractFrom2D(Value vector2D, int64_t position) const { + llvm::SmallVector pos = {position}; + return b().create(loc(), vector2D, pos); +} + +// Insert 1D vector into 2D vector. +Value VectorBuilder::insertInto2D( + Value vector, Value vector2D, int64_t position) const { + llvm::SmallVector pos = {position}; + return b().create(loc(), vector, vector2D, pos); +} + Value VectorBuilder::extractElement(Value vector, int64_t index) const { MultiDialectBuilder create(*this); VectorType type = llvm::cast(vector.getType()); diff --git a/src/Dialect/Mlir/DialectBuilder.hpp b/src/Dialect/Mlir/DialectBuilder.hpp index 32fc82e42c..1c1ce1775e 100644 --- a/src/Dialect/Mlir/DialectBuilder.hpp +++ b/src/Dialect/Mlir/DialectBuilder.hpp @@ -142,6 +142,8 @@ struct MathBuilder final : DialectBuilder { mlir::Value pow(mlir::Value base, mlir::Value exp) const; // B/Float only. mlir::Value rem(mlir::Value lhs, mlir::Value rhs) const; // B. mlir::Value round(mlir::Value) const; // Float only. + mlir::Value roundEven(mlir::Value) const; // Float only. + mlir::Value roundEvenEmulation(mlir::Value) const; // Float only. mlir::Value sqrt(mlir::Value val) const; // Float only. mlir::Value sub(mlir::Value lhs, mlir::Value rhs) const; // B. mlir::Value tanh(mlir::Value val) const; // Float only. @@ -574,7 +576,14 @@ struct VectorBuilder final : DialectBuilder { void multiReduction(mlir::ArrayRef inputVecArray, F2 reductionFct, llvm::SmallVectorImpl &outputVecArray); - // Insert and extract. + // Cast vectors to vectors of different shape (e.g. 1D to 2D and back). + mlir::Value shapeCast(mlir::VectorType newType, mlir::Value vector) const; + // Extract and insert 1D vector from/to 2D vector. + mlir::Value extractFrom2D(mlir::Value vector2D, int64_t position) const; + mlir::Value insertInto2D( + mlir::Value vector, mlir::Value vector2D, int64_t position) const; + + // Insert and extract one element (scalar). mlir::Value extractElement(mlir::Value vector, int64_t position) const; mlir::Value insertElement( mlir::Value vector, mlir::Value element, int64_t position) const; diff --git a/src/Dialect/Mlir/VectorMachineSupport.cpp b/src/Dialect/Mlir/VectorMachineSupport.cpp index 75f46638a6..c03a8cdffb 100644 --- a/src/Dialect/Mlir/VectorMachineSupport.cpp +++ b/src/Dialect/Mlir/VectorMachineSupport.cpp @@ -113,6 +113,7 @@ int64_t VectorMachineSupport::computeArchVectorLength(Type elementType) { int64_t processedValues = std::max(static_cast(1), vl); totProcessedValues += processedValues * num; } + // Compute final values int64_t totNum = vectorizedOpNum + scalarOpNum; if (!hasRegisterPressure) { @@ -127,6 +128,22 @@ int64_t VectorMachineSupport::computeArchVectorLength(Type elementType) { // IBM Z servers // ============================================================================= +bool Z16VectorMachineSupport::needCustomASM( + GenericOps genOp, Type elementType) { + assert(genOp < GenericOps::LastGop && "no metrics here, only genOps"); + bool isFloat = mlir::isa(elementType); + if (isFloat) { + switch (genOp) { + case GenericOps::roundEvenGop: + return true; + default: + return false; + } + } + // Integer + return false; +} + int64_t Z16VectorMachineSupport::computeArchVectorLength( GenericOps genOp, Type elementType) { assert(genOp < GenericOps::LastGop && "no metrics here, only genOps"); @@ -166,6 +183,7 @@ int64_t Z16VectorMachineSupport::computeArchVectorLength( case GenericOps::FmaGop: case GenericOps::MinMaxGop: case GenericOps::MulGop: + case GenericOps::roundEvenGop: case GenericOps::SqrtGop: return archVL; default: @@ -202,6 +220,12 @@ int64_t Z16VectorMachineSupport::computeArchVectorLength( // This may be an approximation of the actual capabilities. // ============================================================================= +bool SSE42x86VectorMachineSupport::needCustomASM( + GenericOps genOp, Type elementType) { + assert(genOp < GenericOps::LastGop && "no metrics here, only genOps"); + return false; +} + int64_t SSE42x86VectorMachineSupport::computeArchVectorLength( GenericOps genOp, Type elementType) { assert(genOp < GenericOps::LastGop && "no metrics here, only genOps"); @@ -241,7 +265,7 @@ int64_t SSE42x86VectorMachineSupport::computeArchVectorLength( case GenericOps::FmaGop: case GenericOps::MinMaxGop: case GenericOps::MulGop: - case GenericOps::RoundGop: + case GenericOps::roundEvenGop: case GenericOps::SqrtGop: case GenericOps::SumAcrossGop: return archVL; @@ -289,6 +313,12 @@ int64_t SSE42x86VectorMachineSupport::computeArchVectorLength( // This may be an approximation of the actual capabilities. // ============================================================================= +bool NeonVectorMachineSupport::needCustomASM( + GenericOps genOp, Type elementType) { + assert(genOp < GenericOps::LastGop && "no metrics here, only genOps"); + return false; +} + int64_t NeonVectorMachineSupport::computeArchVectorLength( GenericOps genOp, Type elementType) { assert(genOp < GenericOps::LastGop && "no metrics here, only genOps"); @@ -327,7 +357,7 @@ int64_t NeonVectorMachineSupport::computeArchVectorLength( case GenericOps::FmaGop: case GenericOps::MinMaxGop: case GenericOps::MulGop: - case GenericOps::RoundGop: + case GenericOps::roundEvenGop: case GenericOps::SqrtGop: case GenericOps::SumAcrossGop: return archVL; @@ -382,7 +412,7 @@ GenOpMix computeGenOpMixUnion(const GenOpMix &mix1, const GenOpMix &mix2) { u[genOp] = num; } // Merge entries from the second mix. - for (auto pair : mix1) { + for (auto pair : mix2) { GenericOps genOp = pair.first; int64_t num = pair.second; if (u.find(genOp) != u.end()) { diff --git a/src/Dialect/Mlir/VectorMachineSupport.hpp b/src/Dialect/Mlir/VectorMachineSupport.hpp index 0d1104bbad..f4597ce736 100644 --- a/src/Dialect/Mlir/VectorMachineSupport.hpp +++ b/src/Dialect/Mlir/VectorMachineSupport.hpp @@ -56,7 +56,7 @@ enum class GenericOps { MulGop, PowGop, RemGop, - RoundGop, + roundEvenGop, /* FP to FP round to nearest even ONNX */ ScalarOnlyGop, /* Any ops that are guaranteed to be scalar on any arch. */ SelectGop, ShiftGop, /* Shift operations: logical/arithmetic. */ @@ -107,6 +107,11 @@ class VectorMachineSupport { // support. static bool hasSimd() { return getArchVectorRegisterNum() > 0; } + // Determine if custom asm is needed (aka operation not supported by llvm). + static bool requireCustomASM(GenericOps gop, mlir::Type elementType) { + return vms()->needCustomASM(gop, elementType); + } + // When querying Vector length for machines with unsupported simd, UNSUPPORTED // (aka 0) is returned. static const int64_t UNSUPPORTED = 1; @@ -157,6 +162,7 @@ class VectorMachineSupport { protected: // Virtual functions that do the actual work. Called by the "get" functions. virtual std::string computeArchName() = 0; + virtual bool needCustomASM(GenericOps gop, mlir::Type elementType) = 0; virtual int64_t computeArchVectorRegisterNum() = 0; virtual int64_t computeArchVectorBitWidth() = 0; virtual int64_t computeArchVectorLength(mlir::Type elementType); @@ -179,6 +185,9 @@ class NoVectorMachineSupport : public VectorMachineSupport { virtual ~NoVectorMachineSupport() = default; std::string computeArchName() override { return "no_vector"; } + bool needCustomASM(GenericOps gop, mlir::Type elementType) override { + return false; + } int64_t computeArchVectorRegisterNum() override { return 0; } int64_t computeArchVectorBitWidth() override { return 0; } int64_t computeArchVectorLength(mlir::Type elementType) override { @@ -198,6 +207,7 @@ class Z16VectorMachineSupport : public VectorMachineSupport { virtual ~Z16VectorMachineSupport() = default; std::string computeArchName() override { return "z16"; } + bool needCustomASM(GenericOps gop, mlir::Type elementType) override; int64_t computeArchVectorRegisterNum() override { return 32; } int64_t computeArchVectorBitWidth() override { return 128; } int64_t computeArchVectorLength( @@ -215,6 +225,7 @@ class SSE42x86VectorMachineSupport : public VectorMachineSupport { virtual ~SSE42x86VectorMachineSupport() = default; std::string computeArchName() override { return "x86-sse4.2"; } + bool needCustomASM(GenericOps gop, mlir::Type elementType) override; int64_t computeArchVectorRegisterNum() override { return 16; } int64_t computeArchVectorBitWidth() override { return 128; } int64_t computeArchVectorLength( @@ -238,6 +249,7 @@ class NeonVectorMachineSupport : public VectorMachineSupport { virtual ~NeonVectorMachineSupport() = default; std::string computeArchName() override { return "arm64-neon"; } + bool needCustomASM(GenericOps gop, mlir::Type elementType) override; int64_t computeArchVectorRegisterNum() override { return 32; } int64_t computeArchVectorBitWidth() override { return 128; } int64_t computeArchVectorLength( diff --git a/test/mlir/conversion/onnx_to_krnl/Math/Elementwise_with_canonicalize.mlir b/test/mlir/conversion/onnx_to_krnl/Math/Elementwise_with_canonicalize.mlir index d560acc6da..735312d524 100644 --- a/test/mlir/conversion/onnx_to_krnl/Math/Elementwise_with_canonicalize.mlir +++ b/test/mlir/conversion/onnx_to_krnl/Math/Elementwise_with_canonicalize.mlir @@ -446,12 +446,13 @@ func.func @where(%arg0: tensor<2x2xi1>, %arg1: tensor<2x2xf32>, %arg2: tensor<2x // ----- + func.func @round(%arg0: tensor<15xf32>) -> tensor<*xf32> { %0 = "onnx.Round"(%arg0) : (tensor<15xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> // mlir2FileCheck.py -// CHECK-LABEL: func @round +// CHECK-LABEL: func.func @round // CHECK-SAME: ([[PARAM_0_:%.+]]: memref<15xf32>) -> memref<15xf32> { // CHECK-DAG: [[CST_5_dot_000000_:%.+]] = arith.constant 5.000000e-01 : f32 // CHECK-DAG: [[CST_2_dot_000000_:%.+]] = arith.constant 2.000000e+00 : f32 @@ -459,8 +460,8 @@ func.func @round(%arg0: tensor<15xf32>) -> tensor<*xf32> { // CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<15xf32> // CHECK-DAG: [[LOOP_0_:%.+]] = krnl.define_loops 1 // CHECK: krnl.iterate([[LOOP_0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to 15){ -// CHECK: [[IV:%.+]] = krnl.get_induction_var_value([[LOOP_0_]]) : (!krnl.loop) -> index -// CHECK: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]][[[IV]]] : memref<15xf32> +// CHECK: [[VAR_1_:%.+]] = krnl.get_induction_var_value([[LOOP_0_]]) : (!krnl.loop) -> index +// CHECK: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_1_]]{{.}} : memref<15xf32> // CHECK: [[VAR_3_:%.+]] = math.floor [[LOAD_PARAM_0_MEM_]] : f32 // CHECK: [[VAR_4_:%.+]] = arith.subf [[LOAD_PARAM_0_MEM_]], [[VAR_3_]] : f32 // CHECK-DAG: [[VAR_5_:%.+]] = arith.cmpf ogt, [[VAR_4_]], [[CST_5_dot_000000_]] : f32 @@ -477,7 +478,7 @@ func.func @round(%arg0: tensor<15xf32>) -> tensor<*xf32> { // CHECK-DAG: [[VAR_14_:%.+]] = arith.select [[VAR_12_]], [[VAR_13_]], [[VAR_3_]] : f32 // CHECK-DAG: [[VAR_15_:%.+]] = arith.cmpf oeq, [[VAR_4_]], [[CST_5_dot_000000_]] : f32 // CHECK: [[VAR_16_:%.+]] = arith.select [[VAR_15_]], [[VAR_14_]], [[VAR_7_]] : f32 -// CHECK: krnl.store [[VAR_16_]], [[RES_]][[[IV]]] : memref<15xf32> +// CHECK: krnl.store [[VAR_16_]], [[RES_]]{{.}}[[VAR_1_]]{{.}} : memref<15xf32> // CHECK: } // CHECK: return [[RES_]] : memref<15xf32> // CHECK: } diff --git a/test/mlir/conversion/onnx_to_krnl/Math/Elementwise_with_canonicalize_O3.mlir b/test/mlir/conversion/onnx_to_krnl/Math/Elementwise_with_canonicalize_O3.mlir index 5e149d2d96..7d397ccc9d 100644 --- a/test/mlir/conversion/onnx_to_krnl/Math/Elementwise_with_canonicalize_O3.mlir +++ b/test/mlir/conversion/onnx_to_krnl/Math/Elementwise_with_canonicalize_O3.mlir @@ -62,6 +62,7 @@ func.func @test_mean(%arg0: tensor<30xf32>, %arg1: tensor<30xf32>, %arg2: tensor // ----- + func.func @round(%arg0: tensor<15xf32>) -> tensor<*xf32> { %0 = "onnx.Round"(%arg0) : (tensor<15xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> @@ -69,35 +70,31 @@ func.func @round(%arg0: tensor<15xf32>) -> tensor<*xf32> { // mlir2FileCheck.py // CHECK-LABEL: func.func @round // CHECK-SAME: ([[PARAM_0_:%.+]]: memref<15xf32>) -> memref<15xf32> { -// CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<5.000000e-01> : vector<16xf32> -// CHECK-DAG: [[VAR_cst_0_:%.+]] = arith.constant dense<2.000000e+00> : vector<16xf32> -// CHECK-DAG: [[VAR_cst_1_:%.+]] = arith.constant dense<1.000000e+00> : vector<16xf32> // CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index // CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<64xi8> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_view_:%.+]] = memref.view [[RES_]]{{.}}[[CST_0_]]{{.}}[] : memref<64xi8> to memref<15xf32> -// CHECK-DAG: [[LOOP_0_:%.+]] = krnl.define_loops 1 -// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]] 16 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) -// CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to 15){ -// CHECK: [[VAR_1_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index -// CHECK: [[LOAD_PARAM_0_MEM_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_1_]]{{.}} : memref<15xf32>, vector<16xf32> -// CHECK: [[VAR_3_:%.+]] = math.floor [[LOAD_PARAM_0_MEM_]] : vector<16xf32> -// CHECK: [[VAR_4_:%.+]] = arith.subf [[LOAD_PARAM_0_MEM_]], [[VAR_3_]] : vector<16xf32> -// CHECK-DAG: [[VAR_5_:%.+]] = arith.cmpf ogt, [[VAR_4_]], [[VAR_cst_]] : vector<16xf32> -// CHECK-DAG: [[VAR_6_:%.+]] = arith.addf [[VAR_3_]], [[VAR_cst_1_]] : vector<16xf32> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_7_:%.+]] = arith.select [[VAR_5_]], [[VAR_6_]], [[VAR_3_]] : vector<16xi1>, vector<16xf32> -// CHECK-DAG: [[VAR_8_:%.+]] = arith.mulf [[VAR_3_]], [[VAR_cst_]] : vector<16xf32> -// CHECK: [[VAR_9_:%.+]] = math.floor [[VAR_8_]] : vector<16xf32> -// CHECK: [[VAR_10_:%.+]] = arith.mulf [[VAR_9_]], [[VAR_cst_0_]] : vector<16xf32> -// CHECK: [[VAR_11_:%.+]] = arith.subf [[VAR_3_]], [[VAR_10_]] : vector<16xf32> -// CHECK-DAG: [[VAR_12_:%.+]] = arith.cmpf oeq, [[VAR_11_]], [[VAR_cst_1_]] : vector<16xf32> -// CHECK-DAG: [[VAR_13_:%.+]] = arith.addf [[VAR_3_]], [[VAR_cst_1_]] : vector<16xf32> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_14_:%.+]] = arith.select [[VAR_12_]], [[VAR_13_]], [[VAR_3_]] : vector<16xi1>, vector<16xf32> -// CHECK-DAG: [[VAR_15_:%.+]] = arith.cmpf oeq, [[VAR_4_]], [[VAR_cst_]] : vector<16xf32> -// CHECK: [[VAR_16_:%.+]] = arith.select [[VAR_15_]], [[VAR_14_]], [[VAR_7_]] : vector<16xi1>, vector<16xf32> -// CHECK: vector.store [[VAR_16_]], [[VAR_view_]]{{.}}[[VAR_1_]]{{.}} : memref<15xf32>, vector<16xf32> +// CHECK: [[VAR_view_:%.+]] = memref.view [[RES_]]{{.}}[[CST_0_]]{{.}}[] : memref<64xi8> to memref<15xf32> +// CHECK: krnl.iterate() with (){ +// CHECK: [[LOOP_0_:%.+]] = krnl.define_loops 1 +// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]] 16 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to 15){ +// CHECK: [[VAR_1_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index +// CHECK: [[LOAD_PARAM_0_MEM_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_1_]]{{.}} : memref<15xf32>, vector<16xf32> +// CHECK: [[VAR_3_:%.+]] = vector.shape_cast [[LOAD_PARAM_0_MEM_]] : vector<16xf32> to vector<4x4xf32> +// CHECK: [[VAR_4_:%.+]] = vector.extract [[VAR_3_]][0] : vector<4xf32> from vector<4x4xf32> +// CHECK: [[VAR_5_:%.+]] = "krnl.round_even"([[VAR_4_]]) : (vector<4xf32>) -> vector<4xf32> +// CHECK-DAG: [[VAR_6_:%.+]] = vector.insert [[VAR_5_]], [[VAR_3_]] [0] : vector<4xf32> into vector<4x4xf32> +// CHECK-DAG: [[VAR_7_:%.+]] = vector.extract [[VAR_3_]][1] : vector<4xf32> from vector<4x4xf32> +// CHECK: [[VAR_8_:%.+]] = "krnl.round_even"([[VAR_7_]]) : (vector<4xf32>) -> vector<4xf32> +// CHECK-DAG: [[VAR_9_:%.+]] = vector.insert [[VAR_8_]], [[VAR_6_]] [1] : vector<4xf32> into vector<4x4xf32> +// CHECK-DAG: [[VAR_10_:%.+]] = vector.extract [[VAR_3_]][2] : vector<4xf32> from vector<4x4xf32> +// CHECK: [[VAR_11_:%.+]] = "krnl.round_even"([[VAR_10_]]) : (vector<4xf32>) -> vector<4xf32> +// CHECK-DAG: [[VAR_12_:%.+]] = vector.insert [[VAR_11_]], [[VAR_9_]] [2] : vector<4xf32> into vector<4x4xf32> +// CHECK-DAG: [[VAR_13_:%.+]] = vector.extract [[VAR_3_]][3] : vector<4xf32> from vector<4x4xf32> +// CHECK: [[VAR_14_:%.+]] = "krnl.round_even"([[VAR_13_]]) : (vector<4xf32>) -> vector<4xf32> +// CHECK: [[VAR_15_:%.+]] = vector.insert [[VAR_14_]], [[VAR_12_]] [3] : vector<4xf32> into vector<4x4xf32> +// CHECK: [[VAR_16_:%.+]] = vector.shape_cast [[VAR_15_]] : vector<4x4xf32> to vector<16xf32> +// CHECK: vector.store [[VAR_16_]], [[VAR_view_]]{{.}}[[VAR_1_]]{{.}} : memref<15xf32>, vector<16xf32> +// CHECK: } // CHECK: } // CHECK: return [[VAR_view_]] : memref<15xf32> // CHECK: } diff --git a/test/mlir/conversion/onnx_to_krnl/Quantization/DynamicQuantizeLinear_with_simd_canonicalize.mlir b/test/mlir/conversion/onnx_to_krnl/Quantization/DynamicQuantizeLinear_with_simd_canonicalize.mlir index b0bea0a414..818d8f2149 100644 --- a/test/mlir/conversion/onnx_to_krnl/Quantization/DynamicQuantizeLinear_with_simd_canonicalize.mlir +++ b/test/mlir/conversion/onnx_to_krnl/Quantization/DynamicQuantizeLinear_with_simd_canonicalize.mlir @@ -15,14 +15,8 @@ func.func @test_dynamic_quantize_linear_simd_only(%arg0: tensor<256x16xf32>) -> // CHECK-SAME: ([[PARAM_0_:%.+]]: memref<256x16xf32>) -> (memref<256x16xui8>, memref, memref) { // CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<2.550000e+02> : vector<16xf32> // CHECK-DAG: [[VAR_cst_0_:%.+]] = arith.constant dense<0.000000e+00> : vector<16xf32> -// CHECK-DAG: [[VAR_cst_1_:%.+]] = arith.constant dense<5.000000e-01> : vector<16xf32> -// CHECK-DAG: [[VAR_cst_2_:%.+]] = arith.constant dense<2.000000e+00> : vector<16xf32> -// CHECK-DAG: [[VAR_cst_3_:%.+]] = arith.constant dense<1.000000e+00> : vector<16xf32> -// CHECK-DAG: [[VAR_cst_4_:%.+]] = arith.constant dense<0xFF800000> : vector<32xf32> -// CHECK-DAG: [[VAR_cst_5_:%.+]] = arith.constant dense<0x7F800000> : vector<32xf32> -// CHECK-DAG: [[CST_5_dot_000000_:%.+]] = arith.constant 5.000000e-01 : f32 -// CHECK-DAG: [[CST_2_dot_000000_:%.+]] = arith.constant 2.000000e+00 : f32 -// CHECK-DAG: [[CST_1_dot_000000_:%.+]] = arith.constant 1.000000e+00 : f32 +// CHECK-DAG: [[VAR_cst_1_:%.+]] = arith.constant dense<0xFF800000> : vector<32xf32> +// CHECK-DAG: [[VAR_cst_2_:%.+]] = arith.constant dense<0x7F800000> : vector<32xf32> // CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index // CHECK-DAG: [[CST_4096_:%.+]] = arith.constant 4096 : index // CHECK-DAG: [[CST_0_dot_000000_:%.+]] = arith.constant 0.000000e+00 : f32 @@ -37,21 +31,21 @@ func.func @test_dynamic_quantize_linear_simd_only(%arg0: tensor<256x16xf32>) -> // CHECK-DAG: [[RES_5_:%.+]] = memref.alloc() : memref // CHECK-DAG: [[RES_6_:%.+]] = memref.alloc() {{.*}}: memref<32xf32> // CHECK-DAG: [[RES_7_:%.+]] = memref.alloc() : memref -// CHECK: vector.store [[VAR_cst_5_]], [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> -// CHECK: vector.store [[VAR_cst_4_]], [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> +// CHECK: vector.store [[VAR_cst_2_]], [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> +// CHECK: vector.store [[VAR_cst_1_]], [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> // CHECK: [[LOOP_0_:%.+]] = krnl.define_loops 1 // CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) // CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to 4096){ -// CHECK: [[VAR_33_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_33_]]{{.}} : memref<4096xf32>, vector<32xf32> -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_1_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_33_]]{{.}} : memref<4096xf32>, vector<32xf32> +// CHECK: [[VAR_20_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_20_]]{{.}} : memref<4096xf32>, vector<32xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_1_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_20_]]{{.}} : memref<4096xf32>, vector<32xf32> // CHECK-DAG: [[LOAD_RES_4_MEM_:%.+]] = vector.load [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> // CHECK-DAG: [[LOAD_RES_6_MEM_:%.+]] = vector.load [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_38_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_]], [[LOAD_VAR_reshape_MEM_]] : vector<32xf32> -// CHECK-DAG: [[VAR_39_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_]], [[LOAD_VAR_reshape_MEM_1_]] : vector<32xf32> -// CHECK: vector.store [[VAR_38_]], [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> -// CHECK: vector.store [[VAR_39_]], [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> +// CHECK-DAG: [[VAR_25_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_]], [[LOAD_VAR_reshape_MEM_]] : vector<32xf32> +// CHECK-DAG: [[VAR_26_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_]], [[LOAD_VAR_reshape_MEM_1_]] : vector<32xf32> +// CHECK: vector.store [[VAR_25_]], [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> +// CHECK: vector.store [[VAR_26_]], [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> // CHECK: } // CHECK-DAG: [[LOAD_RES_4_MEM_1_:%.+]] = vector.load [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> // CHECK-DAG: [[LOAD_RES_6_MEM_1_:%.+]] = vector.load [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> @@ -71,67 +65,49 @@ func.func @test_dynamic_quantize_linear_simd_only(%arg0: tensor<256x16xf32>) -> // CHECK: [[VAR_12_:%.+]] = arith.subf [[CST_0_dot_000000_]], [[VAR_11_]] : f32 // CHECK: [[VAR_13_:%.+]] = arith.maxnumf [[VAR_12_]], [[CST_0_dot_000000_]] : f32 // CHECK: [[VAR_14_:%.+]] = arith.minnumf [[VAR_13_]], [[CST_2_dot_550000_]] : f32 -// CHECK: [[VAR_15_:%.+]] = math.floor [[VAR_14_]] : f32 -// CHECK: [[VAR_16_:%.+]] = arith.subf [[VAR_14_]], [[VAR_15_]] : f32 -// CHECK-DAG: [[VAR_17_:%.+]] = arith.cmpf ogt, [[VAR_16_]], [[CST_5_dot_000000_]] : f32 -// CHECK-DAG: [[VAR_18_:%.+]] = arith.addf [[VAR_15_]], [[CST_1_dot_000000_]] : f32 -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_19_:%.+]] = arith.select [[VAR_17_]], [[VAR_18_]], [[VAR_15_]] : f32 -// CHECK-DAG: [[VAR_20_:%.+]] = arith.mulf [[VAR_15_]], [[CST_5_dot_000000_]] : f32 -// CHECK: [[VAR_21_:%.+]] = math.floor [[VAR_20_]] : f32 -// CHECK: [[VAR_22_:%.+]] = arith.mulf [[VAR_21_]], [[CST_2_dot_000000_]] : f32 -// CHECK: [[VAR_23_:%.+]] = arith.subf [[VAR_15_]], [[VAR_22_]] : f32 -// CHECK-DAG: [[VAR_24_:%.+]] = arith.cmpf oeq, [[VAR_23_]], [[CST_1_dot_000000_]] : f32 -// CHECK-DAG: [[VAR_25_:%.+]] = arith.addf [[VAR_15_]], [[CST_1_dot_000000_]] : f32 -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_26_:%.+]] = arith.select [[VAR_24_]], [[VAR_25_]], [[VAR_15_]] : f32 -// CHECK-DAG: [[VAR_27_:%.+]] = arith.cmpf oeq, [[VAR_16_]], [[CST_5_dot_000000_]] : f32 -// CHECK: [[VAR_28_:%.+]] = arith.select [[VAR_27_]], [[VAR_26_]], [[VAR_19_]] : f32 -// CHECK: [[VAR_29_:%.+]] = arith.fptoui [[VAR_28_]] : f32 to i32 -// CHECK: [[VAR_30_:%.+]] = arith.trunci [[VAR_29_]] : i32 to i8 -// CHECK: [[VAR_31_:%.+]] = builtin.unrealized_conversion_cast [[VAR_30_]] : i8 to ui8 +// CHECK: [[VAR_15_:%.+]] = "krnl.round_even"([[VAR_14_]]) : (f32) -> f32 +// CHECK: [[VAR_16_:%.+]] = arith.fptoui [[VAR_15_]] : f32 to i32 +// CHECK: [[VAR_17_:%.+]] = arith.trunci [[VAR_16_]] : i32 to i8 +// CHECK: [[VAR_18_:%.+]] = builtin.unrealized_conversion_cast [[VAR_17_]] : i8 to ui8 // CHECK: krnl.store [[VAR_10_]], [[RES_1_]][] : memref -// CHECK: krnl.store [[VAR_31_]], [[RES_2_]][] : memref +// CHECK: krnl.store [[VAR_18_]], [[RES_2_]][] : memref // CHECK: [[RES_8_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> // CHECK: affine.store [[CST_4096_]], [[RES_8_]][0] : memref<1xindex> -// CHECK-DAG: [[VAR_reshape_19_:%.+]] = memref.reshape [[PARAM_0_]]([[RES_8_]]) : (memref<256x16xf32>, memref<1xindex>) -> memref<4096xf32> +// CHECK-DAG: [[VAR_reshape_13_:%.+]] = memref.reshape [[PARAM_0_]]([[RES_8_]]) : (memref<256x16xf32>, memref<1xindex>) -> memref<4096xf32> // CHECK-DAG: [[RES_9_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> // CHECK: affine.store [[CST_4096_]], [[RES_9_]][0] : memref<1xindex> -// CHECK-DAG: [[VAR_reshape_21_:%.+]] = memref.reshape [[RES_]]([[RES_]]_20) : (memref<256x16xui8>, memref<1xindex>) -> memref<4096xui8> +// CHECK-DAG: [[VAR_reshape_15_:%.+]] = memref.reshape [[RES_]]([[RES_]]_14) : (memref<256x16xui8>, memref<1xindex>) -> memref<4096xui8> // CHECK-DAG: [[LOOP_1_:%.+]] = krnl.define_loops 1 // CHECK: [[BLOCK_TILE__1_:%.+]], [[BLOCK_IN__1_:%.+]] = krnl.block [[LOOP_1_]] 16 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) // CHECK: krnl.iterate([[BLOCK_TILE__1_]]) with ([[LOOP_1_]] -> [[I_1_:%.+]] = 0 to 4096){ -// CHECK: [[VAR_33_1_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__1_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = vector.load [[VAR_reshape_19_]]{{.}}[[VAR_33_1_]]{{.}} : memref<4096xf32>, vector<16xf32> +// CHECK: [[VAR_20_1_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__1_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = vector.load [[VAR_reshape_13_]]{{.}}[[VAR_20_1_]]{{.}} : memref<4096xf32>, vector<16xf32> // CHECK-DAG: [[LOAD_VAR_reshape_MEM_1_:%.+]] = vector.splat [[VAR_10_]] : vector<16xf32> // CHECK: [[LOAD_RES_4_MEM_2_:%.+]] = arith.divf [[LOAD_VAR_reshape_MEM_2_]], [[LOAD_VAR_reshape_MEM_1_]] : vector<16xf32> -// CHECK: [[LOAD_RES_6_MEM_2_:%.+]] = math.floor [[LOAD_RES_4_MEM_2_]] : vector<16xf32> -// CHECK: [[VAR_38_1_:%.+]] = arith.subf [[LOAD_RES_4_MEM_2_]], [[LOAD_RES_6_MEM_2_]] : vector<16xf32> -// CHECK-DAG: [[VAR_39_1_:%.+]] = arith.cmpf ogt, [[VAR_38_1_]], [[VAR_cst_1_]] : vector<16xf32> -// CHECK-DAG: [[VAR_40_:%.+]] = arith.addf [[LOAD_RES_6_MEM_2_]], [[VAR_cst_3_]] : vector<16xf32> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_41_:%.+]] = arith.select [[VAR_39_1_]], [[VAR_40_]], [[LOAD_RES_6_MEM_2_]] : vector<16xi1>, vector<16xf32> -// CHECK-DAG: [[VAR_42_:%.+]] = arith.mulf [[LOAD_RES_6_MEM_2_]], [[VAR_cst_1_]] : vector<16xf32> -// CHECK: [[VAR_43_:%.+]] = math.floor [[VAR_42_]] : vector<16xf32> -// CHECK: [[VAR_44_:%.+]] = arith.mulf [[VAR_43_]], [[VAR_cst_2_]] : vector<16xf32> -// CHECK: [[VAR_45_:%.+]] = arith.subf [[LOAD_RES_6_MEM_2_]], [[VAR_44_]] : vector<16xf32> -// CHECK-DAG: [[VAR_46_:%.+]] = arith.cmpf oeq, [[VAR_45_]], [[VAR_cst_3_]] : vector<16xf32> -// CHECK-DAG: [[VAR_47_:%.+]] = arith.addf [[LOAD_RES_6_MEM_2_]], [[VAR_cst_3_]] : vector<16xf32> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_48_:%.+]] = arith.select [[VAR_46_]], [[VAR_47_]], [[LOAD_RES_6_MEM_2_]] : vector<16xi1>, vector<16xf32> -// CHECK-DAG: [[VAR_49_:%.+]] = arith.cmpf oeq, [[VAR_38_1_]], [[VAR_cst_1_]] : vector<16xf32> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_50_:%.+]] = arith.select [[VAR_49_]], [[VAR_48_]], [[VAR_41_]] : vector<16xi1>, vector<16xf32> -// CHECK-DAG: [[VAR_51_:%.+]] = vector.splat [[VAR_28_]] : vector<16xf32> -// CHECK: [[VAR_52_:%.+]] = arith.addf [[VAR_50_]], [[VAR_51_]] : vector<16xf32> -// CHECK: [[VAR_53_:%.+]] = arith.maxnumf [[VAR_52_]], [[VAR_cst_0_]] : vector<16xf32> -// CHECK: [[VAR_54_:%.+]] = arith.minnumf [[VAR_53_]], [[VAR_cst_]] : vector<16xf32> -// CHECK: [[VAR_55_:%.+]] = arith.fptoui [[VAR_54_]] : vector<16xf32> to vector<16xi32> -// CHECK: [[VAR_56_:%.+]] = arith.trunci [[VAR_55_]] : vector<16xi32> to vector<16xi8> -// CHECK: [[VAR_57_:%.+]] = builtin.unrealized_conversion_cast [[VAR_56_]] : vector<16xi8> to vector<16xui8> -// CHECK: vector.store [[VAR_57_]], [[VAR_reshape_21_]]{{.}}[[VAR_33_1_]]{{.}} : memref<4096xui8>, vector<16xui8> +// CHECK: [[LOAD_RES_6_MEM_2_:%.+]] = vector.shape_cast [[LOAD_RES_4_MEM_2_]] : vector<16xf32> to vector<4x4xf32> +// CHECK: [[VAR_25_1_:%.+]] = vector.extract [[LOAD_RES_6_MEM_2_]][0] : vector<4xf32> from vector<4x4xf32> +// CHECK: [[VAR_26_1_:%.+]] = "krnl.round_even"([[VAR_25_1_]]) : (vector<4xf32>) -> vector<4xf32> +// CHECK-DAG: [[VAR_27_:%.+]] = vector.insert [[VAR_26_1_]], [[LOAD_RES_6_MEM_2_]] [0] : vector<4xf32> into vector<4x4xf32> +// CHECK-DAG: [[VAR_28_:%.+]] = vector.extract [[LOAD_RES_6_MEM_2_]][1] : vector<4xf32> from vector<4x4xf32> +// CHECK: [[VAR_29_:%.+]] = "krnl.round_even"([[VAR_28_]]) : (vector<4xf32>) -> vector<4xf32> +// CHECK-DAG: [[VAR_30_:%.+]] = vector.insert [[VAR_29_]], [[VAR_27_]] [1] : vector<4xf32> into vector<4x4xf32> +// CHECK-DAG: [[VAR_31_:%.+]] = vector.extract [[LOAD_RES_6_MEM_2_]][2] : vector<4xf32> from vector<4x4xf32> +// CHECK: [[VAR_32_:%.+]] = "krnl.round_even"([[VAR_31_]]) : (vector<4xf32>) -> vector<4xf32> +// CHECK-DAG: [[VAR_33_:%.+]] = vector.insert [[VAR_32_]], [[VAR_30_]] [2] : vector<4xf32> into vector<4x4xf32> +// CHECK-DAG: [[VAR_34_:%.+]] = vector.extract [[LOAD_RES_6_MEM_2_]][3] : vector<4xf32> from vector<4x4xf32> +// CHECK: [[VAR_35_:%.+]] = "krnl.round_even"([[VAR_34_]]) : (vector<4xf32>) -> vector<4xf32> +// CHECK: [[VAR_36_:%.+]] = vector.insert [[VAR_35_]], [[VAR_33_]] [3] : vector<4xf32> into vector<4x4xf32> +// CHECK-DAG: [[VAR_37_:%.+]] = vector.shape_cast [[VAR_36_]] : vector<4x4xf32> to vector<16xf32> +// CHECK-DAG: [[VAR_38_:%.+]] = vector.splat [[VAR_15_]] : vector<16xf32> +// CHECK: [[VAR_39_:%.+]] = arith.addf [[VAR_37_]], [[VAR_38_]] : vector<16xf32> +// CHECK: [[VAR_40_:%.+]] = arith.maxnumf [[VAR_39_]], [[VAR_cst_0_]] : vector<16xf32> +// CHECK: [[VAR_41_:%.+]] = arith.minnumf [[VAR_40_]], [[VAR_cst_]] : vector<16xf32> +// CHECK: [[VAR_42_:%.+]] = arith.fptoui [[VAR_41_]] : vector<16xf32> to vector<16xi32> +// CHECK: [[VAR_43_:%.+]] = arith.trunci [[VAR_42_]] : vector<16xi32> to vector<16xi8> +// CHECK: [[VAR_44_:%.+]] = builtin.unrealized_conversion_cast [[VAR_43_]] : vector<16xi8> to vector<16xui8> +// CHECK: vector.store [[VAR_44_]], [[VAR_reshape_15_]]{{.}}[[VAR_20_1_]]{{.}} : memref<4096xui8>, vector<16xui8> // CHECK: } -// CHECK: return [[RES_]], [[RES_]]_11, [[RES_]]_12 : memref<256x16xui8>, memref, memref +// CHECK: return [[RES_]], [[RES_]]_5, [[RES_]]_6 : memref<256x16xui8>, memref, memref // CHECK: } } @@ -147,14 +123,8 @@ func.func @test_dynamic_quantize_linear_simd_and_scalar(%arg0: tensor<255x17xf32 // CHECK-SAME: ([[PARAM_0_:%.+]]: memref<255x17xf32>) -> (memref<255x17xui8>, memref, memref) { // CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<2.550000e+02> : vector<16xf32> // CHECK-DAG: [[VAR_cst_0_:%.+]] = arith.constant dense<0.000000e+00> : vector<16xf32> -// CHECK-DAG: [[VAR_cst_1_:%.+]] = arith.constant dense<5.000000e-01> : vector<16xf32> -// CHECK-DAG: [[VAR_cst_2_:%.+]] = arith.constant dense<2.000000e+00> : vector<16xf32> -// CHECK-DAG: [[VAR_cst_3_:%.+]] = arith.constant dense<1.000000e+00> : vector<16xf32> -// CHECK-DAG: [[VAR_cst_4_:%.+]] = arith.constant dense<0xFF800000> : vector<32xf32> -// CHECK-DAG: [[VAR_cst_5_:%.+]] = arith.constant dense<0x7F800000> : vector<32xf32> -// CHECK-DAG: [[CST_5_dot_000000_:%.+]] = arith.constant 5.000000e-01 : f32 -// CHECK-DAG: [[CST_2_dot_000000_:%.+]] = arith.constant 2.000000e+00 : f32 -// CHECK-DAG: [[CST_1_dot_000000_:%.+]] = arith.constant 1.000000e+00 : f32 +// CHECK-DAG: [[VAR_cst_1_:%.+]] = arith.constant dense<0xFF800000> : vector<32xf32> +// CHECK-DAG: [[VAR_cst_2_:%.+]] = arith.constant dense<0x7F800000> : vector<32xf32> // CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index // CHECK-DAG: [[CST_4335_:%.+]] = arith.constant 4335 : index // CHECK-DAG: [[CST_0_dot_000000_:%.+]] = arith.constant 0.000000e+00 : f32 @@ -169,34 +139,34 @@ func.func @test_dynamic_quantize_linear_simd_and_scalar(%arg0: tensor<255x17xf32 // CHECK-DAG: [[RES_5_:%.+]] = memref.alloc() : memref // CHECK-DAG: [[RES_6_:%.+]] = memref.alloc() {{.*}}: memref<32xf32> // CHECK-DAG: [[RES_7_:%.+]] = memref.alloc() : memref -// CHECK: vector.store [[VAR_cst_5_]], [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> -// CHECK: vector.store [[VAR_cst_4_]], [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> +// CHECK: vector.store [[VAR_cst_2_]], [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> +// CHECK: vector.store [[VAR_cst_1_]], [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> // CHECK: [[LOOP_0_:%.+]] = krnl.define_loops 1 // CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) // CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to 4304){ -// CHECK: [[VAR_35_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_35_]]{{.}} : memref<4335xf32>, vector<32xf32> -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_1_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_35_]]{{.}} : memref<4335xf32>, vector<32xf32> +// CHECK: [[VAR_22_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_22_]]{{.}} : memref<4335xf32>, vector<32xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_1_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_22_]]{{.}} : memref<4335xf32>, vector<32xf32> // CHECK-DAG: [[LOAD_RES_4_MEM_:%.+]] = vector.load [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> // CHECK-DAG: [[LOAD_RES_6_MEM_:%.+]] = vector.load [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_40_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_]], [[LOAD_VAR_reshape_MEM_]] : vector<32xf32> -// CHECK-DAG: [[VAR_41_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_]], [[LOAD_VAR_reshape_MEM_1_]] : vector<32xf32> -// CHECK: vector.store [[VAR_40_]], [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> -// CHECK: vector.store [[VAR_41_]], [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> +// CHECK-DAG: [[VAR_27_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_]], [[LOAD_VAR_reshape_MEM_]] : vector<32xf32> +// CHECK-DAG: [[VAR_28_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_]], [[LOAD_VAR_reshape_MEM_1_]] : vector<32xf32> +// CHECK: vector.store [[VAR_27_]], [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> +// CHECK: vector.store [[VAR_28_]], [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> // CHECK: } // CHECK: [[LOOP_1_:%.+]] = krnl.define_loops 1 // CHECK: krnl.iterate([[LOOP_1_]]) with ([[LOOP_1_]] -> [[I_1_:%.+]] = 4320 to 4335){ -// CHECK: [[VAR_35_1_:%.+]] = krnl.get_induction_var_value([[LOOP_1_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = krnl.load [[VAR_reshape_]]{{.}}[[VAR_35_1_]]{{.}} : memref<4335xf32> -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_3_:%.+]] = krnl.load [[VAR_reshape_]]{{.}}[[VAR_35_1_]]{{.}} : memref<4335xf32> +// CHECK: [[VAR_22_1_:%.+]] = krnl.get_induction_var_value([[LOOP_1_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = krnl.load [[VAR_reshape_]]{{.}}[[VAR_22_1_]]{{.}} : memref<4335xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_3_:%.+]] = krnl.load [[VAR_reshape_]]{{.}}[[VAR_22_1_]]{{.}} : memref<4335xf32> // CHECK-DAG: [[LOAD_RES_4_MEM_1_:%.+]] = krnl.load [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32> // CHECK-DAG: [[LOAD_RES_6_MEM_1_:%.+]] = krnl.load [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_40_1_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_1_]], [[LOAD_VAR_reshape_MEM_2_]] : f32 -// CHECK-DAG: [[VAR_41_1_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_1_]], [[LOAD_VAR_reshape_MEM_3_]] : f32 -// CHECK: krnl.store [[VAR_40_1_]], [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32> -// CHECK: krnl.store [[VAR_41_1_]], [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32> +// CHECK-DAG: [[VAR_27_1_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_1_]], [[LOAD_VAR_reshape_MEM_2_]] : f32 +// CHECK-DAG: [[VAR_28_1_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_1_]], [[LOAD_VAR_reshape_MEM_3_]] : f32 +// CHECK: krnl.store [[VAR_27_1_]], [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32> +// CHECK: krnl.store [[VAR_28_1_]], [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32> // CHECK: } // CHECK-DAG: [[LOAD_RES_4_MEM_2_:%.+]] = vector.load [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> // CHECK-DAG: [[LOAD_RES_6_MEM_2_:%.+]] = vector.load [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> @@ -216,96 +186,63 @@ func.func @test_dynamic_quantize_linear_simd_and_scalar(%arg0: tensor<255x17xf32 // CHECK: [[VAR_13_:%.+]] = arith.subf [[CST_0_dot_000000_]], [[VAR_12_]] : f32 // CHECK: [[VAR_14_:%.+]] = arith.maxnumf [[VAR_13_]], [[CST_0_dot_000000_]] : f32 // CHECK: [[VAR_15_:%.+]] = arith.minnumf [[VAR_14_]], [[CST_2_dot_550000_]] : f32 -// CHECK: [[VAR_16_:%.+]] = math.floor [[VAR_15_]] : f32 -// CHECK: [[VAR_17_:%.+]] = arith.subf [[VAR_15_]], [[VAR_16_]] : f32 -// CHECK-DAG: [[VAR_18_:%.+]] = arith.cmpf ogt, [[VAR_17_]], [[CST_5_dot_000000_]] : f32 -// CHECK-DAG: [[VAR_19_:%.+]] = arith.addf [[VAR_16_]], [[CST_1_dot_000000_]] : f32 -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_20_:%.+]] = arith.select [[VAR_18_]], [[VAR_19_]], [[VAR_16_]] : f32 -// CHECK-DAG: [[VAR_21_:%.+]] = arith.mulf [[VAR_16_]], [[CST_5_dot_000000_]] : f32 -// CHECK: [[VAR_22_:%.+]] = math.floor [[VAR_21_]] : f32 -// CHECK: [[VAR_23_:%.+]] = arith.mulf [[VAR_22_]], [[CST_2_dot_000000_]] : f32 -// CHECK: [[VAR_24_:%.+]] = arith.subf [[VAR_16_]], [[VAR_23_]] : f32 -// CHECK-DAG: [[VAR_25_:%.+]] = arith.cmpf oeq, [[VAR_24_]], [[CST_1_dot_000000_]] : f32 -// CHECK-DAG: [[VAR_26_:%.+]] = arith.addf [[VAR_16_]], [[CST_1_dot_000000_]] : f32 -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_27_:%.+]] = arith.select [[VAR_25_]], [[VAR_26_]], [[VAR_16_]] : f32 -// CHECK-DAG: [[VAR_28_:%.+]] = arith.cmpf oeq, [[VAR_17_]], [[CST_5_dot_000000_]] : f32 -// CHECK: [[VAR_29_:%.+]] = arith.select [[VAR_28_]], [[VAR_27_]], [[VAR_20_]] : f32 -// CHECK: [[VAR_30_:%.+]] = arith.fptoui [[VAR_29_]] : f32 to i32 -// CHECK: [[VAR_31_:%.+]] = arith.trunci [[VAR_30_]] : i32 to i8 -// CHECK: [[VAR_32_:%.+]] = builtin.unrealized_conversion_cast [[VAR_31_]] : i8 to ui8 +// CHECK: [[VAR_16_:%.+]] = "krnl.round_even"([[VAR_15_]]) : (f32) -> f32 +// CHECK: [[VAR_17_:%.+]] = arith.fptoui [[VAR_16_]] : f32 to i32 +// CHECK: [[VAR_18_:%.+]] = arith.trunci [[VAR_17_]] : i32 to i8 +// CHECK: [[VAR_19_:%.+]] = builtin.unrealized_conversion_cast [[VAR_18_]] : i8 to ui8 // CHECK: krnl.store [[VAR_11_]], [[RES_1_]][] : memref -// CHECK: krnl.store [[VAR_32_]], [[RES_2_]][] : memref +// CHECK: krnl.store [[VAR_19_]], [[RES_2_]][] : memref // CHECK: [[RES_8_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> // CHECK: affine.store [[CST_4335_]], [[RES_8_]][0] : memref<1xindex> -// CHECK-DAG: [[VAR_reshape_19_:%.+]] = memref.reshape [[PARAM_0_]]([[RES_8_]]) : (memref<255x17xf32>, memref<1xindex>) -> memref<4335xf32> +// CHECK-DAG: [[VAR_reshape_13_:%.+]] = memref.reshape [[PARAM_0_]]([[RES_8_]]) : (memref<255x17xf32>, memref<1xindex>) -> memref<4335xf32> // CHECK-DAG: [[RES_9_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> // CHECK: affine.store [[CST_4335_]], [[RES_9_]][0] : memref<1xindex> -// CHECK-DAG: [[VAR_reshape_21_:%.+]] = memref.reshape [[RES_]]([[RES_]]_20) : (memref<255x17xui8>, memref<1xindex>) -> memref<4335xui8> +// CHECK-DAG: [[VAR_reshape_15_:%.+]] = memref.reshape [[RES_]]([[RES_]]_14) : (memref<255x17xui8>, memref<1xindex>) -> memref<4335xui8> // CHECK-DAG: [[LOOP_2_:%.+]] = krnl.define_loops 1 // CHECK: [[BLOCK_TILE__1_:%.+]], [[BLOCK_IN__1_:%.+]] = krnl.block [[LOOP_2_]] 16 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) // CHECK: krnl.iterate([[BLOCK_TILE__1_]]) with ([[LOOP_2_]] -> [[I_2_:%.+]] = 0 to 4320){ -// CHECK: [[VAR_35_2_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__1_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = vector.load [[VAR_reshape_19_]]{{.}}[[VAR_35_2_]]{{.}} : memref<4335xf32>, vector<16xf32> +// CHECK: [[VAR_22_2_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__1_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = vector.load [[VAR_reshape_13_]]{{.}}[[VAR_22_2_]]{{.}} : memref<4335xf32>, vector<16xf32> // CHECK-DAG: [[LOAD_VAR_reshape_MEM_3_:%.+]] = vector.splat [[VAR_11_]] : vector<16xf32> // CHECK: [[LOAD_RES_4_MEM_1_:%.+]] = arith.divf [[LOAD_VAR_reshape_MEM_2_]], [[LOAD_VAR_reshape_MEM_3_]] : vector<16xf32> -// CHECK: [[LOAD_RES_6_MEM_1_:%.+]] = math.floor [[LOAD_RES_4_MEM_1_]] : vector<16xf32> -// CHECK: [[VAR_40_2_:%.+]] = arith.subf [[LOAD_RES_4_MEM_1_]], [[LOAD_RES_6_MEM_1_]] : vector<16xf32> -// CHECK-DAG: [[VAR_41_2_:%.+]] = arith.cmpf ogt, [[VAR_40_2_]], [[VAR_cst_1_]] : vector<16xf32> -// CHECK-DAG: [[VAR_42_:%.+]] = arith.addf [[LOAD_RES_6_MEM_1_]], [[VAR_cst_3_]] : vector<16xf32> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_43_:%.+]] = arith.select [[VAR_41_2_]], [[VAR_42_]], [[LOAD_RES_6_MEM_1_]] : vector<16xi1>, vector<16xf32> -// CHECK-DAG: [[VAR_44_:%.+]] = arith.mulf [[LOAD_RES_6_MEM_1_]], [[VAR_cst_1_]] : vector<16xf32> -// CHECK: [[VAR_45_:%.+]] = math.floor [[VAR_44_]] : vector<16xf32> -// CHECK: [[VAR_46_:%.+]] = arith.mulf [[VAR_45_]], [[VAR_cst_2_]] : vector<16xf32> -// CHECK: [[VAR_47_:%.+]] = arith.subf [[LOAD_RES_6_MEM_1_]], [[VAR_46_]] : vector<16xf32> -// CHECK-DAG: [[VAR_48_:%.+]] = arith.cmpf oeq, [[VAR_47_]], [[VAR_cst_3_]] : vector<16xf32> -// CHECK-DAG: [[VAR_49_:%.+]] = arith.addf [[LOAD_RES_6_MEM_1_]], [[VAR_cst_3_]] : vector<16xf32> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_50_:%.+]] = arith.select [[VAR_48_]], [[VAR_49_]], [[LOAD_RES_6_MEM_1_]] : vector<16xi1>, vector<16xf32> -// CHECK-DAG: [[VAR_51_:%.+]] = arith.cmpf oeq, [[VAR_40_2_]], [[VAR_cst_1_]] : vector<16xf32> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_52_:%.+]] = arith.select [[VAR_51_]], [[VAR_50_]], [[VAR_43_]] : vector<16xi1>, vector<16xf32> -// CHECK-DAG: [[VAR_53_:%.+]] = vector.splat [[VAR_29_]] : vector<16xf32> -// CHECK: [[VAR_54_:%.+]] = arith.addf [[VAR_52_]], [[VAR_53_]] : vector<16xf32> -// CHECK: [[VAR_55_:%.+]] = arith.maxnumf [[VAR_54_]], [[VAR_cst_0_]] : vector<16xf32> -// CHECK: [[VAR_56_:%.+]] = arith.minnumf [[VAR_55_]], [[VAR_cst_]] : vector<16xf32> -// CHECK: [[VAR_57_:%.+]] = arith.fptoui [[VAR_56_]] : vector<16xf32> to vector<16xi32> -// CHECK: [[VAR_58_:%.+]] = arith.trunci [[VAR_57_]] : vector<16xi32> to vector<16xi8> -// CHECK: [[VAR_59_:%.+]] = builtin.unrealized_conversion_cast [[VAR_58_]] : vector<16xi8> to vector<16xui8> -// CHECK: vector.store [[VAR_59_]], [[VAR_reshape_21_]]{{.}}[[VAR_35_2_]]{{.}} : memref<4335xui8>, vector<16xui8> +// CHECK: [[LOAD_RES_6_MEM_1_:%.+]] = vector.shape_cast [[LOAD_RES_4_MEM_1_]] : vector<16xf32> to vector<4x4xf32> +// CHECK: [[VAR_27_2_:%.+]] = vector.extract [[LOAD_RES_6_MEM_1_]][0] : vector<4xf32> from vector<4x4xf32> +// CHECK: [[VAR_28_2_:%.+]] = "krnl.round_even"([[VAR_27_2_]]) : (vector<4xf32>) -> vector<4xf32> +// CHECK-DAG: [[VAR_29_:%.+]] = vector.insert [[VAR_28_2_]], [[LOAD_RES_6_MEM_1_]] [0] : vector<4xf32> into vector<4x4xf32> +// CHECK-DAG: [[VAR_30_:%.+]] = vector.extract [[LOAD_RES_6_MEM_1_]][1] : vector<4xf32> from vector<4x4xf32> +// CHECK: [[VAR_31_:%.+]] = "krnl.round_even"([[VAR_30_]]) : (vector<4xf32>) -> vector<4xf32> +// CHECK-DAG: [[VAR_32_:%.+]] = vector.insert [[VAR_31_]], [[VAR_29_]] [1] : vector<4xf32> into vector<4x4xf32> +// CHECK-DAG: [[VAR_33_:%.+]] = vector.extract [[LOAD_RES_6_MEM_1_]][2] : vector<4xf32> from vector<4x4xf32> +// CHECK: [[VAR_34_:%.+]] = "krnl.round_even"([[VAR_33_]]) : (vector<4xf32>) -> vector<4xf32> +// CHECK-DAG: [[VAR_35_:%.+]] = vector.insert [[VAR_34_]], [[VAR_32_]] [2] : vector<4xf32> into vector<4x4xf32> +// CHECK-DAG: [[VAR_36_:%.+]] = vector.extract [[LOAD_RES_6_MEM_1_]][3] : vector<4xf32> from vector<4x4xf32> +// CHECK: [[VAR_37_:%.+]] = "krnl.round_even"([[VAR_36_]]) : (vector<4xf32>) -> vector<4xf32> +// CHECK: [[VAR_38_:%.+]] = vector.insert [[VAR_37_]], [[VAR_35_]] [3] : vector<4xf32> into vector<4x4xf32> +// CHECK-DAG: [[VAR_39_:%.+]] = vector.shape_cast [[VAR_38_]] : vector<4x4xf32> to vector<16xf32> +// CHECK-DAG: [[VAR_40_:%.+]] = vector.splat [[VAR_16_]] : vector<16xf32> +// CHECK: [[VAR_41_:%.+]] = arith.addf [[VAR_39_]], [[VAR_40_]] : vector<16xf32> +// CHECK: [[VAR_42_:%.+]] = arith.maxnumf [[VAR_41_]], [[VAR_cst_0_]] : vector<16xf32> +// CHECK: [[VAR_43_:%.+]] = arith.minnumf [[VAR_42_]], [[VAR_cst_]] : vector<16xf32> +// CHECK: [[VAR_44_:%.+]] = arith.fptoui [[VAR_43_]] : vector<16xf32> to vector<16xi32> +// CHECK: [[VAR_45_:%.+]] = arith.trunci [[VAR_44_]] : vector<16xi32> to vector<16xi8> +// CHECK: [[VAR_46_:%.+]] = builtin.unrealized_conversion_cast [[VAR_45_]] : vector<16xi8> to vector<16xui8> +// CHECK: vector.store [[VAR_46_]], [[VAR_reshape_15_]]{{.}}[[VAR_22_2_]]{{.}} : memref<4335xui8>, vector<16xui8> // CHECK: } // CHECK: [[LOOP_3_:%.+]] = krnl.define_loops 1 // CHECK: krnl.iterate([[LOOP_3_]]) with ([[LOOP_3_]] -> [[I_3_:%.+]] = 4320 to 4335){ -// CHECK: [[VAR_35_3_:%.+]] = krnl.get_induction_var_value([[LOOP_3_]]) : (!krnl.loop) -> index -// CHECK: [[LOAD_VAR_reshape_MEM_2_1_:%.+]] = krnl.load [[VAR_reshape_19_]]{{.}}[[VAR_35_3_]]{{.}} : memref<4335xf32> +// CHECK: [[VAR_22_3_:%.+]] = krnl.get_induction_var_value([[LOOP_3_]]) : (!krnl.loop) -> index +// CHECK: [[LOAD_VAR_reshape_MEM_2_1_:%.+]] = krnl.load [[VAR_reshape_13_]]{{.}}[[VAR_22_3_]]{{.}} : memref<4335xf32> // CHECK: [[LOAD_VAR_reshape_MEM_3_1_:%.+]] = arith.divf [[LOAD_VAR_reshape_MEM_2_1_]], [[VAR_11_]] : f32 -// CHECK: [[LOAD_RES_4_MEM_1_1_:%.+]] = math.floor [[LOAD_VAR_reshape_MEM_3_1_]] : f32 -// CHECK: [[LOAD_RES_6_MEM_1_1_:%.+]] = arith.subf [[LOAD_VAR_reshape_MEM_3_1_]], [[LOAD_RES_4_MEM_1_1_]] : f32 -// CHECK-DAG: [[VAR_40_3_:%.+]] = arith.cmpf ogt, [[LOAD_RES_6_MEM_1_1_]], [[CST_5_dot_000000_]] : f32 -// CHECK-DAG: [[VAR_41_3_:%.+]] = arith.addf [[LOAD_RES_4_MEM_1_1_]], [[CST_1_dot_000000_]] : f32 -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_42_1_:%.+]] = arith.select [[VAR_40_3_]], [[VAR_41_3_]], [[LOAD_RES_4_MEM_1_1_]] : f32 -// CHECK-DAG: [[VAR_43_1_:%.+]] = arith.mulf [[LOAD_RES_4_MEM_1_1_]], [[CST_5_dot_000000_]] : f32 -// CHECK: [[VAR_44_1_:%.+]] = math.floor [[VAR_43_1_]] : f32 -// CHECK: [[VAR_45_1_:%.+]] = arith.mulf [[VAR_44_1_]], [[CST_2_dot_000000_]] : f32 -// CHECK: [[VAR_46_1_:%.+]] = arith.subf [[LOAD_RES_4_MEM_1_1_]], [[VAR_45_1_]] : f32 -// CHECK-DAG: [[VAR_47_1_:%.+]] = arith.cmpf oeq, [[VAR_46_1_]], [[CST_1_dot_000000_]] : f32 -// CHECK-DAG: [[VAR_48_1_:%.+]] = arith.addf [[LOAD_RES_4_MEM_1_1_]], [[CST_1_dot_000000_]] : f32 -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_49_1_:%.+]] = arith.select [[VAR_47_1_]], [[VAR_48_1_]], [[LOAD_RES_4_MEM_1_1_]] : f32 -// CHECK-DAG: [[VAR_50_1_:%.+]] = arith.cmpf oeq, [[LOAD_RES_6_MEM_1_1_]], [[CST_5_dot_000000_]] : f32 -// CHECK: [[VAR_51_1_:%.+]] = arith.select [[VAR_50_1_]], [[VAR_49_1_]], [[VAR_42_1_]] : f32 -// CHECK: [[VAR_52_1_:%.+]] = arith.addf [[VAR_51_1_]], [[VAR_29_]] : f32 -// CHECK: [[VAR_53_1_:%.+]] = arith.maxnumf [[VAR_52_1_]], [[CST_0_dot_000000_]] : f32 -// CHECK: [[VAR_54_1_:%.+]] = arith.minnumf [[VAR_53_1_]], [[CST_2_dot_550000_]] : f32 -// CHECK: [[VAR_55_1_:%.+]] = arith.fptoui [[VAR_54_1_]] : f32 to i32 -// CHECK: [[VAR_56_1_:%.+]] = arith.trunci [[VAR_55_1_]] : i32 to i8 -// CHECK: [[VAR_57_1_:%.+]] = builtin.unrealized_conversion_cast [[VAR_56_1_]] : i8 to ui8 -// CHECK: krnl.store [[VAR_57_1_]], [[VAR_reshape_21_]]{{.}}[[VAR_35_3_]]{{.}} : memref<4335xui8> +// CHECK: [[LOAD_RES_4_MEM_1_1_:%.+]] = "krnl.round_even"([[LOAD_VAR_reshape_MEM_3_1_]]) : (f32) -> f32 +// CHECK: [[LOAD_RES_6_MEM_1_1_:%.+]] = arith.addf [[LOAD_RES_4_MEM_1_1_]], [[VAR_16_]] : f32 +// CHECK: [[VAR_27_3_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_1_1_]], [[CST_0_dot_000000_]] : f32 +// CHECK: [[VAR_28_3_:%.+]] = arith.minnumf [[VAR_27_3_]], [[CST_2_dot_550000_]] : f32 +// CHECK: [[VAR_29_1_:%.+]] = arith.fptoui [[VAR_28_3_]] : f32 to i32 +// CHECK: [[VAR_30_1_:%.+]] = arith.trunci [[VAR_29_1_]] : i32 to i8 +// CHECK: [[VAR_31_1_:%.+]] = builtin.unrealized_conversion_cast [[VAR_30_1_]] : i8 to ui8 +// CHECK: krnl.store [[VAR_31_1_]], [[VAR_reshape_15_]]{{.}}[[VAR_22_3_]]{{.}} : memref<4335xui8> // CHECK: } -// CHECK: return [[RES_]], [[RES_]]_11, [[RES_]]_12 : memref<255x17xui8>, memref, memref +// CHECK: return [[RES_]], [[RES_]]_5, [[RES_]]_6 : memref<255x17xui8>, memref, memref // CHECK: } } @@ -321,14 +258,8 @@ func.func @test_dynamic_quantize_linear_reduced_simd_only(%arg0: tensor<1x8xf32> // CHECK-SAME: ([[PARAM_0_:%.+]]: memref<1x8xf32>) -> (memref<1x8xui8>, memref, memref) { // CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<2.550000e+02> : vector<8xf32> // CHECK-DAG: [[VAR_cst_0_:%.+]] = arith.constant dense<0.000000e+00> : vector<8xf32> -// CHECK-DAG: [[VAR_cst_1_:%.+]] = arith.constant dense<5.000000e-01> : vector<8xf32> -// CHECK-DAG: [[VAR_cst_2_:%.+]] = arith.constant dense<2.000000e+00> : vector<8xf32> -// CHECK-DAG: [[VAR_cst_3_:%.+]] = arith.constant dense<1.000000e+00> : vector<8xf32> -// CHECK-DAG: [[VAR_cst_4_:%.+]] = arith.constant dense<0xFF800000> : vector<8xf32> -// CHECK-DAG: [[VAR_cst_5_:%.+]] = arith.constant dense<0x7F800000> : vector<8xf32> -// CHECK-DAG: [[CST_5_dot_000000_:%.+]] = arith.constant 5.000000e-01 : f32 -// CHECK-DAG: [[CST_2_dot_000000_:%.+]] = arith.constant 2.000000e+00 : f32 -// CHECK-DAG: [[CST_1_dot_000000_:%.+]] = arith.constant 1.000000e+00 : f32 +// CHECK-DAG: [[VAR_cst_1_:%.+]] = arith.constant dense<0xFF800000> : vector<8xf32> +// CHECK-DAG: [[VAR_cst_2_:%.+]] = arith.constant dense<0x7F800000> : vector<8xf32> // CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index // CHECK-DAG: [[CST_0_dot_000000_:%.+]] = arith.constant 0.000000e+00 : f32 // CHECK-DAG: [[CST_2_dot_550000_:%.+]] = arith.constant 2.550000e+02 : f32 @@ -343,21 +274,21 @@ func.func @test_dynamic_quantize_linear_reduced_simd_only(%arg0: tensor<1x8xf32> // CHECK-DAG: [[RES_5_:%.+]] = memref.alloc() : memref // CHECK-DAG: [[RES_6_:%.+]] = memref.alloc() {{.*}}: memref<8xf32> // CHECK-DAG: [[RES_7_:%.+]] = memref.alloc() : memref -// CHECK: vector.store [[VAR_cst_5_]], [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> -// CHECK: vector.store [[VAR_cst_4_]], [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> +// CHECK: vector.store [[VAR_cst_2_]], [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> +// CHECK: vector.store [[VAR_cst_1_]], [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> // CHECK: [[LOOP_0_:%.+]] = krnl.define_loops 1 // CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]] 8 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) // CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to 8){ -// CHECK: [[VAR_33_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_33_]]{{.}} : memref<8xf32>, vector<8xf32> -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_1_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_33_]]{{.}} : memref<8xf32>, vector<8xf32> +// CHECK: [[VAR_20_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_20_]]{{.}} : memref<8xf32>, vector<8xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_1_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_20_]]{{.}} : memref<8xf32>, vector<8xf32> // CHECK-DAG: [[LOAD_RES_4_MEM_:%.+]] = vector.load [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> // CHECK-DAG: [[LOAD_RES_6_MEM_:%.+]] = vector.load [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_38_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_]], [[LOAD_VAR_reshape_MEM_]] : vector<8xf32> -// CHECK-DAG: [[VAR_39_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_]], [[LOAD_VAR_reshape_MEM_1_]] : vector<8xf32> -// CHECK: vector.store [[VAR_38_]], [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> -// CHECK: vector.store [[VAR_39_]], [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> +// CHECK-DAG: [[VAR_25_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_]], [[LOAD_VAR_reshape_MEM_]] : vector<8xf32> +// CHECK-DAG: [[VAR_26_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_]], [[LOAD_VAR_reshape_MEM_1_]] : vector<8xf32> +// CHECK: vector.store [[VAR_25_]], [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> +// CHECK: vector.store [[VAR_26_]], [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> // CHECK: } // CHECK-DAG: [[LOAD_RES_4_MEM_1_:%.+]] = vector.load [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> // CHECK-DAG: [[LOAD_RES_6_MEM_1_:%.+]] = vector.load [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> @@ -377,67 +308,43 @@ func.func @test_dynamic_quantize_linear_reduced_simd_only(%arg0: tensor<1x8xf32> // CHECK: [[VAR_12_:%.+]] = arith.subf [[CST_0_dot_000000_]], [[VAR_11_]] : f32 // CHECK: [[VAR_13_:%.+]] = arith.maxnumf [[VAR_12_]], [[CST_0_dot_000000_]] : f32 // CHECK: [[VAR_14_:%.+]] = arith.minnumf [[VAR_13_]], [[CST_2_dot_550000_]] : f32 -// CHECK: [[VAR_15_:%.+]] = math.floor [[VAR_14_]] : f32 -// CHECK: [[VAR_16_:%.+]] = arith.subf [[VAR_14_]], [[VAR_15_]] : f32 -// CHECK-DAG: [[VAR_17_:%.+]] = arith.cmpf ogt, [[VAR_16_]], [[CST_5_dot_000000_]] : f32 -// CHECK-DAG: [[VAR_18_:%.+]] = arith.addf [[VAR_15_]], [[CST_1_dot_000000_]] : f32 -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_19_:%.+]] = arith.select [[VAR_17_]], [[VAR_18_]], [[VAR_15_]] : f32 -// CHECK-DAG: [[VAR_20_:%.+]] = arith.mulf [[VAR_15_]], [[CST_5_dot_000000_]] : f32 -// CHECK: [[VAR_21_:%.+]] = math.floor [[VAR_20_]] : f32 -// CHECK: [[VAR_22_:%.+]] = arith.mulf [[VAR_21_]], [[CST_2_dot_000000_]] : f32 -// CHECK: [[VAR_23_:%.+]] = arith.subf [[VAR_15_]], [[VAR_22_]] : f32 -// CHECK-DAG: [[VAR_24_:%.+]] = arith.cmpf oeq, [[VAR_23_]], [[CST_1_dot_000000_]] : f32 -// CHECK-DAG: [[VAR_25_:%.+]] = arith.addf [[VAR_15_]], [[CST_1_dot_000000_]] : f32 -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_26_:%.+]] = arith.select [[VAR_24_]], [[VAR_25_]], [[VAR_15_]] : f32 -// CHECK-DAG: [[VAR_27_:%.+]] = arith.cmpf oeq, [[VAR_16_]], [[CST_5_dot_000000_]] : f32 -// CHECK: [[VAR_28_:%.+]] = arith.select [[VAR_27_]], [[VAR_26_]], [[VAR_19_]] : f32 -// CHECK: [[VAR_29_:%.+]] = arith.fptoui [[VAR_28_]] : f32 to i32 -// CHECK: [[VAR_30_:%.+]] = arith.trunci [[VAR_29_]] : i32 to i8 -// CHECK: [[VAR_31_:%.+]] = builtin.unrealized_conversion_cast [[VAR_30_]] : i8 to ui8 +// CHECK: [[VAR_15_:%.+]] = "krnl.round_even"([[VAR_14_]]) : (f32) -> f32 +// CHECK: [[VAR_16_:%.+]] = arith.fptoui [[VAR_15_]] : f32 to i32 +// CHECK: [[VAR_17_:%.+]] = arith.trunci [[VAR_16_]] : i32 to i8 +// CHECK: [[VAR_18_:%.+]] = builtin.unrealized_conversion_cast [[VAR_17_]] : i8 to ui8 // CHECK: krnl.store [[VAR_10_]], [[RES_1_]][] : memref -// CHECK: krnl.store [[VAR_31_]], [[RES_2_]][] : memref +// CHECK: krnl.store [[VAR_18_]], [[RES_2_]][] : memref // CHECK: [[RES_8_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> // CHECK: affine.store [[CST_8_]], [[RES_8_]][0] : memref<1xindex> -// CHECK-DAG: [[VAR_reshape_19_:%.+]] = memref.reshape [[PARAM_0_]]([[RES_8_]]) : (memref<1x8xf32>, memref<1xindex>) -> memref<8xf32> +// CHECK-DAG: [[VAR_reshape_13_:%.+]] = memref.reshape [[PARAM_0_]]([[RES_8_]]) : (memref<1x8xf32>, memref<1xindex>) -> memref<8xf32> // CHECK-DAG: [[RES_9_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> // CHECK: affine.store [[CST_8_]], [[RES_9_]][0] : memref<1xindex> -// CHECK-DAG: [[VAR_reshape_21_:%.+]] = memref.reshape [[RES_]]([[RES_]]_20) : (memref<1x8xui8>, memref<1xindex>) -> memref<8xui8> +// CHECK-DAG: [[VAR_reshape_15_:%.+]] = memref.reshape [[RES_]]([[RES_]]_14) : (memref<1x8xui8>, memref<1xindex>) -> memref<8xui8> // CHECK-DAG: [[LOOP_1_:%.+]] = krnl.define_loops 1 // CHECK: [[BLOCK_TILE__1_:%.+]], [[BLOCK_IN__1_:%.+]] = krnl.block [[LOOP_1_]] 8 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) // CHECK: krnl.iterate([[BLOCK_TILE__1_]]) with ([[LOOP_1_]] -> [[I_1_:%.+]] = 0 to 8){ -// CHECK: [[VAR_33_1_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__1_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = vector.load [[VAR_reshape_19_]]{{.}}[[VAR_33_1_]]{{.}} : memref<8xf32>, vector<8xf32> +// CHECK: [[VAR_20_1_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__1_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = vector.load [[VAR_reshape_13_]]{{.}}[[VAR_20_1_]]{{.}} : memref<8xf32>, vector<8xf32> // CHECK-DAG: [[LOAD_VAR_reshape_MEM_1_:%.+]] = vector.splat [[VAR_10_]] : vector<8xf32> // CHECK: [[LOAD_RES_4_MEM_2_:%.+]] = arith.divf [[LOAD_VAR_reshape_MEM_2_]], [[LOAD_VAR_reshape_MEM_1_]] : vector<8xf32> -// CHECK: [[LOAD_RES_6_MEM_2_:%.+]] = math.floor [[LOAD_RES_4_MEM_2_]] : vector<8xf32> -// CHECK: [[VAR_38_1_:%.+]] = arith.subf [[LOAD_RES_4_MEM_2_]], [[LOAD_RES_6_MEM_2_]] : vector<8xf32> -// CHECK-DAG: [[VAR_39_1_:%.+]] = arith.cmpf ogt, [[VAR_38_1_]], [[VAR_cst_1_]] : vector<8xf32> -// CHECK-DAG: [[VAR_40_:%.+]] = arith.addf [[LOAD_RES_6_MEM_2_]], [[VAR_cst_3_]] : vector<8xf32> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_41_:%.+]] = arith.select [[VAR_39_1_]], [[VAR_40_]], [[LOAD_RES_6_MEM_2_]] : vector<8xi1>, vector<8xf32> -// CHECK-DAG: [[VAR_42_:%.+]] = arith.mulf [[LOAD_RES_6_MEM_2_]], [[VAR_cst_1_]] : vector<8xf32> -// CHECK: [[VAR_43_:%.+]] = math.floor [[VAR_42_]] : vector<8xf32> -// CHECK: [[VAR_44_:%.+]] = arith.mulf [[VAR_43_]], [[VAR_cst_2_]] : vector<8xf32> -// CHECK: [[VAR_45_:%.+]] = arith.subf [[LOAD_RES_6_MEM_2_]], [[VAR_44_]] : vector<8xf32> -// CHECK-DAG: [[VAR_46_:%.+]] = arith.cmpf oeq, [[VAR_45_]], [[VAR_cst_3_]] : vector<8xf32> -// CHECK-DAG: [[VAR_47_:%.+]] = arith.addf [[LOAD_RES_6_MEM_2_]], [[VAR_cst_3_]] : vector<8xf32> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_48_:%.+]] = arith.select [[VAR_46_]], [[VAR_47_]], [[LOAD_RES_6_MEM_2_]] : vector<8xi1>, vector<8xf32> -// CHECK-DAG: [[VAR_49_:%.+]] = arith.cmpf oeq, [[VAR_38_1_]], [[VAR_cst_1_]] : vector<8xf32> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_50_:%.+]] = arith.select [[VAR_49_]], [[VAR_48_]], [[VAR_41_]] : vector<8xi1>, vector<8xf32> -// CHECK-DAG: [[VAR_51_:%.+]] = vector.splat [[VAR_28_]] : vector<8xf32> -// CHECK: [[VAR_52_:%.+]] = arith.addf [[VAR_50_]], [[VAR_51_]] : vector<8xf32> -// CHECK: [[VAR_53_:%.+]] = arith.maxnumf [[VAR_52_]], [[VAR_cst_0_]] : vector<8xf32> -// CHECK: [[VAR_54_:%.+]] = arith.minnumf [[VAR_53_]], [[VAR_cst_]] : vector<8xf32> -// CHECK: [[VAR_55_:%.+]] = arith.fptoui [[VAR_54_]] : vector<8xf32> to vector<8xi32> -// CHECK: [[VAR_56_:%.+]] = arith.trunci [[VAR_55_]] : vector<8xi32> to vector<8xi8> -// CHECK: [[VAR_57_:%.+]] = builtin.unrealized_conversion_cast [[VAR_56_]] : vector<8xi8> to vector<8xui8> -// CHECK: vector.store [[VAR_57_]], [[VAR_reshape_21_]]{{.}}[[VAR_33_1_]]{{.}} : memref<8xui8>, vector<8xui8> +// CHECK: [[LOAD_RES_6_MEM_2_:%.+]] = vector.shape_cast [[LOAD_RES_4_MEM_2_]] : vector<8xf32> to vector<2x4xf32> +// CHECK: [[VAR_25_1_:%.+]] = vector.extract [[LOAD_RES_6_MEM_2_]][0] : vector<4xf32> from vector<2x4xf32> +// CHECK: [[VAR_26_1_:%.+]] = "krnl.round_even"([[VAR_25_1_]]) : (vector<4xf32>) -> vector<4xf32> +// CHECK-DAG: [[VAR_27_:%.+]] = vector.insert [[VAR_26_1_]], [[LOAD_RES_6_MEM_2_]] [0] : vector<4xf32> into vector<2x4xf32> +// CHECK-DAG: [[VAR_28_:%.+]] = vector.extract [[LOAD_RES_6_MEM_2_]][1] : vector<4xf32> from vector<2x4xf32> +// CHECK: [[VAR_29_:%.+]] = "krnl.round_even"([[VAR_28_]]) : (vector<4xf32>) -> vector<4xf32> +// CHECK: [[VAR_30_:%.+]] = vector.insert [[VAR_29_]], [[VAR_27_]] [1] : vector<4xf32> into vector<2x4xf32> +// CHECK-DAG: [[VAR_31_:%.+]] = vector.shape_cast [[VAR_30_]] : vector<2x4xf32> to vector<8xf32> +// CHECK-DAG: [[VAR_32_:%.+]] = vector.splat [[VAR_15_]] : vector<8xf32> +// CHECK: [[VAR_33_:%.+]] = arith.addf [[VAR_31_]], [[VAR_32_]] : vector<8xf32> +// CHECK: [[VAR_34_:%.+]] = arith.maxnumf [[VAR_33_]], [[VAR_cst_0_]] : vector<8xf32> +// CHECK: [[VAR_35_:%.+]] = arith.minnumf [[VAR_34_]], [[VAR_cst_]] : vector<8xf32> +// CHECK: [[VAR_36_:%.+]] = arith.fptoui [[VAR_35_]] : vector<8xf32> to vector<8xi32> +// CHECK: [[VAR_37_:%.+]] = arith.trunci [[VAR_36_]] : vector<8xi32> to vector<8xi8> +// CHECK: [[VAR_38_:%.+]] = builtin.unrealized_conversion_cast [[VAR_37_]] : vector<8xi8> to vector<8xui8> +// CHECK: vector.store [[VAR_38_]], [[VAR_reshape_15_]]{{.}}[[VAR_20_1_]]{{.}} : memref<8xui8>, vector<8xui8> // CHECK: } -// CHECK: return [[RES_]], [[RES_]]_11, [[RES_]]_12 : memref<1x8xui8>, memref, memref +// CHECK: return [[RES_]], [[RES_]]_5, [[RES_]]_6 : memref<1x8xui8>, memref, memref // CHECK: } } diff --git a/test/mlir/conversion/onnx_to_krnl/Quantization/DynamicQuantizeLinear_with_simd_parallel_canonicalize.mlir b/test/mlir/conversion/onnx_to_krnl/Quantization/DynamicQuantizeLinear_with_simd_parallel_canonicalize.mlir index 14d809207c..86a9d3f14b 100644 --- a/test/mlir/conversion/onnx_to_krnl/Quantization/DynamicQuantizeLinear_with_simd_parallel_canonicalize.mlir +++ b/test/mlir/conversion/onnx_to_krnl/Quantization/DynamicQuantizeLinear_with_simd_parallel_canonicalize.mlir @@ -20,16 +20,10 @@ func.func @test_dynamic_quantize_linear_simd_only(%arg0: tensor<256x16xf32>) -> // CHECK-SAME: ([[PARAM_0_:%.+]]: memref<256x16xf32>) -> (memref<256x16xui8>, memref, memref) { // CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<2.550000e+02> : vector<16xf32> // CHECK-DAG: [[VAR_cst_0_:%.+]] = arith.constant dense<0.000000e+00> : vector<16xf32> -// CHECK-DAG: [[VAR_cst_1_:%.+]] = arith.constant dense<5.000000e-01> : vector<16xf32> -// CHECK-DAG: [[VAR_cst_2_:%.+]] = arith.constant dense<2.000000e+00> : vector<16xf32> -// CHECK-DAG: [[VAR_cst_3_:%.+]] = arith.constant dense<1.000000e+00> : vector<16xf32> -// CHECK-DAG: [[VAR_cst_4_:%.+]] = arith.constant dense<0xFF800000> : vector<1xf32> -// CHECK-DAG: [[VAR_cst_5_:%.+]] = arith.constant dense<0x7F800000> : vector<1xf32> -// CHECK-DAG: [[VAR_cst_6_:%.+]] = arith.constant dense<0xFF800000> : vector<32xf32> -// CHECK-DAG: [[VAR_cst_7_:%.+]] = arith.constant dense<0x7F800000> : vector<32xf32> -// CHECK-DAG: [[CST_5_dot_000000_:%.+]] = arith.constant 5.000000e-01 : f32 -// CHECK-DAG: [[CST_2_dot_000000_:%.+]] = arith.constant 2.000000e+00 : f32 -// CHECK-DAG: [[CST_1_dot_000000_:%.+]] = arith.constant 1.000000e+00 : f32 +// CHECK-DAG: [[VAR_cst_1_:%.+]] = arith.constant dense<0xFF800000> : vector<1xf32> +// CHECK-DAG: [[VAR_cst_2_:%.+]] = arith.constant dense<0x7F800000> : vector<1xf32> +// CHECK-DAG: [[VAR_cst_3_:%.+]] = arith.constant dense<0xFF800000> : vector<32xf32> +// CHECK-DAG: [[VAR_cst_4_:%.+]] = arith.constant dense<0x7F800000> : vector<32xf32> // CHECK-DAG: [[CST_32_:%.+]] = arith.constant 32 : index // CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index // CHECK-DAG: [[CST_4096_:%.+]] = arith.constant 4096 : index @@ -49,63 +43,63 @@ func.func @test_dynamic_quantize_linear_simd_only(%arg0: tensor<256x16xf32>) -> // CHECK-DAG: [[LOOP_0_:%.+]] = krnl.define_loops 1 // CHECK: krnl.parallel([[LOOP_0_]]) : !krnl.loop // CHECK: krnl.iterate([[LOOP_0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to 8){ -// CHECK: [[VAR_34_:%.+]] = krnl.get_induction_var_value([[LOOP_0_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[VAR_35_:%.+]] = affine.apply [[MAP_0_]]([[VAR_34_]]) -// CHECK-DAG: [[VAR_36_:%.+]] = affine.min [[MAP_1_]]([[VAR_34_]]) -// CHECK-DAG: [[VAR_37_:%.+]] = affine.apply [[MAP_2_]]([[VAR_34_]]) -// CHECK: vector.store [[VAR_cst_7_]], [[RES_4_]]{{.}}[[VAR_37_]]{{.}} : memref<256xf32>, vector<32xf32> -// CHECK: vector.store [[VAR_cst_6_]], [[RES_6_]]{{.}}[[VAR_37_]]{{.}} : memref<256xf32>, vector<32xf32> -// CHECK: [[VAR_38_:%.+]] = affine.min [[MAP_3_]]([[VAR_34_]]) -// CHECK: scf.for [[I_1_:%.+]] = [[VAR_35_]] to [[VAR_38_]] step [[CST_32_]] { +// CHECK: [[VAR_21_:%.+]] = krnl.get_induction_var_value([[LOOP_0_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[VAR_22_:%.+]] = affine.apply [[MAP_0_]]([[VAR_21_]]) +// CHECK-DAG: [[VAR_23_:%.+]] = affine.min [[MAP_1_]]([[VAR_21_]]) +// CHECK-DAG: [[VAR_24_:%.+]] = affine.apply [[MAP_2_]]([[VAR_21_]]) +// CHECK: vector.store [[VAR_cst_4_]], [[RES_4_]]{{.}}[[VAR_24_]]{{.}} : memref<256xf32>, vector<32xf32> +// CHECK: vector.store [[VAR_cst_3_]], [[RES_6_]]{{.}}[[VAR_24_]]{{.}} : memref<256xf32>, vector<32xf32> +// CHECK: [[VAR_25_:%.+]] = affine.min [[MAP_3_]]([[VAR_21_]]) +// CHECK: scf.for [[I_1_:%.+]] = [[VAR_22_]] to [[VAR_25_]] step [[CST_32_]] { // CHECK-DAG: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[I_1_]]{{.}} : memref<4096xf32>, vector<32xf32> // CHECK-DAG: [[LOAD_VAR_reshape_MEM_1_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[I_1_]]{{.}} : memref<4096xf32>, vector<32xf32> -// CHECK-DAG: [[LOAD_RES_4_MEM_:%.+]] = vector.load [[RES_4_]]{{.}}[[VAR_37_]]{{.}} : memref<256xf32>, vector<32xf32> -// CHECK-DAG: [[LOAD_RES_6_MEM_:%.+]] = vector.load [[RES_6_]]{{.}}[[VAR_37_]]{{.}} : memref<256xf32>, vector<32xf32> +// CHECK-DAG: [[LOAD_RES_4_MEM_:%.+]] = vector.load [[RES_4_]]{{.}}[[VAR_24_]]{{.}} : memref<256xf32>, vector<32xf32> +// CHECK-DAG: [[LOAD_RES_6_MEM_:%.+]] = vector.load [[RES_6_]]{{.}}[[VAR_24_]]{{.}} : memref<256xf32>, vector<32xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_51_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_]], [[LOAD_VAR_reshape_MEM_]] : vector<32xf32> -// CHECK-DAG: [[VAR_52_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_]], [[LOAD_VAR_reshape_MEM_1_]] : vector<32xf32> -// CHECK: vector.store [[VAR_51_]], [[RES_4_]]{{.}}[[VAR_37_]]{{.}} : memref<256xf32>, vector<32xf32> -// CHECK: vector.store [[VAR_52_]], [[RES_6_]]{{.}}[[VAR_37_]]{{.}} : memref<256xf32>, vector<32xf32> +// CHECK-DAG: [[VAR_38_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_]], [[LOAD_VAR_reshape_MEM_]] : vector<32xf32> +// CHECK-DAG: [[VAR_39_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_]], [[LOAD_VAR_reshape_MEM_1_]] : vector<32xf32> +// CHECK: vector.store [[VAR_38_]], [[RES_4_]]{{.}}[[VAR_24_]]{{.}} : memref<256xf32>, vector<32xf32> +// CHECK: vector.store [[VAR_39_]], [[RES_6_]]{{.}}[[VAR_24_]]{{.}} : memref<256xf32>, vector<32xf32> // CHECK: } -// CHECK: [[VAR_39_:%.+]] = affine.min [[MAP_4_]]([[VAR_34_]]) -// CHECK: [[VAR_40_:%.+]] = arith.remsi [[VAR_39_]], [[CST_32_]] : index -// CHECK: [[VAR_41_:%.+]] = arith.subi [[VAR_39_]], [[VAR_40_]] : index -// CHECK: [[VAR_42_:%.+]] = arith.addi [[VAR_35_]], [[VAR_41_]] : index -// CHECK: scf.for [[I_2_:%.+]] = [[VAR_42_]] to [[VAR_36_]] step [[CST_1_]] { +// CHECK: [[VAR_26_:%.+]] = affine.min [[MAP_4_]]([[VAR_21_]]) +// CHECK: [[VAR_27_:%.+]] = arith.remsi [[VAR_26_]], [[CST_32_]] : index +// CHECK: [[VAR_28_:%.+]] = arith.subi [[VAR_26_]], [[VAR_27_]] : index +// CHECK: [[VAR_29_:%.+]] = arith.addi [[VAR_22_]], [[VAR_28_]] : index +// CHECK: scf.for [[I_2_:%.+]] = [[VAR_29_]] to [[VAR_23_]] step [[CST_1_]] { // CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = memref.load [[VAR_reshape_]]{{.}}[[I_2_]]{{.}} : memref<4096xf32> // CHECK-DAG: [[LOAD_VAR_reshape_MEM_3_:%.+]] = memref.load [[VAR_reshape_]]{{.}}[[I_2_]]{{.}} : memref<4096xf32> -// CHECK-DAG: [[LOAD_RES_4_MEM_1_:%.+]] = memref.load [[RES_4_]]{{.}}[[VAR_37_]]{{.}} : memref<256xf32> -// CHECK-DAG: [[LOAD_RES_6_MEM_1_:%.+]] = memref.load [[RES_6_]]{{.}}[[VAR_37_]]{{.}} : memref<256xf32> +// CHECK-DAG: [[LOAD_RES_4_MEM_1_:%.+]] = memref.load [[RES_4_]]{{.}}[[VAR_24_]]{{.}} : memref<256xf32> +// CHECK-DAG: [[LOAD_RES_6_MEM_1_:%.+]] = memref.load [[RES_6_]]{{.}}[[VAR_24_]]{{.}} : memref<256xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_51_1_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_1_]], [[LOAD_VAR_reshape_MEM_2_]] : f32 -// CHECK-DAG: [[VAR_52_1_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_1_]], [[LOAD_VAR_reshape_MEM_3_]] : f32 -// CHECK: memref.store [[VAR_51_1_]], [[RES_4_]]{{.}}[[VAR_37_]]{{.}} : memref<256xf32> -// CHECK: memref.store [[VAR_52_1_]], [[RES_6_]]{{.}}[[VAR_37_]]{{.}} : memref<256xf32> +// CHECK-DAG: [[VAR_38_1_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_1_]], [[LOAD_VAR_reshape_MEM_2_]] : f32 +// CHECK-DAG: [[VAR_39_1_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_1_]], [[LOAD_VAR_reshape_MEM_3_]] : f32 +// CHECK: memref.store [[VAR_38_1_]], [[RES_4_]]{{.}}[[VAR_24_]]{{.}} : memref<256xf32> +// CHECK: memref.store [[VAR_39_1_]], [[RES_6_]]{{.}}[[VAR_24_]]{{.}} : memref<256xf32> // CHECK: } -// CHECK-DAG: [[LOAD_RES_4_MEM_2_:%.+]] = vector.load [[RES_4_]]{{.}}[[VAR_37_]]{{.}} : memref<256xf32>, vector<32xf32> -// CHECK-DAG: [[LOAD_RES_6_MEM_2_:%.+]] = vector.load [[RES_6_]]{{.}}[[VAR_37_]]{{.}} : memref<256xf32>, vector<32xf32> +// CHECK-DAG: [[LOAD_RES_4_MEM_2_:%.+]] = vector.load [[RES_4_]]{{.}}[[VAR_24_]]{{.}} : memref<256xf32>, vector<32xf32> +// CHECK-DAG: [[LOAD_RES_6_MEM_2_:%.+]] = vector.load [[RES_6_]]{{.}}[[VAR_24_]]{{.}} : memref<256xf32>, vector<32xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_45_:%.+]] = vector.reduction , [[LOAD_RES_4_MEM_2_]] : vector<32xf32> into f32 -// CHECK-DAG: [[VAR_46_:%.+]] = vector.reduction , [[LOAD_RES_6_MEM_2_]] : vector<32xf32> into f32 -// CHECK: memref.store [[VAR_45_]], [[RES_5_]]{{.}}[[VAR_34_]]{{.}} : memref<8xf32> -// CHECK: memref.store [[VAR_46_]], [[RES_7_]]{{.}}[[VAR_34_]]{{.}} : memref<8xf32> +// CHECK-DAG: [[VAR_32_:%.+]] = vector.reduction , [[LOAD_RES_4_MEM_2_]] : vector<32xf32> into f32 +// CHECK-DAG: [[VAR_33_:%.+]] = vector.reduction , [[LOAD_RES_6_MEM_2_]] : vector<32xf32> into f32 +// CHECK: memref.store [[VAR_32_]], [[RES_5_]]{{.}}[[VAR_21_]]{{.}} : memref<8xf32> +// CHECK: memref.store [[VAR_33_]], [[RES_7_]]{{.}}[[VAR_21_]]{{.}} : memref<8xf32> // CHECK: } // CHECK-DAG: [[RES_8_:%.+]] = memref.alloc() : memref // CHECK-DAG: [[RES_9_:%.+]] = memref.alloc() : memref -// CHECK: vector.store [[VAR_cst_5_]], [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32>, vector<1xf32> -// CHECK: vector.store [[VAR_cst_4_]], [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32>, vector<1xf32> +// CHECK: vector.store [[VAR_cst_2_]], [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32>, vector<1xf32> +// CHECK: vector.store [[VAR_cst_1_]], [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32>, vector<1xf32> // CHECK: [[LOOP_1_:%.+]] = krnl.define_loops 1 // CHECK: krnl.iterate([[LOOP_1_]]) with ([[LOOP_1_]] -> [[I_3_:%.+]] = 0 to 8){ -// CHECK: [[VAR_34_1_:%.+]] = krnl.get_induction_var_value([[LOOP_1_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[VAR_35_1_:%.+]] = krnl.load [[RES_5_]]{{.}}[[VAR_34_1_]]{{.}} : memref<8xf32> -// CHECK-DAG: [[VAR_36_1_:%.+]] = krnl.load [[RES_7_]]{{.}}[[VAR_34_1_]]{{.}} : memref<8xf32> +// CHECK: [[VAR_21_1_:%.+]] = krnl.get_induction_var_value([[LOOP_1_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[VAR_22_1_:%.+]] = krnl.load [[RES_5_]]{{.}}[[VAR_21_1_]]{{.}} : memref<8xf32> +// CHECK-DAG: [[VAR_23_1_:%.+]] = krnl.load [[RES_7_]]{{.}}[[VAR_21_1_]]{{.}} : memref<8xf32> // CHECK-DAG: [[LOAD_RES_4_MEM_3_:%.+]] = krnl.load [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32> // CHECK-DAG: [[LOAD_RES_6_MEM_3_:%.+]] = krnl.load [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_39_1_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_3_]], [[VAR_35_1_]] : f32 -// CHECK-DAG: [[VAR_40_1_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_3_]], [[VAR_36_1_]] : f32 -// CHECK: krnl.store [[VAR_39_1_]], [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32> -// CHECK: krnl.store [[VAR_40_1_]], [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32> +// CHECK-DAG: [[VAR_26_1_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_3_]], [[VAR_22_1_]] : f32 +// CHECK-DAG: [[VAR_27_1_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_3_]], [[VAR_23_1_]] : f32 +// CHECK: krnl.store [[VAR_26_1_]], [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32> +// CHECK: krnl.store [[VAR_27_1_]], [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32> // CHECK: } // CHECK-DAG: [[LOAD_RES_4_MEM_4_:%.+]] = vector.load [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32>, vector<1xf32> // CHECK-DAG: [[LOAD_RES_6_MEM_4_:%.+]] = vector.load [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32>, vector<1xf32> @@ -125,68 +119,50 @@ func.func @test_dynamic_quantize_linear_simd_only(%arg0: tensor<256x16xf32>) -> // CHECK: [[VAR_13_:%.+]] = arith.subf [[CST_0_dot_000000_]], [[VAR_12_]] : f32 // CHECK: [[VAR_14_:%.+]] = arith.maxnumf [[VAR_13_]], [[CST_0_dot_000000_]] : f32 // CHECK: [[VAR_15_:%.+]] = arith.minnumf [[VAR_14_]], [[CST_2_dot_550000_]] : f32 -// CHECK: [[VAR_16_:%.+]] = math.floor [[VAR_15_]] : f32 -// CHECK: [[VAR_17_:%.+]] = arith.subf [[VAR_15_]], [[VAR_16_]] : f32 -// CHECK-DAG: [[VAR_18_:%.+]] = arith.cmpf ogt, [[VAR_17_]], [[CST_5_dot_000000_]] : f32 -// CHECK-DAG: [[VAR_19_:%.+]] = arith.addf [[VAR_16_]], [[CST_1_dot_000000_]] : f32 -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_20_:%.+]] = arith.select [[VAR_18_]], [[VAR_19_]], [[VAR_16_]] : f32 -// CHECK-DAG: [[VAR_21_:%.+]] = arith.mulf [[VAR_16_]], [[CST_5_dot_000000_]] : f32 -// CHECK: [[VAR_22_:%.+]] = math.floor [[VAR_21_]] : f32 -// CHECK: [[VAR_23_:%.+]] = arith.mulf [[VAR_22_]], [[CST_2_dot_000000_]] : f32 -// CHECK: [[VAR_24_:%.+]] = arith.subf [[VAR_16_]], [[VAR_23_]] : f32 -// CHECK-DAG: [[VAR_25_:%.+]] = arith.cmpf oeq, [[VAR_24_]], [[CST_1_dot_000000_]] : f32 -// CHECK-DAG: [[VAR_26_:%.+]] = arith.addf [[VAR_16_]], [[CST_1_dot_000000_]] : f32 -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_27_:%.+]] = arith.select [[VAR_25_]], [[VAR_26_]], [[VAR_16_]] : f32 -// CHECK-DAG: [[VAR_28_:%.+]] = arith.cmpf oeq, [[VAR_17_]], [[CST_5_dot_000000_]] : f32 -// CHECK: [[VAR_29_:%.+]] = arith.select [[VAR_28_]], [[VAR_27_]], [[VAR_20_]] : f32 -// CHECK: [[VAR_30_:%.+]] = arith.fptoui [[VAR_29_]] : f32 to i32 -// CHECK: [[VAR_31_:%.+]] = arith.trunci [[VAR_30_]] : i32 to i8 -// CHECK: [[VAR_32_:%.+]] = builtin.unrealized_conversion_cast [[VAR_31_]] : i8 to ui8 +// CHECK: [[VAR_16_:%.+]] = "krnl.round_even"([[VAR_15_]]) : (f32) -> f32 +// CHECK: [[VAR_17_:%.+]] = arith.fptoui [[VAR_16_]] : f32 to i32 +// CHECK: [[VAR_18_:%.+]] = arith.trunci [[VAR_17_]] : i32 to i8 +// CHECK: [[VAR_19_:%.+]] = builtin.unrealized_conversion_cast [[VAR_18_]] : i8 to ui8 // CHECK: krnl.store [[VAR_11_]], [[RES_1_]][] : memref -// CHECK: krnl.store [[VAR_32_]], [[RES_2_]][] : memref +// CHECK: krnl.store [[VAR_19_]], [[RES_2_]][] : memref // CHECK: [[RES_10_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> // CHECK: affine.store [[CST_4096_]], [[RES_10_]][0] : memref<1xindex> -// CHECK-DAG: [[VAR_reshape_23_:%.+]] = memref.reshape [[PARAM_0_]]([[RES_10_]]) : (memref<256x16xf32>, memref<1xindex>) -> memref<4096xf32> +// CHECK-DAG: [[VAR_reshape_17_:%.+]] = memref.reshape [[PARAM_0_]]([[RES_10_]]) : (memref<256x16xf32>, memref<1xindex>) -> memref<4096xf32> // CHECK-DAG: [[RES_11_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> // CHECK: affine.store [[CST_4096_]], [[RES_11_]][0] : memref<1xindex> -// CHECK-DAG: [[VAR_reshape_25_:%.+]] = memref.reshape [[RES_]]([[RES_]]_24) : (memref<256x16xui8>, memref<1xindex>) -> memref<4096xui8> +// CHECK-DAG: [[VAR_reshape_19_:%.+]] = memref.reshape [[RES_]]([[RES_]]_18) : (memref<256x16xui8>, memref<1xindex>) -> memref<4096xui8> // CHECK-DAG: [[LOOP_2_:%.+]] = krnl.define_loops 1 // CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_2_]] 16 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) // CHECK: krnl.parallel([[BLOCK_TILE__0_]]) : !krnl.loop // CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_2_]] -> [[I_4_:%.+]] = 0 to 4096){ -// CHECK: [[VAR_34_2_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[VAR_35_1_:%.+]] = vector.load [[VAR_reshape_23_]]{{.}}[[VAR_34_2_]]{{.}} : memref<4096xf32>, vector<16xf32> -// CHECK-DAG: [[VAR_36_2_:%.+]] = vector.splat [[VAR_11_]] : vector<16xf32> -// CHECK: [[VAR_37_1_:%.+]] = arith.divf [[VAR_35_1_]], [[VAR_36_2_]] : vector<16xf32> -// CHECK: [[VAR_38_1_:%.+]] = math.floor [[VAR_37_1_]] : vector<16xf32> -// CHECK: [[VAR_39_2_:%.+]] = arith.subf [[VAR_37_1_]], [[VAR_38_1_]] : vector<16xf32> -// CHECK-DAG: [[VAR_40_2_:%.+]] = arith.cmpf ogt, [[VAR_39_2_]], [[VAR_cst_1_]] : vector<16xf32> -// CHECK-DAG: [[VAR_41_1_:%.+]] = arith.addf [[VAR_38_1_]], [[VAR_cst_3_]] : vector<16xf32> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_42_1_:%.+]] = arith.select [[VAR_40_2_]], [[VAR_41_1_]], [[VAR_38_1_]] : vector<16xi1>, vector<16xf32> -// CHECK-DAG: [[LOAD_RES_4_MEM_2_:%.+]] = arith.mulf [[VAR_38_1_]], [[VAR_cst_1_]] : vector<16xf32> -// CHECK: [[LOAD_RES_6_MEM_2_:%.+]] = math.floor [[LOAD_RES_4_MEM_2_]] : vector<16xf32> -// CHECK: [[VAR_45_1_:%.+]] = arith.mulf [[LOAD_RES_6_MEM_2_]], [[VAR_cst_2_]] : vector<16xf32> -// CHECK: [[VAR_46_1_:%.+]] = arith.subf [[VAR_38_1_]], [[VAR_45_1_]] : vector<16xf32> -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = arith.cmpf oeq, [[VAR_46_1_]], [[VAR_cst_3_]] : vector<16xf32> -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_3_:%.+]] = arith.addf [[VAR_38_1_]], [[VAR_cst_3_]] : vector<16xf32> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[LOAD_RES_4_MEM_1_:%.+]] = arith.select [[LOAD_VAR_reshape_MEM_2_]], [[LOAD_VAR_reshape_MEM_3_]], [[VAR_38_1_]] : vector<16xi1>, vector<16xf32> -// CHECK-DAG: [[LOAD_RES_6_MEM_1_:%.+]] = arith.cmpf oeq, [[VAR_39_2_]], [[VAR_cst_1_]] : vector<16xf32> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_51_2_:%.+]] = arith.select [[LOAD_RES_6_MEM_1_]], [[LOAD_RES_4_MEM_1_]], [[VAR_42_1_]] : vector<16xi1>, vector<16xf32> -// CHECK-DAG: [[VAR_52_2_:%.+]] = vector.splat [[VAR_29_]] : vector<16xf32> -// CHECK: [[VAR_53_:%.+]] = arith.addf [[VAR_51_2_]], [[VAR_52_2_]] : vector<16xf32> -// CHECK: [[VAR_54_:%.+]] = arith.maxnumf [[VAR_53_]], [[VAR_cst_0_]] : vector<16xf32> -// CHECK: [[VAR_55_:%.+]] = arith.minnumf [[VAR_54_]], [[VAR_cst_]] : vector<16xf32> -// CHECK: [[VAR_56_:%.+]] = arith.fptoui [[VAR_55_]] : vector<16xf32> to vector<16xi32> -// CHECK: [[VAR_57_:%.+]] = arith.trunci [[VAR_56_]] : vector<16xi32> to vector<16xi8> -// CHECK: [[VAR_58_:%.+]] = builtin.unrealized_conversion_cast [[VAR_57_]] : vector<16xi8> to vector<16xui8> -// CHECK: vector.store [[VAR_58_]], [[VAR_reshape_25_]]{{.}}[[VAR_34_2_]]{{.}} : memref<4096xui8>, vector<16xui8> +// CHECK: [[VAR_21_2_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[VAR_22_1_:%.+]] = vector.load [[VAR_reshape_17_]]{{.}}[[VAR_21_2_]]{{.}} : memref<4096xf32>, vector<16xf32> +// CHECK-DAG: [[VAR_23_2_:%.+]] = vector.splat [[VAR_11_]] : vector<16xf32> +// CHECK: [[VAR_24_1_:%.+]] = arith.divf [[VAR_22_1_]], [[VAR_23_2_]] : vector<16xf32> +// CHECK: [[VAR_25_1_:%.+]] = vector.shape_cast [[VAR_24_1_]] : vector<16xf32> to vector<4x4xf32> +// CHECK: [[VAR_26_2_:%.+]] = vector.extract [[VAR_25_1_]][0] : vector<4xf32> from vector<4x4xf32> +// CHECK: [[VAR_27_2_:%.+]] = "krnl.round_even"([[VAR_26_2_]]) : (vector<4xf32>) -> vector<4xf32> +// CHECK-DAG: [[VAR_28_1_:%.+]] = vector.insert [[VAR_27_2_]], [[VAR_25_1_]] [0] : vector<4xf32> into vector<4x4xf32> +// CHECK-DAG: [[VAR_29_1_:%.+]] = vector.extract [[VAR_25_1_]][1] : vector<4xf32> from vector<4x4xf32> +// CHECK: [[LOAD_RES_4_MEM_2_:%.+]] = "krnl.round_even"([[VAR_29_1_]]) : (vector<4xf32>) -> vector<4xf32> +// CHECK-DAG: [[LOAD_RES_6_MEM_2_:%.+]] = vector.insert [[LOAD_RES_4_MEM_2_]], [[VAR_28_1_]] [1] : vector<4xf32> into vector<4x4xf32> +// CHECK-DAG: [[VAR_32_1_:%.+]] = vector.extract [[VAR_25_1_]][2] : vector<4xf32> from vector<4x4xf32> +// CHECK: [[VAR_33_1_:%.+]] = "krnl.round_even"([[VAR_32_1_]]) : (vector<4xf32>) -> vector<4xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = vector.insert [[VAR_33_1_]], [[LOAD_RES_6_MEM_2_]] [2] : vector<4xf32> into vector<4x4xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_3_:%.+]] = vector.extract [[VAR_25_1_]][3] : vector<4xf32> from vector<4x4xf32> +// CHECK: [[LOAD_RES_4_MEM_1_:%.+]] = "krnl.round_even"([[LOAD_VAR_reshape_MEM_3_]]) : (vector<4xf32>) -> vector<4xf32> +// CHECK: [[LOAD_RES_6_MEM_1_:%.+]] = vector.insert [[LOAD_RES_4_MEM_1_]], [[LOAD_VAR_reshape_MEM_2_]] [3] : vector<4xf32> into vector<4x4xf32> +// CHECK-DAG: [[VAR_38_2_:%.+]] = vector.shape_cast [[LOAD_RES_6_MEM_1_]] : vector<4x4xf32> to vector<16xf32> +// CHECK-DAG: [[VAR_39_2_:%.+]] = vector.splat [[VAR_16_]] : vector<16xf32> +// CHECK: [[VAR_40_:%.+]] = arith.addf [[VAR_38_2_]], [[VAR_39_2_]] : vector<16xf32> +// CHECK: [[VAR_41_:%.+]] = arith.maxnumf [[VAR_40_]], [[VAR_cst_0_]] : vector<16xf32> +// CHECK: [[VAR_42_:%.+]] = arith.minnumf [[VAR_41_]], [[VAR_cst_]] : vector<16xf32> +// CHECK: [[VAR_43_:%.+]] = arith.fptoui [[VAR_42_]] : vector<16xf32> to vector<16xi32> +// CHECK: [[VAR_44_:%.+]] = arith.trunci [[VAR_43_]] : vector<16xi32> to vector<16xi8> +// CHECK: [[VAR_45_:%.+]] = builtin.unrealized_conversion_cast [[VAR_44_]] : vector<16xi8> to vector<16xui8> +// CHECK: vector.store [[VAR_45_]], [[VAR_reshape_19_]]{{.}}[[VAR_21_2_]]{{.}} : memref<4096xui8>, vector<16xui8> // CHECK: } -// CHECK: return [[RES_]], [[RES_]]_13, [[RES_]]_14 : memref<256x16xui8>, memref, memref +// CHECK: return [[RES_]], [[RES_]]_7, [[RES_]]_8 : memref<256x16xui8>, memref, memref // CHECK: } } @@ -207,16 +183,10 @@ func.func @test_dynamic_quantize_linear_simd_and_scalar(%arg0: tensor<255x17xf32 // CHECK-SAME: ([[PARAM_0_:%.+]]: memref<255x17xf32>) -> (memref<255x17xui8>, memref, memref) { // CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<2.550000e+02> : vector<16xf32> // CHECK-DAG: [[VAR_cst_0_:%.+]] = arith.constant dense<0.000000e+00> : vector<16xf32> -// CHECK-DAG: [[VAR_cst_1_:%.+]] = arith.constant dense<5.000000e-01> : vector<16xf32> -// CHECK-DAG: [[VAR_cst_2_:%.+]] = arith.constant dense<2.000000e+00> : vector<16xf32> -// CHECK-DAG: [[VAR_cst_3_:%.+]] = arith.constant dense<1.000000e+00> : vector<16xf32> -// CHECK-DAG: [[VAR_cst_4_:%.+]] = arith.constant dense<0xFF800000> : vector<1xf32> -// CHECK-DAG: [[VAR_cst_5_:%.+]] = arith.constant dense<0x7F800000> : vector<1xf32> -// CHECK-DAG: [[VAR_cst_6_:%.+]] = arith.constant dense<0xFF800000> : vector<32xf32> -// CHECK-DAG: [[VAR_cst_7_:%.+]] = arith.constant dense<0x7F800000> : vector<32xf32> -// CHECK-DAG: [[CST_5_dot_000000_:%.+]] = arith.constant 5.000000e-01 : f32 -// CHECK-DAG: [[CST_2_dot_000000_:%.+]] = arith.constant 2.000000e+00 : f32 -// CHECK-DAG: [[CST_1_dot_000000_:%.+]] = arith.constant 1.000000e+00 : f32 +// CHECK-DAG: [[VAR_cst_1_:%.+]] = arith.constant dense<0xFF800000> : vector<1xf32> +// CHECK-DAG: [[VAR_cst_2_:%.+]] = arith.constant dense<0x7F800000> : vector<1xf32> +// CHECK-DAG: [[VAR_cst_3_:%.+]] = arith.constant dense<0xFF800000> : vector<32xf32> +// CHECK-DAG: [[VAR_cst_4_:%.+]] = arith.constant dense<0x7F800000> : vector<32xf32> // CHECK-DAG: [[CST_32_:%.+]] = arith.constant 32 : index // CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index // CHECK-DAG: [[CST_4335_:%.+]] = arith.constant 4335 : index @@ -236,63 +206,63 @@ func.func @test_dynamic_quantize_linear_simd_and_scalar(%arg0: tensor<255x17xf32 // CHECK-DAG: [[LOOP_0_:%.+]] = krnl.define_loops 1 // CHECK: krnl.parallel([[LOOP_0_]]) : !krnl.loop // CHECK: krnl.iterate([[LOOP_0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to 8){ -// CHECK: [[VAR_35_:%.+]] = krnl.get_induction_var_value([[LOOP_0_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[VAR_36_:%.+]] = affine.apply [[MAP_0_]]([[VAR_35_]]) -// CHECK-DAG: [[VAR_37_:%.+]] = affine.min [[MAP_1_]]([[VAR_35_]]) -// CHECK-DAG: [[VAR_38_:%.+]] = affine.apply [[MAP_2_]]([[VAR_35_]]) -// CHECK: vector.store [[VAR_cst_7_]], [[RES_4_]]{{.}}[[VAR_38_]]{{.}} : memref<256xf32>, vector<32xf32> -// CHECK: vector.store [[VAR_cst_6_]], [[RES_6_]]{{.}}[[VAR_38_]]{{.}} : memref<256xf32>, vector<32xf32> -// CHECK: [[VAR_39_:%.+]] = affine.min [[MAP_3_]]([[VAR_35_]]) -// CHECK: scf.for [[I_1_:%.+]] = [[VAR_36_]] to [[VAR_39_]] step [[CST_32_]] { +// CHECK: [[VAR_22_:%.+]] = krnl.get_induction_var_value([[LOOP_0_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[VAR_23_:%.+]] = affine.apply [[MAP_0_]]([[VAR_22_]]) +// CHECK-DAG: [[VAR_24_:%.+]] = affine.min [[MAP_1_]]([[VAR_22_]]) +// CHECK-DAG: [[VAR_25_:%.+]] = affine.apply [[MAP_2_]]([[VAR_22_]]) +// CHECK: vector.store [[VAR_cst_4_]], [[RES_4_]]{{.}}[[VAR_25_]]{{.}} : memref<256xf32>, vector<32xf32> +// CHECK: vector.store [[VAR_cst_3_]], [[RES_6_]]{{.}}[[VAR_25_]]{{.}} : memref<256xf32>, vector<32xf32> +// CHECK: [[VAR_26_:%.+]] = affine.min [[MAP_3_]]([[VAR_22_]]) +// CHECK: scf.for [[I_1_:%.+]] = [[VAR_23_]] to [[VAR_26_]] step [[CST_32_]] { // CHECK-DAG: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[I_1_]]{{.}} : memref<4335xf32>, vector<32xf32> // CHECK-DAG: [[LOAD_VAR_reshape_MEM_1_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[I_1_]]{{.}} : memref<4335xf32>, vector<32xf32> -// CHECK-DAG: [[LOAD_RES_4_MEM_:%.+]] = vector.load [[RES_4_]]{{.}}[[VAR_38_]]{{.}} : memref<256xf32>, vector<32xf32> -// CHECK-DAG: [[LOAD_RES_6_MEM_:%.+]] = vector.load [[RES_6_]]{{.}}[[VAR_38_]]{{.}} : memref<256xf32>, vector<32xf32> +// CHECK-DAG: [[LOAD_RES_4_MEM_:%.+]] = vector.load [[RES_4_]]{{.}}[[VAR_25_]]{{.}} : memref<256xf32>, vector<32xf32> +// CHECK-DAG: [[LOAD_RES_6_MEM_:%.+]] = vector.load [[RES_6_]]{{.}}[[VAR_25_]]{{.}} : memref<256xf32>, vector<32xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_52_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_]], [[LOAD_VAR_reshape_MEM_]] : vector<32xf32> -// CHECK-DAG: [[VAR_53_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_]], [[LOAD_VAR_reshape_MEM_1_]] : vector<32xf32> -// CHECK: vector.store [[VAR_52_]], [[RES_4_]]{{.}}[[VAR_38_]]{{.}} : memref<256xf32>, vector<32xf32> -// CHECK: vector.store [[VAR_53_]], [[RES_6_]]{{.}}[[VAR_38_]]{{.}} : memref<256xf32>, vector<32xf32> +// CHECK-DAG: [[VAR_39_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_]], [[LOAD_VAR_reshape_MEM_]] : vector<32xf32> +// CHECK-DAG: [[VAR_40_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_]], [[LOAD_VAR_reshape_MEM_1_]] : vector<32xf32> +// CHECK: vector.store [[VAR_39_]], [[RES_4_]]{{.}}[[VAR_25_]]{{.}} : memref<256xf32>, vector<32xf32> +// CHECK: vector.store [[VAR_40_]], [[RES_6_]]{{.}}[[VAR_25_]]{{.}} : memref<256xf32>, vector<32xf32> // CHECK: } -// CHECK: [[VAR_40_:%.+]] = affine.min [[MAP_4_]]([[VAR_35_]]) -// CHECK: [[VAR_41_:%.+]] = arith.remsi [[VAR_40_]], [[CST_32_]] : index -// CHECK: [[VAR_42_:%.+]] = arith.subi [[VAR_40_]], [[VAR_41_]] : index -// CHECK: [[VAR_43_:%.+]] = arith.addi [[VAR_36_]], [[VAR_42_]] : index -// CHECK: scf.for [[I_2_:%.+]] = [[VAR_43_]] to [[VAR_37_]] step [[CST_1_]] { +// CHECK: [[VAR_27_:%.+]] = affine.min [[MAP_4_]]([[VAR_22_]]) +// CHECK: [[VAR_28_:%.+]] = arith.remsi [[VAR_27_]], [[CST_32_]] : index +// CHECK: [[VAR_29_:%.+]] = arith.subi [[VAR_27_]], [[VAR_28_]] : index +// CHECK: [[VAR_30_:%.+]] = arith.addi [[VAR_23_]], [[VAR_29_]] : index +// CHECK: scf.for [[I_2_:%.+]] = [[VAR_30_]] to [[VAR_24_]] step [[CST_1_]] { // CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = memref.load [[VAR_reshape_]]{{.}}[[I_2_]]{{.}} : memref<4335xf32> // CHECK-DAG: [[LOAD_VAR_reshape_MEM_3_:%.+]] = memref.load [[VAR_reshape_]]{{.}}[[I_2_]]{{.}} : memref<4335xf32> -// CHECK-DAG: [[LOAD_RES_4_MEM_1_:%.+]] = memref.load [[RES_4_]]{{.}}[[VAR_38_]]{{.}} : memref<256xf32> -// CHECK-DAG: [[LOAD_RES_6_MEM_1_:%.+]] = memref.load [[RES_6_]]{{.}}[[VAR_38_]]{{.}} : memref<256xf32> +// CHECK-DAG: [[LOAD_RES_4_MEM_1_:%.+]] = memref.load [[RES_4_]]{{.}}[[VAR_25_]]{{.}} : memref<256xf32> +// CHECK-DAG: [[LOAD_RES_6_MEM_1_:%.+]] = memref.load [[RES_6_]]{{.}}[[VAR_25_]]{{.}} : memref<256xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_52_1_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_1_]], [[LOAD_VAR_reshape_MEM_2_]] : f32 -// CHECK-DAG: [[VAR_53_1_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_1_]], [[LOAD_VAR_reshape_MEM_3_]] : f32 -// CHECK: memref.store [[VAR_52_1_]], [[RES_4_]]{{.}}[[VAR_38_]]{{.}} : memref<256xf32> -// CHECK: memref.store [[VAR_53_1_]], [[RES_6_]]{{.}}[[VAR_38_]]{{.}} : memref<256xf32> +// CHECK-DAG: [[VAR_39_1_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_1_]], [[LOAD_VAR_reshape_MEM_2_]] : f32 +// CHECK-DAG: [[VAR_40_1_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_1_]], [[LOAD_VAR_reshape_MEM_3_]] : f32 +// CHECK: memref.store [[VAR_39_1_]], [[RES_4_]]{{.}}[[VAR_25_]]{{.}} : memref<256xf32> +// CHECK: memref.store [[VAR_40_1_]], [[RES_6_]]{{.}}[[VAR_25_]]{{.}} : memref<256xf32> // CHECK: } -// CHECK-DAG: [[LOAD_RES_4_MEM_2_:%.+]] = vector.load [[RES_4_]]{{.}}[[VAR_38_]]{{.}} : memref<256xf32>, vector<32xf32> -// CHECK-DAG: [[LOAD_RES_6_MEM_2_:%.+]] = vector.load [[RES_6_]]{{.}}[[VAR_38_]]{{.}} : memref<256xf32>, vector<32xf32> +// CHECK-DAG: [[LOAD_RES_4_MEM_2_:%.+]] = vector.load [[RES_4_]]{{.}}[[VAR_25_]]{{.}} : memref<256xf32>, vector<32xf32> +// CHECK-DAG: [[LOAD_RES_6_MEM_2_:%.+]] = vector.load [[RES_6_]]{{.}}[[VAR_25_]]{{.}} : memref<256xf32>, vector<32xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_46_:%.+]] = vector.reduction , [[LOAD_RES_4_MEM_2_]] : vector<32xf32> into f32 -// CHECK-DAG: [[VAR_47_:%.+]] = vector.reduction , [[LOAD_RES_6_MEM_2_]] : vector<32xf32> into f32 -// CHECK: memref.store [[VAR_46_]], [[RES_5_]]{{.}}[[VAR_35_]]{{.}} : memref<8xf32> -// CHECK: memref.store [[VAR_47_]], [[RES_7_]]{{.}}[[VAR_35_]]{{.}} : memref<8xf32> +// CHECK-DAG: [[VAR_33_:%.+]] = vector.reduction , [[LOAD_RES_4_MEM_2_]] : vector<32xf32> into f32 +// CHECK-DAG: [[VAR_34_:%.+]] = vector.reduction , [[LOAD_RES_6_MEM_2_]] : vector<32xf32> into f32 +// CHECK: memref.store [[VAR_33_]], [[RES_5_]]{{.}}[[VAR_22_]]{{.}} : memref<8xf32> +// CHECK: memref.store [[VAR_34_]], [[RES_7_]]{{.}}[[VAR_22_]]{{.}} : memref<8xf32> // CHECK: } // CHECK-DAG: [[RES_8_:%.+]] = memref.alloc() : memref // CHECK-DAG: [[RES_9_:%.+]] = memref.alloc() : memref -// CHECK: vector.store [[VAR_cst_5_]], [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32>, vector<1xf32> -// CHECK: vector.store [[VAR_cst_4_]], [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32>, vector<1xf32> +// CHECK: vector.store [[VAR_cst_2_]], [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32>, vector<1xf32> +// CHECK: vector.store [[VAR_cst_1_]], [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32>, vector<1xf32> // CHECK: [[LOOP_1_:%.+]] = krnl.define_loops 1 // CHECK: krnl.iterate([[LOOP_1_]]) with ([[LOOP_1_]] -> [[I_3_:%.+]] = 0 to 8){ -// CHECK: [[VAR_35_1_:%.+]] = krnl.get_induction_var_value([[LOOP_1_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[VAR_36_1_:%.+]] = krnl.load [[RES_5_]]{{.}}[[VAR_35_1_]]{{.}} : memref<8xf32> -// CHECK-DAG: [[VAR_37_1_:%.+]] = krnl.load [[RES_7_]]{{.}}[[VAR_35_1_]]{{.}} : memref<8xf32> +// CHECK: [[VAR_22_1_:%.+]] = krnl.get_induction_var_value([[LOOP_1_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[VAR_23_1_:%.+]] = krnl.load [[RES_5_]]{{.}}[[VAR_22_1_]]{{.}} : memref<8xf32> +// CHECK-DAG: [[VAR_24_1_:%.+]] = krnl.load [[RES_7_]]{{.}}[[VAR_22_1_]]{{.}} : memref<8xf32> // CHECK-DAG: [[LOAD_RES_4_MEM_3_:%.+]] = krnl.load [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32> // CHECK-DAG: [[LOAD_RES_6_MEM_3_:%.+]] = krnl.load [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_40_1_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_3_]], [[VAR_36_1_]] : f32 -// CHECK-DAG: [[VAR_41_1_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_3_]], [[VAR_37_1_]] : f32 -// CHECK: krnl.store [[VAR_40_1_]], [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32> -// CHECK: krnl.store [[VAR_41_1_]], [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32> +// CHECK-DAG: [[VAR_27_1_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_3_]], [[VAR_23_1_]] : f32 +// CHECK-DAG: [[VAR_28_1_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_3_]], [[VAR_24_1_]] : f32 +// CHECK: krnl.store [[VAR_27_1_]], [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32> +// CHECK: krnl.store [[VAR_28_1_]], [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32> // CHECK: } // CHECK-DAG: [[LOAD_RES_4_MEM_4_:%.+]] = vector.load [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32>, vector<1xf32> // CHECK-DAG: [[LOAD_RES_6_MEM_4_:%.+]] = vector.load [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32>, vector<1xf32> @@ -312,97 +282,64 @@ func.func @test_dynamic_quantize_linear_simd_and_scalar(%arg0: tensor<255x17xf32 // CHECK: [[VAR_13_:%.+]] = arith.subf [[CST_0_dot_000000_]], [[VAR_12_]] : f32 // CHECK: [[VAR_14_:%.+]] = arith.maxnumf [[VAR_13_]], [[CST_0_dot_000000_]] : f32 // CHECK: [[VAR_15_:%.+]] = arith.minnumf [[VAR_14_]], [[CST_2_dot_550000_]] : f32 -// CHECK: [[VAR_16_:%.+]] = math.floor [[VAR_15_]] : f32 -// CHECK: [[VAR_17_:%.+]] = arith.subf [[VAR_15_]], [[VAR_16_]] : f32 -// CHECK-DAG: [[VAR_18_:%.+]] = arith.cmpf ogt, [[VAR_17_]], [[CST_5_dot_000000_]] : f32 -// CHECK-DAG: [[VAR_19_:%.+]] = arith.addf [[VAR_16_]], [[CST_1_dot_000000_]] : f32 -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_20_:%.+]] = arith.select [[VAR_18_]], [[VAR_19_]], [[VAR_16_]] : f32 -// CHECK-DAG: [[VAR_21_:%.+]] = arith.mulf [[VAR_16_]], [[CST_5_dot_000000_]] : f32 -// CHECK: [[VAR_22_:%.+]] = math.floor [[VAR_21_]] : f32 -// CHECK: [[VAR_23_:%.+]] = arith.mulf [[VAR_22_]], [[CST_2_dot_000000_]] : f32 -// CHECK: [[VAR_24_:%.+]] = arith.subf [[VAR_16_]], [[VAR_23_]] : f32 -// CHECK-DAG: [[VAR_25_:%.+]] = arith.cmpf oeq, [[VAR_24_]], [[CST_1_dot_000000_]] : f32 -// CHECK-DAG: [[VAR_26_:%.+]] = arith.addf [[VAR_16_]], [[CST_1_dot_000000_]] : f32 -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_27_:%.+]] = arith.select [[VAR_25_]], [[VAR_26_]], [[VAR_16_]] : f32 -// CHECK-DAG: [[VAR_28_:%.+]] = arith.cmpf oeq, [[VAR_17_]], [[CST_5_dot_000000_]] : f32 -// CHECK: [[VAR_29_:%.+]] = arith.select [[VAR_28_]], [[VAR_27_]], [[VAR_20_]] : f32 -// CHECK: [[VAR_30_:%.+]] = arith.fptoui [[VAR_29_]] : f32 to i32 -// CHECK: [[VAR_31_:%.+]] = arith.trunci [[VAR_30_]] : i32 to i8 -// CHECK: [[VAR_32_:%.+]] = builtin.unrealized_conversion_cast [[VAR_31_]] : i8 to ui8 +// CHECK: [[VAR_16_:%.+]] = "krnl.round_even"([[VAR_15_]]) : (f32) -> f32 +// CHECK: [[VAR_17_:%.+]] = arith.fptoui [[VAR_16_]] : f32 to i32 +// CHECK: [[VAR_18_:%.+]] = arith.trunci [[VAR_17_]] : i32 to i8 +// CHECK: [[VAR_19_:%.+]] = builtin.unrealized_conversion_cast [[VAR_18_]] : i8 to ui8 // CHECK: krnl.store [[VAR_11_]], [[RES_1_]][] : memref -// CHECK: krnl.store [[VAR_32_]], [[RES_2_]][] : memref +// CHECK: krnl.store [[VAR_19_]], [[RES_2_]][] : memref // CHECK: [[RES_10_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> // CHECK: affine.store [[CST_4335_]], [[RES_10_]][0] : memref<1xindex> -// CHECK-DAG: [[VAR_reshape_23_:%.+]] = memref.reshape [[PARAM_0_]]([[RES_10_]]) : (memref<255x17xf32>, memref<1xindex>) -> memref<4335xf32> +// CHECK-DAG: [[VAR_reshape_17_:%.+]] = memref.reshape [[PARAM_0_]]([[RES_10_]]) : (memref<255x17xf32>, memref<1xindex>) -> memref<4335xf32> // CHECK-DAG: [[RES_11_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> // CHECK: affine.store [[CST_4335_]], [[RES_11_]][0] : memref<1xindex> -// CHECK-DAG: [[VAR_reshape_25_:%.+]] = memref.reshape [[RES_]]([[RES_]]_24) : (memref<255x17xui8>, memref<1xindex>) -> memref<4335xui8> +// CHECK-DAG: [[VAR_reshape_19_:%.+]] = memref.reshape [[RES_]]([[RES_]]_18) : (memref<255x17xui8>, memref<1xindex>) -> memref<4335xui8> // CHECK-DAG: [[LOOP_2_:%.+]] = krnl.define_loops 1 // CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_2_]] 16 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) // CHECK: krnl.parallel([[BLOCK_TILE__0_]]) : !krnl.loop // CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_2_]] -> [[I_4_:%.+]] = 0 to 4320){ -// CHECK: [[VAR_35_2_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[VAR_36_1_:%.+]] = vector.load [[VAR_reshape_23_]]{{.}}[[VAR_35_2_]]{{.}} : memref<4335xf32>, vector<16xf32> -// CHECK-DAG: [[VAR_37_2_:%.+]] = vector.splat [[VAR_11_]] : vector<16xf32> -// CHECK: [[VAR_38_1_:%.+]] = arith.divf [[VAR_36_1_]], [[VAR_37_2_]] : vector<16xf32> -// CHECK: [[VAR_39_1_:%.+]] = math.floor [[VAR_38_1_]] : vector<16xf32> -// CHECK: [[VAR_40_2_:%.+]] = arith.subf [[VAR_38_1_]], [[VAR_39_1_]] : vector<16xf32> -// CHECK-DAG: [[VAR_41_2_:%.+]] = arith.cmpf ogt, [[VAR_40_2_]], [[VAR_cst_1_]] : vector<16xf32> -// CHECK-DAG: [[VAR_42_1_:%.+]] = arith.addf [[VAR_39_1_]], [[VAR_cst_3_]] : vector<16xf32> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_43_1_:%.+]] = arith.select [[VAR_41_2_]], [[VAR_42_1_]], [[VAR_39_1_]] : vector<16xi1>, vector<16xf32> -// CHECK-DAG: [[LOAD_RES_4_MEM_2_:%.+]] = arith.mulf [[VAR_39_1_]], [[VAR_cst_1_]] : vector<16xf32> -// CHECK: [[LOAD_RES_6_MEM_2_:%.+]] = math.floor [[LOAD_RES_4_MEM_2_]] : vector<16xf32> -// CHECK: [[VAR_46_1_:%.+]] = arith.mulf [[LOAD_RES_6_MEM_2_]], [[VAR_cst_2_]] : vector<16xf32> -// CHECK: [[VAR_47_1_:%.+]] = arith.subf [[VAR_39_1_]], [[VAR_46_1_]] : vector<16xf32> -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = arith.cmpf oeq, [[VAR_47_1_]], [[VAR_cst_3_]] : vector<16xf32> -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_3_:%.+]] = arith.addf [[VAR_39_1_]], [[VAR_cst_3_]] : vector<16xf32> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[LOAD_RES_4_MEM_1_:%.+]] = arith.select [[LOAD_VAR_reshape_MEM_2_]], [[LOAD_VAR_reshape_MEM_3_]], [[VAR_39_1_]] : vector<16xi1>, vector<16xf32> -// CHECK-DAG: [[LOAD_RES_6_MEM_1_:%.+]] = arith.cmpf oeq, [[VAR_40_2_]], [[VAR_cst_1_]] : vector<16xf32> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_52_2_:%.+]] = arith.select [[LOAD_RES_6_MEM_1_]], [[LOAD_RES_4_MEM_1_]], [[VAR_43_1_]] : vector<16xi1>, vector<16xf32> -// CHECK-DAG: [[VAR_53_2_:%.+]] = vector.splat [[VAR_29_]] : vector<16xf32> -// CHECK: [[VAR_54_:%.+]] = arith.addf [[VAR_52_2_]], [[VAR_53_2_]] : vector<16xf32> -// CHECK: [[VAR_55_:%.+]] = arith.maxnumf [[VAR_54_]], [[VAR_cst_0_]] : vector<16xf32> -// CHECK: [[VAR_56_:%.+]] = arith.minnumf [[VAR_55_]], [[VAR_cst_]] : vector<16xf32> -// CHECK: [[VAR_57_:%.+]] = arith.fptoui [[VAR_56_]] : vector<16xf32> to vector<16xi32> -// CHECK: [[VAR_58_:%.+]] = arith.trunci [[VAR_57_]] : vector<16xi32> to vector<16xi8> -// CHECK: [[VAR_59_:%.+]] = builtin.unrealized_conversion_cast [[VAR_58_]] : vector<16xi8> to vector<16xui8> -// CHECK: vector.store [[VAR_59_]], [[VAR_reshape_25_]]{{.}}[[VAR_35_2_]]{{.}} : memref<4335xui8>, vector<16xui8> +// CHECK: [[VAR_22_2_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[VAR_23_1_:%.+]] = vector.load [[VAR_reshape_17_]]{{.}}[[VAR_22_2_]]{{.}} : memref<4335xf32>, vector<16xf32> +// CHECK-DAG: [[VAR_24_2_:%.+]] = vector.splat [[VAR_11_]] : vector<16xf32> +// CHECK: [[VAR_25_1_:%.+]] = arith.divf [[VAR_23_1_]], [[VAR_24_2_]] : vector<16xf32> +// CHECK: [[VAR_26_1_:%.+]] = vector.shape_cast [[VAR_25_1_]] : vector<16xf32> to vector<4x4xf32> +// CHECK: [[VAR_27_2_:%.+]] = vector.extract [[VAR_26_1_]][0] : vector<4xf32> from vector<4x4xf32> +// CHECK: [[VAR_28_2_:%.+]] = "krnl.round_even"([[VAR_27_2_]]) : (vector<4xf32>) -> vector<4xf32> +// CHECK-DAG: [[VAR_29_1_:%.+]] = vector.insert [[VAR_28_2_]], [[VAR_26_1_]] [0] : vector<4xf32> into vector<4x4xf32> +// CHECK-DAG: [[VAR_30_1_:%.+]] = vector.extract [[VAR_26_1_]][1] : vector<4xf32> from vector<4x4xf32> +// CHECK: [[LOAD_RES_4_MEM_2_:%.+]] = "krnl.round_even"([[VAR_30_1_]]) : (vector<4xf32>) -> vector<4xf32> +// CHECK-DAG: [[LOAD_RES_6_MEM_2_:%.+]] = vector.insert [[LOAD_RES_4_MEM_2_]], [[VAR_29_1_]] [1] : vector<4xf32> into vector<4x4xf32> +// CHECK-DAG: [[VAR_33_1_:%.+]] = vector.extract [[VAR_26_1_]][2] : vector<4xf32> from vector<4x4xf32> +// CHECK: [[VAR_34_1_:%.+]] = "krnl.round_even"([[VAR_33_1_]]) : (vector<4xf32>) -> vector<4xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = vector.insert [[VAR_34_1_]], [[LOAD_RES_6_MEM_2_]] [2] : vector<4xf32> into vector<4x4xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_3_:%.+]] = vector.extract [[VAR_26_1_]][3] : vector<4xf32> from vector<4x4xf32> +// CHECK: [[LOAD_RES_4_MEM_1_:%.+]] = "krnl.round_even"([[LOAD_VAR_reshape_MEM_3_]]) : (vector<4xf32>) -> vector<4xf32> +// CHECK: [[LOAD_RES_6_MEM_1_:%.+]] = vector.insert [[LOAD_RES_4_MEM_1_]], [[LOAD_VAR_reshape_MEM_2_]] [3] : vector<4xf32> into vector<4x4xf32> +// CHECK-DAG: [[VAR_39_2_:%.+]] = vector.shape_cast [[LOAD_RES_6_MEM_1_]] : vector<4x4xf32> to vector<16xf32> +// CHECK-DAG: [[VAR_40_2_:%.+]] = vector.splat [[VAR_16_]] : vector<16xf32> +// CHECK: [[VAR_41_:%.+]] = arith.addf [[VAR_39_2_]], [[VAR_40_2_]] : vector<16xf32> +// CHECK: [[VAR_42_:%.+]] = arith.maxnumf [[VAR_41_]], [[VAR_cst_0_]] : vector<16xf32> +// CHECK: [[VAR_43_:%.+]] = arith.minnumf [[VAR_42_]], [[VAR_cst_]] : vector<16xf32> +// CHECK: [[VAR_44_:%.+]] = arith.fptoui [[VAR_43_]] : vector<16xf32> to vector<16xi32> +// CHECK: [[VAR_45_:%.+]] = arith.trunci [[VAR_44_]] : vector<16xi32> to vector<16xi8> +// CHECK: [[VAR_46_:%.+]] = builtin.unrealized_conversion_cast [[VAR_45_]] : vector<16xi8> to vector<16xui8> +// CHECK: vector.store [[VAR_46_]], [[VAR_reshape_19_]]{{.}}[[VAR_22_2_]]{{.}} : memref<4335xui8>, vector<16xui8> // CHECK: } // CHECK: [[LOOP_3_:%.+]] = krnl.define_loops 1 // CHECK: krnl.iterate([[LOOP_3_]]) with ([[LOOP_3_]] -> [[I_5_:%.+]] = 4320 to 4335){ -// CHECK: [[VAR_35_3_:%.+]] = krnl.get_induction_var_value([[LOOP_3_]]) : (!krnl.loop) -> index -// CHECK: [[VAR_36_1_1_:%.+]] = krnl.load [[VAR_reshape_23_]]{{.}}[[VAR_35_3_]]{{.}} : memref<4335xf32> -// CHECK: [[VAR_37_3_:%.+]] = arith.divf [[VAR_36_1_1_]], [[VAR_11_]] : f32 -// CHECK: [[VAR_38_2_:%.+]] = math.floor [[VAR_37_3_]] : f32 -// CHECK: [[VAR_39_2_:%.+]] = arith.subf [[VAR_37_3_]], [[VAR_38_2_]] : f32 -// CHECK-DAG: [[VAR_40_3_:%.+]] = arith.cmpf ogt, [[VAR_39_2_]], [[CST_5_dot_000000_]] : f32 -// CHECK-DAG: [[VAR_41_3_:%.+]] = arith.addf [[VAR_38_2_]], [[CST_1_dot_000000_]] : f32 -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_42_2_:%.+]] = arith.select [[VAR_40_3_]], [[VAR_41_3_]], [[VAR_38_2_]] : f32 -// CHECK-DAG: [[VAR_43_2_:%.+]] = arith.mulf [[VAR_38_2_]], [[CST_5_dot_000000_]] : f32 -// CHECK: [[LOAD_RES_4_MEM_2_1_:%.+]] = math.floor [[VAR_43_2_]] : f32 -// CHECK: [[LOAD_RES_6_MEM_2_1_:%.+]] = arith.mulf [[LOAD_RES_4_MEM_2_1_]], [[CST_2_dot_000000_]] : f32 -// CHECK: [[VAR_46_2_:%.+]] = arith.subf [[VAR_38_2_]], [[LOAD_RES_6_MEM_2_1_]] : f32 -// CHECK-DAG: [[VAR_47_2_:%.+]] = arith.cmpf oeq, [[VAR_46_2_]], [[CST_1_dot_000000_]] : f32 -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_1_:%.+]] = arith.addf [[VAR_38_2_]], [[CST_1_dot_000000_]] : f32 -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_3_1_:%.+]] = arith.select [[VAR_47_2_]], [[LOAD_VAR_reshape_MEM_2_1_]], [[VAR_38_2_]] : f32 -// CHECK-DAG: [[LOAD_RES_4_MEM_1_1_:%.+]] = arith.cmpf oeq, [[VAR_39_2_]], [[CST_5_dot_000000_]] : f32 -// CHECK: [[LOAD_RES_6_MEM_1_1_:%.+]] = arith.select [[LOAD_RES_4_MEM_1_1_]], [[LOAD_VAR_reshape_MEM_3_1_]], [[VAR_42_2_]] : f32 -// CHECK: [[VAR_52_3_:%.+]] = arith.addf [[LOAD_RES_6_MEM_1_1_]], [[VAR_29_]] : f32 -// CHECK: [[VAR_53_3_:%.+]] = arith.maxnumf [[VAR_52_3_]], [[CST_0_dot_000000_]] : f32 -// CHECK: [[VAR_54_1_:%.+]] = arith.minnumf [[VAR_53_3_]], [[CST_2_dot_550000_]] : f32 -// CHECK: [[VAR_55_1_:%.+]] = arith.fptoui [[VAR_54_1_]] : f32 to i32 -// CHECK: [[VAR_56_1_:%.+]] = arith.trunci [[VAR_55_1_]] : i32 to i8 -// CHECK: [[VAR_57_1_:%.+]] = builtin.unrealized_conversion_cast [[VAR_56_1_]] : i8 to ui8 -// CHECK: krnl.store [[VAR_57_1_]], [[VAR_reshape_25_]]{{.}}[[VAR_35_3_]]{{.}} : memref<4335xui8> +// CHECK: [[VAR_22_3_:%.+]] = krnl.get_induction_var_value([[LOOP_3_]]) : (!krnl.loop) -> index +// CHECK: [[VAR_23_1_1_:%.+]] = krnl.load [[VAR_reshape_17_]]{{.}}[[VAR_22_3_]]{{.}} : memref<4335xf32> +// CHECK: [[VAR_24_3_:%.+]] = arith.divf [[VAR_23_1_1_]], [[VAR_11_]] : f32 +// CHECK: [[VAR_25_2_:%.+]] = "krnl.round_even"([[VAR_24_3_]]) : (f32) -> f32 +// CHECK: [[VAR_26_2_:%.+]] = arith.addf [[VAR_25_2_]], [[VAR_16_]] : f32 +// CHECK: [[VAR_27_3_:%.+]] = arith.maxnumf [[VAR_26_2_]], [[CST_0_dot_000000_]] : f32 +// CHECK: [[VAR_28_3_:%.+]] = arith.minnumf [[VAR_27_3_]], [[CST_2_dot_550000_]] : f32 +// CHECK: [[VAR_29_2_:%.+]] = arith.fptoui [[VAR_28_3_]] : f32 to i32 +// CHECK: [[VAR_30_2_:%.+]] = arith.trunci [[VAR_29_2_]] : i32 to i8 +// CHECK: [[LOAD_RES_4_MEM_2_1_:%.+]] = builtin.unrealized_conversion_cast [[VAR_30_2_]] : i8 to ui8 +// CHECK: krnl.store [[LOAD_RES_4_MEM_2_1_]], [[VAR_reshape_19_]]{{.}}[[VAR_22_3_]]{{.}} : memref<4335xui8> // CHECK: } -// CHECK: return [[RES_]], [[RES_]]_13, [[RES_]]_14 : memref<255x17xui8>, memref, memref +// CHECK: return [[RES_]], [[RES_]]_7, [[RES_]]_8 : memref<255x17xui8>, memref, memref // CHECK: } } @@ -418,14 +355,8 @@ func.func @test_dynamic_quantize_linear_reduced_simd_only(%arg0: tensor<1x8xf32> // CHECK-SAME: ([[PARAM_0_:%.+]]: memref<1x8xf32>) -> (memref<1x8xui8>, memref, memref) { // CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<2.550000e+02> : vector<8xf32> // CHECK-DAG: [[VAR_cst_0_:%.+]] = arith.constant dense<0.000000e+00> : vector<8xf32> -// CHECK-DAG: [[VAR_cst_1_:%.+]] = arith.constant dense<5.000000e-01> : vector<8xf32> -// CHECK-DAG: [[VAR_cst_2_:%.+]] = arith.constant dense<2.000000e+00> : vector<8xf32> -// CHECK-DAG: [[VAR_cst_3_:%.+]] = arith.constant dense<1.000000e+00> : vector<8xf32> -// CHECK-DAG: [[VAR_cst_4_:%.+]] = arith.constant dense<0xFF800000> : vector<8xf32> -// CHECK-DAG: [[VAR_cst_5_:%.+]] = arith.constant dense<0x7F800000> : vector<8xf32> -// CHECK-DAG: [[CST_5_dot_000000_:%.+]] = arith.constant 5.000000e-01 : f32 -// CHECK-DAG: [[CST_2_dot_000000_:%.+]] = arith.constant 2.000000e+00 : f32 -// CHECK-DAG: [[CST_1_dot_000000_:%.+]] = arith.constant 1.000000e+00 : f32 +// CHECK-DAG: [[VAR_cst_1_:%.+]] = arith.constant dense<0xFF800000> : vector<8xf32> +// CHECK-DAG: [[VAR_cst_2_:%.+]] = arith.constant dense<0x7F800000> : vector<8xf32> // CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index // CHECK-DAG: [[CST_0_dot_000000_:%.+]] = arith.constant 0.000000e+00 : f32 // CHECK-DAG: [[CST_2_dot_550000_:%.+]] = arith.constant 2.550000e+02 : f32 @@ -440,21 +371,21 @@ func.func @test_dynamic_quantize_linear_reduced_simd_only(%arg0: tensor<1x8xf32> // CHECK-DAG: [[RES_5_:%.+]] = memref.alloc() : memref // CHECK-DAG: [[RES_6_:%.+]] = memref.alloc() {{.*}}: memref<8xf32> // CHECK-DAG: [[RES_7_:%.+]] = memref.alloc() : memref -// CHECK: vector.store [[VAR_cst_5_]], [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> -// CHECK: vector.store [[VAR_cst_4_]], [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> +// CHECK: vector.store [[VAR_cst_2_]], [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> +// CHECK: vector.store [[VAR_cst_1_]], [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> // CHECK: [[LOOP_0_:%.+]] = krnl.define_loops 1 // CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]] 8 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) // CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to 8){ -// CHECK: [[VAR_33_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_33_]]{{.}} : memref<8xf32>, vector<8xf32> -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_1_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_33_]]{{.}} : memref<8xf32>, vector<8xf32> +// CHECK: [[VAR_20_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_20_]]{{.}} : memref<8xf32>, vector<8xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_1_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_20_]]{{.}} : memref<8xf32>, vector<8xf32> // CHECK-DAG: [[LOAD_RES_4_MEM_:%.+]] = vector.load [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> // CHECK-DAG: [[LOAD_RES_6_MEM_:%.+]] = vector.load [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_38_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_]], [[LOAD_VAR_reshape_MEM_]] : vector<8xf32> -// CHECK-DAG: [[VAR_39_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_]], [[LOAD_VAR_reshape_MEM_1_]] : vector<8xf32> -// CHECK: vector.store [[VAR_38_]], [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> -// CHECK: vector.store [[VAR_39_]], [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> +// CHECK-DAG: [[VAR_25_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_]], [[LOAD_VAR_reshape_MEM_]] : vector<8xf32> +// CHECK-DAG: [[VAR_26_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_]], [[LOAD_VAR_reshape_MEM_1_]] : vector<8xf32> +// CHECK: vector.store [[VAR_25_]], [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> +// CHECK: vector.store [[VAR_26_]], [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> // CHECK: } // CHECK-DAG: [[LOAD_RES_4_MEM_1_:%.+]] = vector.load [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> // CHECK-DAG: [[LOAD_RES_6_MEM_1_:%.+]] = vector.load [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> @@ -474,68 +405,44 @@ func.func @test_dynamic_quantize_linear_reduced_simd_only(%arg0: tensor<1x8xf32> // CHECK: [[VAR_12_:%.+]] = arith.subf [[CST_0_dot_000000_]], [[VAR_11_]] : f32 // CHECK: [[VAR_13_:%.+]] = arith.maxnumf [[VAR_12_]], [[CST_0_dot_000000_]] : f32 // CHECK: [[VAR_14_:%.+]] = arith.minnumf [[VAR_13_]], [[CST_2_dot_550000_]] : f32 -// CHECK: [[VAR_15_:%.+]] = math.floor [[VAR_14_]] : f32 -// CHECK: [[VAR_16_:%.+]] = arith.subf [[VAR_14_]], [[VAR_15_]] : f32 -// CHECK-DAG: [[VAR_17_:%.+]] = arith.cmpf ogt, [[VAR_16_]], [[CST_5_dot_000000_]] : f32 -// CHECK-DAG: [[VAR_18_:%.+]] = arith.addf [[VAR_15_]], [[CST_1_dot_000000_]] : f32 -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_19_:%.+]] = arith.select [[VAR_17_]], [[VAR_18_]], [[VAR_15_]] : f32 -// CHECK-DAG: [[VAR_20_:%.+]] = arith.mulf [[VAR_15_]], [[CST_5_dot_000000_]] : f32 -// CHECK: [[VAR_21_:%.+]] = math.floor [[VAR_20_]] : f32 -// CHECK: [[VAR_22_:%.+]] = arith.mulf [[VAR_21_]], [[CST_2_dot_000000_]] : f32 -// CHECK: [[VAR_23_:%.+]] = arith.subf [[VAR_15_]], [[VAR_22_]] : f32 -// CHECK-DAG: [[VAR_24_:%.+]] = arith.cmpf oeq, [[VAR_23_]], [[CST_1_dot_000000_]] : f32 -// CHECK-DAG: [[VAR_25_:%.+]] = arith.addf [[VAR_15_]], [[CST_1_dot_000000_]] : f32 -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_26_:%.+]] = arith.select [[VAR_24_]], [[VAR_25_]], [[VAR_15_]] : f32 -// CHECK-DAG: [[VAR_27_:%.+]] = arith.cmpf oeq, [[VAR_16_]], [[CST_5_dot_000000_]] : f32 -// CHECK: [[VAR_28_:%.+]] = arith.select [[VAR_27_]], [[VAR_26_]], [[VAR_19_]] : f32 -// CHECK: [[VAR_29_:%.+]] = arith.fptoui [[VAR_28_]] : f32 to i32 -// CHECK: [[VAR_30_:%.+]] = arith.trunci [[VAR_29_]] : i32 to i8 -// CHECK: [[VAR_31_:%.+]] = builtin.unrealized_conversion_cast [[VAR_30_]] : i8 to ui8 +// CHECK: [[VAR_15_:%.+]] = "krnl.round_even"([[VAR_14_]]) : (f32) -> f32 +// CHECK: [[VAR_16_:%.+]] = arith.fptoui [[VAR_15_]] : f32 to i32 +// CHECK: [[VAR_17_:%.+]] = arith.trunci [[VAR_16_]] : i32 to i8 +// CHECK: [[VAR_18_:%.+]] = builtin.unrealized_conversion_cast [[VAR_17_]] : i8 to ui8 // CHECK: krnl.store [[VAR_10_]], [[RES_1_]][] : memref -// CHECK: krnl.store [[VAR_31_]], [[RES_2_]][] : memref +// CHECK: krnl.store [[VAR_18_]], [[RES_2_]][] : memref // CHECK: [[RES_8_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> // CHECK: affine.store [[CST_8_]], [[RES_8_]][0] : memref<1xindex> -// CHECK-DAG: [[VAR_reshape_19_:%.+]] = memref.reshape [[PARAM_0_]]([[RES_8_]]) : (memref<1x8xf32>, memref<1xindex>) -> memref<8xf32> +// CHECK-DAG: [[VAR_reshape_13_:%.+]] = memref.reshape [[PARAM_0_]]([[RES_8_]]) : (memref<1x8xf32>, memref<1xindex>) -> memref<8xf32> // CHECK-DAG: [[RES_9_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> // CHECK: affine.store [[CST_8_]], [[RES_9_]][0] : memref<1xindex> -// CHECK-DAG: [[VAR_reshape_21_:%.+]] = memref.reshape [[RES_]]([[RES_]]_20) : (memref<1x8xui8>, memref<1xindex>) -> memref<8xui8> +// CHECK-DAG: [[VAR_reshape_15_:%.+]] = memref.reshape [[RES_]]([[RES_]]_14) : (memref<1x8xui8>, memref<1xindex>) -> memref<8xui8> // CHECK-DAG: [[LOOP_1_:%.+]] = krnl.define_loops 1 // CHECK: [[BLOCK_TILE__1_:%.+]], [[BLOCK_IN__1_:%.+]] = krnl.block [[LOOP_1_]] 8 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) // CHECK: krnl.parallel([[BLOCK_TILE__1_]]) : !krnl.loop // CHECK: krnl.iterate([[BLOCK_TILE__1_]]) with ([[LOOP_1_]] -> [[I_1_:%.+]] = 0 to 8){ -// CHECK: [[VAR_33_1_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__1_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = vector.load [[VAR_reshape_19_]]{{.}}[[VAR_33_1_]]{{.}} : memref<8xf32>, vector<8xf32> +// CHECK: [[VAR_20_1_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__1_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = vector.load [[VAR_reshape_13_]]{{.}}[[VAR_20_1_]]{{.}} : memref<8xf32>, vector<8xf32> // CHECK-DAG: [[LOAD_VAR_reshape_MEM_1_:%.+]] = vector.splat [[VAR_10_]] : vector<8xf32> // CHECK: [[LOAD_RES_4_MEM_2_:%.+]] = arith.divf [[LOAD_VAR_reshape_MEM_2_]], [[LOAD_VAR_reshape_MEM_1_]] : vector<8xf32> -// CHECK: [[LOAD_RES_6_MEM_2_:%.+]] = math.floor [[LOAD_RES_4_MEM_2_]] : vector<8xf32> -// CHECK: [[VAR_38_1_:%.+]] = arith.subf [[LOAD_RES_4_MEM_2_]], [[LOAD_RES_6_MEM_2_]] : vector<8xf32> -// CHECK-DAG: [[VAR_39_1_:%.+]] = arith.cmpf ogt, [[VAR_38_1_]], [[VAR_cst_1_]] : vector<8xf32> -// CHECK-DAG: [[VAR_40_:%.+]] = arith.addf [[LOAD_RES_6_MEM_2_]], [[VAR_cst_3_]] : vector<8xf32> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_41_:%.+]] = arith.select [[VAR_39_1_]], [[VAR_40_]], [[LOAD_RES_6_MEM_2_]] : vector<8xi1>, vector<8xf32> -// CHECK-DAG: [[VAR_42_:%.+]] = arith.mulf [[LOAD_RES_6_MEM_2_]], [[VAR_cst_1_]] : vector<8xf32> -// CHECK: [[VAR_43_:%.+]] = math.floor [[VAR_42_]] : vector<8xf32> -// CHECK: [[VAR_44_:%.+]] = arith.mulf [[VAR_43_]], [[VAR_cst_2_]] : vector<8xf32> -// CHECK: [[VAR_45_:%.+]] = arith.subf [[LOAD_RES_6_MEM_2_]], [[VAR_44_]] : vector<8xf32> -// CHECK-DAG: [[VAR_46_:%.+]] = arith.cmpf oeq, [[VAR_45_]], [[VAR_cst_3_]] : vector<8xf32> -// CHECK-DAG: [[VAR_47_:%.+]] = arith.addf [[LOAD_RES_6_MEM_2_]], [[VAR_cst_3_]] : vector<8xf32> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_48_:%.+]] = arith.select [[VAR_46_]], [[VAR_47_]], [[LOAD_RES_6_MEM_2_]] : vector<8xi1>, vector<8xf32> -// CHECK-DAG: [[VAR_49_:%.+]] = arith.cmpf oeq, [[VAR_38_1_]], [[VAR_cst_1_]] : vector<8xf32> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_50_:%.+]] = arith.select [[VAR_49_]], [[VAR_48_]], [[VAR_41_]] : vector<8xi1>, vector<8xf32> -// CHECK-DAG: [[VAR_51_:%.+]] = vector.splat [[VAR_28_]] : vector<8xf32> -// CHECK: [[VAR_52_:%.+]] = arith.addf [[VAR_50_]], [[VAR_51_]] : vector<8xf32> -// CHECK: [[VAR_53_:%.+]] = arith.maxnumf [[VAR_52_]], [[VAR_cst_0_]] : vector<8xf32> -// CHECK: [[VAR_54_:%.+]] = arith.minnumf [[VAR_53_]], [[VAR_cst_]] : vector<8xf32> -// CHECK: [[VAR_55_:%.+]] = arith.fptoui [[VAR_54_]] : vector<8xf32> to vector<8xi32> -// CHECK: [[VAR_56_:%.+]] = arith.trunci [[VAR_55_]] : vector<8xi32> to vector<8xi8> -// CHECK: [[VAR_57_:%.+]] = builtin.unrealized_conversion_cast [[VAR_56_]] : vector<8xi8> to vector<8xui8> -// CHECK: vector.store [[VAR_57_]], [[VAR_reshape_21_]]{{.}}[[VAR_33_1_]]{{.}} : memref<8xui8>, vector<8xui8> +// CHECK: [[LOAD_RES_6_MEM_2_:%.+]] = vector.shape_cast [[LOAD_RES_4_MEM_2_]] : vector<8xf32> to vector<2x4xf32> +// CHECK: [[VAR_25_1_:%.+]] = vector.extract [[LOAD_RES_6_MEM_2_]][0] : vector<4xf32> from vector<2x4xf32> +// CHECK: [[VAR_26_1_:%.+]] = "krnl.round_even"([[VAR_25_1_]]) : (vector<4xf32>) -> vector<4xf32> +// CHECK-DAG: [[VAR_27_:%.+]] = vector.insert [[VAR_26_1_]], [[LOAD_RES_6_MEM_2_]] [0] : vector<4xf32> into vector<2x4xf32> +// CHECK-DAG: [[VAR_28_:%.+]] = vector.extract [[LOAD_RES_6_MEM_2_]][1] : vector<4xf32> from vector<2x4xf32> +// CHECK: [[VAR_29_:%.+]] = "krnl.round_even"([[VAR_28_]]) : (vector<4xf32>) -> vector<4xf32> +// CHECK: [[VAR_30_:%.+]] = vector.insert [[VAR_29_]], [[VAR_27_]] [1] : vector<4xf32> into vector<2x4xf32> +// CHECK-DAG: [[VAR_31_:%.+]] = vector.shape_cast [[VAR_30_]] : vector<2x4xf32> to vector<8xf32> +// CHECK-DAG: [[VAR_32_:%.+]] = vector.splat [[VAR_15_]] : vector<8xf32> +// CHECK: [[VAR_33_:%.+]] = arith.addf [[VAR_31_]], [[VAR_32_]] : vector<8xf32> +// CHECK: [[VAR_34_:%.+]] = arith.maxnumf [[VAR_33_]], [[VAR_cst_0_]] : vector<8xf32> +// CHECK: [[VAR_35_:%.+]] = arith.minnumf [[VAR_34_]], [[VAR_cst_]] : vector<8xf32> +// CHECK: [[VAR_36_:%.+]] = arith.fptoui [[VAR_35_]] : vector<8xf32> to vector<8xi32> +// CHECK: [[VAR_37_:%.+]] = arith.trunci [[VAR_36_]] : vector<8xi32> to vector<8xi8> +// CHECK: [[VAR_38_:%.+]] = builtin.unrealized_conversion_cast [[VAR_37_]] : vector<8xi8> to vector<8xui8> +// CHECK: vector.store [[VAR_38_]], [[VAR_reshape_15_]]{{.}}[[VAR_20_1_]]{{.}} : memref<8xui8>, vector<8xui8> // CHECK: } -// CHECK: return [[RES_]], [[RES_]]_11, [[RES_]]_12 : memref<1x8xui8>, memref, memref +// CHECK: return [[RES_]], [[RES_]]_5, [[RES_]]_6 : memref<1x8xui8>, memref, memref // CHECK: } } diff --git a/utils/analyze-simd.py b/utils/analyze-simd.py index bb7650f9e2..6141b3957f 100755 --- a/utils/analyze-simd.py +++ b/utils/analyze-simd.py @@ -86,8 +86,10 @@ def define_arch_op_names(arch): op_name["vfma"] = "vfma" op_name["vmul"] = "vfm.b" op_name["vdiv"] = "vfd" - # vector conversion between formats (NNPA <-> fp) - op_name["vconv"] = "(vclfnh|vclfnl|vcfn|vcrnf|vcnf)" + # vector conversion between formats (NNPA <-> fp, FP <-> int, int <-> int) + op_name["vconv"] = ( + "(vclfnh|vclfnl|vcfn|vcrnf|vcnf|vclgd|vclfeb|vclgdb|vpkh|vpkf|vpkg)" + ) # add | sub| max | min | compare op_name["vadd"] = "([vw]fa|[vw]fs|[vw]fmax|[vw]fmin|[vw]f[ck][eh])" op_name["load"] = "lg"