From a5cd27880a258df7df32ec1553d9e4ca7e1868a0 Mon Sep 17 00:00:00 2001 From: Roman Lebedev Date: Fri, 5 Nov 2021 19:11:55 +0300 Subject: [PATCH] [IR] Improve member `ShuffleVectorInst::isReplicationMask()` When we have an actual shuffle, we can impose the additional restriction that the mask replicates the elements of the first operand, so we know the replication factor as a ratio of output and op0 vector sizes. --- llvm/include/llvm/IR/Instructions.h | 9 +-------- llvm/lib/IR/Instructions.cpp | 15 +++++++++++++++ llvm/unittests/IR/InstructionsTest.cpp | 10 ++++++++++ 3 files changed, 26 insertions(+), 8 deletions(-) diff --git a/llvm/include/llvm/IR/Instructions.h b/llvm/include/llvm/IR/Instructions.h index b380e34523a7a3..0ef78881c6d7ac 100644 --- a/llvm/include/llvm/IR/Instructions.h +++ b/llvm/include/llvm/IR/Instructions.h @@ -2373,14 +2373,7 @@ class ShuffleVectorInst : public Instruction { } /// Return true if this shuffle mask is a replication mask. - bool isReplicationMask(int &ReplicationFactor, int &VF) const { - // Not possible to express a shuffle mask for a scalable vector for this - // case. - if (isa(getType())) - return false; - - return isReplicationMask(ShuffleMask, ReplicationFactor, VF); - } + bool isReplicationMask(int &ReplicationFactor, int &VF) const; /// Change values in a shuffle permute mask assuming the two vector operands /// of length InVecNumElts have swapped position. diff --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp index 63dd07543f4347..c42df49d97ea21 100644 --- a/llvm/lib/IR/Instructions.cpp +++ b/llvm/lib/IR/Instructions.cpp @@ -2502,6 +2502,21 @@ bool ShuffleVectorInst::isReplicationMask(ArrayRef Mask, return false; } +bool ShuffleVectorInst::isReplicationMask(int &ReplicationFactor, + int &VF) const { + // Not possible to express a shuffle mask for a scalable vector for this + // case. + if (isa(getType())) + return false; + + VF = cast(Op<0>()->getType())->getNumElements(); + if (ShuffleMask.size() % VF != 0) + return false; + ReplicationFactor = ShuffleMask.size() / VF; + + return isReplicationMaskWithParams(ShuffleMask, ReplicationFactor, VF); +} + //===----------------------------------------------------------------------===// // InsertValueInst Class //===----------------------------------------------------------------------===// diff --git a/llvm/unittests/IR/InstructionsTest.cpp b/llvm/unittests/IR/InstructionsTest.cpp index 213435f4c8d337..a4a96714e82d0b 100644 --- a/llvm/unittests/IR/InstructionsTest.cpp +++ b/llvm/unittests/IR/InstructionsTest.cpp @@ -1126,6 +1126,16 @@ TEST(InstructionsTest, ShuffleMaskIsReplicationMask) { ReplicatedMask, GuessedReplicationFactor, GuessedVF)); EXPECT_EQ(GuessedReplicationFactor, ReplicationFactor); EXPECT_EQ(GuessedVF, VF); + + for (int OpVF : seq_inclusive(VF, 2 * VF + 1)) { + LLVMContext Ctx; + Type *OpVFTy = FixedVectorType::get(IntegerType::getInt1Ty(Ctx), OpVF); + Value *Op = ConstantVector::getNullValue(OpVFTy); + ShuffleVectorInst *SVI = new ShuffleVectorInst(Op, Op, ReplicatedMask); + EXPECT_EQ(SVI->isReplicationMask(GuessedReplicationFactor, GuessedVF), + OpVF == VF); + delete SVI; + } } } }