Skip to content

Commit

Permalink
[IR] Improve member ShuffleVectorInst::isReplicationMask()
Browse files Browse the repository at this point in the history
When we have an actual shuffle, we can impose the additional restriction
that the mask replicates the elements of the first operand, so we know
the replication factor as a ratio of output and op0 vector sizes.
  • Loading branch information
LebedevRI committed Nov 5, 2021
1 parent 6d48e25 commit a5cd278
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 8 deletions.
9 changes: 1 addition & 8 deletions llvm/include/llvm/IR/Instructions.h
Original file line number Diff line number Diff line change
Expand Up @@ -2373,14 +2373,7 @@ class ShuffleVectorInst : public Instruction {
}

/// Return true if this shuffle mask is a replication mask.
bool isReplicationMask(int &ReplicationFactor, int &VF) const {
// Not possible to express a shuffle mask for a scalable vector for this
// case.
if (isa<ScalableVectorType>(getType()))
return false;

return isReplicationMask(ShuffleMask, ReplicationFactor, VF);
}
bool isReplicationMask(int &ReplicationFactor, int &VF) const;

/// Change values in a shuffle permute mask assuming the two vector operands
/// of length InVecNumElts have swapped position.
Expand Down
15 changes: 15 additions & 0 deletions llvm/lib/IR/Instructions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2502,6 +2502,21 @@ bool ShuffleVectorInst::isReplicationMask(ArrayRef<int> Mask,
return false;
}

bool ShuffleVectorInst::isReplicationMask(int &ReplicationFactor,
int &VF) const {
// Not possible to express a shuffle mask for a scalable vector for this
// case.
if (isa<ScalableVectorType>(getType()))
return false;

VF = cast<FixedVectorType>(Op<0>()->getType())->getNumElements();
if (ShuffleMask.size() % VF != 0)
return false;
ReplicationFactor = ShuffleMask.size() / VF;

return isReplicationMaskWithParams(ShuffleMask, ReplicationFactor, VF);
}

//===----------------------------------------------------------------------===//
// InsertValueInst Class
//===----------------------------------------------------------------------===//
Expand Down
10 changes: 10 additions & 0 deletions llvm/unittests/IR/InstructionsTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1126,6 +1126,16 @@ TEST(InstructionsTest, ShuffleMaskIsReplicationMask) {
ReplicatedMask, GuessedReplicationFactor, GuessedVF));
EXPECT_EQ(GuessedReplicationFactor, ReplicationFactor);
EXPECT_EQ(GuessedVF, VF);

for (int OpVF : seq_inclusive(VF, 2 * VF + 1)) {
LLVMContext Ctx;
Type *OpVFTy = FixedVectorType::get(IntegerType::getInt1Ty(Ctx), OpVF);
Value *Op = ConstantVector::getNullValue(OpVFTy);
ShuffleVectorInst *SVI = new ShuffleVectorInst(Op, Op, ReplicatedMask);
EXPECT_EQ(SVI->isReplicationMask(GuessedReplicationFactor, GuessedVF),
OpVF == VF);
delete SVI;
}
}
}
}
Expand Down

0 comments on commit a5cd278

Please sign in to comment.