diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp index ab2b96cdc42db8..746ba51a981fe0 100644 --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -15440,9 +15440,25 @@ bool BoUpSLP::collectValuesToDemote( MaskedValueIsZero(I->getOperand(1), Mask, SimplifyQuery(*DL))); }); }; + auto AbsChecker = [&](unsigned BitWidth, unsigned OrigBitWidth) { + assert(BitWidth <= OrigBitWidth && "Unexpected bitwidths!"); + return all_of(E.Scalars, [&](Value *V) { + auto *I = cast(V); + unsigned SignBits = OrigBitWidth - BitWidth; + APInt Mask = APInt::getBitsSetFrom(OrigBitWidth, BitWidth - 1); + unsigned Op0SignBits = + ComputeNumSignBits(I->getOperand(0), *DL, 0, AC, nullptr, DT); + return SignBits <= Op0SignBits && + ((SignBits != Op0SignBits && + !isKnownNonNegative(I->getOperand(0), SimplifyQuery(*DL))) || + MaskedValueIsZero(I->getOperand(0), Mask, SimplifyQuery(*DL))); + }); + }; if (ID != Intrinsic::abs) { Operands.push_back(getOperandEntry(&E, 1)); CallChecker = CompChecker; + } else { + CallChecker = AbsChecker; } InstructionCost BestCost = std::numeric_limits::max(); diff --git a/llvm/test/Transforms/SLPVectorizer/abs-overflow-incorrect-minbws.ll b/llvm/test/Transforms/SLPVectorizer/abs-overflow-incorrect-minbws.ll index a936b076138d07..51b635837d3b59 100644 --- a/llvm/test/Transforms/SLPVectorizer/abs-overflow-incorrect-minbws.ll +++ b/llvm/test/Transforms/SLPVectorizer/abs-overflow-incorrect-minbws.ll @@ -8,8 +8,10 @@ define i32 @test(i32 %n) { ; CHECK-NEXT: [[TMP0:%.*]] = insertelement <2 x i32> poison, i32 [[N]], i32 0 ; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <2 x i32> [[TMP0]], <2 x i32> poison, <2 x i32> zeroinitializer ; CHECK-NEXT: [[TMP2:%.*]] = add <2 x i32> [[TMP1]], -; CHECK-NEXT: [[TMP3:%.*]] = mul <2 x i32> [[TMP2]], -; CHECK-NEXT: [[TMP4:%.*]] = call <2 x i32> @llvm.abs.v2i32(<2 x i32> [[TMP3]], i1 false) +; CHECK-NEXT: [[TMP3:%.*]] = zext <2 x i32> [[TMP2]] to <2 x i64> +; CHECK-NEXT: [[TMP7:%.*]] = mul nuw nsw <2 x i64> [[TMP3]], +; CHECK-NEXT: [[TMP8:%.*]] = call <2 x i64> @llvm.abs.v2i64(<2 x i64> [[TMP7]], i1 true) +; CHECK-NEXT: [[TMP4:%.*]] = trunc <2 x i64> [[TMP8]] to <2 x i32> ; CHECK-NEXT: [[TMP5:%.*]] = extractelement <2 x i32> [[TMP4]], i32 0 ; CHECK-NEXT: [[TMP6:%.*]] = extractelement <2 x i32> [[TMP4]], i32 1 ; CHECK-NEXT: [[RES1:%.*]] = add i32 [[TMP5]], [[TMP6]]