Skip to content

Commit

Permalink
[MLIR] Update APInt construction to correctly set isSigned/implicitTr…
Browse files Browse the repository at this point in the history
…unc (llvm#110466)

This fixes all the places in MLIR that hit the new assertion added in
llvm#106524, in preparation for enabling it by default. That is, cases where
the value passed to the APInt constructor is not an N-bit
signed/unsigned integer, where N is the bit width and signedness is
determined by the isSigned flag.

The fixes either set the correct value for isSigned, or set the
implicitTrunc flag to retain the old behavior. I've left TODOs for the
latter case in some places, where I think that it may be worthwhile to
stop doing implicit truncation in the future.

Note that the assertion is currently still disabled by default, so this
patch is mostly NFC.

This is just the MLIR changes split off from
llvm#80309.
  • Loading branch information
nikic authored and EricWF committed Oct 22, 2024
1 parent b477a51 commit 00974ca
Show file tree
Hide file tree
Showing 9 changed files with 24 additions and 12 deletions.
4 changes: 3 additions & 1 deletion mlir/include/mlir/IR/BuiltinAttributes.td
Original file line number Diff line number Diff line change
Expand Up @@ -701,8 +701,10 @@ def Builtin_IntegerAttr : Builtin_Attr<"Integer", "integer",
return $_get(type.getContext(), type, apValue);
}

// TODO: Avoid implicit trunc?
IntegerType intTy = ::llvm::cast<IntegerType>(type);
APInt apValue(intTy.getWidth(), value, intTy.isSignedInteger());
APInt apValue(intTy.getWidth(), value, intTy.isSignedInteger(),
/*implicitTrunc=*/true);
return $_get(type.getContext(), type, apValue);
}]>
];
Expand Down
3 changes: 2 additions & 1 deletion mlir/include/mlir/IR/OpImplementation.h
Original file line number Diff line number Diff line change
Expand Up @@ -749,7 +749,8 @@ class AsmParser {
// zero for non-negated integers.
result =
(IntT)uintResult.sextOrTrunc(sizeof(IntT) * CHAR_BIT).getLimitedValue();
if (APInt(uintResult.getBitWidth(), result) != uintResult)
if (APInt(uintResult.getBitWidth(), result, /*isSigned=*/true,
/*implicitTrunc=*/true) != uintResult)
return emitError(loc, "integer value too large");
return success();
}
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ Type matchContainerType(Type element, Type container) {
TypedAttr getConstantAttr(Type type, int64_t value, PatternRewriter &rewriter) {
if (auto shapedTy = dyn_cast<ShapedType>(type)) {
Type eTy = shapedTy.getElementType();
APInt valueInt(eTy.getIntOrFloatBitWidth(), value);
APInt valueInt(eTy.getIntOrFloatBitWidth(), value, /*isSigned=*/true);
return DenseIntElementsAttr::get(shapedTy, valueInt);
}

Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,7 @@ static ParseResult parseSwitchOpCases(
int64_t value = 0;
if (failed(parser.parseInteger(value)))
return failure();
values.push_back(APInt(bitWidth, value));
values.push_back(APInt(bitWidth, value, /*isSigned=*/true));

Block *destination;
SmallVector<OpAsmParser::UnresolvedOperand> operands;
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -598,7 +598,7 @@ static ParseResult parseSwitchOpCases(
int64_t value = 0;
if (failed(parser.parseInteger(value)))
return failure();
values.push_back(APInt(bitWidth, value));
values.push_back(APInt(bitWidth, value, /*isSigned=*/true));

Block *destination;
SmallVector<OpAsmParser::UnresolvedOperand> operands;
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1404,7 +1404,7 @@ static ParseResult parseMembersIndex(OpAsmParser &parser,
if (parser.parseInteger(value))
return failure();
shapeTmp++;
values.push_back(APInt(32, value));
values.push_back(APInt(32, value, /*isSigned=*/true));
return success();
};

Expand Down
16 changes: 12 additions & 4 deletions mlir/lib/IR/Builders.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,10 @@ DenseIntElementsAttr Builder::getIndexTensorAttr(ArrayRef<int64_t> values) {
}

IntegerAttr Builder::getI32IntegerAttr(int32_t value) {
return IntegerAttr::get(getIntegerType(32), APInt(32, value));
// The APInt always uses isSigned=true here because we accept the value
// as int32_t.
return IntegerAttr::get(getIntegerType(32),
APInt(32, value, /*isSigned=*/true));
}

IntegerAttr Builder::getSI32IntegerAttr(int32_t value) {
Expand All @@ -256,14 +259,19 @@ IntegerAttr Builder::getI16IntegerAttr(int16_t value) {
}

IntegerAttr Builder::getI8IntegerAttr(int8_t value) {
return IntegerAttr::get(getIntegerType(8), APInt(8, value));
// The APInt always uses isSigned=true here because we accept the value
// as int8_t.
return IntegerAttr::get(getIntegerType(8),
APInt(8, value, /*isSigned=*/true));
}

IntegerAttr Builder::getIntegerAttr(Type type, int64_t value) {
if (type.isIndex())
return IntegerAttr::get(type, APInt(64, value));
return IntegerAttr::get(
type, APInt(type.getIntOrFloatBitWidth(), value, type.isSignedInteger()));
// TODO: Avoid implicit trunc?
return IntegerAttr::get(type, APInt(type.getIntOrFloatBitWidth(), value,
type.isSignedInteger(),
/*implicitTrunc=*/true));
}

IntegerAttr Builder::getIntegerAttr(Type type, const APInt &value) {
Expand Down
3 changes: 2 additions & 1 deletion mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1287,7 +1287,8 @@ LogicalResult spirv::Deserializer::processConstant(ArrayRef<uint32_t> operands,
} words = {operands[2], operands[3]};
value = APInt(64, llvm::bit_cast<uint64_t>(words), /*isSigned=*/true);
} else if (bitwidth <= 32) {
value = APInt(bitwidth, operands[2], /*isSigned=*/true);
value = APInt(bitwidth, operands[2], /*isSigned=*/true,
/*implicitTrunc=*/true);
}

auto attr = opBuilder.getIntegerAttr(intType, value);
Expand Down
2 changes: 1 addition & 1 deletion mlir/unittests/Dialect/SPIRV/SerializationTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ TEST_F(SerializationTest, SignlessVsSignedIntegerConstantBitExtension) {
IntegerType::get(&context, 16, IntegerType::Signless);
auto signedInt16Type = IntegerType::get(&context, 16, IntegerType::Signed);
// Check the bit extension of same value under different signedness semantics.
APInt signlessIntConstVal(signlessInt16Type.getWidth(), -1,
APInt signlessIntConstVal(signlessInt16Type.getWidth(), 0xffff,
signlessInt16Type.getSignedness());
APInt signedIntConstVal(signedInt16Type.getWidth(), -1,
signedInt16Type.getSignedness());
Expand Down

0 comments on commit 00974ca

Please sign in to comment.