From c4bfd01a40b7f5a59fd4a433f8b84bbbffb03c42 Mon Sep 17 00:00:00 2001 From: Alexey Bataev Date: Fri, 18 Oct 2024 13:54:30 -0700 Subject: [PATCH] [SLP]Check that operand of abs does not overflow before making it part of minbitwidth transformation Need to check that the operand of the abs intrinsic can be safely truncated before making it part of the minbitwidth transformation. Fixes #112577 (cherry picked from commit 709abacdc350d63c61888607edb28ce272daa0a0) --- llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp | 16 ++++++++++++++++ .../abs-overflow-incorrect-minbws.ll | 6 ++++-- 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp index ab2b96cdc42db8..746ba51a981fe0 100644 --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -15440,9 +15440,25 @@ bool BoUpSLP::collectValuesToDemote( MaskedValueIsZero(I->getOperand(1), Mask, SimplifyQuery(*DL))); }); }; + auto AbsChecker = [&](unsigned BitWidth, unsigned OrigBitWidth) { + assert(BitWidth <= OrigBitWidth && "Unexpected bitwidths!"); + return all_of(E.Scalars, [&](Value *V) { + auto *I = cast(V); + unsigned SignBits = OrigBitWidth - BitWidth; + APInt Mask = APInt::getBitsSetFrom(OrigBitWidth, BitWidth - 1); + unsigned Op0SignBits = + ComputeNumSignBits(I->getOperand(0), *DL, 0, AC, nullptr, DT); + return SignBits <= Op0SignBits && + ((SignBits != Op0SignBits && + !isKnownNonNegative(I->getOperand(0), SimplifyQuery(*DL))) || + MaskedValueIsZero(I->getOperand(0), Mask, SimplifyQuery(*DL))); + }); + }; if (ID != Intrinsic::abs) { Operands.push_back(getOperandEntry(&E, 1)); CallChecker = CompChecker; + } else { + CallChecker = AbsChecker; } InstructionCost BestCost = std::numeric_limits::max(); diff --git a/llvm/test/Transforms/SLPVectorizer/abs-overflow-incorrect-minbws.ll b/llvm/test/Transforms/SLPVectorizer/abs-overflow-incorrect-minbws.ll index a936b076138d07..51b635837d3b59 100644 --- a/llvm/test/Transforms/SLPVectorizer/abs-overflow-incorrect-minbws.ll +++ b/llvm/test/Transforms/SLPVectorizer/abs-overflow-incorrect-minbws.ll @@ -8,8 +8,10 @@ define i32 @test(i32 %n) { ; CHECK-NEXT: [[TMP0:%.*]] = insertelement <2 x i32> poison, i32 [[N]], i32 0 ; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <2 x i32> [[TMP0]], <2 x i32> poison, <2 x i32> zeroinitializer ; CHECK-NEXT: [[TMP2:%.*]] = add <2 x i32> [[TMP1]], -; CHECK-NEXT: [[TMP3:%.*]] = mul <2 x i32> [[TMP2]], -; CHECK-NEXT: [[TMP4:%.*]] = call <2 x i32> @llvm.abs.v2i32(<2 x i32> [[TMP3]], i1 false) +; CHECK-NEXT: [[TMP3:%.*]] = zext <2 x i32> [[TMP2]] to <2 x i64> +; CHECK-NEXT: [[TMP7:%.*]] = mul nuw nsw <2 x i64> [[TMP3]], +; CHECK-NEXT: [[TMP8:%.*]] = call <2 x i64> @llvm.abs.v2i64(<2 x i64> [[TMP7]], i1 true) +; CHECK-NEXT: [[TMP4:%.*]] = trunc <2 x i64> [[TMP8]] to <2 x i32> ; CHECK-NEXT: [[TMP5:%.*]] = extractelement <2 x i32> [[TMP4]], i32 0 ; CHECK-NEXT: [[TMP6:%.*]] = extractelement <2 x i32> [[TMP4]], i32 1 ; CHECK-NEXT: [[RES1:%.*]] = add i32 [[TMP5]], [[TMP6]]