Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LLVM][IR] Add native vector support to ConstantInt & ConstantFP. #74502

Merged
merged 3 commits into from
Feb 22, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions llvm/include/llvm/IR/Constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,20 @@ class ConstantData : public Constant {
/// Class for constant integers.
class ConstantInt final : public ConstantData {
friend class Constant;
friend class ConstantVector;

APInt Val;

ConstantInt(IntegerType *Ty, const APInt &V);
ConstantInt(Type *Ty, const APInt &V);

void destroyConstantImpl();

/// Return a ConstantInt with the specified value and an implied Type. The
/// type is the vector type whose integer element type corresponds to the bit
/// width of the value.
static ConstantInt *get(LLVMContext &Context, ElementCount EC,
const APInt &V);

public:
ConstantInt(const ConstantInt &) = delete;

Expand Down Expand Up @@ -136,7 +143,7 @@ class ConstantInt final : public ConstantData {
/// Return the constant's value.
inline const APInt &getValue() const { return Val; }

/// getBitWidth - Return the bitwidth of this constant.
/// getBitWidth - Return the scalar bitwidth of this constant.
unsigned getBitWidth() const { return Val.getBitWidth(); }

/// Return the constant as a 64-bit unsigned integer value after it
Expand Down Expand Up @@ -259,13 +266,20 @@ class ConstantInt final : public ConstantData {
///
class ConstantFP final : public ConstantData {
friend class Constant;
friend class ConstantVector;

APFloat Val;

ConstantFP(Type *Ty, const APFloat &V);

void destroyConstantImpl();

/// Return a ConstantFP with the specified value and an implied Type. The
/// type is the vector type whose element type has the same floating point
/// semantics as the value.
static ConstantFP *get(LLVMContext &Context, ElementCount EC,
const APFloat &V);

public:
ConstantFP(const ConstantFP &) = delete;

Expand Down
55 changes: 28 additions & 27 deletions llvm/lib/Bitcode/Reader/BitcodeReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3060,48 +3060,49 @@ Error BitcodeReader::parseConstants() {
V = Constant::getNullValue(CurTy);
break;
case bitc::CST_CODE_INTEGER: // INTEGER: [intval]
if (!CurTy->isIntegerTy() || Record.empty())
if (!CurTy->isIntOrIntVectorTy() || Record.empty())
return error("Invalid integer const record");
V = ConstantInt::get(CurTy, decodeSignRotatedValue(Record[0]));
break;
case bitc::CST_CODE_WIDE_INTEGER: {// WIDE_INTEGER: [n x intval]
if (!CurTy->isIntegerTy() || Record.empty())
if (!CurTy->isIntOrIntVectorTy() || Record.empty())
return error("Invalid wide integer const record");

APInt VInt =
readWideAPInt(Record, cast<IntegerType>(CurTy)->getBitWidth());
V = ConstantInt::get(Context, VInt);

auto *ScalarTy = cast<IntegerType>(CurTy->getScalarType());
APInt VInt = readWideAPInt(Record, ScalarTy->getBitWidth());
V = ConstantInt::get(CurTy, VInt);
break;
}
case bitc::CST_CODE_FLOAT: { // FLOAT: [fpval]
if (Record.empty())
return error("Invalid float const record");
if (CurTy->isHalfTy())
V = ConstantFP::get(Context, APFloat(APFloat::IEEEhalf(),
APInt(16, (uint16_t)Record[0])));
else if (CurTy->isBFloatTy())
V = ConstantFP::get(Context, APFloat(APFloat::BFloat(),
APInt(16, (uint32_t)Record[0])));
else if (CurTy->isFloatTy())
V = ConstantFP::get(Context, APFloat(APFloat::IEEEsingle(),
APInt(32, (uint32_t)Record[0])));
else if (CurTy->isDoubleTy())
V = ConstantFP::get(Context, APFloat(APFloat::IEEEdouble(),
APInt(64, Record[0])));
else if (CurTy->isX86_FP80Ty()) {

auto *ScalarTy = CurTy->getScalarType();
if (ScalarTy->isHalfTy())
V = ConstantFP::get(CurTy, APFloat(APFloat::IEEEhalf(),
APInt(16, (uint16_t)Record[0])));
else if (ScalarTy->isBFloatTy())
V = ConstantFP::get(
CurTy, APFloat(APFloat::BFloat(), APInt(16, (uint32_t)Record[0])));
else if (ScalarTy->isFloatTy())
V = ConstantFP::get(CurTy, APFloat(APFloat::IEEEsingle(),
APInt(32, (uint32_t)Record[0])));
else if (ScalarTy->isDoubleTy())
V = ConstantFP::get(
CurTy, APFloat(APFloat::IEEEdouble(), APInt(64, Record[0])));
else if (ScalarTy->isX86_FP80Ty()) {
// Bits are not stored the same way as a normal i80 APInt, compensate.
uint64_t Rearrange[2];
Rearrange[0] = (Record[1] & 0xffffLL) | (Record[0] << 16);
Rearrange[1] = Record[0] >> 48;
V = ConstantFP::get(Context, APFloat(APFloat::x87DoubleExtended(),
APInt(80, Rearrange)));
} else if (CurTy->isFP128Ty())
V = ConstantFP::get(Context, APFloat(APFloat::IEEEquad(),
APInt(128, Record)));
else if (CurTy->isPPC_FP128Ty())
V = ConstantFP::get(Context, APFloat(APFloat::PPCDoubleDouble(),
APInt(128, Record)));
V = ConstantFP::get(
CurTy, APFloat(APFloat::x87DoubleExtended(), APInt(80, Rearrange)));
} else if (ScalarTy->isFP128Ty())
V = ConstantFP::get(CurTy,
APFloat(APFloat::IEEEquad(), APInt(128, Record)));
else if (ScalarTy->isPPC_FP128Ty())
V = ConstantFP::get(
CurTy, APFloat(APFloat::PPCDoubleDouble(), APInt(128, Record)));
else
V = UndefValue::get(CurTy);
break;
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2624,7 +2624,7 @@ void ModuleBitcodeWriter::writeConstants(unsigned FirstVal, unsigned LastVal,
}
} else if (const ConstantFP *CFP = dyn_cast<ConstantFP>(C)) {
Code = bitc::CST_CODE_FLOAT;
Type *Ty = CFP->getType();
Type *Ty = CFP->getType()->getScalarType();
if (Ty->isHalfTy() || Ty->isBFloatTy() || Ty->isFloatTy() ||
Ty->isDoubleTy()) {
Record.push_back(CFP->getValueAPF().bitcastToAPInt().getZExtValue());
Expand Down
31 changes: 27 additions & 4 deletions llvm/lib/IR/AsmWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1502,16 +1502,39 @@ static void WriteAPFloatInternal(raw_ostream &Out, const APFloat &APF) {
static void WriteConstantInternal(raw_ostream &Out, const Constant *CV,
AsmWriterContext &WriterCtx) {
if (const ConstantInt *CI = dyn_cast<ConstantInt>(CV)) {
if (CI->getType()->isIntegerTy(1)) {
Out << (CI->getZExtValue() ? "true" : "false");
return;
Type *Ty = CI->getType();

if (Ty->isVectorTy()) {
Out << "splat (";
WriterCtx.TypePrinter->print(Ty->getScalarType(), Out);
Out << " ";
}
Out << CI->getValue();

if (Ty->getScalarType()->isIntegerTy(1))
Out << (CI->getZExtValue() ? "true" : "false");
else
Out << CI->getValue();

if (Ty->isVectorTy())
Out << ")";

return;
}

if (const ConstantFP *CFP = dyn_cast<ConstantFP>(CV)) {
Type *Ty = CFP->getType();

if (Ty->isVectorTy()) {
Out << "splat (";
WriterCtx.TypePrinter->print(Ty->getScalarType(), Out);
Out << " ";
}

WriteAPFloatInternal(Out, CFP->getValueAPF());

if (Ty->isVectorTy())
Out << ")";

return;
}

Expand Down
82 changes: 79 additions & 3 deletions llvm/lib/IR/Constants.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,20 @@
using namespace llvm;
using namespace PatternMatch;

// As set of temporary options to help migrate how splats are represented.
static cl::opt<bool> UseConstantIntForFixedLengthSplat(
"use-constant-int-for-fixed-length-splat", cl::init(false), cl::Hidden,
cl::desc("Use ConstantInt's native fixed-length vector splat support."));
static cl::opt<bool> UseConstantFPForFixedLengthSplat(
"use-constant-fp-for-fixed-length-splat", cl::init(false), cl::Hidden,
cl::desc("Use ConstantFP's native fixed-length vector splat support."));
static cl::opt<bool> UseConstantIntForScalableSplat(
"use-constant-int-for-scalable-splat", cl::init(false), cl::Hidden,
cl::desc("Use ConstantInt's native scalable vector splat support."));
static cl::opt<bool> UseConstantFPForScalableSplat(
"use-constant-fp-for-scalable-splat", cl::init(false), cl::Hidden,
cl::desc("Use ConstantFP's native scalable vector splat support."));

Comment on lines +38 to +51
Copy link
Contributor

@Dinistro Dinistro Nov 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@paulwalker-arm @nikic Do you think it makes sense to make these available in a header file? I tried to set them directly from within a custom tool, but failed to get access to these. For us, setting them via. the command line is impractical.
If so, I can provide a PR that moves their declaration to a header file.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd say tentatively "no". We only really export options if we need to read them from multiple files. If we wanted to make these settable via API, I think the way we'd typically do that is by moving them into LLVMContext and adding setters.

FWIW, I don't think these options are mature enough yet for practical usage.

I tried to set them directly from within a custom tool, but failed to get access to these. For us, setting them via. the command line is impractical.

Can you provide more details on this? The usual way to set the options is via cl::ParseCommandLineOptions, which doesn't mean that they actually have to come from the command line. It's pretty typical to just hardcode some options in the invocation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fast response.

Can you provide more details on this?

I attempted to directly access the underlying Option structures, but that did not succeed.

The usual way to set the options is via cl::ParseCommandLineOptions, which doesn't mean that they actually have to come from the command line. It's pretty typical to just hardcode some options in the invocation.

I was not aware that one can executed this function twice to pass custom options to it, thanks for the hint. This should work, once some other, only tangentially related, issues have been addressed.

//===----------------------------------------------------------------------===//
// Constant Class
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -825,9 +839,11 @@ bool Constant::isManifestConstant() const {
// ConstantInt
//===----------------------------------------------------------------------===//

ConstantInt::ConstantInt(IntegerType *Ty, const APInt &V)
ConstantInt::ConstantInt(Type *Ty, const APInt &V)
: ConstantData(Ty, ConstantIntVal), Val(V) {
assert(V.getBitWidth() == Ty->getBitWidth() && "Invalid constant for type");
assert(V.getBitWidth() ==
cast<IntegerType>(Ty->getScalarType())->getBitWidth() &&
"Invalid constant for type");
}

ConstantInt *ConstantInt::getTrue(LLVMContext &Context) {
Expand Down Expand Up @@ -885,6 +901,26 @@ ConstantInt *ConstantInt::get(LLVMContext &Context, const APInt &V) {
return Slot.get();
}

// Get a ConstantInt vector with each lane set to the same APInt.
ConstantInt *ConstantInt::get(LLVMContext &Context, ElementCount EC,
const APInt &V) {
nikic marked this conversation as resolved.
Show resolved Hide resolved
// Get an existing value or the insertion position.
std::unique_ptr<ConstantInt> &Slot =
Context.pImpl->IntSplatConstants[std::make_pair(EC, V)];
if (!Slot) {
IntegerType *ITy = IntegerType::get(Context, V.getBitWidth());
VectorType *VTy = VectorType::get(ITy, EC);
Slot.reset(new ConstantInt(VTy, V));
}

#ifndef NDEBUG
IntegerType *ITy = IntegerType::get(Context, V.getBitWidth());
VectorType *VTy = VectorType::get(ITy, EC);
assert(Slot->getType() == VTy);
#endif
return Slot.get();
}

Constant *ConstantInt::get(Type *Ty, uint64_t V, bool isSigned) {
Constant *C = get(cast<IntegerType>(Ty->getScalarType()), V, isSigned);

Expand Down Expand Up @@ -1024,6 +1060,26 @@ ConstantFP* ConstantFP::get(LLVMContext &Context, const APFloat& V) {
return Slot.get();
}

// Get a ConstantFP vector with each lane set to the same APFloat.
ConstantFP *ConstantFP::get(LLVMContext &Context, ElementCount EC,
const APFloat &V) {
// Get an existing value or the insertion position.
std::unique_ptr<ConstantFP> &Slot =
Context.pImpl->FPSplatConstants[std::make_pair(EC, V)];
if (!Slot) {
Type *EltTy = Type::getFloatingPointTy(Context, V.getSemantics());
VectorType *VTy = VectorType::get(EltTy, EC);
Slot.reset(new ConstantFP(VTy, V));
}

#ifndef NDEBUG
Type *EltTy = Type::getFloatingPointTy(Context, V.getSemantics());
VectorType *VTy = VectorType::get(EltTy, EC);
assert(Slot->getType() == VTy);
#endif
return Slot.get();
}

Constant *ConstantFP::getInfinity(Type *Ty, bool Negative) {
const fltSemantics &Semantics = Ty->getScalarType()->getFltSemantics();
Constant *C = get(Ty->getContext(), APFloat::getInf(Semantics, Negative));
Expand All @@ -1036,7 +1092,7 @@ Constant *ConstantFP::getInfinity(Type *Ty, bool Negative) {

ConstantFP::ConstantFP(Type *Ty, const APFloat &V)
: ConstantData(Ty, ConstantFPVal), Val(V) {
assert(&V.getSemantics() == &Ty->getFltSemantics() &&
assert(&V.getSemantics() == &Ty->getScalarType()->getFltSemantics() &&
"FP type Mismatch");
}

Expand Down Expand Up @@ -1384,6 +1440,16 @@ Constant *ConstantVector::getImpl(ArrayRef<Constant*> V) {

Constant *ConstantVector::getSplat(ElementCount EC, Constant *V) {
if (!EC.isScalable()) {
// Maintain special handling of zero.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just wondering whether this is something you want to keep long term or just initially?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm happy to take offers but see it as temporary.

if (!V->isNullValue()) {
if (UseConstantIntForFixedLengthSplat && isa<ConstantInt>(V))
return ConstantInt::get(V->getContext(), EC,
cast<ConstantInt>(V)->getValue());
if (UseConstantFPForFixedLengthSplat && isa<ConstantFP>(V))
return ConstantFP::get(V->getContext(), EC,
cast<ConstantFP>(V)->getValue());
}

// If this splat is compatible with ConstantDataVector, use it instead of
// ConstantVector.
if ((isa<ConstantFP>(V) || isa<ConstantInt>(V)) &&
Expand All @@ -1394,6 +1460,16 @@ Constant *ConstantVector::getSplat(ElementCount EC, Constant *V) {
return get(Elts);
}

// Maintain special handling of zero.
if (!V->isNullValue()) {
if (UseConstantIntForScalableSplat && isa<ConstantInt>(V))
return ConstantInt::get(V->getContext(), EC,
cast<ConstantInt>(V)->getValue());
if (UseConstantFPForScalableSplat && isa<ConstantFP>(V))
return ConstantFP::get(V->getContext(), EC,
cast<ConstantFP>(V)->getValue());
}

Type *VTy = VectorType::get(V->getType(), EC);

if (V->isNullValue())
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/IR/LLVMContextImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,9 @@ LLVMContextImpl::~LLVMContextImpl() {
IntZeroConstants.clear();
IntOneConstants.clear();
IntConstants.clear();
IntSplatConstants.clear();
FPConstants.clear();
FPSplatConstants.clear();
CDSConstants.clear();

// Destroy attribute node lists.
Expand Down
4 changes: 4 additions & 0 deletions llvm/lib/IR/LLVMContextImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1488,8 +1488,12 @@ class LLVMContextImpl {
DenseMap<unsigned, std::unique_ptr<ConstantInt>> IntZeroConstants;
DenseMap<unsigned, std::unique_ptr<ConstantInt>> IntOneConstants;
DenseMap<APInt, std::unique_ptr<ConstantInt>> IntConstants;
DenseMap<std::pair<ElementCount, APInt>, std::unique_ptr<ConstantInt>>
IntSplatConstants;

DenseMap<APFloat, std::unique_ptr<ConstantFP>> FPConstants;
DenseMap<std::pair<ElementCount, APFloat>, std::unique_ptr<ConstantFP>>
FPSplatConstants;

FoldingSet<AttributeImpl> AttrsSet;
FoldingSet<AttributeListImpl> AttrsLists;
Expand Down
61 changes: 61 additions & 0 deletions llvm/test/Bitcode/constant-splat.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
; RUN: llvm-as -use-constant-int-for-fixed-length-splat \
; RUN: -use-constant-fp-for-fixed-length-splat \
; RUN: -use-constant-int-for-scalable-splat \
; RUN: -use-constant-fp-for-scalable-splat \
; RUN: < %s | llvm-dis -use-constant-int-for-fixed-length-splat \
; RUN: -use-constant-fp-for-fixed-length-splat \
; RUN: -use-constant-int-for-scalable-splat \
; RUN: -use-constant-fp-for-scalable-splat \
; RUN: | FileCheck %s

; CHECK: @constant.splat.i1 = constant <1 x i1> splat (i1 true)
@constant.splat.i1 = constant <1 x i1> splat (i1 true)

; CHECK: @constant.splat.i32 = constant <5 x i32> splat (i32 7)
@constant.splat.i32 = constant <5 x i32> splat (i32 7)

; CHECK: @constant.splat.i128 = constant <7 x i128> splat (i128 85070591730234615870450834276742070272)
@constant.splat.i128 = constant <7 x i128> splat (i128 85070591730234615870450834276742070272)

; CHECK: @constant.splat.f16 = constant <2 x half> splat (half 0xHBC00)
@constant.splat.f16 = constant <2 x half> splat (half 0xHBC00)

; CHECK: @constant.splat.f32 = constant <4 x float> splat (float -2.000000e+00)
@constant.splat.f32 = constant <4 x float> splat (float -2.000000e+00)

; CHECK: @constant.splat.f64 = constant <6 x double> splat (double -3.000000e+00)
@constant.splat.f64 = constant <6 x double> splat (double -3.000000e+00)

; CHECK: @constant.splat.128 = constant <8 x fp128> splat (fp128 0xL00000000000000018000000000000000)
@constant.splat.128 = constant <8 x fp128> splat (fp128 0xL00000000000000018000000000000000)

; CHECK: @constant.splat.bf16 = constant <1 x bfloat> splat (bfloat 0xRC0A0)
@constant.splat.bf16 = constant <1 x bfloat> splat (bfloat 0xRC0A0)

; CHECK: @constant.splat.x86_fp80 = constant <3 x x86_fp80> splat (x86_fp80 0xK4000C8F5C28F5C28F800)
@constant.splat.x86_fp80 = constant <3 x x86_fp80> splat (x86_fp80 0xK4000C8F5C28F5C28F800)

; CHECK: @constant.splat.ppc_fp128 = constant <7 x ppc_fp128> splat (ppc_fp128 0xM80000000000000000000000000000000)
@constant.splat.ppc_fp128 = constant <7 x ppc_fp128> splat (ppc_fp128 0xM80000000000000000000000000000000)

define void @add_fixed_lenth_vector_splat_i32(<4 x i32> %a) {
; CHECK: %add = add <4 x i32> %a, splat (i32 137)
%add = add <4 x i32> %a, splat (i32 137)
ret void
}

define <4 x i32> @ret_fixed_lenth_vector_splat_i32() {
; CHECK: ret <4 x i32> splat (i32 56)
ret <4 x i32> splat (i32 56)
}

define void @add_fixed_lenth_vector_splat_double(<vscale x 2 x double> %a) {
; CHECK: %add = fadd <vscale x 2 x double> %a, splat (double 5.700000e+00)
%add = fadd <vscale x 2 x double> %a, splat (double 5.700000e+00)
ret void
}

define <vscale x 4 x i32> @ret_scalable_vector_splat_i32() {
; CHECK: ret <vscale x 4 x i32> splat (i32 78)
ret <vscale x 4 x i32> splat (i32 78)
}
Loading