Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate uses of bar.dyn_cast<Foo>() to dyn_cast<Foo>(bar) #394

Merged
merged 4 commits into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 27 additions & 29 deletions lib/polygeist/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1631,9 +1631,9 @@ class CopySimplification final : public OpRewritePattern<T> {
Type elTy = dstTy.getElementType();

size_t width = 1;
if (auto IT = elTy.dyn_cast<IntegerType>())
if (auto IT = llvm::dyn_cast<IntegerType>(elTy))
width = IT.getWidth() / 8;
else if (auto FT = elTy.dyn_cast<FloatType>())
else if (auto FT = llvm::dyn_cast<FloatType>(elTy))
width = FT.getWidth() / 8;
else {
// TODO extend to llvm compatible type
Expand Down Expand Up @@ -1734,9 +1734,9 @@ class SetSimplification final : public OpRewritePattern<T> {
return failure();

size_t width = 1;
if (auto IT = elTy.dyn_cast<IntegerType>())
if (auto IT = llvm::dyn_cast<IntegerType>(elTy))
width = IT.getWidth() / 8;
else if (auto FT = elTy.dyn_cast<FloatType>())
else if (auto FT = llvm::dyn_cast<FloatType>(elTy))
width = FT.getWidth() / 8;
else {
// TODO extend to llvm compatible type
Expand Down Expand Up @@ -1787,7 +1787,7 @@ class SetSimplification final : public OpRewritePattern<T> {
SmallVector<Value> idxs;
Value val;

if (auto IT = elTy.dyn_cast<IntegerType>())
if (auto IT = llvm::dyn_cast<IntegerType>(elTy))
val =
rewriter.create<arith::ConstantIntOp>(op.getLoc(), 0, IT.getWidth());
else {
Expand Down Expand Up @@ -2479,7 +2479,7 @@ class SelectOfExt final : public OpRewritePattern<arith::SelectOp> {

LogicalResult matchAndRewrite(arith::SelectOp op,
PatternRewriter &rewriter) const override {
auto ty = op.getType().dyn_cast<IntegerType>();
auto ty = llvm::dyn_cast<IntegerType>(op.getType());
if (!ty)
return failure();
IntegerAttr lhs, rhs;
Expand Down Expand Up @@ -2589,10 +2589,10 @@ class CmpProp final : public OpRewritePattern<CmpIOp> {
v.getDefiningOp<LLVM::UndefOp>() ||
v.getDefiningOp<polygeist::UndefOp>();
if (auto extOp = v.getDefiningOp<ExtUIOp>())
if (auto it = extOp.getIn().getType().dyn_cast<IntegerType>())
if (auto it = llvm::dyn_cast<IntegerType>(extOp.getIn().getType()))
change |= it.getWidth() == 1;
if (auto extOp = v.getDefiningOp<ExtSIOp>())
if (auto it = extOp.getIn().getType().dyn_cast<IntegerType>())
if (auto it = llvm::dyn_cast<IntegerType>(extOp.getIn().getType()))
change |= it.getWidth() == 1;
}
if (!change) {
Expand Down Expand Up @@ -3175,7 +3175,7 @@ bool valueCmp(Cmp cmp, Value bval, ValueOrInt val) {
}
}

if (auto baval = bval.dyn_cast<BlockArgument>()) {
if (auto baval = llvm::dyn_cast<BlockArgument>(bval)) {
if (affine::AffineForOp afFor =
dyn_cast<affine::AffineForOp>(baval.getOwner()->getParentOp())) {
auto for_lb = afFor.getLowerBoundMap().getResults()[baval.getArgNumber()];
Expand Down Expand Up @@ -3487,7 +3487,7 @@ bool valueCmp(Cmp cmp, AffineExpr expr, size_t numDim, ValueRange operands,

// Range is [lb, ub)
bool rangeIncludes(Value bval, ValueOrInt lb, ValueOrInt ub) {
if (auto baval = bval.dyn_cast<BlockArgument>()) {
if (auto baval = llvm::dyn_cast<BlockArgument>(bval)) {
if (affine::AffineForOp afFor =
dyn_cast<affine::AffineForOp>(baval.getOwner()->getParentOp())) {
return valueCmp(
Expand Down Expand Up @@ -3621,7 +3621,7 @@ struct AffineIfSinking : public OpRewritePattern<affine::AffineIfOp> {
if (!opd) {
return failure();
}
auto ival = op.getOperands()[opd.getPosition()].dyn_cast<BlockArgument>();
auto ival = llvm::dyn_cast<BlockArgument>(op.getOperands()[opd.getPosition()]);
if (!ival) {
return failure();
}
Expand Down Expand Up @@ -3663,7 +3663,7 @@ struct AffineIfSinking : public OpRewritePattern<affine::AffineIfOp> {
if (!par.getRegion().isAncestor(v.getParentRegion()) ||
op.getThenRegion().isAncestor(v.getParentRegion()))
return;
if (auto ba = v.dyn_cast<BlockArgument>()) {
if (auto ba = llvm::dyn_cast<BlockArgument>(v)) {
if (ba.getOwner()->getParentOp() == par) {
return;
}
Expand Down Expand Up @@ -4285,7 +4285,7 @@ struct MergeNestedAffineParallelIf
continue;
}
if (auto dim = cur.dyn_cast<AffineDimExpr>()) {
auto ival = operands[dim.getPosition()].dyn_cast<BlockArgument>();
auto ival = llvm::dyn_cast<BlockArgument>(operands[dim.getPosition()]);
if (!ival || ival.getOwner()->getParentOp() != op) {
rhs = rhs + dim;
if (failure)
Expand Down Expand Up @@ -4315,7 +4315,7 @@ struct MergeNestedAffineParallelIf

if (auto dim = bop.getLHS().dyn_cast<AffineDimExpr>()) {
auto ival =
operands[dim.getPosition()].dyn_cast<BlockArgument>();
llvm::dyn_cast<BlockArgument>(operands[dim.getPosition()]);
if (!ival || ival.getOwner()->getParentOp() != op) {
rhs = rhs + bop;
// While legal, this may run before parallel merging
Expand Down Expand Up @@ -4511,7 +4511,7 @@ struct MergeParallelInductions
continue;
}
if (auto dim = cur.dyn_cast<AffineDimExpr>()) {
auto ival = operands[dim.getPosition()].dyn_cast<BlockArgument>();
auto ival = llvm::dyn_cast<BlockArgument>(operands[dim.getPosition()]);
if (!ival || ival.getOwner()->getParentOp() != op) {
rhs = rhs + dim;
continue;
Expand All @@ -4538,7 +4538,7 @@ struct MergeParallelInductions
}

if (auto dim = bop.getLHS().dyn_cast<AffineDimExpr>()) {
auto ival = operands[dim.getPosition()].dyn_cast<BlockArgument>();
auto ival = llvm::dyn_cast<BlockArgument>(operands[dim.getPosition()]);
if (!ival || ival.getOwner()->getParentOp() != op) {
rhs = rhs + bop;
continue;
Expand Down Expand Up @@ -4983,8 +4983,8 @@ template <typename T> struct BufferElimination : public OpRewritePattern<T> {
auto opd = map.getResults()[0].dyn_cast<AffineDimExpr>();
if (!opd)
continue;
auto val = ((Value)load.getMapOperands()[opd.getPosition()])
.dyn_cast<BlockArgument>();
auto val = llvm::dyn_cast<BlockArgument>(
((Value)load.getMapOperands()[opd.getPosition()]));
if (!val)
continue;

Expand Down Expand Up @@ -5033,8 +5033,8 @@ template <typename T> struct BufferElimination : public OpRewritePattern<T> {
auto opd = map.getResults()[0].dyn_cast<AffineDimExpr>();
if (!opd)
continue;
auto val = ((Value)load.getMapOperands()[opd.getPosition()])
.dyn_cast<BlockArgument>();
auto val = llvm::dyn_cast<BlockArgument>(
((Value)load.getMapOperands()[opd.getPosition()]));
if (!val)
continue;

Expand Down Expand Up @@ -5279,7 +5279,7 @@ struct AffineBufferElimination : public OpRewritePattern<T> {
}
if (storeIdxs[pair.index()].isValue) {
Value auval = storeIdxs[pair.index()].v_val;
BlockArgument bval = auval.dyn_cast<BlockArgument>();
BlockArgument bval = llvm::dyn_cast<BlockArgument>(auval);
if (!bval) {
LLVM_DEBUG(llvm::dbgs() << " + non bval expr " << bval << "\n");
continue;
Expand All @@ -5302,9 +5302,7 @@ struct AffineBufferElimination : public OpRewritePattern<T> {
auto i = idxp.index();
if (!idx.isValue) {
if (auto ald = dyn_cast<affine::AffineLoadOp>(ld)) {
if (auto ac = ald.getAffineMap()
.getResult(i)
.dyn_cast<AffineConstantExpr>()) {
if (auto ac = ald.getAffineMap().getResult(i).dyn_cast<AffineConstantExpr>()) {
if (idx == ac.getValue())
continue;
}
Expand Down Expand Up @@ -5339,7 +5337,7 @@ struct AffineBufferElimination : public OpRewritePattern<T> {
if (!VI.isValue)
continue;
auto V = VI.v_val;
auto BA = V.dyn_cast<BlockArgument>();
auto BA = llvm::dyn_cast<BlockArgument>(V);
if (!BA) {
LLVM_DEBUG(llvm::dbgs() << " + non map oper " << V << "\n");
return failure();
Expand Down Expand Up @@ -5426,7 +5424,7 @@ struct AffineBufferElimination : public OpRewritePattern<T> {
if (!VI.isValue)
continue;
auto V = VI.v_val;
auto BA = V.dyn_cast<BlockArgument>();
auto BA = llvm::dyn_cast<BlockArgument>(V);
Operation *c = BA.getOwner()->getParentOp();
if (isa<affine::AffineParallelOp>(c) || isa<scf::ParallelOp>(c)) {
Operation *tmp = store;
Expand Down Expand Up @@ -5468,7 +5466,7 @@ struct AffineBufferElimination : public OpRewritePattern<T> {
if (storeIdxSet.count(V))
continue;

if (auto BA = V.dyn_cast<BlockArgument>()) {
if (auto BA = llvm::dyn_cast<BlockArgument>(V)) {
Operation *parent = BA.getOwner()->getParentOp();

if (auto sop = storeVal.getDefiningOp())
Expand Down Expand Up @@ -5518,10 +5516,10 @@ struct AffineBufferElimination : public OpRewritePattern<T> {
if (!isa<MemoryEffects::Read>(res.getEffect()))
return false;
unsigned addr = 0;
if (auto MT = v.getType().dyn_cast<MemRefType>())
if (auto MT = llvm::dyn_cast<MemRefType>(v.getType()))
addr = MT.getMemorySpaceAsInt();
else if (auto LT =
v.getType().dyn_cast<LLVM::LLVMPointerType>())
llvm::dyn_cast<LLVM::LLVMPointerType>(v.getType()))
addr = LT.getAddressSpace();
else
return false;
Expand Down
8 changes: 4 additions & 4 deletions lib/polygeist/Passes/AffineCFG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ static bool legalCondition(Value en, bool dim = false) {
// if (!outer || legalCondition(IC.getOperand(), false)) return true;
//}
if (!dim)
if (auto BA = en.dyn_cast<BlockArgument>()) {
if (auto BA = llvm::dyn_cast<BlockArgument>(en)) {
if (isa<affine::AffineForOp, affine::AffineParallelOp>(
BA.getOwner()->getParentOp()))
return true;
Expand Down Expand Up @@ -730,7 +730,7 @@ static void setLocationAfter(PatternRewriter &b, mlir::Value val) {
it++;
b.setInsertionPoint(val.getDefiningOp()->getBlock(), it);
}
if (auto bop = val.dyn_cast<mlir::BlockArgument>())
if (auto bop = llvm::dyn_cast<mlir::BlockArgument>(val))
b.setInsertionPoint(bop.getOwner(), bop.getOwner()->begin());
}

Expand All @@ -745,7 +745,7 @@ struct IndexCastMovement : public OpRewritePattern<IndexCastOp> {
}

mlir::Value val = op.getOperand();
if (auto bop = val.dyn_cast<mlir::BlockArgument>()) {
if (auto bop = llvm::dyn_cast<mlir::BlockArgument>(val)) {
if (op.getOperation()->getBlock() != bop.getOwner()) {
op.getOperation()->moveBefore(bop.getOwner(), bop.getOwner()->begin());
return success();
Expand Down Expand Up @@ -1010,7 +1010,7 @@ bool isValidIndex(Value val) {
if (val.getDefiningOp<ConstantIntOp>())
return true;

if (auto ba = val.dyn_cast<BlockArgument>()) {
if (auto ba = llvm::dyn_cast<BlockArgument>(val)) {
auto *owner = ba.getOwner();
assert(owner);

Expand Down
28 changes: 14 additions & 14 deletions lib/polygeist/Passes/CanonicalizeFor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ struct ForBreakAddUpgrade : public OpRewritePattern<scf::ForOp> {

auto condition = outerIfOp.getCondition();
// and that the outermost if's condition is an iter arg of the for
auto condArg = condition.dyn_cast<BlockArgument>();
auto condArg = llvm::dyn_cast<BlockArgument>(condition);
if (!condArg)
return failure();
if (condArg.getOwner()->getParentOp() != forOp)
Expand All @@ -121,7 +121,7 @@ struct ForBreakAddUpgrade : public OpRewritePattern<scf::ForOp> {
// and is false unless coming from inside the if
auto forYieldOp = cast<scf::YieldOp>(block.getTerminator());
auto opres =
forYieldOp.getOperand(condArg.getArgNumber() - 1).dyn_cast<OpResult>();
llvm::dyn_cast<OpResult>(forYieldOp.getOperand(condArg.getArgNumber() - 1));
if (!opres)
return failure();
if (opres.getOwner() != outerIfOp)
Expand All @@ -143,7 +143,7 @@ struct ForBreakAddUpgrade : public OpRewritePattern<scf::ForOp> {
if (opres.getResultNumber() == regionArg.getArgNumber() - 1)
continue;

auto opres2 = forYieldOperand.dyn_cast<OpResult>();
auto opres2 = llvm::dyn_cast<OpResult>(forYieldOperand);
if (!opres2)
continue;
if (opres2.getOwner() != outerIfOp)
Expand Down Expand Up @@ -627,13 +627,13 @@ yop2.results()[idx]);
*/

bool isTopLevelArgValue(Value value, Region *region) {
if (auto arg = value.dyn_cast<BlockArgument>())
if (auto arg = llvm::dyn_cast<BlockArgument>(value))
return arg.getParentRegion() == region;
return false;
}

bool isBlockArg(Value value) {
if (auto arg = value.dyn_cast<BlockArgument>())
if (auto arg = llvm::dyn_cast<BlockArgument>(value))
return true;
return false;
}
Expand All @@ -642,7 +642,7 @@ bool dominateWhile(Value value, WhileOp loop) {
if (Operation *op = value.getDefiningOp()) {
DominanceInfo dom(loop);
return dom.properlyDominates(op, loop);
} else if (auto arg = value.dyn_cast<BlockArgument>()) {
} else if (auto arg = llvm::dyn_cast<BlockArgument>(value)) {
return arg.getOwner()->getParentOp()->isProperAncestor(loop);
} else {
assert("????");
Expand Down Expand Up @@ -682,11 +682,11 @@ struct WhileToForHelper {
negativeStep = false;

auto condOp = loop.getConditionOp();
indVar = cmpIOp.getLhs().dyn_cast<BlockArgument>();
indVar = llvm::dyn_cast<BlockArgument>(cmpIOp.getLhs());
Type extType = nullptr;
// todo handle ext
if (auto ext = cmpIOp.getLhs().getDefiningOp<ExtSIOp>()) {
indVar = ext.getIn().dyn_cast<BlockArgument>();
indVar = llvm::dyn_cast<BlockArgument>(ext.getIn());
extType = ext.getType();
}
// Condition is not the same as an induction variable
Expand Down Expand Up @@ -1004,7 +1004,7 @@ struct MoveWhileAndDown : public OpRewritePattern<WhileOp> {

Value extraCmp = andIOp->getOperand(1 - i);
Value lookThrough = nullptr;
if (auto BA = extraCmp.dyn_cast<BlockArgument>()) {
if (auto BA = llvm::dyn_cast<BlockArgument>(extraCmp)) {
lookThrough = oldYield.getOperand(BA.getArgNumber());
}
if (!helper.computeLegality(/*sizeCheck*/ false, lookThrough)) {
Expand Down Expand Up @@ -1341,7 +1341,7 @@ struct MoveWhileDown2 : public OpRewritePattern<WhileOp> {
// yield i:pair<2>
// }
if (!std::get<0>(pair).use_empty()) {
if (auto blockArg = elseYielded.dyn_cast<BlockArgument>())
if (auto blockArg = llvm::dyn_cast<BlockArgument>(elseYielded))
if (blockArg.getOwner() == &op.getBefore().front()) {
if (afterYield.getResults()[blockArg.getArgNumber()] ==
std::get<2>(pair) &&
Expand Down Expand Up @@ -1580,7 +1580,7 @@ struct WhileCmpOffset : public OpRewritePattern<WhileOp> {
if (addI.getOperand(1).getDefiningOp() &&
!op.getBefore().isAncestor(
addI.getOperand(1).getDefiningOp()->getParentRegion()))
if (auto blockArg = addI.getOperand(0).dyn_cast<BlockArgument>()) {
if (auto blockArg = llvm::dyn_cast<BlockArgument>(addI.getOperand(0))) {
if (blockArg.getOwner() == &op.getBefore().front()) {
auto rng = llvm::make_early_inc_range(blockArg.getUses());

Expand Down Expand Up @@ -1859,7 +1859,7 @@ struct WhileLICM : public OpRewritePattern<WhileOp> {
auto isDefinedOutsideOfBody = [&](Value value) {
auto *definingOp = value.getDefiningOp();
if (!definingOp) {
if (auto ba = value.dyn_cast<BlockArgument>())
if (auto ba = llvm::dyn_cast<BlockArgument>(value))
definingOp = ba.getOwner()->getParentOp();
assert(definingOp);
}
Expand Down Expand Up @@ -2125,7 +2125,7 @@ struct WhileShiftToInduction : public OpRewritePattern<WhileOp> {
if (!matchPattern(cmpIOp.getRhs(), m_Zero()))
return failure();

auto indVar = cmpIOp.getLhs().dyn_cast<BlockArgument>();
auto indVar = llvm::dyn_cast<BlockArgument>(cmpIOp.getLhs());
if (!indVar)
return failure();

Expand All @@ -2144,7 +2144,7 @@ struct WhileShiftToInduction : public OpRewritePattern<WhileOp> {
if (!matchPattern(shiftOp.getRhs(), m_One()))
return failure();

auto prevIndVar = shiftOp.getLhs().dyn_cast<BlockArgument>();
auto prevIndVar = llvm::dyn_cast<BlockArgument>(shiftOp.getLhs());
if (!prevIndVar)
return failure();

Expand Down
4 changes: 2 additions & 2 deletions lib/polygeist/Passes/CollectKernelStatistics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ std::array<StrideTy, 3> estimateStride(mlir::OperandRange indices,
} else {
return UNKNOWN;
}
} else if (auto ba = v.dyn_cast<BlockArgument>()) {
} else if (auto ba = llvm::dyn_cast<BlockArgument>(v)) {
return 0;
if (isa<gpu::GPUFuncOp>(ba.getOwner()->getParentOp())) {
return 0;
Expand Down Expand Up @@ -339,7 +339,7 @@ static void generateAlternativeKernelDescs(mlir::ModuleOp m) {
isa<arith::SubFOp>(&op) || isa<arith::AddFOp>(&op) ||
isa<arith::RemFOp>(&op) || false) {
int width =
op.getOperand(0).getType().dyn_cast<FloatType>().getWidth();
llvm::dyn_cast<FloatType>(op.getOperand(0).getType()).getWidth();
addTo(floatOps, width, blockTrips);
} else if (isa<arith::MulIOp>(&op) || isa<arith::DivUIOp>(&op) ||
isa<arith::DivSIOp>(&op) || isa<arith::SubIOp>(&op) ||
Expand Down
Loading
Loading