From 1a494eca2018dfc94e75ded00ab81a3efd34b15f Mon Sep 17 00:00:00 2001 From: Paul Walker Date: Wed, 29 Nov 2023 13:54:33 +0000 Subject: [PATCH 1/3] [LLVM][IR] Add native vector support to ConstantInt & ConstantFP. NOTE: For brevity the following talks about ConstantInt but everything extends to cover ConstantFP as well. Whilst ConstantInt::get() supports the creation of vectors whereby each lane has the same value, it achieves this via other constants: * ConstantVector for fixed-length vectors * ConstantExprs for scalable vectors However, ConstantExprs are being deprecated and ConstantVector is not space efficient for larger vector types. By extending ConstantInt we can represent vector splats by only storing the underlying scalar value. More specifically: * ConstantInt gains an ElementCount variant of get(). * LLVMContext is extended to map ->ConstantInt. * BitcodeReader/Writer support is extended to allow vector types. Whilst this patch adds the base support, more work is required before it's production ready. For example, there's likely to be many places where isa assumes a scalar type. Accordingly the default behaviour of ConstantInt::get() remains unchanged but a set of flags are added to allow wider testing and thus help with the migration: --use-constant-int-for-fixed-length-splat --use-constant-fp-for-fixed-length-splat --use-constant-int-for-scalable-splat --use-constant-fp-for-scalable-splat NOTE: No change is required to the bitcode format because types and values are handled separately. NOTE: For similar reasons as above, code generation doesn't work out-the-box. --- llvm/include/llvm/IR/Constants.h | 12 +++- llvm/lib/Bitcode/Reader/BitcodeReader.cpp | 55 +++++++-------- llvm/lib/Bitcode/Writer/BitcodeWriter.cpp | 2 +- llvm/lib/IR/AsmWriter.cpp | 27 ++++++-- llvm/lib/IR/Constants.cpp | 82 ++++++++++++++++++++++- llvm/lib/IR/LLVMContextImpl.cpp | 2 + llvm/lib/IR/LLVMContextImpl.h | 4 ++ llvm/test/Bitcode/constant-splat.ll | 61 +++++++++++++++++ 8 files changed, 208 insertions(+), 37 deletions(-) create mode 100644 llvm/test/Bitcode/constant-splat.ll diff --git a/llvm/include/llvm/IR/Constants.h b/llvm/include/llvm/IR/Constants.h index b5dcc7fbc1d929..39eec1b738fabb 100644 --- a/llvm/include/llvm/IR/Constants.h +++ b/llvm/include/llvm/IR/Constants.h @@ -81,7 +81,7 @@ class ConstantInt final : public ConstantData { APInt Val; - ConstantInt(IntegerType *Ty, const APInt &V); + ConstantInt(Type *Ty, const APInt &V); void destroyConstantImpl(); @@ -123,6 +123,12 @@ class ConstantInt final : public ConstantData { /// type is the integer type that corresponds to the bit width of the value. static ConstantInt *get(LLVMContext &Context, const APInt &V); + /// Return a ConstantInt with the specified value and an implied Type. The + /// type is the vector type whose integer element type corresponds to the bit + /// width of the value. + static ConstantInt *get(LLVMContext &Context, ElementCount EC, + const APInt &V); + /// Return a ConstantInt constructed from the string strStart with the given /// radix. static ConstantInt *get(IntegerType *Ty, StringRef Str, uint8_t Radix); @@ -136,7 +142,7 @@ class ConstantInt final : public ConstantData { /// Return the constant's value. inline const APInt &getValue() const { return Val; } - /// getBitWidth - Return the bitwidth of this constant. + /// getBitWidth - Return the scalar bitwidth of this constant. unsigned getBitWidth() const { return Val.getBitWidth(); } /// Return the constant as a 64-bit unsigned integer value after it @@ -281,6 +287,8 @@ class ConstantFP final : public ConstantData { static Constant *get(Type *Ty, StringRef Str); static ConstantFP *get(LLVMContext &Context, const APFloat &V); + static ConstantFP *get(LLVMContext &Context, ElementCount EC, + const APFloat &V); static Constant *getNaN(Type *Ty, bool Negative = false, uint64_t Payload = 0); static Constant *getQNaN(Type *Ty, bool Negative = false, diff --git a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp index 515a1d0caa0415..832907a3f53f5f 100644 --- a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp +++ b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp @@ -3060,48 +3060,49 @@ Error BitcodeReader::parseConstants() { V = Constant::getNullValue(CurTy); break; case bitc::CST_CODE_INTEGER: // INTEGER: [intval] - if (!CurTy->isIntegerTy() || Record.empty()) + if (!CurTy->isIntOrIntVectorTy() || Record.empty()) return error("Invalid integer const record"); V = ConstantInt::get(CurTy, decodeSignRotatedValue(Record[0])); break; case bitc::CST_CODE_WIDE_INTEGER: {// WIDE_INTEGER: [n x intval] - if (!CurTy->isIntegerTy() || Record.empty()) + if (!CurTy->isIntOrIntVectorTy() || Record.empty()) return error("Invalid wide integer const record"); - APInt VInt = - readWideAPInt(Record, cast(CurTy)->getBitWidth()); - V = ConstantInt::get(Context, VInt); - + auto *ScalarTy = cast(CurTy->getScalarType()); + APInt VInt = readWideAPInt(Record, ScalarTy->getBitWidth()); + V = ConstantInt::get(CurTy, VInt); break; } case bitc::CST_CODE_FLOAT: { // FLOAT: [fpval] if (Record.empty()) return error("Invalid float const record"); - if (CurTy->isHalfTy()) - V = ConstantFP::get(Context, APFloat(APFloat::IEEEhalf(), - APInt(16, (uint16_t)Record[0]))); - else if (CurTy->isBFloatTy()) - V = ConstantFP::get(Context, APFloat(APFloat::BFloat(), - APInt(16, (uint32_t)Record[0]))); - else if (CurTy->isFloatTy()) - V = ConstantFP::get(Context, APFloat(APFloat::IEEEsingle(), - APInt(32, (uint32_t)Record[0]))); - else if (CurTy->isDoubleTy()) - V = ConstantFP::get(Context, APFloat(APFloat::IEEEdouble(), - APInt(64, Record[0]))); - else if (CurTy->isX86_FP80Ty()) { + + auto *ScalarTy = CurTy->getScalarType(); + if (ScalarTy->isHalfTy()) + V = ConstantFP::get(CurTy, APFloat(APFloat::IEEEhalf(), + APInt(16, (uint16_t)Record[0]))); + else if (ScalarTy->isBFloatTy()) + V = ConstantFP::get( + CurTy, APFloat(APFloat::BFloat(), APInt(16, (uint32_t)Record[0]))); + else if (ScalarTy->isFloatTy()) + V = ConstantFP::get(CurTy, APFloat(APFloat::IEEEsingle(), + APInt(32, (uint32_t)Record[0]))); + else if (ScalarTy->isDoubleTy()) + V = ConstantFP::get( + CurTy, APFloat(APFloat::IEEEdouble(), APInt(64, Record[0]))); + else if (ScalarTy->isX86_FP80Ty()) { // Bits are not stored the same way as a normal i80 APInt, compensate. uint64_t Rearrange[2]; Rearrange[0] = (Record[1] & 0xffffLL) | (Record[0] << 16); Rearrange[1] = Record[0] >> 48; - V = ConstantFP::get(Context, APFloat(APFloat::x87DoubleExtended(), - APInt(80, Rearrange))); - } else if (CurTy->isFP128Ty()) - V = ConstantFP::get(Context, APFloat(APFloat::IEEEquad(), - APInt(128, Record))); - else if (CurTy->isPPC_FP128Ty()) - V = ConstantFP::get(Context, APFloat(APFloat::PPCDoubleDouble(), - APInt(128, Record))); + V = ConstantFP::get( + CurTy, APFloat(APFloat::x87DoubleExtended(), APInt(80, Rearrange))); + } else if (ScalarTy->isFP128Ty()) + V = ConstantFP::get(CurTy, + APFloat(APFloat::IEEEquad(), APInt(128, Record))); + else if (ScalarTy->isPPC_FP128Ty()) + V = ConstantFP::get( + CurTy, APFloat(APFloat::PPCDoubleDouble(), APInt(128, Record))); else V = UndefValue::get(CurTy); break; diff --git a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp index 13be0b0c3307fb..656f2a6ce870f5 100644 --- a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp +++ b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp @@ -2624,7 +2624,7 @@ void ModuleBitcodeWriter::writeConstants(unsigned FirstVal, unsigned LastVal, } } else if (const ConstantFP *CFP = dyn_cast(C)) { Code = bitc::CST_CODE_FLOAT; - Type *Ty = CFP->getType(); + Type *Ty = CFP->getType()->getScalarType(); if (Ty->isHalfTy() || Ty->isBFloatTy() || Ty->isFloatTy() || Ty->isDoubleTy()) { Record.push_back(CFP->getValueAPF().bitcastToAPInt().getZExtValue()); diff --git a/llvm/lib/IR/AsmWriter.cpp b/llvm/lib/IR/AsmWriter.cpp index 0ae720e8b7ce8c..1fcda6c384d96d 100644 --- a/llvm/lib/IR/AsmWriter.cpp +++ b/llvm/lib/IR/AsmWriter.cpp @@ -1502,16 +1502,35 @@ static void WriteAPFloatInternal(raw_ostream &Out, const APFloat &APF) { static void WriteConstantInternal(raw_ostream &Out, const Constant *CV, AsmWriterContext &WriterCtx) { if (const ConstantInt *CI = dyn_cast(CV)) { - if (CI->getType()->isIntegerTy(1)) { - Out << (CI->getZExtValue() ? "true" : "false"); - return; + if (CI->getType()->isVectorTy()) { + Out << "splat ("; + WriterCtx.TypePrinter->print(CI->getType()->getScalarType(), Out); + Out << " "; } - Out << CI->getValue(); + + if (CI->getType()->getScalarType()->isIntegerTy(1)) + Out << (CI->getZExtValue() ? "true" : "false"); + else + Out << CI->getValue(); + + if (CI->getType()->isVectorTy()) + Out << ")"; + return; } if (const ConstantFP *CFP = dyn_cast(CV)) { + if (CFP->getType()->isVectorTy()) { + Out << "splat ("; + WriterCtx.TypePrinter->print(CFP->getType()->getScalarType(), Out); + Out << " "; + } + WriteAPFloatInternal(Out, CFP->getValueAPF()); + + if (CFP->getType()->isVectorTy()) + Out << ")"; + return; } diff --git a/llvm/lib/IR/Constants.cpp b/llvm/lib/IR/Constants.cpp index a38b912164b130..b04d7955afe670 100644 --- a/llvm/lib/IR/Constants.cpp +++ b/llvm/lib/IR/Constants.cpp @@ -35,6 +35,20 @@ using namespace llvm; using namespace PatternMatch; +// As set of temporary options to help migrate how splats are represented. +static cl::opt UseConstantIntForFixedLengthSplat( + "use-constant-int-for-fixed-length-splat", cl::init(false), cl::Hidden, + cl::desc("Use ConstantInt's native fixed-length vector splat support.")); +static cl::opt UseConstantFPForFixedLengthSplat( + "use-constant-fp-for-fixed-length-splat", cl::init(false), cl::Hidden, + cl::desc("Use ConstantFP's native fixed-length vector splat support.")); +static cl::opt UseConstantIntForScalableSplat( + "use-constant-int-for-scalable-splat", cl::init(false), cl::Hidden, + cl::desc("Use ConstantInt's native scalable vector splat support.")); +static cl::opt UseConstantFPForScalableSplat( + "use-constant-fp-for-scalable-splat", cl::init(false), cl::Hidden, + cl::desc("Use ConstantFP's native scalable vector splat support.")); + //===----------------------------------------------------------------------===// // Constant Class //===----------------------------------------------------------------------===// @@ -825,9 +839,11 @@ bool Constant::isManifestConstant() const { // ConstantInt //===----------------------------------------------------------------------===// -ConstantInt::ConstantInt(IntegerType *Ty, const APInt &V) +ConstantInt::ConstantInt(Type *Ty, const APInt &V) : ConstantData(Ty, ConstantIntVal), Val(V) { - assert(V.getBitWidth() == Ty->getBitWidth() && "Invalid constant for type"); + assert(V.getBitWidth() == + cast(Ty->getScalarType())->getBitWidth() && + "Invalid constant for type"); } ConstantInt *ConstantInt::getTrue(LLVMContext &Context) { @@ -885,6 +901,26 @@ ConstantInt *ConstantInt::get(LLVMContext &Context, const APInt &V) { return Slot.get(); } +// Get a ConstantInt vector with each lane set to the same APInt. +ConstantInt *ConstantInt::get(LLVMContext &Context, ElementCount EC, + const APInt &V) { + // Get an existing value or the insertion position. + std::unique_ptr &Slot = + Context.pImpl->IntSplatConstants[std::make_pair(EC, V)]; + if (!Slot) { + IntegerType *ITy = IntegerType::get(Context, V.getBitWidth()); + VectorType *VTy = VectorType::get(ITy, EC); + Slot.reset(new ConstantInt(VTy, V)); + } + +#ifndef NDEBUG + IntegerType *ITy = IntegerType::get(Context, V.getBitWidth()); + VectorType *VTy = VectorType::get(ITy, EC); + assert(Slot->getType() == VTy); +#endif + return Slot.get(); +} + Constant *ConstantInt::get(Type *Ty, uint64_t V, bool isSigned) { Constant *C = get(cast(Ty->getScalarType()), V, isSigned); @@ -1024,6 +1060,26 @@ ConstantFP* ConstantFP::get(LLVMContext &Context, const APFloat& V) { return Slot.get(); } +// Get a ConstantFP vector with each lane set to the same APFloat. +ConstantFP *ConstantFP::get(LLVMContext &Context, ElementCount EC, + const APFloat &V) { + // Get an existing value or the insertion position. + std::unique_ptr &Slot = + Context.pImpl->FPSplatConstants[std::make_pair(EC, V)]; + if (!Slot) { + Type *EltTy = Type::getFloatingPointTy(Context, V.getSemantics()); + VectorType *VTy = VectorType::get(EltTy, EC); + Slot.reset(new ConstantFP(VTy, V)); + } + +#ifndef NDEBUG + Type *EltTy = Type::getFloatingPointTy(Context, V.getSemantics()); + VectorType *VTy = VectorType::get(EltTy, EC); + assert(Slot->getType() == VTy); +#endif + return Slot.get(); +} + Constant *ConstantFP::getInfinity(Type *Ty, bool Negative) { const fltSemantics &Semantics = Ty->getScalarType()->getFltSemantics(); Constant *C = get(Ty->getContext(), APFloat::getInf(Semantics, Negative)); @@ -1036,7 +1092,7 @@ Constant *ConstantFP::getInfinity(Type *Ty, bool Negative) { ConstantFP::ConstantFP(Type *Ty, const APFloat &V) : ConstantData(Ty, ConstantFPVal), Val(V) { - assert(&V.getSemantics() == &Ty->getFltSemantics() && + assert(&V.getSemantics() == &Ty->getScalarType()->getFltSemantics() && "FP type Mismatch"); } @@ -1384,6 +1440,16 @@ Constant *ConstantVector::getImpl(ArrayRef V) { Constant *ConstantVector::getSplat(ElementCount EC, Constant *V) { if (!EC.isScalable()) { + // Maintain special handling of zero. + if (!V->isNullValue()) { + if (UseConstantIntForFixedLengthSplat && isa(V)) + return ConstantInt::get(V->getContext(), EC, + cast(V)->getValue()); + if (UseConstantFPForFixedLengthSplat && isa(V)) + return ConstantFP::get(V->getContext(), EC, + cast(V)->getValue()); + } + // If this splat is compatible with ConstantDataVector, use it instead of // ConstantVector. if ((isa(V) || isa(V)) && @@ -1394,6 +1460,16 @@ Constant *ConstantVector::getSplat(ElementCount EC, Constant *V) { return get(Elts); } + // Maintain special handling of zero. + if (!V->isNullValue()) { + if (UseConstantIntForScalableSplat && isa(V)) + return ConstantInt::get(V->getContext(), EC, + cast(V)->getValue()); + if (UseConstantFPForScalableSplat && isa(V)) + return ConstantFP::get(V->getContext(), EC, + cast(V)->getValue()); + } + Type *VTy = VectorType::get(V->getType(), EC); if (V->isNullValue()) diff --git a/llvm/lib/IR/LLVMContextImpl.cpp b/llvm/lib/IR/LLVMContextImpl.cpp index 15c90a4fe7b2ec..a0bf9cae7926bb 100644 --- a/llvm/lib/IR/LLVMContextImpl.cpp +++ b/llvm/lib/IR/LLVMContextImpl.cpp @@ -119,7 +119,9 @@ LLVMContextImpl::~LLVMContextImpl() { IntZeroConstants.clear(); IntOneConstants.clear(); IntConstants.clear(); + IntSplatConstants.clear(); FPConstants.clear(); + FPSplatConstants.clear(); CDSConstants.clear(); // Destroy attribute node lists. diff --git a/llvm/lib/IR/LLVMContextImpl.h b/llvm/lib/IR/LLVMContextImpl.h index 6a20291344989d..2ee1080a1ffa29 100644 --- a/llvm/lib/IR/LLVMContextImpl.h +++ b/llvm/lib/IR/LLVMContextImpl.h @@ -1488,8 +1488,12 @@ class LLVMContextImpl { DenseMap> IntZeroConstants; DenseMap> IntOneConstants; DenseMap> IntConstants; + DenseMap, std::unique_ptr> + IntSplatConstants; DenseMap> FPConstants; + DenseMap, std::unique_ptr> + FPSplatConstants; FoldingSet AttrsSet; FoldingSet AttrsLists; diff --git a/llvm/test/Bitcode/constant-splat.ll b/llvm/test/Bitcode/constant-splat.ll new file mode 100644 index 00000000000000..d4921607d15b54 --- /dev/null +++ b/llvm/test/Bitcode/constant-splat.ll @@ -0,0 +1,61 @@ +; RUN: llvm-as -use-constant-int-for-fixed-length-splat \ +; RUN: -use-constant-fp-for-fixed-length-splat \ +; RUN: -use-constant-int-for-scalable-splat \ +; RUN: -use-constant-fp-for-scalable-splat \ +; RUN: < %s | llvm-dis -use-constant-int-for-fixed-length-splat \ +; RUN: -use-constant-fp-for-fixed-length-splat \ +; RUN: -use-constant-int-for-scalable-splat \ +; RUN: -use-constant-fp-for-scalable-splat \ +; RUN: | FileCheck %s + +; CHECK: @constant.splat.i1 = constant <1 x i1> splat (i1 true) +@constant.splat.i1 = constant <1 x i1> splat (i1 true) + +; CHECK: @constant.splat.i32 = constant <5 x i32> splat (i32 7) +@constant.splat.i32 = constant <5 x i32> splat (i32 7) + +; CHECK: @constant.splat.i128 = constant <7 x i128> splat (i128 85070591730234615870450834276742070272) +@constant.splat.i128 = constant <7 x i128> splat (i128 85070591730234615870450834276742070272) + +; CHECK: @constant.splat.f16 = constant <2 x half> splat (half 0xHBC00) +@constant.splat.f16 = constant <2 x half> splat (half 0xHBC00) + +; CHECK: @constant.splat.f32 = constant <4 x float> splat (float -2.000000e+00) +@constant.splat.f32 = constant <4 x float> splat (float -2.000000e+00) + +; CHECK: @constant.splat.f64 = constant <6 x double> splat (double -3.000000e+00) +@constant.splat.f64 = constant <6 x double> splat (double -3.000000e+00) + +; CHECK: @constant.splat.128 = constant <8 x fp128> splat (fp128 0xL00000000000000018000000000000000) +@constant.splat.128 = constant <8 x fp128> splat (fp128 0xL00000000000000018000000000000000) + +; CHECK: @constant.splat.bf16 = constant <1 x bfloat> splat (bfloat 0xRC0A0) +@constant.splat.bf16 = constant <1 x bfloat> splat (bfloat 0xRC0A0) + +; CHECK: @constant.splat.x86_fp80 = constant <3 x x86_fp80> splat (x86_fp80 0xK4000C8F5C28F5C28F800) +@constant.splat.x86_fp80 = constant <3 x x86_fp80> splat (x86_fp80 0xK4000C8F5C28F5C28F800) + +; CHECK: @constant.splat.ppc_fp128 = constant <7 x ppc_fp128> splat (ppc_fp128 0xM80000000000000000000000000000000) +@constant.splat.ppc_fp128 = constant <7 x ppc_fp128> splat (ppc_fp128 0xM80000000000000000000000000000000) + +define void @add_fixed_lenth_vector_splat_i32(<4 x i32> %a) { +; CHECK: %add = add <4 x i32> %a, splat (i32 137) + %add = add <4 x i32> %a, splat (i32 137) + ret void +} + +define <4 x i32> @ret_fixed_lenth_vector_splat_i32() { +; CHECK: ret <4 x i32> splat (i32 56) + ret <4 x i32> splat (i32 56) +} + +define void @add_fixed_lenth_vector_splat_double( %a) { +; CHECK: %add = fadd %a, splat (double 5.700000e+00) + %add = fadd %a, splat (double 5.700000e+00) + ret void +} + +define @ret_scalable_vector_splat_i32() { +; CHECK: ret splat (i32 78) + ret splat (i32 78) +} From 27b6edeea45dfb3be2ce4071a6d3e7833d3c7363 Mon Sep 17 00:00:00 2001 From: Paul Walker Date: Tue, 13 Feb 2024 14:05:17 +0000 Subject: [PATCH 2/3] Make ElementCount get interfaces private. Reduce repeated calls to getType(). --- llvm/include/llvm/IR/Constants.h | 22 ++++++++++++++-------- llvm/lib/IR/AsmWriter.cpp | 18 +++++++++++------- 2 files changed, 25 insertions(+), 15 deletions(-) diff --git a/llvm/include/llvm/IR/Constants.h b/llvm/include/llvm/IR/Constants.h index 39eec1b738fabb..c0ac9a4aa6750c 100644 --- a/llvm/include/llvm/IR/Constants.h +++ b/llvm/include/llvm/IR/Constants.h @@ -78,6 +78,7 @@ class ConstantData : public Constant { /// Class for constant integers. class ConstantInt final : public ConstantData { friend class Constant; + friend class ConstantVector; APInt Val; @@ -85,6 +86,12 @@ class ConstantInt final : public ConstantData { void destroyConstantImpl(); + /// Return a ConstantInt with the specified value and an implied Type. The + /// type is the vector type whose integer element type corresponds to the bit + /// width of the value. + static ConstantInt *get(LLVMContext &Context, ElementCount EC, + const APInt &V); + public: ConstantInt(const ConstantInt &) = delete; @@ -123,12 +130,6 @@ class ConstantInt final : public ConstantData { /// type is the integer type that corresponds to the bit width of the value. static ConstantInt *get(LLVMContext &Context, const APInt &V); - /// Return a ConstantInt with the specified value and an implied Type. The - /// type is the vector type whose integer element type corresponds to the bit - /// width of the value. - static ConstantInt *get(LLVMContext &Context, ElementCount EC, - const APInt &V); - /// Return a ConstantInt constructed from the string strStart with the given /// radix. static ConstantInt *get(IntegerType *Ty, StringRef Str, uint8_t Radix); @@ -265,6 +266,7 @@ class ConstantInt final : public ConstantData { /// class ConstantFP final : public ConstantData { friend class Constant; + friend class ConstantVector; APFloat Val; @@ -272,6 +274,12 @@ class ConstantFP final : public ConstantData { void destroyConstantImpl(); + /// Return a ConstantFP with the specified value and an implied Type. The + /// type is the vector type whose element type has the same floating point + /// semantics as the value. + static ConstantFP *get(LLVMContext &Context, ElementCount EC, + const APFloat &V); + public: ConstantFP(const ConstantFP &) = delete; @@ -287,8 +295,6 @@ class ConstantFP final : public ConstantData { static Constant *get(Type *Ty, StringRef Str); static ConstantFP *get(LLVMContext &Context, const APFloat &V); - static ConstantFP *get(LLVMContext &Context, ElementCount EC, - const APFloat &V); static Constant *getNaN(Type *Ty, bool Negative = false, uint64_t Payload = 0); static Constant *getQNaN(Type *Ty, bool Negative = false, diff --git a/llvm/lib/IR/AsmWriter.cpp b/llvm/lib/IR/AsmWriter.cpp index 1fcda6c384d96d..00cc14296e9b05 100644 --- a/llvm/lib/IR/AsmWriter.cpp +++ b/llvm/lib/IR/AsmWriter.cpp @@ -1502,33 +1502,37 @@ static void WriteAPFloatInternal(raw_ostream &Out, const APFloat &APF) { static void WriteConstantInternal(raw_ostream &Out, const Constant *CV, AsmWriterContext &WriterCtx) { if (const ConstantInt *CI = dyn_cast(CV)) { - if (CI->getType()->isVectorTy()) { + Type *Ty = CI->getType(); + + if (Ty->isVectorTy()) { Out << "splat ("; - WriterCtx.TypePrinter->print(CI->getType()->getScalarType(), Out); + WriterCtx.TypePrinter->print(Ty->getScalarType(), Out); Out << " "; } - if (CI->getType()->getScalarType()->isIntegerTy(1)) + if (Ty->getScalarType()->isIntegerTy(1)) Out << (CI->getZExtValue() ? "true" : "false"); else Out << CI->getValue(); - if (CI->getType()->isVectorTy()) + if (Ty->isVectorTy()) Out << ")"; return; } if (const ConstantFP *CFP = dyn_cast(CV)) { - if (CFP->getType()->isVectorTy()) { + Type *Ty = CFP->getType(); + + if (Ty->isVectorTy()) { Out << "splat ("; - WriterCtx.TypePrinter->print(CFP->getType()->getScalarType(), Out); + WriterCtx.TypePrinter->print(Ty->getScalarType(), Out); Out << " "; } WriteAPFloatInternal(Out, CFP->getValueAPF()); - if (CFP->getType()->isVectorTy()) + if (Ty->isVectorTy()) Out << ")"; return; From ef185d1ac54ebb6a67e57feb150c6431b3ab3eed Mon Sep 17 00:00:00 2001 From: Paul Walker Date: Wed, 21 Feb 2024 11:29:34 +0000 Subject: [PATCH 3/3] Canonicalise splat like ConstantVectors to splat(value). Also adds extra testing for zeroinitializer. --- llvm/lib/IR/Constants.cpp | 12 ++++++++++-- llvm/test/Bitcode/constant-splat.ll | 15 +++++++++++++++ 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/llvm/lib/IR/Constants.cpp b/llvm/lib/IR/Constants.cpp index b04d7955afe670..e6b92aad392f66 100644 --- a/llvm/lib/IR/Constants.cpp +++ b/llvm/lib/IR/Constants.cpp @@ -1412,11 +1412,13 @@ Constant *ConstantVector::getImpl(ArrayRef V) { bool isZero = C->isNullValue(); bool isUndef = isa(C); bool isPoison = isa(C); + bool isSplatFP = UseConstantFPForFixedLengthSplat && isa(C); + bool isSplatInt = UseConstantIntForFixedLengthSplat && isa(C); - if (isZero || isUndef) { + if (isZero || isUndef || isSplatFP || isSplatInt) { for (unsigned i = 1, e = V.size(); i != e; ++i) if (V[i] != C) { - isZero = isUndef = isPoison = false; + isZero = isUndef = isPoison = isSplatFP = isSplatInt = false; break; } } @@ -1427,6 +1429,12 @@ Constant *ConstantVector::getImpl(ArrayRef V) { return PoisonValue::get(T); if (isUndef) return UndefValue::get(T); + if (isSplatFP) + return ConstantFP::get(C->getContext(), T->getElementCount(), + cast(C)->getValue()); + if (isSplatInt) + return ConstantInt::get(C->getContext(), T->getElementCount(), + cast(C)->getValue()); // Check to see if all of the elements are ConstantFP or ConstantInt and if // the element type is compatible with ConstantDataVector. If so, use it. diff --git a/llvm/test/Bitcode/constant-splat.ll b/llvm/test/Bitcode/constant-splat.ll index d4921607d15b54..2bcc3ddf3e4f3a 100644 --- a/llvm/test/Bitcode/constant-splat.ll +++ b/llvm/test/Bitcode/constant-splat.ll @@ -59,3 +59,18 @@ define @ret_scalable_vector_splat_i32() { ; CHECK: ret splat (i32 78) ret splat (i32 78) } + +define <4 x i32> @canonical_constant_vector() { +; CHECK: ret <4 x i32> splat (i32 7) + ret <4 x i32> +} + +define <4 x i32> @canonical_fixed_lnegth_vector_zero() { +; CHECK: ret <4 x i32> zeroinitializer + ret <4 x i32> zeroinitializer +} + +define @canonical_scalable_lnegth_vector_zero() { +; CHECK: ret zeroinitializer + ret zeroinitializer +}